docs-merge-13

TowardsDataScience 2024 中文翻译(十四)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

精通样本大小计算

原文:towardsdatascience.com/mastering-sample-size-calculations-75afcddd2ff3?source=collection_archive---------6-----------------------#2024-10-09

A/B 测试、拒绝推断以及如何为实验获取正确的样本大小

Lucas BragaTowards Data Science Lucas Braga

·发表于 Towards Data Science ·阅读时间 17 分钟·2024 年 10 月 9 日

--

图片由作者制作

针对不同的场景,有不同的统计公式。首先要问的问题是:你是在比较两组数据,比如在A/B 测试中,还是在从一个足够大的群体中选择样本,以使其能够代表该群体?

后者通常用于像交易中的保留组这样的场景。这些保留组对于评估欺诈预防规则的性能或拒绝推断非常关键,在这种情况下,欺诈检测的机器学习模型会重新训练。保留组的优势在于,它包含了没有被任何规则或模型阻止的交易,从而提供了一个不偏不倚的性能视角。然而,为了确保保留组具有代表性,你需要选择一个能准确反映群体的样本大小,这一点将在本文中与 A/B 测试的样本大小一起探讨。

在确定是否比较两组(如 A/B 测试)或选择一个代表性样本(如拒绝推断)之后,下一步是定义你的成功指标。它是一个比例还是一个绝对数值?例如,比较两个比例可能涉及转化率或违约率,其中违约交易的数量除以交易总数。另一方面,比较两个均值适用于处理绝对值的情况,比如总收入或 GMV(商品交易总值)。在这种情况下,你会比较每个客户的平均收入,假设你的实验中存在客户级别的随机化。

1. 比较两组(例如 A/B 测试)——样本大小

第 1.1 节是关于比较两个均值的,但这里提出的大多数原理在第 1.2 节中也适用。

1.1. 比较两个均值(度量绝对数值的平均值)

在这种情况下,我们比较两个组:一个控制组和一个处理组。控制组由通过借贷计划获得 100 欧元信用额度的客户组成,而处理组由通过同一计划获得 200 欧元信用额度的客户组成。

实验的目标是确定是否增加信用额度会导致客户支出增加。

我们的成功度量定义为每位客户每周的平均支出金额,以欧元为单位。

在确立目标和成功度量后,在典型的 A/B 测试中,我们还会定义假设、随机化单元(在本例中为客户)以及目标人群(获得信用的新客户)。然而,由于本文的重点是样本量,我们将在这里不讨论这些细节。

我们将比较控制组和处理组之间的每位客户每周的平均支出。接下来,我们将使用以下脚本来计算这个度量:

脚本 1:计算成功度量,分支:德国,周期:2024 年 5 月 1 日至 2024 年 7 月 31 日。

WITH customer_spending AS (
SELECT
 branch_id,
 FORMAT_DATE('%G-%V', DATE(transaction_timestamp)) AS week_of_year,
 customer_id,
 SUM(transaction_value) AS total_amount_spent_eur
FROM `project.dataset.credit_transactions`
WHERE 1=1
 AND transaction_date BETWEEN '2024-05-01' AND '2024-07-31'
 AND branch_id LIKE 'Germany'
GROUP BY branch_id, week_of_year, customer_id
)
, agg_per_week AS (
SELECT
 branch_id,
 week_of_year,
 ROUND(AVG(total_amount_spent_eur), 1) AS avg_amount_spent_eur_per_customer,
FROM customer_spending
GROUP BY branch_id, week_of_year
)
SELECT *
FROM agg_per_week
ORDER BY 1,2;

在结果中,我们观察到每周的度量avg_amount_spent_eur_per_customer。在过去的四周里,数值保持相对稳定,介于 35 欧元和 54 欧元之间。然而,考虑到过去两个月的所有周,方差较大。(请参见图像 1 作为参考。)

图像 1:脚本 1 的结果。

接下来,我们计算成功度量的方差。为此,我们将使用脚本 2来计算所有周的方差和平均值。

脚本 2:查询计算成功度量的方差,并计算所有周的平均值。

WITH customer_spending AS (
SELECT
 branch_id,
 FORMAT_DATE('%G-%V', DATE(transaction_timestamp)) AS week_of_year,
 customer_id,
 SUM(transaction_value) AS total_amount_spent_eur
FROM `project.dataset.credit_transactions`
WHERE 1=1
 AND transaction_date BETWEEN '2024-05-01' AND '2024-07-31'
 AND branch_id LIKE 'Germany'
GROUP BY branch_id, week_of_year, customer_id
)
, agg_per_week AS (
SELECT
 branch_id,
 week_of_year,
 ROUND(AVG(total_amount_spent_eur), 1) AS avg_amount_spent_eur_per_customer,
FROM customer_spending
GROUP BY branch_id, week_of_year
)
SELECT
 ROUND(AVG(avg_amount_spent_eur_per_customer),1) AS avg_amount_spent_eur_per_customer_per_week,
 ROUND(VAR_POP(avg_amount_spent_eur_per_customer),1) AS variance_avg_amount_spent_eur_per_customer
FROM agg_per_week
ORDER BY 1,2;

脚本 2的结果显示方差约为 145.8(请参见图像 2)。此外,考虑到过去两个月的所有周,每位用户的平均消费金额49.5 欧元

图像 2:脚本 2 的结果。

现在我们已经计算出度量并发现每位客户的平均每周支出约为49.5 欧元,我们可以定义最小可检测效应(MDE)。鉴于信用从 100 欧元增加到 200 欧元,我们旨在检测10%的支出增加,这相当于每位客户每周的新的平均值为54.5 欧元

计算方差(145.8)并确定 MDE 后,我们可以将这些值代入公式中,计算所需的样本量。我们将使用默认值alpha(5%)beta(20%)

  • 显著性水平(Alpha 的默认值为α = 5%):Alpha 是一个预定的阈值,用作拒绝零假设的标准。Alpha 是 I 类错误(假阳性),p 值需要低于 Alpha,才能拒绝零假设。

  • 统计功效(Beta 的默认值为β = 20%):这是测试在替代假设为真时正确拒绝零假设的概率,即在效应存在时检测到效应。统计功效 = 1 — β,β是 II 类错误(假阴性)。

这是计算在典型的 A/B 测试场景中每组(对照组和治疗组)所需样本量的公式:

图像 3:计算比较两个均值时样本量的公式。

  • n是每组的样本量。

  • σ²是正在测试的度量的方差(在本例中为145.8)。因子 2σ²被使用,因为我们计算的是合并方差,这样在比较两个样本时不会产生偏差。

  • δ(德尔塔)表示最小可检测的均值差异(效应大小),即我们希望检测的变化。其计算公式为:δ² = (μ₁ — μ₂)²,其中μ₁是对照组的均值,μ₂是治疗组的均值。

  • Zα/2是与相应置信水平对应的z 值(例如,1.96对应于95%置信水平)。

  • 是与所需测试功效相关的z 值(例如,0.84对应于80%功效)。

n = (2 * 145.8 * (1.96+0.84)²) / (54.5-49.5)²
-> n = 291.6 * 7.84 / 25
-> n = 2286.1 / 25
-> n =~ 92

在我的网页应用程序计算器中尝试,样本量计算器,如应用截图 1所示:

  • 置信水平:95%

  • 统计功效:80%

  • 方差:145.8

  • 需要检测的差异(德尔塔):5(因为预期的变化是从€49.50 到€54.50)

应用截图 1:计算比较两个均值的样本量。

根据之前的计算,我们需要92 个用户在对照组中,92 个用户在治疗组中,总共需要184 个样本

现在,让我们探讨一下改变最小可检测效应(MDE)如何影响样本量。较小的 MDE 需要更大的样本量。例如,如果我们希望检测每个用户平均€1 的增加,而不是之前使用的€5 的增加(10%),所需的样本量将显著增加。

MDE 越小,测试需要越敏感,这意味着我们需要更大的样本量才能可靠地检测到如此微小的效应。

n = (2 * 145.8 * (1.96+0.84)²) / (50.5-49.5)²
-> n = 291.6 * 7.84 / 1
-> n = 2286.1 / 1
-> n =~ 2287

我们将以下参数输入到样本量计算器的网页应用程序计算器中,如应用截图 2所示:

  • 置信水平:95%

  • 统计功效:80%

  • 方差:145.8

  • 需要检测的差异(德尔塔):1(因为预期的变化是从€49.50 到€50.50)

应用截图 2:计算 Delta = 1 时比较两个均值的样本量。

为了检测更小的效应,例如每个用户增加 1 欧元,我们需要2,287 个用户在对照组,2,287 个用户在处理组,总共需要4,574 个样本

接下来,我们将调整统计功效显著性水平,重新计算所需的样本量。但首先,让我们看一下z 值表,以了解如何推导出z 值

我们已经设置了β = 0.2,这意味着当前的统计功效为80%。参考 z 值表(见图像 4),这对应于z 值为 0.84,这是我们之前公式中使用的值。

图像 4:在 z 值表中查找统计功效为 80% 时的 z 值。

如果我们现在将β值调整为 10%,即统计功效为 90%,我们将得到z 值为 1.28。该值可以在 z 值表中找到(见图像 5)。

n = (2 * 145.8 * (1.96+1.28)²) / (50.5-49.5)²
-> n = 291.6 * 10.49 / 1
-> n = 3061.1 / 1
-> n =~ 3062

通过调整β值为 10%(统计功效为 90%)并使用z 值为 1.28,我们现在需要3,062 个用户在对照组和处理组中,每组总共6,124 个样本

图像 5:在 z 值表中查找统计功效为 90% 时的 z 值。

现在,让我们来确定6,124 个样本所代表的流量。我们可以通过计算每周不同客户的平均交易量来得出这个数据。脚本 3将帮助我们获取2024 年 5 月 1 日到 2024 年 7 月 31 日期间的信息。

脚本 3:计算不同客户每周平均交易量的查询。

WITH customer_volume AS (
SELECT
 branch_id,
 FORMAT_DATE('%G-%V', DATE(transaction_timestamp)) AS week_of_year,
 COUNT(DISTINCT customer_id) AS cntd_customers
FROM `project.dataset.credit_transactions`
WHERE 1=1
 AND transaction_date BETWEEN '2024-05-01' AND '2024-07-31'
 AND branch_id LIKE 'Germany'
GROUP BY branch_id, week_of_year
)
SELECT
 ROUND(AVG(cntd_customers),1) AS avg_cntd_customers
FROM customer_volume;

脚本 3的结果显示,平均每周有185,443 个不同客户(见图像 5)。因此,6,124 个样本大约占总每周客户基数的3.35%

图像 5:脚本 3 的结果。

1.2. 比较两个比例(例如转化率、违约率)

虽然前一节讨论的大多数原则保持不变,但比较两种比例的公式有所不同。这是因为,我们不再预先计算指标的方差,而是将重点放在每个组的预期成功比例上(见图像 6)。

图像 6:计算比较两种比例的样本量的公式。

让我们回到相同的场景:我们正在比较两个组。对照组由可以获得100 欧元信贷的客户组成,信贷借款计划中的处理组由可以获得200 欧元信贷的客户组成。

这次,我们关注的成功指标是违约率。这可能是1.1 节中讨论的同一个实验的一部分,其中违约率作为护栏指标,也可能是完全独立的实验。无论是哪种情况,假设是给予客户更多信用可能会导致更高的违约率。

本次实验的目标是确定信用额度的增加是否会导致更高的违约率

我们将成功指标定义为实验周内所有客户的平均违约率。理想情况下,实验应持续更长时间以收集更多数据,但如果无法实现这一点,选择一个没有偏差的星期就显得尤为重要。你可以通过分析过去12 到 16 周的违约率来验证这一点,从中找出与某些月份的特定星期相关的模式。

让我们来查看数据。脚本 4将显示每周违约率,结果可以在图像 7中看到。

脚本 4:查询以获取每周违约率。

SELECT
 branch_id,
 date_trunc(transaction_date, week) AS week_of_order,
 SUM(transaction_value) AS sum_disbursed_gmv,
 SUM(CASE WHEN is_completed THEN transaction_value ELSE 0 END) AS sum_collected_gmv,
 1-(SUM(CASE WHEN is_completed THEN transaction_value ELSE 0 END)/SUM(transaction_value)) AS default_rate,
FROM `project.dataset.credit_transactions`
WHERE transaction_date BETWEEN '2024-02-01' AND '2024-04-30'
  AND branch_id = 'Germany'
GROUP BY 1,2
ORDER BY 1,2;

从违约率指标来看,我们注意到一些波动,尤其是在较早的几周,但过去 5 周的违约率相对稳定。过去 5 周的平均违约率为0.070

图像 7:每周违约率的结果。

现在,假设这个违约率代表了对照组。那么下一个问题是:在处理组中,什么样的违约率会被认为是不可接受的?我们可以设定一个阈值:如果处理组的违约率增加到0.075,就会被认为过高。然而,任何低于0.0749的违约率仍然是可以接受的。

违约率为0.075意味着相较于对照组的0.070,增加了大约7.2%。这个差异——7.2%——是我们的最小可检测效应 (MDE)

有了这些数据点,我们现在可以计算所需的样本大小

n = ( ((1.96+0.84)²) * ((0.070*(1-0.070) + 0.075*(1-0.075)) ) / ( (0.070-0.075)² )
-> n = 7.84 * 0.134475 / 0.000025
-> n = 1.054284 / 0.000025
-> n =~ 42,171

我们将以下参数输入到样本大小计算器的网页应用程序中,如应用截图 3所示:

  • 置信水平:95%

  • 统计功效:80%

  • 第一次比例 (p1):0.070

  • 第二次比例 (p2):0.075

应用截图 3:计算比较两种比例的样本大小。

为了检测违约率的7.2%增长(从0.0700.075),我们需要在对照组和处理组中各有42,171 个用户,因此总样本数为84,343 个

样本大小为84,343相当大!我们可能甚至没有足够的客户来进行这个分析。但让我们探讨一下为什么会这样。我们没有改变alphabeta的默认参数,这意味着我们将显著性水平保持在默认的5%,并将统计功效保持在默认的80%。正如我们之前讨论过的,我们本可以通过选择较低的显著性水平来减少假阳性的机会,或者我们可以提高统计功效来最小化假阴性的风险。

那么,是什么导致了如此大的样本量?是7.2%的 MDE吗?简短的回答是:不完全是

考虑这种替代情境:我们保持相同的显著性水平(5%)统计功效(80%)MDE(7.2%),但假设默认比例(p₁)0.23(23%)而不是0.070(7.0%)。在7.2%的 MDE下,治疗组的新的默认比例(p₂)将是0.2466(24.66%)。注意,这仍然是7.2%的 MDE,但比例明显高于0.070(7.0%)0.075(7.5%)

现在,当我们使用这些新的值p₁ = 0.23p₂ = 0.2466进行样本量计算时,结果会有所不同。接下来,我们来计算一下。

n = ( ((1.96+0.84)²) * ((0.23*(1-0.23) + 0.2466*(1-0.2466)) ) / ( (0.2466-0.23)² )
-> n = 7.84 * 0.3628 / 0.00027556
-> n = 2.8450 / 0.00027556
-> n =~ 10,325

使用新的默认比例(p₁ = 0.23p₂ = 0.2466),我们在对照组和治疗组中各需要10,325 个用户,总共需要20,649 个样本。与之前的 84,343 样本相比,这样的样本量更加可管理。然而,值得注意的是,这种情况下的默认比例处于完全不同的范围。

关键的结论是,较低的成功率(如接近7%的默认比例)需要更大的样本量。当比例较小的时候,检测到即使是适度的差异(比如 7.2%的增长)也变得更加困难,因此需要更多的数据才能实现相同的统计功效和显著性水平。

2. 抽样一个群体

这个案例与 A/B 测试场景有所不同,因为我们现在专注于从单一群体中确定样本量。目标是选取一个能够准确代表群体的样本,从而进行分析并推断结果,进而估算整个群体可能发生的情况。

即使我们不在比较两个群体,从群体中抽样(一个群体)仍然需要决定你是在估算均值还是比例。这两种情况的公式与 A/B 测试中使用的公式非常相似。

看一下图像 8图像 9。当你将图像 8图像 3(比较两个均值的样本量公式)进行对比,或者将图像 9图像 6(比较两个比例的样本量公式)进行对比时,你是否注意到了相似之处?它们确实非常相似。

图像 8:用于估算总体均值的样本量公式。

在估算均值的情况下:

  • 从图像 8 中,抽样自单一组的公式则使用了E,表示误差

  • 从图像 3 中,比较两组的公式使用了delta (δ) 来比较两者均值之间的差异。

图像 9:用于估算总体比例的样本量公式。

在估算比例的情况下:

  • 从图像 9 可以看出,单组抽样的比例公式也使用了E,代表误差

  • 从图像 6 中,比较两组的公式使用了MDE(最小可检测效应),类似于差值,用于比较两种比例之间的差异。

那么,我们什么时候使用这些公式呢?让我们探讨两个实际例子——一个用于估算均值,另一个用于估算比例

2.1. 抽样总体 — 估算均值

假设你希望更好地评估欺诈风险,为此,你打算估算每个国家每周的欺诈交易平均订单值。这可能相当具有挑战性,因为理想情况下,大多数欺诈交易已经被阻止。为了获得更清晰的图像,你将采用一个保留组,它不受任何规则和模型的影响,作为计算欺诈交易的真实平均订单值的参考。

假设你选择了一个特定的国家,并在回顾历史数据后发现:

  • 该指标的方差为€905

  • 欺诈交易的平均订单值为€100

    (你可以参考脚本 1 和 2来计算成功指标和方差。)

由于方差为€905标准差(方差的平方根)约为€30。现在,使用5%的显著性水平,对应的z-score 为 1.96,并假设你能接受10%的误差范围(代表€10的误差,或是€100 的 10%),那么在 95%的置信区间下,意味着在正确的样本量下,你可以95%的置信度断言平均值落在€90 到€110之间。

现在,将这些输入值代入样本量公式:

n = ( (1.96 * 30) / 10 )²
-> n = (58.8/10)²
-> n = 35

我们将以下参数输入到网页应用计算器中,在样本量计算器中,如应用截图 4所示:

  • 置信水平:95%

  • 方差:905

  • 误差:10

应用截图 4:在抽样一个总体时,计算估算均值的样本量。

结果是,你需要35 个样本来估算每个国家每周的欺诈交易平均订单值。然而,这还不是最终的样本量。

由于欺诈交易相对较少,你需要调整欺诈交易的比例。如果欺诈交易的比例为1%,那么你需要收集的实际样本数量为:

n = 35/0.01
-> n = 3500

因此,你需要 3,500 个样本,以确保欺诈交易得到适当的代表。

2.2. 抽样总体——估计一个比例

在这种情况下,我们的欺诈规则和模型已经阻止了大量交易。为了评估我们的规则和模型的效果,我们需要让一部分流量绕过规则和模型,以便评估实际的假阳性率。这部分未经过滤的交易被称为留存组。这是欺诈数据科学团队中的常见做法,因为它既能评估规则和模型的表现,又能将留存组用于拒绝推断

尽管我们在这里不会详细讨论拒绝推断,但值得简要总结一下。拒绝推断涉及使用未被阻止的交易的留存组数据,学习有助于改进交易阻止决策的模式。对此有多种方法,其中模糊增强是一种流行的方法。其基本思路是使用留存组的数据重新标记之前被拒绝的交易,并用这些数据训练新的模型。在欺诈建模中,这尤其重要,因为欺诈率通常较低(通常低于 1%,有时甚至低于 0.1%)。增加标记数据可以显著提升模型表现。

现在我们明白了估计比例的需求,接下来我们将深入探讨一个实际应用案例,看看需要多少样本。

对于某一分支,你分析历史数据发现该分支每月处理50,000,000 个订单,其中50,000 个订单是欺诈的,导致0.1%的欺诈率。使用5%的显著性水平(alpha)25%的误差范围,我们旨在估计真实的欺诈比例,置信区间为95%。这意味着,如果真实的欺诈率为0.001(0.1%),我们估计的范围应在0.000750.00125之间,错误为0.00025

请注意,误差范围(margin of error)和错误(Error)是两个不同的概念,误差范围是一个百分比值,而错误是一个绝对值。在欺诈率为 0.1%的情况下,如果我们的误差范围是 25%,那么它代表的错误为 0.00025。

让我们应用公式:

  • Zα/2 = 1.96(95%置信水平的 z 值)

  • E = 0.00025(错误)

  • p = 0.001(欺诈率)

Zalpha/2= 1.96 
-> (Zalpha/2)²= 3.8416
E = 0.00025
-> E² = 0.0000000625
p = 0.001

n =( 3.8416 * 0.001 * (1 - 0.001) ) / 0.0000000625
-> n = 0.0038377584 / 0.0000000625
-> n = 61,404

我们在样本大小计算器的网页应用程序中输入以下参数,如应用截图 5所示:

  • Confidence Level: 95%

  • Proportion: 0.001

  • Error: 0.00025

应用截图 5:计算估算比例的样本大小,当从一个总体中抽样时。

因此,共需 61,404 个样本。假设一个月内有50,000,000 笔交易,如果保留组代表100%的流量,收集这些样本将花费不到 1 小时。然而,这对于一个可靠的实验来说并不现实。

相反,你应该将流量分布到几天,以避免季节性问题。理想情况下,你应收集至少一周的数据,确保覆盖所有工作日,并避免假期或高峰季节。如果你需要在一周内收集61,404 个样本,那么你需要每天收集8,772 个样本。由于每日流量约为1,666,666 个订单,保留组每天需要代表0.53%的总交易量,并持续一周。

最后的备注

如果你想在 Python 中执行这些计算,以下是相关的函数:

import math

def sample_size_comparing_two_means(variance, z_alpha, z_beta, delta):
   return math.ceil((2 * variance * (z_alpha + z_beta) ** 2) / (delta ** 2))

def sample_size_comparing_two_proportions(p1, p2, z_alpha, z_beta):
   numerator = (z_alpha + z_beta) ** 2 * ((p1 * (1 - p1)) + (p2 * (1 - p2)))
   denominator = (p1 - p2) ** 2
   return math.ceil(numerator / denominator)

def sample_size_estimating_mean(variance, z_alpha, margin_of_error):
   sigma = variance ** 0.5
   return math.ceil((z_alpha * sigma / margin_of_error) ** 2)

def sample_size_estimating_proportion(p, z_alpha, margin_of_error):
   return math.ceil((z_alpha ** 2 * p * (1 - p)) / (margin_of_error ** 2))

下面是如何计算与第 1.1 节中应用截图 1 类似的两组均值比较的样本量:

variance = 145.8
z_alpha = 1.96
z_beta = 0.84
delta = 5

sample_size_comparing_two_means(
    variance=variance, 
    z_alpha=z_alpha, 
    z_beta=z_beta, 
    delta=delta
)
# OUTPUT: 92

这些功能也可以在 GitHub 仓库中找到:GitHub 样本大小计算器,你还可以在那里找到交互式样本大小计算器的链接。

免责声明:与 Google BigQuery 作业结果相似的图像由作者创建。所示数字并非基于任何商业数据,而是为了说明目的手动生成的。SQL 脚本也是如此——它们并非来自任何企业,也都是手动生成的。然而,它们旨在与使用Google BigQuery作为框架的公司可能遇到的情况高度相似。

  • 该计算器使用 Python 编写,并在 Google Cloud Run(无服务器环境)中通过 Docker 容器和 Streamlit 部署,参考请见GitHub 上的代码

🪜 精通辛普森悖论——我的因果关系入门

原文:towardsdatascience.com/mastering-simpsons-paradox-my-gateway-drug-to-causality-87e10b613a80?source=collection_archive---------3-----------------------#2024-10-07

警告:你将再也不会以同样的方式看待数据

Eyal KazinTowards Data Science Eyal Kazin

·发表于Towards Data Science ·阅读时间:18 分钟·2024 年 10 月 7 日

--

图片由Jake Hills提供,来源:Unsplash

在认识到相关性不意味着因果关系后,精通辛普森悖论是人们走向因果思维的重要一步 👣。

大多数研究人员和从业者很可能在某个阶段或另一阶段遇到过这个现象。我可以证明,在 2020 年,即使我拥有 15 年的数据分析和解读经验,我仍然没有意识到这是一个我应该关注的已知问题。

如果你还没有接触过辛普森悖论——准备好被震撼吧 🤯。如果你已经接触过,我希望能揭示一些你可能没有考虑过的方面 💡。

这是一本关于识别和解决辛普森悖论的初学者指南,帮助你避免误解数据。

为了理解本文中这个反直觉的主题,你将会:

  • 深入探讨一个定量例子

  • 考察一种定性方法——学习它在图形模型中的表现。

定量分析的重要性在于,你在分析数据时会识别出悖论并为之解决。定性角度的价值则在于……

掌握 SQL 优化:从功能性到高效的查询

原文:towardsdatascience.com/mastering-sql-optimization-from-functional-to-efficient-queries-74d8692f10be?source=collection_archive---------0-----------------------#2024-07-01

帮助我每天减少 50 小时 Snowflake 查询时间的六个简单而有效的 SQL 技巧

Yu DongTowards Data Science 于东

·发表在Towards Data Science ·阅读时长 9 分钟·2024 年 7 月 1 日

--

SQL 可能是每个数据分析师和数据科学家都应该掌握的最基本的技术技能。它通常是面试过程的一部分,我们在工作中花费大量时间编写 SQL 来收集数据。没有它,就没有分析或机器学习模型。然而,编写一个功能性查询与编写一个好的查询是不同的。

在我最近的工作中,我们的数据科学家每天在 Airflow 上安排了超过 1,000 个查询。我注意到我们的 Snowflake 积分在短短一年内增加了 50%。当我自愿检查前 10 个积分使用查询时,很快意识到存在许多优化机会,有些非常简单。本文重点介绍了六个简单而有效的技巧,帮助我每天将 SQL 查询运行时间减少了 50 小时。其中一些非常简单易行,但请记住,SQL 优化中的 20%努力会带来 80%的性能改进。

注意:

  1. 本文是为那些已经熟悉基本 SQL 语法的人编写的。如果你是一位每天写 SQL 但想改进查询(并让你的数据工程师朋友高兴)的数据分析师或数据科学家,那么这篇文章适合你!

  2. 本文主要基于我在 Snowflake SQL 方面的经验。我专注于…

精通统计检验(第二部分)

原文:towardsdatascience.com/mastering-statistical-tests-part-ii-your-guide-to-choosing-the-right-test-for-your-data-2d33deda59ad?source=collection_archive---------7-----------------------#2024-05-27

您选择适合数据的检验方法的指南

Sheref Nasereldin, Ph.D.Towards Data Science Sheref Nasereldin, Ph.D.

·发表于 Towards Data Science ·阅读时长:12 分钟·2024 年 5 月 27 日

--

图片来源:Annie Spratt 来自 Unsplash

第一部分中,我探讨了根据一些标准对统计检验进行分类的方法,例如独立变量的数量(我只关注了只有一个独立变量的情境)和因变量(同样,我只考虑了一个因变量),以及独立变量中的各个水平和这些水平的独立性。在此基础上,第二部分将分析扩展到涉及独立变量中两个以上组别(水平)的情境,甚至增加独立变量的数量。在这里,我将深入探讨专门为这些复杂配置量身定制的检验方法,并提供其有效应用的见解。

## 精通统计检验

选择适合您数据的检验方法指南(第一部分)

towardsdatascience.com

11- 单因素方差分析(ANOVA)

该测试用于确定三个或更多独立组之间的均值是否存在显著差异。它有助于确认是否至少有一个组的均值与…不同。

精通 t-SNE:理解和在 Python 中实现的综合指南

原文:towardsdatascience.com/mastering-t-sne-a-comprehensive-guide-to-understanding-and-implementation-in-python-480929bfe6f4?source=collection_archive---------5-----------------------#2024-09-20

解锁 t-SNE 在可视化高维数据中的强大功能,通过一步步的 Python 实现和深入的讲解。

Niklas LangTowards Data Science Niklas Lang

·发表于Towards Data Science ·21 分钟阅读·2024 年 9 月 20 日

--

图片来自Martin DörschUnsplash

如果要训练强大的机器学习模型,就需要具有多个维度的大型数据集,以便识别足够的结构并提供最佳的预测。然而,这类高维数据难以可视化和理解。这就是为什么需要降维方法来可视化复杂的数据结构并进行分析。

t-分布随机邻域嵌入(t-SNE/tSNE)是一种降维方法,基于数据点之间的距离,并试图在较低的维度中保持这些距离。它是一种无监督学习方法,也能够分离非线性数据,即不能通过一条线划分的数据。

为什么需要降维?

许多算法,例如线性回归,如果数据集包含相关变量(即相互依赖的变量),会遇到问题。为避免这个问题,删除数据集中相关的变量可能是有意义的…

掌握 Python 丰富图表库的多样性与深度(附代码)

原文:towardsdatascience.com/mastering-the-versatility-and-depth-of-pythons-rich-plot-collection-with-code-b136b584d143?source=collection_archive---------2-----------------------#2024-03-30

Sankey 图蜘蛛图山脊图,本文涵盖了各种可视化,每种都有其独特的应用场景和洞见。

Dr. Theophano MitsaTowards Data Science Dr. Theophano Mitsa

·发布于数据科学之路 ·11 分钟阅读·2024 年 3 月 30 日

--

图片由作者在 DALL-E 创建

让数据集改变你的思维方式。 —— 汉斯·罗斯林”

每个数据科学家都知道图表在其数据故事中至关重要。Python 开发者幸运地能使用一门提供丰富图表库的语言。本文将通过讨论一些较少见的可视化应用场景,如Sankey 图山脊图嵌套图、蜘蛛图词云图,来展示这种丰富性。我们还将讨论一些更常见的图形表示方法,如散点图条形图。大多数图表将使用MatplotlibSeabornPlotly Python 库。

我们将使用诸如形状、大小、颜色、方向、面积大小和标记符号区域等属性,创建十个不同应用场景的图表。在每个应用场景中,我们的目标是创建有效、高效且具有美感的可视化。让我们描述在图表上下文中这些词语的含义:(a) 有效: 所有需要传达的信息都包含在图表中 (b) 高效:图表中没有冗余数据 (c) 美学…

精通 CatBoost 的不确定性

原文:towardsdatascience.com/mastering-uncertainty-with-catboost-cdb330bc00cf?source=collection_archive---------5-----------------------#2024-03-29

揭示预测区间的力量

Ilia Teimouri 博士Towards Data Science Ilia Teimouri 博士

·发布于Towards Data Science ·6 分钟阅读·2024 年 3 月 29 日

--

照片由Ian Taylor提供,来源于Unsplash

预测区间在回归分析中起着至关重要的作用,尤其是在目标不仅仅是点预测,而是评估预测的不确定性或变动性时。与提供每个输入单一估计值的点预测不同,预测区间提供了一个范围,真实值预计会以一定的概率落在该范围内。这一点尤为重要,因为它考虑到了任何预测模型中的固有不确定性。通过量化这种不确定性,预测区间提供了对可能结果的更全面理解。例如,在金融预测中,了解未来回报可能波动的范围对于风险管理和投资策略至关重要。

此外,在回归模型中,追求创建最窄或最“高效”的预测区间可以提高模型输出的精度和可靠性。较窄的区间表示预测的确定性更高,前提是这些区间准确并持续捕捉真实值。

通常,预测区间的写法为:

[𝜇-𝘻𝜎, 𝜇+𝘻𝜎]

数据科学家面试中我关注的数学知识

原文:towardsdatascience.com/mathematics-i-look-for-in-data-scientist-interviews-7c7cb1aaebe5?source=collection_archive---------2-----------------------#2024-10-09

让我们重建我们的数据科学基础。

Farzad NobarTowards Data Science Farzad Nobar

·发表于Towards Data Science ·阅读时间 16 分钟·2024 年 10 月 9 日

--

图片由Deniz Altindas提供,来源于Unsplash

作为一名在过去几年里一直参与亚马逊数据科学家和应用科学家招聘的人,不论是作为招聘经理、面试官还是面试者,我逐渐意识到,虽然大多数候选人熟悉最新的机器学习模型,但他们往往缺乏扎实的数学、统计学和概率论基础知识。这种基础科学知识的缺口可能限制他们有效地定义正确的假设来解决问题、收集合适类型的样本或解释其科学工作的结果。因此,我决定创建这篇文章,解释数据科学家最常见的基础主题。好消息是,即使对这些主题有一个整体的了解,也能让我们与其他候选人区别开来。这些数学技能是理解我们作为数据科学家日常遇到的问题的强大工具。通过不断培养这些技能并进行实践,我们将在职业生涯中长期受益。

我将这些基础数学主题分为以下四个类别(它们并不总是相互排斥的):

  1. 统计学

  2. 微积分

爱的数学:使用 Python 优化婚礼宴会厅座位安排

原文:towardsdatascience.com/mathematics-of-love-optimizing-a-dining-room-seating-arrangement-for-weddings-with-python-f9c57cc5c2ce?source=collection_archive---------3-----------------------#2024-09-02

使用数学编程和 Python 解决限制性二次多重背包问题(RQMKP)

Luis Fernando PÉREZ ARMAS, Ph.D.Towards Data Science Luis Fernando PÉREZ ARMAS, Ph.D.

·发布于 Towards Data Science ·16 分钟阅读·2024 年 9 月 2 日

--

请跟随我来到屋顶接待会……(图像由 DALLE-3 生成)

Hannah Fry博士的书,《爱的数学》 [1]*,是一本难得的好书,既聪明又幽默,而且非常容易阅读。我非常喜欢,以至于把书送给了我的一些亲密朋友。他们也很喜欢,但其中有几个在第八章遇到了难题,而讽刺的是,那一章正是他们最期待的章节。

本章深入探讨了组织婚礼座位安排的棘手问题,使用数学编程和优化方法来确定如何安排座位,使每位宾客都能度过一个尽可能愉快的时光。听起来很酷,对吧?但这里有个问题,本书实际上并没有教你如何设置或解决这个问题,这让我的朋友们感到有些迷茫。

现在,我知道你们在想什么:婚礼座位安排?真的吗?但别被愚弄了,这其实是一个非常难以解决的问题,而且它的解决方案不仅限于婚礼(尽管这个问题已经很重要)。想象一下,在游轮上安排餐桌[3]、组建运动队或工作小组、优化投资组合,甚至在我作为商学院教授的世界里……

Matplotlib:通过 rcParams 让你的绘图生活更轻松

原文:towardsdatascience.com/matplotlib-make-your-plotting-life-easier-with-rcparams-d88f202fa83c?source=collection_archive---------3-----------------------#2024-02-21

可视化

使用更少的代码和最小的努力,创建始终美观的图表

Mike ClaytonTowards Data Science Mike Clayton

·发表于 Towards Data Science ·阅读时间 18 分钟 ·2024 年 2 月 21 日

--

使用 rcParams 进行样式设置的箱型图 — 图片来源:作者

Matplotlib 无疑是一个强大且可扩展的绘图库,但它也可能令人沮丧,因为它复杂且,特别是,非常冗长。

rcParams(运行时配置参数)提供了一种将一些冗余代码提取到配置文件中的方式。这样可以显著减少重复代码,并且使得绘图始终保持美观且布局良好。它还使得你的设置和样式可以复制,并且在项目之间轻松传递。

介绍

正如我在上一篇文章中提到的,文章比较了 Matplotlib 和一些其他流行的绘图库(如 Seaborn 或 plotnine):

它[Matplotlib] 在数据科学领域相当普及,可以公平地说,很多人对它有一种爱恨交织的关系。它非常灵活,但也可能变得相当复杂。

-我在上一篇文章中

“相当复杂”可以指很多事情。特别是,变得非常明显的是,所有其他库都在一个方面做得比 Matplotlib 更好。它们都能生成相当…

最大化节省未使用的 Fabric 容量

原文:towardsdatascience.com/maximize-savings-on-your-unused-fabric-capacity-4c74b12505fd?source=collection_archive---------16-----------------------#2024-07-29

使用 Azure Logic Apps 自动化您的 Microsoft Fabric 容量状态

Christabelle SantosTowards Data Science Christabelle Santos

·发表于 Towards Data Science ·阅读时间 9 分钟·2024 年 7 月 29 日

--

免责声明:本文不涉及 Microsoft Fabric 是什么或如何使用它。

Fabric 是一个统一的数据平台,提供共享的体验、架构、治理、合规性和计费。容量提供了驱动所有这些体验的计算能力。您可以在这里阅读更多关于 Fabric 容量的信息。

想象一下:您的 Fabric 试用期快要结束,而您仍然处于概念验证阶段,正在确定 Fabric 是否是适合您的选项。有效管理您的容量至关重要,因为您不希望在这些测试或评估期间为未使用的容量承担不必要的成本。

图片来源:Carlos MuzaUnsplash

如果你已经在生产工作负载中使用 Fabric,本篇文章讨论的策略对你不适用,因为你可能需要始终保持一些可用容量来支持这些工作负载。你可能更感兴趣的是了解Fabric 中的平滑和限流过程是如何工作的,以便估算成本,以及如何使用容量度量应用程序进行监控

对于仍在非生产环境中工作的你们,我将展示如何在不需要时自动暂停 Fabric 容量。如果你正在测试功能,但不会全天候进行测试,这特别有用。那么,如何确保你不会忘记暂停容量呢?

你可能会想,暂停 Fabric 容量是什么意思?这甚至可能吗?答案是肯定的,我会向你展示如何操作。

当我们在Azure 门户中创建 Fabric 容量时,我们会看到类似的界面,其中你可以选择暂停容量,这样就不会因为不使用而被收费。

但是,如果你忘记暂停它几天怎么办?

这可能会导致 Azure 账单显著增加。为了防止这种情况发生,我们将创建一种自动化的方式来暂停你的 Fabric 容量。

Azure Logic Apps 来帮忙

为了实现这一目标,我们将首先在 Azure 中创建一个逻辑应用程序,最好是在与你的 Fabric 容量相同的资源组中创建。

你可以参考这里的说明来创建你的逻辑应用程序。务必选择**Consumption**计划类型,而不是**Standard**计划类型,因为对于初学者来说,使用**Consumption**计划更为简单,而且你只需为实际使用的部分付费。

在开始设计逻辑应用程序之前,我们需要授予逻辑应用程序修改 Fabric 容量状态的权限,因为安全性至关重要。为此,进入逻辑应用程序的**设置 > 身份**选项,并启用系统分配的托管身份。

托管身份提供了一种自动管理的身份,在 Entra ID 中供应用程序在连接支持 Entra ID 认证的资源时使用,无需在代码中存储凭据。这将使新创建的逻辑应用程序能够安全地与您的 Fabric 容量交互,而无需显式凭据。

启用托管身份后,我们需要为该身份授予对 Fabric 容量的读写权限。为此,我们将返回到我们的 Fabric 容量资源。

**Add Role Assignment**页面,选择**Contributor**角色,位于**Privileged administrator roles**下,因为此角色提供对 Fabric 容量的读写访问权限。然后,点击**Next**

在下一页,选择您的 Logic App 的托管身份。确保它出现在**Selected members**下,然后点击**Select**,接着点击**Next**。在最后一页,简单地点击**Review + assign**

这样,我们就成功地为 Logic App 授予了执行任务所需的权限!现在,开始吧。我们有五个简单的步骤来构建这个工作流,最好按顺序执行。

设计 Logic App

首先,导航到**Logic app designer**,它位于您的 Logic App 资源的**Development Tools**设置下。

创建参数

在我们开始添加触发器和操作之前,我们需要定义一些参数,这些参数将在稍后的 HTTP 请求操作中使用。

我们需要按照下面提到的方式构建我们的参数,将your-subscription-idyour-resource-group-nameyour-fabric-capacity-name替换为对应的订阅 ID、资源组名称和 fabric 容量名称值。第一个参数是获取您 Fabric 容量状态的:

Name          : $getfabricstatus
Type          : String 
Default Value : https://management.azure.com/subscriptions/your-subscription-id/resourceGroups/your-resource-group-name/providers/Microsoft.Fabric/capacities/your-fabric-capacity-name?api-version=2023-11-01

第二个参数是暂停您的 Fabric 容量:

Name          : $pausefabriccapacity
Type          : String
Default Value : https://management.azure.com/subscriptions/your-subscription-id/resourceGroups/your-resource-group-name/providers/Microsoft.Fabric/capacities/your-fabric-capacity-name/suspend?api-version=2023-11-01

确保在此阶段保存 Logic App 设计,以确保您的参数已保存。

创建定期调度触发器

一旦我们创建了参数,就可以开始设计我们的 Logic App。当我们创建一个 Logic App 时,我们通常从添加触发器开始。触发器定义了 Logic App 运行的频率。您可以将其设置为每天、每周或任何适合您的需求的间隔。这确保了 Logic App 会在定期间隔检查您的 Fabric 容量状态。

为此,点击画布中心的**Add Trigger**按钮,如下图所示。

**Add a trigger**面板将打开在页面的右侧。在这里,您可以搜索“Schedule”触发器,并从两个可选项中选择**Recurrence**

然后,您可以根据自己的偏好填写所请求的信息。在下面的图像中,我已设置一个定期触发器,使其在工作日(周一至周五)每天中午 12:00 和下午 6:00(根据我的时区)激活。

添加 HTTP 操作

接下来,我们将添加一个 HTTP 操作。该操作将向 Fabric API 发出 HTTP 请求,以检查当前的容量状态。配置 HTTP 操作时,选择适当的方法(GET、POST 等)、URL 以及身份验证详细信息,以安全地访问 Fabric API。

记得我们创建托管身份并将其分配为贡献者角色吗?在这里,我们将利用贡献者角色提供的读取权限。

点击下面指示的**添加操作**,找到我们将在此步骤中使用的 HTTP 操作。

在搜索栏中搜索 HTTP 操作,然后从**添加操作**面板下三个 HTTP 操作中选择第一个选项。

点击闪电图标打开先前创建的参数下拉列表,并选择第一个参数$getfabricstatus作为 URI 字段的值。在方法字段中,选择GET

然后滚动到面板底部,在**高级参数**部分添加身份验证选项。这将在下面添加一个身份验证部分,并选择以下选项:

Authentication Type : Managed Identity
Managed Identity    : System-assigned managed identity

这样,我们就可以使用之前创建的托管身份安全地查询织物容量的状态。

添加解析 JSON 操作

一旦收到来自 Fabric API 的响应,使用解析 JSON 操作处理数据。上一步的 API 请求将返回大量信息,我们需要选择要使用的数据点。解析 JSON 操作帮助我们从 JSON 响应中提取特定细节,例如 Fabric 容量的当前状态。

要正确配置此操作,请在**架构**部分提供示例 JSON 有效负载或架构,以下内容可供复制。在**内容**字段中,再次点击闪电图标,从上一步的输出中选择Body

Content: Body
Schema: 
{
    "properties": {
        "id": {
            "type": "string"
        },
        "location": {
            "type": "string"
        },
        "name": {
            "type": "string"
        },
        "properties": {
            "properties": {
                "administration": {
                    "properties": {
                        "members": {
                            "items": {
                                "type": "string"
                            },
                            "type": "array"
                        }
                    },
                    "type": "object"
                },
                "mode": {
                    "type": "string"
                },
                "provisioningState": {
                    "type": "string"
                },
                "state": {
                    "type": "string"
                }
            },
            "type": "object"
        },
        "sku": {
            "properties": {
                "name": {
                    "type": "string"
                },
                "tier": {
                    "type": "string"
                }
            },
            "type": "object"
        },
        "tags": {
            "properties": {},
            "type": "object"
        },
        "type": {
            "type": "string"
        }
    },
    "type": "object"
}

添加条件操作

接下来是倒数第二个操作,我们将添加一个条件操作来评估从解析的 JSON 数据中得到的状态。这个条件将决定是否需要暂停 Fabric 容量。例如,我们将设置条件检查容量状态是否不是“暂停”。

要配置条件操作,我们将再次使用闪电图标,从上一步的输出中找到Body state选项。

然后我们需要将is equal to更改为is not equal to并添加Paused值。最终结果应该如下所示:

如果条件为真(即容量未暂停),你可以在 True 部分添加操作来暂停容量。如果条件为假(即容量已暂停),你可以将 False 部分留空,或者添加通知操作来记录不需要任何操作。

在我们的案例中,我们将在 True 部分添加最终的操作,这将是另一个 HTTP 操作。此操作将使用我们定义的第二个参数,$pausefabriccapacity

遵循与之前的 HTTP 操作相同的步骤,但这次使用$pausefabriccapacity参数,并选择POST方法而不是GET方法。

重要!不要忘记像上次 HTTP 操作那样更改身份验证设置,否则你可能会遇到身份验证错误。

就这样,我们完成了逻辑应用程序的设计,并在对步骤名称进行一些编辑后,我们应该拥有一个如下所示的逻辑应用程序配置。

完成设置后,测试你的逻辑应用程序非常重要,以确保它按预期功能运行。进行一次测试,一旦成功执行,验证你的 Fabric 容量确实已暂停。这个确认步骤很重要,可以确保自动化正常工作,并帮助防止不必要的 Azure 费用。

通过这个设置,你现在知道如何自动化暂停你的 Fabric 容量,带给你安心并更好地控制你的容量。

祝你在使用 Fabric 进行机器学习和数据分析时一切顺利!

注意:除非明确提到,否则所有图像均由作者拍摄。

通过缓存最大化生产中的 AI 效率:一种成本高效的性能提升方案

原文:towardsdatascience.com/maximizing-ai-efficiency-in-production-with-caching-a-cost-efficient-performance-booster-9b8afd200efd?source=collection_archive---------6-----------------------#2024-03-19

解锁缓存的力量,以利用 LangChain 缓存扩展 AI 解决方案——全面概述

Han HELOIR, Ph.D. ☕️Towards Data Science Han HELOIR, Ph.D. ☕️

·发表于Towards Data Science ·14 分钟阅读·2024 年 3 月 19 日

--

免费朋友链接 — 请帮忙点赞这篇 LinkedIn 文章

引言

尽管人工智能应用具有变革性的潜力,但大约 70%的应用未能投入生产。挑战是什么?成本、性能、安全性、灵活性和可维护性。本文我们将讨论两个关键挑战:不断上升的成本和对高性能的需求——并揭示如何通过缓存策略解决这些问题,这就是解决方案。

图片由Possessed Photography提供,来源于Unsplash

成本挑战:规模与开销的对接

运行 AI 模型,特别是在大规模时,可能会变得极其昂贵。以 GPT-4 模型为例,处理 100 万输入 tokens 的费用为 30 美元,处理 100 万输出 tokens 的费用为 60 美元。这些费用很容易迅速累积,使得许多项目在财务上难以承担广泛的采用。

为了更好地理解这一点,我们可以考虑一个客服聊天机器人,每天处理大约50,000 个用户查询。每个查询和回复的平均长度可能是 50 个标记。一天之内,这就相当于 2,500,000 个标记,继续…

最大化 Python 代码效率:克服常见性能障碍的策略

原文:towardsdatascience.com/maximizing-python-code-efficiency-strategies-to-overcome-common-performance-hurdles-c6292610d785?source=collection_archive---------4-----------------------#2024-01-11

使用 Python 导航嵌套循环和内存挑战,以实现无缝性能

Subha GanapathiTowards Data Science Subha Ganapathi

·发表于 Towards Data Science ·阅读时长:10 分钟·2024 年 1 月 11 日

--

照片由 Kevin Canlas 提供,来源于 Unsplash

概述

在这篇关于 Python 代码优化的探索中,我们将研究一些常见问题,这些问题会导致性能瓶颈并增加开销。我们将在这里分析两个问题——一个与嵌套循环相关,另一个与读取大量数据集时产生的内存/分配问题相关。

在嵌套循环问题中,我们将通过一个示例用例来理解嵌套循环困境,然后提出一个解决方案,作为避免嵌套循环带来的性能问题的替代方法。

针对大数据集遇到的内存/分配问题,我们探讨了多种数据读取策略,并对每种策略的性能进行了比较。让我们进一步探讨。

性能问题一:嵌套循环困境

嵌套循环困境 — 照片由 愚木混株 cdd20 提供,来源于 Unsplash

虽然嵌套循环是常见的编程结构,但它们的低效实现可能导致性能不佳。一个显著的…

最大化稀缺 AI 资源的利用:一种 Kubernetes 方法

原文:towardsdatascience.com/maximizing-the-utility-of-scarce-ai-resources-a-kubernetes-approach-0230ba53965b?source=collection_archive---------12-----------------------#2024-02-13

优化有限 AI 训练加速器的使用

Chaim RandTowards Data Science Chaim Rand

·发表于Towards Data Science ·13 分钟阅读·2024 年 2 月 13 日

--

照片由Roman Derrick Okello提供,来源于Unsplash

在 AI 开发不断变化的环境中,没有什么比古老的格言(归功于赫拉克利特)更能体现现实:“生活中唯一不变的就是变化”。对于 AI 而言,似乎变化确实是常态,而变化的速度却在不断加快。在这个独特而激动人心的时代,要保持相关性就意味着 AI 团队必须前所未有地考验其不断适应和调整开发流程的能力。那些未能适应或适应迟缓的 AI 开发团队,很可能会迅速变得过时。

近年来,AI 开发面临的最具挑战性的进展之一是获取训练 AI 模型所需硬件的难度不断增加。无论是由于全球供应链持续危机,还是 AI 芯片需求的显著增加,获得所需的 GPU(或替代训练加速器)变得更加困难。新 GPU 订单的巨大等待时间以及曾经提供几乎无限 GPU 计算能力的云服务提供商(CSP)现在也难以满足需求,均证明了这一点。

时代的变化迫使那些曾依赖无限容量 AI 加速器的 AI 开发团队,适应一个减少可访问性、在某些情况下成本更高的世界。曾经理所当然的能够随意启动新的 GPU 机器的开发过程,现在必须做出调整,以应对一个 AI 资源稀缺、且通常由多个项目和/或团队共享的世界。那些未能适应的人,将面临被淘汰的风险。

在这篇文章中,我们将展示如何在 AI 资源稀缺的世界中,使用Kubernetes来协调 AI 模型训练工作负载。我们将从指定我们希望实现的目标开始,然后描述为什么 Kubernetes 是解决这个挑战的合适工具。最后,我们将提供一个简单的演示,展示如何使用 Kubernetes 最大化稀缺的 AI 计算资源的利用。在随后的文章中,我们计划进一步完善基于 Kubernetes 的解决方案,并展示如何将其应用于基于云的训练环境。

免责声明

虽然本文并不假设读者具有 Kubernetes 的先前经验,但一些基本的了解肯定会有所帮助。本文不应被视为 Kubernetes 教程。要学习 Kubernetes,我们建议读者参考许多关于此主题的优秀在线资源。这里,我们仅讨论 Kubernetes 在最大化和优先考虑资源利用方面的一些特性。

对于我们在此提出的方法,有许多替代工具和技术,每种工具和技术都有其优缺点。我们在本文中的意图纯粹是为了教育目的;请不要将我们所做的任何选择视为推荐。

最后,Kubernetes 平台以及 AI 开发领域中的许多框架和工具仍在不断发展。请注意,在您阅读本文时,某些声明、示例和/或外部链接可能已经过时,因此在做出自己的设计决策之前,请务必参考最新的解决方案。

适应稀缺的 AI 计算资源

为了简化我们的讨论,假设我们只有一个单一的工作节点用于训练模型。这个节点可以是配备 GPU 的本地机器,或者是云端的预留实例,例如 AWS 中的p5.48xlarge实例,或者 GCP 中的TPU 节点。在下面的示例中,我们将称这个节点为“我的宝贝”。通常,我们会为这台机器花费大量资金。我们还假设,我们有多个训练任务在争夺这个唯一的计算资源,每个任务可能需要从几分钟到几天的时间。自然,我们希望最大化计算资源的利用率,确保它始终处于工作状态,并优先处理最重要的任务。我们需要的是某种形式的优先级队列以及相关的基于优先级的调度算法。让我们更具体地讨论一下我们希望达到的行为。

调度需求

  1. 最大化利用率: 我们希望我们的资源能够持续使用。具体来说,当一个任务完成后,资源会立即(并自动)开始处理下一个任务。

  2. 队列中的待处理任务: 我们需要存在一个队列,用于存放等待我们唯一资源处理的训练任务。我们还需要相关的 API 来创建和提交新的任务到队列,以及监控和管理队列的状态。

  3. 支持优先级: 我们希望每个训练任务都具有关联的优先级,使得优先级高的任务会在优先级低的任务之前运行。

  4. 抢占: 此外,如果在资源正在处理低优先级任务时,队列中有一个紧急任务提交,我们希望能够将当前正在运行的任务抢占,并由紧急任务替代。被抢占的任务应该返回队列中。

一种满足这些需求的解决方案可能是,使用现有的提交任务到训练资源的 API,并用一个定制的优先级队列实现来包装它,确保具备所需的属性。最基本的,这种方法需要一个数据结构来存储待处理任务的列表,一个专门的进程来从队列中选择并提交任务到训练资源,以及某种机制来识别任务何时完成并且资源变得可用。

一种替代的方法,也是我们在本文中采用的方法,是利用现有的基于优先级的 调度 解决方案,该解决方案能够满足我们的需求,并将我们的训练开发工作流程与其使用对齐。Kubernetes 默认的 调度器 就是这样一个解决方案的例子。在接下来的部分,我们将演示如何使用它来解决优化稀缺 AI 训练资源使用的问题。

使用 Kubernetes 进行 ML 编排

在这一部分,我们将对 Kubernetes 在 ML 训练工作负载编排中的应用进行一些哲学性的探讨。如果你对这些讨论没有耐心(完全可以理解),并且希望直接查看实际示例,请随时跳过到下一部分。

Kubernetes 是(另一个)那种在许多开发者中引发强烈反应的软件/技术解决方案。有人对它推崇备至,并广泛使用;也有人认为它压迫性强、笨重且不必要(例如,参见 这里 一些支持和反对使用 Kubernetes 的论点)。正如许多其他激烈争论一样,作者认为真相介于两者之间——在一些情况下,Kubernetes 提供了一个理想的框架,能够显著提高生产力,而在其他情况下,它的使用几乎是对软件开发职业的侮辱。一个大问题是,ML 开发处于这个光谱的哪个位置?Kubernetes 是否是训练 ML 模型的合适框架?尽管快速的在线搜索可能给人一种普遍共识是明确的“是”,但我们将提出一些理由,说明为什么情况可能并非如此。但首先,我们需要澄清一下“使用 Kubernetes 进行 ML 训练编排”到底是什么意思。

虽然有许多在线资源讨论使用 Kubernetes 进行机器学习(ML)的话题,但需要注意的是,它们并不总是指同一种使用方式。一些资源(例如,这里)仅使用 Kubernetes 来部署集群;一旦集群启动并运行,它们便在 Kubernetes 之外启动训练任务。另一些资源(例如,这里)使用 Kubernetes 定义一个流水线,其中一个专用模块启动一个训练任务(及其相关资源),该模块使用完全不同的系统。与这两个例子不同,许多其他资源将训练工作负载定义为Kubernetes Job工件,该工件在Kubernetes Node上运行。然而,它们在专注的具体属性上也存在很大差异。一些资源(例如,这里)强调自动扩展特性,而其他资源(例如,这里)则强调多实例 GPU(MIG)支持。它们在实现的细节上也差异很大,例如表示训练任务的精确工件(例如,Job扩展),例如ElasticJobTrainingWorkloadJobSetVolcanoJob等。在本文的语境下,我们也将假设训练工作负载被定义为Kubernetes Job。然而,为了简化讨论,我们将专注于核心 Kubernetes 对象,关于针对 ML 的 Kubernetes 扩展的讨论将留待以后文章中详细展开。

反对 Kubernetes 用于机器学习的论点

下面是一些可能反对使用 Kubernetes 进行机器学习模型训练的论点。

  1. 复杂性: 即使是最强烈的支持者也不得不承认 Kubernetes 可能很难使用。有效使用 Kubernetes 需要较高的专业知识,学习曲线陡峭,且从现实角度来看,通常需要专门的 DevOps 团队。基于 Kubernetes 设计训练解决方案增加了对专门专家的依赖,并进一步增加了出现问题的风险,且可能导致开发延迟。许多替代的机器学习训练解决方案使开发者拥有更高的独立性和自由度,并降低了开发过程中出现漏洞的风险。

  2. 固定资源需求: Kubernetes 最为人称道的特性之一是其可扩展性——它能够根据作业数量、客户端数量(对于服务应用程序的情况)、资源容量等,自动、无缝地扩展和缩减其计算资源池。然而,可以说,在机器学习训练工作负载的情况下,由于所需资源的数量(通常)在训练过程中是固定的,自动扩展显得不必要。

  3. 固定实例类型: 由于 Kubernetes 管理容器化应用程序,Kubernetes 在其节点池中对于机器类型提供了很大的灵活性。然而,在机器学习中,我们通常需要非常特定的硬件,配备专用加速器(如 GPU)。此外,我们的工作负载通常会调整为在某一特定实例类型上运行得最优。

  4. 单体应用架构: 在现代应用开发中,常见的做法是将应用分解为称为微服务的小元素。Kubernetes 通常被视为这种设计中的关键组件。机器学习训练应用通常在设计上是相当单体化的,可以说,它们并不自然地适合微服务架构。

  5. 资源开销: 运行 Kubernetes 所需的专用进程需要在其池中的每个节点上占用一些系统资源。因此,这可能会对我们的训练任务产生一定的性能损失。考虑到训练所需资源的成本,我们可能更倾向于避免这种情况。

诚然,我们在 Kubernetes 与 ML 争论中采取了非常片面的观点。仅根据上述论点,你可能会得出结论,认为我们需要一个非常充分的理由才能选择 Kubernetes 作为 ML 训练的框架。我们认为,本文提出的挑战,即最大化稀缺的 AI 计算资源的效用,正是支持使用 Kubernetes 的理由,尽管上述论点提出了反对意见。正如我们将展示的那样,内置的 调度器 和对 优先级和抢占 的支持,使其成为满足上述需求的有力候选者。

示例

在本节中,我们将分享一个简短的示例,展示 Kubernetes 内建的优先级调度支持。为了演示的目的,我们将使用 Minikube(版本 v1.32.0)。Minikube 是一个工具,可以让你在本地环境中运行 Kubernetes 集群,非常适合用来实验 Kubernetes。请参阅官方文档了解 Minikube 的安装与入门

集群创建

让我们首先使用 Minikube start 命令创建一个包含两个节点的集群:

minikube start --nodes 2

结果是一个本地 Kubernetes 集群,包含一个名为 minikube 的主节点(“控制平面”)和一个名为 minikube-m02 的工作节点,它将模拟我们的单一 AI 资源。让我们应用 标签 my-precious 来标识它为一个独特的资源类型:

kubectl label nodes minikube-m02 node-type=my-precious

我们可以使用 Minikube 仪表板 来可视化结果。在另一个 shell 中运行下面的命令,并打开生成的浏览器链接。

minikube dashboard

如果你点击左侧面板上的 Nodes 标签,你应该能够看到集群节点的总结:

Minikube 仪表板中的节点列表(由作者捕获)

PriorityClass 定义

接下来,我们定义两个 PriorityClasses低优先级高优先级,如下面显示的 priorities.yaml 文件所示。新作业默认将分配为 低优先级

apiVersion: scheduling.k8s.io/v1
kind: PriorityClass
metadata:
  name: low-priority
value: 0
globalDefault: true

---
apiVersion: scheduling.k8s.io/v1
kind: PriorityClass
metadata:
  name: high-priority
value: 1000000
globalDefault: false

为了将我们的新类应用到集群中,我们运行:

kubectl apply -f priorities.yaml

创建作业

我们使用下面的代码块定义一个简单的作业,使用job.yaml文件。为了演示的目的,我们定义了一个Kubernetes 作业,该作业只是简单地睡眠 100 秒。我们使用busybox作为它的 Docker 镜像。在实际应用中,这将被替换为一个训练脚本和合适的机器学习 Docker 镜像。我们通过nodeSelector字段指定作业在我们特殊的实例my-precious上运行,并指定资源需求,确保一次只能在该实例上运行一个作业。作业的优先级默认为上面定义的低优先级

apiVersion: batch/v1
kind: Job
metadata:
  name: test
spec:
  template:
    spec:
      containers:
        - name: test
          image: busybox
          command: # simple sleep command
            - sleep
            - '100'
          resources: # require all available resources
            limits:
              cpu: "2"
            requests:
              cpu: "2"
      nodeSelector: # specify our unique resource
          node-type: my-precious
      restartPolicy: Never

我们使用以下命令提交作业:

kubectl apply -f job.yaml

创建作业队列

为了展示 Kubernetes 如何排队处理作业,我们创建了三个与上面定义的作业相同的副本,分别命名为test1test2test3。我们将这三个作业组合成一个文件jobs.yaml,并提交它们进行处理:

kubectl apply -f jobs.yaml

下图显示的是我们集群的工作负载状态,该状态是在提交作业后不久,通过Minikube 仪表盘捕捉到的。你可以看到,my-precious已经开始处理test1,而其他作业则处于待处理状态,等待轮到它们。

集群工作负载状态(由作者捕获)

一旦test1完成,test2的处理将开始:

集群工作负载状态 — 自动调度(由作者捕获)

只要没有其他更高优先级的作业提交,我们的作业将继续按顺序处理,直到所有作业完成。

作业抢占

我们通过展示提交第四个作业(这次设置为高优先级)时发生的情况,来演示 Kubernetes 内置对作业抢占的支持:

apiVersion: batch/v1
kind: Job
metadata:
  name: test-p1
spec:
  template:
    spec:
      containers:
        - name: test-p1
          image: busybox
          command:
            - sleep
            - '100'
          resources:
            limits:
              cpu: "2"
            requests:
              cpu: "2"
      restartPolicy: Never
      priorityClassName: high-priority # high priority job
      nodeSelector:
          node-type: my-precious

工作负载状态的影响如下面的图所示:

集群工作负载状态 — 抢占(由作者捕获)

test2作业已被抢占——它的处理被停止,并且返回到待处理状态。取而代之的是,my-precious开始处理优先级更高的test-p1作业。只有当test-p1完成后,低优先级作业的处理才会恢复。(如果被抢占的作业是一个机器学习训练工作负载,我们会编程让它从最近保存的模型模型检查点继续恢复。)

下图显示了所有作业完成后的工作负载状态

集群工作负载状态 — 完成(由作者捕获)

Kubernetes 扩展

我们展示的基于优先级调度和抢占的解决方案仅依赖于 Kubernetes 的核心组件。在实际应用中,你可以选择利用如Kueue等扩展引入的基本功能增强,和/或在 Kubernetes 之上构建的平台提供的专用 ML 特性,例如Run:AIVolcano。但请记住,为了满足最大化稀缺 AI 计算资源效用的基本要求,我们所需要的只是 Kubernetes 的核心功能。

总结

专用 AI 硅芯片的可用性减少迫使 ML 团队调整他们的开发流程。与过去不同,开发人员可以随意创建新的 AI 资源,现在他们面临 AI 计算能力的限制。这就需要通过购买专用单元和/或预定云实例等方式来采购 AI 实例。此外,开发人员还必须接受与其他用户和项目共享这些资源的可能性。为了确保稀缺的 AI 计算能力得到最大效用的分配,必须定义专用调度算法,减少空闲时间并优先处理关键工作负载。在本文中,我们展示了如何利用Kubernetes 调度器实现这些目标。如上所述,这只是解决最大化稀缺 AI 资源效用的众多方法之一。自然,你选择的方案以及实现的细节将取决于你 AI 开发的具体需求。

衡量 AI 的创造力:视觉字谜

原文:towardsdatascience.com/measuring-ais-creativity-with-visual-word-puzzles-cb1bd2f3f4bb?source=collection_archive---------14-----------------------#2024-02-13

AI 模型能多好地解答(并创造)字谜?

Yennie JunTowards Data Science Yennie Jun

·发布在 Towards Data Science ·11 分钟阅读 ·2024 年 2 月 13 日

--

GPT-4:当提示为:为“视觉字谜”创建一个字谜

这篇文章最初发布在 Art Fish Intelligence

引言

对 AI 来说,创造力意味着什么?

去年,我写过一篇关于如何使用几种基于文字的创造力测试来衡量大型语言模型(LLMs)创造力的文章。

## 探索大型语言模型中的创造力:从 GPT-2 到 GPT-4

通过创造力测试分析大型语言模型创作过程的演变

[towardsdatascience.com

从那时起,人工智能发展迅速,能够处理和创造文本图像。这些模型,有时被称为“多模态大型语言模型”(MLLMs),非常强大,能够理解复杂的文本和视觉输入。

在本文中,我探讨了如何衡量两种流行的多模态大型语言模型(MLLMs)中的创造力:OpenAI 的 GPT-4 VisionGoogle 的 Gemini Pro Vision。我使用了字谜,这是一种需要结合视觉和语言线索来解答的文字谜题。

创造力是极其多面的,很难将其定义为单一的特征……

使用 dbt_set_similarity 测量跨产品采纳情况

原文:towardsdatascience.com/measuring-cross-product-adoption-using-dbt-set-similarity-fdf7c1f88bc2?source=collection_archive---------1-----------------------#2024-12-28

在 dbt 工作流中增强跨产品洞察

Matthew SenickTowards Data Science Matthew Senick

·发布于 Towards Data Science ·5 分钟阅读·2024 年 12 月 28 日

--

介绍

对于多产品公司,一个关键的指标通常被称为“跨产品采纳”(即了解用户如何与产品组合中的多个产品互动)。

在流行书籍 Hacking Growth [1] 中,建议使用一种衡量跨产品或跨特征使用情况的指标,即 杰卡德指数。杰卡德指数传统上用于衡量两个集合之间的相似度,也可以作为评估产品采纳模式的强大工具。它通过量化产品之间用户的重叠,帮助识别跨产品协同效应和增长机会。

dbt 包 dbt_set_similarity 旨在简化在分析工作流中直接计算集合相似度指标的过程。该包提供了一种方法,可以在 SQL 转换层中计算杰卡德指数。

要将此包导入到您的 dbt 项目中,请将以下内容添加到 packages.yml 文件中。为了本文的示例,我们还需要 dbt_utils。请在项目中运行 dbt deps 命令来安装该包。

packages:
  - package: Matts52/dbt_set_similarity
    version: 0.1.1
  - package: dbt-labs/dbt_utils
    version: 1.3.0

杰卡德指数

杰卡德指数(Jaccard Index),也称为杰卡德相似度系数,是一种用于衡量两个集合相似度的指标。它的定义是两个集合交集的大小除以它们并集的大小。

从数学角度来看,它可以表示为:

杰卡德指数表示两个集合的“交集”与它们的“并集”的比值(图由作者提供)

其中:

  • AB 是两个集合(例如,产品 A 和产品 B 的用户)

  • 分子表示两个集合中元素的数量

  • 分母表示两个集合中所有不同元素的总数

(图像由作者提供)

Jaccard 指数在交叉产品采用的背景下特别有用,因为:

  • 它关注的是两个集合之间的重叠部分,非常适合用于理解共享的用户基础

  • 它考虑了集合大小的差异,确保结果是成比例的,而不会被异常值扭曲

例如:

  • 如果 100 个用户采用了产品 A,50 个用户采用了产品 B,且 25 个用户同时采用了这两个产品,那么 Jaccard 指数为 25 / (100 + 50 - 25) = 0.2,表示这两个用户群体之间有 20% 的重叠。

示例数据

我们将使用的示例数据集是一个虚构的 SaaS 公司,提供存储空间作为产品给消费者。该公司提供两种不同的存储产品:文档存储(doc_storage)和照片存储(photo_storage)。这些值要么为真,表示产品已被采用,要么为假,表示产品未被采用。

此外,该公司服务的用户群体(user_category)要么是科技爱好者,要么是房主。

为了这个示例,我们将在 dbt 项目中将此 CSV 文件作为名为 seed_example 的“种子”模型读取。

简单的交叉产品采用

现在,假设我们想要计算文档存储和照片存储产品之间的 Jaccard 指数(交叉采用)。首先,我们需要创建一个包含使用文档存储产品的用户的数组(列表),以及一个包含使用照片存储产品的用户的数组。在第二个 CTE 中,我们应用 dbt_set_similarity 包中的 jaccard_coef 函数,帮助我们轻松计算这两个用户 ID 数组之间的 Jaccard 系数。

with product_users as (
    select
        array_agg(user_id) filter (where doc_storage = true)
            as doc_storage_users,
        array_agg(user_id) filter (where photo_storage = true)
            as photo_storage_users
    from {{ ref('seed_example') }}
)

select
    doc_storage_users,
    photo_storage_users,
    {{
        dbt_set_similarity.jaccard_coef(
            'doc_storage_users',
            'photo_storage_users'
        )
    }} as cross_product_jaccard_coef
from product_users

来自上述 dbt 模型的输出(图像由作者提供)

如我们所见,似乎超过一半(60%)的采用了任何一个产品的用户,也采用了两个产品。 我们可以通过将用户 ID 集合放入维恩图中来图形化验证我们的结果,在那里我们看到三位用户同时采用了这两个产品,在五位用户中占比为 3/5 = 0.6。

用户 ID 和产品采用的集合是什么样的,验证我们的结果(图像由作者提供)

分段交叉产品采用

使用dbt_set_similarity包,创建不同用户类别的分段 Jaccard 指数应该是非常自然的。我们将遵循之前的模式,然而,我们将仅仅在用户类别上对我们的聚合进行分组。

with product_users as (
    select
        user_category,
        array_agg(user_id) filter (where doc_storage = true)
            as doc_storage_users,
        array_agg(user_id) filter (where photo_storage = true)
            as photo_storage_users
    from {{ ref('seed_example') }}
    group by user_category
)

select
    user_category,
    doc_storage_users,
    photo_storage_users,
    {{
        dbt_set_similarity.jaccard_coef(
            'doc_storage_users',
            'photo_storage_users'
        )
    }} as cross_product_jaccard_coef
from product_users

来自上述 dbt 模型的输出(图像由作者提供)

从输出结果来看,考虑到杰卡德指数,房主的跨产品采用率较高。如上所示,所有已采用其中一种产品的房主都已经采用了两种产品。与此同时,只有三分之一的科技爱好者在采用一种产品后也采用了两种产品。因此,在我们这个非常小的数据集中,房主的跨产品采用率高于科技爱好者。

我们可以通过再次绘制韦恩图来图形化验证输出结果:

按照两个细分市场划分的韦恩图(图片由作者提供)

结论

dbt_set_similarity 提供了一种直接在 dbt 工作流中计算跨产品采用指标(例如杰卡德指数)的简单有效方法。通过应用此方法,多产品公司可以深入了解用户行为和产品组合中的采用模式。在我们的示例中,我们演示了整体跨产品采用率的计算,以及对不同用户类别的细分采用情况。

使用该包进行跨产品采用分析只是其中一种直接的应用。实际上,这种技术有无数其他潜在应用,例如:

  • 功能使用分析

  • 营销活动影响分析

  • 支持分析

此外,这种分析方式显然不仅限于 SaaS,几乎可以应用于任何行业。祝你使用杰卡德指数愉快!

参考文献

[1] Sean Ellis 和 Morgan Brown, Hacking Growth(2017 年)

资源

dbt 包中心

衡量生产问题对开发团队的成本

原文:towardsdatascience.com/measuring-the-cost-of-production-issues-on-development-teams-5efcd13bc9c7?source=collection_archive---------8-----------------------#2024-12-11

降低对质量的优先级会牺牲软件的稳定性和速度,从而导致昂贵的问题。而投资质量则能提升速度和成果。

David TranTowards Data Science David Tran

·发布于 Towards Data Science ·8 分钟阅读·2024 年 12 月 11 日

--

图像由作者提供。(AI 生成,使用 Midjourney)

投资软件质量往往说起来容易做起来难。尽管许多工程经理表示他们致力于高质量的软件,但他们通常对将大量资源投入到以质量为重点的计划中持谨慎态度。在紧迫的截止日期和相互竞争的优先事项面前,领导者经常面临如何分配团队时间和精力的艰难抉择。因此,质量方面的投资往往是第一个被削减的项目。

在任何工程组织中,投资质量与优先考虑速度之间的紧张关系都至关重要,尤其是在数据科学和机器学习项目中,其中交付结果是重中之重。与传统的软件开发不同,机器学习系统通常需要持续更新,以保持模型性能、适应不断变化的数据分布并集成新功能。机器学习流水线中的生产问题——如数据质量问题、模型漂移或部署失败——可能会扰乱这些工作流程,并对业务成果产生连锁反应。在机器学习团队中,平衡实验和部署的速度与严格的质量保证至关重要,只有这样才能交付可靠、高效的模型。通过应用一种结构化的、科学的方法来量化生产问题的成本,正如本文中所述,机器学习团队可以在质量改进的投资与优化开发速度之间做出明智的决策。

质量往往面临一个强大的对手:速度。随着业务目标和关键功能交付压力的增加,证明任何非直接提升速度的做法变得越来越具有挑战性。

推动产出。许多团队将非编码活动减少到最低限度,专注于单元测试,同时降低集成测试的优先级,推迟技术改进,并依赖可观察性工具来捕捉生产问题——希望只在问题出现时再进行处理。

平衡速度和质量很少是一个简单的选择,这篇文章也没有打算简化这一点。然而,领导者常常忽视的是,速度和质量是密切相关的。通过降低改善软件质量的优先级,团队最终可能会导致发布的内容既充满缺陷又很慢。快速推出更多功能所带来的任何收益

生产力可能迅速衰退,因为维护问题和持续出现的故障最终会削弱团队的速度。

只有通过了解质量对速度的全面影响,以及质量措施的预期投资回报率,领导者才能做出明智的决策,以平衡团队的工作负载。

在这篇文章中,我们将尝试提供一个模型,用来衡量在提高发布质量的两个方面上的投资回报率(ROI):减少生产问题的数量,以及减少团队在问题发生时花费的时间。

避免缺陷,指的是那些进入生产环境的漏洞。

防止回归可能是减少生产问题对团队负担的最直接、最上游的措施。那些从未发生的问题不会拖累团队,造成中断,或威胁到业务的持续性。

尽管好处可能非常吸引人,但在某个拐点之后,防止代码出现问题的做法可能会导致发布速度的急剧放缓。理论上,团队可以将所需的代码审查次数增加三倍,将测试的投资增加三倍,并建立一个严格的负载测试装置。这样做可以防止更多的问题,但也会让发布新内容变得极为缓慢。

因此,为了证明投资于防止回归的任何努力是值得的,我们需要更好地理解其投资回报率。我们可以尝试估算回归减少 1%的成本节省对整个团队表现的影响,从而开始建立一个框架,帮助我们平衡质量投资。

作者提供的图片。

防止问题的直接收益首先体现在团队处理这些问题所花费的时间上。研究表明,团队目前花费的时间中,有 20%到 40%用于处理生产问题——这对生产力是一个巨大的拖累。

投资于防止问题的好处是什么?通过简单的数学计算,我们可以开始估算在开发过程的早期阶段防止每个问题所带来的生产力提升:

作者提供的图片。

其中:

  • Tsaved​ 是通过问题预防所节省的时间。

  • Tissues 是目前用于处理生产问题的时间。

  • P 是可以预防的生产问题的百分比。

该框架有助于评估工程投资的成本与价值。例如,一位经理指派两名开发人员花费一周时间使用可观测性数据分析性能问题。他们的努力将生产问题减少了 10%。

在一个 100 名开发人员的团队中,如果 40%的时间用于问题解决,这将转化为 4%的容量增益,再加上通过减少上下文切换带来的 1.6%。通过回收 5.6%的容量,这项投资证明是值得的,展示了这种方法如何指导实际决策。

很容易看到,预防每一个 1%的生产回归问题对团队工作速度的直接影响。这代表了团队不需要执行的生产回归工作。下面的表格可以通过插入一些值来提供一些背景信息:

根据这些数据,举个例子,对于一个将25%的时间花费在处理生产问题的团队,每1%的改善带来的直接团队资源增益将是0.25%。如果团队能够预防 20%的生产问题,那么这意味着会有5%的资源回到工程团队。虽然这可能听起来不是一个足够大的数值,但还有其他与问题相关的成本可以优化,从而带来更大的影响。

平均修复时间(MTTR):减少因问题解决而浪费的时间。

在之前的示例中,我们讨论了通过预防问题所获得的生产力提升。那么,对于那些无法避免的问题呢?虽然一些漏洞是不可避免的,但我们仍然可以通过减少解决问题所需的时间来最小化它们对团队生产力的影响——这被称为平均修复时间(MTTR)。

通常,解决一个 bug 需要几个阶段:

  1. 问题分类/评估:团队集合相关的主题专家来确定问题的严重性和紧急性。

  2. 调查/根本原因分析(RCA):开发人员深入研究问题,以识别根本原因,通常是最耗时的阶段。

  3. 修复/解决:团队实施修复。

图片来自作者。

在这些阶段中,调查阶段通常代表了节省时间的最大机会。通过采用更高效的追踪、调试和缺陷分析工具,团队可以简化他们的根本原因分析(RCA)工作,显著减少 MTTR,从而提升生产力。

在问题分类阶段,团队可能会邀请主题专家来评估一个问题是否应该进入待办事项列表,并确定其紧急性。接下来是调查和根本原因分析(RCA)阶段,开发人员会深入研究问题。最后,修复阶段涉及编写代码来解决问题。

有趣的是,前两个阶段,尤其是调查和根本原因分析(RCA),通常占总解决时间的 30%到 50%。这一阶段具有最大的优化潜力,因为关键在于改进现有信息的分析方式。

为了衡量改善调查时间对团队速度的影响,我们可以计算团队在每个问题上所花费的时间百分比,并减少调查阶段的比例成本。这通常可以通过采用更好的跟踪、调试和缺陷分析工具来实现。我们将类似的逻辑应用于问题预防的评估,以便了解每减少一个百分比的调查时间,团队可以获得多少生产力提升。

图片来自作者。

  1. Tsaved:节省的团队时间百分比

  2. R:调查时间的减少

  3. T_investigation:每个问题上用于调查的时间

  4. T_issues:处理生产问题所花费的时间百分比

我们可以测试相对于T_investigationT_issues变量,性能提升的表现。我们将计算每减少 1%调查时间R的边际收益。

随着这些数字的积累,团队可以获得显著的提升。如果我们能够将调查时间减少 40%,例如在一个团队中,该团队的 25%时间用于处理生产问题,那么我们将重新夺回该团队生产力的 4%。

结合这两项好处

考虑到这两个优化领域,我们可以创建一个统一的公式来衡量同时优化问题预防和团队在无法预防问题时所花费的时间的综合效果。

图片来自作者。

回到我们举的例子,假设该团队将 25%的时间用于处理生产问题,并且每个问题的解决时间中有 40%用于调查。如果能将调查时间减少 40%并且预防 20%的问题,那么团队的生产力将提升 8.1%。然而,我们的工作远未完成。

考虑到上下文切换的隐性成本

以上每一种简单的计算都没有考虑到一个主要的惩罚因素——由于未经计划的生产问题导致工作的中断——即上下文切换(CS)。有大量研究反复表明,上下文切换是昂贵的。多昂贵呢?由于中断和在多个任务之间切换,额外工作量的惩罚在 20%到 70%之间。在减少中断工作时间的同时,我们也能减少上下文切换的惩罚。

我们最初的公式没有考虑到这一重要变量。一个简单但天真的方法是假设任何未计划的工作处理生产问题都会对已经分配给团队的待办事项造成相等的上下文切换惩罚。如果我们能够节省 8%的团队效率,那应该会导致上下文切换的惩罚也相应减少,尤其是对于原定任务的完成。在减少 8%的未计划工作时,我们也就减少了相应的 8%计划工作上的上下文切换惩罚。

让我们将这一点加到我们的方程式中:

图片由作者提供。

继续我们的例子,假设的组织会发现他们的改进实际影响现在略超过 11%。对于一个 80 人规模的开发团队来说,这相当于有 8 名开发人员可以腾出时间做其他事情,为积压工作做出贡献。

使用 ROI 计算器

为了简化操作,我已将上述所有公式上传为一个简单的 HTML 计算器,你可以在这里访问:

ROI 计算器

衡量 ROI 是关键

生产问题很昂贵,但一个清晰的 ROI 框架有助于量化质量改进的影响。通过优化分诊和调查来减少平均修复时间(MTTR)可以提高团队生产力。例如,调查时间减少 40%

恢复了 4%的生产力,并降低了上下文切换的隐性成本。

使用 ROI 计算器评估优质投资,并做出数据驱动的决策。点击这里查看如何通过有针对性的改进提高效率。

参考文献:

1. 开发人员实际编写代码的时间有多少?

2. 如何更快编写优质软件(我们花费 90%的时间进行调试)

3. 调查:修复漏洞挤占开发时间

4. 上下文切换的真实成本

衡量你市场营销活动的内在因果影响

原文:towardsdatascience.com/measuring-the-intrinsic-causal-influence-of-your-marketing-campaigns-aa8354c26b7b?source=collection_archive---------1-----------------------#2024-06-02

因果 AI,探索因果推理与机器学习的结合

Ryan O'SullivanTowards Data Science Ryan O'Sullivan

·发表于 Towards Data Science ·阅读时间:10 分钟·2024 年 6 月 2 日

--

图片由 Melanie Deziel 提供,来源于 Unsplash

这篇系列文章讲的是什么?

欢迎来到我的因果 AI 系列文章,我们将在其中探索因果推理如何融入机器学习模型。你将了解在不同商业环境中的一些实际应用。

在上一篇文章中,我们探讨了定价和促销中的非线性处理效果优化。这次我们将讨论衡量你的市场营销活动的内在因果影响

如果你错过了上一篇关于定价和促销中非线性处理效果的文章,可以点击这里查看:

## 定价和促销中的非线性处理效果优化

因果 AI,探索因果推理与机器学习的结合

towardsdatascience.com

介绍

在这篇文章中,我将帮助你理解如何衡量你的市场营销活动的内在因果影响。

以下几个方面将会被涵盖:

  • 在市场营销衡量方面存在哪些挑战?

  • 什么是内在因果影响?它是如何工作的?

  • 一个在 Python 中的案例研究,展示了我们如何利用内在的因果影响来为营销活动提供应有的归因。

完整的笔记本可以在这里找到:

[## causal_ai/notebooks/understanding the intrinsic causal influence of your marketing campaigns.ipynb…

这个项目介绍了因果 AI 以及它如何推动商业价值。 - causal_ai/notebooks/understanding the…

github.com](https://github.com/raz1470/causal_ai/blob/main/notebooks/understanding the intrinsic causal influence of your marketing campaigns.ipynb?source=post_page-----aa8354c26b7b--------------------------------)

在营销衡量中有哪些挑战?

不同类型的营销活动有哪些?

组织通过营销来发展业务,获取新客户并保持现有客户。营销活动通常分为三大类:

  • 品牌

  • 性能

  • 留存

每种活动在衡量时都有自己独特的挑战——理解这些挑战至关重要。

品牌活动

品牌活动的目的是在新受众中提高品牌的认知度。它们通常通过电视和社交媒体进行推广,后者通常以视频格式呈现。它们通常没有直接的行动号召,例如“我们的产品可以使用一辈子”。

测量电视广告的挑战显而易见——我们无法追踪谁看过电视广告!但当涉及到社交媒体时,我们也面临类似的挑战——如果我在 Facebook 上观看一个视频,然后在第二天自发访问网站并购买了产品,我们很可能无法将这两个事件关联起来。

还有一个次要的挑战是延迟效应。当在新受众中提高品牌认知度时,可能需要几天、几周甚至几个月才能让他们考虑购买你的产品。

有人争论品牌活动完成了所有艰难的工作——然而,在营销衡量中,它们常常被低估,因为我们上面提到的一些挑战。

绩效活动

一般来说,绩效活动是面向已经在市场上寻找你产品的顾客的。它们通过付费搜索、社交和联盟渠道进行推广。它们通常会有一个行动号召,例如“点击立即享受首次购买 5% 的折扣”。

当谈到绩效活动时,为什么它们很难衡量并不立即显现出来。很可能我们能将顾客点击绩效活动的事件与当天购买该产品的顾客联系起来。

但如果他们之前对这个品牌并不熟悉,他们还会点击吗?他们是怎么熟悉这个品牌的?如果我们没有展示这个广告,他们会不会还是自然而然地购买?从数据科学的角度看,这些问题很难回答!

留存营销活动

另一类营销活动是留存。这是旨在留住现有客户的营销活动。我们通常可以通过 A/B 测试来衡量这些活动。

获取营销图

通常将品牌营销和效果营销称为获取营销。如前所述,衡量品牌营销和效果营销非常具有挑战性——我们经常低估品牌营销的价值,而高估效果营销的价值。

下图是一个激励性的(但简化的)示例,展示了获取营销的工作原理:

用户生成的图像

我们如何(公平地)估算每个节点对收入的贡献?这就是内在因果影响发挥作用的地方——让我们在接下来的章节深入了解它是什么!

什么是内在因果影响,它是如何工作的?

这个概念来源于哪里?

这个概念最初在 2020 年的一篇论文中提出,并在今年更新:

[## 通过结构保持干预量化内在因果贡献

我们提出了一种因果影响的概念,描述了节点对目标的“内在”贡献部分……

proceedings.mlr.press

它在 Python 包 DoWhy 的 GCM 模块中实现:

[## 量化内在因果影响 - DoWhy 文档

通过量化内在因果影响,我们回答了这个问题:上游的因果影响有多强?

www.pywhy.org

我个人最初发现这个概念很难掌握,所以在接下来的章节中,让我们一步步地分解它。

因果图回顾

在我们试图理解内在因果影响之前,了解因果图、结构因果模型(SCM)和加性噪声模型(ANM)非常重要。我在系列文章中的一篇应该能帮助你快速理解:

## 使用因果图来回答因果问题

因果 AI,探索因果推理与机器学习的整合

[towardsdatascience.com

提醒一下,因果图中的每个节点可以视为一个模型中的目标,其中它的直接父节点作为特征。通常会为每个非根节点使用加性噪声模型:

用户生成的图像

什么才是真正的内在因果影响?

现在我们已经回顾了因果图,让我们开始理解什么是真正的内在因果影响……

“内在”的字典定义是“自然归属”。在我脑海里,我把它想象成一个漏斗,漏斗顶端的东西做着繁重的工作——我们希望赋予它们应得的因果影响。

让我们看下面的示例图,以帮助我们进一步解开内在因果影响:

用户生成的图像

  • A、B 和 C 是根节点。

  • D 是一个非根节点,我们可以通过它的直接父节点(A、B、C)和一个噪声项来建模。

  • E 是一个非根节点,类似于 D,我们可以通过它的直接父节点(A、B、C)和一个噪声项来建模。

  • F 是我们的目标节点,我们可以通过它的直接父节点(D、E)和一个噪声项来建模。

让我们聚焦于节点 D。它从节点 A、B 和 C 继承了一部分对节点 F 的影响。它对节点 F 的内在影响,部分并非仅仅从 A、B 和 C“继承”而来,而是来自噪声项。因此,我们可以说每个节点的噪声项可以用来估算对目标节点的内在因果影响。值得注意的是,根节点仅由噪声组成。

在案例研究中,我们将深入探讨如何精确计算内在因果影响。

它如何帮助我们衡量我们的营销活动?

希望你已经能看到营销获取示例和内在因果影响之间的联系——内在因果影响能否帮助我们停止低估品牌活动,停止高估表现型活动?让我们在案例研究中找出答案!

案例研究

背景

年底将至,营销总监正面临财务团队的压力,要求她为明年计划投入如此多的营销费用做出解释。财务团队使用一个最后点击模型,将收入归因于客户点击的最后一个东西。他们质疑为什么即便每个人都是通过有机或社交渠道进入的,依然需要花钱做电视广告!

数据科学团队的任务是估算每个营销渠道的内在因果影响。

设置图(DAG)

我们从使用专家领域知识来设置一个 DAG 开始,重复使用之前的营销获取示例:

# Create node lookup for channels
node_lookup = {0: 'Demand',
               1: 'TV spend',
               2: 'Social spend', 
               3: 'Organic clicks',                   
               4: 'Social clicks',
               5: 'Revenue'                                                                             
 }

total_nodes = len(node_lookup)

# Create adjacency matrix - this is the base for our graph
graph_actual = np.zeros((total_nodes, total_nodes))

# Create graph using expert domain knowledge
graph_actual[0, 3] = 1.0 # Demand -> Organic clicks
graph_actual[0, 4] = 1.0 # Demand -> Social clicks
graph_actual[1, 3] = 1.0 # Brand spend -> Organic clicks
graph_actual[2, 3] = 1.0 # Social spend -> Organic clicks
graph_actual[1, 4] = 1.0 # Brand spend -> Social clicks
graph_actual[2, 4] = 1.0 # Social spend -> Social clicks
graph_actual[3, 5] = 1.0 # Organic clicks -> Revenue
graph_actual[4, 5] = 1.0 # Social clicks -> Revenue

本质上,财务团队使用的最后点击模型只使用收入的直接父节点来衡量营销效果。

用户生成的图像

数据生成过程

我们根据 DAG 的数据生成过程创建一些数据样本:

  • 3 个根节点,由噪声项组成;需求、品牌支出和社交支出。

  • 2 个非根节点,均继承自 3 个根节点的影响力,并加上一些噪声项;有机点击,社交点击。

  • 1 个目标节点,继承自 2 个非根节点的影响力,并加上一个噪声项;收入

# Create dataframe with 1 column per code
df = pd.DataFrame(columns=node_lookup.values())

# Setup data generating process
df[node_lookup[0]] = np.random.normal(100000, 25000, size=(20000)) # Demand
df[node_lookup[1]] = np.random.normal(100000, 20000, size=(20000)) # Brand spend
df[node_lookup[2]] = np.random.normal(100000, 25000, size=(20000)) # Social spend
df[node_lookup[3]] = 0.75 * df[node_lookup[0]] + 0.50 * df[node_lookup[1]] + 0.25 * df[node_lookup[2]] + np.random.normal(loc=0, scale=2000, size=20000) # Organic clicks
df[node_lookup[4]] = 0.30 * df[node_lookup[0]] + 0.50 * df[node_lookup[1]] + 0.70 * df[node_lookup[2]] + np.random.normal(100000, 25000, size=(20000)) # Social clicks
df[node_lookup[5]] = df[node_lookup[3]] + df[node_lookup[4]] + np.random.normal(loc=0, scale=2000, size=20000) # Revenue

训练 SCM

现在我们可以使用 python 包 DoWhy 中的 GCM 模块来训练 SCM。我们设置数据生成过程时采用线性关系,因此可以使用岭回归作为每个非根节点的因果机制:

# Setup graph
graph = nx.from_numpy_array(graph_actual, create_using=nx.DiGraph)
graph = nx.relabel_nodes(graph, node_lookup)

# Create SCM
causal_model = gcm.InvertibleStructuralCausalModel(graph)

causal_model.set_causal_mechanism('Demand', gcm.EmpiricalDistribution()) # Deamnd
causal_model.set_causal_mechanism('TV spend', gcm.EmpiricalDistribution()) # Brand spend
causal_model.set_causal_mechanism('Social spend', gcm.EmpiricalDistribution()) # Social spend

causal_model.set_causal_mechanism('Organic clicks', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Organic clicks
causal_model.set_causal_mechanism('Social clicks', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Social clicks
causal_model.set_causal_mechanism('Revenue', gcm.AdditiveNoiseModel(gcm.ml.create_ridge_regressor())) # Revenue

gcm.fit(causal_model, df)

或者,我们可以使用 auto_assignment 函数来自动分配模型(而不是手动定义它们)。

我将跳过模型评估部分,因为这不是我文章的重点。然而,我们也可以使用 GCM 评估因果模型:

[## 评估 GCM - DoWhy 文档

建立图形因果模型(GCM)需要做出各种假设和模型选择,这些都可能影响……

www.pywhy.org

内在因果影响

我们可以轻松地使用 GCM 模块计算内在因果影响。我们这样做并将贡献转换为百分比:

# calculate intrinsic causal influence
ici = gcm.intrinsic_causal_influence(causal_model, target_node='Revenue')

def convert_to_percentage(value_dictionary):
    total_absolute_sum = np.sum([abs(v) for v in value_dictionary.values()])
    return {k: round(abs(v) / total_absolute_sum * 100, 1) for k, v in value_dictionary.items()}

convert_to_percentage(ici)

用户生成的图片

让我们在条形图上展示这些:

# Convert dictionary to DataFrame
df = pd.DataFrame(list(ici.items()), columns=['Node', 'Intrinsic Causal Influence'])

# Create a bar plot
plt.figure(figsize=(10, 6))
sns.barplot(x='Node', y='Intrinsic Causal Influence', data=df)

# Rotate x labels for better readability
plt.xticks(rotation=45)
plt.title('Bar Plot from Dictionary Data')
plt.show()

用户生成的图片

我们的结果是否直观?如果你回头查看数据生成过程代码,你会发现它们是直观的!仔细注意每个非根节点继承了什么,并且添加了哪些额外的噪声。

内在因果影响模块非常易于使用,但它并不能帮助我们理解其背后的方法——最后,让我们探索内在因果影响的工作原理!

内在因果影响——它是如何工作的?

我们希望估算每个节点的噪声项对目标节点的贡献:

  • 值得记住的是,根节点仅由噪声项组成。

  • 在非根节点中,我们将噪声项与从父节点继承的部分分开。

  • 我们还包括了目标节点的噪声项——这可以解释为未观察到的混杂因素或影响收入但不影响上游节点的未见因素(尽管它也可能是由于模型误设导致的)。

  • 噪声项随后用于解释目标节点的方差——这可以看作是一个以噪声项为特征、目标节点为结果的模型。

  • 该模型用于估算目标节点在给定噪声变量子集下的条件分布。

  • 然后使用 Shapley 值来估计每个噪声项的贡献——如果改变噪声项对目标的影响很小,那么内在因果影响将非常小。

总结性思考

今天我们讲解了如何估计你的营销活动的内在因果影响。以下是一些总结性思考:

  • 内在因果影响是一个强大的概念,可以应用于不同的使用场景,而不仅仅是营销领域。

  • 理解其内在机制将帮助你更有效地应用这一概念。

  • 准确识别 DAG 并正确估计图形是获得合理内在因果影响估计的关键。

  • 在营销获取示例中,你可能需要考虑为品牌营销添加滞后效应。

如果你想继续探索因果人工智能的旅程,欢迎关注我——在下一篇文章中,我们将探讨如何利用合成控制方法来验证和校准我们的因果模型。

参考文献

Dominik Janzing, Patrick Blöbaum, Atalanti A Mastakouri, Philipp M Faller, Lenon Minorics, Kailash Budhathoki. “通过结构保持干预量化内在因果贡献”; 第 27 届国际人工智能与统计学会议论文集, PMLR 238:2188–2196: proceedings.mlr.press/v238/janzing24a.html

Patrick Blöbaum, Peter Götz, Kailash Budhathoki, Atalanti A. Mastakouri, Dominik Janzing. “DoWhy-GCM: DoWhy 在图形因果模型中因果推断的扩展”; 机器学习研究杂志, MLOSS 25(147):1−7, 2024: jmlr.org/papers/v25/22-1258.html

认识 Git Stash:你未完成代码的秘密宝箱

原文:towardsdatascience.com/meet-git-stash-your-secret-chest-of-unfinished-code-6da48da15413?source=collection_archive---------5-----------------------#2024-10-26

图片由Stefan Steinbauer拍摄,来源于Unsplash

精通 Git

Git 的一个强大功能,用于临时保存正在进行的代码

Zolzaya LuvsandorjTowards Data Science Zolzaya Luvsandorj

·发表于Towards Data Science ·阅读时间:6 分钟·2024 年 10 月 26 日

--

假设你在进行代码更改时发现了一个需要立即修复的严重 bug。你试图切换分支,但由于当前功能分支有未提交的更改,切换失败。这些更改还没有准备好提交,但又太重要,不能丢弃。在这种需要快速切换上下文的情况下,git stash提供了一种优雅的解决方案,可以在不提交的情况下临时安全地存储未完成的代码。本文将探讨如何有效使用git stash

📍 1. 暂存

假设我们正在用笔和纸写一封信,但突然需要写另一封更紧急的信并寄出。我们的桌子上只能放一封信。既然已经花了一些时间写完目前的部分,把未完成的信丢掉太浪费了。与其扔掉它,我们可以把它放进一个安全的宝箱里,这样等我们完成更紧急的信后,可以拿出来继续写。这个方法让我们可以迅速开始写更紧急的信并快速寄出,同时又能保存我们写的其他信件内容。在这个类比中,未完成的信件就像是未提交的…

认识 NiceGUI:你即将成为最爱的 Python UI 库

原文:towardsdatascience.com/meet-the-nicegui-your-soon-to-be-favorite-python-ui-library-fb69f14bb0ac?source=collection_archive---------0-----------------------#2024-04-16

轻松快速地构建自定义网页应用

Youness MansarTowards Data Science Youness Mansar

·发布于 Towards Data Science ·阅读时间 8 分钟·2024 年 4 月 16 日

--

图片来源:Alexander SchimmeckUnsplash

认识 NiceGUI,一个简单的基于 Python 的 UI 框架,能够与您的网页浏览器或桌面应用无缝协作。无论你是在开发小型网页应用、仪表板,还是在玩机器人项目,NiceGUI 都通过其简单的界面和众多功能使之变得简单。

本文的目标是通过列出这个库的优缺点,并展示如何构建和部署一个 NiceGUI 应用,来说服你尝试它。(这不是一篇赞助文章,我只是喜欢这个库🙃)

Streamlit 与 NiceGUI:为什么要切换?

虽然 Streamlit 非常适合创建交互式应用,但处理事件和状态时可能会很棘手,特别是在较大的项目中。NiceGUI 则不同,它让你直接控制状态和交互,无需额外步骤或复杂的解决方法。

简单的状态管理

NiceGUI 使得状态管理变得简单。与 Streamlit 可能会意外重置状态不同,NiceGUI 保持状态稳定,无论是初始状态还是用户所做的更改。你可以使用回调函数以事件驱动的方式处理用户交互,而无需因页面完全刷新而感到烦恼,也不会丢失状态数据。

内存高效嵌入

原文:towardsdatascience.com/memory-efficient-embeddings-d637cba7f006?source=collection_archive---------2-----------------------#2024-01-01

使用一种新的嵌入层创建更小的模型

Dr. Robert KüblerTowards Data Science Dr. Robert Kübler

·发表于Towards Data Science ·13 分钟阅读·2024 年 1 月 1 日

--

图片来源:Kostiantyn Vierkieiev来自Unsplash

每当处理分类数据时,初学者通常会使用独热编码。这种方法通常可以,但如果你面对的是成千上万甚至数百万个类别,这种方法就变得不可行。其原因如下:

  1. 维度增加: 对于每个类别,你会得到一个额外的特征。这可能导致维度灾难。数据变得更加稀疏,模型可能会面临计算复杂度增加和泛化性能下降的问题。

  2. 语义丧失: 独热编码将每个类别视为独立的特征,忽略了类别之间潜在的语义关系。我们失去了原始类别变量中存在的有意义的关系。

这些问题出现在自然语言处理领域(我们有一堆词语)或推荐系统(我们有大量的客户和/或文章),可以通过嵌入技术来克服。然而,如果你有大量的嵌入,那么模型的内存需求可能会飙升,达到几个 GB。

在这篇文章中,我想向你展示几种减少内存占用的方法。其中一种方法来自一篇有趣的论文组合嵌入

使用 mergekit 合并大语言模型

原文:towardsdatascience.com/merge-large-language-models-with-mergekit-2118fb392b54?source=collection_archive---------0-----------------------#2024-01-08

轻松创建你自己的模型,无需 GPU!

Maxime LabonneTowards Data Science Maxime Labonne

·发表于 Towards Data Science ·阅读时间:11 分钟·2024 年 1 月 8 日

--

图片由作者提供

模型合并是一种将两个或更多 LLM 模型合并为单个模型的技术。这是一种相对较新且实验性的方法,用于以低成本(无需 GPU)创建新模型。模型合并效果出乎意料地好,并且产生了许多在Open LLM 排行榜上排名靠前的顶尖模型。

在本教程中,我们将使用mergekit库来实现。更具体地,我们将回顾四种合并方法,并提供配置示例。然后,我们将使用 mergekit 创建我们自己的模型,Marcoro14–7B-slerp,该模型在 Open LLM 排行榜(2024 年 1 月 2 日)上表现最佳。

代码可在GitHubGoogle Colab上找到。我推荐使用我的自动化笔记本来轻松运行 mergekit:🥱 LazyMergekit

特别感谢 Charles Goddard,mergekit 库的作者,感谢他审阅本文。

图片由作者提供

🤝 合并算法

在本节中,我们将重点介绍mergekit中目前实现的四种方法。请注意,还有其他方法,如线性插值任务算术。如果你对模型合并的论文感兴趣,推荐查看这个优秀的合集

1. SLERP

球面线性插值(SLERP)是一种在两个向量之间平滑插值的方法。它保持恒定的变化速率,并保留向量所处球面空间的几何属性。

有几个理由更倾向于使用 SLERP 而不是传统的线性插值。例如,在高维空间中,线性插值可能会导致插值向量的大小减小(即减少权重的尺度)。此外,权重的方向变化通常代表了更有意义的信息(如特征学习和表示),而不是变化的大小。

SLERP 的实现步骤如下:

  1. 将输入向量标准化为单位长度,确保它们表示的是方向而非大小。

  2. 使用它们的点积计算这些向量之间的角度。

  3. 如果向量几乎共线,则默认使用线性插值以提高效率。否则,SLERP 会根据插值因子tt=0 = 100%的第一个向量,t=1 = 100%的第二个模型)和向量之间的角度计算缩放因子。

  4. 这些因素用于加权原始向量,然后将其求和以得到插值向量。

SLERP 目前是最流行的合并方法,但它仅限于一次合并两个模型。仍然可以层次化地合并多个模型,如在Mistral-7B-Merge-14-v0.1中所示。

配置示例:

slices:
  - sources:
      - model: OpenPipe/mistral-ft-optimized-1218
        layer_range: [0, 32]
      - model: mlabonne/NeuralHermes-2.5-Mistral-7B
        layer_range: [0, 32]
merge_method: slerp
base_model: OpenPipe/mistral-ft-optimized-1218
parameters:
  t:
    - filter: self_attn
      value: [0, 0.5, 0.3, 0.7, 1]
    - filter: mlp
      value: [1, 0.5, 0.7, 0.3, 0]
    - value: 0.5
dtype: bfloat16

这是一个经典的 SLERP 配置,应用于两个模型的每一层。请注意,我们输入一个值梯度作为插值因子t。自注意力层和 MLP 层的参数将使用OpenPipe/mistral-ft-optimized-1218mlabonne/NeuralHermes-2.5-Mistral-7B的不同组合。其他层则是两个模型的 50/50 混合。

你可以在 Hugging Face Hub 上找到最终模型:mlabonne/NeuralPipe-7B-slerp

2. TIES

Yadav 等人论文中提出的TIES-Merging旨在高效地将多个任务特定的模型合并为一个多任务模型。它解决了模型合并中的两个主要挑战:

  • 模型参数中的冗余:它识别并消除任务特定模型中的冗余参数。通过关注微调过程中所做的变化,识别最重要的前 k%变化,并丢弃其余部分。

  • 参数符号之间的分歧:当不同模型对同一参数提出相反调整时,会产生冲突。TIES-Merging 通过创建一个统一的符号向量来解决这些冲突,表示所有模型中最主导的变化方向。

TIES-Merging 分为以下三个步骤:

  1. 修剪:通过仅保留最重要参数(密度参数)的部分,重置其余部分为零,从而减少任务特定模型中的冗余。

  2. 选择符号:通过基于累积大小的最主导方向(正或负)创建统一符号向量,解决不同模型之间的符号冲突。

  3. 不相交合并:平均与统一符号向量对齐的参数值,排除零值。

与 SLERP 不同,TIES 可以一次合并多个模型。

配置示例:

models:
  - model: mistralai/Mistral-7B-v0.1
    # no parameters necessary for base model
  - model: OpenPipe/mistral-ft-optimized-1218
    parameters:
      density: 0.5
      weight: 0.5
  - model: mlabonne/NeuralHermes-2.5-Mistral-7B
    parameters:
      density: 0.5
      weight: 0.3
merge_method: ties
base_model: mistralai/Mistral-7B-v0.1
parameters:
  normalize: true
dtype: float16

使用此配置,我们使用 Mistral-7B 作为基础模型来计算增量权重。我们合并相同的两个模型:mistral-ft-optimized-1218(50%)和NeuralHermes-2.5-Mistral-7B(30%),并进行归一化。这里的密度表示我们只保留每个模型 50%的参数(另一半来自基础模型)。

请注意,配置中的权重总和不等于 1,但normalize: true参数会自动在内部对其进行归一化。该配置灵感来自于OpenHermes-2.5-neural-chat-7b-v3–1–7B作者提供的参数。

您可以在 Hugging Face Hub 上找到最终模型:mlabonne/NeuralPipe-7B-ties

3. DARE

由 Yu 等人(2023 年)提出的DARE采用了类似于 TIES 的方法,主要有两个不同之处:

  • 剪枝:DARE 将微调后的权重随机重置为其原始值(即基础模型的值)。

  • 重新缩放:DARE 通过重新缩放权重来保持模型输出的期望值大致不变。它将两个(或多个)模型的重新缩放权重与基础模型的权重通过一个缩放因子相加。

Mergekit 对此方法的实现有两种形式:一种是带有 TIES 符号选择步骤(dare_ties),另一种是没有的(dare_linear)。

配置示例:

models:
  - model: mistralai/Mistral-7B-v0.1
    # No parameters necessary for base model
  - model: samir-fama/SamirGPT-v1
    parameters:
      density: 0.53
      weight: 0.4
  - model: abacusai/Slerp-CM-mist-dpo
    parameters:
      density: 0.53
      weight: 0.3
  - model: EmbeddedLLM/Mistral-7B-Merge-14-v0.2
    parameters:
      density: 0.53
      weight: 0.3
merge_method: dare_ties
base_model: mistralai/Mistral-7B-v0.1
parameters:
  int8_mask: true
dtype: bfloat16

在此配置中,我们基于 Mistral-7B 合并了三种不同的模型,使用了dare_ties。这次,我选择了权重之和为 1 的组合(权重和应该在 0.9 到 1.1 之间)。密度参数比论文中推荐的值(<0.5)稍高,但看起来它 consistently 给出了更好的结果(参见这个讨论)。

你可以在 Hugging Face Hub 上找到它,地址是mlabonne/Daredevil-7B。它也是本文中表现最好的合并模型,甚至超过了 Marcoro14–7B-slerp。

4. 透传

透传方法与之前的几种方法有显著不同。通过将来自不同大语言模型(LLM)的层级进行拼接,它可以生成具有独特参数数量的模型(例如,两个 7B 参数模型合并生成 9B 模型)。这些模型通常被社区称为“弗兰肯合并”或“弗兰肯斯坦模型”。

这个技术非常实验性,但它成功地创造了令人印象深刻的模型,比如使用两个 Llama 2 70B 模型合并的goliath-120b。最近发布的SOLAR-10.7B-v1.0也采用了相同的理念,称为深度向上扩展,详见他们的论文这里

配置示例:

slices:
  - sources:
    - model: OpenPipe/mistral-ft-optimized-1218
      layer_range: [0, 32]
  - sources:
    - model: mlabonne/NeuralHermes-2.5-Mistral-7B
      layer_range: [24, 32]
merge_method: passthrough
dtype: bfloat16

结果的弗兰肯合并将拥有来自第一个模型的所有 32 层,以及来自第二个模型的 8 层。这创建了一个总共 40 层和 8.99B 参数的弗兰肯合并模型。此配置灵感来自于GML-Mistral-merged-v1

你可以在 Hugging Face Hub 上找到最终模型,地址是mlabonne/NeuralPipe-9B-merged

💻 合并你自己的模型

在这一部分中,我们将使用 mergekit 加载合并配置,运行它,并将合并后的模型上传到 Hugging Face Hub。

首先,我们直接从源代码安装 mergekit,如下所示:

!git clone https://github.com/cg123/mergekit.git
!cd mergekit && pip install -q -e .

在接下来的代码块中,我们加载合并配置文件(YAML 格式)。我们还指定了合并后模型的名称,以供以后使用。你可以将上一节中的任何配置复制/粘贴到这里。

这次,我们将使用两个不同的模型:Marcoroni-7B-v3Mistral-7B-Merge-14-v0.1,并使用 SLERP 方法将它们合并。我们将配置保存为 yaml 文件,以便作为输入使用在合并命令中。

import yaml

MODEL_NAME = "Marcoro14-7B-slerp"
yaml_config = """
slices:
  - sources:
      - model: AIDC-ai-business/Marcoroni-7B-v3
        layer_range: [0, 32]
      - model: EmbeddedLLM/Mistral-7B-Merge-14-v0.1
        layer_range: [0, 32]
merge_method: slerp
base_model: AIDC-ai-business/Marcoroni-7B-v3
parameters:
  t:
    - filter: self_attn
      value: [0, 0.5, 0.3, 0.7, 1]
    - filter: mlp
      value: [1, 0.5, 0.7, 0.3, 0]
    - value: 0.5
dtype: bfloat16

"""

# Save config as yaml file
with open('config.yaml', 'w', encoding="utf-8") as f:
    f.write(yaml_config)

我们使用以下参数运行合并命令:

  • --copy-tokenizer 用于从基础模型复制分词器

  • --allow-crimes--out-shard-size 用于将模型分割成较小的分片,以便在低内存的 CPU 上计算。

  • --lazy-unpickle 以启用实验性的懒加载解包器,从而减少内存使用。

此外,一些模型可能需要--trust_remote_code标志(Mistral-7B 不需要此标志)。

这个命令将下载合并配置中列出的所有模型的权重,并运行选定的合并方法(大约需要 10 分钟)。

# Merge models
!mergekit-yaml config.yaml merge --copy-tokenizer --allow-crimes --out-shard-size 1B --lazy-unpickl

模型现在已经合并并保存在merge目录中。在上传之前,我们可以创建一个包含所有可复现性所需信息的 README 文件。以下代码块定义了一个 Jinja 模板,并自动填充来自合并配置的数据。

!pip install -qU huggingface_hub

from huggingface_hub import ModelCard, ModelCardData
from jinja2 import Template

username = "mlabonne"

template_text = """
---
license: apache-2.0
tags:
- merge
- mergekit
- lazymergekit
{%- for model in models %}
- {{ model }}
{%- endfor %}
---

# {{ model_name }}

{{ model_name }} is a merge of the following models using [mergekit](https://github.com/cg123/mergekit):

{%- for model in models %}
* [{{ model }}](https://huggingface.co/{{ model }})
{%- endfor %}

## 🧩 Configuration

```yaml

{{- yaml_config -}}

```py
"""

# Create a Jinja template object
jinja_template = Template(template_text.strip())

# Get list of models from config
data = yaml.safe_load(yaml_config)
if "models" in data:
    models = [data["models"][i]["model"] for i in range(len(data["models"])) if "parameters" in data["models"][i]]
elif "parameters" in data:
    models = [data["slices"][0]["sources"][i]["model"] for i in range(len(data["slices"][0]["sources"]))]
elif "slices" in data:
    models = [data["slices"][i]["sources"][0]["model"] for i in range(len(data["slices"]))]
else:
    raise Exception("No models or slices found in yaml config")

# Fill the template
content = jinja_template.render(
    model_name=MODEL_NAME,
    models=models,
    yaml_config=yaml_config,
    username=username,
)

# Save the model card
card = ModelCard(content)
card.save('merge/README.md')

现在我们有了模型卡片,可以将整个文件夹推送到 Hub。

from google.colab import userdata
from huggingface_hub import HfApi

username = "mlabonne"

# Defined in the secrets tab in Google Colab
api = HfApi(token=userdata.get("HF_TOKEN"))

api.create_repo(
    repo_id=f"{username}/{MODEL_NAME}",
    repo_type="model"
)
api.upload_folder(
    repo_id=f"{username}/{MODEL_NAME}",
    folder_path="merge",
)

该模型现在可以在 Hugging Face Hub 上获取,链接为mlabonne/Marcoro14–7B-slerp。在另一个笔记本中,我们可以使用以下代码在免费的 T4 GPU 上尝试该模型:

!pip install -qU transformers accelerate

from transformers import AutoTokenizer
import transformers
import torch

model = "mlabonne/Marcoro14-7B-slerp"
messages = [{"role": "user", "content": "What is a large language model?"}]

tokenizer = AutoTokenizer.from_pretrained(model)
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)

outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)

我们提出了“什么是大语言模型?”这个问题,并获得了以下输出:

大语言模型是一种人工智能(AI)系统,它通过大量的文本数据进行训练。其设计目的是理解和生成类似人类的语言,预测句子或文档中接下来可能出现的单词或短语。这些模型使用复杂的算法和神经网络架构,从数据中学习,并随着时间的推移提高其表现。一些著名的大语言模型包括 OpenAI 的 GPT-3 和 Google 的 BERT。

一切看起来不错,但我们需要更全面的评估。对于这种通用模型,有一些有趣的基准测试:

  • Chatbot Arena,它根据人类投票编制了一个基于 Elo 的 LLM 排行榜。

  • MT-bench(同一个链接),它使用 GPT-4 作为裁判,基于一组多轮问题对模型的回答进行评分。

  • NousResearch 基准套件,它汇集了四个基准测试:AGIEval、GPT4ALL、TruthfulQA 和 Bigbench。GPT4ALL 本身包括 HellaSwag、OpenBookQA、Winogrande、ARC-Easy、ARC-Challenge、BoolQ 和 PIQA。

  • Open LLM 排行榜,它汇集了六个基准测试:ARC、HellaSwag、MMLU、Winogrande、GSM8K 和 TruthfulQA。

不幸的是,我们无法将模型提交到 Chatbot Arena。相反,我选择使用 Open LLM 排行榜和 NousResearch 基准测试来评估它。

我将我们的模型提交到了Open LLM 排行榜(“🚀 Submit here!”标签)。如介绍所示,它在排行榜中排名为最佳 7B 参数模型。以下是完整的结果:

作者提供的图像

Open LLM Leaderboard 的问题在于这些基准是公开的。这意味着人们可以在测试数据上训练 LLMs 以获得更好的结果。通过合并最佳模型,我们也污染了自己的结果。可以安全地假设Marcoro14–7B-slerp 已被污染,并且本次合并中使用的某些模型可能已在测试集上进行过训练。如果您想创建最好的模型,而不是操控排行榜,我建议仅使用非合并模型来创建您自己的合并模型。

这就是为什么我们不想仅依赖 OpenLLM Leaderboard。对于 NousResearch 基准测试套件,我使用了🧐 LLM AutoEval通过简单的 Colab 笔记本自动计算分数。以下是与优秀的OpenHermes-2.5-Mistral-7B相比的结果:

图片由作者提供

我们在每个基准测试上都取得了显著的提升。请注意,NousResearch 基准测试套件与 Open LLM Leaderboard 共享一些任务:ARC-Challenge、TruthfulQA、HellaSwag 和 Winogrande。据我所知,Bigbench 是唯一一个 100% 不同的基准(如果不是,请随时与我联系)。然而,我们在此合并中使用的某些模型仍然可能在 Bigbench 上进行了训练。

结论

在本文中,我们介绍了将 LLMs 合并的四种不同方法。我们详细讲解了 SLERP、TIES、DARE 和 passthrough 的工作原理,并提供了配置示例。最后,我们使用 mergekit 运行 SLERP,创建了Marcoro14–7B-slerp,并将其上传至 Hugging Face Hub。我们在两个基准测试套件上取得了优秀的表现:Open LLM Leaderboard(表现最好的 7B 模型)和 NousResearch。如果您想创建自己的合并模型,我推荐使用我的自动化笔记本🥱 LazyMergekit

另一种合并多个模型的方法是将它们合并在 Mixture of Experts (MoE) 架构中。在下一篇文章中,我们将详细讨论如何执行此操作,并创建我们自己的类似 Mixtral 的模型。如果您喜欢本文,请在 Medium 和 Twitter 上关注我@maximelabonne

通过点击一次,了解更多关于机器学习的信息并支持我的工作——在这里成为 Medium 会员:

[## 通过我的推荐链接加入 Medium — Maxime Labonne

作为 Medium 会员,您的一部分会员费用将用于支持您阅读的作家,同时您可以完全访问每个故事…

medium.com](https://medium.com/@mlabonne/membership?source=post_page-----2118fb392b54--------------------------------)

使用 SLERP 合并标记以加速 LLM 推理

原文:towardsdatascience.com/merging-tokens-to-accelerate-llm-inference-with-slerp-38a32bf7f194?source=collection_archive---------9-----------------------#2024-04-19

我们可以通过使用 SLERP 合并连续的标记对,显著加速 LLM 下一标记的生成,从而减少进行完整预测所需的计算能力。

Samuel ChaineauTowards Data Science Samuel Chaineau

·发表于 Towards Data Science ·6 分钟阅读·2024 年 4 月 19 日

--

图片由 Martin Martz 提供,来源于 Unsplash

TL;DR:

本文介绍了一种通过使用球面线性插值(SLERP)合并标记来加速大型语言模型(LLM)推理的新方法。通过减少序列长度而保持质量,这项技术在 LLM 推理中提供了显著的速度提升,解决了较长序列带来的计算挑战。该方法仍然处于初步阶段,但揭示了 LLM 的双重世界:一个用于训练,另一个用于预测。

背景:

LLM 已经通过展现出卓越的生成能力,彻底改变了自然语言处理任务。然而,它们的有效性是有代价的——计算资源。随着 LLM 处理更长的序列,变换器计算的二次扩展变得越来越难以承受。传统的缓解方法,如缓存和量化,存在一定的局限性。因此,迫切需要创新的方法来加速 LLM 推理,而不至于过多损害质量。

当前在推理过程中生成 token 的方法是一种蛮力法,本质上是训练方法的转置。虽然这种方法在训练中已被证明有效,但在推理任务中可能并不是最有效的。因此,有必要开发一种专门用于推理过程中生成 token 的新方法,从而优化过程并进一步提高大语言模型(LLM)的效率。这突显了探索替代技术以应对 LLM 推理所面临的计算挑战的重要性。

最近,mergekit 库提出了使用 SLERP 方法合并网络权重,这通常能得到更好的结果。受到这一工作的启发,我决定尝试是否可以在一个序列中合并 token,从而生成一个更小的序列,以便在预测下一个 token 时处理。

普通生成与合并生成的对比

使用 SLERP 合并序列:

提出的这种方法涉及修改 LLM 的前向传播过程,通过球面线性插值(SLERP)合并 token,这一技术借鉴自计算机图形学和动画制作。与简单的平均化技术不同,SLERP 保留了 token 维度的球面特性,提供了更为精细的插值。合并过程包括几个步骤,用于高效地压缩输入序列:

序列长度调整

最初,输入序列会根据其长度进行调整:

  • 长度小于 3 的序列保持不变。

  • 对于奇数长度的序列,添加两个空 token,一个放在开头,一个放在结尾

  • 偶数长度的序列会在倒数第二个位置添加一个空 token。

这样做的目的是确保上下文中的第一个和最后一个 token 能够得到保留。

对形成:

调整后的序列将被格式化为一对对连续的 token。这一过程为聚合做好准备。

使用 SLERP 进行聚合:

每对 token 都会通过 SLERP 进行聚合,实际上将序列长度减半(并不是完全减半,因为我们会添加并保留一些额外的 token)。SLERP 对代表连续 token 的两个向量进行插值,从而创建一个新的向量。

为了高效地完成这一任务,我重新创建了所有 SLERP 函数,并使用原生 PyTorch 实现。然而,代码可能尚未经过优化。

层次切割和提示保留:

合并过程可以在模型架构的不同层次上进行,这被称为“层次切割”。此外,为了保留提示的完整性,可以指定序列的开始部分和/或结束部分保持不变。这在基于指令的模型中尤为有用,因为提示的开始部分应该始终被记住。

这种创新方法为解决 LLM 推理中相关的计算挑战提供了一种微妙的解决方案,承诺在不牺牲质量或准确性的情况下大幅提高速度。

合并序列的简单示例

这意味着什么?

具体来说,在 LLM 中,前向调用的输入是一个形状为(batch_size,序列长度)的令牌序列。嵌入层生成一个形状为(batch size,序列长度,维度)的序列。每个注意力模块将这个序列作为输入。在给定的注意力层中,你可以合并令牌,生成一个形状为(batch size,k,维度)的序列,其中 k 是压缩后的序列长度。选择在哪一层应用此操作就是“层级截止”。

接下来的注意力模块将不再需要计算一个(序列长度,序列长度)的注意力矩阵,而是一个更小的矩阵,因为 k 严格小于原始序列长度。

因此,合并可以发生在模型架构的不同层级。这个参数被称为“层级截止”。此外,为了确保提示语不会完全合并,你可以定义序列的部分部分在开始和/或结束时保持不变。对于基于 Instruct 的模型,这种方式更加高效,因为提示的起始部分应该始终被记住。

这种方法论的一个缺点是它严重依赖于所使用模型的前向传播过程,这要求你根据所选择的模型仔细重写“合并”过程。另一个缺点是每一步都需要重新计算注意力掩码,并可能需要重新计算位置嵌入。

结果:

对 Mistral 7B Instruct V0.2 模型进行的实验展示了有前景的结果。通过比较基础模型与不同层级截止的各种合并模型之间的预测,可以观察到,合并令牌对预测质量的影响不大。此外,合并模型在推理时间上表现出了显著的加速,特别是在较浅的层级上。该技术还展示了其在处理更长序列时的有效性,使其适用于各种应用场景。

在不同层级截止和序列长度下,合并推理模型的预测令牌与基础模型的准确性对比

对于不同的序列长度和层级截止,合并推理模型相比基础模型的加速倍数

一个缺点是我未能成功使前向调用达到最优。因此,可能通过重新思考过程,能找到许多优化的空间。

我还在 AlpacaEval 数据集上测试了 Mistral Instruct v0.2 的合并版本。我在第 20 个注意力模块上进行合并。结果非常令人鼓舞,因为该模型超过了 Falcon 7B、Gemma 7B 和 nous-hermes-13b。它表明,未重新思考位置编码的合并模型生成的平均令牌数增加了 600 个,说明该模型的表现更好,更具表现力。我尝试重新实现位置编码过程,但没有成功。

从排行榜中提取

总结来说,使用 SLERP 合并 tokens 是应对 LLM 推理计算挑战的一个强有力的解决方案。通过在速度和质量之间找到平衡,这种方法几乎只是重写了前向循环。

使用方法:

我准备了一个仓库,包含一个简单的笔记本,供大家体验:github.com/samchaineau/llm_slerp_generation

使用一个新类,适配了前向调用,你可以轻松地将 LLM 传递到生成管道,并在你的数据集上使用它。到目前为止,我的实验仅限于 Mistral 7B 模型,但我希望将其扩展到其他架构,看看性能是否能保持。

所有资源都已经上传,如果你希望在其他 LLM 上进行测试,可以随时联系我。

结论:

使用 SLERP 技术合并 tokens 应该被探索用于加速 LLM 推理。随着进一步优化和探索,它有潜力提高自然语言处理任务的效率和可扩展性。

如果你从事 AI 领域工作并希望将其提升到下一个水平:欢迎联系我!

Github 链接:github.com/samchaineau/llm_slerp_generation

HuggingFace 个人主页:huggingface.co/samchain

相关且启发性的作品:

除非另有说明,所有图片均由作者提供。

Meta Llama 3 优化的 CPU 推理与 Hugging Face 和 PyTorch

原文:towardsdatascience.com/meta-llama-3-optimized-cpu-inference-with-hugging-face-and-pytorch-9dde2926be5c?source=collection_archive---------2-----------------------#2024-04-19

使用 Nightcafe 创建 — 图片由作者提供

了解如何在部署 Meta* Llama 3 到 CPU 时减少模型延迟

Eduardo AlvarezTowards Data Science Eduardo Alvarez

·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 4 月 19 日

--

万众期待的Meta 第三代 Llama 发布终于来了,我希望确保你知道如何优化地部署这个最先进(SoTA)的 LLM。在本教程中,我们将重点介绍执行仅权重量化(WOQ),以压缩 8B 参数模型并提高推理延迟,但首先,让我们讨论一下 Meta Llama 3。

Llama 3

到目前为止,Llama 3 家族包括从 8B 到 70B 参数不等的模型,未来还会推出更多版本。模型配有宽松的 Meta Llama 3许可协议,鼓励在接受使用条款之前仔细审阅。这标志着 Llama 模型家族和开源 AI 的一个激动人心的篇章。

架构

Llama 3 是一个基于解码器单一变压器的自回归 LLM。与 Llama 2 相比,Meta 团队做出了以下显著改进:

  • 采用分组查询注意力(GQA),提高推理效率。

  • 优化的分词器,词汇表包含 128K 个标记,旨在更高效地编码语言。

  • 在一个 15 万亿标记的数据集上训练,这个数据集比 Llama 2 的训练数据集大了 7 倍,并且包含了 4 倍更多的代码。

以下图(图 1)为 print(model) 的结果,其中 modelmeta-llama/Meta-Llama-3–8B-Instruct。在该图中,我们可以看到模型由 32 个 LlamaDecoderLayer 组成,这些层包含了 Llama Attention 自注意力组件。此外,模型还包括 LlamaMLP、LlamaRMSNorm 和一个线性头。我们希望在 Llama 3 研究论文发布后能够了解更多信息。

图 1. print(model) 的输出,展示了 llama-3–8B-instruct 架构中各层的分布 — 图片来源:[作者]

语言建模性能

该模型已在多个行业标准的语言建模基准测试上进行评估,如 MMLU、GPQA、HumanEval、GSM-8K、MATH 等。在本教程中,我们将回顾“指令调优模型”的表现(图 2)。这些数据的最显著特点是,Llama 3 8B 参数模型在所报告的基准测试中超越了 Llama 2 70B 达 62% 到 143%,同时该模型的规模却比 Llama 2 小了 88%!

图 2. Llama 3 指令模型在 MMLU、GPQA、HumanEval、GSM-8K 和 MATH LLM 基准测试中的性能指标汇总。— 图片来源:作者

增强的语言建模性能、宽松的许可协议以及此最新 Llama 代代的架构效率标志着生成式 AI 领域一个令人激动的全新篇章的开始。让我们一起探索如何优化 Llama 3 的推理过程,以便在 CPU 上实现可扩展、低延迟的部署。

使用 PyTorch 优化 Llama 3 推理

上一篇文章中,我讨论了模型压缩和总体推理优化在开发基于 LLM 的应用中的重要性。在本教程中,我们将重点介绍如何对 meta-llama/Meta-Llama-3–8B-Instruct 应用仅权重量化(WOQ)。WOQ 提供了性能、延迟和准确性之间的平衡,并支持量化为 int4 或 int8。WOQ 的一个关键步骤是解量化,将 int4/int8 权重转换回 bf16,再进行计算。

图 3. 简单示意图展示了仅权重量化,其中预量化的权重为橙色,量化后的权重为绿色。请注意,这表示初步量化为 int4/int8 以及计算步骤中的解量化为 fp16/bf16。— 图片来源:作者

环境设置

你将需要大约 60GB 的 RAM 来执行 Llama-3-8B-Instruct 的 WOQ。这包括约 30GB 用于加载完整模型,以及约 30GB 用于量化过程中的峰值内存。WOQ Llama 3 只会消耗大约 10GB 的 RAM,这意味着我们可以通过将完整模型从内存中释放来释放大约 50GB 的 RAM。

你可以在 Intel® Tiber® 开发者云 免费的 JupyterLab环境中运行本教程。该环境提供了一款第四代 Intel*® Xeon® CPU,拥有 224 个线程和 504GB 内存,足以运行此代码。

如果在你自己的 IDE 中运行,你可能需要处理额外的依赖项,比如安装 Jupyter 和/或配置 conda/python 环境。在开始之前,请确保你已经安装了以下依赖项。

intel-extension-for-pytorch==2.2
transformers==4.35.2
torch==2.2.0
huggingface_hub

访问和配置 Llama 3

你需要一个 Hugging Face*帐户来访问 Llama 3 的模型和分词器。

为此,从设置菜单中选择“Access Tokens”(图 4),并创建一个令牌。

图 4. Hugging Face 令牌配置控制台快照 — 图片由作者提供

复制你的访问令牌并将其粘贴到在运行以下代码后在 Jupyter 单元格中生成的“Token”字段中。

from huggingface_hub import notebook_login, Repository

# Login to Hugging Face
notebook_login()

访问meta-llama/Meta-Llama-3–8B-Instruct,并在提供你的信息并提交 Llama 3 访问请求之前仔细评估条款和许可。接受模型条款并提供你的信息是你自己的责任。

使用 WOQ 对 Llama-3–8B-Instruct 进行量化

我们将利用Intel® Extension for PyTorch*来应用 WOQ 于 Llama 3。这一扩展包含了针对 Intel 硬件的最新 PyTorch 优化。按照这些步骤对 Llama 3 模型进行量化并执行推理:

  1. Llama 3 模型和分词器: 导入所需的包,并使用AutoModelForCausalLM.from_pretrained()AutoTokenizer.from_pretrained()方法加载 Llama-3–8B-Instruct 特定的权重和分词器。
import torch
import intel_extension_for_pytorch as ipex
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer

Model = 'meta-llama/Meta-Llama-3-8B-Instruct'

model = AutoModelForCausalLM.from_pretrained(Model)
tokenizer = AutoTokenizer.from_pretrained(Model)

2. 量化配置食谱: 配置 WOQ 量化食谱。我们可以将weight_dtype变量设置为所需的内存中数据类型,分别选择torch.quint4x2torch.qint8来表示 int4 和 in8。此外,我们可以使用lowp_model来定义去量化精度。现在,我们将其保持为ipex.quantization.WoqLowpMode.None,以保持默认的 bf16 计算精度。

qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
  weight_dtype=torch.quint4x2, # or torch.qint8
  lowp_mode=ipex.quantization.WoqLowpMode.NONE, # or FP16, BF16, INT8
)
checkpoint = None # optionally load int4 or int8 checkpoint

# PART 3: Model optimization and quantization
model_ipex = ipex.llm.optimize(model, quantization_config=qconfig, low_precision_checkpoint=checkpoint)

del model 

我们使用ipex.llm.optimize()来应用 WOQ,然后使用del model删除完整模型并释放大约 30GB 的 RAM。

  1. 提示 Llama 3: Llama 3 与 Llama 2 一样,拥有预定义的提示模板,用于其经过指令调优的模型。使用这个模板,开发者可以定义具体的模型行为指令,并提供用户提示和对话历史。
system= """\n\n You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. If you don't know the answer to a question, please don't share false information."""
user= "\n\n You are an expert in astronomy. Can you tell me 5 fun facts about the universe?"
model_answer_1 = 'None'

llama_prompt_tempate = f"""
<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>{system}
<|eot_id|>\n<|start_header_id|>user<|end_header_id|>{user}
<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>{model_answer_1}<|eot_id|>
"""

inputs = tokenizer(llama_prompt_tempate, return_tensors="pt").input_ids

我们提供所需字段,然后使用分词器将整个模板转换为模型的标记。

  1. Llama 3 推理: 在文本生成方面,我们利用TextStreamer生成实时推理流,而不是一次性打印整个输出。这使得读者能够获得更加自然的文本生成体验。我们将配置好的流传递给model_ipex.generate()以及其他文本生成参数。
streamer = TextStreamer(tokenizer,skip_prompt=True)

with torch.inference_mode():
    tokens = model_ipex.generate(
        inputs,
        streamer=streamer,
        pad_token_id=128001,
        eos_token_id=128001,
        max_new_tokens=300,
        repetition_penalty=1.5,
)

运行此代码后,模型将开始生成输出。请记住,这些是未经过滤和未加保护的输出。对于实际应用场景,你需要进行额外的后处理考虑。

图 5. 在 Intel Tiber Developer Cloud 的 JupyterLab 环境中运行 int4 模式下压缩的 Llama-3–8B-Instruct 流式推理——Gif 由作者提供

就这样。不到 20 行代码,你就能在这个生态系统中获得最新 SoTA LLM 的低延迟、CPU 优化版本。

部署时的注意事项

根据你的推理服务部署策略,有一些事项是你需要考虑的:

  • 如果在容器中部署 Llama 3 实例,WOQ 将提供更小的内存占用,并允许你在单一硬件节点上提供多个推理服务实例。

  • 在部署多个推理服务时,你应该优化每个服务实例分配的线程和内存。为后台进程留出足够的额外内存(约 4GB)和线程(约 4 个线程)。

  • 考虑将 WOQ 版本的模型保存并存储在模型注册表中,以避免在每个实例部署时重新量化模型。

结论与讨论

Meta 的 Llama 3 LLM 家族相比前几代有了显著的改进,并且提供了多种配置(更多配置即将推出)。在本教程中,我们探讨了通过仅权重量化(WOQ)来增强 CPU 推理,这是一种在最小化准确性影响的同时减少延迟的技术。

通过将以性能为导向的新一代 Llama 3 LLM 与像 WOQ 这样的优化技术结合,开发者可以为GenAI应用解锁新的可能性。这个组合简化了硬件需求,使 LLM 能够在新旧系统中集成并获得高保真、低延迟的结果。

接下来可以尝试的一些令人兴奋的事情是:

  1. 实验不同的量化级别: 你应该测试 int4 和 int8 量化,以确定在性能和准确性之间的最佳平衡,适应你的特定应用。

  2. 性能监控: 持续评估 Llama 3 模型在不同真实场景中的性能和准确性至关重要,以确保量化保持预期的有效性。

  3. 测试更多的 Llama 模型: 探索整个 Llama 3 系列,并评估 WOQ 和其他 PyTorch 的量化方法的影响。

感谢阅读!别忘了关注 我的个人资料以获取更多类似文章

其他名称和品牌可能被声明为他人的财产。

在 PySpark 中生成合成描述性数据

原文:towardsdatascience.com/methods-for-generating-synthetic-descriptive-data-c6678cc10aff?source=collection_archive---------16-----------------------#2024-01-04

使用各种数据源类型快速生成人工数据集中的文本数据。

Matt CollinsTowards Data Science Matt Collins

·发布于Towards Data Science ·7 分钟阅读·2024 年 1 月 4 日

--

使用 DALL-E 3 生成的图像

上一篇文章中,我们探讨了如何在合成的 PySpark DataFrame 中创建列之间的多对一关系。这个 DataFrame 只包含外键信息,并未生成可能在示范数据集中有用的文本信息。

对于任何希望填充人工数据集的人来说,您很可能想要生成描述性数据,如产品信息、位置详情、客户人口统计等。

在这篇文章中,我们将深入探讨几种能够以低成本和低努力生成合成文本数据的来源,并使用这些技术生成包含客户信息的 DataFrame。

为什么要创建合成数据集?

合成数据集是展示你的数据产品(如网站或分析平台)的一种绝佳方式。它允许用户和利益相关者与示例数据互动,进行有意义的分析,而不会侵犯任何关于敏感数据的隐私问题。

它对于探索机器学习算法也非常有用,允许数据科学家在真实数据有限的情况下训练模型。

性能测试数据工程管道活动是合成数据的另一个重要应用场景,它为团队提供了通过基础设施推送大规模数据的能力,帮助识别设计中的弱点,以及基准测试运行时间。

就我而言,我目前正在创建一个示例数据集,用于在高负载下对一些 Power BI 功能进行性能测试,稍后我将撰写相关文章。

数据集将包含销售数据,包括交易金额和其他描述性特征,如商店位置、员工姓名和客户电子邮件地址。

随机字符

从简单的开始,我们可以使用一些内建功能来生成随机文本数据。导入randomstring Python 模块后,我们可以使用以下简单的函数创建所需长度的随机字符串。

作者提供的图片:几个结果示例的截图

我们可以多次运行此函数,生成足够的数据用于我们的列,并将信息追加到列表中。

我们将在本文稍后部分审视如何将数据添加到 DataFrame 中。

优势与限制

这种数据生成方式非常通用,在示例数据集中的应用有限。尽管如此,它可以与其他字符串生成技术(如串联)结合使用,在付出很少努力的情况下提供更多价值。

以下是随机用户名的示例,用户名由名字作为基础,后接下划线和随机字符:

作者提供的图片:显示用随机字符串附加生成的多个用户名的截图。

这种数据可能有用的示例数据点包括:

  • 电子邮件地址

  • 密码

  • 产品代码

  • 用户名

APIs

APIs 是一个很好的信息来源,在构建真实的分析平台时,可能会作为数据源使用!

然而,也有许多 API 端点值得查询,用于通用的占位符数据,这些数据可以通过代表性的不同主题(如汇率)的数据,为演示仪表盘提供更多意义。

如果我们想要获取用户所在国家的地理数据,可以通过rest countries API 接口生成。这个 API 是免费的,无需注册即可开始使用。利用requests Python 模块,我们可以非常快速地获得国家列表。

作者提供的图片:print(sample_countries)语句,展示了打乱结果的示例。

注意:你应始终检查输出结果——这个 API 请求返回 250 条结果,超过了现实中国家的数量!

作者提供的图片:pprint(sorted(countries_list))语句的片段,展示了与美国相关的多个“国家”。

优势与限制

从 API 获取数据的复杂性和安全要求可能有所不同,这可能会使人望而却步。数据请求的方式和返回的数据格式也可能不同。文档由提供者自身生成,这也可能成为一个值得考虑的瓶颈。

此时值得一提的是,开发者已经创建了各种包,以简化与 API 的交互——我们将在下一节讨论第三方 Python 库!

说到这些,有大量丰富的数据可供你提取,可以用于面向客户的数据产品,也可用于演示目的。

第三方包

也有一些非常棒的现成包供我们使用。这些包减少了我们寻找数据源、处理和格式化数据以便使用的繁重工作。

Faker就是这样一个例子,它能够生成名字、地址等。

按照你的方便安装、导入并使用这样的包!

作者提供的图片:Faker 生成的多个名字的截图

好处与局限性

由于其他开发者已经为你完成了大量的前期工作,像Faker这样的包可以非常高效且低成本(无论是价格还是时间)。

仅凭Faker,我们已经能够直观地生成通用且有用的用户数据。还有其他库可供使用(通常与常见 API 交互),帮助开发者轻松地将有用数据引入到他们的应用程序中。

屏障可能存在于你所在组织关于使用第三方包的规定、包的许可以及包本身的可维护性,因此请注意可能的限制。

ChatGPT

不得不提的是,忽视大型语言模型(LLMs)是不合适的,因为它们是数据生成的宝贵资产。

一种简单的方法是让 ChatGPT(或同类工具)为你生成一份数据点列表,比如可能的客户姓名。

作者提供的图片:向 ChatGPT 请求一个名字列表。

我们可以更进一步,请求 LLM 帮助构建随机变量函数本身。你可以包括数据源的详细信息,甚至请求 ChatGPT 编写一个与您找到的 API 交互的函数。

这里的一个例子是请求帮助使用bored API,提取活动字段并返回供你使用:

作者提供的图片:向 ChatGPT 请求一个随机活动。

作者提供的图片:来自 bored API 的 ChatGPT 生成的函数示例结果。

注意:始终检查结果输出,确保 LLM 生成的功能按预期工作。

好处与局限性

生成性人工智能在快速数据生成方面展现了其优势(尤其是在内容创作方面,正如这份 Gartner 报告所示),我预计明年在填充示例数据集时将看到更多务实的方法。因此,它使得大规模创建各种数据点对不同背景和技术水平的用户都变得可行。

话虽如此,LLM(大语言模型)的可靠性和一致性也在同时受到广泛研究。根据您需要填充的数据类型,这可能会引入严格的数据质量和匿名性检查,从而决定这是否是特定用例的正确工具。

将其放入 DataFrame 中

我们在前几节中采用的方法是生成数据,准备好可以处理到感兴趣的 DataFrame 列中。

在这个阶段,有几种方式可以实现合成数据。UDF(用户定义函数)是一种简洁的方法,可以对每一行的列应用一个函数,并有助于定义唯一性。

Databricks Labs Data Generator是另一个优秀的库,它通过分布式计算加速了数据生成过程。

通过操控生成要填充的值列表的方式,我们可以使用withColumn PySpark 函数,并将其与我们的列表数据结合,以极快的速度生成文本列。

对于我们的测试案例,我们将结合使用这两种方法。我们可以使用 Databricks Labs Data Generator 创建 DataFrame 框架,填充 100 行数据,其中包含一个唯一的Id列,以及我们可能不要求唯一的值,例如国家

对于我们希望保持唯一性的其他列,例如全名用户名,我们可以利用 UDF 函数,使用 id 列从我们创建的列表中访问唯一名称。

综合来看,我们可以创建我们的 DataFrame:

作者提供的图片:DataFrame

快速的数据概览显示,我们的全名用户名列具有完全唯一的值,符合预期!

作者提供的图片:显示“全名”和“用户名”列中 100 个唯一值的分类列的数据概览

结论

我们概述了多种快速生成文本合成数据的方法,使我们能够加速示例数据集的创建。

上述所有示例都可以扩展、优化并根据您的特定用例进行定制。

有没有我遗漏的技巧?你是用合成数据集做什么的?在评论中告诉我吧!

感谢阅读,一如既往,您可以在这里找到供您使用的代码。

如何评估搜索相关性和排名

原文:towardsdatascience.com/metrics-that-matter-a-simple-guide-to-search-ranking-evaluation-4030084c35b4?source=collection_archive---------8-----------------------#2024-05-30

优化搜索引擎的关键指标

Akchay SrivastavaTowards Data Science Akchay Srivastava

·发表于 Towards Data Science ·阅读时间:6 分钟·2024 年 5 月 30 日

--

图片来自 Markus WinklerUnsplash

目录

  1. 引言

  2. 精度@K

  3. 平均平均精度(MAP)

  4. 平均倒数排名(MRR)

  5. 标准化折扣累积增益(NDCG)

  6. 比较分析

  7. 摘要

  8. 参考文献

免责声明:本文表达的观点仅代表个人观点,并不一定反映我雇主或任何其他组织的立场。所有图片均由作者提供,除非另有说明。

1. 引言

确保用户能够快速高效地找到所需信息,对于成功的搜索体验至关重要。当用户能够快速、轻松地找到他们需要的内容时,这就转化为一种积极的体验。

此外,相关结果的排名位置也起着至关重要的作用——它们出现得越靠前,对用户的价值就越大。这转化为更高的用户参与度、转化率和整体网站满意度。

本文探讨了用于评估搜索相关性和排名的关键指标,帮助您优化搜索引擎并提供卓越的用户体验。

为了以实际方式展示搜索相关性的概念,假设用户在搜索引擎中搜索“意大利面菜肴”。为简化分析,我们将分析引擎返回的前五个结果。相关的结果用绿色表示,而不相关的结果用红色标出(参见图 1)。我们将使用 Rn 表示第 n 个结果。

图 1:一个有序的搜索结果列表

2. Precision@K

Precision@K 衡量前 K 个位置中有多少结果是相关的。我们为不同的 K 值计算 Precision,如图 2 所示。

Precision@K = 前 K 个位置中的相关结果数 / K

Precision@1 = 1/1

Precision@3 = 1/3

Precision@5 = 2/5

图 2:Precision@K

3. 平均平均精度(MAP)

MAP 考虑相关结果的排名顺序。

首先,计算这些相关结果位置的 Precision@K。然后,通过将这些相关结果位置的 Precision@K 相加并除以前 K 个结果中的相关项总数,得到平均 Precision@K。为了简便起见,在讨论中我们将偶尔使用 AP 来表示平均精度。

为了更深入地了解 MAP 如何评估排名效果,我们将通过三个不同的搜索查询进行说明。这些示例将突出展示结果的呈现顺序如何影响 MAP 得分。

图 3:查询 1 中每个相关结果的 Precision@K

AP@5_Query_1 = (Precision@1 + Precision@3 + Precision@5) / 3

AP@5_Query_1 = (1 + 0.67 + 0.6) / 3 = 0.76

图 4:查询 2 中每个相关结果的 Precision@K

AP@5_Query_2 = (Precision@1 + Precision@2 + Precision@5) / 3

AP@5_Query_2 = (1 + 1 + 0.6) / 3 = 0.87

图 5:查询 3 中每个相关结果的 Precision@K

AP@5_Query_3 = (Precision@3 + Precision@4 + Precision@5) / 3

AP@5_Query_3 = (0.33 + 0.5 + 0.6) / 3 = 0.47

查询 2 的结果表现出最高的 Average Precision@5,表明最相关的项目位于排名列表的前面。

MAP = 数据集中所有查询的平均精度。

MAP@5 = (AP@5_Query_1 + AP@5_Query_2 + AP@5_Query_3) / 查询总数

数据集的 MAP@5 = (0.76 + 0.87 + 0.47) / 3 = 0.7

此计算将所有查询视为同等重要。然而,如果某些查询更为关键,可以在 MAP 过程中使用不同的加权方法来优先考虑它们。

4. 平均倒数排名(MRR)

MRR 只考虑列表中第一个相关结果的排名。

K = 第一个相关结果的排名

倒数得分 = 1 / K

MRR 是多个查询的平均倒数得分。如果没有相关结果,则第一个相关结果的排名视为无限大,因此倒数得分为 0。

图 6:每个查询的倒数得分(蓝色)

相关结果的倒数得分是其排名的反函数。

数据集的 MRR = (0.5 + 1 + 0.33) / 3 = 0.61

5. 归一化折扣累计增益(NDCG)

NDCG 考虑了结果的分级相关性。每个结果的相关性由一个分数(也称为“等级”)表示。NDCG 的值通过比较搜索引擎返回的结果与假设的“理想”搜索引擎所返回结果的相关性来确定。

假设我们有一个 1 到 5 的相关性/评分尺度,其中 5 是最高分,1 是最低分。我们搜索“意大利面菜肴”并手动对搜索结果进行评分,如图 7 所示。在我们的示例中,R3 是最相关的结果,得分为 5。

图 7:按相关性分数排序的搜索结果列表

累积增益@5 = 4 + 1 + 5 + 1 + 3 = 14

累积增益不考虑排名顺序。

折扣累计增益@K = 施加了一个对数折扣,当相关项出现在排名列表的较低位置时,帮助赋予较低的增益,如图 8 所示。

图 8:DCG@K 公式

其中,rel(i)是第 i 个位置上结果的相关性分数。

DCG@K = 4/1 + 1/1.585 + 5/2 + 1/2.322 + 3/2.585 = 8.72

DCG 的绝对值取决于列表中的结果数量以及分配的相关性分数。为了解决这个问题,可以对 DCG 进行归一化处理。为了得到归一化的 DCG(NDCG),我们将 DCG 除以给定结果集的理想 DCG(IDCG),如图 9 所示。IDCG 考虑了相同的相关性分数,但假设按照这些结果的最佳排名顺序来计算 DCG。上述示例的最佳排名顺序是:R3 → R1 → R5 → R2 → R4。

IDCG@K = 5/1 + 4/1.585 + 3/2 + 1/2.322 + 1/2.585 = 9.83

图 9:NDCG@K 公式

NDCG@K = 8.72/9.83 = 0.88

NDCG 考虑了结果的分级相关性,提供了对搜索排名质量更为细致的理解。

6. 比较分析

除了上述指标之外,斯皮尔曼相关系数肯德尔τ距离也可以用来评估排名列表的相似性。为了衡量用户参与度,点击率(CTR)是一个关键指标,反映了在结果展示后点击该结果的用户百分比。有关这些指标的更多信息,请参考参考文献中的维基百科资源。

7. 总结

照片来源:Alexander SchimmeckUnsplash

在我们探索了四种不同的搜索质量评估指标之后,我们进行了比较分析,旨在了解每种方法的优缺点。这自然引出了一个关键问题:哪个指标最适合评估搜索引擎结果的相关性和排名?最佳的指标选择取决于你的具体需求。

要全面了解搜索引擎的质量,通常建议考虑这些指标的组合,而不是仅依赖单一的衡量标准。

如果你觉得这篇文章有用,请引用以下文献:

Srivastava, Akchay. (2024 年 5 月). 如何评估搜索相关性和排名。 https://towardsdatascience.com/metrics-that-matter-a-simple-guide-to-search-ranking-evaluation-4030084c35b4

8. 参考文献:

  1. en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)

  2. en.wikipedia.org/wiki/Mean_reciprocal_rank

  3. en.wikipedia.org/wiki/Kendall_tau_distance

  4. en.wikipedia.org/wiki/Discounted_cumulative_gain

  5. en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient

  6. web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf

  7. www.coursera.org/lecture/recommender-metrics/rank-aware-top-n-metrics-Wk98r

  8. www.evidentlyai.com/ranking-metrics/ndcg-metric

  9. en.wikipedia.org/wiki/Inter-rater_reliability

  10. en.wikipedia.org/wiki/Click-through_rate

评估分类机器学习模型的指标

原文:towardsdatascience.com/metrics-to-evaluate-a-classification-machine-learning-model-f05f1facd569?source=collection_archive---------7-----------------------#2024-07-31

信用卡欺诈的案例研究

Lucas BragaTowards Data Science Lucas Braga

·发表于 Towards Data Science ·7 分钟阅读·2024 年 7 月 31 日

--

1\. Introduction
2\. How does a model make predictions
3\. Confusion Matrix
4\. Metrics to Evaluate Model Performance
5\. When to use what metrics

1. 引言

一旦我们训练了一个监督式机器学习模型来解决分类问题,我们会希望如果这就是我们的工作结束,我们可以直接丢给它新的数据。我们希望它能正确地对所有数据进行分类。然而,实际上,模型做出的并非所有预测都是正确的。在数据科学中有一句名言,由英国统计学家提出:“

“所有模型都是错误的;有些模型是有用的。”——CLEAR, James, 1976

那么,我们如何知道我们所拥有的模型有多好呢?简短的回答是,我们通过评估模型预测的准确性来判断。为此,我们可以使用多种指标。

2. 模型是如何进行预测的?即,模型是如何对数据进行分类的?

图片 1:模型进行预测

假设我们已经训练了一个机器学习模型,用来对信用卡交易进行分类,并决定该交易是否为欺诈交易。该模型将处理交易数据,并返回一个分数,这个分数可能是 0 到 1 之间的任意数字,例如:0.05、0.24、0.56、0.9875。对于本文,我们将默认阈值设为 0.5,这意味着如果模型给出的分数低于 0.5,则模型将该交易分类为非欺诈(这是模型的预测!)。如果模型给出的分数大于或等于 0.5,则模型将该交易分类为欺诈(这也是模型的预测!)。

实际操作中,我们并不会使用默认的 0.5 阈值。我们会尝试不同的阈值,以找出最适合优化模型性能的值,但这个讨论留到以后再进行。

3. 混淆矩阵

混淆矩阵是可视化分类模型性能的基本工具。它有助于理解预测的各种结果,包括:

  • 真正阳性 (TP)

  • 假阳性 (FP)

  • 假阴性 (FN)

  • 真正阴性 (TN)

让我们分解一下!

为了评估一个模型的有效性,我们需要将其预测与实际结果进行比较。实际结果也被称为“现实”。因此,模型可能将一笔交易分类为欺诈,而实际上,客户在同一笔交易中要求退款,声称他的信用卡被盗。

在这种情况下,模型正确地将交易预测为欺诈,属于真正阳性 (TP)

在欺诈检测的背景下,“阳性”类被标记为欺诈,而“阴性”类被标记为非欺诈。

另一方面,假阳性 (FP) 发生在模型也将一笔交易分类为欺诈,但在这种情况下,客户并未报告其信用卡有任何欺诈活动。因此,在这笔交易中,机器学习模型犯了一个错误。

真正阴性 (TN) 是当模型将交易分类为非欺诈,而实际上它确实不是欺诈。因此,模型做出了正确的分类。

假阴性 (FN) 是当模型将交易分类为非欺诈时,但实际上它是欺诈(客户报告了与该交易相关的信用卡欺诈活动)。在这种情况下,机器学习模型也犯了一个错误,但这是与假阳性不同类型的错误。

让我们看一下图 2

图 2:为欺诈分类的机器学习模型的混淆矩阵

让我们看一个不同的案例,可能更容易理解。设计了一项测试来判断病人是否感染了 COVID。见图 3。

图 3:COVID 测试的混淆矩阵

因此,对于每一笔交易,你可以检查它是 TP、FP、TN 还是 FN。你可以对数百万笔交易进行这种检查,并将 TP、FP、TN 和 FN 的计数写在一个 2x2 的表格中。这个表格也被称为混淆矩阵

假设你将模型对 100,000 笔交易的预测与其实际结果进行了比较,并得出了以下混淆矩阵(见图 4)。

图 4:混淆矩阵

4. 评估模型性能的指标

了解混淆矩阵后,我们就可以开始探索用于评估分类模型性能的指标。

精确度 = TP / (TP + FP)

它回答了这个问题:在所有预测中,正确预测的比例是多少?它反映了模型预测的欺诈案例中实际为欺诈的比例。

用简单的话来说:模型将其判定为欺诈时,实际是欺诈的比例是多少?

查看图像 4 中的混淆矩阵,我们计算精确度 = 76.09%,因为精确度 = 350 / (350 + 110)。

召回率 = TP / (TP + FN)

召回率也称为真正率(TPR)。它回答了这个问题:在所有实际的正类结果中,正确预测的比例是多少?

简单来说,模型在所有实际欺诈案件中,正确识别欺诈者的比例是多少?

使用图像 4 中的混淆矩阵,召回率 = 74.47%,因为召回率 = 350 / (350 + 120)。

警报率 = (TP + FP) / (TP + FP + TN + FN)

也称为阻塞率,该指标帮助回答问题:在所有预测中,正类预测占比是多少?

简单来说:模型预测为欺诈的情况占多少比例?

使用图像 4 中的混淆矩阵,警报率 = 0.46%,因为警报率 = (350 + 110) / (350 + 110 + 120 + 99420)。

F1 分数 = 2 x (精确度 x 召回率) / (精确度 + 召回率)

F1 分数是精确度和召回率的调和平均值。它是精确度和召回率之间的平衡度量,提供了一个综合评分来评估模型。

使用图像 4 中的混淆矩阵,F1 分数 = 75.27%,因为 F1 分数 = 2 * (76.09% * 74.47%) / (76.09% + 74.47%)。

准确率 = TP + TN / (TP + TN + FP + FN)

准确率帮助回答这个问题:在所有交易中,正确分类的交易占比是多少?

使用图像 4 中的混淆矩阵,准确率 = 99.77%,因为准确率 = (350 + 120) / (350 + 110 + 120 + 99420)。

图像 5:包含评估指标的混淆矩阵

5. 何时使用哪些指标

准确率是评估许多分类机器学习模型的常用指标。然而,准确率在目标变量不平衡的情况下效果不好。在欺诈检测的情况下,通常只有极小比例的数据是欺诈的;例如,在信用卡欺诈中,通常不到 1%的交易是欺诈交易。因此,即使模型说所有交易都是欺诈的(这将是非常错误的),或者说所有交易都不是欺诈的(这同样也是非常错误的),模型的准确率仍然可能超过 99%。

那么在这些情况下该怎么做呢?精确度、召回率和警报率。这些通常是能提供模型性能良好视角的指标,即使数据不平衡。具体使用哪个指标可能取决于你的利益相关者。我曾与一些利益相关者合作,他们表示,无论做什么,请保持至少 80%的精确度。因此,在这种情况下,利益相关者非常关注用户体验,因为如果精确度很低,意味着会有大量的假阳性,即模型会错误地屏蔽真正的客户,认为他们进行的是欺诈性的信用卡交易。

另一方面,精确率和召回率之间存在一个权衡:精确率越高,召回率越低。因此,如果模型的精确率非常高,它在发现所有欺诈案件方面就不够出色。从某种意义上讲,这也取决于欺诈案件对业务的成本(财务损失、合规问题、罚款等)与虚假正例对业务的成本(客户生命周期,进而影响业务盈利能力)。

因此,在精确率与召回率之间的财务决策不明确时,一个好的指标是 F1-Score,它有助于在精确率和召回率之间提供平衡,并优化两者。

最后但同样重要的是,警报率也是一个关键的指标,因为它能直观地反映机器学习模型计划阻止的交易数量。如果警报率非常高,比如 15%,那么意味着在客户下的所有订单中,15%会被阻止,只有 85%会被接受。所以,如果你的业务每天有 1,000,000 个订单,机器学习模型会认为其中的 150,000 个是欺诈交易并加以阻止。这是一个巨大的订单量被阻止,因此了解欺诈案件的比例非常重要。如果欺诈案件的比例约为 1%或更低,那么一个阻止 15%的模型不仅会犯很多错误,还会阻碍大部分的业务收入。

6. 结论

理解这些指标可以帮助数据科学家和分析师更好地解读分类模型的结果,并提升模型的表现。精确率和召回率比单纯的准确率提供了更多关于模型有效性的洞见,尤其是在像欺诈检测这样的领域,类分布严重偏斜的情况下。

图片:除非另有说明,所有图片均由作者提供。图像 1 中的机器人面孔是由 DALL-E 创建的,并且是公开使用的。

数据中的微服务与单体架构

原文:towardsdatascience.com/microservices-vs-monolithic-approaches-in-data-8d9d9a064d06?source=collection_archive---------1-----------------------#2024-02-16

这些令人惊叹的雕像中,唯一比它们更庞大的就是 Airflow。照片由Stephanie Morcinek拍摄,来源于Unsplash

微服务与单体架构的争论在软件领域激烈进行,但在数据领域却逐渐平息。

Hugo LuTowards Data Science Hugo Lu

·发表于Towards Data Science ·阅读时间:8 分钟·2024 年 2 月 16 日

--

引言

我不需要花费太多的文字来说服你,选择数据领域中的工具是多么困难。这里有成百上千种方法来解决这个问题

人们忽视的一点是,架构如何影响这些决策。

大约 20 年前,应用程序运行在由需要它们的公司拥有的计算机上——这就是所谓的“本地部署”。拥有这些计算机是一项架构决策。因此,云软件供应商并不存在,因为由于与当时架构的不兼容,云软件没有市场需求。

快进到 2024 年,情况正好相反——大多数人已经完全转向云计算。然而,我们中的一些人仍然在运营自己的服务器,另一些则采用混合模型。这意味着,理解架构对你选择的解决方案的影响比以往任何时候都更为重要,在本文中,我们将深入探讨微服务与单体架构在数据架构中对所购买工具的影响。

微服务与单体架构

MIDI 文件作为训练数据

原文:towardsdatascience.com/midi-files-as-training-data-b67852c8b291?source=collection_archive---------3-----------------------#2024-09-13

一个根本的区别:MIDI 乐谱与 MIDI 表演

Francesco FoscarinTowards Data Science Francesco Foscarin

·发表于 Towards Data Science ·10 分钟阅读·2024 年 9 月 13 日

--

在开始任何 MIDI 文件的深度学习项目之前,确保你知道MIDI 乐谱与 MIDI 表演之间的区别

本文适用于计划或刚开始使用 MIDI 文件的人。这个格式在音乐界被广泛使用,并且由于数据集的可用性,引起了计算机音乐研究人员的关注。

然而,不同类型的信息可以编码在 MIDI 文件中。特别是,MIDI 乐谱和 MIDI 表演之间有很大的区别。没有意识到这一点会导致浪费时间在无用的任务上或选择错误的训练数据和方法

我将提供这两种格式的基本介绍,并给出如何在 Python 中开始使用它们的实用示例。

什么是 MIDI?

MIDI 最初作为合成器之间的实时通信协议引入。其主要思想是每当在 MIDI 键盘上按下音符(note on)时发送一条消息,而在释放音符(note off)时发送另一条消息。然后接收端的合成器将知道生成何种声音。

欢迎来到 MIDI 文件!

如果我们收集并保存所有这些消息(确保添加它们的时间位置),那么我们就得到了一个 MIDI 文件,可以用来重现一段音乐。除了 note-on 和 note-off,还有许多其他类型的消息,例如指定踏板信息或其他控制器。

你可以考虑用pianoroll来绘制这些信息。

请注意,这不是 MIDI 文件,而只是其内容的一种可能表示!一些软件(例如Reaper)在钢琴滚动视图旁边添加了一个小型钢琴键盘,以便更容易地进行视觉解读。

MIDI 文件是如何创建的?

MIDI 文件可以主要通过两种方式创建:1)通过 MIDI 乐器演奏,2)通过手动在音序器(如 Reaper、Cubase、GarageBand、Logic)或乐谱编辑器(例如 MuseScore)中编写。

每种生成 MIDI 文件的方式也会产生不同类型的文件:

  1. 在 MIDI 乐器上演奏→ MIDI 表演

  2. 手动编写音符(音序器或乐谱)→ MIDI 乐谱

现在我们将深入探讨每种类型,然后总结它们之间的区别。

在开始之前,免责声明:我不会专门关注信息如何编码,而是关注可以从文件中提取哪些信息。例如,当我说“时间以秒为单位表示”时,这意味着我们可以获得秒数,尽管编码本身更加复杂。

MIDI 表演

我们可以在 MIDI 表演中找到四种信息:

  • 当音符开始时:音符起始

  • 当音符结束时:音符偏移(或音符持续时间计算为偏移 - 开始)

  • 演奏了哪个音符:音符音高

  • 键按下的“强度”如何:音符速度

音符起始和结束(以及持续时间)以秒为单位表示,对应于演奏者按下和释放音符的秒数。

音符音高通过一个整数编码,范围从 0(最低)到 127(最高);请注意,比钢琴能演奏的音符范围更广;钢琴的音域对应 21–108\。

音符速度也通过一个整数进行编码,范围从 0(静音)到 127(最大强度)。

绝大多数 MIDI 表演都是钢琴表演,因为大多数 MIDI 乐器是 MIDI 键盘。其他 MIDI 乐器(例如 MIDI 萨克斯风、MIDI 鼓和 MIDI 吉他传感器)也存在,但并不那么常见。

最大的人类 MIDI 表演数据集(古典钢琴音乐)是由 Google Magenta 提供的Maestro 数据集

MIDI 表演的主要特性

MIDI 表演的一个基本特性是永远不会有完全相同起始或持续时间的音符(理论上这是可能的,但在实践中极不可能)。

事实上,即使他们真的努力,演奏者也无法准确地同时按下两个(或更多)音符,因为人类的精确度是有限的。音符持续时间也同样如此。此外,这对大多数音乐家来说并不是优先考虑的,因为时间偏差有助于产生更具表现力或更具律动感的感觉。最后,连续的音符之间会有一些静默,或者部分重叠。

因此,MIDI 演奏有时也被称为 未量化 MIDI。时间位置分布在一个连续的时间尺度上,而不是量化到离散的位置(由于数字编码的原因,技术上它是离散的尺度,但非常细腻,因此我们可以认为它是连续的)。

实践示例

让我们来看一个 MIDI 演奏。我们将使用 ASAP 数据集,它托管在 GitHub 上。

在你喜欢的终端(我在 Windows 上使用 PowerShell)中,去到一个方便的位置并克隆仓库。

git clone https://github.com/fosfrancesco/asap-dataset

我们还将使用 Python 库 Partitura 来打开 MIDI 文件,因此你可以在你的 Python 环境中安装它。

pip install partitura

现在一切已设置好,让我们打开 MIDI 文件,并打印前 10 个音符。由于这是一个 MIDI 演奏,我们将使用 load_midi_performance 函数。

from pathlib import Path
import partitura as pt

# set the path to the asap dataset (change it to your local path!)
asap_basepath = Path('../asap-dataset/')
# select a performance, here we use Bach Prelude BWV 848 in C#
performance_path = Path("Bach/Prelude/bwv_848/Denisova06M.mid")
print("Loading midi file: ", asap_basepath/performance_path)
# load the performance
performance = pt.load_performance_midi(asap_basepath/performance_path)
# extract the note array
note_array = performance.note_array()
# print the dtype of the note array (helpful to know how to interpret it)
print("Numpy dtype:")
print(note_array.dtype)
# print the first 10 notes in the note array
print("First 10 notes:")
print(performance.note_array()[:10])

这个 Python 程序的输出应该是这样的:

Numpy dtype:
[('onset_sec', '<f4'), ('duration_sec', '<f4'), ('onset_tick', '<i4'), ('duration_tick', '<i4'), ('pitch', '<i4'), ('velocity', '<i4'), ('track', '<i4'), ('channel', '<i4'), ('id', '<U256')]
First 10 notes:
[(1.0286459, 0.21354167,  790, 164, 49, 53, 0, 0, 'n0')
 (1.03125  , 0.09765625,  792,  75, 77, 69, 0, 0, 'n1')
 (1.1302084, 0.046875  ,  868,  36, 73, 64, 0, 0, 'n2')
 (1.21875  , 0.07942709,  936,  61, 68, 66, 0, 0, 'n3')
 (1.3541666, 0.04166667, 1040,  32, 73, 34, 0, 0, 'n4')
 (1.4361979, 0.0390625 , 1103,  30, 61, 62, 0, 0, 'n5')
 (1.4361979, 0.04296875, 1103,  33, 77, 48, 0, 0, 'n6')
 (1.5143229, 0.07421875, 1163,  57, 73, 69, 0, 0, 'n7')
 (1.6380209, 0.06380209, 1258,  49, 78, 75, 0, 0, 'n8')
 (1.6393229, 0.21484375, 1259, 165, 51, 54, 0, 0, 'n9')]

你可以看到,我们有音符的起始时间和时长(秒)、音高和力度。其他字段对于 MIDI 演奏来说不那么重要。

起始时间和时长也以 ticks(时隙)表示。这更接近 MIDI 文件中实际编码这种信息的方式:选择一个非常短的时间持续单位(= 1 tick),然后所有时间信息都作为该单位的倍数进行编码。当你处理音乐演奏时,通常可以忽略这些信息,直接使用秒数信息。

你可以验证,永远不会有两个音符的起始时间或时长完全相同!

MIDI 乐谱

MIDI 乐谱使用 更丰富的 MIDI 消息集 来编码信息,如时间签名、调性签名、小节和拍号位置。

因此,它们类似于音乐乐谱(乐谱纸),尽管 它们仍然缺少一些重要信息,例如音高拼写、连音符、附点、休止符、连线等……

时间信息不是以秒为单位编码的,而是以更具音乐抽象性的单位编码,如四分音符。

MIDI 乐谱的主要特性

MIDI 乐谱的一个基本特征是,所有音符的起始时间都对齐到一个量化网格,该网格首先由小节位置定义,然后通过递归的整数分割(主要是 2 和 3,但也有其他分割,如 5、7、11 等)来定义小节内的附点音符。

实践示例

我们现在将查看巴赫前奏曲 BWV 848 C# 大调的乐谱,这是我们之前加载的演奏的乐谱。Partitura 有一个专门的 load_score_midi 函数。

from pathlib import Path
import partitura as pt

# set the path to the asap dataset (change it to your local path!)
asap_basepath = Path('../asap-dataset/')
# select a score, here we use Bach Prelude BWV 848 in C#
score_path = Path("Bach/Prelude/bwv_848/midi_score.mid")
print("Loading midi file: ", asap_basepath/score_path)
# load the score
score = pt.load_score_midi(asap_basepath/score_path)
# extract the note array
note_array = score.note_array()
# print the dtype of the note array (helpful to know how to interpret it)
print("Numpy dtype:")
print(note_array.dtype)
# print the first 10 notes in the note array
print("First 10 notes:")
print(score.note_array()[:10])

这个 Python 程序的输出应该是这样的:

Numpy dtype:
[('onset_beat', '<f4'), ('duration_beat', '<f4'), ('onset_quarter', '<f4'), ('duration_quarter', '<f4'), ('onset_div', '<i4'), ('duration_div', '<i4'), ('pitch', '<i4'), ('voice', '<i4'), ('id', '<U256'), ('divs_pq', '<i4')]
First 10 notes:
[(0\. , 1.9958333 , 0\.  , 0.99791664,   0, 479, 49, 1, 'P01_n425', 480)
 (0\. , 0.49583334, 0\.  , 0.24791667,   0, 119, 77, 1, 'P00_n0', 480)
 (0.5, 0.49583334, 0.25, 0.24791667, 120, 119, 73, 1, 'P00_n1', 480)
 (1\. , 0.49583334, 0.5 , 0.24791667, 240, 119, 68, 1, 'P00_n2', 480)
 (1.5, 0.49583334, 0.75, 0.24791667, 360, 119, 73, 1, 'P00_n3', 480)
 (2\. , 0.99583334, 1\.  , 0.49791667, 480, 239, 61, 1, 'P01_n426', 480)
 (2\. , 0.49583334, 1\.  , 0.24791667, 480, 119, 77, 1, 'P00_n4', 480)
 (2.5, 0.49583334, 1.25, 0.24791667, 600, 119, 73, 1, 'P00_n5', 480)
 (3\. , 1.9958333 , 1.5 , 0.99791664, 720, 479, 51, 1, 'P01_n427', 480)
 (3\. , 0.49583334, 1.5 , 0.24791667, 720, 119, 78, 1, 'P00_n6', 480)]

你可以看到音符的起始时间完全落在一个网格上。如果我们考虑 onset_quarter(第三列),我们可以看到 16 分音符每 0.25 四分之一音符出现一次,正如预期的那样。

时值稍微有些问题。例如,在这段乐谱中,一个十六分音符的quarter_duration应该是 0.25。但从 Python 输出中我们可以看到,时值实际上是 0.24791667。发生了什么?使用来生成这个 MIDI 文件的 MuseScore 将每个音符的时值稍微缩短了一些。为什么?仅仅是为了让这个 MIDI 文件的音频演绎听起来更好。确实如此,代价是给那些用这些文件做计算机音乐研究的人带来了许多问题。类似的问题也存在于广泛使用的数据集中,比如 Lakh MIDI 数据集。

MIDI 乐谱与 MIDI 表演

鉴于我们所看到的 MIDI 乐谱和 MIDI 表演之间的差异,下面我将给出一些通用的指导方针,以帮助正确设置您的深度学习系统。

对于音乐生成系统,更倾向使用 MIDI 乐谱,因为量化的音符位置可以用相当小的词汇表示,且其他简化也可能实现,比如只考虑单音旋律。

对于那些旨在模拟人类演奏和感知音乐的系统,使用 MIDI 表演数据,例如节拍跟踪系统、节奏估计系统和情感识别系统(专注于表现性演奏)。

使用这两种数据类型来完成诸如音符跟踪(输入:表演,输出:乐谱)和表现性表演生成(输入:乐谱,输出:表演)等任务。

额外的问题

我已经介绍了 MIDI 乐谱与 MIDI 表演之间的主要差异。然而,正如常常发生的那样,事情可能会更复杂

例如,某些数据集(如 AMAPS 数据集)最初是 MIDI 乐谱,但作者在每个音符上引入了时间变化,以模拟真实人类演奏者的时间偏差(注意,这仅发生在不同时间位置的音符之间;和弦中的所有音符仍然是完全同时发生的)。

此外,某些 MIDI 导出文件,如来自 MuseScore 的文件,也会通过更改节奏标记(如果作品改变了节奏)或在连续音符之间插入非常短的小停顿(我们在前面的例子中看到过)来尝试让 MIDI 乐谱更接近 MIDI 表演,并通过在参考音符起始之前稍微演奏一个非常短的音符来表现装饰音。

事实上,装饰音在 MIDI 乐谱中是一个非常令人头痛的问题。它们的时值在音乐术语中没有明确规定,我们只知道它们应该是“短”的。而它们的起始点与参考音符的起始点相同,但如果我们听 MIDI 文件的音频演绎时,会觉得非常奇怪。那么我们是应该缩短前一个音符,还是下一个音符,为装饰音腾出空间呢?

其他装饰音也存在问题,因为关于如何演奏它们没有统一的规则。例如,颤音应该包含多少个音符?应该从实际音符还是上行音符开始演奏一个刺音?

结论

MIDI 文件非常棒,因为它们明确提供了每个音符的音高、起始时间和持续时间等信息。这意味着,例如,与音频文件相比,针对 MIDI 数据的模型可以更小,并且可以使用更小的数据集进行训练。

这也带来了一些代价:MIDI 文件,以及一般意义上符号编码的音乐格式,是非常复杂的,因为它们以多种不同的方式编码了许多种类的信息。

为了正确使用 MIDI 数据作为训练数据,了解编码的数据类型非常重要。希望这篇文章能为你提供一个很好的起点,让你进一步了解这个话题!

[所有图表均来自作者。]

生成对抗网络(GAN)的极小极大优化设计

原文:towardsdatascience.com/mini-max-optimization-design-of-generative-adversarial-networks-gan-dc1b9ea44a02?source=collection_archive---------8-----------------------#2024-01-12

嵌套双层优化与平衡寻求目标

Michio SuginooTowards Data Science Michio Suginoo

·发布于 Towards Data Science ·阅读时间 8 分钟·2024 年 1 月 12 日

--

介绍

生成对抗网络(GAN)在生成逼真的合成数据方面表现出色,生成的数据与真实数据几乎无法区分。不幸的是,GAN 因其非法应用而引起了公众的关注,尤其是深度伪造。(Knight,2018)

顾名思义,生成对抗网络(GAN)由两个网络组成:生成网络(生成器)和对抗网络(判别器)。将对抗机制纳入其架构使得 GAN 成为一种特殊类型的生成网络。

重要的是,GAN 是非参数化的,因此不会强加诸如马尔可夫链之类的正式统计要求。与其强加统计假设,生成网络在对抗网络的帮助下,通过深度神经网络的反向传播学习真实数据的概率分布。

为了生成逼真的合成数据,生成对抗网络(GAN)在其架构中让这两个代理彼此对抗。在这个博弈中,生成器试图模拟模仿真实样本的合成样本,而判别器试图区分真实样本和合成样本。换句话说,当生成器G通过伪造来欺骗判别器时,判别器D则充当警察的角色,检测合成(伪造)数据。(Goodfellow 等人,2014,第 1 页)从某种意义上说,这两个代理试图实现截然相反的目标。

随着它们技能的提升,合成数据变得与真实数据无法区分。得益于它的对手(判别器),生成器学会了如何更好地模仿给定真实数据的概率分布。

由于在其架构中,GAN 需要训练两个学习者,通过相互作用来实现相反的目标,因此它具有独特的优化设计(双层训练机制和平衡寻求目标)。在这个背景下,我花了一些时间来消化 GAN 的架构设计。在此背景下,我决定写这篇文章与那些对 GAN 不熟悉的人分享我的学习过程,以便他们能更顺利地理解 GAN 的架构特点。我希望读者能把这篇文章作为补充资料,发现它对自己有所帮助。

作为一种预防措施,GAN 是启发式的。目前,GAN 有许多不同的变体应用。本文仅讨论原始 GAN 的架构设计。

原始 GAN 设计

原始 GAN 的基础架构首次在一篇开创性论文中提出:“生成对抗网络”(Goodfellow 等,2014 年)。在这篇原始的 GAN 论文中,为了训练这两个追求截然相反目标的智能体,合著者们设计了一个“双层优化(训练)”架构,其中一个内部训练模块(判别器的训练)被嵌套在另一个高层训练模块(生成器的训练)中。GAN 在这个双层训练框架中交替训练这两个智能体。

图片来源:作者

判别器与生成器

现在,让我们看看这两个智能体在学习过程中做了什么。

很明显,判别器本质上是一个二分类器。给定来自真实数据和合成数据的混合样本,它会将每个样本分类为真实(标签=1)或伪造/合成(标签=0)。

另一方面,生成器本质上是一个噪声分布。它被训练去模仿真实数据集的概率分布,通过一个迭代过程。在每一步的训练迭代中,学习到的生成模型(更新后的生成器)会被复制并用作新的噪声分布。此后,新的噪声分布将用于训练判别器。(Goodfellow I.,2015,第 2 页)

让我们设定以下内容:

图片来源:作者

我们输入噪声 z,并计算其先验分布 G(z),以定义生成器。

在这种设置下,生成器的最终目标是通过将自己的分布转变为尽可能接近真实数据集的分布,从而欺骗判别器。

图片来源:作者

两个目标函数:最小-最大博弈

GAN 的架构中反复进行的是两个代理的训练,而这两个代理具有对立的目标。因此,GAN 有两个目标函数:一个用于判别器,另一个用于生成器。

一方面,作为二分类器的判别器D需要最大化正确分配标签的概率,既包括真实数据(标签=1),也包括合成数据(标签=0)。

另一方面,生成器的最终目标是通过创建与真实数据难以区分的合成数据来欺骗分类器。因此,生成器尝试欺骗分类器,使得判别器错误地将合成数据分类为标签 1。换句话说,生成器的目标是“最大化 D 犯错的概率”。(Goodfellow 等人,2014,第 1 页)

在概念层面,为了实现这两个目标相反的目标,这两个代理可以参考以下通用的对数似然公式V,通常用于二分类问题。

作者提供的图像

在 GAN 训练过程中,判别器的目标是最大化目标函数,而生成器的目标是最小化目标函数的第二项。从这个意义上讲,联合作者将整体目标称为“极小极大博弈”。(Goodfellow 等人,2014,第 3 页)

非饱和修改:

在实现过程中,联合作者在训练生成器的初期遇到了饱和问题。

“在学习的早期,当生成器 G 较差时,判别器 D 可以高信心地拒绝样本,因为它们显然与训练数据不同。在这种情况下,log(1 — D(G(z)))会发生饱和。”

为了解决饱和问题,他们将原始对数似然目标函数的第二项转换如下,并建议生成器最大化这一项:

作者提供的图像

这个公式反映了生成器的目标“最大化 D 犯错的概率”。(Goodfellow 等人,2014,第 1 页)

评估

在训练过程中,生成器不断创造更好的合成数据以欺骗判别器,而判别器则提高其检测能力。从这个角度看,GAN 整体优化的最终目标并不是寻找这两个目标函数的全局最大值,而是寻求一个平衡点,在这个平衡点上,两个代理都无法进一步提升性能。从某种意义上讲,在平衡点,判别器无法区分真实数据和合成数据,因为生成器能够创造尽可能真实的合成数据。

这一目标函数的设定对 GAN 来说非常独特。其中一位联合作者 Ian Goodfellow 描述了平衡点如下:

“它对应于一个鞍点,即分类器的局部最大值和生成器的局部最小值”(Goodfellow I. ,2015,p.2)。

此外,平衡点在概念上可以通过随机猜测的概率 0.5(50%)来表示。

作者提供的图片

交替训练过程:嵌套双层优化

为了实现这一最终目标,GAN 设计了一个交替学习过程,在“二级优化”框架中,其中鉴别器的训练循环嵌套在生成器的另一个更高层次的训练循环中。这个二级优化框架使得 GAN 能够在这两个代理之间交替训练过程:k 步的 D 训练和一步 G 训练(Goodfellow 等,2014,p.3)。在这两个模型交替的过程中,重要的是在训练一个模型时冻结另一个模型的学习过程;“在保持鉴别器固定的情况下更新生成器,反之亦然”(Goodfellow I. ,2015,p.3)。

以下算法修订了原始 GAN 论文中提出的原始算法,以充分反映生成器对数似然转换的推荐。

作者提供的图片

正如你在算法中看到的,GAN 在前向传播过程中同时从生成模型(生成器)和真实数据中采样,而在反向传播过程中同时训练这两个代理。(Goodfellow 等,2014,p.2)它遵循了深度神经网络的常规。

GAN 首先在嵌套块中训练鉴别器,然后在每次迭代时训练生成器以欺骗训练过的鉴别器,然后继续迭代这个二级训练,直到达到前面讨论的平衡点。

总体而言,技术上,GAN 通过生成器学习真实数据的概率分布;鉴别器只是嵌套在生成器学习机制中的一个内部组件。生成器的目标函数在其公式的上层优化过程中,反映了训练过的鉴别器模型的逐步内容。换句话说,每次迭代时,一旦鉴别器在嵌套的优化模块中被训练,生成器就会不断更新其目标函数。

这几乎描绘了 GAN 模型优化的算法设计。

总结

为了交替训练两个代理——鉴别器和生成器——GAN 采用了一个双层优化框架,其中鉴别器在嵌套在生成器训练块内的内部模块中进行训练。

由于这两个代理有截然相反的目标(因为判别器旨在最大化其二元分类器的目标函数,而生成器则旨在最小化该函数),合著者称整体目标为“极小极大博弈”。(Goodfellow 等人,2014,第 3 页)总的来说,GAN 通过寻求一个平衡点来实现其极小极大优化(训练)目标,在这个平衡点上,判别器无法再区分真实数据和合成数据,因为现在生成器生成的合成数据与真实数据无法区分。

它的嵌套双层训练框架及其寻求平衡的目标设置(与最大化目标相对)构成了 GAN 的极小极大优化框架。

最后,重要的是要指出,主要作者 Ian Goodfellow 表示,原始的 GAN 是启发式的,并且具有理论上的局限性。例如,当目标函数不是凸函数时,无法保证收敛。在这种背景下,他阐述了 GAN 仍然有进一步创新改进的空间。事实上,针对 GAN 应用的多种变体,已经探索了广泛的评估度量(Borji,2018)。因此,我想强调的是,本文所述的架构设计仅描述了最初的 GAN 论文中提出的 GAN 原型。因此,本文介绍的架构设计并不是对其他类型 GAN 应用的全面或普遍适用的设计。

鉴于这一预警说明已向读者充分传达,我希望本文能对那些刚接触 GAN 的读者有所帮助,帮助他们开启自己的 GAN 之旅。

参考文献

最小可行机器学习模型(MLE)

原文:towardsdatascience.com/minimum-viable-mle-306877dd6030?source=collection_archive---------9-----------------------#2024-10-31

构建一个最小化的生产就绪情感分析模型

Lenix CarterTowards Data Science Lenix Carter

·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 10 月 31 日

--

图片由Stephen Dawson提供,来自Unsplash

什么是生产就绪模型?

我们经常听到“生产化的机器学习”,但要让模型在实际应用中茁壮成长究竟意味着什么呢?在生产中,许多因素都影响并有助于机器学习模型的效能。为了本文的目的,我们将重点关注其中的五个因素。

  • 可复现性

  • 监控

  • 测试

  • 自动化

  • 版本控制

提供推断

构建生产就绪机器学习模型最重要的部分是能够访问它。

为此,我们构建了一个 FastAPI 客户端,用于提供情感分析的响应。我们利用 Pydantic 来确保输入和输出的结构化。我们使用的模型是来自 Huggingface 的 Transformers 库中的基本情感分析管道,这使我们能够使用预训练模型开始进行测试。

# Filename: main.py
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline

app = FastAPI()
classifier = pipeline("sentiment-analysis")

class TextInput(BaseModel):
    text: str

class SentimentOutput(BaseModel):
    text: str
    sentiment: str
    score: float

@app.post("/predict", response_model=SentimentOutput)
async def predict_sentiment(input_data: TextInput):
    result = classifier(input_data.text)[0]
    return SentimentOutput(
        text=input_data.text,
        sentiment=result["label"],
        score=result["score"]
    )

为了确保我们的工作是可复现的,我们可以使用 requirements.txt 文件和 pip。

# Filename: requirements.txt
# Note: This has all required packages for the final result. 

fastapi==0.68.1
uvicorn==0.15.0
transformers==4.30.0
torch==2.0.0
pydantic==1.10.0
numpy==1.24.3
sentencepiece==0.1.99
protobuf==3.20.3
prometheus-client==0.17.1

要安装此项,请初始化您的文件中的 venv并运行:pip install -r requirements.txt

要托管此 API,只需运行:uvicorn main:app --reload

现在您拥有一个可以通过以下方式查询的 API:

curl -X POST "http://localhost:8000/predict" \
  -H "Content-Type: application/json" \
  -d '{"text": "I love using FastAPI!"}'

或者您想要的任何 API 工具(即Postman)。您应该能收到一个返回结果,其中包括文本查询、预测的情感以及预测的置信度。

我们稍后将使用 GitHub 进行 CI/CD,因此我建议在此目录中初始化并使用 git

我们现在有一个本地托管的机器学习推理 API。

进一步提高可复现性

为了让我们的代码能够更一致地执行,我们将使用 Docker。Docker 模拟了一个轻量级的环境,允许应用程序在隔离的容器中运行,类似于虚拟机。这种隔离确保应用程序能够在任何安装了 Docker 的计算机上稳定执行,而不受底层系统的影响。

首先,为你的操作系统设置 Docker

# Filename: Dockerfile

# Use the official Python 3.9 slim image as the base
FROM python:3.9-slim

# Set the working directory inside the container to /app
WORKDIR /app

# Copy the requirements.txt file to the working directory
COPY requirements.txt .

# Install the Python dependencies listed in requirements.txt
RUN pip install -r requirements.txt

# Copy the main application file (main.py) to the working directory
COPY main.py .

# Define the command to run the FastAPI application with Uvicorn
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

此时,你应该拥有如下的目录结构。

your-project/
├── Dockerfile
├── requirements.txt
└── main.py

现在,你可以构建镜像并使用以下命令运行这个 API:

# Build the Docker image
docker build -t sentiment-api .

# Run the container
docker run -p 8000:8000 sentiment-api

你现在应该能够像之前一样进行查询。

curl -X POST "http://localhost:8000/predict" \
  -H "Content-Type: application/json" \
  -d '{"text": "I love using FastAPI!"}'

我们现在拥有一个容器化的、本地托管的机器学习推理 API。

添加基础监控

在机器学习应用中,监控对于理解模型性能以及确保其达到预期的准确性和效率至关重要。像Prometheus这样的工具帮助跟踪诸如预测延迟、请求计数和模型输出分布等指标,使你能够识别诸如模型漂移或资源瓶颈等问题。这种主动的方法确保你的机器学习模型随着时间的推移保持有效,并能够适应不断变化的数据或使用模式。在我们的案例中,我们专注于预测时间、请求和收集查询信息。

from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
from prometheus_client import Counter, Histogram, start_http_server
import time

# Start prometheus metrics server on port 8001
start_http_server(8001)

app = FastAPI()

# Metrics
PREDICTION_TIME = Histogram('prediction_duration_seconds', 'Time spent processing prediction')
REQUESTS = Counter('prediction_requests_total', 'Total requests')
SENTIMENT_SCORE = Histogram('sentiment_score', 'Histogram of sentiment scores', buckets=[0.0, 0.25, 0.5, 0.75, 1.0])

class TextInput(BaseModel):
    text: str

class SentimentOutput(BaseModel):
    text: str
    sentiment: str
    score: float

@app.post("/predict", response_model=SentimentOutput)
async def predict_sentiment(input_data: TextInput):
    REQUESTS.inc()
    start_time = time.time()

    result = classifier(input_data.text)[0]

    score = result["score"]
    SENTIMENT_SCORE.observe(score)  # Record the sentiment score

    PREDICTION_TIME.observe(time.time() - start_time)

    return SentimentOutput(
        text=input_data.text,
        sentiment=result["label"],
        score=score
    )

使用自定义模型

尽管构建和微调模型的过程并不是本项目的目的,但理解如何将模型添加到此过程中非常重要。

# Filename: train.py

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader

def train_model():
    # Load dataset
    full_dataset = load_dataset("stanfordnlp/imdb", split="train")
    dataset = full_dataset.shuffle(seed=42).select(range(10000))

    model_name = "distilbert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    model.train()

    # Create a DataLoader for batching
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Training loop
    num_epochs = 3  # Set the number of epochs
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            inputs = tokenizer(batch["text"], truncation=True, padding=True, return_tensors="pt", max_length=512).to(device)
            labels = torch.tensor(batch["label"]).to(device)

            optimizer.zero_grad()
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")

    # Save the model
    model.save_pretrained("./model/")
    tokenizer.save_pretrained("./model/")

    # Test the model with sample sentences
    test_sentences = [
        "This movie was fantastic!",
        "I absolutely hated this film.",
        "It was just okay, not great.",
        "An absolute masterpiece!",
        "Waste of time!",
        "A beautiful story and well acted.",
        "Not my type of movie.",
        "It could have been better.",
        "A thrilling adventure from start to finish!",
        "Very disappointing."
    ]

    # Switch model to evaluation mode
    model.eval()

    # Prepare tokenizer for test inputs
    inputs = tokenizer(test_sentences, truncation=True, padding=True, return_tensors="pt", max_length=512).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=1)

    # Print predictions
    for sentence, prediction in zip(test_sentences, predictions):
        sentiment = "positive" if prediction.item() == 1 else "negative"
        print(f"Input: \"{sentence}\" -> Predicted sentiment: {sentiment}")

# Call the function to train the model and test it
train_model()

为了确保我们能够查询我们训练的新模型,我们需要更新一些现有的文件。例如,在main.py中,我们现在使用来自./model的模型,并将其加载为预训练模型。此外,为了对比,我们现在有两个可用的端点,/predict/naivepredict/trained

# Filename: main.py

from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import pipeline
from prometheus_client import Counter, Histogram, start_http_server
import time

# Start prometheus metrics server on port 8001
start_http_server(8001)

app = FastAPI()

# Load the trained model and tokenizer from the local directory
model_path = "./model"  # Path to your saved model
tokenizer = AutoTokenizer.from_pretrained(model_path)
trained_model = AutoModelForSequenceClassification.from_pretrained(model_path)

# Create pipelines
naive_classifier = pipeline("sentiment-analysis", device=-1)
trained_classifier = pipeline("sentiment-analysis", model=trained_model, tokenizer=tokenizer, device=-1)

# Metrics
PREDICTION_TIME = Histogram('prediction_duration_seconds', 'Time spent processing prediction')
REQUESTS = Counter('prediction_requests_total', 'Total requests')
SENTIMENT_SCORE = Histogram('sentiment_score', 'Histogram of sentiment scores', buckets=[0.0, 0.25, 0.5, 0.75, 1.0])

class TextInput(BaseModel):
    text: str

class SentimentOutput(BaseModel):
    text: str
    sentiment: str
    score: float

@app.post("/predict/naive", response_model=SentimentOutput)
async def predict_naive_sentiment(input_data: TextInput):
    REQUESTS.inc()
    start_time = time.time()

    result = naive_classifier(input_data.text)[0]

    score = result["score"]
    SENTIMENT_SCORE.observe(score)  # Record the sentiment score

    PREDICTION_TIME.observe(time.time() - start_time)

    return SentimentOutput(
        text=input_data.text,
        sentiment=result["label"],
        score=score
    )

@app.post("/predict/trained", response_model=SentimentOutput)
async def predict_trained_sentiment(input_data: TextInput):
    REQUESTS.inc()
    start_time = time.time()

    result = trained_classifier(input_data.text)[0]

    score = result["score"]
    SENTIMENT_SCORE.observe(score)  # Record the sentiment score

我们还必须更新我们的 Dockerfile,以包含我们的模型文件。

# Filename: Dockerfile
FROM python:3.9-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install -r requirements.txt

COPY main.py .
COPY ./model ./model

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

重要的是,如果你使用 git,确保将pytorch_model.bin文件添加到git lfs,这样你就可以推送到 GitHub。git lfs 允许你对非常大的文件使用版本控制。

添加测试和 CI/CD

CI/CD 和测试通过确保代码更改自动集成、测试和部署,从而简化了机器学习模型的部署,这减少了错误的风险并提高了模型的可靠性。这个过程促进了持续改进和更快的迭代周期,使团队能够更高效地交付高质量、准备投入生产的模型。首先,我们创建了两个非常基础的测试,确保我们的模型表现得足够好。

# Filename: test_model.py

import pytest
from fastapi.testclient import TestClient
from main import app

client = TestClient(app)

def test_positive_sentiment():
    response = client.post(
        "/predict/trained",
        json={"text": "This is amazing!"}
    )
    assert response.status_code == 200
    data = response.json()
    assert data["sentiment"] == "LABEL_1"
    assert data["score"] > 0.5

def test_negative_sentiment():
    response = client.post(
        "/predict/trained",
        json={"text": "This is terrible!"}
    )
    assert response.status_code == 200
    data = response.json()
    assert data["sentiment"] == "LABEL_0"
    assert data["score"] < 0.5

要测试你的代码,你可以在端点运行时简单地执行pytestpython -m pytest

然而,当推送到 GitHub 时,我们将添加自动化测试 CI/CD(持续集成与持续交付)。

# Filename: .github/workflows/ci_cd.yml

name: CI/CD

on: [push]

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - name: Checkout code
        uses: actions/checkout@v2
        with:
          lfs: true

      - name: Set up Python
        uses: actions/setup-python@v2
        with:
          python-version: '3.9'

      - name: Install dependencies
        run: |
          pip install -r requirements.txt
          pip install pytest httpx

      - name: Run tests
        run: pytest

我们最终的项目结构应如下所示。

sentiment-analysis-project/
├── .github/
│   └── workflows/
│       └── ci_cd.yml
├── test_model.py
├── main.py
├── Dockerfile
├── requirements.txt
└── train.py

现在,每当我们推送到 GitHub 时,它将运行一个自动化过程,检出代码,设置 Python 3.9 环境,安装依赖,并使用 pytest 运行我们的测试。

结论

在这个项目中,我们开发了一个生产就绪的情感分析 API,突出了部署机器学习模型的关键方面。虽然它并不涵盖该领域的每个方面,但它提供了涉及这一过程的基本任务的代表性样本。通过检查这些组件,我希望能澄清你可能遇到过的概念,但不确定它们如何在实际环境中配合工作。

时间序列中的缺失数据:机器学习技术

原文:towardsdatascience.com/missing-data-in-time-series-machine-learning-techniques-6b2273ff8b45?source=collection_archive---------4-----------------------#2024-12-10

第一部分:利用线性回归和决策树填补时间序列中的缺失数据。

Sara NóbregaTowards Data Science Sara Nóbrega

·发表于 Towards Data Science ·13 分钟阅读·2024 年 12 月 10 日

--

来源:DALL-E。

(如果你没有订阅,请点击此处阅读此故事 这里****)。

时间序列分析中的缺失数据 — 听起来很熟悉吗?

数据集中的缺失数据是不是由于传感器故障、数据传输问题或任何类型的维护而导致的,听起来很熟悉?

嗯,缺失值会破坏你的预测并扭曲你的分析。

那么,你如何修复它们呢?

传统方法 可能看起来是解决方案——前向填充或插值——但这足够好吗

当你的数据具有复杂的模式、非线性趋势或高度变化时,会发生什么? 简单的技术将失败,并导致不稳定的结果。

如果有更明智的方法来应对这个挑战呢?

机器学习正是如此: 从回归分析到 K 最近邻再到神经网络,它们不假设任何内容,而是适应并精确地填补空白。

好奇吗? 让我们更深入地了解这些先进的方法如何改变你的时间序列分析。

我们将使用…填补缺失数据

缺失值填充,解释:适合初学者的视觉指南与代码示例

原文:towardsdatascience.com/missing-value-imputation-explained-a-visual-guide-with-code-examples-for-beginners-93e0726284eb?source=collection_archive---------0-----------------------#2024-08-27

数据预处理

一组(微小的)数据集,六种填充方法?

Samy BaladramTowards Data Science Samy Baladram

·发布于 Towards Data Science ·阅读时间 13 分钟·2024 年 8 月 27 日

--

⛳️ 更多 [数据预处理](https://medium.com/@samybaladram/list/data-preprocessing-17a2c49b44e4) 解释: ▶ 缺失值填充 · 类别编码 · 数据缩放 · 离散化 · 过采样与欠采样 · [数据泄露与预处理](https://towardsdatascience.com/data-leakage-in-preprocessing-explained-a-visual-guide-with-code-examples-33cbf07507b7)

让我们来谈谈每个数据科学家、分析师或好奇的数字分析师最终都需要面对的问题:缺失值。现在,我知道你在想什么——“哦,太好了,又一个关于缺失值的指南。”但是请听我说完。我将展示如何使用六种不同的填充方法来解决这个问题,且所有方法都应用于同一个数据集(并且附有有用的可视化效果!)。通过这篇文章,你将明白为什么领域知识的价值不可估量(甚至是我们 AI 朋友可能也难以复制的)。

所有可视化图像:作者使用 Canva Pro 创建。优化适配移动设备,可能在桌面端显示过大。

什么是缺失值以及为什么会出现?

在深入讨论我们的数据集和填补方法之前,让我们先花一点时间理解缺失值是什么以及它们为什么在数据科学中如此常见。

什么是缺失值?

缺失值,通常在 pandas 中表示为 NaN(非数字),在数据库中表示为 NULL,实质上是数据集中的空洞。它们是你电子表格中的空白单元格、调查问卷中的空白回答、失去的数据点。在数据的世界里,并非所有缺失都相同,理解缺失值的性质对于决定如何处理它们至关重要。

图片由作者提供。

为什么会出现缺失值?

缺失值可能因为多种原因悄悄进入你的数据。以下是一些常见的原因:

  1. 数据输入错误:有时,纯粹是人为错误。有人可能忘记输入一个值或不小心删除了某个值。

  2. 传感器故障:在物联网或科学实验中,故障的传感器可能会在某些时刻无法记录数据。

  3. 调查非响应:在调查中,受访者可能会跳过他们不愿回答或不理解的问题。

  4. 合并数据集:当从多个来源合并数据时,某些条目可能在所有数据集中没有对应的值。

  5. 数据损坏:在数据传输或存储过程中,某些值可能会被损坏并变得无法读取。

  6. 故意省略:由于隐私问题或不相关,某些数据可能被故意省略。

  7. 采样问题:数据收集方法可能系统性地遗漏某些类型的数据。

  8. 时效性数据:在时间序列数据中,可能会在数据未收集的期间(如周末、假期)缺失值。

缺失数据的类型

理解你所处理的缺失数据类型可以帮助你选择最合适的填补方法。统计学家通常将缺失数据分为三种类型:

  1. 完全随机缺失(MCAR):缺失情况是完全随机的,不依赖于任何其他变量。例如,如果一个实验室样本不小心掉落。

  2. 随机缺失(MAR):缺失数据的概率取决于其他观测到的变量,而不是缺失数据本身。例如,男性可能不太愿意在调查中回答有关情感的问题。

  3. 非随机缺失(MNAR):缺失情况取决于缺失数据本身的值。例如,高收入的人可能不太愿意在调查中报告他们的收入。

为什么要关心缺失值?

缺失值可能会显著影响你的分析:

  1. 如果处理不当,它们可能会引入偏差。

  2. 许多机器学习算法无法直接处理缺失值。

  3. 如果简单地丢弃带有缺失值的实例,可能会导致重要信息的丧失。

  4. 不当处理缺失值可能导致错误的结论或预测。

这就是为什么拥有处理缺失值的可靠策略至关重要的原因。而这正是我们将在本文中探讨的内容!

数据集

首先,让我们介绍一下数据集。我们将使用一个高尔夫球场的数据集,跟踪影响球场拥挤度的各种因素。这个数据集包含了几乎所有内容——数值型数据、分类数据,当然还有大量缺失值。

该数据集是由作者人工制作的(灵感来源于[1]),旨在促进学习。

import pandas as pd
import numpy as np

# Create the dataset as a dictionary
data = {
    'Date': ['08-01', '08-02', '08-03', '08-04', '08-05', '08-06', '08-07', '08-08', '08-09', '08-10',
             '08-11', '08-12', '08-13', '08-14', '08-15', '08-16', '08-17', '08-18', '08-19', '08-20'],
    'Weekday': [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5],
    'Holiday': [0.0, 0.0, 0.0, 0.0, np.nan, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, np.nan, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    'Temp': [25.1, 26.4, np.nan, 24.1, 24.7, 26.5, 27.6, 28.2, 27.1, 26.7, np.nan, 24.3, 23.1, 22.4, np.nan, 26.5, 28.6, np.nan, 27.0, 26.9],
    'Humidity': [99.0, np.nan, 96.0, 68.0, 98.0, 98.0, 78.0, np.nan, 70.0, 75.0, np.nan, 77.0, 77.0, 89.0, 80.0, 88.0, 76.0, np.nan, 73.0, 73.0],
    'Wind': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, np.nan, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, np.nan, 1.0, 0.0],
    'Outlook': ['rainy', 'sunny', 'rainy', 'overcast', 'rainy', np.nan, 'rainy', 'rainy', 'overcast', 'sunny', np.nan, 'overcast', 'sunny', 'rainy', 'sunny', 'rainy', np.nan, 'rainy', 'overcast', 'sunny'],
    'Crowdedness': [0.14, np.nan, 0.21, 0.68, 0.20, 0.32, 0.72, 0.61, np.nan, 0.54, np.nan, 0.67, 0.66, 0.38, 0.46, np.nan, 0.52, np.nan, 0.62, 0.81]
}

# Create a DataFrame from the dictionary
df = pd.DataFrame(data)

# Display basic information about the dataset
print(df.info())

# Display the first few rows of the dataset
print(df.head())

# Display the count of missing values in each column
print(df.isnull().sum())

输出:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20 entries, 0 to 19
Data columns (total 8 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   Date         20 non-null     object 
 1   Weekday      20 non-null     int64  
 2   Holiday      19 non-null     float64
 3   Temp         16 non-null     float64
 4   Humidity     17 non-null     float64
 5   Wind         19 non-null     float64
 6   Outlook      17 non-null     object 
 7   Crowdedness  15 non-null     float64
dtypes: float64(5), int64(1), object(2)
memory usage: 1.3+ KB

     Date  Weekday  Holiday  Temp  Humidity  Wind Outlook  Crowdedness
0  08-01        0      0.0  25.1      99.0   0.0   rainy         0.14
1  08-02        1      0.0  26.4       NaN   0.0   sunny          NaN
2  08-03        2      0.0   NaN      96.0   0.0   rainy         0.21
3  08-04        3      0.0  24.1      68.0   0.0   overcast      0.68
4  08-05        4      NaN  24.7      98.0   0.0   rainy         0.20

Date           0
Weekday        0
Holiday        1
Temp           4
Humidity       3
Wind           1
Outlook        3
Crowdedness    5
dtype: int64

正如我们所见,数据集中有 20 行和 8 列:

  • 日期:观察日期

  • 星期几:一周中的天数(0–6,其中 0 表示星期一)

  • 假期:布尔值,表示是否为假期(0 或 1)

  • 温度:摄氏温度

  • 湿度:湿度百分比

  • 风速:风速情况(0 或 1,可能表示平静或有风)

  • 天气预报:天气预报(晴天、阴天或雨天)

  • 拥挤度:课程占用百分比

看看这个!除了日期和星期几,其他每一列都有缺失值。非常适合我们的插补工作。

现在我们已经加载了数据集,让我们通过六种不同的插补方法来处理这些缺失值。我们会针对每种数据类型使用不同的策略。

方法 1:逐行删除

逐行删除,也称为完全案例分析,涉及删除包含任何缺失值的整行数据。此方法简单,并能保持数据分布,但如果许多行包含缺失值,它可能会导致信息的显著丢失。

👍 常见使用:逐行删除通常用于缺失值较少且数据完全随机缺失(MCAR)的情况。它也适用于需要完整数据集的分析,而这些分析不能处理缺失值。

在我们案例中:我们使用逐行删除处理至少有 4 个缺失值的行。这些行可能没有足够可靠的信息,删除它们可以帮助我们集中在更完整的数据点上。不过,我们会小心操作,只删除缺失数据较多的行,以尽量保留更多信息。

# Count missing values in each row
missing_count = df.isnull().sum(axis=1)

# Keep only rows with less than 4 missing values
df_clean = df[missing_count < 4].copy()

我们已经删除了 2 行缺失值过多的数据。现在让我们继续处理剩余的缺失数据。

方法 2:简单插补 — 均值和众数

简单插补是指用观测值的汇总统计量来替代缺失值。常见的做法包括使用列中非缺失值的均值、中位数或众数。

👍 常见使用:均值插补通常用于连续变量,当数据随机缺失且分布大致对称时。众数插补通常用于分类变量。

在我们的案例中:我们对 Humidity 使用了均值填充,对 Holiday 使用了众数填充。对于 Humidity,假设缺失值是随机的,均值提供了一个合理的典型湿度估算。对于 Holiday,由于它是一个二元变量(假期与否),众数给出了最常见的状态,这是一个合理的缺失值猜测。

# Mean imputation for Humidity
df_clean['Humidity'] = df_clean['Humidity'].fillna(df_clean['Humidity'].mean())

# Mode imputation for Holiday
df_clean['Holiday'] = df_clean['Holiday'].fillna(df_clean['Holiday'].mode()[0])

方法 3:线性插值

线性插值通过假设已知数据点之间存在线性关系来估算缺失值。它特别适用于时间序列数据或具有自然顺序的数据。

👍 常见用法:线性插值通常用于时间序列数据,在这种情况下,缺失值可以根据前后数据点的值进行估算。它也适用于任何预期相邻数据点之间存在大致线性关系的数据。

在我们的案例中:我们对 Temperature 使用了线性插值。由于温度通常随时间逐渐变化,并且我们的数据按日期排序,线性插值可以根据附近几天的温度值来合理估算缺失的温度值。

df_clean['Temp'] = df_clean['Temp'].interpolate(method='linear')

方法 4:前向/后向填充

前向填充(或“上次观察值前移”)将最后一个已知值前推填补空缺,而后向填充则执行相反的操作。这种方法假设缺失值可能与最近已知的值相似。

👍 常见用法:前向/后向填充常用于时间序列数据,尤其是当数据值在改变之前可能保持不变时(例如金融数据),或者当最新的已知值是当前状态的最佳猜测时。

在我们的案例中:我们对 Outlook 使用了前向和后向填充的组合。天气状况通常会持续好几天,因此可以合理假设缺失的 Outlook 值可能与前一天或后一天的 Outlook 值相似。

df_clean['Outlook'] = df_clean['Outlook'].fillna(method='ffill').fillna(method='bfill')

方法 5:常数值填充

这种方法涉及用特定的常数值替换变量中的所有缺失值。这个常数值可以根据领域知识或一个安全的默认值来选择。

👍 常见用法:常数值填充通常在缺失数据有逻辑默认值的情况下使用,或者当你想明确标记某个值缺失时(通过使用一个超出数据正常范围的值)。

在我们的案例中:我们对 Wind 列使用了常数值填充,将缺失值替换为-1\。这种方法明确标记了填充的值(因为-1 超出了 Wind 列的正常 0-1 范围),同时保留了这些值最初缺失的信息。

df_clean['Wind'] = df_clean['Wind'].fillna(-1)

方法 6:KNN 填充

K 近邻(KNN)填补通过找到数据集中最相似的 K 个样本来估算缺失值(就像 KNN 作为分类算法一样),并使用它们的值来填补缺失数据。这种方法可以捕捉变量之间的复杂关系。

👍 常见用途:KNN 填补方法非常灵活,可以用于连续变量和分类变量。特别是在变量之间可能存在复杂关系,而简单方法可能忽略这些关系时,它尤其有用。

在我们的案例中:我们使用 KNN 填补来估算拥挤度。拥挤度可能依赖于多个因素(如温度、假期状态等),而 KNN 能够捕捉这些复杂的关系,提供更准确的缺失拥挤度值估算。

from sklearn.impute import KNNImputer

# One-hot encode the 'Outlook' column
outlook_encoded = pd.get_dummies(df_clean['Outlook'], prefix='Outlook')

# Prepare features for KNN imputation
features_for_knn = ['Weekday', 'Holiday', 'Temp', 'Humidity', 'Wind']
knn_features = pd.concat([df_clean[features_for_knn], outlook_encoded], axis=1)

# Apply KNN imputation
knn_imputer = KNNImputer(n_neighbors=3)
df_imputed = pd.DataFrame(knn_imputer.fit_transform(pd.concat([knn_features, df_clean[['Crowdedness']]], axis=1)),
                          columns=list(knn_features.columns) + ['Crowdedness'])

# Update the original dataframe with the imputed Crowdedness values
df_clean['Crowdedness'] = df_imputed['Crowdedness']

结论:选择的力量(与知识)

所以,结果就是这样!六种不同的缺失值处理方法,都应用于我们的高尔夫球场数据集。

让我们回顾一下每种方法是如何处理我们的数据的:

  1. 列表删除法:通过删除包含大量缺失值的行,帮助我们集中处理更完整的数据点。

  2. 简单填补:用平均值填补湿度,且用最常见的值填补假期。

  3. 线性插值:根据周围天数的趋势估算缺失的温度值。

  4. 前向/后向填补:根据相邻天的数据猜测缺失的天气预报值,反映天气模式的持续性。

  5. 常数值填补:用-1 标记缺失的风速数据,保留这些值原本是未知的事实。

  6. KNN 填补:根据相似的日子估算拥挤度,捕捉变量之间的复杂关系。

每种方法讲述了关于我们缺失数据的不同故事,而“正确”的选择取决于我们对高尔夫球场运营的了解以及我们想要回答的问题。

关键要点?不要盲目地应用填补方法。理解你的数据,考虑背景,并选择最适合你特定情况的方法。

⚠️ 警告:缺失值填补的目的与局限性

尽管我们已经探讨了各种填补技术,但我们需要理解它们的目的和局限性:

  1. 不是万能解决方案:填补方法并非解决缺失数据的灵丹妙药。它是让你的数据可用的工具,而不是用来创造完美数据的工具

  2. 偏差的潜在可能性:填补的值是经过推测的估计。如果不小心进行,尤其是在数据是非随机缺失(NMAR)的情况下,可能会引入偏差。

  3. 不确定性的丧失:大多数简单的填补方法没有考虑缺失值的不确定性,这可能导致过于自信的模型。

  4. 数据扭曲:激进的插补可能会扭曲数据中的关系。始终检查插补是否显著改变了数据的分布或相关性。

  5. 记录您的过程:始终清晰地记录您的插补方法。这种透明度对于可重复性至关重要,也有助于他人理解您结果中潜在的偏差。

再次强调,插补的目标是使您的数据可用,同时最小化偏差和信息丢失。这不是为了创建完美的数据,而是为了最大限度地利用您已有的信息。始终以谨慎和批判性思维来处理插补。

🌟 缺失值插补概述

import pandas as pd
import numpy as np
from sklearn.impute import KNNImputer

# Create the dataset as a dictionary
data = {
    'Date': ['08-01', '08-02', '08-03', '08-04', '08-05', '08-06', '08-07', '08-08', '08-09', '08-10',
             '08-11', '08-12', '08-13', '08-14', '08-15', '08-16', '08-17', '08-18', '08-19', '08-20'],
    'Weekday': [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5],
    'Holiday': [0.0, 0.0, 0.0, 0.0, np.nan, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, np.nan, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    'Temp': [25.1, 26.4, np.nan, 24.1, 24.7, 26.5, 27.6, 28.2, 27.1, 26.7, np.nan, 24.3, 23.1, 22.4, np.nan, 26.5, 28.6, np.nan, 27.0, 26.9],
    'Humidity': [99.0, np.nan, 96.0, 68.0, 98.0, 98.0, 78.0, np.nan, 70.0, 75.0, np.nan, 77.0, 77.0, 89.0, 80.0, 88.0, 76.0, np.nan, 73.0, 73.0],
    'Wind': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, np.nan, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, np.nan, 1.0, 0.0],
    'Outlook': ['rainy', 'sunny', 'rainy', 'overcast', 'rainy', np.nan, 'rainy', 'rainy', 'overcast', 'sunny', np.nan, 'overcast', 'sunny', 'rainy', 'sunny', 'rainy', np.nan, 'rainy', 'overcast', 'sunny'],
    'Crowdedness': [0.14, np.nan, 0.21, 0.68, 0.20, 0.32, 0.72, 0.61, np.nan, 0.54, np.nan, 0.67, 0.66, 0.38, 0.46, np.nan, 0.52, np.nan, 0.62, 0.81]
}

# Create a DataFrame from the dictionary
df = pd.DataFrame(data)

# 1\. Listwise Deletion
df_clean = df[df.isnull().sum(axis=1) < 4].reset_index(drop=True).copy()

# 2\. Simple Imputation
df_clean['Humidity'] = df_clean['Humidity'].fillna(df_clean['Humidity'].mean())
df_clean['Holiday'] = df_clean['Holiday'].fillna(df_clean['Holiday'].mode()[0])

# 3\. Linear Interpolation
df_clean['Temp'] = df_clean['Temp'].interpolate(method='linear')

# 4\. Forward/Backward Fill
df_clean['Outlook'] = df_clean['Outlook'].ffill()

# 5\. Constant Value Imputation
df_clean['Wind'] = df_clean['Wind'].fillna(-1)

# 6\. KNN Imputation
# One-hot encode the 'Outlook' column
outlook_encoded = pd.get_dummies(df_clean['Outlook'], prefix='Outlook')

# Prepare features for KNN imputation
features_for_knn = ['Weekday', 'Holiday', 'Temp', 'Humidity', 'Wind']
knn_features = pd.concat([df_clean[features_for_knn], outlook_encoded], axis=1)

# Apply KNN imputation
knn_imputer = KNNImputer(n_neighbors=3)
df_imputed = pd.DataFrame(knn_imputer.fit_transform(pd.concat([knn_features, df_clean[['Crowdedness']]], axis=1)),
                          columns=list(knn_features.columns) + ['Crowdedness'])

# Update the original dataframe with the imputed Crowdedness values
df_clean['Crowdedness'] = df_imputed['Crowdedness'].round(2)

print("Before:")
print(df)

print("\n\nAfter:")
print(df_clean)

深入阅读

有关KNNImputer及其在 scikit-learn 中实现的详细解释,读者可以参考官方文档,其中提供了关于其使用和参数的全面信息。

技术环境

本文使用 Python 3.7 和 scikit-learn 1.5。尽管所讨论的概念具有普遍适用性,但不同版本之间的具体代码实现可能会略有不同。

关于插图

除非另有说明,否则所有图片均由作者创作,并结合了 Canva Pro 的授权设计元素。

若想了解缺失值插补的简明视觉总结,请查看相关 Instagram 帖子。

参考文献

[1] T. M. Mitchell, 机器学习(1997),McGraw-Hill 科学/工程/数学, 第 59 页

𝙎𝙚𝙚 𝙢𝙤𝙧𝙚 𝘿𝙖𝙩𝙖 𝙋𝙧𝙚𝙥𝙧𝙤𝙘𝙚𝙨𝙨𝙞𝙣𝙜 𝙢𝙚𝙩𝙝𝙤𝙙𝙨 𝙝𝙚𝙧𝙚:

Samy Baladram

Samy Baladram

数据预处理

查看列表6 个故事两个人物相互拥抱的卡通插图,字母‘A’、‘B’、‘C’和数字‘1’、‘2’、‘3’在他们周围漂浮。上方有一个粉色的心形,象征着情感。背景是蓝色和绿色方块组成的像素化图案,代表数据或编码。这幅图形象地描绘了类别数据编码的概念,其中类别(ABC)被转化为数字表示(123)。代表机器学习中数据缩放的卡通插图。一位高大的女性(代表具有大范围的数值特征)正在变小成一个儿童(代表该特征经过缩放后的较小范围)。红色箭头指示缩小过程,黄色的闪光表示缩放的积极影响。

𝙔𝙤𝙪 𝙢𝙞𝙜𝙝𝙩 𝙖𝙡𝙨𝙤 𝙡𝙞𝙠𝙚:

Samy Baladram

Samy Baladram

分类算法

查看列表8 个故事Samy Baladram

Samy Baladram

回归算法

查看列表5 个故事一个带有辫子和粉色帽子的卡通娃娃。这个“假人”娃娃,凭借其简单的设计和心形装饰的衬衫,形象地代表了机器中的假回归器概念。就像这个玩具一样的形象是一个简化的、静态的人物表现,假回归器是一种基础模型,用作更复杂分析的基准。

Mistral 7B 解析:迈向更高效的语言模型

原文:towardsdatascience.com/mistral-7b-explained-towards-more-efficient-language-models-7f9c6e6b7251?source=collection_archive---------1-----------------------#2024-11-26

RMS Norm、RoPE、GQA、SWA、KV 缓存等!

Bradney SmithTowards Data Science Bradney Smith

·发表于Towards Data Science ·42 分钟阅读·2024 年 11 月 26 日

--

“从零开始的 LLM 系列”第五部分——完整指南,帮助理解和构建大型语言模型。如果你有兴趣了解这些模型是如何工作的,我鼓励你阅读:

图片由作者提供,使用 Freepik AI 创作。

引言

Mistral 7B 于 2023 年 9 月发布,代表了向更小、更高效的大型语言模型(LLM)发展的一个重要里程碑。在过去几年中,LLM 性能的主要提升机制是模型规模,即增加模型中可学习的参数数量。近年来,这导致了参数数量达到数百亿的模型,但这些模型的训练和服务成本较高,并且推理时间较长。然而,通过精心的架构设计和注意力机制的进步,Mistral AI 在 LLM 开发方面取得了突破,利用少量的参数实现了与更大模型相当甚至更高的性能。本文将全面介绍 Mistral 7B 中的各个组件,这些组件使得这些效率提升成为可能。

注意: 在下一篇文章中,我们将探讨 QLORA,一种参数高效的微调技术,并展示如何针对任何下游任务微调 Mistral 7B 和增强版 NeMo 12B 模型。

目录

1 — Mistral 7B 概述

2 — 均方根归一化(RMS Norm)

3 — 旋转位置嵌入(RoPE)

4 — 分组查询注意力(GQA)

5 — 滑动窗口注意力(SWA)

6 — 滚动缓冲 KV 缓存

7 — SwiGLU 激活函数

8 — 结论

9 — 进一步阅读

1 — Mistral 7B 概述

1.1 — 介绍 Mistral AI

自 2022 年 11 月大型语言模型(LLM)热潮以来,许多竞争者涌现,挑战 OpenAI 的主导地位。ChatGPT 的发布使得生成式语言模型的兴趣激增,因此,更多的公司涌现出来推动这一研究向前发展也就不足为奇。

在这些新兴公司中,有一家名为 Mistral AI 的巴黎初创公司,由前 Meta 和 Google DeepMind 员工于 2023 年 4 月成立。其目标是创建专注于效率的强大大型语言模型,这一理念体现在他们的首个模型——Mistral 7B [1] 中。该模型具有四个主要特点:

  • 解码器单一架构: 基于原始 Transformer 中解码器模块的架构

  • 高效设计: 一个拥有较少参数的强大大型语言模型

  • 两种可用模型类型: 提供基础模型和指令模型

  • 强大性能: 在所有基准测试中表现优异,甚至与其更大同类模型相比也不遑多让。

1.2 — 解码器单一架构

上一篇文章中,我们讨论了谷歌的 BERT 模型,它基于原始 Transformer 架构中的编码器模块。仅编码器模型在 BERT 系列之外相对较少见,大多数在 2021 年后发布的 LLM 要么采用原始 Transformer 的旧版编码器-解码器设计,要么更常见地采用原始 GPT 推广的解码器单一架构。仅编码器设计使得 BERT 能够利用双向上下文,并在分类等任务中表现出色。然而,这一设计也限制了 BERT 在生成型应用中的能力,如聊天机器人任务(这可能是仅编码器模型逐渐减少的原因)。

相比之下,解码器单一模型使用单向上下文来预测序列中的下一个词元,这一过程被称为自然语言生成(NLG)。这些模型被应用于聊天机器人等应用中,例如虚拟助手、ChatGPT 等,用户输入提示,模型一次生成一个词元的适当响应。作为在 BERT 时代之后发布的模型,Mistral 也使用了解码器单一架构,主要用于 NLG 任务。

下面的对比展示了 BERT 在自然语言理解(NLU)上的专注与 Mistral 7B 在自然语言生成(NLG)上的专注。图片来源:作者。

1.3 — Mistral 7B 作为高效 LLM

趋向更大模型的趋势:

如前所述,LLM 的发展趋势是通过增加模型的大小来提升性能。一般的想法是,较大的模型(具有更多参数)能够更好地捕捉训练数据中的关系和细微差别,从而在推理过程中产生更好的输出。这种方法已被证明非常有效,导致模型在所有常见的性能基准测试中表现出色。这些大模型的例子包括 xAI 的 Grok-1(3140 亿参数)、谷歌的 PaLM 2(3400 亿参数),以及 OpenAI 的 GPT-4,尽管其参数数量未公开披露,但据信达到了万亿参数的范围。

大模型的缺点:

尽管这些更大的模型展示了高水平的性能,但它们也有一些显著的缺点。训练这些模型既费时又昂贵。大量的参数意味着每个优化步骤中需要更新许多权重和偏差,从而需要巨大的计算资源。这个问题在推理时依然存在,在没有足够强大硬件的情况下,提示这些模型可能会导致响应时间缓慢。其他缺点还包括由于更高的能源需求,环境和可持续性问题,这使得它们的碳足迹相比较小的模型更大。

Mistral 7B 作为更小、更高效的模型:

Mistral 7B 因其在变压器架构方面的进步而广为人知,这些进展使得该模型能够在减少参数数量的同时保持高性能。因此,Mistral AI 通过将重点从当前范式转移,推广更小的模型,领导了高效 LLM 的发展。这种方法具有几个优势,比如减少训练时间和成本,并且应对上述提到的可持续性问题。在接下来的部分,我们将探讨这些架构变化是什么,以及它们如何在更小的尺寸下实现更高性能的模型。

1.4 — 基础模型、对话模型和指令模型概述

不同的模型类型:

如果你在网上阅读过有关不同 LLM 的资料,可能会遇到“base”、“chat”和“instruct”这几个术语。Base指的是可以在下游任务中进行微调的模型标准版本,而chatinstruct则指的是在 base 模型基础上专门为聊天和指令任务微调的版本。聊天模型经过对话数据的微调,旨在用于聊天机器人应用,如虚拟助手和 ChatGPT 风格的用例。指令模型则旨在接收指令并作出回应。尽管这两者在微调上有些许差异(将在下文描述),但重要的是要认识到,两者的预训练是相同的。因此,尽管每个模型在各自领域的表现更好,但也可以将任意一个模型用于两种任务。

聊天与指令:

聊天模型是为对话交互设计的,旨在模拟类人对话。例如,聊天模型通常用于客户支持环境中的虚拟助手,其中输入格式更为非正式和灵活。相反,指令模型旨在根据指令执行特定任务。这里的示例包括代码生成和数据总结等任务。这些类型模型的输入格式更加结构化,需要更正式的提示。

Mistral 7B 中的模型类型:

Mistral 7B 既有 base 版本,也有 instruct 版本,尽管目前没有为 chat 特别微调的版本。然而,base 版本与上面描述的 chat 变体非常相似,可以以非结构化、非正式的方式进行交互。如需查看 Mistral AI 模型的完整列表,可以访问 Hugging Face 模型库中的 Mistral AI 页面。[2]

1.5 — LLM 基准测试的表现

Mistral 7B 还可以通过与更大、同类的模型相比的强劲表现来进行特征描述。在最初的宣传资料中,Mistral AI 将其新的 LLM 与 Meta 的 Llama 系列模型进行了比较:Llama 和 Llama 2(当时 Llama 3 尚未发布)。这些性能比较的图表如下所示,并已摘自 Mistral 7B 的论文[1]。

这些基准测试中的一些利用了零-shot 学习、少-shot 学习或两者的混合。零-shot 学习是指在没有明确遇到过的预训练数据的情况下,要求模型执行任务或回答问题。这要求模型从现有知识中进行泛化以提供答案。少-shot 学习则是指在提示中提供少量示例,帮助模型理解预期的答案格式或类型。

Mistral 7B 在一系列基准测试中与 Llama 和 Llama 2 的表现比较[1]。

上述比较的表格视图及每个基准的分数[1]。

整体趋势表明,Mistral 7B 在所有评估的指标上都优于 Llama 2 13B,且差距通常相当明显。更令人印象深刻的是,Mistral 7B 在大多数基准测试中也能够匹配或超越 Llama 1 34B 的表现。

为了便于可视化,作者将一些相似的基准测试分组为类别,如“知识”和“推理”。下面给出了这些类别的详细信息。

  • MMLU: 大规模多任务语言理解(MMLU)不是一个类别,而是一个单一的基准测试。这个评估旨在衡量一个模型在预训练阶段捕获知识的程度,采用零-shot 和少-shot 学习。这些问题涉及 57 个学科,包括科学、技术、工程、数学(STEM),以及人文学科、社会科学、法律、伦理等。MMLU 由 Hendrycks 等人在 2021 年提出,已被 NLP 社区作为评估 LLM 表现的事实标准[3]。

  • 知识: 知识类别的结果是通过 NaturalQuestions 和 TriviaQA 基准测试的平均值得出的,这些基准使用 5-shot 学习。这些数据集包含一系列问题,用以检验模型从训练数据中获得的一般知识。

  • 推理: 推理类别的结果是通过 HellaSwag、Winogrande、PIQA、SIQA、OpenbookQA、ARC-Easy、ARC-Challenge 和 CommonsenseQA 基准测试的平均值得出的,这些基准使用零-shot 学习。这些数据集测试模型推理现实世界问题的能力。

  • 理解: 理解类别的结果是通过 BoolQ 和 QuAC 基准测试的平均值得出的,这些基准使用零-shot 学习。这些数据集侧重于基于上下文中的文本段落向模型提问,评估模型在对话中理解信息的能力。

  • AGIEval: 与 MMLU 类似,AGIEval 是一个单一的基准测试,而不是多个基准的类别。AGIEval,全称人工通用智能评估,旨在“专门评估基础模型在以人为中心的标准化考试中的表现,如大学入学考试、法学院入学考试、数学竞赛和律师资格考试”。作者认为,之前的基准测试偏向于适合机器和人工数据集的任务,而 AGIEval 则考察更接近人类水平的能力。AGIEval 于 2023 年由 Zhong 等人发布[4]。

  • 数学: 数学类别的结果是通过 GSM8K 和 MATH 基准测试的平均值得出的,这两个基准分别使用 8-shot 和 4-shot 学习。这些数据集包含基本运算(加法、减法、乘法和除法)的数学问题,并可能需要多步才能解决。

  • BBH: 与 MMLU 和 AGIEval 类似,BIG-Bench Hard(简称 BBH)是一个单一基准。BBH 包含来自更大 BIG-Bench 数据集的 23 个特别具有挑战性的任务,专注于评估模型未能超越平均人类评分者的任务。该基准由 Suzgun 等人于 2022 年提出[5]。

  • 代码: 代码类别的平均结果为零次尝试的 Humaneval 和三次尝试的 MBPP。这些基准测试评估模型从文本提示生成代码的能力。

1.6 — Mistral 7B 架构概述

自从 Transformer 首次问世以来,LLM 组件已经取得了长足的进步,因此现代 LLM 通常在原始设计基础上进行了许多改进。关于注意力机制和位置编码器的改进建议已经相对频繁地发布,研究人员正在竞相发现下一项技术,以推动该领域的进一步发展。

与其使命一致,Mistral AI 利用了这些进展来提高 Mistral 7B 的效率,成功构建了一个在参数量上只有原先一小部分的高性能模型。在接下来的章节中,我们将探讨这些进展,包括:

  • RMS 归一化 — 替代层归一化

  • 旋转位置嵌入(RoPE) — 替代绝对位置编码

  • 分组查询注意力(GQA) — 替代多头注意力机制

  • 滑动窗口注意力(SWA) — 提高训练和推理速度,尤其适用于长序列

  • 滚动缓冲区 KV 缓存 — 与 SWA 配合使用,提高训练和推理速度

  • SwiGLU 激活函数 — 替代前馈子层中的 ReLU

对 Mistral 7B 与原始 Transformer 架构的比较。图片来源作者,包含来自[13]的 Transformer 图示。

1.7 — BERT 参数比较

自 2018 年 GPT 和 BERT 发布以来,模型规模持续快速增长,出现了拥有数百亿参数的模型并不罕见。与同类模型相比,Mistral 7B 被认为是一个相对较小的模型。为了说明这一点,当时 BERT Large 在发布时被认为是非常庞大的,然而它仅包含 3.4 亿个参数,这也显示出这个领域在短短几年内取得的进展。对于跟随本系列的人来说,你可能还记得在第四部分中有一个表格总结了 BERT Base 和 BERT Large 的模型参数。下面已更新此表,包含与 Mistral 7B 的对比。

阅读此表格时需要注意几点:

  • 词汇表大小:尽管模型复杂性有所增加,Mistral 7B 的词汇表大小与 BERT 几乎相同。

  • 上下文长度:Mistral 7B 支持比 BERT 大 16 倍的上下文长度,允许分析更长的文档。这是大规模语言模型(LLM)中的一个趋势,带来了许多好处,比如在聊天机器人应用中可以处理更长的对话历史,允许在提示中使用诸如书籍等较长文本中的知识,等等。

  • 注意力头:Mistral 7B 将其查询矩阵分成 8 组,每组包含 4 个矩阵,且每组共享一个键值矩阵。这是由于分组查询注意力(GQA),我们将在本文后续部分讨论这一点。

BERT Base、BERT Large 和 Mistral 7B 关键参数的对比。图像由作者提供。

注意: 仅编码器和仅解码器模型在架构上基本相似,这可以通过比较原始 Transformer 中的编码器和解码器块来看到。除了额外的“多头注意力”和“加法与归一化”步骤,这些块之间的主要区别在于是否存在最终的“线性”层及其对应的 softmax 函数。这些附加组件使得解码器块(因此也使得编码器-解码器和仅解码器模型)能够执行下一个标记预测(NTP)。

2 — 均方根归一化(RMS Norm)

2.1 — 归一化和前馈子层简介

如果你跟随系列文章阅读,你可能已经注意到我们尚未涵盖 Transformer 架构中的“归一化”或“前馈”步骤。这两个组件(通常称为子层)在 Mistral 7B 中得到了改进,因此理解它们的功能以及为什么需要它们将非常有用。让我们现在来探讨这个问题。

归一化子层:

由于一种被称为协变量偏移的问题,基于 Transformer 的模型需要进行归一化。协变量偏移描述的是在模型中某些权重更新显著,而其他权重则没有发生更新的现象。权重分布的这种变化会对网络中的下一层产生连锁反应,在反向传播过程中导致进一步的不稳定更新,从而影响性能。归一化通过确保输入向量的均值和方差一致,从而标准化每一层的输入,这反过来帮助稳定学习过程。

前馈子层:

前馈步骤引入了非线性变换和额外的学习能力。简单来说,这些组件使得模型能够通过从训练数据中学习,决定如何最好地改进其对文本的内部表示。前馈块是由以下部分组成的浅层神经网络:输入层、一个隐藏层和输出层。在 Transformer 中,前馈网络的输入是来自标准化步骤的输出(稍后我们会看到,这对于 Mistral 7B 有所不同)。前馈网络接受这些输入序列的数值表示,并以一种有助于模型生成良好输出序列的方式更新它们。通过使用神经网络方法,我们消除了对模型如何增强这些表示施加严格规则的需要,而是允许模型通过反向传播学习如何最好地改变它们。

示例:

举个更具体的例子,考虑原始 Transformer 如何处理输入序列:“写一首关于一个人在河岸上钓鱼的诗”。

1. 标记化:将输入序列划分为标记writeapoemaboutamanfishingonariverbank。有关标记化的更多信息,请参见本系列第一部分

2. 嵌入:将每个标记映射到其对应的学习嵌入。这些是标记的向量表示,编码了它们的整体含义。有关嵌入的更多信息,请参见本系列第二部分

3. 多头注意力:将嵌入传入注意力块,以更新每个单词的向量表示,结合上下文信息。这确保了像bank这样的词根据其使用情况(例如河岸、银行等)获得更合适的向量表示。有关注意力块的更多信息,请参见本系列第三部分

4. 标准化:将来自注意力块的上下文嵌入传递到标准化块。在这里,输入的向量被标准化,以确保均值和方差的一致性,从而缓解协变量偏移的问题。

5. 前馈:将来自标准化步骤的输出传递到前馈子层。该步骤以有助于模型在后续过程中生成优美诗歌的方式更新每个标记的向量表示。更新向量表示的具体步骤不是硬编码的,而是通过反向传播由模型学习得到的。

6. 标准化:将前馈步骤的输出传递到另一个标准化块。步骤 3 到 6 会重复N次(其中N是编码器块的数量),然后向量表示会被传递到解码器块。

2.2— 层归一化概述(LayerNorm)

Transformer 使用一种名为 LayerNorm 的归一化方法,该方法于 2016 年发布,作为对当时神经网络中使用的旧 BatchNorm 方法的改进 [6]。LayerNorm 的目标是通过修改输入到某一层的分布,使其符合高斯(正态)分布,从而防止协变量偏移,因此得名“归一化”。

归一化子层的输入:

在 Transformer 中,归一化过程发生在每个 Attention 块和每个 Feed Forward 块之后。因此,归一化步骤的输入在每个位置都会有所不同:

  • 在多头自注意力之后:自注意力输入 + 自注意力输出

  • 在前馈网络之后:前馈网络输入 + 前馈网络输出

初看之下,可能会觉得将归一化块同时传递给 Attention/Feed Forward 块的输入和输出有些奇怪。然而,包含这两个组件对于实现强大的模型性能至关重要。

残差连接的必要性:

下图架构图显示了 Attention 和 Feed Forward 子层的输入通过 残差连接(红色高亮)传递给归一化子层。这些输入在归一化之前分别与 Attention 和 Feed Forward 的输出相加,因此“Add & Norm”标签中的“Add”即指这一过程。残差连接有助于解决一个被称为 梯度消失问题 的难题,这也是训练深度神经网络时常见的挑战。在反向传播过程中,梯度(损失函数对每个权重的偏导数)决定了权重更新的方向和大小。然而,这些梯度有时会在经过许多层后变得极其微小,导致某些权重的更新几乎可以忽略不计。这会导致网络中的早期层学习变得非常缓慢,因为它们的梯度接近于零。残差连接通过允许梯度更加直接地流向早期层,从而绕过一些中间层,缓解了这个问题。这条额外的路径有助于保持梯度的强度,确保稳定的更新,并防止模型“遗忘”在早期层学到的知识。简而言之,在每个归一化阶段包含残差连接为反向传播的梯度提供了额外的路径,防止模型在早期层学习缓慢。

Transformer 架构图的特写,展示了 Add & Norm 块的残差连接,红色高亮标出。图像由作者标注。

2.3 — 可视化 LayerNorm

LayerNorm 将输入到网络的分布转换,使其值遵循高斯分布。考虑下面图像中展示的示例,它关注于在 Attention 步骤后进行标准化。在这里,LayerNorm 的输入将是 Attention 输入和 Attention 输出的和,其结果是一个表示输入序列中每个标记的上下文标记嵌入矩阵(在本例中为“Write a poem about a man fishing on a river bank”)。该矩阵的维度为 L_max x d_model,其中 L_max 是输入序列的长度,d_model 是嵌入维度的数量。该矩阵的列存储了输入序列中每个标记的嵌入。例如,第一列存储了“write”的上下文嵌入,第二列存储了“a”的嵌入,以此类推。

可以通过直方图绘制频率图来逼近每个标记嵌入的值分布。下图展示了“bank”这个标记的嵌入的示例。在标准化之前,“bank”嵌入向量中的值的均值为 18.5,而标准化后,均值降至 0。标准化过程会分别应用于矩阵的每一列,每一列都根据其自身的均值和方差进行标准化。

使用 LayerNorm 进行标准化的概览,以输入序列“Write a poem about a man fishing on a river bank”为例。图片由作者提供。

2.4 — LayerNorm 公式

为了标准化标记嵌入,我们首先为每一列计算两个关键统计值:均值方差。这些值分别描述了数据的中心和离散程度。一旦这些值确定,就可以根据标准化公式调整输入向量中的每个值。让我们简要分解这些公式:

  • 均值:均值(平均值)描述了分布的中心,通过将一列中的所有值相加,并除以该列中值的数量(维度)来计算。

  • 方差:方差描述了分布中数据点的离散程度(变化量),通过计算每个数据点与均值之间的平均平方距离来获得。较高的方差表明数据点分布更广,而较低的方差则表明值集中在均值附近。使用平方差而非绝对差值,部分是出于历史原因,但也因为它提供了一个可微分的离散度衡量。这一性质在高级统计学中非常有用,因此方差成为该领域的标准度量。

  • 归一化:归一化过程涉及两个主要公式。第一个(见下图左侧的两个公式中的第一个)将列的当前分布转换为正态分布。通过从每个值中减去均值,使分布集中在 0,然后除以方差的平方根(即标准差)。这一除法确保结果分布的标准差为 1,这是正态分布的要求。为了防止数据没有分布时出现除以 0 的情况,增加了一个额外的项 ϵ。第二个公式通过使用两个参数:缩放因子 γ 和偏移量 β,对这些归一化后的值进行可学习的调整。这些参数通过反向传播在训练过程中被模型学习。γ 和 β 的值是针对每个特征(矩阵中的行)而非每个嵌入(矩阵中的列)进行的。因此,嵌入的每个维度都会使用不同的 γ 和 β 值进行变换。这使得模型能够在嵌入空间内学习灵活的变换,从而提高其表示数据中复杂模式的能力。

LayerNorm 过程中的四个关键公式。图像来源:作者。

2.5 — RMS 归一化简介

Mistral 7B 使用了一种对 LayerNorm 的改进,叫做均方根归一化(Root Mean Square Normalization),或RMS Norm,该方法由 Zhang 和 Sennrich 于 2019 年提出 [7]。作者假设,LayerNorm 的有效性是由于对值进行重新缩放(通过方差除法),而不是重新中心化(减去均值)。

因此,如果能够省略均值的计算,模型在训练阶段将显著加速。然而,问题在于,方差的计算本身也需要均值已知。因此,作者着手寻找一种新的重新缩放方法,最终形成了 RMS 归一化。

2.6 — RMS 统计量

用于重新缩放值的 RMS 统计量有一个简单的公式,如下所示。实质上,输入矩阵(嵌入)中每一列的值都会除以该列的平均平方值的平方根(因此称为“均方根”)。与 LayerNorm 类似,归一化的结果会通过一个可学习的参数 γ 进行缩放(注意这里不需要 β,因为作者认为重新中心化并不必要)。尽管这个变化较小,但将 LayerNorm 替换为 RMS Norm 会显著加速神经网络模型的训练,成为自 Transformer 发布以来 LLM 架构中的许多进展之一。

Mistral 7B 使用的归一化技术 RMS Norm 的公式。图像来源:作者。

3 — 旋转位置嵌入(RoPE)

3.1 — 位置编码器概述

与旧有架构(如循环神经网络)不同,基于 Transformer 的模型并不是按顺序处理输入的每个 token,而是并行处理所有输入 token。虽然这种并行处理提高了速度,但也导致了位置信息的丢失,因为 token 并不是按顺序处理的。因此,需要某种形式的位置信息编码将此信息注入到嵌入向量中,并且可以通过多种方式实现这一点。

绝对位置编码:

原始 Transformer 中引入的正弦位置编码技术使用正弦和余弦函数为输入序列中的每个 token 创建一个位置编码向量。这些向量通过向量加法被添加到学习到的嵌入中。位置编码仅依赖于 token 在序列中的绝对位置,而不依赖于输入序列本身。因此,位置为 0 的 token 将始终具有相同的位置信息编码,无论序列如何。因此,这种方法被称为绝对位置编码

这种方法的一个局限性在于,它只表示 token 的绝对位置,而不是它们之间的相对距离。例如,位置 3 和 5 之间的 token 距离与位置 103 和 105 之间的距离是相同的,但使用绝对位置编码时并不会捕捉到这一信息。从直观上讲,距离较近的 token 可能比距离较远的 token 更为相关,而编码这种相对位置的信息可以显著提高模型性能。

相对位置编码:

2018 年 4 月,谷歌的研究人员(包括原始 Transformer 论文的两位作者)发表了“Self-Attention with Relative Position Representations”一文,提出了一种新的位置编码范式[8]。作者探讨了相对位置编码的使用,这种编码不仅捕捉了 token 之间的相对距离,还包括它们的绝对位置。例如,在句子“Write a poem about a man fishing on a river bank”中,“poem”和“man”之间相隔三个词,正如“on”和“bank”之间相隔三个词一样。这种类型的位置信息编码已经在一些著名模型中得到应用,比如 Dai 等人的 Transformer-XL(2019)[9] 和谷歌的 T5(2020)[10]。

尽管相对位置编码提高了模型捕捉 token 之间关系的能力,但它显著增加了训练时间。随着模型的规模扩大,增加训练时间的组件变得不那么实际。此外,像 KV 缓存集成(我们将在本文后面讨论)这样的挑战也使得许多研究者放弃了这一技术。我们不会详细介绍原始的相对位置编码技术,但如果你有兴趣,强烈建议你阅读相关论文。

旋转位置嵌入(RoPE):

Rotary 嵌入由 Su 等人在 2020 年的论文 “RoFormer: Enhanced Transformer with Rotary Position Embedding” 中提出,提供了一种独特的编码位置信息的方法[11]。与正弦波编码直接将位置信息添加到 token 嵌入中不同,rotary 嵌入则对每个 token 的 查询向量和键向量 应用 旋转。每个 token 的旋转角度基于其在序列中的绝对位置。例如,在输入“write a poem about a man fishing on a river bank”中,poem(位于位置 2)的查询和键向量旋转了 2θ,而 man(位于位置 5)的查询和键向量旋转了 5θ,依此类推。注意,token 位置是零索引的,这意味着我们从 0 开始计数而不是 1(因此 write 被认为处于位置 0,且其查询和键向量没有旋转)。这种方法不仅捕捉了 token 的绝对位置,还捕捉了相对位置,因为 manpoem 之间相隔 3θ,表示它们之间有 3 个 token 的距离。

使用角位移编码位置信息还具有一些与现有的 Transformer 组件配合良好的优点。例如,自注意力机制在很大程度上依赖于点积操作,其公式已经考虑了查询和键之间的角距离。此外,如果在两个 token 前后添加更多的 token,它们之间的角距离保持不变。这使得输入序列可以进行修改,而不会显著改变位置信息,这与绝对位置编码方法不同。

3.2 — 实现 RoPE

上述大纲简要概述了 RoPE,旨在说明其核心概念,但技术实现包含了两个重要细节:

  1. 成对特征旋转:每个查询/键向量的特征在嵌入空间内成对旋转。

2. 多频率位置编码:查询/键向量中的每一对特征都会以稍微不同的角度旋转。

让我们来看看 RoPE 如何集成到基于 Transformer 的架构中,了解其实现背后的数学原理,并理解上述两个细节的含义,以及为什么它们对于 RoPE 的有效运行是必要的。

3.3 — 将 RoPE 集成到 Transformer 中:

使用 RoPE 的 Transformer 处理文本的步骤如下:

1. 分词与嵌入:和往常一样,过程从模型接收到输入序列开始,该序列被分词以生成 token ID 列表。这些 token ID 随后被转换为 token 嵌入,形成一个矩阵,其中每一列对应一个 token 的嵌入向量。

2. 归一化:在原始的 Transformer 模型中,位置信息直接添加到原始的令牌嵌入中。然而,在使用 RoPE 的模型中,令牌嵌入首先会进行归一化。此步骤通过防止协变量偏移来稳定训练,如前文所述(请参见第 2.1 节中的架构图)。

3. 计算查询、键和值矩阵:模型接着计算出注意力机制所需的查询、键和值矩阵(QKV)。这一过程通过将归一化的嵌入矩阵与相应的权重矩阵 W_QW_KW_V 相乘来实现。这里,结果矩阵的列分别表示每个令牌的查询、键和值向量。查询和键矩阵用于计算注意力得分,然后这些得分对值矩阵中的值进行加权,从而在注意力模块中生成上下文感知的输出(有关更详细的解释,请参见第三部分)。

4. 旋转查询和键矩阵:查询和键矩阵被旋转以包含位置信息。由于在计算注意力得分时仅涉及查询和键矩阵,因此位置信息仅添加到这些矩阵中。因此,值矩阵不进行旋转。在计算出注意力得分后,值矩阵仅提供将根据得分更新的嵌入。这就是为什么在架构图中值矩阵中省略了位置编码符号的原因。

3.4 — 特征对的旋转

RoFormer 论文首先考虑了一个简单的情况,其中每个令牌嵌入只有两个维度(d=2)。在这个示例中,应用标准的二维旋转矩阵对令牌的查询和键向量(分别表示为 qk)进行旋转非常简单。下面的方程展示了旋转后的查询向量 q_rot 和键向量 k_rot,它们来自归一化的令牌嵌入。旋转矩阵 R 是一个尺寸为 d x d 的方阵:在此情况下,R 是 2x2。旋转矩阵还依赖于角度 θ(我们稍后会讨论)和乘数 mm 由令牌在序列中的绝对位置决定。也就是说,对于第一个令牌,m = 0;对于第二个令牌,m = 1,以此类推。

注意:以下方程展示了单一查询和键向量的简化示例,而非整个查询和键矩阵。实际上,这一操作将在矩阵级别而非向量级别进行,以并行化处理并显著提高效率。然而,底层概念保持不变。

旋转后的查询(上)和键(下)向量的方程,其中包含通过 RoPE 过程编码的位置信息。图片来源:作者。

这些方程展示了简单二维情况的过程。在实际应用中,大多数模型使用数百甚至数千维的嵌入。旋转这么多维度的向量变得非常复杂,导致一次性旋转整个向量变得不实际。为了解决这个问题,作者提出了通过对每一对特征应用二维旋转矩阵来每次旋转两个元素。这样做的好处是速度更快,操作更简单,但限制了模型只能使用偶数维度的嵌入(尽管通常情况下就是如此)。

下面的公式展示了 d 维嵌入向量的旋转矩阵形式。你会看到沿对角线有重复的 2D 旋转矩阵,剩余的元素填充为零。由于嵌入向量有 d 个维度,所以有 d/2 对特征,因此对角线上有 d/2 个旋转矩阵。

RoPE 中使用的旋转矩阵 R 的一般形式。图片由作者提供。

3.5 — 多频率位置编码

在上面的公式中,你可能会注意到每一对特征都有自己独特的下标θ,表示每一对特征是按稍微不同的角度旋转的。你可能会想为什么每一对特征不按相同的角度旋转。简短的回答是,使用常数θ是可行的,但为每一对特征调整θ可以提高模型性能。θ的变化使得模型能够以更细粒度的方式捕捉嵌入信息,即在特征对的层级上,而不仅仅是在嵌入层级上。这被称为多频率位置编码,这种技术使得模型能够学习嵌入空间的信息,并在注意力机制中创建更丰富的数据表示。

确定旋转角度 θ:

这个难题的最后一部分是建立一个关于θ的公式。作者提出了下面左侧的方程,它将旋转角度作为令牌嵌入维度 d 和特征对索引 i 的函数进行计算。这个方程的形式直接受原始 Transformer 的正弦编码(右侧)的启发,作者特别指出这个选择是为了确保“长期衰减特性”[11]。这描述了一个特性,即远距离的令牌之间的连接比近距离的令牌之间的连接要少,这一点在原始 Transformer 中表现良好。

注意: 如果你以前见过正弦编码的公式,你可能会记得分子通常用“pos”而不是“m”表示。“pos”和“m”都表示令牌在输入序列中的绝对位置,因此我们在这里使用相同的符号写出两个方程,以帮助更容易地进行视觉比较。

RoPE(左)和正弦编码(右)的位置编码方程对比。图片由作者提供。

3.6 — 进一步提高计算效率

简而言之,RoPE 通过将d维查询和键向量通过d x d旋转矩阵进行旋转来引入位置信息,如下所示。这里,x用作通用符号,表示查询向量(q)或键向量(k):

在 d 维空间中,RoPE 的一般形式,其中通用地表示被旋转的查询或键向量。图片由作者提供。

实际上,由于矩阵乘法的性质,这种方法仍然相当缓慢。幸运的是,我们可以通过一个技巧来加速这一过程。旋转矩阵包含许多零元素,因此被称为稀疏矩阵。由于这种稀疏性,我们可以重新构造方程的形式,仅使用逐元素乘法和向量加法——这两种操作要快得多。下面的方程展示了在实际模型中使用的 RoPE 的高效实现,其中⊙表示逐元素乘法。

表示逐元素向量乘法和加法的 RoPE 方程的扩展形式。图片由作者提供。

你可以在 HuggingFace 的 Llama 仓库中查看 PyTorch 实现的 RoPE 公式[12]。下面是对方程进行改写的版本,旨在帮助理解代码:

上述方程的改写形式,更加贴合 Hugging Face GitHub 仓库中 Llama 模型使用的 PyTorch 实现的 RoPE。图片由作者提供。

def rotate_half(x):
    """Rotates half the hidden dims of the input."""

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors."""
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

这 10 行代码使得像 Llama 和 Mistral 7B 这样的模型能够实现丰富的位置信息编码,同时保持快速的训练和推理速度。RoPE 的优势可以总结为:

  • 高效实现了编码令牌之间的相对位置

  • 通过更好地学习短期和长期依赖关系,改进了模型在长序列上的表现。

  • 容易与现有的点积自注意力机制兼容

4 — 分组查询注意力(GQA)

第三部分中,我们详细介绍了自注意力机制,并简要介绍了多头注意力(MHA),这是自注意力在原始 Transformer 架构中的一种具体实现。从那时起,更新的模型采用了改进的注意力机制,优化了训练和推理的效率。Mistral 7B 使用了分组查询注意力(GQA),该机制本身是建立在多查询注意力(MQA)基础上的。在本节中,我们将按时间顺序探讨这些技术,以了解 Mistral 7B 如何执行自注意力。

4.1 — 多头注意力(MHA)概述

多头注意力(MHA)是在 2017 年的论文“注意力就是一切” [13] 中引入的,通过将注意力机制分成多个头部来扩展标准的自注意力。在标准的自注意力中,模型学习一个单独的权重矩阵集合(W_QW_KW_V),将标记嵌入矩阵 X 转换为查询、键和值矩阵(QKV)。然后使用这些矩阵计算注意力分数,并使用上下文信息更新 X

相比之下,MHA 将注意力机制分割为H个独立头部,每个头部学习自己的较小权重矩阵集合。这些权重用于计算一组较小的、头部特定的查询、键和值矩阵(表示为Qh*、*KhV^h)。每个头部独立处理输入序列,生成不同的注意力输出。然后将这些输出连接在一起(叠加在一起),并通过最终的线性层传递,产生更新的X矩阵,如下图中显示的Y,具有丰富的上下文信息。

通过引入多个头部,MHA 增加了注意力过程中可学习参数的数量,使模型能够捕捉数据中更复杂的关系。每个头部学习自己的权重矩阵,使它们能够关注输入的不同方面,如长距离依赖关系(远距离单词之间的关系)、短距离依赖关系(附近单词之间的关系)、语法结构等。总体效果产生了对输入序列更细致理解的模型。

4.2 — 多头注意力逐步

让我们逐步走过这个过程,展示每个阶段使用的方程式及其维度。这些步骤的摘要在最后的单个图表中给出。

1. 生成一个标记嵌入矩阵,X

首先,输入序列被标记化,标记 ID 被映射到它们学习的嵌入中,并添加位置信息。这产生了一个大小为L_max x d 的矩阵,其中L_max是输入序列的最大长度,d是每个标记的嵌入维度。这给出了标记嵌入矩阵X,它沿着列存储标记嵌入向量。

标记嵌入矩阵 X,它构成了多头注意力过程的输入,以及其维度。L_max 由最大序列长度给出,d 表示嵌入维度的数量。作者提供的图片。

2. 为每个头部计算查询、键和值矩阵:

接下来,矩阵X被传递给每个头进行处理。每个头都有一组自己的查询、键和值的权重矩阵(分别表示为W_Qh*、*W_KhW_V^h),其维度为d x d_H,其中d_Hd/H给出。这些权重矩阵通过与 X 的预乘得到该头的查询、键和值矩阵(Qh*、*KhV^h),其维度为L_max x d_H

每个头的查询、键和值矩阵的方程式。d_H表示矩阵中的列数,由嵌入维度(d)除以头的数量(H)得到。图片由作者提供。

注意: 在本解释中,我们假设 W_Qh*、*W_KhW_V^h* 都具有相同的维度d* x d_H。这不是一个严格的要求。在某些实现中,查询、键和值的权重矩阵可能具有不同数量的列,分别表示为d_Qd_Kd_V。然而,实际上,通常会看到d_Q = d_K = d_V = d_H,正如我们在这里所看到的那样。值得注意的是,由于这个原因,你还会看到有些人将d_H 简单地表示为d_K(就像我们在第三部分中做的那样),因为它们是等效的。*

3. 计算每个头的注意力权重:

对于每个头,使用查询和键矩阵按下式计算注意力权重,生成一个维度为L_max x L_max的矩阵。为每个头使用不同的权重矩阵可以捕捉到序列中的不同关系,例如句法或语义模式,从而提高模型学习和生成文本的能力。

用于计算每个头的注意力权重的方程式,它是头特定的查询和键矩阵的函数。图片由作者提供。

4. 计算每个头的注意力输出:

在每个头中,注意力权重被用来预乘相应的值矩阵,从而得到注意力输出矩阵,其维度为L_max x d_H

用于计算每个头的注意力输出的方程式,它是头特定的查询(Query)、键(Key)和值(Value)矩阵的函数。图片由作者提供。

5. 拼接注意力输出:

每个头的注意力输出随后通过拼接的方式进行组合。也就是说,构建一个新的矩阵,其元素就是注意力输出的元素按顺序堆叠在一起。矩阵的顶部由第一个头的输出填充,然后是第二个头的输出,依此类推。由于这个矩阵由H个较小的矩阵组成,每个矩阵的维度为L_max x d_H,所以大矩阵的维度为L_max x d(请记住,d = H x d_h)。

用于将每个头的注意力输出拼接成单一矩阵的方程式。图片由作者提供。

6. 应用最终的线性变换:

最后,连接后的矩阵通过一个线性层处理,可以通过以下矩阵乘法来数学表示。这个层的权重 W_O 在训练过程中学习,并将连接后的输出转换为输出矩阵 Y。这个输出通过改善嵌入中存储的上下文信息,来提升输入序列 X 的表示能力。

多头注意力步骤的输出方程式 Y 是通过将每个头的连接输出与矩阵 W_O 相乘得到的,矩阵 W_O 的值是通过反向传播使用线性层学习得到的。图像来自作者。

多头注意力总结:

以下图片展示了 MHA 过程的摘要:

多头注意力过程概述。图像来自作者。

4.3 — 多查询注意力(MQA)

多头注意力自 2017 年推出以来,被证明非常有效,产生了许多最先进的模型。然而,MHA 存在一个主要的缺点:该技术对内存的消耗极大。原因是每个注意力头都需要在内存中存储较大的键(Key)和值(Value)矩阵,这会造成瓶颈,限制了在给定硬件配置下可使用的模型的整体大小。为了应对这个问题,2019 年提出了多查询注意力(MQA),并首次出现在 Noam Shazeer(原始 Transformer 的作者之一)在论文《Fast Transformer Decoding: One Write-Head is All You Need》[14] 中。 在 MQA 中,相同的键和值矩阵在所有头之间共享,只有查询矩阵(Query)是特定于每个头的。这种方法显著减少了内存的使用,代价是性能略有下降。以下图展示了 MHA 和 MQA 过程之间的区别。

多查询注意力过程概述。图像来自作者。

4.4 — 增量推理

论文还描述了一种重要的优化技术,称为 增量推理,这对于随着大规模语言模型(LLM)尺寸的增大而提高效率是必需的。在这种方法中,模型在预测新标记时不会为每个时间步重新计算查询(Query)、键(Key)和值(Value)矩阵。相反,模型会利用来自前一个时间步的缓存值。以下是这个过程的概述:

1. 计算 Q_hKV

模型为每个注意力头计算一个查询矩阵(Q_h)以及所有头共享的 Key(K)和 Value(V)矩阵,基于输入序列计算得出。KV矩阵中的值会存储在KV 缓存中,以供后续注意力计算使用(我们将在第六部分中详细讨论)。Q_h矩阵中的值不被缓存,因为只有新 token 的查询向量会在下一时间步使用(有关查询、key 和 value 之间的区别,请参见第三部分中的数据库类比)。

2. 预测 x_new

然后,使用 Q_hKV 矩阵计算注意力输出,并将这些输出合并以生成输入序列的上下文嵌入。这些嵌入用于预测输出序列的第一个 token,x_new

3. 计算 q_(new,h)

新的 token 会被附加到输入序列中,并为每个头计算相应的查询向量,q_(new,h),其计算公式如下:

q_(new, h)的计算公式,表示为生成的最新 token 的查询向量,该向量将用于后续的注意力计算。图片来源:作者。

4. 注意力步骤:

查询向量 q_(new,h) 将与缓存的KV矩阵结合,使用以下公式生成注意力输出:

使用最新生成的 token 查询向量 q_new 的注意力步骤公式。图片来源:作者。

5. 更新 KV 缓存:

新 token 的 key 和 value 向量(k_newv_new)通过以下公式计算:

这些向量将被附加到缓存的KV矩阵中。

6. 重复过程:

该过程会重复进行,模型每次预测一个 token,直到生成序列结束(EOS)token。

4.5 — 分组查询注意力(GQA)

分组查询注意力(GQA)由谷歌研究人员在 2023 年提出,相关论文为“GQA: 从多头检查点训练通用多查询转换器模型” [15],可视为 MHA 和 MQA 的广义形式。在 GQA 中,Key 和 Value 矩阵在G个头的组之间共享,其中组的大小由用户决定。

如果所有组都只包含一个头(G=1),每个头都有自己独特的 Key 和 Value 矩阵,这相当于 MHA。另一方面,如果每个头属于一个单独的组(G=H),所有头共享相同的 Key 和 Value 矩阵,这相当于 MQA。GQA 的优势在于可以选择一个组大小,使得性能损失最小,同时显著提高内存效率。下图展示了 MHA、MQA 和 GQA 的对比,取自 GQA 论文。

多头、分组查询和多查询注意力的对比。图像来源于[15]。

4.6 — 分组查询注意力的好处

GQA 的好处最好通过以下图表来总结,这些图表来自原始论文。它们比较了使用 MHA、MQA 和 GQA 的 T5 Large 和 T5 XXL 模型的性能和处理时间,其中 T5 指的是 Google 在 2019 年发布的一系列编码器-解码器 Transformer 模型(H=64)[16]。左侧图表显示,虽然 MHA 提供了最佳的性能,但它也是最慢的。相比之下,MQA 实现了最快的运行时间,但牺牲了性能。GQA 则达到了平衡,提供了高性能且显著减少了运行时间。右侧图表显示了组数与运行时间的关系。注意,在这里使用两个组,每个组有 32 个头(G=32),在保持强劲性能的同时,显著提高了运行时间。因此,许多开发者现在选择使用 GQA,接受性能略微下降,以换取训练和推理的巨大效率提升。

多头、多查询和分组查询注意力的性能对比。左侧图表显示了性能与运行时间的关系,表明 GQA 在保持与 MQA 相似的运行时间的同时,达到了与 MHA 相似的性能。右侧图表显示了 GQA 中组数(G)与运行时间的关系,其中 G=32 时提供了强劲的性能和较低的运行时间。图像来源于[15]。

5 — 滑动窗口注意力(SWA)

5.1 — 因果掩码概述

Mistral 7B 支持比像 BERT 这样的模型更长的上下文长度,这得益于架构选择,例如使用滑动窗口注意力(SWA)。要理解 SWA,我们首先需要探索掩码自注意力,这是 Transformer 架构中的一个关键组件。如果你查看原始的 Transformer 架构图,你会看到其中一个解码器的注意力模块被标记为“掩码多头注意力”而不是“多头注意力”。这个区别看似微小,但对于训练这类模型至关重要。

当一个 Transformer 处理输入序列时,编码器通过分词、嵌入、位置编码和自注意力创建一个内部的数值表示。在编码器中,自注意力利用完整的双向上下文,使得每个标记都能够关注序列中所有其他标记。然后,解码器在自回归过程中迭代生成序列,其中每个新标记都是基于先前生成的标记进行预测的。在这种设置下,标记只能关注序列中先前的标记,因为未来的标记尚未生成。这就是之前提到的单向上下文。

为了在训练期间复制这种行为,在注意力机制中应用了因果掩码。这个掩码确保 token 不能“看到”(即关注)未来的 token,通过将它们屏蔽掉,因此在“掩码多头注意力”中的“掩码”一词。在训练过程中,模型生成 token 并将其预测与期望的输出进行比较,通过反向传播更新其权重。尽管训练期间已知完整的输出序列,但因果掩码阻止了模型利用这一知识,确保训练过程模拟了模型在推理时的行为。

5.2 — 从掩码到滑动窗口

滑动窗口注意力首次由 Beltagy 等人在 2020 年的论文“Longformer: The Long-Document Transformer” [17]中提出,并将掩码的概念扩展到模型的所有注意力模块,包括编码器和解码器。其思想是将注意力限制在一个局部窗口内,窗口大小为w,即当前 token 前后可关注的 token 数量。这减少了每个 token 关注的 token 数量,从而将注意力步骤的时间复杂度从 O(L_max²)降低到 O(w x L_max)。在编码器中,token 仍然可以在定义的窗口范围内关注其他前后 token,而在解码器中,token 只会关注先前生成的 token,从而保持自回归特性。然而,注意力的范围进一步限制在滑动窗口内。SWA 引入的主要变化是,注意力的范围被限制为窗口的大小,从而减少了计算开销,而不会牺牲模型处理局部上下文的能力。

5.3 — 实现滑动窗口注意力

因果掩码和 SWA 在注意力机制中的应用位置相同:即在 softmax 函数之前。超出允许范围的 token(由于因果约束或滑动窗口限制)其注意力得分将被替换为负无穷大。当应用 softmax 函数时,这些掩码得分会消失(因为 e^-∞=0)。这确保了只有未被掩盖的 token 才会贡献正常化后的注意力权重,且有效 token 的注意力权重总和为 1,而被掩码的 token 对输出没有影响。下图展示了原始注意力、带因果掩码的注意力和滑动窗口注意力的比较。

在转换为注意力权重之前,原始注意力、带因果掩码的注意力和滑动窗口注意力的注意力得分比较。图片由作者提供。

6 — 滚动缓冲区 KV 缓存

6.1 — 滚动缓冲区 KV 缓存概述

在 4.4 节中,我们讨论了增量推理作为一种优化技术,它利用标准的 KV 缓存。其工作原理是对输入序列计算查询、键和值矩阵一次,利用它们生成输出序列的第一个标记。之后,键和值矩阵被缓存。当生成后续标记时,最近生成的标记被用来计算查询向量(而不是矩阵)和相应的键值向量。这些新的键值向量会被附加到缓存的键值矩阵中。这种方法使得模型能够高效地生成新标记,因为它只需要计算一个查询向量和对缓存的键值矩阵进行小范围的更新,而不是每个时间步都重新计算完整的查询、键和值矩阵。

滚动缓存 KV 缓存进一步通过利用滑动窗口注意力中的滑动窗口来扩展这一方法。“滚动缓存”指的是缓存中的键值矩阵仅存储当前注意力窗口内的标记信息。因此,缓存可以“遗忘”窗口外的标记,大大减少内存使用,同时保持生成准确标记所需的关键信息。这些创新使得模型能够高效处理长输入,在不引发过度内存使用的情况下使得 32,000 标记的上下文长度成为可能。

6.2 — 实现滚动缓存

与标准的 KV 缓存不同,标准的 KV 缓存随着每个标记的预测而矩阵大小增大,而滚动缓存在推理过程中始终保持固定大小,这一大小由注意力窗口决定。随着窗口向前滑动,缓存通过用新进入窗口的标记的键值向量替换当前窗口外的标记的键值向量来更新。这确保了缓存只存储与当前活动上下文相关的信息,从而减少了内存使用。

以下图片来自 Mistral 7B 论文,展示了针对三个示例句子的滚动缓存概念。对于句子“This is an example of…”,缓存的窗口大小为 4 个标记。最初,标记按顺序添加:Thisisanexample。当第五个标记of添加时,第一个标记This被移除,以保持窗口大小。缓存继续进行这种滚动过程,确保在任何给定时间只存储最新的 4 个标记。

针对窗口大小为 4 的滚动缓存 KV 缓存的概述。图片摘自[1]。

6.3 — 预填充和分块

Mistral 7B 论文还引入了预填充分块的概念,这为推理过程中减少时间和内存使用提供了更多方法。

预填充指在增量推理之前,将输入序列中的所有标记的键和值向量填充到 KV 缓存中。此过程确保输入序列的静态部分(例如提示)在提前完全处理,从而减少生成新标记时的冗余计算。

块化(Chunking)通过将输入划分为固定长度的部分,称为块(chunks),来解决处理长序列长度的问题,每个块的长度等于注意力机制的窗口大小。为了防止内存过载,每个块的键和值矩阵会单独计算,并迭代地添加到缓存中。块化可以在推理过程中继续使用,因为更多的标记会被生成。最新块中的标记只能关注自身以及存储在前一个缓存块中的标记(只要它们在上下文窗口内)。下图展示了这一过程,图片来源于 Mistral 7B 论文。

KV 缓存的概述,其中输入序列已经被预填充到三个块中。最后一个块中的标记只能关注自身和前一个块,只要这些标记在本地上下文窗口内。图片来源于[1]。

7 — SwiGLU 激活函数

7.1 — 激活函数回顾

激活函数是神经网络中至关重要的组成部分,广泛存在于 Transformer 模型中,允许学习输入数据中的复杂模式。当来自前一层神经元的激活值传递到下一层时,它们会与权重相乘并求和,产生加权和(表示为 z)。由于加权和是通过简单的乘法和加法操作形成的,因此修改输入激活值的过程被描述为线性变换。为了捕捉更复杂的关系,使用非线性的“激活”函数将 z 值映射到 0 到 1(或-1 到 1,具体取决于函数)的范围内。

最早广泛使用的激活函数之一是Sigmoid 函数,它平滑地将大负和映射为 0,将大正和映射为 1。它的主要特点是输入在中点(接近 0)附近的微小变化会导致输出的小而平滑的变化,从而帮助稳定学习过程。

sigmoid 激活函数的图示及其方程,用于将权重和的线性组合映射到非线性输出。图片由作者提供。

7.2 — 修正线性单元(ReLU)

尽管 Sigmoid 激活函数最初很受欢迎,但它也存在一些问题,其中最主要的是我们在第 2.2 节讨论的梯度消失问题。为了解决这些局限性,Rectified Linear Unit(ReLU)激活函数在 1975 年由 Kunihiko Fukushima 在论文“Cognitron: A Self-Organizing Multilayered Neural Network”中提出[18]。

ReLU 激活函数通过将负输入值(z<0)的输出设为零,并线性映射正输入值(z>0)来简化计算。与 Sigmoid 不同,ReLU 避免了对高度正输入的饱和,保持了对变化的敏感性,从而在深度网络中实现更高效的学习。

注意: 饱和描述了一种激活函数,其输出几乎不受输入变化的影响,导致梯度减小并阻碍有效的权重更新。ReLU 对正值的线性行为避免了这个问题。

修正线性单元(ReLU)激活函数及其方程的图示。图片由作者提供。

7.3 — 门控线性单元(GLU)

门控线性单元(GLU)由 Dauphin 等人在 2017 年提出,首次出现在论文《使用门控卷积网络进行语言建模》中[19]。虽然 ReLU 激活函数在现代神经网络架构中仍然广泛使用,但由于 GLU 能更好地捕捉复杂的语言模式和关系,它在语言建模任务中越来越受到欢迎。

GLU 的一个关键特性是每个单元内部的门控机制,该机制动态调整激活输出。该机制包含一个额外的学习门,数学表达式为 z1σ(z2),其中 z1 为主要输入,z2 作为门。第二个输入 z2 通过 Sigmoid 激活函数 σ(z2) 传递,控制信息流动,提供选择性激活的机制。这个双输入设计将 GLU 与 ReLU 区分开来,提供了一种更为细致的激活函数,有助于减轻神经元永远不活动的风险(这是 ReLU 常见的问题)。我们在这里不深入探讨这些复杂的细节,但如果你对 GLU 感兴趣,建议阅读原始论文。

门控线性单元(GLU)激活函数及其方程的图示。图片由作者提供。

7.4 — Swish 门控线性单元(SwiGLU)

Swish 门控线性单元(SwiGLU)作为对常规门控线性单元(GLU)的改进而提出,并在谷歌研究的 2022 年论文《PaLM: 使用路径模型扩展语言建模》中首次亮相,配合 PaLM 模型一起使用[20]。通过将 Swish 激活函数(表达式为 zσ(z)) 与 GLU 的门控机制结合,SwiGLU 提供了更强的表达能力,更好地模拟数据中的复杂关系,特别是在语言建模任务中表现突出。注意 Swish 和 GLU 函数的区别:Swish 是单输入函数,而不像 GLU 那样是双输入函数。

Mistral 7B 在其前馈子层中使用了 SwiGLU 激活函数,增强了从训练数据中提取有意义模式的能力,并提高了推理过程中的性能。这一改进有助于 Mistral 7B 在处理复杂语言结构和大上下文窗口时的有效性。

一个 Swish Gated Linear Unit(SwiGLU)激活函数的图像及其方程。图片由作者提供。

8 — 结论

随着 Mistral 7B 的发布,Mistral AI 在模型大小成为推动性能的主要因素时进入了 LLM 领域。Mistral AI 并没有追随日益增大的模型趋势,而是通过强调创新的、内存高效的设计,在参数量的极小部分上取得了令人印象深刻的结果。Mistral 7B 的成功证明了强大的性能并不总是需要庞大的模型,而战略性的设计选择可以使更小的模型与其更大的对手相当,甚至超越它们。

在这一方法的基础上,Mistral 继续推动效率和性能的边界,拓展到如 Mixture of Experts(Mixtral 8x7B)、语言-视觉模型(Pixtral)以及移动领域(Mistral 3B)等多个领域。随着公司的发展,值得关注的是,它们如何继续推动更小模型的艺术前进。

9 — 进一步阅读

[1] Jiang, Albert Q. 等,Mistral 7B(2023),arXiv 预印本 arXiv:2310.06825。

[2] Hugging Face,Mistral AI(2024),HuggingFace.co。

[3] Hendrycks, D. 等,Measuring massive multitask language understanding(2020),arXiv 预印本 arXiv:2009.03300。

[4] Zhong, W. 等,AGIEval: A human-centric benchmark for evaluating foundation models(2023),arXiv 预印本 arXiv:2304.06364。

[5] Suzgun, M. 等,Challenging big-bench tasks and whether chain-of-thought can solve them(2022)arXiv 预印本 arXiv:2210.09261。

[6] Ba, J. 等,Layer Normalization(2016)arXiv 预印本 arXiv:1607.06450。

[7] Zhang, B. 和 Sennrich, R.,RMS Normalization(2019)预印本 arXiv:1910.07467。

[8] Shaw, P. 等,Self-Attention with Relative Position Representations(2018)arXiv:1803.02155。

[9] Dai, Z. 等,Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context(2019)arXiv:1901.02860。

[10] Raffel, C. 等,Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer(2019)arXiv:1910.10683。

[11] Su, J. 等,ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING(2023)arXiv:2104.09864。

[12] Hugging Face, Llama 模型建构 (2024). GitHub

[13] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, 和 I. Polosukhin, Attention is All You Need (2017), Advances in Neural Information Processing Systems 30 (NIPS 2017)

[14] Shazeer, N., 快速转换器解码:只需一个写头 (2019) arXiv:1911.02150

[15] Ainslie, J., 等人, GQA:从多头检查点训练广义多查询转换器模型 (2023) arXiv:2305.13245

[16] Raffel, C., 等人, 探索统一文本到文本转换器在迁移学习中的极限 (2023) arXiv:1910.10683

[17] Beltagy, I., 等人, Longformer:长文档转换器 (2020) arXiv:2004.05150

[18] link.springer.com/article/10.1007/BF00342633

[19] Dauphin, Y. N., 等人, 使用门控卷积网络的语言建模 (2017) arXiv:1612.08083

[20] Chowdhery, A., 等人, PaLM:利用路径扩展语言建模 (2022) arXiv:2204.02311

Mistral AI 与 Meta:对比顶级开源 LLM

原文:towardsdatascience.com/mistral-ai-vs-meta-comparing-top-open-source-llms-565c1bc1516e?source=collection_archive---------1-----------------------#2024-01-23

这是对 Mistral 7B 与 Llama 2 7B、Mixtral 8x7B 与 Llama 2 70B 的比较。

Luís RoqueTowards Data Science Luís Roque

·发布于Towards Data Science ·阅读时间:16 分钟·2024 年 1 月 23 日

--

本文由 Rafael Guedes 和我共同撰写。

介绍

在自然语言处理(NLP)的最新发展中,特别是在大型语言模型(LLMs)方面,重点是提升模型性能,这通常会导致模型规模的增大。可以预见的是,模型规模的扩大也增加了计算成本和推理延迟,带来了在实际应用场景中部署和使用 LLM 时的障碍。

Mistral AI 是一家总部位于巴黎的欧洲公司,他们一直在研究如何提高模型性能,同时减少部署 LLM 所需的计算资源,以便能在实际用例中应用。Mistral 7B 是他们开发的最小型 LLM,它将两个创新概念引入到传统的 Transformer 架构中,分别是 Group-Query Attention(GQA)和 Sliding Window Attention(SWA)。这些组件加速了推理速度,并减少了在解码过程中对内存的需求,从而实现更高的吞吐量,并能够处理更长的令牌序列,同时不牺牲生成的响应质量,相比于在基准数据集上的 Llama 2 7B 表现更佳。

Mistral 7B 并不是他们唯一开发的模型,他们还创建了 Mixtral 8x7B 来与更大规模的 LLM(如 Llama 2 70B)竞争。除了使用 GQA 和 SWA 外,这个版本还增加了第三个…

Mistral-NeMo: 通过量化 Minitron 减少 4.1 倍大小

原文:towardsdatascience.com/mistral-nemo-4-1x-smaller-with-quantized-minitron-9d6ad7b70981?source=collection_archive---------7-----------------------#2024-08-29

剪枝、知识蒸馏和 4 位量化如何使先进的 AI 模型变得更加易于访问和具备成本效益

本杰明·马里Towards Data Science 本杰明·马里

·发表于Towards Data Science ·9 分钟阅读·2024 年 8 月 29 日

--

作者提供的图片 — 制作自Pixabay的插图

NVIDIA 的 Minitron 通过剪枝最不重要的权重来压缩大型语言模型(LLMs),然后通过知识蒸馏进行重新训练。这种方法显著减少了模型的大小,同时保持了其准确性。

NVIDIA 发布了 Llama 3.1 和 Mistral-NeMo 的 Minitron 版本,分别将其参数数量从 8B 减少到 4B,以及从 12B 减少到 8B。

为什么这很重要?

虽然 Mistral-NeMo 无法在消费级 GPU 上运行,但其 Minitron 版本可以。一个 24 GB 的 GPU 就足够了。然而,这也可以通过对 Mistral-NeMo 进行量化来实现。4 位量化方法现在已经足够准确。

但是,如果我们也能对 Minitron 模型进行量化呢?对于一个经过 Minitron 剪枝的模型,量化是否仍然足够准确?

例如,Mistral-NeMo-Minitron 的 4 位版本可以在 8 GB 的 GPU 上运行,显著降低推理成本。

在这篇文章中,我回顾了 Minitron 方法,探讨了如何通过剪枝和知识蒸馏来压缩大型语言模型(LLMs)。我们将…

Mistral 与 Mixtral:比较 7B、8x7B 和 8x22B 大语言模型

原文:towardsdatascience.com/mistral-vs-mixtral-comparing-the-7b-8x7b-and-8x22b-large-language-models-58ab5b2cc8ee?source=collection_archive---------2-----------------------#2024-04-20

在 Google Colab 中运行 7B 和 22B 模型

Dmitrii EliuseevTowards Data Science Dmitrii Eliuseev

·发表于 Towards Data Science ·阅读时间 10 分钟·2024 年 4 月 20 日

--

图片由 Img IX 提供,Unsplash

不久前,所有的 IT 新闻频道都报道了新发布的开源 Mixtral 8x22B 模型,它在 MMLU(大规模多任务语言理解)或 WinoGrande(常识推理)等基准测试中超越了 ChatGPT 3.5。这是开源模型领域的一大成就。学术基准测试自然很有趣,但这个模型在实际操作中表现如何?它的系统要求是什么?与之前的语言模型相比,它真的更好吗?在本文中,我将测试四种不同的模型(7B、8x7B、22B 和 8x22B,带有和不带有“专家混合”架构),我们将一起看到结果。

让我们开始吧!

顺便提一下,我与 Mistral AI 没有任何商业关系,所有的测试都是我个人独立完成的。

稀疏专家混合(SMoE)

在大语言模型(LLM)时代刚开始时,人们就已经知道,通常来说,模型越大,智能越强,知识越丰富,结果也越好。但更大的模型也意味着更高的计算成本。如果一个聊天机器人需要 5 分钟才能回应,没人愿意等下去。“专家混合”(mixture of experts)的直观想法是……

高性能时间序列预测的 KAN 专家混合模型

原文:towardsdatascience.com/mixture-of-kan-experts-for-high-performance-time-series-forecasting-5227e1d2aba2?source=collection_archive---------0-----------------------#2024-09-11

探索 RMoK 模型及其架构,并使用 Python 进行小规模实验。

Marco PeixeiroTowards Data Science Marco Peixeiro

·发布于Towards Data Science ·10 分钟阅读·2024 年 9 月 11 日

--

图片来自Kyaw TunUnsplash

Kolmogorov-Arnold 网络(KAN)的引入为深度学习领域做出了重要贡献,因为它代表了多层感知器(MLP)的替代方案。

MLP 当然是许多深度学习模型的构建模块,包括像 N-BEATS、NHiTS 和 TSMixer 这样最先进的预测方法。

然而,在一个使用 KAN、MLP、NHiTS 和 NBEATS 的预测基准测试中,我们发现 KAN 通常非常慢,并且在各种预测任务中表现持续较差。需要注意的是,该基准测试是在 M3 和 M4 数据集上进行的,这些数据集包含超过 99,000 个独特的时间序列,频率范围从每小时到每年。

最终,当时应用 KAN 进行时间序列预测的结果令人失望,并不是一种推荐的做法。

现在,随着可逆 KAN 混合模型(RMoK)的提出,情况发生了变化,相关内容可以在论文中找到:KAN4TSF: KAN 及基于 KAN 的模型对时间序列预测有效吗?

在本文中,我们首先探讨可逆混合 KAN 模型的架构和内部工作原理…

机器学习初学者应该阅读论文

原文:towardsdatascience.com/ml-beginners-should-read-papers-506a074ffc10?source=collection_archive---------2-----------------------#2024-12-10

这是为什么以及如何做到的

Pascal JanetzkyTowards Data Science Pascal Janetzky

·发表于Towards Data Science ·4 分钟阅读·2024 年 12 月 10 日

--

每天,超过 100 篇新的计算机科学和机器学习论文会被列在arXiv上。尽管这些作品在列出之前不一定经过同行评审,但它仍然是一个巨大的信息宝库。为了了解更多情况,请查看下面的图表,该图表显示了自 2009 年以来每月提交的增长情况,数据来自 arXiv:

每月提交到 arXiv 的计算机科学论文数量。图片由作者提供,数据可以从arXiv公开获取。

做个简单的计算,假设平均每篇论文需要 3 小时才能从头到尾阅读完。按照上述数字,需要 300 小时(或 12 天!)才能读完所有这些论文。而这只是读完一天的论文——第二天,我们又得重新开始,继续阅读类似数量的出版物。显然,这对于专家和初学者来说都是不可行的。

通常,作为机器学习的初学者,你可能会问:我需要阅读论文吗?而且,考虑到有这么多论文,我到底该怎么做呢?这就是原因和方法!

为什么机器学习初学者应该阅读论文

论文就是一场讲座:为了在顶级机器学习会议上被接受,论文的写作必须简洁明了。论文通常包括对主题的介绍、方法部分、结果和总结。总体而言,论文的内容就是对一个非常狭窄的主题进行(浓缩的)讲解。原因如下:

机器学习工程 101:对错误“DataLoader worker (pid(s) xxx) exited unexpectedly”的全面解释

原文:towardsdatascience.com/ml-engineering-101-a-thorough-explanation-of-the-error-dataloader-worker-pid-s-xxx-exited-f3a6a983911e?source=collection_archive---------6-----------------------#2024-06-03

深入探讨 PyTorch DataLoader 与多进程

赵梦柳Towards Data Science 赵梦柳

·发表于 Towards Data Science ·阅读时长 6 分钟·2024 年 6 月 3 日

--

作为日常使用 PyTorch 库的众多用户之一,我相信许多机器学习工程师在训练过程中迟早会遇到“DataLoader worker (pid(s) xxx) exited unexpectedly”这个问题。

这令人沮丧。

当使用 num_workers > 0 调用 DataLoader 时,通常会触发这个错误。许多在线帖子提供了简单的解决方案,比如将 num_workers 设置为 0,这样当前的问题就会消失,但实际上会引发新的问题。

本文将向你展示一些可能有助于解决这个问题的技巧。我将深入探讨 Torch.multiprocessing 模块,并展示一些有用的虚拟内存监控和泄漏防止技术。在极少数情况下,即使没有内存泄漏,torch.multiprocessing 工作进程的异步内存占用和释放仍然可能触发该问题。最终解决方案是优化虚拟内存的使用,理解 torch.multiprocessing 的行为,并在 getitem 方法中进行垃圾回收。

注意:我使用的平台是 Ubuntu 20.04。为了适应其他平台,许多终端命令需要做相应的调整。

图片来源:pxhere.com/en/photo/1379760#google_vignette

暴力解决方案及其缺点

如果你在网上搜索,大多数遇到相同问题的人都会告诉你暴力解决方案;只需将 DataLoader 中的 num_workers 设置为 0,问题就会消失。

如果你的数据集较小,并且能容忍较长的训练时间,这将是最简单的解决方案。然而,根本问题依然存在,如果你有一个非常大的数据集,设置 num_workers=0 将导致非常慢的性能,有时会慢 10 倍。因此,我们必须进一步研究问题并寻找替代解决方案。

监控你的虚拟内存使用情况

当 DataLoader 工作进程退出时,究竟发生了什么?

要捕捉系统中的最后一个错误日志,请在终端中运行以下命令,它将为你提供更详细的错误信息。

dmesg -T

通常,你会看到真实的原因是“内存不足”。但是为什么会出现内存不足的问题呢?具体是什么导致了额外的内存消耗?

当我们在 DataLoader 中设置 num_workers=0 时,只有一个主进程运行训练脚本。只要数据批次能够装入内存,它就会正常运行。

然而,当设置 num_workers > 0 时,情况就不一样了。DataLoader 会启动子进程,并将 prefetch_factornum_workers* 的数据预加载到内存中以加速训练。默认情况下,prefetch_factor = 2。预加载的数据将消耗机器的虚拟内存(但好消息是它不会占用 GPU,因此你不需要缩小批次大小)。所以,我们需要做的第一件事是监控系统的虚拟内存使用情况。

监控虚拟内存使用情况的最简单方法之一是使用 psutil 包,它将监控虚拟内存的使用百分比

import psutil
print(psutil.virtual_memory().percent)

你还可以使用 tracemalloc 包,它将为你提供更详细的信息:

snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
for stat in top_stats[:10]:
    print(stat)

当实际的 RAM 满时,空闲的数据将流入交换空间(因此它是你虚拟内存的一部分)。要检查交换空间,使用以下命令:

free -m

在训练期间临时更改交换空间大小(例如,增加到 16G),在终端中执行以下命令:

swapoff -a
fallocate -l 16G /swapfile
chmod 600 /swapfile
mkswap /swapfile
swapon /swapfile

/dev/shm(或者在某些情况下,/run/shm)是用于存储临时文件的另一种文件系统,应该进行监控。只需运行以下命令,你将看到文件系统中的驱动器列表:

df -h

要临时调整它的大小(例如,增加到 16GB),只需运行:

sudo mount -o remount,size=16G /dev/shm

Torch.multiprocessing 最佳实践

然而,虚拟内存只是问题的一部分。如果在调整交换磁盘后问题仍然存在怎么办?

问题的另一面是 torch.multiprocessing 模块的底层问题。官方网页上有许多最佳实践建议:

## 多进程最佳实践 - PyTorch 2.3 文档

torch.multiprocessing 是 Python 模块的直接替代品。它支持完全相同的操作,但进行了扩展……

pytorch.org

除了这些,还应该考虑另外三种方法,特别是关于内存使用的方面。

第一个问题是共享内存泄漏。泄漏意味着每次子进程运行后,内存没有正确释放,你可以通过监控运行时的虚拟内存使用情况来观察这个现象。内存消耗会不断增加,直到达到“内存不足”的程度。这是典型的内存泄漏现象。

那么,是什么导致了内存泄漏呢?

让我们来看看 DataLoader 类本身:

github.com/pytorch/pytorch/blob/main/torch/utils/data/dataloader.py

查看 DataLoader 的内部结构时,我们会看到,当nums_worker > 0时,调用的是_MultiProcessingDataLoaderIter。在_MultiProcessingDataLoaderIter内部,Torch.multiprocessing 创建了工作队列。Torch.multiprocessing 使用两种不同的内存共享和缓存策略:file_descriptorfile_system。虽然file_system不需要文件描述符缓存,但它容易导致共享内存泄漏。

要检查你的机器使用的是哪种共享策略,只需在脚本中添加以下内容:

torch.multiprocessing.get_sharing_strategy()

要获取系统文件描述符限制(Linux),请在终端运行以下命令:

ulimit -n

要将共享策略切换为file_descriptor

torch.multiprocessing.set_sharing_strategy(‘file_descriptor’)

要统计已打开的文件描述符数量,请运行以下命令:

ls /proc/self/fd | wc -l

只要系统允许,建议使用file_descriptor策略。

第二个问题是多进程工作者启动方法。简而言之,这是关于是否使用 fork 或 spawn 作为工作者启动方法的争论。Fork 是 Linux 中启动多进程的默认方式,可以避免某些文件复制,因此速度更快,但在处理 CUDA 张量和第三方库(如 OpenCV)时,可能会遇到问题。

要使用 spawn 方法,你只需将参数*multiprocessing_context=* “spawn”传递给 DataLoader。

第三,确保数据集对象可序列化/可拾取

有一篇非常不错的文章进一步讨论了进程折叠中的“按需复制”效应:ppwwyyxx.com/blog/2022/Demystify-RAM-Usage-in-Multiprocess-DataLoader/

简而言之,不再推荐__getitem__方法中创建文件名列表并加载它们。可以创建一个 numpy 数组或 pandas 数据框来存储文件名列表,以便进行序列化。如果你熟悉 HuggingFace,使用 CSV/数据框是加载本地数据集的推荐方法:huggingface.co/docs/datasets/v2.19.0/en/package_reference/loading_methods#datasets.load_dataset.example-2

如果你的数据加载器非常慢怎么办?

好的,现在我们对多进程模块有了更好的理解。但这就是故事的结局吗?

听起来真的很疯狂。如果你有一个大且重的数据集(例如,每个数据点 > 5 MB),就有可能遇到上述问题,我将告诉你为什么。秘密就在于多进程工作进程的异步内存释放。

诀窍很简单:黑进 torch 库,在 _MultiProcessingDataLoaderIter 类的数据队列前后添加一行 psutil.virtual_memory().percent

[## pytorch/torch/utils/data/dataloader.py 在 70d8bc2da1da34915ce504614495c8cf19c85df2 ·…

使用强大的 GPU 加速的 Python 中的张量和动态神经网络 - pytorch/torch/utils/data/dataloader.py 在…

github.com](https://github.com/pytorch/pytorch/blob/70d8bc2da1da34915ce504614495c8cf19c85df2/torch/utils/data/dataloader.py?source=post_page-----f3a6a983911e--------------------------------#L1130)

类似这样的:

print(“before clearing”, psutil.virtual_memory().percent)
data = self._data_queue.get(timeout=timeout)
print("after", psutil.virtual_memory().percent)

在我的情况下,我将 DataLoader 的 num_workers 设置为 8,并观察到如下情况:

所以内存不断上涨——但这算是内存泄漏吗?其实不算。问题的根本原因是 dataloader 的工作进程加载速度快于释放速度,创建了 8 个任务,却只释放了 2 个。这就是内存溢出的根本原因。解决方案很简单:只需在你的 getitem 方法的开头添加一个垃圾回收器即可:

import gc
def __getitem__(self, idx):
    gc.collect()

现在你已经做好了准备!

参考文献

ML 变形:通过串联 ML 模型实现优化结果

原文:towardsdatascience.com/ml-metamorphosis-chaining-ml-models-for-optimized-results-d89d952627a9?source=collection_archive---------2-----------------------#2024-10-23

知识蒸馏、模型压缩和规则提取的普遍原理

Vadim ArzamasovTowards Data Science Vadim Arzamasov

·发表于 Towards Data Science ·阅读时长 7 分钟·2024 年 10 月 23 日

--

图 1。此图及其他图像由作者在 recraft.ai 的帮助下创建

机器学习(ML)模型训练通常遵循一个熟悉的流程:首先收集数据,清理并准备数据,然后进行模型拟合。但如果我们能将这一过程进一步推进呢?就像一些昆虫在达到成熟之前会经历剧烈的变化一样,ML 模型也可以以类似的方式进化(参见 Hinton 等人[1])——我将其称为ML 变形。这个过程涉及将不同的模型串联在一起,最终生成的模型比直接从头开始训练的模型质量要好得多。

其工作原理如下:

  • 从一些初步知识开始,数据 1

  • 在这些数据上训练一个机器学习模型,模型 A(例如神经网络)。

  • 使用模型 A生成新的数据,数据 2

  • 最后,使用数据 2 来拟合你的目标模型,模型 B

图 2。ML 变形的示意图

你可能已经熟悉这一概念,知识蒸馏就是用一个较小的神经网络替换较大的神经网络。但 ML 变形更进一步,初始模型(模型 A)和最终模型(模型 B)不必是神经网络。

示例:MNIST 数据集上的 ML 变形

想象一下,你的任务是使用 MNIST 手写数字图像数据集训练一个多类决策树,但只有 1,000 张图像有标签。你可以直接在这个有限的数据上训练决策树,但准确度大约只能达到 0.67。这并不好,对吧?或者,你可以使用机器学习变换方法来提高结果。

但在深入探讨解决方案之前,我们先快速回顾一下支持这种方法的技术和研究。

1. 知识蒸馏(2015)

即使你没有使用知识蒸馏,你也可能见过它的应用。例如,Meta 建议蒸馏其 Llama 3.2 模型,以便将其适配到特定任务[2]。或者看看 DistilBERT——一个蒸馏版的 BERT[3],或者 DMD 框架,它通过蒸馏 Stable Diffusion 来加速图像生成速度,提升了 30 倍[4]。

知识蒸馏的核心是将知识从一个大型、复杂的模型(教师)传递给一个较小、效率更高的模型(学生)。该过程包括创建一个转移集,该集包含原始训练数据和由教师模型伪标注的额外数据(无论是原始数据还是合成数据)。这些伪标签被称为软标签——它们源自教师模型在多个类别上的预测概率。这些软标签提供了比硬标签(简单的类别指示符)更丰富的信息,因为它们反映了教师的信心,并捕捉到类别之间的微妙相似性。例如,它们可能表明一个特定的“1”比“5”更像“7”。

通过在这个丰富的转移集上进行训练,学生模型可以有效地模仿教师模型的表现,同时更加轻便、快速且易于使用。

以这种方式获得的学生模型比仅在原始训练集上训练得到的模型更准确。

2. 模型压缩(2007)

模型压缩[5]通常被视为知识蒸馏的前奏,但两者之间存在重要差异。与知识蒸馏不同,模型压缩似乎并未使用软标签,尽管文献中有一些相关说法[1,6]。我没有找到任何证据表明软标签是过程的一部分。事实上,原始论文中的方法甚至不依赖于人工神经网络(ANNs)作为模型 A。相反,它使用了一组模型——如支持向量机(SVMs)、决策树、随机森林等。

模型压缩通过逼近特征分布p(x)来创建转移集。然后,这个集由模型 A标注,提供条件分布p(y|x)。原始工作中的关键创新是一种名为 MUNGE 的技术,用于逼近p(x)。与知识蒸馏一样,目标是训练一个较小、更高效的模型 B,并保留较大模型 A的性能。

正如知识蒸馏中所示,这种训练方式得到的压缩模型,通常能够超过直接在原始数据上训练的类似模型,因为传递集中嵌入了丰富的信息[5]。

通常,“模型压缩”一词被更广泛地用来指代任何减少模型 A大小的技术[7,8]。这包括像知识蒸馏这样的技术,也包括不依赖传递集的技术,如修剪、量化或神经网络的低秩近似。

3. 规则提取(1995)

当问题不在于计算复杂度或内存,而在于模型决策过程的不透明性时,教育性规则提取提供了一种解决方案[9]。在这种方法中,训练一个更简单、更易解释的模型(模型 B)来复制不透明教师模型(模型 A)的行为,目的是推导出一组人类可读的规则。这个过程通常从将未标记的示例(通常是随机生成的)输入到模型 A开始,模型 A对这些示例进行标记,生成一个传递集。然后,使用这个传递集来训练透明的学生模型。例如,在分类任务中,学生模型可能是一个决策树,它输出如下规则:“如果特征 X1 大于阈值 T1 且特征 X2 小于阈值 T2,则分类为正类”。

教育性规则提取的主要目标是紧密模仿教师模型的行为,以保真度为衡量标准——即学生模型与教师模型之间的准确性,作为主要的质量衡量标准。

有趣的是,研究表明,通过这种方法创建的透明模型,有时能够比直接在用于构建模型 A的原始数据上训练的类似模型达到更高的准确率[10,11]。

教育性规则提取属于被称为“全局”模型解释方法的更广泛技术家族,这其中还包括分解性和折衷规则提取。更多详情请参见[12]。

4. 仿真作为模型 A

模型 A不一定是一个机器学习模型——它也可以是一个经济或物理过程的计算机仿真,例如模拟飞机机翼周围的气流。在这种情况下,数据 1由定义该过程的微分方程或差分方程组成。对于任何给定的输入,仿真通过数值解这些方程来做出预测。然而,当这些仿真变得计算开销较大时,就需要一种更快速的替代方案:一个代理模型(模型 B),它可以加速诸如优化之类的任务[13]。当目标是识别输入空间中的重要区域,例如系统稳定性的区域时,开发一个可解释的模型 B,这一过程称为情境发现[14]。为了生成用于代理建模和情境发现的传递集(数据 2),模型 A会在一个多样化的输入集上运行。

回到我们的 MNIST 示例

在 TDS 上的一篇深刻的文章中,Niklas von Moers展示了半监督学习如何提高卷积神经网络(CNN)在相同输入数据上的表现。这个结果适用于 ML 变换流水线的第一阶段,在该阶段,Model A是一个经过训练的 CNN 分类器。转移集Data 2包含原本标注的 1,000 个训练样本以及约 55,000 个由Model A高置信度预测的伪标注样本。接下来,我在Data 2上训练我们的目标Model B,一个决策树分类器,并达到了 0.86 的准确率——远高于仅在Data 1的标注部分训练时的 0.67。这意味着,将决策树与 CNN 解决方案链式连接,将决策树的错误率从 0.33 降低到了 0.14。相当大的提升,是吧?

查看完整的实验代码,请访问 GitHub 仓库。

结论

总结来说,ML 变换并非总是必要的——特别是当你的唯一关注点是准确性,并且不需要可解释性、更快的推理或减少存储需求时。但在其他情况下,链式模型可能会比直接在原始数据上训练目标模型产生显著更好的结果。

图 2:为方便参考,图示再次呈现

对于分类任务,过程包括:

  • Data 1:原始的,完全或部分标注的数据。

  • Model A:在Data 1上训练的模型。

  • Data 2:包括伪标注数据的转移集。

  • Model B:最终模型,旨在满足额外要求,如可解释性或效率。

那么,为什么我们不总是使用 ML 变换呢?挑战通常在于找到合适的转移集,Data 2 [9]。但这是另一个话题。

参考文献

[1] Hinton, Geoffrey. “蒸馏神经网络中的知识。” arXiv 预印本 arXiv:1503.02531 (2015 年)。

[2] 介绍 Llama 3.2

[3] Sanh, Victor 等人. “DistilBERT:BERT 的蒸馏版:更小、更快、更便宜、更轻量。” arXiv 预印本 arXiv:1910.01108 (2019 年)。

[4] Yin, Tianwei 等人. “一步扩散与分布匹配蒸馏。” IEEE/CVF 计算机视觉与模式识别会议论文集。2024 年。

[5] Buciluǎ, Cristian, Rich Caruana, 和 Alexandru Niculescu-Mizil. “模型压缩第 12 届 ACM SIGKDD 国际会议《知识发现与数据挖掘》论文集,2006 年。

[6] 知识蒸馏,维基百科

[7] 太空深度学习模型压缩技术概述,发表于 Medium

[8] 使用未标注的问答数据集蒸馏 BERT,发表于 Towards Data Science

[9] Arzamasov, Vadim, Benjamin Jochum, 和 Klemens Böhm. “教育规则提取以学习可解释模型 — 一项实证研究.” arXiv 预印本 arXiv:2112.13285 (2021).

[10] Domingos, Pedro. “通过多模型从示例中获取知识MACHINE LEARNING-INTERNATIONAL WORKSHOP THEN CONFERENCE-. MORGAN KAUFMANN PUBLISHERS, INC., 1997.

[11] De Fortuny, Enric Junque, 和 David Martens. “基于主动学习的教学规则提取.” IEEE 神经网络与学习系统交易 26.11 (2015): 2664–2677.

[12] Guidotti, Riccardo, 等. “解释黑盒模型的方法调查.” ACM 计算机调查 (CSUR) 51.5 (2018): 1–42.

[13] 代理模型,维基百科

[14] Python 中的情景发现Water Programming上的博客文章

[15] 让模型从自身学习,发表于 Towards Data Science

将简单线性回归剖析到最基础的层面

原文:towardsdatascience.com/mlbasics-simple-linear-regression-machine-learning-algorithm-predict-programming-8d83cac9873a?source=collection_archive---------2-----------------------#2024-01-14

MLBasics #1:用简单线性回归揭开机器学习算法的神秘面纱

Josep FerrerTowards Data Science Josep Ferrer

·发布于 Towards Data Science ·阅读时长 8 分钟·2024 年 1 月 14 日

--

图片由作者提供。ML Basics。简单线性回归。

在数据和计算机程序的世界里,机器学习的概念可能听起来像一个难以破解的难题,充满了复杂的数学和复杂的思想。

这就是为什么今天我想放慢脚步,看看使这一切运作的基本内容。我将开始发布一系列新的文章,名为 MLBasics。

我们将回顾那些简单却至关重要的模型,它们是机器学习的基本组成部分。可以把它看作是从一个大拼图中开始,从最简单的部分入手。我们回归到简单的内容,在这里很容易理解发生了什么。

所以,跟着我们一起走,看看我们如何将其拆解,并使一切变得清晰明了。

让我们一步一步地一起深入了解简单线性回归吧!👇🏻🤓

#1. 简单线性回归简介

预测分析的领域广阔,但在其核心是线性回归——最简单的方法,用来理解数据趋势。

虽然它扩展到多个变量时可能会让人觉得有些艰难,但今天我们的重点将专注于简单线性回归。

MLOps — MLflow Pipelines 的温和入门

原文:towardsdatascience.com/mlops-a-gentle-introduction-to-mlflow-pipelines-c7bcec88a6ec?source=collection_archive---------0-----------------------#2024-03-13

图片由Sean Robertson提供,来自Unsplash

使用 MLflow 管理您的端到端机器学习生命周期

Marcello PolitiTowards Data Science Marcello Politi

·发表于Towards Data Science ·8 分钟阅读·2024 年 3 月 13 日

--

介绍

各种统计数据显示,50% 到 90% 的模型未能成功投入生产。这通常是由于未能有效地组织工作流程。学术界(或 Kaggle 上)获得的技能,往往不足以支撑一个可以被成千上万用户使用的机器学习系统。

在寻找机器学习行业工作时,最为抢手的技能之一就是能够使用能协调复杂流程的工具,如 MLflow。

在本文中,我们将了解如何将一个项目结构化为多个步骤,并以有序的方式管理所有步骤。

我在Deepnote上运行本文的脚本:一个基于云的笔记本,非常适合协作数据科学项目和原型设计。

什么是 MLflow?

MLflow 是一个由Databricks开发的开源平台,用于机器学习的端到端生命周期管理。

MLflow 提供多种功能,如监控训练中的模型,使用工件存储等……

MLOps — 使用 PyTest 进行数据验证

原文:towardsdatascience.com/mlops-data-validation-with-pytest-749641874871?source=collection_archive---------5-----------------------#2024-06-11

图片由Michael Dziedzic提供,来自Unsplash

运行确定性和非确定性测试以验证你的数据集

Marcello PolitiTowards Data Science Marcello Politi

·发表于Towards Data Science ·9 分钟阅读·2024 年 6 月 11 日

--

介绍

在 MLOps 管道中,我们尽量自动化尽可能多的步骤,目标是最小化程序员直接干预可能导致的错误数量,同时也要关注数据集验证。我相信大家都熟悉机器学习的第 1 条规则:垃圾进,垃圾出。无论我们开发的模型有多么复杂,如果数据集没有得到妥善处理,我们很有可能会得到糟糕的结果。

在本文中,我们将看到如何使用PyTest对数据集进行自动化验证。

我使用Deepnote运行本文中的脚本:这是一个基于云的笔记本,非常适合协作的数据科学项目和原型开发。

关于 ETL

初次接触机器学习的人通常需要解决一些像Kaggle上找到的挑战。在这些挑战中,我们几乎总是有一个静态数据集,它不会随时间变化。然而,在现实世界中,情况并非完全如此。

在处理实际的机器学习产品时,数据可能会不断变化。由此产生的...

MMM:用于市场营销组合建模和广告支出回报率(ROAS)的贝叶斯框架

原文:towardsdatascience.com/mmm-bayesian-framework-for-marketing-mix-modeling-and-roas-ccade4005bd5?source=collection_archive---------1-----------------------#2024-06-06

使用 PyMC 的贝叶斯框架来建模媒体渠道表现、广告支出回报率(ROAS)和预算分配

Luís RoqueTowards Data Science Luís Roque

·发布于 Towards Data Science ·18 分钟阅读·2024 年 6 月 6 日

--

这篇文章由 Rafael Guedes 共同撰写。

介绍

可扩展的互联网企业依赖营销来推动增长。当然,不仅如此,在一定规模下,极少数公司能够承受不在客户获取方面做到极其高效的成本。两大热门话题,企业正在大量投资以将人工智能(AI)能力引入营销领域,分别是媒体组合建模(MMM)和客户生命周期价值(LTV)预测。两者的目标都是提高企业在营销上的投资回报。本文将介绍 MMM 是什么以及应用 MMM 的最佳实践。

MMM 是一种技术,允许营销团队衡量他们的投资影响及其如何促进转化。随着过去几年可用的广告平台激增,这项任务的复杂性迅速增加。这一现象将潜在客户分散到了不同的媒体渠道,这些渠道可以分为离线或在线两类。传统的离线渠道与数字支持脱节,可能包括报纸、广播、电视广告、优惠券以及展会上的摊位。在线渠道爆炸性增长,企业将它们结合使用…

使用 FastAPI、Azure 和 Docker 进行模型部署

原文:towardsdatascience.com/model-deployment-with-fastapi-azure-and-docker-10e5cfbc1f4f?source=collection_archive---------2-----------------------#2024-09-28

使用 FastAPI 服务机器学习模型的完整指南

Sabrine BendimeradTowards Data Science Sabrine Bendimerad

·发表于Towards Data Science ·10 分钟阅读·2024 年 9 月 28 日

--

pixabay.com

欢迎来到我MLOps 系列的第三篇文章。在第一篇文章中,我们探讨了 Docker 及其如何简化应用程序打包。在第二篇文章中,我们使用MLflowAzureDocker来管理机器学习模型。现在,在这一第三部分,我们将通过构建一个FastAPI应用程序,将我们之前存储的模型部署到 Azure 上,从而将所有内容整合起来。这将允许我们创建一个全球可访问的预测服务!

什么是 API?

API就像一座桥梁。当你与 Python 中的库进行交互时,你就是在使用它的 API。它是一个应用程序的公开部分,你可以与之交互,而其背后的所有内容则是隐藏的。

API 通常用于与 Web 应用程序进行通信,它们提供一组返回数据的 URL(你发送带有一些参数的请求,并收到响应)。通常,数据以像 JSON 或 XML 这样的格式返回,这些格式易于解析。这与返回 HTML 的网页不同,HTML 包括渲染页面所需的信息。通过 API,你只会得到原始数据。

有些 API 是公开的,而其他的是私有的。在构建 API 时,你决定分享哪些数据,以及如何分享……

模型漂移介绍与概念

原文:towardsdatascience.com/model-drift-introduction-and-concepts-e32c5305da2a?source=collection_archive---------3-----------------------#2024-06-22

了解机器学习模型漂移背后的一些概念,并理解为什么 MLOps 在今天的世界中如此重要。

Ivo BernardoTowards Data Science Ivo Bernardo

·发表于 Towards Data Science ·6 分钟阅读·2024 年 6 月 22 日

--

模型会发生变化,因为世界在变化——图片来源:arptrastogi 通过 Unsplash.com

税收、死亡和模型漂移是生活中唯一的三大确定性。好吧,我可能在这句格言中加入了最后一个,但事实是所有模型都会遭遇衰退。

开发机器学习模型后,你总会看到相同的模式发展:

  • 在开发过程中,模型在测试集上的表现是预期的。

  • 模型在投入生产后表现不同(通常,表现稍差)。

  • 模型的性能随着时间的推移而下降。

几年后,你的模型表现很可能比最初开发时差得多。这可能由多种原因引起,但根本原因是世界在变化。

当世界变化时,我们用来表示现实信息的数据也会发生变化。 潜在的数据分布发生偏移,这必然会影响我们的机器学习模型如何学习和表现。

在这篇博客文章中,我们将探讨一些情况示例,其中世界的潜在变化如何影响你的模型……

模型评估与任务评估

原文:towardsdatascience.com/model-evaluations-versus-task-evaluations-5bc742054957?source=collection_archive---------7-----------------------#2024-03-26

由作者使用 Dall-E 3 创建的图像

理解 LLM 应用中的差异

Aparna DhinakaranTowards Data Science Aparna Dhinakaran

·发表于Towards Data Science ·9 分钟阅读·2024 年 3 月 26 日

--

想象一下飞机。你首先想到的是什么?现在再想象一架波音 737 和一架V-22 鱼鹰。这两者都是旨在运输货物和人员的飞机,但它们服务的目的不同——一种更为通用(商业航班和货运),另一种非常具体(为特种作战部队执行渗透、撤离和补给任务)。它们看起来完全不同,因为它们是为不同的活动而设计的。

随着 LLM 的兴起,我们见证了第一批真正的通用机器学习模型。它们的通用性在许多方面帮助了我们:

  • 同一个工程团队现在可以进行情感分析和结构化数据提取

  • 许多领域的从业者可以共享知识,从而使整个行业能够相互受益于彼此的经验

  • 有许多行业和工作领域,其中相同的经验是有用的

但正如我们在飞机中看到的,通用性需要与在特定任务上出色表现截然不同的评估方法,归根结底,商业价值通常来自于解决特定的问题。

这是模型评估与任务评估差异的一个很好的类比。模型评估侧重于总体的综合评估,而任务评估则侧重于评估特定任务的表现。

不止一个 LLM 评估

LLM 评估 这个术语常常被广泛使用。OpenAI 早期发布了一些工具来进行 LLM 评估,例如。大多数从业者更关注 LLM 任务评估,但这一区分并不总是很清晰。

有什么区别?

模型评估关注的是模型的“整体健身情况”。它在各种任务上的表现如何?

任务评估则专门设计用来检查模型是否适合你的特定应用。

一般锻炼并且身体素质较好的人,在真实比赛中可能会在职业相扑选手面前表现不佳,而模型评估无法与任务评估在评估你特定需求的能力上相提并论。

模型评估

模型评估专门用于构建和微调通用模型。它们基于你给模型提出的一组问题以及你用来评分的地面真实答案。可以将其想象成参加 SAT 考试。

虽然模型评估中的每个问题都不同,但通常有一个共同的测试领域。每个指标都有一个特定的目标主题或技能。例如,HellaSwag 的表现已经成为衡量 LLM 质量的流行方式。

HellaSwag 数据集包含了一系列上下文和多项选择题,每个问题都有多个可能的完成选项。只有一个选项是合乎逻辑且合理的,其他选项虽然看似合理,但其实是错误的。这些完成项旨在对 AI 模型提出挑战,不仅要求语言理解能力,还需要常识推理才能选择正确的选项。

这是一个例子:

一盘土豆被放入烤箱并取出。一大盘蛋糕被翻过来并放到柜台上。一大盘肉

A. 被放到烤土豆上

B. 土豆和腌菜被放入烤箱中

C. 被准备好后,由助手在完成时从烤箱中取出。

另一个例子是 MMLU。MMLU 涵盖了多个学科的任务,包括科学、文学、历史、社会科学、数学以及法律和医学等专业领域。这些学科的多样性旨在模拟人类学习者所需的知识和理解的广度,使其成为测试模型处理多面语言理解挑战能力的好方法。

这里有一些例子——你能解答它们吗?

在以下哪些热力学过程中,理想气体的内能增加等于加热到气体的热量?

A. 恒温

B. 恒体积

C. 恒压

D. 绝热过程

图片来源:作者

Hugging Face 排行榜可能是获取此类模型评估的最知名平台。排行榜跟踪开源的大型语言模型,并记录许多模型评估指标。这通常是一个很好的起点,用来理解开源 LLM 在不同任务表现上的差异。

多模态模型需要更多的评估。Gemini 论文展示了多模态引入了许多其他基准,比如 VQAv2,它测试理解和整合视觉信息的能力。这些信息不仅仅是简单的物体识别,而是对行动和物体之间关系的解读。

同样,针对音频和视频信息以及如何跨模态整合的指标也存在。

这些测试的目标是区分两个模型或同一个模型的两个不同快照。选择一个适合你应用的模型很重要,但这是你做的事,一般情况下只是一次性操作或非常不频繁的操作。

作者提供的图片

任务评估

更常见的问题是通过任务评估来解决的。基于任务的评估目标是分析使用 LLM 作为评判者的模型表现。

  • 你的检索系统是否获取了正确的数据?

  • 你的回答中有幻觉吗?

  • 系统是否用相关的答案回答了重要问题?

有些人可能对 LLM 评估其他 LLM 感到有些不确定,但我们每天都有人工评估其他人。

模型评估和任务评估的真正区别在于:在模型评估中,我们会提出许多不同的问题,而在任务评估中,问题保持不变,只有数据会变化。例如,假设你在操作一个聊天机器人。你可以在数百次客户互动中使用你的任务评估,并问它:“这里有幻觉吗?”这个问题在所有对话中始终相同。

作者提供的图片

有几个库旨在帮助实践者构建这些评估:RagasPhoenix(完全披露:作者领导了开发 Phoenix 的团队),OpenAILlamaIndex

它们是如何工作的?

任务评估整体上评估应用程序的每个输出的表现。我们来看看构建一个评估任务需要哪些内容。

建立基准

基础在于建立一个稳健的基准。这从创建一个能够准确反映 LLM 将遇到的场景的黄金数据集开始。该数据集应包含地面真实标签——通常来源于细致的人工审核——作为对比标准。别担心,通常你可以用几十到几百个示例来完成。选择合适的 LLM 进行评估也至关重要。虽然它可能与应用程序的主要 LLM 不同,但应该与成本效益和准确性目标保持一致。

制定评估模板

任务评估过程的核心是评估模板。该模板应清晰定义输入(例如,用户查询和文档)、评估问题(例如,文档与查询的相关性)和预期的输出格式(二元或多类别相关性)。根据应用程序的具体需求,可能需要调整模板,以确保它能够准确评估 LLM 在黄金数据集上的表现。

这里是一个用来评估问答任务的模板示例。

You are given a question, an answer and reference text. You must determine whether the given answer correctly answers the question based on the reference text. Here is the data:
  [BEGIN DATA]
  ************
  [QUESTION]: {input}
  ************
  [REFERENCE]: {reference}
  ************
  [ANSWER]: {output}
  [END DATA]
Your response should be a single word, either "correct" or "incorrect", and should not contain any text or characters aside from that word.
"correct" means that the question is correctly and fully answered by the answer. 
"incorrect" means that the question is not correctly or only partially answered by the answer. 

度量标准与迭代

在你的黄金数据集上运行评估,可以生成关键度量指标,如准确率、精确度、召回率和 F1 分数。这些指标为评估模板的有效性提供了洞察,并突出改进的领域。迭代至关重要;根据这些度量标准精炼模板,确保评估过程与应用目标保持一致,同时避免对黄金数据集的过拟合。

在任务评估中,仅依赖总体准确率是不够的,因为我们总是会遇到显著的类别不平衡。精确度和召回率提供了更全面的视角,强调准确识别相关和不相关结果的重要性。平衡的度量方法确保评估对提升 LLM 应用有实际贡献。

LLM 评估的应用

一旦评估框架就绪,下一步是将这些评估直接应用到你的 LLM 应用中。这涉及将评估过程集成到应用程序的工作流程中,以便实时评估 LLM 对用户输入的响应。这个持续的反馈循环对于保持和提高应用的相关性和准确性至关重要。

系统生命周期中的评估

有效的任务评估不仅限于单一阶段,而是贯穿整个大语言模型(LLM)系统的生命周期。从生产前的基准测试和评估,到生产中的持续性能评估, LLM 评估 确保系统始终能够响应用户需求。

示例:模型是否出现幻觉?

让我们更详细地看一个幻觉示例。

作者示例

由于幻觉是大多数从业者面临的普遍问题,现有一些基准数据集可供使用。这是一个很好的起步,但你通常需要在公司内部拥有一个定制的数据集。

接下来的重要步骤是开发提示模板。这里同样一个好的库可以帮助你入门。我们之前看到了一个示例提示模板,这里我们看到另一个专门针对幻觉的模板。你可能需要根据你的需求进行调整。

In this task, you will be presented with a query, a reference text and an answer. The answer is
generated to the question based on the reference text. The answer may contain false information, you
must use the reference text to determine if the answer to the question contains false information,
if the answer is a hallucination of facts. Your objective is to determine whether the reference text
contains factual information and is not a hallucination. A 'hallucination' in this context refers to
an answer that is not based on the reference text or assumes information that is not available in
the reference text. Your response should be a single word: either "factual" or "hallucinated", and
it should not include any other text or characters. "hallucinated" indicates that the answer
provides factually inaccurate information to the query based on the reference text. "factual"
indicates that the answer to the question is correct relative to the reference text, and does not
contain made up information. Please read the query and reference text carefully before determining
your response.

    [BEGIN DATA]
    ************
    [Query]: {input}
    ************
    [Reference text]: {reference}
    ************
    [Answer]: {output}
    ************
    [END DATA]

    Is the answer above factual or hallucinated based on the query and reference text?

Your response should be a single word: either "factual" or "hallucinated", and it should not include any other text or characters. 
"hallucinated" indicates that the answer provides factually inaccurate information to the query based on the reference text.
"factual" indicates that the answer to the question is correct relative to the reference text, and does not contain made up information.
Please read the query and reference text carefully before determining your response. 

现在你已经准备好将黄金数据集中的查询传递给你的评估 LLM,并让它标记出幻觉。当你查看结果时,记住应该存在类别不平衡。你需要跟踪精确度和召回率,而不是整体准确率。

构建混淆矩阵并将其可视化非常有用。当你有了这样的图表时,你可以对你的 LLM 性能感到放心。如果性能不尽如人意,你可以随时优化提示模板。

评估任务评估性能的示例,以便用户能够建立对其评估的信心

在评估构建完成后,你现在拥有了一个强大的工具,可以用已知的精确度和召回率标记所有数据。你可以使用它来跟踪你系统中的幻觉,无论是在开发阶段还是生产阶段。

区别总结

让我们总结一下任务评估和模型评估之间的区别。

表格由作者提供

要点总结

最终,模型评估和任务评估在构建功能性 LLM 系统时都非常重要。理解何时以及如何应用每种评估方法是很关键的。对于大多数从业者来说,大部分时间将花费在任务评估上,这为系统在特定任务上的性能提供了衡量标准。

使用信用卡欺诈数据的模型可解释性

原文:towardsdatascience.com/model-interpretability-using-credit-card-fraud-data-f219ff7ec89d?source=collection_archive---------4-----------------------#2024-06-12

为什么模型可解释性很重要

Danila MorozovskiiTowards Data Science Danila Morozovskii

·发布于 Towards Data Science ·17 分钟阅读·2024 年 6 月 12 日

--

最近,我偶然发现了一本在线书籍,书中描述了可用于机器学习模型可解释性的不同工具(christophm.github.io/interpretable-ml-book/)。机器学习模型不应该是黑箱,并且可以解释的这一想法让我非常着迷,我决定深入研究这个话题。之前,当我开始进行一个新的机器学习项目时,我会遵循相同的程序:识别问题、熟悉数据集、特征工程、选择模型、训练/测试和超参数调优,以及结果分析。然而,我没有意识到自己漏掉了最关键的一步:模型可解释性。

什么是模型可解释性?

模型可解释性是解释黑箱(模型)如何工作以及如何做出预测的过程。让我们假设一个情况,一个人申请了信用贷款并被拒绝,因为模型做出了负面预测。任何人都会想知道为什么被拒绝,以及他们可能需要改变什么,才能使决定变得积极,而银行员工只能指着机器学习模型说“它是这么说的!”。这种情况并不好,并且会损害银行的声誉,因为看起来银行对其产品没有控制权。如果银行员工能够向客户解释哪些特定特征在做出预测时发挥了作用,那就会好得多。

使用 MLflow、Azure 和 Docker 进行模型管理

原文:towardsdatascience.com/model-management-with-mlflow-azure-and-docker-2920b51a5bdd?source=collection_archive---------2-----------------------#2024-09-17

跟踪实验和管理模型的指南

Sabrine BendimeradTowards Data Science Sabrine Bendimerad

·发表于 Towards Data Science ·阅读时长 10 分钟·2024 年 9 月 17 日

--

pixabay.com

在第一篇文章中,我们探索了 Docker 将应用程序及其依赖项打包成可移植容器的强大能力,确保在各种环境中保持一致性。

在此基础上,本文介绍了MLflow,这是一个在机器学习工作流中用于实验追踪和模型管理的重要工具。我们将演示如何在 Docker 容器内部署和使用 MLflow,以确保可移植性并避免与依赖项相关的问题。容器化的 MLflow 服务器将部署在Azure上,以提供更好的可扩展性、远程访问,且更重要的是团队协作。

什么是 MLflow

MLflow 是一个开源平台,简化了机器学习生命周期的管理,从实验追踪到模型部署。它提供了一个稳定的框架,用于记录实验、管理代码和追踪模型版本,确保你的工作流程在团队间是可复现的且井然有序。

MLflow 可以集成到你机器学习管道的各个阶段。它提供了四个主要组件:

  • MLflow Tracking:这是最广泛使用的功能,允许你记录和查询实验。它追踪有用的细节……

模型选择:类别平衡 第一部分

原文:towardsdatascience.com/model-selection-a-guide-to-class-balancing-part-i-14b17003186f?source=collection_archive---------7-----------------------#2024-10-10

关于匿名化数据类别平衡的教程。

Leo Anello 💡Towards Data Science Leo Anello 💡

·发表于数据科学前沿 ·阅读时间 22 分钟·2024 年 10 月 10 日

--

图片由Dave Lowe提供,来自Unsplash

我将为你带来一个机器学习模型选择项目,涉及匿名化数据的多变量分析

这是一个全面的项目,我们将从头到尾进行讲解——从定义业务问题到模型部署(尽管部署部分我们会留到其他时间)。

这个项目将有两部分完整的教程,我希望带你了解一系列技术,其中包括处理匿名化数据的复杂性——这是由于数据隐私问题,越来越多地出现在职场中的一种情况。

[## GitHub - Anello92/Model-Selection-and-Class-Balancing

本仓库包含一个关于机器学习模型选择的全面教程,专注于多变量分析...

github.com](https://github.com/Anello92/Machine-Learning-Model-Selection-and-Class-Balancing?source=post_page-----14b17003186f--------------------------------)

那么,处理这种类型的数据的最大挑战是什么呢?就是你没有任何关于每个变量所代表的内容的信息。

现在,这可有点棘手,不是吗?你会收到数据,在不知道每个变量代表什么的情况下,你需要基于这些数据开发一个机器学习模型

我们还将借此机会深入探讨模型选择。哪种机器...

构建 5 个机器学习模型 第二部分

原文:towardsdatascience.com/model-selection-building-5-machine-learning-models-part-ii-3be49fb0dc61?source=collection_archive---------8-----------------------#2024-10-10

构建、比较和优化机器学习模型。

Leo Anello 💡Towards Data Science Leo Anello 💡

·发表于Towards Data Science ·阅读时间 38 分钟·2024 年 10 月 10 日

--

模型选择

现在我们进入项目的第二部分——机器学习模型选择与多变量分析中的匿名化数据

这第二部分是魅力所在——预测建模机器学习。每个人都迫不及待地想直接进入构建机器学习模型的阶段。我理解这一点,我也感到同样的兴奋,因为我热爱这个阶段。

但在我们深入之前,我们必须经过数据处理——这正是我们在上一篇教程中涵盖的内容。

## 模型选择:类平衡指南(第一部分)

关于机器学习模型选择中的匿名数据类平衡的综合教程

towardsdatascience.com

安装和加载软件包

我们首先安装XGBoost包,这是参与机器学习竞赛的 Kaggle 平台用户的最爱之一。

# This package does not come with Anaconda and needs to be installed
!pip install -q xgboost

这个软件包不包括Anaconda,因此你需要单独安装它。要安装它,我们使用命令……

模型验证技术解析:带有代码示例的可视化指南

原文:towardsdatascience.com/model-validation-techniques-explained-a-visual-guide-with-code-examples-eb13bbdc8f88?source=collection_archive---------1-----------------------#2024-11-30

模型评估与优化

12 种必须了解的机器学习验证方法

Samy BaladramTowards Data Science Samy Baladram

·发表于Towards Data Science ·阅读时长 26 分钟·2024 年 11 月 30 日

--

每天,机器都会做出数百万个预测——从检测照片中的物体到帮助医生发现疾病。但在相信这些预测之前,我们需要知道它们是否准确。毕竟,没有人愿意使用一个大多数时候都错误的机器!

这时,验证就显得尤为重要。验证方法测试机器的预测结果,以衡量其可靠性。虽然这听起来很简单,但实际上存在多种验证方法,每种方法都是为了应对机器学习中的特定挑战而设计的。

在这里,我将这些验证技术——全部 12 种——以树状结构组织,展示它们如何从基本概念发展成更为专业的技术。当然,我们将使用清晰的可视化图像和一致的数据集,展示每种方法的不同之处以及为什么选择方法至关重要。

所有可视化图像:作者使用 Canva Pro 创建。已优化为移动端显示;在桌面端可能会显得过大。

什么是模型验证?

模型验证是测试机器学习模型在未见过或未在训练中使用过的数据上表现如何的过程。基本上,我们使用现有数据来检查模型的表现,而不是使用新的数据。这帮助我们在实际使用模型之前识别问题。

有多种验证方法,每种方法都有其特定的优势,并且解决不同的验证挑战:

  1. 不同的验证方法可能会产生不同的结果,因此选择正确的方法很重要。

  2. 一些验证技术在特定类型的数据和模型中效果更佳。

  3. 使用不正确的验证方法可能会导致关于模型真实表现的误导性结果。

这里有一棵树形图,展示了这些验证方法之间的关系:

这棵树形图展示了哪些验证方法相互关联。

接下来,我们将更仔细地研究每种验证方法,展示它们是如何工作的。为了更容易理解,我们将通过清晰的示例,展示这些方法如何在实际数据中运作。

📊 📈 我们的运行示例

我们将始终使用相同的示例,帮助你理解每种验证方法。虽然这个数据集可能不适合某些验证方法,但为了教学目的,使用这个示例使得比较不同方法并观察每种方法如何工作的过程更加容易。

📊 高尔夫游戏数据集

我们将使用这个数据集,它根据天气条件预测某人是否会打高尔夫。

列:‘Overcast(独热编码为 3 列)’,‘Temperature’(以华氏度表示),‘Humidity’(百分比),‘Windy’(是/否)和‘Play’(是/否,目标特征)

import pandas as pd
import numpy as np

# Load the dataset
dataset_dict = {
    'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 
                'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy',
                'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast',
                'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
    'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0,
                   72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0,
                   88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
    'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0,
                 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0,
                 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
    'Wind': [False, True, False, False, False, True, True, False, False, False, True,
             True, False, True, True, False, False, True, False, True, True, False,
             True, False, False, True, False, False],
    'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes',
             'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes',
             'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}

df = pd.DataFrame(dataset_dict)

# Data preprocessing
df = pd.DataFrame(dataset_dict)
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)

# Set the label
X, y = df.drop('Play', axis=1), df['Play']

📈 我们的模型选择

我们将在所有测试中使用决策树分类器。如果你不熟悉它,可以参考以下文章:

## 决策树分类器解析:附带代码示例的视觉指南(面向初学者)

对我们最喜欢的倒立树的全新看法

[towardsdatascience.com

我们选择这个模型是因为我们可以很容易地将结果模型绘制为树形结构,每个分支显示不同的决策。为了简化操作并专注于如何测试模型,我们将使用默认的scikit-learn参数,并设置固定的random_state

让我们明确这两个术语:决策树分类器是我们的学习算法——它是找到数据中模式的方法。当我们将数据输入该算法时,它会创建一个模型(在这种情况下,是一棵显示不同决策的树)。这个模型就是我们实际用来进行预测的模型。

from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

dt = DecisionTreeClassifier(random_state=42)

每次我们以不同的方式拆分数据进行验证时,都会得到不同的模型和不同的决策规则。一旦我们的验证表明算法可靠地工作,我们将使用所有数据创建一个最终模型。这个最终模型就是我们实际用来预测某人是否会打高尔夫的模型。

设置好这一切后,我们现在可以集中精力了解每种验证方法的工作原理,以及它如何帮助我们根据天气条件做出更好的高尔夫球预测。我们将逐一检查每种验证方法。

保留法

保留法是检验我们模型效果的最基础方法。在这些方法中,我们基本上将一部分数据专门用于测试。

训练-测试拆分

这种方法很简单:我们将数据分成两部分。我们使用一部分来训练模型,另一部分来测试模型。在分割数据之前,我们会随机打乱数据顺序,以确保原始数据的顺序不会影响结果。

训练集和测试集的大小取决于我们的总数据集大小,通常用它们的比例来表示。为了确定它们的大小,您可以遵循以下指导原则:

  • 对于小型数据集(大约 1,000–10,000 个样本),使用 80:20 的比例。

  • 对于中等规模的数据集(大约 10,000–100,000 个样本),使用 70:30 的比例。

  • 大型数据集(超过 100,000 个样本),使用 90:10 的比例。

from sklearn.model_selection import train_test_split

### Simple Train-Test Split ###
# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train and evaluate
dt.fit(X_train, y_train)
test_accuracy = dt.score(X_test, y_test)

# Plot
plt.figure(figsize=(5, 5), dpi=300)
plot_tree(dt, feature_names=X.columns, filled=True, rounded=True)
plt.title(f'Train-Test Split (Test Accuracy: {test_accuracy:.3f})')
plt.tight_layout()

这种方法很容易使用,但也有一些局限性 —— 结果可能会因为我们如何随机分割数据而有很大变化。这就是为什么我们总是需要尝试不同的random_state来确保结果的一致性。此外,如果我们起初的数据不多,可能没有足够的数据来充分训练或测试我们的模型。

训练-验证-测试拆分

这种方法将数据分为三部分。中间部分,称为验证数据,用来调整模型的参数,我们的目标是尽量减少该部分的误差。

由于在调整过程中会多次考虑验证结果,我们的模型可能会在验证数据上表现得太好(这正是我们想要的)。这就是我们为什么要设立单独的测试集的原因。我们只在最后一次测试它 —— 它能真实地反映出我们的模型效果如何。

以下是常见的数据拆分方式:

  • 对于较小的数据集(1,000–10,000 个样本),使用 60:20:20 的比例。

  • 对于中等规模的数据集(10,000–100,000 个样本),使用 70:15:15 的比例。

  • 大型数据集(>100,000 个样本),使用 80:10:10 的比例。

### Train-Validation-Test Split ###
# First split: separate test set
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Second split: separate validation set
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.25, random_state=42
)

# Train and evaluate
dt.fit(X_train, y_train)
val_accuracy = dt.score(X_val, y_val)
test_accuracy = dt.score(X_test, y_test)

# Plot
plt.figure(figsize=(5, 5), dpi=300)
plot_tree(dt, feature_names=X.columns, filled=True, rounded=True)
plt.title(f'Train-Val-Test Split\nValidation Accuracy: {val_accuracy:.3f}'
          f'\nTest Accuracy: {test_accuracy:.3f}')
plt.tight_layout()

保留法根据数据量的不同会有不同的表现。当你有大量数据(>100,000 个样本)时,它效果很好。但当你数据较少(<1,000 个样本)时,这种方法可能不是最理想的。在数据较少的情况下,你可能需要使用更高级的验证方法,以便更好地了解你的模型到底有多有效。

📊 转向交叉验证

我们刚刚了解到,留出法可能在小数据集上效果不佳。这正是我们目前面临的挑战——我们只有 28 天的数据。按照留出法原则,我们将保留 14 天的数据作为最终测试数据。这样,我们剩下 14 天的数据可以用于尝试其他验证方法。

# Initial train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=False)

在接下来的部分,我们将看到交叉验证方法如何将这 14 天的数据多次划分,并以不同的方式进行测试。这让我们即使在数据有限的情况下,也能更好地了解模型的实际效果。

交叉验证

交叉验证改变了我们测试模型的方式。我们不再仅仅用一种数据划分方式测试一次模型,而是通过多次使用相同数据的不同划分来进行测试。这有助于我们更好地理解模型的实际表现。

交叉验证的主要思想是多次测试我们的模型,每次的训练集和测试集都来自我们数据的不同部分。这有助于避免由于数据划分极端(如特别好或特别差)而带来的偏差。

这为什么很重要呢?假设我们的模型在某次测试中得到 95%的准确率,而在另一种测试方法下只得到 75%的准确率,哪一个结果才是真正反映模型表现的呢?交叉验证通过提供多个测试结果,而不仅仅是一个,帮助我们回答这个问题。这让我们更清楚地了解模型的实际表现。

K 折法

基础 K 折交叉验证 K折交叉验证解决了基本数据划分方法的一个大问题:过于依赖单一的数据划分方式。与其只进行一次数据划分,K折将数据划分成K个相等的部分。然后,它多次测试模型,每次使用不同的部分进行测试,而其他部分则用于训练。

我们选择的K数值会影响我们如何测试模型。大多数人使用 5 或 10 作为K,但这个数值也可以根据我们拥有的数据量和项目需求来调整。假设我们使用K = 3。这意味着我们将数据分成三等份。然后我们将模型训练和测试三次。每次,2/3 的数据用于训练,1/3 的数据用于测试,但每次测试时,所用的测试部分都会不同。这样,每个数据片段都会同时用于训练和测试。

from sklearn.model_selection import KFold, cross_val_score

# Cross-validation strategy
cv = KFold(n_splits=3, shuffle=True, random_state=42)

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
plt.figure(figsize=(4, 3.5*cv.get_n_splits(X_train)))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
    # Train and visualize the tree for this split
    dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
    plt.subplot(cv.get_n_splits(X_train), 1, i+1)
    plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
    plt.title(f'Split {i+1} (Validation Accuracy: {scores[i]:.3f})\nTrain indices: {train_idx}\nValidation indices: {val_idx}')

plt.tight_layout()

验证准确率: 0.433 ± 0.047

当我们完成所有轮次后,我们会计算所有K测试的平均表现。这个平均值为我们提供了一个更可靠的衡量标准,来评估我们的模型表现如何。我们还可以通过观察不同测试轮次之间结果的变化,来了解我们的模型有多稳定。

分层 K 折 基本的 K 折交叉验证通常效果不错,但当我们的数据不平衡时——即某些类型的数据比其他类型多得多——它可能会遇到问题。例如,如果我们有 100 个数据点,其中 90 个是 A 类型,而只有 10 个是 B 类型,随机划分这些数据可能会导致某些划分中没有足够的 B 类型数据来进行合理的测试。

分层 K 折交叉验证通过确保每个数据划分与原始数据的分布相同来解决这个问题。如果我们的完整数据集中有 10% 是 B 类型,那么每个划分也将包含大约 10% 的 B 类型数据。这使得我们的测试更加可靠,特别是在某些数据类型比其他类型稀少时。

from sklearn.model_selection import StratifiedKFold, cross_val_score

# Cross-validation strategy
cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
plt.figure(figsize=(5, 4*cv.get_n_splits(X_train)))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
    # Train and visualize the tree for this split
    dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
    plt.subplot(cv.get_n_splits(X_train), 1, i+1)
    plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
    plt.title(f'Split {i+1} (Validation Accuracy: {scores[i]:.3f})\nTrain indices: {train_idx}\nValidation indices: {val_idx}')

plt.tight_layout()

验证准确率:0.650 ± 0.071

保持这种平衡有两个好处。首先,它确保每个划分能够恰当地代表我们数据的分布。其次,它使得我们的测试结果更加一致。这意味着,如果我们多次测试模型,我们很可能每次都会得到类似的结果。

重复 K 折 有时,即使我们使用了 K 折验证,测试结果在不同的随机划分之间也可能发生较大的变化。重复 K 折通过多次运行整个 K 折过程来解决这个问题,每次使用不同的随机划分。

例如,假设我们运行 5 折交叉验证三次。这意味着我们的模型总共会进行 15 次训练和测试。通过如此多次的测试,我们可以更好地判断结果中的差异是来自随机因素,还是能真正反映出模型的性能。缺点是,所有这些额外的测试需要更多的时间来完成。

from sklearn.model_selection import RepeatedKFold

# Cross-validation strategy
n_splits = 3
cv = RepeatedKFold(n_splits=n_splits, n_repeats=2, random_state=42)

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
total_splits = cv.get_n_splits(X_train)  # Will be 6 (3 folds × 2 repetitions)
plt.figure(figsize=(5, 4*total_splits))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
   # Train and visualize the tree for this split
   dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])

   # Calculate repetition and fold numbers
   repetition, fold = i // n_splits + 1, i % n_splits + 1

   plt.subplot(total_splits, 1, i+1)
   plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
   plt.title(f'Split {repetition}.{fold} (Validation Accuracy: {scores[i]:.3f})\n'
            f'Train indices: {list(train_idx)}\n'
            f'Validation indices: {list(val_idx)}')

plt.tight_layout()

验证准确率:0.425 ± 0.107

当我们查看重复 K 折结果时,由于我们有很多组测试结果,我们可以做的不仅仅是计算平均值——我们还可以了解我们对结果的信心。这使我们更好地理解模型的可靠性。

重复分层 K 折 这种方法结合了我们刚刚学习的两件事:保持类别平衡(分层)和进行多轮测试(重复)。它在测试多次的同时保持了不同类型数据的正确比例。这在我们的数据集较小且不平衡时尤其有效——例如,当我们有大量一种类型的数据,而其他类型的数据较少时。

from sklearn.model_selection import RepeatedStratifiedKFold

# Cross-validation strategy
n_splits = 3
cv = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=2, random_state=42)

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
total_splits = cv.get_n_splits(X_train)  # Will be 6 (3 folds × 2 repetitions)
plt.figure(figsize=(5, 4*total_splits))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
   # Train and visualize the tree for this split
   dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])

   # Calculate repetition and fold numbers
   repetition, fold = i // n_splits + 1, i % n_splits + 1

   plt.subplot(total_splits, 1, i+1)
   plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
   plt.title(f'Split {repetition}.{fold} (Validation Accuracy: {scores[i]:.3f})\n'
            f'Train indices: {list(train_idx)}\n'
            f'Validation indices: {list(val_idx)}')

plt.tight_layout()

验证准确率:0.542 ± 0.167

然而,这种方法有一个权衡:它需要更多的时间来运行。每次我们重复整个过程时,训练模型所需的时间会成倍增加。在决定是否使用这种方法时,我们需要考虑,是否值得花费额外的时间来获得更可靠的结果。

分组 K 折交叉验证 有时,我们的数据自然分为一些应该保持在一起的组。例如,高尔夫数据中,我们可能有来自同一个高尔夫球场的多次测量数据。如果我们将来自一个高尔夫球场的部分测量数据放入训练数据,而其他的放入测试数据,就会出现问题:我们的模型可能会在训练过程中间接了解测试数据,因为它看到了来自同一球场的其他测量数据。

分组 K 折交叉验证通过保持来自同一组的数据(例如来自同一高尔夫球场的所有测量数据)一起划分,来解决这一问题。这可以防止我们的模型在训练过程中无意中看到不应该看到的信息,从而让我们误以为它表现得比实际情况更好。

# Create groups 
groups = ['Group 1', 'Group 4', 'Group 5', 'Group 3', 'Group 1', 'Group 2', 'Group 4', 
          'Group 2', 'Group 6', 'Group 3', 'Group 6', 'Group 5', 'Group 1', 'Group 4', 
          'Group 4', 'Group 3', 'Group 1', 'Group 5', 'Group 6', 'Group 2', 'Group 4', 
          'Group 5', 'Group 1', 'Group 4', 'Group 5', 'Group 5', 'Group 2', 'Group 6']

# Simple Train-Test Split
X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split(
    X, y, groups, test_size=0.5, shuffle=False
)

# Cross-validation strategy
cv = GroupKFold(n_splits=3)

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv.split(X_train, y_train, groups=groups_train))
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
plt.figure(figsize=(4, 3.5*cv.get_n_splits(X_train)))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train, groups=groups_train)):
   # Get the groups for this split
   train_groups = sorted(set(np.array(groups_train)[train_idx]))
   val_groups = sorted(set(np.array(groups_train)[val_idx]))

   # Train and visualize the tree for this split
   dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
   plt.subplot(cv.get_n_splits(X_train), 1, i+1)
   plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
   plt.title(f'Split {i+1} (Validation Accuracy: {scores[i]:.3f})\n'
            f'Train indices: {train_idx} ({", ".join(train_groups)})\n'
            f'Validation indices: {val_idx} ({", ".join(val_groups)})')

plt.tight_layout()

验证准确度: 0.417 ± 0.143

当我们处理自然分组的数据时,这种方法尤其重要,比如来自同一个高尔夫球场的多次天气数据,或者同一地点在不同时间收集的数据。

时间序列划分 当我们在常规的 K 折交叉验证中随机划分数据时,我们假设每个数据点不会影响其他数据点。但这对于随时间变化的数据并不适用,因为过去发生的事情会影响未来的结果。时间序列划分通过调整 K 折交叉验证,更好地处理这种时间顺序数据。

时间序列划分并非随机分割数据,而是按顺序使用数据,从过去到未来。训练数据仅包括测试数据之前的时间段的信息。这与我们在现实生活中使用模型的方式一致,即我们利用过去的数据来预测未来的事件。

from sklearn.model_selection import TimeSeriesSplit, cross_val_score

# Cross-validation strategy
cv = TimeSeriesSplit(n_splits=3)

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
plt.figure(figsize=(4, 3.5*cv.get_n_splits(X_train)))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
   # Train and visualize the tree for this split
   dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
   plt.subplot(cv.get_n_splits(X_train), 1, i+1)
   plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
   plt.title(f'Split {i+1} (Validation Accuracy: {scores[i]:.3f})\n'
            f'Train indices: {train_idx}\n'
            f'Validation indices: {val_idx}')

plt.tight_layout()

验证准确度: 0.556 ± 0.157

例如,假设K=3,并且我们有高尔夫数据。我们可以使用一月和二月的天气数据训练,来预测三月的高尔夫打球模式。接着,我们使用一月到三月的数据来预测四月,依此类推。通过只向前推进时间,这种方法能更真实地反映我们的模型在预测基于天气的未来高尔夫打球模式时的表现。

留出法

留一交叉验证 (LOOCV) 留一交叉验证 (LOOCV) 是最彻底的验证方法。它仅使用一个样本进行测试,其他所有样本用于训练。验证会重复进行,直到每一条数据都被用作测试。

假设我们有 100 天的高尔夫天气数据。LOOCV 会训练并测试模型 100 次。每次,它使用 99 天的数据进行训练,1 天的数据进行测试。这种方法消除了测试中的任何随机性——如果你多次在相同的数据上运行 LOOCV,你将始终得到相同的结果。

然而,LOOCV 需要很长的计算时间。如果你有N个数据点,你需要训练模型N次。对于大型数据集或复杂模型,这可能需要的时间太长,无法实际使用。一些简单的模型,如线性模型,有一些捷径使得 LOOCV 变得更快,但并不是所有模型都适用。

from sklearn.model_selection import LeaveOneOut

# Cross-validation strategy
cv = LeaveOneOut()

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
plt.figure(figsize=(4, 3.5*cv.get_n_splits(X_train)))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
   # Train and visualize the tree for this split
   dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
   plt.subplot(cv.get_n_splits(X_train), 1, i+1)
   plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
   plt.title(f'Split {i+1} (Validation Accuracy: {scores[i]:.3f})\n'
            f'Train indices: {train_idx}\n'
            f'Validation indices: {val_idx}')

plt.tight_layout()

验证准确率:0.429 ± 0.495

LOOCV 在数据量不多,需要最大限度利用每一份数据时表现得非常好。由于结果依赖于每一条数据,如果数据中有噪声或异常值,结果可能会有很大变化。

Leave-P-Out 交叉验证 Leave-P-Out 基于 Leave-One-Out 的思想,但它每次测试时使用 P 个数据点,而不是仅测试一个数据点。这在 Leave-One-Out 和 K-fold 验证之间创造了平衡。我们选择的 P 值会改变模型的测试方式以及所需的时间。

Leave-P-Out 的主要问题是可能的测试组合数量增长得非常快。例如,如果我们有 100 天的高尔夫天气数据,并且每次测试 5 天(P=5),那么选择这 5 天的方式有数百万种不同的组合。当数据量很大或 P 值较大时,测试所有这些组合会耗费大量时间。

from sklearn.model_selection import LeavePOut, cross_val_score

# Cross-validation strategy
cv = LeavePOut(p=3)

# Calculate cross-validation scores (using all splits for accuracy)
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot first 15 trees
n_trees = 15
plt.figure(figsize=(4, 3.5*n_trees))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
   if i >= n_trees:
       break

   # Train and visualize the tree for this split
   dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
   plt.subplot(n_trees, 1, i+1)
   plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
   plt.title(f'Split {i+1} (Validation Accuracy: {scores[i]:.3f})\n'
            f'Train indices: {train_idx}\n'
            f'Validation indices: {val_idx}')

plt.tight_layout()

验证准确率:0.441 ± 0.254

由于这些实际限制,Leave-P-Out 通常用于需要非常彻底测试且数据集足够小以使其可行的特殊情况。它在研究项目中尤其有用,在这些项目中,获取最准确的测试结果比测试所需的时间更为重要。

随机方法

ShuffleSplit 交叉验证 ShuffleSplit 与其他验证方法不同,它采用完全随机的分割方式。与 K-fold 按有序方式划分数据,或像 Leave-P-Out 那样测试所有可能的组合不同,ShuffleSplit 每次都会创建随机的训练和测试分割。

ShuffleSplit 与 K-fold 的不同之处在于,分割不遵循任何固定模式。在 K-fold 中,每条数据都恰好用于一次测试。但在 ShuffleSplit 中,一天的高尔夫天气数据可能被用于多次测试,也可能根本不被用于测试。这种随机性为我们提供了一种不同的方式来理解模型的表现。

ShuffleSplit 在大数据集上特别有效,而 K-折交叉验证可能需要花费过多时间来运行。我们可以选择测试多少次,无论数据量多大。同时,我们还可以控制每次划分的大小。这让我们能够在全面测试和运行时间之间找到一个良好的平衡。

from sklearn.model_selection import ShuffleSplit, train_test_split

# Cross-validation strategy
cv = ShuffleSplit(n_splits=3, test_size=0.2, random_state=41)

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
plt.figure(figsize=(4, 3.5*cv.get_n_splits(X_train)))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
   # Train and visualize the tree for this split
   dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
   plt.subplot(cv.get_n_splits(X_train), 1, i+1)
   plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
   plt.title(f'Split {i+1} (Validation Accuracy: {scores[i]:.3f})\n'
            f'Train indices: {train_idx}\n'
            f'Validation indices: {val_idx}')

plt.tight_layout()

验证准确率:0.333 ± 0.272

由于 ShuffleSplit 可以创建任意数量的随机划分,它在我们希望查看模型性能如何随不同的随机划分而变化,或在我们需要更多的测试以确保结果的可靠性时非常有用。

分层 ShuffleSplit 分层 ShuffleSplit 结合了随机划分和保持不同类型数据的正确混合。像分层 K-折交叉验证一样,它确保每个划分的每种类型的数据占比与整个数据集相同。

该方法为我们提供了双赢的局面:既有随机划分的自由,又有保持数据平衡的公平性。例如,如果我们的高尔夫数据集有 70% 的“是”天和 30% 的“否”天,每个随机划分都会尽量保持这一 70-30 的比例。这在数据不均衡时尤其有用,因为随机划分可能会无意中创建不代表我们数据的测试集。

from sklearn.model_selection import StratifiedShuffleSplit, train_test_split

# Cross-validation strategy
cv = StratifiedShuffleSplit(n_splits=3, test_size=0.2, random_state=41)

# Calculate cross-validation scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot trees for each split
plt.figure(figsize=(4, 3.5*cv.get_n_splits(X_train)))
for i, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
   # Train and visualize the tree for this split
   dt.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
   plt.subplot(cv.get_n_splits(X_train), 1, i+1)
   plot_tree(dt, feature_names=X_train.columns, impurity=False, filled=True, rounded=True)
   plt.title(f'Split {i+1} (Validation Accuracy: {scores[i]:.3f})\n'
            f'Train indices: {train_idx}\n'
            f'Validation indices: {val_idx}')

plt.tight_layout()

验证准确率:0.556 ± 0.157

然而,保持划分的随机性以及数据类型的正确混合可能会很棘手。该方法有时需要在完全随机和保持完美比例之间做出一些小的妥协。在实际使用中,这些小的折衷很少会引起问题,且通常保持测试集的平衡比拥有完全随机的划分更为重要。

🌟 验证技术总结与代码总结

总结来说,模型验证方法分为两大类:留出法和交叉验证法:

留出法 · 训练-测试分割:最简单的方法,将数据分成两部分

· 训练-验证-测试分割:一种三分法用于更复杂的模型开发

交叉验证法 交叉验证法通过多轮验证更好地利用可用数据:

K-折交叉验证法 这些方法将数据分为 K 个部分,而不是一个单独的划分:

· 基本 K-折交叉验证:轮流使用不同的测试集

· 分层 K-折交叉验证:保持各个划分中的类别平衡

· 分组 K-折交叉验证:保留数据分组

· 时间序列分割:尊重时间顺序

· 重复 K-折交叉验证

· 重复分层 K-折交叉验证

留出法 这些方法将验证推向极限:

· 留 P 法:一次对 P 个数据点进行测试

· 留一法:对单个数据点进行测试

随机方法 这些方法引入了受控的随机性:

· ShuffleSplit:重复创建随机划分

· 分层 ShuffleSplit:随机划分且保持类别平衡

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import (
    # Hold-out methods
    train_test_split,
    # K-Fold methods 
    KFold,                   # Basic k-fold
    StratifiedKFold,         # Maintains class balance
    GroupKFold,              # For grouped data
    TimeSeriesSplit,         # Temporal data
    RepeatedKFold,           # Multiple runs
    RepeatedStratifiedKFold, # Multiple runs with class balance
    # Leave-out methods
    LeaveOneOut,             # Single test point
    LeavePOut,               # P test points
    # Random methods
    ShuffleSplit,           # Random train-test splits
    StratifiedShuffleSplit, # Random splits with class balance
    cross_val_score         # Calculate validation score
)

# Load the dataset
dataset_dict = {
    'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 
                'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy',
                'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast',
                'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
    'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0,
                   72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0,
                   88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
    'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0,
                 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0,
                 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
    'Wind': [False, True, False, False, False, True, True, False, False, False, True,
             True, False, True, True, False, False, True, False, True, True, False,
             True, False, False, True, False, False],
    'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes',
             'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes',
             'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}

df = pd.DataFrame(dataset_dict)

# Data preprocessing
df = pd.DataFrame(dataset_dict)
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)

# Set the label
X, y = df.drop('Play', axis=1), df['Play']

## Simple Train-Test Split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, shuffle=False,
)

## Train-Test-Validation Split
# First split: separate test set
# X_temp, X_test, y_temp, y_test = train_test_split(
#    X, y, test_size=0.2, random_state=42
# )
# Second split: separate validation set
# X_train, X_val, y_train, y_val = train_test_split(
#    X_temp, y_temp, test_size=0.25, random_state=42
# )

# Create model
dt = DecisionTreeClassifier(random_state=42)

# Select validation method
#cv = KFold(n_splits=3, shuffle=True, random_state=42)
#cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
#cv = GroupKFold(n_splits=3) # Requires groups parameter
#cv = TimeSeriesSplit(n_splits=3)
#cv = RepeatedKFold(n_splits=3, n_repeats=2, random_state=42)
#cv = RepeatedStratifiedKFold(n_splits=3, n_repeats=2, random_state=42)
cv = LeaveOneOut()
#cv = LeavePOut(p=3)
#cv = ShuffleSplit(n_splits=3, test_size=0.2, random_state=42)
#cv = StratifiedShuffleSplit(n_splits=3, test_size=0.3, random_state=42)

# Calculate and print scores
scores = cross_val_score(dt, X_train, y_train, cv=cv)
print(f"Validation accuracy: {scores.mean():.3f} ± {scores.std():.3f}")

# Final Fit & Test
dt.fit(X_train, y_train)
test_accuracy = dt.score(X_test, y_test)
print(f"Test accuracy: {test_accuracy:.3f}")

验证准确率: 0.429 ± 0.495

测试准确率: 0.714

对上述结果的评论: 验证准确率和测试准确率之间的巨大差距,以及验证分数中非常高的标准差,表明我们的模型表现不稳定。这种不一致性很可能来源于在我们的小型天气数据集上使用 LeaveOneOut 验证——在单个数据点上进行测试导致性能剧烈波动。使用较大的验证集的不同验证方法可能会给我们带来更可靠的结果。

选择正确的验证方法

选择如何验证模型并不简单——不同的情况需要不同的方法。理解使用哪种方法可能意味着获得可靠结果或误导性结果之间的差异。以下是选择验证方法时应该考虑的一些方面:

1. 数据集大小

数据集的大小对选择哪种验证方法最有效有很大的影响。让我们来看一下不同大小的数据集:

大数据集(超过 100,000 个样本)

当你有大量数据集时,测试所需的时间成为主要考虑因素之一。简单的保留验证(将数据一次性分为训练集和测试集)通常效果不错,因为你有足够的数据进行可靠的测试。如果需要使用交叉验证,使用 3 折或使用 ShuffleSplit 进行较少轮次的验证可以在不花费太多时间的情况下获得良好的结果。

中等数据集(1,000 到 100,000 个样本)

对于中等大小的数据集,常规的 K 折交叉验证效果最佳。使用 5 折或 10 折可以在可靠结果和合理的计算时间之间取得良好的平衡。这种数据量通常足以创建具有代表性的划分,而不会使得测试时间过长。

小型数据集(少于 1,000 个样本)

小型数据集,例如我们 28 天高尔夫记录的例子,需要更仔细的测试。在这种情况下,Leave-One-Out 交叉验证或重复 K 折交叉验证(使用更多折数)实际上可以很好地工作。尽管这些方法的运行时间较长,但在数据量不大的情况下,它们帮助我们获得最可靠的结果。

2. 计算资源

在选择验证方法时,我们需要考虑计算资源。在数据集大小、模型复杂度和所使用的验证方法之间存在三方面的平衡:

快速训练模型

像决策树、逻辑回归和线性 SVM 这样的简单模型可以使用更彻底的验证方法,如 Leave-One-Out 交叉验证或重复分层 K 折交叉验证,因为它们训练速度较快。由于每轮训练只需几秒钟或几分钟,我们可以承受多次验证迭代。即使是使用 N 轮训练的 LOOCV,也可能对这些算法来说是可行的。

资源密集型模型

深度神经网络、拥有大量树的随机森林或梯度提升模型的训练时间较长。在使用这些模型时,更加密集的验证方法,如重复 K 折交叉验证或 Leave-P-Out,可能不太实际。我们可能需要选择更简单的方法,如基本的 K 折交叉验证或 ShuffleSplit,以保持合理的测试时间。

内存考虑因素

一些方法,如 K 折交叉验证,需要同时跟踪多个数据划分。ShuffleSplit 可以帮助解决内存限制问题,因为它一次只处理一个随机划分。对于具有复杂模型(如需要大量内存的深度神经网络)的大规模数据集,可能需要使用更简单的保留方法。如果我们在内存有限的情况下仍需要彻底的验证,可以使用时间序列划分,因为它自然地按顺序处理数据,而不需要一次性将所有划分存储在内存中。

当资源有限时,使用一个我们可以顺利运行的更简单的验证方法(例如基本的 K 折交叉验证)比尝试运行一个我们无法完成的更复杂方法(例如 Leave-P-Out)要好。

3. 类别分布

类别不平衡会强烈影响我们应该如何验证模型。对于不平衡数据,分层验证方法变得至关重要。像分层 K 折交叉验证和分层 ShuffleSplit 这样的方式确保每个测试划分与完整数据集的类别分布大致相同。如果不使用这些分层方法,一些测试集可能完全没有某个类别,这样就无法正确测试模型的预测效果。

4. 时间序列

当处理随时间变化的数据时,我们需要特殊的验证方法。常规的随机划分方法效果不佳,因为时间顺序很重要。对于时间序列数据,我们必须使用像时间序列划分(Time Series Split)这样的方式,尊重时间顺序。

5. 群组依赖

许多数据集包含自然的相关数据组。在验证模型时,这些数据中的连接需要特殊处理。当数据点相关时,我们需要使用像 Group K-fold 这样的方式,以防止我们的模型错误地学习到不该学习的东西。

实用指南

这张流程图将帮助你为你的数据选择最合适的验证方法。下面的步骤概述了一个清晰的选择最佳验证方法的过程,前提是你有足够的计算资源。

最后的备注

模型验证对于构建可靠的机器学习模型至关重要。在探索了许多验证方法,从简单的训练-测试划分到复杂的交叉验证方法后,我们发现,总有一种适合你的数据的验证方法。

虽然机器学习在不断变化,出现了新的方法和工具,但这些基本的验证规则始终不变。当你很好地理解这些原则时,我相信你会建立起人们可以信任和依赖的模型。

深入阅读

关于验证方法的详细解释,读者可以参考官方文档,里面提供了全面的使用和参数说明。

技术环境

本文使用的是 Python 3.7 和 scikit-learn 1.5。尽管所讨论的概念通常适用,但具体的代码实现可能会因版本不同而有所变化。

关于插图

除非另有说明,否则所有图片均由作者创作,并结合了 Canva Pro 的授权设计元素。

𝙎𝙚𝙚 𝙢𝙤𝙧𝙚 𝙈𝙤𝙙𝙚𝙡 𝙀𝙫𝙖𝙡𝙪𝙖𝙩𝙞𝙤𝙣 & 𝙊𝙥𝙩𝙞𝙢𝙞𝙯𝙖𝙩𝙞𝙤𝙣 𝙢𝙚𝙩𝙝𝙤𝙙𝙨 𝙝𝙚𝙧𝙚:

Samy Baladram

Samy Baladram

模型评估与优化

查看列表3 个故事

𝙔𝙤𝙪 𝙢𝙞𝙜𝙝𝙩 𝙖𝙡𝙨𝙤 𝙡𝙞𝙠𝙚:

Samy Baladram

Samy Baladram

分类算法

查看列表8 个故事Samy Baladram

Samy Baladram

集成学习

查看列表4 个故事

使用马尔可夫链建模 DAU

原文:towardsdatascience.com/modeling-dau-with-markov-chain-640ea4fddeb4?source=collection_archive---------1-----------------------#2024-12-02

如何使用 Duolingo 的增长模型预测 DAU 并控制预测结果

Vladimir KukushkinTowards Data Science Vladimir Kukushkin

·发表于 Towards Data Science ·阅读时间:21 分钟·2024 年 12 月 2 日

--

1. 引言

毫无疑问,DAU、WAU 和 MAU——每日、每周和每月的 活跃用户——是关键的业务指标。由 Duolingo 前 CPO Jorge Mazal 撰写的文章 “How Duolingo reignited user growth” 位列 Lenny’s Newsletter 博客增长板块的第一篇文章。在这篇文章中,Jorge 特别关注了 Duolingo 如何建模 DAU 指标的方法(参见另一篇文章 “Meaningful metrics: how data sharpened the focus of product teams”Erin Gustafson 撰写)。这种方法有多个优点,但我想重点介绍如何使用这种方法进行 DAU 预测。

新的一年即将来临,许多公司目前正在为明年的预算进行规划。成本估算通常需要 DAU 预测。在本文中,我将展示如何使用 Duolingo 的增长模型进行预测。我还将解释为什么这种方法相比于标准的时间序列预测方法更为优越,以及如何根据团队的计划(例如市场营销、激活、产品团队)调整预测结果。

本文内容与代码相结合,并附有模拟数据集,研究结果完全可复现。Jupyter notebook 版本可以在 这里 查看。最后,我将分享一个以 Google Spreadsheet 格式设计的 DAU “计算器”。

我将以“我们”的身份叙述,就像我们在一起讨论一样。

2. 方法论

简要回顾一下Duolingo 的增长模型是如何工作的。在用户生命周期的第 d 天(d = 1, 2, …)中,用户可以处于以下 7 种(互斥)状态之一:newcurrentreactivatedresurrectedat_risk_wauat_risk_maudormant。这些状态是根据用户今天、过去 7 天或过去 30 天是否活跃来定义的。定义摘要如下表所示:

在定义了这些状态(作为集合 S)后,我们可以将用户行为视为一个马尔科夫链。以下是用户轨迹的一个例子:newcurrentcurrentat_risk_wau→...→ at_risk_mau→...→ dormant。设 M 为与此马尔科夫过程相关的转移矩阵:m_{i, j} = P(s_j | s_i) 表示用户从状态 s_i 转移到状态 s_j 的概率,其中 s_i, s_j ∈ S。这样一个矩阵是从历史数据中推断出来的。

如果我们假设用户行为是平稳的(与时间无关),那么矩阵 M 完全描述了未来所有用户的状态。假设长度为 7 的向量 u_0 包含给定一天(记为第 0 天)在各状态下的用户计数。根据马尔科夫模型,在第二天第 1 天,我们预计会有以下数量的用户状态 u_1:

通过递归应用此公式,我们可以推导出未来任意天 t > 0 时,处于某些状态的用户数量。

除了初始分布 u_0 外,我们还需要提供未来每天将会出现的新用户数量。我们将把这个问题作为一般的时间序列预测问题来处理。

现在,计算出 u_t 后,我们可以确定第 t 天的 DAU 值:

DAU_t = #New_t + #Current_t + #Reactivated_t + #Resurrected_t

此外,我们可以轻松计算 WAU 和 MAU 指标:

WAU_t = DAU_t + #AtRiskWau_t,

MAU_t = DAU_t + #AtRiskWau_t + #AtRiskMau_t。

最后,这里是算法大纲:

  1. 对于每个预测日 t = 1, …, T,计算预期的新用户数量 #New_1, …, #New_T。

  2. 对于每个用户的每个生命周期天,分配其中一个 7 个状态。

  3. 从历史数据中计算转移矩阵 M。

  4. 计算初始状态计数 u_0 对应于 t=0 的第 t 天。

  5. 递归地计算 u_{t+1} = M^T * u_t。

  6. 计算每个预测日 t = 1, …, T 的 DAU、WAU 和 MAU。

3. 实现

本节内容专注于实现的技术方面。如果你更感兴趣的是研究模型的属性,而非代码,可以跳过本节,直接查看第四部分。

3.1 数据集

我们使用基于 SaaS 应用历史数据的模拟数据集。数据存储在dau_data.csv.gz文件中,包含三列:user_iddateregistration_date。每条记录表示一个用户活跃的日期。数据集包括来自2020-11-012023-10-31的 51480 个用户的活动指标。此外,还包括 2020 年 10 月的数据,以正确计算用户状态,因为at_risk_maudormant状态需要前一个月的数据。

import pandas as pd

df = pd.read_csv('dau_data.csv.gz', compression='gzip')
df['date'] = pd.to_datetime(df['date'])
df['registration_date'] = pd.to_datetime(df['registration_date'])

print(f'Shape: {df.shape}')
print(f'Total users: {df['user_id'].nunique()}')
print(f'Data range: [{df['date'].min()}, {df['date'].max()}]')
df.head()
Shape: (667236, 3)
Total users: 51480
Data range: [2020-10-01 00:00:00, 2023-10-31 00:00:00]

这就是 DAU 时间序列的样子。

df.groupby('date').size()\
    .plot(title='DAU, historical')

假设今天是 2023 年 10 月 31 日,我们想预测 2024 年接下来的 DAU 指标。我们定义了两个全局常量PREDICTION_STARTPREDICTION_END,它们包含了预测期间。

PREDICTION_START = '2023-11-01'
PREDICTION_END = '2024-12-31'

3.2 预测新用户数量

从新用户预测开始。我们使用prophet库作为预测时间序列数据的最简单方法之一。new_users系列包含这样的数据。我们从原始的df数据集中提取它,选择registration date等于date的行。

new_users = df[df['date'] == df['registration_date']]\
    .groupby('date').size()
new_users.head()
date
2020-10-01    4
2020-10-02    4
2020-10-03    3
2020-10-04    4
2020-10-05    8
dtype: int64

prophet要求时间序列以包含dsy两列的 DataFrame 格式提供,因此我们将new_users系列重新格式化为new_users_prophet DataFrame。我们还需要准备的是创建future变量,包含用于预测的某些日期:从prediction_startprediction_end。这个逻辑在predict_new_users函数中实现。下面的图表展示了过去和未来时期的预测。

import logging
import matplotlib.pyplot as plt
from prophet import Prophet

# suppress prophet logs
logging.getLogger('prophet').setLevel(logging.WARNING)
logging.getLogger('cmdstanpy').disabled=True

def predict_new_users(prediction_start, prediction_end, new_users_train, show_plot=True):
    """
    Forecasts a time-seires for new users

    Parameters
    ----------
    prediction_start : str
        Date in YYYY-MM-DD format.
    prediction_end : str
        Date in YYYY-MM-DD format.
    new_users_train : pandas.Series
        Historical data for the time-series preceding the prediction period.
    show_plot : boolean, default=True
        If True, a chart with the train and predicted time-series values is displayed.
    Returns
    -------
    pandas.Series
        Series containing the predicted values.
    """
    m = Prophet()

    new_users_train = new_users_train\
        .loc[new_users_train.index < prediction_start]
    new_users_prophet = pd.DataFrame({
        'ds': new_users_train.index,
        'y': new_users_train.values
    })

    m.fit(new_users_prophet)

    periods = len(pd.date_range(prediction_start, prediction_end))
    future = m.make_future_dataframe(periods=periods)
    new_users_pred = m.predict(future)
    if show_plot:
        m.plot(new_users_pred)
        plt.title('New users prediction');

    new_users_pred = new_users_pred\
        .assign(yhat=lambda _df: _df['yhat'].astype(int))\
        .rename(columns={'ds': 'date', 'yhat': 'count'})\
        .set_index('date')\
        .clip(lower=0)\
        ['count']

    return new_users_pred
new_users_pred = predict_new_users(PREDICTION_START, PREDICTION_END, new_users)

new_users_pred系列存储预测的新用户数量。

new_users_pred.tail(5)
date
2024-12-27    52
2024-12-28    56
2024-12-29    71
2024-12-30    79
2024-12-31    74
Name: count, dtype: int64

3.3 获取状态

在实际操作中,最合理的计算方法是通过 SQL 查询在存储数据的数据库中执行。接下来,我们将使用duckdb库模拟这种查询。

我们希望为每个用户在应用中的生命周期中的每一天分配 7 种状态之一。根据定义,对于每一天,我们需要考虑至少过去 30 天的数据。这时,SQL 窗口函数发挥了作用。然而,由于df数据只包含活跃的日期记录,我们需要明确扩展它们,包含用户未活跃的日期。换句话说,而不是像这样一组记录:

user_id    date          registration_date
1234567    2023-01-01    2023-01-01
1234567    2023-01-03    2023-01-01

我们想得到这样的一个列表:

user_id    date          is_active    registration_date
1234567    2023-01-01    TRUE         2023-01-01
1234567    2023-01-02    FALSE        2023-01-01
1234567    2023-01-03    TRUE         2023-01-01
1234567    2023-01-04    FALSE        2023-01-01
1234567    2023-01-05    FALSE        2023-01-01
...        ...           ...          ...
1234567    2023-10-31    FALSE        2023-01-01

为了提高可读性,我们将以下 SQL 查询拆分成多个子查询。

  • full_range:为每个用户创建一个完整的日期序列。

  • dau_full:获取包含活跃和非活跃记录的完整列表。

  • states:为每个用户生命周期的每一天分配 7 种状态之一。

import duckdb

DATASET_START = '2020-11-01'
DATASET_END = '2023-10-31'
OBSERVATION_START = '2020-10-01'

query = f"""
WITH
full_range AS (
    SELECT
        user_id, UNNEST(generate_series(greatest(registration_date, '{OBSERVATION_START}'), date '{DATASET_END}', INTERVAL 1 DAY))::date AS date
    FROM (
        SELECT DISTINCT user_id, registration_date FROM df
    )
),
dau_full AS (
    SELECT
        fr.user_id,
        fr.date,
        df.date IS NOT NULL AS is_active,
        registration_date
    FROM full_range AS fr
    LEFT JOIN df USING(user_id, date)
),
states AS (
    SELECT
        user_id,
        date,
        is_active,
        first_value(registration_date IGNORE NULLS) OVER (PARTITION BY user_id ORDER BY date) AS registration_date,
        SUM(is_active::int) OVER (PARTITION BY user_id ORDER BY date ROWS BETWEEN 6 PRECEDING and 1 PRECEDING) AS active_days_back_6d,
        SUM(is_active::int) OVER (PARTITION BY user_id ORDER BY date ROWS BETWEEN 29 PRECEDING and 1 PRECEDING) AS active_days_back_29d,
        CASE
            WHEN date = registration_date THEN 'new'
            WHEN is_active = TRUE AND active_days_back_6d BETWEEN 1 and 6 THEN 'current'
            WHEN is_active = TRUE AND active_days_back_6d = 0 AND IFNULL(active_days_back_29d, 0) > 0 THEN 'reactivated'
            WHEN is_active = TRUE AND active_days_back_6d = 0 AND IFNULL(active_days_back_29d, 0) = 0 THEN 'resurrected'
            WHEN is_active = FALSE AND active_days_back_6d > 0 THEN 'at_risk_wau'
            WHEN is_active = FALSE AND active_days_back_6d = 0 AND ifnull(active_days_back_29d, 0) > 0 THEN 'at_risk_mau'
            ELSE 'dormant'
        END AS state
    FROM dau_full
)
SELECT user_id, date, state FROM states
WHERE date BETWEEN '{DATASET_START}' AND '{DATASET_END}'
ORDER BY user_id, date
"""
states = duckdb.sql(query).df()

查询结果保存在states DataFrame 中:

3.4 计算转移矩阵

获得这些状态后,我们可以计算状态转移频率。在第 4.3 节中,我们将研究预测如何依赖于考虑转移的周期,因此预先按日聚合这些数据是合理的。结果生成的transitions数据框包含datestate_fromstate_tocnt列。

现在,我们可以计算转移矩阵 M。我们实现了get_transition_matrix函数,它接受transitions数据框和一对日期,这些日期涵盖了要考虑的转移期。

作为基准,让我们计算从2022-11-012023-10-31的全年转移矩阵。

M = get_transition_matrix(transitions, '2022-11-01', '2023-10-31')
M

任何转移矩阵的每一行的和都等于 1,因为它表示从一个状态转移到任何其他状态的概率。

3.5 获取初始状态计数

初始状态是通过get_state0函数和相应的 SQL 查询从states数据框中检索的。该函数的唯一参数是我们想要获取初始状态的日期。我们将结果分配给state0变量。

def get_state0(date):
    query = f"""
    SELECT state, count(*) AS cnt
    FROM states
    WHERE date = '{date}'
    GROUP BY state
    """

    state0 = duckdb.sql(query).df()
    state0 = state0.set_index('state').reindex(states_order)['cnt']

    return state0
state0 = get_state0(DATASET_END)
state0
state
new               20
current          475
reactivated       15
resurrected       19
at_risk_wau      404
at_risk_mau     1024
dormant        49523
Name: cnt, dtype: int64

3.6 预测 DAU

下面的predict_dau函数接受所有之前预测 DAU 所需的变量,并对由start_dateend_date参数定义的日期范围进行预测。

def predict_dau(M, state0, start_date, end_date, new_users):
    """
    Predicts DAU over a given date range.

    Parameters
    ----------
    M : pandas.DataFrame
        Transition matrix representing user state changes.
    state0 : pandas.Series
        counts of initial state of users.
    start_date : str
        Start date of the prediction period in 'YYYY-MM-DD' format.
    end_date : str
        End date of the prediction period in 'YYYY-MM-DD' format.
    new_users : int or pandas.Series
        The expected amount of new users for each day between `start_date` and `end_date`.
        If a Series, it should have dates as the index.
        If an int, the same number is used for each day.

    Returns
    -------
    pandas.DataFrame
        DataFrame containing the predicted DAU, WAU, and MAU for each day in the date range,
        with columns for different user states and tot.
    """

    dates = pd.date_range(start_date, end_date)
    dates.name = 'date'
    dau_pred = []
    new_dau = state0.copy()
    for date in dates:
        new_dau = (M.transpose() @ new_dau).astype(int)
        if isinstance(new_users, int):
            new_users_today = new_users
        else:
            new_users_today = new_users.astype(int).loc[date] 
        new_dau.loc['new'] = new_users_today
        dau_pred.append(new_dau.tolist())

    dau_pred = pd.DataFrame(dau_pred, index=dates, columns=states_order)
    dau_pred['dau'] = dau_pred['new'] + dau_pred['current'] + dau_pred['reactivated'] + dau_pred['resurrected']
    dau_pred['wau'] = dau_pred['dau'] + dau_pred['at_risk_wau']
    dau_pred['mau'] = dau_pred['dau'] + dau_pred['at_risk_wau'] + dau_pred['at_risk_mau']

    return dau_pred
dau_pred = predict_dau(M, state0, PREDICTION_START, PREDICTION_END, new_users_pred)
dau_pred

这就是PREDICTION_START - PREDICTION_END期间 DAU 预测dau_pred的样子。除了预期的dauwaumau列,输出还包含每个预测日期每个状态下的用户数量。

最后,我们计算 DAU、WAU 和 MAU 的真实值(以及用户状态计数),将它们保存在dau_true数据框中,并将预测值和真实值一起绘制。

query = f"""
SELECT date, state, COUNT(*) AS cnt
FROM states
GROUP BY date, state
ORDER BY date, state;
"""

dau_true = duckdb.sql(query).df()
dau_true['date'] = pd.to_datetime(dau_true['date'])
dau_true = dau_true.pivot(index='date', columns='state', values='cnt')
dau_true['dau'] = dau_true['new'] + dau_true['current'] + dau_true['reactivated'] + dau_true['resurrected']
dau_true['wau'] = dau_true['dau'] + dau_true['at_risk_wau']
dau_true['mau'] = dau_true['dau'] + dau_true['at_risk_wau'] + dau_true['at_risk_mau']
dau_true.head()

pd.concat([dau_true['dau'], dau_pred['dau']])\
    .plot(title='DAU, historical & predicted');
plt.axvline(PREDICTION_START, color='k', linestyle='--');

我们已经获得了预测结果,但到目前为止,尚不清楚它是否公正。在下一节中,我们将评估该模型。

4. 模型评估

4.1 基准模型

首先,让我们检查是否真的需要建立一个复杂的模型来预测 DAU。难道不应该通过提到的prophet库将 DAU 作为一般时间序列来进行预测吗?下面的predict_dau_prophet函数实现了这一点。我们尝试使用库中可用的一些调整,以使预测更准确。特别是:

  • 我们使用逻辑回归模型而不是线性回归,以避免负值;

  • 我们明确添加了月度和年度季节性;

  • 我们去除离群值;

  • 我们明确将 1 月和 2 月的高峰期定义为“假期”。

def predict_dau_prophet(prediction_start, prediction_end, dau_true, show_plot=True):
    # assigning peak days for the new year
    holidays = pd.DataFrame({
        'holiday': 'january_spike',
        'ds': pd.date_range('2022-01-01', '2022-01-31', freq='D').tolist() + \
              pd.date_range('2023-01-01', '2023-01-31', freq='D').tolist(),
        'lower_window': 0,
        'upper_window': 40
    })

    m = Prophet(growth='logistic', holidays=holidays)
    m.add_seasonality(name='monthly', period=30.5, fourier_order=3)
    m.add_seasonality(name='yearly', period=365, fourier_order=3)

    train = dau_true.loc[(dau_true.index < prediction_start) & (dau_true.index >= '2021-08-01')]
    train_prophet = pd.DataFrame({'ds': train.index, 'y': train.values})
    # removining outliers
    train_prophet.loc[train_prophet['ds'].between('2022-06-07', '2022-06-09'), 'y'] = None
    train_prophet['new_year_peak'] = (train_prophet['ds'] >= '2022-01-01') &\
                                     (train_prophet['ds'] <= '2022-02-14')
    m.add_regressor('new_year_peak')
    # setting logistic upper and lower bounds
    train_prophet['cap'] = dau_true.max() * 1.1
    train_prophet['floor'] = 0

    m.fit(train_prophet)

    periods = len(pd.date_range(prediction_start, prediction_end))
    future = m.make_future_dataframe(periods=periods)
    future['new_year_peak'] = (future['ds'] >= '2022-01-01') & (future['ds'] <= '2022-02-14')
    future['cap'] = dau_true.max() * 1.1
    future['floor'] = 0
    pred = m.predict(future)

    if show_plot:
        m.plot(pred);

    # converting the predictions to an appropriate format
    pred = pred\
        .assign(yhat=lambda _df: _df['yhat'].astype(int))\
        .rename(columns={'ds': 'date', 'yhat': 'count'})\
        .set_index('date')\
        .clip(lower=0)\
        ['count']\
        .loc[lambda s: (s.index >= prediction_start) & (s.index <= prediction_end)]

    return pred

代码最终变得相当复杂,说明不能简单地将prophet应用于 DAU 时间序列。

接下来,我们测试多个预测范围的预测结果:3 个月、6 个月和 12 个月。因此,我们得到 3 个测试集:

  • 3 个月预测周期:2023-08-01 - 2023-10-31

  • 6 个月预测周期:2023-05-01 - 2023-10-31

  • 1 年预测周期:2022-11-01 - 2023-10-31

对于每个测试集,我们计算了MAPE损失函数。

from sklearn.metrics import mean_absolute_percentage_error

mapes = []
prediction_end = '2023-10-31'
prediction_horizon = [3, 6, 12]

for offset in prediction_horizon:
    prediction_start = pd.to_datetime(prediction_end) - pd.DateOffset(months=offset - 1)
    prediction_start = prediction_start.replace(day=1)
    prediction_end = '2023-10-31'
    pred = predict_dau_prophet(prediction_start, prediction_end, dau_true['dau'], show_plot=False)
    mape = mean_absolute_percentage_error(dau_true['dau'].reindex(pred.index), pred)
    mapes.append(mape)

mapes = pd.DataFrame({'horizon': prediction_horizon, 'MAPE': mapes})
mapes

MAPE 误差较高:18% — 35%。最短的预测周期有最高的误差,意味着该模型主要是针对长期预测进行调优的。这是这种方法的另一个不便之处:我们必须针对每个预测周期来调整模型。不管怎样,这就是我们的基准。在下一部分,我们将与更先进的模型进行比较。

4.2 一般评估

在这一部分,我们评估了第 3.6 节中实现的模型。目前我们将过渡期设置为预测开始前的一年。我们将在第 4.3 节中研究预测如何依赖于过渡期。至于新用户,我们使用两种选项运行模型:实际值和预测值。同样地,我们固定了 3 个预测周期,并在这些周期上测试模型。

以下的make_predicion辅助函数实现了所描述的选项。它接受prediction_startprediction_end参数,定义给定预测周期的开始和结束时间,new_users_mode可以是truepredict,以及transition_period。后者参数的选项将在后文进一步解释。

import re

def make_prediction(prediction_start, prediction_end, new_users_mode='predict', transition_period='last_30d'):
    prediction_start_minus_1d = pd.to_datetime(prediction_start) - pd.Timedelta('1d')
    state0 = get_state0(prediction_start_minus_1d)

    if new_users_mode == 'predict':
        new_users_pred = predict_new_users(prediction_start, prediction_end, new_users, show_plot=False)
    elif new_users_mode == 'true':
        new_users_pred = new_users.copy()

    if transition_period.startswith('last_'):
        shift = int(re.search(r'last_(\d+)d', transition_period).group(1))
        transitions_start = pd.to_datetime(prediction_start) - pd.Timedelta(shift, 'd')
        M = get_transition_matrix(transitions, transitions_start, prediction_start_minus_1d)
        dau_pred = predict_dau(M, state0, prediction_start, prediction_end, new_users_pred)
    else:
        transitions_start = pd.to_datetime(prediction_start) - pd.Timedelta(240, 'd')
        M_base = get_transition_matrix(transitions, transitions_start, prediction_start_minus_1d)
        dau_pred = pd.DataFrame()

        month_starts = pd.date_range(prediction_start, prediction_end, freq='1MS')
        N = len(month_starts)

        for i, prediction_month_start in enumerate(month_starts):
            prediction_month_end = pd.offsets.MonthEnd().rollforward(prediction_month_start)
            transitions_month_start = prediction_month_start - pd.Timedelta('365D')
            transitions_month_end = prediction_month_end - pd.Timedelta('365D')

            M_seasonal = get_transition_matrix(transitions, transitions_month_start, transitions_month_end)
            if transition_period == 'smoothing':
                i = min(i, 12)
                M = M_seasonal * i / (N - 1)  + (1 - i / (N - 1)) * M_base
            elif transition_period.startswith('seasonal_'):
                seasonal_coef = float(re.search(r'seasonal_(0\.\d+)', transition_period).group(1))
                M = seasonal_coef * M_seasonal + (1 - seasonal_coef) * M_base

            dau_tmp = predict_dau(M, state0, prediction_month_start, prediction_month_end, new_users_pred)
            dau_pred = pd.concat([dau_pred, dau_tmp])

            state0 = dau_tmp.loc[prediction_month_end][states_order]

    return dau_pred

def prediction_details(dau_true, dau_pred, show_plot=True, ax=None):
    y_true = dau_true.reindex(dau_pred.index)['dau']
    y_pred = dau_pred['dau']
    mape = mean_absolute_percentage_error(y_true, y_pred) 

    if show_plot:
        prediction_start = str(y_true.index.min().date())
        prediction_end = str(y_true.index.max().date())
        if ax is None:
            y_true.plot(label='DAU true')
            y_pred.plot(label='DAU pred')
            plt.title(f'DAU prediction, {prediction_start} - {prediction_end}')
            plt.legend()
        else:
            y_true.plot(label='DAU true', ax=ax)
            y_pred.plot(label='DAU pred', ax=ax)
            ax.set_title(f'DAU prediction, {prediction_start} - {prediction_end}')
            ax.legend()
    return mape

总体来说,我们有 6 个预测场景:2 个新用户选项和 3 个预测周期。下图展示了结果。左侧的图表与new_users_mode = 'predict'选项相关,而右侧的图表与new_users_mode = 'true'选项相关。

fig, axs = plt.subplots(3, 2, figsize=(15, 6))
mapes = []
prediction_end = '2023-10-31'
prediction_horizon = [3, 6, 12]

for i, offset in enumerate(prediction_horizon):
    prediction_start = pd.to_datetime(prediction_end) - pd.DateOffset(months=offset - 1)
    prediction_start = prediction_start.replace(day=1)
    args = {
        'prediction_start': prediction_start,
        'prediction_end': prediction_end,
        'transition_period': 'last_365d'
    }
    for j, new_users_mode in enumerate(['predict', 'true']):
        args['new_users_mode'] = new_users_mode
        dau_pred = make_prediction(**args)
        mape = prediction_details(dau_true, dau_pred, ax=axs[i, j])
        mapes.append([offset, new_users_mode, mape])

mapes = pd.DataFrame(mapes, columns=['horizon', 'new_users', 'MAPE'])
plt.tight_layout()

这里是总结预测质量的 MAPE 值:

mapes.pivot(index='horizon', columns='new_users', values='MAPE')

我们注意到了多个方面。

  • 一般来说,模型表现出的结果比基准要好得多。实际上,基准模型仅基于历史的 DAU 数据,而模型则利用了用户状态信息。

  • 然而,对于 1 年预测周期和new_users_mode='predict',MAPE 误差非常大:65%。这个值是相应基准误差(21%)的 3 倍。另一方面,new_users_mode='true'选项提供了一个更好的结果:8%。这意味着新用户预测对模型有着巨大的影响,特别是在长期预测中。对于较短的周期,这个差异则不那么明显。造成这种差异的主要原因是 1 年周期包括了圣诞节这一极端值。因此,i) 很难预测如此高的新用户值,ii) 这个周期对用户行为、过渡矩阵以及 DAU 值有着重大影响。因此,我们强烈建议谨慎地实现新用户预测。基准模型专门为圣诞节期间进行了调整,因此它超越了马尔可夫模型也就不足为奇了。

  • 当新用户预测准确时,模型能够很好地捕捉到趋势。这意味着使用过去 365 天的数据来计算转移矩阵是一个合理的选择。

  • 有趣的是,真实的新用户数据对于 3 个月的预测结果更差。这纯粹是巧合。2023 年 10 月错误的新用户预测逆转了预测的 DAU 趋势,并使得 MAPE 略有下降。

现在,让我们分解预测误差,看看哪些状态贡献最大。这里的误差指的是dau_pred - dau_true的值,相对误差指的是(dau_pred - dau_true)/ dau_true,对应的左侧和右侧图表如下所示。为了聚焦于这一方面,我们将配置范围缩小到三个月的预测期以及new_users_mode='true'选项。

dau_component_cols = ['new', 'current', 'reactivated', 'resurrected']

dau_pred = make_prediction('2023-08-01', '2023-10-31', new_users_mode='true', transition_period='last_365d')
figure, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

dau_pred[dau_component_cols]\
    .subtract(dau_true[dau_component_cols])\
    .reindex(dau_pred.index)\
    .plot(title='Prediction error by state', ax=ax1)

dau_pred[['current']]\
    .subtract(dau_true[['current']])\
    .div(dau_true[['current']])\
    .reindex(dau_pred.index)\
    .plot(title='Relative prediction error (current state)', ax=ax2);

从左侧图表中我们可以看到,误差主要由current状态贡献。这并不令人惊讶,因为这个状态对 DAU 的贡献最大。reactivatedresurrected状态的误差较低。另一个有趣的现象是,current状态的误差大多为负值,而resurrected状态的误差大多为正值。前者可能是由于在预测期内出现的新用户比过去的用户更加活跃。后者则表明,resurrected用户实际上对 DAU 的贡献低于转移矩阵的预期,因此dormantresurrected的转化率被高估了。

至于相对误差,分析current状态的误差是有意义的。因为reactivatedresurrected状态的日活跃用户数量较少,所以相对误差较大且噪声较多。current状态的相对误差介于-25%到 4%之间,这个范围相当大。由于我们已经固定了新用户的预测,这个误差只能用转移矩阵的不准确性来解释。特别是,currentcurrent的转化率大致为 0.8,这个值较高,因此它对误差贡献很大。所以如果我们想改善预测,首先需要考虑调整这一转化率。

4.3 转移期的影响

在上一节中,我们固定了转移期:预测开始前 1 年。现在我们将研究为了获得更准确的预测,这一时期应该多长。我们考虑相同的预测期:3 个月、6 个月和 12 个月。为了减少新用户预测的噪声,我们使用新用户数量的实际值:new_users_mode='true'

这里介绍了transition_period参数的变化。其值通过last_<N>d模式进行遮掩,其中N表示转移期的天数。对于每个预测期,我们计算了 12 个不同的转移期,分别为 1 个月、2 个月、...、12 个月。然后我们为每个选项计算 MAPE 误差并绘制结果。

result = []

for prediction_offset in prediction_horizon:
    prediction_start = pd.to_datetime(prediction_end) - pd.DateOffset(months=prediction_offset - 1)
    prediction_start = prediction_start.replace(day=1)

    for transition_offset in range(1, 13):
        dau_pred = make_prediction(
            prediction_start, prediction_end, new_users_mode='true',
            transition_period=f'last_{transition_offset*30}d'
        )
        mape = prediction_details(dau_true, dau_pred, show_plot=False)
        result.append([prediction_offset, transition_offset, mape])
result = pd.DataFrame(result, columns=['prediction_period', 'transition_period', 'mape'])

result.pivot(index='transition_period', columns='prediction_period', values='mape')\
    .plot(title='MAPE by prediction and transition period');

结果表明,最佳的转移期长度取决于预测范围。较短的预测期需要较短的转移期:对于 3、6 和 12 个月的预测,最小的 MAPE 误差分别出现在 1、4 和 8 个转移期。这显然是因为较长的预测期包含了一些只有通过较长转移期才能捕捉到的季节性效应。此外,似乎对于较长的预测期,MAPE 曲线呈 U 形,这意味着过长或过短的转移期都不利于预测。我们将在下一节展开讨论这个想法。

4.4 过时性与季节性

然而,固定一个单一的转移矩阵来预测整个未来一年的情况似乎并不是一个好主意:这样的模型过于僵化。通常,用户行为会根据季节变化。例如,在圣诞节后出现的用户可能会有行为上的变化。另一个典型的情况是用户在夏季改变他们的行为。在本节中,我们将尝试考虑这些季节性效应。

因此,我们希望预测从 2022 年 11 月开始的未来 1 年内的日活跃用户数(DAU)。我们不再仅使用一个计算自预测开始前 8 个月的单一转移矩阵M_base,而是根据上一小节的结果(并在下面标记为last_240d选项),将这个矩阵与一个季节性矩阵M_seasonal结合使用。后者是按月计算的,滞后 1 年。例如,为了预测 2022 年 11 月的 DAU,我们将M_seasonal定义为 2021 年 11 月的转移矩阵。然后,我们将预测的时间范围转移到 2022 年 12 月,并计算 2021 年 12 月的M_seasonal,以此类推。

为了混合M_baseM_seasonal,我们定义了以下两个选项。

  • seasonal_0.3:M = 0.3 * M_seasonal + 0.7 * M_base。0.3 是经过一些实验后选择的局部最小值的权重。

  • smoothing:M = i/(N-1) * M_seasonal + (1 - i/(N - 1)) * M_base,其中 N 是预测期内的月份数,i = 0, …, N - 1 是月份索引。这个配置的想法是随着预测月份的推进,从最新的转移矩阵M_base逐渐切换到季节性矩阵。

result = pd.DataFrame()
for transition_period in ['last_240d', 'seasonal_0.3', 'smoothing']:
    result[transition_period] = make_prediction(
        '2022-11-01', '2023-10-31',
        'true',
        transition_period
    )['dau']
result['true'] = dau_true['dau']
result['true'] = result['true'].astype(int)
result.plot(title='DAU prediction by different transition matrices');

mape = pd.DataFrame()
for col in result.columns:
    if col != 'true':
        mape.loc[col, 'mape'] = mean_absolute_percentage_error(result['true'], result[col])
mape

根据 MAPE 错误,seasonal_0.3 配置提供了最佳结果。有趣的是,smoothing 方法比 last_240d 更差。从上面的图表可以看出,所有三个模型从 2023 年 7 月开始都低估了 DAU 的值,尤其是 smoothing 模型。看起来从 2023 年 7 月开始出现的新用户比 2022 年的用户更活跃。可能是应用程序得到了足够的改善,或者市场团队做得很出色。结果,smoothing 模型过度依赖 2022 年 7 月至 10 月的过时过渡数据,表现得比其他模型更差。

4.5 最终解决方案

总结一下,让我们对 2024 年做一个最终的预测。我们使用 seasonal_0.3 配置和新用户的预测值。

dau_pred = make_prediction(
    PREDICTION_START, PREDICTION_END,
    new_users_mode='predict',
    transition_period='seasonal_0.3'
)
dau_true['dau'].plot(label='true')
dau_pred['dau'].plot(label='seasonal_0.3')
plt.title('DAU, historical & predicted')
plt.axvline(PREDICTION_START, color='k', linestyle='--')
plt.legend();

5. 讨论

在第四部分中,我们从预测准确性角度研究了模型的表现。现在让我们从实际角度讨论该模型。

除了准确性差之外,将 DAU 作为时间序列进行预测(参见第 4.1 节)使得这种方法非常僵化。本质上,它以一种方式做出预测,使得它最能拟合历史数据。实际上,在为来年做规划时,我们通常对未来有一些明确的预期。例如,

  • 市场团队将推出一些更有效的新营销活动,

  • 激活团队计划改进用户引导流程,

  • 产品团队将发布一些新功能,以更好地吸引和留住用户。

我们的模型可以考虑到这些预期。对于上述示例,我们可以分别调整新用户的预测、newcurrentcurrentcurrent 的转化率。结果,我们可以得到一个与历史数据不完全匹配但更为现实的预测。该模型的特点不仅仅是灵活——它是可解释的。你可以轻松地与利益相关者讨论这些调整,他们也能理解预测是如何工作的。

该模型的另一个优点是它不需要预测某个用户在某一天是否会活跃。有时,二分类器会被用于这个目的。这种方法的缺点是我们需要对每个用户(包括所有休眠用户)以及预测时间范围内的每一天应用这样的分类器。这是一个巨大的计算开销。相比之下,马尔可夫模型只需要初始状态量(state0)。此外,这类分类器通常是黑箱模型:它们难以解释且难以调整。

马尔可夫模型也有一些局限性。正如我们已经看到的,它对新用户的预测非常敏感。一个错误的新用户数量可能会完全破坏预测结果。另一个问题是,马尔可夫模型是“无记忆”的,意味着它没有考虑用户的历史。例如,它不能区分一个当前用户是新手、老手,还是重新激活/复活的用户。这些用户类型的留存率显然应该不同。此外,正如我们之前讨论的,用户行为可能会因季节、营销来源、国家等因素而有所不同。到目前为止,我们的模型还无法捕捉到这些差异。然而,这可能是进一步研究的一个课题:我们可以通过为不同的用户群体拟合更多的转移矩阵来扩展该模型。

最后,正如我们在导言中承诺的,我们提供了一个DAU 电子表格计算器。在Prediction工作表中,你需要填写初始状态分布行(标记为蓝色)和新用户预测列(标记为紫色)。在Conversions工作表中,你可以调整转移矩阵的值。请记住,矩阵的每一行的总和应等于 1。

暂时就这些。我希望这篇文章对你有所帮助。如果你有任何问题或建议,欢迎在下面的评论区提问,或者通过LinkedIn直接联系我。

本文中的所有图片均由作者生成。

使用马尔可夫链建模依赖随机变量

原文:towardsdatascience.com/modeling-dependent-random-variables-using-markov-chains-f363a3be1f9a?source=collection_archive---------2-----------------------#2024-02-23

纠正独立性违背的依赖模型

Roman PaolucciTowards Data Science Roman Paolucci

·发表于Towards Data Science ·阅读时长:15 分钟·2024 年 2 月 23 日

--

摄影:Nick Collins:www.pexels.com/photo/blue-abstract-painting-1292996/

在我之前讨论的随机变量参数的最大似然估计一文中,我们扮演了医院风险经理、高级医生统计学家、数据科学护士(我依然一点也不清楚谁会负责这个)等角色,并开发了一个简单的概率模型来估计在没有足够床位容纳新患者时的风险。为了实现这一目标,我们做出了以下假设:

  • 假设所有入住医院的患者当天会出院

  • 假设每天入住的患者是相互独立的

虽然这些假设不太现实,但它们使我们能够将某一天的患者数建模为一个泊松随机变量(见常见随机变量),该变量具有明确的分布函数,我们可以用它来估计无法容纳新患者的概率。

稍微偏离一下话题——当然,泊松随机变量由lambda参数化,它模型化了某一天患者的期望值和方差。在上一篇文章中,我们花了大部分时间讨论了如何利用最大似然法(maximum likelihood)根据一组观察数据来估计这个最佳统计量。

用神经 ODE 建模动态系统:实用指南

原文:towardsdatascience.com/modeling-dynamical-systems-with-neural-ode-a-hands-on-guide-71c4cfdb84dc?source=collection_archive---------2-----------------------#2024-01-12

概念,案例研究,逐步实施

Shuai GuoTowards Data Science Shuai Guo

·发表在Towards Data Science ·22 分钟阅读·2024 年 1 月 12 日

--

由 DALL-E 生成的图片。

对于许多物理系统,正确建模它们的动态行为为系统的理解、预测和控制奠定了基础。从数学上讲,描述系统状态随时间变化速率的常微分方程(ODEs)被广泛用于表征动态系统。相应地,许多分析和数值求解器已经被开发出来,可以有效地解决 ODEs,从而使我们能够描绘系统状态轨迹。

然而,使用 ODE 求解器有一个条件:我们需要完全了解 ODEs。对于许多真实世界的复杂系统,不幸的是,这个要求很少被满足:对于一些系统,我们可能知道 ODEs 的函数形式,但不知道相关的模型参数;对于其他系统,甚至它们的函数形式对我们来说是隐藏的。没有明确的方程,使用 ODE 求解器进行动态系统建模是不可行的。

图 1。在实践中,许多真实动态系统的主导 ODEs 只有部分已知。(作者提供的图片)

那么如何解决这个问题?

加泰罗尼亚语言灭绝建模

原文:towardsdatascience.com/modeling-the-extinction-of-the-catalan-language-01eda1527f5d?source=collection_archive---------9-----------------------#2024-06-28

将现有文献应用于实际案例

Pol MarinTowards Data Science Pol Marin

·发表于Towards Data Science ·9 分钟阅读·2024 年 6 月 28 日

--

图片来自Brett Jordan,来源于Unsplash

我们能预测一种语言的灭绝吗?这听起来并不简单,实际上也不应该简单,但这并不应该阻止我们尝试建立模型。

最近我对这个话题产生了兴趣,开始回顾一些现有的文献。我遇到了一篇文章[1],我很喜欢,决定与大家分享。

因此,在这篇文章中,我将分享那篇论文的见解,并将其翻译成(希望)易于理解的内容,并应用到一个实际案例中,以便我们能看到数据科学和数学建模的实际应用。

介绍

我是加泰罗尼亚人,对于那些不了解的人,加泰罗尼亚语是加泰罗尼亚、瓦伦西亚自治区和巴利阿里群岛(西班牙)与西班牙语共同的官方语言。它还是安道尔的官方语言,并在法国南部和意大利的阿尔盖罗(Alghero)地区存在。

我们常在当地的电视或媒体上看到,加泰罗尼亚语面临灭绝的风险。仅仅聚焦于加泰罗尼亚,我们就能深入了解这个情况,因为政府通过所谓的“语言使用调查”(Enquesta d’usos lingüistics de la població)来研究语言的使用情况[2]。

模型、MLFlow 和 Microsoft Fabric

原文:towardsdatascience.com/models-mlflow-and-microsoft-fabric-8faacaa90814?source=collection_archive---------7-----------------------#2024-04-29

Fabric 疯狂系列 第五部分

Roger NobleTowards Data Science Roger Noble

·发表于 Towards Data Science ·6 分钟阅读·2024 年 4 月 29 日

--

作者与 ChatGPT 提供的插图。 “设计一幅插图,展示多个机器学习模型,聚焦于篮球数据”提示。ChatGPT,4,OpenAI,2024 年 4 月 25 日。chat.openai.com.

非常感谢 Martim Chaves 他与我共同编写了这篇文章并开发了示例脚本。

到目前为止,在本系列中,我们已经探讨了如何使用 Fabric 来收集数据、进行特征工程和训练模型。

但现在我们已经有了这些崭新的模型,我们该怎么做呢?如何追踪它们,如何利用它们进行预测呢?这就是 MLFlow 的模型注册表发挥作用的地方,或者说是 Fabric 所称的ML 模型

模型注册表允许我们追踪模型的不同版本及其相应的表现。在生产环境中尤其有用,我们需要部署特定版本的模型来进行推理。

模型注册表可以看作是机器学习模型的源代码管理。基本上,每个版本代表一组独特的模型文件。这些文件包含了模型的架构、训练好的权重,以及任何其他必要的文件,用于加载和使用该模型。

在这篇文章中,我们将讨论如何记录模型,并且如何使用模型注册表来追踪模型的不同版本。我们还将讨论如何从注册表中加载模型并用其进行预测。

注册模型

在 Fabric 中注册模型有两种方式:通过代码或通过 UI。让我们来看看这两种方式。

使用代码注册模型

上一篇文章中,我们讨论了如何创建实验并记录具有不同配置的运行。记录或注册模型可以通过在运行中使用代码来完成。为此,我们只需添加几行代码。

# Start the training job with `start_run()`
with mlflow.start_run(run_name="logging_a_model") as run:
  # Previous code...
  # Train model
  # Log metrics

  # Calculate predictions for training set
  predictions = model.predict(X_train_scaled_df)

  # Create Signature
  # Signature required for model loading later on
  signature = infer_signature(np.array(X_train_scaled_df), predictions)

  # Model File Name
  model_file_name = model_name + "_file"

  # Log model
  mlflow.tensorflow.log_model(best_model, model_file_name, signature=signature)

  # Get model URI
  model_uri = f"runs:/{run.info.run_id}/{model_file_name}"

  # Register Model
  result = mlflow.register_model(model_uri, model_name)

在这个代码片段中,我们首先计算训练集的预测结果。然后创建一个签名,这本质上是模型的输入和输出形状。这是为了确保模型以后可以加载。

MLFlow 具有用于记录不同常用包(如TensorFlowPyTorchscikit-learn)制作的模型的功能。当使用 mlflow.tensorflow.log_model 时,一个文件夹作为工件保存,并附加到运行中,其中包含加载和运行模型所需的文件。在这些文件中,包含了模型的架构、训练好的权重以及重建所需的任何其他配置。这使得以后可以加载模型进行推理、微调或任何其他常规模型操作,而无需重新运行最初创建它的代码。

模型的 URI 用作指向模型文件的“路径”,由运行 ID 和用于模型的文件名组成。一旦我们获得了模型的 URI,就可以使用该 URI 注册一个 ML 模型。

有趣的是,如果已经存在同名的模型,将会添加一个新版本。这样,我们可以跟踪同一模型的不同版本,并查看它们的表现,而无需编写过于复杂的代码来管理这些版本。

在我们的上一篇文章中,我们进行了三个实验,每个实验测试一种模型架构,并使用三种不同的学习率。对于每种模型架构,都会创建一个 ML 模型,对于每种学习率,也会保存一个版本。总共有 9 个版本可供选择,每个版本具有不同的架构和学习率。

使用 UI 注册模型

ML 模型也可以通过 Fabric 的 UI 进行注册。可以从已创建的实验中导入模型版本。

图 1 — 使用 UI 创建 ML 模型。图片来自作者。

在创建 ML 模型后,我们可以从现有实验中导入模型。为此,在一次运行中,我们需要在 Save run as an ML Model 部分选择 Save

图 2 — 从一次运行中创建已创建的 ML 模型的新版本。图片来自作者。

选择最佳模型

现在我们已经注册了所有模型,可以选择最佳的一个。可以通过 UI 界面或代码实现。具体方法是打开每个实验,选择列表视图,然后选择所有可用的运行记录。找到最佳运行记录后,我们需要确认是哪一个模型和版本。

图 3 — 检查实验。图像由作者提供。

或者,这也可以通过代码实现,获取所有机器学习模型的所有版本的性能,并选择表现最佳的版本。

from mlflow.tracking import MlflowClient

client = MlflowClient()

mlmodel_names = list(model_dict.keys())
best_score = 2
metric_name = "brier"
best_model = {"model_name": "", "model_version": -1}

for mlmodel in mlmodel_names:

 model_versions = client.search_model_versions(filter_string=f"name = '{mlmodel}'")

 for version in model_versions:

  # Get metric history for Brier score and run ID
  metric_history = client.get_metric_history(run_id=version.run_id,
                                             key=metric_name)

  # If score better than best score, save model name and version
  if metric_history:
   last_value = metric_history[-1].value
   if last_value < best_score:
    best_model["model_name"] = mlmodel
    best_model["model_version"] = version.version
    best_score = last_value
  else:
   continue

在这段代码中,我们获取所有可用的机器学习模型列表。然后,我们遍历这个列表,获取每个机器学习模型的所有可用版本。

获取机器学习模型的版本列表可以使用以下代码行:

model_versions = client.search_model_versions(filter_string=f"name = '{mlmodel}'")

然后,对于每个版本,我们只需要获取其指标历史记录。这可以通过以下代码行实现:

metric_history = client.get_metric_history(run_id=version.run_id,
                                         key=metric_name)

然后,我们只需要追踪表现最好的版本。最终,我们找到了表现最佳的模型,不论其架构和超参数如何。

加载最佳模型

在找到最佳模型后,使用该模型进行最终预测可以通过以下代码片段实现:

# Load the best model
loaded_best_model = mlflow.pyfunc.load_model(f"models:/{best_model['model_name']}/{best_model['model_version'].version}")

# Evaluate the best model
final_brier_score = evaluate_model(loaded_best_model, X_test_scaled_df, y_test)
print(f"Best final Brier score: {final_brier_score}")

加载模型可以使用mlflow.pyfunc.load_model(),唯一需要的参数是模型的路径。模型的路径由模型名称和版本组成,格式为models:/[model name]/[version]。之后,我们只需确保输入的形状与训练时相同,并且特征的顺序也一致——就这样!

使用测试集,我们计算了最终的 Brier 分数,0.20

结论

在本文中,我们讨论了模型注册表背后的理念,以及为什么使用模型注册表是有益的。我们展示了如何通过机器学习模型工具使用 Fabric 的模型注册表,既可以通过 UI 界面,也可以通过代码实现。最后,我们介绍了如何从注册表加载模型并进行推理。

本系列教程到此为止。我们希望你喜欢这篇文章,并且学到了新知识。如果你有任何问题或评论,欢迎与我们联系。我们期待你的反馈!👋

原文发布于 https://nobledynamic.com 2024 年 4 月 29 日。

现代企业数据建模

原文:towardsdatascience.com/modern-enterprise-data-modeling-a3d61f7c0c25?source=collection_archive---------3-----------------------#2024-07-20

如何解决浅薄、过时的模型的不足,并使您的建模策略具有未来适应性

Bernd WesselyTowards Data Science Bernd Wessely

·发表于 Towards Data Science ·阅读时长 11 分钟·2024 年 7 月 20 日

--

图片来自 Justin Chrn ,来源:Unsplash

我从事数据建模已经超过 30 年,创建了多种数据模型(3NF、维度模型、集成模型(锚模型、数据金库)、图形模型等),主要用于分析系统。然而,这些模型中的许多也逐渐变得过时或过时了。有时候感觉就像是那个不幸的西西弗斯,他不断地把大石头推上山顶,却在某个时刻意识到这又是徒劳的。

很长一段时间,我一直相信,必须有可能为一家公司集中建模一个共同而完整的业务视图。毕竟,长期从事建模过程的业务人员,应该知道公司内部发生了什么吧?好吧,公司越小,我离目标就越近。但老实说,最终每个模型都只是一个近似值——一个试图反映不断变化现实的静态视图。

但是,即使创建这样的模型非常费力,我们绝对无法在没有它的情况下取得成功。现代数据驱动的企业基于从数据中提取价值的核心理念。然而,事实是数据本身没有价值。我们需要使用…

大型语言模型的 MOE 和 MOA

原文:towardsdatascience.com/moe-moa-for-large-language-models-c1cafeffd6a5?source=collection_archive---------8-----------------------#2024-07-16

寻求专家小组的建议

Rahul VirTowards Data Science Rahul Vir

·发表于 Towards Data Science ·阅读时间:4 分钟·2024 年 7 月 16 日

--

图片由作者提供(AI 生成 leonardo.ai)

大型语言模型(LLMs)无疑已经掀起了科技行业的风暴。它们的迅猛崛起得益于来自维基百科、网页、书籍、大量研究论文,以及我们喜爱的社交媒体平台的用户内容的大型数据集。这些数据和计算资源密集型的模型一直在热切地整合来自音频和视频库的多模态数据,并且已经使用了成千上万的英伟达 GPU,进行数月的训练,以打造最先进的(SOTA)模型。这一切让我们不禁思考,这种指数级增长是否能够持续下去。

这些大型语言模型面临的挑战是多方面的,但让我们在此探讨其中的一些。

  • 成本与可扩展性:更大的模型在训练和部署过程中可能需要数千万美元,这成为了日常应用中广泛采用的障碍。(参见训练 GPT-4 的成本

  • 训练数据饱和:公开可用的数据集很快就会耗尽,可能需要依赖缓慢生成的用户内容。只有那些拥有稳定新内容来源的公司和机构,才能够生成进一步的改进。

  • 幻觉:生成虚假且无根据信息的模型将成为一种障碍,因为用户在使用这些信息进行敏感应用之前,期望从权威来源获得验证。

  • 探索未知:LLMs 现在被用于超出其原始目的的应用。例如,LLMs 在游戏玩法、科学发现和气候建模方面展现了极大的能力。我们将需要新的方法来解决这些复杂情况。

在我们开始过于担心未来之前,先来看一下 AI 研究人员如何不懈努力以确保持续进展。Mixture-of-Experts(MoE)和 Mixture-of-Agents(MoA)的创新表明,希望就在眼前。

首次引入于 2017 年,Mixture-of-Experts技术表明,多个专家和一个能够选择稀疏专家的门控网络能够以更低的计算成本产生显著改善的结果。门控决策使得可以关闭网络的大部分部分,从而启用条件计算,并且专门化提高了语言建模和机器翻译任务的性能。

来源:来自极大规模神经网络的 MoE 层

上图展示了 Mixture-of-Experts 层如何被集成到一个递归神经网络中。门控层仅激活两个专家来处理任务,并随后将它们的输出合并。

尽管这一点已经在一些选定的基准上得到了验证,但条件计算为我们提供了一个途径,可以在不依赖日益增大的模型规模的情况下看到持续的改进。

受到 MOE 的启发,Mixture-of-Agents技术利用多个 LLM 来改善结果。问题通过多个 LLMs(即代理)进行处理,这些代理在每个阶段增强结果,作者已经证明,与更大的 SOTA 模型相比,这种方法能够用更小的模型产生更好的结果。

来源:Mixture-of-Agents 增强了大语言模型的能力 | 许可

该图展示了 4 个 Mixture-of-Agents 层,每个层中有 3 个代理。为每一层选择合适的 LLMs 非常重要,以确保适当的协作并产生高质量的响应。(来源)

MOA 依赖于这样一个事实:LLMs 之间的协作能产生更好的输出,因为它们能够结合其他模型的响应。LLMs 的角色分为生成多样输出的提议者和可以将其结合起来生成高质量响应的聚合器。多阶段方法可能会增加“首次令牌时间”(TTFT),因此需要开发缓解方法使其适用于广泛的应用。

MOE 和 MOA 具有相似的基础元素,但表现不同。MOE 基于选择一组专家来完成任务的概念,门控网络的任务是选择合适的专家集。MOA 则依赖于团队在前一团队工作的基础上进行建设,并在每个阶段改进结果。

MOE 和 MOA 的创新开辟了一条创新路径,在这条路径上,专门化的组件或模型的组合通过合作和信息交流,能够持续提供更好的结果,即使在模型参数和训练数据集的线性扩展不再简单时。

尽管我们只有回顾时才能知道大语言模型(LLM)创新是否能持久,但我一直在关注该领域的研究以获取见解。从各大高校和研究机构的成果来看,我对未来的进展非常看好。我确实感到,我们正处于新能力和新应用的“热身”阶段,这些将彻底改变我们的生活。我们不知道它们是什么,但可以相当确定,未来的日子不会让我们失望。

“我们往往高估技术在短期内的影响,而低估它在长期内的影响。” - 阿马拉定律

参考文献

[1] Wang, J., Wang, J., Athiwaratkun, B., Zhang, C., & Zou, J. (2024). 代理混合体增强大语言模型能力。arXiv [预印本]。arxiv.org/abs/2406.04692

[2] Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G., & Dean, J. (2017). 极其庞大的神经网络:稀疏门控专家混合层。arXiv 预印本 arXiv:1701.06538.

MOIRAI-MOE:通过专家混合技术升级 MOIRAI 以增强预测能力

原文:towardsdatascience.com/moirai-moe-upgrading-moirai-with-mixture-of-experts-for-enhanced-forecasting-26a38017734f?source=collection_archive---------3-----------------------#2024-11-02

这款流行的基础时间序列模型刚刚获得了更新。

Nikos KafritsasTowards Data Science Nikos Kafritsas

·发表于Towards Data Science ·阅读时间 9 分钟·2024 年 11 月 2 日

--

图片来源

争夺打造顶级基础预测模型的竞赛已经开始!

Salesforce 的MOIRAI,作为早期的基础模型之一,达到了高基准结果,并与其预训练数据集 LOTSA 一起开源。

我们在这里深入分析了 MOIRAI 的工作原理——并构建了一个端到端的项目,将 MOIRAI 与流行的统计模型进行比较。

Salesforce 现在发布了升级版——MOIRAI-MOE——具有显著的改进,特别是增加了专家混合(MOE)。我们曾在讨论另一个模型时简要提到过 MOE,Time-MOE也使用了多个专家。

在本文中,我们将讨论:

  • MOIRAI-MOE 如何工作,以及为什么它是一个强大的模型。

  • MOIRAI 和 MOIRAI-MOE 之间的主要区别。

  • MOIRAI-MOE 如何利用专家混合提高预测准确性。

  • 专家混合(Mixture-of-Experts)如何通常解决基础时间序列模型中的频率变化问题。

让我们开始吧。

✅ 我已推出AI Horizon Forecast这是一个专注于时间序列和创新的新闻通讯……

MOIRAI:Salesforce 的时间序列预测基础模型

原文:towardsdatascience.com/moirai-salesforces-foundation-model-for-time-series-forecasting-4eff6c34093d?source=collection_archive---------2-----------------------#2024-03-15

代码、模型权重和数据将在不久后发布

Nikos KafritsasTowards Data Science Nikos Kafritsas

·发表于数据科学前沿 ·10 分钟阅读·2024 年 3 月 15 日

--

图片来源:[1]

时间序列基础模型终于迎来了突破!

前面的文章探讨了 2 个有前景的基础预测模型,TimeGPT 和TimesFM

本文将探讨MOIRAI [1],这是 Salesforce 推出的开创性时间序列基础模型。MOIRAI在性能上表现出色——但更重要的是,作者们承诺将开源该模型及其训练数据集!

这点在 Caiming Xiong(Salesforce AI 副总裁,论文作者之一)的推文中有提到这里

本文的主要贡献如下:

  • MOIRAI:一种新型的变压器编码器架构,作为一个通用的时间序列预测模型。

  • LOTSA大型开放时间序列档案):全球最大的开放时间序列数据集,涵盖 9 个领域,共计 270 亿条观测数据。

  • UNITS:一个用于训练通用时间序列模型的开源库。

此外,本文还讨论了:

  1. MOIRAI的工作原理及其强大之处。

  2. MOIRAI与谷歌的 TimesFM 相比的表现

Moirai:面向通用预测的时间序列基础模型

原文:towardsdatascience.com/moirai-time-series-foundation-models-for-universal-forecasting-dc93f74b330f?source=collection_archive---------0-----------------------#2024-04-11

预测分析的未来:探索 Moirai,Salesforce 推出的用于先进时间序列预测的新基础模型

Luís RoqueTowards Data Science Luís Roque

·发表于Towards Data Science ·15 分钟阅读·2024 年 4 月 11 日

--

这篇文章由 Rafael Guedes 和我共同撰写。

引言

时间序列基础模型的开发在过去两个季度加速推进,我们几乎见证了每个月都有新模型的发布。从 2023 年最后一个季度的 TimeGPT [1]开始,到 Lag-Llama [2]的发布,再到 Google 推出 TimesFM [3]、Amazon 推出 Chronos [4],以及 Salesforce 发布 Moirai [5],这一过程正在加速。

为了理解对基础模型日益增长的兴趣,我们应该定义它们的核心能力:零-shot 推理。零-shot 推理指的是这些模型能够在训练过程中从未接触过的数据上准确执行任务或做出预测的能力。这种能力已被探索应用于各个领域的模型,例如自然语言处理(NLP)、计算机视觉以及多模态任务(结合文本、图像等)。术语“零-shot”来自于这样一个概念:模型在训练期间未曾见过来自某个特定任务或数据领域的任何示例,但它能够有效地“射击”或瞄准该领域的任务。这个术语最早出现在 Hinton 等人发表的论文《Zero-Shot Learning with Semantic Output Codes》中,并在 NIPS 会议上进行了介绍…

塑造想象力:使用 AI 创建新的 3D 可打印物体

原文:towardsdatascience.com/molding-the-imagination-using-ai-to-create-new-3d-printable-objects-cf3682f8563b?source=collection_archive---------8-----------------------#2024-02-13

将你的创意转化为具体的物品,使用 Midjourney 和开源项目:Shap-E、MVDream 和 threestudio

Robert A. GonsalvesTowards Data Science Robert A. Gonsalves

·发表于 Towards Data Science ·21 分钟阅读·2024 年 2 月 13 日

--

AI 系统根据文本提示生成的 3D 打印物体,图片由作者提供

如果你一直在阅读我在 Medium 上的文章,你就会知道我喜欢通过使用 AI 进行创作实验并撰写我的经验。我已经涵盖了图像生成、创意写作和音乐创作等领域。第一次,我将我的研究拓展到了第三维度。我研究了使用商业和开源 AI 工具来创建新的物理对象,并使用 3D 打印机将它们打印出来。在本文中,我将展示我使用各种商业和开源工具设计和打印四种不同 3D 网格的步骤。你将在附录的 3D 画廊中看到这四个物体。

概述

在接下来的章节中,我将带领你了解我使用不同工具进行 3D 对象生成的四个实验,并展示结果。第一个实验使用了商业工具:Midjourney 来创建 2D 图像,以及一个名为 3dMaker.ai 的网站来提取 3D 网格,这是一种包含几何形状的文件,用于在 3D 打印和显示中使用。第二个实验使用了 OpenAI 的开源 AI 模型 Shape-E [1]。第三个实验使用了…

MOMENT:一个基础时间序列预测、分类、异常检测模型

原文:towardsdatascience.com/moment-a-foundation-model-for-time-series-forecasting-classification-anomaly-detection-1e35f5b6ca76?source=collection_archive---------0-----------------------#2024-04-27

一个统一的模型,涵盖多个时间序列任务

Nikos KafritsasTowards Data Science Nikos Kafritsas

·发表于Towards Data Science ·阅读时间 11 分钟·2024 年 4 月 27 日

--

由作者使用 DALLE*3 创建

基础模型已点燃 LLMs 在时间序列中的应用。

在过去的几个月里,我们看到了新的预测模型的发布,如TimesFMTimeGPT以及当然是 Salesforce 的MOIRAI

基础时间序列模型将对实际应用产生重大影响。时间序列无处不在,广泛应用于零售、能源需求、经济学、医疗保健等多个领域。一个基础时间序列模型可以像 GPT-4 处理文本一样,迅速应用于任何时间序列任务并取得很高的准确性。

本文探讨了MOMENT[1],最新的基础时间序列模型。

MOMENT 与前述模型的区别在于其通用性——它可以处理预测分类异常检测填补缺失值任务。此外,它是开源的!

本文将描述MOMENT的工作原理、架构,以及与其他 SOTA 时间序列模型的性能比较。

让我们开始吧。

我已发布 AI Horizon Forecast, 一份专注于时间序列和创新 AI 研究的通讯。订阅 这里 以拓宽你的

使用 Snowflake 的数据指标函数监控数据管道

原文:towardsdatascience.com/monitor-data-pipelines-using-snowflakes-data-metric-functions-0df71c46f04a?source=collection_archive---------8-----------------------#2024-04-15

使用谷歌 SRE 原则构建可信的数据平台

Jess.ZTowards Data Science Jess.Z

·发表于 Towards Data Science ·阅读时间 6 分钟·2024 年 4 月 15 日

--

由 Dall-E 生成的图像

您是否遇到过客户首先因为数据事件而找上门?您的客户是否因为数据不可信而自己构建数据解决方案?您的数据团队是否在修复未检测到的数据质量问题上花费了不必要的长时间,而不是优先处理战略性工作?

数据团队需要能够全面了解其数据系统的健康状况,以便赢得利益相关者的信任,并与整个业务进行更好的沟通。

我们可以将数据质量维度与谷歌的站点可靠性工程(SRE)原则结合起来,衡量我们数据系统的健康状况。为此,评估一些对您的数据管道有意义的数据质量维度,并制定服务水平目标(SLOs)

什么是服务水平目标?

本文中我们将使用的服务水平术语是服务水平指标服务水平目标。这两个概念是借鉴自 谷歌 SRE 书籍的原则。

服务水平 指标 — 精确定义的、定量衡量某个服务层面水平的指标。

我们在软件领域常见的指标有吞吐量、延迟和正常运行时间(可用性)。这些用于衡量应用程序或网站的可靠性。

典型事件

然后,这些指标被转化为受阈值限制的目标。软件应用的健康状况现在是“可度量的”,我们可以与客户沟通应用的状态。

服务水平目标:由 SLI 衡量的服务水平的目标值或范围。

我们直观地理解这些定量衡量标准和指标在典型用户应用中的必要性,以减少摩擦并建立与客户的信任。在构建数据管道时,我们需要采用类似的思维方式。

数据质量维度转化为服务水平术语

数据系统故障

假设用户与我们的应用程序互动并每小时生成 X 量的数据进入我们的数据仓库,如果进入仓库的行数突然大幅下降,我们可以将其标记为问题。然后,我们可以追踪管道中的时间戳来诊断并解决问题。

我们希望捕获进入我们系统的数据的足够信息,以便在发生异常时能够检测到。大多数数据团队倾向于从数据及时性开始。预期的数据是否在正确的时间到达?

这可以分解为以下指标:

  • 数据可用性——预期的数据是否已经到达/可用?

  • 数据新鲜度——新的数据是否按预期时间到达?

数据质量维度转化为 SLIs 和 SLOs

一旦系统稳定,保持与客户的良好关系就变得重要,以便设定对利益相关者有价值的正确目标。

阈值的概念…

我们如何实际确定期望多少数据以及何时到达?对于我们所有不同的数据集,正确的数据量是多少?这时我们需要关注阈值概念,因为它确实比较复杂。

假设我们有一个应用程序,用户主要在工作时间登录系统。我们预计每天 9 点到下午 5 点之间大约会有 2000 次 USER_LOGIN 事件,而在其他时间则有 100 次。如果我们为一天使用一个单一的阈值,它会得出错误的结论。在晚上 8 点接收 120 个事件是完全合理的,但如果我们在下午 2 点只接收了 120 个事件,那就值得关注,并且应该进一步调查。

显示阈值线的图表,绿色

因此,我们需要为一天中的每个小时以及每个不同的数据集计算不同的期望值——这就是阈值。需要定义一个元数据表,动态地获取每小时到达的行数,以便得出对每个数据源有意义的阈值。

有些阈值可以使用时间戳作为代理提取,正如上面所解释的那样。这可以通过统计度量,如平均值、标准差或百分位数,来迭代你的元数据表。

根据你的创造性程度,你甚至可以在这个过程的某一部分引入机器学习,帮助你设定阈值。其他阈值或预期需要与利益相关者讨论,因为这将依赖于对业务的具体了解,知道该预期什么。

在 Snowflake 中的技术实现

开始的第一步是选择几个对业务至关重要的数据集,在实现大规模的数据运维解决方案之前在其上构建。这是获取动力并感受数据可观察性工作影响的最简单方式。

许多分析型数据仓库已经内建了相关功能。例如,Snowflake 最近为企业账户推出了数据度量函数预览版,帮助数据团队快速入门。

数据度量函数是我们可能写的一些查询的包装器,用来洞察我们的数据系统。我们可以从系统 DMF 开始。

Snowflake 系统 DMF

我们首先需要整理一些权限…

DMF 访问控制文档

USE ROLE ACCOUNTADMIN;

GRANT database role DATA_METRIC_USER TO role jess_zhang;

GRANT EXECUTE data metric FUNCTION ON account TO role jess_zhang;

## Useful queries once the above succeeds
SHOW DATA METRIC FUNCTIONS IN ACCOUNT;
DESC FUNCTION snowflake.core.NULL_COUNT(TABLE(VARCHAR));

DATA_METRIC_USER 是一个数据库角色,这可能会让一些人感到困惑。如果你遇到问题,重新查看文档非常重要。最可能的原因是由于权限问题。

然后,简单地选择一个 DMF …

-- Uniqueness
SELECT SNOWFLAKE.CORE.NULL_COUNT(
  SELECT customer_id
  FROM jzhang_test.product.fct_subscriptions
);
-- Freshness
SELECT SNOWFLAKE.CORE.FRESHNESS(
  SELECT
    _loaded_at_utc
  FROM jzhang_test.product.fct_subscriptions
) < 60;
-- replace 60 with your calculated threshold value

你可以使用数据度量调度 —— 一个对象参数或者你常用的编排工具来安排你的 DMF 运行。仍然需要做大量工作来确定自己的阈值,以便为管道设置正确的 SLO。

总结…

数据团队需要与利益相关者合作,通过使用服务水平指标和目标来设定关于数据的更合理预期。引入这些指标将帮助数据团队从被动的应急响应转变为更主动的方式,防止数据事件的发生。这将使精力重新集中于交付业务价值,并构建一个可信的数据平台。

除非另有说明,所有图片均由作者提供。

监控 Amazon EventBridge 规则

原文:towardsdatascience.com/monitoring-amazon-eventbridge-rules-127434c58984?source=collection_archive---------11-----------------------#2024-08-30

探索指标产品并提出改进建议

Ross RhodesTowards Data Science Ross Rhodes

·发布于 Towards Data Science ·6 分钟阅读·2024 年 8 月 30 日

--

来源于CloudWatch Events,Amazon Web Services (AWS)于 2019 年 7 月推出 EventBridge: 一款无服务器产品,适用于事件驱动架构,现在包括多个组件:

  • Buses(代理)和Rules,用于将消费者与事件生产者集成。

  • Scheduling 用于完全托管的一次性或定期任务调用。

  • Pipes 提供生产者与消费者之间的托管集成。

网上有大量资料解释 EventBridge 的工作原理。产品文档是一个宝贵的起点。在这篇文章中,我们将深入探讨 EventBridge 规则的监控。

图片来源:berko via Unsplash

我们将从指标工作原理概述开始,然后探讨当前的规则指标产品并识别其局限性。在涵盖了广泛的规则指标后,我们将提出改进建议,以便为客户提供额外的价值。

规则指标的工作原理

与其他 AWS 服务一致,EventBridge 规则提供CloudWatch 指标来观察性能。一些指标可能会通过一个或多个维度来细分,以便查看行为…

监控 Amazon Kinesis 数据流

原文:towardsdatascience.com/monitoring-amazon-kinesis-data-streams-825ae21dfa18?source=collection_archive---------11-----------------------#2024-09-04

探索可观察性功能并提出改进建议

Ross RhodesTowards Data Science Ross Rhodes

·发表于 Towards Data Science ·阅读时长 8 分钟·2024 年 9 月 4 日

--

于 2013 年 11 月发布,Kinesis 是一个由亚马逊网络服务(AWS)提供的托管解决方案,用于实时处理和分析流数据。Kinesis 提供了几种不同的流媒体服务:

  • 数据流 用于处理数据记录,无需管理服务器。

  • 视频流 用于捕获和存储媒体流,例如用于回放。

目前已有大量在线文档深入探讨 Kinesis 的工作原理。亚马逊官方文档 是一个很好的参考起点。在这篇博客文章中,我们将重点介绍 Kinesis 数据流监控,格式与 我之前关于 EventBridge 规则的文章类似。

图片来源:Rob Laughter via Unsplash

我们将从 Kinesis 指标的工作原理概述开始,然后介绍当前的、特别有意义的数据流指标。在适用的情况下,我们将指出它们的局限性。接着,我们将给出可以为数据流可观察性增值的建议步骤。

数据流指标如何工作

与其他 AWS 服务一样,Kinesis 提供 CloudWatch 指标 来监控数据流的性能。提供两个级别的指标: 级别和 分片 级别。对于……

单声道到立体声:AI 如何为音乐注入新生命

原文:towardsdatascience.com/mono-to-stereo-how-ai-is-breathing-new-life-into-music-4180f1357db4?source=collection_archive---------4-----------------------#2024-12-24

AI 单声道到立体声升混的应用与技术

Max HilsdorfTowards Data Science Max Hilsdorf

·发表于Towards Data Science ·10 分钟阅读·2024 年 12 月 24 日

--

图片由 DALL-E 3 生成。

单声道录音是历史的一个快照,但它们缺乏使音乐真正充满生命力的空间感。通过人工智能,我们可以将单声道录音人工转换为立体声,甚至重新混音现有的立体声录音。本文将探讨单声道到立体声的升混的实际应用场景和方法。

物理和数字世界中的单声道与立体声

图片由J提供,来源于Unsplash

当一个乐团现场演奏时,声波从不同的乐器传播到房间中的各个地方,最终传到你的耳朵。这会导致时间差异(声音到达耳朵的时间)和音量差异(每只耳朵听到的声音大小)。通过这个过程,音乐表演不仅仅是和声、音色和节奏。每个乐器都传递着空间信息,让听众沉浸在一种“此时此刻”的体验中,吸引他们的注意力并触动他们的情感。

听听第一个片段(没有空间信息)和第二个片段(左右耳之间明显的差异)之间的区别:

文章全程强烈推荐使用耳机,但并非绝对必要。

示例:单声道

歌曲原由Lexin Music。适用 Pixabay 的内容许可

示例:立体声

这首歌最初由Lexin Music创作。Pixabay 的内容许可适用。

正如你所听到的,通过录音传递的空间信息对我们作为听众感知的生动性和兴奋感有着强烈的影响。

在数字音频中,最常见的格式是单声道立体声。单声道录音仅由一个音频信号组成,在你的耳机两侧听起来完全相同(我们称其为通道)。而立体声录音则由两个独立的信号组成,分别完全向左和向右通道分配。

一个由两个通道组成的立体波形示例。图片由作者提供。

现在我们已经体验到立体声如何使听觉体验更加生动和引人入胜,同时也理解了关键术语,我们可以更深入地探讨我们所关注的问题:AI 在单声道到立体声转换中的作用,也被称为单声道到立体声升混

单声道到立体声升混的应用案例

AI 并不是最终目的。为了证明这种先进技术的开发和应用是合理的,我们需要实际的应用案例。单声道到立体声升混的两个主要应用案例是:

1. 将现有的单声道音乐丰富为立体声体验。

尽管立体声录音技术在 1930 年代初期就已经发明,但直到 1960 年代它才成为录音室的事实标准,并且更长时间才在普通家庭中普及。在 1950 年代末期,新的电影发行仍然同时配有立体声轨道和额外的单声道轨道,以便为那些尚未过渡到立体声系统的电影院做准备。简而言之,有许多流行歌曲是以单声道录制的。以下是一些示例:

  • 埃尔维斯·普雷斯利:That’s All Right

  • 查克·贝瑞:Johnny Be Goode

  • 杜克·埃灵顿:Take the “A” Train

“Elvis Presley: That’s All Right”的官方音频,这首歌于 1954 年以单声道录音发布。

即便在今天,业余音乐人可能也会发布他们的单声道录音,要么是因为缺乏技术能力,要么仅仅是因为他们不想花时间去制作立体声混音。

单声道到立体声的转换让我们可以以全新的方式体验我们喜爱的旧录音,还能将业余录音或示范曲目带入生命。

2. 改进或现代化现有的立体声混音,特别是那些显得杂乱无章或已经过时的混音。

即使有立体声录音,我们仍然可能想要改善它。例如,许多 60 年代和 70 年代的老录音是立体声录制的,但每个乐器完全偏向一侧。听一下 The Doors 的《Soul Kitchen》,注意到低音和鼓完全偏向左边,键盘和吉他偏向右边,主唱在中央。歌曲很棒,并且有独特的美学,但这个立体声混音可能不太符合现代听众的喜好。

技术限制曾经影响了立体声的表现。此外,立体声混音不仅仅是一种工艺,它是艺术创作的一部分。立体声混音可以在客观上是可以接受的,但在风格上仍然可能不合时宜。一种立体声转换工具可以用来创建一个更符合特定风格偏好的替代立体声版本。

单声道到立体声 AI 如何工作

现在我们讨论了单声道到立体声技术的重要性,你可能会想知道它在幕后是如何工作的。事实证明,有多种方法可以使用 AI 来解决这个问题。接下来,我将展示四种不同的方法,从传统信号处理到生成 AI。这并不是一个完整的方法列表,而是过去 20 年中如何解决这个任务的一些启示。

传统信号处理:声音源形成

在机器学习像今天这样流行之前,音乐信息检索(MIR)领域曾由智能、手工制作的算法主导。难怪这种方法在单声道到立体声的上混音中也有所应用。

2007 年一篇论文(Lagrange, Martins, Tzanetakis, [1])的基本思想很简单:

如果我们能够找到录音中的不同声音源并从信号中提取它们,我们就可以将它们重新混合在一起,以实现逼真的立体声体验。

听起来很简单,但我们如何分辨信号中的声音源呢?我们如何清晰地定义它们,使得算法能够从信号中提取它们?这些问题非常难以解决,这篇论文使用了多种先进的方法来实现这一点。归根结底,这就是他们提出的算法:

  1. 将录音分成短片段,并识别每个片段中的峰值频率(主导音符)

  2. 识别哪些峰值属于同一组(即一个声音源),使用聚类算法

  3. 决定每个声音源应该在立体声混音中位置(手动步骤)

  4. 对于每个声音源,提取其分配的频率从信号中

  5. 将所有提取的源混合在一起,形成最终的立体声混音。

这是为研究构建的用户界面示例。用户浏览所有提取的源,并手动将它们放入立体声混音中,然后重新合成整个信号。图片来源于[1]

尽管细节上相当复杂,但直觉是非常清晰的:找到来源,提取它们,再将它们混合在一起。

快速解决方案:源分离 / 音轨分离

自从拉格朗日 2007 年的论文以来,发生了很多变化。自 2019 年 Deezer 发布了他们的音轨分离工具Spleeter以来,基于 AI 的源分离系统变得非常实用。领先的玩家如Lalal.aiAudioshake使得快速解决方案成为可能:

  1. 使用免费的或商业的音轨分离工具将单声道录音分离为各个乐器音轨

  2. 将音轨加载到数字音频工作站(DAW)中,根据自己的喜好将它们混合在一起

这一技术在 2011 年的一篇研究论文中已有应用(见[2]),但由于最近在音轨分离工具的改进,它变得更加可行。

源分离方法的缺点是它们会产生明显的声音伪影,因为源分离本身仍然不是没有缺陷的。此外,这些方法仍然需要人工混音,使得它们只是半自动化的。

要实现完全自动化的单声道到立体声升混,需要使用机器学习。通过从真实的立体声混音中学习,机器学习系统可以适应真实人类制作人的混音风格。

基于参数化立体声的机器学习

摄影:由Zarak Khan拍摄,来自Unsplash

一种非常创新和高效的利用机器学习进行单声道到立体声升混的方法,由 Serrà及其同事在 2023 年 ISMIR 会议上提出[3]。这项工作基于一种名为参数化立体声的音乐压缩技术。立体声混音由两个音频通道组成,这使得在低带宽环境下(如音乐流媒体、广播或电话连接)集成变得困难。

参数化立体声是一种通过关注我们大脑用来判断声音来源的空间线索,将单声道信号转换为立体声的方法。这些线索包括:

  1. 左右耳之间的声音强度差异(声道间强度差,IID)

  2. 它在时间或相位上左右声道的同步程度(声道间时间或相位差)

  3. 每只耳朵中信号的相似性或差异性(声道间相关性,IC)

利用这些参数,可以仅通过单声道信号创建类似立体声的体验。

这是研究人员开发他们的单声道到立体声升混模型时采用的方法:

  1. 收集大量的立体声音乐轨道数据集

  2. 将立体声轨道转换为参数化立体声(单声道 + 空间参数)

  3. 训练一个神经网络,根据单声道录音预测空间参数

  4. 为了将一个新的单声道信号转换为立体声,可以使用训练好的模型从单声道信号中推断空间参数,然后将这两个参数结合成一个参数化的立体声体验。

目前,似乎没有任何代码或试听演示可以用于这篇论文。作者们自己也承认,“专业立体声混音与提出的方法之间仍然存在差距”(第 6 页)。尽管如此,论文概述了一种创造性且高效的方式,使用机器学习完成完全自动化的单声道到立体声的上混合。

生成式 AI:基于 Transformer 的合成

Meta 的文本到音乐模型 MusicGen 中的立体声生成。图片来自作者的另一篇文章

现在,我们将讨论看似最直接的单声道生成立体声的方法。训练一个生成模型,接受单声道输入,并直接合成两个立体声输出通道。尽管概念上简单,从技术角度来看,这是迄今为止最具挑战性的方法。1 秒钟的高分辨率音频有 44.1k 个数据点。因此,生成一首三分钟的立体声音频意味着生成超过 1500 万个数据点

现在,借助卷积神经网络、变换器和神经音频编解码器等技术,任务的复杂性开始变得可控。有些论文选择通过直接的神经合成生成立体声信号(参见[4][5][6])。然而,只有[5]训练了一个可以直接解决单声道到立体声生成的模型。我的直觉是,应该有一篇论文专门解决“简单”的单声道到立体声生成任务,并专注于 100%解决这个目标。有没有人对这个博士课题感兴趣?

接下来需要发生什么?

Samuel Spagl拍摄,图片来自Unsplash

在总结这篇文章时,我想讨论单声道到立体声上混合的研究领域可能的发展方向。最重要的是,我注意到这一领域的研究非常稀少,相比之下,像文本到音乐生成这样的热门话题则备受关注。我认为,研究社区应该集中精力,将单声道到立体声上混合的研究提升到一个新的水平:

1. 开放可用的演示和代码

目前该研究领域发布的论文不多。这使得许多论文没有与社区分享代码或工作成果变得更加令人沮丧。我曾多次阅读一篇引人入胜的论文,却发现唯一能够测试该方法输出质量的方式,是完全理解论文中的每一个公式,并从头开始实现算法。

分享代码和创建公共演示从未像今天这样容易。研究人员应该将此作为优先事项,以使更广泛的音频社区能够理解、评估并欣赏他们的工作。

2. 全力投入生成式 AI

传统信号处理和机器学习很有趣,但当谈到输出质量时,再也无法绕过生成式 AI 了。文本到音乐的模型已经能生成听起来很棒的立体声混音。为什么没有易于使用的、最先进的单声道到立体声升混库呢?

从我的研究中得知,构建一个高效且有效的模型可以通过合理的数据集大小以及对现有模型架构和训练方法进行最小至适度的修改来实现。我的印象是,这是一项触手可及的成果,是一个“就做吧!”的情况。

3. 使升混自动化,但可控

一旦我们拥有了一个出色的开源升混模型,接下来我们需要的是可控性。我们不应该在黑盒的“要么接受,要么放弃”的神经网络生成与基于源分离的老派手动混音之间做选择。我认为我们可以兼得这两者。

一个神经网络单声道到立体声升混模型可以在一个庞大的数据集上进行训练,然后进行微调,以根据用户提示调整其立体声混音。这样,音乐家们可以根据个人偏好定制生成立体声的风格。

结论

有效且公开可访问的单声道到立体声的升混技术具有将旧录音或业余制作焕发生机的潜力,同时也让我们能够创造我们最喜爱的歌曲的替代立体声混音。

尽管已经有若干次尝试来解决这个问题,但尚未建立标准方法。通过拥抱最近的生成式 AI 发展,可以创建一代新的单声道到立体声升混模型,使得技术变得更加有效且在社区中更广泛可用。

关于我

我是一名音乐学家和数据科学家,分享我对 AI 和音乐当前话题的思考。以下是我与本文相关的一些前期工作:

可以在 MediumLinkedin 上找到我!

参考文献

[1] M. Lagrange, L. G. Martins 和 G. Tzanetakis (2007):“使用声音源形成的半自动单声道到立体声升混”,见《音频工程学会会议》第 122 届。音频工程学会,2007 年。

[2] D. Fitzgerald (2011): “从单声道到源分离的方法”,发表于 2011 年第 17 届数字信号处理国际会议(DSP)。IEEE,2011 年,第 1-7 页。

[3] J. Serrà, D. Scaini, S. Pascual, 等人 (2023): “通过参数化立体声生成从单声道到立体声的转换”: arxiv.org/abs/2306.14647

[4] J. Copet, F. Kreuk, I. Gat 等人 (2023): “简单且可控的音乐生成”(2024 年 1 月 30 日修订版)。 arxiv.org/abs/2306.05284

[5] Y. Zang, Y. Wang & M. Lee (2024): “Ambisonizer: 作为球面调和函数生成的神经网络升混合”。 arxiv.org/pdf/2405.13428

[6] K.K. Parida, S. Srivastava & G. Sharma (2022): “超越单声道到双耳声道:利用深度和跨模态注意力从单声道音频生成双耳音频”,发表于 IEEE/CVF 冬季计算机视觉应用会议(WACV),2022 年,第 3347–3356 页。 Link

使用 Depth Anything V2 进行单目深度估计

原文:towardsdatascience.com/monocular-depth-estimation-with-depth-anything-v2-54b6775abc9f?source=collection_archive---------4-----------------------#2024-07-24

神经网络是如何从二维图像中学习估计深度的?

Avishek BiswasTowards Data Science Avishek Biswas

·发表于 Towards Data Science ·阅读时间:10 分钟·2024 年 7 月 24 日

--

什么是单目深度估计?

Depth Anything V2 算法(作者插图)

单目深度估计(MDE)是训练神经网络从单幅图像中提取深度信息的任务。这是一个激动人心且充满挑战的机器学习与计算机视觉领域,因为预测深度图要求神经网络仅凭二维图像来形成三维的理解。

在本文中,我们将讨论一个新的模型叫做Depth Anything V2及其前身Depth Anything V1。Depth Anything V2 在深度估计领域超越了几乎所有其他模型,在处理复杂图像时展现了令人印象深刻的结果。

Depth Anything V2 演示(来源:作者从 Depth Anything V2 演示页面录制的屏幕录像)

本文基于我制作的同主题视频。 这里是视频链接,适合喜欢视觉媒介的学习者。 对于喜欢阅读的朋友,继续往下看!

为什么我们要关心 MDE 模型?

良好的 MDE 模型有许多实际应用,如帮助机器人导航和避障、无人机以及自动驾驶汽车。它们还可以用于视频和图像编辑、背景替换、物体移除和创建三维效果。此外,它们在增强现实(AR)和虚拟现实(VR)头戴设备中也非常有用,可以为用户创造交互式三维空间

进行单目深度估计(MDE)有两种主要方法(本文仅介绍其中一种)

训练 MDE 模型已经出现了两种主要的方法:一种是判别方法,网络试图将深度作为监督学习目标进行预测;另一种是生成方法,如条件扩散,其中深度预测是一个迭代的图像生成任务。Depth Anything 属于第一类判别方法,今天我们将讨论的是这一部分。 欢迎来到 Neural Breakdown,让我们深入探讨深度估计!

传统数据集与 MiDAS 论文

为了全面理解 Depth Anything,让我们首先回顾一下 2019 年的 MiDAS 论文,它为 Depth Anything 算法提供了前身。

来源:MIDAS论文的截图(许可证:免费)

MiDAS 通过结合不同的数据集来训练 MDE 模型,这些数据集包含标注的深度信息。例如,KITTI数据集用于自动驾驶,提供户外图像,而NYU-Depth V2数据集则提供室内场景。了解这些数据集是如何收集的非常重要,因为像 Depth Anything 和 Depth Anything V2 这样的新模型解决了数据收集过程中固有的多个问题。

如何收集现实世界中的深度数据集

这些数据集通常是通过立体摄像头收集的,两个或更多的摄像头以固定距离放置,从略微不同的视角同时拍摄图像,从而实现深度信息的提取。NYU-Depth V2 数据集使用 RGB-D 摄像头,它不仅捕捉像素颜色,还记录深度值。一些数据集利用激光雷达(LiDAR),通过投射激光束来捕捉场景的三维信息。

然而,这些方法存在一些问题。 由于获取这些数据集的高运营成本,标注数据的数量有限。此外,标注可能会有噪音且分辨率较低。立体摄像头在各种光照条件下表现不佳,无法可靠地识别透明或高度反射的表面。激光雷达昂贵,并且激光雷达与 RGB-D 摄像头的范围有限,生成的深度图分辨率低且稀疏。

我们能否使用未标注图像来学习深度估计?

利用未标注的图像来训练深度估计模型将是非常有益的,因为网上有大量这样的图像。2023 年原始 Depth Anything 论文提出的主要创新是将这些未标注数据集整合进训练流程。接下来的部分,我们将探讨这一点是如何实现的。

Depth Anything 架构

原始的Depth Anything (V1)模型,发布于 2023 年采用了三步训练过程。在深入每个部分之前,让我们先对该算法进行一个高层次的概览。

Depth Anything V1 算法(作者插图)

步骤 1:教师模型训练

首先,名为 TEACHER 的神经网络模型被训练用于监督式深度估计,使用了五个不同的公开数据集。

从深度到视差空间的转换

TEACHER 模型初始化时使用了预训练的 Dino-V2 编码器,然后在合并后的标注数据集上进行训练。训练多个数据集时的一个主要挑战是绝对深度的差异。为了解决这个问题,深度值被转化为视差空间(d = 1 / t),并在每个深度图中进行归一化处理——离最近像素为 1,离最远像素为 0。通过这种方式,所有数据集共享相同的输出空间,从而使得模型可以预测视差。

不同的深度估计数据集提供的深度值具有不同的尺度。我们需要将它们对齐以使输出空间相同。视差可以让我们将所有深度值归一化到 0 和 1 之间(作者插图)

这两个损失函数用于训练这些模型:尺度位移不变损失和梯度匹配损失,这两个损失函数也在 2019 年的 MiDAS 论文中得到了应用。

  1. 尺度位移不变损失

使用简单的均方误差损失来比较预测图像与真实图像之间的差异时存在一个问题。假设图像中三个像素的真实深度值分别是1、0.5 和 0.1,而我们的网络预测的是0.9、0.6 和 0.3。尽管预测值不完全相同,但预测深度与真实深度之间的关系是相似的,只有一个乘法和加法因子的差异。我们不希望这种尺度和位移影响我们的损失函数——在应用均方误差损失之前,我们需要对两个图像进行对齐。

尺度位移不变损失(作者插图)

MiDaS 论文建议将真实深度和预测深度进行归一化,以确保零平移和单位尺度。计算中位数和标准差,然后相应地调整深度图的尺度和位移。对齐后,应用均方误差损失。

SSI 损失(来源:MiDAS 论文)(许可证:免费)

2. 梯度匹配损失

如果没有梯度匹配损失,深度图可能会变得过于模糊,失去锐利感(作者插图)

仅使用 SSI 损失可能会导致平滑的深度图,无法捕捉相邻像素之间的明显区别。梯度匹配损失通过将预测深度图的梯度与真实深度图的梯度对齐,帮助保留这些细节。

首先,我们计算预测的和地面真实深度图在 x 轴和 y 轴上的梯度,然后在梯度层面应用损失。MiDaS 还使用了一个具有四个尺度级别的多尺度梯度匹配损失。预测和地面真实深度图被下采样四次,并在每个分辨率上应用损失。

梯度匹配损失。该损失应用于多个下采样的深度图(未显示)。(插图作者提供)

最终损失是尺度不变损失和多尺度梯度匹配损失的加权和。虽然 SSI 损失鼓励模型学习一般的相对深度关系,但梯度匹配损失有助于保持场景中的锐利边缘和精细信息。

在 MIDAS 和 Depth Anything V1 中训练深度估计模型所使用的损失函数(插图作者提供)

步骤 2 — 伪标签无标签数据集

使用我们训练好的教师模型,我们现在可以为数百万张无标签图像添加注释,创建一个庞大的伪深度标签数据集。这些标签被称为伪标签,因为它们是 AI 生成的,可能并不代表实际的地面真实深度。我们现在拥有大量(伪)标签图像来训练一个新的网络。

伪标签图像(请注意,这张屏幕截图实际上来自 Depth Anything V2 论文,而非 V1)来源:Depth Anything V2 Paper(许可证:免费)

步骤 3 — 训练学生网络

回顾 Depth Anything V1 算法。我们现在处于步骤 3。(插图由作者制作)

我们将在标签数据集和伪标签数据集的结合上训练一个新的神经网络(学生网络)。然而,仅仅在教师网络提供的注释上训练网络并不会使模型超越基础教师模型的能力。为了使学生网络更强大,采用了两种策略:使用图像增强进行重度扰动,并引入辅助语义保持损失。

重度扰动

使用的一个有趣的扰动是 Cut Mix 操作。它涉及通过二进制掩膜将一对随机的无标签图像结合起来,用图像 B 替换图像 A 中的一个矩形区域。最终的损失是来自两个地面真实深度图的两部分的合成 SSI 和梯度匹配损失。这些空间失真还与颜色失真结合,帮助学生网络应对开放世界图像的多样性。

Cut Mix 操作(插图作者提供)

辅助语义保持损失

网络还通过一个辅助任务进行训练,称为语义辅助感知。使用像 Dino-V2 这样的强大预训练计算机视觉模型,它已经通过自监督方式在数百万张图像上进行了训练。给定一张图像,我们的目标是减少由新学生模型生成的嵌入与预训练 Dino-V2 编码器之间的余弦距离。这使得我们的学生模型能够捕捉到更大、更通用的 Dino-V2 模型的一些语义感知能力,并利用这些能力来预测深度图。

语义辅助感知(作者插图)

通过结合空间畸变、语义辅助感知以及标注和未标注数据集的力量,学生网络能够更好地泛化,并在深度估计上超过原始的教师网络!以下是来自 Depth Anything V1 模型的一些令人难以置信的结果!

Depth Anything V2

尽管 Depth Anything V1 的结果令人印象深刻,但它在处理透明物体和捕捉细粒度细节方面仍存在困难。 Depth Anything V2 的作者认为,模型性能的最大瓶颈不是架构本身,而是数据的质量。 大多数使用传感器捕获的标注数据集可能会有噪声,忽略细粒度细节,生成低分辨率的深度图,并且在光照条件以及反射/透明物体的处理上存在困难。

真实世界传感器数据集的问题(作者插图)

Depth Anything V2 舍弃了来自真实世界传感器(如立体相机、激光雷达和 RGB-D 相机)的标注数据集,而仅使用合成数据集。合成数据集是通过图形引擎生成的,而非通过设备捕获的。例如,使用 Unity 游戏引擎创建渲染图像和深度图的虚拟 KITTI 数据集,用于自动驾驶。还有一些室内数据集,如 IRS 和 Hyper-sim。Depth Anything V2 使用了五个包含接近 595K 张逼真图像的合成数据集。

合成数据集 vs 真实世界传感器数据集

合成图像确实有其优缺点。它们非常准确,具有高分辨率输出,能够捕捉最细微的细节,并且透明和反射表面的深度可以轻松获得。合成数据集直接获取所有所需的 3D 信息,因为图形引擎本身创建了场景。

从缺点方面来看,这些图像可能无法完全捕捉我们在现实场景中会遇到的图像。这些数据集的场景覆盖面也不够多样,且仅为现实世界图像的一个较小子集。Depth Anything 2 结合了合成图像和数百万未标注图像的优势,训练出了一个 MDE 模型,其性能超越了我们迄今为止见过的大多数模型。

合成或计算机生成数据集的优缺点(图示:作者提供)

和 V1 类似,V2 中的教师模型首先在标注数据集上进行训练。然而,在 V2 中,它仅在合成数据集上进行训练。在第二步,教师模型为所有未标注图像分配伪深度标签。最后,在第三步,学生模型仅在伪标注图像上进行训练——不使用任何真实标注数据集,也不使用合成数据集。在这一阶段,由于之前提到的分布偏移,合成数据集没有被使用。学生网络在由教师模型标注的真实世界图像上进行训练。就像 V1 中一样,辅助的语义保留损失与尺度-平移不变性损失和梯度匹配损失一同使用。

Depth Anything V2 架构(图示:作者提供)

视频链接,形象地解释这些概念

这里有一个视频,逐步解释了本视频中讨论的所有概念。

你也可以以视频形式学习本文中涉及的主题

Depth Anything V1 vs Depth Anything V2

原版的 Depth Anything 强调了在 MDE 训练流程中使用未标注图像的重要性。它引入了知识蒸馏流程,包括教师模型训练、伪标签化未标注图像,然后在标注和未标注图像的结合上训练学生网络。强空间和颜色扭曲的使用,以及语义辅助感知损失,有助于创建更通用和鲁棒的嵌入。这导致了为复杂场景生成高效且高质量的深度图。然而,Depth Anything V1 仍然在处理反射表面和细节时存在问题,因为来自现实传感器的深度标签噪声大且分辨率低。

Depth Anything V2 通过忽略真实世界传感器数据集,专门使用通过图形引擎生成的合成图像来训练教师网络,从而提升了性能。然后,教师网络为数百万未标注图像进行标注,学生网络仅在这些伪标注数据集上进行训练,并使用真实世界图像。通过这些技术,Depth Anything V2 现在能够预测精细级别的深度图,并更有效地处理透明和反射表面。

相关链接

MiDAS: arxiv.org/abs/1907.01341

Depth Anything: depth-anything.github.io/

Depth Anything V2: depth-anything-v2.github.io/

KITTI 数据集: www.cvlibs.net/datasets/kitti/

NYU V2: cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html

虚拟 KITTI: datasetninja.com/virtual-kitti

Youtube 视频: youtu.be/sz30TDttIBA

蒙特卡洛方法解析

原文:towardsdatascience.com/monte-carlo-methods-decoded-d63301bde7ce?source=collection_archive---------2-----------------------#2024-02-16

用模拟解决复杂问题

Hennie de HarderTowards Data Science Hennie de Harder

·发表于 Towards Data Science ·阅读时间 16 分钟·2024 年 2 月 16 日

--

“蒙特卡洛”这个名字来源于位于摩纳哥的著名蒙特卡洛赌场。这个名称的选择并不是因为与这些方法背后的数学原理有直接关联,而是因为它在隐喻上与随机性和机会相联系,这两个元素在赌博和蒙特卡洛模拟中都占据着核心地位。在这篇文章中,我们将讨论这一技术,并展示与项目管理、不规则区域近似以及游戏相关的代码示例。

现实世界中的系统和过程通常涉及不确定的参数和变量。使用蒙特卡洛方法,你可以明确地对这些不确定性进行建模。企业可以通过理解不同风险的概率和影响,做出更明智的决策。除了决策支持外,你还可以将其用于增强预测模型和/或沟通。

装有玻璃球的瓶子。图像由作者使用 Dall·E 创建。

基础知识

想象你有一个装满各种颜色玻璃球的大瓶子。问题是,你看不见瓶子里有多少种颜色的球。你想知道如果你不看瓶子,伸手进去,最有可能拿到哪种颜色的球。

蒙特卡罗方法解决强化学习问题

原文:towardsdatascience.com/monte-carlo-methods-for-solving-reinforcement-learning-problems-ff8389d46a3e?source=collection_archive---------7-----------------------#2024-09-04

剖析理查德·S·萨顿的《强化学习》并结合定制的 Python 实现,第 III 集

Oliver STowards Data Science Oliver S

·发布于 Towards Data Science ·18 分钟阅读·2024 年 9 月 4 日

--

我们继续深入探索 Sutton 关于强化学习的伟大著作 [1],并在此重点讨论蒙特卡罗(MC)方法。这些方法仅通过经验学习,即不需要任何环境模型,例如我们在上一篇文章中介绍的动态规划(DP)方法。

这非常有吸引力——因为通常模型是未知的,或者很难建模转移概率。考虑一下 二十一点 游戏:尽管我们完全理解游戏规则,通过 DP 方法解决它将非常繁琐——我们需要计算各种概率,例如,给定当前打出的牌,“黑杰克”出现的可能性有多大,是否有可能再发一张七点牌……通过 MC 方法,我们无需处理这些问题,只需玩游戏并从经验中学习。

图片由 Jannis Lucas 提供,来源于 Unsplash

由于不使用模型,蒙特卡罗方法(MC)是无偏的。它们在概念上简单且易于理解,但表现出较高的方差,且无法通过迭代方式(如自助法)解决。

如前所述,在这里我们将介绍这些方法,内容基于 Sutton 书中的第五章…

使用统计测试提升多变量 EDA 的稳健性

原文:towardsdatascience.com/more-robust-multivariate-eda-with-statistical-testing-d221fc145870?source=collection_archive---------5-----------------------#2024-04-16

图片由Jessica Ruscello提供,来源于Unsplash

通过结合相关的假设检验方法,提升在识别变量之间关系时的稳健性

Pararawendy IndarjoTowards Data Science Pararawendy Indarjo

·发表于Towards Data Science ·8 分钟阅读·2024 年 4 月 16 日

--

探索性数据分析(EDA)是数据科学家的基本技能。为了强调它的重要性,我认为 EDA 比机器学习建模技能更为重要。为什么?因为 EDA 在比机器学习建模更广泛的背景下都很有用。

遇到需要熟悉的新数据?做 EDA。想从数据中获得见解?做 EDA。即使你想创建一个机器学习模型?你仍然需要进行 EDA,以准备建模数据集。

如你所知,EDA 的一个步骤是多变量分析。此练习旨在识别数据集中变量之间是否存在关系。了解这些关系(如果存在)可能让我们根据数据的高层次背景采取适当的行动。

话虽如此,我们应该意识到由噪音引起的关系。在双变量可视化中观察到的两变量之间看似有趣的关系,可能是由于噪音而非真实的有意义关系。

幸运的是,统计学为我们提供了一个工具,帮助我们确定观察到的关系是否真正有意义,还是可能仅仅由于噪音造成的。是的,我们在谈论的是统计假设检验方法。

在这篇博文中,我将讨论如何使用统计测试方法,除了标准的双变量可视化之外,来提高多变量 EDA 练习的稳健性。

本文的其余部分将按以下方式组织。

  1. 数据集准备。

  2. 确定数值对数值变量关系:散点图 + 相关性检验。

  3. 确定数值对分类变量关系:KDE 图 + 单因素方差分析(ANOVA)。

  4. 确定分类对分类变量关系:计数图 + 卡方检验。

数据集准备

我们将使用流行的 MPG 数据集。它是一个公开可用的数据集,具有创作共用 4.0 许可证(允许为任何目的共享和改编数据集)。

数据集关于汽车燃油效率的细节(每加仑多少英里/MPG)及其他车辆属性。数据字典信息可以在其UCI 机器学习库页面找到。

我们加载数据集并去除缺失值(该数据集中缺失值很少,因此去除是可以接受的)。为了演示方法,我们还从mpg列创建了一个新的分类列,命名为efficiency。当 mpg 至少为 25 时,它的值为“yes”,否则为“no”。

# import libraries
import seaborn as sns
import pandas as pd

# load dataset via seaborn lib
df = sns.load_dataset("mpg")

# inspect missing values
df.isna().sum() #result: minimal, hence OK to remove

# remove missing values
df.dropna(inplace=True)

# create a new categorical column based on mpg column
df["efficiency"] = df["mpg"].apply(lambda x: "yes" if x >= 25 else "no")

# dataframe info
df.info()

# dataframe head
df.head()

df.info() 输出(图片来源:作者)

df.head() 输出(图片来源:作者)

关于在此数据集上进行多变量 EDA 的目标,我们自然希望了解哪些因素会影响汽车的燃油效率。为此,我们将回答以下问题:

  1. 哪些数值特征会影响 mpg 表现?

  2. 不同的原产地是否会导致 mpg 表现的差异?

  3. 不同的原产地是否会导致汽车效率的不同表现?

数值对数值关系

对于多变量 EDA 的第一个案例,让我们讨论如何识别两个数值变量之间的关系。在这种情况下,众所周知,我们可以使用散点图来直观检查变量之间是否存在关系。

如前所述,并非所有观察到的模式都是有意义的。在数值对数值的情况下,我们可以通过皮尔逊相关性检验来补充散点图。首先,我们计算绘制变量的皮尔逊相关系数。其次,通过计算其 p 值来判断获得的系数是否具有显著性。

后续步骤非常重要,它们是检验某个相关系数值是否足够大以被认为是有意义的理性检查(即,绘制的变量之间存在线性关系)。在小数据量的情况下尤其如此。例如,如果我们只有 10 个数据点,相关系数必须至少为 0.64 才能被认为是显著的(参考)!

在 Python 中,我们可以使用scipy库中的pearsonr函数来进行上述的相关性检验。

在以下代码中,我们绘制了每一对数值特征-mpg 列的散点图。作为标题,我们打印相关系数,并且如果系数显著(p 值 < 0.05),则加上条件双星号。

import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

# prepare variables to inspect
numeric_features = ['cylinders','displacement','horsepower',
                    'weight','acceleration','model_year']
target = 'mpg'

# Create a figure and axis
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 6))

# Loop through the numerical columns and plot each scatter plot
for i, col in enumerate(numeric_features):
    # Calculate Pearson correlation coefficient
    corr_coeff, p_val = pearsonr(df[col],df[target])

    # Scatter plot using seaborn
    sns.scatterplot(data=df, x=col, y=target, ax=axes[i//3, i%3])

    # Set title with Pearson correlation coefficient
    # Print ** after the correlation if the correlation coefficient is significant
    axes[i//3, i%3].set_title(f'{col} vs {target} (Corr: {corr_coeff:.2f} {"**" if p_val < 0.05 else ""})')

plt.tight_layout()
plt.show()

数值特征与 mpg(图片来源:作者)

注意到所有的图表标题中都包含了双星号,表示相关性是显著的。因此,我们可以得出以下结论:

  1. 气缸数、排量、马力和重量与 mpg 有强烈的负相关。这意味着这些变量的较高值对应较低的燃油效率。

  2. 加速和车型年份与 mpg(每加仑英里数)有中等程度的正相关。这意味着较长的加速时间(较慢的汽车)和更新生产的汽车与更高的燃油效率相关。

数值与类别关系

接下来,我们将研究 mpg 配置文件是否因原产地而异。请注意,原产地是一个类别变量。因此,我们正在考虑数值与类别之间的关系。

KDE(核密度估计)图,也称为平滑版的直方图,可以用来可视化 mpg 分布,并按原产地值进行分类。

就统计检验而言,我们可以使用单因素方差分析(ANOVA)。我们要检验的假设是,不同汽车原产地之间的 mpg 平均值是否存在显著差异。

在 Python 中,我们可以使用 scipy 库中的 f_oneway 函数来执行单因素方差分析。

在以下代码中,我们创建了一个按不同原产地值分类的 mpg KDE 图。接下来,我们运行单因素方差分析(ANOVA)并在标题中显示 p 值。

import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import f_oneway

# Create a KDE plot with hue
sns.set(style="whitegrid")
ax = sns.kdeplot(data=df, x="mpg", hue="origin", fill=True)

# Calculate one-way ANOVA p-value
p_value = f_oneway(*[df[df['origin'] == cat]['mpg'] for cat in df['origin'].unique()])[1]

# Set title with one-way ANOVA p-value
ax.set_title(f'KDE Plot mpg by origin (One-way ANOVA p-value: {p_value:.4f})')

plt.show()

按原产地分类的 MPG 的 KDE 图(图片来源:作者)

上述图中的 p 值小于 0.05,表示具有显著性。从高层次来看,我们可以这样解释图表:通常情况下,美国制造的汽车比其他地方制造的汽车燃油效率低(这是因为与其他原产地相比,美国 mpg 分布的峰值位于左侧)。

类别与类别关系

最后,我们将评估一种情况,即我们有两个类别变量。考虑到我们的数据集,我们将看看不同的原产地是否会产生不同的汽车效率模型。

在这种情况下,带分类的计数图是适当的双变量可视化方法。我们将展示按效率标记(是/否)分类的每个原产地的汽车频率。

就统计检验方法而言,卡方检验是适合的选择。使用此检验,我们希望验证不同汽车原产地是否具有不同的高效与低效汽车分布。

在 Python 中,我们可以使用 scipy 库中的 chisquare 函数。然而,与之前的情况不同,我们必须首先准备数据。具体来说,我们需要计算每个原始-效率值组合的“预期频率”。

对于那些希望更深入了解期望频率概念和卡方检验整体机制的读者,我推荐阅读我关于这个主题的博客,博客链接如下。

[## 了解 A/B 测试引擎:卡方检验

理解这个概念并从零开始进行操作

medium.com](https://medium.com/bukalapak-data/meet-the-engine-of-a-b-testing-chi-square-test-30e8a8ab44c5?source=post_page-----d221fc145870--------------------------------)

执行上述数据准备的代码如下所示。

# create frequency table of each origin-efficient pair
chi_df = (
    df[['origin','efficiency']]
    .value_counts()
    .reset_index()
    .sort_values(['origin','efficiency'], ignore_index=True)
)

# calculate expected frequency for each pair
n = chi_df['count'].sum()

exp = []
for i in range(len(chi_df)):
  sum_row = chi_df.loc[chi_df['origin']==chi_df['origin'][i],'count'].sum()
  sum_col = chi_df.loc[chi_df['efficiency']==chi_df['efficiency'][i],'count'].sum()
  e = sum_row * sum_col / n
  exp.append(e)

chi_df['exp'] = exp
chi_df

chi_df 结果(图由作者提供)

最后,我们可以执行下面的代码来绘制按效率标志划分的汽车来源计数图。此外,我们使用chi_df来执行卡方检验并获取 p 值。

import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import chisquare

# Create a count plot with hue
sns.set(style="whitegrid")
ax = sns.countplot(data=df, x="origin", hue="efficiency", fill=True)

# Calculate chi-square p-value
p_value = chisquare(chi_df['count'], chi_df['exp'])[1]

# Set title with chi-square p-value
ax.set_title(f'Count Plot efficiency vs origin (chi2 p-value: {p_value:.4f})')

plt.show()

效率与来源的计数图(图由作者提供)

该图表显示了不同来源的高效汽车在分布上存在差异(p 值 < 0.05)。我们可以看到,美国汽车大多效率低下,而日本和欧洲汽车则呈现相反的模式

摘要

在这篇博客文章中,我们学习了如何通过适当的统计检验方法改进双变量可视化。这将通过筛选掉基于仅凭视觉检查双变量图表所能看到的噪声引起的关系,从而提高我们多变量 EDA 的稳健性。

我希望这篇文章能在你下一次 EDA 练习中有所帮助!总之,感谢阅读,欢迎在LinkedIn与我联系!👋

大多数数据质量项目在启动之前就已经失败。原因如下。

原文:towardsdatascience.com/most-data-quality-initiatives-fail-before-they-start-heres-why-e66123b7bb3b?source=collection_archive---------6-----------------------#2024-07-23

展示你的数据质量评分卡,我可以告诉你一年前是否能成功。

Barr MosesTowards Data Science Barr Moses

·发布于 Towards Data Science ·阅读时间 7 分钟·2024 年 7 月 23 日

--

图片由 Braden Collum 提供,来源于 Unsplash

每天我都在与那些准备投入大量时间和资源,但注定失败的数据质量项目的组织进行对话。

提供激励措施和关键绩效指标(KPIs)能够驱动良好的行为这一点并不新鲜。销售补偿计划的审查如此严格,以至于它们经常成为董事会会议的议题。如果我们也能对数据质量评分卡给予同样的关注,结果会怎样呢?

即使是在全盛时期,传统的数据质量评分卡——比如 Hadoop 时代的评分卡——也很少获得完全成功。我知道这一点,因为在创办 Monte Carlo 之前,我曾在运营副总裁的岗位上工作多年,致力于创建能推动信任和采纳的数据质量标准。

在过去的几年里,云计算和元数据管理的进展使得组织海量数据变得可能。

数据工程过程正开始向更加成熟、严格的工程学科标准发展。当然,人工智能有潜力简化这一切。

尽管这个问题可能永远无法完全解决,但我看到一些组织采取了最佳实践,这些实践成为了推动项目成功与另一轮启动会议之间的分水岭。

这是构建数据质量评分卡的 4 个关键教训:

  • 知道哪些数据重要

  • 衡量机器

  • 正确地运用奖励与惩罚

  • 自动化评估与发现

知道哪些数据是重要的

任何与数据相关的计划最容易失败的方式就是假设所有数据的价值是相等的。而确定哪些数据重要的最佳方法就是与业务方沟通。

Red Ventures 的 Brendon Beidel 阐明了一个好的起点

“我会问:

  • 你如何使用这个表格?

  • 你什么时候查看这些数据?你什么时候报告这些数据?这些数据是否需要实时更新,按小时、每日更新?

  • 这有什么意义?

  • 如果这些数据延迟了,谁需要被通知?”

如果你在一个庞大的组织中工作,且员工遍布全球,可能这比说起来容易做起来难。

在这些情况下,我的建议是从你最关键的业务数据单元开始(如果你不知道这些,我帮不上忙!)。开始讨论需求和优先事项。

记住:先验证概念,再扩展。你会惊讶于有多少人反其道而行。

衡量机器

这种努力的一大挑战,简而言之,就是数据质量抵制标准化。质量是,并且应该是,依赖于使用场景的。

数据质量的六个维度是任何数据质量评分卡中的关键部分,也是一个重要的起点,但对许多团队而言,这只是开始——每个数据产品都是不同的。

例如,财务报告可能需要非常准确,并在时效性上有一定的宽容度,而机器学习模型则可能完全相反。

从实施角度来看,这意味着衡量数据质量通常是极度分散的。数据质量是按表格逐一衡量的,不同的分析师或数据管理员使用完全不同的数据质量规则,并赋予不同的权重。

这在某种程度上是有道理的,但很多东西在翻译过程中丢失了。

数据是多用途的,并在不同的使用场景中共享。不仅一个人的“黄色”质量评分是另一个人的“绿色”,而且数据消费者通常也很难理解“黄色”评分意味着什么,或者它是如何被评定的。他们也经常忽略由红色数据表提供数据给绿色数据表的影响(你懂的,垃圾进,垃圾出……)。

“黄色”评分卡的意义是什么?照片由 Keiron Crasktellanos 提供,来自 Unsplash

显示违反规则的数量当然很重要,但你还需要:

  • 尽可能将其放在上下文中,

  • 拥有一个聚合的端到端数据产品视图,

  • 投资一些强大的无代码数据分析工具,并

  • 意识到这还不够。

那么,你还需要什么?你需要衡量机器。

换句话说,数据生产和交付中通常导致高质量的组成部分。这些部分更容易标准化,也更容易跨业务单元和团队理解。

Airbnb Midas是一个知名的内部数据质量评分和认证项目,这也理所当然。他们在这个概念上投入了大量精力。他们衡量数据准确性,但可靠性、管理和可用性实际上占总评分的 60%。

许多数据团队仍在制定自己标准的过程中,但我们发现与数据健康高度相关的组成部分包括:

  • 前面提到的六个数据质量维度(有效性、完整性、一致性、及时性、唯一性、准确性)。

可用性与管理

  • 文档:对于数据资产、其用途以及过去的事件,具有一定程度的语义含义。一家在线旅游搜索公司会根据数据资产如何及在哪里被编目,以及其元数据的完整性,在其六个类别中的两个类别上进行评分。

  • 血统:在系统间能够追踪到数据在字段级别的来源。

  • 使用情况:表格接收到的查询次数以及具有下游依赖关系的数据产品数量。这可以是一个“关键资产评分”,并具有飞轮效应。你将可靠性工作重点放在最常用的部分,而人们会信任那些流行的部分。

系统可靠性

  • 监控:通常,如果一个数据产品不仅在最后一公里的表格上有强大的覆盖范围,而且在上游也有覆盖,这表明它是一个经过精心策划的资产。

  • 新鲜度:数据的新鲜度要求会根据数据产品类型有所不同,但这是一个表格级别的指标,可以识别并揭示偏离常规的情况。许多组织,如 罗氏诊断,会为其数据产品设定特定的新鲜度服务水平协议(SLA)并衡量遵守的程度。

  • 数据量:表格接收到的相对稳定的行数通常是数据管道和数据交付系统运作良好的标志。

  • 架构:至少你希望消费者能够看到架构变更。对于最关键的管道,你理想的做法是有一定程度的架构执行或数据契约,这样你就能知道当源数据发生变化时,是否会破坏下游的资产。

操作响应:

  • 所有权:一个资产是否有所有者?如果有技术和业务所有者,则更为加分。

  • 通知渠道与沟通: 数据交付是一个复杂的过程,涉及从数据采集到汇总再到消费的多个交接过程。在此基础上,理想情况下,你会有多个团队使用同一数据资产(否则你的数据网格更像是一个孤岛)。在这种环境中,拥有一个中心化的沟通渠道来突出和讨论变更和事件,是确保可靠数据产品的唯一方式。

  • 平均修复时间: 可以说,衡量数据集可信度最重要的指标是支持团队回应并修复出现的事件的速度。坏数据是不可避免的,优秀的事件响应是有意为之的。

把你的胡萝卜和棒子调整好

激励生产者和消费者提供高质量数据。照片由Jonathan Pielmayer拍摄,图片来源于Unsplash

“耶,又一套我们必须遵守的流程!”… 这句话从来没有人说过。

记住,衡量数据健康的目的是为了衡量数据健康,而不是单纯地为了数据健康本身。如Clark 在 Airbnb 所说,目的是“推动对生产和使用高质量数据的偏好”。

我看到的最佳实践是在数据入驻平台时,设置一套最低要求(棒)和在每个级别认证时设置更严格的要求(胡萝卜)。

认证起到激励作用,因为生产者确实希望消费者使用他们的数据,而消费者会迅速识别并对高可靠性的数据产生偏好。

自动化评估与发现

数据管理中几乎没有什么能在没有一定程度的自动化和自助服务能力的情况下成功。Airbnb 丢弃了任何 1) 无法立即理解的评分标准,和 2) 无法自动衡量的评分标准。

你的组织也必须做到这一点。即使这是史上最好的评分标准,如果你没有一套能够自动收集和展示数据的解决方案,那它就必须被丢进垃圾桶。

图片由作者提供。

我看到的最常见方式是通过data observability and quality solutions,和数据目录。例如,罗氏就是这样做的,并且在创建、发布和治理可信数据产品时,增加了访问管理。

来源

当然,也可以通过手动将多个数据系统中的元数据拼接成一个自制的发现门户来完成这项工作,但一定要注意维护开销。

量化的事物得以管理

数据团队已在现代数据和人工智能平台上进行了大量投资。但为了最大化这一投资,组织——无论是数据生产者还是消费者——必须完全采纳并信任所提供的数据。

归根结底,什么被衡量,什么就能被管理。这不正是最重要的吗?

我测试了前沿的多模态 LLM 在图表解读技能上的表现

原文:towardsdatascience.com/mulitmodal-llms-interpreting-charts-b212f5c0aa1f?source=collection_archive---------10-----------------------#2024-11-05

多模态 LLM 能准确推断基础图表吗?

Thuwarakesh MurallieTowards Data Science Thuwarakesh Murallie

·发表于Towards Data Science ·31 分钟阅读·2024 年 11 月 5 日

--

图片由作者使用 Flux 1.1 [Pro]制作

多模态 LLM(MLLM)承诺它们能够解读图像上的任何内容。对于大多数情况来说,这是真的,例如图像描述和物体检测。

但是它能合理且准确地理解图表上的数据吗?

如果你真的想开发一个应用程序,让它在你把相机对准汽车仪表盘时告诉你该做什么,那么 LLM 在图表解读技能上的表现应该是出色的。

当然,多模态 LLM 能够叙述图表上的内容,但理解数据并回答复杂的用户问题仍然具有挑战性。

我想知道这有多难。

我为 LLM 设置了八个挑战任务。每个任务都有一个基础图表以及一个问题,LLM 需要基于图表给出答案。我们知道正确答案,因为是我们创建了这些数据,但 LLM 只需利用给定的可视化信息推理出答案。

截至本文写作时,根据我的理解,市场上有五家主要的多模态大语言模型(LLM)提供商:OpenAI(GPT4o),Meta 的 Llama 3.2(11B 和 90B 模型),Mistral 及其全新的 Pixtral 12B,Cloude 3.5 Sonnet,以及 Google 的 Gemini 1.5。

多代理即服务——一位资深工程师的概述

原文:towardsdatascience.com/multi-agent-as-a-service-a-senior-engineers-overview-fc759f5bbcfa?source=collection_archive---------1-----------------------#2024-08-14

Saman (Sam) RajaeiTowards Data Science Saman (Sam) Rajaei

·发布于 Towards Data Science ·8 分钟阅读·2024 年 8 月 14 日

--

有关 AI 代理 的讨论很多——这些关键的自包含单元能够根据特定指令和上下文理解自主执行任务。事实上,这个话题的讨论几乎已经和 LLM 一样广泛。在本文中,我将探讨 AI 代理,尤其是从负责在生产系统中处理 AI 代理的首席工程师、架构师和站点可靠性工程师(SRE)的角度,探讨 多代理即服务 的概念。

上下文:AI 代理能解决哪些问题?

AI 代理擅长处理那些有利于人性化交互的任务:

  1. 电子商务: 基于 LLM 的 RAG 或 Text-to-SQL 等技术驱动的代理,根据公司政策回应用户查询,提供准确的答案,从而创造出更加个性化的购物体验和客户旅程,革命性地改变电子商务。

  2. 客户服务: 这是另一个理想的应用场景。我们中的许多人都经历过为了简单的查询(如订单状态更新)而长时间等待与客服代表沟通的情况。一些初创公司——例如 Decagon——正在通过 AI 代理在解决这些低效问题方面取得进展。

  3. 个性化产品和内容创作: 这一点的一个典型例子是 Wix——对于低代码或无代码的网站建设,Wix 开发了一款聊天机器人,通过互动问答环节,根据客户的描述和需求为其创建初步网站

“人类设定目标,但 AI 代理独立选择执行这些目标所需的最佳行动。”

总体来说,基于 LLM 的代理在模拟自然人类对话和简单业务工作流方面表现出色,通常能够产生既有效又令人印象深刻的结果。

工程师的视角:AI 代理与企业生产环境

考虑到上述提到的好处,你是否曾经想过 AI 代理在企业生产环境中的运作方式?哪些架构模式和基础设施组件最能支持它们?当问题不可避免地发生,代理出现幻觉、崩溃或(或许更糟糕的是)在执行关键任务时进行错误推理/规划时,我们该怎么办?

作为高级工程师,我们需要仔细考虑上述内容。此外,我们还必须问一个更重要的问题:我们如何定义一个多代理平台成功部署的标准是什么?

为了回答这个问题,让我们借用另一个软件工程领域的概念:来自可靠性工程的服务级目标(SLOs)。SLOs 是衡量服务性能和可靠性的关键组成部分。简单来说,SLOs 定义了“成功”测量与“所有”测量之间的可接受比例,并评估它们对用户旅程的影响。这些目标帮助我们确定代理及其所支持的更广泛工作流所需和预期的服务水平。

那么,SLOs 与我们的 AI 代理讨论有何关系?

以简化的视角来看,让我们考虑两个重要的目标——“可用性”和“准确性”——针对代理,并识别出一些更细化的 SLOs,它们有助于实现这些目标:

  1. 可用性: 这指的是代理或平台接收到的请求中,部分成功响应(例如 HTTP 200 状态码)的比例。从历史上看,底层服务器的正常运行时间和 ping 成功率(即时间性度量)是可用性的关键相关指标。但随着微服务的兴起,名义上的正常运行时间变得不那么重要。现代系统更关注用户请求的成功响应与失败响应的数量,以此作为更准确的可用性代理。其他相关指标可以被认为是延迟吞吐量

  2. 准确性: 另一方面,准确性不仅仅是代理返回客户端响应的速度和一致性,更重要的是它们在没有人工验证的情况下,能从业务角度上多么正确地执行任务并返回数据。传统系统也会跟踪类似的 SLOs,如数据的正确性和质量。

上述两个目标的衡量通常通过在运行时提交内部应用程序指标来进行,这些指标可以是在设定的时间间隔内提交(例如每 10 分钟),或响应事件(用户请求、上游调用等)。例如,合成探测可以用于模拟用户请求、触发相关事件并监控相关数据。这里需要探索的关键 思想是:传统系统在很大程度上是确定性的,因此通常更容易对其进行仪表化、探测和评估。另一方面,在我们这个美丽而非确定性的生成 AI 代理世界中,情况未必如此。

注意:本文的重点更多放在我们两个目标中的前者 —— 可用性。这包括确定接受标准,以设置基准的云/环境稳定性,帮助代理响应用户查询。要深入了解* 准确性(即为代理定义合理的任务范围,优化少量样本方法的性能和评估框架),* 这篇博客文章 是一个很好的入门文章。

[## 规划代理

我们的“循环系列”中的第四期,讨论了规划对于一个代理意味着什么,以及如何……

blog.langchain.dev](https://blog.langchain.dev/planning-for-agents/?source=post_page-----fc759f5bbcfa--------------------------------)

现在,回到工程师在部署代理时需要做对的事情,以确保基础设施的准备情况。为了实现我们的目标服务水平目标(SLO)并提供可靠且安全的平台,资深工程师始终考虑以下几个因素:

  1. 可扩展性:当请求数量增加(有时是突然增加时),系统能否有效处理?

  2. 成本效益:大型语言模型(LLM)的使用成本较高,那么我们如何监控和控制成本呢?

  3. 高可用性:我们如何保持系统始终可用并响应客户?代理是否能够自我修复并从错误/崩溃中恢复?

  4. 安全性:我们如何确保数据在存储和传输过程中是安全的,如何进行安全审计、漏洞评估等?

  5. 合规与监管:这是 AI 的一个重要话题,我们必须遵守哪些相关的数据隐私法规和其他行业特定标准?

  6. 可观测性:我们如何获得 AI 代理活动、健康状况和资源利用率的实时可见性,以便在问题影响用户体验之前识别和解决它们?

听起来很熟悉吗?这些与现代 web 应用程序、微服务模式和云基础设施所要解决的挑战类似。

那么,接下来呢?我们提出一个 AI 智能体开发和维护框架,遵循多年来在多个工程和软件领域发展出的最佳实践。

多智能体即服务(MAaaS)

这一次,让我们借鉴一些 基于云的应用最佳实践,重新定义智能体在生产系统中的设计方式:

  • 清晰的边界上下文: 每个智能体应具有明确定义的小范围职责,并明确功能边界。这种模块化方法确保了智能体更加精准,便于管理,并能够独立扩展。

  • RESTful 和异步服务间通信: 使用 RESTful API 在用户和智能体之间进行通信,并利用消息代理进行异步通信。这种方式解耦了智能体,提升了可扩展性和容错性。

  • 每个智能体的独立数据存储: 每个智能体应拥有自己的数据存储,以确保数据封装并减少依赖。在必要时,利用分布式数据存储解决方案来支持可扩展性。

  • 容器化与编排: 使用容器(例如 Docker)来包装和部署智能体,使其在不同环境中保持一致,简化部署和扩展。采用容器编排平台(如 Kubernetes)来管理智能体服务的部署、扩展和运营生命周期。

  • 测试和 CI/CD: 实施自动化测试(单元测试、集成测试、契约测试和端到端测试),确保智能体的可靠变更管理。使用 CI 工具,在代码更改提交时自动构建和测试智能体。建立 CD 流水线,实现无缝的生产环境变更部署,减少停机时间并确保快速迭代周期。

  • 可观察性: 实施强大的可观察性工具,如度量、追踪和日志记录,监控智能体及其支持基础设施,以实时展示平台的可靠性(追踪尤其重要,如果某个用户请求经过多个智能体)。计算并跟踪智能体的服务水平目标(SLO)和错误预算,以及聚合的请求流。通过合成探测和高效的警报机制,在智能体健康问题对最终用户造成广泛影响之前,及时发现警告和故障。

通过应用这些原则,我们可以为 AI 智能体创建一个稳健的框架,将这一概念转变为“多智能体即服务”(MAaaS)。这种方法借鉴了基于云的应用最佳实践,重新定义了智能体的设计、部署和管理方式。

图片来源:作者

智能体在业务操作中扮演着至关重要的角色;然而,它并非孤立运行。强大的基础设施支持着它,确保其满足生产期望,关键组件包括:

  • 面向服务的架构: 将代理设计为可以轻松集成到现有系统中的服务。

  • API 网关: 使用 API 网关管理和保护客户端与代理之间的流量。

  • 弹性基础设施: 利用能够根据需求弹性扩展或缩减资源的云基础设施。

  • 托管服务: 利用托管服务来管理数据库、向量存储、消息传递和机器学习,减少运营负担。

  • 集中监控: 使用集中监控解决方案(例如 CloudWatch、Prometheus、Grafana)来跟踪代理的健康状况和性能。

为了突出这一点,我们将演示一个简单的多代理系统:一个辩论平台。

示例:多代理辩论系统

我们设计了一个多代理辩论系统来展示 MAaaS 的实际应用。辩论的主题是AI 对就业市场的影响。设置中有三个代理:

  • 队伍 A,支持 AI 对工作的益处

  • 队伍 B,持反对意见

  • 主持人管理辩论,在八轮之后或当讨论变得冗余时结束辩论。

专注于系统架构,我们使用PhiData创建代理,并通过 AWS 弹性 Kubernetes 服务(EKS)进行部署,以确保高可用性。代理活动通过 AWS CloudWatch 进行监控,EKS 的服务发现确保代理之间无缝通信。重要的是,对话历史存储在数据库中,这样任何备份代理在发生故障时都可以不中断地继续对话。这种弹性通过消息队列得到增强,消息队列确保在完全消费消息之前不会处理它们,从而确保数据完整性。为了保持对话流畅,每个代理目前仅限于一个副本,尽管 Kubernetes 会确保在 Pod 宕机时始终保持所需的状态。

作者提供的图片

为了让用户能够在本地尝试系统,我们创建了一个 MiniKube 部署 YAML 文件。在这个简化版本中,我们去掉了 postgres 数据库。相反,每个代理将临时将其对话历史存储在内存中。这个调整使系统更加轻量化,并且便于本地部署,同时仍展示了核心功能。你需要先在系统上安装MiniKubeOllamakubectl

将上述内容保存为名为deploy.yml的文件,然后运行:

$ minikube start
$ ollama run llama3
$ kubectl apply -f deploy.yml

启动辩论(MiniKube 在基于 Linux 的系统与 Windows 系统上表现略有不同):

$ kubectl get pods
$ kubectl exec <host-pod-name> -- curl -X GET 'http://localhost:8080/agent/start_debate'

获取辩论历史:

$ kubectl exec <host-pod-name> -- curl -X GET 'http://localhost:8080/agent/chat-history'

拆除资源:

$ kubectl delete -f .\minikube-deploy.yml

代理们展开了精彩的辩论(请参见下面附录中的辩论输出)。

结论

对多智能体系统的兴趣为创新和效率开辟了诸多可能性。通过利用云原生原则和最佳实践,我们可以创建可扩展、具有成本效益、安全且高度可用的多智能体系统。MAaaS 范式不仅与现代软件工程原则一致,而且为更复杂且适合生产的 AI 应用铺平了道路。在我们继续探索和完善这些概念的过程中,多智能体系统在各个行业中变革的潜力变得越来越有前景。

注意:本文由 Sam Rajaei Guanyi Li* 合作撰写。*

附录:辩论输出

感谢您的关注,下次再见!

使用 Hugging Face 代码代理的多代理 RAG 系统

原文:towardsdatascience.com/multi-agentic-rag-with-hugging-face-code-agents-005822122930?source=collection_archive---------0-----------------------#2024-12-31

使用 Qwen2.5–7B-Instruct 驱动的代码代理创建一个本地的、开源的多代理 RAG 系统

Gabriele Sgroi, 博士Towards Data Science Gabriele Sgroi, 博士

·发表于 Towards Data Science ·阅读时长:61 分钟 ·4 天前

--

图片来自 Jaredd CraigUnsplash

大语言模型(LLM)展示了令人印象深刻的能力,并且随着每一代新模型的发布,它们仍在持续改进。像聊天机器人和摘要生成等应用可以直接利用 LLM 的语言能力,因为它们只需要生成文本输出,这是它们的自然环境。大语言模型还展示了理解和解决复杂任务的强大能力,但只要它们的解决方案停留在“纸面上”,即仅仅是纯文本形式,它们就需要外部用户代表它们采取行动,并反馈建议动作的结果。代理系统通过让模型在其环境中进行操作来解决这个问题,通常通过一组可以执行特定操作的工具。通过这种方式,LLM 可以通过反复试验与环境互动,逐步找到解决方案。

一种有趣的情况是,当一个大语言模型(LLM)代理可以访问的工具本身就是代理时:这就是多代理系统的核心概念。多代理系统通过将任务分配并委托给专门的模型,并像拼图一样将它们的输出组合在一起,从而解决任务。实现这种系统的常见方法是使用管理代理来协调和组织其他代理的工作流程。

代理系统,特别是多代理系统,需要一个强大的 LLM 作为支撑才能正常运作,因为底层模型需要能够理解各种工具的目的和适用性,并将原始问题分解为每个工具可以处理的子问题。因此,像 ChatGPT 或 Anthropic 的 Claude 这样的专有模型通常是代理系统的默认解决方案。幸运的是,开源的 LLM 在性能上持续取得了巨大进步,甚至在某些情况下,它们与专有模型的表现相媲美。更有趣的是,现在适度规模的开源 LLM 已经能够执行一些几年前还难以想象的复杂任务。

在这篇博客文章中,我将展示如何让一个可以运行在消费级硬件上的“小型”大语言模型(LLM)足够强大,以支持一个多代理系统并取得良好的结果。特别地,我将提供一个教程,教你如何使用Qwen2.5–7B-Instruct来创建一个多代理的 RAG 系统。你可以在以下的GitHub 仓库找到代码实现,以及一个说明性的Colab 笔记本

在深入系统架构的细节之前,我将回顾一些关于 LLM 代理的基本概念,这些概念有助于更好地理解这个框架。

ReAct

ReAct,提出于ReAct: Synergizing Reasoning and Acting in Language Models,是一个用于构建 LLM 代理的流行框架。该方法的主要思想是将“思维链”提示的有效性融入到代理框架中。ReACT 包含交替进行的推理和行动步骤:大语言模型在执行动作之前,首先被提示提供一系列思维过程。通过这种方式,模型可以创建动态的推理痕迹来引导行动,并在与环境互动的过程中更新高层次的计划。这使得解决给定任务的过程能够采取迭代和增量的方式。在实践中,ReAct 代理的工作流程由思维、行动和观察三个步骤组成:模型在思维步骤中产生关于一般计划和具体工具使用的推理,然后在行动步骤中调用相关工具,最后在观察步骤中接收来自环境的反馈。

以下是 ReACT 框架的示例。

ReACT、思维链和仅行动框架在问答任务中的比较。图片来源于ReAct: Synergizing Reasoning and Acting in Language Models

代码代理

代码代理是一种特定类型的 LLM 代理,使用可执行的 Python 代码与环境进行交互。它们基于论文可执行代码操作能激发更好的 LLM 代理中提出的 CodeAct 框架。CodeAct 与 ReAct 框架非常相似,不同之处在于每个操作都由任意可执行代码组成,可以执行多个操作。手工制作的工具作为常规 Python 函数提供给代理,代理可以在代码中调用这些函数。

代码代理相较于使用 JSON 或其他文本格式执行操作的传统代理具有一系列独特的优势:

  • 它们可以结合现有的软件包和手工制作的特定任务工具来使用。

  • 它们可以通过使用在发生错误后返回的错误信息来进行自我调试。

  • LLMs 熟悉编写代码,因为代码通常广泛出现在它们的预训练数据中,使其成为编写操作的更自然格式。

  • 代码自然允许存储中间结果并在一次操作中组合多个操作,而 JSON 或其他文本格式可能需要多次操作才能完成相同的任务。

基于这些原因,代码代理能够比使用 JSON 或其他文本格式执行操作的代理提供更好的性能和更快的执行速度。

代码代理与使用 JSON 或文本作为操作的代理对比。图片来源:可执行代码操作能激发更好的 LLM 代理

以下是原始论文中的一个具体例子,展示了代码代理如何在解决某些任务时需要更少的操作。

代码代理与使用 JSON/文本操作格式的代理对比。代码代理可以在一次操作中执行多个操作。图片来源:可执行代码操作能激发更好的 LLM 代理。[RIVEDERE]

Hugging Face 的 transformers 库提供了构建代理,特别是代码代理的有用模块。Hugging Face 的 transformer 代理框架将清晰性和模块化作为核心设计原则。这些原则在构建代理系统时尤为重要:由于工作流的复杂性,必须对架构中的所有相互关联的部分进行控制。这些设计选择使得 Hugging Face 代理成为构建自定义和灵活代理系统的优秀工具。当使用开源模型为代理引擎提供支持时,Hugging Face 代理框架还有一个额外的优势,即可以轻松访问 Hugging Face 生态系统中现有的模型和工具。

Hugging Face 代码代理还解决了不安全代码执行的问题。事实上,让 LLM 不加限制地生成代码可能会带来严重风险,因为它可能执行不希望发生的操作。例如,一个幻觉可能导致代理删除重要文件。为了减轻这种风险,Hugging Face 代码代理实现采用了从头开始的安全代码执行方法:代码解释器只能执行明确授权的操作。这与通常的自上而下的范式不同,后者从一个完全功能的 Python 解释器开始,然后禁止可能危险的操作。Hugging Face 的实现包括一个可以执行的安全授权函数列表,并提供了一个可以导入的安全模块列表。除非已被用户预先授权,否则其他任何操作都不可执行。你可以在他们的博客文章中阅读更多关于 Hugging Face(代码)代理的内容:

代理 RAG

检索增强生成(RAG)已成为涉及大型语言模型(LLM)信息检索任务的事实标准。它可以帮助保持 LLM 的信息更新,提供特定信息的访问权限,并减少幻觉(hallucinations)。它还可以通过返回模型用来生成答案的来源,增强人类可解释性和监督性。传统的 RAG 工作流程包括基于与用户查询语义相似度的检索过程以及通过检索到的信息增强模型上下文,但这种方法无法解决一些特定任务。一些不适合传统 RAG 的情况包括需要与信息来源互动的任务、需要多个信息片段才能回答的查询以及需要复杂操作才能与来源中的实际信息连接的复杂查询。

对于传统 RAG 系统的一个具体挑战性示例是多跳问答(MHQA)。它涉及提取和组合多个信息片段,可能需要多次迭代推理过程,以处理提取的信息和仍然缺失的部分。例如,如果模型被问到“桦木胶合板在乙醇中会浮吗?”,即使用于 RAG 的来源包含关于这两种材料密度的信息,如果这两条信息没有直接关联,标准的 RAG 框架也可能失败。

一种常见的方式来增强 RAG 以避免上述不足,是使用代理系统。一个 LLM 代理可以将原始查询拆解成一系列子查询,然后使用语义搜索作为工具来检索这些生成的子查询的段落,随着更多信息的收集,它可以改变和调整计划。它可以自主决定是否已经收集到足够的信息来回答每个查询,或者是否应该继续搜索。代理 RAG 框架还可以通过扩展为多代理系统进一步增强,在该系统中,每个代理都有自己的任务和职责。这使得,例如,可以将高层任务规划与与文档源的交互分开。在下一节中,我将描述这种系统的实际实现。

带有代码代理的多代理 RAG

在这一节中,我将讨论我用来基于代码代理实现多代理 RAG 系统的一般架构选择,该系统遵循 ReAct 框架。你可以在以下GitHub 仓库中找到完整代码实现的剩余细节。

多代理系统的目标是通过在维基百科上搜索必要的信息来回答一个问题。它由 3 个代理组成:

  • 一个管理代理,其工作是将任务拆解成子任务,并利用它们的输出提供最终答案。

  • 一个维基百科搜索代理,用于在维基百科上找到相关页面并将从中提取的信息结合起来。

  • 一个页面搜索代理,用于从提供的维基百科页面中检索并总结与给定查询相关的信息。

这三个代理以分层方式组织:每个代理可以将位于层次结构下方的代理作为工具使用。特别是,管理代理可以调用维基百科搜索代理来查找查询信息,而后者又可以使用页面搜索代理从维基百科页面中提取特定信息。

以下是架构图,指定了每个代理可以调用的手工工具(包括包装其他代理的工具)。请注意,由于代码代理通过代码执行来操作,这些并不是它们唯一可以使用的工具,因为任何本地 Python 操作和函数(只要获得授权)也可以使用。

显示代理和手工工具的架构图。图片由作者提供。

让我们深入探讨架构中涉及的代理的工作细节。

管理代理

这是顶层代理,它接收用户的提问并负责返回答案。它可以通过向维基百科搜索代理发送查询并接收搜索的最终结果来使用该工具。它的目的是通过将用户问题拆解成一系列子查询并将搜索结果汇总,收集来自维基百科的必要信息。

以下是为此代理使用的系统提示符。它建立在默认的 Hugging Face 提示符模板之上。请注意,提示符中提供的示例遵循了支持此代理的模型的聊天模板,在这种情况下是Qwen2.5–7B-Instruct

You are an expert assistant who can find answer on the internet using code blobs and tools. To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
You will be given the task of answering a user question and you should answer it by retrieving the necessary information from Wikipedia. Use and trust only the information you retrieved, don't make up false facts.
To help you, you have been given access to a search agent you can use as a tool. You can use the search agent to find information on Wikipedia. Break down the task into smaller sub-tasks and use the search agent to find the necessary information for each sub-task.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_action>' sequence.
During each intermediate step, you can use 'print()' to save whatever important information you will then need. These print outputs will be provided back to you by the user in the 'Observation:' field, which will be available as input for the next steps. Always print the output of tools, don't process it or try to extract information before inspecting it.
If an error rise while executing the code, it will be shown in the 'Observation:' field. In that case, fix the code and try again.

In the end you have to return a final answer using the `final_answer` tool.

Here are a few notional examples:
---
<|im_start|>user
Task: When was the capital of Italy founded?<|im_end|>
<|im_start|>assistant
Thought: Let's break up the task: I first need to find the capital of Italy and then look at its foundation date. I will use the tool `wikipedia_search_agent` to get the capital of Italy. Code:
```py

result = wikipedia_search_agent("意大利首都")

print("意大利的首都:", result)

```py<end_action><|im_end|>
<|im_start|>user
[OUTPUT OF STEP 0] -> Observation:
Capital of Italy:According to the information extracted from the Wikipedia page 'Rome', the capital of Italy is Rome.<|im_end|>
<|im_start|>assistant
Thought: Now that I know that the capital of Italy is Rome, I can use the `wikipedia_search_agent` tool to look for its foundation date.
Code:
```py

result = wikipedia_search_agent("罗马创立日期")

print("罗马创立时间:", result)

```py<end_action><|im_end|>
<|im_start|>user
[OUTPUT OF STEP 1] -> Observation:
Rome foundation: According to the information from the Wikipedia page 'Natale di Roma', the traditional foundation date of Rome is April 21, 753 BC.<|im_end|>
<|im_start|>assistant
Thought: Now that I have retrieved the relevant information, I can use the `final_answer` tool to return the answer.
Code:
```py

final_answer("根据传说,罗马是在公元前 753 年 4 月 21 日建立的,但考古证据表明其发展可以追溯到青铜时代。")

```py<end_action><|im_end|>
---
<|im_start|>user
Task: "What's the difference in population between Shanghai and New York?"<|im_end|>
<|im_start|>assistant
Thought: I need to get the populations for both cities and compare them: I will use the tool `search_agent` to get the population of both cities.
Code:
```py

population_guangzhou_info = wikipedia_search_agent("纽约市人口")

population_shanghai_info = wikipedia_search_agent("上海人口")

print("广州人口:", population_guangzhou)

print("上海人口:", population_shanghai)

```py<end_action><|im_end|>
<|im_start|>user
[OUTPUT OF STEP 0] -> Observation:
Population Guangzhou: The population of New York City is approximately 8,258,035 as of 2023.
Population Shanghai: According to the information extracted from the Wikipedia page 'Shanghai', the population of the city proper is around 24.87 million inhabitants in 2023.<|im_end|>
<|im_start|>assistant
Thought: Now I know both the population of Shanghai (24.87 million) and of New York City (8.25 million), I will calculate the difference and return the result.
Code:
```py

population_difference = 24.87*1e6 - 8.25*1e6

answer=f"上海和纽约之间的人口差异是{population_difference}人。"

final_answer(answer)

```py<end_action><|im_end|>
---

On top of performing computations in the Python code snippets that you create, you have access to those tools (and no other tool):

<<tool_descriptions>>

<<managed_agents_descriptions>>

You can use imports in your code, but exclusively from the following list of modules: <<authorized_imports>>.  Do not try to import other modules or else you will get an error.
Now start and solve the task!

维基百科搜索代理

这个代理向经理代理报告,它从经理代理接收查询任务,并负责返回从维基百科检索到的信息。它可以访问两个工具:

  • 一个维基百科搜索工具,使用wikipedia package中的内置搜索功能。它接收一个查询并返回维基百科页面及其摘要的列表。

  • 一个页面搜索代理,从特定的维基百科页面检索查询信息。

这个代理收集信息以回答查询,将其分解为更小的子查询,并在需要时结合多个页面的信息。这是通过使用维基百科包的搜索工具来识别可能包含必要信息的页面来实现的:代理可以使用报告的页面摘要或调用页面搜索代理,从特定页面提取更多信息。收集到足够的数据后,它会返回答案给经理代理。

系统提示符再次是对 Hugging Face 默认提示符的轻微修改,后面跟随了一些具体示例,遵循模型的聊天模板。

You are an expert assistant that retrieves information from Wikipedia using code blobs and tools. To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
You will be given a general query, your task will be of retrieving and summarising information that is relevant to the query from multiple passages retrieved from the given Wikipedia page. Use and trust only the information you retrieved, don't make up false facts. Try to summarize the information in a few sentences.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_action>' sequence.
During each intermediate step, you can use 'print()' to save whatever important information you will then need. These print outputs will be provided back to you by the user in the 'Observation:' field, which will be available as input for the next steps. Always print the output of tools, don't process it or try to extract information before inspecting it.
If an error rise while executing the code, it will be shown in the 'Observation:' field. In that case, fix the code and try again.

In the end you have to return a final answer using the `final_answer` tool.

Here are a few notional examples:
---
<|im_start|>user
Task: Retrieve information about the query:"What's the capital of France?" from the Wikipedia page "France".<|im_end|>
<|im_start|>assistant
Thought: I need to find the capital of France. I will use the tool `retrieve_passages` to get the capital of France from the Wikipedia page.
Code:
```py

result = retrieve_passages("法国首都")

print("法国的首都:", result)

```py<end_action><|im_end|>
<|im_start|>user
[OUTPUT OF STEP 0] -> Observation:
Retrieved passages for query "France capital":
Passage 0: ... population of nearly 68.4 million as of January 2024\. France is a semi-presidential republic with its capital in Paris, the ...
Passage 1: ... France, officially the French Republic, is a country located primarily in Western Europe. Its overseas regions and territories ...
Passage 2: ... The vast majority of France's territory and population is situated in Western Europe and is called Metropolitan France. It is ...
Passage 3: ... France is a highly urbanised country, with its largest cities (in terms of metropolitan area population in 2021) being Paris ...
Passage 4: ... === Government ===\nFrance.fr – official French tourism site (in English)...<|im_end|>
<|im_start|>assistant
Thought: Now that I know that the capital of France is Paris, I can use the `final_answer` tool to return the answer.
Code:
```py

final_answer("法国的首都是巴黎。")

```py<end_action><|im_end|>
---
<|im_start|>user
Task: Retrieve information about the query:"Tallest mountain in the World" from the Wikipedia page "List of highest mountains on Earth"<|im_end|>
<|im_start|>assistant
Thought: I need to find the tallest mountain in the world. I will use the tool `retrieve_passages` to look for data on the Wikipedia page.
Code:
```py

result = retrieve_passages("最高山脉")

print(result)

```py<end_action><|im_end|>
<|im_start|>user
[OUTPUT OF STEP 1] -> Observation:
Retrieved passages for query "highest mountain":
Passage 0: ... above sea level) is the world's tallest mountain and volcano, rising about 10,203 m (33,474 ft) from the Pacific Ocean floor. ...
Passage 1: ... As of December 2018, the highest peaks on four of the mountains—Gangkhar Puensum, Labuche Kang III, Karjiang, and Tongshanjiabu, all located in Bhutan or China—have not been ascended. ...
Passage 2: ... The highest mountains above sea level are generally not the highest above the surrounding terrain. ...
Passage 3: ... The highest mountain outside of Asia is Aconcagua (6,961 m or 22,838 ft), the 189th highest in the world. ...
Passage 4: ... the southern summit of Peru's tallest mountain, Huascarán, is another contender. Both have elevations above sea level more than 2 km (1.2 mi) less than that of Everest....
<|im_end|>
<|im_start|>assistant
Thought: The results don't clearly specify a clear result for the world's tallest mountain, I will use the tool `web_results` with a different query.
Code:
```py

result = retrieve_passages("世界上最高的山")

print(result)

```py<end_action><|im_end|>
<|im_start|>user
Passages retrieved from page List of highest mountains on Earth:
Passage 0: ... The highest mountain outside of Asia is Aconcagua (6,961 m or 22,838 ft), the 189th highest in the world....
Passage 1: ... above sea level) is the world's tallest mountain and volcano, rising about 10,203 m (33,474 ft) from the Pacific Ocean floor. ...
Passage 2: ... The bases of mountain islands are below sea level, and given this consideration Mauna Kea (4,207 m (13,802 ft) above sea level) is the world's tallest mountain and volcano, rising about 10,203 m (33,474 ft) from the Pacific Ocean floor. ...
Passage 3: ... the southern summit of Peru's tallest mountain, Huascarán, is another contender. Both have elevations above sea level more than 2 km (1.2 mi) less than that of Everest. ...
Passage 4: ... The highest mountains are also not generally the most voluminous. Mauna Loa (4,169 m or 13,678 ft) is the largest mountain on Earth in terms of base area (about 5,200 km2 or 2,000 sq mi) and volume (about 42,000 km3 or 10,000 cu mi)...<|im_end|>
<|im_start|>assistant
Thought: I have found that Mauna Kea is the world's tallest mountain rising about 10,203 m (33,474 ft) from the Pacific Ocean floor. I can use the `final_answer` tool to return the relevant information.
Code:
```py

final_answer("茂纳基阿是世界上最高的山脉,从太平洋海底升起约 10,203 米(33,474 英尺)。")

```py<end_action><|im_end|>
___
On top of performing computations in the Python code snippets that you create, you have access to those tools (and no other tool):

<<tool_descriptions>>

<<managed_agents_descriptions>>

You can use imports in your code, but only from the following list of modules: <<authorized_imports>>. Do not try to import other modules or else you will get an error.
Now start and solve the task!

页面搜索代理

这个代理报告给维基百科搜索代理,后者为其提供查询和维基百科页面的标题,任务是从该页面中检索相关信息以回答查询。 本质上,这是一个单代理的 RAG 系统。 为了执行该任务,这个代理生成自定义查询,并使用语义搜索工具来检索与查询更为相似的段落。 语义搜索工具的实现非常简单,将页面内容拆分成块,并使用 LangChain 提供的 FAISS 向量数据库进行嵌入。

以下是系统提示,仍然基于 Hugging Face 默认提供的提示构建

You are an expert assistant that finds answers to questions by consulting Wikipedia, using code blobs and tools. To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
You will be given a general query, your task will be of finding an answer to the query using the information you retrieve from Wikipedia. Use and trust only the information you retrieved, don't make up false facts. Cite the page where you found the information.
You can search for pages and their summaries from Wikipedia using the `search_wikipedia` tool and look for specific passages from a page using the `search_info` tool. You should decide how to use these tools to find an appropriate answer:some queries can be answered by looking at one page summary, others can require looking at specific passages from multiple pages.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_action>' sequence.
During each intermediate step, you can use 'print()' to save whatever important information you will then need. These print outputs will be provided back to you by the user in the 'Observation:' field, which will be available as input for the next steps. Always print the output of tools, don't process it or try to extract information before inspecting it.
If an error rise while executing the code, it will be shown in the 'Observation:' field. In that case, fix the code and try again.

In the end you have to return a final answer using the `final_answer` tool.

Here are a few notional examples:
---
<|im_start|>user
Task: When was the ancient philosopher Seneca born?<|im_end|>
<|im_start|>assistant
Thought: I will use the tool `search_wikipedia` to search for Seneca's birth on Wikipedia. I will specify I am looking for the philosopher for disambiguation.
Code:
```py

result = search_wikipedia("塞内卡 哲学家 出生")

print("result")

```py<end_action><|im_end|>
<|im_start|>user
[OUTPUT OF STEP 0] -> Observation:
Pages found for query 'Seneca philosopher birth':
Page: Seneca the Younger
Summary: Lucius Annaeus Seneca the Younger ( SEN-ik-ə; c.4 BC – AD 65), usually known mononymously as Seneca, was a Stoic philosopher of Ancient Rome, a statesman, dramatist, and in one work, satirist, from the post-Augustan age of Latin literature.
Seneca was born in Colonia Patricia Corduba in Hispania, a
Page: Phaedra (Seneca)
Summary: Phaedra is a Roman tragedy written by philosopher and dramatist Lucius Annaeus Seneca before 54 A.D. Its 1,280 lines of verse tell the story of Phaedra, wife of King Theseus of Athens and her consuming lust for her stepson Hippolytus. Based on Greek mythology and the tragedy Hippolytus by Euripides,
Page: Seneca the Elder
Summary: Lucius Annaeus Seneca the Elder ( SEN-ik-ə; c.54 BC – c. AD 39), also known as Seneca the Rhetorician, was a Roman writer, born of a wealthy equestrian family of Corduba, Hispania. He wrote a collection of reminiscences about the Roman schools of rhetoric, six books of which are extant in a more or
Page: AD 1
Summary: AD 1 (I) or 1 CE was a common year starting on Saturday or Sunday, a common year starting on Saturday by the proleptic Julian calendar, and a common year starting on Monday by the proleptic Gregorian calendar. It is the epoch year for the Anno Domini (AD) Christian calendar era, and the 1st year of
Page: Seneca Falls Convention
Summary: The Seneca Falls Convention was the first women's rights convention. It advertised itself as "a convention to discuss the social, civil, and religious condition and rights of woman". Held in the Wesleyan Chapel of the town of Seneca Falls, New York, it spanned two days over July 19–20, 1848\.  Attrac
<|im_start|>assistant
Thought: From the summary of the page "", I can see that Seneca was born in . I can use the `final_answer` tool to return the answer.
Code:
```py

final_answer("根据维基百科页面《年轻塞内卡》中的内容,塞内卡生于公元前 4 年。")

```py<end_action><|im_end|>
---
<|im_start|>user
Task: Who was Charlemagne predecessor?<|im_end|>
<|im_start|>assistant
Thought: I will use the tool `search_wikipedia` to search for Charlemagne reign duration.
Code:
```py

result = search_wikipedia("查理曼的前任")

print(result)

```py<end_action><|im_end|>
<|im_start|>user
[OUTPUT OF STEP 0] -> Observation:
Pages found for query 'Charlemagne predecessor':
Page: Charlemagne
Summary: Charlemagne ( SHAR-lə-mayn; 2 April 748 – 28 January 814) was King of the Franks from 768, King of the Lombards from 774, and Emperor of what is now known as the Carolingian Empire from 800, holding these titles until his death in 814\. He united most of Western and Central Europe, and was the first
Page: Pope Leo III
Summary: Pope Leo III (Latin: Leo III; died 12 June 816) was bishop of Rome and ruler of the Papal States from 26 December 795 to his death. Protected by Charlemagne from the supporters of his predecessor, Adrian I, Leo subsequently strengthened Charlemagne's position by crowning him emperor. The coronation
Page: Throne of Charlemagne
Summary: The Throne of Charlemagne (German: Karlsthron or Aachener Königsthron, "Royal Throne of Aachen") is a throne erected in the 790s by Charlemagne, as one of the fittings of his palatine chapel in Aachen (today's Aachen Cathedral) and placed in the Octagon of the church. Until 1531, it served as the co
Page: Louis the Pious
Summary: Louis the Pious (Latin: Hludowicus Pius; French: Louis le Pieux; German: Ludwig der Fromme; 16 April 778 – 20 June 840), also called the Fair and the Debonaire, was King of the Franks and co-emperor with his father, Charlemagne, from 813\. He was also King of Aquitaine from 781\. As the only surviving
Page: Holy Roman Emperor
Summary: The Holy Roman Emperor, originally and officially the Emperor of the Romans (Latin: Imperator Romanorum; German: Kaiser der Römer) during the Middle Ages, and also known as the Romano-German Emperor since the early modern period (Latin: Imperator Germanorum; German: Römisch-deutscher Kaiser, lit. 'R
<|im_end|>
<|im_start|>assistant
Thought: The results don't contain explicit information about Charlemagne predecessor, I will search for more information on the page 'Charlemagne' using the 'search_info' tool.
Code:
```py

result = search_info("查理曼的前任", "查理曼")

print(result)

```py<end_action><|im_end|>
<|im_start|>user
[OUTPUT OF STEP 1] -> Observation:
Information retrieved from the page 'Charlemagne' for the query 'Charlemagne predecessor':
Charlemagne's predecessor was Pepin the Short.
<|im_end|>
<|im_start|>assistant
Thought: I have found that, according to the Wikipedia page 'Charlemagne', Pepin the Short was Charlemagne predecessor. I will return the results using the `final_answer` tool.
Code:
```py

final_answer("根据从维基百科页面《查理曼》提取的信息,他的前任是短小的佩平。")

```py<end_action><|im_end|>
___
On top of performing computations in the Python code snippets that you create, you have access to those tools (and no other tool):

<<tool_descriptions>>

<<managed_agents_descriptions>>

You can use imports in your code, but only from the following list of modules: <<authorized_imports>>. Do not try to import other modules or else you will get an error.
Now start and solve the task!

实现选择

在本小节中,我将概述与使用 Hugging Face 代理的架构进行直观实现时的主要区别。这些是有限试错后得到的结果,解决方案效果尚可,但我尚未进行大量测试和消融,因此这些选择可能不是最优的。

  • 提示: 如前所述,每个代理都有自己的专用系统提示,与 Hugging Face 代码代理提供的默认提示不同。我观察到,可能由于所使用的模型大小有限,一般的标准系统提示并没有给出好的结果。该模型似乎在系统提示 closely 反映任务时效果最佳,包括量身定制的重大用例示例。由于我使用了一个聊天模型,旨在改善指令跟随行为,提供的示例遵循模型的聊天模板,尽可能接近运行时遇到的格式。

  • 历史总结: 长时间的执行历史对执行速度和任务表现都有不利影响。后者可能是由于模型从长上下文中提取必要信息的能力有限。此外,极长的执行历史可能超出引擎模型的最大上下文长度。为了缓解这些问题并加速执行,我选择不显示之前思考-行动-观察步骤的所有细节,而是仅收集之前的观察记录。更具体地说,在每一步中,模型只接收以下聊天历史:系统消息、包含任务的第一条消息、其最后一项行动和所有之前观察的历史。此外,执行错误仅在最后一步发生时出现在观察历史中,之前已经解决的错误会被丢弃。

  • 工具与托管代理: Hugging Face 代理实现本身支持托管代理,但将它们包装为工具可以更好地控制提示并实现更加简洁的实现。特别是,Hugging Face 实现向托管代理及其管理者添加了特定的提示。虽然我没有看到在解决给定任务的能力上有实质性的区别,但我更倾向于第二种方法,因为它对展示的架构更加灵活和简洁,并且便于更好地控制代理的行为。这也有助于减少提示的长度,从而加速计算。

  • 限制页面搜索代理的最大试验次数: 有时页面搜索代理会不断在一个没有信息的页面上寻找信息。减少最大试验次数缓解了这个问题:达到该次数后,代理执行停止,工具返回代码执行的最后一次观察。

  • 改变工具对用户消息的响应: 这是一个更具技术性、与实现相关的点。由于Qwen2.5–7B-Instruct的聊天模板仅支持系统、用户和助手角色,因此观察结果作为用户消息返回。

示例

在这一部分,我将收集一些使用所述架构生成的示例。我还将提供所有参与任务的代理的详细执行过程,以展示它们在一些实际案例中的互动。你可以使用插图Colab 笔记本探索更多示例。所有提取的段落都来自执行中指定的维基百科页面。

让我们开始探索一个简单的示例,它需要结合来自多个来源的信息并进行一些基本的算术理解。

任务:埃菲尔铁塔的建造和帝国大厦的建造之间相隔多少年?

以下是管理员代理执行的日志。它能够正确地将任务分解为找到两个建筑物的建设周期,并计算出年份差异。

Agent:  Thought: First, I need to find out when the Eiffel Tower and the Empire State Building were constructed. I will use the `wikipedia_search_agent` tool to find this information.
Code:
```py

wikipedia_search_agent(query="埃菲尔铁塔建造年份")

```py<end_action>
Observation:  Print outputs:
Last output from code snippet:
According to the information extracted from the Wikipedia page 'Eiffel Tower', the construction of the Eiffel Tower took place from 1887 to 1889.
Agent:  Thought: Now I have the construction period of the Eiffel Tower. Next, I need to find the construction period of the Empire State Building. I will use the `wikipedia_search_agent` tool again.
Code:
```py

wikipedia_search_agent(query="帝国大厦建造年份")

```py<end_action>
Observation:  Print outputs:
Last output from code snippet:
According to the information extracted from the Wikipedia page 'Empire State Building', the construction of the Empire State Building started in 1930 and ended in 1931.
Agent:  Thought: Now I have the construction periods for both the Eiffel Tower and the Empire State Building. I can now calculate the number of years between their constructions. 
Code:
```py

construction_eiffel_tower_start = 1887

construction_eiffel_tower_end = 1889

construction_empire_state_building_start = 1930

construction_empire_state_building_end = 1931

years_passed = construction_empire_state_building_start - construction_eiffel_tower_end

answer = f"从埃菲尔铁塔的建设到帝国大厦建设之间经过了 {years_passed} 年。"

final_answer(answer)

```py<end_action>
Observation:  Print outputs:
Last output from code snippet:
The number of years passed between the construction of the Eiffel Tower and the Empire State Building is 41 years.

现在让我们可视化维基百科搜索代理的执行过程。它能够正确搜索并报告在页面摘要中找到的建设期间,而无需检查单独的页面。有趣的是,尽管被泛泛地询问“建造年份”,它报告了整个建设期间,因为不清楚问题中的年份指的是建设的开始还是结束。

TASK:  Eiffel Tower construction year
AGENT:  Thought: I will use the `search_wikipedia` tool to find information about the Eiffel Tower construction year.
Code:
```py

search_wikipedia('埃菲尔铁塔建造年份')

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Pages found for query 'Eiffel Tower construction year':
Page: Eiffel Tower
Summary: The Eiffel Tower (  EYE-fəl; French: Tour Eiffel [tuʁ ɛfɛl] ) is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower from 1887 to 1889.
Locally nicknamed "La dame de fer" (French for "Iron Lady"), it was constructed as the centerpiece of the 1889 World's Fair, and to crown the centennial anniversary of the French Revolution. Although initially criticised by some of France's leading artists and intellectuals for its design, it has since become a global cultural icon of France and one of the most recognisable structures in the world. The tower received 5,889,000 visitors in 2022\. The Eiffel Tower is the most visited monument with an entrance fee in the world: 6.91 million people ascended it in 2015\. It was designated a monument historique in 1964, and was named part of a UNESCO World Heritage Site ("Paris, Banks of the Seine") in 1991.
The tower is 330 metres (1,083 ft) tall, about t
Page: Eiffel Tower (Paris, Texas)
Summary: Texas's Eiffel Tower is a landmark in the city of Paris, Texas. The tower was constructed in 1993\. It is a scale model of the Eiffel Tower in Paris, France; at 65 feet in height, it is roughly one-sixteenth of the height of the original. 

Page: Gustave Eiffel
Summary: Alexandre Gustave Eiffel (  EYE-fəl, French: [alɛksɑ̃dʁ ɡystav ɛfɛl]; né Bonickhausen dit Eiffel; 15 December 1832 – 27 December 1923) was a French civil engineer. A graduate of École Centrale des Arts et Manufactures, he made his name with various bridges for the French railway network, most famously the Garabit Viaduct. He is best known for the world-famous Eiffel Tower, designed by his company and built for the 1889 Universal Exposition in Paris, and his contribution to building the Statue of Liberty in New York. After his retirement from engineering, Eiffel focused on research into meteorology and aerodynamics, making significant contributions in both fields.
Page: Watkin's Tower
Summary: Watkin's Tower was a partially completed iron lattice tower in Wembley Park, London, England. Its construction was an ambitious project to create a 358-metre (1,175 ft)-high visitor attraction in Wembley Park to the north of the city, led by the railway entrepreneur Sir Edward Watkin. Marketed as the "Great Tower of London", it was designed to surpass the height of the Eiffel Tower in Paris, and it was part of Wembley Park's emergence as a recreational place. The tower was never completed and it was demolished in 1907\. The site of the tower is now occupied by the English national football ground, Wembley Stadium.
Page: Eiffel Tower (Paris, Tennessee)
Summary: The Eiffel Tower is a landmark in the city of Paris, Tennessee. It is a 1:20 scale replica of the original located in Paris, France.

AGENT:  Thought: From the summary of the page 'Eiffel Tower', I can see that the construction of the Eiffel Tower took place from 1887 to 1889\. I will use the `final_answer` tool to return the answer.
Code:
```py

final_answer("根据从维基百科页面 '埃菲尔铁塔' 提取的信息,埃菲尔铁塔的建设发生在 1887 年至 1889 年之间。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
According to the information extracted from the Wikipedia page 'Eiffel Tower', the construction of the Eiffel Tower took place from 1887 to 1889.
------------------------------
TASK:  Empire State Building construction year
AGENT:  Thought: I will use the `search_wikipedia` tool to find information about the construction year of the Empire State Building.
Code:
```py

search_wikipedia('帝国大厦建造年份')

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Pages found for query 'Empire State Building construction year':
Page: Empire State Building
Summary: The Empire State Building is a 102-story Art Deco skyscraper in the Midtown South neighborhood of Manhattan in New York City. The building was designed by Shreve, Lamb & Harmon and built from 1930 to 1931\. Its name is derived from "Empire State", the nickname of the state of New York. The building has a roof height of 1,250 feet (380 m) and stands a total of 1,454 feet (443.2 m) tall, including its antenna. The Empire State Building was the world's tallest building until the first tower of the World Trade Center was topped out in 1970; following the September 11 attacks in 2001, the Empire State Building was New York City's tallest building until it was surpassed in 2012 by One World Trade Center. As of 2024, the building is the seventh-tallest building in New York City, the ninth-tallest completed skyscraper in the United States, and the 57th-tallest completed skyscraper in the world.
The site of the Empire State Building, on the west side of Fifth Avenue between West 33rd and 34th St
Page: British Empire Building
Summary: The British Empire Building, also known by its address 620 Fifth Avenue, is a commercial building at Rockefeller Center in the Midtown Manhattan neighborhood of New York City. Completed in 1933, the six-story structure was designed in the Art Deco style by Raymond Hood, Rockefeller Center's lead architect. The British Empire Building, along with the nearly identical La Maison Francaise to the south and the high-rise International Building to the north, comprise a group of retail-and-office structures known as the International Complex. La Maison Francaise and the British Empire Building are separated by Channel Gardens, a planted pedestrian esplanade running west to the complex's Lower Plaza.
The facade is made of limestone, with a main entrance along Fifth Avenue and secondary entrances on 50th Street and Channel Gardens. The top of the British Empire Building contains setbacks, a rooftop garden, and a partial seventh-story penthouse. The building's entrances contain ornate decoration
Page: 2012 Empire State Building shooting
Summary: On August 24, 2012, a gunman shot and killed a former co-worker outside the Empire State Building in New York City. Following the initial shooting, the gunman, 58-year-old Jeffrey T. Johnson, was fatally shot by police officers after raising his weapon at them. Nine bystanders were wounded by stray bullets fired by the officers and ricocheting debris, but none suffered life-threatening injuries.
Page: Empire State Building in popular culture
Summary: The Empire State Building is a 102-story skyscraper on Fifth Avenue between West 33rd and 34th Streets in Midtown, Manhattan, New York City. It has a roof height of 1,250 feet (381 m), and with its antenna included, it stands a total of 1,454 feet (443.2 m) tall. The Empire State Building stood as the world's tallest building for nearly 40 years, from its completion in early 1931 until the topping out of the original World Trade Center's North Tower in late 1970\. It has been featured in numerous films, TV shows, songs, video games, books, photographs, and artwork.
Page: Empire State Plaza
Summary: The Governor Nelson A. Rockefeller Empire State Plaza (known commonly as the Empire State Plaza, and also as the South Mall) is a complex of several state government buildings in downtown Albany, New York.
The complex was built between 1965 and 1976 at an estimated total cost of $2 billion. It houses several departments of the New York State administration and is integrated with the New York State Capitol, completed in 1899, which houses the state legislature. Among the offices at the plaza are the Department of Health and the Biggs Laboratory of the Wadsworth Center. The Empire State Art Collection, a major public collection of 1960s and 1970s monumental abstract artworks, is on permanent display throughout the site. The New York State Office of General Services maintains the plaza. The Nelson A. Rockefeller Empire State Plaza Performing Arts Center Corporation is a New York state public-benefit corporation that was created in 1979 to manage the performing arts facility in the plaza.

AGENT:  Thought: From the summary of the page "Empire State Building", I can see that the building was designed and built from 1930 to 1931\. I will use the `final_answer` tool to return the answer.
Code:
```py

final_answer("根据从维基百科页面 '帝国大厦' 提取的信息,帝国大厦的建设始于 1930 年,结束于 1931 年。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
According to the information extracted from the Wikipedia page 'Empire State Building', the construction of the Empire State Building started in 1930 and ended in 1931.

由于所有信息都已经在页面摘要中找到,因此没有使用段落检索代理。

现在让我们尝试一个更复杂的例子,询问桦木胶合板是否会在酒精中浮起来。这个问题需要基本的物理知识和算术能力,并结合非平凡的检索任务。

任务:桦木胶合板在乙醇中会浮起来吗?

首先,让我们可视化各个代理的日志。

管理员代理日志

AGENT:  Thought: To determine if birch plywood floats in ethanol, we need to know the density of both birch plywood and ethanol. If the density of birch plywood is less than the density of ethanol, it will float. I will use the `wikipedia_search_agent` tool to find the densities of birch plywood and ethanol.
Code:
```py

birch_plywood_density_info = wikipedia_search_agent("桦木胶合板的密度")

ethanol_density_info = wikipedia_search_agent("乙醇的密度")

print(birch_plywood_density_info)

print(ethanol_density_info)

```py<end_action>
OBSERVATION:  Print outputs:
The density of birch plywood is approximately 680 kg/m³.
The density of ethanol at 20°C is approximately 0.7894 g/cm³.

AGENT:  Thought: Now that I have the densities, I can compare them. Birch plywood has a density of 680 kg/m³ and ethanol has a density of 0.7894 g/cm³. Since 1 g/cm³ is equal to 1000 kg/m³, the density of ethanol is 789.4 kg/m³. Since the density of birch plywood (680 kg/m³) is less than the density of ethanol (789.4 kg/m³), birch plywood will float in ethanol. I will now provide this conclusion using the `final_answer` function.
Code:
```py

final_answer("桦木胶合板在乙醇中会浮起来,因为它的密度(680 kg/m³)小于乙醇的密度(789.4 kg/m³)。")

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
Birch plywood will float in ethanol because its density (680 kg/m³) is less than the density of ethanol (789.4 kg/m³).

Wikipedia 搜索代理日志

TASK:  density of birch plywood
AGENT:  Thought: I will use the `search_wikipedia` tool to find information about the density of birch plywood. 
Code:
```py

search_wikipedia('桦木胶合板')

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
Pages found for query 'birch plywood':
Page: Plywood
Summary: Plywood is a composite material manufactured from thin layers, or "plies", of wood veneer that have been stacked and glued together. It is an engineered wood from the family of manufactured boards, which include plywood, medium-density fibreboard (MDF), oriented strand board (OSB), and particle board (or chipboard).
All plywoods bind resin and wood fibre sheets (cellulose cells are long, strong and thin) to form a composite material. The sheets of wood are stacked such that each layer has its grain set typically (see below) perpendicular to its adjacent layers. This alternation of the grain is called cross-graining and has several important benefits: it reduces the tendency of wood to split when nailed at the edges; it reduces thickness swelling and shrinkage, providing improved dimensional stability; and it makes the strength of the panel consistent across all directions.  There is usually an odd number of plies, so that the sheet is balanced, that is, the surface layers have their gr
Page: Birch
Summary: A birch is a thin-leaved deciduous hardwood tree of the genus Betula (), in the family Betulaceae, which also includes alders, hazels, and hornbeams. It is closely related to the beech-oak family Fagaceae. The genus Betula contains 30 to 60 known taxa of which 11 are on the IUCN 2011 Red List of Threatened Species. They are typically short-lived pioneer species and are widespread in the Northern Hemisphere, particularly in northern areas of temperate climates and in boreal climates. Birch wood is used for a wide range of purposes.
Page: Birch wood
Summary: Birch wood is a type of wood of the birch. Birch wood is pale yellow-brown wood having a close, straight grain and uniform texture that finishes to a smooth surface. Sometimes it is dyed to imitate mahogany. This type of wood is used for among others firewood, turnery, furniture, cabinetry, tools handles, hoops, plywood, flooring and shoe heels.
Page: Falco F8L
Summary: The Falco F8L is an Italian-designed lightweight 2-seater aerobatic aircraft.
Page: Crusader (speedboat)
Summary: Crusader was a jet-powered speed boat piloted by John Cobb.
The combination of an aerodynamically stable hull form and turbojet propulsion was proposed by Reid Railton, Cobb's adviser. A rocket-powered scale model was tested at Haslar. The full size design was by Peter du Cane and built by Vospers of Portsmouth. Technical assistance came from Saunders-Roe and Vickers-Supermarine. It cost £15,000 in 1949.
It was silver and scarlet in colour and 10 m long.  
The engine was a de Havilland Ghost Mk 48 centrifugal turbojet provided as a loan by the Ministry of Supply at the request of Major Frank Halford, the engine designer. The engine was rated at 5,000 lb thrust fed by two scoop inlets forward of the cockpit.
The hull was of trimaran form, a main hull with a planing step, and two smaller rear-mounted outriggers. Construction was of birch plywood frames and stringers. The hull was skinned in birch ply covered in doped fabric with metal skin reinforcement for planing surfaces. Aircraft-sty

AGENT:  Thought: The pages returned from the search did not directly mention the density of birch plywood. However, the page on Birch Wood mentions that birch wood is used for plywood. I will use the `search_info` tool to find more specific information about the density of birch plywood from the page 'Plywood'.
Code:
```py

search_info('桦木胶合板的密度', '胶合板')

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
Information retrieved from the page 'Plywood' for the query 'density of birch plywood':
The density of birch plywood is approximately 680 kg/m³.
AGENT:  Thought: I have found the density of birch plywood from the Wikipedia page 'Plywood'. The density is approximately 680 kg/m³.
Code:
```py

final_answer("桦木胶合板的密度大约是 680 kg/m³。")

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
The density of birch plywood is approximately 680 kg/m³.
------------------------------
TASK:  density of ethanol
AGENT:  Thought: I will use the `search_wikipedia` tool to find the density of ethanol from Wikipedia. 
Code:
```py

search_wikipedia('乙醇的密度')

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
Pages found for query 'density of ethanol':
Page: Ethanol (data page)
Summary: This page provides supplementary chemical data on ethanol.

Page: Ethanol
Summary: Ethanol (also called ethyl alcohol, grain alcohol, drinking alcohol, or simply alcohol) is an organic compound with the chemical formula CH3CH2OH. It is an alcohol, with its formula also written as C2H5OH, C2H6O or EtOH, where Et stands for ethyl. Ethanol is a volatile, flammable, colorless liquid with a characteristic wine-like odor and pungent taste. In nature, grape-sugar breaks up by the action of fermentation into alcohol or carbonic acid, without anything being added. As a psychoactive depressant, it is the active ingredient in alcoholic beverages, and the second most consumed drug globally behind caffeine.
Ethanol is naturally produced by the fermentation process of sugars by yeasts or via petrochemical processes such as ethylene hydration. Historically it was used as a general anesthetic, and has modern medical applications as an antiseptic, disinfectant, solvent for some medications, and antidote for methanol poisoning and ethylene glycol poisoning. It is used as a chemical so
Page: Alcohol by volume
Summary: Alcohol by volume (abbreviated as alc/vol or ABV) is a standard measure of the volume of alcohol contained in a given volume of an alcoholic beverage, expressed as a volume percent. It is defined as the number of millilitres (mL) of pure ethanol present in 100 mL (3.5 imp fl oz; 3.4 US fl oz) of solution at 20 °C (68 °F). The number of millilitres of pure ethanol is the mass of the ethanol divided by its density at 20 °C (68 °F), which is 0.78945 g/mL (0.82353 oz/US fl oz; 0.79122 oz/imp fl oz; 0.45633 oz/cu in). The alc/vol standard is used worldwide. The International Organization of Legal Metrology has tables of density of water–ethanol mixtures at different concentrations and temperatures.
In some countries, e.g. France, alcohol by volume is often referred to as degrees Gay-Lussac (after the French chemist Joseph Louis Gay-Lussac), although there is a slight difference since the Gay-Lussac convention uses the International Standard Atmosphere value for temperature, 15 °C (59 °F).

Page: Alcohol fuel
Summary: Various alcohols are used as fuel for internal combustion engines.  The first four aliphatic alcohols (methanol, ethanol, propanol, and butanol)
are of interest as fuels because they can be synthesized chemically or biologically, and they have characteristics which allow them to be used in internal combustion engines. The general chemical formula for alcohol fuel is CnH2n+1OH.
Most methanol is produced from natural gas, although it can be produced from biomass using very similar chemical processes. Ethanol is commonly produced from biological material through fermentation processes. Biobutanol has the advantage in combustion engines in that its energy density is closer to gasoline than the simpler alcohols (while still retaining over 25% higher octane rating); however, biobutanol is currently more difficult to produce than ethanol or methanol. When obtained from biological materials and/or biological processes, they are known as bioalcohols (e.g. "bioethanol"). There is no chemical dif
Page: Cellulosic ethanol
Summary: Cellulosic ethanol is ethanol (ethyl alcohol) produced from cellulose (the stringy fiber of a plant) rather than from the plant's seeds or fruit. It can be produced from grasses, wood, algae, or other plants. It is generally discussed for use as a biofuel. The carbon dioxide that plants absorb as they grow offsets some of the carbon dioxide emitted when ethanol made from them is burned, so cellulosic ethanol fuel has the potential to have a lower carbon footprint than fossil fuels.
Interest in cellulosic ethanol is driven by its potential to replace ethanol made from corn or sugarcane. Since these plants are also used for food products, diverting them for ethanol production can cause food prices to rise; cellulose-based sources, on the other hand, generally do not compete with food, since the fibrous parts of plants are mostly inedible to humans. Another potential advantage is the high diversity and abundance of cellulose sources; grasses, trees and algae are found in almost every envi

AGENT:  Thought: The summary of the page 'Ethanol' mentions that ethanol is an alcohol with a characteristic wine-like odor and pungent taste. However, it does not provide the density of ethanol. I will use the `search_info` tool to find the density of ethanol from the 'Ethanol (data page)'.
Code:
```py

search_info('乙醇的密度', '乙醇(数据页面)')

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
Information retrieved from the page 'Ethanol (data page)' for the query 'density of ethanol':
The density of ethanol at 20°C is approximately 0.7894 g/cm³.
AGENT:  Thought: I have found the density of ethanol from the Wikipedia page 'Ethanol (data page)'. I will return the result using the `final_answer` tool.
Code:
```py

final_answer("乙醇在 20°C 时的密度大约是 0.7894 g/cm³。")

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
The density of ethanol at 20°C is approximately 0.7894 g/cm³.

页面搜索代理日志

TASK:  Retrieve information about the query:"density of birch plywood" from the Wikipedia page "Plywood".
AGENT:  Thought: I need to find the density of birch plywood. I will use the tool `retrieve_passages` to look for data on the Wikipedia page.
Code:
```py

result = retrieve_passages("桦木胶合板的密度")

print(result)

```py<end_action>
OBSERVATION:  Print outputs:
Retrieved passages for query "density of birch plywood":
Passage 0: ... . Shear stress, by contrast, is higher in the center of the panel, and at the outer fibres. Within Europe basic plywood can be divided into three main categories: birch plywood (density approx. 680 kg/m3), mixed plywood (density approx. 620 kg/m3) and conifer plywoods (density 460–520 kg/m3). ...
Passage 1: ... . Aircraft grade plywood is made from three or more plies of birch, as thin as 0.40 millimetres (1⁄64 in) thick in total, and is extremely strong and light. ...
Passage 2: ... Baltic Birch plywood is a product of an area around the Baltic Sea. Originally manufactured for European cabinet makers but now popular in the United States as well. It is very stable composed of an inner void-free core of cross-banded birch plys with an exterior grade adhesive. The face veneers are thicker than traditional cabinet grade ...
Passage 3: ... . As a result, thicker panels can span greater distances under the same loads. In bending, the maximum stress occurs in the outermost layers, one in tension, the other in compression. Bending stress decreases from the maximum at the face layers to nearly zero at the central layer. Shear stress, by contrast, is higher in the center of the panel, and at the outer fibres. Within Europe basic plywood can be divided into three main categories: birch plywood (density approx ...
Passage 4: ... Wind turbine blades
Insulation boxes for liquefied natural gas (LNG) carriers
Smooth surface and accurate thickness combined with the durability of the material makes birch plywood a favorable material for many special end uses e.g.:

High-end loud speakers
Die-cutting boards
Supporting structure for parquet
Playground equipment
Furniture
Signs and fences for demanding outdoor advertising
Musical instruments
Sports equipment...

AGENT:  Thought: From the passages, I can see that the density of birch plywood is approximately 680 kg/m³. I can use the `final_answer` tool to provide the answer.
Code:
```py

final_answer("桦木胶合板的密度约为 680 kg/m³。")

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
The density of birch plywood is approximately 680 kg/m³.
------------------------------
TASK:  Retrieve information about the query:"density of ethanol" from the Wikipedia page "Ethanol (data page)".
AGENT:  Thought: I need to find the density of ethanol. I will use the tool `retrieve_passages` to get the information from the Wikipedia page.
Code:
```py

result = retrieve_passages("乙醇的密度")

print(result)

```py<end_action>
OBSERVATION:  Print outputs:
Retrieved passages for query "density of ethanol":
Passage 0: ... == Density of ethanol at various temperatures ==
Data obtained from Lange 1967

These data correlate as ρ [g/cm3] = −8.461834×10−4 T [°C] + 0.8063372 with an R2 = 0.99999.

== Properties of aqueous ethanol solutions ==
Data obtained from Lange 1967

== Boiling points of aqueous solutions ==
Data obtained from CRC Handbook of Chemistry (Page 2117): 2391 

‡Azeotropic mixture

== Charts ==

== References == ...
Passage 1: ... This page provides supplementary chemical data on ethanol.

== Material Safety Data Sheet ==
External MSDS

== Structure and properties ==

== Thermodynamic properties ==

== Spectral data ==

== Vapor pressure of liquid ==

== Density of ethanol at various temperatures ==
Data obtained from Lange 1967

These data correlate as ρ [g/cm3] = −8.461834×10−4 T [°C] + 0.8063372 with an R2 = 0.99999.

== Properties of aqueous ethanol solutions ==
Data obtained from Lange 1967...

AGENT:  Thought: From the passages, I can see that the density of ethanol varies with temperature. I can use the provided linear equation to calculate the density at a specific temperature. I will use the `math` module to perform the calculation.
Code:
```py

# 从方程中提取系数

a = -8.461834e-4

b = 0.8063372

# 计算特定温度下的密度,我们以 20°C 为例

temperature = 20

density = a * temperature + b

print(density)

```py<end_action>
OBSERVATION:  Print outputs:
0.789413532

AGENT:  Thought: The passage provides a linear equation to calculate the density of ethanol at different temperatures. I can now use this equation to provide the density of ethanol at a specific temperature. I will use the `final_answer` tool to return the relevant information.
Code:
```py

final_answer(f"20°C 下乙醇的密度约为{round(0.789413532, 4)} g/cm³。")

```py<end_action>
OBSERVATION:  Print outputs:
Last output from code snippet:
The density of ethanol at 20°C is approximately 0.7894 g/cm³.

该模型正确地识别出密度差异是浮沉的原因,将任务分解为查找两种物质的密度,并能够从检索到的数据中得出正确的结论,尽管这些数据的单位不同。

限制

讨论的系统具有令人印象深刻的能力,特别是考虑到使用的引擎模型的有限规模,但当然,它也有自己的局限性。最大的局限似乎与引擎模型本身的能力有关:虽然它似乎能够理解并执行复杂且冗长的指令,以及正确使用提供的工具,但有时它未能遵循简单的指南,并且往往会重复无效的操作而没有变化。当代理寻找的信息不存在(或无法通过语义搜索工具轻松提取)时,这种情况尤为突出。

另一个缺点是,代理系统本身固有的,尤其是在多代理系统中更为突出的是,各个代理执行所需的计算时间。在一些简单的案例中,标准的 RAG 工作流足以回答问题,但这可能会产生不必要的计算开销。

最后,多个代理之间的相互作用可能会放大幻觉的风险,因为一个代理产生的幻觉可能会在执行链中传播,使得追踪变得非常困难。

让我们通过一个示例来可视化,系统虽然仍然提供了正确答案,但在没有必要的情况下过度使用了工具调用(并重复某些操作),不必要地延长了执行时间。

任务:君士坦丁堡陷落时,中国正由哪个朝代统治?

该任务要求较高的抽象能力,因为这两件事没有直接关系,且君士坦丁堡陷落的确切年份可能没有出现在相应朝代的页面上。

管理员代理的执行仍然简洁且精准,通过首先检索君士坦丁堡陷落的日期,再查找该年份的中国朝代,正确解决了任务。然而,我们已经可以看到,尽管搜索结果仍包含所需的信息,但看起来有些奇怪。

Agent:  Thought: To answer this question, I need to find out when Constantinople fell and which Chinese dynasty was ruling at that time. I'll start by finding the fall date of Constantinople.
Code:
```py

wikipedia_search_agent(query="君士坦丁堡陷落")

```py<end_action>
Observation:  Print outputs:
Last output from code snippet:
The leader of the Ottoman Empire in 1453, during the fall of Constantinople, was Mehmed II, also known as Mehmed the Conqueror.
Agent:  Thought: From the previous search, we know that Constantinople fell in 1453, and it was conquered by Mehmed II. Now I need to find out which Chinese dynasty was ruling in 1453\. I will search for the Chinese dynasties around that time.
Code:
```py

wikipedia_search_agent(query="1453 年中国朝代")

```py<end_action>
Observation:  Print outputs:
Last output from code snippet:
In 1453, the Ming dynasty was in full control of China. No specific events related to 1453 are mentioned in the information provided.
Agent:  Thought: Based on the information provided, the Ming dynasty was in full control of China in 1453\. Therefore, the dynasty ruling China when Constantinople fell would be the Ming dynasty.
Code:
```py

final_answer("1453 年君士坦丁堡陷落时,中国的统治朝代是明朝。")

```py<end_action>
Observation:  Print outputs:
Last output from code snippet:
The dynasty ruling China when Constantinople fell in 1453 was the Ming dynasty.

因此,让我们检查相应的维基百科搜索代理任务及其执行情况。我们可以看到,可能由于“君士坦丁堡的陷落”这一广泛的任务,代理在找到正确日期后,查看了其他具体信息。检索关于中国朝代的信息时也发生了类似情况,此外还出现了某些工具调用被重复执行的问题。

TASK:  Fall of Constantinople
AGENT:  Thought: I will use the `search_wikipedia` tool to find information about the fall of Constantinople. This will likely provide me with a summary of the event and possibly links to related pages.
Code:
```py

search_wikipedia("君士坦丁堡的陷落")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Pages found for query 'Fall of Constantinople':
Page: Fall of Constantinople
Summary: The fall of Constantinople, also known as the conquest of Constantinople, was the capture of the capital of the Byzantine Empire by the Ottoman Empire. The city was captured on 29 May 1453 as part of the culmination of a 55-day siege which had begun on 6 April.
The attacking Ottoman Army, which significantly outnumbered Constantinople's defenders, was commanded by the 21-year-old Sultan Mehmed II (later nicknamed "the Conqueror"), while the Byzantine army was led by Emperor Constantine XI Palaiologos. After conquering the city, Mehmed II made Constantinople the new Ottoman capital, replacing Adrianople.
The fall of Constantinople and of the Byzantine Empire was a watershed of the Late Middle Ages, marking the effective end of the Roman Empire, a state which began in roughly 27 BC and had lasted nearly 1500 years. For many modern historians, the fall of Constantinople marks the end of the medieval period and the beginning of the early modern period. The city's fall also stood as a turni
Page: Sack of Constantinople
Summary: The sack of Constantinople occurred in April 1204 and marked the culmination of the Fourth Crusade. Crusaders sacked and destroyed most of Constantinople, the capital of the Byzantine Empire. After the capture of the city, the Latin Empire (known to the Byzantines as the Frankokratia, or the Latin occupation) was established and Baldwin of Flanders crowned as Emperor Baldwin I of Constantinople in Hagia Sophia.
After the city's sacking, most of the Byzantine Empire's territories were divided up among the Crusaders. Byzantine aristocrats also established a number of small independent splinter states—one of them being the Empire of Nicaea, which would eventually recapture Constantinople in 1261 and proclaim the reinstatement of the Empire. However, the restored Empire never managed to reclaim all its former territory or attain its earlier economic strength, and it gradually succumbed to the rising Ottoman Empire over the following two centuries.
The Byzantine Empire was left poorer, smal
Page: Constantinople
Summary: Constantinople (see other names) became the capital of the Roman Empire during the reign of Constantine the Great in 330\. Following the collapse of the Western Roman Empire in the late 5th century, Constantinople remained the capital of the Eastern Roman Empire (also known as the Byzantine Empire; 330–1204 and 1261–1453), the Latin Empire (1204–1261), and the Ottoman Empire (1453–1922). Following the Turkish War of Independence, the Turkish capital then moved to Ankara. Officially renamed Istanbul in 1930, the city is today the largest city in Europe, straddling the Bosporus strait and lying in both Europe and Asia, and the financial center of Turkey.
In 324, following the reunification of the Eastern and Western Roman Empires, the ancient city of Byzantium was selected to serve as the new capital of the Roman Empire, and the city was renamed Nova Roma, or 'New Rome', by Emperor Constantine the Great. On 11 May 330, it was renamed Constantinople and dedicated to Constantine. Constantin
Page: Moscow, third Rome
Summary: Moscow, third Rome (Russian: Москва — третий Рим; Moskva, tretiĭ Rim) is a theological and political concept asserting Moscow as the successor to ancient Rome, with the Russian world carrying forward the legacy of the Roman Empire. The term "third Rome" refers to a historical topic of debate in European culture: the question of the successor city to the "first Rome" (Rome, within the Western Roman Empire) and the "second Rome" (Constantinople, within the Eastern Roman Empire).
Page: Byzantine Empire
Summary: The Byzantine Empire, also referred to as the Eastern Roman Empire, was the continuation of the Roman Empire centred in Constantinople during late antiquity and the Middle Ages. The eastern half of the Empire survived the conditions that caused the fall of the West in the 5th century AD, and continued to exist until the fall of Constantinople to the Ottoman Empire in 1453\. During most of its existence, the empire remained the most powerful economic, cultural, and military force in the Mediterranean world. The term "Byzantine Empire" was only coined following the empire's demise; its citizens referred to the polity as the "Roman Empire" and to themselves as "Romans". Due to the imperial seat's move from Rome to Byzantium, the adoption of state Christianity, and the predominance of Greek instead of Latin, modern historians continue to make a distinction between the earlier Roman Empire and the later Byzantine Empire.
During the earlier Pax Romana period, the western parts of the empire b

AGENT:  Thought: From the summary of the page "Fall of Constantinople", I can see that the event took place in 1453\. I will use the `search_info` tool to find more details about the event, specifically who was the leader of the Ottoman Empire at that time.
Code:
```py

search_info("奥斯曼帝国领导人 1453", "君士坦丁堡的陷落")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Information retrieved from the page 'Fall of Constantinople' for the query 'leader of the Ottoman Empire in 1453':
The leader of the Ottoman Empire in 1453 was Mehmed II, also known as Mehmed the Conqueror.
AGENT:  Thought: I have found that the leader of the Ottoman Empire in 1453 was Mehmed II, also known as Mehmed the Conqueror. I will use the `final_answer` tool to provide the answer.
Code:
```py

final_answer("1453 年,君士坦丁堡陷落时,奥斯曼帝国的领导人是穆罕默德二世,也称为征服者穆罕默德。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
The leader of the Ottoman Empire in 1453, during the fall of Constantinople, was Mehmed II, also known as Mehmed the Conqueror.
------------------------------
TASK:  Chinese dynasties 1453
AGENT:  Thought: I will use the `search_wikipedia` tool to find information about Chinese dynasties around the year 1453\. I will provide a general query related to Chinese dynasties in that time period.
Code:
```py

search_wikipedia('中国朝代 1453')

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Pages found for query 'Chinese dynasties 1453':
Page: Chinese Empire
Summary: Chinese Empire (or Empire of China) is a term referring to the realm ruled by the Emperor of China during the era known as Imperial China. It was coined by western scholars and used to describe the Ming and Qing dynasties (or imperial Chinese dynasties in general). Another term was "Celestial Empire", in reference to the status of the emperor as the Son of Heaven. In 221 BC, China was unified under an emperor for the first time, and various dynasties or empires founded by hereditary monarchs ruled China for a total of two millennia since then, including the Qin, Han, Jin, Sui, Tang, Song, Yuan, Ming, and Qing.

Page: Ming dynasty
Summary: The Ming dynasty, officially the Great Ming, was an imperial dynasty of China, ruling from 1368 to 1644 following the collapse of the Mongol-led Yuan dynasty. The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China. Although the primary capital of Beijing fell in 1644 to a rebellion led by Li Zicheng (who established the short-lived Shun dynasty), numerous rump regimes ruled by remnants of the Ming imperial family—collectively called the Southern Ming—survived until 1662.
The Ming dynasty's founder, the Hongwu Emperor (r. 1368–1398), attempted to create a society of self-sufficient rural communities ordered in a rigid, immobile system that would guarantee and support a permanent class of soldiers for his dynasty: the empire's standing army exceeded one million troops and the navy's dockyards in Nanjing were the largest in the world. He also took great care breaking the power of the court eunuchs and unrelated magnates, enfeoff
Page: List of time periods
Summary: The categorisation of the past into discrete, quantified named blocks of time is called periodization. This is a list of such named time periods as defined in various fields of study.
These can be divided broadly into prehistorical periods and historical periods
(when written records began to be kept). 
In archaeology and anthropology, prehistory is subdivided into the three-age system, this list includes the use of the three-age system as well as a number of various designation used in reference to sub-ages within the traditional three.
The dates for each age can vary by region. On the geologic time scale, the Holocene epoch starts at the end of the last glacial period of the current ice age (c. 10,000 BC) and continues to the present. The beginning of the Mesolithic is usually considered to correspond to the beginning of the Holocene epoch.

Page: Celestial Empire
Summary: Celestial Empire (Chinese: 天朝; pinyin: Tiāncháo; lit. 'heavenly dynasty') is an old name used to refer to China or the Chinese Empire, from a literary and poetic translation of the Chinese term, one of many names for China. The name was used in reference to the status of the Emperor of China as the Son of Heaven in the Sinosphere.
Accordingly, in the 19th century, the name "Celestial" was used to refer to Chinese people. Both terms were widely used in the English-language popular mass media of the day, but fell into disuse later on. Its usage has become popular again in the present day (2015), particularly among Chinese Internet users. It is used to refer to the current Communist regime, to imply either disapproval for its political suppression and arrogance or national pride in the country's emergence as a superpower in the 21st century, depending on the context.
In modern times it has mostly fallen into disuse, but the name "Celestial Empire" still appears in various media such as ar
Page: List of dynasties
Summary: This is a list of monarchical dynasties and dynastic regimes organized by geographic region. Extant dynasties are rendered in bold and highlighted.

AGENT:  Thought: From the summaries, I can see that the Ming dynasty ruled from 1368 to 1644, which covers the year 1453\. I will use the `search_info` tool to find more detailed information about the Ming dynasty during the year 1453.
Code:
```py

search_info('1453', '明朝')

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Information retrieved from the page 'Ming dynasty' for the query '1453':
The year 1453 falls within the period of the Ming dynasty, which ruled from 1368 to 1644\. During this time, the Ming dynasty was in full control of China and experienced significant developments such as the construction of the Forbidden City and the expansion of the Great Wall of China.
AGENT:  Thought: Now that I know the Ming dynasty ruled during 1453, I can provide a more specific answer about the events that happened in China during that year. However, the information provided does not give us specific details about the events in 1453\. I will use the `search_info` tool to find more detailed information about the Ming dynasty during the year 1453.
Code:
```py

search_info('1453', '明朝')

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Information retrieved from the page 'Ming dynasty' for the query '1453':
The year 1453 falls within the period of the Ming dynasty, which ruled from 1368 to 1644\. During this time, the Ming dynasty was in full control of China and experienced significant developments such as the construction of the Forbidden City and the expansion of the Great Wall of China.
AGENT:  Thought: The information provided does not specifically mention the events of 1453\. However, we can infer that during the Ming dynasty, significant developments occurred. I will use the `search_info` tool to find more specific information about the events of 1453 in the Ming dynasty.
Code:
```py

search_info('1453', '明朝')

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Information retrieved from the page 'Ming dynasty' for the query '1453':
The year 1453 falls within the period of the Ming dynasty, which ruled from 1368 to 1644\. During this time, the Ming dynasty was in full control of China and experienced significant developments such as the construction of the Forbidden City and the expansion of the Great Wall of China.
AGENT:  Thought: The information provided indicates that the Ming dynasty was in full control of China during the year 1453\. However, there is no specific mention of events that happened in 1453\. I will use the `search_info` tool to find more specific information about the events of 1453 in the Ming dynasty.
Code:
```py

search_info('1453', '明朝')

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
Information retrieved from the page 'Ming dynasty' for the query '1453':
The year 1453 falls within the period of the Ming dynasty, which ruled from 1368 to 1644\. During this time, the Ming dynasty was in full control of China and experienced significant developments such as the construction of the Forbidden City and the expansion of the Great Wall of China.
AGENT:  Thought: The information provided does not specifically mention any events that occurred in 1453\. However, we know that the Ming dynasty was in full control of China during this time. I will use the `final_answer` tool to provide the answer based on the information we have.
Code:
```py

final_answer("1453 年,明朝完全控制了中国。提供的信息中没有提及与 1453 年相关的具体事件。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
In 1453, the Ming dynasty was in full control of China. No specific events related to 1453 are mentioned in the information provided.

深入分析页面搜索代理执行情况,我们看到这也重复了多次使用相同参数的工具调用。

TASK:  Retrieve information about the query:"leader of the Ottoman Empire in 1453" from the Wikipedia page "Fall of Constantinople".
AGENT:  Thought: I need to find out who was the leader of the Ottoman Empire in 1453\. I will use the `retrieve_passages` tool to search for this information on the Wikipedia page "Fall of Constantinople".
Code:
```py

result = retrieve_passages("奥斯曼帝国领导人 1453")

print(result)

```py<end_action>
OBSERVATIONS:  Print outputs:
Retrieved passages for query "leader of the Ottoman Empire in 1453":
Passage 0: ... When Mehmed II succeeded his father in 1451, he was 19 years old. Many European courts assumed that the young Ottoman ruler would not seriously challenge Christian hegemony in the Balkans and the Aegean. In fact, Europe celebrated Mehmed coming to the throne and hoped his inexperience would lead the Ottomans astray. This calculation was boosted by Mehmed's friendly overtures to the European envoys at his new court. But Mehmed's mild words were not matched by his actions ...
Passage 1: ... The attacking Ottoman Army, which significantly outnumbered Constantinople's defenders, was commanded by the 21-year-old Sultan Mehmed II (later nicknamed "the Conqueror"), while the Byzantine army was led by Emperor Constantine XI Palaiologos. After conquering the city, Mehmed II made Constantinople the new Ottoman capital, replacing Adrianople. ...
Passage 2: ... . The oldest boy, renamed Murad, became a personal favourite of Mehmed and served as Beylerbey (Governor-General) of Rumeli (the Balkans). The younger son, renamed Mesih Pasha, became Admiral of the Ottoman fleet and Sancak Beg (Governor) of the province of Gallipoli. He eventually served twice as Grand Vizier under Mehmed's son, Bayezid II. ...
Passage 3: ... == External links ==

 Media related to Fall of Constantinople (1453) at Wikimedia Commons
The Siege of Constantinople As The Islamic World Sees it Archived 22 April 2021 at the Wayback Machine
World History Encyclopedia – 1453: The Fall of Constantinople
Constantinople Siege & Fall, BBC Radio 4 discussion with Roger Crowley, Judith Herrin & Colin Imber (In Our Time, 28 December 2006). ...
Passage 4: ... In the summer of 1452, when Rumeli Hisarı was completed and the threat of the Ottomans had become imminent, Constantine wrote to the Pope, promising to implement the union, which was declared valid by a half-hearted imperial court on 12 December 1452\. Although he was eager for an advantage, Pope Nicholas V did not have the influence the Byzantines thought he had over the Western kings and princes, some of whom were wary of increasing papal control...

AGENT:  Thought: From the passages, it seems that the leader of the Ottoman Empire in 1453 was Mehmed II, also known as Mehmed the Conqueror. I will use the `final_answer` tool to provide the answer.
Code:
```py

final_answer("1453 年,奥斯曼帝国的领导人是穆罕默德二世,也称为征服者穆罕默德。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
The leader of the Ottoman Empire in 1453 was Mehmed II, also known as Mehmed the Conqueror.
------------------------------
TASK:  Retrieve information about the query:"1453" from the Wikipedia page "Ming dynasty".
AGENT:  Thought: I need to find information related to the year 1453 from the Wikipedia page "Ming dynasty". I will use the `retrieve_passages` tool to get relevant passages.
Code:
```py

result = retrieve_passages("1453 明朝")

print(result)

```py<end_action>
OBSERVATIONS:  Print outputs:
Retrieved passages for query "1453 Ming dynasty":
Passage 0: ... The Ming dynasty, officially the Great Ming, was an imperial dynasty of China, ruling from 1368 to 1644 following the collapse of the Mongol-led Yuan dynasty. The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China ...
Passage 1: ... === Reference works and primary sources ===
Farmer, Edward L. ed. Ming History: An Introductory Guide to Research (1994).
Goodrich, Luther Carrington (1976). Dictionary of Ming Biography, 1368–1644\. New York: Columbia University Press. ISBN 978-0-231-03833-1.
The Ming History English Translation Project, A collaborative project that makes available translations (from Chinese to English) of portions of the 明史 Mingshi (Official History of the Ming Dynasty). ...
Passage 2: ... ==== Reign of the Hongwu Emperor ==== ...
Passage 3: ... . The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China. Although the primary capital of Beijing fell in 1644 to a rebellion led by Li Zicheng (who established the short-lived Shun dynasty), numerous rump regimes ruled by remnants of the Ming imperial family—collectively called the Southern Ming—survived until 1662\. ...
Passage 4: ... . Zhu Di assumed the throne as the Yongle Emperor (r. 1402–24); his reign is universally viewed by scholars as a "second founding" of the Ming dynasty since he reversed many of his father's policies....

AGENT:  Thought: The passages do not directly mention the year 1453\. However, I can infer that the Ming dynasty lasted from 1368 to 1644, and 1453 falls within this period. I will use the `final_answer` tool to provide a summary based on the information available.
Code:
```py

final_answer("1453 年属于明朝时期,明朝统治时间为 1368 年到 1644 年。在这段时间里,明朝完全控制了中国,并经历了如紫禁城的建造和长城扩建等重要发展。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
The year 1453 falls within the period of the Ming dynasty, which ruled from 1368 to 1644\. During this time, the Ming dynasty was in full control of China and experienced significant developments such as the construction of the Forbidden City and the expansion of the Great Wall of China.
------------------------------
TASK:  Retrieve information about the query:"1453" from the Wikipedia page "Ming dynasty".
AGENT:  Thought: I need to find information related to the year 1453 from the Wikipedia page "Ming dynasty". I will use the `retrieve_passages` tool to get relevant passages.
Code:
```py

result = retrieve_passages("1453 明朝")

print(result)

```py<end_action>
OBSERVATIONS:  Print outputs:
Retrieved passages for query "1453 Ming dynasty":
Passage 0: ... The Ming dynasty, officially the Great Ming, was an imperial dynasty of China, ruling from 1368 to 1644 following the collapse of the Mongol-led Yuan dynasty. The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China ...
Passage 1: ... === Reference works and primary sources ===
Farmer, Edward L. ed. Ming History: An Introductory Guide to Research (1994).
Goodrich, Luther Carrington (1976). Dictionary of Ming Biography, 1368–1644\. New York: Columbia University Press. ISBN 978-0-231-03833-1.
The Ming History English Translation Project, A collaborative project that makes available translations (from Chinese to English) of portions of the 明史 Mingshi (Official History of the Ming Dynasty). ...
Passage 2: ... ==== Reign of the Hongwu Emperor ==== ...
Passage 3: ... . The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China. Although the primary capital of Beijing fell in 1644 to a rebellion led by Li Zicheng (who established the short-lived Shun dynasty), numerous rump regimes ruled by remnants of the Ming imperial family—collectively called the Southern Ming—survived until 1662\. ...
Passage 4: ... . Zhu Di assumed the throne as the Yongle Emperor (r. 1402–24); his reign is universally viewed by scholars as a "second founding" of the Ming dynasty since he reversed many of his father's policies....

AGENT:  Thought: The passages do not directly mention the year 1453\. However, I can infer that the Ming dynasty lasted from 1368 to 1644, and 1453 falls within this period. I will use the `final_answer` tool to provide a summary based on the information available.
Code:
```py

final_answer("1453 年属于明朝时期,明朝统治时间为 1368 年到 1644 年。在这段时间里,明朝完全控制了中国,并经历了如紫禁城的建造和长城扩建等重要发展。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
The year 1453 falls within the period of the Ming dynasty, which ruled from 1368 to 1644\. During this time, the Ming dynasty was in full control of China and experienced significant developments such as the construction of the Forbidden City and the expansion of the Great Wall of China.
------------------------------
TASK:  Retrieve information about the query:"1453" from the Wikipedia page "Ming dynasty".
AGENT:  Thought: I need to find information related to the year 1453 from the Wikipedia page "Ming dynasty". I will use the `retrieve_passages` tool to get relevant passages.
Code:
```py

result = retrieve_passages("1453 明朝")

print(result)

```py<end_action>
OBSERVATIONS:  Print outputs:
Retrieved passages for query "1453 Ming dynasty":
Passage 0: ... The Ming dynasty, officially the Great Ming, was an imperial dynasty of China, ruling from 1368 to 1644 following the collapse of the Mongol-led Yuan dynasty. The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China ...
Passage 1: ... === Reference works and primary sources ===
Farmer, Edward L. ed. Ming History: An Introductory Guide to Research (1994).
Goodrich, Luther Carrington (1976). Dictionary of Ming Biography, 1368–1644\. New York: Columbia University Press. ISBN 978-0-231-03833-1.
The Ming History English Translation Project, A collaborative project that makes available translations (from Chinese to English) of portions of the 明史 Mingshi (Official History of the Ming Dynasty). ...
Passage 2: ... ==== Reign of the Hongwu Emperor ==== ...
Passage 3: ... . The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China. Although the primary capital of Beijing fell in 1644 to a rebellion led by Li Zicheng (who established the short-lived Shun dynasty), numerous rump regimes ruled by remnants of the Ming imperial family—collectively called the Southern Ming—survived until 1662\. ...
Passage 4: ... . Zhu Di assumed the throne as the Yongle Emperor (r. 1402–24); his reign is universally viewed by scholars as a "second founding" of the Ming dynasty since he reversed many of his father's policies....

AGENT:  Thought: The passages do not directly mention the year 1453\. However, I can infer that the Ming dynasty lasted from 1368 to 1644, and 1453 falls within this period. I will use the `final_answer` tool to provide a summary based on the information available.
Code:
```py

final_answer("1453 年属于明朝时期,明朝统治时间为 1368 年到 1644 年。在这段时间里,明朝完全控制了中国,并经历了如紫禁城的建造和长城扩建等重要发展。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
The year 1453 falls within the period of the Ming dynasty, which ruled from 1368 to 1644\. During this time, the Ming dynasty was in full control of China and experienced significant developments such as the construction of the Forbidden City and the expansion of the Great Wall of China.
------------------------------
TASK:  Retrieve information about the query:"1453" from the Wikipedia page "Ming dynasty".
AGENT:  Thought: I need to find information related to the year 1453 from the Wikipedia page "Ming dynasty". I will use the `retrieve_passages` tool to get relevant passages.
Code:
```py

result = retrieve_passages("1453 明朝")

print(result)

```py<end_action>
OBSERVATIONS:  Print outputs:
Retrieved passages for query "1453 Ming dynasty":
Passage 0: ... The Ming dynasty, officially the Great Ming, was an imperial dynasty of China, ruling from 1368 to 1644 following the collapse of the Mongol-led Yuan dynasty. The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China ...
Passage 1: ... === Reference works and primary sources ===
Farmer, Edward L. ed. Ming History: An Introductory Guide to Research (1994).
Goodrich, Luther Carrington (1976). Dictionary of Ming Biography, 1368–1644\. New York: Columbia University Press. ISBN 978-0-231-03833-1.
The Ming History English Translation Project, A collaborative project that makes available translations (from Chinese to English) of portions of the 明史 Mingshi (Official History of the Ming Dynasty). ...
Passage 2: ... ==== Reign of the Hongwu Emperor ==== ...
Passage 3: ... . The Ming dynasty was the last imperial dynasty of China ruled by the Han people, the majority ethnic group in China. Although the primary capital of Beijing fell in 1644 to a rebellion led by Li Zicheng (who established the short-lived Shun dynasty), numerous rump regimes ruled by remnants of the Ming imperial family—collectively called the Southern Ming—survived until 1662\. ...
Passage 4: ... . Zhu Di assumed the throne as the Yongle Emperor (r. 1402–24); his reign is universally viewed by scholars as a "second founding" of the Ming dynasty since he reversed many of his father's policies....

AGENT:  Thought: The passages do not directly mention the year 1453\. However, I can infer that the Ming dynasty lasted from 1368 to 1644, and 1453 falls within this period. I will use the `final_answer` tool to provide a summary based on the information available.
Code:
```py

final_answer("1453 年属于明朝时期,明朝自 1368 年到 1644 年统治中国。在此期间,明朝完全控制了中国,并经历了重要的历史发展,例如紫禁城的建设和长城的扩展。")

```py<end_action>
OBSERVATIONS:  Print outputs:
Last output from code snippet:
The year 1453 falls within the period of the Ming dynasty, which ruled from 1368 to 1644\. During this time, the Ming dynasty was in full control of China and experienced significant developments such as the construction of the Forbidden City and the expansion of the Great Wall of China.

结论

在这篇博客中,我解释了如何使用代码智能体和一个“小型”开源 LLM(如Qwen2.5–7B-Instruct)创建一个多智能体 RAG 系统。我讨论了主要的架构特征以及我在实现 Hugging Face 代码智能体时做出的一些具体选择,以提高结果。完整的代码细节可在以下GitHub 仓库中找到。

尽管该多智能体系统由一个运行在消费级硬件上的小型模型驱动,但它能够解决与复杂查询相关的多跳问答任务。特别是:

  • 它可以将查询分解为可管理的子任务;

  • 它可以识别包含必要信息的维基百科页面;

  • 它可以将来自多个页面的信息进行整合;

  • 它可以在维基百科页面上搜索详细信息;

  • 它可以判断是否需要更多的信息并尝试去寻找;

  • 它可以成功修复它生成的代码中的小错误,并处理工具错误(如维基百科的歧义错误)。

我还概述了该系统的一些局限性,例如增加的计算时间、重复的操作以及潜在的幻觉传播。后者可以通过在系统中加入一个“校对员”智能体来缓解,确保报告的信息与检索到的来源一致。

还值得注意的是,由于该智能体系统的核心采用了标准的 RAG 方法,因此可以在框架中实现所有通常用于提高 RAG 效率和准确性的技术。

另一个可能的改进是使用技术来增加测试时间计算,给予模型更多的“思考时间”,类似于 OpenAI 的 o1/o3 模型。然而,值得注意的是,这一修改将进一步增加执行时间。

最后,由于多智能体系统由专注于单一任务的智能体组成,因此为每个智能体使用不同的模型引擎可能会提升性能。特别是,可以针对系统中的每个任务对不同的模型进行微调,从而进一步提高性能。这对于小型模型尤其有利。值得一提的是,微调数据可以通过在一组预定任务上运行系统并保存智能体在系统给出正确答案时的输出,从而收集,这样可以避免昂贵的人工数据标注。

我希望你觉得这个教程有用,你可以在GitHub 仓库中找到完整的代码实现,并在Colab 笔记本中亲自尝试。

多智能体系统 101

原文:towardsdatascience.com/multi-ai-agent-systems-101-bac58e3bcc47?source=collection_archive---------0-----------------------#2024-06-16

使用 CrewAI 自动化数据源管理中的常规任务

Mariya MansurovaTowards Data Science Mariya Mansurova

·发布于 Towards Data Science ·阅读时间:26 分钟·2024 年 6 月 16 日

--

图片来自 DALL-E 3

最初,当 ChatGPT 刚刚出现时,我们使用简单的提示来获得我们问题的答案。随后,我们遇到了幻觉问题,开始使用 RAG(检索增强生成)来为 LLM 提供更多上下文。之后,我们开始尝试 AI 代理,其中 LLM 作为推理引擎,决定接下来做什么、使用什么工具以及何时返回最终答案。

下一步的进化是创建这样的代理团队,使其能够相互协作。这种方法是合乎逻辑的,因为它反映了人类的互动方式。我们在团队中工作,每个成员都有一个特定的角色:

  • 产品经理提出下一个要处理的项目。

  • 设计师创造其外观和感觉。

  • 软件工程师开发解决方案。

  • 分析师检查数据,以确保其按预期执行,并找出改进产品以满足客户需求的方法。

类似地,我们可以创建一个 AI 代理团队,每个代理专注于一个领域。它们可以协作并共同达成最终结论。正如专业化提升现实生活中的表现一样,它也能提高 AI 代理的表现。

这种方法的另一个优势是更高的灵活性。每个代理可以使用自己独立的提示、工具集甚至是 LLM。例如,我们可以为系统的不同部分使用不同的模型。你可以为需要更多推理的代理使用 GPT-4,而为只做简单提取的代理使用 GPT-3.5。我们甚至可以为小型特定任务对模型进行微调,并将其用于我们的代理团队中。

这种方法的潜在缺点是时间和成本。多个智能体之间的交互和知识共享需要更多的 LLM 调用,并消耗额外的令牌。这可能导致更长的等待时间和增加的费用。

今天有几个框架可以用于多智能体系统。

下面是一些最流行的框架:

  • AutoGen:由微软开发,AutoGen 使用对话式的方法,是最早的多智能体系统框架之一。

  • LangGraph:虽然 LangGraph 并不是严格意义上的多智能体框架,但它允许使用图结构定义演员之间的复杂交互。因此,它也可以被改编为创建多智能体系统。

  • CrewAI:CrewAI 作为一个高级框架,便于创建由角色扮演智能体组成的“团队”,这些智能体能够以多种方式协作。

我决定从 CrewAI 开始尝试多智能体框架,因为它非常流行且易于使用。因此,它看起来是一个很好的入门选择。

在本文中,我将向你展示如何使用 CrewAI。作为分析师,我们是负责记录各种数据源并回答相关问题的领域专家。我们将探讨如何使用多智能体框架来自动化这些任务。

设置环境

让我们从设置环境开始。首先,我们需要安装 CrewAI 主包以及一个扩展来与工具一起使用。

pip install crewai
pip install 'crewai[tools]'

CrewAI 是为了主要与 OpenAI API 配合使用而开发的,但我也想尝试在本地模型上使用它。根据 ChatBot Arena 排行榜,你可以在笔记本电脑上运行的最佳模型是 Llama 3(8b 参数)。这将是我们使用案例中最可行的选择。

我们可以通过 Ollama 访问 Llama 模型。安装过程非常简单。你需要从 官方网站 下载 Ollama,然后按照安装过程进行操作。就是这样。

现在,你可以通过运行以下命令在 CLI 中测试模型。

ollama run llama3

例如,你可以提问类似这样的问题。

让我们创建一个自定义的 Ollama 模型,以便在 CrewAI 中稍后使用。

我们将从 ModelFile 开始(文档)。我只指定了基础模型(llama3)、温度和停止序列。不过,你可以添加更多的功能。例如,你可以使用 SYSTEM 关键字来确定系统消息。

FROM llama3

# set parameters
PARAMETER temperature 0.5
PARAMETER stop Result

我已经将它保存在 Llama3ModelFile 文件中。

让我们创建一个 bash 脚本来加载 Ollama 的基础模型,并创建我们在 ModelFile 中定义的自定义模型。

#!/bin/zsh

# define variables
model_name="llama3"
custom_model_name="crewai-llama3"

# load the base model
ollama pull $model_name

# create the model file
ollama create $custom_model_name -f ./Llama3ModelFile

让我们执行这个文件。

chmod +x ./llama3_setup.sh
./llama3_setup.sh

您可以在 GitHub 上找到这两个文件:Llama3ModelFilellama3_setup.sh

我们需要初始化以下环境变量以使用本地的 Llama 模型与 CrewAI。

os.environ["OPENAI_API_BASE"]='http://localhost:11434/v1'

os.environ["OPENAI_MODEL_NAME"]='crewai-llama3' 
# custom_model_name from the bash script

os.environ["OPENAI_API_KEY"] = "NA"

我们已经完成了设置,并准备继续我们的旅程。

用例:与文档一起工作

作为分析师,我们经常扮演数据和一些与数据相关工具的专家角色。在我之前的团队中,我们曾经有一个拥有近 1K 参与者的频道,我们在那里回答了很多关于我们的数据和我们用作存储的 ClickHouse 数据库的问题。我们花了相当多的时间来管理这个频道。看看这些任务是否可以通过 LLM 自动化执行将是很有趣的。

对于这个示例,我将使用 ClickHouse 数据库。如果您感兴趣,您可以在我的先前文章中了解更多关于 ClickHouse 以及如何在本地设置它的信息。但是,我们不会使用任何特定于 ClickHouse 的功能,因此请随意使用您了解的数据库。

我创建了一个非常简单的数据模型来使用。在我们的 DWH(数据仓库)中只有两个表:ecommerce_db.usersecommerce_db.sessions。正如你可能猜到的那样,第一个表包含有关我们服务的用户的信息。

ecommerce_db.sessions表存储有关用户会话的信息。

关于数据源管理,分析师通常处理编写和更新文档以及回答有关这些数据的问题等任务。因此,我们将使用 LLM 为数据库中的表编写文档,并教它回答有关数据或 ClickHouse 的问题。

但在继续实施之前,让我们更多地了解 CrewAI 框架及其核心概念。

CrewAI 基本概念

多代理框架的基石是代理概念。在 CrewAI 中,代理由角色扮演驱动。角色扮演是一种策略,当你要求一个代理扮演一个角色并像顶尖后端工程师或乐于助人的客服代理一样行事时。因此,在创建 CrewAI 代理时,您需要指定每个代理的角色、目标和背景故事,以便 LLM 了解足够的信息来扮演这个角色。

代理的能力在没有工具(代理可以执行并获得结果的功能)的情况下是有限的。使用 CrewAI,您可以使用预定义工具之一(例如,搜索互联网、解析网站或在文档上执行 RAG),自己创建自定义工具或使用 LangChain 工具。因此,创建强大的代理非常容易。

让我们从智能体转到它们正在做的工作。智能体正在处理任务(具体的任务)。对于每个任务,我们需要定义一个描述、期望的输出(完成的定义)、可用工具的集合和分配的智能体。我非常喜欢这些框架遵循管理最佳实践,例如对任务的清晰完成定义。

下一个问题是如何定义任务的执行顺序:首先做哪个任务,哪些任务可以并行执行,等等。CrewAI 实现了流程来协调任务。它提供了几个选项:

  • 顺序性 — 当任务一个接一个地被调用时,这是一种最直接的方式。

  • 层级化 — 当有一个经理(指定为 LLM 模型)创建并委派任务给智能体时。

此外,CrewAI 正在处理一个共识过程。在这样的过程中,智能体将能够通过民主的方式协作做出决策。

还有其他一些手段,你可以用来调整任务执行的过程:

  • 你可以将任务标记为“异步”,然后它们将并行执行,这样你将能更快得到答案。

  • 你可以在任务上使用“人工输入”标志,然后智能体将在最终确定该任务输出之前请求人类批准。这可以让你在过程中加入监督。

我们已经定义了所有的基本构件,并可以讨论 CrewAI 的圣杯——小组概念。小组代表智能体团队和他们将要处理的任务集合。协作的方式(我们之前讨论的流程)也可以在小组级别进行定义。

此外,我们可以为一个小组设置记忆。记忆对于智能体之间的高效协作至关重要。CrewAI 支持三种记忆级别:

  • 短期记忆存储与当前执行相关的信息。它帮助智能体在当前任务中协作。

  • 长期记忆是存储在本地数据库中的关于先前执行的数据。这种类型的记忆使得智能体能够从早期的迭代中学习并随时间改进。

  • 实体记忆捕捉并结构化关于实体的信息(如角色、城市等)。

目前,你只能为小组开启所有类型的记忆,而无法进行进一步的自定义。不过,这在 Llama 模型上不起作用。

我们已经学到了足够多的 CrewAI 框架内容,现在是时候将这些知识付诸实践了。

用例:编写文档

让我们从一个简单的任务开始:为我们的 DWH 编写文档。正如我们之前讨论过的,DWH 中有两个表,我希望使用 LLM 创建它们的详细描述。

第一种方法

一开始,我们需要考虑团队结构。这就像一个典型的管理任务。你会为这样的工作雇佣谁?

我将把这个任务分为两部分:从数据库中检索数据和编写文档。因此,我们需要一位数据库专家和一位技术写作者。数据库专家需要访问数据库,而写作者则不需要任何特殊工具。

现在,我们有了一个高层次的计划。让我们创建这些代理。

对于每个代理,我已经指定了角色、目标和背景故事。我尽力为代理提供了所有需要的上下文。

database_specialist_agent = Agent(
  role = "Database specialist",
  goal = "Provide data to answer business questions using SQL",
  backstory = '''You are an expert in SQL, so you can help the team 
  to gather needed data to power their decisions. 
  You are very accurate and take into account all the nuances in data.''',
  allow_delegation = False,
  verbose = True
)

tech_writer_agent = Agent(
  role = "Technical writer",
  goal = '''Write engaging and factually accurate technical documentation 
    for data sources or tools''',
  backstory = ''' 
  You are an expert in both technology and communications, so you can easily explain even sophisticated concepts.
  You base your work on the factual information provided by your colleagues.
  Your texts are concise and can be easily understood by a wide audience. 
  You use professional but rather an informal style in your communication.
  ''',
  allow_delegation = False,
  verbose = True
)

我们将使用一个简单的顺序过程,因此代理之间无需相互委派任务。这就是为什么我指定了allow_delegation = False

下一步是为代理设置任务。但在进行这些任务之前,我们需要创建一个自定义工具来连接数据库。

首先,我编写了一个函数,用于通过 HTTP API 执行 ClickHouse 查询。

CH_HOST = 'http://localhost:8123' # default address 

def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
  r = requests.post(host, params = {'query': query}, 
    timeout = connection_timeout)
  if r.status_code == 200:
      return r.text
  else: 
      return 'Database returned the following error:\n' + r.text 

在使用 LLM 代理时,确保工具具有容错性非常重要。例如,如果数据库返回错误(status_code != 200),我的代码不会抛出异常。相反,它会将错误描述返回给 LLM,以便 LLM 尝试解决问题。

要创建 CrewAI 自定义工具,我们需要从crewai_tools.BaseTool派生类,实施_run方法,然后创建该类的实例。

from crewai_tools import BaseTool

class DatabaseQuery(BaseTool):
  name: str = "Database Query"
  description: str = "Returns the result of SQL query execution"

  def _run(self, sql_query: str) -> str:
      # Implementation goes here
      return get_clickhouse_data(sql_query)

database_query_tool = DatabaseQuery()

现在,我们可以为代理设置任务。同样,向 LLM 提供清晰的指示和所有上下文是至关重要的。

table_description_task = Task(
  description = '''Provide the comprehensive overview for the data 
  in table {table}, so that it's easy to understand the structure 
  of the data. This task is crucial to put together the documentation 
  for our database''',
  expected_output = '''The comprehensive overview of {table} in the md format. 
  Include 2 sections: columns (list of columns with their types) 
  and examples (the first 30 rows from table).''',
  tools = [database_query_tool],
  agent = database_specialist_agent
)

table_documentation_task = Task(
  description = '''Using provided information about the table, 
  put together the detailed documentation for this table so that 
  people can use it in practice''',
  expected_output = '''Well-written detailed documentation describing 
  the data scheme for the table {table} in markdown format, 
  that gives the table overview in 1-2 sentences then then 
  describes each columm. Structure the columns description 
  as a markdown table with column name, type and description.''',
  tools = [],
  output_file="table_documentation.md",
  agent = tech_writer_agent
)

你可能已经注意到,我在任务描述中使用了{table}占位符。我们将在执行团队任务时,将table作为输入变量,这个值将被插入到所有占位符中。

此外,我已经为表格文档任务指定了输出文件,以便将最终结果保存在本地。

我们拥有所需的一切。现在,是时候创建一个团队并执行流程,指定我们感兴趣的表格了。让我们试试“users”表。

crew = Crew(
  agents = [database_specialist_agent, tech_writer_agent],
  tasks = [table_description_task,  table_documentation_task],
  verbose = 2
)

result = crew.kickoff({'table': 'ecommerce_db.users'})

这是一个令人兴奋的时刻,我真的很期待看到结果。如果执行过程需要一些时间,请不要担心。代理会多次调用 LLM,因此花费几分钟是完全正常的。我在我的笔记本上花了 2.5 分钟。

我们要求 LLM 以 Markdown 格式返回文档。我们可以使用以下代码在 Jupyter Notebook 中查看格式化后的结果。

from IPython.display import Markdown
Markdown(result)

第一眼看起来不错。我们得到了一个有效的 Markdown 文件,描述了“users”表。

等等,这不对。让我们看看表格中有哪些数据。

文档中列出的列与我们数据库中的完全不同。这是 LLM 幻觉的一个例子。

我们已将verbose = 2设置为从 CrewAI 获取详细的日志。让我们浏览执行日志,以找出问题的根本原因。

首先,数据库专家由于引用符号的问题,无法查询数据库。

专家未能解决这个问题。最终,这个链条已被 CrewAI 终止,输出为:Agent stopped due to iteration limit or time limit

这意味着技术写作者没有收到任何关于数据的事实性信息。然而,代理继续执行并产生了完全虚假的结果。最终我们得到了错误的文档。

修复问题

尽管我们的第一次迭代没有成功,但我们学到了很多。我们至少有两个可以改进的方面:

  • 我们的数据库工具对模型来说太复杂,代理难以使用它。我们可以通过去掉查询开始和结束的引号来使工具更具容错性。这个解决方案并不理想,因为有效的 SQL 可能以引号结尾,但我们还是试试看。

  • 我们的技术写作者没有基于数据库专家的输入来生成输出。我们需要调整提示,强调只提供事实性信息的重要性。

那么,让我们尝试修复这些问题。首先,我们将修复工具——我们可以利用strip来去除引号。

CH_HOST = 'http://localhost:8123' # default address 

def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
  r = requests.post(host, params = {'query': query.strip('"').strip("'")}, 
    timeout = connection_timeout)
  if r.status_code == 200:
    return r.text
  else: 
    return 'Database returned the following error:\n' + r.text

接下来,到了更新提示的时候。我已经在代理和任务定义中加入了强调坚持事实重要性的语句。

 tech_writer_agent = Agent(
  role = "Technical writer",
  goal = '''Write engaging and factually accurate technical documentation 
  for data sources or tools''',
  backstory = ''' 
  You are an expert in both technology and communications, so you 
  can easily explain even sophisticated concepts.
  Your texts are concise and can be easily understood by wide audience. 
  You use professional but rather informal style in your communication.
  You base your work on the factual information provided by your colleagues. 
  You stick to the facts in the documentation and use ONLY 
  information provided by the colleagues not adding anything.''',
  allow_delegation = False,
  verbose = True
)

table_documentation_task = Task(
  description = '''Using provided information about the table, 
  put together the detailed documentation for this table so that 
  people can use it in practice''',
  expected_output = '''Well-written detailed documentation describing 
  the data scheme for the table {table} in markdown format, 
  that gives the table overview in 1-2 sentences then then 
  describes each columm. Structure the columns description 
  as a markdown table with column name, type and description.
  The documentation is based ONLY on the information provided 
  by the database specialist without any additions.''',
  tools = [],
  output_file = "table_documentation.md",
  agent = tech_writer_agent
)

让我们再次执行我们的团队任务,看看结果。

我们取得了一些更好的结果。我们的数据库专家能够执行查询并查看数据,这是我们的一项重大胜利。此外,我们可以在结果表中看到所有相关的字段,尽管还有许多其他字段。所以,它仍然不是完全正确的。

我再次查看了 CrewAI 执行日志,试图弄清楚问题出在哪里。问题出在获取列列表时。没有按数据库进行过滤,因此它返回了一些与结果无关的列。

SELECT column_name 
FROM information_schema.columns 
WHERE table_name = 'users'

此外,在查看了多次尝试后,我注意到数据库专家偶尔执行select * from <table>查询。它可能在生产环境中引发一些问题,因为它可能会生成大量数据并将其发送到 LLM。

更专业的工具

我们可以为我们的代理提供更多专业的工具来改进我们的解决方案。目前,代理有一个可以执行任何 SQL 查询的工具,这个工具灵活且强大,但容易出错。我们可以创建更专注的工具,例如获取表结构和从表中提取前 N 行数据。希望这样可以减少错误的发生。

class TableStructure(BaseTool):
  name: str = "Table structure"
  description: str = "Returns the list of columns and their types"

  def _run(self, table: str) -> str:
    table = table.strip('"').strip("'")
    return get_clickhouse_data(
      'describe {table} format TabSeparatedWithNames'\
        .format(table = table)
    )

class TableExamples(BaseTool):
  name: str = "Table examples"
  description: str = "Returns the first N rows from the table"

  def _run(self, table: str, n: int = 30) -> str:
    table = table.strip('"').strip("'")
    return get_clickhouse_data(
      'select * from {table} limit {n} format TabSeparatedWithNames'
        \.format(table = table, n = n)
    )

table_structure_tool = TableStructure()
table_examples_tool = TableExamples()

现在,我们需要在任务中指定这些工具并重新运行我们的脚本。在第一次尝试后,我从技术写作者那里得到了以下输出。

Task output: This final answer provides a detailed and factual description 
of the ecommerce_db.users table structure, including column names, types, 
and descriptions. The documentation adheres to the provided information 
from the database specialist without any additions or modifications.

更加专注的工具帮助数据库专家检索了正确的表信息。然而,尽管写作者拥有所有必要的信息,我们仍未得到预期的结果。

如我们所知,LLM 是概率性的,所以我又尝试了一次。太好了,这次结果非常好。

它还不完美,因为仍然包括一些无关的评论,并且缺少表格的整体描述。然而,提供更多专业工具显然是值得的。这也有助于防止在代理尝试加载表格中的所有数据时出现问题。

质量保证专家

我们已经取得了相当不错的结果,但让我们看看是否能进一步改善它们。多代理设置中的常见做法是质量保证,它在最终确定结果之前添加了最后的审核阶段。

让我们创建一个新的代理——质量保证专家,负责审核工作。

qa_specialist_agent = Agent(
  role = "Quality Assurance specialist",
  goal = """Ensure the highest quality of the documentation we provide 
  (that it's correct and easy to understand)""",
  backstory = '''
  You work as a Quality Assurance specialist, checking the work 
  from the technical writer and ensuring that it's inline 
  with our highest standards.
  You need to check that the technical writer provides the full complete 
  answers and make no assumptions. 
  Also, you need to make sure that the documentation addresses 
  all the questions and is easy to understand.
  ''',
  allow_delegation = False,
  verbose = True
)

现在,到了描述审核任务的时候了。我使用了context参数,指定此任务需要table_description_tasktable_documentation_task的输出。

qa_review_task = Task(
  description = '''
  Review the draft documentation provided by the technical writer.
  Ensure that the documentation fully answers all the questions: 
  the purpose of the table and its structure in the form of table. 
  Make sure that the documentation is consistent with the information 
  provided by the database specialist. 
  Double check that there are no irrelevant comments in the final version 
  of documentation.
  ''',
  expected_output = '''
  The final version of the documentation in markdown format 
  that can be published. 
  The documentation should fully address all the questions, be consistent 
  and follow our professional but informal tone of voice.
  ''',
  tools = [],
  context = [table_description_task, table_documentation_task],
  output_file="checked_table_documentation.md",
  agent = qa_specialist_agent
)

让我们更新我们的团队并运行它。

full_crew = Crew(
  agents=[database_specialist_agent, tech_writer_agent, qa_specialist_agent],
  tasks=[table_description_task,  table_documentation_task, qa_review_task],
  verbose = 2,
  memory = False # don't work with Llama
)

full_result = full_crew.kickoff({'table': 'ecommerce_db.users'})

由于加入了 QA 阶段,我们现在有了更结构化和详细的文档。

委派

随着质量保证(QA)专家的加入,测试委派机制会变得非常有趣。QA 专家代理可能会有一些问题或请求,可以将其委派给其他代理。

我尝试使用 Llama 3 进行委派,但结果不太理想。Llama 3 在正确调用同事工具时遇到了困难,无法指定正确的同事姓名。

我们使用本地模型在任何笔记本电脑上都能取得相当不错的结果,但现在是时候换个思路,使用更强大的模型——GPT-4o 了。

为了实现这一点,我们只需更新以下环境变量。

os.environ["OPENAI_MODEL_NAME"] = 'gpt-4o'  
os.environ["OPENAI_API_KEY"] = config['OPENAI_API_KEY'] # your OpenAI key

为了开启委派功能,我们应该为 QA 专家代理指定allow_delegation = True

此外,我们还可以使用 OpenAI 模型的便捷记忆功能,如上所述。记忆功能将允许我们的代理在执行过程中互相分享知识,并利用长期记忆从以前的执行中获取信息。

full_crew = Crew(
  agents = [database_specialist_agent, tech_writer_agent, qa_specialist_agent],
  tasks = [table_description_task,  table_documentation_task, qa_review_task],
  verbose = 2,
  memory = True 
)

full_result = full_crew.kickoff({'table': 'ecommerce_db.users'})

让我们看看 CrewAI 的执行日志,了解委派是如何工作的。这里是 QA 专家的所有日志。我们可以看到它联系了数据库专家以核实信息。

Thought: I need to review the provided draft documentation to ensure it 
fully answers all the questions and is complete, correct, and easy 
to understand. Specifically, I need to ensure that it explains the purpose 
of the table, its structure, and includes relevant examples without 
any irrelevant comments.

First, I will compare the two drafts to identify any discrepancies and 
missing information. Then, I will verify the accuracy and completeness of 
the information provided.

**Comparison:**
1\. **Introduction:**
   - First draft: Includes an introduction and examples.
   - Second draft: Adds a more detailed introduction about the purpose and importance of the table.

2\. **Columns:**
   - Both drafts have the same column structure and descriptions.

3\. **Additional Information:**
   - Both drafts include similar additional information, but the second draft is more concise.

4\. **Examples:**
   - The first draft includes examples, which are missing in the second draft.

5\. **Final Checks:**
   - Ensure there are no irrelevant comments.
   - Verify consistency with information provided by the database specialist.

Action: Ask question to co-worker
Action Input: {
"question": "Can you confirm the accuracy of the provided column descriptions 
  and examples for the `ecommerce_db.users` table?", 
"context": "We are finalizing the documentation for the `ecommerce_db.users`
 table. The table includes columns `user_id`, `country`, `is_active`, and 
  `age`. Here are the examples provided:

  | user_id | country        | is_active | age |
  |---------|----------------|-----------|-----|
  | 1000001 | United Kingdom | 0         | 70  |
  | 1000002 | France         | 1         | 87  |
  | 1000003 | France         | 1         | 88  |
  | 1000004 | Germany        | 1         | 25  |
  | 1000005 | Germany        | 1         | 48  |
  | 1000006 | France         | 1         | 78  |
  | 1000007 | Netherlands    | 1         | 65  |
  | 1000008 | Germany        | 1         | 31  |
  | 1000009 | Germany        | 1         | 66  |
  | 1000010 | Germany        | 1         | 73  |", 
"coworker": "database specialist"}

<-- omitted logs from the database specialist -->

Based on the provided context and examples for the `ecommerce_db.users` table, 
the column descriptions and examples appear accurate. Here are the details:
- `user_id`: This column contains a unique identifier for each user. 
  The provided examples (e.g., 1000001, 1000002) align with this description 
  as they are unique numeric identifiers.

- `country`: This column specifies the country of the user. The provided 
  examples (e.g., United Kingdom, France, Germany, Netherlands) are 
  all valid country names.

- `is_active`: This column indicates whether the user is active (1) or not (0). The provided examples (e.g., 0, 1) correctly represent this binary status.

- `age`: This column shows the age of the user. The provided examples (e.g., 70, 
  87, 88, 25, 48, 78, 65, 31, 66, 73) are all valid ages.

Therefore, the column descriptions and examples provided for the 
 `ecommerce_db.users` table are accurate and meet the expected criteria.

Thought: Based on the observation from the database specialist, 
the column descriptions and examples provided for the `ecommerce_db.users` 
table are accurate. Now, I will consolidate the best elements from both drafts,
 ensuring the documentation is complete, clear, and free of irrelevant comments.

Action: Finalize the documentation in markdown format, incorporating 
the detailed introduction, examples, and ensuring the tone is professional 
but informal.

当我第一次尝试委派时,没有启用记忆功能,这导致了不正确的结果。数据专家和技术写作者最初返回了正确的信息。然而,当 QA 专家返回后续问题时,他们开始出现幻觉。因此,看起来当启用记忆功能时,委派效果会更好。

这是 GPT-4o 的最终输出。现在结果看起来非常不错。我们完全可以使用 LLM 来自动化文档编写。

所以,第一个任务已经解决!

我使用相同的脚本为ecommerce_db.sessions表生成文档。这对我们的下一个任务很有用。所以,让我们不要浪费时间,继续前进。

用例:回答问题

我们的下一个任务是根据文档回答问题,因为这对许多数据分析师(以及其他专家)来说是常见的。

我们将从简单开始,只创建两个代理人:

  • 文档支持专家将根据文档回答问题,

  • 支持 QA 代理人将在与客户分享答案之前对其进行审查。

我们需要为文档专家提供一些工具,使他们能够查看存储在目录中的所有文件并阅读这些文件。这很简单,因为 CrewAI 已经实现了这样的工具。

from crewai_tools import DirectoryReadTool, FileReadTool

documentation_directory_tool = DirectoryReadTool(
    directory = '~/crewai_project/ecommerce_documentation')

base_file_read_tool = FileReadTool()

然而,由于 Llama 3 在调用工具时仍然在处理引号的问题,我不得不在FileReaderTool之上创建一个自定义工具来解决这个问题。

from crewai_tools import BaseTool

class FileReadToolUPD(BaseTool):
    name: str = "Read a file's content"
    description: str = "A tool that can be used to read a file's content."

    def _run(self, file_path: str) -> str:
        # Implementation goes here
        return base_file_read_tool._run(file_path = file_path.strip('"').strip("'"))

file_read_tool = FileReadToolUPD()

接下来,我们需要像之前一样创建代理人、任务和团队。

data_support_agent = Agent(
  role = "Senior Data Support Agent",
  goal = "Be the most helpful support for you colleagues",
  backstory = '''You work as a support for data-related questions 
  in the company. 
  Even though you're a big expert in our data warehouse, you double check 
  all the facts in documentation. 
  Our documentation is absolutely up-to-date, so you can fully rely on it 
  when answering questions (you don't need to check the actual data 
  in database).
  Your work is very important for the team success. However, remember 
  that examples of table rows don't show all the possible values. 
  You need to ensure that you provide the best possible support: answering 
  all the questions, making no assumptions and sharing only the factual data.
  Be creative try your best to solve the customer problem. 
  ''',
  allow_delegation = False,
  verbose = True
)

qa_support_agent = Agent(
  role = "Support Quality Assurance Agent",
  goal = """Ensure the highest quality of the answers we provide 
  to the customers""",
  backstory = '''You work as a Quality Assurance specialist, checking the work 
  from support agents and ensuring that it's inline with our highest standards.
  You need to check that the agent provides the full complete answers 
  and make no assumptions. 
  Also, you need to make sure that the documentation addresses all 
  the questions and is easy to understand.
  ''',
  allow_delegation = False,
  verbose = True
)

draft_data_answer = Task(
  description = '''Very important customer {customer} reached out to you 
  with the following question:

{问题}


  Your task is to provide the best answer to all the points in the question 
  using all available information and not making any assumprions. 
  If you don't have enough information to answer the question, just say 
  that you don't know.''',
  expected_output = '''The detailed informative answer to the customer's 
  question that addresses all the point mentioned. 
  Make sure that answer is complete and stict to facts 
  (without any additional information not based on the factual data)''',
  tools = [documentation_directory_tool, file_read_tool], 
  agent = data_support_agent
)

answer_review = Task(
  description = '''
  Review the draft answer provided by the support agent.
  Ensure that the it fully answers all the questions mentioned 
  in the initial inquiry. 
  Make sure that the answer is consistent and doesn't include any assumptions.
  ''',
  expected_output = '''
  The final version of the answer in markdown format that can be shared 
  with the customer. 
  The answer should fully address all the questions, be consistent 
  and follow our professional but informal tone of voice. 
  We are very chill and friendly company, so don't forget to include 
  all the polite phrases.
  ''',
  tools = [], 
  agent = qa_support_agent
)

qna_crew = Crew(
  agents = [data_support_agent, qa_support_agent],
  tasks = [draft_data_answer,  answer_review],
  verbose = 2,
  memory = False # don't work with Llama
)

让我们看看它在实践中是如何运作的。

result = qna_crew.kickoff(
  {'customer': "Max", 
   'question': """Hey team, I hope you're doing well. I need to find 
    the numbers before our CEO presentation tomorrow, so I will really 
    appreciate your help.
    I need to calculate the number of sessions from our Windows users in 2023\. I've tried to find the table with such data in our data warehouse, but wasn't able to. 
    Do you have any ideas whether we store the needed data somewhere, 
    so that I can query it? """
  }
)

我们得到了一份有礼貌、实用和有帮助的答复。这真的很棒。

**Hello Max,**

Thank you for reaching out with your question! I'm happy to help you 
find the number of sessions from Windows users in 2023\. 
After reviewing our documentation, I found that we do store data 
related to sessions and users in our ecommerce database, specifically in 
the `ecommerce_db.sessions` table.

To answer your question, I can provide you with a step-by-step guide 
on how to query this table using SQL. First, you can use the `session_id` 
column along with the `os` column filtering for "Windows" and 
the `action_date` column filtering for dates in 2023\. 
Then, you can group the results by `os` using the `GROUP BY` clause 
to count the number of sessions that meet these conditions.

Here's a sample SQL query that should give you the desired output:

```sql

SELECT COUNT(*)

FROM ecommerce_db.sessions

WHERE os = 'Windows'

AND action_date BETWEEN '2023-01-01' AND '2023-12-31'

GROUP BY os;

```py

This query will return the total number of sessions from Windows 
users in 2023\. I hope this helps! If you have any further questions or 
need more assistance, please don't hesitate to ask.

让我们稍微复杂化任务。假设我们不仅可以得到关于我们的数据的问题,还可以得到关于我们的工具(ClickHouse)的问题。因此,我们的团队中将有另一个代理人——ClickHouse 专家。为了给我们的 CH 代理人一些知识,我将与它分享一个文档网站。

from crewai_tools import ScrapeWebsiteTool, WebsiteSearchTool
ch_documenation_tool = ScrapeWebsiteTool(
  'https://clickhouse.com/docs/en/guides/creating-tables')

如果您需要处理长篇文档,可以尝试使用 RAG(检索增强生成)—WebsiteSearchTool。它将计算嵌入并将其存储在 ChromaDB 中。在我们的情况下,我们将坚持使用简单的网站爬虫工具。

现在我们有了两位主题专家,我们需要决定谁将负责回答问题。因此,现在是使用分层过程并添加经理来协调所有任务的时候了。

CrewAI 提供了经理实现,因此我们只需要指定 LLM 模型。我选择了 GPT-4o。

from langchain_openai import ChatOpenAI
from crewai import Process

complext_qna_crew = Crew(
  agents = [ch_support_agent, data_support_agent, qa_support_agent],
  tasks = [draft_ch_answer, draft_data_answer, answer_review],
  verbose = 2,
  manager_llm = ChatOpenAI(model='gpt-4o', temperature=0),  
  process = Process.hierarchical,  
  memory = False 
)

在这一点上,我不得不再次从 Llama 3 切换到 OpenAI 模型以运行分层过程,因为使用 Llama 对我来说没有效果(类似于此问题)。

现在,我们可以尝试使用不同类型的问题来测试我们的新团队(无论是与我们的数据还是 ClickHouse 数据库相关的问题)。

ch_result = complext_qna_crew.kickoff(
  {'customer': "Maria", 
   'question': """Good morning, team. I'm using ClickHouse to calculate 
   the number of customers. 
   Could you please remind whether there's an option to add totals 
   in ClickHouse?"""
  }
)

doc_result = complext_qna_crew.kickoff(
  {'customer': "Max", 
   'question': """Hey team, I hope you're doing well. I need to find 
    the numbers before our CEO presentation tomorrow, so I will really 
    appreciate your help.
    I need to calculate the number of sessions from our Windows users 
    in 2023\. I've tried to find the table with such data 
    in our data warehouse, but wasn't able to. 
    Do you have any ideas whether we store the needed data somewhere, 
    so that I can query it. """
  }
)

如果我们查看最终的答案和日志(我在这里省略了它们,因为它们相当冗长,但 你可以在 GitHub 找到它们和完整日志),我们会看到管理器能够正确地协调并将任务委派给拥有相关知识的同事,以解答客户的问题。对于第一个(与 ClickHouse 相关)问题,我们得到了一个详细的答案,包含了示例以及使用WITH TOTALS功能的潜在影响。对于数据相关的问题,模型返回的信息大致与我们上面看到的相同。

因此,我们构建了一个可以根据文档回答各种问题的团队,无论是来自本地文件还是网站。我认为这是一个非常棒的结果。

你可以在GitHub上找到所有的代码。

总结

在本文中,我们探讨了如何使用 CrewAI 多智能体框架创建一个基于表格撰写文档并回答相关问题的解决方案。

鉴于我们所使用的广泛功能,现在是总结这个框架的优缺点的时候了。

总体来说,我认为CrewAI是一个非常有用的多智能体系统框架:

  • 它简单明了,你可以快速构建你的第一个原型。

  • 它的灵活性使得可以解决相当复杂的商业问题。

  • 它鼓励像角色扮演这样的良好实践。

  • 它提供了许多开箱即用的实用工具,如 RAG 和网站解析器。

  • 支持不同类型的记忆增强了智能体之间的协作。

  • 内置的防护措施有助于防止智能体陷入重复的循环中。

然而,仍有一些可以改进的地方:

  • 尽管该框架简单且易于使用,但它的可定制性不强。例如,目前你无法创建自己的 LLM 管理器来协调这些过程。

  • 有时,从文档中获取完整详细的信息是相当具有挑战性的。例如,显然 CrewAI 实现了一些防护措施,以防止重复的函数调用,但文档并没有完全解释它是如何工作的。

  • 另一个改进的方向是透明度。我喜欢了解框架是如何在背后工作的。例如,在 Langchain 中,你可以使用langchain.debug = True来查看所有 LLM 调用。然而,我还没有弄清楚如何在 CrewAI 中获得相同级别的详细信息。

  • 对本地模型的全面支持将是一个很大的补充,因为当前的实现要么缺少某些功能,要么难以正确运行。

LLM 领域和工具正在迅速发展,因此我希望在不久的将来会看到许多进展。

非常感谢你阅读这篇文章。希望这篇文章对你有所启发。如果你有任何后续问题或评论,请在评论区留下。

参考资料

本文灵感来源于 DeepLearning.AI 的“带 CrewAI 的多 AI 智能体系统”短期课程。

使用 Keras 3 进行多框架 AI/ML 开发

原文:towardsdatascience.com/multi-framework-ai-ml-development-with-keras-3-cf7be29eb23d?source=collection_archive---------3-----------------------#2024-06-16

欢迎 Keras 的回归

Chaim RandTowards Data Science Chaim Rand

·发布于 Towards Data Science ·14 分钟阅读·2024 年 6 月 16 日

--

照片由 Jose Rueda 提供,来自 Unsplash

作者:

Keras 回来了!!Keras 最初于 2015 年发布,作为一个高阶 Python 库,用于训练机器学习模型。由于其简洁清晰的 API,Keras 迅速获得了人气。与当时那些笨重且不灵活的机器学习框架相比,Keras 降低了许多现有机器学习开发者(包括作者本人)的入门门槛。然而,在某个时刻,Keras 的使用几乎与 TensorFlow 开发同义。因此,当开发者开始转向其他框架时,Keras 的相对受欢迎程度开始下降。但现在,经过“完全重写”后,Keras 已经回归。凭借其崭新的引擎和对多后端支持的重新承诺,它有望恢复昔日的辉煌。

在这篇文章中,我们将重新审视 Keras,并评估其在当前 AI/ML 开发时代的价值。我们将通过示例展示其易用性,并指出其不足之处。重要的是,这篇文章并非旨在支持或反对采用 Keras(或任何其他框架、库、服务等)。像往常一样,最适合你项目开发的决策将取决于许多细节,其中很多超出了这篇文章的范围。

Google 最近发布的开源 NLP 模型家族 Gemma,以及 Keras 3 作为其 API 核心组件的加入,为我们提供了评估 Keras 优点的机会,也可能成为其复兴的一个绝佳机会。

为什么使用 Keras 3?

在我们看来,Keras 3 最有价值的特点是其多框架支持。这可能会让一些读者感到惊讶,他们可能会记得 Keras 的独特性在于其用户体验。Keras 3 自我宣传为“简单”、“灵活”,并且是“为人类设计,而不是为机器设计”。的确,Keras 的早期成功和急剧的流行上升,正是得益于其优秀的用户体验。但现在是 2024 年,许多高级深度学习 API 都提供了“降低认知负担”的功能。在我们看来,虽然用户体验非常好,但它已经不足以成为考虑选择 Keras 而不是其他替代框架的充足动因。而它的多框架支持则是。

多框架支持的优点

Keras 3 支持多个后端来训练和运行其模型。在本文撰写时,这些后端包括JAXTensorFlowPyTorchKeras 3 公告很好地解释了这一特性的优势。我们将在此基础上扩展并加入一些自己的观点。

避免选择 AI/ML 框架的难题: 选择一个 AI/ML 框架可能是作为 ML 开发者你需要做出的最重要的决定之一,同时也是最困难的决定之一。这个决策需要考虑的因素有很多,包括用户体验、API 覆盖范围、可编程性、可调试性、支持的输入数据的格式和类型、与开发流程中其他组件的兼容性(例如,模型部署阶段可能施加的限制),以及最重要的——运行时性能。正如我们在之前的多篇文章中讨论的那样(例如,这里),AI/ML 模型开发可能非常昂贵,选择框架所带来的哪怕是最小的性能提升,也会对整体成本产生重大影响。事实上,在很多情况下,这可能值得你为将模型和代码迁移到不同的框架,甚至支持多个框架而付出额外的开销。

问题在于,在你开始开发之前,几乎不可能知道哪个框架对你的模型最为优化。而且,即使你已经选择了一个框架,你也会希望紧跟所有框架的演变和发展,并持续评估改进模型和/或降低开发成本的潜在机会。AI/ML 开发的环境非常动态,优化和增强功能正在持续不断地设计和开发中。你不希望落后于人。

Keras 3 通过允许你在不承诺底层后端的情况下开发模型,解决了框架选择问题。通过在多个框架后端之间切换的选项,你可以专注于模型定义,完成后再选择最适合你需求的后端。即使随着机器学习项目的特性变化或支持的框架发展,Keras 3 也能让你轻松评估更换后端的影响。

通俗地说,你可以说 Keras 3 帮助人类避免做他们最讨厌做的事情之一——做出决定并坚持下去。但言归正传,使用 Keras 3 进行 AI/ML 模型开发,确实能够帮助你避免选择并被迫困在一个次优的框架中。

享受各个世界的最佳优势: PyTorch、TensorFlow 和 JAX 各自具有独特的优势和差异化特性。例如,JAX 支持即时编译(JIT),在这种模式下,模型操作符会被转换成中间计算图,然后编译成专门针对底层硬件的机器码。对于许多模型来说,这通常会显著提升运行时性能。另一方面,PyTorch 通常以立即执行操作符(即“急切执行”)的方式使用,通常被认为:具有最符合 Python 风格的接口、最容易调试,并提供最佳的整体用户体验。通过使用 Keras 3,你可以享受两者的最佳优点。在初始模型开发和调试过程中,你可以将后端设置为 PyTorch,而在生产模式下进行训练时,可以切换到 JAX 以获得最佳性能。

与最大数量的 AI 加速器和运行时环境的兼容性: 正如我们之前所讨论的(例如,在这里),我们的目标是与尽可能多的 AI 加速器和运行时环境兼容。在 AI 机器容量受限的时代,能够在不同机器类型之间切换是一个巨大的优势。当你使用 Keras 3 及其多后端支持进行开发时,你将自动增加可以训练和运行你模型的平台数量。例如,虽然你可能最习惯在 GPU 上运行 PyTorch,但只需将后端更改为 JAX,你也可以将模型配置为在 Google Cloud TPUs 上运行(——不过这可能取决于模型的具体细节)。

增加模型的采用率: 如果你的目标是让其他 AI/ML 团队使用你的模型,那么通过支持多个框架,你将增加潜在的受众群体。由于各种原因,一些团队可能会局限于特定的 ML 框架。通过在 Keras 中交付你的模型,你消除了采用的障碍。一个很好的例子是最近发布的 Google 的 Gemma 模型,我们将在下面更详细地讨论。

将数据输入管道与模型执行解耦: 一些框架鼓励使用某些数据存储格式和/或数据加载实践。一个经典的例子是 TensorFlow 的 TFRecord 数据格式,用于存储通常以 .tfrecord 文件形式存储的二进制记录序列。尽管 TensorFlow 原生支持解析和处理存储在 TFRecord 文件中的数据,但你可能会发现将它们输入到 PyTorch 的训练循环中会有些困难。PyTorch 训练的一个更合适的格式可能是 WebDataset。但是,训练数据的创建可能是一个漫长的过程,而将其维护为多种格式可能会变得极其昂贵。因此,训练数据的存储和维护方式可能会使团队不愿考虑使用其他框架。

Keras 3 帮助团队克服这一障碍,通过完全解耦数据输入管道和训练循环。你可以在 PyTorch、TensorFlow、Numpy、Keras 和其他库中定义输入数据管道,而无需考虑将用于训练循环的后端。使用 Keras 3 后,将训练数据存储在 TFRecord 文件中不再是采用 PyTorch 作为后端的障碍。

多框架支持的缺点

与市场上其他任何新软件解决方案一样,了解 Keras 3 的潜在缺点非常重要。软件开发中的一个通用准则是,越是在软件堆栈的上层,你对应用程序行为和性能的控制就越少。在 AI/ML 领域,成功的程度通常取决于模型超参数、初始化设置、环境配置等的精确调节,这样的控制可能至关重要。以下是一些需要考虑的潜在缺点:

运行时性能的潜在下降: 使用高级 Keras API 而不是直接使用框架 API,可能会对优化运行时性能造成一定的限制。在我们关于分析和优化 PyTorch 模型性能的系列文章中,我们展示了各种提高训练速度的工具和技巧。有时这些技巧需要直接、不加中介地使用 PyTorch 的 API。例如,Keras 的 API 目前对PyTorch 的 JIT 编译选项的支持非常有限(通过jit_compile设置)。另一个例子是 PyTorch 对缩放点积注意力的内置支持,而 Keras 层面(截至本文撰写时)并不支持这一功能。

跨框架支持的局限性: 虽然 Keras 的跨框架支持非常广泛,但你可能会发现它并不是包罗万象的。例如,一个覆盖的空白(截至本文撰写时)是分布式训练。虽然 Keras 引入了Keras 分布式 API来支持所有后端的数据和模型并行性,但目前仅对 JAX 后端实现了该功能。要在使用其他后端时运行分布式训练,你需要回退到相关框架的标准分布式 API(例如,PyTorch 的分布式数据并行 API)。

维护跨框架兼容性的开销: Keras 3 支持多种预构建模型,您可以复用它们(例如,这里)。然而,您可能不可避免地想要引入自己的自定义功能。虽然 Keras 3 支持自定义模型层、指标、训练循环等,但您需要小心避免破坏跨框架兼容性。例如,如果您使用 Keras 的后端无关 API(keras.ops)创建了一个自定义层,您可以放心,它会保留多后端支持。但是,有时您可能会选择依赖于框架特定的操作。在这种情况下,保持跨框架兼容性将需要为每个框架提供专门的实现,并根据所使用的后端进行适当的条件编程。目前自定义训练步骤训练循环的方法是框架特定的,这意味着它们同样需要为每个后端提供专门的实现,以保持跨框架兼容性。因此,随着模型复杂性的增加,维持这一独特能力所需的开销也可能增加。

我们仅列出了 Keras 3 及其多后端支持的一些潜在缺点,您可能还会遇到其他问题。尽管多框架的提供确实具有吸引力,但其采用并不一定是无代价的。借用统计推断领域著名定理的名字,可以说,当涉及到选择 AI/ML 开发方法时,“没有免费的午餐”。

Keras 3 的实践 — 一个玩具示例

正如我们在许多最近的文章中提到的,我们将定义的玩具模型是一个视觉变换器(ViT)支持的分类模型。我们将依赖于此 Keras教程中的参考实现。我们已根据ViT-Base架构(约 8600 万个参数)配置了我们的模型,将mixed_precision策略设置为使用bfloat16,并定义了一个带有随机输入数据的PyTorch 数据加载器

以下区块包含了配置设置,并紧接着是核心 ViT 模型组件的定义:

import os

# choose backend
backend = 'jax' # 'torch'
os.environ["KERAS_BACKEND"] = backend

import keras
from keras import layers
from keras import ops

# set mixed precision policy
keras.mixed_precision.set_global_policy('mixed_bfloat16')

# use ViT Base settings
num_classes = 1000
image_size = 224
input_shape = (image_size, image_size, 3)
patch_size = 16  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 768
num_heads = 12
transformer_units = [
    projection_dim * 4,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 12

# set training hyperparams
batch_size = 128
multi_worker = False # toggle to use multiple data loader workers
preproc_workers = 0 if 'jax' else 16

# ViT model components:
# ---------------------

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        input_shape = ops.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = keras.ops.image.extract_patches(images, size=self.patch_size)
        patches = ops.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

使用核心组件,我们定义了一个 ViT 支持的 Keras 模型:

# the attention layer we will use in our ViT classifier
attention_layer = layers.MultiHeadAttention

def create_vit_classifier():
    inputs = keras.Input(shape=input_shape)
    # Create patches.
    patches = Patches(patch_size)(inputs)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = attention_layer(
            num_heads=num_heads, key_dim=projection_dim//num_heads, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.GlobalAveragePooling1D()(representation)
    representation = layers.Dropout(0.5)(representation)

    # Classify outputs.
    logits = layers.Dense(num_classes)(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

# create the ViT model
model = create_vit_classifier()
model.summary()

在下一个区块中,我们定义了优化器、损失函数和数据集。

model.compile(
    optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )

def get_data_loader(batch_size):
    import torch
    from torch.utils.data import Dataset, DataLoader

    # create dataset of random image and label data
    class FakeDataset(Dataset):
        def __len__(self):
            return 1000000

        def __getitem__(self, index):
            rand_image = torch.randn([224, 224, 3], dtype=torch.float32)
            label = torch.tensor(data=[index % 1000], dtype=torch.int64)
            return rand_image, label

    ds = FakeDataset()
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=preproc_workers if multi_worker else 0,
        pin_memory=True
    )
    return dl

dl = get_data_loader(batch_size)

最后,我们使用 Keras 的Model.fit()函数开始训练:

model.fit(
    dl,
    batch_size=batch_size,
    epochs=1
)

我们在 Google Cloud Platform (GCP) g2-standard-16 虚拟机(配备单个 NVIDIA L4 GPU)上运行了上述脚本,并使用了专门的 深度学习虚拟机镜像(common-cu121-v20240514-ubuntu-2204-py310),并安装了 PyTorch(2.3.0)、JAX(0.4.28)、Keras(3.3.3)和 KerasCV(0.9.0)。请参阅 Keras 官方文档获取完整的安装说明。请注意,我们手动修改了由 Keras 进度条报告的步骤时间格式:

 formatted += f" {time_per_unit:.3f}s/{unit_name}"

使用 backend 标志,我们能够轻松地在 Keras 支持的后端之间切换,并比较每个后端的运行时性能。例如,当配置 PyTorch 数据加载器 并设置为 0 个工作线程时,我们发现 JAX 后端比 PyTorch 快约 24%。当将工作线程数设置为 16 时,这一差距缩小至约 12%。

自定义注意力层

我们现在定义了一个自定义注意力层,用 PyTorch 的 闪电注意力 实现替换了 Keras 的默认注意力 计算。请注意,这仅在后端设置为 torch 时有效。

class MyAttention(layers.MultiHeadAttention):
    def _compute_attention(
            self, query, key, value, attention_mask=None, training=None
    ):
        from torch.nn.functional import scaled_dot_product_attention
        query = ops.multiply(
            query, ops.cast(self._inverse_sqrt_key_dim, query.dtype))
        return scaled_dot_product_attention(
            query.transpose(1,2),
            key.transpose(1,2),
            value.transpose(1,2),
            dropout_p=self._dropout if training else 0.
            ).transpose(1,2), None

attention_layer = MyAttention

我们实验的结果总结在下面的表格中。请记住,相对性能结果可能会根据模型的细节和运行时环境有很大差异。

ViT 运行时(作者)

使用我们的自定义注意力层时,JAX 和 PyTorch 后端之间的差距几乎消失。这凸显了多后端解决方案的使用可能会以牺牲某些框架特有的优化为代价(在我们的示例中,是 PyTorch 的 SDPA)。

Keras 3 in Gemma

Gemma 是 Google 最近发布的一系列轻量级 开源模型。Keras 3 在 Gemma 发布中扮演了重要角色(例如,参见 此处),其多框架支持使得 Gemma 能够自动服务于各类 AI/ML 开发者——无论是 PyTorch、TensorFlow 还是 Jax。有关 Gemma API 的详细信息,请参阅官方 文档KerasNLP

以下代码大致基于官方的Gemma 微调教程。为了运行此脚本,请遵循必要的设置说明

import os
backend = 'jax' #'torch'
os.environ["KERAS_BACKEND"] = backend

num_batches = 1000
batch_size = 4 if backend == 'jax' else 2

# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
os.environ["KAGGLE_USERNAME"]="chaimrand"
os.environ["KAGGLE_KEY"]="29abebb28f899a81ca48bec1fb97faf1"
import keras
import keras_nlp
keras.mixed_precision.set_global_policy('mixed_bfloat16')

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training batches, to keep it fast.
data = data[:num_batches*batch_size]
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

gemma_lm.summary()
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512

gemma_lm.compile(
   loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
   optimizer=keras.optimizers.SGD(learning_rate=5e-5),
   weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=batch_size)

当在上述相同的 GCP 环境中运行脚本时,我们发现使用 JAX 后端(每秒 6.87 个样本)与使用 PyTorch 后端(每秒 3.01 个样本)之间的运行时性能存在显著(且令人惊讶的)差异。这部分原因在于,JAX 后端允许将训练批次大小加倍。对这种差异原因的深入探讨超出了本文的范围。

正如我们之前的示例所示,我们通过在脚本顶部添加以下矩阵乘法操作的配置来展示了一种优化 PyTorch 运行时的方法:

import torch
torch.set_float32_matmul_precision('high')

这个简单的变化在使用 PyTorch 后端时带来了 29% 的性能提升。再次,我们可以看到应用框架特定优化的影响。实验结果总结在下面的表格中。

Gemma 微调运行时(作者提供)

结论

我们的演示表明,坚持使用与后端无关的 Keras 代码可能会带来显著的运行时性能损失。在每个示例中,我们都看到了一个简单的框架特定优化如何对我们选择的后端的相对性能产生显著影响。同时,我们讨论的多框架 AI/ML 开发的论点也相当有说服力。

如果你选择采用 Keras 作为开发框架,你可能需要考虑以一种设计代码的方式来实现和评估框架特定优化的机制。你还可以考虑在项目的早期阶段使用 Keras,并随着项目的成熟,针对揭示出的最合适的后端进行优化。

摘要

在本文中,我们探讨了新发布的 Keras 3 版本。Keras 3 不再是 TensorFlow 的附属品,而是提供了框架无关的 AI/ML 模型开发能力。正如我们讨论的,这种能力有几个显著的优势。然而,正如 AI 开发领域中常见的情况一样,“天下没有免费的午餐”——这种增加的抽象层次可能意味着我们对代码内部工作原理的控制力减弱,这可能导致训练速度变慢和成本增加。最佳的解决方案可能是结合使用 Keras 及其多框架支持,并采用专门的机制来纳入框架特定的修改。

重要的是,Keras 3 在您的项目中的适用性以及所需投资的性价比分析,将在很大程度上取决于多种因素,包括:目标受众、模型部署过程、项目时间表等。请将此帖子视为您深入探索的简要介绍。

多 GPU 微调 Llama 3.1 70B 模型,采用 FSDP 和 QLoRA

原文:towardsdatascience.com/multi-gpu-fine-tuning-for-llama-3-1-70b-with-fsdp-and-qlora-67a8a5b4f0d6?source=collection_archive---------1-----------------------#2024-08-08

使用仅有的 2 块 24 GB GPU 和大量 CPU 内存,你可以做什么

Benjamin MarieTowards Data Science Benjamin Marie

·发表于Towards Data Science ·8 分钟阅读·2024 年 8 月 8 日

--

使用 DALL-E 生成

微调最多 35B 参数的大型语言模型(LLMs)相对简单且廉价,因为只需要一块消费级 GPU 即可完成。理论上,使用单块消费级 GPU 微调更大的模型并非不可能,因为我们可以将模型的部分内容转移到 CPU 内存中。然而,即使使用高端 CPU,这个过程将会极其缓慢。

使用多个 GPU 是保持微调速度足够快的唯一替代方案。配置为 2 块 24 GB GPU 的系统提供了很多可能性。48 GB 的 GPU 内存足以微调 70B 模型,如 Llama 3 70B 和 Qwen2 72B。

在本文中,我将解释如何仅使用两块 GPU,通过 FSDP 和 QLoRA 微调 70B 参数的 LLM。

我首先解释什么是 FSDP,然后我们将看到如何修改标准的 QLoRA 微调代码,以便在多个 GPU 上运行。在实验和演示中,我使用了 Llama 3.1 70B,但对其他 LLM 也能起到类似的效果。硬件方面,我依赖于RunPod (推荐链接)提供的 2 块 RTX 3090 GPU。使用 2 块 RTX 4090 GPU 会更快,但成本也更高。

我还制作了一个笔记本,实施了本文中描述的代码。它可以在这里获取:

多头注意力 — 形式化解释与定义

原文:towardsdatascience.com/multi-head-attention-formally-explained-and-defined-89dc70ce84bd?source=collection_archive---------9-----------------------#2024-06-11

多头注意力的全面且详细的形式化

Jean Meunier-PionTowards Data Science Jean Meunier-Pion

·发表于Towards Data Science ·阅读时长 9 分钟·2024 年 6 月 11 日

--

带有多个头部的机器人,正在集中注意力 — 作者提供的图像(由 AI 生成,Microsoft Copilot)

多头注意力在变换器(Transformers)中扮演着至关重要的角色,而变换器已经彻底改变了自然语言处理(NLP)。理解这个机制是了解当前最先进语言模型的必要步骤。

尽管这个概念几年前已经提出,并且此后被广泛使用和讨论,但模糊的符号表示和缺乏正式定义,导致新手难以迅速揭开多头注意力机制的谜团。

本文的目标是提供一个全面且清晰的多头注意力形式化定义,使这一机制更容易理解。

由于要更好地理解新概念,关键是要自己动手实践,本文附带了几个练习/问题(以及解答),以帮助你准确理解多头注意力机制的作用。

免责声明: 在开始多头注意力的定义和解释之前,请注意,由于缺乏对 LaTeX 的支持,我不得不将方程转化为图像,以显示不同的数学对象。

输入

多头自注意力 — 手动实现

原文:towardsdatascience.com/multi-headed-self-attention-by-hand-d2ce1ae031db?source=collection_archive---------4-----------------------#2024-07-12

手动计算:现代人工智能的基石

Daniel WarfieldTowards Data Science Daniel Warfield

·发表于Towards Data Science ·6 分钟阅读·2024 年 7 月 12 日

--

“焦点”由 Daniel Warfield 使用 MidJourney 创作。所有图片除非另有说明,否则均为作者提供。文章最初发布于直观且详尽的解析

多头自注意力可能是机器学习中最重要的架构范式。本总结介绍了多头自注意力中的所有关键数学运算,使您能够在基础层面理解其内部工作原理。如果您想深入了解该主题背后的直觉,可以查看 IAEE 文章。

## Transformer — 直观且详尽的解析

探索现代机器学习浪潮:一步步拆解 Transformer

towardsdatascience.com

第一步:定义输入

多头自注意力(MHSA)在多种上下文中都有应用,每种应用可能会以不同的方式格式化输入。在自然语言处理的背景下,通常会使用词向量嵌入,并配以位置编码,来计算一个表示每个单词的向量。通常,无论数据类型如何,多头自注意力都会期望一系列向量,其中每个向量代表某种含义。

没有检索模型构建 RAG 是一个可怕的错误

原文:towardsdatascience.com/multi-rep-colbert-retrieval-models-for-rags-fe05381b8819?source=collection_archive---------0-----------------------#2024-09-17

这是我最喜欢的两种技术——一种更快,另一种更精确。

Thuwarakesh MurallieTowards Data Science Thuwarakesh Murallie

·发表于Towards Data Science ·阅读时间:9 分钟·2024 年 9 月 17 日

--

图片来自Alexander Grey

我构建 RAG 应用;这很有趣!

但我构建的应用在生产环境中表现不佳。它们是有前景的原型,但从未上线!

问题几乎总是出在检索上。来吧,这是 RAG 的核心。如果没有这个,我们还能构建什么呢?

直到我为更快或更好的检索对文档进行索引

索引帮助我们设计能够更快检索数据的解决方案。它显著减少了延迟,提升了整体应用体验。几乎每个我们构建的应用都使用了索引。它与 LLMs 或 RAGs 无关。

几乎所有数据库都支持索引。例如,Postgres可以进行B-Tree, GiST, SP-GiST, BRIN, GIN 和 Hash类型的索引。这是一个足够长的列表,可能会成为未来一篇独立文章的主题。

在这篇文章中,我将讨论我经常使用的流行索引策略,以便更好地进行文档检索。然而,这些技术是特定于 RAG 应用的。你很快就会明白为什么。

我常用的两种索引技术是多重表示和 ColBERT。这并不是我们拥有的唯一方法。而且…

HuggingFace 中的多选问答

原文:towardsdatascience.com/multichoice-question-answering-in-huggingface-96f61eb88d18?source=collection_archive---------9-----------------------#2024-02-07

揭示问答任务的强大能力

Mina GhashamiTowards Data Science Mina Ghashami

·发表于Towards Data Science ·阅读时长 15 分钟·2024 年 2 月 7 日

--

图片来自 unsplash.com

自然语言处理技术在问答(QA)任务中展现出了巨大的能力。在这篇文章中,我们利用 HuggingFace 库来解决一个多选问答挑战。

具体来说,我们在一个多选问题数据集上微调一个预训练的 BERT 模型,使用 Trainer API。这使得我们能够将预训练 BERT 的强大双向表示适配到我们的目标任务中。通过添加分类头,模型可以学习帮助从多个选项中选择正确答案的文本模式。然后,我们通过在保留的测试集上使用准确率来评估模型性能

Transformer 框架允许快速实验不同的模型架构、分词器选项和训练方法。在本次分析中,我们将展示一个逐步的过程,用于通过 HuggingFace Transformers 实现多选 QA 任务的竞争性性能。

第一步:安装并导入库

第一步是安装和导入相关库。要安装库,请使用以下pip install命令:

!pip install datasets transformers[torch] --quiet

多层感知器,解释:带有迷你二维数据集的视觉指南

原文:towardsdatascience.com/multilayer-perceptron-explained-a-visual-guide-with-mini-2d-dataset-0ae8100c5d1c?source=collection_archive---------1-----------------------#2024-10-25

分类算法

解剖一个小型神经网络的数学(带视觉展示)

Samy BaladramTowards Data Science Samy Baladram

·发布于 Towards Data Science ·13 分钟阅读·2024 年 10 月 25 日

--

⛳️ 更多 [分类算法](https://medium.com/@samybaladram/list/classification-algorithms-b3586f0a772c),解释:· 虚拟分类器 · K 近邻分类器 · 伯努利朴素贝叶斯 · 高斯朴素贝叶斯 · 决策树分类器 · 逻辑回归 · 支持向量分类器 ▶ 多层感知器

是否觉得神经网络无处不在?它们出现在新闻中、手机里,甚至出现在你的社交媒体动态中。但说实话——我们大多数人根本不清楚它们是如何运作的。那些复杂的数学和像“反向传播”这样的术语呢?

这是一个思考:如果我们将事情简化呢?让我们探索一个多层感知机(MLP)——最基本类型的神经网络——来使用一个小型网络分类一个简单的二维数据集,只使用少量的数据点。

通过清晰的视觉效果和逐步的解释,你将看到数学变得生动,观察数字和方程式如何在网络中流动,以及学习是如何真正发生的!

所有视觉效果:作者使用 Canva Pro 创建。优化为手机端显示;在桌面端可能显示过大。

定义

多层感知机(MLP)是一种神经网络类型,它使用连接的节点层来学习模式。它之所以得名,是因为它有多个层——通常包括一个输入层、一个或多个中间(隐藏)层和一个输出层。

每个节点都与下一层的所有节点相连。当网络学习时,它会根据训练示例调整这些连接的强度。例如,如果某些连接导致正确的预测,它们会变得更强。如果它们导致错误预测,它们会变得更弱。

通过这种例子学习的方式,帮助网络识别模式,并对它从未见过的新情况做出预测。

多层感知机(MLPs)被认为是神经网络和深度学习领域的基础,因为它们可以处理一些简单方法无法解决的复杂问题。

📊 使用的数据集

为了理解多层感知机(MLPs)的工作原理,我们从一个简单的例子开始:一个只有几个样本的迷你二维数据集。我们将使用相同的数据集来保持简洁,这个数据集来自我们之前的文章。

列:温度(0-3)、湿度(0-3)、打高尔夫(是/否)。训练数据集有 2 个维度和 8 个样本。

在直接进入训练之前,让我们先试着理解构成神经网络的关键部分及其如何协同工作。

步骤 0:网络结构

首先,我们来看看网络的各个部分:

节点(神经元)

我们从神经网络的基本结构开始。这个结构由许多个单独的单位组成,称为节点或神经元。

这个神经网络有 8 个节点。

这些节点被组织成一个个层次结构来协同工作:

输入层

输入层是我们开始的地方。它接收我们的原始数据,节点的数量与特征的数量相匹配。

输入层有 2 个节点,每个特征对应一个节点。

隐藏层

接下来是隐藏层。我们可以有一个或多个这样的层,且可以选择每层的节点数量。通常,我们在层数越深时,每层的节点数量会越来越少。

这个神经网络有 2 个隐藏层,分别包含 3 个节点和 2 个节点。

输出层

最后一层给出了我们的最终答案。输出层中节点的数量取决于我们的任务:对于二分类或回归问题,我们可能只有一个输出节点,而对于多类问题,我们则为每个类别设置一个节点。

这个神经网络的输出层只有 1 个节点(因为是二分类)。

权重

节点之间通过权重相互连接——权重是控制每个信息片段重要性的数字。每个节点之间的连接都有自己的权重。这意味着我们需要大量的权重:一层中的每个节点都连接到下一层的每个节点。

这个神经网络总共有 14 个权重。

偏置

除了权重,每个节点还拥有一个偏置——一个额外的数字,帮助它做出更好的决策。权重控制节点之间的连接,而偏置则帮助每个节点调整其输出。

这个神经网络有 6 个偏置值。

神经网络

总结来说,我们将使用并训练这个神经网络:

我们的网络由 4 层组成:1 个输入层(2 个节点)、2 个隐藏层(3 个节点和 2 个节点),以及 1 个输出层(1 个节点)。这形成了一个 2–3–2–1 的架构。

让我们看看这个新图示,它展示了我们的网络从上到下的结构。我已经更新了它,以便更容易理解数学过程:信息从顶部的节点开始,沿着各层向下流动,直到到达底部的最终答案。

现在我们理解了网络的构建方式,接下来让我们看看信息是如何在其中流动的。这叫做前向传播(forward pass)。

第一步:前向传播

让我们逐步了解网络是如何将输入转换为输出的:

权重初始化

在我们的网络开始学习之前,我们需要为每个权重设置一个初始值。我们选择在-1 到 1 之间的小随机数。使用随机数开始有助于我们的网络在没有任何早期偏好或模式的情况下学习。

所有权重都是从[-0.5, 0.5]范围内随机选择的。

加权和

每个节点以两步处理传入的数据。首先,它将每个输入乘以其权重,并将所有这些数字相加。然后,它再加上一个数字——偏置——以完成计算。偏置本质上是一个恒定输入为 1 的权重

激活函数

每个节点将其加权和通过激活函数进行处理,以生成输出。激活函数通过引入非线性行为,帮助我们的网络学习复杂的模式。

在我们的隐藏层中,我们使用 ReLU 函数(整流线性单元)。ReLU 很简单:如果一个数是正数,它保持不变;如果是负数,它变为零。

层级计算

这个两步过程(加权求和和激活)在每一层中依次进行。每一层的计算帮助我们一步一步地将输入数据转化为最终的预测值。

输出生成

最后一层给出了网络的最终答案。对于我们的是/否分类任务,我们在这一层使用一种特殊的激活函数,叫做 sigmoid

Sigmoid 函数将任何数值转化为 0 到 1 之间的值。这使得它非常适合用于是/否的决策,因为我们可以将输出视为一种概率:越接近 1 意味着越可能是“是”,越接近 0 意味着越可能是“否”。

这个前向传播过程将我们的输入转化为一个 0 到 1 之间的预测值。但是这些预测有多准确呢?接下来,我们将衡量我们的预测与正确答案之间的接近程度。

步骤 2:损失计算

损失函数

为了检查我们的网络表现如何,我们衡量它的预测与正确答案之间的差异。对于二分类问题,我们使用一种叫做 二元交叉熵 的方法,它能告诉我们预测值与真实值之间的偏差。

神经网络中的数学符号

为了提高我们网络的性能,我们需要使用一些数学符号。在继续之前,让我们先定义一下每个符号的含义:

权重和偏置

权重表示为矩阵,偏置表示为向量(或一维矩阵)。括号表示法[1]表示层的编号。

输入、输出、加权和以及激活后的值

节点中的值可以表示为向量,从而形成一致的数学框架。

总结 这些数学符号帮助我们精确地描述网络的运作过程:

让我们看一个图示,展示我们网络中所有数学运算的过程。每一层都有:

  • 连接层的权重(W)和偏置(b

  • 激活前的值(z

  • 激活后的值(a

  • 最终预测值(ŷ)和损失值(L)在末尾

让我们看看每一层到底发生了什么:

第一个隐藏层

· 取输入 x,与权重 W[1] 相乘,加入偏置 b[1] 得到 z[1]

·z[1] 应用 ReLU 函数得到输出 a[1]

第二个隐藏层

·a[1],与权重 W[2] 相乘,加入偏置 b[2] 得到 z[2]

·z[2] 应用 ReLU 函数得到输出 a[2]

输出层

·a[2],与权重 W[3] 相乘,加入偏置 b[3] 得到 z[3]

·z[3] 应用 sigmoid 来获得我们的最终预测 ŷ

现在我们看到网络中的所有数学公式,如何改进这些数字以获得更好的预测?这时反向传播就派上用场了——它向我们展示了如何调整我们的权重和偏差,以减少错误。

第 3 步:反向传播

在我们了解如何改进网络之前,让我们快速回顾一下我们需要的一些数学工具:

导数

为了优化我们的神经网络,我们使用梯度——这是一个与导数密切相关的概念。让我们回顾一下基本的导数规则:

偏导数

让我们澄清常规导数和偏导数之间的区别:

*常规导数 · 用于函数只有一个变量的情况

· 显示当唯一变量变化时函数的变化量

· 写作 df/dx

偏导数

· 用于函数有多个变量的情况

· 显示当一个变量变化时函数的变化量,同时保持其他变量不变(视为常数)

· 写作 ∂f/x

一些偏导数的例子

梯度计算与反向传播

回到我们的神经网络,我们需要确定如何调整每个权重和偏差以最小化误差。我们可以通过一种叫做反向传播的方法来实现,它向我们展示了改变每个值如何影响网络的误差。

由于反向传播是通过网络反向进行的,我们可以将图示翻转过来,看看它是如何工作的。

网络的矩阵规则

由于我们的网络使用矩阵(权重和偏差的组合),我们需要特殊的规则来计算变化如何影响我们的结果。这里有两个关键的矩阵规则。对于向量v, u(大小为 1 × n)和矩阵W, X(大小为n × n):

  1. 求和规则

    ∂(W + X)/∂W = I(单位矩阵,大小为n × n

    ∂(u + v)/∂v = I(单位矩阵,大小为n × n

  2. 矩阵-向量乘积规则

    ∂(vW)/∂W = v

    ∂(vW)/∂v = W

使用这些规则,我们得到:

激活函数的导数 ReLU 的导数对于向量az(大小为 1 × n),其中a = ReLU(z):

a/∂z = diag(z > 0)

创建一个对角矩阵,表示:如果输入为正,则为 1;如果输入为零或负,则为 0。

Sigmoid 的导数

对于a = σ(z),其中 σ 是 sigmoid 函数:

a/∂z = a ⊙ (1 - a)

这直接乘以元素(⊙表示逐位置相乘)。

二进制交叉熵损失函数的导数

对于一个单独的例子,损失函数为 L = -[y log(ŷ) + (1-y) log(1-ŷ)]:

L/∂ŷ = -(y-ŷ) / [ŷ(1-ŷ)]

到目前为止,我们可以将所有的偏导数总结如下:

下图展示了我们迄今为止得到的所有偏导数:

链式法则

在我们的网络中,变化通过多个步骤进行传播:一个权重影响它所在层的输出,这又影响下一层,依此类推,直到最终误差。链式法则告诉我们将这些逐步变化相乘,以找出每个权重和偏置如何影响最终的误差。

误差计算

我们并不是直接计算权重和偏置的导数,而是首先计算层误差∂L/∂(相对于预激活输出的梯度)。这使得后续计算我们应该如何调整早期层的权重和偏置变得更容易。

权重梯度和偏置梯度

使用这些层误差和链式法则,我们可以将权重和偏置的梯度表示为:

梯度向我们展示了网络中每个值如何影响网络的误差。我们随后对这些值进行小幅调整,以帮助网络做出更好的预测。

步骤 4:权重更新

更新权重

一旦我们知道每个权重和偏置如何影响误差(即梯度),我们就通过将这些值调整到与梯度相反的方向来改进我们的网络。这一步步地减少了网络的误差。

学习率与优化

我们并不是一次性做出大的变化,而是进行小而谨慎的调整。我们使用一个叫做学习率(η)的数值来控制每次调整的幅度:

  • 如果η太大:变化过大,可能会导致情况变得更糟

  • 如果η太小:变化很小,改进需要很长时间

这种通过小幅、可控的调整来优化的方式被称为随机梯度下降(SGD)。我们可以将其写作:

η(学习率)的值通常选择较小,通常在 0.1 到 0.0001 之间,以确保学习的稳定性。

我们刚才看到了网络如何通过一个示例进行学习。网络对数据集中每个示例重复这些步骤,在每轮实践中逐步改进。

步骤总结

以下是我们为训练网络在单个示例上所涉及的所有步骤:

扩展到完整的数据集

迭代

我们的网络会对数据集中的每个示例重复这四个步骤——前向传播、损失计算、反向传播和权重更新。遍历所有示例一次称为一次迭代

网络通常需要多次看到所有例子,才能熟练掌握任务,甚至多达 1000 次。每一次训练帮助它更好地学习模式。

批次

我们的网络不是一次从一个例子中学习,而是一次从一小组例子(称为批次)中学习。这有几个好处:

  • 运行更快

  • 学习更好的模式

  • 稳定地改善

在处理批次时,网络会先查看组内的所有例子,然后再做出改变。这比每看一个例子就改变一次值能得到更好的结果。

测试步骤

准备完全训练好的神经网络

训练完成后,我们的网络准备对它从未见过的新例子进行预测。它使用与训练相同的步骤,但只需向前传播通过网络进行预测。

进行预测

处理新数据时:

1. 输入层接收新值

2. 在每一层:

· 通过权重进行乘法运算并加上偏差

· 应用激活函数

3. 输出层生成预测(例如,二分类的概率值在 0 到 1 之间)

ID 9 的预测结果是 1(是)。

神经网络的确定性特征

当我们的网络两次看到相同的输入时,它会两次给出相同的答案(前提是我们没有改变它的权重和偏差)。网络处理新例子的能力来自于它的训练,而不是在预测时的任何随机性。

最后的备注

随着网络一次又一次地练习这些例子,它会在任务上变得更好。随着时间的推移,它犯的错误越来越少,预测也变得更加准确。这就是神经网络如何学习的过程:查看例子,找到错误,做出小的改进,并不断重复!

🌟 多层感知机分类器代码总结

现在让我们看看我们的神经网络是如何运作的。以下是一些 Python 代码,构建了我们一直在讨论的网络,使用的是我们刚刚学习的相同结构和规则。

import pandas as pd
import numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score

# Create our simple 2D dataset
df = pd.DataFrame({
    '🌞': [0, 1, 1, 2, 3, 3, 2, 3, 0, 0, 1, 2, 3],
    '💧': [0, 0, 1, 0, 1, 2, 3, 3, 1, 2, 3, 2, 1],
    'y': [1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, 1]
}, index=range(1, 14))

# Split into training and test sets
train_df, test_df = df.iloc[:8].copy(), df.iloc[8:].copy()
X_train, y_train = train_df[['🌞', '💧']], train_df['y']
X_test, y_test = test_df[['🌞', '💧']], test_df['y']

# Create and configure our neural network
mlp = MLPClassifier(
    hidden_layer_sizes=(3, 2), # Creates a 2-3-2-1 architecture as discussed
    activation='relu',         # ReLU activation for hidden layers
    solver='sgd',              # Stochastic Gradient Descent optimizer
    learning_rate_init=0.1,    # Step size for weight updates
    max_iter=1000,             # Maximum number of epochs
    momentum=0,                # Disable momentum for pure SGD as discussed
    random_state=42            # For reproducible results
)

# Train the model
mlp.fit(X_train, y_train)

# Make predictions and evaluate
y_pred = mlp.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

想了解更多?

  • 查看 scikit-learn 的MLPClassifier官方文档,了解更多详情及如何使用

  • 本文使用的是 Python 3.7 和 scikit-learn 1.5,但核心理念也适用于其他版本。

图片归属

本文中的所有图表和技术插图都是作者使用 Canva Pro 的商业许可证条款下授权设计元素制作的。

𝙎𝙚𝙚 𝙢𝙤𝙧𝙚 𝘾𝙡𝙖𝙨𝙨𝙞𝙛𝙞𝙘𝙖𝙩𝙞𝙤𝙣 𝘼𝙡𝙜𝙤𝙧𝙞𝙩𝙝𝙢𝙨 𝙝𝙚𝙧𝙚:

Samy Baladram

Samy Baladram

分类算法

查看列表8 个故事

𝙔𝙤𝙪 𝙢𝙞𝙜𝙝𝙩 𝙖𝙡𝙨𝙤 𝙡𝙞𝙠𝙚:

Samy Baladram

Samy Baladram

回归算法

查看列表5 个故事一个绑着辫子、戴着粉色帽子的卡通娃娃。这只“傀儡”娃娃,凭借其简单的设计和心形图案的衬衫,形象地代表了机器学习中的傀儡回归器概念。就像这个玩具般的人物是一个简化的、静态的人物代表一样,傀儡回归器是用作更复杂分析的基础模型。Samy Baladram

Samy Baladram

集成学习

查看列表4 个故事

使用 lme4 的多层次建模:分析电子商务销售

原文:towardsdatascience.com/multilevel-modelling-with-lme4-analysing-e-commerce-sales-4a36436c1e66?source=collection_archive---------12-----------------------#2024-07-23

多层次建模可以分析具有层级或聚类结构的数据。

Michael GroganTowards Data Science Michael Grogan

·发表于 Towards Data Science ·阅读时间 7 分钟·2024 年 7 月 23 日

--

来源:作者使用 R 生成的输出。

注意:原文可在此处查看 这里.

多层次建模在市场研究中尤为有用,通过对顾客按类别(如人口统计、购买习惯)进行细分,有助于理解企业如何吸引新顾客并提高现有顾客的忠诚度。

R 中的lme4库用于创建多层次模型。在该库中,一个重要的多层次建模实例是sleepstudy示例,其中使用多层次模型分析了在睡眠剥夺条件下,不同睡眠剥夺天数对参与者反应时间的影响。

我们如何将这样的模型应用于分析顾客数据呢?让我们来看一下!

背景

一个电子商务网站希望分析他们收集的关于网站活动的最新销售数据。具体来说,他们希望确定影响每位顾客消费的因素。

他们提供了一个包含以下信息的数据集:

多语言 RAG、算法思维、异常值检测与其他问题解决亮点

原文:towardsdatascience.com/multilingual-rag-algorithmic-thinking-outlier-detection-and-other-problem-solving-highlights-e09aa52e21ec?source=collection_archive---------3-----------------------#2024-06-06

TDS 编辑Towards Data Science TDS 编辑

·发表于 Towards Data Science ·发送至 通讯 ·阅读时长 4 分钟·2024 年 6 月 6 日

--

感到受启发,准备写下你的第一篇 TDS 文章了吗? 我们始终欢迎新作者的投稿

当我们思考问题解决时,我们的注意力通常集中在解决部分:强大的黑客技巧、新的神奇工具、几行代码使一切水到渠成。实际上,要使这些最终的解决方案发挥作用,还需要做很多工作——从深入理解问题本质,到勾画出一个可行的流程,确保我们找到的是一致的成功,而不是仅仅一个临时的应急措施。

本周的亮点以其全面的方式脱颖而出,致力于寻找有效的解决方案,处理那些偶尔棘手的挑战。它们为我们提供了一个关于从业者思维方式的视角,展示了他们在探索可用资源(例如数据、工具和时间)时,如何权衡不同工作流程的利弊。我们认为,这些内容可能会激发你从全新的视角看待当前正在进行的项目。祝您阅读愉快!

  • 数据科学家的算法思维要深入了解算法思维的好处——它意味着“结合严谨的逻辑与创造力来构建、解决和分析问题,通常借助计算机的帮助”——不要错过Chinmay Kakatkar的精彩文章。重点是编写高效代码,但你可以将这里阐述的原则应用到广泛的使用场景中。

  • 终极指南:在时间序列数据中发现异常值(第一部分)在数据集中检测模式并排除异常仍然是数据科学家的基本任务。Sara Nóbrega的新指南是一份广泛而实用的资源,概述了几种强大的技术,并深入探讨了如何为你正在进行的项目选择合适的技术。

  • Jet Sweep:优化路线,访问每个 NFL 球队的主场旅行商问题是一个经典的优化难题;Sejal Dua提供了一个有趣的理论复杂性解析,并介绍了一些新角度:我们看的是 NFL 体育场,而非销售路线,并且使用线性规划和地理空间数据来生成最佳的旅行路线,访问所有体育场。

图片由Kayla Duhon提供,来自Unsplash

  • 利用数学规划与列生成法解决资源规划问题在他的首次 TDS 文章中,Luis Fernando PÉREZ ARMAS 博士挑战了另一个著名的优化问题:最小顶点着色问题(也叫图着色问题),并深入探讨了它在现实世界中的应用,随后展示了如何使用数学规划和列生成法来解决这一问题。

  • 数据干扰提升实体嵌入 当类别特征有很多可能的取值(‘高基数’)时,建模和分析都会变得棘手。”Valerie Carey在她易于理解的解释中,以深入探讨实体嵌入作为应对高基数特征挑战的潜在方案开始,并提出了一种随机正则化方法,以改善其在神经网络模型中的泛化能力。

  • 跨语言探索 RAG 应用:与《米示拿》对话目前,已经有许多成熟的工作流程用于构建高效的 RAG 系统;Shlomo Tannor在他的实践教程中(又一次 TDS 首秀!)通过展示他如何构建一个多语言应用程序,向我们展示了英语用户如何从米示拿中获取信息,这是一部最初用希伯来语编写的古老拉比文献。

想要了解其他主题的推荐阅读吗?我们希望是这样——这里有一些我们最近的最爱:

感谢您支持我们作者的工作!我们非常喜欢发布新作者的文章,如果您最近写了一篇有趣的项目教程、教程或关于我们核心主题的理论反思,请不要犹豫,与我们分享

直到下一个 Variable,

TDS 团队

商业应用的多模态 AI 搜索

原文:towardsdatascience.com/multimodal-ai-search-for-business-applications-65356d011009?source=collection_archive---------2-----------------------#2024-11-01

使企业能够从其数据中提取真实价值

Umair Ali KhanTowards Data Science Umair Ali Khan

·发表于 Towards Data Science ·阅读时间:16 分钟 ·2024 年 11 月 1 日

--

如果你不是 Medium 会员,你可以通过 这个链接阅读完整文章。

商业文档,如复杂的报告、产品目录、设计文件、财务报表、技术手册和市场分析报告,通常包含多模态数据(文本以及图表、图像、地图、照片、信息图表、图示和蓝图等视觉内容)。从这些文档中找到正确的信息,需要对文本和相关图像进行语义搜索,以回答客户或公司员工提出的特定查询。例如,一家公司的产品可能通过标题、文字描述和图像来展示。类似地,一份项目提案可能包含文本、预算分配图表、显示地理覆盖范围的地图以及过去项目的照片。

对多模态信息的精准与快速搜索对于提升企业生产力至关重要。商业数据通常分布在不同的来源,既有文本格式也有图像格式,这使得高效地检索所有相关信息变得具有挑战性。尽管生成型 AI 方法,特别是那些利用大规模语言模型(LLM)来增强商业知识管理的方法(例如,检索增强生成图谱 RAGs等)在访问多模态的分散数据时存在一定限制,但统一不同数据类型的方法使用户能够通过自然语言提示查询各种格式的数据……

多模态嵌入:介绍

原文:towardsdatascience.com/multimodal-embeddings-an-introduction-5dc36975966f?source=collection_archive---------4-----------------------#2024-11-29

将文本和图像映射到一个共同的空间

Shaw TalebiTowards Data Science Shaw Talebi

·发布于Towards Data Science ·阅读时长 8 分钟·2024 年 11 月 29 日

--

这是更大系列中关于多模态 AI 的第二篇文章。在上一篇文章中,我们了解了如何增强大型语言模型(LLMs)以理解新的数据模态(例如图像、音频、视频)。其中一种方法依赖于编码器来生成非文本数据的向量表示(即嵌入)。在本文中,我将讨论多模态嵌入,并通过两个实际的使用案例展示它们的功能。

图像来自 Canva。

人工智能研究传统上分为不同的领域:自然语言处理(NLP)、计算机视觉(CV)、机器人学、人机交互(HCI)等。然而,许多实际任务需要整合这些不同的研究领域,例如自动驾驶汽车(CV + 机器人学)、AI 代理(NLP + CV + HCI)、个性化学习(NLP + HCI)等。

尽管这些领域旨在解决不同的问题并处理不同类型的数据,但它们都共享一个基本过程。即生成现实世界现象的有用数值表示

多模态大语言模型与苹果的 MM1

原文:towardsdatascience.com/multimodal-large-language-models-apples-mm1-c1e94d87a161?source=collection_archive---------0-----------------------#2024-04-13

本博客文章将深入介绍苹果“MM1:多模态 LLM 预训练的方式、分析与洞察”论文背后的架构和研究发现。

Matthew GuntonTowards Data Science Matthew Gunton

·发表于Towards Data Science ·9 分钟阅读·2024 年 4 月 13 日

--

图片由作者生成,使用 DALL-E

抽象化是计算机科学中最重要的概念之一,具有极其深远的影响。从简化的角度来看,抽象化是将某种方法应用于多个不同情况的能力。例如,如果你在工厂中创建了一种根据大小成功分类苹果的方法,那么你也可以将这一解决方案抽象化,用于同样的方式对橙子或桃子进行分类。因此,通过抽象化,一种非常强大的解决方案能够深刻影响世界的多个领域。

虽然大语言模型在处理文本输入时具有出色的推理能力,但最近我们已经能够将它们的输入抽象化,使它们能够通过图像和声音进行推理。

以下博客文章深入探讨了苹果 MM1 论文中的架构剖析及其在构建多模态大语言模型(MLLM)时的研究成果。

抽象化 LLM 输入

大语言模型背后的架构可以追溯到 2017 年发表的论文《Attention is All You Need》,其中引入了 Transformer 架构。

本文展示了如何将人类语言转化为神经网络处理的标记(在该论文中,转化为不同语言)。

图 1 来自《Attention is All You Need》

如你从图像中看到的,我们在早期进行了一次转换,将输入转换为标记(嵌入部分)。然而,并没有固有的理由认为只有文本数据才能映射为标记。因此,研究领域开始尝试将其他类型的数据映射为标记。

MM1 架构基础

来自论文的图 3 部分

苹果的模型有三个关键组件:视觉转换器(ViT)图像编码器、视觉-语言连接器和大型语言模型。假设你已经对 LLM 有一个不错的了解,知道它是如何工作的,那么我们就深入了解图像编码器和 VL 连接器。

图像编码器与视觉连接器

从抽象的角度看,我们可以将文本和图像视为不同种类的输入,但为了实现这一点,我们需要接受一个事实,即可能需要以不同的方式处理它们,以便将其转化为标记。目前,我们有两个不同的系统帮助我们将图像转换为 LLM 可以推理的标记:图像编码器和连接器。

首先,图像编码器负责将我们的图像转换为转换器模型可以理解的标记表示。

第二,连接器是将来自视觉编码器的数据转换成直接传递给大型语言模型的数据的部分。由于图像编码器返回的是标记,你可能会想,为什么我们还需要连接器。其背后的想法似乎是图像编码器在其标记中提供了过多的信息,因此为了在优化推理的同时减少成本,我们希望有选择地传递信息。

下面的图像展示了我们在此工作中的数据流。

来自“Honeybee: Locality-enhanced Projector for Multimodal LLM”的图 2

消融实验

机器学习中的消融研究围绕去除和修改模型的某些部分,以观察它们如何影响整体性能。苹果的研究聚焦于图像编码器的不同训练方式、VL 连接器的不同投影器以及不同的预训练数据。

让我们深入探讨主要的发现。

图像编码器消融实验

对于图像编码器,他们在 CLIP 和 AIM 模型、图像分辨率大小以及模型训练所用的数据集之间进行了变化。下面的图表展示了每个消融实验的结果。

来自论文的表 1

让我们逐一介绍上述的主要部分,并解释它们是什么。

CLIP代表对比语言图像预训练,旨在通过为需要被视作文本的事物提供名称,帮助模型学习视觉概念。如下面的图像所示,它将图像与文本编码配对,以便模型最终能够将视觉令牌(在下面的图像中表示为 I)与文本令牌 T 连接起来。这种方法称为对比训练。

图 1 来自《从自然语言监督学习可迁移的视觉模型》

AIM代表自回归图像模型,它通过重构损失优化算法进行训练。这里的目标是查看变换器是否能够重建(恢复)给定的图像。

图 2 来自《可扩展的大型自回归图像模型预训练》

图像分辨率在这里指的是输入到变换器中的像素数量。例如,378 x 378 的图像分辨率意味着我们将传入一个该大小的矩阵,然后将其转换为模型将要训练的嵌入。训练数据分为(DFN-2B)、(DFN-5B)、(DFN-5B + VeCap)和(ImageText-400M)四组。

作者发现图像分辨率最为重要,其次是模型大小,然后是训练数据内容。具体来说,他们发现图像分辨率越高,模型在零-shot 和少-shot 提示下的表现越好。由于训练和运行具有更高图像分辨率要求的模型需要更多的计算资源,这表明对于视觉变换器而言,计算资源仍然至关重要。

VL 连接器消融实验

对于 VL 连接器,他们测试了图像使用 64 或 144 个令牌,图像分辨率使用 224、336 和 378 进行测试,并在几种架构之间做了选择。下面我将简要介绍这些架构。

平均池化正如其名称所示,是取所有令牌的平均值,然后对这个平均值进行线性投影,使得网格为 8x8 或 12x12。

注意力池化假设图像令牌应被视为来自与文本令牌完全不同的群体。这里我们调整每个图像输入的令牌数量,在论文中称之为 k 可学习查询。研究人员仅考虑了 k 为 64 或 144 的情况。

卷积映射是 Honeybee 中的一种方法,使用 ResNet 动态决定从图像传递给 LLM 多少令牌。这在 C-Abstractor 模块中实现。

图 4 来自该论文

从上面的内容可以看出,不同的架构实际上对性能的影响很小。正如人们可能猜到的,分辨率更高的图像以及传递的 token 更多,提升了所有连接器的性能,但提升幅度并不显著。

这一发现表明,我们要么尚未找到一种显著更好的方式来将图像编码器与大语言模型连接起来,要么这一领域根本就不是优秀模型能够显著区分的地方。

预训练数据消融

来自论文的表 2

在这里,作者使用了四种不同的数据类型:带有描述的图像、合成描述的图像、交织的图像-文本数据,以及仅文本数据。他们总结出了四个经验教训,并用图表总结了性能变化。

来自论文的图 5a

第一,交织数据有助于少样本和仅文本性能,而带描述的数据有助于零样本性能。研究人员通过改变交织数据的比例,以下的图表展示了结果。正如你所看到的,使用交织数据训练的模型在少样本提示上表现明显优于使用完全文本或完全图像的模型。

来自论文的图 5b

第二,仅文本数据有助于少样本推理。在此背景下,仅文本数据意味着训练数据包含了图像示例和仅文本示例。这么做是为了确保模型能够理解人类语言和图像。比较仅描述与描述加文本的效果,除了 0-shot 推理外,所有任务都显示了明显的提升,但对于除 TextCore 测试以外的所有任务,交织数据的表现优于交织加文本。

来自论文的图 5c

第三,如果你能正确地将图像和文本混合在一起,就能获得非常强的性能。上面的图表展示了不同的图像与文本数据混合比例与仅文本数据的对比。由于目标是拥有一个多模态模型,因此他们从未测试过没有任何图像数据的性能表现。这里的作者指出,91/9 的比例产生了最 consistently 的优秀结果。

来自论文的图 5d

第四,合成数据有助于少样本学习。VeCap 代表视觉增强描述,它是一种创建描述的方法,确保这些描述能够准确表达图像中的关键视觉信息。相反,想象一个描述,可能解释了照片背后的意义,但并没有解释照片中的任何元素。如果你的数据抓取工具找到了描述不准确的图片,你通常会做这种处理。

作者在此得出结论,VeCap 在少样本推理中提供了“非平凡”的提升,但在质量上提升较小。这引发了关于 VeCap 性价比的问题。

结果

通过他们的消融实验结果,作者创建了两种形式的变换器:专家混合(Mixture-of-Expert)和常规形式。两个模型的编码器处理的是 378 x 378 的图像,且仅用 DFN-5B 数据集进行了预训练。它们的数据包含 45%的带字幕数据,45%的交错数据和 10%的仅文本数据(近似图像与文本数据的 91:9 比例)。VL 连接器有 144 个令牌,他们选择了一个 C 摘要器,尽管他们指出这是一个相对随意的选择。对于大语言模型本身,他们创建了一个 3B、7B 和 30B 参数的模型(其中 MoE 模型仅支持到 7B)。下图显示了这些模型的表现。

来自论文的表格 4

有趣的是,这个 30B 参数的模型与其他拥有数十亿更多参数的模型(如 LLaVA-NeXT-34B 等)表现相当,暗示参数大小与性能之间可能存在某种量子关系。

结束语

多模态大语言模型是这一领域令人兴奋的一个重要部分。随着我们找到更好的方法将不同类型的数据转化为令牌,我们可能会解锁这些变换器的更多应用。展望未来,考虑如何将除文本描述以外的其他感官输入(如声音、气味甚至触觉)结合进来并不不合理。数据质量很可能会变得愈加宝贵。

由于作者得出结论认为不同的语言连接器没有显著差异,因此很有趣的是,研究是否应该集中于图像编码器,还是我们只是尚未找到真正突破性的方式来使用 VL 连接器。

除了这篇特定的论文外,一个大问题是这些多模态大语言模型(MLLMs)在基准测试之外的表现如何。随着大语言模型(LLMs)的普及,常见的批评之一是围绕基准测试进行的比较。通常这些基准测试使用一致的数据集进行比较,允许某个模型通过过拟合来获得更好的表现,即使这种过拟合并非故意。像LLM Arena from lmsys这样的使用 ELO 棋类评级算法的方法,可能会提供一个更真实的模型表现比较。

总结来说,随着更多输入能够与大语言模型相连接,可以预期它们的应用领域将会增加。只有时间才能告诉我们如何让这项技术变得更加有用。

[1] McKinzie, B., 等人. “MM1:多模态大语言模型预训练方法、分析与洞察”(2024),arXiv

[2] Cha, J., 等人. “Honeybee:多模态大语言模型的局部增强投影器”(2023),arXiv

[3] Antoniadis, P. 等人. “机器学习:什么是消融研究?” (2024), arXiv

[4] Radford, A. 等人. “从自然语言监督中学习可迁移的视觉模型” (2021), arXiv

[5] El-Nouby, Al. 等人. “大规模预训练自回归图像模型” (2024), arXiv

[6] Vaswani, A. 等人. “注意力即一切” (2017), arXiv

[7] Lai, Z. 等人. “VeCLIP:通过视觉增强标题改进 CLIP 训练” (2023), arXiv

多模态模型 — 可以“看”和“听”的 LLM

原文:towardsdatascience.com/multimodal-models-llms-that-can-see-and-hear-5c6737c981d3?source=collection_archive---------7-----------------------#2024-11-19

包含示例 Python 代码的介绍

Shaw TalebiTowards Data Science Shaw Talebi

·发表于Towards Data Science ·阅读时长 9 分钟·2024 年 11 月 19 日

--

这是关于多模态 AI系列文章的第一篇。多模态模型 (MM) 是一种能够处理或生成多种数据模态的 AI 系统(例如文本、图像、音频、视频)。在本文中,我将讨论一种特殊类型的 MM,它建立在大型语言模型 (LLM) 的基础上。我将从高层次概述此类模型开始,然后分享使用 LLaMA 3.2 Vision 执行各种图像到文本任务的示例代码。

图片由Sincerely Media提供,来源于Unsplash

大型语言模型 (LLMs)在人工智能的研究与开发中标志着一次根本性的转变。然而,尽管它们的影响广泛,仍然存在根本的局限性

具体来说,LLMs 只能处理和生成文本,这使得它们对图像、视频、音频等其他模态视而不见。这是一个主要的限制,因为一些任务依赖于非文本数据,例如分析工程图纸、解读肢体语言或语音语调,以及解释图表和信息图。

多模态 RAG — 直观且详尽的解释

原文:towardsdatascience.com/multimodal-rag-intuitively-and-exhaustively-explained-5713d8069eb0?source=collection_archive---------4-----------------------#2024-07-25

人工智能 | 检索增强生成 | 多模态

现代 RAG 适用于现代模型。

Daniel WarfieldTowards Data Science Daniel Warfield

·发布于Towards Data Science ·10 分钟阅读·2024 年 7 月 25 日

--

“多彩团队”由 Daniel Warfield 使用 Midjourney 创作。除非另有说明,所有图片均由作者提供。文章最初发布于直观且详尽的解释

多模态检索增强生成(Multimodal Retrieval Augmented Generation,简称 RAG)是一种新兴的设计范式,它使得 AI 模型能够与文本、图像、视频等多种信息存储库进行交互。

在探讨这个话题时,我们将首先介绍什么是检索增强生成(RAG)、多模态的概念,以及这两者如何结合在一起,构建现代的多模态 RAG 系统。一旦我们理解了多模态 RAG 的基本概念,就可以使用 Google Gemini 和 CLIP 风格模型进行编码,亲自构建一个多模态 RAG 系统。

这对谁有用? 任何对现代人工智能感兴趣的人。

这篇文章的难度如何? 尽管多模态 RAG 处于人工智能的前沿,但它直观简单且易于接触。本文应该对资深 AI 研究人员有趣,同时对于初学者来说也足够简单。

前提条件:

检索增强生成简介

在深入了解多模态 RAG 之前,先简要回顾一下传统的检索增强生成(RAG)。基本概念是……

多模态 RAG:用 AI 处理任何文件类型

原文:towardsdatascience.com/multimodal-rag-process-any-file-type-with-ai-e6921342c903?source=collection_archive---------1-----------------------#2024-12-05

适合初学者的指南,包含示例(Python)代码

Shaw TalebiTowards Data Science Shaw Talebi

·发表于Towards Data Science ·12 分钟阅读·2024 年 12 月 5 日

--

这是关于多模态 AI系列文章的第三篇。在之前的文章中,我们分别讨论了多模态 LLMs 和嵌入模型。在本文中,我们将结合这些理念,推动多模态 RAG 系统的开发。我将首先回顾关键概念,然后分享实现该系统的示例代码。

图片来自 Canva。

类似 GPT、LLaMA 和 Claude 这样的语言模型通过预训练学习了大量的世界知识。这使得它们成为解决定制问题和回答复杂问题的强大工具。

然而,即使是最先进的语言模型也有它们不知道的知识。这包括组织内部的专有信息、模型预训练数据收集后发生的事件,以及在互联网上不常见的专业知识。

尽管这种无知限制了模型的开箱即用能力,但有一种流行的技术可以克服这些……

多项式朴素贝叶斯分类器

原文:towardsdatascience.com/multinomial-naive-bayes-classifier-c861311caff9?source=collection_archive---------10-----------------------#2024-03-14

一个完整的文本评论分类工作示例

Yoann MocquinTowards Data Science Yoann Mocquin

·发表于Towards Data Science ·阅读时间 16 分钟·2024 年 3 月 14 日

--

在这篇新文章中,我们将尝试理解多项式朴素贝叶斯分类器是如何工作的,并提供使用 Python 和 scikit-learn 的实际示例。

我们将看到:

  • 什么是多项式分布:与依赖于假定高斯分布的高斯朴素贝叶斯分类器不同,多项式朴素贝叶斯分类器依赖于多项式分布。

  • 创建依赖于贝叶斯定理的分类器的一般方法,以及基于朴素假设——即在给定目标类的条件下,输入特征相互独立。

  • 多项式分类器是如何通过学习/估计每个类别的多项式概率来“拟合”的——使用平滑技巧来处理空特征。

  • 如何计算新样本的概率,使用对数空间技巧以避免下溢。

所有图片由作者提供。

理解多项式分布

如果你已经熟悉多项式分布,可以继续阅读下一部分。

统计学必知:双变量正态投影解析

原文:towardsdatascience.com/must-know-in-statistics-the-bivariate-normal-projection-explained-ace7b2f70b5b?source=collection_archive---------8-----------------------#2024-08-14

该强大概念的推导与实践例子

Luigi BattistoniTowards Data Science Luigi Battistoni

·发表于Towards Data Science ·8 分钟阅读·2024 年 8 月 14 日

--

介绍

在统计学和机器学习中,理解变量之间的关系对于构建预测模型和分析数据至关重要。探索这些关系的基本技术之一是双变量投影,它依赖于双变量正态分布的概念。这项技术使得可以利用变量之间的依赖结构,来检查和预测一个变量在另一个变量的条件下的行为。

双变量投影有助于确定给定另一个变量特定值时一个随机变量的期望值。例如,在线性回归中,投影有助于估计因变量如何随自变量变化。

本文分为三个部分:在第一部分,我将探讨双变量投影的基本原理,推导其公式并展示其在回归模型中的应用。在第二部分,我将提供一些投影背后的直觉,并通过图示帮助更好地理解其含义。在第三部分,我将使用投影来推导线性回归的参数。

在我推导双变量投影公式时,我将使用一些著名的结果。为了不让读者感到过于沉重,我将提供证明……

必知的 Hive 大数据处理技术

原文:towardsdatascience.com/must-know-techniques-for-handling-big-data-in-hive-fa70e020141d?source=collection_archive---------10-----------------------#2024-08-14

HQL 的独特功能—PARTITIONED BY、STORED AS、DISTRIBUTE BY / CLUSTER BY、LATERAL VIEW 配合 EXPLODE 和 COLLECT_SET

Jiayan YinTowards Data Science Jiayan Yin

·发表于 Towards Data Science ·7 分钟阅读·2024 年 8 月 14 日

--

图片由 Christopher Gower 提供,来源于 Unsplash

在大多数技术公司中,数据团队必须具备强大的能力来管理和处理大数据。因此,熟悉 Hadoop 生态系统对于这些团队来说是至关重要的。由 Apache 开发的 Hive 查询语言(HQL)是数据专业人员在这个生态系统中操作、查询、转换和分析数据的强大工具。

HQL 提供了类似 SQL 的接口,使得在 Hadoop 中的数据处理对于广泛的用户群体来说既易于访问又用户友好。如果你已经精通 SQL,你会发现过渡到 HQL 并不难。然而,需要注意的是,HQL 包含了一些标准 SQL 中没有的独特功能和特性。在本文中,我将基于之前的经验,探索一些这些关键的 HQL 功能和特性,这些功能和特性需要超出 SQL 的特定知识。理解并利用这些能力对于任何从事 Hive 和大数据工作的人员来说都是至关重要的,因为它们构成了在 Hadoop 生态系统中构建可扩展和高效的数据处理管道和分析系统的基础。为了说明这些概念,我将提供带有模拟数据的用例…

我的 30 天地图挑战 2023

原文:towardsdatascience.com/my-30-day-map-challenge-2023-20a700d970e1?source=collection_archive---------5-----------------------#2024-06-06

选定地图主题和算法概述

Bryan R. VallejoTowards Data Science Bryan R. Vallejo

·发布于 Towards Data Science ·9 分钟阅读·2024 年 6 月 6 日

--

作者提供的图片。精选自我创作的地图拼贴。

介绍

正如你们中的许多人所知道的,30 天地图挑战是一个社交媒体活动,热爱地图的人们在每年的 11 月为期 30 天内围绕不同的主题进行地图设计。你可以在 官方网站 上找到更多信息。很少有人能够处理每个主题都能完全呈现的地图,并且每天在社交媒体上发布。另一方面,有些人会选择自己最喜欢的主题,只发布这些天的地图。我属于第二组人。

在这个故事中,我想描述我所经历的地图设计过程的概述。我选择了我最喜欢的主题,并使用算法为它们增色。我相信,创建能够展示“某种其他内容”的地图非常重要,而最好的方式就是通过算法的结果。地图应在第一眼给人一个清晰的视觉信息,如果可能的话,还能为经验丰富的地图阅读者提供更多的详细信息。

让我们来看看这些地图。

数据

临近点地图表示通过每个 OpenStreetMap(OSM)建筑物与 OSM 消防站之间的距离来衡量的接近度。每个点是每个建筑物的质心。然后,我使用最近邻 算法归属了欧几里得距离。这个算法节省了大量的计算时间。通过将接近度(距离)作为点的大小(最小且红色表示较远的建筑)为点图添加了不错的视觉效果。查看这个教程。

## 赫尔辛基地区大数据集的最近邻分析

BallTree 函数用于高效的距离分析:公交车站和餐馆

towardsdatascience.com

该地图是使用 QGIS 制作的。

作者提供的图像。点。

线条

该线路地图展示了 OSM 住宅楼与中央火车站之间的累计旅行距离。由于我有超过 19,000 座住宅楼,我需要构建正好 19070 条路线。幸运的是,我在芬兰阿尔托大学工作时有一台 16 核超级计算机,可以并行处理这个过程。路径是通过 OSMnx Python 库中的最短路径(Dijkstra 算法)生成的。地图中有一个很好的视觉效果,最短路径到市中心的线路宽度较细(白色),而较长的路径(红色)则宽一些。

这张地图是在 QGIS 中制作的。

图片来自作者。线路

多边形

该多边形地图是通过从赫尔辛基中央火车站生成的等时线制作的。生成 2 分钟等时线的过程非常具有挑战性,因为 Valhalla API 经常崩溃。但经过一些测试,终于成功了。我使用了routingpy Python 库来完成这个任务。

这张地图的视觉效果是移除海洋上的多边形。幸运的是,我成功地将由陆地限定的等时线与准确的海洋图层进行了裁剪。我不建议过于挑剔,因为这个裁剪过程花费了我很多时间。

我修改了这张地图,来自于原始帖子,因为最好保持等时线的自然边界。

这张地图是在 QGIS 中制作的。

图片来自作者。多边形

导航

导航地图显示了摩洛哥两点之间的两条路线。起点和终点是由一名需要帮助的学生提供的。我想了解如果一条路径按距离加权,另一条按时间加权,它们会有何不同。因此,最短路径和最快路径。路径是通过使用 OSMnx Python 库的最短路径(Dijkstra 算法)生成的。

我在地图背后添加了道路网络,并给它做了 3D 旋转。代码可以在线获取。

## Dijkstra 算法加权旅行时间在 OSM 网络中的应用

使用 OSMNX 1.6 找到最快和最短的路径

[towardsdatascience.com

这张地图是在 KeplerGl 中制作的。

图片来自作者。导航。

六边形

这张六边形地图是使用Kontur 人口数据制作的。该数据集按国家在 400 米的分辨率下聚合的人口,清晰地展示了人口密度。这个地图的挑战在于如何将加那利群岛加到地图的一侧。幸运的是,QGIS 可以添加不同位置的框架,因此我包括了所有的岛屿。

如果你愿意像这个例子一样在六边形中进行聚合,可以使用H3-pandas Python 库

使用 QGIS 制作的地图。

作者提供的图像。六边形

北美

这张地图是通过使用H3-pandas Python 库在全球范围内制作的。数据聚合在 H3 分辨率 6 中,使用的是包含全球蜂窝天线的 OpenCellID 数据。然后,我框定了北美地区。这张地图的挑战在于读取全球数据。幸运的是,这一过程可以通过使用Dask Python 库完成。

你可以在网上找到处理蜂窝天线数据所需的代码。

## 厄瓜多尔的移动网络运营商覆盖情况

蜂窝天线的 2G、3G 和 4G 发展地理可视化

[towardsdatascience.com

使用 QGIS 制作的地图。

作者提供的图像。北美。

南美

每次提到南美时,我都会想到厄瓜多尔海岸线的温暖海水。我想展示海表温度(SST)如何被上色而显得如此吸引人。在这张地图中,我使用了 2022 年 11 月的 MODIS OB DAAC 全球级数据集。然后,我使用 QGIS 框定了南美地区。这张地图的挑战是找到合适的颜色调色板。

如果你想独立处理 SST 数据,可以在这里找到代码:

## 使用 GEE 在全球范围内监测海表温度

如何使用 Python 为海洋监测创建 Streamlit 应用

[towardsdatascience.com

作者提供的图像。南美。

欧洲

制作这张地图是一个愉快的过程。我获得了 2022 年 5 月塔尔图(爱沙尼亚)共享单车数据的 GPS 位置数据集。然后,我使用 Python 库 Movingpandas 将高质量的 GPS 数据转换为轨迹(LineString),这使得操作变得更加方便。我添加了距离并为轨迹上色,呈现出人们如何使用自行车骑行短途和长途的精彩可视化效果。

塔尔图是 2024 年欧洲文化之都。一般来说,它被认为是一个 15 分钟城市,因为你可以骑车或步行到任何地方。游客会对这座城市如何轻松实现流动性感到着迷,而这张地图展示了城市各地共享单车的分布情况。

看一下展示如何使用 Movingpandas 的 Python 代码。

## 爱沙尼亚塔尔图的 Metallica 音乐会自行车共享系统运动

使用 Movingpandas 和 KeplerGl 进行 GPS 跟踪可视化 — 教程

[towardsdatascience.com

这张地图是在 QGIS 中制作的。

作者提供的图像。欧洲。

流动

这张地图是与 GIS4 WildlifeSEO BirdLife 的可视化项目的一部分。这个项目的挑战在于将所有的鸟类迁徙数据(GPS)融入到时空可视化中。这个项目使用了私人代码和算法,但幸运的是,我可以分享最终的可视化结果。事实上,KeplerGl 是进行时空可视化的最佳工具。

[## GIS4 Wildlife 移动分析

在“加拉帕戈斯海洋保护区”实施野生动物追踪器为地理空间解决方案提供了帮助……

www.gis4-wildlife.com

作者提供的图像。流动。

3D

和之前一样,我使用了Valhalla API来生成等时线。这里的挑战是为 3D 高程添加合适的数值。我需要加入回退的距离,以便较小的等时线呈现比较大的等时线更高的视图。虽然有点棘手,但对于 3D 地图来说效果不错。

这张地图是使用 KeplerGl 制作的。

作者提供的图像。3D

南极

当我听到南极时,首先浮现在我脑海中的就是那里的海洋哺乳动物,当然也包括企鹅。幸运的是,我找到了一层显示这些物种分布的图层,并了解到它们被认为是濒危物种。国际自然保护联盟(IUCN)提供了海洋-陆地哺乳动物的分布图层,而这个图层位于南极。这个可视化的挑战在于投影和叠加。幸运的是,QGIS 使用南极极地立体投影很好地完成了这项工作。

这张地图是在 QGIS 中制作的。由于数据许可规定,这张地图从此选择中移除。但你可以直接在 Twitter 上查看该地图。

👇🦭

南极海洋-陆地哺乳动物分布

加拉帕戈斯群岛因其丰富的海洋生物多样性,激发了我对运动生态学和地理空间技术的理解。渔业和鲨鱼种群无疑是全球关注的话题,因此当我听到“点”时,我想到了全球渔业观察的渔业数据。这张地图展示了 2021 年 9 月至 2022 年 2 月期间,加拉帕戈斯海洋保护区的渔业强度。这个指数被称为表观渔业努力,它为海洋空间规划提供了宝贵的见解。

这张地图的挑战在于点的大小和叠加的效果。这个组合展示了渔业的强度,通过点的大小和亮度来表示。很明显,我们可以看到渔船如何避免穿越加拉帕戈斯海洋保护区的边界,同时保护区内部的渔业活动更为局部,且多为小规模渔业。你可以进一步了解全球渔业地图数据以及渔业如何影响这个原始群岛中的海洋大型动物。

## 利用地理空间技术揭示渔业活动对海洋大型动物的风险

受到中国渔船威胁的鲸鲨——加拉帕戈斯群岛

[towardsdatascience.com

这张地图是在 QGIS 中制作的。

图片由作者提供。点

实验性

在地图创建过程中处理数据集、产品或中间层可能非常繁琐。我花了很长时间寻找最佳的颜色和大小组合,以正确传达信息。有一次,我想表示从步行网络到图书馆的距离。我找到了一个非常吸引人的调色板用于地图,但显然它并没有传达出正确的信息。

在那一刻,我意识到并不是每张“看起来不错”的地图都是做得好的,因此我把它称作实验性可视化。

这个实验是在 QGIS 中完成的。

图片由作者提供。实验性。

如果你想了解更多关于我的地图和编码教程,请随时联系我。

💼 Bryan R Vallejo LinkedIn

作为一名数据科学家的 7 个收入来源

原文:towardsdatascience.com/my-7-sources-of-income-as-a-data-scientist-aef6b6a41388?source=collection_archive---------0-----------------------#2024-10-01

作为一名实践中的数据科学家,我的所有赚钱方式

Egor HowellTowards Data Science Egor Howell

·发表于Towards Data Science ·阅读时间 7 分钟·2024 年 10 月 1 日

--

图片来源:Alexander GreyUnsplash

作为一名数据科学家,你有很多机会创造多个收入来源。在这篇文章中,我想分解一下我自己的 7 个收入来源,希望能给你一些可以尝试的思路。

这里的内容不构成财务建议。

全职工作

我的主要和最重要的收入来源是我作为数据科学家的全职工作。我从事这个职业已经超过三年了,我非常热爱它。

我每天都有机会处理许多有趣的问题,应用各种算法,并使用像生成式人工智能这样的前沿技术。

正如我在之前的文章中所讨论的,成为一名数据科学家并不容易。要进入这个领域需要付出大量的努力和学习,尤其是在现在这个时代。

然而,这些努力得到了非常丰厚的回报。我不能公开我的薪水以尊重我的雇主,但我赚的钱远远高于英国的平均薪水。

我的通勤时间超过四小时。每单程。

原文:towardsdatascience.com/my-commute-to-work-is-more-than-4-hours-each-way-fc8d7c2392ad?source=collection_archive---------1-----------------------#2024-09-27

我疯了吗,还是数据给出了不同的故事?

Jose ParreñoTowards Data Science Jose Parreño

·发布于Towards Data Science ·10 分钟阅读·2024 年 9 月 27 日

--

图片由Yusuf Mansoor提供,来源于Unsplash

我的通勤时间超过四小时。每单程。这意味着我每天要花费 8 小时通勤,再加上全天的工作。有些人可能会认为我疯了,但事实是:这些数据是有道理的。因此,当我和家人分析这些数据时,我们决定不搬到靠近我办公室的城市。这就是这一决定如何展开的故事——以及数据对此的反映。

附言 1:我与任何进行调查的公司或使用我计算通勤数据的公司没有任何关系。

附言 2:除非另有说明,所有图片均由我本人创作。

与我的数据科学团队的对话促使我写下了这篇文章。

在分享我的通勤故事之前,我觉得有必要介绍一下我为什么开始写这篇文章。随着亚马逊最近宣布让更多员工回到办公室,我们数据科学和数据工程团队有些不安。科技行业是那些远程办公和混合工作形式更为普遍的行业之一。在 Skyscanner,我们的数据科学团队有 40 多位成员,分布在伦敦、爱丁堡、巴塞罗那和深圳。其他团队也有类似的情况……

我的《前后处理测试简明指南》

原文:towardsdatascience.com/my-easy-guide-to-pre-vs-post-treatment-tests-0206f56f83a4?source=collection_archive---------4-----------------------#2024-06-13

《前后测试简介与代码示例》

Gustavo R SantosTowards Data Science Gustavo R Santos

·发表于 Towards Data Science ·13 分钟阅读·2024 年 6 月 13 日

--

图片来自 Towfiqu barbhuiyaUnsplash

介绍

我将从说起,A/B 测试对我来说一直不是一项强项。“好吧,那如果你正在写一篇关于它的文章呢?”你可能会这么想。

事实上,我不得不尽可能多地学习这个主题,寻找一种方法,在工作中快速顺利地执行这些测试。许多个月来,我一直避免进行这些测试,因为我无法找到任何能够立刻与我理解相契合的直接内容。

我读了许多关于 A/B 测试的概念性文章,其中一些是针对市场营销专业人员的,因此更多地涉及了如何确定样本量和测试周期的问题。但当谈到前后测试(也称为前后处理测试,我们可以说它是一种 A/B 测试类型)时,这些知识更难找到。

所以,希望这篇文章能够为所有寻求易于跟随的教程,使用 Python 执行这种类型测试的人提供一个良好的《前后测试》(或称前后测试)的介绍。

让我们深入了解一下。

概念

我在 DuckDB 中的第一亿条数据(行)

原文:towardsdatascience.com/my-first-billion-of-rows-in-duckdb-11873e5edbb5?source=collection_archive---------0-----------------------#2024-05-01

DuckDB 处理 450Gb 数据的初步印象,在实际项目中的应用

João PedroTowards Data Science João Pedro

·发表于 Towards Data Science ·阅读时间 12 分钟·2024 年 5 月 1 日

--

Duck 蓝图。由 Copilot Designer 生成。

介绍

人工智能、数据科学和数据工程领域正在全速发展。每天都有新的工具、新的范式和新的架构被创造出来,始终试图解决前一个问题。在这片充满新机遇的海洋中,了解一些现有工具来高效解决问题是很有趣的。而我说的不仅仅是技术细节,还有使用范围、优缺点、挑战和机会,这些都是通过实践获得的。

在这篇文章中,我将描述我在 DuckDB(一个用于在本地计算机上处理大量数据的新兴数据库)中的第一次体验,并重新审视我以前遇到的一个老问题——巴西电子投票箱日志的处理,以计算投票时间度量。正如你将通过这篇文章看到的,这是一个具有挑战性的问题,是对性能和用户体验评估的一个良好基准。

这个文章的目的是为你提供一些参考,特别是你想了解更多关于 DuckDB 的内容,我将涵盖技术方面的内容,比如运行问题和计算数据库性能,也会谈及一些“软”方面的内容,比如编程体验和可用性。

DuckDB 是一个开源项目 [OSD],作者与 DuckDB/DuckDB Labs 没有任何关系。所使用的数据符合 ODbL 许可证。这是一个完全免费的项目,不需要支付任何服务费、数据访问费或其他费用。

问题

这个问题包括处理电子投票箱日志中的记录,以获取关于巴西选民投票时间的统计指标。例如,计算公民投票所需的平均时间、收集指纹用于身份识别等等。这些指标应在不同的粒度层次进行聚合:从国家级、州级、选举区级到选举分区级。

如果你不知道,巴西拥有 100%电子投票系统,所有超过一亿的公民都在同一天投票,选举结果几乎实时计算并公布。投票由成千上万的电子投票箱收集,这些投票箱分布在全国各地。

电子投票箱。来自巴西最高选举法院的图片

电子投票箱是一种专用的微型计算机,用于选举工作,具有以下特点:坚固、小巧、轻便、具有能源自给能力,并具备安全功能 [4]。每台投票箱最多可以容纳 500 名选民,这是为了避免投票地点出现长时间排队的情况。

该系统由 TSE(最高选举法院)管理,TSE 通过其开放数据门户分享有关选举过程的数据 [ODbL]许可证。日志是文本文件,包含投票箱中所有事件的详细列表。

这就是挑战的开始。由于日志记录了每一个事件,因此可以从中计算出大量的指标;它是一个充满活力的信息源泉。但正是这种丰富性,使得它们变得极为难以处理,因为整个国家的记录总量已经达到了 450GB,TSV 文件中有超过 40 亿行。

除了数据量,另一方面使这项工作成为一个良好基准的原因是,达成我们最终目标所需的转换涉及了各种复杂度的操作,从简单的(如 wheregroup byorder by)到复杂的 SQL 操作(如窗口函数)。

DuckDB

在如此高的数据量下,人们可能会愿意调用传统的大数据工具,如 Apache Spark,并在集群中使用多个工作节点、几 GB 的 RAM 和十多个 CPU 来处理这些数据。

DuckDB 的创建就是为了挑战这一现状

正如其创始人在这个视频中所辩护的那样,这是一种数据库设计,旨在赋予单一机器处理大量数据的能力。

也就是说,放弃寻求复杂的行业解决方案——如 PySpark——或基于云的解决方案——如 Google BigQuery——而是使用本地进程数据库,利用标准 SQL 实现所需的转换。

总而言之,DuckDB 是一个内嵌式(运行在程序内部,没有独立进程,类似于 SQLite)、OLAP(针对分析负载进行优化)、能够处理传统格式数据(CSV、Parquet)的数据库,优化了在单台机器上处理大量数据的能力(不需要特别强大的机器)。

数据

一个投票箱的日志是一个标准化名称的单一 TSV 文件 — XXXXXYYYYZZZZ.csv,包含投票箱位置的元数据,其中前 5 位数字是城市代码,接下来的 4 位是选举区(一个地理州的子区域),最后 4 位是选举区段(即投票箱本身)。

巴西几乎有 50 万个投票箱,因此几乎有 50 万个文件。文件的大小取决于该选区选民的数量,范围从 1 到 500。这就是日志的样子:

2022-10-02 09:35:17 INFO 67305985 VOTA Voter was enabled
2022-10-02 09:43:55 INFO 67305985 VOTA Vote confirmed for [Federal Deputy]
2022-10-02 09:48:39 INFO 67305985 VOTA Vote confirmed for [State Deputy]
2022-10-02 09:49:10 INFO 67305985 VOTA Vote confirmed for [Senator]
2022-10-02 09:49:47 INFO 67305985 VOTA Vote confirmed for [Governor]
2022-10-02 09:50:08 INFO 67305985 VOTA Vote confirmed for [President]
2022-10-02 09:50:09 INFO 67305985 VOTA The voter's vote was computed
# Literal Translations to English
# Events that represent a vote

我们的目标是将这些原始信息转化为有关投票时间的统计指标(每个选民投票所需的时间是多少?每分钟计算多少票?),并在不同的粒度层级(国家、州、城市)上实现这一目标,为此我们将创建一个OLAP 立方体,如图所示:

| State         | City              | Mean Voting Time (seconds) | Max Votes Computed in 5 Min |
|---------------|-------------------|----------------------------|-----------------------------|
| Null          | Null              | 50                         | 260                         |
| São Paulo     | São Paulo         | 30                         | 300                         |
| São Paulo     | Campinas          | 35                         | 260                         |
| São Paulo     | Null              | 20                         | 260                         |
| Rio de Janeiro| Rio de Janeiro    | 25                         | 360                         |
| Minas Gerais  | Belo Horizonte    | 40                         | 180                         |
| Bahia         | Salvador          | 28                         | 320                         |
| Rio Grande ...| Porto Alegre      | 30                         | 300                         |
| ...           | ...               | ...                        | ...                         |

实现

设置环境

运行此项目所需的仅仅是一个安装了DuckDB 包的 Python 环境

pip install duckdb

转换数据

在接下来的部分中,我将描述每个转换的目标、DuckDB 如何执行每个转换、优点、挑战、结果和结论。

处理过程分为 4 个步骤:将 TSV 文件转换为 Parquet;过滤和清理;隔离选票及其属性;并计算 OLAP 立方体的指标。

处理步骤。图片由作者提供。

不幸的是,为了避免使这篇文章过于庞大,我不会详细解释每个转换。但所有代码都可以在GitHub 仓库中找到。

将 TSV 文件转换为 Parquet

对于任何想要处理大量数据的人来说,这是一个简单而不可或缺的步骤。在 DuckDB 中进行此操作非常直接。

首先,创建一个 DuckDB 会话:

cursor = duckdb.connect("")

在这个例子中,我们用一个空字符串实例化数据库连接器。这样做是为了表明 DuckDB 不应该创建自己的数据库文件,而是仅与系统文件交互。正如之前所提到的,DuckDB 是一个数据库,因此它具有创建表、视图等功能,但我们在这里不会探讨这些功能。我们将专注于将其用作转换引擎。

并定义以下查询:

query = f"""
    COPY (
        SELECT 
            * 
        FROM read_csv('/data/logs/2_{state}/*.csv', filename=True)
    ) TO '{state}.parquet' (FORMAT 'parquet');
"""
cursor.execute(query)

就这样!

让我们详细看看这个查询:

内部表达式只是一个标准的SELECT * FROM table 查询,唯一的区别是,DuckDB 可以直接引用文件,而不是引用表。

这个查询的结果可以导入到 pandas 数据框中进行进一步的表达,就像这样:

my_df = cursor.execute(query).df()

这使得 DuckDB 与 pandas 之间的无缝集成成为可能。

外部表达式是一个简单的 COPY … TO …,它将内部查询的结果写入文件。

在这个第一次转换中,我们开始看到 DuckDB 的一个优势——能够使用纯 SQL 与文件进行交互,而无需配置任何其他内容。上述查询与我们在标准 SGBD(如 PostgreSQL 和 MySQL)中执行的日常操作完全相同,唯一的区别是,我们不再操作表格,而是与文件进行交互。

原本我们有 450GB 的 TSV 文件,约 30 分钟 后,我们剩下了 97GB 的 Parquet 文件。

过滤并清除

如前所述,日志存储了选票箱上发生的每个事件。第一步旨在过滤出仅与投票相关的事件,如“选民投票选举了总统”、“选民采集了指纹”以及“投票已计算”,这些事件发生在选举日(这很重要,因为日志还记录了培训部分和其他行政程序)。

一个简单的查询,但包含大量文本和日期处理:

 VOTES_DESCRIPTIONS = [
    # VOTES
    "event_description = 'Aguardando digitação do título'",  
    # Awaiting voter's title (Voter Registration ID) input
    "event_description = 'Título digitado pelo mesário'",  
    # Voter's title entered by the poll worker
    "event_description = 'Eleitor foi habilitado'",  
    # Voter has been enabled
    "event_description ILIKE 'Voto confirmado par%'", 
    # Vote confirmed for ... could be [PRESIDENT, SENATOR, DEPUTY, ...]
    "event_description = 'O voto do eleitor foi computado'",  
    # Voter's vote has been computed
]

ACCEPTED_DATES = [
    '2022-10-02', '2022-10-30', # Constitutional date of the election filter
    '2022-10-03', '2022-10-31', 
]

query = F"""
    SELECT 
        *
    FROM (
        SELECT
            event_timestamp,
            event_timestamp::date AS event_date,
            event_type,
            some_id,
            event_system,
            event_description,
            event_id,

            REPLACE(SPLIT_PART(filename, '/', 5), '_new.csv', '') AS filename,

            -- Metadata from filename
            SUBSTRING( SPLIT_PART(SPLIT_PART(filename, '/', 5), '-', 2),  1, 5 ) AS city_code,
            SUBSTRING( SPLIT_PART(SPLIT_PART(filename, '/', 5), '-', 2),  6, 4 ) AS zone_code,
            SUBSTRING( SPLIT_PART(SPLIT_PART(filename, '/', 5), '-', 2), 10, 4 ) AS section_code,
            REPLACE(SPLIT_PART(filename, '/', 4), '2_', '') AS uf
        FROM
            {DATASET}
        WHERE 1=1
            AND ( {' OR '.join(VOTES_DESCRIPTIONS)} )
    ) _
    WHERE 1=1
    AND event_date IN ({', '.join([F"'{date}'" for date in ACCEPTED_DATES])})
"""

在这个查询中,突出了 DuckDB 的另一个优势:能够读取和写入分区数据。表分区在大数据的背景下非常重要,但在单机架构中尤为关键,因为我们在进行输入和输出操作时使用的是同一磁盘,即它要承受两次负担,每一个优化都非常欢迎。

原本我们有 97GB,但约 30 分钟后,我们剩下了 63GB 的 Parquet 文件。

隔离投票及其属性

由于每个投票由多行组成,我们需要将所有信息浓缩成一个唯一的记录,以便简化计算。这里情况变得复杂,因为查询变得复杂,而且不幸的是,DuckDB 无法一次性处理所有数据。

为了克服这个问题,我做了一个循环,以增量的方式处理数据切片:

for state in states:
    for date in ACCEPTED_DATES:
        for zone_group in ZONE_GROUPS:
            query = F"""
                COPY 
                {
                    complex_query_goes_here
                    .replace('<uf>', state)
                    .replace('<event_date>', date)
                    .replace('<zone_id_min>', str(zone_group[0]))
                    .replace('<zone_id_max>', str(zone_group[1]))
                } 
                TO 'VOTES.parquet' 
                (FORMAT 'parquet', PARTITION_BY (event_date, uf, zone_group), OVERWRITE_OR_IGNORE 1);
            """

实现细节并不重要,关键是我们不需要对代码做太多改动,就能逐步构建这个最终表格。由于每个处理的“切片”代表一个分区,通过将参数 OVERWRITE_OR_IGNORE 设置为 1,DuckDB 会自动覆盖该分区的任何现有数据,或者如果数据已存在,则忽略它。

原本我们有 63GB,约 1 小时 20 分钟后,最终得到了 15GB 的 Parquet 文件。

计算指标并构建 OLAP Cube

这是一个简单的步骤。现在,每个投票都由一个记录表示,所需的只是计算指标。

query_metrics = F"""
    SELECT
        turno, state,
        zone_code,
       section_code,

        COUNT(*) AS total_votes,
        COUNT( DISTINCT state || zone_code || section_code ) AS total_sections,

        SUM( vote_time ) AS voting_time_sum,
        AVG( vote_time  ) AS average_voting_time,

        MAX( nr_of_votes )   AS total_ballot_items_voted,
        SUM( nr_of_keys_pressed ) AS total_keys_pressed

    FROM
        source
    GROUP BY ROLLUP(turno, state, zone_code, section_code)
"""

由于我们需要计算多个粒度级别的指标,理想的做法是使用 GROUP BY + ROLLUP。

在这个案例中,DuckDB 表现得尤为突出:我们从 15GB 开始,36 秒后,文件大小缩减到了 88MB!

这是一个惊人的性能,它在不到一分钟的时间里,将超过 2 亿行数据按 4 个不同的粒度级别进行了分组,其中最高级别的基数为 2,最低级别的基数约为 200,000!

结果

下表总结了结果:

整个管道的执行时间约为 2 小时 30 分钟,运行在 WSL 上,配置如下:约 16GB 的 DDR4 内存,一颗 Intel 第 12 代 Core i7 处理器,和 1TB NVMe SSD。

在这个过程中,我注意到内存使用成为了瓶颈,因为 DuckDB 不断在磁盘的 .temp/ 目录中创建临时文件。另外,我在运行带有 Windows 函数的查询时遇到了很多问题:这些查询不仅执行时间超出了预期,还出现了程序随机崩溃的情况。

尽管如此,我认为最终的性能是令人满意的,毕竟我们讨论的是仅通过一台计算机(相对于计算机集群而言,性能不算特别强大)处理 1/2TB 的数据,并执行复杂的查询。

结论

事实上,处理数据有时就像提炼铀矿。我们从一大堆原材料开始,通过一个艰难、耗时且昂贵的过程(有时还可能会危及生命),提取出一小部分有用的精炼信息。

说正经的,在我的帖子中,我探索了很多数据处理的方式,讨论了工具、技术、数据架构……始终在寻找最好的做事方法。这类知识非常重要,因为它帮助我们选择合适的工具来做合适的事。本文的目标正是要了解 DuckDB 能解决什么样的工作,以及它适合什么样的使用体验。

总体来说,这是一次不错的体验。

与这个数据库的合作非常顺利,我几乎不需要配置任何东西,只需导入数据并使用普通的 SQL 语句进行操作。换句话说,对于那些已经懂 SQL 和一点 Python 的人来说,这个工具几乎没有什么初始使用门槛。在我看来,这正是 DuckDB 的巨大胜利。它不仅赋予我的机器处理 450GB 数据的能力,而且在环境(以及程序员)适应成本低的情况下达成了这一目标。

就处理速度而言,考虑到项目的复杂性、450GB 的数据量以及我没有优化数据库参数,2 小时 30 分钟是一个不错的结果。尤其是考虑到,如果没有这个工具,想要在我的电脑上完成这个任务几乎是不可能的,或者说是极其复杂的。

DuckDB 在 Pandas 和 Spark 之间某种程度上是一个折衷方案。对于小规模数据,Pandas 在可用性方面可能更具吸引力,尤其是对于那些有一定编程背景的人,因为这个包有许多内置的转换,这些在 SQL 中实现起来可能会很棘手。它还与许多其他 Python 包,包括 DuckDB,具有无缝集成。对于极大规模的数据,Spark 可能是更好的选择,具备并行处理、集群等特性。因此,DuckDB 填补了中型到不算太大的项目的空白,在这些项目中使用 Pandas 是不可行的,而使用 Spark 则过于复杂。

DuckDB 扩展了单台机器能够达到的极限,并扩大了可以在本地开发的项目范围,为大规模数据分析/操作带来了速度。毫无疑问,它是一个强大的工具,我会自豪地将它加入到我的工具箱中。

此外,我希望这篇文章能帮助你更好地了解 DuckDB。和往常一样,我并不是本文所涉及主题的专家,我强烈建议进一步阅读,以下是我的参考文献,代码也可以在 GitHub 上找到。

感谢阅读! 😉

参考文献

所有代码可在 此 GitHub 仓库*中找到。

想了解更多类似的作品吗?访问我的* 文章仓库*。

[1] 2022 年结果 — 传输文件用于汇总 — TSE 开放数据门户链接。[ODbL]

[2] Databricks。(2023 年 6 月 29 日)。数据 + AI 峰会主旨演讲,星期四第五部分 — DuckDB。YouTube。

[3]DuckDB 官方 文档。DuckDB。

[4]电子投票箱最高选举法院

[5] 维基百科贡献者。(2023 年 7 月 25 日)。OLAP 立方体。维基百科。

[6] Duckdb — GitHub。窗口性能 · 问题 #7809 · duckdb/duckdb

[7] Gunnarmorling。GitHub — gunnarmorling/1brc: 1️⃣🐝🏎️ 十亿行挑战 — 一次有趣的探索,了解如何快速汇总来自文本文件的 10 亿行数据,使用 Java

我掌握 SAP 数据模型的第一步

原文:towardsdatascience.com/my-first-steps-into-mastering-saps-data-models-4d20ad2485f2?source=collection_archive---------2-----------------------#2024-04-24

如果你是一个好奇的读者,想了解更多关于 SAP 数据模型的信息,那么你来对地方了!

Ayoub El OutatiTowards Data Science Ayoub El Outati

·发布于 Towards Data Science ·8 分钟阅读·2024 年 4 月 24 日

--

你好,Medium 的读者们!我很高兴与大家分享我最近在一个项目中的一些收获,在这个项目中,我深入探讨了 SAP 数据模型的复杂性。

出于保密原因,我无法分享所有的项目细节 🤫。不过,我会讨论一个我在面对 SAP 数据模型复杂性时遇到的挑战:SAP 数据架构是什么样的,所有内容如何整合成一个对业务用户有意义的连贯数据模型?

在这个项目中,我的主要工作是将数据集成到一个分析/挖掘平台中,使用 SAP 采购流程相关的业务数据。随着项目的进展,我遇到了一些关于数据建模的问题,并且需要在 SAP 数据库的各个表格之间找到我的方向。

图片来自Tingey Injury Law FirmUnsplash

我的目标是创建一个连贯的数据模型,能够高效地支撑未来的仪表盘、报告以及其他分析输出,所有这些都将在下一步计划中实现。为了达到这个目标,我必须完全理解 SAP 数据架构的基本知识、表映射等内容。这不是一项简单的任务,但我成功地掌握了这些新知识 🚀

1) SAP,企业资源规划软件的领导者

如你所知,SAP 是企业资源规划(ERP)软件的领先提供商之一,与 Oracle 和 Infor 等公司并列。目前,SAP ERP 软件有两代版本:SAP ECC 及其继任者 SAP S/4 HANA。

两个软件的组件涵盖了公司的所有职能,这些职能不仅是内向型的,还包括外向型的,例如客户侧(CRM)和供应商侧(SCM)。它由不同的模块组成,涵盖了公司所有的需求:销售、生产、物流、财务、人力资源、售后服务。

SAP ECC 模块 — 作者提供的图片

这是一个将所有业务流程整合到单一数据库中的软件包。它还自动考虑不同流程之间的相互依赖关系。换句话说,SAP 或一般的 ERP 软件,作为业务流程的核心,依靠大量的数据和事务驱动。

对于数据分析和过程优化项目,我们只能说,将 SAP 中的数据整合并为前端使用建模,以提供业务流程的端到端视图并识别低效之处,确实是有益的。然而,在没有首先了解数据在 SAP 中如何组织的情况下,这可能会非常具有挑战性。

2) SAP 的数据和表格类型

SAP 的数据是如何组织的?简而言之,SAP 数据库中的多个表格存储了所有由 SAP 事务生成的数据,例如在采购功能中创建供应商或采购订单。此外,需要注意的是,虽然一些表格随时间推移保持稳定,但其他表格则存储了捕捉到特定商业事件的数据,这些事件对日常运营至关重要。在 SAP 中创建供应商或采购订单的例子很好地说明了这一点。

让我们以汽车行业的一家公司为例。该公司会有一个供应商表,列出他们常合作的供应商,以采购制造汽车和其他物品所需的材料。在此表中创建一个新供应商是一个罕见的事件,可能一年只会发生几次。这个表格被称为主表,仅存储主数据。

与此同时,创建采购订单是一个操作任务。公司内部的采购团队通常使用 MM(物料管理)事务来创建采购订单(例如,数百万个采购订单)。因此,存储所有这些采购订单的表格被称为事务表,这种类型的数据被称为事务数据。

在 SAP ECC 和 S/4 HANA 中,两个表有特定的名称。供应商主表被称为 LFA1 表,而采购订单存储在 EKKO 表中。 对于第一次接触 SAP 表的人来说,可能无法仅凭表名就清楚每个表所代表的内容。例如,如果我提到 MCHA、MSEG、BSEG 等名称,你会知道这些是 SAP 表,但不一定知道它们存储的信息。它们是存储生产订单、财务账户还是发票?

别担心😉,在文章的末尾你会找到一个有用的网站,可以进一步了解 SAP 表格命名,照片来源:Siora PhotographyUnsplash

这正是为什么对我来说,进行研究、做笔记并记住在项目范围内的表格是如此重要。更重要的是,我需要理解每个表格中使用的关系和字段。在这个阶段,映射这些数据以便快速理解也同样重要。

3) SAP 采购表格的数据映射

在这类项目中,你可能会面临的另一个挑战是理解 SAP 的数据架构以及如何定义采购业务流程。幸运的是,我之前在与采购团队合作的流程改进项目中的经验,对于理解采购流程非常有帮助。在进入 SAP 数据模型或采购表格映射之前,让我们简要回顾一下采购流程中的主要对象。

简而言之,采购是从供应商那里购买某些物品的过程,可能是原材料、服务、工具等。该过程包括接收购买的物品、验证其状态,然后启动支付流程,通常称为应付账款流程。然而,这一过程超出了我们此次讨论的范围。

照片来源:Alexander GreyUnsplash

在整个过程中,采购团队承担着各种任务:他们验证采购需求单,创建采购订单,并将这些订单发送给供应商。大多数这些任务是通过企业资源计划(ERP)软件完成的,特别是 SAP 中的物料管理(MM)模块。在这一系列活动流程中,关键对象或组件从一个步骤过渡到下一个步骤:

  • 采购需求单: 这些通常由生产团队在内部创建,旨在通知采购团队某个特定物品需要为生产目的进行采购。

  • 采购订单: 这些订单由采购团队创建,包含有关待购买物品的详细信息,如数量、应接收采购订单的供应商及其他相关数据。

  • 收货单: 在收到采购订单中提到的商品后,供应商会提供收货单。该收货单对于核实仓库收到的商品是否与采购订单中的要求相符至关重要。

  • 发票收据: 该文档确认收到的商品和服务是正确的,并且与采购订单一致。它常用于启动支付流程以支付供应商。

SAP 中的每个对象都有其关联的交易表,存储所有已创建的对象。例如,如前所述,采购订单存储在 EKKO 表中。然而,在 SAP 表格中还有另一层复杂性需要考虑:头部和项目文档的概念。

在 SAP 中,像采购订单这样的对象或文档有两个层次的表示:头部层次和项目层次。这是因为,从概念上讲,文档始终由两个层次的信息组成。头部包含一般和汇总数据,而项目则对应文档中的具体行。

举个例子,假设你正在购买 PS5、《最后生还者 2》和一副耳机。

图片由 Samuel Regan-Asante 提供,来源于 Unsplash

您最终的订单将包括包含您的地址和订单号的头部。您购买的商品以及它们的数量(在这种情况下,每个项目一个)和价格,将在项目行中单独列出。

采购订单格式 — 作者提供的图像

然后,一个常见的问题是:EKKO 表代表的是头部数据还是项目数据?正确答案是头部数据项目数据存储在一个名为 EKPO 的单独表中。

这个核心概念非常重要,因为它适用于大多数 SAP 对象/文档。发票将有一个用于头部数据的表格和另一个用于项目的表格,货物收据也是如此。然而,采购申请是一个例外,它只有一个项目表。

为了整合和分析 SAP 采购的数据,我需要确定合适的表格来提取数据,并了解它们之间的关系,从而构建相应的数据模型。我所做的映射,旨在可视化这些数据元素是如何相互连接的,详细信息如下:

SAP 采购数据映射示例 — 作者提供的图像

显然,我所做的映射主要集中在采购流程的基本交易表上。可以将额外的表格,包括主数据表和其他交易表——例如,存储采购订单变更的表格——包含在映射中。

这个映射还突出了表格之间的关系,无论是“一对一”、“一对多”还是“多对多”。此外,它还包括构成每个对象主键的字段。

理解每个表的作用以及潜在的关系可能会很耗时。然而,有许多资源可以提供帮助,例如那些帮助我创建这个高级采购数据映射的资源。如果你对了解 SAP 表感兴趣,可以考虑访问这个网站。它提供了每个 SAP 表的详细概述,包括主键、字段、可能的值等。你只需要输入你正在寻找的表名,例如 MSEG,你就可以获得关于该表结构和存储信息类型的更多详细信息:

搜索 SAP 表 | LeanX

最后,如果这篇文章让你对其他业务功能的 SAP 数据模型产生了更多兴趣,别犹豫,快去查看 SAP 社区网站。你将会找到一些有趣的内容,例如同样的财务功能高级数据映射:

MM 和 FI 表的关系(基础)— SAP 社区

希望你喜欢这篇关于 SAP 数据模型的入门指南!欢迎留下掌声 👏、评论 🗨️ 或与同事分享 📨 这篇文章。

感谢阅读,别忘了关注我获取更多文章 🚀

我的自由职业经历:作为一名地理数据科学家在 UpWork 上的 10 个月

原文:towardsdatascience.com/my-freelance-experience-as-a-geo-data-scientist-on-upwork-after-10-months-9c5c293acad4?source=collection_archive---------0-----------------------#2024-10-14

如何获得客户、评价和稳定收入

Aleksei RozanovTowards Data Science Aleksei Rozanov

·发布于 Towards Data Science ·7 分钟阅读·2024 年 10 月 14 日

--

图片来源:Phil Desforges 通过 Unsplash

大约 10 个月前,我决定在几年的纯学术专业经验后,尝试一些自由职业/咨询。然而,我发现我的技能组合有些稀有且奇特。

一方面,我的“大数据与机器学习”硕士学位让我掌握了一系列非常有用的工具和工业数据处理、建模与分析的知识。然而,我距离自称为专家还差得远。可以说,我仅具备一名初级数据科学家的技能。

另一方面,作为来自地球科学领域的人,我知道如何处理地理空间数据,特别是遥感数据(例如卫星影像)。从 NASA 的 API 中提取一幅景象,进行掩膜处理并执行一些计算对我来说根本不是什么难事。然而,要称自己为该领域的专家,你需要对大气光学、物理学和地理信息系统(GIS)有更深入的理解。

因此,我渴望体验学术界之外的世界,最终尝试在这两个领域之间找到平衡。令人惊讶的是,这个决定证明是相当正确的,因为这一专业技能是具体的小众的

全球有成千上万的人在机器学习领域做得非常出色,开发出最疯狂的架构,也有很多人掌握了遥感和卫星数据处理的扎实知识。然而,能够将这两个领域结合起来的专业人士要少得多。我当时心想:我可能会成为其中之一!

我的 UpWork 个人资料公开页面。图片来自作者

我坚信,细分市场(niche)这个词在这个故事中非常重要。首先,我对它感到害怕。万一没人需要呢?万一没有项目,换句话说,没钱呢?万一这个专业领域太过具体呢?

但最终,它成了关键

UpWork似乎是获取客户的最大平台之一,提供了大量的地理数据科学项目。意识到其潜力后,我注册了账号,快速填写了个人资料,并在我的作品集中展示了一个个人项目。

没有虚假的期望:我知道,第一份完成的项目和客户评价是开始盈利的必要条件。

当然,我一开始提交的提案都是报酬最低的项目(所有工作报酬在25 美元50 美元之间)。这种工作有两个显著的问题:

  • 首先,即使任务看起来很简单,你也可能需要至少一个小时的时间与客户沟通、收集数据、上传结果以及进行分析或建模工作。在最好的情况下,你每小时能赚大约20 美元,但更常见的是接近10 美元12 美元

  • 其次,预算较低的客户往往是最挑剔的。平均而言,他们在项目进行过程中会提出新的需求,并频繁要求额外的修改。这些客户通常偏好固定价格的项目,也就是说,无论你投入多少努力、时间或资源,预算始终不变——而且通常非常少(即极其少)。情况就是这样。如果你有不同的经验,我很乐意听听!

当然,确实有一些天生较小的项目。我有过几次经验,当时我被要求修复别人的代码,而这些项目对我来说非常顺利且愉快(希望客户也同样喜欢)。对于这些小的修改,我按小时计费,让这份工作变得轻松愉快

然而,前两三个项目至关重要,因为那时我的个人资料几乎是空白的。没有任何实际工作经验,重要客户是不会信任我的个人资料的。关键是要持续申请工作。

个月的稳定努力之后,我开始从各类客户那里获得稳定的工作量。

最有收获的方面是我获得的知识:从掌握基本的沟通、推销和谈判技巧,到探索领域内的新方法、新模型、新库和新工具。

在成为自由职业者的个月后,我成功确保了每周 20 到 25 小时的稳定工作量,那时这对我来说是完美的。总体来说,我在 2024 年 1 月到 8 月之间在 UpWork 上的总收入为6,338 美元,对于一个经验几乎为零的人来说,我认为这是一个非常不错的数字。

不幸的是,最近我不得不暂停我的 UpWork 之旅,但我真的决定将来重新回归!

与此同时,我想分享一些关于这次经历的提示和关键收获。希望它们能帮助你取得成功!

I. 首先,请尽可能缩小范围。

不要仅仅称自己为机器学习工程师,而要称自己为金融数据科学家/数据标注员/医学计算机视觉专家/LLM 架构师,或者像我一样称自己为地理数据科学家。你可以接到通用数据科学项目,但你也会有一个固定的细分领域。

我参与的几乎所有项目都在地理空间领域,但其中一个项目是关于从科学文献中合成 NLP 数据集的。我完全有能力收集、构建和描述这类数据,尽管我官方上是一个地理方向的人。顺便说一句,客户给了我 5/5 的评分!所以,细分领域并不会限制你,不用担心。

II. 竞争越少越好。

听起来有些矛盾,但请相信我。我见过成百上千个有超过 50 个提案的项目。你不想参与这种竞争,不是因为你不配,而是因为客户根本无法充分评估所有专业人士的技能。他们只希望工作能完成。

所以只有在提交的申请不超过 10 个时才申请,或者申请,但避免对你能参与该合同的工作抱有任何不切实际的期望。

一个有超过 50 个提案的项目示例。图片来自作者

一个提案提交少于 5 个的更细分领域项目示例。图片来自作者

III. 不要忽视你在提交提案时写的内容。

通常情况下,UpWork 会隐藏客户的名字,但如果你查看他们的客户评价,你可能会发现其他自由职业者在评论中提到了这个人的名字,所以请用他们名字来称呼客户!

添加一个外部作品集,比如 GitHub/个人网站或任何其他平台,这些可以作为你技能的指示器。UpWork 允许你从你的 UpWork 作品集中附加最多 3 个项目,但通常这不够。所以一定要分享一些链接给客户,让他们更深入地了解你的其他项目。在我的情况下,许多客户提到我的 Medium 博客是雇佣我的关键原因之一。这是一个直接的方式来展示我的技术和沟通能力。

最后,不要把提案的重点放在你的个人特点上。我知道你想证明自己是这个项目的最佳人选。但没人关心,抱歉。客户唯一关心的,是问题是否能得到解决。所以,专注于此。展示你关心和热衷帮助的态度。把客户当作愿意治疗伤口的医生,而不是追逐下一个收入的商人。

IV. 避免与“转售商”打交道。

这类客户像是经纪人或代理人。他们在多个平台上寻找甜蜜的项目,签订合同后,再去 UpWork 雇佣其他人以较低价格完成工作。

为什么这不好?因为他们几乎不和你以及另一方的客户进行沟通。他们根本不在项目的背景中。

我非常不幸,第一次接到这种客户。他最后创建了一个群聊,把我和最初的客户加进来,让我不要告诉对方是我在做这项工作。

解决方案很简单:检查他们以前的项目。通常,代理商有大量的评价(不一定是负面的)。但项目的标题之间完全没有关联。如果这个人雇佣别人做数据分析、卫星图像处理、视频编辑、解答数学奥林匹克题目和 3D 设计,确实会让人感到怀疑!

V. 使用 GPT 让你的个人项目更出彩。你在 UpWork 上展示的任何作品都应该有一些背后的故事。刚开始时,你唯一能展示的可能是一个个人项目或某个课堂作业。利用 ChatGPT 创建一个引人注目的作品描述,让它看起来像一项真正的工作,且具有影响力和行业应用。

对于你的第一个客户,你会逐步添加更多作品,但根据我的经验,4-5 个作品示例就足够了

这是我自己的 UpWork 作品集的一部分。图片来自作者

希望这四个要点能为你提供信息和启发。我希望在开始这段旅程之前,10 个月前就知道这些。

如果你考虑创建一个新的副业,或者辞去朝九晚五的工作成为自由职业者,UpWork 可能是一个非常不错的选择,至少在最初的 1-1.5 年内。

祝你有启发性的项目、有趣的客户和快速的职业成长。如果你喜欢这篇文章,给它点个赞吧 😃

===========================================

我在 Medium 上的所有出版物都是免费且开放访问的,因此如果你在这里关注我,我会非常感激!

P.s. 我对(地理)数据科学、机器学习/人工智能和气候变化充满热情。如果你想一起做项目,请通过LinkedIn联系我,或者查看我的网站

🛰️关注以获取更多内容🛰️

我给想成为数据科学家的人的诚实建议

原文:towardsdatascience.com/my-honest-advice-for-someone-who-wants-to-become-a-data-scientist-1ecc018fb0b2?source=collection_archive---------1-----------------------#2024-08-12

我希望有人在我学习数据科学之前告诉我这些话

Egor HowellTowards Data Science Egor Howell

·发布于 Towards Data Science ·阅读时长 6 分钟 ·2024 年 8 月 12 日

--

图片来源:GRAY via Unsplash

所以,我收到了很多关于如何进入数据科学领域的咨询信息。虽然我已经写过几篇文章,详细介绍了我会遵循的路线图和过程,但我认为有时候听听残酷的真相也挺好的。这就是为什么在这篇文章中,我想给你们那些想成为数据科学家的朋友们提供我完全诚实的建议。

注意:你也可以观看这篇博客文章的视频版本。

现在就开始

我经常被问到,应该使用平台 A还是平台 B课程 X课程 Y,哪个更好?

我以前已经说过很多次了,现在再说一遍:这不重要。

任何大型课程提供商或高度评价的课程,尤其是在初学者和入门阶段,都会教你相同的内容。你花在选择课程的时间,不如花时间去真正做……

我一年后再次被聘用的经历……第一部分

原文:towardsdatascience.com/my-learning-to-being-hired-again-after-a-year-part-i-b99a11255c5d?source=collection_archive---------1-----------------------#2024-06-23

适合任何求职的人,不仅仅是科技行业的人

Amy MaTowards Data Science Amy Ma

·发表于Towards Data Science ·10 分钟阅读·2024 年 6 月 23 日

--

图片由Jacqueline Munguía提供,来源于Unsplash

一年前的今天,2023 年 5 月 13 日,我被裁员了。今天,我开始了新工作的第一天。在过去的一年里,我成为了一名母亲,并发现了自己从未意识到的部分。

我想分享一些我在这个旅程中的学习经验。但如果你在寻找破解编程面试或者应对行为问题的技巧,这不是那种类型的文章。这些文章通常会详细列出作者与多少家大公司进行了面试,收到的工作邀请,使用的备考资源,甚至提供他们面试过程的详细流程。他们总是以“这并不容易,我哭过,担心过,但我还是做到了。祝你好运!”来结束。虽然我尊重并感谢他们的坦率,但这些文章常常让我感到焦虑和不安。

这篇文章是为任何正在找工作的人写的,无论你处于哪个阶段或职位类型。我想特别向那些在求职过程中感到冷漠和沮丧的人伸出援手,因为我曾经也有过这样的感受。这是我的个人哲学之一:成功故事不会激励我,除非它们详细描述了艰难困苦以及如何克服。我想从错误和障碍中学习,而不是从别人开心的派对中得到启发。

重新找回自我:你不仅仅是一个工作

我的 Medium 之旅作为数据科学家:6 个月,18 篇文章和 3,000 个粉丝

原文:towardsdatascience.com/my-medium-journey-as-a-data-scientist-6-months-18-articles-and-3-000-followers-c449306e45f7?source=collection_archive---------4-----------------------#2024-11-11

Medium 写作者的真实数据、收入和数据驱动的增长策略

Yu DongTowards Data Science Yu Dong

·发表于Towards Data Science ·9 分钟阅读·2024 年 11 月 11 日

--

我在 2024 年 5 月开始在 Medium 上撰写数据科学和人工智能内容。这是我的第六个月,我刚刚达到了一个重要的里程碑——3,000 个粉丝!我为自己的成就感到非常自豪。

在这篇文章中,我将分享这段旅程是如何开始的,我写了什么,以及我学到了什么。此外,作为一名数据科学家,我总是喜欢分析我自己的数据。我收集了一个包含我的 Medium 统计数据的数据集,包括文章的浏览👀、阅读📖、点赞👏、收入💵等。加入我,一起用数据分析我的 Medium 经历,并分享我的数据驱动写作策略

图片由 DALL·E 创建

我的 Medium 之旅概览

一切是如何开始的

我的写作习惯远在我开始在 Medium 上写作之前就已经有了。自 2018 年起,我便开始运营我的数据科学作品集网站,那时我刚刚开始了我的第一份全职工作。我会在网站上发布文章,并偶尔分享到 LinkedIn。这有助于我与数据领域的朋友和同事建立联系。今年早些时候,我发布了一篇关于我对定制 GPT 的实验的文章,它在 LinkedIn 上达到了近 10k 的曝光量。这个成绩还不错,但……

我的每周日程安排作为高级数据科学经理

原文:towardsdatascience.com/my-weekly-calendar-as-a-senior-data-science-manager-d57112ae372d?source=collection_archive---------0-----------------------#2024-09-05

我的目标是覆盖 3 个 P:人员、项目和流程,按这个重要性顺序。我的目标是“50:50”法则。

Jose ParreñoTowards Data Science Jose Parreño

·发表于 Towards Data Science ·14 分钟阅读·2024 年 9 月 5 日

--

图片由 Djim Loic 提供,来源于 Unsplash

你是刚晋升的经理吗?或者,如果你已经是经理,你还记得从个人贡献者角色过渡到管理者时的经历吗?在写这篇文章时,我已经领导数据科学团队 4 到 5 年了,但我仍然记得作为经理初期的主要挑战之一:时间管理

数据科学家或高级数据科学家的个人贡献者通常专注于 1 到 2 个问题领域。IC(Individual Contributor,个人贡献者)角色应该为你提供空间,专注于一个深度问题,减少来自其他项目和行政工作的干扰。你可能会和工程同事或其他数据科学同事进行一些临时性的讨论,但这些是必要的,因为技术设计和系统依赖性对于交付至关重要。当然,必须参加与利益相关者的每周会议,但你可能每周只有 1 次这样的会议。因此,可以说,作为一个个人贡献者,你的每周大约有 20% 的时间用于开会,80% 用于技术工作。

现在…如果不是只处理一个项目,而是管理 3 到 5 个项目并且有 6 到 12 个直接下属会发生什么?如果你不小心

N-BEATS — 第一个在时间序列预测中有效的可解释深度学习模型

原文:towardsdatascience.com/n-beats-the-first-interpretable-deep-learning-model-that-worked-for-time-series-forecasting-06920daadac2?source=collection_archive---------2-----------------------#2024-05-11

易于理解的 N-BEATS 工作原理深度解析,以及你如何使用它。

Jonte DanckerTowards Data Science Jonte Dancker

·发表于Towards Data Science ·11 分钟阅读·2024 年 5 月 11 日

--

N-BEATS 架构(图片来自Oreshkin 等人)。

时间序列预测一直是深度学习和 Transformers 未能超越其他模型的唯一领域。

从 Makridakis M 比赛来看,获胜的解决方案总是依赖于统计模型。直到 M4 比赛,获胜的解决方案才是纯统计模型或机器学习与统计模型的混合。纯机器学习方法几乎无法超越竞争基准。

这一切在 2020 年由Oreshkin 等人发表的一篇论文中发生了变化。作者们发布了 N-BEATS,这是一种有前景的纯深度学习方法。该模型超越了 M4 竞赛的获胜方案。它是第一个超越成熟统计方法的纯深度学习方法。

N-BEATS 代表Neural Basis Expansion Analysis for Interpretable Time Series(可解释的时间序列神经基础扩展分析)。

在本文中,我将介绍 N-BEATS 背后的架构。但请不要担心,深入讲解将会易于理解。我还将展示如何使深度学习方法具有可解释性。然而,仅仅理解 N-BEATS 如何工作是不够的。因此,我将向你展示如何…

N-HiTS — 使深度学习在时间序列预测中更加高效

原文:towardsdatascience.com/n-hits-making-deep-learning-for-time-series-forecasting-more-efficient-d00956fc3e93?source=collection_archive---------3-----------------------#2024-05-30

深入探讨 N-HiTS 如何工作以及如何使用它

Jonte DanckerTowards Data Science Jonte Dancker

·发布于 Towards Data Science ·10 分钟阅读·2024 年 5 月 30 日

--

N-HiTS 架构(图片来自 Challu 和 Olivares 等人)。

2020 年,N-BEATS 是第一个在时间序列预测中超越统计模型和混合模型的深度学习模型。

两年后的 2022 年,一款新模型将 N-BEATS 从其宝座上推翻了。Challu 和 Olivares 等人发布了深度学习模型 N-HiTS。他们解决了 N-BEATS 在较长预测时间跨度上的两个短板:

  • 降低准确度并

  • 增加计算量。

N-HiTS 代表 Neural Hierarchical Interpolation for Time Series Forecasting。

该模型基于 N-BEATS 及其神经基扩展的思想。神经基扩展发生在多个跨层堆叠的块中。

在本文中,我将介绍 N-HiTS 背后的架构,特别是与 N-BEATS 的区别。但是不要担心,深入探讨将易于理解。不过,单单理解 N-HiTS 如何工作还不够。因此,我将向你展示如何在 Python 中轻松实现一个 N-HiTS 模型,并调整其超参数。

如果核心思想相同,N-BEATS 和 N-HiTS 有什么区别呢

N-of-1 试验与分析你自己的健身数据

原文:towardsdatascience.com/n-of-1-trials-and-analyzing-your-own-fitness-data-7ae390036094?source=collection_archive---------3-----------------------#2024-04-20

喝酒后我真的睡得更差吗?

Merete LutzTowards Data Science Merete Lutz

·发表于Towards Data Science ·阅读时间:11 分钟·2024 年 4 月 20 日

--

图片由Luke Chesser提供,来源于Unsplash

我第一次听说 N-of-1 试验是在 2018 年,当时我还是一名流行病学的硕士生。我在我的中级流行病学与临床研究方法课上,听了一位Dr. Eric Daza博士关于 N-of-1 研究设计的讲座。N-of-1 研究可以看作是一项临床试验,旨在研究某一干预措施对个体患者的疗效。当时,这种方法学是一个新兴的实践,具有个性化医学和优化个人医疗保健的有希望的前景。

在讲座的最后,他提到 N-of-1 实验是一些对 Fitbit 着迷并熟悉统计学的人的兴趣爱好。这些数据科学家能够获取自己的生物识别数据,进行实验和分析,以优化他们的睡眠、锻炼和饮食。我对此感到非常着迷。

这次讲座的时机简直太完美了。我在过去的五年里一直是女子排球队的学生运动员,正准备开始我将持续五年的职业排球生涯。我刚开始使用Whoop带子来追踪我的锻炼、睡眠和恢复,手头有了这些新的数据。我正在学习统计学和研究设计,而现在我知道如何将它们应用到实践中,也许还能在我的新职业生涯中占得先机。

在接下来的一年半里,我通过我的 Whoop 设备严格记录了所有数据,但最终停止佩戴该设备,因为当时公司不允许你下载自己的数据。快进到几年后,现在我是英属哥伦比亚大学的数据科学硕士生。我掌握了更多的分析方法,而且 Whoop 终于允许你访问你的旧数据!现在,比以往任何时候,我终于可以进行自己的 N-of-1 研究,回答当时我所困惑的一些问题。

在进行分析之前,我们必须首先定义 N-of-1 研究框架,并回顾一下历史上因果医学研究的方法。

随机对照试验(RCT)

现代医学研究的黄金标准是随机对照试验(RCT)。假设我们想知道一种新药是否能够降低心脏病发作的风险。在 RCT 中,一组患者会被随机分配到治疗组(服用药物)或对照组(服用安慰剂)。研究人员巧妙地设计了这个实验,使得两组之间的个体具有相似的特征,唯一的主要区别就是他们是否服用药物。我们会跟踪这些人一段时间,记录发生的心脏病发作事件。在实验结束时,我们统计每组中的数据,进行一些统计分析,以便比较两组之间的心脏病发作发生率,看看是否存在统计学上显著的差异。

RCT 是因果推断中极为强大的工具,使我们能够发现某种干预是否会导致期望的反应。它们是应用医学研究的历史支柱,但由于它们过于强调结果的普适性,受到了一定的限制。当 RCT 进行时,最终的目标是为更广泛的群体设定一个新的实践标准,而不仅仅是研究参与者。我们根据样本对整体人群做出推断,在此过程中,我们往往会忽略个体反应的差异。这种做法几乎与医学的目标相违背,因为医学的目标是为个体患者提供治疗。

N-of-1 试验

N-of-1 试验通过将 RCT(随机对照试验)研究设计应用于个体层面,解决了这个局限性。它们使我们能够探索患者对特定治疗的反应差异,并通过减少在次优治疗上花费的时间,帮助改善患者的治疗效果。尽管 N-of-1 研究的理念已经存在一段时间,但现在由于技术的进步,收集数据变得更加容易,这类研究也变得更加可行。

N-of-1 试验并不总是个性化医学的答案。对于像传染病这样的快速变化的疾病,你可能没有时间进行这种个性化的试验,最好采取更为普适的方法。然而,对于慢性病的治疗,N-of-1 试验提供了一条极具前景的途径,有助于改善健康结果。这些病症可能不会直接威胁生命,并且可以在长时间内观察到。这为尝试不同的干预提供了机会,旨在找到最佳治疗方法。

除了医学领域,你还可以将 N-of-1 试验应用到日常生活中。我们有多少人尝试过新的药物、饮食、补充品、锻炼或睡眠习惯,却很难确定它是否有效?很难明确说明干预是否有效,因为大多数现有的证据都是轶事性质的,或者难以量化。通过结合 N-of-1 研究框架和你从可穿戴健康追踪器获取的生物数据,你可以获得确凿的证据,从而做出你知道会带来改变的生活方式调整。

实践中的 N-of-1 试验

为了给你展示这种方法的实际应用,我将基于从 2018 年 4 月 27 日到 2019 年 10 月 5 日从我的 Whoop 手环收集的一些数据进行分析。我们这项 N-of-1 研究的研究问题是:

饮酒会导致睡眠质量差吗?

作为一名运动员和流行病学家,我非常清楚酒精对睡眠、运动表现和整体健康的危害。我常常听说运动员不应饮酒,然而,别人告诉你是一回事,而亲眼看到证据又是另一回事。当我开始佩戴 Whoop 时,我注意到每次饮酒后,睡眠评分(由 Whoop 应用计算得出的指标)都会下降。有时甚至第二天,我也能看到这种影响。这些观察促使我想进行自己的分析,现在我终于完成了。

数据说明

我们分析中关注的两个变量是睡眠表现评分和酒精消费量。睡眠表现评分的范围是从 0 到 100,是由 Whoop 应用通过生物数据计算得出的指标,包括呼吸频率、浅睡眠时长、慢波睡眠时长和快速眼动睡眠时长。

酒精消费变量是对问题“你昨天有喝酒吗?”的回答,该问题是 Whoop 用户在每天早晨醒来时回答的。我总是如实而一致地回答这些问题,尽管我们的数据存在局限,因为应用程序并未询问摄入了多少酒精。这意味着所有的酒精消费量都被视为相同,这排除了我们对关系进行更深入分析的机会。在我们的酒精数据中存在一些缺失值,但这些缺失的数据被用“没有”填充,因为根据个人经验,如果我前一晚喝酒了,我肯定会在应用程序中标明。

探索性数据分析

任何分析的第一步都是进行一些探索性数据分析(EDA)。这只是为了大致了解我们的数据是什么样的,并创建一个可视化图形,帮助我们指导进一步的研究。

图 1. 按酒精摄入水平分布的睡眠表现分数的探索性图表。

从上面的箱形图可以看出,当没有摄入酒精时,平均睡眠分数似乎较高,并且分布较为集中。奇怪的是,当没有摄入酒精时,睡眠表现分数似乎有更多的异常值。也许旅行日和时差反应能解释这些异常值,因为在这段样本期间,我曾出国旅行 5 次。

现在我们已经对感兴趣的数据有了一个初步了解,是时候深入进行统计分析了。

假设检验

为了回答我们的研究问题,我将进行假设检验。假设检验是一种统计方法,允许我们基于某些样本数据对总体做出推断。在这种情况下,我们试图推断我饮酒与当晚睡眠质量差之间是否存在关联。我们没有关于我一生中每个晚上酒精摄入与睡眠情况的数据,因此我们使用样本数据作为代理进行研究。

假设检验的第一步是制定假设。‘零假设’是假设没有有趣的现象发生,或者不存在关系或效果。在我们的案例中,零假设是:摄入酒精和未摄入酒精的晚上,平均睡眠表现没有差异。

‘备择假设’是与零假设相对立的假设,声称实际上有某种有趣的现象发生。在我们的例子中,备择假设是:摄入酒精和未摄入酒精的晚上,平均睡眠表现存在差异。

选择统计检验

要评估这些假设哪个为真,我们必须选择一个统计检验方法。我们想知道我在饮酒的夜晚与未饮酒的夜晚,平均睡眠得分是否不同,因此我们将使用均值差异来进行检验。具体来说,我们的检验统计量是:无酒精的平均睡眠表现 - 有酒精的平均睡眠表现

现在我们已经定义了框架,我们可以使用 R 来计算检验统计量并评估我们的假设。

在 R 中进行分析

从我们的样本数据中,我们可以计算出观察到的检验统计量。以下是 R 中的代码:

test_stat <- data |> 
  specify(formula = sleep_performance ~ alcohol) |> 
  calculate(
    stat = "diff in means",
    order = c("No", "Yes")
  )

我们的检验统计量是 8.01。这个数字意味着,在我没有饮酒的夜晚,平均睡眠得分比我饮酒的夜晚高 8.01 分。

接下来的分析步骤是从我们的样本数据生成一个零分布。零分布表示如果我们从总体中重复抽取样本,可能观察到的所有不同的检验统计量值。该分布反映了由于随机抽样引起的检验统计量的变异。下面是用 R 生成零分布的过程:

set.seed(42) #Setting seed for reproducibility

null_distribution <- data |> 
  specify(formula = sleep_performance ~ alcohol) |> 
  hypothesize(null = "independence") |> 
  generate(reps = 1000, type = "permute") |> 
  calculate(
    stat = "diff in means",
    order = c("No", "Yes")
  )

上述操作是从数据中进行有放回的抽样,并计算这些样本的均值差异。我们执行这一过程 1000 次,以生成足够大的分布,从而判断我们的观察到的检验统计量是否具有显著性。

在得到零分布和检验统计量后,我们可以计算出双侧 p 值,显著性水平为 0.05。p 值可以理解为在零假设成立的前提下,观察到的检验统计量与我们的观测值一样极端或更极端的概率。简单来说,它表示如果不存在真实的关联,看到这个结果的可能性有多大。我们在 R 中计算双侧 p 值,如下所示,因为我们关心检验统计量是否比预期更大或更小。

p_value <- null_distribution|> 
  get_p_value(test_stat, direction = "both")

我们的 p 值是 0.017,这意味着在显著性水平 alpha=0.05 下,我们的发现具有统计学意义,这是统计学中通常接受的显著性水平。这意味着我们发现的睡眠得分差异是显著的!我们有证据拒绝零假设并接受备择假设;在饮酒与未饮酒的夜晚之间,睡眠表现的均值存在差异。

我在下面包含了一个有用的可视化图,展示了零分布、检验统计量和 95%的分位数范围。灰色条形表示从我们的 1000 个样本中计算出的许多可能的检验统计量,橙色线表示这些值的密度。蓝色虚线表示该分布的 97.5 百分位数和 2.5 百分位数,在这些范围之外,我们的检验统计量(以红色表示)被显示为显著。

图 2。零假设下检验统计量的分布(酒精消费对平均睡眠得分无影响)

最终结论

好吧,结果证明我的教练们一直都是对的!我们的分析发现,当我没有消费酒精时,我的平均睡眠得分比我消费酒精时高出 8.01 分。这一差异被发现具有统计学意义,p 值为 0.017,这意味着我们拒绝零假设,支持备择假设。这个统计结果支持了我的个人经验,给了我一个可以信赖的定量结果。

进一步探索

既然我已经完成了这项初步分析,我可以在数据中探索更多的关联,甚至使用更复杂的方法,如预测和机器学习模型。

这项分析是一个非常基础的 N-of-1 研究示例,也有其局限性。我的研究是观察性的而非实验性的,我们不能声明因果关系,因为有许多其他混杂变量是我的 Whoop 设备没有测量到的。如果我想找到因果关系,我需要精心设计研究,记录所有可能的混杂因素,并找到一种方法让我在治疗过程中保持盲目。N-of-1 研究在临床环境之外很难进行,但我们仍然可以通过向数据提出简单问题来发现有意义的关联和关系。

我希望在这个教程之后,你能够主动从任何可以获得的健身追踪器中下载你的数据,并进行尝试。我知道每个人都能提出一个关于某个变量如何影响健康的假设,但大多数人没有意识到,你离获得该问题的量化答案比你想象的要近。

参考文献与进一步阅读

[1] Davidson, K., Cheung, K., Friel, C., & Suls, J. (2022)。将数据科学引入 N-of-1 设计、统计学、应用案例、未来及“ N-of-1”试验这一称谓。哈佛数据科学评论,(特刊第 3 期)。doi.org/10.1162/99608f92.116c43fe

[2] Lillie EO, Patay B, Diamant J, Issell B, Topol EJ, Schork NJ. N-of-1 临床试验:个性化医学的终极策略?个体化医学。2011 年 3 月;8(2):161–173。doi: 10.2217/pme.11.7。PMID: 21695041;PMCID: PMC3118090。

[3] Daza EJ. 使用反事实框架进行 N-of-1 试验的自我追踪时间序列数据因果分析。方法与信息医学。2018 年 2 月;57(1):e10-e21。doi: 10.3414/ME16–02–0044。2018 年 4 月 5 日电子出版。PMID: 29621835;PMCID: PMC6087468。

[4] Schork, N. 个性化医学:是时候进行单人试验了。自然 520,609–611(2015)。doi.org/10.1038/520609a

完美应对机器学习设计面试

原文:towardsdatascience.com/nailing-the-machine-learning-design-interview-6b91bc1d036c?source=collection_archive---------8-----------------------#2024-06-18

FAANG 设计面试的提示和技巧

Rhea GoelTowards Data Science Rhea Goel

·发布在 Towards Data Science ·8 分钟阅读·2024 年 6 月 18 日

--

通过 Dall-E 生成

我是亚马逊的高级应用科学家,曾在机器学习设计面试问题的两方角色中都有过经验。我希望分享我随着时间积累的所有技巧和窍门。通过阅读本文,你将了解面试中可能遇到的内容,面试官的期望,候选人常犯的错误和陷阱,以及如何根据职位的资历/级别调整你的回答。接下来,我还会发布一系列关于常见机器学习设计面试问题的文章(包括解决方案)。敬请关注!

什么是机器学习设计面试?

机器学习设计面试是一个问题解决会议,特别关注机器学习在商业应用中的使用。面试的目的是评估你是否能够将商业问题转化为机器学习问题,并提出一个端到端的策略,以在生产环境中应用机器学习算法。

期望的内容

你将被给定一个真实世界的商业问题,通常是与你所面试的公司或根据你的简历相关的专业领域的问题。你需要主导面试全过程,并经常与面试官确认方向及时间管理上的建议。讨论通常是开放式的,可能涉及白板工具(如Excalidraw)或共享文档(如Google Docs)。通常,这一轮面试不需要编程。

FAANG 及类似公司常见的机器学习设计问题:

  • 为电商平台设计一个推荐系统

  • 为银行应用设计一个欺诈检测系统

  • 设计一个系统,将客户服务票据自动分配给正确的处理团队

面试官在寻找什么

从高层次来看,面试官需要收集以下数据:

  1. 科学广度与深度: 你能为业务问题识别出机器学习解决方案吗?

  2. 问题解决: 你能充分理解业务用例/问题吗?

  3. 行业机器学习应用经验: 你能将机器学习算法应用于生产环境中吗?

具体来说,当你阐述解决方案时,面试官会关注以下内容:

  1. 理解业务用例/问题: 你会提出澄清问题,确保自己完全理解问题吗?你理解机器学习解决方案如何用于下游任务吗?

  2. 识别业务成功指标: 你能定义明确的业务指标来衡量成功吗?比如通过与问题相关联的指标,如点击率、收入或更低的解决时间?

  3. 将业务问题转化为机器学习问题: 你能识别出适用于此问题的正确机器学习算法类型吗,比如分类、回归、聚类或其他算法?

  4. 识别系统的高层次组件: 你能识别出整个系统的关键组件吗?你能展示各种在线和离线组件如何相互作用吗?你是否遵循一个有条理的思考过程:从数据收集、预处理、模型训练、部署,到用户服务层?

  5. 建议相关数据/特征: 你能识别出哪些数据和特征对模型的性能至关重要吗?你能推理出最佳的数据收集策略吗——通过人工标注收集真实数据、使用隐式数据(例如用户点击)还是使用一些自动标注方法?你能推理不同数据源的质量吗?

  6. 预测特征/标签可能存在的偏差或问题并提出缓解策略: 你能预测数据质量问题吗,比如缺失数据、稀疏特征或特征过多?你是否考虑过标签中的噪音?你能预见数据中的偏差吗,比如流行度偏差或位置偏差?你如何解决这些问题?

  7. 使用简单模型设定基准并推理为何需要更复杂的模型: 你能为该问题建议合适的算法吗?你是否建议构建一个基于启发式的方法的简单模型或轻量级模型,作为一个基准模型,以便在需要时评估更高级/复杂的模型?在从简单模型迁移到更复杂的模型时,你能推理出性能与复杂度的权衡吗?

  8. 训练流水线经验: 你能解释训练模型的不同步骤吗?你如何进行训练-测试-验证集的划分?你会使用什么损失函数?你会使用什么优化器?你会使用什么架构和激活函数?你会采取哪些步骤来防止过拟合?

  9. 提出离线评估指标和在线实验设计:你能为你的模型识别出合适的评估指标吗(例如,精确度、召回率)?你能提出一个好的在线实验设计吗?你是否提议采取错开拨号策略以减少意外问题时的影响范围?

常见的好与不好回答错误

#1 直接跳入模型

有些候选人直接跳到他们会使用的机器学习算法来解决问题,而没有先阐明业务应用、解决方案的目标以及成功的衡量标准。

不好的回答:“对于欺诈检测,我会使用深度神经网络,因为它很强大。”

好的回答:“这个解决方案是用来做每次刷卡的实时欺诈检测吗?这意味着我们需要一个快速且高效的模型。让我先确定可以用来训练这个模型的所有数据。首先,我有交易元数据,例如交易金额、地点和时间。我还可以使用这张卡片的历史交易数据——我可以查看过去 30 天的数据,从而减少需要实时分析的数据量,或者我可能会预先计算从交易历史中得出的分类/二元特征,比如‘is_transaction_30_days’(30 天内是否有交易)、‘most_frequent_transaction_location_30days’(过去 30 天最频繁的交易地点)等。最初,我将使用逻辑回归来设置基准,然后再考虑使用更复杂的模型,如深度神经网络,如果有必要的话。”

#2 过于泛泛而谈

你不仅仅需要提供一个套用的策略,还应该在每个步骤中加入与特定业务问题相关的具体例子。

不好的回答:“我会做探索性数据分析,去除异常值,然后构建模型预测用户参与度。”

好的回答:“我会分析历史用户数据,包括页面浏览量、点击率以及在网站上停留的时间。我会分析如产品类别、品牌等分类特征,如果有超过 75%的值缺失,我会将其删除。但在这一步我会小心,因为某些特征的缺失有时可能也包含很有价值的信息。逻辑回归模型可以作为起点,如果需要的话,我会进一步使用像随机森林之类的更复杂的模型。”

#3 只解决理想情况

如果候选人只谈论数据和建模策略,而没有讨论数据质量问题或其他在实际应用中可能遇到的细节,那么很容易看出其缺乏行业经验。

不好的回答:“我会使用过去的用户-商品点击数据来训练分类器,预测广告点击。”

好的回答:“过去的用户-商品点击数据可能会有位置偏差,因为在搜索结果中,排名靠前的商品更容易被点击。我会通过估计每个位置的点击概率(即倾向度),并使用逆权重倾向性来修正这个位置偏差,从而加权所有标签。”

#4 从最复杂的模型开始

你需要展示采取行动的偏好,使用易于开发、成本较低且时间消耗少的轻量级模型,并根据需要引入复杂性。

不良回答:“我将使用最先进的双编码器深度学习架构来构建推荐系统。”

良好回答:“我将从一个简单的协同过滤方法开始,建立一个基准模型。一旦了解其表现,我们可以根据初步结果的需要引入矩阵分解或深度学习模型,如双编码器。”

#5 在遇到难题时不做调整

面试官可能会打断你的策略,提出后续问题或提出替代情境,以了解你对不同技术的理解深度。你应能够在他们引入新挑战或变化时调整你的策略。

不良回答:“如果我们无法访问用户的个人身份信息,我们就无法构建个性化模型。”

良好回答:“对于选择不分享个人身份信息或过去互动数据的用户,我们可以将他们视为冷启动用户,并向他们推荐基于人气的内容。我们还可以加入在线会话 RNN,基于他们在会话中的活动调整推荐。”

根据层级调整回应

随着职位层级的提升,回答中的广度和深度预期也会增加。通过一个例子可以更好地说明这一点。假设你被要求设计一个在线支付平台的欺诈检测系统。

入门级(0-2 年相关行业经验)

对于这个层级,候选人应关注数据(特征、预处理技术)、模型(简单的基准模型、更复杂的模型、损失函数、优化方法)以及评估指标(离线指标、A/B 实验设计)。一个好的流程如下:

  1. 特征识别与预处理:例如交易金额、位置、时间以及其他代表支付历史的分类特征。

  2. 基准模型与进阶模型:例如使用逻辑回归模型作为基准,考虑在下一版本中采用梯度提升树。

  3. 评估指标:例如精准率、召回率、F1 分数。

中级经验(3-6 年相关行业经验)

对于这个层级,候选人应关注业务问题以及在生产环境中部署模型时的细节。一个好的流程如下:

  1. 业务需求:例如在召回率和精准率之间的权衡,我们希望减少欺诈金额,同时保持低假阳性率,以提供更好的用户体验;强调需要可解释的模型。

  2. 数据细节:例如欺诈交易的数量远少于非欺诈交易,可以使用像 SMOTE 这样的技术来解决类别不平衡问题。

  3. 模型权衡:例如,先使用基于启发式的基线模型,然后是逻辑回归,再接着是基于树的模型,因为它们比使用难以解释的非线性特征变换的逻辑回归更易于解释。

  4. 讨论部署细节:例如,实时交易处理,以及更新模型的节奏,以适应变化中的欺诈模式。

高级/员工/首席级别经验(6 年以上)

对于这个层级,候选人需要利用他们多年的经验,深入思考更广泛的生态系统,识别这个领域中的核心挑战,并强调不同的机器学习子系统如何协同解决更大的问题。解决如实时数据处理和确保模型对抗攻击的鲁棒性等挑战。提出一种多层次的方案:使用规则引擎系统进行即时标记,并使用深度学习模型进行模式识别。包括反馈回路和监控机制,以确保模型能够适应新的欺诈形式。同时,展示你对最新行业趋势的了解(例如,使用 GPU、表示学习、强化学习、边缘计算、联邦机器学习、构建不涉及 PII 数据的模型、机器学习中的公平性与偏见等)。

希望这份指南能帮助你顺利应对机器学习设计面试!如果有任何想法或经验分享,请留言评论。

其他有用资源

朴素贝叶斯,清晰解释

原文:towardsdatascience.com/naive-bayes-clearly-explained-880e00560050?source=collection_archive---------3-----------------------#2024-03-02

使用高斯方法在 scikit-learn 中解决鸢尾花数据集问题。

Yoann MocquinTowards Data Science Yoann Mocquin

·发布于 Towards Data Science ·10 分钟阅读·2024 年 3 月 2 日

--

在这篇文章中,我们将深入探讨一种特定的分类器,称为朴素贝叶斯分类器。这些方法依赖于贝叶斯定理和朴素假设,即每一对特征在给定类别标签的条件下是条件独立的。如果这对你来说没有意义,继续阅读!

作为一个示例,我们将使用著名的鸢尾花数据集(CC BY 4.0 许可证)和一种特定的朴素贝叶斯分类器,称为高斯朴素贝叶斯分类器。请记住,鸢尾花数据集由 4 个数值特征组成,目标是三种鸢尾花类型中的任意一种(setosa、versicolor、virginica)。

我们将把该方法分解为以下步骤:

所有图像由作者提供。

  1. 复习贝叶斯定理:这个定理提供了一个数学公式,允许我们估计一个给定样本属于某个类别的概率。

  2. 我们可以创建一个分类器,一个对于输入样本返回预测类别的工具,通过比较该样本属于每个类别的概率。

  3. 使用链式法则和条件独立性假设,我们可以简化概率公式。

  4. 然后,为了能够计算概率,我们……

命名实体识别揭秘——必备指南

原文:towardsdatascience.com/named-entity-recognition-unmasked-the-essential-guide-404ad0568964?source=collection_archive---------3-----------------------#2024-03-29

如何像专家一样使用 NER 从文本语料库中提取个人信息

RechitasinghTowards Data Science Rechitasingh

·发表于 Towards Data Science ·9 分钟阅读·2024 年 3 月 29 日

--

图片来源:Christopher GowerUnsplash

介绍

好吧,假设这样一种情况——你有大量的文章、期刊和博客,里面充满了你想处理的信息。现在假设你认为,如果社区能够有机会处理这些数据,可能对他们也会有所帮助,但是,你不想立即分享这些数据,因为它可能包含某些个人信息,未经这些人的同意不应共享。

由于向所有这些人请求许可不可行,你决定运用自己的技能,在 FERPA 指南下掩盖任何个人信息。公司在将数据分享出去进行分析或演示时,通常会对数据进行掩盖,数字数据更容易做到这一点。我们在这里也想做同样的事情,但针对的是文本数据。

现在,在这里,由于我们讨论的是文本数据,我们将使用自然语言处理(NLP)中的一种技术。引入命名实体识别(NER),一个可靠的 NLP 侦探,揭示那些隐藏的数据宝藏。这里的目的是识别个人信息。

让我们深入了解 NER 是如何工作的,NER 机制背后的概念,如何实现 NER,选择哪种解决方案方法以及为什么选择它,以及如何在 Python 中实现这一问题的解决方案。

命名实体识别(NER):技术解析

简单来说,NER 就是教计算机在文本中识别特定的“实体”,在这种情况下是个人身份信息(PII)。你可以把它想象成给程序一套荧光笔——一种标记名字,一种标记地点,一种标记公司,一种标记大学、学生 ID、电子邮件地址或任何能识别个人身份并危及其隐私的内容。下面是 NER 如何工作的一个简要介绍:

  • 基于规则的系统: 传统的方法。我们创建手写规则,例如“一个名字通常以大写字母开头”。对于基础案例有效,但可能变得非常复杂。此外,如果有很多规则,系统会变得越来越混乱。

  • 机器学习方法: 统计模型通过大量数据集进行学习。可以把它理解为向你的命名实体识别(NER)系统展示大量示例,让它自己找出模式。这就是机器学习在所有领域中的工作方式。然而,涉及文本数据时,它仍然可能存在性能问题。

  • 深度学习明星: 神经网络是解决文本、图像和视频数据相关问题的最著名模型方法。处理复杂语言,类似我们人类的方式。这些模型能够理解上下文,使其非常精确。唯一的条件是,你需要使用大量的数据,否则模型将大部分训练数据记住(过拟合)。虽然有控制过拟合的技术,但它仍然在大量数据的语料库中效果最好。

详细的技术分析:NER 背后的“大脑”

我们已经看到,NER 可以利用多种技术,每种技术都有其优势。下面是更深入的分析:

  1. 条件随机场(CRFs):

想象一下,你在教 NER 系统识别位置。你展示像“10 Made UP Street, London, UK”这样的地址示例。CRFs 在这方面表现出色,因为它们会查看整个词序列及其关系。它们考虑到“London”跟在数字后面,且“UK”跟在城市后面,这强烈暗示这是一个位置实体。这使得 CRFs 在像 NER 这样的任务中非常强大,因为上下文非常重要。

阅读这篇关于 CRF 及其背后数学原理的精彩文章:条件随机场解释,作者:Nikos Kafritsas

2. LSTM 网络(长短期记忆网络):

假设你想识别文本中的人名。LSTM 是继 RNN 之后的一项伟大发明,因为它们具有一种特殊的能力——记忆!是的,它们可以保持记忆或上下文,并根据上下文进行预测。与仅考虑当前单词的 CRFs 不同,LSTM 可以记住序列中的前一个单词,并不会偏离上下文。这对 NER 非常重要。为什么?因为这将帮助你理解它是指“Apple”公司,还是指“Apple”水果。

再举一个例子:“史密斯博士是著名的心脏病专家”这句话中,LSTM 能够记住头衔“博士”,并利用这个上下文正确地将“史密斯”分类为一个人名。

这里有一个现实世界的例子:假设你正在构建一个新闻分类模型,依据提到的人物来分类文章。基于 LSTM 的 NER 系统肯定可以识别诸如“巴拉克·奥巴马”或“Elon Musk”这样的实体,即使他们的名字出现在复杂的句子中并被正确分类。不错的实现,对吧?

阅读这篇文章了解Rian Dolphin的《LSTM 全面介绍》。

3. Transformers:

Transformers 是当前 NLP 领域的热门话题,NER 也不例外。这些模型使用一种注意力机制,就像你将所有注意力集中在画中的某个细节上一样。它们的做法是,关注整个句子中的相关词汇,而不仅仅是附近的词。可以想象成一个人看着他/她从未读过的文本。我们在各处扫视,聚焦(像聚光灯一样)不同的部分并捕捉意义。这种技术使得它们能够理解复杂的关系,甚至识别出那些不太显眼的实体。

例如,考虑这句话:“位于加利福尼亚的 Acme 公司首席执行官宣布了一款新产品的发布。” 基于 Transformer 的命名实体识别(NER)系统可以关注“CEO”和“Acme 公司”,即使它们之间被几个词分开。然后,它可以利用这种关注来正确地将“Acme 公司”分类为一个组织。

这种能力使得 Transformer 非常适合用于诸如在研究论文中识别医学术语或在社交媒体数据中识别特定产品名称等任务。

阅读这篇文章了解微调能做什么:使用 Transformers 和 Spacy 进行 NER by James Briggs。如果你仍在疑惑 Attention 是什么,先阅读这篇文章,它详细讲解了 Attention:你需要了解的所有关于‘Attention’和‘Transformers’的信息 — 深入理解 — 第一部分 by Arjun Sarkar

超越基础:新兴技术

NER 领域正在不断发展。以下是一些值得关注的令人兴奋的发展:

  • 双向 LSTM(BiLSTM): 这些是升级版的 LSTM,能够正向和反向处理文本。这使得它们能够更深入地理解上下文。它也有一些缺点,因为你无法使用它来进行预测,因为你要同时将句子正向和反向输入。所以系统能知道上下文。

  • 命名实体消歧(NED): 再次以苹果为例。假设你在文本中看到了“Apple”这个名字。是指科技巨头还是指水果?NER 可以与 NED 结合,以识别在上下文中最可能的含义。

通过理解这些技术并保持对最新进展的更新,你可以利用 NER 的力量从文本数据中提取有价值的个人信息,并推动你的研究工作。

命名实体识别(NER)实践:你的项目代码一瞥

是时候动手实践了!假设你正在使用 Python 和强大的 spaCy 库:

Python

python -m spacy download en_core_web_trf
pip install spacy
pip install nltk
import spacy
from spacy import displacy
import nltk  # Assuming you might use NLTK for additional tasks 
from nltk.corpus import stopwords  # Example of utilizing NLTK

# Load a powerful pre-trained NER model (adjust based on your needs)
nlp = spacy.load("en_core_web_trf")

# Define your text for analysis
text = """
Jane Doe, a researcher at Stanford University, recently published a paper on 
Natural Language Processing.  Dr. John Smith from MIT will be collaborating on the 
project.  They can be reached at jane.doe@stanford.edu and john_smith@mit.edu.
"""

# Process the text with NER
doc = nlp(text)

# Print identified entities 
print("Entities Found:")
for entity in doc.ents:
    print(entity.text, entity.label_)

# Visualize the NER results
displacy.render(doc, style="ent", jupyter=True) 

输出

作者的图片 — 上述代码的输出

如上所示,虽然 NER 能够完美识别人名和组织名,但却漏掉了它们的电子邮件地址。让我们探索一下为什么会这样以及如何解决。

漏掉电子邮件地址的原因

  1. NER 模型的局限性: 标准的 NER 模型通常训练于诸如人名、组织、地点等类别。虽然它们可能会识别一些电子邮件模式,但这并不是它们的主要强项。因此,在本例中它漏掉了电子邮件地址。

  2. 电子邮件地址的复杂性: 电子邮件格式可能出奇地多样。像“Gmail”和“Yahoo”这样的简单格式可能会被识别,但更复杂的模式可能会被漏掉。例如,它可能识别 gmail ID,但可能遗漏一些组织特定的 ID。再次强调,这就是在本例中发生的情况。

虽然我们知道原因,但我们可能更关注如何解决问题!

让我们使用以下三种技术中的一种来个性化并解决当前任务:

  • 正则表达式(Regex): 正则表达式允许构建特定的模式来匹配电子邮件地址。这一方法很早就被开发出来并且被广泛使用。你可以将它视为一种在编程中识别模式的硬编码方式。以下是一个基本示例:
import re

email_regex = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
emails = re.findall(email_regex, text)
print(emails)
  • 专门化库: 你可以使用像 email_validator 这样的库。这些库专门用于电子邮件识别和验证,如果验证电子邮件是你的用例。

  • 增强你的 NER 模型: 你可以通过提供电子邮件地址作为额外的实体类型来微调你现有的模型。然而,这将需要更多的数据,并且可能需要更复杂的模型训练。这包括使用像 BERT 这样的预训练模型等。再次,阅读这篇文章 NER 与 Transformers 和 Spacy 由James Briggs撰写。它讨论了如何微调 roBERTA 并使用 spaCy。这样,你会更加清楚这一选项。

为了演示目的,我们选择第一种方法并将其实现到代码中。我们可以增加一个专门提取电子邮件的部分,如下所示:

Python

import re

email_regex = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
emails = re.findall(email_regex, text)
print("Emails Found:", emails) 

输出

作者的图片 — 上述代码的输出

选择最佳方法

理想的解决方案取决于你的具体项目:

  • 简单的电子邮件 + 准确性: 正则表达式可能足够了。

  • 复杂的电子邮件 + 可靠性: 专门的电子邮件验证库是最安全的选择。

  • 广泛的 NER 再训练: 如果其他实体的 NER 准确性很重要,并且你有大量以电子邮件为重点的数据,重新训练你的模型可能是长期解决方案。你没猜错,你可以使用一些先进的技术,如通过微调 BERT 预训练模型,并将其应用于你的问题。使用这些技术之前需要考虑一些重要事项:

    • 数据: 微调通常需要相当数量的标注数据。如果你的数据有限,其他技术(如正则表达式)可能一开始更实用。

    • 复杂性: 微调涉及更多的设置,并且可能需要比使用正则表达式或基础库更多的计算资源。

提升你的 NER 技能的资源

前路如何

NER 仍然是一个热门的研究领域——准备迎接那些能理解复杂关系、识别自定义实体并跨语言工作的模型吧!这项技术正在革新我们从浩瀚的文本海洋中提取和利用信息的方式。

随时提出更技术性的问题或真实世界的例子,帮助我们完善这篇关于命名实体识别(NER)的指南!

关于作者

亲爱的读者们,我对这个话题充满热情,喜欢写关于数据科学的文章和发人深省的文章。最重要的是,我愿意接受反馈!

我很想知道你的评论。如果这篇文章对你有帮助,或者你有任何反馈,请不要犹豫,留言告诉我!如果你希望进一步解释某个话题,也可以留言,我会尽力在这里解答,或者写另一篇相关文章!

如果想了解更多关于我的信息,我这里有一篇小文章供你阅读:

[## 关于我 — Rechita Singh

数据奇才,思维探索者,在这过程中精通制作烂笑话的艺术

medium.com

让我们开始聊天吧……

导航成本与复杂性:思维混合 LLM 级联揭示了高效部署大语言模型的路径

原文:towardsdatascience.com/navigating-cost-complexity-mixture-of-thought-llm-cascades-illuminate-a-path-to-efficient-large-23291d1eda41?source=collection_archive---------5-----------------------#2024-03-06

Yuval ZukermanTowards Data Science Yuval Zukerman

·发布于Towards Data Science ·5 分钟阅读·2024 年 3 月 6 日

--

图片来源:Joshua SortinoUnsplash

如果我告诉你,在不牺牲准确性的情况下,你可以节省 60%或更多的 LLM API 开销,你会相信吗?令人惊讶的是,现在你可以。

大型语言模型(LLM)现在已经成为我们日常生活的一部分。公司利用这项技术来自动化流程、改善客户体验、开发更好的产品、节省开支等。

托管自己的 LLM 非常具有挑战性。它们提供了广泛的功能,但运行成本通常很高。它们通常需要复杂的基础设施和大量数据。成本和复杂性是你使用提示工程的原因。你甚至可能使用检索增强生成(RAG)来改善上下文并减少幻觉。通过这两种技术,你将 LLM 的运行工作卸载给像 OpenAI、Cohere 或 Google 这样的公司。然而,将 LLM 的应用扩展到新的使用场景,尤其是使用最新的强大模型,可能会产生一个之前没有预料到的新成本。较弱的模型可能更便宜,但你能信任它们回答复杂的问题吗?现在,新的研究展示了如何节省开支,并获得与强大模型相媲美,甚至更好的 LLM 结果。

了解 LLM 级联

在追求降低 LLM 成本的过程中,研究人员转向了 LLM 级联的概念。在黑暗时代,也就是 ChatGPT 推出之前,谷歌和多伦多大学的一个团队定义了这个术语,指的是利用概率计算,通过多个 LLM 获得最佳结果的程序。

最近,FrugalGPT 论文将级联定义为将用户查询依次发送给一系列 LLM,从较弱到较强,直到答案足够好为止。FrugalGPT 级联使用专门的模型来确定何时答案达到了质量阈值,足够好。

一篇来自乔治梅森大学、微软和弗吉尼亚理工大学的最新论文《‘基于思维混合表示的低成本推理的大型语言模型级联’》提供了一种替代方案:一种可以判断答案是否足够好的方法,而无需微调另一个模型。

思维混合 LLM 级联

与使用多个 LLM 不同,‘思维混合’(MoT)推理只使用两个——GPT 3.5 Turbo 和 GPT 4。前者被认为是‘较弱’的 LLM,而后者则是‘较强’的 LLM。作者利用 LLM 的‘答案一致性’来判断 LLM 的回答是否足够好。当 LLM 对类似提示给出一致的回答时,说明它对答案有信心。因此,当较弱的 LLM 的回答一致时,就不需要调用较强的 LLM。相反,当 LLM 缺乏信心时,它们会给出不一致的回答。这时,就需要更强的 LLM 来回答提示。(注:你也可以选择任何较弱/较强的 LLM 组合。)

这些提示本身使用少量示例上下文提示,以提高 LLM 回答质量。这样的提示通过提供类似问题和答案的示例来引导 LLM 的回应。

为了改进模型推理并简化一致性测量,研究人员通过‘混合’两种提示技术,提出了一种新的推理任务提示方法:

  • 思维链(CoT)提示鼓励 LLM 在给出最终答案之前生成中间步骤或推理过程。生成这些步骤有助于模型改善复杂任务的结果,同时提高答案的准确性。

  • 思维程序(PoT)扩展了思维链提示,使用模型的输出作为进一步提示的新输入。使用这种技术的提示通常要求模型用代码回答问题,而非人类语言。

论文还介绍了两种确定答案一致性的方法:

  • 投票:该方法通过类似的提示或改变响应温度选项,从 LLM 查询中采样多个答案。然后,衡量 LLM 的答案之间的相似度。与其他所有答案最为一致的答案被认为是正确的。团队还定义了一个灵活的“阈值”值,以平衡答案一致性与预算限制。

  • 验证:该方法比较 LLM 在两种不同思维表示(例如 CoT 和 PoT)下最一致的答案。如果两个提示的回答相同,则算法接受较弱 LLM 的答案。

由于投票方法需要多个提示,因此当有预算来指导阈值时,它可能更为适用。

结论:思维混合帮助你节省成本

让我们看看 MoT 技术节省了多少成本,以及它对答案准确性的影响。

研究人员使用以下公式来计算提示成本:

  • 较弱模型的提示成本(因为我们可能会提示它多次)

  • 答案评估过程的成本

  • 如果评估过程拒绝答案,我们将添加提示强模型的成本。

结果非常显著:

  • 使用 MoT 变体——将投票和验证与 CoT 和 PoT 结合——可以在仅使用 GPT-4 的 40%成本下,达到相当的表现。

  • 在与CREPE问答数据集的测试中,MoT 以仅 47%的成本超越了 GPT-4。

  • 将 PoT 与 CoT 结合使用,相较于单独使用其中一种技术,能提高决策的准确性。

  • 在使用投票方法时,即使提高阈值,额外的成本也未显著影响质量。

  • 一致性模型在可靠地识别正确的 LLM 答案方面证明了其有效性。它成功地预测了何时应当使用强模型来获得最佳结果。

在内部托管和管理大型语言模型(LLMs)带来了显著的挑战。它们带来了复杂性、高成本,并且需要大量的基础设施和数据资源。因此,对于希望利用 LLM 广泛能力的组织来说,LLM 呈现出实质性的障碍。这可能促使你转向托管的 LLM。然而,这种方法在扩展到新的应用场景时,往往给公司带来预料之外的成本增加和预算挑战。当整合最新的强大模型时,这一点尤为明显。为了避免这种命运,你面临一个新的困境:你能信任较弱且更实惠的模型吗?你能克服它们在处理复杂问题时的准确性顾虑吗?

LLM 级联与思维混合(MoT)提供了两项重要的进展:

  1. 相较于仅使用最新模型,能够实现显著的成本节省。

  2. 与最新模型相当的可验证结果。

这一突破为组织提供了一种实用且高效的方法,在 LLM 的强大能力与有效管理成本之间找到微妙的平衡。

Domino 的工作人员软件工程师 Subir Mansukhani 为这篇文章做出了贡献。

在数据马拉松中导航数据:见解与指导 [NeurIPS’23]

原文:towardsdatascience.com/navigating-data-in-datathons-insights-and-guidelines-at-neural-information-processing-systems-26ef8a1078d4?source=collection_archive---------11-----------------------#2024-02-09

如何在数据马拉松中处理数据

Carlos MouganTowards Data Science Carlos Mougan

·发布于 Towards Data Science ·阅读时间 8 分钟·2024 年 2 月 9 日

--

什么是数据马拉松?

数据马拉松或数据黑客马拉松,广义上定义为以数据或数据科学为中心的黑客马拉松,近年来变得越来越受欢迎,为参与者和组织提供了一个平台,能够在短时间内进行协作、创新和学习,专注于数据科学领域。

这些活动挑战参与者在限定的时间内解决与数据相关的问题,要求他们理解数据科学并对所使用的数据保持敏锐的意识。

问题是什么?

数据马拉松是高能量的活动,数据科学和机器学习从业者聚集一堂,共同解决紧迫问题,这不仅关乎创新,也关乎数据的有效处理。

尽管数据马拉松有显著的潜在好处,组织往往因缺乏明确的指南和最佳实践而在处理数据时面临困难,无法有效应对可能出现的问题。

这篇博客的目标是什么?

本文基于 2023 年神经信息处理系统会议上的一篇论文《如何在数据马拉松中处理数据》(How to Data in Datathons),深入探讨了准备和选择数据参加数据马拉松的关键方面,涉及:

— 数据对数据马拉松来说 合适 是什么意思?

— 多少数据是 足够 的数据?

— 我们如何识别、分类并使用 敏感 数据?

— 数据分析 准备好 了吗?

— 数据是否 可靠

该框架来源于艾伦·图灵研究所自 2016 年以来,与 60 多个合作组织共同举办 80 多场数据马拉松挑战赛的经验和洞察!!

它旨在提供一套指南和建议,帮助准备不同类型的数据,以应对数据马拉松,这些建议来源于广泛的数据马拉松组织经验。如果有兴趣,考虑作为参与者或挑战拥有者参加其中的一个数据学习小组活动;更多信息请见[此处]

这是艾伦·图灵研究所应用技能团队的合影,摄于 2023 年 5 月。

在数据马拉松中的数据评估

数据评估矩阵。摘自《如何在数据马拉松中使用数据》#NeurIPS23

对于数据马拉松来说,任何数据都不行。数据需要是“适用的”、“充分的”,并且要考虑到隐私问题。组织者和参与者常常会遇到这样的问题:什么样的数据适合数据马拉松?多少数据才算足够?我们如何处理敏感数据?每个维度都对于确保数据马拉松使用的数据适合、合乎伦理并有助于实现活动目标至关重要。让我们逐一探讨这些方面。

1. 数据适用性

数据的适用性关乎其在解决数据马拉松特定挑战问题时的相关性和实用性。本维度评估提供的数据是否与数据马拉松的目标一致,确保参与者能够获得合适的数据来进行工作。

  • 不足: 数据与数据马拉松的目标显然没有关联,使得参与者无法有效使用。例如,为一个关注金融预测的挑战提供天气数据完全不合适。

  • 发展中: 虽然数据与挑战有一定相关性,但缺乏进行全面分析或解决方案开发所需的关键元素或目标变量。

  • 功能性: 数据是相关的,可以直接应用于挑战。然而,通过加入更多变量或更详细的元数据,仍然有提升其价值的机会,这可能提供更深入的见解。

  • 最优: 提供的数据完美契合挑战需求,包含丰富的特征、相关的目标变量以及全面的元数据。这个层级代表了理想的情况,参与者可以获取所有进行分析和开发解决方案所需的信息。

2. 数据准备度

准备度评估数据的状态,涉及其是否为即时分析做好了准备。这包括数据的清洁度、完整性、结构和可访问性等因素,这些因素对数据马拉松的效率有着显著影响。

  • 不足: 数据要么没有收集,要么组织得非常糟糕,必须付出巨大努力才能使其可用。这种情况严重限制了在数据马拉松期间能取得的成果。

  • 正在开发中: 数据已被收集,但可能不完整、格式不一致或缺乏文档,需进行初步工作后才能开始有意义的分析。

  • 功能性: 虽然数据需要一些清理或预处理,但它大体上处于可以进行分析的状态。可能需要一些小的努力来整合数据源或正确格式化数据。

  • 最佳: 数据已准备好进行分析,文档齐全、清晰、结构化。参与者可以专注于应用数据科学技术,而无需担心数据准备工作。

3. 数据可靠性

数据可靠性涉及数据的准确性和偏差问题。它考量数据在多大程度上能够真实反映其所代表的现象或群体。

  • 不足: 数据存在严重偏差或包含显著错误,可能导致误导性结论。这类数据可能错误地代表某些群体或现象,导致分析结果偏差。

  • 正在开发中: 由于数据收集和记录过程中可能存在未知的偏差或潜在错误,数据的可靠性尚不确定。这种情况需要在解释时保持谨慎,并可能限制对结果的信心。

  • 功能性: 存在已知的偏差或问题,但可以通过仔细分析来解决,或承认这些问题是研究的局限性。这个层级的可靠性需要对数据的局限性保持透明。

  • 最佳: 数据被认为高度可靠,没有已知的显著偏差或错误。它准确地代表了目标现象,可以进行自信且稳健的分析。

4. 数据敏感性

数据敏感性涉及数据的隐私、机密性和伦理问题。它评估了使用和共享数据时所涉及的风险,特别是与个人或专有信息相关的风险。

  • 不足 (第四层级): 数据高度敏感,存在显著的法律、伦理或个人风险。这类数据通常不适合用于数据马拉松,因为其滥用或造成伤害的潜在风险较高。

  • 正在开发中 (第三层级): 虽然数据敏感性不如其他级别严重,但仍然需要严格的措施来保护隐私和机密性,这可能会限制其在数据马拉松等开放合作环境中的可用性。

  • 功能性 (第二层级): 数据敏感性通过去标识化或其他保护措施得到管理,但数据保护仍然是一个重要的考虑因素。参与者在分析过程中必须关注隐私问题。

  • 最佳 (第零层级/第一层级): 数据的敏感性风险最小,允许更为直接的共享和分析。这一层级适合促进开放合作,而不妥协隐私或伦理标准。

5. 充足性

充足性评估提供的数据量和类型是否足以有效解决挑战性问题。它考虑数据的数量、种类和粒度与数据挑战赛目标之间的关系。

  • 不足: 数据的数量或多样性过于有限,无法进行有意义的分析或得出可靠的结论。这样的不足可能严重妨碍数据挑战赛的成功。

  • 发展中: 虽然有一些数据可用,但其数量或质量可能不足以完全探索挑战性问题或构建稳健的模型。参与者可能会发现很难获得重要的见解。

  • 功能性: 提供的数据足够使参与者能够有意义地应对挑战性问题。尽管数据不全面,但它使参与者能够提取有用的见解并提出可行的解决方案。

  • 最佳: 数据丰富且多样,超出了数据挑战赛的基本要求。这个水平为参与者提供了丰富的探索空间,以便他们提出创新解决方案并进行深入分析。

见解与建议

数据学习小组(DSGs)是由英国国家数据科学与人工智能研究院——阿兰·图灵研究所组织的获奖协作型数据挑战赛活动。ADSGs 是由一个单一团队共同合作进行的(而不是多个团队彼此竞争)。数据学习小组的目标是为学术界和工业界的组织与参与者提供合作机会,利用数据科学和机器学习方法共同解决现实世界的挑战。DSGs 由一支专门的内部团队负责管理和准备,包括活动组织者和跨学科的学术支持人员。更多信息[这里]

成功的数据挑战赛是准备工作、灵活性以及组织者、挑战主办方和参与者共同努力的结果。我们概述了以下建议。

活动前:协作与对齐

成功的数据挑战赛的基础工作在活动之前就已奠定。与挑战主办方(商业合作伙伴)的早期合作至关重要。他们的领域专长和对数据的理解可以显著影响活动的方向和结果。他们对问题的理解和领域专长可以大大改善数据,早期的合作有助于对齐双方的目标和期望,从而增加活动取得成功的可能性。

随着数据挑战赛的临近,对数据准备情况进行合理性检查并根据经验调查员的反馈调整挑战问题是非常有益的,调查员能够平衡行业需求和研究需求,同时考虑参与者的视角。

数据挑战赛期间:适应与参与

现场活动是计划与现实的交汇点。PI(主要研究员)在引导参与者解决数据挑战并确保目标达成方面发挥着关键作用。此外,参与者的反馈是宝贵的资源。他们对数据的全新视角可以揭示新的洞察或识别改进的领域,使得数据马拉松成为一个动态环境,调整不仅是可能的,而且是受到鼓励的。

对实际应用案例感兴趣吗?在会议论文中,我们将 10 个应用案例映射到我们的框架中。

  1. Cefas: 环境、渔业与水产科学中心

  2. 谢菲联大学先进制造研究中心:基于多传感器的智能加工过程监控

  3. CityMaaS: 通过预测和个性化使城市中的人们更容易出行

  4. WWF: 保护区智能监测

  5. 英国南极考察:太空中的海豹

  6. DWP: 工作和养老金部

  7. 痴呆症研究所与 DEMON 网络:预测 DNA 序列与表观遗传状态之间的功能关系

  8. 自动化评估重症病人舌下微循环的灌注情况

  9. Entale: 播客发现的推荐系统

  10. Odin Vision: 探索 AI 支持的早期结直肠癌诊断决策

完整报告以及其他数据研究小组的成果可以在报告部分找到。

最后 10 个 DSG 报告的报告计数数据评估分类

结论

在本文中,我们从五个关键维度(适当性、准备性、可靠性、敏感性和充足性)分析了数据马拉松中的数据,这些维度来自于自 2016 年以来组织的 80 多场数据马拉松。通过这样做,我们希望在数据马拉松事件之前改善组织对数据的处理。

我们提出的定性分析提供了多个视角下的数据状态程度;这些程度可以根据需要调整或扩展,类似于美国国家航空航天局提供的技术准备度等级,后者随着时间的推移和进一步的工作不断扩展。

Bibtex 引用:

@inproceedings{
mougan2023how,
title={How to Data in Datathons},
author={Carlos Mougan and Richard Plant and Clare Teng and Marya Bazzi and Alvaro Cabrejas-Egea and Ryan Sze-Yin Chan and David Salvador Jasin and martin stoffel and Kirstie Jane Whitaker and JULES MANSER},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2023},
url={https://openreview.net/forum?id=bjvRVA2ihO}
}

Mougan, C., Plant, R., Teng, C., Bazzi, M., Cabrejas-Egea, A., Chan, R. S.-Y., Jasin, D. S., Stoffel, M., Whitaker, K. J., & Manser, J. (2023). 如何在数据马拉松中使用数据。载于第三十七届神经信息处理系统大会数据集和基准跟踪

一张我(Carlos Mougan)在艾伦·图灵研究所的照片。(所有图片由作者提供并经授权使用)

数据科学导航:B2C 与 B2B 分析

原文:towardsdatascience.com/navigating-data-science-b2c-vs-b2b-analytics-a9ce007381b7?source=collection_archive---------5-----------------------#2024-07-29

客户类型如何塑造数据科学角色和方法论

Yu DongTowards Data Science Yu Dong

·发布于Towards Data Science ·阅读时间:10 分钟·2024 年 7 月 29 日

--

背景

在考虑加入一家公司或接受工作邀请时,我们通常会考虑行业、公司愿景、增长机会、文化等方面。今天,我想介绍另一个角度:公司是 B2B(企业对企业)还是 B2C(企业对消费者)。这一区分对数据科学角色的影响出乎意料地大。

作为一名数据科学家,已经有超过六年的经验,我在 B2C 公司(Ancestry.com,一家消费型家谱公司)工作了大约一半的时间,另外一半时间则在 B2B 公司(Brex,一家支出管理金融科技公司)工作。尽管这两个行业不同,但我注意到数据科学的方法论和由不同客户类型带来的挑战存在显著差异。在这篇文章中,我将讨论 B2B 和 B2C 公司之间在数据科学分析上的区别。

来源:DALL·E

I. 数据量与分析单元

B2C 和 B2B 公司之间最显著的区别之一是数据量和分析单元。

B2C:高数据量,聚焦个体用户

2024 年数据科学职位导航:角色、团队与技能

原文:towardsdatascience.com/navigating-data-science-jobs-in-2024-roles-teams-and-skills-c03193eb4c6e?source=collection_archive---------8-----------------------#2024-02-22

TDS 编辑Towards Data Science TDS 编辑

·发表于《Towards Data Science》 ·作为电子报发送 ·阅读时长 3 分钟·2024 年 2 月 22 日

--

无论你是申请第一次实习,还是负责管理一个由分析师和工程师组成的跨学科团队,数据科学职业都面临着一系列独特的挑战。其中一些可能比其他的更令人兴奋,而另一些则可能枯燥无味——这在任何工作中都是如此——但我们相信,可以将所有这些潜在的缺点视为加深知识、扩展技能并考虑新观点的机会。

本周的推荐内容汇集了围绕数据职业中常见障碍的各种观点和经验,并提出了有效的应对策略。无论你处于数据科学旅程的哪个阶段,我们希望你能浏览我们的推荐阅读,获得对自己工作有启发的见解。

  • 数据 ROI 金字塔:衡量与最大化数据团队价值的方法尽管Barr Moses的可操作路线图主要面向数据负责人和高管,但它对企业层级中的所有数据专业人士都是必备资源。毕竟,每个人都可以从了解自己的工作如何为业务做出贡献,以及如何向更广泛的非技术观众展示自己的影响力中获益。

  • 重建让我获得数据科学家职位的作品集 一年前,Matt Chapman 撰写了关于构建数据科学作品集的权威实践指南(并且在此过程中 非常 火爆)。在他最新的文章中,Matt 回顾了他的方法,并提出了几个关键更新,以实现更流畅的工作流程和更具可定制性的最终成果。

图片由 Anastase Maragos 提供,来源于 Unsplash

  • Spotify 高级数据科学家用来提高工作效率的 5 个习惯 在获得一份好的数据工作后,真正的挑战才开始:你可以做些什么来在新职位上脱颖而出,同时避免过度劳累和/或冒名顶替综合症?Khouloud El Alami 提出了五个具体的建议,你可以根据自己的需求加以调整,并且没有忽略任何细节。

  • 英特尔机器学习实习的 7 个教训 在银行业担任数据科学家多年后,Conor O'Sullivan 的职业生涯发生了重大转折,进入了科技巨头英特尔担任机器学习实习生;不要错过他在实习期间的经历和他在探索新行业和组织文化时学到的教训。

像往常一样,我们的作者在最近几周涵盖了广泛的主题,从 AI 的新兴技能到预测建模和深度学习。以下是我们不希望你错过的一些精彩文章。

  • 多模态模型在视觉文字谜题上的表现如何? Yennie Jun 测试了 GPT-4 Vision 和 Gemini Pro Vision 的能力,试图衡量它们在模型生成过程中呈现的创意水平。

  • 物理信息神经网络听起来可能是一个高深的理论概念,但正如 Shuai Guo 在一篇综合概述中所展示的那样,它们的现实世界应用非常广泛—并且增长速度健康。

  • 在一篇通俗易懂、图文并茂的解释文章中,施瑞娅·拉奥继续探索深度学习的基础话题,这次她深入解析了神经网络学习的过程。

  • 在她的 TDS 首秀中,尼西亚·拉马穆尔提展示了一个简单有效的框架,帮助你建立数据展示中的信心。

  • 如果你正在寻找一份关于 R 平方(决定系数)的清晰、详细指南,不要错过罗伯塔·罗卡的资源,她的指南将消除你对这一广泛使用的度量指标的任何困惑。

  • 想深入了解并动手实践吗?蒂姆·福斯特带我们一起走过使用线性求解器优化多维非线性神经网络的过程。

感谢你对我们作者工作的支持!如果你受到启发,想要加入他们的行列,为什么不写下你的第一篇文章?我们期待阅读

直到下一期《变量》专栏,

TDS 团队

使用 NetworkX 浏览网络:Python 中图形的简短指南

原文:towardsdatascience.com/navigating-networks-with-networkx-a-short-guide-to-graphs-in-python-c16cbafe8063?source=collection_archive---------1-----------------------#2024-11-18

Alina Grubnyak 提供的照片,来源于 Unsplash

探索 NetworkX,构建、分析和可视化 Python 中的图形。发现关联数据中的洞察。

Diego PenillaTowards Data Science Diego Penilla

·发布于 Towards Data Science ·14 分钟阅读·2024 年 11 月 18 日

--

在一个充满连接的世界中——从社交媒体上的朋友关系到复杂的交通网络——理解关系和模式是理解周围系统的关键。试想一下,如何将一个社交网络可视化,其中每个人都是一个点(“节点”),通过线(或“边”)与朋友相连。或者想象映射一个城市的地铁系统,其中每个车站是一个节点,每条路线是连接它们的边。

这就是 NetworkX 的亮点,它提供了一种强大的方式来构建、分析和可视化这些错综复杂的关系网络。

NetworkX 允许我们以传统表格方式难以实现的形式来表示数据,而图形格式使得这种表示变得既简便又自然。那些在电子表格中需要许多行和列才能定义的关系,可以通过直观的可视化方式捕捉,从而帮助我们理解和解读复杂数据。

Terry Vlisidis 提供的照片,来源于 Unsplash

该库允许我们将各种方法和算法应用于这些图形,每次通过新的方法重新构架数据时,都能提供新的洞察。

NetworkX

让我们首先分解一下什么是图。在网络分析中,图由节点(或顶点)和(或链接)组成。

  • 可以把节点看作是主要的实体,比如人或网页,边则是它们之间的连接——就像社交网络中的友谊或网站之间的超链接。

  • 一些边甚至带有权重,表示两个节点之间连接的强度、距离或成本。这个附加的信息层帮助我们分析两个节点不仅是否连接,还能知道它们的连接强度或紧密度。

这些图形可以用来建模各种系统,从社交网络到分子和交通网格。

让我们先看看如何使用networkx创建图。如果你还没有安装,可以先运行:

$ pip install networkx

创建图

要构建一个网络,我们将:

  1. 初始化网络:通过创建一个图G = nx.Graph()来初始化网络。

  2. 添加带属性的节点:使用 G.add_node() 来添加节点,每个节点可以存储自定义属性,如标签或年龄。

  3. 添加边:使用 G.add_edge() 来连接节点,每条边可以包含一个权重属性,表示连接的强度或成本。

  4. 可视化图形:使用 Matplotlib 函数,如nx.draw()nx.draw_networkx_edge_labels()来显示图形,展示节点标签和边权重,方便解释。

这是实现这一目标的代码:

import networkx as nx
import matplotlib.pyplot as plt

# Create a simple graph
G = nx.Graph()

# Add nodes with attributes (e.g., 'label' and 'age')
G.add_node(1, label="A", age=25)
G.add_node(2, label="B", age=30)
G.add_node(3, label="C", age=22)
G.add_node(4, label="D", age=28)

# Add weighted edges (node1, node2, weight)
G.add_edge(1, 2, weight=4)
G.add_edge(1, 3, weight=3)
G.add_edge(2, 4, weight=5)

# Retrieve and print node attributes
node_attributes = nx.get_node_attributes(G, 'age')  # Get 'age' attribute for all nodes
print("Node Attributes (Age):", node_attributes)

# Retrieve and print edge attributes
edge_weights = nx.get_edge_attributes(G, 'weight')  # Get 'weight' attribute for all edges
print("Edge Attributes (Weight):", edge_weights)

# Draw the graph with node and edge attributes
pos = nx.spring_layout(G)  # Layout for node positions
node_labels = nx.get_node_attributes(G, 'label')  # Get node labels for visualization
edge_labels = nx.get_edge_attributes(G, 'weight')  # Get edge weights for visualization

plt.figure(figsize=(6, 6))
nx.draw(G, pos, with_labels=True, node_color='skyblue', font_size=15, font_weight='bold', node_size=500)

# Draw the edge weights and node labels
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

plt.title("NetworkX Graph with Node and Edge Attributes")
plt.show()

图 1:一个带权重的图,节点为 1 到 4。图片来源:作者。

在这个例子中,我们初始化图形,然后创建:

  • 4 个节点(1、2、3、4)通过调用G.add_node(node, label, attr)

  • 3 条带权重的边连接这些节点:(1, 2)、(1, 3) 和 (2, 4),通过调用G.add_edge(node1, node2, attr)

在 NetworkX 中,节点和边都可以保存额外的属性,使得图形更加丰富和信息量更大。

  • 节点属性(通过nx.get_node_attributes(G, 'attribute')访问)允许每个节点存储数据,比如社交网络中某人的职业。

  • 边属性(通过nx.get_edge_attributes(G, 'attribute')访问)存储每个连接的信息,比如交通网络中的距离或旅行时间。这些属性增加了上下文和深度,使得网络的分析更加详细。

然后,我们使用 NetworkX 的弹簧布局pos = nx.spring_layout(G)来为节点定位,以便自然地将它们排布在图中。最后,nx.draw()nx.draw_networkx_edge_labels() 显示带有节点标签和边权重的图,清晰地呈现网络的结构和连接。

虽然这是一个相对简单的网络,但它展示了处理网络的基本方法:为了操作图形,我们需要处理节点及其连接,以及它们可能具有的任何属性。

空手道俱乐部网络

网络科学中最著名的例子之一是扎卡里空手道俱乐部,它常常用于说明社交网络分析和社区检测。该数据集是公有领域的,并且默认包含在 networkx 中。您可以按以下方式访问:

# Load the  Karate Club
G = nx.karate_club_graph()

# Draw the graph
plt.figure(figsize=(8, 8))
pos = nx.spring_layout(G)  # Layout for nodes -> treats nodes as repelling objects
nx.draw(G, pos, with_labels=True, node_color='skyblue', font_size=12, font_weight='bold', node_size=500)
plt.title("Zachary's Karate Club Network")
plt.show()

图 2:扎卡里空手道俱乐部网络。图片来自作者。

该网络代表了空手道俱乐部 34 名成员之间的友谊,它因两个派别之间的分裂而闻名,每个派别都有一个中心人物——Mr. HiOfficer

让我们来看看节点数据中包含的属性:

# looping over nodes
for node in G.nodes():
    print(f"Node: {node}, Node Attributes: {G.nodes[node]}")
Node: 0, Node Attributes: {'club': 'Mr. Hi'}
Node: 1, Node Attributes: {'club': 'Mr. Hi'}
Node: 2, Node Attributes: {'club': 'Mr. Hi'}
Node: 3, Node Attributes: {'club': 'Mr. Hi'}
Node: 4, Node Attributes: {'club': 'Mr. Hi'}
Node: 5, Node Attributes: {'club': 'Mr. Hi'}
Node: 6, Node Attributes: {'club': 'Mr. Hi'}
Node: 7, Node Attributes: {'club': 'Mr. Hi'}
Node: 8, Node Attributes: {'club': 'Mr. Hi'}
Node: 9, Node Attributes: {'club': 'Officer'}
Node: 10, Node Attributes: {'club': 'Mr. Hi'}
Node: 11, Node Attributes: {'club': 'Mr. Hi'}
Node: 12, Node Attributes: {'club': 'Mr. Hi'}
Node: 13, Node Attributes: {'club': 'Mr. Hi'}
Node: 14, Node Attributes: {'club': 'Officer'}
Node: 15, Node Attributes: {'club': 'Officer'}
Node: 16, Node Attributes: {'club': 'Mr. Hi'}
Node: 17, Node Attributes: {'club': 'Mr. Hi'}
Node: 18, Node Attributes: {'club': 'Officer'}
Node: 19, Node Attributes: {'club': 'Mr. Hi'}
Node: 20, Node Attributes: {'club': 'Officer'}
Node: 21, Node Attributes: {'club': 'Mr. Hi'}
Node: 22, Node Attributes: {'club': 'Officer'}
Node: 23, Node Attributes: {'club': 'Officer'}
Node: 24, Node Attributes: {'club': 'Officer'}
Node: 25, Node Attributes: {'club': 'Officer'}
Node: 26, Node Attributes: {'club': 'Officer'}
Node: 27, Node Attributes: {'club': 'Officer'}
Node: 28, Node Attributes: {'club': 'Officer'}
Node: 29, Node Attributes: {'club': 'Officer'}
Node: 30, Node Attributes: {'club': 'Officer'}
Node: 31, Node Attributes: {'club': 'Officer'}
Node: 32, Node Attributes: {'club': 'Officer'}
Node: 33, Node Attributes: {'club': 'Officer'}

节点属性club指的是每个节点所属的社区,例如“Officer”或“Mr. Hi”。我们将使用这些信息为图中的节点着色。

为此,我们将蓝色分配给标签为“Mr. Hi”的节点,将红色分配给标签为“Officer”的节点,并将它们存储在一个color_map列表中,然后我们可以使用nx.draw来可视化网络。

# Load the Karate Club 
G: nx.Graph = nx.karate_club_graph()

# Get the node labels
labels = nx.get_node_attributes(G, 'club')

# Map community labels to colors
color_map = []
for node in G.nodes():
    if labels[node] == 'Mr. Hi':
        # Assign blue color for 'Mr. Hi'
        color_map.append('blue')  
    else:
        # Assign red color for 'Officer'
        color_map.append('red')  

# Visualize the graph
plt.figure(figsize=(8, 8))
pos = nx.spring_layout(G)  

nx.draw(G, pos, with_labels=True, node_color=color_map, font_size=12, font_weight='bold', node_size=500, cmap=plt.cm.rainbow)
plt.title("Zachary's Karate Club Network with Ground Truth Communities")
plt.show()

图 3:空手道俱乐部网络中的“Mr. Hi”与“Officer”社区。图片来自作者。

图例说明,俱乐部的教练“Mr. Hi”与俱乐部的管理员“Officer”之间发生了冲突。这场分裂最终导致俱乐部分成了两个不同的团体,每个团体的中心人物分别是这两位领导者。

通过将这些关系表示为网络,我们可以直观地捕捉到这种分裂,并揭示数据中的模式和聚类——这些洞察在传统的表格格式中可能难以看到。

中心性

为了理解网络的结构和动态,识别最具影响力或战略性节点是至关重要的。这就是中心性度量的作用,它是网络科学中的一个关键概念。

它通过节点连接类型来衡量节点的位置,根据某些标准识别关键节点。常见的度量包括:

  • 度数中心性(仅指每个节点拥有的连接数)

  • 紧密中心性(一个节点能够多快地访问网络中的所有其他节点)。

  • 以及介数中心性(一个节点在其他节点之间的最短路径中出现的频率)

这些度量有助于揭示网络中的关键节点或瓶颈,从而提供对网络结构和动态的洞察。

import networkx as nx
import matplotlib.pyplot as plt

# Load the Karate Club 
G = nx.karate_club_graph()

# Compute centrality measures
degree_centrality = nx.degree_centrality(G)
betweenness_centrality = nx.betweenness_centrality(G)
closeness_centrality = nx.closeness_centrality(G)

# top 5 nodes by centrality for each measure
top_degree_nodes = sorted(degree_centrality, key=degree_centrality.get, reverse=True)[:5]
top_betweenness_nodes = sorted(betweenness_centrality, key=betweenness_centrality.get, reverse=True)[:5]
top_closeness_nodes = sorted(closeness_centrality, key=closeness_centrality.get, reverse=True)[:5]

# top 5 nodes for each centrality measure
print("Top 5 nodes by Degree Centrality:", top_degree_nodes)
print("Top 5 nodes by Betweenness Centrality:", top_betweenness_nodes)
print("Top 5 nodes by Closeness Centrality:", top_closeness_nodes)

# top 5 nodes for Degree Centrality
plt.figure(figsize=(8, 8))
pos = nx.spring_layout(G)  # Positioning of nodes
node_color = ['red' if node in top_degree_nodes else 'skyblue' for node in G.nodes()]

# draw top 5 nodes by degree centrality
nx.draw(G, pos, with_labels=True, node_color=node_color, font_size=15, font_weight='bold', node_size=500)
plt.title("Karate Club Network with Top 5 Degree Central Nodes")
plt.show()
Top 5 nodes by Degree Centrality: [33, 0, 32, 2, 1]
Top 5 nodes by Betweenness Centrality: [0, 33, 32, 2, 31]
Top 5 nodes by Closeness Centrality: [0, 2, 33, 31, 8]

图 4:空手道俱乐部网络中中心性最高的节点。图片来自作者。

对于节点033,我们看到这些节点在网络中最为中心,具有较高的度数、介数和紧密中心性。

它们在网络中的核心作用表明它们是高度连接的枢纽,通常充当其他成员之间的桥梁,并能够快速触及网络中的其他节点。这一定位使它们成为关键角色,在网络的流动性和结构中具有重要意义。

使用 Girvan-Newman 算法进行社区检测

一个社区 C 是一个节点集合(例如,社交网络中的个体、通过超链接连接的网页等),这些节点之间的连接比与网络中其他部分的连接更强。

在考虑中心性的可视化表示时,让我们将 Girvan-Newman 算法应用于此图。

  • 该算法通过逐步移除具有最高介数中心性的边缘,生成一系列社区划分。

  • 当算法第一次运行时,它识别出最重要的社区划分。

from networkx.algorithms.community import girvan_newman

# Load the Karate Club graph
G = nx.karate_club_graph()

# Apply Girvan-Newman community detection
comp = girvan_newman(G)
first_level_communities = next(comp)

# Visualize the first level of communities
pos = nx.spring_layout(G)
plt.figure(figsize=(8, 8))

# Color nodes by their community
node_colors = ['skyblue' if node in first_level_communities[0] else 'orange' for node in G.nodes()]
nx.draw(G, pos, with_labels=True, node_color=node_colors, font_size=12, node_size=500)

plt.title("Karate Club Network with Girvan-Newman Communities")
plt.show()

print("Detected Communities:", first_level_communities)

图 5:Girvan-Newman 算法在空手道俱乐部网络中的第一次划分。图像来自作者。

  • 由于girvan_newman(G)返回一个迭代器comp,调用next(comp)可以让你检索到第一个划分,即将网络分成两个社区的第一次划分。

让我们将检测到的社区与实际的节点标签club进行比较。

 print("Detected Communities:", first_level_communities)
# Print the actual communities (ground truth)
print("\nActual Communities (Ground Truth):")
mr_hi_nodes = [node for node, label in labels.items() if label == 'Mr. Hi']
officer_nodes = [node for node, label in labels.items() if label == 'Officer']

print(f"Mr. Hi's Community: {mr_hi_nodes}")
print(f"Officer's Community: {officer_nodes}")
Detected Communities: (
{0, 1, 3, 4, 5, 6, 7, 10, 11, 12, 13, 16, 17, 19, 21}, 
{2, 8, 9, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33}
)

Actual Communities (Ground Truth):
Mr. Hi's Community: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 16, 17, 19, 21]
Officer's Community: [9, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]

Girvan-Newman 算法检测到的社区与实际的 Hi 先生和警官社区相似,但并不完全匹配。这是因为 Girvan-Newman 算法仅基于边缘的介数中心性来划分网络,而不依赖任何预定义的社区标签。

这种方法在没有标签的非结构化数据集中尤其有用,因为它揭示了基于网络结构特性的有意义的分组。这突出了社区检测中的一个关键问题:什么构成一个社区并没有严格的定义。

结果是,没有单一的“正确”方式来划分一个网络。不同的方法,由不同的度量驱动,可以得出不同的结果,每种结果根据具体的上下文提供有价值的见解。

来源

团体

在网络中,一个有用的概念是团体。在网络科学中,团体是指图中一个节点子集,其中每个节点都与该子集中的其他每个节点相连接。这意味着团体中的所有成员之间都有直接的关系,形成了一个紧密连接的群体。团体在研究复杂网络的结构时特别有用,因为它们通常代表了更大系统中高度连接或凝聚的群体。

例如在:

  • 在社交网络中:团体可以代表彼此认识的群体,如紧密联系的朋友圈或职业同事。

  • 在协作网络中:在一个协作网络中(例如,研究合作),团体可以揭示那些在相同主题或项目上共同合作的研究团队。

  • 在生物网络中:在生物网络中,团体可以表示在生物过程中紧密相互作用的功能性蛋白质或基因群体。

让我们找到空手道网络中的最大团体。我们将找到具有所有互联链接的最大人群。

import networkx as nx
import matplotlib.pyplot as plt

# Load the Karate Club graph
G = nx.karate_club_graph()

# Find all cliques in the Karate Club network
cliques = list(nx.find_cliques(G))

# Find the largest clique (the one with the most nodes)
largest_clique = max(cliques, key=len)

# Print the largest clique
print("Largest Clique:", largest_clique)

# Visualize the graph with the largest clique highlighted
plt.figure(figsize=(8, 8))
pos = nx.spring_layout(G)  # Layout for node positions
nx.draw(G, pos, with_labels=True, node_color='skyblue', font_size=12, node_size=500)

# Highlight the nodes in the largest clique
nx.draw_networkx_nodes(G, pos, nodelist=largest_clique, node_color='orange', node_size=500)

plt.title("Karate Club Network with Largest Clique Highlighted")
plt.show()

图 6:空手道俱乐部网络中的最大团体,节点 0、1、2、3 和 13 相互连接。图片由作者提供。

尽管在网络科学中定义“社区”存在挑战,但团体提供了一个具体且明确的概念,用于识别那些完全互联的群体,从而为结构化和非结构化网络提供有意义的见解。

最短路径

网络科学中的另一个有趣概念是最短路径。在图中的两个节点之间的最短路径指的是连接这两个节点的边的序列,同时最小化总距离或成本,这可以根据应用的不同有多种解释。这个概念在路由算法、网络设计、运输规划甚至社交网络分析等领域发挥着至关重要的作用。

NetworkX 提供了几种算法来计算最短路径,例如用于加权图的 Dijkstra 算法和用于无权图的广度优先搜索(BFS)。

图片来源:Ed 259Unsplash

让我们来看一个例子,我们将创建一个合成数据集,其中节点代表车站,边表示车站之间的连接。

  • 我们还将添加加权边时间,表示从一个站点到下一个站点所需的时间。
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

# Simulate loading a CSV file (real example would load an actual CSV file)
# Define a more extensive set of stations and travel times between them
data = {
    'station_id': ['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D', 'E', 'E', 'F', 'F', 'G', 'G', 'H'],
    'connected_station': ['B', 'C', 'A', 'C', 'A', 'D', 'C', 'E', 'B', 'F', 'D', 'G', 'E', 'H', 'F'],
    'time': [10, 20, 10, 15, 20, 10, 5, 15, 10, 25, 10, 5, 15, 10, 30]  # Travel times in minutes
}

# Create a DataFrame
df = pd.DataFrame(data)

# Create a graph from the DataFrame
G = nx.Graph()

# Add edges to the graph (station connections with weights as travel times)
for index, row in df.iterrows():
    G.add_edge(row['station_id'], row['connected_station'], weight=row['time'])

# Draw the graph
plt.figure(figsize=(8, 8))
pos = nx.spring_layout(G)  # Layout for node positions
nx.draw(G, pos, with_labels=True, node_size=500, node_color='skyblue', font_size=12, font_weight='bold')

# Draw edge weights (travel times)
edge_labels = nx.get_edge_attributes(G, 'weight')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

plt.title("Expanded Transportation Network with Travel Times")
plt.show()

图 7:示例运输网络,其中节点代表车站,边代表时间或长度。图片由作者提供。

在这个例子中,我们使用 Dijkstra 算法来计算从车站 A 到车站 H 的最短路径,其中边的权重代表旅行时间。最短路径及其总旅行时间会被打印出来,路径在图中以红色突出显示,边的权重显示以指示车站之间的旅行时间。

# Compute the shortest path using Dijkstra's algorithm (considering the travel time as weight)
source = 'A'
target = 'H'

shortest_path = nx.shortest_path(G, source=source, target=target, weight='weight')
path_length = nx.shortest_path_length(G, source=source, target=target, weight='weight')

# Print the shortest path and its length
print(f"Shortest path from {source} to {target}: {shortest_path}")
print(f"Total travel time from {source} to {target}: {path_length} minutes")

# Visualize the shortest path on the graph
plt.figure(figsize=(8, 8))
nx.draw(G, pos, with_labels=True, node_size=500, node_color='skyblue', font_size=12, font_weight='bold')

# Highlight the shortest path in red
edges_in_path = [(shortest_path[i], shortest_path[i + 1]) for i in range(len(shortest_path) - 1)]
nx.draw_networkx_edges(G, pos, edgelist=edges_in_path, edge_color='red', width=2)

# Draw edge weights (travel times)
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

plt.title(f"Shortest Path from {source} to {target} with Travel Time {path_length} minutes")
plt.show()
Shortest path from A to H: ['A', 'B', 'E', 'G', 'H']
Total travel time from A to H: 45 minutes

图 8:给定图中输入节点 A 和 H 之间的最短路径—45 分钟。图片由作者提供。

该算法计算了最短路径及其总旅行时间,然后将其显示出来。从 A 到 H 的最短路径在图中以红色高亮显示,边的权重表示每个连接站点之间的时间,总和为 45。

尽管这只是一个简单的计算,但最短路径算法有着广泛的应用。在交通运输中,它们优化路线并减少旅行时间;在数字通信中,它们有效地路由数据。它们在物流中至关重要,用于降低成本,在供应链中确保及时交货,在社交网络中衡量个体之间的亲密度。理解最短路径使得各领域可以基于数据做出决策——从城市规划到网络基础设施——使其成为高效导航复杂系统的关键工具。

感谢阅读

我们已经使用 NetworkX 探索了网络科学中的几个基本概念,如最短路径算法、社区检测以及图论在建模和分析复杂系统中的强大能力。

如果你想继续学习,下面我放了一些链接供你参考 😃 如果你想深入了解社区检测算法,可以查看 CDLib 库。

  1. Networkx 教程

  2. CDLib, 一个用于社区检测的库

注意:在图上计算高级指标和度量常常可能是模糊的,甚至会导致误导。由于有这么多可用的潜在指标,很容易生成一些没有实际意义或可能会误表示网络真实结构的数字。选择正确的指标需要谨慎考虑,因为并非所有度量都能为每种类型的网络分析提供相关见解。如果你对此有所共鸣,可以点击这里查看更多信息:统计推断将数据与理论联系在一起,应用于网络科学

参考文献

导航慢变维度(SCD)与数据重述:全面指南

原文:towardsdatascience.com/navigating-slowly-changing-dimensions-scd-and-data-reinstatement-a-comprehensive-guide-f8b72ff90d98?source=collection_archive---------10-----------------------#2024-02-05

在企业数据仓库中高效管理维度变化和数据重述的策略

Kirsten Jiayi PanTowards Data Science Kirsten Jiayi Pan

·发表于 Towards Data Science ·6 分钟阅读·2024 年 2 月 5 日

--

想象一下,你是一名数据工程师,为一家大型零售公司工作,该公司在数据仓库中使用增量加载技术。该技术涉及选择性地仅更新或加载自上次更新以来新增或修改的数据。当产品研发部门决定更改当前产品的名称或描述时,可能会发生什么情况?这些更新将如何影响你现有的数据管道和数据仓库?你打算如何解决这些挑战?本文提供了一份全面指南,结合使用慢变维度(SCD),来解决数据重述过程中可能遇到的问题。

图片来源:unsplash.com/photos/macbook-pro-with-images-of-computer-language-codes-fPkvU7RDmCo

什么是慢变维度(SCD)?

慢变维度指的是维度值的偶发性变化,这些变化并不依赖于每日或定期的时间表,因为维度变化的频率通常低于系统中事务条目的变化频率。例如,一家珠宝公司,其顾客在其网站上提交新的订单,将在订单事实表中生成一条新的记录。另一方面,珠宝公司很少更改其产品名称和描述,但这并不意味着未来不会发生这种情况。

管理这些维度的变化需要采用慢变维度(SCD)管理技术,这些技术被分为定义的 SCD 类型,从类型 0 到类型 6,包括一些组合型或混合型类型。我们可以使用以下方法之一:

SCD 类型 0:忽略

维度值的变化被完全忽视,维度的值从初次在数据仓库中创建时起保持不变。

SCD 类型 1:覆盖/替换

当维度属性的前一个值不再相关或不重要时,适用这种方法。然而,不需要进行变化的历史跟踪。

SCD 类型 2:创建新的维度行

这种方法被推荐作为解决维度值变化的主要技术,涉及为维度创建第二行,包含开始日期、结束日期,并可能有一个“当前/过期”标志。它适用于我们像产品描述或地址变更等场景,确保历史数据的清晰划分。新的维度行会与新插入的事实行相关联,每个维度记录会根据插入时间与一部分事实行相关联——变更之前的行与旧的维度行相关联,变更之后的行与新的维度行相关联。

图 1(作者提供的图片):PRODUCT_KEY = “cd3004” 是对 PRODUCT_KEY = “cd3002” 的重新表述。

SCD 类型 3:创建“PREV”列

当旧值和新值都相关时,此方法适用,用户可能希望使用任一值进行历史分析。然而,将这种技术应用于所有维度属性并不实际,因为这将涉及为每个维度属性提供两列,或者如果需要保留多个“PREV”值,则需要更多列。应在合适的地方选择性使用。

图 2(作者提供的图片):PRODUCT_KEY = “cd3002” 被重新表述为新的 PRODUCT_NAME,旧的 PRODUCT_NAME 存储在 NAME_PREV 列中。

SCD 类型 4:快速变化的大型维度

如果在某个场景中,你需要捕获零售大型维度中每个维度属性的每次变化,比如一个拥有百万以上客户的巨大珠宝公司,怎么办?如上所述使用类型 2 会非常迅速地将客户维度表的行数膨胀到几十亿甚至上百亿行,而使用类型 3 是不可行的。

对于快速变化且数据量较大的维度表,一个更有效的解决方案是对属性进行分类(例如客户年龄类别、性别、购买力、生日等),并将它们分离到二级维度中,如客户档案维度。这个表作为“全覆盖”维度表,将每个类别的维度属性的所有潜在值预加载到表中,这样可以更好地管理变更的粒度,同时避免主客户维度中行的过度扩展。

例如,如果我们有 8 个年龄类别、3 种性别、6 个购买力类别和 366 种可能的生日。我们的“全覆盖”客户档案维度表包含上述所有组合,将是 8 x 3 x 6 x 366 组合,或 52704 行

我们需要为这个维度表生成 surrogate_key,并在事实表中建立与新外键的连接。当其中一个维度类别发生修改时,无需向客户维度中添加新行。相反,我们生成一行新的事实数据,并将其与客户维度和新的客户档案维度关联起来。

图 3(作者提供的图片):“全覆盖维度”表的实体关系图

SCD 类型 5:类型 4 的扩展

为了增强前面提到的类型 4 方法,我们可以在客户维度和客户档案维度之间建立连接。这种连接能够追踪特定客户的“当前”客户档案。该关键连接将客户与最新的客户档案相连接,允许从客户维度无缝地跳转到最新的客户档案维度,而无需通过事实表进行链接。

图 4(作者提供的图片):实体关系图显示了客户维度到客户档案维度的链接

SCD 类型 6:混合技术

通过这种方法,您可以将类型 2(新行)和类型 3(“PREV”列)结合在一起。这种混合方法融合了两种方法的优点。您可以通过“PREV”列来检索事实,它提供了历史值,并展示与特定时间点的产品类别相关的事实。同时,通过“new”列进行查询,可以获得当前和所有之前的产品类别的所有事实。

图 5(作者提供的图片):PRODUCT_ID = “cd3004” 是 PRODUCT_ID = “cd3002” 的重述,PRODUCT_ID = “cd3001” 在 LAST_ACTION 列中标记为“EXPIRED”

奖励与总结

通常,数据提取采用星型模式(STAR schema),其中包括一个事实表和多个维度表。在企业中,维度表存储所有描述性数据和主键,而事实表包含引用每个维度的主键的数字和加法数据。

图 6(作者提供的图片):星型模式的示意图

然而,如果您的营销销售数据提取是作为一个单一的去规范化表提供的,且没有独立的维度表,并且缺乏描述性数据的主键,那么将来更新产品名称可能会带来挑战。在现有的管道中处理这种情况会更为复杂。

描述性数据中缺少主键可能会在数据重述时引发问题,尤其是在处理大规模数据集时。例如,如果在重述提取中更新了产品名称,但没有唯一的product_key,增量加载管道可能会将其视为新产品,从而影响消费层的历史数据。为了解决这个问题,需要为产品维度创建surrogate_key,并建立一个映射表,将原始产品名称和重述后的产品名称进行关联,从而维护数据完整性。

总之,数据仓库设计的每个方面都应仔细考虑,尤其是要考虑潜在的边缘情况。

探索软演员评论家强化学习

原文:towardsdatascience.com/navigating-soft-actor-critic-reinforcement-learning-8e1a7406ce48?source=collection_archive---------11-----------------------#2024-12-12

在生物工程背景下理解 SAC 强化学习的理论与实现

Mohammed AbuSadehTowards Data Science Mohammed AbuSadeh

·发布于 Towards Data Science ·阅读时长 10 分钟·2024 年 12 月 12 日

--

由作者使用 ChatGPT-4o 生成的图像

介绍

强化学习(RL)研究领域在过去几年里取得了巨大进展。深度强化学习方法的使用,如近端策略优化(PPO)(Schulman,2017)和深度确定性策略梯度(DDPG)(Lillicrap,2015),使得智能体能够在高维环境中解决任务。然而,许多这些无模型的强化学习算法在训练过程中存在稳定性问题。这些挑战源于脆弱的收敛性、高方差的梯度估计、非常高的样本复杂度以及在连续动作空间中对超参数的敏感性。鉴于这些问题,有必要考虑一种新开发的强化学习算法,避免这些问题并扩展其在复杂现实世界问题中的适用性。这个新算法就是软演员评论家(SAC)深度强化学习网络。(Haarnoja,2018)

软演员评论家网络的模型架构。图像来源于(Du, 2023)

SAC 是一种离线策略(off-policy)Actor-Critic 深度强化学习算法,旨在解决其前辈算法的稳定性和效率问题。SAC 算法基于最大熵强化学习框架,目标是让网络的演员部分最大化预期奖励,同时最大化熵。它结合了离线策略更新和更稳定的随机 Actor-Critic 方法的表述。与 PPO 等在线策略方法不同,离线策略算法通过经验回放实现更快的学习和更好的样本效率,后者需要为每一步梯度更新收集新的样本。对于像 PPO 这样的在线策略方法,在学习过程中的每一步梯度更新都需要收集新的样本。使用随机策略和最大化熵的目的是通过鼓励行为中的更多随机性,促进算法的鲁棒性和探索性。此外,与 PPO 和 DDPG 不同,SAC 使用了双 Q 网络、独立的演员网络和熵调整,以提高将离线学习与高维非线性函数逼近相结合时的稳定性和收敛性。

离线策略强化学习方法对改善患者生活的生物工程系统产生了广泛的影响。更具体地说,强化学习已被应用于诸如机器人手臂控制、药物递送方法以及最著名的去 novo 药物设计等领域。(Svensson, 2024)Svensson 等人使用了多种在线和离线框架及不同类型的回放缓冲区,以学习基于 RNN 的分子生成策略,该策略能有效作用于 DRD2(一种多巴胺受体)。该论文指出,使用经验回放处理高、中、低评分分子在提高结构多样性和生成活性分子的数量方面产生了效果。回放缓冲区提高了训练代理的样本效率。他们还报告指出,使用离线策略方法,特别是 SAC,有助于通过防止模式坍塌来促进结构多样性。

理论解释

SAC 通过引入带有熵项的目标函数,使用“软”价值函数,Η(π(a|s))。因此,网络试图最大化终身奖励的预期回报和策略的熵。策略的熵被定义为一个随机变量的不可预测性,随着可能值范围的增大而增加。因此,新的熵正则化目标变为:

熵正则化目标

α 是一个温度参数,用于平衡探索与利用之间的关系。

在软值函数的实现中,我们的目标是最大化熵,因为算法会将相似 Q 值的动作分配相等的概率。最大化熵还有助于防止智能体选择利用近似 Q 值中的不一致性的动作。最终我们可以理解 SAC 如何通过允许网络进行更多探索,而不是对一类动作赋予过高的概率,从而改善脆弱性。这部分内容灵感来自 Vaishak V.Kumar 在《软演员-评论家解密》一文中对熵最大化的解释。

SAC 论文的作者讨论了,由于状态值函数逼近软值,因此实际上并不需要为策略训练单独的函数逼近器,因为它们根据以下公式与状态值相关联。然而,训练三个独立的逼近器能提供更好的收敛性。

软状态值函数

三个函数逼近器网络的特征如下:

  • 策略网络(Actor): 随机策略输出一组从高斯分布中采样的动作。策略参数通过最小化 Kullback-Leibler 散度来学习,公式如下:

最小化 KL 散度

KL 散度比较的是两个概率分布之间的相对熵或差异。因此,在该公式中,我们试图最小化策略函数与经过 Z 函数规范化的指数化 Q 函数之间分布的差异。由于目标密度函数是可微分的 Q 函数,因此我们对策略应用了重新参数化技巧,以减少方差的估计。

重新参数化策略

ϵₜ 是一个从高斯分布中采样的向量,描述了噪声。

然后,策略目标被更新为以下表达式:

策略目标

策略目标通过以下梯度估计来优化:

策略梯度估计器

  • Q 网络(Critic): 包含两个 Q 值网络,用于估计状态-动作对的期望回报。我们通过使用以下软贝尔曼残差来最小化软 Q 函数的参数:

软 Q 函数目标

其中:

即时 Q 值

Soft Q-函数目标最小化网络 Q 值估计与即时 Q 值之间的平方差。即时 Q 值(Q hat)是通过当前状态-动作对的奖励与下一时间戳中目标值函数的折扣期望之和得到的。最终,目标通过以下给定的随机梯度估计进行优化:

随机梯度估计器

目标值网络(Critic): 一个独立的 Soft 值函数,有助于稳定训练过程。Soft 值函数逼近器通过以下方式最小化平方残差误差:

Soft 值函数目标

这个 Soft 值函数目标最小化值函数与 Q 值的期望加上策略函数π的熵之间的平方差。此目标的负对数部分描述了策略函数的熵。我们还知道,信息熵是通过负号来计算,从而输出正的熵值,因为概率值(介于 0 和 1 之间)的对数总是负数。类似地,目标通过以下表达式进行优化,使用无偏的梯度估计器:

无偏梯度估计器

代码实现

本文中实现的代码来自以下 GitHub 仓库(quantumiracle,2023 年):

[## GitHub - quantumiracle/Popular-RL-Algorithms: PyTorch 实现 Soft Actor-Critic (SAC)…

Soft Actor-Critic (SAC)、Twin Delayed DDPG (TD3)、Actor-Critic (AC/A2C)、Proximal Policy 的 PyTorch 实现…

github.com](https://github.com/quantumiracle/Popular-RL-Algorithms?source=post_page-----8e1a7406ce48--------------------------------)

pip install gymnasium torch

SAC 依赖于使用连续动作空间的环境,因此所提供的仿真大多使用了机器人臂‘Reacher’环境,以及 Gymnasium 包中的 Pendulum-v1 环境。

Pendulum 环境在另一个仓库中运行,该仓库实现了相同的算法,但使用了较少过时的库(MrSyee,2020 年):

[## GitHub - MrSyee/pg-is-all-you-need: Policy Gradient is all you need! 一个详细的 PG 方法教程…

Policy Gradient 就是你所需要的!一个详细的 PG 方法教程。- MrSyee/pg-is-all-you-need

github.com](https://github.com/MrSyee/pg-is-all-you-need?tab=readme-ov-file&source=post_page-----8e1a7406ce48--------------------------------)

在网络架构方面,如理论解释中所提到,有三个主要组件:

策略网络: 实现一个高斯演员网络,计算动作分布的均值和对数标准差。

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), -20, 2)  # Limit log_std to prevent instability
        return mean, log_std

软 Q 网络: 估计给定状态-动作对下的期望未来回报,对于定义的最优策略而言。

class SoftQNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(SoftQNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

价值网络: 估计状态值。

class ValueNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, 1)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return self.out(x)

以下代码片段展示了更新与 SAC 算法相关的不同变量的关键步骤。首先,它从重放缓冲区采样一个批次进行经验回放。然后,在计算梯度之前,先将梯度初始化为零,以确保不会累积来自前一批次的梯度。接着执行反向传播,并在训练过程中更新网络的权重。最后,更新 Q 网络的目标值和损失值。这些步骤适用于所有三种方法。

def update(batch_size, reward_scale, gamma=0.99, soft_tau=1e-2):
    # Sample a batch
    state, action, reward, next_state, done = replay_buffer.sample(batch_size)
    state, next_state, action, reward, done = map(lambda x: torch.FloatTensor(x).to(device), 
                                                  [state, next_state, action, reward, done])

    # Update Q-networks
    target_value = target_value_net(next_state)
    target_q = reward + (1 - done) * gamma * target_value
    q1_loss = F.mse_loss(soft_q_net1(state, action), target_q.detach())
    q2_loss = F.mse_loss(soft_q_net2(state, action), target_q.detach())

    soft_q_optimizer1.zero_grad()
    q1_loss.backward()
    soft_q_optimizer1.step()

    soft_q_optimizer2.zero_grad()
    q2_loss.backward()
    soft_q_optimizer2.step()

    # Update Value Network
    predicted_q = torch.min(soft_q_net1(state, action), soft_q_net2(state, action))
    value_loss = F.mse_loss(value_net(state), predicted_q - alpha * log_prob)
    value_optimizer.zero_grad()
    value_loss.backward()
    value_optimizer.step()

    # Update Policy Network
    new_action, log_prob, _, _, _ = policy_net.evaluate(state)
    policy_loss = (alpha * log_prob - predicted_q).mean()
    policy_optimizer.zero_grad()
    policy_loss.backward()
    policy_optimizer.step()

    # Soft Update Target Network
    for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
        target_param.data.copy_(soft_tau * param.data + (1 - soft_tau) * target_param.data)

最后,要运行 sac.py 文件中的代码,只需运行以下命令:

python sac.py --train
python sac.py --test

结果与可视化

训练“Reacher”机器人手臂,(由作者生成)

在两个环境中训练 SAC 智能体时,我注意到问题的行动空间会影响训练的效率和表现。事实上,当我在简单的摆动环境中训练智能体时,学习收敛得更快,且波动较小。然而,由于 Reacher 环境包含了更复杂的连续动作空间,算法虽然训练得相对不错,但奖励的巨大跳跃并没有像预期那样显著。Reacher 环境的训练回合数是摆动环境的 4 倍。

通过最大化奖励的学习表现(由作者生成)

下面的动作分布图显示了策略在训练过程中探索的多样化动作,直到它收敛到一个最优策略。像 SAC 这样的熵正则化算法的特点来自于探索性的增加。我们还可以注意到,峰值对应于具有高预期回报的动作值,这推动策略收敛到更具确定性的行为。

行动空间使用分布(由作者生成)

说到更具确定性的行为,我们观察到,熵在训练回合数上平均有所减少。然而,这种行为是可以预期的,因为我们希望最大化熵的唯一原因是鼓励更多的探索。更高的探索主要发生在训练过程的早期,目的是耗尽大多数可能的状态-动作对,从而获得更高的回报。

训练回合的熵评估(由作者生成)

结论

SAC 算法是一种离策略强化学习框架,通过引入新的熵项,平衡了开发和探索。SAC 算法的主要目标函数包括最大化期望回报和训练过程中的熵,这解决了传统框架中存在的许多问题。使用双 Q 网络和自动温度调节来应对高样本复杂性、脆弱的收敛性和复杂的超参数调整问题。SAC 在连续控制任务领域中已经证明非常有效。关于动作分布和熵的结果表明,算法在早期训练阶段偏向于探索和多样化的动作采样。随着智能体的训练,它趋向于一个更具体的策略,减少了熵并达到了最优动作。因此,它已被有效地应用于生物工程中的各种领域,如机器人控制、药物发现和药物传递。未来的实现应着重于将该框架扩展到更复杂的任务并降低其计算复杂性。

参考文献

Lillicrap, T.P., Hunt, J.J., Pritzel, A., Heess, N., Erez, T., Tassa, Y., Silver, D. 和 Wierstra, D. (2015)。使用深度强化学习进行连续控制。[在线] arXiv.org。可用链接:arxiv.org/abs/1509.02971.

Schulman, J., Wolski, F., Dhariwal, P., Radford, A. 和 Klimov, O. (2017)。邻近策略优化算法。[在线] arXiv.org。可用链接:arxiv.org/abs/1707.06347.

Haarnoja, T., Zhou, A., Abbeel, P. 和 Levine, S. (2018)。软演员-评论家:带有随机演员的离策略最大熵深度强化学习。arXiv:1801.01290 [cs, stat]。[在线] 可用链接:arxiv.org/abs/1801.01290.

Du, H., Li, Z., Niyato, D., Yu, R., Xiong, Z., Xuemin, Shen 和 Dong In Kim (2023)。在无线边缘网络中启用 AI 生成内容(AIGC)服务。doi: https://doi.org/10.48550/arxiv.2301.03220.

Svensson, H.G., Tyrchan, C., Engkvist, O. 和 Morteza Haghir Chehreghani (2024)。利用强化学习进行新药设计。机器学习, 113(7), 第 4811–4843 页。doi: https://doi.org/10.1007/s10994-024-06519-w.

quantumiracle (2019)。GitHub — quantumiracle/Popular-RL-Algorithms: PyTorch 实现的软演员-评论家(SAC),双延迟 DDPG(TD3),演员-评论家(AC/A2C),邻近策略优化(PPO),QT-Opt,PointNet.. [在线] GitHub。可用链接:github.com/quantumiracle/Popular-RL-Algorithms [访问时间:2024 年 12 月 12 日]。

MrSyee (2019). GitHub — MrSyee/pg-is-all-you-need: 策略梯度就是你所需要的一切!一步一步教你掌握著名的 PG 方法。 [在线] GitHub. 可用链接:github.com/MrSyee/pg-is-all-you-need?tab=readme-ov-file.

2024 年人工智能景观的探索:趋势、预测与可能性

原文:towardsdatascience.com/navigating-the-ai-landscape-of-2024-trends-predictions-and-possibilities-41e0ac83d68f?source=collection_archive---------0-----------------------#2024-01-02

2024 年以创新的承诺召唤我们——这是人工智能与科技汇聚的一年,将重新绘制可能性的地图。

Vincent KocTowards Data Science Vincent Koc

·发表于Towards Data Science ·21 分钟阅读·2024 年 1 月 2 日

--

引言

欢迎进入 2024 年人工智能与科技所带来的可能性之旅。在这里,每一项预测都可能是通向充满创新、变革和更重要的机会的窗口,这种机会类似于 1950 年代的工业革命。50 年代见证了数字计算的崛起,重塑了各行各业及社会规范。今天,人工智能扮演着类似的角色,推动着下一个工业革命的到来。

就像战后科技繁荣一样,我们正在经历 2024 年的变革,行业转型、对新技能的需求增加,以及伦理问题日益突出。这里表达的观点完全是我个人的,基于我的商业和学术经验。

目录 — 所有预测与趋势:

  1. 生成性人工智能从炒作走向聚光灯中心

  2. 期待看到下一代神经网络

  3. 每个平台都将推出其向量数据功能

  4. 争夺硬件和平台供应链的控制权

  5. 2024 年是人工智能可穿戴设备与 XR 的元年

  6. 人工智能代理将开始与其他人工智能代理沟通

  7. 生成性人工智能的模式将扩展

  8. 消费者和监管机构推动更民主化的人工智能

  9. 人工智能驱动的全新时代营销战略

  10. “垃圾进,垃圾出”的困境愈加严重

  11. 为特定用途构建的小型基础模型成为常态

  12. 人工智能代理市场的黎明

  13. AI 产品将超越 SaaS 模型

  14. BYO AI 运动推动了对安全数字身份的需求

  15. 结论与关键收获

1) 生成式 AI 从炒作走向舞台中心

凭借着重新定义技术生态系统的潜力,生成式 AI 从一个“即将破灭”的炒作周期气球转变为核心战略,并且越来越成为一个确定的趋势,就像飞艇正在迎来复兴一样。

对于大型企业来说,这是一场如何解决问题和创新的范式转变,他们从实验阶段进入了采用生成式 AI 的阶段。这一技术性修正类似于云技术早期采用所带来的变革浪潮,预示着它将对技术生态系统产生类似甚至更大的影响。

97%的企业主已经认为,像 ChatGPT 这样的生成式 AI 工具将对他们的业务产生积极影响(福布斯

因此,2024 年的景象承诺将是一个生成式 AI 不仅仅是一个流行词,而是推动技术进步和业务转型的关键动力的时代。在这个时代,AI 战略就是战略

2) 下一代神经网络开始显现

随着对驱动 LLM 的变换器架构的一般认知能力产生怀疑,以及对 AGI 的竞争激发了研究界对下一代技术的探索。我们已经看到了近期的进展和现有技术的发展:

  • 曼巴这样的 SSM 模型;一个在选择性状态空间下,在线性时间序列建模中表现出色的模型,是一种替代当前运行大语言模型的变换器架构。它代表了 AI 处理和理解序列的飞跃,序列是人类认知的一个基本方面。

  • 神经符号 AI将神经网络的学习能力与符号 AI 的精准性相结合。这种混合方法承诺能更细致和复杂地理解复杂问题,架起了人类式推理和机器效率之间的桥梁。

  • 最后是AI 对齐,与自我纠错模型。其目标是创建能够自我适应和纠错的模型,无需持续的人类干预,朝着生成式独立且负责任的方向迈进,同时减少大规模幻觉风险。

这些新的模型架构正在勾画出一个未来,其中 AI 的能力可能与人类认知相媲美,而当前生成式 AI 的局限性仍然存在。尽管研究还处于初期,但预计会有人在这一领域敢于冒险,开创新局面。

3) 每个数据平台推出他们的向量数据解决方案

随着生成型人工智能对向量数据库的依赖,我们将看到所有主要数据平台玩家推出他们的解决方案,旨在解决这一问题。虽然向量数据库并不是一个新概念,但它们传统上不属于“现代数据栈”,过去它们主要用于搜索引擎和其他类型的机器学习。

向量数据库是更复杂的生成型人工智能应用案例的关键要求,比如对话记忆、搜索文档(RAG),以及像图像索引这样的多模态解决方案。毫不奇怪,在 2023 年中期,我们看到随着多模态模型变得更加可访问,需求激增。

2023 年 3 月 OpenAI 发布 GPT-4 后,向量数据库的需求 — Google 趋势(全球)

随着 Databricks 最近发布了他们的向量数据解决方案,我们可能会预期像 Snowflake 这样的现代数据平台重要玩家在 2024 年跟随这一趋势。其他相关技术提供商也可能会加入这一趋势,提供各种功能和服务。几乎所有数据库技术将在 2024 年开始自称为“向量存储”。

随着这一采用转变,我们可能会看到资深数据和软件工程师的技能期望发生变化,他们会更加熟悉与向量索引、语义搜索解决方案、嵌入以及可能的 BM25 等算法相关的概念。

为了扩展,我们可能会看到一些最初被生成型人工智能社区采纳的现代向量搜索解决方案的收购或进一步融资轮次,例如ChromaWeaveatePineconeQdrant

开源向量数据库仓库的 Github“星标”。增长与 GPT3.5 发布相符 — 来源:Star-History & Github

在向量存储的开源世界中,Milvus 目前是大众最喜爱的选择,并凭借其托管服务树立了企业级选择的地位,但最近 Qdrant 在 2023 年取得了异常进展,几乎呈指数增长。

如果你有空闲时间,可以阅读 AI 工程师 Prashanth Rao 所写的四篇详细的向量数据库解析,深入了解向量数据库解决方案的能力和各个提供商。

4) 急于控制硬件和平台供应链

随着人工智能的主导地位,端到端更好地控制供应链变得愈发重要,因为它掌握着创新的钥匙。

例如,NVIDIA 作为芯片制造领域的领导者,已经在 AI 竞赛中取得了明显的胜利。大家心中的问题是,NVIDIA 是否会将视野拓展到云计算领域,利用其硬件专长提供集成的 AI 云服务。NVIDIA 已经拥有一项名为Geforce Now的云流媒体服务,提供按需高端图形处理。这一举动可能会重新定义竞争格局,使 NVIDIA 对 AI 发展轨迹的影响力更为直接。

Nvidia 股价,过去 5 年(NASD:NVDA)— 来源: Nvidia 投资者关系

在另一端,有像 Amazon 这样的云 AI 提供商和像 Anthropic 与 Mistral 这样的新兴 AI 公司。他们目前对外部硬件资源的依赖引出了一个关键问题:他们是否会跟随 OpenAI 的脚步,开始自购芯片,以及 Google 的TPUCoral AI? 这一战略可能意味着 AI 硬件自给自足和定制化的转变,可能会为顶级 AI 提供商带来更为量身定制、高效的 AI 解决方案。

Google Coral — 本地 AI 开发板与硬件 — 来源: Coral AI

随着 AI 底层硬件领域的所有企业争相控制自己的供应链,芯片制造商 NVIDIA 会转向云计算吗?云 AI 服务商如 Amazon 或像 Anthropic 这样的新兴公司会跟随 OpenAI,开始自给自足地生产芯片吗?移动芯片制造商如 Qualcomm 是否会成为赢家,推动新一代 AI 移动设备和可穿戴设备的普及?

向垂直整合 AI 的趋势突显了一个更大的故事:对 AI 硬件的控制正变得与控制技术未来的发展密不可分。我们将看到新兴公司争相进入为云服务提供的定制 AI 芯片领域,各国纷纷加速推动微处理器的发展,最终 OpenAI 也将在 2024 年初进入芯片领域。

5) 2024 年是 AI 可穿戴设备与 XR 的时代

今年,我们正见证一场深刻的革命——AI 和可穿戴设备与扩展现实(XR)设备的结合,改变了我们与技术的互动方式。这些设备不仅仅是小工具,它们是我们数字自我的延伸,无缝融入了我们的日常生活。

Humane AI Pin — 来源: Humane

Humane AI’s Pin(由高通 Snapdragon 处理器提供支持)和Tab正在重新定义可穿戴设备的格局。这些设备让我们一窥未来——在这个未来中,可穿戴设备不再仅仅是用于追踪健康数据或接收通知。它们旨在提升人与人之间的互动,提供实时的 AI 辅助,并为我们提供增强现实的体验。

Rewind AI 挂坠 — 来源:Rewind

Rewind,另一款突破性产品,正在革新我们捕捉和重温回忆的方式。想象一下,一款设备不仅能记录我们经历的瞬间,还能让我们重新回顾、理解并反思我们的记忆。

但不仅仅是新兴玩家在掀起波澜。像苹果和OpenAI 正迈入可穿戴设备领域,承诺带来他们巨大的创新能力。苹果凭借其在创造引领潮流设备方面的记录,极有可能推出与其生态系统无缝集成的可穿戴设备,可能会与其 XR 设备Apple Vision Pro一同推出,提供无与伦比的用户体验。

另一方面,OpenAI 可能会利用其在 AI 领域的专业知识,推出能够将先进 AI 功能带到我们指尖(或者手腕)的可穿戴设备。AI 可穿戴设备与生成型 AI 模型(如 ChatGPT)的互动潜力尤为令人兴奋,它将提供前所未有的便利与智能结合。

Meta 是另一个值得关注的重要玩家。他们进入 VR 可穿戴设备领域,已经向我们展示了沉浸式数字体验的潜力。随着他们在这个领域的不断创新,我们可以期待不仅能增强现实的可穿戴设备,还能创造出全新的领域,让我们去探索和互动。

在 2024 年,AI 可穿戴设备将不仅仅是技术的延伸。它们正在成为个人伴侣、数字助手以及通向新现实的门户。它们代表了人与技术之间更为亲密互动关系的转变,在这种关系中,我们的数字世界与物理世界无缝交织。今年,我们不只是在佩戴技术;我们在与之共生。

6) AI 代理将开始与其他 AI 代理进行沟通

想象一下一个生态系统,在这个生态系统中,具有独特专业知识和知识库的专业代理相互互动与协作。随着 AI“代理”的发展,我们将继续看到这个领域在 2024 年不断增长。

AI 代理市场图(2023 年 12 月) — 图片来源:Olivia Moore

我们将看到代理的出现,超越了支持个体需求,如写电子邮件、解决客户支持问题或订购杂货,进入一个代理开始相互互动的生态系统。这一范式的转变与数据产品的概念密切相关,企业将有机会像在其模型和数据集上获利一样从其代理中获利。这将促使一个由互联的智能代理组成的新生态系统的诞生。

特斯拉 Optimus Gen 2 人形机器人 — 2023 年 12 月

我们看到来自波士顿动力特斯拉等公司的机器人和人形机器人正面临解决这一问题的挑战,因为各种机器人需要共存并在本地进行通信,以决定如何执行任务。

拥有特定垂直领域数据优势的公司,如彭博社(金融)和 LexisNexis(法律),有望成为这一领域的潜在领跑者。彭博社凭借其在金融数据领域的优势,可能推出复杂的金融代理,并且已经开始开发自己的 LLM,而 LexisNexis 则可以利用其庞大的法律信息库来开发法律代理。这些代理凭借各自深厚的数据优势,不仅为直接用户服务,还将成为其他企业和系统不可或缺的资源,推动一个新的数字化劳动力生态系统。

预计我们将开始看到超越数字劳动力的新代理解决方案,包括代理协调、管理、监控以及数字机器人流程自动化领域的参与者,如UiPath,同时制造业和工厂作业领域的人形机器人也将通过其在大规模自动化和机器人系统方面的经验,逐步进入这一领域。

7) 生成 AI 模式将扩展

超越文本、代码、图像、视频和音频,进入更具沉浸感的模式和感官体验,如 3D、基因组学、嗅觉、味觉,并将以初步形式进入市场。

生成空间 AI 以生成新城市布局 — www.generativespatialai.com/

生成 AI 将突破当前文本、代码、图像、视频和音频的边界。我们将迎来更多沉浸式的模式,帮助科学研究如蛋白质结构和材料,或引入额外的感官体验,如 3D、嗅觉和味觉。这些新颖的模式预计将在初期形式中出现,预示着生成 AI 应用的下一波浪潮。

高斯溅射演示 — 将视频转化为 3D 虚拟现实

自主人工智能代理和多模态模型的崛起,加上可穿戴设备和扩展现实(XR)技术的进步,正在为消费者创造更加沉浸式和互动性的体验铺路。

想象一下你的梦想变成了一个虚拟现实世界,每个游戏都有一个根据你的视角定制的独特世界?

随着最近的 3D 建模技术高斯溅射)使得视频可以转换为 3D 虚拟现实,我们将看到随着生成技术的发展,这一技术将达到新的高度。

A-Lab Berkeley,机器人测试新材料——来源:Marilyn Sargent/Berkeley Lab

最大的影响将来自于材料科学和基因组学。GNoME 模型由 Google DeepMind 开发,已经被用来在材料科学领域取得突破,发现新的晶体结构,推动更好的电池和更高效的计算机。

科学等领域将是最有可能发生深刻研究进展的地方。

8) 消费者和监管机构推动更加民主化的人工智能

人们将继续推动人工智能的更大普及和包容性,但由于开发基础性人工智能模型的复杂性和成本,仍然存在挑战。这一对立局面为公众对人工智能透明度和伦理监督的需求增加提供了背景。

欧盟人工智能法案——提议的风险等级——来源:EY

对隐私和人工智能社会影响的关注正在推动消费者和监管机构,特别是在像欧盟这样的地区,在 GDPR 推动下成为现代数据隐私法律的催化剂,呼吁对人工智能进行更严格的治理。今年,我们预计在建立人工智能模型审计框架、标准化准确性和推出人工智能系统“报告卡”方面会取得进展,但仍然任重道远。

基础模型提供商是否符合欧盟人工智能法案草案?——来源:斯坦福 CRFN

人工智能中的风险管理将变得更加突出。企业将面临一个人工智能不仅仅是创新工具,同时也受到严格监管的局面。统一的框架和标准将会出现,引导企业负责任地采用人工智能,并确保人工智能的融入主流社会是安全的,且与公众福利相一致。

9) 人工智能赋能的营销策略新时代

传统上占据企业预算大头的营销领域,现在正经历着一场变革性的景观。催化剂是什么?像 ChatGPT 这样的基于聊天的工具崛起。这些创新有可能导致传统搜索量的显著下降,从根本上改变消费者获取信息的方式。

搜索引擎与 Chat GPT 的辩论与统计 — 来源:Twitter Greg Sterling

在这个不断发展的情境中,市场营销人员发现自己处于一个十字路口。在这些由 AI 驱动的对话中,影响或监控品牌提及的能力仍处于初期阶段。因此,越来越多的趋势是为生成型 AI 世界调整营销策略。这种调整在短期内依赖传统媒体,利用其广泛的影响力来建立和维持品牌存在感。

与此同时,我们也见证了技术领域的重大变化。从基于浏览器的工具到设备端应用的转变正在加速。这一变革的先驱者包括像 Microsoft Co-Pilot、Google Bard 以及预计在 2024 年推出的苹果自有大型语言模型(LLM)等创新。这一转变标志着从以网页为中心的互动向更加集成、设备驱动的 AI 体验的范式转变。

新款 Microsoft Surface X 预计将于 2024 年底发布 — 来源:Microsoft

这种变化不仅仅是为了便利,它代表了用户互动范式的根本改变。随着 AI 更加无缝地集成到设备中,在线和离线互动之间的界限变得越来越模糊。用户可能会在更加个人化、情境感知的环境中与 AI 互动,从而带来更有机和更具吸引力的用户体验。对于像 Google、Microsoft 和 Apple 这样已经深耕营销服务领域的科技巨头来说,这是重新定义其产品和服务的机会。

ChatGPT 不知道我是谁 — 来源:Vincent Koc 和 OpenAI ChatGPT

我们可以预见到新的“答案分析”平台和营销运营模式的出现,以支持答案引擎优化。这些工具可能会专注于理解和利用 AI 驱动互动的细微差别,但也可能更好地利用训练数据,了解结果如何呈现给特定品牌或产品。

数字营销人员将开始更加深入地思考他们在这些训练数据集中如何被索引,就像他们曾经在搜索引擎中所做的那样。

常见爬取域名的截图,常用数据集用于训练大语言模型 — commoncrawl.github.io/cc-crawl-statistics/plots/domains.html

此外,像 OpenAI 这样的平台可能推出广告赞助的搜索结果或媒体测量工具,这可能会为数字广告引入一个新维度。这一发展不仅为品牌推广提供了新的途径,还可能挑战现有的数字营销策略,促使人们重新评估指标和投资回报率的评估方法。

随着大语言模型(LLM)逐渐进入设备,并逐步远离传统的网页界面,营销格局即将发生重大变化。营销人员必须适应这些变化,既要利用传统媒体,又要借助新兴的 AI 技术,以便在这个新的数字时代有效地与受众互动。这种结合传统媒体影响力与 AI 驱动的精准分析的双重方法,可能会成为 2024 年迅速发展的营销环境中成功的关键。

10) “垃圾进,垃圾出”困境加剧

随着组织越来越倾向于利用生成性 AI 模型并开发自己的定制解决方案,焦点迅速转向输入数据的质量。数据管理圈中经典的“垃圾进,垃圾出”这一说法再次浮现,因为数据质量已经重新成为讨论的重点。

趋势主题:数据质量 — 来源:Google Trends 全球趋势

组织和领导者面临着一个严峻的现实,那就是高质量、准确标注的数据是有效部署 AI 的基石。这个问题不仅仅是数据的可用性,更关乎数据的相关性、准确性以及它所提供的背景。偏差问题和误导性的训练数据可能导致模型输出的灾难性后果。

在用于大语言模型的各种训练数据中,发现了多种偏见。— 来源:大语言模型中的偏见与公平性:调查 — arxiv

然而,挑战并不止于此。曾经看似足够的现有数据和 AI 管道技术,如今正被推向极限。它们常常在面对高级 AI 模型的细致需求时显得捉襟见肘。能力与需求之间的这一差距,要求数据处理工具和方法的进化。

2023 数据完整性趋势与洞察报告揭示了数据质量差作为各行业组织普遍面临的主题,70%的那些对数据缺乏信任的组织将数据质量视为做出自信决策的最大挑战。

此外,追求 AI 卓越也伴随着一系列后勤难题。这些复杂 AI 模型所需的强大计算能力意味着对 GPU 的需求非常大。

以“GPU 短缺”为题的表情包

但这不仅仅是拥有投资硬件的财力问题;更重要的是这些资源在市场中的可获得性。随着越来越多的玩家进入 AI 领域,GPU 的竞争加剧,可能导致 AI 开发和部署的瓶颈,这进一步增加了组织适应 AI 常态的复杂性。

随着 2024 年的到来,我们见证了对数据质量和基础设施提升的重新关注,塑造着 AI 发展的轨迹。

11) 定制的小型基础模型普及

2024 年很可能成为小型基础模型的时代。这些专用的定制 AI 模型将走到前台,在效率和精准度上超越它们的通用版本。

现在,组织可以选择使用现成的已训练的通用大型语言模型,如 OpenAI GPT、Google Bard、Anthropic Claude [图表 b 中的 RL 模型],或者进入构建自有模型的领域。

LLM 开发阶段,由 InstructGPT 论文开创,最终引领至 ChatGPT。此图改编自 Chip Huyen 的文章“RLHF: Reinforcement Learning from Human Feedback” — 来源:Argilla

提供细调的选项,即以基础语言模型为基础,但教它新的东西(就像你可以想象将一辆车调校得更快),或者深入创建你自己的基础(底层)模型。

Open AI GPT 定制企业模型 — 来源:OpenAI

例如,像 GPT 这类拥有 1000 亿参数的通用模型(或其模型编程中的“开关”重建的费用大约超过 400 万美元。OpenAI 最近也开始为企业提供“自建 GPT”的服务,定价从 200 万到 300 万美元不等。

那些已经开发自有基础模型或对现有模型进行细调以适应特定用例的组织,正处于成功的有利位置。这种方法与 80:20 规则相契合,重点放在较小、量身定制的模型上,这些模型针对特定需求,而非试图通过通用解决方案吸引大众。

组织的真正价值在于能够开发这些针对特定任务的定制模型。这些模型不仅提供更高的准确性和相关性,还能带来新的货币化机会。在一个日益受到专业需求驱动的世界中,这些模型成为了无价的资产,提供既高效又经济可行的解决方案。

此外,趋势正在从单纯依赖大型通用模型转变,因为这些模型并不适用于所有需求。许多组织已经在这些广泛的模型基础上构建了解决方案,充当“薄包装器”,提供有限的定制和扩展性。尽管这些解决方案可能曾是一个踏脚石,但它们不太可能提供开发专有模型所能带来的长期价值。

通用模型可用性和质量预测 — 来源:Sequioa Capital

在这一不断发展的领域中,获胜者将是那些投资开发自己模型(无论是通用模型还是小型基础模型)以填补通用模型领域空白的企业。这一策略不仅提高了准确性和有效性,还减少了成本开销。较小的模型不仅运行成本更低,而且更容易适应和管理。

这引出了“建模总成本”的概念,即开发和维护这些模型的好处超出了初期投资,从长远来看,提供了一种更可持续、更具成本效益的方法。

展望 2024 年,显然能够创建并利用小型基础模型将成为 AI 市场竞争中的关键区分因素。这一转变标志着向更加个性化、高效且经济合理的 AI 解决方案迈出了重要一步。

12) AI 市场的代理商崛起

AI 市场正在成为适应性强且响应迅速的平台,重塑我们对技术交易和互动的思考方式。

OpenAI GPT 商店,2024 年推出 — 来源:OpenAI

领军的 OpenAI 准备在 2024 年初揭开备受期待的“GPT 市场”的面纱,为 AI 领域树立新的标杆。这一举措预计将打开大门,Meta 等其他科技巨头也会迅速跟进。我们可能会看到包括亚马逊、苹果,甚至字节跳动等老牌巨头,以及新兴初创公司在这一领域纷纷投入。

但这一革命的涟漪效应超越了传统的科技公司。随着自主代理和可穿戴技术的增长,我们将再次迎来苹果 iPhone 和 App Store 的时刻。开发者将把这视为下一个黄金机会。

在这一变革阶段,AI 市场现象正在从 B2B 扩展到 B2C 领域。我们可能会看到一系列不同的参与者尝试进入这一领域,每个玩家都带来独特的价值主张。从面向消费者的 AI 应用到企业级解决方案,这些市场中的产品种类将满足各种需求和期望。

13) AI 产品将超越 SaaS 模式

随着 AI 市场和工具的普及,传统定价策略正在重新评估,为迎合 AI 服务独特性质的创新方法铺平道路。

基于使用量的公司 — 来源:Open Venture Partners

我们可能会看到从传统应用商店定价模式向更动态的基于消费的计费系统的显著转变。这些模式类似于电话线路等公用事业计费方式,预计将变得更加普及,符合租赁 AI 代理的概念。在这种设置中,客户将根据其 AI 使用的程度和性质进行支付,提供一种灵活且可能更加公平的定价结构。

但这种演变并未止步于此。AI 市场和工具的日益普及也为各种定价策略和新型商业模式铺平了道路。这一变化源于对各种 AI 应用和服务的需求,每个应用和服务都有其独特的价值主张和使用模式。这可能会试行开发者和其 AI 服务或代理在市场上的收入分成(如 App Store)或使用费(如 Spotify)模式。

另一个新兴的模式可能是基于表现的定价,其中费用与 AI 工具所交付的成果或结果挂钩。这种模式在那些 AI 影响可以量化的领域,特别具有吸引力,例如营销分析、财务预测甚至创意产业。

定价中的捆绑与拆分策略 — 来源:Matt Brown

此外,随着 AI 不断渗透各个行业,跨行业的合作伙伴关系可能会催生捆绑服务。这些捆绑可能将 AI 工具与传统软件服务结合,提供一个全面的套餐,以满足更广泛的商业需求。

这些新定价模式和策略的出现反映了一个快速适应 AI 所带来的独特挑战和机遇的市场。随着企业和消费者对 AI 能力的认知不断加深,对灵活、透明且与价值对齐的定价模型的需求可能会愈加强烈。

14) BYO AI 运动推动了对安全数字身份的需求

扩展现代和生成性人工智能工具将推动数字足迹的扩展,从而需要安全、可携带的数字身份,而挑战在于如何平衡强大的安全性与用户可访问性。用户将期望个性化的体验,在这种体验中,偏好、历史和语境将成为使用许多 AI 服务的关键。

银行和电子政府平台正逐步成为这些单一数字身份和个人偏好的潜在托管者。这种整合指向一个更加简化和安全的数字存在。但这不仅仅关乎安全,它还关乎我们数字自我在各个平台之间的无缝集成。

“自带人工智能”(BYO AI)正是与此密切相关。想象一下,你能无缝地将你的数字偏好、学习风格,甚至购物习惯从一次数字互动传递到下一次。这种可携带性不仅仅是方便,它具有变革性。它使得以前无法达到的个性化和效率成为可能。可穿戴设备也正在成为管理我们数字身份的重要工具。通过不断从我们的互动中学习,它们演变成个人数据中心,不仅理解我们的偏好,还能预见我们的需求。

人工智能的融入工作环境意味着我们的数字偏好可能会自动调整办公室应用程序、通讯工具,甚至物理工作空间的设置。想象一下,进入会议室时,照明、温度,甚至数字显示屏都会自动根据你的偏好进行调整。

然而,这种个性化和数据整合的水平引发了关于隐私和数据使用的问题。随着这些数字身份变得更加复杂,并与人工智能紧密交织,它们被用于广告和数字体验提供商的超个性化的潜力是巨大的。这可能导致一种新的语境广告和消费者参与的时代,其中促销不仅仅是针对特定群体,而是深度融入我们的数字人格中。

身份与人工智能的融合将重新定义我们与技术的互动方式,无论是在个人领域还是职业领域,最终会导致更加个性化、高效和互联的生活方式。在我们拥抱这一未来时,伦理考虑和隐私保障的重要性比以往任何时候都更加突出。

结论与关键要点

展望 2024 年,人工智能和技术重塑我们世界的潜力是不容忽视的。每一个预测都为我们展示了一个创新、责任和包容并行的未来。

关键要点:

  1. 生成性人工智能崭露头角,成为核心技术战略:从炒作到各个领域的主流应用的转变。

  2. 神经网络的进展接近通用人工智能(AGI):像 Mamba 和神经-符号 AI 这样的新架构显著增强了认知能力。

  3. AI 可穿戴设备和扩展现实(XR)获得关注:在日常生活中提供增强的人工体验和互动。

  4. 互联的 AI 代理创建了一个新的生态系统:专门的 AI 代理相互沟通,革命性地改变了行业动态。

  5. 在 BYO AI 时代,安全的数字身份变得至关重要:由银行和电子政府平台管理的便携式数字身份崛起。

让我们以开放的心态和热情拥抱这段旅程,准备好成为塑造未来的一部分,这不仅是正在发生的事情,而是我们可以共同创造的未来。加入对话,分享你的见解,让我们共同设想并建设 2024 年的世界。

喜欢这个故事吗?

Vincent Koc 是一位高成就、以商业为导向的技术专家和企业家,拥有丰富的经验,专注于数据驱动和数字化领域。目前,Vincent 在澳大利亚担任数据领袖,同时在美国教授人工智能课程。

免费订阅,在 Vincent 发布新故事时获取通知。或者通过 LinkedInX(前 Twitter)关注他。

[## 每当 Vincent Koc 发布文章时,你都会收到电子邮件通知。

每当 Vincent Koc 发布文章时,你都会收到电子邮件通知。通过注册,如果你还没有 Medium 账户,将会创建一个账户……

medium.com](https://medium.com/subscribe/@vkoc?source=post_page-----41e0ac83d68f--------------------------------)

除非另有注明,所有图片均为作者提供

探索未来

原文:towardsdatascience.com/navigating-the-future-62ea60f27046?source=collection_archive---------5-----------------------#2024-01-10

大型多模态模型时代的自主机器人技术

Nate CibikTowards Data Science Nate Cibik

·发表于 Towards Data Science ·阅读时间 34 分钟·2024 年 1 月 10 日

--

图片由作者使用 DALL-E 3 创建。

在我最近的工作中,我探索了轻量级的 层次化视觉变压器 在有效执行多个计算机视觉任务学习和推理的能力,这些任务对于机器人感知至关重要。在我的 Multiformer 项目中,我运用了这种“共享主干”概念,通过一个共同的主干向多个任务头提供特征,这已成为多任务学习中的一种流行方法,特别是在自主机器人领域,因为研究表明,学习一个对多个任务有用的特征空间不仅能产生一个可以在给定单一输入下执行多任务的单一模型,还能通过利用从其他任务中学到的互补知识,提高每个任务的表现。

传统上,自主车辆(AV)的感知系统通过在多个计算机视觉任务上进行同步推理,来形成对周围环境的理解。因此,具有共同主干的多任务学习是一个自然的选择,提供了参数效率和单个任务性能的双赢解决方案。然而,大型多模态模型(LMMs)的崛起挑战了这一高效的多任务范式。利用 LMMs 创建的世界模型具备深刻的能力,可以在描述性和预测性层面理解传感器数据,超越了任务特定处理,能够全面理解环境及其未来状态(尽管需要更高的参数量)。

在这一新范式中,已被称为AV2.0的模式下,像语义分割和深度估计这样的任务,成为了拥有更深理解数据能力的模型的涌现能力,对于这些模型来说,执行这些任务变得没有意义,除非是为了将这些知识传递给人类。事实上,执行这些中间任务的整个目的,是将这些预测发送到感知、规划和控制算法的更深层次,最终描述自我与周围环境的关系,以及应采取的正确行动。相比之下,如果一个更大的模型能够描述完整的驾驶场景,并能够在相同输入条件下描述正确的驾驶行为,那么就不再需要知识的有损中间表示,网络可以直接学习对数据做出响应。在这一框架中,感知、规划和控制之间的界限被消除,创造出一种可以端到端优化的统一架构。

虽然这仍然是一个新兴的思想流派,但使用大规模语言模型(LMMs)构建的生成世界模型来实现端到端自主驾驶解决方案,已成为一个合理的长期赢家。它延续了通过序列建模公式简化以往复杂解决方案的趋势,这一趋势始于自然语言处理(NLP),迅速扩展到计算机视觉,现在似乎已经在强化学习(RL)中扎根。更进一步的是,这些曾经独立的研究领域正在一个共同的框架下统一,并因此相互加速发展。对于自动驾驶(AV)研究,接受这一范式转变意味着要迎接基础设施和方法论在大规模变换器模型的训练、微调和部署方面的快速加速,而来自多个学科的研究人员正继续加入并为这一“智能是序列建模问题”的现象注入动力。

但这对传统模块化自动驾驶堆栈意味着什么呢?像 Multiformer 这样的多任务计算机视觉模型是否注定会过时?似乎很明显,对于一些简单问题,比如需要在已知类别集合上进行基本图像分类的应用,大型模型显然是过度的。然而,对于像自主机器人这样的复杂应用,答案目前远不那么明确。大型模型存在严重的缺点,特别是在内存需求和资源消耗方面。它们不仅训练所需的财务(和环境)成本巨大,而且部署可能性也受到限制:模型越大,嵌入系统(机器人)就必须越大。因此,大型模型的开发有一个实际的进入壁垒,这势必会阻止小型公司采纳。不过,大型模型的能力所带来的吸引力已经在全球范围内推动了训练和部署方法的可访问性发展,而这一趋势注定会继续下去。

2019 年,Rich Sutton 在 AI 研究中提到“苦涩的教训”,指出无论是自然语言处理还是计算机视觉等领域,基于人类知识的复杂方法,融入手工设计元素,最终都会成为浪费时间的死胡同,这些方法会被更多依赖原始计算的大多数通用方法所取代。目前,大型变换器的出现以及各种问题巧妙地转化为自监督序列建模任务,是消除割裂和定制问题框架的主要推动力。现在,长期以来在强化学习(RL)和时间序列分析中的方法,包括像循环神经网络(RNN)这样的经典模型,都必须证明它们的有效性,否则就将像 SIFT 和基于规则的语言模型一样,进入退役的行列。在自动驾驶堆栈(AV stack)开发方面,我们是否应该尽早打破传统的束缚,转向大型世界建模,还是传统模块化驾驶堆栈的可访问性和可解释性能够经受住大型模型浪潮的冲击?

本文讲述了一个引人入胜的研究趋势交汇故事,这些趋势将指导我们朝着这个问题的有教育意义的答案迈进。首先,我们回顾传统的模块化自动驾驶(AV)堆栈开发,以及多任务学习如何通过在共享参数空间中利用泛化知识来提升性能。接下来,我们探索大语言模型(LLMs)的飞速崛起及其向多模态的扩展,展示它们在机器人学中的影响潜力。然后,我们了解强化学习(RL)中世界建模的历史,以及大语言模型的出现如何点燃一场强大的革命,将这些世界模型赋予类似当今大模型所具备的推理和语义理解能力。接着,我们对比了这种大型世界建模方法与传统自动驾驶堆栈开发的优缺点,显示出大型模型在简化架构、高维空间中的端到端优化以及卓越的预测能力方面具有巨大优势,但也以远高的参数数量为代价,带来了多个工程挑战。鉴于此,我们回顾了几种有前景的技术,用以克服这些工程挑战,使得这些大模型的开发和部署成为可能。最后,我们总结发现,尽管大型世界模型有望成为长期赢家,但从传统方法中汲取的经验教训仍将对其成功的最大化具有重要意义。我们以讨论结束,重点介绍了该激动人心领域未来工作的一些有前景的方向。

计算机视觉与自动驾驶中的多任务学习

多任务学习(MTL)是一个受到广泛关注的研究领域,通常被描述为人工智能(AI)向人类推理迈进的重要一步。正如Michael Crawshaw 对该主题的全面调查中所指出的,MTL 涉及同时在多个任务上训练一个模型,使其能够利用这些任务之间共享的信息。这种方法不仅在计算效率上具有优势,还由于学习到的特征具有互补性,导致任务性能得到提升。Crawshaw 的调查强调,MTL 模型通常通过学习更加稳健和泛化的表示,超越单任务模型的表现。

我们认为,多任务学习(MTL)比单任务学习更能准确反映人类的学习过程,因为跨领域整合知识是人类智能的核心特征。当一个新生儿学习走路或使用双手时,它积累了一般的运动技能,这些技能依赖于平衡和直觉物理的抽象概念。一旦这些运动技能和抽象概念被掌握,它们可以被重复使用并在生活后期用于更复杂的任务,例如骑自行车或走钢丝。

Crawshaw, 2020

MTL 的优势在自动驾驶汽车(AVs)中尤为重要,因为自动驾驶汽车需要实时推理多个相关的视觉任务,以做出安全的导航决策。MultiNet是为自动驾驶汽车设计的典型 MTL 模型,结合了道路分割、目标检测和分类等任务,统一在一个架构中。将 MTL 集成到自动驾驶汽车中带来了显著的优势,如更高的帧率和减少的内存占用,这对于不同规模的自动化机器人至关重要。

使用Multiformer对来自 RGB 输入的三个自动驾驶感知任务进行多任务推理。

基于 Transformer 的网络,如视觉 Transformer(ViT)及其衍生物,在计算机视觉中展示了令人难以置信的描述能力,且将 Transformer 与卷积架构结合,形成如分层 Transformer的形式(如金字塔视觉 Transformer v2(PVTv2)),已证明特别强大且易于训练,在近期模型中,如SegformerGLPNPanoptic Segformer,其性能持续超越ResNet主干,并且参数更少。受到对强大且轻量感知模块的需求激励,Multiformer结合了 MTL(多任务学习)和分层 Transformer 的描述能力,利用超过 8M(百万)参数,在语义分割、深度估计和 2D 目标检测上同时表现出色,并且可以轻松扩展到全景分割任务。

这张 Multiformer 的示意图显示,即使是统一的多任务视觉架构也很复杂,包含多个任务特定的模块。尽管它为轻量级感知模块提供了强大的基础,但它并不涉及规划或控制的推理,而且在没有重大修改的情况下,无法泛化到新任务。(图片来源:作者)

然而,构建一个完整的自主堆栈不仅仅需要一个感知模块。我们还需要规划和执行动作,因此我们需要添加一个规划与控制模块,该模块可以利用感知堆栈的输出,准确地跟踪并预测自我和环境的状态,从而发送代表安全驾驶动作的指令。一个有前景的选择是英伟达的DiffStack,它提供了一种可训练且可解释的轨迹预测、路径规划和控制建模的结合。然而,该模块需要 3D 代理位姿作为输入,这意味着我们的感知堆栈必须生成这些信息。幸运的是,已经有可用的 3D 物体检测算法,尤其是在获得准确深度信息的情况下,但我们的物体追踪将在这个困难任务中对准确性和时间一致性极为敏感,任何错误都会传播并降低下游运动规划和控制的质量。

DiffStack模块的示意图,该模块需要过去的轨迹(3D 代理位姿)作为输入。所有问题都被战略性地设计为可微分的,以便通过子模块进行反向传播,同时允许解释中间表示。然而,这些周期性的晶化信息是有损的,系统也从这些中间问题的表述中继承了一系列弱点。

的确,传统的自主堆栈模块化范式,从传感器输入到感知、规划和控制的各个阶段,本质上容易受到累积错误的影响。序列中的每个阶段都依赖于前一个阶段的准确性,这使得系统容易受到错误级联的影响,并妨碍通过中间信息的晶化进行端到端的错误修正。另一方面,模块化方法比端到端系统更具可解释性,因为中间表示可以被理解和诊断。正因为如此,端到端系统通常被回避,因为它们被视为“黑箱”解决方案,缺乏对于像自主导航这样的安全关键应用所能接受的可解释性。但如果可解释性问题可以克服呢?如果这些黑箱能够用简单的英语或其他自然语言解释它们做出的决定呢?进入自主机器人学的大语言模型时代,这一愿景不再是遥不可及的梦想,而是一个切实的现实。

自回归变换器与大语言模型的崛起

ChatGPT 被要求在一张截图中展示其才华。

在被证明是我们时代最具影响力的研究论文之一中,Vaswani 等人于 2017 年通过论文“Attention is All You Need”提出了变换器架构,革新了序列到序列(seq2seq)建模,提出了他们的注意力机制。这些创新模块克服了之前偏爱使用的 RNN 的弱点,通过有效地捕捉序列中的长距离依赖关系,并在计算过程中允许更多的并行化,带来了各种 seq2seq 任务的显著改善。一年后,谷歌的双向编码器表示变换器(BERT)通过引入基于掩码语言建模(MLM)的双向预训练目标,进一步增强了变换器在自然语言处理中的能力,融合了左右两侧的上下文,编码了更为细致的上下文理解,支持了情感分析、问答、机器翻译、文本摘要等多种语言任务。

在 2018 年中期,OpenAI 的研究人员展示了如何训练一个因果解码器单向转换器,使其能够处理字节对编码(BPE)文本标记,并使用生成预训练变换器(GPT)。他们发现,通过在自监督自回归语言建模任务上进行预训练,使用大量未标记文本数据,随后通过任务感知的输入变换(在必要时进行架构修改)进行任务特定的微调,可以产生在多种语言任务上显著提升的模型,进而改善了当时的技术水平。

虽然 GPT-1 中在标记空间中使用的任务感知输入变换可以被视为“提示工程”的早期形式,但这个术语通常指的是通过有策略的文本结构化,激发语言模型的多任务行为。2018 年,Salesforce 的研究人员通过他们具有影响力的多任务问答网络(MQAN)展示了这一点。通过将任务框架化为具有独特格式的文本字符串,作者训练了一个没有任务特定模块或参数的单一模型,使其在一组十项自然语言处理任务(他们称之为“自然语言十项全能”(decaNLP))上表现良好。

在 2019 年,OpenAI 发现,通过在推理时采用这种形式的提示工程,GPT-2展现了令人期待的零样本多任务表现,并且模型和数据集的规模与性能呈对数线性增长。尽管这些任务提示结构并不像 MQAN 那样明确地包含在训练数据中,但模型能够从之前见过的结构化语言中泛化知识,以完成当前的任务。该模型展示了具有 15 亿参数的令人印象深刻的无监督多任务学习(相比 GPT 的 1.17 亿参数),这表明这种语言建模方式为通用 AI 提供了一个有前景的路径,同时也引发了对未来的伦理担忧。

谷歌研究在 2019 年底开源了文本到文本的转换器(T5),其模型规模最大可达 110 亿参数。虽然 T5 同样基于自回归变换器构建,但它将自然语言问题表示为统一的文本到文本框架,使用完整的变换器架构(包括编码器),与 GPT 风格模型的下一个词预测任务有所不同。虽然这种文本到文本框架在需要更多控制任务训练和预期输出的应用中是一个强有力的选择,但 GPT 风格模型的下一个词预测方案因其任务无关的训练方式和自由形式的生成长篇连贯回应而被更为青睐。

然后在 2020 年,OpenAI 将模型和数据规模提升到了前所未有的高度,发布了GPT-3,其余的事情便成为了历史。在他们的论文《语言模型是少样本学习者》中,作者定义了一个“少样本”迁移范式,在这种范式中,他们提供了任何数量的示例,用于未见过的任务(以自然语言形式呈现),这些示例会被纳入模型的上下文,直到任务的最终开放式提示供模型完成。他们将其与“一次示例”对比,即在上下文中提供一个示例,以及“零示例”,即完全不提供示例。研究团队发现,在所有三种评估方法下,模型的表现随着参数数量的增加持续提升,直到达到了 1750 亿参数,这标志着发布的模型规模出现了历史性跃升。这一庞然大物实现了接近人类水平的通用少样本学习和文本生成能力,引起了主流关注,并激发了对这一趋势在 AI 研究中未来影响的担忧。那些担忧的人或许能暂时获得一丝安慰,因为在这些规模下,训练和微调这些模型远远超出了除了最大型组织之外的所有机构的能力范围,但这一现状显然会发生变化。

来自GPT-3论文的图表,展示了随着模型规模的增大,整体性能的提升。

GPT-3 在多个方面具有突破性意义,同时也标志着 OpenAI 开放性结束,它是第一个封闭源代码的模型。幸运的是,开源 LLM(大语言模型)研究的浪潮已经开始。EleutherAI 从 2020 年起发布了广受欢迎的一系列大规模开源 GPT-3 风格的模型,从GPT-Neo 2.7B开始,接着是 2021 年的GPT-J 6B,再到 2022 年的GPT-NeoX 20B,后者在基准测试中与 GPT-3.5 DaVinci 不相上下(所有模型均可在huggingface/transformers找到)。

随后的几年标志着基于 Transformer 的 LLM 的“寒武纪大爆炸”。研究兴趣的超新星产生了令人叹为观止的出版物清单,这些文献的全面回顾超出了本文的范围,但我建议读者参考Zhao 等人 2023的全面调查。几个值得一提的关键进展包括,当然,OpenAI 发布的 GPT-4,以及 Meta AI 发布的开源LLaMA模型,强大的Mistral 7B模型,及其专家混合(MoE)版本:Mixtral 8X7B,这些都发生在 2023 年。普遍认为,GPT-4 是一个 MoE 系统,而 Mixtral 8X7B 所展示的强大性能(在大多数基准测试中超过LLaMA 2 70B,并且推理速度快 6 倍)为此提供了有力证据。

来自Zhao 等人 2023调查的 LLaMA 后代家谱展示了 LLM 研究的规模。

为了简洁地总结过去几年 LLM 大爆炸的视觉概况,再次借用强大的 Zhao 等人 2023 年的调查是很有帮助的。请注意,这张图表仅包括参数超过 10B 的模型,因此遗漏了一些重要的小型模型,如 Mistral 7B。不过,它为最近的进展提供了有用的视觉锚点,同时也证明了 T5 和 GPT-3 发布后研究动能的巨大增长。

来自Zhao 等人 2023调查的 LLM 时间轴。

值得注意的是,尽管开源 LLM 在性能上明显落后于私人模型,但这一差距正在随着时间的推移逐渐缩小,开源模型似乎准备在不久的将来迎头赶上。现在正是熟悉将 LLM 集成到工作中的最佳时机。

图片由作者提供。请注意,为了公平对比,微调模型已从趋势线数据中移除。

大型多模态模型时代

在 LLM 的巨大成功基础上,人工智能的最新时代见证了 LMM 的诞生,代表了机器理解和与世界互动方式的范式转变。这些大型模型可以接收多种模态的数据作为输入,返回多种模态的数据作为输出,或者两者兼有,通过学习这些数据模态的共享嵌入空间并利用 LLM 对该空间进行序列建模。这使得 LMM 能够执行突破性的任务,如使用自然语言进行视觉问答,正如在这一大型语言与视觉助手(LLaVA)的演示中所展示的那样:

LLaVA演示了视觉问答,利用自然语言推理图像内容。

视觉-语言预训练的一个重要进展(VLP),OpenAI 的对比语言-图像预训练(CLIP)在 2021 年解锁了一个新的可能性,当时它建立了一种对比方法来学习共享的视觉和语言嵌入空间,使得图像和文本能够在一个共同的数字空间中表示,并基于余弦相似度得分进行匹配。CLIP 在计算机视觉领域引发了一场革命,当它能够以零-shot 的方式超越多项图像分类基准,超过了那些通过监督训练的专家模型,并激发了对零-shot 分类的研究兴趣。尽管它没有实现诸如视觉问答等能力,但训练 CLIP 会生成一个图像编码器,该编码器可以被移除并与 LLM 配对,从而创建一个 LMM。例如,LLaVA 模型(如上所示)使用预训练且冻结的 CLIP 图像编码器将图像编码为多模态嵌入空间,DeepMind 的Flamingo也是如此。

注意 — LMM 的术语并不完全一致。虽然“LMM”似乎已成为最流行的术语,但在其他地方这些模型被称为MLLMs,甚至是MM-LLMs

这些预训练 CLIP 编码器生成的图像嵌入可以与文本嵌入交替输入到自回归的变换器语言模型中。AudioCLIP将音频作为第三种模态加入到 CLIP 框架中,突破了环境声音分类(ESC)任务的现有技术水平。Meta AI 的有影响力的ImageBind提出了一个学习跨六种数据模态(图像、文本、音频、深度、热成像和惯性质量单元(IMU)数据)编码联合嵌入的框架,但表明所有模态的涌现对齐是通过仅将每种模态与图像对齐来实现的,证明了图像的丰富语义内容(“一张图胜过千言万语”)。PandaGPT将 ImageBind 的多模态编码方案与Vicuna LLM 结合,创建了一个能够理解这六种模态数据输入的 LMM,但与前面提到的其他模型一样,仍然仅限于文本输出。

图像可能是最具多功能性的模型输入格式,因为它可以用来表示文本、表格数据、音频,甚至在一定程度上,视频。而且,视觉数据远多于文本数据。如今,我们有手机/网络摄像头不断拍摄照片和视频。

文本是模型输出的更强大的模态。能够生成图像的模型只能用于图像生成,而能够生成文本的模型可以用于许多任务:总结、翻译、推理、问答等。

— Huyen 在其文章“多模态与大型多模态模型(LMMs)”中对数据模态优势的敏锐总结(2023 年)。

事实上,大多数 LMMs 的研究只提供了单模态语言输出,与之相比,返回多模态数据的模型发展相对滞后。那些寻求提供多模态输出的工作通常通过使用 LLM 解码文本来引导其他模态的生成(例如,当请求生成图像时,GPT-4将生成一个自然语言的专门提示并将其传递给DALL-E 3,后者则为用户创建图像),这种方式本质上引入了级联错误的风险,并且无法进行端到端的调优。NExT-GPT旨在解决这一问题,设计了一个可以端到端训练的全对全 LMM。在编码器方面,NExT-GPT 使用了上面提到的 ImageBind 框架。为了引导跨 6 种模态的解码,LMM 在一个定制的模态切换指令调优数据集 Mosit 上进行了微调,学习生成作为解码过程指令的特殊模态信号标记。这使得数据输出模态切换的处理可以通过端到端的学习来完成。

GATO是 DeepMind 在 2022 年开发的一款通用智能体,代表了 LMM(大规模语言模型)的卓越多功能性。该系统展示了前所未有的能力,能够在统一的学习框架内执行广泛的 604 项任务,从 Atari 游戏到使用真实机器人手臂堆叠积木等复杂控制任务。GATO 的成功证明了 LMM 在多样化环境和任务中模拟类人适应性的潜力,向实现人工通用智能(AGI)的难以捉摸的目标迈出了重要一步。

LMM 时代的世界模型

深度强化学习(RL)是一种流行且经过充分研究的解决复杂机器人问题的方法,最初在Atari 游戏中展示了超人类能力,随后又击败了世界顶级围棋玩家(围棋是一项著名的挑战性游戏,要求长远的战略眼光)。传统的深度强化学习算法通常分为无模型方法或基于模型的方法,尽管近期的研究通过将强化学习框架化为一个使用大型变换器模型的大规模序列建模问题,模糊了这两者的界限,借鉴了自然语言处理和计算机视觉领域中的成功趋势。

虽然无模型强化学习方法在设计和实现上明显更加高效,并且比基于模型的方法更容易设计,但它们通常样本效率较低,需要比人类更多的环境交互才能学会一项任务。基于模型的强化学习方法通过学习如何根据先前的状态和行动来建模环境的变化,来减少与环境的交互次数。这些模型可以用于预测环境的未来状态,但这也为强化学习系统增加了一种失败模式,因为它们必须依赖于这种建模的准确性和可行性。使用神经网络来学习动力学模型以训练强化学习策略已有很长的历史,可以追溯到1980 年代使用前馈网络,以及1990 年代使用递归神经网络(RNNs),后者由于能够在多步时间范围内进行建模和预测,成为了主要的研究方向。

2018 年,Ha & Schmidhuber 发布了一项具有突破性的研究成果,名为“递归世界模型促进策略演化”,在其中他们展示了通过超越简单的动态建模,扩展环境建模的力量,而是使用卷积变分自编码器(CVAE)和一个大型 RNN 的组合来建模环境本身的压缩时空潜在表示,从而形成了所谓的“世界模型”。策略完全在这个世界模型的表示中进行训练,并且因为它从未暴露于真实环境中,因此可以从可靠的世界模型中进行采样,从其学到的世界理解中模拟想象中的执行轨迹,为进一步训练策略提供有效的合成示例。这使得策略训练更加数据高效,这对于实际应用中的 RL 在真实世界领域的应用具有巨大优势,因为数据收集和标注通常非常资源密集。

Ha & Schmidhuber, 2018展示了 CarRacing-v0 和 DoomTakeCover-v0 的世界模型模拟。

Ha & Schmidhuber, 2018提供的基于 RNN 的世界模型的清晰示意图。

这种在世界模型想象中的学习诱人概念自此获得广泛关注。模拟策略学习(SimPLe)利用这一范式,在视频预测模型中训练 PPO 策略,仅用两小时的实时游戏经验就能在 Atari 游戏中实现最先进的表现。DreamerV2Dreamer的改进版)成为第一个通过想象学习的智能体,在 Atari 50M 基准测试上达到超人类表现的例子(尽管需要数月的游戏经验)。Dreamer 算法还被证明在DayDreamer中对真实机器人控制的在线学习中同样有效。

这张来自DREAMER V2的图表展示了通过之前的 SoTA 模型,Atari 性能的提升过程。

尽管它们在强化学习(RL)环境中最初证明难以训练,但变换器的诱人特性使其在另一个研究领域中产生了颠覆性影响。将强化学习框架化为序列建模问题有许多好处,即简化了架构和问题表述,并且变换器提供的数据和模型规模的可扩展性。轨迹变换器被训练用来预测未来的状态、奖励和动作,但仅限于低维状态,而决策变换器可以处理图像输入,但仅预测动作。

将强化学习,及更广泛的数据驱动控制,作为一个序列建模问题,可以处理通常需要不同解决方案的许多考虑因素:演员-评论家算法……行为策略估计……动力学模型……价值函数。这些问题都可以通过单一的序列模型统一处理,将状态、动作和奖励视为简单的数据流。该观点的优势在于,可以采用高容量的序列模型架构来解决问题,从而形成一种更简化的方法,这种方法可能会受益于大型无监督学习结果背后的可扩展性。

— 在轨迹变换器介绍中提供的动机

IRIS(通过内在语言进行自回归的想象)是一个最近的开源项目,构建了一个生成世界模型,其结构类似于VQGANDALL-E,结合了离散自编码器和 GPT 风格的自回归变换器。IRIS 通过模拟数百万条轨迹来学习行为,使用编码的图像令牌和策略动作作为输入,传递给变换器预测下一组图像令牌、奖励和回合终止状态。预测的图像令牌被解码为图像并传递给策略生成下一步动作,尽管作者承认,在潜在空间上训练策略可能会带来更好的性能。

IRIS 的结构,一个有前景的开源大规模世界模型。

IRIS 展示了通过完美预测这一轮 Pong 来表现其深刻的环境理解。

GAIA-1Wayve开发,采用自回归变换器世界建模方法,结合了使用扩散解码器的图像和视频生成,并且增加了文本调节作为输入模态。这使得视频生成可以在推理时通过自然语言进行指导,允许提示特定场景,如天气状况或智能体行为(例如汽车偏离车道)。然而,GAIA-1 仅限于图像和视频输出,未来的工作应探索输出的多模态性,使得模型能够解释它所看到的内容及其采取的行动,这有可能推翻关于端到端驾驶堆栈不可解释性的批评。此外,GAIA-1 在潜在空间中生成动作标记,但这些标记尚未解码。解码这些潜在空间中的动作将使得该模型可用于机器人控制,并提高其可解释性。进一步地,可以应用 ImageBind 的原理来扩展输入数据模态(例如,包括深度信息),从而有可能开发出更通用的内部世界表示,并改进下游生成。

来自GAIA-1论文的图表展示了视频生成能力和提示模态。

在这些世界模型发展的背景下,必须承认像 GAIA-1 这样的生成模型对合成数据生成领域可能带来的颠覆性影响。随着这些先进模型在创建逼真且多样化的数据集方面越来越得心应手,它们将彻底改变合成数据的生成方式。目前,主流的汽车合成数据生成方法是使用仿真和基于物理的渲染,通常在游戏引擎中生成场景,完全控制天气、地图和智能体。Synscapes是此类合成数据集生成的开创性工作,作者探讨了工程化数据生成过程的优势,以尽可能接近目标领域,从而应对合成数据与真实数据领域差距对知识迁移的负面影响。

尽管已经在多方面取得了进展来解决这个问题,这种从合成到真实的领域差距仍然是合成数据生成过程中的一种副产物,并且在领域之间知识的迁移性上提出了持续的挑战,阻碍了从仿真中学习的全部潜力。然而,从世界模型中采样合成数据是一种根本不同的做法,也是一个具有说服力的替代方案。模型在描述能力和环境知识上的任何提升都将相互促进,改善模型生成的合成数据质量。这些合成数据直接从模型学习到的分布中进行采样,减少了任何关于分布对齐的担忧——这种对齐是模型与被建模领域之间的对齐,而非涉及受完全不同力量影响的第三方领域。随着生成模型的不断进步,可以预见,这种类型的合成数据生成将取代今天复杂且本质上割裂的生成过程。

迈向未来:自动化系统中的多任务与大型世界模型

自动导航领域正在见证一种在场景理解方法上的有趣演变,这种演变受到了多任务视觉模型和大型世界模型发展推动的影响。我个人的工作以及其他领域专家的研究,成功地利用了多任务模型在感知模块中的应用,展示了其效能和效率。同时,像 Wayve 这样的公司正在开创将大型世界模型应用于自动化领域的先河,预示着可能的范式转变。

多任务视觉模型的紧凑性和数据效率使其成为感知模块中的自然选择。通过同时处理多个视觉任务,它们为传统模块化自动化堆栈提供了务实的解决方案。然而,在这种设计范式下,这些感知模块必须与下游的规划和控制模块结合,以实现自动化操作。这就创造了一系列复杂的组件,执行高度专业化的问题求解,这种结构本身就容易受到累积性错误的影响。每个模块能否表现良好取决于它从前一个环节接收到的信息质量,而在这一管道早期出现的错误可能会被放大。

尽管像英伟达的 DiffStack 这样的作品通过不同任务模块间的可微损失公式构建,旨在提供一个既可学习又可人类解释的最佳解决方案,但模块之间定期结晶化的中介数据表示本质上是一种有损压缩形式,会产生信息瓶颈。此外,将多个模型串联起来会积累它们各自的局限性,影响对世界的表征。

另一方面,LMM 作为世界模型的应用,如 Wayve 的 AV2.0 计划所示,暗示着一种不同的发展轨迹。这些模型以其庞大的参数空间为特点,提出了一种端到端的自治框架,涵盖了感知、规划和控制。尽管它们巨大的规模给训练和部署带来了挑战,但近期的进展正在缓解这些问题,并使得大型模型的使用变得更加可及。

展望未来,我们可以清楚地看到,训练和部署大型模型的障碍正在稳步减少。AI 领域的这一持续进展正在 subtly(微妙地)但显著地改变传统任务特定模型与更大模型之间的动态关系。虽然多任务视觉模型目前在某些方面,如规模和可部署性上占有优势,但在大型模型训练技术和计算效率上的持续进步正在逐步拉平竞争态势。随着这些障碍的不断降低,我们可能会看到人们对更全面和集成化模型的偏好发生转变。

将火种带给人类:让大型模型民主化

图像由作者使用 DALL-E 3 生成。

尽管大型模型具有令人印象深刻的能力,但它们也带来了重大挑战。训练所需的计算资源庞大,激起了关于环境影响和可访问性的担忧,并且为研究和开发设置了进入门槛。幸运的是,有几种工具可以帮助我们将大型基础模型(LFM)的强大功能带入现实:剪枝、量化、知识蒸馏、适配器模块、低秩适配、稀疏注意力、梯度检查点、混合精度训练和开源组件。这些工具箱为我们提供了一个有希望的方案,将从大型模型训练中获得的力量集中到可管理的规模。

一种直观的方法是训练一个大型模型直到收敛,移除对性能贡献最小的参数,然后对剩余网络进行微调。通过移除不重要的权重来减少神经网络的大小和推理成本的网络最小化方法被称为“剪枝”,其起源可以追溯到 1980 年代(参见 LeCun 等人 1989 年发布的《最优大脑损伤》)。2017 年,Nvidia 的研究人员提出了一种有影响力的网络剪枝方法,该方法使用泰勒展开估算去除给定神经元后损失函数的变化,提供了衡量其重要性的指标,从而帮助识别可以在对网络性能影响最小的情况下剪枝的神经元。剪枝过程是迭代的,每减少一次参数后都会进行一轮微调,直到达到所需的准确性和效率平衡。

与此同时,2017 年,谷歌的研究人员发布了一项关于网络量化的开创性工作,提供了一种正交方法来缩小大型预训练模型的大小。作者提出了一种影响深远的 8 位量化方案,适用于权重和激活(包括训练和推理框架),旨在通过使用仅限整数运算的推理提高移动 CPU 上的推理速度。这种量化形式已被应用于大语言模型(LLMs),使其能够在更小的硬件上进行推理(参见TheBloke在 Hugging Face hub 上提供的各种量化模型)。

另一种压缩大型、笨重模型能力的方法是知识蒸馏。2006 年,康奈尔大学的研究人员提出了后来被称为知识蒸馏的概念,在他们的工作 “Model Compression” 中首次展示了这一概念。这项工作成功地探索了训练小型紧凑模型,以近似大型笨重专家(尤其是大型集成模型)所学习的函数。作者利用这些大型专家为多个领域的大型未标注数据集生成标签,并证明了在由这些标签生成的数据集上训练的小模型,比在原始训练集上训练的等效模型在目标任务上表现更好。此外,他们训练小模型去目标化大模型产生的原始 logits,因为它们的相对值包含的信息比硬类标签或 softmax 概率更为丰富,后者在低概率范围的细节和梯度会被压缩。

Hinton 等人在 2015 年扩展了这一概念,并创造了“蒸馏”(distillation)一词,发表了论文 “Distilling Knowledge in a Neural Network”,通过训练小模型来目标化大专家产生的概率,而非原始的 logits,且在最终的 softmax 层中提高温度参数,以产生“一组合适的软目标”。作者指出,这个参数提供了一个可调的放大级别,用于细粒度信息在概率范围低端的放大,并发现容量较小的模型在较低的温度下能更好地过滤掉 logit 值低端的某些细节,集中模型有限的容量于更高层次的交互。他们进一步证明,使用他们的方法,基于原始训练集而非新的大型迁移数据集,同样能够取得良好效果。

在其他大型模型生成的数据上微调大型模型也是一种知识蒸馏的形式。Self-Instruct 提出了一个数据管道,利用大型语言模型(LLM)生成指令调优数据,尽管原始论文展示了如何在 GPT-3 的自我输出上进行微调,Alpaca则使用这种方法,通过 GPT-3.5 的输出对 LLaMA 进行了微调。WizardLM通过引入一种名为 Evol-Instruct 的控制生成指令复杂性的方法,扩展了 Self-Instruct 方法。VicunaKoala则使用了来自ShareGPT的真实人类/ChatGPT 交互进行指令调优。在Orca中,微软研究院警告道,尽管训练较小的模型模仿大型语言模型(LFM)输出时可能会学习到这些模型的写作风格,但它们往往无法捕捉到生成响应时的推理能力。幸运的是,他们的团队发现,通过在生成示例时使用系统指令(例如“逐步思考并证明你的回答”),以促使教师将其推理过程解释为响应的一部分,从而为较小的模型提供了有效的 LFM 思维窗口。Orca 2随后引入了提示删除(prompt erasure),强制较小的模型学习适当的推理策略以应对给定的指令。

上述方法都专注于将大型预训练模型的能力浓缩到可管理的规模,但对于这些大型模型的便捷微调又该如何处理呢?2017 年,Rebuffi 等人引入了适配器模块来进行模型微调。这些小型可训练矩阵可以插入到预训练且被冻结的计算机视觉模型中,以便快速适应新的任务和领域,并且只需要少量的示例。两年后,Houlsby 等人展示了这些适配器在自然语言处理(NLP)中的应用,将一个预训练的 BERT 模型转移到 26 个不同的自然语言分类任务中,取得了接近最先进的表现。适配器使得大型语言模型(LFM)的参数高效微调成为可能,并且可以轻松互换,以便在不同的专家模型之间切换,而不需要为每个任务训练和部署完全不同的模型,这将是非常昂贵且不可行的。

Houlsby 等,2019的示意图,展示了适配器模块在变换器层中的位置。相较于原始模型中的注意力层和前馈层,适配器包含的参数较少。只有绿色块在微调过程中进行训练。

2021 年,微软研究团队改进了这一概念,提出了一种开创性的方法,用于训练一种新的适配器形式——低秩适配(LoRA)。与将适配器矩阵像信用卡一样插入模型中(这种方式会降低模型的推理速度)不同,这种方法学习权重增量矩阵,推理时可以将其与冻结的权重结合,从而提供一种轻量级的适配器,用于在不同的微调任务之间切换基础模型,而不会增加推理延迟。它们通过将权重增量矩阵表示为低秩分解成两个较小的矩阵AB(它们的点积形成原始权重矩阵的形状),从而减少了可训练参数的数量,这一做法受到他们的假设启发(基于Aghajanyan 等,2020),即在微调过程中对权重的更新具有低的内在秩。

低秩适配(LoRA)的示意图。仅在微调过程中训练AB

稀疏变换器进一步探索了通过两种类型的分解自注意力来提高变换器的计算效率。值得注意的是,作者们还采用了梯度检查点,这是一种在训练大型网络时通过在反向传播过程中重新计算激活值而非将其存储在内存中的资源节省方法。该方法对于建模长序列的变换器特别有效,因为在此场景下,计算成本较大,相应的内存占用也较高。这提供了一个有吸引力的权衡:在训练过程中迭代速度略有下降,但可以显著减少 GPU 内存占用,从而使得在硬件限制下可以训练更多变换器层次,并处理更长的序列长度。为了进一步提高效率,稀疏变换器还采用了混合精度训练,其中网络权重以单精度浮点数存储,而激活值和梯度以半精度进行计算。这进一步减少了训练过程中的内存占用,并增加了在给定硬件预算下可训练的模型大小。

最后,一个重要的(也许有些显而易见的)工具,用于普及大型模型的开发和应用,是发布和利用预训练的开源组件。OpenAI 的无处不在的工作马 CLIP 是开源的,并且拥有商业许可的许可证,Meta 的突破性 LFM 发布 LLaMA 2 也是如此。像这些预训练的开源组件,整合了开发 LMM 的大部分繁重工作,因为这些模型通过微调能够迅速推广到新任务,且我们已经知道这在上述贡献的支持下是可行的。值得注意的是,NExT-GPT 仅使用现有的预训练组件和巧妙的对齐学习技术(仅对变换器的输入和输出进行训练投影,占总模型权重的 1%)构建了他们的全到全 LMM。只要最大的机构继续保持对开源理念的承诺,较小的团队将能够高效地做出深远的贡献。

正如我们所看到的,尽管大型模型的规模庞大,但仍有许多互补的方式可以用于它们的精细调优和部署。我们可以通过将其知识蒸馏到较小的模型中,或者将其权重量化为整数来压缩这些模型。我们还可以使用适配器、梯度检查点和混合精度训练高效地进行精调。来自大型研究机构的开源贡献以令人尊敬的速度持续推进,并且看起来正在缩小与封闭源能力之间的差距。在这种环境下,从传统问题表述转向大型序列建模的世界并非冒险的选择。一个最近且具有代表性的成功案例是LaVIN,它使用轻量级适配器将冻结的 LLaMA 转变为 LMM,仅用 3.8M 的参数训练了 1.4 小时,挑战了 LLaVA 的表现,而无需任何端到端的微调。

协同多种 AI 方法:结合多任务与大规模世界模型

虽然 LMM 提供了统一的自动导航解决方案,并且威胁到了模块化自动驾驶堆栈的主导地位,但它们在底层本质上仍然是模块化的,自 MTL(多任务学习)开始以来,LMM 研究中可以看到其遗产。其精神本质上是相同的:在中央网络中捕捉深刻且通用的知识,并使用任务特定的组件提取与特定任务相关的知识。从许多方面来看,LMM 研究是 MTL 的演变。它分享了相同的远见目标,即开发具有通用能力的模型,并标志着迈向 AGI(通用人工智能)的下一步重大进展。由此可见,MTL 的印记贯穿于 LMM 的设计之中。

在现代大规模多模态模型(LMM)中,输入数据的各个模态会分别编码到联合嵌入空间中,然后传递给语言模型,因此在这些编码器的实验中具有灵活性。例如,许多 LMM 中使用的 CLIP 图像编码器通常采用 ViT-L(307M 参数),而关于其他选项的实验工作较少。一个可能的竞争者是 PVTv2-B5,它只有 82M 参数,在 ImageNet 基准测试中的得分仅比 ViT-L 低 1.5%。很有可能,像 PVTv2 这样的分层 Transformer 能够创建有效的图像-语言对齐图像编码器,而且参数更少,从而显著减少 LMM 的总体大小。

同样,在解码器设计中应用多任务学习(MTL)的经验也有空间,以处理 LMM 所提供的输出数据模态。例如,Multiformer 中使用的解码器非常轻量,但能够从联合特征空间中提取准确的深度、语义分割和目标检测。将他们的设计原则应用于 LMM 的解码端,可能会产生这些模态的输出,且可以通过监督来构建更深层次、更广泛的中央嵌入空间知识。

另一方面,NExT-GPT 展示了在 LMM 输入端添加深度等数据模态的可行性和优势,因此将像 Multiformer 这样的模型中准确的多任务推理编码到 LMM 输入中,是未来研究的一个有趣方向。一个训练良好且具通用性的专家可能生成高质量的伪标签来处理这些额外的模态,从而避免在训练 LMM 时需要标注数据,但仍然可以使模型将嵌入空间与这些模态的可靠表示对齐。

无论如何,自动导航领域向 LMM 的过渡远不是一种敌对接管。数十年的多任务学习(MTL)和强化学习(RL)研究所获得的经验,在人工智能研究的前沿找到了一个令人兴奋的新舞台。自动驾驶公司在标注原始数据上花费了大量资金,许多公司可能拥有大量的顺序未标注数据,非常适合自监督世界建模任务。鉴于本文讨论的启示,我希望他们正在关注这一点。

结论

在本文中,我们看到了自动驾驶(AV)开发中范式转变的曙光,这种转变凭借其优势,可能威胁到模块化驾驶堆栈作为该领域主流方法的地位。AV2.0 的这一新方法在顺序世界建模任务中使用大语言模型(LMM),通过结合先前的传感器数据和控制操作来预测未来状态,同时还包括文本等其他模态,从而提供了一种简化问题陈述和统一架构中的感知、规划和控制的合成。之前,许多人认为端到端方法对于安全关键型部署来说过于黑箱化,因为它们的内部状态和决策过程不可解释。然而,随着大语言模型根据传感器数据做出驾驶决策,如果模型被提示,它有可能用自然语言解释它所感知到的内容以及背后的推理。这种模型还可以从自己想象中采样合成示例进行学习,减少对真实世界数据收集的需求。

尽管这种方法的潜力很诱人,但它需要非常大的模型才能有效,因此也继承了大模型的局限性和挑战。能够训练或微调一个数十亿参数的多层次大模型的机构屈指可数,而大模型伴随着很多效率问题,从计算成本到嵌入式硬件的尺寸都有考虑。然而,我们已经看到,存在许多强大的开源工具和获得商业许可的语言模型(LFM),各种参数高效微调方法使得定制变得可行,还有压缩技术使得在可控规模下部署成为可能。鉴于这些因素,回避采用大模型来解决像自动驾驶机器人这样的复杂问题似乎难以站得住脚,并且这将忽视通过不断发展的技术来为系统提供未来保障的价值,而不是固守那些可能已经达到瓶颈的方法。

尽管如此,小型多任务模型在其相对微小的规模上具有巨大的优势,这使得它们在可获取性和实验便捷性方面占据优势,同时简化了许多工程和预算决策。然而,任务特定模型的局限性带来了不同的挑战,因为这些模型必须以复杂的模块化架构进行组织,才能在自主系统堆栈中完成所有必要功能。这种设计导致信息通过感知、预测、规划,再到控制堆栈的顺序流动,增加了因这些顺序组件而带来的错误积累风险,并妨碍了端到端优化。此外,尽管在这一范式中整体参数量可能远低于其他架构,但堆栈的复杂性无疑更高,因为众多组件涉及各自研究领域的专门问题表述,需要一个由不同学科的高技能工程师组成的大团队来维护和开发。

大型模型已经展示了对信息进行推理的深刻能力,并能够在多个模态中将这种能力推广到新的任务和领域,这在深度学习领域中一直是一个难以实现的目标。长期以来,大家都知道,通过监督学习训练的模型在面对来自其训练分布之外的示例时非常脆弱,它们能够非常出色地执行某一(甚至多个)任务,但要称其为“智能”几乎不值得一提。如今,在几年的爆发式发展之后,2020 年似乎已成青铜时代,人工智能研究领域的“伟大白水牛”终于现身,首先表现为巨型聊天机器人的特性,现在它已轻松获得视觉和听觉的能力。这项技术以及它所引发的机器人技术革命似乎即将在几年内,甚至更早,为灵活的机器人控制提供突破,自动驾驶汽车(AVs)将是首批向世界展示这一技术力量的领域之一。

未来工作

如上所述,驱动许多大规模语言模型(LMM)的 CLIP 编码器通常由 ViT-L 构成,而我们已经错过了尝试更现代架构的时机。像 PVTv2 这样的层次化变换器在 ImageNet 上几乎能与 ViT-L 的表现相匹敌,但参数量却要少得多,因此它们可能是作为语言对齐图像编码器在紧凑型 LMM 中使用的有力候选者。

IRIS 和 GAIA-1 为使用 LMM 构建世界模型指明了前进的道路。然而,这两种模型的输出方式都有限。两者都使用自回归变换器来预测未来的帧和奖励,但虽然 GAIA-1 支持文本提示,两者都没有设计用于生成文本,这将是评估推理能力和解释失败模式的一个重要步骤。

在这一阶段,领域将从发布像 GAIA-1 这样的开源生成世界模型中受益匪浅,但需要采用全对全的模式方案,能够在输出中提供自然语言和动作。这可以通过添加适配器、编码器、解码器和修订问题声明来实现。很可能,组建这样一个架构所需的预训练组件已经存在,并且可以通过合理数量的可训练参数进行对齐,因此这为研究提供了一个开放的方向。

此外,正如 Mixtral 8X7B 所展示的那样,小型模型的 MoE(专家混合)配置可以超越大型单一模型的性能,未来的研究应该探索基于 MoE 配置的 LMM(大规模语言模型)世界模型。进一步地,将大型 MoE 模型蒸馏成单一模型已被证明是一种有效的模型压缩方法,并可能将大型世界模型的性能提升到下一个水平,因此这为创建 MoE LMM 世界模型提供了额外的动力。

最后,使用具有商业许可的合成数据对开源模型进行微调应该成为标准做法。因为 Vicuna、WizardLM 和 Orca 是通过 ChatGPT 的输出进行训练的,这些预训练的权重本质上仅限于研究用途,因此尽管这些发布提供了强大的微调 LLM 方法论,但它们并没有完全“民主化”这一力量,因为任何试图将这些方法创建的模型用于商业目的的人都必须花费必要的自然和财力资源来收集新的数据集并重复实验。应该推动使用像 Evol-Instruct 这样的商业许可开源模型生成合成指令调优数据集,而不是使用 ChatGPT,这样使用这些数据集训练的权重才能真正实现“民主化”,有助于提升那些资源较少的人的能力。

导航最新的生成式 AI 公告——2024 年 7 月

原文:towardsdatascience.com/navigating-the-latest-genai-model-announcements-july-2024-461f227f588f?source=collection_archive---------7-----------------------#2024-07-26

关于新模型 GPT-4o mini、Llama 3.1、Mistral NeMo 12B 以及其他生成式 AI 趋势的指南

Tula MastermanTowards Data Science Tula Masterman

·发表于 Towards Data Science ·阅读时间 7 分钟·2024 年 7 月 26 日

--

图片由作者使用 GPT-4o 生成,展示不同的模型

引言

自 2022 年 11 月 ChatGPT 发布以来,几乎每周都有新的模型、创新的提示方法、创新的代理框架或其他令人兴奋的生成式 AI 突破。2024 年 7 月也不例外:仅在本月,我们就见证了 Mistral Codestral MambaMistral NeMo 12BGPT-4o miniLlama 3.1 等模型的发布。这些模型在推理速度、推理能力、编程能力和工具调用性能等领域带来了显著提升,使它们成为企业使用的有力选择。

在本文中,我们将介绍最近发布的模型亮点,并讨论当今生成式 AI(GenAI)领域的一些主要趋势,包括增加上下文窗口大小、以及跨语言和跨模态的性能提升。

2024 年 7 月发布模型概述

Mistral Codestral Mamba

  • 概述:Codestral Mamba 7B 采用了Mamba 架构,旨在提供增强的推理和编码能力,而不是大多数语言模型使用的 Transformer 架构。该架构使得在更长序列中进行上下文检索成为可能,已被测试支持最多 256K tokens 的序列。相比之下,大多数基于 Transformer 的模型仅支持 8-128K tokens 的上下文窗口。Mamba 架构还能够比基于 Transformer 的模型提供更快的推理速度。

  • 可用性:Codestral Mamba 是一个开放源代码模型,采用 Apache 2.0 许可证。

  • 性能:Codestral Mamba 7B 在 HumanEval、MBPP、CruxE、HumanEval C++和 Human Eval JavaScript 基准测试中超越了 CodeGemma-1.1 7B、CodeLlama 7B 和 DeepSeekv1.5 7B。尽管其规模较小,但它在这些基准测试中的表现与 Codestral 22B 相似。

图像由作者基于 Mistral AI Codestral Mamba 公告的结果创建

Mistral NeMo 12B

  • 概述:Mistral NeMo 12B 是由 Mistral 和 Nvidia 联合推出,旨在提供一个具有竞争力的 12B 参数范围的语言模型,且其上下文窗口比大多数同类大小的模型要大得多。NeMo 12B 具有128K token 上下文窗口,而同样大小的模型 Gemma 2 9B 和 Llama 3 8B 仅提供 8K token 的上下文窗口。NeMo旨在支持多语言应用场景,并提供了一种新的分词器——Tekken,该分词器在压缩 85%的语言文本时,比 Llama 3 的分词器表现更优。HuggingFace 模型卡指出,NeMo 应在比早期 Mistral 模型更低的温度下使用,他们建议将温度设置为 0.3。

  • 可用性:NeMo 12B 是一个开放源代码模型(提供基础版和指令调优版检查点),采用 Apache 2.0 许可证。

  • 性能:Mistral NeMo 12B 在多个零样本和五样本基准测试中,比 Gemma 2 9B 和 Llama 3 8B 的表现高出多达 10%。在 WildBench 测试中,其表现几乎是 Mistral 7B 的 2 倍,WildBench 旨在衡量模型在需要复杂推理和多轮对话的现实任务中的表现。

图像由作者基于Mistral AI NeMo 公告的结果创建

Mistral Large 2

  • 概述Mistral Large 2 提供了 128K token 上下文窗口、改进的函数调用、支持多种语言和 80 多种编程语言。与 Codestral Mamba 和 NeMo 一样,Mistral Large 2 在大量代码的基础上进行训练,使其在与 GPT-4o、Claude 3 Opus 和 Llama 3.1 405B 的竞争中表现出色。在训练过程中,Mistral 团队 专注于减少模型出现幻觉的可能性,使得 Mistral Large 2 更倾向于回应它无法找到答案或缺乏提供回答所需的信息。

  • 可用性:Mistral Large 2 在 Mistral Research License 许可下提供。这允许在研究和非商业用途下进行实验和修改。对于那些有意在商业上使用 Mistral Large 2 的用户,可以通过 直接联系 Mistral AI 并申请 Mistral 商业许可。

  • 性能:Mistral Large 2 在函数调用任务上优于 GPT-4o 和 Claude 3 Opus,并且在遵循指令和对齐任务上表现与这些模型相似,经过 Wild Bench 和 Arena Hard 基准测试评估。

GPT-4o mini

  • 概述:GPT-4o mini 是一个小型、成本效益高的模型,支持文本和视觉,并提供具有竞争力的推理和工具调用性能。它具有 128K token 上下文窗口 和令人印象深刻的 16K token 输出长度。它是 OpenAI 中最具成本效益的模型,每百万输入 tokens 的费用为 15 美分,每百万输出 tokens 的费用为 60 美分。OpenAI 提到,这一价格比其 2022 年的 text-davinci-003 模型便宜 99%,这表明在相对较短的时间内,模型趋向于更便宜、更小巧且功能更强大。虽然 GPT-4o mini 不像 GPT-4o 那样支持图像、视频和音频输入,但 OpenAI 报告称这些功能即将推出。与 GPT-4o 一样,GPT-4o mini 采用了内置的安全措施,并且是第一个应用 指令层次结构 方法 的 OpenAI 模型,该方法旨在使模型 更能抵抗提示注入和越狱。GPT-4o mini 使用与 GPT-4o 相同的分词器,从而实现 在非英语文本上的性能提升。在 GPT-4o mini 宣布后不久,OpenAI 还宣布了一个 实验性 64K token 输出 的 GPT-4o,用户可以通过其 Alpha 项目进行体验。

  • 可用性:GPT-4o mini 是一个闭源模型,通过 OpenAI 的 Assistants API、Chat Completions API 和 Batch API 提供。它也可以通过 Azure AI 访问。

  • 性能:GPT-4o mini 在多个基准测试中优于同类尺寸的 Gemini Flash 和 Claude Haiku 模型,包括MMLU(大规模多任务语言理解,旨在衡量推理能力)、MGSM(多语言小学数学,衡量数学推理能力)、HumanEval(衡量编程能力)和MMMU(大规模多学科多模态理解与推理基准,衡量多模态推理能力)。

该图由作者根据Feynmans Learning Method结果绘制

Llama 3.1

  • 概述:Llama 3.1 引入了128K token 上下文窗口,这是相较于 Llama 3(仅三个月前在四月发布)的 8K token 上下文窗口的重大进步。Llama 3.1提供三种规模:405B、70B 和 8B。它提升了推理、工具调用和多语言性能。Meta 在其 Llama 3.1 的公告中称Llama 3.1 405B 是“首个前沿级别的开源 AI 模型”。这展示了开源社区的巨大进步,也体现了 Meta 致力于使 AI 更具可访问性的承诺。马克·扎克伯格在他的文章“开源 AI 是前进之路”中详细讨论了这一点。Llama 3.1 的公告还包括了关于如何启用常见用例的指南,如实时和批处理推理、微调、RAG、持续预训练、合成数据生成和蒸馏。Meta 还发布了Llama Reference System,以支持开发者基于 Llama 3.1 进行代理性用例开发,此外还发布了其他AI 安全工具,包括 Llama Guard 3 用于多语言的输入和输出内容的监控、Prompt Guard 用于减轻提示注入攻击、以及 CyberSecEval 3 用于降低生成式 AI 的安全风险。

  • 可用性:Llama 3.1 是一个开源模型。Meta 已更改其许可证,允许开发者使用 Llama 模型的输出用于训练和改进其他模型。模型可以通过 HuggingFace、llama.meta.com 以及其他合作平台如 Azure AI 获取。

  • 性能:每个 Llama 3.1 模型在其规模类别中,几乎在所有常见的语言模型基准测试中,包括推理、编码、数学、工具使用、长上下文和多语言性能上,都优于其他模型。

该图由作者根据Meta Llama 3.1 公告结果绘制

生成式 AI 模型的趋势

总的来说,所有规模的模型呈现出越来越强大的趋势,拥有更长的上下文窗口、更长的令牌输出长度以及更低的价格点。推动推理、工具调用和编码能力的提升反映了对能够代表用户执行复杂操作的代理系统的需求不断增加。为了创建有效的代理系统,模型需要理解如何分解问题、如何使用可用的工具以及如何在一次处理大量信息时进行调和。

最近,OpenAI 和 Meta 的公告反映了 AI 安全讨论的日益增加,两家公司展示了应对同一挑战的不同方法。OpenAI 采取了闭源的方式,通过采纳社会心理学和虚假信息领域专家的反馈以及实施新的训练方法来提高模型的安全性。相比之下,Meta 加大了其开源计划的力度,发布了新的工具,专注于帮助开发者减轻 AI 安全问题。

这张图片由作者使用 GPT-4o 创建,展示了一个闭源和开源模型竞争的竞技场。

结论

未来,我认为我们将继续看到通用模型和专业模型的进展,像 GPT-4o 和 Llama 3.1 这样的前沿模型将越来越擅长于分解问题,并在跨模态执行各种任务时表现得越来越好,而像 Codestral Mamba 这样的专业模型将在其领域中表现出色,并在处理更长上下文和微妙任务时变得更加得心应手。此外,我预计我们将看到新的基准,专注于模型在单次回合内同时执行多个指令的能力,以及越来越多的 AI 系统通过结合通用模型和专业模型协作执行任务。

此外,虽然模型性能通常是根据标准基准来衡量的,但最终重要的是人类如何看待模型的表现,以及模型在多大程度上有效地推动人类目标的实现。Llama 3.1 的发布包含了一张有趣的图表,展示了人们如何对 Llama 3.1 与 GPT-4o、GPT-4 和 Claude 3.5 的回答进行评分。结果表明,Llama 3.1 在超过 50%的示例中与其他模型平局,剩余的胜率大致在 Llama 3.1 和其挑战者之间分配。这一点具有重要意义,因为它表明开源模型现在可以在一个以前由闭源模型主导的领域中轻松竞争。

想要进一步讨论或合作?请通过 LinkedIn与我联系!

穿越迷宫:数据架构指南

原文:towardsdatascience.com/navigating-the-maze-a-guide-to-data-architectures-05e2ba412225?source=collection_archive---------10-----------------------#2024-05-13

数据是当今任何组织的命脉。但是,如何确保你的数据组织有序、安全并且能满足所有业务需求呢?这就是数据架构的作用所在。

Antonio GrandinettiTowards Data Science Antonio Grandinetti

·发表于Towards Data Science ·阅读时长 5 分钟·2024 年 5 月 13 日

--

数据架构是指信息系统中数据的整体设计和组织。没有一种万能的🪄解决方案,但了解不同的选项能帮助你选择最适合你特定需求的方法。

在这篇文章中,我想向你展示一些关键的数据架构概念和常见的解决方案。

支柱

图片由Clark Van Der Beken提供,来源于Unsplash

让我们来看一下数据架构的每个支柱,了解它们在创建一个良好管理的数据生态系统中的作用:

数据存储: 这是基础部分,数据的物理存放位置。它包括多种选项,例如:

  • 硬盘驱动器(HDDs): 传统存储,提供高容量但成本较低,访问速度较慢。

  • 固态硬盘(SSDs): 更快的访问时间,适用于频繁访问的数据,但每单位存储的价格通常较高。

探索新的 LLM 代理和架构类型

原文:towardsdatascience.com/navigating-the-new-types-of-llm-agents-and-architectures-309382ce9f88?source=collection_archive---------0-----------------------#2024-08-30

图像由作者使用 Dall-E 创作

ReAct 代理的失败为新一代代理及其可能性铺平了道路

Aparna DhinakaranTowards Data Science Aparna Dhinakaran

·发表于 Towards Data Science ·10 分钟阅读·2024 年 8 月 30 日

--

感谢 John Gilhuly 为本篇文章的贡献

如果 2023 年是检索增强生成(Retrieval Augmented Generation, RAG)的一年,那么 2024 年就是代理的年份。全球各地的公司正在尝试聊天机器人代理,像 MultiOn 这样的工具通过将代理连接到外部网站而得到发展,像 LangGraph 和 LlamaIndex Workflows 这样的框架帮助世界各地的开发人员构建结构化代理。

然而,尽管代理受到了广泛关注,但它们仍未在人工智能生态系统之外取得显著的突破。无论是消费者还是企业用户中,只有少数代理得以推广。

团队如何在新的框架和代理方向中航行?有哪些工具可用,你应该使用哪些工具来构建下一个应用程序?作为一家最近在产品中构建了一个复杂代理作为副驾驶的公司的领导者,我们对此话题有一些见解。

定义代理

首先,定义我们所说的代理是非常重要的。基于 LLM 的代理是将多个处理步骤(包括对 LLM 的调用)串联起来的软件系统,以实现期望的最终结果。代理通常具有一定的条件逻辑或决策能力,并且拥有在各步骤之间可以访问的工作记忆。

让我们深入了解今天代理的构建方式、现代代理存在的问题以及一些初步的解决方案。

ReAct 代理的失败

让我们诚实一点,智能体的概念并不新鲜。在过去一年中,AI Twitter 上推出了无数智能体,宣称具有惊人的智能表现。这些第一代智能体主要是ReAct(推理,行动)智能体。它们的设计目标是尽可能地抽象化,并承诺能带来广泛的结果。

不幸的是,第一代智能体架构确实面临了很多挑战。它们的高度抽象化使得使用起来困难,尽管它们做出了宏大的承诺,但最终它们几乎什么都做不了

针对此情况,许多人开始重新思考智能体的结构方式。在过去的一年里,我们见证了巨大的进展,这也将我们引入了下一代智能体。

什么是第二代智能体?

这一代新的智能体建立在定义智能体可以走的路径的原则上,这些路径比 ReAct 的开放性方式更为严格。无论智能体是否使用框架,我们已经看到一种趋势,那就是智能体的解决方案空间变得更小——也就是说,每个智能体能做的事情减少了。更小的解决方案空间意味着智能体更容易定义,这通常也会导致更强大的智能体。

第二代智能体涵盖了许多不同类型的智能体,但值得注意的是,我们今天看到的大多数智能体或助手都是通过无框架的代码编写的,具有 LLM 路由阶段,并且以迭代循环的方式处理数据。

什么构成了一个智能体?

许多智能体都有一个叫做路由器的节点或组件,用来决定智能体应该执行哪个下一步。路由器通常是指由 LLM 或分类器做出的路径选择决策。智能体在执行过程中可能会不断返回这个路由器,每次都带上一些更新的信息。路由器会利用这些信息,结合它对可能的下一步的已有知识,选择接下来的行动。

路由器本身有时由对 LLM 的调用提供支持。目前大多数流行的 LLM 都支持函数调用,它们可以从函数定义的 JSON 字典中选择一个组件进行调用。这种能力使得路由步骤的初步设置变得容易。然而,正如我们稍后将看到的,路由器通常是智能体中需要最多改进的步骤,因此这种设置的简便性掩盖了其背后的复杂性。

智能体可以执行的每一个动作通常由一个组件表示。组件是完成特定小任务的代码块。这些组件可以调用 LLM,或进行多次 LLM 调用,发起内部 API 调用,或仅仅运行某种应用代码。在不同的框架中,它们有不同的名称。在 LangGraph 中,这些称为节点;在 LlamaIndex Workflows 中,它们被称为步骤。一旦组件完成任务,它可能会返回到路由器,或者转到其他决策组件。

根据代理的复杂性,将组件作为执行分支或技能进行分组可能会有所帮助。假设你有一个客户服务聊天机器人代理。这个代理能够做的事情之一是检查订单的运输状态。为了实现这一功能,代理需要从用户的查询中提取订单 ID,创建一个 API 调用到后端系统,执行该 API,解析结果,并生成回应。每一个步骤都可能是一个组件,它们可以被归类为“检查运输状态”技能。

最后,许多代理在执行过程中会追踪共享状态或记忆。这使得代理更容易在各个组件之间传递上下文。

代理架构的示例

今天我们在代理部署中看到一些常见的模式。我们将在接下来的部分中概述所有这些架构,但下面的示例可能是最常见的。

在其最简单的形式下,代理或助手可能仅由一个 LLM 路由器和一个工具调用来定义。我们将这个第一个示例称为单路由器与函数。我们有一个单一的路由器,可以是一个 LLM 调用、一个分类器调用,或仅仅是简单的代码,来指引和协调要调用哪个函数。其理念是,路由器可以根据系统的输入来决定调用哪个工具或功能。单路由器来自于我们在这个架构中只使用了一个路由器这一事实。

图示由作者提供

我们看到的一个稍微复杂一些的助手是单路由器与技能。在这种情况下,路由器不再仅仅调用一个简单的工具或函数,而是可以调用一个更复杂的工作流或技能集,这些技能集可能包括多个组件,是一系列更深的链式动作。这些组件(LLM、API、工具、RAG 和代码调用)可以被循环和链式连接,以形成一个技能。

这可能是目前我们看到的最常见的生产环境中高级 LLM 应用团队使用的架构。

图示由作者提供

一般架构通过将 LLM 调用、工具和状态的分支混合在一起,变得更加复杂。在这个案例中,路由器决定调用其哪些技能(用红色标注)来回答用户的问题。它还可能根据这个问题更新共享状态。每个技能也可以访问共享状态,并可能涉及一个或多个 LLM 调用,以获取对用户的回应。

图示由作者提供

这通常仍然比较简单,然而,代理通常要复杂得多。随着代理变得更加复杂,你会开始看到一些框架被构建出来,以帮助降低这种复杂性。

代理架构框架

LangGraph

LangGraph 基于预先存在的 Pregel 图概念,但将其转化为代理。在 LangGraph 中,你定义代理可以沿着其移动的节点和边缘。虽然在 LangGraph 中定义一个路由节点是可能的,但除非你正在处理多代理应用,否则通常是不必要的。相反,原本可以在路由器中存在的相同条件逻辑现在存在于 LangGraph 引入的节点和条件边缘对象中。

这是一个 LangGraph 代理的示例,它可以回应用户的问候,或者执行某种 RAG 信息查找:

图示由作者提供

在这里,路由逻辑实际上存在于节点和条件边缘中,这些边缘根据函数响应决定是否在不同节点之间移动用户。在这个例子中,is_greetingcheck_rag_response 是条件边缘。定义其中一个边缘如下所示:

graph.add_conditional_edges("classify_input", is_greeting, {True: "handle_greeting", False: "handle_RAG"})

我们不是将所有的路由逻辑集中在一个节点中,而是将它分布在相关的边缘之间。这在你需要为代理强加一个预定义结构,并且希望将各个逻辑部分分开时特别有用。

LlamaIndex 工作流

LlamaIndex 工作流 这样的其他框架采用不同的方法,使用事件和事件监听器在节点之间移动。与 LangGraph 类似,工作流不一定需要一个路由节点来处理代理的条件逻辑。相反,工作流依赖于各个节点,或者它们称之为“步骤”,来处理传入事件,并广播传出的事件,以便其他步骤处理。这导致大多数工作流逻辑都在每个步骤内部处理,而不是在步骤和节点之间处理。

作为 LlamaIndex 工作流的反射型 SQL 生成代理(图示由作者提供)

CrewAI、Autogen、Swarm 等

还有其他框架旨在简化代理开发,其中包括一些专注于处理多个代理协同工作的框架。这个领域发展迅速,值得查看这些及其他框架。

考虑代理时的关键问题

是否应该使用框架来开发你的代理?

无论你使用什么框架,这些工具提供的额外结构都能帮助你构建代理应用。使用这些框架是否对创建更大、更复杂的应用有益是一个更具挑战性的问题。

我们在这一领域有一个相当强烈的观点,因为我们自己也构建了一个助手。我们的助手使用多层路由架构,具有与当前框架的一些抽象概念相呼应的分支和步骤。在 LangGraph 稳定之前,我们就开始构建我们的助手。因此,我们不断自问:如果我们从零开始,我们会使用当前的框架抽象吗?它们能胜任这一任务吗?

当前的回答是“尚不需要”。整体系统的复杂性太大,不适合基于 Pregel 的架构。如果你眯着眼睛看,可以将其映射为节点和边,但软件抽象可能会阻碍这一点。就目前而言,我们团队更倾向于使用代码而非框架。

然而,我们确实看到了代理框架方法的价值。也就是说,它确实强制执行了一种架构,这种架构具有一些最佳实践和良好的工具。它们也在不断进步,扩展它们的适用范围以及你可以用它们做什么。随着这些框架的改进,我们的回答很可能在不久的将来会发生变化。

你真的需要一个代理吗?

这引出了另一个重要问题:究竟什么类型的应用程序需要代理?毕竟,代理涵盖了广泛的系统 —— 而且如今关于什么是“代理化”的话题也充满了炒作。

这里有三个标准来判断你是否可能需要代理:

  • 你的应用程序是否基于输入数据遵循迭代流?

  • 你的应用程序是否需要根据先前采取的行动或途中反馈来调整并遵循不同的流程?

  • 是否存在可以采取的行动的状态空间?该状态空间可以通过多种方式遍历,而不仅仅限于线性路径。

需要预期的常见问题有哪些?

假设你对其中一个问题回答是肯定的,并且需要一个代理。以下是你在构建时需要注意的几个已知问题。

第一个是长期规划。虽然代理很强大,但它们在将复杂任务分解为逻辑计划方面仍然存在困难。更糟的是,它们常常会陷入循环,无法找到解决方案。代理还在格式不正确的工具调用方面存在困难。这通常是由于驱动代理的底层 LLM 所致。在每种情况下,通常需要人工干预来纠正方向。

另一个需要注意的问题是由于解决方案空间的广阔而导致的性能不一致。代理可以采取的可能行动和路径的数量庞大,使得难以实现一致的结果,并且往往推高成本。也许正因为如此,市场正在趋向于受限代理,只能从一组可能的行动中进行选择,有效地限制了解决方案空间。

应对这些挑战的一些策略是什么?

如前所述,最有效的策略之一是事先映射或缩小解决方案空间。通过彻底定义可能的行动和结果范围,你可以减少模糊性。将领域和业务启发式融入代理的指导系统也是一个简单且有效的方法,它为代理提供了做出更好决策所需的背景。明确行动意图(清晰定义每个行动的目的)并创建可重复的过程(标准化代理遵循的步骤和方法)也能增强可靠性,并在出现错误时更容易识别和纠正。

最后,通过代码和更可靠的方法进行协调,而不是仅仅依赖于 LLM 规划,可以显著提高代理的性能。这涉及到在可能的情况下用基于代码的路由器替代 LLM 路由器。通过使用基于代码的协调,你可以实现更确定性和可控的过程,减少常伴 LLM 规划的不可预测性。

结论

在充满 FOMO 的新框架和狂热生成式 AI 环境中,过度炒作的情况下,很容易忽视基本问题。在全力投入 MVP 之前,花时间思考现代代理框架在何时何地可能适合——或不适合——你的使用场景,始终是值得的。

有问题吗?欢迎在此联系我或 在 Slack 上 ,或者在我们每两周一次的 AI 研究论文 阅读会中找到我。

探索作为数据科学家的现实

原文:towardsdatascience.com/navigating-the-realities-of-being-a-data-scientist-429625f9baed?source=collection_archive---------8-----------------------#2024-02-24

作为数据科学家,我经常面临的一些挑战

Egor HowellTowards Data Science Egor Howell

·发表于 Towards Data Science ·9 分钟阅读·2024 年 2 月 24 日

--

由 ThisIsEngineering 提供的照片,来自 Pexels:www.pexels.com/photo/female-software-engineer-coding-on-computer-3861972/

表面上看,成为数据科学家似乎一切都很美好(至少我觉得这是我从我的文章中给出的印象!)。高薪、优厚的福利、灵活的工作时间和有趣的工作内容是人们在想到数据科学工作时常常浮现的几个方面。

事实上,我甚至写过一篇文章,讲述为什么数据科学是一个很酷的职业!

[## 数据科学为何如此棒的五个理由

列出成为数据科学家时应该考虑的五个关键原因

pub.towardsai.net](https://pub.towardsai.net/5-reasons-why-data-science-is-awesome-af53c94d3d96?source=post_page-----429625f9baed--------------------------------)

虽然这些都是无可否认的事实,但每个工作背后都有一些隐藏的困难,数据科学家也不例外。别误会,它是一个非常棒的工作,我非常热爱这个领域,但并非每一天都是完全光辉灿烂的。

这就是为什么在这篇文章中,我想深入探讨一些你必须接受的现实或原则,这些也是我经常面临的挑战。希望这能帮助任何想成为数据科学家的人,决定这个领域是否适合他们。

探索 LLM 代理的世界:初学者指南

原文:towardsdatascience.com/navigating-the-world-of-llm-agents-a-beginners-guide-3b8d499db7a9?source=collection_archive---------0-----------------------#2024-01-10

一步一步指南,发现并利用 LLM 代理和工具包的力量

Dominik PolzerTowards Data Science Dominik Polzer

·发表于 Towards Data Science ·19 分钟阅读·2024 年 1 月 10 日

--

思维链提示——图片由作者提供

目录

简介

什么是代理?

我们需要什么来构建自己的智能助手?

我们如何连接所有必要的技能?—— 代理背后的理论

代理是如何工作的?—— 思维链

代理执行器——代理背后的代理

从理论到实践

如何使用 SQLDatabaseToolkit?

实践教程

总结

简介

这篇文章讲述的是如何让 LLM(大语言模型)独立解决复杂任务。

人类在不断吸收信息、做出决策、采取行动、观察变化,然后做出下一个决策的能力上非常出色。我们的一生就是一个永无止境的观察、思考和行动的链条。我们希望将这一概念应用于 LLM,使其能够持续做出新的决策,从而逐步接近复杂问题的解决方案。

驾驭数据平台的成长痛点:从数据混乱到数据网格的路径

原文:towardsdatascience.com/navigating-your-data-platforms-growing-pains-a-path-from-data-mess-to-data-mesh-c16df72f5463?source=collection_archive---------3-----------------------#2024-04-03

与软件开发团队不同,数据团队缺乏应对可扩展性挑战的成熟方法。本文提供了一套指导原则,帮助你有效扩展数据平台,同时最大化其商业影响。

Mahdi KarabibenTowards Data Science Mahdi Karabiben

·发表于Towards Data Science ·10 分钟阅读·2024 年 4 月 3 日

--

图片由Jack Anstey提供,来自Unsplash

在开发软件组件时,开发人员可以利用多种框架、设计模式和原则来扩展他们的产品,并无缝调整架构以支持新的用例,处理不断增加的使用量和复杂性。这使得软件工程团队能够确保随着平台(及其价值)规模的增长,优化性能和可靠性。

然而,数据团队并不那么幸运。尽管数据平台生命周期的最初几个月通常充满了解决复杂技术挑战的兴奋和交付第一批数据产品的喜悦,但接下来的往往是日益复杂的螺旋、成本的上升以及回报的递减。

与我们作为数据团队需要应对的其他问题不同,我们在可扩展性方面的挑战与软件团队所面临的挑战本质上不同。在数据领域,这些挑战表现为不可避免的技术复杂性(比如在不断扩展的系统列表中混合多种模式来移动和转换数据),以及数据平台在公司中的独特定位(因为最终每个业务部门都将直接或间接地与其连接)。

所以,在这个后 MDS 时代,数据团队在支出方面受到严格审查,并不断被要求展示其价值,定义标准和原则以成功扩展数据平台比以往任何时候都更为重要。本文将聚焦于实现这一目标的五项关键原则,并提供应用这些原则的策略。

1. 不要忽视重要的东西(即商业价值)

在大多数情况下,数据平台有潜力成为公司最有价值的资产之一。然而,在证明其价值的过程中,不幸的是,数据团队常常专注于做什么而不是为什么做

如果你浏览网上的相关数据内容,你大多会看到复杂的数据架构,或者是对某些技术的赞扬/批评。另一方面,专注于支持下游用例或衡量数据项目影响的内容相对较少。我认为这是数据团队需要解决的一个关键问题的症状:我们很容易过于专注于我们正在构建的东西,而忽视了我们为什么要构建它

这个问题最终可能导致数据团队走上“为了构建而构建”的道路。相反,更具可扩展性的方法是寻找能够创造价值的数据项目,并持续反思(并在可能时衡量)所创造的价值。为此,我推荐以下四种策略:

找出哪些数据是重要的(并非所有数据都同等重要)

由于今天的技术使得生成、存储和转换海量数据变得极其容易,数据团队很快会被他们需要支持的快速增长的数据集所淹没。然而,您应该给予数据集的关注程度必须基于其重要性,这一重要性是通过将其与下游用例相连接来确定的。

将每个项目/计划与它将创造的价值联系起来

在一个数据项目获得批准之前,应该评估它将为业务带来多少价值(直接或间接)。这一重要的工作确保了与利益相关者的一致性,并允许你将数据项目映射到下游的倡议。为数据管道和其他类型的数据组件编写设计文档是确保你所构建的内容与明确的业务目标和指标一致的好方法。

持续寻找新的高影响力使用案例

数据团队通常依赖于从其他部门和业务单元接收请求和潜在项目。尽管依赖这些流程可能很有吸引力,但你应该始终关注那些可以利用数据的潜在价值创造领域。无论是内部数据产品还是可以嵌入公司产品中的数据应用,关键的使用案例可能会因为数据团队在自己的小角落里构建事物而被忽视。

经常向用户索取反馈

数据团队很容易集中注意力于错误的指标。例如,使用量的增加可能让我们感到有所成就。然而,在数据驱动的环境中,人们需要全方位的数据——因此,使用量的增加并不一定意味着一切顺利,或者使用量正在转化为价值创造的使用案例。随着平台的扩展,问题最终会显现出来(从数据质量问题到开发速度缓慢),而与用户的脱节意味着你将不得不追赶上去,才能重新获得他们的信任。相反,尝试经常要求反馈,无论是通过你提供的工具中的嵌入式反馈机制,还是每隔几个月发送给用户的简单反馈表单。

2. 自动化标准是你最好的朋友

扩展数据平台的最大挑战之一是被拉向多个方向。无论是由于需要新数据组件的项目涌入,还是无休止的临时请求,数据团队的旅程中总有一天会遇到瓶颈。那时,通常会做出一个不幸的决定:采用自助服务模式,但没有为其设定正确的基础。这意味着其他团队可以构建自己的数据集或管道来生成洞察——这一切开始时看起来不错,直到有人最终意识到没有任何数据集可以再信任。

采用没有治理的自助服务方法的陷阱(作者图片)

这里的建议不是避免自助服务。相反,你应该定义事情应该如何做的标准,并尽可能自动化其执行。这些标准可以包括像表格/列/模型命名规范和强制文档化等基本事项,或者更细致的做法,如强制测试和只在生产管道中使用受管控的数据集。

通过应用这样的标准——例如通过持续集成(CI)自动化——可以确保最小程度的一致性,并避免数据平台最终变成数据混乱或沼泽的情况。

然而,重要的是要注意,你选择应用的标准应该有明确的商业理由,并且不应在没有实际好处的情况下影响开发速度。标准化和迭代/交付速度之间的正确平衡取决于公司的背景和数据使用的具体场景,但在所有情况下,至少需要有最基本的标准化。

3. 经常问自己,是否拥有适合当前规模的工具

作为一名工程师,我知道很容易对你为了做 X 而构建的整洁的包,或是你用来做 Y 的开源项目(也许你还为其做出了贡献)产生过度依赖。然而,随着数据平台规模的增长,定期重新审视你的架构,找出工具需要升级(或降级)之处,是一项重要的工作。

迁移当然有成本,而且总是很棘手——但是在合适的时机进行迁移并且拥有一个替换平台组件的操作手册是每个数据团队必须掌握的关键技能。为了评估哪些工具可能不适合你当前的规模,我认为对于你的技术栈中的每个组件,必须有一个最新的、高层次的答案来回答以下问题:

  • 它的成本是多少?(包括维护它所需的工程时间、它的基础设施成本、定价等)这个成本我们是否能够接受?这个成本在明年内会如何变化?

  • 我们是否在它上面投入了资源(工程时间、资金等),而如果把这些资源投入到其他地方,可能会带来更多的商业价值?

  • 它是否具备我们当前所需的所有主要功能,以及我们在明年需要的功能?如果没有,是否有其他能够提供我们所需功能的选项?如果有这样的选项,迁移成本是多少?

通过进行这样的练习,你一定会遇到一些可以完全淘汰的领域(例如那些没有提供任何有意义价值的管道或工具),一些需要用另一个系统替换的领域(可能你构建了一个实时摄取过程,但所有当前对业务至关重要的消费场景只需要批处理),还有一些需要升级的领域(比如将内建的数据移动过程迁移到一个功能更全的成熟工具)。

作为一个更实际的例子,尽管 dbt 测试是确保关键资产数据质量的良好起点,并让消费者信任他们所使用的数据,但当你将更多系统添加到数据平台并开始支持生产(甚至是操作)用例时,它们的效用很快就会达到极限。一旦你在数据事件和调试数据问题上花费的精力开始影响你的路线图,也许是时候从 dbt 测试升级到数据可观察性组件了。

4. 不要过于雄心勃勃(优先解决你的问题)

这一原则与上面提到的原则是相辅相成的。就像在需要时执行迁移至关重要一样,避免不必要的迁移更为重要。

当前数据领域的发展速度意味着新技术、新范式和新架构是不断变化的。科技巨头们在不断重新思考他们的数据处理方式,不管你如何努力,平台的大部分部分总会落后于数据先行者几步。然而,这并不意味着坏事。

在你反思当前技术状态并尝试识别平台中的差距时,指导思想应该是你通过升级所期望生成的商业价值。没错,能力 X 在纸面上听起来不错,但它具体解决了哪些业务场景?你又可以将迁移工作投入到哪些其他项目中呢?

我相信保持与行业前沿同步,了解其他数据团队的历程非常重要,但决定采用新工具或新范式时,始终需要基于预期的商业价值。此外,始终牢记业务优先级,并随着数据团队周围环境的变化进行相应的更新,也是至关重要的。

5. 将部分责任向左(或右)转移

一旦覆盖了上述领域,尤其是当数据平台开始支持高影响力的项目时,数据团队可以开始转变为平台团队角色。这意味着你可以自信地开始在所有权上增添更多的细节,并将数据团队从管理业务指标和源数据问题的负担中解脱出来。

这次过渡的时机至关重要,因为责任的去中心化只有在你已经建立了正确的基础设施的前提下才能成功:

  • 使数据平台成为一个对业务成功至关重要的价值生成系统(这样其他团队也会被激励投入时间进行与数据相关的工作)

  • 标准化不同的流程,定义清晰的接口,减少平台贡献团队(无论是生产者、消费者还是贡献者)需要做的猜测工作

  • 拥有正确的工具集,能够让你轻松引入新团队,并驾驭更复杂的结构

此外,与其一次性将所有数据产品过渡到去中心化模式,不如分阶段进行过渡,这样你成功的几率更高。首先,从识别最关键的数据集和数据管道开始,你可以:

  • 与数据生产者建立契约

  • 将指标定义的所有权移交给相应的业务团队

一旦第一个计划成功,你可以根据业务的重要性,开始扩大符合新模式的数据管道范围。完全不需要将这个模式应用于所有数据集——如果一个数据管道还没有明确的业务用例,它就不需要进行严格的治理。

大规模运作的数据平台示意图(图片由作者提供)

上述要点,最终是本文讨论的前四个原则的总结。许多去中心化/数据网格项目失败,要么是因为缺乏正确的基础,要么是因为数据团队试图从完全集中化转向去中心化的过渡过于乐观(甚至不切实际),而不是分阶段进行转变。

在本文中,我们探讨了成功扩展数据平台并将其转变为一个具有价值生成能力的组成部分的五个原则。它们的应用不要求使用特定的工具或技术,而是可以在进行数据项目时普遍应用的指导方针(除了一些特定的边缘案例/小众行业)。

我们讨论的最终状态是一个混合设置,结合了集中式方法(数据团队拥有端到端的平台)和去中心化/数据网格方法(不同团队构建并拥有其数据产品的端到端),这种设置提升了你成功扩展平台的机会,尤其是在你从一种方法过渡到另一种方法时——但其是否适合你的团队,也取决于你所在的行业和运营的具体背景。

最后,如果只记住一个原则/规则,那就记住这一点:定期重新评估你的方法和你所走的道路是极其重要的。从 0 到 1 的过程与从 1 到 10(以及之后的过程)是截然不同的。 这是一个普遍适用的原则,我认为它对于数据平台尤为重要。

如果你想获取更多数据工程方面的内容,可以订阅我的通讯《Data Espresso》,在其中我讨论与数据工程和技术相关的各种话题:

[## Data Espresso | Mahdi Karabiben | Substack

数据工程更新和评论,伴随着你下午的浓缩咖啡。点击阅读《Data Espresso》,作者:Mahdi…

dataespresso.substack.com

速度需求:cuDF Pandas vs. Pandas

原文:towardsdatascience.com/need-for-speed-cudf-pandas-vs-pandas-16b87009aefa?source=collection_archive---------4-----------------------#2024-04-05

图片由作者提供(Dalle-3)

比较概述

Thomas ReidTowards Data Science Thomas Reid

·发表于 Towards Data Science ·阅读时间 11 分钟·2024 年 4 月 5 日

--

什么是 cuDF Pandas?

如果你是 Python 中 Pandas 库的用户,并且希望或需要最大化程序的运行时间,那么你有几个可供选择的选项。这些选项大多涉及使用外部库,取代现有的 Pandas 操作,并对大规模、高速数据处理进行了优化。此类库的例子包括 VAEX、POLARS、DuckDB 等。问题在于,这些库通常要求你在某种程度上重写代码,这可能不是你想做的,或者你没有能力去做。

如果你的系统幸运地配有 GPU,那么另一个更为近期的可选项就是 cuDF.pandas

cuDF.pandas 构建于 cuDF 之上,它是一个基于 GPU 的 Python DataFrame 库(基于 Apache Arrow 列存储内存格式),用于加载、连接、聚合、过滤以及其他数据操作。

要使用 cuDF Pandas,你只需在命令行运行 Python 时提供一个标志,或者在通过 Jupyter Notebook 运行 Python 时加载一个扩展。当支持 GPU 计算时(例如,有 NVIDIA GPU 可用,且 cuDF 知道如何运行 Pandas 代码),你的代码将在 GPU 上运行。在无法支持的情况下,cuDF…

速度需求:Streamlit 与 Functool 缓存

原文:towardsdatascience.com/need-for-speed-streamlit-vs-functool-caching-eb3b7426f209?source=collection_archive---------4-----------------------#2024-08-28

比较 pandas 和 polars 中 streamlit 缓存与 functools 缓存的性能。结果会让你惊讶!

Jose ParreñoTowards Data Science Jose Parreño

·发表于 Towards Data Science ·12 分钟阅读·2024 年 8 月 28 日

--

图片由 Oscar Sutton 提供,来源于 Unsplash

Streamlit 是我构建概念验证演示和分析仪表盘的默认框架。该框架的简洁性允许快速开发和容易维护。然而,简洁性的暗面是,它带有内置的设计假设,使其很难作为一款顶级生产工具使用。我们稍后会详细讨论这些问题,但这些假设的结果是,Streamlit 在处理和渲染应用时可能会非常慢

在这篇文章中,我想向你展示提高 Streamlit 应用速度的 2 种方法:使用内置的Streamlit 缓存功能和使用内置的functools 缓存功能。这两种方法都基于缓存的概念,即如果某个操作之前已经执行过,那么输出会被保存下来以便之后重用。

在进入结果之前,我认为理解以下 3 个基本理论是很重要的:Streamlit 是如何工作的,Streamlit 缓存如何工作,以及 functools 缓存是如何在后台运行的。

附言:所有图片均由我创作,除非另有说明。

Streamlit 会重新执行所有内容……

NeMo Guardrails:终极开源 LLM 安全工具包

原文:towardsdatascience.com/nemo-guardrails-the-ultimate-open-source-llm-security-toolkit-0a34648713ef?source=collection_archive---------3-----------------------#2024-02-09

探索 NeMo Guardrails 的实际应用案例

Wenqi GlantzTowards Data Science Wenqi Glantz

·发表于Towards Data Science ·阅读时间 13 分钟·2024 年 2 月 9 日

--

图片由 DALL-E 3 生成,由作者提供

在 LLM 安全性话题中,我们已经从不同角度探讨了OWASP top 10 for LLM applications、Llama Guard 和 Lighthouz AI。今天,我们将探索NVIDIA 开发的开源工具包 NeMo Guardrails,该工具包可轻松为基于 LLM 的对话系统添加可编程防护措施。

NeMo Guardrails vs. Llama Guard

NeMo Guardrails 与我们在上一篇文章中深入探讨过的 Llama Guard 有何不同?让我们并排比较它们的功能。

表格由作者提供

正如我们所看到的,Llama Guard 和 NeMo Guardrails 在本质上是不同的:

  • Llama Guard 是一个大型语言模型,基于 Llama 2 进行微调,并且是一个输入输出安全防护模型。它包含六个不安全类别,开发者可以通过添加额外的不安全类别来定制这些类别,以适应他们的输入输出审查需求。

  • NeMo Guardrails 是一个更为全面的 LLM 安全工具集,提供了更广泛的可编程功能…

使用 🤗 Accelerate 进行捷克文档中的 NER 任务,基于 XLM-RoBERTa 模型

原文:towardsdatascience.com/ner-in-czech-documents-with-xlm-roberta-using-accelerate-32a6baf3e91e?source=collection_archive---------9-----------------------#2024-11-12

在开发一个成功部署的文档处理模型过程中,我做出的决策

Bohumir BusoTowards Data Science Bohumir Buso

·发布于 Towards Data Science ·阅读时间:9 分钟·2024 年 11 月 12 日

--

图像由 Dall-E 生成

尽管我在机器学习项目中已有超过 8 年的经验,但这是我第一次从事 NLP 项目。我最初寻找现有的资源和代码,但发现相关材料非常有限,特别是关于捷克语文档中的 NER 任务。这激发我将开发过程中学到的所有知识整理到一个地方,希望能帮助未来的新人更高效地进步。因此,本文提供的是一个实用的入门介绍,而非深入的理论分析

由于敏感性考虑,特定的数值结果已被省略。

数据

包含实体的文档示例:债权人变量符号(红色)、姓氏(浅绿色)、名字(深绿色)、出生日期(蓝色)。敏感信息已被屏蔽。

任务 主要目标是通过以下任一标识符来识别与每个文档相关联的客户:

  • 债权人变量符号(约 20%的文档中出现)

  • 出生身份证明(约 60%的文档中出现)

  • 组合 姓名 + 姓氏 + 出生日期(约 50%的文档中出现)

大约 5%的文档没有包含任何识别实体。

数据集 在开发过程中,我使用了 710 个“真实”PDF 文档,将其分为三个集:600 个用于训练,55 个用于验证,55 个用于测试。

标签 我收到了一个包含实体提取为纯文本的 Excel 文件,需要手动标注文档文本。我使用了BIO标注格式,按照以下步骤进行操作:

  1. 打开每个文档(使用extract_text()函数,来自pdfminer.high_level模块)

  2. 将文本拆分为单词(使用 SpaCy 模型“xx_sent_ud_sm”,并做一些调整,如防止在连字符上拆分,以处理出生号码格式,例如‘84–12–10/7869’)

  3. 识别文本中的实体

  4. 为实体分配相应的标签,使用“O”标签标注所有其他词汇

替代方法 类似 LayoutLM 的模型,考虑输入标记的边界框,可能会提高质量。然而,我避免了这个选项,因为和往常一样(😮‍💨),我已经在数据准备上花费了大部分项目时间(例如,重新格式化 Excel 文件、修正数据错误、标注)。追求基于边界框的模型会需要更多的时间。

尽管正则表达式和启发式方法理论上可以处理这些简单的实体,我认为这种方法会无效,因为它需要过于复杂的规则才能准确地在其他潜在候选实体中识别出正确的实体(如律师名字、案件编号、其他诉讼参与者等)。另一方面,模型能够学习区分相关实体,因此不需要使用启发式方法。

模型(训练)

🤗 Accelerate

由于在封装器较不常见的时期开始工作,我习惯于编写自己的训练循环,这让我更容易调试 - 这一方法得到了🤗 Accelerate 的有效支持。这在这个项目中证明是有益的——我并不完全确定所需的数据和标签格式或形状,而我的数据与教程中常见的组织良好的示例不符,但在训练循环中完全访问中间计算使我能够快速迭代。

上下文长度 大多数教程建议将每个句子作为一个单独的训练示例。然而,在这个案例中,我决定更长的上下文更为合适,因为文档通常包含对多个实体的引用,其中许多是无关的(例如律师、其他债权人、案件编号)。这种更广泛的上下文帮助模型更好地识别相关客户。我使用每个文档的 512 个标记作为一个训练示例。这是大多数模型的常见最大限制,但足以容纳我文档中的所有实体。

子标记的标注 在🤗标记分类教程[1]中,推荐的方法是:

只标注给定单词的第一个标记。对同一单词的其他子标记分配*-100*

然而,我发现以下方法在他们的 NLP 课程中的标记分类教程[2]中效果更好:

每个词元都得到与它所在单词开头词元相同的标签,因为它们是同一实体的一部分。对于单词内部但不是开头的词元,我们将*B-*替换为*I-*

标签“-100”是一个特殊标签,它会被损失函数忽略。因此,我对他们的函数做了小的修改:

def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels

def tokenize_and_align_labels(examples):
    tokenizer = AutoTokenizer.from_pretrained("../model/xlm-roberta-large")
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True,
        padding="max_length", max_length=512)
    all_labels = examples["ner_tags"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

我还使用了他们的postprocess()函数:

为了简化其评估部分,我们定义了这个*postprocess()*函数,它接收预测值和标签,并将它们转换为字符串列表。

def postprocess(predictions, labels):
    predictions = predictions.detach().cpu().clone().numpy()
    labels = labels.detach().cpu().clone().numpy()

    true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    return true_predictions, true_labels

类别权重

将类别权重纳入损失函数显著提高了模型性能。虽然这个调整看起来很简单——没有它,模型过于强调了“O”类别这个多数类——但令人惊讶的是,大多数教程中都没有提到这一点。我实现了一个自定义的compute_weights()函数来解决这个不平衡问题:

def compute_weights(trainset, num_labels):
    c = Counter()
    for t in trainset:
        c += Counter(t['labels'].tolist())
    weights = [sum(c.values())/(c[i]+1) for i in range(num_labels)]
    return weights

训练循环 我定义了两个额外的函数:PyTorch 的DataLoader()来管理批处理,以及一个main()函数来设置分布式训练对象并执行训练循环。

from accelerate import Accelerator, notebook_launcher
from collections import Counter
from datasets import Dataset
from datetime import datetime
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification
from transformers import XLMRobertaConfig, XLMRobertaForTokenClassification
from seqeval.metrics import classification_report, f1_score

def create_dataloaders(trainset, evalset, batch_size, num_workers):
    train_dataloader = DataLoader(trainset, shuffle=True, 
                          batch_size=batch_size, num_workers=num_workers)
    eval_dataloader = DataLoader(evalset, shuffle=False, 
                          batch_size=batch_size, num_workers=num_workers)
    return train_dataloader, eval_dataloader

def main(batch_size, num_workers, epochs, model_path, dataset_tr, dataset_ev, training_type, model_params, dt):
    accelerator = Accelerator(split_batches=True)
    num_labels = model_params['num_labels']

    # Prepare data #
    train_ds = Dataset.from_dict(
                {"tokens": [d[2][:512] for d in dataset_tr], 
                 "ner_tags": [d[1][:512] for d in dataset_tr]})
    eval_ds = Dataset.from_dict(
                {"tokens": [d[2][:512] for d in dataset_ev],
                 "ner_tags": [d[1][:512] for d in dataset_ev]})
    trainset = train_ds.map(tokenize_and_align_labels, batched=True,
                 remove_columns=["tokens", "ner_tags"])
    evalset = eval_ds.map(tokenize_and_align_labels, batched=True,
                 remove_columns=["tokens", "ner_tags"])
    trainset.set_format("torch")
    evalset.set_format("torch")
    train_dataloader, eval_dataloader = create_dataloaders(trainset, evalset,
                                          batch_size, num_workers)

    # Type of training #
    if training_type=='from_scratch':
        config = XLMRobertaConfig.from_pretrained(model_path, **model_params)
        model = XLMRobertaForTokenClassification(config)
    elif training_type=='transfer_learning':
        model = AutoModelForTokenClassification.from_pretrained(model_path, 
                    ignore_mismatched_sizes=True, **model_params)
        for param in model.parameters():
            param.requires_grad=False
        for param in model.classifier.parameters():
            param.requires_grad=True
    elif training_type=='fine_tuning':
        model = AutoModelForTokenClassification.from_pretrained(model_path,
                     **model_params)
        for param in model.parameters():
            param.requires_grad=True
        for param in model.classifier.parameters():
            param.requires_grad=True

    # Intantiate the optimizer #
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=2e-5)

    # Instantiate the learning rate scheduler #
    lr_scheduler = ReduceLROnPlateau(optimizer, patience=5)

    # Define loss function #
    weights = compute_weights(trainset, num_labels)
    loss_fct = CrossEntropyLoss(weight=torch.tensor(weights))

    # Prepare objects for distributed training #
    loss_fct, train_dataloader, model, optimizer, eval_dataloader, lr_scheduler = accelerator.prepare(
        loss_fct, train_dataloader, model, optimizer, eval_dataloader, lr_scheduler)

    # Training loop #
    max_f1 = 0 # for early stopping
    for t in range(epochs):
        # training
        accelerator.print(f"\n\nEpoch {t+1}\n-------------------------------")
        model.train()
        tr_loss = 0
        preds = list()
        labs = list()
        for batch in train_dataloader:
            outputs = model(input_ids=batch['input_ids'],
                            attention_mask=batch['attention_mask'])
            labels = batch["labels"]
            loss = loss_fct(outputs.logits.view(-1, num_labels), labels.view(-1))
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            tr_loss += loss
            predictions = outputs.logits.argmax(dim=-1)
            predictions_gathered = accelerator.gather(predictions)
            labels_gathered = accelerator.gather(labels)
            true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
            preds.extend(true_predictions)
            labs.extend(true_labels)

        lr_scheduler.step(tr_loss)

        accelerator.print(f"Train loss: {tr_loss/len(train_dataloader):>8f} \n")
        accelerator.print(classification_report(labs, preds))

        # evaluation
        model.eval()
        ev_loss = 0
        preds = list()
        labs = list()
        for batch in eval_dataloader:
            with torch.no_grad():
                outputs = model(input_ids=batch['input_ids'],
                                attention_mask=batch['attention_mask'])
                labels = batch["labels"]
                loss = loss_fct(outputs.logits.view(-1, num_labels), labels.view(-1))

            ev_loss += loss
            predictions = outputs.logits.argmax(dim=-1)
            predictions_gathered = accelerator.gather(predictions)
            labels_gathered = accelerator.gather(labels)
            true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
            preds.extend(true_predictions)
            labs.extend(true_labels)

        accelerator.print(f"Eval loss: {ev_loss/len(eval_dataloader):>8f} \n")
        accelerator.print(classification_report(labs, preds))

        accelerator.print(f"Current Learning Rate: {optimizer.param_groups[0]['lr']}")

        # checkpoint best model
        if f1_score(labs, preds) > max_f1:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(f"../model/xlml_ner/{dt}/",
                               is_main_process=accelerator.is_main_process,
                               save_function=accelerator.save)
            accelerator.print(f"Model saved during {t+1}. epoch.")
            max_f1 = f1_score(labs, preds)
            best_epoch = t

        # early stopping
        if (t - best_epoch) > 10:
            accelerator.print(f"Early stopping after {t+1}. epoch.")
            break

    accelerator.print("Done!") 

一切准备就绪,模型已经可以开始训练了。我只需要启动这个过程:

label_list = [
    "O",
    "B-evcu", "I-evcu", # variable symbol of creditor
    "B-rc", "I-rc", # birth ID
    "B-prijmeni", "I-prijmeni", # surname
    "B-jmeno", "I-jmeno", # given name
    "B-datum", "I-datum", # birth date
]
id2label = {a: b for a,b in enumerate(label_list)}
label2id = {b: a for a,b in enumerate(label_list)}

num_workers = 6 # number of GPUs
batch_size = num_workers*2
epochs = 100
model_path = "../model/xlm-roberta-large"
training_type = "fine_tuning" # from_scratch / transfer_learning / fine_tuning
model_params = {"id2label": id2label, "label2id": label2id, "num_labels": 11}
dt = datetime.now().strftime("%Y%m%d_%H%M%S")
os.mkdir(f"../model/xlml_ner/{dt}")

notebook_launcher(main, args=(batch_size, num_workers, epochs, model_path,
                   dataset_tr, dataset_ev, training_type, model_params, dt),
                   num_processes=num_workers, mixed_precision="fp16", use_port="29502")

我发现使用 notebook_launcher() 很方便,因为它允许我在控制台中运行训练,并且之后可以轻松处理结果。

XLM-RoBERTa 基础版 vs 大型版 vs Small-E-Czech 我尝试了微调三种模型。XLM-RoBERTa 基础版模型[3]表现令人满意,但服务器容量也允许我尝试 XLM-RoBERTa 大型版模型[3],它的参数量是前者的两倍。

XLM-RoBERTa 是 RoBERTa 的多语言版本。它是在 2.5TB 过滤后的 CommonCrawl 数据上预训练的,包含 100 种语言。

大型模型在结果上略有提升,所以我最终部署了它。我也测试了 Small-E-Czech [4],一个在捷克网页数据上预训练的 Electra-small 模型,但它的表现很差。

微调 vs 迁移学习 vs 从头训练 除了微调(更新所有模型权重)之外,我还测试了迁移学习,因为有时会建议只训练最后一层(分类层)就足够了。然而,性能差异显著,微调更为优越。我还尝试了从头训练,通过仅导入模型的架构,随机初始化权重,然后进行训练,但如预期的那样,这种方法效果不佳。

RoBERTa vs LLM(Claude 3.5 Sonnet) 我简要探索了零样本 LLM,尽管我几乎没有进行提示工程(所以有点🥱)。模型在处理基本请求时表现困难,比如(我在实际提示中使用了捷克语):

查找债权人变量符号。这个数字由恰好 9 位连续数字(0–9)组成,不包含字母或其他特殊字符。通常前面会有以下缩写之一:‘ev.č.’、‘zn. opr’、‘VS. O’、‘evid. č. opr.’。相反,我对缩写为 ‘č.j.’ 的交易号不感兴趣。这个数字在文档中出现的频率较低,可能你找不到它,若无法找到,写上‘cannot find’。如果不确定,写上‘not sure’。

模型有时无法准确输出 9 位数字格式。后处理会过滤掉较短的数字,但也会出现很多误报的 9 位数字。

有时模型仅基于出生日期推测错误的 出生身份证号(即使温度设置为 0)。另一方面,模型在提取 姓名姓氏出生日期 方面表现优秀。

总体而言,即使在我之前的实验中,我发现大语言模型(LLMs)(在撰写时)在一般任务上表现较好,但在特定或非常规任务上的准确性和可靠性较差。在客户识别方面,两种方法的表现相差无几。由于内部原因,最终部署了 RoBERTa 模型。

后处理

值得注意的是,实施 后处理可以显著减少误报,提高整体性能。每个实体都遵循定制的过滤和验证规则:

  • 债务人变量符号 - 验证 9 位数字格式

  • 出生身份证号 - 强制使用 XXXXXX/XXX(X) 格式,并检查是否能被 11 整除

  • 姓名姓氏 - 使用 MorphoDiTa [5] 进行词形还原

  • 出生日期 - 强制使用 DD.MM.YYYY 格式

结论

微调后的模型成功部署,表现卓越在仅有 710 篇文档的 modest 数据集上超出预期表现

尽管大语言模型在一般任务中显示出潜力,但它们在专业任务上的准确性和可靠性不足。也就是说,随着大语言模型的持续改进,未来即便是微调,除非是高度专业化的任务,否则可能变得不再必要。

致谢 我想感谢 MartinTomášPetr 对本文章提出的宝贵建议。

来源 [1] Hugging Face,Transformers - Token classification

[2] Hugging Face,NLP Course — Token classification

[3] A. Conneau, K. Khandelwal, N. Goyal, V. Chaudhary, G. Wenzek, F. Guzman, E. Grave, M. Ott, L. Zettlemoyer 和 V. Stoyanov,无监督跨语言表示学习 at Scale(2019),CoRR abs/1911.02116

[4] M. Kocián, J. Náplava, D. Štancl 和 V. Kadlec,基于 Siamese BERT 的网络搜索相关性排序模型,在新的捷克数据集上进行评估(2021)

[5] J. Straková, M. Straka 和 J. Hajič. 用于形态学、词形还原、词性标注和命名实体识别的开源工具(2014),发表于《第 52 届计算语言学协会年会:系统展示论文集》,第 13–18 页,美国马里兰州巴尔的摩,2014 年 6 月。计算语言学协会。

网络分析、扩散模型、数据湖仓及更多:我们近期的最佳深度探讨

原文:towardsdatascience.com/network-analysis-diffusion-models-data-lakehouses-and-more-our-best-recent-deep-dives-927c5a9063b9?source=collection_archive---------7-----------------------#2024-11-14

TDS 编辑Towards Data Science TDS 编辑

·发表于Towards Data Science ·作为时事通讯发送 ·4 分钟阅读·2024 年 11 月 14 日

--

想写下你的第一篇 TDS 文章吗?我们始终欢迎新作者的投稿

我们在深度探讨页面上展示的文章包括前沿研究的详细讲解、数学概念的解析,以及关于构建和部署基于 LLM 工具的耐心教程。它们共同代表了我们一些最有深度、最有思考的故事。

本周,我们邀请社区成员从日常生活的快节奏中稍作停顿,抽出时间探索一系列近期的深度探讨——所有这些都提供了对关键数据科学和机器学习话题的细致见解。

你是否有心情动手搞点代码?还是更愿意思考一些关于人工智能的重大问题,这些问题正在塑造相关辩论?无论哪种方式,我们都为你准备好了:我们在本期《Variable》特别版中策划的内容涵盖了多个领域,提供了多个进入复杂(且迷人)对话的入口。选择你自己的冒险之旅吧!

  • 生成式 AI 安全框架:文本到图像模型的缓解堆栈指南 “鉴于与图像生成和图像修复功能相关的潜在风险,有必要在模型生命周期的不同阶段建立一个强大的安全缓解堆栈。” Trupti Bavalatti 详细解析了当前可用的不同方法,以应对生成式 AI 图像工具中固有的风险。

  • 去中心化推理中的 DCIN 温和介绍 去中心化协作智能网络是如何工作的?Marcello Politi 介绍了他和他的团队最近几个月的研究成果:一个由节点组成的网络,通过共享计算能力执行开源模型的推理,“计算是动态高效地分布的,同时保持高安全性,并且通过共享计算奖励用户。”

  • 悲惨世界社交网络分析:使用 Marimo 笔记本和 NetworkX Python 库 网络分析在日常数据科学工作流中有着广泛的应用——它也可以帮助我们在艺术作品中发现模式和关系。举个例子:Maria Mouschoutzi, PhD 的一个引人入胜的项目,该项目依赖于 NetworkX 库和 Marimo 笔记本,研究维克多·雨果的《悲惨世界》中所展现的复杂社交关系网。无论你是支持瓦尔让还是支持贾维尔,还是仅仅对学习新的数据科学工具感兴趣,都应该将其加入你的阅读清单。

图片由 boris misevic 提供,来自 Unsplash

  • 让光明降临!扩散模型与重新照明的未来 扩散模型几年前首次引起广泛关注,自那时以来,研究人员投入了大量时间和精力来优化其性能。 Pulkit Gera 提供了对这些模型在一个关键方面——重新照明的最新研究综述,重新照明是指“在给定输入场景的基础上,在特定目标照明条件下渲染场景的任务。”

  • 从蒙提霍尔问题中学习决策从一个新颖且发人深省的角度来看待一个知名话题总是值得的。这正是你将在Eyal Kazin的这篇简明而全面的《蒙提霍尔问题》入门中找到的内容,他从三个不同的角度探讨这个经典问题,深入挖掘其底层数学原理及实际应用。

  • 批判性审视 AI 图像生成“我想强调的一点是,这些并非不受文化和社会影响——无论这种影响是好是坏。”Stephanie Kirmer从整体上反思图像生成模型,考虑到其美学的局限性以及它们所反映的潜在偏见。

  • 论文解析:Attention Is All You Need很可能没有哪篇近期的机器学习论文像 Vaswani 等人 2017 年的里程碑贡献那样,生成如此多的指南、解释文章和教程。你可能会问,为什么还要分享另一篇?阅读Muhammad Ardi的精彩深度分析,我们猜测你会明白其中的要点:它在解读 Transformer 的关键组成部分时表现出色,且在理论与实践实现之间找到了很好的平衡。

  • SQL 与数据建模实践:深入探索数据湖仓对于任何初次接触数据库、复杂数据架构和/或 SQL 的人来说,我们强烈推荐Sarah Lea的这篇面向初学者的——但又详细而精确的——入门指南,内容涵盖了 SQL 和云应用数据建模的基础知识。

感谢你支持我们作者的工作!正如我们之前提到的,我们非常喜欢发表新作者的文章,所以如果你最近写了一篇有趣的项目解析、教程或关于我们核心主题的理论反思,请不要犹豫,与我们分享

直到下一个 Variable,

TDS 团队

网络分析示意图:传播公共卫生信息的指标

原文:towardsdatascience.com/network-analysis-illustrated-metrics-to-spread-public-health-information-4e259c3a5e33?source=collection_archive---------15-----------------------#2024-02-16

网络分析

本文通过视觉化的方式简化了网络指标,以公共卫生为例进行讲解。

Christabelle PabalanTowards Data Science Christabelle Pabalan

·发布于 Towards Data Science ·7 分钟阅读·2024 年 2 月 16 日

--

DALLE 生成的图像,作者提供

传播疾病预防信息

假设你是一名公共卫生官员,负责在一个人口密集的城市中传播关于疾病预防的重要信息。在传染性疾病的威胁下,你的任务非常明确:教育社区采取主动措施,保护他们的健康,并防止疾病传播。

你希望了解网络的动态,并识别出城市中的关键影响者和沟通渠道。通过绘制社交连接图,你可以洞察到接触不同群体的最有效方式。你还将识别出那些具有影响力的群体,他们可以作为信息传播者,快速传递关于疾病预防的信息。

网络分析

这就是网络分析发挥作用的地方。这个计算工具为研究个体在网络中如何互相连接并相互影响提供了一个共同的语言。它在各个领域都有应用……

神经网络(MLP)在时间序列预测中的实践应用

原文:towardsdatascience.com/neural-network-mlp-for-time-series-forecasting-in-practice-04c47c1e3711?source=collection_archive---------0-----------------------#2024-07-08

特征工程和构建 MLP 模型的实用示例

Daniel J. TOTHTowards Data Science Daniel J. TOTH

·发表于 Towards Data Science ·阅读时长 16 分钟·2024 年 7 月 8 日

--

介绍

时间序列,尤其是时间序列预测,是数据科学领域一个非常著名的问题,受专业人士和商业用户的广泛关注。

存在多种预测方法,可以将它们归类为统计方法或机器学习方法,便于理解和概览,但实际上,预测需求如此之高,现有的选项种类繁多。

机器学习方法被认为是时间序列预测中的最先进方法,并且因能够捕捉数据中的复杂非线性关系而越来越受欢迎,通常能提供更高的预测准确性[1]。其中,神经网络领域是一个广受关注的机器学习分支。特别是在时间序列分析中,循环神经网络已经被开发并应用于解决预测问题[2]。

数据科学爱好者可能会觉得这些模型背后的复杂性令人望而生畏,作为其中的一员,我可以告诉你,我也有同样的感觉。然而,本文旨在展示

尽管机器学习方法的最新进展非常显著,但在寻求特定问题的解决方案时,并不一定需要追求最复杂的应用。经过强化的成熟方法与强大的特征工程技术结合,依然能够提供令人满意的结果。

更具体地说,我应用了多层感知器模型,并分享了代码和结果,让你能够亲自体验如何有效地进行时间序列特征工程和预测。

本文目标

更准确地说,我想为自学的专业人士提供的内容可以总结为以下几点:

  1. 基于实际问题/数据进行预测

  2. 如何为捕捉时间模式工程化时间序列特征

  3. 构建一个能够利用混合变量(浮动和整数,通过嵌入处理为类别变量)的 MLP 模型

  4. 使用 MLP 进行点预测

  5. 使用 MLP 进行多步预测

  6. 使用置换特征重要性方法评估特征重要性

  7. 针对一组分组特征(多个组,分别针对每个组进行训练)重新训练模型,以细化分组特征的重要性

  8. 通过与UnobservedComponents模型进行比较来评估模型

关键技术术语

请注意,本文假定读者已经具备一些关键技术术语的基础知识,并不打算详细解释这些术语。以下列出了这些关键术语,并提供了参考,读者可查阅以便理解:

  1. 时间序列 [3]

  2. 预测 [4] — 在这种情况下,它将用于区分训练期间的模型输出

  3. 预测 [4] — 在这种情况下,它将用于区分测试期间的模型输出

  4. 特征工程 [5]

  5. 自相关 [6]

  6. 偏自相关 [6]

  7. MLP(多层感知器) [7]

  8. 输入层 [7]

  9. 隐藏层 [7]

  10. 输出层 [7]

  11. 嵌入 [8]

  12. 状态空间模型 [9]

  13. 未观察到的组件模型 [9]

  14. RMSE(均方根误差) [10]

  15. 特征重要性 [11]

  16. 置换特征重要性 [11]

数据探索

在分析过程中使用的核心包包括:用于数据处理的numpypandas,用于交互式图表的plotly,用于统计和状态空间建模的statsmodels,以及用于 MLP 架构的tensorflow

注意:由于技术限制,我将提供交互式绘图的代码片段,但此处展示的图表将是静态的。

import opendatasets as od
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import tensorflow as tf

from sklearn.preprocessing import StandardScaler
from sklearn.inspection import permutation_importance
import statsmodels.api as sm
from statsmodels.tsa.stattools import acf, pacf
import datetime

import warnings
warnings.filterwarnings('ignore')

数据通过opendatasets自动加载。

dataset_url = "https://www.kaggle.com/datasets/robikscube/hourly-energy-consumption/"
od.download(dataset_url)
df = pd.read_csv(".\hourly-energy-consumption" + "\AEP_hourly.csv", index_col=0)
df.sort_index(inplace = True)

请记住,数据清理是分析的关键第一步。如果你对细节感兴趣,特别是状态空间建模,请参考我之前的文章这里。☚📰 简而言之,进行了以下步骤:

  1. 识别缺失的时间戳(仅识别了单步缺失)

  2. 执行插补(使用前后记录的均值)

  3. 识别并删除重复项

  4. 将时间戳列设置为数据框的索引

  5. 将数据框的索引频率设置为每小时,因为这是进一步处理的要求

在准备好数据后,我们通过绘制 5 个随机时间戳样本来探索数据,并比较不同尺度下的时间序列。

fig = make_subplots(rows=5, cols=4, shared_yaxes=True, horizontal_spacing=0.01, vertical_spacing=0.04)

#  drawing a random sample of 5 indices without repetition
sample = sorted([x for x in np.random.choice(range(0, len(df), 1), 5, replace=False)])

# zoom x scales for plotting
periods = [9000, 3000, 720, 240]

colors = ["#E56399", "#F0B67F", "#DE6E4B", "#7FD1B9", "#7A6563"]

# s for sample datetime start
for si, s in enumerate(sample):

    # p for period length
    for pi, p in enumerate(periods):
        cdf = df.iloc[s:(s+p+1),:].copy()
        fig.add_trace(go.Scatter(x=cdf.index,
                                 y=cdf.AEP_MW.values,
                                 marker=dict(color=colors[si])),
                        row=si+1, col=pi+1)

fig.update_layout(
    font=dict(family="Arial"),
    margin=dict(b=8, l=8, r=8, t=8),
    showlegend=False,
    height=1000,
    paper_bgcolor="#FFFFFF",
    plot_bgcolor="#FFFFFF")
fig.update_xaxes(griddash="dot", gridcolor="#808080")
fig.update_yaxes(griddash="dot", gridcolor="#808080")

数据集的随机抽样和不同时间尺度的可视化。来源:作者

状态空间建模

通过仔细分析这个简单但有效的图表,我可以清楚地看到,分析应该考虑几个季节性效应:

  1. 能源消耗——通常——在夏季和冬季的中期达到峰值,无论选择哪个年份

  2. 在周一似乎会出现每周最小值模式

  3. 在夜间有一个每日最低值,白天有一个每日最高值。

进一步分析会揭示,数据集的年度模式有 2 个谐波,因为冬季和夏季的峰值水平不同。因此,考虑了以下状态空间模型,其中周期以小时为单位(见下文模型总结):

# splitting time series to train and test subsets
y_train = df.iloc[:-8766, :].copy()
y_test = df.iloc[-8766:, :].copy()

# Unobserved Components model definition
model = sm.tsa.UnobservedComponents(y_train,
                                    level='dtrend',
                                    irregular=True,
                                    stochastic_level = False,
                                    stochastic_trend = False,
                                    stochastic_freq_seasonal = [False, False, False],
                                    freq_seasonal=[{'period': 24, 'harmonics': 1},
                                                    {'period': 168, 'harmonics': 1},
                                                    {'period': 8766, 'harmonics': 2}])
# fitting model to train data
model_results = model.fit()

# printing statsmodels summary for model
print(model_results.summary())
Value of `irregular` may be overridden when the trend component is specified using a model string.

                           Unobserved Components Results                            
====================================================================================
Dep. Variable:                       AEP_MW   No. Observations:               112530
Model:                  deterministic trend   Log Likelihood            -1002257.017
                     + freq_seasonal(24(1))   AIC                        2004516.033
                    + freq_seasonal(168(1))   BIC                        2004525.664
                   + freq_seasonal(8766(2))   HQIC                       2004518.941
Date:                      Tue, 25 Jun 2024                                         
Time:                              08:13:35                                         
Sample:                          10-01-2004                                         
                               - 08-02-2017                                         
Covariance Type:                        opg                                         
====================================================================================
                       coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------------
sigma2.irregular  3.168e+06    1.3e+04    244.095      0.000    3.14e+06    3.19e+06
===================================================================================
Ljung-Box (L1) (Q):              104573.71   Jarque-Bera (JB):              2731.37
Prob(Q):                              0.00   Prob(JB):                         0.00
Heteroskedasticity (H):               1.04   Skew:                             0.35
Prob(H) (two-sided):                  0.00   Kurtosis:                         3.30
===================================================================================

Warnings:
[1] Covariance matrix calculated using the outer product of gradients (complex-step).

在不提前过多展开的情况下,我想指出,模型近似了过去 365 天的总能源消耗,误差约为~2%,从商业角度来看,我认为这是相当准确的。下面构建的 MLP 模型将通过与上述状态空间模型的比较来评估。

特征工程

在构建 MLP 模型之前,我们应使模型能够学习到独特的趋势和季节性效应。这可以通过向数据集添加新特征来实现,这些特征是从原始的 1D 时间序列数据派生而来的。为捕捉已经识别或未识别的模式,派生特征包括:

  1. 滞后

  2. 差异

  3. 滚动均值

  4. 滚动标准差

  5. 一天中的小时

  6. 一周中的天数

  7. 标记周末

这些派生的——以及数值型——特征可以在多个时间间隔中进行考虑。为了确定模型在哪些时间间隔中能够获益,强烈建议检查数据集的自相关特性。

dff = df.copy()
acorr = acf(dff.AEP_MW.values, nlags=2*366)     # autocorrelation
pacorr = pacf(dff.AEP_MW.values, nlags=2*366)   # partial autocorrelation

fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0)
fig.add_trace(go.Scatter(
    x=np.linspace(0, len(acorr), len(acorr)+1),
    y=acorr,
    name="Autocorrelation",
    marker=dict(color="rgb(180, 120, 80)")
), row=1, col=1)
fig.add_trace(go.Scatter(
    x=np.linspace(0, len(pacorr), len(pacorr)+1),
    y=pacorr,
    name="Partial Autocorrelation",
    marker=dict(color="rgb(80, 180, 120)")
), row=2, col=1)
fig.update_layout(
    font=dict(family="Arial"),
    margin=dict(b=4, l=4, r=4, t=4),
    showlegend=False,
    height=500,
    paper_bgcolor="#FFFFFF",
    plot_bgcolor="#FFFFFF")
fig.update_xaxes(griddash="dot", gridcolor="#808080", row=1, col=1)
fig.update_xaxes(griddash="dot", gridcolor="#808080", title_text="No. of lags", row=2, col=1)
fig.update_yaxes(griddash="dot", gridcolor="#808080", title_text="Autocorrelation", row=1, col=1)
fig.update_yaxes(griddash="dot", gridcolor="#808080", title_text="Partial Autocorrelation", row=2, col=1)

时间序列的自相关和部分自相关图。来源:作者

数据集具有很高的自相关性,这很合理,因为值大多在 10K MW 到 20K MW 之间波动,且从一个小时到下一个小时的过渡平滑。然而,专注于下图所示的部分自相关性,似乎在 24 小时的倍数以及最后几个小时中存在显著的相关性。因此,派生特征主要可以分类为:

  1. 每日(24 小时的倍数),

  2. 每小时(专注于最后几个小时)和

  3. 分类特征

dff = df.reset_index(drop=False)
dff["Datetime"] = pd.to_datetime(dff.Datetime.values)

# lags and difference of multiple days for capturing seasonal effects
for i in np.linspace(24, 15*24, 15, dtype=int):
    dff[f"lag_{i}"] = dff.AEP_MW.shift(i)
    dff[f"difference_{i}"] = dff.AEP_MW.diff(periods=i)

# rolling mean and standard deviation up to 3 days for capturing seasonal effects better
for i in np.linspace(24, 72, 3, dtype=int):
    dff[f"rolling_mean_{i}"] = dff.AEP_MW.rolling(window=i).mean()
    dff[f"rolling_std_{i}"] = dff.AEP_MW.rolling(window=i).std()

# lag, rolling mean, rolling standard deviation and difference up to 4 hours for capturing immediate effects
for i in range(2, 5, 1):
    dff[f"lag_{i}"] = dff.AEP_MW.shift(i)
    dff[f"rolling_mean_{i}"] = dff.AEP_MW.rolling(window=i).mean()
    dff[f"rolling_std_{i}"] = dff.AEP_MW.rolling(window=i).std()
    dff[f"difference_{i}"] = dff.AEP_MW.diff(periods=i)

# categorical features
dff["hour_of_day"] = dff.Datetime.dt.hour
dff["day_of_week"] = dff.Datetime.dt.day_of_week
dff["is_weekend"] = dff["day_of_week"].isin([5, 6]).astype(int)

# grouping derived features for later use in feature importance analysis
daily_lags = [col for col in dff.columns if all(["lag_" in col, len(col)>5])]
hourly_lags = [col for col in dff.columns if all(["lag_" in col, len(col)<=5])]
daily_differences = [col for col in dff.columns if all(["difference_" in col, len(col)>12])]
hourly_differences = [col for col in dff.columns if all(["difference_" in col, len(col)<=12])]
daily_rolling_means = [col for col in dff.columns if all(["rolling_mean_" in col, len(col)>14])]
hourly_rolling_means = [col for col in dff.columns if all(["rolling_mean_" in col, len(col)<=14])]
daily_rolling_stds = [col for col in dff.columns if all(["rolling_std_" in col, len(col)>13])]
hourly_rolling_stds = [col for col in dff.columns if all(["rolling_std_" in col, len(col)<=13])]
categoricals = ["hour_of_day", "day_of_week", "is_weekend"]

构建 MLP 模型

生成上述详细特征后,输入形状已知,可以构建 MLP 模型。需要注意的是,我们处理的是混合数据类型:浮动型和整数型。还请注意,尽管所有特征都是数值类型,整数型输入本质上是分类特征,应当视为分类特征来处理。

有一种方法可以使用例如独热编码技术对类别进行编码,但这会显著增加特征的数量,因为每个类别列都应该扩展为与类别数相等的列数(减去一个)[12]。我故意选择了嵌入方法,以限制特征数量,虽然这样做会使得模型的输入层更加复杂,因为类别数据首先通过嵌入转换为向量,再与浮动输入结合。

请查看代码部分后的图表以获得更清晰的理解。该架构是使用经验法则构建的,因为超参数调优不在本文范围内。然而,如果你对如何进行超参数调优的通用框架感兴趣,请查看📰☛ 我之前的文章(在这篇文章中,我使用 Optuna 作为贝叶斯搜索工具调优了 XGBoost 模型的最佳超参数)。

# segmenting last year as test data
inputs = dff.dropna().iloc[:, 2:].columns
xs_train = dff.dropna().iloc[:-8766, 2:]
xs_test = dff.dropna().iloc[-8766:, 2:]
ys_train = dff.dropna().iloc[:-8766, 1]
ys_test = dff.dropna().iloc[-8766:, 1]
embedding_dim = 4       # potential hyperparameter

# defining baseline NN model
float_inputs = tf.keras.layers.Input(shape=(len(inputs)-3,), name="float_inputs")           # floats can be directly used in model fitting
integer_inputs = tf.keras.layers.Input(shape=(3,), dtype="int32", name="integer_inputs")    # integers should be treated as categoricals ang get them embedded
embedding_layer = tf.keras.layers.Embedding(input_dim=3, output_dim=embedding_dim)          # embedding will be performed during model fitting
embedded_integer_inputs = embedding_layer(integer_inputs)
flattened_embeddings = tf.keras.layers.Flatten()(embedded_integer_inputs)                   
preprocessing_layers = tf.keras.layers.concatenate([float_inputs, flattened_embeddings])    # float and embedded inputs are combined
hidden_layers = tf.keras.layers.Dense(units=64, activation="relu")(preprocessing_layers)    # No. of hidden layers, No. of units, activation function are potential hyperparameters
hidden_layers = tf.keras.layers.Dense(units=32, activation="relu")(hidden_layers)
output = tf.keras.layers.Dense(units=1, activation="linear")(hidden_layers)                 # single unit for one step ahead, multiple units for multiple step prediction
model_NN_baseline = tf.keras.Model(inputs=[float_inputs, integer_inputs], outputs=output)

# compiling baseline NN model
model_NN_baseline.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.MeanSquaredError(),
    jit_compile=True)

# fitting baseline NN model
model_NN_baseline.fit(
    x=[xs_train.iloc[:, :-3], xs_train.iloc[:, -3:]],
    y=ys_train,
    validation_data=[[xs_test.iloc[:, :-3], xs_test.iloc[:, -3:]], ys_test],
    epochs=128,
    batch_size=64,
    verbose=1
)

使用 Tensorflow/Keras 创建的 MLP 架构。来源:作者

就点预测而言,结果非常准确。这是一个好兆头,说明所应用的特征工程原则正确地捕捉了数据中的潜在模式,模型能够将其泛化。

基准 MLP 模型的点预测与测试数据对比。来源:作者

点预测与测试集重叠,且两个图形轨迹几乎无法区分。更精确地说,预测(训练集)和预测值(测试集)的 RMSE 分别约为 19.3 和 18.9(相对误差约为 0.1%)。

特征重要性

是什么使得模型准确?所有派生特征是否同样重要,还是有一个子集在决定结果时具有更大的权重?这两个问题有其有效性,原因有二:

  1. 在实际场景中,尤其是在大数据情况下,训练模型的资源有限,所使用的数据量可能对是否能够训练模型产生重大影响。

  2. 如果没有任何解释,模型就像一个黑箱,这会带来关于其性能的不确定性。神经网络尤其容易成为黑箱模型,解释它们是一个独立的领域[11]。

目前有大量的模型解释技术,每种方法都有其优缺点。我选择了排列特征重要性方法,以便为模型解释提供一些见解。然而,我分析中的一个关键结论是,

模型解释技术仅仅是在特定范围内解释模型,而不一定是解释其背后的过程。现实可能与特征重要性分析大相径庭,因此不应将其视为自变量与目标变量之间因果关系的最终真相。

让我用我的分析结果来解释这一点。逐一置换特征,重新计算 RMSE 得分并记录相对于使用原始数据的预测 RMSE 的相对变化,将给出特征的相对重要性[13]。

# permutation feature importance
features = xs_test.columns
permutation_importance_results = {}
rmse = tf.keras.metrics.RootMeanSquaredError()
rmse_permuted = tf.keras.metrics.RootMeanSquaredError()
rmse.update_state(ys_test.values, model_NN_baseline.predict([xs_test.iloc[:, :-3], xs_test.iloc[:, -3:]], verbose=0).flatten())

for feature in features:

    xs_test_permuted = xs_test.copy()
    xs_test_permuted.loc[:, feature] = xs_test.loc[:, feature].sample(frac=1, axis=0, replace=False, random_state=42).values

    rmse_permuted.reset_state()
    rmse_permuted.update_state(ys_test.values, model_NN_baseline.predict([xs_test_permuted.iloc[:, :-3], xs_test_permuted.iloc[:, -3:]], verbose=0).flatten())

    permutation_importance_results[feature] = rmse_permuted.result().numpy() / rmse.result().numpy()

pi_results_sorted_keys = sorted(permutation_importance_results, key=permutation_importance_results.get, reverse=True)

fig3 = make_subplots()
fig3.add_trace(go.Bar(
    x=pi_results_sorted_keys,
    y=[permutation_importance_results[key] for key in pi_results_sorted_keys]))
fig3.update_layout(
    title="<b>Permutation Feature Importance</b>",
    font=dict(family="Arial"),
    margin=dict(b=4, l=4, r=4, t=36),
    showlegend=False,
    height=500,
    paper_bgcolor="#FFFFFF",
    plot_bgcolor="#FFFFFF"
)
fig3.update_xaxes(griddash="dot", gridcolor="#808080", row=1, col=1)
fig3.update_yaxes(griddash="dot", gridcolor="#808080", row=1, col=1)

置换特征重要性直方图。来源:作者

每小时和每日滞后以及差异似乎很重要,也许每小时的滚动均值也很重要。然而,每日和每小时滚动标准以及分类特征似乎相对较小,可以忽略不计,与上述特征相比。置换特征重要性的一项警告是,它没有考虑多重共线性,因此可能会给出不准确的结果。请记住,这些特征是从具有高自相关的数据集中推导出来的。

处理这种情况的一种可能方式是遵循scikit learn的指导:

对 Spearman 等级顺序相关性执行层次聚类,选择一个阈值,并从每个簇中保留一个特征。[13]

然而,我想专注于突出不准确之处,并通过逐一训练替代模型以分组特征来为数据集添加更多洞见。为此使用了相同的 MLP 架构,仅对输入层进行了调整,以适应数据的子集。以下组在特征工程部分创建并在此测试(训练/测试数据集的 RMSE 结果也分别报告):

  1. 每日滞后(942 和 994)

  2. 每日差异(1792 和 1952)

  3. 每小时滞后(686 和 611)

  4. 日常滚动均值和标准差(1710 和 1663)

  5. 每小时滚动均值和标准差(84.4 和 75.5)

显然,替代模型显示的结果与简单的置换特征重要性分析预期不同,且未处理多重共线性:例如,每日滚动特征的得分优于每日差异,且训练于每小时滚动特征的模型在所有替代模型中表现最佳,接近基线模型(RMSE 分别为百分比0.5%和0.1%)。

数据中的特定异常说明

我想强调 2008 年 10 月 20 日 14:00 观察到的一个非常特殊的异常情况。 这是有史以来记录的最高值,且没有明显的原因,数据集中之前和之后没有类似的数据点。

然而,由特征工程驱动的基线模型能够预测该数据点,并且不被认为是异常值!

基线 MLP 模型的点预测和观察到的潜在异常。来源:作者

模型是如何预测该数据点的呢?让我们使用替代模型进行推理。最佳的替代模型(每小时滚动特征)在该点附近似乎非常准确,但只能部分解释这一现象:

替代的 MLP 模型(利用每小时滚动特征)点预测和观察到的潜在异常。来源:作者

第二好的替代方案是利用每小时滞后的模型,但它完全没有解释为何会发生这种情况:

替代的 MLP 模型(利用每小时滞后特征)点预测和观察到的潜在异常。来源:作者

简而言之,每日差异可能包含有关潜在模式的重要信息。尽管单独使用每日差异组会给出更高的预测值,但基准模型似乎找到了特征权重的良好平衡。

替代的 MLP 模型(利用每日差异特征)点预测和观察到的潜在异常。来源:作者

多步预测模型

最后,模型架构已被修改,以生成多步预测。预测期为一年,如数据集发布者所建议的[14]。考虑到这种过程中的所有不确定性,特别是天气条件方面,考虑这么长的预测期可能没有意义。然而,这对于评估多步模型与状态空间模型的表现是一个有趣的练习,后者明确建模了跨年观察到的趋势和季节性效应(见下一节)。

实现多步模型的关键点如下:

  1. 目标是一个向量序列(为每个步骤定义的接下来的 8766 小时)

  2. 结果是,预测或预报是接下来 8766 小时(大约一年)的最后一行输入数据

  3. 由于资源限制,我不得不限制前一训练数据集的最后一年的训练数据。

  4. 输出层已相应修改,以给出所需的向量输出

first_index = -8766*5
last_index = -8766*2
final_index = -8766
inputs = dff.dropna().iloc[:, 2:].columns
xs_train = dff.dropna().iloc[first_index:last_index, 2:]
xs_train.iloc[:, :-3] = xs_train.iloc[:, :-3].astype(np.float32)
xs_test = dff.dropna().iloc[last_index:final_index, 2:]
xs_test.iloc[:, :-3] = xs_test.iloc[:, :-3].astype(np.float32)
ys_train = np.vstack([dff.dropna().iloc[i:i+8765, 1].astype(int).values for i in range(first_index, last_index, 1)])
ys_test = np.vstack([dff.dropna().iloc[i:i+8765, 1].astype(int).values for i in range(last_index, final_index, 1)])
embedding_dim = 4

# defining, compiling and training NN model for MULTIPLE STEP PREDICTIONS. Model architecture is the same, except output layer
float_inputs = tf.keras.layers.Input(shape=(len(inputs)-3,), name="float_inputs")
integer_inputs = tf.keras.layers.Input(shape=(3,), dtype="int32", name="integer_inputs")
embedding_layer = tf.keras.layers.Embedding(input_dim=3, output_dim=embedding_dim)
embedded_integer_inputs = embedding_layer(integer_inputs)
flattened_embeddings = tf.keras.layers.Flatten()(embedded_integer_inputs)
preprocessing_layers = tf.keras.layers.concatenate([float_inputs, flattened_embeddings])
hidden_layers = tf.keras.layers.Dense(units=64, activation="relu")(preprocessing_layers)
hidden_layers = tf.keras.layers.Dense(units=32, activation="relu")(hidden_layers)
output = tf.keras.layers.Dense(units=np.abs(final_index)-1, activation="linear")(hidden_layers)

model_NN_multistep = tf.keras.Model(inputs=[float_inputs, integer_inputs], outputs=output)
model_NN_multistep.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.MeanSquaredError(),
    jit_compile=True)
model_NN_multistep.fit(
    x=[xs_train.iloc[:, :-3], xs_train.iloc[:, -3:]],
    y=ys_train,
    validation_data=[[xs_test.iloc[:, :-3], xs_test.iloc[:, -3:]], ys_test],
    epochs=128,
    batch_size=64,
    verbose=1
)

对于可视化评估,可以看出模型试图对模式进行泛化:

多步 MLP 模型的预测与原始数据的对比。来源:作者

MLP 与状态空间模型

由于数据的泛化,RMSE 得分显著增加:训练集和测试集的得分分别为 1982 和 2017。然而,为了正确评估多步 MLP,我们应该使用另一个模型进行比较。正如我在前一节中提到的,状态空间模型提供了对跨年观察到的趋势和季节性效应的相当可理解的近似。这一特点使得它们相对容易解释,不像神经网络。主要原因是隐藏层有很多连接,理解它们是如何被激活的并不是一个直接的过程。[11]

我之前的文章中,☚📰我使用了一种简化但有意义的评估方法:比较过去一年内的总能耗。实际上,这是能耗时间序列下的曲线面积。可以直接比较原始数据和模型预测的值。对于UnobservedComponents模型:

y_train = df.iloc[:-8766, 0].values
y_test = df.iloc[-8766:, 0].values
observed_integral = np.cumsum([y_test[x] + (y_test[x+1] - y_test[x]) / 2 for x in range(len(y_test)-1)])[-1]
forecast = model_results.forecast(steps=8766)
UC_integral = np.cumsum([forecast[x] + (forecast[x+1] - forecast[x]) / 2 for x in range(len(forecast)-1)])[-1]

# calculating absolute and percentage error of forecast integral compared to observed integral
fcast_integral_abserror = UC_integral - observed_integral
fcast_integral_perror4 = (UC_integral - observed_integral) * 100 / observed_integral

print(f"Observed yearly energy demand: {'%.3e' % observed_integral} MWh")
print(f"Forecast yearly energy demand: {'%.3e' % UC_integral} MWh")
print(f"Forecast error of yearly energy demand: {'%.3e' % fcast_integral_abserror} MWh or {'%.3f' % fcast_integral_perror4} %")
Observed yearly energy demand: 1.312e+08 MWh
Forecast yearly energy demand: 1.283e+08 MWh
Forecast error of yearly energy demand: -2.832e+06 MWh or -2.159 %

对于 MLP 模型:

y_test = dff.dropna().iloc[-8766:-1, 1].values
observed_integral = np.cumsum([y_test[x] + (y_test[x+1] - y_test[x]) / 2 for x in range(len(y_test)-1)])[-1]
forecast = model_NN_multistep.predict([xs_test.iloc[-1:, :-3], xs_test.iloc[-1:, -3:]], verbose=0).flatten()
model_NN_multistep_integral = np.cumsum([forecast[x] + (forecast[x+1] - forecast[x]) / 2 for x in range(len(forecast)-1)])[-1]

# calculating absolute and percentage error of forecast integral compared to observed integral
fcast_integral_abserror = model_NN_multistep_integral - observed_integral
fcast_integral_perror4 = (model_NN_multistep_integral - observed_integral) * 100 / observed_integral

print(f"Observed yearly energy demand: {'%.3e' % observed_integral} MWh")
print(f"Forecast yearly energy demand: {'%.3e' % model_NN_multistep_integral} MWh")
print(f"Forecast error of yearly energy demand: {'%.3e' % fcast_integral_abserror} MWh or {'%.3f' % fcast_integral_perror4} %")
Observed yearly energy demand: 1.312e+08 MWh
Forecast yearly energy demand: 1.286e+08 MWh
Forecast error of yearly energy demand: -2.508e+06 MWh or -1.912 %

简而言之:它是-1.912% vs. -2.159%,偏向于 MLP 模型。请注意,这是通过使用 MLP 架构并结合一些简单的经验法则实现的,甚至没有考虑超参数调优或某些有效的模型训练特征,例如在评估指标达到平台期时减少学习率或提前停止。

结果应该是相当令人信服的,确实,通过利用相对简单的神经网络架构结合强大的特征工程技术,准确的预测工具已经在数据科学家的初级阶段触手可得。

资源

数据来源:

www.kaggle.com/datasets/robikscube/hourly-energy-consumption/ (CC0)

笔记本(仅代码,不包含输出):gist.github.com/danielandthelions/2e6f0edd30902113ad10fd9f20bda215

参考文献

[1] preset.io/blog/time-series-forecasting-a-complete-guide/

[2] www.ibm.com/topics/recurrent-neural-networks

[3] www.timescale.com/blog/time-series-analysis-what-is-it-how-to-use-it/

[4] plat.ai/blog/difference-between-prediction-and-forecast/

[5] dotdata.com/blog/practical-guide-for-feature-engineering-of-time-series-data/

[6] statisticsbyjim.com/time-series/autocorrelation-partial-autocorrelation/

[7] www.sciencedirect.com/topics/computer-science/multilayer-perceptron

[8] jina.ai/news/embeddings-in-depth/

[9] Hyndman, R.J., & Athanasopoulos, G. (2021) 《Forecasting: principles and practice》,第三版,OTexts:澳大利亚墨尔本。OTexts.com/fpp3。访问时间:2024 年 7 月 7 日

[10] statisticsbyjim.com/regression/root-mean-square-error-rmse/

[11] christophm.github.io/interpretable-ml-book/

[12] scikit-learn.org/stable/modules/preprocessing.html

[13] scikit-learn.org/stable/modules/permutation_importance.html#permutation-feature-importance

[14] www.kaggle.com/datasets/robikscube/hourly-energy-consumption/

神经网络用于灵活的多变量预测

原文:towardsdatascience.com/neural-networks-for-flexible-multivariate-forecasting-82194d6cca0f?source=collection_archive---------5-----------------------#2024-10-25

一步步轻松入门神经网络进行时间序列预测

Lucas SeeTowards Data Science Lucas See

·发表于Towards Data Science ·14 分钟阅读·2024 年 10 月 25 日

--

图片由Aron Visuals提供,来源于Unsplash

预测多个时间序列可能会迅速变得复杂;传统方法要么需要为每个序列单独建立模型(例如 SARIMA),要么要求所有序列必须相关(例如 VARMA)。神经网络提供了一种灵活的方法,能够使用单一模型进行多系列预测,无论这些系列是否相关。

此外,这种方法允许轻松地将外生变量纳入其中,并可以预测未来多个时间步的值,从而提供一个强大的通用解决方案,能够在各种情况下表现良好。

在本文中,我们将展示如何执行数据窗口化操作,将我们的数据从时间序列转化为监督学习格式,适用于单变量和多变量时间序列。一旦数据转换完成,我们将展示如何训练深度神经网络和 LSTM 来进行多变量预测。

检查我们的数据

我们将使用一个数据集,该数据集记录了 2013 年至 2016 年间印度德里的每日平均温度和湿度。该数据可在 Kaggle 上获取,并且根据CC0: 公共领域许可证进行使用,非常适合…

周期性函数的神经网络

原文:towardsdatascience.com/neural-networks-for-periodic-functions-648cfc940437?source=collection_archive---------6-----------------------#2024-01-17

当 ReLU 的外推能力不足时

Dr. Robert KüblerTowards Data Science Dr. Robert Kübler

·发布于Towards Data Science ·阅读时长 8 分钟·2024 年 1 月 17 日

--

Willian Justen de Vasconcellos拍摄,发布于Unsplash

神经网络被认为是任何函数的优秀逼近器——至少在我们不偏离数据集太远时。让我们看看这是什么意思。这里是一些数据:

图片由作者提供。

它不仅仅看起来像正弦波,它实际上就是正弦波,只是加了一些噪声。我们现在可以训练一个普通的前馈神经网络,拥有 1 个隐藏层,包含 1000 个神经元并使用 ReLU 激活。我们得到以下拟合结果:

图片由作者提供。

除了边缘部分,效果看起来相当不错。我们可以通过根据 Cybenko 的普适逼近定理,向隐藏层添加更多神经元来修复这个问题。但我想给你指出一些其他的东西:

图片由作者提供。

我们可以现在辩称,如果假设波形模式在外部继续,这种外推行为是不好的……

深度学习图解,第一部分:神经网络是如何工作的?

原文:towardsdatascience.com/neural-networks-illustrated-part-1-how-does-a-neural-network-work-c3f92ce3b462?source=collection_archive---------0-----------------------#2024-01-31

神经网络的图解和直观介绍

Shreya RaoTowards Data Science Shreya Rao

·发布于 Towards Data Science ·15 分钟阅读·2024 年 1 月 31 日

--

如果你读过我之前的文章,你就会知道接下来会讲什么。在这部分内容中,我们将复杂的概念通过图解使其变得有趣且易懂。如果你还没读过我之前的文章,我强烈推荐你从我的系列文章《机器学习入门包》开始阅读,因为你会发现这里的许多内容在后续的文章中仍然相关。

Shreya Rao

Shreya Rao

机器学习入门包

查看列表3 个故事

今天,我们要讲解的重点是神经网络。这只是我计划进行的一系列关于深度学习的文章中的第一篇。它将聚焦于简单的人工神经网络是如何学习的,并为你提供一个深入(哈,双关)理解神经网络是如何一层一层构建的——从神经元到神经元。这对我们后续的学习非常重要,超级必要。虽然我们会涉及一些数学细节,但不用担心,因为我们会逐步拆解并图解每一个步骤。等到这篇文章读完,你会发现它比看起来简单得多。

但是在我们探讨这个问题之前,你可能会想:我们为什么需要神经网络?随着如此多的…

Neural Speed:针对 4 位大型语言模型的 CPU 快速推理

原文:towardsdatascience.com/neural-speed-fast-inference-on-cpu-for-4-bit-large-language-models-0d611978f399?source=collection_archive---------2-----------------------#2024-04-18

比 llama.cpp 快最多 40 倍?

Benjamin MarieTowards Data Science Benjamin Marie

·发表于Towards Data Science ·5 分钟阅读·2024 年 4 月 18 日

--

使用 DALL-E 生成

在消费级硬件上运行大型语言模型(LLM)可能是一个挑战。如果 LLM 无法适应 GPU 内存,通常会应用量化技术来减小其大小。然而,即使经过量化,模型仍然可能过大,无法适应 GPU。另一种方法是使用优化过的框架,将其在 CPU RAM 上运行,例如 llama.cpp这样的 CPU 推理框架。

Intel 也在致力于加速 CPU 上的推理。他们提出了一种框架,Intel 为 Transformers 提供的扩展,该框架建立在 Hugging Face Transformers 之上,易于使用,能够充分利用 CPU。

使用Neural Speed(Apache 2.0 许可证),该框架依赖于 Intel 为 Transformers 提供的扩展,Intel 进一步加速了在 CPU 上对 4 位 LLM 的推理。根据 Intel 的说法,使用此框架可以使推理速度比 llama.cpp 快 40 倍

在本文中,我回顾了 Neural Speed 带来的主要优化。我展示了如何使用它并基准测试推理吞吐量。我还将其与 llama.cpp 进行了比较。

Neural Speed 对 4 位 LLM 的推理优化

在 2023 年 NeurIPS 大会上,Intel 展示了在 CPU 上进行推理的主要优化:

神经形态计算——一个更具前沿性、更环保的人工智能

原文:towardsdatascience.com/neuromorphic-computing-an-edgier-greener-ai-3911fab9fe09?source=collection_archive---------7-----------------------#2024-11-22

为什么计算机硬件和人工智能算法要通过借鉴大脑的灵感进行重新设计

Jonathan R. Williford, PhDTowards Data Science Jonathan R. Williford 博士

·发表于Towards Data Science ·14 分钟阅读·2024 年 11 月 22 日

--

神经形态计算不仅可能有助于将人工智能推向边缘计算,还能减少数据中心的碳排放。由作者通过 ImageGen 3 生成。

有关神经形态计算革命的定期宣告,这种革命借鉴大脑的灵感,重新思考神经网络及其运行的硬件。虽然该领域仍面临一些挑战,但在脉冲神经网络算法和神经形态硬件方面已经取得了坚实的成功,并持续稳步进展。这一进展为至少某些人工智能领域的颠覆铺平了道路,并将减少推理时每次计算的能耗,同时使人工智能能够更广泛地应用于边缘计算。在本文中,我将介绍一些神经形态计算和工程基础、训练、神经形态系统的优势以及剩余的挑战。

神经形态系统的经典应用场景是用于边缘设备,这些设备需要在本地进行计算,并且受限于能源,例如电池供电的设备。然而,最近对神经形态系统的一个新兴趣是减少数据中心的能耗,例如大型语言模型(LLM)所需的能量。例如,OpenAI 在 2023 年 12 月签署了购买价值 5100 万美元的神经形态芯片的意向书,来自 Rain AI。这是有道理的,因为 OpenAI 在推理上的开支非常大,有人估计 2024 年推理的花费大约为$40 亿。同时,英特尔的 Loihi 2 和 IBM 的 NorthPole(TrueNorth 的继任者)神经形态系统似乎都被设计用于服务器。

神经形态计算的前景可以大致分为两类:1)务实的、短期内已取得成功的应用;2)更具理想化的、狂热的神经科学家幻想,关于脉冲动力学如何赋予神经网络接近真实智能的潜力。当然,第二类让我更加兴奋,但本文将专注于第一类。而没有比深入探讨术语更激动人心的开始方式了。

术语

神经形态计算通常被定义为灵感来源于大脑的计算,但这一定义留给了很多想象空间。神经网络比经典计算更具神经形态特征,但如今神经形态计算特别关注使用基于事件的脉冲神经网络(SNNs)来提高能效。尽管 SNNs 是一种人工神经网络,但“人工神经网络”(ANNs)这一术语在神经形态文献中专门指的是更标准的非脉冲人工神经网络。Schuman 及其同事(2022)将神经形态计算机定义为非冯·诺依曼计算机,其中处理和内存都集中在人工神经元和突触中,而冯·诺依曼计算机则将处理与内存分开。

冯·诺依曼计算机基于数字信息操作,拥有分离的处理器和内存,并通过时钟进行同步,而神经形态计算机基于事件驱动的脉冲工作,计算和内存结合,并且是异步的。此内容由作者基于 Schuman 等人 2022 年的研究进行创作。

神经形态工程指的是硬件的设计,而“神经形态计算”则关注的是模拟的内容,而非模拟的平台。这两者密切相关,因为计算依赖于硬件的特性,而硬件的实现又依赖于实验证明最有效的方案。

另一个相关术语是神经 AI,其目标是使用人工智能来获得大脑的机制性理解,更关注生物学的真实感。神经形态计算关注神经科学作为手段,视大脑为可以用来实现目标的思想源泉,如在神经架构中的能效和低延迟。神经 AI 的相当一部分研究依赖于脉冲平均值,而非脉冲神经网络,这使得与大多数应用于离散任务的现代人工神经网络进行更为接近的比较成为可能。

基于事件的系统

由作者使用 ImageGen 3 生成。

神经形态系统是基于事件的,这是与现代人工神经网络系统工作方式的范式转变。即使是实时的人工神经网络系统,通常也一次处理一个帧,活动会从一层同步地传播到下一层。这意味着,在人工神经网络中,不携带信息的神经元需要与携带关键信息的神经元进行相同的处理。基于事件的系统是一种不同的范式,通常从传感器开始,并在需要处理信息的地方进行更多的工作。人工神经网络依赖于矩阵运算,这些运算无论矩阵中的值如何,都需要相同的时间和能量。神经形态系统则使用脉冲神经网络,其中工作的量取决于脉冲的数量。

传统的部署人工神经网络通常会连接到一个摄像头,该摄像头同步地记录单帧画面。然后,人工神经网络处理该帧。该帧的结果可能会输入到追踪算法中,进行进一步处理。

基于事件的系统可能从传感器开始,使用事件摄像头。每个像素在有变化超过阈值时会异步发送更新。因此,当场景中发生移动(而场景本身静止)时,与运动对应的像素会立即发送事件或脉冲,而无需等待同步信号。这些事件信号可以在几十微秒内发送,而传统摄像头可能以 24 Hz 的频率收集数据,并可能引入几十毫秒的延迟。除了更快地接收信息外,基于事件的系统中的信息更加稀疏,并且专注于运动。传统系统则必须依次通过每个网络层处理整个场景。

脉冲神经网络中的学习

训练脉冲神经网络的一种方法是使用人工神经网络作为教师。由作者使用 ImageGen 3 生成。

脉冲神经网络面临的主要挑战之一是如何训练它们。反向传播算法和随机梯度下降是训练人工神经网络的常见方法,但这些方法在脉冲神经网络中遇到了困难。如何训练脉冲神经网络尚未确立,以下方法是一些常见的应用方法:

  1. 人工神经网络到脉冲神经网络的转换

  2. 类似反向传播

  3. 突触可塑性

  4. 演化性

ANN 到 SNN 的转换

创建脉冲神经网络(SNN)的一种方法是跳过直接训练 SNN,而是训练 ANN。这种方法限制了可以使用的 SNN 类型和硬件。例如,Sengupta 等人(2019 年)使用没有泄漏或折返期的积分与脉冲(IF)神经元将 VGG 和 ResNets 转换为 ANN。他们引入了一种新颖的权重归一化技术来执行转换,该技术基于每个神经元的前突触权重设置其触发阈值。Priyadarshini Panda 博士在她的ESWEEK 2021 SNN 演讲中详细阐述了这一点。

优点

  1. 使深度 SNN 成为可能。

  2. 允许重用深度人工神经网络(ANN)的知识,例如训练、架构等。

缺点

  1. 限制架构仅适用于 ANN 以及转换过程。

  2. 网络未能学习如何利用脉冲神经网络(SNN)的特性,这可能导致较低的准确性和更长的延迟。

类似反向传播的方法和代理梯度下降

当前训练 SNN 的最常见方法是类似反向传播的方法。标准的反向传播方法无法训练 SNN,因为 1)脉冲阈值函数的梯度在阈值处非零,除非在阈值处它是未定义的,2)信用分配问题需要在时间维度上解决,而不仅仅是在空间维度(如颜色等)。

在人工神经网络(ANN)中,最常见的激活函数是 ReLU。而在脉冲神经网络(SNN)中,当膜电位超过某个阈值时,神经元会发放脉冲,否则不会。这被称为 Heaviside 函数。你也可以使用 sigmoid 函数,但那样就不再是脉冲神经网络。使用代理梯度的解决方案是在前向传播中使用标准的阈值函数,但在反向传播时使用“平滑”版本的 Heaviside 函数的导数,例如 sigmoid 函数(Neftci 等,2019 年;Bohte,2011 年)。

优点

  1. 与著名方法连接。

  2. 与转换方法相比,可能导致更节能的网络(Li 等,2022 年)

缺点

  1. 在空间和时间上都可能需要计算密集型的求解

突触可塑性

脉冲时序依赖性可塑性(STDP)是最著名的突触可塑性形式。在大多数情况下,当前突触(输入)脉冲在后突触脉冲之前立即到达时,STDP 会增强突触的强度。早期的模型在简单的无监督任务中显示了 STDP 的潜力,尽管在更复杂的模型和任务中获得良好效果仍然更为困难。

其他生物学习机制包括神经元和突触的修剪与创建、自稳可塑性、神经调节剂、星形胶质细胞和进化。最近还有一些证据表明,某些原始类型的知识可以通过表观遗传学传递下来。

优点

  1. 无监督

  2. 可以利用时间特性

  3. 生物启发

缺点

  1. 突触可塑性尚未完全理解,特别是在不同的时间尺度下。

  2. 难以与非平凡的网络一起工作。

进化优化

进化优化是另一种方法,具有一些很酷的应用,特别适合小型网络。Catherine Schuman 博士是这一领域的领先专家,她在 ICS 实验室做了一场关于神经形态计算的精彩演讲,视频可以在 YouTube 上观看。

优点

  1. 适用于许多任务、架构和设备。

  2. 可以学习拓扑和参数(减少对问题的知识需求)。

  3. 学习小型网络,从而导致较低的延迟。

缺点

  1. 对需要深度或大规模架构的问题无效。

神经形态系统的优点

能效

神经形态系统有两个主要优势:1)能效和 2)低延迟。关于能效,有很多令人兴奋的理由。例如,英特尔声称,他们的 Loihi 2 神经处理单元(NPU)在能耗上比传统人工神经网络(ANN)低 100 倍,同时速度比常规 ANN 快 50 倍。Chris Eliasmith 在YouTube 上的一场演讲中将神经形态硬件上的脉冲神经网络(SNN)与同一架构的标准硬件上的人工神经网络(ANN)进行了能效对比。他发现,Loihi 上的 SNN 在能效上比标准 NVIDIA GPU 上的 ANN 高 100 倍,比 NVIDIA Jetson GPU 上的 ANN 高 20 倍。它比英特尔神经计算棒(NCS)和 NCS 2 高效 5-7 倍。同时,SNN 的准确率为 93.8%,而 ANN 的准确率为 92.7%。

图表由作者根据 Chris Eliasmith 在www.youtube.com/watch?v=PeW-TN3P1hk&t=1308s的幻灯片重制,显示了神经形态处理器在效率上比传统处理器高出 5 到 100 倍,同时实现相似的准确率。

神经形态芯片更具能效,允许在低能耗的边缘设备上部署复杂的深度学习模型。2024 年 10 月,BrainChip 推出了 Akida Pico NPU,功耗低于 1 毫瓦,而英特尔的 Loihi 2 NPU 功耗为 1 瓦。这比常用于嵌入式人工神经网络的 NVIDIA Jetson 模块(功耗在 10 到 50 瓦之间)要低得多,而服务器 GPU 的功耗通常为 100 瓦左右。

比较人工神经网络(ANNs)和脉冲神经网络(SNNs)的能效是困难的,因为:1. 能效依赖于硬件,2. SNNs 和 ANNs 可以使用不同的架构,3. 它们适用于不同的问题。此外,SNNs 的能耗与脉冲的数量和时间步数成正比,因此需要尽量减少脉冲和时间步数,以实现最佳的能效。

理论分析通常用于估算 SNNs 和 ANNs 所需的能量,但这并没有考虑到用于 ANNs 的 CPU 和 GPU 与用于 SNNs 的神经形态芯片之间的所有差异。

观察自然界可以给我们提供一些未来可能实现的想法,Mike Davies 在英特尔的Architecture All Access YouTube 视频中分享了一个很好的轶事:

考虑到一只微小的凤头鹦鹉大脑的能力,这个两克重的大脑在大约 50 毫瓦的功率下运行。这个大脑使得凤头鹦鹉能够以每小时 20 英里的速度飞行,在觅食时能够导航未知的环境,甚至能学会将物体当作工具操作,并发出人类的语言。

在当前的神经网络中,存在大量的计算浪费。例如,一个图像编码器对一个空白页面的编码时间与对一本《沃尔多在哪里?》书中杂乱页面的编码时间相同。而在脉冲神经网络中,空白页面几乎没有单元激活,计算量也很小;而包含许多特征的页面则会激活更多的单元,使用更多的计算量。在现实生活中,视觉场中通常有一些区域包含更多的特征,需要比其他区域(如清晰的天空)更多的处理。在这两种情况下,SNNs 仅在需要工作时才会执行任务,而 ANNs 则依赖于矩阵乘法,这种计算方式很难稀疏使用。

这一点本身就令人兴奋。目前,很多深度学习涉及将大量音频或视频上传到云端,在那里数据被处理在庞大的数据中心,花费大量的能量用于计算和冷却计算设备,然后返回结果。通过边缘计算,你可以在本地设备上拥有更安全、更快速的语音识别或视频识别,且能耗比传统方法低几个数量级。

低延迟

当事件相机的像素接收器改变超过某个阈值时,它可以在微秒内发送一个事件或脉冲。它不需要等待快门或同步信号的发送。这一优势贯穿了 SNN 的事件驱动架构。单元可以立即发送事件,而不是等待同步信号。这使得类脑计算机在延迟方面比 ANN(人工神经网络)要快。因此,类脑处理对于需要低延迟的实时应用程序更有优势。如果问题允许批处理,并且您是通过吞吐量来衡量速度,这一优势会减弱,因为 ANN 可以更容易地利用批处理。然而,在实时处理(如机器人技术或用户接口)中,延迟更加重要。

缺点与挑战

无处不在的每一刻

挑战之一是类脑计算和工程正同时在多个层面上发展。模型的细节取决于硬件实现,实际化的模型的经验结果指导着硬件的发展。英特尔在其 Loihi 1 芯片上发现了这一点,并在其 Loihi 2 芯片中加入了更多的灵活性,然而,总会存在折中,且硬件和软件方面仍有许多进展需要做。

商业硬件的有限可得性

希望这种情况很快会改变,但商业硬件并不容易获得。BrainChip 的 Akida 是首个商业化的类脑芯片,尽管显然,它甚至不支持标准的漏积与发放(LIF)神经元。SpiNNaker 板曾经可以购买,这是一部分欧盟人类大脑计划的内容,但现在已经不再提供。英特尔通过英特尔类脑研究社区(INRC)项目向一些学术研究人员提供 Loihi 2 芯片。

数据集

类脑数据集的数量远少于传统数据集,但可以大得多。一些常见的小型计算机视觉数据集,如 MNIST(NMNIST,Li 等,2017)和 CIFAR-10(CIFAR10-DVS,Orchard 等,2015),已通过显示图像并使用基于事件的相机记录它们,转换为事件流。这些图像是在移动(或“扫视”)的过程中收集的,以增加处理的脉冲数量。对于更大的数据集,如 ES-ImageNet(Lin 等,2021),已使用事件相机的仿真。

通过静态图像衍生的数据集可能有助于将脉冲神经网络(SNN)与传统人工神经网络(ANN)进行比较,并且可能在训练或评估流程中有所帮助。然而,SNNs 天生是时序性的,如果你希望利用 SNN 的时序特性,用它们处理静态输入就没有太大意义。以下是一些能够利用 SNN 时序特性的相关数据集:

  • DvsGesture(Amir et al. 2017)——一个包含 11 种手势和臂部动作的数据集

  • Bullying10K(Dong et al. 2024)——一个用于欺凌识别的隐私保护数据集

可以通过标准的可见光摄像头数据生成合成数据,而无需使用昂贵的事件摄像头数据收集,然而这些数据不会展现事件摄像头所捕获的高动态范围和帧率。

Tonic 是一个示例 Python 库,可以轻松访问至少部分基于事件的数据集。这些数据集本身可能比传统数据集占用更多空间。例如,MNIST 的训练图像约为 10 MB,而在 N-MNIST 中则接近 1 GB。

需要考虑的另一件事是,数据集的可视化可能会很困难。即使是从静态图像衍生的数据集,也可能难以与原始输入图像匹配。此外,使用真实数据的好处通常是为了避免训练和推理之间的差距,因此,使用这些数据集的好处似乎取决于它们与部署或测试过程中使用的摄像头的相似性。

结论

由作者使用 ImageGen 3 和 GIMP 创建。

我们正处于神经形态计算的激动人心的时代,不仅硬件方面有了投资,脉冲神经网络(SNN)也在不断进展。虽然应用仍面临挑战,但已经有证明的案例表明,SNN 比传统的 ANN 在能效上更优,尤其是在标准服务器 GPU 上,同时延迟更低,精度相当。许多公司,包括英特尔、IBM、高通、模拟器件、Rain AI 和 BrainChip,都在投资神经形态系统。BrainChip 是首家将其神经形态芯片商业化的公司,而英特尔和 IBM 则分别进入了其研究芯片的第二代(Loihi 2 和 NorthPole)。过去几年,特别是在 Spikformer 论文(Zhou et al. 2022)和 SEW-ResNet 论文(Fang et al. 2021)之后,脉冲变压器和其他深度脉冲神经网络似乎出现了特别的成功。

参考文献

  • Amir, A., Taba, B., Berg, D., Melano, T., McKinstry, J., Di Nolfo, C., Nayak, T., Andreopoulos, A., Garreau, G., Mendoza, M., Kusnitz, J., Debole, M., Esser, S., Delbruck, T., Flickner, M., & Modha, D. (2017). 一个低功耗的全事件驱动手势识别系统. 7243–7252. openaccess.thecvf.com/content_cvpr_2017/html/Amir_A_Low_Power_CVPR_2017_paper.html

  • Bohte, S. M. (2011). 在分数预测脉冲神经元网络中进行误差反向传播。收录于 人工神经网络与机器学习 doi.org/10.1007/978-3-642-21735-7_8

  • Dong, Y., Li, Y., Zhao, D., Shen, G., & Zeng, Y. (2023). Bullying10K:一个大型神经形态数据集,用于隐私保护的欺凌识别。神经信息处理系统进展, 36, 1923–1937.

  • Fang, W., Yu, Z., Chen, Y., Huang, T., Masquelier, T., & Tian, Y. (2021). 脉冲神经网络中的深度残差学习。神经信息处理系统进展, 34, 21056–21069. proceedings.neurips.cc/paper/2021/hash/afe434653a898da20044041262b3ac74-Abstract.html

  • Li, C., Ma, L., & Furber, S. (2022). 快速脉冲神经网络的量化框架。前沿神经科学, 16. doi.org/10.3389/fnins.2022.918793

  • Li, H., Liu, H., Ji, X., Li, G., & Shi, L. (2017). CIFAR10-DVS:一个用于物体分类的事件流数据集。前沿神经科学, 11. doi.org/10.3389/fnins.2017.00309

  • Lin, Y., Ding, W., Qiang, S., Deng, L., & Li, G. (2021). ES-ImageNet: 一个用于脉冲神经网络的百万事件流分类数据集。前沿神经科学, 15. doi.org/10.3389/fnins.2021.726582

  • Neftci, E. O., Mostafa, H., & Zenke, F. (2019). 脉冲神经网络中的代理梯度学习:将基于梯度的优化方法引入脉冲神经网络。IEEE 信号处理杂志. doi.org/10.1109/MSP.2019.2931595

  • Orchard, G., Jayawant, A., Cohen, G. K., & Thakor, N. (2015). 使用扫视将静态图像数据集转换为脉冲神经形态数据集。前沿神经科学, 9. doi.org/10.3389/fnins.2015.00437

  • Schuman, C. D., Kulkarni, S. R., Parsa, M., Mitchell, J. P., Date, P., & Kay, B. (2022). 神经形态计算算法和应用的机遇。自然计算科学, 2(1), 10–19. doi.org/10.1038/s43588-021-00184-y

  • Sengupta, A., Ye, Y., Wang, R., Liu, C., & Roy, K. (2019). 《深入研究脉冲神经网络:VGG 和残差架构》. 神经科学前沿, 13. doi.org/10.3389/fnins.2019.00095

  • Zhou, Z., Zhu, Y., He, C., Wang, Y., Yan, S., Tian, Y., & Yuan, L. (2022 年 9 月 29 日). Spikformer:当脉冲神经网络遇到 Transformer。第十一届国际学习表征会议。 openreview.net/forum?id=frE4fUwz_h

资源

YouTube 上的讲座

最初发布于 https://neural.vision 于 2024 年 11 月 22 日。

用于训练物理(与计算机基础的)人工神经网络的新方法

原文:towardsdatascience.com/new-approach-for-training-physical-as-opposed-to-computer-based-artificial-neural-networks-e57517b70631?source=collection_archive---------9-----------------------#2024-08-12

由光波构建的神经网络可能使得人工智能系统更加多功能、可扩展且能效更高

LucianoSphere (Luciano Abriata, PhD)Towards Data Science LucianoSphere (Luciano Abriata, PhD)

·发布于Towards Data Science ·6 分钟阅读·2024 年 8 月 12 日

--

图片由 Dall-E 3 通过 ChatGPT 生成。

依赖于计算机内运行的深度人工神经网络的传统人工智能系统,需要大量的计算资源进行训练,这引发了对其可持续性的担忧。解决这个问题的一个有前景的途径是开发物理人工神经网络:这些系统通过物理方式使信息流动,模拟生物神经网络的结构,而不是像数字神经网络那样通过连接计算机中神经元的数字计算来实现。例如,在一种名为“光学神经网络”的物理神经网络子类中,光波被发射并结合起来以执行各种计算。但这些物理系统面临着独特的挑战,特别是在训练方面。最近发表于Nature杂志的一项研究(Xue et al Nature 632:280–286, 2024)提出了一种真正开创性的解决方案,利用物理学来应对这些挑战。这是迈向一个可能的未来的第一步,在这个未来中,人工智能系统运行在物理系统上,从而变得更加可管理、可扩展,并且……

下一代代理:释放动态上下文的力量

原文:towardsdatascience.com/next-level-agents-unlocking-the-power-of-dynamic-context-68b8647eef89?source=collection_archive---------5-----------------------#2024-05-15

如何提高 AI 代理的准确性、速度和令牌使用效率

Frank WittkampfTowards Data Science Frank Wittkampf

·发表于数据科学前沿 ·阅读时长 12 分钟·2024 年 5 月 15 日

--

大多数 AI 代理之间存在显著差异(图片由 MidJourney 生成)

引言

AI 代理的行为由两件事决定:(1) 它运行的模型和 (2) 你提供给模型的上下文。提供这些上下文的方式决定了代理输出的质量。你甚至可以说,在同一模型下运行的两个 AI 代理之间的行为完全由它们的上下文定义。那么,什么是代理的上下文呢?请参见下方的上下文类型(图片)以了解更多示例。

本文深入探讨了通过根据代理需求有选择地定制上下文,以提高 AI 代理效率和准确性的更高级方法。我将首先介绍 5 种策略,然后提供一些额外的实施技巧。本文的经验来源于在生产环境中与多代理团队合作的丰富经验,但同样适用于单代理系统。

AI 代理的上下文类型(图片来自作者)

理解上下文需求

大多数在线示例和教程简化了它们的上下文处理方式(例如应用固定的截止长度)。实际的 AI 应用需要更加复杂的方法。

不同的代理根据手头任务的不同具有不同的上下文需求。有些代理可能只需要前一条消息,而其他代理可能需要大量的历史数据才能准确地做出回应。这些差异表明,每个代理都需要一种量身定制的上下文管理方法。

考虑以下用户与两个代理(一个软件开发人员和一个软件审查员)之间的持续对话示例:

1 **User**: “Please improve the sorting function I wrote from the uploaded file”
2 **System_prompt_for_coder**: “You are a SW developer, you help the user to 
  develop …”
3 **Coding_agent**: “read_file_function”
4 **Function_response**: “coding snippet: ```一些返回的代码```py”
5 **System_prompt_for_coder**: “You are a SW developer, you help the user to
  develop …”
6 **Coding_agent**: “I’ve improved your code by adding … and … to your code, 
  here is the new function …”
7 **System_prompt_for_reviewer**: “you are a SW code reviewer, you …”
8 **Code_reviewer_agent**: “After reviewing the improved code, there 
  are a few more things that I would add to the code …"
9 **System_prompt_for_coder**: “you are a SW developer, you help the user to 
  develop …”
10 **Coding_agent**: “ … ”

这个例子中的上下文显然是重复且低效的。许多行是重复的(例如系统提示),而每个代理都得到了超过它所需的上下文。让我们回顾一下几种改善上下文处理的方法。

策略 1:消息标签化

消息标签化是管理和优化人工智能代理与其任务之间互动的关键工具。通过为对话中的每条消息分配元数据,我们可以智能地选择与代理当前任务最相关的信息。这一策略涉及几个关键策略:

相关性标签:每条消息都应标记上反映其与当前和未来互动相关性的标签。这个过程涉及分析消息的内容,并确定它对代理决策过程的潜在帮助。例如,包含问题、决策或见解的消息应该标记为高度相关。

持久性标签:对消息进行分类,依据其在时间上的持续性和有用性非常重要。有些消息,比如包含基础决策或里程碑通讯的消息,具有长期价值,应该在多个会话中保留。相反,系统消息可能只在特定时刻需要一次。这些应该在其即时相关性过去后从代理的记忆中排除。

来源与关联标签:这涉及识别每条消息的来源,无论是来自特定代理、用户、功能还是其他过程。这个标签有助于构建一个结构化且易于浏览的历史记录,使代理能够根据来源或任务相关性高效地检索和引用信息。

为消息的元数据应用智能标签可以使你进行智能选择。继续阅读以了解一些示例。

策略 2:代理特定的上下文需求

不同的代理有不同的需求。有些代理能在极少的信息下正常工作,而其他代理则需要大量上下文才能正确操作。这一策略建立在我们刚刚讨论的标签基础上。

关键上下文识别:识别哪些消息对每个特定代理至关重要并专注于这些消息,对于简化处理过程和提高响应准确性至关重要。让我们看一下上述上下文中的第 8 行。代码审查员只需要有限的上下文才能准确地完成工作。我们甚至可以相当确定地说,如果给它提供超过必要的上下文,它将给出更糟的答案。

那么它需要什么样的上下文呢?快速看一下,你会推测出代码审查员只需要它自己的系统提示,还需要它之前的最后一条代理消息,其中包含了最新版本的代码(第 6 行)。

因此,每个代理应该被配置为只选择它所需要的历史记录。代码审查员只查看最后 2 条消息,而代码编写者则需要更长的历史记录。

策略 3:优化系统提示

位置: 当你快速搜索代理和系统提示时,很明显,系统提示在代理中的位置非常重要。它应该是链中的第一条消息,还是最后一条消息?不同的使用案例会得出不同的观点和结果。例如,哪种方式能得到更好的结果?

1) user: "I visited dr. Fauci on Thursday, and got diagnosed with …"
2) system: "Extract all medically relevant info from the user prompt"

 1) system: "Extract all medically relevant info from the user prompt"
2) user: "I visited dr. Fauci on Thursday, and got diagnosed with …"

如果你在一个较大且更复杂的历史记录中进行测试,你会发现同样的指令放在不同的位置会产生不同的结果。可以明确的是,系统提示应该有意识地放置在特定位置,而且这取决于代理和使用场景。

注:根据我的经验,是否将系统提示放在链中的最后一条消息取决于 LLM。通常,将其作为最后一条消息是一个不错的选择(当然有明确的例外情况,比如信息提取)。

持久性: 其次,除了系统提示的位置,在大多数使用场景中,你应该确保系统提示被标记为“短暂的”。我所说的短暂消息,是指它在使用后不会保留在聊天历史中。

例如,agent1 可能会收到一个提示,要求它回顾并反思问题。接下来的一轮,agent1 被要求使用它的反思来找到问题的解决方案。在第二轮中,最初的提示(让其反思)会让代理感到困惑,导致答案更差。如果那个初始的系统提示不是持久性的,并且没有再次提供给代理,那就会更好。下面的例子中,这意味着在你第二次调用 LLM 时,第 2 行应该被删除。

#1 **User**: 
  "What is the meaning of the universe?"
#2 **System**: 
  "You're a philosopher; before you answer, step back and 
  reflect on what the user's question implies."
#3 **Agent_philo**: 
  "Reflecting on this question, one can consider the universe 
  through various philosophical lenses - existentialism views 
  it as a canvas for creating personal meaning, while cosmology 
  sees it as a playground for understanding physical laws. This 
  interplay between human perception and scientific exploration 
  is central to understanding our query."
#4 **System**: 
  "You're a philosopher, and you just reflected deeply on the 
  question posed by the user. Use that insightful reflection to 
  craft a comprehensive answer."
#5 **Agent_philo**: 
  "From an existentialist perspective, the meaning of the universe 
  is a personal journey, where each individual constructs their own 
  meaning through choices and life experiences. In contrast, from a 
  cosmological viewpoint, it is about exploring and understanding 
  the vast, intricate laws that govern the cosmos. These perspectives 
  help us grasp the multifaceted nature of our universe, blending our 
  existential quests with scientific inquiry."

这种方法确保任何旧的系统提示都从链中删除,并且每个提示都被合理地放置。现在我们拥有一个整洁的对话历史,从而实现更精确且可预测的交互。

策略 4:减少 RAG 中的冗余

我们可以撰写十几篇文章来优化你的代理,通过改善 RAG 的使用,但在这里我们会简要介绍几段。使用 RAG 时可能产生的大量令牌,已经足够让我们必须提到一些管理技巧。如果你还没有开始研究,这是一个你应该花大量时间研究的课题。

关于 RAG 的基础教程大多假设你或你的用户上传的文档是简单直接的。然而,在实际应用中,大多数文档都是复杂且不可预测的。我的经验是,很多文档都包含重复信息。例如,同样的信息往往会在 PDF 文章的引言、正文和结论中反复出现。或者,一份医疗文件中可能会有重复的医生更新,内容几乎完全相同。又或者,日志会一遍遍地重复。此外,特别是在生产环境中,当需要在大量文件中进行检索时,标准的 RAG 过程返回的内容可能会非常重复。

处理重复内容: 优化 RAG 上下文的第一步是识别并去除检索到的文档片段中的完全重复和近似重复内容,以防止冗余。完全重复的内容很容易识别。近似重复可以通过语义相似性来检测,方法是查看向量嵌入的多样性(多样化的片段具有相互间距离较大的向量),以及许多其他技术。如何实现这一点将极大地依赖于你的使用案例。以下是一些示例(来自 perplexity)

响应多样性: 另一种确保 RAG 响应多样性的方法是通过智能地对来自不同文件的内容进行分组。一种非常简单但有效的方法是,不仅仅按相似度取前 N 个文档,而是使用检索查询中的 GROUP BY。是否采用这种方法在很大程度上取决于你的使用案例。这里有一个示例(来自 perplexity)

动态检索: 既然本文讨论的是动态上下文,那么你如何将这种哲学引入到你的 RAG 过程中呢?大多数 RAG 过程会检索前 N 个结果,例如,前 10 个最相似的文档片段。然而,这并不是人类检索结果的方式。当你搜索信息时,你会去像 Google 这样的搜索引擎,直到找到正确的答案。这可能是第 1 个或第 2 个搜索结果,也可能是第 20 个。当然,这也取决于你的运气和耐力 😉。你可以以同样的方式构建 RAG。我们可以允许代理进行更具选择性的检索,只提供前几个结果,并让代理决定是否需要更多信息。

这是一个建议的做法。不要仅仅定义一个相似度的临界值,而是定义高、中、低三个临界点。例如,你的搜索结果可能有 11 个非常相似的文档,5 个中等相似的文档和 20 个稍微相似的文档。如果我们说代理每次获取 5 个文档,那么现在你让代理自己决定是否需要更多。你告诉代理它已经看到 11 个非常相似的文档中的 5 个,还有 25 个文档在其后。通过一些提示工程,你的代理在寻找数据时将开始更加理性地行动。

战术 5:上下文处理的高级策略

我将简要介绍几种策略,将动态上下文进一步推进。

即时元数据:如战术 1 所述,向消息添加元数据可以帮助你预选择特定代理所需的历史记录。对于大多数情况,简单的一个词文本标签就足够了。知道某个内容来自某个特定功能、代理或用户,可以让你为消息添加一个简单的标签,但如果你处理的是非常大的 AI 响应并且需要更多优化,那么有一种更高级的方式可以向消息添加元数据:通过 AI。

其中的一些例子包括:

  • 标记历史消息的一种简单方法是进行一次单独的 AI 调用(使用更便宜的模型),该调用会为消息生成标签。然而,现在你每次都需要进行两次 AI 调用,并且在流程中引入了额外的复杂性。

一种更优雅的生成标签的方式是让消息的原作者在编写响应的同时生成标签。

  • 让代理以 JSON 格式给出响应,其中一个元素是其正常的响应,另一个元素是内容的标签。

  • 使用多功能调用,并为代理提供它需要调用的功能,该功能定义了消息标签。

  • 在代理进行的任何功能调用中,保留一个必需的参数,其中包含一个标签。

    通过这种方式,你可以即时为功能内容生成标签。

另一个动态优化上下文的高级策略是对你的 RAG 进行预处理。

RAG 的双重处理:为了优化你的 RAG 流程,你可以考虑使用一个更便宜(且更快)的 LLM 来浓缩你的 RAG 结果,然后再将其提供给标准的 LLM。使用这种方法时的窍门是使用一个非常简单且非干扰性的提示,将原始的 RAG 结果浓缩或简化为更易消化的形式。

例如,你可能会使用一个更便宜的模型来提取特定信息,以减少重复,或者只选择与当前任务相关的文档部分。这确实需要你了解更便宜的模型的优缺点。当与更强大的模型结合使用时,这种方法可以节省大量成本(和提高速度)。

实施

好的,那么以上内容是否意味着我的每个代理都需要几页定制代码来优化其性能?我该如何将这些概念普遍化并加以扩展呢?

代理架构:这些问题的答案是,有一些清晰的方法可以进行设置。只需要一些前瞻性和规划。建立一个能够正确运行多种代理的平台要求你具备一个代理架构。如果你从一套明确的设计原则开始,那么利用动态上下文让你的代理更快、更便宜、更好,一次性满足所有需求,并不会非常复杂。

动态上下文配置是你代理架构中的一个元素。

动态上下文配置:正如本文所讨论的,每个代理都有独特的上下文需求。而管理这些需求可能会涉及管理所有可能代理上下文中的大量变化(请参见本文顶部的图像)。然而,好消息是,这些变化可以轻松地编码为几个简单的维度。让我给你一个例子,整合本文中的大多数概念。

让我们想象一个代理,它是一个软件开发人员,首先规划自己的行动,然后执行该计划。这个代理的上下文配置可能是:

  • 保留最初的用户问题

  • 保留计划

  • 忘记除最后一次代码修订和最后一条消息外的所有历史记录

  • 使用 RAG(上传的代码文件中)而不进行 RAG 凝缩

  • 始终将系统提示设置为最后一条消息

这个配置保存在该代理的上下文配置中。因此,现在你对 AI 代理的定义不仅仅是一个提示指令集合。你的代理还有一个特定的上下文配置。

你会看到,在不同的代理之间,这些配置可以非常有意义并且各不相同,它们使得代码的抽象化变得非常有意义,而不再是完全定制化的。

总结

正确管理动态上下文不仅能提升你的 AI 代理的性能,还能显著提高准确性、速度和令牌使用效率……你的代理现在更快、更好、更便宜,一次性满足所有需求。

你的代理不应仅仅由其提示指令定义,它还应该有自己的上下文配置。通过使用简单的维度来编码每个代理不同的配置,将极大地增强你通过代理实现的功能。

动态上下文只是你的代理架构中的一个元素。如果你想了解更多,欢迎邀请我讨论。如果你有任何问题或其他见解,可以在评论区与我交流,当然,如果你觉得这篇文章对你有帮助,也请给我点个赞或关注我。

编程愉快!

从你的 Rust 代码访问云文件的九个规则

原文:towardsdatascience.com/nine-rules-for-accessing-cloud-files-from-your-rust-code-d456c1e2ceb4?source=collection_archive---------7-----------------------#2024-02-07

升级 Bed-Reader:来自生物信息学库的实践经验

Carl M. KadieTowards Data Science Carl M. Kadie

·发表于 Towards Data Science ·阅读时长 21 分钟·2024 年 2 月 7 日

--

Rust 和 Python 直接从云端读取 DNA 数据 — 来源:openai.com/dall-e-2/。所有其他图片来自作者。

你希望你的 Rust 程序能够无缝地访问云端的文件数据吗?当我提到“云端文件”时,我指的是存储在 Web 服务器或云存储解决方案(如 AWS S3、Azure Blob Storage 或 Google Cloud Storage)中的数据。这里所说的“读取”包含了对文件内容的顺序检索——无论是文本还是二进制数据,从头到尾——并且具有根据需要定位并提取文件中特定部分的能力。

将程序升级以访问云文件,可以减少烦恼和复杂性:不再需要将文件下载到本地存储,也不再需要定期检查本地副本是否为最新。

不幸的是,将你的程序升级以访问云文件也可能会增加烦恼和复杂性:URLs 和凭证信息带来的烦恼,以及异步编程的复杂性。

Bed-Reader 是一个用于读取 PLINK Bed 文件的 Python 包和 Rust crate,这是一种在生物信息学中用于存储基因型(DNA)数据的二进制格式。应用户的要求,我最近更新了 Bed-Reader,使其能够选择性地直接从云存储读取数据。在此过程中,我总结了九条规则,可以帮助你为程序添加云文件支持。这些规则是:

  1. 使用 crate [object_store](https://crates.io/crates/object_store)(以及可能的 [cloud-file](https://crates.io/crates/cloud-file))按顺序读取云文件的字节。

  2. 通过两个嵌套的循环顺序读取云文件中的文本行。

  3. 随机访问云文件,即使是非常大的文件,也可以使用“范围”方法,同时遵守服务器施加的限制。

  4. 使用 URL 字符串和选项字符串访问 HTTP、本地文件、AWS S3、Azure 和 Google Cloud。

  5. 通过 [tokio](https://crates.io/crates/tokio)::test 测试 HTTP 和本地文件。

如果其他程序调用你的程序——换句话说,如果你的程序提供了一个 API(应用程序接口)——则有四条额外的规则适用:

6. 为了获得最佳性能,可以通过异步 API 将云文件支持添加到你的 Rust 库中。

7. 或者,为了最大程度的方便,可以通过传统的(“同步”)API 将云文件支持添加到你的 Rust 库中。

8. 通过使用文档测试中的隐藏行,遵循良好的 API 设计规则。

9. 包括一个运行时,但可以选择性地使用。

顺便提一下:为了避免模糊不清,我称这些为“规则”,但它们当然只是建议。

规则 1:使用 crate object_store(以及可能的 cloud-file)顺序读取云文件的字节。

强大的[object_store](https://crates.io/crates/object_store) crate 提供对存储在 HTTP、AWS S3、Azure、Google Cloud 和本地文件中的文件的完整内容访问。它是 Apache Arrow 项目的一部分,已下载超过 240 万次。

对于本文,我还创建了一个名为 [cloud-file](https://crates.io/crates/cloud-file) 的新 crate。它简化了 object_store crate 的使用。它包装并专注于 object_store 的一个有用子集。你可以直接使用它,或者将其代码提取出来供你自己使用。

让我们来看一个例子。我们通过计算云文件中包含的换行符数量来统计文件的行数。

use cloud_file::{CloudFile, CloudFileError};
use futures_util::StreamExt; // Enables `.next()` on streams.

async fn count_lines(cloud_file: &CloudFile) -> Result<usize, CloudFileError> {
    let mut chunks = cloud_file.stream_chunks().await?;
    let mut newline_count: usize = 0;
    while let Some(chunk) = chunks.next().await {
        let chunk = chunk?;
        newline_count += bytecount::count(&chunk, b'\n');
    }
    Ok(newline_count)
}

#[tokio::main]
async fn main() -> Result<(), CloudFileError> {
    let url = "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/toydata.5chrom.fam";
    let options = [("timeout", "10s")];
    let cloud_file = CloudFile::new_with_options(url, options)?;
    let line_count = count_lines(&cloud_file).await?;
    println!("line_count: {line_count}");
    Ok(())
}

当我们运行这段代码时,它返回:

line_count: 500

一些要点:

  • 我们使用 async(在这里使用的是[tokio](https://docs.rs/tokio/latest/tokio/))。我们将在规则 6 和 7 中进一步讨论这一选择。

  • 我们通过 CloudFile::new_with_options(url, options)? 将 URL 字符串和选项字符串转换为 CloudFile 实例。我们使用 ? 来捕获格式错误的 URL)。

  • 我们通过 cloud_file.stream_chunks().await? 创建一个二进制块流。这是代码首次尝试访问云文件的地方。如果文件不存在或无法打开,? 会返回一个错误。

  • 我们使用 chunks.next().await 来获取文件的下一个二进制块。(请注意 use futures_util::StreamExt;。)next 方法在所有块被检索完后返回 None

  • 如果确实有下一个块,但在检索时出现问题呢?我们将通过 let chunk = chunk?; 捕获任何问题。

  • 最后,我们使用快速的 [bytecount](https://docs.rs/bytecount/latest/bytecount/) crate 来计算换行符的数量。

与这种云解决方案相对比,想想你会如何为本地文件编写一个简单的行计数器。你可能会写成这样:

use std::fs::File;
use std::io::{self, BufRead, BufReader};

fn main() -> io::Result<()> {
    let path = "examples/line_counts_local.rs";
    let reader = BufReader::new(File::open(path)?);
    let mut line_count = 0;
    for line in reader.lines() {
        let _line = line?;
        line_count += 1;
    }
    println!("line_count: {line_count}");
    Ok(())
}

cloud-file版本和本地文件版本之间,有三个突出差异。首先,我们可以轻松地将本地文件作为文本读取。默认情况下,我们将云文件作为二进制文件读取(但请参见规则 2)。其次,默认情况下,我们同步读取本地文件,直到完成才会阻塞程序执行。另一方面,我们通常异步访问云文件,这样在等待相对较慢的网络访问完成时,程序的其他部分仍然可以继续运行。第三,像lines()这样的迭代器支持for循环。然而,像stream_chunks()这样的流则不支持,所以我们使用while let

我之前提到过,你不需要使用cloud-file包装器,而可以直接使用object_store crate。接下来,让我们看看只使用object_store方法时,如何计算云文件中的换行符:

use futures_util::StreamExt;  // Enables `.next()` on streams.
pub use object_store::path::Path as StorePath;
use object_store::{parse_url_opts, ObjectStore};
use std::sync::Arc;
use url::Url;

async fn count_lines(
    object_store: &Arc<Box<dyn ObjectStore>>,
    store_path: StorePath,
) -> Result<usize, anyhow::Error> {
    let mut chunks = object_store.get(&store_path).await?.into_stream();
    let mut newline_count: usize = 0;
    while let Some(chunk) = chunks.next().await {
        let chunk = chunk?;
        newline_count += bytecount::count(&chunk, b'\n');
    }
    Ok(newline_count)
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
    let url = "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/toydata.5chrom.fam";
    let options = [("timeout", "10s")];

    let url = Url::parse(url)?;
    let (object_store, store_path) = parse_url_opts(&url, options)?;
    let object_store = Arc::new(object_store); // enables cloning and borrowing
    let line_count = count_lines(&object_store, store_path).await?;
    println!("line_count: {line_count}");
    Ok(())
}

你会发现代码与cloud-file代码非常相似。不同之处在于:

  • 与单个CloudFile输入不同,大多数方法需要两个输入:ObjectStoreStorePath。因为ObjectStore是一个不可克隆的特性,这里count_lines函数专门使用了&Arc<Box<dyn ObjectStore>>。或者,我们可以将函数设为泛型,并使用&Arc<impl ObjectStore>

  • 创建ObjectStore实例、StorePath实例和流需要比创建CloudFile实例和流多一些步骤。

  • 不再只处理一种错误类型(即CloudFileError),而是可能出现多种错误类型,因此我们退回使用了[anyhow](https://crates.io/crates/anyhow) crate。

无论你是直接使用object_store(目前下载量为 240 万次),还是通过cloud-file间接使用(目前下载量为 124 次 😀),都由你决定。

在本文的其余部分,我将重点讨论cloud-file。如果你想将cloud-file方法转换为纯object_store代码,可以查阅该方法的文档,并点击“source”链接。源代码通常只有一两行。

我们已经了解了如何顺序读取云文件的字节。接下来,让我们看看如何顺序读取它的行。

规则 2:通过两个嵌套循环顺序读取云文件的文本行。

我们经常需要顺序读取云文件的行。使用cloud-file(或object_store)来实现这一点需要两个嵌套循环。

外部循环像以前一样返回二进制块,但有一个关键的修改:我们现在确保每个块只包含完整的行,从行的第一个字符开始,到换行符结束。换句话说,块可能包含一行或多行完整的内容,但没有部分行。内部循环将块转换为文本,并迭代生成的一行或多行。

在这个例子中,给定一个云文件和一个数字n,我们找出索引位置n的那一行:

use cloud_file::CloudFile;
use futures::StreamExt;  // Enables `.next()` on streams.
use std::str::from_utf8;

async fn nth_line(cloud_file: &CloudFile, n: usize) -> Result<String, anyhow::Error> {
    // Each binary line_chunk contains one or more lines, that is, each chunk ends with a newline.
    let mut line_chunks = cloud_file.stream_line_chunks().await?;
    let mut index_iter = 0usize..;
    while let Some(line_chunk) = line_chunks.next().await {
        let line_chunk = line_chunk?;
        let lines = from_utf8(&line_chunk)?.lines();
        for line in lines {
            let index = index_iter.next().unwrap(); // safe because we know the iterator is infinite
            if index == n {
                return Ok(line.to_string());
            }
        }
    }
    Err(anyhow::anyhow!("Not enough lines in the file"))
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
    let url = "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/toydata.5chrom.fam";
    let n = 4;

    let cloud_file = CloudFile::new(url)?;
    let line = nth_line(&cloud_file, n).await?;
    println!("line at index {n}: {line}");
    Ok(())
}

这段代码打印:

line at index 4: per4 per4 0 0 2 0.452591

一些值得注意的要点:

  • 关键方法是.stream_line_chunks()

  • 我们还必须调用std::str::from_utf8来创建文本。(可能会返回一个[Utf8Error](https://doc.rust-lang.org/std/str/struct.Utf8Error.html)。)此外,我们调用.lines()方法来创建一个行迭代器。

  • 如果我们想要行索引,必须自己创建。这里我们使用:

let mut index_iter = 0usize..;
...
let index = index_iter.next().unwrap(); // safe because we know the iterator is infinite

旁白:为什么要用两个循环?为什么cloud-file不定义一个返回每次一行的流?因为我不知道怎么做。如果有人能搞明白,请发送一个包含解决方案的 pull 请求给我!

我希望这能更简单些。我很高兴它是高效的。让我们通过下一步来回归简洁,看看如何随机访问云文件。

规则 3:使用范围方法随机访问云文件,即使是巨大的文件,同时尊重服务器设置的限制。

我在处理一种叫做 PLINK Bed 1.9 的基因组学文件格式。文件最大可达 1 TB。是不是太大,无法通过网络访问?不一定。有时候我们只需要文件的一小部分。此外,现代云服务(包括大多数网络服务器)可以高效地从云文件中提取感兴趣的区域。

让我们看一个例子。这个测试代码使用了一个名为read_range_and_file_sizeCloudFile方法。它读取一个*.bed 文件的前 3 个字节,检查文件是否以预期的字节开头,然后检查文件的预期长度。

#[tokio::test]
async fn check_file_signature() -> Result<(), CloudFileError> {
    let url = "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/plink_sim_10s_100v_10pmiss.bed";
    let cloud_file = CloudFile::new(url)?;
    let (bytes, size) = cloud_file.read_range_and_file_size(0..3).await?;

    assert_eq!(bytes.len(), 3);
    assert_eq!(bytes[0], 0x6c);
    assert_eq!(bytes[1], 0x1b);
    assert_eq!(bytes[2], 0x01);
    assert_eq!(size, 303);
    Ok(())
}

注意,在一次网络调用中,这个方法不仅返回请求的字节,还返回了整个文件的大小。

这里是一些高层次的CloudFile方法及它们在一次网络调用中可以检索的内容:

  • [read_all](https://docs.rs/cloud-file/0.1.0-beta.2/cloud_file/struct.CloudFile.html#method.read_all) — 作为内存中的[Bytes](https://docs.rs/bytes/latest/bytes/struct.Bytes.html)返回的整个文件内容

  • [read_range](https://docs.rs/cloud-file/0.1.0-beta.2/cloud_file/struct.CloudFile.html#method.read_range) — 从指定范围读取的[Bytes](https://docs.rs/bytes/latest/bytes/struct.Bytes.html)

  • [read_ranges](https://docs.rs/cloud-file/0.1.0-beta.2/cloud_file/struct.CloudFile.html#method.read_ranges) — 从指定范围读取的Vec类型的[Bytes](https://docs.rs/bytes/latest/bytes/struct.Bytes.html)

  • [read_range_and_file_size](https://docs.rs/cloud-file/0.1.0-beta.2/cloud_file/struct.CloudFile.html#method.read_range_and_file_size) — 从指定范围读取的[Bytes](https://docs.rs/bytes/latest/bytes/struct.Bytes.html)和文件的大小

  • [read_file_size](https://docs.rs/cloud-file/0.1.0-beta.2/cloud_file/struct.CloudFile.html#method.read_file_size) — 文件的大小

如果我们一次请求太多数据,这些方法可能会遇到两个问题。首先,我们的云服务可能会限制每次调用能检索的字节数。其次,通过同时发出多个请求而不是一次发一个请求,我们可能会得到更快的结果。

考虑这个例子:我们想要收集一个任意大小的文件中相邻 ASCII 字符的频率统计。例如,在一个包含 10,000 个相邻字符的随机样本中,或许“th”出现了 171 次。

假设我们的 Web 服务器支持 10 个并发请求,但每个请求只允许我们获取 750 字节。(8 MB 会是一个更常见的限制)。

感谢 Seattle Rust Meetup 的 Ben Lichtman(B3NNY)指引我正确的方向,帮助我向异步流添加了限制。

我们的主函数可能长这样:

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
    let url = "https://www.gutenberg.org/cache/epub/100/pg100.txt";
    let options = [("timeout", "30s")];
    let cloud_file = CloudFile::new_with_options(url, options)?;

    let seed = Some(0u64);
    let sample_count = 10_000;
    let max_chunk_bytes = 750; // 8_000_000 is a good default when chunks are bigger.
    let max_concurrent_requests = 10; // 10 is a good default

    count_bigrams(
        cloud_file,
        sample_count,
        seed,
        max_concurrent_requests,
        max_chunk_bytes,
    )
    .await?;

    Ok(())
}

count_bigrams函数可以首先创建一个随机数生成器,并调用来查找云文件的大小:

#[cfg(not(target_pointer_width = "64"))]
compile_error!("This code requires a 64-bit target architecture.");

use cloud_file::CloudFile;
use futures::pin_mut;
use futures_util::StreamExt; // Enables `.next()` on streams.
use rand::{rngs::StdRng, Rng, SeedableRng};
use std::{cmp::max, collections::HashMap, ops::Range};

async fn count_bigrams(
    cloud_file: CloudFile,
    sample_count: usize,
    seed: Option<u64>,
    max_concurrent_requests: usize,
    max_chunk_bytes: usize,
) -> Result<(), anyhow::Error> {
    // Create a random number generator
    let mut rng = if let Some(s) = seed {
        StdRng::seed_from_u64(s)
    } else {
        StdRng::from_entropy()
    };

    // Find the document size
    let file_size = cloud_file.read_file_size().await?;
//...

接下来,根据文件大小,函数可以创建一个包含 10,000 个随机两字节范围的向量。

 // Randomly choose the two-byte ranges to sample
    let range_samples: Vec<Range<usize>> = (0..sample_count)
        .map(|_| rng.gen_range(0..file_size - 1))
        .map(|start| start..start + 2)
        .collect();

例如,它可能生成以下向量[4122418..4122420, 4361192..4361194, 145726..145728,]。但一次性获取 20,000 字节(我们假设这样)太多了。所以,我们将向量分成 27 个块,每个块不超过 750 字节:

 // Divide the ranges into chunks respecting the max_chunk_bytes limit
    const BYTES_PER_BIGRAM: usize = 2;
    let chunk_count = max(1, max_chunk_bytes / BYTES_PER_BIGRAM);
    let range_chunks = range_samples.chunks(chunk_count);

使用一些异步魔法,我们为每个 27 个块创建一个未来工作的迭代器,然后将该迭代器转换为流。我们告诉流最多同时调用 10 个请求。此外,我们还表示可以接受乱序的结果。

 // Create an iterator of future work
    let work_chunks_iterator = range_chunks.map(|chunk| {
        let cloud_file = cloud_file.clone(); // by design, clone is cheap
        async move { cloud_file.read_ranges(chunk).await }
    });

    // Create a stream of futures to run out-of-order and with constrained concurrency.
    let work_chunks_stream =
        futures_util::stream::iter(work_chunks_iterator).buffer_unordered(max_concurrent_requests);
    pin_mut!(work_chunks_stream); // The compiler says we need this

在最后一段代码中,我们首先在流中进行工作,并且—随着结果的到来—进行汇总。最后,我们对结果进行排序并打印出最好的结果。

 // Run the futures and, as result bytes come in, tabulate.
    let mut bigram_counts = HashMap::new();
    while let Some(result) = work_chunks_stream.next().await {
        let bytes_vec = result?;
        for bytes in bytes_vec.iter() {
            let bigram = (bytes[0], bytes[1]);
            let count = bigram_counts.entry(bigram).or_insert(0);
            *count += 1;
        }
    }

    // Sort the bigrams by count and print the top 10
    let mut bigram_count_vec: Vec<(_, usize)> = bigram_counts.into_iter().collect();
    bigram_count_vec.sort_by(|a, b| b.1.cmp(&a.1));
    for (bigram, count) in bigram_count_vec.into_iter().take(10) {
        let char0 = (bigram.0 as char).escape_default();
        let char1 = (bigram.1 as char).escape_default();
        println!("Bigram ('{}{}') occurs {} times", char0, char1, count);
    }
    Ok(())
}

输出是:

Bigram ('\r\n') occurs 367 times
Bigram ('e ') occurs 221 times
Bigram (' t') occurs 184 times
Bigram ('th') occurs 171 times
Bigram ('he') occurs 158 times
Bigram ('s ') occurs 143 times
Bigram ('.\r') occurs 136 times
Bigram ('d ') occurs 133 times
Bigram (', ') occurs 127 times
Bigram (' a') occurs 121 times

Bed-Reader 基因组学 crate 的代码使用相同的技术来从分散的 DNA 区域获取信息。当 DNA 信息到达时,可能是乱序的,代码会填充输出数组的正确列。

顺便提一下:此方法使用了迭代器、流和循环。我希望它能更简单。如果你能找到一种更简单的方法来获取区域的向量,同时限制最大块大小和并发请求数,请发给我一个 pull request。

这涵盖了访问存储在 HTTP 服务器上的文件,但 AWS S3 和其他云服务呢?本地文件怎么办?

规则 4:使用 URL 字符串和选项字符串来访问 HTTP、本地文件、AWS S3、Azure 和 Google Cloud。

object_store crate(以及cloud-file包装 crate)支持通过 URL 字符串或结构体指定文件。我建议使用 URL 字符串,但选择权在你。

让我们考虑一个 AWS S3 的示例。如你所见,AWS 访问需要凭证信息。

use cloud_file::CloudFile;
use rusoto_credential::{CredentialsError, ProfileProvider, ProvideAwsCredentials};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
    // get credentials from ~/.aws/credentials
    let credentials = if let Ok(provider) = ProfileProvider::new() {
        provider.credentials().await
    } else {
        Err(CredentialsError::new("No credentials found"))
    };

    let Ok(credentials) = credentials else {
        eprintln!("Skipping example because no AWS credentials found");
        return Ok(());
    };

    let url = "s3://bedreader/v1/toydata.5chrom.bed";
    let options = [
        ("aws_region", "us-west-2"),
        ("aws_access_key_id", credentials.aws_access_key_id()),
        ("aws_secret_access_key", credentials.aws_secret_access_key()),
    ];
    let cloud_file = CloudFile::new_with_options(url, options)?;

    assert_eq!(cloud_file.read_file_size().await?, 1_250_003);
    Ok(())
}

关键部分是:

 let url = "s3://bedreader/v1/toydata.5chrom.bed";
    let options = [
        ("aws_region", "us-west-2"),
        ("aws_access_key_id", credentials.aws_access_key_id()),
        ("aws_secret_access_key", credentials.aws_secret_access_key()),
    ];
    let cloud_file = CloudFile::new_with_options(url, options)?;

如果我们希望使用结构体而不是 URL 字符串,则变为:

 use object_store::{aws::AmazonS3Builder, path::Path as StorePath};

    let s3 = AmazonS3Builder::new()
        .with_region("us-west-2")
        .with_bucket_name("bedreader")
        .with_access_key_id(credentials.aws_access_key_id())
        .with_secret_access_key(credentials.aws_secret_access_key())
        .build()?;
    let store_path = StorePath::parse("v1/toydata.5chrom.bed")?;
    let cloud_file = CloudFile::from_structs(s3, store_path);

我更喜欢 URL 方法而不是结构体。我发现 URL 稍微简单一些,更加统一,跨云服务时也更容易互操作(例如与 Python)。

这里是我使用的三个 Web 服务的示例 URL:

  • HTTP — [www.gutenberg.org/cache/epub/100/pg100.txt](https://www.gutenberg.org/cache/epub/100/pg100.txt)

  • 本地文件 — file:///M:/data%20files/small.bed — 使用cloud_file::abs_path_to_url_string函数将完整的文件路径正确编码为 URL。

  • AWS S3 — s3://bedreader/v1/toydata.5chrom.bed

本地文件不需要选项。对于其他服务,这里是它们支持的选项和一些示例链接:

  • HTTP — [ClientConfigKey](https://docs.rs/object_store/latest/object_store/enum.ClientConfigKey.html#variant.Timeout)[("timeout", "30s")]

  • AWS S3 — [AmazonS3ConfigKey](https://docs.rs/object_store/latest/object_store/aws/enum.AmazonS3ConfigKey.html)[("aws_region", "us-west-2"), ("aws_access_key_id",), ("aws_secret_access_key",)]

  • Azure — [AzureConfigKey](https://docs.rs/object_store/latest/object_store/azure/enum.AzureConfigKey.html)

  • Google — [GoogleConfigKey](https://docs.rs/object_store/latest/object_store/gcp/enum.GoogleConfigKey.html)

现在我们可以指定和读取云文件,接下来我们应该创建测试。

规则 5:通过 tokio::test 对 HTTP 文件和本地文件进行测试。

object_store crate(以及 cloud-file)支持任何异步运行时。为了测试,Tokio 运行时使得在云文件上测试代码变得简单。下面是一个针对 HTTP 文件的测试:

[tokio::test]
async fn cloud_file_extension() -> Result<(), CloudFileError> {
    let url = "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/plink_sim_10s_100v_10pmiss.bed";
    let mut cloud_file = CloudFile::new(url)?;
    assert_eq!(cloud_file.read_file_size().await?, 303);
    cloud_file.set_extension("fam")?;
    assert_eq!(cloud_file.read_file_size().await?, 130);
    Ok(())
}

运行此测试命令:

cargo test

如果您不想通过测试访问外部 Web 服务器,您可以改为将本地文件当作云文件进行测试。

#[tokio::test]
async fn local_file() -> Result<(), CloudFileError> {
    use std::env;

    let apache_url = abs_path_to_url_string(env::var("CARGO_MANIFEST_DIR").unwrap()
             + "/LICENSE-APACHE")?;
    let cloud_file = CloudFile::new(&apache_url)?;
    assert_eq!(cloud_file.read_file_size().await?, 9898);
    Ok(())
}

这使用标准的 Rust 环境变量 [CARGO_MANIFEST_DIR](https://doc.rust-lang.org/cargo/reference/environment-variables.html) 来查找文本文件的完整路径。然后,它使用 cloud_file::abs_path_to_url_string 将该完整路径正确编码为 URL。

无论是在 HTTP 文件还是本地文件上进行测试,object_store 的强大功能意味着您的代码应该能在任何云服务上运行,包括 AWS S3、Azure 和 Google Cloud。

如果您只需要访问云文件供自己使用,您可以在这里停止阅读规则并跳到结论部分。如果您是为他人添加云访问到一个库(Rust crate),请继续阅读。

规则 6:为了获得最佳性能,通过异步 API 将云文件支持添加到您的 Rust 库中。

如果您提供 Rust crate 给他人使用,支持云文件为您的用户提供了极大的便利,但也有一定成本。让我们来看看我为 Bed-Reader 添加了云支持的基因组学 crate。

如前所述,Bed-Reader 是一个用于读取和写入 PLINK Bed 文件的库,PLINK Bed 文件是一种在生物信息学中用于存储基因型(DNA)数据的二进制格式。Bed 格式的文件可以大到一个 TB。Bed-Reader 为用户提供对大量数据子集的快速随机访问。它返回一个二维数组,用户可以选择 int8、float32 或 float64 格式。Bed-Reader 还为用户提供了 12 个元数据字段,其中六个与个体相关,六个与 SNP(大致来说,是 DNA 位置)相关。基因型数据通常比元数据大 100,000 倍。

PLINK 存储基因型数据和元数据。(图由作者提供。)

旁注:在此上下文中,“API”指的是应用程序编程接口。它是由诸如 Bed-Reader 之类的库代码提供的公共结构、方法等,供其他程序调用。

这里是使用 Bed-Reader 原始“本地文件”API 的示例代码。该代码列出了前五个个体 ID,前五个 SNP ID,以及每个独特的染色体编号。然后,它读取染色体 5 中的每个基因组值:

#[test]
fn lib_intro() -> Result<(), Box<BedErrorPlus>> {
    let file_name = sample_bed_file("some_missing.bed")?;

    let mut bed = Bed::new(file_name)?;
    println!("{:?}", bed.iid()?.slice(s![..5])); // Outputs ndarray: ["iid_0", "iid_1", "iid_2", "iid_3", "iid_4"]
    println!("{:?}", bed.sid()?.slice(s![..5])); // Outputs ndarray: ["sid_0", "sid_1", "sid_2", "sid_3", "sid_4"]
    println!("{:?}", bed.chromosome()?.iter().collect::<HashSet<_>>());
    // Outputs: {"12", "10", "4", "8", "19", "21", "9", "15", "6", "16", "13", "7", "17", "18", "1", "22", "11", "2", "20", "3", "5", "14"}
    let _ = ReadOptions::builder()
        .sid_index(bed.chromosome()?.map(|elem| elem == "5"))
        .f64()
        .read(&mut bed)?;

    Ok(())
}

这是使用新云文件 API 的相同代码:

#[tokio::test]
async fn cloud_lib_intro() -> Result<(), Box<BedErrorPlus>> {
    let url = "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/some_missing.bed";
    let cloud_options = [("timeout", "10s")];

    let mut bed_cloud = BedCloud::new_with_options(url, cloud_options).await?;
    println!("{:?}", bed_cloud.iid().await?.slice(s![..5])); // Outputs ndarray: ["iid_0", "iid_1", "iid_2", "iid_3", "iid_4"]
    println!("{:?}", bed_cloud.sid().await?.slice(s![..5])); // Outputs ndarray: ["sid_0", "sid_1", "sid_2", "sid_3", "sid_4"]
    println!(
        "{:?}",
        bed_cloud.chromosome().await?.iter().collect::<HashSet<_>>()
    );
    // Outputs: {"12", "10", "4", "8", "19", "21", "9", "15", "6", "16", "13", "7", "17", "18", "1", "22", "11", "2", "20", "3", "5", "14"}
    let _ = ReadOptions::builder()
        .sid_index(bed_cloud.chromosome().await?.map(|elem| elem == "5"))
        .f64()
        .read_cloud(&mut bed_cloud)
        .await?;

    Ok(())
}

当切换到云数据时,Bed-Reader 用户必须做出以下更改:

  • 他们必须在一个异步环境中运行,这里是 #[tokio::test]

  • 他们必须使用一个新的结构体 BedCloud,而不是 Bed。(另外,未展示的是,使用 BedCloudBuilder 而不是 BedBuilder。)

  • 它们提供了一个 URL 字符串和可选的字符串选项,而不是本地文件路径。

  • 他们必须在许多地方使用 .await,这些地方往往是不可预测的。(幸运的是,如果他们漏掉了某个地方,编译器会给出很好的错误提示。)

  • ReadOptionsBuilder 增加了一个新方法 read_cloud,用于配合之前的 read 方法。

从库开发者的角度来看,添加新的 BedCloudBedCloudBuilder 结构体需要增加大量的主代码和测试代码。在我的情况下,是 2200 行新的主代码和 2400 行新的测试代码。

旁注:此外,还可以参考 Mario Ortiz Manero 的文章 “我生命中的痛苦:在 Rust 中支持异步和同步代码”。

用户从这些改动中获得的好处是可以利用异步的高效性从云文件中读取数据。

这种好处值得吗?如果不值得,接下来我们将看一下替代方案。

规则 7:或者,为了最大程度的便利,通过传统的(“同步”)API 向你的 Rust 库添加云文件支持。

如果为你添加一个高效的异步 API 看起来太麻烦,或者对你的用户来说太混乱,那也有替代方案。也就是说,你可以提供一个传统的(“同步”)API。我在 Bed-Reader 的 Python 版本和支持该 Python 版本的 Rust 代码中就是这么做的。

旁注:见:用 Rust 编写 Python 扩展的九条规则:从升级 Python 生物信息学包 Bed-Reader 中获得的实际经验教训,《Towards Data Science》。

这里是 Python 调用的 Rust 函数,用于检查一个 *.bed 文件是否以正确的文件签名开始。

use tokio::runtime;
// ...
    #[pyfn(m)]
    fn check_file_cloud(location: &str, options: HashMap<&str, String>) -> Result<(), PyErr> {
        runtime::Runtime::new()?.block_on(async {
            BedCloud::new_with_options(location, options).await?;
            Ok(())
        })
    }

请注意,这不是一个异步函数。它是一个普通的“同步”函数。在这个同步函数内部,Rust 进行了一个异步调用:

BedCloud::new_with_options(location, options).await?;

我们通过将异步调用包装在 Tokio 运行时中来使其变为同步:

use tokio::runtime;
// ...

runtime::Runtime::new()?.block_on(async {
    BedCloud::new_with_options(location, options).await?;
    Ok(())
})

Bed-Reader 的 Python 用户之前可以使用命令 open_bed(file_name_string) 打开一个本地文件进行读取。现在,他们也可以用相同的命令 open_bed(url_string) 打开一个云文件进行读取。唯一的区别是他们传入的字符串格式。

这是规则 6 中的示例,使用更新后的 Python API 的 Python 代码:

 with open_bed(
      "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/some_missing.bed",
      cloud_options={"timeout": "30s"},
  ) as bed:
      print(bed.iid[:5])
      print(bed.sid[:5])
      print(np.unique(bed.chromosome))
      val = bed.read(index=np.s_[:, bed.chromosome == "5"])
      print(val.shape)

请注意,Python API 还提供了一个名为 cloud_options 的新可选参数。此外,幕后有一小段新代码,区分了表示本地文件和表示 URL 的字符串。

在 Rust 中,你也可以使用相同的技巧使 object_cloud 的调用变为同步。具体而言,你可以将异步调用包装在运行时中。好处是接口更简单,库代码更少。代价是效率比提供异步 API 要低。

如果你决定放弃“同步”替代方案,选择提供异步 API,你会发现一个新问题:如何在文档中提供异步示例。接下来我们将讨论这个问题。

规则 8:通过在文档测试中使用隐藏行,遵循良好的 API 设计规则。

文章中的所有规则 优雅 Rust 库 API 的九条规则:从将 Bed-Reader(一种生物信息学库)从 Python 移植到 Rust 中得到的实践经验 适用,特别是以下两条:

*编写良好的文档,保持设计的诚实。

创建不会让你尴尬的示例。*

这表明我们应该在文档中给出示例,但如何在异步方法和 await 中实现这一点呢?窍门在于在我们的文档测试中使用“隐藏行”。例如,以下是 [CloudFile::read_ranges](https://docs.rs/cloud-file/0.1.0-beta.2/cloud_file/struct.CloudFile.html#method.read_ranges) 的文档:

 /// Return the `Vec` of [`Bytes`](https://docs.rs/bytes/latest/bytes/struct.Bytes.html) from specified ranges.
    ///
    /// # Example
    /// ```

    /// 使用 cloud_file::CloudFile;

    ///

    /// # Runtime::new().unwrap().block_on(async {

    /// let url = "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/plink_sim_10s_100v_10pmiss.bim";

    /// let cloud_file = CloudFile::new(url)?;

    /// let bytes_vec = cloud_file.read_ranges(&[0..10, 1000..1010]).await?;

    /// assert_eq!(bytes_vec.len(), 2);

    /// assert_eq!(bytes_vec[0].as_ref(), b"1\t1:1:A:C\t");

    /// assert_eq!(bytes_vec[1].as_ref(), b":A:C\t0.0\t4");

    /// # Ok::<(), CloudFileError>(())}).unwrap();

    /// # 使用 {tokio::runtime::Runtime, cloud_file::CloudFileError};

    /// ```py

文档测试从 py` ```py. Within the doc test, lines starting with /// #` disappear from the documentation:

The hidden lines, however, will still be run by cargo test.

In my library crates, I try to include a working example with every method. If such an example turns out overly complex or otherwise embarrassing, I try to fix the issue by improving the API.

Notice that in this rule and the previous Rule 7, we added a runtime to the code. Unfortunately, including a runtime can easily double the size of your user’s programs, even if they don’t read files from the cloud. Making this extra size optional is the topic of Rule 9.

Rule 9: Include a runtime, but optionally.

If you follow Rule 6 and provide async methods, your users gain the freedom to choose their own runtime. Opting for a runtime like Tokio may significantly increase their compiled program’s size. However, if they use no async methods, selecting a runtime becomes unnecessary, keeping the compiled program lean. This embodies the “zero cost principle”, where one incurs costs only for the features one uses.

On the other hand, if you follow Rule 7 and wrap async calls inside traditional, “synchronous” methods, then you must provide a runtime. This will increase the size of the resultant program. To mitigate this cost, you should make the inclusion of any runtime optional.

Bed-Reader includes a runtime under two conditions. First, when used as a Python extension. Second, when testing the async methods. To handle the first condition, we create a Cargo feature called extension-module that pulls in optional dependencies pyo3 and tokio. Here are the relevant sections of Cargo.toml:


[功能]

extension-module = ["pyo3/extension-module", "tokio/full"]

默认 = []

[依赖]

#...

pyo3 = { version = "0.20.0", features = ["extension-module"], optional = true }

tokio = { version = "1.35.0", features = ["full"], optional = true }

```py

Also, because I’m using Maturin to create a Rust extension for Python, I include this text in `pyproject.toml`:

[tool.maturin]

features = ["extension-module"]


I put all the Rust code related to extending Python in a file called `python_modules.rs`. It starts with this [conditional compilation attribute](https://doc.rust-lang.org/reference/conditional-compilation.html#the-cfg-attribute):

![cfg(feature = "extension-module")] // 如果特性未开启则忽略文件


This starting line ensures that the compiler includes the extension code only when needed.

With the Python extension code taken care of, we turn next to providing an optional runtime for testing our async methods. I again choose Tokio as the runtime. I put the tests for the async code in their own file called `tests_api_cloud.rs`. To ensure that that async tests are run only when the `tokio` dependency feature is “on”, I start the file with this line:

![cfg(feature = "tokio")]


As per Rule 5, we should also include examples in our documentation of the async methods. These examples also serve as “doc tests”. The doc tests need conditional compilation attributes. Below is the documentation for the method that retrieves chromosome metadata. Notice that the example includes two hidden lines that start
`/// # #[cfg(feature = "tokio")]`

/// 每个 SNP(变异)的染色体

/// [...]

///

/// # 示例:

/// py /// use ndarray as nd; /// use bed_reader::{BedCloud, ReadOptions}; /// use bed_reader::assert_eq_nan; /// /// # #[cfg(feature = "tokio")] Runtime::new().unwrap().block_on(async { /// let url = "https://raw.githubusercontent.com/fastlmm/bed-sample-files/main/small.bed"; /// let mut bed_cloud = BedCloud::new(url).await?; /// let chromosome = bed_cloud.chromosome().await?; /// println!("{chromosome:?}"); // Outputs ndarray ["1", "1", "5", "Y"] /// # Ok::<(), Box<BedErrorPlus>>(())}).unwrap(); /// # #[cfg(feature = "tokio")] use {tokio::runtime::Runtime, bed_reader::BedErrorPlus}; ///


在这个文档测试中,当 `tokio` 功能“开启”时,示例使用 `tokio` 并在 Tokio 运行时中运行四行代码。当 `tokio` 功能“关闭”时,`#[cfg(feature = "tokio")]` 块中的代码消失,从而有效地跳过异步操作。

在格式化文档时,Rust 默认包含所有功能的文档,因此我们看到了这四行代码:

![](https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/4544233478a0a6911c16e85bb4a9ab4e.png)

总结规则 9:通过使用 Cargo 功能和条件编译,我们可以确保用户只为他们使用的功能付费。

# 结论

所以,就这样:九条规则帮助你在 Rust 程序中读取云文件。借助[`object_store`](https://docs.rs/object_store/latest/object_store/) crate 的强大功能,你的程序可以突破本地驱动器的限制,从 Web、AWS S3、Azure 和 Google Cloud 加载数据。为了让这一过程更简单,你还可以使用我为本文编写的全新[`cloud-file`](https://crates.io/crates/cloud-file)包装 crate。

我还应该提到,这篇文章仅探讨了`object_store`的一个子集功能。除了我们看到的,`object_store` crate 还处理写入文件和操作文件夹及子文件夹。另一方面,[`cloud-file`](https://crates.io/crates/cloud-file) crate 只处理读取文件。(但嘿,我很欢迎提交 Pull Request)。

你是否应该在程序中添加云文件支持?当然,这取决于。支持云文件为你的程序用户提供了巨大的便利。代价是使用/提供异步接口的额外复杂性。代价还包括像 Tokio 这样的运行时文件大小的增加。另一方面,我认为添加此类支持的工具已经非常好,而且尝试它们也很简单,所以不妨试试看!

感谢你与我一同踏上云端之旅。如果你选择支持云文件,我希望这些步骤能帮助你实现。

*请* [*关注 Carl 的 Medium 账号*](https://medium.com/@carlmkadie)*。我在 Rust 和 Python 的科学编程、机器学习和统计学方面写作。我通常每个月写一篇文章。*
posted @ 2025-01-09 18:54  绝不原创的飞龙  阅读(917)  评论(0)    收藏  举报