TowardsDataScience-2023-博客中文翻译-三十九-

TowardsDataScience 2023 博客中文翻译(三十九)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

给你的 AB 测试实验调味

原文:towardsdatascience.com/seasoning-your-ab-testing-experiments-e585ab2ef2d2

盐如何帮助你的实验?

Mark EltsefonTowards Data Science Mark Eltsefon

·发表于Towards Data Science ·阅读时间 5 分钟·2023 年 3 月 13 日

--

照片由Manuel Asturias提供,来源于Unsplash

AB 测试是衡量新特性实施效果的最著名方法之一。其主要思想是将你的流量(或仅部分流量)随机分成两个或更多组。然而,确保分组真正随机且无偏见是非常重要的,这样才能对结果有信心。这就是盐的作用所在。盐的目的是消除可能影响 AB 测试结果的偏差或可预测性来源。如果没有盐,可能会有一些用户因为其特征、行为或其他因素而更可能被分配到网页或应用程序的特定变体,从而导致结果偏差。

哈希和分组。

让我们深入探讨如何将用户分成不同的实验和组。为了更清楚地说明,最好使用一个示例。我们可以将整个流量分成 10 个桶。

10 个原始桶。图片由作者提供

在实际将用户分配到桶之前,让我们回顾一下哈希函数是什么。

哈希函数用于将输入数据转换为固定大小的唯一值,该值可以表示原始数据。这有什么好处?它可以将用户的一个独特方面,通常是用户 ID,转换为一个数字,我们可以利用这个数字将他们分配到其中一个桶中。

user_id = "a4234jf3345" 

# Get the hash value of the user
hash_value = hash(user_id)

# Convert the hash value to a positive number
number = abs(hash_value)

# Print the number
print(number)

# Print the bucket that we assigned our user to 
print(number%10)

如果哈希值除以 10 的余数等于 1,则用户会被分配到桶 1,依此类推。

为什么我们需要盐?

首先,让我们讨论一下领域的概念。领域指的是你自己定义的产品的具体方面,例如购物篮、登录页面或结账页面。领域最重要的特点是,你可以自信地在不同领域之间重叠你的实验。这意味着在一个领域进行的任何实验(或大部分实验)不应对另一个领域产生任何变化。

使用领域盐帮助我们克服了两个挑战:

  1. 如果没有领域盐,用户在所有实验中都会进入相同的桶,导致实验结果的偏差。

  2. 我们受到流量的限制,但通过添加领域盐,我们可以根据领域的数量扩展我们可以进行的实验数量。对于每个领域,我们使用唯一的盐创建自己的桶。

在 Python 中,代码大致如下:

user_id = "a3d45f6g6j7"

# Get the hash value of the user with the domain salt
hash_value = hash(user_id + 'Basket')

# Convert the hash value to a positive number
number = abs(hash_value)

# Print the number
print(number)

用户分配

为了将用户分配到实验中,特定的桶被分配给每个实验。在给定的例子中,实验 #1 被分配了桶 1 和 2,其中 A 代表对照组,B 代表处理组。我们假设我们在一个领域内进行实验。

图片由作者提供

在频繁实验过程中,大多数桶通常分配给正在进行的实验。

红色 — #1 实验,绿色 — #2 实验,黄色 — #3 实验。图片由作者提供

然而,当我们的实验(#1 红色)结束后,会释放出一些可以分配给新实验的桶,这会发生什么呢?

假设现在 20% 的桶是空闲的(因为第一个实验已经结束),我们想要启动实验 #4。

图片由作者提供

我们是否可以简单地使用哈希和之前的领域盐?

答案是否定的,原因是延续效应

延续效应发生在用户接受某种处理后,可能会影响他们在之后的行为。实质上,人们往往会记住过去的经历,这可能导致未来实验结果的偏差。为了解决这个问题,我们引入了一种新盐——洗牌盐,它对每个实验都是独特的。

现在,我们的桶分配代码变为如下。

user_id = "a3d45f6g6j7"

# Get the hash value of the user with the domain and shuffle salts
hash_value = hash(user_id + 'Basket' + 'Experiment_with_basket')

# Convert the hash value to a positive number
number = abs(hash_value)

# Print the number of the bucket
print(number % 10)

洗牌盐是如何工作的?

我们将前两个桶细分为 10 个各自的桶,并将用户分配到这些子桶中,以提供更高的实验细节。

细分前两个桶。图片由作者提供

图片中的图像使用紫色和蓝色分别区分实验的对照组和处理组。

现在,它如何融入更大的图景。

实验 #4 的启动。图片由作者提供

在我们需要少于 20%流量的情况下,我们将失去一定比例的用户。例如,如果我们需要 15%的整体流量,我们必须获得 2 桶,因为我们不能得到 1.5 桶。因此,我们最终会损失 5%的流量。

不幸的是,没有办法完全消除这种损失。

实际上,大多数公司使用这两种盐。然而,一些公司也可能使用额外的盐,例如专用盐(用于测试重大变化)或更细粒度的领域盐(用于测试特定国家的变化)。

结论

从头开始构建一个 AB 测试系统并不是最简单的任务,因为有许多细微之处人们往往忽略。然而,你不应害怕尝试,并且不要忘记使用盐。

愿偏见与你无关!

成功细分的秘密

原文:towardsdatascience.com/secret-sauce-of-successful-segmentation-74e0c48d84ba

细分不仅仅是你工具箱中的另一项工具,而是你会经常使用的工具,以便了解你的客户。

Deepak Chopra | Talking Data ScienceTowards Data Science Deepak Chopra | Talking Data Science

·发布于Towards Data Science ·8 分钟阅读·2023 年 1 月 17 日

--

摄影作品来自Kaitlyn Chow,发布于Unsplash

大数据分析赋予了组织基于客户的思维、情感和行为提供超个性化服务的能力。随着一对一定位的承诺的出现,细分看起来可能显得粗糙和过时。然而,当正确使用时,细分仍然代表着一种巨大的竞争优势潜力,是倾听客户、了解他们的需求和愿望的一种手段。

分割让你的客户拥有强大的发言权,在设定战略优先事项和做出战术决策时,这种声音可以在 KPIs、统计数据和趋势的喧嚣中被淹没。

接下来,我将讨论成功细分的要素。但在此之前,让我们快速了解一下什么是细分。

请注意,我所指的是客户细分,因为这是细分的最常见用例,但相同的原则适用于任何实体的细分。

细分…什么和为什么

在今天这个竞争激烈的世界中,了解客户的需求和愿望是提供一种将来仍然受到青睐的服务(或产品)的关键,从而保持你的组织运转。

对于所有类型的组织,其客户基础通常由若干个个人或个人组(例如:其他组织/集团等)组成——这个数量可以从几百个(对于企业对企业服务组织)到几百万个(对于零售商/消费品公司),有时甚至到数十亿个(对于社交媒体平台)。

从实际角度来看,了解客户基础的每一个方面一对一是不可行的(……因为其规模巨大),也不是一个好的策略将其视为一个统一的整体(……因为实际上存在需要理解的差异)。

这导致我们需要一个折衷方案,即在一对一和一刀切方法之间的中间地带;即将客户基础分组为可管理的客户集——在每个集中的客户在关注的特征上较为相似,而不同集中的客户在同一关注特征上足够不同。

→ 细分市场 是将大量实体‘分组’或‘细分’成较少数量的实体;前提是一个细分市场内的实体是同质的 (彼此相似),而不同细分市场之间的实体是异质的 (彼此不同)

细分市场……成功的 5 大法则!

1. 实际需求——至少一个使用案例。

成功实施客户细分的第一个关键是清楚其用途是什么,并确保它适合这个任务。特别是在提出方案时,一个关键步骤是明确不同细分市场中的客户所重视的内容,这在我们将注意力集中在单一的、互不重叠的、完全包容的客户分类细分市场上时会更容易。

记住,通过创建一个细分市场,你是在将新的文献引入到整个组织中。

必须有一个真正的“需求”;即一个准备好“使用案例”,企业将在你开发细分后准备好行动。

组织必须避免“细分的死亡”;因此,首先请准备好一个使用案例。

每当你考虑构建一个“新”的细分市场以理解XYZ时,总是建议评估现有细分市场在多大程度上已经解释了XYZ的行为。

→ 如果现有细分市场在可接受的程度上解释了这些,那么你不需要构建“新”的细分市场,而是利用现有的细分市场来完成任务。

2. 建立层级细分市场。

将每个客户描述为他们所属的细分市场的优势在于,这通常比将每个人描述为与平均客户具有相同特征更准确、更连贯。

成本是为每个细分市场制定并执行不同策略所需的时间和精力。

业务的不同部分可能有或没有时间规划和执行策略,或创建自己的细分。我们也可能不希望他们创建自己的细分 (不要死于细分!!)

我们可以通过提供对细分的额外深度来满足不同业务团队的更广泛需求,通常通过将较大的细分拆分为较小的子细分来实现,尽可能做到合理。

→ 优势在于,不同团队仍然可以‘使用相同的语言’,而更精细的策略结果可以汇总到大多数业务关注的顶层细分目标上

例如,在下图中,知道哪些客户是高价值客户与低价值客户(第#1 层次细分)确实很重要。然而,在‘高价值’群体中,进一步区分‘低频率、高价值’客户与‘高频率、高价值’客户(第#2 层次细分)也非常有价值。对于进一步细分‘低价值’群体也是如此。

figure 1: 一个简单的层次细分视图(图片来源于作者)

3. 对客户基础的态度建模。

客户行为数据是通过技术手段在客户与产品和服务互动时捕获的数据。每一个细节,如点击、页面浏览、页面停留时间、加入/移除购物车、使用折扣、购买等都被记录下来。虽然行为数据本身非常丰富,但它无法捕捉客户的“感受”,如愿望和挫折感。

这种信息通常是通过对客户进行问卷调查收集的数据来获取的,这不仅可以深入了解客户对当前环境的感受,还可以了解他们希望改进的地方——这些信息可以共同转化为组织的战略细分。

尽管调查数据可以提供最丰富和全面的客户需求和愿望信息,但收集大量和详细数据可能昂贵或完全不可行

所以,挑战来了——我们是否应该以我们拥有回应的小样本客户为基础来获取洞察?

答案是……不!

我们可以将客户行为数据与调查数据结合起来,用于我们拥有这两部分信息的小样本客户。我们可以使用机器学习技术基于行为产品/服务使用详细数据建模调查基础的细分。 一旦准确度达到合理水平,基于调查的细分可以推广到整个客户基础。

— 我们不需要仅仅依赖小样本;而是要基于从整个客户群体中获得的洞察来制定策略。

4. 为你的利益相关者实现可视化。

成功的细分会嵌入业务中,并且被关键决策者很好地理解。他们需要理解它,认可它,并在正确的用例中依赖它。 在总部办公室内,几乎没有与客户的互动,良好细分的一个关键好处是通过图像、视频、引用等故事帮助决策者更接近客户,而不是仅仅用数据。

细分是一种新语言,是你向业务利益相关者介绍的新文献。

客户细分能够影响企业做出更好决策的最强大方式之一是帮助讲述一个关于客户的难忘故事。通过“去平均化”客户,细分的集体平均行为开始感觉像是一个有自身特点的连贯个体的行为。

  • 细分的力量: 用简单的术语向你的利益相关者解释你的细分如何区分客户,以及整体细分预测的客户行为。

  • 建立细分档案:提供一个额外的层次,讲述一个单一战略客户细分的详细且引人入胜的故事。通过强调他们是谁,他们做什么以及如何/为什么去做,来使你的细分生动起来。

为了说明‘新’文献,利用已经嵌入组织中的文献,使其具有相关性并且易于理解。将你的‘新’细分与已知的客户属性、产品行为、人口统计或其他现有细分进行叠加。

上述做法将有助于为你的利益相关者描绘‘细分’的图景;它将帮助他们通过访问产品和服务“可视化”这些细分。

5. 追踪它。

有意义的细分必须导致基于细分需求和欲望的策略。设定细分级别的目标是一个吸引人的方式,可以通过确保所有行动都以最初的客户目标为导向来推动客户变化。

当然,这些目标依赖于细分的构建方式,使其在时间上保持可比性。

为了确保可追踪性,你必须确保细分的可刷新性。也就是说,

  • 细分必须在一个合理的、预先决定的频率下进行刷新,以便定期考虑新客户或流失客户。

  • 为确保公平比较,将客户分配到细分的规则必须保持不变或尽可能相似。

  • 定义客户细分的数据应对广泛的人群具有可识别性和意义,以避免基于个人解释的任何偏见 (即客户应该能够准确地‘自我识别’为特定细分)

结论!

数据科学是“艺术”和“科学”的结合,这对于创建有意义且可操作的“细分”也同样适用。

构建成功的细分的关键可以总结为以下五条准则。

  1. 提前创建你的使用案例,并构建“适合”它的细分。

  2. 构建一个层级化的细分,以满足更多的使用案例。

  3. 利用数据科学算法对客户群体的态度进行建模

  4. 通过将其带入实际操作中,将其嵌入到你的利益相关者之间

  5. 不断更新并跟踪细分,以监控进展。

连接、学习与成长..

如果你喜欢这篇文章并对类似内容感兴趣,可以在MediumLinkedIn与我一对一联系加入我的邮件列表上关注我,并且(如果你还没有的话),成为Medium 家族的成员以获取数千篇有用的文章。(如果你使用上面的链接,我将获得你会费的~50%)

.. 继续学习,继续成长!

揭示对数损失的秘密

原文:towardsdatascience.com/secrets-of-log-loss-84c668f4024a

数学、理论和直观理解,专为机器学习工程师准备

Joseph Robinson, Ph.D.Towards Data Science Joseph Robinson, Ph.D.

· 发表在 Towards Data Science · 12 分钟阅读 · 2023 年 11 月 23 日

--

让我们深入探讨对数损失,并揭开这个关键机器学习目标的神秘面纱:它的数学严谨性、理论基础和直观方面。这个博客将提供深入见解,以更有效地优化你的模型,并理解对数损失在现实应用中的意义!

对数损失曲线:展示了预测概率与真实标签偏离时惩罚的增加。曲线越陡峭,错误的代价越高。图表由作者生成。

目录

· 介绍

· 对数损失的基础

· 对数损失背后的数学

· 支撑对数损失的理论

· 对数损失的直观理解

· 机器学习的实际应用

· 优化模型

· 常见陷阱及其避免方法

· 结论

介绍

神秘的对数损失既引人入胜又至关重要。它处于机器学习的核心,沐浴在数学的优雅之中。此外,对数损失是概率分类器的核心;它以更强大、更准确的模型为承诺吸引我们。

但让我们不要在惊叹和惊奇中耽搁。我们还有工作要做!

为什么作为机器学习工程师的你应该深入研究对数损失这一数学和概念的漩涡?很简单。对数损失是一把瑞士军刀。 更深刻的理解使你能够超越简单的准确度,细致审视分类器性能的细微差别。因此,对数损失不仅仅是一个数字——它是你机器学习模型稳健性的试金石,让你可以以其他指标难以达到的细腻程度进行微调和优化。

我们在这个博客中的目标是深入理解对数损失的复杂层次。我们的行程包括数学推导的严谨性,解开深藏其间的理论基础,探索直观,找出抽象中的可关联性。我们将探讨对数损失的基础,分解其数学成分,并揭示它与信息理论的关系。通过现实世界的应用和案例研究,强调这个指标的实际力量。我们将讨论一些陷阱——那些容易让你绊倒的细微差别——以及如何优雅地避免它们。最后,我们将使用可视化来更好地理解这一数学构造的理论和实践。

准备好了吗?让我们深入探讨。

对数损失基础

让我们深入探讨:什么是对数损失,何时它是你机器学习冒险中的骑士?

定义与公式

对数损失,正式称为逻辑损失或对数损失,是一个用于分类模型的性能指标,模型输出概率:数值越小越好,完美模型的对数损失为零。它在需要概率结果而非硬分类的场景中很受欢迎。对数损失量化了你的预测与实际结果之间的偏差,比单纯的准确率更具说明性。

在数学上,二分类器的对数损失通常表示为:

在这里,N 是样本的数量,y_i 是第 i^{th} 样本的真实标签,p_i 是第 i^{th} 样本被预测为类别 1 的概率。

这是一个简单的 Python 代码片段,使用 NumPy 计算对数损失:

import numpy as np
​
def calculate_log_loss(y_true, y_pred):
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
​
# Example usage
y_true = np.array([0, 1, 1, 0])
y_pred = np.array([0.1, 0.9, 0.8, 0.2])
log_loss = calculate_log_loss(y_true, y_pred)
print(f"Log Loss: {log_loss}")

输出:

Log Loss: 0.164252033486018

何时使用对数损失

对数损失在需要细致解读结果的问题中特别有用。比如医学诊断,不仅仅是标签(即生病或健康),而是生病的概率,这具有巨大的重要性。或者推荐系统:你不仅仅是对喜欢或不喜欢进行分类,还要决定用户是否可能点击推荐的项目。

对数损失与其他指标不同,后者往往笼统却忽略了细微差别(例如准确率)。准确率例如统计正确预测的事件数量,但并未说明置信度。还有 F1 分数,它是精度和召回率的调和均值,适用于不平衡数据集。然而,它也无法窥见分类的分级置信度。

另一方面,对数损失惩罚错误的同时也惩罚模型对错误的置信度。它对概率的敏感性使它成为在需要更为精细评估的情况下的宝贵指标。

对数损失景观:该 3D 图揭示了对数损失函数的复杂性。观察当预测偏离真实标签时,对数损失如何显著增加,直观地强调了自信错误预测的更陡罚款。同时请注意,对数损失对假阳性和假阴性有不同的惩罚。图表由作者生成。

因此,对数损失不仅仅是一个指标;它是衡量模型表现的叙事工具,尤其在风险高、细节丰富时效果显著。

对数损失背后的数学

数学是宇宙的语言,也是对数损失之谜存在的基础。

公式的推导

在统计学领域,我们通常从似然性这一概念开始,它是衡量模型如何解释观察数据的指标。对于二分类问题,似然性L可以表达如下。

这个公式直观但计算起来繁琐。概率的乘积可能会导致下溢,特别是当N很大时。

似然面可视化:该 3D 图捕捉了成功次数与估计成功概率之间的关系。表面上的峰值表示较高的似然性,强调了在不同概率估计和观察到的成功中,我们对预测成功的信心如何变化。图表由作者生成。

引入自然对数;我们使用自然对数将似然转化为和,我们称之为对数似然:

这里是对数损失公式的核心:该对数似然的负平均:

在这里,我们使用负号将最大化问题(即最大化对数似然)转化为最小化问题——这是优化中更为熟悉的领域——这一概念在下图中得以阐释。

对数在对数损失中的表现:一个双轴图展示了对数和负对数的行为,这两个是对数损失中的关键组成部分。图表由作者生成。

数学深度探讨

那么,为何使用对数? 你可能会问。对数不仅仅是数学上的方便工具:对数函数是单调的,保持了概率之间的顺序。而且,它放大了错误但自信预测的惩罚。例如,如果你预测的概率为 0.01,而真实标签为 1,那么对数将增加你的损失,促使你重新思考你错误的自信。

对数损失对异常值敏感且具有抗干扰性,这是一个让人着迷的悖论。预测一个接近 0 或 1 的极端概率;如果你错了,对数损失会变得惩罚性十足,不留情面。另一方面,它比其他度量标准(例如,均方误差,对极端值给予不成比例的权重)更不容易受到异常值的影响。

这是一个快速的 Python 代码片段,演示了异常值的影响:

# With an outlier
y_true_with_outlier = np.array([0, 1, 1, 0, 1])
y_pred_with_outlier = np.array([0.1, 0.9, 0.8, 0.2, 0.99])  # The 0.99 is the outlier
log_loss_with_outlier = calculate_log_loss(y_true_with_outlier, y_pred_with_outlier)
​
# Without the outlier
log_loss_without_outlier = calculate_log_loss(y_true, y_pred)
​
print(f"Log Loss with outlier: {log_loss_with_outlier}")
print(f"Log Loss without outlier: {log_loss_without_outlier}")

输出:

Log Loss with outlier: 0.1334116939595147
Log Loss without outlier: 0.164252033486018

注意,带有异常值的对数损失与没有异常值的情况相比,并没有显著偏离,展示了它对极端值的相对抗干扰性。

对数损失的数学不仅仅是一系列抽象符号;它是一种叙事。它讲述了可能性和对数,平衡了信心和惩罚。

对数损失的理论基础

现在,让我们超越数学,深入理论领域。与方程式搏斗是一回事,理解其知识基础并问道:“这个度量标准为什么存在?”这又是另一回事。你准备好进入理论的深度探讨了吗?

概率基础

从本质上讲,对数损失与信息理论紧密相连——一个量化信息的领域。信息理论说,“告诉我一些我不知道的事情,你就给了我信息。”对数损失是惊讶的度量(即,不确定性)。你的模型的预测越是偏离实际结果,人们会越“惊讶”(即传达的信息越多)。

因此,熵的概念量化了信息内容。对于一个概率为 p 的单一事件,熵 H 为:

更进一步,让我们看看交叉熵,它衡量真实分布 y 和预测分布 p 之间的距离。对于二分类,交叉熵为:

这张图展示了熵和交叉熵值如何随着真实概率 (p) 的变化而变化。熵随着概率接近 0 或 1 而减少,表示事件结果的确定性增加。交叉熵显示为一条稳定的线,强调它在测量两个概率分布之间差异中的作用。图表由作者生成。

于是,我们发现自己又回到了熟悉的对数损失,这是所有实例上交叉熵的平均值。

熵与对数损失的分布比较:直方图突出显示了这两种指标的频率分布,并叠加了表示其期望分布的曲线。熵显示了向高值集中,而对数损失则展示了更分散的分布,峰值在低值附近,体现了它们的固有差异以及数据在每种指标下的表现。图由作者生成。

在 Python 中,可以使用以下方式计算交叉熵:

def calculate_cross_entropy(y_true, y_pred):
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
​
# Example usage
cross_entropy = calculate_cross_entropy(y_true, y_pred)
print(f"Cross Entropy: {cross_entropy}")

输出:

Cross Entropy: 0.164252033486018

哲学方面

使用对数损失并非没有假设。它假设你的预测是概率,在 0 和 1 之间,并且标签是真正的二元。如果偏离这些条件,指标可能会误导你,产生难以解读的数字。

从本质上讲,对数损失衡量的是模型预测中的不确定性。它体现的哲学思想是校准概率。在理想的世界中,一个 90%信心的预测应该在 90%的时间内是正确的。因此,对数损失保持你的模型诚实,惩罚过于自信的错误答案和过于不自信的正确答案。

通过了解对数损失敏感的区域及曲线的性质,机器学习工程师可以更好地理解模型在实际环境中的表现。主要曲线(蓝色实线):显示对数损失如何随正类的真实概率变化。完美校准曲线(红色虚线):一个假设的线条,表示完美校准模型的对数损失应有的样子。黄色阴影区域:突出显示对数损失函数对真实概率变化的敏感性。在这个区域内的微小变化会导致对数损失的显著波动。注释和文本:提供了额外的见解,指出曲线上的特定点,使理解对数损失行为更加容易。图由作者生成。

对数损失的直观理解

是时候摆脱形式主义的束缚,以全新的视角探索对数损失的领域了。

类比与现实世界的例子

将对数损失视为谎言的代价。假设你在赛马中下注。你对某一结果赋予的概率越高,你如果判断错误损失的就越大。那如果是一个轻微的谎言,你本来就不太确定呢?相反,损失较少。如果是一个极大的谎言,而你非常有信心,那么损失就更大了!

回到推荐系统的应用中,使用对数损失作为用户的烦恼程度。以高度确定性推荐一部鲜有人观看的电影,你会让用户感到烦恼;若推荐正确,你则是英雄。在医疗保健中,考虑一下诊断测试。以低概率预测患者患有某种疾病,可能会导致严重后果。

这里是一个模仿基本医疗诊断模型的 Python 示例:

# Assume '1' means the patient has the disease, and '0' means they don't.
y_true = np.array([0, 1, 1, 0, 1])
y_pred = np.array([0.1, 0.9, 0.8, 0.2, 0.7])  # Predicted probabilities from the model
​
log_loss_healthcare = calculate_log_loss(y_true, y_pred)
print(f"Healthcare Log-Loss: {log_loss_healthcare}")

输出:

Healthcare Log-Loss: 0.20273661557656092

视觉化对数损失

可视化对数损失的最有效方法之一是通过绘制真实标签与预测概率的图表。当你偏离理想的对角线(预测概率与真实标签匹配的地方)时,你的对数损失会增加,直观地显示出模型的缺陷。在这个图表上,完美的模型将是一条从左下角到右上角的直对角线。

剖析对数损失: 虚线完美地匹配真实标签和预测概率。蓝点突出了实际预测与这一理想之间的偏差,说明了对数损失的概念,其中离线距离越大表示预测误差越高。图表由作者生成。

机器学习的实际影响

现在是时候将理论付诸实践了。我们已经穿越了数学领域,涉足了直观的湖泊。那么实际的土壤如何?对数损失的理解如何为机器学习项目的花园施肥?

优化模型

如何通过最小化对数损失来获得更强健的模型

让我们切入正题。最小化对数损失是一个模型校准的过程。把它看作是调音:你越接近完美音符(即真实标签),你的表现就越好。当你最小化对数损失时,你是在告诉你的模型对真实结果“惊讶”得更少,从而做出更准确、可靠的预测。

优化的技术和策略

现在,你如何给这个植物浇水?有很多方法:梯度下降法、超参数调整、集成技术等。一种广泛使用的方法是结合交叉验证的网格搜索。

from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
​
params = {'C': [0.1, 1, 10], 'solver': ['liblinear', 'lbfgs']}
grid_search = GridSearchCV(LogisticRegression(), params, scoring='neg_log_loss', cv=5)
grid_search.fit(X_train, y_train)
​
print(f"Best Parameters: {grid_search.best_params_}")

通过微调这些拨轮,你可以优化对数损失,从而创建一个更强健的模型。

案例研究

对数损失优化的一个突破性应用可以在医疗行业找到,特别是在早期癌症检测中。通过降低对数损失,模型能更好地发现癌细胞,这对于早期开始治疗至关重要。

另一个案例来自金融领域,信用评分模型已经通过使用对数损失作为性能指标进行了微调。结果?更准确的风险评估和更聪明的借贷决策。

所以,我们现在站在理论与实践结合的沃土上。对数损失不仅仅是一个数学抽象或知识辩论的主题;它是一个强大、可操作的杠杆,可以将机器学习项目的轨迹从普通转向非凡。

常见陷阱及如何避免它们

数值稳定性

数值稳定性——或缺乏稳定性——是许多人跌入的陷阱。在计算概率的对数时,将那个数字不断推向零会导致数值不稳定,从而在计算中引发混乱。

为了缓解这个问题,通常会对预测的概率应用一个小的 epsilon ϵ:

这个图表全面展示了 epsilon 值如何影响对数损失计算的稳定性,这一点在处理接近 0 或 1 的概率时尤为重要。图表由作者生成。

你可以这样修改一个 Python 对数损失函数:

def calculate_stable_log_loss(y_true, y_pred, epsilon=1e-15):
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
​
# Example usage
log_loss_stable = calculate_stable_log_loss(y_true, y_pred)
print(f"Stable Log-Loss: {log_loss_stable}")

输出:

Stable Log-Loss: 0.20273661557656092

什么时候不使用对数损失

现在,即使是最锋利的刀子也不适用于每个任务。对数损失也是如此。在类别不平衡严重或假阳性和假阴性的成本差异很大的分类问题中,像 F1-score、精度或召回率这样的指标可能更适合。

另一个有价值的角度是对比平衡数据集与不平衡数据集中的对数损失变化。这是一个关键洞察,尤其是对处理常见类别不平衡的实际问题的机器学习工程师来说。在这些直方图中: 平衡数据集的对数损失分布:第一个图展示了类别平衡时对数损失值的分布。值倾向于分布在更广的范围内,反映出模型的不同确定性水平。不平衡数据集的对数损失分布:第二个图展示了不平衡数据集中对数损失值的分布。注意范围通常较窄,反映出由于不平衡,模型可能对其预测过于自信。理解平衡和不平衡数据集中对数损失的细微差别可以帮助机器学习工程师更有效地调整模型评估和调整策略。图表由作者生成。

另一种情况?对于具有两个以上标签的多类别问题,虽然可以将对数损失推广到这些情境中,但通常需要更加直接和易于解释。在这些情况下,像分类交叉熵或简单准确率这样的指标可能更有效。

结论

在这里,我们站在对数损失复杂领域智力之旅的顶峰。掌握了理论和实践智慧后,我们可以重温我们所开辟的路径。让我们回顾一下。

理解对数损失就像掌握调节复杂仪器的艺术。它让你具备了调整概率模型的能力,生成可靠且易于解释的预测。在数据时代,模型影响着从医疗保健到金融的方方面面,这种掌握不仅是美好之事,而是必需的。

知识在应用时最为有益。是时候卷起袖子,把手深入你项目的肥沃土壤中了。调整超参数,尝试不同的优化技术,不要畏惧采取计算过的风险。在实验的熔炉中,理论的金属被锻造成应用的利剑。

当我们结束这个博客时,我相信对更深层次理解的追求不会止步于此。如果知识的追求不是一段无尽的旅程,那它是什么呢?你是否将带着这份新获得的智慧,勇敢迈向你的项目?地平线在召唤。

联系

想要联系?关注罗宾逊博士的LinkedInTwitterFacebookInstagram。访问我的主页,获取论文、博客、邮件注册及更多内容!

[## 主页 | 乔·罗宾逊的网站 | 研究工程师 | 企业家

罗宾逊博士在计算机视觉、模式识别、MM 和多模态方面有超过 35 篇论文。曾在各个行业工作过……

www.jrobs-vision.com](https://www.jrobs-vision.com/?source=post_page-----84c668f4024a--------------------------------)

使用扩展的 Databricks MLFlow 保障 MLOps 的安全

原文:towardsdatascience.com/secure-mlops-with-extended-databricks-mlflow-ee9b7310c5b3?source=collection_archive---------31-----------------------#2023-01-10

安全管理模型目标环境

Luuk van der VeldenTowards Data Science Luuk van der Velden

·

关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 1 月 10 日

--

MLFlow 是一个开源的 Databricks 产品,支持机器学习模型生命周期的部分功能。它的模型注册表允许将模型版本以模型名称注册,以便于模型部署。我们希望使用一个 MLFlow 实例和一个 Databricks 工作区来支持多个部署目标(接受、预发布、生产等),同时为生产模型提供安全保障。我们扩展了 MLFlow 客户端,以安全的方式在一个模型注册表中管理多个环境。该客户端与 Databricks 权限协同工作,我们将展示 Azure terraform 代码片段。

github 仓库:环境 MLFlow 客户端存储库

Aleksa Kalajdzic 拍摄的照片

1 为什么选择一个 Databricks 工作区?

安全的 Databricks 工作区部署在虚拟网络中,这要求控制平面和计算平面分别使用不同的子网。计算平面子网的 IP 范围限制了可以同时使用的并行计算节点的最大数量。由于公司范围内的 IP 范围大小限制,我们选择为一个 Databricks 工作区使用一个大的子网,而不是多个较小的子网绑定到多个 Databricks 工作区。因此,我们希望在一个 Databricks 工作区中管理多个逻辑环境(验收、预生产和生产)。

1.2 使用 Databricks 池进行扩展

Databricks 池允许在集群创建过程中重用闲置实例(虚拟机)。池提供了适中的启动和自动缩放速度的好处。池中的每个实例需要计算平面子网中的一个 IP。如果达到最大限制,下一个实例创建请求将失败,且没有优雅的错误处理或回退机制。Databricks 池的这个限制使得拥有一个大的子网对于你的活动至关重要,因为你随时可能达到最大限制,从而导致应用程序崩溃。

2 基于身份的安全

我们的安全策略基于身份,主要是 Azure 和 Databricks 用户的应用程序注册,背后由 Azure Active Directory (AAD) 用户支持。例如,我们在生产数据的相同 Databricks 工作区中运行验收测试,使用具有只读权限的独立身份,将生产数据复制到临时验收测试存储帐户。身份在 Azure 上非常方便,因为 Databricks 支持凭证透传,利用我们在身份上设置的所有 AAD 角色和组。我们还依赖于 Databricks 权限和组,因为我们在一个 Databricks 工作区中使用多个身份。

2.1 Databricks 组

我们的单个 Databricks 工作区中的每个逻辑环境都有一个名为 f”apps_{env_name}”(例如 apps_acc, apps_prod)的 Databricks 组,用于其应用程序主体,以及一个包含所有活动应用程序主体的组“apps_all”。这些组及其权限由 terraform 管理。

Databricks 组 terraform 摘录

3 环境 MLFlow 客户端

3.1 MLFlow 实验跟踪

MLFlow 实验存储被调整以支持多个逻辑环境,通过按环境管理存储位置。我们的解决方案分配了如“/experiments/acc”和“experiments/prod”的目录来存储实验数据。使用 Databricks 权限管理来授予相应 Databricks 组(在本例中为“apps_acc”和“apps_prod”)目录权限。这允许对每个逻辑环境的实验和模型进行安全日志记录,无需用户考虑。

MLFlow 实验 Databricks 权限 terraform 摘录

3.2 MLFlow 模型注册表

MLFlow 模型注册表是一个中央位置,用于注册模型和模型版本以供 ML 系统使用。在不同环境中的数据科学实验中记录和注册的模型最终会进入同一个模型注册表。MLFlow 模型注册表较难调整以支持多个逻辑环境。实际上,MLFlow 中的部署目标概念假设每个特定模型的版本可以具有不同的“阶段”。模型阶段是注册在 MLFlow 模型注册表中的模型版本的一个属性。它可以设置为“None”(无)、“Staging”(暂存)、“Production”(生产)和“Archived”(归档)。尽管有帮助,但模型阶段非常有限,因为我们不能为其定义自己的值,因此它不能完全支持我们定义各种逻辑环境的需求。我们确实使用它来管理与 Databricks 的模型版本权限,我们将在后文中描述。

模型命名

我们决定模型命名将作为区分任意数量逻辑环境的第一层。每个注册的模型名称都后缀带有环境标识符 f”{model_name}_{env_name}”。我们的 MLFlow 客户端在模型注册和检索过程中透明地管理这种命名,基于从系统环境变量获取的环境名称或传递给其构造函数的环境名称。与实验管理类似,由于我们在环境 MLFlow 客户端中的抽象,接口与原版 MLFlow 大致相同。

MLFlow 客户端摘录显示环境模型命名

模型权限

我们的目标是将生产模型与其他环境分离,并防止任何非生产主体注册或检索生产模型版本。如前所述,模型阶段可以设置为“None”(无)、“Staging”(暂存)、“Production”(生产)和“Archived”(归档)。Databricks 权限管理与模型阶段的值相关联。我们可以控制谁可以设置“Staging”和“Production”模型阶段值,以及谁可以使用“CAN_MANAGE_STAGING_VERSIONS”和“CAN_MANAGE_PRODUCTION_VERSIONS”权限管理具有这些模型阶段值的模型。“Production”阶段权限还允许主体访问“Staging”阶段模型版本并将其过渡到生产模型版本。

为了将生产模型与其他环境中的模型区分开来,我们在注册模型版本时自动分配模型阶段。所有非生产环境将模型阶段值分配为“Staging”。从生产环境注册的模型版本分配“Production”阶段值。我们利用分配给我们 Databricks 组的 Databricks 权限,安全地将我们的生产模型与其他逻辑环境分开管理。

Databricks MLFlow 模型权限 Azure terraform 摘录

3.3 汇总

除了对原生 MLFlow 客户端进行的抽象之外,我们还添加了一些封装各种操作的方法以方便使用。其中一个额外的方法是“log_model_helper”,它处理记录和注册模型版本以及设置适当模型阶段的各种步骤。

注册模型版本通常分为两个步骤,首先在实验运行期间记录一个模型,然后将记录的模型注册为已注册模型的模型版本。在实验运行期间记录模型会返回一个 ModelInfo 对象,其中包含一个指向本地工件位置的 model_uri。我们使用 model_uri 将模型版本注册到已注册模型名称下,并将其上传到 MLFlow 注册表。注册模型版本会返回一个 ModelVersion 对象,该对象告诉我们自动递增的模型版本和模型的 current_stage,创建后会始终是“None”。

在模型版本注册期间的第三步是根据逻辑环境将模型版本阶段从“None”过渡到“Staging”或“Production”。每个人都可以创建模型版本,但阶段转换受到 Databricks 权限的限制。这三个步骤都封装在“log_model_helper”方法中。使用这个辅助方法,我们可以假设所有已注册的模型版本都有环境感知的名称和适当的阶段值。

log_model_helper 摘录

3.4 测试我们的 MLFlow 客户端

为了维护我们自己的客户端,我们需要详细测试它,以检查是否符合我们的权限设计。对上游 MLFlow 客户端的任何更改也会同时进行测试,例如 MLFlow 模块的重构。我们在一个 PyTest 固定装置内启动一个本地 MLFlow 服务器,该装置的作用范围是整个测试会话。使用 Python tempfile 模块生成一个临时工件位置和一个临时 sqlite 数据库文件。我们测试会话的目标是记录和注册模型,并对其进行各种变更和检索。第一步是在空模型注册表中记录和注册一个模型版本。我们在 PyTest 固定装置内执行此操作,因为所有其他测试都依赖于它,尽管从技术上讲,它本身就是一个测试,因为它包括断言。

MLFlow 测试服务器固定装置,感谢 @Wessel Radstok 和 Rik Jongerius

模型注册的测试夹具

结论

我们扩展的环境 MLFlow 客户端允许我们将多个逻辑环境中的模型注册到同一个模型注册表中。它利用 Databricks 中的最小权限选项,在模型注册和部署检索期间安全地将开发环境与生产环境分开。此外,MLFlow API 在各环境中基本保持不变。

最初发布于 https://codebeez.nl

保护你的容器化模型和工作负载

原文:towardsdatascience.com/securing-your-containerised-models-and-workloads-3bff4d90a07b

切换到非根用户!

Jake TeoTowards Data Science Jake Teo

·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 10 月 24 日

--

容器化现在是部署许多应用程序的事实上的方法,其中 Docker 是推动其采用的前沿软件。随着其流行,也带来了攻击风险的增加 [1]。因此,我们需要确保保护我们的 Docker 应用程序。实现这一目标的最基本方法是确保我们在容器内设置非根用户。

CONTENTS
========

Why use non-root?

What you can & cannot do as a default non-root user

The Four Scenarios
  1) Serve a model from host (Read Only)
  2) Run data processing pipelines (Write within Container)
  3) Libraries automatically writing files (Write within Container)
  4) Save trained models (Write to Host)

Summary

为什么要使用非根用户?

或者说,为什么不使用根用户?让我们以下面的虚拟架构为例。

黑客进入一个具有根权限的容器。图片由作者提供

安全通常被视为多层防御。如果攻击者设法进入容器,作为用户的权限将是第一道防线。如果容器用户被分配了根权限,攻击者就可以自由控制容器内的一切。凭借如此广泛的访问权限,攻击者还可以利用任何潜在的漏洞,并可能通过这些漏洞逃脱到主机,获得对所有连接系统的完全访问权限。后果非常严重,包括以下几点:

  • 检索存储的机密

  • 拦截并扰乱你的流量

  • 运行恶意服务,如加密挖矿

  • 访问任何连接的敏感服务,如数据库

攻击者可能通过根权限横向渗透你的基础设施服务。图片由作者提供

真是太可怕了!不过解决方案很简单,将你的容器改为非根用户!

在继续阅读本文之前,如果你对 Linux 权限和访问控制没有很好的理解,请查看我之前的 文章 [2]。

作为默认的非根用户,你可以做什么和不能做什么

让我们尝试创建一个具有默认非 root 用户的简单 Docker 应用程序。使用下面的Dockerfile

# Dockerfile
FROM python:3.11-slim

WORKDIR /app

# create a dummy py file
RUN echo "print('I can run an existing py file')" > example.py

# create & switch to non-root user
RUN adduser --no-create-home nonroot
USER nonroot

构建镜像并使用它创建一个容器。

docker build -t test .
docker run -it test bash

现在你在容器内,让我们尝试几个命令。那么你不能做哪些事情呢?你可以看到各种写入和安装权限都是不允许的。

作为默认的非 root 用户你不能做的事情。作者截图。

在相反的情况下,我们可以运行各种读取权限。

读取命令是可以的。作者图片

由于我们安装了 python,它有些独特。如果我们ls -l $(which python),可以看到 python 解释器具有完全权限。因此,它可以执行像我们在Dockerfile中最初创建的example.py文件这样的现有 python 文件。我们甚至可以进入 python 控制台并运行简单命令。然而,当我们切换到非 root 用户时,其他系统写入权限已被移除,因此你会看到我们无法创建和修改脚本,或使用 python 执行写入命令。

可以执行现有的 python 脚本,但其他操作则不可行。作者图片

虽然系统范围的限制对安全性有好处,但在许多情况下,特定文件和目录的写入权限是必需的,我们需要考虑这些许可。

在接下来的章节中,我将给出机器学习操作生命周期中的四种场景的示例。通过这些示例,可以了解如何实现大多数其他情况。

四种场景

1) 从主机服务模型 — 只读

服务模型时,它涉及到一个推理和服务脚本,用于加载模型并通过 API(例如 Flask、FastAPI)暴露它以接受输入。模型有时从主机加载,并与镜像分开,以使镜像大小尽可能小,任何重新加载镜像都将尽可能快速而不重复下载模型。然后,模型通过一个bind-mount卷传递到容器中,以便加载和服务。

服务模型只需要读取权限。作者图片

这可能是实现非 root 用户的最简单方法,因为只需要读取权限,而所有用户默认都被授予此权限。下面是一个如何完成的示例Dockerfile

# Dockerfile
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip3 install --no-cache-dir --upgrade pip~=23.2.1 \
    && pip3 install --no-cache-dir -r requirements.txt

COPY ./project/ /app

# add non-root user ---------------------
RUN adduser --no-create-home nonroot
# switch from root to non-root user -----
USER nonroot

CMD ["python", "inference.py"]

它有两个简单的命令,首先创建一个名为nonroot的新系统用户。其次,在最后一行CMD之前,从 root 切换到nonroot用户。这一点很重要,因为默认的非 root 用户没有任何写入和执行权限,因此不能安装、复制或操作在早期步骤中所需的文件。

现在我们知道如何在 Docker 中分配非 root 用户,接下来我们进入下一步。

2) 运行数据处理管道 — 在容器内写入

有时,我们只是想存储临时文件以执行一些任务,例如,一些数据预处理工作。这包括添加和删除文件。我们可以在容器内完成这些任务,因为文件不是持久的。

写入文件。作者图片

如果我们使用非 root 用户,我们需要写入权限。为此,我们需要使用命令chown(更改所有者)并将特定文件夹的所有权分配给nonroot用户。完成后,我们可以切换到nonroot用户。

# Dockerfile

# ....

# add non-root user & grant ownership to processing folder
RUN adduser --no-create-home nonroot && \
    mkdir processing && \
    chown nonroot processing

# switch from root to non-root user
USER nonroot

CMD ["python", "preprocess.py"]

3) 库自动写入文件 — 在容器内写入

之前的示例展示了如何写入我们自己创建的文件。然而,库自动创建文件和目录也是很常见的。你只有在尝试运行容器时才会知道它们已被创建,并且被拒绝写入权限。

我将给你展示两个这样的示例,一个来自supervisor,用于管理多个进程,另一个来自huggingface-hub,用于从 huggingface 下载模型。如果我们切换到非 root 用户,将会看到类似的权限错误。

Supervisor 阻止了创建日志文件。作者截图。

Supervisor 阻止了将 PID 存储在文件中。作者截图。

Huggingface Hub 阻止了下载模型文件。作者截图。

对于两个supervisor文件,我们可以先将它们创建为空文件,并赋予所有权。对于huggingface-hub的下载问题,错误日志中已提示我们可以通过TRANSFORMERS_CACHE变量更改下载目录,因此我们可以先设置目录变量,创建目录,然后分配所有权。

# Dockerfile

# ....

# add non-root user ................
# change huggingface dl dir
ENV TRANSFORMERS_CACHE=/app/model

RUN adduser --no-create-home nonroot && \
    # create supervisor files & huggingfacehub dir
    touch /app/supervisord.log /app/supervisord.pid && \
    mkdir $TRANSFORMERS_CACHE && \
    # grant supervisor & huggingfacehub write access
    chown nonroot /app/supervisord.log && \
    chown nonroot /app/supervisord.pid && \
    chown nonroot $TRANSFORMERS_CACHE
USER nonroot

CMD ["supervisord", "-c", "conf/supervisord.conf"]

当然,还会有其他示例可能与我这里展示的略有不同,但允许最少写入权限的概念将是相同的。

4) 保存训练好的模型 — 写入主机

假设我们使用容器训练模型,并希望将该模型写入主机,例如,以便被另一个任务用于基准测试或部署。在这种情况下,我们需要通过将容器目录链接到主机目录来写入模型文件,这也称为绑定挂载。

将模型文件写入主机。作者图片

首先,我们需要为nonroot创建一个组和用户,并为每个指定唯一的 ID,在此情况下,我们使用1001(1000 以上的任意数字都可以)。然后,创建一个模型目录来存储模型。

与情境 2 相比的不同之处在于,模型目录的写入不需要chown。为什么?

# Dockerfile

# ....
# add non-root group/user & create model folder
ENV UID=1001
RUN addgroup --gid $UID nonroot && \
    adduser --uid $UID --gid $UID --no-create-home nonroot && \
    mkdir model

# switch from root to non-root user
USER nonroot

CMD ["python", "train.py"]

这是因为绑定挂载目录的权限由主机目录决定。因此,我们需要在主机中再次创建相同的用户,确保用户 ID 相同。然后在主机中创建模型目录,并将nonroot用户授予所有者权限。

# in host terminal

# add the same user & group
addgroup --gid 1001
adduser --uid 1001 --gid 1001 --no-create-home nonroot
# create model dir to bind-mount & make nonroot an owner
mkdir /home/model
chown nonroot /home/model

绑定挂载通常在docker-compose.yml文件或docker run命令中指定,以启用更多灵活性。以下是前者的一个示例。

version: "3.5"

services:
    modeltraining:
        container_name: modeltraining
        build:
            dockerfile: Dockerfile
        volumes:
            - type: bind
              source: /home/model # host dir
              target: /app/model  # container dir

对于后者:

docker run -d --name modeltraining -v /home/model:/app/model <image_name>

运行其中之一,你会看到你的非根用户可以毫无问题地执行脚本。

摘要

我们已经看到如何分配非根用户并仍然使容器完成其所需任务。这主要在需要特定写入权限时相关。我们只需了解两个基本概念。

  • 在容器中进行写入权限的设置,可以使用Dockerfile中的chown

  • 对于绑定挂载的写入权限,请在主机中创建相同的非根用户,并在主机目录中使用chown

如果你需要进入 Docker 容器并以 root 用户身份运行一些测试,我们可以使用以下命令。

docker exec -it -u 0 <container_id/name> bash

参考文献

查看你使用 SAM 的分割效果

原文:towardsdatascience.com/see-what-you-sam-4eea9ad9a5de?source=collection_archive---------1-----------------------#2023-05-03

如何生成和可视化 Segment Anything Model 的预测结果

Jacob Marks, Ph.D.Towards Data Science Jacob Marks, Ph.D.

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 5 月 3 日

--

使用 Meta AI 的 Segment Anything Model (许可证) 对 Open Images V7 的样本进行分割 (许可证)。图片由作者提供。

在过去的几周里,Meta AI Research 的通用图像分割模型吸引了大量关注。这个模型,名为 Segment Anything Model (SAM) (Apache 许可证 2.0),是在一个包含 1100 万张图像和超过十亿个分割掩码的数据集上进行训练的。

SAM 功能强大。但像往常一样,在将模型投入生产之前,你需要了解模型在你的数据集上的表现。在计算机视觉的背景下,关键因素之一是可视化模型预测。

本博文旨在帮助你快速上手 SAM:我们将引导你如何使用 SAM 将分割掩码添加到你的数据集中,以及如何系统地可视化整个数据集中的这些分割掩码。通过可视化(和评估)这些预测,我们可以更好地了解 SAM 在我们数据集上的表现、其局限性以及将模型集成到我们管道中的潜在影响。

SAM 提供了多种生成分割掩码的途径:

  1. 自动:它自动工作,无需任何提示或提示

  2. 从边界框:给定一个边界框,SAM 对边界内的对象进行分割

  3. 从点:给定点标签,可能是正数或负数,SAM 推断需要分割的区域。

  4. 从点和框:你可以提供点边界框来提高性能

接下来,我们将明确介绍前三个。本文将按以下结构组织:

  • 设置

  • 使用 SAM 进行自动分割

  • 使用 SAM 进行语义分割

  • 使用 SAM 进行实例分割

设置

安装

本教程要求python≥3.8pytorch≥1.7torchvision≥0.8。如果你没有安装 Torch 或 Torchvision,请运行:

pip install torch torchvision

此外,我们将使用开源计算机视觉库FiftyOne来加载数据集和可视化预测。如果你没有安装 FiftyOne,可以运行:

pip install fiftyone

为了使用 SAM,你可以从源代码安装Segment Anything 库,使用:

pip install git+https://github.com/facebookresearch/segment-anything.git

然后,你将能够将库导入为segment_anything

然后,下载一个模型检查点。在这次演练中,我们将使用默认的ViT-H SAM 模型,即“huge”视觉变换器分割模型。如果你愿意,你也可以使用大型(ViT-L SAM)或基础(ViT-B SAM)模型。

导入模块

这是我们需要导入所有模块的头部代码:

import numpy as np
import PIL
import torch

import fiftyone as fo
import fiftyone.zoo as foz # for loading/downloading datasets

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

定义常量

我们还可以定义一些在所有分割应用中不会改变的元素:

sam_checkpoint = "path/to/ckpt.pth"
model_type = "vit_h"
device = "cuda"

sam = sam_model_registrymodel_type
sam.to(device)

加载数据集

对于本教程,我们将使用来自 Google 的Open Images V7Apache 许可证 2.0)数据集的图像。数据集已经为许多图像准备了实例分割掩码,但为了说明,我们只会加载点标签和目标检测边界框。关于如何处理 Open Images V7 中的点标签的全面教程,请查看此 Medium 文章

让我们从验证集中加载 100 个随机图像:

dataset = foz.load_zoo_dataset(
    "open-images-v7", 
    split="validation", 
    max_samples=100,
    label_types=["detections", "points"],
    shuffle=True,
)

我们将命名数据集并使其持久化。此外,我们还将通过运行compute_metadata()将图像的宽度和高度存储为像素,以便我们可以使用这些信息在绝对坐标和相对坐标之间进行转换:

dataset.name = "openimages_sam"
dataset.persistent = True
dataset.compute_metadata()

## visualize the dataset
session = fo.launch_app(dataset)

在我们开始添加 SAM 预测之前,数据集的外观如下所示:

Open Images V7 中的图像在 FiftyOne App 中可视化。图像由作者提供。

使用 SAM 进行自动分割

如果您没有任何现有的关键点或边界框来指导 Segment Anything 模型,您可以使用“自动分割”功能为图像中的任何对象和物品生成分割掩码。这是通过SamAutomaticMaskGenerator类完成的。请注意,这不是全景分割,因为掩码没有标注。

您可以实例化一个SamAutomaticMaskGenerator对象,并设置交并比(IoU)阈值、返回的掩码的最小面积和其他参数:

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.9,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=400
)

有关可接受参数的完整列表,请参阅此 SAM 笔记本

给定一个样本(位于sample.filepath的图像),我们可以使用 Pillow 读取图像,并调用我们的SamAutomaticMaskGenerator对象的generate()方法生成掩码:

image = np.array(PIL.Image.open(sample.filepath))
masks = mask_generator.generate(image)

这些掩码包含 2D 的“分割”数组,但没有标签。如果我们还想要标签,我们可以使用类似语义分割工具的库。为简单起见,我们将向您展示如何将所有这些组合成一个完整的图像掩码,并为我们的掩码生成器返回的每个单独的掩码分配一个不同的颜色。

要为单个样本添加“自动”分割掩码,我们可以将与该样本关联的图像传递给我们的掩码生成器。然后对于返回的每个掩码,我们可以将该掩码添加到完整的图像掩码中,乘以一个唯一的数字,使得显示颜色对应于该子掩码。然后,我们可以将这个完整的图像掩码作为Segmentation标签对象存储在我们的数据集中。

这包含在以下函数中:

def add_SAM_auto_segmentation(sample):
    image  = np.array(PIL.Image.open(sample.filepath))
    masks = mask_generator.generate(image)

    full_mask = np.zeros_like(masks[0]["segmentation"]).astype(int)
    for i in range(len(masks)):
        x, y = np.where(masks[i]['segmentation'])
        full_mask[x,y] = i + 1

    sample["auto_SAM"] = fo.Segmentation(mask=full_mask.astype(np.uint8))

只要在定义掩码生成器时设置crop_n_layers=1,添加步骤就是有效的。此代码可以处理最多 256 个唯一子掩码。

我们将遍历数据集中的样本,并在此过程中保存每个样本:

def add_SAM_auto_segmentations(dataset):
    for sample in dataset.iter_samples(autosave=True, progress=True):
        add_SAM_auto_segmentation(sample)

当我们在 FiftyOne App 中可视化结果时,我们看到的是:

Meta AI 的 Segment Anything Model 对 Open Images V7 样本的自动分割。图片由作者提供。

观察这些自动生成的掩码,我们可以看到有一些很小的斑点对我们并没有特别的意义。当我们定义掩码生成器时,我们将最小掩码区域面积设置为 400 像素。如果我们将此方法用作更大管道的一部分,我们可能需要考虑增加这一最小要求,或者根据图像中的像素数量为某些图像使用不同的最小值。

使用 SAM 进行语义分割

如果你的数据集中图像上有点标签(关键点),你可以使用这些点标签来提示 SAM 模型。这对于正负点标签都是适用的!本节将向你展示如何做到这一点。

在 FiftyOne 中,点标签表示为Keypoint对象。在 Open Images V7 中,图像上显示的每个单独点都存储在“points”字段中的自己的Keypoint对象中,因为它携带了额外的信息。

我们可以通过keypoints属性访问给定样本的点标签内容。例如,要获取数据集中第一个样本的第一个点标签,我们可以使用:

dataset.first().points.keypoints[0]
<Keypoint: {
    'id': '644c260d753fe20b7f60f9de',
    'attributes': {},
    'tags': [],
    'label': 'Rope',
    'points': [[0.11230469, 0.7114094]],
    'confidence': None,
    'index': None,
    'estimated_yes_no': 'no',
    'source': 'ih',
    'yes_votes': 0,
    'no_votes': 3,
    'unsure_votes': 0,
}>

这个点是类别Rope的负标签(estimated_yes_no字段),其结果由单个yesno投票数决定。在 Open Images V7 数据集中,点标签有estimated_yes_no("yes", "no", "unsure")。我们将忽略unsure点(这仅占总点标签的一小部分),并关注高确定性的点。

让我们实例化一个 SAM 预测器模型,用于语义和实例分割:

predictor = SamPredictor(sam)

为了初始化预测器,我们将通过point_coordspoint_labels参数传递图像中的点标签信息。

SamPredictor期望point_coords使用绝对坐标,而 FiftyOne 存储相对坐标。此外,point_labels接受01的数组,所以我们将从[yes, no]转换过来。以下函数接受给定图像的点标签列表、标签类别、图像宽度和高度,并返回所有相关点的point_coordspoint_labels

def generate_sam_points(keypoints, label, w, h):
    def scale_keypoint(p):
        return [p[0] * w, p[1] * h]

    sam_points, sam_labels = [], []
    for kp in keypoints:
        if kp.label == label and kp.estimated_yes_no != "unsure":
            sam_points.append(scale_keypoint(kp.points[0]))
            sam_labels.append(bool(kp.estimated_yes_no == "yes"))

    return np.array(sam_points), np.array(sam_labels)

对于单个样本,我们可以使用以下函数添加 SAM 语义分割掩码:

def add_SAM_semantic_segmentation(sample, n2i):
    image = np.array(PIL.Image.open(sample.filepath))
    predictor.set_image(image)

    if sample.points is None:
        return

    points = sample.points.keypoints
    labels = list(set([point.label for point in points]))

    w, h = sample.metadata.width, sample.metadata.height
    semantic_mask = np.zeros((h, w))
    for label in labels:
        sam_points, sam_labels = generate_sam_points(points, label, w, h)
        if not np.any(sam_labels):
            continue

        masks, scores, _ = predictor.predict(
            point_coords=sam_points,
            point_labels=sam_labels,
            multimask_output=True,
        )
        mask = masks[np.argmax(scores)].astype(int) ## get best guess

        semantic_mask *= (1 - mask)
        semantic_mask += mask * n2i[label]

    sample["semantic_SAM"] = fo.Segmentation(
        mask=semantic_mask.astype(np.uint8)
    )

这里,n2i 是一个字典,将类名映射到整数值,用于填充分割掩码。值得注意的是,当multimask_output=True时,预测器会为每个输入返回多个分割掩码的猜测。我们选择置信度最高的预测(最大 score)。

遍历数据集中的样本:

def add_SAM_semantic_segmentations(dataset):
    point_classes = dataset.distinct("points.keypoints.label")
    dataset.default_mask_targets = {i+1:n for i, n in enumerate(point_classes)}
    dataset.default_mask_targets[0] = "other"  # reserve 0 for background
    NAME_TO_INT = {n:i+1 for i, n in enumerate(point_classes)}
    dataset.save()

    for sample in dataset.iter_samples(autosave=True, progress=True):
        add_SAM_semantic_segmentation(sample, NAME_TO_INT)

我们可以为数据集生成分割掩码:

Meta AI 的 Segment Anything 模型对 Open Images V7 样本的语义分割。图片由作者提供。

当然,并不是所有的内容都进行了语义分割,因为图像中包含一些稀疏的点标签。向初始数据中添加更多点将导致数据集中图像的语义分割掩码更加密集。

我们还可以看到,虽然 SAM 在整个数据集上的表现相当不错,但在适当地分割摩托车轮子方面表现得比较困难。

使用 SAM 进行实例分割

如果你已经有了数据集中对象的边界框,你可以用这些边界框来提示 SAM 模型,并生成这些对象的分割掩码!方法如下:

与点标签一样,我们需要将边界框从相对坐标转换为绝对坐标。在 FiftyOne 中,边界框的存储格式为 [<top-left-x>, <top-left-y>, <width>, <height>],坐标在 [0,1] 范围内。而 SAM 边界框的格式为 [<top-left-x>, <top-left-y>, <top-right-x>, <top-right-y>],使用绝对坐标。以下函数将为我们执行转换:

def fo_to_sam(box, img_width, img_height):
    new_box = np.copy(np.array(box))
    new_box[0] *= img_width
    new_box[2] *= img_width
    new_box[1] *= img_height
    new_box[3] *= img_height
    new_box[2] += new_box[0]
    new_box[3] += new_box[1]
    return np.round(new_box).astype(int)

一旦我们为给定的对象检测生成了实例分割掩码,我们可以使用以下方式将掩码添加到检测对象中:

def add_SAM_mask_to_detection(detection, mask, img_width, img_height):
    y0, x0, y1, x1 = fo_to_sam(detection.bounding_box, img_width, img_height)    
    mask_trimmed = mask[x0:x1+1, y0:y1+1]
    detection["mask"] = np.array(mask_trimmed)
    return detection

要将实例分割掩码添加到图像中,我们需要遍历所有对象检测,使用 SamPredictor 对象与每个检测的边界框,并将生成的掩码添加到 FiftyOne Detection 对象中:

def add_SAM_instance_segmentation(sample):
    w, h = sample.metadata.width, sample.metadata.height
    image = np.array(PIL.Image.open(sample.filepath))
    predictor.set_image(image)

    if sample.detections is None:
        return

    dets = sample.detections.detections
    boxes = [d.bounding_box for d in dets]
    sam_boxes = np.array([fo_to_sam(box, w, h) for box in boxes])

    input_boxes = torch.tensor(sam_boxes, device=predictor.device)
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])

    masks, _, _ = predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
        )

    new_dets = []
    for i, det in enumerate(dets):
        mask = masks[i, 0]
        new_dets.append(add_SAM_mask_to_detection(det, mask, w, h))

    sample.detections = fo.Detections(detections = new_dets)

对于实例分割,将其扩展到整个数据集是很简单的:

def add_SAM_instance_segmentations(dataset):
    for sample in dataset.iter_samples(autosave=True, progress=True):
        add_SAM_instance_segmentation(sample)

按标签着色,我们得到的效果如下:

Meta AI 的 Segment Anything 模型对 Open Images V7 样本的实例分割。图片由作者提供。

注意:为了提高效率,你还可以批量处理这些预测!

全景分割

如果你希望使用 SAM 对数据集进行全景分割,你可以通过以下方式结合关键点和边界框的方法:

对于每一个边界对象,或者事物,如汽车或桌子:

  1. 生成对象周围的边界框,可以通过传统注释,或使用Grounding DINO,或者其他方法。

  2. 选择边界框的中心点作为该对象的默认关键点。如果发现这个点不在对象内部,请相应调整。

  3. 使用这些关键点和边界框来计算实例分割掩码

对于每个连续的物体区域(例如天空或草地):

  1. 添加一个或多个标记的关键点。

  2. 使用这些关键点计算语义分割掩码

填补空白:

  1. 在所有实例和语义分割掩码的基础上,识别重叠区域和没有任何掩码的区域。

  2. 使用适合你应用的策略处理这些区域。

结论

Meta AI 的 Segment Anything Model 功能强大且多才多艺。尽管如此,SAM 只是分割和提示/引导计算机视觉领域众多令人兴奋的进展中的一个。这个领域发展非常迅速!如果你对了解更多感兴趣,我建议你查看以下相关项目:

来源

通过声音看见世界:利用 GPT-4V(ision)和文本转语音技术赋能视觉障碍者

原文:towardsdatascience.com/seeing-with-sound-empowering-the-visually-impaired-with-gpt-4v-ision-and-text-to-speech-bb5807b4e08c

提升视觉障碍导航:将 GPT-4V(ision)和 TTS 集成以提供高级感官辅助

Luís RoqueTowards Data Science Luís Roque

·发表于Towards Data Science ·阅读时间 12 分钟·2023 年 11 月 16 日

--

本文由 Rafael Guedes 共同撰写。

介绍

OpenAI 的最新发展通过发布 GPT-4V(ision)和文本转语音(TTS) API,将 AI 的可用性提升到了一个全新的水平。为什么?让我们通过一个使用案例来激发它们的实际效用。对大多数人来说,走在街上是一个简单的任务,但对视觉障碍者来说,每一步都可能是一个挑战。传统的辅助工具如导盲犬和手杖虽然有用,但 AI 技术的融合开启了一个改善盲人社区独立性和移动性的全新篇章。配备隐蔽摄像头的简单眼镜足以彻底改变视觉障碍者体验周围环境的方式。我们将解释如何利用 OpenAI 的最新发布来实现这一点。

另一个有趣的使用案例是改变我们在博物馆和其他类似场所的体验。设想一下,博物馆中常见的音频导览系统被别在衣服上的隐蔽摄像头所取代。假设你正在参观一家艺术博物馆。当你在博物馆中漫步时,这项技术可以为你提供有关每幅画作的信息,并且可以按照你选择的特定风格进行讲解。假设你有点疲倦,需要一些轻松有趣的内容,你可以提示它‘给我一些关于这幅画的历史背景,但要有趣和引人入胜,甚至可以加些笑话’

增强现实(AR)呢?这种新技术能否改善甚至取代它?目前,AR 被视为我们可以叠加在对现实世界的视觉感知上的数字层。问题是,这可能很快变得嘈杂。这些新技术可能会在某些用例中取代 AR。在其他情况下,它可以使 AR 个性化,使我们能够以自己的节奏体验世界。

在这篇文章中,我们将探讨如何将 GPT-4V(视觉)和文本转语音(TTS)结合起来,使世界对视觉障碍者更具包容性和可导航性。我们将首先解释 GPT-4V(视觉)是如何工作的及其架构(我们将使用一些开源对应物来获取直觉,因为 OpenAI 不提供有关其模型的详细信息)。接下来,我们将解释什么是 TTS,以及用于将文本转化为音频信号的最常见模型架构。最后,我们将通过一步一步的实施,展示如何利用 GPT-4V(视觉)和 TTS 帮助视觉障碍者在葡萄牙波尔图的街道上导航。

图 1:OpenAI 发布了有关其 API 服务的多个更新,并将多模态引入 GPT-4(来源

一如既往,代码可以在我们的 Github 上找到。

什么是 GPT-4V(视觉)?

GPT-4 像 GPT-3.5 一样,是一个大型多模态模型,能够处理文本输入并生成文本输出 [1]。在最新的 OpenAI 公告中,他们表示 GPT-4 已被扩展为一个多模态大型语言模型(LLM)。这意味着该模型现在能够接收额外的输入模态,在这种情况下是图像。多模态 LLM 扩展了仅语言系统的影响,通过新的接口和能力,为更复杂的用例开辟了可能性。你可以在下图中看到使用 GPT-4V(视觉)的示例,其中视觉和推理能力一起工作,以检测图片中的一些复杂细微之处。

图 2:GPT-4 解释对人类来说不寻常的能力(来源

尽管 OpenAI 在其论文中明确表示 …

考虑到竞争环境以及像 GPT-4 这样的大规模模型的安全性,这份报告未包含有关架构(包括模型大小)、硬件、训练计算、数据集构建、训练方法或类似内容的更多细节。

…我们可以尝试估计 GPT-4V(视觉)的架构是什么样的。

我们知道 GPT-4V(视觉)接收文本和图像作为输入。因此,它很可能有三个主要组件:

  1. 编码器: 用于处理文本和图像数据的独立编码器

  2. 注意力机制: 它采用先进的注意力机制,使模型能够关注文本和图像输入中最相关的部分。

  3. 解码器: 根据编码器的潜在空间结合注意力层生成输出文本。

图 3:使用图像和文本作为输入的多模态模型的简单架构(图片由作者提供)

类似的架构可以在🦩 Flamingo 模型 [2]中找到,该模型由 DeepMind 创建。

Flamingo 旨在处理文本和视觉数据作为输入,以生成自由形式的文本作为输出。作者将其定义为视觉语言模型(VLM)。该模型有三个主要组件:输入处理、感知重采样器和整合两种数据类型并生成输出文本的层。

输入处理:Flamingo 接收视觉和文本数据。文本在进入语言模型(LM)之前经历常规的分词处理,而视觉输入由视觉编码器(VE)处理,将像素转换为更抽象的特征表示。

感知重采样器:此组件进一步处理视觉特征。它为图像添加时间感(在视频中尤为重要),并将数据压缩成更易于管理的格式。这对后续有效结合视觉和文本数据至关重要。

图 4:感知重采样器架构(来源

集成和输出: 处理过的视觉和文本数据随后被整合到 GATED XATTN-DENSE 层。该层采用了交叉注意力机制与门控函数结合,来有效地融合两种数据。该层的输出输入到 LM 层,最终由 Transformer 解码器生成自由形式的文本输出。

图 5:Flamingo 模型概述。Flamingo 模型是一系列视觉语言模型(VLM),可以同时接受视觉数据和文本作为输入,并生成自由形式的文本输出(来源)。

GPT-4V(ision) API

OpenAI 的 GPT-4V(ision) API 允许处理视觉和文本信息。我们在下面介绍使用该 API 的主要步骤。

设置环境

  • 在你的环境中安装 Python 依赖项,即 OpenAI 库。
pip install openai
  • 在你的 Python 脚本中导入必要的库。
import openai 
import os

配置 API 参数:利用**ChatCompletion**类,结合特定参数处理多模态(文本和图像)数据。

  • 模型参数:将其设置为**gpt-4-vision-preview**以启用对视觉和文本数据的处理。
params = {
    "model": "gpt-4-vision-preview",
    "messages": PROMPT_MESSAGES,
    "api_key": os.environ['OPENAI_API_KEY'],
    "headers": {"Openai-Version": "2020-11-07"},
    "max_tokens": 400,
}
  • 消息参数:这需要包括文本和图像。图像应以 base64 格式编码。
PROMPT_MESSAGES = [{
    "role": "user",
    "content": [
        "<Your Prompt>",
        {"image": image_in_base64_format}
    ],
}]

处理图像:在将其包含在 API 调用中之前,图像必须转换为 base64 格式。

import base64

def convert_image_to_base64(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

执行 API 调用:设置好参数后,发起 API 调用以处理输入数据。

response = openai.ChatCompletion.create(**params)
print(response)

什么是文本到语音?

TTS [4] 技术将书面文字转换为口语。这一复杂过程涉及多个阶段,每个阶段由不同的模型处理。首先,一个 字形到音素 模型将书面文字转换为音素单位。接下来,一个 音素到 Mel 谱图 模型将这些音素转换为视觉频率表示。最后,一个 Mel-谱图到音频 模型或 声码器 从这种表示生成实际的口语音频。

图 6:TTS 架构(图片来源于作者)

  1. 字形到音素转换:第一步是将书写的单词(字形)转换为音素,即发音的基本单位。例如,短语 **"Swifts, flushed from chimneys"** 将被转换为音标表示,如 **"ˈswɪfts, ˈfɫəʃt ˈfɹəm ˈtʃɪmniz"**。一个常用的模型是 G2P-Conformer [5]。

  2. 音素到 Mel 谱图:接下来,这些音素会被处理以创建 mel-谱图,这是声音频率随时间变化的视觉表示。这通常通过编码器-解码器架构来实现,例如 Tacotron2 [6]。在这个模型中,卷积神经网络(CNN)嵌入音素,然后通过双向长短期记忆(LSTM)网络进行处理。生成的 mel-谱图是一个关键的中间步骤,充当音素与最终音频输出之间的桥梁。

  3. Mel-谱图到音频转换:最后阶段涉及将 mel-谱图转换为实际的音频。这时需要一个声码器,通常使用先进的生成模型。由 DeepMind 开发的 WaveNet [7] 是一个很好的例子。它使用带有扩张因果卷积的深度神经网络,以确保正确的序列处理。每个预测的音频样本会反馈到网络中以预测下一个,从而使模型能够捕捉音频信号中的长距离模式。

文本到语音 API

OpenAI 提供了一个通过 API 访问的 TTS 服务,提供两个质量等级和六种不同的声音,以满足不同的需求和偏好。

质量选项

  • **tts-1**:适用于实时应用,这个选项提供较低的质量但具有减少延迟的优点。

  • **tts-1-hd**:适合于音频质量较高且延迟不是问题的场景。

选择声音

  • OpenAI 的 TTS API 具有六种独特的声音:**alloy****echo****fable****onyx****nova****shimmer**

  • 每种声音都有其独特的特性;例如,**Fable** 类似于播客叙述者的声音。

  • 你可以在 OpenAI 的 文本到语音指南 上预览这些声音。

发起 API 请求:

要使用 OpenAI 的 TTS API,发送请求到 **https://api.openai.com/v1/audio/speech**。你需要以下内容:

  1. 模型规格:根据你的需求选择 **tts-1**(低质量,低延迟)或 **tts-1-hd**(高质量,高延迟)。

  2. 输入文本:你想要转换为语音的文本内容。

  3. 语音选择:从可用的语音中选择一个最适合你的内容和观众的声音。

这是一个关于如何结构化 API 请求的基本示例:

response = requests.post(
    "https://api.openai.com/v1/audio/speech",
    headers={
        "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",
    },
    json={
        "model": "tts-1",
        "input": result.choices[0].message.content,
        "voice": "fable",
    },
)

如何实现一个帮助视障人士在街上行走时感到更安全的应用?

本节逐步描述了如何使用 OpenAI 的 GPT-4V(ision) 和 TTS 创建视频的音频描述。还涵盖了如何将生成的音频添加到视频中。

在我们的案例中,如介绍所述,我们创建了一个音频指南,帮助视障人士在街上行走,通过识别障碍物的位置并提供空间信息来辅助他们。

过程开始于导入必要的库并设置 Python 环境。我们使用如 cv2 进行视频处理和 openai 访问 GPT-4V(ision) 和 TTS 模型等库。

import base64
import cv2
import openai
import os
import requests
import time

from IPython.display import display, Image, Audio
from moviepy.editor import VideoFileClip, AudioFileClip
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip

接下来,我们加载并处理视频。视频被调整为可管理的分辨率,以确保不超过 OpenAI 的令牌限制。每一帧都被编码为 base64,这是 OpenAI API 所需的格式。

video = cv2.VideoCapture("video.mp4")

base64Frames = []
while video.isOpened():
    success, frame = video.read()
    if not success:
        break
    frame = cv2.resize(frame, (512,512))
    _, buffer = cv2.imencode(".jpg", frame)
    base64Frames.append(base64.b64encode(buffer).decode("utf-8"))

video.release()
print(len(base64Frames), "frames read.")

一个重要的步骤是为 GPT-4V(ision) 制作合适的提示。精心设计的提示会显著影响我们从模型中获得的输出。从我们的实验中,影响输出的两个主要组件是:

  • 动词如描述、叙述、讲述;

  • 确定我们想要的语音风格。

我们尝试的第一个提示之一是:‘这些是一个人在城市中行走的画面。描述周围的元素和障碍物,以帮助盲人成功导航。’ 这种结构使模型变得极其冗长和描述性。我们的使用案例需要输出更少噪音。

对我们有效的提示是:‘这些是一个人行走的画面。用 GPS 设备那样的简短句子简洁叙述周围的元素和障碍物,以帮助盲人成功导航。’ 这次模型给出了简短的句子,让我们获得了街道导航所需的基本信息。结果如下:

“在有纹理的路径上直行,将建筑物保持在右侧。继续前行,稍微向右转。保持直行,前方右侧有小悬挑。继续前进,经过悬挑,继续在平坦的路径上前行。直行,接近一个光线充足的区域。经过光线充足的区域后,过渡到有图案的人行道。跟随有引导线的瓷砖人行道直行。继续通过通道,保持柱子与自己平行。穿过通道,前方有小的下降。路径结束,准备在斑马线停下。站在斑马线前,等待听到过马路的信号。”

PROMPT_MESSAGES = [
    {
        "role": "user",
        "content": [
            "These are frames from a person walking. Narrate succinctly with short sentences like they do in GPS devices the elements and obstacles around you to help a blind person go through.",
            *map(lambda x: {"image": x}, base64Frames[0::100]),
        ],
    },
]

params = {
    "model": "gpt-4-vision-preview",
    "messages": PROMPT_MESSAGES,
    "api_key": os.environ['OPENAI_API_KEY'],
    "headers": {"Openai-Version": "2020-11-07"},
    "max_tokens": 350,
}

result = openai.ChatCompletion.create(**params)
print(result.choices[0].message.content)

一旦我们收到 GPT-4V(ision)的描述,下一步是将文本转换为音频。我们选择了寓言风格的声音,因为它的清晰度和叙述的相似性。使用 OpenAI 的 TTS API 将生成的文本转化为音频文件。

response = requests.post(
    "https://api.openai.com/v1/audio/speech",
    headers={
        "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",
    },
    json={
        "model": "tts-1",
        "input": result.choices[0].message.content,
        "voice": "fable",
    },
)

audio = b""
for chunk in response.iter_content(chunk_size=1024 * 1024):
    audio += chunk

with open('audio.mp3', 'wb') as file:
    file.write(audio)

最后一步是将音频与原始视频合并。

# Open the video and audio
video_clip = VideoFileClip("video.mp4")
audio_clip = AudioFileClip("audio.mp3")

# Concatenate the video clip with the audio clip
final_clip = video_clip.set_audio(audio_clip)
final_clip.write_videofile("video_audio.mp4")

你可以在这里查看最终视频。

结论

OpenAI 的 GPT-4V(ision)和 TTS API 开辟了新的可能性,解决了可以改变许多人日常生活的重要用例。我们探讨了一个专注于视障人士的包容性的用例,但我们本可以选择其他许多用例。如在介绍中所述,它们还让我们能够提升人类体验(例如博物馆导览),并根据个人的偏好和情况进行更好的定制。

在实施过程中,我们发现提示工程对将解决方案量身定制到我们的特定用例有显著影响。未来,其他方法如微调或某种类型的检索增强生成(RAG)可能会适用于 VLMs。我们看到这些方法在某些任务和情境中使 LLMs 更为有效。尽管输出展示了这些新模型的潜力,但仍有待完善。如我们实验中所见,VLM “说话”的方式像是你可以看到它所说的“跟随有引导线的瓷砖人行道直行。” 它还难以准确区分左右,这是一个值得进一步探索的有趣事实。

尽管面临这些挑战,OpenAI 的最新进展向我们展示了一个更加包容的未来,AI 可以提升体验。

保持联系:LinkedInX/TwitterMedium

参考文献

[1] OpenAI. (2023). GPT-4 技术报告。arXiv 预印本 arXiv:2303.08774。

[2] Alayrac, J.-B., Donahue, J., Luc, P., Miech, A., Barr, I., Hasson, Y., Lenc, K., Mensch, A., Millican, K., Reynolds, M., Ring, R., Rutherford, E., Cabi, S., Han, T., Gong, Z., Samangooei, S., Monteiro, M., Menick, J., Borgeaud, S., Brock, A., Nematzadeh, A., Sharifzadeh, S., Binkowski, M., Barreira, R., Vinyals, O., Zisserman, A., & Simonyan, K. (2022). Flamingo: 一种用于少样本学习的视觉语言模型。arXiv 预印本 arXiv:2204.14198。

[3] Brock, A., De, S., Smith, S. L., & Simonyan, K. (2021). 无需归一化的高性能大规模图像识别。arXiv 预印本 arXiv:2102.06171。

[4] Maheshwari, H. (2021 年 5 月 11 日). 基础文本到语音,解析。Towards Data Science。取自 towardsdatascience.com/text-to-speech-explained-from-basic-498119aa38b5

[5] Gulati, A., Qin, J., Chiu, C.-C., Parmar, N., Zhang, Y., Yu, J., Han, W., Wang, S., Zhang, Z., Wu, Y., 等. (2020). Conformer: 用于语音识别的卷积增强变换器。arXiv 预印本 arXiv:2005.08100。

[6] Shen, J., Pang, R., Weiss, R. J., Schuster, M., Jaitly, N., Yang, Z., Chen, Z., Zhang, Y., Wang, Y., Skerry-Ryan, R. J., Saurous, R. A., Agiomyrgiannakis, Y., & Wu, Y. (2018). 通过将 WaveNet 条件化于 Mel 频谱预测进行自然 TTS 合成。arXiv 预印本 arXiv:1712.05884。

[7] van den Oord, A., Dieleman, S., Zen, H., Simonyan, K., Vinyals, O., Graves, A., Kalchbrenner, N., Senior, A., & Kavukcuoglu, K. (2016). WaveNet: 一种原始音频的生成模型。arXiv 预印本 arXiv:1609.03499。

Segment Anything 3D for Point Clouds: 完整指南 (SAM 3D)

原文:towardsdatascience.com/segment-anything-3d-for-point-clouds-complete-guide-sam-3d-80c06be99a18

3D Python

如何利用 SAM 和 Python 构建 3D 点云的语义分割应用程序。附加内容:投影和 3D 点与 2D 像素之间关系的代码。

Florent Poux, Ph.D.Towards Data Science Florent Poux, Ph.D.

·发表于 Towards Data Science ·阅读时间 27 分钟·2023 年 12 月 13 日

--

针对 3D 环境的 Segment Anything 模型。我们将使用 3D 点云数据集检测室内空间的物体。特别鸣谢 Mimatelier,这位才华横溢的插画师创作了这张图片。

科技的飞跃真是疯狂,特别是看到人工智能(AI)应用于 3D 挑战时。能够利用最新的前沿研究来进行高级 3D 应用是非常令人振奋的。尤其是在将人类级别的推理能力带入计算机时,明确从我们观察到的 3D 实体中提取出正式化的意义显得尤为重要。

在本教程中,我们旨在将令人惊叹的人工智能进展与利用 3D 点云的 3D 应用程序结合起来。— 🐲 Florent & Ville

这不是一件容易的事,但一旦掌握,3D 点云与深度学习的融合将开辟对我们视觉世界的新维度的理解和解释。

人工智能与 3D 点云。© F. Poux

在这些进展中,Segment Anything 模型是最近的创新灯塔,尤其是在无需监督的全自动化方面。

我们用于 3D 数据的 Segment Anything 模型架构。它包括一个图像编码器、图像嵌入以及一些预处理操作,最后传递给解码器和提示编码器,输出结果为掩模。© F. Poux

在这份终极指南中,我们将开始一个实际的旅程,从模型的诞生到实际分割应用。但这里的目标是什么?

任务 🥷

好了,任务简报时间到了!你现在是你国家特种部队的多类成员,你必须在不被发现的情况下找到隐藏在特定建筑物中的危险材料(这里是 ITC 大楼)。

利用你出色的互联网黑客技能,你成功找到了你感兴趣的建筑部分的 3D 扫描。你现在需要快速找到定义你危险材料回收队的路径的方法。之后,队伍可以在不被察觉的情况下进行材料回收,你也就成功拯救了世界!

经过仔细研究并利用你的各种技能,你开发了一个 3D 数据处理工作流程,其中包括设置 3D Python 代码环境,通过 Segment Anything 模型处理 3D 点云,以突出场景的组成,如下所示。

Segment Anything 3D 的工作流程。我们有五个主要步骤(3D 项目设置、Segment Anything 模型、3D 点云投影、无监督分割和定性分析),并在下图中进一步细化。© F. Poux

这将允许你生成一个 3D 语义地图,在队伍到达现场之前的九十分钟内,能够准确定位材料的位置。你准备好了吗?

🎵读者注意: 这个实践指南是* UTWENTE 与合著者F. Poux* V. Lehtola* 的联合工作的一部分。我们感谢来自数字双胞胎 @ITC 项目的资助,该项目由特温特大学 ITC 学院授予。

1. 3D 项目设置

在我们深入探讨 Segment Anything 模型的奇迹之前,建立一个稳固的基础至关重要。设置适当的环境可以确保我们在整个过程中顺利进行,从而实现无缝的实验和探索。

在这个阶段,我们要确保我们的编码环境设置正确,并且具有强大的库。准备好了吗?

🤠 Ville: 这是在行动开始之前。如果你是从零开始做这件事,比如可能需要更新 CUDA 驱动程序,请单独预留一个或两个小时。你将下载几 GB 的内容。

3D 项目设置。我们首先设置环境,然后附加基础库、深度学习库和 IDE 设置。© F. Poux

1.1. 3D 代码环境设置

现在是时候动手了!我们的目标是使用 Segment Anything Model 对 3D 点云进行语义分割。这绝非易事!所以,首先的反应是查看 Segment Anything 的依赖项:访问 SAM Github

从那里,我们检查包的必要前提条件:

Segment Anything 中突出的依赖项。

🦊 Florent无论何时处理深度学习库或深度学习的新研究代码,都必须检查依赖项和安装建议。这将极大地影响实验的后续进展和复制所需的时间。

如你所见,我们需要使用以下库版本:

python ≥ 3.8
pytorch ≥ 1.7
torchvision ≥ 0.8

现在这一点搞定后,我们将生成一个虚拟环境以确保顺利进行!如果你想详细了解这一过程,建议你查看以下指南:

## 3D Python Workflows for LiDAR City Models: A Step-by-Step Guide

解锁 3D 城市建模应用程序的精简工作流的终极指南。教程涵盖了 Python…

towardsdatascience.com

为了不让你感到无助,下面是另一种快速轻量级的设置策略,使用 Miniconda

💡 注意Miniconda 是一个免费的 Conda 最小安装程序。它是 Anaconda 的“微型”版本,仅包含最少的依赖项。这些包括 Conda 包管理器、一个 Python 版本、它们所依赖的包以及其他有价值的包如 pip 和 zlib。这使得我们可以以轻量的方式只安装我们需要的东西。

🤠Ville:虚拟环境的酷炫之处在于,你可以将其导出并在强大的 Linux 计算机和超级集群上直接运行你的代码!这对于训练网络非常有用!

在从 这里 下载适合你操作系统的 Miniconda 版本后(建议选择 Python 3.9 或 3.10 版本以确保与包的兼容性),你可以按照安装过程中的各种步骤进行安装。

miniconda 安装程序窗口。© F. Poux

就这样!你现在已经完成了最简单的 Python 安装,使用轻量级的 miniconda 使得隔离受控的虚拟环境变得非常容易。在继续下一步之前,我们启动 miniconda 及其命令行访问:

在 Windows 中,只需搜索“miniconda”即可找到

一旦进入 Anaconda 提示符,我们按照下面显示的简单四步过程进行操作。

设置 Python 环境以使用 3D Segment Anything Model 的工作流程。© F. Poux

  1. 要创建新环境,我们输入:conda create -n GEOSAM python=3.10

  2. 切换到新创建的环境,我们输入:conda activate GEOSAM

  3. 要检查 Python 版本,输入 python --version,检查已安装的软件包:conda list。这应分别显示 Python 3.10 和基础库列表。

  4. 在新环境中安装 pip,我们输入:conda install pip

就这些!我们现在准备安装必要的库以进行 SAM 的操作。

[## 3D 创新者通讯

每周提供实用内容、见解、代码和资源,掌握 3D 数据科学。我写关于点云、人工智能……

learngeodata.eu

1.2. 基础库

本教程中使用的基础库(Numpy、Matplotlib、Laspy)。© F. Poux

我们现在安装用于 SAM 的基础库:NumPyLasPyOpenCVMatplotlibNumPy 可能是最推荐的数值计算库,OpenCV 用于计算机视觉任务,Laspy 处理 LIDAR 数据,而 Matplotlib 是一个绘图和数据可视化库。

🦊 Florent这些库是任何 3D 项目的基础和坚实的基石。如果你想深入了解它们,我建议你去 这个教程 ,它探讨了其深奥的内容 🪸。

要安装这些库,我们可以用一行代码通过 pip:

pip install numpy matplotlib laspy opencv-python

很好;是时候设置深度学习库了!

1.2 深度学习库

深度学习库。

我们现在将着手安装深度学习库。当然,我们首先探索的是我迄今为止最喜欢的:Pytorch。自 2017 年推出以来,Pytorch 优先考虑其灵活性和可黑客性,其次是性能。因此,今天,使用 Pytorch 进行深度学习应用是绝佳的选择,如果你需要 (1) 高性能执行,(2) Pythonic 内部实现,以及 (3) 有价值任务的良好抽象。

🦊 Florent: 自 2017 年以来,硬件加速器(如 GPU)在计算任务中的速度提高了约 15 倍。你只能猜测接下来几年会发生什么。因此,必须关注灵活的库,它们可以快速适应,甚至对“内部”进行重构,如 Pytorch 所做的那样。

🤠Ville: SAM 作者推荐使用 8GB 内存的 GPU。然而,我们提供了一些如何在内存较少的情况下进行教程的技巧。如果你收到‘MemoryError’或‘Out-of-bounds memory access’或‘Illegal memory access’消息,请使用这些技巧。我使用 6GB 内存成功运行了它。

为了无忧地安装 Pytorch 的相关发行版,而不必为如何安装 CUDA(这并不简单)而烦恼,他们制作了一个简单的网页应用程序,生成代码以复制并粘贴到你的命令行中。为此,你可以访问这个 Pytorch 入门页面 并选择最相关的安装方式,如下所示。

如何根据你的操作系统和配置安装 Pytorch。

💡 注意: 我们希望充分利用我们的 GPU。因此,重要的是要注意我们希望进行 CUDA 安装。但这只有在你写这篇文章时拥有 Nvidia GPU 时才可能。如果没有,你可能需要使用 CPU 或切换到像 Google Colab 这样的云计算服务。

因此,我们的代码行如下:

conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia

这行代码将触发必要元素的检索和安装,以便 Pytorch 可以顺利运行。

Pytorch 的安装。

我们要使用的第二个深度学习库是 Segment Anything。在 Pytorch 安装的同时,我们可以下载并安装“软件”,这将使我们更容易管理版本和访问在线库。这就是Git,可以在这里访问:Git 官网

你可以下载并安装 git,一旦安装完成,Pytorch 也应该在你的环境中顺利安装。因此,为了安装 segment-anything,我们可以写如下代码:

pip install git+https://github.com/facebookresearch/segment-anything.git

这将再次需要一些时间,直到你看到如下消息。

安装 Pytorch 的 CLI 结果。

在这一阶段,我们已经安装了基础库以及深度学习库。在使用它们之前,让我们安装一个 IDE,以便一切顺利运行。

4. 设置 IDE

Jupyter lab IDE 安装后的界面。

我们设置的最后一步是安装一个 IDE。我们仍在环境中的命令行界面下,输入:pip install jupyterlab,这将会在我们的环境中安装 jupyterlab。

要在指定的本地文件夹中使用它,我们可以首先为我们的项目创建一个父目录(我们称之为 SAM),该目录将包含一个 CODE 文件夹和一个 DATA 文件夹。完成后,在控制台中,我们通过写入 cd C://COURSES/POUX/SAM 来切换到创建的目录。

我们通过在控制台输入 jupyter lab 来从这个位置启动 jupyterlab,这将会在你的网页浏览器(Chrome、Firefox 或 Safari)中打开一个新的本地主机页面。

在 Jupyter 中,你可以创建一个笔记本 (.ipynb),并在第一个单元格中写入导入语句,以使用所有已安装的软件包:

# The Base libraries
import numpy as np
import matplotlib.pyplot as plt
import cv2
import laspy

# The Deep Learning libraries
import torch
from segment_anything import sam_model_registry
from segment_anything import SamAutomaticMaskGenerator

好的!我们已经准备好了一切。在开始编码模型的其他步骤之前,现在正是提取我们 3D 数据集的好时机。

5. 3D 数据集整理

在之前的教程中,我们展示了多个 3D 数据集中的点云处理和网格化,其中一些使用了 AHN4 LiDAR 活动的航空 LiDAR。

## 3D 深度学习 Python 教程:PointNet 数据准备

终极 Python 指南,用于结构化大型 LiDAR 点云,以训练 3D 深度学习语义分割……

towardsdatascience.com

这次,我们将使用一个使用地面激光扫描仪收集的数据集:2023 年乌特勒支大学 ITC 的新建筑,如下所示。它包含一个室内绿色区域,内有精美的复杂树叶,分割后可以进行评估。

ITC UTwente 新建筑的 3D 点云,包括其室内“丛林”。© F. Poux

你可以从这里下载数据: 指南数据集(Google Drive),并将其放入你保存数据集的文件夹中(在我的例子中是“DATA”)。

在这个过程中阶段,我们有一个良好的编码设置,所有必要的库都在一个轻量级、隔离的 GEOSAM conda 环境中。

🦊 Florent: 到目前为止做得很好!如果你急于运行一些测试以检查 Pytorch 是否正常工作,即 CUDA 是否被识别,你可以写下以下代码行:

import torch
print('CUDA available -> ', torch.cuda.is_available())
print('CUDA GPU number -> ', torch.cuda.device_count())
print('GPU -> ', torch.cuda.get_device_name())

上述打印结果的配置。

现在是时候对数据进行分割了!

[## 3D 创新者通讯

每周实用内容、见解、代码和资源,以掌握 3D 数据科学。我写关于点云、人工智能等内容……

learngeodata.eu

2. 设置 Segment Anything 模型

我们的小冒险的核心是 Segment Anything Model,这是一种强大的创作,具有极好的 3D 点云语义分割潜力。凭借其创新的架构和训练过程,该模型是室内应用测试的理想候选者。让我们先来了解一下其核心概念。

2.1. Segment Anything 基础知识

MetaAI 已深入探讨自然语言处理(NLP)和计算机视觉的迷人领域,其 Segment Anything Model 使 零-shot少-shot 学习 在新数据集和任务上成为可能,使用基础模型。

🦊 Florent: 好吧,我承认有很多脏话。为了清晰起见,这里是我对每个复杂术语的总结尝试。零-shot 学习指的是在未见过某物的情况下识别它(零次见过)。类似地,少-shot 学习使用有限数量的标记示例来处理每个新类别,目标是根据这些少量的标记数据进行预测。

🤠Ville: 此外,所谓的 基础模型 是一个在大量数据上训练的模型。它如此庞大,可以适应来自不同场景的各种任务。

让我们为你拆解一下:

总体而言,SAM “AI” 算法可以显著减少进行图像分割所需的人力。为此,你需要向模型提供前景/背景点、粗略的框或掩码、一些文本或任何其他指示你想要在图像中分割的输入。Meta AI 团队已训练 Segment Anything Model 以生成合适的分割掩码。这个掩码是模型的输出,应该是一个适合划定提示可能指向的事物的掩码。例如,如果你在房子屋顶上标出一个点,输出应该正确识别你是指屋顶还是房子。

Segment Anything Model (SAM) 是如何工作的?解释分割提示以生成有效的掩码(以房屋为例)。© F. Poux

该分割任务可以用于模型预训练,并指导解决各种下游分割问题。

从技术角度来看,我们所称的图像编码器为每张图像创建了一个独特的嵌入(表示),而一个轻量级的编码器迅速将任何查询转换为嵌入向量。这两个数据源通过一个(轻量级的)掩码解码器合并,以预测分割掩码,如下所示。

Segment Anything Model 的工作流程图。图像经过图像编码器处理。然后它被嵌入,最后在使用提示和提示编码器后合并,以生成我们 3D 点云的最终掩码。© F. Poux

这种有效的架构,加上大规模的训练阶段,使 Segment Anything Model 达到了四个里程碑:

  • 轻松对象分割 🔥: 使用 SAM,用户可以通过简单选择要包括或排除的点来轻松分割对象。你还可以使用边界框作为模型的提示。

  • 处理不确定性 🔥: SAM 能够处理对象分割中的不确定情况。它可以生成多个有效的掩码,这对于有效解决实际的分割挑战至关重要。

  • 自动对象检测与掩膜 🔥: SAM 使得自动对象检测和掩膜变得轻而易举。它简化了这些任务,节省了你的时间和精力。

  • 实时交互 🔥: 得益于预计算的图像嵌入,SAM 可以即时提供任何提示的分割掩膜。这意味着你可以与模型进行实时交互。

既然这些都解决了,你准备好使用它了吗?

2.1. SAM 参数

SAM 模型可以通过三种不同的编码器加载:ViT-B、ViT-L 和 ViT-H。ViT-H 的结果优于 ViT-B,但与 ViT-L 相比仅有微小的提升。

|     Encoder          |   #parameters    |     Speed      |   Quality    |
|----------------------|------------------|----------------|--------------|
|   ViT-B   (basic)    |       91M        |     Fastest    |   Low        |
|   ViT-L   (large)    |       308M       |     Fast       |   High       |
|   ViT-H   (huge)     |       636M       |     Slow       |   Highest    |

🤠 Ville: 为了帮助选择,我在 NVIDIA GeForce GTX 1650、6 Gb VRAM 和 Win11 上测试了 ViT-B。

这三种编码器具有不同的参数数量,这为应用程序的调优提供了更多自由。ViT-B(最小)有 9100 万个参数,ViT-L 有 3.08 亿个参数,而 ViT-H(最大)有 6.36 亿个参数。

这种大小差异也会影响推断速度,因此这应有助于你为你的具体用例选择编码器。按照本指南,我们将使用重型武器:ViT-H,带有一个可以从Github(2.4 Gb)下载的模型检查点,并将其放置在你的当前父文件夹中,例如。

在这里,我们可以定义两个变量,以使你的代码在之后稍微更灵活一些:

MODEL = "../../MODELS/sam_vit_h_4b8939.pth"

#You can run the line below to test if you have the ability to leverage CUDA
torch.cuda.is_available()

#Choose between cpu or cuda training. For cpu, input 'cpu' instead 'cuda:0'
USED_D = torch.device('cuda:0')

从这里,我们可以用以下两行代码初始化我们的 SAM 模型:

sam = sam_model_registry"vit_h"

#Cast your model to a specific device (cuda or cpu)
sam.to(device = USED_D)

一切都准备好了!也许最后一步,试试看它在你桌面上的随机图像上的表现如何?

2.2 在 2D 图像上的性能

让我们测试一下在随机图像上的效果是否如预期。我们对地理空间应用感兴趣,所以我去Google Earth并放大一个感兴趣的点:

选择来自 Biscarosse 的图像数据集。© F. Poux

🦊 Florent: 这个点有偏见,对吧?希望这能给你一些法国假期的感觉,你很自豪地经历了这段美妙的岁月,充满了激动人心的项目!

从那里,我会截取一个感兴趣区域的屏幕截图:

一个来自 Biscarosse plage 区域的图像数据集。© F. Poux

然后我用 openCV 将图像加载到内存中:

#When loading an image with openCV, it is in bgr by default
loaded_img = cv2.imread("../DATA/biscarosse.jpg")

#Now we get the R,G,B image
image_rgb = cv2.cvtColor(loaded_img, cv2.COLOR_BGR2RGB)

🦚 注意: 如你所见,默认情况下,OpenCV 通过切换到蓝色、绿色和红色通道(BGR)来加载图像,我们通过第二行将其排序为 RGB,并存储在 image_rgb 变量中。

现在,我们可以用两行代码在图像上应用 SAM:

mask_generator = SamAutomaticMaskGenerator(sam)
result = mask_generator.generate(image_rgb)

大约 6 秒钟后,这将返回一个填充了字典的列表,每个字典代表一个特定对象的自动提取掩膜,并附有其分数和元数据。详细查看时,结果是一个字典列表,每个dict包含以下信息:

  • segmentation : 这会生成形状为(W, H)(和bool类型)的掩膜,其中W(宽度)和H(高度)针对原始图像尺寸;

  • area : 这是以像素为单位的掩膜面积

  • bbox : 这是xywh格式的边界框检测

  • predicted_iou : 模型对掩膜质量的预测 IoU 指标。

  • point_coords : 这是用于生成掩膜的采样输入点的列表

  • stability_score : 稳定性得分是掩膜质量的附加衡量指标。查看论文获取更多细节 😉

  • crop_box : 这是用于生成该掩膜的 crop_box 坐标列表,格式为xywh(可能与边界框不同)

现在你对我们正在处理的内容有了更好的了解,要查看结果,我们可以用以下函数在图像上绘制掩膜:

def sam_masks(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    c_mask=[]
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.8)))
        c_mask.append(img)
    return c_mask

🦊 Florent: 我承认,这有点直接。但在这个函数中,我会按掩膜面积对它们进行排序,以随机颜色和透明度参数在图像上绘制它们。

🤠Ville: 内存错误可能会毁掉法国假期的氛围!记得使用 Google Colab 选项!如果重启不能解决问题且分配内存过高,以下代码可以清除 GPU 内存中的额外分配。用它来解决内存问题。

print('Mem allocated by other programs: ', torch.cuda.memory_allocated(), 'reserved:', torch.cuda.memory_reserved())
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
import gc
gc.collect()
torch.cuda.empty_cache()

If the GPU memory is not freed enough, try rebooting your (Windows) computer.
ALSO, try using the following line if memory problems persist
mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=16)

如果 GPU 内存未充分释放,请尝试重启你的(Windows)计算机。如果内存问题仍然存在,可以尝试使用以下行: mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=16)

现在,要绘制和导出图像,我们写下以下内容:

fig = plt.figure(figsize=(np.shape(image_rgb)[1]/72, np.shape(image_rgb)[0]/72))
fig.add_axes([0,0,1,1])
plt.imshow(image_rgb)
color_mask = sam_masks(result)
plt.axis('off')
plt.savefig("../test_result.jpg")

这导致:

在 Segment Anything Model 之前和之后。© F. Poux

因此,在这个阶段,我们已经有了令人兴奋的结果,SAM 工作得非常好!例如,你可以看到几乎所有的屋顶都是分段的一部分,三个泳池(两个蓝色和一个绿色)也是分段的一部分。因此,这可能是完全自动检测的起点

🦊 Florent: 根据你的计算机设置,绘制掩膜时可能会遇到内存错误。在这种情况下,加载一个更轻的 SAM 模型应该能解决你的问题。 😉

既然我们已经有了一个有效的 SAM 设置,让我们将这些辛苦获得的知识应用于 3D 点云。

3. 3D 点云到图像投影

为了理解复杂的 3D 世界,我们深入探讨点云投影的艺术。通过像正射和球面投影这样的技术,我们弥合了维度之间的差距,使我们能够在 2D 领域中可视化点云的复杂性,这正是 SAM 所需的输入。点云映射为这一投影过程增添了一层理解。

3.1 正射投影:平展维度,扩展洞察

让我们看看正射投影的变革性技术。此方法作为 3D 点云的多维复杂性与 2D 图像的可理解世界之间的绝佳桥梁。通过正射投影,我们“平展”维度,同时揭示了使用 SAM 进行分割的直接方法。

这个想法基本上是生成一个俯视视图平面,并生成一个不受单一视角限制的图像。你可以将正射投影视为将点云中的可见点(最高点)推送到持有空图像的平面上,以填充所有必要的像素。你可以看到与透视视图的不同之处,如下所示。

解释正射视图和透视视图在 3D 投影中的区别。透视视图与一个单独的视点相关,这会扭曲维度。© F. Poux

为了完成这个过程,我们可以定义一个 3D 到 2D 的投影函数,它将点云的点及其颜色和所需分辨率作为输入,计算正射投影并从点云中返回正射图像。这将转化为以下内容:

def cloud_to_image(pcd_np, resolution):
    minx = np.min(pcd_np[:, 0])
    maxx = np.max(pcd_np[:, 0])
    miny = np.min(pcd_np[:, 1])
    maxy = np.max(pcd_np[:, 1])
    width = int((maxx - minx) / resolution) + 1
    height = int((maxy - miny) / resolution) + 1
    image = np.zeros((height, width, 3), dtype=np.uint8)
    for point in pcd_np:
        x, y, *_ = point
        r, g, b = point[-3:]
        pixel_x = int((x - minx) / resolution)
        pixel_y = int((maxy - y) / resolution)
        image[pixel_y, pixel_x] = [r, g, b]
    return image

很好,现在是时候进行测试了,你同意吗?为此,让我们加载一个点云数据集,将其转换为 numpy 数组,应用这个函数,并导出这个点云的图像:

#Reading the point cloud with laspy
pcd = laspy.read("../DATA/34FN2_18.las")

#Transforming the point cloud to Numpy
pcd_np = np.vstack((pcd.x, pcd.y, pcd.z, (pcd.red/65535*255).astype(int), (pcd.green/65535*255).astype(int), (pcd.blue/65535*255).astype(int))).transpose()

#Ortho-Projection
orthoimage = cloud_to_image(pcd_np, 1.5)

#Plotting and exporting
fig = plt.figure(figsize=(np.shape(orthoimage)[1]/72, np.shape(orthoimage)[0]/72))
fig.add_axes([0,0,1,1])
plt.imshow(orthoimage)
plt.axis('off')
plt.savefig("../DATA/34FN2_18_orthoimage.jpg")

这允许我们获得以下内容:

从 3D 点云到正射图像的工作流程。我们首先按照正射投影模式对点云进行投影,然后确保包括点到像素的映射用于反投影。© F. Poux

让我们继续球面投影

3.2 3D 点云球面投影

我们的旅程在遇到球面投影时变得非常有趣。这种技术提供了独特的视角,使我们能够通过“模拟”虚拟扫描站来可视化数据。为此,我们通过以下四个步骤进行: (1) 考虑 3D 点云, (2) 将这些点投影到球体上, (3) 定义一个几何体来检索像素, (4) “平展”这个几何体以生成图像。

🤠 Ville: 球面投影就像是在 3D 点云内部,拍摄你看到的 360 度照片

3D 点云球面投影工作流。我们获取 3D 点云,创建一个 3D 投影球体,定义映射平面,并生成等距圆柱投影。© F. Poux

为了实现 3D 投影到球面,我们希望获得如下所示的点。

如何将 3D 点云的点投影到球面。© F. Poux

然后,我们将根据几何形状(圆柱体)展开,以获得等距圆柱图像,如下所示。

如何从球面到等距圆柱图像。我们有一个投影机制,允许我们将像素“展开”到圆柱体上。© F. Poux

现在让我详细介绍允许实现这一点的功能:

def generate_spherical_image(center_coordinates, point_cloud, colors, resolution_y=500):
    # Translate the point cloud by the negation of the center coordinates
    translated_points = point_cloud - center_coordinates

    # Convert 3D point cloud to spherical coordinates
    theta = np.arctan2(translated_points[:, 1], translated_points[:, 0])
    phi = np.arccos(translated_points[:, 2] / np.linalg.norm(translated_points, axis=1))

    # Map spherical coordinates to pixel coordinates
    x = (theta + np.pi) / (2 * np.pi) * (2 * resolution_y)
    y = phi / np.pi * resolution_y

     # Create the spherical image with RGB channels
    resolution_x = 2 * resolution_y
    image = np.zeros((resolution_y, resolution_x, 3), dtype=np.uint8)

    # Create the mapping between point cloud and image coordinates
    mapping = np.full((resolution_y, resolution_x), -1, dtype=int)

    # Assign points to the image pixels
    for i in range(len(translated_points)):
        ix = np.clip(int(x[i]), 0, resolution_x - 1)
        iy = np.clip(int(y[i]), 0, resolution_y - 1)
        if mapping[iy, ix] == -1 or np.linalg.norm(translated_points[i]) < np.linalg.norm(translated_points[mapping[iy, ix]]):
            mapping[iy, ix] = i
            image[iy, ix] = colors[i]
    return image

🌱 Growing消化这个功能是至关重要的。它看起来很简单,但在多个阶段有一些巧妙的技巧。例如,你对 3D 点云到球面坐标步骤有什么看法?映射的作用是什么?在将点分配给像素时,使用映射作为条件语句的意义何在?

现在,为了使用这个方便的功能,我们首先加载并准备 ITC 室内点云:

#Loading the las file from the disk
las = laspy.read("../DATA/ITC_BUILDING.las")

#Transforming to a numpy array
coords = np.vstack((las.x, las.y, las.z))
point_cloud = coords.transpose()

#Gathering the colors
r=(las.red/65535*255).astype(int)
g=(las.green/65535*255).astype(int)
b=(las.blue/65535*255).astype(int)
colors = np.vstack((r,g,b)).transpose()

准备好后,我们可以定义投影所需的参数。这些参数包括投影中心(基本上是我们希望虚拟扫描站的位置)和最终图像的分辨率(以像素表示,即图像的最小边)。

resolution = 500

#Defining the position in the point cloud to generate a panorama
center_coordinates = [189, 60, 2]

最后,我们可以调用新的函数,绘制并将结果导出为图像。

#Function Execution
spherical_image, mapping = generate_spherical_image(center_coordinates, point_cloud, colors, resolution)

#Plotting with matplotlib
fig = plt.figure(figsize=(np.shape(spherical_image)[1]/72, np.shape(spherical_image)[0]/72))
fig.add_axes([0,0,1,1])
plt.imshow(spherical_image)
plt.axis('off')

#Saving to the disk
plt.savefig("../DATA/ITC_BUILDING_spherical_projection.jpg")

所有这些过程会产生以下图像:

3D 点云通过投影转化为等距圆柱图像。© F. Poux

你觉得怎么样?你可以调整各种参数,如分辨率或投影中心,以确保在“无数据”像素和相关全景之间取得良好的平衡。

🦊 Florent你刚刚解锁了一项强大的新技能——将 3D 点云转换为等距圆柱图像。确实,它允许你在你认为有意义的地方生成虚拟扫描,并开启使用图像处理和深度学习技术处理图像的可能性。你还可以将提供的功能扩展到其他映射投影,以增加你的工具库。

🤠Ville我几乎可以看到讲座大厅和我的办公室,荷兰的工作氛围!

3.3 3D 点到像素的映射

我们将原始点数据转换为结构化的栅格表示,理清看似散乱的信息。点云映射是我们在 2D 投影中处理 3D 点云的指南针。好消息是:我们已经处理了这个映射。

确实,如果你仔细查看函数generate_spherical_image,你会发现我们返回了mapping变量并将其捕获到另一个变量中以便后续处理。这确保了我们可以拥有一致的 3D 点到像素的映射。

4. 使用 SAM 的无监督分割

无监督分割以 Segment Anything 模型的形式出现。在非标记输出的情况下,我们通过 SAM 的分割架构,这属于聚类应用。这与大多数监督学习方法提供标记输出的方式相对,如下所示。

无监督学习和监督学习之间的区别。在无监督学习中,我们旨在定义一些相似的数据“点”组,而在监督学习中,我们旨在满足监督需求(通常通过提供标记数据)。© F. Poux

因此,像素预测的转移,加上无缝的点云导出,展示了革新物体检测和场景理解等应用的潜力。

4.1. SAM 分割

要执行程序,我们可以重新执行我们用于测试 SAM 功能在 2D 图像上的代码片段,这些代码片段是:

sam = sam_model_registry"vit_h"
sam.to(device = USED_D)

mask_generator = SamAutomaticMaskGenerator(sam)

temp_img = cv2.imread("../DATA/ITC_BUILDING_spherical_projection.jpg")
image_rgb = cv2.cvtColor(temp_img, cv2.COLOR_BGR2RGB)

t0 = time.time()
result = mask_generator.generate(image_rgb)
t1 = time.time()

然后,我们可以在图像上绘制结果

fig = plt.figure(figsize=(np.shape(image_rgb)[1]/72, np.shape(image_rgb)[0]/72))
fig.add_axes([0,0,1,1])

plt.imshow(image_rgb)
color_mask = sam_masks(result)
plt.axis('off')
plt.savefig("../DATA/ITC_BUILDING_spherical_projection_segmented.jpg")

这将得到:

3D 点云投影上的 Segment Anything 模型的结果。© F. Poux

这看起来我们正在勾画出图像中的重要部分。让我们继续前进。

4.2. 点预测转移

让我们用这张图像为点云上色。因此我们定义一个上色函数:

def color_point_cloud(image_path, point_cloud, mapping):
    image = cv2.imread(image_path)
    h, w = image.shape[:2]
    modified_point_cloud = np.zeros((point_cloud.shape[0], point_cloud.shape[1]+3), dtype=np.float32)
    modified_point_cloud[:, :3] = point_cloud
    for iy in range(h):
        for ix in range(w):
            point_index = mapping[iy, ix]
            if point_index != -1:
                color = image[iy, ix]
                modified_point_cloud[point_index, 3:] = color
    return modified_point_cloud

这意味着要为我们的点云上色,我们可以使用以下代码行调用我们的新函数:

modified_point_cloud = color_point_cloud(image_path, point_cloud, mapping)

这一行返回一个 numpy 数组,该数组保存了点云。

现在是 3D 点云导出的时候了!

4.3. 点云导出

要导出点云,你可以使用 numpy 或 laspy 直接提取一个.las 文件。我们将采用第二种解决方案:

def export_point_cloud(cloud_path, modified_point_cloud):
    # 1\. Create a new header
    header = laspy.LasHeader(point_format=3, version="1.2")
    header.add_extra_dim(laspy.ExtraBytesParams(name="random", type=np.int32))

    # 2\. Create a Las
    las_o = laspy.LasData(header)
    las_o.x = modified_point_cloud[:,0]
    las_o.y = modified_point_cloud[:,1]
    las_o.z = modified_point_cloud[:,2]
    las_o.red = modified_point_cloud[:,3]
    las_o.green = modified_point_cloud[:,4]
    las_o.blue = modified_point_cloud[:,5]
    las_o.write(cloud_path)

    print("Export succesful at: ", cloud_path)
    return

这样,我们可以导出我们的 modified_point_cloud 变量:

export_point_cloud("../DATA/pcd_results.las", modified_point_cloud)

在这一阶段,我们成功地获取了各种来自 3D 点云投影过程的 2D 图像。我们对其应用了 SAM 算法,基于其预测对其上色,并导出了一个彩色点云。因此我们可以开始获取一些关于我们所得到的东西的见解。

🦊 Florent为了快速在 Python 之外分析结果,我建议使用 CloudCompare 开源软件。如果你想要一个清晰的使用指南,可以阅读下面的文章。

## 3D 深度学习 Python 教程:PointNet 数据准备

《终极 Python 指南》用于构建大型 LiDAR 点云以训练 3D 深度学习语义分割…

[towardsdatascience.com

5. 定性分析与讨论

随着我们的旅程接近巅峰,现在是关注定性分析的时候了。特别地,我们不会进行定量分析,因为在这个阶段我们需要适当的标签。

🤠 Ville: 没有标签?你刚刚做的是零样本学习(砰!)或少样本学习(砰!砰!)。我们不能确定是哪种,因为我们不知道 SAM 的训练方式。因此对我们来说它有点像黑箱,但没关系。

我们仔细检查光栅和点云结果,得出的见解揭示了 SAM 的性能。同时,让我们保持脚踏实地,承认模型的局限性,同时展望未来。

5.1 光栅结果

SAM 在我们实施下的成果通过下述光栅结果得到了生动的展示。这些视觉效果作为画布,快速评估 SAM 的分割,帮助我们理解模型对场景的 2D 表示。

Segment Anything Model 在 3D 点云投影上的另一个结果。© F. Poux

如你所见,即使在点分布不均和“黑区”下,SAM 仍能识别出点云的主要部分。具体而言,它可能突出了左侧的绿色区域,即危险材料所在的位置,以及为我们提取团队提供最直接路线的门窗!

5.2 点云结果

然而,正是在点云结果中,SAM 的真正能力得以显现。当我们在点云中穿梭时,SAM 的分割预测为经典的“点云混乱”带来了清晰度,展示了其在实际应用中的潜力。

使用 Segment Anything 3D 进行的 3D 点云无监督分割结果。我们看到构成场景的主要元素的明显区别。© F. Poux

如我们所见,我们可以直接链接到基础点,这真是极其棒!想想这能为你的应用程序解锁什么。一个拥有不到五个主要断点的 100%自动化分割过程?不错!

5.3. 局限性

但,我们的探险只有在承认过程中存在的粗糙点后才算完成。SAM,尽管令人印象深刻,也不例外地存在局限性。通过认识这些不足,我们为改进和成长铺平道路。

首先是所有“未见”的点仍然保持未标记(下图中的白点)。这可能会成为完整处理的一个限制,如果你使用基本模型或大模型,你会看到比使用巨大模型时更多的未标记点。

从中央视角 360°模拟扫描位置的第一次扫描中,未标记点与标记点的比例。 © F. Poux

此外,在这个阶段,我们使用了自动提示引擎,它触发了大约 50 个兴趣点,即分割任务的种子。虽然这对于获得直接结果非常好,但如果能够进行调整会更棒。

最后,此阶段的映射相对简单;它将大大受益于遮挡剔除和特定像素的点选择。

5.4. 视角

Segment Anything 模型只是 3D 点云分割更大领域中的一步。然而,如今的实现应该能够很好地适用于任何具有某种独特初始特征的 SAM 应用程序。正如下图所示,它也适用于俯视的航空点云。

Segment Anything 3D 在航空点云中的结果。© F. Poux

扩展到室内场景,你会发现也能得到一些相当不错和有趣的结果。这甚至对自动更换大厅灯具的灯泡是有用的,当然是由机器人自动完成的(更换一个灯泡需要多少机器人?)!

Segment Anything 3D 在另一个室内场景中的结果。© F. Poux

因此,除了泛化外,第一个视角是解锁生成全景图和融合不同视角预测的方法。当然,另一个视角是扩展到自定义提示,最终解决在 2D-3D 映射中提高点到像素精度的挑战。

结论

如果你是 13.37%中实际使代码正常工作的 3D 创作者中的一员,那么对你表示由衷的赞赏!

我们在这篇文章中覆盖的工作流程。 © F. Poux

这是一个巨大的成就,你现在拥有了一个非常强大的工具来处理 3D 场景理解的语义提取任务。通过 Segment Anything 模型,你可以在许多产品中封装创新,改变我们感知和解读 3D 点云的方式。

我们的探索应该为这一开创性模型从起步到其影响描绘了一个全面、实用的图景。你现在可以探索这些变体,并根据之前部分发现的限制扩展其相关性。

🦊 Florent我期待你未来的项目能加以利用!

🤠 Ville继续编码!

参考文献

  1. Kirillov, A.,Mintun, E.,Ravi, N.,Mao, H.,Rolland, C.,Gustafson, L.,Xiao, T.,Whitehead, S.,Berg, A.C.,Lo, W.Y. 和 Dollár, P.,2023. 分割任何东西。arXiv 预印本 arXiv:2304.02643

  2. Poux, Florent,Mattes, C.,Selman, Z. 和 Kobbelt, L.,2022. 用于大规模点云分割的自动区域生长系统。建筑自动化138,第 104250 页。 Elsevier 链接

  3. Lehtola, Ville,Kaartinen, H.,Nüchter, A.,Kaijaluoto, R.,Kukko, A.,Litkey, P.,Honkavaara, E.,Rosnell, T.,Vaaja, M.T.,Virtanen, J.P. 和 Kurkela, M.,2017. 对选定的先进 3D 室内扫描和点云生成方法的比较。遥感9(8),第 796 页。 MDPI 链接

🔷其他资源

🎓作者推荐

要构建完整的室内语义提取场景,你可以将这种方法与“3D 点云形状检测用于室内建模”文章中解释的方法结合起来:

## 3D 点云形状检测用于室内建模

一份 10 步 Python 指南,用于自动化 3D 形状检测、分割、聚类和体素化…

towardsdatascience.com [## 3D 创新者通讯

每周提供实用内容、见解、代码和资源,以掌握 3D 数据科学。我写关于点云、人工智能…

learngeodata.eu](https://learngeodata.eu/3d-newsletter/?source=post_page-----80c06be99a18--------------------------------)

Segment Anything: 可提示的任意对象分割

原文:towardsdatascience.com/segment-anything-promptable-segmentation-of-arbitrary-objects-f28958c5612d

🚀Sascha 的论文俱乐部

Segment Anything 由 A. Krillov 等人

Sascha KirchTowards Data Science Sascha Kirch

·发表于 Towards Data Science ·12 分钟阅读·2023 年 9 月 14 日

--

今天的论文讲解将是视觉化的!我们将分析 Segment Anything,这是 Meta AI 研究团队的一篇论文,它不仅在研究界引起了关注,也在各种深度学习从业者和支持者中引起了广泛关注。

Segment Anything 引入了可提示的分割任务,介绍了 segment anything 模型(SAM),并详细描述了生成一个包含超过 10 亿个掩膜的 1100 万张图片的新公开数据集。SAM 已被广泛采纳,并产生了一些新的最先进基础模型,如 Grounded-SAM,它将 Grounding DINO 与 SAM 结合起来。

图片来源于 出版物Sascha Kirch

论文: Segment AnythingAlexander Kirillov 等人,2023 年 4 月 5 日

资源: GitHub演示项目页面数据集HuggingFace

类别: 分割、零-shot 预测、计算机视觉、提示、大规模

其他教程:

[BYOL] — [CLIP] — [GLIP] — [Depth Anything] — [DINO] — [DDPM]

大纲

  1. 背景与背景

  2. SAM — Segment Anything Model

  3. SA-1B — 具有 10 亿个掩码的数据集

  4. 实验与消融

  5. 结论

  6. 进一步阅读与资源

背景与背景

《Segment Anything》的作者明确声明:“我们目标是建立一个图像分割的基础模型。” 基础模型源于自然语言处理(NLP)的巨大成功。这些模型在自监督的方式下经过了大规模的训练。通常,这些模型在零-shot 任务中表现非常好,即它们可以解决与训练时不同的任务,并表现得相当不错,甚至比其监督型对手更优秀。近年来,许多研究人员致力于将 NLP 基础模型的成功带到计算机视觉等其他领域。

模型如 CLIP 和 GLIP 使得可以根据文本提示对图像分类或对象检测任务进行条件限制,而不是固定的类别集合。其他模型,如 BYOL 或 DINO,提出了不同的技术来学习输入图像的语义丰富表示,这也是许多计算机视觉应用的关键要求。

《Segment Anything》论文旨在:

  1. 通过提示启用零-shot 分割

  2. 训练一个大规模模型(SAM)作为演示模型

  3. 收集并发布最大的公开可用分割数据集。

但为什么零-shot 性能如此重要? — 答案有两个方面。首先,最初计算机视觉模型是以监督方式训练的,这不仅需要数据,还需要大量的真实标签。收集这些数据是极其耗时和昂贵的。其次,模型可以预测的类别仅限于训练时使用的固定类别集合。如果你想向模型中添加一个新类别,你需要首先收集数据并重新训练模型。

如何对分割模型进行提示? — 你可能对来自 ChatGPT、CLIP 或 GLIP 等模型的文本提示比较熟悉。虽然 SAM 原则上也经过了文本提示的测试,但它主要通过掩码、点、框或点网格来进行提示,如下图所示。

图 1:不同输入提示和生成的掩码。照片由Terence Burke拍摄,发布于Unsplash + 掩码由Sascha KirchSAM生成

了解了 SAM 的背景后,让我们转到重点,详细了解 Segment Anything Model,即 SAM。

Sascha Kirch

Sascha Kirch

Sascha Kirch 的论文讲解

查看列表7 个故事“DDPM — 去噪扩散概率模型” 论文插图,Sascha Kirch“Depth Anything” 论文插图,Sascha Kirch

SAM — Segment Anything Model

Segment Anything Model(SAM)是一个多模态模型,它输入一张图像和一个或多个提示,并输出一个有效的分割掩码。该模型由三个主要模块组成:图像编码器、提示编码器和掩码解码器。

SAM 可以通过掩码、一组点、边界框或文本,或这些的任何组合来进行提示。

注意:尽管论文提到并实验了文本作为提示,但截至 2023 年 9 月,文本提示尚未在官方实现SAM 演示中发布。

图 2:SAM 架构。图片来源 + 注释由Sascha Kirch

图像编码器 — 为给定的输入图像输出图像嵌入。SAM 实现并适配了一个预训练的 ViT-H/16 掩码自编码器。这是一个相对较大的模型,性能强劲。

提示编码器 — 稀疏提示(例如点、框和文本)被转换为嵌入向量。文本提示在输入提示编码器之前,使用 CLIP 转换为文本嵌入。密集提示(例如掩码)则简单地通过步幅卷积下采样,并与图像嵌入相加。所有嵌入随后被送入最终阶段:掩码解码器。

掩码解码器 — 接受一组图像嵌入(可选地包含密集掩码嵌入)和一组提示嵌入,并输出有效的分割掩码。

还有两个细节我们应该讨论:提示的歧义性和性能。

简而言之,提示包含的上下文越少,就越模糊,对模型提供正确输出的难度也越大。对于文本提示,我们已经在 CLIP 和 GLIP 中看到了输入文本的具体性与模型性能之间的这种联系。同样,提供一个单点作为输入可能会产生多种可能的掩码。因此,SAM 输出一组三种掩码,分别对应于有效掩码的对象级别、部件级别和子部件级别,如下图所示。

图 3:单点提示的歧义性。图片来源 + 由 Sascha Kirch 注释

我想提到的第二个细节是推理速度方面的性能。你是否注意到图像编码器是 SAM 中最大的一部分?好吧,这个问题有点不公平,因为我之前没有告诉你,但 SAM 的设计目的是拥有语义丰富的图像嵌入(通常需要一个大型模型),然后通过一个轻量级的提示编码器和轻量级的掩码解码器来处理这些嵌入。好的一点是:每张图像只需运行一次图像编码器,然后可以使用相同的图像嵌入多次提示模型。这使得 SAM 可以在浏览器中运行,仅需 ~50ms 来预测给定提示的掩码(在图像嵌入计算后)。

让我们更详细地看看轻量级掩码解码器。它输入图像嵌入和提示嵌入,并输出一组带有相应分数的掩码。在内部,两个连续的解码器块通过自注意力和交叉注意力的组合生成图像与提示之间的强依赖关系。一个简单的上采样网络结合另一个交叉注意力块生成掩码和分数。

图 4:掩码解码器的详细架构。 图片来源 + Sascha Kirch 的注释

SA-1B — 具有 10 亿掩码的数据集

Segment Anything 的第二个重大成果是创建和发布了一个大规模的分割数据集。它包含 1100 万张高分辨率和许可的图像,大约有 11 亿个掩码。虽然数据集的原始版本平均有 3300x4950 像素,但发布版本经过下采样,使最短边为 1500 像素。它在不同场景和每张图像掩码数量上都具有多样性,范围从不到 50 个到超过 500 个。

图 5:来自 SA-1B 的不同掩码。 图片来源 + Sascha Kirch 的注释

该数据集是在一个三阶段数据引擎中创建的,该引擎结合了人工标注和 SAM 生成的自动标注。

阶段 1:辅助手动阶段 — 一组专业标注员在 SAM 的早期版本的帮助下对图像进行了标注,SAM 在常见的分割数据集上进行了训练。他们被要求标注最显著的对象,并被鼓励在 30 秒后继续。在此阶段结束时,SAM 通过新的标签进行重新训练(总计 12 万张图像和 430 万个掩码)。

阶段 2:半自动阶段 — 在这一阶段的目标是通过首先让 SAM 预测一些掩码,然后让标注员标注缺少的、不太显著的对象,以增加掩码的多样性。在此阶段结束时,SAM 再次进行重新训练,包括新的样本(总计 30 万张图像和 1020 万个掩码)。

阶段 3:完全自动阶段 — 在这一阶段,注释完全自动化。SAM 通过 32x32 的网格点生成掩码,并应用一些后处理。

数据集分析

现在让我们仔细看一下论文中关于 SA-1B 数据集的一些分析。

在第一次评估中,作者创建了掩码中心点的标准化分布。有趣的是,这些分布会受到摄影师的偏差,即大多数照片将感兴趣的对象置于图像的中心和主轴上。

图 6:图像中对象中心点位置的分布。 图片来源 + Sascha Kirch 的注释

SA-1B 的一个主要优点是每张图像的掩码数量相比其他数据集更高(见图 7 左)。这也意味着 SA-1B 有许多小掩码(见图 7 中)。比较掩码的凹凸度,作为复杂性的衡量标准,SA-1B 与其他手动标注的数据集非常相似(见图 7 右)。

图 7:SA-1B 的掩码属性与其他数据集的比较。 图片来源 + Sascha Kirch 注释

高度关注负责任的人工智能(RAI),在这里,不仅分析对某些人群的偏见,还尝试减轻这些偏见。如图 8 所示,世界上大多数国家的图像数量超过 1000 张,前 3 名国家来自不同地区。虽然低收入国家的样本相对较少(占所有样本的 0.9%),但绝对数量仍超过 900 万张,比其他分割数据集更多。

图 8:SA-1B 图像的估计地理分布。 图片来源 + Sascha Kirch 注释

作者进一步研究了感知性别展示、感知年龄组和感知肤色之间的性能差异。他们提供了预测掩码与真实掩码之间的平均 IoU(交并比)以及 95%的置信区间。SAM 的提示可以是单个点或三个点。主要信息是,在一个组内,结果非常相似(且置信区间重叠),这表明该组的任何成员都没有被偏袒。唯一的例外是感知年龄组中的老年人。

图 9:SAM 在感知性别展示、年龄组和肤色方面的分割性能。 图片来源 + Sascha Kirch 注释

## 每当 Sascha Kirch 发布新内容时获取电子邮件 🚀

每当 Sascha Kirch 发布新内容时获取电子邮件 🚀 想了解更多关于深度学习的知识或只是保持更新…

medium.com

实验和消融研究

Segment Anything 确实为我们提供了一系列实验,主要集中在其零-shot 性能上,因为这是作者的主要目标:找到一个可提示的零-shot 分割模型。同时,我们也知道其他模型如 CLIP 和 GLIP 的表现,提示调整几乎与模型微调在性能上等效。

为了进行实验,编制了一套包含 23 个多样化数据集的集合。它包含了来自各种数据分布的样本,如图 10 所示。

图 10:来自 23 个数据集的样本。图片来源 + 注释由 Sascha Kirch

零-Shot 单点有效掩码评估

记住,零-Shot 意味着模型从未在评估过程中接触过的数据上进行训练。同样,单点提示由于其模糊性,如图 3 所示,是一项相当困难的任务。

在这个第一次实验中,作者将 SAM 与 RITM进行了比较,RITM 是一种强大的交互式分割器,作者表示其在他们的基准测试中表现最佳。

记住,当用单个点进行提示时,SAM 会输出 3 个不同的掩码及其相关分数。在这个实验中,选择分数最高的掩码进行评估。由于这种情况有时会出现错误,作者还对最佳掩码进行了评估,通过将预测结果与真实掩码进行比较,选择重叠度最高的掩码。这些是“oracle”预测。

在 23 个数据集中,SAM 在 16 个数据集中中的零-Shot 单点有效掩码预测中表现优于 RITM。在进行 oracle 预测时,它在所有 23 个数据集中都优于 RITM。

图 11:在 23 个数据集上的 SAM 与 RITM 对比。图片来源 + 注释由 Sascha Kirch

零-Shot 文本到掩码

在这个实验中,SAM 通过文本进行提示。作者将此功能称为概念验证,因此既没有进行广泛的实验,也没有在其官方代码实现中发布此功能。

看图 12,你可以看到 SAM 能够为像“海狸牙齿格栅”这样的复杂对象返回正确的掩码。在其他一些情况下,模型仅通过文本提示失败,他们展示了在提供点的上下文时,SAM 能够正确预测单个或多个擦拭器,显示出不仅点被用于预测,文本也被考虑在内。

图 12:零-shot 文本到掩码。图片来源 + 注释由 Sascha Kirch

零-Shot 边缘检测

有趣的是,SAM 也可以用于边缘检测,这是一项它在训练过程中未被考虑的任务,也没有访问相关数据。

为了预测图像,SAM 首先使用 16x16 点的网格进行提示,生成 768 个预测的掩码(每个 256 个点的对象、部分和子部分)。然后对生成的掩码进行筛选和后处理,以获取边缘掩码。

如图 13 所示,与真实数据相比,SAM 预测了更多的细节。但为了公平起见,如果真实数据不完整或覆盖了不同的抽象层次,这种比较对我来说似乎不太公平。但总的来说,性能还是相当不错的!

图 13:SAM 的零样本边缘预测。 图片来源 + 注释由 Sascha Kirch

零样本实例分割

对于这个实验,SAM 以 COCOLVIS 上训练的完全监督的 ViTDet-H 的边界框输出作为提示。然后将生成的掩码连同初始边界框一起输入到 SAM 中,以精细化结果。图 14 显示了 ViTDet 和 SAM 的比较。

图 14:在 LVIS v1 上的零样本实例分割。 图片来源 + 注释由 Sascha Kirch

这里有两件事需要注意:如果你查看 COCOLVIS,你会发现掩码与对象的像素对齐并不完全。这种偏差在 ViTDet 中也存在,这就是为什么 SAM 的质量似乎更好的原因。由于基准真实值具有相同的偏差,而与差的 GT 相比,SAM 的表现可能更差。因此,他们要求人工进行视觉检查。其次,为什么这只大象只有 3 条腿 😅。无论我怎么努力,我都看不到第四条腿…

消融实验

在消融实验部分,作者主要关注于扩展数据集、提示点数量和图像编码器的大小(见图 13)。性能以平均 IoU 报告。

图 15:消融研究。 图片来源 + 注释由 Sascha Kirch

有趣的是,尽管数据扩展和模型规模扩展影响了 mIoU 性能,但它达到饱和状态。这可能表明模型已经足够好,没有太多改进的空间,或者可能是他们方法的局限性。

结论

Segment Anything 引入了可提示的 Segment Anything Model (SAM) 以及一个包含超过 10 亿个掩码的分割大规模数据集,涵盖超过 1100 万张图像。能够提示分割模型带来了很多灵活性,比如将训练好的模型适应于未见过的任务或检测未知类别。虽然有些人讨论 SAM 是否可以被视为基础模型,因为它是以监督方式训练的,但它仍然显示出了显著的成果,并已被广泛采用。

进一步阅读与资源

正如你自己可能知道的那样:深度学习领域正在以令人难以置信的速度发展。因此,SAM 发布后,许多新项目在其成功的基础上进一步改进了预测质量、减少了推理时间,或者使模型适用于边缘应用,这也就不足为奇了。

以下是一些有趣的资源,它们在 SAM 的基础上进行扩展:

  1. 基础分割任何内容

  2. 高质量分割任何内容

  3. 快速分割任何内容

  4. 更快的分割任何内容:朝着适用于移动应用的轻量级 SAM

在这里,我分享一些链接,如果你想亲自体验 SAM 和 SA-1B:

这里是我一些文章的链接,带你了解一些相关的基础模型:

## CLIP 基础模型

论文总结— 从自然语言监督中学习可迁移的视觉模型

towardsdatascience.com ## GLIP: 将语言-图像预训练引入物体检测

论文总结:基础语言-图像预训练

towardsdatascience.com ## BYOL - 对比自监督学习的替代方案

论文分析—Bootstrap Your Own Latent: 自监督学习的新方法

towardsdatascience.com

将文本分段成段落

原文:towardsdatascience.com/segmenting-text-into-paragraphs-e8bed99b6ebd

基于监督学习的统计 NLP 方法

Arun JagotaTowards Data Science Arun Jagota

·发表于 Towards Data Science ·11 min 阅读·2023 年 2 月 25 日

--

图片由 Gordon Johnson 提供,来自 Pixabay

在之前的 Medium 文章中,我们讨论了将文本分割成句子的问题[3]。现在我们来看一个相关问题:将文本分割成段落。

初看起来,这两个问题似乎本质上是相同的,只是在不同的分块层次上。实际上,将文本分段成段落的问题要有趣得多。

一方面,句子边界有明确的信号,如句号、问号或感叹号。通常,问题在于这些标记中的哪些是实际的边界,哪些是在句子内嵌入的边界。这就是假阳性的难题。

将文本分段成段落更为复杂。这样考虑一下。假设你有一个长句子序列,没有段落分隔符。应该在哪里设置段落边界?这不是一个容易解决的问题,也不一定有唯一的解决方案。这意味着可能存在多种将句子序列分段的合理分法。

将文本分段成段落可以看作是文本分割的一个特例[1]。文本片段是一个连续的段落,保持一定的连贯性,比如说在同一主题上。根据这种连贯性度量,段落会在主题发生变化时过渡到另一个段落。

更广泛的文本分段问题更难解决。原因有几个。其中之一是很难获取标记数据。对于段落分段,有大量的标记数据可用,这些数据以网页和包含段落分隔符的维基百科文章的形式存在。而对于更广泛的文本分段问题情况则不然。

在这篇文章中,我们认为一个能够建议合理分割边界的算法对写作中的人是有帮助的。就像 Grammarly 对写作的帮助一样。换句话说,精确率和召回率都不需要特别高。精确率需要合理;召回率甚至可以更低。

本文的其余部分应以此视角进行阅读。我们将满足于一个具有合理精确度的解决方案,可能接近 50%,以及较低的召回率,可能接近 10%。关键是即使这样也在类似 Grammarly 的环境中是有用的。

即使段落分割建议很少出现,只要它们具有合理的精确度,它们就会增加像 Grammarly 这样的产品的价值。

不用说,如果我们可以在付出最少努力的情况下获得更好的精确率或召回率,我们自然会选择这样做。

预测段落分割的概率模型

我们将从正式描述开始,用简单的英语解释其各个组件。

设 X1 和 X2 表示训练语料库中相邻的两个句子。

我们将与(X1, X2)关联一个二进制标签 Y。如果 X1 和 X2 之间有段落分割,则 Y 为 1,否则为 0。

我们将跟踪第三个预测变量i。X1 将是当前段落中的第i个句子。预测变量“i”将赋予我们的模型关注段落长度的能力。

我们的训练集将包含实例(X1, X2, i, Y)。

从训练集中,我们旨在学习一个模型P(Y | X1, X2, i)。

P(Y=1 | X1, X2, i)将表示在 X1 作为当前段落中的第i个句子的情况下,X1 和 X2 之间有段落分割的概率。

P(Y=0|X1, X2, i)表示在 X1 作为当前段落中的第i个句子的情况下,X2 应该扩展当前段落的概率。

模型 P(Y | X1, X2, i)非常复杂。这是因为 X1 和 X2 是句子,可能非常稀有或很长。这意味着,即使我们的训练集包含了几亿个标记实例,也可能没有足够的数据来估计这个模型。

我们需要做一些假设。

首先,让我们应用贝叶斯规则。

P(Y | X1, X2, i) = N(X1, X2, i, Y)/Z

其中 N(X1, X2, i, Y)等于P(X1, X2, i | Y) P(Y)。

Z 只是 N(X1, X2, i, 0) + N(X1, X2, i, 1)。

接下来,我们将对 N(X1, X2, i, Y)进行如下分解。

N(X1, X2, i, Y) = P(X1 | Y)P*(X2 | Y)P(i* | Y)**P*(Y)

P(X1 | Y=1)是段落中最后句子的分布。P(X1 | Y = 0)是段落中非最后句子的分布。

P(X2 | Y = 1)是段落中第一句子的分布。P(X2 | Y = 0)是段落中非第一句子的分布。

现在考虑P(i | Y)。

让我们提醒自己 X1 是当前段落中的 i 句。因此,P(i | Y=1) 实际上是段落长度的分布,因为段落必须在 X1,即当前段落中的第 i 句之后结束。

P(i | Y = 1) 将倾向于偏向于较小的 i。这是因为大多数段落较短。

P(i | Y = 0) 将倾向于偏向于更小。这是因为 Y = 0 表示当前段落中的第 i 句 X1 不会结束该段落。

概率模型 P(X1 | Y) 和 P(X2|Y) 都仍然过于复杂。这是因为句子的宇宙是无限的。也就是说,句子可以任意长。也可以任意稀有。

我们可以做进一步的简化假设吗?具体来说,使用的可能性不一定是整个句子,而是它们的前几个词。

让我们从实际例子开始。

首先,让我们看看继续短段落的句子的例子。

假设下一句以“例如”,“例子”,“更准确地说”等开头。如果当前段落足够短,比如一两句,这些前缀在下一句中的出现预测 Y = 0,即扩展段落。

为了支持这个假设,我们邀请读者阅读这些相邻句子的对,例如 en.wikipedia.org/wiki/Deep_learning

深度学习算法可以应用于无监督学习任务。这是一个重要的好处,因为未标记的数据比标记的数据更为丰富。例子 包括可以以无监督方式训练的深度结构,如深度置信网络。

深度学习是一类机器学习算法,它[8]: 199–200 使用多个层次逐步从原始输入中提取更高级的特征。例如,在图像处理过程中,较低的层次可能识别边缘,而较高的层次可能识别对人类相关的概念,如数字、字母或面孔。

“深度学习”中的“深度”指的是数据转换通过的层数。更准确地说,深度学习系统具有实质性的信用分配路径(CAP)深度。

你是否同意每个加粗的词序列预测段落的延续?

这些例子表明,考虑简化 P(X1 | Y) 为 P(X1 的前几个词 | Y) 是有意义的。

这就引出了“few”在这里的值是多少的问题?我们稍后会解决这个问题。

接下来,让我们看看一句话段落的例子。

为此,我要求 ChatGPT 给我一些一句话段落的例子。似乎它很字面地理解了这个问题。因此我重新表述了问题为

给我一些由一句话构成的段落的例子。

现在我得到了好的例子。

沉默。

停止。

再也不见。

为什么?

是的!

对不起。

足够了。

记住。

帮助!

再见。

我们期望这些句子中的每一个都有较高的 P(X1 | Y = 1) 的可能性。也就是说,对于其中一些或全部句子,P(X1 | Y = 0) 也可能相对较高。这意味着段落不会在它们之后立即结束。

尽管如此,我们在这里展示这些示例,因为它们确实表明 P(X1 | Y = 1) 对于这些句子是值得建模的。

接下来,让我们看看以新段落开始的句子的示例。我们从 en.wikipedia.org/wiki/Deep_learning 中挑选了一些段落,并展示了它们第一句话的前几个词。

深度学习是一个更广泛领域的一部分…

深度学习架构如…

人工神经网络(ANNs)是…

在深度学习中,每一层学会…

ANN 基于一组…

深度神经网络(DNNs)可以建模复杂的非线性…

这些示例表明,P(X2 | Y = 1) 可以简化为

P(X2 | Y = 1)

对于“少量”这一选择的合适情况。

让我们将从这些示例中学到的内容形式化。

我们可以简化 P(X1 | Y) 和 P(X2 | Y) 为

P(X1 以 w(1)、w(2)、…、w(k) | Y)

以及

P(X2 以 w’(1)、w’(2)、…、w’(k’) |Y)

分别。

这里 w(1)、w(2)、…、w(k) 和 w’(1)、w’(2)、…、w’(k’) 分别是 kk’ 的词序列。

显而易见的问题是 kk’ 应该是什么?

解决这个问题的一种方法是不要提前固定 k 和 k’,而是推迟决策直到推断时。

方法如下。

首先介绍一些术语。我们将称以句子开始的词序列为句子的前缀。

现在考虑 P(X1 | Y=y)。我们将按如下方式近似。

我们将首先找到 X 的最大前缀,称之为 P,并且具有足够的支持。我们将使用 P(X1 以 P 开始 | Y) 作为 P(X1 | Y) 的代理。

X1 前缀 P 对 P(X1 | Y) 的支持定义为训练集中 P 作为 X1 前缀的实例数量。

这一近似过程的理念是,我们应该使用 X1 的最长前缀,只要它在训练集中出现的次数足够多(作为 X1 的前缀)。

类似地,我们应该将 P(X2 | Y) 估计为 P(Q | Y),其中 Q 是 X2 的最大前缀,其支持度足够大。

我们推断出的形式对模型意味着什么?我们需要跟踪所有 X1 的前缀 P 的概率 P(X1 以 P 开始 | Y)。P(X2 | Y) 也类似。

内部,对于建模 P(X1 |Y) 和 P(X2 | Y),我们需要跟踪大量的词序列。

幸运的是,这些词序列可以收集到所谓的 Trie 数据结构中。这些结构被优化为紧凑地表示大量的词序列。

这些 Tries 在训练过程中如下构建。

我们将分别使用四个 Tries T10、T11、T20 和 T21。Tiyi = 1 或 2,将存储前缀序列及其对 Y=y 的计数,针对 Xi

Trie 中的每个节点将存储一个计数。

我们将所有的 Trie 初始化为从一个单独的节点开始,即根节点,其计数设置为零。

现在考虑一个训练集中的实例 (X1, X2, y)。我们忽略了 i,因为它不会影响 Trie。

将 X1 解释为单词序列时,我们将在 T1y 中查找 X1,必要时扩展 Trie 并添加由新节点组成的路径。每次创建新节点时,其计数将初始化为 0。

现在,在 Trie T1y 中,我们将沿着表示 X1 的路径递增所有节点的计数,每个节点增加一次。

处理 X2 时,我们将在 Trie T2y 上重复相同的过程。

数值示例

现在让我们说明这个过程。

四个 Trie 将被初始化为 T10、T11、T20 和 T21,每个 Trie 的初始状态为 {[]:0}。

现在假设我们呈现第一个训练实例为 ([a,b],[A,B,C],1)

T11 的新状态将是 {[]:1,[a]:1,[a,b]:1}

T21 的新状态将是

现在假设我们呈现这个训练实例:([a,d],[A,B,E],1)

T11 的新状态将是 {[]:2,[a]:2,[a,b]:1,[a,d]:1}

T21 的新状态将是

在上述示例中,为了视觉上的方便,我们将每个 Trie 表示为一个哈希映射,即 Python 中的字典。

实际上,我们可以通过利用(重复的)前缀结构将 Trie 更紧凑地表示为树。

将 Trie 表示为树对于查找与给定序列 X 的所有前缀相关的计数也更高效。我们只需沿着 Trie 包含的唯一路径向下查找 X 的最长前缀。我们说“最长前缀”是因为 X 可能未完全包含在 Trie 中,如果 X 在训练集中从未以这种上下文出现,它会被放置在这个 Trie 中。然而,Trie 中总是存在至少一个 X 的前缀的路径,即使只是空前缀。

使用 Trie 推理

假设训练已经完成。现在,对于给定的 (X1, X2, i),我们想计算 P(Y=y|X1,X2,i)。

这个计算中涉及 Trie 的部分是 P(X1|Y=y) 和 P(X2|Y=y)。

让我们演示如何计算其中之一,其他的计算过程也会类似。

让我们选择 P(X1|Y=y)。

我们遍历 Trie T10 和 T11 以找到 X1 的最长支持前缀。我们需要使用这两个 Trie,因为 X1 的前缀 P 的支持是 T10 和 T11 中 P 结束的节点上的计数之和。

让我们用 P(X1) 表示 X1 的最长支持前缀。

P(P(X1) | Y=y) 只是 P 在 Trie T1y 中结束的节点上的计数除以 Trie T1y 根节点上的计数。这仅仅是训练集中标签为 y 且 X1 以 P(X1) 开始的实例数量,除以标签为 y 的训练集中实例数量。

总结

在这篇文章中,我们介绍了将文本分割成段落的 NLP 问题。我们注意到,这个问题比将文本分割成句子的难度更大,但比将文本分割成以主题为单位的连贯单元的难度要小。

我们将这个问题框架设定为监督学习问题。有大量标记数据可以直接使用。输入是一对相邻句子。结果是这两句之间是否存在段落分隔。因此,这是一个监督学习问题,其中输入是一对序列,结果是二元的。

接下来,我们在简单贝叶斯假设下应用了贝叶斯规则,即在给定结果的情况下,预测变量的条件独立性。然后我们计算了结果公式中的似然性和先验项。

从这里我们注意到,即使在简单假设下,得到的模型也过于复杂。我们讨论了如何通过将输入的两个句子建模为从空前缀到整个序列的一组前缀来应对这种复杂性。在推理时,我们描述了如何使用“正确”的前缀来预测结果。

我们检查了多个实际文本中相邻句子的实例,以支持我们基于前缀而非完整句子的工作方法。

最后,我们注意到,处理句子的所有前缀而不是句子本身可能会导致模型规模激增。为此,我们提出了一种使用 Tries 的方案。在适当的 Tries 中,以紧凑方式表示相同上下文中的序列。我们详细讨论了 Tries 如何在训练过程中学习,以及 Tries 在推理过程中如何使用。

参考文献

  1. 文本分割的神经模型

  2. Grammarly

  3. towardsdatascience.com/segmenting-text-into-sentences-using-nlp-35d8ef55c0fd

  4. en.wikipedia.org/wiki/Trie

使用 NLP 将文本分割成句子

原文:towardsdatascience.com/segmenting-text-into-sentences-using-nlp-35d8ef55c0fd

特征工程、统计模型以及从反馈中学习

Arun JagotaTowards Data Science Arun Jagota

·发布在 Towards Data Science ·10 分钟阅读·2023 年 1 月 30 日

--

图片由 Nile 提供,来源于 Pixabay

在自然语言处理(NLP)中,将文本文档分割成句子是一个有用的基本操作。这是许多更复杂的 NLP 任务的第一步。例如,在写作过程中检测和纠正文本中的错误[1],或者检测命名实体[2]。

在前者中,想法是常见错误不会跨越句子边界。这对于后者也是成立的。命名实体也通常不会跨越句子边界。

在任何情况下,这都大大简化了问题。例如,训练和评估可以依赖于一个句子语料库。即使是从更长的文档中提取的句子也可以独立处理。

在这篇文章中,我们将讨论将文本分割成句子的问题。我们将采用一种“苏格拉底式方法”的风格。通过这种方法,我们将从头开始迭代地构建一个启发式预测器,采用一种“迭代假设精炼”的风格。适当的猜测和问题推动精炼。在这个过程中,ChatGPT 将非常有帮助。

我们最终采用的特定方法在精神上类似于流行的 Punkt 算法 [5]。

让我们开始吧。

首先,让我们明确一点:我们将从原始文本开始。没有诸如分词、转小写、去除停用词等预处理。这些操作可能会丢失一些对预测任务有用的信号。

我们都知道,大多数句子以句号结束。因此,我们最简单的预测器是

If the current character is a period
  Predict that we are at a sentence boundary

显而易见的问题是,“这在哪些方面会失败”?例如,句子可能以问号结束。或者以感叹号结束。好吧,让我们稍微抽象一下这个规则。

If the current character is boundary_marker
  Predict that we are at a sentence boundary

我们将 boundary_marker 设置为{句号 (.), 问号 (?), 感叹号 (!) }。

好吧,那么现在同样的问题是,“在哪儿失败”?考虑“Ph.D.”。我们不希望第一个句号被视为句子边界。

直观上,如果句号在一个单词中,我们就不是在句子边界。

所以下面的规则应该会更有效。

If the current character is boundary_marker and followed by a white space
  Predict that we are at a sentence boundary

好吧,这样可以过滤掉涉及单词内句号的误报,并且似乎没有丢失任何召回率。

在查看本文的句子时,我们还注意到紧接着的句子中的第一个词是大写的。呃,这是一个基本的语法规则。好吧,现在它也是句子边界的预测指标。

鉴于上述情况,让我们进一步优化规则。

If the current character is boundary_marker and followed by right_context
  Predict that we are at a sentence boundary

其中 right_context 是“空格后跟大写字母”。

通过添加“大写字母”,我们收紧了条件。这如何影响精确度和召回率?

精确度不可能下降。召回率可能下降。考虑以下情况。文本短且非正式,也许是匆忙写成的——例如在短信中。我们可能会错过实际的句子边界,因为新句子的第一个词可能没有大写。

好吧,这实际上让我们思考一下我们是否需要一个用于干净文本的句子边界预测器,还是用于不一定干净的文本。看来问题的复杂性发生了变化。

如果我们想在不一定干净的文本中预测句子边界,因为召回率可能会下降,我们需要问一下,当我们收紧条件时,精确度是否提高。

考虑这个例子,我在询问 ChatGPT 时设法诱导出来的。

give me examples of sentence boundary false positives involving periods

我不确定它是否理解了问题,但我确实在其回答中找到了一些我可以使用的答案。

The U.S. is a big country.

“句号后跟空格”的条件会在U.S之后立即标记为句子边界,这显然是错误的。

好吧,我们继续前进。接下来思考一下。如果我们在条件中添加左侧上下文呢?新的条件将是

If the current character is boundary_marker and is preceded by 
left_context and is followed by right_context
  Predict that we are at a sentence boundary

在不一定干净的文本中,可能会这样。我们可以从右侧上下文中去掉“大写字母”。这将有助于召回,但可能会降低精确度。我们可以通过添加左侧上下文来尝试提高精确度。

我问了 ChatGPT 这个问题。

give me examples of words that serve as the last word in a sentence.

其回答的前几行如下。

Run.
Stop.
Wait.

好吧,看来它没有完全理解我的问题,可能是因为表达得不够清楚。尽管如此,这些答案是有用的,即使非常有限。确实,在单词句子中,这些词是很好的左侧上下文。

这也帮助我完善了对 ChatGPT 的提问。首先,我将问题重新表述为

give me examples of words that serve as the last word in a multi-word 
sentence.

那没有帮助。它的回答仍然由单词句子组成,尽管实际的单词是很好的新候选词,适合在左侧上下文中使用。

以下问题

give me examples of left context to predict sentence boundaries.

生成了一个有趣且可能有用的回答,尽管不是我所寻找的。所以我不会详细阐述。

以下问题

Which words tend to be the last word in a sentence.

的确导致了更有用的回答。

Words that tend to be the last word in a sentence in written English 
include:

questions (e.g. What?, How?)
exclamations (e.g. Wow!, Oh!)
commands (e.g. Stop!, Go!)
short, one-word sentences that express a complete thought (e.g. Yes, No, Okay)
proper nouns (e.g. London, Bob)
some adverbs (e.g. Today, Yesterday, Finally)

Note: These are general tendencies, and the last word in a sentence can 
vary depending on the context and the type of sentence.

很好的响应!从中我们可以推测,添加左上下文可能有帮助。这个响应还提供了以下实际的见解,推动我们前进。

替换

If the current character is boundary_marker and is preceded by 
left_context and is followed by right_context

通过较不严格的条件。例如left_context特征、right_context特征、boundary_marker。关键是 left_context 和 right_context 在条件中表示可能会导致组合爆炸。特征可以更柔和地结合以用于预测目的。

这可能很强大,特别是如果我们有办法利用监督学习,即反馈。支配特征集合如何用于预测句边界的参数就变得可学习,并可能减轻在遇到新场景时引入新特征的需要。

在这种设置中利用监督学习并不像人们想象的那么困难。我们不需要构建大量的句边界正例和句边界负例训练集。相反,我们可以按如下方式进行。

从初始条件开始,例如

if the current character is a boundary_marker and it is followed by 
white space

具有高召回率。额外增加 left_context 和 right_context 特征。

运行最初配置的句边界预测器,在一个或多个足够大的文档上,并以易于可视化和标记的格式输出相邻预测句对。即使是 CSV 文件也足够了。检查 CSV 文件并将发现的任何假阳性标记出来。即使时间限定这个工作也可能会提升句边界预测器的性能。

当然,投入更多时间进行标记数据可能会进一步提高预测准确性。一个丰富的标记数据集可能会产生更好的结果。不过,这项投资不需要提前完成。可以逐步进行。

接下来,我们讨论要包含的具体合理特征。这些特征受到我们迄今为止的调查的启发。首先,我们的示例中添加一些我们从[5]找到的特征,因为这些特征将有助于特征工程。

... Mr. Smith ...
... Johann S. Bach ...

1. 当前句子中在 boundary_marker 之前的单词数。

2. 如果 boundary_marker 是句号,则当前标记中在其之前的额外句号数量。

3. 上一个词的词性。

4. 边界标记符前面单词的身份。

5. 边界标记符前面单词的长度。

6. 如果 boundary_marker 是句号,那么它是否嵌入在一个命名实体中。

特征 1) 与特征 4) 结合显然对单词句子有用。

特征 2) 受到示例的启发。

The U.S. is a big country.

这可以帮助预测U.S.中的第二个句号不是句边界。

特征 3) 受到 ChatGPT 对专有名词的响应的启发。

特征 5) 与特征 2) 结合时可能对预测边界标记实际上是前一个单词的一部分有用,也就是说这是一个假阳性。

特征 6) 有助于这个示例。

... Johann S. Bach ...

监督统计算法

现在我们将描述我们认为是一种有效的统计性质的监督算法。针对这个问题进行了调优。相对容易实现。

算法将假设我们正处于一个boundary_marker字符上。它将进行一个概率推断,我们将其称为等式 (1)。

P(该边界真实|boundary_marker,左上下文特征,右上下文特征)

这里的boundary_marker指的是标记的实际值:句点问号,或感叹号

尽管我们不会详细开发这个算法,但我们希望指出一些我们心中所想的特征:

  1. 如果我们认为合适,可以预先指定某些特征的贡献,并且不允许其发生变化。例如,如果紧随boundary_marker之后的字符不是空格,而我们希望确保在这种情况下boundary_marker被预测为假阳性,我们可以强制使等式 (1) 始终评估为零。

  2. 我们不一定要假设等式 (1) 中的特征在结果—边界是否真实—给定的情况下是条件独立的。例如,这些特征可能通过一个合适的概率图模型连接起来,该模型的结构可能由模型设计者为特定用例设定。或者,它可以是一个专家组合模型,其中“专家”是特定的特征值或特征组合,它们强烈预测边界标记是假阳性。有关专家组合模型,请参见 [4]。

  3. 我们可以想象一个等式 (1) 的变体,它不是完全概率性的,同时保持相同的精神。具体来说,我们可以用两个评分函数替换等式 (1):score_true(boundary_marker左上下文特征,右上下文特征) 和 score_false(boundary_marker左上下文特征,右上下文特征)。这样做,我们可能将模型中可学习的参数数量翻倍,从而使模型更为丰富。

学习示例

现在让我们在具体的设置中说明学习过程。我们将采用类似专家组合的形式,因为它在我们的设置中直观有效,并且学习过程也易于解释。

假设我们想要学习如果边界标记是句点而前一个标记有额外的句点,则我们应该预测这个边界标记不真实。出于防御目的,我们还将前一个标记的长度作为额外的预测因子。

我们希望第二个‘.’在美国中被标记为假阳性boundary_marker

让我们设想将这个专家表述为

P(boundary_marker 真实| boundary_marker 是 ‘.’,前一个标记的长度前一个标记中的句点数量)。

这模拟了一个特定的专家,当boundary_marker是句点时会参与其中。在这种情况下,它使用前一个标记的长度和其中句点的数量来决定是否预测boundary_marker为真实。

假设我们最初设定先验为

P(boundary_marker 为真实|boundary_marker 为‘.’, l, np) 为 1

针对各种合理的 lnp 组合。假设初始算法将句号(例如 U.S. 中的第二个)预测为句子边界。假设我们提供反馈,指出这些是误报。如果我们看到足够多的这种例子,我们应该能够学习到

P(boundary_marker 为真实|boundary_marker 为‘.’, 3, 1) 接近零

我们选择 l=3 和 np=1 作为示例,因为它适用于 U.S.

接下来,设想一个不同的专家,建模为

P(boundary_marker 为真实|紧随 boundary_marker 之后的字符不是空格)

同样,我们可以将先验设定为接近 1,这样最初,例如,U.S. 中的第一个句号就会被预测为真实的句子边界。用户随后提供反馈,将其标记为假阳性。这些反馈可以用于推动

P(boundary_marker 为真实|紧随 boundary_marker 之后的字符不是空格)

逼近零。

总结

在这篇文章中,我们讨论了将文本分割成句子的问题。我们以如下方式处理这个问题。我们通过从最基本的规则开始,迭代地构建句子边界预测器,并通过问答来完善它。我们询问某个规则可能导致的假阳性或假阴性。答案帮助我们完善规则。在此过程中,我们按照之前描述的方式请求 ChatGPT 的帮助。

在某个阶段,我们意识到我们应该从硬规则转向基于特征的方法。我们在一个特定的概率模型下开发了这种方法,引入了我们早期调查中建议的特定特征,将其框定在专家产品的设置中,并讨论了在这种设置中从反馈中学习。

参考文献

  1. 使用 NLP 进行文本修正。检测和修正常见错误… | 作者 Arun Jagota | 2023 年 1 月 | Towards Data Science

  2. NLP 中的命名实体识别。真实世界的应用案例、模型、方法…

  3. ChatGPT

  4. Geoff Hinton — 专家产品

  5. nltk.tokenize.punkt

R 中的 SEIR 建模使用 deSolve — 鹿中的慢性消耗性疾病

原文:towardsdatascience.com/seir-modeling-in-r-using-desolve-chronic-wasting-disease-in-deer-d88b4de2d6c6?source=collection_archive---------12-----------------------#2023-03-01

使用小数据集获取健康政策洞察

Giovanni MalloyTowards Data Science Giovanni Malloy

·

关注 发布于 Towards Data Science ·14 分钟阅读·2023 年 3 月 1 日

--

图片由 Acton Crawford 提供,来自 Unsplash

收集大量数据是训练和测试监督学习模型以及进行预测的前提。然而,使用少量数据和简单但基础的数学模型,我们可以生成大量见解,从而为缓解传染病威胁提供政策参考。这些类型的模型在数据科学界经常被忽视,但它们同样是生成见解的重要工具。

目前,困扰鹿和麋鹿种群的最突出疾病之一是慢性消耗病(CWD)。CWD 是一种“传染性海绵状脑病或朊病毒病”,类似于疯牛病,主要“影响鹿、麋鹿、驯鹿、四不像和驼鹿”[1, 2]。这种疾病存在于自由放养的鹿种群中,并且可以通过“体液如粪便、唾液、血液或尿液的直接接触,或通过土壤、食物或水的环境污染进行间接传播”[2]。目前,它不会影响人类,但保持 CWD 感染动物不进入食品供应链是健康安全的重要目标。建模这种疾病的传播对于理解未来 CWD 发生率如何增加是很重要的。利用我在上一篇博客文章中描述的一些简单建模工具,我们可以使用在威斯康星州白尾鹿种群上收集的数据,在 R 中建立一个 CWD 的 SEIR 模型。

SEIR 模型

首先,我们将为鹿种群中的 CWD 开发简单的 SEIR 模型。SEIR 代表易感、暴露、传染、恢复/移除,这些疾病状态将构成我们分隔模型的基础。一般来说,易感状态包括所有尚未感染 CWD 但可能会被感染的鹿。通过模型中的出生、死亡以及感染,易感状态会有进出。模型中的暴露状态将包括感染了 CWD 但尚未传染的鹿种群。鹿在暴露于 CWD 后进入这个阶段,直到潜伏期(从感染到具有传染性的时间)过去。传染状态是吸收性的。也就是说,唯一的出路是通过死亡。现在,让我们在数学上描述 SEIR 模型:

作者提供的图像

参数

许多参数值可以通过文献综述或与主题专家咨询后的假设来实例化。理想情况下,我们希望最小化校准的参数数量。许多不同的参数集可以实现对数据的近似相同拟合,没有一个参数集能完美拟合。

有时,我们会在文献中找到可以调整以适应我们模型的数据。例如,一只鹿的预期寿命约为 4.5 年[3],CWD 的潜伏期约为 15 个月[1],CWD 将死亡率提高到健康鹿的约 4.5 倍[4]。此外,威斯康星州公布了猎物采集数据[5],这将影响我们的死亡率,并且公布了猎杀天数的数据[9]。我们将在模型中使用的 CWD 发生率数据(1999–2019 年)直接来自威斯康星州公布的检测结果,可以在此处找到:CWD 全州监测总结 [8]。在我们拥有数据的时期,大约 99%的报告 CWD 病例发生在威斯康星州南部农田区,因此我们将把模型人口限制在该区域。

R 编程

我们首先导入数据,并使用 deSolve 库定义我们的模型。在这里,我们将模型定义为一个函数。注意,我们明确考虑了自然死亡和因狩猎而死亡的情况,以及遗传(beta_CWD_Natural)和传染性(beta_CWD_Deer)CWD。

require('deSolve')
require('ggplot2')

#Original Data
df_1999_2019_SouthernFarmland <- read.csv('./CWD Year 1999-2019 Southern Farmland Population Only.csv')

##################
# Section 1\. Model
##################
CWD_mod <- function(Time, State, Pars)
{
  with(as.list(c(State,Pars)),{
    N_Deer <- S_Deer + E_Deer + I_Deer
    births_in <- death_rate_S*S_Deer + hunt_kill_rate*S_Deer*N_Hunters + death_rate_E*E_Deer + hunt_kill_rate*E_Deer*N_Hunters + death_rate_I*I_Deer + hunt_kill_rate*I_Deer*N_Hunters
    d_S_Deer <- births_in - death_rate_S*S_Deer - beta_CWD_Deer*S_Deer*(I_Deer)/N_Deer - beta_CWD_Natural*S_Deer/N_Deer - hunt_kill_rate*S_Deer*N_Hunters
    d_E_Deer <- beta_CWD_Deer*S_Deer*(I_Deer)/N_Deer + beta_CWD_Natural*S_Deer/N_Deer - epsilon_Deer*E_Deer - death_rate_E*E_Deer - hunt_kill_rate*E_Deer*N_Hunters
    d_I_Deer <- epsilon_Deer*E_Deer - death_rate_I*I_Deer - hunt_kill_rate*I_Deer*N_Hunters
    return(list(c(d_S_Deer, d_E_Deer, d_I_Deer)))
  })
}

现在,我们使用文献回顾和一些经过充分考量的假设来定义除 CWD 传播率之外的所有参数。为了方便起见,我在注释中包含了用于参数化模型的网站/参考资料。

##########################
# Section 2\. Parameters
##########################
S_Deer <- 1377100/4
E_Deer <- 0
I_Deer <- 3
births_in <- 100
life_expectancy_deer <- 4.5 # years https://www.jstor.org/stable/3803059?seq=7#metadata_info_tab_contents
life_expectancy_deer <- life_expectancy_deer * 12
death_rate_S <- 1/life_expectancy_deer
epsilon_Deer <- 1/15 #https://pdfs.semanticscholar.org/75eb/8b27d8cd507c23f2a74bfc7f4391505a7b4a.pdf
death_rate_E <- 1/(life_expectancy_deer)
risk_ratio_CWD_Deer <- 4.5 # https://www.dec.ny.gov/docs/wildlife_pdf/cwdfactsheet.pdf
death_rate_I <- death_rate_S * risk_ratio_CWD_Deer

hunter_harvest_2017 <- 322054 #https://dnr.wi.gov/wideermetrics/DeerStats.aspx
Avg_days_hunted_2017 <- 4.3 # https://dnr.wi.gov/topic/WildlifeHabitat/documents/reports/gundeer2.pdf
hunter_days_2017 <- 2695768 # https://dnr.wi.gov/topic/WildlifeHabitat/documents/reports/gundeer2.pdf
num_hunters_2017 <- 2695768/4.3
num_deer_2017 <- 1377100
num_deers_per_hunter_2017 <- hunter_harvest_2017/num_hunters_2017
hunter_kill_rate_year <- num_deers_per_hunter_2017/num_deer_2017
hunter_kill_rate_month <- hunter_kill_rate_year/12
beta_CWD_Natural <- .000001

现在,我们已经到了至关重要的校准步骤。正如我之前提到的,有许多不同的参数集可以生成适合的模型。我们现在的任务是估计 CWD 的传播率,该传播率最能符合威斯康星州的发病数据,同时保持所有其他参数不变。

校准模型的方法有很多种,因为参数值具有无限的可能范围。对于 CWD 的传播率,传播率的可能值可以从 0 到正无穷。在这种情况下,我进行了一些初步探索,认为 beta_CWD_Deer 的值会在 0.01 和 0.5 之间。在这种情况下,我将测试 0.01 到 0.5 之间的所有值,增量为 0.001。这是最简单的方法,但当我们关注政策或简单性时,它完全可以满足需求。我们将寻找使得每年记录的病例数与模型每年报告的病例数之间的均方误差最小化的 beta_CWD_Deer 值。

##########################
# Section 3\. Calibration
##########################
beta_CWD_Deer <- 0.01

# Define list of possible values for beta_CWD_Deer
list_beta_CWD_Deer <- seq(0.01,0.5,by = 0.001)
# This is a placeholder value that also serves as a sanity check
best_beta_CWD_Deer <- 100
# Set the best score to something absurdly high as a santiy check
best_beta_CWD_Deer_Score <- 10000000000000000000000000

#Track for plot
list_all_betas <- c()
list_all_mse_scores <- c()

for(i in 1:length(list_beta_CWD_Deer))
{
  # Test the possible value of beta_CWD_Deer
  beta_CWD_Deer <- list_beta_CWD_Deer[i]
  #Initialize the model population
  ini_cwd_mod <- c(S_Deer = S_Deer, E_Deer = E_Deer, I_Deer = I_Deer)
  #Run the model for 252 months
  times_cwd_mod <- seq(1,252,by = 1) #By Month
  #Initialize the parameters in the model
  pars_cwd_mod <- c(births_in = births_in, death_rate_S = death_rate_S, 
                    beta_CWD_Deer = beta_CWD_Deer, 
                    beta_CWD_Natural = beta_CWD_Natural, 
                    hunt_kill_rate = hunter_kill_rate_month,
                    epsilon_Deer = epsilon_Deer, 
                    death_rate_E = death_rate_E,
                    death_rate_I = death_rate_I, 
                    N_Hunters = num_hunters_2017)

  #Output using our previously defined function
  CWD_mod_Out <- ode(ini_cwd_mod, times_cwd_mod, CWD_mod, pars_cwd_mod)

  #create a data frame of the results
  yS_Deer <- CWD_mod_Out[,"S_Deer"]
  yE_Deer <- CWD_mod_Out[,"E_Deer"]
  yI_Deer <- CWD_mod_Out[,"I_Deer"]
  CWD_Results <- data.frame(yS_Deer, yE_Deer, yI_Deer)
  #Calculate the model predicted prevalence each year
  CWD_Results$Prevalence <- CWD_Results$yI_Deer / (CWD_Results$yS_Deer +
                                                     CWD_Results$yE_Deer +
                                                     CWD_Results$yI_Deer)
  CWD_Results_Year <- CWD_Results[c(12,24,36,48,60,72,84,96,108,120,
                                    132,144,156,168,180,192,204,
                                    216,228,240,252),]
  # Compare to the data
  CWD_Results_Year$Prevalence_Data <- df_1999_2019_SouthernFarmland$Incidence
  CWD_Results_Year$Year <- df_1999_2019_SouthernFarmland$Year

  CWD_Results_Year$Prevalence_Data <- as.numeric(as.character(CWD_Results_Year$Prevalence_Data))
  CWD_Results_Year$Prevalence <- as.numeric(as.character(CWD_Results_Year$Prevalence))
  # Calculate the mean squared error
  curr_mean_sq_err <- mean((CWD_Results_Year$Prevalence - CWD_Results_Year$Prevalence_Data)²)

  list_all_betas <- c(list_all_betas, beta_CWD_Deer)
  list_all_mse_scores <- c(list_all_mse_scores, curr_mean_sq_err)

  #If it is the lowest MSE, save the beta value and the score as the best
  if(curr_mean_sq_err < best_beta_CWD_Deer_Score)
  {
    best_beta_CWD_Deer <- beta_CWD_Deer
    best_beta_CWD_Deer_Score <- curr_mean_sq_err
    print(curr_mean_sq_err)
  }
  print(i)
}

我们可以使用 ggplot 绘制均方误差随时间的变化,以直观地查看模型在不同参数值下的改进情况。

plot_df <- as.data.frame(matrix(c(list_all_betas, list_all_mse_scores), ncol = 2, byrow = FALSE))
colnames(plot_df) <- c('beta_CWD_Deer', 'Model_MSE')

ggplot()+
  geom_line(data = plot_df, aes(x = beta_CWD_Deer, y = Model_MSE, color = 'Model MSE'), size = 3)+
  xlab('beta_CWD_Deer Value')+
  ylab('MSE')+
  ggtitle('Calibration Results')+
  theme_bw()+
  theme(axis.text.x = element_text(size = 24),
        axis.text.y = element_text(size = 24),
        axis.title = element_text(size = 28),
        axis.title.x = element_text(size = 28),
        axis.title.y = element_text(size = 28),
        title = element_text(size = 16),
        legend.text = element_text(size = 24))+
  scale_color_brewer(palette="Pastel1")+
  theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank())+
  labs(color=' ', shape = '')

图片由作者提供

最小化模型误差的 beta_CWD_Deer 值为 0.359。我们可以从图中看到,当值低于大约 0.25 时,爆发不会发生,因此误差保持不变。

结果

现在,我们可以使用假设的和校准的参数集运行我们的完整模型。

###############################################
# Section 4\. Plot Results
###############################################

beta_CWD_Deer <- best_beta_CWD_Deer

ini_cwd_mod <- c(S_Deer = S_Deer, E_Deer = E_Deer, I_Deer = I_Deer)

times_cwd_mod <- seq(1,252,by = 1) #By Month

pars_cwd_mod <- c(births_in = births_in, death_rate_S = death_rate_S, 
                  beta_CWD_Deer = beta_CWD_Deer, 
                  beta_CWD_Natural = beta_CWD_Natural, 
                  hunt_kill_rate = hunter_kill_rate_month,
                  epsilon_Deer = epsilon_Deer, 
                  death_rate_E = death_rate_E,
                  death_rate_I = death_rate_I, 
                  N_Hunters = num_hunters_2017)

#Output
CWD_mod_Out <- ode(ini_cwd_mod, times_cwd_mod, CWD_mod, pars_cwd_mod)

yS_Deer <- CWD_mod_Out[,"S_Deer"]
yE_Deer <- CWD_mod_Out[,"E_Deer"]
yI_Deer <- CWD_mod_Out[,"I_Deer"]
CWD_Results <- data.frame(yS_Deer, yE_Deer, yI_Deer)
CWD_Results$Prevalence <- CWD_Results$yI_Deer / (CWD_Results$yS_Deer +
                                                  CWD_Results$yE_Deer +
                                                  CWD_Results$yI_Deer)
CWD_Results_Year <- CWD_Results[c(12,24,36,48,60,72,84,96,108,120,
                                  132,144,156,168,180,192,204,
                                  216,228,240,252),]
CWD_Results_Year$Prevalence_Data <- df_1999_2019_SouthernFarmland$Incidence
CWD_Results_Year$Year <- df_1999_2019_SouthernFarmland$Year

CWD_Results$Year <- seq((1998+(1/12)),2019,by=(1/12))

ggplot()+
  geom_point(data = df_1999_2019_SouthernFarmland, aes(x = Year, y = Incidence, color = 'Southern Farmland Zone Data'), size = 3)+
  geom_line(data = CWD_Results, aes(x = Year, y = Prevalence, color = 'Model'), size = 2)+
  xlab('Year')+
  ylab('Prevalence')+
  ggtitle('Chronic Wasting Disease, Whitetail Deer, Wisconsin')+
  theme_bw()+
  theme(axis.text.x = element_text(size = 24),
        axis.text.y = element_text(size = 24),
        axis.title = element_text(size = 28),
        axis.title.x = element_text(size = 28),
        axis.title.y = element_text(size = 28),
        title = element_text(size = 16),
        legend.text = element_text(size = 24),
        legend.position = c(0.3,0.9))+
  scale_y_continuous(breaks = c(0,0.05,0.1,0.15),
                     labels = c('0%','5%','10%','15%'))+
  scale_x_continuous(breaks = c(1999,2005,2011,2017),
                     labels = c('1999', '2005', '2011', '2017'))+
  scale_color_brewer(palette="Pastel1")+
  theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank())+
  labs(color=' ', shape = '')

图片来源:作者

从图中我们可以看到,模型在 2016 年之前低估了 CWD 的流行率。之后,模型开始过度拟合流行率。SEIR 模型永远无法完美拟合数据,也不设计为这样做。它旨在作为对疾病传播的优雅近似。我们可以将模型运行到数据之外的时间段,以观察长期的情况:

###############################################
# Section 5\. Plot Predicted prevalence 10 years forward
###############################################

beta_CWD_Deer <- best_beta_CWD_Deer

ini_cwd_mod <- c(S_Deer = S_Deer, E_Deer = E_Deer, I_Deer = I_Deer)

times_cwd_mod <- seq(1,492,by = 1) #By Month

pars_cwd_mod <- c(births_in = births_in, death_rate_S = death_rate_S, 
                  beta_CWD_Deer = beta_CWD_Deer, 
                  beta_CWD_Natural = beta_CWD_Natural, 
                  hunt_kill_rate = hunter_kill_rate_month,
                  epsilon_Deer = epsilon_Deer, 
                  death_rate_E = death_rate_E,
                  death_rate_I = death_rate_I, 
                  N_Hunters = num_hunters_2017)

#Output
CWD_mod_Out <- ode(ini_cwd_mod, times_cwd_mod, CWD_mod, pars_cwd_mod)

yS_Deer <- CWD_mod_Out[,"S_Deer"]
yE_Deer <- CWD_mod_Out[,"E_Deer"]
yI_Deer <- CWD_mod_Out[,"I_Deer"]
CWD_Results <- data.frame(yS_Deer, yE_Deer, yI_Deer)
CWD_Results$Prevalence <- CWD_Results$yI_Deer / (CWD_Results$yS_Deer +
                                                   CWD_Results$yE_Deer +
                                                   CWD_Results$yI_Deer)
CWD_Results_Year <- CWD_Results[c(12,24,36,48,60,72,84,96,108,120,
                                  132,144,156,168,180,192,204,
                                  216,228,240,252,
                                  264,276,288,300,
                                  312,324,336,348,360,372,
                                  384,396,408,420,432,
                                  444,456,468,480,492),]

CWD_Results$Year <- seq((1998+(1/12)),2039,by=(1/12))

ggplot()+
  geom_point(data = df_1999_2019_SouthernFarmland, aes(x = Year, y = Incidence, color = 'Southern Farmland Zone Data'), size = 3)+
  geom_line(data = CWD_Results, aes(x = Year, y = Prevalence, color = 'Model'), size = 2)+
  xlab('Year')+
  ylab('Prevalence')+
  ggtitle('Chronic Wasting Disease, Whitetail Deer, Wisconsin')+
  theme_bw()+
  theme(axis.text.x = element_text(size = 24),
        axis.text.y = element_text(size = 24),
        axis.title = element_text(size = 28),
        axis.title.x = element_text(size = 28),
        axis.title.y = element_text(size = 28),
        title = element_text(size = 16),
        legend.text = element_text(size = 24),
        legend.position = c(0.3,0.9))+
  scale_y_continuous(breaks = c(0,0.05,0.1,0.15),
                     labels = c('0%','5%','10%','15%'))+
  scale_x_continuous(breaks = c(2000,2005,2010,2015,2020,2025,2030,2035),
                     labels = c('2000','2005','2010','2015','2020','2025','2030','2035'))+
  scale_color_brewer(palette="Pastel1")+
  theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank())+
  labs(color=' ', shape = '')

图片来源:作者

模型预测,从 2023 年开始,群体中的 CWD 流行率将保持相对稳定的 22%。这应理解为,如果基础疾病和人口动态不发生变化,威斯康星州南部农田地区的白尾鹿的 CWD 流行率应为约 22%。

在 2022 年,威斯康星州自然资源部发现 CWD 的流行率略高于 18% [6]。

政策评估

现在我们已经有了模型和基线预测,我们可以探索一些假设的干预措施及其对 CWD 传播的影响。例如,目前的狩猎规定授权在威斯康星州每张许可证猎杀一只鹿 [7]。然而,州政府正在考虑允许无限制猎杀出现 CWD 症状的鹿。让我们设想几个可能的情景:

  1. 预计该法律将使 CWD 感染鹿的猎杀率增加 35%,而没有 CWD 的鹿猎杀率增加 15%(由于误诊)。

  2. 法律已制定,但朊病毒病已变异为传播性提高 30%。

  3. 法律已经制定,朊病毒病发生了突变,并开发了新的 AI 辅助技术,该技术利用计算机视觉来完美区分感染的和未感染的鹿。因此,预计该法律将使患有 CWD 的鹿的猎杀率增加 70%,而没有 CWD 的鹿猎杀率保持不变(0%)。

在模型中,我们将所有情景的起点设定为 2020 年(当时我们没有数据)。虽然我们会调整猎杀率和传播率,但我们将继续保持简化假设,即人口保持不变(出生人数 = 死亡人数)。以下是实现和绘制我们情景比较的 R 代码:

###################################
# Section 6\. Policy Evaluation
###################################

# Run period with data
beta_CWD_Deer <- best_beta_CWD_Deer

ini_cwd_mod <- c(S_Deer = S_Deer, E_Deer = E_Deer, I_Deer = I_Deer)

times_cwd_mod <- seq(1,252,by = 1) #By Month

pars_cwd_mod <- c(births_in = births_in, death_rate_S = death_rate_S, 
                  beta_CWD_Deer = beta_CWD_Deer, 
                  beta_CWD_Natural = beta_CWD_Natural, 
                  hunt_kill_rate = hunter_kill_rate_month,
                  epsilon_Deer = epsilon_Deer, 
                  death_rate_E = death_rate_E,
                  death_rate_I = death_rate_I, 
                  N_Hunters = num_hunters_2017)

#Output
CWD_mod_Out <- ode(ini_cwd_mod, times_cwd_mod, CWD_mod, pars_cwd_mod)

yS_Deer <- CWD_mod_Out[,"S_Deer"]
yE_Deer <- CWD_mod_Out[,"E_Deer"]
yI_Deer <- CWD_mod_Out[,"I_Deer"]
CWD_Results_data <- data.frame(yS_Deer, yE_Deer, yI_Deer)
CWD_Results_data$Prevalence <- CWD_Results_data$yI_Deer / (CWD_Results_data$yS_Deer +
                                                             CWD_Results_data$yE_Deer +
                                                             CWD_Results_data$yI_Deer)
CWD_Results_data$Year <- seq((1998+(1/12)),2019,by=(1/12))

# Baseline model continues 20 years
beta_CWD_Deer <- best_beta_CWD_Deer

ini_cwd_mod <- c(S_Deer = CWD_Results_data$yS_Deer[252], E_Deer = CWD_Results_data$yE_Deer[252], I_Deer = CWD_Results_data$yI_Deer[252])

times_cwd_mod <- seq(253,492,by = 1) #By Month

pars_cwd_mod <- c(births_in = births_in, death_rate_S = death_rate_S, 
                  beta_CWD_Deer = beta_CWD_Deer, 
                  beta_CWD_Natural = beta_CWD_Natural, 
                  hunt_kill_rate = hunter_kill_rate_month,
                  epsilon_Deer = epsilon_Deer, 
                  death_rate_E = death_rate_E,
                  death_rate_I = death_rate_I, 
                  N_Hunters = num_hunters_2017)

#Output
CWD_mod_Out <- ode(ini_cwd_mod, times_cwd_mod, CWD_mod, pars_cwd_mod)

yS_Deer <- CWD_mod_Out[,"S_Deer"]
yE_Deer <- CWD_mod_Out[,"E_Deer"]
yI_Deer <- CWD_mod_Out[,"I_Deer"]
CWD_Results_baseline_projection <- data.frame(yS_Deer, yE_Deer, yI_Deer)
CWD_Results_baseline_projection$Prevalence <- CWD_Results_baseline_projection$yI_Deer / (CWD_Results_baseline_projection$yS_Deer +
                                                                                           CWD_Results_baseline_projection$yE_Deer +
                                                                                           CWD_Results_baseline_projection$yI_Deer)
CWD_Results_baseline_projection$Year <- seq((2019+(1/12)),2039,by=(1/12))

# Scenario 1 model continues 20 years
CWD_mod_diff_hunt_rates <- function(Time, State, Pars)
{
  with(as.list(c(State,Pars)),{
    N_Deer <- S_Deer + E_Deer + I_Deer
    births_in <- death_rate_S*S_Deer + hunt_kill_rate_S*S_Deer*N_Hunters + death_rate_E*E_Deer + hunt_kill_rate_EI*E_Deer*N_Hunters + death_rate_I*I_Deer + hunt_kill_rate_EI*I_Deer*N_Hunters
    d_S_Deer <- births_in - death_rate_S*S_Deer - beta_CWD_Deer*S_Deer*(I_Deer)/N_Deer - beta_CWD_Natural*S_Deer/N_Deer - hunt_kill_rate_S*S_Deer*N_Hunters
    d_E_Deer <- beta_CWD_Deer*S_Deer*(I_Deer)/N_Deer + beta_CWD_Natural*S_Deer/N_Deer - epsilon_Deer*E_Deer - death_rate_E*E_Deer - hunt_kill_rate_EI*E_Deer*N_Hunters
    d_I_Deer <- epsilon_Deer*E_Deer - death_rate_I*I_Deer - hunt_kill_rate_EI*I_Deer*N_Hunters
    return(list(c(d_S_Deer, d_E_Deer, d_I_Deer)))
  })
}

beta_CWD_Deer <- best_beta_CWD_Deer

ini_cwd_mod <- c(S_Deer = CWD_Results_data$yS_Deer[252], E_Deer = CWD_Results_data$yE_Deer[252], I_Deer = CWD_Results_data$yI_Deer[252])

times_cwd_mod <- seq(253,492,by = 1) #By Month

pars_cwd_mod <- c(births_in = births_in, death_rate_S = death_rate_S, 
                  beta_CWD_Deer = beta_CWD_Deer, 
                  beta_CWD_Natural = beta_CWD_Natural, 
                  hunt_kill_rate_S = hunter_kill_rate_month*1.15,
                  hunt_kill_rate_EI = hunter_kill_rate_month*1.35,
                  epsilon_Deer = epsilon_Deer, 
                  death_rate_E = death_rate_E,
                  death_rate_I = death_rate_I, 
                  N_Hunters = num_hunters_2017)

#Output
CWD_mod_Out <- ode(ini_cwd_mod, times_cwd_mod, CWD_mod_diff_hunt_rates, pars_cwd_mod)

yS_Deer <- CWD_mod_Out[,"S_Deer"]
yE_Deer <- CWD_mod_Out[,"E_Deer"]
yI_Deer <- CWD_mod_Out[,"I_Deer"]
CWD_Results_scenario_1_projection <- data.frame(yS_Deer, yE_Deer, yI_Deer)
CWD_Results_scenario_1_projection$Prevalence <- CWD_Results_scenario_1_projection$yI_Deer / (CWD_Results_scenario_1_projection$yS_Deer +
                                                                                               CWD_Results_scenario_1_projection$yE_Deer +
                                                                                               CWD_Results_scenario_1_projection$yI_Deer)
CWD_Results_scenario_1_projection$Year <- seq((2019+(1/12)),2039,by=(1/12))

# Scenario 2 model continues 20 years
beta_CWD_Deer <- best_beta_CWD_Deer*1.3

ini_cwd_mod <- c(S_Deer = CWD_Results_data$yS_Deer[252], E_Deer = CWD_Results_data$yE_Deer[252], I_Deer = CWD_Results_data$yI_Deer[252])

times_cwd_mod <- seq(253,492,by = 1) #By Month

pars_cwd_mod <- c(births_in = births_in, death_rate_S = death_rate_S, 
                  beta_CWD_Deer = beta_CWD_Deer, 
                  beta_CWD_Natural = beta_CWD_Natural, 
                  hunt_kill_rate_S = hunter_kill_rate_month*1.15,
                  hunt_kill_rate_EI = hunter_kill_rate_month*1.35,
                  epsilon_Deer = epsilon_Deer, 
                  death_rate_E = death_rate_E,
                  death_rate_I = death_rate_I, 
                  N_Hunters = num_hunters_2017)

#Output
CWD_mod_Out <- ode(ini_cwd_mod, times_cwd_mod, CWD_mod_diff_hunt_rates, pars_cwd_mod)

yS_Deer <- CWD_mod_Out[,"S_Deer"]
yE_Deer <- CWD_mod_Out[,"E_Deer"]
yI_Deer <- CWD_mod_Out[,"I_Deer"]
CWD_Results_scenario_2_projection <- data.frame(yS_Deer, yE_Deer, yI_Deer)
CWD_Results_scenario_2_projection$Prevalence <- CWD_Results_scenario_2_projection$yI_Deer / (CWD_Results_scenario_2_projection$yS_Deer +
                                                                                               CWD_Results_scenario_2_projection$yE_Deer +
                                                                                               CWD_Results_scenario_2_projection$yI_Deer)
CWD_Results_scenario_2_projection$Year <- seq((2019+(1/12)),2039,by=(1/12))

# Scenario 3 model continues 20 years
beta_CWD_Deer <- best_beta_CWD_Deer*1.3

ini_cwd_mod <- c(S_Deer = CWD_Results_data$yS_Deer[252], E_Deer = CWD_Results_data$yE_Deer[252], I_Deer = CWD_Results_data$yI_Deer[252])

times_cwd_mod <- seq(253,492,by = 1) #By Month

pars_cwd_mod <- c(births_in = births_in, death_rate_S = death_rate_S, 
                  beta_CWD_Deer = beta_CWD_Deer, 
                  beta_CWD_Natural = beta_CWD_Natural, 
                  hunt_kill_rate_S = hunter_kill_rate_month,
                  hunt_kill_rate_EI = hunter_kill_rate_month*1.7,
                  epsilon_Deer = epsilon_Deer, 
                  death_rate_E = death_rate_E,
                  death_rate_I = death_rate_I, 
                  N_Hunters = num_hunters_2017)

#Output
CWD_mod_Out <- ode(ini_cwd_mod, times_cwd_mod, CWD_mod_diff_hunt_rates, pars_cwd_mod)

yS_Deer <- CWD_mod_Out[,"S_Deer"]
yE_Deer <- CWD_mod_Out[,"E_Deer"]
yI_Deer <- CWD_mod_Out[,"I_Deer"]
CWD_Results_scenario_3_projection <- data.frame(yS_Deer, yE_Deer, yI_Deer)
CWD_Results_scenario_3_projection$Prevalence <- CWD_Results_scenario_3_projection$yI_Deer / (CWD_Results_scenario_3_projection$yS_Deer +
                                                                                               CWD_Results_scenario_3_projection$yE_Deer +
                                                                                               CWD_Results_scenario_3_projection$yI_Deer)
CWD_Results_scenario_3_projection$Year <- seq((2019+(1/12)),2039,by=(1/12))

ggplot()+
  geom_point(data = df_1999_2019_SouthernFarmland, aes(x = Year, y = Incidence, color = 'Southern Farmland Zone Data'), size = 3)+
  geom_line(data = CWD_Results_data, aes(x = Year, y = Prevalence, color = 'Calibrated Model'), size = 2)+
  geom_line(data = CWD_Results_baseline_projection, aes(x = Year, y = Prevalence, color = 'Baseline Projection'), size = 2)+
  geom_line(data = CWD_Results_scenario_1_projection, aes(x = Year, y = Prevalence, color = 'Scenario 1 Projection'), size = 2)+
  geom_line(data = CWD_Results_scenario_2_projection, aes(x = Year, y = Prevalence, color = 'Scenario 2 Projection'), size = 2)+
  geom_line(data = CWD_Results_scenario_3_projection, aes(x = Year, y = Prevalence, color = 'Scenario 3 Projection'), size = 2)+
  xlab('Year')+
  ylab('Prevalence')+
  ggtitle('Chronic Wasting Disease, Whitetail Deer, Wisconsin')+
  theme_bw()+
  theme(axis.text.x = element_text(size = 24),
        axis.text.y = element_text(size = 24),
        axis.title = element_text(size = 28),
        axis.title.x = element_text(size = 28),
        axis.title.y = element_text(size = 28),
        title = element_text(size = 16),
        legend.text = element_text(size = 24),
        legend.position = c(0.2,0.8))+
  scale_y_continuous(breaks = c(0,0.05,0.1,0.15),
                     labels = c('0%','5%','10%','15%'))+
  scale_x_continuous(breaks = c(2000,2005,2010,2015,2020,2025,2030,2035),
                     labels = c('2000','2005','2010','2015','2020','2025','2030','2035'))+
  theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank())+
  labs(color=' ', shape = '')

图片来源:作者

我们可以看到,增加对 CWD 感染鹿的猎杀对 CWD 的总体流行率几乎没有影响。情景 1 导致长期流行率约为 18.5%,情景 2 导致长期流行率约为 23%,情景 3 导致长期流行率约为 20.2%。实际上,还有其他环境和生态后果在这个方法中未考虑。

结论

仅凭 20 个数据点,我们能够开发一个简单的模型,用 SEIR 模型解释威斯康星州南部白尾鹿群体中 CWD 的传播。这种模型在结合了学科专业知识时尤其有助于指导政策。虽然不如 AI 方法准确或新颖,但数学模型在极少量数据下的表现非常出色。

除非另有说明,否则所有图片均由作者提供。

数据在威斯康星州自然资源部书面许可下使用。

参考文献

[1] Williams ES, Miller MW. 北美鹿和麋鹿的慢性消耗病。Rev Sci Tech. 2002 年 8 月;21(2):305–16. doi: 10.20506/rst.21.2.1340. PMID: 11974617.

[2] 疾病控制与预防中心。 慢性消耗病 (CWD) | 普里昂病 | CDC. 2023.

[3] Lopez, R. R., Mark E. P. Vieira, Silvy, N. J., Frank, P. A., Whisenant, S. W., & Jones, D. A. (2003). 佛罗里达群岛鹿的生存、死亡率和预期寿命。野生动物管理杂志, 67(1), 34–45. doi.org/10.2307/3803059

[4] 纽约州环境保护部。 慢性消耗病 — 纽约州环境保护部. 2016.

[5] 威斯康星州自然资源部。 鹿统计数据 (wi.gov). 2020.

[6] 威斯康星州自然资源部。 CWD 2022 年度报告 (wi.gov). 2023.

[7] 威斯康星州狩猎规定:2022 秋季-2023 春季. (2022). 威斯康星州自然资源部。 2022WI_HuntRegulations.pdf (widen.net).

[8] 威斯康星州自然资源部。 CWD 全州监测总结. 2023.

[9] 威斯康星州自然资源部。 鹿统计数据 (wi.gov). 2019.

对我的内容感兴趣? 请考虑在 Medium 上关注我。

所有代码和数据可以在 GitHub 上找到: gspmalloy/SEIR_CWD_Deer: 使用 deSolve 的 SEIR 模型 — 鹿的慢性消耗病 (github.com)

在 Twitter 上关注我: @malloy_giovanni

你在用有限数据开发模型方面的经验如何? 或者在开发模型以提供政策建议方面的经验如何? 请在下方评论中分享你的想法或经验!

在 SageMaker 中选择正确的 XGBoost 损失函数

原文:towardsdatascience.com/selecting-the-right-xgboost-loss-function-in-sagemaker-60e545a75c47?source=collection_archive---------4-----------------------#2023-02-18

何时以及为何应使用绝对误差或平方误差

Andrew CharabinTowards Data Science Andrew Charabin

·

关注 发表在 Towards Data Science ·7 min 阅读·2023 年 2 月 18 日

--

图片来源于 VectorStock,版权归 Andrew Charabin 所有

XGBoost 是一个开源软件库,用于将梯度提升框架应用于监督学习(SL)任务。由于其在解决各种 SL 问题中的有效性、速度和稳健性,目前的说法是,针对你的 SL 问题,最佳的模型类型要么是深度学习算法,要么是 XGBoost。由于其名副其实的流行,XGBoost 已作为内置模型在 Amazon SageMaker 这款云机器学习平台中提供。

凭借对 XGBoost 框架的基本了解,任何数据科学家都可以轻松地将数据集插入到 SageMaker 中并生成一个 XGBoost 模型。由于 SageMaker 提供了贝叶斯超参数调优,它将重新训练并选择不同超参数下的最佳模型,用户可以无需完全理解诸如 max_depth 和 eta 等关键输入。然而,在训练 XGBoost 时,通常很少考虑要使用的损失函数(目标),尽管它可能是过程中的最重要的人为决策。本文旨在提供更多关于如何在训练 XGBoost 或其他回归任务的 SL 模型时选择正确损失函数的见解。

我们从解释梯度提升算法开始。一个有用的起点是理解什么不是梯度提升——例如集成决策树算法随机森林。集成决策树算法的目标是生成多个“弱学习者”树,这些树在网络中组合起来时可以产生一个“强学习者”。假设我们有一个由 100 棵树组成的集成,随机森林从概念上讲是并行生成这些树,使用“自助采样”选择观察值和维度的子集来生成每棵树,然后对所有 100 棵树的预测结果进行加权平均,以获得回归问题的最终预测。

与 XGBoost 的关键区别在于,单棵树不能并行生成,必须按顺序生成。具体来说,XGBoost 模型中的每棵后续树用于预测前一棵树的误差。通过结合“自助采样”(从一个袋子中提取不同的观察值和维度)和“提升”(按顺序构建树)的集成方法,XGBoost 框架能够减少模型偏差而不出现过拟合。此外,XGBoost 和其他基于树的模型对高方差数据集和低重要性维度具有鲁棒性。

现在我们从高层次上理解了 XGBoost 的工作原理,让我们回到损失函数上。损失函数接受一组观察值的模型预测,并告诉你这些预测距离实际结果有多远。对于回归问题,本讨论的重点是人们脑海中最简单的损失函数是均绝对误差——我们从实际值中减去预测值(或反之),取绝对值,然后对所有观察值取平均。换句话说,错了 200 比错了 100 要糟糕两倍。

但大多数 XGBoost 用例使用的标准损失函数是‘reg:squarederror’,在 SageMaker 中以前称为‘reg:linear’。正如名字所示,平方误差和绝对误差的唯一区别在于,错误在被平均之前会被平方。由于平方效应在错误大小增加时具有累积效应,均方误差对大错误的惩罚程度比均绝对误差更大。错了 200 现在比错了 100 要糟糕 4 倍。

但为什么均方误差会成为训练 XGBoost 和其他 SL 模型时的标准、无可争议的损失函数,而它并不像均值绝对误差那样简单或直观?

首先要理解的区别是预测损失和估计损失。

(图表由作者提供)

预测损失是业务所承担的损失,而估计损失是模型与“真实情况”模型的距离。正确的估计损失函数并不总是最小化预测损失的函数。正如我稍后所展示的,你可以通过选择不同的估计损失函数来减少预测损失。

给定的均值平均误差在许多实际案例中很好地量化了遗漏的惩罚,为什么均方误差在估计最佳模型参数时常常胜出?

一切归结为模型误差被假设为正态分布的实际假设。

(图表由作者提供)

假设有两个组件需要完美地解释一个结果 Y。首先,是可以从模型维度中得出的信号。在这种情况下,假设一个具有单一维度和偏置项的线性模型是足够的(下面的 mX + b)。其次,是一个遵循独立方差正态分布的随机变量 Z。

从某种意义上说,最佳的估计损失函数是当应用于模型时最大化产生观察到的数据的可能性。在这种情况下,它被证明是均方误差。

平方项的出现可以追溯到正态分布的概率密度函数,这对产生钟形曲线是必不可少的。

但如果误差,特别是随机成分 Z,并不是正态分布呢?在回归问题中观察到的另一种常见预测误差分布是拉普拉斯分布。

(图表由作者提供)

当观察到拉普拉斯分布的概率密度函数时,绝对误差公式立即显现出来。缺乏平方导致分布在 X 轴上的 0 附近出现尖峰。

此外,事实证明,当模型误差遵循拉普拉斯分布时,均值绝对误差是最大化观察到数据的可能性的损失函数。

然而,正态分布和拉普拉斯分布只是回归任务中可以观察到的几种误差分布。对于基于计数的数据问题,误差可能遵循泊松分布。因此,在其他情况下,可能需要进一步研究以找到匹配的误差分布和相关的损失函数,以便获得最佳的参数估计。

在不进一步讨论的情况下,我们已经有了一种改进的算法来选择适用于 XGBoost 模型的正确损失函数。

  1. 生成一个任意模型。

  2. 观察误差,并检查它们是否与正态分布或拉普拉斯分布相似。

  3. 如果相似,请分别选择均方误差或平均绝对误差作为损失函数。

  4. 如果不相似,请研究并测试与观察到的误差匹配且对类似问题常见的其他分布。然后,找到能够导致最可能的参数估计的损失函数。

作为这个简单算法应用的商业案例,我将分享一个我遇到的真实 SL 回归案例。

  1. 使用均方误差损失训练的 XGBoost 模型产生了如下的预测误差分布。

(图表由作者提供)

2. 预测误差与拉普拉斯分布非常相似。

3. 应该使用平均绝对误差作为损失函数。

唯一的问题是,XGBoost 和深度学习算法,作为我们解决众多 SL 问题的首选,需要平滑的损失函数来进行梯度下降(或类似的优化算法)。因此,我们需要找到一个平滑的损失函数,其与平均绝对误差非常相似。在 SageMaker 中最好的选择是伪 Hubbard 分布;下面显示的 Hubbard 分布在 delta 为 1 时与绝对误差非常接近。

(图表由作者提供)

那么,对于我的用例来说,从均方误差损失函数切换到伪 Hubbard 和平均绝对误差对模型训练和超参数调优有什么影响?

(图表由作者提供)

在图表中,你可以看到,当伪 Hubbard 损失与 MAE 超参数调优配对时,平均绝对误差(MAE)下降了 30%。更有趣的是,即使对业务最重要的目标是均方根误差(RMSE),使用均值绝对误差作为估计损失函数仍然能带来改善的业务成果。

结束语

虽然 XGBoost 和类似的集成决策树模型使数据科学家能够轻松地在不同的数据集上进行插件式操作,并以最小的努力获得良好的模型性能,但这种表面上的理解是有代价的。选择正确的损失函数是一个简单的例子,我展示了它可以带来 30%以上的模型改进。

从更深层次来看,数据科学就是不断问“为什么”,追求重要的细节以揭示答案,并舍弃那些附属的细节。它需要直觉,追求好奇心,遇到惊讶时深入探究,并从大局到细节拆解问题。本文概述了我理解为什么均方误差被认为是 XGBoost 回归任务的标准损失函数的过程。结果是一个简单的系统,可以为各种情况推荐最佳损失函数。希望这篇文章对你在理解为什么以及使用你发现的东西来帮助解决重要的商业挑战的过程中有所帮助。

感谢阅读!如果你喜欢这篇文章,请关注我以获取我新帖子通知。同时,欢迎随时分享任何评论/建议。

自助数据分析的需求层次

原文:towardsdatascience.com/self-service-data-analytics-as-a-hierarchy-of-needs-19bb68551640?source=collection_archive---------0-----------------------#2023-11-22

从食品和住房到自我实现:如何使用科学的方法创建支持自助分析的基础

安德鲁·塔夫特Towards Data Science 安德鲁·塔夫特

·

关注 发表在 Towards Data Science · 15 分钟阅读 · 2023 年 11 月 22 日

--

自助需求层次(作者提供的图片)

我一直在回顾 90 年代,当时自助式商业智能(BI)工具如 Business Objects 和 Cognos 首次推出。像所有过于热情的软件工程师一样,我甚至在短暂的花旗集团工作期间帮助构建了一个。那时候,我的年轻自我做出了两个非常快速的预测:

  1. Excel 已死

  2. 自助数据将迅速取而代之

好吧,我并不是诺斯特拉达姆斯。在花旗银行之后,我发现自己在商业智能顾问的职业生涯中已经度过了十年——进行一些数据工程(当时是 ETL,而不是 ELT),在其上加装 BI 工具,培训业务用户,反复循环。我们建立了一些“伟大的东西”,但每一个项目之后都留下了令人不满的结果:

业务用户未按我们预期的速度采用软件进行自主服务。

一小部分“高级用户”(通常是技术人员)会拿起工具,创建各种不同水平的仪表板和报告,但在业务端并没有普遍采纳。而且依然严重依赖顾问。

BI 供应商的销售宣传:100%自助数据民主化

我的期望:60–80%的采纳率

现实情况是:乐观估计下的采纳率不到 20%。

过了一段时间,这些项目开始感觉像是一次很棒的学习机会,而不是一次绝对的失败。到底是什么原因导致这种情况?工具?用户?IT?顾问?我们来到了 2010 年左右,开始有大量关于失败 BI 项目的文档。并非“失败”指的是项目从未产生有意义的结果,而是很少能充分发挥其潜力。业务领域依然严重依赖 IT 来获取数据。清洁、可信赖的数据并未迅速可用。

在这一时期发生了一件有趣的事情:一个名为 Tableau 的数据可视化产品开始广泛被采用。它无处不在,是数据民主化的解决方案。接着,Power BI 作为一个集数据可视化和报告功能于一体的最佳工具进入竞争。然而,十年或更长时间以来,我们仍然看到这些新工具同样面临着 BI 工具自助采纳率低的问题。显然,我并不孤单。

全球各组织的商业智能(BI)采纳率为26%。(360Suite 2021)

[## 2021 年惊人的商业智能统计数据

随着 BI 市场的不断发展,这些统计数据显示了商业智能工具将继续至关重要的原因……

www.trustradius.com](https://www.trustradius.com/vendor-blog/business-intelligence-statistics-and-trends?source=post_page-----19bb68551640--------------------------------)

我不能坐视不管。自然而然地,我不得不创造世界一直需要的东西:解决自助服务的 BI 工具。是的,我告诉自己,我终于能做对了。于是我创建了FlexIt Analytics并设定了这个目标。好吧,还记得我之前的预测吗?是的,我再次错了。让我直奔主题:

并不存在,也永远不会有一种单一神奇的解决方案,能够以有意义的方式使数据分析对大众更加可接近。

没有 BI 工具可以解决自我服务的问题。然而,我们可以退一步,从“大图景”的非技术角度考虑这个问题,也许能获得一些有价值的见解和策略来前进。

马斯洛的需求层次

回到高中时代,试着回忆那堂令人振奋的关于人类动机的心理学讲座。如果你在学校里没有学习过这部分内容,或者记不起来了,这里有一个总结:

美国心理学家亚伯拉罕·马斯洛提出了一种人类动机理论,认为在一个人能够满足更高层次的需求之前,必须先满足基本需求。随着我们在层级上升,我们从如食物和水这样的低层次短期需求转向更高层次的需求,这些需求持续时间更长,更复杂,并且越来越难以满足。最高层次是自我实现和超越。

[## 马斯洛需求层次

马斯洛的需求层次是激励人的需求金字塔。个人最基本的需求,在底层…

自我服务需求层次

简而言之,你需要一个基本的基础才能进入下一个层次。任何在数据领域的人都会立刻认出这一点,并理解它直接应用于实现“数据的自我实现”,这显然是“自我服务”。来吧,它们都有“自我”,这绝不是巧合。让我们深入探讨一下。

自我服务需求层次

我们将从顶部展示相同的图像,因为它不仅是一个值得在 Instagram 上分享的美图,而且在我们即将进行的分析中也非常有帮助。像马斯洛的层次结构一样,自我服务数据分析需求层次展示了每个层次如何支持和使上层层次成为可能。此外,你会看到,随着你向上移动,更多的信任既是必要的也是交付的。

再来一次,DJ:

自我服务需求层次(作者图片)

收集

在底层,马斯洛的生理需求是显而易见的:食物、水、住所。同样,自我服务需求层次的底层也是显而易见的——数据收集。你需要先收集数据。再进一步说,你的基础需要从不同来源收集原始数据。在现代数据世界中,这就是 ELT(提取、加载、转换)的提取和加载部分,结果是我们称之为数据湖的东西。请注意与传统/旧的数据仓储概念 ETL(提取 -> 转换 -> 加载)之间的区别,后者没有数据湖,迫使我们在需要原始源数据时返回到不同的源数据库。

最后一点是,从这一层级产生的任何数据分析都需要由更高技能的分析师/数据科学家来完成,并且信任度较低,因为它没有经过层级的更高层级。类比可以是这样的:你能否直接跳到顶层的超越?也许可以,但在周末聚会结束时,你不太可能持续这种状态。

转型

马斯洛层级中的下一层是安全性,包括安全、社会稳定、可预测性和控制。在我们的自助服务层级中,我们通过在数据仓库中将数据清理和组织成业务模型来实现这种可预测性、稳定性和控制。这通常采用多维星型模式的形式。使用来自下层集合的原始源数据,分析师可能需要将大量不同的表连接在一起以获得客户数据。在这一层级中,这些不同的数据已经在一个共同的表中整合,称为客户维度。在这一过程中,数据会被清理(去重、同一客户的名称不匹配),并进行有用的计算(例如,首次订单日期),使得 SQL 变得更简单。

最终,我们建立了另一层的数据安全性和信任,同时也赋能了一个新的自助分析师群体,因为他们无需了解基础源数据的业务复杂性。同样值得注意的是,在这一层级,我们应该看到业务领域所有者的参与。转型过程旨在支持实际的业务需求,因此必须有业务所有者参与。在现代数据世界中,我们开始看到“分析工程师”作为支持这种混合需求的关键角色。

语义层

马斯洛的第三层级是通过关系和联系来获得爱与归属感。与我们的自助服务层级的相关性惊人,因为语义层实际上是你建立关系(表连接)的地方,并且是将一切结合在一起的部分。我可以继续深入探讨语义层,并在此处链接的帖子中详细说明:

## “语义自由”是商业智能的未来

dbt、度量、无头和通用语义层如何实现“语义自由”商业智能

towardsdatascience.com

我认为这一层级是实现真正自助服务的最重要层级,业务领域的拥有者需要深度参与。“通用语义层”可以提供一个单一的真相来源,通过数据素养、简洁性和信任来驱动自助分析。分析师可以依赖于友好的字段和实体名称、数据目录描述,最重要的是,他们不需要知道表如何连接(或者至少如何编写 SQL)。我们还可以访问诸如数据源追溯(追踪字段回到源表)、同义词(你称之为“销售”,我称之为“收入”)以及数据新鲜度(数据上次刷新时间)等关键内容。

这里有一件重要的事情需要注意,特别是对于那些可能会说“Business Objects 在 90 年代就有这个功能”的历史学家们。我们还没有达到“分析层”(BI 工具层级)。由于许多原因,这些原因在上面链接的文章中有详细阐述(“没有语义的未来是商业智能的未来”),你必须避免将业务逻辑语义层塞进 BI 工具中。在我们的自助服务层级体系中,“语义层”级别应该支持下一个层级,而不是代替它。

分析

在这一层,我们开始讨论 BI 工具、报告、仪表盘,以及当我们谈论自助分析时大多数人想到的东西。如果你对语义层与马斯洛需求层次理论的关联感到如我一样的惊讶,那么请准备好迎接马斯洛的自尊层级。在这里,他将需求分为“较低”版本的需求,如地位、认可、名声、威望和关注,以及“较高”版本的需求,如力量、能力、掌握、自信、独立和自由。你好,“数据英雄”、“禅宗大师”和大师们。

在我们的自助服务层级体系中,这一层级开始关注业务领域的拥有权和自助分析,重点是四种分析类型中的两种:

1. 描述性 — 显示发生了什么的报告和仪表盘

2. 诊断性 — 显示为何发生了这些事情的分析

你是从一个干净的数据仓库开始构建你的仪表盘,数据仓库上有一个良好建模的转换层和通用的语义层,对吗?

自相矛盾的是,可能正是那些我们认为能够实现自助服务的 BI 工具实际上做出了最大的负面贡献。我们知道 Tableau(一个极为出色的可视化工具,确实具有巨大的价值)最早通过绕过缓慢的 IT 部门直接向业务部门销售获得了早期的关注,并继续利用这一分歧。过多的实施涉及从手写 SQL 的源数据库或静态 BI 报告中导出数据,并将这些.CSV 文件导入 Tableau。虽然你可以在这个自助餐中选择健康饮食,但现实往往大相径庭。随之而来的混乱常常使得企业陷入困境,以至于他们永远无法达到下一个层级,因此他们继续只生成描述性仪表盘,展示过去发生的事情。

自我实现与超越

马斯洛需求层次的最高级别涉及自我实现、个人成长和发挥全部潜力。类似于生活,在数据世界中,没有一个可以达到的顶峰然后说“就这样,完成了。”这是一个持续的工作进程,非常难以实现,似乎可以永远进行下去。在这个层级,我们超越了基本的描述性和诊断性分析,建立了对数据和流程的高度信任。这使得接下来的两种分析成为可能:

3. 预测性——预测接下来会发生什么

4. 规范性——基于预测,推荐最佳前进路径

此时,我们在所有数据层面上都建立了坚实的基础,可以开始在利用人工智能、自动化业务流程和处理更高级的用例方面取得有意义的进展。

数据驱动组织的组成部分

好的,我们已经建立了一个改善“数据生命周期”的框架,目标是数据自我实现。现在,让我们深入探讨如何实现这个目标。首先,让我们看看需要关注的方面:人员、流程和工具。

数据驱动组织的组成部分(图由作者提供)

人员

我来自技术领域,所以希望创建技术解决方案来解决业务问题。当然,如果我获得一些业务需求,锁自己在一个黑暗的房间里编写代码,那么我可以创建一个满足业务需求的软件。我犯的错误,以及许多其他人犯的错误,是忽视了软性方面:人员。这听起来显而易见,但我认为我们这些技术人员需要承认,我们经常创建出色的软件产品,然后把它们交给业务用户,说“瞧,这就是了!”当这些软件没有按照我们预期的方式使用,或者他们“就是不理解”时,我们感到困惑。

技术的人性化可能令人困惑和神秘,但这并非必须如此。在这一核心问题上,我们需要通过专注于一些关键领域来建立信任和能力。首先,必须有采纳,否则反对你的力量会使即使是出色的技术解决方案也偏离轨道。除此之外,必须有严肃的协作,理念是我们在朝向“业务驱动”的数据解决方案而不是“数据驱动”的解决方案。我们所做的一切都必须考虑到业务需求。在我们构建时,我们需要考虑如何在我们交付的产品中实现能力。在数据世界中,我们如何促进“数据素养”?当然,企业应该了解他们的数据,但是当我们把他们的业务通过技术研磨后再呈现给他们时,并不总是那么明显。我们需要通过数据目录和语义层促进数据素养。最后,当我们推出我们的解决方案时,我们不能只做标准的推出会议和讲习班,这些会议和讲习班给人的感觉是演讲。我们需要专注于“及时”培训,专注于业务用户在需要解决实际数据问题时的真实数据需求。

流程

即使我们把人性化部分做得很好,我们仍然很容易偏离轨道。为了保持在正轨上,我们还需要做好过程部分。在过去几十年中,特别是在技术领域,最明显的问题之一是许多项目采用了瀑布方法,即在项目开始时就要确定最终结果。我们的第一步,特别是在数据世界中,建立我们的数据驱动型组织可能需要多年时间,是灵活并专注于不断变化的业务需求采取敏捷方法。

敏捷方法被开发为一种灵活的方法,欢迎在过程的后期甚至方向的改变,并在整个过程中考虑利益相关者的反馈。 — Forbes

[## 敏捷 vs. 瀑布:哪种项目管理方法适合您?

敏捷和瀑布是两种广为人知的项目管理方法。它们在软件领域都很流行...

Forbes

做“敏捷”的一个大错误是进行一系列不同的冲刺项目,但这些项目并未形成一个连贯的最终产品。我们必须有一个明确的最终目标,即使我们不是采取瀑布式的方法。必须有标准和数据治理机制,以确保我们始终关注这个最终目标。业务方也需要对他们的数据负责,而不是技术部门。他们需要深入参与这个过程。最后,流程需要专注于持续改进。什么有效?什么无效?为什么?然后,去修复这些问题,并继续交付。

工具

早期,我们依赖工具作为解决问题的魔法解决方案。正如我之前所阐述的,工具不是解决方案。它们甚至不到解决方案的三分之一。我认为这大致是 50% 人员,30% 过程,只有 20% 工具。作为一个 BI 工具提供商,这是一个粗略的看法。但这确实是事实。

话虽如此,工具可以做一些事情来支持整体的人员和流程组件。显然,它们需要直观,以便不需要深入了解如何使用它们,我认为许多现代 BI 工具做得很好。我认为它们不足的一个方面是“即插即用”。正如我之前提到的,我们在工具中放入了过多的业务逻辑,因此从一个工具切换到另一个工具是一项重大工作。更不用说许多组织拥有 3 个或更多的 BI 工具,通常访问相同的数据集。我们需要做的是将这些业务逻辑从 BI 工具中移除,并推送到一个所有 BI 工具都可以接入的集中语义层中。

此外,我们的工具需要与其他工具集成,而不是试图成为一个全能的单体工具。这是“现代数据堆栈”做对的一点,但重要的是我们不要走得太远,导致有数百种工具,造成混乱和杂乱的架构。归根结底,记住工具只是为了支持人员和流程。

创建数据驱动组织的步骤

现在我们已经建立了一个框架和数据驱动组织的整体组成部分,接下来让我们谈谈如何实现这一目标。

第 1 步:获得支持

首先,你需要确定关键利益相关者,并获得高层的支持。没有这些,你可能会面临缺乏“人力资源”来实施你的自助服务框架和组件的风险。获得广泛的支持可能非常困难,所以找出谁可以成为早期的倡导者。在这些步骤结束时,你将重新开始第 1 步,继续建立你的数据驱动组织,并在此过程中获得更多的支持。你是在追求滚雪球效应。

第 2 步:从小处开始

继续使用雪球类比,我们在打造一个雪人。自然地,我们从小做起,逐步扩大。我们将我们所构建的东西视为组件,采取敏捷的方法来满足实际的业务需求。我们希望在第一次迭代中取得一个“快速胜利”,以便在这些积极结果上积累更多,吸引更多的人参与其中。

第三步:建立流程

这些敏捷的“快速胜利”有可能导致混乱的架构。这就是为什么我们立即建立标准和数据治理,这为我们提供了基础,并保持我们专注于交付高质量、准确和可靠的数据产品。像 Github 这样的工具在支持我们的标准和数据治理方面发挥了重要作用。

第四步:民主化

数据治理将使我们能够更加安全地推出这些数据产品,增强信心并降低风险。在实现数据民主化时,我们需要:

  • 消除数据孤岛 —— 这些是由一个部门(通常是技术部门)控制的“黑箱”数据源,与整个组织隔离。

  • 提升数据素养 —— 我们不能期望业务用户立即理解 IT 提供的内容,即使这些数据是他们自己的。数据目录可以大大支持数据素养,但这可能很棘手。我们经常得到的是逐渐过时的电子表格数据字典,最终积灰尘。我们需要转向更动态、主动的数据目录,使业务用户能够对数据目录实体采取行动,并提供对定义等的反馈,以便持续改进。

  • 建立信任 —— 为了实现数据民主化,IT 部门需要信任业务部门会正确使用数据。业务部门需要信任 IT 部门会提供准确、可靠、及时的数据。每一步都需要建立和维护信任。

第五步:合作

现在我们已经采取措施来实现数据民主化,我们需要确保我们在合作和共同开发解决方案的同时,也提供关键反馈来改进工作。重要的是组成一个类似 DART(数据分析和报告团队)的跨部门小组,从技术到业务都有成员,并定期开会解决问题。

第六步:评估

最后,我们需要突出胜利,同时确保建设性地讨论未能奏效或需要改进的方面。在不采取过于教条主义的态度,也不编造 KPI 的情况下,我们需要找出一种衡量成功的方法。人们对第一次迭代的结果满意吗?我们是否创建了一个立刻有用的数据产品?然后,我们进行迭代,持续改进有效的和无效的部分。

然后,反复进行,从第 1 步开始,获得更多的支持和以结果为驱动的下一个项目。

结束语

总结一下,我们涵盖了三个关键领域,以专注于建立数据驱动型组织并实现自助服务。需要特别注意的是,我们不是从零开始进行自助服务,逐步实现完全的数据民主化。我们正在一点一点地推动进展,并不断改进,以便让组织中的更多人参与到数据中来。回顾一下,这里有三种方法可以集中精力:

  1. 框架 — 一种需求层级结构,可以指导我们需要构建什么,以实现数据驱动型组织。

  2. 组件 — 这个数据驱动型组织的组件,即人员、流程和工具。

  3. 创建步骤 — 一种六步法,专注于这些组件,以在框架内建立我们的数据驱动型组织。

祝你在自助服务方面好运!

请评论,我很想听听你的想法,或者联系Andrew Taft

自监督学习在计算机视觉中的应用

原文:towardsdatascience.com/self-supervised-learning-in-computer-vision-fd43719b1625

如何用仅有的几个标记示例来训练模型

Michał Oleszak数据科学前沿 Michał Oleszak

·发表于数据科学前沿·18 分钟阅读·2023 年 1 月 29 日

--

迄今为止,AI 所提供的大部分价值来自于在越来越大的数据集上训练的监督模型。许多这些数据集由人工标注,这是一项单调、耗时、容易出错且有时昂贵的工作。自监督学习(SSL)是一种不同的学习范式,允许机器从未标注的数据中学习。在本文中,我们将深入探讨 SSL 的工作原理以及如何将其应用于计算机视觉。我们将比较简单的方法和最先进的技术,并展示 SSL 在医学诊断中的实际应用,这个领域可以从中受益颇多,但同时也需要深入理解该方法以正确实施。

什么是自监督学习?

根据 Yann LeCun 的说法,Meta 的首席 AI 科学家,自监督学习是“构建背景知识并在 AI 系统中逼近一种常识形式的最有前途的方法之一”。自监督方法背后的理念是用没有注释的数据来训练模型。

自监督学习是构建背景知识并在 AI 系统中逼近一种常识形式的最有前途的方法之一。
Yann LeCun

考虑另外两种学习范式:监督学习和无监督学习。在监督学习中,我们向模型提供一些输入及相应的标签,模型的任务是找到一种映射关系,使其能够对新数据进行泛化。

另一方面,在无监督学习中,我们只有输入而没有标签,学习目标是探索输入数据中的模式,目的是对相似的示例进行聚类、减少数据维度或检测异常等。

学习范式。图片由作者提供。

自监督学习在某种程度上介于这两者之间。它类似于无监督学习,因为它从未标记的数据中学习,但同时也具有监督性质,因为模型在训练过程中创建自己的伪标签来进行学习。

这个想法并不完全新鲜。自监督学习在过去已经被广泛使用,最著名的是在 NLP 中用于训练大型语言模型,如 BERT 或 GPT。这些模型可能会被提供原始文本,并要求预测序列中的下一个标记。因此,对于每个训练样本,模型会将其伪标签设置为句子中的下一个词,例如。

以自监督方式训练的模型从未标记的数据中创建自己的伪标签。

但在过去的三年中,自监督学习在计算机视觉领域被重新发现,取得了突破性的进展,相关论文来自 GoogleDeepMindMeta。然而,原则仍然相同:模型创建自己的伪标签,例如通过遮挡图像的一部分并尝试预测它,或通过旋转图像并尝试预测旋转角度。我们稍后将讨论具体的技术。

现在我们对自监督学习有了基本的了解,让我们看看它在医疗应用中为何特别有用。

医疗数据中的注释稀缺。

医疗行业产生了大量的图像。根据 IBM 的数据,高达 90% 的医疗数据以图像形式存在。这些数据是通过进行各种检查获得的,例如 X 射线,根据 WHO 的数据,每年进行 36 亿次 X 射线检查。

这些大量的数据似乎为应用机器学习算法提供了极大的机会,这些算法依赖于数据,以帮助人类进行诊断和治疗。然而,有一个问题存在。

传统的监督学习模型,为了从数据中学习,除了训练样本,还需要注释或标签:当我们在训练过程中向监督模型展示一张 X 射线图像时,我们需要告诉它需要识别哪些医疗条件。

不幸的是,在医学领域,注释资源稀缺,获取这些注释是一项具有挑战性的任务。它们通常需要由专家医生提供,而专家医生的时间昂贵,且无疑更好地用于照顾他们的病人。这时自监督学习就派上用场了。

自监督学习解决了注释稀缺的问题。

在诸如识别 X 光图像中的医疗条件等标注稀少的环境中,我们通常发现自己处于下图左侧所示的情况:我们有大量数据,但只有一小部分是标注的。

如果我们采用传统的监督方法,我们只能使用少量的标记数据来训练模型。然而,得益于自监督学习,我们也可以从未标记的图像中学习。让我们看看怎么做。

自监督学习工作流程。图像来源:作者。

首先,我们让自监督模型从未标记的数据中生成其伪标签并进行训练。这被称为自监督预训练,在这个过程中模型解决一个称为预文本任务的问题。之前提到过,这可能是预测一个被遮蔽的图像片段或旋转角度,我们会在后面讨论如何选择预文本任务。

上述结果是一个预训练的模型,它已经学习了未标记数据中存在的模式。这个模型对特定的医疗条件一无所知(因为这种信息只有在标签中才会出现,而它未曾见过),但它可能已经学会了一些 X 光图像在一致的方式上有所不同。这就是 LeCun 所说的建立背景知识。

自监督学习工作流程包括两个步骤:在未标记数据上进行预训练以建立背景知识,并在标记数据上进行微调以学习解决下游任务。

第二步是以常规、监督的方式对这个预训练模型进行微调,在数据的标记部分进行。关键在于,现在模型已经有了一些关于数据集中模式的背景知识,提供给它仅仅几个带注释的例子就足以让它学习如何解决下游任务,在我们的例子中就是检测 X 光图像中的医疗条件。

预文本任务

现在让我们讨论模型在预训练步骤中解决的预文本任务。文献中提出了许多这种任务,可能性几乎是无限的。唯一的要求是我们必须能够从输入数据本身创建伪标签。让我们看几个最受欢迎的传统方法。

被遮蔽的预测

被遮蔽的预测简单地意味着遮蔽输入图像的一部分,并让模型尝试从剩余的图像中预测。

被遮蔽的预测。图像来源:作者。

转换预测

存在一个完整的家族方法,这些方法可以归纳在一个广泛的名称下——变换预测。在这些任务中,我们对图像应用一些变换,比如旋转它、改变颜色等。然后,模型的任务是预测变换的参数:角度比例、颜色变化量等。

变换预测。图像由作者提供。

拼图

另一种方法是让模型解决拼图。我们将输入图像切成若干块,随机重新排列,然后要求模型找出正确的原始排列。

拼图。图像由作者提供。

实例识别

其他一些方法将预训练任务集中在实例识别上。这要求你拥有同一对象的多个视图,例如,从不同角度或不同地点拍摄的同一只猫的照片。这种方法的变体会自动生成视图,例如,从 3D 点云中生成或使用生成模型。预训练任务就是识别两张图像是否表示完全相同的对象。

迄今为止,我们讨论的每一个预训练任务的目标都是迫使模型学习数据中的结构和模式。然而,最新的研究发现,有一种稍微不同的方法在实现这一目标时效果最佳。成功的方法基于对比学习

对比学习

对比学习的原则是将样本相互对比,以学习样本之间的共同模式以及区分它们的模式。

对比学习将样本相互对比,以学习它们之间的共同模式以及区分它们的模式。

监督式对比学习

这种方法不仅限于自监督学习。事实上,它最初是作为监督式少样本问题的解决方案出现的。想象一下,你负责办公楼的安全,并且你想安装一个只有经过验证的员工才能打开的自动门。你只有少量每个员工的照片来训练模型。

这里的一个解决方案可能是训练一个模型来识别给定的两张图像是否描绘了同一个人。系统可以将每个来访者的面部与员工照片数据库进行比对,并寻找匹配。这些模型通常以对比方式进行训练。在训练过程中,它们会被呈现三张照片以进行对比:两张是同一个人,一张是不同的。目标是学习前两张彼此相似,而与第三张不同。这种对比方法是监督的,因为你知道每张照片中的人是谁,并可以利用这些知识生成训练样本。

对比学习可以用于有监督和自监督学习任务。

一种略有不同的对比学习方法证明非常适合自监督问题。

自监督对比学习

在这种方法中,我们还向模型呈现三张图片。

  • 第一张是训练数据集中一张随机图片,被称为锚点图片

  • 第二张图片是相同的锚点图片,但以某种方式进行了变换,例如通过旋转或颜色偏移,被称为正样本

  • 第三张图片是训练数据中的另一张随机图片,与第一张不同,被称为负样本

对比自监督学习。图片来自作者。

学习的目标是教会模型前两张图片是相似的,我们希望它们的潜在表示彼此接近(毕竟,旋转的黑白猫仍然是一只猫),而最后一张图片与前两张图片不同,它的潜在表示或嵌入应该远离。

现在我们来更详细地讨论几种自监督对比架构。

三元组损失

可以想象的最简单的方法是基于三元组损失的模型。我们将锚点、正样本和负样本图像通过一个骨干模型,例如 ResNet 或视觉变换器,以获取它们的嵌入,然后将这些嵌入传递给三元组损失函数,其目标是教会模型将锚点和正样本图像在潜在空间中彼此靠近,而将锚点和负样本图像远离。

基于简单三元组损失的模型架构。图片来自作者。

基于三元组损失的模型是一个简单的模型。现在,让我们看一看一些最先进的方法。

SimCLR

对比学习的简单框架,简称 SimCLR,已在 2020 年由Google Research 的一篇论文中提出。

模型接收两张输入图片:锚点图片及其变换版本或正样本,然后将每张图片通过 ResNet 编码器、多层感知器和一个可学习的非线性层。噪声对比估计(NCE)损失旨在最大化两个嵌入之间的相似度,同时最小化与同一小批次中其他图片的嵌入的相似度。

SimCLR 架构。图片来源:LINK

SimCLR 在图像识别中取得了很好的结果。不幸的是,正如作者所示,它在批量大小达到 4096 时效果最佳,并且需要长时间训练。这使得它对于那些不愿在云计算上花费大量资金的个人和公司几乎不可用。

MoCo

Facebook AI 研究团队的 Momentum Contrast(简称 MoCo)缓解了 SimCLR 的一些缺点。他们通过一个巧妙的技巧成功将批量大小减少到 256。

MoCo 有两个编码器网络,它们的参数分别优化,一个用于锚点图像(在线编码器),另一个用于正例(动量编码器)。在线编码器通过基于梯度下降的算法进行优化,而动量编码器则基于在线编码器权重的指数移动平均进行更新。

最重要的是,MoCo 维护了一个动量编码器嵌入的记忆库,并从中抽取负例来计算 NCE 损失。这消除了对大批量大小的需求。

MoCo 架构。图片来源:LINK

在 MoCo 之前,记忆库已经被用于对比学习,但它们通常存储由在线编码器生成的表示。因此,这样的记忆库同时存储了在训练的不同阶段生成的图像,从而导致不一致。基于在线权重移动平均的动量更新的引入,使 MoCo 能够保持一致的记忆库,成为计算损失的良好负例来源。

在发布时,MoCo 在许多不同的计算机视觉任务中超越了顶级监督模型。

BYOL

Bootstrap Your Own Latent(简称 BYOL)是 DeepMind 的一项成果。它基于 MoCo, 同样利用了两个网络,其中一个网络的权重由另一个网络的移动平均更新。

然而,BYOL 并没有使用对比损失函数,而是学习将正例和归一化的锚点映射到嵌入空间中的同一位置。换句话说,在线网络被训练去预测另一网络的表示。这消除了对负例和记忆库的需求。

BYOL 架构。图片来源:LINK

尽管 BYOL 没有明确地将不同图像进行对比,但 一次全面调查 发现它实际上是在以对比的方式进行学习,尽管是间接的。

其他

目前存在许多其他现代自监督架构,几乎每个月都会出现更多新的架构,这些新架构的结果超越了它们的前任。目前的研究往往更多关注模型在许多不同下游任务中的迁移性。大多数这些研究来自 Meta 研究人员。一些显著的例子包括Barlow TwinsSwAVSimSiam,以及最新的TicoVICRegL

选择变换

我们已经讨论了自监督学习是如何工作的,以及它如何解决医学数据中普遍存在的稀缺标注问题。我们还检查了各种前置任务和最先进的对比架构,这些架构将锚点图像与其变换版本进行对比。我们缺少的最后一块拼图是如何选择要应用于锚点图像的变换。而这个选择被证明是成功将自监督学习应用于实际问题的关键步骤。

在对比学习中正确选择变换对于成功解决实际问题至关重要。

包括 SimCLR 和 MoCo 论文在内的最先进文献声称已识别出最佳的变换集合。他们建议使用随机裁剪、颜色抖动和模糊。作者已经证明这些变换在广泛的下游任务中效果最佳。

不幸的是,事情并不那么简单。不同的变换会向模型引入不同的不变性,这可能并不总是令人满意的。

不应对比的内容

Xiao 等人撰写的一篇优秀论文,What Should Not Be Contrastive in Contrastive Learning,很巧妙地展示了这一现象。

考虑一个包含三类图像的数据集:鸟类、花卉和大象,以及在对比预训练期间可以应用于锚点图像的三种可能变换:颜色偏移、旋转和纹理变化。根据你选择的变换,你将能够解决一些下游任务,但不能解决其他任务。

不同的变换会向模型引入不同的不变性,这可能并不总是令人满意的。

如果你将颜色偏移作为你的变换,你将向模型引入颜色不变性:在对比预训练步骤中,损失函数会迫使模型在嵌入空间中将语义相似但颜色不同的图像靠近彼此。

不同的变换会向模型引入不同的不变性。来源:LINK

然后你可以对模型进行微调,以执行粗粒度分类任务,例如区分鸟类和大象,因为它们在许多方面的差异远超过颜色。然而,细粒度分类任务,例如区分不同的鸟类或花卉物种,将会更困难。在这些情况下,类别通常主要通过颜色来区分,而因为模型在预训练期间被教导忽略颜色,它可能在这些下游任务中表现不佳。

选择转换应该根据我们想要解决的下游任务来指导。

我鼓励你花一点时间查看上面论文中的图,并思考旋转和纹理转换如何影响模型可能表现不佳的任务。

从以上示例中得到的启示是,转换的选择应该根据我们想要解决的下游任务的具体情况来指导。使用错误的转换进行预训练实际上可能会阻碍模型在后续任务中的表现。

胸部 X 光片的转换

现在让我们看看转换选择对 X 光图像的重要性。假设我们忽略将图像分类为不同医学状况的下游任务,而仅仅遵循 Google 和 Meta 的研究人员的建议,使用随机裁剪作为我们的转换。

选择适合 X 光图像的转换。图片来源:作者。

让橙色圆圈代表图像中的一部分,指示某种特定的情况,例如肺部的某种损伤。通过随机裁剪,我们可能会得到一个积极的例子,如上图所示:损伤区域被裁剪掉。

对比损失会教会模型损伤的肺和没有损伤的肺是相似的。这种预训练可能使微调模型以识别这种类型的肺损伤变得困难。

另外两个声称最佳的转换对于 X 光数据也不适用。对灰度图像应用颜色抖动或模糊可能会适得其反,因为灰度的阴影或局部模糊可能表明某种特定的医学状况。再次强调,转换必须始终根据特定的数据集和下游任务来选择。

理论方面的内容就到这里了;让我们看看自监督对比学习在 X 光分类中的实际应用!

自监督学习下的 X 光分类

我们与 Tooploox 的同事一起,开始探索自监督学习在医学诊断中的价值。

我们使用了CheXpert 数据集,该数据集包含约 220k 张标注有十个互斥类别的胸部 X 光图像,指示不同的医疗条件和病人是否使用了辅助设备。我们仅使用了超过 200k 张正面图像的子集。

我们选择了约 200k 张图像的随机子集进行自监督预训练。经过一系列实验,我们决定使用轻微的随机旋转、水平翻转和随机透视作为对锚图像应用的变换。所有 CheXpert 中的图像都有标签,但我们在预训练数据中忽略了这些标签。

来自 CheXpert 数据集的示例图像。来源:链接

在预训练之后,我们在不同大小的标注数据集上以监督方式微调了模型:从 250 张到 10k 张图像。目标是研究性能如何随着标注集大小的变化而变化。

最后,我们在 300 张手动标注的图像上测试了这些模型(微调数据的标签是由数据集作者通过自动解析病人的记录获得的,这可能引入了一些噪声;而测试标签则是由医生手动标注的,质量较高)。

性能评估

我们比较了三种模型架构:

  • 一种传统的迁移学习方法,使用 ResNet18。仅在标注微调集上以监督方式进行训练。这反映了我们不使用自监督学习的情况,因此不得不忽略那些未标注的数据。

  • 如前所述的简单三重损失模型,使用相同的 ResNet18 作为骨干,但使用三重损失和我们选择的变换以对比方式进行预训练。

  • Meta 的 MoCo,使用相同的 ResNet18 骨干和我们的变换集。

每个模型已被训练和测试十次,每次使用不同大小的标注微调集。我们通过 ROC 曲线下面积或 AUC 来比较它们。

结果

不同架构和标注集大小的 AUC 如下面的图所示。

模型比较。图片由作者提供。

自监督模型明显超越了监督基准。然而,从这些结果中还有其他有趣的结论:

  • 自监督预训练在标注集最小时提供了最大的提升:仅用 250 个标注样本就超越了监督基准 10 个百分点。

  • 自监督预训练即使在标注数据集较大时也能改进监督基准:即使有 10k 个标注样本,提升仍达到约 6 个百分点。

  • MoCo 相比 Triplet Loss 在基准上获得的增益更多,特别是当标注数据集较小时。

让我们更仔细地查看数据中的类别频率。下图左侧面板显示了 MoCo 相对于基线的优势(通过 AUC 差异来衡量),按每个类别单独计算。右侧面板显示了数据集中各类别的频率。

虽然自监督学习对每个类别都带来了一些收益,但对于相对稀有的类别收益似乎最大。这与之前的图表结果一致,后者显示了在标记集较小的情况下改进最为显著。

结论

自监督学习在计算机视觉领域在过去三年里取得了巨大的进展。由大型 AI 研究实验室发布的对比架构(以 Meta 为首)不断提高标准。这有两个主要的影响。

首先,高效利用未标记数据集的能力将在许多数据稀缺的行业中引发变革。

其次,训练这些所谓的基础模型,从未标记的数据中学习背景知识,并将其转移到多个不同的下游任务中,是推动 AI 泛化的重要一步。

致谢

这篇文章基于我在 2022 年 12 月 18 日在波兰华沙举办的数据科学峰会上的演讲。演示文稿幻灯片可以在这里查看。

将自监督学习应用于医疗应用的研究是我与 Tooploox 的同事们的共同努力。你可以在公司博客上阅读更多内容。

感谢阅读!

如果你喜欢这篇文章,为什么不订阅电子邮件更新以便获取我新文章的通知呢?通过成为 Medium 会员,你可以支持我的写作,并无限制访问其他作者以及我自己的所有故事。

想要始终保持对不断加速发展的机器学习和 AI 领域的关注吗?查看我的新通讯AI Pulse。需要咨询?你可以在这里问我任何问题或预约 1 对 1 咨询。

你还可以尝试我其他的文章。无法决定?选择以下其中之一:

## 蒙特卡洛 Dropout

用一个小技巧免费提高你的神经网络,同时获得模型不确定性的估计作为额外福利。

[towardsdatascience.com ## 生活中贝叶斯思维的重要性

这种简单的思维转变将帮助你更好地理解你周围不确定的世界。

[towardsdatascience.com ## 确立因果关系:第四部分

利用政策变动的差分法

[towardsdatascience.com

使用投影头的自监督学习

原文:towardsdatascience.com/self-supervised-learning-using-projection-heads-b77af3911d33

使用无标记数据提升性能

丹尼尔·沃菲尔德Towards Data Science 丹尼尔·沃菲尔德

·发布于 Towards Data Science ·阅读时间 13 分钟·2023 年 6 月 29 日

--

“自监督”由丹尼尔·沃菲尔德使用 p5.js 实现

在这篇文章中,你将学习自监督学习、如何利用它提升模型性能,以及投影头在自监督学习过程中的作用。我们将涵盖直觉、一些文献以及 PyTorch 中的计算机视觉示例。

适合谁? 任何拥有无标记且可增补数据的人。

这篇文章的难度如何? 本文的开头对初学者概念上是可及的,但例子更多地集中在中级和高级数据科学家上。

前提条件: 对卷积网络和密集网络有较高水平的理解。

代码: 完整代码请见这里

自监督与其他方法

通常,当人们想到模型时,他们会考虑两类:监督模型和无监督模型。

  • 监督学习 是基于标记信息训练模型的过程。例如,在训练一个预测图像是否包含猫或狗的模型时,首先收集一组标记为猫或狗的图像,然后训练模型(使用 梯度下降)来理解包含猫和狗图像之间的差异。

  • 无监督学习 是向模型提供某种无标记信息,并通过对数据进行某种转化来提取有用的推断。无监督学习的经典例子是聚类;在此过程中,从未分组的数据中基于局部位置提取信息组。

自监督学习介于两者之间。自监督使用 程序生成的标签 而非人工标签。 在某些方面,它是监督式学习,因为模型从标记的数据中学习,但在其他方面,它是无监督的,因为训练算法没有提供标签。因此称为自监督。

自监督学习(SSL)的目标是在没有任何人工标记数据注释的情况下生成有用的特征表示。— K Gupta 等

自监督概述

自监督使用对数据的变换,加上巧妙的损失函数,来教模型理解相似的数据。我们可能不知道图像包含什么(它是无人标记的),但我们知道稍微修改的图像仍然是那个东西的图像。因此,你可以将一张图像及其翻转图像标记为包含相同的东西。

即使我们不知道这张图像包含一只猫,我们也知道图像包含的内容不论如何操控图像都是相同的。

这个想法是,通过训练模型学习数据是否包含相似的东西,你是在教模型理解数据,而不论数据是如何呈现的。换句话说,你是在训练模型理解图像,通常情况下,不论类别。一旦自监督完成,模型可以在少量标记数据上进行精炼,以理解最终任务(是狗的图像还是猫的图像)。

自监督学习如何融入一般工作流程的基本概念

我在这个例子中使用的是图像,但自监督可以应用于任何数据,这些数据通过增强方法来改变数据,而不改变其从最终建模问题的角度来看本质。例如,音频数据的增强可以使用波表进行,具体说明见这篇文章

附言:另一种常见的概念化方式是风格不变性。换句话说,你是在训练一个模型来忽略图像中的风格差异。

投影头

随着机器学习作为一个学科的发展,一些架构选择已经证明是普遍有用的。例如,在卷积网络中,一些网络有主干,一些有颈部,还有一些有头部。头部,通常来说,是一个位于较大网络末端的密集网络,将特征转化为最终输出

第一篇 YOLO 论文是经典的卷积架构。它可以被认为有两个部分:一系列卷积将原始图像转换为关键特征(主干),以及一个密集网络将这些特征转化为最终结果(头部)。来源

这个头部的功能通常被描述为投影。在数学和许多其他学科中,投影是将某种信息从一个空间映射到另一个空间的概念,就像灯光如何将你的三维形态映射到墙上的二维阴影中一样。投影头是一个位于较大网络末端的密集网络,负责将一些信息转换为其他信息。在我们的玩具示例中,猫与狗,投影头将图像的通用理解作为特征映射到猫与狗的预测中。

为什么投影头在自监督学习中如此重要。

想象一下你正在玩大富翁游戏。有很多需要学习的东西;投资房地产可以带来收益,做出投资前需要考虑未来,经过“起点”并获得 200 美元,对鞋子和顶针之间没有根本区别等等。在大富翁游戏中,有两种类型的信息:一般适用的信息和任务特定的信息。你不应该在日常生活中每次看到“起点”这个词就兴奋:那是任务特定的。你应该仔细考虑你的投资:那是普遍有用的。

我们可以把自监督看作是一场“游戏”,在这场游戏中,模型学习识别相似或不相似的图像。通过玩这个游戏,模型学习到一般的图像理解,以及实现两张图像是否相同的具体规则。

在一个经典的卷积网络中,颈部和头部是常见的直觉,其中卷积提取特征、风格、纹理以及其他用于一般图像理解的必要信息。另一方面,密集头部则将这些发现的特征映射到任务特定的输出中(例如,识别两张图片是否相同,类似于自监督学习)。

一旦我们在相似数据上训练了一个自监督模型,并且现在想根据标记数据来细化这个模型,我们不再关心识别两张图像是否相同的任务特定逻辑。我们希望保留一般的图像理解,但用分类知识替换任务特定知识。为此,我们丢弃投影头,并用一个新的投影头替换它

模型在自监督学习(顶部)和监督学习(底部)中被丢弃的部分。卷积骨干网络被保留,而负责任务特定逻辑的投影头则被丢弃。

在自监督学习过程中使用投影头是当前的研究热点(这篇文章对这一主题进行了很好的阐述),但直观上是这样的:在自监督学习中你必须具备必要的逻辑才能在自监督任务中表现良好,以便你能学习到一般适用的特征表示。一旦你学会了这些特征,包含优化自监督逻辑的投影头可以被丢弃。

创建和使用投影头与传统建模有所不同。投影头的目标并不是必然要创建一个擅长自监督任务的模型,而是引导生成在后续任务中更有用的特征表示。

PyTorch 中的自监督

在这个例子中,我们将使用 MNIST 数据集的一个修改版,MNIST 是一个经典的数据集,包含手写数字的图像,并配有标记以表示图像所代表的数字。

MNIST 包含 60,000 张有标签的训练图像和 10,000 张有标签的测试图像。然而,在这个例子中,我们将丢弃除了 200 个训练标签之外的所有标签。这意味着我们将有 200 张有标签的图像用于训练,以及 59,800 张无标签的图像用于训练。这个修改反映了自监督最有用的应用类型:数据量大但标记成本高的数据集。

完整的代码可以在这里找到。

MNIST 数据集在GNU 通用公共许可证 v3.0下授权使用,而用于加载它的 torchvision 模块在BSD 3-Clause “New” or “Revised” License下授权使用,两者都允许商业使用。

1) 加载数据

加载数据集

"""
Downloading and rendering sample MNIST data
"""

#torch setup
import torch
import torchvision
import torchvision.datasets as datasets
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#downloading mnist
mnist_trainset = datasets.MNIST(root='./data', train=True,
                                download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False,
                               download=True, transform=None)

#printing lengths
print('length of the training set: {}'.format(len(mnist_trainset)))
print('length of the test set: {}'.format(len(mnist_testset)))

#rendering a few examples
for i in range(3):
  print('the number {}:'.format(mnist_trainset[i][1]))
  mnist_trainset[i][0].show()

下载的数据集,包含一些样本

2) 将数据分为有标签和无标签数据

在这个例子中,我们将人工忽略大部分训练集中的标签,以模拟一种容易收集大量数据但难以或资源密集型地标记所有数据的用例。这个代码块还执行了一些必要的数据操作,以便利用 PyTorch。

"""
Creating un-labled data, and handling necessary data preprocessing
"""

from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import OneHotEncoder

# ========== Data Extraction ==========
# unlabeling some data, and one hot encoding the labels which remain
# =====================================

partition_index = 200

def one_hot(y):
  #For converting a numpy array of 0-9 into a one hot encoding of vectors of length 10
  b = np.zeros((y.size, y.max() + 1))
  b[np.arange(y.size), y] = 1
  return b

print('processing labeld training x and y')
train_x = np.asarray([np.asarray(mnist_trainset[i][0]) for i in tqdm(range(partition_index))])
train_y = one_hot(np.asarray([np.asarray(mnist_trainset[i][1]) for i in tqdm(range(partition_index))]))

print('processing unlabled training data')
train_unlabled = np.asarray([np.asarray(mnist_trainset[i][0]) for i in tqdm(range(partition_index,len(mnist_trainset)))])

print('processing labeld test x and y')
test_x = np.asarray([np.asarray(mnist_testset[i][0]) for i in tqdm(range(len(mnist_testset)))])
test_y = one_hot(np.asarray([np.asarray(mnist_testset[i][1]) for i in tqdm(range(len(mnist_testset)))]))

# ========== Data Reformatting ==========
# adding a channel dimension and converting to pytorch
# =====================================

#adding a dimension to all X values to put them in the proper shape
#(batch size, channels, x, y)
print('reformatting shape...')
train_x = np.expand_dims(train_x, 1)
train_unlabled = np.expand_dims(train_unlabled, 1)
test_x = np.expand_dims(test_x, 1)

#converting data to pytorch type
torch_train_x = torch.tensor(train_x.astype(np.float32), requires_grad=True).to(device)
torch_train_y = torch.tensor(train_y).to(device)
torch_test_x = torch.tensor(test_x.astype(np.float32), requires_grad=True).to(device)
torch_test_y = torch.tensor(test_y).to(device)
torch_train_unlabled = torch.tensor(train_unlabled.astype(np.float32), requires_grad=True).to(device)

print('done')

从格式化过程中的打印输出

3) 定义模型

为了加快训练,这个问题使用了一个超简单的卷积网络和最少的超参数探索。这个模型有两个主要部分:卷积主干和密集连接的头部。

"""
Using PyTorch to create a modified, smaller version of AlexNet
"""
import torch.nn.functional as F
import torch.nn as nn

#defining model backbone
class Backbone(nn.Module):
    def __init__(self):
        super(Backbone, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 16, 3)
        self.conv3 = nn.Conv2d(16, 32, 3)

        if torch.cuda.is_available():
            self.cuda()

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = F.max_pool2d(F.relu(self.conv3(x)), 2)
        x = torch.flatten(x, 1)
        return x

#defining model head
class Head(nn.Module):
    def __init__(self, n_class=10):
        super(Head, self).__init__()
        self.fc1 = nn.Linear(32, 32)
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, n_class)

        if torch.cuda.is_available():
            self.cuda()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x,1)

#defining full model
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.backbone = Backbone()
        self.head = Head()

        if torch.cuda.is_available():
            self.cuda()

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

model_baseline = Model()
print(model_baseline(torch_train_x[:1]).shape)
model_baseline

输出维度和模型架构。

4) 使用仅有监督学习作为基准进行训练和测试

为了了解自监督如何提高性能,我们将在仅有的 200 个标记样本上训练我们的基线模型。

"""
Training model using only supervised learning, and rendering the results.
This supervised training function is reused in the future for fine tuning
"""

def supervised_train(model):

    #defining key hyperparamaters explicitly (instead of hyperparamater search)
    batch_size = 64
    lr = 0.001
    momentum = 0.9
    num_epochs = 20000

    #defining a stocastic gradient descent optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    #defining loss function
    loss_fn = torch.nn.CrossEntropyLoss()

    train_hist = []
    test_hist = []
    test_accuracy = []

    for epoch in tqdm(range(num_epochs)):

        #iterating over all batches
        for i in range(int(len(train_x)/batch_size)-1):

            #Put the model in training mode, so that things like dropout work
            model.train(True)

            # Zero gradients
            optimizer.zero_grad()

            #extracting X and y values from the batch
            X = torch_train_x[i*batch_size: (i+1)*batch_size]
            y = torch_train_y[i*batch_size: (i+1)*batch_size]

            # Make predictions for this batch
            y_pred = model(X)

            #compute gradients
            loss_fn(model(X), y).backward()

            # Adjust learning weights
            optimizer.step()

        with torch.no_grad():

            #Disable things like dropout, if they exist
            model.train(False)

            #calculating epoch training and test loss
            train_loss = loss_fn(model(torch_train_x), torch_train_y).cpu().numpy()
            y_pred_test = model(torch_test_x)
            test_loss = loss_fn(y_pred_test, torch_test_y).cpu().numpy()

            train_hist.append(train_loss)
            test_hist.append(test_loss)

            #computing test accuracy
            matches = np.equal(np.argmax(y_pred_test.cpu().numpy(), axis=1), np.argmax(torch_test_y.cpu().numpy(), axis=1))
            test_accuracy.append(matches.sum()/len(matches))

    import matplotlib.pyplot as plt
    plt.plot(train_hist, label = 'train loss')
    plt.plot(test_hist, label = 'test loss')
    plt.legend()
    plt.show()
    plt.plot(test_accuracy, label = 'test accuracy')
    plt.legend()
    plt.show()

    maxacc = max(test_accuracy)
    print('max accuracy: {}'.format(maxacc))

    return maxacc

supervised_maxacc = supervised_train(model_baseline)

监督模型在训练过程中的测试准确率。考虑到随机猜测的准确率为 10%,而这个模型仅接触了 200 个标记样本,我对其表现如此之好感到惊讶。尽管如此,通过结合自监督学习,我们仍然可以做得更好。

5) 定义数据增强

自监督学习需要数据增强。这个函数将一批图像增强两次,结果是得到一对随机增强的图像,用于对比学习。

import torch
import torchvision.transforms as T

class Augment:
   """
   A stochastic data augmentation module
   Transforms any given data example randomly
   resulting in two correlated views of the same example,
   denoted x ̃i and x ̃j, which we consider as a positive pair.
   """

   def __init__(self):

       blur = T.GaussianBlur((3, 3), (0.1, 2.0))

       self.train_transform = torch.nn.Sequential(
           T.RandomAffine(degrees = (-50,50), translate = (0.1,0.1), scale=(0.5,1.5), shear=0.2),
           T.RandomPerspective(0.4,0.5),
           T.RandomPerspective(0.2,0.5),
           T.RandomPerspective(0.2,0.5),
           T.RandomApply([blur], p=0.25),
           T.RandomApply([blur], p=0.25)
       )

   def __call__(self, x):
       return self.train_transform(x), self.train_transform(x)

"""
Generating Test Augmentation
"""
a = Augment()
aug = a(torch_train_unlabled[0:100])

i=1
f, axarr = plt.subplots(2,2)
#positive pair
axarr[0,0].imshow(aug[0].cpu().detach().numpy()[i,0])
axarr[0,1].imshow(aug[1].cpu().detach().numpy()[i,0])
#another positive pair
axarr[1,0].imshow(aug[0].cpu().detach().numpy()[i+1,0])
axarr[1,1].imshow(aug[1].cpu().detach().numpy()[i+1,0])
plt.show()

同一批次内的两个正样本对

6) 定义对比损失

对比损失是用于将正样本对在嵌入空间中紧密放置,而将负样本对放置得更远的损失函数。

class ContrastiveLoss(nn.Module):
   """
   Vanilla Contrastive loss, also called InfoNceLoss as in SimCLR paper
   """
   def __init__(self, batch_size, temperature=0.5):
       """
       Defining certain constants used between calculations. The mask is important
       in understanding which are positive and negative examples. For more
       information see https://theaisummer.com/simclr/
       """
       super().__init__()
       self.batch_size = batch_size
       self.temperature = temperature
       self.mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float().to(device)

   def calc_similarity_batch(self, a, b):
       """
       Defines the cosin similarity between one example, and all other examples.
       For more information see https://theaisummer.com/simclr/
       """
       representations = torch.cat([a, b], dim=0)
       return F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

   def forward(self, proj_1, proj_2):
       """
       The actual loss function, where proj_1 and proj_2 are embeddings from the
       projection head. This function calculates the cosin similarity between
       all vectors, and rewards closeness between examples which come from the
       same example, and farness for examples which do not. For more information
       see https://theaisummer.com/simclr/
       """
       batch_size = proj_1.shape[0]
       z_i = F.normalize(proj_1, p=2, dim=1)
       z_j = F.normalize(proj_2, p=2, dim=1)

       similarity_matrix = self.calc_similarity_batch(z_i, z_j)

       sim_ij = torch.diag(similarity_matrix, batch_size)
       sim_ji = torch.diag(similarity_matrix, -batch_size)

       positives = torch.cat([sim_ij, sim_ji], dim=0)

       nominator = torch.exp(positives / self.temperature)

       denominator = self.mask * torch.exp(similarity_matrix / self.temperature)

       all_losses = -torch.log(nominator / torch.sum(denominator, dim=1))
       loss = torch.sum(all_losses) / (2 * self.batch_size)
       return loss

"""
testing
"""
loss = ContrastiveLoss(torch_train_x.shape[0]).forward
fake_proj_0, fake_proj_1 = a(torch_train_x)
fake_proj_0 = fake_proj_0[:,0,:,0]
fake_proj_1 = fake_proj_1[:,0,:,0]
loss(fake_proj_0, fake_proj_1)

损失函数的输出。关键是存在grad_fn,意味着该函数是可微的,因此可以更新模型参数。

7) 自监督训练

训练模型通过自监督和对比损失理解图像的相似性和差异。由于这是一个中间步骤,很难创建清晰和直观的性能指标。因此,我选择花费一些额外的计算来深入理解损失,这对于调整参数以获得一致的模型改进非常有用。

from torch.optim.lr_scheduler import ExponentialLR

#degining a new model
model = Model()
model.train()

#defining key hyperparameters
batch_size = 512
epoch_size = round(torch_train_unlabled.shape[0]/batch_size)-1
num_epochs = 100
patience = 5
cutoff_ratio = 0.001

#defining key learning functions
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_examples = train_unlabled.shape[0]
lossfn = ContrastiveLoss(batch_size).forward
augmentfn = Augment() #augment function

#for book keeping
loss_hist = []
improvement_hist = []
schedule_hist = []

#for exponentially decreasing learning rate
scheduler = ExponentialLR(optimizer,
                          gamma = 0.95)

#for early stopping
patience_count=0

#Training Loop
avg_loss = 1e10
for i in range(num_epochs):

    print('epoch {}/{}'.format(i,num_epochs))

    total_loss = 0
    loss_change = 0

    for j in tqdm(range(epoch_size)):

        #getting random batch
        X = torch_train_unlabled[j*batch_size: (j+1)*batch_size]

        #creating pairs of augmented batches
        X_aug_i, X_aug_j = augmentfn(X)

        #ensuring gradients are zero
        optimizer.zero_grad()

        #passing through the model
        z_i = model(X_aug_i)
        z_j = model(X_aug_j)

        #calculating loss on the model embeddings, and computing gradients
        loss = lossfn(z_i, z_j)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        #checking to see if backpropegation resulted in a reduction of the loss function
        if True:
            #passing through the model, now that parameters have been updated
            z_i = model(X_aug_i)
            z_j = model(X_aug_j)

            #calculating new loss value
            new_loss = lossfn(z_i, z_j)

            loss_change += new_loss.cpu().detach().numpy() - loss.cpu().detach().numpy()

        total_loss += loss.cpu().detach().numpy()

        #step learning rate scheduler
        schedule_hist.append(scheduler.get_last_lr())

    scheduler.step()

    #calculating percentage loss reduction
    new_avg_loss = total_loss/epoch_size
    per_loss_reduction = (avg_loss-new_avg_loss)/avg_loss
    print('Percentage Loss Reduction: {}'.format(per_loss_reduction))

    #deciding to stop if loss is not decreasing fast enough
    if per_loss_reduction < cutoff_ratio:
        patience_count+=1
        print('patience counter: {}'.format(patience_count))
        if patience_count > patience:
            break
    else:
        patience_count = 0

    #setting new loss as previous loss
    avg_loss = new_avg_loss

    #book keeping
    avg_improvement = loss_change/epoch_size
    loss_hist.append(avg_loss)
    improvement_hist.append(avg_improvement)
    print('Average Loss: {}'.format(avg_loss))
    print('Average Loss change (if calculated): {}'.format(avg_im

训练的前几个时期输出,包含若干基于损失的性能指标,这些指标对于调整参数非常有用。

8) 自监督训练进展

这是自监督学习带来的损失改进。你可以看到指数递减的学习率与损失值之间的关系。

plt.plot(schedule_hist, label='learning rate')
plt.legend()
plt.show()
plt.plot(loss_hist, label = 'loss')
plt.legend()
plt.show()

学习率按样本绘制,而损失按时期绘制,但你可以理解。损失下降然后收敛,学习在损失减少变得微不足道时停止。

9) 用监督学习微调自监督模型

使用之前的监督函数在监督数据上训练自监督模型。这做了两次:一次使用原始自监督学习头部,一次使用新随机初始化的头部。

import copy

#creating duplicate models for finetuning
model_same_head = copy.deepcopy(model)
model_new_head = copy.deepcopy(model)

#replacing the projection head with a randomly initialized head
#for one of the models
model_new_head.head = Head()

#training models
same_head_maxacc = supervised_train(model_same_head)
new_head_maxacc = supervised_train(model_new_head)

使用原始头部(左)和随机初始化头部(右)的训练结果

10) 讨论

如图所示,纯监督学习表现最差,自监督学习与监督学习结合表现第二好,而自监督学习与新头部结合表现最佳。

这些结果仅用于演示;没有进行显著的超参数优化,这在实际应用中是必要的。然而,本笔记本确实支持自监督的理论效用,并且强调了投影头谨慎使用的重要性。

  • 仅监督学习:52.5%准确率

  • SSL 和在 SSL 头上的监督:59.7%准确率

  • SSL 和在新头上的监督:63.6%

仅考虑 200 张标记图像的情况下,63.6%的准确率非常令人印象深刻!

关注以获取更多信息!

在未来的帖子中,我还会描述 ML 领域的几篇重要论文,重点是实用和直观的解释。

归属: 本文档中的所有图像均由 Daniel Warfield 创建,除非另有来源说明。你可以在非商业目的下使用本文中的任何图像,只要你引用本文,danielwarfield.dev,或两者皆可。

使用 PostgreSQL 和 OpenAI 嵌入实现语义搜索

原文:towardsdatascience.com/semantic-search-with-postgresql-and-openai-embeddings-4d327236f41f?source=collection_archive---------3-----------------------#2023-11-21

Dima TimofeevTowards Data Science Dima Timofeev

·

查看 发表在 Towards Data Science ·4 分钟阅读·2023 年 11 月 21 日

--

图片由 Igor Omilaev 提供,来源于 Unsplash

在公司数据库中实现语义搜索可能是具有挑战性的,并且需要付出大量的努力。然而,真的非得如此吗?在这篇文章中,我将展示如何利用 PostgreSQL 和 OpenAI 嵌入技术来在你的数据上实现语义搜索。如果你不希望使用 OpenAI 嵌入 API,我还会为你提供免费的嵌入模型链接。

从很高的层面来看,具有 LLM 的向量数据库允许对可用数据(存储在数据库、文档等中)进行语义搜索。感谢 “Efficient Estimation of Word Representations in Vector Space” 论文(也称为 “Word2Vec 论文”),由传奇 Jeff Dean 共同作者,我们知道如何将单词表示为实值向量。词嵌入是单词在向量空间中的密集向量表示,其中意义相似的单词彼此接近。词嵌入捕捉了单词之间的语义关系,并且有多种技术来创建它们。

图片由作者提供

让我们实践并使用 OpenAI 的 text-embedding-ada 模型!距离函数的选择通常没那么重要。OpenAI 推荐使用余弦相似度。如果你不想使用 OpenAI 的嵌入,倾向于本地运行不同的模型而不是进行 API 调用,我建议考虑其中一种 SentenceTransformers 预训练模型。选择你的模型时要谨慎。

import os

import openai
from openai.embeddings_utils import cosine_similarity

openai.api_key = os.getenv("OPENAI_API_KEY")

def get_embedding(text: str) -> list:
 response = openai.Embedding.create(
     input=text,
     model="text-embedding-ada-002"
 )
 return response['data'][0]['embedding']

good_ride = "good ride"
good_ride_embedding = get_embedding(good_ride)
print(good_ride_embedding)
# [0.0010935445316135883, -0.01159335020929575, 0.014949149452149868, -0.029251709580421448, -0.022591838613152504, 0.006514389533549547, -0.014793967828154564, -0.048364896327257156, -0.006336577236652374, -0.027027441188693047, ...]
len(good_ride_embedding)
# 1536

既然我们已经理解了嵌入的概念,让我们利用它来排序一些评论。

good_ride_review_1 = "I really enjoyed the trip! The ride was incredibly smooth, the pick-up location was convenient, and the drop-off point was right in front of the coffee shop."
good_ride_review_1_embedding = get_embedding(good_ride_review_1)
cosine_similarity(good_ride_review_1_embedding, good_ride_embedding)
# 0.8300454513797334

good_ride_review_2 = "The drive was exceptionally comfortable. I felt secure throughout the journey and greatly appreciated the on-board entertainment, which allowed me to have some fun while the car was in motion."
good_ride_review_2_embedding = get_embedding(good_ride_review_2)
cosine_similarity(good_ride_review_2_embedding, good_ride_embedding)
# 0.821774476808789

bad_ride_review = "A sudden hard brake at the intersection really caught me off guard and stressed me out. I wasn't prepared for it. Additionally, I noticed some trash left in the cabin from a previous rider."
bad_ride_review_embedding = get_embedding(bad_ride_review)
cosine_similarity(bad_ride_review_embedding, good_ride_embedding)
# 0.7950041130579355

虽然绝对差异可能看起来很小,但考虑到数以千计的评论排序函数。在这种情况下,我们可以优先突出显示顶部的正面评论。

一旦一个单词或文档被转换为嵌入,它可以存储在数据库中。然而,这个动作并不会自动将数据库分类为向量数据库。只有当数据库开始支持对向量的快速操作时,我们才能称其为向量数据库。

目前有许多商业和开源的向量数据库,这使得它成为一个高度讨论的话题。我将通过使用 pgvector,一个开源的 PostgreSQL 扩展,来演示向量数据库的功能,它为 arguably 最受欢迎的数据库提供了向量相似性搜索功能。

让我们运行带有 pgvector 的 PostgreSQL 容器:

docker pull ankane/pgvector

docker run --env "POSTGRES_PASSWORD=postgres" --name "postgres-with-pgvector" --publish 5432:5432 --detach  ankane/pgvector

让我们启动 pgcli 以连接到数据库 (pgcli postgres://postgres:postgres@localhost:5432),创建一个表,插入我们计算的嵌入,然后选择相似的项目:

-- Enable pgvector extension.
CREATE EXTENSION vector;

-- Create a vector column with 1536 dimensions.
-- The `text-embedding-ada-002` model has 1536 dimensions.
CREATE TABLE reviews (text TEXT, embedding vector(1536));

-- Insert three reviews from the above. I omitted the input for your convinience.
INSERT INTO reviews (text, embedding) VALUES ('I really enjoyed the trip! The ride was incredibly smooth, the pick-up location was convenient, and the drop-off point was right in front of the coffee shop.', '[-0.00533589581027627, -0.01026702206581831, 0.021472081542015076, -0.04132508486509323, ...');
INSERT INTO reviews (text, embedding) VALUES ('The drive was exceptionally comfortable. I felt secure throughout the journey and greatly appreciated the on-board entertainment, which allowed me to have some fun while the car was in motion.', '[0.0001858668401837349, -0.004922827705740929, 0.012813017703592777, -0.041855424642562866, ...');
INSERT INTO reviews (text, embedding) VALUES ('A sudden hard brake at the intersection really caught me off guard and stressed me out. I was not prepared for it. Additionally, I noticed some trash left in the cabin from a previous rider.', '[0.00191772251855582, -0.004589076619595289, 0.004269456025213003, -0.0225954819470644, ...');

-- sanity check
select count(1) from reviews;
-- +-------+
-- | count |
-- |-------|
-- | 3     |
-- +-------+

我们现在已经准备好搜索相似的文档。我再次缩短了“good ride”的嵌入,因为打印 1536 维度是多余的。

--- The embedding we use here is for "good ride"
SELECT substring(text, 0, 80) FROM reviews ORDER BY embedding <-> '[0.0010935445316135883, -0.01159335020929575, 0.014949149452149868, -0.029251709580421448, ...';

-- +--------------------------------------------------------------------------+
-- | substring                                                                |
-- |--------------------------------------------------------------------------|
-- | I really enjoyed the trip! The ride was incredibly smooth, the pick-u... |
-- | The drive was exceptionally comfortable. I felt secure throughout the... |
-- | A sudden hard brake at the intersection really caught me off guard an... |
-- +--------------------------------------------------------------------------+
SELECT 3
Time: 0.024s

完成了!如你所见,我们已经为多个文档计算了嵌入向量,将其存储在数据库中,并进行了向量相似性搜索。这些应用的潜力巨大,涵盖了从企业搜索到医疗记录系统中的相似症状患者识别等众多领域。此外,这种方法不仅限于文本;也可以计算声音、视频和图像等其他类型数据的相似性。

享受吧!

使用 BERT 进行语义文本相似度分析

原文:towardsdatascience.com/semantic-textual-similarity-with-bert-fc800656e7a3

如何使用 BERT 计算两段文本之间的语义相似度

Ruben WinastwanTowards Data Science Ruben Winastwan

·发表于 Towards Data Science ·阅读时间 11 分钟·2023 年 2 月 15 日

--

图片来源于 Leeloo Thefirst: www.pexels.com/photo/brown-wooden-ruler-and-colored-pencils-on-papers-8970296/

自从 2017 年由 Google Brain 团队首次推出以来,Transformers 迅速成为计算机视觉和自然语言处理领域中各种应用的最先进模型。其卓越的性能促成了多个最先进模型的开发,如 BERT 及其变体 distilBERT 和 RoBERTa。

BERT 在各种自然语言处理任务中优于旧的递归模型,如文本分类、命名实体识别(NER)、问答系统,甚至是我们在本文中将重点讨论的任务——语义文本相似度(STS)。

因此,在本文中,我们将探讨如何利用 Sentence Transformers 库训练 BERT 模型以执行 STS 任务。接下来,我们将使用训练好的模型来预测未知数据。但作为开始,我们需要首先了解 STS 任务的实际内容以及我们将用于该任务的数据集。

语义文本相似度及数据集

语义文本相似度(STS)指的是一个任务,我们比较一段文本与另一段文本的相似性。

图片由作者提供

我们从模型得到的 STS 任务输出通常是一个浮点数,表示比较的两个文本之间的相似度。

有几种方法可以量化一对文本之间的相似度。让我们以本文将使用的数据集为例,即 STSB 数据集(使用 CC-Share Alike 4.0 许可)。

!pip install datasets

from datasets import load_dataset
dataset = load_dataset("stsb_multi_mt", name="en", split="train")

print(dataset[0])
>>> {'sentence1': 'A plane is taking off.',
 'An air plane is taking off.',
 'similarity_score': 5.0}

print(dataset[1])
>>> {'sentence1': 'A man is playing a large flute.',
 'sentence2': 'A man is playing a flute.',
 'similarity_score': 3.799999952316284}

一对文本之间的相似性被标记为从 1 到 5 的数字;如果一对文本完全不相似,则标记为 1;如果一对文本在语义意义上完全相似,则标记为 5。

然而,有一个问题。当我们想要使用 Sentence Transformers 库训练 BERT 模型时,我们需要将相似性分数标准化到 0 到 1 之间。这可以通过简单地将每个相似性分数除以 5 来实现。

similarity = [i['similarity_score'] for i in dataset]
normalized_similarity = [i/5.0 for i in similarity]

现在我们知道了要使用的数据集,接下来让我们继续讨论本文中要使用的模型。

如何基于 Transformers 的模型衡量一对文本之间的相似性

基于 Transformers 的模型,如 BERT、distilBERT 或 RoBERTa,期望输入为标记序列。因此,第一步应该是将输入文本转换为标记序列。这个过程称为标记化。

BERT 模型的标记化过程包括两个步骤。首先,将输入文本分割成几个小块,称为标记;一个标记可以是一个词或一个子词。其次,在标记序列的开头和结尾添加两个特殊标记。这两个特殊标记是:

  • [CLS]: 这是每个标记序列中的第一个标记

  • [SEP]: 这个标记对于给 BERT 提供哪个标记属于哪个序列的提示是很重要的。如果只有一个标记序列,那么这个标记将是序列中的最后一个标记。

根据你事先定义的标记器的最大序列长度,一些 [PAD] 标记也会在 [SEP] 标记之后被追加。

标记化的输入将被传递到模型中,作为输出,我们将获得每个标记的嵌入向量。每个嵌入向量具有 768 个维度。

如果我们使用 BERT 进行分类目的,那么通常我们会取 [CLS] 标记的嵌入向量,并将其传递给最终的 softmax 或 sigmoid 层,这一层将作为分类器。

图片由作者提供

如果我们使用 BERT 进行 STS 任务,工作流程将类似于以下步骤:

图片由作者提供

使用上述工作流程,BERT 在 STS 基准测试中达到了最先进的性能。然而,这种工作流程有一个主要的缺点:可扩展性因素。

假设我们有一段全新的文本。接下来,我们希望在包含 10 万个不同文本的数据库中查询与这段新文本最相似的条目。如果我们使用上面提到的 BERT 架构,那么我们需要将新文本与数据库中的每一个条目进行 10 万次比较。这意味着需要进行 10 万次的标记化过程和前向传递。

这个可扩展性因素的主要问题是 BERT 输出的是每个标记的嵌入向量,而不是每个文本/句子的嵌入向量。

图片由作者提供

如果 BERT 能够给我们提供有意义的句子级嵌入,那么我们可以将每个条目的嵌入保存到我们的数据库中。一旦我们有了新的文本,我们只需通过余弦相似度将新文本的句子嵌入与数据库中每个条目的句子嵌入进行比较,这是一种更快的方法。

这是 Sentence BERT (SBERT) 试图解决的问题。你可以将 SBERT 视为通过应用孪生网络模型架构来微调的 BERT,如下所示:

作者提供的图像

上述架构的问题在于它仍然生成 token 级嵌入。因此,SBERT 在 BERT 的基础上实现了一个额外的池化层。SBERT 实现了三种不同的池化策略:

  • 使用 [CLS] token 的嵌入

  • 使用所有 token 级嵌入向量的均值(这是默认实现)

  • 使用时间上的最大 token 级嵌入向量

作者提供的图像

上述插图是 SBERT 模型的最终架构。我们在池化层之后得到的嵌入是一个具有 768 维度的文本向量。然后,可以通过成对距离或余弦相似度对这些嵌入进行比较,这正是 STS 任务的核心。

要实现 SBERT,我们可以使用 sentence-transformers 库。如果你还没有安装,可以通过 pip 安装:

!pip install sentence-transformers

现在我们将实现基于 BERT 的 SBERT 模型,但你也可以使用如 distilBERT 或 RoBERTa 等 BERT 变体来实现 SBERT,或者加载一个在特定数据集上预训练的模型。你可以在这里找到所有可用的模型

from sentence_transformers import SentenceTransformer, models

word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=128)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
sts_bert_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

从上面的代码片段中,我们首先加载 BERT 模型作为我们的词嵌入模型,然后在 BERT 模型上应用一个池化层,最终获得句子级嵌入。

假设我们有一对句子,我们想要获取每个句子的句子级嵌入。我们可以通过以下步骤做到这一点:

!pip install transformers

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

sentence_1 = [i['sentence1'] for i in dataset]
sentence_2 = [i['sentence2'] for i in dataset]
text_cat = [[str(x), str(y)] for x,y in zip(sentence_1, sentence_2)][0]

input_data = tokenizer(text_cat, padding='max_length', max_length = 128, truncation=True, return_tensors="pt")
output = sts_bert_model(input_data)

print(output['sentence_embedding'][0].size())
>>> torch.Size([768])

print(output['sentence_embedding'][1].size())
>>> torch.Size([768])

语义文本相似度实现

在本节中,我们将对在前一节中讨论的 STS 任务数据集进行 SBERT 模型训练。

模型架构定义

首先定义模型架构。

import torch

class STSBertModel(torch.nn.Module):

    def __init__(self):

        super(STSBertModel, self).__init__()

        word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=128)
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        self.sts_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    def forward(self, input_data):

        output = self.sts_model(input_data)

        return output

上述模型架构类似于我们在前一节中看到的内容。我们使用 BERT 基础模型作为我们的词嵌入模型。该模型的输出仍然是 token 级嵌入。因此,我们需要在其上添加一个池化层。

从我们的 SBERT 模型中得到的最终输出是 768 维的句子级嵌入向量。由于模型的输入是一对文本,因此输出也将是一对 768 维的句子级嵌入向量。

数据加载器

数据加载器是创建数据集批次所必需的。这个过程很重要,因为我们不能在训练过程中一次性将整个数据集输入到模型中。

class DataSequence(torch.utils.data.Dataset):

    def __init__(self, dataset):

        similarity = [i['similarity_score'] for i in dataset]
        self.label = [i/5.0 for i in similarity]
        self.sentence_1 = [i['sentence1'] for i in dataset]
        self.sentence_2 = [i['sentence2'] for i in dataset]
        self.text_cat = [[str(x), str(y)] for x,y in zip(self.sentence_1, self.sentence_2)]

    def __len__(self):

        return len(self.text_cat)

    def get_batch_labels(self, idx):

        return torch.tensor(self.label[idx])

    def get_batch_texts(self, idx):

        return tokenizer(self.text_cat[idx], padding='max_length', max_length = 128, truncation=True, return_tensors="pt")

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

def collate_fn(texts):

  num_texts = len(texts['input_ids'])
  features = list()
  for i in range(num_texts):
      features.append({'input_ids':texts['input_ids'][i], 'attention_mask':texts['attention_mask'][i]})

  return features

我们已经在上面的部分中了解了我们的数据集是什么样的以及如何准备它,以便模型可以用于 STS 任务。上面的代码正是做了这些事情:

  • 每对文本之间的相似度得分被归一化,这将作为模型训练的真实标签。

  • 每对文本都使用完全相同的分词器和步骤进行分词,这与我们在前一节中看到的一样。分词后的文本对将作为我们训练模型的输入。

上面的 collate_fn 函数在分词处理后将每对文本分组在一起,以便于批处理,这是一个重要的功能。

损失函数

在 STS 任务中,我们的目标是训练一个模型,使其能够根据语义意义区分相似和不相似的文本对。这意味着我们希望模型将不相似的文本对的距离推得更远,同时将相似的文本对的距离保持得更近。

我们可以使用几个常见的损失函数来实现这个目标:余弦相似度损失、三元组损失和对比损失。

通常,我们可以使用对比损失来处理这种情况。然而,对比损失期望我们的标签是二进制的,即如果文本对在语义上相似,标签为 1,否则为 0。与此同时,我们在这个数据集中获得的标签是一个范围在 0 到 1 之间的浮动数值,因此余弦相似度损失将是一个更好的损失函数。

class CosineSimilarityLoss(torch.nn.Module):

    def __init__(self,  loss_fct = torch.nn.MSELoss(), cos_score_transformation=torch.nn.Identity()):

        super(CosineSimilarityLoss, self).__init__()
        self.loss_fct = loss_fct
        self.cos_score_transformation = cos_score_transformation
        self.cos = torch.nn.CosineSimilarity(dim=1)

    def forward(self, input, label):

        embedding_1 = torch.stack([inp[0] for inp in input])
        embedding_2 = torch.stack([inp[1] for inp in input])

        output = self.cos_score_transformation(self.cos(embedding_1, embedding_2))

        return self.loss_fct(output, label.squeeze())

这个损失函数接收每个文本的句子级嵌入,然后计算两个嵌入之间的余弦相似度。结果是,损失函数将在向量空间中将不相似的文本对推得更远,同时将相似的文本对保持得更近。

模型训练

现在我们已经设置好了模型的架构、数据加载器和损失函数,是时候训练模型了。代码只是一个标准的 Pytorch 训练脚本,如下所示:

from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

def model_train(dataset, epochs, learning_rate, bs):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    model = STSBertModel()

    criterion = CosineSimilarityLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    train_dataset = DataSequence(dataset)
    train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True)

    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    best_acc = 0.0
    best_loss = 1000

    for i in range(epochs):

        total_acc_train = 0
        total_loss_train = 0.0

        for train_data, train_label in tqdm(train_dataloader):

            train_data['input_ids'] = train_data['input_ids'].to(device)
            train_data['attention_mask'] = train_data['attention_mask'].to(device)
            del train_data['token_type_ids']

            train_data = collate_fn(train_data)

            output = [model(feature)['sentence_embedding'] for feature in train_data]

            loss = criterion(output, train_label.to(device))
            total_loss_train += loss.item()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print(f'Epochs: {i + 1} | Loss: {total_loss_train / len(dataset): .3f}')
        model.train()

    return model

EPOCHS = 8
LEARNING_RATE = 1e-6
BATCH_SIZE = 8

# Train the model
trained_model = model_train(dataset, EPOCHS, LEARNING_RATE, BATCH_SIZE)

在上面的实现中,我们将模型训练 8 个周期,学习率设置为 10e-6,批量大小设置为 8。这些是你可以调整以适应自己需求的超参数。

如果你运行上面的model_train函数,你将看到类似这样的训练进度:

图片由作者提供

模型预测

在训练完模型后,我们现在可以使用它来预测未见的数据,即未见的文本对。然而,在将模型输入未见的文本对之前,让我们创建一个函数,以便从模型中获取相似度预测。

# Load test data
test_dataset = load_dataset("stsb_multi_mt", name="en", split="test")

# Prepare test data
sentence_1_test = [i['sentence1'] for i in test_dataset]
sentence_2_test = [i['sentence2'] for i in test_dataset]
text_cat_test = [[str(x), str(y)] for x,y in zip(sentence_1_test, sentence_2_test)]

# Function to predict test data
def predict_sts(texts):

  trained_model.to('cpu')
  trained_model.eval()

  test_input = tokenizer(texts, padding='max_length', max_length = 128, truncation=True, return_tensors="pt")
  test_input['input_ids'] = test_input['input_ids']
  test_input['attention_mask'] = test_input['attention_mask']
  del test_input['token_type_ids']

  test_output = trained_model(test_input)['sentence_embedding']
  sim = torch.nn.functional.cosine_similarity(test_output[0], test_output[1], dim=0).item()

  return sim

上面的代码实现包括了数据的所有预处理步骤以及获取模型预测的步骤。

假设我们有一对类似的文本,如下所示:

print(text_cat_test[420])
>>> ['four children are playing on a trampoline.',
 'Four kids are jumping on a trampoline.']

print(predict_sts(text_cat_test[420]))
>>> 0.8608950972557068

现在我们只需调用predict_sts函数,就可以得到由我们的模型推断出的两个文本之间的余弦相似度。在这个例子中,我们得到的相似度为 0.860。这意味着我们的文本对彼此非常相似。

进行对比,现在让我们用一对不同的文本来测试模型。

print(text_cat_test[245])
>>> ['A man spins on a surf board.', 
'A man is putting barbecue sauce on chicken.']

print(predict_sts(text_cat_test[245]))
>>> 0.05531075596809387

如上所示,当我们处理一对不同的文本时,相似度仅为 0.055,这意味着在向量空间中两个文本的嵌入彼此相距较远。这正是我们的模型经过训练的目的。

结论

在这篇文章中,我们实现了一个用于语义文本相似度任务的 BERT 模型。具体来说,我们使用了 Sentence-Transformers 库来微调 BERT 模型为 Siamese 架构,从而能够获取每个文本的句子级嵌入。然后,可以通过余弦相似度比较每个文本的句子级嵌入。

你可以在这个笔记本中找到本文实现的所有代码。

预测建模中的敏感性:用更少的流量购买付费客户的指南

原文:towardsdatascience.com/sensitivity-in-predictive-modeling-a-guide-to-buying-paying-customers-with-less-traffic-c2ab97f6d629?source=collection_archive---------14-----------------------#2023-02-20

通过定义和评估模型敏感性,发现一种经济高效的广告活动策略,并提供逐步指南和 Python 实现

Dina BavliTowards Data Science Dina Bavli

·

关注 发表在 Towards Data Science · 7 分钟阅读 · 2023 年 2 月 20 日

--

图片由 Joey Kyber 提供,来源于 Unsplash

本博客文章概述了一种利用付费流量的广告策略。目标是以最小的流量获取付费客户,同时最大化效率。预测建模用于评估和提升模型在实现这一目标上的效果。通过定义和分析模型敏感性,公司可以在节省成本的同时实现预期结果。本文提供了 Python 实现和详细的逐步指南。

我们将涵盖以下内容:

· 介绍

· 理解业务中的混淆矩阵

· 与我谈 Python

· 这是完整代码

· 总结

介绍

以较少流量购买付费客户是使用付费流量进行广告的公司面临的常见挑战。目标是通过减少流量同时尽可能获得更多购买客户,从而提高效率。实现这一目标的一种方法是使用预测建模来评估和优化模型的性能。

预测建模涉及使用统计技术根据历史数据预测未来事件或结果。在此背景下,目标是预测哪些客户可能会购买,以便公司可以将广告投放目标对准这些客户。

评估预测模型性能时,可以使用混淆矩阵。混淆矩阵是一种表格,用于定义分类算法的性能,特别是在评估二分类模型时非常有用,如我们讨论的模型。矩阵比较模型的预测结果与实际结果。

评估二分类模型性能的常用指标之一是召回率。召回率是模型预测为购买客户的次数与实际购买客户的数量之比。换句话说,它衡量模型识别正例的能力,在我们的例子中即购买客户。

另一个重要的指标是阈值。阈值是将预测结果视为正例的点。提高阈值会增加假阳性的数量,从而降低精确度。而降低阈值则会增加假阴性的数量,从而降低召回率。

精确度和召回率之间的平衡被称为权衡。找到最佳阈值以最大化召回率同时最小化精确度非常重要,从而在减少流量的情况下获得尽可能多的付费客户。

在本博客文章中,我们将讨论一种以较少流量购买付费客户的策略。通过定义模型敏感性来评估模型,公司可以在实现预期结果的同时节省资金。

理解业务中的混淆矩阵

在商业预测建模中,拥有一个能够准确识别购买客户的模型至关重要,因为购买客户通常是稀有而有价值的群体。衡量分类算法准确性的一种方法是使用混淆矩阵。

混淆矩阵是一个表格,通过比较预测值和实际值,总结分类模型的性能。

二分类混淆矩阵中的四个类别是:

  • 真正正例(TP):模型正确预测为正例的正例数量。

  • 假正例(FP):模型错误地将负例预测为正例的数量。

  • 真正负例(TN):模型正确预测为负例的负例数量。

  • 假负例(FN):模型错误地将正例预测为负例的数量。

由作者创建,使用 excalidraw.com/

真正正例(TP)、假正例(FP)、假负例(FN)和真正负例(TN)。TP 表示模型预测为购买客户且预测准确的次数,而 FN 表示模型漏掉的购买客户次数。FP 表示模型预测为非购买客户但预测错误的次数,而 TN 表示模型预测为非购买客户且预测正确的次数。

由作者创建,使用 editor.codecogs.com/

召回率指标,也称为敏感性或真正正例率,衡量模型正确识别的实际购买客户的比例。计算公式为 TP/(TP+FN),表示模型预测为购买客户且预测正确的次数,除以实际购买客户的总数。

由作者创建,使用 editor.codecogs.com/

除了测量模型对购买客户的敏感性外,混淆矩阵还可以提供有关从特定阈值可以预期的流量和购买客户的洞察。通过计算 (FN + TP)/(TN + FP + FN + TP),可以确定在特定阈值下,模型将正确识别的购买客户占所有客户的百分比。

然而,需要注意的是,提高阈值会增加假正例,从而降低精确度。平衡模型的敏感性和精确度的一种方法是设置所需的支付客户百分比,并根据特定模型计算达到该百分比的阈值。

理解混淆矩阵及其指标可以为商业中预测模型的表现提供宝贵的见解,特别是在识别稀有且有价值的客户群体时。通过分析混淆矩阵,企业可以优化其模型,并做出数据驱动的决策,从而获得更好的结果。

Talk Python To Me

机器学习模型通过各种指标(如准确率、精确度和召回率)进行评估。在某些情况下,实现特定的召回水平比最大化准确率更为重要。在这篇文章中,我们将通过 Python 代码展示如何基于期望的召回水平评估模型。

问题:假设我们有一个二分类问题,我们希望预测用户是否会购买产品。数据集包含 200,000 条记录,其中 30,630 条为正例,169,070 条为负例。我们的目标是训练一个能够以高召回率预测用户是否会购买产品的模型。

解决方案:我们可以使用以下 Python 函数来评估模型在期望召回水平下的性能:

  1. extract_threshold_given_recall(y_test, probabilities, given_recall) 该函数接受三个输入:
  • y_test:测试集的目标值

  • probabilities:测试集的预测概率

  • given_recall:期望的召回水平

该函数使用y_test和概率计算精度-召回曲线,并返回给定召回值的阈值。

  1. get_model_results_for_recall(model, X_test, y_test, X_train, y_train, given_recall, with_plots=True) 该函数接受六个输入:
  • model:训练好的机器学习模型

  • X_test:测试集的特征矩阵

  • y_test:测试集的目标值

  • X_train:训练集的特征矩阵

  • y_train:训练集的目标值

  • given_recall:期望的召回水平

函数首先使用模型计算测试集的预测概率。然后,它使用extract_threshold_given_recall函数计算 ROC 曲线和期望召回值的最佳阈值。最后,它计算混淆矩阵、分类报告、FPR、AUC、准确度分数、最佳阈值和购买流量。可选地,函数还可以绘制 ROC 曲线。

输出将如下所示 👇

作者截屏

下面是完整代码:

总结

在本文中,我们看到了如何评估机器学习模型的性能,以达到所需的召回率水平。通过通过定义模型的敏感度来评估模型,公司可以在节省金钱的同时仍实现他们购买付费客户的期望结果,而不需要增加流量。我们提供了一个 Python 实现,可以帮助通过找到最大化召回率的最佳阈值来实现这一过程。最大化召回率可以减少购买未付费客户,因为召回率是衡量实际正例(即付费客户)被预测模型正确识别为正例的比例的指标。通过优化模型以最大化召回率,模型更擅长识别付费客户,这意味着公司可以避免购买不太可能产生付费客户的流量。这可以降低客户获取成本,并提高公司广告预算的效率。

句子变换器:伪装中的意义

原文:towardsdatascience.com/sentence-transformers-meanings-in-disguise-323cf6ac1e52

NLP 语义搜索

现代语言模型如何捕捉意义

James BriggsTowards Data Science James Briggs

·发表于Towards Data Science ·阅读时间 12 分钟·2023 年 1 月 3 日

--

图片来源:Brian SuhUnsplash。最初发布在Pinecone 的 NLP 语义搜索电子书中(作者在此工作)。

变换器完全重塑了自然语言处理(NLP)的格局。在变换器出现之前,得益于递归神经网络(RNNs),我们有了还不错的翻译和语言分类——它们的语言理解能力有限,导致许多小错误,大块文本的一致性几乎是不可能的。

自 2017 年论文《Attention is all you need》 [1]首次引入变换器模型以来,NLP 从 RNNs 发展到 BERT 和 GPT 等模型。这些新模型可以回答问题、撰写文章(也许是 GPT-3 写的)、实现非常直观的语义搜索——还有更多。

有趣的是,对于许多任务,这些模型的后期部分与 RNN 中的部分相同——通常是几个前馈神经网络,用于输出模型预测结果。

改变的是这些层的输入。变换器模型创建的密集嵌入信息量要丰富得多,尽管使用相同的最终外层,我们仍然获得了巨大的性能提升。

这些越来越丰富的句子嵌入可以用来快速比较各种用例中的句子相似性。例如:

  • 语义文本相似度 (STS) — 比较句子对。我们可能想要在数据集中识别模式,但这通常用于基准测试。

  • 语义搜索——使用语义意义的信息检索(IR)。给定一组句子,我们可以使用‘查询’句子进行搜索,并识别最相似的记录。使搜索能够基于概念(而非特定词汇)进行。

  • 聚类——我们可以对句子进行聚类,这对主题建模很有用。

在本文中,我们将探讨这些嵌入是如何被调整和应用于各种语义相似性应用的,通过使用一种叫做‘句子变换器’的新型变换器。

一些“上下文”

在我们深入探讨句子变换器之前,理解变换器嵌入为何如此丰富可能会有所帮助——以及普通变换器句子变换器之间的差异所在。

变换器是之前 RNN 模型的间接后代。这些旧的递归模型通常由许多递归单元构成,如 LSTMs 或 GRUs。

机器翻译中,我们会找到编码器-解码器网络。第一个模型用于编码原始语言到上下文向量,第二个模型用于将其解码成目标语言。

具有单一上下文向量共享的编码器-解码器架构在两个模型之间,这作为一个信息瓶颈,因为所有信息必须通过这一点传递。

这里的问题是我们在两个模型之间创建了一个信息瓶颈。我们在多个时间步骤中创建了大量信息,并试图通过一个连接将所有信息挤压过来。这限制了编码器-解码器的性能,因为编码器产生的许多信息在到达解码器之前就已经丢失。

注意机制为瓶颈问题提供了一个解决方案。它提供了信息传递的另一条途径。然而,它并没有让过程变得复杂,因为它仅仅关注最相关的信息。

通过将每个时间步的上下文向量传递到注意机制中(生成注释向量),去除了信息瓶颈,并在较长序列中有更好的信息保留。

带有注意机制的编码器-解码器。注意机制考虑了所有编码器输出激活和解码器中每个时间步的激活,这些都会修改解码器输出。

在解码过程中,模型一次解码一个单词/时间步。每一步都会计算单词与所有编码器注释之间的对齐度(例如,相似度)。

更高的对齐度导致了对解码器步骤输出的编码器注释的更大加权。这意味着机制计算了需要关注哪些编码器单词。

英法编码器和解码器之间的注意力,来源 [2]。

所有表现最好的 RNN 编码器-解码器都使用了这种注意力机制。

Attention is All You Need

在 2017 年,一篇题为Attention Is All You Need的论文发表了。这标志着 NLP 的一个转折点。作者们展示了我们可以去除 RNN 网络,并仅使用注意力机制——经过一些修改,就能获得更好的性能。

这个基于注意力的新模型被称为‘transformer’。从那时起,由于其极其优越的性能和出色的泛化能力,NLP 生态系统完全从 RNN 转向了 transformers。

第一个 transformer 通过使用三个关键组件去除了对 RNN 的需求:

  • 位置编码

  • 自注意力

  • 多头注意力

位置编码取代了 RNN 在 NLP 中的关键优势——考虑序列顺序的能力(它们是递归的)。它通过根据位置向每个输入嵌入添加一组变化的正弦波激活来工作。

自注意力是指在一个词与其自身上下文(句子/段落)中的所有其他词之间应用注意力机制。这不同于普通的注意力,它专注于编码器和解码器之间的注意力。

多头注意力可以看作是几个并行的注意力机制共同工作。使用多个注意力允许表示多个关系集(而不是单一的关系集)。

预训练模型

新的 transformer 模型的泛化能力远远超过了以前的 RNN,这些 RNN 通常是为每个用例特别构建的。

使用 transformer 模型,可以使用相同的‘核心’模型,并仅替换最后几层以适应不同的用例(而无需重新训练核心)。

这种新特性导致了NLP预训练模型的兴起。预训练的 transformer 模型是在大量训练数据上训练的——通常由谷歌或 OpenAI 等公司高成本训练,然后免费提供给公众使用。

这些预训练模型中最广泛使用的之一是 BERT,或谷歌 AI 的Bidirectional Encoder Representations from Transformers。

BERT 产生了一系列进一步的模型和变种,如 distilBERT、RoBERTa 和 ALBERT,涵盖分类、问答、词性标注等任务。

BERT 用于句子相似度

到目前为止,一切都很好,但这些 transformer 模型在构建句子向量时存在一个问题:Transformers 使用的是词或token级别的嵌入,而不是句子级别的嵌入。

在句子 transformers 出现之前,使用 BERT 计算准确的句子相似度的方法是使用交叉编码器结构。这意味着我们将两个句子传递给 BERT,在 BERT 顶部添加一个分类头——并用它来输出相似度评分。

BERT 交叉编码器架构包括一个 BERT 模型,该模型处理句子 A 和 B。这两个句子在同一序列中处理,由 [SEP] 标记分隔。随后是一个前馈神经网络分类器,输出相似度评分。

交叉编码器网络确实生成非常准确的相似度评分(比 SBERT 更好),但它的可扩展性差。如果我们想在一个 100K 句子的数据库中进行相似度搜索,我们需要完成 100K 次交叉编码器推断计算。

要对句子进行聚类,我们需要比较我们 100K 数据集中的所有句子,结果将产生接近 5 亿次比较 —— 这显然是不现实的。

理想情况下,我们需要预先计算句子向量,这些向量可以被存储,然后在需要时使用。如果这些向量表示良好,我们只需计算每对句子之间的余弦相似度。

使用原始的 BERT(及其他变换器),我们可以通过对 BERT 输出的所有标记嵌入进行平均来构建句子嵌入(如果我们输入 512 个标记,则输出 512 个嵌入)。另一种方法是使用第一个 [CLS] 标记的输出(一个特定于 BERT 的标记,其输出嵌入用于分类任务)。

使用这两种方法中的一种可以得到我们的句子嵌入,这些嵌入可以被存储并且比较速度更快,将搜索时间从 65 小时缩短到大约 5 秒(见下文)。然而,准确性不佳,甚至比使用平均的 GloVe 嵌入(该方法开发于 2014 年)更差。

这个解决方案 旨在解决缺乏准确模型的合理延迟问题,由 Nils Reimers 和 Iryna Gurevych 在 2019 年设计,推出了 sentence-BERT(SBERT)和 sentence-transformers 库。

SBERT 在所有常见的语义文本相似性(STS)任务中表现优于之前的最先进(SOTA)模型 —— 更多关于这些任务的内容见下文 —— 唯一的例外是一个数据集(SICK-R)。

幸运的是,为了实现可扩展性,SBERT 生成句子嵌入 —— 因此我们需要对每个句子对比较执行整个推断计算。

Reimers 和 Gurevych 在 2019 年展示了显著的速度提升。从 10K 个句子中找到最相似的句子对,使用 BERT 需要 65 小时。使用 SBERT,嵌入的创建时间约为 5 秒,与余弦相似度的比较时间约为 0.01 秒。

自 SBERT 论文发布以来,已经构建了许多使用类似概念的句子变换器模型,这些概念用于训练原始的 SBERT。它们都在许多相似和不相似的句子对上进行了训练。

使用诸如 softmax 损失、多负样本排序损失或 MSE 边际损失等损失函数,这些模型被优化以生成相似句子的相似嵌入和不相似句子的不同嵌入。

现在你已经了解了一些关于句子变换器的背景知识,包括它们的来源及其必要性。让我们深入探讨它们是如何工作的。

*[3] SBERT 论文涵盖了本节中的许多陈述、技术和数据。

句子转换器

我们解释了使用 BERT 的跨编码器架构来衡量句子相似性。SBERT 类似,但去掉了最终的分类头,并且一次处理一个句子。SBERT 然后在最终输出层上使用均值池化来生成句子嵌入。

与 BERT 不同,SBERT 在句子对上使用siamese架构进行微调。我们可以将其视为两个并行的完全相同的 BERT,分享完全相同的网络权重。

SBERT 模型应用于句子对句子 A句子 B。注意,BERT 模型输出的是令牌嵌入(由 512 个 768 维向量组成)。我们随后使用池化函数将这些数据压缩成一个 768 维的句子向量。

实际上,我们使用的是单个 BERT 模型。然而,由于我们在训练期间将句子 A 和句子 B 作为进行处理,因此更容易将其视为具有相同权重的两个模型。

Siamese BERT 预训练

训练句子转换器有不同的方法。我们将描述最初的 SBERT 过程,该过程主要优化softmax-loss。请注意,这是一个高层次的解释,我们将把深入讲解留到另一篇文章中。

softmax-loss 方法使用‘siamese’架构,在斯坦福自然语言推理(SNLI)和多领域 NLI(MNLI)语料库上进行微调。

SNLI 包含 570K 句子对,MNLI 包含 430K。这两个语料库中的句子对都包含一个前提和一个假设。每对句子被分配一个三种标签之一:

  • 0蕴含,例如,前提暗示了假设

  • 1中立前提假设都可能为真,但它们不一定相关。

  • 2矛盾前提假设互相矛盾。

给定这些数据,我们将句子 A(假设为前提)输入到 siamese BERT A 中,将句子 B(假设)输入到 siamese BERT B 中。

siamese BERT 输出我们的池化句子嵌入。SBERT 论文中有三种不同的池化方法结果。这些方法是均值最大值[CLS]-池化。均值池化方法在 NLI 和 STSb 数据集中表现最佳。

现在有两个句子嵌入。我们将嵌入 A 称为u,嵌入 B 称为v。下一步是拼接uv。再次,测试了几种拼接方法,但表现最好的方法是(u, v, |u-v|)操作:

我们将嵌入uv|u — v|进行拼接。

|u-v|的计算结果给出两个向量的逐元素差异。除了原始的两个嵌入(uv),这些都被输入到一个具有三个输出的前馈神经网络(FFNN)中。

这三个输出与我们的 NLI 相似性标签012对齐。我们需要从我们的 FFNN 计算 softmax,这在交叉熵损失函数中完成。softmax 和标签用于优化这个‘softmax-loss’

这些操作是在训练过程中对两个句子嵌入uv进行的。注意,softmax-loss 指的是交叉熵损失(默认情况下包含一个 softmax 函数)。

这导致我们对于相似句子(标签0)的池化句子嵌入变得更加相似,而对于不相似句子(标签2)的嵌入变得不那么相似

请记住,我们使用的是siamese BERT,而不是dual BERT。这意味着我们不使用两个独立的 BERT 模型,而是使用一个处理句子 A 然后处理句子 B 的单个 BERT。

这意味着当我们优化模型权重时,它们会朝着一个方向推动,使模型在看到蕴含标签时输出更多相似的向量,而在看到矛盾标签时输出更多不相似的向量。

这种训练方法有效的事实并不是特别直观,实际上 Reimers 曾描述过它偶然产生了良好的句子嵌入[5]。

自原始论文以来,这个领域有了进一步的研究。已经建立了许多更多的模型,如最新的 MPNet 和 RoBERTa 模型(在超过 1B 样本上训练)(表现更佳)。我们将在未来的文章中探讨其中的一些模型及其使用的优越训练方法。

现在,让我们看看如何初始化和使用一些这些句子变换器模型。

开始使用句子变换器

开始使用句子变换器的最快和最简单的方法是通过 SBERT 创建的sentence-transformers库。我们可以通过pip安装它。

pip install sentence-transformers

我们将从原始的 SBERT 模型bert-base-nli-mean-tokens开始。首先,我们下载并初始化模型。

我们可以看到的输出是SentenceTransformer对象,它包含了三个组件:

  • 变换器本身,我们可以看到最大序列长度为128个标记,并且是否对任何输入进行小写处理(在这种情况下,模型进行小写处理)。我们还可以看到模型类BertModel

  • 池化操作,在这里我们可以看到我们正在生成一个768维的句子嵌入。我们使用的是均值池化方法。

一旦我们有了模型,使用encode方法可以迅速构建句子嵌入。

现在我们有了句子嵌入,可以用来快速比较句子相似性,用于文章开头介绍的用例:STS、语义搜索和聚类。

我们可以仅使用余弦相似度函数和 Numpy 快速编写一个 STS 示例。

热图显示了所有句子对之间的余弦相似度值。

在这里,我们计算了五个句子嵌入之间每种组合的余弦相似度。它们是:

我们可以看到右下角的最高相似度分数为 0.64。正如我们所希望的,这一结果是针对描述使用建筑材料进行不良牙科实践的句子 43 的。

其他句子转换器

尽管我们从 SBERT 模型中得到了良好的结果,但自那以后已经构建了许多其他句子转换器模型。我们可以在 sentence-transformers 库中找到许多这样的模型。

这些更新的模型在性能上可以显著超过原始的 SBERT。事实上,SBERT 不再列为 SBERT.net 模型页面上的可用模型。

在句子转换器模型页面上的一些顶级表现模型。

我们将在未来的文章中更详细地介绍一些这些较新的模型。现在,让我们比较一下表现最好的模型,并运行我们的 STS 任务。

这里我们有 SentenceTransformer 模型 all-mpnet-base-v2。这些组件与 bert-base-nli-mean-tokens 模型非常相似,只是有一些小差异:

  • max_seq_length128 增加到了 384。这意味着我们可以处理的序列长度是使用 SBERT 时的 三倍

  • 基础模型现在是 MPNetModel [4] 而不是 BertModel

  • 对句子嵌入应用了额外的归一化层。

让我们比较一下 all-mpnet-base-v2 和 SBERT 的 STS 结果。

SBERT 和 MPNet 句子转换器的热图。

后期模型的语义表示非常明显。虽然 SBERT 正确地识别 43 为最相似的对,但它也对其他句子对赋予了相当高的相似度。

另一方面,MPNet 模型在相似对和不相似对之间做出了 非常 清晰的区分,大多数对的得分低于 0.1,而 4-3 对的得分为 0.52

通过增加不相似对和相似对之间的分离,我们正在:

  1. 使得自动识别相关对变得更加容易。

  2. 将预测结果推向训练期间使用的 01 目标分数,使 不相似相似 对的分数更加接近。这是我们将在未来的文章中关于微调这些模型时看到的内容。

这就是本文介绍句子嵌入和当前 SOTA 句子转换器模型的全部内容,这些模型用于构建这些极其有用的嵌入。

句子嵌入,虽然最近才流行开来,但却是从一系列出色的创新中产生的。我们描述了一些应用于创建第一个句子转换器 SBERT 的机制。

我们还展示了尽管 SBERT 于 2019 年才刚刚推出,但其他句子变换器已经超越了该模型。幸运的是,通过sentence-transformers库,我们可以轻松地将 SBERT 替换为这些更新的模型。

参考文献

[1] A. Vashwani 等,注意力机制全在于此(2017 年),NeurIPS

[2] D. Bahdanau 等,通过共同学习对齐和翻译的神经机器翻译(2015 年),ICLR

[3] N. Reimers, I. Gurevych,Sentence-BERT:使用 Siamese BERT 网络的句子嵌入(2019 年),ACL

[4] MPNet 模型,Hugging Face 文档

[5] N. Reimers,自然语言推断,GitHub 上的 sentence-transformers

所有图片均为作者提供,除非另有说明

情感分析与时间序列文本数据中的结构性断裂

原文:towardsdatascience.com/sentiment-analysis-and-structural-breaks-in-time-series-text-data-8109c712ca2

Arabica 现在提供结构性断裂和情感分析模块,以丰富时间序列文本挖掘

Petr KorabTowards Data Science Petr Korab

·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 3 月 6 日

--

图片由 Adam Śmigielski 提供,来源于 Unsplash

介绍

文本数据包含大量定性信息,可以通过各种方法进行量化,包括情感分析技术。这些模型用于识别、提取和量化文本数据中的情感,并广泛应用于商业和学术研究。由于文本通常以时间序列的形式记录,因此文本数据集可能会显示出结构性断裂,因为定量信息可能因多种因素而变化。

作为业务分析师,衡量客户对特定品牌的感知变化可能是关键任务之一。在研究角色中,可能会关注弗拉基米尔·普京公开声明随时间的变化。Arabica 是一个专门设计用于处理类似问题的 Python 库。它包含以下时间序列文本数据集的探索性分析方法:

  • arabica_freq 用于描述性 n-gram 基础的探索性分析(EDA)

  • cappuccino 是一个可视化模块,包括 热力图词云折线图,用于 unigram、bigram 和 trigram 频率

  • coffee_break 实现情感和结构性断裂分析。

本文将介绍Coffee-break,情感与结构断点分析模块。阅读 文档 和这两个方法的教程:arabica_freq、cappuccino

编辑 2023 年 7 月:Arabica 已更新。请查看 文档 以获取完整的参数列表。

2. Coffee_break: 算法和结构

coffee_break 模块具有简单的后端架构。以下是它的工作原理示意图:

图 1. Coffee_break 架构。来源:draw.io

原始文本使用 cleantext 进行清理,包括标点符号和数字的清除。停用词(语言中最常见但意义不大的词)在预处理步骤中不会被移除,因为它们不会对情感分析产生负面影响。然而,使用 skip 参数,我们可以移除一系列附加的停用词或不需要的字符串(词或词组),这些不会影响情感分析。

情感分析 实现了 VADERValence Aware Dictionary and Sentiment Reasoner),一种通用的预训练情感分类器 [1]。它是在 Twitter 的社交媒体数据上训练的,但也非常适用于其他类型的数据集。我之前的文章 提供了关于模型和 Python 编码的更详细介绍。

Coffee_break 使用 VADER 的复合指标进行情感评估。总体情感的计算方法如下:

其中 t 是聚合周期。总指标范围为 [-1: 1],正情感接近 1,负情感接近 -1。

聚合的情感创建了一个时间序列,显示出一定程度的时间变化。结构性断点 在时间序列中通过 Fisher-Jenks 算法Jenks 优化方法 识别,最初由 George F. Jenks 提出 [2]。

这是一种基于聚类的方法,旨在找到将值最佳安排到不同类别(聚类)中的方法。jenks_breaks 函数通过 jenkspy 库实现,返回一个值列表,这些值对应于类别的边界。这些结构性断点在图中标记为垂直线,直观地指示了文本数据时间序列中的断点。

实施的库 包括 Matplotlib(可视化)、vaderSentiment(情感分析)和 jenkspy(结构性断裂)。PandasNumpy 进行处理操作。

3. 使用案例:Twitter 情感分析

让我们在辉瑞疫苗推文数据集上展示编码,该数据集使用 Twitter API 收集。数据包含 11,000 条关于辉瑞与 BioNTech 疫苗的推文,发布于 2006 年至 2021 年之间。该数据集根据CC0: 公共领域许可发布,符合Twitter 开发者政策

数据包含大量标点符号和数字,在进行任何进一步的步骤之前需要清理:

图 2. 辉瑞疫苗推文数据集

coffee_break 方法的参数是:

def coffee_break(text: str,                 # Text column
                 time: str,                 # Time column
                 date_format: str,          # Date format: 'eur' - European, 'us' - American
                 time_freq: int ='',        # Aggregation period: 'Y'/'M'
                 preprocess: bool = False,  # Clean data from numbers and punctuation
                 skip: [] ,                 # Remove additional stop words
                 n_breaks: int =''          # Number of breaks: min. 2
)

3. 1. 随时间变化的情感分析

我们的数据涵盖了 15 年的时间跨度,包括 Covid-19 危机。关于疫苗接种的公众情绪变化、关于疫苗的虚假新闻以及许多其他因素预计会导致情感随时间发生显著变化。

编码

首先,导入coffee_break

from arabica import coffee_break

Arabica 读取 美国风格(MM/DD/YYYY欧洲风格 (DD/MM/YYYY) 日期和时间格式。数据相当原始,涵盖了 15 年。因此,以月份显示情感并不是很有帮助。

让我们使用以下代码清理数据并按年份汇总情感:

coffee_break(text = data['text'],
             time = data['date'],
             date_format = 'eur',  # Read dates in European format
             time_freq = 'Y',      # Yearly aggregation
             preprocess = True,    # Clean data - punctuation + numbers
             skip = None ,         # No other stop words removed
             n_breaks = None)      # No structural break analysis

结果

Arabica 返回一张可以手动保存为 PNG 或 JPEG 的图片。

图 3. 情感分析 — 每年

与此同时,Arabica 返回一个包含相应数据的数据框。只需将函数分配给一个对象即可保存表格:

# generate a dataframe
df = coffee_break(text = data['text'],
                  time = data['date'],
                  date_format = 'eur',
                  time_freq = 'Y',
                  preprocess = True,
                  skip = None ,
                  n_breaks = None)

# save is as a csv
df.to_csv('sentiment_data.csv')

结果解读: 我们可以看到,情感在 2021 年辉瑞疫苗开始用于对抗 Covid 后显著下降(图 2)。原因可能是全球大流行以及这些年普遍的负面情绪。

3.2. 结构性断裂分析

接下来,让我们在统计上形式化情感中的结构性断裂。coffee_break 能够识别最少 2 个断点。以下代码返回一张带有 3 个断点的图,并附有相应时间序列的表格:

coffee_break(text = data['text'],
             time = data['date'],
             date_format = 'eur', # Read dates in European format
             time_freq = 'Y',     # Yearly aggregation
             preprocess = True,   # Clean data
             skip = None,         # No other stop words removed
             n_breaks = 3)        # 3 breaktpoints

图:

图 4. 结构性断裂分析 — 每年

将数据子集化到两个 Covid 年份(2020–2021),我们可以观察到公众情绪的月度变化,保持 n_breaks = 3 并设置 time_freq = 'M'

图 5. 结构性断裂分析 — 每月

该图表信息不够丰富。这个子集中有 1577 行数据和 24 次时间观测,清理原始数据后,时间序列非常波动。使用基于聚类的算法对如此有限的数据量做出结论并不是一个好主意。

结果解释: 年度频率的结构性断裂分析在统计上确认了我们从图 3 中的情感时间序列中看到的情况。Fisher-Jenks 算法识别出了三个结构性断裂点——分别在 2009 年、2017 年和 2021 年。我们只能猜测 2009 年以及 2016 年至 2018 年之间的下降原因。2021 年的下降可能是由于 Covid-19 危机造成的。

4. 结构性断裂分析的最佳实践

让我们总结一下coffee_break的最有效使用建议:

  • 如果对应的时间序列中存在 NAN 值,请不要使用结构性断裂分析。

  • 在较长时间序列中(至少 12 次观测),识别 3 个以上的断裂点是有意义的。

  • 断裂点识别在高度波动的数据集中可能效果不好。剧烈变化的原因可能不是情感的变化,而是数据的质量。

  • 分析的准确性仅与基础情感数据的质量有关。在实际使用之前,简要探索原始文本数据集,以检查 (1) 是否在每个时期的行数上不太不平衡,以及 (2) 是否包含足够的情感评估信息(文本不太短,且不包含大多数数字和特殊字符)。

作者提供的图片

结论

coffee_break 的一个缺点是目前它仅适用于英文文本。由于 Arabica 主要是一个基于 Pandas 的包(包括某些部分的 Numpy 向量化),coffee_break 在评估大型数据集时相对较慢。它在处理最多约 40 000 行的数据集时比较高效。

阅读这些教程,了解更多关于 n-gram 和情感分析以及时间序列文本数据可视化的信息:

Coffee_break 是与布尔诺理工大学的 Jitka Poměnková教授合作开发的。这个教程中的完整代码在我的GitHub上。

你喜欢这篇文章吗?你可以邀请我 喝咖啡 来支持我的写作。你也可以订阅我的 电子邮件列表 以便接收到我新文章的通知。谢谢!

照片由Content Pixie提供,来源于Unsplash

参考文献

[1] Hutto, C., Gilbert, E. (2014). VADER: 一种简约的基于规则的社交媒体文本情感分析模型国际 AAAI 网络与社交媒体会议论文集8(1),216–225。

[2] Jenks, G.F. (1977). 最优数据分类用于分层地图。堪萨斯大学地理气象系,临时论文第 2 号。

九月还是“Septemquake”?用 R 分析和可视化墨西哥的地震活动数据

原文:towardsdatascience.com/september-or-septemquake-analysis-and-visualization-of-seismic-activity-data-in-mexico-with-r-f051be5f1fb

如何使用 SSN(国家地震学服务)数据分析和可视化墨西哥的地震历史

Saúl BuentelloTowards Data Science Saúl Buentello

·发布于 Towards Data Science ·13 分钟阅读·2023 年 9 月 21 日

--

当我在去年 9 月 30 日开始撰写这篇文章时,又一个月结束了,但这不是一个普通的月份。对许多墨西哥人来说,这个特别的月份常常让人忧虑,因为这个月份经常见证了我们国家因地震而动摇的情景,通常这些地震的强度都很大。本文旨在通过数据分析和可视化,为读者提供有关墨西哥地震历史的有价值的见解。虽然它不做预测或制定政策,但它提供了对地震趋势和模式的深入了解。通过获得这些知识,读者可以更好地为地震事件做好准备,并为建筑和灾难准备方面的知情决策做出贡献。

一个特别的日期突显出来,并激励了这篇文章:9 月 19 日。在 1985 年这一天,墨西哥经历了有记录以来最具破坏性的地震,震中在里氏 8.1 级。大约有 40,000 人遇难,近 4,000 人从由于构造运动而坍塌的无数建筑物和房屋的废墟中被救出。

图片来源 Mr. DimentioWikimedia Commons

自那时起已经过去了 32 年,2017 年,墨西哥人再次在 9 月 19 日被 7.1 级里氏震级的地震所震惊,这让许多经历过 1985 年灾难的人再次感受到了旧伤。在 2017 年之前,我无法理解为什么我的父母在听到墨西哥城安装的地震警报声时会如此紧张,这种警报至少可以让我们提前几秒钟预知即将发生的事情。在 2017 年之前,我对警报声并不认真对待,但在经历了在九楼(我曾经工作的地方)的地震后,从那里你可以看到附近建筑物倒塌时的尘土云,听到恐惧的人们的统一尖叫,同时建筑物发出如同要断裂的声音,震动把你摔倒在地,这一切终身难忘。

## 2017 年 9 月 19 日在墨西哥城的地震 #地震 #earthquake #Sismo #temblor

编辑描述

www.youtube.com

当天在墨西哥录得近 400 人遇难,7000 多人受伤。在地震发生时,我的女友(现在是我们美丽小宝贝的母亲)和我紧紧相拥,忘记了那天早晨的争执。从那时起,每当听到地震警报声,都会让我浑身发抖,就像现在生活在墨西哥城的许多人一样。

不管怎样,我们来到了 2022 年,去年 9 月 19 日下午 1 点左右,我们再次被响遍全城的地震警报惊醒,预示着即将发生的是 7.7 级里氏震级的地震。尽管为了防范这样的事件建造了更坚固、更雄伟的房屋和建筑,但自然力量永远没有免疫保证。这次事件中,至少有两人报告遇难,3000 多处财产受到影响。这让许多人产生了疑问:这个日期有什么特别之处吗?为什么墨西哥在这个特定的日子会发生强震?这会不会是阿兹特克神灵要求牺牲的惩罚?

— 这是发生在九月。不是今年的九月,而是去年的。还是说是前年,梅利顿?

— 不,那是在去年。

— 是的,我现在记起来了。是在去年九月二十一号。嘿,梅利顿,九月二十一号不是正是地震那天吗?

— 那是稍早一点。我认为是在十八号。

— 你说得对。我那时在图斯卡库斯科(Tuzcacuexco)。我甚至看到房屋像太妃糖一样倒塌,扭曲并做出各种面部表情,然后整个墙壁轰然倒下。

… 摘自 1955 年墨西哥作家胡安·鲁尔福(Juan Rulfo)的故事《崩溃之日》(“El día del derrumbe”)。

本文旨在通过数据分析,为读者提供对墨西哥人特别是在这个月广泛讨论的问题的更全面的理解。需要指出的是,墨西哥的地震活动非常频繁,位于五个不同的构造板块交汇处,使其比其他国家更为脆弱。

我可以在哪里获得可靠的墨西哥地震活动数据?

有许多来源可以获得地震活动的历史数据;你使用的数据集已获得商业用途许可。幸运的是,我推荐一个非常可靠的来源,即国家地震服务中心(SSN)。SSN 维护了全国范围的地震监测网络,并提供适合商业用途的数据许可证。他们的数据集涵盖了从 1900 年至今的记录,并且几乎即时更新新的地震事件记录。

你可以在他们的网站上下载完整的历史记录 CSV 格式的数据,或者选择感兴趣的特定时间段:www2.ssn.unam.mx:8080/catalogo/

九月还是“震动九月”?使用 R 对墨西哥的地震活动数据进行分析和可视化 — SSN 网站截图(作者提供的图像)

数据加载和准备

在新的 R 脚本中,你将首先加载相关的库来进行分析。除了从国家地震服务中心网站提取的数据外,为了实用起见,你还需要从 INEGI 门户网站下载墨西哥共和国的政治分区地图,网址是 www.inegi.org.mx/app/mapas/ 以 SHP 格式提供。众所周知,你可以在 R 中处理这种格式。最后,你还将获得一个州名和标识符的列表,以便在后续与其他内容结合使用。

# LOADING LIBRARIES

library(data.table)
library(magrittr)
library(sf)
library(tmap)
library(sp)
library(ggpubr)
library(ggplot2)
library(lattice)

# PALETTE & FUNCTION 

myColors <- c('#d9ef8b','#91cf60','#fee08b','#fc8d59','#d73027','#1a9850')

decade <- function(date){
  year <- data.table::year(date)
  y <-year - year %% 10
  return(y)
}

# LOAD DATA AND SHAPES

dataSSN <- fread("SSNMX_catalogo_19000101_20221126.csv", header = T, skip = 4, sep=",", fill=T)

mxMap <- st_read("dest20gw/dest20gw.shp")

statesNames <- fread("nombres_estados.csv", encoding = "UTF-8",header = T)

skipLast <- grep(pattern = "Fecha y hora local",as.character(unlist(dataSSN[,1])))

dataSSN <- dataSSN[-skipLast:-dim(dataSSN)[1],]

names(dataSSN)<- tolower(names(dataSSN))

names(dataSSN)<- gsub("[[:space:]]", ".", names(dataSSN))

你还将按震级对地震数据进行分类,并将 CSV 数据格式化以提高可读性,因为它包含了如震级、日期、震中和其他详细信息等精确数据。

# ADJUSTMENTS TO DATA

dataSSN <- dataSSN %>% 
  .[, state := gsub(".*,\\s*", "\\1", referencia.de.localizacion)] %>%  
  .[, state := gsub("[[:space:]]","", state)] %>% 
  .[, state := fifelse(state=="N", "NL", state)] %>% 
  .[, date  := as.Date(fecha)] %>% 
  .[, intensity := suppressWarnings( as.numeric(magnitud))] %>% 
  .[, intensityRanges := fcase(
    intensity>=0 & intensity<= 2 , "0-2", intensity>=2 & intensity<= 4 , "2-4",
    intensity>=4 & intensity<= 6 , "4-6", intensity>=6 & intensity<= 8 , "6-8",
    intensity>=8 & intensity<= 10 , "8-10", intensity>10, "10 +",
    is.na(intensity), "Unmeasured magnitude")] %>% 
  .[, intensityRanges := factor(intensityRanges, levels = c("0-2", "2-4", "4-6", 
                                                            "6-8", "8-10","10+", 
                                                            "Unmeasured magnitude"))] %>% 
  .[, theDecade := decade(date)] %>% 
  .[, monthDate := as.Date(cut(date, "month"))] %>% 
  .[, weekDate := as.Date(cut(date, "week"))] %>% 
  .[, month := data.table::month(date)] %>% 
  .[, monthName := format(date, "%B")] %>% 
  .[, dayName := format(date, "%A")] %>%
  .[, monthDay := format(as.Date(date), "%m-%d")] %>% 
  .[, year := data.table::year(date)] %>% 
  .[, day := data.table::mday(date)] %>% 
  .[date >= "1900-01-01"]

dataSSN <- merge(dataSSN, statesNames, by.x="state", by.y="id.estado")

自 1900 年至今记录的地震活动

使用你的数据,你可以大致查看自 1900 年以来记录到的地震数量是否有所增加或减少。

# SEISMIC ACTIVITY SINCE 1900

dataSSN %>% 
  .[, .(seismicCount = .N), by=date] %>% 
  ggplot(aes(x= date, y= seismicCount)) +
  geom_line(color="darkcyan") +
  xlab("Year") +
  ylab("Earthquakes count") +
  ggtitle("Earthquakes per day reported by the SSN", "Since 1900 to 2022") 

请注意,虽然历史数据可以追溯到 1900 年,但直到 2010 年之后,国家才安装了更多的地震传感器。这清楚地显示了记录到的地震数量的增长,但这并不一定表明地震活动的增加。

九月还是“震动九月”?使用 R 对墨西哥的地震活动数据进行分析和可视化 — RStudio 绘图“自 1900 年至今记录的地震活动”(作者提供的图像)

哪种震级主导了墨西哥的地震活动?

正如你之前按震级对数据进行了分类,现在你可以可视化并更具体地了解记录到的地震震级随时间的变化。

# SEISMIC ACTIVITY SINCE 1900 BY INTENSITY

timeSeries1 <- dataSSN %>% 
  .[, .(seismicCount = .N), by=list(date, intensityRanges)] %>% 
  na.omit %>% 
  ggplot(aes(x= date, y= seismicCount, color=intensityRanges)) +
  geom_line(size=1)+
  scale_color_manual(values= myColors, name="Richter Scale") +
  theme(legend.position="top")+
  ggtitle("Earthquakes per day reported by the SSN", "Magnitude ranges since 1900 to 2022")+
  xlab("Year") +
  ylab("Earthquakes count") 

timeSeries2 <- dataSSN %>% 
  .[, .(seismicCount = .N), by=list(date, intensityRanges)] %>% 
  na.omit %>% 
  ggplot(aes(x= date, y= seismicCount, color=intensityRanges)) +
  geom_line(size=1)+
  scale_color_manual(values= myColors) +
  scale_x_date(breaks = "4 year")+
  theme(legend.position="top", axis.text.x = element_text(angle=90))+
  xlab("Year") +
  ylab("Earthquakes count") +
  facet_wrap(~intensityRanges, scales = "free_y") 

ggarrange(timeSeries1, timeSeries2,
          nrow = 2)

正如你所见,震中震级在里氏 2 到 4 级的地震占据主导地位。再提醒一下,在 2000 年之前,安装的传感器不多。

九月还是“九月震”?使用 R 进行的墨西哥地震活动数据分析和可视化——RStudio 图表“哪一月份地震活动最强?”(作者提供的图像)

墨西哥发生了多少次震级超过 6 的地震?

你还可以从总体上查看地震的发生情况,特别是震级超过 6 的地震(被视为高风险地震)。

# SEISMIC ACTIVITY SINCE 1900 BY INTENSITY (>6)

dataSSN %>% 
  .[intensity>6] %>% 
  .[, .(seismicCount = .N), by=list(date, intensityRanges)] %>%
  ggplot(aes(x= (date), y= seismicCount)) +
  geom_segment( aes(x=(date), xend=date, y=0, yend=seismicCount, color=intensityRanges)) +
  geom_point(size=3, alpha=0.6, aes(color= intensityRanges))+
  scale_color_manual(values= c("darkcyan", "darkred"), name="Range Magnitude Richter Scale") +
  scale_x_date(breaks = "5 years")+
  scale_y_continuous(breaks = seq(0,2,1))+
  theme(legend.position="top", axis.text.x = element_text(angle=45))+
  ggtitle("Earthquakes reported by the SSN", "Intensity > 6 (Since 1900 to 2022)")+
  xlab("Year") +
  ylab("Earthquakes count")

虽然在你获得的图像中可能不易察觉,但放大后会显示震级达到 8 的地震,这些地震确实给国家带来了震动。

九月还是“九月震”?使用 R 进行的墨西哥地震活动数据分析和可视化——RStudio 图表“墨西哥发生了多少次震级超过 6 的地震?”(作者提供的图像)

墨西哥哪个月份地震活动最强?

根据获得的数据,你可以回答这个问题,很多人在九月时曾关心过。值得注意的是,目前没有严格的科学依据解释为何特定月份的地震活动会增加。

# SEISMIC ACTIVITY REPORTED SINCE 1900 (PER MONTH)

dataSSN %>% 
  .[, .(seismicCount = .N), by=list(month, monthName)] %>% 
  .[order(month)] %>% 
  ggplot(aes(x= reorder(monthName, month), y= seismicCount)) +
  geom_bar(stat = "identity") +
  geom_col(aes(fill = seismicCount)) +
  theme(axis.text.x = element_text(angle=45)) +
  geom_label(aes(reorder(monthName, month), y= seismicCount,
                 label=seismicCount) ) +
  xlab("Month") +
  ylab("Earthquakes count")+
  ggtitle("Earthquakes reported by the SSN", "Earthquakes per month (Since 1900 to 2022)")

然而,数据本身说明了问题,显示出九月确实记录了比其他月份更多的地震。

九月还是“九月震”?使用 R 进行的墨西哥地震活动数据分析和可视化——RStudio 图表“哪个月份地震活动最强?”(作者提供的图像)

按强度分组的墨西哥地震活动最高的月份?

假设你想更深入地了解先前获得的结果。你可以按强度分组,以确定除了九月是地震活动最强的月份外,九月是否也与较高的震级相关。

# SEISMIC ACTIVITY REPORTED SINCE 1900 (PER MONTH AND INTENSITY)

rangesMonth <- dataSSN %>% 
  .[, .(seismicCount = .N), by=list(monthName, month, intensityRanges)] %>% 
  ggplot(aes(x= reorder(monthName, month), y= seismicCount, fill= intensityRanges)) +
  geom_bar(stat = "identity") +
  scale_fill_manual(name = "Response", values = myColors) +
  theme(axis.text.x = element_text(angle=45), legend.position="top") +
  xlab("Month") +
  ylab("Earthquakes count")+
  ggtitle("Earthquakes reported by the SSN", "Earthquakes per month (Grouped by intensity since 1900 to 2022)")

rangesMonth

下面的图表可能不会显示震级 8 的地震,但它们确实存在,6 月份有一次,9 月份有两次。如果你愿意,可以使用 plotly 进行更明显的可视化。

九月还是“九月震”?使用 R 进行的墨西哥地震活动数据分析和可视化——RStudio 图表“按强度分组的墨西哥地震活动最强的月份?”(作者提供的图像)

墨西哥哪个月份的地震活动最强,震级超过 6 度?

考虑到震中达到 6 度或更高的地震通常更为显著,假设你想更深入地研究之前获得的图表,以筛选出那些明显超过 6 度的地震。

# SEISMIC ACTIVITY REPORTED SINCE 1900 (GROUPED BY INTENSITY >6)

dataSSN %>% 
  .[intensity>6] %>% 
  .[, .(seismicCount = .N), by=list(monthName, month, intensityRanges)] %>% 
  ggplot(aes(x= reorder(monthName, month), y= seismicCount, fill= intensityRanges)) +
  geom_bar(stat = "identity")+
  scale_y_continuous(breaks = seq(0,100,1))+
  theme(axis.text.x = element_text(angle=45), legend.position="top")+
  scale_fill_manual(values= c("darkcyan", "darkred"), name="Range Magnitude Richter Scale")+  
  xlab("Month") +
  ylab("Earthquakes count")+
  geom_label(aes(reorder(monthName, month), y= seismicCount, label=seismicCount) )+
  ggtitle("Earthquakes reported by the SSN", "Earthquakes per month (Intensity > 6 since 1900 to 2022)")

这一次,情况变得更加清晰,尤其是震级超过 8 度的地震发生的月份。12 月似乎是震级超过 6 度的地震最多的月份。

9 月还是“九月地震”?使用 R 对墨西哥的地震活动数据进行分析和可视化 — RStudio 绘图“哪一个月份的地震活动最高,震级超过 6 度?”(图片由作者提供)

哪一天的“生日”地震数量最多,震级超过 6 度?

假设你还想了解在同一天和同一个月份但不同年份中发生了三次或更多次震级达到 6 度或以上的地震的日期。

# DAYS THAT AN EARTHQUAKE OF INTENSITY >6 HAS BEEN REPEATED AT LEAST 3 OR MORE TIMES

dataSSN %>% 
  .[intensity>=6] %>%
  .[, .(seismicCount = .N), by=list(month, monthDay)] %>% 
  .[order(month)] %>% 
  .[seismicCount>=3] %>%
  ggplot(aes(x= reorder(monthDay, month), y= seismicCount)) +
  geom_bar(stat = "identity", fill="darkcyan")+
  geom_col(aes(fill = seismicCount)) +
  theme(axis.text.x = element_text(angle=45)) +
  geom_label(aes(reorder(monthDay, month), y= seismicCount, label=seismicCount) ) +
  xlab("Date (month-day)") +
  ylab("Earthquakes count")+
  ggtitle("Earthquakes reported by the SSN", "Dates that have occurred three or more earthquakes on the same day 
and month but different year (Intensity >= 6 since 1900 to 2022)")

你会看到 6 月 7 日以七次符合条件的地震居于榜首。我们是否应该在这些日期佩戴头盔并避免高楼?

9 月还是“九月地震”?使用 R 对墨西哥的地震活动数据进行分析和可视化 — RStudio 绘图“哪一天的‘生日’地震数量最多,震级超过 6 度?”(图片由作者提供)

一周中的哪一天墨西哥的地震活动最高,震级超过 7 度?

这可能是另一个你可以通过建立震中等于或大于 7 度的地震发生情况来回答的问题。

# WEEKDAYS THAT AN EARTHQUAKE OF INTENSITY >7 HAS BEEN REPEATED MORE TIMES

dataSSN %>% 
  .[intensity>=7] %>%
  .[order(dayName)] %>% 
  .[, .(seismicCount = .N), by=list(dayName)] %>% 
  ggplot(aes(x= dayName, y= seismicCount)) +
  geom_bar(stat = "identity", fill="darkcyan")+
  geom_col(aes(fill = seismicCount)) +
  theme(axis.text.x = element_text(angle=45)) +
    xlab("Mes") +
  ylab("Conteo Sismos")+
  ggtitle("Earthquakes reported by the SSN", "Weekdays with greater occurrences of earthquakes (Intensity >= 7 since 1900 to 2022)")

你将得到如下图表,显示出星期五记录的地震活动最高,而星期天的记录较少。

9 月还是“九月地震”?使用 R 对墨西哥的地震活动数据进行分析和可视化 — RStudio 绘图“一周中的哪一天墨西哥的地震活动最高,震级超过 7 度?”(图片由作者提供)

墨西哥地震活动图

在这样的练习中,拥有地理参考总是很有用。你可以创建一张地图,以概括地可视化该国的地震活动。

# MAP SEISMIC ACTIVITY BY STATE 1 

dataSSN %>% 
  .[, .(seismicCount = .N), by=list(nombreEstado)] %>% 
  merge(., mxMap, by.x="nombreEstado", by.y="NOM_ENT") %>% 
  st_as_sf() %>% 
  ggplot()+
  geom_sf(aes(fill = seismicCount)) + 
  ggtitle("Earthquakes reported by the SSN", "Map of seismic activity in Mexico (Since 1900 to 2022)")

结果会突出显示瓦哈卡,该地区的地震数量最高。

9 月还是“九月地震”?使用 R 对墨西哥的地震活动数据进行分析和可视化 — RStudio 绘图“墨西哥地震活动图”(图片由作者提供)

如果你愿意,可以为你的地图提供更具美感的格式。

# MAP SEISMIC ACTIVITY BY STATE 2

tmap_style("classic")

statesMap <- dataSSN %>% 
  .[, .(seismicCount = .N), by=list(nombreEstado)] %>% 
  merge(., mxMap, by.x="nombreEstado", by.y="NOM_ENT") %>% 
  st_as_sf() %>% 
  tm_shape() +
  tm_polygons("seismicCount", 
              title = "Number of earthquakes range")+
  tm_layout("Number of earthquakes 
              by State (Since 1900 to 2022)", title.position = c('right', 'top'))

statesMap

你可以试验“tmap_style()”函数的可用样式,以选择你喜欢的样式。

9 月还是“九月地震”?使用 R 对墨西哥的地震活动数据进行分析和可视化 — RStudio 绘图“墨西哥地震活动图”(图片由作者提供)

哪个州的地震活动最强?

尽管您已经获得了一个地图来回答这个问题,但制作另一个图表来用更多细节加强结果也不会有坏处。

# SEISMIC ACTIVITY GROUPED BY STATE 

paretoData <- dataSSN %>% 
  .[, .(seismicCount = .N), by=list(nombreEstado)]  %>% 
  .[, total := sum(seismicCount)] %>% 
  .[order(seismicCount, decreasing = T)] %>% 
  .[, accumulatedSum := cumsum(seismicCount)] %>% 
  .[, percentage := seismicCount/total] %>% 
  .[, accumulatedPercentage := accumulatedSum/total]

listPercentage <- 0:100

statesBars <- ggplot(data= data.frame(paretoData), aes(x=nombreEstado)) +
  geom_bar(aes(x=reorder(nombreEstado, -seismicCount), y=seismicCount), fill='darkcyan', stat="identity") + 
  scale_y_continuous(limits = c(0, (max(paretoData$seismicCount))+15000 ))+
  theme(axis.text.x = element_blank(), axis.title.x = element_blank())+
  ylab("Earthquakes count")+
  geom_text(aes(reorder(nombreEstado, -seismicCount), y= seismicCount, label=seismicCount), vjust=-1, angle=45, hjust=0)+
  ggtitle("Earthquakes reported by the SSN", "Earthquakes grouped by State (Since 1900 to 2022)")

statesPercentage <- ggplot(data= data.frame(paretoData)) +
  geom_bar(aes(x=reorder(nombreEstado, -seismicCount), y=percentage), fill='darkcyan', stat="identity") + 
  theme(axis.text.x = element_text(angle=45, hjust=1))+
  scale_y_continuous(labels = scales::percent, breaks = seq(0,1,0.1), limits=c(0,0.7))+
  xlab("State") +
  ylab("Earthquakes count")+
  geom_text(aes(x= reorder(nombreEstado, -seismicCount), y = percentage,
                label=paste0(round(percentage*100, 3),"%"), vjust=0, angle=45, hjust=0))+
  ggtitle(" ", "Percentage of earthquakes by State (Since 1900 to 2022)")

ggarrange(statesBars, statesPercentage, nrow = 2)

如果您在寻找一个可以摆脱地震担忧的地方,尤卡坦、克雷塔罗、阿瓜斯卡连特斯或杜兰戈可能是不错的选择。

九月还是“九月地震”?使用 R 对墨西哥地震活动数据进行分析和可视化 — RStudio 图表“墨西哥哪个州的地震活动最强?”(图片由作者提供)

里氏震级大于 7.5 的地震地图

讨论震中强度等于或大于 7.5 的地震意味着讨论可能导致灾难性损失的震动。您可以查看这些被记录为历史上最强烈的地震的起源地图。

# MAP TOP EARTHQUAKES >= 7.5

topEarthquakes <- dataSSN %>%  
  .[intensity>=7.5]

ggplot(data = mxMap) +
  geom_sf()+
  geom_point(data= topEarthquakes, aes(x= longitud, y= latitud, size= intensity), color="darkcyan")+
  scale_color_manual( name="Magnitud escala de Richter")+
  theme(axis.text = element_blank(), axis.title = element_blank(), legend.position="top")+
  ggrepel::geom_text_repel(data= topEarthquakes, aes(x= longitud, y= latitud, label= intensity), color="darkred")+
  ggtitle("Earthquakes reported by the SSN", "Map of strongest earthquakes locations (Intensity >= 7.5 since 1900 to 2022)")

结果将显示,太平洋沿岸的州是墨西哥记录到的最强烈地震发生地。

九月还是“九月地震”?使用 R 对墨西哥地震活动数据进行分析和可视化 — RStudio 图表“里氏震级大于 7.5 的地震地图”(图片由作者提供)

总结来说,本文通过数据分析和可视化揭示了墨西哥的地震历史。虽然它没有提供预测或政策建议,但它为读者提供了对地震趋势的宝贵理解。这些知识可以对个人安全、城市规划和建筑实践产生重大影响。通过了解地震活动最可能发生的时间和地点,个人和社区可以做出明智的决策以减少风险。无论您是墨西哥城的居民还是设计抗震结构的工程师,从这次分析中获得的见解都可以产生真正的影响。保持安全,保持知情,并记住,知识在地震准备中是一个强大的工具。

感谢您的细心阅读。如果您目前居住在地震活动频繁的地区,请务必注意自身安全。和我的其他文章一样,我在这里分享了完整的代码:github.com/cosmoduende/r-earthquakes

祝您在分析中获得愉快的体验,把一切付诸实践,并对结果感到惊讶和娱乐!

从你的电脑上提供大语言模型服务,通过文本生成推理

原文:towardsdatascience.com/serve-large-language-models-from-your-computer-with-text-generation-inference-54f4dd8783a7

使用 Falcon-7B 的指令版本的示例

本杰明·玛丽Towards Data Science 本杰明·玛丽

·发表在Towards Data Science ·阅读时间 6 分钟·2023 年 7 月 13 日

--

图片由Nana Dua提供,来源于Unsplash

通过量化方法如 QLoRa 和GPTQ,现在可以在消费者硬件上本地运行非常大的语言模型(LLM)。

考虑到加载大语言模型的时间,我们可能还希望将 LLM 保持在内存中以便即时查询并获取结果。如果你使用标准的推理管道,你必须每次重新加载模型。如果模型非常大,你可能需要等待几分钟才能生成输出。

有各种框架可以在服务器(本地或远程)上托管大语言模型(LLMs)。在我的博客中,我已经介绍了NVIDIA 开发的非常优化的 Triton 推理服务器框架,它用于服务多个 LLMs,并在 GPU 之间平衡负载。但是,如果你只有一个 GPU,并且希望在你的电脑上托管模型,使用 Triton 推理可能会显得不太合适。

在这篇文章中,我介绍了一种替代方案,称为文本生成推理。这是一个更简单的框架,实现了运行和服务 LLMs 所需的所有基本功能,适用于消费者硬件。

阅读完本文后,你将在你的电脑上拥有一个本地部署并等待查询的聊天模型/LLM。

文本生成推理

文本生成推理(TGI)是一个用 Rust 和 Python 编写的框架,用于部署和服务 LLM。它由 Hugging Face 开发,并以 Apache 2.0 许可证 进行分发。Hugging Face 在生产中使用它来驱动他们的推理小部件。

尽管 TGI 已针对 A100 GPU 进行了优化,但由于对量化和 分页注意力的支持,我发现 TGI 非常适合自托管的 LLM,在如 RTX GPU 这样的消费级硬件上表现出色。然而,它需要特定的安装来支持 RTX GPU,这一点我将在本文后续部分详细说明。

最近,我还发现 Hugging Face 正在优化一些 LLM 架构,以便它们在 TGI 下运行更快。

这尤其适用于 Falcon 模型,这些模型在使用标准推理管道时运行较慢,但在使用 TGI 时运行更快。一位 Falcon 模型的作者在 Twitter 上告诉我,这是因为他们在多查询注意力的实现上匆忙,而 Hugging Face 则优化了它以便与 TGI 配合使用。

有几种 LLM 架构以这种方式进行了优化,以便在 TGI 下运行得更快:BLOOM、OPT、GPT-NeoX 等。完整列表可以在 TGI 的 GitHub 上找到,并定期更新。

设置文本生成推理

硬件和软件要求

我在 RTX 3060 12 GB 上进行了测试。它应该适用于所有 RTX 30x 和 40x,但请注意 TGI 特别优化了 A100 GPU。

要运行这些命令,你需要一个 UNIX 操作系统。我使用了通过 Windows WSL2 的 Ubuntu 20.04。

它在 Mac OS 上也应该可以正常工作。

TGI 需要 Python ≥ 3.9。

我将首先介绍如何从零开始安装 TGI,我认为这并不简单。如果你在安装过程中遇到问题,可能需要改用 Docker 镜像。我将讨论这两种情况。

设置

TGI 是用 Rust 编写的。你需要先安装它。如果你没有安装,运行以下命令:

curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

这应该花费不到 2 分钟。我建议重启你的 shell,例如,打开一个新的终端,以确保所有环境变量正确更新。

然后,我们创建一个专用的 conda 环境。此步骤是可选的,但我更喜欢为每个项目创建一个独立的环境。

conda create -n text-generation-inference python=3.9 
conda activate text-generation-inference

我们还需要安装 Protoc。Hugging Face 目前推荐版本 21.12。你需要具有 sudo 权限。

PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
rm -f $PROTOC_ZIP

我们已经安装了所有要求。现在,我们可以安装 TGI。

首先,克隆 GitHub 仓库:

git clone https://github.com/huggingface/text-generation-inference.git

然后安装 TGI:

cd text-generation-inference/
BUILD_EXTENSIONS=False make install

注意:我将 BUILD_EXTENSIONS 设置为 False 以停用自定义 CUDA 内核,因为我没有 A100 GPU。

应该可以顺利安装……但在我的计算机上却没有。我不得不手动运行 server/Makefile 文件中的所有命令。我怀疑是由于“make”由于某种原因切换到不同的 shell,导致我的环境变量没有正确加载。你可能也需要这样做。

注意:如果安装失败,不用担心!Hugging Face 创建了一个 Docker 镜像,你可以启动它以启动服务器,我们将在下一部分中看到。

使用 TGI 启动模型

对于以下示例,我使用的是 Falcon-7B 模型 的 instruct 版本,它在 Apache 2.0 许可证下分发。如果你想了解更多关于 Falcon 模型的信息,我在上一篇文章中做了介绍:

## Open LLM Falcon-40B 介绍:性能、训练数据和架构

开始使用 Falcon-7B、Falcon-40B 及其 instruct 版本

towardsdatascience.com

不使用 Docker

安装创建了一个新的命令“text-generation-launcher”,它将启动 TGI 服务器。

text-generation-launcher --model-id tiiuae/falcon-7b-instruct --num-shard 1 --port 8080 --quantize bitsandbytes
  • model-id:模型名称在 Hugging Face Hub

  • num-shard:设置为你拥有的 GPU 数量,以及你希望利用的数量。

  • port:你希望服务器监听的端口。

  • quantize:如果你使用的 GPU 内存少于 24 GB,你需要对模型进行量化,以避免内存不足。我选择了“bitsandbytes”进行即时量化。GPTQ(“gptq”)也可用,但我对这个算法不太熟悉。

使用 Docker(如果手动安装失败)

注意:如果 Docker 守护进程未运行,并且你通过 WSL 运行 Ubuntu,请在另一个终端中使用“sudo dockerd”启动守护进程。

volume=$PWD/data
sudo docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9 --model-id tiiuae/falcon-7b-instruct --num-shard 1  --quantize bitsandbytes

参数几乎与 text-generation-launcher 一样。如果你只有一个 GPU,你可以将“all”替换为“0”。

保持这个 Docker 镜像运行,只要你想使用服务器。

使用 TGI 查询模型

要用 Python 脚本查询 TGI 提供的模型,你需要安装以下库:

pip install text-generation

然后在 Python 脚本中,写类似这样的代码:

from text_generation import Client

client = Client("http://127.0.0.1:8080")
print(client.generate("Translate the following into French: 'What is Deep Learning?'", max_new_tokens=500).generated_text)

它应该打印:

Qu'est-ce que la profondeur de l'apprentissage ?

这是一种较差质量的翻译。这是对一个 70 亿参数模型的预期。它在编码任务上略好一些:

from text_generation import Client

client = Client("http://127.0.0.1:8080")
print(client.generate("Code in Javascript a function to remove all spaces in a string and then print the string twice.", max_new_tokens=500).generated_text)

它生成:

Here is an example code snippet in JavaScript to remove all spaces in a string and then print the string twice:

```javascript

function removeSpaces(str) {

return str.replace(/\s+/g, '');

}

console.log(removeSpaces('Hello World'));

console.log(removeSpaces('Hello World'));

```py

你也可以用 curl 进行查询,而不是 Python 脚本:

curl 127.0.0.1:8080/generate \
    -X POST \
    -d '{"inputs":"Code in Javascript a function to remove all spaces in a string and then print the string twice.","parameters":{"max_new_tokens":500}}' \
    -H 'Content-Type: application/json'

TGI 确实很快。使用 Falcon-7B 并将最大 token 数设置为 500,仅需几秒钟,我的 RTX 3060 GPU 即可完成。

使用标准推理管道,几乎需要 40 秒,还不包括加载模型所需的时间。

结论

自托管聊天模型(即,指导 LLM)有很多优势。主要的是你不会将数据发送到互联网。另一个是你完全控制操作成本,这只反映在你的电费账单上。

然而,如果你使用消费级 GPU,你将无法运行最先进的 LLM。即使是较小的 LLM,我们也必须将其量化,以便在配备少于 24 GB VRAM 的 GPU 上运行。量化还会降低 LLM 的准确性。

尽管如此,即使是小型量化 LLM 也仍然适合简单任务:简单的编码问题、二元分类……

现在,你可以通过查询你的自托管 LLM 在计算机上完成所有这些任务。

如果你喜欢这篇文章并对阅读接下来的文章感兴趣,支持我工作的最佳方式是通过这个链接成为 Medium 会员:

[## 通过我的推荐链接加入 Medium - Benjamin Marie]

作为 Medium 会员,你的一部分会员费会给你阅读的作者,并且你可以全面访问每一个故事……

medium.com](https://medium.com/@bnjmn_marie/membership?source=post_page-----54f4dd8783a7--------------------------------)

如果你已经是会员并且想支持这项工作, 只需关注我在 Medium 上的账户

使用 TorchServe 服务 ML 模型

原文:towardsdatascience.com/serving-ml-models-with-torchserve-1578eca5aa20?source=collection_archive---------9-----------------------#2023-03-29

一个完整的端到端图像分类任务 ML 模型服务示例

Andrey GolovinTowards Data Science Andrey Golovin

·

关注 发表在 Towards Data Science ·8 min read·2023 年 3 月 29 日

--

图片来源:作者

动机

本文将引导你通过使用 TorchServe 框架服务你的深度学习 Torch 模型的过程。

关于这个主题有很多文章。然而,通常这些文章要么专注于部署 TorchServe 本身,要么专注于编写自定义处理程序并获得最终结果。这是我写这篇文章的动机。本文涵盖了这两个部分,并提供了端到端的示例。

以图像分类挑战为例。最终你将能够部署 TorchServe 服务器,提供模型服务,发送任何随机的衣物图片,并最终获得预测的衣物类别标签。我相信这就是人们对作为 API 端点提供分类服务的 ML 模型的期望。

简介

比如,你的数据科学团队设计了一个很棒的深度学习模型。这无疑是一个伟大的成就。然而,为了让它发挥价值,模型需要以某种方式向外界开放(如果不是 Kaggle 竞赛的话)。这就是所谓的模型服务。在这篇文章中,我不会涉及批处理操作的服务模式以及基于流处理框架的流处理模式。我将专注于将模型作为 API 提供服务的一个选项(无论这个 API 是由流处理框架还是任何自定义服务调用)。更准确地说,这个选项是 TorchServe 框架。

因此,当你决定将模型作为 API 提供服务时,你至少有以下几种选择:

  • 网络框架如 Flask、Django、FastAPI 等等

  • 云服务如 AWS Sagemaker 端点

  • 专用服务框架如 Tensorflow Serving、Nvidia Triton 和 TorchServe

每种方法都有其优缺点,选择可能并不总是简单明了的。让我们实际探索一下 TorchServe 选项。

结构

第一部分将简要描述模型是如何训练的。这对 TorchServe 来说并不重要,但我认为这有助于跟踪端到端的过程。接下来,将解释自定义处理器。

第二部分将重点讨论 TorchServe 框架的部署。

本文的源代码位于这里:git repo

准备一个深度学习模型

对于这个示例,我选择了基于 FashionMNIST 数据集的图像分类任务。如果你不熟悉这个数据集,它包含 70k 张 28x28 像素的灰度图像,展示了不同的衣物。共有 10 个衣物类别。因此,一个深度学习分类模型将返回 10 个 logit 值。为了简化,模型基于 TinyVGG 架构(如果你想用 CNN 解释器可视化它):仅有少量卷积层和最大池化层,带有 RELU 激活。仓库中的笔记本 model_creation_notebook 展示了训练和保存模型的全部过程。

简而言之,笔记本只是下载数据,定义模型架构,训练模型并用 torch 保存状态字典。有两个与 TorchServe 相关的文档:一个包含模型架构定义的类和保存的模型(.pth 文件)。

为 TorchServe 准备文档以提供模型服务

需要准备两个模块:模型文件和自定义处理器。

模型文件

根据文档,“*模型文件应包含模型架构。对于急切模式的模型,这个文件是必须的。

该文件应包含一个继承自 torch.nn.Module 的单个类。

所以,让我们从模型训练笔记本中复制类定义,并将其保存为 model.py(你可以选择任何名称):

处理程序

TorchServe 提供了一些默认处理程序(例如 image_classifier),但我怀疑它是否可以直接用于实际案例。因此,你很可能需要为你的任务创建一个自定义处理程序。处理程序实际上定义了如何从 http 请求中预处理数据,如何将其输入模型,如何后处理模型的输出以及在响应中返回最终结果。

有两个选项——模块级入口点和类级入口点。请查看官方文档这里

我将实现类级选项。这基本上意味着我需要创建一个自定义 Python 类并定义两个强制函数:initializehandle

首先,为了简化操作,让我们继承 BaseHandler 类。initialize 函数定义了如何加载模型。由于我们这里没有任何具体要求,因此可以直接使用超类中的定义。

handle 函数基本上定义了如何处理数据。在最简单的情况下,流程是:预处理 >> 推理 >> 后处理。在实际应用中,你可能需要定义自定义的预处理和后处理函数。对于这个示例的推理函数,我将使用超类中的默认定义:

预处理函数

比如,你构建了一个图像分类应用。该应用将图像作为负载发送给 TorchServe。图像可能不会始终符合用于模型训练的图像格式。此外,你可能会在样本批次上训练模型,张量维度必须进行调整。因此,让我们创建一个简单的预处理函数:将图像调整为所需的形状,转换为灰度图像,转换为 Torch 张量并将其作为单样本批次。

后处理函数

多类别分类模型将返回一个 logit 或 softmax 概率列表。但在实际场景中,你更可能需要一个预测类别或带有概率值的预测类别,或者可能是前 N 个预测标签。当然,你可以在主应用程序/其他服务中完成这些操作,但这意味着将应用程序逻辑与 ML 训练过程绑定在一起。因此,让我们直接在响应中返回预测类别。

(为了简单起见,这里的标签列表是硬编码的。在 github 版本中,处理程序从配置中读取标签)

启动 Torch 服务器

好了,模型文件和处理程序都准备好了。现在让我们部署 TorchServe 服务器。上面的代码假设你已经安装了 pytorch。另一个前提条件是安装了 JDK 11(注意,仅 JRE 不够,你需要 JDK)。

对于 TorchServe,你需要安装两个包:torchservetorch-model-archiver

成功安装后,第一步是准备一个 .mar 文件——包含模型工件的档案。torch-model-archiver 的 CLI 接口旨在实现这一点。终端中输入:

torch-model-archiver --model-name fashion_mnist --version 1.0 --model-file path/model.py --serialized-file path/fashion_mnist_model.pth --handler path/handler.py

参数如下:

-模型名称:你想给模型起的名称

-版本:用于版本控制的语义版本

-模型文件:包含模型架构类定义的文件

-序列化文件:来自 torch.save() 的 .pth 文件

-处理器:包含处理器的 Python 模块

结果是一个名为模型名称的 .mar 文件(在这个例子中是 fashion_mnist.mar)将在执行 CLI 命令的目录中生成。因此,最好在调用命令之前 cd 到你的项目目录。

下一步是启动服务器。终端中输入:

torchserve --start --model-store path --models fmnist=/path/fashion_mnist.mar 

参数:

-模型存储:存放 mar 文件的目录

-模型:模型名称和对应 mar 文件的路径。

注意,归档器中的模型名称定义了你的 .mar 文件将如何命名。torchserve 中的模型名称定义了调用模型的 API 端点名称。所以,这些名称可以相同也可以不同,取决于你。

执行这两个命令后,服务器应该启动并运行。默认情况下,TorchServe 使用三个端口:8080、8081 和 8082 分别用于推理、管理和指标。打开浏览器/curl/Postman 发送请求到

http://localhost:8080/ping

如果 TorchServe 正常工作,你应该看到 {‘status’: ‘Healthy’}

图片来源于作者

一些可能问题的提示

1. 如果在 torchserve -start 命令后你在日志中看到提到“..no module named captum”的错误,那么请手动安装它。我在 torchserve 0.7.1 中遇到了这个错误

2. 可能会有某些端口已经被另一个进程占用。那你可能会看到 ‘Partially healthy’ 状态和日志中的一些错误。

要检查哪个进程使用了 Mac 上的端口(例如 8081),输入:

sudo lsof -i :8081

一种选择是终止进程以释放端口。但如果该进程很重要,这可能并不是一个好主意。

另外,可以在简单的配置文件中为 TorchServe 指定任何新的端口。假设你有一个已经在 8081 端口上运行的应用程序。通过创建包含一行的 torch_config 文件来更改默认的 TorchServe 管理 API 端口:

management_address=https://0.0.0.0:8443

(你可以选择任何空闲端口)

接下来我们需要让 TorchServe 知道配置。首先,通过以下命令停止不健康的服务器

torchserve --stop

然后重新启动它:

torchserve --start --model-store path --models fmnist=/path/fashion_mnist.mar --ts-config path/torch_config

请求模型的预测

在这一步假设服务器已正确启动。让我们将随机的衣物图像传递给推理 API 并获取预测标签。

推理的端点是

http://localhost:8080/predictions/model_name

在这个例子中是 http://localhost:8080/predictions/fmnist

让我们使用 curl 传递一个图像:

curl -X POST http://localhost:8080/predictions/fmnist -T /path_to_image/image_file

例如,使用 repo 中的示例图像:

curl -X POST http://localhost:8080/predictions/fmnist -T tshirt4.jpg

(X 标志用于指定方法 /POST/,-T 标志用于传输文件)

在响应中,我们应该看到预测的标签:

图片由作者提供

结论

好吧,通过跟随这篇博客文章,我们能够创建一个 REST API 端点,我们可以向其发送图像并获取图像的预测标签。通过在服务器上重复相同的过程,而不是本地机器,一个人可以利用它来为面向用户的应用程序、其他服务,或者例如流式机器学习应用程序创建一个端点(参见这篇有趣的论文了解为什么你可能不应该这样做:https://sites.bu.edu/casp/files/2022/05/Horchidan22Evaluating.pdf

敬请关注,在下一部分中我将扩展示例:让我们为业务逻辑创建一个 Flask 应用程序的模拟,并通过 TorchServe 调用一个机器学习模型(并使用 Kubernetes 部署所有内容)。

一个简单的用例:面向用户的应用程序,具有大量业务逻辑和许多不同的功能。比如,一个功能是上传图像,将所需的样式应用于图像,使用样式迁移机器学习模型。机器学习模型可以通过 TorchServe 提供,因此机器学习部分将完全与主应用程序中的业务逻辑和其他功能解耦。

为 2024 年数据科学家的更高质量工作与生活平衡设定这些界限

原文:towardsdatascience.com/set-these-boundaries-for-a-better-quality-work-life-balance-as-a-data-scientist-in-2024-e3af4a256a23?source=collection_archive---------14-----------------------#2023-11-04

这 5 条不可妥协的原则将帮助你让 2024 年成为你最平衡的一年

Madison HunterTowards Data Science Madison Hunter

·

关注 发表在Towards Data Science ·6 分钟阅读·2023 年 11 月 4 日

--

图片由Leonardo Iheme拍摄,来源于Unsplash

工作与生活的平衡是每个人都渴望的,但只有少数人有勇气去实现。

在 Google 上“工作与生活平衡”有 29 亿个搜索结果,这很明显它是我们都在追求的东西。它不仅成为了我们搜索的重点,而且在过去三年里,它似乎逐渐进入了我们的日常对话中。

到 2020 年,数据科学被视为一种能够提供大家谈论的神秘工作与生活平衡的职业。然而,许多人似乎意识到,从事某种数据科学工作可能和其他任何工作一样,甚至因为我们更普遍的“灵活”工作安排(老板似乎觉得需要更加紧握我们的手,以缓解他们对微管理的不安)而更加消耗生活。

不幸的是,工作与生活平衡并不是总能得到的。有时候,它需要通过设定界限和非谈判项来争取。随着 2024 年仅剩两个月的时间,现在是准备如何恢复你的工作与生活平衡的完美时机,以使即将到来的一年成为你最平衡的一年。以下是你需要设定的五个界限,以获得 2024 年更高质量的工作与生活平衡。

1. 准备一个文档系统

不可避免地,项目里程碑将会被超越,预算会紧张,时间表会混乱。当这种情况发生时,你可能会成为承担你团队领导、客户,或者更糟糕的,你的老板全部怒火的“出气筒”。

然而,你也很有可能是唯一一个点了 i 的点和画了 t 的那个人。因此,为了你的工作与生活平衡,创建一个支持你在事情出错时能够证明你是在做自己工作的文档系统是很重要的。

我最喜欢记录这些互动方式之一是使用电子表格。在那里,你可以创建一个简单的文档,每一行都是一个独特的事件/邮件/对话(按需删除),列出如事件 ID(因为我们是数据科学家,不是吗?没有这个就会一片混乱)、日期、你互动的人的名字、问题、他们的回应、你的回应、是否跟进、解决方案是什么、是否需要升级措施等信息。我保持这种文档全天候打开,以确保我遇到的一切都被记录下来。

尽管这可能看起来繁琐或过度,但我从技术工作中学到的唯一一件事是:记录、记录、记录,这将始终避免你做更多不必要的工作。

2. 将项目时间表设定为所需时间的两倍

任何在技术行业工作超过两分钟的人都知道,项目总是会比预期花费更长的时间。因此,在你将重新获得工作与生活平衡的那一年,你需要给出项目所需时间的现实估计——换句话说,总是说项目会花费至少两倍的时间,因为它们总是会这样。

数据可能意外不可用,你的团队可能生病,你的代码可能会让软件开发人员难以将其制作成生产级的产品,你的客户可能会在你准备发布的前一周完全改变他们的要求,等等。

为了你的合同工时(顺便提一下,你绝对不应超时工作,见下文),你必须为项目设定适当的时间表,使你能够在经历所有脱轨情况的情况下,仍能生产高质量的工作,而不牺牲你的工作与生活平衡。没有什么比压力更能迅速降低项目质量,这就是为什么适当的数据清理、分析和可视化的工作最好在冷静的头脑和看似遥远的截止日期下进行。更好的是,如果你能提前完成项目,你就会低估承诺、高估交付,这应该真正成为你作为数据科学家的座右铭。

3. 其他人的糟糕计划不等于你的紧急情况

有时,不是你设定适当的项目截止日期,而是由其他人为你设定。这些截止日期不切实际,并且会影响你的工作与生活平衡。

解决方案?

告诉人们你不会满足他们的不切实际的截止日期,他们应该首先与你商量,以防止类似问题再次发生。

哎呀。我能理解这对刚入行的数据科学家来说听起来很可怕。然而,我也从经验中知道,如果你从一开始就不为自己站出来,后面将会困难得多,以至于最终可能更容易离开工作,而不是试图收紧你不断扩大的责任、项目和不切实际的截止日期。

数据分析最好不要急于求成,即使是看似微不足道的任务,如改变矩阵散点图上的颜色。虽然你会随着时间的推移在某些任务上变得更快,自动化也可能承担一些繁重的工作,但对依赖你分析结果的客户来说,提供半吊子的结果是毫无意义的。你呈现的结论和策略可能对客户(不是对个人,而是对公司)具有改变人生的影响,这意味着你需要对你的发现非常确定。因此,你不应该被催促。最终大家都会受益,你可能需要提醒他们这一点。

4. 永远不要为了虚假的截止日期加班

通过开始坚持对上述不合理计划的截止日期,你将已经在实现下一个不可谈判目标上取得了良好的进展。下一步是你永远不要为了虚假的截止日期加班。

人为的截止日期是一种让你突然在晚餐时间开始检查电子邮件、在周末推送代码,以及让团队聊天演变成另一个完全“快速”对接编程会话的快速简便方法。

2024 年是你拒绝为了人为截止日期而加班的时机,而是将这种有时必要的邪恶仅限于最极端和特殊的情况。但说实话,数据科学家何时真的需要加班呢?这种情况非常少见,虽然这些时候似乎出现的频率很高,尤其是当数据科学家不仅是数据科学家,还兼任系统分析师、他们工作区的本地 IT 帮助台,甚至可能还是让一切具备生产就绪状态的软件开发者。

关键在于,为了保持工作与生活的平衡,你需要对加班的紧急性诚实。因为只要你在公司内的工作情况健康,就没有真正必要的“数据紧急情况”加上人为截止日期。

5. 规定质量是你唯一的操作方式

客户希望得到项目三角形的所有三个方面:速度、成本和质量。这是生活的一个事实。

作为公司为数不多的数据科学家之一,某种程度上这自动成为了你的职责,即向客户阐明数据分析项目的运作方式,尤其是为什么速度从来不是解决问题的答案。虽然软件部门可以在几个小时内快速组装一个产品管理系统,但对于预测未来一年商业趋势的深入数据分析则需要一些技巧。

换句话说,当你的结果可能决定客户未来业务的走向时,质量永远不会因为速度而被舍弃。最终总是会以不好的结局告终。客户会从你多花几个小时来优化预测模型中受益,特别是当复杂变量起作用时,更详细的见解将从一个没有被为了几天速度而匆忙完成的分析中获得。

订阅以将我的故事直接发送到你的收件箱:Story Subscription

请成为会员,以通过我的推荐链接无限访问 Medium(这不会额外增加你的费用,我会获得少量佣金):Medium Membership

订阅我的通讯,以获得更多带有环保主义色彩的独家数据驱动内容:DataDrivenEnvironmentalist

通过捐赠支持我的写作,以资助创作更多类似的故事:Donate

为数据科学设置 Flask 应用

原文:towardsdatascience.com/setting-up-a-flask-application-for-data-science-7522fc9f771e

构建一个 Flask 应用的基本结构,以支持模块化开发

Philip Wilkinson, Ph.D.Towards Data Science Philip Wilkinson, Ph.D.

·发布于 Towards Data Science ·阅读时间 9 分钟·2023 年 3 月 10 日

--

照片由 KOBU Agency 提供,来源于 Unsplash

数据科学工作流程通常涉及笔记本和 Python 脚本的使用。这些都是很棒的工具,但通常意味着你的输出可能会一直保留在这些文件中,无法展示。然而,改变这种情况的一个好方法是创建一个网站来展示和讨论你的发现,或者创建一个 API 将你的模型提供给全世界。Flask 是一个可以帮助实现这一点的框架。

Flask 允许你构建网站和 API,使你可以更广泛地分享你的成果。无论是通过一个展示你的工作和结果的界面,还是通过一个其他人可以调用以获取模型预测的 API,Flask 都能实现。Flask 是一个轻量级框架,易于学习和使用,非常适合希望专注于构建模型和分析数据的 Data Scientist,而不是学习复杂的 Web 开发框架。它也使用 Python,因此数据科学工作流程中的许多步骤可以很容易地转移到 Flask 框架中。

在本文中,我将向你展示如何设置一个 Flask 应用的基本框架,之后你可以在其基础上进行扩展。这将包括解释应用的基本结构以及你希望放入每个文件的内容。

Flask 项目结构

我们可以从基本的 Flask 结构概述开始。它的形式如下:

application/
|-static/
|--css/
|--images/
|--scipts/
|-templates/
|--includes/
|--index.html
|-__init__.py
|-routes.py
tests/
flask_env/
.env
.env.template
.gitignore
flask_config.py
app.py
requirements.txt

其中:

  • application/ 目录包含主要的 Flask 应用代码

  • static/ 目录包含网站中使用的静态文件,如 CSS 样式表、JavaScript 脚本和图片

  • templates/ 目录包含 Flask 应用使用的 HTML 模板

  • __init__.py 初始化 Flask 应用程序并设置配置和数据库连接。

  • routes.py 定义了 Flask 应用程序的路由。

  • tests/ 目录包含 Flask 应用程序的测试文件。

  • flask_env 包含 Flask 应用程序的虚拟环境。

  • .env 文件包含 Flask 应用程序使用的环境变量。

  • flask_config.py 包含 Flask 应用程序的配置设置。

  • app.py 是 Flask 应用程序的入口点,并运行 Flask 开发服务器。

那么,我们如何开始呢?

设置环境。

要开始构建这个 Flask 应用程序,你需要创建一个 Python 虚拟环境。这是一个好习惯,有助于将应用程序的依赖项与机器上的全局 Python 环境隔离开来。这使得管理依赖项变得更容易,并确保应用程序在不同机器上顺利运行。

为此,你需要首先打开一个终端窗口,并导航到你想要创建虚拟环境的目录。然后运行以下命令:

python -m venv <venv_name>

在这种情况下,我将其命名为 flask_application,但也可以简单地命名为 env

然后你可以使用以下命令激活环境:

#on windows
<venv_name>\Scripts\activate

#on Max or Linux
source <venv_name>/bin/activate

并安装常用库。在这种情况下,我们将使用的基本库是 flask(当然)、python-dotenvpytest。要安装这些库,你可以运行以下命令:

pip install flask python-dotenv pytest

这些库应该安装在你的虚拟环境中。请注意,有时这可能需要一些时间,所以不要太担心!

当对虚拟环境进行更改时,例如安装新库时,最好冻结要求,以便其他人知道你使用了哪些库及其版本。这可以通过以下命令完成:

pip freeze > requirements.txt

这将创建 requirements.txt 文件。这将使其他人更容易在自己的机器上复制你的环境。

基础文件。

一旦你设置好虚拟环境并创建了 requirements.txt 文件,你就可以开始构建你的应用程序。因此,我们可以从目录结构底部的文件开始:

.env
.env.template
.gitignore
flask_config.py
app.py

你需要首先从 .env 文件开始,该文件用于设置应用程序的环境变量。刚开始时,它看起来像这样:

# Flask server configurations
FLASK_APP=app
FLASK_ENV=development
FLASK_DEBUG=1

# Secret key
SECRET_KEY = b'nnz\x89\x0f\xde\x8a\xc4\x13\xe0\xf0\xca>=\xe0\xfe'

第一行告诉 Flask 应用程序应用程序文件叫做 app.py(尚未创建)。使用 flask run 命令时,这将告诉应用程序运行 app.py 文件以启动应用程序。

我们还从开发环境开始,因为这只是目前在本地机器上。为此,我们设置FLASK_ENV = developmentFLASK_DEBUG=1,这表明我们有一个开发环境,并在环境启动时运行 Flask 调试器。这允许热重载,因此当您对代码进行任何更改时,不必重启实例即可在浏览器中查看这些更改!

此文件中的最后一行是秘密密钥。这是一个加密密钥,用于加密会话数据和客户端与服务器之间传输的其他敏感数据,是重要的安全措施。在这种情况下,我们使用一个 16 位密钥,该密钥通过以下命令生成:

python3 -c "import os; print(os.urandom(16))"

如果您有其他信息或密钥需要应用程序使用,例如 API 密钥,可以将这些信息添加到此文件中。由于这些是环境变量,通常使用蛇形命名法(空格替换为_),所有字母都大写。

当然,除非您与他人合作并希望安全地共享,否则您实际上不希望与任何人分享您的秘密或 API 密钥。这就是为什么您需要复制.env文件来创建.env.template文件,并删除您希望保密的任何信息。其他开发人员可以使用.env.template文件创建自己的.env文件,填入自己的密钥和信息,或在必要时安全地共享自己的密钥。

这意味着在您的.gitignore文件中,您只需写入:

.env

这将告诉 git 忽略对.env文件的任何更改,以确保您的秘密不会被共享。

下一阶段是创建flask_config.py文件,用于为您的应用程序创建一个配置对象。在我们的案例中,这个对象的形式为:

import os 
from dotenv import load_dotenv

load_dotenv()

class Config:
    def __init__(self):
        """Base configuration variables"""
        self.SECRET_KEY = os.environ.get("SECRET_KEY")
        if not self.SECRET_KEY:
            raise("Secret key is missing. Something is wrong")

该密钥使用了在.env文件中定义的SECRET_KEY变量。如果没有找到秘密密钥,将会引发错误,并阻止应用程序运行。这是为了确保应用程序按预期运行,并且是安全的,您可以根据需要添加其他配置参数。

最后,我们有app.py文件,它是 Flask 应用程序的入口点,用于运行 Flask 开发服务器。目前,这仅包含:

from application import app

if __name__ == "__main__":
    app.run()

该结构导入application文件夹中的app并运行它。建立了基础结构后,我们可以继续创建应用程序结构。

应用程序结构

我们在仓库中创建一个application文件夹,以结构化和模块化的方式组织应用程序代码和资源。具体而言,这种结构使得应用程序可以在需要时扩展为更全面的应用程序,通过创建多个路由、模型、视图模型和模板,从而提高代码的可维护性和可扩展性。然而,目前我们只需要一个简单的结构。

我们可以通过创建 __init__.py 文件开始,该文件用于初始化从基础文件夹中的 app.py 文件调用的应用程序。它包含:

from flask import Flask
from flask_config import Config

app = Flask(__name__)
app.config.from_object(Config)

from application import routes

它导入 flask 库和 Config 对象,初始化应用程序,然后导入路由。

重要的是,我们可以看到这个文件中没有定义任何路由或视图。为了保持模块化结构,这些决定应用程序行为的路由是在 routes.py 文件中定义的,与初始化分开。

因此,我们可以将 routes.py 文件定义如下:

from flask import render_template
from application import app

@app.route("/")
def index():
    return render_template("index.html")

这里有两个主要要点需要注意:

  1. 奇怪的装饰器语法 @app.route("/")。这个装饰器用于定义用户可以访问的路由,接受一个 URL 模式参数并将其与视图函数关联。在这个示例中,我们为根 URL (/) 定义了一个路由,并将其与 index() 函数关联。

  2. index() 函数不是返回一个值,而是返回函数 render_template("index.html")。这个函数由 flask 提供,用于允许函数返回预定义的 HTML 模板。

那么这些 HTML 模板是什么呢?

好吧,在 index() 函数中我们返回了 index.html 模板。从文件结构来看,它位于 templates 文件夹中,并包含:

<!doctype html>
<html lang="en">
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">

    <link rel="stylesheet" type="text/css" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
    <link rel="stylesheet" href="{{url_for('static', filename='css/style.css')}}">
    <title>{% block title %}Hello World{% endblock %}</title>
  </head>
  <body>

    <div class = "container body-content">
        <h1>Hello World!</h1>
    </div>

    <script src="https://code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.7/umd/popper.min.js" integrity="sha384-UO2eT0CpHqdSJQ6hJty5KVphtPhzWj9WO1clHTMGa3JDZwrnQq4sF86dIHNDz0W1" crossorigin="anonymous"></script>
    <script src="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/js/bootstrap.min.js" integrity="sha384-JjSmVgyd0p3pXB1rRibZUAYoIIy6OrQ6VrjIEaFf/nJGzIxFDsf4x0xIM+B07jRM" crossorigin="anonymous"></script>
  </body>
</html>

这本质上是一个 HTML 文件,在我们的例子中简单地显示“Hello World”,可以在 container 类中的 <h1></h1> 标签之间看到。

我们还在头部和主体中添加了一些额外的内容,主要用于导入 bootstrap,一个使开发变得更容易的 CSS 框架,以及 JQuery,使应用程序的交互更容易。但如果你愿意,可以忽略这些,你可以简单地在主体中定义 <h1>Hello World!</h1>

其他文件和文件夹

我们已经覆盖了应用程序的基本结构及其内容,但还有一些其他文件和文件夹没有涉及:

  • application/templates — 这个文件夹将包含应用程序使用的所有 HTML 模板。

  • application/templates/includes — 这个文件夹可以包含你可能想要在布局或其他文件中包含的辅助 HTML 结构,例如导航栏、页脚或大屏幕。

  • tests — 这个文件夹将包含所有对 flask 应用程序的测试,进一步分为 functionalunit 测试两个文件夹。

  • static — 这个文件夹包含了应用程序中使用的所有静态文件,包括 cssimages 和可能使用的 scripts

运行应用程序

现在我们已经完成了所有结构设置和环境配置,只需运行命令 flask run。你应该会看到类似如下的内容:

图片由作者提供

这告诉你应用程序正在运行,并且运行在 http://127.0.0.1:5000

这意味着你应该能够在浏览器中导航到这个 URL 或 localhost:5000 并查看输出:

图片由作者提供

这表明应用程序正在运行!恭喜你,成功运行了第一个 Flask 应用!

扩展功能

超越这里展示的基本结构,Flask 具有相当大的灵活性,允许各种不同的扩展:

  • Flask-wtf:一个可以以安全和可靠的方式处理用户输入的库

  • Flask-SQLAlchemy:一个允许你与 SQLAlchemy ORM 及数据库集成的库(还有其他库可以与如 MongoDB 等常见数据库进行交互)

  • Flask-Login:一个提供用户认证功能的库,如果你想根据用户角色管理对应用程序某些部分的访问权限。

在许多更多的库和用例中!这种结构通过允许你以模块化的方式构建应用程序,从而促进了这些扩展,提升了代码的可维护性、可扩展性和模块化。不论你是想通过路由服务你的机器学习算法的 API,还是想在网站上展示你的故事!

本文的代码可以在:github.com/PhilipDW183/flask_structure 获取

设置 Python 项目:第 V 部分

原文:towardsdatascience.com/setting-up-python-projects-part-v-206df3c1e3d3

掌握 Python 项目设置的艺术:逐步指南

Johannes SchmidtTowards Data Science Johannes Schmidt

·发表于Towards Data Science ·阅读时间 20 分钟·2023 年 1 月 14 日

--

照片由Zoya Loonohod拍摄,来自Unsplash

无论你是经验丰富的开发者还是刚刚开始接触🐍 Python,了解如何构建稳健且易于维护的项目都很重要。本教程将指导你使用一些行业内最流行且有效的工具来设置 Python 项目。你将学习如何使用GitHubGitHub Actions进行版本控制和持续集成,以及其他工具进行测试、文档编写、打包和分发。本教程的灵感来源于Hypermodern Python新 Python 项目的最佳实践。然而,这并不是唯一的方法,你可能有不同的偏好或观点。教程旨在对初学者友好,同时涵盖一些高级主题。在每个部分,你将自动化一些任务,并为你的项目添加徽章,以展示你的进展和成就。

该系列的代码库可以在github.com/johschmidt42/python-project-johannes找到

这一部分的灵感来自于这篇博客文章:

[Python、Poetry 与 GitHub Actions 的语义化发布 🚀

由于我的同事们的兴趣,我计划向 Dr. Sven 添加一些功能。在此之前,我需要…](https://mestrak.com/blog/semantic-release-with-python-poetry-github-actions-20nn)

要求

  • 操作系统: Linux、Unix、macOS、Windows(WSL2,例如 Ubuntu 20.04 LTS)

  • 工具:python3.10, bash, git, tree

  • 版本控制系统(VCS)主机GitHub

  • 持续集成(CI)工具GitHub Actions

预计你对版本控制系统(VCS)git有所了解。如果不了解,以下是一个复习: Git 介绍

提交将基于 最佳 git 提交实践传统提交。你可以使用 PyCharm 的传统提交插件VSCode 扩展 来帮助你以这种格式编写提交。

概述

结构

  • Git 分支策略 (GitHub 流程)

  • 什么是发布? (zip, tar.gz)

  • 语义版本控制 (v0.1.0)

  • 手动创建发布 (git tag, GitHub)

  • 自动创建发布 (传统提交,语义发布)

  • CI/CD (release.yml)

  • 创建个人访问令牌(PAT)

  • GitHub Actions 流程 (编排工作流)

  • 徽章 (发布)

  • 奖励 (强制执行传统提交)

发布软件是软件开发过程中的重要步骤,因为它使新功能和修复程序可供用户使用。发布软件的一个关键方面是版本控制,它有助于跟踪和传达每个发布中的变化。语义版本控制是一种广泛使用的软件版本控制标准,它使用格式为 Major.Minor.Patch(例如 1.2.3)的版本号来指示发布中所做更改的级别。

传统提交是一种为提交消息添加人类和机器可读意义的规范。它是一种以一致的方式格式化提交消息的方法,这使得确定所做更改的类型变得简单。传统提交通常与语义版本控制结合使用,因为提交消息可以用来自动确定发布的版本号。语义版本控制和传统提交一起提供了一种清晰且一致的方法来跟踪和传达每个软件项目发布中的更改。

Git 分支策略

git 有许多不同的分支策略。很多人倾向于使用GitFlow(或变种)、Three FlowTrunk based Flows。一些人使用这些策略中的混合策略,例如这个策略。我使用非常简单的GitHub flow分支策略,其中所有的 bug 修复和功能都有各自的独立分支,完成后,每个分支都会合并到主分支并进行部署。简单、好用且易于操作。

GitHub Flow 分支策略

无论你的策略是什么,最终你都会合并一个拉取请求,并(可能)创建一个版本发布。

什么是版本发布?

简而言之,发布就是将一个版本的代码打包(例如压缩文件),并推送到生产环境(这对你来说可能是任何东西)。

版本管理可能会很混乱。因此,需要有一个简明的方法(以及其他人也跟随的方法),定义什么是版本发布,以及一个版本与下一个版本之间的变化。如果你不跟踪版本之间的变化,你可能不会理解每个版本中发生了什么变化,也无法识别新代码中可能引入的任何问题。没有变更日志,很难理解软件如何随着时间的推移而发展。它也可能使回滚更改变得困难(如果必要的话)。

语义化版本控制

语义化版本控制只是一个编号方案和业界的标准实践。它指示了该版本与前一个版本之间的变更程度。一个语义版本号有三个部分,例如1.8.42,遵循以下模式:

  • MAJOR.MINOR.PATCH

每个部分代表了不同程度的变化。PATCH 版本发布表示错误修复或微小更改(例如从 1.0.0 到 1.0.1)。MINOR 版本发布表示添加/删除功能或向后兼容的功能更改(例如从 1.0.0 到 1.1.0)。MAJOR 版本发布表示添加/删除功能以及可能的向后不兼容的更改,例如破坏性更改(例如从 1.0.0 到 2.0.0)。

我推荐迈克·迈尔斯的一个讲座,如果你想要一个关于语义版本发布的视觉介绍。它总结了什么是发布,以及如何利用git 标签来创建版本发布。

关于git 标签:git 中有轻量级标签和注释标签。一个轻量级标签只是指向特定提交的指针,而注释标签则是 git 中的一个完整对象。

手动创建版本发布

让我们先手动创建一个版本发布,然后再进行自动化处理。

如果你记得,我们的 example_app 的 __init__.py 文件包含了版本信息。

# src/example_app/__init__.py

__version__ = "0.1.0"

以及 pyproject.toml 文件

# pyproject.toml

[tool.poetry]
name = "example_app"
version = "0.1.0"
...

所以我们首先必须做的是创建一个注释的 git 标签 v0.1.0 并将其添加到主分支的最新提交中:

> git tag -a v0.1.0 -m "version v0.1.0"

请注意,如果在命令末尾没有指定提交哈希,则 git 会使用你当前所在的提交。

我们可以通过以下命令获取标签列表:

> git tag

v0.1.0

如果我们想要再次删除它:

> git tag -d v0.1.0

Deleted tag 'v0.1.0'

并通过以下命令获取有关该标签的更多信息:

> git show v0.1.0

tag v0.1.0

Tagger: Johannes Schmidt <johannes.schmidt.vik@gmail.com>
Date:   Sat Jan 7 12:55:15 2023 +0100
version v0.1.0
commit efc9a445cd42ce2f7ddfbe75ffaed1a5bc8e0f11 (HEAD -> main, tag: v0.1.0, origin/main, origin/HEAD)
Author: Johannes Schmidt <74831750+johschmidt42@users.noreply.github.com>
Date:   Mon Jan 2 11:20:25 2023 +0100
...

我们可以通过以下命令将新创建的标签推送到 origin:

> git push origin v0.1.0

Enumerating objects: 1, done.
Counting objects: 100% (1/1), done.
Writing objects: 100% (1/1), 171 bytes | 171.00 KiB/s, done.
Total 1 (delta 0), reused 0 (delta 0), pack-reused 0
To github.com:johschmidt42/python-project-johannes.git
 * [new tag]         v0.1.0 -> v0.1.0

使得这个 git 标签现在可以在 GitHub 上使用:

让我们手动在 GitHub 上创建一个新的版本发布,并使用这个 git 标签:

我们点击 Create a new release,选择我们现有的标签(已经绑定到提交),然后通过点击 Generate release notes 按钮自动生成发布说明,最后用 Publish release 按钮发布该版本。

GitHub 将自动为源代码创建 tarzip(资产),但不会构建应用程序!结果将如下所示:

总结一下,发布的步骤是:

  • 从你的默认分支创建一个新分支(例如功能或修复分支)

  • 进行更改并增加版本(例如 pyproject.tomlinit.py

  • 将功能/错误修复提交到默认分支(可能通过 Pull Request)

  • 添加一个 注释的 git 标签(语义版本)到提交中

  • 在 GitHub 上发布版本,并附加一些额外信息

自动创建发布

作为程序员,我们不喜欢重复自己。因此,有很多工具可以让这些步骤变得非常简单。在这里,我将介绍Semantic Releases,一个专门为 Python 项目设计的工具。

这是一个自动在你的仓库中设置版本号、用版本号标记代码并创建发布的工具!这一切都是基于 约定式提交 风格消息的内容完成的。

约定式提交

语义版本控制和 conventional-commits 之间有什么联系?

某些提交类型可以用于自动确定语义版本的提升!

  • 一个 fix 提交是 PATCH。

  • 一个 feat 提交是 MINOR。

  • 一个带有 BREAKING CHANGE! 的提交是 MAJOR。

其他类型的提交,例如 buildchorecidocsstylerefactorperftest 通常不会增加版本。

查看最后的附加部分,了解如何在你的项目中强制执行约定式提交!

自动语义版本发布(本地)

我们可以通过以下命令添加库:

> poetry add --group semver python-semantic-release

让我们深入配置设置,以便自动生成变更日志和发布。在 pyproject.toml 中,我们可以将 semantic_release 作为工具添加:

# pyproject.toml

...
[tool.semantic_release]
branch = "main"
version_variable = "src/example_app/__init__.py:__version__"
version_toml = "pyproject.toml:tool.poetry.version"
version_source = "tag"
commit_version_number = true # required for version_source = "tag"
tag_commit = true
upload_to_pypi = false
upload_to_release = false
hvcs = "github" # gitlab is also supported
  • branch:指定发布应基于的分支,在这种情况下是 "main" 分支。

  • version_variable:指定源代码中版本号的文件路径和变量名称。在这种情况下,版本号存储在文件 src/example_app/__init__.py 中的 __version__ 变量中。

  • version_toml:指定pyproject.toml文件中版本号的文件路径和变量名称。在这种情况下,版本号存储在 pyproject.toml 文件的 tool.poetry.version 变量中。

  • version_source:指定版本号的来源。在这种情况下,版本号来自标签(而不是提交)。

  • commit_version_number:当version_source = "tag"时,此参数是必需的。它指定是否将版本号提交到仓库。在这种情况下,它设置为 true,这意味着版本号将被提交。

  • tag_commit:指定是否为发布提交创建新的标签。在这种情况下,它设置为 true,这意味着将创建一个新的标签。

  • upload_to_pypi:指定是否将软件包上传到 PyPI 包仓库。在这种情况下,它设置为 false,这意味着软件包不会上传到 PyPI。

  • upload_to_release:指定是否将软件包上传到 GitHub 发布页面。在这种情况下,它设置为 false,这意味着软件包不会上传到 GitHub 发布页面。

  • hvcs:指定项目的托管版本控制系统。在这种情况下,它设置为 "github",这意味着项目托管在 GitHub 上。"gitlab" 也是支持的。

我们可以更新定义项目/模块版本的文件。为此,我们使用变量version_variable用于普通文件,version_toml用于.toml 文件。version_source定义了版本的真实性来源。由于这两个文件中的版本与 git 注释标签紧密耦合,例如我们每次发布时自动创建 git 标签(标志tag_commit设置为 true),我们可以使用源tag,而不是默认值commit,后者在提交信息中查找最后一个版本。为了能够更新文件并提交更改,我们需要设置 [commit_version_number](https://github.com/relekang/python-semantic-release/issues/104) 标志为 true。因为我们不想将任何东西上传到 Python 索引 PyPi,所以标志upload_to_pypi设置为 false。现在我们也不想将任何东西上传到我们的发布页面。hvcs设置为github(默认),其他值可以是:gitlab

我们可以通过运行几个命令在本地测试这一点,我将直接将这些命令添加到我们的 Makefile 中:

# Makefile

...

##@ Releases

current-version: ## returns the current version
 @semantic-release print-version --current

next-version: ## returns the next version
 @semantic-release print-version --next

current-changelog: ## returns the current changelog
 @semantic-release changelog --released

next-changelog: ## returns the next changelog
 @semantic-release changelog --unreleased

publish-noop: ## publish command (no-operation mode)
 @semantic-release publish --noop

使用命令current-version我们可以从 git 树中的最后一个 git 标签获取版本:

> make current-version

0.1.0

如果我们以传统提交风格添加一些提交,例如 feat: new cool featurefix: nasty bug,那么命令 next-version 将计算版本号的增量:

> make next-version

0.2.0

目前,我们的项目中没有 CHANGELOG 文件,因此当我们运行:

> make current-changelog 

输出将是空的。但根据提交记录,我们可以使用以下方法创建即将发布的变更日志:

> make next-changelog### Feature
* Add releases (#8)) (`5343f46`))
* Docstrings (#5)) (`fb2fa04`))
* Add application in app.py (`3f07683`))### Documentation
* Add search bar & github url (#6)) (`3df7c48`))
* Add badge pages.yml to README.py (`b76651c`))
* Add documentation to Makefile (#3)) (`2294ee1`))

如果我们推送新的提交(直接到主分支或通过 PR),我们现在可以发布一个新版本:

> semantic-release publish

发布命令将执行一系列操作:

  1. 更新或创建变更日志文件。

  2. 运行 semantic-release version

  3. 将更改推送到 git。

  4. 运行 build_command 并将分发文件上传到你的仓库。

  5. 运行 semantic-release changelog 并发布到你的 VCS 提供者。

  6. 将由 build_command 创建的文件附加到 GitHub 发布中。

每一步当然都可以配置或禁用!

CI/CD

让我们使用 GitHub Actions 构建一个 CI 流水线,每次提交到主分支时运行 semantic-release 的发布命令。

虽然整体结构与 lint.ymltest.ymlpages.yml 相同,但有一些变化需要说明。在步骤 Checkout repository 中,我们添加了一个新的 token,用于检出分支。这是因为默认值 GITHUB_TOKEN 没有操作受保护分支所需的权限。因此,我们必须使用一个包含 个人访问令牌 权限的秘密 (GH_TOKEN) 。稍后我会展示如何生成个人访问令牌。我们还定义了 fetch-depth: 0 以提取所有分支和标签的全部历史记录。

with:
  ref: ${{ github.head_ref }}
  token: ${{ secrets.GH_TOKEN }}
  fetch-depth: 0

我们仅安装 semantic-release 工具所需的依赖项:

- name: Install requirements
  run: poetry install --only semver

在最后一步,我们更改一些 git 配置并运行 semantic-release 的发布命令:

- name: Python Semantic Release
  env:
    GH_TOKEN: ${{ secrets.GH_TOKEN }}
  run: |
    set -o pipefail
    # Set git details
    git config --global user.name "github-actions"
    git config --global user.email "github-actions@github.com"
    # run semantic-release
    poetry run semantic-release publish -v DEBUG -D commit_author="github-actions <action@github.com>"

通过更改 git 配置,提交的用户将会是“github-actions”。我们以 DEBUG 日志(stdout)运行发布命令,并显式将 commit_author 设置为“github-actions”。除了这个命令,我们还可以直接使用 semantic-release 的 GitHub action,但设置步骤 非常少,且该 action 每次都需要拉取 docker 容器。因此,我更倾向于采用简单的运行步骤。

因为发布命令会生成提交,您可能会担心我们会陷入触发工作流的无限循环。但请放心,生成的提交不会触发另一个 GitHub Actions 工作流运行。这是由于 GitHub 设定的限制

创建个人访问令牌(PAT)

个人访问令牌是使用密码进行 GitHub Enterprise Server 身份验证的替代方案,当使用 GitHub API命令行 时。个人访问令牌旨在代表您访问 GitHub 资源。要代表组织访问资源或用于长期集成,您应该使用 GitHub 应用。有关更多信息,请参见“关于应用”。

换句话说:我们可以创建一个 Personal Access Token,并让 GitHub Actions 存储并使用该 secret 代表我们执行某些操作。请记住,如果 PAT 被泄露,可能会被用于对您的 GitHub 仓库执行恶意操作。因此,建议在组织中使用 GitHub OAuth 应用和 GitHub 应用。为了本教程的目的,我们将使用 PAT 允许 GitHub Actions 流水线代表我们操作。

我们可以通过导航到 GitHub 用户的 Settings 部分并按照 创建个人访问令牌 中总结的说明来创建新的访问令牌。这将给我们一个看起来像这样的窗口:

具有推送访问权限的管理员帐户的个人访问令牌。

通过选择作用域,我们定义令牌将具有的权限。对于我们的用例,我们需要 push access 权限,因此新的 PAT GH_TOKEN 应该具有 repo 权限作用域。该作用域将授权对受保护分支的推送,前提是您没有在受保护分支的设置中启用 包括管理员

回到代码库概览,在 设置 菜单中,我们可以在 密钥 部分添加环境设置或库设置:

仓库密钥特定于单个仓库(及其中使用的所有环境),而环境密钥特定于环境。GitHub 运行器可以配置为在特定环境中运行,这允许它访问该环境的密钥。这在考虑不同阶段(例如 DEV 与 PROD)时是有意义的,但对于本教程,我对仓库密钥感到满意。

GitHub Actions 流程

现在我们有了几个管道(linting、testing、releasing、documentation),我们应该考虑主分支提交的动作流程!有一些我们需要注意的事项,其中一些是特定于 GitHub 的。

理想情况下,我们希望主分支的提交创建一个推送事件,从而触发测试和 linting 工作流。如果这些工作流成功,我们将运行发布工作流,该工作流负责基于传统提交检测是否需要版本提升。如果是这样,发布工作流将直接推送到主分支,提升版本,添加 git 标签并创建发布。发布的版本应当例如通过运行文档工作流来更新文档。

预期的动作流程

问题与考虑

  1. 如果你仔细阅读上一段或查看上面的流程图,你可能会注意到有两个主分支的提交。一个是初始的(即来自 PR),另一个是用于发布的。由于我们的lint.ymltest.yml在主分支的推送事件下会触发,因此它们会运行两次!为了节省资源,我们应该避免两次运行。为此,我们可以在版本提交消息中添加[skip ci]字符串。可以在pyproject.toml文件中为工具semantic_release定义自定义提交消息。
# pyproject.toml

...

[tool.semantic_release]
...
commit_message = "{version} [skip ci]" # skip triggering ci pipelines for version commits
...

2. 工作流pages.yml当前在推送到主分支时运行。更新文档可能只是我们希望在有新版本发布时做的事情(我们可能会在文档中引用版本)。我们可以相应地更改pages.yml文件中的触发器:

# pages.yml

name: Documentation

on:
  release:
    types: [published]

现在,构建文档将需要已发布的版本

3. 发布工作流应该依赖于 linting 和 testing 工作流的成功。目前,我们在工作流文件中没有定义依赖关系。我们可以让这些工作流依赖于特定分支上定义的工作流运行的完成,使用[workflow_run](https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#workflow_run)事件。然而,如果我们为workflow_run事件指定多个workflows

on:
  workflow_run:
    workflows: [Testing, Linting]
    types:
    - completed
    branches:
    - main

仅一个工作流需要完成!这不是我们所期望的。我们期望所有工作流都必须完成(并成功)。只有在这种情况下,发布工作流才应运行。这与在单个工作流中定义作业之间的依赖关系时得到的结果相反。更多关于这种不一致和不足的内容,请阅读这里

作为替代方案,我们可以使用流水线的顺序执行:

这种想法的一个大缺点是它 a) 不允许并行执行和 b) 我们将无法在 GitHub 中查看依赖图。

解决方案

目前,我认为解决上述问题的唯一方法是将工作流在一个协调器工作流中进行协调。

让我们创建这个工作流文件:

当我们推送到 main 分支时,协调器被触发。

只有在两个工作流:Testing 和 Linting 都成功时,才会调用发布工作流。这在 needs 关键字中定义。如果我们希望对作业执行(工作流)有更细致的控制,也可以考虑使用 if 关键字。但要注意,如 这篇文章 所述的令人困惑行为。

为了使我们的工作流 lint.ymltest.ymlrelease.yml 可以被另一个工作流调用,我们需要更新触发器:

# lint.yml

---
name: Linting

on:
  pull_request:
    branches:
      - main
  workflow_call:

jobs:
...
# test.yml

---
name: Testing

on:
  pull_request:
    branches:
      - main
  workflow_call:

jobs:
...
# release.yml

---
name: Release

on:
  workflow_call:

jobs:
...

现在,新工作流(Release)应该仅在质量检查工作流成功的情况下运行,这里指的是 linting 和 testing。

徽章

要创建徽章,这次我将使用平台 shields.io

这是一个为项目生成徽章的网站,徽章显示诸如版本、构建状态和代码覆盖率等信息。它提供了广泛的模板,并允许定制外观和创建自定义徽章。徽章会自动更新,提供项目的实时信息。

对于发布徽章,我选择了 GitHub release (latest SemVer)

徽章的 Markdown 可以复制并添加到 README.md 中:

我们的 GitHub 登录页面现在看起来是这样的 ❤(我稍微整理了一下,并提供了描述):

恭喜!你已经完成了本教程的主要部分!你已经学习了管理 软件发布 的基本步骤。我们首先手动创建了一个发布,然后利用 Conventional Commits 的力量,通过 CI pipeline 自动化我们的发布过程,处理版本控制。最后,我们在 README.md 文件中添加了 徽章,为用户提供了项目最新版本的清晰而简洁的显示。掌握了这些技巧,你将能够高效而有效地管理你的软件发布。

下一部分将是最后一部分,涵盖:容器化

[## 通过我的推荐链接加入 Medium - Johannes Schmidt

阅读 Johannes Schmidt 的每一篇故事(以及 Medium 上其他数千名作者的故事)。您的会员费直接…

johschmidt42.medium.com

奖励

确保使用规范提交

我们已经看到,按照定义格式的提交可以帮助我们进行版本控制。在一个协作项目中,我们可能希望对所有提交到默认分支的提交强制执行这种格式。两个流行的工具可以帮助开发者遵循规范提交格式:

然而,一些开发者觉得这些工具有点限制性,因此避免使用它们*。所以仅仅希望总是有规范提交可能不是一个好主意。因此,在服务器端强制执行规则,如规范提交格式,才是明智的!

*同样适用于 pre-commit 钩子,这也是我在这一系列中排除它们的原因。

不幸的是,目前(2023 年 5 月)在 GitHub 上基于规则阻止提交仍然不可行,因为 该功能仍在开发中。但我们可以通过分支保护规则CI 工作流尽可能接近这个目标。以下是我们在仓库中需要的策略:

  • 对受保护的默认分支(例如 main)的提交应该限制为拉取请求(PR)提交。

  • 只有 压缩提交 应该被允许*。

  • 合并拉取请求时展示的默认提交信息应该是拉取请求 标题

如果对受保护的默认分支(例如 main)的唯一提交方式是通过拉取请求(压缩提交 仅限),我们可以使用 GitHub Action,如 amannn/action-semantic-pull-request,确保拉取请求的标题符合 规范提交规范。这样,当我们 squash and merge PR 分支(假设所有必需的流水线成功)时,建议的提交信息就是 PR 标题,该标题之前由 GitHub action 运行检查过。

*Squash and merge 策略是一种流行的代码合并方法,它将功能分支中的多个提交合并为一个提交。这种方法创建了一个线性的、一致的 git 历史记录,其中每个提交代表一个特定的更改。然而,这种方法也有其缺点,因为它丢弃了详细的提交历史记录,这对于理解开发过程是有价值的。虽然可以使用 rebase 合并来保留这些信息,但这可能会给工作流带来复杂性。从这个角度来看,squash and merge 策略因其简单性而受到青睐。

工作流

让我们为这个策略创建 GitHub Actions 工作流:

触发事件 pull_request_target 的解释见 这里。我使用了建议的类型 openededitedsynchronizeGITHUB_TOKEN 被作为 env 传递给 action。因此,每当 PR 的标题发生变化时,管道就会触发。只有当 PR 的标题符合约定的提交格式时,管道才会成功。

请注意

你需要在主分支中拥有此配置,以便 action 能够运行(例如,它不会在初次添加 action 的 PR 中运行)。此外,如果你在 PR 中更改配置,当前 PR 中的更改将不会被反映 —— 只有在更改被合并到主分支后,随后的 PR 才会反映这些更改。

所以我们必须首先在默认分支 main 中拥有这个工作流,只有这样我们才能看到它的实际效果。

分支保护规则

接下来,在 GitHub 仓库的 设置 部分,我们可以为 main 分支创建一个 分支保护规则

主要分支的分支保护规则 — 作者提供的图片

现在一个提交需要一个通过状态检查(必需工作流)的 PR 才能合并*。

一个 必需的工作流 由拉取请求事件触发,并作为必需的状态检查出现,这会阻止合并拉取请求,直到必需的工作流成功。

所需工作流 — 作者提供的图片

组织所有者有能力在其组织内强制执行特定的工作流,例如要求对拉取请求进行状态检查。不幸的是,这个功能仅对组织可用,个人账户无法激活,因此无法阻止合并。

*请注意,规则不会在私有仓库中生效,直到它被 迁移到 GitHub Team 或 Enterprise 组织账户

Squash & merge 策略

最后,我们可以配置 PR 选项,以便在选择 squash and merge 按钮时使用 PR 的标题作为默认提交消息:

默认的提交消息在“压缩和合并”时 — 图片由作者提供

这样,我们会看到一个类似这样的窗口:

在 PR 中的“压缩和合并”对话框 — 图片由作者提供

请注意,开发者可能会在合并过程中更改标题名称,这将绕过策略!

尽管我们还不能完全确保在 GitHub 上使用传统的提交方式,但我们应尽量做到尽可能接近。

设置 Python 项目: 第六部分

原文:towardsdatascience.com/setting-up-python-projects-part-vi-cbdbf28eff53

掌握 Python 项目设置的艺术: 一步一步的指南

Johannes SchmidtTowards Data Science Johannes Schmidt

·发布于 Towards Data Science ·阅读时间 26 分钟·2023 年 4 月 10 日

--

图片由 Amira El Fohail 提供,来源于 Unsplash

无论你是经验丰富的开发者还是刚刚开始接触🐍 Python,了解如何构建稳健且可维护的项目都是很重要的。本文教程将引导你完成使用一些行业内最受欢迎和有效的工具来设置 Python 项目的过程。你将学习如何使用 GitHubGitHub Actions 进行版本控制和持续集成,以及其他用于测试、文档编写、打包和分发的工具。该教程灵感来源于 Hypermodern Python新 Python 项目的最佳实践。然而,这并不是唯一的方法,你可能有不同的偏好或意见。教程旨在对初学者友好,同时也涵盖一些高级主题。在每个部分,你将自动化一些任务,并为你的项目添加徽章以展示你的进展和成就。

本系列的代码库可以在 github.com/johschmidt42/python-project-johannes 找到

要求

  • 操作系统: Linux, Unix, macOS, Windows (WSL2,例如 Ubuntu 20.04 LTS)

  • 工具: python3.10, bash, git, tree

  • 版本控制系统 (VCS) 主机: GitHub

  • 持续集成 (CI) 工具: GitHub Actions

预计你对版本控制系统 (VCS) git 已经熟悉。如果不熟悉,这里有一个复习资料:Git 入门介绍

提交将基于 最佳 git 提交实践约定式提交。对于 PyCharm,有 约定式提交插件 或者 VSCode 扩展 可以帮助你按照这种格式撰写提交。

概述

结构

  • 容器化

  • Docker

  • Dockerfile

  • Docker 镜像

  • Docker 容器

  • Docker 阶段 (基础、构建、生产)

  • 容器注册中心 (ghcr.io)

  • Docker 推送

  • CI (build.yml & build_and_push.yml)

  • 徽章 (构建)

  • 奖励 (trivy)

在这篇文章中,我们将深入探讨容器化的概念及其好处,以及如何与Docker结合使用,以创建和管理容器化应用。我们将使用GitHub Actions来持续构建 Docker 镜像并在发布新版本时将其上传到我们的仓库。

容器化

容器化是一项现代技术,它彻底改变了软件应用的开发、部署和管理方式。近年来,由于它能够解决软件开发和部署中的一些重大挑战,已获得广泛采用。

简单来说,容器化是将应用程序及其所有依赖项打包成一个单一的容器的过程。这个容器是一个轻量、可移植、自给自足的单元,可以在不同的计算环境中一致地运行。它为应用程序提供了一个隔离的环境,确保它在任何底层基础设施下都能一致运行。它使开发者能够创建可扩展、可移植且易于管理的应用程序。此外,容器通过将应用程序与宿主系统隔离,提供了额外的安全层。如果你听到有人说“它在我的电脑上能工作”,这已经不再有效,因为你可以并且应该在 Docker 容器中测试你的应用。这确保了它在不同环境中的一致性。

总之,容器化是一项强大的技术,它允许开发者创建可靠、高效且易于管理的容器化应用,使他们能够专注于开发优秀的软件。

Docker

Docker 是一个流行的容器化平台,允许开发人员创建、部署和运行容器化应用程序。它提供了一系列工具和服务,使得将应用程序打包和部署为容器化格式变得简单。使用 Docker,开发人员可以在几分钟内创建、测试和部署应用程序,而不是几天或几周。

要使用 docker 创建这样的容器化应用程序,我们需要

  1. Dockerfile构建一个Docker 镜像

  2. 从 Docker 镜像创建一个容器

为此,我们将使用 docker CLI。

Dockerfile

Dockerfile 是一个包含所有构建给定镜像所需命令的文本文件。它遵循特定的格式和指令集,你可以在这里找到相关信息。

本节的目标是创建一个构建我们 Python 包的 wheel 文件的 Dockerfile:

**FROM python:3.10-slim** 
**WORKDIR /app** # install poetry **ENV POETRY_VERSION=1.2.0
RUN pip install "poetry==$POETRY_VERSION"** # copy application **COPY ["pyproject.toml", "poetry.lock", "README.md", "./"]
COPY ["src/", "src/"]** # build wheel **RUN poetry build --format wheel** # install package **RUN pip install dist/*.whl**

这个 Dockerfile 本质上是一组指令,告诉 Docker 如何为 Python 应用程序构建一个容器。它以python:3.10-slim作为基础镜像开始,这是一种已经预装了一些基本库和依赖项的 Python 3.10 精简版镜像。

第一个指令WORKDIR /app将工作目录设置为容器内的/app,应用程序将被放置在此目录中。

下一个指令ENV POETRY_VERSION=1.2.0设置一个名为POETRY_VERSION的环境变量为1.2.0,此变量将在下一条命令中用于安装 Poetry 包管理器。

RUN pip install "poetry==$POETRY_VERSION"命令在容器内安装 Poetry 包管理器,用于管理 Python 应用程序的依赖项。

下一个指令COPY ["pyproject.toml", "poetry.lock", "README.md", "./"]将项目文件(包括pyproject.tomlpoetry.lockREADME.md)复制到容器中。

README.md文件是必需的,因为在 pyproject.toml 中有引用。没有它,我们将无法构建 wheel。

指令COPY ["src/", "src/"]将应用程序的源代码复制到容器中。

RUN poetry build --format wheel命令使用poetry.lock文件和应用程序的源代码为 Python 应用程序构建一个Python wheel包。

最后,最后一条指令RUN pip install dist/*.whl使用pip安装包,并安装位于dist目录中的生成的.whl包文件。

总之,这个 Dockerfile 设置了一个包含 Python 3.10 和已安装 Poetry 的容器,复制了应用程序源代码和依赖项,构建了一个包 wheel 并安装它。

这还不会运行应用程序。但不用担心,我们将在接下来的章节中更新它。我们必须首先了解使用 Docker 的流程。

Docker 镜像

我们已经创建了一个 Dockerfile,其中包含构建 Docker 镜像的指令。为什么我们需要 Docker 镜像?因为它允许我们构建 Docker 容器

让我们运行 docker build 命令来创建我们的镜像:

**> docker build --file Dockerfile --tag project:latest .**

...
 => [7/7] RUN pip install dist/*.whl                                                                                                                                                                                                                                                                              30.7s
 => exporting to image                                                                                                                                                                                                                                                                                             0.5s 
 => => exporting layers                                                                                                                                                                                                                                                                                            0.5s 
 => => writing image sha256:bb2acf440f4cf24ac00f051b1deaaefaf4e41b87aa26c34342cbb6faf6b55591                                                                                                                                                                                                                       0.0s 
 => => naming to docker.io/library/project:latest

此命令用于从 Dockerfile 构建 Docker 镜像,并使用指定的名称和版本标记它。让我们来解析一下命令:

  • docker build:这是用于构建 Docker 镜像的命令。

  • --file Dockerfile:此选项指定用于构建镜像的 Dockerfile 的路径和名称。在这种情况下,它被简单地命名为 Dockerfile,所以它使用了默认名称。

  • --tag project:latest:此选项指定要创建的镜像的名称和版本。在这种情况下,镜像名称为 project,版本为 latestproject 是给镜像的名称,而 latest 是版本号。你可以用你选择的名称和版本替换 projectlatest

  • .:此选项指定了构建上下文,即用于构建镜像的文件位置。在这种情况下,. 指当前执行命令的目录。

因此,当执行此命令时,Docker 会读取当前目录中的 Dockerfile,并使用它来构建一个名为 project:latest 的新镜像。我们可以通过运行以下命令找到有关结果镜像(及其他镜像)的更多信息:

**> docker images**

REPOSITORY   TAG       IMAGE ID       CREATED         SIZE
project      latest    bb2acf440f4c   2 minutes ago   271MB

我们的镜像大小为 271 mb。大小将在后续减少。

Docker 容器

我们可以使用 docker run 命令从 Docker 镜像创建/运行一个 Docker 容器。该命令需要一个参数,即镜像的名称。例如,如果你的镜像名为 myimage,你可以使用以下 命令 运行它:docker run myimage

如果我们像这样运行我们的应用:

**> docker run -it --rm project:latest**

它将打开一个 Python 终端(你可以使用 CTRL + DCMD + D 关闭会话;-it 选项用于以交互模式运行容器,并提供伪终端(终端仿真)。这允许你与容器的 shell 进行交互,并实时查看其输出。-rm 选项用于在容器退出时自动删除容器。)

Python 3.10.10 (main, Mar 23 2023, 03:59:34) [GCC 10.2.1 20210110] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>>

为什么它会打开一个 Python 会话?这是因为 Docker 镜像的 entrypoint 默认指向标准 python:3.10-slim 镜像中的 Python 解释器。如果我们想查看容器内部,我们必须覆盖 entrypoint。因为 bash 默认安装在此构建中,我们可以通过以下命令运行 Docker 容器并进入其中:

**> docker run -it --rm project:latest /bin/bash**

root@76eb4cb2d8fb:/app#

所以我们用 /bin/bash 覆盖了 entrypoint。

现在我们可以检查容器内部的内容:

app
├── README.md
├── dist
│   └── example_app-0.3.0-py3-none-any.whl
├── poetry.lock
├── pyproject.toml
└── src
    └── example_app

我们可以使用以下命令检查已安装的包:

**> pip freeze**

...
dulwich==0.20.50
**example-app** @ file:///app/dist/example_app-0.3.0-py3-none-any.whl
fastapi==0.85.2
...

太好了,我们可以进入容器,这对于故障排除非常有用。但我们如何让它运行我们的应用程序?我们的应用程序安装在哪里?默认情况下,包可以在 Python 安装的site-packages 目录中找到。要获取这些信息,我们可以使用pip show 命令:

**> pip show example-app**

Name: example-app
Version: 0.3.0
Summary: 
Home-page: https://github.com/johschmidt42/python-project-johannes
Author: Johannes Schmidt
Author-email: johannes.schmidt.vik@gmail.com
License: MIT
Location: **/usr/local/lib/python3.10/site-packages**
Requires: fastapi, httpx, uvicorn
Required-by:

由于uvicorn(我们的 ASGI 服务器实现)默认安装,我们可以cd 进入 /usr/local/lib/python3.10/site-packages/example_app

并使用uvicorn 命令运行应用程序:

**> uvicorn app:app --host 0.0.0.0 --port 80 --workers 1**

INFO:     Started server process [17]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:80 (Press CTRL+C to quit)

其中 app:app 遵循 <file_name>:<variable_name> 模式。

应用程序在 docker 容器中运行在端口 80,并使用1 个工作进程。为了在主机(你的机器)上访问,我们需要暴露容器端口并将其发布到主机。这可以通过在 docker run 命令中添加 --expose--publish 标志来完成。或者,我们可以通过在Dockerfile中定义某个端口来让容器暴露该端口。我们稍后会做这个。之前,我们要做的是:

我们的应用程序可以在 site-packages 目录中找到。这要求我们在运行 uvicorn app:app 命令之前更改目录。如果我们想避免更改目录,我们可以创建一个文件来为我们导入应用程序。以下是一个示例:

添加一个 main.py

# main.py

from example_app.app import app

if __name__ == '__main__':
    print(app.title)

main.py 中导入应用程序,以便 uvicorn 可以使用它。如果我们现在将这个文件复制到我们的 /app 目录:

# Dockerfile
...
**COPY ["main.py", "./"]**
...

我们可以运行应用程序

**> uvicorn main:app --host 0.0.0.0 --port 80 --workers 1**

INFO:     Started server process [8]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:80 (Press CTRL+C to quit)

太好了。现在让我们将这个命令设置为启动容器时的入口点

FROM python:3.10-slim

WORKDIR /app

# install poetry
ENV POETRY_VERSION=1.2.0
RUN pip install "poetry==$POETRY_VERSION"

# copy application
COPY ["pyproject.toml", "poetry.lock", "README.md", "**main.py**", "./"]
COPY ["src/", "src/"]

# build wheel
RUN poetry build --format wheel

# install package
RUN pip install dist/*.whl

# expose port
**EXPOSE 80**

# command to run
**CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--workers", "1"]**

现在我们将 main.py 文件复制到 /app 目录。EXPOSE 指令通知 Docker 容器在运行时监听指定的网络端口。在这种情况下,它暴露了端口 80

CMD 指令指定了在容器内运行的命令。在这里,它运行命令 uvicorn main:app --host 0.0.0.0 --port 80 --workers 1。这个命令启动了一个 uvicorn 服务器,运行 main:app 应用程序,监听主机 0.0.0.0 和端口 80,使用 1 个工作进程。

然后我们可以使用docker run 命令运行容器:

> **docker run -p 9000:80 -it --rm project:latest**

[2023-01-30 21:04:33 +0000] [1] [INFO] Starting gunicorn 20.1.0
[2023-01-30 21:04:33 +0000] [1] [INFO] Listening at: http://0.0.0.0:80 (1)
[2023-01-30 21:04:33 +0000] [1] [INFO] Using worker: uvicorn.workers.UvicornWorker
[2023-01-30 21:04:33 +0000] [7] [INFO] Booting worker with pid: 7
[2023-01-30 21:04:34 +0000] [7] [INFO] Started server process [7]
[2023-01-30 21:04:34 +0000] [7] [INFO] Waiting for application startup.
[2023-01-30 21:04:34 +0000] [7] [INFO] Application startup complete.

docker run 命令中的 -p 标志用于将容器的端口发布到主机。在这种情况下,它将主机上的端口 9000 映射到容器上的端口 80。这意味着任何发送到主机端口 9000 的流量都会被转发到容器端口 80

我们可以看到在容器中运行的应用程序可以被访问:

运行在 Docker 容器中的 fastAPI 应用程序 — 作者图片

重要备注:我建议在生产构建中使用 gunicorn 代替 uvicorn!为了完整性,这里是 Dockerfile 的替代版本:

FROM python:3.10-slim

WORKDIR /app

# install poetry
ENV POETRY_VERSION=1.2.0
RUN pip install "poetry==$POETRY_VERSION"

# install gunicorn (ASGI web implementation)
**RUN pip install gunicorn==20.1.0**

# copy application
COPY ["pyproject.toml", "poetry.lock", "README.md", "./"]
COPY ["src/", "src/"]

# build wheel
RUN poetry build --format wheel

# install package
RUN pip install dist/*.whl

# expose port
EXPOSE 80

# command to run
CMD ["**gunicorn**", "main:app", "**--bind**", "**0.0.0.0:80**", "--workers", "1", "**--worker-class", "uvicorn.workers.UvicornWorker**"]

这两个有什么区别?

Uvicorn 是一个支持 ASGI 协议的 ASGI 服务器。它建立在 uvloophttptools 之上,并以其性能优势而闻名。然而,它作为进程管理器的能力还有待提高

Gunicorn 另一方面是一个成熟且功能全面的服务器和进程管理器。它是从 Ruby 的 Unicorn 项目移植而来的预叉工人模型,并与各种 web 框架广泛兼容

Docker 阶段

Docker 阶段是一项功能,允许你在 Dockerfile 中创建多个阶段。每个阶段可以有自己的 基础镜像 和指令集。你可以选择性地将工件从一个阶段复制到另一个阶段,留下你不想要的内容。此功能很有用,因为它可以通过减少 Docker 镜像的大小和复杂性来优化 Docker 镜像

使用 Docker 阶段,我们可以(并且应该!)优化我们的 Docker 镜像。所以我们想要实现的是:

  • poetry 不应该出现在生产构建中

  • 生产构建应该仅包含运行应用程序所需的最少内容

这就是我们要做到的方法:我们创建一个干净的 基础 阶段。从基础阶段,我们有一个 构建 阶段,安装 poetry 并构建 wheel。另一个阶段,生产, 可以从构建阶段复制这个工件(.whl 文件)并使用它。这样我们可以避免在生产构建中安装 poetry,同时也将其限制为仅包含必要内容,从而减少最终镜像的大小。

关于 Docker 中的 poetry

我见过不同的策略将 poetry 与 Docker 结合使用。

  • 创建一个 虚拟环境,然后将整个 venv 从一个阶段复制到另一个阶段。

  • poetry.lock 文件创建 requirements.txt 文件,并使用这些文件通过 pip 安装要求。

在第一种情况下,Poetry 在构建镜像时安装。在第二种情况下,Poetry 不在 Docker 构建中安装,但需要使用 Poetry 来创建 requirements.txt 文件。

在这两种情况下,我们需要以某种方式安装 Poetry——无论是在 Docker 镜像中还是在运行 Docker 构建命令的主机上。

将 Poetry 放入 Docker 中会稍微增加构建时间,而将其放在 Docker 之外则需要你在主机上安装 Poetry,并为构建过程添加额外步骤(从 poetry.lock 创建 requirements.txt 文件)。在 Docker 构建 CI 流水线 的上下文中,主机上的 Poetry 安装可以被缓存,构建过程通常会更快。这两种方法都有其优缺点,最佳方法将取决于你的具体需求和偏好。

为了本教程的目的,我将保持简单,使用上面描述的 venv 策略。所以这是新的 Dockfile,其中包含阶段(为了识别不同的阶段,通过 FROM 语句分隔,我将这些行用 粗体 高亮):

**FROM python:3.10-slim as base**

WORKDIR /app

# ignore 'Running pip as the root user...' warning
ENV PIP_ROOT_USER_ACTION=ignore

# update pip
RUN pip install --upgrade pip

**FROM base as builder**

# install poetry
ENV POETRY_VERSION=1.3.1
RUN pip install "poetry==$POETRY_VERSION"

# copy application
COPY ["pyproject.toml", "poetry.lock", "README.md", "./"]
COPY ["src/", "src/"]

# build wheel
RUN poetry build --format wheel

**FROM base as production**

# expose port
EXPOSE 80

# copy the wheel from the build stage
COPY --from=builder /app/dist/*.whl /app/

# install package
RUN pip install /app/*.whl

# copy entrypoint of the app
COPY ["main.py", "./"]

# command to run
CMD ["uvicorn", "main:app","--host", "0.0.0.0", "--port", "80", "--workers", "1"] 

这个 Dockerfile 定义了一个包含三个阶段的多阶段构建:basebuilderproduction

  1. base 阶段从 Python 3.10-slim 镜像开始,将工作目录设置为 /app。它还设置了一个环境变量以忽略关于以 root 用户身份运行 pip 的警告,并将 pip 更新到最新版本。

  2. builder 阶段从 base 阶段开始,并使用 pip 安装 Poetry。然后,它复制应用程序文件,并使用 Poetry 为应用程序构建一个 wheel 文件。

  3. production 阶段再次从 base 阶段开始,并暴露端口 80。它复制在 builder 阶段构建的 wheel 文件,并使用 pip 安装。它还复制应用程序的入口点,并将命令设置为使用 uvicorn 运行应用程序。

我们现在可以使用以下命令重新构建我们的 Docker 镜像:

**> docker build --file Dockerfile --tag project:latest --target production .**

我们可以使用 --target 标志指定我们希望构建的阶段。

文件大小现在减少了 ~70 Mb,总共为 197MB

**> docker images**

REPOSITORY   TAG       IMAGE ID       CREATED          SIZE
project      latest    f1be09c32a55   14 minutes ago   **197MB**

我们可以使用以下命令运行它:

**> docker run -p 9000:80 -it --rm project:latest**

API 将在浏览器中通过 localhost:9000 提供。

fastAPI 应用程序在 Docker 容器中运行 — 作者图片

容器注册表

容器注册表是用于存储和访问容器镜像的仓库或仓库集合。容器注册表可以支持基于容器的应用程序开发,通常作为 DevOps 过程的一部分。它们可以直接连接到像 Docker 和 Kubernetes 这样的容器编排平台。

最受欢迎的容器注册表是 Docker Hub。每个云服务提供商都有自己的注册表。Azure 的 ACR、AWS 的 ECR 以及更多。GitHub 有一个名为 GitHub Packages 的包注册解决方案。

由于我们到目前为止基本上都在 GitHub 上完成了所有操作,所以在本教程中我们将使用 GitHub Packages

GitHub Packages — 作者图片

GitHub 为普通用户提供了免费的层级。这允许我们为我们的容器使用最多 500 MB 的存储空间。这对我们的应用程序来说足够了。

GitHub Packages 定价和免费层 — 作者图片

Docker push

docker push 命令用于 上传 Docker 镜像到容器注册表。这允许你与其他人分享你的镜像或将其部署到不同的环境中。该命令需要你指定要推送的镜像名称和要推送到的注册表名称作为参数。在推送镜像之前,你需要登录到注册表。

这是将 Docker 镜像推送到容器注册表的步骤:

  1. 标记(重命名)你的镜像,使用注册表名称:docker tag project:latest <registry-name>/<project>:latest

  2. 登录到容器注册表:docker login <registry-url>

  3. 推送你的镜像到注册表:docker push <registry-name>/<project>:latest

我们将把镜像推送到GitHub Packages

GitHub Packages

GitHub Packages 仅支持使用个人访问令牌进行身份验证(2023 年 2 月)。但我们在第五部分创建了一个个人访问令牌(PAT),所以我们也可以在这里使用它。

我们需要登录到容器注册表

> CR_PAT="XYZ"
> echo $CR_PAT | docker login ghcr.io -u johschmidt42 --password-stdin

Login Succeeded

这是一个 shell 命令,使用管道将两个命令连接起来。管道是一个符号(|),它将一个命令的输出重定向到另一个命令的输入。在这种情况下,第一个命令是 echo $(CR_PAT),它将 CR_PAT 变量的值打印到标准输出。第二个命令是 docker login ghcr.io -u johschmidt42 --password-stdin,它使用 johschmidt42 作为用户名,并从标准输入读取密码来登录 ghcr.io。通过使用管道,echo 命令的输出成为 docker login 命令的输入,这意味着 CR_PAT 变量的值被用作登录密码。

让我们把它添加到我们的 Makefile 中

# Makefile

...

login: ## login to ghcr.io using a personal access token (PAT)
 @if [ -z "$(CR_PAT)" ]; then\
  echo "CR_PAT is not set";\
 else\
  echo $(CR_PAT) | docker login ghcr.io -u johschmidt42 --password-stdin;\
 fi

...

我们需要在bash中写一个小的 if-else 语句,以便这个目标登录需要我们首先设置CR_PAT

这使我们现在可以像这样登录:

> **make login CR_PAT="XYZ"**

对于任何对 bash 命令感到困惑的人,这里有一个解释:

这个 shell 命令使用 if-else 语句来检查一个条件,并根据不同的条件执行不同的操作。条件是 [ -z "$(CR_PAT)" ],这意味着“CR_PAT 变量为空吗?” -z 标志测试零长度。$(CR_PAT) 部分在括号内展开 CR_PAT 变量的值。如果条件为真,那么 then 后的操作会被执行,即 echo "CR_PAT is not set"。这将一条消息打印到标准输出。如果条件为假,则执行 else 后的操作,即 echo $(CR_PAT) | docker login ghcr.io -u johschmidt42 --password-stdin。每行末尾的 \ 意味着命令继续在下一行。末尾的 fi 标志着 if-else 语句的结束。

现在我们已登录,我们需要重命名 docker 文件,以便使用 docker tag 命令将其推送到远程注册表:

> **docker tag project:latest ghcr.io/johschmidt42/project:latest**
# Makefile

...

tag: ## tag docker image to ghcr.io/johschmidt42/project:latest
 @docker tag project:latest ghcr.io/johschmidt42/project:latest

...

我们可以通过以下命令查看有关我们的 docker 镜像的信息:

**> docker images**

REPOSITORY                     TAG       IMAGE ID       CREATED             SIZE
project                        latest    f1be09c32a55   About an hour ago   197MB
ghcr.io/johschmidt42/project   latest    f1be09c32a55   About an hour ago   197MB

如果我们现在尝试将镜像推送到注册表,它会失败

> **docker push ghcr.io/johschmidt42/project:latest** 
denied: permission_denied: The token provided does not match expected scopes.
# Makefile

...

push: tag ## docker push to container registry (ghcr.io)
 @docker push ghcr.io/johschmidt42/project:latest

...

这是因为我们的令牌没有预期的范围。消息没有告诉我们需要哪些范围(权限),但我们可以在文档中找到这些信息。

所以我们需要添加这些范围:

  • read:packages

  • delete:packages

GH_TOKEN — 作者提供的图片

现在我们看到它被推送到容器注册表:

> **make push**

1a3ba1c1448c: Pushed 
0ad139eaf32a: Pushing [========================================>          ]   43.3MB/54.08MB
0e0b5d4aea1e: Pushed 
a179cef7de6a: Pushing [==================================================>]  18.15MB
22f1e17dcfe4: Pushed 
805fe34ec92b: Pushing [==================================================>]  12.76MB
fa04dee82d1b: Pushed 
42d55226bf51: Pushing [==================================================>]  30.83MB
7d13900c8624: Pushed 
650abce4b096: Pushing [==============>                                    ]  22.72MB/80.51MB
latest: digest: sha256:57d409bb564f465541c2529e77ad05a02f09e2cc22b3c38a93967ce1b277f58a size: 2414 

在 GitHub 中,profile下的packages标签中现在有一个 docker 镜像:

你的个人资料 — 作者提供的图片

GitHub packages — 作者提供的图片

点击它,可以将包连接到我们的仓库

GitHub packages: 连接仓库 — 作者提供的图片

现在,这个 docker 镜像可以在仓库的首页找到 github.com/johschmidt42/python-project-johannes

GitHub Packages 首页 — 作者提供的图片

很好。我们已经创建了一个 Docker 镜像,将其推送到远程仓库,链接到我们当前的版本,现在每个人都可以通过运行 docker pull 命令来测试我们的应用:

**> docker pull ghcr.io/johschmidt42/python-project-johannes:v0.4.1**

CI/CD:

CI/CD 代表持续集成和持续部署。通过 Docker 镜像,CI/CD 可以自动化构建、测试和部署镜像的过程。在本教程中,我们将重点关注持续构建我们的 Docker 镜像并在有新版本时将其推送到远程容器注册表(CI)。然而,在本教程中,我们不会部署镜像(CD)(请关注未来的博客文章)。我们的 Docker 容器将在以下情况构建:

  • 提交到一个有打开的 PR 的分支

  • 提交到默认分支(main)

  • 创建一个新版本(这会将镜像推送到容器注册表)

第一个动作帮助我们及早发现错误。第二个动作使我们能够在 README.md 文件中创建并使用徽章。最后一个动作创建 Docker 镜像的新版本并将其推送到容器注册表。整体动作流程总结如下:

GitHub Actions 流程 — 作者提供的图片

让我们创建build管道:

这个 GitHub Actions 工作流构建了一个 Docker 镜像。当有pushpull requestmain分支时,或者当工作流被调用时,它会被触发。这个工作是命名为“Build”,包含两个步骤。第一步使用actions/checkout动作检出仓库。第二步通过运行make build命令来构建 Docker 镜像。就是这样。

工作流运行 — 作者提供的图片

我们还需要相应地更新 orchestrator.yml

当我们推送到 main 分支时,会触发 orchestrator。

orchestrator.yml — 图片由作者提供

为了在我们的 GitHub 存储库中发布每个 新版本 时构建新的 Docker 镜像,我们需要创建一个新的 GitHub Actions 工作流:

这是一个 GitHub Actions 工作流,当发布版本时,它会构建并推送 Docker 镜像到 GitHub Container Registry (ghcr.io)。名为“build_and_push”的作业有三个步骤。第一步使用 actions/checkout 操作检出存储库。第二步使用 docker/login-action 登录到 GitHub Container Registry。第三步使用 docker/build-push-action 构建并推送 Docker 镜像。

build_and_push — 图片由作者提供

请注意,为了使用 docker/login-action@v2 登录到 GitHub Container Registry,我们需要提供 GH_TOKEN 这个 PAT,正如我们在 Part V 中定义的那样。

以下是最后一步中使用的参数的简要说明 docker/build-push-action@4

  • context: . 指定构建上下文为当前目录。

  • push: true 指定在构建后将图像推送到注册表。

  • tags: ghcr.io/${{ github.repository }}:${{ github.ref_name }} 指定图像的标签。在这种情况下,它使用存储库的名称和触发工作流的分支或标签名称进行标记。

  • labels: 指定图像的标签。在这种情况下,它设置了图像的源、标题和版本标签。

  • target: production 指定在多阶段 Dockerfile 中构建的目标阶段。

  • github-token: ${{ secrets.GH_TOKEN }} 指定用于认证的 GitHub 令牌。

我们可以在 GitHub 上看到我们的新 Docker 镜像:

GitHub 上的图像 — 图片由作者提供

徽章:

对于这部分,我们将像以前一样向我们的 repo 添加一个徽章。这一次是针对 构建 管道的。当我们点击 build.yml 工作流运行时,可以检索徽章:

创建状态徽章 — 图片由作者提供

从 GitHub 上的工作流文件创建状态徽章

复制状态徽章 Markdown — 图片由作者提供

并选择主分支。徽章的 Markdown 可以复制并添加到 README.md 中:

我们的 GitHub 登陆页面现在看起来是这样的 ❤:

README.md 中的第五个徽章:构建 — 图片由作者提供

如果你想了解如何神奇地显示 main 中最后一次管道运行的当前状态,请查看 GitHub 上的提交 statuses API

这就结束了教程的核心部分!我们成功创建了一个Dockerfile,并使用它构建了一个Docker 镜像,使我们能够在Docker 容器中运行我们的应用程序。此外,我们实施了一个CI/CD 管道,自动构建我们的 Docker 镜像并将其推送到容器注册表。最后,我们在 README.md 文件中添加了一个徽章,向世界展示我们功能齐全的构建管道!

这就是最后一部分!这个教程是否帮助你在 GitHub 上构建了一个 Python 项目?有任何改进建议吗?让我知道你的想法!

[## 通过我的推荐链接加入 Medium - Johannes Schmidt

阅读 Johannes Schmidt 的每个故事(以及 Medium 上的其他成千上万名作者的故事)。你的会员费直接…

johschmidt42.medium.com](https://johschmidt42.medium.com/membership?source=post_page-----cbdbf28eff53--------------------------------)

奖励

清理:

以下是使用 Docker CLI 时的一些有用命令:

要停止所有容器并删除它们:

> docker stop $(docker ps -a -q) && docker rm $(docker ps -a -q)

要删除所有未使用的 Docker 镜像:

> docker rmi $(docker images --filter "dangling=true" -q --no-trunc)

Docker 镜像中的漏洞扫描

漏洞扫描是确保 Docker 镜像安全性的关键步骤。它帮助你识别并修复可能会危害应用程序或数据的潜在弱点或风险。其中一个可以帮助你的工具是 trivy

这个开源工具是一个简单快速的Docker 镜像漏洞扫描器,支持多种格式和来源。我将演示如何在本地使用它。理想情况下,你应该考虑创建一个 GitHub Actions 工作流,每当你构建 Docker 镜像时都运行!

我们首先应根据 文档 安装 trivy。在构建生产 Docker 镜像后

> docker build --file Dockerfile --tag project:latest --target production .

我们可以用以下命令扫描构建的镜像

> **trivy image project:latest --scanners vuln --format table --severity  CRITICAL,HIGH**

这将从数据库下载已知的最新漏洞并扫描镜像。输出将以表格形式--format table显示,仅包含严重性为 CRITICAL 或 HIGH 的发现--severity CRITICAL,HIGH

project:latest (debian 12.0)

Total: 27 (HIGH: 27, CRITICAL: 0)

┌────────────────┬────────────────┬──────────┬───────────────────┬───────────────┬──────────────────────────────────────────────────────────────┐
│    Library     │ Vulnerability  │ Severity │ Installed Version │ Fixed Version │                            Title                             │
├────────────────┼────────────────┼──────────┼───────────────────┼───────────────┼──────────────────────────────────────────────────────────────┤
│ linux-libc-dev │ CVE-2013-7445  │ **HIGH**     │ 6.1.27-1          │               │ kernel: memory exhaustion via crafted Graphics Execution     │
│                │                │          │                   │               │ Manager (GEM) objects                                        │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2013-7445                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2019-19449 │          │                   │               │ kernel: mounting a crafted f2fs filesystem image can lead to │
│                │                │          │                   │               │ slab-out-of-bounds read...                                   │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2019-19449                   │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2019-19814 │          │                   │               │ kernel: out-of-bounds write in __remove_dirty_segment in     │
│                │                │          │                   │               │ fs/f2fs/segment.c                                            │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2019-19814                   │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2021-3847  │          │                   │               │ low-privileged user privileges escalation                    │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2021-3847                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2021-3864  │          │                   │               │ descendant's dumpable setting with certain SUID binaries     │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2021-3864                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-1194  │          │                   │               │ use-after-free in parse_lease_state()                        │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-1194                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-2124  │          │                   │ 6.1.37-1      │ OOB access in the Linux kernel's XFS subsystem               │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-2124                    │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-2156  │          │                   │               │ IPv6 RPL protocol reachable assertion leads to DoS           │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-2156                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-2176  │          │                   │               │ Slab-out-of-bound read in compare_netdev_and_ip              │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-2176                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-3090  │          │                   │ 6.1.37-1      │ out-of-bounds write caused by unclear skb->cb                │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-3090                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-31248 │          │                   │               │ use-after-free in nft_chain_lookup_byid()                    │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-31248                   │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-32247 │          │                   │ 6.1.37-1      │ session setup memory exhaustion denial-of-service            │
│                │                │          │                   │               │ vulnerability                                                │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-32247                   │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-32248 │          │                   │               │ tree connection NULL pointer dereference denial-of-service   │
│                │                │          │                   │               │ vulnerability                                                │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-32248                   │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-32250 │          │                   │               │ session race condition remote code execution vulnerability   │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-32250                   │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-32252 │          │                   │               │ session NULL pointer dereference denial-of-service           │
│                │                │          │                   │               │ vulnerability                                                │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-32252                   │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-32254 │          │                   │               │ tree connection race condition remote code execution         │
│                │                │          │                   │               │ vulnerability                                                │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-32254                   │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-32257 │          │                   │               │ session race condition remote code execution vulnerability   │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-32257                   │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-32258 │          │                   │               │ session race condition remote code execution vulnerability   │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-32258                   │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-3268  │          │                   │               │ out-of-bounds access in relay_file_read                      │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-3268                    │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-3269  │          │                   │               │ distros-[DirtyVMA] Privilege escalation via                  │
│                │                │          │                   │               │ non-RCU-protected VMA traversal                              │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-3269                    │
│                ├────────────────┤          │                   │               ├──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-3390  │          │                   │               │ UAF in nftables when nft_set_lookup_global triggered after   │
│                │                │          │                   │               │ handling named and anonymous sets...                         │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-3390                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-3397  │          │                   │               │ slab-use-after-free Write in txEnd due to race condition     │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-3397                    │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-35001 │          │                   │               │ stack-out-of-bounds-read in nft_byteorder_eval()             │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-35001                   │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-35788 │          │                   │ 6.1.37-1      │ out-of-bounds write in fl_set_geneve_opt()                   │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-35788                   │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-35827 │          │                   │               │ race condition leading to use-after-free in ravb_remove()    │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-35827                   │
│                ├────────────────┤          │                   ├───────────────┼──────────────────────────────────────────────────────────────┤
│                │ CVE-2023-3640  │          │                   │               │ a per-cpu entry area leak was identified through the         │
│                │                │          │                   │               │ init_cea_offsets function when...                            │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-3640                    │
├────────────────┼────────────────┤          ├───────────────────┼───────────────┼──────────────────────────────────────────────────────────────┤
│ perl-base      │ CVE-2023-31484 │          │ 5.36.0-7          │               │ CPAN.pm before 2.35 does not verify TLS certificates when    │
│                │                │          │                   │               │ downloading distributions over...                            │
│                │                │          │                   │               │ https://avd.aquasec.com/nvd/cve-2023-31484                   │
└────────────────┴────────────────┴──────────┴───────────────────┴───────────────┴──────────────────────────────────────────────────────────────┘

存在2 个操作系统库,其严重性为 HIGH。这两个库都没有提供可以升级到的版本(参见 Fixed Version 列),以修复我们 Docker 镜像中的漏洞。因此,我们将按以下方式处理它们:

linux-libc-dev

这是一个运行应用程序时不需要的包。因此,最好还是卸载它!

perl-base

这个操作系统包提供了 Perl 解释器,并且是我们应用程序使用的其他库所必需的。这意味着我们不能卸载它,也不能修复它。因此,我们必须接受风险。接受已知漏洞应由管理层确认和批准。然后,我们可以将漏洞,例如 CVE-2023–31484,添加到 .trivyignore 文件中,再次运行扫描程序。

这里是变更内容:

# Dockerfile
...

FROM base as production

# expose port
EXPOSE 80

# copy the wheel from the build stage
COPY --from=builder /app/dist/*.whl /app/

# install package
RUN pip install /app/*.whl

# copy entrypoint of the app
COPY ["main.py", "./"]

**# Remove linux-libc-dev (CVE-2023-31484)
RUN apt-get remove -y --allow-remove-essential linux-libc-dev**

# command to run
CMD ["uvicorn", "main:app","--host", "0.0.0.0", "--port", "80", "--workers", "1"]
# .trivyignore

# vulnerabilities to be ignored by trivy are added here
CVE-2023-31484

当我们再次运行命令(这次包含 .trivyignore 文件)时:

> trivy image project:latest --scanners vuln --format table --severity  CRITICAL,HIGH **--ignorefile .trivyignore**

不再报告严重性为 HIGH 或 CRITICAL 的漏洞:

project:latest (debian 12.0)

Total: 0 (HIGH: 0, CRITICAL: 0)

干杯!

使用 Scikit-Learn 的 SGDRegressor:你需要知道的未授课程

原文:towardsdatascience.com/sgdregressor-with-scikit-learn-untaught-lessons-you-need-to-know-cf2430439689

通过令人困惑的名称揭示隐藏的算法关系

Angela and Kezhan ShiTowards Data Science Angela and Kezhan Shi

·发表于 Towards Data Science ·阅读时间 13 分钟·2023 年 3 月 8 日

--

在机器学习领域,线性模型是一种基本技术,广泛用于根据输入数据预测数值。Scikit-learn 中的 SGDRegressor 估算器是一个强大的工具,可以让机器学习从业者快速高效地进行线性回归。

然而,SGDRegressor 的名称对于初学者来说可能有些混淆。本文将解释它的工作原理,并探讨为何这个名称对于刚开始学习机器学习的初学者来说可能会产生误导。

此外,我们将深入探讨 SGDRegressor 中实际上隐藏了多个模型的观点,每个模型都有其特定的参数和超参数。这将引发关于机器学习中模型定义的有趣问题。例如,岭回归或 SVR 是一个独立的模型,还是一个调优过的线性模型?

在你读完这篇文章时,你将更好地理解 SGDRegressor 的内部工作原理,并对线性模型在机器学习中的复杂性以及其实的简单性有新的认识。

1. SGDRegressor 的通常教学内容

SGDRegressor 是 Scikit-Learn 中的一种机器学习算法,它实现了随机梯度下降(SGD)来解决回归问题。由于其处理高维数据集的能力和快速的训练时间,它是大规模回归任务的热门选择。

SGDRegressor 通过使用训练数据的小随机子集而不是整个数据集来迭代地更新模型权重,这使得它在处理大数据集时计算上更高效。它还包括几个可以调整的超参数,以优化性能,包括学习率、惩罚或正则化项和迭代次数。

1.1 线性回归

SGDRegressor 是一个线性模型,使用线性函数来预测目标变量。线性函数的形式为:

y = w[0] + w[1] * x[1] + … + w[p] * x[p]

其中 x[1] 到 x[p] 是输入特征,w[1] 到 w[p] 是线性模型的系数,w[0] 或 b 是截距项。SGDRegressor 算法的目标是找到 w 和 b 的值,使得在预测值和目标变量的实际值之间定义的损失函数最小化。

1.2 随机梯度下降 vs. 梯度下降

SGDRegressor 算法使用随机梯度下降进行优化。随机梯度下降是一种迭代优化算法,它在小批量数据中更新模型参数。该算法使用相对于参数的代价函数的梯度来更新参数。

虽然 SGD 和 GD 都是机器学习中广泛使用的优化算法,但它们的效果取决于具体的问题。SGD 通常对大数据集和非凸问题更快且效果更好,而 GD 对小数据集和凸问题更可靠。

1.3 在 scikit-learn 中使用 SGDRegressor

使用 SGDRegressor 在 scikit-learn 中非常简单。首先,我们需要从 scikit-learn 的 linear_model 模块中导入 SGDRegressor 类。然后,我们可以创建一个 SGDRegressor 类的实例,并将模型拟合到我们的训练数据上。

以下是如何在 scikit-learn 中使用 SGDRegressor 的示例:

from sklearn.linear_model import SGDRegressor
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# Load the Boston housing dataset
X, y = load_boston(return_X_y=True)
# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create an instance of SGDRegressor
sgd_reg = SGDRegressor(max_iter=1000, tol=1e-3, penalty=None, eta0=0.1)
# Fit the model to the training data
sgd_reg.fit(X_train, y_train)
# Make predictions on the test data
y_pred = sgd_reg.predict(X_test)
# Calculate the mean squared error of the predictions
mse = mean_squared_error(y_test, y_pred)
print("Mean squared error:", mse)

在这个示例中,我们加载了波士顿房价数据集,并使用 train_test_split 函数将其拆分为训练集和测试集。然后,我们创建了一个 SGDRegressor 实例,并使用 fit 方法将模型拟合到训练数据上。最后,我们使用 predict 方法对测试数据进行预测,并使用 mean_squared_error 函数计算预测的均方误差。

2. 解开 SGDRegressor 对初学者的困惑名字

在本节中,我们将讨论为什么 SGDRegressor 这个名字对机器学习初学者来说可能会令人困惑。尽管领域内的专家可能熟悉这个名字及其相关算法,但对于新手来说,可能无法立即理解它的功能或工作原理。

然而,我们还将解释,对于专家来说,这个名字不一定是困惑的来源。这是因为他们已经熟悉了该算法及其目的,并且能够很容易地将其与其他类似的算法区分开来。

2.1 为什么它会令人困惑

在我的文章《三步学会机器学习:如何高效学习》中,我解释了如何通过将学习过程分解为三个不同的步骤:模型算法、拟合算法和调优算法,来提高你对机器学习算法的学习。尽管这种方法简单,但在实践中可能难以应用。SGDRegressor 算法就是一个典型的例子。

我认为,选择一个准确反映机器学习算法特征的合适名称对于理解和有效使用该算法至关重要。然而,这并不总是如此,例如混淆的“SGDRegressor”算法。

在 scikit-learn 中,根据机器学习模型是用于回归还是分类任务,有一个惯例是将“regressor”或“classifier”添加到模型名称中。这一惯例在许多示例中都很明显,例如 DecisionTreeClassifier 与 DecisionTreeRegressor,KNeighborsClassifier 与 KNeighborsRegressor 等。

然而,“SGDRegressor”这个名称的问题在于,如果我们将这一命名惯例应用于它,它会暗示“SGD”是一个机器学习模型,而实际上它是一个拟合算法。这对于可能对机器学习算法的不同组件没有清晰理解的初学者来说尤其困惑。

2.2 如何解释这一命名

对于专家来说,“SGDRegressor”的命名惯例可能看起来可以接受。使用“SGD”暗示该模型基于数学函数或参数模型,而非距离或树状模型。因此,“SGD”暗示了尽管 SGD 本身是一个拟合算法,但所使用的隐藏模型。

虽然理论上 SGD 可以用于线性和非线性模型,例如神经网络,但实际上,这个估计器仅实现了线性模型。因此,你会说这个估计器的一个更精确和简洁的名称可能是“LinearSGDRegressor”!是的,没错!但你是否也注意到“SGDRegressor”位于“linear_model”模块中,该模块定义上只实现了线性模型?!

最终,虽然 SGD 指的是具体的拟合算法,但由于此算法仅用于基于数学函数的模型,因此 SGDRegressor 这个名称似乎是合适的。此外,鉴于 SGDRegressor 位于“linear_model”模块中,这表明所使用的模型是线性模型。

3. SGDRegressor 中的隐藏模型

SGDRegressor 中可用的参数让我们能够选择不同的损失函数(squared_error、huber、epsilon_insensitive 或 squared_epsilon_insensitive)和惩罚函数(l1、l2、elasticnet 或 none)的组合。这些组合中的一些对应于传统的统计模型。

有趣的是,历史上,统计学家根据不同的假设和目标开发和构建了不同的模型。相比之下,机器学习提供了一个更统一的框架,其中线性模型保持不变,但损失函数和惩罚可以更改,以实现不同的目标。

3.1 损失函数

损失函数 squared_error、huber、epsilon_insensitive 和 squared_epsilon_insensitive 在对模型及其性能的影响上有所不同。

  • squared_error 损失函数,也称为均方误差(MSE),对较大的错误的惩罚程度比对较小的错误更重,使其对异常值敏感。这个损失函数在线性回归中常用。

  • huber 损失函数比 squared_error 损失函数对异常值的敏感性更低,因为它在较大的残差下从二次误差过渡到线性误差。这使得它在处理有些异常值的数据集时是一个不错的选择。

  • epsilon_insensitive 损失函数用于目标值预计在一定范围内的回归问题。它忽略小于某一阈值(称为 epsilon)的错误,并对较大的错误进行线性惩罚。这个损失函数在支持向量回归中常用。

  • squared_epsilon_insensitive 损失函数类似于 epsilon_insensitive 损失函数,但它通过对较大的错误进行平方惩罚来加重对较大错误的惩罚。当较大的错误需要比线性更多地惩罚时,这个损失函数可以很有用,但较小的错误可以被忽略。

总的来说,损失函数的选择会影响模型处理异常值的能力以及对错误的敏感性。为确保最佳性能,选择适合特定问题的损失函数是非常重要的。

这里是一张总结所有损失函数及其数学表达式的图片。也提供了 Python 实现,允许我们可视化和比较它们的行为。如果你想访问 Python 代码,你可以通过以下链接在 Ko-fi 上支持我:ko-fi.com/s/4cc6555852

SGDRegressor 中的损失函数 — 图片由作者提供

需要注意的是,尽管 SGDRegressor 并未明确提供 MAE 损失,但可以通过将 epsilon 不敏感损失函数的超参数 epsilon 设置为零来实现 MAE 损失。这是因为,当 epsilon 为零时,epsilon 不敏感损失的数学表达式等同于 MAE 损失。

最终,尽管损失函数的种类繁多,我们可以观察到它们的构造基于两个基本概念:基本损失是二次的或绝对的,以及引入 epsilon-tube 概念,它创建了两个不同的区域——中央区域和边缘区域——具有不同类型的损失函数。

通过采用这种方法,您可以潜在地创建自己定制的损失函数!

3.2 惩罚项

使用 SGDRegressor 时,我们还可以指定除了所选择的损失函数外还使用的惩罚项。可用的惩罚项包括 L1、L2、Elastic Net 和 None。

L1 惩罚将系数的绝对值添加到损失函数中,从而导致一些系数被设置为零的稀疏模型。该惩罚对特征选择和减少过拟合是有用的。

L2 惩罚将系数的平方添加到损失函数中,从而导致较小但非零的系数。该惩罚还可以帮助减少过拟合并提高模型的泛化能力。

Elastic Net 惩罚结合了 L1 和 L2 惩罚,允许既有稀疏性又有非零系数。它有两个超参数:alpha 控制 L1 和 L2 惩罚之间的权重,l1_ratio 控制 L1 和 L2 惩罚之间的平衡。

最后,None 意味着不使用惩罚,模型仅用所选择的损失函数进行拟合。

选择合适的惩罚项取决于具体问题和数据的性质。一般来说,L1 和 Elastic Net 惩罚有助于特征选择和稀疏模型,而 L2 惩罚有助于泛化和避免过拟合。

这里是来自scikit learn 文档的有趣可视化

来自 scikit learn 的惩罚项

3.3 隐藏模型的展示

SGDRegressor 提供了在指定不同组合的损失和惩罚参数方面的灵活性。在 SGDRegressor 中,某些损失和惩罚参数的组合对应于具有特定名称的知名模型。以下图像提供了这些选项的概述,包括相应的模型名称,我们将详细探讨其中的一些。

SGDRegressor 中的损失函数和惩罚项及其对应模型名称 — 图像由作者提供

平方误差

在 SGDRegressor 中,最简单的组合是 squared_error 无惩罚,这对应于 Scikit-learn 的 LinearRegression 估计器中的经典线性回归模型。然而,这种命名方式可能会产生误导。虽然所有这些模型都是回归模型且为线性模型,但将它们称为“线性回归”可能会造成混淆。为了避免这种情况,最好使用“线性模型”一词,而将“线性回归”一词专门保留用于 OLS 回归。

Lasso 是一种具有 L1 正则化的线性 (OLS) 回归模型,鼓励系数估计的稀疏性。

另一方面,岭回归在线性 (OLS) 回归模型中添加了 L2 惩罚,以帮助减轻多重共线性的影响。

ElasticNet 结合了 L1 和 L2 正则化的惩罚,以在 Lasso 的稀疏性和 Ridge 的稳定性之间取得平衡。

使用平方误差时,我们可以轻松识别每个惩罚对应的具体模型名称。然而,对于其他损失函数,并不总是有特定的名称与每个惩罚关联。在我看来,研究人员历史上专注于寻找正则化或线性回归系数的惩罚效果,导致每种惩罚类型都有特定名称。对于其他损失函数,添加惩罚项似乎不再新颖。

Huber 损失

Huber 损失函数是一种对离群值不那么敏感的替代损失函数,适用于具有显著离群值的数据集。Huber 损失函数是平方损失函数和绝对损失函数的组合。对于小误差,它的行为像平方损失函数,而对于大误差,它的行为像绝对损失函数。这使得它比平方损失函数对离群值更具鲁棒性,同时对小误差仍提供良好的性能。

Epsilon 不敏感损失

Epsilon-不敏感损失是另一种常用于线性模型的损失函数。epsilon 参数用于定义在此范围内的误差被视为零。此损失函数对具有噪声输出变量的数据集非常有用,因为它可以帮助减少输出变量中小波动的影响。

实际上,Epsilon 不敏感损失和 L2 正则化的组合也被称为 SVR(支持向量回归)。使用 epsilon 不敏感管道概念使得添加惩罚成为强制性,因为没有它可能会有无限多的解决方案。术语“支持向量”之所以被使用,是因为正如 L1 正则化项对系数会导致某些系数变为零一样,应用于数据集的 L1 损失(绝对损失)将导致某些数据点不用于计算系数,只留下其余的数据点,这些数据点被称为支持向量。

值得注意的是,Huber 回归通常被描绘为不那么敏感,但它与 SVR 共享这一特性,因为两者都对较大的值使用绝对值。虽然“epsilon 不敏感”一词强调误差为零的中心区域,但较大值的绝对误差对最终模型也可以有显著影响。

要了解有关 SVR 和 Epsilon 不敏感损失及其在 Scikit-learn 中的应用,可以阅读这篇文章:使用 Scikit-learn 理解 SVR 和 Epsilon 不敏感损失

平方 epsilon 不敏感损失

我们可能以前从未听说过这个:平方 epsilon 不敏感损失! 正如其名,它基于epsilon 不敏感损失,但使用的是平方误差而非绝对误差。问题是,为什么使用这个特定的损失函数?嗯,为什么不呢。

答案是,在机器学习中,“没有免费的午餐”理论表明,一个模型不能在所有数据集上表现良好,因此需要测试不同的损失函数。在某些情况下,平方 epsilon 不敏感损失可能是最佳选择。

3.4 一个可调模型还是不同模型?

为了理解这些模型的不同名称,我们可以从两个不同的角度来分析:统计学角度和机器学习角度。

在统计学领域,多年来已经开发出各种模型来解决不同类型的问题。因此,存在许多具有不同假设、约束和特性的模型。另一方面,机器学习框架相对简单,因为它围绕线性模型展开,可以通过更改损失函数和应用惩罚来避免过拟合,从而轻松修改。

这种简单性和灵活性使得机器学习从业者更容易进行实验并适应不同的问题设置,从而开发出适用于各种应用的新型有效模型。

在我之前的文章“三步学习机器学习:如何高效学习它”中,我强调了区分算法的三个部分——模型算法、模型拟合算法和模型调优算法的重要性。这种方法有助于简化对机器学习算法的理解。

至于 SGDRegressor,以下是三个步骤:

  • 模型:我们可以将 LASSO、Ridge、弹性网、SVM 和 Huber 回归视为一个整体模型,即线性模型表示为 y = wX + b。

  • 拟合:使用的拟合算法是随机梯度下降(SGD)。

  • 调优:可以调节的超参数包括损失和惩罚等。

尽管 sci-kit learn 在同一个 linear_model 模块中有多个独立的模型如 LinearRegression、LASSO 和 Ridge,但是否这些模型实际上是同一个模型并不重要。重点应放在理解它们的内部功能上,因为名字可能会产生误导。

在总结本节之前,我想到一个问题:估计器 LinearRegression 是否真的属于机器学习模型,因为它不可调且没有任何需要调整的超参数?

结论和主要要点

总之,scikit-learn 中的 SGDRegressor 提供了一个灵活且强大的线性回归工具。它的多种损失函数和惩罚选项为用户提供了许多自定义模型以满足特定需求的选择。此外,使用 SGD 拟合非凸函数的能力相较于标准梯度下降是一个显著的优势。需要注意的是,损失和惩罚参数应被视为需要调整的超参数。通过应用模型、拟合和调整的学习框架,数据科学家可以利用 SGDRegressor 在他们的线性回归任务中实现最佳结果。

SHAP:在 Python 中解释任何机器学习模型

原文:towardsdatascience.com/shap-explain-any-machine-learning-model-in-python-72f0bea35f7c

照片由 Priscilla Du Preez 提供,来源于 Unsplash

您的 SHAP、TreeSHAP 和 DeepSHAP 综合指南

Louis ChanTowards Data Science Louis Chan

·发表于 Towards Data Science ·阅读时间 13 分钟·2023 年 1 月 11 日

--

动机

故事时间!

想象一下你训练了一个机器学习模型来预测抵押贷款申请者的违约风险。一切都很好,性能也很出色。但模型是如何工作的?模型是如何得出预测值的?

我们站在那里说模型考虑了几个变量,而这些多维关系和模式复杂到用简单的语言无法解释。

这就是模型可解释性可以拯救局面的地方。在可以剖析机器学习模型的算法中,SHAP 是该领域中较为中立的算法之一。在这篇博客中,我们将深入探讨以下内容:

  • 什么是 Shapley 值?

  • 如何计算 Shapley 值?

  • 如何在 Python 中使用它?

  • SHAP 如何支持局部和全局可解释性?

  • SHAP 库中有哪些可用的可视化?

  • SHAP 的常见变体如何工作?— TreeSHAP 和 DeepSHAP

  • LIME 与 SHAP 相比如何?

Shapley 值

让我们玩个游戏

当一支由十一名球员组成的球队赢得世界杯时,谁是最有价值的球员?Shapley 值是一种分解算法,客观地将最终结果分配给一组因素。在解释机器学习模型时,Shapley 值可以理解为单个输入特征对模型预测值的贡献程度。

快速示例 — Shapley 值是如何工作的?

为了简单起见,假设我们有三名进攻球员,每名球员有不同的预期进球数。我们还知道这三名球员并不总是相互配合良好,这意味着根据这三名球员的组合,预期进球数可能会有所不同:

作者提供的图片

作为基准,我们不使用这三名球员,即特征数 f = 0,团队的预期进球数将是 0.5。每一个箭头向下的矩阵表示包含一个新特征(或在我们情况下是一个新球员)时可能的逐步增量。

遵循逐步扩展玩家集的思路,这意味着我们可以计算每一个箭头的边际变化。例如,当我们从不使用任何玩家(用空集符号 ∅ 表示)移动到仅使用玩家 1 时,边际变化是:

作者提供的图片

要获得玩家 1 在所有三名玩家中的总体贡献,我们需要对每一个可能出现玩家 1 边际贡献的情景重复相同的计算:

作者提供的图片

通过所有边际变化,我们可以使用以下公式计算它们的权重:

作者提供的图片

或者,简单来说:这只是指向同一行的所有边的数量的倒数。这意味着:

作者提供的图片

有了这些,我们现在可以计算玩家 1 的 SHAP 值,以获得预期进球数:

作者提供的图片

对另外两名玩家进行相同的操作,我们将得到:

  • 玩家 1 的 SHAP = -0.1133

  • 玩家 2 的 SHAP = -0.0233

  • 玩家 3 的 SHAP = +0.4666

如果我是主教练,我在这种情况下只会使用玩家 3。

这与另一种操作符 Choquet Integral 非常相似,对于那些数学更精通的朋友。

计算复杂度

以上述仅有 3 个特征的例子为例,我们需要考虑 8 个不同的模型,每个模型有不同的输入特征集,以全面解释所有特征。事实上,对于一个完整的N特征集,总子集的数量将是2^N。因此,在使用 SHAP 解释训练有大量且更重要的是宽数据集的机器学习模型时,我们需要注意预期的运行时间。

在接下来的章节中,我们将首先深入探讨如何在 Python 中使用 SHAP,然后将大部分注意力转向 SHAP 的不同变体,这些变体旨在通过近似技术或针对模型拓扑特定的技术来应对 SHAP 的复杂性。

Pascal 三角形 — 图片来源于 维基百科

Python 中的 SHAP

接下来,让我们探讨如何在 Python 中使用 SHAP。

SHAP (SHapley Additive exPlanations) 是一个兼容大多数机器学习模型拓扑的 Python 库。安装非常简单,只需 pip install shap

SHAP 提供了两种解释机器学习模型的方法——全局解释和本地解释。

使用 SHAP 进行本地可解释性

本地可解释性试图解释特定预测背后的驱动因素。在 SHAP 中,个体 Shapley 值就是用来做这个的,如早期部分的快速示例所示。

在 SHAP 的工具集中,有两种可视化方法用于解释个体预测:瀑布图和力图。瀑布图让你更好地理解逐步推导预测结果的过程,而力图旨在提供特征对预测结果偏差的相对贡献强度。

注意: 两种可视化都包括了一个整体期望预测值(或基准值)。这可以理解为训练集上模型输出的平均值。

瀑布图

# Code snippet from SHAP github page
import xgboost
import shap

# train an XGBoost model
X, y = shap.datasets.boston()
model = xgboost.XGBRegressor().fit(X, y)

# explain the model's predictions using SHAP
# (same syntax works for LightGBM, CatBoost, scikit-learn, transformers, Spark, etc.)
explainer = shap.Explainer(model)
shap_values = explainer(X)

# visualize the first prediction's explanation
shap.plots.waterfall(shap_values[0])

图片来自 SHAP GitHub 页面(MIT 许可证)

  • 在 y 轴上,你可以找到特征的名称和值。

  • 在 x 轴上,你可以找到基准值 E[f(X)] = 22.533,这表示训练集上的平均预测值。

  • 图中红色条形表示特征对预测值的正贡献。

  • 图中蓝色条形表示特征对预测值的负贡献。

  • 条形上的标签表示归因于参数的模型基准预测值的偏差。例如,AGE = 65.2 对预测值的偏差从基准值 22.533 上贡献了 +0.19。

  • 条形按其对预测值的绝对重要性降序排列。

力图

# Code snippet from SHAP github page
# visualize the first prediction's explanation with a force plot
shap.plots.force(shap_values[0])

图片来自 SHAP GitHub 页面(MIT 许可证)

  • 在 x 轴上,你可以找到基准值。这表示训练集上平均预测值的大致位置。

  • 在 x 轴上,你还可以找到用粗体数字标记的模型输出。这表示该记录的预测值。

  • 在图表底部,你可以找到特征的名称和值,标记为红色或蓝色。

  • 所有在模型输出左侧的红色条形是对预测偏离基准值有正面贡献的特征。特征的名称在条形的底部。条形的长度表示特征的贡献。

  • 模型输出右侧的所有蓝色条形图表示对预测偏离基准值产生负面贡献的特征。特征的名称位于条形图底部。条形图的长度表示特征的贡献。

SHAP 的全球可解释性

全球可解释性可以理解为在整个数据集中理解每个特征的整体重要性,并提供对数据和潜在模式的一般了解。由于分解个体预测贡献和在数据中聚合的模糊性,尝试全球可解释性的方法不止一种。示例包括信息增益、汇总权重、基于置换的特征重要性和 Shapley 值。SHAP 当然专注于最后一个。

SHAP 提供了一种可视化方法,我们可以查看特征在数据集中的平均 Shapley 值。与其他使用统计上更复杂解释来提供重要性度量的机制不同,SHAP 的全球可解释性通过让你能够说,平均而言,特征关系使得“Class 1”数据记录的预测值比“Class 0”数据记录高约 1.0,从而提供了一个立即可理解的影响。

图像来自SHAP GitHub 页面(MIT 许可证)

SHAP 的全球可解释性功能允许我们排查或调查模型偏差。以上面的图像为例,年龄通常是一个非常重要的特征。这是否可能表明模型对特定年龄组存在不必要的偏见?此外,一个非常重要的特征是否可能是潜在的数据泄露?所有这些问题都使我们在部署更负责任且强健的机器学习模型之前,能够改进模型。

注意: 如果你有兴趣了解更多关于负责任的人工智能,我还写了一篇关于如何通过 5 个简单步骤来实现这一目标的文章。

## 解锁负责任的人工智能:确保伦理系统的 5 个步骤

负责任的人工智能系统的 5 个步骤

pub.towardsai.net

SHAP 支持的另一种可视化是局部可解释性部分的力图堆叠版本。通过堆叠力图,我们可以可视化模型与不同输入值的特征之间的交互。这为我们提供了基于 Shapley 值的聚类视图,并提供了模型如何看待数据的视角。这对修正和验证假设以及基础业务逻辑非常有用。在分析所有 Shapley 值后,你可能还会发现数据分割的新方法!

# Code snippet from SHAP github page
# visualize all the training set predictions
shap.plots.force(shap_values)

图片来源于SHAP GitHub 页面(MIT 许可证)

SHAP 的变体

TreeSHAP

  • 优点: 高效且准确的算法,用于计算基于树模型的 Shapley 值。

  • 缺点: 仅适用于基于树的模型。

与原始 SHAP 不同,TreeSHAP 是特定于基于树的机器学习模型的。这意味着 TreeSHAP 仅适用于决策树、随机森林、梯度提升机等模型。

TreeSHAP 特定于树模型,因为它利用树结构来更高效地计算准确的 Shapley 值。由于这些结构在其他模型拓扑中不存在,因此 TreeSHAP 仅限于基于树的模型。

TreeSHAP 可以通过干预和树路径依赖的方法计算 Shapley 值。这可以在feature_perturbation参数中指定。树路径依赖方法递归地计算条件期望的变化。我们以一个接受 2 个特征(x, y)的简单决策树为例:

示例决策树 — 作者提供的图片

在上面的示例中,我们有一个包含 7 个节点的决策树,接受两个特征(x, y)来预测z,并且已经用8个训练样本进行了训练。为了计算在联盟(x=10, y=5)yz预测的局部贡献,我们需要考虑以下因素:

  1. 对于(x=10, y=5),模型将从节点 1 移动到节点 3 并到达节点 6。由于节点 6 是叶节点,模型确定预测为z=4

  2. 对于(x=10),模型将从节点 1 移动到节点 3。然而,由于节点 3 不是叶节点,预测值可以推断为节点 3 所有叶节点的加权和。在通过节点 3 的 5 个训练样本中,有两个预测为z=4,而其他的预测为z=24。加权和为4(2/5) + 24(3/5)=1.6 + 14.4 = 16

  3. 在联盟(x=10, y=5)中,yz预测的边际贡献可以计算为Prediction(x=10, y=5) — Prediction(x=10) = 4–16= -12

注意: 这里的负贡献并不意味着特征y不重要,而是特征y将预测值推高了-12

通过对所有特征继续这一过程,TreeSHAP 将获得所有 Shapley 值,并提供局部可解释性(使用上述方法)和全局可解释性(对训练集中的所有局部可解释性结果进行平均)

顾名思义,干预方法通过人为调整感兴趣特征的值来计算 Shapley 值。在我们上述的例子中,这可能是将y从 5 改为 4。为了估计敏感性,TreeSHAP 需要反复使用背景集/训练集作为参考点(当我们在最后一节讨论 LIME 时会再次提到),其线性运行时间复杂度。因此,在使用干预方法时,我们应更加关注 TreeSHAP 的可扩展性。

import shap

# Load the data
X_train, y_train, X_test, y_test = load_data()

# Train the model
model = MyModel.train(X_train, y_train)

# Explain the model's predictions using the tree path dependent approach
explainer = shap.TreeExplainer(
    model,
    X_train,
    feature_perturbation='tree_path_dependent')
shap_values_path = explainer.shap_values(X_test)

# Display the explanations
shap.summary_plot(shap_values_path, X_test)

# Explain the model's predictions using the interventional approach
explainer = shap.TreeExplainer(
    model,
    X_train,
    feature_perturbation='interventional')
shap_values_interv = explainer.shap_values(X_test)

# Display the explanations
shap.summary_plot(shap_values_interv, X_test)

DeepSHAP

  • 优点: 高效的算法,用于近似深度学习或基于神经网络的模型的 Shapley 值。兼容 Tensorflow 和 PyTorch

  • 缺点: 仅适用于深度学习或基于神经网络的模型。由于算法的近似特性,比 SHAP 的准确性低。

讨论可解释性时,我们不能忽视神经网络。DeepSHAP 是 SHAP 和 DeepLIFT 的结合,旨在揭示深度学习模型背后的哲学。它专为深度学习模型设计,这使得 DeepSHAP 仅适用于基于神经网络的模型。

DeepSHAP 试图近似 Shapley 值。解释 DeepSHAP 的一种相对原始的方法是,它试图通过使用梯度或偏导数来分配特征x的局部边际贡献,前提是使用一个有意义的背景/参考点(例如,图像识别模型的全黑背景,预测暴富机会的 0%)。

注意:有进一步研究发布了 DeepSHAP 的通用版本——G-DeepSHAP。你可以在arxiv阅读。

import shap

# Load the data
X_train, y_train, X_test, y_test = load_data()

# Train the model
model = MyModel.train(X_train, y_train)

# Explain the model's predictions using TreeSHAP
explainer = shap.DeepExplainer(model, X_train)
shap_values = explainer.shap_values(X_test)

# Display the explanations
shap.summary_plot(shap_values, X_test)

LIME — SHAP 的替代方法

LIME(局部可解释模型无关解释)是解释预测的 SHAP 的替代方法。它是一种模型无关的方法,默认假设内核大小(解释个体预测时考虑的局部邻域的大小)来近似特征对局部实例的贡献。一般来说,选择较小的内核大小时,LIME 提供的结果将更倾向于局部解释特征值对预测的贡献。(即,较大的内核大小往往提供更全局的视角)

然而,内核大小的选择应根据数据和模式仔细决定。因此,在使用 LIME 时,我们应考虑相应地调整内核大小,以获得对机器学习模型的合理解释。

要尝试一下,我们可以安装并使用该软件包:

pip install lime
import lime
import lime.lime_tabular

# Load the data
X_train, y_train, X_test, y_test = load_data()
feature_names = X_train.columns

# Train the model
model = MyModel.train(X_train, y_train)

# Explain the model's predictions using LIME
explainer = lime.lime_tabular.LimeTabularExplainer(
    X_train, feature_names=feature_names)

# Choose a kernel size for the local neighborhood
kernel_size = 10

# Explain the model's prediction for a single instance
instance = X_test[0]
exp = explainer.explain_instance(
    instance,
    model.predict,
    num_features=10,
    kernel_size=kernel_size)

# Display the explanations
exp.show_in_notebook(show_all=False)

结论

最后总结一下,这里是对本文讨论内容的简要总结:

  • SHAP 是一种基于博弈论的方法,用于解释机器学习模型。

  • SHAP 考虑所有可能的特征组合以评估每个特征的影响。

  • 特征 f 对于本地预测实例的 SHAP 值是由于特征的引入在包含 f 的所有可能特征组合中的边际变化的加权总和。

  • 边际变化的权重根据 f × C(F, f) 的倒数进行,其中 F 是实际模型考虑的特征数量,而 f 是计算边际变化时考虑的特征数量。

  • 由于 SHAP 考虑了所有可能的特征组合,因此算法不会线性扩展,会受到维度灾难的影响。

  • 为了应对 SHAP 的计算复杂性,已经常用几种 SHAP 的变体:

图片来源于作者

  • 我们应该考虑对基于树的模型使用 TreeSHAP,对基于深度学习的模型使用 DeepSHAP。

  • LIME 是一种替代 SHAP 的模型无关方法,用于近似特征的贡献。

  • LIME 的解释可以根据内核大小的选择显著不同。

关于 SHAP 的这次全面介绍就是这些了。我希望你发现这些内容对提升你的写作水平或开始写作有所帮助。如果你喜欢这篇文章,你也可以通过下面的我的附属链接订阅 Medium 来支持我。这是一个我发现了很多有趣读物的平台。即使你完全不打算订阅,你也可以通过点“赞”来支持我和我的创作。

[## 通过我的推荐链接加入 Medium — Louis Chan

阅读 Louis Chan 的每一个故事(以及 Medium 上成千上万的其他作家的故事)。你的会员费直接支持…

louis-chan.medium.com](https://louis-chan.medium.com/membership?source=post_page-----72f0bea35f7c--------------------------------)

最后但绝对不是最不重要的,如果我遗漏或误解了任何关键内容,请随时在评论中指出或通过 LinkedIn 给我发消息。让我们一起保持知识的流动,共同在这个领域中进步!

[## Louis Chan — 主任级 GCP 数据与 ML 工程师 — 副总监 — KPMG 英国 | LinkedIn

有抱负、好奇且富有创意的个人,坚信知识领域之间的相互联系。

www.linkedin.com](https://www.linkedin.com/in/louis-chan-b55b9287?source=post_page-----72f0bea35f7c--------------------------------)

参考文献

  1. Lundberg, Scott M., 和 Su-In Lee. “统一的模型预测解释方法。” 神经信息处理系统进展,2017。

  2. Lundberg, Scott, 和 Su-In Lee. “一致的个性化特征归因用于树集成。” arXiv 预印本 arXiv:1802.03888, 2018.

  3. Ribeiro, Marco Tulio, Sameer Singh, 和 Carlos Guestrin. “我为什么应该相信你?解释任何分类器的预测。” 第 22 届 ACM SIGKDD 国际知识发现与数据挖掘大会论文集, 2016.

  4. Ribeiro, Marco Tulio, Sameer Singh, 和 Carlos Guestrin. “Anchors: 高精度模型无关解释。” arXiv 预印本 arXiv:1802.07814, 2018.

二元和多类目标变量的 SHAP

原文:towardsdatascience.com/shap-for-binary-and-multiclass-target-variables-ff2f43de0cf4

提供一个指南,讲解当模型预测分类目标变量时如何编写代码和解读 SHAP 图

Conor O'SullivanTowards Data Science Conor O'Sullivan

·发表于 Towards Data Science ·9 分钟阅读·2023 年 9 月 4 日

--

照片由 Nika Benedictova 提供,来源于 Unsplash

SHAP 值展示了模型特征对预测的贡献。这在我们使用 SHAP 进行分类时也同样适用。不同的是,对于二元目标变量,我们用对数几率来解释这些值。对于多类目标,我们使用softmax。我们将:

  • 更深入地讨论这些解释

  • 提供用于显示分类问题的 SHAP 图的代码

  • 探索聚合 SHAP 值的新方法以适应多类目标

你还可以观看关于该主题的视频:

之前的 SHAP 教程

我们继续之前的 SHAP 教程。它深入探讨了连续目标变量的 SHAP 图。你将发现这些图及其见解对于分类目标变量也是类似的。你还可以在 GitHub 上找到完整的项目。

## 使用 Python 进行 SHAP 介绍

如何创建和解读 SHAP 图:瀑布图、力图、均值 SHAP、蜜蜂散点图和依赖图

towardsdatascience.com

总结一下,我们使用 SHAP 解释了基于海螺数据集构建的模型。该数据集包含4,177个实例,下面可以看到特征的示例。我们使用8个特征来预测 y——海螺壳上的环数。环数与海螺的年龄有关。在本教程中,我们将 y 分成不同的组,以创建二元和多类目标变量。

X 特征矩阵(来源:UCI 机器学习库)(许可证:CC0:公共领域)

二元目标变量

对于连续目标变量,我们发现每个实例都有 8 个 SHAP 值——每个模型特征一个。如图 1所示,如果我们将这些值与平均预测值E[f(x)]相加,我们就得到了该实例的预测值f(x)。对于二元目标变量,我们也有相同的性质。区别在于我们将值解释为预测的对数几率。

图 1:根据对数几率解释 SHAP 值(来源:作者)

为了理解这一点,让我们深入研究 SHAP 图。我们从创建一个二元目标变量(第 2 行)开始。我们基于 y 创建了两个组:

  • 1 如果海螺的环数高于平均水平

  • 0 否则

#Binary target varibale
y_bin = [1 if y_>10 else 0 for y_ in y]

我们使用这个目标变量和 8 个特征来训练一个XGBoost 分类器(第 2-3 行)。该模型的准确率为96.6%。

#Train model 
model_bin = xgb.XGBClassifier(objective="binary:logistic")
model_bin.fit(X, y_bin)

我们现在计算 SHAP 值(第 2-3 行)。我们输出这个对象的形状(第 5 行),得到(4177, 8)。因此,与连续目标变量一样,我们每个预测和特征都有一个 SHAP 值。稍后,我们将看到这对于多类目标变量是如何不同的。

#Get shap values
explainer = shap.Explainer(model_bin)
shap_values_bin = explainer(X)

print(shap_values_bin.shape) #output: (4177, 8)

我们为第一个实例显示了一个瀑布图(第 6 行)。我们可以在图 2中看到结果。注意,代码与连续变量的代码相同。除了数字外,瀑布图也看起来类似。

# waterfall plot for first instance
shap.plots.waterfall(shap_values_bin[0])

现在E[f(x)] = -0.789 给出了所有 4,177 个海螺的平均预测对数几率。这是正预测(1)的对数几率。对于这个特定的海螺,模型预测其有0.3958的概率具有高于平均水平的环数(即P = 0.3958)。这给出了预测的对数几率f(x) = ln(0.3958/(1–0.3958)) = -0.423

图 2:带有二元目标变量的瀑布图(来源:作者)

因此,SHAP 值表示预测对数几率与平均预测对数几率之间的差异。正的 SHAP 值增加对数几率。例如,剥壳重量增加了1.32的对数几率。换句话说,这个特征增加了模型预测环数高于平均水平的概率。类似地,负值则减少对数几率。

我们也可以以之前相同的方式聚合这些值。好消息是,像蜜蜂散点图或均值 SHAP 这样的图形解释将保持不变。只需记住,我们处理的是对数赔率。现在让我们看看这种解释如何在多类别目标变量中发生变化。

多类别目标变量

我们通过创建一个新的目标变量 (y_cat) 来开始,该变量有 3 个类别——年轻(0)、中等(1)和老(2)。如前所述,我们训练了一个 XGBoost 分类器来预测这个目标变量(第 5–6 行)。

#Categorical target varibale
y_cat = [2 if y_>12 else 1 if y_>8 else 0 for y_ in y]

#Train model 
model_cat = xgb.XGBClassifier(objective="binary:logistic")
model_cat.fit(X, y_cat)

对于这个模型,我们不能再谈论“正预测”。如果我们输出第一个实例的预测概率(第 2 行),我们可以看到这一点。这给我们 [0.2562, 0.1571, 0.5866]。在这种情况下,第三个概率最高,因此海洋蜗牛被预测为老(2)。这对 SHAP 意味着我们不能再只考虑正类的值。

# get probability predictions
model_cat.predict_proba(X)[0]

当我们计算 SHAP 值时可以看到这一点(第 2–3 行)。代码与二分类模型相同。然而,当我们输出形状(第 5 行)时,我们得到 (4177, 8, 3)。现在我们为每个实例、特征和类别都有一个 SHAP 值。

#Get shap values
explainer = shap.Explainer(model_cat)
shap_values_cat= explainer(X)

print(np.shape(shap_values_cat))

因此,我们必须在单独的瀑布图中显示每个类别的 SHAP 值。我们在下面的代码中为第一个实例执行此操作。

# waterfall plot for class 0
shap.plots.waterfall(shap_values_cat[0,:,0])

# waterfall plot for class 1
shap.plots.waterfall(shap_values_cat[0,:,1])

# waterfall plot for class 2
shap.plots.waterfall(shap_values_cat[0,:,2])

图 3 给出了类别 0 的瀑布图。该图解释了每个特征如何对模型预测为此类别做出贡献。也就是说,与该类别的平均预测相比。我们看到该类别的概率相对较低(即 0.2562)。我们可以看到,去壳重量特征对这一低概率做出了最显著的贡献。

图 3:类别 0 的瀑布图(来源:作者)

图 4 给出了其他类别的输出。你会注意到 f(x) = 1.211 在类别 2 中是最大的。这是有道理的,因为我们看到这个类别的概率也是最大的(0.5866)。在分析该实例的 SHAP 值时,可能要重点关注这个瀑布图。这是该海洋蜗牛的类别预测。

图 4:类别 1 和 2 的瀑布图(来源:作者)

使用 Softmax 解释值

由于我们现在处理的是多个类别,f(x) 是以 softmax 形式给出的。我们可以使用下面的函数将 softmax 值转换为概率。fx 给出了上述瀑布图中的三个 f(x) 值。结果是 [0.2562, 0.1571, 0.5866]。这就是我们看到的实例 0 的预测概率!

def softmax(x):
    """Compute softmax values for each sets of scores in x"""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

# convert softmax to probability
fx = [0.383,-0.106,1.211]
softmax(fx)

聚合多类别 SHAP 值

这些 SHAP 值可以使用任何 SHAP 图进行汇总。然而,像瀑布图一样,每个类别都会有单独的图。分析这些可能会很繁琐,尤其是当目标变量中有很多类别时。因此,我们将讨论一些其他的汇总方法。

首先是平均 SHAP 图的一个版本。我们分别计算每个类别的 SHAP 值的绝对平均值(第 2–4 行)。然后我们创建一个条形图,每个类别和特征都有一个条形。

# calculate mean SHAP values for each class
mean_0 = np.mean(np.abs(shap_values_cat.values[:,:,0]),axis=0)
mean_1 = np.mean(np.abs(shap_values_cat.values[:,:,1]),axis=0)
mean_2 = np.mean(np.abs(shap_values_cat.values[:,:,2]),axis=0)

df = pd.DataFrame({'small':mean_0,'medium':mean_1,'large':mean_2})

# plot mean SHAP values
fig,ax = plt.subplots(1,1,figsize=(20,10))
df.plot.bar(ax=ax)

ax.set_ylabel('Mean SHAP',size = 30)
ax.set_xticklabels(X.columns,rotation=45,size=20)
ax.legend(fontsize=30)

我们可以在图 5中看到输出。有一点需要提到的是,每个条形图显示的是所有预测的平均值。然而,实际的预测类别在每种情况下都会有所不同。因此,您可能会因为 SHAP 值不能解释预测类别而使均值产生偏差。这可能是我们看到中等类别均值较小的原因。

图 5:多分类目标变量中每个类别的平均 SHAP 值(来源:作者)

为了避免这个问题,我们可以集中在预测类别的 SHAP 值上。我们首先获取每个实例的预测类别(第 2 行)。我们创建一组新的 SHAP 值(new_shap_values)。这是通过遍历原始值并仅选择与该实例预测相对应的那一组值来完成的(第 5–7 行)。

# get model predictions
preds = model_cat.predict(X)

new_shap_values = []
for i, pred in enumerate(preds):
    # get shap values for predicted class
    new_shap_values.append(shap_values_cat.values[i][:,pred])

然后我们将原始对象中的 SHAP 值替换掉(第 2 行)。现在,如果我们输出形状,得到的是(4177, 8)。换句话说,我们回到了每个实例一组 SHAP 值的情况。

# replace shap values
shap_values_cat.values = np.array(new_shap_values)
print(shap_values_cat.shape)

这种方法的一个好处是可以轻松使用内置的 SHAP 图。例如,图 6中的平均 SHAP 图。我们可以将这些值解读为特征对预测类别的平均贡献。

shap.plots.bar(shap_values_cat)

图 6:多分类目标变量中预测类别的平均 SHAP 值(来源:作者)

我们也可以使用 beeswarm 图。然而,注意到我们没有看到 SHAP 值与特征值之间的明确关系。这是因为特征的关系会根据预测类别而有所不同。年长的鲍鱼体型会更大。例如,大的壳重会导致老年(2)预测的概率更高。年轻(0)预测则相反。

shap.plots.beeswarm(shap_values_cat)

图 6:多分类目标变量的 beeswarm 图(来源:作者)

希望现在清楚如何解读二分类和多分类目标变量的 SHAP 值。然而,您可能会想知道为什么它们以对数几率和 softmax 的形式给出。将它们解释为概率可能更有意义。

这源于 SHAP 值的计算方式。即同时通过线性模型进行计算。如果我们需要用线性模型预测一个二元或多类变量,我们会分别使用逻辑回归或 Softmax 回归。这些连接函数是可微的,并允许我们将模型预测公式化为参数和特征的线性方程。同样,这些特性用于高效地估计 SHAP 值。

了解更多关于 SHAP 的信息:

## 新的 SHAP 图:小提琴图和热图

SHAP 版本 0.42.1 中的图可以告诉你关于你的模型的哪些信息

[towardsdatascience.com ## SHAP 的局限性

SHAP 如何受到特征依赖性、因果推断和人为偏差的影响

[towardsdatascience.com ## 使用 SHAP 调试 PyTorch 图像回归模型

使用 DeepShap 来理解和改进支持自动驾驶汽车的模型

[towardsdatascience.com

我希望你喜欢这篇文章!你可以通过成为我的 推荐会员 来支持我 😃

[## 通过我的推荐链接加入 Medium — Conor O’Sullivan

作为 Medium 会员,你的会员费用的一部分将分配给你阅读的作者,并且你可以完全访问每个故事……

conorosullyds.medium.com

| Twitter | YouTube | Newsletter — 免费注册获取 Python SHAP 课程

参考文献

Stackoverflow 如何解释多类分类问题的 base_value,当使用 SHAP 时?stackoverflow.com/questions/65029216/how-to-interpret-base-value-of-multi-class-classification-problem-when-using-sha/65034362#65034362

SHAP 用于时间序列事件检测

原文:towardsdatascience.com/shap-for-time-series-event-detection-5b4d9d0f96f4?source=collection_archive---------2-----------------------#2023-02-01

图片由 Luke Chesser 提供,发布在 Unsplash

使用修改版的 KernelSHAP 进行时间序列事件检测

Nakul UpadhyaTowards Data Science Nakul Upadhya

·

关注 发表在 Towards Data Science ·8 min 阅读 ·2023 年 2 月 1 日

--

特征重要性是一种广泛使用的技术,用于解释机器学习模型如何做出预测。该技术为每个特征分配一个分数或权重,指示该特征对预测的贡献程度。这些分数可以用来识别最重要的特征,并了解模型如何做出预测。其中一种常用的版本是 Shapley 值,这是一种基于博弈论的模型无关度量,将“支付”(预测)公平地分配给特征[1]。Shapley 值的一种扩展是 KernelSHAP,它使用核技巧以及局部替代模型来近似 Shapley 值,这使得它能够计算更复杂模型(如神经网络)的特征重要性值[2]。

KernelSHAP 通常用于解释时间序列预测,但在这个领域确实存在一些显著的约束和缺陷:

  1. 时间序列预测通常涉及大量过去数据,这可能会导致在应用 KernelSHAP 时出现数值下溢错误,尤其是在多变量时间序列预测中[3]。

  2. KernelSHAP 假设特征独立性。这在表格数据情况下通常有效,但在时间序列中,特征和时间步的独立性往往是一种例外,而非常态[3]。

  3. KernelSHAP 使用已拟合于数据扰动的线性模型的系数。然而,在时间序列的情况下,向量自回归模型(VAR)通常比仅使用线性模型[3]更适合建模过程。

为了解决这些问题,J.P. 摩根的 AI 研究部门(Villani 等人)在他们的 2022 年 10 月论文 [3]中提出了更适合时间序列数据的 KernelSHAP 变体:

  1. 研究人员首先创建了 VARSHAP,这是一种使用 VAR 模型而非线性模型的 KernelSHAP 修改版。这种修改使得研究人员还计算了一种封闭形式的方法来计算 AR、MA、ARMA 和 VARMAX 模型的 SHAP 值。

  2. 基于 VARSHAP,研究人员提出了时间一致性 SHAP,利用问题的时间成分来减少 SHAP 值的计算量。

使用时间一致性 SHAP 度量,研究人员展示了一种通过捕捉特征重要性的激增来进行事件检测的有前途的方法。

在这篇文章中,我将首先解释 KernelSHAP 的计算方法以及如何将其修改为 VARSHAP。然后还将解释如何获取时间一致性 SHAP(TC-SHAP)以及如何使用 TC-SHAP 来识别时间序列分析中的事件。

KernelSHAP 和 VARMAX

SHAP 值的公式如[2]所提供的:

方程 1:SHAP 方程

上述方程中的 phi 是特征 i 的 SHAP 值,给定值函数 v(值函数通常是模型预测)。C* 是所有特征的集合,NC 的大小或特征数量。P(C) 是所有不包含特征 i 的特征的幂集。Delta(S, i)* 是将特征 i 添加到特征联盟 S(这是幂集 C 中的一个集合)时导致的预测变化。

该方程总结为“将特征 i 的加权边际贡献添加到不包括 i 的每个可能特征联盟中”。

KernelSHAP 处理的问题是,随着特征数量的增加,幂集大小呈指数级增长,这使得计算变得极其庞大。KernelSHAP 是通过解决以下问题来计算的:

方程 2:KernelSHAP 方程 [3]

其中 h_x 是应用于 z 的掩码函数,z 是从集合 Z 中采样的二进制向量,代表所有可能的特征联盟的集合。这个函数将由 z 表示的联盟映射到掩码数据点,然后将其输入到我们的模型中(f_theta)。目标是找到最佳的线性模型 (g),该模型在所有掩码下估计模型性能。线性模型中的权重是 KernelSHAP 值。这一切都可以通过以下定义的组合核实现:

方程 3:组合核 [3]

要计算 VARSHAP,只需将 g 的线性表示替换为 VAR 模型。根据作者的说法,由于线性模型和 VAR 模型的系数都是通过普通最小二乘法估计的,因此所有 KernelSHAP 的数学原理仍然适用,并且对时间序列更具代表性 [3]。

时间一致性 SHAP

如前所述,SHAP 是一种将模型的特征解释为游戏中的玩家,并使用 Shapley 值来寻找公平奖励分配的方法。然而,对于随时间发展的游戏,这些分配可能不足以激励所有方追求最初目标。为了避免这种情况,博弈论者使用分配计划和时间一致性的概念来管理跨时间的激励 [3]。

这种通过时间的利益竞争思想也扩展到特征上,因为传统的 SHAP 方法将不同时间步的相同特征视为游戏中的不同玩家。根据作者的说法,我们可以通过添加时间一致性来弥合这一差距 [3]。

SHAP 值的时间一致性可以表示如下:

方程 4:Shapley 值的时间一致性 [3]

在这个方程中,beta 代表在 t 时间步中对玩家(特征)i 进行的支付分配计划,而 phi(0,i) 是玩家对游戏(预测)贡献的总值。可以将其视为类似于商业伙伴关系。

每个个体(即特征)向启动基金(在时间步 0 时为 phi)支付初始金额。然后,在未来的时间步中,个体会定期获得回报,因为他们对业务结果(即最终预测)作出更多贡献。这些支付也会使任何个体不愿意采取不利于业务利益的行动。通过这种方式表述问题,TC-SHAP 在时间序列背景下表现得更好,因为现在 特征的不同时间步被建模为一个整体,而不是作为单独的参与者

在实际操作中,可以采取以下步骤:

  1. 通过将特征替换为零或特征均值来计算每个特征的总 SHAP 贡献(在时间步 0 时为 phi)。对所有特征重复此操作。

  2. 然后我们需要计算窗口内每个时间步的“子游戏 SHAP”。这通过修改方程 2 中的掩码机制来完成,即不再只掩码时间步 (t-w),而是掩码 t-wt 之间的所有时间步(再次使用零或均值)。

  3. 然后我们简单地使用方程 4 和 5 计算填充值表。

方程 5: 填充值表 [3]

第 1 步计算“初始投资”。第 2 步则实施了一个观点,即我们有 N 个特征跨越多个时间步 (W),而不是 NW* 个特征。第 3 步通过提供填充值表(或每个“投资者”的周期性回报)将所有内容汇总在一起。

这个过程还有一个好处,即将计算次数从 2^(NW)* 减少到 W**2^N*,其中 W 是用于预测的时间步数,而 N 是特征的数量 [3]。

一旦计算完成,我们可以将 TC SHAP 值解释为“在给定时间步中,特征的演变如何影响其他特征轨迹的联盟。” 换句话说,TC SHAP 代表了给定时间步的特征如何改变未来时间步中其他特征的共同贡献。对未来协作产生重大影响的特征-时间步点将根据定义对最终预测产生重大影响。

使用 TC SHAP 进行事件检测

虽然了解单次预测中某个时间步的重要性是有用的,但时间序列分析通常涉及分析多个预测和模式,我们可能想要了解模型预测中一些重要时间步的总体情况 [3]。

根据作者的说法,我们可以通过将给定预测集(或如果我们想要全球事件检测机制则是所有预测)的 TC SHAP 值相加来找到重要的时间步。通过绘制这些值,我们可以轻松查看哪些时间步很重要以及一些重要事件可能发生在哪里 [3]。

作者通过 Individual Household Electricity dataset 演示了这种方法的有效性。作者训练了一个 LSTM 网络,后接一个密集层来预测功耗。然后,他们计算了 TC-SHAP 值并将其汇总以获得事件检测卷积。接着,他们将卷积叠加到目标时间序列上。

图 1:事件检测卷积(蓝色)用于子计量 2 和 3 与目标的比较(图源自 [3])

如图 1 所示,目标变量的大幅变化可以通过事件卷积中的大幅峰值来解释。例如,在时间步 25 之后,子计量 2 的事件卷积出现了一个大峰值,随后目标也出现了大幅下降。类似地,在时间步 75 左右,子计量 3 的事件卷积大幅下降后,目标也出现了大幅下降。大多数的大幅变化可以通过一些子计量的变化来解释。

结论

对 KernelSHAP 的修改填补了当前工作中的一个巨大空白。除此之外,针对时间序列特征重要性进行开发的事后可解释性方法并不多。TC-SHAP 有助于解决这个问题,确实是迫切需要的。

然而,这种新方法仍然存在一些关注点和进一步的工作需要解决。其中一个关注点(作者也提到)是 VARSHAP 和 TC-SHAP 解释之间的显著差异,这表明需要更多的工作来检验这些值的确切解释。此外,尽管理论上 TC-SHAP 克服了独立性问题,但仍需更多实验来完全确认这一说法。

此外,一般来说,模型无关的方法可能具有误导性,因为它们只能提供重要性的估计,而非真实的重要性。然而,对于大多数使用情况,这种粗略的评估已经足够,并且拥有一个处理时间依赖性的方法非常有用。

资源和参考文献

  1. Python 的 SHAP 包:shap.readthedocs.io/en/latest/index.html

  2. 对 Kernel SHAP 的更深入解释:christophm.github.io/interpretable-ml-book/shap.html#kernelshap

参考文献

[1] L. Shapley. n 人博弈的价值。(1953)。博弈理论贡献 2.28。

[2] S.M. Lundberg, S-I. Lee. 一种统一的模型预测解释方法。(2017)。神经信息处理系统进展,30。

[3] M. Villani, J. Lockhart, D. Magazzeni. 时间序列数据的特征重要性:改进 KernelSHAP (2022)。ICAIF 解释性人工智能在金融中的研讨会

SHAP 与 ALE 在特征交互上的对比:理解冲突的结果

原文:towardsdatascience.com/shap-vs-ale-for-feature-interactions-understanding-conflicting-results-ac506149f678

模型解释工具需要深思熟虑的解读

Valerie CareyTowards Data Science Valerie Carey

·发布于Towards Data Science ·10 分钟阅读·2023 年 10 月 2 日

--

图片由Diogo Nunes拍摄,发布在Unsplash

在这篇文章中,我比较了特征交互的模型解释技术。令人惊讶的是,两个常用的工具,SHAP 和 ALE,产生了相反的结果。

我可能不应该感到惊讶。毕竟,解释工具以不同的方式测量特定的响应。解释需要理解测试方法、数据特征和问题背景。仅仅因为某物被称为解释器并不意味着它生成了解释,如果你把解释定义为一个人理解模型的工作方式。

本文重点关注特征交互的解释技术。我使用了一个来源于真实贷款的常见项目数据集[1],以及一种典型的模型类型(一个提升树模型)。即使在这种日常情况下,解释也需要深思熟虑的解读。

如果忽略了方法论细节,解释工具可能会妨碍理解,甚至破坏确保模型公平性的努力。

在下文中,我展示了不同的 SHAP 和 ALE 曲线,并证明这些技术之间的分歧来源于测量响应和测试执行的特征扰动的差异。但首先,我会介绍一些概念。

特征交互

特征交互发生在两个变量共同作用时,导致的效果与它们各自贡献的总和不同。 例如,一夜的睡眠质量差对第二天的测试成绩的影响会大于一周后的影响。在这种情况下,代表时间的特征将与睡眠质量特征相互作用或修改。

在线性模型中,交互表示为两个特征的乘积。非线性机器学习模型通常包含众多交互。事实上,交互是高级机器学习模型逻辑的基础, 然而许多常见的可解释性技术侧重于孤立特征的贡献。检查交互的方法包括 2-way ALE 图、Friedman 的 H、部分依赖图和 SHAP 交互值 [2]。本博客探讨了其中的两个:ALE 和 SHAP。

ALE 图

累计局部效应(ALE)是一种测量特征效应的技术,不会受到相关特征或不太可能的特征组合造成的失真的影响。 特征交互可以通过 2-way ALE 图进行可视化。2-way ALE 图首先通过扰动一个特征(i),在固定的第二个特征(j)值下测量模型输出的变化。然后,(在稍微不同的 j 值下进行类似的测量。这两个测量值的差异揭示了扰动 j 如何影响模型对 i 变化的响应。为了减少不太可能的特征组合的影响,测量仅使用在选择的 ij 值附近的小窗口中的观察值。

SHAP 交互值

Shapley 值表示每个特征对模型输出的信用或责任的量。 “SHAP”指的是一组用于计算机器学习模型的 Shapley 值的方法。SHAP 计算测量特征设置为其原始值与参考值相比时模型响应的变化。特征的边际贡献是通过对其他特征的各种组合或“联盟”进行平均来计算的。Shapley 联盟是通过将一些特征值替换为从参考数据集中(通常是训练数据)随机抽取的值来形成的。与 ALE 不同,Shapley 涉及许多特征的扰动,而不仅仅是感兴趣的特征对,并且为每个观察值计算值。

SHAP 交互值将模型分数分配到所有特征主效应 成对交互 [3]。对于一个观察值,特征对(ij)的交互是通过测量在特征 i 具有其原始值 j 时的 Shapley 值来计算的。然后,将 j 替换为从参考中随机抽取的值,并计算特征 i 的新 Shapley 值。这两个测量值之间的差异量化了特征 j 如何修改 i 的 Shapley 值。

数据和模型

我使用了通过 Kaggle [1] 获得的 Lending Club 贷款数据集。该模型根据利率、贷款期限、借款人收入、借款人住房拥有状态以及个人或联合借款人的信用评分等特征预测贷款违约。利用其他人 [1, 4–5] 做的分析,我选择了 18 个预测特征,响应变量是贷款违约的二元指标。训练了一个提升树模型,使用 Scikit-learn 的 GradientBoostingClassifier,该模型与 ALE 图 (PyALE)、SHAP 值 (SHAP) 和 Friedman 的 H (sklearn_gbmi) 兼容。代码可在 GitHub [6] 上获得。

SHAP 和 ALE 对有影响的特征意见不一致

为了检查模型中的交互作用,我生成了 SHAP 依赖图和二维 ALE 图用于特征对。对于大多数特征对,ALE 和 SHAP 图至少有一些相似之处。但对于一个关键交互作用,即利率和期限,结果却出现了冲突:

图 1. 利率和期限的交互作用度量。正值表示预测的违约风险更高。x 轴显示利率,颜色对应于期限值;红色表示 60 个月,蓝色表示 36 个月。A. 利率:期限交互作用的 ALE 值折线图。B. 显示利率:期限的 SHAP 交互作用值的散点图。图片来源于作者

二维 ALE 表明,与高利率结合的较长期限增加了风险。但 SHAP 讲述了相反的故事;较长期限在高利率下对违约有保护作用!

在这个数据集中,贷款期限(term)是分类变量,只有两个值,36 个月和 60 个月,而利率(int_rate)是连续变量。图 1A 和 B 分别显示了 ALE 和 SHAP 值,这些值绘制在相同的尺度上,正值表示由于交互作用,模型默认风险增加。尽管热图通常用于二维 ALE 图,但我更喜欢折线图;这些图也更容易与 SHAP 图进行比较。

图 1 中的数据矛盾尤其令我担忧,因为利率和期限是模型中两个最重要的特征,通过多种衡量标准(聚合 Shapley 值、不纯度和置换重要性;见[6])。此外,根据 SHAP、ALE 和 Friedman 的 H,期限:利率交互作用也很大。

因此,我有两个具有重要交互作用的有影响力特征,但 SHAP 和 ALE 显示的效果方向不同。常识能帮助解决冲突吗?以下是曲线的一些可能解释:

(ALE) 长期高利率贷款特别具有风险。

(SHAP) 高利率对违约的预测力很强;在高利率时,期限并不那么重要。

(SHAP 的故事也来源于单向 ALE 响应[6]。负交互作用取消了单向项的响应,因此这种交互作用可以被解释为该项特征影响力的丧失。)

作为一个非借贷领域的专家,这两个账户对我来说似乎都很合理。理解这些图形为何有所不同意味着深入理解这些技术,同时也要考察简化模型。

Mike Houser的照片,来源于Unsplash

SHAP 与 ALE — 哪些差异是重要的?

SHAP 交互作用和双向 ALE 值都测量当特征j被修改时,模型响应的差异,针对特征i具有相似值的数据点子集。

从上述陈述开始,让我们列出 SHAP 和 ALE 可能存在差异的一些方式:

1. 选择用于测量的数据点。

2. 测试所测量的响应。

3. 特征值如何被修改。

第 1 项似乎不太可能是罪魁祸首。对于 Shapley 来说,是对每个数据点进行测量,我们使用原始特征值。ALE 考虑的是一个值周围的窗口。窗口大小是基于数据密度的,因此较高的利率点反映了一个相对较大的值范围,但在图 1 的“高利率”部分,我们可能有足够相似的观测值。

第 2 项和第 3 项的差异可能对解释图形中的不一致很重要。对于第 2 项,SHAP 和 ALE 测试不同的模型响应。ALE 使用原始模型输出,而 SHAP 则将模型预测分布到多个特征上,并检查归因于 i部分

对于第 3 项,ALE 扰动仅涉及特征j的值替代。而 SHAP 则汇总了许多模型特征的响应;所有变量都会被扰动。替代值是从训练数据中随机抽取的,通常反映了更典型的值,这可能与初始观测的特征差异很大。

稀有情况生成 ALE 信号

在进行模型简化和其他分析后(见[6]中的代码),我意识到 ALE 测试是在响应多个风险因素的情况下模型的预测。下面,我重新计算了仅针对年收入超过 45,000 美元且不是租房者的客户的 ALE 和 SHAP 图。

图 2. 利率和期限的交互测量,仅对高收入非租户客户计算。P. 正值表示预测的违约风险较高。x 轴显示利率,颜色对应期限值;红色表示 60 个月,蓝色表示 36 个月。A. 利率:期限交互的 ALE 值折线图。B. 显示利率:期限的 SHAP 交互值的散点图。图片由作者提供。

当排除低收入和租户案例(约占总数的 50%)时,ALE 信号几乎完全消失,而 SHAP 曲线在质量上保持不变。

原始的 ALE 曲线(图 1A)可以通过一个简化的单树模型进行再现,该模型只涉及三个特征(利率、期限和年收入),如下所示:

图 3. 简单模型以重现 ALE 检测到的特征交互。A. 单个决策树模型的图示。树的遍历从节点 0 开始,当框中显示的条件为真时向左移动,否则向右移动。遍历在叶节点结束;响应是该节点中的值。根据模型响应值对框进行着色(对于非叶节点,值反映了子叶节点的平均值)。在框中注明了达到每个节点的训练数据样本数量。B 关于利率:期限的 ALE 值的折线图,展示了 A 中所示的树。C. 显示利率:期限的 SHAP 交互值的散点图,针对 A 中的树。由于文本中讨论的原因,一些异常点在 SHAP 曲线中被裁剪。图片由作者提供。

图 3A 包含具有非常高或低值的低人口节点(例如,节点 2 和 7)。节点 7 由稀有客户访问,这些客户具有低收入、高利率和长期;这些客户的违约风险非常高。

ALE 图受到稀有特征组合影响的主导作用。 节点 7 代表了少量的贷款,但在 ALE 计算过程中,当更改期限时,模型响应发生了剧烈变化。60 个月的客户离开此节点,降低了风险,而(数量更多的)36 个月客户进入此节点,导致了一个显著的信号。

SHAP 检测复杂模型中的系统性效应

图 1B 中的 SHAP 信号在图 3B 中消失。模型复杂性是 SHAP 结果的关键。为了可靠地再现原始的 SHAP 曲线,我发现需要≥4 个特征、20 棵树和深度>5(见[6]中的代码)。

图 3B 中的 SHAP 曲线包含了超过 30% 利率的异常值(其中一些在图中被裁剪;异常值高达 ~0.4)。如果对异常值取平均,36 个月和 60 个月的值非常相似,接近零 (~0.001)。这些异常值是由于具有访问极端节点 4 和 7 的联合体的情况。模型复杂性减少了异常值。随着特征数量的增加,从参考数据中抽取多个不寻常的值变得不太可能。此外,在计算中平均更多的联合体会稀释信号。

SHAP 测量会降低对稀有特征组合的重视。 SHAP 联合体可能涉及与原始特征值非常不同的值,而 ALE 计算通常涉及在更受限范围内的扰动。SHAP 联合体提供了对模型的更广泛覆盖,反映了由更多节点生成的值,特别是人口较多的节点。

SHAP 计算中特征变化的范围取决于观察值是否相对于参考数据异常。图 1B 和 2B 中的 36 个月期限曲线的平坦性反映了大多数客户(75%)拥有 36 个月的贷款。因此,为期限生成 SHAP 联合体时,随机抽取的值可能会使期限保持不变。减去两个相似曲线的结果会得到一个较小的 SHAP 交互值。

相比之下,60 个月的期限曲线与典型情况更远,因此生成 SHAP 信号。高利率和 60 个月期限下的负值表明,利率特征对 36 个月较低期限值的影响更大。更多的贷款是 36 个月的,并且大多数贷款风险适中,因此在这种情况下高利率更令人惊讶。对于 60 个月期限,高利率则不那么令人惊讶(利率和期限的 Pearson 相关系数约为 0.4),因此可能预期 SHAP 对长期贷款的利率特征赋予较少的权重。

那么,哪个是正确的?

之前,我描述了图 1 中曲线所暗示的两个不同故事:

(ALE) 长期高利率贷款特别具有风险。

(SHAP) 高利率对违约的预测能力很强;当利率较高时,期限并不是很重要。

这两个故事似乎都是真的,但针对不同的客户。对于具有多个风险因素的少见情况,第一个解释是正确的;利率和期限的组合会产生非常大的 ALE 响应。但对于更典型的高利率客户,利率捕捉了大部分风险。因此,SHAP 和 ALE 测试关注的是不同的客户。

为什么这很重要?

在应用可解释性工具后,我们期望能增加对模型工作原理的理解。我们相信,我们将对模型的决策过程有一个总体的了解,甚至可能揭示数据中的一些模式。这些测试用于质量控制和信任建立。当解释与期望一致时,利益相关者会感到安心。

可解释性工具可以提供许多好处,但它们也可能误导或提供虚假的安慰。

可解释性工具在模型公平性测试中尤其重要,以避免偏见和歧视。当特征偏见存在或怀疑存在时,交互度量是至关重要的[7]。SHAP 对稀有特征组合的缺乏响应可能是一个问题,因为性别、种族或年龄等特征组合可能与不良结果有关。相反,ALE 可能会忽略系统性效应,因为它扰动了较少数量的特征,范围更有限。

结论

模型可解释性包通常被描述为“解释器”,其输出为“解释”。我认为使用测试测量这样的词更为有用。例如,“SHAP 值”比“SHAP 解释”更好,因为包的输出与对复杂模型的实际理解之间存在一些距离。我正在尝试改变我对这些术语的使用,以提醒自己这一点!

在医学中,诊断测试在特定情况下进行,结果由专家解读。通常,为了建立诊断,使用不止一种测试。同样,需要对模型可解释性工具有更深入的理解才能得出有意义的结论。

参考文献

[1] N. George, All Lending Club loan data, www.kaggle.com/datasets/wordsforthewise/lending-club

[2] C. Molnar, 可解释的机器学习:使黑箱模型可解释的指南(2023)。

[3] S. M. Lundberg, G. G. Erion 和 S-I. Lee, 树集成的个体化特征归因(2019),arXiv:1802.03888v3 [cs.LG]。

[4] M. Gusarova, 特征选择技术(2023),Kaggle。

[5] N. George, 使用 Python 进行探索性数据分析(2019),Kaggle。

[6] V Carey, GitHub 代码库, github.com/vla6/Blog_interactions

[7] V Carey, 特征偏见的无免费午餐(2021),Towards Data Science。

使用 ONets 进行形状重建

原文:towardsdatascience.com/shape-reconstruction-with-onets-1c1afe89c50

使用可学习函数表示 3D 空间

Cameron R. Wolfe, Ph.D.Towards Data Science Cameron R. Wolfe, Ph.D.

·发布于 Towards Data Science ·11 分钟阅读·2023 年 2 月 7 日

--

(照片由 Tareq Ajalyakin 拍摄,来源于 Unsplash

3D 重建的问题旨在根据物体的噪声视图(例如,部分点云、2D 图像等)生成高分辨率的物体表示。最近,深度神经网络成为 3D 重建的热门方法,因为它们可以编码有助于处理模糊的信息。简单来说,这意味着如果从输入中不清楚如何正确重建给定的物体,神经网络可以借鉴它在训练过程中遇到的其他数据点的经验,仍然生成合理的输出

大多数 3D 重建方法最初在表示高分辨率物体的能力上存在限制。体素、点云和网格在以内存高效的方式建模高分辨率物体方面都存在不足。与如 GANs [2] 等可以生成高分辨率、逼真图像的模型相比,用于生成 3D 几何的可比方法还处于初级阶段。

(来自 [1])

为了解决这个问题,文献[1]中作者提出了一种 3D 重建方法,该方法使用神经网络对占据函数进行建模。更具体地说,我们训练一个神经网络来预测空间中给定点是否被物体占据(即,占据函数!)。然后,底层物体通过该神经网络的决策边界(即,预测从占据到未占据的切换位置)来表示;见上文。占据网络(ONets)可以以任意精度和合理的内存要求表示和重建 3D 形状。

(来自 [1])

背景

在我们之前对DeepSDF的 3D 形状生成的回顾中,我们涵盖了一些相关的背景概念,包括:

  • 点云、网格和体素表示 [link]

  • 前馈神经网络 [link]

[1]中的大多数神经网络使用了简单的前馈架构,这种架构作为一种替代方案,用于存储 3D 形状的标准网格、体素或点云格式。为了理解为什么这种方法更可取,我们需要更详细地了解这些表示的限制。

3D 形状表示的缺点

(来自 [6])

3D 几何形状通常以网格、体素或点云的形式存储或表示;见上文。鉴于这些表示已经存在,我们为什么还要使用神经网络来表示形状呢? 简单的答案是 (i) 其他表示方法存在一些显著的限制,以及 (ii) 我们可以通过这种方式节省大量内存。

体素在内存使用上不高效。 在深度学习应用中,体素是用于 3D 形状的最广泛使用的表示形式,因为它们简单——它们是像素在 3D 空间中的直接推广。然而,如果我们使用体素来编码一个 3D 形状,这种编码的内存占用会随着分辨率的提高而立方增长。如果我们想要更精确的体素表示,我们需要使用更多的内存。

点云是断开的。 点云的格式类似于我们通常从传感器(如 LiDAR)获得的数据,但结果几何形状是断开的——它只是 3D 空间中的一堆点。因此,从这些数据中提取实际的 3D 形状需要昂贵的后处理程序。

网格并不是解决方案。 如果点云需要后处理而体素在内存使用上不高效,我们应该使用网格,对吗? 不幸的是,大多数网格表示实际上是基于变形的“模板”网格。实际上,这意味着网格无法编码任意的拓扑结构,这使得它们在准确表示某些几何形状时受到限制。

那么我们该怎么办? 考虑到这些限制,[1]中提出的方法开始变得更加合理。我们可以训练一个神经网络来生成可以恢复形状的输出,而不是直接存储 3D 形状的网格、体素或点云表示。通过这种方式,我们可以在一个具有固定内存成本的神经网络参数中存储大量不同的几何形状!

占用函数

[1]中的工作将占据函数表示为神经网络,但占据函数是什么?简单来说,这只是一个将空间中的点(例如,[x, y, z] 坐标)作为输入,并返回一个二进制输出的函数,表示该位置是否被目标对象“占据”。

占据函数的特征

这样的函数可以通过一个神经网络进行逼近,该网络被训练为在给定 [x, y, z] 坐标作为输入时输出零到一之间的概率。

提取等值面。 要从占据函数中提取 3D 几何体,我们必须找到一个等值面。为此,我们只需在 3D 空间中找到占据函数等于某个阈值 t 的点。

等值面

在这里,t 被设置为零到一之间的某个值。因此,等值面表示占据函数的边界,即输出从零切换到一(或反之)的地方——这对应于基础对象的表面!

评估指标

用于判断 3D 形状质量的指标与我们在普通计算机视觉中看到的指标非常相似。[1]中使用的主要指标如下所述。

体积 IoU。 两个形状交集的体积除以它们并集的体积。该指标与法线 IoU相同,但它已被推广到三维。

Chamfer-L1。 精度和完整性的均值。精度是输出网格上点到地面真实网格最近点的平均距离。完整性是相同的,但方向相反。

法线一致性。 我们取预测网格的面法线(即一个垂直于网格某一面平面的向量),找到另一个网格中对应最近邻的面法线,然后取这些向量的点积的绝对值。通过对预测网格中的所有法线重复这一过程并取平均值,我们得到法线一致性。这一指标稍显复杂,但它对于确定预测形状是否捕捉到高阶信息(即两个网格的面是否趋向于指向相同方向)非常有用。

占据网络 [1]

在了解 DeepSDF 的基本概念后,我们可能会问占据网络的第一个问题是:为什么要建模占据函数,而不是像 有符号距离函数(SDFs)这样的替代方法? 基本的答案是,占据函数更容易学习。

“[SDFs 通常比占据表示更难学习,因为网络必须在 3D 空间中推断距离函数,而不仅仅是将体素分类为占据或未占据。]” — 摘自 [1]

网络。 为了逼近占据函数,我们使用一个前馈神经网络,该网络将一个介于零和一之间的概率分配给任何 3D 坐标。网络的输入包括:

  1. 一个(单一的)3D 坐标。

  2. 对底层物体的噪声观察。

我们称之为 ONet 的网络输出一个标量概率。作为模型输入的底层物体的噪声观察可能是诸如不完整的点云或粗略的体素网格之类的东西。我们将神经网络以这些噪声数据为条件,然后使用它生成物体的更精确表示;见下文。

对每个空间位置建模占据函数

我们不能直接将观察到的噪声数据作为输入传递给前馈神经网络,因为它们可能有许多不同的格式。相反,[1]中的作者使用不同的基于神经网络的编码器(例如,图像的 ResNets [4] 或点云的 PointNet [5];这些只是将图像和点云转换为向量的常见网络架构)将这些数据(即,将其转换为向量)嵌入。

然后,前馈神经网络的第一层使用 条件批量归一化 — 一种条件 批量归一化 变体 — 根据物体嵌入调整网络的输入。这样,我们确保 ONet 的输入以我们尝试重建的 3D 几何数据为条件。

训练。 为了训练 ONet,我们 (i) 考虑一个围绕训练物体的填充空间体积,以及 (ii) 在这个空间中均匀采样 K 个占据函数值。通过将这种采样过程应用于多个训练物体来形成一个小批量。我们使用带有每个小批量中占据函数值的 交叉熵损失 的小批量梯度下降法正常训练网络。

生成对象。 一旦 ONet 训练完成,我们可以在任意分辨率下输出给定空间位置的占据函数值。但是,我们如何使用这些值来创建实际的 3D 对象(例如,以网格格式)? 为此,[1]中的作者提出了一种多分辨率等值面提取(MISE)过程。该过程受到 八叉树 的启发,八叉树是一种递归表示 3D 空间的树形数据结构。每个八叉树节点有八个子节点,我们可以递归地向每个节点添加子节点,以表示更高分辨率的体积。

MISE 的基本过程如下:

  1. 在初始分辨率下离散化空间

  2. 在该空间的每个离散位置评估 ONet

  3. 标记所有具有至少两个相邻体素且占用情况不同的体素

  4. 将标记的体素划分为八个子体素,并重复直到达到所需的分辨率

  5. 应用 Marching Cubes 以获取网格

因此,MISE 通过在需要的更高分辨率下递归评估 ONet 来获得网格(即,接近对象边界)。尽管我们必须应用一些额外的步骤来精细化此网格,但整体过程相当直观;见下文。

(来自 [1])

实验结果。 ONets 主要在合成 ShapeNet 数据集上进行评估,基于其表示复杂 3D 形状和/或从图像、噪声点云和低分辨率体素网格中重建它们的能力。在这些实验中,我们看到 ONets 可以准确地表示 ShapeNet 的“椅子”部分。使用 ONet,我们可以使用仅 6M 参数独立编码近 5K 对象。相比之下,体素表示不够准确,并且其内存要求随着所需分辨率的增加而增加;见下文。

(来自 [1])

在重建实验中,我们继续看到 ONets 工作得相当好。它们能够恢复复杂的形状,并且往往产生最接近真实几何的结果。大多数基准方法通常存在限制,例如产生粗糙的表示、需要后处理的断开对象或缺乏相关细节的对象。模型输出的定性示例如下所示。

(来自 [1])

当尝试从噪声点云和粗糙体素网格(即,而不是像上面那样使用图像)重建几何时,我们看到 ONet 继续表现良好;见下文。

(来自 [1])

其他内容。 ONets 还通过从 KITTIOnline Products 数据集中获取图像应用于现实世界数据。尽管仅在合成数据上训练,ONet 似乎对这种类型的数据泛化效果很好。然而,值得注意的是,作者确实使用了 KITTI 提供的分割掩码来提取与所需对象相关的像素。以下展示了从这些数据集中生成的重建示例。

(来自 [1])

超越[1]中的初步提议,作者还提出了一种生成版本的 ONet,该版本通过对 ShapeNet 进行无监督训练,并形成一个类似于变分自编码器(VAE)的 3D 几何潜在空间。简单来说,作者发现可以创建生成 ONets,这些 ONets 能够生成新的网格并在网格之间进行插值。如果我们希望关注生成应用而不仅仅是 3D 重建,这将非常有用。

要点

在 ONets 之前,现有的 3D 重建方法在保持合理内存占用的同时,难以对高分辨率物体进行建模。在[1]中,我们了解到,更智能的 3D 几何表示可以带来显著的好处。ONets 提案中的主要要点如下。

占据函数非常棒。 常见的 3D 几何表示(如网格、点云、体素)在表示高分辨率物体时往往占用过多内存。占据函数是一种更简洁的表示方法,它通过编码单一函数来实现对 3D 物体的任意精度建模。而且,与像 SDF 这样的替代方法相比,占据函数更容易学习或建模。

可学习的重建。 当然,占据函数很棒,但这项工作的真正价值在于我们如何表示这些函数。也就是说,我们可以使用神经网络来学习和存储各种形状的占据函数。因此,我们可以通过(i)训练一个 ONet 和(ii)存储模型的参数,以任意精度表示许多不同的形状。这种方法使用有限且固定的内存量,并通过利用先验信息来提高重建质量!

表示 3D 空间。 在解释使用 ONets 生成网格的 MISE 方法时,我们很快遇到了八叉树的概念。这是 3D 建模中的一个重要数据结构,它允许我们以不同的精度递归生成形状。如果我们想获得更精确的表示,只需继续细分体素。但我们应仅在有意义的地方(即当附近的体素具有不同的占据情况时)进行此操作,以避免不必要的计算。

结束语

非常感谢阅读本文。我是Cameron R. Wolfe,一名在Alegion工作的研究科学家,同时也是莱斯大学的博士生,研究深度学习的经验和理论基础。你也可以查看我在 medium 上的其他写作!如果你喜欢这篇文章,请在twitter上关注我或订阅我的Deep (Learning) Focus 通讯,我在其中撰写了有关深度学习重要主题的易懂概述系列。

参考文献

[1] Mescheder, Lars, 等. “占据网络:在函数空间中学习 3D 重建。” IEEE/CVF 计算机视觉与模式识别会议论文集。2019 年。

[2] Goodfellow, Ian, 等. “生成对抗网络。” ACM 通讯 63.11 (2020):139–144。

[3] Mildenhall, Ben, 等. “NeRF:将场景表示为神经辐射场以进行视图合成。” ACM 通讯 65.1 (2021):99–106。

[4] He, Kaiming, 等. “图像识别的深度残差学习。” IEEE 计算机视觉与模式识别会议论文集。2016 年。

[5] Qi, Charles R., 等. “Pointnet:针对 3D 分类和分割的点集深度学习。” IEEE 计算机视觉与模式识别会议论文集。2017 年。

[6] Hoang, Long, 等. “一种使用波形核签名和 3D 三角网中心点的 3D 对象分类深度学习方法。” 电子学 8.10 (2019):1196。

用 SQL 进行数据塑形

原文:towardsdatascience.com/shaping-your-data-with-sql-71822f2fc2f4

使用不同的技术改进和优化你的数据分析过程

Chi NguyenTowards Data Science Chi Nguyen

·发表于Towards Data Science ·9 分钟阅读·2023 年 4 月 18 日

--

OB OA拍摄,图片来自Unsplash

什么是数据塑形?

没有一种放之四海而皆准的数据。为了不同的目的和使用案例,数据会根据需要进行定制。你对数据未来使用目的的了解越多,你就越能正确地将数据呈现给最终用户。

例如,用于进行深入分析的数据与提供给高层管理的汇总数据有所不同。

另一个例子是,业务发展团队更关心每个地区的总体成本以吸引新用户,而市场营销经理则更关注与区域有关的附属营销成本。

也就是说,将现有数据结构转换为任何替代的透视或非透视结构,是数据操作和分析过程中不可或缺的一步。

在这篇文章中,我将介绍一些在特定情况下对数据进行塑形和切片的技术。通常,我会使用 PostgreSQL 来展示我的例子。

现在,让我们开始看看我们得到的结果吧!

数据

在这篇文章中,我将使用《2015–2021 年世界幸福报告》的数据。该数据集提供了基于不同指标的全球各国幸福水平:经济增长、社会支持、出生时的预期寿命等。数据可在Kaggle上获得,并具有CC0: 公共领域许可。如下面的图像所示,我将仅利用一些字段:

  • 国家名称

  • 年份:报告年份(2005–2021)

  • 生活梯度:每个受访者认为的最佳生活由梯度上的 10 表示,而最差的生活由 0 表示。然后要求每个参与者在梯度上对其当前生活进行排名。

  • 人均 GDP 对数:以购买力平价(PPP)调整的美元表示的人均 GDP 的对数。

  • 社会支持:国家中对社会支持(能够依靠他人)的感知。

  • 健康预期寿命:指一个国家公民在特定时期的平均寿命。

作者提供的图片

使用窗口函数

使用 PRECEDING AND CURRENT ROW 进行的滚动计算

示例: 显示每个国家的滚动三年平均生活梯度指数

什么是三年滚动平均?简单来说,它计算的是过去两年的平均生活梯度分数加上当前年。如果当前年份是 2010 年,例如,每个国家的三年滚动平均生活梯度分数将等于该国家在 2008 年、2009 年和 2010 年的分数的平均值。如下面的图片所示,2010 年阿富汗的‘rolling_average’为 4.29,这个值是 3.72、4.40 和 4.76 的平均值。

作者提供的图片

为了指定计算中要考虑的行范围,SQL 中的窗口函数与PRECEDINGCURRENT ROW一起使用。具体来说,PRECEDING确定在CURRENT ROW之前的行数。因此,与PARTITION BY国家和ORDER BY年份结合使用时,下面的 SQL 命令将返回每个国家在‘rolling_average’列中的滚动三年平均生活梯度分数。

SELECT year, country_name, life_ladder,
  AVG(life_ladder) OVER (
    PARTITION BY country_name
    ORDER BY year
    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW
  ) AS rolling_average
FROM public.happiness_index; 

使用 UNBOUNDED PRECEDING AND CURRENT ROW 进行的周期计算

示例: 显示每个国家的期间平均生活梯度分数

为了进行此计算,我们使用UNBOUNDED PRECEDING AND CURRENT ROW。计算窗口将包括当前值和所有行直到当前行。

例如,如果当前年份是 2009 年,那么一个国家在这段时间内的平均分数将等于其 2008 年和 2009 年的分数的平均值。类似地,在 2010 年,平均生活梯度分数将通过将 2008 年、2009 年和 2010 年的分数除以 3 来确定。你可以参考下面的命令获取更多信息。

SELECT year, country_name, life_ladder,
AVG(life_ladder) OVER (
    PARTITION BY country_name
    ORDER BY year
    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
  ) AS rolling_average
FROM public.happiness_index

作者提供的图片

使用 UNBOUNDED PRECEDING AND CURRENT ROW 进行的百分比计算

示例: 计算生活梯度点与累积平均值的百分比变化。

类似于之前的示例,使用了UNBOUNDED PRECEDING AND CURRENT ROW,但这个示例并不是仅计算跨时间的平均值,而是关注当前值与期平均值之间的百分比变化。在这种情况下,结果存储在第四列,如下图所示。你可以很容易地观察到哪一年相比滚动平均值实际发生了正/负变化。此外,这个指标还告诉我们目标值变化的幅度。

图片由作者提供

SELECT 
  year, 
  country_name, 
  life_ladder, 
  100 * (life_ladder - AVG(life_ladder) OVER (PARTITION BY country_name ORDER BY year ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)) / AVG(life_ladder) OVER (PARTITION BY country_name ORDER BY year ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS percentage_change
FROM 
  public.happiness_index

实际上,在时间序列分析中,这种方法是识别数据趋势或模式的重要方法之一。例如,它在销售分析中使用最为频繁,主要关注销售增长或市场份额。

使用 PERCENTILE_RANK() 进行百分位计算

示例: 识别区域内每个国家的人均 GDP 的百分位

在这种情况下,我想知道哪些国家在 2008 年各自区域中拥有更高的人均 GDP。这个任务只需借助PERCENTILE_RANK()函数即可完成。

SELECT country_name, regional_indicator, log_gdppercapita,
  round((PERCENT_RANK() OVER (
    PARTITION BY regional_indicator
    ORDER BY log_gdppercapita
  ))::numeric,2) AS percentile_rank
FROM public.happiness_index
where year = 2008;

图片由作者提供

因此,正如你所观察到的,通过PARTITION BY区域和ORDER BY人均 GDP,命令将国家根据其人均 GDP 排名分为不同的百分位类别。例如,根据数据,拉脱维亚的人均 GDP 高于中东欧地区 33% 的国家。

使用 CASE WHEN 结合聚合函数进行数据转换

示例: 显示每年每大洲的平均生活阶梯得分

在这种情况下,我将根据‘regional_indicator’字段将国家划分为亚洲、欧洲、非洲和美洲四个大区域。我们可以通过首先使用CASE WHEN识别每个区域,然后得到每个位置对应的生活阶梯平均值。

SELECT year, 
ROUND(AVG(CASE WHEN regional_indicator LIKE '%Asia%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Asia, 
ROUND(AVG(CASE WHEN regional_indicator LIKE '%Europe%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Europe, 
round(AVG(CASE WHEN regional_indicator LIKE '%Africa%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Africa,
round(AVG(CASE WHEN regional_indicator LIKE '%America%' THEN life_ladder 
  ELSE null END)::numeric,2) AS America
FROM public.happiness_index
GROUP BY 1
ORDER BY 1

图片由作者提供

你可以看到保存于行中的 4 个不同区域的生活阶梯得分已经被转换成 4 列。这种数据转换使得分析师可以轻松监控不同地点在特定年份的值变化。它还对跟踪某地区在一定时期内的数据变化非常有用。

使用 UNION ALL 进行数据反透视

示例: 反透视之前数据集中的表

我们如何将上一个示例中的表转换成下表?

图片由作者提供

分析需要灵活的数据转换技术,因为它使你能够从任何维度查看数据并获得更有洞察力的信息。正如你所看到的,前面的示例展示了结果数据表在垂直查看和水平查看数据时如何提供洞察。

在这种情况下,我将展示如何使用 UNION ALL 将数据反透视到其原始状态的简单转换。当使用 UNION ALL 时,需要注意的是,所有用于联合的组件的列数和数据类型必须兼容。

WITH tbl AS 
(SELECT year, 
ROUND(AVG(CASE WHEN regional_indicator LIKE '%Asia%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Asia, 
ROUND(AVG(CASE WHEN regional_indicator LIKE '%Europe%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Europe, 
round(AVG(CASE WHEN regional_indicator LIKE '%Africa%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Africa,
round(AVG(CASE WHEN regional_indicator LIKE '%America%' THEN life_ladder 
  ELSE null END)::numeric,2) AS America
FROM public.happiness_index
GROUP BY 1
ORDER BY 1)

SELECT year, 
'Asia' AS region
, asia AS avg_life_ladder
FROM tbl
    UNION ALL
SELECT year, 
'Europe' AS region
, europe AS avg_life_ladder
FROM tbl
    UNION ALL
SELECT year, 
'Africa' AS region
,africa AS avg_life_ladder
FROM tbl
    UNION ALL
SELECT year, 
'America' AS region
, america AS avg_life_ladder
FROM tbl
;

使用 UNPIVOTING 和 PIVOTING 函数进行数据透视

UNNEST 函数用于数据反透视

示例: 与前一个案例的要求相同

了解多种处理数据以产生相同结果的方法是非常重要的,因为这可以对数据进行更大的主动控制。也就是说,除了 UNION ALL 外,UNNEST() 也是另一种用于数据反透视的函数。使用 UNNEST() 可以将数组列转换为不同的行。

WITH tbl AS 
(SELECT year, 
ROUND(AVG(CASE WHEN regional_indicator LIKE '%Asia%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Asia, 
ROUND(AVG(CASE WHEN regional_indicator LIKE '%Europe%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Europe, 
round(AVG(CASE WHEN regional_indicator LIKE '%Africa%' THEN life_ladder 
  ELSE null END)::numeric,2) AS Africa,
round(AVG(CASE WHEN regional_indicator LIKE '%America%' THEN life_ladder 
  ELSE null END)::numeric,2) AS America
FROM public.happiness_index
GROUP BY 1
ORDER BY 1)

SELECT 
year, 
UNNEST(ARRAY['Asia', 'Europe', 'Africa', 'Ameria']) AS region,
UNNEST(ARRAY[asia, europe, africa, america]) AS life_ladder
FROM tbl
;

图片来源:作者

在我们的示例中,数组列 [‘Asia’, ‘Europe’, ‘Africa’, ‘Ameria’] 在反透视后被转换回行值。

CROSSTAB 函数用于数据透视

示例:与 CASE WHEN 示例相同的要求:显示每年每个大洲的平均生活等级得分

COSSTAB() 是一种智能的数据透视、转换和汇总方法,以矩阵格式显示数据。在这种情况下,我将使用这个函数将亚洲、欧洲、非洲、美洲和独立国家联合体的行值转换为不同的列。

SELECT * FROM crosstab(
  'SELECT year, region, round(AVG(life_ladder)::NUMERIC,2)::FLOAT as life_ladder
   FROM (SELECT *, 
CASE WHEN regional_indicator LIKE ''%Asia%'' THEN ''Asia''  
     WHEN regional_indicator LIKE ''%Europe%'' THEN ''Europe''  
     WHEN regional_indicator LIKE ''%Africa%'' THEN ''Africa'' 
     WHEN regional_indicator LIKE ''%America%'' THEN ''America'' 
  ELSE ''Commonwealth_of_Independent_States'' END AS region
FROM public.happiness_index) a
  GROUP BY 1,2
   ORDER BY 1, 2') AS region_life_ladder (year int, Asia FLOAT, Europe FLOAT, Africa FLOAT, America FLOAT, Commonwealth_of_Independent_States FLOAT)

图片来源:作者

结论

数据下隐藏着洞察力,我们的使命是以任何可能的方式处理数据,以从数字和事实中获取最大价值。

以上是我在数据整理和处理中的一些技巧,希望它们对你有所帮助。

感谢你读到最后。要获取有关我即将发布的文章的更新,请通过提供的 Medium 链接 订阅成为会员。

你可以在以下网址阅读我的其他 SQL 文章:

照亮您数据科学之旅的可转移技能

原文:towardsdatascience.com/shining-light-on-transferrable-skills-for-your-data-science-journey-a4c67c3d0de8

我对那些从学术界转向商业数据科学的关键可转移技能的看法

Kirill LepchenkovTowards Data Science Kirill Lepchenkov

·发布于Towards Data Science ·阅读时间 9 分钟·2023 年 4 月 7 日

--

光束形状图像(作者拍摄)

前言

我在激光物理、非线性光学和固态激光工程领域担任研究员已有 5 年。虽然我完全沉浸于这个领域,并对自己所做的工作感到兴奋,但在某个时刻,我过渡到了商业数据科学行业。

在数据科学领域工作了额外 6 年后,我感到我在应用物理领域发展起来的技能在与激光物理完全无关的商业项目中得到了完美应用。

关于学术经验可能多么有用已经有很多讨论,但我决定表达我个人对这一主题的看法。

为了阐明我的观点,我决定根据每个技能组的实用性及其原因进行评级。

这篇文章适合谁?

我认为我写这篇文章主要是为了那些考虑从学术环境转向商业领域的人,但也是为了我自己,反思两者之间工具、技能和思维方式的交集。

文献综述经验 → 7/10

为什么文献综述在商业数据科学中是如此重要且可转移的技能(习惯)?

在我物理学时期的文献综述(作者的桌面)

在我看来,文献综述在商业数据科学中有点被忽视和误解。我并不是说我们对全新模型架构和框架设计的阅读不够(这部分做得非常好)。

但当涉及到快速有效地获取关于项目主题的更结构化和有价值的信息时——在我看来,这是数据科学领域存在的最大空白。

文献综述 可能不是这里最好的术语。我也可以称其为背景研究,或最先进的分析

在处理商业问题时,我认为对问题主题有一定的理论基础是至关重要的。文献综述的作用:

  • 为数据战略的可靠决策奠定基础。 了解领域内现有的技术和方法。

  • 加快入职过程。 如果你对自己正在从事的领域不熟悉,尽快获取相关知识是实现价值生成的第一步。

  • 提高与领域专家的沟通质量。 领域专家,也称为主题专家,对于解决数据问题至关重要。但他们通常不编程,而且非常忙。因此,数据科学家必须掌握一些领域特定的术语和概念,以便与这些专家有效沟通和顺畅合作。

  • 大幅提升洞察力的质量。 根据我的经验,文献综述为数据收集、预处理、建模和评估提供了决策基础,最终提高了你提供的洞察力的质量。在我的经验中,它有效,但并非总是如此。

关注文献综述,并投入时间和精力,体现了一种特定的心态——开放、谦虚和好奇。 文献综述有助于避免重新发明轮子或陷入确认偏差的陷阱。

我相信,随着大语言模型和基于这些模型的服务的扩展,文献综述的过程会发生变化,但我们还未到达那一步。

记录→ 9/10

将学术界的记录实践转移到商业数据科学中,对我来说非常有价值。除了多个实际好处,它在经历研究人员工作生活中的起伏时,给你一种无价的连续感。在我看来,通过采用保持实验室笔记本这一关键习惯,数据科学家可以轻松跟踪实验、记录想法和观察,监控个人和职业成长。我写了一整篇文章来阐述这样做的好处,欢迎查阅!

实验室笔记本作为数据科学从业者的选择武器

我的一套有效笔记记录原则,以实验室笔记本的形式呈现

towardsdatascience.com

编程知识 → 6/10

在我的科学历程中,我每天都在处理实验数据、进行数值模拟和统计学习。编程对于开发和测试新的激光设计(数值模拟)也是必不可少的。

我一直在不断使用它来处理典型的数据科学任务:

  • 实验数据处理(Python,Wolfram

  • 数值模拟(Wolfram, Matlab, Python)

  • 统计学习(Wolfram, Matlab, Python)

  • 数据可视化(Origin Pro, Python, R)

我的“数据工作”科学工具栈

Wolfram(更具体地说是 Wolfram Mathematica)是使用最频繁的工具,因为我们在实验室里有它的许可证。它有很棒的工具集用于求解非线性微分方程,我们广泛使用它进行数值模拟。

Python 是我处理实验生成的数据(光束形状、振荡图)的首选工具。

说到数据可视化,Origin 是主要工具,因为它允许将视觉元素嵌入到文本文档中,同时保持可编辑性。折线图、直方图(包括核密度估计器)、回归分析——Origin 是一个很棒的工具。Origin 有一个图形用户界面,所以这不仅仅是编程的问题,我提到它是为了确保 Python 和 R 不会独占所有数据可视化的功劳。

总的来说,我对上述提到的每一个工具都有扎实的使用经验:我了解语法,并且能够以相当高的效率解决问题。那么为什么只有 6/10 呢?为什么在学术界获得的编程技能在商业数据科学中相对难以转移?这确实是一个相当强的声明,但我认为学术经验的缺点可能会超过其优点。主要是因为许多科学环境中完全忽视了良好的软件实践

警告:这一说法基于我在应用物理领域的个人经验,并且绝对不适用于所有在学术界工作的人。对这一部分的内容要持保留态度!

一方面,忽视良好的软件原则是研究人员优化研究速度和发表数量,而非代码质量和可维护性的自然结果。另一方面,几乎没有人从正统的软件开发转到学术界(出于经济原因),因此根本没有真正的生产专业知识。我还应该提到,设计实验、进行文献综述、收集测量数据、编写处理代码以及获得有价值的见解——所有这些同时进行是非常耗费精力的。因此,你根本没有足够的资源去学习软件开发。

测量能力→ 9/10

这一点难以解释,所以请耐心听我说。应用激光物理中的测量工作本身就是一个独立的学科。提供有价值的测量是一项需要多年训练的技能!原因有很多:你必须理解过程的物理学,遵循测量协议,并且具备操作复杂且昂贵仪器的专业知识和训练。

例如,我一直在使用二极管泵浦脉冲固态激光器,测量激光束的多个参数:脉冲持续时间、脉冲能量、重复频率、束形、发散度、偏振、光谱内容、时间特性和光束腰部。进行这些测量中的任何一项都非常困难。比如说,你想测量束形(见下图)。

beam profiles 3d(作者拍摄)

光束形状指的是激光束在其截面或横截面上的强度空间分布。

理论上,你只需将激光光束对准 CCD 相机,几秒钟内就能获得光束形状。但实际操作起来则大相径庭。如果你正在使用脉冲固态激光且有相当的脉冲能量,并且你知道自己在做什么,你会将激光光束导向高质量的光学楔子,将大部分脉冲能量集中到一个陷阱中,并使用仅有原始光束一部分能量的反射光束进行工作。这样做是为了保护 CCD 相机免受灾难。但使用楔子还不够。你还需安装一个可调的光束衰减器,将其锁定到最暗模式,然后逐渐降低吸收率,直到在 CCD 相机上获得正确的曝光。

如果你正在使用对人眼不可见的红外激光,你会面临一个问题:你必须在看不到实际光束的情况下通过小孔引导光束。这项技能只能通过训练和实践获得。顺便说一下,每一步光束操作都必须极其小心,以遵守安全规定:你必须佩戴适当的防护眼镜,使用保护屏幕等。

好的,继续,现在你的光束被衰减并完美地对准了 CCD 相机。但你还有很多工作要做:将 CCD 相机连接到激光电源单元以实现同步并产生稳定的图像。如果你做对了所有步骤——你就能获得图像。等等,图像?

光束轮廓 2D(由作者捕获)

然后你意识到,如果你的激光以 50 Hz 的脉冲重复频率运行,这意味着它每秒产生 50 个脉冲。每个脉冲可能有略微不同的光束轮廓。你该如何获得结果?你应该随便选一个脉冲并捕捉图像?还是应该使用一定数量的脉冲生成平均图像?哦,管理 CCD 相机的软件默认启用了平均功能?

让我们结束这个“测量光束形状”的废话吧。根据我一生中的所有测量经验,我有两个关键的可转移品质:警惕(永远不要仅仅相信表面)和对元数据的细致关注(数据是如何测量或记录的,使用了哪些工具,甚至最初发生的原因)。这两者在处理实际数据时都是金标准。因为它让你在产生实际影响时更加高效,而不会陷入麻烦之中。这在学术界和商业数据科学中都很受重视。

数据通信熟练度 → 10/10

当我在学术界时,我并没有认为数据沟通是一个特别值得关注或有价值的写作主题。处理数据可视化、讨论数据和理论,以及撰写科学论文只是工作的组成部分。但经过多年的研究,你在不同层次(正式和非正式)上获得了扎实的数据沟通技能。

写作科学论文是正式数据沟通类型中最具挑战性的技能之一。要能写出一个结构合理(摘要 → 引言 → 文献综述 → 方法论 → 结果 → 讨论 → 结论 → 致谢)的引人入胜的文章,需要大量的练习。文章的结构本身假设你有一个故事要写。而且这不仅仅是写作:你还必须知道如何制作引人注目且有目的的数据可视化。这一切都是为了将你的信息传达给观众。

我将这一技能的可转移性评分为 10 分(满分 10 分),因为商业数据科学毫不意外地依赖于人与人之间的互动、传达你的思想和结果。

结论

总体而言,我相信拥有科学背景的人可以为数据科学领域带来独特的视角和宝贵的技能。对于那些认为转向商业数据科学意味着放弃所有辛勤工作和专业知识的学术界人士,我提供一个不同的观点:你有大量的价值可以带到桌面上。在我看来,最佳的行动方案是利用你现有的技能,同时掌握你转型领域的新技术和最佳实践(我们都知道这是一个终身的旅程)。

最短路径(Dijkstra)算法:一步步的 Python 指南

原文:towardsdatascience.com/shortest-path-dijkstras-algorithm-step-by-step-python-guide-896769522752

使用 OSMNX 1.6 和长距离路径的更新

Bryan R. VallejoTowards Data Science Bryan R. Vallejo

·发布于 Towards Data Science ·6 分钟阅读·2023 年 10 月 4 日

--

图片由作者提供。摩洛哥的最短路径(约 350 公里)

这个著名的算法在 Python 库 OSMNX 中实现,可以用来寻找两个位置之间按距离或时间加权的最短路径。该算法使用 OpenStreetMap(OSM)网络,通过 Python 库 NETWORKX 在后台寻找驾驶、步行或骑车的路线。

我写这个更新是因为函数的参数稍有变化,且有人询问为什么我的代码在旧博客文章中无法工作,这只是因为代码是用旧版本的 osmnx 编写的。

旧教程包含了相当有价值的过程,但我决定做一个一步步的指南,这样获取最短路径的过程会更准确,使用这个指南的分析师可以真正理解整个过程。

这里是旧教程,如果你想查看一下。

在芬兰赫尔辛基,使用不同的网络

## OSM 街道网络中使用的最短路径算法

车辆、自行车和行人最短路径分析的 GIS 自动化技巧

towardsdatascience.com

在爱沙尼亚塔尔图,使用步行网络

## 使用 OSM 步行网络的最短路径算法

使用 OSM 数据在爱沙尼亚塔尔图寻找最短步行路径

towardsdatascience.com

OSM 数据许可

介绍

在这个实践中,我将使用摩洛哥的两个位置。这个实践由我的一位读者 Hanae 提出,她提供了原点和目的地。

作者提供的图像。原点和目的地位置。

编码实践

正如我提到的,我将做一个逐步指南,所以让我们开始。在此之前,让我们导入所需的库。

import osmnx as ox
import geopandas as gpd
from shapely.geometry import Point, LineString
import pandas as pd
import matplotlib.pyplot as plt

1. 定义原点和目的地

简单地,我们将创建几何对象作为点:

# origin and destination geom

origin_geom = Point(-5.6613932957355715, 32.93210288339607)

destination_geom = Point(-3.3500597061072726, 34.23038027794419)

2. 提取 OSM 图对象

然后,我们将提取图形,用于生成最短路径。我们逐步来看。

  • 从原点和目的地创建 GeoDataFrames
# create origin dataframe
origin =  gpd.GeoDataFrame(columns = ['name', 'geometry'], crs = 4326, geometry = 'geometry')
origin.at[0, 'name'] = 'origin'
origin.at[0, 'geometry'] =origin_geom

# create destination dataframe
destination =  gpd.GeoDataFrame(columns = ['name', 'geometry'], crs = 4326, geometry = 'geometry')
destination.at[0, 'name'] = 'destination'
destination.at[0, 'geometry'] = destination_geom
  • 获取包含原点和目的地的图

我们将使用 Geopandas 的 envelope 函数来将多边形用作掩膜以提取图形。

首先一个简单的函数。

def get_graph_from_locations(origin, destination, network='drive'):
    '''
    network_type as drive, walk, bike
    origin gdf 4326
    destination gdf 4326
    '''
    # combine and area buffer
    combined = pd.concat([origin, destination])

    convex = combined.unary_union.envelope # using envelope instead of convex, otherwise it breaks the unary_union

    graph_extent = convex.buffer(0.02)

    graph = ox.graph_from_polygon(graph_extent, network_type= network)

    return graph

然后,使用它并绘制结果。

graph = get_graph_from_locations(origin, destination)
fig, ax = ox.plot_graph(graph, node_size=0, edge_linewidth=0.2)

作者提供的图像。图包含原点和目的地

3. 找到原点和目的地的最近节点

获取使用原点和目的地位置的网络中最接近的节点。节点代码可以使用 osmnx 函数获得。

# ------------- get closest nodes

# origin
closest_origin_node = ox.nearest_nodes(G=graph, 
                                       X=origin_geom.x, 
                                       Y=origin_geom.y)

# destination
closest_destination_node = ox.nearest_nodes(G=graph, 
                                           X=destination_geom.x, 
                                           Y=destination_geom.y)

你可以检查并注意到我们目前只有代码。

4. 找到最短路径

然后,使用最短路径函数来获取路线。

# run
route = ox.shortest_path(graph, 
                         orig = closest_origin_node, 
                         dest = closest_destination_node, 
                         weight = 'length')

这将返回一堆路径中节点的代码。

作者提供的图像节点代码

5. 从节点创建 Line Geometry

我们将从图中提取节点的几何形状,并创建一个表示最短路径的 LineString 几何体。

首先为此创建一个函数。

def nodes_to_route(graph_nodes, path_nodes):

    # Extract the route nodes of the graph
    route_nodes = graph_nodes.loc[path_nodes]

    # ---> note! If you have more routes, check for each one, to be removed in length is 1\.  A path can not be built with only 1 node.

    # Create a LineString out of the route
    list_geom = route_nodes.geometry.to_list()
    path = LineString(list_geom)

    # Append the result into the GeoDataFrame
    route_df = gpd.GeoDataFrame( [[path]] )

    # Add a column name
    route_df.columns = ['geometry'] 

    # Set geometry
    route_df = route_df.set_geometry('geometry')

    # Set coordinate reference system
    route_df.crs = graph_nodes.crs

    # remove nans
    route_df = route_df.dropna(subset=['geometry'])

    return route_df

获取节点,并在函数中使用它们。

# get all network nodes
graph_nodes = ox.graph_to_gdfs(graph, edges=False)

# get the line geometries from osm nodes
route_gdf = nodes_to_route(graph_nodes, route)

6. 计算距离

我们将使用墨卡托投影来测量路线的米数。如果你想要更准确的结果,可以使用位置投影。

首先,为此创建一个函数。

def compute_distance(shortest_path_gdf):
    '''
    Compute distance in EPSG:3387

    '''

    # project WGS84 to EPSG3387
    distances = shortest_path_gdf.to_crs("EPSG:3387").geometry.length

    # add
    shortest_path_gdf['distance'] = distances

    return shortest_path_gdf

然后,使用它:

# calculate distance m
route_distance_gdf = compute_distance(route_gdf)

它将测量约 351.243 米的路线。

7. 保存网络和路径

将网络和路径保存到本地磁盘上用于地图。

提取网络并定义 GeoDataFrame:

# fetch network
network = ox.graph_to_gdfs(graph, nodes=False)

# get only needed columns
network_gdf = network.reset_index(drop=True)[['geometry']]

然后存储:

network_gdf.to_file(r'osm_network.gpkg')
route_distance_gdf.to_file(r'osm_shortest_path.gpkg')

你可以使用这些数据来创建自己的地图。例如,这个在 QGIS 中:

作者提供的图像。QGIS 中的最短路径和网络

8. 绘制结果

我们将通过绘制所有元素来检查我们的工作是否正确。

# plot network
ax = network_gdf.plot(figsize=(12, 10), linewidth = 0.2, color='grey', zorder=0);

# origin and destination
origin.plot(ax=ax, markersize=46, alpha=0.8, color='blue', zorder=1)
destination.plot(ax=ax, markersize=46, alpha=0.8, color='green', zorder=2)

# route
route_distance_gdf.plot(ax=ax, linewidth = 3, color='red', alpha=0.4, zorder=3)

plt.axis(False);

结果将会是这样的。

图片由作者提供。最短路径、网络、起点和终点在 Matplotlib 中

已知限制

最短路径是通过节点网络的联合生成的,线条并不完全匹配道路。这完全没问题,因为我们要的只是一个近似值。如果你需要导航,应该使用 Google API 进行路由,或其他提供商。

图片由作者提供。线条是通过节点创建的。

结论

使用 OSMNX 的最短路径算法提供了路线的近似值,并且可以广泛用于城市或区域规模的可达性研究。这个 Python 库不断更新,函数或参数可能会有所变化,因此建议在我们的工作流程中持续更新库版本。

如果你有问题或需要定制分析,欢迎联系我:

Bryan R. LinkedIn

深度伪造技术是否应该开源?

原文:towardsdatascience.com/should-deepfakes-be-open-sourced-87d7644a0765?source=collection_archive---------9-----------------------#2023-05-25

意见

讨论了开放深度伪造技术的利弊

Jack SaundersTowards Data Science Jack Saunders

·

关注 发布于Towards Data Science ·5 分钟阅读·2023 年 5 月 25 日

--

图片生成使用了DreamStudio

我是一名博士研究人员,创建可以被认为是深度伪造的技术用于我的研究。我对创建逼真的数字双胞胎和提升娱乐水平的能力感到着迷。在开始我的研究之前,我曾认为这些模型可能会造成太多伤害,不适合向公众发布。在过去几个月里,我注意到越来越多的顶尖声音主张将开源软件作为人工智能领域的核心原则。虽然这次讨论几乎完全集中在大规模语言模型(LLMs)上,但我认为这一理念在整个领域中普遍存在。我完全支持几乎所有人工智能模型的开源,但对于我自己的研究领域,我不太确定。

几乎没有哪个领域的误用潜力像深度伪造技术那样高。

到目前为止,我的方法是走中间道路。尽量以一种不需要博士学位才能理解的水平来传达深度伪造技术的工作原理。然而,我经常怀疑这是否是正确的方法。本文的目的是尝试启动关于我们深度伪造研究人员应采取方向的讨论。 有鉴于此,本文讨论了开源的一些优缺点。

开源的优点

我们常常听到关于深度伪造技术的负面信息。有人可能会问,像我这样的研究人员为什么还要考虑创建开源模型。然而,实际上有很多合理的理由这样做:

  • 透明性:对我来说,这是最重要的一点。如果大多数主要的深度伪造模型完全开源,那么它们就会变得透明。这样,监管者就可以理解他们正在处理的内容,其他研究人员也可以开发更好的检测算法。对于深度伪造技术来说,那些希望利用它们进行恶意行为的“坏演员”和那些试图防止这种伤害的“好演员”之间将会有一场军备竞赛。你可以确定“坏演员”无论我们是否发布我们的模型都会开发他们的深度伪造技术。在开源中,我们可以给“好演员”提供更多的数据,以构建他们的防害模型。

  • 公平性:如果我们选择不进行开源,只有那些拥有人才和计算资源的机构才能创建深度伪造技术。根据经验,开发这些模型需要很长时间,没有开源软件的话,很少有人能做到这一点。这可能进一步将权力集中在已经强大的手中。深度伪造技术可以在多个市场中使用,并且可能具有数十亿美元的潜在价值。例如,仅配音市场就预计超过 35 亿美元。如果只有像谷歌这样的公司能够创建深度伪造技术,那么只有谷歌这样的公司才能从中获益。

  • 意识: 深度伪造技术正在迅速发展。我们很可能很快会达到这样一个阶段,即你无法相信在线上看到的任何视频,除非它以其他方式经过验证。虽然很多人对此有模糊的认识,但我认为很少有人真正理解其含义。作为深度伪造研究人员,我们有责任帮助教育公众。我们需要真正鼓励每个人保持良好的数字怀疑态度,检查他们在线上看到的任何媒体的来源,并质疑其真实性。开源软件有所帮助。当模型可以在网上自由获取时,教育变得更加容易。如果你能自己创建深度伪造,你自然会留意其他人的伪造。

缺点

当然,将深度伪造模型开源存在许多潜在的缺点,从明显的到更微妙的都有。

  • 人们将滥用它们: 无论我们如何监管或检测模型有多么有效,总会有一小部分人会出于最糟糕的理由使用深度伪造。从报复色情虚假信息,这项技术有一些非常恶劣的应用。如果我们开源模型,就会使所有人更容易访问它们,这无疑会造成伤害。确实,一些坏人无论如何都会做到这一点,尤其是大型犯罪或国家组织,但大多数寻求伤害的人可能本来无法做到,如果没有开源模型的话。

  • 保护措施可能被移除: 保护深度伪造技术不被滥用的较好方法之一是引入保护措施。特别是,大多数创建深度伪造的团队都使用水印技术。水印涉及将数据添加到创建的视频中,以一种对人类和大多数软件都不可见但可以被拥有“密钥”的人轻松识别的方式。这意味着,例如,YouTube 或 Twitter 可以快速检测到视频是否由深度伪造平台创建,并将其移除。由于水印只能被那些获得了这个秘密密钥的人看到,坏人无法移除它们。如果我们开源深度伪造生成,那么坏人将可以简单地跳过添加水印。 这使得深度伪造变得不可检测。

  • 单向共享: 如果我们再次考虑所谓的好人和坏人之间的军备竞赛,那么我们可以看到开源的另一个缺点。如果我们这些好人开源了我们所有的软件,那么坏人可以在此基础上进行构建。另一方面,坏人不会开源他们的模型,这意味着信息只是在一个方向上共享。这给坏人带来了显著的优势。

总结

正如所见,这不是一个容易回答的问题。利弊权衡很多,无论哪种情况,潜在的危害都很大。在写这篇文章的过程中,我进行了许多对话。让我感到惊讶的一点是开源绝对主义者的数量。我经常听到的一个论点是,深伪技术已经存在,不可能把魔鬼收回瓶子。许多人,包括我自己在内,都认为,我们将迎来一个深伪技术与现实难以区分的时代。如果到那时我们都没有足够的意识去质疑这些技术,我们可能会面临很大麻烦,因为坏分子可能在未被察觉的情况下行动。这是一个我认为最近对 AI 发展暂停的呼吁忽视的点。然而,虽然开源可能会减少长期的危害,但它也为那些可能想要滥用深伪技术但没有技术能力的人打开了短期危害的大门。

尽管我对开源问题仍然未作决定,但我比以往任何时候都更自信地认为,这个讨论是必须进行的,至少深伪技术的研究需要在公开的环境下进行,并且要向公众传达。我强烈鼓励每个人发表意见。如果你有我没有涉及到的看法或任何问题,请留下评论或直接联系我,我真的希望听到尽可能多的人反馈。

我真的应该吃这个蘑菇吗?

原文:towardsdatascience.com/should-i-really-eat-that-mushroom-9edeaa69d934

使用 CatBoost 梯度提升决策树对可食用和有毒蘑菇进行分类

Caroline ArnoldTowards Data Science Caroline Arnold

·发表于Towards Data Science ·阅读时间 6 分钟·2023 年 8 月 17 日

--

大多数教育和现实世界的数据集包含分类特征。今天我们将讨论来自CatBoost库的梯度提升决策树,该库原生支持分类数据。我们将使用一个蘑菇数据集,这些蘑菇要么是可食用的,要么是有毒的。蘑菇通过分类特征如颜色、气味和形状进行描述,我们想要回答的问题是:

基于其分类特征,这种蘑菇是否安全食用?

如你所见,风险很高。我们希望确保机器学习模型的准确性,以免我们的蘑菇煎蛋卷以灾难收场。作为额外奖励,最后我们将提供一个特征重要性排名,告诉你哪个分类特征是蘑菇安全性的最强预测因素。

图片由Andrew Ridley提供,来自Unsplash

介绍蘑菇数据集

蘑菇数据集可以在这里找到:archive.ics.uci.edu/dataset/73/mushroom [1]。为了清晰展示,我们将从原始的简短变量中创建一个pandas DataFrame,并用适当的列名和长格式变量进行注释。我们使用 pandas 的replace函数,长格式变量来自数据集描述。目标变量只能取TrueFalse值——数据集创建者采取了保守的方式,将可疑的蘑菇归类为不可食用。

在检查数据集缺失值后,我们发现只有一列——stalk_root——受到影响。我们删除了这一列。

数据集的探索揭示数据相当平衡:在 8124 个蘑菇中,4208 个是可食用的,3916 个是有毒的。我们将数据框分为目标变量is_edible和其余的蘑菇特征。然后,我们通过对目标变量进行分层,将数据集分为训练数据和测试数据。这确保了两个拆分中的类别分布是可比较的。

CatBoost 库

CatBoost 是一个开源的机器学习包,用于梯度提升决策树。可以通过按照安装说明来获取 CatBoost Python 包。对我们来说最重要的组件是catboost.Pool,它组织数据集并指定分类特征和数值特征,以及我们的模型catboost.CatBoostClassifier。分类特征在机器学习算法中可能难以处理,它们必须被编码成数值才能用于训练。每个分类值都与一个数字相关联,例如蘑菇颜色的brown->0, black->1, yellow->2, ...CatBoost 可以自动处理分类输入变量,这样我们就不用再添加 独热编码 到流程中。这不仅方便,而且 CatBoost 算法也经过优化,以便更快地训练分类变量。

Haithem FerdiUnsplash 拍摄

梯度提升决策树

决策树是成熟的机器学习算法,根据特征值将样本分类为不同的类别。单棵决策树容易过拟合。 因此,通常使用决策树的集合来实现更好的性能。在梯度提升决策树中,树的集合通过迭代更新树来构建。每次迭代的树通过在应用前一个树后留下的残差上进行训练,提供了比前一次迭代更小的改进。该过程在损失收敛时停止,即当添加更多树没有增值时,或达到固定的总树数时。有关梯度提升决策树的更详细介绍,请参见页面底部推荐的博客文章。

蘑菇分类

在蘑菇数据集中,所有特征都是分类的,并在Pool中相应指定。我们为训练和测试分别构建一个Pool。目标变量被转换为数值,因为这与CatBoostClassifier的损失函数更好地集成。分类器本身的格式类似于 scikit-learn。可以调整许多属性,包括学习率、树的总数和树的正则化。损失函数是log-loss,因为我们处理的是二分类问题。

对于二分类目标类的预测,使用的是对数损失或交叉熵函数。实际值 y 与模型提供的概率 p 进行比较。

我们在下面的代码框中定义数据集和模型。为了比较,我们训练了一个单一的决策树和一个完整的梯度提升决策树。

评估

现在我们准备评估分类器在测试数据上的表现。食用毒蘑菇可能导致严重的健康问题,因此我们关注减少假阳性。我们计算精确度指标,即实际可食用的蘑菇数量与预测为可食用的蘑菇数量的比例。

单一决策树的精确度为 97%,对于一个分类算法来说相当不错。但是通过梯度提升树,我们可以将精确度提高到 100%,测试数据集中没有毒蘑菇被误标为可食用。混淆矩阵显示,梯度提升决策树在测试集上提供了最佳表现。

单一决策树(左)和梯度提升决策树(右)的混淆矩阵。

特征重要性

这很好,但我们可能没有整天的时间来确定每种我们想吃的蘑菇的 22 个特征。那么,确定蘑菇是否可食用的最重要特征是什么呢?

为了回答这个问题,我们使用内置的模型属性feature_importances_来推导梯度提升树分类器的特征重要性排名。结果显示,气味在特征重要性排名中占据主导地位,其次是孢子 印刷颜色数量

从训练好的 CatBoostClassifier 中获得的蘑菇数据集的特征重要性排名。

进一步观察可能的气味值可以发现,这个特征本身已经是一个很好的预测因素,能够判断一只蘑菇是会成为你餐点的美味补充,还是会让你一天结束在医院。数据集中所有散发茴香杏仁气味的蘑菇都是可食用的。没有气味的蘑菇大多也是可食用的。你应该远离腥臭、辛辣、刺鼻、腐臭、木焦油霉味的蘑菇——老实说,这些蘑菇听起来一开始就不怎么美味。

对蘑菇数据集中的气味特征进行详细分析。

摘要

我们介绍了蘑菇数据集,其中包含仅由分类变量描述的可食用和有毒蘑菇样本。我们引入了 catboost 包,它对分类数据效果很好,并提供了梯度提升决策树。训练了一个模型来对蘑菇进行分类,并取得了令人满意的表现。气味是预测蘑菇安全性的最强指标。希望你喜欢这篇博客文章,并对模型在真实蘑菇上的应用不负责任 😃。

照片由 Zhen H提供,来自Unsplash

深入阅读

数据集参考

[1] 蘑菇。UCI 机器学习库(1987 年)。doi.org/10.24432/C5959T. 本数据集采用创意共享署名 4.0 国际(CC BY 4.0)许可。

我们是否应该更依赖数据?有时候。

原文:towardsdatascience.com/should-we-be-more-data-driven-sometimes-3dcf5e2753ae?source=collection_archive---------3-----------------------#2023-08-17

何时应该依赖数据,何时数据依赖反而会成为障碍。

Robert YiTowards Data Science Robert Yi

·

关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 8 月 17 日

--

我在 Airbnb 担任数据科学家时,Covid-19 爆发了。正如你所料,Covid-19 对于一个依赖良好人际互动的业务来说特别残酷。当世界正在形成孤立的社交圈时,想要让人们住在陌生人的家中是非常困难的。因此,正如你所预期,我们的指标急剧下滑——我们的核心指标下降到了个位数的同比值。没有人再预订 Airbnb 了,可以肯定的是,也没有人愿意开设新的 Airbnb。

当我们面临那突如其来的指标悬崖时,我们的首席执行官布赖恩迅速作出了令人钦佩的回应。尽管我们都在设置家庭办公室,并从好市多囤积卫生纸和罐头食品,布赖恩却召开了紧急全员大会。他明确告诉我们:“我们所知的旅行已经结束。”他没有明确的下一步计划,但在风暴中却有一个灯塔般的指示:停止一切非关键工作,弄清楚如何在疫情中生存下来。

随后发生的事情令人印象深刻。公司有效地转变了方向,在如此大规模的公司中参与其中是非常激动人心的。我们在创纪录的时间内推出了 Airbnb 在线体验。我们以“近在咫尺即为远方”为新的口号,策划并推动人们前往那些在疫情期间适合作为避难所的地点。明显不符合未来方向的新举措被关闭(我曾参与一个名为“社交住宿”的团队,尽管投入巨大,我们还是迅速终结了这一项目)。我们进行了新的融资,重组了公司。公司每天做出数百甚至上千个决定,因此,成功地在疫情最严重时期游刃有余,表现出尽可能好的灵活性。

话虽如此,尽管业务变动颇具趣味,我实际上更想在这篇文章中讨论这一时期数据的作用以及我们可以从中获得的经验教训。我最令人震惊的认识是:数据,曾经在几乎所有战略对话中扮演关键角色,却在一夜之间变成了次要因素。那时,为了争取“数据驱动决策”而奋斗将会是可笑的——不是因为数据在这一过渡期没有用,而是因为在危机中数据不应成为主导。接下来,我将讨论这种思维方式转变的根本原因:紧迫性。让我们考虑不同的决策情况,然后讨论我们应该如何利用数据。是时候真正谈谈“数据驱动”应该意味着什么了。

决策的划分

你可以通过两个轴来清晰地划分决策过程:决策的紧迫性决策的重要性。根据你的决策在帕内特方格中的位置,分析的参与程度可以并且应该有所不同。

图片由作者提供。

低紧迫性,高重要性

一方面,当一个决定极其重要但并不特别紧急时,我们可以按照理想的方式进行分析——与利益相关者紧密迭代,以更好地导航可能的行动空间。例如,假设你公司的高管想要彻底修改你的登陆页面,但他们希望你支持决定应该放置什么内容。你团队中的机器学习软件工程师跳转到卡片分类解决方案,但你和你的利益相关者知道,更关键的决定是是否首先要应用这种分类解决方案。

作者提供的图片。

当前的主页运行良好,因此所需的更改并不紧急,但决定的影响很大——你的更改将影响每一个访客的体验。因此,应该利用分析来更好地导航决策空间:你可以筛选过去的实验,汇总可能有助于当前决策的学习;你可以进行小规模的机会大小检查,以查看任何更改的范围;你可以提供人口统计/渠道/其他分布数据,以更好地了解你可能需要重点关注的内容。

利益相关者必须处理大量的选项,而你可以帮助他们以一种有度量、以假设驱动的方式进行。这就像你在买车一样。花时间进行市场调研是一个好的投资。

高紧迫性,高重要性

另一方面,让我们重新考虑上面的 Covid-19 Airbnb 情况。公司正处于危机状态,领导层已经确定了前进的最佳行动方案:我们需要确定一些市场,推向那些对 Covid 隔离所具有吸引力的市场。你可以像之前的例子那样采取相同的方法——仔细分析细分市场,筛选过去的实验结果等。但每推迟一天做出选择,你将失去两样东西:

  1. 有机会利用新市场。

  2. 有机会进行测试并学到东西。

因此,你提出了一个简单的假设:如果你选择一些与主要城市相对接近的地点,那么你将最大化预订量,因为客人将(a)感到足够隔离于 Covid,但也(b)足够接近以便在紧急情况下能回到家中与朋友和家人团聚。你在几小时内回到高管那里,他们发起了一个推进这些地点的倡议,你发现一些地点效果更佳,从而为你的第二批选择提供了参考。

作者提供的图片。

在这里分析的最佳参与程度与低紧急情况有所不同——你仍在帮助你的利益相关者在思想迷宫中导航,但所做的决策大多是由直觉驱动的,因此你的参与自然较为浅薄。这并不是说你应该盲目顺从,强化反应性先例——仍然要理解原因,但接受你的参与将会是较少结构化和严谨的。尽管你可以在足够的时间里帮助利益相关者做出更好的决策,但你没有足够的时间,现在做出 80%正确的决策远比明天做出 90%正确的决策要有价值得多。

你遇到车祸了。获取一些数据来评估你和对方司机的健康状况,以及到最近医院的最佳路线是有用的,但你可能不应该花几个小时阅读医院评价。

低重要性

最后,有时决策实际上并没有那么重要。你在用户支持页面上移动一个按钮,实验没有收敛,但你的利益相关者想知道结果的真相。这时候你应该反驳——分析确实可以提供答案,但结果会改变什么行动?你会学到什么?利益相关者已经知道这是一个更好的体验,他们询问是为了确认,但你知道在这种实验曝光水平下,确定性是不可能的。

如果我们的决策没有因为我们的数据工作而发生变化,或至少我们没有从探索数据中学到些什么,我们可能根本不应该做这项工作。学会预测你的工作的影响——如果你帮助做出这个决策,潜在的提升是什么?——然后据此行动。

最终评论

为了明确,我并不主张在这里做一个严格的截断,但在选择适合任务的分析时,速度重要性应当被考虑。当决策非常紧急时,数据几乎总是应该退居二线,依赖直觉。当决策极为重要时,数据应当被更仔细地使用来验证假设,并对直觉进行监督。当决策不重要时,你不应该花很多时间担心这个决策,因此任何分析工作都应该在完成前重新考虑。

👋 你好!我是罗伯特, Hyperquery 的首席产品官,曾是数据科学家和分析师。此帖最初发布在 Win With Data,我们每周讨论如何最大化数据的影响。如果你想了解更多关于 Hyperquery 如何帮助你最大化影响的信息,请随时联系我。你可以在 LinkedIn Twitter上找到我。

我们是否应该虚拟化我们的数据科学系统——还是不虚拟化?

原文:towardsdatascience.com/should-we-be-virtualizing-our-data-science-systems-and-or-not-6cb69b4850f3

作者当前的家庭实验室设置

导航虚拟化数据科学过程的优缺点可能很困难,但有些性能趋势无法忽视

Will KeefeTowards Data Science Will Keefe

·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 9 月 12 日

--

随着“巨量数据”的利用在各个行业解决问题变得越来越相关,家庭实验室和数据湖规模的数据存储库需要比以往更多的并行计算能力来提取、转换、加载和分析数据。在创建自己的家庭实验室时,是否在虚拟机上还是在硬件上原生创建并行化设置让我感到困惑,我难以找到性能比较。在本文中,我们将探讨每种设置的优缺点,并对每种方法的虚拟和原生性能及基准测试进行逐一对比。

介绍

许多并行计算集群包括多个节点,即指定处理集群中分配任务的计算机。管理这些节点可能是一个大麻烦,这也是为什么数据工程如此有利可图相比于它们的分析对手。通常,公司会管理整个集群的队列,这使得几乎不可能对单独的节点给予个别关注,因此,“高可用性”设置,如 Proxmox、Kubernetes 和 Docker Swarm,是现代企业的必备工具。你可能已经与这些集群互动过,甚至没有意识到这一点——我今天午餐吃的 Chick-fil-A 鸡肉三明治就是通过一个边缘计算 Kubernetes 集群与他们的销售点系统完成的。

在虚拟化机器上计算有许多好处,包括:

  • 整个操作系统可以从企业服务器快速部署到现场,几乎是瞬时的

  • 图像可以实时备份

  • 部署可以容器化以限制范围并增加安全性

  • 在硬件故障的情况下,系统可以在最小的停机时间内迁移

这些并不是新概念,但随着每个组织对数据分析需求的增加,访问并行化部署的方式可以并且应该有所不同,因为虚拟化的缺点通常是你离裸金属越远,你的系统性能受到的影响就越大。虽然一个开发者在处理一个 Excel 文件时可能不会受到影响,但在处理几 GB 甚至 TB 的数据时,需要仔细考虑如何以及何时使用虚拟工具,并建立考虑处理能力的设置。

设置我们的比较

为了验证这一点,我们可以比较使用 readily available 企业硬件的小型到中型组织的设置(我负担不起那些高级设备)。在我的家庭实验室中,我有一个由多个翻新的企业单元构建的计算集群。我在下面的一些文章中链接了如何构建此设置以及我的用途,但现在让我们比较虚拟系统和裸金属系统之间的性能,并特别测量虚拟化的影响。

关于启动自己的数据分析家庭实验室的完整指南

现在是启动你的数据科学家庭实验室以分析对你有用的数据的最佳时机,存储……

towardsdatascience.com [## 在家庭实验室集群上使用 Python 构建分布式机器学习模型

使用我们自己设置的经济实惠的家庭实验室设备,设置并行和分布式机器非常简单……

betterprogramming.pub

自从写了上述文章后,我稍微升级了我的设置,增加了六台配备 Intel Core i7–7700 处理器、32 GB DDR4–2400 RAM 和 256–512 GB SATA III SSD 的 HP EliteDesk 800 G3 Mini。我在一个拍卖网站上以约 80 美元一台的便宜价格购买了这些设备,并额外支付了大约 30–40 美元以将它们升级为新 RAM 和硬盘。处理器都是 65W 型号,配有 90W 电源。即使按照今天的标准,处理器也不容小觑,超频达到 4Ghz 和 4 核心,8 线程。

今天的比较中,我有两个节点并排放置。一个节点运行 Proxmox,这是一个优化用于虚拟化和部署的 Linux 操作系统,运行一个 Windows 10 Pro 虚拟机,另一个节点在裸机上运行 Windows 10 Pro。没有“正确”的操作系统可用,因为这严重依赖于个人喜欢使用的工具,但每个操作系统都有其优缺点。

Proxmox

Proxmox 的一个优点是,它声称对基线处理器的影响很小。静止时,我们可以看到我在节点上部署的虚拟机的资源使用非常低。下面的截图捕捉了仅一个节点的性能摘要。我们可以看到在空闲时,CPU 的使用率仅为极小的百分比,这也与非常有限的功耗(和电费)相关。

针对特定节点的仪表板 — 作者截图

一旦部署了来宾操作系统,情况就完全不同了。此时的资源利用几乎完全由虚拟机配置决定。

我从玩弄 Proxmox 中学到的一些笔记包括:

  • 学习曲线相当陡峭。这是一个企业工具,虽然使用 Proxmox 的文档非常丰富,但你需要花费大量时间阅读文档、论坛,甚至 Reddit 线程。

  • 从另一个积极的角度来看,解决问题时有大量的文档可供参考,而且围绕该平台建立的社区非常强大。

  • 尽管 Proxmox 具有非常直观的 GUI,但解决问题仍需要一些“跳出框框”的思维。例如,一旦我创建了一个 Windows 虚拟机,并将其调整到我想要的标准,我不能像第一次启动镜像时那样轻松地将其“拖放”到另一个节点。我不得不通过将外部硬盘添加到我运行 Windows 的破旧笔记本电脑中来创建网络附加存储(NAS)(有些人可能还记得我在第一篇文章中提到的那台发光的笔记本电脑)。这个存储充当了中介和备份库,用于克隆和迁移我的虚拟机。

  • 我不是一个很精通 Linux 的人。我知道,我知道,我确实应该深入研究一下,但多年来 Mac 和 Windows PC 的便利性使得我在尝试用 CLI 完成操作时会感到挣扎,而这些操作我通常可以通过点击来完成。

  • Proxmox 非常容易扩展。将第一个节点添加到集群或所谓的“数据中心”花了一些时间才弄明白,但添加其他节点则没有花费任何时间。一开始我能够通过完成分配静态 IP 地址等管理员任务,按照我的要求自定义每个节点。一旦掌握了技巧,部署虚拟机也只需几分钟。

  • 这非常酷,我非常喜欢那个展示所有操作统计数据的仪表板,我在操作过程中会密切监控。下图中,我们可以看到仪表板不仅监控了一个节点的使用情况,还记录了其他节点的情况。能够在数据中心的节点之间灵活切换是巨大的,当在裸机上监控可能需要在实例之间远程操作时。

Proxmox 图形界面在家庭实验室 — 作者截屏

最终,回顾起来我想到的一个问题是,不管我花了多少时间来配置 Windows 虚拟机以使其完全按照我希望的方式运行(本周早些时候,我花了一整晚来配置嵌套虚拟化以使 Docker 能够运行),总是会有一个额外的障碍或瓶颈。

容器化

我甚至不会对 Docker 进行比较,因为我在虚拟机中尝试启动的容器(一个为每半年一次的大学朋友 Minecraft 夜晚准备的 Minecraft 服务器)甚至无法达到令人满意的性能水平(服务器无法跟上,且无法进行游戏)。虽然如果我打算使用嵌套虚拟化,我的周末计划略有受阻,但也有实际应用可能会受到影响。

我经常用于工作和娱乐的一个工具是 PyCaret,一个专注于集成机器学习模型的 Python 库。机器学习模型常常有处理器或架构特定的注意事项,比如 PyCaret 不适用于 M1 Macs,因为 ARM 架构的原因,Tensorflow 没有针对我使用的 Radeon 显卡进行优化,而 Autogluon 在我的 i7 Mac 上无法构建(我甚至不知道为什么)。因此,这些包通常被容器化成 Docker 应用程序以实现便携性。我还在研究本地化到 Docker 的 DynamoDB,以利用强大的 AWS NoSQL 架构,而无需支付云端相关的高额费用。时间和速度是这些工具的卖点,而嵌套虚拟化对 Docker 的影响是巨大的(至少在这些 PC 上)。实际上,性能下降在每一级虚拟化中都是递增的,每一级的性能下降超过 10%

对于那些可能会指出这一点的人来说,一个额外的反思是,能够运行 Docker 的 LXCs(Linux 容器)是在主机操作系统上运行的,因此像 ML 模型这样的大型程序可能会导致内核崩溃,如内存交换失败,不仅会杀死容器,还会影响操作系统(而不仅仅是客操作系统)。因此,我甚至没有考虑在这里使用它们,尽管它们无疑是轻量级应用程序的有用工具。

即使没有测量,我们也可以看到尽可能接近裸金属的方式能提升工具的性能。然而,一些人能够在虚拟化环境中仍然实现出色的性能。例如,AWS Nitro就是该领域中的一个真正的差异化因素,它为亚马逊的大规模计算和数据仓储成功做出了贡献,尽管这需要巨大的成本,使得一些数据科学工具如 Sagemaker 的费用相当于我为每台桌面电脑在一个月内支付的费用来租用一台笔记本电脑。我们可以看到下面一个标准的 Sagemaker 工作室笔记本实例,每天使用八小时,一周五天,规格与我们的机器相似(甚至时钟速度有限),大约每月需 $75。总体来看,每个单元的成本可能在 $100–120 之间,升级后,功耗在待机状态下约为 10–15W,峰值为 65W。这大致相当于每月 $2–3 的电费。这与整年相比节省了近一个数量级。

Sagemaker 计算成本通过 AWS 的公共成本估算器 — 作者截图

话虽如此,我相信随着时间的推移以及更好、更快的硬件在消费者和二手市场上的出现,虚拟化性能与实际性能之间的差距会缩小。如果英特尔想送我一台 i9–13900K 或NVIDIA AI送我一台 RTX 4090,我会很乐意进行测试并向大家汇报。与此同时,我将满足于我的 HP mini 电脑和 AWS 免费套餐来满足我的数据分析和虚拟化需求。

比较

为了实际比较虚拟性能与物理性能,我们将对 Windows 10 虚拟机和物理系统分别进行一般化的基准测试,然后进行 Python 性能测试。在这里我要说明的是,我为虚拟机和 PC 分配了相同数量的核心和 RAM(虽然这让我在虚拟机方面有点冒险,因为我曾经在分配“所有”核心时遇到过问题,因为这影响了宿主虚拟化程序的性能,导致系统故障)。

毫不拖延,以下是基于 userbenchmark.com 为虚拟机构建的 基准测试 快照,运行于本地。以下我们可以看到 VCPU 的性能远低于基线和平均水平,我们的 RAM 也仅稍微低于基线。这表明要么我们的 CPU 没有得到充分利用,要么在测试期间这些数学密集型操作中,托管虚拟机的 CPU 存在大量开销。具体的整数计算性能等指标见截图。

作者截图

整体表现不算特别好,尽管单位本身目前输出了一些 BTUs。

让我们通过运行一个简单的 Python 脚本作为基准来评估性能(请注意,由于 GIL,该脚本是单线程运行的)。下面的非科学 Python 脚本是我编写的,用于创建一个粗略的“速度计算”以比较相对性能。

在两个单元格中,我们首先进行:

  1. 计算一个任意大的数字并将每个数字添加到列表中

  2. 在循环中乘以越来越大的数字

每个测试都有时间限制,并且会重复执行,以建立相对性能的基线,用于比较处理速度和内存 IO。以下是代码片段,供你感兴趣时自行运行以进一步比较。

def test1(n):
    l = []
    for i in range(n):
        l.append(i)
n = 100000
%timeit -r 5 -n 1000 test1(n)

def test2(n):
    for i in range(n):
        i * (i-1)
n = 100000
%timeit -r 5 -n 1000 test2(n)

在虚拟机的 Jupyter Notebook 实例中运行 Python — 作者截图

第一个脚本平均完成时间为 8.84ms,第二个脚本平均完成时间为 11.5ms。我们将很快在虚拟机和裸金属之间比较这些数据。

在运行我们的脚本后,我们可以看到相当一部分 RAM 在使用中,然而,CPU 的利用率几乎没有增加,不过如果尝试在多个线程中分配此任务,我会担心弹性带来的问题。12.5GB 的空闲内存是一个显著的开销,应通过进一步的研究进行优化。

作者的任务管理器截图

现在谈谈裸金属……

在实际硬件上原生运行 Windows 10 Pro,我们可以看到使用相同测试套件的性能基准显著提高,这与虚拟机上使用的测试套件相同。

作者截图

我们的处理器在 Windows 原生模式下的整数比较性能几乎是虚拟化模式的两倍,这使得虚拟机的表现大打折扣。我们现在更接近基准线,一般而言,处理过程的平均水平也有所提高。至于我们的 RAM,读写速度显著提升,当在单核上运行时,吞吐量几乎提高了 3 倍。直接在裸金属上运行对 IO 和处理速度确实产生了巨大的影响。

在运行我们的 Python 测试时,我们注意到性能有类似的跳跃。我们的测试脚本运行速度是虚拟化模式的两倍,测试 1 为 4.06 毫秒,测试 2 为 6.04 毫秒。这是虚拟机上原始测试速度的一半。

作者的裸金属 Jupyter Notebook 截图

在未虚拟化的情况下,我们还可以看到使用的 RAM 是虚拟机空闲时的一半。总的来说,这表明与运行虚拟机的相同硬件相比,裸金属运行可以显著改善处理和内存性能。

作者的任务管理器截图

每个团队独特的数据科学需求没有一刀切的解决方案。对于企业而言,花更多的钱用于虚拟分析工具可能更为合理,因为这些工具的安全性和性能可以得到严格监控。较小的公司也可以利用云工具,具体取决于它们的预算。然而,对于个人和小型研究团队来说,基于裸金属的构建可能是实现最佳性能的必要条件。

我用于构建项目和管道的策略不是专注于管理特定的主机和节点,而是保持一个新安装的 Windows(去除多余软件)的备份,其中预安装了我所需的所有内容——某些 Python 包、代码重新分发包、服务器连接等。项目的其余部分集中在一个代码库中,我可以在运行时将其复制并部署到节点进行处理。从 2010 年代中期开始,大多数计算机都具备的千兆连接速度足够快,可以传输数据科学库和包。因此,对高可用计算和操作系统正常运行时间的需求减少了,因为我大规模管理硬盘,不会在本地进行大规模更改。一些服务仍然需要主动管理,例如在 Docker 中运行的容器,但这些服务无论如何都需要相当活跃的管理,将其放在我的本地 Windows 10 专业版安装中,更符合我如何支配时间的方式,因为无论如何都会发生故障。

你怎么看?你在哪里运行你的代码?你倾向于使用哪些工具和平台来托管你的数据科学工作流?请在下面告诉我,或者随时在 LinkedIn 上与我建立联系!

对我使用的硬件感兴趣?查看我在 www.willkeefe.com 上的评论。

你应该使用 slots 吗?Slots 如何影响你的类,何时以及如何使用它们

原文:towardsdatascience.com/should-you-use-slots-how-slots-affect-your-class-when-and-how-to-use-ab3f118abc71

一行代码能带来 20%的性能提升?

Mike HulsTowards Data Science Mike Huls

·发表于 Towards Data Science ·6 分钟阅读·2023 年 8 月 12 日

--

(图片来源:Sébastien GoldbergUnsplash)

Slots 是一种机制,它允许你声明类属性并限制其他属性的创建。你可以确定你的类有哪些属性,从而防止开发者动态添加新属性。这通常会导致20%的速度提升

Slots 在有大量类实例且属性集已知的程序中特别有用。比如视频游戏或物理模拟;在这些情况下,你跟踪大量的实体。

你可以通过添加一行代码将 slots 添加到你的类中,但这总是一个好主意吗?在本文中,我们将探讨为什么如何使用slots使你的类更快以及何时使用它们。总体目标是更好地理解 Python 的类内部工作原理。开始编码吧!

Slots 让 Python 类更快

通过使用slots,你可以提高类的内存使用效率和性能。一个具有 slots 的类占用更少的内存,执行速度更快。

如何让我的类使用 slots?

告诉 Python 让一个类使用 slots 非常简单。你只需添加一个特殊的属性__slots__,它指定所有其他属性的名称:

class Person:
  first_name:str
  last_name:str
  age:int

  __slots__ = ['first_name', 'last_name', 'age']    # <-- this adds slots

  def __init__(self, first_name:str, last_name:str, age:int):
    self.first_name = first_name
    self.last_name = last_name
    self.age = age

在上述类中,我们看到Person有三个属性:first_namelast_nameage。我们可以告诉 Python 我们希望Person类使用 slots,通过添加__slots__属性来实现。这个属性必须指定所有其他属性的名称。

## 参数与关键字参数:哪种方式在 Python 中调用函数最快?

timeit模块的清晰演示

towardsdatascience.com

slotted 类的速度提升了多少?

我们上面使用的 Person 类使用 slots 后几乎 小了 60%(从 488 字节减少到 206 字节)。

关于速度,我已经对实例化、访问和赋值进行了基准测试。我发现 速度提高了多达 20%!你需要对这些结果持保留态度;虽然这些百分比看起来相当令人印象深刻,但这 20% 仅代表 10 万次实例化类的 0.44 秒。这相当于每个实例 可忽略的 44 纳秒(大约比一秒小 3030 万倍)。

查看用于基准测试的 内存速度代码;

## 为什么 Python 很慢以及如何加速

看看底层,了解 Python 的瓶颈所在

towardsdatascience.com

为什么 slotted 类更小且更快?

这与 Python 类的 动态字典 有关。这个字典让你可以为 Python 类分配属性:

class Person:
  pass

mike = Person()

mike.age = 33  # <-- create a new attribute

在上面的例子中,我们定义了一个没有任何属性的类,创建了一个实例,然后动态地创建 age 属性并赋值。

在底层,Python 将所有属性信息存储在一个字典中。通过调用类上的 __dict__ 魔法方法可以访问这个字典:

# 1\. Define class
class Person:
  name:str

  def __init__(self, name:str):
      self.name = name

# 2\. Create instance
mike = Person(name='mike')
# 3\. Create a new variable
mike.age = 33
# 4\. Create new attribute throught the __dict__
mike.__dict__['website'] = 'mikehuls.com'
# 5\. Print out the dynamic dictionary
print(mike.__dict__)  
# -> {'name': 'mike', 'age': 33, 'website': 'mikehuls.com'}

动态字典使得 Python 类非常灵活,但它有一个缺点:使用属性时 Python 会在这个字典中进行搜索,这相对较慢。

## 用两行代码线程化你的 Python 程序

通过同时做多件事来加速你的程序

towardsdatascience.com

slots如何影响动态字典?

当你告诉 Python 为你的类使用 slots 时,不会创建动态字典。相反,Python 创建了一个 固定大小的数组,其中包含对变量的引用。这就是你必须将属性名称传递给 __slots__ 属性的原因。

访问这个数组不仅速度更快,而且占用的内存空间也更少。较小的内存占用对内存分配和垃圾回收也有积极的影响。

插槽有什么副作用?

插槽改变了你的类;它变得有点不灵活,因为你的类变得更静态。这意味着你不能在运行时添加属性;你必须事先指定你的属性:

# 1\. Define class
class Person:
  name:str

  def __init__(self, name:str):
      self.name = name

# 2\. Create instance
mike = Person(name='mike')

# 3\. Add a new attribute?
mike.website = 'mikehuls.com'     # this will not work!
# ERROR: AttributeError: 'Person' object has no attribute 'website'

# 4\. Print out dynamic dict
print(mike.__dict__)              # this will not work
# ERROR: AttributeError: 'Person' object has no attribute '__dict__'

有一种(虽然有点乱的)变通方法:通过将 "__dict__" 的值添加到你的 __slots__ 数组中:

# 1\. Define class
class Person:
  name: str

  __slots__ = ["name", "__dict__"] # <- We've added __dict__

  def __init__(self, name: str):
    self.name = name

# 2\. Create instance
mike = Person(name='mike')

# 3\. Add a new attribute
mike.website = 'mikehuls.com'     # no error this time!

最后一个需要注意的事项是,有些包可能期望使用“普通”的 Python 类,而不是使用插槽类。

towardsdatascience.com ## 6 步骤让 Pandas DataFrame 操作快 100 倍

Cython 用于数据科学:将 Pandas 与 Cython 结合,以实现令人难以置信的速度提升

[towardsdatascience.com

这在数据类中也适用吗?

是的!从 Python 3.10 开始,你还可以添加插槽数据类。使用数据类更简单,只需向 @dataclass 装饰器添加一个参数即可。只需像下面这样定义你的数据类:

@dataclasses.dataclass(slots=True)
class Person:
    name: str

使用插槽有什么好处?

显然,速度内存效率,但也许还有安全性:如果我想覆盖类中的 age 属性但打错字,例如输入 mike.aage = 34,那么未使用插槽的类将创建一个新属性,而保持 age 属性不变。当你使用插槽时,Python 会抛出一个错误,因为它不知道类中有 aage 属性。

何时使用插槽?

速度:尽管插槽从百分比上加速了你的类,但每次操作的绝对时间增加是相当微不足道的。因此,如果你需要创建大量实例,或者需要多次重写或访问属性,插槽的使用会更具吸引力。

内存:如果你内存不足且希望节省每一个字节,使用插槽可能会有好处,因为它们显著减少了内存使用量。我们的简单类减少了 60% 的内存占用。

安全性:插槽防止你使用错误的属性和动态创建新属性。如果你尝试修改一个未知的属性,插槽类会抛出错误。

towardsdatascience.com ## 绝对初学者的 Cython:两步实现 30 倍更快的代码

轻松编译 Python 代码,实现极快的应用程序

[towardsdatascience.com

结论

正如我们在这篇文章中看到的,slots 以三种方式影响你的类:

  • 大小:slots 消除了 Python 创建动态字典的需要,而是依赖于 更小 的固定大小数组,这间接通过减少对垃圾回收的需求来加速你的应用。

  • 速度:slots 允许直接访问内存,绕过搜索字典的需要,这样会更快。速度提升在绝对意义上是相当微小的;节省了几纳秒。

  • 灵活性:slots 防止在运行时添加属性,因此你的类变得有点不那么灵活。这也可能是件好事,因为当你使用动态属性创建时,你的代码可能会变得杂乱无章。

在我看来,减少的灵活性是我不常遇到的缺点:我从不动态创建属性,我喜欢 slots 保持属性静态。因此,我会尽可能使用 slots。在最坏的情况下,依赖关系可能会出问题,但在这种情况下,很容易再次移除 slots。

我希望这篇文章能像我希望的那样清晰,但如果不是这种情况,请告诉我我可以做些什么来进一步澄清。与此同时,查看我的其他文章,涵盖了各种编程相关的话题,例如:

编程愉快!

— Mike

附言:喜欢我做的事吗? 关注我!

## 使用我的推荐链接加入 Medium — Mike Huls

阅读 Mike Huls 的每一篇故事(以及 Medium 上其他数千名作家的故事)。你的会员费用直接支持 Mike...

mikehuls.medium.com

在你的 Medium 博客中展示 Streamlit 应用

原文:towardsdatascience.com/show-streamlit-apps-in-your-medium-blog-520e98c7d51d

跟随这个教程将你的应用在线发布到 Medium 帖子中。

Gustavo SantosTowards Data Science Gustavo Santos

·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 5 月 30 日

--

图片由 Lloyd Dirks 提供,来源于 Unsplash

介绍

我喜欢写关于Streamlit的文章。我会说它是我最喜欢的 Python 包之一。

Streamlit 是一个非常容易学习的包,它使人们能够快速创建仪表板、Web 应用,甚至将模型部署到生产环境中。我已经写过几次关于它的文章,甚至教你如何用这个优秀的包来构建和部署你的第一个应用(见下方链接)。

2022 年,Snowflake 收购了 Streamlit,从那时起,他们不断为这个包添加更多功能,使其变得更好。最近,他们推出了一个非常酷的功能,我想在这个简短的教程中展示:如何将你的应用嵌入到像这样的 Medium 博客中

[## Streamlit 基础:构建你的第一个 Web 应用

学习创建每个基本功能的代码,并在几分钟内部署你的 Web 应用。

medium.com [## Streamlit 基础:部署你的第一个 Web 应用

学习如何使用 Streamlit 的分享功能部署一个 Web 应用。

medium.com

使用案例

在我看来,将应用程序嵌入 Medium 博客中的好处有很多。

首先,我认为这是任何数据科学家建立优秀作品集的好方法。如果你在科技行业工作,你可能知道拥有一个有趣的作品集来展示你的技能,是吸引注意力的好方法,同时也能保持你的专业形象光鲜亮丽。

另一个好处是将研究成果展示给世界或甚至是客户。你可以写一篇博客文章,并在良好的市场分析或书面报告之后,将你的应用程序添加进去,作为你工作的一个很好的视觉补充。

我甚至看到有人将餐厅菜单做成 Streamlit 应用程序,连接客户和厨房,这真的很棒。所以,看看我们可以想到多少种选项。

创建一个简单的应用程序

好的,为了嵌入一个应用程序,我们需要先创建一个。所以,让我们创建一个简单的应用程序,只需几行代码。

我决定使用meteostat包创建一个应用程序,该包允许你根据经纬度坐标从几乎任何位置检索温度信息。

下面是我将在这个快速项目中使用的包:

# Imports
import pandas as pd
import streamlit as st
from datetime import datetime
from meteostat import Point, Monthly

接下来,让我们将 streamlit 页面设置为宽布局,这样我们的应用程序可以占据页面的更多部分,而不是将所有功能居中显示。作为输入数据,我使用了这个包含许多世界城市及其经纬度的 csv 文件。

# Set Page Layout
st.set_page_config(layout='wide')

# Load the Dataset
cities = pd.read_csv('world_cities.csv')

很好。下一步是放置两个下拉框,让用户选择他们想要检索信息的城市和年份。

# Select box
st.subheader('Temperature History App')

col1, col2 = st.columns(2, gap='medium')
# column 1 - Table weather history
with col1:
    # Title of the select box
    selected_city = st.selectbox(label='Select a city for weather information',
                                 options=cities['city'].sort_values().unique())

with col2:
    selected_year = st.selectbox(label= 'Select an year',
                                 options= range(2022,1999,-1) )

下一个代码片段展示了从meteostat包中实际获取的信息。我们首先设置一个时间框架,使用用户选择的年份。然后我们使用选择的城市从 CSV 中获取经纬度信息,使用简单的 Pandas 查询,将结果数据转换为浮点数。接下来,我们使用meteostat中的Point()函数创建一个对象,该对象将放入Monthly()函数中,后者接收时间框架和位置,创建一个最终对象data,用于提取我们应用程序所需的数据。

# Collect the Weather Information
# Set time period
start = datetime(selected_year, 1, 1)
end = datetime(selected_year, 12, 31)

# Create Point
lat = cities.query('city == @selected_city')['latitude'].astype('float').tolist()[0]
long = cities.query('city == @selected_city')['longitude'].astype('float').tolist()[0]
city_loc = Point(lat, long)

# Get daily data for 2018
data = Monthly(city_loc, start, end)
data = data.fetch()
#data['mth'] = data.index.month
data = data[ ['tavg', 'tmax', 'tmin'] ]

收集了所需的数据框后,我们将创建两列来绘制选定位置的月度温度线图,并显示包含检索信息的表格。

col1, col2 = st.columns(2, gap='large')
# column 1 - Table weather history
with col1:
    st.text('| TEMPERATURES IN °C')
    st.line_chart(data=data)

# column 2 - Graphics
with col2:
    # WEATHER INFORMATION TABLE
    st.text('| WEATHER HISTORY')
    st.write(data)

最后,我们将添加另一个部分,展示一个地图,显示所选位置。

# Division
st.markdown('---')
# Map
st.subheader('| WHERE THIS CITY IS')
df_map = cities.query('city == @selected_city')
st.map(df_map, zoom=5)

生成的应用程序被保存为.py文件,与 requirement.txt 和城市 csv 一起放在一个GitHub 存储库中,链接在此,这是部署前所需的步骤。

然后我们可以直接去share.streamlit.io/将应用程序部署到网络上。一旦完成,你将获得应用程序的链接,并将其粘贴到你的 Medium 博客文章中。

在下面的序列中,你可以看到我们在最后几段中刚刚构建的嵌入应用程序。它在你的 Medium 帖子中直接功能齐全(请耐心等待,可能需要几秒钟才能完全加载)

在本教程中创建的天气应用。图片由作者提供。

在你离开之前

哇!这真是太棒了。当我看到这个新功能时,我迫不及待地想要测试它并与你分享。希望你也喜欢,并找到很好的方法与大家分享和展示你的工作。

Streamlit 使用起来真的很简单,你可以通过他们的文档或互联网上的许多教程了解更多。我相信你会喜欢用它编写应用程序的简单性。

如果你喜欢这些内容,记得关注我获取更多信息。

[## Gustavo Santos - Medium

阅读 Gustavo Santos 在 Medium 上的文章。他是一名数据科学家,从数据中提取洞察,帮助个人和公司……

gustavorsantos.medium.com](https://gustavorsantos.medium.com/?source=post_page-----520e98c7d51d--------------------------------)

LinkedIn 上找到我,或者通过 TopMate.io 预约时间与我讨论数据科学

参考文献

[## 嵌入你的应用 - Streamlit 文档

嵌入 Streamlit Community Cloud 应用可以通过集成交互式、数据驱动的应用程序来丰富你的内容……

docs.streamlit.io](https://docs.streamlit.io/streamlit-community-cloud/get-started/embed-your-app?utm_medium=email&_hsmi=259535966&_hsenc=p2ANqtz-9deEEF-Z6E5LeUsWM_TiXef4GoXNX6wpR27Fz5CYkwa9nRbwFaYVnGkLwIy9hmvE_gN6GwZsaFmkGDq8iQFCS3wfRp3g&utm_content=259535966&utm_source=hs_email&source=post_page-----520e98c7d51d--------------------------------#embedding-with-oembed) [## Streamlit 文档

Streamlit 不仅仅是创建数据应用的一种方式,它还是一个创作者社区,分享他们的应用和想法……

docs.streamlit.io](https://docs.streamlit.io/?source=post_page-----520e98c7d51d--------------------------------) [## Snowflake 以 8 亿美元收购 Streamlit,帮助客户构建基于数据的应用

Snowflake 帮助客户在云中存储和管理大量数据,而不受云供应商的锁定。Streamlit 是一个……

techcrunch.com [## Streamlit 基础知识:构建你的第一个 Web 应用

学习如何编写代码以创建每一个基本功能,并在几分钟内部署你的网页应用。

medium.com [## Streamlit 基础知识:部署你的第一个 Web 应用

学习如何使用 Streamlit 的分享功能部署 Web 应用。

medium.com

Siamese 神经网络与三重损失和余弦距离

原文:towardsdatascience.com/siamese-neural-networks-with-tensorflow-functional-api-6aef1002c4e

理论与代码实践:使用三重损失和余弦距离进行 Siamese 网络在 CIFAR-10 数据集上的训练

Tan Pengshi AlvinTowards Data Science Tan Pengshi Alvin

·发表于 Towards Data Science ·阅读时长 11 分钟·2023 年 5 月 12 日

--

图片由 Alex Meier 提供,来源于 Unsplash

如果我们可以将每个对象图像(如人脸等)编码成一个模板——一个数字向量呢?之后,我们可以通过对比它们的模板——计算距离——来客观地确定对象之间的相似性。在深度学习中,这正是 Siamese 神经网络希望实现的目标。

Siamese 神经网络基本上是经过训练后为每个输入对象生成独特特征向量(模板)的模型。尽管这些模型通常用于对象图像的模板(计算机视觉),但它们也可以用于文本和声音数据。

除了安全认证,如人脸识别和签名比对,Siamese 神经网络还常用于电子商务平台中测量产品相似性。例如,一些电子商务平台允许你通过上传你想寻找的对象的图像来搜索类似产品。在 Kaggle 上,甚至有一个由东南亚领先电子商务公司 Shopee 举办的 产品匹配竞赛

在这篇文章中,我们将探索一个在 Tensorflow 中常见的数据集——CIFAR-10——该数据集与产品相似性搜索问题有些相似,只不过兴趣对象是汽车——如汽车、飞机、卡车、船只等——以及动物(或者说宠物也行!)——如猫、狗、马、鸟、鹿等。

在开始之前,我们首先需要理解 Siamese 神经网络背后的理论。之后,我们将探索在 CIFAR-10 数据集上训练和评估简单 Siamese 神经网络的代码。

准备好了吗?开始吧!

1. 孪生网络理论

我不得不承认本文中的封面图片有点误导——‘Siamese’一词实际上并不是源于‘暹罗猫’。而是来源于‘暹罗双胞胎’,即身体某部分连在一起的双胞胎。

因此,孪生神经网络基本上指的是双胞神经网络,这些网络通常在最后——Lambda 层,如我们将看到的——连接在一起,然后将模型输出输入损失函数。在训练这些双胞神经网络的过程中,它们的权重在初始化、前向传播和反向传播过程中完全相同。

由于我们通常处理的是图像,每对孪生神经网络通常是卷积神经网络(CNN)。如果你对 CNN 不熟悉或需要刷新记忆,我这里有一篇关于 CNN 的优秀文章:

[## 迁移学习与卷积神经网络(CNN)

从 CNN 到迁移学习的完整指南,适用于 Kaggle 的猫狗数据集

medium.com](https://medium.com/mlearning-ai/transfer-learning-and-convolutional-neural-networks-cnn-e68db4c48cca?source=post_page-----6aef1002c4e--------------------------------)

牢记这一点,我们将介绍两种常见的孪生神经网络:

1.1 对比损失孪生网络

第一种类型是基于计算双胞 CNN 的嵌入层(特征向量)之间的欧几里得/余弦距离,然后与真实值(1:匹配,0:不匹配)比较来确定对比损失的孪生神经网络。

以下是这种模型的示意图:

对比损失的孪生神经网络示例。图片改编自SigNet 论文

对比损失公式与欧几里得距离,其中 Y 为真实值。图片作者提供。

1.2 三重损失孪生网络

第二种类型的孪生神经网络基于计算三重 CNN 的嵌入层(特征向量)之间的两个欧几里得/余弦距离——即锚点和正样本之间,锚点和负样本之间——然后在 Lambda 层中完全计算三重损失,而不与任何真实值进行比较。

因为研究表明这种三重损失模型通常比对比损失模型更鲁棒,所以我们将在本文中重点讨论这种类型的孪生网络。

以下是这种模型的示意图:

三重损失孪生神经网络的示例。图片改编自SigNet 论文

使用欧几里得距离的三元组损失公式,其中 A 是锚点图像输入,P 是正样本图像输入,N 是负样本图像输入。图片由作者提供。

1.3 孪生网络的目标

现在,我们已经看到孪生神经网络的大致架构。但是在训练网络后,我们打算达到什么目标?让我们看看下面的插图:

孪生网络的训练减少了相似图像之间的距离,同时增加了不相似图像之间的距离。图片来源于FaceNet 论文

我们看到孪生网络正在学习在同一类别图像之间重建相似的特征向量。因此,训练后,相似图像模板之间的距离将减少,而不相似图像模板之间的距离将增加。

话虽如此,在训练过程中覆盖尽可能多的图像类别是很重要的,以便模型也能推广到未见过的类别(签名、面孔等)。

最后,在模型评估期间,我们主要关注生成输入图像数据的模板。因此,在进行模板推理时,仅提取单个 CNN 网络或双胞胎/三胞胎网络的主体,而不包括 Lambda 层。

1.4 欧几里得距离与余弦距离

在我们开始编码之前,让我们首先区分两个常见的向量距离度量——欧几里得距离和余弦距离。到目前为止,在上述插图中,我们展示了欧几里得距离,因为它更直观易懂,但在构建更好的模型时并不一定优于余弦距离。下面我们来说明一下:

2D 空间中两个向量的欧几里得距离和余弦距离的插图。图片由作者提供。

从上述内容来看,欧几里得距离只是两个特征向量之间的“坐标距离”,而余弦距离是它们之间“角度距离”的一种度量。因此,当两个特征向量远离时,我们可以看到欧几里得距离和余弦距离都很大。但它们之间存在微妙的差别,下面我们来看一下:

欧几里得距离和余弦距离在小角度下但向量长度不同的比较。图片由作者提供。

虽然余弦距离仅测量特征向量之间的角度差异,但欧几里得距离测量第二维度——长度差异。因此,虽然更直观,但欧几里得距离本质上比余弦距离更复杂。

一般来说,欧几里得距离和余弦距离都被广泛使用,选择取决于经验探索。然而,对于较小的数据集和特定数量的类别,采用余弦距离作为损失函数可能是一个更好的选择,这也是我们为 CIFAR-10 数据集所做的。

2. 孪生网络代码练习

接下来,让我们开始编码吧。我们将基于 TensorFlow CIFAR-10 数据集构建三元组损失孪生网络。我们将基于余弦距离来构建三元组损失,然后在测试集评估时,通过角度相似度来比较测试图像。

注意:使用的是角度相似度,因为它基于余弦距离,但值范围缩放在 0%到 100%之间。

角度相似度的公式。图片由作者提供。

还需要注意的是,在模型初始化过程中,我们将采用 TensorFlow 的功能性 API(对比之前在迁移学习和 CNN 文章中使用的顺序 API),以及自定义 Lambda 层和自定义损失函数。

毫不犹豫地,让我们开始编码吧!

2.1 探索 CIFAR-10 数据集

# import necessary libraries
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# set random seed
np.random.seed(42)

# load CIFAR-10 data
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()

# check data size
assert X_train.shape == (50000, 32, 32, 3)
assert X_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

# combine data first - we will generate test set later.
X = np.concatenate([X_train,X_test],axis=0)
y = np.concatenate([y_train,y_test],axis=0)
y = np.squeeze(y)

assert X.shape == (60000, 32, 32, 3)
assert y.shape == (60000,)

# check number of data in each class
unique, counts = np.unique(y,return_counts=True)
np.asarray([unique,counts]).T

# Plot Class N (0-9)

TARGET = # Class index here
NUM_ARRAYS = 10

arrays = X[np.where(y==TARGET)]
random_arrays_indices = np.random.choice(len(arrays),NUM_ARRAYS)
random_arrays = arrays[random_arrays_indices]

fig = plt.figure(figsize=[NUM_ARRAYS,4])
plt.title('Class 0: Plane',fontsize = 15)
plt.axis('off')

for index in range(NUM_ARRAYS):
     fig.add_subplot(2, int(NUM_ARRAYS/2), index+1)
     plt.imshow(random_arrays[index])

2.2 生成三元组

# initialize triplets array
triplets = np.empty((0,3,32,32,3),dtype=np.uint8)

# get triplets for each class
for target in range(10):

    locals()['arrays_'+str(target)] = X[np.where(y==target)].reshape(3000,2,32,32,3)
    locals()['arrays_not_'+str(target)] = X[np.where(y!=target)]

    random_indices = np.random.choice(len(locals()['arrays_not_'+str(target)]),3000)
    locals()['arrays_not_'+str(target)] = locals()['arrays_not_'+str(target)][random_indices]

    locals()['arrays_'+str(target)] = np.concatenate(
        [
            locals()['arrays_'+str(target)],
            locals()['arrays_not_'+str(target)].reshape(3000,1,32,32,3)
        ],
        axis = 1
    )

    triplets = np.concatenate([triplets,locals()['arrays_'+str(target)]],axis=0)

# check triplets size
assert triplets.shape == (30000,3,32,32,3)

# plot triplets array to visualize
TEST_SIZE = 5
random_indices = np.random.choice(len(triplets),TEST_SIZE)

fig = plt.figure(figsize=[5,2*TEST_SIZE])
plt.title('ANCHOR | POSITIVE | NEGATIVE',fontsize = 15)
plt.axis('off')

for row,i in enumerate(range(0,TEST_SIZE*3,3)):
    for j in range(1,4):
        fig.add_subplot(TEST_SIZE, 3, i+j)
        random_index = random_indices[row]
        plt.imshow(triplets[random_index,j-1])

# save triplet array
np.save('triplets_array.npy',triplets)

2.3 准备模型训练/评估

# Import all libraries

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras import Input, optimizers, Model
from tensorflow.keras.layers import Layer, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import plot_model

from sklearn.metrics import precision_recall_curve, roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split

from scipy import spatial
triplets = np.load('triplets_array.npy')

triplets = triplets/255 #normalize by 255
labels = np.ones(len(triplets)) #create a fixed label

assert triplets.shape == (30000,3,32,32,3)
# Split data into our train and test set

X_train, X_test, y_train, y_test = train_test_split(
    triplets,
    labels,
    test_size=0.05,
    random_state=42
)
# Load pretrained model for transfer learning

pretrained_model = MobileNetV2(
    weights='imagenet', 
    include_top=False, 
    input_shape=(32,32,3)
)

for layer in pretrained_model.layers:
    layer.trainable = True

2.4 模型训练

# Initialize functions for Lambda Layer

def cosine_distance(x,y):
    x = K.l2_normalize(x, axis=-1)
    y = K.l2_normalize(y, axis=-1)
    distance = 1 - K.batch_dot(x, y, axes=-1)
    return distance

def triplet_loss(templates, margin=0.4):

    anchor,positive,negative = templates

    positive_distance = cosine_distance(anchor,positive)
    negative_distance = cosine_distance(anchor,negative)

    basic_loss = positive_distance-negative_distance+margin
    loss = K.maximum(basic_loss,0.0)

    return loss
# Adopting the TensorFlow Functional API

anchor = Input(shape=(32, 32,3), name='anchor_input')
A = pretrained_model(anchor)

positive = Input(shape=(32, 32,3), name='positive_input')
P = pretrained_model(positive)

negative = Input(shape=(32, 32,3), name='negative_input')
N = pretrained_model(negative)

loss = Lambda(triplet_loss)([A, P, N])

model = Model(inputs=[anchor,positive,negative],outputs=loss)
# Create a custom loss function since there are no ground truths label

def identity_loss(y_true, y_pred):
    return K.mean(y_pred)

model.compile(loss=identity_loss, optimizer=Adam(learning_rate=1e-4))

callbacks=[EarlyStopping(
    patience=2, 
    verbose=1, 
    restore_best_weights=True,
    monitor='val_loss'
    )]

# view model
plot_model(model, show_shapes=True, show_layer_names=True, to_file='siamese_triplet_loss_model.png')

# Start training - y_train and y_test are dummy

model.fit(
    [X_train[:,0],X_train[:,1],X_train[:,2]],
    y_train,
    epochs=50, 
    batch_size=64,
    validation_data=([X_test[:,0],X_test[:,1],X_test[:,2]],y_test),
    callbacks=callbacks
)

2.5 模型评估

X_test_anchor = X_test[:,0]
X_test_positive = X_test[:,1]
X_test_negative = X_test[:,2]

# extract the CNN model for inference
siamese_model = model.layers[3]

X_test_anchor_template = np.squeeze(siamese_model.predict(X_test_anchor))
X_test_positive_template = np.squeeze(siamese_model.predict(X_test_positive))
X_test_negative_template = np.squeeze(siamese_model.predict(X_test_negative))

y_test_targets = np.concatenate([np.ones((len(X_test),)),np.zeros((len(X_test),))])
# Get predictions in angular similarity scores

def angular_similarity(template1,template2):

    score = np.float32(1-np.arccos(1-spatial.distance.cosine(template1,template2))/np.pi)

    return score

y_predict_targets = []

for index in range(len(X_test)):
    similarity = angular_similarity(X_test_anchor_template[index],X_test_positive_template[index])
    y_predict_targets.append(similarity)

for index in range(len(X_test)):
    similarity = angular_similarity(X_test_anchor_template[index],X_test_negative_template[index])
    y_predict_targets.append(similarity)
# Get prediction results with ROC Curve and AUC scores

fpr, tpr, thresholds = roc_curve(y_test_targets, y_predict_targets)

fig = plt.figure(figsize=[10,7])
plt.plot(fpr, tpr,lw=2,label='UnoFace_v2 (AUC={:.3f})'.format(roc_auc_score(y_test_targets, y_predict_targets)))
plt.plot([0,1],[0,1],c='violet',ls='--')
plt.xlim([-0.05,1.05])
plt.ylim([-0.05,1.05])
plt.legend(loc="lower right",fontsize=15)

plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title('Receiver Operating Characteristic (ROC) Curve',weight='bold',fontsize=15)

# Getting Test Pairs and their Corresponding Predictions

positive_comparisons = X_test[:,[0,1]]
negative_comparisons = X_test[:,[0,2]]

positive_predict_targets = np.array(y_predict_targets)[:1500]
negative_predict_targets = np.array(y_predict_targets)[1500:]

assert positive_comparisons.shape == (1500,2,32,32,3)
assert negative_comparisons.shape == (1500,2,32,32,3)

assert positive_predict_targets.shape == (1500,)
assert negative_predict_targets.shape == (1500,)

np.random.seed(21)
NUM_EXAMPLES = 5
random_index = np.random.choice(range(len(positive_comparisons)),NUM_EXAMPLES)
# Plotting Similarity Scores for Positive Comparisons 
# (Switch values and input to plot for Negative Comparisons)

plt.figure(figsize=(10,4))
plt.title('Positive Comparisons and Their Similarity Scores')
plt.ylabel('Anchors')
plt.yticks([])
plt.xticks([32*x+16 for x in range(NUM_EXAMPLES)], ['.' for x in range(NUM_EXAMPLES)])
for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
    t.set_color('green') 
plt.grid(None)
anchor = np.swapaxes(positive_comparisons[:,0][random_index],0,1)
anchor = np.reshape(anchor,[32,NUM_EXAMPLES*32,3])
plt.imshow(anchor)

plt.figure(figsize=(10,4))
plt.ylabel('Positives')
plt.yticks([])
plt.xticks([32*x+16 for x in range(NUM_EXAMPLES)], positive_predict_targets[random_index])
for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
    t.set_color('green') 
plt.grid(None)
positive = np.swapaxes(positive_comparisons[:,1][random_index],0,1)
positive = np.reshape(positive,[32,NUM_EXAMPLES*32,3])
plt.imshow(positive)

3. 结论

恭喜你完成理论和代码练习!希望这个教程为你提供了关于孪生网络及其在对象相似度应用方面的全面介绍。

在结束之前,我还要补充的是,如何处理对象相似度分数取决于问题陈述。

如果我们在生产中进行 1:1 对象比较(即两个对象是否相似或不同),通常需要基于测试时的假匹配率(FMR)设置一个相似度阈值。另一方面,如果进行 1:N 对象匹配,通常会返回相似度得分最高的对象,并进行排序。

注:有关完整的代码,请查看我的 GitHub

感谢您的时间,希望您喜欢本教程。我还想介绍一个在这篇文章中详细阐述的极其重要的话题——以数据为中心的机器学习

[## 以数据为中心的 AI — 数据收集和增强策略]

关于以数据为中心的机器学习项目的数据生成策略的综合指南

pub.towardsai.net

感谢阅读!如果您喜欢我的内容,可以浏览我在 Medium 上的其他文章,并在 LinkedIn 上关注我。

支持我! — 如果您没有订阅 Medium,并且喜欢我的内容,请考虑通过我的推荐链接来支持我。

[## 通过我的推荐链接加入 Medium - Tan Pengshi Alvin]

阅读 Tan Pengshi Alvin 的每一个故事(以及 Medium 上成千上万的其他作家)。您的会员费用直接…

tanpengshi.medium.com

相似性搜索,第三部分:结合倒排文件索引和产品量化

原文:towardsdatascience.com/similarity-search-blending-inverted-file-index-and-product-quantization-a8e508c765fa?source=collection_archive---------1-----------------------#2023-05-19

了解如何结合两种基本的相似性搜索索引,以获得两者的优势

Vyacheslav EfimovTowards Data Science Vyacheslav Efimov

·

关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 5 月 19 日

--

相似性搜索 是一个问题,其中给定一个查询的目标是找到数据库中与之最相似的文档。

引言

在数据科学中,相似性搜索常见于自然语言处理(NLP)领域、搜索引擎或推荐系统中,这些系统需要为查询检索出最相关的文档或项目。在海量数据中提升搜索性能的方法有很多种。

在本系列的前两部分中,我们讨论了信息检索中的两种基本算法:倒排文件索引产品量化。这两者都优化了搜索性能,但关注的方面不同:前者加快了搜索速度,而后者则将向量压缩为更小、更节省内存的表示。

## 相似性搜索,第一部分:kNN 与倒排文件索引

相似性搜索是一个热门问题,其中给定一个查询 Q,我们需要在所有文档中找到最相似的文档。

## 相似性搜索,第二部分:产品量化

在本系列文章的第一部分,我们查看了用于执行相似性搜索的 kNN 和倒排文件索引结构。

medium.com

由于这两种算法侧重于不同方面,自然会产生一个问题,即是否可以将这两种算法合并成一种新算法。

在本文中,我们将结合这两种方法的优点,以产生一种快速且节省内存的算法。供参考,大多数讨论的想法将基于这篇论文

在深入细节之前,有必要了解残差向量是什么,并对其有用的属性有一个简单的直观认识。我们将在设计算法时使用它们。

残差向量

想象一下执行了一个聚类算法,并产生了几个簇。每个簇都有一个质心和与之相关的点。残差是一个点(向量)与其质心之间的偏移。基本上,要找出特定向量的残差,需要从其质心中减去该向量。

如果簇是由 k-means 算法构建的,那么簇的质心是所有属于该簇的点的均值。因此,从任何点找出残差将等同于从中减去簇的均值。通过从属于特定簇的所有点中减去均值,这些点将围绕 0 中心对齐。

原始的点簇显示在左侧。然后从所有簇点中减去簇质心。结果的残差向量显示在右侧。

我们可以观察到一个有用的事实:

用残差替换原始向量不会改变它们之间的相对位置。

也就是说,向量之间的距离始终保持不变。我们可以简单地查看下面的两个方程。

减去均值不会改变相对距离

第一个方程是两个向量之间的欧几里得距离公式。在第二个方程中,从两个向量中减去簇的均值。我们可以看到,均值项会被消去——整个表达式变得与第一个方程中的欧几里得距离完全相同!

我们通过使用 L2 度量(欧几里得距离)的公式证明了这一声明。重要的是要记住,这个规则可能不适用于其他度量。

因此,如果对于给定的查询,目标是找到最近的邻居,可以仅从查询中减去簇均值,然后在残差中进行正常的搜索过程。

从查询中减去均值不会改变其相对位置。

现在让我们看看下图中的另一个例子,其中两个簇的向量残差分别计算。

从每个簇的对应质心中减去均值将使所有数据集向量围绕 0 中心

这是一个有用的观察,将在未来使用。此外,对于给定的查询,我们可以计算到所有簇的查询残差。查询残差使我们能够计算到簇的原始残差的距离。

从每个簇中减去均值后,所有点都围绕 0 中心。查询和查询残差与相应簇中其他点的相对位置保持不变。

训练

考虑到上一节中的有用观察,我们可以开始设计算法。

给定一个向量数据库,构建一个倒排文件索引,将向量集划分为n个 Voronoi 分区,从而减少推理过程中的搜索范围。

在每个 Voronoi 分区内,从每个向量中减去质心的坐标。结果是,所有分区中的向量变得彼此更接近,并且围绕 0 中心。此时,无需原始向量,因为我们存储它们的残差。

之后,对所有分区中的向量运行产品量化算法。

重要方面:产品量化不会对每个分区单独执行——那样会很低效,因为分区的数量通常很高,这将需要大量的内存来存储所有的码本。相反,算法会对所有分区的残差同时执行

实际上,现在每个子空间包含来自不同 Voronoi 分区的子向量。然后,对于每个子空间,执行一个聚类算法,构建出如常规的 k 个簇及其中心点。

训练过程

替换向量为其残差的重要性是什么? 如果向量没有被其残差替换,那么每个子空间将包含更多的各种子向量(因为子空间将存储来自不同不相交的 Voronoi 分区的子向量,而这些子向量可能在空间中相距很远)。现在来自不同分区的向量(残差)彼此重叠。由于现在每个子空间的变化更小,因此表示向量所需的重现值也更少。换句话说:

使用之前相同长度的 PQ 代码,向量可以更准确地表示,因为它们的方差更小。

推断

对于给定的查询,找到 Voronoi 分区的 k 个最近中心点。所有这些区域内的点都被视为候选点。由于原始向量在每个 Voronoi 区域中最初被其残差所替代,查询向量的残差也需要被计算。在这种情况下,查询残差需要为每个 Voronoi 分区单独计算(因为每个区域有不同的中心点)。只有来自所选 Voronoi 分区的残差才会成为候选点。

查询残差随后被拆分为子向量。与原始的产品量化算法相同,对于每个子空间,计算包含从子空间中心点到查询子向量的距离的距离矩阵 d。必须记住,查询残差在每个 Voronoi 分区中都是不同的。这基本上意味着距离矩阵 d 需要为每个查询残差单独计算。这是所需优化的计算代价。

最后,部分距离被汇总,就像在产品量化算法中之前所做的那样。

排序结果

在计算所有距离后,需要选择 k 个最近邻点。为了高效完成这一过程,作者建议使用一个 MaxHeap 数据结构。它的容量有限为 k,并在每一步中存储 k 个当前最小的距离。每当计算出一个新距离时,只有当计算出的值小于 MaxHeap 中的最大值时,该值才会被添加到 MaxHeap 中。计算完所有距离后,查询的答案已经存储在 MaxHeap 中。使用 MaxHeap 的优点是其构建时间很快,为 O(n)

推断过程

性能

该算法利用了倒排文件索引和产品量化。根据推理过程中 Voronoi 分区的数量,相同数量的子向量到质心矩阵 d 需要计算并存储在内存中。这可能看起来像是一个缺点,但与整体优势相比,这是一个相当好的折衷。

该算法从倒排文件索引继承了良好的搜索速度,从产品量化继承了压缩效率。

Faiss 实现

Faiss(Facebook AI 搜索相似性)是一个用 C++ 编写的 Python 库,用于优化相似性搜索。该库提供了不同类型的索引,这些数据结构用于高效地存储数据和执行查询。

根据 Faiss 文档 的信息,我们将了解如何将倒排文件和产品量化索引组合在一起形成新的索引。

Faiss 在 IndexIVFPQ 类中实现了上述算法,该类接受以下参数:

  • quantizer:指定计算向量之间距离的方式。

  • d:数据维度。

  • nlist:Voronoi 分区的数量。

  • M:子空间的数量。

  • nbits:编码单个簇 ID 所需的位数。这意味着单个子空间中的簇总数将等于 k = 2^nbits

此外,可以调整 nprobe 属性,该属性指定在推理过程中用于搜索候选项的 Voronoi 分区数量。更改此参数无需重新训练。

Faiss 的 IndexIVFPQ 实现

存储单个向量所需的内存与原始产品量化方法相同,只是现在我们增加了 8 个字节,用于在倒排文件索引中存储关于向量的信息。

结论

利用之前文章部分的知识,我们探讨了一个先进算法的实现,该算法实现了高效的内存压缩和加速的搜索速度。该算法在处理大量数据时广泛用于信息检索系统。

资源

除非另有说明,所有图片均由作者提供。

相似性搜索,第一部分:kNN 与倒排文件索引

原文:towardsdatascience.com/similarity-search-knn-inverted-file-index-7cab80cc0e79?source=collection_archive---------1-----------------------#2023-04-28

介绍 kNN 的相似性搜索及其通过倒排文件的加速。

Vyacheslav EfimovTowards Data Science Vyacheslav Efimov

·

关注 发布于 Towards Data Science ·9 分钟阅读·2023 年 4 月 28 日

--

相似性搜索是一个问题,给定一个查询,目标是找到数据库中与之最相似的文档。

介绍

在数据科学中,相似性搜索常出现在自然语言处理领域、搜索引擎或推荐系统中,需要为查询检索最相关的文档或项目。通常,文档或项目以文本或图像的形式表示。然而,机器学习算法不能直接处理原始文本或图像,这就是为什么文档和项目通常被预处理并存储为向量的原因。

有时,向量的每个组件可以存储语义信息。在这种情况下,这些表示也称为嵌入。这些嵌入可以有数百维,并且其数量可以达到数百万!由于这些巨大的数字,任何信息检索系统必须能够迅速检测到相关文档。

在机器学习中,向量也被称为对象

索引

为了加速搜索性能,数据集嵌入之上建立了一个特殊的数据结构。这种数据结构称为索引。在这一领域已有大量研究,并且发展出了许多类型的索引。在选择适用于特定任务的索引之前,有必要了解其内部操作原理,因为每种索引有不同的用途,并且各自有优缺点。

在本文中,我们将看看最简单的方法——kNN。基于 kNN,我们将转到倒排文件——一种用于更可扩展搜索的索引,可以加速搜索过程数倍。

kNN

kNN是相似性搜索中最简单和最原始的算法。考虑一个向量数据集和一个新的查询向量 Q。我们希望找到与 Q 最相似的前 k 个数据集向量。首先要考虑的是如何测量两个向量之间的相似性(距离)。实际上,有几种相似性度量可以实现这一点。下面的图中展示了其中一些。

相似性度量

训练

kNN 是机器学习中为数不多的无需训练阶段的算法之一。选择合适的度量后,我们可以直接进行预测。

推理

对于一个新的对象,该算法会穷尽地计算与所有其他对象的距离。之后,它会找到距离最小的 k 个对象,并将其作为响应返回。

kNN 工作流程

显然,通过检查与所有数据集向量的距离,kNN 可以保证 100%的准确结果。然而,这种蛮力方法在时间性能上非常低效。如果一个数据集由 n 个具有 m 维度的向量组成,那么对于每个 n 向量,需要 O(m) 时间来计算与查询 Q 的距离,总时间复杂度为 O(mn)。正如我们稍后将看到的,存在更高效的方法。

此外,原始向量没有压缩机制。想象一下一个包含数十亿个对象的数据集。将所有这些对象存储在 RAM 中几乎是不可能的!

kNN 性能。具有 100%的准确率和没有训练阶段会导致在推理过程中进行穷举搜索以及向量的无内存压缩。注意:这种图示显示了不同算法的相对比较。根据情况和选择的超参数,性能可能会有所不同。

应用

kNN 的应用范围有限,应该仅在以下情况之一中使用:

  • 数据集的大小或嵌入维度相对较小。这一方面确保了算法仍然能够快速执行。

  • 算法的要求准确度必须达到 100%。在准确度方面,没有其他最近邻算法能够超越 kNN。

基于指纹检测一个人的例子是需要 100%准确度的问题。如果一个人犯了罪并留下了指纹,检索到的结果必须完全正确。否则,如果系统不是 100%可靠,则可能会错误地将另一人定罪,这是一种非常严重的错误。

基本上,改进 kNN 有两种主要方法(稍后我们将讨论):

  • 缩小搜索范围。

  • 降低向量的维度。

使用这两种方法之一时,我们将不会再次进行穷举搜索。这些算法被称为近似最近邻(ANN),因为它们不保证 100%的准确结果。

倒排文件索引

“倒排索引(也称为文档列表文档文件倒排文件)是一个数据库索引,存储内容的映射,例如单词或数字,及其在表格、文档或文档集合中的位置” — 维基百科

在执行查询时,计算查询的哈希函数,并从哈希表中获取映射值。这些映射值中的每一个都包含自己的一组潜在候选者,然后在条件下完全检查是否为查询的最近邻。通过这样做,所有数据库向量的搜索范围被缩小。

倒排文件索引工作流程

这种索引有不同的实现方式,具体取决于哈希函数的计算方式。我们将要研究的实现是使用Voronoi 图(或Dirichlet 镶嵌)的方法。

训练

该算法的思想是创建几个不相交的区域,每个数据集点将属于其中一个区域。每个区域都有自己的质心,指向该区域的中心。

有时Voronoi 区域也被称为单元分区

Voronoi 图示例。白点是各自分区的中心,分区内包含一组候选者。

Voronoi 图的主要特性是,质心到其区域内任意点的距离小于该点到另一质心的距离。

推理

当给定一个新对象时,计算所有 Voronoi 分区质心的距离。然后选择距离最小的质心,并将该分区中的向量作为候选项。

通过给定查询,我们搜索最近的质心(位于绿色区域)

最终,通过计算与候选项的距离并选择前k个最近的候选项,返回最终答案。

在选定区域内查找最近邻居

如你所见,这种方法比之前的要快得多,因为我们不需要遍历所有数据集向量。

边缘问题

随着搜索速度的提高,倒排文件也有一个缺点:它不能保证找到的对象始终是最近的。

在下图中,我们可以看到这样的场景:实际的最近邻居位于红色区域,但我们仅从绿色区域选择候选项。这种情况被称为边缘问题

边缘问题

这种情况通常发生在查询对象靠近另一区域边界时。为了减少这种情况下的错误数量,我们可以扩大搜索范围,并基于与对象最接近的前m个质心选择多个区域来搜索候选项。

在几个区域内查找最近邻居(m = 3)

探索的区域越多,结果越准确,但计算所需的时间也越长。

应用

尽管存在边缘问题,倒排文件在实践中表现出色。它在需要在准确度和速度提升之间进行权衡时非常适合使用。

一个使用案例示例是基于内容的推荐系统。假设它根据用户过去观看的其他电影向用户推荐一部电影。数据库包含一百万部电影供选择。

  • 使用 kNN 时,系统确实会选择对用户最相关的电影并进行推荐。然而,执行查询所需的时间非常长。

  • 假设使用倒排文件索引,系统推荐第 5 个最相关的电影,这在现实生活中可能是这种情况。搜索时间比 kNN 快 20 倍。

从用户体验的角度来看,很难区分这两种推荐结果的质量:第 1 个和第 5 个最相关的结果都是来自百万个可能电影中的良好推荐。用户可能对这些推荐中的任何一个都感到满意。从时间的角度来看,倒排文件显然是赢家。这就是为什么在这种情况下,最好使用后者的方法。

反向文件索引性能。在这里,我们稍微降低了准确度以在推理过程中获得更高的速度。

Faiss 实现

Faiss(Facebook AI 搜索相似性)是一个用 C++编写的 Python 库,用于优化的相似性搜索。该库展示了不同类型的索引,这些索引是用来高效存储数据和执行查询的数据结构。

根据Faiss 文档中的信息,我们将了解索引的创建和参数化过程。

kNN

实现 kNN 方法的索引在 Faiss 中被称为flat,因为它们不压缩任何信息。它们是唯一保证正确搜索结果的索引。实际上,Faiss 中存在两种类型的 flat 索引:

  • IndexFlatL2。相似度计算为欧几里得距离。

  • IndexFlatIP。相似度计算为内积。

这两种索引在构造函数中都需要一个单一的参数d:数据维度。这些索引没有任何可调参数。

Faiss 对 IndexFlatL2 和 IndexFlatIP 的实现

存储一个向量的单一分量需要 4 字节。因此,存储一个维度为d的向量需要4 * d字节。

反向文件索引

对于描述的反向文件,Faiss 实现了IndexIVFFlat类。与 kNN 的情况一样," Flat "一词表示没有对原始向量进行解压,它们被完全存储。

为了创建这个索引,我们首先需要传递一个量化器——一个决定数据库向量如何存储和比较的对象。

IndexIVFFlat有两个重要参数:

  • nlist:定义在训练过程中创建的区域(Voronoi 单元)的数量。

  • nprobe:决定搜索候选区域的数量。更改 nprobe 参数不需要重新训练。

Faiss 对 IndexIVFFlat 的实现

与之前的情况一样,我们需要4 * d字节来存储一个向量。但现在我们还需要存储有关 Voronoi 区域的信息,这些区域是数据集向量所属的。在 Faiss 实现中,这些信息每个向量占用 8 字节。因此,存储单个向量所需的内存为:

结论

我们已经探讨了相似性搜索中的两种基础算法。实际上,朴素的 kNN 几乎不应该用于机器学习应用,因为它在扩展性方面表现不佳,除非在特定情况下。另一方面,倒排文件提供了加速搜索的良好启发式方法,其质量可以通过调整超参数来提高。从不同的角度仍然可以提升搜索性能。在本系列文章的下一部分,我们将深入探讨一种旨在压缩数据集向量的方法。

## 相似性搜索,第二部分:产品量化

学习一种有效压缩大数据的强大技术

towardsdatascience.com

资源

除非另有说明,否则所有图片均由作者提供。

相似性搜索,第四部分:分层可导航的小世界(HNSW)

原文:towardsdatascience.com/similarity-search-part-4-hierarchical-navigable-small-world-hnsw-2aad4fe87d37?source=collection_archive---------0-----------------------#2023-06-16

发现如何构建高效的多层图以提升在海量数据中的搜索速度

Vyacheslav Efimov数据科学进展 Vyacheslav Efimov

·

关注 发表在 数据科学进展 · 13 分钟阅读 · 2023 年 6 月 16 日

--

相似性搜索 是一个问题,其中给定一个查询,目标是找到与之最相似的文档,这些文档位于所有数据库文档中。

介绍

在数据科学中,相似性搜索通常出现在自然语言处理领域、搜索引擎或推荐系统中,这些系统需要为查询检索最相关的文档或项目。在海量数据中,有各种不同的方法可以提高搜索性能。

分层可导航小世界(HNSW)是一种用于近似邻居搜索的最先进算法。在背后,HNSW 构建了优化的图结构,使其与本系列文章前面讨论的其他方法大相径庭。

HNSW 的主要思想是构建一个图,使得任意一对顶点之间的路径可以在少量步骤内遍历。

一个著名的类比是著名的 六度分隔理论,与这种方法相关:

所有人彼此之间的社交联系最多为六层。

在深入探讨 HNSW 的内部工作之前,我们先讨论跳表和可导航小世界——HNSW 实现中使用的关键数据结构。

跳表

跳表 是一种概率数据结构,允许在排序列表中以 O(logn) 的平均时间复杂度插入和搜索元素。跳表由多个层次的链表构成。最低层包含所有元素的原始链表。当移动到更高的层级时,被跳过的元素数量增加,从而减少了连接数。

在跳表中找到元素 20

对于某个值的搜索程序从最高层开始,将其下一个元素与该值进行比较。如果值小于或等于该元素,则算法继续到下一个元素。否则,搜索程序降到连接更多的较低层,并重复相同的过程。最后,算法降到最低层并找到所需的节点。

根据 维基百科 的信息,跳表有一个主要参数 p,它定义了一个元素出现在多个列表中的概率。如果一个元素出现在层 i 中,则它出现在层 i + 1 的概率等于 pp 通常设置为 0.5 或 0.25)。平均而言,每个元素会出现在 1 / (1 — p) 个列表中。

正如我们所见,这个过程比普通的链表线性搜索要快得多。实际上,HNSW 继承了相同的思想,但它使用的是图而不是链表。

可导航的小世界

可导航的小世界 是一个具有多对数 T = O(logᵏn) 搜索复杂度的图,它使用贪心路由。路由 指的是从低度顶点开始搜索过程,并以高维度顶点结束。由于低度顶点的连接非常少,算法可以在它们之间迅速移动,从而高效地导航到可能存在最近邻的区域。然后,算法逐渐放大并切换到高维度顶点,以在该区域的顶点中找到最近邻。

顶点有时也被称为节点

搜索

首先,通过选择一个入口点进行搜索。为了确定算法下一步移动的顶点(或顶点),它计算查询向量到当前顶点邻居的距离,并移动到最近的一个。某些时候,当算法找不到比当前节点更靠近查询的邻居节点时,它会终止搜索过程。这个节点被返回作为查询的响应。

可导航的小世界中的贪心搜索过程。节点 A 被用作入口点。它有两个邻居 B 和 D。节点 D 比 B 更接近查询。因此,我们移动到 D。节点 D 有三个邻居 C、E 和 F。E 是距离查询最近的邻居,所以我们移动到 E。最终,搜索过程将导致节点 L。由于 L 的所有邻居都比 L 本身离查询更远,我们停止算法,并将 L 作为查询的答案返回。

这种贪心策略不能保证找到确切的最近邻,因为该方法仅使用当前步骤的局部信息来做出决策。早期停止 是该算法的问题之一。特别是在搜索过程的开始阶段,当没有比当前节点更好的邻居节点时,早期停止现象尤为明显。在大多数情况下,这可能发生在起始区域有太多低度顶点时。

早期停止。当前节点的两个邻居都比查询更远。因此,算法返回当前节点作为响应,尽管存在距离查询更近的节点。

可以通过使用多个入口点来提高搜索精度。

构建

NSW 图是通过打乱数据集点并逐个将它们插入当前图中来构建的。当插入一个新节点时,它会通过边连接到其M个最近的顶点。

节点的顺序插入(从左到右),M = 2。在每次迭代中,向图中添加一个新顶点,并将其链接到其 M = 2 个最近邻居。蓝线表示连接到新插入节点的边。

在大多数情况下,长距离边缘可能会在图构建的初始阶段创建。它们在图导航中扮演着重要角色。

在构建开始时插入的元素的最近邻链接随后变成了网络中心之间的桥梁,这些桥梁保持了整个图的连通性,并允许在贪婪路由过程中对跳数进行对数缩放。 — Yu. A. Malkov, D. A. Yashunin

从上图中的示例可以看出,在开始时添加的长距离边缘AB的重要性。设想一个查询需要遍历从相对远离的节点AI的路径。拥有边缘AB允许通过直接从图的一侧导航到另一侧来快速完成这个过程。

随着图中顶点数量的增加,新连接到新节点的边的长度变短的概率也增加。

HNSW

HNSW 基于与跳表和可导航小世界相同的原理。它的结构表现为一个多层次的图,其中顶部层次的连接较少,而底部层次的区域则更为密集。

搜索

搜索从最高层开始,每次在层节点中贪婪地找到局部最近邻,然后逐层向下。最终,找到的最低层上的最近邻即为查询的答案。

HNSW 中的搜索

类似于 NSW,通过使用多个入口点可以提高 HNSW 的搜索质量。与其在每层上仅找到一个最近邻,不如使用efSearch(一个超参数)找到与查询向量最接近的最近邻,并将每个邻居作为下一层的入口点。

复杂度

原始论文的作者声称,在任何层上查找最近邻所需的操作数都由一个常数限制。考虑到图中的所有层数是对数级的,我们得到了总的搜索复杂度,即O(logn)

构建

选择最大层

节点在 HNSW 中是一个接一个地顺序插入的。每个节点会随机分配一个整数l,表示该节点可以出现在图中的最大层。例如,如果l = 1,则该节点只能在第 0 层和第 1 层找到。作者为每个节点随机选择l,其指数衰减概率分布由非零乘数mL(mL = 0 结果是 HNSW 中的单层和非优化的搜索复杂度)进行归一化。通常,大多数l值应该等于 0,因此大多数节点仅存在于最低层。较大的mL值增加了节点出现在更高层的概率。

每个节点的层数 l 是根据指数衰减概率分布随机选择的。

基于标准化因子mL的层数分布。横轴表示均匀分布(0, 1)的值。

为了实现可控层次结构的最佳性能优势,不同层之间的邻居重叠(即也属于其他层的元素邻居的百分比)必须很小。 — Yu. A. Malkov, D. A. Yashunin。

减少重叠的一个方法是减小mL。但重要的是要记住,减少mL通常会导致在每层贪婪搜索过程中需要更多的遍历。因此,选择一个能够平衡重叠和遍历次数的mL值至关重要。

论文的作者建议选择mL的最佳值,即1 / ln(M)。该值对应于跳表的参数p = 1 / M,它是层间的平均单元素重叠。

插入

节点被分配l值后,有两个插入阶段:

  1. 算法从上层开始,贪婪地找到最近的节点。找到的节点随后被用作下一层的入口点,搜索过程继续。一旦达到层l,插入过程就进入第二步。

  2. 从层l开始,算法在当前层插入新节点。然后,它像之前一样执行第 1 步,但不是仅找到一个最近邻,而是贪婪地搜索efConstruction(超参数)个最近邻。然后从efConstruction个邻居中选择M个,并建立从插入节点到它们的边。之后,算法下降到下一层,每个找到的efConstruction节点作为入口点。算法在新节点及其边被插入到最低层 0 后终止。

在 HNSW 中插入一个节点(蓝色)。新节点的最大层随机选择为 l = 2。因此,节点将被插入到层 2、1 和 0。在每一层,节点将连接到其 M = 2 个最近邻。

选择构造参数的值

原始论文提供了如何选择超参数的几个有用见解:

  • 根据模拟,M的良好值在 5 到 48 之间。较小的M值适合较低的召回率或低维数据,而较大的 M 值则更适合较高的召回率或高维数据。

  • 更高的efConstruction值意味着更深层次的搜索,因为会探索更多的候选项。然而,这也需要更多的计算。作者建议选择一个efConstruction值,以便在训练过程中回忆接近0.95–1

  • 另外,还有一个重要的参数 Mₘₐₓ — 一个顶点可以拥有的最大边数。除此之外,还存在一个相同的参数 Mₘₐₓ₀,但仅针对最低层。建议选择一个接近 2 * MMₘₐₓ 值。大于 2 * M 的值可能会导致性能下降和过度的内存使用。同时,Mₘₐₓ = M 会导致高召回率下性能差。

候选选择启发式

上面提到,在节点插入过程中,从 efConstruction 候选节点中选择 M 个来建立边。让我们讨论选择这些 M 个节点的可能方法。

天真的方法选取 M 个最近的候选节点。然而,这并不总是最优选择。下面是一个演示这个问题的例子。

想象一个如下面图所示的图结构。正如你所见,图中有三个区域,其中两个区域彼此没有连接(在左侧和顶部)。因此,例如,从点 AB 需要通过另一个区域经过很长的路径。为了更好的导航,将这两个区域以某种方式连接起来是合乎逻辑的。

节点 X 被插入到图中。目标是将其最优地连接到其他 M = 2 个点。

然后一个节点 X 被插入到图中,并且需要连接到 M = 2 个其他顶点。

在这种情况下,天真的方法直接选择 M = 2 个最近的邻居(BC),并将 X 连接到它们。尽管 X 已经连接到其真实的最近邻居,但这并没有解决问题。让我们来看一下作者们发明的启发式方法。

启发式算法不仅考虑节点之间的最近距离,还考虑图中不同区域的连通性。

启发式算法选择第一个最近的邻居(在我们的例子中是 B)并将插入的节点 (X) 连接到它。然后算法按照排序的顺序逐个选择下一个最接近的邻居 (C),并仅当该邻居到新节点 (X) 的距离小于该邻居到所有已经连接的顶点 (B) 到新节点 (X) 的距离时,才建立一条边。之后,算法继续处理下一个最近的邻居,直到建立 M 条边。

回到例子,启发式过程如下面的图所示。启发式算法选择 B 作为 X 的最近邻居,并建立了边 BX。然后算法选择 C 作为下一个最近邻居。然而,这次 BC < CX。这表明将边 CX 添加到图中并不是最优的,因为已经存在边 BX,且节点 BC 非常接近。相同的类比适用于节点 DE。之后,算法检查节点 A。这一次,它满足条件,因为 BA > AX。因此,新边 AX 和两个初始区域变得互相连接。

左侧的示例使用了简单的方法。右侧的示例使用了选择启发式,使两个初始不相交的区域相互连接。

复杂度

插入过程与搜索过程非常相似,没有显著的差异需要非恒定数量的操作。因此,单个顶点的插入需要 O(logn) 的时间。要估计总复杂度,应该考虑给定数据集中的所有插入节点 n。最终,HNSW 构建需要 O(n * logn) 时间。

将 HNSW 与其他方法结合使用

HNSW 可以与其他相似性搜索方法结合使用,以提供更好的性能。最常见的方法之一是将其与倒排文件索引和产品量化(IndexIVFPQ)结合使用,这在本系列文章的其他部分中已有描述。

[## 相似性搜索,第三部分:融合倒排文件索引和产品量化

在本系列的前两部分中,我们讨论了信息检索中的两个基本算法:倒排……

medium.com](https://medium.com/@slavahead/similarity-search-blending-inverted-file-index-and-product-quantization-a8e508c765fa?source=post_page-----2aad4fe87d37--------------------------------)

在这个范式中,HNSW 充当粗量化器的角色,负责找到最近的 Voronoi 划分,从而可以缩小搜索范围。为此,必须在所有 Voronoi 质心上构建 HNSW 索引。给定查询时,使用 HNSW 找到最近的 Voronoi 质心(而不是之前通过比较每个质心的距离进行的暴力搜索)。之后,查询向量在相应的 Voronoi 划分中被量化,并通过 PQ 代码计算距离。

通过在 Voronoi 质心上建立的 HNSW 中找到最近邻,选择最接近的 Voronoi 质心。

当仅使用倒排文件索引时,最好将 Voronoi 划分的数量设置得不太大(例如 256 或 1024),因为会执行暴力搜索以找到最近的质心。通过选择较少的 Voronoi 划分,划分内的候选项数量变得相对较大。因此,算法迅速识别查询的最近质心,并且大部分运行时间集中在 Voronoi 划分内找到最近邻上。

然而,将 HNSW 引入工作流需要调整。考虑仅在少量质心(256 或 1024)上运行 HNSW:由于质心数量较少,HNSW 在执行时间上与简单的暴力搜索相对相同,因此不会带来显著的好处。此外,HNSW 需要更多的内存来存储图结构。

这就是为什么在合并 HNSW 和倒排文件索引时,建议将 Voronoi 质心的数量设置得比平时大得多。这样,每个 Voronoi 分区内的候选者数量会大大减少。

这种范式的转变导致了以下设置:

  • HNSW 以对数时间快速识别最近的 Voronoi 质心。

  • 之后,执行各自 Voronoi 分区内的穷举搜索。因为潜在候选者的数量较少,所以不应成为问题。

Faiss 实现

Faiss(Facebook AI 搜索相似性)是一个用 C++编写的 Python 库,用于优化相似性搜索。该库提供了不同类型的索引,这些索引是用于高效存储数据和执行查询的数据结构。

根据Faiss 文档的信息,我们将探讨如何将 HNSW 与倒排文件索引和乘积量化结合使用。

IndexHNSWFlat

FAISS 有一个类IndexHNSWFlat实现了 HNSW 结构。通常,“Flat”后缀表示数据集向量完全存储在索引中。构造函数接受 2 个参数:

  • d:数据维度。

  • M:在插入过程中需要添加到每个新节点的边的数量。

此外,通过hnsw字段,IndexHNSWFlat 提供了几个有用的属性(可以修改)和方法:

  • hnsw.efConstruction:构造时要探索的最近邻数量。

  • hnsw.efSearch:搜索时要探索的最近邻数量。

  • hnsw.max_level:返回最大层级。

  • hnsw.entry_point:返回入口点。

  • faiss.vector_to_array(index.hnsw.levels):返回每个向量的最大层级列表。

  • hnsw.set_default_probas(M: int, level_mult: float):允许分别设置MmL值。默认情况下,level_mult 设置为 1 / ln(M)

Faiss 实现的 IndexHNSWFlat

IndexHNSWFlatMₘₐₓ = MMₘₐₓ₀ = 2 * M 设置值。

IndexHNSWFlat + IndexIVFPQ

IndexHNSWFlat 也可以与其他索引结合使用。一个例子是前面部分描述的IndexIVFPQ。创建这个复合索引分两个步骤进行:

  1. IndexHNSWFlat 被初始化为粗量化器。

  2. 量化器作为参数传递给IndexIVFPQ的构造函数。

训练和添加可以使用不同或相同的数据完成。

FAISS 实现的 IndexHNSWFlat + IndexIVFPQ

结论

在这篇文章中,我们研究了一种强大的算法,该算法在处理大型数据集向量时表现尤为出色。通过使用多层图表示和候选选择启发式方法,其搜索速度在保持合理的预测准确性的同时得以高效扩展。值得注意的是,HNSW 还可以与其他相似性搜索算法结合使用,使其非常灵活。

资源

除非另有说明,否则所有图像均由作者提供。

相似性搜索,第五部分:局部敏感哈希(LSH)

原文:towardsdatascience.com/similarity-search-part-5-locality-sensitive-hashing-lsh-76ae4b388203?source=collection_archive---------0-----------------------#2023-06-24

探索如何将相似性信息融入哈希函数

Vyacheslav EfimovTowards Data Science Vyacheslav Efimov

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 6 月 24 日

--

相似性搜索 是一个问题,给定一个查询的目标是在所有数据库文档中找到与其最相似的文档。

介绍

在数据科学中,相似性搜索通常出现在自然语言处理(NLP)领域、搜索引擎或推荐系统中,其中需要为一个查询检索到最相关的文档或项目。存在多种不同的方法来提升在海量数据中的搜索性能。

在本系列文章的前几部分中,我们讨论了倒排文件索引、产品量化和 HNSW 以及它们如何结合使用以提高搜索质量。在本章中,我们将探讨一种主要不同的方法,这种方法既能保持高搜索速度,又能保证高质量。

## 相似性搜索,第三部分:融合倒排文件索引和产品量化

在本系列的前两部分中,我们讨论了信息检索中的两个基本算法:倒排…

towardsdatascience.com ## 相似性搜索,第四部分:分层可导航小世界(HNSW)

分层可导航小世界(HNSW)是一种最先进的算法,用于近似搜索最近的…

towardsdatascience.com

局部敏感哈希(LSH)是一组方法,用于通过将数据向量转换为哈希值来缩小搜索范围,同时保留有关其相似性的信息。

我们将讨论传统方法,该方法包括三个步骤:

  1. 切片:将原始文本编码成向量。

  2. MinHashing:将向量转换为一种称为 签名 的特殊表示形式,这种表示形式可以用于比较它们之间的相似性。

  3. LSH 函数:将签名块哈希到不同的桶中。如果一对向量的签名至少有一次落入同一个桶中,则它们被视为候选。

我们将逐步深入探讨这些步骤的细节。

切片

切片是从给定文本中收集 k-grams 的过程。k-gram 是一组 k 个顺序排列的标记。根据上下文,标记可以是单词或符号。切片的最终目标是通过使用收集到的 k-grams 来编码每个文档。我们将使用独热编码来完成这一点。然而,也可以应用其他编码方法。

收集句子“学习数据科学很有趣”的长度为 k = 3 的唯一切片

首先,收集每个文档的独特k-gram。其次,为了对每个文档进行编码,需要一个词汇表,它代表了所有文档中独特k-gram 的集合。然后,为每个文档创建一个长度等于词汇表大小的零向量。对于文档中出现的每个k-gram,确定其在词汇表中的位置,并在文档向量的相应位置放置一个“1”。即使相同的k-gram 在文档中出现多次也没关系:向量中的值始终为 1。

一热编码

MinHashing

在这个阶段,初始文本已经被向量化。可以通过Jaccard 指数比较向量的相似性。记住,Jaccard 指数定义为两个集合中共同元素的数量除以所有元素的总长度。

Jaccard 指数定义为两个集合的交集与并集之比

如果取一对编码向量,则 Jaccard 指数公式中的交集是两个都包含 1 的行数(即k-gram 在两个向量中都出现),并且并集是至少包含一个 1 的行数(k-gram 至少在一个向量中出现)。

Jaccard 指数公式

使用上述公式计算两个向量的 Jaccard 指数的示例

当前的问题是编码向量的稀疏性。计算两个一热编码向量之间的相似性得分将耗费大量时间。将它们转换为稠密格式可以使后续操作更高效。最终目标是设计一个将这些向量转换为较小维度的函数,同时保留它们的相似性信息。构建这样的函数的方法称为 MinHashing。

MinHashing 是一种哈希函数,它对输入向量的组件进行排列,然后返回排列向量组件等于 1 的第一个索引。

计算给定向量和排列的 minhash 值的示例

为了获得由n个数字组成的向量的稠密表示,可以使用n个 minhash 函数来获得n个 minhash 值,这些值构成一个签名

一开始可能不太明显,但可以使用多个 minhash 值来近似向量之间的 Jaccard 相似性。实际上,使用的 minhash 值越多,近似值就越准确。

计算签名矩阵及其如何用于计算向量之间的相似性。使用 Jaccard 相似性和签名计算的相似性应该通常大致相等。

这只是一个有用的观察。事实证明,背后有一个完整的定理。让我们来了解为什么 Jaccard 指数可以通过使用签名来计算。

陈述证明

假设给定的一对向量仅包含011011类型的行。然后对这些向量进行随机排列。由于所有行中至少存在一个 1,因此在计算两个哈希值时,这两个哈希值计算过程中的至少一个会在具有对应哈希值为 1 的向量的第一行停止。

第二个哈希值等于第一个的概率是多少?显然,只有当第二个哈希值也等于 1 时才会发生。这意味着第一行必须是11类型。由于排列是随机的,这种事件的概率等于P = count(11) / (count(01) + count(10) + count(11))。这个表达式与 Jaccard 指数公式完全相同。因此:

基于随机行排列,两个二进制向量获得相同哈希值的概率等于 Jaccard 指数

然而,通过证明上述陈述,我们假设初始向量不包含00类型的行。显然,00类型的行不会改变 Jaccard 指数的值。同样,包含00类型行时获得相同哈希值的概率不会影响它。例如,如果第一个排列行是 00,则 minhash 算法只是忽略它,转到下一行,直到找到至少一个 1。当然,00类型的行可能导致不同的哈希值,但获得相同哈希值的概率保持不变

我们已经证明了一个重要的陈述。但是,如何估计获得相同的 minhash 值的概率呢?当然,可以生成所有可能的向量排列,然后计算所有的 minhash 值以找到所需的概率。出于显而易见的原因,这种方法效率不高,因为一个大小为n的向量的可能排列数等于n!。不过,概率可以大致评估:我们可以使用很多哈希函数来生成大量的哈希值。

两个二进制向量的 Jaccard 指数大致等于它们签名中对应值的数量。

数学符号

很容易注意到,采用更长的签名会导致更准确的计算。

LSH 函数

目前,我们可以将原始文本转换为长度相等的密集签名,从而保留关于相似性的的信息。然而,在实践中,这些密集签名通常仍具有高维度,直接比较它们效率不高。

考虑到n = 10⁶ 个文档,每个文档的签名长度为 100. 假设一个签名的单个数字需要 4 字节来存储,那么整个签名将需要 400 字节。存储n = 10⁶ 个文档需要 400 MB 的空间,这在现实中是可行的。但以蛮力方式比较每个文档需要大约 5 * 10¹¹次比较,这太多了,尤其是当n 更大的时候。

为了避免这个问题,可以建立一个哈希表来加速搜索性能,但即使两个签名非常相似,仅在 1 个位置上有所不同,它们仍可能具有不同的哈希值(因为向量的余数可能不同)。然而,我们通常希望它们落入同一个桶中。这就是 LSH 派上用场的地方。

LSH机制构建一个哈希表,该表由几个部分组成,如果一对签名有至少一个对应的部分,它们就会被放入同一个桶中。

LSH 将签名矩阵水平分成相等的b部分,称为,每部分包含r 。而不是将整个签名插入到一个哈希函数中,签名被分成b部分,每个子签名由一个哈希函数独立处理。因此,每个子签名落入不同的桶中。

LSH 的示例。两个长度为 9 的签名被分成 b = 3 个带,每个带包含 r = 3 行。每个子向量被哈希到 k 个可能的桶之一。由于第二个带中存在匹配(两个子向量具有相同的哈希值),我们将这两个签名对视为最近邻候选。

如果两个不同签名的对应子向量之间至少有一个碰撞,那么这些签名被视为候选。如我们所见,这个条件更灵活,因为考虑向量作为候选者时,它们不需要完全相等。然而,这增加了假阳性的数量:一对不同的签名可能只有一个对应的部分,但总体上完全不同。根据问题的不同,优化参数brk 总是更好的。

错误率

使用 LSH,可以估计两个具有相似度s的签名被视为候选的概率,给定带数b和每个带中的行数r。让我们分几个步骤找到它的公式。

两个签名的任意一行相等的概率

一随机带有 r 行的概率相等

一随机带有 r 行的概率不同

表中所有 b 个带不同的概率

至少有一个 b 带相等的概率,即两个签名是候选的

注意,公式没有考虑当不同的子向量意外地哈希到同一个桶中时的碰撞。因此,签名成为候选的真实概率可能会略有不同。

示例

为了更好地理解我们刚刚得到的公式,我们考虑一个简单的例子。考虑两个长度为 35 符号的签名,它们被平均分成 5 个带,每个带有 7 行。以下表格表示了基于 Jaccard 相似度至少有一个相等带的概率:

基于相似度 s,至少获得一对签名具有对应带的概率 P

我们注意到,如果两个相似的签名具有 80% 的 Jaccard 相似度,那么在 93.8% 的情况下它们有一个对应带(true positives)。在剩余的 6.2% 情况下,这样的一对签名是 false negative

现在让我们考虑两个不同的签名。例如,它们的相似度只有 20%。因此,在 0.224% 的情况下,它们是 false positive 候选。在其他 99.776% 的情况下,它们没有相似的带,所以它们是 true negatives

可视化

现在让我们可视化相似度 s 和两个签名成为候选的概率 P 之间的关系。通常,随着签名相似度 s 的提高,签名成为候选的概率应当更高。理想情况下,情况如下:

理想的场景。只有当签名的相似度大于某个阈值 t 时,才认为一对签名是候选的

基于上述获得的概率公式,典型的线如下图所示:

一条典型的线在开始和结束时缓慢上升,并在图中所示的近似概率公式的阈值 t 处有一个陡峭的斜率

可以通过改变带的数量b,将图中的线向左或向右移动。增加 b 将线向左移动,并导致更多的 FP,减少则将其向右移动,导致更多的 FN。根据问题找到一个好的平衡点是很重要的。

带的数量增加,线会向左移动;减少则向右移动

将阈值向左移动会增加 FP,而向右移动则增加 FN

采用不同数量的带和行进行实验

以下为不同值的br构建的几条线图。根据具体任务调整这些参数通常更为有效,以成功检索所有相似文档对,并忽略那些具有不同签名的文档。

调整带的数量

调整行数

结论

我们已经讲解了 LSH 方法的经典实现。LSH 通过使用低维签名表示和快速哈希机制来优化搜索速度,从而减少候选项的搜索范围。同时,这也会影响搜索的准确性,但在实践中,差异通常微不足道。

然而,LSH 对高维数据比较敏感:更多维度需要更长的签名长度和更多计算来保持良好的搜索质量。在这种情况下,建议使用其他索引。

实际上,存在不同的 LSH 实现,但所有这些实现都基于将输入向量转换为哈希值的相同范式,同时保留关于它们相似性的信息。基本上,其他算法只是定义了获得这些哈希值的不同方式。

随机投影是另一种 LSH 方法,将在下一章中介绍,并且在Faiss库中实现为 LSH 索引,用于相似性搜索。

资源

所有图像除非另有说明,均由作者提供。

相似度搜索,第六部分:使用 LSH 森林的随机投影

原文:towardsdatascience.com/similarity-search-part-6-random-projections-with-lsh-forest-f2e9b31dcc47?source=collection_archive---------8-----------------------#2023-07-21

了解如何通过构造随机超平面对数据进行哈希,并反映其相似性

Vyacheslav EfimovTowards Data Science Vyacheslav Efimov

·

关注 发表在 Towards Data Science · 12 分钟阅读·2023 年 7 月 21 日

--

相似度搜索 是一个问题,其中给定一个查询的目标是从所有数据库文档中找到与其最相似的文档。

简介

在数据科学中,相似度搜索通常出现在 NLP 领域、搜索引擎或推荐系统中,在这些领域,需要为查询检索最相关的文档或项目。在处理大规模数据时,存在多种方法来提高搜索性能。

上一部分,我们探讨了局部敏感哈希(LSH)的主要范式,即将输入向量转换为低维哈希值,同时保留它们之间相似性的信息。为了获得哈希值(签名),使用了 minhash 函数。在本文中,我们将随机投影输入数据,以获取类似的二进制向量。

## 相似性搜索,第五部分:局部敏感哈希(LSH)

探索如何将相似性信息融入到哈希函数中

towardsdatascience.com

思路

考虑在高维空间中的一组点。可以构造一个随机超平面作为墙,将每个点分为两组:正组和负组。将“1”赋给正组中的每个点,将“0”赋给负组中的每个点。

超平面分隔两个 3D 空间点的示例

如何确定某个向量在超平面的一侧?通过使用内积!深入到线性代数的本质中,给定向量与超平面法向量之间点积的符号决定了该向量位于超平面的哪一侧。这样,每个数据集向量都可以被分隔到两个侧面之一。

通过计算向量与超平面法向量的内积,并与 0 进行比较,可以判断该向量相对于超平面的位置。

显然,使用一个二进制值对每个数据集向量进行编码是不够的。因此,应该构造多个随机超平面,以便每个向量可以根据其相对位置用多个 0 和 1 进行编码。如果两个向量的二进制代码完全相同,则表示没有任何构造的超平面能够将它们分离到不同的区域。因此,它们在现实中很可能非常接近。

要为给定的查询找到最近的邻居,只需通过检查其相对位置到所有超平面,使用 0 和 1 对查询进行编码即可。可以将查询找到的二进制向量与数据集中所有其他二进制向量进行比较。这可以通过使用汉明距离线性完成。

汉明距离指的是两个向量在其值不同的位置的数量。

计算汉明距离的示例。左侧的一对向量因为其汉明距离较小而彼此更为相似。

与查询的汉明距离最小的二进制向量被作为候选项,然后与初始查询进行全面比较。

为什么超平面是随机构建的?

在当前阶段,似乎有必要探讨为什么超平面是以随机方式构建而不是确定性方式,因而可以定义自定义规则来分隔数据集点。主要有两个原因:

  • 首先,确定性方法无法对算法进行泛化,可能导致过拟合。

  • 其次,随机性允许对算法性能做出概率性陈述,这不依赖于输入数据。对于确定性方法而言,这种方法无法实现,因为它在某些数据上表现良好,而在另一些数据上表现不佳。一个好的类比是确定性 快速排序 算法,它平均时间复杂度为 O(n * log n)。然而,在已排序的数组上,它的最坏情况时间复杂度为 O(n²)。如果有人了解算法的工作流程,那么可以利用这一信息有针对性地降低系统的效率,通过总是提供最坏的数据。这就是为什么随机化的快速排序更受欢迎。随机超平面也有类似的情况。

为什么 LSH 随机投影也被称为“树”?

随机投影方法有时被称为 LSH 树。这是因为哈希码分配的过程可以表示为决策树的形式,每个节点包含一个条件,判断向量是否位于当前超平面的负侧或正侧。

第一个节点检查向量相对于红色超平面的位置。第二层节点检查相对于绿色超平面的位置。最后,第三层检查相对于蓝色超平面的位置。根据这三个条件,向量被分配一个 3 位哈希值。

超平面森林

超平面是随机构建的。这可能导致它们无法有效分隔数据集点,如下图所示。

构建了 4 个超平面来将数据集点表示为 4 长度的二进制向量。即使点 D 和 E 具有相同的哈希码,它们之间的距离仍然相对较远(FP)。相反,点 E 和 F 的情况则是它们位于不同的区域,但彼此非常接近(FN)。考虑到汉明距离,算法通常预测点 D 更接近点 E 而不是点 F。

从技术上讲,当两个点具有相同的哈希代码但彼此距离较远时,这并不是什么大问题。在算法的下一步中,这些点将作为候选项进行完全比较 — 这样算法可以消除 假阳性 情况。假阴性 更复杂:当两个点具有不同的哈希代码但实际上彼此接近时。

为什么不使用经典机器学习中的决策树方法,这些决策树被组合成随机森林以提高整体预测质量?如果一个估计器出现错误,其他估计器可以产生更好的预测,减轻最终预测误差。利用这个想法,构建随机超平面的过程可以独立重复。得到的哈希值可以像上一章中 minhash 值一样,按对向量的方式进行聚合:

如果查询与另一个向量至少有一次相同的哈希代码,则它们被视为候选项

使用这种机制可以减少 假阴性 的数量。

质量与速度的权衡

选择适当数量的超平面以对数据集进行划分非常重要。选择的超平面越多,数据点之间的碰撞越少,但计算哈希代码所需的时间越长,存储它们所需的内存也越多。具体而言,如果数据集由 n 个向量组成,我们用 k 个超平面进行划分,则平均每个可能的哈希代码将被分配给 n / 2ᵏ 个向量。

k = 3 结果是 2³ = 8 个桶

复杂度

训练

LSH Forest 训练阶段分为两个部分:

  1. 生成 k 个超平面。这是一个相对快速的过程,因为在 d 维空间中生成一个超平面所需的时间为 O(d)

  2. 为所有数据集向量分配哈希代码。此步骤可能需要时间,尤其是对于大型数据集。获得单个哈希代码需要 O(dk) 的时间。如果数据集由 n 个向量组成,则总复杂度变为 O(ndk)

上述过程对森林中的每棵树重复多次。

训练复杂度

推断

LSH forest 的一个优点是其快速推断,包括两个步骤:

  1. 获取查询的哈希代码。这相当于计算 k 个标量乘积,复杂度为 O(dk) (d — 维度)。

  2. 查找最近邻,在同一桶内(具有相同哈希代码的向量)通过计算与候选项的精确距离。距离计算线性进行,复杂度为 O(d)。每个桶平均包含 n / 2ᵏ 个向量。因此,计算所有潜在候选项的距离需要 O(dn / 2ᵏ) 的时间。

总复杂度为 O(dk + dn / 2ᵏ)

和往常一样,上述过程对森林中的每棵树重复多次。

推理复杂度

当超平面数量 k 选择为 n ~ 2ᵏ(在大多数情况下是可能的),则总的推理复杂度为 O(ldk)(l 是树的数量)*。基本上 这意味着 计算时间不依赖于数据集的大小! 这种微妙之处使得对数百万甚至数十亿个向量的相似性搜索具有高效的可扩展性。

错误率

在前面的 LSH 文章部分中,我们讨论了如何根据签名相似性找到两个向量被选为候选的概率。在这里,我们将使用几乎相同的逻辑来寻找 LSH 森林的公式。

设 s 为两个向量的哈希值在相同位置上具有相同比特位的概率(s 将在后面估计)

两个向量的长度为 k 的哈希码相等的概率

两个向量的长度为 k 的哈希码不同(或至少有一个比特位不同)的概率

两个向量的所有 l 个哈希码(用于 l 个超平面)不同的概率

至少有一个 l 个哈希码相等的概率,这样向量将成为候选

到目前为止,我们几乎获得了估计两个向量成为候选的概率的公式。剩下的唯一任务是估计方程中变量 s 的值。在经典的 LSH 算法中,s 等于两个向量的 Jaccard 指数或签名相似性。另一方面,为了估计 LSH 森林中的 s,将使用线性代数理论。

说实话,s 是两个向量 ab 具有相同比特位的概率。这个概率等同于一个随机超平面将这些向量分到同一侧的概率。让我们可视化一下:

向量 a 和 b 被蓝色超平面分开。绿色超平面没有将它们分开。

从图中可以看出,只有当超平面穿过它们之间时,才会将向量 ab 分到两个不同的侧面。这样的概率 q 与向量之间的角度成正比,可以很容易地计算:

随机超平面将两个向量分开的概率(即,它们具有不同的比特位)

随机超平面不将两个向量分开的概率(即,它们具有相同的比特位)

将此方程代入之前获得的方程中,可以得到最终公式:

基于超平面数量 k 和 LSH 树的数量 l,两个向量至少有一个对应的哈希值(即成为候选者)的概率

可视化

注意。余弦相似度在正式定义范围[-1, 1]内。为了简便起见,我们将这个区间映射到[0, 1],其中 0 和 1 分别表示最低和最高的相似度。

使用最后得到的公式,让我们可视化不同超平面 k 和树 l 数量下的两个向量成为候选者的概率。

调整树的数量 l

调整超平面数量 k

根据图表,可以得出几个有用的观察结果:

  • 余弦相似度为 1 的一对向量总是成为候选者。

  • 余弦相似度为 0 的一对向量从不会成为候选者。

  • 当超平面数量 k 减少或 LSH 树的数量 l 增加时,两向量成为候选者的概率 P 会增加(即更多的 假阳性)。逆命题也成立。

总结一下,LSH 是一种非常灵活的算法:可以根据给定问题调整不同的 kl 值,以获得符合问题要求的概率曲线。

示例

我们来看以下例子。假设构建了 l = 5 棵树,并且使用了 k = 10 个超平面。此外,还有两个余弦相似度为 0.8 的向量。在大多数系统中,这种余弦相似度表明向量彼此非常接近。然而,根据前面的结果,这个概率仅为 2.5%! 显然,对于如此高的余弦相似度,这个结果非常低。使用这些参数 l = 5k = 10 会产生大量的假阴性! 下面的绿色线条表示这种情况下的概率。

基于两个向量的余弦相似度的概率曲线

这个问题可以通过调整更好的 kl 的值来解决,使曲线向左移动。

例如,如果 k 减少到 3(红线),那么相同的余弦相似度为 0.8 的概率将达到 68%,这比之前要好。乍一看,红线似乎比绿线更合适,但需要注意的是,使用小的 k 值(如红线情况)会导致大量的碰撞。因此,有时调整第二个参数,即树的数量 l 更为可取。

k不同,通常需要非常多的树l才能获得类似的曲线形状。在图中,蓝线是通过将l的值从 10 更改为 500 得到的。蓝线明显比绿线更适合,但仍然远未完美:因为在 0.6 到 0.8 的余弦相似度值之间斜率很高,所以在 0.3-0.5 的余弦相似度附近概率几乎为 0,这不利于效果。实际中,0.3–0.5 的文档相似度的小概率通常应该更高。

根据最后一个例子,很明显,即使是非常多的树(需要大量计算)仍然会产生许多假阴性! 这就是随机投影方法的主要缺点:

尽管理论上可以得到完美的概率曲线,但这要么需要大量计算,要么会导致许多碰撞。否则,它会导致较高的假阴性率。

Faiss 实现

Faiss(Facebook AI 搜索相似性)是一个用 C++编写的 Python 库,用于优化相似性搜索。该库提供了不同类型的索引,这些索引是用于高效存储数据和执行查询的数据结构。

根据Faiss 文档的信息,我们将了解如何构建 LSH 索引。

随机投影算法在 Faiss 中通过IndexLSH类实现。尽管 Faiss 作者使用了一种略有不同的技术称为“随机旋转”,但它与本文所描述的仍有相似之处。该类只实现了一个 LSH 树。如果我们想使用 LSH 森林,只需创建几个 LSH 树并汇总它们的结果即可。

IndexLSH类的构造函数接受两个参数:

  • d: 维度的数量

  • nbits: 编码单个向量所需的位数(可能的桶数等于2ⁿᵇᶦᵗˢ

search()方法返回的距离是查询向量的汉明距离。

Faiss 对 IndexLSH 的实现

此外,Faiss 允许通过调用faiss.vector_to_array(index.codes)方法检查每个数据集向量的编码哈希值。

由于每个数据集向量由nbits个二进制值编码,存储单个向量所需的字节数等于:

Johnson-Lindenstrauss 引理

Johnson-Lindenstrauss 引理是一个与降维相关的精彩引理。虽然可能很难完全理解其原始陈述,但可以用简单的话来表述:

选择一个随机子集并将原始数据投影到该子集上,可以保持点之间的对应对距离。

更准确地说,拥有一个 n 个点的数据集,可以在一个 O(logn) 维的新空间中表示这些点,从而几乎保持点之间的相对距离。如果一个向量在 LSH 方法中由 ~logn 个二进制值编码,则可以应用该引理。此外,LSH 随机创建超平面,正如引理所要求的那样。

约翰逊-林登斯特劳斯引理的另一个令人惊叹的地方是 新数据集的维数不依赖于原始数据集的维数!实际上,这个引理在非常小的维度下效果不佳。

结论

我们已经介绍了一种用于相似性搜索的强大算法。该算法基于通过随机超平面分隔点的简单思路,通常在大数据集上表现良好并且具有很好的可扩展性。此外,它通过允许选择适当数量的超平面和树,提供了良好的灵活性。

约翰逊-林登斯特劳斯引理的理论结果加强了随机投影方法的使用。

资源

除非另有说明,否则所有图像均为作者提供。

相似性搜索,第七部分:LSH 组合

原文:towardsdatascience.com/similarity-search-part-7-lsh-compositions-1b2ae8239aca?source=collection_archive---------6-----------------------#2023-07-24

探索 LSH 函数组合以确保更可靠的搜索

Vyacheslav EfimovTowards Data Science Vyacheslav Efimov

·

关注 发表在 Towards Data Science · 11 分钟阅读 · 2023 年 7 月 24 日

--

相似性搜索 是一个问题,目标是在所有数据库文档中找到与给定查询最相似的文档。

介绍

在数据科学中,相似性搜索通常出现在自然语言处理领域、搜索引擎或推荐系统中,在这些场景下需要检索与查询最相关的文档或项目。在海量数据中提升搜索性能有许多不同的方法。

在本系列文章的最后两部分,我们深入探讨了 LSH——一种将输入向量转换为低维哈希值的算法,同时保持其相似性信息。特别是,我们已经研究了两种适用于不同距离度量的算法:

相似性搜索,第五部分:局部敏感哈希(LSH)

探索如何将相似性信息纳入哈希函数

towardsdatascience.com](/similarity-search-part-5-locality-sensitive-hashing-lsh-76ae4b388203?source=post_page-----1b2ae8239aca--------------------------------)

经典的 LSH 算法构建的签名反映了Jaccard 指数的信息。

相似性搜索,第六部分:使用 LSH 森林的随机投影

理解如何通过构建随机超平面对数据进行哈希并反映其相似性

towardsdatascience.com](/similarity-search-part-6-random-projections-with-lsh-forest-f2e9b31dcc47?source=post_page-----1b2ae8239aca--------------------------------)

随机投影方法构建了一个保持余弦相似性的超平面森林。

实际上,LSH 算法也存在于其他距离度量中。虽然每种方法都有其独特之处,但它们之间有许多共同的概念和公式。为了方便将来学习新方法,我们将更多地关注理论,并提供一些在高级 LSH 文献中经常出现的基本定义和定理。到文章末尾,我们将能够通过简单地将基本方法结合起来,像搭积木一样构建更复杂的 LSH 方案。

另外,我们将在最后了解如何将欧几里得距离纳入局部敏感哈希(LSH)中。

注意。作为主要前提,预计你已经熟悉本系列文章的第 5 和第六部分。如果没有,强烈建议你先阅读它们。

注意余弦距离 在 [0, 2] 范围内正式定义。为了简化起见,我们将其映射到 [0, 1] 区间,其中 0 和 1 分别表示最低和最高可能的相似性。

正式 LSH 定义

给定距离度量 d,H 被称为 (d₁, d₂, p₁, p₂)-敏感的 LSH 函数,如果对于随机选择的对象 x 和 y,满足以下条件:

  • 如果 d(x, y) ≤ d₁,则 p(H(x) = H(y)) ≥ p₁,即 H(x) = H(y) 的概率至少为 p₁。

  • 如果 d(x, y) ≥ d₂,则 p(H(x) = H(y)) ≤ p₂,即 H(x) = H(y) 的概率至多为 p₂。

让我们了解这些陈述的含义。当两个向量相似时,它们之间的距离很小。基本上,第一个陈述确保将它们哈希到同一个桶中的概率高于某个阈值。这样,一些 假阴性 被消除:如果两个向量之间的距离大于 d₁,那么它们被哈希到同一个桶中的概率总是小于 p₁。相反,第二个陈述控制 假阳性:如果两个向量不相似且它们之间的距离大于 d₂,则它们出现在同一个桶中的上限概率为 p₂

根据上述陈述,我们通常希望系统中的以下陈述得到满足:

  • p₁ 应尽可能接近 1,以减少 假阴性 的数量。

  • p₂ 应尽可能接近 0,以减少 假阳性 的数量。

  • d₁d₂ 之间的间隙应尽可能小,以减少无法对数据进行概率估计的区间。

左侧的图示展示了 LSH 参数(d₁, d₂, p₁, p₂)记号的典型曲线。右侧的曲线展示了一个理想的情况,其中阈值 d₁d₂ 之间没有间隙。

有时,上述陈述使用相似度 s 而不是距离 d 来引入:

给定一个相似度度量 s,H 被称为 (s₁, s₂, p₁, p₂)-敏感 LSH 函数,如果对于随机选择的对象 x 和 y,满足以下条件:

  • 如果 s(x, y) ≥ s₁,则 p(H(x) = H(y)) ≥ p₁,即 H(x) = H(y) 的概率至少为 p₁。

  • 如果 s(x, y) ≤ s₂,则 p(H(x) = H(y)) ≤ p₂,即 H(x) = H(y) 的概率至多为 p₂。

左侧的图示展示了 LSH 参数(s₁, s₂, p₁, p₂)的关系的典型曲线。右侧的曲线展示了一个理想的情况,其中阈值 s₁s₂ 之间没有间隙。

注意:在本文中,将使用两种记号((d₁, d₂, p₁, p₂)(s₁, s₂, p₁, p₂))。根据文本中使用的记号字母,应该可以清楚地知道是隐含距离 d 还是相似度 s。默认情况下,使用记号 (d₁, d₂, p₁, p₂)

LSH 示例

为了使问题更清楚,我们证明以下陈述:

如果距离度量 s 是 Jaccard 指数,那么 H 是一个 (0.6, 0.6, 0.4, 0.4)-敏感 LSH 函数。基本上,需要证明等效陈述:

  • 如果 d(x, y) ≤ 0.6,则 p(H(x) = H(y)) ≥ 0.4

  • 如果 d(x, y) ≥ 0.6,则 p(H(x) = H(y)) ≤ 0.4

从这篇文章系列的第五部分我们知道,两个二进制向量获得相等哈希值的概率等于 Jaccard 相似度。因此,如果两个向量至少相似 40%,那么获得相等哈希值的概率也至少为 40%。与此同时,至少 40%的 Jaccard 相似度等同于最多 60%的 Jaccard 指数。因此,第一个陈述得到了证明。第二个陈述可以做类似的推理。

这个例子可以概括为定理:

定理。如果 d 是 Jaccard 指数,则 H 是一个(d₁, d₂, 1 — d₁, 1 — d₂)的 LSH 函数族。

类似地,基于第六部分获得的结果,可以证明另一个定理:

定理。如果 s 是余弦相似度(介于-1 和 1 之间),则 H 是一个(s₁, s₂, 1 — arccos(s₁) / 180, 1 — arccos(d₂) / 180)的 LSH 函数族。

结合 LSH 函数

让我们参考在之前的 LSH 部分中学到的有用概念:

  • 回到第五部分的 minhashing,每个向量被分成若干个带,每个带包含一组行。为了将一对向量视为候选对,必须存在至少一个带,其中所有的向量行都是相等的。

  • 关于第六部分的随机投影,只有在存在至少一个树,其中所有的随机投影未能分开初始向量时,两个向量才被视为候选。

我们可以注意到,这两种方法在底层有相似的范式。只有当至少有一次n配置中向量的哈希值全部相同k次时,它们才会被视为候选对。用布尔代数表示,可以写成这样:

基于这个例子,我们引入逻辑运算符ORAND,它们允许聚合一组哈希函数。然后我们将估计它们如何影响两个向量成为候选的输出概率以及假阴性假阳性错误的率。

AND 运算符

给定 n 个独立的 LSH 函数 H₁, H₂, … Hₙ,AND运算符仅当两个向量的所有 n 个对应哈希值相等时,才会将它们视为候选对。否则,向量不会被视为候选。

如果两个高度不同的向量的哈希值通过AND运算符进行聚合,那么它们成为候选的概率会随着使用的哈希函数数量的增加而减少。因此,假阳性的数量也会减少。

同时,两类似的向量可能由于偶然原因而产生一对不同的哈希值。因此,算法不会将这些向量视为相似。这个方面导致了较高的假阴性率。

定理。考虑 r 个独立的(s₁, s₂, p₁, p₂)-敏感 LSH 函数。将这些 r 个 LSH 函数与 AND 操作符结合会得到一个具有以下参数的新 LSH 函数

通过使用多个独立事件的概率公式来证明这一点是容易的,该公式将所有事件的概率相乘以估计所有事件发生的概率。

OR 操作符

给定 n 个独立的 LSH 函数 H₁、H₂、… Hₙ,OR操作符仅在至少一个对应哈希值相等时将两个向量视为候选对。否则,这些向量不被视为候选。

AND操作符相反,OR操作符增加了任何两个向量成为候选的概率。对于任何向量对,只需对应哈希值中的一个相等即可。因此,OR 聚合减少了假阴性的数量,增加了假阳性

定理。考虑b个独立的(d₁, d₂, p₁, p₂)-族 LSH 函数。将这些b个 LSH 函数与 AND 操作符结合会得到一个具有以下参数的新 LSH 函数

我们不会证明这个定理,因为类似的概率公式已经在本文系列的第五部分中获得和解释。

组合

通过ANDOR操作,可以以各种方式将它们结合在一起,以更好地控制假阳性假阴性率。假设有r个 LSH 函数由AND组合器使用,b个 LSH 函数由OR组合器使用。可以使用这些基本组合器构建两种不同的组合:

AND-OR 和 OR-AND 是可以通过使用 AND 和 OR 操作符构建的两种类型的组合。

前两篇文章中描述的算法使用了AND-OR组合。实际上,构建基于ANDOR操作的更复杂的组合并没有什么阻碍。

组合示例

让我们研究一个示例,以了解ANDOR的组合如何显著提高性能。假设一个OR-AND组合,其参数为b = 4r = 8。根据上述公式,我们可以估计两个向量成为候选的初始概率在组合后的变化:

应用参数为 b = 4 和 r = 8 的 OR-AND 组合后的概率变化。第一行显示初始概率,第二行显示转换后的概率。

例如,如果对于两个向量之间的某个相似度值,单个 LSH 函数在 40%的情况下将它们哈希到相同的桶中,那么在OR-AND组合后,它们将在 32.9%的情况下被哈希。

为了理解组合的特殊性,考虑一个(0.4, 1.7, 0.8, 0.2)敏感的 LSH 函数。经过OR-AND变换后,LSH 函数转变为(0.4, 1.7, 0.0148, 0.987)敏感格式。

从本质上讲,如果最初两个向量非常相似且距离小于 0.4,那么它们将被认为是 80%情况下的候选。然而,应用组合后,它们现在在 98.7%的情况下被认为是候选,从而大大减少了假阴性错误!

类似地,如果两个向量彼此差异很大且距离大于 1.7,那么它们现在只会在 1.48%的情况下被认为是候选(相较于之前的 20%)。这样,假阳性错误的频率降低了 13.5 倍!这是一项巨大的改进!

曲线显示了在不同组合后初始概率的转变

通常,通过具有(d₁, d₂, p₁, p₂)敏感性的 LSH 函数,可以将其转换为(d₁, d₂, p’₁, p’₂)格式,其中p’₁接近 1,而p’₂接近 0。要使p’₁p’₂更接近 1 和 0,通常需要使用更多的组合。

用于其他距离度量的 LSH

我们已经深入研究了用于保留 Jaccard 指数和余弦距离信息的 LSH 方案。自然会产生一个问题,那就是是否可以使用 LSH 来处理其他距离度量。不幸的是,对于大多数度量,没有相应的 LSH 算法。

尽管如此,欧几里得距离的 LSH 方案确实存在——这是机器学习中最常用的度量之一。由于它被广泛使用,我们将研究如何获取欧几里得距离的哈希值。通过上述引入的理论符号,我们将证明这一度量的一个重要 LSH 属性。

用于欧几里得距离的局部敏感哈希(LSH)

欧几里得空间中点的哈希机制包括将它们投影到随机线上。算法假设

  • 如果两个点相对接近,那么它们的投影也应该相对接近。

  • 如果两个点彼此相距较远,那么它们的投影也应该相距较远。

为了测量两个投影的接近程度,可以将一条线分成若干个大小为a的相等段(桶)。每个线段对应一个特定的哈希值。如果两个点投影到相同的线段上,那么它们具有相同的哈希值。否则,哈希值不同。

在随机线上投影点

尽管这种方法起初可能看起来很可靠,但它仍然可能将相隔较远的点投影到相同的段中。特别是在连接两个点的线几乎与初始投影线垂直时,这种情况尤其明显。

尽管两个点相对较远,但它们仍有可能被哈希到同一个桶中。

为了降低错误率,强烈建议使用随机投影线的组合,如上所述。

从几何上讲,可以证明如果 a 是欧几里得空间中单个线段的长度,则 H(a / 2, 2a, ½, ⅓)-敏感 LSH 函数。

结论

在本章中,我们积累了关于一般 LSH 符号的知识,这帮助我们正式引入了组合操作,使我们显著降低了错误率。值得注意的是,LSH 只存在于机器学习指标的少数部分,但至少对于最流行的指标,如欧几里得距离、余弦距离和 Jaccard 指数,LSH 是存在的。在处理其他度量向量相似性时,建议选择另一种相似性搜索方法。

作为参考,本文中介绍的陈述的正式证明可以在这些讲义中找到。

资源

除非另有说明,所有图像均由作者提供。

相似性搜索,第二部分:产品量化

原文:towardsdatascience.com/similarity-search-product-quantization-b2a1a6397701?source=collection_archive---------1-----------------------#2023-05-10

学习一种强大的技术来有效地压缩大量数据

Vyacheslav EfimovTowards Data Science Vyacheslav Efimov

·

关注 发布于 Towards Data Science ·9 分钟阅读·2023 年 5 月 10 日

--

相似性搜索是一个问题,其中给定一个查询,目标是找到与其最相似的数据库文档。

介绍

在数据科学中,相似性搜索常出现在自然语言处理(NLP)领域、搜索引擎或推荐系统中,需要为查询检索最相关的文档或项。存在多种方法来提升在大量数据中的搜索性能。

在这篇文章系列的第一部分中,我们研究了 kNN 和倒排文件索引结构来执行相似性搜索。正如我们所了解的,kNN 是最直接的方法,而倒排文件索引则在其之上运行,建议在速度加速和准确性之间进行权衡。然而,这两种方法都不使用数据压缩技术,这可能会导致内存问题,特别是在大数据集和有限 RAM 的情况下。在本文中,我们将尝试通过另一种方法来解决这个问题,这种方法被称为 产品量化

## 相似性搜索,第一部分:kNN 和倒排文件索引

相似性搜索是一个流行的问题,其中给定查询 Q,我们需要在所有文档中找到最相似的文档…

[towardsdatascience.com

定义

产品量化 是一个过程,其中每个数据集向量被转换成一种简短的内存高效表示(称为 PQ 代码)。与其完全保存所有向量,不如存储它们的简短表示。同时,产品量化是一种有损压缩方法,会导致预测准确性降低,但在实际应用中,该算法效果很好。

一般来说,量化是将无限值映射到离散值的过程。

训练

首先,算法将每个向量分成几个相等的部分 —— 子向量。所有数据集向量的各个部分形成独立的 子空间 并分别处理。然后,对每个子空间的向量执行聚类算法。这样,每个子空间中会生成多个质心。每个子向量使用它所属质心的 ID 进行编码。此外,所有质心的坐标也会被存储以供后续使用。

子空间质心也称为 量化向量

在产品量化中,集群 ID 通常被称为 重现值

注意。 在下面的图示中,矩形代表包含多个值的向量,而正方形表示单个数字。

使用量化进行编码

因此,如果原始向量被分成 n 部分,那么它可以由 n 个数字 —— 每个子向量的相应质心的 ID 来编码。通常,创建的质心数量 k 通常选择为 2 的幂,以便更有效地使用内存。这样,存储一个编码向量所需的内存是 n * log(k) 位。

每个子空间内所有质心的集合称为 代码本。对所有子空间运行 n 个聚类算法会生成 n 个独立的代码本。

压缩示例

想象一个原始向量大小为 1024,存储浮点数(32 位),被划分为n = 8个子向量,每个子向量由k = 256个聚类中的一个进行编码。因此,编码单个聚类的 ID 需要log(256) = 8位。让我们比较这两种情况下向量表示的内存大小:

  • 原始向量:1024 * 32 位 = 4096 字节。

  • 编码向量:8 * 8 位 = 8 字节。

最终压缩比为 512 倍!这就是产品量化的真正力量。

量化示例。向量中的数字显示了它存储了多少数字。

以下是一些重要的备注:

  • 该算法可以在一个向量子集上进行训练(例如,创建聚类),并用于另一个子集:一旦算法被训练,另一个向量数据集会传递过来,其中新向量使用已经构建的每个子空间的质心进行编码。

  • 通常,k-means 被选择作为聚类算法。它的一个优点是,聚类数k是一个超参数,可以根据内存使用要求手动定义。

推断

为了更好地理解,让我们首先看几个天真的方法并找出它们的缺点。这也将帮助我们理解为什么它们通常不应该被使用。

天真的方法

第一种天真的方法包括通过连接每个向量的相应质心来解压所有向量。之后,可以从查询向量到所有数据集向量计算L2距离(或其他度量)。显然,这种方法是可行的,但非常耗时,因为进行的是暴力搜索,并且距离计算是在高维解压向量上进行的。

另一种可能的方法是将查询向量拆分为子向量,并计算每个查询子向量与基于其 PQ 代码的数据库向量的相应量化向量的距离之和。因此,暴力搜索技术再次被使用,并且这里的距离计算仍然需要原始向量维度的线性时间,与前一种情况相同。

使用天真的方法计算近似距离。示例显示了作为度量的欧几里得距离。

另一种可能的方法是将查询向量编码为 PQ 代码。然后直接利用这个 PQ 代码计算与所有其他 PQ 代码的距离。具有最短距离的对应 PQ 代码的数据集向量被认为是查询的最近邻。这种方法比前两种方法更快,因为距离总是在低维 PQ 代码之间计算。然而,PQ 代码由聚类 ID 组成,这些 ID 没有太多语义意义,可以被明确视为实际变量的类别变量。显然,这是一种不好的做法,这种方法可能导致预测质量差。

优化方法

查询向量被分为子向量。对于每个子向量,计算其到相应子空间中所有质心的距离。最终,这些信息存储在表格 d 中。

获取一个表格 d 存储部分查询子向量到质心的距离

计算出的子向量到质心的距离通常被称为 部分距离

通过使用这个子向量到质心的距离表 d,可以通过其 PQ 代码轻松获得查询到任何数据库向量的近似距离:

  1. 对于数据库向量的每个子向量,找到最接近的质心 j(通过使用 PQ 代码中的映射值),并从质心到查询子向量 i 的部分距离 d[i][j](通过使用计算矩阵 d)被取用。

  2. 所有的部分距离被平方并求和。通过对该值开方,可以获得近似的欧几里得距离。如果你想了解如何获得其他度量的近似结果,请导航到 “其他距离度量的近似” 部分。

通过使用 PQ 代码和距离表计算查询向量到数据库向量的距离

使用这种方法计算近似距离假设部分距离 d 非常接近查询与数据库子向量之间的实际距离 a

然而,这种条件可能无法满足,特别是当数据库子向量与其质心之间的距离 c 较大时。在这种情况下,计算结果的准确性较低。

左侧的示例展示了一个良好的近似情况,当实际距离非常接近部分距离(c 较小)。右侧则展示了一个不良情况,因为部分距离远大于实际距离(c 较大)。

在获取所有数据库行的近似距离后,我们搜索具有最小值的向量。这些向量将是查询的最近邻。

其他距离度量的近似

到目前为止,我们已经了解了如何通过部分距离来近似欧几里得距离。让我们将规则推广到其他度量。

假设我们想要计算一对向量之间的距离度量。如果我们知道度量的公式,我们可以直接应用它来得到结果。但有时我们可以按以下方式分步进行:

  • 两个向量都被分成 n 个子向量。

  • 对于每一对对应的子向量,计算距离度量。

  • 计算出的 n 个度量然后被组合以生成原始向量之间的实际距离。

图中显示了计算度量的两种方法。左侧,度量公式直接应用于两个向量。右侧,为每对对应的子向量计算部分距离,然后通过使用聚合函数 h、g 和 f 进行组合。

欧几里得距离是可以按部分计算的度量的一个例子。根据上图,我们可以选择聚合函数为 h(z) = z²g(z₀, z₁, …, zₙ) = sum(z₀, z₁, …, zₙ)f(z) = √z

欧几里得距离可以按部分计算

内积是另一种度量的例子,具有聚合函数 h(z) = z, g(z₀, z₁, …, zₙ) = sum(z₀, z₁, …, zₙ) 和 f(z) = z

在产品量化的背景下,这是一个非常重要的属性,因为在推断过程中,算法按部分计算距离。这意味着使用没有此属性的度量来进行产品量化会更加困难。余弦距离就是这种度量的一个例子。

如果仍然需要使用没有此属性的度量,则需要应用额外的启发式方法来聚合带有一定误差的部分距离。

性能

产品量化的主要优势是数据库向量的巨大压缩,这些向量作为简短的 PQ 代码存储。对于某些应用,这种压缩率甚至可能超过 95%! 然而,除了 PQ 代码外,还需要存储大小为 k x n 的矩阵 d,其中包含每个子空间的量化向量。

产品量化是一种有损压缩方法,因此压缩越高,预测准确性下降的可能性就越大。

建立一个高效表示的系统需要训练多个聚类算法。除此之外,在推断过程中,k * n 部分距离需要以暴力方式计算并为每个数据库向量求和,这可能需要一些时间。

产品量化性能

Faiss 实现

Faiss(Facebook AI 搜索相似性)是一个用 C++编写的 Python 库,用于优化相似性搜索。该库展示了不同类型的索引,这些索引是用于高效存储数据和执行查询的数据结构。

根据Faiss 文档的信息,我们将看到产品量化是如何被利用的。

产品量化在IndexPQ类中实现。初始化时,我们需要提供 3 个参数:

  • d:数据中的维度数量。

  • M:每个向量的拆分数量(与上文中使用的n相同的参数)。

  • nbits:编码单个聚类 ID 所需的位数。这意味着单个子空间中的总聚类数将等于 k = 2^nbits

对于相等子空间维度的拆分,参数 dim 必须可以被 M 整除。

存储单个向量所需的总字节数等于:

如上公式所示,为了更有效地利用内存,M * nbits 的值应能被 8 整除。

Faiss 对 IndexPQ 的实现

结论

我们已经探讨了信息检索系统中一种非常流行的算法,该算法能有效地压缩大量数据。其主要缺点是推理速度较慢。尽管如此,该算法在现代大数据应用中被广泛使用,特别是与其他相似性搜索技术结合时。

在文章系列的第一部分中,我们描述了倒排文件索引的工作流程。实际上,我们可以将这两种算法合并成一个更高效的算法,兼具两者的优势!这正是我们将在本系列的下一部分中要做的。

## 相似性搜索,第三部分:融合倒排文件索引和产品量化

在本系列的前两部分中,我们讨论了信息检索中的两个基本算法:倒排索引……

medium.com

资源

除非另有说明,否则所有图片均由作者提供。

基本统计概念的简单解释(第二部分)

原文:towardsdatascience.com/simple-explanations-of-basic-statistics-concepts-part-2-baa49db597ba

不同统计概念的简单解释

Chi NguyenTowards Data Science Chi Nguyen

·发布于 Towards Data Science ·7 分钟阅读·2023 年 3 月 22 日

--

图片来源于 Icons8 TeamUnsplash

介绍

第一部分:基本统计概念的简单解释中,我解释了有关统计概念的一些基本思想,包括与总体和样本相关的不同定义、抽样方法和置信区间。今天,我将为你提供一些经常遇到的统计主题的额外解释。希望即使是对统计学不熟悉的人也能简单易懂。

现在,让我们深入了解吧!

变异性

什么是变异性 & 为什么它很重要?

谈到变异性时,你是在讨论数据的分散程度。中位数和均值并不适用于此,因为它们仅显示了大多数数据值的范围。

你还记得第一部分中提到的我家柠檬农场的例子吗 ^^?在本节中,我将再次提到我的柠檬农场。每年,我家需要在 1 月和 9 月进行两次柠檬收获。下面展示了每个收获季节柠檬重量的分布。乍一看,我们可以看到两个收获季节的柠檬平均重量大致相同。然而,似乎 1 月份柠檬的重量分布比 9 月份的柠檬更为分散。换句话说,尽管两个月收获的柠檬重量相似,但 1 月份收获的柠檬变异性更大。因此,相同的集中趋势并不意味着相似的变异度,反之亦然。

图 1 — 作者提供的图片

显然,了解变异性同样重要,因为它帮助我的家庭评估两个收获季节之间柠檬的质量,并在 1 月份调整种植方法,以产出更可比的产品。

总体而言,低变异性更可取,因为它使用样本数据提供了更准确的总体信息预测。

那么,我们如何在统计学中描述变异性(或差异)呢?我们来看 4 个指标:范围、标准差、方差和四分位差。

范围

这是变异性的最简单测量方法,计算方法是最小值与最大值之间的差异。

例如,在 1 月份的收获中,最重的柠檬重 13 克,而最小的柠檬仅重 2 克。这意味着重量从 2 克到 13 克变化,柠檬的重量范围是 11 克。

尽管其简单性,范围很少作为唯一的变异性测量方法。原因是范围不能考虑所有数据点。看看下面的图 2,你会看到两种情况的范围都是 13–5=8 克。然而,那两者之间的体重分布完全不同。这就是为什么仅了解范围并不能告诉你数据如何分散。

图 2:作者拍摄

要深入了解,方差和标准差是你可能需要的。

方差与标准差

这两个指标描述了值的分布情况。

在比较两个大致相同平均值的不同数据集的离散度时,标准差很有帮助,因为它告诉我们每个数据点与均值的平均距离。标准差较小的数据集在均值周围更加集中。

但在使用标准差时,有一些注意事项。

首先,标准差需要参考均值来评估。例如,在比较猪的体重时,500 克的差异并不大。然而,对柠檬来说情况却不同。虽然柠檬的平均重量仅为 10 克,但增加 500 克的重量会产生巨大的差异。

图 3:作者拍摄

其次,极端值可能会影响对标准差的解释。 一些异常值可能会增加标准差,使离散度看起来比正常情况更大。这引出了我的第三点,即当数据呈正态分布时,标准差更受青睐。

方差是通过平方标准差计算得出的。更高的方差意味着数据更分散。方差可能更难直观理解,但毫无疑问,它是用于统计检验的一个重要指标,例如 ANOVA,用于测试数据集之间的差异。

四分位差

当数据具有不对称分布、包含极端值或在有序水平上测量时,使用四分位距会更合适。

当数据值按升序排序时,Q1(即第一个四分位数)是 25%的数据值小于或等于的那个值。类似地,75%的数据值低于的那个值被称为第三四分位数(Q3)。IQR 即 Q3 与 Q1 之间的差值。

图 4: IQR — 作者

假设我农场生产的 9 月柠檬和 1 月柠檬的重量中位数相同。然而,如图 5 所示,1 月收获的重量 IQR 大于 9 月的。因此,这表明 1 月采摘的柠檬重量差异显著,而 9 月收获的柠檬重量更为接近。

图 5 — 作者

一般来说,…

  • 如果你的数据是有序的,使用范围和四分位距来评估数据的变异性。

  • 如果你的数据呈正态分布,应考虑标准差和方差,但要注意异常值。

  • 如果你的数据不对称或有异常值,四分位距是合适的方法。

标准差与标准误差

标准误差定义为均值的标准差。这是什么意思呢?

假设我家想检查 9 月柠檬的重量。然而,由于称量所有柠檬会浪费太多时间,我们决定随机选择 4 个样本来评估重量。对于每一批次,我们仔细称量每个柠檬,并得出每批柠檬的平均重量。结果,我们得到 4 个均值,分别对应 4 个批次。然后,我们计算这 4 个均值的平均值。均重值的标准差被定义为标准误差。换句话说,标准误差提示了如果我们随机选择 4 个柠檬样本进行评估时,均值的变化情况

如你从我之前的帖子中了解到的,抽样误差始终存在。因此,通过了解标准误差,你可以评估你的样本代表了多大程度的人群,从而得出可靠的见解。低标准误差表明样本均值接近于总体均值,意味着样本代表性较好,反之亦然。

图 6: 标准误差示例 — 作者

标准误差通常小于标准差,因为均值之间的差异不显著。

标准差与变异系数

变异系数(CV)可以简单理解为相对变异性的度量。它是通过将标准差与数据的规模结合考虑得出的。我为什么这么说呢?我们来举一个非常简单的例子。

假设我有两个数据集 A 和 B,如下所示。对于每个数据集,我可以很容易地计算均值和标准差。看来这两个数据集的标准差相同。然而,这是否意味着这两个数据集的数据分布应相似?答案是否定的。

看列表 A,你会发现最大值是最小值的 10 倍。而在列表 B 中,最大值仅比最小值大 1.009 倍。因此,很明显,列表 B 中的数据点彼此更接近。在这种情况下,标准差不再有用,因为它是以绝对值计算的。我们需要一种能考虑数据集规模的度量,那就是 CV。

CV 的计算方法是将标准差除以均值。A 的 CV 为 0.94,而 B 的 CV 仅为 0.004。B 的 CV 远小于 A 的 CV。显然,这个结果给我们展示了一个与标准差完全不同的情况。

图 6:CV — 作者

通常,当我们想要比较两个不同单位的数据集(如津巴布韦元与美元)时,CV 比标准差更有用。

结论

统计学还有很多内容需要覆盖。然而,我会很快回来发布下一篇文章。希望我能让这些概念稍微清晰一点。感谢你一直耐心阅读到最后。

为了接收有关我即将发布的文章的更新,请使用提供的Medium 链接订阅成为会员。

干杯。

参考资料

[## 4.5.3 计算方差和标准差

Statistics: Power from Data! 是一个创建于 2001 年的网络资源,旨在帮助中学学生和教师…

www150.statcan.gc.ca

www.scribbr.com/statistics/variability/#:~:text=questions%20about%20variability-,Why%20does%20variability%20matter%3F,the%20sample%20to%20your%20population

使用 Streamlit 进行简单调查

原文:towardsdatascience.com/simple-surveys-with-streamlit-and-databutton-d027586f1c71

Streamlit 的用户界面组件使得构建简单调查变得容易

Alan JonesTowards Data Science Alan Jones

·发布于 Towards Data Science ·阅读时间 10 分钟·2023 年 6 月 19 日

--

照片由 Nguyen Dang Hoang Nhu 提供,来源于 Unsplash

  • 你对人工智能的未来有什么看法?应该对其进行监管吗?它会创造新工作还是会毁灭现有工作?

  • 你认为气候变化将如何影响你的生活方式?

  • 你相信宇宙中存在外星生命吗?

  • 你最喜欢的数据科学编程语言是什么?

有时我们使用他人的数据来创建故事——而有时我们需要创建自己的数据,因此我们必须进行收集。这可能是调查或实验结果的日志,但我们需要提出问题并记录结果数据。

当然,也有一些服务可以为你完成这项工作(有时需要付费,但也经常有免费的选项)。或者你可以坚持使用经过验证的剪贴板和铅笔方法。

但如果你是 Streamlit 用户,创建一个简单的调查还是相当容易的。

存储数据

不过有一点问题。虽然 Streamlit 的用户界面组件非常出色且易于使用,但没有内置的数据存储方法。你可以简单地将数据存储在文本文件或 SQLite 数据库中,这对于本地应用来说效果很好。

如果你尝试在 Streamlit Cloud 部署该应用,你会发现你创建的任何数据都会消失。

一想就明白了。

当你启动一个 Streamlit Cloud 应用时,它会从 Github 复制源文件,包括任何数据文件或数据库,但当你离开应用时,数据不会被写回。因此,当你再次启动应用时,你将从头开始。你收集并存储的任何数据只能在应用运行时存在,当你离开应用时,这些数据就会丢失。

对于调查应用来说,这种行为并不好。

Streamlit 团队当然考虑到了这一点,并在他们的文档中提供了建议的解决方案(参见“知识库”中的教程部分)。这些主要涉及连接运行各种数据库(如 MySQL、Microsoft SQL Server 等)的数据库服务器,但也展示了如何将 Streamlit 与基于云的服务(如 Amazon S3、MongoDB 和 Google Cloud Storage)一起使用。

还有一个Databutton,这是一个综合性的在线开发环境,专为 Streamlit 设计,具有一键部署、AI 支持编码等功能,并且将数据存储方便地集成在开发和部署环境中。文章末尾有一节关于迁移到 Databutton 的内容。

(你们中的一些人可能已经看到这篇文章的标题上带有 Databutton——这是一个错误。我已经包含了一个 Databutton 部分,但我会在不久的将来写一篇更全面的关于 Databutton 的文章。)

现在,我们将集中精力处理调查部分,并单独处理存储。在这个应用程序中,我们将使用一个本地文件来存储数据,但为了将来更方便,我们将把所有文件操作放入一个库中。这样,如果我们想迁移到另一个平台,我们只需重写库即可。因此,请记住,我们的初始应用程序并不是为了部署而设计的,而是为了在本地机器上运行。

在 Streamlit 中创建调查

Streamlit 提供了良好的用户界面组件选择,可用于创建、展示和分析调查数据。特别是,我们将利用一组单选按钮来实现多项选择题,并使用可编辑的数据框来展示和编辑问卷本身。

使用 Streamlit 单选按钮的多项选择题——截图由作者提供

我们可以稍后考虑更复杂的展示方式或不同的问题类型——目前我们将保持简单。

该应用程序有三个组件:问卷编辑器;调查展示;和结果分析器/可视化器。我已经将它们实现为多页面应用程序中的页面。(这仅意味着它们位于一个名为 pages 的文件夹中。)

编辑器

我们将主要使用 Python 字典来表示我们的数据——包括问卷和结果——在本地应用程序版本中,我们将其存储为 JSON 文件。

问题将存储在两个字段中,text,一个包含问题文本的字符串,以及 responses,一个由逗号分隔的多项选择答案的字符串。

你可以在下面的截图中看到以 Streamlit data_editor 组件显示的问题数据。使用这个组件,你可以直接编辑问卷(如果你愿意的话)。

在可编辑的数据框上方有几个字段:第一个是问题,第二个是可能的回答列表。填写这些内容并点击 将问题添加到问卷 按钮,你会看到新问题出现在数据框中。

如我所说,你也可以直接编辑数据框:点击相应的字段以更改现有数据;点击行左侧以选择该行并使用删除键删除它;或点击最后一行下方左侧以添加新行。

无论哪种情况,你都需要点击 保存更改 以存储数据。

问卷编辑器 — 作者截图

你可以看到下面的实现。

Streamlit 程序在每次用户交互时都会重新运行,因此我们使用 Streamlit 会话功能来存储问卷,以便其值得到适当维护。除此之外,这是一个非常直接的 Streamlit 程序;它展示了两个 st.text_input() 组件(在第二个组件中添加了默认响应字符串),接着是一个 st.data_editor(),它不仅显示问卷,还允许修改。

程序的最后一部分是数据存储。这使用了我在 DButils 库中编写的例程。这些本质上是对基本文件存储函数的封装 —— 正如我之前所说,我实现了类似的存储方式,以便程序可以在不同平台上与其他存储选项一起使用。

DButils.get_survey() 用于检索存储的问卷,DButils.save_survey() 用于将整个数据框保存到文件中。

import streamlit as st
import DButils

st.set_page_config(layout="wide")

if 'survey' not in st.session_state:
    st.session_state['survey'] = DButils.get_survey()

st.title("Questionaire editor")
st.write("""Type in the question text in the field below and then add
            a list of possible responses (or you can leave, or edit, 
            the default responses)."""
)

# Set a default response
default_response = (
    "Strongly agree,Agree,Neither agree not disagree,Disagree,Strongly disagree"
)

st.header("Question")
q_text = st.text_input("Question text")
q_responses = st.text_input(
    "A comma separated list of responses", value=default_response
)

submitted = st.button("Add question to survey")

if submitted:
    st.session_state['survey'].append(
        {
            "text": q_text,
            "responses": q_responses,
        }
    )

st.write("You can also edit the questions and response directly in the table.")

edited_df = st.data_editor(st.session_state['survey'], num_rows="dynamic")

save = st.button("Save changes")
if save:
    DButils.save_survey(edited_df)
    st.success(f"Changes saved")

展示问卷

每个问题都以一组单选按钮的形式展示。

展示问卷 — 作者截图

如下所示,我们遍历问卷,提取 text 字段作为提示,并将 responses 字段拆分为单独的答案,以便显示按钮组。

import pandas as pd
import streamlit as st

import DButils

st.info("## Select the answer to each question and then click on 'Submit'")
questions = DButils.get_survey()

responses = {}

for q in questions:
    response = st.radio(label=q['text'], options=q['responses'].split(","))
    responses[q['text']] = response.strip()

if st.button("Submit"):
    entry = responses
    DButils.append_results(entry)
    st.write("Updated")

然后,将完整的记录数据添加到存储的响应中,使用 DButils.update()

展示结果

结果页面分为 3 部分:第一部分以数据表格的形式展示结果,可以下载为 CSV 文件。

展示结果 1 — 作者截图

第二部分是对整个调查的图形概述。条形图是使用 Plotly Express 创建的。

展示结果 2 — 作者截图

最后一部分允许用户选择每个问题的结果,这些结果以条形图的形式显示(也是 Plotly)。

展示结果 3 — 作者截图

下面是这部分的代码。我们使用 DButils.get_results() 来加载结果数据框,然后将其显示为 st.dataframe()(这次不可编辑,当然!),并添加一个下载按钮,将数据保存到你的本地机器上作为 CSV 文件。

接下来是整个响应数据的条形图(每个问题有不同的颜色)。由于这不一定是最容易阅读的,因此紧接着会出现一个单选按钮组,让你选择一个特定的问题进行关注。每个问题的条形图在之前的循环中绘制,并根据选择的单选按钮显示相应的条形图。

import streamlit as st
import plotly.express as px
import pandas as pd
import DButils

st.set_page_config(layout="wide")

st.info("## Here are the results:")

st.write("The results are presented as a dataframe.")

# Read data from Databutton's datastore
results = DButils.get_results()
st.dataframe(results, use_container_width=True)

df = pd.DataFrame(results)

st.download_button(
    label="Download data as CSV",
    data=df.to_csv().encode("utf-8"),
    file_name="survey_results.csv",
    mime="text/csv",
)

# Plot a summary bar graph

fig = px.bar(results, title="Survey responses - overview")
fig.update_xaxes(title_text="Response")
fig.update_yaxes(title_text="Count")
st.plotly_chart(fig)

# Create an array of bar graph figures
# one for each question 

figures = []

for q in df.columns:
    fig = px.bar(df[q], title=q)
    fig.update_layout(showlegend=False)
    fig.update_xaxes(title_text="Response")
    fig.update_yaxes(title_text="Count")
    figures.append(fig)

# Choose which graph to display with a set of radio buttons

st.info("### Choose the graph for a specific question")
f = st.radio("Choose a graph", options=df.columns)
column_index = df.columns.get_loc(f)
st.plotly_chart(figures[column_index])

DButils

正如你在下面看到的,DButils 库有许多用于读取、写入和更新 CSV 文件的函数。它还定义了我们上面使用的两个文件的常量。

这个库是专门为本地应用编写的,使用 JSON 文件来存储数据,但如果你想要移植到另一个平台,只需重写四个简单的函数并定义两个常量即可。

import os
import json

SURVEY_KEY = "survey.json"
RESULTS_KEY = "results.json"

# Save data
def save_dict(value, key=SURVEY_KEY):
    print(f"Saving: {value}")
    #return None
    out_file = open(key, "w")
    json.dump(value,out_file)
    out_file.close()

def save_results(value):
    save_dict(value,RESULTS_KEY)

def save_survey(value):
    save_dict(value, SURVEY_KEY)

# Retrieve data
def retrieve(key):
    # file exists read it and return dict array
    if os.path.isfile(key):
        in_file = open(key, "r")
        result = json.load(in_file)
        in_file.close()
        return result
    else:
        # File does not exist return an empty dict array
        return []

def get_survey(key=SURVEY_KEY):
    return retrieve(key)

def get_results(key=RESULTS_KEY):
    return retrieve(key)

# Update results
# This may not be efficient but it is simple
def append_results(value):
    results = get_results()
    results.append(value)
    save_results(results)

Databutton

为了演示将其移植到另一个平台是多么简单,特别是将其移植到 Databutton 是多么容易,这里有一个新的 DButils 库版本。

import databutton as db

def get_survey():
    survey = db.storage.json.get("survey", default=[])
    return survey

def save_survey(survey):
    db.storage.json.put("survey", survey)

def append_results(entry):
    # Retrieve the existing survey results from the JSON file in Databutton
    survey_results = db.storage.json.get("survey_results", default=[])

    # Append the new entry to the survey results
    survey_results.append(entry)

    # Save the updated survey results back to the JSON file in Databutton
    db.storage.json.put("survey_results", survey_results)

def get_results():
    # Retrieve the results data from Databutton's datastore
    results = db.storage.json.get("survey_results", default={})
    return results

要将整个内容移植到 Databutton,只需将上述页面复制到 Databutton 页面中,并将上述代码复制到 Databutton 库中。

这之所以如此简单,是因为我几乎没有编写任何代码——Databutler 为我完成了这项工作。我只是让 AI 助手为每个页面生成库代码,然后将其粘贴到库文件中。

然后它就正常工作了?

并不是完全正确。Databutler 为单个页面生成的代码使用了略微不同的名称来存储数据,例如,一个页面用 survey_results,另一个页面用 survey。这在几秒钟内得到修复。然后它就正常工作了!

经过反思,我本可以在提示中更精确一些,告诉 Databutler 应使用的名称。

在现实世界中

我希望你能同意这些简单的例程创建了一个相当吸引人的应用程序,并向你展示了如何使用 Streamlit 创建简单的调查。

但如果你想要在现实世界中部署这样的东西,你需要考虑一些事项。

这里有一些你可能需要考虑的事项:

  • 尽管调查是匿名进行的,你可能还是希望能够识别受访者,以避免重复条目。

  • 你可能想要包含不同类型的问题或以不同的方式呈现它们(例如 st.select_slider())。

  • 随机化响应的呈现方式有时可以避免引导受访者到特定的答案。

  • 你几乎肯定会想要向调查中添加人口统计问题。这些问题也可以作为多项选择题来实现,但在分析时需要以不同的方式处理结果。

但这不是一个关于如何设计调查的教程,所以我就不再详细讲解了。

感谢阅读,我希望这对你设计 Streamlit 中的调查问卷有所帮助。这个应用程序故意保持非常简单,数据存储仅在应用程序本地部署时有效——我希望在后续文章中解决这些问题。

如果你想查看更多我的作品,请查看我的网页

促进员工之间联系的简单工具

原文:towardsdatascience.com/simple-tool-to-foster-connection-among-employees-82ef5c1353f5

办公时间

利用 Python 构建一个快乐而紧密的团队

Zolzaya LuvsandorjTowards Data Science Zolzaya Luvsandorj

·发表于 Towards Data Science ·6 分钟阅读·2023 年 2 月 1 日

--

COVID-19 促进了一个积极的变化,那就是推动更多公司采用灵活的工作安排。这种采用意味着我们中的更多人即使在封锁结束后也能继续在家工作。虽然这种灵活性在许多方面都很棒,但一个潜在的缺点是,你不能像在办公室那样随意碰到同事,并进行那些随意的自发对话,这有助于建立更好的同事关系并让你感受到团队的一部分。在这篇文章中,我分享了一个利用一点编程技能来激发这些对话的简单想法,当不是所有人都在办公室时可以用得上。

图片来源:Toa HeftibaUnsplash

💡 想法

所以这个想法很简单:我们定期随机匹配同事组,并鼓励他们在小组内进行随意对话。例如,我们每周随机将所有同事分成 3 人一组,指定其中一人作为会议组织者,在他们的日历中预定一个 25 分钟的会话,并鼓励他们参加这个会话以便彼此沟通。或者可以是每两周进行一次的步行交流,以此作为关心健康的举措。这个基本想法可以进一步调整和自由定制,以适应公司文化。

我所在的公司每周组织一次,我们每周都会与另一位同事进行虚拟或面对面的咖啡交流。我还听说过其他公司在 COVID-19 之前就已经组织了较大的小组在附近的咖啡馆一起喝咖啡。因此,即使每个人都在办公室,这个想法也可以帮助促进联系!

在这篇文章的剩余部分,我们将查看如何在简单约束下创建随机匹配的 Python 初始代码。我们将从对开始,然后扩展代码以适应不同的组大小。虽然示例是用 Python 编写的,但这个想法也可以转换到其他语言中。

📦 数据集

在这个示例中,我们假设我们想为整个公司(100 名员工)组织这项团队建设活动。根据公司的规模和结构,这项活动可以涉及整个组织中的所有同事,或特定部门的同事。

我们将从导入库和创建一个包含 100 名员工的合成数据集开始。与经理或直接下属进行定期沟通是很常见的,因此我们将添加一个简单的约束,避免将经理和他们的直接下属配对,因为他们之间的互动已经足够多。因此,我们将经理分配给员工。

太棒了,我们现在将初始化一个数据框来保存历史记录。

跟踪匹配历史将有助于确保相同的配对不会过快地再次匹配。这样,每个人都能有机会与尽可能多的不同人互动。为此,我们将在第一次匹配后使用匹配历史作为额外约束条件。

🌸 对

我们将从最简单的形式开始,一个对:两个组成的组。让我们构建一个Matcher对象,寻找在以下约束条件下的两人匹配:

◼️ 排除同事的直接下属或经理

◼️ 排除过去 10 次匹配中与同事匹配的人员

我们刚刚为 50 对创建了第一次匹配。让我们选择一个示例员工:‘诺亚·罗德斯’ 并检查他的约束条件:

由于我们还没有进行任何匹配,因此约束条件应仅基于他的经理和直接下属。让我们在员工数据集中检查这些信息:

太棒了,约束条件很合理。让我们看看诺亚会匹配到谁:

这个匹配符合约束条件。现在,我们将进行另外 5 次匹配,持续 5 周,并检查约束条件是否按预期工作:

让我们再次检查诺亚的约束条件:

我们有几个新的约束条件。这些新的约束条件应该反映历史上的匹配。让我们检查一下:

很棒,所以约束条件按预期工作。‘布里特妮·菲利普斯’ 不在约束条件中是有道理的,因为这是最近的一次匹配。

经过一些检查后,我们将再次初始化Matcher对象,并运行半年的时间:

我们可以看到在一些情况下,找到匹配项时遇到困难,因此不得不重新开始随机匹配。目前,Matcher的设置是,如果我们有一个奇数个员工,一个人将被分配一个默认值:‘去散步’,这样那个没有找到匹配的人可以进行一些运动。这个默认值可以是任何东西!

在做了一个简单版本之后,让我们调整Matcher对象,以便它可以处理不同的组大小。

🍀 组

以下展示了一种我们可以扩展代码以适应不同组大小的方式。虽然处理员工人数不能被组大小整除的情况有不同的方法,但我们选择了一种更简单的选项,通过填充默认值直到记录数量变为可整除的数字。例如,当我们尝试将 100 名员工分成 3 人一组时,会出现一个多余的人。在这种情况下,我们将用默认值:‘去散步’ 填充两次,将记录数增加到 102,一个可以被 3 整除的数字。为了防止一个人被匹配到两个默认值,我们可以确保任何组只有一个默认值使用约束。因此,这两个较小的组(即对)可以进行步行会议或视频通话。

尽管在这个示例中我们选择了一个 3 人的组,但该对象也可以处理更大的数字。

好了,这就是本帖的全部内容!希望这段起始代码能为你提供帮助,节省时间,如果你想在你的组织中提出这个想法的版本来实现。

图片由 Shaurya Sagar 提供,来源于 Unsplash

感谢阅读本文。如果你感兴趣,这里有一些我其他帖子的链接:

◼️️ 从 ML 模型到 ML 管道

◼️️ 用 SHAP 解释 Scikit-learn 模型

◼️️ Python 中绘制多个图表的 4 个简单技巧

◼️ 美化 pandas DataFrames

◼ 你会发现有用的 Python 简单数据可视化️

◼️ 6 个让 Seaborn(Python)绘图更美观和个性化的简单技巧

再见了 🏃 💨

提高零-shot CLIP 性能的简单方法

原文:towardsdatascience.com/simple-way-of-improving-zero-shot-clip-performance-4eae474cb447

第一部分 — 通过语言模型定制提示(CuPL)

Alexey KravetsTowards Data Science Alexey Kravets

·发表于 Towards Data Science ·12 分钟阅读·2023 年 11 月 3 日

--

单模态模型旨在处理来自单一模态的数据,这可以是文本或图像。这些模型专注于理解和生成特定于其选择模态的内容。例如,GPT 在生成类似人类的文本方面表现出色。它们已被用于语言翻译、文本生成和问题回答等任务。卷积神经网络(CNNs)是图像模型的例子,擅长于图像分类、对象检测和图像生成等任务。目前,许多有趣的任务,如视觉问答(VQA)和图像-文本检索等,需要多模态能力。是否可以结合文本和图像处理?可以!CLIP 是最初几个高度成功的图像-文本模型之一,展示了在图像识别和文本理解方面的能力。

我们将把这篇文章分为以下几个部分:

  1. 引言

  2. 架构

  3. 训练过程和对比损失

  4. 零-shot 能力

  5. CuPL

  6. 结论

引言

CLIP 模型是一种令人印象深刻的零-shot 预测器,使其能够在没有明确训练的任务上进行预测。正如我们将在接下来的部分中详细了解的,通过使用自然语言提示来查询图像,CLIP 可以在没有任务特定训练数据的情况下执行图像分类。然而,通过一些技巧,可以显著提高其性能。在这一系列文章中,我们将探讨利用大型语言模型(LLM)生成的额外提示或少量示例的几种方法,而无需涉及任何参数训练。这些方法具有显著优势,因为它们计算负担较轻,不需要微调额外的参数。

架构

CLIP 是一个双编码器模型,具有两个独立的编码器用于视觉和文本模态,分别独立编码图像和文本。这种架构不同于融合编码器,它通过交叉注意力实现视觉和文本模态之间的交互,涉及学习注意力权重,帮助模型在处理两种模态时关注图像的特定区域和文本的相应部分。这个想法类似于自注意力,它允许每个标记关注同一模态中的其他标记。交叉注意力扩展了这个概念,允许一种模态中的标记(例如,代表图像特征的标记或补丁)关注另一种模态中的标记(例如,代表文本描述的标记)。双编码器和融合编码器的概念可以总结如下:

作者插图 — 双编码器与融合编码器

编码器

文本编码器: 负责处理输入文本,文本编码器将其转换为向量表示。在 CLIP 中,模型使用一种标准,我们在这篇文章中进行了详细探讨。文本编码器为提供的文本生成嵌入,封装与输入相关的语义信息。

图像编码器: 图像编码器处理图像以获取其向量表示。视觉编码器可以是像 ResNet 模型这样的卷积神经网络,或者是 ViT Transformer(请参见这里以刷新你的知识),用于生成图像向量表示。

这两个向量具有相同的维度,从而能够计算给定文本和图像之间的相似性。如果你一直只使用一种模态,你可能会想,图像和文本嵌入如何能够进行比较呢?关键在于训练过程和损失函数,这使得 CLIP 能够学习一个统一的图像-文本空间,促进来自不同模态的向量比较。

训练过程与对比损失

CLIP 通过在一个包含图像-文本对的大规模数据集上进行多模态目标训练。当我说大规模时,意味着数据量非常庞大——约为 4 亿对图像-文本。这些数据来自互联网上公开的来源,并经过自动筛选以确保高质量。

一旦图像-文本对被收集,模型通过对比损失进行训练。对比损失使模型能够通过对齐图像和文本的表示来学习共享的图像-文本空间,最大化匹配对嵌入之间的相似性,同时最小化不匹配对嵌入之间的相似性。该过程如下图所示:

图片来自CLIP论文 — 对比损失

图像嵌入 I_i 对应于文本嵌入 T_i(即在对角线上)形成匹配对,而所有其他文本 T_j(j ≠ i)(非对角线)被视为不匹配对。同样地,对于 T_i,只有 I_i 被认为是匹配的图像,而所有其他图像 I_j(j ≠ i)则不被视为 T_i 的描述。然而,这种假设可能是有限的,因为可能存在其他文本对有效地描述一个图像,反之亦然。挖掘困难负例是解决这一挑战的潜在方案。尽管如此,CLIP 凭借其 32,768 的巨大批量大小成功克服了这一限制。

在这个多样化的数据集上预训练之后,CLIP 学习到的嵌入可以用于许多下游应用 — 其中一个真正令人印象深刻的是零-shot 图像分类。

零-shot 能力

那么零-shot 到底是什么意思呢?如介绍中提到的,零-shot 分类指的是模型在没有特定类别示例或训练数据的情况下,正确分类未见过的类别的能力。CLIP 在一个大型数据集上进行了训练,它学会了在广泛的概念中进行泛化,使其能够根据语义关系识别和分类类别。让我们看看这在实践中是如何实现的:

假设我们只知道特定数据集的类别名,比如

[“狗”、 “猫”、 “马”]。由于 CLIP 被训练来匹配图像和文本,我们可以计算给定测试图像与提示“一个{类别名}的图片”之间的余弦相似度,这在我们的情况下变成:

“一个{狗}的图片”、 “一个{猫}的图片”、 “一个{马}的图片”

具有最高余弦相似度的提示代表预测的类别。

通过定制提示改进 Zero-Shot CLIP(CuPL)

现在,Zero-Shot CLIP 已经取得了相当令人印象深刻的表现,不过,我们仍然可以通过一些简单的技巧进一步挖掘其潜力。CLIP 的零-shot 性能对其输入的文本提示非常敏感。这就是为什么对于不同的数据集如 ImageNet,人们提出了不同的文本提示,如“一个折纸{类别名}”“一个视频游戏中的{类别名}”等。这些手动设计的提示比简单的“一个{类别名}的图片”要好,但它们仍然存在一些主要的限制:

  1. 手动编写的提示需要大量的人力

  2. 手动编写的提示必须是通用的 — 我们不能使用像“一个{鸭嘴兽}的照片,一种水生哺乳动物”这样的模板,因为这只适用于水生哺乳动物,而不适用于其他类别。这是一个限制,因为描述性细节对细粒度分类是有用的。

  3. 编写高效的提示模板需要对数据集内容的事先了解。因此,对于 ImageNet,我们必须提前知道感兴趣的数据集包含折纸、视频游戏图像等。

那我们能做什么呢?我们只需请求一个大型语言模型 (LLM) 为我们生成这样的提示,这些提示可以轻松扩展到任何数量的类别和数据集!我们可以向 LLM 提出以下问题:

  • 描述一个/该类别名称的外观:

  • 描述一个/该类别名称:

  • 一个/该类别名称的识别特征是什么?

为什么这应该比简单提示更好?假设是 LLM 生成的提示将包含非常详细的类别描述,并使 CLIP 更加重视对正确分类最相关的图像区域。

现在让我们跳入编码中看看我们能得到什么。我们将使用 Hugging Face 的 Transformers 库中的 CLIP。因此,让我们导入模型——我们将使用 ViT,patch 大小为 32,以及处理文本和预处理图像的管道:

from transformers import CLIPProcessor, CLIPModel
import torch
import requests
from PIL import Image

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

接下来,我们将从 freeimages 网站下载一张 “树蛙” 的图片,该网站对下方的图片有开放许可。然后,我们将使用 CLIP 和简单提示 “{类别} 的照片” 来预测它是 “树蛙” 还是 “有尾蛙”(它们在视觉上相似,主要在于大眼睛的大小):

来自 FreeImages 的树蛙图像(许可: www.freeimages.com/license

url = "https://images.freeimages.com/images/large-previews/342/green-tree-frog2-1616738.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=["a photo of a tree frog", "a photo of a tailed frog"], images=image, return_tensors="pt", padding=True)

outputs = model(**inputs)
logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
print(probs)

"""
Output:
tensor([[0.3164, 0.6836]], grad_fn=<SoftmaxBackward0>)
"""

模型通过选择 “有尾蛙” 以 0.68 的概率做出了错误预测。

现在让我们请求一个 LLM(例如 ChatGPT)为我们生成提示:

prompts = {"tree frog": [
        "A tree frog is a small frog that typically has greenish coloration.",
        "A tree frog is a small frog that typically has bright colors, long toes that help it climb, and suction cups on its feet.",
        "A tree frog is small, typically green, frog that lives in trees.",
        "A tree frog looks like a frog with special adaptations for living in trees.",
        "A tree frog is a small, typically green frog with large adhesive pads on its feet that allow it to climb smooth surfaces like glass and plastic.",
        "Tree frogs are small amphibians with big toes that help them climb.",
        " Most tree frogs have bright colors.",
        "Tree frogs are small frogs that live in trees.",
        "A tree frog is a small frog that has large toe pads that help it climb trees.",
        "A tree frog typically has green skin, although some species can be brown, gray, or yellow.",
        "A tree frog is a small frog that lives in trees.",
        "A tree frog is a small, tailless amphibian with large, powerful hind legs and webbed feet.",
        "A tree frog is a amphibian that has well-developed hind legs which enable it to climb trees and other structures.",
        "Avatar of the forest, the tree frog is a small amphibian with big eyes bulging out of its head.",
        "A tree frog is a small amphibian that typically has bright green skin and lives in trees.",
        "A tree frog is a small, slim frog that typically has brightly colored skin.",
        "A tree frog is a small frog that typically has a bright green body and lives in trees.",
        "A tree frog is a small, tailless amphibian that typically has bright green or yellowish skin and lives in trees or near bodies of water.",
        "A tree frog is a small amphibian that typically has a green body and large eyes.",
        "A tree frog is a small frog that typically has bright colors.",
        "The identifying characteristics of a tree frog varies depending on the species, but some common features include large adhesive toes, protruding eyes, and bright colours.",
        "Some identifying characteristics of a tree frog are that they have large toe pads, which help them grip onto tree branches, and their bodies are slim so that they can fit into small spaces.",
        "Tree frogs are small frogs that live in trees and other high places.",
        "The identifying characteristics of a tree frog are its long hind legs, which it uses to jump, and its adhesive pads, which it uses to stick to surfaces.",
        "The identifying characteristics of a tree frog are that they have long, sticky toes that help them climb trees, and they have wrinkled skin that helps them absorb water.",
        "Tree frogs have long, sticky toes that help them climb trees.",
        "Tree frogs are small frogs that can climb trees.",
        "Tree frogs have long hind legs that they use to jump.",
        "There are over 6,300 species of tree frogs, so it is difficult to give one answer to this question.",
        "The identifying characteristics of a tree frog are its long, sticky toes that help it climb trees, and its dark green or brown coloration that helps it blend in with leaves."
    ],
    "tailed frog": [
        " short, stout body; webbed hind feet with large, adhesive discs on the toes; long, muscular tail; small eyes located on top of the head; smooth or warty skin; and a small mouth.",
        "A tailed frog has a long, skinny body and a long tail.",
        "A tailed frog has a long, slender body with a tail that is about as long as its body.",
        "A tailed frog is a frog with a long tail.",
        "A tailed frog has a long tail and four legs.",
        "A tailed frog is a small frog that has a long tail.",
        "A tailed frog has a long tail that is often as long as its body.",
        "A tailed frog has a long tail and webbed feet.",
        "A tailed frog is a frog with a long tail.",
        "A tailed frog is a small frog with a long tail.",
        "A tailed frog (Asteriscus species) is a species of frog in the Asteriscidae family.",
        "A tailed frog is a type of frog that has a long tail.",
        "A tailed frog is a frog with a long tail, typically over 10 cm in length.",
        "A tailed frog is a species of frog that has a long tail.",
        "Tailed frogs are a type of frog that have a long tail.",
        "A tailed frog is a type of frog that has a long, tail-like structure protruding from its back.",
        "A tailed frog is a frog that has a long tail.",
        "A tailed frog is a frog that has a long tail.",
        "A tailed frog has a long, thin body with short legs.",
        "A tailed frog is a small amphibian that has a long tail.",
        "Some identifying characteristics of a tailed frog are that they have a long tail, they are good swimmers, and they live near water.",
        "There are over 60 species of tailed frogs, so it is difficult to give a definitive answer.",
        "There are over 100 species of tailed frogs, so it is difficult to give a general answer to this question.",
        "Some identifying characteristics of a tailed frog are that they have a long tail, they are small, and they have webbed feet.",
        "There are over 60 species of tailed frog, so identifying characteristics can vary.",
        "Tailed frogs are a species of frog that are native to the western United States and northern Mexico.",
        "The identifying characteristics of a tailed frog are its tail, which is used for swimming, and its webbed feet.",
        "Tailed frogs are small, dark-colored frogs with long, slender hind legs and a long, thin tail.",
        "Some tailed frogs have a tail that is about one-third the length of their body.",
        "The identifying characteristics of a tailed frog are its long tail and its smooth, moist skin."
    ]}

这些提示对类别的描述更加详细,这应该能指导模型识别正确的类别。例如,许多“树蛙”的提示强调它有“很大的眼睛”,这是像“树蛙的图片”这样的简单提示所无法捕捉的。

使用上述提示并经过一些操作,我们为一个类别形成最终的提示,带有平均向量嵌入:

"""
First of all, we can verify that 
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
probs

Is the same as:
image_features = model.visual_projection(model.vision_model(inputs['pixel_values']).pooler_output)
text_features = model.text_projection(model.text_model(inputs['input_ids']).pooler_output)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# cosine similarity as logits
logit_scale = model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
probs = logits_per_image.softmax(dim=1)
probs
"""

image_features = model.visual_projection(model.vision_model(inputs['pixel_values']).pooler_output)

tree_frog_vector = model.text_model(processor(prompts['tree frog'], return_tensors="pt", padding=True)['input_ids']).pooler_output
# take the mean prompt embedding
tree_frog_vector = tree_frog_vector.mean(dim=0, keepdims=True)
# final projection 
tree_frog_vector = model.text_projection(tree_frog_vector)

tailed_frog_vector = model.text_model(processor(prompts['tailed frog'], return_tensors="pt", padding=True)['input_ids']).pooler_output
# take the mean prompt embedding
tailed_frog_vector = tailed_frog_vector.mean(dim=0, keepdims=True)
# final projection
tailed_frog_vector = model.text_projection(tailed_frog_vector)

# concatenate 
text_features = torch.cat([tree_frog_vector, tailed_frog_vector], dim=0)

# normalize features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# cosine similarity as logits
logit_scale = model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
probs = logits_per_image.softmax(dim=1)
print(probs)

"""
Output:
tensor([[0.6512, 0.3488]], grad_fn=<SoftmaxBackward0>)
"""

使用 LLM 提供的提示可以给出正确的分类答案——“树蛙”

结论

在这篇文章中,我们已经看到如何通过使用大型语言模型轻松提高 CLIP 的零样本预测。这个解决方案的优点不仅在于更高的准确性,还在于其可扩展性,因为我们不需要任何人工努力来生成提示。在接下来的文章中,我们将探索其他方法来改进 CLIP 的零样本学习以及无需训练的少样本学习方法。

参考文献

[1] [2103.00020] 从自然语言监督中学习可迁移的视觉模型 (arxiv.org)

[2] [2209.03320] 杜鹃鸟的样子是什么?生成定制化提示用于零样本图像分类 (arxiv.org)

[3] CLIP (huggingface.co)

在 Amazon ECS 上将机器学习模型作为 Flask API 部署的简单方法

原文:towardsdatascience.com/simple-way-to-deploy-ml-models-as-flask-apis-on-amazon-ecs-7be11f9dc4d9

在 4 分钟内将 Flask API 部署到 Amazon ECS

Nikola KuzmicTowards Data Science Nikola Kuzmic

·发表于 Towards Data Science ·6 分钟阅读·2023 年 3 月 10 日

--

图片由 Arjan van den Berg 提供,来源于 Unsplash

在这篇文章中,我们将介绍如何部署一个线性回归 XGBoost 模型,该模型根据开发者的工作经验年限来预测他们的薪资。

👉 游戏计划

  1. 训练一个 XGBoost 模型

  2. 构建一个简单的 Flask API 来提供模型预测

  3. 为 Flask API 构建 Docker 镜像

  4. 在 Amazon ECS 上部署 Docker 容器

完整源代码 GitHub 仓库:link🧑‍💻

flask-on-ecs - repo structure
.
├── Dockerfile
├── README.md
├── myapp.py
├── requirements.txt
└── train_xgb.ipynb

为什么我们需要 API 来部署机器学习模型

如果你正在阅读这篇文章,说明你已经到了数据科学项目的阶段,希望将你出色的机器学习模型在互联网上提供给所有人。人们称这一步骤为将模型部署到生产环境中。

在这里,我们不会使事情变得过于复杂,也不会详细审查生产级部署的样子,而是简单地利用默认的 Flask 开发服务器演示从训练/保存的 XGBoost 机器学习模型,到 Docker 化,再到将其作为实时 API 部署到 Amazon ECS 的全过程。

👉 步骤 1:训练一个 XGBoost 模型

训练一个 XGBoost 模型以预测开发者的薪资,基于他们的工作经验年限,并将模型保存为 pickle 文件。

为了在 VS Code 内运行,让我们创建一个单独的 Python 3.8 环境:

conda create --name py38demo python=3.8 
conda activate py38demo
pip install ipykernel pandas flask numpy xgboost scikit-learn

然后重启 VS Code,并在 Jupyter Notebook 中 -> 选择‘py38demo’作为内核。

训练并 pickle XGBoost 模型:

现在是创建一个可以提供这些推荐的 API 的时候了!

👉 步骤 2:Flask API

我们的 API 将加载 XGBoost 模型,接受 POST 请求,并生成响应。

首先让我们在本地运行 API。然后在一个单独的终端中,我们可以通过发送负载 POST 请求来测试,以查看一个拥有 2.5 年经验的开发者会做什么:

curl -X POST http://0.0.0.0:80/recms -H 'Content-Type: application/json' -d '{"years":"2.5"}'

2.5 年后 $260k,12.5 年后 $750k。不错!🤑

👉 步骤 3:Docker 镜像

要在 Docker 容器中运行我们的应用程序,我们需要一个“蓝图”,该蓝图包含有关使用什么环境、要复制哪些本地文件以及如何运行应用程序的指令。所有这些都被称为 Docker 镜像,通常在 Dockerfile 中指定。

现在我们可以在 Docker 容器中运行我们的 API 并进行本地测试。

注意:由于我在 Mac 上构建镜像,我需要指定

    • platform linux/amd64

使其与 ECS Fargate Linux 环境兼容。

这是我们如何构建和运行镜像的步骤。

注意:我们将主机(即笔记本电脑)的端口 80 绑定到 Docker 容器的端口 80:

docker build --platform linux/amd64 -t simpleflask .
docker run -dp 80:80 simpleflask

让我们测试一下现在在 Docker 容器中运行的 API!📦

curl -X POST http://0.0.0.0:80/recms -H 'Content-Type: application/json' -d '{"years":"12.5"}'

是时候在 AWS 上部署了!🚀

👉 步骤 4:在 Amazon ECS 上运行容器

这个部分可能一开始看起来很复杂,但实际上如果我们将过程分解成 6 个简单步骤,就会变得很简单。

Chrome 徽标:[1],Flask 徽标:[2]

i) 将 Docker 镜像推送到 ECR

让我们创建一个名为 demo 的 ECR 仓库,我们可以将 Docker 镜像推送到其中。

然后我们可以使用 ECR 提供的推送命令:

# autheticate
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin <Your-aws-acc-no>.dkr.ecr.us-east-1.amazonaws.com

#tag the image
docker tag <Your-local-docker-image-name>:latest <Your-aws-acc-no>.dkr.ecr.us-east-1.amazonaws.com/<Your-ECR-repo-name>:latest

#push the image to ECR
docker push <Your-aws-acc-no>.dkr.ecr.us-east-1.amazonaws.com/<Your-ECR-repo-name>:latest

假设:您已在本地机器上配置了 AWS CLI,并设置了具有与 ECR 交互所需权限的 IAM 用户。您可以在此 链接 中找到更多信息。

运行上述 3 个命令后,我们可以看到我们的镜像已经在 ECR 上了!🎉

复制并粘贴镜像 URI 到某个地方,因为我们将在接下来的几个步骤中需要它。

ii) 创建 IAM 执行角色

我们需要创建一个执行角色,以便运行容器的 ECS 任务可以访问从 ECR 拉取镜像。我们将其命名为:simpleRole

iii) 创建安全组

需要安全组以允许互联网上的任何人向我们的应用程序发送请求。在现实世界中,您可能希望将其限制为特定的 IP 集合,但在这里我们将其对所有人开放,并命名为:

simpleSG

iv) 创建 ECS 集群

这个步骤很简单,只需几秒钟。我们将其命名为:flaskCluster

在我们的集群被配置时,让我们创建一个任务定义。

v) 创建任务定义

任务定义,顾名思义,是一组与运行哪个镜像、开放哪个端口以及分配多少虚拟 CPU 和内存相关的指令。我们将其称为:demoTask

vi) 运行任务

让我们在我们的flaskCluster上运行demoTask,使用我们在步骤 iii) 中创建的simpleSG

现在可以通过公共 IP地址测试已部署的 API 了! 🥁

curl -X POST http://<PUBLIC-IP>:80/recms -H 'Content-Type: application/json' -d '{"years":"2.5"}'

它正在运行! 🥳

正如你所见,我们可以通过向 ECS 提供的公共 IP发送 POST 请求来获取薪资预测。 🔥

最后

这只是一个简单的演示,展示了我们如何将 XGBoost 模型 Docker 化并在 Amazon ECS 上进行实时推理。然而,我们使用了 Flask 提供的默认开发服务器,实际上应该使用像 Gunicorn 这样的生产级应用服务器,下一篇文章中我们会介绍它。

感谢阅读,希望你觉得这些内容对开始使用 Flask、Docker 和 Amazon ECS 有帮助!

想要更多关于 ML 工程的有用文章?

免费订阅 以便在我发布新故事时收到通知。

成为 Medium 会员,阅读更多来自我和其他数千名作者的故事。你可以通过使用我的 推荐链接 来支持我。当你注册时,我将获得佣金,而你无需额外支付费用。

参考资料

[1] Chrome 标志:链接

[2] Flask 标志:链接

在 Python 中创建合成数据集的简单方法

原文:towardsdatascience.com/simple-ways-to-create-synthetic-dataset-in-python-76a8e9a2f35c

数据科学基础

创建模拟表格数据的初学者指南

Zolzaya LuvsandorjTowards Data Science Zolzaya Luvsandorj

·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 1 月 12 日

--

在开发代码时,有时我们需要一个虚拟数据集。例如,我们想分享代码和底层数据,但实际数据集是机密的,因此不适合分享。一个选项是找到并使用合适的玩具数据集或公开可用的数据集。另一个选项是创建一个足够满足使用案例的合成数据集。在本文中,我们将探讨一些在 Python 中创建合成数据集的简单方法。

图片来自Jackie TsangUnsplash

🔧 设置

我们将从加载必要的库开始:

import numpy as np
import pandas as pd
from faker import Faker
from scipy.stats import skewnorm
from datetime import datetime
from sklearn.datasets import (make_regression, make_classification, 
                              make_multilabel_classification, 
                              make_blobs)
from sklearn.model_selection import train_test_split
from sklearn.ensemble import (RandomForestClassifier,
                              RandomForestRegressor)
from sklearn.multioutput import MultiOutputClassifier
from sklearn.cluster import KMeans
from sklearn.metrics import mean_squared_error, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='darkgrid', context='talk')

一切准备就绪,让我们深入探讨

📍 Scikit-learn

Scikit-learn 提供了许多有用的函数来创建合成数值数据集。在这一部分,我们将熟悉其中的一些。

📌 回归

我们可以使用[make_regression](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html)函数创建具有数值特征和连续目标的数据集。我们来创建一个包含 5 个特征和 1000 条记录的连续目标数据集:

n = 1000
n_features = 5
seed = 123X, y = make_regression(n_samples=n, n_features=n_features, 
                       random_state=seed)
columns = [f"feature{i+1}" for i in range(n_features)]
df = pd.concat([pd.DataFrame(X, columns=columns), 
                pd.Series(y, name='target')], axis=1)
print(df.shape)
df.head()

图片由作者提供

这非常简单明了!我们将记录数指定为n_samples参数,特征数指定为n_features参数。我们设置了随机种子,以便合成数据集可以被再现。如果需要,我们可以使用这个数据集来构建回归模型:

X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=seed
)
model = RandomForestRegressor(random_state=seed)
model.fit(X_train, y_train)
print(f"Train | MSE: {mean_squared_error(y_train, model.predict(X_train)):.4f}")
print(f"Test | MSE: {mean_squared_error(y_test, model.predict(X_test)):.4f}")

图片由作者提供

在这种情况下,模型似乎在训练数据上过拟合得很严重。

📌 分类

类似于上面,我们可以使用[make_classification](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html)创建具有所需数量数值特征和离散目标的数据集。我们现在将练习创建一个包含 5 个特征和一个二进制目标的 1000 条记录的数据集:

n_classes = 2
X, y = make_classification(n_samples=n, n_features=n_features, 
                           n_classes=n_classes, random_state=seed)
columns = [f"feature{i+1}" for i in range(n_features)]
df = pd.concat([pd.DataFrame(X, columns=columns), 
                pd.Series(y, name='target')], axis=1)
print(df.shape)
df.head()

图片由作者提供

除了在回归部分使用的参数外,我们还将类别数量指定为n_classes参数。我们接下来将构建一个分类模型:

X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=seed
)
model = RandomForestClassifier(random_state=seed)
model.fit(X_train, y_train)
print(f"Train | ROC-AUC: {roc_auc_score(y_train, model.predict_proba(X_train)[:,1]):.4f}")
print(f"Test | ROC-AUC: {roc_auc_score(y_test, model.predict_proba(X_test)[:,1]):.4f}")

图片由作者提供

从初步观察来看,该模型表现不错。我们还可以使用[make_multilabel_classification](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_multilabel_classification.html#sklearn.datasets.make_multilabel_classification)创建多标签分类问题的数据集:

X, Y = make_multilabel_classification(n_samples=n, 
                                      n_features=n_features, 
                                      n_classes=n_classes, 
                                      random_state=seed)
x_columns = [f"feature{i+1}" for i in range(n_features)]
y_columns = [f"target{i+1}" for i in range(n_classes)]
df = pd.concat([pd.DataFrame(X, columns=x_columns), 
                pd.DataFrame(Y, columns=y_columns)], axis=1)
print(df.shape)
df.head()

图片由作者提供

在这种情况下,我们有两个目标标签。使用虚拟数据集,我们可以构建一个多标签分类模型:

X_train, X_test, Y_train, Y_test = train_test_split(
    X, Y, random_state=seed
)
model = MultiOutputClassifier(
    RandomForestClassifier(random_state=seed)
)
model.fit(X_train, Y_train)
print(f"Train | Accuracy by class: {np.round(np.mean(Y_train==model.predict(X_train), axis=0),4)}")
print(f"Test | Accuracy by class: {np.round(np.mean(Y_test==model.predict(X_test), axis=0),4)}")

图片由作者提供

📌 聚类

另一个有用的函数是[make_blobs](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html),它用于创建聚类数据:

X, y = make_blobs(n_samples=n, n_features=n_features, 
                  centers=4, random_state=seed)
columns = [f"feature{i+1}" for i in range(n_features)]
df = pd.concat([pd.DataFrame(X, columns=columns), 
                pd.Series(y, name='target')], axis=1)
print(df.shape)
df.head()

图片由作者提供

在这种情况下,我们选择了 4 个簇中心作为centers参数。尽管聚类是无监督的,即我们没有目标变量,但我们在合成数据集中得到了簇作为目标变量。让我们可视化不同k值下的平方距离和:

ks = np.arange(2, 11)
sum_squared_distances = []
for k in ks:
    model = KMeans(k, random_state=seed)
    model.fit(X)
    sum_squared_distances.append(model.inertia_)plt.figure(figsize=(6,4))
sns.lineplot(x=ks, y=sum_squared_distances)
plt.xlabel('k')
plt.ylabel('Sum of squared distances');

图片由作者提供

看起来当 k=4 时,平方距离和趋于平稳。还是我们受到了确认偏误的影响?

如果你想了解更多,这四个函数还有其他有用的参数,可以进一步定制和控制如何创建合成数据集。你可以从这里了解更多关于 Scikit-learn 生成合成数据集的函数。

📍 NumPy & pandas

我们还可以使用numpypandas,这两个常用的数据处理库,来创建虚拟数据集:

np.random.seed(123)
df = pd.DataFrame()
df['id'] = np.random.choice(np.arange(10**5, 10**6), n, 
                            replace=False)
df['gender'] = np.random.choice(['female', 'male'], n, 
                                p=[0.6, 0.4])
df['age'] = np.random.randint(18, 80, size=n)
df['spend'] = skewnorm.rvs(100, loc=1000, scale=500, size=n)
df['points'] = np.random.normal(loc=50, scale=10, size=n)start_date = pd.Timestamp("2013-01-01")
end_date = pd.Timestamp("2023-02-01")
delta = (end_date-start_date).days
df['date_joined'] = start_date + pd.to_timedelta(np.random.randint(delta, size=n), 'day')
print(df.shape)
df.head()

图片由作者提供

📌 分类变量

我们创建了一个分类变量:gender作为示例。使用p参数,我们指定了类别的期望概率。我们可以检查生成的数据是否反映了这一点:

pd.concat([df['gender'].value_counts(normalize=True),
           df['gender'].value_counts()], axis=1)

图片由作者提供

很好,大致是 60:40。如果我们不指定p参数,类别将会均匀分配。

📌 数值变量

我们创建了一些数值变量:idagespendpoints

◼️ id:我们通过指定replace=False确保 5 位数的 ID 是唯一的。

◼️ age:使用了np.random.randint()函数在一个范围内生成随机整数。

◼️ spend:使用了scipy.stats.skewnorm.rvs()函数创建了一个偏斜的随机数值变量。

◼️ points:使用了np.random.normal()函数创建了一个正态分布的随机数值变量。

让我们比较一下spendpoints的分布,如下所示:

fig, ax = plt.subplots(2, 1, figsize=(6, 7))
sns.histplot(data=df, x='spend', ax=ax[0])
sns.histplot(data=df, x='points', ax=ax[1])
fig.tight_layout();

图片由作者提供

我们可以看到,spend的分布偏斜且有一个长的右尾,而points大致呈正态分布。

📌 日期变量

最后,我们使用pandas创建了一个日期变量。我们定义了start_dateend_date,并在这个范围内找到了随机日期。我们可以检查随机抽样日期的分布:

plt.figure(figsize=(6,4))
sns.histplot(data=df, x='date_joined');

图片由作者提供

现在,让我们熟悉一个有趣的库。

📍 Faker

Faker是一个用于创建假数据集的库。我们使用这个库的方法非常简单,我们首先初始化一个Faker对象:fake = Faker()。然后我们可以通过fake.<method_name()>访问它提供的所有方法。例如,查看fake.name()。以下是我们可以使用该库创建的样本数据集:

df = pd.DataFrame()
fake = Faker()
fake.seed_instance(seed)
np.random.seed(seed)
start_date = datetime(1940, 1, 1)
end_date = datetime(2005, 2, 1)
for i in range(n):
    df.loc[i, 'birthday'] = fake.date_between(start_date, end_date).strftime('%Y-%m-%d')
    df.loc[i, 'first_name'] = fake.first_name()
    df.loc[i, 'last_name'] = fake.last_name()
    df.loc[i, 'email'] = f"{df.loc[i, 'first_name'].lower()}@{fake.domain_name()}"
    df.loc[i, 'phone_number'] = fake.phone_number()
print(df.shape)
df.head()

图片由作者提供 | 假人员信息因安全原因已被模糊处理

除了使用pandas,我们还可以使用Faker添加日期,如birthday列所示。我们还生成了一些自由文本列。当你需要虚拟数据时,这个库非常有用,不是吗?如果你想创建虚拟文本,这里有一个示例语法:

corpus = [fake.sentence() for i in range(n)]
corpus[:5]

图片由作者提供

根据我们的需求,这里还有略有不同的版本:fake.sentences()fake.paragraph()fake.paragraphs()

在这篇文章中,我们只看了一些它的方法。如果你想了解更多关于这个库的信息,请访问这里的 GitHub 文档。

就这样!希望这些生成合成数据集的简单方法对你的 Python 代码开发有所帮助。

图片由Edgar Chaparro提供,来源于Unsplash

感谢阅读这篇文章。如果你感兴趣,这里有我其他一些文章的链接:

◼️️ 从 ML 模型到 ML 管道

◼️️ 使用 SHAP 解释 Scikit-learn 模型

◼️️ Python 中绘制多个图形的 4 个简单技巧

◼️ 美化 pandas DataFrames

◼ 在 Python 中简单的数据可视化,你会发现很有用️

◼️ 在 Seaborn(Python)中绘制更美观和定制化图表的 6 个简单技巧

再见了 🏃 💨

用 Hamilton 在 8 分钟内简化 Airflow DAG 的创建和维护

原文:towardsdatascience.com/simplify-airflow-dag-creation-and-maintenance-with-hamilton-in-8-minutes-e6e48c9c2cb0?source=collection_archive---------6-----------------------#2023-07-05

如何利用 Hamilton 编写更易维护的 Airflow DAG

Stefan KrawczykTowards Data Science Stefan Krawczyk

·

关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 7 月 5 日

--

一个抽象的表示,展示了 Airflow 和 Hamilton 之间的关系。Airflow 帮助将一切整合起来,而 Hamilton 则帮助管理内部细节。图片来自 Pixabay

本文与 Thierry Jean 合作撰写,并最初发布于 此处

本文介绍了两个开源项目,HamiltonAirflow,以及它们的有向无环图(DAGs)如何协同工作。在高层次上,Airflow 负责编排(宏观),而 Hamilton 帮助编写干净和可维护的数据转换(微观)。

对于不熟悉 Hamilton 的人,我们推荐你查看tryhamilton.dev上的互动概述,或我们的其他帖子,例如这篇文章。否则,我们将高层次地讨论 Hamilton,并指向参考文档以获取更多细节。作为参考,我是 Hamilton 的共同创作者之一。

对于仍在尝试理解这两者如何协同运行的人来说,你可以与 Airflow 一起运行 Hamilton 的原因是 Hamilton 只是一个依赖性小的库,因此可以迅速将 Hamilton 添加到你的 Airflow 设置中!

仅作总结,Airflow 是编排数据管道的行业标准。它支持各种数据计划,包括 ETL、机器学习管道和商业智能。自 2014 年首次推出以来,Airflow 用户在编写和维护数据管道方面面临一些挑战:

  1. 可维护地管理工作流的演变;起初简单的工作流往往会变得复杂。

  2. 编写模块化、可重用和可测试的代码,以便在 Airflow 任务中运行。

  3. 跟踪 Airflow DAG 生成的代码和数据工件的血统。

我们相信Hamilton可以提供帮助!Hamilton是一个用于编写数据转换的 Python 微框架。简而言之,用户以“声明性”风格编写 python 函数,Hamilton 会根据函数的名称、参数和类型注释解析这些函数,并将它们连接成一个图。可以请求特定的输出,Hamilton 将执行所需的函数路径以生成这些输出。由于它不提供宏观编排能力,因此与 Airflow 很好地配合,帮助数据专业人员编写更干净、更可重用的 Airflow DAG 代码。

Hamilton 范式的示意图。此示例展示了如何将过程性 pandas 代码映射到定义 DAG 的 Hamilton 函数。注意:Hamilton 可以用于任何 Python 对象类型,而不仅仅是 Pandas。图片由作者提供。

编写可维护的 Airflow DAG

Airflow 的一个常见用途是帮助进行机器学习/数据科学。生产中运行这种工作负载通常需要复杂的工作流。使用 Airflow 的一个必要设计决策是确定如何将工作流拆分为 Airflow 任务。创建太多任务会增加调度和执行开销(例如,移动大量数据),创建太少任务则会有大型任务可能需要较长时间运行,但可能更高效。这里的权衡是 Airflow DAG 的复杂性与每个任务中的代码复杂性。这使得调试和理解工作流变得更加困难,尤其是当你没有编写初始 Airflow DAG 时。往往,Airflow DAG 的初始任务结构变得固定,因为重构任务代码变得非常困难!

尽管像 A->B->C 这样的简单 DAG 是理想的,但结构的简单性和每个任务的代码量之间存在固有的紧张关系。每个任务的代码量越多,识别故障点就越困难,这可能会影响计算效率,但在故障的情况下,重试的成本随着任务的“大小”增长。

Airflow DAG 结构选择:任务数量?每个任务的代码量?图片作者。

相反,如果你可以同时处理 Airflow 任务中的复杂性,无论其中的代码量多大,并且能够以最小的努力轻松更改 Airflow DAG 的形状呢?这就是 Hamilton 的作用。

使用 Hamilton,你可以用 Hamilton DAG 替换每个 Airflow 任务中的代码,Hamilton 处理任务内代码的“微观”编排。注意:Hamilton 实际上使你能够逻辑上建模你希望 Airflow DAG 做的所有事情。更多内容见下文。

要使用 Hamilton,你需要加载一个包含 Hamilton 函数的 Python 模块,实例化一个 Hamilton Driver,然后在 Airflow 任务中执行一个 Hamilton DAG,只需几行代码。通过使用 Hamilton,你可以以任意粒度编写数据转换,允许你更详细地检查每个 Airflow 任务的操作。

具体的代码机制是:

  1. 导入你的函数模块

  2. 将它们传递给 Hamilton 驱动程序以构建 DAG。

  3. 然后,调用 Driver.execute(),传入你希望从你定义的 DAG 中执行的输出。

让我们看一些代码,这些代码创建了一个单节点的 Airflow DAG,但使用 Hamilton 来训练和评估 ML 模型:

现在,我们这里没有展示 Hamilton 代码,但这种方法的好处是:

  1. 单元和集成测试。 Hamilton 通过其命名和类型注解要求,推动开发人员编写模块化的 Python 代码。这使得 Python 模块非常适合进行单元测试。一旦你的 Python 代码经过单元测试,你可以开发集成测试以确保它在 Airflow 任务中正常工作。相比之下,测试包含在 Airflow 任务中的代码并不简单,特别是在 CI/CD 环境中,因为这需要访问 Airflow 环境。

  2. 重用数据转换。 这种方法将数据转换代码保留在 Python 模块中,与 Airflow DAG 文件分开。这意味着这些代码也可以在 Airflow 外部 运行!如果你来自分析领域,这种方法应该类似于在外部 .sql 文件中开发和测试 SQL 查询,然后将其加载到 Airflow Postgres 操作员中。

  3. 轻松重新组织你的 Airflow DAG。 更改 Airflow DAG 的难度现在大大降低。如果你在 Hamilton 中逻辑建模一切,例如端到端的机器学习管道,只需确定这个 Hamilton DAG 中有多少内容需要在每个 Airflow 任务中计算。例如,你可以将任务数量从一个庞大的 Airflow 任务更改为几个或许多——只需调整你从 Hamilton 请求的内容,因为你的 Hamilton DAG 完全不需要更改!

使用 Hamilton 和 Airflow 的迭代开发

在大多数数据科学项目中,从第一天起就编写最终系统的 DAG 几乎是不可能的,因为需求会发生变化。例如,数据科学团队可能希望尝试不同的特征集。在特征集确定和最终确定之前,将其包含在源代码中并进行版本控制可能并不理想;配置文件会更好。

Airflow 支持默认和运行时的 DAG 配置,并会记录这些设置以确保每次 DAG 运行都是可重复的。然而,添加可配置的行为需要在你的 Airflow 任务代码中添加条件语句和复杂性。这段代码可能在项目过程中变得过时,或仅在特定场景下有用,最终降低你的 DAG 可读性。

相比之下,Hamilton 可以使用 Airflow 的运行时配置动态执行函数图中的不同数据转换。这种分层方法可以大大增加 Airflow DAG 的表达能力,同时保持结构上的简单性。或者,Airflow 可以从配置中 动态生成新的 DAG,但这可能会降低可观察性,而且这些功能仍然是实验性的。

Airflow UI 的 DAG 运行配置。图像由作者提供。

如果你在一个交接模型中工作,这种方法促进了数据工程师(负责 Airflow 生产系统)与数据科学家(负责编写 Hamilton 代码以开发业务解决方案)之间的关注点分离。这样的分离还可以提高数据一致性,减少代码重复。例如,可以用不同的 Hamilton 模块重用单个 Airflow DAG 以创建不同的模型。类似地,相同的 Hamilton 数据转换可以在不同的 Airflow DAG 中重用,以支持仪表板、API、应用程序等。

以下是两张图片。第一张展示了包含两个节点的高层次 Airflow DAG。第二张展示了在 Airflow 任务 train_and_evaluate_model 中导入的 Python 模块 evaluate_model 的低层次 Hamilton DAG。

1. Airflow UI:缺勤 Airflow DAG

2. Hamilton 驱动器可视化:evaluate_model.py 的函数图

处理数据工件

数据科学项目会产生大量的数据工件,包括数据集、性能评估、图表、训练模型等。在项目生命周期(数据探索、模型优化、生产调试等)中所需的工件会有所变化。这对于 Airflow 来说是个问题,因为从 DAG 中删除任务会删除其元数据历史记录,并破坏工件的传承。在某些情况下,生成不必要或冗余的数据工件可能会产生显著的计算和存储成本。

Hamilton 可以通过其 data saver API 提供生成数据工件所需的灵活性。被 @save_to.* 装饰的函数增加了存储其输出的可能性,只需通过 Driver.execute() 请求此功能。下面的代码中,调用 validation_predictions_table 将返回表格,而调用 save_validation_predictionsoutput_name_ 值将返回表格并将其保存为 .csv

这种附加的灵活性使用户可以轻松切换生成的工件,并且可以直接通过 Airflow 运行时配置完成,无需编辑 Airflow DAG 或 Hamilton 模块。

此外,细粒度的 Hamilton 函数图允许精确的数据传承与来源追踪。工具函数 what_is_downstream_of()what_is_upstream_of() 有助于可视化和编程探索数据依赖关系。我们指引感兴趣的读者了解更多详情 这里

完成与示例入门

希望到现在为止我们已经让你印象深刻,结合 Hamilton 和 Airflow 将帮助你解决 Airflow 的 DAG 创建与维护挑战。由于这是篇短文,最后,让我们看看在Hamilton 仓库中的代码。

为了帮助你快速上手,我们提供了一个示例,展示如何将 Hamilton 与 Airflow 一起使用。它涵盖了你需要了解的所有基础知识。README 包括如何使用 Docker 设置 Airflow,这样你就不需要担心仅仅为了玩这个示例而安装依赖项。

关于示例中的代码,它包含两个 Airflow DAG,一个展示了一个基本 Hamilton “使用方法”用于创建用于训练模型的“特征”,另一个是一个更完整的机器学习项目示例,完成了创建特征、拟合和评估模型的完整端到端流程。对于这两个示例,你可以在插件文件夹下找到 Hamilton 代码。

你应该在 Airflow 示例中看到的内容。图片由作者提供。

如果你有问题或需要帮助——请加入我们的Slack。否则,若要了解更多关于 Hamilton 的功能和特点,请参考 Hamilton 的文档

参考资料与进一步阅读

感谢你查看这篇文章。如果你想深入了解,或想了解更多关于 Hamilton 的信息,我们提供了以下链接供你浏览!

使用 BigQuery SQL 用户定义函数简化数据清洗

原文:towardsdatascience.com/simplify-data-cleaning-with-bigquery-sql-user-defined-functions-41c0560ea6

介绍及用例

Vicky YuTowards Data Science Vicky Yu

·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 4 月 20 日

--

图片由 Brooke Cagle 提供,来源于 Unsplash

数据清洗是任何与数据相关的工作中占比很大的部分,但编写 SQL 语句往往是乏味的,尤其是在表中的多个列上编写相同的 SQL 逻辑。直到我发现可以创建 BigQuery 中的用户定义函数 (UDFs) 来满足我特定的数据清洗用例。今天,我想分享一些数据清洗用例,在这些用例中,你可以应用 UDF 来简化你的 SQL 查询。

介绍

由于不同公司对数据库权限的管理有所不同,我将讨论使用临时 UDF 的数据清洗示例,因为永久性 UDF 可能需要数据库管理员不允许的额外访问权限。临时 UDF 在 SQL 查询完成时会过期,而持久性 UDF 会保存在数据库中,可以在多个 SQL 查询中使用。

我将使用我创建并上传到 BigQuery 沙箱 的虚拟电影数据,该沙箱对任何拥有 Google 账户的人都是免费的。我几年前在一次数据分析师面试的家庭作业中收到了类似的数据,并将使用在作业中执行的数据清洗示例,但这次使用 UDFs。

用例 1:用于报告的值分组

我开始时是按年份统计电影数量,但这并没有什么用,因为在 1980 年之前有许多电影的数量不到 5 部。我决定改为按十年分组电影,以更好地了解电影的频率分布。

在下面的临时 UDF ReleaseYearCategory 中(第 1 到 10 行),第 3 到 8 行的 CASE 语句根据 release_year 字段将电影分为 5 类。注意我在第 3 到 7 行的 CASE 语句前缀,即 1. < 1980。数字前缀将强制 release_year_category 从较早的年代到最近的年代排序。

虽然这是一次性的任务,但使用 UDF 仍然有许多优点。

  1. release_year 字段是字符串,但需要是数字以进行日期范围检查。与其在每次引用 release_year 字段时进行类型转换,不如一次性将 cast(release_year as int) 传递给 UDF,然后 field 变量将被替换为 cast(release_year as int).

  2. 如果 release_year 字段被更改为整数类型,我只需在调用 ReleaseYearCategory UDF 时去掉 cast 语句。

  3. UDF 是可重用的(假设它被保存为持久性 UDF)。如果我想将相同的年份分组逻辑应用于另一张表,只需将不同的字段名称传递给 UDF。

  4. 如果我想按 5 年的增量分组,而不是按十年分组,我只需修改一个 UDF,而不是更改多个 SQL 语句。

作者创建的临时 UDF ReleaseYearCategory 示例截图

将电影分成 5 个类别显示大多数电影是在 2000 年后发布的。如果在按字段分组时数据行过多,考虑像我上面做的那样将数据折叠到更少的行中。一个常见的例子是按周或月而不是按日分组数据。

用例 2:将字符串值转换为数字

我想查看按每部电影分配的类别数量来统计电影数量。一种简单的方法是对类别字段求和并以此作为分组字段,但类别字段必须从 TRUE 转换为 1,从 FALSE 转换为 0。例如,在下面的数据示例中,第 1 行的 Movie Title 355 将累加为 3,因为动作、冒险和科幻类别字段的 TRUE 值将被转换为 1。

作者创建的具有 TRUE 和 FALSE 值的电影类别字段的截图示例

UDF 使得编码转换变得更容易,因为我不需要为每个类别字段输入 CASE 语句。我只需将字段名称传递给 UDF。在下面的 ConvertTrueFalse UDF 中(第 1 到 8 行),我在第 6 行有一个 ELSE -1 语句,以捕捉任何不匹配 TRUEFALSE 预期值的值。由于我之前确认了类别字段中只有 TRUE 或 FALSE 值,因此 ELSE 不是必需的,但作为最佳实践,您可以添加一个 ELSE 语句以防映射到意外值。例如,如果类别字段为 NULL,我将会在没有 ELSE 语句的情况下将其设置为 0。

我还在第 4 和第 5 行中添加了 UPPER 函数,以防TRUEFALSE被拼写为混合大小写,如Truefalse。在处理混合大小写字符串时,添加 UPPER 函数是一种良好的做法。如果字段的值为True,如果没有 UPPER 函数,我会将其设置为 0,这会导致分析错误。尽管这是一次性的任务,但可以看到第 11 到第 15 行中的 UDF 调用减少了 SQL 代码,使其更易于阅读。

作者创建的临时 UDF ConvertTrueFalse 示例截图

用例 3:计算收益

我想通过类型来分析每部电影的收益,以了解哪些类型的电影更有利可图。为了计算收益,我需要使用类型为字符串的movie_grossmovie_budget字段。为了避免多次编写 cast 语句,我只需将 cast 语句传递给第 10 行中显示的CalcReturn UDF。

虽然这是一次性计算,但如果你经常需要计算收益或其他常见计算,考虑使用 UDF 来简化编码。

作者创建的临时 UDF CalcReturn 示例截图

最后想法

虽然数据清理可能不是你作为数据专业人士最喜欢的活动,但我希望你能看到创建 UDF(用户自定义函数)如何帮助简化你的 SQL 编码。

虽然我只讨论了临时 UDF,但将经常使用的 SQL 逻辑保存为持久 UDF 可以帮助集中代码,并允许在 SQL 查询中重复使用。这可能需要与你的数据库管理员讨论,以了解关于 UDF 创建和 SQL 用户使用的数据库权限。查看文档也会有帮助,以查看 UDF 代码和使用说明。

我已简要介绍了 UDF,并强烈建议你查看 文档 以了解更多信息。

注意:上述所有查询均在 BigQuery sandbox 上运行,任何拥有 Google 帐号的人都可以免费使用。

你可能还会喜欢…

## BigQuery SQL 数据清理函数

使用案例和应用函数

[towardsdatascience.com ## BigQuery SQL 过程语言简化数据工程

简介

[towardsdatascience.com ## 每个用户都应该知道的 6 个 BigQuery SQL 函数

检查你的数据库是否也包含这些函数

[towardsdatascience.com

使用 Fugue 和 Python 简化 BigQuery 上的数据科学工作流

原文:towardsdatascience.com/simplify-data-science-workflows-on-bigquery-with-fugue-and-python-5215a1b65e43

加快迭代速度并降低计算成本

Khuyen TranTowards Data Science Khuyen Tran

·发表在 Towards Data Science ·阅读时间 6 分钟·2023 年 4 月 13 日

--

动机

许多数据团队开始时会在数据仓库如 BigQuery 上建立分析实践。然而,仅依赖 BigQuery 处理数据科学工作负载可能不是最佳方法,原因有很多:

  • SQL 之外的高级需求:如数据验证、可视化和机器学习预测等用例可能需要超出 SQL 语法限制的更高级功能。

  • 探索成本高:由于 BigQuery 的迭代特性,它可能不是数据科学任务中最具成本效益的解决方案,这涉及到大量的特征工程和算法实验。

对于在 BigQuery 上处理数据的数据科学家而言,一个理想的解决方案应能够让他们:

  • 使用 SQL 和 Python 查询 BigQuery 中的数据。

  • 互动地在本地测试各种 SQL 查询

  • 在彻底测试后轻松切换回 BigQuery。

作者提供的图片

FugueSQL BigQuery 集成允许你做到这一点。

Fugue 是什么?

Fugue 是一个开源项目,通过将 Python、Pandas 和 SQL 移植到如 Spark、Dask 和 Ray 等分布式计算引擎上,从而简化大数据开发。

## 介绍 FugueSQL — Pandas、Spark 和 Dask 数据框的 SQL

面向数据科学和分析的端到端 SQL 界面

towardsdatascience.com

在本文中,我们将使用 FugueSQL 来简化在 BigQuery 上的开发。

欢迎在此处自由试验和分叉本文的源代码:

[## Data-science/fugue_bigquery.ipynb at master · khuyentran1401/Data-science

目前无法执行该操作。你在另一个标签或窗口中登录。你在另一个标签或窗口中登出……

github.com](https://github.com/khuyentran1401/Data-science/blob/master/data_science_tools/fugue_bigquery.ipynb?source=post_page-----5215a1b65e43--------------------------------)

安装和设置

安装 Fugue BigQuery

要安装 Fugue BigQuery 集成,请输入:

pip install fugue-warehouses[bigquery]

对 Google BigQuery 进行身份验证

要对 Google BigQuery 进行身份验证,标准方法是使用 GOOGLE_APPLICATION_CREDENTIALS 环境变量指定凭证 JSON 文件的位置。

import os  

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/your/credential.json'

如何获取凭证 JSON 文件。

Fugue BigQuery 使用

查询表格

开始时,设置 FugueSQL 以便在 Jupyter Notebook 或 Lab 中使用:

from fugue_jupyter import setup
setup(run_js=True)

要使用 FugueSQL BigQuery 集成运行 SQL 查询,只需在单元格的开头添加 %%fsql bq

作者提供的图片

参数化

FugueSQL 允许你使用 Jinja 模板对 SQL 查询进行参数化。

以下查询对 table 变量进行参数化,这在从开发表过渡到生产表时特别有用。

作者提供的图片

拆分查询

Fugue 包含了多个改进的标准 SQL 功能,这些功能有助于查询拆分。以下查询展示了这些增强功能:

  • 等号将查询的输出分配给 df 变量,该变量随后在另一个操作中被重用。

  • TAKE 返回整行数据。PREPARTION BY 按性别对数据进行分区。PRESORTct 列以降序排序数据。

  • YIELD 使 DataFrame 可用于后续的 Jupyter Notebook 单元格。

作者提供的图片

我们现在可以访问 names 变量,并在另一个 Jupyter Notebook 单元格中执行额外的查询。

作者提供的图片

FugueSQL Python 扩展

TRANSFORM

FugueSQL 允许你通过使用 TRANSFORM 关键字 将 Python 函数集成到 SQL 查询中。

以下查询演示了如何使用 TRANSFORMget_decade 函数应用于数据,以生成名为 decade 的新列。

作者提供的图片

使用 TRANSFORM,数据科学家现在可以对数据进行特征工程,并在 BigQuery 上利用新的机器学习模型。

OUTPUT

FugueSQL 还提供了内置扩展,方便与其他绘图库(如 seaborn)集成。

在下面的代码中,OUPUT USING sns:lineplot() 将查询结果导入 pandas,然后使用 seaborn 生成线图。

作者提供的图片

生产化 SQL 查询

为了将笔记本过渡到生产环境,我们可以消除中间的 YIELD 语句,并将查询作为字符串传递给 fugue_sql() 函数。

import fugue.api as fa
res = fa.fugue_sql("""
SELECT name, gender, SUM(number) AS ct
  FROM `{{table}}`
 GROUP BY name, gender

names = TAKE {{n}} ROWS PREPARTITION BY gender PRESORT ct DESC

SELECT name, year, SUM(number) AS ct
  FROM `{{table}}`
 WHERE name IN (SELECT name FROM names)
 GROUP BY name, year
 ORDER BY year
""", engine="bq", table=table, n=n)

然后我们可以使用 as_pandas() 函数将输出转换为 pandas 以进行进一步分析。

作者提供的图片

迭代大数据

FugueSQL 提供了使用 SAMPLE 关键字将 BigQuery 表采样到较小的数据框的功能。这加快了迭代过程,避免了每次都需要处理完整数据集。

在这里,YIELD 关键字再次被使用,以便使 test 数据框可用。

作者提供的图片

在下一个单元格中,我们可以使用 FugueSQL 和 DuckDB 后端在 test 数据框上测试查询,以加速代码。

[## Fugue 和 DuckDB:Python 中的快速 SQL 代码]

使用 Python 和 DuckDB 优化 SQL 代码

towardsdatascience.com](/fugue-and-duckdb-fast-sql-code-in-python-e2e2dfc0f8eb?source=post_page-----5215a1b65e43--------------------------------)

作者提供的图片

一旦查询经过彻底测试,切换引擎到 bq 是一项简单的任务。

作者提供的图片

将 BigQuery 与 Spark、Dask 或 Ray 结合使用

如果你在 BigQuery 中处理大量数据,单台机器可能处理速度过慢。Fugue 提供了将 BigQuery 与分布式计算框架(如 Spark、Dask 和 Ray)集成的便捷方式。

在下面的查询中,transform() 函数以分布式方式将 median 函数应用到 Dask 数据框的每个分区。

import pandas as pd
from typing import List, Any

# schema: *
def median(df:pd.DataFrame) -> List[List[Any]]:
    return [[df.state.iloc[0], df.number.median()]]

fa.transform(
    ("bq", """SELECT state, number
    FROM `bigquery-public-data.usa_names.usa_1910_2013` TABLESAMPLE SYSTEM (1 PERCENT)"""),
    median,
    partition="state",
    engine="dask"
).compute().head()

作者提供的图片

当查询执行时,数据会自动持久化到临时数据集。默认情况下,数据集名为 FUGUE_TEMP_DATASET,必须在 BigQuery 工作区中创建,如下所示。

作者提供的图片

结论

恭喜你!你刚刚学会了如何使用 FugueSQL 快速迭代 BigQuery 中的数据。Fugue 的 Python 和 SQL 互操作性提供了无摩擦的开发者体验,且代码冗余最小。

我喜欢撰写数据科学概念和玩各种数据科学工具。你可以在LinkedInTwitter上与我联系。

如果你想查看我写的文章的代码,可以给这个仓库加星。关注我在 Medium 上,以便及时了解我最新的数据科学文章:

[## 使用 Pandera 验证你的 pandas DataFrame

确保你的数据符合预期

[## 使用管道编写清晰的 Python 代码

处理可迭代对象的简短而清晰的方法

[## 当 Python 文件发生更改时自动重新加载的 2 个工具

[## 使用 Hex 创建可观察和可重复的笔记本

如何将笔记本集成到你的数据管道中

[## 创建可观察和可重复的 Hex 笔记本

简化文件共享

原文:towardsdatascience.com/simplify-file-sharing-44bde79a8a18?source=collection_archive---------4-----------------------#2023-09-12

使用 Python 操作 Google Drive 的共享文件夹的编码示例

Gijs van den DoolTowards Data Science Gijs van den Dool

·

关注 发表在 Towards Data Science ·13 分钟阅读·2023 年 9 月 12 日

--

最近,又出现了数据共享问题,我认为这是设计一个处理共享文件夹的方法的好时机。我作为一名独立的地理信息科学专家,常常与多个组织同时合作。在我的项目中,我注意到每个组织在处理数据时都有其独特的方法,这些方法受到其特定文化和工作伦理的影响,导致出现了多种多样的方法论。幸运的是,它们之间有一些共同的做法,其中之一就是使用基于云的数据管理系统,通常是 Google,但也可以是 Microsoft 的 One-Drive 或 Dropbox。

在这篇文章中,我将解释如何在 Google 生态系统中使用 Python 操作共享文件夹。

Annie Spratt 提供的照片,来源于 Unsplash

用例

在本地计算机上管理文件是非常个性化的,在组织中工作时(希望)能够标准化,或者至少有一些标准化。系统间共享文件可能会很复杂,但当你没有直接访问生产文件夹的权限时,使用共享文件夹是一个选项,组织可以与你共享一个专门指定的工作文件夹来交换文件。在这个例子中,组织已经授权访问他们 Google Drive 仓库中的一个名为 DATA 的文件夹,并且已经同意我们可以使用这个文件夹来交换文件。

本地文件管理

简单解释一下,对于不熟悉 Google Drive 文件共享的人,这个过程从收到一封邀请你贡献特定文件夹的电子邮件开始;请参见下面的邀请(左侧)。在邀请中有一个按钮,点击后会打开一个带有 Google Drive 界面的网页浏览器(右侧),该界面与接收邀请的 Google 邮箱相关联。

图 1:创建共享文件夹(作者提供的图片)

界面中隐藏了一些重要信息,尽早了解这些信息将有助于完成后续过程。

  • 在 URL(屏幕顶部)中,有一个隐藏的 ID,这就是 Google 用来跟踪对该文件夹的所有操作的 ID,这也是我们稍后在 Python 代码中获取的 ID。

  • 然后它会显示:“与我共享”以及共享文件夹的名称;这点很重要,因为当我们将 Google Drive 挂载到 CoLab 笔记本时,我们会发现这个类别不可用。

  • 最后,我们看到 Data 下的文件和文件夹;这意味着我们可以访问所需的信息,并向文件夹中添加新文件。不过,文件夹的安全设置可能存在问题,因此在此阶段的一个好测试是创建一个小文本文件,并拖放到“ExternalData”文件夹中,以验证你是否拥有完全访问权限。

为了使“与我共享”的文件夹可访问,我们需要将该文件夹链接到本地/个人驱动器。我们可以通过创建快捷方式来实现,但这是一个手动步骤,每个人都会有所不同。要在 Google Colab 中访问与你共享的文件夹或文件,你需要:

  1. 转到 Google Drive 中的“与我共享”。

  2. 选择你想访问的文件夹或文件。

  3. 右键点击它并选择“将快捷方式添加到驱动器”,会出现一个弹出窗口,选择“MyDrive”,然后点击“添加快捷方式”。

  4. 将快捷方式放在驱动器上一个容易找到的位置;在我使用的设置中,快捷方式的位置是“__Shared”,确保包含快捷方式的文件夹在“MyDrive”下的文件夹列表顶部,然后是组织的子目录。

  5. 将快捷方式重命名为有意义的名称;在此示例中,我使用了“DataDevelopement”。文件位置和名称约定非常个人化,程序不关心文件存储的位置或名称,但有一定结构可以避免以后的一些麻烦。

图 2:创建快捷方式(图片由作者提供)

在本地文件系统组织好,并配置好个人 Google Drive 后,我们可以尝试在 Python 笔记本中使用这个共享文件夹,并自动化项目中的文件共享。

安装

本项目基于 Google Colab 或“Collaboratory”笔记本,我将在本文底部分享。使用这个环境的优点是它允许你在浏览器中编写和执行 Python,且

  • 不需要配置

  • 免费访问 GPU

  • 简单共享

在与拥有内部程序的组织合作时,这些是非常重要的点,因为作为外部协作者,你通常无法直接访问代码库(这可能有许多不同的原因,从安全问题到项目管理限制)。Colab 笔记本是 Google 生态系统的一部分,并且(作为附加优势)创建了一个运行时环境,提供了挂载个人 Google 驱动器(用于文件共享)的选项。

导入模块和包

在此示例中,只加载了运行时所需的必要包,我们需要一些特定的库来处理共享驱动器。

Google 授权

from oauth2client.client import GoogleCredentials

from google.colab import auth as google_auth
google_auth.authenticate_user()

from google.colab import drive
drive.mount('/content/gdrive')

使用 oauth2client 和 Google Credentials 会使文件操作更加轻松。有其他替代方案,如下载带有凭据的 JSON 文件,在某些情况下,使用 JSON 文件可能会优于使用 Google Credentials,但由于这是一个不涉及敏感数据的项目,使用 oauth2client 库已经提供了足够的保护。

pydrive

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive

gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

pydrive 是 google-api-python-client 的一个封装库,它简化了许多常见的 Google Drive API 任务,其中之一是处理查询 Google Drive 文件系统时的响应。Google Drive 通过 ID 存储所有对象,这些 ID 通过对象中的关系信息进行链接。通过 API 可以访问这些信息(见下一个代码块),但封装库在我们使用 Files.list() 的参数作为 dict 创建 GoogleDriveFileList 实例时,处理了所有繁重的工作。调用 GetList() 会以 GoogleDriveFile 的列表形式获取所有匹配查询的文件。

Google API 客户端

# Google API client:
from googleapiclient.discovery import build

# Initialize the Google Drive API client
drive_service = build('drive', 'v3')

Google API 客户端是一个大型库,功能众多,但在这个项目中,我们只需要一个模块:build。build 模块构造一个资源对象用于与 API 交互,并返回与服务交互的方法。pydrive 库能够很好地处理基本功能,如创建、更新和删除文件,但在这个项目中,有些时候我们需要更高级的功能,访问“服务”使我们能够提取 pydrive 方法未捕获的信息。

这就完成了笔记本的配置。在这个例子中,我们不需要比加载的库更多的库进行文件管理,加载库后,我们可以查看它们的作用。

笔记本中的文件管理

到目前为止,发生了一些事情:

  • Google 授权已设置,

  • 我们创建了对驱动器的访问(用于读/写访问),并且

  • Pydrive 包可用于在驱动器上进行导航

希望在你跟随并运行代码时,你会在刷新面板后看到右侧的图片。你可以在图像中看到“__Shared”下的快捷方式作为一个文件夹,我们没有看到“Shared with me”部分,但由于我们有快捷方式,所以不需要查看“Shared with me”文件。

图 3:Google Colab 网络界面中运行时环境的未挂载状态与挂载状态(图片由作者提供)

Google Drive 的工作方式与本地操作系统中的文件管理不同,文件的物理位置不重要,因为这些对象由 ID 管理,存储在非结构化的数据湖中,我们可以通过 ID 访问文件和文件夹。

不幸的是,虽然 os.path(在 Python 中)有用于遍历文件系统的 walk 函数,但 Google Drive 中没有类似的方法(或者我对此方法不知情)。不过,我们可以使用 pydrive 库,手动遍历目录树中的文件夹,幸运的是,我们知道从文件夹的路径要去哪里。因此,我们不需要遍历整个结构,而是可以使用数据路径中的文件夹名称深入文件夹树。

因此,我们遍历小列表(在这个例子中,有三个项目)来找到 ID 并使用这个 ID 进入下一个级别。请注意第四级被注释掉了;我们将在笔记本的文件处理部分的第二部分中到达这个级别。

# File handling testing:
# There are in this example three folder levels:
# /content/gdrive/MyDrive/__Shared/<your Project>/DataDevelopment

# Update these to your structure:
folderList1 = ["__Shared", your_Project ,"DataDevelopment"] #, "ExternalData"]

下面的代码块中的循环从根目录开始,当它在列表中找到一个项目时,循环将使用对象的 ID 进入列表的下一个级别,如果没有找到项目,代码将提示找不到文件夹,并且不会在结构中更深的地方查找任何文件夹。循环以快捷方式文件夹的 ID 或提示找不到文件夹结束。

# Trying to copy the created dummy file:
boo_foundFolder = False
fileID = "root"
level = 0

# View all folders and file in your Google Drive
# First loop over the list:
print("File and Folder structure - check with IDs")
for folderName in folderList1:
  print(f"Checking: {folderName}")

  if boo_foundFolder or fileID == "root": #first run
    boo_foundFolder = False

    fileList = drive.ListFile({'q': f"'{fileID}' in parents and trashed=false"}).GetList()
    for file in fileList:
      # Testing the name:      
      if(file['title'] == folderName):
        fileID = file['id']
        boo_foundFolder = True

        level += 1 
      # end if
    # end for

    if boo_foundFolder == False:
      print(f"folder not found")
      break
    # end if    
  # end if
# end for

print(f"Did we find the folder: {boo_foundFolder}")
if boo_foundFolder:
  print(fileID)
  ShortCutID = fileID
else:
  ShortCutID = 0

目前,我们已经获得了工作文件夹的本地文件 ID,但在我们能够在这个位置查找文件之前,我们需要将这个本地 ID 与共享文件夹的目标 ID 匹配。为了找到这些信息,我们必须深入 Google 基础设施,为此我们需要一个助手:drive_service。我们在加载项目时激活了这个助手,并且没有收到警告,这意味着我们可以通过 API 访问服务,并通过 ID 请求信息。

我们需要的详细信息最好通过一个简单的函数来收集,例如下一段代码块中的 findTargetID 函数。在这个函数中,fileID 是我们通过遍历文件夹中的名称找到的快捷方式 ID,通过调用 drive_service.files().get 并指定字段,我们可以获得文件夹的目标 ID(这将与 Google Drive Web 界面 URL 中的 ID 相同(见图 1))。

def findTargetID(fileID, drive_service):
  # The ID of the shared file you want to get ShortcutDetails from
  file_id = fileID

  try:
      # Get the file details
      file = drive_service.files().get(fileId=file_id, 
                                      fields="id, shortcutDetails").execute()

      # Check if the file is a shortcut
      if 'shortcutDetails' in file:
          shortcut_details = file['shortcutDetails']
          print("Shortcut Details:")
          print(f"Target ID: {shortcut_details['targetId']}")
          print(f"Target MIME Type: {shortcut_details['targetMimeType']}")
      else:
          print("The file is not a shortcut.")
      # end if

  except Exception as e:
      print(f"An error occurred: {e}")

  return shortcut_details['targetId']

if boo_foundFolder:
  targetID = findTargetID(fileID, drive_service)
  print(targetID)
else:
  print("Folder not found")
# end if 

有了这个目标 ID,我们可以访问 Google 数据服务器上的实际共享文件夹,我们不再在快捷方式文件夹中工作。

总结一下,我们创建快捷方式文件夹的原因是为了能够在挂载的文件夹列表中查看文件夹。类别“与我共享”没有挂载,但快捷方式有。因此,我们可以使用这个新的 ID 查找文件。

查找文件

我们现在得到了所需的目标 ID,即在流程开始时与我们共享的文件夹的目标 ID,凭借这个 ID,所有正常的文件操作对我们都是可用的。

我们可以通过在运行时环境中首先创建一个小的文本文件来验证我们是否对共享文件夹拥有足够的权限;创建这个文件也确认了我们可以访问运行时环境,因为当文件正确创建时,它将出现在 CoLab 笔记本的 Web 界面的左侧面板中。

# Create a test file:
with open('example.txt', 'w') as f:
  f.write('This is an example file, to test CoLab file sharing')
# this file is now sitting in the runtime space of the notebook 
# (see left plane, under files)

现在的想法是将这个文件移动到“与我共享”的文件夹“Data”,我们在快捷方式中将其重命名为“DataDevelopment”,但前面的函数提供了 ,我们现在可以使用这个 ID 来检查我们刚刚在运行时环境中创建的文件是否在共享驱动器上可用。

if boo_foundFolder:
  print("folder found")
  folderID = targetID

  file_on_drive = False
  file_id = 0

  # check if the file is on the drive:
  fileList = drive.ListFile({'q': f"'{folderID}' in parents and trashed=false"}).GetList()
  for file in fileList:
    if(file['title'] == "example.txt"):
      file_on_drive = True
      fileID = file['id']
    # end if
  # end for

  if file_on_drive:
  #Overwrites the existing Google drive file."""
    file1 = drive.CreateFile({'id': fileID})
    strFileHandling = "Updated"

  else:
    file1 = drive.CreateFile({"mimeType": "text/csv",
                             "parents": [{"kind": "drive#fileLink", 
                                        "id": folderID}]})
    strFileHandling = "Created"
  # end if

  # creating the binding to the file in the Run-Time environment:
  file1.SetContentFile("example.txt")

  # copying the file to the Google Drive:
  file1.Upload()

  print(f'{strFileHandling} file %s with mimeType %s' % (file1['title'], file1['mimeType']))

else:
  print("folder not found")
# end if

运行上述代码将创建一个新的共享文件夹中的文件,或者在找到文件时更新(覆盖)文件。

创建工作区

使用快捷方式 ID 查找目标 ID 的第二个原因是查找共享文件夹下的项目。如前所述,Google Drive 通过 ID 管理一切,快捷方式 ID 没有任何子项,因此使用此 ID 查找新项目将导致空列表。这可以通过在第一个文件夹列表中包含“ExternalData”文件夹名称来测试;第一个列表将找不到此文件夹。但是,重新启动以目标 ID 为起点的循环将找到此文件夹。

在下面的代码片段中,创建了一个新文件夹列表,使用“共享给我”文件夹名称下方的文件夹名称。虽然“ExternalData”文件夹已存在(见图 1),但“NewDataFolder”尚未创建。

# Update these to your structure:
# ... DataDevelopment/ExternalData/__CoLab_Notebook 

# Setting the Working Folder:
folderList2 = ["ExternalData", "NewDataFolder"]

我们可以使用与之前相同的循环结构,但现在不是从 ROOT 开始,而是从目标 ID 开始,循环将找到“ExternalData”文件夹,但找不到新数据文件夹。在共享为 gist 的笔记本中,此测试的代码以以下内容开始:

print("File and Folder structure - check with TARGET IDs")
boo_foundFolder = False
fileID = targetID

for folderName in folderList2:
  print(f"Checking: {folderName}")

使用第二个文件夹列表和 targetID 开始检查时,循环将报告没有“NewDataFolder”。

由于工作文件夹尚不存在,我们可以使用drive_service.files来创建这个新文件夹,并用同样的方法将所有需要从运行时环境转移到“共享给我”文件夹的文件也转移过去。

def create_folder_in_folder(folder_name,parent_folder_id, drive_service):

    file_metadata = {
    'name' : folder_name,
    'parents' : [parent_folder_id],
    'mimeType' : 'application/vnd.google-apps.folder'
    }

    file = drive_service.files().create(body=file_metadata, supportsAllDrives=True, 
                                  fields='id').execute()

    print ('Folder ID: %s' % file.get('id')) 
if WorkingFolderID == 0: 
  # fileID is the parent ID from the previous search
  create_folder_in_folder("NewDataFolder", fileID, drive_service)

主要收获: Google Drive 文件系统是基于 ID 的,所有对象都有 ID。“共享给我”的对象在 Google Colab 中不可用,但通过“快捷方式”可以访问它们,通过查找相关的目标 ID,我们可以直接在“共享给我”文件夹中工作,包括最初与我们共享的文件夹下的对象。

结论

在本文中,我们涵盖了与共享文件夹相关的一些基本方面,包括:

  1. 设置本地文件管理: 我们从接收贡献指定 Google Drive 目录的邀请开始,展示了如何构建本地文件系统以提高协作效率。

  2. 配置 Google Colab 以进行协作: 我们讨论了使用 Google Colab(一个协作 Python 环境)的优势,以及如何为项目协作进行设置。

  3. 导入必要的模块和包: 我们提供了导入基本模块和包的代码示例,包括 Google 授权、简化 Google Drive API 任务的 pydrive 和用于高级功能的 Google API 客户端。

  4. 笔记本中的文件管理: 你看到如何在 Google Colab 环境中管理文件,包括创建和移动文件,在本地环境和共享文件夹之间使用共享 ID 和目标 ID。

  5. 查找文件和创建工作区: 我们深入探讨了使用目标 ID 在共享文件夹中查找文件的过程,以及为你的项目创建新文件夹和工作区的方法。

我希望这篇关于组织间共享文件夹和文件的操作指南对你有所帮助,并且提供了一些关于如何在共享文件夹中操作文件和文件夹的见解。

感谢阅读,我希望这篇文章能帮助你解决问题或给你下一项目的灵感。

Google CoLab NoteBook 链接: gist

免责声明: 这个示例中使用的代码并未优化,而是为了说明过程(对改进代码的任何建议欢迎在托管此笔记本的 gitHub 页面提出)。

用这四个鲜为人知的 Scikit-Learn 类简化你的数据准备

原文:towardsdatascience.com/simplify-your-data-preparation-with-these-4-lesser-known-scikit-learn-classes-70270c94569f

忘掉 train_test_split:Pipeline、ColumnTransformer、FeatureUnion 和 FunctionTransformer 即使在使用 XGBoost 或 LGBM 时也是不可或缺的

Matt ChapmanTowards Data Science Matt Chapman

·发布于Towards Data Science ·10 分钟阅读·2023 年 6 月 1 日

--

图片由Victor提供,来源于Unsplash

数据准备被公认为数据科学中最不受欢迎的部分。然而,如果做得正确,它其实不必那么令人头疼。

尽管近年来由于 PyTorch、LightGBM 和 XGBoost 的迅猛崛起,scikit-learn 作为建模库的流行度有所下降,但它仍然是最佳的数据准备库之一。

我不仅仅是指那老掉牙的train_test_split。如果你愿意深入挖掘,你会发现一系列对高级数据准备技术非常有用的工具,这些工具与其他库如lightgbmxgboostcatboost完美兼容,用于后续建模。

在本文中,我将介绍四个 scikit-learn 类,这些类显著加快了我作为数据科学家日常工作中的数据准备流程。

1. Pipeline:无缝结合预处理步骤

Scikit-learn 的Pipeline类使你能够将不同的预处理器或模型组合成一个可调用的代码块:

作者提供的图片

管道可以由两种不同的东西组成:

  • 变换器:任何具有fit()transform()方法的对象。你可以把变换器看作是用于处理数据的对象,通常在数据准备工作流中会有多个变换器。例如,你可以使用一个变换器来填补缺失值,另一个来缩放特征或对分类变量进行独热编码。MinMaxScaler()SimpleImputer()OneHotEncoder()都是变换器的例子。

  • 估计器:在 scikit-learn 术语中,“估计器”通常指机器学习模型;即具有fit()predict()方法的对象。LinearRegression()RandomForestClassifier()是估计器的例子。

在一个管道中,你可以链式连接任意数量的变换器,使你能够顺序地应用不同的数据预处理步骤。如果愿意,你还可以在末尾添加一个估计器(ML 模型),以便利用新变换的数据进行预测,但这不是强制性的。

例如,你可以构建一个管道,首先用零填补缺失值,然后对变量进行独热编码:

作者提供的图像

或者,如果你想直接在管道中包含建模步骤,你可以构建一个管道,该管道用均值填补缺失值、缩放特征,然后使用RandomForestRegressor()进行预测:

作者提供的图像

使用 scikit-learn 构建管道是非常简单的。

为了说明这一点,我将首先加载一些数据并将其分为训练集和测试集。在这个例子中,我将使用糖尿病数据集,该数据集由 scikit-learn 提供,包含 442 名糖尿病患者的十个预测变量(年龄、性别、体质指数、平均血压以及六个血清测量值)和一个响应变量,表示这些预测变量记录一年后每位患者糖尿病的进展情况。

import pandas as pd
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

# Load diabetes dataset into pandas DataFrames
X, y = load_diabetes(scaled=False, return_X_y=True, as_frame=True)

# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
display(X_train.head())
display(y_train.head())

作者提供的图像。糖尿病数据集和 scikit-learn 在 BSD-3 许可证下发布

接下来,我们定义我们的Pipeline。目前,我将定义一个简单的预处理Pipeline,包括两个步骤——用均值填补缺失值,并重新缩放所有特征——且不会包含估计器/模型。然而,无论是否包含估计器,原理都是一样的。

from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import MinMaxScaler
from sklearn import set_config

# Return pandas DataFrames instead of numpy arrays
set_config(transform_output="pandas")

# Build pipeline
pipe = Pipeline(steps=[
    ('impute_mean', SimpleImputer(strategy='mean')),
    ('rescale', MinMaxScaler())
])

一旦我们定义了Pipeline,我们就“拟合”它到训练数据集,并用它来转换训练集和测试集:

# Fit the pipeline to the training data
pipe.fit(X_train)

# Transform data using the fitted pipeline
X_train_transformed = pipe.transform(X_train)
X_test_transformed = pipe.transform(X_test)

这将给我们两个预处理的数据集(X_train_transformedX_test_transformed),准备好进行建模或特征选择等后续步骤。

使用Pipeline来处理这些预处理步骤的优势有两个方面:

  1. 防止信息泄露:由于预处理器是拟合于训练数据集X_train的,因此在填补缺失值或创建独热编码特征时,测试集的信息不会“泄露”。

  2. 避免重复:如果我们不使用Pipeline来处理这些预处理步骤,我们最终会多次转换X_test数据集(每次我们想应用一个预处理步骤时)。在这种小规模下,重复可能看起来不是太糟糕。但在复杂的机器学习工作流中,你很容易增加到 5、10,甚至 20 个预处理步骤。使用Pipeline可以使这一过程变得简单,因为我们可以添加任意多的步骤,仍然只需对X_trainX_test进行一次转换:

preprocessor = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='mean')),
    ('scaler', MinMaxScaler()),
    ('step_3', ...),
    ('step_4', ...),
    ...,
    ('step_k', ...)
])

preprocessor.fit(X_train)

X_train_transformed = pipe.transform(X_train)
X_test_transformed = pipe.transform(X_test)

2. ColumnTransformer:对不同特征子集应用不同的转换器

在之前的示例中,我们对所有特征应用了相同的预处理步骤。但如果我们有异质的数据类型,并且希望对不同的特征应用不同的预处理器呢?例如,如果我们只想对数值特征进行缩放,或者如果我们想对分类特征进行独热编码?

这时ColumnTransformer派上用场了。ColumnTransformer允许你将不同的转换器应用于数组或 pandas DataFrame 的不同列。

在下面的代码中,我们首先定义了不同的列组,并且对于每个组,我们使用Pipeline来构建一个将作用于该特定组的预处理器。最后,我们将所有转换器链在一个ColumnTransformer中。

# This code will only work if you've already run the code in the previous sections

from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer, MinMaxScaler
from sklearn.impute import SimpleImputer

# Categorical columns transformer - (a) impute NAs with the mode, and (b) one-hot encode
categorical_features = ['sex']
categorical_transformer = Pipeline(steps=[
    ('impute_mode', SimpleImputer(strategy='mode')),
    ('ohe', OneHotEncoder(handle_unknown='ignore', sparse=False, drop='first')) # handle_unknown='ignore' ensures that any values not encountered in the training dataset are ignored (i.e. all ohe columns will be set to zero)
])

# Numerical columns transformer - (a) impute NAs with the mean, and (b) rescale
numerical_features = ['bp', 'bmi', 's1', 's2', 's3', 's4', 's5', 's6'] # All except 'age' and 'sex'
numerical_transformer = Pipeline(steps=[
    ('impute_mean', SimpleImputer(strategy='mean')),
    ('rescale', MinMaxScaler())
])

# Combine the individual transformers into a single ColumnTransformer
preprocessor = ColumnTransformer(

    # Chain together the individual transformers
    transformers = [
        ('categorical_transformer', categorical_transformer, categorical_features),
        ('numerical_transformer', numerical_transformer, numerical_features),
    ],

    # By default, columns which are not transformed by the ColumnTransformer 
    # will be dropped. By setting remainder='passthrough', we ensure that
    # these columns are retained, in their original form.
    remainder='passthrough',

    # Prefix feature names with the name of the transformer that generated them (optional)
    verbose_feature_names_out=True
)
# Get visual representation of the preprocessing/feature engineering pipeline
preprocessor

图片作者

要将ColumnTransformer应用于数据,我们使用与应用第一个Pipeline时相同的代码:

# Fit the preprocessor to the training data
preprocessor.fit(X_train)

# Transform data using the fitted preprocessor
X_train_transformed = preprocessor.transform(X_train)
X_test_transformed = preprocessor.transform(X_test)

3. FeatureUnion:并行应用多个转换器

PipelineColumnTransformer是非常棒的工具,但它们有一个显著的限制。你发现了吗?

它们只能顺序应用转换器。

换句话说,当你使用Pipeline/ColumnTransformer转换特征Column1时,scikit-learn 首先会将transformer_1应用于Column1,然后将transformer_2应用于Column1的转换版本,以此类推。这对于按顺序预处理数据是可以的(例如“首先填补缺失值,然后进行独热编码”),但在我们希望并行应用不同的预处理步骤时就不理想了(例如“同时从相同的基础列创建两个新特征”)。在这些情况下,使用标准的PipelineColumnTransformer就不够用了,因为一旦序列中的第一个转换器应用,Column1的原始“原始”值就会丢失。

如果我们想对相同的基础特征应用多个变换并行,我们需要使用另一个工具:FeatureUnion

我们可以将FeatureUnion视为一个工具,它创建了你底层数据的“副本”,并并行地对这些副本应用转换器,然后将结果拼接在一起。每个转换器都接收原始的底层数据,因此我们不会遇到顺序转换的问题。

要使用FeatureUnion,我们只需要添加几行代码:

# This code will only work if you've already run the code in the previous sections

from sklearn.pipeline import FeatureUnion
from sklearn.decomposition import PCA, TruncatedSVD

# Define a feature_union object which will create reduced-dimensionality features
union = FeatureUnion(transformer_list=[
    ("pca", PCA(n_components=1)),
    ("svd", TruncatedSVD(n_components=2))
])

# Adapt the numerical transformer so that it includes the FeatureUnion
numerical_features = ['bp', 'bmi', 's1', 's2', 's3', 's4', 's5', 's6'] # All except 'age' and 'sex'
numerical_transformer = Pipeline(steps=[
    ('impute_mean', SimpleImputer(strategy='mean')),
    ('rescale', MinMaxScaler()),
    ('reduce_dimensionality', union)
])

# Categorical columns transformer - same as above
categorical_features = ['sex']
categorical_transformer = Pipeline(steps=[
    ('impute_mode', SimpleImputer(strategy='mode')),
    ('ohe', OneHotEncoder(handle_unknown='ignore', sparse=False, drop='first')) # handle_unknown='ignore' ensures that any values not encountered in the training dataset are ignored (i.e. all ohe columns will be set to zero)
])

# Build the ColumnTransformer
preprocessor = ColumnTransformer(
    transformers = [
        ('categorical_transformer', categorical_transformer, categorical_features),
        ('numerical_transformer', numerical_transformer, numerical_features),
    ],
    remainder='passthrough',
    verbose_feature_names_out=True
)
preprocessor

作者提供的图片

在这个图示中,我们可以看到FeatureUnion步骤是并行应用的,而不是顺序应用的。就像之前一样,我们将preprocessor拟合到训练数据上,然后使用它来转换我们想用于建模/预测的任何数据集。

# Fit the preprocessor to the training data
preprocessor.fit(X_train)

# Transform data using the fitted preprocessor
X_train_transformed = preprocessor.transform(X_train)
X_test_transformed = preprocessor.transform(X_test)

4. FunctionTransformer:无缝集成特征工程

上面讨论的所有转换器和工具都使用 scikit-learn 中预构建的类来对数据进行标准转换(例如,缩放、独热编码、插补等)。

如果你想应用自定义函数——例如在特征工程期间——那么你需要使用FunctionTransformer。就个人而言,我非常喜欢这个类——它使得将自定义函数集成到Pipeline中变得非常简单,而无需从头编写新的转换器类。

创建FunctionTransformer非常简单。你可以按照标准的 Python 风格定义你的函数,然后创建一个管道。在这里,我定义了两个简单的函数:一个将两列相加,另一个将两列相减。

from sklearn.preprocessing import FunctionTransformer

def add_features(X):
    X['feature_1_2'] = X['feature_1'] + X['feature_2']
    return X

def subtract_features(X):
    X['feature_3_4'] = X['feature_3'] - X['feature_4']
    return X

# Put into a pipeline
feature_engineering = Pipeline(steps=[
    ('add_features', FunctionTransformer(add_features)),
    ('subtract_features', FunctionTransformer(subtract_features))
])

为了进一步简化,你可以在同一个函数中包含多个转换:

def add_subtract_features(X):
    X['feature_1_2'] = X['feature_1'] + X['feature_2'] # Add features
    X['feature_3_4'] = X['feature_3'] - X['feature_4'] # Subtract features
    return X

# Put into a pipeline
feature_engineering = Pipeline(steps=[
    ('add_subtract_features', FunctionTransformer(add_subtract_features)),
])

最后,将feature_engineering管道添加到我们之前定义的preprocessing管道中:

# Combine preprocessing and feature engineering in a single pipeline
pipe = Pipeline([
    ('preprocessing', preprocessor),
    ('feature_engineering', feature_engineering),
])

pipe

作者提供的图片

并使用这个新的管道对所有数据集应用相同的预处理/特征工程步骤:

# Fit the preprocessor to the training data
pipe.fit(X_train)

# Transform data using the fitted preprocessor
X_train_transformed = pipe.transform(X_train)
X_test_transformed = pipe.transform(X_test)

附加提示:保存你的管道以实现真正可重复的工作流程

在机器学习的企业应用中,单次使用模型或预处理工作流的情况非常少见。更常见的是,你需要定期每周/月重新运行模型,并为新数据生成新的预测。

在这些情况下,与其每次从头编写新的预处理管道,不如每次使用相同的管道。为此,在开发好管道后,使用joblib库保存管道,以便将来能够使用相同的转换来处理数据集:

import joblib

# Save pipeline
joblib.dump(pipe, "pipe.pkl")

# Assume that the below steps are applied in another notebook/script

# Load pipeline
pretrained_pipe = joblib.load("pipe.pkl")

# Apply pipeline to a new dataset, X_test_new
X_test_new_transformed = pretrained_pipe.transform(X_test_new)

结论

总结:

  • Pipeline提供了一种快速的方法来依次将不同的预处理转换器应用于你的数据

  • 使用ColumnTransformer是一种绝佳的方法,可以将不同的预处理步骤依次应用于不同的特征子集

  • FeatureUnion使你能够并行应用不同的预处理转换

  • FunctionTransformer提供了一种超级简单的方法来编写自定义特征工程函数,并将它们集成到你的管道中

如果你使用这些工具,我向你承诺它们会帮助你写出更优雅、可重复且符合 Python 风格的代码。你的机器学习工程师们会喜欢你的!

如果你喜欢这篇文章,关注我将对我意义重大。如果你想无限制地访问我所有的故事(以及 Medium.com 的其他内容),你可以通过我的推荐链接每月支付 $5 注册。与通过普通注册页面注册相比,这不会增加你的额外费用,并且帮助支持我的写作,因为我会获得一小笔佣金。

感谢阅读!

简化你的机器学习项目

原文:towardsdatascience.com/simplify-your-machine-learning-projects-ab171d19c9ef

图片由作者使用 Midjourney 创建。

为什么花费大量时间和精力在复杂模型上是个坏主意,以及应该采取什么替代措施

Hennie de HarderTowards Data Science Hennie de Harder

·发表于 Towards Data Science ·8 分钟阅读·2023 年 5 月 10 日

--

许多企业急于采用机器学习来改进他们的产品和服务。然而,许多数据科学家过于专注于创建完美的模型和使用最先进的技术。这样做的结果是,他们忘记了最重要的事情:交付一个功能齐全的最小可行产品(MVP)。在这篇文章中,我将讨论为什么在花费过多时间创建复杂模型之前,专注于获得一个有效的 MVP 更好。最后,我将提供三条创建 MVP 的建议。

当我大约六年前开始做数据科学家时,我对朴素贝叶斯、线性回归和统计学等话题并不感兴趣。也许是因为我的数学背景,我在学习期间已经掌握了这些话题。相反,我对神经网络、语言模型、计算机视觉和强化学习更感兴趣。这些话题吸引了我,我参加了课程以尽快学习它们。当我在公司处理实际业务问题时,我总是尝试一些复杂的模型和解决方案,通常涉及深度学习、从网络抓取的数据集和复杂的架构。不幸的是,我的代码杂乱且难以阅读。

我记得有一个项目我花了几个月的时间去做。我每周都与业务方开会,但我总是主讲,最终结果过于复杂且几乎没有被使用。那段时间的重要经验是不要使机器学习解决方案过于复杂,以及要少说话。通过这篇文章,我希望能防止你犯我曾经犯的错误,并解释应该采取什么措施。

不要专注于复杂的技术

为什么项目的重点不应放在使用复杂技术上?在我看来,有三个主要原因,我将在这里解释。

原因 1. 业务并不在意

第一个也是最重要的原因是业务并不在意!你的利益相关者对模型的技术细节并不感兴趣。无论你使用的是增强树还是神经网络,对他们而言都没有区别。他们想知道的是你的模型如何帮助他们实现业务目标。如果模型需要频繁重新训练,你可以用简单模型如逻辑回归而非神经网络来证明你的决策,因为逻辑回归的训练速度非常快。

机器学习模型的主要目标通常不是达到 100%的准确率。相反,机器学习模型是为了辅助业务流程。花费过多时间优化模型会延迟将可用产品推向市场的时间。更好的做法是创建一个 MVP,确保它符合业务需求,并投入生产。除了性能,还必须考虑解释性、计算速度、开发成本、鲁棒性和训练时间。这些因素同样重要,对业务人员来说可能和性能一样相关。

除了你自己,还有其他人关心复杂模型和最先进的方法。这些人通常是研究人员或数据科学同事。如果你过于与他们合作而不是与业务合作,你可能会认为建模是主要目标。为了克服这一点,尽量与业务人员更紧密地合作。在每次实施新功能后演示你的产品,并询问业务人员你的假设是否正确。看似微小的决策对业务人员可能非常重要。

原因 2. 复杂模型带来的价值不如一个可用的 MVP

你在模型上的时间花费越多,你在良好工程原则上的时间就越少,比如编写模块化代码、测试、架构、日志记录和监控。一开始以良好的方式设置这些内容可以节省大量后续时间。你可以轻松地在坚实的代码库中添加新功能。这比在 Jupyter Notebook 中拥有一个稍微表现更好的复杂模型却无法投入生产要有价值得多。简单模型的另一个好处是可解释性,这有助于说服利益相关者,因为他们可以看到预测的结果是有道理的。

尤其在初期,专注于创建一个有效的产品,并确保代码鲁棒和 CI/CD 管道的精心设计。这使得后续改进解决方案变得更加容易。如果业务没有感受到改进当前解决方案的紧迫性,你可以转到另一个项目。你并没有浪费时间去创建一个‘完美’的模型。

与此相关的是帕累托原则。这是一条规则,指出 80%的结果可以通过 20%的努力来实现(也称为 80/20 规则)。通常,创建一个比简单模型稍微好一点的复杂模型并不属于 80%的结果,而是一项困难且耗时的任务。复杂模型是那最后那 20%难以达到的部分,需要 80%的努力。在开始之前,让自己相信这值得付出。

帕累托原则。20%的努力带来 80%的结果。剩余的 20%结果需要 80%的努力。通过正确地优先排序,你可以将精力集中在用 20%的努力获得 80%的结果上。图片由作者提供。

原因 3. 复杂项目需要更多的维护

项目越复杂,维护所需的资源和时间就越多。这意味着你会花更多的时间修复 bug、优化模型、保持数据的最新状态,而花在添加新功能或改进产品上的时间会减少。另一方面,简单的项目需要的维护较少,这意味着你可以花更多时间迭代 MVP 并添加新功能以改进产品。

需要记住的一个重要想法是,最佳解决方案通常是最简单的符合要求的解决方案。这可以帮助你判断那种深度学习的前沿模型是否真的值得付出额外的工作!如果有两个模型性能相同,其中一个简单而另一个复杂,选择简单的那个。

我在公司工作的一个例子:我尝试用强化学习来解决一个调度问题。这个问题相当复杂,我们进展缓慢。由于无法展示良好的结果,业务方面有些恼火和失望。当我们将解决方案的方法切换为(传统的)数学优化时,进展就快多了!虽然不那么有趣,但我们赢得了业务的信任,并能够轻松地实现新功能和约束。

创建最简可行产品(MVP)

确定 MVP 应该涵盖哪些内容可能很困难。以下是三个提示,帮助你确定(和保护)范围。

提示 1. 在编写代码之前需要回答的问题

在编写代码之前,首先要回答以下问题。确保从业务专家那里获得答案,并且所有相关人员都参加了会议。

  • 为什么要构建这个产品?

    这个问题显示了产品的重要性。通过这个问题,你可以开始思考用于计算产品价值的指标。

  • 这个产品应该做什么?

    与其专注于“如何”,不如专注于“什么”。新产品的“如何”可以在数据科学团队中稍后解决。现在,你应该关注业务认为重要的内容以及他们认为输出应该是什么。别忘了询问他们认为完美的解决方案是什么样的。这可以帮助确定必需的功能和附加功能。对于 MVP,你需要专注于必需的功能。

  • 现在是怎么做的?

    商业专家可以告诉你所有相关信息!你可以利用这些信息确定产品的关键特性。

  • 应该联系谁作为商业专家? 除了工程师和数据科学家,至少应有一位商业专家参与。如果可能的话,还应该有一个最终用户,他知道产品最终应该做什么。你需要一个在开发过程中随时可以向他请教的人,以帮助你做出正确的决策。

这些问题是最基本的,但往往被忽视。直接跳到“如何做”并专注于 ML 技术,而不倾听利益相关者的需求,往往是项目失败的原因。确保在会议中不打断地倾听,并且不要标记产品功能为不可能实现。

提示 2. 创建一个包含功能和创意的活文档

包含所有创意和一个对所有人都能理解的功能图的文档,是与不同利益相关者沟通的好方法。它可以随着项目进展而更新,你可以跟踪功能并添加新的。图示不应过于技术性或架构性,而应以简单的英语(或任何其他语言)展示所有业务功能。下面是一个聊天机器人的简单示例图:

一个简单的聊天机器人的示例,其功能对所有人都能理解。图片来源:作者。

提示 3. 在 MVP 交付之前不要更改范围

如果你与富有创造力和发明精神的人合作,可能会发现很难坚持最初的计划。你展示你的产品,人们可能会回应:“在这里加一个计算机视觉功能不是很酷吗?”你会想,“当然了,这听起来很有趣!”但这正是问题开始的时刻,因为不久之后,你会开始构建一个复杂的系统,按时交付变得非常困难。更好的处理方式是表示感谢并记录下来。在 MVP 交付后,你可以重新审视这些想法,决定是否要实现。你还可以决定对实验阶段设定时间限制。这可以帮助你专注于核心问题,并尝试有限的策略。实验阶段之后,你可以继续推进最成功的解决方案。

结论

尽管创建一个复杂的机器学习模型似乎是交付高性能产品的最佳方式,但更好的做法是首先专注于交付一个可用的 MVP。一个可用的 MVP 可以向利益相关者展示你的产品价值,而且在简单模型上进行迭代更容易。请记住,企业更关心的是你的产品如何帮助他们实现目标,而不是你模型的技术细节。因此,专注于交付一个符合要求的可用 MVP,少担心创建完美的机器学习模型。

为了稍微补充一下这篇文章:当然也可能存在例外情况,例如当性能是主要的业务目标,或者当拥有相关问题经验的 ML 专家一致认为应该创建新解决方案时。在这种情况下,祝你好运!你可以尝试所有那些有趣的新东西。

相关

## 创建一个优秀数据科学产品的步骤

从问题到生产。

towardsdatascience.com ## 实用数据科学家的三个必备软技能

为什么技术知识不是你唯一需要关注的东西。

towardsdatascience.com ## 如何在几分钟内创建一个可以进行预测并解释其结果的网页应用

数据科学家的简短指南。

towardsdatascience.com

简化 Matplotlib 中子图的创建

原文:towardsdatascience.com/simplifying-subplots-creation-in-matplotlib-3f6efce356b9

将马赛克魔法融入你的图表

Parul PandeyTowards Data Science Parul Pandey

·发布于 Towards Data Science ·阅读时间 5 分钟·2023 年 5 月 23 日

--

图片由 charlesdeluvioUnsplash 提供

最近,我在一个项目中需要使用 Matplotlib 库创建子图。如果你曾经使用过 Matplotlib 库,你很可能也使用过它的子图功能。子图是同时生成多个图表的有效工具,这在比较结果或多个图表共享相同坐标轴时非常有用。然而,有时 Matplotlib 中的子图语法对我们很多人来说可能非常复杂,包括我自己。实现所需的子图布局似乎像是在进行试错游戏,这使我们从实际项目中分心。

藏在显眼处,真的!!

藏在显眼处 | 图片来源于 Pixabay

我知道 R 中的 patchwork 库 擅长处理子图的创建。然而,我惊讶地发现 Matplotlib 一直都有这个功能,这让我意识到应该彻底阅读文档。出于好奇,我决定深入了解这个功能,并通过博客文章将我的经历和见解分享给他人。

Matplotlib 中的子图

Matplotlib,一个广泛使用的绘图库,提供了两种创建子图的方法:**Figure.subplots()****Figure.subplot_mosaic()**。虽然这两种方法都能实现相同的目的,但使用后者方法具有一些固有的优势。让我们探讨它们的区别,强调Figure.subplot_mosaic()相较于Figure.subplots()提供的简便性和灵活性。

理解 Figure.subplots 方法

Matplotlib 中的 subplots 方法允许我们在网格状结构中创建子图。它接受指定子图网格中的行数和列数的参数,并返回一个Figure对象和一个表示各个子图的Axes对象数组。

让我们考虑一个示例,我们希望创建一个包含四个子图的 2X2 图形。我们可以使用[**Figure.subplots()**](http://Figure.subplots)方法来实现这个任务。

图(1):使用[**Figure.subplots**](https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure.subplots)方法在 Matplotlib 中创建子图 | 作者图片

图(1)所示,尽管**Figure.subplots**方法通过指定行数和列数提供了一种直接创建子图的方法,但在以下情况下会显得不足:

  • 手动索引错误

用户必须手动指定每个子图的索引。这个过程可能容易出错,特别是在处理复杂的子图排列时,或者当索引错误导致子图位置错误或遗漏时。

  • 有限的布局灵活性

该方法依赖于固定的网格结构,导致创建不规则或自定义布局变得困难。如果所需的排列不符合指定的网格,可能会导致视觉不一致或图形扭曲。

  • 调整挑战

对子图排列进行更改或添加/删除子图可能会很麻烦。调整索引或调整网格大小需要仔细的手动调整,增加了引入错误的风险。

更好的替代方案 — Figure.subplot_mosaic 方法

[**Figure.subplot_mosaic**](https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure.subplot_mosaic)方法是 Matplotlib 3.4.0 版本中引入的一种强大替代方案,旨在简化子图的创建和排列。它提供了一种更直观的方法,通过类似字典的结构定义子图,其中键表示子图标签,值定义它们在网格中的位置。使用这种方法,你可以通过指定每个子图的相对位置轻松创建复杂的子图布局。

让我们将之前的方法与Figure.subplot_mosaic()方法进行比较。

图 (2):使用 [**Figure.subplot_mosaic**](https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure.subplot_mosaic) 方法在 Matplotlib 中创建子图 | 图片由作者提供

图(2)可以看出,虽然我们获得了相同的结果,但过程要简单得多且更直观。此外,我们还拥有以下优点:

1. 直观的语法

Figure.subplot_mosaic() 使用的字典-like 结构提供了一种清晰简洁的方式来指定子图的排列。这种方法消除了在 fig.subplots() 中手动计算和索引的需要。我们已经在上面的代码中看到如何指定子图的顺序。然而,我们可以通过 限制坐标轴标签 为单个字符来进一步简化过程,如下所示:

ax_dict

--------------------------------

{'a': <AxesSubplot: label='a'>,
 'b': <AxesSubplot: label='b'>,
 'd': <AxesSubplot: label='d'>,
 'c': <AxesSubplot: label='c'>}

2. 排列子图的灵活性:

Figure.subplot_mosaic() 使我们能够轻松定义复杂的子图布局,包括跨越多行或多列的不规则网格。这种灵活性在处理需要并排显示的不同可视化数据集时特别有用。我们来看两个示例,展示我们希望的子图排列:

  • 首先,我们希望有一个跨越多行或多列的坐标轴。更准确地说,我们寻找的是类似于以下示例的东西:

图片由作者提供

尽管使用 [Figure.subplots](https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure.subplots) 方法排列子图可能让人感到困难,但 Figure.subplot_mosaic() 方法将过程简化为仅重新排列坐标轴标签。

使用 Figure.subplot_mosaic() 进行跨越多行/列的坐标轴 | 图片由作者提供

  • 另一个需要考虑的场景是并非将坐标轴填满整个图形,而是留下一些网格空间为空,如下所示:

图片由作者提供

使用 Figure.subplot_mosaic() 指定网格中的一些空间为空白

除了上述优点之外,我们还可以 控制马赛克创建管理子图创建以及分别调整 每个子图的参数 等功能。文档提供了详细的示例,是实验和探索这些功能的极好资源。

[## 复杂和语义化的图形组合(subplot_mosaic) - Matplotlib 3.7.1 文档

在图形中以非均匀网格布局坐标轴可能既繁琐又冗长。

matplotlib.org](https://matplotlib.org/stable/gallery/subplots_axes_and_figures/mosaic.html?source=post_page-----3f6efce356b9--------------------------------#string-short-hand)

结论

在这篇博客文章中,我们探讨了在 Matplotlib 中使用Figure.subplot_mosaic()相较于fig.subplots()的优势。前者在组织子图方面提供了更高的灵活性,其定位语法特别直观,从而使数据科学家和可视化爱好者能够轻松创建复杂和定制化的布局。这一功能让我想起了我最初接触 Python 的 f-strings。在 f-strings 推出之前,我主要使用%-formattingstr.format方法进行字符串格式化。虽然这些方法是有效的,但并不特别直观。然而,自从引入f-strings以来,我的编码体验显著改善——就像我使用Figure.subplot_mosaic()的经历一样👩‍💻😃

简化 Transformers:用你理解的词汇解析最前沿的 NLP —— 第四部分 —— 前馈层

原文:towardsdatascience.com/simplifying-transformers-state-of-the-art-nlp-using-words-you-understand-part-4-feed-foward-264bfee06d9?source=collection_archive---------4-----------------------#2023-10-04

传统的前馈层及其在 Transformers 中的作用。

Chen MargalitTowards Data Science Chen Margalit

·

关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 10 月 4 日

--

由于这是一个持续更新的系列,如果你还没有开始,可能需要考虑从之前的部分开始:第一部分、第二部分 和 第三部分。

本节将介绍基本的前馈层,这是大多数深度学习架构中的一个基本元素。在讨论深度学习中的重要主题时,我们将强调它们在塑造 Transformers 架构中的重要作用。

原始论文中的图片

前馈线性层基本上是一堆神经元,每个神经元都与其他一堆神经元相连。请看下面的图片。A、b、c 和 d 是神经元。它们持有一些输入,即表示我们想要理解的数据的数字(像素、词嵌入等)。它们与神经元 1 相连。每个神经元有不同的连接强度。a-1 是 0.12,b-1 是 -0.3,等等。实际上,左列中的所有神经元都与右列中的所有神经元相连。这样做会使图像变得不清晰,因此我没有这样做,但这点很重要。正如 a-1 存在一样,我们也有 a-2、b-2、c-2、d-3 等。每两个神经元之间的连接都有不同的“连接强度”。

作者提供的图片

在这个架构中,有两点需要注意:

  1. 如前所述,每个节点(神经元)都与每个其他节点相连。四个神经元 a、b、c 和 d 都与每个其他神经元(1、2、3)相连。将这张图片视为一个指挥链。1、2、3 是指挥官。他们从士兵 a、b、c 和 d 那里收到报告。A 知道一些事情,但视野不广。1 知道得更多,因为它从 a、b、c 和 d 那里收到报告。2 和 3 也是指挥官,它们的情况也是如此。这些指挥官(1、2、3)也将报告传递给更高层的指挥官。那些更高层的指挥官会同时收到来自 a、b、c、d 和 1、2、3 的报告,因为下一层(每一列神经元是一个层)也是以完全相同的方式连接的。因此,第一个重要的理解点是 1 的视野比 a 广,而下一层的指挥官将比 1 的视野更广。当你有更多的点时,你可以建立更多有趣的连接。

2. 每个节点与下一层的每个其他节点有不同的连接强度。a-1 是 0.12,b-1 是-0.3。我这里给出的数字显然是虚构的,但它们的规模是合理的,并且它们是学习到的参数(例如,在训练过程中会发生变化)。把这些数字看作是 1 在 a、b 等上的权重。从指挥官 1 的角度来看,a 稍微可信一点。你不应该完全相信他说的每句话,但你可以相信他的话中的一部分。B 则非常不同。这个节点通常会减轻它收到输入的重视程度。就像一个随和的人。这是一个对发生事情的过度简化,但重要的是要注意:每个神经元持有一些输入,无论是原始输入还是处理过的输入,并以自己的处理方式传递它。

你知道“传话游戏”吗?你和 10 个人坐成一排,你对下一个人耳语一个词,比如“Pizza”。第 2 个人听到的类似“Pazza”,于是他们把“Pazza”传给第 3 个人。第 3 个人听到的是“Lassa”(毕竟是耳语),于是他传递“Lassa”。第 4 个人听到的是“Batata”,所以他传递“Batata”,依此类推。当你问第 10 个人你听到的是什么,他会说:“Shambala!”我们怎么从 Pizza 变成 Shambala 的?事有凑巧。这种游戏和神经网络的区别在于,每个人都会添加自己的有用处理。第 2 个人不会说“Pazza”,他会说:“Pazza 是意大利菜,非常棒”。第 3 个人会说:“Lassa 是意大利菜,在世界各地都很常见”,等等。每个人(层)都添加了一些希望有用的东西。

基本上就是这样发生的。每个神经元接收一个输入,处理它,然后传递出去。为了匹配完全连接层,我建议对“传话游戏”进行升级:从现在起,你与多个排进行游戏,每个人对其他行中的任何其他人耳语。第 2 位置及其后的人从许多人那里收到耳语,并需要理解他们给每个人多少“权重”(重要性)。这就是前馈层。

我们为什么使用这些层?因为它们允许我们进行有用的计算。把它想象成群体智慧的一个例子。你知道猜测牛的体重的故事吗?1906 年,在英国某地,有人把一头牛带到展览上。主持人让 787 个随机人猜它的体重。你会怎么说?这头牛重多少?

所有猜测的平均值是 1197 磅(542 千克)。这些是随机人的猜测。它们的误差有多大?1 磅,450 克。Steer 的重量是 1198。这个故事来源于 这里,我不知道细节是否正确,但回到我们的主题,你可以把线性层想象成类似的工作方式。你增加更多的参数,更多的计算(更多的猜测),你就能得到更好的结果。

图片由 Larry Costales 提供,来源于 Unsplash

让我们试着想象一个真实场景。我们给网络一个图像,并且我们想决定它是一个苹果还是一个橘子。这个架构基于 CNN 层,我不会深入讨论,因为它超出了本系列的范围,但基本上,它是一个能够识别图像中特定模式的计算层。每一层都能识别逐渐复杂的模式。例如,第一层几乎不能察觉任何东西,它只是传递原始像素,第二层识别垂直线,下一层得知有垂直线,并且从其他神经元那里知道垂直线非常接近。它进行 1+1 运算并认为:不错!这是一个角落。这就是从多个来源获取输入的好处。

我们做的计算越多,我们可以想象得到的结果就会更好。实际上,它并不完全是这样,但确实有 一些 真理。如果我进行更多的计算并咨询更多的人(神经元),我通常可以达到更好的结果。

激活函数 我们将堆叠另一个深度学习中基本且非常重要的概念的关键构建块,然后我们将连接这些点以理解它与 Transformers 的关系。

全连接层虽然很棒,但有一个大缺点。它们是线性层,只进行线性变换和线性计算。它们只能进行加法和乘法,但不能以“创造性”的方式转换输入。有时候,增加更多的能力并不足够,你需要从完全不同的角度思考问题。

如果我赚取 $10,每天工作 10 小时,并且我想更快地存到 $10k,我可以每周工作更多天数或每天工作额外的小时。但还有其他解决方案,不是吗?那么多的银行可以抢劫,其他人不需要他们的钱(我可以更好地使用),找到更高薪的工作等等。解决方案不总是更多的同样方法。

激活函数来救场。激活函数允许我们进行非线性变换。例如,将一组数字 [1, 4, -3, 5.6] 转换成概率。这正是 Softmax 激活函数所做的。它将这些数字转换为 [8.29268754e-03, 1.66563082e-01, 1.51885870e-04, 8.24992345e-01]。这五个数字的和为 1。虽然写法有些繁琐,但每个 e-03 表示第一个数字(8)后面跟着 3 个零(例如 0.00 然后是 82926。实际数字是 0.00829268754)。这个 Softmax 激活函数将整数转变为 0 到 1 之间的浮点数,同时保持它们之间的间距。当需要对这些值使用统计方法时,你可以想象这有多么有用。

还有其他类型的激活函数,其中最常用的一种是 ReLU(修正线性单元)。它是一个极其简单(但极其有用)的激活函数,将任何负数变为 0,将任何非负数保持不变。非常简单,非常有用。如果我将列表 [1, -3, 2] 传递给 ReLU,我得到 [1, 0, 2]。

在用 Softmax 吓到你之后,你可能会期待更复杂的东西,但正如有人曾告诉我的,“运气是有用的”。有了这个激活函数,我们很幸运。

我们需要这些激活函数的原因是非线性关系无法通过线性计算(全连接层)来表示。如果每小时我赚 $10,那么我得到的金额是线性的。如果每工作 5 小时,我在接下来的 5 小时中获得 10% 的增加,关系就不再是线性的了。我的工资不会是我工作的小时数 * 固定小时工资。我们之所以在更复杂的任务中,如计算机识别和文本生成中承担深度学习的负担,是因为我们寻找的关系是高度非线性的。“我爱”之后出现的词并不明显,也不是恒定的。

ReLU 的一个重要好处,也可能是它如此常用的原因,是它在计算大量数字时非常便宜。当你有少量神经元(假设是几万个)时,计算并不是特别关键。但当你使用数百亿个神经元,如大型语言模型所做的那样,计算上的效率差异可能会产生显著影响。

正则化 在解释它在 Transformers 中(非常简单)实现的方式之前,我们将介绍最后一个概念,即 dropout。Dropout 是一种正则化技术。正则化?正如我们所教的那样,学习复杂的逻辑并不总是有用,有时我们可以记住我们所见到的东西,或记住接近它的东西。第二次世界大战是什么时候?嗯……它受到第一次世界大战、经济危机、愤怒的人们等的影响……大约在 1917 年,所以我们说 1928 年。也许记住实际日期会更好。

正如你想象的,这对机器学习来说不好。如果我们需要回答我们已经知道的答案,我们就不需要这个复杂的领域了。我们需要一个聪明的算法,因为我们不能记住所有的东西。我们需要它对实时推断进行考虑,我们需要它有点像思考。用于让算法学习而不是记忆的技术的总称是正则化。在这些正则化技术中,一个常用的是 dropout。

Dropout 什么是 dropout?一种相当简单的(幸运的是)技术。记住我们说全连接层是完全连接的?好吧,dropout 打破了这个逻辑。Dropout 技术意味着将“连接强度”设置为 0,这意味着它不会产生任何效果。焊接的“a”对指挥官 1 完全无用,因为其输入被设置为 0。没有答案,没有正面,没有负面。在每一层中我们添加 dropout 时,我们随机选择一定数量的神经元(由开发者配置),并将它们与其他神经元的连接设置为 0。每次指挥官都被迫忽略不同的士兵,因此无法记住他们,因为下次可能不会遇到他们。

回到 Transformers! 我们现在拥有了理解 Feed Forward 层中具体发生什么的所有构建块。这将变得非常简单。

原始论文中的图像

这个层简单地做三件事:

1. 基于位置的线性计算——文本中的每个位置(以向量表示)都通过一个线性层。

2. 在线性计算的输出上进行 ReLU 计算。

3. 在 ReLU 输出上进行另一个线性计算。

4. 最后,我们将添加到第 3 层的输出中。

就是这样。如果你对深度学习有经验,这部分可能对你来说非常简单。如果没有,你可能会遇到一些困难,但你已经理解了深度学习中的一个极其重要的移动部分。

在接下来的部分,我们将讨论解码器!你可以在这里找到它。

简化变换器:使用你理解的词汇的最先进的 NLP — 第五部分 — 解码器与最终输出

原文:towardsdatascience.com/simplifying-transformers-state-of-the-art-nlp-using-words-you-understand-part-5-decoder-and-cd2810c6ad40?source=collection_archive---------7-----------------------#2023-10-05

Transformer 系列的最后一部分

Chen MargalitTowards Data Science Chen Margalit

·

关注 发表在 Towards Data Science · 6 分钟阅读 · 2023 年 10 月 5 日

--

由于这是一个持续更新的系列,如果你还没有开始,可能需要考虑从之前的部分开始:第一部分,第二部分,第三部分,以及第四部分。

来自原始论文的图片。

本系列的第四部分将 heavily 基于 第二部分第三部分第四部分,因此如果你还没有阅读这些内容,并且不确定上面图像中架构的左侧如何工作,我建议你从这些部分开始。你发现的不清楚或未解释的术语表可能在之前的部分中已经讲解过。

解码器

Transformers 是一种编码器-解码器架构。输入进入后,被编码(转换)成数学表示(以某种形式的数字,通常是向量)。然后,它被传送到另一个处理单元,称为解码器,在那里它从数字转换回请求的输出。在语言模型的情况下,这将是一个单词。

解码器的第一个重要任务是从目标序列(我们想要给出的答案)中创建值矩阵。它接收整个目标序列,将其转换为嵌入(Embeddings),并以与编码器相同的方式添加位置编码。然后,它会将嵌入传递给一个掩码多头注意力层(Masked Multi-Head Attention layer),该层将创建值矩阵。这个矩阵将帮助模型决定用户提示和预期目标如何配合。

在解释 Transformers 如何工作的过程中存在一种悖论,因为你需要了解这些组成部分才能理解最终结果的发生方式,但你需要理解最终结果的发生方式才能理解这些组成部分。考虑到这一悖论,我们将短暂展望未来,并解释 Transformer 训练的两个方面:

首先,在编码器中,用户提示被转换为嵌入,然后我们添加位置编码。编码器组(原始论文中为 6 个)处理数据并生成文本的数值表示。接下来,在解码器中,真实值(我们希望模型回应的内容)会被前置一些标记,表示这是句子的第一个标记。类似于(句子开始),但它可以是模型训练的任何其他符号。这个输入被转换为嵌入,并添加位置编码。解码器组(最初也是 6 个)将这些向量与编码器的输出一起生成新的词表示。该词表示被转换为概率分布,从模型的整个数据集中选择概率最高的词。最后,根据模型选择的词与模型的真实值之间的差距计算损失函数。该损失用于生成对反向传播(计算权重如何根据各自对总体误差的贡献进行调整的算法)重要的梯度。

现在我们了解了整体流程,让我们看看一个小但重要的细节:一种叫做教师强迫的技术。

教师强迫 想象一个数学测试,你有 3 个作业:

1. 取数字 4,加上 5,并保持得分。

2. 从练习 1 中得到结果,乘以 2,并保持得分。

3. 从练习 2 中得到结果,除以 2。

你将根据每个练习的结果单独评分,无论对错。你看到了问题吗?如果你在练习 1 中犯了错误,那么练习 2 和 3 也会有错误。教师强迫技术可以处理这个问题。由于语言模型也是基于序列的(例如,预测第二个词依赖于第一个词),为了在未来做出正确的预测,你必须在之前做出正确的预测。在我们的例子中,教师强迫是给你练习 2 中的正确答案,然后在练习 3 中也给你正确答案,这样你实际测试的是乘法/除法而不是加法。你仍然会在每个练习上单独评分,你只是不必因为练习 1 中的错误而在练习 2 中受苦。

具体来说,在我们做的事情中,教师强迫帮助我们更快地训练。我们在每一步都给模型真实标签,确保第 x 个词的预测。例如,第 4 个词将基于 3 个真实(正确)标签,而不是基于可能是错误的预测标签,这样模型就可以正确地继续进行,而不是因为之前的错误。

好的,我们现在掌握了重要部分。回到正题。在转换为嵌入并添加位置编码后,我们将输入通过掩蔽多头注意力

掩蔽多头注意力 我们之前已经了解了注意力层的功能及其存在的原因。掩蔽注意力层与之前的注意力层非常相似,用于相同的原因,但有一个重要的区别。由于解码器处理的是完整的期望输出,因此模型在构建注意力分数时可以使用整个序列。由于这些注意力分数是关键部分,我们必须确保它们的准确性。假设我需要生成句子:“惊人的技术,我喜欢它”。在当前阶段,我们试图预测单词 technology。单词 technology 在不同上下文中可能有不同的表示,但在推断时(即在实际使用模型时),模型不会拥有整个句子,只会有之前的单词。因此,我们需要确保训练时也仅能访问之前预测的单词,而不能访问未来的单词。为此,我们通过掩蔽(隐藏)未来单词来实现这一点。

正如你可能猜到的,由于机器学习深深根植于数学中,我们不会简单地删除单词,而是有一种更高级的做法。我们将数学转换为被动攻击模式。我们忽略它。具体来说,在计算注意力分数时,我们添加负无穷大(一个非常小的负数,例如 -86131676513135468),这会使下一阶段的 softmax 将这些负数转换为 0。这一重要技术确保模型在没有访问下一个单词时无法使用它。

图片由作者提供

计算掩蔽注意力分数后,输入通过一个添加和归一化层,方式和原因与我们之前解释的相同。它还接收来自注意力计算前一层的跳跃连接。之后,值矩阵将继续进入下一阶段。我们现在从编码器中获取 Q(查询)和 K(键)矩阵,这些矩阵表示用户提示和对该查询的可能建议。解码器带来了自己的值矩阵,以决定关注编码器输入的哪个部分。结合这三个矩阵(两个来自编码器,一个来自解码器),我们计算“常规”注意力分数。

接下来,我们有另一个前馈 + 添加和归一化层,它接收另一个跳跃连接,就像我们之前解释过的那样,然后……我们完成了解码器!

我们现在来到了最后一步。堆叠解码器中的最后(第 6 个)解码器将其输出通过一个线性层。这个线性层使我们能够生成任意数量的数值表示。在语言模型中,我们希望表示的数量与词汇表的大小匹配。如果模型的整个词汇表(它所见过的所有唯一单词)有 1000 个,那么我们希望有 1000 个数字来表示词汇表中的每一个可能的单词。我们对每个位置的每个单词都这样做。如果最终输出包含 10 个单词,我们将为这 10 个单词中的每一个计算 1000 个数字。然后,我们将其传递给 Softmax 层,Softmax 层为每个单词提供一个概率,概率最高的单词就是我们将使用的那个。Softmax 返回一个索引,例如 3。模型输出词汇表中的索引 3,并得到新的预测单词。如果我们的词汇表是[‘a’, ‘man’, ‘works’, ‘nice’, ‘go’],那么选择的单词将是‘nice’。

就是这样……完成了!你现在对变压器架构有了全面的了解。这是一次相当长的旅程,你勇敢地完成了它。希望这个系列能帮助你理解这个复杂但又易于理解且重要的主题。

祝你度过愉快的时光。

带重启策略的模拟退火

原文:towardsdatascience.com/simulated-annealing-with-restart-a19a53d914c8

对经典的模拟退火优化算法及其在旅行商问题中的应用进行变种处理

Egor HowellTowards Data Science Egor Howell

·发表于 Towards Data Science ·5 分钟阅读·2023 年 2 月 13 日

--

图片由 Jonathan Greenaway 提供,来源于 Unsplash

背景

在我之前的文章中,我们讨论了如何使用 模拟退火元启发式 优化算法来解决 旅行商问题 (TSP)。你可以在这里查看那篇文章:

## 如何使用模拟退火解决旅行商问题

使用模拟退火优化算法来获得旅行商问题的最优解

towardsdatascience.com

TSP 是一个著名的 组合优化运筹学 问题。它的目标是找到销售员访问 n 个城市的最短路径,要求每个城市仅访问一次,并以原始/起始城市结束。

问题听起来很简单,但随着城市数量的增加,可能的路径数量会出现组合爆炸。例如,4个城市的可能路径数量为36个城市的可能路径数量为60,而20个城市的可能路径数量则是巨大的60,822,550,200,000,000! 实际上,对于20个城市,尝试每一条路径的时间大约是~2000,这需要暴力破解

TSP 的可能解的数量按(n-1)!/2的比例增长,其中n是城市的数量。

这是启发式和元启发式方法,如模拟退火,发挥作用的地方,它们在合理的计算时间内提供足够好的解决方案。

在本文中,我们将回顾模拟退火的过程,并解释其原始算法的一个小变化,这种变化可能会改善性能。然后,我们将实现这种变化以在 Python 中解决 TSP。

带重启的模拟退火

概述

模拟退火是一种随机的(随机)全局搜索优化算法。它的名字来源于退火过程,这一过程在冶金学中通过温度变化改变金属的物理性质。

模拟退火利用温度的概念来计算转移到较差解的概率,以更好地探索状态空间,从而更有可能达到全局最优。这是为了避免陷入局部最优,这是贪婪算法,如最近邻,经常发生的情况。

理论

模拟退火的一般数学框架是:

由作者使用 LaTeX 生成的方程。

其中:

这个转移概率源自于玻尔兹曼分布热力学

由作者使用 LaTeX 生成的方程。

这里x是当前解,x'是新解,Δy是两个解之间的性能差异,P(x → x')是过渡到新解的概率,T是此时过程的温度。

如果新解优于当前解,则我们总是过渡到这个新解,因为上述公式中的概率为1。此外,当新解较差但温度很高时,我们很可能会过渡到新解,尽管它的表现较差。然而,随着温度的下降,我们过渡到更差的新解的可能性会减小。因此,过程开始收敛,并且正在利用搜索空间。

温度通常以几何方式降温:

由作者在 LaTeX 中生成的方程。

其中γ是取值范围为0 ≤ γ ≤ 1的调用因子,t是迭代次数。

另一个常见的问题是如何计算初始温度?这是一个复杂的主题,好的研究这里可以帮助回答这个问题。一般来说,这主要是一个试错过程。

变体

这篇研究论文的作者那里获得灵感,我们可以稍微修改这个原始实现,以帮助更广泛地探索搜索空间。这是通过每次找到新的最佳解时重置温度到初始温度来完成的。这一过程可以描述为重启。这实际上是我们执行多个模拟退火过程,并选择找到的最佳解。

修改版 TSP 算法

实现修改版模拟退火算法以解决 TSP 问题的步骤:

  • 获取一个初始解,这可以是任何有效的路径。

  • 随机选择两个城市并交换它们以生成新路径。

  • 使用模拟退火来计算接受这个新解的概率。

  • 持续进行这一过程,迭代设置次数,并在每次迭代中降低温度。

  • 如果新解是迄今为止我们见过的最佳解,则将温度重置为初始温度。

  • 始终记录最优的总体解。

Python 实现

我们现在将实现这个新修改版的模拟退火算法来解决 TSP 问题。让我们从生成一些城市并绘制初始解开始:

作者的 GitHub Gist。

由作者在 Python 中生成的图表。

现在让我们为修改版的模拟退火算法构建一个 Python 类,以解决 TSP 问题:

我不是最好的编码者,因此以下代码片段可能不是最优或最佳实践实现!

作者的 GitHub Gist。

运行算法并记录输出:

作者的 GitHub Gist。

由作者在 Python 中生成的图表。

由作者在 Python 中生成的图表。

从上述图表中,我们可以看到温度在过程开始时频繁重启,但随着迭代次数的增加逐渐减少。找到的最佳路线看起来合理,但仍有一些路径交叉,可能意味着我们尚未找到全局最优解。但这正是元启发式算法的要点,解决方案是足够好的

总结与进一步思考

在本文中,我们解释了模拟退火算法的修改版本。在此版本中,每次找到新的最佳解时,我们将温度重置为初始温度,这个过程称为重启。这种方法为我们在 Python 中实现的旅行商问题提供了一个良好的解决方案。

本文中使用的完整代码可以在我的 GitHub 上找到:

## Medium-Articles/sa_with_restart.py at main · egorhowell/Medium-Articles

我在 Medium 博客/文章中使用的代码。通过创建一个账户来贡献开发 egorhowell/Medium-Articles:

github.com

另一件事!

我有一个免费的新闻简报,Dishing the Data,每周分享成为更好数据科学家的技巧。没有“虚假”或“诱饵”,只有来自实践中的数据科学家的纯粹可操作的见解。

## Dishing The Data | Egor Howell | Substack

如何成为更好的数据科学家。点击阅读 Dishing The Data,作者 Egor Howell 的 Substack 出版物...

newsletter.egorhowell.com

联系我!

参考文献与进一步阅读

模拟主题公园:用 R 理解队列时间

原文:towardsdatascience.com/simulating-a-theme-park-understanding-queue-times-with-r-100b12d97cd3?source=collection_archive---------4-----------------------#2023-08-21

模拟主题公园以理解队列时间,并学习如何用 R 优化业务流程。

Joseph George LewisTowards Data Science Joseph George Lewis

·

关注 发表在 Towards Data Science ·11 分钟阅读·2023 年 8 月 21 日

--

图片由 Thomas Kelley 提供,来源于 Unsplash

长时间排队总是让人不快,特别是当你在等待飞向太空或航行大堡礁时。随着暑假继续,我相信几乎每个人都会排队等候,希望你能幸运地直接前往魔法王国。也许你在阅读这个博客时正处于某个队列中!

一些代码用于支持示例,但完整的代码可以在我的 GitHub 上找到,链接在文章末尾。该项目使用 R 语言和simmer包进行离散事件模拟。请享受!

概念回顾——离散事件模拟

那么在我的笔记本电脑上模拟一个主题公园需要什么?它会像《无敌破坏王》中的游戏中心站那样吗?

恐怕不会……用 R 语言编写的代码将使用离散事件模拟(DES),它实际上展示了一个过程随时间的发展情况。DES 的主要用途是优化过程,这也是它在运筹学中常被使用的原因。模拟允许决策者在多次迭代后查看典型过程,并观察如何改进。例如,增加额外的机器是否会减少生产产品的瓶颈?

本文将离散事件模拟应用于一个假设的迪士尼世界。这个版本的公园会简单一些,并且有一些额外的假设以简化建模。

由于这个博客更侧重于应用,因此将有很多代码示例,而非理论,但对所涉及组件的概念回顾应该有助于我们跟上进度。

组件

离散事件模拟需要一些基本组件。每个组件都是我们将使用simmer创建的,但在编码开始之前,让我们在这里回顾一下:

  • 轨迹:轨迹是来宾在模拟中将要走的路径。例如,当来宾到达时,他们可能会排队等候游乐设施,乘坐游乐设施后离开。这将是他们的轨迹。

  • 资源:资源是在轨迹中使用的事物。在上面的示例中,游乐设施就是一个资源。

  • 生成器:生成器是我们用来生成来宾的工具,生成器通常会在模拟过程中生成许多来宾。生成器还需要一个轨迹来知道生成的来宾应该去哪里。

  • 环境:环境是整个模拟运行的地方,在我们的案例中就是主题公园本身。它封装了所有资源、生成器和轨迹。使用simmer,环境还有一个额外的优点,它将跟踪和报告我们的资源使用情况,使得模拟分析更为简单。

在完成审查后,来动手实践吧!下面的模拟是逐步构建的,复杂性逐渐增加。为了说明问题,代码在最后的模拟中才进行重构,利用了一些额外的 R 函数和特性来清理模拟。

模拟一:遇见莉萝和史迪奇

第一个模拟将作为引言。在这个模拟中,客人到达并开始排队等待见到莉萝和史迪奇。在编码之前,回顾一下组件将会是什么是很有帮助的。

  • 轨迹:到达公园,排队,使用角色资源(莉萝或史迪奇),释放角色资源并离开公园。

  • 资源:客人可以遇到的角色。可以是莉萝或史迪奇,因此,这个资源的容量设置为 2,这意味着它可以被两个客人同时占用。

  • 生成器:将定期生成到达公园的客人。

  • 环境:公园本身。

现在让我们考虑一下每个元素的代码。R 语言允许我们利用其他包,如 dplyr,来组织我们编写的代码。这使得轨迹构建变得更加简单:

首先,客人占用一个角色,锁定该角色以防其他客人使用。然后,他们在这个资源上超时。超时仅意味着他们不能做其他任何事情。在大多数模拟中,超时时间长度将基于真实过程。然而,在这个虚构的场景中,超时时间长度是从正态分布中取样的。最后,客人释放资源,其他排队的客人可以再次使用它。

轨迹是最复杂的概念。环境、资源和生成器都在一个代码块中一起定义:

这一组组件仍然非常基础。重要的注意事项是,资源的容量为 2,符合规格。客人根据指数函数到达,并获得上述定义的轨迹。客人的到达通常也会模拟真实事件,但为了示例的目的,指数函数就足够了!

最后的步骤是运行模拟并进行一些分析。运行模拟非常简单,只需输入 run 并提供要运行的时间步数。在这种情况下,公园将开放 15 小时。对于分析,我们将使用 simmer.plot,这是 simmer 库的一个扩展,用于构建一些简单的可视化图表:

下面的第一个图显示了资源利用率。在此模拟中,我们的利用率仅为 20%。这个百分比表示资源在模拟时间中被使用的比例。

图片由作者提供

20%的利用率相当低,理想情况下,我们希望有更高的利用率,因为这意味着更多快乐的客人!下一个可视化图表示等待时间,没有客人等待超过 5 个时间步。这是有道理的,因为我们的资源超时时间基于从正常分布中取样的 5 步左右。然而,平均等待时间要低得多,由蓝线表示。

图片由作者提供

模拟二:使用快速通行证遇见莉萝和史迪奇

现在让我们逐步增加一些复杂性。模拟将保持与上述相同。然而,我们将添加优先级顾客。这些顾客将模拟迪士尼乐园的“快速通行”系统。实际上,我们只需生成一个具有优先级的新顾客:

请注意,这位顾客的优先级为 1,而默认顾客的优先级为 0。较高的优先级值意味着这位顾客将跳过队伍,并在所有低优先级顾客之前乘坐。另一个补充是这些顾客每 50 个时间步骤到达一次。在现实中,他们会为某个均匀间隔的时间预定一个通行证。

由作者提供的图片

请注意资源利用率已经大幅增加!这是因为现在有一个保证到达的顾客,他们会在固定间隔内优先于所有其他顾客。相反,等待时间也增加了,这也是合乎逻辑的,因为不仅有更多的客人,还有更多的快速通行证顾客,这增加了原始顾客组的等待时间。

模拟三:带有违约的会面莉萝与史迪奇

排队等候可能会令人沮丧,人们经常选择放弃。这可以通过违约的概念用数学方法来解释。如果一个顾客在排队等候很长时间,我们可以模拟他们简单放弃的概率。这是通过在轨迹中添加来实现的:

顾客违约所需的时间再次从正态分布中抽样,这次的值约为 2(通常情况下,这将被估计为更高,但在这个例子中,我们有非常没有耐心的顾客来帮助理解违约)。重要的是一旦顾客获得了资源,我们必须中止违约。顾客不会在等待期结束时立即离开!

由作者提供的图片

这种违约行为最终减少了资源的利用率,因为更多的客人离开了,因此没有使用这些资源。然而,它也减少了等待时间,因为许多客人会放弃,所以他们的等待时间不被计入。客人离开的另一个效果是其他客人经历了更短的排队,这也减少了等待时间。在剩下的模拟中,我们将增加客人违约所需的时间。

模拟四:使用批处理的太空山一车游乐设施

现在让我们进入有趣的部分:游乐设施。主题公园中的游乐设施将利用simmer中的批处理功能。例如,批处理允许我们将客人分组到一个太空山车厢中:

参数n表示可以将多少客人分批到这个组中,所以在这种情况下,每个车厢可以容纳 6 个客人。timeout参数表示一个批次在转到轨迹的下一部分之前应该等待多长时间才能填满。所以,批次可以达到其 6 人的容量,或者在该批次占用车厢资源之前,可能会经过 10 个时间步。

图片由作者提供

由于这个例子增加了复杂性,并且需要将客人分批处理,或者至少在没有完整批次之前等待适当的时间,资源利用率下降。这些条件还导致了等待时间的问题,这与角色等待时间相比有很大的变化。一个客人可能会立即到达并离开,或者需要等待整整 10 个时间步才能等到一个新批次离开。为了解决这个问题,在实际应用中会增加车厢的大小或数量。

模拟五:具有分支的太空山或飞溅山

让我们让事情变得更加现实;没有任何一个游乐园只靠一个游乐设施是不够的。通过添加像飞溅山这样的新游乐设施,我们可以模拟客人的偏好,并使用分支轨迹来查看他们将选择哪个游乐设施。这有点复杂,在完整脚本中,可以看到轨迹的游乐设施元素已被添加到各自的函数中。为了简化,这里展示了一个更简洁的版本,仅用来演示分支逻辑:

在上述情况下,客人根据掷硬币来选择,因此客人选择太空山或飞溅山的可能性相等。这可以从实际偏好中建模,在更实际的应用中,模拟也会有所变化:

在这个例子中,有两个不同的资源。这与上面提到的《莉萝与史迪奇》的例子不同,因为莉萝和史迪奇都被建模为具有两个容量的“角色”资源。然而,这里飞溅山与太空山相比,是一个独立的资源。

图片由作者提供

这两个游乐设施的利用率差异很大,主要是由于两条轨迹之间的游乐时间差异。乘坐太空山所需的时间更长,客人更有可能取消。增加的复杂性也导致了更为多样的等待时间,但从优化的角度来看,我们成功地通过减少新开设的更快游乐设施的平均等待时间来改善了客人的体验。

模拟六:通过时间表和排队偏好来开放公园

最后的模拟是将我们创建的一切整合起来,并为公园添加一个时间表。到目前为止,我们一直使用固定值来定义资源容量。还有一种替代方法:使用时间表。时间表允许我们根据时间间隔来控制资源的容量。

公园时间表将由一个门控控制,门在 50 时间步后打开,并在模拟结束前 50 时间步关闭。下面的代码使用了很多重构的函数,这些函数再次链接在下面的 GitHub 仓库中。门控时间表是这里最重要的审查内容:

这一次我们的轨迹结合了游乐设施和角色。添加了一些代码,让客人在回合轮换政策中选择他们访问的角色。现在也生成了更多的客人,因为事情变得更多。

作者提供的图片

这个模拟看到所有资源的利用率更为现实。顾客有更多的事情可做,因此每个资源的使用时间减少。还有一个额外的规定是顾客必须等待才能进入公园。注意,门的存在导致了大量的利用率减少,但实际使用率很低,因为顾客会立即开启和关闭它。等待时间在这里也达到了峰值,客人们在等待公园开门以及在公园内等候游乐设施时,但这是一个更为现实的模拟。第一批客人一开门就被释放到公园,选择他们想要乘坐的设施和遇见的人。

结论和最终想法

离散事件模拟给分析师和决策者提供了探索业务流程的机会。在上面的示例中,复杂性可以从一个客人到达、进行活动并离开,扩展到考虑偏好和人类行为(如失信行为)。

通过这个过程做出的一些决策,比如增加更多的游乐设施,展示了如何利用 DES 来优化流程。另一方面,增加更复杂的行为,如失信行为,展示了如果资源没有正确设置,可能会被低效利用。

希望你喜欢这篇文章!如果喜欢,请考虑关注我的页面,以获取更多数据科学内容。

资源

代码:

## GitHub - josephlewisjgl/DESR: 存放博客文章《模拟主题公园…》的代码库

存放博客文章《模拟主题公园:用 R 理解排队时间》的代码库 - GitHub …

github.com

Simmer 文档:

## 6.1 甜甜圈店 | 模拟与建模以理解变化

这些是人类学部提供的《模拟与建模以理解变化》模块的讲义……

bookdown.org

使用 Python 模拟系外行星发现

原文:towardsdatascience.com/simulating-exoplanet-discoveries-with-python-a2d460a4889b

快速成功数据科学

模型的强大威力!

Lee VaughanTowards Data Science Lee Vaughan

·发表于 Towards Data Science ·15 分钟阅读·2023 年 12 月 18 日

--

2012 年 6 月金星凌日(Evan Clark via 现实世界的 Python

在我飞往爱达荷州拍摄2017 年大美洲日全食之前,我做了充足的准备。全食事件,即月球完全遮住太阳的时刻,仅持续了 2 分钟 10 秒。这没有时间进行实验、测试或现场解决问题。

为了成功捕捉到半影、全影、太阳耀斑和钻石环效应的影像,我必须准确了解需要带什么设备、使用什么相机设置以及这些事件何时发生。在互联网的帮助下,我能够理清这些细节并为我的位置准备了一个精确的时间表。

2017 年全日食中的钻石环效应(作者提供)

同样,计算机模拟也帮助科学家为观察自然世界做好准备。它们帮助科学家理解期望什么、何时期望以及如何校准仪器和设计实验。

本文的目标是展示使用系外行星凌日事件实际应用。系外行星是指绕我们太阳系之外的恒星运行的天体。

天文学家们通过一种叫做凌日光度测量的技术发现了数千颗系外行星,该技术记录了当系外行星在恒星与地球之间经过时,恒星光线的微弱暗淡。我们可以使用凌日模拟器来理解行星大小以及太阳黑子、小行星带、卫星甚至外星巨构等因素的影响。

要构建模拟器,我们将使用 OpenCV,这是 Python 处理图像和视频的首选开源库,以及 Tkinter,这是 Python 内置的图形用户界面(GUI)工具。我们将使用后者来制作仪表板。这里是一个预览:

运行中的系外行星过境仪表板(由作者提供)

过境光度测量

在天文学中,过境发生在一个相对较小的天体直接穿过一个较大天体的盘面和观察者之间。当小天体横穿大天体的表面时,大天体会稍微变暗。最著名的过境是水星和金星经过我们的太阳。

借助今天的技术,天文学家可以检测到远离的恒星在过境事件中的微弱暗淡。这种技术称为过境光度测量,它输出恒星亮度随时间变化的图表。

过境光度测量技术用于探测系外行星(来自 Real-world Python

在前面的图中,光曲线图上的蓝色点表示测量到的恒星发出的光。当一个行星没有位于恒星上方(图中的位置 1)时,测量到的亮度达到最大值。(我们将忽略行星在其相位过程中反射的光,这会略微增加恒星的表观亮度)。

当行星的前缘移动到盘面上(位置 2),发出的光逐渐变暗,形成光曲线中的一个斜坡。当整个行星在盘面上可见时(位置 3),光曲线变平,保持平坦,直到行星开始从盘面的远端离开。这会形成另一个斜坡(位置 4),直到行星完全离开盘面(位置 5)。此时,光曲线在其最大值处变平,因为恒星不再被遮挡。

因为过境期间阻挡的光量与行星圆盘的大小成正比,所以你可以使用以下公式计算行星的半径:

其中 Rp 是行星的半径,Rs 是恒星的半径。天文学家通过恒星的距离、亮度和颜色(与其温度相关)来确定恒星的半径。深度 指的是过境期间亮度的总变化,如下图所示。

“深度”是光曲线中观察到的亮度总变化(来自 Real-world Python

系外行星越大,光曲线的深度就越大。

当然,这些计算假设了整个系外行星,而不仅仅是部分,都经过了星体的表面。如果系外行星只是从我们的视角上擦过星体的顶部或底部,结果将是一个“未完成的”和“V 形的”光曲线,实用性有限。

部分过境(红箭头)产生了一个“V 形”的光曲线(作者提供)

通过测量光曲线观察系外行星不仅仅是专业天文学家的工作。根据 Sky & Telescope 杂志,即使是小型 6" 望远镜也可以记录有用的光曲线。NASA 甚至启动了一个 公民科学家计划,让后院天文学家帮助专业人员寻找类木星大小的系外行星。

代码

以下 Python 程序使用 OpenCV 生成系外行星过境星体的可视化模拟,使用 Matplotlib 绘制结果光曲线,并在仪表板中同时显示这两者。

为了生成光曲线,我们需要能够测量亮度的变化。我们可以通过对像素进行数学操作来使用 OpenCV 实现这一点。

OpenCV 最好通过 pip 安装,因此如果你是 Anaconda 用户,你应该将其作为你添加到 conda 环境中的最后一个包。以下是安装命令:

pip install opencv-python

你还需要一张我们太阳的图像,你可以从这个 GitHub 仓库下载。只需点击链接,然后按图像右上角的下载图标即可。将其保存在与你的 Python 脚本相同的文件夹中。

点击红色圆圈中的图标从 GitHub 下载图像 (NSO/AURA/NSF)

导入库和分配常量

以下代码导入了 tkinter,用于创建仪表板;matplotlib,用于绘制光曲线;matplotlib 的 backend_tkagg 模块和 FigureCanvasTkAgg 类,用于在 tkinter 和 matplotlib 之间集成;以及 OpenCV(cv2),用于显示星体图像和计算光曲线的相对亮度。

import tkinter as tk
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import cv2 as cv

IMG_HT, IMG_WIDTH = 400, 500
BLACK_IMG = cv.imread('limb_darkening.png', cv.IMREAD_GRAYSCALE)
EXO_RADIUS = 7
EXO_DX = 3
EXO_START_X = 40
EXO_START_Y = 230
NUM_FRAMES = 145

BLACK_IMG 变量包含了我们太阳的图像,将作为外星星体的替代图像。注意,我们将其加载为灰度图像,以便可以直接从像素中测量强度(亮度)。

使用图像可以捕捉到光球亮度的现实变化。光球是恒星发光和辐射热量的外层。由于光球的温度随着距离恒星中心的增加而降低,恒星光盘的边缘比中心部分更凉,因此看起来更暗。这种效应被称为 边缘变暗,它对光曲线有明显影响。

EXO_ 开头的常量代表与外星行星相关的参数,包括其半径、速度(DX)和起始坐标(以像素为单位)。NUM_FRAMES 常量决定了模拟的运行时长。

定义一个创建仪表板的函数

我们的仪表板将包括图像和图表。OpenCV 处理图像,matplotlib 处理图表,Tkinter 将它们组合成一个单一的显示。Tkinter 使用一个名为 canvas 的小部件来实现这一点,该小部件提供了一个绘图区域,用于显示图像、绘制形状和创建交互元素。

def create_dashboard(root):
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))
    canvas = FigureCanvasTkAgg(fig, master=root)
    canvas_widget = canvas.get_tk_widget()
    canvas_widget.pack(side=tk.TOP, fill=tk.BOTH, expand=1)

    intensity_samples = []
    exo_start_x = EXO_START_X

    for _ in range(NUM_FRAMES):
        temp_img = BLACK_IMG.copy()
        cv.circle(temp_img, (exo_start_x, EXO_START_Y), EXO_RADIUS, 0, -1)
        intensity = temp_img.mean()
        intensity_samples.append(intensity)
        relative_brightness = calc_rel_brightness(intensity_samples)

        update_image(ax1, temp_img)
        update_light_curve(ax2, relative_brightness, 'k')

        canvas.draw()
        root.update()
        root.after(3)

        exo_start_x += EXO_DX

代码开始时使用 matplotlib 创建两个垂直堆叠的子图。接着,我们使用 FigureCanvasTkAgg 类创建画布,将其嵌入 Tkinter 窗口中。我们将在此画布上显示我们的图形(fig)。作为 backend_tkagg 模块的一部分,FigureCanvasTkAgg 促进了 Matplotlib 图形在 Tkinter 应用程序中的正确渲染。

此类的 master 参数指定了将包含画布的 主小部件(或窗口)。在这种情况下,它是我们将在程序结束时创建的 Tkinter 根窗口root)(使用 root = tk.Tk())。

接下来,我们创建一个空列表来保存强度(相对亮度)测量值,并将 x 轴起始点常量重新分配给一个新变量 exo_start_x。程序运行时,我们将通过 EXO_DX 常量增加该变量,以使外星行星向前移动。EXO_DX 的值越大,移动速度越快。

模拟是由 NUM_FRAMES 常量控制的 for 循环。为了避免降低输入图像的质量,我们在每次循环开始时将其复制到临时图像(temp_img)。接着,我们使用 [cv.circle()](https://docs.opencv.org/4.x/d6/d6e/group__imgproc__draw.html#gaf10604b069374903dbd0f0488cb43670) 方法绘制一个黑色圆圈,表示外星行星。

为了测量强度,我们取图像的平均值并将其附加到 intensity_samples 列表中。

接下来的几行代码绘制仪表板并通过调用接下来将定义的函数来更新其组件。root.after(3) 行在每次迭代后引入了 3 毫秒的短暂延迟。函数通过增加外星行星的 x 坐标来结束。

定义一个计算相对亮度的函数

接下来,我们定义一个辅助函数来计算从强度样本列表中的相对亮度。第一步是找到列表中的最大值。下一步返回一个新列表,其中每个强度项都被最大值除以,将结果归一化到 0 到 1 之间。

def calc_rel_brightness(intensity_samples):
    max_brightness = max(intensity_samples)
    return [intensity / max_brightness for intensity in intensity_samples]

定义更新仪表板的函数

现在我们定义函数以在每次循环迭代中更新仪表板的两个组件。第一个更新灰度图像。第二个用更新后的intensity_samples列表重新绘制光曲线。

def update_image(ax, img):
    ax.clear()
    ax.imshow(img, cmap='gray')
    ax.axis('off')

def update_light_curve(ax, data, color):
    ax.clear()
    ax.plot(data, 
            color=color, 
            linestyle='dashed', 
            linewidth=2, 
            label='Relative Brightness')
    ax.legend(loc='upper center')
    ax.set_title('Relative Brightness vs. Time')

运行模拟

最后部分的代码调用了 Tkinter 根窗口、创建仪表板的函数以及 Tkinter 的 mainloop() 函数。后者是运行模拟的 Tkinter 事件循环

if __name__ == "__main__":
    root = tk.Tk()
    root.title("Exoplanet Transit Dashboard")
    create_dashboard(root)
    root.mainloop()

这是完成模拟的一个示例:

模拟结束时的仪表板(作者提供)

尽管系外行星过境对其恒星光曲线的影响看起来很戏剧性,但你看到的只是总亮度的最顶层部分。如果你用y 值的完整范围重新绘制曲线,行星的影响几乎无法察觉。

完整 y 轴绘制的光曲线(作者提供)

接下来,我们将处理单次过境,但在现实生活中,天文学家会尽可能捕捉多个过境。光曲线中包含大量信息,通过记录多个过境事件,天文学家可以确定系外行星的轨道参数,比如行星与恒星之间的距离。他们可以利用光曲线中的细微变化来推测行星完全覆盖恒星表面的时间。他们可以估计理论上的光边暗化量,并使用建模——如你在这里所做的——将所有信息结合起来,并将其假设与实际观察结果进行比较。

过境光度测量实验

现在我们有了一个工作中的模拟器,我们可以用它来建模可能的过境行为,以便将来更好地分析现实中的观察结果。一个方法是运行大量可能的情况,并生成一个“图谱”来预期系外行星的反应。研究人员可以使用这个图谱来帮助他们解释实际的光曲线。

星黑子

太阳黑子——在外星太阳上称为星黑子——是由于恒星磁场的变化而导致的表面温度降低的区域。星黑子可以使恒星的表面变暗,并对光曲线产生有趣的影响。

要查看示例,请编辑之前的脚本,使得一个与星黑子大小大致相同的系外行星在过境期间经过几个星黑子。根据以下指示更改常量:

EXO_RADIUS = 4

EXO_START_Y = 205

这是结果:

过境路径中的星黑子导致“崎岖不平”的光曲线(作者提供)

当系外行星遮挡(覆盖)一个星斑时,整体效果是让图像变亮,因为两个暗点变成了一个。这反过来会导致光曲线中短暂的“波动”。

小行星带

不对称的光曲线也可能由小行星带产生。这些碎片带通常源于行星碰撞或太阳系的形成,例如木星轨道上的特洛伊小行星

特洛伊小行星和木星(感谢NASA

以下代码使用面向对象编程(OOP)创建随机小行星。如果你需要可重复的小行星对象,请确保random.seed(15)这一行未被注释。更改种子编号(15)将会改变小行星的大小及其分布。

"""Simulate transit of asteroids and plot light curve."""
import random
import tkinter as tk
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import cv2 as cv

IMG_HT, IMG_WIDTH = 400, 500
BLACK_IMG = cv.imread('limb_darkening.png', cv.IMREAD_GRAYSCALE)
NUM_ASTEROIDS = 15
NUM_LOOPS = 170

random.seed(15) # Uncomment to permit reproducible asteroids.

class Asteroid():
    """Draws a circle on an image that represents an asteroid."""    
    def __init__(self, number):
        self.radius = random.choice((1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3))
        self.x = random.randint(-30, 60)
        self.y = random.randint(220, 230)
        self.dx = 3  

    def move_asteroid(self, image):
        """Draw and move an asteroid object."""
        cv.circle(image, (self.x, self.y), self.radius, 0, -1)
        self.x += self.dx

def create_dashboard(root):
    asteroid_list = []

    for i in range(NUM_ASTEROIDS):
        asteroid_list.append(Asteroid(i))

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))
    canvas = FigureCanvasTkAgg(fig, master=root)
    canvas_widget = canvas.get_tk_widget()
    canvas_widget.pack(side=tk.TOP, fill=tk.BOTH, expand=1)    
    intensity_samples = []

    for _ in range(NUM_LOOPS):
        temp_img = BLACK_IMG.copy()        
        for ast in asteroid_list:
            ast.move_asteroid(temp_img)
        intensity = temp_img.mean()
        intensity_samples.append(intensity)
        relative_brightness = calc_rel_brightness(intensity_samples)        
        update_image(ax1, temp_img)
        update_light_curve(ax2, relative_brightness, 'k')
        canvas.draw()
        root.update()

def calc_rel_brightness(intensity_samples):
    max_brightness = max(intensity_samples)
    return [intensity / max_brightness for intensity in intensity_samples]

def update_image(ax, img):
    ax.clear()
    ax.imshow(img, cmap='gray')
    ax.axis('off')

def update_light_curve(ax, data, color):
    ax.clear()
    ax.plot(data, color=color, linestyle='dashed', linewidth=2, 
            label='Relative Brightness')
    ax.legend(loc='upper center')
    ax.set_title('Relative Brightness vs. Time')

if __name__ == "__main__":
    root = tk.Tk()
    root.title("Exoplanet Transit Dashboard")
    create_dashboard(root)
    root.mainloop()

这是输出:

小行星带的模拟过境(作者提供)

过境小行星产生不规则且不对称的光曲线。内部的“台阶”代表小行星在星体表面上移动和离开的情况。

趣味模拟

现在,让我们对边缘情况和“尖锐”情况进行一些不寻常的实验。我们将进行这些实验而不显示太阳图像,以便可以完全专注于模拟的特征。为了简洁,我将提供每个模拟的代码链接,而不是直接包含代码。

模拟一个带有外卫星的系外行星

如果系外行星有一个卫星会发生什么?让我们来看看(你可以在这个GitHub 仓库中找到这次模拟的代码):

一个具有卫星的系外行星(作者及埃里克·莫滕森博士提供)

具有卫星的系外行星的光曲线(作者及埃里克·莫滕森博士提供)

一个月球在与系外行星轨道相同平面上且与地球轨道平行的轨道上,每当它被系外行星掩蔽时都会在光曲线中产生一个小的波动。NASA 可能已经观察到这种现象。你可以在这里观看该事件的视频。

开普勒-1625 的光曲线暗示了一个具有卫星的系外行星的存在(NASA

注意前一图中的开普勒-1625 光曲线与我们模拟的月球之间的相似性。得益于建模,开普勒-1625 的结果并不令人意外。

探测外星人巨构

在 2015 年,分析开普勒太空望远镜数据的公民科学家注意到,位于天鹅座的 Tabby 的星星有些异常。2013 年记录的这颗星星的光变曲线表现出亮度的非规则变化,这些变化远远超出了行星所能引起的范围。

由开普勒太空观测台测量的 Tabby 的星星光变曲线(作者来自Wikipedia

除了亮度的剧烈下降外,光变曲线还不对称,并包含一些在典型行星凌日中未见的奇怪波动。提出的解释包括:光变曲线是由星星吞噬行星、云团过境的解体彗星、大型环形行星后面跟随着小行星群,或者是外星人建造的巨大结构引起的。

科学家推测,这种规模的人工结构最可能代表了外星文明试图从其太阳中收集能量。科学文献和科幻小说都描述了这些令人惊叹的巨大太阳能板项目。例子包括戴森群、戴森球、环世界和 Pokrovsky 壳。

设计用来拦截恒星辐射的 Pokrovsky 壳层系统(Wikimedia Commons

为了模拟一个巨大结构,我们将用其他简单几何形状替换程序中使用的圆形外行星。我们不需要完全匹配曲线;我们只需要捕捉到关键特征,如不对称性、2 月 28 日左右的“突起”以及(非常)大的亮度下降。

这是我使用两个巨大的但不对称的太阳能板的尝试:

为 Tabby 的星星模拟的巨大结构(作者提供)

为 Tabby 的星星模拟的巨大结构光变曲线(作者提供)

这个曲线与 Tabby 的星星的曲线非常相似。信不信由你,我第一次尝试就制作出了这个!你可以在这里找到代码。

这很有趣,但我们现在知道,无论是什么在 Tabby 的星星周围,都允许一些光波长通过,因此它不可能是一个固体物体。基于这种行为及其吸收的波长,科学家们认为尘埃导致了这颗星星光变曲线的奇怪形状。然而,其他星星,如天秤座中的 HD 139139,具有奇异的光变曲线,仍然没有解释。

侦测外星舰队

既然我们已经很开心了,就不要拘谨。

系外行星 BR549 的超进化海狸们忙得不可开交。他们聚集了一支巨大的殖民船队,这些船只现在已经装载完毕,准备离开轨道。由于他们自己对系外行星的检测,他们决定放弃被啃噬的故乡,前往地球的郁郁葱葱的森林!

我们能否通过光变曲线检测到这支舰队?让我们找出答案。你可以在这里找到这个模拟的代码。

外星舰队穿越外星恒星(作者)

外星舰队模拟的光变曲线(作者)

环绕的宇宙飞船产生了不对称且不规则的光变曲线。根据 Tabby 星的例子,这肯定会引起兴趣,但我敢打赌没有天文学家(也许只有Avi Loeb)会有勇气提出其真实来源!无论如何,你不应该在进行详尽的模拟,包括小行星、多颗系外行星、彗星群、尘埃云和其他自然现象后得出这个结论。

摘要

希望你喜欢这个小项目,并对计算机模拟有了更深的了解。作为科学研究的多功能和强大工具,它们可以帮助科学家和工程师理解复杂现象并设计高效的实验。它们的优势包括:

  • 模拟复杂系统的能力,这些系统可能很难或不可能直接研究。这包括避免与实地研究相关的安全隐患。

  • 设计高效且成本效益高的实验,这些实验在现实世界中是无法实现的。正如我们无需观测站即可模拟系外行星的掠过,模拟可以消除对昂贵设备和资源的需求,并缩短完成研究的时间。

  • 使用灵敏度分析预测和解决参数变化的影响。这有助于识别关键因素并消除昂贵的试错需求。

  • 发现涌现现象,这些现象可能在现实世界观察中不立即显现。这可能会导致在运行模拟之前未曾梦想到的其他发现。

  • 设计和优化最有效或最高效的解决方案。

  • 为教室里的学生或会议室里的管理者生成引人入胜的教育工具。如果一张图片价值千言万语,那么一个好的模拟就值百万。

  • 预测建模的贡献,使研究人员能够预测未来趋势、行为或事件。这在经济学、气候学和流行病学等领域尤为重要。

虽然模拟不能完全替代现实世界的实验,但它们在成本、时间、安全性以及探索和理解复杂系统的能力上提供了诸多优势。

谢谢

感谢阅读,请关注我以获取更多快速成功的数据科学项目。有关发现外行星的更多内容,请查看我的书第八章,现实世界中的 Python

[## 现实世界中的 Python:黑客解决问题的指南

现实世界中的 Python:黑客解决问题的指南 [Vaughan, Lee] 在 Amazon.com 上。免费送货。

a.co](https://a.co/d/4GHvthg?source=post_page-----a2d460a4889b--------------------------------)

使用 Python 模拟物理系统

原文:towardsdatascience.com/simulating-physical-systems-with-python-dd5751e80b5c

任何工程师或科学家的必备技能

Nick HemenwayTowards Data Science Nick Hemenway

·发表于Towards Data Science ·20 分钟阅读·2023 年 3 月 7 日

--

照片由NASA拍摄,发布在Unsplash

模拟物理系统的行为在几乎所有科学和工程领域都有不可思议的实用价值。模拟允许我们理解系统的时间演变,而且这种方式通常无法与物理测试相媲美。虽然物理测试可能耗时、昂贵且可能存在危险,但模拟则快速、廉价,并且不会对设备造成损坏或对人身安全构成风险。实验也有其局限性,因为它们只能提供你实际可以测量的数据——未测量的系统状态要么无法获得,要么需要在后处理过程中估计。另一方面,模拟如同窗口,允许我们窥视系统的整个内部状态。这种基本上能查看系统内部运作的能力,可以为工程应用提供宝贵的设计洞察。

本文的目的是提供一个快速入门指南,介绍如何在 Python 编程语言中开始模拟物理系统。为此,我们将通过一个全面的示例来模拟一个弹跳的球。我特别选择了这个系统,因为它直观且不太难,但仍有一些细微之处,使得模拟它比你在网上找到的其他简单示例更有趣(控制球运动的方程会根据球是否在空中或与地面接触而变化)。在这个过程中,我们将查看如何将数学模型转换为模拟所需的正确格式,展示如何利用现有的代码库(SciPy 中的solve_ivp)来简化我们的工作,并解释为什么我们以这种方式组织代码。到本文结束时,我们不仅会拥有一个可以运行的弹跳球模拟(如下面的动画所示),还将具备可以扩展到模拟任何物理系统的基础知识。

作者提供的动画

数学部分

物理世界中的几乎所有系统都可以通过微分方程来建模。这些方程自然地来源于支配我们周围世界的物理定律。以牛顿第二定律为例:

牛顿第二定律告诉我们,质量的位置的二阶导数与施加在该质量上的净力成正比。于是我们得到一个微分方程——一个解出后会给我们一个函数的方程。对于我们的问题,给定一个强迫函数 f(t),解决上述微分方程可以得到质量的位置作为时间的函数 x(t)。

那么,我们如何解决上述微分方程呢?不幸的是,我们不能使用通用的解析方法,因为解完全依赖于强迫函数 f(t)。根据强迫函数的形式,解决上述微分方程可能既非常简单,也可能极其困难。这就是模拟发挥作用的地方。模拟为我们提供了一种通用的方法来解决任意困难的微分方程!

那么模拟是如何工作的呢?当我说模拟时,我基本上只是指微分方程的数值积分。存在大量用于此目的的算法,例如非常简单但有时不准确的欧拉方法,或者龙格-库塔方法,这些方法非常流行且更准确。我们不深入探讨这些算法的具体细节,而是将它们视为黑箱,并利用已经实现这些算法的现有代码库。然而,为了做到这一点,我们需要将问题公式化,以符合这些算法所需的格式。

数值积分算法通常要求我们将方程输入为一阶微分方程系统。一般格式如下,其中 t 是时间,x 是系统变量/状态的向量。

看着上面的方程,我们可以看到系统中每个状态的变化率受到系统当前状态和瞬时时间的影响。然而,我们的系统的微分方程往往不会自然地符合这种形式(例如,牛顿第二定律不是一阶微分方程)。相反,我们通常需要通过一些代数操作将系统方程转化为这种格式。目前这一切都比较抽象,所以在接下来的部分,我们将通过将方法应用于弹跳球问题来固化这个概念。

弹跳球问题

对于弹跳球问题,我们需要考虑两种情况:1) 当球在空中时,和 2) 当球接触地面时。在这两种情况下,我们都会使用牛顿第二定律,但方程的形式会有所变化,因为施加在球上的外力发生了变化。为了公式化运动方程,我们将使用如下坐标系统,其中 x 代表球中心距离地面的高度。

图片作者

第一步:推导系统方程

情况 1) 球在空中时: 当球在空中时,唯一作用在球上的力是重力,因此我们得到如下公式:

图片作者

情况 2) 球接触地面时: 许多不同的接触模型存在于碰撞物体中,但在这个例子中,我们将假设球像一个质量-弹簧-阻尼系统。也就是说,我们将假设球的全部质量集中在球的中心,球在碰撞时会以某种刚度 (k) 和阻尼 (b) 压缩。如下所示:

图片作者

绘制接触地面的球体的自由体图,我们得到以下图示:

图片由作者提供

对球体作用的三种力是重力、阻尼力和球内产生的弹簧力。注意,弹簧(球)的压缩量等于球的半径(R)减去球的位置(x)。稍微调整一下方程,我们得到了著名的质量-弹簧-阻尼微分方程及其强迫项。

第 2 步:将方程转换为适合仿真的格式

我们推导出了一组控制球体在飞行和接触地面时运动的微分方程。不幸的是,它们在当前形式下无法用于仿真——我们必须将它们转换为一组一阶方程。为此,我们将引入两个新的变量 x1 和 x2,如下所示:

变量 x1 和 x2 被称为状态变量,包含它们的向量称为状态向量。根据定义,状态向量包含了完全定义系统状态所需的所有信息(在我们的例子中,即球的位移和速度)。一般来说,如果我们有一个 n 阶微分方程,我们需要引入 n 个新的状态变量(一个用于 x 及其每一个导数,直到但不包括最高阶导数)。在我们的例子中,我们引入了两个状态变量,因为我们有一个二阶微分方程。

对状态向量进行微分得到:

注意,x1 的导数根据定义是 x2。我故意保留了 x 双点项,因为其表达式依赖于球体是在空中还是接触地面。

对于球体在飞行中的情况,我们有:

对于球体接触地面的情况,我们有:

将两个 x 双点表达式代入上述状态导数方程,得到以下两个向量值的状态导数方程,分别适用于球体在空中和接触时的情况。

此时我们已成功将高阶微分方程转换为一阶方程系统,并准备开始编写仿真代码!

第 3 步:代码

首先,我们将导入必要的模块,并根据个人喜好设置一些绘图设置。下面的关键代码行是from scipy.integrate import solve_ivp,它导入了我们将用于仿真的求解器(请注意,ivp代表“初值问题”,即具有给定初值的微分方程)。

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import pandas as pd

#set default plotting settings to personal preference
font_settings = {'family':'Times New Roman', 'size':12}
line_settings = {'lw':2}
plt.rc('font', **font_settings)
plt.rc('lines', **line_settings)

我们将通过创建一个BouncingBall类来组织我们的代码——下方显示了一个模板,其中包括每个方法需要实现的注释。我们将逐个填充这些方法。

class BouncingBall():

    def __init__(self, m, k, b, ball_radius_cm):
        #TODO: initialize the physical properties of the ball
        pass

    def in_air(self,t,x):
        #TODO: compute the in-air physics of the ball (return the 
        # rate of change of the ball's current state)
        pass

    def in_contact(self,t,x):
        #TODO: compute the in-contact physics of the ball (return the
        # rate of change of the ball's current state)
        pass

    def simulate(self, t_span, x0, max_step=0.01):
        #TODO: use the two physics models (`in_air` and `in_contact`) to 
        # actually simulate the trajectory of the ball given a desired time
        # span and initial conditions. Must switch between the two physics 
        # models at the appropriate times
        pass

首先,我们将查看__init__方法。如下代码所示,__init__方法仅接收并存储球体的各种物理参数。如果没有提供重力值,默认分配标准值 9.81 m/s²。

 def __init__(self, m, k, b, ball_radius_cm, gravity=None):
        """
        Parameters
        ----------
        m : float
            ball mass in kg
        k : float
            ball stiffness in N/m
        c : float
            ball viscous damping coef. in N/(m/s)
        ball_radius_cm : float
            ball radius in centimeters
        gravity : float, optional
            acceleration of gravity in m/s²\. Defaults to 9.81 m/s² 
            if None given.
        """

        self.r = ball_radius_cm/100 #radius in meters
        self.m = m
        self.k = k
        self.b = b

        if gravity is None:
            self.g = 9.81 #m/s²
        else:
            self.g = gravity #m/s²

接下来,我们将填充两个物理方法(in_airin_contact)。这两个方法(如下)接收一个数组x和当前时间t,将x数组解包为x1x2,然后根据我们在第 2 步中推导的方程计算并返回状态导数向量。

 def in_air(self,t,x):
        """computes the ball's state derivatives while in air

        Parameters
        ----------
        t : float
            simulation time
        x : array-like
            vector containing the ball's state variables (height and velocity)

        Returns
        -------
        list
            vector containing the ball's state derivatives
        """

        x1, x2 = x #unpack the current state of the ball
        return [x2, -self.g]

    def in_contact(self,t,x):
        """computes the ball's state derivatives while in contact with the ground

        Parameters
        ----------
        t : float
            simulation time
        x : array-like
            vector containing the ball's state variables (height and velocity)

        Returns
        -------
        list
            vector containing the ball's state derivatives
        """

        x1, x2 = x #unpack the current state of the ball
        x1_dot = x2
        x2_dot = -(self.b/self.m)*x2 + (self.k/self.m)*(self.r - x1) - self.g
        return [x1_dot, x2_dot]

最后,我们可以填充弹跳球类的simulate方法,这将实际执行所有有趣的工作。该方法需要将球体的物理过程向前积分,确保在适当的时候在空中和接触物理模型之间切换。请注意,根据我们定义的坐标系,当球体的位置(x)小于球体的半径(R)时,球体将与地面接触。

切换两个物理模型的最简单方法是定义一个中间方法,根据球体的位置在两个先前定义的方法之间切换。它可能类似于以下内容:

def ball_physics(self, t, x):

    #unpack ball's height and velocity
    x1, x2 = x

    #check if ball's center height is less than ball's radius
    if x1 <= self.R:
        #compute in-contact physics (state derivative)
        return self.in_contact(t,x)
    else:
        #compute in-air physics (state derivative)
        return self.in_air(t,x)

然后我们将这个单一的函数传递给数值积分器solve_ivp,并将系统在时间上向前推进。不过,这种方法有一个主要问题,就是它无法准确捕捉到物理模型需要变化的接触时刻。下面的图示将加以说明。

作者提供的图像

在时间步 tn,球体刚好在地面上方并向下移动。数值积分器将把球体的轨迹向前积分到时间步 tn+1。请注意,tn 和 tn+1 之间的时间差取决于所使用的求解器类型(固定步长或可变步长)和相应的误差容忍设置。无论如何,在时间 tn+1 时,物理模型将切换到接触模型,但存在一个问题——球体在模型变化之前已经与地面接触。我们不希望每当球体的位置(x)小于球体的半径(R)时,物理模型就改变,我们希望物理模型在球体接触的确切瞬间发生变化,即当球体的位置等于球体的半径(x = R)时。为此,我们需要另一种方法,我们必须使用“事件”。

直观地说,“事件”正如其名——在特定时间点发生的“某事”。在我们的案例中,我们将尝试捕捉两个不同的事件:1)球与地面接触的瞬间,以及 2)球离开地面的瞬间。我们通过定义一个在事件发生时等于零的函数来数学描述一个事件(对熟悉 Simulink 的人来说,这与“零交叉”概念相同)。对于我们弹跳球的问题,这只是球的位置与半径之间的差值,如下所示:

当上述函数为零时,球的高度等于球的半径,因此我们有一个接触事件。此外,从下面的图中可以看到,当事件函数在负方向上交叉零(从正向负),这表明球正在与地面接触。同样,当事件函数在正方向上交叉零(从负向正),这表明球离开地面接触。

作者提供的图片

很好,但我们该如何使用这个呢?幸运的是,我们想用的求解器(solve_ivp)中已经实现了事件的概念——我们只需正确使用它即可!查看 SciPy 的文档我们看到,solve_ivp有一个可选参数events,它接受一个任意对象。这个对象必须是可调用的,这意味着我们可以像调用函数一样调用它(这正是我们上面定义的事件函数)。该对象还有附加属性terminaldirection。如果terminal属性设置为True,当事件发生时,模拟将终止。direction属性是一个附加标志,可以设置以确定关注哪个零交叉方向。例如,如果direction=-1,只有负方向的交叉(例如,球与地面接触)会被视为事件,正方向的交叉将被忽略。

我们可以构建一个名为ContactEvent的类,这个类正好符合文档要求。查看下面的代码,我们看到这个类接收球的半径,并在__call__方法中定义事件函数(我们必须在__call__方法中实现它,因为solve_ivp要求事件对象是可调用的)。ContactEvent类还包含属性directionterminal。注意,terminal设置为True,这意味着当事件发生时,模拟将停止。direction属性在实例化时设置,因为它取决于我们尝试捕捉的事件类型。

class ContactEvent():
    """Callable class that returns zero when the ball engages/disengages
    contact with the ground.
    """

    def __init__(self, r, direction=0):
        """
        Parameters
        ----------
        r : float
            Radius of the ball in meters
        direction : int, optional
            Direction of a zero crossing to trigger the contact event.
            Negative for the ball coming into contact with the ground.
            Positive for the ball leaving contact with the ground, by default 0
        """

        self.r = r
        self.direction = direction
        self.terminal = True #terminal is True so that simulation will end on contact event

    def __call__(self,t,x):
        """Computes the height of the ball above being in contact

        Notes
        -----
        The ball will engage/disengage contact when the height of the center
        of the ball equals the radius of the ball.

        Parameters
        ----------
        t : float
            time in the simulation
        x : array-like
            vector of ball's state variables (height and velocity)

        Returns
        -------
        float
            height above being in contact
        """
        #unpack height and velocity of ball
        x1, x2 = x
        return x1 - self.r

要使用ContactEvent类,我们将创建两个ContactEvent对象(一个用于球体接触,另一个用于球体离开接触),并将它们存储在BouncingBall对象中。为此,我们需要在__init__方法中添加两行代码,如下所示:

def __init__(self, m, k, b, ball_radius_cm, gravity=None):
        """
        Parameters
        ----------
        m : float
            ball mass in kg
        k : float
            ball stiffness in N/m
        c : float
            ball viscous damping coef. in N/(m/s)
        ball_radius_cm : float
            ball radius in centimeters
        gravity : float, optional
            acceleration of gravity in m/s²\. Defaults to 9.81 m/s² 
            if None given.
        """

        self.r = ball_radius_cm/100 #radius in meters
        self.m = m
        self.k = k
        self.b = b

        if gravity is None:
            self.g = 9.81 #m/s²
        else:
            self.g = gravity #m/s²

        # create event functions for the ball engaging and disengaging contact
        # note that when coming into contact with the ground, the direction of the
        # zero crossing will be negative (height abve the ground is transitioning 
        # to negative) whereas when the ball leaves the ground the zero crossing 
        # will be positive (height above the ground is transitioning from negative
        # to positive)
        self.hitting_ground = ContactEvent(self.r, direction=-1)
        self.leaving_ground = ContactEvent(self.r, direction=1)

随着我们定义的接触事件函数,我们终于可以查看simulate方法的代码(如下)。

def simulate(self, t_span, x0, max_step=0.01):
        """Simulates the time evolution of the ball bouncing

        Parameters
        ----------
        t_span : two element tuple/list
            starting and stopping time of the simulation, e.g. (0,10)
        x0 : array-like
            initial conditions of ball (height and velocity in m and m/s)
        max_step : float, optional
            max step size in simulation, by default 0.01

        Returns
        -------
        tuple
            tuple containing the time vector and height of the ball 
        """

        #check the initial height of the ball to determine if it's starting in the air
        in_air = x0[0] > self.r 
        #extract out initial starting and stopping times
        t_start, t_stop = t_span

        #create lists with initial conditions that we can append the piecewise solutions to
        t_lst = [t_start]
        x_lst = [x0[0]]

        #loop until we reach the desired stopping time
        while t_start < t_stop:

            """
            Here we simulate the ball forward in time using either the air 
            or contact model. Each of these subroutines will terminate when 
            either 1) the final desired simulation stop time is reached,
            or 2) when a contact event is triggered. At a high level, we are
            simulating our system forward in time using the relevant physics 
            model (that depends on the state of the system). The simulation 
            will alternate between the "in_air" model and "in_contact" model 
            switching between the two each time a contact event is triggered
            """

            if in_air:
                sol = solve_ivp(self.in_air, [t_start, t_stop], x0, 
                                events=[self.hitting_ground], max_step=max_step)

            else:
                sol = solve_ivp(self.in_contact, [t_start, t_stop], x0, 
                                events=[self.leaving_ground], max_step=max_step) 

            # append solution and time array to list of solutions. 
            # Note that the starting time of each solution
            # is the stopping time of the previous solution. 
            # To avoid having duplicate time points, we will not include the first
            #data point of each simulation. This is also why we created our 
            # solution lists above with the initial conditions already in them
            t_lst.append(sol.t[1::])
            x_lst.append(sol.y[0,1::])

            #set the starting time and initial conditions to the stopping 
            #time and end condtions of the previous loop
            t_start = sol.t[-1]
            x0 = sol.y[:,-1].flatten()

            #if we haven't reached the stopping time yet in the current 
            #loop, we must be switching between being in the air and
            #being in contact
            if t_start < t_stop:
                in_air = not in_air

        #concatenate all of the solutions into a single numpy array
        t = np.hstack(t_lst)
        x = np.hstack(x_lst)

        return t,x

从上面的代码可以看出,simulate方法接收初始条件和我们希望求解的时间跨度。它首先检查球体是开始在空中还是在地面接触,以确定最初使用哪种物理模型。接下来的代码如下:

  1. 创建两个列表(t_lstx_lst),我们可以将分段模拟解决方案不断添加到这些列表中。

  2. 将球体向前模拟,通过solve_ivp传递适当的物理模型和事件对象(注意solve_ivp接受一个返回状态导数向量的函数(如第 2 步讨论),初始条件和要解决的时间跨度)。在发生接触事件之前向前模拟,届时模拟停止。将模拟时间和解决方案分别附加到t_lstx_lst

  3. 切换物理模型(每个事件都表示物理模型的切换),并将初始条件和起始时间重置为前一次求解的最终条件和时间。

  4. 重复第 2 步和第 3 步,直到达到期望的最终模拟时间。

  5. 将所有时间和解向量连接成单一数组并返回解决方案。

完整的BouncingBall类如下所示。

class BouncingBall():
    """Class to simulate a bouncing ball

    Methods
    -------
    in_air(t, x) : Computes the ball's state derivatives while in air

    in_contact(t, x) : Computes the ball's state derivatives while in contact

    simulate(t_span, x0, max_step) : Simulates the ball bouncing
    """

    def __init__(self, m, k, b, ball_radius_cm, gravity=None):
        """
        Parameters
        ----------
        m : float
            ball mass in kg
        k : float
            ball stiffness in N/m
        c : float
            ball viscous damping coef. in N/(m/s)
        ball_radius_cm : float
            ball radius in centimeters
        gravity : float, optional
            acceleration of gravity in m/s²\. Defaults to 9.81 m/s² 
            if None given.
        """

        self.r = ball_radius_cm/100 #radius in meters
        self.m = m
        self.k = k
        self.b = b

        if gravity is None:
            self.g = 9.81 #m/s²
        else:
            self.g = gravity #m/s²

        # create event functions for the ball engaging and disengaging contact
        # note that when coming into contact with the ground, the direction of the
        # zero crossing will be negative (height abve the ground is transitioning 
        # to negative) whereas when the ball leaves the ground the zero crossing 
        # will be positive (height above the ground is transitioning from negative
        # to positive)
        self.hitting_ground = ContactEvent(self.r, direction=-1)
        self.leaving_ground = ContactEvent(self.r, direction=1)

    def in_air(self,t,x):
        """computes the ball's state derivatives while in air

        Parameters
        ----------
        t : float
            simulation time
        x : array-like
            vector containing the ball's state variables (height and velocity)

        Returns
        -------
        list
            vector containing the ball's state derivatives
        """

        x1, x2 = x #unpack the current state of the ball
        return [x2, -self.g]

    def in_contact(self,t,x):
        """computes the ball's state derivatives while in contact with the ground

        Parameters
        ----------
        t : float
            simulation time
        x : array-like
            vector containing the ball's state variables (height and velocity)

        Returns
        -------
        list
            vector containing the ball's state derivatives
        """

        x1, x2 = x #unpack the current state of the ball
        x1_dot = x2
        x2_dot = -(self.b/self.m)*x2 + (self.k/self.m)*(self.r - x1) - self.g
        return [x1_dot, x2_dot]

    def simulate(self, t_span, x0, max_step=0.01):
        """Simulates the time evolution of the ball bouncing

        Parameters
        ----------
        t_span : two element tuple/list
            starting and stopping time of the simulation, e.g. (0,10)
        x0 : array-like
            initial conditions of ball (height and velocity in m and m/s)
        max_step : float, optional
            max step size in simulation, by default 0.01

        Returns
        -------
        tuple
            tuple containing the time vector and height of the ball 
        """

        #check the initial height of the ball to determine if it's starting in the air
        in_air = x0[0] > self.r 
        #extract out initial starting and stopping times
        t_start, t_stop = t_span

        #create lists with initial conditions that we can append the piecewise solutions to
        t_lst = [t_start]
        x_lst = [x0[0]]

        #loop until we reach the desired stopping time
        while t_start < t_stop:

            """
            Here we simulate the ball forward in time using either the air 
            or contact model. Each of these subroutines will terminate when 
            either 1) the final desired simulation stop time is reached,
            or 2) when a contact event is triggered. At a high level, we are
            simulating our system forward in time using the relevant physics 
            model (that depends on the state of the system). The simulation 
            will alternate between the "in_air" model and "in_contact" model 
            switching between the two each time a contact event is triggered
            """

            if in_air:
                sol = solve_ivp(self.in_air, [t_start, t_stop], x0, 
                                events=[self.hitting_ground], max_step=max_step)

            else:
                sol = solve_ivp(self.in_contact, [t_start, t_stop], x0, 
                                events=[self.leaving_ground], max_step=max_step) 

            # append solution and time array to list of solutions. 
            # Note that the starting time of each solution
            # is the stopping time of the previous solution. 
            # To avoid having duplicate time points, we will not include the first
            #data point of each simulation. This is also why we created our 
            # solution lists above with the initial conditions already in them
            t_lst.append(sol.t[1::])
            x_lst.append(sol.y[0,1::])

            #set the starting time and initial conditions to the stopping 
            #time and end condtions of the previous loop
            t_start = sol.t[-1]
            x0 = sol.y[:,-1].flatten()

            #if we haven't reached the stopping time yet in the current 
            #loop, we must be switching between being in the air and
            #being in contact
            if t_start < t_stop:
                in_air = not in_air

        #concatenate all of the solutions into a single numpy array
        t = np.hstack(t_lst)
        x = np.hstack(x_lst)

        return t,x

到此为止,所有的艰苦工作都完成了,我们准备好运行模拟了!我们可以使用以下几行代码,创建一个BouncingBall对象,模拟从 2 米高度下落后球体弹跳八秒,然后绘制轨迹。

b = BouncingBall(m=1, k=10e3, c=10, ball_radius_cm=6)

t,x = b.simulate([0,8], [2, 0])
df = pd.DataFrame({'time':t, 'x':x}).set_index('time')
df.to_csv('sim_data.csv')

fig, ax = plt.subplots()
ax.plot(t,x)
ax.set_xlabel('Time [s]')
ax.set_ylabel('Height [m]')

fig.savefig('bouncing_ball.png')

作者提供的图片

相当令人满意!如果我们愿意,甚至可以对数据进行后处理,创建一个球体弹跳的动画(如帖子开头和这里重复的动画)。

作者提供的动画

matplotlib中创建动画超出了本文的范围,但对感兴趣的人来说,创建动画的代码包含在本文附带的 Github 仓库中。

结论

模拟物理系统的能力在几乎所有的科学和工程领域中都能提供宝贵的帮助。虽然所突出的例子相对具体,但它涵盖了一个更为通用的工作流程,可以扩展到模拟任何类型的物理系统。无论系统是什么,你总是需要确定支配方程,将其转换为一阶微分方程系统,然后将方程传递给某种数值积分例程。当然,这里有些细节在本文中无法讨论——有时系统包含约束方程,我们会得到一组微分代数方程(DAEs),而不是常微分方程(ODEs)。或者你可能无法明确地求解状态向量的导数与系统状态的关系。无论如何,本文中涵盖的内容仍然是必须掌握的基础,才能继续进行更深入的研究。

欢迎随时留言或提问,或者通过 Linkedin 联系我 —— 我会很乐意澄清任何不确定的地方。最后,我鼓励你自己动手尝试代码(或将其作为你自己工作流程的起始模板)—— 本文的所有代码可以在我的 Github 上找到。

Nicholas Hemenway

  • 如果你喜欢这个,请关注我在 Medium 上的文章

  • 考虑订阅 电子邮件更新

  • 有兴趣合作吗?让我们 在 LinkedIn 上联系

模拟扑克牌游戏‘战争’

原文:towardsdatascience.com/simulating-the-card-game-war-ebafb4462a6a

一个关于具有无限变化的简单游戏的编程故事

Jake MitchellTowards Data Science Jake Mitchell

·发表于 Towards Data Science ·阅读时间 6 分钟·2023 年 1 月 24 日

--

我模拟了很多游戏——需要技巧、欺骗或策略的游戏。

这不是其中之一。

图片由 Ivan Slade 提供,来源于 Unsplash

引言:

‘战争’游戏非常简单。两个玩家分别抽取一副标准 52 张扑克牌的一半。然后两人对峙,翻开他们的顶牌。谁的牌面值高,谁就拿走这两张牌。如果牌面值相同,玩家需要冒险出自己牌堆中的接下来的 3 张牌,并用第 4 张牌进行对战。谁打出的第 4 张牌最大,谁就赢得所有牌(包括自己出的 3 张牌和对手出的 3 张牌)。当一位玩家没有任何牌时,游戏结束。

获胜没有秘密。你只能希望你的牌面值比对手的牌高。你完全依赖运气。

当我开始编写这个游戏的代码时,我对卡片的初始值是否能在游戏开始之前预测赢家感兴趣。令我没有预料到的是,在适当的情况下,这个游戏实际上可以永远进行下去。

游戏编程:

我首先对一副扑克牌进行‘洗牌’:

deal = randperm(52,52);
  for i = 1:13
    for j = 1:52
      if deal(j) > (i - 1)*4 && deal(j) <= (i - 1)*4 + 4
        deal(j) = i;
      end
    end
  end

这段代码创建了一个包含 1 到 52 的 52 空间数组,数字随机排列。然后,循环会将 1 到 13 的值分配给 4 张卡,其中 1 代表 A,13 代表 K。剩下的就是相当于洗过牌的扑克牌。

然后我决定每个对手的下一张卡是否不相等。如果是,代码就非常简单:

if next_up(1,1) ~= next_up(1,2)
  [best,index] = max(next_up(1,:));
  player_cards(index,size(player_cards,2) + 1:size(player_cards,2) + 2) = next_up(1,:);
  player_cards(1,next_up(2,1)) = 0;
  player_cards(2,next_up(2,2)) = 0;
end

这确定了哪张牌更高,前提是两张牌不相等,并将这两张牌给予赢家,同时从输家那里拿走那张牌。

如果打出的卡牌相同,那么玩家必须进入战争。这段代码与之前展示的代码相同,不同之处在于从每个玩家处额外添加了 3 张卡牌到共享堆中。

这些步骤会重复进行,直到决定出胜者…… 或者我以为如此

结果:

当一个模拟无法结束时,结果很难得出。然后我查看了每个玩家的牌组。

玩家们已经进入了一个完美平衡的“战争”游戏。

然后我绘制了每回合每手牌的数量,以可视化发生了什么:

作者提供的图像。

这是一个显示每回合游戏状态的图表。由于游戏的零和性质,它在 26 张卡牌线处完美对称。这张图显示在这个无限游戏中(我在 2000 回合时停止了它以生成这张图片),玩家们不小心发现了一种导致卡牌流动完美平衡的模式。

红方拿到的每张卡牌,蓝方在下一回合会拿回。如此反复——永远。我放大了之前的图表来说明这一点:

作者提供的图像。

不过这并不是每个游戏都会发生的情况。很多游戏的进展完全正常。我积累了更多这些图表,以展示“战争”游戏可以产生的不同卡牌流动。

作者提供的图像。

视觉化这些游戏的样子是非常迷人的。一方面(左侧),你会看到一个包含许多不同枢纽点的游戏。玩家来回交换领先位置,直到有一个无法继续竞争。而在另一端(右侧),你会看到一个从一开始就严重偏向某个玩家的游戏。蓝方可能依靠几张高价值的卡牌勉强维持,但最终还是运气耗尽了。

我仍然对找出有多少比例的游戏会跃入无限感兴趣。我模拟了数百次“战争”游戏,并跟踪了达到 2000 回合的游戏数量(这是我对无限的定义,因为如果我真的让它无限运行下去,我将无法完成模拟)。我发现 16.67%的游戏结果是无限游戏。

我还发现平均非无限游戏持续 513 回合。 根据下面的直方图,大多数游戏持续不到 500 步。虽然这看起来很多,但平均每步可能持续 2-3 秒,意味着平均游戏时间是 15-20 分钟。

当我们考虑到 16.67%的游戏会进入无限回合时,我们可以理论上说平均每游戏回合数……是无限—— 但这个想法让我脑袋有点疼。

作者提供的图像。

到这个时候,我发现也许我对规则的理解导致了这个无限游戏。我编写代码使得当一个玩家获胜时,他们将那些卡片放回到手牌的底部——没有洗牌。

我模拟了相同的代码,这次加入了一个洗牌块,生成的结果与我玩游戏的经验更为一致。

我发现没有一个模拟游戏达到无限。 我还发现每局游戏的平均回合数缩短到了大约 250 回合——是未洗牌版本的一半。

图片由作者提供。

直方图本身说明了一切。很少有游戏接近无限,许多游戏只持续了 5 到 10 分钟。

只需简单地洗牌就可以产生足够的变化,防止游戏持续无限时间,这完全是合理的。对于不感兴趣于实验卡片总平衡性的普通玩家来说,洗牌你的胜利卡片就足以使游戏时间保持在合理范围内。

结论:

“战争”游戏在未洗牌的版本中,有很大的可能性会持续无限时间。每 6 场游戏中可能有 1 场是无限的。如果你的目标是找到卡片之间的完美平衡,那么我建议你玩这个版本。

我很少有项目的结果超出我的预期。在Catan中,我一般知道什么是好的设置,只是在优化它。在Monopoly中,我期望找到投资回报最好的物业是理想的。在这个项目中,我盲目地进行实验——发现我最简单的项目产生了最有趣的结果。

感谢你花时间阅读我的文章!如果你读到这里,为什么不继续阅读有关棋盘游戏中的数据科学呢?看看我其他的一些项目:

模拟 101:导热传输

原文:towardsdatascience.com/simulation-101-conductive-heat-transfer-a4f09b3e16b4

计算物理的温馨介绍

Le NguyenTowards Data Science Le Nguyen

·发表于Towards Data Science ·11 分钟阅读·2023 年 7 月 25 日

--

传导或物体之间的热传递是我们每天都会经历的现象。将锅放在炉子上或坐在热公园长椅上让我们对导热传递有直观的感觉,但在这里我们将正式化这一过程并建立一个基本的计算框架来模拟它。传导是解决的第一个模拟问题,因为它使用了许多计算物理问题中的基本工具。

在本文中,我们将:

  • 创建网格网格以表示材料

  • 学习基本的热传递方程及其计算等效物

  • 基于基础物理更新网格网格中的值

  • 模拟导热传输

创建网格网格

网格网格是用于离散化连续空间的计算工具。也就是说,我们无法在问题的所有时间和空间上进行计算,因此我们选择一个代表性子集的点,通常间隔规则地查看。

在下图 1 中,我们可以看到一个网格网格的示例。这里一个空间被细分为均匀间隔的单元,这在物理模拟中是常见做法。我们现在可以只处理网格点,这使得问题更具可行性。

图 1:网格网格示例。在模拟中,我们将空间划分成这样的网格,并在每个虚线网格点计算值。

上面的网格是使用 Python 的 numpy meshgrid 函数创建的,该函数可以接受一组一维数组并为我们创建一个网格。对于我们的模拟,我们希望建模一个二维表面,因此我们将生成 2 个数组,填充我们想要的起始值,长度为我们希望在其上评估模拟的间隔数。请参见下面的代码片段,我们创建了一个 100x100 的零网格作为模拟的基础。

import numpy as np

#Define how many intervals we want per axis
resolution = 100

#Create x and Y arrays of zeros of length 100
x = np.zeros(resolution)
y = np.zeros(resolution)

#Create a mesh grid from above arrays. Outputs are all x and Y values in grid
gridX, gridY = np.meshgrid(x,y)

在讨论网格之前,需要注意的是,我们的网格上的每个点通常不是单一值。通常在每个网格点上,我们希望有一个值的数组来表示我们将要处理的属性。在我们的热传导模拟中,我们需要知道每个网格点的材料属性以及温度。因此,网格上的所有点将是一个包含如温度、材料密度、材料导热性以及我们需要了解的其他信息的数组。

制作网格的简单方法是创建多个在每个点上包含单一值的网格,然后将它们叠加在一起,如图 2 所示。

图 2: 叠加 2 个网格以创建一个新的网格,其中每个点包含多个值。

Numpy 在这里再次发挥作用,通过 dstack 函数可以按元素叠加两个数组。下面的代码片段将创建 2 个网格并将它们叠加在一起。

import numpy as np

#Define how many intervals we want per axis
resolution = 100

#Create x and Y arrays of zeros of length 100
x1 = np.zeros(resolution)
y1 = np.zeros(resolution)

x2 = np.full(resolution, 1)
y2 = np.full(resolution, 1)

#Create a mesh grids from above arrays. Outputs are all x and Y values in grid
gridX, gridY = np.meshgrid(x1,y1)
gridX2, gridY2 = np.meshgrid(x2,y2)

#Stack both of the mesh grids we have created so every element is [0,1]
fullGridX = np.dstack([gridX,gridX2])
fullGridY = np.dstack([gridY,gridY2])

有了创建我们自己的网格以表示模拟环境的工具后,我们可以继续研究物理学。

热传导基础

我们将首先从一维时间相关热传导方程开始。

方程 1: 一维时间相关热传导方程

该方程表明温度随时间变化的原因是材料的热属性与材料内的温度变化的比例(严格来说是温度的二阶导数)。我们不需要担心材料的热属性,因为它们可以在 查找表 中找到。我们需要做的是将方程中的导数转换为我们可以计算的形式。幸运的是,这可以通过 有限差分法 完成,将方程 1 转换为下面的方程 2。

方程 2: 计算一维时间依赖热传导方程

方程 2 是时间依赖热传导方程的计算等效形式,我们可以用它来更新模拟中的温度。遍历网格图中的每个单元,我们可以通过该单元的当前温度加上该单元与其邻近单元之间的温差(这里 i 代表网格索引),乘以材料的热属性和我们选择的时间步长来更新该单元的温度。

方程 1 和 2 是一维热方程,但我们希望在二维中运行我们的模拟。添加第二维是直接的,特别是当我们有热方程的计算版本时,我们只需将所有邻近单元添加到方程 3 中。

方程 3: 计算二维时间依赖热传导方程

现在我们已经将时间依赖热传导方程转换为计算表达式,我们可以用它来更新我们的网格图。

更新我们的网格图

首先,我们使用迄今为止开发的工具初始化一个网格图。我们知道我们需要至少 2 个特征来插入我们的温度方程,即当前温度和我们选择的材料的热扩散常数。我们将从室温(20°C)开始,材料为铜,其热扩散率为 1.11x10^-2 cm²/s。我们用这些值填充 2 个网格图并将它们叠加以形成我们的整体网格。

import numpy as np

#Define how many intervals we want per axis
resolution = 100
startingTemperature = 20
thermalDiffusivity = 1.11*10**-2

#Create x and Y arrays of zeros of length 100
x1 = np.full(resolution, 20)
y1 = np.full(resolution, 20)

x2 = np.full(resolution, 1.11*10**-2)
y2 = np.full(resolution, 1.11*10**-2)

#Create a mesh grids from above arrays. Outputs are all x and Y values in grid
gridX, gridY = np.meshgrid(x1,y1)
gridX2, gridY2 = np.meshgrid(x2,y2)

#Stack both of the mesh grids we have created so every element is
#[20,1.11*10**-2]
fullGridX = np.dstack([gridX,gridX2])
fullGridY = np.dstack([gridY,gridY2])

在我们的具体情况下,我们的网格图是对称的(x 和 y 方向上的特征值相同),所以单独定义 x 和 y 网格不是必要的,但这仍然是一个好习惯,因为在模拟时我们不会总是有这样的对称性。现在我们已经用网格图表示了材料,我们需要给它添加热量(正如我们在物理方程中看到的,热量传递必须有温差)。为了给材料添加热量,我们需要选择网格图中要加热的单元,并增加它们的温度索引。我们将继续保持简单,假设在 X 和 Y 方向的 35–65 索引范围内添加 1000°C 的热量。下面可以看到执行此操作的代码和结果图。

import matplotlib.pyplot as plt

#fullGridX[:,:,0] gets the first index of every element in our mesh grid
#fullGridX[:,:,0][35:65,35:65] selects a box of values
fullGridX[:,:,0][35:65,35:65] += 1000

plt.imshow(fullGridX[:,:,0])

图 3: 我们的网格图上增加热量的可视化。

我们在网格中添加了一片热区,这只是一个好的开始,但现在我们需要随着时间推移来演变我们的系统,以观察热量的传递。我们还需要开发几个工具来完成这项工作。第一个工具是查找邻居温度的函数,该函数可以获取所有邻近单元格的温度。下面给出的函数循环遍历了给定单元格的所有 8 个边界单元格,但忽略了角落,仅检索上、下、左、右的单元格。我们在查看一个单元格的所有邻居时,还需要考虑边界条件,即当我们到达网格的边缘时我们想要做什么。如果我们处于已经在网格边缘的网格单元格中,寻找所有邻居将抛出一个错误,因为它们并不都存在。我们可以在循环中使用 try except 语句来处理这一点,当我们尝试查找不存在的单元格时,它将给出我们的边界条件。我们的边界条件可以填充我们想要假设在网格外部的温度。在我们的模拟中,我们将假设这是室温,因此当我们计算网格边缘的热传递时,室温将始终作为邻近值提供给边缘单元格。

def getNeighborsTemperature(grid, point, boundaryTemp):
    #List to collect all neighboring temperatures
    neighbors = []
    #Loop over all neighboring cells
    for i in range(-1,2):
        for j in range(-1,2):
            try:
                #Ignore corner cells
                if abs(i) != abs(j):
                    neighbors.append(grid[point[0] + i][point[1]+ j])
            except:
                #Apply boundary condition
                neighbors.append(boundaryTemp)

    return neighbors

接下来我们需要开发的函数将实现上一节中找到的二维时间相关热传导方程。我们已经解决了这个方程的计算等效,因此实现是直接的,下面给出。该函数通过获取给定单元格的温度、周围单元格的温度、我们选择的时间步长以及单元材料的热扩散率来完成热传导方程。

def calculateHeat(cellTemp, neighborTemps, timeStep, thermalDiffusivity):
    #Converting equation 3 into code
    cellTemp = cellTemp + timeStep*thermalDiffusivity*((neighborTemps[0] -2*cellTemp + neighborTemps[-1]) + 
                                                       (neighborTemps[1] -2*cellTemp + neighborTemps[-2]))
    return cellTemp

我们需要的最后一个函数将结合之前的两个函数,循环遍历网格中的每个单元格,并更新其温度。对于网格中的每个单元格,我们首先运行“getNeighbors”来获取所有邻近的温度,然后将邻近的温度加上当前单元格的温度以及其他参数传递给更新单元格的温度。

def heatTransfer(grid, timeStep, boundaryTemp):
    #Loop over all grid cells 
    for i in range(0,len(grid)):
        for j in range(0,len(grid)):
            #Get neighboring cell temperatures
            neighbors = getNeighborsTemperature(grid[:,:,0], (i,j), boundaryTemp)
            #Update current cell temperature
            grid[:,:,0][i][j] = calculateHeat(grid[:,:,0][i][j], neighbors, timeStep, grid[:,:,1][i][j])

    return grid

最后,我们可以使用我们的热传递方程来更新原始网格。最终的代码片段将用 30 秒的时间步长更新我们的网格,并比较前后的结果。

#Applying our heat transfer equation on current grid with a 30 secound time step
nextGrid = heatTransfer(fullGridX.copy(),30, 20)

#Matplotlib subplots to show the original and updated mesh grid
plt.figure(figsize = (12,6))

plt.subplot(1,2,1)
plt.imshow(fullGridX[:,:,0])
plt.title("t = 0")

plt.subplot(1,2,2)
plt.imshow(nextGrid[:,:,0])
plt.title("t = 30")

图 4:用 30 秒的时间步长更新我们的网格热传递。

我们现在拥有了进行导热传递模拟所需的所有工具,接下来我们将在下一节中运行几个模拟。

模拟

创建完所有模拟工具后,我们只需将热传递方程放入循环中,并根据我们想要的时间步长来演变系统。给出的代码以及一个将我们的模拟可视化为 gif 的函数。

import imageio
"""
This make gif function takes in a set of images and turns them into a gif!
Give the frames as an array of mesh grids (only the temperature)and a name 
for the gif as well as the temperature bounds for the heat map.
"""
def makeGif(frames,name,minTemp,maxTemp):
    !mkdir frames

    counter=0
    images = []
    for i in range(0,len(frames)):
        plt.figure()
        plt.imshow(frames[i], cmap = "inferno", vmin = minTemp, vmax = maxTemp)
        plt.savefig("frames/" + str(counter)+ ".png")
        images.append(imageio.imread("frames/" + str(counter)+ ".png"))
        counter += 1
        plt.close()

    imageio.mimsave(name, images)

    !rm -r frames
#RUNNING SIMULATION
#Keep track of our mesh grid frames to make into gif
frames = [fullGridX[:,:,0].copy()]

timeStep = 30
boundaryTemp = 20

#Run for 500 time steps
for t in range(0,500):
    fullGridX = heatTransfer(fullGridX.copy(),timeStep,boundaryTemp)
    frames.append(fullGridX[:,:,0].copy())

makeGif(frames,"simulation.gif",20,500)

图 5:我们的导热传递模拟的 gif

就这样。我们已经运行了导热传递的模拟并进行了可视化(使用酷炫的“火焰”色彩图)。凭借我们开发的一套通用工具,我们可以运行许多不同的模拟,包括不同的时间步长、起始温度、材料和几何形状。下一节也是最后一节将展示使用我们的工具制作的几个更多场景。

示例场景

焊接

以下是黄铜和钢板焊接的模拟。值得注意的是,黄铜(左侧)的热导率约为钢材(右侧)的 3 倍,我们可以看到热量散失过程中材料属性的差异。

注意,我们还可以通过在模拟循环内向网格温度层中添加热量来迭代地增加热量。

图 6: 黄铜与钢材焊接的模拟。

布鲁纳

在这里,通过向网格中添加一个圆形热斑来模拟炉灶顶部的加热。热量逐渐添加(每秒 3°C),直到炉灶关闭并允许其冷却。

图 7: 模拟炉灶顶部加热并关闭的过程。

完整代码

import numpy as np
import matplotlib.pyplot as plt
import imageio

def getNeighborsTemperature(grid, point, boundaryTemp):
    neighbors = []
    for i in range(-1,2):
        for j in range(-1,2):
            try:
                if abs(i) != abs(j):
                    neighbors.append(grid[point[0] + i][point[1]+ j])
            except:
                neighbors.append(boundaryTemp)

    return neighbors

def calculateHeat(cellTemp, neighborTemps, timeStep, thermalDiffusivity):
    cellTemp = cellTemp + timeStep*thermalDiffusivity*((neighborTemps[0] -2*cellTemp + neighborTemps[-1]) + 
                                                       (neighborTemps[1] -2*cellTemp + neighborTemps[-2]))
    return cellTemp

def heatTransfer(grid, timeStep, boundaryTemp):

    for i in range(0,len(grid)):
        for j in range(0,len(grid)):
            neighbors = getNeighborsTemperature(grid[:,:,0], (i,j), boundaryTemp)
            grid[:,:,0][i][j] = calculateHeat(grid[:,:,0][i][j], neighbors, timeStep, grid[:,:,1][i][j])

    return grid

def makeGif(frames,name,minTemp,maxTemp):
    !mkdir frames

    counter=0
    images = []
    for i in range(0,len(frames)):
        plt.figure()
        plt.imshow(frames[i], cmap = "inferno", vmin = minTemp, vmax = maxTemp)
        plt.savefig("frames/" + str(counter)+ ".png")
        images.append(imageio.imread("frames/" + str(counter)+ ".png"))
        counter += 1
        plt.close()

    imageio.mimsave(name, images)

    !rm -r frames

#Make mesh grid
resolution = 100
startingTemperature = 20
thermalDiffusivity = 1.11*10**-2

x1 = np.full(resolution, 20)
y1 = np.full(resolution, 20)

x2 = np.full(resolution, 1.11*10**-2)
y2 = np.full(resolution, 1.11*10**-2)

gridX, gridY = np.meshgrid(x1,y1)
gridX2, gridY2 = np.meshgrid(x2,y2)

fullGridX = np.dstack([gridX,gridX2])
fullGridY = np.dstack([gridY,gridY2])

#Add heat
fullGridX[:,:,0][35:65,35:65] += 1000

#Run simulation
frames = [fullGridX[:,:,0].copy()]

timeStep = 30
boundaryTemp = 20

for t in range(0,500):
    fullGridX = heatTransfer(fullGridX.copy(),timeStep,boundaryTemp)
    frames.append(fullGridX[:,:,0].copy())

#Make Gif, saves as "simulation.gif"
makeGif(frames,"simulation.gif",20,500)

参考文献

[1] 有限差分示例:1D 显式热方程 geodynamics.usc.edu/~becker/teaching/557/problem_sets/problem_set_fd_explicit.pdf

[2] 使用 2D 有限差分法进行热传递分析 resources.system-analysis.cadence.com/blog/msa2022-using-the-2d-finite-difference-method-for-heat-transfer-analysis

[3] 热传导方程 cecs.wright.edu/~sthomas/htchapter02.pdf

[4] 除非另有引用,本文中的所有图形均由作者创建。

仿真 104:使用向量场的电磁映射

原文:towardsdatascience.com/simulation-104-electromagnetic-mapping-with-vector-fields-96ab3d5e7637

建模电场和磁场

Le NguyenTowards Data Science Le Nguyen

·发布于 Towards Data Science ·阅读时间 12 分钟·2023 年 8 月 7 日

--

水、火、空气和泥土,磁铁是如何工作的?这不是奇迹,而是科学!我们都曾玩过磁铁,无论是在冰箱上还是在科学课堂上,但我们可能不完全理解磁铁的真正含义或功能。本文将学习电场和磁场背后的基本理论,并学习如何构建一个计算框架来对其进行建模。

图 1:电场示例

在本文中我们将:

  • 学习基本的电磁(EM)理论

  • 创建向量场

  • 使用向量场绘制电磁场

电场和磁场

电磁学是宇宙的四种基本力之一。它是支配带电粒子行为的力量,电场和磁场是这种力量表现出来的方式。在本节中,我们将深入探讨这些场背后的理论。

电场

电场是带电粒子固有的。这就是为什么带电粒子可以相互排斥和吸引的原因。按照惯例,我们说正电荷粒子具有指向外部的电场,而负电荷粒子具有指向内部的电场,如图 2 所示。在吸引的情况下,当一个负电荷和一个正电荷粒子靠近时,场线从正电荷开始并终止于负电荷,如图 1 所示。

图 2:电场线对于正、负和中性电荷粒子。

布朗定律所描述的电荷粒子之间的相互作用力由库仑定律给出。库仑定律指出,电荷之间的力与每个粒子的电荷成正比,并且与它们之间的距离成反比。具体的方程见下方的方程 1。

方程 1:库仑定律

我们现在了解了电场的基本知识,可以继续研究磁场。

磁场

磁场 比电场更复杂一些,因为它们是由移动的电荷产生的。这可以是由于电磁铁中的电流,或者由于铁磁材料中的电子自旋产生的固有磁偶极子。在任何情况下,磁场是施加在磁铁(和移动电荷)之间的力。磁场更复杂,因为与带电粒子不同,没有磁单极子。也就是说,没有正负磁铁,每个磁铁都有一个北极和一个南极,这两个极分别吸引其他磁铁的相反极。这意味着在确定磁场时需要考虑磁铁的方向和几何形状。

图 3:圆柱形磁铁的磁场线

由于磁场的复杂性,我们将考虑最简单的情况:磁铁足够小或相距足够远,以至于它们的几何形状不重要。在这种情况下,磁铁可以被看作是电荷粒子,或者更确切地说,就像图 1 所示的正负电荷紧挨在一起。这种简化的设置被称为磁偶极子-偶极子相互作用,如果我们假设磁铁沿同一方向对齐,它将进一步简化。

在所有这些简化假设下,我们得到了方程 2 描述的磁场,并且磁场施加的力由方程 3 描述。

方程 2:磁偶极子-偶极子场

方程 3:磁偶极子-偶极子力

理解电场和磁场后,我们将继续研究矢量场;这是一种计算方案,将帮助我们建模这些场。

注意:为了简洁和简单性,本介绍文章对电磁场理论进行了许多简化。更高级的解释将在未来的文章中给出。

矢量场

向量是数量的数学表示,这些数量还具有相关的方向。例如,我可以告诉你外面的风速是 25 英里每小时,这是与风速相关的数量或标量,但如果我告诉你风速是 25 英里每小时向东北方向,你现在就有了风速向量,因为你知道风的强度或大小以及风吹的方向。

继续使用这个风速类比,向量只能表示空间中的一个点的风,但假设我们想知道风在整个区域的表现。为此,我们需要一个向量场,这类似于我们在上一篇文章中讨论的网格,只不过我们现在在网格中查看的是向量,而不是点。下面的图 4 展示了一个向量场的例子。如我们所见,存在许多均匀间隔的向量,每个向量都有自己的方向和大小。视觉上,向量的大小或强度由其尺寸表示,因此较大的向量表示其给定方向上的风速更高。

图 4: 向量场的例子

向量场是建模电磁场的完美计算工具,因为正如我们在前一部分看到的那样,它们具有强度和方向。我们现在将深入探讨如何使用向量场来建模电场和磁场。

电磁场映射

为了创建我们的向量场,我们将使用matplotlib 的箭头函数。这个函数将在给定的 x,y 位置绘制一个向量,以及一个 x,y 大小。结合numpy 的网格函数,我们可以通过制作向量网格来创建一个向量场。现在,让我们开始吧。

电场

首先,让我们创建一个来表示带电粒子,我们可以调用它来生成任意数量的粒子。该类将接收一个位置和一个给定的电荷量来初始化一个粒子。

class charge():

    def init(self, position, charge):
        self.position = np.array(position)
        self.charge = charge

现在让我们编写一个库仑定律的函数,该函数将接收库仑常数、粒子的电荷量以及到某个观察点的距离。请注意,此时我们使用库仑定律来查看由于一个粒子而在某个观察点的场强度,因此第二个粒子的电荷量(尚未)不需要。还要注意,我们用* r 除以 r *的大小来表示我们的方向单位向量。

def coulombs_law(k,q,r):
    F = k*q*r/(np.linalg.norm(r))**3
    return F

让我们映射单个带电粒子周围的电场。我们将围绕粒子设置一个观察点的环,并查看场强(大小)和方向。我们的结果将在图 5 中显示。

#Define coulomb's constant and charge of our particle
k = 8.9e9                           
q = 0.1e-6    

#Make our charged particle
particle = charge()
particle.init([0,0],q)

#Define our source and observation points
source = np.array(particle.position)
observations = []
r = []
field = []
#Need to scale our vectors to always have a visible magnitude
visual_scale = []

#Loop over points in a circle around our charge
for i in np.linspace(0,2*np.pi,20):
    observations.append(np.array([0.1*np.cos(i),0.1*np.sin(i)]))          
    r.append(observations[-1] - source)                
    field.append(coulombs_law(k,q,r[-1]))  
    visual_scale.append(np.linalg.norm(r[-1])/np.linalg.norm(field[-1])/2)  

#Plot our particle and the vector field around it
fig = plt.figure(figsize=(5,5))
plt.plot(particle.position[0],particle.position[1],'ro')

for j in range(0, len(observations)):
    plt.arrow(observations[j][0],observations[j][1],visual_scale[j]*field[j][0],visual_scale[j]*field[j][1])  

plt.title("Positively Charged Particle"); 

图 5:模拟的正电荷粒子的电场

我们已经成功模拟了单个带电粒子周围的电场,这很酷,但让我们模拟一些更有趣的东西。比如在图 1 中,一个负电荷旁边有一个正电荷?现在我们将有 2 个带电粒子,它们的电荷大小相等但方向相反,因此我们必须跟踪两个粒子的电场,并查看它们如何叠加到整体场中。这将要求我们找出每个粒子到给定观察点的 2 个距离r,然后找到每个粒子的场的向量贡献,并将它们相加,以获得该观察点的整体场。

在这里,我们将使用观察点的网格来创建一个向量场。有关电荷偶极子的结果,请参见图 6。

#Define coulomb's constant and charge of our particle
k = 8.9e9                           
q = 0.1e-6  

#Make our charges
q1 = charge()
q2 = charge()
q1.init([0.4,0],q)
q2.init([-0.4,0],-q)

#Make a meshgrid of observation points
x = np.linspace(-.8,.8,20)
X,Y = np.meshgrid(x,x)
grid =[]
for i in range(0, len(X)):
    for j in range(0, len(Y)):
        grid.append([X[i][j],Y[i][j]])

#Define our needed quantities
observations = np.array(grid)  
r1 = []
r2 = []
field1 = []
field2 = []
field_total = []

#Loop through observation points
for n in range(0,len(observations)):

    r1.append(observations[n] - q1.position)               
    r2.append(observations[n] - q2.position)

    field1.append(coulombs_law(k,q1.charge,r1[-1]))  
    field2.append(coulombs_law(k,q2.charge,r2[-1]))
    field_total.append(field1[-1]+field2[-1])

#Scale x and y
field_total = np.array(field_total)
scale_x = .3/max(field_total[:,0])
scale_y = .3/max(field_total[:,1])

#Plot vector field
fig = plt.figure(figsize=(5,5))

plt.plot(q1.position[0],q1.position[1],'ro', label = "Positive")
plt.plot(q2.position[0],q2.position[1],'bo', label = "Negitive")

for j in range(0, len(observations)):
    plt.arrow(observations[j][0],observations[j][1],scale_x*field_total[j][0],scale_y*field_total[j][1], head_width = .025)  

plt.title("Electric Field of Positive and Negative Charge")
plt.show()

图 6:模拟的电荷偶极子的电场

我们现在已经模拟了如图 1 所示的正负电荷的电场。如我们所见,所有向量的箭头都指向正确的方向;从正电荷(红色)到负电荷(蓝色)。由于观察点靠近带电粒子,存在一些视觉噪声。由于库仑定律的制定方式,当距离粒子接近 0 时,场强会趋于无穷大;因此,我们会在靠近电荷的点看到大的向量。我们可以选择对它们进行归一化、去除,或仅保留作为合理性检查。

现在我们已经模拟了电场,是时候转向磁场了。

磁场

当我们将磁场情景简化为类似电场时,代码看起来会非常相似。我们将定义一个磁铁类和一个偶极-偶极相互作用函数。注意,与方程 2 不同,我使用了 x 和 y 坐标而不是距离r。这样可以使计算更清晰,因为与粒子电荷不同,偶极子(m)本身是一个具有 x,y 分量的向量。还需注意的是,我们这里将映射磁场而不是磁力。磁场施加的磁力取决于带电粒子/其他磁体如何通过它,这超出了本文的范围。

class magnet():

    def init(self, position, dipole):
        self.position = np.array(position)
        self.dipole = dipole

def dipole_dipole(u_0,m,x,y):

    r = np.sqrt(x**2 + y**2)
    r_hat_x = x / r
    r_hat_y = y / r

    factor = (u_0 / (4 * np.pi)) * (3 * (m[0] * r_hat_x + m[1] * r_hat_y))

    Bx = factor * r_hat_x - m[0]
    By = factor * r_hat_y - m[1]

    Bx /= r**3
    By /= r**3

    return Bx, By

现在让我们模拟磁场。在这里,我们将模拟电子的磁场并使用其偶极子。

#Define magnetic constant and dipole of our magnet
u_0 = 4*np.pi*10**-7
#dipole on an electron
m = (-9.28*10**-24,0)   

#Make our magnets
m1 = magnet()
m1.init([.1,0],m)

#Make a meshgrid of observation points
x = np.linspace(-.5,.5,20)
X,Y = np.meshgrid(x,x)
grid =[]
for i in range(0, len(X)):
    for j in range(0, len(Y)):
        grid.append([X[i][j],Y[i][j]])

#Define our needed quantities
observations = np.array(grid)  
field1 = []

#Loop through observation points
for n in range(0,len(observations)):

    x = observations[n][0] - m1.position[0]               
    y = observations[n][0] - m1.position[1]

    field1.append(dipole_dipole(u_0,m1.dipole,x,y))  

#Scale x and y
field_total = np.array(field1)
scale_x = .1/max(field_total[:,0])
scale_y = .1/max(field_total[:,1])

#Plot vector field
fig = plt.figure(figsize=(5,5))

plt.plot(m1.position[0],m1.position[1],'rs')

for j in range(0, len(observations)):
    plt.arrow(observations[j][0],observations[j][1],scale_x*field1[j][0],scale_y*field1[j][1], head_width = .02)  

plt.title("Magnetic Field of an Electron")

图 7:模拟的电子的磁场

在这里我们看到的是我们模拟的电子磁场。它看起来有点奇怪,因为我们只在二维空间中工作,所以我们实际上只看到磁场的一个垂直切片。在三维中,我们会看到矢量向我们弯曲并绕着电子缠绕

我们也会看到两个电子相同的效果。

#Define coulomb's constant and charge of our particle
u_0 = 4*np.pi*10**-7                          
m = (-9.28*10**-24,0)   

#Make our charges
m1 = magnet()
m2 = magnet()
m1.init([0,.2],m)
m2.init([0,0],m)

#Make a meshgrid of observation points
x = np.linspace(-1,1,30)
X,Y = np.meshgrid(x,x)
grid =[]
for i in range(0, len(X)):
    for j in range(0, len(Y)):
        grid.append([X[i][j],Y[i][j]])

#Define our needed quantities
observations = np.array(grid)  
field1 = []
field2 = []
field_total = []

#Loop through observation points
for n in range(0,len(observations)):

    x1 = observations[n][0] - m1.position[0]               
    y1 = observations[n][0] - m1.position[1]

    x2 = observations[n][0] - m2.position[0]               
    y2 = observations[n][0] - m2.position[1]

    field1.append(dipole_dipole(u_0,m1.dipole,x1,y1))  
    field2.append(dipole_dipole(u_0,m2.dipole,x2,y2))
    field_total.append(field1[-1]+field2[-1])

#Scale x and y
field_total = np.array(field_total)
scale_x = .1/max(field_total[:,0])
scale_y = .1/max(field_total[:,1])

#Plot vector field
fig = plt.figure(figsize=(5,5))

plt.plot(m1.position[0],m1.position[1],'ro', label = "Positive")
plt.plot(m2.position[0],m2.position[1],'bo', label = "Negitive")

for j in range(0, len(observations)):
    plt.arrow(observations[j][0],observations[j][1],scale_x*field_total[j][0],scale_y*field_total[j][1], head_width = .025)  

plt.title("Magnetic Field of 2 Electrons")

图 8:两个电子的磁场(它们通过不同的颜色表示,但具有相同的偶极子)

这可能不是最直观的结果,但在未来的文章中,我们将进行三维电磁场映射,并结合几何学来更好地定义场。

完整代码

电场

import numpy as np
import matplotlib.pyplot as plt

class charge():

    def init(self, position, charge):
        self.position = np.array(position)
        self.charge = charge

def coulombs_law(k,q,r):
    F = k*q*r/(np.linalg.norm(r))**3
    return F

#Define coulomb's constant and charge of our particle
k = 8.9e9                           
q = 0.1e-6  

#Make our charged particle
particle = charge()
particle.init([0,0],q)

#Define our source and observation points
source = np.array(particle.position)
observations = []
r = []
field = []
#Need to scale our vectors to always have a visible magnitude
visual_scale = []

#Loop over points in a circle around our charge
for i in np.linspace(0,2*np.pi,20):
    observations.append(np.array([0.1*np.cos(i),0.1*np.sin(i)]))          
    r.append(observations[-1] - source)                
    field.append(coulombs_law(k,q,r[-1]))  
    visual_scale.append(np.linalg.norm(r[-1])/np.linalg.norm(field[-1])/2)  

#Plot our particle and the vector field around it
fig = plt.figure(figsize=(5,5))
plt.plot(particle.position[0],particle.position[1],'ro')

for j in range(0, len(observations)):
    plt.arrow(observations[j][0],observations[j][1],visual_scale[j]*field[j][0],visual_scale[j]*field[j][1])  

plt.title("Positively Charged Particle");
#Make our charges
q1 = charge()
q2 = charge()
q1.init([0.4,0],q)
q2.init([-0.4,0],-q)

#Make a meshgrid of observation points
x = np.linspace(-.8,.8,20)
X,Y = np.meshgrid(x,x)
grid =[]
for i in range(0, len(X)):
    for j in range(0, len(Y)):
        grid.append([X[i][j],Y[i][j]])

#Define our needed quantities
observations = np.array(grid)  
r1 = []
r2 = []
field1 = []
field2 = []
field_total = []

#Loop through observation points
for n in range(0,len(observations)):

    r1.append(observations[n] - q1.position)               
    r2.append(observations[n] - q2.position)

    field1.append(coulombs_law(k,q1.charge,r1[-1]))  
    field2.append(coulombs_law(k,q2.charge,r2[-1]))
    field_total.append(field1[-1]+field2[-1])

#Scale x and y
field_total = np.array(field_total)
scale_x = .3/max(field_total[:,0])
scale_y = .3/max(field_total[:,1])

#Plot vector field
fig = plt.figure(figsize=(5,5))

plt.plot(q1.position[0],q1.position[1],'ro', label = "Positive")
plt.plot(q2.position[0],q2.position[1],'bo', label = "Negitive")

for j in range(0, len(observations)):
    plt.arrow(observations[j][0],observations[j][1],scale_x*field_total[j][0],scale_y*field_total[j][1], head_width = .025)  

plt.title("Electric Field of Positive and Negative Charge") 

磁场

import numpy as np
import matplotlib.pyplot as plt
class magnet():

    def init(self, position, dipole):
        self.position = np.array(position)
        self.dipole = dipole

def dipole_dipole(u_0,m,x,y):

    r = np.sqrt(x**2 + y**2)
    r_hat_x = x / r
    r_hat_y = y / r

    factor = (u_0 / (4 * np.pi)) * (3 * (m[0] * r_hat_x + m[1] * r_hat_y))

    Bx = factor * r_hat_x - m[0]
    By = factor * r_hat_y - m[1]

    Bx /= r**3
    By /= r**3

    return Bx, By

#Define magnetic constant and dipole of our magnet
u_0 = 4*np.pi*10**-7
#dipole on an electron
m = (-9.28*10**-24,0)   

#Make our magnets
m1 = magnet()
m1.init([.1,0],m)

#Make a meshgrid of observation points
x = np.linspace(-.5,.5,20)
X,Y = np.meshgrid(x,x)
grid =[]
for i in range(0, len(X)):
    for j in range(0, len(Y)):
        grid.append([X[i][j],Y[i][j]])

#Define our needed quantities
observations = np.array(grid)  
field1 = []

#Loop through observation points
for n in range(0,len(observations)):

    x = observations[n][0] - m1.position[0]               
    y = observations[n][0] - m1.position[1]

    field1.append(dipole_dipole(u_0,m1.dipole,x,y))  

#Scale x and y
field_total = np.array(field1)
scale_x = .1/max(field_total[:,0])
scale_y = .1/max(field_total[:,1])

#Plot vector field
fig = plt.figure(figsize=(5,5))

plt.plot(m1.position[0],m1.position[1],'rs')

for j in range(0, len(observations)):
    plt.arrow(observations[j][0],observations[j][1],scale_x*field1[j][0],scale_y*field1[j][1], head_width = .02)  

plt.title("Magnetic Field of an Electron")
#Make our charges
m1 = magnet()
m2 = magnet()
m1.init([0,.2],m)
m2.init([0,0],m)

#Make a meshgrid of observation points
x = np.linspace(-1,1,30)
X,Y = np.meshgrid(x,x)
grid =[]
for i in range(0, len(X)):
    for j in range(0, len(Y)):
        grid.append([X[i][j],Y[i][j]])

#Define our needed quantities
observations = np.array(grid)  
field1 = []
field2 = []
field_total = []

#Loop through observation points
for n in range(0,len(observations)):

    x1 = observations[n][0] - m1.position[0]               
    y1 = observations[n][0] - m1.position[1]

    x2 = observations[n][0] - m2.position[0]               
    y2 = observations[n][0] - m2.position[1]

    field1.append(dipole_dipole(u_0,m1.dipole,x1,y1))  
    field2.append(dipole_dipole(u_0,m2.dipole,x2,y2))
    field_total.append(field1[-1]+field2[-1])

#Scale x and y
field_total = np.array(field_total)
scale_x = .1/max(field_total[:,0])
scale_y = .1/max(field_total[:,1])

#Plot vector field
fig = plt.figure(figsize=(5,5))

plt.plot(m1.position[0],m1.position[1],'ro', label = "Positive")
plt.plot(m2.position[0],m2.position[1],'bo', label = "Negitive")

for j in range(0, len(observations)):
    plt.arrow(observations[j][0],observations[j][1],scale_x*field_total[j][0],scale_y*field_total[j][1], head_width = .025)  

plt.title("Magnetic Field of 2 Electrons")

参考文献

  1. 本文中使用的所有图形要么由作者生成,要么符合创意共享许可证 CC BY-SA,正如原始图像创作者所声明的那样。

  2. 《电动力学导论》第四版 hansandcassady.org/David%20J.%20Griffiths-Introduction%20to%20Electrodynamics-Addison-Wesley%20(2012).pdf

posted @ 2024-10-12 19:54  绝不原创的飞龙  阅读(152)  评论(0)    收藏  举报