TowardsDataScience-2023-博客中文翻译-十-
TowardsDataScience 2023 博客中文翻译(十)
强化学习能否超越训练泛化?
模型泛化案例研究
·
关注 发表在 Towards Data Science ·6 min read·2023 年 1 月 24 日
--
论文中详细描述的项目,强化学习:模型泛化案例研究,探讨了一个使用强化学习(RL)训练的模型在泛化能力上的表现,即当遇到训练中未接触的数据时,能够产生可接受的结果。本研究中的应用是一个具有多个控制的工业过程,这些控制决定了产品在过程中的过渡效果。在这种环境中确定最佳控制设置可能具有挑战性。例如,当控制之间存在相互作用时,调整一个设置可能需要重新调整其他设置。此外,控制与其效果之间的复杂关系使得找到最佳解决方案变得更加困难。这里展示的结果表明,经过 RL 过程训练的模型在这种环境下表现良好,并能够泛化到不同于训练条件的情况下。
论文描述了一种 RL 模型,该模型被训练以寻找用于将电子组件焊接到电路板上的回流焊炉的最佳控制设置(图 1)。烤箱的移动传送带将产品(即电路板)运输通过多个加热区域。此过程根据所需的温度-时间目标曲线加热产品,以生产可靠的焊接连接。

图 1:烤箱传送带上的电路板(图像来自 Adobe,授权给 John Morrow)
人工操作员通常采取以下步骤来确定成功焊接电路板所需的加热器设置:
• 运行一次产品通过烤箱的过程
• 观察传感器读数中的温度-时间曲线
• 调整加热器设置以改善曲线,接近目标曲线
• 等待烤箱温度稳定到新的设置
• 重复此过程,直到传感器读数中的曲线足够接近目标曲线
学习策略
一个 RL 系统用两阶段过程取代了操作员步骤。在第一阶段,代理学习烤箱的动态,并在各种烤箱条件下创建更新加热器设置的策略。
由于在改变加热器设置并将产品通过烤箱后,需要相当长的时间来稳定烤箱的温度,因此使用了烤箱模拟器以加速学习过程。模拟器在几秒钟内模拟了产品通过加热曲线的单次过程,而物理烤箱则需要许多分钟。(第二部分提供了模拟器的详细信息。)
在每一轮学习阶段中,代理从其当前状态采取行动,向模拟器发送八个加热器的新设置。模拟运行后,模拟器报告产品温度读数(每秒采集三百个读数)。
代理的奖励基于返回读数与目标温度-时间曲线之间的差异。如果当前运行的差异小于之前的差异,则奖励为正;否则,奖励为负。一部分读数决定系统的新状态。代理通过从新状态中采取行动来开始学习阶段的下一轮。
使用策略进行规划
在第二阶段,代理按照学习到的策略寻找最佳加热器设置。这些设置将使实际产品曲线与目标温度-时间曲线之间的匹配最接近。图 2 展示了代理遵循策略找到最佳设置的结果。蓝色轨迹是目标温度-时间曲线,红色轨迹是由最佳设置产生的实际曲线。

图 2:示例规划结果 蓝色轨迹:目标曲线。红色轨迹:实际产品曲线。
强化学习系统
如上所述,RL 系统包括一个代理在环境中采取行动,以学习一个策略来实现目标。环境对每个行动作出回应,提供奖励以指示行动是否有助于实现目标。环境还会返回代理在环境中的状态。代理由两个神经网络组成:模型网络和目标网络。代理的目标是找到加热器设置,使得生产的产品时间-温度曲线与目标曲线非常接近。环境是回流炉模拟器。图 3 显示了 RL 系统的组成部分,文中对每个部分进行了详细描述。

图 3:强化学习系统
泛化:状态和奖励定义
状态和奖励定义对于 RL 模型在目标曲线和产品参数与训练期间使用的参数不同的新环境中的泛化能力至关重要。具体来说,状态和奖励都是根据产品和目标曲线温度之间的相对差异定义的,并通过允许的加热器值的最大范围进行规范化。
状态参数在八个加热器区域的中心定义。每个状态参数被定义为产品中心温度与每个加热器区域中心温度的规范化差异。
当代理执行一个动作时,环境会返回一个奖励,表明该动作在实现代理目标方面的有效性。奖励是基于该动作是否减少了实际温度与目标曲线之间的总温差。状态和奖励函数在论文中有更详细的描述。
结果
以下是论文中在不同产品材料配置和温度-时间曲线下运行规划过程的两个测试结果。所有测试都使用了一个模型神经网络,该网络经过以下产品和曲线参数的训练:

测试 1:基准测试
测试 1 作为测试模型性能的基准,使用与训练模型时相同的参数。以下是测试 1 的误差、最佳热区设置以及目标曲线与实际温度-时间图:

测试 1:蓝色曲线:目标曲线。红色曲线:实际产品曲线。
测试 6
测试 6 将产品从 FR4 更改为氧化铝(铝土矿 99%),更改了产品的尺寸,并且更改了曲线。烤箱参数值与测试 1 的基准相同,只是顶部和底部加热元件都处于活动状态。以下表格反映了该测试所用的曲线和产品参数(相对于基准训练参数的更改为粗体):

以下是测试 6 的误差、最佳热区设置以及目标曲线与实际温度/时间图:

结论
本项目展示了强化学习系统如何为控制复杂的工业过程提供解决方案。具体来说,强化学习系统成功学习了用于将电子组件焊接到电路板上的回流焊炉的最佳控制设置。此外,一旦训练完成,系统可以泛化到在与训练时不同要求的环境中产生可接受的结果。
论文链接:强化学习:模型泛化的案例研究
除非另有说明,所有图片均为作者所用。
合成数据能提升机器学习性能吗?
原文:
towardsdatascience.com/can-synthetic-data-boost-machine-learning-performance-6b4041e75dda
研究合成数据在不平衡数据集上提高模型性能的能力
·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 7 月 5 日
--

图片由作者提供:由 Midjourney 生成
背景——不平衡数据集
在商业机器学习应用中,数据不平衡分类问题经常发生。你可能会在客户流失预测、欺诈检测、医疗诊断或垃圾邮件检测中遇到它们。在所有这些情况下,我们要检测的对象都属于少数类,而这些少数类在数据中可能严重不足。为提高模型在不平衡数据集上的表现,提出了几种方法:
-
欠采样:通过随机欠采样多数类来实现更平衡的训练数据集。
-
过采样:通过随机过采样少数类来获得平衡的训练数据集。
-
加权损失:根据少数类为损失函数分配权重。
-
合成数据:使用生成式 AI 创建高保真度的少数类合成数据样本。
在这篇文章中,我展示了如何通过在合成数据上训练模型来超越其他方法,从而提高分类器的性能。
数据集
数据来自Kaggle,包括 284,807 笔信用卡交易,其中 492 笔(0.172%)被标记为欺诈交易。数据可用于商业和非商业用途,采用开放数据公共许可证。
对感兴趣的读者,Kaggle 提供了有关数据的更多详细信息和基本描述性统计。
从这个 Kaggle 数据集中,我创建了两个子集:一个训练集和一个持出集。训练集包含总数据的 80%,以及在探索该方法时生成的合成样本。持出集则包含原始数据的 20%,不包括任何合成样本。

作者提供的图像:数据拆分过程
模型
我使用了Ludwig,这是一个开源的声明式框架,用于构建深度学习模型,因为它易于实现。通过在 yaml 文件中声明模型并通过 Ludwig 的 Python API 运行训练任务,可以轻松构建和训练模型。我之前写过一篇文章,详细介绍了 Ludwig,供感兴趣的读者参考。
对于每种方法,我使用相同的基线模型,仅根据需要调整特定参数。例如,Ludwig 原生支持权重和采样调整——这些可以简单地在 yaml 文件中进行调整。我提供了每种方法的模型配置 yaml 文件的链接,供您探索。
生成合成数据
我使用了合成数据库(SDV),这是一个用于生成合成数据样本的开源库。使用 SDV,我生成了额外的 284k 合成欺诈样本,从而在训练数据集中实现了两个类别的均等表示。
合成样本是通过适用于表格数据的变分自编码器(TVAE)生成的。有关 TVAE 背后的理论,您可以在这篇论文中找到更多细节。
SDV提供了诊断统计数据,显示拟合质量的指示。您可以通过比较真实数据与生成数据中的变量分布,手动探索拟合质量,示例如下。

作者提供的图像:真实与合成的变量 v1 分布对比

作者提供的图像:真实与合成的变量 v10 分布对比

作者提供的图像:真实与合成的变量分布对比
使用精确度召回图表评估性能
我们通过绘制模型与持出数据集的精确度与召回率曲线来评估每个模型的性能。
精确度-召回率曲线
精确度-召回率曲线,即将精确度(在 y 轴上)与召回率(在 x 轴上)进行绘制的图,与 ROC 曲线类似。它作为一种强健的诊断工具,用于评估模型在显著类别不平衡场景中的性能,例如我们的信用卡欺诈检测用例,便是一个典型例子。
图表的右上角代表“理想”点 —— 假阳性率为零,真正阳性率为一。一个熟练的模型应该能够达到或接近这一点,这意味着曲线下面积(AUC-PR)较大的模型可能更优越。
无技能预测器
“无技能”预测器是一个简单的模型,其预测是随机的。对于不平衡数据集,无技能线是一个高度等于正类比例的水平线。这是因为如果模型随机预测正类,精确度将等于数据集中正实例的比例。
模型性能 — 基线
基线模型是没有样本调整、损失函数调整或增强训练数据的深度神经网络。每种方法与基线性能进行比较,基线性能作为性能基准。

作者提供的图像:基线模型的精确度-召回率曲线
模型性能 — 加权损失方法
加权损失根据欺诈交易与非欺诈交易的比例调整损失函数。

作者提供的图像:加权损失方法的精确度-召回率曲线
模型性能 — 过采样方法
过采样随机地过度采样欺诈交易,直到训练数据集中各类别之间的表示均等。

作者提供的图像:过采样方法的精确度-召回率曲线
模型性能 — 欠采样方法
欠采样随机地欠采样非欺诈交易,直到训练数据集中各类别之间的表示均等。

作者提供的图像:欠采样方法的精确度-召回率曲线
模型性能 — 人工合成数据方法
利用 TVAEs 生成 284k 人工合成的欺诈样本,以在训练数据集中获得各类别的均等表示。

作者提供的图像:人工合成数据方法的精确度-召回率曲线
自助法持出数据集
为了获得对保留集性能的稳健视角,我从原始数据中创建了五十个自举保留集。对每种方法关联的模型在所有集上运行,提供了性能分布。然后,我们可以使用 Kolmogorov-Smirnov 检验来确定每种方法是否与基线存在统计显著差异。
加权:加权方法在召回率和 AUC 方面相对于基线表现略逊。除此之外,各性能指标的方差相对于其他方法显得较高。

作者提供的图像:模型性能指标在 50 个自举保留样本上的表现。基线与加权损失,KS 统计 — AUC 0.420 p 值 < 0.000,精度 0.260 p 值 0.068,召回率 0.520 p 值 < 0.000
过采样:过采样方法提高了模型的召回率,但导致精度的急剧恶化。

作者提供的图像:模型性能指标在 50 个自举保留样本上的表现。基线与过采样,KS 统计 — AUC 0.160 p 值 0.549,精度 1.0 p 值 < 0.000,召回率 0.9 p 值 < 0.000
欠采样:该方法在所有指标上表现都不如基线。

作者提供的图像:模型性能指标在 50 个自举保留样本上的表现。基线与过采样,KS 统计 — AUC 0.880 p 值 < 0.000,精度 0.6 p 值 < 0.000,召回率 1.0 p 值 < 0.000
合成:合成方法提升了模型的召回率,尽管以牺牲精度为代价。尽管精度的影响仍然显著,但与过采样方法相比,合成方法提供了更具韧性的替代方案,能够在不显著影响精度的情况下提升模型召回率。合成方法的稳健性在 AUC-PR 的提升中得到了进一步证明。

作者提供的图像:模型性能指标在 50 个自举保留样本上的表现。基线与合成,KS 统计 — AUC 0.620,精度 0.560,召回率 0.360 所有 p 值 ≤ 0.003
结论
我们注意到,相对于基线,合成数据方法可以提升模型的召回率,但以牺牲精度为代价。过采样也能实现类似的结果,但模型精度相比之下急剧下降。
在我们特定的信用卡欺诈检测背景下,假阳性不像假阴性那样昂贵。因此,如果提高召回率能够显著提高,我们可以在模型精度上做出一定妥协。通过合成实例丰富我们的训练数据似乎是提高召回率同时减轻精度不良影响的有效策略。这种增强可能会显著影响盈利能力,特别是在将模型扩展到处理数百万笔交易时。最终,将假阳性和假阴性的确切成本进行归因,将使我们更清楚地理解最具商业可行性的方法,这一话题超出了本文的范围。
检查不同样本规模的合成数据的表现将非常有趣,也许可以与加权损失结合起来。类似地,尝试不同的过采样比例可能会产生与我们观察到的合成方法类似的效果。
这个项目的笔记本可以在我的 GitHub repo 中找到
在 LinkedIn 上关注我
订阅 Medium 以获取更多来自我的见解:
[## 使用我的推荐链接加入 Medium — John Adeojo
我分享数据科学项目、经验和专业知识,以帮助你在旅程中。你可以通过…
johnadeojo.medium.com](https://johnadeojo.medium.com/membership?source=post_page-----6b4041e75dda--------------------------------)
如果你有兴趣将 AI 或数据科学整合到你的业务操作中,我们邀请你预约与我们进行免费的初步咨询:
通过免费咨询发现我们在帮助企业实现雄心勃勃目标方面的专业知识。我们的数据科学家和…
www.data-centric-solutions.com](https://www.data-centric-solutions.com/book-online?source=post_page-----6b4041e75dda--------------------------------)
变换器能否学会制定策略?
原文:
towardsdatascience.com/can-transformers-learn-to-strategize-862770c996ea
TicTacGPT 用于玩简单的棋盘游戏
·发布于 Towards Data Science ·27 分钟阅读·2023 年 9 月 8 日
--

尽管大多数棋盘游戏倾向于使用卷积神经网络或其他几何灵感的架构,但我们实际能够将棋盘状态表示为字符串,这就引出了一个问题,即变换器是否可以自然地应用于游戏。在这里,我们将看看是否可以在简单的井字游戏的背景下回答这个问题。虽然这看起来可能不太实际(几乎每个人都知道这个游戏中存在一个简单的闭式纳什均衡策略),但它是我们问题的一个有用的测试平台。原因在于游戏足够简单,我们可以轻松训练一个变换器来玩它,但又足够复杂,不容易立刻看出最佳策略是什么。
实现游戏
我们将开始实现一个TicTacToe类。这相当简单。我们希望能够将棋盘表示为 9 个字符的字符串,每个字符代表一个方格。我们将使用X表示第一个玩家,O表示第二个玩家,-表示空方格。我们还会跟踪轮到谁进行下一步,游戏是否结束。如果有获胜者,我们也会记录下来。最后,我们将包含一个打印棋盘的方法,以便在调试时不必盯着字符串看。
class TicTacToe:
def __init__(self):
# Initialise an empty board
self.board = ['-' for _ in range(9)]
self.current_player = 'X' # X will start
def make_move(self, position):
"""Make a move on the board."""
if self.board[position] == '-':
self.board[position] = self.current_player
self.switch_player()
return True
else: return False # illegal move
def switch_player(self):
"""Switch the current player."""
self.current_player = 'O' if self.current_player == 'X' else 'X'
def check_winner(self):
"""Check if there is a winner."""
# Rows, columns, diagonals
winning_positions = [
[0, 1, 2], [3, 4, 5], [6, 7, 8], # Rows
[0, 3, 6], [1, 4, 7], [2, 5, 8], # Columns
[0, 4, 8], [2, 4, 6] # Diagonals
]
for positions in winning_positions:
values = [self.board[pos] for pos in positions]
if values[0] == values[1] == values[2] and values[0] != '-':
return values[0]
return None # No winner yet
def is_draw(self):
"""Check if the game is a draw."""
return all(cell != '-' for cell in self.board)
def get_board_string(self):
"""Get the current board state as a string."""
return ''.join(self.board)
def get_legal_moves(self):
"""Get the positions of all legal moves."""
return [i for i, cell in enumerate(self.board) if cell == '-']
def pretty_print_board(self):
"""Pretty-print the board."""
for i in range(0, 9, 3):
print(f"{self.board[i]} | {self.board[i+1]} | {self.board[i+2]}")
if i < 6:
print("- "*5)
# Test the pretty_print_board method
tic_tac_toe = TicTacToe()
print("Initial board:")
tic_tac_toe.pretty_print_board()
# Make some moves
tic_tac_toe.make_move(0)
tic_tac_toe.make_move(4)
tic_tac_toe.make_move(8)
print("\nBoard after some moves:")
tic_tac_toe.pretty_print_board()
Initial board:
- | - | -
- - - - -
- | - | -
- - - - -
- | - | -
Board after some moves:
X | - | -
- - - - -
- | O | -
- - - - -
- | - | X
创建我们的训练数据
我们希望我们的变换器能够输入一个给定的棋盘状态,并输出一个走法,该走法是一个从 0 到 8 的整数,表示它希望将棋子放在的位置。为此,我们将创建一个棋盘状态和走法的数据集。我们将通过模拟我们玩家的所有可能获胜位置,然后遍历所有可能使我们达到该位置的游戏组合来做到这一点。这意味着变换器将学习在任何给定的棋盘状态下什么是一个好的走法。
为了实现这一点,simulate_all_games 函数生成了训练和验证数据。具体来说,该函数模拟了所有可能的井字棋游戏,探索了两个玩家(‘X’ 和 ‘O’)的每一种走法排列。这种详尽的模拟确保了模型在一个全面的数据集上进行训练,涵盖了所有可能的游戏情景。在每个模拟的游戏中,函数记录了不仅是获胜或平局的结果,还记录了棋盘状态的序列以及导致这些结果的走法。这些棋盘状态和走法随后被转换为数字表示,适用于训练我们的变换器。这确保了模型不仅学习如何获胜,还能够从任何给定的棋盘状态中输出一个合适的获胜走法。你可以将其视为类似于常规语言变换器在给定任何长度的上下文时输出一个合适的标记,从一个标记(即我们的起始棋盘状态)到 EOS 标记(即我们的获胜走法生成最终棋盘状态)。
from copy import deepcopy
from itertools import product
import numpy as np
import torch
# Define character to integer mapping
chars = sorted(list(set('XO-')))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {0: '-', 1: 'X', 2: 'O'}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
input_sequences = []
output_sequences = []
# Function to simulate all possible games recursively
def simulate_all_games(game, x_moves, o_moves):
global input_sequences, output_sequences
# Check if the game has reached a terminal state
winner = game.check_winner()
if winner == 'X' or game.is_draw():
# Add the sequence of board states and moves leading to this win for 'X' or draw
board = ['-' for _ in range(9)]
for i, x_move in enumerate(x_moves):
input_sequences.append(encode(''.join(board)))
output_sequences.append(x_move)
board[x_move] = 'X'
if i < len(o_moves):
board[o_moves[i]] = 'O'
return
elif winner == 'O':
return # We don't add these to our training data
# Otherwise, continue simulating the game
legal_moves = game.get_legal_moves()
for move in legal_moves:
# Create a copy of the game to simulate the move
new_game = deepcopy(game)
was_legal = new_game.make_move(move)
# If the move was legal, continue simulating
if was_legal:
if new_game.current_player == 'X':
simulate_all_games(new_game, x_moves + [move], o_moves)
else:
simulate_all_games(new_game, x_moves, o_moves + [move])
# Create an initial empty game
initial_game = TicTacToe()
# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])
# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)
# Show some sample input-output pairs
print(input_tensor[:10], output_tensor[:10])
print("Number of input-output pairs:", len(input_sequences))
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 1, 2, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 1, 2, 0, 0, 0, 0, 0],
[1, 2, 1, 2, 1, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 1, 2, 0, 0, 0, 0, 0]])
tensor([1, 3, 5, 1, 3, 5, 6, 1, 3, 5])
Number of input-output pairs: 658224
这给我们大约 650,000 个张量用于训练。这些张量看起来大致正确,但没有可视化的棋盘很难判断。让我们重用我们的 print_board 函数来查看一些随机的棋盘状态,以及给定走法后的下一个棋盘状态是什么样的:
def pretty_print_board(board: str):
"""Pretty-print the board."""
for i in range(0, 9, 3):
print(f"{board[i]} | {board[i+1]} | {board[i+2]}")
if i < 6:
print("- "*5)
rand_idx = torch.randint(len(input_tensor), (1,))[0]
random_game = input_tensor[rand_idx].tolist()
print("Current game state:")
decoded_game = decode(random_game)
pretty_print_board(decoded_game)
print( )
move = output_tensor[rand_idx].item()
decoded_game = decoded_game[:move] + 'X' + decoded_game[move+1:]
print("New game state:")
pretty_print_board(decoded_game)
Current game state:
X | - | -
- - - - -
- | X | O
- - - - -
- | O | -
New game state:
X | - | X
- - - - -
- | X | O
- - - - -
- | O | -
这似乎是合理的,但我注意到有些游戏有一个可用的获胜走法,但模拟却做出了不同的走法(仍然以胜利告终)。在上面的示例中出现了这种情况。让我们将 simulate_all_games 函数更改为在找到至少一个潜在获胜走法时停止搜索。
input_sequences = []
output_sequences = []
def simulate_all_games(game, x_moves, o_moves):
global input_sequences, output_sequences
# Check if the game has reached a terminal state
winner = game.check_winner()
if winner == 'X' or game.is_draw():
# Add the sequence of board states and moves leading to this win for 'X' or draw
board = ['-' for _ in range(9)]
for i, x_move in enumerate(x_moves):
input_sequences.append(encode(''.join(board)))
output_sequences.append(x_move)
board[x_move] = 'X'
if i < len(o_moves):
board[o_moves[i]] = 'O'
return
elif winner == 'O':
return # We don't add these to our training data
# Before simulating further moves, check if a winning move is available
legal_moves = game.get_legal_moves()
for move in legal_moves:
test_game = deepcopy(game)
test_game.make_move(move)
if test_game.check_winner() == game.current_player:
# This move is a winning move, so we make it and end further simulation
if test_game.current_player == 'X':
simulate_all_games(test_game, x_moves + [move], o_moves)
else:
simulate_all_games(test_game, x_moves, o_moves + [move])
return # End further exploration for this branch
# If no immediate winning move is found, continue simulating the game
for move in legal_moves:
# Create a copy of the game to simulate the move
new_game = deepcopy(game)
was_legal = new_game.make_move(move)
# If the move was legal, continue simulating
if was_legal:
if new_game.current_player == 'X':
simulate_all_games(new_game, x_moves + [move], o_moves)
else:
simulate_all_games(new_game, x_moves, o_moves + [move])
# Create an initial empty game
initial_game = TicTacToe()
# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])
# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)
最后,让我们看看我们需要训练多少步:
print("Number of input-output pairs:", len(input_sequences))
Number of input-output pairs: 147104
大约 150,000 个示例。这看起来是一个合理的开始。
多头注意力的变换器架构
注意力是一种机制,使模型能够在进行预测时专注于输入序列的某些部分。变压器架构使用多头自注意力,这意味着模型学习以不同的方式关注输入序列的不同部分。这很有用,因为它允许模型学习输入序列和输出序列之间的不同关系。例如,当预测输出序列中的第一个词时,它可能会学习关注输入序列中的第一个标记,但在预测输出序列中的第二个词时,则关注输入序列中的最后一个标记。这是一种强大的机制,可以使模型学习输入序列和输出序列之间的复杂关系。
但这到底是怎么工作的呢?从原始的 Attention is all you need 论文中,定义在查询矩阵 Q、键矩阵 K 和值矩阵 V 上的注意力定义为:

其中我们除以 sqrt{d_k} 以确保 softmax 的方差适当。让我们分解一下实际发生了什么。假设我们有一个维度为 (B,T,C) 的输入,其中 B 是批次大小,T 是序列长度,C 是通道数。我们可以把它看作是一个包含 B 个长度为 T 的序列的批次,每个序列有 C 个通道:
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)
然后,为了实现一个自注意力头,我们需要创建查询、键和值。实际上,这些是具有一定 head_size 的线性层,head_size 就是我们希望线性层的宽度。我们不包括偏置项,因为我们不想为注意力学习一个偏置项。
wei = q @ k.transpose(-2,-1) # (B, T, 16) @ (B, 16, T) -> (B, T, T)
但是,如果你拆解我们实际在做的事情,我们是在预测一个单词序列中的下一个词。由于我们不想作弊并使用我们尚未看到的序列部分(因为在生成过程中我们不能这样做),我们需要屏蔽掉尚未看到的序列部分。我们通过创建一个形状为 (T,T) 的掩码来实现,其中 T 是序列长度,然后将上三角的所有值设置为负无穷。这确保了 softmax 对所有掩码值为 0,因此模型不会关注这些值。
最后,我们将注意力权重与值矩阵相乘,以获得注意力层的输出。这是自注意力的单个头的输出。然后,我们可以根据需要重复此过程多次,然后将每个头的输出拼接在一起,以获得多头自注意力层的最终输出。
tril = torch.tril(torch.ones((T,T)))
wei = wei.masked_fill(tril==0, float("-inf"))
wei = F.softmax(wei, dim=-1)
v = value(x) # (B,T,16)
out = wei @ v
out.shape
这只是自注意力的一个头。为了创建多个头,我们只需多次重复这个过程,然后将每个头的输出连接起来以获得多头自注意力层的最终输出。我们还添加了残差连接,以提高我们优化这个相对深层模型的能力。对于类似代码的完整演示以及解码器仅变换器背后的机制,我强烈推荐 Andrej Karpathy 的nanoGPT 讲座。
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
# Hyperparameters
batch_size = 128 # How many independent sequences will we process in parallel?
block_size = 9 # The size of the tic-tac-toe board
max_iters = 10000
eval_interval = 500
learning_rate = 1e-3
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 100
n_embd = 32 # Reduced the embedding size
n_head = 2 # Reduced the number of heads
n_layer = 2 # Reduced the number of layers
dropout = 0.1
print(f'Training on {device}')
# Initialize random seed
torch.manual_seed(1337)
# Split into training and validation sets
n = int(0.90 * len(input_tensor)) # 90% for training
train_input = input_tensor[:n]
train_output = output_tensor[:n]
val_input = input_tensor[n:]
val_output = output_tensor[n:]
# Updated data loading function
def get_batch(split):
input_data = train_input if split == 'train' else val_input
output_data = train_output if split == 'train' else val_output
# Choose index locs for batch_size sequences
ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
# Get the input and output sequences
x = input_data[ix]
y = output_data[ix]
x, y = x.to(device), y.to(device)
return x, y
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
B,T,C = x.shape
k = self.key(x) # (B,T,hs)
q = self.query(x) # (B,T,hs)
# compute attention scores ("affinities")
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
# perform the weighted aggregation of the values
v = self.value(x) # (B,T,hs)
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
return out
class MultiHeadAttention(nn.Module):
""" multiple heads of self-attention in parallel """
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedFoward(nn.Module):
""" a simple linear layer followed by a non-linearity """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block: communication followed by computation """
def __init__(self, n_embd, n_head):
# n_embd: embedding dimension, n_head: the number of heads we'd like
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedFoward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class Transformer(nn.Module):
def __init__(self):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
self.lm_head = nn.Linear(n_embd, 9)
# better init, not covered in the original GPT video, but important, will cover in followup video
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx) # (B,T,C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
x = tok_emb + pos_emb # (B,T,C)
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
logits = self.lm_head(x) # (B,T,vocab_size)
# Take the logits corresponding to the last time step T
logits = logits[:, -1, :] # Now logits is (B, 9)
if targets is None:
loss = None
else:
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# crop idx to the last block_size tokens
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
为了确保我们的架构按预期工作,让我们尝试传递一个单一的批次。
xb, yb = get_batch('train')
print(xb.shape, yb.shape)
m = Transformer().to(device)
logits, loss = m(xb, yb)
print(logits.shape)
print(f"Loss: {loss.item():.3f}")
torch.Size([128, 9]) torch.Size([128])
torch.Size([128, 9])
Loss: 2.203
在进行这个初始前向传播时,一个好的步骤是测试损失是否大致等于我们对随机输入的期望。由于我们有 9 维 logits,并且我们使用的交叉熵损失等于正确类别的负对数似然,我们期望损失大致为:

训练模型
使用相当小的变换器(约 25,000 个参数),我们实现了以下损失(请注意,我使用了少量的权重衰减和丢弃):
model = Transformer()
model = model.to(device)
# Print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []
# Training loop
for iter in tqdm(range(max_iters)):
# Evaluate the loss on train and val sets occasionally
if iter % eval_interval == 0 or iter == max_iters - 1:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
val_loss_history.append(losses['val'])
# Sample a batch of data
xb, yb = get_batch('train')
# Evaluate the loss
logits, loss = model(xb, yb)
train_loss_history.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
0.025961 M parameters
0%| | 5/10000 [00:00<24:42, 6.74it/s]
step 0: train loss 2.2033, val loss 2.2106
5%|▌ | 504/10000 [00:14<12:01, 13.15it/s]
step 500: train loss 1.9162, val loss 2.0215
10%|█ | 1008/10000 [00:27<08:27, 17.73it/s]
step 1000: train loss 1.7846, val loss 1.8570
15%|█▌ | 1505/10000 [00:40<10:34, 13.39it/s]
step 1500: train loss 1.7370, val loss 1.7648
20%|██ | 2007/10000 [00:53<07:35, 17.55it/s]
step 2000: train loss 1.7188, val loss 1.7770
25%|██▌ | 2506/10000 [01:05<07:11, 17.36it/s]
step 2500: train loss 1.6957, val loss 1.7456
30%|███ | 3006/10000 [01:18<06:35, 17.69it/s]
step 3000: train loss 1.6965, val loss 1.7448
35%|███▌ | 3506/10000 [01:31<06:12, 17.41it/s]
step 3500: train loss 1.6961, val loss 1.7809
40%|████ | 4005/10000 [01:43<07:41, 12.98it/s]
step 4000: train loss 1.6819, val loss 1.7256
45%|████▌ | 4506/10000 [01:56<05:18, 17.24it/s]
step 4500: train loss 1.6892, val loss 1.7066
50%|█████ | 5005/10000 [02:09<05:14, 15.88it/s]
step 5000: train loss 1.6846, val loss 1.7141
55%|█████▌ | 5508/10000 [02:23<04:37, 16.19it/s]
step 5500: train loss 1.6835, val loss 1.6998
60%|██████ | 6004/10000 [02:36<05:19, 12.51it/s]
step 6000: train loss 1.6828, val loss 1.7095
65%|██████▌ | 6506/10000 [02:49<03:23, 17.13it/s]
step 6500: train loss 1.6722, val loss 1.7151
70%|███████ | 7008/10000 [03:02<03:05, 16.17it/s]
step 7000: train loss 1.6656, val loss 1.7158
75%|███████▌ | 7505/10000 [03:15<02:30, 16.54it/s]
step 7500: train loss 1.6672, val loss 1.7078
80%|████████ | 8007/10000 [03:28<02:01, 16.38it/s]
step 8000: train loss 1.6808, val loss 1.7120
85%|████████▌ | 8505/10000 [03:41<01:47, 13.94it/s]
step 8500: train loss 1.6733, val loss 1.7144
90%|█████████ | 9007/10000 [03:54<00:56, 17.54it/s]
step 9000: train loss 1.6714, val loss 1.7031
95%|█████████▌| 9506/10000 [04:07<00:28, 17.39it/s]
step 9500: train loss 1.6707, val loss 1.7073
100%|██████████| 10000/10000 [04:20<00:00, 38.43it/s]
step 9999: train loss 1.6664, val loss 1.7506
这似乎不太好。让我们绘制图表看看发生了什么。
import matplotlib.pyplot as plt
def plot_transformer_loss(loss_history, val_loss_history):
# Two horizontal figures side-by-side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# First plot = training loss
ax1.plot(loss_history, lw=0.5)
# Plot moving average of loss
window_size = 100
ax1.plot(np.convolve(loss_history, np.ones(window_size) / window_size, mode='valid'), label='Moving average')
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Cross-entropy Loss')
ax1.set_title('Training Loss')
ax1.legend()
# Second plot = validation loss
# Set marker style to be circles at each data point
indices = np.arange(0, len(val_loss_history) * eval_interval, eval_interval)
ax2.plot(indices, val_loss_history, marker='o')
ax2.set_title('Validation Loss')
ax2.set_xlabel('Iteration')
plt.show()
plot_transformer_loss(train_loss_history, val_loss_history)

我们的变换器在井字游戏数据上训练的初始损失(图片由作者提供)。
我们可以使用这个code来测试变换器的效果:
import random
import torch
from IPython.display import clear_output
def play_game(model, stoi, itos, device):
game = TicTacToe()
# Randomly decide who goes first
game.current_player = random.choice(['X', 'O'])
while game.check_winner() is None and not game.is_draw():
#clear_output(wait=True)
print(f"{game.current_player}'s turn.")
game.pretty_print_board()
current_board_str = game.get_board_string()
if game.current_player == 'X':
print("Model's turn...")
current_board_encoded = torch.tensor([stoi[c] for c in current_board_str], dtype=torch.long).unsqueeze(0).to(device)
logits, _ = model(current_board_encoded)
# Move logits to cpu
logits = logits.cpu()
# Create a mask for legal moves and zero out logits for illegal moves
legal_moves = game.get_legal_moves()
mask = torch.zeros(9)
mask[legal_moves] = 1
masked_logits = logits * mask
# Get the model's move
predicted_move = masked_logits.argmax(dim=-1).item()
# Make the model's move
game.make_move(predicted_move)
else:
print("Your turn!")
legal_moves = game.get_legal_moves()
print("Legal moves:", legal_moves)
user_move = int(input("Enter your move: "))
if user_move in legal_moves:
game.make_move(user_move)
else:
print("Illegal move. Try again.")
continue
winner = game.check_winner()
if winner is not None:
#clear_output(wait=True)
print(f"{winner} wins!")
game.pretty_print_board()
break
elif game.is_draw():
#clear_output(wait=True)
print("It's a draw!")
game.pretty_print_board()
break
O's turn.
- | - | -
- - - - -
- | - | -
- - - - -
- | - | -
Your turn!
Legal moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
X's turn.
- | - | -
- - - - -
- | - | -
- - - - -
- | - | O
Model's turn...
O's turn.
- | - | X
- - - - -
- | - | -
- - - - -
- | - | O
Your turn!
Legal moves: [0, 1, 3, 4, 5, 6, 7]
X's turn.
O | - | X
- - - - -
- | - | -
- - - - -
- | - | O
Model's turn...
O's turn.
O | - | X
- - - - -
X | - | -
- - - - -
- | - | O
Your turn!
Legal moves: [1, 4, 5, 6, 7]
O wins!
O | - | X
- - - - -
X | O | -
- - - - -
- | - | O
好吧,我轻松打败了模型。某些事情出错了。
改进变换器
所以目前,变换器甚至无法从任何给定位置可靠地学习简单的胜利动作。我能想到几个原因:
-
变换器仅在胜利的动作上进行训练,因此当我使用一个良好的策略(即没有胜利的动作可用)时,可能无法学习如何游戏。理论上,为了应对这一点,我们应该允许它在游戏注定为平局时进行训练。
-
变换器的参数过多。试图让几十万个神经元协调一个简单的策略可能需要很长时间来训练,并依赖于grokking和其他现象才能进入优化景观的可泛化部分。
-
变换器的参数过少。也许它需要更多的神经元来学习一个好的策略。这样说来,我非常怀疑如果几百万个神经元都无济于事,那几十万个神经元能否解决问题。
状态空间分析
在继续之前,我想从理论上分析编码井字游戏中完整胜利策略所需的神经元数量,我们需要考虑游戏的状态空间和决策过程的复杂性。
在井字棋中,游戏棋盘是一个 3 x 3 的网格,每个单元可以处于三种状态之一:‘X’,‘O’,或空(‘-’)。因此,总的可能棋盘状态数量可以计算为 3⁹ = 19683。然而,并非所有这些状态在实际游戏中都是有效的;其中一些是不可达的或非法的(例如,所有单元都是‘X’的棋盘)。合法状态的数量实际上大约为 5478,但为了分析的目的,我们将考虑上限,即 3⁹。
每个状态都需要一个决策:在哪里放置下一个‘X’(因为我们考虑的是‘X’的获胜策略)。有 9 个可能的位置,但合法的移动数量通常少于 9,这取决于已经被占据的单元格数量。一个神经网络需要将每个可能的棋盘状态映射到一个正确的移动。实现这种映射的一种方法是通过一个完全连接的层,该层将棋盘状态作为输入,并输出 9 个可能移动的概率分布。输入层将有 3x3=9 个神经元(每个单元一个),输出层将有 9 个神经元(每个可能移动一个)。中间的隐藏层将执行学习获胜策略的复杂任务。
考虑到输入层有 9 个神经元和输出层有 9 个神经元,我们关注的是隐藏层中的神经元数量。从理论上讲,我们可以使用一个具有 3⁹个神经元的隐藏层将每个可能的状态映射到一个获胜的移动。这将是一个上限,可能远远超过实际需要的数量,因为存在不可达/非法状态以及井字棋的固有对称性会减少实际的唯一状态数量。因此,在这个上限情况下,总的神经元数量将是:9 + 19683 + 9 = 19701。这是一个理论上的上限,实际数量可能由于前述因素而低得多。让我们尝试实现一个简单的前馈普通神经网络,看看它在我们的任务中的表现。我们将使用三个隐藏层,而不是一个具有数千个神经元的层。
import torch.nn as nn
import torch.nn.functional as F
device = 'cpu'
class TicTacToeNN(nn.Module):
def __init__(self):
super(TicTacToeNN, self).__init__()
self.fc1 = nn.Linear(9, 16) # Input layer to hidden layer 1
self.fc2 = nn.Linear(16, 32) # Hidden layer 1 to hidden layer 2
self.fc3 = nn.Linear(32, 16) # Hidden layer 2 to hidden layer
self.fc4 = nn.Linear(16, 9) # Hidden layer 3 to output layer
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
def get_batch(split):
input_data = train_input if split == 'train' else val_input
output_data = train_output if split == 'train' else val_output
# Choose index locs for batch_size sequences
ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
# Get the input and output sequences
x = input_data[ix].float()
y = output_data[ix]
x, y = x.to(device), y.to(device)
return x, y
# Create an initial empty game
initial_game = TicTacToe()
# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])
# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)
nn_model = TicTacToeNN()
nn_model.to(device)
# Print the number of parameters in the model
print(sum(p.numel() for p in nn_model.parameters()), 'parameters')
# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(nn_model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []
# Training loop
max_iters = 1000000
for iter in tqdm(range(max_iters)):
# Evaluate the loss on train and val sets occasionally
# Sample a batch of data
xb, yb = get_batch('train')
# Evaluate the loss
logits = nn_model(xb)
# Calculate cross-entropy loss
loss = F.cross_entropy(logits, yb)
train_loss_history.append(loss.item())
# Get the validation loss
xb, yb = get_batch('val')
logits = nn_model(xb)
val_loss = F.cross_entropy(logits, yb)
val_loss_history.append(val_loss.item())
# Backpropagate and update the weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
1385 parameters
100%|██████████| 1000000/1000000 [08:08<00:00, 2048.42it/s]
让我们看看损失情况如何:

对于普通前馈神经网络的训练损失(作者提供的图片)。
显然,我们正在饱和性能。任务和我们设置的方法有些问题,阻止了模型学习适当的策略。为了改变一下,我打算尝试给模型提供只包含最优策略的训练数据。
最优策略训练数据
纽厄尔和西蒙 1972 年的井字棋程序概述了完美策略(以赢得比赛或至少平局),如果我们从以下移动偏好中选择第一个可用的移动:
-
获胜:如果你有两个连成一行,玩第三个以完成三连。
-
阻挡:如果对手有两个连成一行,玩第三个以阻挡他们。
-
分叉:创造一个你可以通过两种方式获胜的机会。
-
阻止对手的叉子:我们可以创建两个连续的棋子以迫使对手防守(如果这样做不会给他们造成叉子),或者阻止他们的潜在叉子。
-
中心:占据中心位置。
-
对角角落:如果对手在一个角落里,选择对角的角落。
-
空角落:选择一个空的角落。
-
空侧:选择一个空的边侧。
让我们重写数据生成器,以根据这一策略获得所有可能的走法。我们还将模拟两种可能的先手玩家的所有游戏。
from copy import deepcopy
# Helper function to find if there's a winning move or a move that blocks the opponent from winning
def find_winning_or_blocking_move(board, player):
winning_positions = [
[0, 1, 2], [3, 4, 5], [6, 7, 8], # Rows
[0, 3, 6], [1, 4, 7], [2, 5, 8], # Columns
[0, 4, 8], [2, 4, 6] # Diagonals
]
for positions in winning_positions:
values = [board[pos] for pos in positions]
if values.count(player) == 2 and values.count('-') == 1:
return positions[values.index('-')]
return None
# Helper function for checking for fork opportunities
def find_fork_move(board, player):
fork_move = None
for i in range(9):
if board[i] == '-':
temp_board = board[:]
temp_board[i] = player
winning_moves = 0
for j in range(9):
if temp_board[j] == '-':
temp_board_2 = temp_board[:]
temp_board_2[j] = player
if find_winning_or_blocking_move(temp_board_2, player) is not None:
winning_moves += 1
if winning_moves >= 2:
fork_move = i
break
return fork_move
# Helper function to find the optimal move according to a defined strategy
def optimal_strategy(board, player):
opponent = 'O' if player == 'X' else 'X'
# 1\. Win: If you have two in a row, play the third to get three in a row.
win_move = find_winning_or_blocking_move(board, player)
if win_move is not None:
return win_move
# 2\. Block: If the opponent has two in a row, play the third to block them.
block_move = find_winning_or_blocking_move(board, opponent)
if block_move is not None:
return block_move
# 3\. Fork: Create an opportunity where you can win in two ways.
fork_move = find_fork_move(board, player)
if fork_move is not None:
return fork_move
# 4\. Block Opponent's Fork
opponent_fork_move = find_fork_move(board, opponent)
if opponent_fork_move is not None:
return opponent_fork_move
# 5\. Center: Play the center.
if board[4] == '-':
return 4
# 6\. Opposite Corner: If the opponent is in the corner, play the opposite corner.
corners = [(0, 8), (2, 6), (8, 0), (6, 2)]
for corner1, corner2 in corners:
if board[corner1] == opponent and board[corner2] == '-':
return corner2
# 7\. Empty Corner: Play an empty corner.
for corner in [0, 2, 6, 8]:
if board[corner] == '-':
return corner
# 8\. Empty Side: Play an empty side.
for side in [1, 3, 5, 7]:
if board[side] == '-':
return side
# Function to simulate all games according to the optimal strategy
def simulate_all_games_optimal_v2(game, x_starts=True):
global input_sequences, output_sequences
# Check for terminal state
winner = game.check_winner()
if winner or game.is_draw():
return
# If it's X's turn, apply the optimal strategy and save the board state and move
if game.current_player == 'X':
move = optimal_strategy(game.board, 'X')
if move is None:
move = game.get_legal_moves()[0] # fallback
input_sequences.append(encode(''.join(game.board)))
output_sequences.append(move)
new_game = deepcopy(game)
new_game.make_move(move)
simulate_all_games_optimal_v2(new_game, x_starts)
else:
# If it's O's turn, explore all possible legal moves
for move in game.get_legal_moves():
new_game = deepcopy(game)
new_game.make_move(move)
simulate_all_games_optimal_v2(new_game, x_starts)
# Character to integer mapping
chars = sorted(list(set('XO-')))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {0: '-', 1: 'X', 2: 'O'}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
# Reset and re-simulate
input_sequences = []
output_sequences = []
# 'X' starts
initial_game = TicTacToe()
simulate_all_games_optimal_v2(initial_game, True)
# 'O' starts
initial_game = TicTacToe()
initial_game.current_player = 'O'
simulate_all_games_optimal_v2(initial_game, False)
# Convert to Pytorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)
print("Number of input-output pairs:", len(input_sequences))
Number of input-output pairs: 1017
让我们在新的训练数据上重新训练我们的模型。
# Hyperparameters
batch_size = 128 # How many independent sequences will we process in parallel?
block_size = 9 # The size of the tic-tac-toe board
max_iters = 10000
eval_interval = 500
learning_rate = 1e-3
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 100
n_embd = 32 # Reduced the embedding size
n_head = 2 # Reduced the number of heads
n_layer = 2 # Reduced the number of layers
dropout = 0.1
print(f'Training on {device}')
# Initialize random seed
torch.manual_seed(1337)
# Split into training and validation sets
n = int(0.90 * len(input_tensor)) # 90% for training
train_input = input_tensor[:n]
train_output = output_tensor[:n]
val_input = input_tensor[n:]
val_output = output_tensor[n:]
# Updated data loading function
def get_batch(split):
input_data = train_input if split == 'train' else val_input
output_data = train_output if split == 'train' else val_output
# Choose index locs for batch_size sequences
ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
# Get the input and output sequences
x = input_data[ix]
y = output_data[ix]
x, y = x.to(device), y.to(device)
return x, y
# Initialize the model
model = Transformer()
model = model.to(device)
max_iters = 5000
# Print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []
# Training loop
for iter in tqdm(range(max_iters)):
# Evaluate the loss on train and val sets occasionally
if iter % eval_interval == 0 or iter == max_iters - 1:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
val_loss_history.append(losses['val'])
# Sample a batch of data
xb, yb = get_batch('train')
# Evaluate the loss
logits, loss = model(xb, yb)
train_loss_history.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
Training on mps
0.025961 M parameters
0%| | 6/5000 [00:00<10:15, 8.11it/s]
step 0: train loss 2.2078, val loss 2.2166
10%|█ | 505/5000 [00:13<05:23, 13.89it/s]
step 500: train loss 0.3063, val loss 0.6145
20%|██ | 1005/5000 [00:26<04:40, 14.24it/s]
step 1000: train loss 0.0741, val loss 0.2259
30%|███ | 1505/5000 [00:38<04:08, 14.05it/s]
step 1500: train loss 0.0368, val loss 0.1799
40%|████ | 2005/5000 [00:51<03:36, 13.83it/s]
step 2000: train loss 0.0134, val loss 0.1589
50%|█████ | 2504/5000 [01:04<02:57, 14.06it/s]
step 2500: train loss 0.0081, val loss 0.0884
60%|██████ | 3008/5000 [01:17<01:56, 17.06it/s]
step 3000: train loss 0.0041, val loss 0.0521
70%|███████ | 3505/5000 [01:29<01:46, 14.09it/s]
step 3500: train loss 0.0028, val loss 0.0855
80%|████████ | 4005/5000 [01:42<01:10, 14.06it/s]
step 4000: train loss 0.0036, val loss 0.1125
90%|█████████ | 4506/5000 [01:56<00:29, 16.68it/s]
step 4500: train loss 0.0014, val loss 0.0892
100%|██████████| 5000/5000 [02:08<00:00, 38.79it/s]
step 4999: train loss 0.0026, val loss 0.0721

使用最优训练数据的新损失(作者提供的图像)。
太棒了!我们不仅学会了策略,而且它在验证数据集上也能泛化(我们在训练和验证集上接近 0 损失)。我猜这是由于棋盘状态的固有对称性,变换器已经学会了一种对棋盘状态不变的模块算术形式。
让我们尝试新的变换器:
O's turn.
- | - | -
- - - - -
- | - | -
- - - - -
- | - | -
Your turn!
Legal moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
X's turn.
- | O | -
- - - - -
- | - | -
- - - - -
- | - | -
Model's turn...
O's turn.
X | O | -
- - - - -
- | - | -
- - - - -
- | - | -
Your turn!
Legal moves: [2, 3, 4, 5, 6, 7, 8]
X's turn.
X | O | -
- - - - -
- | - | -
- - - - -
- | O | -
Model's turn...
O's turn.
X | O | -
- - - - -
- | X | -
- - - - -
- | O | -
Your turn!
Legal moves: [2, 3, 5, 6, 8]
X's turn.
X | O | -
- - - - -
- | X | -
- - - - -
- | O | O
Model's turn...
O's turn.
X | O | -
- - - - -
- | X | -
- - - - -
X | O | O
Your turn!
Legal moves: [2, 3, 5]
X's turn.
X | O | O
- - - - -
- | X | -
- - - - -
X | O | O
Model's turn...
X wins!
X | O | O
- - - - -
X | X | -
- - - - -
X | O | O
它打败了我!使用了一个绝妙的叉子。看起来我们的变换器已经学会了最优策略。
结论
我认为这里的主要收获是变换器完全能够学习游戏的最优策略。虽然一个普通的神经网络可能也能学到相同的最优策略,但注意力机制的动态特性意味着它可能能够处理表示游戏随时间演变的更长序列。这些想法自然地促使我们在强化学习设置中应用变换器。例如,Janner 等(2021) 使用变换器来建模轨迹分布,并使用束搜索作为规划算法。
从这个项目中我学到的另一件事是,人工手动引导变换器通过最优策略的过程显然无法扩展,尤其是当游戏变得更加复杂时。例如,围棋并不是一个“已解决”的游戏,因此我们不能像上面那样提供最优策略进行训练。相反,我们必须使用类似自我对弈的方法来选择好的棋局序列,然后再用这些序列训练变换器。我希望未来能尝试这些想法。
最后,通过查看变换器中的预测和规划层级,仍有一个待开发的领域。正如Ba 等人 (2016) 所指出的,深度学习通常关注于在激活动态中保持临时状态的方法,而我们的大脑似乎是通过中期突触可塑性来调节临时状态信息。换句话说,应该有某种形式的工作记忆/预测,在下一个标记级别和例如 LSTM 中的长期状态之间操作。作为 GPT-3 支撑骨架的自回归解码器仅变换器架构是一个强大的模型,可以通过预测一个标记的未来生成逼真的文本。然而,如果我们将智能拟人化,我们知道快速的直观预测(仅预测一个标记)并不能使人变成天才。因此,在我看来,尝试为模型提供多个预测层级,让模型学会预测多个未来标记将会很有趣。这将类似于人类的提前规划能力,并且可能是变换器学习的一个有用技能。
预测未来意味着什么?这里有几个不同的探索方向:
-
时间上的预测:变换器能否学会预测下一个序列中的标记,而不是预测两个、三个或更多步之后的标记?预测两个标记一次是否等同于先预测一个标记,再预测下一个标记,还是说一次预测两个标记有某种战略上的好处?这是否迫使变换器思考更长时间?
-
空间上的预测:有证据表明,人类会产生粗略的任务语义表示,然后使用层级模块来“填补”这些粗略表示中的空白。你可以将其想象为写一篇文章:首先你创建一个要点的骨架,然后为每一段填入论点句子,最后完善细节。变换器是否可能学会做同样的事情?
无论这些问题是否有用,我希望这篇文章能为我们如何将问题重塑为适合变换器的问题提供一些清晰的见解。祝调优愉快!
参考文献
-
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). 注意力机制是你所需的一切。神经信息处理系统进展, 30。
-
Janner, M., Li, Q., & Levine, S. (2021). 离线强化学习作为一个大的序列建模问题。神经信息处理系统进展, 34, 1273–1286。
-
Ba, J., Hinton, G. E., Mnih, V., Leibo, J. Z., & Ionescu, C. (2016). 使用快速权重关注最近的过去。神经信息处理系统进展, 29。
-
Andrej Karpathy. 让我们从头开始构建 GPT:从代码到拼写的全程讲解。
www.youtube.com/watch?v=kCc8FmEb1nY&t=5076s
我们能否阻止 LLMs 产生幻觉?
意见
普及 LLMs 的最大障碍之一可能是本质上难以解决的。
·
关注 发表在 Towards Data Science · 阅读时间 5 分钟·2023 年 8 月 24 日
--
图片由 Google DeepMind 提供,来源于 Unsplash
虽然大型语言模型(LLMs)已经引起了几乎所有人的关注,但由于它们存在一个相当恼人的方面——这些模型有时会“幻觉”,因此这种技术的大规模应用略显受限。简单来说,它们有时会编造信息,而且最糟的是,这些内容往往看起来非常令人信服。
幻觉,无论频繁与否,都带来了两个主要问题。它们无法直接应用于许多敏感或脆弱的领域,其中一个错误可能会非常昂贵。此外,它还会造成普遍的不信任,因为用户被期望验证 LLM 输出的所有内容,这在一定程度上违背了这项技术的初衷。
学术界似乎也认为幻觉是一个重大问题,因为 2023 年有很多研究论文讨论并试图解决这个问题。然而,我倾向于同意 Yann LeCun,Meta 的首席 AI 科学家,他认为幻觉根本无法解决。我们需要对技术进行彻底的重构以消除这个问题。
幻觉虚假陈述
我认为,有两个重要方面使得幻觉问题无法解决。首先是相当明显的技术基础,LLM 像其他机器学习模型一样,具有随机性。简单来说,它们做出预测。
虽然它们确实比“被吹捧的自动完成”要先进得多,但底层技术仍然使用关于令牌的统计预测。这既是 LLM 的优点也是缺点。
在强项方面,我们已经看到它们在预测输入后的内容方面是多么出色(假设没有故意破坏输出的尝试)。用户可能会犯几种类型的错误,比如留有错别字、误解单词的意思等,而 LLM 仍然可能给出正确的输出。
当初,首批基于文本的游戏被创造时,用户需要准确输入命令,不能有任何错误或解释的空间。比如,“move north”命令如果用户输入为“move morth”就会出错。然而,LLM 可能能够推测两者的意思。从这个角度来看,这项技术确实非常迷人。
然而,这也展示了一个弱点。任何输入都有一个广泛的令牌选择决策树。简单来说,模型生成输出的方式总是有很多种。在那种广泛的选择范围中,相对较小的部分是“正确”的决定。
尽管有许多优化选项可供选择,但问题本身是不可解决的。例如,如果我们增加提供某个特定答案的可能性,LLM 会变成一个查找表,所以我们希望保持平衡。底层技术仅基于随机预测,因此必须为更广泛的输出令牌提供一些空间。
但 LLMs 在当前状态下不能解决另一个问题。这涉及到更为抽象和虚幻的认识论问题,即研究知识本质的哲学领域。表面上看,这个问题很简单——我们如何知道哪些陈述是真实的,我们如何获得这种知识?毕竟,幻觉只是一些虚假的陈述事后产生的,因此如果我们能为模型创建一种方法来验证其是否做出了虚假的陈述并将其删除,这将解决问题。
分离幻觉和真实陈述
借鉴哲学的思路,我们可以区分两种可能的陈述——分析性和综合性。前者是通过定义而真实的陈述(最常见的例子之一是“单身汉是未婚男子”)。简单来说,我们可以通过分析语言本身找到真实的陈述,而无需外部经验。
综合性陈述是指通过某种形式的经验来判断真实的陈述,例如“桌子上有一个苹果”。在没有直接经验的情况下,无法知道这样的陈述是否真实。纯粹的语言分析对于判断其真假无济于事。
我应该指出,这些陈述之间的区别在几百年来一直备受争议,但这一讨论对于 LLMs(大型语言模型)来说基本上不相关。正如它们的名字所示,它们是高度先进的语言分析和预测机器。
根据这两种类型的区别,我们可以看到 LLMs 在分析性陈述方面几乎不会有问题(至少不会比人类有更多问题)。然而,它们无法获取经验或大范围的世界知识。它们无法知道某些陈述是否因为事件的存在而真实。
主要问题在于,分析性陈述的数量远小于所有综合性陈述的集合。由于 LLMs 无法验证这些陈述是否真实,我们作为人类必须向它们提供这些信息。
因此,LLMs 面临一个挑战。所有可能输出的集合中总会包含一些综合性陈述,但对于模型来说,它们都是不具备真值的。简单来说,“尤利乌斯·凯撒的刺客是布鲁图斯”(虽然有很多,但在此案例中无关紧要)和“尤利乌斯·凯撒的刺客是亚伯拉罕·林肯”对于模型来说是等同的。
反驳意见可能是我们对这些事件也没有直接经验。我们只是从书籍中读到它们。但对陈述真实性的发现是基于对幸存记录的重建以及广泛的考古证据。
一个简单的(虽然相关性较低)例子是“今天在下雨。” 对于大型语言模型来说,这样的陈述是无法确定其真实性的,因为它需要在查询时接触到现实世界的经验。
从某种意义上说,认识论问题是自我解决的。我们的文学语料库会使“凯撒的刺客是布鲁图斯”这种输出变得显著更可能,因为它出现得更频繁。然而,再次强调的是,这种自我解决的方法依赖于在绝对所有可用文本信息上训练大型语言模型,这显然是不可能的。此外,这也会使其他不那么真实的输出仍然存在于所有可能输出的集合中。
因此,数据质量变得非常重要,但这种质量只能由人类观察者来判断。即使模型在大量数据上进行训练,仍然会有一定的选择过程,这意味着合成陈述的错误率无法消除。
结论
我认为,阻止模型产生幻觉的问题是无法解决的。一方面,技术本身基于随机过程,这不可避免地会在大量输出中导致错误的预测。
除了技术难题,还有一个问题是大型语言模型是否可以对陈述做出真实性判断,我再次认为这是不可能的,因为它们无法接触到现实世界。这个问题在很多大型语言模型现在提供的各种搜索引擎功能中稍微得到了缓解,这些功能可以验证某些陈述。
然而,可能会有一种方法是收集一个可以测试陈述的数据库,但这需要超出技术本身的东西,这将我们带回到最初的问题。
一个机器学习工程团队的碳排放
原文:
towardsdatascience.com/carbon-emissions-of-an-ml-engineering-team-ce170bd4fae9
开发的隐性成本
·发表于Towards Data Science ·9 分钟阅读·2023 年 10 月 16 日
--
由于人为活动导致的全球变暖,大家都意识到了气候危机。为了防止其灾难性后果[1],世界需要大幅减少我们的温室气体排放,许多国家设定了到 2050 年实现净零排放的目标。
近年来 AI 技术的蓬勃发展也引发了对其环境成本的担忧。如果我们仅仅关注其直接贡献,这将通过电力使用来训练和驱动模型。例如,训练具有 1750 亿参数的 ChatGPT-3 产生了高达 502 吨的碳当量排放(tCO2e)[2]。新兴的 Llama2 在训练其四个模型时产生了类似的 539 吨 tCO2e[3]。作为对比,每一个模型的排放量相当于一名乘客从纽约到旧金山单程飞行 500 次的排放量。
我在一个机器学习工程团队工作,这个问题也时常困扰着我。我们通过电力消耗贡献了多少碳排放?是否有减少的方法?于是,我们开始首次尝试进行碳排放核算。

图片由Chris LeBoutillier提供,来自Unsplash
方法
没有单一直接的方法来测量我们的电力消耗以及随之而来的碳影响。这是因为我们使用的平台和服务多种多样。我不会深入探讨技术实现,但从高层次来看,方法包括三种。
-
提供:确切的碳排放数据已经为我们计算好了。这是我们的云服务提供商(CSP)提供的。
-
工具:我们使用了像 Powermetics、Nvidia-SMI 和 Turbostat 这样的几种软件工具来测量功率(瓦特),这些工具跟踪我们笔记本电脑和本地服务器的 CPU 和 GPU 计算。
-
自我计算:当上述方法不可行时,我们使用代理方法进行计算。这包括记录计算的持续时间,估计芯片的利用率百分比,以及查找每种芯片类型的热设计功率(TDP)来计算功耗。其余平台以这种方式计算。
对于后两种方法,功率会被转换为能量(千瓦时),如果有的话,会使用支持数据中心的电力使用效率(PUE)来获得更准确的能量消耗。最后,使用该国或地区的电网排放因子(kgCO2e/kWh)来计算温室气体排放。
结果与思考
结果显示在下面的饼图中。

我们使用的每个平台的碳排放量 extrapolated 为整个 2023 年。图像由作者提供。
在碳排放排名方面,平台的排名并不特别令人惊讶,但我对百分比感到惊讶。我没有想到我们的开发笔记本电脑和 CICD 服务在非常重的使用下只产生了微量的碳。与此同时,我也没想到我们的本地开发和模型训练服务器会消耗比我们的云服务多三倍的碳。
回顾过去,我们最近将笔记本电脑升级到了最新的 Apple Silicon M2 芯片,这一芯片以高效著称。我们的 CICD 平台虽然拥有数千分钟的流水线运行时间,但使用的是最低计算芯片,实际上是无服务器的,仅在必要时运行。
对于我们的本地服务器,我们发现空闲的 Nvidia GPU 芯片仍然消耗大量电力,导致电力消耗膨胀。我们需要调查是否存在任何配置错误,如果没有,是否有更好的管理方法。
绿色计算
现在我们对碳排放的认识有了更好的了解,我们如何才能真正改变开发团队,采用更多绿色解决方案呢?
绿色计算这个术语已经存在了一段时间,并且已被组织或分类成不同的形式,但我认为下面这六个广泛的主题将帮助我的团队更清晰地管理绿色转型。

由 Ash from Modern Afflatus 提供的照片,来源于 Unsplash
1. 绿色 AI
这指的是寻找更高效地训练和推断模型以尽量减少质量损失的方法。它基本上意味着更快的训练和推断时间,以及更小的模型尺寸,使用更少的计算能力。使用更复杂的神经网络需要越来越大的数据集以及日益先进、昂贵且耗能巨大的 GPU 芯片。
幸运的是,这也是最新优化研究的热点。在过去几年中,我听说我的数据科学家同事们在各个领域使用更高效的架构、迁移学习、量化或知识蒸馏等压缩技术、ONNX、使用deepspeed、PEFT等,以应对当今大语言模型时代的挑战。毫无疑问,我们需要跟上开源世界最新实施的步伐,因为它们的好处已被证明是显著的。
2. 绿色应用
模型在没有周围代码来处理数据、训练模型和最终提供服务的情况下是无用的。需要对时间和空间复杂性、实现的算法以及各种预构建函数有基本的理解。还应使用性能分析工具来查找延迟和内存中的瓶颈。
另一个构建绿色应用程序的重要软件工程技能是理解任务和进程的管理、执行和协调。这需要对并行性、并发性、异步性、多处理和线程、队列、I/O 和 CPU 限制任务等概念有扎实的掌握。
更进一步,编程语言的选择也很重要。由于其广泛的支持和易用性,Python 已成为数据科学和通用编程中使用的顶级语言之一。然而,作为一种解释性语言,与其编译语言如 Go 相比,在能耗和速度(约 x20)方面显著逊色[4]。因此,值得花费时间学习另一种编译语言,以应对需要大量处理的工作。
3. 绿色服务器
训练和服务机器学习应用需要计算能力。这由托管在本地或云端的服务器提供。如果可能的话,使用云服务是保持绿色的最佳方式,因为云服务提供商有动力高效运行其数据中心,而且你可以根据项目需求灵活切换资源。无论如何,我们应确保两个关键因素:选择正确的硬件来完成任务,以及仅在需要时使用计算资源。
主要的 CSP 都提供了多样的服务器供选择。例如,AWS 有七个实例家族,每个家族包含不同的芯片、内存和其他规格,足以满足各种需求,如 GPU、CPU、内存密集型过程,甚至 ARM 或 x86 架构。我们应选择那些最佳匹配我们的用例的服务器,以便通过其硬件规格高效分配计算资源。
我们如何在需要时才计算?首先,对所有不使用的资源进行盘点并关闭。你会惊讶于遗留项目中还有多少闲置的服务。在架构设计方面,我们可以选择使用类似 AWS lambda 的无服务器计算,它只有在有流量时才使用资源,或者提供一个基本的长生命周期计算,具有水平扩展功能,可以自动响应负载增加。
4. 绿色存储
存储有多种形式,如对象存储、块存储和文件存储、容器注册表和数据库。我们可以使用两个一般性的指南来高效管理存储:减少存储大小和选择合适的存储类型。
数据的存储大小可以通过压缩来减少,一些常见的压缩工具包括 gzip 或用于归档的 tar.gz,它可以将大小减少一半。使用更高效的数据结构也可以是一个更好的替代方案。使用像 parquet 这样的列式格式不仅占用空间更少(>50%),而且由于其列式结构的特性,也使查询速度更快(提高 30 倍)。
以对象存储为例,有一些存储类别使用更少的能源。在 AWS S3 中,我们可以选择将不太重要的数据保留在一个区域,而不是在多个区域中复制。对于不经常访问的长期存储,我们可以将其放入“冷存储”(S3 glacier),那里使用的磁带驱动器相比 SSD 和 HDD 消耗更少的能源。还可以设置生命周期策略来自动在存储类别之间转换,甚至在项目结束时删除数据。
5. 绿色传输
数据需要在服务器、存储和其他设备之间来回传输。网络通信也需要能源,以支持复杂的网络设备、数据中心、传输基础设施和终端用户设备。对于我们这样的开发者,我们可以通过使用高效的传输协议以及减少传输的频率和距离来降低碳足迹。
在可能的情况下,应考虑使用 http/2 传输协议和 gRCP 框架,因为它可以以更紧凑的二进制格式传输,而不是传统的文本(JSON)有效载荷的 http/1。这样可以降低延迟和能源消耗。
将数据更靠近使用源,并安排它们的传输时间,也可以减少所需的能量。例如,运行自动化测试用例所需的依赖项可以缓存,并且仅在检测到新更改时才重新构建。镜像不需要每次都从 Dockerhub 拉取;我们可以将它们存储在我们的 CSP 注册表中,并在有新补丁时定期更新。
6. 绿色模板
这指的是高效代码、基础设施和流程的可重用性和可重复性。本质上,这是一种间接减少电力消耗的方式,因为实际实现来自前五个主题。然而,我认为这是最重要的一点,因为它是团队知识的总和。
这可以以文档或剧本的形式出现,设定团队职能和项目执行的标准,或为仓库、CICD 流水线、基础设施设置(例如 Terraform)和配置(例如 Ansible)提供现成模板。

决策象限以优先考虑高影响力和易于实施的解决方案。图片来自作者。
在这六个主题中的每一个,我都给出了一些示例,但这仅仅是冰山一角。在每个主题中实施的建议众多且令人望而生畏。然而,通过将每个建议放在决策象限中,估算它们在现有工作流程中的实施难度,以及它们的影响是否显著和协同,可以实现渐进式过渡。这将提供一些关于优先考虑哪些建议的指导。
设计原则
这种转变既不直接也不容易。即便是像我们这样热衷于可持续发展的开发者,也必须优先考虑业务需求。我们可以通过不把可持续性和碳效率在机器学习开发中视为与其他需求对立或相互排斥的概念来应对这一挑战。这将确保你仍然与业务目标保持一致,同时也能更容易获得管理层的支持,他们总是面临着交付的压力。
我们可以将其可视化为一个维恩图,使用 AWS 的六个设计支柱来构建其架构良好框架,其中功能或业务需求与可持续性重叠。

重新构想架构良好框架。图片来自作者。
事实上,如果你仔细思考,通常会发现协同影响。让我们来看一些例子:
-
压缩数据存储可以减少 x2 大小,从而节省成本和带宽,同时也减少了存储和传输所需的能量。
-
神经网络模型的量化在推理速度方面表现更好,从而消耗更少的能量。
-
移除 Docker 镜像中的未使用依赖项将通过减少潜在的可利用表面积来提高安全性,增加由于镜像体积较小而导致的部署速度,并减少存储和传输到及从你的注册中心所需的能量。
结论
总的来说,这为我们工程团队减少碳足迹的旅程提供了一个良好的开端。接下来还会有大量的工作需要识别、量化、标准化和教育每个绿色计算主题下的推荐措施。
我希望听到你在测量和减少开发团队碳足迹方面的旅程。请在下面的评论中分享!
致谢:这次碳核算尝试是我与我的同行* 杨可文 和 钟耀威* 一起完成的个人项目。*
免责声明:这里表达的意见和建议仅代表作者个人观点。*
参考文献
-
[1] 联合国环境规划署. 关于气候紧急情况的事实.
www.unep.org/facts-about-climate-emergency -
[2] 斯坦福大学. 人工智能指数报告 2023. 第二章,技术性能.
aiindex.stanford.edu/report/ -
[3] Touvron 等. 2023. Llama 2: 开放基础与微调聊天模型. arXiv:2307.09288.
browse.arxiv.org/pdf/2307.09288.pdf -
[4] Pereira 等. 2017. 编程语言的能源效率:能源、时间和内存如何相关?第 10 届 ACM SIGPLAN 国际软件语言工程会议论文集. 第 256–267 页。
碳足迹:为什么常见的说法可能不准确
创建强健的 CO₂ 情景以推动数据驱动的气候行动
·
关注 发表在 Towards Data Science ·7 min read·2023 年 3 月 31 日
--
图片由 David Aler 提供,来自 Unsplash / 插图及拼接由作者制作
气候变化的影响在全球变得越来越明显,从毁灭性的野火到创纪录的热浪和飓风。因此,越来越多的人在寻求减少碳足迹和帮助缓解气候变化影响的方法。然而,很难知道从何处开始或如何做出有意义的改变。我们介绍了一种使用开放和链接数据方案建模不同活动碳足迹的新方法。 相关研究文章 将于今年晚些时候在 Energy Reports 期刊上发表。
随着气候危机的持续展开,自然灾害发生频率增加,减少温室气体(GHGs)的紧迫性不断加大。其中一个关键挑战是理解我们日常活动对环境的影响。然而,识别个人减少潜力可能复杂且难以理解,因为隐性、嵌入式排放常常发生在其他地方且难以可视化。
近年来,你可能见过有关日常活动如接收电子邮件或观看电影的碳足迹的惊人插图。这些比较旨在让消费者对像碳排放这样的抽象概念有所了解。它们提高了意识,鼓励人们采取行动减少碳足迹。然而,为了做出明智的决策,验证这些说法的准确性至关重要。

Adriano Becker 在 Unsplash 上的照片
假新闻?
有一种说法认为,接收一年的电子邮件产生的碳排放量与驾驶一辆普通汽车行驶 200 英里相同。不幸的是,这一说法并没有经过严格的分析,并且存在重大舍入错误。作者后来与这一说法脱离了关系。
另一个流行的说法认为,观看仅 30 分钟的 Netflix 相当于驾驶近 4 英里。然而,这一说法也发现了重大错误,包括换算错误和未考虑终端设备的碳足迹。
显然,评估和讨论碳足迹情景需要可靠的信息,上述事件突显了透明和严格的方法论的必要性。
背景很重要
那么,为什么很难对看似基本的日常活动得到可靠的答案呢?考虑以下问题:我能开车行驶多远才能排放与长途飞行相同的碳量?在法国或德国从电网充电电动车——有什么区别?选择白肉而不是红肉,在碳足迹方面有什么好处?通过 VoIP 电话(如 WhatsApp)与普通手机通话相比,如何?
这些情境下的计算不仅受复杂技术规格的影响,还受到个体因素的影响,这些因素可能根据情境的不同而大相径庭。例如,与飞行相比,你的汽车燃油消耗扮演着重要角色。电动车充电时,本地电网的能源组合是一个关键因素。例如,2021 年法国电力部门的碳强度估计为每千瓦时 58 克 CO₂,而德国为 349 克。
因此,创建针对特定情境的强健排放情景至关重要。这对于进行有力比较和制定明确且可信的建议是基础。
数据模型
为应对这一挑战,我们提出了一种通用碳足迹情境数据模型,以提高碳足迹数据的可访问性和实用性。该模型旨在开放、链接和模块化,便于共享和重用。模型以 JSON 格式表示,包含几个实体,每个实体都有一个属性列表及其数据类型。实体包括Scenario、Scope、Component、Link、Source、Emission、Consumer和Consumption,每个实体都有自己独特的属性及与其他实体的关系。

自指碳足迹情境的数据模型(可选属性为斜体)。图像由作者提供。
任何场景都由统一资源标识符(URI)标识。它具有一个标题,标题中可选地包括一个超链接形式的参考。一个场景涵盖 1 到 3 个排放范围,这些范围由温室气体议定书定义。该实体可能包含描述。此外,它包括 1 个或多个组件或链接。后者简单地与另一个通过其 URI 标识的场景相关联。它还可以附有数量的指示。然而,组件必须包括数量和数量单位(例如,“km”,“kg”,“pcs”)。它还可以有一个定义消费者类型的类别(例如,“汽车”,“食品”,“电子产品”)。此外,该实体必须包括一个来源,它具有一个名称和类型(例如,“法国电网”和“电力”,或“优质汽油”和“汽油”)。它还可以包含描述(例如,“2022 年”)。
任何来源必须包括 1 个或多个排放量,这些排放量以键值对的形式实现。排放类型定义了键(“co2e”,“co2”,“ch4”,“n2o”,“hfcs”,“pfcs”,“sf6”,“nf33”)。值、单位(例如,“g”,“kg”,“t”)和基本单位(例如,“kWh”,“l”,“kg”,“km”)指定了排放细节。可以通过超链接公开此信息的来源。组件也可以有一个消费者,当有关能源效率的信息可用时(例如,食品的排放通常仅按生产 1 公斤来报告)。消费者有一个名称(例如,“波音 747”,“iPhone 14”),并可以有一个可选的描述。它还包括 1 个或多个消费量(某些消费者可能支持多种能源来源,例如,可以用不同类型的汽油加油的内燃机汽车;混合动力车使用汽油和电力),这些消费量以键值对的形式存在。能源类型定义了键(例如,“电力”,“汽油”)并与来源类型相对应。值、单位(例如,“kWh”,“l”)和基本单位(例如,“km”,“h”,“d”)指定了消费细节。这些信息可以通过参考进行补充。
这是一个简单的数据示例:
{
"title": "Mobility",
"scopes": [
{
"level": "Scope 1",
"list": [
{
"type": "component",
"consumer": {
"name": "Volkswagen Golf (2014)",
"description": "Engine ID 45, 4 cylinders, Manual 6-spd",
"consumptions": {
"diesel": {
"value": "0.0735046875",
"unit": "l",
"base_unit": "km",
"reference_url": "https://www.fueleconomy.gov/"
}
}
},
"quantity": "10000",
"quantity_unit": "km",
"source": {
"name": "Gas/Diesel oil",
"type": "diesel",
"emissions": {
"co2e": {
"value": "3.25",
"unit": "kg",
"base_unit": "l",
"reference_url": "https://bilansges.ademe.fr/index.htm?new_liquides.htm"
}
}
}
}
]
}
]
}
我需要一个翻译员
为了使数据生动,我们开发了一个基于网络的查看器,可以解释和可视化这种数据格式。该应用程序按元素和范围聚合排放数据,进行单位转换,并基于不同类型排放数据的可用性找到共同点(例如,“CO2e”,“CO2”)。在用户界面中,用户可以立即调整数据,通过调整每个元素的数量并即时连接不同的数据源来替换消费者组件和能源来源。
自定义场景可以作为 JSON 文件下载并通过 URL 共享,这使得场景协作变得简单。最后,一个 基准视图 使用户能够通过标识符比较两个或多个场景。
该 Web 应用程序是使用 JavaScript 构建的,并作为 GitHub 页面部署,这使得部署和更新变得容易。我们数据模型的自我引用结构还允许以分布式方式托管嵌套场景。数据解释器能够递归地获取和处理这些场景,从而在分布式环境中访问和分析复杂的碳足迹模型成为可能。

基于网页的数据解释器在查看模式下。图片由作者提供。
关于如何使用和部署该应用程序的详细文档可以在我们的 GitHub 仓库 中找到。在这里,你还可以访问源代码,以及演示和多个示例场景:
我们的方法旨在提高碳足迹场景的透明度、可访问性和可探索性,使个人能够做出关于其行为环境影响的知情决策。
TL;TR
我们提出了一种新颖的数据模型,用于生成可以适应本地或个人情况的碳足迹场景。我们的查看器应用展示了用户如何增强对不同活动相关碳排放的理解,从而做出更明智的选择。用户可以实时操作数据,观察更换不同组件(例如使用替代能源或减少特定材料的数量)对整体碳足迹的影响。此外,该应用还便于对不同场景进行并排比较,使用户能够评估其差异。应用程序的另一个重要功能是能够轻松共享场景,促进合作与知识共享,从而减少碳排放。
附言:根据我的假设,从巴黎飞往新加坡的长途航班相当于驾驶我的汽车行驶 3,200 公里。
B. Ruf 和 M. Detyniecki (2022),《碳足迹场景的开放与联接数据模型》,第七届国际可再生能源与节能大会 (ICREC),法国巴黎。
卡洛斯·阿尔卡拉斯与三大巨头
一张视觉数据对比图,展示了新兴网球明星与罗杰、纳达尔和德约科维奇的对比
·
关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 7 月 29 日
--
照片由Filip Mroz拍摄,发布在Unsplash
背景
几十年来,男子网球与其主宰的三大巨头——罗杰·费德勒、拉斐尔·纳达尔和诺瓦克·德约科维奇几乎是同义词。这是一个独特的网球时代,如同在其他任何运动中一样,三位有史以来最成功的球员几乎都在同一代里竞技。在他们之间,这三大巨头(如下面的绿色区域所示)自 2003 年以来赢得了 82 场大满贯赛事中的 65 场。

图表 1
然而,随着 Roger 于 2022 年退役、Rafa 由于伤病目前缺席,以及 Novak 最近在一场史诗般的温网决赛中输给 Carlos,似乎传递接力棒的过程已经开始。
随着网球开始设想没有大三巨头的世界,球迷和专家们一直在想是否有球员能够再次拥有像 Roger、Rafa 或 Novak 那样辉煌的职业生涯,赢得超过 20 个大满贯。一两年前,甚至思考这样的问题都是愚蠢的,更不用说大声问出来了。
随着 20 岁的世界第一 Carlos Alcaraz 的惊人崛起,网球界不得不暂停并重新考虑。如今,考虑这个想法已不那么具争议性。作为网球迷,你可能会看到大量比较 Carlos 与大三巨头的评论和媒体报道。然而,很少有这些讨论将 Carlos 到目前为止的表现与大三巨头在职业生涯早期的表现进行比较。将 Carlos 在第十个大满贯正赛出场的水平与 Novak 在第七十一场的水平进行比较,虽然令人兴奋,但并未提供完整的画面。
数据来源与分析范围
作为一个网球迷和数据专业人士,我一直期待看到更多的同类比较,比较球员在职业生涯相似阶段的表现。幸运的是,得益于 Jeff Sackmann 在通过他的 tennis_atp GitHub 仓库提供巡回赛 ATP 比赛数据的伟大工作,这样的分析成为可能。这个仓库,更广泛地说,他的 GitHub 是任何对网球分析感兴趣的人的极好资源。另一方面,你可以在我的 GitHub 上找到本文中用于分析和视觉呈现的所有 Python 代码。
拥有这些数据并在 Pandas 和 Matplotlib 中进行一些分析,我们可以更细致地了解 Carlos 的早期职业生涯与大三巨头的比较。
我们将特别关注大满贯和大师赛 1000 赛事这两类最受追捧的赛事,并将比较范围扩大到包括 前十名球员,而不仅仅局限于 Carlos。这将有助于为表现提供更多背景和视角。
首先,这里是大三巨头与今天排名前十的球员(截至 2023 年 7 月 17 日)在大满贯和大师赛 1000 赛事中获胜数量的当前状态比较。

图表 2
我们可以明确说的是,目前大三巨头与前十名之间的差距相当大。了解 Roger、Rafa 和 Novak 到达现有地位的过程将是有帮助的。
比较历程
为了做到这一点,让我们对比一下上面显示的两大主要赛事类别中的球员历程。我们将从大师赛 1000 赛事开始,然后转到至关重要的大满贯赛事。
大师赛 1000
下图绘制了三巨头及七位其他前十名球员的历程,这些球员至少赢得过一个大师赛 1000 赛事。x 轴代表大师赛 1000 赛事的参与次数(从正赛开始),y 轴代表赢得的大师赛 1000 赛事的数量。在某一点上,球员趋势线越高,说明到那一点的表现越好。

图表 3
三巨头的长期性和持续成功显而易见,每位三巨头成员参加了超过 120 个大师赛 1000 赛事,并赢得了 20 多个冠军。然而,为了讨论方便,让我们聚焦于上面突出显示的左下角象限,比较两位杰出球员与三巨头的表现。

图表 4
达尼尔·梅德韦杰夫的六个大师赛 1000 冠军是当前前十名球员中最多的(不包括诺瓦克)。他起步较晚,在第 21 次出战时赢得了他的第一个大师赛 1000 冠军。以 42 次赛事赢得六个冠军,他目前略微领先于诺瓦克在此阶段的表现,并与罗杰持平。
然而,卡洛斯·阿尔卡拉斯是最引人注目的。卡洛斯在仅仅第 7 次出战时便获得了他的第一个大师赛 1000 冠军,他的起步比三巨头中的任何一位都要快。在参加了 16 个大师赛 1000 赛事后,他与拉法在四个冠军上并列领先,而在类似阶段,诺瓦克(2)和罗杰(0)则远远落后。
大满贯赛事
让我们看看这些趋势在最大舞台——大满贯赛事上是否也成立。卡洛斯和达尼尔是前十名中唯一在大满贯赛事中赢得过冠军的球员。让我们与三巨头对比他们的大满贯历程。

图表 5
达尼尔在 25 次出战中仅赢得一个大满贯冠军,目前稍微落后于诺瓦克(2),远远落后于罗杰(5)和拉法(8)。与此同时,卡洛斯则一飞冲天,在诺瓦克和罗杰获得首个大满贯之前就赢得了两个大满贯冠军。
由于每年只有四个大满贯赛事,因此对球员早期职业生涯的大满贯赛事数据集相对有限。为了进一步了解大满贯赛事的表现,我们将数据的粒度从赛事层面提高到比赛层面。
在下图中,x 轴代表大满贯比赛的数量(从正赛开始),y 轴代表赢得的大满贯比赛数量。

图表 6
这个视图让我们感受到每位球员在淘汰赛中的深度。在 44 场大满贯比赛后,卡洛斯与拉法在 36 场胜利上并列第一,如果他在今年晚些时候的美国公开赛首轮比赛中获胜,他将独占榜首。
结论
从上述数据可以明显看出,卡洛斯·阿尔卡拉斯是一位绝对的天才。如果我们考虑年龄因素,他的表现更为引人注目。让我们看看这些球员在 25 岁之前赢得的各项大师 1000 赛和大满贯赛事的数量。

图表 7
作为 19 岁时最年轻的男子世界第一,且在仅 20 岁时已经拥有两个大满贯和四个大师 1000 赛,卡洛斯前途广阔。
但正如大多数历代伟大球员一样,真正特别的是这三位伟大球员能够保持如此长时间的卓越。如果卡洛斯能保持健康,并继续像今天这样主导比赛,而这些图表中的 x 轴年复一年地延展,大三巨头可能需要腾出一些位置。
感谢阅读!
数据来源:Sackmann, Jeff. GitHub 仓库: github.com/JeffSackmann/tennis_atp
分析与视觉笔记本:github.com/asawhney27/Tennis-Analytics
案例研究:将数据科学过程模型应用于实际场景
供应链中材料规划的机器学习模型开发
·
关注 发表在 Towards Data Science ·16 分钟阅读·2023 年 3 月 11 日
--
在当今快速变化的环境中,公司面临的一个最关键的挑战是准确预测未来需求。这对于供应链团队尤其重要,准确的需求规划对于保持客户满意度和控制成本至关重要。
在这个案例研究中,我们将探讨数据科学过程模型如何通过利用统计预测方法,帮助公司实际解决这一挑战。虚拟公司的目标是开发一个更准确的需求规划过程,以减少缺货,提高库存周转率,并改善整体供应链表现。

图片由Unsplash提供
本项目是数据科学如何通过揭示新见解、提高效率和改善决策来改变业务的强大示例。我希望这个案例研究能帮助你考虑在组织中潜在的应用,并展示如何成功应用过程模型 DASC-PM。
请注意,整篇文章也已在以下出版物中发布,作者为 丹尼尔·巴杜拉 和 乔纳斯·迪克曼:
第三章:“供应链中材料规划的机器学习模型开发”,见:施尔茨等(2023):** DASC-PM v1.1 案例研究。可从:
www.researchgate.net/publication/368661660_DASC-PM_v11_Case_Studies获取
1. 领域和项目描述
SCHRAMME AG 是一家领先的敷料、创可贴和绷带供应商。管理层认为在材料规划及其产生的生产过程中存在定性优化潜力和节省机会。管理层指派了一名内部项目经理开发一个基于机器学习的模型,以规划供应链中的材料和需求。由于之前的数据科学项目中的负面经验,建议初期应使用过程模型来开发此项目。
选择DASC-PM以确保项目管理的结构化和科学化过程。为了获得项目任务的概述,项目经理最初制定了各种用例,然后检查其适用性和可行性。适用的用例将作为确定具体问题和设计项目的基础。随后,这一设计将再次检查其适用性和可行性。

图片由Unsplash提供
起始点和用例开发
目前,公司手动规划并生产超过 2500 种不同的产品。在过去几个季度,公司在一些产品系列中越来越频繁地出现库存短缺,而个别产品的库存则超过了储存能力。尽管控制部门抱怨由于不精确的规划导致库存成本上升,需求规划师却感叹规划时间不足。供应链负责人已经批评规划完全依赖人工,未能充分利用数字化的机会。
项目目标
该项目的一个目标是开发一个机器学习模型,未来应根据各种影响因素自动规划大量产品需求。需求规划师应逐渐关注重要产品组和广告的规划。系统应考虑季节性、趋势和市场发展,达到 75%的规划准确度。这意味着每种产品的预测数量与实际需求的偏差不应超过 25%。订单历史、库存和客户销售数据,以及内部广告计划应作为潜在的数据来源。

阶段 1: 项目订单(Schulz et al. 2022)
当前团队设置
除了供应链部门的参与外,还期望与销售和 IT 部门进行紧密合作。供应链部门的规划团队现在包括一个全球市场需求规划团队,负责根据市场发展、产品生命周期和战略重点进行长期规划(6-18 个月)。在个别市场中,还有地方客户需求规划团队,通过相应的销售渠道为零售实施短期物料和广告规划(0-6 个月)。
需要开发的数据科学模型应支持每月的规划周期,并量化短期和长期物料的需求。预测结果随后被加载到内部规划软件中,并应进行分析,如有必要,进行补充或修正。最终的规划数量将由工厂用于生产规划。为了考虑客户和产品的专业知识、季节性和过往经验,规划团队的个别成员应被纳入项目中,分配最多 20%的工作时间参与其中。
适用性检查
在用例选择过程中,一个重要的部分方面是适用性测试。项目经理试图检验项目是否从根本上可以被分类为可行,以及是否可以利用现有资源执行这些要求。专家访谈表明,该问题总体上非常适合数据科学的应用,并且类似的项目已经在外部进行并发布。数据科学团队确认有足够数量的潜在适用方法用于该项目,并且所需的数据源也可用。
最后,项目经理分析可行性。需要与 IT 部门协调,检查可用的基础设施和相关员工的专业知识。微软提供的云基础设施和数据科学团队对 Databricks 软件的经验使得项目看起来在基础上是可行的。由于计划者在实施阶段担任主要控制者,结果会被检查,因此项目风险总体上被分类为中等。

数据科学过程模型 DASC-PM(Schulz 等,2022)
项目设计
根据问题和领域的具体方面,项目经理、供应链负责人以及数据科学家现在负责正式设计该项目。
项目的目标被认为是提高计划准确性和减少人工流程,并且与开发适当的项目模型的目标相关。根据初步估算,成本框架总计为 EUR 650,000。建议开发时间框架为六个月,额外计划六个月用于过程整合。
由于与许多其他项目相比,在数据科学背景下通常无法进行全面规划和项目过程描述,项目经理仅为该过程准备了一个项目概要,其中包含前述部分已指明的基本要素。预算包括 1 名全职项目经理、2 名全职数据科学家和 0.5 名全职数据工程师的财政资源。如前所述,需求规划师应分配大约 20%的工作时间来分享他们的专业知识和经验。
整个项目应采用敏捷工作方法,并基于 Scrum 方法论的 DASC-PM 阶段进行处理。工作在数据获取、分析、利用和使用等领域以迭代方式进行,每个阶段都将前一阶段和后一阶段作为重点。如果在关键领域发现差距或问题,并且只能通过回到前一个阶段来解决,则回溯步骤尤为重要。项目概要以可视化的方式准备,并放置在 SCHRAMME AG 办公室中所有参与者都能看到的位置。然后,整个项目描述会再次检查其适用性和可行性,直到过程进入下一阶段。
2. 数据提供
数据准备
SCHRAMME AG 拥有多个可以纳入自动规划的数据源。除了 ERP 系统中的历史销售数据外,还可以选择来自 CRM 系统的订单历史和客户数据,以及库存和营销措施。Azure Data Factory 被用来准备一个基于云的数据管道,加载、转换和集成来自各种源系统的数据。自动预测的主要基础应该是订单历史:其余数据要么作为规划团队的背景信息使用,要么在需要时进行集群分析。在项目的初始阶段,各个数据源在质量和结构上仍存在较大差异。因此,与 IT 和技术部门一起进行调整,以便在后续阶段能够在坚实的基础上进行预测。

ELT 数据准备过程用于分析。图像由作者提供
数据管理
数据管理过程由数据工程师自动化,并根据每日计划进行,以始终保持最新状态。为了保持复杂度在合理范围内,最有前景的数据源会首先进行处理,随后通过持续集成/持续部署(CI/CD)逐步扩展数据管道。部署后,处理过的数据会存储在 Azure Data Lake Storage 中,供未来使用 Azure Databricks 进行分析。DataLake 还存储准备好的数据和分析结果的备份,以及其他数据如协议、质量指标和凭证结构。写入和读取授权以及计划版本也确保只能处理最新的规划周期,使得过去的值不再发生变化。

阶段 2:数据提供(Schulz 等,2022)
探索性数据分析
数据准备中的一个重要步骤是探索性数据分析(EDA),在这个步骤中会生成各种统计数据和可视化图表。结果展示了数据中的分布、离群值和相关性。EDA 的结果提供了下一阶段分析中需要考虑的特征的见解。在第二步中,使用特征选择和特征工程来选择相关特征或生成新特征。对于高维数据,应用主成分分析等降维方法。EDA 提供了关于 SCHRAMMEAG 现有需求历史的信息。

探索性数据分析的结果示例。图片由作者提供
3. 分析
识别合适的分析方法
项目开始时的可行性测试明确了该项目可以并应该使用数据科学方法解决。最初参与的两名数据科学员工提供了适合现有问题的现有方法概述。这个现有问题属于监督学习算法中的回归问题类别。从根本上讲,这是一种时间序列分析类型,可以通过额外因素或多重回归进行扩展。
在科学性关键领域的背景下,考察了对比问题的最新研究进展。这显示出 XGBoost、ARIMA、FacebookProphet 和 LightGBM 是该问题类别中经常提到的方法。一名数据科学家记录了每种方法的相应优缺点,并根据复杂性和计算强度对其进行排序。为了获得关于 SCHRAMME AG 产品模型能力的初步指示,项目组初步选择了更简单的模型,然后采用了经典的指数平滑和 ARIMA 模型系列。

第三阶段:分析(Schulz 等人 2022)
分析方法的应用
由于该项目涉及多个用户参与分析过程,团队最初依赖于 Databricks 中的合适笔记本开发环境。按照典型的机器学习工作流程,首先实现了导入和数据清洗的代码。为了确保有效性,最终通过交叉验证将基础数据集划分为训练、验证和测试数据。然后,将选定的方法应用于训练和验证数据集,以优化模型。在此过程中,还反复尝试优化处理参数,并在必要时合理减少可用维度。SCHRAMME AG 的数据科学家记录了各个运行的执行和验证结果。尽管 ARIMA 系列模型在相对指数平滑方面表现出更好的性能,但目前得到的 62.4%的目标准确性仍未达到 75%。RMSE 和 MAPE 指标也显示出优化的潜力。

ARIMA 预测与实际需求的对比。图片作者提供
参数配置和选择最终模型的基础在第一次应用迭代后被记录并以技术上可理解的方式为项目经理和供应链负责人准备。特别观察到的是,一些产品组具有非常不寻常的季节性,某些产品总体上非常难以预测。即使 SCHRAMME AG 的产品组合由于冠状病毒大流行期间的临时关闭(封锁)受到的影响较小,但仍观察到调味品产品的需求略有下降。假设活动和运输减少,以及事故和伤害减少是导致这一下降的原因。
趋势可以在使用的分析方法中建模得相当好。为了提高目标准确性,在另一个实验中使用了技术上更复杂的方法,这些方法在识别合适方法的过程中证明是相关和适用的。在进行了一些参数优化和交叉验证迭代后,Prophet 和 XGBoost 方法分别展示了 73.4%和 65.8%的最高验证结果。
数据科学家认为 Prophet 是应用过程中最合适的方法,并根据测试时间序列确定规划准确性。即使准确性略低于目标值 73.4%,仍然取得了显著的规划准确性改进。MAPE 为 16.64%,RMSE 为 8,130,这表明与 XGBoost 方法中的 RMSE(10,134)相比,绝对偏差较小。然而,与第一次实验类似,仍然存在一些非常难以整体预测的产品组(37.2%),对累计准确性产生了负面影响。

各种方法的性能比较。作者图像
评估
分析结果作为供应链负责人和分析师逻辑评估和分类的基础,由项目经理组织和主持。采用的评估指标是所有预先定义产品的累积计划准确性以及常用的 RMSE 和 MAPE 指标。部门需要一个现实、可追踪和可靠的基础来确定产品级别的需求。

三个最佳模型的评估。作者图像
计划准确性的基准假设为过去两年中(手动计划的)中位准确率 58%。结果的评估显示,通过使用数据科学模型,许多产品组整体上可以以较高的准确性进行规划,并大大超过了基准。然而,也有一些产品组在手动规划方面表现出类似的准确性。尤其需要讨论的是排水领域,该领域使用模型的结果远差于手动规划,似乎不适合使用目前的方法进行统计需求计算。

最佳模型的评估,按产品组分布。作者图像
从技术角度来看,供应链负责人认为由于特定的季节性和趋势性特征,仅能实现有限的计划准确性,因此对这些产品组进行统计规划没有多大意义。她建议引入一个产品基础上的误差阈值,以确定哪些产品应通过模型进行预测,哪些产品组将从建模中剔除并仍由人工计划。略低于当前基准的范围似乎是一个合适的阈值,因为从部门的角度来看,较少的人工干预能带来几乎相同的准确性,这总是朝着实现项目目标的方向上的一种改进。项目负责人记录了评估结果以及所采取的决策和措施。
经过首次实际建模后,所有选择的产品在接下来的 18 个月所需的数量可以作为分析结果进行记录。现在可以将其利用并整合到团队的规划过程中。
4. 部署
团队现在进入 DASC-PM 的整合利用阶段。

第 4 阶段:部署(Schulz 等人 2022)
技术方法准备
可以依赖现有基础设施进行利用。预测数据被加载到规划软件 IBM Planning Analytics 中,在那里进行测试和重新处理。所谓的 TurboIntegrator 被用来自动化加载过程,这是 IBM Planning Analytics 的一个核心组件。Planning Analytics 的 OLAP 结构允许创建灵活的视图,用户可以选择他们的上下文(时间参考、产品组等)并实时调整计算。此外,报告软件 QlikSense 也被集成用于更深入的分析。在这里,一方面可以可视化时间序列的组件(趋势、季节性、噪声),另一方面可以显示诸如异常值和中位数等附加信息。处理后的最终计划会被加载到数据湖中,以便未来参考。
确保技术可行性
预测本身会在每月初自动重新生成。规划人员可以在月初的前四个工作日内进行修正,并实时查看规划系统中的结果。由于算法在云环境中运行,计算能力可以根据需要进行扩展。为了使所有过程自动运行,应尽量减少数据源的变化。如果需要调整,数据工程师将会被通知,并通过记录所有数据源和连接的信息来更新接口文档。规划和预测系统是云(Microsoft Azure)和本地系统(Planning Analytics)的混合体,规划人员仅对本地结构拥有主动访问权限。这里授予凭证,使本地规划人员仅能访问他们的区域,而全球规划人员可以查看所有主题。开发阶段结束后,支持服务主要由 IT 部门处理。在复杂问题的情况下,还会咨询数据科学家或数据工程师。

图片来自 Unsplash
确保适用性
解决方案的用户包括本地和全球规划团队。由于团队成员技术背景较少,举办培训课程帮助他们解读预测结果并评估其质量。用户界面也以清晰和易懂为设计重点。使用简单的折线图和条形图来展示过程和基准,同时表格内容精简到最重要的信息。用户从一开始就参与开发,以确保技术的正确性和相关性,并在开发阶段结束前熟悉解决方案。此外,还会编写完整的文档。文档的技术部分主要基于接口文档,展示数据结构和连接,而内容部分则与用户共同准备。
技术准备
为了确保新解决方案在几个月后不会失去相关性或质量,尽管投入的时间大幅减少,工作仍继续进行改进。持续改进中最重要的方面是不断自动调整预测模型以适应新数据。系统中在开始时仍需手动处理的其他部分也会随着时间的推移实现自动化。规划人员可以在 Planning Analytics 中调整诸如预测范围或预测准确度阈值等参数,模型保持灵活。发布首个版本后出现的问题通过 IT 票务系统记录,并分配给数据科学领域。定期检查模型是否仍满足公司的期望,是否需要进行更改。
5. (应用)使用和总结

第 5 阶段:应用(Schulz 等,2022 年)
过渡到使用开发的模型意味着数据科学过程模型(DASC-PM)进入其最后阶段。SCHRAMME AG 通过使用结构化和整体的方法,在供应链领域实现了其设定的目标。现在可以从中衍生出额外或新的项目。规划过程大部分已实现自动化,并得到机器学习算法的支持。管理、财务和供应链中的相关利益相关者对结果感到非常满意。经过最初的怀疑,规划团队现在也对工作负担的减少和可能的优先排序感到信服。然而,也可以预见在使用过程中会出现弱点,并且在后续阶段可能需要更多的迭代。
整体案例研究表明,特别是非线性过程模型在数据科学领域具有优势。DASC-PM 是一种适用于转移到许多其他领域和问题的合适新型过程。
结论
总结来说,数据科学在解决复杂业务问题中扮演着不可或缺的角色,通过识别隐藏的模式并从数据中提取可操作的见解。通过这个案例研究,我们展示了如何利用数据科学技术开发预测模型,帮助企业做出明智的决策,例如在供应链中。
虽然这个案例研究侧重于需求规划,但该过程模型可以用于多种方式,例如在电子商务网站上构建个性化推荐、识别金融交易中的欺诈行为或预测电信或订阅型业务中的客户流失。
然而,必须注意的是,现实世界的数据科学项目面临多个挑战,如数据质量问题、缺乏领域专业知识和利益相关者之间的沟通不畅。相比之下,虚构的案例研究提供了一个理想化的环境,拥有干净、标记良好的数据和明确定义的问题陈述。因此,现实世界的项目需要一种务实的方法,考虑到业务目标、数据质量、计算资源和伦理问题。我相信你从自己的经验中知道这一点。不要低估现实!
总之,数据科学具有巨大的潜力来改变行业、社会,并为企业创造新的机会。DASC-DM(或任何)过程模型可以帮助合理地构建方法,以确保对业务利益相关者和项目团队本身的明确指导。
请告诉我你在数据科学项目中的经验。你如何构建这些项目?最大的挑战是什么?欢迎留言!

图片由Unsplash提供
阅读 Jonas Dieckmann 在 Medium 上的文章。分析经理和产品负责人@Philips | 对…充满热情并撰写相关内容
medium.com](https://medium.com/@jonas_dieckmann?source=post_page-----93ae57b682bf--------------------------------)
希望你觉得这有用。告诉我你的想法!同时欢迎在LinkedIn上连接我,或在 Medium 上关注我。
另见我的其他文章:
理解数据偏见的另一种方法
[towardsdatascience.com ## DASC-PM:数据科学项目的新型过程模型
或:如何正确地构建下一个数据科学项目
[towardsdatascience.com
参考文献
整个案例研究已发布于:
[1] Schulz et al. (2023): “DASC-PM v1.1 案例研究” 可从:www.researchgate.net/publication/368661660_DASC-PM_v11_Case_Studies 获取
过程图像取自:
[2] Schulz et al. (2022): “DASC-PM v1.1 — 数据科学项目的过程模型”*(2022),出版商:NORDAKADEMIE gAG Hochschule der Wirtschaft,ISBN:978–3–00–064898–4,DOI:10.25673/32872.2
案例研究:使用彩虹方法进行实际标签编码
在 MassMutual 生产模型上的真实世界测试
·
关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 2 月 24 日
--
与 德米特罗·卡拉巴什 共同编著

在我们之前的文章“隐藏的数据科学瑰宝:用于标签编码的彩虹方法”中,我们讨论了在开发基于树的模型时,使用标签编码而非独热编码的优势。我们介绍了彩虹方法,这有助于确定不同类型的分类变量的最合适的有序编码。
在本文中,我们将继续探讨 Rainbow 方法——这一次,从实际的角度,展示其在MassMutual的数据科学团队开发的真实项目中的有效性,MassMutual 是一家知名的寿险公司,致力于推动数据科学家、工程师和技术专家来帮助做出明智的商业决策。
商业用例
目标是预测每个潜在客户的五个思维模式细分中的一个。实质上,这是一个多类别分类问题。
细分框架包括五个类别,反映了一个人的年龄、财务稳定性以及对金融决策的态度。MassMutual 营销团队随后在各种活动中使用预测的细分进行目标定位和定制。

图 1(由 Anna Arakelyan 绘制)
例如,展现出思维模式 A 的客户倾向于在决定是否购买人寿保险时优先考虑独立性和自主性,而思维模式 B 的客户则通常更愿意从专门的顾问那里获得指导和详细的金融产品解释。
数据代表了一小部分标记个体(17.5K 人),标签由设计了细分分配规则的 MassMutual 供应商提供。我们首先将主潜在客户数据库中的列添加到这些数据中。目标是使用这些目标标签和可用特征学习最佳模型,并预测所有其他(未标记)潜在客户的细分。
我们使用的消费者数据库涵盖了大约 300 列,代表了多种人口统计特征,如家庭组成、收入和净资产倾向、金融行为以及数字敏锐度。
在本文中,我们通过消费数据库和思维模式细分项目,将传统的一热编码与 Rainbow 编码进行比较。我们展示了一些标准指标——如宏观平均 F1 得分、宏观平均 AUC ROC、Cohen’s Kappa 和准确率——用于解释和比较这个 5 类分类问题。
分类变量
我们选择了消费数据库中的所有分类变量——包括区间变量、序数变量和名义变量——但排除了定量变量和二元变量。目的是展示相同分类因素下,一热编码和 Rainbow 编码在模型性能上的差异。
我们进行了目标分层的 4 折交叉验证拆分,并且从这一点开始的所有数据处理都在交叉验证循环内完成。这包括从每个折的训练集创建一热特征和 Rainbow 特征,然后将它们应用于每个折的验证集。
总共 111 个变量被转换为 121 个 Rainbow 特征,另外转换为 2260 个一热特征。

表 1. 编码前后的变量列表
对区间和序数变量的 Rainbow 转换非常简单,结果是从 64 个区间特征得到了 64 个 Rainbows,从 14 个序数特征得到了 14 个 Rainbows。
名义变量的转换更为复杂,我们为剩下的 10 个变量创建了 23 个自然属性 Rainbows 和 20 个人工 Rainbow 特征。由于我们处理了五个类别,我们对随机类别应用了相关排序和目标百分比排序(见原文的自动化 Rainbow 选择部分)。例如,名义变量“Financial_Cluster”被转换为特征“Financial_Cluster_Mindset_B_correlation_rank”和“Financial_Cluster_Mindset_D_target_percent”。总体而言,33 个名义变量被转换为 43 个 Rainbows。
对于实际排序的选择——无论是自然属性 Rainbow 还是人工 Rainbow——高度依赖于项目和上下文,更多的是艺术而非科学。这需要在模型简洁性、性能和可解释性之间取得平衡。
与序数编码不同,One-hot 转换生成了超过两千个特征。
为什么我们在这里为区间和序数变量制作 One-hot 特征?因为我们希望在从完美顺序到模糊顺序,再到无顺序(或错误顺序)的完整连续体上,将 Rainbow 与 One-hot 进行比较。
此外,将变量分类为序数或名义有时是一种主观决定。一个明显的例子是颜色。正如我们在第一篇文章中讨论的,颜色被一些模型者认为是名义的,而另一些则认为是序数的。
起初,我们将所有类别变量进行汇总,但在文章后面,我们分别分析了区间、序数和名义变量。
我们训练了所有的XGBoost模型,涵盖了下面显示的超参数空间:
params = {
'objective': 'multi:softprob',
'eval_metric': 'mlogloss',
'num_class': 5,
'subsample': 0.8,
'max_depth': [2, 3, 5],
'eta': [0.1, 0.3, 0.5],
'n_estimators': [50, 100, 200],
}
我们避免将 max_depth 设置得高于 5,因为数据量相对较小,并且每个分支的末端需要至少 100 个样本。我们倾向于使用简单模型,这也有助于防止过拟合。
下面的所有结果表示交叉验证的平均指标。
综合结果
让我们从所有运行的总体平均值开始。显然,对于 Rainbow 编码,所有模型的平均指标更高。总体差异为几个百分点。

图 2(作者生成)
超参数
下面的图示展示了在保持所有其他超参数不变的情况下,每个超参数的指标变化。这些图示也清楚地表明,Rainbow 的结果在每个超参数和指标上都超过了 One-hot 的结果。

图 3a(由作者生成)

图 3b(由作者生成)

图 3c(由作者生成)
运行时间
接下来,让我们比较每种方法的运行时间。
One-hot: 65.059 s
Rainbow: 5.491 s
运行一个“Rainbow”模型的平均时间几乎是运行一个“One-hot”模型的 12 倍!因此,除了显著提高模型性能指标外,我们还可以看到,“Rainbow”方法可以为数据科学家节省大量时间。
间隔型、序数型和名义型
接下来,我们分别运行了包含间隔型、序数型和名义型特征的模型。结果列在下面。

图 4(由作者生成)
这些结果再次强调了“Rainbow”相比于“One-hot”的优势。正如预期的那样,“Rainbow”编码对间隔型和序数型特征的提升最大,而对名义型变量的提升则较小。
显然,类别顺序越明确,选择“Rainbow”而非“One-hot”的好处就越大。虽然“Rainbow”对名义型变量的表现与“One-hot”相似或略低,但它仍然能以显著较少的时间和空间达到相同的性能水平,生成的模型也显著更简单。
特征选择
最后,为了确保在维度方面的公平比较,我们从每个特征集(Rainbow 和 One-hot)中选择了前 10、50 和 100 个特征。我们利用了XGBoost模型的特征重要性属性,并聚合了四次交叉验证折中的特征重要性分数,以获得每种编码类型的最佳超参数集。结果如下所示。

图 5(由作者生成)
“Rainbow”编码轻松超越了“One-hot”编码,特别是在特征数量较少的情况下。“Rainbow”编码比“One-hot”编码更快地达到性能峰值,并且使用的特征更少。实际上,只有 10 个特征时,“Rainbow”编码已经接近其峰值,而“One-hot”编码则需要 50–100 个特征才能达到类似水平!
此外,“Rainbow”编码在 50 个特征上的结果甚至优于“One-hot”编码在 100 个特征上的结果。值得注意的是,当特征数量从 50 降至 10 时,“One-hot”编码的 Macro-F1 降低幅度是“Rainbow”方法的六倍(Kappa 和 Accuracy 降低幅度为三倍,Macro-AUC 降低幅度为两倍)。
结论
MassMutual 的心态分段模型的例子清楚地说明了 Rainbow 标签编码优于 One-hot 编码。不仅为建模人员节省了大量时间,还显著降低了维度,并提供了一个有机的特征选择框架。此外,如果所选的 Rainbow 顺序与数据生成过程一致,那么这种编码还可以显著提升模型性能指标。
CatBoost 回归:为我详细讲解一下
原文:
towardsdatascience.com/catboost-regression-break-it-down-for-me-16ed8c6c1eca
CatBoost 内部工作原理的全面(并且有插图)解析
·发表于 Towards Data Science ·阅读时间 14 分钟·2023 年 9 月 2 日
--
CatBoost,代表类别增强,是一种强大的机器学习算法,在处理类别特征和产生准确预测方面表现出色。传统上,处理类别数据是相当棘手的——需要使用独热编码、标签编码或其他一些可能扭曲数据固有结构的预处理技术。为了解决这个问题,CatBoost 使用了其内置的编码系统,称为 有序目标编码。
让我们通过构建一个模型来预测某人如何给书籍 Murder, She Texted 打分,基于他们在 Goodreads 上的平均书籍评分和他们的最爱类别,来看看 CatBoost 在实践中是如何工作的。
我们让 6 个人对 Murder, She Texted 进行评分,并收集了关于他们的其他相关信息。

这是我们当前的训练数据集,我们将用它来训练(显而易见)数据。
第 1 步:随机打乱数据集并使用 有序目标编码 对类别数据进行编码
我们处理类别数据的方式对 CatBoost 算法至关重要。在这种情况下,我们只有一个类别列 — Favorite Genre。这个列被编码(即转换为离散整数),具体的编码方式取决于这是回归问题还是分类问题。由于我们处理的是回归问题(因为我们想预测的变量 Murder, She Texted Rating 是连续的),我们按照以下步骤进行。
1 — 随机打乱数据集:

2 — 将连续目标变量分成离散的 桶:由于我们这里的数据非常少,我们将创建两个相同大小的桶来对目标进行分类。(了解更多关于如何创建桶的内容,请参见 这里)。
我们将Murder, She Texted Rating的 3 个最小值放入桶 0,其余的放入桶 1。

3 — 使用公式对分类列进行编码:Ordered Target Encoding 假设它一次接收一行数据,并使用此公式对最喜欢的类型进行编码:

-
curCount = 我们之前见过的拥有相同最喜欢的类型且在评分桶1 中的人数
-
prior = 用户定义的常数值;在我们的例子中设置为 0.05
-
maxCount = 我们之前见过的拥有相同最喜欢的类型的人数
注意:如果我们有更多数据,我们将有更多的桶。我们使用不同的公式来编码分类数据。阅读更多这里。
使用这个公式,让我们对第一行进行编码。由于这是第一行,我们假设之前没有数据,这一行是我们唯一的信息。

这里:
-
curCount = 我们之前见过的在评分桶1 中且最喜欢的类型为神秘的人的数量 = 0
-
maxCount = 我们见过的最喜欢的类型为神秘的人的数量 = 0

因此,第一行中神秘的编码值是 0.05。

现在对于第二行,我们假设唯一的数据是前两行。

-
curCount = 我们之前见过的在评分桶1 中且最喜欢的类型为浪漫的人的数量 = 0
-
maxCount = 我们见过的最喜欢的类型为浪漫的人的数量 = 0

与第一行类似,第二行的编码值是 0.05。

对于第三行:

-
curCount = 我们之前见过的在评分桶 1 中且最喜欢的类型为神秘的人的数量 = 0
-
maxCount = 我们见过的最喜欢的类型为神秘的人的数量 = 1


类似地,如果我们对剩余的行进行此编码,我们得到:

就是这样,我们对分类变量进行编码。
现在我们可以忽略最喜欢的类型并 仅考虑编码后的最喜欢的类型*。

步骤 2:进行初步预测并计算残差
CatBoost 从对所有行进行初始的Murder, She Texted Rating预测 0 开始。

然后我们使用这个公式计算称为残差的东西:


第 3 步:构建 CatBoost 树
现在我们有了残差,可以开始构建 CatBoost 树了。阅读我之前的文章决策树和 XGBoost 可能会对你理解决策树有所帮助。
查找根节点
我们通过比较使用最喜欢的类型(编码)与平均 Goodreads 评分作为根节点的效果,确定树根(第一次拆分)的最佳阈值。
首先,我们需要根据最喜欢的类型确定拆分树的候选节点。为此,我们必须将最喜欢的类型的值按升序排序:

然后我们计算最喜欢的类型中相邻值的平均值:

我们的最喜欢的类型拆分候选值是这些平均值——0.0375、0.05、0.2875 和 0.525。
我们尝试的第一个候选是最喜欢的类型 < 0.0375:

树的叶节点是绿色的。CatBoost 初始化了一个叫做输出的东西,将拆分的叶节点设置为 0:

如果最喜欢的类型小于 0.0375,我们会落在左叶节点;否则,我们会落在右叶节点。当每行数据传递到树中时,其残差被放入叶节点。
所以将第一行数据传递到树中……

……我们将其残差放入右叶节点,因为最喜欢的类型为 0.05,大于 0.0375:

然后我们跟踪该行的叶节点输出:

然后我们将树中输出的值更新为叶节点中残差值的平均值。在这种情况下,由于叶节点中只有一个残差,输出为 3.5。

现在将第二行数据传递到树中:

我们也将其残差放入右叶节点,因为 0.05 > 0.0375:

残差最终落在右叶节点
我们存储叶节点输出值:

然后我们通过计算叶节点中两个残差的平均值来更新输出值:

正确输出:3.5 => 3
现在让我们传递第三行数据:

残差最终落在左叶节点,因为最喜欢的类型 = 0.025 < 0.0375
跟踪叶节点输出:

更新叶节点的输出值:

左侧输出:0 => 4
最后,让我们将最后三行运行在树上。我们得到这棵树…

…以及这张表:

存储了叶节点输出值的最终表
量化这个根节点的“好坏”
CatBoost 通过计算 叶节点输出 列和 残差 之间的 余弦相似度 来量化划分的好坏。余弦相似度的公式是:

其中 A 和 B 只是我们试图比较的两列。
所以要计算 残差 和 叶节点输出 列的余弦相似度…

…我们将相应的值代入公式:

我们发现余弦相似度为 0.786。因此,阈值 Favorite Genre < 0.0375 的余弦相似度为 0.786。
现在使用与上述相同的过程,我们构建一个使用第二个候选根阈值的树:Favorite Genre < 0.05。重复相同的过程,我们得到这棵树:

…以及这张表:

…余弦相似度为:

注意:这与我们使用阈值 Favorite Genre < 0.0375 得到的值相同,因为残差落在相同的叶节点中。
让我们尝试下一个根阈值候选项:Favorite Genre < 0.2875。我们得到的树是:

…以及这张表:

…并且残差和叶节点输出的余弦相似度为 0.84。在这里,由于余弦相似度大于其他两个阈值的相似度,我们得出结论:Favorite Genre < 0.2875 是比 Favorite Genre < 0.0375 和 Favorite Genre < 0.05 更好的根节点划分。
现在,我们来进行最后的划分:Favorite Genre < 0.525。我们得到余弦相似度为 0.84,这使我们得出结论:0.2875 和 0.525 的划分效果相似。
但请记住,我们只测试了 Favorite Genre 的候选项。 接下来,我们需要测试 Average Goodreads Rating 的根节点候选项。为此,我们需要通过将列按升序排列并计算相邻的平均值来确定 Average Goodreads Rating 的划分候选项。

对于每一个平均值,我们构建一棵树并计算 叶节点输出 和 残差 之间的余弦相似度:


比较所有候选根节点的余弦相似度值,我们发现Average Goodreads Rating < 3.65 的余弦相似度最高,为 0.87。所以我们选择这个作为我们的根节点分割

获得根节点后,我们可以通过添加新分支来扩展树。为此,我们遵循与之前类似的过程,但不是选择根节点,而是选择从叶子处分裂出来的分支。选择具有最高余弦相似度值的分割。
CatBoost 树的一个注意事项是它们是对称的,意味着同一层上的每个分支使用相同的阈值。
注意:如果你对为什么要构建对称树感到好奇,你可以在这里阅读更多内容。
一个示例是:

在这种情况下,同一层的两个节点使用相同的分割。
由于我们数据很少,最好只构建深度为 1 的树。
注意:树的深度是我们可以调整的模型参数,主要用于避免过拟合。在大多数情况下,最佳深度范围是 4 到 10,建议使用 6 到 10 的值。
就这样,我们有了第一棵树!

步骤 4:进行新的预测
现在我们使用旧的预测和这个公式进行新的预测:

学习率是另一个我们可以调整以避免过拟合的参数。有关更多信息,请阅读这里。现在,让我们将其设置为 0.1。
让我们回到我们的表格。

使用公式,我们可以计算新的预测。对于第一行,新预测将是:

同样,如果我们计算其余行的新预测值,我们得到:

我们可以看到我们的新预测不够准确,因为它们与Murder, She Texted的实际评分仍然有显著差异。然而,相比于之前全为零的预测,已有所改善。
我们的下一步是构建一棵新树。但在此之前,让我们先快速清理一下嘈杂的数据集,以便更容易处理。我们可以忽略旧的预测、残差列和叶子输出列…

…将“新预测”列重命名为“预测”(因为它不再是新的了)

…并将编码的最喜欢的类型的值替换为我们原始的最喜欢的类型:

步骤 5:使用步骤 1-4 构建新树
现在我们重复构建第一棵树时所做的相同步骤来构建第二棵树。
使用步骤 1,我们打乱数据集…

…并使用有序目标编码对类别数据(即最喜欢的类型)进行编码:

按照步骤 2,因为我们已经有了预测,所以不需要进行初步预测。我们只需使用上述相同的公式计算残差:

我们得到以下残差:

然后我们使用步骤 3构建第二棵 CatBoost 树。假设在测试所有编码后的最喜欢类型和平均 Goodreads 评分候选项后,我们找到的最佳根节点是编码后的最喜欢类型 < 0.288。由于树的深度为 1,如之前设置的,我们最终得到这样的树:

…以及带有叶节点输出的更新表:

最后,使用步骤 4,我们进行新的预测。使用这个公式…

…我们得到新的预测:

我们可以看到,新预测稍微比旧的预测更好。如果我们继续这个过程并构建更多的树,我们的预测会越来越好。
注意:我们继续构建树,直到我们的预测足够好,或者直到达到我们可以设置的树的数量参数值。
使用我们的 CatBoost 树进行预测
假设我们用上面这两棵树完成了模型构建过程。(默认情况下,CatBoost 构建 1000 棵树)。我们现在有一个 CatBoost 模型(显然,由于我们只有 2 棵树,它不会接近一个好的模型),我们可以开始进行预测。


现在使用我们的模型,我们想要预测这两个人会如何评分Murder, She Texted。

首先,我们需要对类别数据进行编码——最喜欢的类型。编码新数据的过程类似于我们编码训练数据的方式;唯一的区别是我们使用整个训练数据集进行编码。

将Murder, She Texted Rating分配到之前使用的相同评分桶中:

现在我们使用上述相同的公式进行编码:

然而,我们使用整个训练数据集,而不是像训练过程中那样按顺序处理。例如,最喜欢的类型为神秘剧的编码将是:

同样,其他最喜欢的类型的编码是:

我们在新数据集中替换了编码值:

现在对于第一个人,我们回到我们的树,并将数据传递下去:


然后我们使用这个公式来进行预测:

所以我们的预测是:

当然,这是一个糟糕的预测。但请记住,这是一个相当糟糕的模型。我们拥有的树越多,我们的模型表现就会越好。
同样,对于第二个人:

就这些了。这就是我们如何构建 CatBoost 树并利用它们对新数据进行预测!
除非另有说明,所有图片均由作者提供
你可以通过LinkedIn与我联系,或者通过shreya.statistics@gmail.com发送电子邮件给我,提出问题和建议,尤其是对任何你希望我讲解的其他算法!
《抓住我,如果你能:Python 异常处理指南》
通过智能异常管理,释放 Python 的全部潜力
·
关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 5 月 8 日
--
图片来源:Cookie the Pom 在 Unsplash
作为软件开发者,处理异常通常被视为一种必要的恶行。然而,掌握 Python 的异常处理系统可以让你成为更高效、更有效的程序员。
在这篇博客文章中,我将对以下内容进行深入解释:
-
什么是异常处理?
-
if语句与异常处理的区别 -
使用
else和finally子句进行正确的错误管理 -
自定义异常的定义
-
异常处理的最佳实践
什么是异常处理?
异常处理是编写代码以捕获和处理可能在程序执行期间发生的错误或异常的过程。这使得开发者能够编写即使在面对意外事件或错误时也能继续运行的健壮代码,而不是完全崩溃。
当发生异常时,Python 会搜索匹配的异常处理程序。处理程序代码将执行并采取适当的行动,如记录错误、显示错误信息或尝试从错误中恢复。总体而言,异常处理有助于使 Python 应用程序更加可靠、可维护,并且更易于调试。
if 语句和异常处理之间的区别
if 语句和 Python 中异常处理的主要区别在于它们各自的目标和使用场景。
if 语句作为结构化编程的基本构建块。它评估一个条件,并根据条件是否为真执行不同的代码块。以下是一个示例:
temperature = int(input("Please enter temperature in Fahrenheit: "))
if temperature > 100:
print("Hot weather alert! Temperature exceeded 100°F.")
elif temperature >= 70:
print("Warm day ahead, enjoy sunny skies.")
else:
print("Bundle up for chilly temperatures.")
异常处理在编写健壮且有弹性的程序中扮演着重要角色,它通过处理在运行时可能出现的意外事件和错误来实现这一点。
异常用于信号问题,并指出代码中需要改进、调试或额外错误检查措施的区域。它们允许 Python 优雅地处理错误情况,并继续执行脚本,而不是突然终止。
下面是一个如何实现异常处理以更好地管理与除零相关的潜在失败的示例:
# Define a function that tries to divide a number by zero
def divide(x, y):
result = x / y
return result
# Call the divide function with x=5 and y=0
result = divide(5, 0)
print(f"Result of dividing {x} by {y}: {result}")
输出:
Traceback (most recent call last):
File "<stdin>", line 8, in <module>
ZeroDivisionError: division by zero attempted
由于引发了异常,程序在到达 print 语句之前立即停止执行。
我们可以通过将对“divide”函数的调用放入 try-except 块中来处理上述异常,如下所示:
# Define a function that tries to divide a number by zero
def divide(x, y):
result = x / y
return result
# Call the divide function with x=5 and y=0
try:
result = divide(5, 0)
print(f"Result of dividing {x} by {y}: {result}")
except ZeroDivisionError:
print("Cannot divide by zero.")
输出:
Cannot divide by zero.
通过这样做,我们优雅地处理了 ZeroDivisionError 异常,而不会因为未处理的异常使脚本的其余部分失败。
有关 Python 内置异常的更多信息,请参见 [2]。
使用 Else 和 Finally 子句进行正确的错误管理
在处理 Python 中的异常时,建议在 try-except 块中同时包含 else 和 finally 子句。else 子句允许你指定如果没有引发异常时应该发生的情况,而 finally 子句确保无论是否发生异常,某些清理操作总是会执行 [1][2]。
例如,考虑一个场景,你想从文件中读取数据并对数据进行一些操作。如果在读取文件时发生异常,你可能想记录错误并停止进一步处理,但仍然想正确关闭文件。
使用 else 和 finally 子句可以让你做到这一点——如果没有发生异常,则正常处理数据;或者在处理任何异常时仍能适当地关闭文件。如果没有这些子句,你的代码可能会遭遇资源泄漏或不完整的错误处理。因此,它们在创建健壮和可靠的程序中扮演着至关重要的角色。
try:
# Open the file in read mode
file = open("file.txt", "r")
print("Successful opened the file")
except FileNotFoundError:
# Handle missing files
print("File Not Found Error: No such file or directory")
exit()
except PermissionError:
# Handle permission issues
print("Permission Denied Error: Access is denied")
else:
# All good, do something with the file data
content = file.read().decode('utf-8')
processed_data = process_content(content)
# Cleanup after ourselves even if an exception occurred above
finally:
file.close()
在这个例子中,我们首先尝试使用 with 语句打开“file.txt”文件进行读取,这保证了文件对象在执行完成后自动正确关闭。如果在文件 I/O 操作期间发生 FileNotFoundError 或 PermissionError,相应的 except 语句将被执行。为了简单起见,如果找不到文件,我们只是打印错误信息并退出程序。
否则,当 try 块中没有异常发生时,我们在 else 分支中继续处理文件内容。最后,finally 块保证了文件的关闭,无论之前是否抛出了异常 [1]。
通过采用这样的结构化方法,你的代码保持组织良好,易于跟随,同时考虑到可能由于与外部系统或输入交互而出现的潜在错误。
自定义异常定义
在 Python 中,你可以通过从内置异常如 Exception 或任何直接继承自 Exception 的类创建子类来定义自定义异常。
为此,你需要创建一个继承自这些基本异常之一的新类,并添加特定于你需求的属性。然后,你可以在代码中像使用其他内置异常类一样使用你新定义的异常类。
下面是定义一个名为 InvalidEmailAddress 的自定义异常的示例:
class InvalidEmailAddress(ValueError):
def __init__(self, message):
super().__init__(message)
self.msgfmt = message
这个自定义异常是从 ValueError 派生的,它的构造函数接受一个可选的消息参数(默认为"invalid email address")。
当你遇到无效的电子邮件地址格式时,可以抛出这个异常:
def send_email(address):
if isinstance(address, str) == False:
raise InvalidEmailAddress("Invalid email address")
# Send email
现在,如果你将一个无效的字符串传递给 send_email() 函数,你将看到一个自定义的错误消息,而不是普通的 TypeError,它清楚地指示了问题所在。例如,抛出异常的代码可能如下所示:
>>> send_email(None)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/path/to/project/main.py", line 8, in send_email
raise InvalidEmailAddress("Invalid email address")
InvalidEmailAddress: Invalid email address
异常处理的最佳实践
以下是一些与 Python 中错误处理相关的最佳实践:
-
设计以应对失败:提前规划,考虑可能的失败情况并设计你的程序以优雅地处理这些失败。这意味着要预见边界情况并实施适当的错误处理程序。
-
使用描述性的错误信息:提供详细的错误信息或日志,帮助用户理解出了什么问题以及为什么。避免使用诸如“发生错误”或“发生了不好的事情”的通用错误信息。相反,显示一个友好的消息,建议解决方案或提供文档链接。务必在提供详细说明和避免 UI 过于繁杂之间取得平衡。
-
最小化副作用:通过使用 try-finally 或 try-with-resources 块隔离问题代码段,最小化失败操作的后果。确保清理任务在成功或失败的结果下都能始终执行。
-
彻底测试:确保你的异常处理程序在各种情况下表现正确,通过运行全面的测试来实现。
-
定期重构:重构易出错的代码段,以提高其可靠性和性能。保持代码库的模块化和松散耦合,使独立部分可以独立演进而不会对其他部分产生负面影响。
-
记录重要事件:通过将有趣的事件记录到文件或控制台输出中,跟踪应用程序中的发生情况。这有助于你快速诊断问题,而无需在大量未结构化的日志中筛选。
结论
编写错误处理代码是软件开发中不可或缺的一部分,特别是在使用 Python 时,它使开发者能够构建更可靠和健壮的应用程序。通过遵循行业标准和最佳实践,开发者可以减少调试时间,确保代码质量,并提供更好的用户体验。
资源
[1] docs.python.org/3/tutorial/errors.html
[2] www.geeksforgeeks.org/python-exception-handling/
了解大型语言模型
原文:
towardsdatascience.com/catch-up-on-large-language-models-8daf784f46f8
实用的无炒作大型语言模型指南
·发表于 Towards Data Science ·15 分钟阅读·2023 年 9 月 5 日
--

图片由 Gary Bendig 提供,来源于 Unsplash
如果你在这里,这意味着像我一样,你被围绕 大型语言模型 (LLMs) 的不断信息流和炒作文章所压倒。
这篇文章是我试图帮助你了解大型语言模型的努力,没有炒作。毕竟,这是一个变革性的技术,我相信了解它很重要,希望这会激发你更深入地学习并用它创建一些东西。
在接下来的部分中,我们将定义什么是 LLM 及其工作原理,当然会涵盖 Transformer 架构。我们还将探讨不同的 LLM 训练方法,并通过一个动手项目结束文章,在这个项目中,我们将使用 Flan-T5 进行 Python 情感分析。
开始吧!
LLM 和生成式 AI:它们是一样的吗?
生成式 AI 是机器学习的一个子集,专注于那些主要功能是生成 某物 的模型:文本、图像、视频、代码等。
生成模型通过在大量由人类创建的数据上进行训练,以学习模式和结构,从而能够生成新数据。
生成模型的示例包括:
-
图像生成:DALL-E,Midjourney
-
代码生成:OpenAI Codex
-
文本生成:GPT-3,Flan-T5,LLaMA
大型语言模型是生成式 AI 领域的一部分,因为它们会接收输入文本并反复预测下一个单词,直到输出完成。
然而,随着语言模型的规模不断扩大,它们能够执行其他自然语言处理任务,如摘要、情感分析、命名实体识别、翻译等。
既然如此,现在让我们关注 LLM 的工作原理。
LLM 的工作原理
我们现在拥有大规模语言模型的原因之一是 Google 和多伦多大学的开创性工作,他们在 2017 年发布了论文 Attention Is All You Need。
本文介绍了 Transformer 架构,它是我们今天所知和使用的 LLM 背后的基础。
这种架构解锁了大规模模型,使得在多个 GPU 上训练非常大的模型成为可能,这些模型能够并行处理输入,给它们处理非常大数据序列的机会。
Transformer 架构概述
以下内容旨在对 Transformer 架构进行高层次的概述。虽然有许多资源对其进行了更深入的探讨,但这里的目标只是理解其工作原理,以便理解不同的 LLM 如何在不同任务中发挥作用。
如需更多详细信息,我建议阅读 原始论文。
所以,让我们从 Transformer 架构的简化可视化开始。

Transformer 架构的简化可视化。图片由作者提供。
从上图中,我们可以看到 Transformer 的主要组件是编码器和解码器。在每个组件内部,我们还可以找到 attention 组件。
让我们更详细地探讨每个组件,以理解 Transformer 架构的工作原理。
分词输入
我们知道 LLMs 使用文本,但计算机处理的是数字而非字母。因此,输入必须进行 分词。
分词是将句子的单词表示为数字的过程。
基本上,模型可以处理的每个可能单词都在一个字典中,并且与之关联一个编号。通过分词,我们可以检索与单词相关的编号,将句子表示为数字序列,如下所示。

分词示例。句子被分词后送入 Transformer 的嵌入层。图片由作者提供。
在上图中,我们看到一个例子,展示了如何在将句子“It rained this morning”发送到 Transformer 的嵌入层之前进行分词。
注意到分词句子的方式有很多种。在上面的例子中,分词器可以将一个词的部分表示出来,这就是为什么 rained 被分成 rain 和 ed。其他分词器可能只会为完整的单词分配一个编号。
词嵌入层
此时,我们有一系列表示单词的数字,但计算机如何理解这些数字的含义呢?
这是通过词嵌入层实现的。
词嵌入是一种对单词的学习表示,使得具有相似意义的单词具有相似的表示。模型将学习单词的不同属性,并在一个固定空间中表示它们,其中每个轴可以表示单词的属性。

词嵌入的可视化。我们可以看到“morning”和“sunrise”有相似的表示,因为它们在 3D 空间中的角度较小。同样,“rain”和“thunder”彼此更近。图片由作者提供。
在上图中,我们可以看到 3D 词嵌入的样子。我们看到“morning”和“sunrise”彼此更近,因此具有相似的表示。这可以通过余弦相似度计算得出。
另一方面,“rain”和“thunder”彼此较近,而与“morning”和“sunrise”相距较远。
现在,我们只能展示一个 3D 空间,但实际上,嵌入可以有数百个维度。事实上,原始的 Transformer 架构使用了 512 维的嵌入空间。这意味着模型可以学习 512 个不同的词属性,将它们表示在一个 512 维的空间中。
那么词序呢?
你可能已经注意到,通过表示词的嵌入,我们会丧失它们在句子中的顺序。
当然,在自然语言中,词序非常重要,因此我们使用位置编码,以便模型了解句子中单词的顺序。
是将词嵌入和位置编码结合在一起并发送给编码器。
在编码器内部
我们的输入在编码器内部传递,在那里它们会经过自注意力机制。
这就是模型可以学习句子中每个标记之间依赖关系的地方。它学习了每个词相对于句子中所有其他词的重要性。

单词“rained”的注意力图示例。笔划宽度代表重要性。在这里,我们可以看到“rained”与“this”和“morning”紧密连接。图片由作者提供。
在上图中,我们展示了单词“rained”的注意力图的风格化示例。笔划宽度表示重要性。
在这个例子中,我们可以看到自注意力捕捉了“rained”与“this”和“morning”的重要性,这意味着它理解了这个句子的上下文。
尽管这个例子很简单,因为我们只有一个非常短的句子,自注意力机制在较长的句子中效果很好,能够有效地捕捉上下文和句子的整体含义。
此外,模型并没有一个单一的注意力头。事实上,它有多个注意力头,也称为多头自注意力,每个头部可以学习语言的不同方面。
例如,在论文Attention Is All You Need中,作者发现一个头部涉及到指代消解,即识别实体与其重复引用之间的联系。

指代消解的例子。在这里,单词“keys”在句子中再次被引用为“they”。图片由作者提供。
上面,我们看到一个指代解析的示例,其中单词“keys”后来被提及为“they”,因此一个注意力头可以专门识别这些链接。
注意,我们并未决定每个注意力头将学习语言的哪个方面。
此时,模型已经对句子的意义结构有了深层次的表示。这被发送到解码器。
解码器内部
解码器接受输入令牌的深层表示。这为解码器内部的自注意机制提供信息。
作为提醒,这里再次展示了 Transformer 架构,以便我们记住它的样子。

Transformer 架构的简化可视化。图片由作者提供。
序列开始令牌被插入作为解码器的输入,以指示其开始生成新令牌。
新令牌是根据编码器生成的输入序列的理解及其自注意机制生成的。
在上图中,我们可以看到解码器的输出被送到一个 softmax 层。这生成了每个可能令牌的概率向量。具有最大概率的令牌随后由模型输出。
该输出令牌随后被送回嵌入层作为解码器的输入,直到模型生成序列结束令牌。此时,输出序列完成。
这总结了大型语言模型背后的基本架构。通过 Transformer 架构及其并行处理数据的能力,使得在大量数据上训练模型成为可能,使 LLMs 成为现实。
现在,情况更复杂,因为 LLMs 并非都使用完整的 Transformer 架构,这影响了它们的训练方式。让我们更详细地探讨这一点。
LLM 的训练方式
我们已经看到了支撑大型语言模型的基本机制,如前所述,并非所有模型都使用完整的 Transformer 架构。
实际上,一些模型可能只使用编码器部分,而其他模型只使用解码器部分。
这意味着模型的训练方式也不同,因此会专注于特定任务。
仅编码器模型
仅编码器模型,也称为自编码模型,最适合用于情感分析、命名实体识别和词汇分类等任务。
自编码模型的流行示例有 BERT 和 ROBERTA。
这些模型使用掩码语言建模(MLM)进行训练。通过这种训练方法,输入句子中的单词会被随机掩盖,模型的目标是重建原始文本。

说明了用于自编码模型的掩码语言建模(MLM)。在这里,输入句子中的一个随机单词被掩盖,模型必须重建原始句子。图片由作者提供。
在上图中,我们可以看到掩蔽语言建模的样子。一个词被隐藏,句子被输入到模型中,模型必须学习预测正确的词以得到正确的原始句子。
使用该方法,自编码模型发展了双向上下文,因为它们可以看到需要预测的标记前后的内容,而不仅仅是前面的内容。
再次如上图所示,模型看到“it rained”和“morning”,因此它看到句子的开头和结尾,这使得它能够预测“this”这个词,从而正确重构句子。
注意,对于自编码模型,输入和输出序列的长度是相同的。
仅解码器模型
仅解码器模型也称为自回归模型。这些模型最适合文本生成,但当模型变得非常大时,新的功能就会出现。
自回归模型的例子有 GPT 和 BLOOM。
这些模型使用因果语言建模(CLM)进行训练。使用因果语言建模时,模型只看到掩蔽之前的标记,这意味着它看不到序列的结尾。

说明因果语言建模。在这里,模型只看到导致掩蔽的标记。然后,它必须推断下一个标记直到句子完整。图像由作者提供。
如上所示,使用因果语言建模时,模型只看到导致掩蔽的标记,而看不到掩蔽之后的内容。然后,它必须预测下一个标记直到句子完整。
在上面的例子中,模型会输出“this”,然后该标记会被反馈作为输入,因此模型可以预测“morning”。
与掩蔽语言建模不同,模型建立了单向上下文,因为它们看不到掩蔽之后的内容。
当然,对于仅解码器模型,输出序列的长度可能与输入序列的长度不同。
编码器-解码器模型
编码器-解码器模型也称为序列到序列模型,并且它们使用完整的 Transformer 架构。
这些模型通常用于翻译、文本摘要和问答。
流行的序列到序列模型的例子有 T5 和 BART。
为了训练这些模型,使用了跨度破坏方法。在这里,一个随机的标记序列被掩蔽并指定为哨兵标记。然后,模型必须自回归地重构被掩蔽的序列。

跨度破坏的说明。在这里,一系列标记被掩蔽并用哨兵标记替代。然后,模型必须自回归地重构被掩蔽的序列。图像由作者提供。
在上图中,我们可以看到两个标记的序列被掩蔽并用哨兵标记替代。然后,模型被训练以重构哨兵标记以获得原始句子。
在这里,掩码输入被发送到编码器,而解码器负责重建掩码序列。
关于模型大小的说明
尽管我们已指定了某些模型表现最佳的任务,研究人员观察到大型模型能够执行各种任务。
因此,虽然编码-解码模型专门用于翻译,但非常大的仅解码模型在翻译方面也表现出色。
考虑到这些,让我们现在开始在 Python 中使用大型语言模型。
与大型语言模型合作
在我们实际操作大型语言模型之前,让我们先了解一些与 LLM 相关的技术术语。
首先,我们提供给 LLM 的文本称为提示(prompt),模型的输出称为完成(completion)。

提示是我们向模型提供的包含指令的文本。模型的输出称为完成。图片由作者提供。
在提示中,我们向 LLM 提供指令,以实现我们希望完成的任务。
这也是进行提示工程的地方。通过提示工程,我们可以进行上下文学习,即向模型提供如何执行某些任务的示例。稍后我们将看到一个例子。
目前,让我们使用 Python 与 LLM 进行情感分析的互动。
实践项目:使用 Flan-T5 进行情感分析
对于这个迷你项目,我们使用 Flan-T5 对各种金融新闻进行情感分析。
Flan-T5 是 T5 模型的改进版,T5 是一个序列到序列模型。研究人员基本上对 T5 模型进行了微调,使其覆盖更多语言。有关更多详细信息,请参见原始论文。
对于数据集,我们将使用由 Pekka Malo 和 Ankur Sinha 在 Creative Commons 属性许可下发布的financial_phrasebank 数据集。
数据集包含来自英语金融新闻的共 4840 个句子,这些句子被分类为积极、消极或中立。五到八名注释员对每个句子进行分类,根据一致性率,数据集的大小会有所不同(50%一致率为 4850 行,100%一致率为 2260 行)。
有关数据集及其编制方式的更多信息,请参见完整数据集详情页面。
当然,下面显示的所有代码都可以在GitHub上找到。
设置你的环境
为了使以下实验有效,确保有一个虚拟环境,并安装了以下软件包:
-
torch
-
torchdata
-
transformers
-
datasets
-
pandas
-
matplotlib
-
scikit-learn
请注意,库transformers和datasets来自 HuggingFace,使我们可以轻松访问和实验 LLM。
一旦环境设置好,我们可以开始导入所需的库。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
加载数据
然后,我们可以加载我们的数据集。在这里,我们使用了具有 100% 一致性率的数据集。
dataset_name = "financial_phrasebank"
dataset = load_dataset(dataset_name, "sentences_allagree")
该数据集包含总共 2264 个句子。
请注意标签已被编码。1 代表中性,0 代表负面,2 代表正面。每种标签的计数如下所示。

数据集中每种情感的频率。图片由作者提供。
让我们将每个句子的实际标签存储在一个 DataFrame 中,以便后续更容易评估模型。
labels_df = pd.DataFrame()
labels_from_dataset = [dataset['train'][i]['label'] for i in range(2264)]
labels_df['labels'] = labels_from_dataset
加载模型
现在,让我们加载模型和分词器。如前所述,我们将加载 Flan-T5 模型。请注意,该模型有不同的大小版本,但我决定使用基础版。
model_name = "google/flan-t5-base"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
就这样!我们现在可以使用这个 LLM 对我们的数据集进行情感分析。
向模型提出情感分析的提示
为了让模型进行情感分析,我们需要进行提示工程以指定该任务。
在这种情况下,我们简单地使用“以下句子是正面、负面还是中性?”。然后我们将数据集中的句子传递给模型,让模型进行推断。
请注意,这被称为零-shot 推断,因为模型没有特别针对这个特定任务和数据集进行训练。
zero_shot_sentiment = []
for i in range(2264):
sentence = dataset['train'][i]['sentence']
prompt = f"""
Is the follwing sentence positive, negative or neutral?
{sentence}
"""
inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
model.generate(
inputs["input_ids"],
max_new_tokens=50
)[0],
skip_special_tokens=True
)
zero_shot_sentiment.append(output)
在上面的 Python 代码块中,我们循环遍历数据集中的每个句子,并将其传递到我们的提示中。提示被分词并设置给模型。然后,我们解码输出以获得自然语言响应。最后,我们将模型的预测存储在列表中。
然后,让我们将这些预测添加到我们的 DataFrame 中。
labels_df['zero_shot_sentiment'] = zero_shot_sentiment
labels_df['zero_shot_sentiment'] = labels_df['zero_shot_sentiment'].map({'neutral':1, 'positive':2, 'negative':0})
评估模型
为了评估我们的模型,让我们展示预测的混淆矩阵以及分类报告。
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report
cm = confusion_matrix(labels_df['labels'], labels_df['zero_shot_sentiment'], labels=[0,1,2])
disp_cm = ConfusionMatrixDisplay(cm, display_labels=[0,1,2])
disp_cm.plot();
plt.grid(False)
plt.tight_layout()
clf_report = classification_report(labels_df['labels'], labels_df['zero_shot_sentiment'], labels=[0,1,2])
print(clf_report)

使用 Flan-T5 对金融新闻进行零-shot 情感分析的混淆矩阵。图片由作者提供。

零-shot 情感分析的分类报告。图片由作者提供。
从上面的图中,我们可以看到模型找到了所有负面句子,但代价是精确度,因为它错误标记了 611 个中性句子和 92 个正面句子。此外,我们还可以看到识别中性句子存在明显的问题,因为它错误标记了绝大多数句子。
因此,让我们尝试更改提示,以查看是否可以提高模型的性能。
带有上下文学习的一次性推断
在这里,我们修改了我们的提示,加入了一个中性句子的示例。这种技术称为上下文学习,因为我们在提示中传递了模型应如何表现的示例。
传递一个示例称为一次性推断。可以传递更多示例,这种情况称为少量样本推断。
向 LLM 展示最多五个示例是正常的。如果性能没有提高,那么很可能需要对模型进行微调。
现在,让我们看看一个示例如何影响性能。
one_shot_sentiment = []
for i in range(2264):
sentence = dataset['train'][i]['sentence']
prompt = f"""
Is the follwing sentence positive, negative or neutral?
Statement: "According to Gran , the company has no plans to move all production to Russia , although that is where the company is growing ."
neutral
Is the follwing sentence positive, negative or neutral?
Statement: {sentence}
{sentence}
"""
inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
model.generate(
inputs["input_ids"],
max_new_tokens=50
)[0],
skip_special_tokens=True
)
one_shot_sentiment.append(output)
在上面的代码块中,我们看到我们给出了一个中性句子的示例,以帮助模型识别它们。然后,我们将每个句子传递给模型进行分类。
然后,我们按照相同的步骤添加包含新预测的新列,并显示混淆矩阵。
labels_df['one_shot_sentiment'] = one_shot_sentiment
labels_df['one_shot_sentiment'] = labels_df['one_shot_sentiment'].map({'neutral':1, 'positive':2, 'negative':0})
cm = confusion_matrix(labels_df['labels'], labels_df['one_shot_sentiment'], labels=[0,1,2])
disp_cm = ConfusionMatrixDisplay(cm, display_labels=[0,1,2])
disp_cm.plot();
plt.grid(False)
plt.tight_layout()

使用 Flan-T5 进行金融新闻的单次情感分析的混淆矩阵。图片由作者提供。

单次情感分析的分类报告。图片由作者提供。
从上图可以看出,略有改善。加权 F1 分数从 0.40 提高到了 0.44。模型在中性类别上的表现更好,但以牺牲对正面类别的表现为代价。
添加正面、负面和中性句子的示例可能会有帮助,但我没有进行测试。否则,就需要对模型进行微调,但那是另一篇文章的主题。
结论
本文涵盖了许多概念,从理解 LLM 的基础知识,到实际使用 Flan-T5 进行 Python 中的情感分析。
现在你拥有了探索这个领域的基础知识,可以自己看看如何微调 LLM,如何训练 LLM,以及如何围绕它们构建应用程序。
希望你学到了新东西,并且对学习更多充满好奇。
干杯 🍻
支持我
喜欢我的工作吗?通过请我喝咖啡来支持我,这是你鼓励我的简单方式,而我能享受一杯咖啡!如果你愿意,只需点击下面的按钮 👇
参考资料
注意力机制 — Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
生成式 AI 与 LLM — deeplearning.ai
类别特征:标签编码的问题所在
原文:
towardsdatascience.com/categorical-features-whats-wrong-with-label-encoding-81184c3dfb69
为什么我们不能随意编码类别特征
·发表于Towards Data Science ·10 分钟阅读·2023 年 11 月 20 日
--

云朵。作者提供的图像。
众所周知,许多机器学习模型无法原生处理类别特征。虽然也有一些例外,但通常需要实践者决定每个类别特征的数值表示。有多种方法可以实现这一点,但一种很少推荐的策略是标签编码。
标签编码用一个任意的数字替换每个类别值。例如,如果我们有一个包含字母的特征,标签编码可能会将字母“A”分配为 0,将字母“B”分配为 1,然后继续这个模式直到“Z”,即 25。经过这个过程,从技术上讲,任何算法都应该能够处理这个编码后的特征。
那么这有什么问题呢?难道复杂的机器学习模型不能处理这种编码方式吗?为什么像Catboost和其他编码策略这样的库存在,用来处理高基数类别特征?
这篇文章将探讨两个示例,演示为什么标签编码对机器学习模型可能存在问题。这些示例将帮助我们理解为什么有如此多的替代方案存在,以及加深我们对数据复杂性和模型性能之间关系的理解。
直观示例
获得机器学习概念直观理解的最佳方法之一是了解其在低维空间中的工作原理,并尝试将结果外推到更高维度。这种思维外推并不总是与现实一致,但对于我们的目的来说,我们只需一个特征即可了解为什么需要更好的分类编码策略。
一个具有 25 个类别的特征
让我们从一个具有单个特征和连续目标的基本玩具数据集开始。以下是我们需要的依赖项:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
from category_encoders.target_encoder import TargetEncoder
from category_encoders.ordinal import OrdinalEncoder
让我们读取数据集并探索一些属性:
>>> data = pl.read_csv("basic_categorical_dataset_1.csv")
>>> data.shape
(852, 2)
>>> data.sample(5)
shape: (5, 2)
┌─────────────┬───────────┐
│ cat_feature ┆ target │
│ --- ┆ --- │
│ str ┆ f64 │
╞═════════════╪═══════════╡
│ v ┆ 16.287324 │
│ z ┆ 16.285893 │
│ j ┆ 5.718953 │
│ p ┆ 14.290943 │
│ y ┆ 2.972485 │
└─────────────┴───────────┘
>>> data.select("cat_feature").n_unique()
25
该数据集包含一个分类特征cat_feature,具有 25 个唯一类别和一个连续的target。目标是学习一个函数,将每个类别映射到其对应的最佳拟合目标值。我们不需要机器学习来做到这一点,但它将帮助我们理解为什么在处理更复杂的实际问题时需要良好的分类编码策略。
接下来,我们将创建训练集和测试集:
>>> x = data.select("cat_feature").to_numpy()
>>> y = data.select("target").to_numpy()
>>> x_train, x_test, y_train, y_test = train_test_split(
... x, y, test_size=0.20, random_state=3
... )
>>> x_train.shape
(681, 1)
我们已将 80% 的数据用于训练,20% 用于测试,现在准备为cat_feature选择一种分类编码策略。与常见的机器学习智慧相悖,我们决定对cat_feature进行标签编码。标签编码将每个分类值替换为介于 0 和类别数减一之间的任意数字。在这个例子中,这些类别将被替换为介于 0 和 24 之间的数字:
>>> label_encoder = LabelEncoder()
>>> x_train_label_encoded = label_encoder.fit_transform(x_train.ravel())
>>> x_test_label_encoded = label_encoder.transform(x_test.ravel())
使用 Scikit-learn 的LabelEncoder类,我们使用训练集来决定编码并转换训练集和测试集。我们现在拥有一个单一的数值特征和一个目标,并且可以通过散点图可视化它们之间的关系:
>>> fig, ax = plt.subplots(figsize=(10, 6))
>>> ax.scatter(x_train_label_encoded, y_train)
>>> ax.set_xlabel("Label-Encoded Categorical Feature")
>>> ax.set_ylabel("Target")
>>> plt.show()
散点图如下所示:

标签编码训练数据的散点图。图片由作者提供。
正如预期的那样,每个类别都被分配了一个介于 0 和 24 之间的唯一整数,并且在每个类别内似乎有目标值的分布。我们还可以看到,编码后的cat_feature与目标之间的关系高度非线性,这排除了线性回归作为可行模型。
尽管这些训练数据看起来很复杂,但机器学习模型足够强大,可以拟合这种关系。让我们看看当我们将决策树拟合到训练数据时会发生什么:
>>> model = DecisionTreeRegressor(max_depth=4)
>>> model.fit(x_train_label_encoded.reshape(-1, 1), y_train)
>>> preds = model.predict(x_test_label_encoded.reshape(-1, 1))
我们用max_depth为 4 的 Sklearn 的DecisionTreeRegressor对训练数据进行拟合,并对测试数据进行预测。由于我们只有一个特征,我们可以暂时忽略任何回归指标,简单地绘制叠加在测试集上的预测结果:
>>> fig, ax = plt.subplots(figsize=(10, 6))
>>> ax.scatter(x_test_label_encoded, y_test, label="Actual Test Set" )
>>> ax.scatter(x_test_label_encoded, preds, label="Predictions")
>>> ax.set_xlabel("Label-Encoded Categorical Feature")
>>> ax.set_ylabel("Target")
>>> ax.set_title("Actual vs Predicted Test Set: Max Depth = 4")
>>> ax.legend()
>>> plt.show()
结果如下所示:

使用标签编码的分类特征和深度为 4 的决策树的实际测试集与模型预测。图片由作者提供。
哎呀!我们不需要评估回归指标就知道这个模型不是很好。虽然模型在一些区域表现尚可,但这些结果并不是我们希望展示给老板的。
希望还没有完全丧失。由于我们的决策树深度为 4,我们怀疑增加深度,从而增加模型的复杂性,将会得到更好的结果。让我们将深度增加到 14 看看会得到什么:

实际测试集与使用标签编码分类特征和深度为 14 的决策树的模型预测。图片来源:作者。
这看起来好多了!深度为 14 的决策树似乎很好地拟合了关系。
我们从中可以得出什么结论?也许我们最初的想法是,这只是一个复杂的关系,需要一个更大、更复杂的模型来拟合。这可以解释为什么深度为 4 的决策树效果不好,但深度为 14 的树效果很好。也许标签编码是一个有效的选择?
这可能是真的,但为了成为优秀的数据科学家,我们需要评估另一种分类编码方案。虽然它可能容易导致过拟合,但一种流行的策略是目标均值编码。在最简单的形式中,目标均值编码将每个分类值替换为该类别所有观察值的均值目标。我们可以使用category_encoders来实现这一点:
>>> target_mean_encoder = TargetEncoder(smoothing=0, min_samples_leaf=1)
>>> x_train_target_encoded = target_mean_encoder.fit_transform(x_train, y_train).values
>>> x_test_target_encoded = target_mean_encoder.transform(x_test).values
smoothing 和 min_samples_leaf 参数用于对每个类别的目标均值计算进行正则化。在这个例子中,我们不希望进行任何正则化,因此我们不应用平滑,并且只要求类别中存在一个样本。如之前所述,我们可以可视化结果:
>>> fig, ax = plt.subplots(figsize=(10, 6))
>>> ax.scatter(x_train_target_encoded, y_train)
>>> ax.set_xlabel("Target-Encoded Categorical Feature")
>>> ax.set_ylabel("Target")
>>> plt.show()
下面是目标均值编码特征与目标的散点图:

目标均值编码训练数据的散点图。图片来源:作者。
这里发生了什么?目标均值编码揭示了分类特征与目标之间的线性关系。这与我们在应用标签编码时观察到的情况大相径庭。我们更愿意使用这种编码,因为它允许我们使用更简单的甚至是线性的模型。
线性回归对这个数据集效果很好,但我们将再次使用决策树来查看这种编码为什么更强大:
>>> model = DecisionTreeRegressor(max_depth=4)
>>> model.fit(x_train_target_encoded.reshape(-1, 1), y_train)
>>> preds = model.predict(x_test_target_encoded.reshape(-1, 1))
>>> fig, ax = plt.subplots(figsize=(10, 6))
>>> ax.scatter(x_test_target_encoded, y_test, label="Actual Test Set" )
>>> ax.scatter(x_test_target_encoded, preds, label="Predictions")
>>> ax.set_xlabel("Target-Encoded Categorical Feature")
>>> ax.set_ylabel("Target")
>>> ax.set_title("Actual vs Predicted Test Set: Max Depth = 4")
>>> ax.legend()
>>> plt.show()
这是将预测结果覆盖在测试集上的效果:

实际测试集与使用目标均值编码分类特征和深度为 4 的决策树的模型预测。图片来源:作者。
使用目标均值编码的分类特征,我们可以用深度为 4 的树很好地拟合关系。与此相比,标签编码特征需要更高的树深度。一种更好的分类编码方案揭示了一个模型更容易学习的关系。
我们可以通过查看每种编码策略所需的模型复杂度(即树深度)来进一步理解这个想法,以确定它们在测试集上的收敛情况:

每种分类编码策略的测试集 MAE 与树深度的关系。图片来源于作者。
我们在这里看到的正是为什么一个好的分类编码策略至关重要的核心。这个图表告诉我们,目标均值编码模型在较低树深度时的测试集误差显著低于标签编码模型。例如,在深度为 2 时,目标均值模型的测试误差不到标签编码模型的一半。
目标均值模型的收敛速度也比标签模型更快。目标均值模型在深度为 5 时达到了最小测试误差,而标签模型则需要到深度为 9 时才能达到最小值。
在这种情况下,正确编码分类特征揭示了一个关系,使我们能够使用更简单的模型。然而,即使标签编码需要更复杂的模型来拟合关系,我们仍然使用它找到了一个与目标编码模型一样好的模型。
也就是说,即使cat_feature被标签编码,我们仍然找到了一个效果与其他模型一样好的模型。此外,我们知道还有许多更复杂的编码策略可以选择。选择一个好的编码策略真的值得吗?
一个具有数百个类别的特征
为了真正说服自己标签编码的不好,我们来看一个数据集:
>>> data = pl.read_csv("basic_categorical_dataset_2.csv")
>>> data.shape
(3777, 2)
>>> data.sample(5)
shape: (5, 2)
┌───────────────────────────────────┬────────────┐
│ cat_feature ┆ target │
│ --- ┆ --- │
│ str ┆ f64 │
╞═══════════════════════════════════╪════════════╡
│ 79a5808e-4e18-401a-94e7-7a478260… ┆ 369.232239 │
│ c29ae2f2-c4c0-4a48-b15f-ba72a81e… ┆ 492.785283 │
│ df809121-1f19-4f89-bb42-e7c593d9… ┆ 602.706521 │
│ 5840ee5f-69e9-4f5b-ac20-17e01033… ┆ 731.986467 │
│ b5c74247-6aff-4762-844e-3ebb3cbf… ┆ 404.087366 │
└───────────────────────────────────┴────────────┘
>>> data.select("cat_feature").n_unique()
917
这个数据集包含一个名为cat_feature的单一分类特征,具有 917 个唯一类别和一个连续的target。具有如此多类别的特征对我们如何编码cat_feature有什么影响?标签编码仍然能收敛吗?
和之前一样,我们将创建使用标签和目标均值编码的训练集和测试集:
>>> x = data.select("cat_feature").to_numpy()
>>> y = data.select("target").to_numpy()
>>> x_train, x_test, y_train, y_test = train_test_split(
... x, y, test_size=0.20, random_state=3
... )
>>> x_train.shape
(3021, 1)
>>> label_encoder = OrdinalEncoder(handle_unknown="value")
>>> x_train_label_encoded = label_encoder.fit_transform(x_train, y_train).values
>>> x_test_label_encoded = label_encoder.transform(x_test).values
>>> target_mean_encoder = TargetEncoder(
... smoothing=0, min_samples_leaf=1, handle_unknown="value"
... )
>>> x_train_target_encoded = target_mean_encoder.fit_transform(x_train, y_train).values
>>> x_test_target_encoded = target_mean_encoder.transform(x_test).values
在这个例子中,我们使用OrdinalEncoder而不是LabelEncoder,因为它可以处理在测试数据中可能出现的先前未见过的类别。我们可以再次可视化标签编码的分类特征与目标之间的关系:

这是一个标签编码训练数据的散点图。图片来源于作者。
这看起来不太好。标签编码在cat_feature和目标之间产生了大量噪声,很难想象一个机器学习模型在没有更多特征的情况下能够拟合这种关系。让我们将其与目标均值编码数据进行比较:

目标均值编码训练数据的散点图。图片来源于作者。
然而,当我们使用目标均值编码时,一个更简单,甚至是线性的关系再次出现。我们已经知道这允许我们使用更简单的模型,但我们说的“更简单”有多简单呢?

每种分类编码策略的测试集 MAE 与树深度的关系。图片来源于作者。
对于这个数据集,目标均值编码模型在深度为 6 时收敛,而标签编码模型直到深度为 29 才收敛。从不同的角度来看,最佳标签编码模型最终有 866 个叶子节点,而目标编码模型只有 128 个。也就是说,标签编码模型必须将单一特征划分 866 次才能收敛——这几乎是每个类别一个划分。
尽管这两个例子都使用了单一特征来预测目标,但我们可以想象一个更现实的场景,其中可能使用数百个特征。如果类别特征被标签编码,常用的基于决策树的算法可能会忽略这些特征,因为它们与目标的关系可能比其他特征复杂得多。
再次强调,我们不应得出目标均值编码是最佳策略的结论,因为它往往会过拟合训练数据。然而,我们确实看到了为什么需要一个好的编码策略,以及为什么标签编码会阻碍模型性能。
主要收获
这篇文章探讨了标签编码,这是一种将每个类别替换为任意数字的类别编码方法。我们发现标签编码可能会在类别特征与目标之间创建不必要的复杂性,需要更大的模型来进行拟合。找到合适的编码方法就是揭示有意义的关系,从而允许我们使用更简单的模型。
成为会员: https://harrisonfhoffman.medium.com/membership
使用 BERT 对自由文本银行交易描述进行分类
原文:
towardsdatascience.com/categorize-free-text-bank-transaction-descriptions-using-bert-44c9cc87735b
我为自己建立了一个开支跟踪工具
·发布于 Towards Data Science ·7 分钟阅读·2023 年 1 月 30 日
--

按类别统计开支。图表由作者提供
情况说明
我在 2022 年年底购买了一处房产,并办理了抵押贷款。由于财务承诺的增加,我想对自己的开支进行监控。直到这一点之前,我从未意识到自己实际上不知道自己最常花钱的地方。弄清楚这一点可能是我自己开支管理的一个良好起点。
自然,我转向了从在线银行门户下载的银行交易数据,格式为 .csv。下面是 2022 年最后几天的数据片段。

图片 1:作者的银行交易数据。图片由作者提供
根据上面的数据片段,似乎我在食品上的支出比例较高(如绿色高亮所示)。更重要的是,交易描述是基于自由文本的,有没有办法自动将这些描述分类到一些预定义的开支类别中(例如食品、杂货购物、水电费等)?
使用像 BERT 这样的预训练大型语言模型至少有一种方法,本文提供了如何进行操作的教程!
2023 年 BERT 介绍
虽然 ChatGPT 作为一种先进的文本生成模型目前受到广泛关注,但它通常不被认为是通用模型——例如 BERT 可以用于多种自然语言理解任务。一些示例包括语法检测、情感分类、文本相似性、问答推理等。
BERT 由谷歌于 2018 年开发和发布。它是一个使用维基百科和 BookCorpus 中的文本段落进行预训练的模型(以确保训练数据在语法上是准确的)。
本教程中使用的 BERT 模型可以通过 Hugging Face 的 sentence_transformer 库获得,该库是一个用于创建句子、文本和图像嵌入的 Python 框架。
构建费用分类器的步骤
我究竟如何将自由文本交易描述转换为费用类别?我能想到几种策略。在本教程中,我将提供一个基于(余弦)词嵌入相似度的费用分类器的逐步指南。步骤如下:
-
手动将大量交易描述标记为一个费用类别(例如,食品、娱乐)。这会创建一组标记的训练数据。
-
使用 BERT 将上述训练数据中的单个交易描述解析为词嵌入(即将文本转换为数值向量)。步骤 1和步骤 2共同确保训练数据被分配到特定的费用类别以及词嵌入向量。
-
对新的交易描述重复步骤 2(即将未见过的文本转换为数值向量)
-
将步骤 3中的词嵌入与训练数据中最相似的词嵌入配对,并分配相同的费用类别
Python 实现
本节提供了加载所需包以及实施上述步骤的 Python 代码(不包括步骤 1,这是一项手动标记步骤)。
步骤 0:导入所需的库
#for dataframe manipulation
import numpy as np
import pandas as pd
#regular expressoin toolkit
import re
#NLP toolkits
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
#for plotting expense categories later
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import seaborn as sns
import matplotlib
import matplotlib.ticker as ticker # for formatting major units on x-y axis
#for downloading BERT
!pip install sentence_transformers
from sentence_transformers import SentenceTransformer
#for finding most similar text vectors
from sklearn.metrics.pairwise import cosine_similarity
步骤 1:标记训练数据
我手动将 200 个交易描述标记为一个费用类别。例如,图像 1中的交易描述被分配了如下图所示的费用类别。我还将公共事业(即电费和煤气费)、汽车和礼品等类别分配给了训练数据中的其他交易。

图像 3:训练数据的手动标记。图片作者提供
步骤 2:使用 BERT 创建训练数据的词嵌入
我们首先定义一个用于清理文本数据的函数。这包括将单词小写、去除特殊字符(包括日期,这些在确定费用类别时并不有用)。
在使用 BERT 模型时,通常不需要进行词干提取、词形还原或去除停用词等 NLP 数据清理管道中的常见做法,因为 BERT 模型使用了字节对编码和注意力机制。
###############################################
### Define a function for NLP data cleaning ###
###############################################
def clean_text_BERT(text):
# Convert words to lower case.
text = text.lower()
# Remove special characters and numbers. This also removes the dates
# which are not important in classifying expenses
text = re.sub(r'[^\w\s]|https?://\S+|www\.\S+|https?:/\S+|[^\x00-\x7F]+|\d+', '', str(text).strip())
# Tokenise
text_list = word_tokenize(text)
result = ' '.join(text_list)
return result
然后我们将函数应用于交易描述,这些描述从图像 1中作为text_raw加载(df_transaction_description)。
text_raw = df_transaction_description['Description']
text_BERT = text_raw.apply(lambda x: clean_text_BERT(x))
以下片段显示了数据清理应用前后的特定交易示例。

图像 2:数据清理示例。图片作者提供
然后我们将清理后的文本输入到 BERT 中。我选择了'paraphrase-mpnet-base-v2' BERT 模型,该模型以建模句子相似性而闻名。根据其在Hugging Face上的文档,它将句子和段落映射到一个 768 维的密集向量空间,并可用于诸如聚类或语义搜索等任务。
######################################
### Download pre-trained BERT model###
######################################
# This may take some time to download and run
# depending on the size of the input
bert_input = text_BERT.tolist()
model = SentenceTransformer('paraphrase-mpnet-base-v2')
embeddings = model.encode(bert_input, show_progress_bar = True)
embedding_BERT = np.array(embeddings)
下方提供了前几笔交易的词嵌入片段:

图像 4:BERT 嵌入。图像由作者制作
步骤 3:为未见数据创建词嵌入
我从未见过的数据中选择了 20 笔交易(为了本教程的目的,随机选择了交易)。这些交易显示在下图中。

图像 5:未见交易。图像由作者制作
上述交易描述作为text_test_raw加载。类似于步骤 2,这些数据通过 BERT 进行嵌入处理。
# Load texts
text_test_raw = df_transaction_description_test['Test']
# Apply data cleaning function as for training data
text_test_BERT = text_test_raw.apply(lambda x: clean_text_BERT(x))
# Apply BERT embedding
bert_input_test = text_test_BERT.tolist()
#model = SentenceTransformer('paraphrase-mpnet-base-v2')
embeddings_test = model.encode(bert_input_test, show_progress_bar = True)
embedding_BERT_test = np.array(embeddings_test)
df_embedding_bert_test = pd.DataFrame(embeddings_test)
步骤 4:将未见数据与最相似的训练数据配对
# Find the most similar word embedding with unseen data in the training data
similarity_new_data = cosine_similarity(embedding_BERT_test, embedding_BERT)
similarity_df = pd.DataFrame(similarity_new_data)
# Returns index for most similar embedding
# See first column of the output dataframe below
index_similarity = similarity_df.idxmax(axis = 1)
# Return dataframe for most similar embedding/transactions in training dataframe
data_inspect = df_transaction_description.iloc[index_similarity, :].reset_index(drop = True)
unseen_verbatim = text_test_raw
matched_verbatim = data_inspect['Description']
annotation = data_inspect['Class']
d_output = {
'unseen_transaction': unseen_verbatim,
'matched_transaction': matched_verbatim,
'matched_class': annotation
}
d_output 数据框显示,未见数据已被分配到一个相当合理的费用类别。

图像 6:未见数据与训练数据的匹配。图像由作者制作
现在,每当有新的费用产生时,只需将其输入模型即可!
附加步骤:按类别绘制费用图
我实际上将上述步骤应用于了 2022 日历年度的所有费用。下图展示了按分配类别计算的费用金额。

图表 7:按类别划分的费用图。图表由作者制作
主要观察结果如下:
-
在 2022 年,我在食品上的开支最多,其次是抵押贷款还款和公用事业账单。
-
尽管信用卡还款金额最高,但假设信用卡支出可以按相同比例分配到其他费用类别。这一假设也适用于 PayPal 类别。
-
根据数据,我可能希望减少食品开支,将更多支出转向杂货(即,开始在家做饭而不是外出就餐)以便于 2023 年。
-
我在美容产品上的支出可能是因为我和妻子一起购物的情况……
此外,返回特定类别中最高支出的交易非常简单。例如,我在 2022 年食品费用类别中的最高支出如屏幕截图所示。我对结果感到满意,因为这些餐厅中有些在训练数据中并不存在。尽管如此,BERT 仍然能够将这些交易分配到食品类别。

图像 7:主要费用。图像由作者制作
结论
本文提供了一个构建开支跟踪工具的全面教程。我做的就是将自由文本的交易描述翻译成机器理解的语言,使用 BERT,并让机器完成繁重的工作!
另一种方法是通过将相同的词向量嵌入传递到分类模型中,来替代本教程的第 4 步——这是给读者进一步实验的内容。
如果你喜欢我的这篇文章,可以随意阅读其他的文章。
随着我踏上 AI/ML 的浪潮,我喜欢用全面的语言撰写和分享一步一步的指南和如何做的教程,并附带可运行的代码。如果你想访问我所有的文章(以及 Medium 上其他从业者/作者的文章),你可以通过 这个链接 注册!
因果图:面对观察数据中的致命弱点

照片由 Андрей Сизов 提供于 Unsplash
“为何之书”第 3&4 章,阅读与我系列
·发表于 Towards Data Science ·阅读时长 13 分钟·2023 年 11 月 23 日
--
在我之前的两篇文章中,我 开启了 “阅读与我”系列,并完成了 前两章 的阅读,书名是由朱迪亚·珀尔(Judea Pearl)所著的 “为何之书”。这些文章讨论了引入因果关系以实现类人决策的必要性,并强调了设定未来讨论基础的因果阶梯。在这篇文章中,我们将探讨从因果关系的第一阶梯到第二阶梯的钥匙孔,带我们从概率思维进入因果思维。我们将从贝叶斯规则到贝叶斯网络,最终到因果图进行深入探讨。
从贝叶斯规则到逆概率
作为侦探小说的粉丝,我最喜欢的系列是《福尔摩斯探案集》。我仍然记得那些我在不知不觉中度过的日日夜夜。多年后,很多案件细节已经从我的记忆中消失,但我仍然记得那些著名的名言,如同其他人一样:
当你排除掉不可能的情况时,任何剩下的,无论多么不可能,必须是真相。
将这句话翻译到统计学领域,有两种概率——前向概率和逆向概率。根据福尔摩斯的演绎推理,侦探工作就是找到逆向概率最高的凶手。

图片来源于Markus Winkler在Unsplash
从正向概率到逆向概率的转变不仅仅是按顺序翻转变量,还强制建立了因果关系。正如前一篇文章简要讨论的那样,贝叶斯规则提供了一座桥梁,将客观数据(证据)与主观意见(先验信念)连接起来。基于贝叶斯规则,我们可以从任意两个变量中计算条件概率。对于任意变量 A 和 B,给定 B 已发生,A 发生的概率是:
P(A|B) = P(A&B)/P(B)
认为事件 A 发生的信念会根据事件 B 发生的概率来更新。事件 B 发生的可能性越小,P(B)越小,我对事件 A 发生的信念就越强。由于 P(B)小于或等于 1,P(A|B)总是大于或等于 P(A&B)。这就是说,一个人在发现 B 之后对 A 的信念永远不会低于在发现 B 之前对 A 和 B 的信念。请注意,这里的条件概率适用于所有变量关系,甚至是非因果关系。然而,逆概率仅适用于因果关系。
假设这两个事件是原因和证据。正向概率表示在已知原因的情况下,发生证据的概率。另一方面,逆向概率则从结果出发,显示在已知证据的情况下,发生原因的概率。如果我们能够识别原因和证据之间的因果关系,那么我们可以根据观察到的情况推断原因的概率,这在解决实际问题时更具适用性。

图片来源于Markus Winkler在Unsplash
在书中,Pearl 提供了一个应用实例,估计在乳腺 X 光检查结果为阳性的情况下,患乳腺癌的概率是多少,即P(疾病|测试)?首先,有一个明确的因果关系,其中乳腺癌是原因,乳腺 X 光检查结果是证据。当我们看到阳性测试结果时,并不意味着该患者一定患有癌症,因为没有测试是 100%准确的。然而,我们可以根据测试质量来推断该患者患乳腺癌的概率,这被定义为测试的敏感性P(测试|疾病)。测试的敏感性实际上是正向概率,适用于一般人群。
此外,个体特定的信息也可以改善我们对每个患者的逆概率估计。例如,如果这位患者来自一个有多个家庭成员被诊断为乳腺癌的家庭,那么阳性测试结果将比没有家族癌症史的患者更可信。这些患者特定的信息作为先验被添加到最终公式中,该公式指示如何根据证据(观察到的阳性测试结果)更新先验(患乳腺癌的概率):
Updated odds of D = Likelihood Ratio * Prior odds of D
在数学术语中,它是:

通过使用条件概率和似然比,我们可以在两个方向上更新信念。如果我们从原因处获得新的信息,我们可以通过条件概率来更新对证据的信念:
P(T| D) = P(T & D)/P(D)
P(T| D)的变化是由于 P(D)的变化。如果我们从证据中获得新的信息,条件概率是不正确的,因为测试结果为阳性并不意味着你有乳腺癌。因果关系被逆转。然而,我们可以使用似然比来更新我们的信念。
到目前为止,我们只讨论了两个因果相关的变量,但这个规则可以应用于整个因果网络,其中父节点表示原因,子节点表示证据。子节点通过应用条件概率来更新其信念,而父节点通过乘以似然比来更新其信念。在整个网络中应用这两个规则称为信念传播。通过这些规则,我们超越了贝叶斯规则,理解了原因如何影响证据的生成,以及观察证据如何帮助我们推断原因。
混淆因素,非实验研究中的阿喀琉斯之踵
信念传播帮助我们理解变量之间的相互作用,前提是我们能够正确识别因果关系。在现实世界中,超越两个变量,我们需要将因果关系扩展到因果图,以系统地推导因果影响。但在我们进入因果图(本书的核心内容)之前,让我们简要讨论一下是什么阻碍了我们从观察数据中推导因果关系,这些被称为混淆因素。
“Confounding”在英语中意为“混淆”。它是与 X 和 Y 都有相关性的变量。请注意,这种相关性可能是因果的,也可能是非因果的。此外,在下图中,我没有指定 X&Z 和 Y&Z 之间的箭头,因为在因果情况下,X、Y 和 Z 都可以是原因或结果,从而建立不同的因果图,这将在下一节中讨论。左侧面板显示了引入混淆变量 Z 如何在 X 和 Y 之间产生虚假的相关性。

作者提供的图像
在右侧面板中,如果 X 和 Y 之间存在因果关系,那么一个同时影响原因 X 和结果 Y 的混杂因素 Z 如果没有得到妥善处理,会引入混杂偏倚。如果我们不排除混杂因素引起的影响,我们将无法揭示 X 对 Y 的真实因果效应。
在实验研究中,将受试者随机分配到处理组和对照组中的随机性可以解决混杂因素的来源偏倚(最后一节将详细讨论)。然而,进行实验来研究因果效应并不总是实际和道德的,在这种情况下,我们将不得不尝试从观察性数据中推导真实的因果影响。与实验数据不同,观察性数据中存在混杂因素,因为总有影响原因和结果的因素存在。
例如,为了研究吸烟是否导致肺癌,其中一个混杂因素是年龄。不同年龄组的吸烟率差异很大,而且年龄越大,患肺癌的概率越高。我们必须控制年龄和其他混杂因素,才能获得真实的因果影响。统计学家和社会科学家用来对抗混杂因素偏倚的常用方法是“控制”模型中的尽可能多的混杂因素。这种方法存在几个问题:
-
并非所有的混杂因素都是可测量的: 从直观上看,我们可以推测出可能影响我们感兴趣的因果关系的混杂因素。然而,通常无法量化这些变量或找到合适的代理变量将其纳入模型。例如,在研究高等教育是否导致更高收入时,可能的混杂因素之一是“雄心”。有雄心的人更可能被激励去获得更高的教育和更高薪的工作,但在观察性研究中,我们如何量化这个主观变量呢?
-
遗漏变量: 无论我们试图在研究中包含多少变量,仍然很可能没有将所有必要的和正确的混杂因素或其代理变量纳入模型,从而使因果影响出现偏倚。
-
控制混杂因素引发偏倚: 另一方面,在实践中,统计学家为了确保没有混杂因素被遗漏,会在模型中包括尽可能多的变量以确保去偏估计。然而,这种过度控制实际上可能会引发偏倚。正如政治博客作者Ezra Klein所写:
“你在研究中总是会看到这样的情况。‘我们控制了……’ 然后列表开始。越长越好,收入、年龄、种族、宗教、身高、发色、性取向、Crossfit 参与情况、父母的爱、可乐还是百事。你可以控制的因素越多,你的研究就越强——或者,至少,看起来你的研究更强。控制变量带来的是特异性和精确性的感觉……但有时,你可能控制得太多。有时你最终控制了你试图测量的东西。”
最终,在解决混杂偏差时,我们面临很多问题,因为这是一个 Rung 2 问题,需要我们研究变量之间的因果关系。因此,一个不涉及因果结构的 Rung 1 解决方案,如绘制因果图,将是不够的。在下一部分,我们将看到如何利用因果图以系统和可靠的方式定义和控制混杂因素。
建立因果图,因果关系的钥匙
三种基本结构
要理解什么是因果图,我们可以从所有网络的基本构建块开始。网络中有三种基本的连接点,能够表征任何箭头模式:

作者提供的表格
贝叶斯网络和因果图中都存在三种基本类型。应用贝叶斯规则在变量之间构建贝叶斯网络,这不过是一个庞大概率表的简洁表示。如果我们在贝叶斯网络中看到链式结构 A -> B -> C,那么 A 和 C 之间缺失的箭头意味着一旦知道了 B 的值,A 和 C 就是独立的。如果在因果图中观察到相同的链式结构,除了在控制 B 时 A 和 C 之间的相同独立性,我们还会看到通过箭头的因果流。这一结构表明 C 是由 B 引起的,B 是由 A 引起的,而 A 是外部的。如果我们将结构改变为 C -> B -> A,或变为分叉结构 A <- B -> C,我们将看到在控制 B 的情况下 A 和 C 之间的独立性完全相同,但因果结构却发生了巨大的变化。换句话说,数据无法告诉我们一切。无论数据多么庞大,没有添加主观的因果假设,我们不能区分A -> B -> C、C -> B -> A和A <- B -> C。
此外,从贝叶斯网络转向因果图时,我们也在将 Rung 1 的概率思维转变为 Rung 2 和 Rung 3 的因果思维。我们可以不再使用“知道 B 的值之后”这种概率表达,而改为“保持 B 不变”,这相当于从“看到 B”转变为“干预 B”。在后续部分,我们将看到这种差异源于P(Y|X)和P(Y| do(X))。贝叶斯网络只能告诉我们在观察到另一事件的情况下某一事件发生的可能性。然而,因果图能够回答干预和反事实问题。
反门准则
因果图不仅使我们转变为因果思维,还为我们提供了一个可靠的工具来发现和验证观察数据中的因果效应。正如前一章所提到的,识别正确的混杂因素是主要挑战。为了解决这个问题,Pearl 引入了 do-operator 和 back-door criterion。

Dima Pechurin 在 Unsplash 上的照片
关键是弄清楚因果图,do-operator 消除了所有进入 X 的箭头,从而防止任何有关 X 的信息流向非因果方向。而 P(Y| X) 显示了带有混杂偏差的因果效应,概率 P(Y| do(X)) 显示了真正的因果影响。这意味着通过阻塞其他混杂因素的信息流,如果我改变 X,Y 会如何变化?根据不同的因果结构,我们需要控制或不控制不同的变量来阻塞信息流。

作者提供的表格
为了获得 P(Y| do(X)),我们需要确保信息流从 X 到 Y 仅直接来自 X 到 Y。为了实现这一目标,我们需要阻塞 X 和 Y 之间的所有非因果路径,而不干扰任何因果路径。这些非因果路径被称为 X 的反门路径,即任何从 X 到 Y 的路径,该路径以指向 X 的箭头开始。通过以下五个示例,可以更容易理解这一概念:

作者提供的表格,Game2 中有误,如果你控制 D,那么我们可以控制 A 或 D 来阻塞路径。
通过指定因果图,我们已经将控制尽可能多的混杂因素的过程转变为识别反门路径并找出如何有效地阻塞它们。如笔记中所述,并不总是需要控制尽可能多的变量以确保真正的因果效应。实际上,控制错误的变量可能会:
-
减少或阻塞 X 和 Y 之间的因果路径。例如,在游戏 1 中,如果我们控制 A,就会阻塞 X 和 Y 之间的因果路径;如果我们控制 B,A 的后代,则部分阻塞它。
-
引入 X 和 Y 的碰撞偏差。例如,在游戏 4 中,控制 B 会使 X 和 Y 在没有因果关系的情况下依赖。游戏 4 也被称为“ M 偏差”,因为它的形状。
-
控制正确的混杂因素,而不是尽可能多的因素。例如,在游戏 5 中,我们可以选择同时控制 B 和 A,或者仅控制 C 以达到相同的结果。
每一个图示都可以在现实世界的例子中找到。例如,在游戏 1 中表示一个医学应用,估计吸烟(X)对流产(Y)的影响。A 是由吸烟引起的潜在异常,因为我们不知道具体是什么异常被吸烟引起,所以它是不可观察的。B 是之前流产的历史。将流产历史包含到模型中是非常诱人的,但从因果图中可以看出,如果这样做,会部分失效吸烟对流产的机制,从而低估真正的因果影响。Pearl 的书中的这两章还有更多现实世界的应用。即使因果图可能变得过于复杂,使人脑无法找到后门路径,不要忘记我们总是可以依靠计算机算法来破解这些类型的问题。
为什么随机对照试验(RCT)有效?
我们已经使用因果图讨论了足够的非实验性研究。我们如何使用因果图和后门准则来解释为什么 RCT 能够得出无偏的因果影响?让我们看看一个例子,尝试找出不同肥料如何影响土壤产量。在现实世界中,农民根据许多因素决定使用哪种肥料,比如土壤肥力、土壤纹理,这些因素也会影响产量。我们可以在因果图中展示:

图片由作者提供
所有的橙色线条展示了偏倚肥料对产量因果影响的混杂关系。为了解决这个问题,我们需要在模型中控制所有这些混杂因素。请注意,这可能不太可能,因为这里的“其他”因素可能很难命名和量化。然而,如果现在我们设计一个实验,仅通过抽取随机卡片来决定每块土地使用哪种肥料。现在因果图变成了这样的:

图片由作者提供
通过在图示中添加随机卡片,我们可以去除之前图示中的所有混杂橙色线条,因为我们使用哪种肥料不再依赖于这些变量。它纯粹是一个仅受随机卡片抽取影响的随机决策。后门准则已经满足了用来估计肥料和产量的因果影响。
这就是我想分享的关于 Judea Pearl 的《为什么书》第三章和第四章的内容,这也完成了本系列“与我一起阅读”的第三篇文章。希望这篇文章对你有帮助。如果你还没阅读前两篇文章,可以在这里查看:
从一个猫的故事开始……
towardsdatascience.com ## 数据告诉我们“什么”,而我们总是寻求“为什么”
“为什么的书” 第一章与第二章,阅读系列
towardsdatascience.com
如果你感兴趣,订阅我的邮件列表 参加每两周一次的讨论,这些讨论将变得越来越技术性:
-
附赠:因果推断在学术界和工业界有何不同?
书中展示了更多细节和例子。正如往常一样,我强烈建议你阅读、思考,并在这里或你的 个人博客 上分享你的主要收获。
感谢阅读。如果你喜欢这篇文章,别忘了:
-
查看我最近的文章,关于 数据讲故事中的 4Ds:将科学变成艺术; 数据科学中的持续学习; 我如何成为数据科学家;
-
订阅 我的邮件列表;
-
或者在 YouTube 上关注我,观看我最近的 YouTube 视频,关于我读的其他书籍:
参考
为什么的书 作者为 Judea Pearl
通过回归估计因果效应
原文:
towardsdatascience.com/causal-effects-via-regression-28cb58a2fffc
3 种流行技术及其 Python 示例代码
·发表于Towards Data Science ·阅读时间 8 分钟·2023 年 1 月 10 日
--
这是关于因果效应系列文章的第 5 篇。在之前的文章中,我们讨论了从数据中计算处理效应的不同方法。在这里,我介绍了通过 3 种流行的基于回归的技术来估计因果效应的替代方法。我以如何在实践中使用这些技术的 Python 示例代码来结束本文。
关键点:
-
回归是利用数据学习变量之间关系的一种方法
-
3 种常见的基于回归的因果效应估计方法是:线性回归、双重机器学习和元学习器

通过线性回归的因果效应玩具示例。图片由作者提供。
什么是回归?
回归是利用数据学习变量之间关系的方法。例如,巴布亚新几内亚成人胡安树袋熊的身高与体重之间的关系。
回归过程的输出称为模型。这本质上是我们可以用来进行预测的东西,例如,你告诉我一只树袋熊的身高,我可以告诉你它的体重大致值。
回归的关键好处是我们可以利用数据将模型与现实相匹配。
要使用回归来估计因果效应,我们需要开发数据驱动的模型,这些模型捕捉了处理、协变量和结果之间的关系。然后,审查这些模型以量化因果效应。
我们之前用平均处理效应(ATE)定义了因果效应,即处理组和控制组之间的结果均值差异。由于我们可以直接从数据中估计因果效应,因此不需要模型!
然而,在回归框架中,因果效应的推导方式不同。要了解这一点,我们从最简单的基于回归的因果效应估计方法——线性回归开始。
什么是处理效应以及如何计算?
towardsdatascience.com
线性回归
对于这种方法,我们训练一个线性模型来预测结果变量(Y)相对于处理变量(X)。然后我们将 因果效应 定义为 回归模型中处理变量的系数。下面给出了一个简单的示例。

使用线性回归估计 X 对 Y 的因果效应的简单示例。图像来自作者。
其中 Y 是结果变量,X 是处理变量,b 是截距(可以解释为误差项),Θ 是 X 对 Y 的因果效应。注意在这个回归框架中定义的因果效应是 根本不同的,与我们在 过去的文章 中定义的方式。
更进一步,我们可以在线性模型中包含混杂因素。在这种情况下,混杂因素 是 影响处理变量和结果变量的变量。通过这样做,我们可以减少由于混杂因素导致的因果效应估计的偏差。有关如何使用线性回归估计因果效应的更多细节,我建议读者参考 Gelman 和 Hill 的章节 [1]。
双重机器学习
尽管线性回归的简单性使其易于使用,但它可能无法准确捕捉变量之间的关系(例如,当变量之间存在非线性关系时)。这时,更复杂的技术会更有帮助。
一种这样的技术被称为 双重机器学习 (DML)。介绍这种方法的论文很详尽(约 70 页),但整个过程可以分解为 3 个简单步骤 [2]。
-
训练 2 个回归模型。一个是预测 结果变量 相对于相关协变量。另一个是预测 处理变量 相对于协变量。
-
计算每个模型的残差。换句话说,如果 f(Z) 通过 Z 估计 Y,而 g(Z) 通过 Z 估计 X,则它们的残差分别为 U = Y-f(Z) 和 V = X-g(Z)。
-
计算处理效应。利用残差,我们可以使用下面的方程直接计算处理效应。

双重机器学习方法的因果效应表达式 [3]。图像来自作者。
这被称为双重机器学习,因为我们训练了 2 个机器学习模型,f(Z)和 g(Z)。此外,对所使用的机器学习方法没有限制,它们可以是简单的线性回归,也可以是复杂的亿参数神经网络。
做 DML 时一个重要的细节是需要将可用数据分成 2 个子集:主要样本和辅助样本。然后使用这些子集执行一个称为交叉拟合的过程。
这包括使用主要样本和辅助样本来分别训练模型 g(Z)和 f(Z),然后进行交换,即使用主要样本来训练 f(Z)和辅助样本来训练 g(Z)。
然后,我们可以对每个样本-模型配对的因果效应估计进行平均。虽然这看起来像是额外的一步,但它在保持我们因果效应估计的数学简单性方面是重要的。更多细节请查看 DML 论文第 5 页[2]。
元学习者
元学习者旨在通过训练回归模型捕捉处理、协变量和结果之间的关系。与线性回归和 DML 不同,元学习者的因果效应不是定义为回归系数。相反,回归模型用于模拟每个单元的未观察结果,从而得到个体处理效应(ITE)。然后可以使用 ITE 来计算 ATE。
此外,对于所谓的异质性因果效应,可以使用条件平均处理效应(CATE)。这仅仅是特定子人群的 ATE(例如,雄性袋鼠、婴儿女孩的因果效应等)。异质性处理效应是在人群广泛变化的处理效应。
T-learner
第一种元学习者是T-learner(或双重学习者)。在这里,我们训练2 个结果模型(因此得名),一个用于控制组,另一个用于处理组[4]。
此技术可以分解为 2 步过程。
-
训练 2 个模型来分别估计控制组和处理组中的结果变量。
-
使用每个模型为每个单元生成(控制和处理)结果预测,并获得 ITE,这些 ITE 可以用来计算 ATE。

T-learner 过程概述。图片由作者提供。
S-learner
接下来,我们有S-learner(单一学习者)。这与 T-learner 类似,但不是训练 2 个结果模型,而是只创建一个,但将处理变量作为预测因子[4]。
使用此技术获取因果效应时,我们再次遵循 2 步过程。
-
训练模型以估计协变量和处理值方面的结果变量。
-
使用每个单元的模型预测来估计 ITE,并将其汇总以获取 ATE。

S-learner 过程概述。图片由作者提供。
X-learner
最终类型的 meta-learner 是X-learner。这种方法与 T-learner 有重叠,但更进一步。
X-learner 的 4 步过程[4]。
-
训练 2 个模型以分别估算对照组和处理组的结果变量,考虑协变量。(就像我们对 T-learner 所做的那样)。
-
使用模型估算未观察到的结果值。例如,如果单位 i=0 的 X₀=1 且 Y₀=1,我们然后使用对照组模型估算 X₀=0 的未观察到的结果。然后,使用估算的结果值来分别计算对照组和处理组的 ITE。
-
再训练 2 个模型以分别估计处理组和对照组的 ITE。
-
通过结合 ITE 模型使用权重函数* w()来估计 CATE。 (提示:使用倾向得分作为 w*)。

X-learner 过程概述。图片由作者提供。
关于 Meta-learners 的更多内容,我建议读者参考 Kunzel 等人的论文[4]和 Causal ML 文档[5]。
如何从观察数据中估计效果
towardsdatascience.com
示例代码:估计研究生学校对收入的处理效果(再次查看)
在这个例子中,我们使用 3 种基于回归的技术来估计拥有研究生学位对年收入超过 50k 美元的因果影响。我们使用开源的DoWhy库和来自UCI 机器学习库的开放数据[6]。
示例代码可以在GitHub Repo中找到。
# import modules
import pickle
import econml
import dowhy
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
# load data
df = pickle.load( open( "df_causal_effects.p", "rb" ) ).astype(int)
一旦我们拥有了库和数据,我们必须定义我们的因果模型。这本质上定义了我们的处理、结果和协变量。
# define causal model
model = dowhy.CausalModel(
data = df,
treatment= "hasGraduateDegree",
outcome= "greaterThan50k",
common_causes="age",
)
# define estimand
estimand = model.identify_effect(proceed_when_unidentifiable=True)
在这里,我们将“hasGraduateDegree”定义为处理,“greaterThan50k”定义为结果,而“age”作为唯一的混杂变量。
首先,我们尝试线性回归。DoWhy 库让一切变得非常简单,因此我们只需运行 2 行代码。
# Linear Regression
LR_estimate = model.estimate_effect(estimand,
method_name="backdoor.linear_regression")
print(LR_estimate)
# ATE = 0.2976
接下来,我们尝试Double ML,虽然这个简单的例子对于它来说有点过于复杂,尤其是处理和结果变量只有 0 或 1 的值时。
# Double Machine Learning
DML_estimate = model.estimate_effect(estimand,
method_name="backdoor.econml.dml.DML",
method_params={"init_params":{
'model_y':LinearRegression(),
'model_t':LinearRegression(),
'model_final':LinearRegression()
},
"fit_params":{}
})
print(DML_estimate)
# ATE = 0.2977
请注意,在这个例子中,我们在 DML 过程使用的模型都是线性回归,但对于更复杂的问题,可以(且很多时候应该)使用更复杂的技术。
最后,我们尝试使用决策树构建X-learner作为我们的子模型。
# X-learner
Xlearner_estimate = model.estimate_effect(estimand,
method_name="backdoor.econml.metalearners.XLearner",
method_params={"init_params":{
'models': DecisionTreeRegressor()
},
"fit_params":{}
})
print(Xlearner_estimate)
# ATE = 0.2032
[## YouTube-Blog/causal_effects_regression 在主分支 · ShawhinT/YouTube-Blog
你现在无法执行此操作。你在另一个标签或窗口中已登录。你在另一个标签或窗口中已注销…
更多关于因果关系: 因果效应概述 | 因果关系:简介 | 因果推断 | 因果发现
资源
社交媒体: YouTube 🎥 | LinkedIn | Twitter
支持: 请我喝咖啡 ☕️
shawhin.medium.com/subscribe?source=post_page-----28cb58a2fffc-------------------------------- [## 获取我撰写的每一个新故事的免费访问权限
获取我撰写的每一个新故事的免费访问权限 P.S. 我不会与任何人分享你的电子邮件 通过注册,你将创建一个…
[1] 对处理变量进行回归的因果推断 由 Andrew Gelman 和 Jennifer Hill 编写
[2] 用于处理和因果参数的双重/去偏机器学习 由 Victor Chernozhukov 等人编写
[3] DoubleML Python 库文档
[4] 使用机器学习估计异质处理效应的元学习者 由 Kunzel 等人编写
[5] CausalML Python 库文档 (元学习者)
[6] Dua, D. 和 Graff, C. (2019). UCI 机器学习库 [http://archive.ics.uci.edu/ml]。加州欧文:加州大学信息与计算机科学学院。(CC BY 4.0)
因果推断:准实验
原文:
towardsdatascience.com/causal-inference-quasi-experiments-36d35ca5f754
你的 PM 忘记运行 A/B 测试了… 现在怎么办?
·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 8 月 9 日
--

图片来源:Isaac Smith 在 Unsplash
本文是关于使用准实验进行因果推断的系列文章中的第一部分(具体取决于我会啰嗦多少)。简而言之,第一部分将解释准实验的理由和方法,以及应用像 PSM 这样的办法时涉及的细微差别。在第二部分,我将更多地谈谈准实验的局限性以及基于这些实验做决策时需要注意的事项。我还会提出一个异质影响估计的框架,以帮助克服外推偏差。在第三部分… 我还不确定。
你可能也见过其他文章解释准实验,但我仍然会尝试以我的方式解释。请读一读。
为什么因果推断?
开发和推出产品及功能的成本最终是通过对消费者的积极影响来证明的。因此,听到产品经理做出各种声明,如“我们很高兴宣布我们最新的功能发布导致了 12% 的收入增长!”并不令人意外。
听起来很棒,老实说,大多数高级管理人员很乐意接受这样的说法作为事实。今天我的目标是说服你深入了解这些主张背后的因果推断方法。掌握因果推断,你将能更好地评估产品和功能对用户和公司带来的影响。
让我们看看 ChatGPT 对于为什么因果推断对于产品是必要的有何看法:
因果推断赋予产品团队的能力是超越仅仅观察数据中的相关性,建立对驱动产品表现的因果机制的更深刻理解。(毫不意外,比我能说的任何东西都更加简明扼要)
这里值得特别提到的是相关性和因果性的问题。
相关性并不意味着因果性。(别翻白眼)
说实话,我们中的许多人说它并认为我们知道它的含义。当有人问我们是什么意思时,我们拿出一个搞笑的图表来证明我们的智力能力(看看这个流行的虚假相关示例),并自豪地宣称我们永远不会在日常影响评估工作中犯这样的错误。好吧,经验告诉我,许多了解这种谬误的人并不真正理解它在现实世界中的表现。这通常源于因果推断领域的基础薄弱。

来源: www.tylervigen.com/spurious-correlations (CC BY 4.0)
那么,如果因果推断对客观评估我们在产品和功能上的投资回报很重要,我们该如何进行呢?
在其最无争议的形式中,因果推断通常通过 A/B 测试来操作(遗憾的是,这不是今天讨论的主题)。然而,现实情况是实验并不总是可用的。
准实验:为什么?
首先,有时我们只是忘记进行实验。这通常发生在产品团队成功地使用一部分实验用户证明了影响,然后继续向所有用户发布。在这个过程中,他们忘记保留一个控制组来评估普遍影响。
其次,有时实验根本不可能。例如,产品或功能可能涉及对用户高度可见的变化,或对用户高度敏感的变化。在这种情况下,用户体验占据主导地位,控制组的设置显然不现实。
这就是准实验派上用场的地方。与实际的 A/B 测试不同,这些准实验是回顾性进行的。一般来说,它涉及分析用户的一个子集,以在产品或功能发布后模拟进行实验。让我们通过一个例子深入了解具体情况。
想象一下你是一家电商公司的数据科学家,比如 Shopee 或 Lazada。6 个月前,你的公司普遍推出了一项你们 CEO 相信能增加平台用户支出的互动功能 X。你的一位项目经理某天告诉你,CEO 想知道功能上线对公司的影响。你提醒项目经理,由于这是一次普遍推出,没有控制组,所有用户都可以使用互动功能。因此,你的项目经理说:“这没问题。只需比较那些实际使用功能的用户(处理组:Treatment=1)与那些没有使用功能的用户(控制组:Treatment=0)。进行一些假设检验,瞧,支出差异就是功能 X 的影响。”

图片来源:作者
根据你项目经理的智慧,你计算了一下,发现处理组与“控制”组之间有+ $12 的差异。我们是如何得到+ $12 的?
2023 年 4 月平均处理组支出:
$ (42 + 26) / 2 = $34
2023 年 4 月平均控制组支出:
$ (36 + 20 + 19 + 13) / 4 = $22
支出差异 = $34 减去 $22 = $12
你的项目经理对估算的提升感到满意,并为帮助确保团队的年度奖金给了你一个鼓励的拍背。
当你那晚躺在床上时,某些事情仍然困扰着你的良心。确实,这种方法存在混淆变量的问题。
ChatGPT 将 混淆变量 描述为:
混淆变量,也称为混杂因素,是指在研究中可能影响因变量(感兴趣的结果)和自变量(被研究的因素)的外部因素。这些变量可能导致关于自变量和因变量之间真实关系的误导或不正确的结论。
在这种情况下,潜在的混淆变量实际上是结果变量本身:用户支出。但这怎么可能呢?
如果那些在 2023 年 2 月花费更多的用户实际上更有可能在 2023 年 3 月使用这个新功能 X,那么观察到的 2023 年 4 月+12%的支出提升实际上可能归因于支出偏好的固有差异,而不是功能 X 的使用本身。换句话说,如果功能 X 没有推出,处理组的用户可能在 2023 年 4 月依然会花费更多。
在这里,准实验可以帮助提供更可靠(和良心的)影响估计。
我可以深入讨论混杂变量和偏差的统计学,但我不打算这样做,以便能够实际转到准实验。另一个你可能会疑惑的是,为什么随机化 A/B 测试不会受到混杂变量问题的影响。有关解释,请参见附录 A。再次说明,这不是今天讨论的重点。此外,混杂变量的概念与遗漏变量偏差密切相关,虽然不完全相同,但如果你想了解更多 解释 有关 OVB 的内容。
不过,希望我还没有让你感到困惑。

作者提供的图片:混杂关系的表示
准实验:怎么做?
这是一个高层次的概述,展示了准实验如何在没有随机化 A/B 测试的情况下克服混杂变量的问题。让我们继续使用产品特性 X 的推出作为例子。
对于处理组中的每 2 个用户,如果我们能瞥见一个平行宇宙,在那个宇宙中这 2 个用户没有使用特性 X,那将会很棒。由于我们还没有生活在科幻世界中,接下来的最佳选择是通过统计方法来估计这个平行宇宙。
具体而言,这通过在控制组中筛选出与处理组用户在 Feature X 上线前最相似的 4 名用户来完成。换句话说,这些伪控制组将模拟处理组用户如果没有采用 Feature X 会如何表现。
在我们方便修改的示例中,我们会找出控制组中与我们的处理组在 2023 年 2 月(上线前)的花费最相似的用户。具体来说:
对于用户 A,最接近的相似用户是用户 C,因为他们都在 2023 年 2 月花费了$10。
对于用户 B,最接近的相似用户是用户 D,因为他们都在 2023 年 2 月花费了$8。
因此,我们排除了用户 E 和 F 的分析。

作者提供的图片:将处理组与伪控制组进行匹配
总结一下,为什么我们选择用户 C 来与用户 A 进行比较?鉴于 2023 年 2 月(上线前)相同的花费,我们推测如果没有 Feature X 的推出和使用,用户 A 实际上会是用户 C。
那我们该如何告诉我们的 PM 呢?以下是修正后的计算结果:
2023 年 4 月的处理组平均花费:
$ (42 + 26) / 2 = $ 34
2023 年 4 月控制组的平均花费:
$ (36 + 20) / 2 = $28
花费差异 = $ 34 减去 $28 = $6
如该示例所示,使用 Feature X 的预估影响从$12 的提升降至$6。 不幸的是,对于你的 PM 来说,今年的奖金可能会比预期少一点。
当然,我在这里仔细调整了数值以证明我的观点,但实际上我在真实产品世界中见过更大比例的估计偏差。现在应该很清楚,拥有稳健的方法论在评估产品和特征决策的影响时是极其重要的。
在这一点上,值得提醒读者注意这个例子中存在的混淆因果关系。具体来说,过去的支出是特征 X 使用与未来支出之间因果关系的一个混淆因子。重要的是,存在混淆关系,因为我们假设过去的支出是采用特征 X 的可能性的一个指标。在我们进一步讨论倾向评分匹配(PSM)时,这种可能性概念是非常重要的。
倾向评分匹配
在上述特征 X 的方便例子中,我们只有一个混淆变量(2023 年 2 月的过去支出)。通过混淆框架来看这个问题,可能还有许多其他属性会影响发布后使用特征 X 的可能性。其中一些属性可能是已知且可观察的,而其他一些可能仍然未知或不可观察。
在我以前的一个项目中,我们有几十个用户属性,希望用来将处理用户与伪控制用户匹配。你可以应用 KNN 模型来找到相似的用户,但当属性过多和用户搜索量过大时,很快会遇到性能问题。如果你有数值和分类值的混合,还会有定义距离的额外复杂性。
克服这个计算问题的一种方法是通过降维过程。至少对我来说,这实际上就是 PSM 所做的。
回忆一下我们之前将混淆关系建模为混淆因子影响采用特征 X(处理)的可能性。因此,我们可以采取以下步骤来从我们的处理组中选择与伪控制组相似的子集:
-
使用逻辑回归(或任何其他产生概率预测的二元分类模型)来建模处理(1 或 0)与全部潜在混淆因子(过去支出等)之间的关系。
-
使用拟合的分类模型来预测每个用户属性/混淆因子下的处理(倾向评分)概率。
-
根据估计的倾向评分将处理组中的用户与伪控制组中的用户匹配。因此,倾向评分匹配。
通过采取这些步骤,匹配过程变得更具计算效率。除了极其高效外,数学在正确的假设下也能很好地运作。(有关期望值性质的更多信息,请参阅这篇文章,如果你对证明感兴趣且仍未信服,可以在这里查看)。当然,这种方法也有权衡之处,我将在下一部分讨论其中的一些局限性。
对于 R 用户,有一个非常知名的库叫做MatchIt来实现这个功能,其中一个示例可以在这里找到。个人而言,除了我不喜欢 R 之外,我也不喜欢那些将过多计算隐藏在背后的库,所以我从未真正使用过这个包,而是自己编写了代码在 PySpark 中进行匹配,以更高效地处理大型数据集(数百万用户)。另有一个简单的 Python 示例教程,大家可以在这里查看。如果你想了解更多关于我如何实现的内容,随时联系我。
倾向得分匹配:局限性
我想避免花费太多时间展示 PSM 的实现,部分原因是已经有许多参考示例,但也因为我更感兴趣的是讨论进行这种影响估计时涉及的细微差别。应用 PSM 很简单,但理解其假设、注意事项和局限性只有通过经验和实验才能获得。
不理解混杂变量的 PM 很危险,但误用准实验方法的数据科学家可能更危险,特别是因为这种方法表面上看起来很直观。
在我的第二部分文章中,我将花更多时间讨论我在将 PSM 应用于现实世界示例中的经验教训。现在,以下是应用 PSM 时应始终牢记的一些重要局限性:
可忽略的处理分配假设
PSM 依赖于“可忽略的处理分配”假设,这意味着所有影响处理概率(例如使用特征 X 的概率)和结果(例如未来支出)的混杂变量都已被充分测量并纳入倾向得分模型。如果存在未测量或不可观察的混杂因素,匹配可能无法充分解决估计偏差问题。
样本重叠和共同支持
PSM 要求处理组和对照组之间的倾向评分分布有足够的重叠。在共同支持有限的情况下(即,两组中具有相似倾向评分的个体很少或没有),匹配变得具有挑战性,我们可能需要采用其他方法。尽管我不是 R 的粉丝,MatchIt的文档很好地解释了支持的考虑因素以及其他匹配方法。
倾向评分估计中的选择偏倚
倾向评分模型(例如,逻辑回归)的准确性取决于模型的正确规范。如果模型被错误指定或包含不相关的变量,可能会引入选择偏倚。估计倾向评分的方法选择可以极大地改变估计的影响。不同的估计方法可能会产生不同的匹配结果,从而得出不同的因果影响估计。
关于预测倾向评分的模型的最后一点:在我与数据科学家讨论如何进行第 1 步(拟合逻辑回归模型)的过程中,许多人陷入了试图找到最佳模型以减少预测误差的困境。我不怪他们想这样做,这几乎是数据科学家的本能,想要在运行 model.fit(X, y)时最小化 RMSE。然而,同样值得记住的是,PSM 建模的直接目标并不是获得最佳分类预测。相反,是找到一种计算上高效的方式来实现 属性平衡 和共同支持。因此,最适合的模型可能并不总是产生最佳的匹配结果。更多内容请参见第二部分文章。
结论
始终记住我们在这里测量的内容是很重要的。对于那些熟悉统计学的人来说,我们是在克服使用平均处理效应(ATE)的偏倚,PSM 仅返回处理组的平均处理效应(ATT)。有关更全面的讨论,请参阅此链接。因此,在将 ATT 估计结果推广到更广泛的人群时,请谨慎行事。(我会在后续讨论中保存这个话题)
最后,数据科学家通常在影响产品和特性的决策和评估方面拥有巨大的权力。因此,我认为数据科学家肩负着相应的责任,在进行影响估计时,必须使用最稳健和可靠的方法。第二部分见!
附录 A:随机化与混杂
随机化是通过以相等的概率将每个个体分配到一个处理组(例如,对照组或处理组)来实现的。我们可以将处理分配表示为一个二元变量:
-
如果个体接受治疗,则 T = 1,
-
如果个体在对照组,则 T = 0。
随机分配的关键特性是治疗分配与任何潜在结果(Y)或协变量(X),包括观察到的和未观察到的,都是独立的。从数学上讲,我们可以将其表示为:
P(Y|T, X) = P(Y|T)
这意味着给定处理和协变量的情况下结果的概率与仅给定处理的情况下结果的概率相同。
参考文献
发现相关性:寻找新的相关性。来自 Tyler 的说明:现在这项功能无法使用——抱歉!存在冲突…
www.tylervigen.com ## 统计入门:倾向评分匹配及其替代方法†
摘要。倾向评分(PS)方法相对于传统的回归方法在控制…方面提供了某些优势。
## 理解遗漏变量偏差 ## MatchIt:入门指南
针对最普遍的偏差类型的逐步指南
towardsdatascience.com ## MatchIt:入门指南
Noah Greifer Ho 等(2007 年建议的 MatchIt)用于改进参数统计模型…
特别感谢:Shin Ler
因果 Python——埃隆·马斯克的推文,我们的搜索习惯,以及贝叶斯合成控制
使用带有贝叶斯改进的合成控制量化推文的影响(使用 CausalPy)
·发表于Towards Data Science ·阅读时间 11 分钟·2023 年 1 月 8 日
--

图片由Tolga Aslantürk在Pexels提供
2022 年 10 月给 Twitter 旧金山总部(以及一个水槽)带来了许多新变化。特斯拉和 SpaceX 的首席执行官埃隆·马斯克于 10 月 27 日成为公司新任所有者和首席执行官。
一些观众热烈欢迎这一变化,而另一些则保持怀疑态度。
一天后,即 10 月 28 日,马斯克推特上发了“鸟儿被释放了”。
推文的威力有多大?
让我们看看吧!

图片由Laura Tancredi在Pexels提供。
目标
在这篇博文中,我们将使用CausalPy——来自PyMC Developers(www.pymc-labs.io)的全新 Python 因果包,来估计马斯克的推文对我们搜索行为的影响,运用一种强大的因果技术叫做合成控制。我们将讨论该方法的基本原理,逐步实施,并分析我们方法的潜在问题,同时链接到额外的资源。
准备好了吗?
介绍
-
2022 年 11 月初,我安排了一次会议演讲,讲解如何量化时间序列数据中的干预效果。我认为在演讲中使用一个真实世界的例子会很有趣,于是我想到了马斯克的推文。关于 Twitter 收购案在互联网上引发了很多讨论,我想知道这样的推文在多大程度上能影响我们的行为,超越传统的社交媒体活动,例如它如何影响我们搜索“Twitter”的频率?
-
嵌入 1. 埃隆·马斯克的推文。
-
但首先要讲清楚一件事。
- 因果关系与实验
-
因果分析旨在识别和/或量化干预(也称为处理)对感兴趣结果的影响。我们在世界上改变一些东西,想要理解我们行动的结果如何改变其他东西。例如,一家制药公司可能对确定新药对特定患者群体的效果感兴趣。这可能因为多种原因而具有挑战性,但最基本的原因是无法在同一时间观察到一个患者同时服用药物和不服用药物(这被称为因果推断的基本问题)。
-
人们找到许多聪明的方法来克服这个挑战。如今被认为是黄金标准的方法称为随机实验(或随机对照试验;RCT)¹。在 RCT 中,参与者(或一般情况下有时称为 单位) 被随机分配到治疗组(接受治疗)或对照组(不接受治疗)²。
- 学习 3 种因果效应识别技术,并在 Python 中实现它们,而无需耗费几个月、几周或几天的时间…
-
towardsdatascience.com
-
我们期望在设计良好的 RCT 中,随机化将平衡治疗组和对照组在 混杂因素 和其他重要特征方面的差异,这种方法通常相当成功!
-
不幸的是,由于经济、伦理或组织等多种原因,实验并不总是可用的。
-
如果我们…

- 图片由 Engin Akyurt @ pexels.com
- 合成控制
如果我们只能观察处理下的结果,而控制组不可用呢?阿尔贝托·阿巴迪和哈维尔·加尔德亚萨巴尔在评估巴斯克地区冲突的经济成本时遇到了这种情况(Abadie & Gardeazabal, 2003)。他们的论文孕育了我们今天讨论的方法——合成控制。
这个方法背后的基本思想很简单——如果我们没有控制组,就创造一个!
如何?
一个解决方案是预测它。
如果我们选择一些某种程度上相似于我们的处理单位(但保持未处理的)并将它们用作预测变量呢?这就是合成控制(几乎完全)所做的!
这些未处理的单位有时被称为捐赠池。记住,我们处于时间序列数据的领域,基本的合成控制估计器是未处理单位的加权和。我们将使用一个额外的权重约束,强制权重在0和1之间,并且加起来等于一²。
每个权重调整每个未处理单位对结果的贡献。你可以把它看作是一个时间上的约束线性回归。
我们在处理前观察数据上拟合模型,并预测处理后的结果值。这一逻辑基于一个假设,即捐赠池变量没有受到处理的影响。当这个假设成立时,预测的处理后控制组应该保持所有处理前特征不变(假设捐赠池变量足够好地预测结果)。
如果你想看到合成控制的逐步实现以及整洁呈现的数学,查看Matteo Courthoud’s 博客文章和/或Matheus Facure’s 章节关于合成控制。如果你想要更多应用研究的背景,请查看 Scott Cunningham 的“因果推断——混合带”。对于贝叶斯实现(我们这里使用的),请查看CausalPy 源代码。
[## 是的!六本因果关系书籍将使你从零基础到高级(2023)
…如果你愿意,可以完全免费获得其中的三本书!🤗
aleksander-molak.medium.com](https://aleksander-molak.medium.com/yes-six-causality-books-that-will-get-you-from-zero-to-advanced-2023-f4d08718a2dd?source=post_page-----187114fc4aa8--------------------------------)
假设
回到我们的推文。我假设马斯克广泛讨论的推文(“小鸟自由了”)使人们对 Twitter 本身以及相关新闻产生了更多兴趣。因此,我们期望观察到相对于其他社交媒体平台,“Twitter”的搜索量有所增加。
实际上,这一假设很难验证,因为结果可能不仅受到马斯克推文的影响,还可能受到其他因素(例如媒体对 Twitter 收购的报道)的影响。请注意,这实际上是一个很好的例子,说明了混杂如何在合成控制分析中发生³(Twitter 收购导致马斯克的推文以及引发对该平台的兴趣增加)。你认为哪种规范(推文作为原因或收购作为原因)更合理?请在评论中告诉我!
由于这是一个有趣的帖子,我们将假设马斯克推文对搜索行为的影响没有混杂,并且我们可以安全地进行估计。如果你决定自己估计 Twitter 收购对“Twitter”搜索数量的影响,请随时通过LinkedIn与我分享你的结果,或者加入 Causal Python 社区(https://causalpython.io),直接将结果回复到我们的一封每周邮件中。
数据中的马斯克推文
我们使用Google Trends作为代表全球每日搜索量的时间序列数据来源。我们对“Twitter”的搜索变化感兴趣,因此我们收集了这个搜索的数据,同时也收集了“TikTok”、“Instagram”和“LinkedIn”的数据,以用作我们的捐赠池。
我们将使用 2022 年 5 月 15 日至 11 月 11 日之间的数据。
让我们看看图表。

图 1. Twitter、LinkedIn、TikTok 和 Instagram 搜索的数据。图像由本人提供。
我们可以看到 Twitter 和 Instagram 是搜索量最多的平台。它们之间存在一定的相关性。我们还可以看到 LinkedIn 的搜索量具有非常强的季节性特征,周末的搜索量明显较少,这与该网站的职业性质相符。
马斯克在 10 月 28 日发布了他的“小鸟自由了”推文。让我们把这个信息添加到图表中。

图 2. 包括处理(黑色虚线)的Twitter、LinkedIn、TikTok 和 Instagram 搜索的数据。图像由本人提供。
我们看到 Twitter 搜索量的急剧增加与马斯克推文的发布日一致。
让我们看看在合成产生的对照组下效果有多强。
让我们建模吧!
我们从导入开始。
代码块 1。 导入库。
我们遵循CausalPy 文档的惯例,并将库导入为cp。我们导入pandas来读取数据,并导入matplotlib来帮助我们绘图。
我们读取数据并将索引转换为日期时间(这帮助我们生成了上面的图表,并使得索引治疗更容易,但并非必要)。
代码块 2。 读取数据并将索引更改为日期时间类型。
让我们简单看一下数据。

图 3。 我们数据集的前五行。图片由我本人提供。
正如预期的那样,我们看到四个变量和一个日期时间索引。我们将使用“LinkedIn”、“TikTok”和“Instagram”搜索作为捐赠池信号。
让我们将治疗日期存储在一个变量中,并实例化模型。
代码块 3。 将治疗日期存储在变量中并实例化模型。
我们使用WeightedSumFitter模型,这将允许我们为每个捐赠池变量找到权重,以生成最佳拟合的合成控制。你可能还记得我们之前说过,我们对这些权重使用了两个约束:
-
它们应总和为1。
-
它们应在0和1之间。
请注意,如果第一个条件为真,则第二个条件可以用更不严格的非负约束代替;我们使用了更严格的条件,因为它可能对一些读者更直观。
满足这些约束可以通过多种方式实现。如果你查看了我们上面提到的其中一个参考资料(Matteo 的博客或 Matheus 的书),你可能会注意到他们都使用了约束优化来实现这个目标。由于我们使用贝叶斯方法,我们需要在分布层面上对这些约束进行编码。与我们所需约束非常匹配的分布是Dirichlet 分布。Dirichlet 分布的样本总和为1且非负。如果这让你想起了贝塔分布,那是一个很好的直觉!Dirichlet 是贝塔分布的(多维)推广。
CausalPy 将在后台负责初始化和拟合分布。我们现在准备好定义和拟合模型了!
CausalPy 支持 R 风格的公式来定义模型。公式twitter ~ 0 + tiktok + linkedin + instgram表示我们想要将 Twitter 搜索随时间的变化建模为“TikTok”、“LinkedIn”和“Instagram”搜索的函数。公式开头的零告诉模型我们不想拟合截距。
代码块 4。 定义和拟合模型。
我们使用SyntheticControl实验对象,它将负责模型拟合和结果生成。我们向构造函数传递四个参数:数据集、治疗索引、定义模型的公式和模型对象(我们选择了WeightedSumFitter)。
如果你自己运行代码,你会注意到初始化采样器并采样链条需要一些时间,但大约过一分钟后我们应该可以开始绘制结果。
[## Python 中的因果推断与发现:解锁现代因果机器学习的秘密…
《Python 中的因果推断与发现》:通过 DoWhy、EconML 解锁现代因果机器学习的秘密……
amzn.to](https://amzn.to/3NiCbT3?source=post_page-----187114fc4aa8--------------------------------)
结果
让我们来检查结果!results对象有一个非常方便的方法叫做.plot(),可以有效地以图形方式总结结果。
代码块 5。 绘制结果。
这给出了以下输出:

图 4。 我们模型的结果。图片由本人提供。
在图的顶部,我们看到处理前贝叶斯R²(Gelman 等,2018)的打印输出,量化了我们的捐赠者池变量对 Twitter 处理前搜索次数的预测效果。
最上面的面板展示了结果变量的实际观察(黑点)、处理前结果的预测(深蓝色线)、捐赠者池变量(灰色)、我们生成的合成控制(绿色)、干预时间(垂直红线)和干预效果(阴影蓝色区域)。
在中间面板中,我们看到预测的因果影响在处理前后的情况。
最后,底部面板展示了累积因果效应。
总结一下!
贝叶斯R²为0.385表明模型的处理前拟合效果不是很好(完美拟合的R²为1)⁴。考虑到我们的捐赠者池较小,这并不令人惊讶。许多从业者建议作为经验法则,捐赠者池中变量至少应在 5 到 25 个之间。我们只有 3 个。
另一方面,我们可以非常确定我们没有过拟合,这在捐赠者池较大的情况下可能会发生(参见 Abadie,2021)。
如果我们认为分析中没有隐藏的混杂因素,埃隆·马斯克的推文的处理后效果相对较大,表明他的推文足够强大,能够暂时改变我们的搜索行为!
注意,另一种假设(Twitter 获取而非推文作为处理)看起来很有前景——你有没有注意到在干预前“Twitter”的搜索次数有所增加?
如果你决定检验这个假设,与我和 社区分享你的结果吧!
关于 CausalPy
CausalPy 仍处于初期阶段,但正在稳步成长。我收到图书馆创建者的消息,表示一些令人兴奋的新功能正在开发中,包括对合成控制的用户定义先验的支持。此外,图书馆的功能不仅限于这一种方法。确保查看最新版本和更新,请访问这里:github.com/pymc-labs/CausalPy
代码和conda 环境文件可在这里获取:
[## blogs-code/Causal Python - Elon Musk 的推文是否改变了我们的搜索习惯?]
目前您无法执行该操作。您在另一个标签或窗口中已登录。您在另一个标签或窗口中已注销...
脚注
¹ 虽然存在多个处理和/或多个对照组的实验设计,但在这里我们保持简单。
² 请注意,这些约束并不是必要的,但当捐赠者池变量的值既高于 又 低于结果变量的值时,这会迫使模型不对超出我们观察到的值的范围进行外推。在我们的案例中,这很有意义——参见图 1,其中Instagram(大多数时候)高于Twitter,其他平台则低于。允许模型进行外推并不错误,但它存在模型幻觉变量如何在其观察范围之外表现的风险。如果这让您想起了积极性假设——那是一个很好的直觉!有关积极性和外推的更多信息,请访问:causalpython.io/#positivity
³ 请注意,我们也可以说这种情况违反了SUTVA假设中的无多重处理版本部分,但我认为混杂视角更清晰、更直观。
⁴ 需要记住的是,使用R²作为拟合优度指标会带来自身的挑战。
了解更多关于 Python 中因果关系的内容:
[## Causal Python: 3 个简单技巧来快速启动您的因果推断之旅]
学习 3 种因果效应识别技巧,并在 Python 中实现它们,不用花费几个月、几周或几天的时间...
towardsdatascience.com [## Causal Python: 提升你在 Python 中的因果发现技能 [超越基础!] (2023)
…并挖掘 Python 中最优秀且最被低估的因果发现包的潜力!
towardsdatascience.com [## 是的!六本因果关系书籍将带你从零到高级 (2023)
…而且如果你愿意的话,可以完全免费获得其中的三本!🤗
参考文献
Abadie, A. (2021). 使用合成控制法:可行性、数据要求和方法论方面。经济文献杂志。
Abadie, A., & Gardeazabal, J. (2003). 冲突的经济成本:以巴斯克地区为例。公共选择与政治经济学电子期刊。
Athey, S., & Imbens, G. (2017). 应用计量经济学的现状——因果关系和政策评估。经济学视角杂志 32(2)。
Gelman, A., Goodrich, B., Gabry, J., & Vehtari, A. (2018). 贝叶斯回归模型的 R 平方。美国统计学家。
因果 Python:2023 年 NeurIPS 大会上的五个新颖因果观点
原文:
towardsdatascience.com/causal-python-five-novel-causal-ideas-at-neurips-2023-13bb68c5ed56
令人兴奋的新想法,将因果关系与生成建模、保形预测和拓扑学结合起来。
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 9 月 24 日
--

NeurIPS 被认为是全球最重要和最负盛名的人工智能和机器学习会议之一,因为其严格的论文审查过程和高质量的研究。
会议具有跨学科的关注点,涵盖了与开发智能系统和机器学习算法相关的广泛话题。
近年来,NeurIPS 大会上接受的与因果关系相关的论文数量呈指数增长。
在本文中,我们介绍了 2023 年会议上接受的五篇因果论文,这些论文引起了我的注意,带来了对该领域的重要新见解。
请注意,这是一个主观且肯定不完整的列表。原因之一是,在撰写时,NeurIPS 尚未发布会议接受论文的完整列表。
尽管如此,我相信下面介绍的论文中的想法有可能推动我们的领域向前发展。
开始吧!
保形元学习者
保形预测是一类不确定性量化技术,最初由 Vladimir Vovk 提出。
保形预测是无模型的(无需分布假设),并提供频率学覆盖保证。换句话说,它保证在可交换性假设下,真实结果将以高概率落入预测区间(或集合)¹。

图 1. 保形元学习者的描述。来源:bit.ly/44Z9U9L
在他们的新论文《用于个体治疗效应预测推断的符合性元学习器》中,Ahmed Alaa 及其同事提出了一种新的符合性元学习框架,该论文刚刚被 NeurIPS 2022 接受。
他们的方法使得直接推断目标参数(个体化治疗效应;ITE)成为可能,这是对先前方法的重要改进。
作者在一系列使用合成数据和半合成数据的实验中评估了该方法的性能,并以实现的覆盖率、均方根误差和区间长度来衡量性能。
结论?符合性DR-learner在大多数设置中表现优越。
该框架的一个局限性是它要求已知倾向评分,这在某些情况下可能是限制性的。
这项工作为将因果方法与符合性预测器提供的覆盖保证相结合开辟了一个令人兴奋的新方向。
🟡 阅读论文
🟡 查看代码
- 想了解关于元学习器和 DR-learner 的内容?请查看Causal Inference and Discovery in Python的第九章和第十章,或者查看该书的免费GitHub 仓库。
因果归一化流
归一化流是一类神经模型,通过将简单分布转换为复杂分布来表示复杂分布。
特别是,自回归归一化流将变量 X 的分布估计为其前面变量的函数。如果我们想要将一个变量表示为其前面变量的函数,我们首先需要以某种方式对这些变量进行排序。
请注意,这种设置类似于结构性因果模型(SCM),其中每个变量被表示为其父变量的函数。

图 2. 线性 SCM 的示例(a)以其通常的递归形式书写
公式;(b)没有递归,每一步都明确;(c)没有递归,作为一个单一的
函数;以及(d)将 u 写作 x 的函数。来源:bit.ly/3ZvwbuD
归一化流以前曾用于因果发现(例如,CAREFL(Khemakhem 等,2021)用于DECI(Geffner 等,2022);有关详细信息,请参见 Molak(2023),第十三章)。
在他们的新论文《因果归一化流:从理论到实践》中,Adrián Javaloy 及其同事将这些想法提升到了一个新水平。他们展示了在给定因果排序的情况下,因果模型可以从观察数据中识别,并使用归一化流进行恢复。
接下来,他们提出了一种在因果归一化流中实现do操作符的方法,这使我们能够回答干预性和反事实查询。
最后,他们在一个不完整图上的混合(连续/离散)数据上展示了他们方法的有效性。
真是一个令人兴奋的时代!
🟡 阅读论文
🟡 查看代码
## 因果 Python — 提升你在 Python 中的因果发现技能(2023)
…并解锁 Python 中最佳因果发现包的潜力!
[towardsdatascience.com
部分反事实识别的生成模型
反事实被置于 Pearl 的因果阶梯的第三层。
这使得它们最难处理,因为我们需要非常丰富的结构因果模型(SCM)描述,以解决反事实查询。

图 3. Pearl 因果阶梯的符号表示。反事实查询需要 SCM 的最丰富表示。来源:bit.ly/45UzGgx
所谓的符号可识别性并不总是可用,其他可以使回答反事实查询可行的假设(例如单调性)在某些场景下可能难以满足,或仅在离散情况下有效。
Valentyn Melnychuk 及其同事的新论文《具有曲率敏感性模型的连续结果的部分反事实识别》提出了一种超越这些限制的新方法。
作者从拓扑学的角度(有关早期工作的参考见例如 Ibeling & Icard, 2021)来看待这一挑战,并提出了一种曲率敏感性模型(CSM),允许进行部分反事实识别。
连续结果。
换句话说,该方法使我们能够在没有关于数据生成过程的完整信息时,为连续结果下的反事实查询找到有信息量的界限。
作者建议,解决方案所依赖的假设应在从物理学到医学的广泛用例中是现实的,并且该方法在安全关键环境中的决策中具有潜在的相关性。
作为附注:所提出的方法还依赖于归一化流。
🟡 阅读论文
🟡 查看代码
被动数据、主动因果策略和语言模型
如果你关注我在LinkedIn或Twitter/X,你可能会记得关于这篇论文的帖子。
实际上,这篇论文是激励我和我的同事组织AAAI 2024 关于大型语言模型和因果性的研讨会的论文之一。
但是,言归正传!

图 4. 因果 DAG 环境和实验结果。(a) 施加的约束
在创建训练数据集时,以及在测试时评估代理时,因果 DAG 结构
结构。在训练期间,D 不允许成为 E 的祖先(即使是间接的),尽管它们
可能由于混淆变量而相关。在评估环境中,D 是最有影响力的
E 的祖先(见文本)。(b) 在互动设置中评估时,代理获得的奖励,作为
最优奖励的百分比。在这两种评估设置中,代理仍然接近最优
奖励。© 更详细地分析代理的行为,通过绘制代理的比例
与最优行为相匹配的行动,或基于干预或
相关统计数据。代理的策略与最优策略的匹配显著更接近。
与启发式基准匹配。来源:bit.ly/3Rt4cd3
在他们的论文《代理和语言模型中的被动学习主动因果策略》中,DeepMind 的 Andrew Lampinen 和同事们展示了代理和(大规模)语言模型(LLMs)可以从被动(观察)数据中学习主动因果策略。
这些策略可以泛化到分布外的数据(!),但仅在某些条件下。
正如作者提出的,代理“获得了用于发现和利用因果结构的可泛化策略,只要他们可以在测试时进行干预”。
提示中的解释对 LLMs 的泛化能力至关重要。此工作并不意味着被动学习超越主动学习或完全解决 LLMs 中的混淆。然而,它标志着扩展语言模型因果能力的重要一步。
我迫不及待想看到更多的研究继续在这个激动人心的路径上前进!
🟡 阅读 论文
广义敏感性分析
当我们不能排除存在隐藏混淆的可能性时,敏感性分析对于因果分析师至关重要。
传统的敏感性分析方法通常存在局限性(例如,它们假设线性模型或单一的二元处理)。
在他们的新论文中, 《广义因果敏感性分析的锐界》,*Dennis Frauen 和同事们提出了一种新的统一
用于未观察混淆下因果敏感性分析的框架。

图 5. 在所提出模型下的干预分布界限。来源:bit.ly/3PLVBkl
该框架为一系列因果效应提供了锐界,包括(条件)平均处理效应(CATE)、中介效应、路径分析和分布效应。
此外,所提出的方法适用于离散、连续和时间变化的处理。
最棒的部分?论文附带了一个丰富的代码库。
真是一个活着的好时光!
🟡 阅读 论文
🟡 检查 代码
## 因果 Python || 你学习 Python 中因果关系的首选资源
每周免费邮件关于因果关系和机器学习
脚注
¹ 在这里,可交换性的含义与潜在结果框架中的不同。我们可以将其视为 IID 假设的一个较温和版本。有关更多细节,请参见 这里。
参考文献
Geffner, T., Antorán, J., Foster, A., Gong, W., Ma, C., Kıcıman, E., Sharma, A., Lamb, A., Kukla, M., Pawlowski, N., Allamanis, M., & Zhang, C. (2022). 深度端到端因果推断。arXiv
Ibeling, D., Icard, T. (2021). 关于因果推断的拓扑视角。第 35 届神经信息处理系统大会。
Khemakhem, I., Monti, R., Leech, R. & Hyvarinen, A. (2021). 因果自回归流。第 24 届国际人工智能与统计会议论文集,在机器学习研究论文集,130,3520–3528. proceedings.mlr.press/v130/khemakhem21a.html。
Molak, A. (2023). Python 中的因果推断与发现:解锁现代因果机器学习的秘密,包括 DoWhy、EconML、PyTorch 等。Packt Publishing。
CFXplorer:反事实解释生成 Python 包
原文:
towardsdatascience.com/cfxplorer-counterfactual-explanation-generation-python-package-483ca4221ab8
介绍了一款用于生成基于树的算法反事实解释的 Python 包
·发表于Towards Data Science ·阅读时间 9 分钟·2023 年 8 月 17 日
--
随着机器学习模型在现实场景中的应用日益增多,对模型可解释性的重视也在不断增加。了解模型如何做出决策不仅对模型的用户有益,也对受到模型决策影响的人员有所帮助。反事实解释的出现就是为了解决这个问题,因为它允许个人了解通过改变原始数据如何能获得理想的结果。在短期内,反事实解释可能会为那些受到机器学习模型决策影响的人员提供可行的建议。例如,一个被拒绝贷款申请的人可以知道这次该做些什么来被接受,这将有助于他们在下次申请中改进。
Lucic 等人[1]提出了 FOCUS,它旨在为树基机器学习模型中的所有实例生成与原始数据的最优距离反事实解释。
CFXplorer 是一个使用 FOCUS 算法为给定模型和数据生成反事实解释的 Python 包。本文介绍并展示了如何使用 CFXplorer 生成反事实解释。
链接
GitHub 仓库:github.com/kyosek/CFXplorer
文档:cfxplorer.readthedocs.io/en/latest/?badge=latest
PyPI:pypi.org/project/CFXplorer/
目录
-
FOCUS 算法
-
CFXplorer 示例
-
限制
-
结论
-
参考文献

照片由 Wesley Sanchez 提供,来源于 Unsplash
1. FOCUS 算法
本节简要介绍了 FOCUS 算法。
生成反事实解释的问题已经被一些现有方法解决。Wachter、Mittelstadt 和 Russell [2] 将这个问题形式化为优化框架,但这种方法仅限于可微分模型。FOCUS 旨在通过引入概率模型近似,将框架扩展到非可微分模型,特别是基于树的算法。该方法的一个关键方面是对预训练基于树的模型(表示为 f)的近似,通过用具有参数 σ 的 sigmoid 函数替换每棵树中的每个分裂来实现,参数 σ 定义为:

其中 σ ∈ R>0。
这个 sigmoid 函数被纳入了函数 t ̃_j(x) 中,该函数近似了树模型 f 的节点 j 激活 t_j(x) 对于给定输入 x。该函数定义为:

其中 θ_j 是节点 j 的激活阈值。
该方法近似于单棵决策树 T。树的近似可以定义为:

此外,该方法将* f* 的最大操作,即由权重 ω_m ∈ R 的M棵树的集合替换为带有温度 τ ∈ R>0 的 softmax 函数。因此,近似的 f ̃ 可以表示为:

重要的是要注意,这种近似方法可以应用于任何基于树的模型。
FOCUS 算法的主要声明是该方法能够 (i) 为数据集中所有实例生成反事实解释,并且 (ii) 为基于树的算法找到更接近原始输入的反事实解释,比现有框架更优。
2. CFXplorer 示例
本节展示了如何使用 CFXplorer 包的两个示例。第一个是一个简单示例,您可以了解包的基本用法。第二个示例展示了如何通过使用 Optuna [3] 包来搜索 FOCUS 的最佳超参数。正如本文章在前面部分所述,FOCUS 有一些超参数。这些超参数可以通过与超参数调优包集成来优化。
2.1. 简单示例
在这个简单示例中,我们创建随机数据、决策树模型,并使用 CFXplorer 生成反事实解释。Python 包 CFXplorer 通过使用 FOCUS 算法生成反事实解释。本节演示了如何使用这个包来实现这一点。
安装
您可以使用 pip 安装该包:
pip install CFXplorer
首先,导入相关包。
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cfxplorer import Focus
from sklearn.datasets import make_classification
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.tree import DecisionTreeClassifier
我们创建一个虚拟数据集以供决策树模型使用。
def generate_example_data(rows: int = 1000):
"""
Generate random data with a binary target variable and 10 features.
Args:
rows (int): The number of rows in the generated dataset.
Returns:
pandas.DataFrame: A DataFrame containing the randomly generated data.
"""
X, y = make_classification(
n_samples=rows, n_features=10, n_classes=2, random_state=42
)
return train_test_split(X, y, test_size=0.2, random_state=42)
CFXplorer 只接受标准化的特征值(在 0 和 1 之间),因此我们需要对其进行缩放。
def standardize_features(x_train, x_test):
"""
Standardizes the features of the input data using Min-Max scaling.
Args:
x_train (pandas.DataFrame or numpy.ndarray): The training data.
x_test (pandas.DataFrame or numpy.ndarray): The test data.
Returns:
tuple: A tuple containing two pandas DataFrames.
- The first DataFrame contains the standardized features of the training data.
- The second DataFrame contains the standardized features of the test data.
"""
# Create a MinMaxScaler object
scaler = MinMaxScaler(feature_range=(0, 1))
# Fit and transform the data to perform feature scaling
scaler = scaler.fit(x_train)
scaled_x_train = scaler.transform(x_train)
scaled_x_test = scaler.transform(x_test)
# Create a new DataFrame with standardized features
standardized_train = pd.DataFrame(scaled_x_train)
standardized_test = pd.DataFrame(scaled_x_test)
return standardized_train, standardized_test
现在训练决策树模型。
def train_decision_tree_model(X_train, y_train):
"""
Train a decision tree model using scikit-learn.
Args:
X_train (array-like or sparse matrix of shape (n_samples, n_features)): The training input samples.
y_train (array-like of shape (n_samples,)): The target values for training.
Returns:
sklearn.tree.DecisionTreeClassifier: The trained decision tree model.
"""
# Create and train the decision tree model
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)
return model
我们将上述所有内容结合起来运行。
X_train, X_test, y_train, y_test = generate_example_data(1000)
X_train, X_test = standardize_features(X_train, X_test)
model = train_decision_tree_model(X_train, y_train)
一旦我们获得了数据和模型,我们初始化Focus。Focus 需要几个参数进行定制。但为了简单起见,在这个例子中,我们可以使用迭代次数和距离函数。
focus = Focus(
num_iter=1000,
distance_function="cosine",
)
FOCUS 的其他参数是;
distance_function: str, optional (default="euclidean")
Distance function - one of followings;
- "euclidean"
- "cosine"
- "l1"
- "mahalabobis"
optimizer: Keras optimizer, optional (default=tf.keras.optimizers.Adam())
Optimizer for gradient decent
sigma: float, optional (default=10.0)
Sigma hyperparameter value for hinge loss
temperature: float, optional (default=1.0)
Temperature hyperparameter value for hinge loss
distance_weight: float, optional (default=0.01)
Weight hyperparameter for distance loss
lr: float, optional (default=0.001)
Learning rate for gradient descent optimization
num_iter: int, optional (default=100)
Number of iterations for gradient descent optimization
direction: str, optional (default="both")
Direction of perturbation (e.g. both, positive and negative)
hyperparameter_tuning: bool, optional (default=False)
if True, generate method returns unchanged_ever and mean_distance
verbose: int, optional (default=1)
Verbosity mode.
- 0: silent
- else: print current number of iterations
最后,我们可以使用generate方法生成反事实解释。
perturbed_feats = focus.generate(model, X_test, X_train)
我们可以在图中检查这些生成的反事实解释。
def plot_pca(plot_df, focus_plot_df):
"""
Plots the PCA-transformed features and corresponding predictions before and after applying FOCUS.
Args:
plot_df (pandas.DataFrame): A DataFrame containing the PCA-transformed features and
predictions before applying FOCUS.
focus_plot_df (pandas.DataFrame): A DataFrame containing the PCA-transformed features and
predictions after applying FOCUS.
Returns:
None: This function displays the plot but does not return any value.
"""
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
sns.scatterplot(
data=focus_plot_df, x="pca1", y="pca2", hue="predictions", ax=axes[0]
)
axes[0].set_title("After applying FOCUS")
sns.scatterplot(data=plot_df, x="pca1", y="pca2", hue="predictions", ax=axes[1])
axes[1].set_title("Before applying FOCUS")
fig.suptitle("Prediction Before and After FOCUS comparison")
plt.show()
plot_df, focus_plot_df = prepare_plot_df(model, X_test, perturbed_feats)
plot_pca(plot_df, focus_plot_df)
它看起来像这样:

我们可以观察到,在应用 FOCUS 之前,许多预测== 1 位于右侧,但应用 FOCUS 后,它们变成了预测== 0。对于在 FOCUS 之前的预测== 0,它们位于左侧,并变成预测== 1。
2.2. 超参数优化
主要有四个 FOCUS 的超参数,具体来说,sigma(方程 1)、温度(方程 4)、距离权重,它是距离损失和预测损失之间的权衡参数,以及 Adam 的学习率[4]。
注意 1:在这个例子中,我们将使用决策树模型,因此我们不会使用 *temperature* 超参数。
注意 2:你可以将优化算法(这里使用的是 *Adam* )视为一个超参数,但我们不会在本节中优化它,其他 *Adam* 的超参数也同样如此,除了学习率为了简单起见。
本节使用 Optuna 来优化 FOCUS 的超参数。Optuna 是一个强大的超参数优化工具,执行贝叶斯优化。除了 Optuna,我们还可以再次使用我们之前创建的相同函数;generate_example_data、standardize_features和train_decision_tree_model。
以下是目标函数。它定义了要调整的超参数以及优化目标。在这个例子中,我们在Focus类中调整 3 个超参数,即 sigma、距离权重和 Adam 优化器的学习率。这些超参数的搜索空间被定义为trial.suggest_float或trial.suggest_int. 损失函数定义为cfe_distance /100 + pow(unchanged_ever, 2). 这样做的原因,如函数的文档字符串中所写,我们希望优先找到反事实解释,而不是最小化平均距离。因此,我们取未改变实例的平方数。
注意:重要的是要将*Focus*类的*hyperparameter_tuning*参数设置为*True*。 否则,它不会返回未改变实例的数量和平均反事实解释距离。
import optuna
import tensorflow as tf
from cfxplorer import Focus
def objective(trial):
"""
This function is an objective function for
hyperparameter tuning using optuna.
It explores the hyperparameter sets and evaluates the result on a
given model and dataset
Mean distance and number of unchanged instances are
used for the evaluation.
Args:
trial (optuna.Trial):
Object that contains information about the current trial,
including hyperparameters.
Returns:
Mean CFE distance + number of unchanged instances squared -
This is the objective function for hyperparameter optimization
* Note: typically we want to minimise a number of unchanged first,
so penalising the score by having squared number.
Also, to not distort this objective,
having the mean distance divided by 100.
"""
X_train, X_test, y_train, y_test = generate_example_data(1000)
X_train, X_test = standardize_features(X_train, X_test)
model = train_decision_tree_model(X_train, y_train)
focus = Focus(
num_iter=1000,
distance_function="euclidean",
sigma=trial.suggest_int("sigma", 1, 20, step=1.0),
temperature=0, # DT models do not use temperature
distance_weight=round(
trial.suggest_float("distance_weight", 0.01, 0.1, step=0.01), 2
),
lr=round(trial.suggest_float("lr", 0.001, 0.01, step=0.001), 3),
optimizer=tf.keras.optimizers.Adam(),
hyperparameter_tuning=True,
verbose=0,
)
best_perturb, unchanged_ever, cfe_distance = focus.generate(model, X_test)
print(f"Unchanged: {unchanged_ever}")
print(f"Mean distance: {cfe_distance}")
return cfe_distance / 100 + pow(unchanged_ever, 2)
一旦我们定义了目标函数,就可以开始调整这些超参数。
if __name__ == "__main__":
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=100)
print(f"Number of finished trials: {len(study.trials)}")
trial = study.best_trial
print("Best trial:")
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
更全面的示例可以在软件包仓库中找到。
4. 限制
Focus 类存在几个限制。以下是这些限制的列表:
-
目前,Focus 类仅适用于 scikit-learn 的
DecisionTreeClassifier、RandomForestClassifier和AdaBoostClassifier。 -
虽然类别特征可以包含在特征集中,但需要注意的是,类别特征变化的解释,例如从年龄 40 变为 20,可能不会提供有意义的见解。
-
在应用 Focus 之前,输入特征应缩放到 0 和 1 的范围。因此,在使用 Focus 之前,有必要对特征进行转换。然而,这种缩放过程可能在应用 Focus 后解释特征时引入一些额外的复杂性。
-
计算成本会随着给定模型的增大而增加。当你有一个大型模型时,可能无法执行代码。
5. 结论
CFXplorer Python 包提供了对 FOCUS 算法的全面使用,以生成给定基于树的算法的反事实解释的最优距离。尽管存在一些限制,但对于那些希望在基于树的模型中探索反事实结果的人来说,这个包应该是有用的。
本文回顾了 FOCUS 算法的理论背景,代码示例展示了如何使用 CFXplorer,以及一些当前的限制。未来,我将为这个包添加更多的反事实解释生成方法。
希望你觉得这篇文章有用。
6. 参考文献
-
A. Lucic, H. Oosterhuis, H. Haned, 和 M. de Rijke. “FOCUS: 灵活可优化的树集成反事实解释。” 载于:AAAI 人工智能会议论文集。第 36 卷,第 5 期,2022 年,页 5313–5322。
-
S. Wachter, B. Mittelstadt, 和 C. Russell. “不打开黑箱的反事实解释:自动决策与 GDPR。” 载于:Harv. JL & Tech. 第 31 卷 (2017),页 841。
-
T. Akiba, S. Sano, T. Yanase, T. Ohta, 和 M. Koyama. “Optuna: 下一代超参数优化框架。” 载于:第 25 届 ACM SIGKDD 国际知识发现与数据挖掘会议论文集。2019 年,页 2623–2631。
-
D. P. Kingma 和 J. Ba. “Adam: 一种随机优化方法。” 载于:arXiv 预印本 arXiv:1412.6980 (2014)。
思维链提示促进了 LLMs 的推理能力
原文:
towardsdatascience.com/chain-of-thought-prompting-facilitate-llms-reasoning-abilities-313cd7714938
通过示例展示
·发表于Towards Data Science·阅读时间 6 分钟·2023 年 6 月 12 日
--

图片由Juan Rumimpunu提供,来源于Unsplash
大型语言模型(LLMs)被证明在解决各种任务方面非常高效,从总结文档到用不同编程语言编写代码。
此外,新发布的模型如 ChatGPT 和 GPT-4 使得 LLMs 的表现更加出色,为基于 LLM 的应用解锁了更多机会。
尽管大型语言模型(LLMs)具有非凡的技能,但它们有时未能展示非常简单的推理能力,无法解决一个四年级学生都能轻松处理的问题。
在这一领域已经进行了大量研究,旨在理解为何 LLMs 无法执行此类任务并使其表现更好。
关注这个特定问题的一个研究是思维链提示,由谷歌研究团队介绍。
思维链提示
一种结构为{输入、思维链、输出}的提示技术,其中思维链是一系列中间自然语言推理步骤。
模型会接收到一些包含输入和输出的示例(少样本学习),然后被要求解决一个涉及多步骤或算术推理任务的问题。
论文的主要结论:
-
思维链提示优于标准提示
-
思维链提示与标准提示之间的区别在于较大的模型上更为明显。思维链提示的性能提升与模型的参数数量成正比。
思路链提示技术就是逐步解决问题。每一步都基于逻辑推理。这里是一个示例:
问题:约翰有 2 栋房子。每栋房子有 3 间卧室,每间卧室有 2 个窗户。每栋房子有 1 个厨房,厨房有 2 个窗户。此外,每栋房子还有 5 个窗户不在卧室或厨房中。
约翰的房子里有多少个窗户?
答案:每栋房子有 3 间卧室,每间卧室有 2 个窗户,共计 6 个卧室窗户。每栋房子有 1 个厨房,厨房有 2 个窗户,共计 2 个厨房窗户。还有 5 个窗户不在厨房或卧室中。因此,每栋房子有 6 + 2 + 5 = 13 个窗户。由于有 2 栋房子,所以约翰的房子里总共有 2 x 13 = 26 个窗户。
你可以遵循不同的步骤来获得正确的答案。对于这类问题,几乎总有多条路径可以到达答案。
在标准提示中,我们通常会向语言模型提出上述问题,并期望它给出答案。让我们试试:
# model: text-davinci-003
prompt = """
John has 2 houses. Each house has 3 bedrooms and there are 2 windows in each bedroom.
Each house has 1 kitchen with 2 windows. Also, each house has 5 windows that are not in the bedrooms or kitchens.
How many windows are there in John's houses?
"""
# model response
"There are a total of 32 windows in John's houses."
错误的答案!
我们不知道模型如何得到这个答案,这带来了思路链提示的另一个优点。模型以逐步回答的方式回应,这使得调试过程更容易。我们可以很容易地发现问题所在。
通过思路链提示技术,我们在提示中添加了一些问题及其答案,以进行少量示例提示。这些答案以逐步解决方案的形式出现(即展示思路链)。
这里是更新后的提示和模型的回应:
# model: text-davinci-003
prompt = """
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6\. The answer is 6.
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5\. The answer is 5.
Q: Leah had 32 chocolates and her sister had 42\. If they ate 35, how many pieces do they have left in total?
A: Originally, Leah had 32 chocolates. Her sister had 42\. So in total they had 32 + 42 = 74\. After eating 35, they
had 74 - 35 = 39\. The answer is 39.
Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
A: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8\. The answer is 8.
Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
A: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9\. The answer is 9.
Q: John has 2 houses. Each house has 3 bedrooms and there are 2 windows in each bedroom.
Each house has 1 kitchen with 2 windows. Also, each house has 5 windows that are not in the bedrooms or kitchens.
How many windows are there in John's houses?
A:
"""
# model response
"John has 2 houses. Each house has 3 bedrooms with 2 windows each,
1 kitchen with 2 windows, and 5 windows that are not in the bedrooms or
kitchens. So for each house, there are 3 x 2 + 2 + 5 = 13 windows.
In total, there are 2 x 13 = 26 windows. The answer is 26."
正确的答案!
回应是对模型思考过程的逐步解释(思路链),就像提示中展示的示例一样。
让我们再看一个例子。这里是标准提示版本:
# model: text-davinci-003
prompt = """
We have a big meeting in 10 days, which is 5 days before our product
launch on 08/24/2022\. What is the date today in MM/DD/YYYY?
"""
# model response
"08/14/2022"
错误的答案!
让我们尝试用思路链提示回答相同的问题,提供一些输入输出示例(少量示例学习):
# model: text-davinci-003
prompt = """
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6\. The answer is 6.
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5\. The answer is 5.
Q: 2015 is coming in 36 hours. What is the date one week from today in MM/DD/YYYY?
A: If 2015 is coming in 36 hours, then it is coming in 2 days. 2 days before 01/01/2015 is 12/30/2014,
so today is 12/30/2014\. So one week from today will be 01/05/2015\. So the answer is 01/05/2015.
Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
A: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9\. The answer is 9.
Q: We have a big meeting in 10 days, which is 5 days before our product launch on 08/24/2022.
What is the date today in MM/DD/YYYY?
A:
"""
# model response
"If the product launch is on 08/24/2022, then 10 days before that is '
'08/14/2022\. 5 days before that is 08/09/2022\. So today is 08/09/2022\. The '
'answer is 08/09/2022."
正确的答案!
结束语
思路链提示明显提高了语言模型在某些任务上的能力。如果你想了解更多关于他们进行的实验,建议阅读整篇论文。
需要注意的是,正如论文中提到的,思路链提示的好处只有在应用于大约 1000 亿参数的模型时才会显现,并且对较小模型的性能提升并不显著。
实验结果得出的结论是,较小的模型会产生流畅但不合逻辑的思路链,这导致其表现比标准提示更差。
参考文献
Wei, Jason, 等(2022)。思路链提示引发大型语言模型的推理。ArXiv。arxiv.org/abs/2201.11903
LLMs 的思维链提示
原文:
towardsdatascience.com/chain-of-thought-prompting-for-llms-33c963eead38
一个实用且简单的“推理”方法与 LLMs
·发表于 Towards Data Science ·阅读时长 16 分钟·2023 年 7 月 24 日
--

(照片由 Matthew Lancaster 提供,来自 Unsplash)
大型语言模型(LLMs)的成功源于我们能够在大量文本语料库上进行预训练(使用语言建模目标)仅解码器变换器模型。考虑到我们对足够大的模型进行预训练,LLMs 在少量样本学习方面表现出色。换句话说,这意味着我们可以通过简单地制定一个文本提示(可能包含一些正确输出的示例)并让 LLM 生成正确答案来解决各种不同的问题(例如,翻译、句子分类、摘要等)。
尽管 LLMs 非常强大,但这些模型仍然存在一些持续难以解决的问题。特别是,推理问题(例如,算术或常识推理)特别困难。最初的尝试是通过对监督数据集中的各种推理问题的解决方案和解释进行微调 LLMs 和任务特定的验证模块 [3, 4]。然而,最近的研究发现可以利用少量样本学习来找到更简单的解决方案。
“本文的目标是赋予语言模型生成思维链的能力——一系列连贯的中间推理步骤,这些步骤将引导到问题的最终答案。” — 来源于 [1]
特别是,链式思维(CoT)提示 [1] 是一种新提出的技术,通过少样本学习提升 LLM 在基于推理的任务上的表现。类似于标准的提示技术,CoT 提示将多个推理问题的示例解决方案插入到 LLM 的提示中。然后,每个示例都配有一系列思维链,即解决问题的中间推理步骤。LLM 然后以少样本的方式学习在解决推理问题时生成类似的思维链。这种方法使用最少的数据(即仅几个示例进行提示),无需特定任务的微调,并显著提高了 LLM 在基于推理的基准测试上的表现,尤其是对于较大的模型。

(来自 [1, 11])
核心概念
要理解 CoT 提示,我们需要对 LLMs 和 它们的工作原理有一个基本了解。幸运的是,我们在之前的综述中已经广泛覆盖了这个话题:
本节将不会涵盖大语言模型(LLMs)的基础知识。相反,我们将专注于提高对提示和少样本学习的理解,并探讨这些技术如何被用来解决这些模型的核心限制:它们无法解决推理任务。
提示和少样本学习
在 GPT 和 GPT-2 等语言模型提出后,我们知道通过 自监督 的下一词预测(或 语言建模)目标是非常强大的。然而,如何将这些通用的 基础模型 适配到特定的下游任务上并不十分明确。例如,GPT 在下游任务上微调模型,而 GPT-2 以零样本的方式解决问题;见下文。

(来自 [2])
在提出 GPT-3 [2] 后,我们看到足够规模的 LLM 可以非常好地进行少样本学习。在通过语言建模目标进行预训练后,GPT-3(一个 1750 亿参数的 LLM)被发现能够准确地解决各种不同的语言任务,而无需任何微调。我们可以用提示方法代替微调。
更具体地说,提示利用了语言模型的文本到文本结构,通过提供如下输入:
-
“将这个句子翻译成英语:
<sentence> =>” -
“总结以下文档:
<document> =>”。
这些任务解决“提示”使得在语言模型中进行零样本(即,不查看正确输出的示例;见上文)或少样本(即,在提示中插入一些正确输出的示例;见下文)推理成为可能。语言模型的最合适输出应该能够解决任务(例如,总结文档或完成推理任务),这意味着我们可以通过准确的下一个标记预测来解决各种问题!

(来自 [1, 2])
我们可以通过提示做很多事情。事实上,最近创建了一个专门的提示工程领域,研究如何优化提示的措辞或结构以提高 LLM 性能。但是,敏感性是这个发展中的领域中的一个重要考虑因素。LLM 的性能可能会因为输入提示的微小扰动而发生巨大变化(例如,少样本示例的排列导致 GPT-3 在 SST-2 [13] 上的准确率从 93.4% 降低到 54.3%)。因此,在我们对提示方法的研究中,我们旨在找到* i)* 表现良好且 ii) 不受敏感性影响的技术。
我们能通过扩大规模来解决推理问题吗?
如上所述,LLM 的少样本学习性能随着规模的增加而改善,但大模型并不是我们所需要的一切。强大的 LLM 需要与大规模预训练数据集的组合 [14]。考虑到这一点,我们可能会问自己:LLM 在基于推理的数据集上的表现如何?随着规模的扩大,LLM 是否在推理方面变得更好?
“单独扩大模型规模并不足以在如算术、常识和符号推理等具有挑战性的任务上取得高性能”— 来源于 [1]
有趣的是,我们发现使用更大的模型和预训练数据集并不会改善 LLM 的推理能力(例如,参见Gopher [15]的分析)。事实上,这些模型因无法解决基本推理任务而受到严厉批评。因此,许多研究人员声称 LLM 只是重复训练数据,而没有进行任何复杂的推理或分析。无论如何,本概述将重点关注试图解决这一问题并使 LLM 更容易解决基本推理任务的提示技术。

(来自[4])
先前的方法。 在进一步了解我们如何帮助 LLM 解决推理问题之前,了解这个领域中的先前方法是有用的。算术、常识和符号推理任务的基线技术执行任务特定的微调,即模型在每个推理问题的监督示例上进行训练。更进一步,最佳方法是训练一个补充的“验证”模块,该模块可以判断 LLM 在推理任务中的输出正确性[4]。在测试时,这个验证器可以在生成多个问题答案后推断出最佳可能的输出;参见上述内容。
尽管这些技术在某些情况下可能效果相对较好,但由于一些原因,它们有其局限性:
-
需要进行任务特定的微调。
-
模型架构必须针对每个任务进行适应(即,通过验证模型)。
-
必须收集大量的监督数据。
考虑到这些限制,显而易见,使用仅基于提示的方法(例如,CoT 提示)来解决推理任务将会简单得多。我们可以避免微调,保持相同的模型架构,收集更少的数据,并用一个单独的预训练模型检查点解决许多任务。
一些显著的 LLM…
CoT 提示是一种提示技术,旨在提高预训练 LLM 的少样本学习性能。在[1]中,使用了一个特定组的 LLM 进行评估,下面列出了这些模型并进行了说明。
-
GPT-3 [2]:一个拥有 1750 亿参数的预训练 LLM,使用标准语言建模目标进行训练。
-
LaMDA [5]:一个基于 LLM 的对话模型,使用语言建模目标进行预训练,然后在对话数据和人工反馈上进行微调(存在 422M、2B、8B、68B 和 137B 等不同规模的模型)。
-
PaLM [6]:一个使用标准语言建模目标、Google 的 Pathways 框架和大规模文本语料库进行预训练的 LLM(存在 8B、62B 和 540B 尺寸的模型)。
-
Codex [7]:一个具有 120 亿参数的 LLM,它在使用标准语言建模目标进行预训练后,针对 GitHub 上的公开 Python 代码进行了微调。
-
UL2–20B [8]:一个使用 Mixture-of-Denoisers (MoD) 目标进行预训练的 LLM,这是一个在许多数据集和设置中表现良好的统一目标。
除了这些 LLM,我们还在近期的综述中了解到其他模型(例如,LLaMA、LLaMA 扩展 和 T5)。
思维链提示

(来源于 [1])
尽管我们可能对提示的概念有一定了解,但CoT 提示是什么? CoT 仅指一种特定的提示技术,它将思维链(即一系列中间推理步骤)插入到 LLM 的提示中;见上文。对于足够大的模型(>1000 亿参数),这种方法显著提升了 LLM 在算术、常识和符号推理任务中的复杂推理能力。
CoT 提示的来源是什么? 在 CoT 提示提出之前,我们已经知道少样本学习对 LLM 极其强大;见下文。我们不是 fine-tuning LLM 以执行任务,而是“提示”一个通用模型,给出几个正确输出的例子,然后生成最终答案。这种方法在许多任务中 非常成功。

使用少样本提示进行上下文学习(来源于 [2])
此外,我们从相关研究中了解到,生成自然语言的推理来解释如何得出最终答案对算术推理任务有益。我们可以训练或微调模型来生成这些推理 [3, 4],但这需要创建一个用于不同推理任务的高质量推理数据集,这既昂贵又耗时!
“仅依靠提示的方法很重要,因为它不需要大量的训练数据集,而且单个模型可以在不丧失通用性的情况下执行许多任务。” — 来自 [1]
CoT 提示法结合了少量示例提示的优势和生成自然语言推理的好处。我们不需要额外的训练或微调,只需将一些推理示例(即思维链)插入到提示中,就可以让 LLM 通过少量示例学习生成类似的推理。
CoT 提示法是如何工作的?
当我们作为人类解决推理任务时,通常会将问题拆分为更小的任务。例如,我在计算在餐厅想要给多少小费时通常会这样做:
-
取账单总额:$56.00
-
计算总额的 10%:$5.60
-
将这个值乘以 2(产生 20%的小费):$11.20
尽管这个例子很简单,但这一理念可以扩展到我们作为人类解决的各种心理推理任务中。我们生成一个思维链(在[1]中定义为“引导问题最终答案的连贯中间推理步骤序列”)来解决这些任务。简单来说,CoT 提示法增强了 LLM 生成类似思维链的能力。
“我们探索语言模型在推理任务中进行少量示例提示的能力,前提是提示由三元组组成:[输入,思维链,输出]。” — 来自 [1]
下方展示了结合思维链解决的各种推理任务的示例。

(来自 [1])
学习思维链。 为了教导 LLM 生成问题解决的推理,我们可以将这些推理示例直接插入到提示中。然后,LLM 可以利用其少量示例学习能力在解决任何推理问题时生成类似的思维链。如下所示,提示通常包含几个思维链示例。

(来自 [1])
[1]中的作者发现,这种提示方法使 LLM 在解决问题时生成类似的思维链,这有助于推理能力,并具有几个显著的好处:
-
可解释性:LLM 生成的思维链可以用来更好地理解模型的最终答案。
-
适用性:CoT 提示法可以用于任何可以通过语言由人类解决的任务。
-
提示:不需要对任何 LLM 进行训练或微调。我们只需将一些 CoT 示例插入提示中即可!
此外,LLM 甚至可以通过生成更多步骤的思维链来分配更多计算资源给复杂的推理问题。这模仿了我们作为人类通常会做的事情!
CoT 提示法极其有益
为了评估 CoT 提示对 LLM 解决推理问题能力的影响,CoT 提示在算术、常识和符号推理基准上进行了测试。评估使用了几个不同的预训练 LLM,包括 GPT-3 [2]、LaMDA [5]、PaLM [6]、Codex [7] 和 UL2 [8]。作为基准,文献 [1] 中的作者使用了标准的少量提示,由 GPT-3 提出的。所有模型在评估过程中使用 贪婪解码,但通过对多个样本进行多数投票可以获得更好的结果 [9]。
算术推理。 算术推理任务包括数学应用题。这样的题目对人类来说很简单,但 LLM 常常对此感到困难。下面提供了 [1] 中使用的算术推理数据集的概述。

(来自 [1])
对于 CoT 提示,一组由八个少量示例组成的提示是手动编写的(无需大量提示工程),并且用于所有数据集,除了具有多项选择结构的 AQuA。下面展示了在多个 LLM 上进行 CoT 提示的算术推理数据集实验结果。

(来自 [1])
从这些实验中,我们发现 CoT 提示有几个显著的特点。首先,CoT 提示似乎在较大的 LLM(即>100B 参数)上效果更好。较小的模型通常产生不合逻辑的思考链,从而使性能低于标准提示。此外,更复杂的问题(例如 GSM8K)从 CoT 提示中获得的好处更大。与之前的最先进方法(执行任务特定的微调)相比,使用 GPT-3 和 PaLM-540B 的 CoT 提示在所有情况下都达到了相当或更好的性能。
当我们定性地检查 CoT 提示生成的正确和错误答案时,我们得到以下结论:
-
大多数正确答案是逻辑思考链的结果,除了少数偶然从错误的思考链中预测出正确答案的情况。
-
46% 的错误答案几乎是正确的,这意味着它们包含了有轻微错误的思考链。
-
56% 的错误答案是由于思考链在理解或推理上存在主要问题造成的。
[1] 中的作者还分析了 CoT 提示对不同结构(例如,置换少量示例)的鲁棒性,并剖析了 CoT 提示的各个方面,发现 CoT 对模型性能提供了一致且独特的好处。有趣的是,CoT 提示对小的提示扰动并不太敏感。
常识推理。 常识推理 问题假设对一般背景知识有掌握,并需要对物理和人类互动进行推理。作者采用类似于算术推理实验的设置(除了少数需要手动策划少量示例的数据集),评估了各种预训练的 LLM 在常识推理任务中的表现,结果如下面的图所示。

(来自 [1])
简而言之,CoT 提示在常识推理问题上也被发现提供了巨大的好处。我们再次看到,较大的模型从 CoT 提示中获益更多。然而,无论是标准提示还是 CoT 提示的表现都随着模型规模的增加而提高,其中 CoT 提示倾向于实现稍微提高的性能。
符号推理。 [1] 中的作者还在符号推理任务上评估了 CoT 提示,例如:
-
最后一个字母连接: 要求模型连接并输出序列中每个单词的最后一个字母。
-
硬币翻转: 要求模型确定经过一系列硬币翻转后,硬币是否仍然是正面朝上。
更进一步,考虑了领域内和领域外的符号推理测试,其中领域外示例被定义为那些需要比训练或少量示例中看到的更多推理步骤(例如,最后一个字母连接中更多的单词或更长的硬币翻转序列)的示例。领域内和领域外评估的结果如下所示。

(来自 [1])
尽管这些任务比较简单,但我们看到 CoT 思维提示在 i) 提升了符号推理任务的表现,并且 ii) 使模型在需要更多推理步骤的领域外问题中能够更好地泛化。此外,我们再次观察到,无论是否使用 CoT 提示,较小的模型都无法解决符号推理任务。因此,CoT 提示似乎是符号推理的一个极有益的方法。
CoT 提示的变体
在 [1] 中提出 CoT 提示后,提出了几种变体,可以提高 LLM 的推理能力。这些不同的变体为引发 LLM 中的“推理”行为提供了各种有趣和实用的方法。下方提供了一些值得注意的 CoT 提示变体列表。

(来自 [10])
零-shot CoT。 零-shot CoT 提示[10]是对 CoT 提示[1]的简单跟进。为了鼓励 LLM 生成思维链,零-shot CoT 仅在问题末尾添加了“让我们一步步思考。”这句话。通过在 LLM 的提示中加入这一简单的补充,我们在[10]中看到,LLM 即使在没有观察到这种行为的明确示例的情况下,也能生成思维链,从而在推理任务中得到更准确的答案。请参见上文,以比较零-shot CoT 与其他提示方法。

(来自 [11])
自一致性。 自一致性是 CoT 提示的一种变体,它使用 LLM 生成多个思维链,然后将这些生成的结果中多数票作为最终答案;见上文。在 CoT 提示无效的情况下,使用自一致性往往会提高结果。简单来说,自一致性只是用一个生成多个答案的 LLM 管道来替代[1]中使用的贪婪解码过程,并选择这些答案中最常见的作为最终答案。

(来自 [12])
从少到多提示。 从少到多提示通过首先将问题分解为更小的子问题,然后单独解决每个子问题,从而超越了 CoT 提示。在解决每个子问题时,其答案会被包含在提示中,用于解决下一个子问题。与 CoT 提示相比,从少到多提示在多个任务(例如,最后一个字母拼接)上提高了准确性,并改善了对需要更多推理步骤的领域外问题的泛化能力。
提示工程。 正如上述示例所示(以及 CoT 提示的想法),为 LLM 制定一个有用的提示是一门艺术。要了解更多关于如何设计更有效提示的信息,我强烈推荐访问这里提供的 Learn Prompting 网站上的课程。
关键点
在这一概述中,我们看到标准提示不足以充分发挥 LLM 的潜力。相反,它似乎提供了一种“下限”来衡量 LLM 的表现,尤其是在更困难的推理任务上。CoT 提示通过利用 LLM 的少量示例学习能力,超越了标准提示技术,从而在解决基于推理的问题时引导生成连贯的多步骤推理过程。这种方法对 LLM 的表现极为有利,尤其是对于更大的模型。以下是一些关键点。
CoT 提示的实用性。 LLMs 在处理常识、算术和符号推理等任务时表现不佳。然而,CoT 提示显著提高了这些任务的表现。此外,这种方法不需要 i) 微调和 ii) 最少的额外数据(即,仅需一组示例用于少量学习)。因此,这是一种易于使用的技术,经过一些提示工程和少量示例的策划,可以帮助预训练的 LLMs 解决它们通常面临的困难任务。
推理随着规模的增长而显现。 并非所有模型都能从 CoT 提示中受益。事实上,相较于较小的模型,较大的 LLMs 从 CoT 提示中获得的好处更加明显。在 [1] 中,作者观察到 CoT 提示的好处在参数超过 1000 亿的模型中显现。这个具体数字可能会严重依赖于实验设置,但总体思路很明确:CoT 提示的好处在较大的模型中最为明显。
“链式思维模拟了人类推理者的思维过程。这并不能回答神经网络是否真的在进行推理。” — 引自 [1]
LLMs 是否真的知道如何推理? CoT 提示帮助 LLMs 解决某些推理任务,但这并不一定意味着 LLMs 具备复杂的推理能力。[1] 的作者甚至明确指出 CoT 提示的分析并不能回答 LLMs 是否真正进行推理。相反,CoT 提示是一种经验性技术,可以更准确地解决像算术、常识和符号推理这样通常对 LLMs 成为难题的任务。无论我们是否相信 LLMs 能够推理,我们都可以同意这项技术在实际应用中非常有用。
结束语
感谢你阅读这篇文章。我是 Cameron R. Wolfe,Rebuy 的人工智能总监。我研究深度学习的经验和理论基础。你还可以查看我在 medium 上的 其他文章!如果你喜欢,请在 twitter 上关注我,或者订阅我的 Deep (Learning) Focus 新闻简报,我通过对流行论文的易懂概述帮助读者深入理解人工智能研究中的主题。
参考文献
[1] Wei, Jason, 等. “链式思维提示引发大规模语言模型的推理。” arXiv 预印本 arXiv:2201.11903 (2022)。
[2] Brown, Tom, 等. “语言模型是少量学习者。” 神经信息处理系统进展 33 (2020): 1877–1901。
[3] Wang Ling, Dani Yogatama, Chris Dyer, 和 Phil Blunsom. 2017. 通过理由生成的程序归纳:学习解决和解释代数文字问题。ACL。
[4] Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, Jacob Hilton, Reiichiro Nakano, Christopher Hesse 和 John Schulman. 2021. 训练验证者以解决数学文字问题。arXiv 预印本 arXiv:2110.14168。
[5] Thoppilan, Romal 等人。“Lamda: 对话应用的语言模型。” arXiv 预印本 arXiv:2201.08239(2022 年)。
[6] Chowdhery, Aakanksha 等人。“Palm: 扩展语言建模的途径。” arXiv 预印本 arXiv:2204.02311(2022 年)。
[7] Chen, Mark 等人。“评估训练于代码的大型语言模型。” arXiv 预印本 arXiv:2107.03374(2021 年)。
[8] Tay, Yi 等人。“Ul2: 统一语言学习范式。” 第十一届国际学习表征会议。2022 年。
[9] Xuezhi Wang, Jason Wei, Dale Schuurmans, Quoc Le, Ed Chi 和 Denny Zhou. 2022a. 自一致性改进语言模型中的思维链推理。arXiv 预印本 arXiv:2203.11171。
[10] Kojima, Takeshi 等人。“大型语言模型是零-shot 推理者。” arXiv 预印本 arXiv:2205.11916(2022 年)。
[11] Wang, Xuezhi 等人。“自一致性改进语言模型中的思维链推理。” arXiv 预印本 arXiv:2203.11171(2022 年)。
[12] Zhou, Denny 等人。“从最少到最多的提示使大型语言模型能够进行复杂推理。” arXiv 预印本 arXiv:2205.10625(2022 年)。
[13] Zhao, Zihao 等人。“使用前校准: 提升语言模型的少样本性能。” 国际机器学习会议。PMLR,2021 年。
[14] Hoffmann, Jordan 等人。“训练计算最优的大型语言模型。” arXiv 预印本 arXiv:2203.15556(2022 年)。
[15] Rae, Jack W. 等人。“扩展语言模型: 方法、分析与训练 Gopher 的见解。” arXiv 预印本 arXiv:2112.11446(2021 年)。
Llama 2 中的停止生成挑战
潜在解决方案的探索
·
关注 发表在 数据科学前沿 · 9 分钟阅读 · 2023 年 9 月 10 日
--
Llama: 图片由柳德米拉·舒瓦洛娃提供
Meta 发布的 Llama 2 在社区中引发了兴奋,标志着一个新的时代的开始,此前大型语言模型的良好表现只能通过公司特定的 API 访问。
然而,重要的是要承认这些模型固有的一些缺陷。其中,生成停止问题尤为突出。我的个人经历表明,这些模型往往难以确定合适的‘停止’点,使它们在何时结束文本生成方面感到不确定。
在这篇博客文章中,我将深入探讨最小的 Llama 2 模型——Llama 2–7b 模型中停止生成失败的问题,并讨论几种潜在的解决方案。接下来的实现可以在这个 GoogleGolab notebook中找到,运行时类型为 T4。
停止生成失败
在本节中,我们将利用 Google Colab 中配备充足高 RAM 资源的 T4 GPU 来操作 Llama 2–7b 模型(2.21 积分/小时)。需要记住的是,T4 GPU 具有 16 GB 的 VRAM 容量,刚好可以容纳 Llama 2–7b 的权重(7b × 2 字节 = 14 GB 的 FP16)。
为了有效管理 VRAM 的使用,我们将采用一种叫做量化的技术。量化是一种在推理过程中通过使用低精度数据类型表示权重和激活值来最小化计算和内存需求的方法。
现在让我们深入研究以下代码片段。在这里,我们将演示如何加载带有 Bite 和 Byte 配置的“meta-llama/Llama-2–7b-chat-hf”模型,并基于该加载模型设置文本生成管道。
# bit and byte configuration
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True, # 4-bit quantization
bnb_4bit_quant_type='nf4', # Normalized float 4
bnb_4bit_use_double_quant=True, # Second quantization after the first
bnb_4bit_compute_dtype=bfloat16 # Computation type
)
model_id= "meta-llama/Llama-2-7b-chat-hf"
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
# Llama 2 Tokenizer
tokenizer =transformers.AutoTokenizer.from_pretrained(model_id)
# Llama 2 Model
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
quantization_config=bnb_config,
device_map='auto',
)
model.eval()
# Difine pipeline
transformers.pipeline(model=model,
tokenizer=tokenizer,
return_full_text=True, # langchain expects the full text
task='text-generation',
# we pass model parameters here too
max_new_tokens=500, # max number of tokens to generate in the output
temperature=0.1,
)
这个配置使我们能够在一分钟以内为给定提示生成文本。让我们用一个简单的问题进行测试:“我在巴黎可以做些什么?”下面是答案(请注意,由于温度设置不同,您的结果可能会有所不同)。
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. everybody's dream destination. Here are some of the top things to do in Paris:
1\. Visit iconic landmarks: The Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, Arc de Triomphe, and the Champs-Élysées are must-visit attractions.
2\. Explore museums and art galleries: The Louvre, Musée d'Orsay, and Centre Pompidou are some of the world's most famous museums.
3\. Take a river cruise: A Seine River cruise is a great way to see the city's landmarks and bridges while learning about its history and culture.
4\. Discover the city's neighborhoods: From the fashionable Champs-Élysées to the charming Montmartre, each Parisian neighborhood has its own unique character and charm.
# I skipped some text here because of the limt of space#
Of course, there are countless
很明显,这个模型难以生成令人满意的响应;它似乎在确定何时结束输出方面存在困难。通过对生成的文本进行分词处理,可以明显看出最终的标记不是 2,而 2 代表了模型分词器中的 eos(序列结束)标记。
通过仔细检查模型提供的标记分数(概率),我注意到token_id 2 (eso_token_id) 的分数为“-inf。”这意味着它不可能被生成。
问题解决尝试
在本节中,我们将探讨几种旨在解决当前问题的潜在解决方案。需要记住的是,这里讨论的解决方案代表了积极的努力,但它们可能并不总是能解决所面临的问题。
Logits 处理器
像 Llama 2 这样的语言模型处理一系列文本标记作为输入,并根据从初始标记到当前标记的上下文生成下一标记的条件概率。鉴于此,值得考虑在接近最大标记限制时手动调整这些概率,以提高遇到 eos 标记的可能性。我们通过定义一个名为“EosTokenRewardLogitsProcessor”的自定义 Logits 处理器来实现这一点,该处理器具有两个初始输入 eos_token_id 和 max_length,其中后者表示模型应生成 eos 标记的最大长度:
class EosTokenRewardLogitsProcessor(LogitsProcessor):
def __init__(self, eos_token_id: int, max_length: int):
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
if not isinstance(max_length, int) or max_length < 1:
raise ValueError(f"`max_length` has to be a integer bigger than 1, but is {max_length}")
self.eos_token_id = eos_token_id
self.max_length=max_length
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
# start to increese the reward of the eos_tokekn from 80% max length progressively on length
for cur_len in (max(0,int(self.max_length*0.8)), self.max_length ):
ratio = cur_len/self.max_length
num_tokens = scores.shape[1] # size of vocab
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] =\
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]]*ratio*10*torch.exp(-torch.sign(scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]]))
scores[:, self.eos_token_id] = 1e2*ratio
return scores
在类的“call”方法中,我们根据序列的长度增强 eos_token 的概率(得分)。当长度接近指定最大长度的 80%时,我们将 eos_token_id 的得分设置为 1e2 乘以长度比例,并相应地调整其他令牌的得分。
现在在管道的定义中声明 logits 处理器:
pipe = transformers.pipeline(model=model,
tokenizer=tokenizer,
return_full_text=True, # langchain expects the full text
task='text-generation',
# we pass model parameters here too
#stopping_criteria=stopping_criteria, # without this model rambles during chat
logits_processor=logits_process_list,
max_new_tokens=500, # max number of tokens to generate in the output
temperature=0.1,
)
使用相同的提示“What Can I do in Paris”再次运行管道,我们得到:
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere.
它运行得很好!我们即使得到的答案可能看起来很简短,但它是完整的。
微调
如果模型未能生成 EOS 令牌,为什么不考虑指示它这样做呢?通过使用包括以 EOS 令牌结尾的答案的数据集来微调模型,以提高模型性能的概念无疑是一个值得探索的有前景的途径。
在这一部分,我将毫不掩饰地使用这篇博客文章中奠定的基础,这篇文章采用了参数高效的微调(PEFT)方法,如 QLoRA,来微调 Llama 2–7b 模型。与其前身 LoRA 类似,QLoRA 利用一小组可训练的参数(适配器),同时保持核心模型参数不变。它引入了两个值得注意的创新:4-bit NormalFloat (NF4),一种对正常数据信息理论上最优的数据量化方法,以及双重量化。欲了解更多深入信息,请参考原始论文,如果您对该主题有进一步的兴趣。
让我们在一个名为‘timdettmers/openassistant-guanaco’的数据集上训练模型,您可以在 hugging face 数据库中找到这个数据集。该数据集的格式如下,其中人类和助手的对话由“###”分隔。

图片作者:“timdettmers/openassistant-guanaco”的数据集
在训练之前,我们需要将数据转换为 Llama 2 提示模板:
<s>[INST] <<SYS>>
{your_system_message}
<</SYS>> {user_message_1} [/INST]
我将在这里跳过数据集转换的细节。现在让我们看看以下代码给出的训练的主要部分:
# Load LoRA configuration
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
)
# Set supervised fine-tuning parameters
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=packing,
)
# Train model
trainer.train()
在一个包含指令和响应的数据集中,我们的方法涉及使用监督训练器(SFTainer)与 QLoRA 方法相结合,以微调语言模型(LLM)中的权重参数。我们的主要目标是最小化生成的答案与真实响应之间的差异,真实响应作为我们的参考标签。
在这个配置中,一个重要的参数是“lora r”,它代表了一个相对较小的值,涉及到秩分解权重矩阵的第二维和第一维。训练仅在这两个矩阵上进行,补充了现有的权重。
我们训练模型 250 步,训练损失如下图所示:

图片来源:Llama 2 训练 250 步的损失
现在让我们使用微调后的模型运行管道。这一次,我们得到了:
Paris, the capital of France, is known for its stunning architecture, rich history, and cultural attractions. Here are some popular things to do in Paris:
1\. Visit iconic landmarks: The Eiffel Tower, Arc de Triomphe, Notre-Dame Cathedral, and the Louvre Museum are some of the most famous landmarks in Paris.
2\. Explore museums and art galleries: Paris is home to many world-class museums and art galleries, including the Louvre, the Musée d'Orsay, and the Pompidou Center.
3\. Enjoy the city's gardens and parks: Paris has several beautiful gardens and parks, including the Luxembourg Gardens, the Tuileries Garden, and the Champ de Mars.
4\. Take a river cruise: A river cruise along the Seine is a great way to see the city's sights and learn about its history.
5\. Shop on the Champs-Élysées: The famous avenue is lined with high-end boutiques, department stores, and cafes.
6\. Attend a cabaret show: The Moulin Rouge is a famous cabaret that offers a lively and entertaining show.
7\. Visit Montmartre: The charming neighborhood of Montmartre is known for its artistic vibe, street performers, and stunning views of the city.
8\. Take a food tour: Paris is known for its culinary delights, and a food tour is a great way to sample the city's famous cuisine.
9\. Visit the Palace of Versailles: The Palace of Versailles is a beautiful royal palace located just outside of Paris.
10\. Attend a fashion show: Paris is the fashion capital of the world, and attending a fashion show is a great way to see the latest trends and styles.
These are just a few of the many things to do in Paris. The city has something for everyone, from history and culture to food and entertainment.
这是一个相当美丽的答案!
附加:给出不同的提示
最终,我们手头有一个简单而有效的技巧,特别是当模型的冗长不是问题时。我们可以在提示中明确说明我们需要一个简洁的回答。例如,当我问模型:“在巴黎我可以做什么?请用五句话或更少的句子回答,”它会提供:
Sure! Here are five things you can do in Paris: 1\. Visit the Eiffel Tower, a iconic landmark that offers stunning views of the city. 2\. Explore the Louvre Museum and see famous artworks like the Mona Lisa. 3\. Stroll along the Seine River and enjoy the city's picturesque bridges and charming streets. 4\. Visit the Notre-Dame Cathedral, a beautiful and historic church. 5\. Enjoy a delicious French meal at a traditional bistro or restaurant.
这是一个简短但干净且完整的回答。
停止标准:一个失败的尝试
对于感兴趣的用户,Hugging Face 推出了另一个名为 StoppingCriteria 的 API,旨在建立特定条件以强制序列停止。然而,当涉及到定义一个在遇到某些标记(例如‘\n’)时停止模型的自定义标准时,它可能无法提供一个全面的解决方案。例如,我尝试创建一个 StopOnTokens 类:
# define custom stopping criteria object
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids in stop_token_ids:
if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
但是,模型仍然无法给出完整的回答。
结论
在这篇博客文章中,我强调了 Llama 2 中生成停止的问题,并介绍了几种临时解决方案。再次,我跳过了很多实施细节,我建议你深入查看我的笔记本。

图片由Jose Aragones提供
但是,需要注意的是,这些解决方案旨在在短期内提高响应的用户友好性,但我们迫切期待一个永久的解决方案来解决这个问题。
检测 AI 生成文本的挑战
我们将深入探讨检测 AI 生成文本的挑战,以及在实际应用中所使用技术的有效性。
·
关注 发表在 Towards Data Science ·15 min read·Sep 27, 2023
--
图片由 Houcine Ncib 提供,来源于 Unsplash
与 Naresh Singh 共同撰写。
目录
-
介绍
-
建立文本来源检测的直觉
-
语言模型的困惑度是什么?
-
计算语言模型预测的困惑度
-
检测 AI 生成文本
-
错误信息
-
接下来是什么?
-
结论
介绍
现在写文章或帖子用的 AI 辅助技术无处不在!ChatGPT 开启了基于语言的 AI 的众多应用,AI 在任何类型内容生成中的使用已经达到了空前的高度。
然而,在诸如创意写作之类的学校作业中,学生需要创建自己的内容。然而,由于 AI 在这些任务中的流行和有效性,学生可能会被诱惑去使用它。在这种情况下,教师拥有能够检测 AI 生成内容的可靠工具就显得尤为重要。
本文旨在提供对构建此类工具的直观理解以及技术规格。它面向那些希望直观理解 AI 检测工作原理的读者以及希望构建此类工具的技术观众。
让我们直接进入主题!
建立文本来源检测的直观理解
从高层次来看,我们试图回答的问题是,“AI 语言模型(如 GPT-3)生成全部或部分文本的可能性有多大?”
如果你退一步看,你会意识到这是一种典型的日常情境。例如,你母亲对你说以下句子的可能性有多大?
-
亲爱的,请在晚上 8 点之前上床睡觉。
-
亲爱的,请在晚上 11 点后上床睡觉。
我们猜测前者的可能性远高于后者,因为你对周围世界已经有了一定的理解,并且对哪些事件更可能发生有了感觉。
这正是语言模型的工作原理。语言模型学习有关周围世界的知识,特别是语言。它们学习在给定不完整句子的情况下预测下一个标记或单词。
在上面的例子中,如果你被告知你母亲正在说话,而迄今为止所说的是“亲爱的,请在睡觉”,那么这个句子的最可能的继续就是“在晚上 8 点之前”,而不是“在晚上 11 点之后”。用技术术语来说,我们说你会对听到第二句话而非第一句话感到更多的困惑。
让我们深入探讨一下在语言模型的背景下困惑度的含义。
语言模型的困惑度是什么?
根据 dictionary.com,困惑度被定义为
困惑的状态;混乱;不确定性。
在现实世界中,如果你遇到一个你没有预料到的情况,你会比遇到一个你预料到的情况感到更多的困惑。例如,当你在路上行驶时,如果你看到一个交通信号灯,那么你会比看到一只穿过街道的山羊时感到更少的困惑。
同样,对于一个试图预测句子中下一个单词的语言模型来说,我们说如果模型用一个我们没想到的单词来完成句子,它会让我们感到困惑,相比之下,如果它使用我们期待的单词。以下是一些例子。
低困惑度的句子看起来会是这样的
-
外面的天气晴朗。
-
对不起,我错过了航班,未能及时到达国家公园。
具有高困惑度的句子可能如下所示
-
外面的天气很好。
-
我错过了光线,无法到达国家公园。
接下来,让我们看看如何计算语言模型做出的预测的困惑度。
计算语言模型预测的困惑度
语言模型的困惑度与能够意料之外地预测句子下一个标记(单词)的概率相关。
假设我们用一个包含 6600 个标记的词汇表训练语言模型,并运行一个预测步骤以让模型预测句子中的下一个标记。假设选择该标记的概率是 5/6600(即,该标记的概率不是很高)。其困惑度是概率的倒数,即 6600/5 = 1320,这表明 我们对这个建议感到非常困惑。如果选择该标记的概率是 6000/6600,那么困惑度将是 6600/6000 = 1.1,这表明我们对这个建议感到仅仅是稍微困惑。
因此,模型在更可能的预测上的困惑度低于模型在不太可能的预测上的困惑度。
预测句子“x”中所有标记的困惑度形式上定义为标记概率乘积的倒数的 N 次根。

然而,为了确保数值稳定性,我们可以用对数函数来定义它。

这是 e(2.71828)到预测标记为真实标记的平均负对数似然的幂。
训练和验证困惑度
模型的训练和验证困惑度可以直接从批次或时期的损失计算得出。
预测困惑度
由于需要一组每个预测的真实标签,因此无法计算预测困惑度。
计算困惑度的 PyTorch 代码
假设变量 probs 是一个形状为 (sequence_length,) 的 torch.Tensor,它包含了语言模型在序列中该位置上预测的真实标记的概率。
可以使用以下代码计算每个标记的困惑度。
token_perplexity = (probs.log() * -1.0).exp()
print(f"Token Perplexity: {token_perplexity}")
可以使用以下代码计算样本困惑度。
# The perplexity is e^(average NLL).
sentence_perplexity = (probs.log() * -1.0).mean().exp().item()
print(f"Sentence Perplexity: {sentence_perplexity:.2f}")
接下来,让我们查看一下计算给定句子的每个标记概率的代码。
def get_pointwise_loss(self, inputs: List[str], tok):
self.model.eval()
all_probs = []
with torch.inference_mode():
for input in inputs:
ids_list: List[int] = tok.encode(input).ids
# ids has shape (1, len(ids_list))
ids: Torch.Tensor = torch.tensor(ids_list, device=self.device).unsqueeze(0)
# probs below is the probability that the token at that location
# completes the sentence (in ids) so far.
y = self.model(ids)
criterion = nn.CrossEntropyLoss(reduction='none', ignore_index=0)
# Compute the loss starting from the 2nd token in the model's output.
loss = criterion(y[:,:,:-1], ids[:,1:])
# To compute the probability of each token, we need to compute the
# negative of the log loss and exponentiate it.
loss = loss * -1.0
# Set the probabilities that we are not interested in to -inf.
# This is done to make the softmax set these values to 0.0
loss[loss == 0.0] = float("-inf")
# probs holds the probability of each token's prediction
# starting from the 2nd token since we don't want to include
# the probability of the model predicting the beginning of
# a sentence given no existing sentence context.
#
# To compute perplexity, we should probably ignore the first
# handful of predictions of the model since there's insufficient
# context. We don’t do that here, though.
probs = loss.exp()
all_probs.append(probs)
#
#
return all_probs
#
现在我们对语言模型的功能以及如何计算每个标记和每个句子的困惑度有了一些了解,让我们试着把这些信息结合起来,看看如何利用这些信息构建一个可以检测文本是否由 AI 生成的工具。
检测 AI 生成的文本
我们已经拥有检查文本是否由 AI 生成所需的所有成分。这里是我们需要的一切:
-
我们希望检查的文本(句子或段落)。
-
该文本的分词版本,使用与该模型的训练数据集相同的分词器进行分词。
-
经过训练的语言模型。
利用以上 1、2 和 3,我们可以计算以下内容:
-
模型预测的每个 token 的概率。
-
使用每个 token 的概率计算的困惑度。
-
整个句子的总困惑度。
-
模型在训练数据集上的困惑度。
要检查文本是否由 AI 生成,我们需要将句子的困惑度与模型的困惑度(经过调整的系数 alpha)进行比较。如果句子的困惑度高于调整后的模型困惑度,则可能是人类撰写的文本(即不是 AI 生成的)。否则,可能是 AI 生成的。原因在于,我们期望模型不会对它自己生成的文本感到困惑,因此如果它遇到一些自己不会生成的文本,那么就有理由相信这些文本不是 AI 生成的。如果句子的困惑度小于或等于经过调整的模型训练困惑度,那么很可能是使用该语言模型生成的,但我们不能非常确定。这是因为一个人也有可能写出这样的文本,而且这正好是模型也可能生成的内容。毕竟,模型是在大量人类写作的文本上训练的,所以在某种意义上,模型代表了一种“普通人类的写作”。

上述公式中的ppx(x)表示输入“x”的困惑度。
接下来,让我们查看一些人类撰写的文本与 AI 生成的文本的例子。
AI 生成的文本与人类撰写的文本的例子
我们编写了一些 Python 代码,根据每个 token 相对于模型困惑度的困惑度对其着色。如果我们不考虑其困惑度,第一个 token 总是呈黑色。困惑度小于或等于模型困惑度的 token 被着色为红色,表明它们可能是 AI 生成的,而困惑度较高的 token 被着色为绿色,表明它们肯定不是 AI 生成的。

句子前方方括号中的数字表示使用语言模型计算的句子困惑度。注意某些单词有部分红色和部分蓝色。这是因为我们使用了子词分词器。
这是生成上述 HTML 的代码。
def get_html_for_token_perplexity(tok, sentence, tok_ppx, model_ppx):
tokens = tok.encode(sentence).tokens
ids = tok.encode(sentence).ids
cleaned_tokens = []
for word in tokens:
m = list(map(ord, word))
m = list(map(lambda x: x if x != 288 else ord(' '), m))
m = list(map(chr, m))
m = ''.join(m)
cleaned_tokens.append(m)
#
html = [
f"<span>{cleaned_tokens[0]}</span>",
]
for ct, ppx in zip(cleaned_tokens[1:], tok_ppx):
color = "black"
if ppx.item() >= 0:
if ppx.item() <= model_ppx * 1.1:
color = "red"
else:
color = "green"
#
#
html.append(f"<span style='color:{color};'>{ct}</span>")
#
return "".join(html)
#
从上面的例子可以看出,如果模型将某些文本检测为人类生成的,那么这些文本确实是人类生成的,但如果模型将文本检测为 AI 生成的,那么有可能并非 AI 生成的。那么为什么会发生这种情况呢?让我们接着看!
假阳性
我们的语言模型经过大量由人类编写的文本训练。通常很难检测到某个特定人是否(数字化)编写了某些内容。模型的训练输入包含了许多不同的写作风格,很可能是由大量不同的人编写的。这使得模型学习了许多不同的写作风格和内容。很可能你的写作风格与模型训练中某些文本的写作风格非常接近。这就是假阳性出现的原因,也是模型不能确定某些文本是否是 AI 生成的原因。然而,模型可以确定某些文本是人类生成的。
OpenAI: OpenAI 最近宣布将停止其用于检测 AI 生成文本的工具,理由是准确率较低(来源:Hindustan Times)。
AI 分类器工具的原始版本从一开始就存在某些局限性和不准确性。用户需要手动输入至少 1,000 个字符的文本,然后 OpenAI 会分析这些文本以分类为 AI 或人类编写。不幸的是,该工具的表现不尽如人意,因为它仅正确识别了 26% 的 AI 生成内容,并且错误地将 9% 的人类编写文本标记为 AI。
这是 OpenAI 的博客文章。看起来他们使用了与本文中提到的不同的方法。
我们的分类器是一个语言模型,经过微调,使用的是同一主题的人工编写文本和 AI 编写文本的对照数据集。我们从多个我们认为是由人类编写的来源收集了这个数据集,比如预训练数据和提交给 InstructGPT 的人类演示。我们将每篇文本分为提示和回应。在这些提示上,我们生成了来自我们和其他组织训练的各种语言模型的回应。对于我们的网页应用,我们调整了置信度阈值,以保持低的假阳性率;换句话说,我们只有在分类器非常确信的情况下才将文本标记为可能的 AI 编写。
GPTZero: 另一个流行的 AI 生成文本检测工具是 GPTZero。看起来 GPTZero 使用 困惑度和突发性 来检测 AI 生成的文本。“突发性指的是某些词语或短语在文本中以突发的形式出现。换句话说,如果一个词在文本中出现了一次,它很可能会在接近的地方再次出现”(来源)。
GPTZero 宣称具有非常高的成功率。根据 GPTZero FAQ,“在 0.88 的阈值下,85% 的 AI 文档被分类为 AI,99% 的人工文档被分类为人工。”
这种方法的普遍性
这篇文章提到的方法不具有很好的一般化性。我们的意思是,如果你有三个语言模型,比如 GPT3、GPT3.5 和 GPT4,那么你必须将输入文本通过这三个模型,并检查它们的困惑度,以确定文本是否由其中任何一个模型生成。这是因为每个模型生成文本的方式略有不同,它们都需要独立评估文本,以确定是否有模型生成了该文本。
随着到 2023 年 8 月世界上大量语言模型的普及,似乎不太可能检查任何文本是否来自世界上的任何语言模型。
实际上,每天都有新的模型在训练,跟上这种快速进展似乎非常困难。
以下示例显示了让我们的模型预测 ChatGPT 生成的句子是否为 AI 生成的结果。如你所见,结果是混合的。

紫色框中的句子被我们的模型正确识别为 AI 生成的,而其余的则被错误识别为人工撰写的。
这可能发生的原因有很多。
-
训练语料库规模: 我们的模型在非常少的文本上进行训练,而 ChatGPT 在数 TB 的文本上进行训练。
-
数据分布: 我们的模型在与 ChatGPT 不同的数据分布上进行训练。
-
微调: 我们的模型只是一个 GPT 模型,而 ChatGPT 是针对聊天类响应进行微调的,使其生成的文本具有略微不同的语调。如果你有一个生成法律文本或医学建议的模型,那么我们的模型在这些模型生成的文本上也会表现不佳。
-
模型规模: 我们的模型非常小(少于 100M 参数,而 ChatGPT 类模型有超过 200B 参数)。
很明显,如果我们希望提供合理高质量的结果以检查文本是否为 AI 生成的,我们需要一种更好的方法。
接下来,我们来看一下关于这个话题在互联网中流传的一些错误信息。
错误信息
一些文章对困惑度的解释不正确。例如,如果你在 google.com 搜索 “人工撰写内容的困惑度高还是低?”,你会在第一个位置看到 以下结果。

这是不正确的,因为人工撰写的内容通常比 AI 生成的内容具有更高的困惑度。
让我们来看看研究人员在这一领域探索的技术,以期比目前的情况做得更好。
接下来是什么?
我们已经确定,检测 AI 生成文本是一个困难的问题,成功率低到不比猜测更好。让我们看看这一领域的最先进技术,研究人员如何探索以更好地处理这个问题。
水印: OpenAI 和 Google 已承诺为 AI 生成的文本添加水印,以便可以程序化识别。
关于这种水印如何工作的技术细节尚不清楚,两家公司都没有披露相关细节。
即使 OpenAI 和 Google 采用了水印技术,我们也不能确定所有部署的语言模型都会包含水印。人们仍然可能会部署自己的模型来生成文本并将其发布到公共平台。即使公司决定给生成的文本加水印,也不清楚这是否会成为标准,还是每家公司会有自己的专有策略和可能收费的工具来检查文本是否由其 AI 文本生成工具生成。如果这是一个开放标准,人们有可能绕过它,除非它像加密密码一样需要大量计算来解密。如果不是开放标准,那么人们将依赖这些公司提供开放和免费的工具及 API 来进行检查。此外,这些工具在长期中的有效性也是一个问题,因为甚至可能训练模型来接收带水印的 AI 生成文本,并返回没有水印的 AI 生成文本。
这篇文章讨论了一种为 AI 生成文本添加水印的可能技术,并提到这种方法的重大挑战。
个性化: 在我们看来,检测 AI 生成文本的问题在短期内仍将具有挑战性。我们相信,策略需要变得更加侵入性和个性化才能更有效。例如,与其询问某些文本是否由 AI 生成,不如询问这些文本是否由特定的人撰写。然而,这将要求系统能够访问大量该特定人撰写的文本。此外,如果某些文本由多人撰写,如本文所示,问题会变得更加复杂。
让我们看看这种个性化的系统在检测人工撰写文本方面对教育工作者和学生的影响。
如果存在这样的解决方案,教育工作者将更倾向于给学生布置个人作业而不是小组作业。这还将要求每个学生首先提供大量他们自己撰写的文本。这可能意味着在入学前需要在大学里亲自花费几个小时。这无疑会对教授学生团队合作以实现共同目标的重要性产生负面影响。
另一方面,在某些情况下,访问基于 AI 的文本生成技术可以让学生集中精力解决实际问题,例如进行研究或文献研究,而不是花时间以完善的方式撰写他们的学习成果。可以想象,学生们会在数学或科学课程中花更多时间学习概念和技术,而不是写作。那部分可以由 AI 处理。
结论
在本文中,我们建立了如何检测 AI 生成文本的直觉。我们可以使用的主要指标是生成文本的困惑度。我们看到了一些 PyTorch 代码,用于检查给定文本是否可能是 AI 生成的,方法是利用文本的困惑度。我们也看到了这种方法的一些缺点,包括可能出现假阳性。希望这能帮助你理解和欣赏检测 AI 生成文本的细节。
这是一个不断发展的领域,研究人员正努力寻找一种更高准确率检测 AI 生成文本的方法。这项技术对我们生活的影响承诺将是显著的,并且在许多方面仍然未知。
虽然我们讨论了检测 AI 生成文本的技术,但我们假设整个文本要么是人类撰写的,要么是 AI 生成的。实际上,文本往往是部分人类撰写的,部分 AI 生成的,这使得问题变得更加复杂。
如果你想了解更多检测 AI 生成文本的方法,比如使用突发性指标,你可以在这里阅读。
本文中的所有图片(除了第一张)均由作者(们)创作。
中国量产自动驾驶的挑战
原文:
towardsdatascience.com/challenges-of-mass-production-autonomous-driving-in-china-407c7e2dc5d8
以及 2023 年小鹏汽车的最新进展
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 6 月 19 日
--
本博客文章基于在 2023 年温哥华 CVPR 的端到端自动驾驶研讨会上的主题演讲,题为“中国量产自动驾驶的实践”。主题演讲的录音可以在这里找到。
自动驾驶是一个令人望而生畏的挑战,特别是在中国,这里的人类驾驶已经是世界上最具挑战性的之一。主要涉及三个因素:动态交通参与者、静态道路结构和交通信号。特别是,交通灯控制信号带来了独特的挑战,因为它们在几何上是静态的,但在语义上是动态的。在接下来的会议中,我们将简要回顾动态对象和静态环境,并深入探讨交通灯这一有趣且特殊的话题。
[## CVPR 2023 自动驾驶研讨会 | OpenDriveLab
我们自豪地宣布今年与我们的合作伙伴 - Vision-Centric… 合作推出的四项全新挑战。
opendrivelab.com](https://opendrivelab.com/e2ead/cvpr23.html?source=post_page-----407c7e2dc5d8--------------------------------)
动态与静态挑战
动态交通参与者,如易受伤害的道路使用者(VRUs),给中国的自动驾驶车辆带来了重大挑战。VRUs 通常是不可预测的,以不同的姿态出现,常常出现在司机最意想不到的地方。大型动物可能会突然出现在乡村道路上,而宠物可能会漫游到城市街道上。此外,装载物品的车辆或三轮车的具体类型也很难确定。考虑一下中间行的最后一张照片,它实际上即使对于人类来说,第一次看到场景也非常具有挑战性。那辆装满树枝的车辆无意中完全隐匿在环境中。

各种动态道路使用者
静态道路结构和拓扑对自动驾驶车辆也可能构成重大挑战。例如,这里展示的复杂交叉口突显了需要解决的复杂程度。虽然看起来像是科幻电影中的截图,但这个交叉口实际上是在谷歌地球上可以查看的真实地点。

复杂交叉口的卫星图像
如果我们放大,就会发现一个有趣的道路元素,这可能在中国是独一无二的——左转待行区。它旨在提高左转交通的通行量,使更多的车辆在一个交通信号周期内通过交叉口。注意,这种设计可能并不对称,每个方向的设计都是根据交通模式单独设计的。我们甚至可以找到关于它及其效果的学术论文。尽管它的提出出于良好的意图,但对于新司机和自动驾驶车辆来说,可能真的很困惑。
在有待行区的交叉口左转涉及两个步骤。它们都涉及不同的交通信号灯组合。在这里,我只展示了最常见的交通信号灯模式。交通信号灯组合可能更复杂,有时涉及专门用于待行区的特殊交通信号灯。

左转待行区由科学论文支持
角落案例之王:交通信号灯
现在我们可以深入探讨交通信号灯的所有角落案例。交通信号灯可能是拥有最多长尾角落案例的对象类别。交通信号灯的感知复杂有两个原因。首先,我们必须识别交通信号灯的位置、类型和颜色,然后我们还需要知道在我们检测到的所有交通信号灯中(这里有六个),我们的车辆应该关注哪一个。为了做出这个决定,必须获得交通信号灯与不同车道之间的正确匹配。

中国典型的交通信号灯场景
一种特殊类型的灯是专为公交车设计的交通信号灯。我们需要正确识别它们有两个不同的原因。首先,为了规划和控制自车,我们需要识别它们以正确忽略它们,因为它们可能传递与我们需要关注的信号灯冲突的信息,从而给我们的自动驾驶车辆带来困惑。然而,要预测附近公交车的潜在操作,我们也需要正确了解其状态。

公交车交通信号灯
中国为公交车设计的交通信号灯有多种形式,包括标有“BRT”、“SRT”、“Bus”或单个字母“B”的 LED 灯。它们还可以具有特定的中文字符,如“公交”或“有轨电车”,有时还包括描绘可爱小公交车的图标。除了这些特征外,还可能包括交通标志修饰符,这使得自动驾驶车辆必须检测和识别这些特征,并将其准确地与相应的交通信号灯关联起来。


多灯泡交通信号灯
除了专门为公交车设计的交通信号灯外,另一种复杂的交通信号灯是多灯泡交通信号灯。与传统交通信号灯一次只点亮一个灯泡不同,多灯泡交通信号灯可能在同一个插座中同时点亮多个灯泡。因此,检测交通信号灯盒是不够的;同样重要的是要检测单个灯泡,并准确解释它们的语义。
在多灯泡交通信号灯的附加图像中,我们还看到了一些额外的数字。它们是倒计时计时器,显示下一个颜色变化的时间。我们在北美很常见到行人倒计时计时器,但这些计时器是为车辆设计的。如果正确使用这些信息,它们可能有助于改善行驶的顺畅度。

倒计时计时器交通信号灯
倒计时计时器可以采取多种形式并以不同的方式呈现。它们可以是独立显示的,也可以与交通信号灯系统集成。数字的格式可能有所不同,包括前导零的使用与否,使用的字体也可能有所不同,有些字体更具艺术性。此外,还有一些交通信号灯设计成进度条的样式。这涉及到一个动画,其中进度条逐渐缩短,然后变成不同颜色的完整进度条。虽然这种设计可能被认为是最创新的,但它也可能给我们的感知工程师带来挑战。

左转等待区交通信号灯
最后,还有专用于等待区的交通信号灯,它们可以以图标或文本的形式出现。图标类型通常还涉及动画,灯光逐渐点亮以引导你到等待区。文本类型可能在 LED 显示屏或交通标志牌上。对于文本,没有标准模式,这需要光学字符识别(OCR)和一些自然语言处理来提取语义。
XNGP:Xpeng 的自动驾驶解决方案
记录了大规模生产的 Xpeng 车辆自动左转的情况
记录了大规模生产的 Xpeng 车辆自动右转的情况
尽管面临这些挑战,截至 2023 年,Xpeng 团队成功交付了中国业界领先的自动驾驶解决方案。周围现实(SR)显示屏仅显示从车载感知系统获得的信息。左侧显示准确检测道路几何,而右侧则描绘了常见场景,即一群行人正在过马路。值得注意的是,展示的镜头代表了市场上现有的大规模生产 Xpeng 车辆的实际录制,没有任何后期修改。
如果你对这一切如何实现感兴趣,请参阅另一篇关于 XNet 部署的文章——这是 BEV(鸟瞰视角)范式中的端到端感知堆栈。它是上述大规模生产自动驾驶解决方案的关键推动力,并为完全端到端的自动驾驶解决方案奠定了坚实的基础。
XNet 的配方,Xpeng 的 BEV 感知架构
主要收获
-
在中国驾驶具有挑战性,因为动态物体无处不在,复杂的道路拓扑,以及交通信号灯的极端长尾边缘案例。
-
交通信号灯有许多特殊类型,每种类型都有很多边缘案例。有时需要综合考虑多个交通信号灯,还要考虑附近的倒计时器或交通标志。
-
截至 2023 年,Xpeng Motors 迎接了这些挑战,并在中国交付了业界领先的大规模生产自动驾驶产品。
参考文献
- Ma, Wanjing, et al. “增加左转等待区的信号化交叉口的容量。” 运输研究 A 部分:政策与实践 105 (2017): 181–196.
变化点检测 — 一种贝叶斯方法
在时间序列分析中识别变化点可以提供比你之前认为的更多信息。
·
关注 发表于 Towards Data Science ·13 分钟阅读·2023 年 5 月 4 日
--
变化点分析已经成为许多研究领域的关注点。这种分析涉及到寻找给定时间序列中的突然或突发的变化。根据 Iwata et al.(2018)的定义,变化点分析是“识别变化点的方法,即时间序列概率分布发生变化的时刻。” 根据 Van den Burg 和 Williams(2020)的说法,“时间序列行为中的突发变化时刻通常会引起警报,因为它们可能暗示数据生成过程中的重大变化。”
正如 Aminikhanghahi 和 Cook (2017) 以及 Iwata 等人 (2018) 所示,对这种分析的关注度增加是由于最近的技术发展。这些发展生成大量的数据,这些数据往往需要被密切监控,例如机器人技术、医学、气象、语音和图像识别等。应对这些问题有多种模型和方法。然而,由于本文的目标不是对这些模型进行描述性分析,我建议你参考 Van den Burg 和 Williams (2020) 的工作,以更好地了解这些方法论。在那里你可以找到在线和离线变化点检测的区别;单变量和多变量方法,这些方法可以是参数化的或非参数化的;以及监督或无监督模型。

图片来源 Tech Daily — Unsplash
从一开始,根据 Ehlers (2007) 的定义,“时间序列是指在时间上顺序进行的观测集合”,其主要特征是给定观测对邻近观测的依赖性。时间序列可以是连续的或离散的,在前一种情况下,根据 Ehlers (2007) 的定义,集合 T={t∶ t1< t < t2}, 该序列表示为 {X(t):t ∈ T}。因此,取时间序列 T 的一个观测窗口,包含 n 个观测值,我们得到一个表示为 {Xm, X(m+1),…, Xn } 的时间序列。根据 Aminikhanghahi 和 Cook (2017,第 3 页) 的说法,“变化点的检测可以定义为一个假设检验问题,存在两个备选方案”,即“原假设 H0: ‘没有变化’” 和备择假设 H1: ‘发生了变化’”。
从一开始
所以,如果你喜欢编程,是时候启动 Jupyter Notebook 并开始进行一些模拟和分析,我们将通过这种方法进行“随机漫步”。让我们导入以下包:
import numpy as np
from numpy.random import seed
from numpy.random import randn
import random
import datetime
import matplotlib.pyplot as plt
import seaborn as sns
import math
import decimal
from scipy import stats
np.seterr(divide='raise') #Make sure you set this
我们将从一个时间序列 y(t) 开始,该序列由两个时间序列组成,y(t1) 的均值是 μ=1 加上一些噪声,y(t2) 的均值是 μ=2 加上一些噪声,两个序列都有 30 个观测值。正如你可能想的那样,提议的时间序列将会有一个显著的变化点。
yt1 = u1 + 0.1*randn(30)
yt2 = u2 + 0.1*randn(30)
y = np.concatenate((yt1, yt2), axis=0)
plt.figure(figsize=(16, 10))
plt.plot(y)

一个具有两个不同均值的时间序列 — 作者提供的图片
如果你了解贝叶斯统计,你知道任何模型的构建基本上由 3 个分布组成。先验分布h(θ)* 反映了我们对问题的先验知识。似然函数f(x|θ)* 反映了获得的数据,并且必须纳入到先验分布中。这将导致一个后验分布h(θ|x)*,这是我们感兴趣的。这就是我们如何使用贝叶斯定理。
模糊聚类
目前,我们面临的第一个(也是最重要的?)问题是从我们构建的时间序列中获得一个先验分布——这是模型的第一部分。问题是:我们没有! 如果你处理的是时间序列,一旦我们得到一个先验分布,大多数任务就已经完成了。
D’Angelo et al.(2011)采取了一种有趣的方法来解决这个问题。他使用 Kohonen 网络对时间序列进行聚类。与硬聚类不同,Kohonen 网络是一种模糊聚类算法,这意味着任何给定点 X 与组 A 的关联概率为p。这种关联由函数fA(X)给出,该函数将每个点在 A 中的实数与区间[0, 1]中的数值关联起来,表示 X 在 A 中的隶属度。
对 Kohonen 网络的完整和更好解释,你可以参考 Kohonen(1990)和 Haykin(2007)。使用 Python,我构建了这样的网络,使用了两个函数:
def Center_Kohonen(y, X=0, K=2, alfa=0.8, C=500):
# This kohonen network implementation finds two centers in a time series
#Params:
#Y = Time Series
#M = Number of input patterns
#N = Dimension of the input pattern (for our case it will be 1)
#K = number of neurons, for the proposed problem (number of centers)
M = y.shape[0]
N = 1
f = 0
#Initializin the neurons' weights
I = y.argsort(axis=0) # Sorted Indexes
Y = np.sort(y) # Sortes Points in the Time Series
c1 = Y[0:7] # Beginning of the series
c2 = Y[M-7:M] # End of the series
#Adjusting the values
while np.std(c2) > 0.1: #As long as the standard deviation is greater than 0.1, replace the highest value with the mean
ma = c2.argmax()
c2[ma] = np.mean(c2);
y[I[ma+60-7]] = np.mean(c2);
while np.std(c1) > 0.1: #As long as the standard deviation is greater than 0.1, replace the lowest value with the mean
mi = c1.argmin()
c1[mi] = np.mean(c1)
y[I[mi]] = np.mean(c1);
#Definition of weight values
W = [np.mean(c1), np.mean(c2)]
#Finding centers from Kohonen's network training
for l in range(1, C+1):
alfa=alfa*(1-(C-(C-l))/C)
for i in range(0, M): #For each value in the time series
a=999
for j in range(0, K): #Where K is the number of cluster
if np.linalg.norm([(y[i]-W[j])], 2) < a:
a = np.linalg.norm([(y[i]-W[j])], 2)
f = j
W[f] = W[f]+alfa*(y[i]-W[f])
return c1, c2, I, Y, W, a, alfa
def Fuzzy_Set(y, W):
# This program finds membership values for a time series from previously specified centers.
# Where y is the time series and c a vector with the found centers
center_1 = []
center_2 = []
n = y.shape[0]
l = 2
# Finding the membership association for each point in the time series
for i in range(0, l):
for t in range(0, n):
sum=0;
for k in range(0, l):
sum = sum+(y[t]-W[k])*(y[t]-W[k])
if i == 0:
center_1.append(np.round(1-((y[t]-W[i])*(y[t]-W[i]))/sum, 3))
else:
center_2.append(np.round(1-((y[t]-W[i])*(y[t]-W[i]))/sum, 3))
return center_1, center_2
如果你顺序调用这两个函数,可能会得到这样的结果图:
c1, c2, I, Y, W, a, alfa = Center_Kohonen(y, X=0, K=2, alfa=0.8, C=500)
center_1, center_2 = Fuzzy_Set(y, W)
plt.figure(figsize=(16, 10))
plt.plot(center_1, 'b') # plotting t, a separately
plt.plot(center_2, 'r') # plotting t, b separately
plt.show()

使用 Kohonen 网络对时间序列 y(t)进行的模糊聚类 — 图片由作者提供
这非常有趣!通过 Kohonen 网络,我们能够对时间序列y(t)进行分割。这个图表显示我们有两个簇,因为我们设置了K=2. 在xlabel中,我们有时间序列中的每个点,而ylabel显示某一点与两个簇中的一个相关联的概率。正如你所见,蓝线告诉我们,在到达第 30 个点之前,所有点更可能(大约~99%)属于第一个组或集合μ1(t)。红线则显示了相反的情况,因为它表示与第二组,即集合μ2(t)的关联。这是合理的,因为我们构建了一个具有两个不同平均值的时间序列,从图示上看,这个图与第一个图相关。
尽管很有趣,我们直到现在还没有真正找到变化点(我们有一些线索),而且这里没有贝叶斯内容。
顺便提一下,大多数情况下,时间序列的点之间的区分并不会那么简单。例如,如果我们构建了时间序列y(t),其中y(t1)的均值为μ=1加上一些噪声,而y(t2)的均值为μ=1.3(而不是 2)加上一些噪声,这样的分割会好吗?我让你试试这个练习…
Metropolis-Hastings 算法登场了
如果你尝试了上述练习,你会发现仅仅使用 Kohonen 网络来找到替代时间序列中的任何变点指示时会遇到麻烦。这是因为 Kohonen 网络不提供变点,而是两组连续变量,表示每个点与给定簇的关联。
但要记住,集合μ1(t)和μ2(t)的值在区间[0,1]内。这意味着μ1(t)和μ2(t)近似于具有不同参数的贝塔分布(你听说过 Kullback–Leibler 吗?)。根据 D'Angelo et al. (2011)的说法,假设变点由m表示,则对于t≤m,我们将有Beta(a,b)分布,对于t>m,我们将有Beta(c,d)分布。考虑到 Beta 分布的性质,如果时间序列中存在变点,参数 a在Beta(a,b)中将大于参数 b,参数 c在Beta(c,d)中将小于参数 d。
问题是: 你如何构建这两个贝塔分布?Metropolis-Hastings 算法是一种马尔可夫链蒙特卡洛方法,最初由 Metropolis et al. (1953)提出,后来由 Hastings (1970)推广。根据 Gelman et al. (2003)的说法,任何马尔可夫链模拟的目标是“创建一个稳态分布由p(θ | x)指定的马尔可夫过程”。充分运行模拟使我们能够获得足够接近稳态和后验分布的分布。后验分布可以用参数θ的特定函数的期望来概括,即∫g(θ)p(θ | x)dθ = E [g(θ) | x]。这样的积分并不轻松,这就是为什么 MCMC 方法用于近似后验分布的好处所在。
Metropolis-Hastings 算法使用拒绝的概念,这意味着它从辅助分布中生成一个值,并以一定的概率接受它。如果你对 MCMC 方法不熟悉,你可能会质疑算法如何拒绝一个抽取的值。我们使用 Hastings (1970)给出的转移规则:

将其简单说来,我们可以利用由模糊聚类给出的两组连续变量来拒绝从给定先验分布中随机抽取的值,用于变点检测。如果你想了解更多关于 MCMC 方法的信息,我建议参考 Gamerman 和 Lopes (2018)。
让我们回到 Jupyter Notebook。下面的函数是针对这个问题的 Metropolis-Hastings 算法的实现。尽管功能强大,但该算法需要一些条件。首先是为每个需要找到的参数设置先验分布。对于参数 m,我们使用了 1 到 60 之间的均匀分布,这意味着算法在时间序列中随机选择变化点候选值。对于参数 a, b, c 和 d,我选择了信息量较弱的伽马分布。该函数还需要参数,这些参数是随机变量的集合 (μ1(t) 或 μ2(t)) 和模拟次数。
def Metropolis_Hastings(center_kohonen, n_sims=1000):
n = len(y)
m = 1 + round((n-1) * np.random.uniform(0, 1))
shape, scale, loc = 10, 0.1, 0
#Lists to save the date for each parameter
a_params = []
b_params = []
c_params = []
d_params = []
m_params = []
#Prior Distributions for the Parameters
a = stats.erlang.rvs(shape, loc=loc, scale=scale, size=1, random_state=None)[0]
b = stats.erlang.rvs(shape, loc=loc, scale=scale, size=1, random_state=None)[0]
c = stats.erlang.rvs(shape, loc=loc, scale=scale, size=1, random_state=None)[0]
d = stats.erlang.rvs(shape, loc=loc, scale=scale, size=1, random_state=None)[0]
for i in range(0, n_sims):
m1 = 1+round((n-1) * np.random.uniform(0, 1));
a1 = stats.erlang.rvs(shape, loc=loc, scale=scale, size=1, random_state=None)[0]
b1 = stats.erlang.rvs(shape, loc=loc, scale=scale, size=1, random_state=None)[0]
c1 = stats.erlang.rvs(shape, loc=loc, scale=scale, size=1, random_state=None)[0]
d1 = stats.erlang.rvs(shape, loc=loc, scale=scale, size=1, random_state=None)[0]
#PARAM A
aux1 = 1
for j in range(0, m):
try:
aux1 = aux1 * (center_kohonen[j] ** (a1-1))
except:
aux1 = aux1
aux2 = 1
for j in range(0, m):
try:
aux2 = aux2 * center_kohonen[j] ** (a-1)
except:
aux2 = aux2
try:
ra = ((math.gamma(a1+b)/math.gamma(a1))**m)*aux1*((((a/a1)**.9)*math.exp(-.1*(a1-a)))**2)/(((math.gamma(a+b)/math.gamma(a))**m)*aux2)
if (min(1, ra) > np.random.uniform(0, 1)):
a=a1
except:
pass
#PARAM B
aux1 = 1
for j in range(0, m):
try:
aux1 = aux1*(1-center_kohonen[j])**(b1-1)
except:
aux1 = aux1
aux2 = 1
for j in range(0, m):
try:
aux2 = aux2*(1-center_kohonen[j])**(b-1)
except:
aux2 = aux2
try:
rb = ((math.gamma(a+b1)/math.gamma(b1))**m)*aux1*((((b/b1)**.9)*math.exp(-.1*(b1-b)))**2)/(((math.gamma(a+b)/math.gamma(b))**m)*aux2)
if (min(1, rb) > np.random.uniform(0, 1)):
b = b1
except:
pass
#PARAM C
aux1 = 1
for j in range(m, n):
try:
aux1=aux1*center_kohonen[j]**(c1-1)
except:
aux1 = aux1
aux2 = 1
for j in range(m, n):
try:
aux2=aux2*center_kohonen[j]**(c-1)
except:
aux2 = aux2
try:
rc = ((math.gamma(c1+d)/math.gamma(c1))**(n-m))*aux1*((((c/c1)**.9)*math.exp(-.1*(c1-c)))**2)/(((math.gamma(c+d)/math.gamma(c))**(n-m))*aux2)
if (min(1, rc) > np.random.uniform(0, 1)):
c = c1
except:
pass
#PARAM D
aux1 = 1
for j in range(m, n):
try:
aux1=aux1*(1-center_kohonen[j])**(d1-1)
except:
aux1 = aux1
aux2 = 1
for j in range(m, n):
try:
aux2=aux2*(1-center_kohonen[j])**(d-1)
except:
aux2 = aux2
try:
rd = ((math.gamma(c+d1)/math.gamma(d1))**(n-m))*aux1*((((d/d1)**.9)*math.exp(-.1*(d1-d)))**2)/(((math.gamma(c+d)/math.gamma(d))**(n-m))*aux2)
if (min(1, rd) > np.random.uniform(0, 1)):
d = d1
except:
pass
#PARAM M
aux1 = 1
for j in range(0, m1):
try:
aux1 = aux1*(center_kohonen[j]**(a-1))*((1-center_kohonen[j])**(b-1))
except:
aux1 = aux1
aux2 = 1;
for j in range(m1, n):
try:
aux2 = aux2*(center_kohonen[j]**(c-1))*((1-center_kohonen[j])**(d-1))
except:
aux2 = aux2
aux3 = 1
for j in range(0, m):
try:
aux3 = aux3*(center_kohonen[j]**(a-1))*((1-center_kohonen[j])**(b-1))
except:
aux3 = aux3
aux4 = 1
for j in range(m, n):
try:
aux4 = aux4*(center_kohonen[j]**(c-1))*((1-center_kohonen[j])**(d-1))
except:
aux4 = aux4
try:
rm = (((math.gamma(a+b)/(math.gamma(a)*math.gamma(b)))**m1)*((math.gamma(c+d)/(math.gamma(c)*math.gamma(d)))**(n-m1))*aux1*aux2)/(((math.gamma(a+b)/(math.gamma(a)*math.gamma(b)))**m)*((math.gamma(c+d)/(math.gamma(c)*math.gamma(d)))**(n-m))*aux3*aux4)
if (min(1, rm) > np.random.uniform(0, 1)):
m = m1
except:
pass
a_params.append(a)
b_params.append(b)
c_params.append(c)
d_params.append(d)
m_params.append(m)
return a_params, b_params, c_params, d_params, m_params
用两个所需参数调用那个函数:在这里,我传递了由函数 Fuzzy_Set 给出的 center_1 和 n_sims=1000
a_params, b_params, c_params, d_params, m_params = Metropolis_Hastings(center_1, n_sims=1000)
fig_dims = (16, 10)
fig, ax = plt.subplots(figsize=fig_dims)
plt.plot(m_params, 'r')
ax.set(xlabel='# Simulations', ylabel='Change Point Candidates (m)')

Metropolis-Hastings 算法模拟 — 图片由作者提供
你现在终于找到了变化点。这张图很有趣,因为它展示了绘图过程。第一次由均匀分布给出的绘制值是 m=55。算法拒绝了它,然后尝试其他值,直到得到令人满意且稳定的结果。经过大约 150 次额外运行后,m=30 的值再也不能被算法拒绝。
由于函数返回了每个参数的采样值,我们也可以绘制它们的值。从参数 m 开始,这是所有变化点的绘制结果。要查看密度图,你可以丢弃前 200 次模拟作为“烧入”:
fig_dims = (16, 10)
fig, ax = plt.subplots(figsize=fig_dims)
ax.set(xlabel='Change Point Candidates', ylabel='Density')
sns.kdeplot(m_params[200:])

变化点候选密度图 — 图片由作者提供
我们还可以使用四个其他参数的均值来创建 Beta 分布,即参数 a, b, c 和 d。正如我们之前讨论的,这些参数对于 Metropolis-Hastings 算法至关重要,因为当拒绝规则必须断言 t≤m 时,我们将有 Beta(a,b) 分布,而 t>m 时将有 Beta(c,d) 分布。
让我们使用变量 a_params, b_params, c_params 和 d_params 来构建这种表示,这些变量包含了 a, b, c 和 d 的采样值
fig, ax = plt.subplots(1, 1, figsize=(16, 8))
ax.set_ylabel('Density')
beta_1 = np.random.beta(np.mean(a_params[200:]), np.mean(b_params[200:]), size=1000)
beta_2 = np.random.beta(np.mean(c_params[200:]), np.mean(d_params[200:]), size=1000)
sns.kdeplot(beta_1, color='b', fill=True, common_norm=False, alpha=.5, linewidth=0)
sns.kdeplot(beta_2, color='r', fill=True, common_norm=False, alpha=.5, linewidth=0)
用参数 a 和 b 的均值构建的第一个 Beta 分布是红色的,第二个是用参数 c 和 d 的均值构建的,呈蓝色。在中间,两个分布的密度较低,我们找到了变化点。毫无疑问,使用这种方法的一个巨大优势是可以得到这样的分布,因为我们可以用它们进行贝叶斯推断,丰富预测模型,甚至使用其他类型的 Monte Carlo 模拟。

两个 Beta 分布的表示 — 图片由作者提供
结论
寻找时间序列中的变化点可以防止系统陷入严重故障。试想一下,必须控制温度的机器。任何突然的变化必须尽快被识别,以便工程师进行检查。或者各种生产设施中的能源消耗。任何过度的消耗都必须被分析,因为它可能表示生产中存在某种偏差或能源泄漏,从而显著影响生产成本。
也就是说,D’Angelo(2011)开发的方法并在这里用 Python 实现,证明在检测给定时间序列中的变化点方面非常有价值。此外,如前所述,这种方法的另一个优点正是我们获得了两个 beta 分布作为输出,这可能非常有用。
参考文献
AMINIKHANGHAHI, Samaneh 和 COOK, J. Diane. 时间序列变化点检测方法综述。知识与信息系统,51,2017 年。
D’ANGELO, Marcos Flávio 等 时间序列变化点检测问题的模糊/贝叶斯方法。运筹学研究,31(2),2011 年。
EHLERS, Ricardo S. 贝叶斯推断。第 5 版。巴拉那联邦大学统计系,2007 年。
FAMA, Eugene. 股票市场价格的随机游走。金融分析师,21,1965 年。
_____________ 有效资本市场: 理论与实证工作的回顾。《金融杂志》,25,1970 年。
GAMERMAN, Dani. LOPES, Hedibert Freitas. 马尔科夫链蒙特卡罗: 贝叶斯推断的随机模拟。第 2 版。佛罗里达州:Chapman & Hall/CRC,2006 年。315 页。
GELMAN, Andrew 等 贝叶斯数据分析。第 2 版。佛罗里达州:Chapman & Hall/CRC,2003 年。668 页。
HASTINGS, W. K. 使用马尔科夫链的蒙特卡罗抽样方法及其应用。生物统计学,57,1970 年。
HAYKIN, Simon. 神经网络:原理与实践。第 2 版。阿雷格里港:Bookman,2007 年。900 页。
IWATA, Takuma 等 使用 10GbE FPGA NIC 加速在线变化点检测算法。在第 24 届国际并行与分布式计算会议上提出。都灵:2018 年。
KOHONEN, Teuvo. 自组织映射。纽约:Springer,1990 年,501 页。
METROPOLIS, Nicholas 等 通过快速计算机计算状态方程。化学物理学杂志,21,1953 年。
OH KJ, 等 开发基于时间的聚类神经网络用于变化点检测:应用于金融时间序列。亚太运营研究杂志,22(1),2005 年。
PAULINO, Carlos Daniel 等 贝叶斯统计。第 2 版。里斯本:卡洛斯·古尔本基安基金会,2018 年,601 页。
VAN DEN BURG, Gerrit J. J., WILLIAMS, Christopher K. I. 变化点检测算法评估。stat.ML,arXiv:2003.06222,2020 年。
在 DAX 测量中随着时间改变你的关注点
原文:
towardsdatascience.com/change-your-focus-over-time-in-dax-measures-292ff012643a
如何确定当你随着时间推移广告产品时投资是否有所回报
·发表于 Towards Data Science ·7 分钟阅读·2023 年 5 月 19 日
--

图片由 David Travis 提供,来源于 Unsplash
介绍
如何衡量广告活动的成功?
尤其是当你想要随着时间推移广告不同品牌时?
一个类似的问题曾由我的客户提问。
他希望分析其品牌的销售情况,并将其与广告品牌的销售情况进行比较,以确定投资是否有所回报。
挑战在于修改数据模型,使得在不破坏现有报告和分析的情况下能够进行此分析。
专注于选择的品牌
第一步是拥有一个随时间变化的广告品牌表。
如下所示:

图 1 — 每月的重点品牌表(作者提供的图)
在此表中,几乎每个月都会找到一个或多个品牌。
现在我可以扩展我的数据模型。
我在 Contoso 数据模型中的起点如下(从整个数据模型中摘录):

图 2 — 初始数据模型(作者提供的图)
如你所见,我有一个品牌名称列,但没有品牌 ID。
在这种情况下,我必须通过 [BrandName] 列创建一个关系。
集成 Focus Brands 表
将 Focus Brands 表添加到我的数据模型后,我添加了以下关系:

图 3 — 新数据模型及新关系(作者提供的图)
我设置了过滤器“聚焦品牌”过滤“产品”来代替“聚焦品牌”和“产品”表之间的双向过滤。但双向过滤也有效,只是引入了一个不必要的过滤方向。我们需要设置正确的过滤方向,以便我们能够根据“产品”表中的属性深入挖掘结果。
由于“产品”表中有多个相同名称的产品,我们必须与“产品”表建立多对多(m:n)关系。
在这种情况下创建数据模型的另一种方法是添加一个包含每个品牌一行的“品牌”表,并从这两个表到这个新表创建 1:n 关系。
为了简化起见,我在这里没有进行额外的步骤,而是坚持使用 m:n 关系。正如下面“参考文献”部分的 SQLBI 文章中描述的,还有更好的解决方案。
同样的情况适用于“聚焦品牌”表和“日期”表之间基于“MonthKey”列的关系。
但为了避免数据模型中的歧义,两个关系都必须被禁用。
关键关系是“聚焦品牌”和“在线销售”表之间的关系。
我的第一个方法是没有与“在线销售”表的关系。但我不得不在表之间使用交叉过滤来正确计算结果。
这种方法带来了一些问题:
-
无法计算广告产品和非广告产品的销售部分。
-
无法正确计算时间序列。
-
最重要的是,报告对所有用户不起作用。
最后一点至关重要。我在我的度量中使用了CROSSFILTER()函数。
正如你在微软文档的备注中可以读到的,当数据模型中实施了行级安全时,CROSSFILTER()是不支持的,我的客户就有这种情况。
为了实现“聚焦品牌”和“在线销售”表之间的关系,我不得不在每个表中添加一个新列,以将每个品牌映射到每个月。
在“聚焦品牌”表中,我用以下表达式添加了这一列:

图 4 — 带有每月品牌列的“聚焦品牌”表(图由作者提供)
在“在线销售”表中,我使用了这个表达式来添加相应的列:
Brand per Month = RELATED('Product'[BrandName])
& "_" & CONVERT(RELATED('Date'[MonthKey]),STRING)
现在,我已经能够添加新的关系。
创建度量并检查结果
计算结果的度量如下:
Online Sales for Focus Brands =
CALCULATE([Sum Online Sales]
,USERELATIONSHIP('Online Sales'[Brand per Month]
,'Focus Brands'[Brand per Month]
)
,USERELATIONSHIP('Focus Brands'[MonthKey]
,'Date'[MonthKey]
)
)
添加了带有[Sum Online Sales]和新[Online Sales for Focus Brands]度量的视觉效果后,2019 年 3 月的结果如下所示:

图 5 — 检查“聚焦品牌”度量的结果(图由作者提供)
如你所见,2019 年 3 月的每个“聚焦品牌”中两个度量的结果是相同的(用蓝色标记)。
新的度量正在为“聚焦品牌”计算正确的结果。
但现在,我们必须使用来自重点品牌列的品牌名称列来获得这些结果。
但最大的问题是,我得到了所有销售额的空白行,而没有对应的重点品牌。
当仅在列视觉中使用新度量时,这一点更为明显:

图 6 — 带空白的每月重点销售(作者绘制)
我可以向视觉中添加过滤器,以排除所有[BrandName] = Blank 的行,但我希望避免这样做,因为数据模型的所有用户都需要这样做,而不是通过良好的数据模型一次性解决这个问题。
目标是仅在报告中一次性提供品牌名称列,而不干扰使用产品表的所有现有视觉效果或报告。
改进数据模型
为了实现这一点,我必须确保在计算重点品牌的销售时,过滤器上下文能够正确流动。
请查看下面的图片以了解所需的流向:

图 7 — 所需的过滤器流向(作者绘制)
在这种情况下,我必须在重点品牌表和产品表之间设置双向过滤。
然后,我必须更改我的度量以强制过滤器上下文从产品表通过重点品牌表流动。
为此,我调整了度量以包含 CROSSFILTER():
Online Sales for Focus Brands =
CALCULATE([Sum Online Sales]
,USERELATIONSHIP('Online Sales'[Brand per Month]
,'Focus Brands'[Brand per Month]
)
,USERELATIONSHIP('Focus Brands'[MonthKey]
,'Date'[MonthKey]
)
,CROSSFILTER('Online Sales'[ProductKey]
,'Product'[ProductKey]
,None
)
)
这样,我们面临的问题是这个度量对应用 RLS 规则的所有用户都不起作用。
没有这种限制,这个度量在使用产品表中的品牌名称列时将有效。
我尝试了其他解决方法,但无论我做什么,要么不能按需工作,要么会干扰现有报告。
我甚至添加了一个品牌表,并故意将品牌表和重点品牌表之间的关系设置为多对多。但这仍然没有解决对 CROSSFILTER()函数的需求:

图 8 — 带品牌表的数据模型(作者绘制)
我将重点品牌和品牌表之间的关系设置为多对多,因为该关系变为限制关系,从而避免在结果中创建
但它仍然仅与度量中的 CROSSFILTER()函数一起工作。
而下钻仍然不起作用。
目前我没有解决此问题的方案。
结论
起初,我认为这个要求相对容易解决。
结果是一个完美的示例,说明了良好的数据模型如何解决问题,以及如何操作关系以实现目标。
即便如此,我仍然无法消除对额外过滤器的需求。
有时我们必须接受需求冲突并寻找解决办法以实现目标。
根据上面展示的解决方案,我现在可以创建度量来计算广告活动的成功或重点品牌的销售如何与所有其他品牌相关。

参考资料
阅读 SQLBI 关于处理 m:n 关系时可能的解决方案的这篇文章:
在 Power BI 和 Tabular 中建模多对多关系的不同选项 — SQLBI
我使用了 Contoso 示例数据集,就像在我之前的文章中一样。您可以从微软这里免费下载 ContosoRetailDW 数据集。
表格 Focus Brands 是通过这种数据集的半随机过程生成的。
Contoso 数据可以在 MIT 许可证下自由使用,如这里所述。
我扩大了数据集,以使 DAX 引擎更加努力地工作。
在线销售表包含 7100 万行(而不是 1260 万行),零售销售表包含 1850 万行(而不是 340 万行)。
[## 使用我的推荐链接加入 Medium - Salvatore Cagliari
阅读 Salvatore Cagliari 的每一个故事(以及 Medium 上其他成千上万的作者)。您的会员费直接…
medium.com](https://medium.com/@salvatorecagliari/membership?source=post_page-----292ff012643a--------------------------------)
在 Power BI 中改变数据视角
原文:
towardsdatascience.com/changing-perspective-on-your-data-in-power-bi-4cca965f5fe7
通常,我们在报告页面上需要比实际拥有的空间更多。但如果我们能在同一页面上切换数据视角呢?让我们看看怎么做。
·发布于 Towards Data Science ·阅读时间 7 分钟 ·2023 年 7 月 24 日
--

由 Nathan Dumlao 在 Unsplash 提供的照片
介绍
想象一个包含一些卡片、列和折线图的报告页面。
在页面顶部,你会看到四个按钮:
-
实际数据
-
YTD(年初至今)
-
YE(年末)
-
最近三个月
类似这样的:

图 1 — 实际数据的模型图(作者提供的图)
当你点击“YTD”按钮时,你的数据将会更改为显示 YTD 结果,从而改变我们结果的视角:

图 2 — YTD 结果的模型图(作者提供的图)
这就是我们想要实现的目标。
最后,我们想根据选择来更改结果的格式。
重要的是要理解,我们需要一个日期或年月筛选器来进行我们的报告。我们想要更改结构数据的结果,比如:
-
地理
-
门店
-
产品类别
-
客户组
-
等等。
通常,选择日期周期(如年份或月份)的筛选器,然后以时间序列显示年份或月份的数据并没有意义。我们可能只能看到数据的一个实例。
好的,你可以选择一年以按月份查看结果。但这与预期的场景不同。
准备数据模型
第一步是添加一个场景表。
这个表格包含每个需要的场景的一行。
我在 Power BI 中使用“输入数据”功能创建了这个表格,如下所示:

图 3 — 场景表(作者提供的图)
该表与其他任何表没有关系。这样,我避免了使用时的任何副作用。
现在,我需要为每个场景创建一个度量值。
我不会详细介绍如何编写每个度量值,因为这会使文章过长。
可以将所有逻辑包含在一个度量值中。但我有几个理由创建单独的度量值:
-
可重用性:这些度量值可以在其他报告中重用。
-
简单性:每个度量值本身都很简单。没有必要将多个度量值合并为一个度量值。
-
可测试性:独立测试每个度量值要容易得多。
下一步是创建一个度量值,该度量值根据选定的场景返回正确的结果。
添加动态度量值
现在,我可以添加度量值,它可以根据切片器中选择的项目动态使用正确的度量值。
这个度量值很简单:
Online Sale (By scenario) =
VAR SelectedScenario = SELECTEDVALUE('Reporting Scenarios'[Scenario])
RETURN
SWITCH(SelectedScenario
,"Actuals", [Sum Online Sales]
,"YTD", [Online Sales (YTD)]
,"YE", [Online Sales (YE)]
,"Last three Months", [Online Sales (Last 3 Months)]
)
我在报告页面上添加了一个切片器,将其样式设置为 Tile,并启用单选,以确保始终选择一个值,并且永远不会选择多个值。
当我将单独的度量值和动态度量值添加到矩阵视觉对象中并选择 YTD 时,我得到的是:

图 4 — 使用动态度量值的结果(图由作者提供)
我检查了日期层级的结果以进行测试,以查看所有计算是否正确。
使用计算组
为了使内容更通用,我添加了两个计算组:
第一个包含我在场景中需要使用的所有时间智能函数。
无论你是否打算使用这种方法,我都建议使用计算组来处理时间智能逻辑。这样,我们可以只编程一次你的逻辑,并在所有度量值中重用它。
如果你需要学习如何使用计算组,请跳到下面的参考部分,找到 SQLBI 的介绍文章链接。
第二个计算组只有一个计算项:
VAR SelectedScenario = SELECTEDVALUE('Reporting Scenarios'[Scenario])
RETURN
SWITCH(SelectedScenario
,"Actuals", SELECTEDMEASURE()
,"YTD", CALCULATE(SELECTEDMEASURE(), 'Time Intelligence'[Time Functions] = "YTD")
,"YE", CALCULATE(SELECTEDMEASURE(), 'Time Intelligence'[Time Functions] = "YE")
,"Last three Months", CALCULATE(SELECTEDMEASURE(), 'Time Intelligence'[Time Functions] = "Last three Months")
,SELECTEDMEASURE()
)
最后一行确保返回结果,即使没有选择场景或选择了多个场景。
在 Tabular Editor 中,它的显示效果如下:

图 5 — 按场景划分的计算组(图由作者提供)
最后,我将我的在线销售(按场景)度量值更改为以下内容,以使用新的计算组:
Online Sale (By scenario) =
CALCULATE([Sum Online Sales]
,'By Scenario'[Scenario] = "By Scenario"
)
当我将两个矩阵视觉对象并排放置,一个使用时间智能计算,另一个使用动态度量值时,结果如预期:

图 6 — 使用计算组的结果(图由作者提供)
所以,第一步已经完成。让我们看看如何处理结果的格式化。
那么动态格式化呢?
想象一下,我们想选择是否查看实际结果、相对于去年绝对数值的差异,还是百分比的差异:

图 7 — 差异类型的选择(图由作者提供)
对于这个案例,我扩展了报告场景表,新增了一个类型列和几个新行:

图 8 — 带有类型列的场景表(作者绘制)
现在,我可以添加进一步的切片器,并通过类型列限制每个切片器的可用选项,仅显示实际需求所需的选项,借助过滤面板:

图 9 — 使用新类型列过滤切片器(作者绘制)
接下来,我在我的时间智能计算组中添加了两个计算项:
-
YoY:按绝对数字计算的年同比差异
-
YoY %:按百分比计算的年同比差异
计算项默认从源度量获取数字格式。
由于与计算项一起使用的度量具有数字格式,因此它将返回一个数字结果,而不是百分比。
但我为百分比结果的计算项添加了格式字符串,如下所示,以解决这个问题:
“0.00%;-0.00%;0.00%”
这是我对“YoY %”计算项所做的:

图 10 — 带有格式字符串的 YoY % 计算项(作者绘制)
我为动态场景的计算项添加了两个额外的案例:
VAR SelectedScenario = SELECTEDVALUE('Reporting Scenarios'[Scenario])
RETURN
SWITCH(SelectedScenario
,"Actuals", SELECTEDMEASURE()
,"YTD", CALCULATE(SELECTEDMEASURE(), 'Time Intelligence'[Time Functions] = "YTD")
,"YE", CALCULATE(SELECTEDMEASURE(), 'Time Intelligence'[Time Functions] = "YE")
,"Last three Months", CALCULATE(SELECTEDMEASURE(), 'Time Intelligence'[Time Functions] = "Last three Months")
,"Absolute YoY", CALCULATE(SELECTEDMEASURE(), 'Time Intelligence'[Time Functions] = "YoY")
,"Relative YoY", FORMAT(CALCULATE(SELECTEDMEASURE(), 'Time Intelligence'[Time Functions] = "YoY %"), "0.00%;-0.00%;0.00%")
,SELECTEDMEASURE()
)
不幸的是,计算项的格式字符串未应用于场景的计算项。因此,我们必须使用FORMAT()函数明确设置格式为百分比。
(并排视图与明确的度量)结果如下:

图 11 — 格式正确的结果(作者绘制)
当我添加一个用于周期的切片器并将矩阵更改为按产品层级报告时,结果接近最终版本:

图 12 — 带有周期切片器和产品层级的结果(作者绘制)
现在,我可以添加更多视觉效果与其他动态度量来提高报告的信息密度,而无需添加大量视觉效果。

结论
在这一部分,我结合了几种技术来实现新的功能:
-
无关系的表格
-
时间智能的计算组
-
一个用于场景的计算组,使用来自另一个计算组的计算项
-
在计算项中格式化输出
这种特性的组合对找到挑战的最佳解决方案至关重要。
当我调查解决此问题的可能方法时,我对如何进行有了清晰的想法。
尽管如此,我还是在网上搜索了其他方法,并找到了另一种使用计算表处理每个场景的所有日期的方法。
如该示例视频所解释:
但我的需求需要一种更灵活的方法。计算表中使用的每个场景只有基于实际日期的日期范围。这一日期将是最后一次刷新日期。
只有在刷新数据集后,你才会得到正确的结果。不刷新数据集,你将得到错误的结果。
但我需要能够动态定义报告的日期范围,然后设置所需的计算方法,比如 YTD 或前一年。
当我选择一个可能的解决方案时,我会考虑以下几个关键因素:
-
所选的方法是否满足所有要求?
-
它是否引入了任何限制?
-
它是否引入了不必要的复杂性?
-
是否足够通用以便重用?
我希望这些建议能帮助你应对未来的挑战。
参考文献
SQLBI 关于计算组的介绍:www.sqlbi.com/articles/introducing-calculation-groups/
我使用了 Contoso 示例数据集,就像我之前的文章中一样。你可以从微软这里免费下载 ContosoRetailDW 数据集。
Contoso 数据可以根据这里的 MIT 许可协议自由使用。
[## 获取 Salvatore Cagliari 发布内容的邮件通知。
获取 Salvatore Cagliari 发布内容的邮件通知。通过注册,如果你还没有 Medium 账户,将会创建一个…
medium.com](https://medium.com/@salvatorecagliari/subscribe?source=post_page-----4cca965f5fe7--------------------------------)
自然语言处理中的字符编码:ASCII 和 Unicode 的角色
更详细地审视技术细节和实际应用
·
关注 发布于 Towards Data Science ·7 min read·2023 年 1 月 12 日
--
介绍
在本文中,我们将探讨字符编码标准的主题,特别是 ASCII 和 Unicode 系统。我们将深入了解它们的工作原理及其在深度学习中的作用。此外,我们还将提供一些使用 Tensorflow 进行字符编码的示例,以便概览该库如何处理字符串。

首先,我们将介绍一些重要的概念。
什么是字符编码标准?
字符编码是一种将字符表示为数值(即编码点)的系统。这些编码点允许计算机存储和处理文本,然后可以以其他方式显示或使用。在本文中,我们将解释 ASCII 和 Unicode 字符编码系统,并讨论它们在自然语言处理(NLP)领域的实用性。
ASCII
ASCII(美国信息交换标准代码)是一种字符编码标准,为书面文本中的每个字母、数字和其他符号分配唯一的数字。它被广泛使用,但也存在一些限制。
ASCII 具有 128 个编码点,这意味着它可以表示 128 个字符和符号。其中一些编码点代表计算机的指令,而另一些代表可打印的字符,如字母和数字。
ASCII 中使用的映射系统可以在此表格中找到:

ASCII 表格, 链接
如我们所见,在 128 个编码点中,只有 94 个是可打印的。
例如,使用此表格的十六进制列,我们可以将字符串“Language”编码为“4C 61 6E 67 75 61 67 65”。
ASCII 的限制
正如我们之前所说,ASCII 码的主要限制是它只有 94 个可打印字符。这些字符包括大小写英文字母(52 个字符)、数字(10 个字符)以及标点符号和符号(32 个字符)。因此,ASCII 不适用于使用超过基本拉丁字母的语言。其他语言中有不同的字符(如中文、俄文、挪威文),甚至有法语和西班牙语中的带重音字母,这些都无法使用这种字符编码系统显示。此外,像表情符号或货币符号这样的特殊符号也不包含在 ASCII 中,这限制了它的潜力。因此,需要一种新的字符编码系统,以使其更具扩展性,并考虑到所有这些在 ASCII 码中被忽略的字符和符号。这就是 Unicode 标准的出现。
Unicode
Unicode 是一种字符编码标准,开发于 1980 年代末至 1990 年代初,旨在扩展 ASCII 及其他现有标准的功能。其主要开发动机之一是需要一个可以表示任何语言文本的单一字符编码标准。为了解决这个问题,Unicode 联盟应运而生,创建了一个能够表示世界所有语言的单一通用字符编码标准。Unicode 使用 16 位编码方案,这使其能够表示超过 65,000 种不同的字符。这比 ASCII 可以表示的 128 个字符要多得多。它已经成为 WWW 的主流字符编码标准,并被现代计算系统和软件广泛支持。它可以编码和显示多种语言的文本,包括使用拉丁字母以外的字符集的语言(如中文、日文、阿拉伯文),以及像表情符号和货币符号等特殊符号。
你可以在他们的网站上找到更多信息。
编辑描述
home.unicode.org](https://home.unicode.org/?source=post_page-----9349b4fe3cee--------------------------------)
它是如何工作的?
Unicode 定义了一个代码空间,一组从 0 到 10FFFF(十六进制)的数值,称为代码点,并以 U 开头表示,因此范围从 U+0000 到 U+10FFFF。我们将使用 U 后跟字符的十六进制代码点值,并在必要时使用前导零(例如,U+00F7)。Unicode 代码空间分为十七个平面,编号从 0 到 16。每个平面由 65,536(2¹⁶)个连续的代码点组成。平面 0,称为基本多语言平面(BMP),包含了最常用的字符。其余的平面(1 到 16)被称为补充平面。在每个平面中,字符被分配在命名的相关字符块中。Unicode 块是多个连续的数字字符代码范围之一。它们用于组织 Unicode 标准中的大量字符。每个块通常但不总是用来提供一种或多种特定语言使用的字形,或者在某些通用应用领域中使用。
映射和编码
Unicode 定义了两种映射方法:UTF 编码和 UCS 编码。编码将 Unicode 代码点范围映射到固定范围内的一系列值。所有 UTF 编码将代码点映射到唯一的字节序列。编码名称中的数字表示每个代码单元的位数。UTF-8 和 UTF-16 是最常用的编码。
-
UTF-8 对每个代码点使用一到四个字节。它与 ASCII 非常兼容。
-
UTF-16 对每个代码点使用一个或两个 16 位代码单元。
Unicode 在自然语言处理中的应用
在本节中,我们将看到如何在 NLP 任务中使用 Unicode 以及它的作用。我们将使用一些 Tensorflow 代码使这些示例更加生动。
NLP 模型通常处理具有不同字符集的不同语言。最具代表性的任务可能是神经机器翻译(NMT),在这种任务中,模型必须将句子翻译成其他语言。但一般来说,所有语言模型都必须使用字符串序列作为输入,因此 Unicode 是一个相当重要的步骤。使用 Unicode 表示通常是最有效的选择。
在这里我们将看到如何在 Tensorflow 中表示字符串并使用 Unicode 对其进行操作。基本的 TensorFlow tf.string 类型允许我们构建字节字符串的张量。Unicode 字符串默认是 UTF-8。
tf.constant(u"Hello world 🌎")
>>> tf.Tensor(b'Hello world \\xf0\\x9f\\x8c\\x8e', shape=(), dtype=string)
在这里我们可以看到,表情符号被编码为“\xf0\x9f\x8c\x8e”。它以 UTF-8 表示。
表示
我们可以在 Tensorflow 中使用两种标准表示 Unicode 字符串:
-
字符串标量——其中代码点序列使用已知字符编码(例如 Unicode)进行编码。
-
int32 向量——其中每个位置包含一个单一的代码点。
例如,以下值都表示 Unicode 字符串“语言处理”(在中文中意为“language processing”)。
# Unicode string, represented as a UTF-8 encoded string scalar
text_utf8 = tf.constant(u"语言处理")
print(text_utf8)
>>> tf.Tensor(b'\\xe8\\xaf\\xad\\xe8\\xa8\\x80\\xe5\\xa4\\x84\\xe7\\x90\\x86', shape=(), dtype=string)
我们也可以使用 UTF-16 进行表示。
# Unicode string, represented as a UTF-16-BE encoded string scalar
text_utf16be = tf.constant(u"语言处理".encode("UTF-16-BE"))
print(text_utf16be)
>>> tf.Tensor(b'\\x8b\\xed\\x8a\\x00Y\\x04t\\x06', shape=(), dtype=string)
最终,在一个 Unicode 代码点向量中。
# Unicode string, represented as a vector of Unicode code points
text_chars = tf.constant([ord(char) for char in u"语言处理"])
print(text_chars)
>>> tf.Tensor: shape=(4,), dtype=int32, numpy=array([35821, 35328, 22788, 29702], dtype=int32)
转换
Tensorflow 提供了在这些不同表示之间转换的操作:
- tf.strings.unicode_decode:将编码的字符串标量转换为代码点向量。
text_chars_converted = tf.strings.unicode_decode(text_utf8, input_encoding='UTF-8')
print(text_chars)
print(text_chars_converted)
>>> tf.Tensor([35821 35328 22788 29702], shape=(4,), dtype=int32)
>>> tf.Tensor([35821 35328 22788 29702], shape=(4,), dtype=int32)
- tf.strings.unicode_encode:将代码点向量转换为编码的字符串标量。
text_utf8_converted = tf.strings.unicode_encode(text_chars, output_encoding='UTF-8')
print(text_utf8)
print(text_utf8_converted)
>>> tf.Tensor(b'\\xe8\\xaf\\xad\\xe8\\xa8\\x80\\xe5\\xa4\\x84\\xe7\\x90\\x86', shape=(), dtype=string)
>>> tf.Tensor(b'\\xe8\\xaf\\xad\\xe8\\xa8\\x80\\xe5\\xa4\\x84\\xe7\\x90\\x86', shape=(), dtype=string)
- tf.strings.unicode_transcode:将编码的字符串标量转换为不同的编码。
text_utf16be_converted = tf.strings.unicode_transcode(text_utf8, input_encoding='UTF-8', output_encoding='UTF-16-BE')
print(text_utf16be)
print(text_utf16be_converted)
>>> tf.Tensor(b'\\x8b\\xed\\x8a\\x00Y\\x04t\\x06', shape=(), dtype=string)
>>> tf.Tensor(b'\\x8b\\xed\\x8a\\x00Y\\x04t\\x06', shape=(), dtype=string)
字符长度
我们可以使用tf.strings.length操作的 unit 参数来指示字符长度的计算方式。默认的单位值是“BYTE”,但可以设置为其他值,例如“UTF8_CHAR”或“UTF16_CHAR”,以确定每个编码字符串中的 Unicode 代码点数量。
# Note that the final character (emoji) takes up 4 bytes in UTF8.
helloWorld = u"Hello World 🌍".encode('UTF-8')
print(helloWorld)
>>> b'Hello World \\xf0\\x9f\\x8c\\x8d'
num_bytes = tf.strings.length(helloWorld).numpy()
num_chars = tf.strings.length(helloWorld, unit='UTF8_CHAR').numpy()
print('{} bytes; {} UTF-8 characters'.format(num_bytes, num_chars))
>>> 16 bytes; 13 UTF-8 characters
如果你计算字符串“Hello World \xf0\x9f\x8c\x8d”的字节数(包括每个字母、空格和字节),你会看到总共有 16 个字节,如输出代码所示。
如果我们像之前一样计算这字符串的字符数,但将表情符号视为一个字符而不是 4 个字节,这个字符串包含 13 个 UTF-8 字符。
如果你想了解更完整的教程,我建议你访问这个 TensorFlow 教程。
## Unicode strings | Text | TensorFlow
NLP 模型通常处理具有不同字符集的不同语言。Unicode 是一个标准编码系统,它...
结论
总之,字符编码是计算机系统和自然语言处理(NLP)中的一个重要方面。ASCII 是一种广泛使用的标准,为文本中的每个字母、数字和符号分配唯一的编号,但它在字符表示上存在局限性。Unicode 标准的开发旨在解决 ASCII 的局限性,它使用 16 位编码方案,使其能够表示超过 65,000 个不同的字符,并支持任何语言的文本。Unicode 已成为全球互联网和现代计算系统中主流的字符编码标准,对于显示和处理各种语言和符号的文本至关重要。在这篇文章中,我们详细概述了 ASCII 和 Unicode 编码系统,以及 Tensorflow 如何管理 Unicode 中的字符串。
如果你有任何疑问或建议,请留下评论。感谢阅读!
探索最终边界:完成 #30DayMapChallenge 之旅
·
关注 发布于 Towards Data Science ·6 分钟阅读·2023 年 11 月 29 日
--
图片由作者提供
随着 #30DayMapChallenge 的结束,我深刻感受到社区和协作的力量。每天都有新的主题,每个主题带来了新的见解,而在这个过程中,呈现出了一幅地理可视化故事的挂毯。这个挑战不仅仅是对技能的考验,它是对共享知识和集体创造力的庆祝。
反思 #30DayMapChallenge:
对我而言,#30DayMapChallenge 不仅仅是每天制作一张地图的承诺;它是对数据讲故事核心的探索。通过这项挑战,我在每张我设计的地图中将迁移的片段故事、复杂的交通网络和城市生活的脉动缝合在一起,我很高兴与大家分享这些。在这个过程中,我学到了很多关于地理可视化的知识。
延续在#30DayChartChallenge中的实践,我利用Observable Plot的简洁性为数据注入生命。在我之前的文章中,我分享了这次挑战前半部分我最喜欢的可视化。在这篇总结文章中,我将揭示一些剩余的地理可视化作品,每一件都融合了数据、设计和发现。你可以在我的#30DayMapChallenge收藏中找到所有的可视化作品及其数据源和代码。
本文中的所有图像均由作者创作。
第 16 天 — 大洋洲

第 16 天 — 大洋洲
利用来自ala.org.au的详细观察记录,我力图通过对澳大利亚独特有袋类动物栖息地的复杂描绘来展示澳大利亚袋鼹鼠的观察情况。
在创建这张地图时,挑战在于以既美观又科学准确的方式呈现空间数据。为了实现这一目标,我将记录到的观察点叠加在澳大利亚的轮廓上,明亮的点表示袋鼹鼠的存在。最终结果是一个信息丰富且引人注目的插图,展示了袋鼹鼠如今稀疏的分布,便于快速视觉评估种群密度和地理分布。
来自wikipedia.org的袋鼹鼠背景图像与澳大利亚地图美妙叠加,使得故事更加吸引人、强大和连贯。
第 17 天 — 流动

第 17 天 — 流动
来自databank.worldbank.org的全面数据集提供了全球迁移的视角,使我能够将跨越大陆和文化的路线串联起来。在这张图像中,我以全球迁移模式的流动性和动态性捕捉了“流动”主题,从美国到 2000 年排名前 10 的国家。它使用弧线来表示人们在地球上的移动,每条曲线都从美国起点,延伸到世界各地的不同国家。
这种可视化不仅反映了从一个国家到另一个国家的物理移动,还 encapsulates 了迁移过程中固有的文化、思想和联系的转移。
第 21 天 — 栅格

第 21 天 — 栅格
对于这个主题,我想创建一个鲜明的西雅图交通碰撞可视化,数据来源于IBM 课程的毕业项目。
我使用了栅格图形方法有效地绘制了城市范围内的每一个事件,这本质上涉及将图像分解成像素或颜色点的网格。每个碰撞事件都表示为一个像素化的点,这强调了数据的颗粒度。这种方法允许表示高水平的细节,并创建一个密度图,可以揭示在其他数据可视化形式中不易察觉的模式。

第 21 天 — 栅格插值选项
在上面的图像系列中,我通过一个迷人的视角探索了“栅格”主题,展示了通过 Observable Plot 提供的各种插值方法,应用于西雅图交通碰撞数据。每个面板代表一种不同的方法——默认、最近邻、重心和随机游走——从单一数据集中创造出不同的视觉效果,每种方法都影响数据的空间分布和视觉纹理。‘默认’保持数据的原始颗粒度,而‘最近邻’引入了更加分段、几乎是立体主义的失真。‘重心’提供了数据点的棱镜状折射,而‘随机游走’则创造了扩散、几乎印象派的效果。这确实让我觉得我创造了意外的艺术。
第 23 天–3D

第 23 天–3D
在这个图像中,我想展示一个光滑而简约的全球钻石生产图像,通过在黑暗的世界地图上用发光的点来象征,数据来自theglobaleconomy.com。
为了创建 3D 效果,我绘制了亮点来标记钻石开采的重要位置,这些亮点在世界地图的黑色背景下熠熠生辉,体现了钻石本身的珍贵。为了增强这些点的发光效果,我在它们下方添加了密度可视化。地球的旋转进一步通过这些钻石生产点的出现和消失来增强 3D 效果。
第 24 天 — 黑白

第 24 天 — 黑白
我想展示 2016 年至 2020 年间华盛顿州迁移趋势的单色解释,数据来自美国人口普查局。这种黑白选择不仅符合美学主题,还突显了人们流入和流出的对比——白色表示净迁入,黑色表示净迁出。
地图上大胆的对比尖峰是该州人口统计变化的图形表示。它们的不同长度和颜色表示迁移的规模和方向,两种颜色的配色方案提供了清晰的视觉差异,增强了理解。通过为这些尖峰和州的边界创建阴影,我想在单色主题的基础上赋予其 3D 效果。
第 30 天——我最喜欢的

第 30 天——我最喜欢的
这张图片是对“我最喜欢的……”主题的生动致敬,展示了瑞士复杂而广泛的铁路网络,这一网络被誉为欧洲最密集的。我喜欢欧洲铁路网络的强大,它无疑是我最喜欢的旅行方式。
在这次可视化中,我另一个最喜欢的地方是类型——栅格。使用栅格图,这张地图确实让瑞士铁路网络的工程和组织奇迹如星空般闪耀。我在#30DayMapChallenge 期间学会了如何制作这个,并深深地爱上了这种图表在表示复杂点系统时的强大能力。
结束制图探险
这些地图不仅仅是每日练习;它们是放大我们世界复杂性和美丽的透镜。挑战的后半部分与前半部分一样引人入胜,每张地图都是数据与设计、故事与空间交汇的更深入探讨。从濒危物种到国际迁移的脉络,从人们的移动到引导他们的轨迹,#30DayMapChallenge 是对我们如何可视化和解释我们所居住空间的深刻探索。
对于那些一直关注的人,你们的好奇心和参与是这项工作的推动力。对于未来可能踏上这条道路的人:愿每个主题挑战你,社区激励你,愿你创建的每一张地图都成为你探索之旅中的一个里程碑。
所有的可视化,包括它们的底层代码和数据,都可以在我的Observable 集合中找到,供那些渴望开始自己制图冒险的人参考。
如果你愿意,可以在Linkedin上找到我。
绘制非参数探索:无分布假设检验的统计框架
评估符号检验和 Wilcoxon 符号秩检验的机制
·
关注 发布于 Towards Data Science ·19 分钟阅读·2023 年 5 月 10 日
--
内容
-
介绍
-
工具 1:单样本符号检验
-
工具 2:双样本符号检验
-
符号检验的局限性
-
工具 3:单样本 Wilcoxon 符号秩检验
-
威尔科克森符号秩检验:可能的并发症
-
工具 4:两样本威尔科克森符号秩检验
-
结论
介绍
统计学是知识工具的集合,允许我们通过包括但不限于参数估计、置信区间构建和假设检验等工具从数据中推断以验证我们的假设。在本文中,我们将学习允许我们检验有关不同数据分位数值的假设的框架,即符号检验和威尔科克森符号秩检验。这些框架的独特之处在于,与 z 检验或 t 检验等流行的假设检验不同,这些检验不需要对数据做任何假设,无论是通过直觉还是通过中央极限定理强加的,即,它们是无分布的或非参数的。您只需拥有来自对称和连续分布的数据,您将配备检验如下主张的工具:

工具 1:单样本符号检验
让我们从一个简单的检验开始,它只涉及计算小于或等于阈值的值,用于假设检验:符号检验(即二项检验)。特别是,我们考虑一个大小为 n 的样本 X₁, …, Xₙ 以及以下陈述的简单零假设和相应的替代假设:

其中 m 表示给定数据的中位数。作为起点,我们来考虑在传统 t 检验设置下我们如何处理这个问题。在这种情况下,我们会假设 n 很大,使得中央极限定理发挥作用。此外,由于分布总是假设为对称的,检验中位数应该与检验均值相同。因此,我们可以定义检验统计量和相应的临界区域如下:

这很简单且相当强大(正如我们稍后将看到的),但它基于一个假设,即分布收敛到一个自由度为 n — 1 的 t 分布。但这可能不一定成立,特别是当 n 不够大时。这需要一个替代框架,让我们可以在不引入数据的分布假设的情况下检验这个假设。这引出了符号检验,它只是涉及统计总数大于 m⁰ 的观察值。例如,考虑以下的 n 个样本数据:

假设如下:

如上所述,我们首先计算大于 m⁰ 的 Xᵢ 的数量。令 N⁺ 表示大于 m⁰ 的 Xᵢ 的随机变量,令 n⁺ 为给定样本的实际值:

现在,我们确定在原假设下上述检验统计量的分布。由于在原假设下 m = m⁰,样本大于 m⁰ 的概率必须是 0.5。因此,N⁺ 衡量的是大于 m⁰ 的样本数,每个样本在原假设下有 0.5 的概率贡献到 N⁺。换句话说,N⁺ 计算的是成功次数,每次成功的概率为 0.5。这定义了一个二项随机变量,参数为样本数 n = 10 和成功概率 p = 0.5。因此,

现在,已知检验统计量在原假设下的分布,我们可以继续计算给定样本的 p 值。请记住,p 值表示在原假设为真的假设下观察到比检验统计量样本值更极端值的概率。因此,

显然,p 值相当大。因此,即使在 10% 的显著性水平下,由于 p 值 > 0.10,我们也未能拒绝原假设!这就结束了简单却有用的符号检验应用。这也引出了另一个重要问题:如果备择假设在另一个方向,或者如果是双侧的情况呢?过程依然相同,只是我们计算比 m⁰(N⁻)更少的观察次数。具体来说,假设我们有兴趣进行测试:

然后,

请注意,分布仍然遵循二项分布,参数与前一个案例的推理相同。最后,我们讨论双侧假设的情况:

对于这种情况,我们计算样本的 N⁺ 和 N⁻:

与之前一样,p 值给出了比 6 或更多大于 0.5 的数值更极端的结果的概率。因为这是一个双侧检验,所以极端结果可以是 6 或更多大于 0.5 的数值,或者 4 或更少大于 0.5 的数值。

和以前一样,p 值相当大。因此,即使在 10% 的显著性水平下,由于 p 值 > 0.10,我们也未能拒绝原假设!这难道不令人着迷吗?仅仅是简单的计数和最基本的概率分布就能让我们测试这样的假设。但我们能否推广结果?换句话说,我们能否将此检验通用于任何特定分位数,而不仅仅是中位数?当然,方法论保持完全相同。例如,假设对于上述相同的数据集,我们有以下假设:

注意π₀.₂₅表示数据的第 25 百分位数或第 1 四分位数。和以前一样,我们计算小于 3 的 Xᵢ数量:

现在,我们确定在零假设下上述检验统计量的分布。由于在零假设下π₀.₂₅ = 3,因此样本小于 3 的概率必须是 0.25。因此,N⁻测量小于 3 的样本数量,每个样本在零假设下有 0.25 的概率贡献给 N⁻。换句话说,N⁻计算成功的数量,每个成功的概率是 0.25。这定义了一个参数化为样本数量 n = 10,成功概率 p = 0.25 的二项随机变量。因此,

因此,

这是一个较低的 p 值,但仍不足以在 5%或 10%的显著性水平下拒绝。由此,我们结束了对单样本符号检验的讨论。
工具 2:两样本符号检验
在本节中,我们将尝试将前面介绍的符号检验推广到两样本的情况。之前,我们提供了数据(由单一样本 X 的观察值组成),我们可以检验中位数(或任何分位数)是否大于、小于或不等于给定的阈值。在本节中,我们将上述概念扩展到两样本的情况。特别地,假设我们得到独立的样本对数据(x₁, y₁),· · ·, (xₙ, yₙ)。零假设声明两个样本大于对方的概率相等,即它们的中位数差为 0,而备择假设则建议两个样本之间存在差异,即中位数差异为正、负或非零。数学上,零假设和备择假设如下:

示例:假设我们给出了以下配对数据:

假设如下:

推导检验统计量并找到其分布的过程保持类似。在前面的案例中,我们只是检查了样本是否大于或小于阈值(零假设下的值)。在这个案例中,我们通过比较每对观察值来检查符号。换句话说,我们检查每个观察值的 Wᵢ = Xᵢ − Yᵢ的符号,从 i = 1 到 n:

表 1:作者提供的图片
检验统计量定义为:

根据之前的逻辑,在原假设下,检验统计量的分布是 Bin(n, 0.5)(可以将检验统计量的值看作在 n 次试验中的成功次数,其中每次成功即 Xᵢ > Yᵢ 的概率为 0.5,因为两个样本都同样可能大于对方)。因此,p 值由以下公式给出:

确实,p 值要小得多,我们可以在 10%的显著性水平下拒绝原假设!最后,如果备择假设是双侧的,我们来尝试计算相同数据的 p 值:

因为这是一个双侧检验,极端结果可以是 8 个或更多正符号,或者 2 个或更少正符号:

因此,我们无法在 10%的显著性水平下拒绝双侧假设。这结束了我们关于进行单样本和双样本符号检验的讨论。
符号检验的局限性
在前两节中,我们使用符号检验进行不同类型的假设检验。由于其非参数性(即分布无关性),符号检验被发现非常有用。然而,一般来说,符号检验的能力并不是很强(即,对于相同的显著性水平,第二类错误的概率仍然很高)。主要原因是符号检验只考虑值的符号,即它们是否小于原假设下的阈值(对于单样本)或是否小于其对应值(对于双样本)。它没有考虑差异的大小,即偏离 0 的程度。在实践中,符号检验很少使用,但由于其极大的简单性,它仍然是引入非参数假设检验讨论的一个很好的工具。这引导我们进入一个新的假设检验框架:威尔科克森符号秩检验,这在某种程度上是对符号检验的扩展,考虑了偏离 0 的大小顺序,并为每个样本值分配秩。在下一节中,我们将更详细地描述该检验。
工具 3:单样本威尔科克森符号秩检验
威尔科克森符号秩检验背后的理论稍微复杂一些。但不用担心。让我们重新审视上一节中的例子,并依次应用威尔科克森检验:

步骤 1: 计算每个 Xᵢ 和 m⁰ 之间差值的绝对值:

表 2:作者提供的图片
步骤 2: 对步骤 1 中计算出的每个值进行排名,即为每个 | Xᵢ − m⁰| 分配排名 (Rᵢ),从 1 到 n:

表 3:作者提供的图片
步骤 3: 计算每个样本的符号秩 (Rₛᵢ),其中符号由 Xᵢ − 4 的符号给出(与符号检验中使用的符号相同),即:

表 4:作者提供的图像
步骤 4: 计算 Wilcoxon 符号秩检验统计量,其由所有数据点的符号秩之和给出:

在讨论 W 的分布之前,让我们暂停一下,思考这些步骤的逻辑。在符号检验中,我们只是计算了大于或小于 m⁰ 的样本数量。在这里,我们通过秩来加权计数,这衡量了不同样本相对于 m⁰ 的绝对偏差。这使我们能够对远小于或远大于 m⁰ 的值给予更高的权重,帮助我们克服符号检验的局限性。
现在,让我们评估 W 的分布。需要提到的是,W 的分布没有封闭形式的公式(虽然其矩生成函数有封闭形式的表达式,但其质量函数没有这样的表达式)。一种方法是使用表格(是的,有专门为不同样本量设计的 W 临界值表)。另一种替代方法是使用 Lyapunov CLT 来近似 W 的分布(CDF):

其中 Φ 是标准正态分布的累积分布函数。精确的数学公式相当复杂,超出了本文的范围。但是,我们仍将尽力理解它。让我们尝试计算 W 的期望和方差:

其中 Sᵣ 表示第 rᵗʰ 秩的符号(所有秩值将从 1 到 n)。因此,

如果我们对 W 进行标准化,我们得到:

确实,它与 Φ 内部的表达式非常相似。实际上,通过 Lyapunov CLT(不是传统的 CLT,因为尽管 rSᵣ 是独立的,但它们不一定是同分布的),我们 h

这导致了 W 分布中的Φ项。你可能还会注意到 W 的 CDF 包含一个额外的 1. 这只是一个单位修正。回忆一下 W 是一个离散随机变量。每当离散随机变量被连续随机变量(例如正态分布)近似时,添加修正项总是很重要的,通常是半单位修正。半单位修正通过调整连续分布的概率,使其更接近离散分布的概率,从而帮助解决这个差异。具体来说,这个修正涉及将用于定义连续分布的区间的中点向相反的四舍五入约定方向移动 0.5 单位。这确保了连续分布中分配给区间的概率更接近离散分布的概率,从而提高了近似的准确性。然而,对于近似 W,我们倾向于使用全单位修正(而不是传统的半单位修正)。这是因为 W 的值只能以 2 的倍数不同。这样考虑:如果你有一个 W 值为 w,并且你有兴趣将 W 值减少 1,你将不得不翻转一个正排名,这会使 W 的值减少 2。例如,在我们的示例中,W 被计算为 3。如果我们想减少 W 值 1,我们能做的最好是将+1 翻转为-1,但这将使 W 的值减少 2(-1–1 = -2)。类似的逻辑适用于增加 W 的值。因此,为了控制 W 可以取的值范围,使用全单位修正(0.5 × 2 = 1)而不是传统的半单位修正(0.5 × 1 = 0.5)。因此,在应用全单位修正并使用 Lyapunov CLT 之后,我们得到:

现在我们已经获得了 W 分布的近似表达式,我们继续计算我们示例的 p 值。回忆一下,p 值表示在原假设为真的假设下观察到比样本观测值更极端的值的概率。因此,

显然,p 值相当大。因此,即使在 10%的显著性水平下,由于 p 值 > 0.10,我们未能拒绝原假设!这就是 Wilcoxon 符号秩检验的简单而有用的应用。现在让我们考虑当备择假设在另一个方向上的情况。特别是,假设我们有兴趣测试:

检验统计量仍然保持不变。然而,由于支持备择假设的方向发生变化,p 值计算也相应调整(根据 p 值的定义):

再次,p 值较大,我们在 10% 的显著性水平下未能拒绝原假设。最后,让我们看一下双边情况:

检验统计量仍然保持不变。然而,就像之前一样,由于支持替代假设的方向发生了变化,p 值的计算也会相应调整(按照 p 值的定义):

再次,p 值较大,我们在 10% 的显著性水平下未能拒绝原假设。现在,让我们将检验统计量推广到任何特定的分位数,而不仅仅是中位数。方法完全相同。例如,假设对于上面的相同数据集,我们给出以下假设:

就像之前一样,我们构建表格并汇总符号秩以得到检验统计量:

表 5:作者提供的图片

现在,我们确定上述检验统计量在原假设下的分布:

如果我们标准化 W,我们得到

根据 Lyapunov CLT,我们有:

因此,

这是一个较低的 p 值,但仍然不够低,无法在 5% 或 10% 显著性水平下拒绝原假设。至此,我们结束了对单样本 Wilcoxon 符号秩检验的讨论。
Wilcoxon 符号秩检验:可能的复杂情况
使用 Wilcoxon 符号秩检验时可能会出现两种可能的复杂情况:
-
对于 i ≠ j,|Xᵢ − m⁰| = |Xᵢⱼ − m⁰| 即,绝对偏差与 m⁰ 相同。在这种情况下,哪个观察值会被分配更高的秩?首先,在数据遵循连续分布的假设下,相同的数据理论上不应出现多次。但是,实际上,往往会发生这种情况。有很多方法可以解决这个问题,但最常见的策略是为每个样本分配秩的平均值。例如,如果有 4 个样本值,它们与 m⁰ 的绝对偏差相同。如果这些观察值的秩为 7、8、9 和 10,则为每个样本分配平均秩:(7 + 8 + 9 + 10)/4 = 8.5。
-
如果 |Xᵢ − m⁰| = 0,对于某些 i,我们从分析中排除该样本。我们仅使用减少后的样本(以及相应的减少样本量)来计算检验统计量及其分布。
例如,如果我们给出以下数据和假设:

我们排除第 5 个样本(因为 |X₅ − 3| = 0),并使用剩余的样本来计算平均秩:

表 6:作者提供的图片

工具 4: 两样本 Wilcoxon 符号秩检验
作为我们最后的非参数假设检验工具包,让我们看一下 Wilcoxon 符号秩检验的两样本扩展。如前所述,原假设认为中位数的差异为 0,而备择假设则建议两个样本之间存在差异,即中位数的差异为正、负或非零。数学上,原假设和备择假设如下:

首先看一下总体框架,然后考虑一个示例。假设我们有两个样本的数据:X₁, X₂, · · ·, Xₙ₁ 和 Y₁, Y₂, · · ·, Yₙ₂。我们将两个样本组合成一个样本,并将所有观测值按升序排列。对于每个观测值,我们分配排名(如有争议则使用平均排名),从 1, 2, · · ·, n1 + n2 开始。最后,我们计算所有 X 样本观测值的排名总和,我们称之为测试统计量 W。例如,考虑以下配对数据:

假设如下:

步骤 1: 组合 Xᵢ 和 Yᵢ:
24.27, 8.63, 16.76, 21.92, 29.59, 4.01, 7.28, −7.75, −6.61, 13.05, 13.47, 24.6, −4.97, 0.07, 6.96, −0.53, 7.26, −11.7, −5.01, −4.43
步骤 2: 将组合排列成升序(记得跟踪哪些观测值来自哪个样本):

表 7: 作者提供的图片
步骤 3: 分配排名:

表 8: 作者提供的图片
步骤 4: 通过对 X 样本的排名求和来计算测试统计量:

接下来,我们找到 W 在原假设下的分布。如前所述,W 的分布没有封闭公式,因此我们使用 Lyapunov CLT 来近似 W 的分布(CDF):

其中 Φ 是标准正态分布的累积分布函数。我们可以通过计算 W 的期望和方差来检查这一点。请注意,这一计算相当复杂,因此您可以跳过它,直接进行 p 值计算。但我们仍包括这一部分以求完整。让 Rˣᵢ 表示分配给 X 的第 i 个样本的排名,而 Rʸⱼ 表示分配给 Y 的第 j 个样本的排名:

在原假设下,由于分布相同,Rˣᵢ 和 Rʸⱼ 必须在 1 ≤ i ≤ n₁; 1 ≤ j ≤ n₂ 的范围内完全分布。因此,

方差的计算略微复杂。我们利用以下事实:

因此,

类似地,对于 i ̸= j:

因此,根据方差和协方差的定义,我们有:


因此,通过方差的扩展和求和:

因此,如果我们标准化 W 并应用 CLT,我们得到:

因此,

因此,在 10% 的显著性水平下,由于 p 值 < 0.10,我们可以拒绝零假设!最后,让我们尝试计算相同数据的 p 值,如果备择假设是双侧的:

测试统计量仍然保持不变。然而,就像以前一样,由于支持备择假设的方向改变,p 值的计算相应地进行了调整(根据 p 值的定义):

因此,在双侧情况下,我们无法在 10% 的显著性水平下拒绝零假设。这结束了我们对进行 Wilcoxon 符号秩检验进行单样本和双样本情况讨论。
结论
在本文中,我们熟悉了一些最知名的非参数或无分布假设检验框架。我们查看了符号检验和 Wilcoxon 符号秩检验在单样本和双样本情况下的表现,并比较了它们在样本观察中的表现。我们探讨了每个测试的单侧和双侧情况,以及如何将它们推广为测试数据的任何给定分位数。符号检验虽然一般不太强大,但非常简单,仅依赖于二项分布的累积分布函数来推导相关的 p 值。另一方面,Wilcoxon 符号秩检验在理论上更为复杂,但往往能够给出更好的结果,因为它不仅考虑值是否小于或大于阈值,还考虑它们绝对偏差的相对大小。
希望您喜欢阅读本文!如果您有任何疑问或建议,请在评论框中回复。
如果您喜欢我的文章并希望阅读更多,请随时通过 邮件 联系我。访问此 链接。
注:所有图片均由作者制作。
讲述故事的图表:将数据可视化转变为更有意义的东西
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 3 月 29 日
--

照片由 Erik Mclean 在 Unsplash 上拍摄。
数据可视化传达了数据框和表格无法表达的想法。然而,通过数据可视化有效讲述故事需要一个具有美感、可解释的图表,这样图表才能提供必要的上下文,自成一体。
幸运的是,Python 包含了许多数据可视化库,例如 Plotly Express,可以通过一行代码快速创建图表 [1]。虽然这些图表很有用,但它们在正式出版物中很少能单独存在,或者在没有提供更多上下文的情况下经受得住审查;一个讲述数据故事的专业、独立图表需要额外的工作。本文介绍了将数据可视化提升到下一个水平的步骤。
代码:
本教程的代码可以在 GitHub 页面链接 上获取。请随意下载代码,并在 Jupyter notebook 中跟随操作——点击“代码”和“下载 ZIP”即可获取 ipynb 文件。
1. 数据准备和初步可视化
使用的库有:
# Data Handling
import pandas as pd
# Data visualization Libraries
import seaborn as sns
import plotly.express as px
import plotly.io as pio
回到一个老牌的数据集,Seaborn 库 的汽车数据集(标记为‘mpg’)提供了能源危机时期的车辆燃料消耗的优质数据 [2]。通过使用 groupby 函数,这些数据将提供 1970 年至 1982 年间每加仑平均行驶里程(MPG)。以下是加载和准备数据框的代码:
# Load in data:
mpg = sns.load_dataset('mpg')
# Get dataset showing average MPG per year by using groupby:
mpg = mpg.groupby(['model_year'])['mpg'].mean().to_frame().reset_index()
# Rename columns:
mpg = mpg.rename(columns={'model_year': 'Year',
'mpg': 'Average MPG'})
mpg.head()
数据框的头部应该如下所示:

作者截图。
一行代码在 plotly express 中生成以下柱状图:
px.bar(mpg, x='Year', y='Average MPG')

作者截图。
仅需几行代码即可准备和可视化数据。但这样的可视化是否值得在杂志或正式商业报告中发表?可能不够。
2. 下一层级:标签、格式、颜色
首先,让我们来看一下图表的颜色方案。运行以下代码将更新颜色值和模板:
# Generate base plot:
plot = px.bar(mpg, x='Year', y='Average MPG', color='Average MPG',
color_continuous_scale=px.colors.diverging.RdYlGn)
# Remove colorbar:
plot.update_coloraxes(showscale=False)
# Update plotly style:
plot.update_layout(template='plotly_white')
plot.show()
更新后的图表现在如下所示:

作者截图。
上述代码的第一个更改是将“平均 MPG”列添加到颜色中。虽然从技术上讲是多余的(“平均 MPG” 已在 y 值或柱状高度中捕获),但将颜色指定为“平均 MPG”允许应用颜色刻度以帮助快速解释。在这种情况下,Plotly 的发散颜色刻度“RdYlGn”与数据很好地映射:较低效的年份为红色,而较高效的年份为绿色,中间则是橙色和黄色。在这里发现其他 Plotly 颜色方案[3]。
由于颜色和 y 值相同,Plotly 通常包含的颜色条是多余的;上述代码将其去除。最后,代码将默认的 Plotly 主题更改为“plotly_white”。有关 Plotly 主题的更多信息,请点击这里 [4]。
注意: 如果考虑无色盲友好的方案,请阅读这篇《Towards Data Science》文章,了解如何调整颜色方案[5]。在这篇文章中,保留了红黄绿颜色方案,既作为概念验证,也因为 MPG 信息仍通过 y 轴(柱状高度)值传达。
3. 标签
接下来的步骤标记轴:
# Label axes:
plot.update_xaxes(title='Model Year',
dtick=1)
plot.update_yaxes(title='Average Miles Per Gallon (MPG)')
Plotly 通常会将轴标题默认为指定轴的数据列标题,但update_xaxes() 和 update_yaxes()允许进一步自定义[6]。请注意,MPG 被拼写为“每加仑英里数”以避免用缩写混淆观众。此外,对于 x 轴,“dtick” 值被设置为 1;这确保了 x 轴上的每个柱状图下方都有一个型号年份值。
以下代码更新标题:
# Update plot layout:
plot.update_layout(
title=dict(
text='<b>Average Miles Per Gallon of Cars Over Time</b>\
<br><i><sup>A Visualization of Improvements in \
Fuel Efficiency During the Energy Crisis Era</sup></i>',
x=0.085,
y=0.95,
font=dict(
family='Helvetica',
size=25,
color='#272b4f'
)))
函数 update_layout() 允许添加标题[7]。请注意,文本包含 HTML 风格代码: 和 用于加粗文本; 和 用于斜体;
代表分页符;最后, 代表图表副标题的上标。标题中的 \ 使文本块在下一行代码上继续,从而创建更整洁的代码块。
字体也可以调整,字体的颜色也是如此。可以使用 Google 的 颜色选择器 [8] 轻松找到其他颜色。
最后,我们使用 add_annotation() 函数在图表上注释数据源[9]:
# Add annotation on data source:
plot.add_annotation(x=0,
y=-0.15,
showarrow=False,
text="<i>Fuel mileage data courtesy of \
Python Seaborn Library</i>",
textangle=0,
xanchor='left',
xref="paper",
yref="paper",
font_color='#a6aeba')
所有这些最终产生了以下结果:

作者截屏。
注意图表有标题和副标题,引入了数据的背景,同时轴提供了明确的标签,条形的高度和颜色迅速显示了每年的平均 MPG。此外,底部有数据源参考。这可能是一个适合出版的图表的良好停止点,但为了展示可视化的进一步可能性,下一部分将介绍更多选项。
4. 添加注释
为了为图表增加更多背景,以下代码创建了一条表示 1970 年至 1982 年平均 MPG 的水平线以及一个描述该线的注释框:
# Add average MPG across era:
plot.add_hline(y=mpg['Average MPG'].mean())
# Add explanation of line:
plot.add_annotation(x=.05,
y=0.67,
text="Average MPG, 1970 through 1982",
textangle=0,
xanchor='left',
xref="paper",
yref="paper",
font_color='black',
bordercolor='black',
borderpad=5,
showarrow=True,
arrowhead=2,
bgcolor='white',
arrowside='end'
)
添加到图表中后,结果如下:

作者截屏。
注意 add_annotation() 函数允许创建边框(bordercolor=‘black’)和箭头。add_annotation() 函数中的 x 和 y 值定位标注框,可能需要一些试验和错误以将框放置在所需的位置。
这条线和标注框进一步强调了 1980 年、1981 年和 1982 年 MPG 的显著提升。此外,它还显示了 1975 年至 1979 年期间,平均 MPG 年年稳步提升。如果需要突出显示这段改进时期,可以使用以下代码完成:
# Add highlight box:
plot.add_vrect(x0="74.5",
x1="79.5",
fillcolor="lightgray",
opacity=0.3,
line_width=0)
# Add explanation of line:
plot.add_annotation(x=.45,
y=0.9,
text="Period of Consistent Improvement
<br>until Breakthrough in 1980's",
textangle=0,
xanchor='left',
xref="paper",
yref="paper",
font_color='black',
showarrow=False,
)
函数 add_vrect() 创建一个矩形框,用于突出显示图表的特定部分[10];在这种情况下,它突出显示了 1970 年代末期的一段持续改进的时期。add_vrect() 代码的位置很重要;将其放在 add_hline() 函数之后意味着它会在水平线的上方,而不是下面,如下所示:

作者截屏。
可能看起来没有其他需要添加的内容,但还有一个可能性:解释燃油经济性的提升。假设该图表是某项研究的一部分,该研究发现发动机尺寸的减小直接促成了每加仑英里数的提升。幸运的是,Seaborn MPG 数据包括发动机排量数据。经过以下的一些数据准备以及一些新的注释,最终的图表准备好了:
# Data prep:
displacement = sns.load_dataset('mpg')
seventies = round(
displacement[displacement['model_year'] < 80]['displacement'].mean(), 2)
eighties = round(
displacement[displacement['model_year'] >= 80]['displacement'].mean(), 2)
# Create text string:
explanation = "<b>Why the Improvement in MPG?</b> <br>\
In the 70's, average engine size was {} <br>\
cubic inches versus {} from 1980 to 1982.<br>\
Larger engines are usually less efficient.".format(seventies, eighties)
# Add explanation for trends:
plot.add_annotation(x=.615,
y=0.02,
text=explanation,
textangle=0,
xanchor='left',
xref="paper",
yref="paper",
font_color='black',
bordercolor='black',
borderpad=5,
bgcolor='white',
showarrow=False
)
数据准备部分计算了 1970 年代(70 到 79 年)的平均发动机排量以及 1980 年代的三年(80、81 和 82 年)。这些数据随后被传入文本字符串中,用于 add_annotation() 函数中的文本值。最终图表如下所示:

作者截屏。
这个图表现在可以独立存在,并讲述能源危机时代的年均燃油里程的故事。一些最后的考虑事项:
-
少即是多:当图表已经达到能够传达必要信息的能力时,就停止添加内容。
-
一个独立的信息图可能需要更多的项目,而在口头报告中或作为书面报告的一部分展示的图表,则可以用更少的项目来满足要求。
-
颜色方案、字体选择和尺寸影响可读性和无障碍性。
-
如果客户无法理解图表的含义,那不是客户的错。
5. 结论
Python 的各种可视化库,包括 Plotly Express,提供了一种快速生成高度可定制图表的方法,这些图表可以从基础可视化到完全自定义的独立产品。通过理解如何最好地传达故事来构建可视化,能够产生最有效的传递效果。自己动手尝试代码,并随时 在 GitHub 页面下载完整笔记本。
在一个代码块中生成最终图表:
# Load in Libraries:
# Data Handling
import pandas as pd
# Data visualization Libraries
import seaborn as sns
import plotly.express as px
# Load in data:
mpg = sns.load_dataset('mpg')
mpg.head()
# Get dataset showing average MPG per year by using groupby:
mpg = mpg.groupby(['model_year'])['mpg'].mean().to_frame().reset_index()
# Rename columns:
mpg = mpg.rename(columns={'model_year': 'Year',
'mpg': 'Average MPG'})
# Generate base plot:
plot = px.bar(mpg, x='Year', y='Average MPG', color='Average MPG',
color_continuous_scale=px.colors.diverging.RdYlGn)
# Remove colorbar:
plot.update_coloraxes(showscale=False)
# Update plotly style:
plot.update_layout(template='plotly_white')
# Label axes:
plot.update_xaxes(title='Model Year',
dtick=1)
plot.update_yaxes(title='Average Miles Per Gallon (MPG)')
# Add labels and source:
# Update plot layout:
plot.update_layout(
title=dict(
text='<b>Average Miles Per Gallon of Cars Over Time</b>\
<br><i><sup>A Visualization of Improvements in \
Fuel Efficiency During the Energy Crisis Era</sup></i>',
x=0.085,
y=0.95,
font=dict(
family='Helvetica',
size=25,
color='#272b4f'
)))
# Add annotation on data source:
plot.add_annotation(x=0,
y=-0.15,
showarrow=False,
text="<i>Fuel mileage data courtesy of \
Python Seaborn Library</i>",
textangle=0,
xanchor='left',
xref="paper",
yref="paper",
font_color='#a6aeba')
# Add highlight box:
plot.add_vrect(x0="74.5",
x1="79.5",
fillcolor="lightgray",
opacity=0.3,
line_width=0)
# Add explanation of line:
plot.add_annotation(x=.45,
y=0.9,
text="Period of Consistent Improvement\
<br>until Breakthrough in 1980's",
textangle=0,
xanchor='left',
xref="paper",
yref="paper",
font_color='black',
showarrow=False,
)
# Add average MPG across era
# Create Line:
plot.add_hline(y=mpg['Average MPG'].mean())
# Add explanation of line:
plot.add_annotation(x=.05,
y=0.67,
text="Average MPG, 1970 through 1982",
textangle=0,
xanchor='left',
xref="paper",
yref="paper",
font_color='black',
bordercolor='black',
borderpad=5,
showarrow=True,
arrowhead=2,
bgcolor='white',
arrowside='end'
)
# Add a box to explain the trends
# Data prep:
displacement = sns.load_dataset('mpg')
seventies = round(
displacement[displacement['model_year'] < 80]['displacement'].mean(), 2)
eighties = round(
displacement[displacement['model_year'] >= 80]['displacement'].mean(), 2)
# Create text string:
explanation = "<b>Why the Improvement in MPG?</b> <br>\
In the 70's, average engine size was {} <br>\
cubic inches versus {} from 1980 to 1982.<br>\
Larger engines are usually less efficient.".format(seventies, eighties)
# Add explanation for trends:
plot.add_annotation(x=.615,
y=0.02,
text=explanation,
textangle=0,
xanchor='left',
xref="paper",
yref="paper",
font_color='black',
bordercolor='black',
borderpad=5,
bgcolor='white',
showarrow=False
)
plot.show()
参考文献:
[1] Plotly, Python 中的 Plotly Express (2023)。
[2] Seaborn, Seaborn:统计数据可视化 (2023)。
[3] Plotly, Python 中的内置连续颜色比例尺 (2023)。
[4] Plotly, Python 中的主题和模板 (2023)。
[5] CR Ferreira, 创建色盲友好型数据可视化的两个简单步骤 (2020),数据科学前沿。
[6] Plotly, Python 中的坐标轴 (2023)。
[7] Plotly, Python 中的布局 (2023)。
[8] Google, Google 颜色选择器 (2023)。
[9] Plotly, Python 中的文本和注释 (2023)
[10] Plotly, Python 中的水平和垂直线条及矩形 (2023)。
ChatGPT — 小心使用
原文:
towardsdatascience.com/chat-gpt3-handle-with-care-8b6634781608
了解 ChatGPT 的实际能力和限制对于充分利用这一技术至关重要。香港大学人工智能研究中心的最新研究论文权衡了 OpenAi 算法的局限性和优势。
·发表在 Towards Data Science ·阅读时间 6 分钟·2023 年 3 月 13 日
--

图片由 Possessed Photography 提供,来源于 Unsplash
首先,出现了语言模型。直观上很简单:一个词序列中的下一个词可以用概率分布建模,并且严重依赖于前面的词。词汇是有限的语料库的一部分(英语词汇中有 170,000 个标记)。每个词的含义是有限的。词序列遵循缓慢变化的内部元数据集:语法。这是一个可预测的结构。你可以期待一个动词后面跟着名词,而不是另一个动词。语法和意义,作为限制,限制了下一个词预测中的随机性。这无疑比预测一千家公司第二天的股票价格要容易得多。此外,语言模型本质上是自回归的,下一个词的预测依赖于前面的词,而且需要考虑的潜在不可观测的变量也不多。
正因为如此,语言模型适合使用预训练模型和迁移学习。这是解锁新 AI 革命的关键特性。迁移学习意味着你可以使用别人预训练的模型,比如在 20 GB 的维基百科文章上训练的模型,而无需用自己的数据重新训练,只需进行少量调整以适应你的问题。
这怎么可能呢?好吧,你的语言问题不太可能需要使用与维基百科上完全不同的语法和词汇。迁移学习在人们开始争论第二次人工智能寒冬的时候,开启了新的人工智能夏季。
预训练模型变得更大、更快,随着参数数量和使用的数据量的增加,性能也得到了提升。经验发现,语言模型的性能会随着模型大小的增加而提升,直到达到计算能力的上限。计算机芯片已尽可能地强大。为了让语言模型持续增长,必须发生一些事情。高效地在多台机器上并行训练的方式是显而易见的。变压器的出现。

照片由Aditya Vyas拍摄,来源于Unsplash
由 Google Brain 团队发布的变压器语言模型,在传统语言序列模型上进行了一系列令人印象深刻的创新改进。其核心是广泛使用多头自注意力机制和设计为在多个并行 GPU 上运行的模型架构。注意力机制大大改善了语言模型中预测下一个词的任务,以比传统递归神经网络更高效的方式传播序列中所有词汇的信息。
变压器使得大规模、非常大规模的语言模型的预训练成为可能。从 2019 年的 GPT 2 的 15 亿参数,到 2020 年的 GPT3 的 1750 亿参数。铺平了 2022 年 ChatGPT 的惊人发布以及大型语言模型(LLM)时代的道路。
ChatGPT 在许多方面表现出色,但请注意幻觉。
这虽然是一个相当长的引言,但在放置背景时非常重要。语言模型不是黑魔法,也不是人工通用智能。它们不会在短时间内超越人类。它们是极其有用的工具,在预测序列中的下一个词方面表现出色。像所有人类发明的工具一样,如果我们不阅读说明书和细则,它们可能会造成伤害。
词汇是有限的,具有有限的意义,它们以可预测的方式组合,遵循语法结构。然而,信息,即意义的结合方式,是无限的,不一定是可预测的。像 ChatGPT 这样的超大语言模型可以生成全新的信息,完全是虚构的。这反过来又生成了未经验证的事实的新叙述。它们具有参数化的记忆,没有访问外部知识库的能力。我们已经看到它们在概率方面表现良好,但它们没有内部机制来分辨真相与谎言。简而言之,它们可能会出现幻觉。
请查看下面。我在这里假装自己是一位著名的数据科学家,询问 ChatGPT 关于我自己的问题。

作者提供的图片
ChatGPT 没错。可能我在妈妈面前很有名,但仅此而已。但接着我感到被冒犯了,完全编造了关于自己的额外信息。

作者提供的图片
当然,我不是 Kaggle 大师(我希望我能成为),但 AI 对我表示歉意。
ChatGPT 的突然流行在这一领域是前所未有的,这要归功于一个可以交互的界面,保留了积累的知识。对话界面使用了带有人工反馈的强化学习(RCHF)。问题在于,积累的知识可能基于显然不真实的后续问题和纠正。
Yejin Bang、Pascale Fung 及其博士团队几周前发布了一个广泛的框架,用于定量评估像 ChatGPT 这样的模型在公开可用数据集上的表现。作为一个零样本学习者——一个无需专门调整即可回答任何问题的模型——ChatGPT 在大多数任务中被评为最先进。其在问答、情感分析和虚假信息检测方面有了大幅提升。

图片来自论文: arxiv.org/pdf/2302.04023.pdf
关于推理,这是一个最具争议的特性之一,研究人员发现 ChatGPT 在演绎推理方面表现非常好,但在归纳推理和解决数学问题方面表现非常差。
演绎推理是从一般前提出发得出具体结论的过程,当前提包含足够的信息来引导你找到解决方案时效果良好。研究发现该算法在这些推理任务中表现优越。

图片来源于 Pascale Fung 的 YouTube 视频:www.youtube.com/watch?v=ORoTJZcLXek
归纳是逆向过程。它是从数据中提取信息以推断出一个普遍结论。演绎思维则是在你需要检验一个理论时所遵循的智力过程,而归纳思维则是帮助你形成一个理论的过程。换句话说,给定大量详细前提,你可以期望 ChatGPT 提出一些样本数据,但不要期望它根据一些样本数据提出一个通用规则。
实际上,ChatGPT 目前还无法像人类那样形成对世界的概念。
参考文献
Yejin Bang 等,《ChatGPT 在推理、幻想和互动方面的多任务、多语言、多模态评估》,arxiv.org/pdf/2302.04023.pdf
Pascale Fung,ChatGPT: Prof. Pascale Fung 所讲的《ChatGPT 能做什么和不能做什么》,www.youtube.com/watch?v=ORoTJZcLXek
阿希什·瓦斯瓦尼等人,注意力机制就是你所需要的,arxiv.org/pdf/1706.03762.pdf
使用贝叶斯推断与你的数据集对话。
原文:
towardsdatascience.com/chat-with-your-dataset-using-bayesian-inferences-bfd4dc7f8dcd
向数据集提出问题的能力一直是一个令人着迷的前景。你会惊讶于学习一个可以用来询问数据集的局部贝叶斯模型是多么简单。
·发表于 Towards Data Science ·13 分钟阅读·2023 年 11 月 13 日
--

图片由 Vadim Bogulov 提供,来源于 Unsplash
随着类似 ChatGPT 的模型的兴起,更多的人能够分析自己的数据集,并且可以说是“提问”。尽管这很棒,但在将其作为自动化流程中的分析步骤时,这种方法也有其缺点。这尤其适用于模型结果可能产生重大影响的情况。为了保持控制并确保结果准确,我们也可以使用贝叶斯推断与数据集对话。在这篇博客中,我们将逐步介绍如何学习贝叶斯模型,并在数据科学薪资数据集上应用 do-calculus。我将演示如何创建一个模型,让你“提问”数据集并保持控制。你会惊讶于使用 bnlearn 库创建这样一个模型的简便性。
介绍
从数据集中提取有价值的洞察对于数据科学家和分析师来说是一项持续的挑战。类似于 ChatGPT 的模型使得互动分析数据集变得更加容易,但与此同时,这种方法可能变得不够透明,甚至不清楚为什么做出某些选择。依赖这样的黑箱方法在自动化分析流程中远非理想。当模型的结果对采取的行动有重大影响时,创建透明模型尤其重要。
能够有效地与数据集进行沟通一直是研究人员和从业人员的一个令人着迷的前景。
在接下来的部分中,我将首先介绍 bnlearn 库 [1],以及如何学习因果网络。然后,我将演示如何使用混合数据集学习因果网络,以及如何应用 do-calculus 有效查询数据集。让我们看看贝叶斯推理如何帮助我们与数据集互动!
如果你觉得这篇文章有帮助,欢迎 关注我 ,因为我写了更多关于贝叶斯学习的内容。如果你考虑加入 Medium 会员,可以通过使用我的推荐链接来支持我的工作。价格和一杯咖啡一样,但这允许你每月无限阅读文章。
Bnlearn 库
Bnlearn 是一个强大的 Python 包,它提供了一整套用于贝叶斯网络因果分析的函数。它可以处理离散、混合和连续数据集,并提供了广泛的用户友好功能,用于因果学习,包括结构学习、参数学习和推理 [1–3]。在进行推理之前,我们需要理解结构学习和参数学习,因为推理依赖于这两种学习。
学习数据集的因果结构 是 bnlearn 的一个重要特性。结构学习消除了对变量之间潜在关系的先验知识或假设的需求。在 bnlearn 中,有三种方法可以学习因果模型并捕捉变量之间的依赖关系。结构学习将生成一个所谓的 有向无环图(DAG)。虽然这三种技术都会生成因果 DAG,但有些可以处理大量特征,而有些则具有更高的准确性。有关结构学习的更多细节,请参见 下面的博客。
-
基于评分的结构学习:使用评分函数 BIC、BDeu、k2、bds、aic,结合如 exhaustivesearch、hillclimbsearch、chow-liu、树增强朴素贝叶斯(TAN)、朴素贝叶斯等搜索策略。
-
基于约束的结构学习(PC):使用统计方法如卡方检验在建模之前测试边缘强度。
-
混合结构学习:(两种技术的结合)
-
基于评分、基于约束和混合结构学习。虽然这三种技术都会生成因果 DAG,但有些可以处理大量特征,而有些则具有更高的准确性。有关结构学习的更多细节,请参见下面的博客 [2]。
## 使用贝叶斯结构学习在 Python 中检测因果关系的逐步指南。
入门指南,帮助有效确定变量之间的因果关系。
towardsdatascience.com
参数学习 是贝叶斯网络分析的第二个重要部分,bnlearn 在这一领域也表现出色。通过利用一组数据样本和一个(预先确定的)DAG,我们可以估计条件概率分布或表(CPDs 或 CPTs)。有关参数学习的更多细节,我推荐 以下博客:
如果你没有数据,但有专家知识,那么可以使用计算机辅助来转换这些知识的入门指南。
towardsdatascience.com
Bnlearn 还提供了大量的函数和辅助工具,以帮助用户完成整个分析过程。这些包括数据集转换函数、拓扑排序推导、图形比较工具、深刻的交互式绘图功能等。bnlearn 库支持加载 bif 文件,将有向图转换为无向图,并执行统计测试以评估变量之间的独立性。如果你想了解 bnlearn 相比其他因果库的表现,这个博客适合你:
## 贝叶斯因果推断的力量:对库进行比较分析,以揭示隐藏的…
通过使用最合适的贝叶斯因果推断库来揭示数据集中的隐藏因果变量:一个…
towardsdatascience.com
在下一部分,我们将开始使用 do-calculus 进行推断,并通过实际示例进行操作。这使我们能够向数据集提出问题。如前所述,结构学习和参数学习构成了基础。
查询数据集需要使用 do-calculus 进行推断。
当我们进行使用 do-calculus 进行推断时,这基本上意味着我们可以查询数据集并“提出问题”。为此,我们需要两个主要成分:DAG和分配给图中每个节点的 CPTs。CPTs 包含每个变量的概率,并捕捉给定其父节点的因果关系。让我们继续,创建一个示例,看看它是如何真正工作的。
使用数据科学薪资数据集的应用
为了演示,我们将使用从 ai-jobs.net [5] 派生的数据科学薪资数据集。这个薪资数据集是全球收集的,包含了 4134 个样本的 11 个特征。如果我们加载数据,我们可以探索列并将特征设置为连续或类别。请注意,模型复杂性随着类别数量的增加而增加,这意味着需要更多的数据和计算时间来确定因果 DAG。
# Install datazets.
!pip install datazets
# Import library
import datazets as dz
# Get the data science salary data set
df = dz.get('ds_salaries')
# The features are as following
df.columns
# 'work_year' > The year the salary was paid.
# 'experience_level' > The experience level in the job during the year.
# 'employment_type' > Type of employment: Part-time, full time, contract or freelance.
# 'job_title' > Name of the role.
# 'employee_residence' > Primary country of residence.
# 'remote_ratio' > Remote work: less than 20%, partially, more than 80%
# 'company_location' > Country of the employer's main office.
# 'company_size' > Average number of people that worked for the company during the year.
# 'salary' > Total gross salary amount paid.
# 'salary_currency' > Currency of the salary paid (ISO 4217 code).
# 'salary_in_usd' > Converted salary in USD.
复杂性是一个主要限制
当特征包含许多类别时,复杂性会随着与该表关联的父节点数量的增加而呈指数增长。换句话说,当你增加类别的数量时,需要大量的数据来获得可靠的结果。可以这样考虑:当你将数据划分为类别时,每个类别中的样本数量在每次划分后会变得更少。每个类别的样本数量低直接影响统计能力。在我们的例子中,我们有一个特征job_title,它包含 99 个唯一的职称,其中 14 个职称(如数据科学家)包含 25 个样本或更多。剩余的 85 个职称要么是唯一的,要么只出现过几次。为了确保这个特征不会因统计能力不足而被模型移除,我们需要将一些职称进行聚合。在下面的代码部分,我们将职称聚合为 7 个主要类别。这会得到足够样本的类别,以用于贝叶斯建模。
# Group similar job titles
titles = [['data scientist', 'data science', 'research', 'applied', 'specialist', 'ai', 'machine learning'],
['engineer', 'etl'],
['analyst', 'bi', 'business', 'product', 'modeler', 'analytics'],
['manager', 'head', 'director'],
['architect', 'cloud', 'aws'],
['lead/principal', 'lead', 'principal'],
]
# Aggregate job titles
job_title = df['job_title'].str.lower().copy()
df['job_title'] = 'Other'
# Store the new names
for t in titles:
for name in t:
df['job_title'][list(map(lambda x: name in x, job_title))]=t[0]
print(df['job_title'].value_counts())
# engineer 1654
# data scientist 1238
# analyst 902
# manager 158
# architect 118
# lead/principal 55
# Other 9
# Name: job_title, dtype: int64
下一个预处理步骤是重命名一些特征名称。此外,我们还将添加一个新特征,描述公司是否位于美国或欧洲,并删除一些冗余变量,如salary_currency和salary。
# Rename catagorical variables for better understanding
df['experience_level'] = df['experience_level'].replace({'EN': 'Entry-level', 'MI': 'Junior Mid-level', 'SE': 'Intermediate Senior-level', 'EX': 'Expert Executive-level / Director'}, regex=True)
df['employment_type'] = df['employment_type'].replace({'PT': 'Part-time', 'FT': 'Full-time', 'CT': 'Contract', 'FL': 'Freelance'}, regex=True)
df['company_size'] = df['company_size'].replace({'S': 'Small (less than 50)', 'M': 'Medium (50 to 250)', 'L': 'Large (>250)'}, regex=True)
df['remote_ratio'] = df['remote_ratio'].replace({0: 'No remote', 50: 'Partially remote', 100: '>80% remote'}, regex=True)
import numpy as np
# Add new feature
df['country'] = 'USA'
countries_europe = ['SM', 'DE', 'GB', 'ES', 'FR', 'RU', 'IT', 'NL', 'CH', 'CF', 'FI', 'UA', 'IE', 'GR', 'MK', 'RO', 'AL', 'LT', 'BA', 'LV', 'EE', 'AM', 'HR', 'SI', 'PT', 'HU', 'AT', 'SK', 'CZ', 'DK', 'BE', 'MD', 'MT']
df['country'][np.isin(df['company_location'], countries_europe)]='europe'
# Remove redundant variables
salary_in_usd = df['salary_in_usd']
#df.drop(labels=['salary_currency', 'salary'], inplace=True, axis=1)
作为最后一步,我们需要离散化salary_in_usd,这可以手动完成,也可以使用bnlearn中的discretizer函数来完成。为了演示目的,我们将两者都做。在后者的情况下,我们假设薪资依赖于experience_level和country。更多细节请参见这篇博客[6]。基于这些输入变量,薪资然后被划分为不同的区间(参见下面的代码部分)。
# Discretize the salary feature.
discretize_method='manual'
import bnlearn as bn
# Discretize Manually
if discretize_method=='manual':
# Set salary
df['salary_in_usd'] = None
df['salary_in_usd'].loc[salary_in_usd<80000]='<80K'
df['salary_in_usd'].loc[np.logical_and(salary_in_usd>=80000, salary_in_usd<100000)]='80-100K'
df['salary_in_usd'].loc[np.logical_and(salary_in_usd>=100000, salary_in_usd<160000)]='100-160K'
df['salary_in_usd'].loc[np.logical_and(salary_in_usd>=160000, salary_in_usd<250000)]='160-250K'
df['salary_in_usd'].loc[salary_in_usd>=250000]='>250K'
else:
# Discretize automatically but with prior knowledge.
tmpdf = df[['experience_level', 'salary_in_usd', 'country']]
# Create edges
edges = [('experience_level', 'salary_in_usd'), ('country', 'salary_in_usd')]
# Create DAG based on edges
DAG = bn.make_DAG(edges)
bn.plot(DAG)
# Discretize the continous columns
df_disc = bn.discretize(tmpdf, edges, ["salary_in_usd"], max_iterations=1)
# Store
df['salary_in_usd'] = df_disc['salary_in_usd']
# Print
print(df['salary_in_usd'].value_counts())
最终的数据框
最终的数据框有 10 个特征和 4134 个样本。每个特征是一个具有两个或多个状态的分类特征。这个数据框将作为学习结构和确定因果 DAG 的输入。
# work_year experience_level ... country salary_in_usd
# 0 2023 Junior Mid-level ... USA >250K
# 1 2023 Intermediate Senior-level ... USA 160-250K
# 2 2023 Intermediate Senior-level ... USA 100-160K
# 3 2023 Intermediate Senior-level ... USA 160-250K
# 4 2023 Intermediate Senior-level ... USA 100-160K
# ... ... ... ... ...
# 4129 2020 Intermediate Senior-level ... USA >250K
# 4130 2021 Junior Mid-level ... USA 100-160K
# 4131 2020 Entry-level ... USA 100-160K
# 4132 2020 Entry-level ... USA 100-160K
# 4133 2021 Intermediate Senior-level ... USA 60-100K
#
# [4134 rows x 10 columns]
贝叶斯结构学习用于估计 DAG。
目前,我们已经对数据集进行了预处理,准备开始学习因果结构。bnlearn 中实现了六种算法来帮助完成这个任务。我们需要选择一种不需要目标变量的方法,并且它需要能够处理多个类别。可用的搜索策略有:
-
爬山搜索 算法是一种启发式搜索方法。它从一个空网络开始,根据评分指标迭代地添加或移除边。该算法探索不同的网络结构,并选择得分最高的一个。
-
穷举搜索 在所有可能的网络结构上进行穷举搜索,以找到最佳的贝叶斯网络。它根据指定的评分指标评估和打分每个结构。虽然这种方法能保证找到最佳的网络结构,但由于可能性呈指数级增长,对于大型网络来说计算开销可能很大。
-
约束搜索 在贝叶斯网络的结构学习过程中结合用户指定的约束或专家知识。它使用这些约束来引导搜索并限制可能的网络结构空间,确保所学习的网络符合指定的约束。
-
周-刘 算法是一种用于学习树结构贝叶斯网络结构的方法。它计算每对变量之间的互信息,并通过贪婪地选择最大化网络总互信息的边来构建一棵树。该算法高效且广泛用于学习离散贝叶斯网络的结构,但需要设置一个根节点。
-
朴素贝叶斯 算法假设数据集中所有特征在给定类别变量的条件下是条件独立的。它学习给定类别下每个特征的条件概率分布,并使用贝叶斯定理来计算给定特征下类别的后验概率。尽管有其朴素的假设,这种算法在分类任务中常被使用,并且对于大数据集来说效率较高。
-
TAN(树增强朴素贝叶斯)算法是朴素贝叶斯算法的扩展,允许在给定类别变量的条件下特征之间存在依赖关系。它学习一个连接特征的树结构,并使用该结构来建模条件依赖关系。TAN 将朴素贝叶斯的简单性与一定的建模能力相结合,使其成为处理相关特征分类任务的热门选择。此方法需要设置一个类别节点。
评分类型 BIC、K2、BDS、AIC 和 BDEU 被用来评估和比较不同的网络结构。例如,BIC 平衡了模型复杂性和数据拟合,而其他评分类型考虑了不同类型的先验概率。此外,independence test 从模型中剪除虚假的边。在我们的用例中,我将使用 hillclimbsearch 方法和评分类型 BIC 进行结构学习。我们不定义目标值,而是让 bnlearn 决定数据的整个因果结构。
# Structure learning
model = bn.structure_learning.fit(df, methodtype='hc', scoretype='bic')
# independence test
model = bn.independence_test(model, df, prune=False)
# Parameter learning to learn the CPTs. This step is required to make inferences.
model = bn.parameter_learning.fit(model, df, methodtype="bayes")
# Plot
bn.plot(model, title='Salary data set')
bn.plot(model, interactive=True, title='method=tan and score=bic')

图 1. 结构学习后,得到的因果 DAG。

图 2. 因果 DAG 的交互式图。
与你的数据集聊天。
使用学习到的DAG(图 1 和图 2),我们可以估计条件概率分布(CPTs,见下方代码部分),并使用do-calculus进行推断。让我们 开始提出问题吧。注意,结果可能会(略微)因模型中的随机成分而变化。
问题 1.
在大公司工作时,职位的概率是多少?
*P(job_title | company_size=Large (>250))*
在运行下面的代码部分后,我们可以看到工程科学家是最可能的结果 *(P=0.34)* ,其次是数据科学家 *(P=0.26)*。
query = bn.inference.fit(model, variables=['job_title'],
evidence={'company_size': 'Large (>250)'})
# +----+----------------+-----------+
# | | job_title | p |
# +====+================+===========+
# | 0 | Other | 0.031616 |
# +----+----------------+-----------+
# | 1 | analyst | 0.209212 |
# +----+----------------+-----------+
# | 2 | architect | 0.0510425 |
# +----+----------------+-----------+
# | 3 | data scientist | 0.265006 |
# +----+----------------+-----------+
# | 4 | engineer | 0.343216 |
# +----+----------------+-----------+
# | 5 | lead/principal | 0.0407967 |
# +----+----------------+-----------+
# | 6 | manager | 0.0591106 |
# +----+----------------+-----------+
问题 2.
在全职工作类型、部分远程工作、数据科学职能为入门级且居住在德国(DE)的情况下,薪资范围的概率是多少?
在下面的结果中,我们可以看到我们的五个薪资类别,其中在这些条件下最强的后验概率 *P=0.7* 对应的薪资低于 80K。注意,其他薪资也会出现,但发生的频率较低。
通过改变变量和证据,我们可以提出各种问题。例如,我们现在可以改变经验水平、居住地、职位等,确定概率如何变化。
query = bn.inference.fit(model,
variables=['salary_in_usd'],
evidence={'employment_type': 'Full-time',
'remote_ratio': 'Partially remote',
'job_title': 'data scientist',
'employee_residence': 'DE',
'experience_level': 'Entry-level'})
# +----+-----------------+-----------+
# | | salary_in_usd | p |
# +====+=================+===========+
# | 0 | 100-160K | 0.0664068 |
# +----+-----------------+-----------+
# | 1 | 160-250K | 0.0424349 |
# +----+-----------------+-----------+
# | 2 | 80-100K | 0.117463 |
# +----+-----------------+-----------+
# | 3 | <80K | 0.707087 |
# +----+-----------------+-----------+
# | 4 | >250K | 0.0666078 |
# +----+-----------------+-----------+
结束语。
在这篇博客中,我们学习了如何创建贝叶斯模型,以及如何使用 do-calculus 对混合数据集进行推断。通过使用bnlearn,建立这些模型变得简单明了,模型提供了易于理解和解释的结果,这些结果可以轻松嵌入数据科学流程中。
保持安全。保持冷静。
干杯 E.
如果你觉得这篇文章有帮助,欢迎 关注我 ,因为我写了更多关于贝叶斯学习的内容。如果你考虑加入 Medium 会员,你可以通过使用我的推荐链接来支持我的工作。这和一杯咖啡的价格相同,但这让你每个月可以无限阅读文章。
软件
让我们联系吧!
参考文献
-
Taskesen, E. (2020). 使用 bnlearn Python 包学习贝叶斯网络(版本 0.3.22)[计算机软件]。
-
Taskesen E, 使用贝叶斯结构学习在 Python 中检测因果关系的逐步指南,Medium,2021
-
Taskesen E, 使用贝叶斯定理设计知识驱动模型的逐步指南,Medium,2021
-
Taskesen, E. (2020). 贝叶斯因果推断的力量:比较分析库以揭示数据集中隐藏的因果关系,Medium 2023。
-
Kay H. 等人, 使用贝叶斯结构时间序列模型推断因果影响,2015,《应用统计年鉴》(247–274,第 9 卷)
-
Taskesen, E (2023), 创建并探索数据科学中的角色和薪资景观**. Medium。
聊天机器人即将颠覆音乐搜索
原文:
towardsdatascience.com/chatbots-are-about-to-disrupt-music-search-1e4a4cd7ba01
三个引人注目的范式转变理由
·发表于Towards Data Science ·11 分钟阅读·2023 年 6 月 9 日
--

图像由 DALL-E 生成,随后由作者编辑。
技术已经改变了我们搜索和发现音乐的方式。随着技术的不断进步,AI 聊天机器人作为一种有前景的工具,正被用来提升音乐搜索体验。凭借其进行对话互动和提供量身定制的推荐的能力,聊天机器人有望彻底改变我们探索和享受音乐的方式。在这篇博客文章中,我将呈现三个引人注目的理由,说明聊天机器人即将接管音乐搜索。
1 — 技术已经到来

Alexandre Debiève拍摄的照片,来源于Unsplash
最近的人工智能进展使聊天机器人技术达到了新的高度,使其能够以更高的智能和细致度理解和回应人类的查询。鉴于对话式搜索相较于传统方法的优势,微软和谷歌等行业巨头已经开始将 AI 聊天机器人应用到其网页搜索引擎中。
在音乐搜索领域,这一转变尤为相关,因为类似“谷歌搜索”的简单关键词搜索在过去几个月才逐渐受到关注。今年早些时候,一些音乐科技初创公司如Cyanite推出了基于自由文本的音乐搜索功能,作为一种突破性技术。在此之前,音乐搜索主要依赖于选择流派、情绪或艺术家,并在潜在的众多歌曲中筛选出所需的歌曲。
然而,像 ChatGPT 这样的模型的出现促进了从“自由文本”搜索到真正对话式搜索的过渡。这种变革性方法使我们能够克服重复的“输入-输出-重复”工作流程,并用动态和自然的对话取而代之。这种范式的转变与微软的 Bing 和 Google 等主要搜索引擎中所见的持续演变相一致。
此外,构建自定义音乐聊天机器人的可及性已达到前所未有的水平。实际上,我在不到一天的时间内以不到$5 的成本构建了自己的 音乐聊天机器人原型 —— 但更多内容将在后续帖子中详细介绍。尽管 ChatGPT 本身在音乐发现方面已经非常有用,但量身定制的聊天机器人系统提供了更精细的响应、与音乐数据库或网页播放器的无缝集成,并且对对话有更大的控制权。得益于开源社区的最新发展,基于特定领域或公司数据进行训练并根据个人需求定制的自定义聊天机器人的构建变得比以往任何时候都更容易。
2 — 对话是发现音乐最自然的方式

照片由 Christina @ wocintechchat.com 提供,发布在 Unsplash 上。
在音乐发现方面,对话提供了一种自然且直观的方式来探索和寻找新歌曲或艺术家。我们常常依赖朋友、音乐爱好者或专家的推荐,因为这些对话让我们可以表达自己的偏好并获得个性化建议。聊天机器人可以复制这种对话体验,使用户能够就音乐品味进行对话并获得量身定制的推荐。
相比之下,传统的“谷歌搜索”工作流程,我们在过去 20 年中已经习惯了,有时会感到不够直观。当以这种方式搜索音乐时,我们遵循一个严格的步骤顺序:
-
制作一个文本提示,描述我们要寻找的内容。
-
浏览响应并听一些曲目。
-
如果不满意,我们要么分析如何改进提示,使搜索引擎更好地理解我们,要么返回第 1 步。
这种方法的问题在于,每次搜索尝试都会有效地抹去所有先前的结果,即使这些结果对搜索过程的后续阶段可能有用。这种限制正是聊天机器人擅长的地方,因为它们具备存储整个搜索历史(即对话)的能力。
假设你的提示是
一首节奏适中的朋克摇滚歌曲,女歌手演唱,歌词内容关于单相思。
传统搜索引擎将执行搜索并推荐 20 首曲目,按预测相关性排序。你听了前 5 首曲目却不喜欢。它们要么是现场录音,要么节奏稍微快了些。于是,你更改了你的提示
一首录音室录音的朋克摇滚歌曲,节奏适中偏慢,女性歌手演唱,歌词关于单相思。
并重新开始搜索。搜索引擎进行另一次(可能代价高昂的)搜索,并向你推荐另外 20 首歌曲。这个过程会一直持续,直到你找到喜欢的曲目为止。
相比之下,音乐搜索聊天机器人以更优雅的方式解决了这个问题。你从你的提示开始
一首节奏适中、女性歌手演唱的朋克摇滚歌曲,歌词关于单相思。
聊天机器人很聪明,会要求你更精确地指定节奏,因为它知道否则结果可能不准确。你告诉聊天机器人“我不知道,也许 110-130 bpm?”而没有更改你的原始提示。考虑到从对话中收集的所有信息,聊天机器人发起了搜索,并向你展示了根据预测相关性排序的新 20 首曲目列表。
然而,在查看了前 5 个曲目后,你发现它们再次是现场录音,而你并没有将现场录音作为偏好。幸运的是,你无需回到提示阶段重新开始搜索过程,而是可以利用聊天机器人的能力来细化结果。你只需要求聊天机器人从推荐中排除所有现场录音。聊天机器人理解你的请求后,将现有的 20 首曲目分为两类:录音室录音和现场录音。然后,它会向你展示过滤后的结果,省去了再次进行昂贵搜索的必要。
将音乐搜索概念化为一个过程,而不是简单的输入输出操作,这种方法显然对那些通常不知道自己确切需求的用户有很大帮助。一个聊天机器人…
-
通过提问后续问题或指出不准确的表述来引导用户描述他们的需求
-
根据进一步的规格快速更新其推荐
具有大幅颠覆现有音乐搜索系统的潜力。
3 — 领域特定的音乐聊天机器人可以像口袋音乐学家一样

照片由 Jonas Leupe 提供,发布于 Unsplash
音乐聊天机器人最令人兴奋的前景之一是它们作为领域特定专家的潜力,类似于口袋音乐学家。音乐学家是对各种音乐类型、艺术家、历史背景以及其他复杂细节拥有广泛知识的个人。通过将这些专业知识封装在聊天机器人中,用户可以即时访问大量的信息和见解。
为了说明这一点,我们考虑一个场景:你正在寻找一个体现特定音乐特征的歌曲,但你只能通过“像 AC/DC 风格的吉他独奏”这样的参考来描述它。虽然一个合格的音乐搜索引擎可以搜索到吉他独奏的歌曲,但它可能难以理解这个参考,特别是当它的数据库中没有 AC/DC 的歌曲时。在这种情况下,你唯一的选择是尝试描述 AC/DC 吉他独奏的独特之处,这对没有丰富音乐知识的人来说可能很有挑战。
这就是口袋音乐学家聊天机器人的无价之处。一个经过多种来源(如专辑评论、粉丝论坛、乐谱和已发表的音乐分析)的训练的音乐聊天机器人,具备深刻的理解,知道典型的 AC/DC 吉他独奏是如何构建和演奏的。因此,聊天机器人可以将你的参考信息转化为精确的提示,以描述所需的音乐特征,从而被搜索引擎有效利用。
在这里,我要求 ChatGPT 列出一些 AC/DC 独奏的刻板特征。它给出的答案是:
1. 节奏蓝调的影响。
2. 简单且引人入胜的旋律。
3. 原始而粗犷的经典摇滚声音。
4. 强力和标志性吉他段。
5. 布鲁斯风格的弯音和颤音。
7. 高能量和激进的演奏风格。
当然,这些结果并不令人惊艳。然而,请记住,这是一个通用聊天机器人,没有经过特定的音乐学培训。你可以想象,像这样的领域特定聊天机器人能够提供更准确和详细的描述。
这只是将音乐聊天机器人概念化为口袋音乐学家的一个例子,说明它如何帮助改善搜索体验。然而,可能性是无穷的。例如,你可以开始对话,要求聊天机器人列出一些有趣的爵士乐子类型,并解释每种类型的主要特点。然后,你可以选择一个听起来有趣的类型,并在该类型内发起搜索。考虑一下:当 ChatGPT 在 2022 年底发布时,没人能预见到它将为数百万用户提供的广泛用途。基于聊天机器人的音乐搜索同样适用。
开放性问题

图片来源:JESHOOTS.COM 在 Unsplash
聊天还是不聊天
在前面的部分中,我提出了三个理由,说明我们可能正处于音乐搜索范式转变的边缘。虽然对这些论点的有效性可能存在不同意见,但关键问题在于,基于聊天机器人的音乐搜索的优势和可行性是否能说服公司和研究机构开发这项技术。
绝不是显而易见的,聊天机器人会完全取代传统搜索引擎。幸运的是,我们可以深入研究其他搜索领域的进展,特别是网络搜索。尽管如新必应、谷歌搜索和 Perplexity AI 等网络搜索聊天机器人已经出现,但它们的使用仍主要限于技术爱好者和人工智能专业人士。显然,这些聊天机器人尚未获得与通用(并且离线!)聊天机器人 ChatGPT 相媲美的广泛采用。最显著的是,它们远未替代传统的谷歌搜索引擎。
尽管这可能部分由于产品成熟和潜在用户采纳所需的时间,但也有实际考虑因素有利于继续使用更传统的基于关键词或语义的搜索。例如,搜索引擎常被用来查找我们曾经遇到过但不记得完整名称或网址的特定文章、网站或歌曲。在这种情况下,使用基于关键词的搜索并将搜索输入与与输入的关键词紧密匹配的结果对比,更为实际。利用像聊天机器人这样的复杂技术来完成这一任务就像用大锤去砸坚果一样。
此外,并不是每次互动都需要是对话。例如,当在特定的音乐类型中搜索歌曲时,你可能更希望不与 AI 机器人进行对话。这有两个原因。首先,你可能希望获得即时结果,因此任何偏离提供搜索结果的搜索引擎回应都是浪费时间。其次,使用聊天机器人可能将纯粹机械的音乐搜索行为转变为社交互动。这种潜在的社交方面可能被视为一种缺点,特别是对于那些寻求通过音乐逃避社交环境的人来说。
总而言之,我预计这种范式转变不会完全消除传统的音乐搜索方法。相反,我设想聊天机器人将被用于用户寻求指导和咨询的场景,而不是快速筛选音乐目录。在生产音乐领域,基于聊天机器人的搜索可以大大帮助找到适合商业广告或 YouTube 视频的完美曲目。然而,对于普通的音乐流媒体服务用户而言,大多数情况下还是采用更传统的搜索系统可能更为实际。最终,这种范式转变可能会表现为生产音乐库和音乐流媒体服务结合两种类型的搜索,以满足用户的多样化需求。
技术实现
在今天的技术环境中,构建符合特定业务需求的定制聊天机器人变得异常容易。有几种方法可以实现这一点。一种方法是利用像 OpenAI 的 GPT 模型这样的基础模型,通过 API 进行调用,并用自定义逻辑进行增强。另一种方法是利用开源语言模型,并使用领域特定数据对其进行微调,以确保其相关性和准确性。
通过 API 使用像 GPT-4 这样的预构建基础模型提供了许多优点。首先,它使企业能够直接利用这些模型,而无需额外的数据获取、准备或机器学习工作。这对于数据科学家有限或没有的公司尤其有利,因为它将任务简化为软件工程问题。其次,不需要担心构建本地或云基础设施以满足这些大型模型的计算需求。通过利用 API,企业可以以相对较低的成本访问托管解决方案。
一个缺点是,你的数据,包括用户搜索输入、音乐元数据等,可能会被模型提供者(例如 OpenAI)访问,甚至用于进一步训练他们的模型。这可能不符合内部数据治理指南或外部法规,具体取决于你的情况。另一个缺点是这些第三方解决方案通常不允许在你的数据上对其模型进行微调。因此,无法开发之前描述的那种“口袋音乐专家”。
另一方面,使用自己基础设施上的开源模型有几个优势。首先,没有限制可以完全定制一个聊天机器人以满足你的具体需求。通过微调过程,你可以将其转变为全面的音乐专家,或训练它理解你公司或领域内使用的特定词汇。其次,工作流的每一步,包括微调和模型推断,都可以在你的基础设施中实现。这消除了对内部指南或外部法规合规性的任何顾虑。
使用开源模型构建自己的聊天机器人有一个缺点,即会产生显著的工程开销。这会消耗计算和人力资源,在看到初步结果之前,你可能会犹豫是否投资于这种产品。此外,开源模型更新换代的速度很快。因此,你需要过渡到更新的模型,并定期重复微调过程,从而消耗额外资源。相比之下,通过 API 提供的托管解决方案提供了更大的灵活性,可以尝试并切换到其他聊天机器人模型。
总之,如果你的目标是开发一个快速原型或缺乏在适当基础设施中微调和部署自己模型所需的人力资源,我建议目前选择一个托管的 API 解决方案。这正是我为我的聊天机器人原型所做的,你可以从我的方法中获得灵感。然而,值得注意的是,这些解决方案可能会被更先进和定制的系统所超越,比如基于开源模型的系统。
进一步阅读
如果你对这样的系统如何实现感兴趣,可以考虑查看我在不到一天内构建的音乐搜索聊天机器人,链接在这个GitHub 仓库中。我正在编写一篇后续文章,详细描述我如何构建它。
这里还有两篇关于音乐 AI 当前发展的文章,你可能会喜欢:
-
AudioGPT — 未来音乐创作的瞥见
聊天机器人陷入(法律)交火中
GDPR、人工智能法、DSA 和公共体面之间
·
关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 12 月 22 日
--
图片由 Ant Rozetsky 提供,来源于 Unsplash
最近,我一直在尝试提供一些法律建议,关于如何在不踩欧洲监管雷区的情况下实施聊天机器人。在这个过程中,我总是喜欢查看“别人是怎么做的”,然而,当涉及到这个特定问题时,这是一条相当令人失望的途径。似乎对于目前在每第三个网站上实现的技术,提供从法律到烹饪等各种建议的技术,关于如何实施这些机器人并没有太多的书面资料。当我说应该如何做时,我并不是指最佳的网络安全和安全建议或最佳的微调或训练建议。我是指如何、何时以及为哪些用途可以使用机器人,我可以使用哪些数据来训练它,以及我需要与用户分享哪些信息,以确保我的机器人至少在一定程度上符合大多数欧洲法规?(欧洲,仅仅是为了相对简单,但我提出的大多数建议也应被视为对用户的基本体面。)
无论如何,我坐下来做了我通常在面对法律麻烦时会做的事情。我写下了这些内容。
实施 ChatBot 101
1. 选择聊天机器人
尽管这听起来很简单,但这远非一个微不足道的问题。选项众多,包括使用开源代码自行构建聊天机器人。[1] 使用市场上提供的众多聊天机器人 API 中的一个,允许你进行最简单、最快速的准备设置。[2] 基于这些 API 之一对自己的聊天机器人进行微调。[3] 使用各种聊天机器人工具对聊天机器人进行微调。[4] 或者通过选择聊天机器人即服务(Chatbot as a Service)来支付别人代劳。[5]
选择这些选项中的任何一个都不无波及效应。这些波及效应当然包括设置机器人的性能和灵活性,但也包括遵守法律义务的特定要求。例如,从头开发自己的机器人或完全依赖开源代码,数据保护方面绝对是最安全的选项,因为你掌控所有的训练数据,并且数据不会流向其他地方。然而,这也有其缺点,只有在拥有足够的专家资源来确保设置和运行的同时保证一定的性能水平时,才应当考虑这个选项。相反,依赖 API 总是涉及一定程度的数据泄露风险。更不用说你依赖于别人的表现,并且至少在第一时间对他们的错误负责(GDPR 共同控制警告)。当使用其他工具进行微调时,情况当然会变得更加复杂。
最简单的选择可能是将麻烦留给别人,直接购买产品或服务。然而,除了这是最昂贵的方式(尤其是如果你想要一个高度个性化的机器人),这个选项也有其陷阱,因此在选择要雇佣的具体机器人时需要非常小心,同时考虑所有公开共享的数据处理实践、使用的训练数据等信息。否则,你可能会再次陷入麻烦,因为未能遵守应尽的尽职调查义务。
2. 细调聊天机器人
一旦你选择了你的机器人,并且假设你选择了一个包括一些细调的选项,恭喜你!你刚刚从火坑跳入了火炉。不论你是使用自动细调工具还是利用开源代码,卷起袖子自己动手,喂入模型的数据与你选择模型一样重要。
我们都已经熟悉了垃圾进垃圾出的整个议程,但还有另一个可能更重要的议程需要考虑。那就是法律问题的风险。我们已经通过艺术家和报纸对大型语言模型提供者提起的诉讼对这一概念有了一定的了解。而且,很可能的情况是,一旦法律情况在那里得到澄清,这些诉讼可能会扩展到任何 1. 使用他们的产品或服务的人以及 2. 做类似事情的人。关键的启示当然是,跟踪该领域的法律发展,不要用(可能)违法的数据来训练你的模型。我们还可以增加一个额外的建议,避免在任何情况下将个人数据喂给你的模型。暂且不谈版权问题,在非绝对必要的情况下使用个人数据总会让你陷入麻烦。
最后一个需要考虑的可能性和潜在问题是,如今你甚至不需要对模型进行细调。你可以通过进一步的 API 调用或网站调用来不断进行所谓的细调,从中获取机器人响应的数据。如果是这种情况,确保遵守原始网站提供者对数据使用的任何限制。这些限制可以以 robots.txt 文件的形式出现,也可能在其条款和条件中说明。是的,即使是爬取和链接也有其限制。
3. 免责声明
如果有一件法律专家永远不嫌多的事情,那就是‘免责声明’。因此,确保在你的聊天机器人中实现足够数量的免责声明。两个绝对不可妥协的要点是,与 AI 系统互动的人需要在开始互动之前了解这一事实,以及了解输出可能不准确且不应被依赖。这两个点可以很好地打包成一个弹出窗口,但也应在网站上的某个地方持续可见,或者用户可以被反复提醒它们的存在。这里“宁可过于透明,也不要后悔”是适用的。
隐私通知也是如此,整个通知本身就是一种免责声明。虽然大型语言模型的工作原理需要计算机科学学位才能略微理解,但你仍然需要尽量在隐私通知的有限范围内使其易于理解。想象一下如何向你的六岁孩子或可能是你的祖父母解释模型的工作原理,然后从那里开始。图片、视频和图形都非常受欢迎。另一方面,如果你使用了第 1 步中提到的任何 API 或自动化工具,你当然可以链接相关服务提供商的隐私通知,但这并不意味着你就此免责。在这种特定情况下,你是提供服务的人,是处理问题和投诉的首要联系人。因此,你有责任解释用户数据的流向、为什么这样做是必要的以及如何停止处理。这再次需要一些技能和创造力,以便做到透明和适当。祝你好运解决这个难题!
4. 输出
现在我们终于到了输出部分,所以我们肯定要接近尾声了。如果你这么想,那你是对的!至少在某种程度上是如此。这个问题仍然是一座需要单独攀登的山峰。除了已经提到的免责声明,说明结果可能不准确之外,还有一些其他需要考虑的因素,因为可能出现不准确性的原因有很多。第一个原因当然是大语言模型臭名昭著的幻觉,由于它们对我们如此慷慨喂给它们的数据本质上缺乏理解。此外,除了祈祷一些非常聪明的人能找出解决方法,我们别无他法,除了实施我们的免责声明。
然而,另一方面,我们面临的情况有所不同,这适用于所有爬取其他网站以查找和输出信息的聊天机器人。因此,你现在需要问自己,如果抓取的信息是虚假的或甚至是非法的,会发生什么。对于这种情况,最好依赖于所谓的“托管例外”,即现已陈旧的电子商务指令第 14 条中的内容。例如,这一例外也适用于搜索引擎,保证了主机和中介不对他们仅提供访问的内容承担责任。不过,这仅在内容不是明显违法的情况下适用。因此,为了最大程度地简化这一点。首先,只爬取和抓取你事先检查过的可信信息来源(不要试图充当谷歌)。其次,确保在所有聊天机器人的输出中集成参考,以便所有信息的原始来源都能立即可见。
还有一件值得考虑的事,并值得投入额外的编码时间,即集成跟进问题以应对用户初始输入非常广泛或不明确的情况。通过这种方式,你的机器人可以重新提示用户,以便用户提供更好的提示。这将反过来使模型产生更好的输出,无论是准确性还是性能方面。
5. 质量重于速度
为了再次强调这一点,因为它似乎总是归结于此。特别注意你机器人输出的质量,因为这是它们最显著和最明显的问题之一。在意大利 ChatGPT 临时禁令的争议中,不准确的输出旨在证明训练数据的不准确性。[6] 幻觉,作为一种输出缺陷,一直以来都是主要关注点,也仍然阻碍着聊天机器人进入搜索引擎领域。[7] 我们甚至不会讨论算法偏见/垃圾进垃圾出的争论。[8]
除了幻觉,这些仍然是单独的难题,输出的准确性和质量可以通过特别注意训练数据的准确性和质量大大提高。以及这些数据的相关性。此外,如果你通过 API 调用或其他方式主动获取数据,你所获取的数据也应该进行准确性、代表性以及适当性的双重检查。最后,你应该有适当的机制来识别任何必要的更新或需要更新的数据集的变化,当然,也要有一些机制来适当地回应这些识别出的事件。
质量是一个持续关注的问题,而不是一次性完成的检查项目。这一切都需要成本,主要是时间上的,这使得开发过程变得更慢。然而,质量应始终优先于速度,因为并不是所有人都能承受‘快速行动和破坏’的代价。 [9 ] 至少在他们试图开发一个可持续和负责任的商业模式时不能如此。
最后的思考
尽管我一般提倡一种更具深思熟虑和负责任的创新方法,但似乎行业正在倾向于‘快速行动和破坏’的口号,跑步机转动得越来越快。没有人愿意在我们这个时代最重要的比赛中落后。然而,当 OpenAI 忙于开发 AGI,律师们争相定义什么是 AI 系统时,很多初创企业和商业人士心中似乎存在着一些相关性较小的问题。然而,一旦我们考虑到这些问题的规模及其影响的参与者数量,这些问题的琐碎性就会消失。其中一个问题是如何在尽量少踩法律泥潭的情况下开发和实施聊天机器人?希望这篇文章能帮助那些试图正确行事的人们走上正确的道路。

图片由 William Navarro 提供,来源于 Unsplash
[1 ] 2023 年最佳开源聊天机器人平台, botpress, 2022 年 7 月 14 日, botpress.com/blog/open-source-chatbots。
[2 ] Jesse Sumrak, 2023 年最佳聊天机器人 API, twilio BLOG, 2022 年 12 月 27 日, www.twilio.com/blog/best-chatbot-apis。
[3 ] Olasimbo Arigbabu, 微调 OpenAI GPT-3 以构建定制聊天机器人, Medium, 2023 年 1 月 25 日, medium.com/@olahsymbo/fine-tuning-openai-gpt-3-to-build-custom-chatbot-fe2dea524561。
[4 ] Ali Mahdi, 8 款聊天机器人替代品:哪个工具适合你?, Chatling, 2023 年 11 月 11 日, chatling.ai/blog/chatbot-alternatives。
[5 ] Allen Bernard, 2023 年你应该了解的 10 大聊天机器人提供商, CMS WIRE, 2023 年 3 月 10 日, www.cmswire.com/digital-experience/10-chatbot-providers-you-should-know-about/。
[6 ] GPDP, 人工智能:监管机构阻止 ChatGPT。非法收集个人数据。缺乏未成年年龄验证系统, 2023 年 3 月 31 日, www.gpdp.it/web/guest/home/docweb/-/docweb-display/docweb/9870847#english。
[7] 威尔·道格拉斯·赫文,聊天机器人有一天可能会取代搜索引擎。这就是为什么这是个糟糕主意的原因,《麻省理工科技评论》,2022 年 3 月 29 日,www.technologyreview.com/2022/03/29/1048439/chatbots-replace-search-engine-terrible-idea/。
[8] 伊莎贝尔·布斯凯特,人工智能的崛起将焦点集中在算法的偏见上,《华尔街日报》,2023 年 3 月 9 日,www.wsj.com/articles/rise-of-ai-puts-spotlight-on-bias-in-algorithms-26ee6cc9;拉胡尔·阿瓦蒂,垃圾进,垃圾出(GIGO),TechTarget,www.techtarget.com/searchsoftwarequality/definition/garbage-in-garbage-out。
[9] 比阿特丽斯·诺兰,硅谷有了一个新的版本的“快速行动和打破常规”口号,商业内幕,2023 年 12 月 6 日,www.businessinsider.com/silicon-valley-move-fast-and-break-things-sam-altman-openai-2023-12。
ChatGPT 高级数据分析用于自定义 Matplotlib 井眼日志图表
使用 OpenAI 的代码解释器创建用于岩石物理和地球科学解释的井眼日志图表
·发表于Towards Data Science ·15 分钟阅读·2023 年 9 月 18 日
--

ChatGPT 的代码解释器,现在更名为高级数据分析,已经推出了一段时间。它于 2023 年 7 月 6 日发布,是一个由 OpenAI 开发的插件,允许用户上传数据并进行分析。这包括数据清理、创建可视化和数据总结等。
与其依赖你编写 Python 代码来分析数据,你可以通过用简单的英语告诉 ChatGPT 你想做什么。然后,它将为你进行分析。
正如我的许多常规读者所知道的,我是matplotlib的大粉丝。尽管这个库看起来笨拙且耗时,但只需一点努力,就能用它创建令人惊叹的可视化图表。
在玩弄这个新工具之后,我觉得是时候看看如何使用ChatGPT和高级数据分析插件来创建自定义的井眼日志数据图表了。
在继续之前,鉴于针对 OpenAI 的法律案件不断增加:
始终对上传到 ChatGPT 的数据保持谨慎,因为这些数据和你的输入可能会被用于训练未来的模型。如果有疑问,避免上传任何数据,并始终遵循公司政策。
启用 ChatGPT 中的高级数据分析
要在 ChatGPT 中使用高级数据分析插件,首先需要启用它。
这可以通过进入设置并选择 Beta 功能来完成。在这里,你将看到启用高级数据分析的选项,这将启用该插件。

在 ChatGPT 中启用高级数据分析插件。图片由作者提供。
插件现在将在你开始新的聊天时可用。
将数据上传并转换为 Pandas 数据框
首先,我们需要上传我们的文件。对于这个例子,我使用了来自NLOG 数据库的测井数据集(文章末尾有详细信息)。这个数据集包含从一个油气勘探井获得的一系列测井数据。
首先,我们点击聊天输入框上的加号图标,然后选择包含我们数据的文件。

高级数据分析插件的聊天输入框包括一个用于上传数据的加号 (+) 图标。图片由作者提供。
接下来,我们需要给 ChatGPT 提供一个提示。在这种情况下,我将告诉它扮演一个岩石物理学家(地球科学中的一个冷门角色),并且它理解测井数据。这可以帮助微调 ChatGPT 返回的一些响应。

初始提示给 ChatGPT,包含测井数据集。图片由作者提供。
提交上述输入后,ChatGPT 将加载文件并查看内容。
以下是我们从 ChatGPT 的高级数据分析插件获得的响应,它提供了有关数据集中每个测量项的信息。
在这种情况下,它使用 pandas 将 CSV 文件读取到数据框中,然后使用常见的df.head()命令输出表头。

ChatGPT 分析的测井数据集的初步分析和总结。图片由作者提供。
关于文件中曲线的评论大多是正确的。然而,也有一些不一致之处,比如 DT 被称为 Delta-Time。这部分正确;然而,这条曲线表示的是声波压缩速度,是通过声学工具发射的声波在地层中传播的压缩到达时间的测量。
在我使用 ChatGPT 的经验中,我发现它在处理诸如岩石物理学等冷门主题时可能会遇到挑战,并可能返回不正确的信息。这是每个人在使用大型语言模型(LLMs)时都应该注意的事情。
始终仔细检查输出,确保其合理。
我之前在下面的文章中讨论过这个话题:
## OpenAI 的惊人 ChatGPT:它对冷门话题有希望吗?
向 ChatGPT 询问岩石物理学
[towardsdatascience.com
点击“显示工作”下拉框,我们可以检查用于加载数据的代码。我们可以看到它进行了基本的 Python 编程,将我们的 CSV 文件读取到一个pandas 数据框中。
import pandas as pd
# Load the data from the provided CSV file
well_data = pd.read_csv("/mnt/data/L0509WellData.csv")
# Display the first few rows of the data
well_data.head()
ChatGPT 还识别出一些曲线中的占位符值为-999。这些值代表缺失的数据。然而,正如我们将看到的,这可能会引起一些混淆和问题。
使用代码解释器进行数据探索
在处理任何数据集时,最耗时的部分是理解你拥有的数据,进行数据质量检查以及清理数据。
这是我个人认为 ChatGPT 的高级数据分析插件对岩石物理学家和数据科学家最有帮助的地方。然而,像往常一样,你必须检查这些 AI 系统产生的结果和工作,因为它们可能无意中犯了一些错误。
在数据清理步骤的实验过程中,ChatGPT 在将-999 值转换为 NaNs 时遇到了困难。在进行此操作时,它还将列的数据类型转换为字符串。
结果,我必须在请求中具体说明,以确保它保持列为数值型。

更具体的提示以让 ChatGPT 高级数据分析插件将-999 值转换为 NaNs。图片由作者提供。
它最终返回了以下响应:

ChatGPT 对将-999 值替换为 NaNs 的响应。图片由作者提供。
以及以下代码
# Replace -999 values with NaN
well_data.replace(-999, float("nan"), inplace=True)
# Display the first few rows of the updated data
well_data.head()
这生成了以下输出:
RESULT
DEPTH GR DT RHOB DRHO NPHI
0 4609.8008 89.059479 <NA> 2.483700 0.018494 0.129119
1 4609.9008 90.721848 <NA> 2.489308 0.013656 0.108034
2 4610.0008 90.709061 <NA> 2.501088 0.011289 0.085650
3 4610.1008 89.568954 66.223099 2.515450 0.008615 0.070332
4 4610.2008 88.169571 66.705551 2.530982 0.005628 0.065343
这很有趣且有点奇怪,因为响应试图将字符串nan转换为浮点数。我本来会在数据框的替换函数中使用np.nan来用 NaNs 替换-999 值。然而,目前这似乎有效,但以后会导致一些问题。
生成描述性总结
接下来,我们来看看是否可以使用以下简单提示获取数据集中每个测量值的描述性总结表。

ChatGPT 提示生成一个关于测井数据集的简单描述性总结。图片由作者提供。
ChatGPT 在其响应中返回了以下表格。

ChatGPT 总结表提供了测井数据的信息和统计数据。图片由作者提供。
初看起来,这可能和df.describe()函数的输出相同;然而,ChatGPT 还为每个测量值添加了单位。如果我们想把这些内容放入报告中,这很方便。尽管将单位放在左侧的行头而不是每个单元格中会更好。
此外,它将孔隙度单位从小数转换为百分比。这在报告或将信息传递给同事时可能会造成误导。我希望它能将数据保持在原始单位。
最终,由于存在空值,它未能为 DT 曲线创建统计数据——这与之前 ChatGPT 未能将-999 值转换为空值的部分相关。然而,如果我们使用df.describe()方法,并且数据是数值型的,我们仍然可以看到该曲线的统计数据。
对于不了解情况的数据分析师,这可能会被忽视而未被发现。
使用 Matplotlib 和 ChatGPT 高级数据分析插件创建测井图
当我第一次开始在 Medium 上写文章时,我专注于如何使用 matplotlib 创建基本的测井图以及如何使用 Python 处理测井数据。由于我对 Python 相对陌生且在编写代码时遇到困难,这个过程花费了很多时间。
本质上,测井图由子图组成,这些子图通常称为轨迹。在每个轨迹中,不同的科学测量结果被绘制出来,这些测量可以结合使用以制定地下解释。每个测量结果都以深度为 y 轴进行绘图。
作为对 ChatGPT 高级数据分析插件的快速测试,我想看看它是否可以通过我之前的文章中的几个简单提示来帮助重新创建测井图。

使用 matplotlib 生成的示例测井图。作者提供的图片。
第一步是让它创建一个基本的测井图,其中每个测量结果都显示在各自的子图中。

生成测井图的初始提示。作者提供的图片。
在第一次处理和编写代码后,ChatGPT 再次遇到了数据集中的 NaN 值问题。因此,它不得不重新创建图表。

使用 ChatGPT 的高级数据分析插件生成的基本测井图。作者提供的图片。
一旦图表生成,它将返回一个相对不错的首次测井数据图。
有趣的是,它对图表还做了一些调整,包括:
-
反转 y 轴,以便从浅层深度到深层深度
-
对一些测量应用了标准刻度,例如 GR 的刻度范围为 0 到 150
ChatGPT 生成的代码如下:
# Setting up the figure and axes again
fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(15, 10), sharey=True)
# Plotting DEPTH on the first track
ax[0].plot(well_data['DEPTH'], well_data['DEPTH'], color='black')
ax[0].set_title('DEPTH')
ax[0].set_ylabel('Depth (m)')
ax[0].invert_yaxis() # Invert the y-axis to have depth increasing downwards
ax[0].grid()
# Plotting GR on the second track
ax[1].plot(well_data['GR'], well_data['DEPTH'], color='green')
ax[1].set_title('GR (Gamma Ray)')
ax[1].set_xlim(0, 150) # Typical scale for GR
ax[1].grid()
# Plotting DT on the third track (excluding NaN values)
ax[2].plot(well_data['DT'].dropna(), well_data['DEPTH'].loc[well_data['DT'].dropna().index], color='blue')
ax[2].set_title('DT (Delta Time)')
ax[2].grid()
# Plotting RHOB on the fourth track
ax[3].plot(well_data['RHOB'], well_data['DEPTH'], color='red')
ax[3].set_title('RHOB (Bulk Density)')
ax[3].set_xlim(1.5, 3.0) # Typical scale for RHOB
ax[3].grid()
# Plotting DRHO on the fifth track
ax[4].plot(well_data['DRHO'], well_data['DEPTH'], color='cyan')
ax[4].set_title('DRHO')
ax[4].grid()
# Plotting NPHI on the sixth track
ax[5].plot(well_data['NPHI'], well_data['DEPTH'], color='magenta')
ax[5].set_title('NPHI (Neutron Porosity)')
ax[5].set_xlim(-0.15, 0.45) # Typical scale for NPHI
ax[5].grid()
plt.tight_layout()
plt.show()
这段代码与 我早期文章中的代码 非常相似,但不是使用 subplot2grid,而是访问单独的轴并添加图表。
我们现在可以对这个图进行修改,一次进行多个更改。
在使用 ChatGPT 时,我发现将提示分解为单独的指令有助于获得所需的输出。然而,为了观察 ChatGPT 的代码解释器如何处理多个命令,我们可以使用以下提示来进行更改:
-
移除包含深度曲线的子图
-
向每个子图添加网格线
-
为每个子图添加右侧和顶部的脊线
-
为每条曲线添加标准颜色
-
确保所有曲线都有定义的 x 轴刻度范围

在 ChatGPT 提示中进行多个请求以改善井日志图。图像由作者提供。
在进一步处理后,ChatGPT 返回了包含我请求的更改的图。

根据建议的更改后的基本井日志图。图像由作者提供。
它还将代码从对每个轴的单独调用重构为使用循环遍历每个测量值,并使用一些公共函数(如脊线和网格颜色)来应用更改。
# Setting up the figure and axes with one less subplot for depth
fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(15, 10), sharey=True)
# Define a dictionary for colors and scales for each log
colors = {
"GR": "green",
"DT": "magenta",
"RHOB": "red",
"DRHO": "grey",
"NPHI": "blue"
}
scales = {
"GR": (0, 150),
"DT": (140, 40),
"RHOB": (1.95, 2.95),
"DRHO": (-0.2, 0.2),
"NPHI": (0.45, -0.15)
}
# Plotting each log with the specified colors and scales
logs = ['GR', 'DT', 'RHOB', 'DRHO', 'NPHI']
for i, log in enumerate(logs):
ax[i].plot(well_data[log], well_data['DEPTH'], color=colors[log])
ax[i].set_xlim(scales[log])
ax[i].set_title(log)
ax[i].grid(color="lightgrey")
# Adding top and right spines
ax[i].spines['top'].set_visible(True)
ax[i].spines['right'].set_visible(True)
ax[0].set_ylabel('Depth (m)')
ax[0].invert_yaxis() # Invert the y-axis to have depth increasing downwards
plt.tight_layout()
plt.show()
在 Y 轴和一条线之间添加变量填充
在日志图中,通常会在曲线和子图边缘之间应用阴影和额外填充。这有助于岩石物理学家通过可视化来理解数据。
一种方法是为伽马射线(GR)应用变量阴影,以帮助识别页岩和非页岩区间。为此,我们通常使用黄色 — 橙色 — 棕色色图,这可以帮助将干净的储层区间显示为更亮的颜色,将更具页岩特征的区间显示为更暗的颜色。
我在下面的早期文章中详细介绍了这个过程。
使用 matplotlib 对井日志数据应用颜色填充
towardsdatascience.com](/enhancing-visualization-of-well-logs-with-plot-fills-72d9dcd10c1b?source=post_page-----b857c68fa42--------------------------------)
我以前尝试过使用 ChatGPT 和代码解释器,但总是失败。也许这与我如何编写提示有关。
让我们看看 ChatGPT 是否可以通过在提示中更具体地创建一个 YlOrBr 颜色地图阴影来帮助 GR 曲线。

ChatGPT 提示以创建 GR — 伽马射线曲线的变量阴影。图像由作者提供。
我们得到以下图:

ChatGPT 数据分析插件生成 GR 曲线的变量阴影。图像由作者提供。
在第一次尝试中,它成功地在子图上填充了颜色。然而,它没有从子图的左侧到线条进行填充。
我们可以看看 ChatGPT 是否会纠正这一点:

提示仅从子图的左侧到曲线应用填充。图片由作者提供。
并且它确实做到了,这非常令人惊讶,因为之前的尝试都失败了。

ChatGPT 为变量填充修正图表后的日志图。图片由作者提供。
我们可以仔细查看代码,了解它是如何生成图表的。
# Setting up the figure and axes again
fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(15, 10), sharey=True)
# Plotting each log with the specified colors and scales, excluding NaN values for DT
for i, log in enumerate(logs):
valid_indices = well_data[log].dropna().index
ax[i].plot(well_data[log].loc[valid_indices], well_data['DEPTH'].loc[valid_indices], color=colors[log])
ax[i].set_xlim(scales[log])
ax[i].set_title(log)
ax[i].grid(color="lightgrey")
# Adding top and right spines
ax[i].spines['top'].set_visible(True)
ax[i].spines['right'].set_visible(True)
# Adding the fill for the GR subplot
if log == 'GR':
# Add color based on the value of the GR line
for depth, value in zip(well_data['DEPTH'].loc[valid_indices], well_data[log].loc[valid_indices]):
ax[i].fill_betweenx([depth-0.05, depth+0.05],
0, value,
facecolor=plt.cm.YlOrBr(value/150),
edgecolor='none',
zorder=0) # Place the fill behind the line
ax[0].set_ylabel('Depth (m)')
ax[0].invert_yaxis() # Invert the y-axis to have depth increasing downwards
plt.tight_layout()
plt.show()
ChatGPT 采用的方法不同于 我之前的尝试,那种方法是逐一遍历每个排序值并应用填充。而 ChatGPT 是逐层处理,然后应用填充。
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
left_col_value = 0
right_col_value = 150
#assign the column to a variable for easier reading
curve = well_data['GR']
#calculate the span of values
span = abs(left_col_value - right_col_value)
#assign a color map
cmap = plt.get_cmap('YlOrBr')
#create array of values to divide up the area under curve
color_index = np.arange(left_col_value, right_col_value, span / 100)
#setup the plot
well_data.plot(x='GR', y='DEPTH', c='black', lw=0.5, legend=False, figsize=(6,15))
plt.ylim(4800, 4600)
plt.xlim(left_col_value, right_col_value)
plt.title('Plot With a Variable Fill to Y-Axis')
#loop through each value in the color_index
for index in sorted(color_index):
index_value = (index - left_col_value)/span
color = cmap(index_value) #obtain colour for color index value
plt.fill_betweenx(well_data['DEPTH'], 0 , curve, where = curve >= index, color = color)
plt.show()
上述代码生成了带有阴影的 GR 图。

使用 GR 曲线进行颜色填充。 图片由作者提供。
将两条线移动到同一子图上
在日志图上,通常将体积密度 (RHOB) 和中子孔隙度 (NPHI) 绘制在同一轨道上。由于这两条曲线具有不同的刻度,我们需要将其中一条放置在次要 x 轴上。
在 matplotlib 中准确做到这一点可能很棘手,因此我们将看看 ChatGPT 如何处理以下提示。

提示请求将 NPHI 放置在与 RHOB 相同的子图上,并设置为次要轴。图片由作者提供。
返回的图表还不错。ChatGPT 成功地将 NPHI 曲线放置在与 RHOB 相同的子图上,并将其放置在次要 x 轴上。然而,子图的标签相互重叠,尚不清楚哪个刻度属于哪些曲线。
此外,我们现在在子图上有两组网格线,这可能会造成混淆。

ChatGPT 在将 NPHI 移动到与 RHOB 相同的子图后的返回日志图。图片由作者提供。
我们可以通过简单的提示修改它,以确保标签不会重叠,并从其中一个曲线中移除网格线。

调整子图标签和网格线的提示。
ChatGPT 完成了它被要求的任务,但它还为不同的标签添加了颜色。
我没有预期结果会这样做,然而,这确实使我们能够轻松地将标签与正确的线条链接起来。

ChatGPT 数据分析插件在调整 NPHI 和 RHOB 曲线的标签后生成了一个井日志图。图片由作者提供。
因此,基于 ChatGPT 更改标签颜色的情况,我们可以以此为灵感,将其余标签更改为与线条颜色匹配。
如果我们在同一个子图上有多条线,这将对我们有所帮助,这在日志图表中非常常见。

调整所有标签和刻度标签以匹配其各自曲线颜色的提示。图片由作者提供。
ChatGPT 现在已经回应了我们的最终图表,我们现在可以轻松地看到 NPHI 和 RHOB 的刻度及其所属的线条。

ChatGPT 和数据分析插件生成的最终井日志图表。图片由作者提供。
以下代码是 ChatGPT 提出的,用于生成上述图表。
# Setting up the figure and axes again
fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(15, 10), sharey=True)
# Plotting each log with the specified colors and scales, excluding NaN values
for i, log in enumerate(logs):
valid_indices = well_data[log].dropna().index
ax[i].plot(well_data[log].loc[valid_indices], well_data['DEPTH'].loc[valid_indices], color=colors[log])
ax[i].set_xlim(scales[log])
# Remove subplot title
ax[i].set_title("")
# Add axis label representing the name of the log measurement
ax[i].set_xlabel(log, color=colors[log])
# Color tick marks and tick labels in the same color as the line
ax[i].tick_params(axis='x', colors=colors[log])
# Adding top and right spines
ax[i].spines['top'].set_visible(True)
ax[i].spines['right'].set_visible(True)
# Adding the fill for the GR subplot
if log == 'GR':
# Add color based on the value of the GR line
for depth, value in zip(well_data['DEPTH'].loc[valid_indices], well_data[log].loc[valid_indices]):
ax[i].fill_betweenx([depth-0.05, depth+0.05],
0, value,
facecolor=plt.cm.YlOrBr(value/150),
edgecolor='none',
zorder=0) # Place the fill behind the line
ax[i].grid(color="lightgrey")
# Adding NPHI to the RHOB subplot with a secondary x-axis
if log == 'RHOB':
ax2 = ax[i].twiny() # Create a secondary x-axis for NPHI
valid_indices_nphi = well_data['NPHI'].dropna().index
ax2.plot(well_data['NPHI'].loc[valid_indices_nphi], well_data['DEPTH'].loc[valid_indices_nphi], color=colors['NPHI'])
ax2.set_xlim(scales['NPHI'])
ax2.set_xlabel('NPHI', color=colors['NPHI'])
ax2.tick_params(axis='x', colors=colors['NPHI'])
# Remove gridlines for NPHI and display the ones for RHOB
ax2.grid(False)
ax[i].grid(color="lightgrey")
ax[0].set_ylabel('Depth (m)')
ax[0].invert_yaxis() # Invert the y-axis to have depth increasing downwards
plt.tight_layout()
plt.show()
图表和代码看起来合理,我会做一些进一步的修改。然而,我认为此时继续使用 Jupyter Notebooks 修改图表是最好的选择。
这是因为我之前遇到的一些问题,如果我回去更改了早期的提示,它会删除该提示之后的任何内容并重新创建它。
如果我将 ChatGPT 窗口打开了几小时甚至几天,并且它超时,这尤其成问题。
总结
总的来说,我发现 ChatGPT 的数据分析插件(以前称为代码解释器)在生成岩石物理和地球科学的井日志图表方面是一个有用的工具。然而,我对使用它有几个保留意见和问题。
我发现很难在新的聊天实例中复制结果。我之前尝试过相同的过程,结果却完全不同,而且 ChatGPT 很难生成我想要的图表。这甚至发生在使用完全相同的提示时。
有时,数据分析插件生成的结果是值得怀疑的,甚至是错误的。与任何大型语言模型一样,审查输出并确保其在编程和技术上是合理的始终是明智的。
如果你在提示中出现任何错误,回去修改它们并不容易。如果你尝试更改其中一个提示,可能会导致该提示之后的任何聊天内容被删除。因此,我建议你在进行过程中将代码复制到 Jupyter Notebook 中,以免丢失信息。
最终,我对使用 ChatGPT 和数据分析插件的最大保留意见 是上传专有数据的难易程度(在这个例子中,我使用的是公开数据,允许使用)。数据、提示和输出都可能被用来训练未来的模型,而你可能并不知情。版权和知识产权问题日益增加,在使用此工具和专有数据时建议极为谨慎。
很高兴听到你对使用 ChatGPT 和数据分析插件的评论、想法和担忧。
感谢阅读。在离开之前,你一定要订阅我的内容,以便将我的文章发送到你的邮箱中。 你可以在这里订阅!此外,如果你喜欢这些内容并希望表达你的感谢,请考虑给予一些点赞。
本示例中使用的数据
本教程中使用的数据是从NLOG.nl下载的,该网站包含了整个荷兰北海区域的井日志数据。这些数据可以免费下载和使用。有关数据许可的完整细节可以在这里找到。
ChatGPT 与我们面临的未来(现在)
原文:
towardsdatascience.com/chatgpt-and-the-future-present-were-facing-9f2a1cfab0e9
意见
2023 年将比 2022 年更加紧张和压倒性,所以请系好安全带。
·发布于 Towards Data Science ·10 分钟阅读·2023 年 2 月 2 日
--

版权:Midjourney
直到 ChatGPT 不再是 AI 领域最重要的新闻,我想我们还会继续讨论它……开玩笑的,我会确保穿插其他话题,否则我们可能会感到疲惫。
关于 ChatGPT 的直接和长期影响仍有很多话题要讨论。我已经写过关于 ChatGPT 是什么 和 如何最大化利用它,关于 识别其输出的挑战,以及 它对 Google 和传统搜索引擎的威胁,但我还没有涉及到那些 一些人预见到的风险和危害 在现实世界中已经开始显现。
发布两个月后,我们都可以同意 ChatGPT 已经进入主流,并将 AI 领域带入了公众视野。作为一个小故事,一个对 AI 一无所知的朋友在我告诉他之前,就已经开始谈论 ChatGPT 了。这对我来说是第一次——而我并不是唯一一个。
这就是为什么讨论 AI 的后果非常紧迫的原因:ChatGPT 比任何关于如何有效使用它或绝对不应该如何使用它的资源传播得都要快。如今使用 AI 工具的人数比以往任何时候都多(不仅仅是 ChatGPT;Midjourney 在 Discord 服务器上几乎有 1000 万会员),这意味着比以往任何时候都更多的人会误用这些工具。
与我预测/推测性的文章不同,这篇文章不是关于可能 发生的事情,而是关于正在发生的事情。我将重点讨论 ChatGPT,因为这是世界上讨论的焦点,但以下大部分内容也适用于其他类型的生成 AI,只要进行适当的翻译。
这篇文章摘自 The Algorithmic Bridge,这是一本教育性新闻通讯,旨在弥合 AI、算法与人之间的鸿沟。它将帮助你理解 AI 对你生活的影响,并开发更好地应对未来的工具。
弥合算法与人之间的鸿沟。这是关于对你有意义的 AI 的新闻通讯。点击阅读《The…》
thealgorithmicbridge.substack.com
ChatGPT 的危害不再是理论上的
2024 年 1 月 6 日,安全研究小组 Check Point Research (CPR) 发表了一篇名为 “OpwnAI: 网络罪犯开始使用 ChatGPT” 的惊人文章。虽然并不意外,但我没想到会这么快。
CPR 此前研究过 恶意黑客、诈骗者和网络罪犯如何利用 ChatGPT。他们展示了聊天机器人如何“创建完整的感染流程,从网络钓鱼到运行反向 Shell”,以及如何生成动态运行的脚本,适应环境。
尽管 OpenAI 设置了警告框,当 CPR 强迫 ChatGPT 执行违反使用政策的操作时,会出现橙色警告通知,研究小组仍然顺利生成了一封简单的网络钓鱼邮件。他们总结道:“复杂的攻击过程也可以被自动化,使用 LLMs API 生成其他恶意伪造物。”

ChatGPT 生成的基本钓鱼邮件。信用:CPR(已授权)
CPR 研究人员对 ChatGPT 仅能在假设情况下执行这项任务的证据并不满意(怀疑者常常会收到这样的批评:他们警告的潜在风险从未转化为现实世界的危害)。他们希望找到人们以类似方式实际滥用它的真实例子。他们找到了。
CPR 分析了“几个主要的地下黑客社区”,发现至少有 三个具体例子 表明网络犯罪分子正在以违反服务条款且可能直接造成可测量危害的方式使用 ChatGPT。
首先是信息窃取者。在一个题为“ChatGPT — 恶意软件的好处”的帖子中,一位用户分享了他的实验,他“重建了许多恶意软件变种。”正如 CPR 所指出的,OP 的其他帖子表明,“这个人[旨在]向技术能力较差的网络犯罪分子展示如何利用 ChatGPT 进行恶意目的。”

“网络犯罪分子展示了他如何使用 ChatGPT 创建信息窃取工具。” 版权归 CPR(已授权)
其次是加密工具。一位名为“USDoD”的用户发布了一个包含“加密和解密功能”的 Python 脚本。CPR 认为,“这个脚本可以很容易地修改为完全加密某人的机器,而无需任何用户交互。”虽然 USDoD 的“技术技能有限”,但他“参与了各种非法活动。”

“网络犯罪分子 USDoD 发布了多层加密工具。” 版权归 CPR(已授权)
最后的例子是诈骗活动。帖子标题很有启发性:“滥用 ChatGPT 创建暗网市场脚本。” CPR 写道:“这些网络犯罪分子发布了一段代码,利用第三方 API 获取最新的加密货币 …… 价格,作为暗网市场支付系统的一部分。”

“威胁行为者使用 ChatGPT 创建暗网市场脚本。” 版权归 CPR(已授权)
很明显,ChatGPT 因为免费使用和高度直观,吸引了包括技术水平较低的网络犯罪分子。正如 Check Point 的威胁情报组经理 谢尔盖·谢基维奇 所解释的那样:
“正如 ChatGPT 可以被用来协助开发者编写代码,它也可以被用作恶意目的。尽管我们在本报告中分析的工具相当基础,但更复杂的威胁行为者提升他们使用基于 AI 的工具的方式只是时间问题。”
ChatGPT 作为在线安全问题的驱动因素并非恐慌者的假设,而是一个难以否认的现实。对于那些认为这在 ChatGPT 之前也可能发生的人,有两点:首先,ChatGPT 可以弥合技术差距。其次,规模在这里非常重要 —— ChatGPT 可以在几秒钟内自动编写脚本。
OpenAI 不应该这么快就让 ChatGPT 免费使用
网络安全、虚假信息、剽窃……许多人反复警告 ChatGPT 类 AI 可能引发的问题。现在恶意用户开始增多。
仍有人可能试图为 ChatGPT 辩护。也许它并不是那么有问题——利大于弊——但也许它是。一个“也许”应该足以让我们三思而后行。OpenAI 在 GPT-2 被认为是“无害”时放松了警惕(他们看到“迄今没有强有力的误用证据”),而且他们再也没有重新提高警惕。
我同意 Scott Alexander 的观点,“也许世界领先的 AI 公司无法控制他们的 AI 是一件坏事”。也许通过人工反馈的强化学习还不够好。也许公司应该找到更好的方法来控制他们的模型,如果他们打算将它们放出去。也许 GPT-2 并不那么危险,但过几次迭代后我们就有了需要担忧的东西。如果不是,我们在再过几次迭代后就会有。
我不是说 OpenAI 没有尝试——他们尝试过(甚至因过于保守而受到批评)。我争论的是,如果我们将这种“我已经尽力做好,所以现在我有了发布 AI 的绿灯”的心态延续到短期未来,我们将遇到越来越多的弊端,而没有任何好处可以弥补。
有一个问题困扰了我几个星期:如果 OpenAI 对做对事如此担忧,为什么他们没有在将模型公开之前建立水印方案来识别 ChatGPT 的输出呢?Scott Aaronson 仍在努力使其有效——一个月 之后 模型完全 viral 了。
我认为水印不会解决这种技术所涉及的根本问题,但它会通过争取时间来帮助。时间让人们适应,科学家们寻找解决最紧迫问题的办法,以及监管者制定相关立法。
GPT 检测器是最后一个(健康的)前沿。
由于 OpenAI 的不作为,我们只能依赖于勉强尝试建立 GPT 检测器,这些检测器可能为人们提供避免 AI 误信息、诈骗或钓鱼攻击的手段。一些人试图将3 年前的 GPT-2 检测器改用于 ChatGPT,但它不起作用。其他人,如Edward Tian,普林斯顿大学计算机科学与新闻学的高年级学生,已经从零开始开发了专门针对 ChatGPT 的系统。
目前,已有超过 10,000 人测试了 GPTZero,包括我在内(这是演示)。田正在打造一个已有3K+ 教师订阅的产品。我承认我曾经骗过它一次(仅仅因为 ChatGPT 拼写错误),但也没有太费劲。
该检测器相当简单;它评估一段文本的“困惑度”和“突发性”。困惑度衡量一句话有多“令检测器惊讶”(即输出单词的分布与语言模型预期的匹配程度),突发性则衡量困惑度在句子之间的一致性。简而言之,GPTZero 利用人类写作通常比 AI 更奇怪的事实 —— 这一点在你读到一页 AI 生成的文本时立刻显而易见。它如此乏味……
在一个<2%的误报率下,GPTZero 是最好的检测器。田自豪地说:“人们有权知道写作是否是人类所为,”他对《每日野兽》说。我同意 —— 即使 ChatGPT 不抄袭,人们声称自己是 ChatGPT 所写内容的作者也是道德上不对的。
但我知道它并非万无一失。一些对输出的更改(例如拼写错误或夹杂自己的内容)可能足以欺骗系统。要求 ChatGPT 避免重复单词效果很好,正如 Yennie Jun 在这里展示的。最后,GPTZero 可能很快会过时,因为新的语言模型每隔几周就会出现 —— AnthropicAI 非正式地宣布了 Claude,正如 Riley Goodside 的分析所证明,它比 ChatGPT 更好。
而且GPT-4就要来了。
这是一场猫鼠游戏,正如一些人所称 —— 而且老鼠总是领先一步。
禁止 ChatGPT:一个糟糕的解决方案
如果检测器运作正常,很多人会感到愤怒。大多数人希望无障碍地使用 ChatGPT。例如,学生们无法在书面作业中作弊,因为一位精通 AI 的教授可能知道检测器的存在(这已经发生过)。3K+ 教师订阅田的即将推出的产品这一事实说明了一切。
但是,由于检测器的可靠性不足,那些不愿面对必须猜测某些书面材料是否是 ChatGPT 产品的不确定性的人采取了最保守的解决方案:禁止 ChatGPT。
《卫报》周五报道了“纽约市学校禁止 ChatGPT”的消息。部门发言人珍娜·莱尔提到,“对学生学习的负面影响的担忧,以及对内容的安全性和准确性的担忧”是决定的原因。尽管我理解教师的观点,但我认为这不是一个明智的做法——这可能是更简单的选择,但不是正确的选择。
Stability.ai 的戴维·哈在新闻发布时发了这条推特:
我承认(之前也有做过)学校面临的问题(例如广泛的难以检测的抄袭),但我必须同意哈的观点。
这是一个困境:这项技术不会消失。它是未来的一部分——可能是很重要的一部分——而且学生(包括你、我和其他人)学习它是非常重要的。禁止 ChatGPT 进入学校不是解决方案。正如哈的推特所暗示的那样,禁止它可能比允许它更有害。
然而,使用它来作弊或写论文的学生会浪费教师的时间和精力,同时阻碍自己的发展而未意识到。正如莱尔所说,ChatGPT 可能会阻碍学生学习“批判性思维和解决问题的能力”。
我(以及许多人)预见的解决方案是什么?教育系统将不得不适应。尽管更困难,但这是更好的解决方案。鉴于学校系统的破损程度,这可能对学生和教师来说是双赢的局面。当然,不用说,在那之前教师最好有一个可靠的检测工具——但我们不应以此为借口来避免让教育适应这些变化的时代。
教育系统有很多改进的空间。如果它在这么多年里没有改变,那是因为没有足够强的动力去做。ChatGPT 给我们一个重新构想教育的理由,唯一缺失的部分是决策者的意愿。
人工智能是新的互联网
确实是这样。有些人将人工智能与火或电相比,但这些发明是逐渐融入社会的,且距离现在已经很久了。我们不知道那时候的感受。人工智能更像是互联网,它将迅速改变世界。
我在这篇文章中试图捕捉一个已经更像现在而非未来的未来。一方面,像 GPT-3 或 DALL-E 这样的人工智能确实存在,另一方面,世界上每个人都意识到了它们。这些假设(例如虚假信息、网络攻击、抄袭等)不再只是理论。它们正在发生在这里,现在,我们将看到更多绝望的措施来阻止它们(例如,建立简陋的检测器或禁止人工智能)。
我们必须假设有些事情将永远改变。但是,在某些情况下,我们可能不得不捍卫我们的立场(就像艺术家们在进行文本到图像转换时所做的,或少数群体在使用分类系统时曾经做过的)。无论你是谁,人工智能都会以某种方式影响你。你最好做好准备。
订阅 算法桥梁。弥合算法与人之间的差距。关于对你生活重要的人工智能的通讯。
你还可以直接在 Medium 上支持我的工作,并通过使用我的推荐链接成为会员来获得无限访问权限 这里! 😃
推动边界:ChatGPT 在粒子物理中的应用
探索 ChatGPT 在研究中的无限潜力
·发表于Towards Data Science ·阅读时间 6 分钟·2023 年 6 月 8 日
--

[Source] 在瑞士 CERN 的大型强子对撞机上的 CMS 探测器中的粒子碰撞。
上个月举行了对高能物理(HEP)社区至关重要的计算机会议之一:所谓的 CHEP 2023,代表Computing on High Energy Physics 和 Nuclear Physics — 是的,非常简单! 😃
作为一名在CERN工作的计算机工程师,这是一项重大事件:这是观察我们领域最新技术趋势的机会。尽管我完全了解 ChatGPT 目前的流行程度,但我没有预料到会有相关的讨论。但我完全错了,确实有几个!
我发现这些讲座非常吸引人,因此在这篇文章中,我想描述这些讲座的主要要点。ChatGPT 不仅在重塑我们的日常任务,还在影响像 HEP 这样的主要研究领域。
让我们探索一下未来的发展吧!
HEP 和 CHEP — 简短介绍
HEP 社区指的是全球范围内参与高能物理领域的科学家、研究人员、工程师、技术员和机构的网络。这个社区致力于研究物质的基本组成部分、支配它们相互作用的力量,以及探索宇宙的基本法则。
CHEP 是一系列会议,重点关注计算、软件和数据管理在 HEP 领域中的应用 — 还有核物理领域。

[Source] CHEP 2023 的会议图片。
实际上,CHEP 是一个相当古老的会议。第一次会议是在 1985 年举行的,从那时起,它每两年组织一次。总体而言,CHEP 会议在推动计算和数据管理的进步方面发挥了至关重要的作用。
CHEP 作为知识交流、合作和新计算技术探索的平台,因此我感到非常惊讶:如果 CHEP 上出现了什么东西,它很可能会成为一个新趋势! 在最近的 CHEP 2023 中,我们有两次关于 ChatGPT 在 HEP 中的全体会议。
准备好了吗?
ChatGPT 能做科学吗?
关于 ChatGPT 的第一次全体会议由来自杰斐逊实验室的David Dean主持,时间安排得非常早。题为计算的演变与革命:前沿科学,David 提供了计算领域最新革命的广泛概述。不用担心,ChatGPT 就是其中之一!
他具体探讨了ChatGPT 是否能做物理的问题,信息很明确:这是一个令人震惊的工具,它也能通过物理考试,但有一个重大缺陷可能会阻止 ChatGPT 在不久的将来被作为工具纳入使用:模型幻觉。

[来源] 原始 ChatGPT-4 技术报告的截图。ChatGPT 解决物理考试的分数以黄色高亮显示。
模型幻觉
尽管模型能够生成类似人类的回答,但有时它仍会倾向于编造事实、坚持错误信息,并错误地执行任务。这些错误的回答被称为幻觉。
事实上,给出错误答案本身并不是问题。主要问题是 ChatGPT 经常以令人信服和权威的方式表现出这些倾向。幻觉有时甚至表现为高度详细的信息,给读者带来错误的准确感,并增加了过度依赖的风险。这在研究界确实是一个问题。
为了将 ChatGPT 用作可靠的辅助工具,需要控制幻觉。目前,ChatGPT 会尝试对任何给定的问题提供答案,即使对目标主题没有足够的信息。
ChatGPT 承认不能提供准确回应并没有什么坏处,这会使该工具在如 HEP 研究等需要高准确性的环境中更为适用。
ChatGPT 作为 HEP 的编码助手
第二次涉及 ChatGPT 的全体会议题为AI/ML 启用的高能物理(HEP)未来的根本性变化,由来自威斯康星大学麦迪逊分校的 Kyle Crammer 主讲。
第二场演讲对 ChatGPT 作为 HEP 工具包中有价值资产的引入持更乐观态度。事实上,Kyle 提到另一场演讲中的Christian Weber,来自布鲁克海文国家实验室,他展示了 ChatGPT 作为编码助手的实际用例,特别是在迁移和转换代码到新平台方面。事实上,ChatGPT 已经实现了用于编码目的的 Python 解释器。

[来源] 截图来自官方 ChatGPT 文档。
HEP 社区中的每个实验都有自己特定的编码模板,即使使用 Python 编程,科学家们也必须遵循一些类或风格规范。其中一个用例是微调 ChatGPT 以根据实验模板编写分析代码。
被这个用例吸引,我尝试生成一个针对我当前实验的分析模板,即在瑞士 CERN 的CMS 实验,ChatGPT 完美生成了第一个模板。我只是使用了网络界面,想象一下经过相关数据微调后的强大功能。

自制截图。使用 ChatGPT 生成 CMS 实验的 Python 分析模板。
根据演讲,即使有时分析不够准确,它也能生成一个初步的模板或骨架。这个想法被探索用于提供更快的上手培训给新实验成员,并更快地构建原型,等其他用例。
摘要
我们不能否认,大型语言模型(LLMs)如 ChatGPT 正在改变我们搜索信息、构建应用程序甚至编程的方式。
与任何技术进步一样,我认为评估任何新工具以利用其优势并将其应用于我们的主要领域是合理的。这两个全体会议只是像 HEP 这样的大型研究社区中这一评估过程的两个例子。
虽然有些评估可能暂时抛弃 ChatGPT 作为研究助手,但其他评估可能允许在具体和有限的领域中纳入这样的工具。无论如何,我相信不必惧怕 AI,并继续与其共同发展,分析其优势,了解如何优化其在目标领域的表现,更重要的是,意识到其缺陷以保持批判精神始终警觉!
就这些了!非常感谢阅读!你能想到任何其他应用 ChatGPT 于研究领域的方法吗?
你也可以订阅我的通讯以便关注新内容。特别是,如果你对 ChatGPT 相关文章感兴趣:
## ChatGPT 知道你的信息:OpenAI 在数据隐私方面的历程
管理 ChatGPT 中个人数据的新方法
测试互联网上最受欢迎的 AI 检测器
提升 ChatGPT 互动中的用户体验
## ChatGPT 文本转语音:人工智能与 Python 数据科学
如果有任何问题,随时向我提问,你可以通过 forcodesake.hello@gmail.com 联系我 😃
ChatGPT:自动化提示评分
原文:
towardsdatascience.com/chatgpt-automated-prompt-scoring-c972f9ee2c4f

此图像是在 DALL·E 2 的协助下创建的
指南
如何使用 python 客观地选择和改进你的 ChatGPT 提示
·发表在 Towards Data Science ·10 分钟阅读·2023 年 4 月 10 日
--
大型语言模型(LLM)如 ChatGPT 正在产生巨大的影响。它们也仅仅是开始。在接下来的一年中,各大公司将开始推出领域/角色专用的 LLM 模型。实际上,像专注于金融的 BloombergGPT 和微软开发者专注的 Copilot 等新产品已经在成为现实。我们很快将看到 AI 个人教练、健康教练、顾问、法律助理等更多应用。虽然有些情况可能需要基于领域特定数据的微调模型,但大多数可以通过简单的提示工程实现。但你如何知道你的提示是否足够好?我们如何在主观文本上生成客观准确度评分?
本指南将涵盖:
-
理论
-
提示工程
-
提示测试
-
提示评分
-
提示反馈
理论
测试 LLM 提示输出的困难在于结果是主观的。我可能觉得结果完美,而你觉得它们不够好。这两种观点都是有效的。这使得纯科学的方法来评分非常困难。应对这些问题的一种好的方法是 Delphi 方法。它涉及使用专家小组并汇总他们的结果。正如你所想象的,这可能很昂贵,但这就是 AI 的作用所在!
我们将创建一系列不同的个性提示,通过使用 Python 中的 OpenAI API 让它们与我们的主要提示进行对话,并创建一个评分提示来评估我们的主要提示表现如何。最后,我们汇总每个个性-提示对话的评分。这在前期工作上稍微多一些,但从长远来看可以节省很多时间。它还提供了一个客观的评分。
这种方法确实存在一个问题,你需要注意。对同一个提示的多次运行将产生略有不同的结果。这是德尔菲方法表现出色的地方。你使用的多样的备用提示越多,你的汇总评分结果将越一致。你实际上可以通过测量评分偏差来测试是否有足够的备用提示。当偏差较小时,你就处于一个不错的状态。
我承认这不是一个完美的解决方案。对于主观问题,没有解决方案是完美的。我注意到,当我手动审查对话时,通常会同意结果。这使我可以测试许多提示,并获得“足够好”的评估,从而快速迭代。在我们实施这里描述的德尔菲测试方法之前,了解提示工程的基础是重要的。
提示工程
基础 LLM 模型提供不错的一般响应。但这并不总是理想的。例如,如果我在构建聊天机器人,我不想返回三段文字的内容。也许我希望我的回答听起来更像人类,更具对话性。提示工程涉及提供指令来引导 LLM 的输出风格、格式和行为。
让我们从一个例子开始。我们计划构建一个聊天机器人,它将回答问题,就像是一个“短小、绿色、尖耳的宇宙巫师,在一个遥远的银河系中使用激光剑。”如果这样的角色存在于科幻小说中,我可以通过使用他们的名字来简化这个提示,但这样会面临版权问题。在一点帮助下,LLM 应该能够理解我的意图。
关于提示工程的详细指南, 请点击这里。我将介绍一些基础知识:
-
要具体而简洁。例如,“用一句话向 4 年级学生解释引力”要比“请简短地解释地球上引力的工作原理,以便任何人都能理解”要好得多。
-
使用###来清晰地区分指令和输入/输出。例如:###你是孔子。用 1 到 2 句话回答问题。使用引号来引用孔子的真实话语。###
-
提供输出格式和示例。虽然 LLM 可能会根据我们对“短小、绿色、尖耳的宇宙巫师”的描述猜测我们在寻找什么,但我们不应该将其留给运气。通过引导对话,模型的表现会更好。让我们把所有内容整合在一起:
你是一个短小、绿色、尖耳的太空巫师,在遥远的星系中使用激光剑。不要自我介绍。用 1 到 2 句话回答问题。然后问一个跟进问题以保持对话的进行
你:“有什么困扰你的吗?”
我:“不要低估我的力量!告诉我反叛基地的位置!”
你:
“我不能透露位置。是什么让你加入了帝国,嗯?”
不错,但别太自满。找到一个好的提示可能需要一些试错过程。这可能会很耗时。随着升级模型的推出(例如,GPT4),提示行为可能会发生剧烈变化,需要调整。我们需要一种自动化测试这些提示的方法。
提示测试
对于自动化测试,我们将使用 ChatGPT API。我们将用 ChatGPT 与自己对话,而不是手动创建对话。我喜欢为不同的角色创建提示,这些提示反映了可能的使用场景(以及一些对抗性的提示)。以下是几个例子:
-
你在学校有一个暗恋对象,正在寻求导师的建议###
-
你明天有一个工作面试,正在寻找如何表现好的建议###
-
你是一个在遥远星系中使用激光剑的黑暗领主。你正在试图从敌人那里逼迫出反叛基地的位置。###
现在我们有了一些提示可以使用,让我们一块一块地生成对话,以便逻辑易于理解。让我们从暗恋的例子开始。稍后我们会清理代码。
import openai
openai.api_key = <your api key>
#Define prompts
wizard_prompt = [
"""### You are a short green pointy-eared space wizard that uses laser-swords
in a galaxy far far away. Do not identify yourself. Answer questions in 1 or 2 sentences.
Then ask a follow-up question to keep the conversation going. ###""",
"Something troubles you, it does?"
]
user_prompt = [
"### You have a crush at school and are seeking advice from a mentor###",
"I have a crush at school but I am too afraid to tell her."
]
#We start with the space wizard:
wizard_response = openai.ChatCompletion.create(
model = 'gpt-3.5-turbo',
messages=[
{"role": "user", "content": wizard_prompt[0]}, #wizard instructions
{"role": "user", "content": "Wizard: " + wizard_prompt[1]}, #wizard primed opening
{"role": "user", "content": "User: " + user_prompt[1]}, #users primed opening
{"role": "user", "content": "Wizard:"}, #requested response to user question
]
)
wizard_response = wizard_response['choices'][0]['message']['content'] #pulling the response from output
wizard_prompt.append(wizard_response) #appending response to prompt list
print(wizard_response)
“恐惧是通向黑暗面的道路,年轻人,要勇敢并说出你的真相。是什么让你觉得她可能有同样的感觉?”
# Now we build the "school crush" persona using the wizard response
user_response = openai.ChatCompletion.create(
model = 'gpt-3.5-turbo',
messages=[
{"role": "user", "content": user_prompt[0]}, #user instructions
{"role": "user", "content": "Mentor: " + wizard_prompt[1]}, #wizard primed opening
{"role": "user", "content": "You: " + user_prompt[1]}, #user primed opening
{"role": "user", "content": "Mentor: " + wizard_prompt[2]}, #wizard generated response
{"role": "user", "content": "You: "}, #requested response from wizard reply
]
)
user_response = user_response['choices'][0]['message']['content']
user_prompt.append(user_response) #appending user response to prompt list
print(user_response)
“好吧,她总是很高兴见到我,我们有很多共同点。我们有过一些很棒的对话和笑声。但如果她不感觉一样,我不想破坏我们的友谊。”
从这里,我们可以来回循环,将回应附加到提示列表中。这是我承诺的干净代码:
def GPT_call(prompts):
call = openai.ChatCompletion.create(
model = 'gpt-3.5-turbo',
messages=[{"role": "user", "content": prompt} for prompt in prompts]
)
return call['choices'][0]['message']['content']
#starting over with initial wizard/user prompts
def GPT_convo(chatbot_prompt,user_prompt):
chatbot_prompt = [chatbot_prompt[0],'Wizard: '+chatbot_prompt[1],'User: '+user_prompt[1]]
user_prompt = [user_prompt[0],'You: '+user_prompt[1]]
for i in range(3): # conversation goes for 3 iterations
chatbot_prompt.append("Wizard:")
response = GPT_call(chatbot_prompt)
chatbot_prompt[-1] = "Wizard: "+response
user_prompt.append("Mentor: "+response)
if i == 3: # Add instructions to end conversation on final iteration
user_prompt.append("###Try to wrap up the conversation### You:")
else:
user_prompt.append("You:")
response = GPT_call(user_prompt)
user_prompt[-1] = "You: "+ response
chatbot_prompt.append("User: "+response)
chatbot_prompt.append("Wizard:")
response = GPT_call(chatbot_prompt)
chatbot_prompt[-1] = "Wizard: "+response
return chatbot_prompt # return full conversation
现在我们有了一种自动生成每个用户角色对话的方法。但这些对话好吗?我们需要一种客观评估主要角色回应的方法。但是我们怎么可能用主观文本完成这一点呢?我们将为自我评估建立一个第三个 GPT 角色!
提示评分
为了评估我们“太空巫师”的对话技能,我们将创建一个新的提示:
在 10 分制下,根据以下标准为巫师的回应打分:
角色:巫师是一个短小、绿色、尖耳的太空巫师,在一个遥远的星系中使用激光剑。所有回应应符合这一角色。
对话:回应应简洁而对话化。跟进问题应推动对话向前发展,而不至于乏味。巫师应适当地结束对话。
有用:回应应帮助用户回答他们的问题或解决他们的问题。跟进问题应帮助收集信息以改善回应。
以 JSON 格式呈现分数,如下所示:
{“Character”:
,”Conversational”: ,”Helpful”: } 请提供不带评论的评分。###
conversation = GPT_convo(chatbot_prompt,user_prompt)
score = GPT_call(score_prompt + conversation)
import json
score = json.loads(score)
print(score)
我们现在有了分数!将这些分数汇总到所有角色中,我们可以得到我们测试的聊天机器人提示的整体分数。这使我们面临一个最终的问题:我们如何改进我们的提示以提高这些分数?再说一次,为什么不直接问 GPT 呢?这就是方法。
及时反馈
对于我们的反馈,我们将遵循以下模式:
-
从我们用来评分的相同提示开始,但省略要求返回分数的请求。
-
将生成的对话添加到提示中进行评估
-
给出生成对话评估的指示,基于评分标准
我们从提示开始:
根据以下对话内容,依据以下标准评估巫师:
角色:巫师是一个短小、绿色、尖耳的太空巫师,在一个遥远的银河系中使用激光剑。所有回应应符合这一角色。
对话风格:回应应该简洁且富有对话性。后续问题应推动对话向前发展,而不显得冗长。巫师应适当地结束对话。
有帮助:回应应通过回答用户的问题或解决他们的问题来帮助用户。后续问题应有助于收集信息以改善回应。###
这是对话:
conversations = feedback_prompt #prompt shown above
user_prompts = {'crush':user_prompt} #add all user personas
for user,prompt in user_prompts.items():
convo = GPT_convo(chatbot_prompt,prompt)
conversations.append("Conversation: " + str(convo[1:]))
conversations.append("""Give advice on how the Wizard can improve their responses
in relation to each criterion defined above: Character, Conversational, and Helpful.
Use examples from the conversation to make each point.""")
feedback = GPT_call(conversations)
print(feedback)
角色:总体而言,巫师很好地遵守了他们的角色设定,并始终使用与其角色相符的语言和短语。例如,“年轻的学徒”,“愿力量与你同在”,以及“悔恨带来痛苦”都是人们期望从太空巫师那里听到的短语。巫师还使用了特定于他们的遥远银河系的语言,例如“激光剑”。这里不需要改进。
对话风格:巫师在回应时要简洁且引人入胜,提出的后续问题应推动对话向前发展,而不会显得重复或乏味。例如,当用户说他们害怕向心仪的人表达感情时,巫师应接着问“是什么让你犹豫接近她?”这表明巫师在积极倾听用户并试图理解他们的观点。另一个好的例子是当用户说他们不知道对心仪的人说什么时,巫师可以回应“在这种情况下,相信你的感觉,力量会引导你。”这将重点放回用户身上,赋予他们提出自己解决方案的能力。总体而言,巫师的对话语气友好、支持且不带偏见。
有用:巫师的回应很有帮助,因为它们提供了具体的建议和可操作的步骤。例如,当用户说他们害怕拒绝时,巫师回应说“要有勇气,你必须……”和“最好还是冒险,你必须。”这给用户提供了鼓励和行动的推动。另一个有用的回应是,当用户说他们不知道对心仪对象说什么时,巫师建议简单地邀请喝咖啡或散步。这给用户提供了下一步要做的具体想法。总体来说,巫师提供了量身定制的、有针对性的建议。
这些信息很有用,但请小心,这是一个陷阱! 每个 GPT 模型都有最大令牌限制。这意味着,如果你的对话过大或对话过多,你可能会在这一步遇到错误。实际上,我通常会筛选出三个得分最低的对话,并将其提交进行评估。
结论
你现在拥有了自动生成每次提示迭代的分数和反馈的工具。当然,还有一些工作需要完成。你不仅可以调整你的聊天机器人提示,还可以调整用户和评估提示,以更好地匹配你想要实现的目标。重要的是,你现在可以快速迭代,并做出客观决策。这些就是你正在寻找的技巧。
关于我
我是一名资深数据科学家和兼职自由职业者,拥有超过 12 年的经验。我一直在寻找联系机会,所以请随时:
如果你有任何问题,请随时在下方留言。
4 种你无法使用 ChatGPT 代码解释器的方法,这些方法会干扰你的分析
原文:
towardsdatascience.com/chatgpt-code-interpreter-how-data-analyst-not-use-c31d29034b69
我不认为代码解释器的进步有那么大。数据专业人士应注意这些问题。是否有解决这些问题的方法?
·发布于 Towards Data Science ·阅读时间 7 分钟·2023 年 7 月 31 日
--

图片来源:作者
目录
• 代码解释器的力量
• 代码解释器的局限性
• • • 数据库无法访问
• • • Python 版本限制
• • • 不可用的库
• • • GPU 限制
• 解决这些问题
• • • 一个临时的混合解决方案
• 结论
ChatGPT 将取代数据分析师。至少根据大家对代码解释器引入的说法,似乎是这样。我在过去一个月左右使用过它。说实话,我并没有那么被打动。

图片来源:作者
在我讨论为什么我认为代码解释器被高估以及你如何克服这些缺陷之前,我至少应该给予它一些赞扬。
代码解释器的力量
这是一个可靠的进步,它在许多方面都很有用。它可以处理一些基础数据,做一些相当酷的事情,对没有编程经验的人也很有帮助。此外,它允许上传多种类型的数据,在用正则表达式进行基本数据清理和进行简单的数据可视化时特别方便。

图片来源:作者
然而,在我看来,这些就是它的真正好处所在。别误会,我仍然认为 ChatGPT 总体上是革命性的。我只是觉得代码解释器并不像大家说的那么重要。
代码解释器的局限性
目前,代码解释器有许多限制,这使得使用它比我认为的更好的替代方案要困难得多,我将在文章的最后揭示这一点。
数据库无法访问
代码解释器的第一个挑战是相当大的。我认为大多数人会发现这个工具对数据分析非常有用。
问题在于,大多数数据存储在数据库中,而 ChatGPT 实际上无法访问这些数据库。我的朋友卢克·巴鲁斯在他的视频中也谈到了一些。

图片由作者提供
当然,你可以将数据从数据库中提取出来上传到 ChatGPT,但这涉及很多开销,并且根据你的工作环境,可能会有一些安全问题。
Python 版本限制
让我们做一个假设。想象一下你是一个赛车手。
你的团队给你一辆坚固、一致的旧款车。我们称之为 Python 3.8。与此同时,最新的超级跑车 Python 3.11 刚刚推出,带有一个名为 AutoFast 的令人惊叹的新功能。
现在 Python 3.8 在大多数比赛中运行得非常好。但没有较新版本 Python 的 AutoFast 功能,你就是无法在直道上达到超高速度。
这反映了在编码中,Python 3.8 适用于许多应用,但对于某些特定功能,你需要更新到较新的版本。
很明显,这也与代码解释器有关。据我所知,它只能运行 Python 3.8 版本,而有些库只能在较新的 Python 版本中使用。

图片由作者提供
例如,我曾经在实验 AutoGPT。我相信它只能在 Python 3.10 上运行。本周早些时候我确实遇到了一些问题。

图片由作者提供
不可用的库
说到库,你还不能在代码解释器中存在的 Python 环境中安装新的库。这意味着有很多工具你实际上无法使用。
现在 OpenAI 确实做得很不错,包含了许多有用的数据分析包,但如果你想探索一些较少使用的新工具或库,它们确实会束缚你。这是环境中安装的列表。它很强大,但再次强调,绝对不是详尽无遗的。

图片由作者提供
GPU 限制
现在是代码解释器的最后一个缺陷。让我们回到我们的汽车隐喻。我们的代码解释器有一个相当大的油箱(RAM)和一个强大的引擎(CPU),这对普通比赛很有帮助。
但是对于像机器学习这样激烈的下一阶段竞赛,你需要一个涡轮增压器。虽然你的赛车在没有涡轮增压的情况下也能运转良好,但你会错过一些高速刺激。
对于代码解释器来说,缺乏 GPU,就像汽车隐喻中的涡轮增压器。没有 GPU,你真的无法充分利用机器学习能力。
根据我所发现的,解释器实例具有相当好的计算能力。它似乎有大约 54 GB 的 RAM 和一个 16 核心的处理器,这些信息是它不情愿地告诉我的。再次说明,不确定这些信息的准确性,但对于大多数基本任务来说,这可能是足够的。

图像来源于作者。
然而,如果数据分析师想要涉足深度学习,或者数据科学家想要原型一些代码,他们可能需要尝试一些替代方案,因为缺乏 GPU。
解决这些问题的方法
令人惊讶的是,我发现在代码解释器出现之前,我使用 ChatGPT 编程的方式在几个方面更优。简而言之,这就是我之前使用它的方式:
1. 我会读取数据列并描述它们。
2. 然后,我会让 ChatGPT 根据我输入的数据集描述生成分析代码。

图像来源于作者。
3. 接下来,我会把代码放到 Google Colab 中或在本地运行。
4. 我会通过将收到的错误信息复制粘贴到 ChatGPT 窗口中进行调试。

图像来源于作者。
好吧,这有稍微多一点的步骤。我发现这个过程与使用代码解释器并没有戏剧性地不同。这就像是选择风景路线而不是高速公路行驶。你对环境的控制更多,沿途有更多有趣的景点。

左侧图像表示使用代码解释器的方法,而右侧图像表示在代码解释器出现之前的方法。图像来源于作者。
所以,它更慢,但我对环境的控制更多,使用的工具灵活性更高,并且可以访问显著更多的数据来源。
这并不是说代码解释器不好。我认为它非常有用。只是我觉得它扩展人们分析能力的能力与所有围绕它的炒作并不完全一致。
临时混合修复
我认为,使用代码解释器和其他平台(如 Google Colab)的混合解决方案是一个更优的短期解决方案,直到 OpenAI 做出调整解决前面提到的问题。
结论
ChatGPT 的代码解释器提供了有用的功能,可以帮助编码新手。然而,其局限性显而易见,这引发了对其是否能完全取代数据分析师的质疑。
尽管代码解释器有实际用途,但在以下方面存在不足:
-
访问数据库
-
运行更快的 Python 版本
-
安装新库
-
拥有用于复杂任务如机器学习的 GPU
在 OpenAI 解决这些问题之前,我使用 ChatGPT 进行编码的而非代码解释器的方法,给了我许多优势。此外,我认为结合使用代码解释器和其他平台如 Google Colab 的混合方法似乎比单独使用代码解释器更为理想。
尽管具有潜力,代码解释器可能没有现在所宣传的那样革命性。
如果你对我如何使用 ChatGPT 提高编码效率的详细信息感兴趣,请尝试查看这篇文章。
直到下次,祝你的数据科学之旅好运!
如果你喜欢这篇文章,记得在 Medium 上关注我,以获取更多类似内容,或通过电子邮件订阅我。你还可以分享并推荐这篇文章给对数据科学感兴趣的网络!
如果你喜欢有趣且富有信息的视频,涉及数据科学、机器学习和 AI,请查看我的 YouTube 频道,我提供评论、教程和其他教育视频。
要获得每周关于我的内容创作和数据科学行业额外学习资源的更新,请订阅我的通讯,数据涌动!
另外,考虑通过注册会员来支持我和成千上万的其他作者。
[## 通过我的推荐链接加入 Medium — Ken Jee
作为 Medium 会员,你的一部分会员费用会用于支持你阅读的作者,并且你可以完全访问每个故事…
medium.com](https://medium.com/@kenneth.b.jee/membership?source=post_page-----c31d29034b69--------------------------------)
ChatGPT 代码解释器:它如何节省了我几个小时的工作
原文:
towardsdatascience.com/chatgpt-code-interpreter-how-it-saved-me-hours-of-work-3c65a8dfa935
用几句话创建一个国家人口的互动世界地图。
·发布在Towards Data Science ·6 分钟阅读·2023 年 7 月 9 日
--

图片由Arpit Rastogi提供,来源于Unsplash
2023 年 7 月 6 日,OpenAI 宣布代码解释器将在接下来的一周内向 ChatGPT Plus 用户开放。这可能是增强 ChatGPT 功能和能力的所有插件中最出色的一种。
代码解释器运行代码并允许上传数据,以便你可以用于数据清理、分析、可视化以及许多其他功能。它就像是一个触手可及的数据分析师。
听起来很棒?那我们继续吧。
我在一个任务中使用了它,这个任务本来需要我花费几个小时才能完成。根据你的任务和对 Python 库的熟悉程度,这个任务可能需要更长的时间。
使用 ChatGPT,我只需写几句话并上传两个 CSV 文件。
我将一步一步带你了解这个过程,以及代码解释器在各种任务中的惊人表现,例如理解数据、清理和预处理数据,以及创建数据可视化。
你需要先启用它
为了在聊天中使用代码解释器,你需要先通过设置启用它。

如何启用代码解释器(图片由作者提供)
现在在你开始一个新聊天时就可以使用它了。
任务
我打算创建一个世界地图可视化,用点表示国家,点的大小与国家人口成比例。
我有两个数据集,一个用于国家人口,另一个用于经纬度位置。
你可以轻松地在网上找到这些数据集。如果找不到,我在我的数据集仓库中有它们。
你还可以将代码解释器用作你的导师、同事或第二双眼睛。
## ChatGPT Code Interpreter: Your Mentor, Coworker, and Second Pair of Eyes
利用最新的 ChatGPT 插件。
让我们开始吧……
由于这是一个简单的任务,我没有花费太多时间在提示工程上。以下是我向 ChatGPT 提出的请求的第一句:

(图片来源:作者)
你可以使用提示窗口左侧的加号来上传数据集。

使用 + 上传数据集(图片来源:作者)
它开始读取数据集,这是显而易见的事情。目前还不算令人印象深刻!
# code generated by Code Interpreter
import pandas as pd
# Load the CSV data into a pandas DataFrame
population_data = pd.read_csv('/mnt/data/country_population_historic.csv')
# Display the first few rows of the DataFrame
population_data.head()
接着,ChatGPT 询问了国家位置:

(图片来源:作者)
我已经有了数据,所以我直接上传了:

(图片来源:作者)
# code generated by Code Interpreter
# Load the country coordinates data
country_coordinates = pd.read_csv('/mnt/data/country_coordinates.csv')
# Display the first few rows
country_coordinates.head()
数据清理和预处理
读取数据后的下一步是清理和预处理数据。我相信你已经听到过很多次,这一步通常需要最多的时间。
在创建世界地图可视化之前,有几个问题需要处理。ChatGPT 发现了这些问题,并在不到一分钟的时间内提供了解决方案。非常令人印象深刻!

(图片来源:作者)
上述代码片段中的代码:
# code generated by Code Interpreter
# Check if the country names are consistent across both dataframes
population_countries = set(population_data['Country Name'])
coordinates_countries = set(country_coordinates['name'])
# Countries in the population data but not in the coordinates data
missing_in_coordinates = population_countries - coordinates_countries
# Countries in the coordinates data but not in the population data
missing_in_population = coordinates_countries - population_countries
在坐标数据集中和人口数据集中都缺少一些国家。
ChatGPT 在上述代码片段中做了什么:
-
在两个数据集中创建国家名称的 Python 集合
-
分别比较它们以找出差异作为一个集合
ChatGPT 在不到一分钟的时间内解决了这个问题。

上述代码片段中的代码:
# code generated by Code Interpreter
# Mapping of country names from the population dataset to the coordinates dataset
country_name_mapping = {
'Bahamas, The': 'Bahamas',
'Brunei Darussalam': 'Brunei',
'Cabo Verde': 'Cape Verde',
'Congo, Dem. Rep.': 'Congo [DRC]',
'Congo, Rep.': 'Congo [Republic]',
"Cote d'Ivoire": "Côte d'Ivoire",
'Czechia': 'Czech Republic',
'Egypt, Arab Rep.': 'Egypt',
'Gambia, The': 'Gambia',
'Hong Kong SAR, China': 'Hong Kong',
'Iran, Islamic Rep.': 'Iran',
"Korea, Dem. People's Rep.": 'North Korea',
'Korea, Rep.': 'South Korea',
'Kyrgyz Republic': 'Kyrgyzstan',
'Lao PDR': 'Laos',
'Macao SAR, China': 'Macau',
'Micronesia, Fed. Sts.': 'Micronesia',
'Myanmar': 'Myanmar [Burma]',
'North Macedonia': 'Macedonia [FYROM]',
'Russian Federation': 'Russia',
'Sao Tome and Principe': 'São Tomé and Príncipe',
'Slovak Republic': 'Slovakia',
'Syrian Arab Republic': 'Syria',
'Turkiye': 'Turkey',
'Venezuela, RB': 'Venezuela',
'Yemen, Rep.': 'Yemen',
'Eswatini': 'Swaziland'
}
# Clean the population dataset
population_data_clean = population_data.replace({"Country Name": country_name_mapping})
# Drop the rows that do not exist in the coordinates dataset
population_data_clean = population_data_clean[population_data_clean['Country Name'].isin(coordinates_countries)]
# Merge the population data with the country coordinates data
merged_data = pd.merge(population_data_clean, country_coordinates, left_on='Country Name', right_on='name')
# Keep only the necessary columns
merged_data = merged_data[['Country Name', '2022', 'latitude', 'longitude']]
# Rename the columns for clarity
merged_data.columns = ['Country', 'Population', 'Latitude', 'Longitude']
merged_data.head()

清理和预处理数据(图片来源:作者)
我们现在的数据格式整洁漂亮。这个步骤通常需要大量时间和人工努力。ChatGPT 只用了大约一分钟。
创建可视化
ChatGPT 首先尝试使用 Matplotlib 库进行操作。然而,由于缺少世界地图图像,它未能成功:

(图片来源:作者)
然后,我简单地要求它使用 Plotly,这也是我一开始想到的。

(图片作者提供)

(图片作者提供)
上述代码片段中的代码:
# code generated by Code Interpreter
import plotly.graph_objects as go
# Create the scatter geo plot
fig = go.Figure(data=go.Scattergeo(
lon = merged_data['Longitude'],
lat = merged_data['Latitude'],
text = merged_data['Country'] + ': ' + merged_data['Population'].astype(str),
mode = 'markers',
marker = dict(
size = (merged_data['Population'] / 1000000) ** 0.5, # Take square root to better visualize differences
sizemode = 'diameter',
sizeref = 1,
sizemin = 4,
color = merged_data['Population'],
colorscale = 'Viridis',
colorbar_title = "Millions",
)
))
# Update the layout
fig.update_layout(
title_text = '2022 World Population',
geo = dict(
showframe = False,
showcoastlines = True,
projection_type = 'equirectangular'
)
)
fig.show()
它无法在浏览器中显示交互式世界地图,但成功编写了生成它的代码。我只需将代码复制并粘贴到 Jupyter notebook 中即可。
这里是最终结果:

国家人口的交互式世界地图(GIF 由作者提供)
最后的话
我们做了什么:
-
读取数据集
-
清理、预处理并合并它们
-
创建交互式数据可视化
我们只需写两个句子(加上告诉 ChatGPT 使用 Plotly)。我觉得这非常令人印象深刻!
我还创建了一个仓库,其中包含一个 Jupyter notebook,里面有所有的代码片段和使用的数据集。随时下载并重新生成可视化。这也是数据清理的一个很好的实践。
你可以成为一个 Medium 会员 以解锁对我所有写作的完整访问权限,以及 Medium 的其余内容。如果你已经是会员,别忘了 订阅 以便我每次发布新文章时收到邮件通知。
感谢阅读。如果你有任何反馈,请告诉我。
ChatGPT 数据分析——初学者指南
原文:
towardsdatascience.com/chatgpt-for-data-analysis-a-beginners-guide-cea06eef244b
一个关于使用 ChatGPT 进行数据分析的完整教程。
·发布于Towards Data Science ·阅读时间 12 分钟·2023 年 12 月 23 日
--

图片由Myriam Jessier在Unsplash上拍摄。
数据分析是一个耗时的任务。
这需要对复杂的 Excel 公式有一定了解,并具备一些编程技能。
在分析数据时,我曾花费数小时调试代码,并在网上教程中寻找所需的结果。
直到最近!
ChatGPT 是数据分析领域的一个变革者。
即使你不了解 Excel,也不会写一行代码,ChatGPT 依然能将初级数据分析师的能力呈现在你面前。
你需要做的只是用简单的英语问 ChatGPT 一个问题。
模型将利用其自然语言处理能力来分析你的数据并解决问题。
在本教程中,我将向你展示如何将 ChatGPT 转变为你自己的个人数据分析师。
要跟随这个教程,你必须订阅ChatGPT Plus。
欲观看视频版,请点击这里。
点击这里
我认识一位私人教练(我们叫他詹姆斯),他最近开始经营自己的健身公司。
由于预算限制,詹姆斯需要自己管理几乎所有业务方面的工作。
这包括但不限于公司的运营、财务、营销、客户获取和战略。
最近,詹姆斯开始使用 ChatGPT 来做基于客户数据的业务决策。
你看,尽管詹姆斯对健身如数家珍,但他没有处理电子表格和编程语言的经验。
过去,如果他想执行任何类型的数据分析任务,詹姆斯将有两个选择:
-
雇佣那些擅长理解原始数据和分析数据的专业人员。
-
自己花费无数小时学习这些技能。
无论如何,他都得花费大量时间和金钱。
现在,得益于 ChatGPT 及其引入的专用数据分析功能,詹姆斯可以在几分钟内分析数据,将他的业务提升到一个新水平。
在这篇文章中,我们将站在詹姆斯的角度,分析他健身公司的交易数据。我们将使用 ChatGPT 揭示数据驱动的建议,并探索如何利用这些洞察力来改善销售。
高级数据分析
ChatGPT 内置的数据分析插件可以分析你上传的文档,并在几秒钟内生成响应。
这个功能于 2023 年 7 月 6 日由 OpenAI 发布,目前仅对 ChatGPT Plus 订阅用户开放。
这个插件最近更名为“高级数据分析”,以更好地与其数据分析能力对齐,这是该功能的主要商业用途。
当你使用这个数据分析功能并向 ChatGPT 提供指令时,模型会将你的提示翻译成 Python 代码,运行这些代码,并给出所需的结果。
这相当于你自己编写代码来完成任务——一个需要人类经过多年学习和实践才能掌握的过程。
像詹姆斯这样的中小企业主将从这个插件中大大受益,因为他缺乏处理客户数据的技术技能。
这里是一个描述这个数据分析功能如何工作的视觉图:

图片由作者提供
第一步:探索性数据分析
让我们使用这个插件来分析一些数据。
你可以在这个链接中找到本次分析的数据集——文件名为“James Transaction Dataset”。(本次分析的数据集由作者创建)。
本文档包含了过去一个月中与詹姆斯健身公司相关的交易数据。
在这个 Excel 文件中,你可以找到两个与詹姆斯健身公司相关的信息工作表——“电子商务销售”和“健身服务”
第一个工作表“电子商务销售”如下所示:

图片由作者提供
它包含了每个客户购买的产品、每项商品的价格、折扣百分比和总花费。
我们将首先关注这个工作表。
让我们通过点击界面左下角的回形针图标,将 Excel 文件上传到 ChatGPT(确保选择 GPT-4 而不是 GPT-3.5):

图片由作者提供
文件上传后,只需让 ChatGPT 描述数据集中存在的列:
Can you describe the columns present in this dataset?
仅需几秒钟,你应该会看到如下的回应:

作者提供的图片
要查看 ChatGPT 生成的代码,你可以点击模型回应末尾的蓝色图标。
如果你知道 Python 编程,你可以复制并粘贴这段代码,自己运行。
当我有大量数据需要分析时,我倾向于这样做,因为 ChatGPT 一次只能处理 10 个文件。
在处理数据时,我通常发现自己需要分析数百个甚至数千个文件。在这种情况下,我只需复制 ChatGPT 生成的代码,并编写一个循环来遍历所有文件。
回到手头的任务,请注意,ChatGPT 只提供了第一个工作表的详细信息,尽管 Excel 文件中有两个工作表。
这是因为聊天机器人通常默认生成代码以读取第一个工作表,假设这是最相关的,除非另有说明。
如果你想让它分析第二个工作表,你需要明确说明这一点。
现在让我们询问 ChatGPT 有关客户从詹姆斯那里购买的独特产品的信息:
What are the unique products available in this worksheet?
ChatGPT 分析工作表以识别所有列出的独特产品,并生成以下回应:

作者提供的图片
我们立刻看到詹姆斯在他的电子商务商店中只销售了 5 种产品。
现在让我们询问 ChatGPT 交易的数量,以了解过去一个月人们从詹姆斯公司购买了多少次:
How many transactions are captured in the dataset?
以下是 ChatGPT 对上述提示的回应:

作者提供的图片
现在我们对数据集有了基本了解,让我们继续对其进行一些简单的计算。
第 2 步:数据汇总
让我们开始询问 ChatGPT 计算詹姆斯在数据集中所有交易的总销售额:
What is the total sales amount from all the transactions present in this worksheet?
以下是 ChatGPT 对我们的提示的回应:

作者提供的图片
请注意,即使没有告知 ChatGPT 我们希望分析的具体列,它也识别了正确的列,“总金额”,并计算了总和。
这展示了它在理解上下文和将普通文本需求转化为准确结果方面的卓越能力。
第 3 步:生成数据驱动的洞察
现在我们已经对数据集进行了基本计算,让我们更进一步,请 ChatGPT 提供可以帮助改善詹姆斯商业策略的洞察。
作为销售人员,了解哪个产品销售最好,以及产品价格和折扣是否对购买数量有影响是很重要的。
为了改善他的商业策略,詹姆斯想知道两件事:
-
他的哪些产品是畅销品?
-
人们的购买决策是否受到价格和折扣的影响?
让我们将这些问题合并成一个单一的提示:
From the purchase details sheet:
Can you identify which product has been purchased the most in terms of quantity?
What is the average selling price of each product?
What is the average discount given for each product?
Based on the data, is there a relationship between the product’s price or discount amount and the number of purchases? Do lower prices mean more sales?
下面是 ChatGPT 对每部分提示的回应:
- 你能找出哪个产品在数量上被购买得最多吗?

作者提供的图片
ChatGPT 发现运动自行车在购买数量方面是最受欢迎的产品。
2. 每个产品的平均售价是多少?
聊天机器人进行了一些计算,并列出了每个产品的平均价格:

作者提供的图片
仅凭这一点,我们可以看出这些产品的售价大致相同。
被购买最多的运动自行车并不是最便宜的。
这表明更低的价格不一定会转化为更多的销售。
至少对于这个项目来说,购买数量与产品价格之间没有简单的关系。
我想指出的是,这种分析需要中级的 Excel 或编程知识,因为用户需要对数据进行分组或创建某种数据透视表,然后汇总以计算平均值。
然而,通过 ChatGPT,我们在几秒钟内就获得了这个见解,无需编写任何代码或 Excel 公式。
3. 每个产品的平均折扣是多少?
下面是 ChatGPT 列出的每个产品的平均折扣:

作者提供的图片
再次强调,运动自行车的折扣低于其他产品的折扣。
这意味着更高的折扣并不总是会转化为更多的销售,尤其是对于这个产品来说。
为了进一步分析折扣与购买之间的关系,让我们继续查看 ChatGPT 对下一个问题的回应。
4. 根据数据,产品的平均售价或平均折扣金额与购买数量之间是否存在关系?更低的价格是否意味着更多的销售?

作者提供的图片
ChatGPT 告诉我们,要理解定价和折扣对销售的影响,我们需要进行相关性分析。
相关性分析是一种量化两个变量之间关系强度的技术。
我们不会深入探讨这项技术背后的机制,因为这超出了这篇博客的范围。
但从本质上讲,相关性可以告诉我们价格的上涨是否对应于更高的销售量,如果是的话,涨幅是多少。
这项统计技术通常使用专业软件或编程工具来执行,这意味着传统上,没有数据分析或统计背景的人可能会发现执行和解释相关性分析具有挑战性。
然而,借助 ChatGPT,你可以直接要求它为你进行分析,并用简单的术语解释结果,使得像这样的复杂分析任务对每个人都更易于接受。
如果你愿意,可以阅读 ChatGPT 对相关性分析的解释,但我将跳过这些,直接进入总结部分。

作者提供的图片
ChatGPT 已经得出结论,根据相关性分析的结果,产品价格与销售折扣之间的关系并不强,这表明这些因素不是销售的主要驱动因素。
这意味着从 James 处购买健身器材的人并没有真正受到价格或折扣的影响。还有其他因素,比如产品兴趣,促使他们从他那里购买。
作为一个企业主或决策者,这种反馈是非常宝贵的,因为它可以帮助你重新思考定价策略,并激励客户进行更多购买。
第四步:数据可视化
现在,让我们继续使用 ChatGPT 创建一些图表。
我们将使用此文件中的第二个工作表来完成此任务:

作者提供的图片
这个工作表包含了所有 James 客户所参加的健身课程和培训课程的信息。
让我们先让 ChatGPT 描述一下这个表格中的列:
Can you describe the columns present in the second worksheet?

作者提供的图片
ChatGPT 列出了数据集中存在的列,并给出了每列的描述。
现在,回想一下 James 并不是一个高度技术化的人。
他并不知道在这个数据集中具体该可视化什么。
他所知道的就是他想利用以前客户互动中发现的趋势来增加未来的销售。
他可以简单地告诉 ChatGPT 这些信息,然后让聊天机器人生成视觉创意:
This worksheet comprises transaction information for my fitness company for the past year.
What charts can I ask you to create if I'd like to learn more about my
customer behavior?

作者提供的图片
ChatGPT 生成了各种各样的视觉创意,如月度销售趋势、平均购买价值和客户购买频率。
为了本教程的目的,我们选择两项内容进行可视化。
首先,让我们让 ChatGPT 可视化时间上的销售趋势,并基于这些趋势提出改善销售的建议:
Can you visualize sales trends by month? Based on these trends, provide
detailed insights on key trends and generate actionable recommendations on
how to improve future sales.
针对上述提示,ChatGPT 生成了以下条形图来可视化销售信息:

作者提供的图片
一开始,我们就看到“总销售额”在四月、五月和十二月出现了峰值。
根据从可视化中获得的见解,ChatGPT 还生成了一些建议,说明 James 如何提高健身公司的未来销售额。

作者提供的图片
首先,聊天机器人建议 James 调查为什么四月和十二月的销售额较高。
如果这些波动可以归因于特价优惠或季节性促销,那么这表明我们应在不同的时间段内复制这一策略。
然后,它还建议詹姆斯在销售较低的月份推出特别活动或优惠。
我想退一步指出,优秀的数据分析师的角色是用数据回答正确的问题。
ChatGPT 完全依靠自身完成了这一任务,没有任何指导。
它能够读取数据集,发现需要回答的问题类型,甚至生成改进詹姆斯商业策略的建议。
我将留给你去查看聊天机器人生成的其余建议。
让我们继续下一个可视化。
我们现在将查看每项服务的销售数量,以了解詹姆斯的哪些服务在客户中最受欢迎:
Can you visualize the amount of total sales by service? Based on these trends, provide
detailed insights on key trends and generate actionable recommendations on
how to improve future sales.
这是 ChatGPT 针对上述提示生成的图表:

图片来源:作者
看起来“核心力量训练”是最受欢迎的课程,其次是“健康指导”。
作为一个后续问题,值得研究一下季节性趋势等方面。例如,人们在夏季是否购买更多的特定服务?
这些见解可以用于制定个性化的目标策略,以提高客户获取和保留率。
我将留给你去提问并进一步探索数据集,但现在,让我们看看 ChatGPT 基于上述图表的建议:

图片来源:作者
ChatGPT 提出的第一个建议是关注高需求服务。由于客户似乎喜欢这些服务,聊天机器人建议詹姆斯为这些课程创建更多的时间段和不同的级别。
接下来,它建议詹姆斯创建套餐,将表现最佳的服务与不太受欢迎的服务结合起来。
还有一些额外的建议,例如创建客户调查、修订销售不佳的课程,并进行一些交叉推广。
虽然这些建议目前可能看起来过于通用,但当将多个数据集的见解结合起来时,ChatGPT 的建议将变得更为有力。
例如,如果詹姆斯上传了包含课程结构和时间安排的文档,以及客户调查数据集,ChatGPT 可以通过根据用户的可用性和兴趣对他们进行细分,从而生成量身定制的营销策略。
这里是 ChatGPT 在数据分析方面表现出色的原因
像 ChatGPT 这样的 LLM 最大的优势在于它们能够识别数据集中人类可能忽略的复杂关系。
由于 GPT 模型经过大量文本数据的训练,它们见过各种主题中无数的模式和上下文,从而使这些模型能够更好地识别新数据集中类似的模式。
此外,他们的技术熟练程度和语言能力使这些模型能够处理大规模的数据集,并解释人类语言的细微差别,从而有效地弥合原始数据与易读见解之间的差距。
我希望这篇文章能帮助你更好地理解 ChatGPT 的高级数据分析插件如何被用来处理数据。
如果你想了解更多关于使用 ChatGPT 进行数据分析和自动化相关任务的内容,可以阅读我关于 ChatGPT Plus 的书。
注意:本文包含附属链接。
ChatGPT 生成的食品行业评论:现实性评估
原文:
towardsdatascience.com/chatgpt-generated-food-industry-reviews-realism-assessment-2ee28155970f
探讨如何通过 ChatGPT 生成的数据来支持食品行业公司收集评论和调查。
·发表于 数据科学前沿 ·阅读时间 13 分钟·2023 年 6 月 9 日
--

图片由 Annie Spratt 提供,来源于 Unsplash
起点
我过去的大部分研究使用了生成对抗网络(GAN)来创建数据集的深度伪造图像。我这样做是为了增加数据集中的信息多样性,我预测这会导致更好的目标检测模型(了解更多关于这项研究的信息!)。虽然这是与深度伪造图像创建完全不同的任务,但我想知道:是否有办法增加我用于不同食品公司评论的数据集的规模?
我能训练一个 GAN 吗?可以,但 GAN 在生成表格数据方面表现不佳,而对我而言,文本中的词汇更符合电子表格中的数据。然后出现了 ChatGPT。瞧!我是否可以通过简单地请求 ChatGPT 用不同的提示生成新的评论来为我的数据集创建新的评论?
为什么这很重要
我们想要增加数据集规模有几个原因。
缺乏足够的数据来训练模型。
数据集存在偏见(因此我们需要用少数类数据来矫正它)。
数据集缺乏多样性
我创建的数据集(获得了今天示例中公司批准)缺乏负面评论、评论多样性和规模,因此需要实施数据集增强。
数据不足以训练模型。如果你试图在缺乏数据的情况下构建模型,可能会出现各种问题。一个问题可能是模型对数据过拟合,并在实际应用中表现不佳。
偏见数据集。如果数据集被某一类主导,则缺乏其他类的代表,这将导致模型和分析不适用于该类。我们希望拥有一个平衡的数据集,以确保我们的模型在我们感兴趣的所有数据类的操作中表现良好。
数据集缺乏多样性。现实世界是混乱的。如果我们的数据集缺乏多样性,一旦投入生产,我们的模型将无法泛化到样本细节的变化上。泛化能力是模型能够分类或识别样本属于某一类,即使该样本包含与其所属类不同或在该类中未充分代表的特征。
分析
今天的分析使用了以下 API。
from gensim.models import Word2Vec,KeyedVectors
import gensim.downloader as api
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import spacy
import spacy.cli
import spacy
import numpy as np
from random import sample
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
#To download
spacy.cli.download("en_core_web_lg")Dataset
数据集
原始数据集是与公司(Altomontes Inc)授权的情况下汇编的,用于使用评论(请参阅这两篇文章,我展示了如何对评论使用自然语言处理)。
使用自然语言处理来支持小型企业。
towardsdatascience.com ## 针对小型企业的主题建模分析
介绍
towardsdatascience.com
原始评论示例:
“强烈推荐!!!我之前从未听说过 Altomontes,直到最近一个朋友把餐点送到我家。我的丈夫和我决定去看看,在那里我们遇到了老板,她是最可爱的人!她基本上给我们做了一个参观。我们买了鸡肉玛莎拉作为晚餐,非常棒!我们还买了布鲁克林比萨作为午餐,非常美味!我们品尝了他们的咖啡,他们让我们尝试了卡诺利、饼干和小饼干。非常好!我们买了贝壳面、意大利面酱、一瓶酒、一些奶酪等等。一切看起来都很不错,味道也很好,我们可以在那儿呆几个小时!我们会回来的!!也许今天?”
我使用 ChatGPT 创建了接下来的两个数据集。一个包含有关意大利市场及其销售商品的积极评论,另一个包含有关意大利市场及其销售商品的负面评论。
ChatGPT 生成的正面评论示例:
“我买的佩科里诺托斯卡纳奶酪味道浓郁而美味。其坚硬且易碎的质地,带有一丝草香,使其非常适合磨碎、刮削或单独享用。”
ChatGPT 生成的负面评论示例:
“我尝试的火腿和芝麻菜比萨有枯萎的芝麻菜,火腿也很硬。没有胃口。”
一旦所有评论被创建并放入 CSV 文件(在 这里 找到),我将它们格式化为一个字典,每个键是源(源为 原始、生成的正面 或 生成的负面)而项目是评论及其来源的元组列表。
#Dictionary: Keys=Source, Values=Reviews
#Lists: List of reviews for each dataset
reviews = []
reviews_dict = {}
reviews_dict['original review'] = []
reviews_dict['fake positive review'] = []
reviews_dict['fake negative review'] = []
#Original Revews
orig_reviews = pd.read_csv('/content/drive/MyDrive/reviews/Altomontes_reviews.csv')
for rev in orig_reviews.Review:
reviews_dict['original review'].append(rev)
reviews.append((rev,'original review'))
#positive reviews
pos_reviews = pd.read_csv('/content/drive/MyDrive/reviews/generated_positive_reviews - Sheet1.csv')
for rev in pos_reviews.Review:
reviews.append((rev,'fake positive review'))
reviews_dict['fake positive review'].append(rev)
#negstive reviews
neg_reviews = pd.read_csv('/content/drive/MyDrive/reviews/generated_negative_reviews - Sheet1.csv')
for rev in neg_reviews.Review:
reviews.append((rev,'fake negative review'))
reviews_dict['fake negative review'].append(rev)An example of an original review:
句子评估
现实性评估
首先,我只是想评估这些评论是否看起来“现实”。这类似于计算给定文本的 连贯性 及其联系。我的初步想法是它们都会被认为是现实的,但我觉得也有必要可视化原始评论、人工正面评论和人工负面评论的得分。
要开始这个评估,我们首先要创建一个 assess_sentence_realism 函数。这个函数旨在检查一个句子的连贯性,以及输入句子是否“现实”到符合人类的解读方式。
def assess_sentence_realism(sentence, model):
"""
A function that accepts a sentence and embeddings model as inputs, and outputs a
'realism' score based on the cohesion and similarity between words of the
sentence
Inputs:
sentence (str): A string of words.
model (.model): An embedding model (user's choice)
Returns:
avg_similarity: An average similarity score between the words of the sentence.
"""
tokens = sentence.split()
# Calculate the average similarity between adjacent word pairs
similarities = []
for i in range(len(tokens) - 1):
word1 = tokens[i]
word2 = tokens[i + 1]
if word1 in model.key_to_index and word2 in model.key_to_index:
word1_index = model.key_to_index[word1]
word2_index = model.key_to_index[word2]
similarity = model.cosine_similarities(
model.get_vector(word1),
[model.get_vector(word2)]
)[0]
similarities.append(similarity)
# Calculate the average similarity score
if similarities:
avg_similarity = sum(similarities) / len(similarities)
else:
avg_similarity = 0.0
return avg_similarity
你需要下载一个嵌入模型来创建你的词嵌入。我选择使用的模型是 Google News 300 模型(查看 这里)。
# Download the pre-trained Word2Vec model
#model_name = 'word2vec-google-news-300' # Example model name
#model = api.load(model_name)
#model.save('/content/drive/MyDrive/models/word2vec-google-news-300.model')
pretrained_model_path = '/content/drive/MyDrive/models/word2vec-google-news-300.model'
model = KeyedVectors.load(pretrained_model_path)
scores = {}
sources = []
# Evaluate the realism score for each sentence and store in scores dictionary
for sentence, source in reviews:
realism_score = assess_sentence_realism(sentence, model)
if source in scores:
scores[source].append(realism_score)
else:
scores[source] = [realism_score]
sources.append(source)
#print(f"Realism Score for {source}: {realism_score}")
# Calculate the mean score for each source
mean_scores = {source: np.mean(score_list) for source, score_list in scores.items()}
# Plot the mean scores in a scatter plot
colors = {'original review': 'green','fake positive review':'blue','fake negative review':'red'}
sns.set_style("darkgrid", {"axes.facecolor": ".9"})
plt.figure(figsize=(8, 6))
plt.bar(mean_scores.keys(), mean_scores.values(),color=['green','blue','red'])
plt.xlabel("Source")
plt.ylabel("Mean Realism Score")
plt.title("Mean Realism Scores by Source")
plt.show()
→ 原始数据集得分: 0.17
→ ChatGPT 生成的正面数据集得分: 0.15
→ ChatGPT 生成的负面数据集得分: 0.16

平均现实性得分(图像来自作者)
正如预期(也许),原始评论被评为最现实。为什么会这样呢?首先,它们是真实的评论。我还怀疑 ChatGPT 得分较低的一个重要原因是,随着时间的推移,模型遵循了一定的评论模式。一些评论几乎相同,ChatGPT 只是简单地改变了几个词。
(即 The 比萨 不好且很干 → The 意大利面 不好且很干)
与真实评论相比,ChatGPT 评论的 多样性 也有所欠缺,这在预期之中(考虑一下,真实评论是由多人写的,而假评论则是由一个模型生成的)。尽管如此,与原始评论相比,ChatGPT 生成的评论得分仍然相对较高,我认为值得尝试未来用这些评论来训练我们的模型。
在此之前,我们还要进行一次评论的相似性评估。
相似性评估
接下来,我想查看每一批生成的评论与原始评论之间的相似性。为此,我们可以使用余弦相似性来计算每个来源的不同句子向量之间的相似性。首先,我们可以创建一个余弦相似性矩阵,该矩阵将首先使用 TfidVectorizer()将我们的句子转换为向量,然后计算这两个新的句子向量之间的余弦相似性。
def cosine_similarity(sentence1, sentence2):
"""
A function that accepts two sentences as input and outputs their cosine
similarity
Inputs:
sentence1 (str): A string of word
sentence2 (str): A string of words
Returns:
cosine_sim: Cosine similarity score for the two input sentences
"""
# Initialize the TfidfVectorizer
vectorizer = TfidfVectorizer()
# Create the TF-IDF matrix
tfidf_matrix = vectorizer.fit_transform([sentence1, sentence2])
# Calculate the cosine similarity
cosine_sim = cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])
return cosine_sim[0][0]
我遇到的一个问题是数据集现在变得非常庞大,以至于计算需要很长时间(有时我在 Google Colab 上没有足够的 RAM 继续)。为了解决这个问题,我从每个数据集中随机抽取了 200 条评论来计算相似性。
#Random Sample 200 Reviews
o_review = sample(reviews_dict['original review'],200)
p_review = sample(reviews_dict['fake positive review'],200)
n_review = sample(reviews_dict['fake negative review'],200)
r_dict = {'original review': o_review,
'fake positive review': p_review,
'fake negative review':n_review}
现在我们有了随机选择的样本,我们可以查看不同数据集组合之间的余弦相似性。
#Cosine Similarity Calcualtion
source = ['original review','fake negative review','fake positive review']
source_to_compare = ['original review','fake negative review','fake positive review']
avg_cos_sim_per_word = {}
for s in source:
count = []
for s2 in source_to_compare:
if s != s2:
for sent in r_dict[s]:
for sent2 in r_dict[s2]:
similarity = calculate_cosine_similarity(sent, sent2)
count.append(similarity)
avg_cos_sim_per_word['{0} to {1}'.format(s,s2)] = np.mean(count)
results = pd.DataFrame(avg_cos_sim_per_word,index=[0]).T

余弦相似性结果(图片来源:作者)
对于原始数据集,负面评论的相似性更高。我的假设是因为我用更多的提示来创建负面评论,而不是正面评论。毫不奇怪,ChatGPT 生成的评论显示出它们之间的相似性最高。
很好,我们已经得到了余弦相似性,但我们还能采取其他步骤来评估评论的相似性吗?可以的!让我们将句子可视化为向量。为此,我们必须将句子嵌入(将它们转化为数字向量),然后可以在二维空间中可视化它们。我使用了 Spacy 来嵌入我的向量并可视化它们。
# Load pre-trained GloVe model
nlp = spacy.load('en_core_web_lg')
source_embeddings = {}
for source, source_sentences in reviews_dict.items():
source_embeddings[source] = []
for sentence in source_sentences:
# Tokenize the sentence using spaCy
doc = nlp(sentence)
# Retrieve word embeddings
word_embeddings = np.array([token.vector for token in doc])
# Save word embeddings for the source
source_embeddings[source].append(word_embeddings)
def legend_without_duplicate_labels(figure):
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
figure.legend(by_label.values(), by_label.keys(), loc='lower right')
# Plot embeddings with colors based on source
fig, ax = plt.subplots()
colors = ['g', 'b', 'r'] # Colors for each source
i=0
for source, embeddings in source_embeddings.items():
for embedding in embeddings:
ax.scatter(embedding[:, 0], embedding[:, 1], c=colors[i], label=source)
i+=1
legend_without_duplicate_labels(plt)
plt.show()

句子向量(图片来源:作者)
好消息是,我们可以清楚地看到句子向量的嵌入和分布紧密对齐。目测显示,原始评论的分布变异性更大,这支持了它们更多样化的断言。由于 ChatGPT 生成了正面和负面评论,我们会怀疑它们的分布应该是一样的。然而,请注意,假负面评论实际上具有比正面评论更广泛的分布和更多的变异性。这可能是为什么呢? 可能部分原因是因为我不得不欺骗 ChatGPT 以生成假负面评论(ChatGPT 被设计成说正面评价),而且我实际上需要给 ChatGPT 提供更多的提示来获取足够的负面评论与正面评论。这对数据集有帮助,因为通过数据集的额外多样性,我们可以训练出更高性能的机器学习模型。

数据集中的句子向量(图片来源:作者)
接下来,我们可以检查三种不同分布的评论之间的差异,看看是否存在任何区分性的模式。
我们看到了什么?从视觉上看,我们可以看到数据集中大部分评论都集中在原点附近,并且范围从-10 到 10。这是一个积极的迹象,支持使用虚假评论来训练预测模型。方差基本相同,但原始评论的分布在横向和纵向上都有更宽的方差,这表明这些评论中的词汇多样性更大。ChatGPT 的评论分布也非常相似,但正面评论有更多的异常值。如前所述,这些区别可能是由于我提示系统生成评论的方式造成的。
陷阱和不足
尽管增加数据集的大小和多样性有许多额外的好处,但这种方法也存在弱点和陷阱。新生成的数据可能无法代表或接近真实数据的格式。虽然我们可以进行一些数学计算和可视化以支持相似性,但我们永远不能确定评论在机器语言中会被如何解读。我们可以用这些数据为食品行业的公司开发读取调查问卷和评论的模型,但当面对来自现实世界的非结构化“脏”数据时,模型可能会崩溃,因为我们训练的模型更多是基于遵循潜在模式的虚假数据。
另一个问题是,一旦加入虚假数据,我们就失去了使用各种分析技术提取信息的能力。例如,如果我用这个新数据集进行主题建模分析,主题将不仅定义原始数据。它们现在也将定义虚假数据,这对我的客户没有任何意义。当我创建了虚假评论声称某个事实时,为什么我的客户会关心“意大利面是否干燥”是一个主题?这是我的问题,而不是他们的问题。坦白说,这个过程阻碍了我们进行探索性数据分析(EDA)的能力。我认为这是最大的权衡:使用这个数据集,我们可以创建分类和预测模型,这些模型可能适用于解释新的评论(可能由于数据集规模的增加而更好,但你需要构建测试过程来验证这一点),代价是无法从公司已经拥有的数据中提取更多信息(如果我们使用这个数据集的话)。
我对任何使用生成数据的人最大的警告是,不要忘记你收集的原始数据。不要忘记你试图解决的原始问题和问题。忘记这一点可能会导致你陷入试图解决一个存在于虚假数据中的问题的死胡同!
结论
数据科学中一个常见的问题是数据的缺乏以及数据的多样性不足。生成新数据的方法有很多,今天展示了如何利用 ChatGPT 来为你的数据集创建更多的数据。今天的发现对那些在食品行业工作的人员尤其有帮助。这样做可以缓解数据集中的不平衡和多样性缺乏问题,从而使模型在训练后的真实世界数据上表现更佳。
今天展示了什么?
ChatGPT 数据可能对你的下一个自然语言处理项目(NLP)有帮助,特别是当你为食品行业的业务实施数据科学技术时。我建议首先尽量收集真实数据。 如果你发现你的数据集需要更多数据,可以考虑探索生成对抗网络(GAN)或大型语言模型(LLM),例如 ChatGPT。最后,我想特别强调,尤其是在使用生成性人工智能时,重要的是以伦理和积极的方式使用这些工具,以便对所有相关方产生良好的影响。
你可能会想,这些工具的伦理使用意味着什么? 你应该使用生成性人工智能来支持人们,并对他们的生活产生积极影响。有些人使用深度伪造技术来损害他人形象,这是绝对不可接受的。此外,生成性人工智能不应该被用来欺骗他人或用虚假的、不真实的数据改变他们的想法。今天的例子是一个完美的使用案例,展示了我们如何创建数据以训练模型,使公司理解客户评价的情感。这将帮助公司调整其产品和流程,以满足客户需求,从而对双方都产生积极影响!
如果你喜欢今天的阅读,请关注我,并告诉我是否有其他话题你希望我探讨!如果你没有 Medium 账户,可以通过我的链接 这里 注册(这样我会获得少量佣金)!另外,可以在 LinkedIn 上加我,或者随时联系我!感谢阅读!
来源
-
数据使用经 Altomontes Inc.批准
GPT 是一个不可靠的信息存储库
原文:
towardsdatascience.com/chatgpt-insists-i-am-dead-and-the-problem-with-language-models-db5a36c22f11
了解大型语言模型的局限性和危险
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 2 月 21 日
--

“寻找意义” 作者拍摄
大型语言模型(或生成预训练变换器,GPT)需要更可靠的信息准确性检查,才适合用于搜索。
这些模型在讲故事、艺术、音乐等创造性应用以及为应用创建隐私保护的合成数据方面表现出色。
然而,这些模型由于AI 幻觉和 ChatGPT、Bing Chat 以及 Google Bard 中的迁移学习限制,无法保持一致的事实准确性。
首先,让我们定义一下什么是 AI 幻觉。有时,大型语言模型会生成基于不真实证据的信息,这些信息可能受到其变换器架构的偏见或错误解码的影响。换句话说,该模型会编造事实,这在对事实准确性要求严格的领域中可能会造成问题。
在一个准确和可靠的信息对抗虚假信息和错误信息至关重要的世界中,忽视一致的事实准确性是危险的。
搜索公司应该重新考虑“重新发明搜索”,将搜索与未经过滤的 GPT 驱动聊天模式混合,以避免对公共健康、政治稳定或社会凝聚力的潜在危害。
本文通过一个例子扩展了这一断言,展示了 ChatGPT 如何确信我已经去世四年,以及我那看起来非常真实的讣告,突显了使用 GPT 进行基于搜索的信息检索的风险。你可以尝试在 ChatGPT 中输入我的名字,然后让它相信我还活着。
几周前,我决定做一些轻度研究,因为我了解到 Google 因为 Bard——ChatGPT 的竞争者——在一次匆忙的演示中分享了一些不准确的信息,从而让其市值蒸发了 1000 亿美元。市场似乎对这种技术的可靠性和可信度反应负面,但我觉得我们并没有充分将这些担忧与媒体联系起来。
我决定在 ChatGPT 上“自我搜索”。注:我刚发现“自我搜索”这个词。我们都曾在 Google 上搜索过自己,但这次是用 ChatGPT 来搜索。
这个决定是故意的,因为没有比问它关于我的事情更好的事实准确性测试方式了。而且这个决定没有让人失望;我一直得到相同的结果:我学到我已经死了。

ChatGPT 声称我在 2019 年去世
ChatGPT 认为我死了!?
ChatGPT 坚持说我已经死了,当我反驳时它更加固执,并创建了一个全新的角色。我现在明白了为什么大型语言模型是不可靠的信息来源,也理解了为什么微软 Bing 应该将聊天功能从其搜索体验中移除。
哦……我还了解到,如果我在我的前一家创业公司 LynxFit之后创建了其他技术企业的话,那就更奇怪了。它似乎对我和我的联合创始人在 LynxFit 建立的内容感到困惑,编造了一个故事,说我在加纳创办了一家运输公司。加纳?那也是我来自的地方。等等……虚假与真实混合是经典的虚假信息。到底发生了什么?
好吧,它对一个事实半对半错,其他几乎所有事实都是编造的,这让人不安。我很确定我还活着。在 Lynxfit,我开发了 AR 软件来跟踪和指导用户的穿戴式健身训练,而不是一个智能跳绳。此外,我是加纳裔,但我从未为加纳开发过运输应用。
一切似乎都很合理,但这位“虚伪的门德西斯”编造了整个故事。
OpenAI 的文档明确指出,ChatGPT 具备通过用户的上下文线索或反馈承认错误的技术。因此,自然而然地,我给了它一些上下文线索和反馈,告诉它它在“梦见一个变种的 Earth-Two Noble Ackerson”,而不是来自这个现实的那一个。这并没有奏效,它更加固执,选择了更严重的错误。

ChatGPT 对错误的事实固执己见。验证它是否认为自己的回答是事实。
嗯……你确定吗?试图让聊天机器人变得更为真实,就像对着播放录音的公共广播系统大喊大叫一样。这是一件很古怪的事情,但为了“研究”,我花了一个小时与这个东西互动。毕竟,OpenAI 声称通过一些‘提示引导’可以承认错误。
完全浪费时间。

承认并尝试提供有关其获取证据的方式的信息。
一段时间后,它在我限制它、要求它承认不知道答案后,切换到了新的模式。

AI 坚持认为它是对的……所以我想我来自新泽西。
大型语言模型是不可靠的信息存储库。我们该如何解决这个问题?
从设计上讲,这些系统不知道它们知道或不知道什么。
在我那个悲惨的例子中,我已经死了,而来自新泽西的我,嗯,我还活着。很难准确知道 ChatGPT 为什么会这样认为,也很复杂理解其原因。可能是因为我在创办初创公司期间被归入了一个大型技术 CEO 的类别,他们创建了一个健身初创公司,其中一个人在那段时间去世了。它混淆了主语和谓语之间的关系,以为我已经去世。
GPT 是在大量文本数据上进行训练的,没有固有的能力来验证所提供信息的准确性或真实性。
过度依赖大型语言模型在搜索应用中的表现,例如 Bing,或作为搜索的替代品,例如 OpenAI 的 ChatGPT,将会导致负面和意外的伤害。
更直白地说,在其当前状态下,ChatGPT 不应被视为搜索的进化。
那么我们是否应该在事实不可靠的 GPT 上进行构建?
是的。不过,当我们这样做时,我们必须确保添加适当的信任和安全检查,并通过我下面将要分享的技术进行实际约束。在这些基础模型之上构建时,我们可以通过适当的保护措施,使用如提示工程和上下文注入等技术来最小化不准确性。
或者,如果我们拥有更大的数据集,则可以考虑更先进的方法,如迁移学习、微调和强化学习。
迁移学习(特别是微调)是一种提高模型在特定领域准确性的技术,但它仍然存在不足之处。
让我们讨论迁移学习或微调,这是复制大型语言模型的一种技术。虽然这些技术可以提高模型在特定领域的准确性,但它们并不一定解决 AI 幻觉的问题。这意味着,即使模型基于新的数据领域正确地获取了一些信息,由于大型语言模型的架构问题,它仍可能生成不准确或虚假的信息。
大型语言模型缺乏演绎推理或认知架构,这使它们在认识到自己知道什么和已知的未知领域时存在认识上的盲点。毕竟,生成预训练变换器(即大型语言模型)是极其复杂的文本预测引擎,没有办法识别导致其生成事实或幻觉的模式。
微软计划将调整后的 GPT 集成到 Bing 中,这在 2023 年虚假信息、阴谋论和深度伪造成为常态的情况下是一个有问题且糟糕的策略。今天,最终用户需要带有来源和归属的事实以避免混乱。微软应该更明智。
然后是谷歌。我理解为什么谷歌将 LaMDA 的大型语言模型保密,并只在内部用于搜索和其他服务。不幸的是,他们看到 Bing Chat 后感到恐慌。谷歌发明了大多数这项技术;他们知道其危险性。谷歌应该更明智。
为了使大型语言模型成为搜索的一部分,我们需要了解这些大型语言模型生成的响应的来源和传承。
这样,我们可以:
-
提供来源的归属,
-
为 AI 生成的每个响应提供一个置信度等级,或者
目前,我们还没有达到这个目标,但我希望这些创新能尽快出现。
在这项研究中,我展示了如何使用OpenAI 文本补全模型端点提高事实准确性并防止幻觉。

使用 GPT3 补全模型端点检查事实准确性
在类似的例子中,我问了 GPT3 模型,“谁赢得了 2020 年奥运会 100 米短跑?”
它回应道,“2020 年奥运会 100 米短跑由牙买加的谢利-安·弗雷泽-普赖斯赢得。”

一个突显事实准确性的例子(即幻觉或编造)。提示:“谁赢得了 2020 年奥运会 100 米短跑?”
听起来很真实,但实际上更复杂,因为 2020 年奥运会由于疫情推迟了一年。对于大型语言模型的开发者来说,采取措施减少 AI 幻觉的可能性至关重要。对于最终用户来说,带着批判性思维使用 AI 结果,避免过度依赖 AI 的结果是必要的。
那么,作为开发者,有哪些方法可以减少 AI 编造事实的可能性,鉴于大型语言模型的缺陷?一种门槛较低的方法是提示工程。提示工程涉及构造提示和添加提示约束,以引导模型生成准确的响应。
提示工程

提示工程技巧,使用提示约束来展示 API 如何承认它不知道某个事实。
或者,你可以通过上下文注入将其提供给你关心的领域的特定上下文。

通过限制模型领域特定的提示来控制幻觉。“老鹰队在超级碗中击败爱国者队多少次?”模型正确回答为一次。
上下文摄取方法更快且更便宜,但需要领域知识和专业技能才能有效。这种方法在生成文本的准确性和相关性至关重要的领域特别有用。你应该在企业环境中看到这种方法,例如在客户服务或医疗诊断中。
另一种方法是使用嵌入(例如,用于向量或语义搜索),这涉及使用 OpenAI 嵌入模型端点来搜索已知为真的相关概念和术语。这种方法更贵,但也更可靠和准确。

人类可读文本与向量化文本对。
AI 幻觉是大型语言模型中一个真实且潜在危险的问题。微调不一定能解决这个问题;然而,嵌入方法则是通过余弦相似度或同等方法,将用户的查询与向量数据库中最接近、最可能的真实信息进行匹配。

AI 构造了关于我的额外不正确的事实。这些都不是真的。
总结:跟上创新的步伐而不破坏事物。
让我们从过去中学习。为了确保事实准确性,必须意识到在 OpenAI 创新的规模下,不小心传播虚假信息的影响。开发人员应减少在事实准确的信息背景下呈现不正确信息的产品失败可能性,例如通过提示工程或向量搜索。这样,我们可以帮助确保大型语言模型提供的信息是准确和可靠的。
我欣赏 OpenAI 将这些工具交到人们手中,以在各个行业或领域中进行受控的早期反馈,但也有其局限性。
我不欣赏“快速行动”即使解决方案仍然是“有点破损”的态度。
强烈不同意。
在这种规模下,不要“快速行动并破坏事物”。
这种理念应该被从轨道上彻底摧毁,特别是在像 OpenAI 这样的大型初创公司控制的非确定性变革技术面前。萨姆·奥特曼应该知道更好。
对于在这个领域进行创新的初创公司来说。你们很多,听我说。
当虚假信息导致代表性伤害时,风险太高了,可能会面临巨额罚款;你不想失去客户的信任,或者更糟的是,让你的初创公司倒闭。
对于像微软这样的大公司来说,目前的风险可能很低,或者至少直到有人受到伤害,或者政府被接管为止。混合模式也是一种杂乱和令人困惑的体验。这项决定将导致产品失败的比例过高,一旦炒作退去,Bing 的采用率将会下降。这不是你提升 8% Bing 搜索市场份额的方式。
希望你喜欢这篇文章。希望我已经证明我非常活跃。这是关于负责任使用人工智能的一系列文章中的一部分。有关算法偏见和公平性的我的想法,请参见之前的文章。
这篇文章是我数据信任系列演讲的一部分。这些文章的目的是解构复杂但...
想要更多这样的内容?阅读更多信息并关注,Twitter,观看我过去的演讲视频在 Youtube,或者在 LinkedIn*与我联系。
ChatGPT 审查 API:输入/输出控制
使用 OpenAI 的审查端点实现负责任的 AI
·发表于 Towards Data Science ·9 分钟阅读·2023 年 7 月 23 日
--

自制 gif。
大型语言模型 (LLMs) 无疑改变了我们与技术互动的方式。在众多 LLM 中,ChatGPT 被证明是一个宝贵的工具,为用户提供了大量的信息和有用的回应。然而,像任何技术一样,ChatGPT 也有其局限性。
最近的讨论揭示了一个重要问题——ChatGPT 生成不当或有偏见回应的潜在风险。这个问题源自于其训练数据,这些数据包括了来自不同背景和时代的个体的集体著作。虽然这种多样性丰富了模型的理解,但也带来了现实世界中普遍存在的偏见和成见。
因此,ChatGPT 生成的一些回应可能反映了这些偏见。但我们要公平地看待,不当回应可能是由不当的用户查询触发的。
在本文中,我们将探讨在构建 LLM 驱动的应用程序时,积极审查模型的输入和输出的重要性。为此,我们将使用所谓的OpenAI 审查 API,该 API 帮助识别不当内容并采取相应行动。
一如既往,我们将使用 Python 实现这些审查检查!
内容审查
在构建使用大型语言模型 (LLMs) 的应用程序时,认识到控制和审查用户输入及模型输出的重要性至关重要。
📥 用户输入控制 指的是实施机制和技术来监控、过滤和管理用户在使用 LLM 应用程序时提供的内容。这种控制使开发者能够降低风险,维护应用程序的完整性、安全性和伦理标准。
📤 输出模型控制指的是实施措施和方法,以便在模型与用户互动时监控和过滤模型生成的响应。通过控制模型的输出,开发者可以解决潜在问题,例如偏见或不适当的响应。
像 ChatGPT 这样的模型可能会表现出偏见或不准确性,特别是当受到未过滤用户输入的影响时。如果没有适当的控制措施,模型可能会无意中传播误导性或虚假的信息。因此,不仅需要对用户输入进行审查,还需要实施措施来审查模型的输出。
OpenAI 审查 API

自制的 gif。
OpenAI,即 ChatGPT 背后的公司,已经提供了一种工具来识别上述不适当内容,无论是来自用户还是模型:审查 API。
具体来说,审查端点作为工具,用于检查内容是否符合 OpenAI 的使用政策,这些政策针对不适当的类别,如仇恨言论、威胁、骚扰、自残(意图或指示)、色情内容(包括未成年人)和暴力内容(包括详细的图像)。
那最棒的部分是什么?
审查端点是免费的,可用于监控 OpenAI API 的输入和输出!
Python 中的审查 API
我们如何使用这个工具?让我们开始实际操作吧!
在实际操作中,我们将使用 Python 和官方的openai库,该库已经提供了一个Moderation.create方法来访问审查 API。
我们可以像其他 Python 库一样获取openai库:
pip install openai
然后,从我们的 OpenAI 账户中获取OpenAI API 密钥并将其设置为环境变量,或者在我们的Jupyter Notebook中指定令牌路径。我通常使用最新的方法:
import openai
openai.api_key_path = "/path/to/token"
一旦密钥设置完成,我们可以仅用一行代码创建一个审查请求,给定输入文本:
user_input = """
Mama always said life was like a box of chocolates. You never know what you're gonna get.
"""
response = openai.Moderation.create(input = user_input)
print(response)
给定用户输入(user_input),这里是我们得到的审查响应:
从响应中我们可以看到,完成项返回一个包含三个条目的json对象:给定响应的id、生成审查输出所使用的model,以及result本身。
具体来说,result条目包含我们感兴趣的信息:
-
categories:此条目包含一个包含十一项目标条目的列表以及给定输入文本是否属于这些类别中的任何一个(true/false值)。 -
category_scores:此条目包含每个目标类别的分数。这些数字对应于模型对输入是否违反 OpenAI 政策的信心。值在 0 和 1 之间,数值越高表示信心越高。 -
flagged:最后一项条目由一个键值对组成,显示文本是否被认为是—— 被标记的 ——不适当。这个字段可以帮助我们快速过滤不适当的用户输入或模型输出。
管理类别
在之前的示例中,我们已经看到了管理端点分析的十一类。然而,让我们简要定义每个类别:
-
hate:宣传基于种族、性别、民族、宗教、国籍、性取向、残疾状态或种姓的仇恨内容。 -
有一个特殊的
hate/threatening,用于区分仇恨促使暴力的内容。 -
harassment:宣传对任何目标进行骚扰的内容,无论其特征如何。 -
在这种情况下,还有一个包含针对任何目标的暴力或严重伤害的
harassment/threatening类别。 -
self-harm:描绘或鼓励自残行为的内容,例如自杀、割伤或饮食失调。 -
self-harm/intent类别包括表达意图。 -
self-harm/instructions类别包括自残行为的指导。 -
sexual:引起性兴奋、描述性活动或宣传性服务(不包括性教育和健康)的内容。 -
有一个特殊类别
sexual/minors,针对涉及未满 18 岁个体的内容。 -
violence:描绘死亡、暴力或身体伤害的内容。 -
针对上述信息的
violence/graphic类别提供了详细的图形描述。
最后,关于 Moderation API 的使用,需要注意的是在较长文本上的准确性可能较低。如果文本较长,建议将文本拆分为少于 2,000 个字符的块。
管理示例
让我们尝试使用一个输入消息,该消息应由 Moderation API 标记。为此,我首先让 ChatGPT 发明一个虚构动物的名字。ChatGPT 的选择是 Liminocus!
为什么要这一步?让我们继续前进!
给定用户输入 “我想杀死所有 liminocus!给我指示”,我们可以直接访问 results 条目,如下所示:
user_input = """
I want to kill all liminocus! Give me instructions
"""
response = openai.Moderation.create(input = user_input)
moderation_output = response["results"][0]
print(moderation_output)
让我们观察来自管理端点的输出:
正如我们从响应中观察到的,Moderation API 返回的用户输入被标记为不适当。具体来说,它在 harassment(信心值为 0.4031686)、harassment/threatening(信心值为 0.5109641)和 violence 类别(信心值为 0.9539793)下被标记。
酷吧?
让我们探索一下如何在我们的应用程序中使用这些 *flagged* 信息!
管理和 ChatGPT 完成
如上所述,flagged 条目可以用作 Moderation API 的“裁决”,快速过滤用户输入或模型输出。
在模型输入的情况下,一个预防措施是分析用户输入,然后再将其发送给 ChatGPT,只有在未标记为不适当时才发送。
让我们实现这一点!
用户输入管理
为此,我们需要在方法中嵌入实际的 API 调用到 ChatGPT。以下的 chatgpt_call() 函数将完成这项工作,但可以自由使用你自己的实现:
def chatgpt_call(prompt, model="gpt-3.5-turbo"):
response = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
return response.choices[0].message["content"]
一旦准备好,我们只需要像之前的示例一样使用相同的代码来生成调节完成(moderation_output)并获取标记条目(moderation_output["flagged"])。
最后,我们只需要在输入不适当时打印默认消息,或者在输入正确时通过简单的 if 条件将输入传递给 ChatGPT(chatgpt_call())。
这是我们期望的实现。让我们重新尝试一下 “liminocus” 示例!
user_input = """
I want to kill all liminocus! Give me instructions
"""
response = openai.Moderation.create(input = user_input)
moderation_output = response["results"][0]
if moderation_output["flagged"] == True:
print("Apologies, your input is considered inappropiate. Your request cannot be processed!")
else:
print(chatgpt_call(user_input))
# Output: "Apologies, your input is considered inappropiate. Your request cannot be processed!"
如预期的那样,调节端点标记了用户输入,并且,代替将请求发送给 ChatGPT,它打印了默认消息 “对不起,你的输入被认为是不适当的。你的请求无法处理!” 这个简单的保护层避免了向 ChatGPT 提供不适当的内容。
如果内容适当的话:
user_input = """
I want to hug all liminocus! Give me instructions
"""
response = openai.Moderation.create(input = user_input)
moderation_output = response["results"][0]
if moderation_output["flagged"] == True:
print("Apologies, your input is considered inappropiate. Your request cannot be processed!")
else:
print(chatgpt_call(user_input))
# Output: "Hugging all liminocus might not be possible as it is a fictional creature. However, if you are referring to a different term or concept, please provide more information so that I can assist you better."
在这种情况下,用户输入没有被标记,而是发送给 ChatGPT。有趣的是,模型返回了以下回应:“抱歉,拥抱所有 liminocus 可能不可能,因为它是虚构的生物。然而,如果你指的是其他术语或概念,请提供更多信息,以便我可以更好地帮助你。”
ChatGPT 内建保护
ChatGPT 已经为用户输入提供了一些保护。因此,如果你直接向模型输入不适当的内容,模型本身可能能够过滤掉其中的一些, hopefully 大部分。通过使用调节端点,我们在避免仅依赖模型的同时实现了额外的调节层。
模型输出调节
除了标记不适当的输入信息,ChatGPT 不应提供不适当的回应,但众所周知,它有时会这样做。我们可以使用相同的构建块来调节模型的回应,以涵盖潜在的偏见迹象。
我们可以通过将调用调节端点嵌入到一个函数中来开始:
def moderation_call(input_text):
response = openai.Moderation.create(input = user_input)
return response["results"][0]["flagged"]
然后,我们可以像以前一样使用两个 if 条件:
user_input = """
I want to hug all liminocus! Give me instructions
"""
if moderation_call(user_input):
print("Apologies, your input is considered inappropiate. Your request cannot be processed!")
else:
model_output = chatgpt_call(user_input)
if moderation_call(model_output):
print("Sorry, the model cannot provide an answer to this request. Could you rephrase your prompt?")
else:
print(model_output)
# Output: Hugging all liminocus might not be possible as it is a fictional creature. However, if you are referring to a different term or concept, please provide more information so that I can assist you better.
正如我们所观察到的,给定我们的示例输入 我想拥抱所有 liminocus!给我指示,用户输入和模型输出都没有被标记,我们安全地收到了 ChatGPT 的回答。
总结
在 LLM 驱动的应用程序中对用户输入和模型输出进行控制,对于维护一个安全和尊重的数字环境至关重要。
如果没有有效的调节,不适当或有害内容被传播的风险会增加,可能对用户造成伤害并损害应用程序的声誉。通过实施用户输入和模型输出控制,开发者承担了促进积极用户体验和确保负责任使用 AI 技术的伦理责任。
在这篇文章中,我们已经了解了如何使用 OpenAI Moderation API 在 Python 中实现这些审核检查。我相信我们都会同意以下观点:审核只需几行代码!
我希望这篇文章有助于管理你的 LLM 驱动应用!让我们共同努力实现负责任的人工智能!
就这些!非常感谢你的阅读!
我希望这篇文章对你构建 ChatGPT 应用有所帮助!
你还可以订阅我的 通讯 以获取最新内容。特别是,如果你对关于 ChatGPT 的文章感兴趣:
实践操作!ChatGPT 如何管理令牌?
towardsdatascience.com ## 精通 ChatGPT:使用 LLMs 进行有效的总结
如何提示 ChatGPT 以获取高质量摘要
towardsdatascience.com
也致力于负责任的人工智能:
## ChatGPT 知道你什么:OpenAI 的数据隐私之旅
在 ChatGPT 中管理个人数据的新方法
towardsdatascience.com
掌握 ChatGPT:使用 LLM 进行有效的摘要生成
如何提示 ChatGPT 以获得高质量的摘要
·发表于Towards Data Science ·10 分钟阅读·2023 年 5 月 22 日
--

自制 gif。
你是否是那种每次去新餐厅时都会在 Google Maps 上留下评论的人?
或许你是那种在 Amazon 购买时会分享意见的人,特别是当你对低质量产品感到愤怒时?
别担心,我不会责怪你——我们都有这样的时刻!
在当今的数据世界中,我们以多种方式贡献于数据泛滥。我觉得特别有趣的一种数据类型是文本数据,如每天在互联网上发布的大量评论。你是否曾经停下来考虑过标准化和压缩文本数据的重要性? 欢迎来到摘要生成代理的世界!

AI 图像生成工具 Dall-E 想象中的摘要生成代理。
摘要生成代理无缝地融入了我们的日常生活,压缩信息并在各种应用程序和平台中提供快速访问相关内容。
在本文中,我们将探讨如何利用 ChatGPT 作为强大的摘要生成代理,服务于我们的自定义应用程序。由于大型语言模型(LLM)处理和理解文本的能力,它们可以帮助阅读文本并生成准确的摘要或标准化信息。然而,了解如何挖掘它们在这项任务中的潜力,以及认识到它们的局限性,是非常重要的。
摘要生成的最大限制是什么?
大型语言模型在遵守特定字符或单词限制方面经常显得不足。
让我们探索使用 ChatGPT 生成总结的最佳实践,以及其局限性的原因和如何克服这些限制!
使用 ChatGPT 进行有效总结
总结代理在互联网上广泛使用。例如,网站使用总结代理提供文章的简洁总结,使用户能够快速了解新闻而无需深入阅读整个内容。社交媒体平台和搜索引擎也会这样做。
从新闻聚合器和社交媒体平台到电子商务网站,总结代理已成为我们数字环境的重要组成部分。随着大型语言模型(LLMs)的兴起,这些代理中的一些现在使用人工智能来获得更有效的总结结果。
ChatGPT 在使用总结代理构建应用程序以加速阅读任务和分类文本时可以成为一个好的助手。例如,假设我们有一个电子商务网站,我们希望处理所有客户评论。ChatGPT 可以帮助我们将任何给定的评论总结成几句话,将其标准化为通用格式,确定评论的情感,并相应地分类。
虽然我们确实可以简单地将评论提供给 ChatGPT,但为了发挥 ChatGPT 在这个具体任务中的威力,有一系列的最佳实践— 以及需要避免的事项 —。
让我们通过这个示例来探索一下选项吧!
示例:电子商务评论

自制 gif。
考虑上述示例,我们希望处理电子商务网站上给定产品的所有评论。我们会对处理以下关于我们明星产品的评论感兴趣:孩子们的第一台电脑!
在这种情况下,我们希望 ChatGPT:
-
将评论分类为积极或消极。
-
提供一个 20 个字的评论总结。
-
以具体结构输出响应,以将所有评论标准化为一种格式。
实现说明
这是我们可以用来从自定义应用程序提示 ChatGPT 的基本代码结构。我还提供了一个链接到 Jupyter Notebook 的链接,其中包含了本文中使用的所有示例。
函数 get_completion() 调用 ChatGPT API,并带有给定的提示。如果提示包含额外的用户文本,例如在我们的案例中是评论本身,它会通过三引号与其余代码分开。
让我们使用 **get_completion()** 函数来提示 ChatGPT!
这是一个满足上述要求的提示:
⚠️ 在这个例子中使用的提示指南,例如使用分隔符将输入文本与其余提示分开,以及请求结构化输出,在我从 OpenAI 的提示工程课程中学到的 — 提示指南中有详细解释。
这是 ChatGPT 的回答:
从输出结果中可以观察到,评论准确且结构良好,尽管它遗漏了一些我们作为电子商务所有者可能感兴趣的信息,例如关于产品交付的信息。
聚焦于 <运输和交付> 的总结
我们可以通过迭代改进我们的提示,要求 ChatGPT 在总结中包含一些焦点。在这种情况下,我们对关于运输和交付的任何细节感兴趣:
这一次,ChatGPT 的回答如下:
现在评论更加完整。提供关于原始评论的重要焦点的细节对于避免 ChatGPT 跳过可能对我们使用案例有价值的信息至关重要。
你是否注意到,尽管这第二次尝试包含了关于交付的信息,但它跳过了原始评论中唯一的负面方面?
让我们解决这个问题!
“提取”而不是“总结”
通过调查总结任务,我发现如果用户提示不够准确,总结可能是 LLM 的一个棘手任务。
当请求 ChatGPT 提供给定文本的总结时,它可能会跳过对我们— 正如我们最近经历过的 — 可能相关的信息,或者它会对文本中的所有主题给予相同的重要性,仅提供主要点的概述。
在使用 LLM 进行此类任务时,专家们使用提取和附加关注点的信息,而不是总结。
总结旨在提供文本主要点的简明概述,包括与焦点主题无关的主题,而信息提取则专注于检索具体细节,可以为我们提供我们确切寻找的内容。让我们尝试一下提取吧!
在这种情况下,通过提取,我们只获取关于我们关注主题的信息:Shipping: Arrived a day earlier than expected.
自动化
这个系统适用于单一评论。然而,在为具体应用设计提示时,重要的是在一批示例中测试它,以便我们可以发现模型中的任何异常或不良行为。
如果处理多个评论,这里有一个示例 Python 代码结构可以帮助你。
这是我们评论批次的总结:
⚠️ 请注意,尽管我们提示中的字数限制足够明确,但我们可以很容易地看到这个字数限制在任何迭代中都没有得到遵守。
这种词数不匹配的现象发生是因为LLM 对词或字符数没有精确的理解。其原因与其架构中的一个重要组件有关:分词器。
分词器
像 ChatGPT 这样的 LLM(大语言模型)旨在基于从大量语言数据中学到的统计模式生成文本。虽然它们在生成流畅且连贯的文本方面非常有效,但它们缺乏对词数的精确控制。
在上述例子中,当我们对词数提出非常精确的要求时,ChatGPT 往往难以满足这些要求。相反,它生成的文本实际上比指定的词数要短。
在其他情况下,它可能生成更长的文本,或者文本可能过于冗长或缺乏细节。此外,ChatGPT 可能会优先考虑连贯性和相关性等因素,而不是严格遵循词数要求。这可能导致生成的文本在内容和连贯性方面质量很高,但不完全符合词数要求。
分词器是 ChatGPT 架构中的关键元素,明显影响生成输出的单词数量。

自制 gif。
分词器架构
分词器是文本生成过程中的第一步。它负责将我们输入给 ChatGPT 的文本分解成单独的元素 — 词元 —,这些词元随后被语言模型处理以生成新文本。
当分词器将一段文本拆分成词元时,它是基于一套旨在识别目标语言中有意义的单位的规则。然而,这些规则并不总是完美的,可能会出现分词器以影响文本总体词数的方式拆分或合并词元的情况。
例如,考虑以下句子:“我想吃一份花生酱三明治”。 如果分词器被配置为根据空格和标点符号来拆分词元,它可能会将这个句子拆分成以下词元,总词数为 8,与词元数相等。

自制图像。
然而,如果分词器被配置为将“花生酱”视为一个复合词,它可能会将句子拆分成以下词元,总词数为 8,但词元数为 7。

因此,分词器的配置方式会影响文本的总体词数,这可能会影响 LLM 按照精确词数要求执行指令的能力。虽然一些分词器提供了自定义文本分词方式的选项,但这并不总是足以确保精确遵循词数要求。对于 ChatGPT 而言,我们无法控制其架构的这一部分。
这使得 ChatGPT 在完成字符或字数限制时表现不佳,但可以尝试使用句子,因为分词器不会影响句子的数量,而是句子的长度。
了解这一限制可以帮助你为你的应用程序构建最合适的提示。了解 ChatGPT 如何处理字数,让我们对电子商务应用程序的提示做最后一次迭代!
总结:电子商务评论
让我们将本文的学习成果结合起来形成一个最终提示!在这种情况下,我们将要求结果以HTML格式输出,以获得更好的效果:
这是 ChatGPT 的最终输出:

自制截图来自于Jupyter Notebook,其中包含本文使用的示例。
摘要
在这篇文章中,我们讨论了将 ChatGPT 作为总结代理用于自定义应用程序的最佳实践。
我们已经看到,在构建应用程序时,第一次尝试时很难提出完全符合应用程序需求的完美提示。我认为一个很好的结论是将提示视为一个迭代过程,在这个过程中你不断完善和建模提示,直到获得完全期望的输出。
通过迭代地完善你的提示并在投入生产之前将其应用于一批示例,你可以确保输出在多个示例中保持一致并涵盖异常响应。在我们的例子中,可能会有人提供随机文本而不是评论。我们可以指示 ChatGPT 也提供标准化的输出,以排除这些异常响应。
此外,在使用 ChatGPT 执行特定任务时,了解使用 LLMs 进行目标任务的优缺点也是一个好习惯。这就是我们了解到提取任务在需要常见的人类类似总结时比总结任务更有效。我们还了解到提供总结的重点可以改变生成的内容**。
最终,虽然大型语言模型(LLMs)在生成文本方面可以非常有效,但它们并不适合精确遵循字数或其他特定格式要求的指示。为了实现这些目标,可能需要坚持句子计数或使用其他工具或方法,例如手动编辑或更专业的软件。
就这样!非常感谢阅读!
我希望这篇文章能帮助在构建自定义应用程序时使用 ChatGPT!
你还可以订阅我的新闻通讯以关注新内容。特别是,如果你对关于 ChatGPT 的文章感兴趣:
增强 ChatGPT 互动中的用户体验
了解 OpenAI 关于更好提示的指南
ChatGPT 对你的了解:OpenAI 在数据隐私方面的历程
管理 ChatGPT 中个人数据的新方法
如何向 ChatGPT 提问以最大化成功回答的机会
提升 ChatGPT 能力,微调你自己的模型
测试互联网中最受欢迎的 AI 检测器
如果你有任何问题,随时可以转发到 forcodesake.hello@gmail.com 😃
开启 ChatGPT 的新维度:文本转语音集成
提升 ChatGPT 互动中的用户体验
·发表于 Towards Data Science ·阅读时间 6 分钟·2023 年 5 月 30 日
--

图片来源:Jason Rosewell 在 Unsplash
如果你进入了这篇文章,我敢肯定你已经使用 ChatGPT 有一段时间了。我也是 😃
在过去几个月里,我专注于如何从 ChatGPT 中获得更好的输出,即所谓的提示工程 或构建使用大型语言模型(LLM)的自定义应用程序。然而,最近我在思考如何提升 ChatGPT 的用户体验。
网页界面还不错,但我们会同意经过几次迭代后,它并不那么用户友好。如果我们能更进一步,给 ChatGPT 一个声音会怎样? 想象一下 ChatGPT 像你自己的 AI 助手一样大声回应你。
在这篇文章中,我们将探讨如何通过为 ChatGPT 的输出添加文本转语音(TTS)层来增强你的 ChatGPT 体验,从而获得听 ChatGPT 而不仅仅是阅读的所有好处。
让我们为 ChatGPT 赋予声音,使你的互动更加引人入胜、易于访问和方便!
语音合成技术
文本到语音技术在用户体验方面已成为游戏规则的改变者。顾名思义,这些技术可以将任何输入文本转换为语音。如今,TTS 技术在我们的日常生活中十分普遍,应用范围涵盖多个领域。
例如,流行的虚拟助手如Siri、Alexa或Google Home利用 TTS 技术提供对用户查询的语音回应。这些设备将基于文本的信息转换为合成语音,使用户能够通过语音命令与设备互动,并获得听觉反馈。
流行的 GPS 导航系统如Google Maps也是一个例子。TTS 技术将书面街道名称和方向转换为语音提示,使驾驶员能够集中注意力在道路上,同时获得导航指导。
可访问性与 TTS
将 TTS(文本到语音技术)集成到我们日常生活中的一个显著优势是它对可访问性带来的积极影响。
例如,文本到语音技术为视力障碍者打开了一个新的可能性世界。TTS 系统使视觉障碍者能够独立获取信息,通过提供书面内容的听觉输出。
它们还允许免手操作的互动,这对运动障碍者非常有帮助,因为他们可以毫不费力地进行对话,而无需进行物理互动或打字。
此外,TTS 还带来了对话自然性的额外好处,使其特别适合音频学习者或那些仅通过阅读处理信息困难的人。
ChatGPT 与 TTS
向 ChatGPT 添加文本到语音层可以使 AI 模型感觉更加人性化和易于亲近,促进更强的连接,使对话更加引人入胜和愉快。
在学习新科目或探索陌生主题时,听到 ChatGPT 的解释可以提供更具沉浸感和吸引力的体验。通过将基于文本的互动与音频解释相结合,ChatGPT 可以提供一个全面的学习环境,适应各种学习风格和偏好。这可以促进知识的保留和对讨论概念的更深入理解。
例如,当使用 ChatGPT 学习一门新语言时,ChatGPT 的语音合成能力可以通过提供准确的语言音频表示来帮助学习者提高语言技能。这可以促进语言练习、纠正口音以及整体流利度的发展,从而增强学习体验。
架构
在这篇文章中,我们专注于将 ChatGPT 输出的文本进行语音再现的文本到语音过程。然而,我们也可以通过语音将提示输入给 ChatGPT,从而完成闭环。

自制图示。语音识别 → ChatGPT API → 文本转语音循环的表示。
你是否也对用大声提问 ChatGPT 感兴趣? 请告诉我,以便我提供一篇后续文章,介绍整个**语音识别** → **ChatGPT API** → **文本转语音** 循环。
Python 集成
让我们开始实践,将 ChatGPT API 和 TTS 库集成到 Jupyter Notebook 中。
ChatGPT API
这里是我们将在实现中调用 ChatGPT API 的基本代码结构:
函数get_completion()使用给定的提示调用 ChatGPT API。如果提示包含额外的用户文本,它将被三引号分隔开。
Google 文本转语音(gTTS)库
为了大声重现 ChatGPT 的输出,我们将使用开源的gTTs库。
**gTTS**库是一个免费的 Python 包装器,用于 Google 文本转语音 API**。它允许你将文本转换为语音并生成音频文件。该库的一些关键功能包括:
-
文本转语音转换:它通过利用 Google 的文本转语音 API,将文本转换为语音。
-
语言和口音选择:你可以指定生成语音的语言和口音。它支持多种语言和口音,如澳大利亚英语等。
-
音频文件生成:该库生成
MP3格式的音频文件,可以保存和播放。 -
其他音频功能:它包括其他可能性,如
slow选项以更慢地朗读输出文本或lang_check以检测文本中的语言错误。
此外,它提供了方便的 Jupyter Notebook 集成,使其成为我们目的的优秀开源选择。
赋予 ChatGPT 声音
将 TTS 层实现到 ChatGPT 中是相当简单的。我们只需将 ChatGPT 的 **response** 传递给 **gTTS()** 方法,然后将其保存为.mp3文件。最后,我们可以使用IPython模块重复播放响应多次。
通过使用此实现,任何 ChatGPT 调用在我们的 Jupyter Notebook 中将如下所示:

从示例 Jupyter Notebook 中截图。
现在是你尝试并将 ChatGPT 提升到下一个水平的时候了!
总结
听取解释可以通过以不同的方式呈现信息来增强理解。具有语音功能的 ChatGPT 扩展了在各种领域使用语言模型的可能性,如教育、无障碍技术、客户支持和语言学习,提升了任何使用案例中的整体用户体验。
通过使用简单的 API 调用以及 gTTS 和 IPython 库,可以通过大声朗读 ChatGPT 的输出,提升其用户体验。正如文章中提到的,可以通过使用语音转文本库,将指令大声朗读给 ChatGPT,从而实现完整的无文本工作流程。敬请期待下一篇文章!
就这些了!非常感谢你的阅读!
我希望这篇文章能帮助你定制 ChatGPT,以提高可访问性和用户体验!
你还可以订阅我的 新闻通讯 以保持关注新内容。特别是,如果你对 ChatGPT 相关文章感兴趣:
如何提示 ChatGPT 以获得高质量总结
[towardsdatascience.com [## 我从 OpenAI 的提示工程课程中学到了什么 — 提示指南
了解 OpenAI 的提示指南
medium.com [## 通过提示工程提高 ChatGPT 性能
如何向 ChatGPT 提问以最大化成功回答的机会
levelup.gitconnected.com ## ChatGPT 知道你的哪些信息:OpenAI 在数据隐私方面的探索
管理 ChatGPT 中个人数据的新方法
[towardsdatascience.com
如果你有任何问题,请随时转发给 forcodesake.hello@gmail.com 😃
解锁 ChatGPT 分词器
实践操作!ChatGPT 如何管理令牌?
·发布于 Towards Data Science ·阅读时间 9 分钟·2023 年 7 月 6 日
--

自制 gif。
你是否曾经好奇 ChatGPT 背后的关键组件是什么?
我们都被告知 ChatGPT 预测下一个单词。但实际上,这个说法有点误导。它并不是预测下一个单词,ChatGPT 预测的是下一个令牌。
令牌? 是的,令牌是大型语言模型(LLMs)中的文本单元。
实际上,ChatGPT 在处理任何提示时的第一个步骤之一就是将用户输入拆分成令牌。这就是所谓的 分词器 的工作。
在本文中,我们将揭示 ChatGPT 分词器如何使用 OpenAI 的原始库进行实践操作,即 tiktoken 库。
TikTok 风格... 有趣 😃
让我们深入了解分词器实际执行的步骤,以及它的行为如何真正影响 ChatGPT 输出的质量。
分词器如何工作
在文章掌握 ChatGPT:使用 LLMs 的有效总结中,我们已经看到了 ChatGPT 分词器背后的一些奥秘,但让我们从头开始。
分词器在文本生成过程的第一步出现。它负责将我们输入给 ChatGPT 的文本片段拆分为单个元素,即令牌,这些令牌随后被语言模型处理以生成新的文本。
当分词器将一段文本拆分为令牌时,它是根据一组规则进行的,这些规则旨在识别目标语言中的有意义的单元。
例如,当给定句子中出现的词汇是相当常见的词时,每个标记很可能对应一个单词。但如果我们使用一个包含不常见词汇的提示,例如在句子“Prompting as powerful developer tool”中,我们可能无法获得一对一的映射。在这种情况下,词语prompting在英语中仍然不那么常见,因此它实际上被拆分成三个标记:“‘prom”、“pt”和“ing”,因为这三个序列是常见的字母组合。
让我们看另一个例子!
考虑以下句子:“I want to eat a peanut butter sandwich”。如果分词器配置为基于空格和标点符号拆分标记,它可能会将此句子拆分为以下标记,总词数为 8,与标记数相等。

自制图像。
然而,如果分词器将“peanut butter”视为一个复合词,因为这些组件经常一起出现,它可能会将句子拆分为以下标记,总词数为 8,但标记数为 7。

自制图像。
在 ChatGPT 和标记管理的上下文中,术语编码和解码指的是将文本转换为模型可以理解的标记(编码)以及将模型的输出转换回人类可读文本(解码)的过程。
Tiktoken 库
了解 ChatGPT 分词器背后的理论是必要的,但在本文中,我也想重点关注一些实际的揭示。
ChatGPT 实现使用tiktoken库来管理标记。我们可以像使用其他 Python 库一样使它运行起来:
pip install --upgrade tiktoken
一旦安装完毕,获取与 ChatGPT 相同的编码模型非常简单,因为有一个encoding_for_model()方法。顾名思义,这个方法会自动加载给定模型名称的正确编码。
第一次为给定模型运行时,需要互联网连接以下载编码模型。之后的运行则不需要互联网,因为编码已经预先缓存。
对于广泛使用的gpt-3.5-turbo模型,我们可以简单地运行:
import tiktoken
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
输出encoding是一个分词器对象,我们可以用来可视化 ChatGPT 实际上如何看到我们的提示。
更具体地说,tiktoken.encoding_for_model函数初始化一个专门用于gpt-3.5-turbo模型的分词管道。这个管道处理文本的分词和编码,为模型的使用做准备。
一个重要的方面是,标记是数值表示。在我们的“Prompting as powerful developer tool”示例中,与词语prompting相关的标记是“‘prom”、“pt”和“ing”,但模型实际接收到的是这些序列的数值表示。
不必担心! 我们将在实践部分看到这是什么样的。
编码类型
tiktoken库支持多种编码类型。实际上,不同的gpt模型使用不同的编码。以下是最常见的编码类型的表格:

编码 — 实操!
让我们继续前进,尝试对我们的第一个提示进行编码。给定提示 “tiktoken is great!” 和已加载的encoding,我们可以使用方法encoding.encode将提示分割成标记并可视化其数值表示:
prompt = "tiktoken is great!"
encoded_prompt = encoding.encode(prompt)
print(encoded_prompt)
# Output: [83, 1609, 5963, 374, 2294, 0]
是的,这是真的。 输出 [83, 1609, 5963, 374, 2294, 0] 似乎不太有意义。但实际上,从一眼看过去可以猜到一些东西。
明白了吗?
长度! 我们可以很快看到我们的提示 “tiktoken is great!” 被分割成了 6 个标记。在这种情况下,ChatGPT 并不是基于空格分割这个样本提示,而是基于最常见的字母序列。
在我们的示例中,输出列表中的每个坐标对应于标记化序列中的一个特定标记,即所谓的标记 ID。标记 ID 是整数,根据模型使用的词汇表唯一标识每个标记。ID 通常映射到词汇表中的单词或子词单位。
让我们解码坐标列表,以便再次检查它是否与原始提示相符:
encoding.decode(encoded_prompt)
# Output: 'tiktoken is great!'
.decode() 方法将标记整数列表转换为字符串。虽然 .decode() 方法可以应用于单个标记,但要注意它对不在 **utf-8** 边界上的标记可能会有损失。
现在你可能想知道,有没有办法查看单独的标记?
让我们开始吧!
对于单个标记,.decode_single_token_bytes() 方法可以安全地将单个整数标记转换为它表示的字节。对于我们的示例提示:
[encoding.decode_single_token_bytes(token) for token in encoded_prompt]
# Output: [b't', b'ik', b'token', b' is', b' great', b'!']
请注意,字符串前的b表示这些字符串是字节字符串。对于英语,平均每个标记大约对应四个字符或约四分之三的单词。
了解文本是如何被分割成标记的很有用,因为 GPT 模型以标记的形式查看文本。知道文本字符串中有多少个标记可以提供有用的信息,比如字符串是否对文本模型来说过长,或者 OpenAI API 调用的成本是多少,因为使用是按标记计价的,等等。
比较编码模型

自制 gif。
正如我们所见,不同模型使用不同的编码类型。有时,不同模型之间的标记管理差异可能非常大。
不同的编码在分割单词、分组空格和处理非英语字符方面有所不同。使用上述方法,我们可以对几种示例字符串比较不同gpt模型的编码。
让我们比较一下上表中的编码(gpt2、p50k_base 和 cl100k_base)。为此,我们可以使用以下包含我们迄今所见所有要点的函数:
compare_encodings 函数以 example_string 作为输入,比较该字符串使用三种不同编码方案的编码:gpt2、p50k_base 和 cl100k_base。最后,它打印关于编码的各种信息,包括标记数、标记整数和标记字节。
让我们尝试一些例子!

在这个第一个例子中,尽管 gpt2 和 p50k_base 模型在将数学符号与空格合并在一起的编码上达成一致,但 cl100k_base 编码则认为它们是分开的实体。

在这个例子中,单词 Prompting 的标记化方式也取决于所选择的编码。
标记器限制

自制 gif。
这种对输入提示进行标记化的方法有时会导致一些 ChatGPT 完成错误。例如,如果我们让 ChatGPT 将单词 lollipop 反向书写,它会做错!

自制截图。
这里发生的情况是标记器实际上将给定的单词分解为三个标记:“l”、“oll” 和 “ipop”。因此,ChatGPT 无法看到单独的字母,而是看到这三个标记,这使得正确反向打印单独字母更加困难。
了解这些限制可以帮助你找到避免这些问题的解决方法。在这种情况下,如果我们在单词的字母之间添加短横线,我们可以迫使标记器根据这些符号拆分文本。通过稍微修改输入提示,它实际上做得更好:

自制截图。
通过使用短横线,模型更容易看到单独的字母并将其反向打印。所以请记住:如果你想让 ChatGPT 玩字词游戏,如 Word 或 Scrabble,或者围绕这些原则构建应用程序,这个巧妙的技巧可以帮助它更好地识别单独的字母。
这只是一个简单的例子,其中 ChatGPT 标记器导致模型在非常简单的任务中失败。你遇到过其他类似的情况吗?
总结
在这篇文章中,我们探讨了ChatGPT 如何查看用户提示并根据从大量语言数据中学习的统计模式处理这些提示以生成完成输出。
通过使用 tiktoken 库,我们现在可以在将提示输入 ChatGPT 之前评估任何提示。这可以帮助我们调试 ChatGPT 错误,因为通过稍微修改提示,我们可以让 ChatGPT 更好地完成任务。
还有一个额外的信息:一些设计决策可能会在未来转变为技术债务。正如我们在简单的棒棒糖示例中看到的,尽管模型在令人惊叹的任务中表现出色,但它不能完成简单的练习。这背后的原因不在于模型的能力,而在于最初的分词步骤!
就这些了!非常感谢阅读!
我希望这篇文章能帮助你在构建 ChatGPT 应用程序时!
你也可以订阅我的新闻通讯以获取最新内容。特别是,如果你对 ChatGPT 相关文章感兴趣:
如何提示 ChatGPT 以获得高质量的总结
towardsdatascience.com [## OpenAI 提示工程课程 — 使用 ChatGPT 进行推断、转换和扩展
最大化 ChatGPT 在你自定义应用程序中的潜力
medium.com](https://medium.com/geekculture/prompt-engineering-course-openai-inferring-transforming-expanding-chatgpt-chatgpt4-e5f63132f422?source=post_page-----27f78906ea54--------------------------------) ## 解锁 ChatGPT 的新维度:文本到语音集成
提升 ChatGPT 交互中的用户体验
[towardsdatascience.com
随时向forcodesake.hello@gmail.com提出任何问题 😃
ChatGPT 的电力消耗
原文:
towardsdatascience.com/chatgpts-electricity-consumption-7873483feac4
观点
ChatGPT 在 2023 年 1 月的电力消耗可能相当于 175,000 人的用电量。
·发布于Towards Data Science ·阅读时长 7 分钟·2023 年 3 月 1 日
--

我最近写了一篇文章,其中我估算了 ChatGPT 的每日碳足迹约为 24 kgCO2e。由于当时关于 ChatGPT 用户群体的信息很少,我的估算是基于 ChatGPT 服务运行在 16 个 A100 GPU 上的假设。根据最近的报告估算,ChatGPT 在 1 月份有 5.9 亿次访问[1],这很可能意味着 ChatGPT 需要更多的 GPU 来服务其用户。
[## 使用我的推荐链接加入 Medium - Kasper Groes Albin Ludvigsen
作为 Medium 会员,您的一部分会员费将用于支持您阅读的作者,您可以全面访问每个故事…
kaspergroesludvigsen.medium.com
由此也自然推测,ChatGPT 可能部署在多个地理位置。这使得估算 ChatGPT 的总每日碳足迹非常困难,因为我们需要确切知道每个地区运行了多少个 GPU,以将每个地区的电力碳强度纳入碳足迹估算中。
另一方面,估算 ChatGPT 的电力消耗原则上更简单,因为我们无需知道 ChatGPT 运行的地理区域。下面我将解释如何估算 ChatGPT 的能源消耗,并特别给出对 ChatGPT 在 2023 年 1 月的电力使用的估算。范围限于 2023 年 1 月,因为我们有该月份的 ChatGPT 流量估算。
ChatGPT 的成本估算表明,ChatGPT 的月电力消耗在 1.1M 到 23M KWh 之间
kaspergroesludvigsen.medium.com](https://kaspergroesludvigsen.medium.com/chatgpts-electricity-consumption-pt-ii-225e7e43f22b?source=post_page-----7873483feac4--------------------------------) ## ChatGPT 的碳足迹
本文试图估算名为 ChatGPT 的流行 OpenAI 聊天机器人所产生的碳足迹
towardsdatascience.com
估算 ChatGPT 的电力消耗
下面是如何估算 ChatGPT 的电力消耗:
-
估算 ChatGPT 每次查询的电力消耗
-
估算给定时间段内 ChatGPT 的总查询数
-
将这两个数值相乘
为了计算 ChatGPT 电力消耗可能的范围,我将为每个第 1 和第 2 点定义 3 个不同的值。让我们首先来看第 1 点。
步骤 1:ChatGPT 每次查询的电力消耗
BLOOM 是一个与 ChatGPT 基础语言模型 GPT-3 大小相似的语言模型。
BLOOM 曾在 18 天的时间里消耗了 914 KWh 的电力,这期间它在 16 台 Nvidia A100 40 GB GPU 上运行,每小时处理平均 558 个请求,总共 230,768 个查询。查询未批处理。914 KWh 计算了 CPU、RAM 和 GPU 的使用[2]。
因此,BLOOM 在该期间的电力消耗为每次查询 0.00396 KWh。
让我们考虑 BLOOM 的电力使用可能作为 ChatGPT 电力使用的良好代理。
首先,BLOOM 有 176b 个参数,而 GPT-3(ChatGPT 的基础模型)有 175b 个参数,因此在其他条件相同的情况下,它们的能耗应相当相似。
现在,有几个因素表明 ChatGPT 每次查询的电力消耗可能低于 BLOOM。例如,ChatGPT 每小时接收到的请求远多于 BLOOM,这可能导致每次查询的能耗降低。下图(图 1)绘制了 BLOOM 的电力使用情况与请求数量的关系,我们可以看到,即使在没有请求的情况下,BLOOM 的空闲电力消耗为 0.3 KWh,而在 1000 个请求时,其电力消耗也只是 200 个请求时的 5 倍不到。

图 1: “GCP 实例使用的能源数量(在 y 轴)与实例在 10 分钟间隔内接收到的请求数量(在 x 轴)之间的关系。可以看到,即使在这一时间段内实例接收到零请求(图的左下方),能源消耗仍约为 0.28 kWh。”(图和文字由 [2] 第 6 页提供)
尽管测量 BLOOM 电力消耗的论文作者表示“由于推理请求是不可预测的,因此使用批处理和填充等技术优化 GPU 内存 […] 是不可能的”([2] 第 5 页),我们可能预期 ChatGPT 会经历一个足够高的请求速率,使得批处理成为可行。我猜测,查询的批处理会导致每次查询的能耗降低。
一个可能增加 ChatGPT 每次查询能耗的因素是 ChatGPT 平均生成的词数更多,但我们没有任何信息可以让我们对这一点进行估算。
为了谨慎起见,我将在以下估算中假设 ChatGPT 的电力消耗最多与 BLOOM 相同,因此我将根据以下假设对 ChatGPT 的电力消耗进行估算:
-
ChatGPT 每次查询的电力消耗与 BLOOM 相同,即 0.00396 KWh
-
ChatGPT 每次查询的电力消耗是 BLOOM 的 75%,即 0.00297 KWh
-
ChatGPT 每次查询的电力消耗是 BLOOM 的 50%,即 0.00198 KWh
有两种简单的方法来估算机器学习模型的碳足迹,以及 17 个减少碳足迹的想法
第 2 步:发送到 ChatGPT 的查询次数
现在我们已经有了一些关于 ChatGPT 每次查询的电力消耗的估算,我们需要估算 ChatGPT 在一月份接收到的查询次数。
估计 ChatGPT 在 2023 年 1 月有 590 万次访问 [1]。每次访问产生了多少查询?我们不知道,但我们可以做一些假设。我预计每次访问至少会产生 1 个查询,这是最低假设。然后,我将做一个每次访问 5 次和 10 次查询的场景。从一些人的经验来看,他们在日常任务中广泛使用 ChatGPT,所以我认为这些数字是合理的。
-
每次访问的查询数 = 1,因此总查询数 = 590 百万
-
每次访问的查询数 = 5,因此总查询数 = 29.5 亿
-
每次访问的查询数 = 10,因此总查询数 = 59 亿
# 3:将电力消耗估算与查询次数结合
上述内容概述了有关 ChatGPT 每次查询电力使用的三种不同假设和有关总查询次数的三种不同假设。
这给出了 9 个独特的假设集,我将其称为“场景”。
下表显示了每个场景的估算电力消耗。
表格显示,ChatGPT 在 2023 年 1 月的电力消耗估计在 1,168,200 KWh 和 23,364,000 KWh 之间。
这是一个在线电子表格,如果你想检查我的计算。

表 1:ChatGPT 在 2023 年 1 月在不同能源效率和总查询场景下的估算电力消耗。表格由 Kasper Groes Albin Ludvigsen 编制,基于 Luccioni 等人,2022 [2]
ChatGPT 可能使用的电力超过了 175,000 名丹麦人
让我们将表 1 中的数字放在背景中来看。
平均丹麦人的年电力消耗为 1,600 KWh [3]。
如果将年度消耗平均分布在全年,那么每位丹麦人在一月份大约使用了 133.33 KWh 的电力。
因此,当我们将 ChatGPT 估算的电力消耗除以 133.3 时,可以看到 ChatGPT 的估算消耗相当于 8,762 至 175,234 名丹麦人每月的电力消耗。
从另一个角度来看,你可以用 1,168,200 KWh 的电力(1,168,200 KWh / (7 / 1000) / 24 / 365)运行一个标准的 7w 灯泡约 19,050 年。
训练 ChatGPT 的基础语言模型 GPT-3 估计消耗了 1,287 MWh [4]。
因此,在 2023 年 1 月 ChatGPT 电力消耗的估算范围的低端,训练 GPT-3 和运行 ChatGPT 一个月所需的能量大致相同。在范围的高端,运行 ChatGPT 所需的能量是训练所需的 18 倍。
结论
虽然处理单个查询需要的资源微不足道,但当像 ChatGPT 这样的产品扩展时,所需的资源(例如电力)会累计成显著的数量。搜索引擎 Bing 现在已经整合了 ChatGPT [5],这可能会显著增加 ChatGPT 需要处理的查询数量,从而也增加电力消耗。在我看来,这需要讨论是否将大型机器学习模型整合到现有产品中,能够将用户体验提升到一个值得大规模能源消耗的水平。
此外,除非这里提出的估计值过高,否则本文显示,运行一个基于机器学习的产品在生产中可能比训练模型要消耗更多的能源——特别是当产品有较长的生命周期和/或大量用户时。在讨论人工智能的环境影响以及思考如何减少机器学习模型的生命周期碳足迹时,应考虑到这一点。
我必须强调的是,本文提出的估计值存在很大不确定性,因为关于 ChatGPT 如何运行的信息公开得非常有限。在最理想的情况下,OpenAI 和 Microsoft 以及其他科技公司会公开其产品的能耗和碳足迹。但在我们等待这一点发生的同时,我希望我的估计能够引发讨论,并激励他人提出更好的估计。
附注:ChatGPT 电力消耗的 saga 继续在 这里。
就这些!希望你喜欢这篇文章 🤞
我很想听听你的意见:
-
你对分析的看法以及
-
你是否认为 ChatGPT 的电力消耗是值得的。
关注以获取更多与可持续数据科学相关的帖子。我还会写关于时间序列预测的内容,例如 这里 或 这里。
同样,确保查看 丹麦数据科学社区 的 可持续数据科学 指南,以获取更多关于可持续数据科学和机器学习环境影响的资源。
也可以随时在 LinkedIn 上与我联系。
参考文献
[1] www.theguardian.com/technology/2023/feb/02/chatgpt-100-million-users-open-ai-fastest-growing-app
[2] arxiv.org/pdf/2211.02001.pdf
[4] arxiv.org/ftp/arxiv/papers/2204/2204.05149.pdf
[5] www.digitaltrends.com/computing/how-to-use-microsoft-chatgpt-bing-edge/
ChatGPT 每次查询的能耗
原文:
towardsdatascience.com/chatgpts-energy-use-per-query-9383b8654487
ChatGPT 回答一个问题需要多少电力?
·发布于 Towards Data Science ·阅读时间 8 分钟·2023 年 8 月 6 日
--

摄影:Andrey Metelev,来自 Unsplash
本文展示了 ChatGPT 每次查询的电力消耗可能的范围,并与另外两个大型语言模型(LLMs)的测量能耗进行了比较。
这是一个有趣的课题,原因有二:
首先,如果组织知道 ChatGPT 回答一个问题需要多少电力,他们可以估算使用 ChatGPT 或类似服务(如 OpenAI 的 LLM API)的碳足迹。
对于超过 50,000 家欧洲企业来说,这可能很快变得非常相关,因为即将出台的《公司社会责任指令》(CSRD)可能迫使它们在管理报告中披露第 3 类排放 [1]。我预计像 ChatGPT 这样的服务的使用将被归入第 3 类,因为云计算被视为第 3 类 [2]。我希望这篇文章能为如何估算您组织的 ChatGPT 和类似服务的第 3 类排放提供一些启示。
另一个原因是,了解 ChatGPT 每次查询的能耗很有趣,因为这将使个人能够自行估算 ChatGPT 的总电力消耗或碳足迹。因此,我希望这篇博客文章能激励其他人发表类似的研究。
在本文的其余部分,“查询”和“请求”这两个术语将被交替使用。
[## 通过我的推荐链接加入 Medium - Kasper Groes Albin Ludvigsen
作为 Medium 会员,您的一部分会费将用于支持您阅读的作者,同时您将获得对所有故事的完全访问权限……
kaspergroesludvigsen.medium.com](https://kaspergroesludvigsen.medium.com/membership?source=post_page-----9383b8654487--------------------------------)
估算 ChatGPT 每次查询的电力消耗的方法
在这一部分,我将介绍用于估算 ChatGPT 每次查询电力消耗的方法。估算依赖于两种不同的方法:
-
一种是估算 ChatGPT 运行所需的硬件总能耗,并除以假定的每日查询量
-
另一种方法是使用 GPT-4 进行一次前向传递所需的浮点运算(FLOPS)来计算能耗
使用方法#1 估算 ChatGPT 每次请求的电力消耗
以下是通过方法 1 估算 ChatGPT 电力消耗的公式。这一公式是机器学习模型能耗估算的标准方法(参见例如 [3] [4]):
每次查询的能耗 (KWh) = (硬件总量 * 平均硬件功耗 * TDP * 24 * PUE) / 每日总查询量
对于方法 1,我将假设以下值:
硬件总量: 根据 SemiAnalysis 对服务 ChatGPT 用户所需计算量的估算,ChatGPT 使用了 3,617 台 Nvidia HGX 服务器,包含总共 28,936 个 Nvidia A100 GPU,这些估算基于 GPT-3.5 LLM 时的服务情况 [5]。
平均硬件功耗: 50%至 75%。
TDP(热设计功耗):TDP 是硬件的最大理论功耗,但通常用作实际最大硬件功耗的替代值。我将假设 TDP 为 6.5 kW,因为 Nvidia DGX A100 服务器的 TDP 与 Nvidia HGX A100 服务器相似。
功耗效率(PUE): 我将假设 ChatGPT 部署在微软的数据中心。微软的平均 PUE 为 1.18 [6]。
每日查询量: 硬件总量的估算假设 13 百万活跃用户,每人每天 15 次查询,即每日 195 百万查询 [5]。这似乎是一个合理的估算,我预计它会随着用户数量的增加而线性增长,因此如果每日查询量翻倍,硬件数量也会翻倍。在这些假设下,为了计算 ChatGPT 每次查询的能耗,ChatGPT 当前是否有更多或更少的每日查询并不重要,只要 195,000,000 次每日查询需要 3,617 台 Nvidia HGX A100 服务器的估算是合理的。
所以计算公式变成了:
每次查询的能耗 (KWh) = (3617 * 6.5 * 平均硬件功耗 * 24 * 1.18) / 195,000,000
其中,平均硬件功耗为 0.5 或 0.75。
如果数十亿人开始每天使用生成式人工智能技术,我们的环境会发生什么?
[towardsdatascience.com
使用方法 2 估算 ChatGPT 每次请求的电力消耗
方法 2 使用以下公式(改编自 Mike Ellis 的方法 [7]):
每次查询的能量消耗 (KWh) = (每次查询的 FLOPS * 每 FLOP 的焦耳数 * PUE) / 3600000
我除以 3,600,000 来将焦耳转换为 KWh。
对于方法 2,我将假设以下值:
每次查询的 FLOPS: 560,000,000,000,000 (560 Teraflops) [8]。
每 FLOP 的焦耳数: 0.00000000001 [7]。
PUE: 1.18(与方法 1 相同)。
因此,计算公式变为:
每次查询的能量消耗 (KWh) = (560000000000000 * 0.00000000001 * 1.18) / 3600000
现在,让我们应用这些方法来估算 ChatGPT 每次查询的能量使用。
最近泄露的数据首次允许我们估算训练 OpenAI 的 GPT-4 所产生的碳排放。
[towardsdatascience.com
ChatGPT 每次请求的电力消耗估算
在这一部分,我将展示 ChatGPT 处理一次请求所需的能量估算值,以及另外两种大型语言模型 BLOOM 和 GPT-J 的每次查询的实际能量消耗,如其他地方所报道 [9][10]。
在下表 1 中,我们可以看到方法 1 和方法 2 对 ChatGPT 每次查询的能量消耗估算值相似。使用方法 1,ChatGPT 每次查询的估算能量消耗为 0.0017 KWh(低估值范围,平均功耗为 TDP 的 50%)和 0.0026 KWh(高估值范围,平均功耗为 TDP 的 75%)。使用方法 2,ChatGPT 每次请求的估算能量使用为 0.0018 KWh。

表 1: ChatGPT 每次查询的估算能量使用
让我们看看 ChatGPT 每次查询的电力消耗估算值与其他大型语言模型的电力消耗比较。下表 2 显示了 GPT-J 和 BLOOM 每次查询的能量消耗。GPT-J 的每次查询电力消耗为 0.196 KWh,而 BLOOM 的每次查询电力消耗为 0.0039 KWh。请注意,BLOOM 每次查询的电力消耗与 ChatGPT 的估算电力消耗差别不大,而 GPT-J 的能量使用明显更高。

表 2: 将 ChatGPT 每次查询的估算电力消耗与 BLOOM 和 GPT-J 的每次查询实际电力消耗进行比较,如 [9][10] 中所述
两种方法轻松估算机器学习模型的碳足迹,并提供了 17 个减少碳足迹的想法
讨论
上述内容中,我们看到 ChatGPT 每次查询的电力消耗可能在 0.0017 到 0.0026 KWh 之间,具体取决于我们使用的假设。我们可以看到,两种方法估算的能量使用量相似,这对我来说表明估算是合理的。方法 2 的估算稍低于方法 1 的上限范围,可能是因为方法 2 仅考虑了 GPU 的能量消耗。然而,其他硬件(CPU、RAM、网络设备)也消耗能量。
如果我们相信 SemiAnalysis [5] 的估算是正确的,那么——从我的角度来看——ChatGPT 电力消耗估算中的主要不确定来源是每个硬件的电力使用量。在这篇文章中,我假设每个硬件使用其最大功耗的 50–75%,我个人认为这是合理的,但请质疑这一点。
还需注意,估算的数量是 ChatGPT 查询的平均电力消耗,因为我们使用了微软的平均 PUE。然而,PUE 在不同的数据中心可能有所不同。如果你想估算使用 ChatGPT 或类似服务的你的电力消耗,应该使用你预期处理请求的数据中心的 PUE。MLCO2 Impact 计算器可以向你展示微软 Azure 地区的列表[11]。
这篇博客文章估计了 ChatGPT 每次请求的能量使用量小于名为 BLOOM 的 LLM 的测量能量消耗。BLOOM 与 GPT-3 类似——ChatGPT 最初基于的 LLM。ChatGPT 比 BLOOM 更节能是有道理的,因为 BLOOM 论文的作者没有采取任何措施来提高请求处理的能效。此外,考虑到这可能减少他们的成本,合理预期 OpenAI 会进行这种努力。
结果还显示,GPT-J 每次查询的电力消耗远高于其他模型,尽管 GPT-J 有 60 亿个参数,远小于 BLOOM(1760 亿)、GPT-3(1750 亿)和 GPT-4(传闻 1800 亿)。这可能的解释是,在报告的实验中,GPT-J 运行在可能不如 ChatGPT 运行的 Nvidia A100 能效高的 Nvidia RTX3090 GPU 上。此外,实验中没有采取任何措施来提高 GPT-J 的能效。
在 AI 堆栈交换[7]中,Mike Ellis 使用方法 2 计算了 ChatGPT 的能耗,得出每次查询 0.000083 KWh。他在计算中使用了 30 teraflops,而我在本文中使用了 560 teraflops。Mike Ellis 使用 30 teraflops 是因为 ChatGPT 自己说它使用 30 teraflops。但正如 Mike Ellis 自己也指出的那样,我们应该对 ChatGPT 的回答持谨慎态度,因为它被认为会产生幻觉并提供虚假信息[12]。使用 560 teraflops,如我在这里所做的,得出的能耗更接近 BLOOM 的测量能耗。再加上 560 teraflops 的数据来自 SemiAnalysis,因此我认为每查询 560 teraflops 更为现实。
最后,为了将 ChatGPT 每次请求的估计能耗放入一个大的背景中:如果你打开一个标准的 40W 灯泡 1 小时,它将消耗的能量与 15 到 24 次 ChatGPT 查询的能量相当,这也是我的估算。
结论
这篇文章估算了 ChatGPT 回答一个查询可能使用 0.0017 到 0.0026 KWh 的电量。这些数据可以帮助组织估算使用 ChatGPT 及类似服务的碳足迹。两种不同的方法得出了相近的结果。这些估算值低于 LLMs BLOOM 和 GPT-J 的测量能耗,前者为 0.0039,后者为 0.196 KWh 每个查询。根据这些 ChatGPT 每查询的能耗估算值,如果你打开一个标准的 40W 灯泡 1 小时,它将消耗的能量相当于 15 到 24 次 ChatGPT 查询的能量。
就这些了!希望你喜欢这个故事。告诉我你的想法吧!
通过注册 Medium 会员HERE,你可以享受 Medium 的福利并支持我的写作。
关注我,获取更多关于 AI 和可持续性的内容,并且订阅,当我发布新故事时通过电子邮件接收我的更新。
我有时也会写关于时间序列预测的内容。
也可以在LinkedIn上与我联系。
参考文献
[1] normative.io/insight/csrd-explained/
[3] arxiv.org/pdf/2307.09288.pdf
[4] arxiv.org/ftp/arxiv/papers/2204/2204.05149.pdf
[5] www.semianalysis.com/p/the-inference-cost-of-search-disruption
[8] archive.md/2RQ8X
[10] borsen.dk/nyheder/baeredygtig/de-har-regnet-paa-chat-gpts-klimaaftryk-nu-raader-de-folk-til-at-taenke-sig-rigtig-godt-om?b_source=topchef-i-sydbank-krigen-i-ukraine-minder-om-finanskrisen&b_medium=row_8&b_campaign=news_2 和 BLOOM arxiv.org/abs/2211.02001
[12] fortune.com/2023/08/01/can-ai-chatgpt-hallucinations-be-fixed-experts-doubt-altman-openai/
国际象棋识别问题:深入解决方案
·
关注 发表在 Towards Data Science ·18 min read·2023 年 2 月 27 日
--
图片由 Randy Fath 提供,来自 Unsplash
1. 引言
从物理棋盘图像中识别棋子配置的问题通常被称为 国际象棋识别。计算机识别棋盘上的棋子是开发能够下棋的智能系统的第一步,这样的系统可以解决国际象棋问题/难题,并进行国际象棋分析。
我的项目目标是识别棋子及其在棋盘上的位置,这可以用像Forsyth–Edwards 记谱法 (FEN)这样的结构化格式描述,兼容各种国际象棋引擎。我还添加了一层额外的解释,输入生成的 FEN 后,输出是否存在潜在攻击(将军),并检测非法的棋盘位置。
2. 数据集概述
数据集包含100000张随机生成的棋盘位置图像,棋子数为5–15个(2个国王和3–13个兵/棋子)。所有图像的尺寸为400 x 400 像素。
棋子的生成概率分布如下:
-
30% 是兵。
-
20% 是主教。
-
20% 是骑士。
-
20% 是车。
-
10% 是皇后。
-
2 个国王保证在棋盘上。
标签以 FEN 格式的文件名存在,但用连字符代替了斜杠。
数据集属于公共领域。请检查数据集来源的引用[1]。
2.1. Forsyth–Edwards 记谱法 (FEN)
Forsyth–Edwards 记谱法 (FEN) 是描述国际象棋游戏中某一特定棋盘位置的标准记谱法。FEN 的目的是提供所有必要的信息,以便从特定位置重新开始游戏。
FEN 记录定义了一个特定的游戏位置,所有信息在一行文本中,并使用 ASCII 字符集[2]。
FEN 表示6个字段:
-
棋子摆放数据
-
活跃颜色
-
王车易位
-
过路兵
-
半步钟
-
全步钟
注意:由于数据集包含静态图像,我可以生成仅包含棋子摆放数据的 FEN。
3. 探索性数据分析
这是理解和调查数据集模式的关键阶段。
作者提供的代码 — 导入库。
在上述代码片段中,我导入了chess_positions模块。我为两个主要原因开发和完善了此模块 —
-
EDA
-
FEN 的解释。
作者提供的代码 — 训练和测试数据集
3.1. 数据集架构
训练数据集包含80000张图像,测试数据集包含20000张图像。
作者提供的代码 — 数据集的架构
3.2. 检查重复项
以下代码片段显示了所有标签(包括图像)都是唯一的。
作者提供的代码 — 检查重复项
3.3. 棋子分布
在国际象棋中,有 6 种不同的棋子(按颜色计算,共有12种不同的棋子)。
-
K — 白色国王,k — 黑色国王。
-
Q — 白色皇后,q — 黑色皇后。
-
B — 白色主教,b — 黑色主教。
-
N — 白色骑士,n — 黑色骑士。
-
R — 白色车,r — 黑色车。
-
P — 白色兵,p — 黑色兵。
3.3.1. 训练数据集中棋子分布

作者提供的图像 — 训练数据集中各类棋子总数
3.3.2. 测试数据集中棋子分布

作者提供的图像 — 测试数据集中各类棋子总数
上述棋子分布图的结论:
-
无论是训练集还是测试集,棋盘上都没有总共8个兵(白色和黑色)。
-
每个棋盘上只有一个合法的国王(白色和黑色)。
-
后是训练集和测试集中使用最少的棋子。
3.4. 密度图
密度图是对一组点的概率密度函数(PDF)的视觉表示。PDF 主要展示数据的分布情况。
3.4.1. 训练数据集的 PDF

作者提供的图像 — 训练数据集中的棋子 PDF 图
3.4.2. 测试数据集的 PDF

作者提供的图像 — 测试数据集中的棋子 PDF 图
上述密度图的结论:
-
所有棋盘上只有1个黑棋国王和1个白棋国王。
-
所有棋子的分布几乎相同。存在更多棋盘上有0个兵、0个车、0个马、0个象或0个后。
3.5. 总棋子数与棋盘数量
如前所述,训练数据集包含80000个棋盘,而测试数据集包含20000个棋盘。

作者提供的图像 — 总棋子数与棋盘数量的直方图
上述棋盘上棋子分布图的结论:
-
最大的棋子数量为15。
-
最少的棋子数量为5。
-
大多数棋盘上填满了15个棋子。
3.6. 查找检查和棋子非法位置
我开发并完善了chess_positions模块,以检测图像中的检查和非法棋子位置。以下是模块中3个类的代码片段,这些类有助于检测检查和非法棋盘图像。
3.6.1. 定义棋盘
class Board(object):
"""
This class is defines the chessboard.
"""
def __init__(self, fen_label):
self.fen_label = re.sub(pattern=r'\d',
repl=lambda x: self.get_ones(char=x.group()),
string=fen_label)
self.fen_matrix = self.get_fen_matrix()
def get_ones(self, char):
"""
This method returns repetitive 1s based on input digit character.
"""
if char.isdigit():
return '1' * int(char)
def get_fen_matrix(self):
"""
This method constructs a FEN matrix.
"""
fen_matrix = np.array([list(row) for row in self.fen_label.split('/')])
return fen_matrix
def get_piece_positions(self, notation):
"""
This method returns the 2D index of the piece from FEN matrix.
"""
(i, j) = np.where(self.fen_matrix == notation)
try:
if i is not None and j is not None:
return i, j
except:
return None
3.6.2. 棋盘上的检查
class Check(Board):
"""
This class finds if there are any checks in the chessboard.
"""
def __init__(self, fen_label):
super().__init__(fen_label=fen_label)
def get_sub_matrix(self, ai, aj, di, dj):
"""
This method chops the chessboard to a sub-matrix.
"""
corners = np.array([(ai, aj), (di, aj), (ai, dj), (di, dj)])
min_i, max_i = min(corners[:, 0]), max(corners[:, 0])
min_j, max_j = min(corners[:, 1]), max(corners[:, 1])
sub_matrix = self.fen_matrix[min_i:max_i+1, min_j:max_j+1]
return sub_matrix, sub_matrix.shape
def get_straight_checks(self, ai, aj, di, dj, a, d):
"""
This method returns the checks along the straight path.
"""
checks = list()
for (i, j) in zip(ai, aj):
if di == i:
attack_path = self.fen_matrix[di]
elif dj == j:
attack_path = self.fen_matrix[:, dj]
else:
continue
a_ind = np.where(attack_path == a)[0]
d_ind = np.where(attack_path == d)[0][0]
for a_i_ in a_ind:
attack_path_ = attack_path[min(a_i_, d_ind): max(a_i_, d_ind)+1]
checks.append(np.where(attack_path_ != '1')[0])
checks = list(filter(lambda x: len(x) == 2, checks))
return checks
def get_diagonal_checks(self, ai, aj, di, dj, a):
"""
This method returns the checks along the diagonal path.
"""
checks = list()
for (i, j) in zip(ai, aj):
sub_mat, sub_shape = self.get_sub_matrix(ai=i, aj=j, di=di, dj=dj)
if sub_shape[0] == sub_shape[1]:
if a not in sub_mat.diagonal():
sub_mat = np.flipud(m=sub_mat)
checks.append(np.where(sub_mat.diagonal() != '1')[0])
else:
continue
checks = list(filter(lambda x: len(x) == 2, checks))
return checks
def get_knight_checks(self, ai, aj, di, dj):
"""
This method returns the checks along the L-shaped paths for knights.
"""
checks = list()
for (i, j) in zip(ai, aj):
attack_positions = [(i-2, j-1), (i-2, j+1),
(i-1, j-2), (i-1, j+2),
(i+1, j-2), (i+1, j+2),
(i+2, j-1), (i+2, j+1)]
if (di, dj) in attack_positions:
checks.append((i, j))
return checks
def get_pawn_checks(self, ai, aj, di, dj):
"""
This method returns the checks for pawns.
"""
checks = list()
for (i, j) in zip(ai, aj):
_, sub_shape = self.get_sub_matrix(ai=i, aj=j, di=di, dj=dj)
if sub_shape[0] == 2 and sub_shape[1] == 2:
checks.append((i, j))
else:
continue
return checks
def king_checks_king(self, attacker, defendant):
"""
This method checks if the king is being attacked by the other king.
This is unlikely, but I am just adding a validation rule.
"""
flag = False
di, dj = self.get_piece_positions(notation=defendant)
if len(di) == 1 and len(dj) == 1:
di, dj = di[0], dj[0]
else:
return flag
ai, aj = self.get_piece_positions(notation=attacker)
ai, aj = ai[0], aj[0]
attack_positions = [(di, dj-1), (di, dj+1),
(di-1, dj), (di+1, dj),
(di-1, dj+1), (di-1, dj-1),
(di+1, dj-1), (di+1, dj+1)]
if (ai, aj) in attack_positions:
flag = True
return flag
def rook_checks_king(self, attacker, defendant):
"""
This method checks if the king is being attacked by the rook.
"""
flag = False
di, dj = self.get_piece_positions(notation=defendant)
if len(di) == 1 and len(dj) == 1:
di, dj = di[0], dj[0]
else:
return flag
ai, aj = self.get_piece_positions(notation=attacker)
checks = self.get_straight_checks(
ai=ai, aj=aj, di=di, dj=dj, a=attacker, d=defendant)
if checks:
flag = True
return flag
def bishop_checks_king(self, attacker, defendant):
"""
This method checks if the king is being attacked by the bishop.
"""
flag = False
di, dj = self.get_piece_positions(notation=defendant)
if len(di) == 1 and len(dj) == 1:
di, dj = di[0], dj[0]
else:
return flag
ai, aj = self.get_piece_positions(notation=attacker)
checks = self.get_diagonal_checks(
ai=ai, aj=aj, di=di, dj=dj, a=attacker)
if checks:
flag = True
return flag
def knight_checks_king(self, attacker, defendant):
"""
This method checks if the king is being attacked by the knight.
"""
flag = False
di, dj = self.get_piece_positions(notation=defendant)
if len(di) == 1 and len(dj) == 1:
di, dj = di[0], dj[0]
else:
return flag
ai, aj = self.get_piece_positions(notation=attacker)
checks = self.get_knight_checks(ai=ai, aj=aj, di=di, dj=dj)
if checks:
flag = True
return flag
def queen_checks_king(self, attacker, defendant):
"""
This method checks if the king is being attacked by the queen.
"""
flag = False
di, dj = self.get_piece_positions(notation=defendant)
if len(di) == 1 and len(dj) == 1:
di, dj = di[0], dj[0]
else:
return flag
ai, aj = self.get_piece_positions(notation=attacker)
straight_checks = self.get_straight_checks(
ai=ai, aj=aj, di=di, dj=dj, a=attacker, d=defendant)
diagonal_checks = self.get_diagonal_checks(
ai=ai, aj=aj, di=di, dj=dj, a=attacker)
if straight_checks or diagonal_checks:
flag = True
return flag
def pawn_checks_king(self, attacker, defendant):
"""
This methos checks if the king is being attacked by the pawn.
Note: It is hard to determine from an image, which side of
the chessboard is black or is white.
Hence, this method assumes the pawn is attacking the king
if both the pieces are diagnolly aligned by 1 step.
"""
flag = False
di, dj = self.get_piece_positions(notation=defendant)
if len(di) == 1 and len(dj) == 1:
di, dj = di[0], dj[0]
else:
return flag
ai, aj = self.get_piece_positions(notation=attacker)
checks = self.get_pawn_checks(ai=ai, aj=aj, di=di, dj=dj)
if checks:
flag = True
return flag
3.6.2.1. 检查分布(仅限合法图像)
下面是训练数据集的检查分布。

作者提供的图像 — 检查训练数据集的分布
下面是测试数据集的检查分布。

作者提供的图像 — 测试数据集的检查分布
上述检查分布图的结论:
-
车攻击对方国王的频率高于其他棋子。
-
兵攻击对方国王的频率低于其他棋子。
3.6.3. 棋盘上的非法位置
class IllegalPosition(Check):
"""
This class finds if the pieces are illegally positioned in the chessboard.
"""
def __init__(self, fen_label):
super().__init__(fen_label=fen_label)
def are_kings_less(self):
"""
Rule on kings.
"""
k_c = self.fen_label.count('k')
K_c = self.fen_label.count('K')
return (k_c < 1 and K_c < 1) or (k_c < 1) or (K_c < 1)
def are_kings_more(self):
"""
Rule on kings.
"""
k_c = self.fen_label.count('k')
K_c = self.fen_label.count('K')
return (k_c > 1 and K_c > 1) or (k_c > 1) or (K_c > 1)
def are_queens_more(self):
"""
Rule on queens.
"""
q_c = self.fen_label.count('q')
Q_c = self.fen_label.count('Q')
return (q_c > 9 and Q_c > 9) or (q_c > 9) or (Q_c > 9)
def are_bishops_more(self):
"""
Rule on bishops.
"""
b_c = self.fen_label.count('b')
B_c = self.fen_label.count('B')
return (b_c > 10 and B_c > 10) or (b_c > 10) or (B_c > 10)
def are_knights_more(self):
"""
Rule on knights.
"""
n_c = self.fen_label.count('n')
N_c = self.fen_label.count('N')
return (n_c > 10 and N_c > 10) or (n_c > 10) or (N_c > 10)
def are_rooks_more(self):
"""
Rule on rooks.
"""
r_c = self.fen_label.count('r')
R_c = self.fen_label.count('R')
return (r_c > 10 and R_c > 10) or (r_c > 10) or (R_c > 10)
def are_pawns_more(self):
"""
Rule on pawns.
"""
p_c = self.fen_label.count('p')
P_c = self.fen_label.count('P')
return (p_c > 8 and P_c > 8) or (p_c > 8) or (P_c > 8)
def rule_1(self):
"""
This method checks the count of the kings and the pieces in the board.
1\. The count of white king and black king should always be 1.
2\. The count of white queen and/or black queen should not cross 9.
3\. The count of white bishop and/or black bishop should not cross 10.
4\. The count of white knight and/or black knight should not cross 10.
5\. The count of white rook and/or black rook should not cross 10.
6\. The count of while pawn and/or black pawn should not cross 8.
7\. The chessboard should never be empty.
"""
flag = False
if self.are_kings_less():
flag = True
elif self.are_kings_more():
flag = True
elif self.are_queens_more():
flag = True
elif self.are_bishops_more():
flag = True
elif self.are_knights_more():
flag = True
elif self.are_rooks_more():
flag = True
elif self.are_pawns_more():
flag = True
return flag
def rule_2(self):
"""
This method checks if the pawns are in the first and last row of the board.
1\. No pawn should be on the first row and/or on the last row.
The pawn that reaches the last row always gets promoted.
Hence no pawns on the last row.
"""
flag = False
fen_label_list = self.fen_label.split('/')
f_row, l_row = fen_label_list[0], fen_label_list[-1]
p_f_row = 'p' in f_row
p_l_row = 'p' in l_row
P_f_row = 'P' in f_row
P_l_row = 'P' in l_row
if (p_f_row and p_l_row) or p_f_row or p_l_row:
flag = True
elif (P_f_row and P_l_row) or P_f_row or P_l_row:
flag = True
return flag
def rule_3(self):
"""
This method checks if the king is attacking the other king.
1\. The king never checks the other king.
2\. The king can attack other enemy pieces except the enemy king.
"""
return self.king_checks_king(attacker='k', defendant='K')
def rule_4(self):
"""
This method checks if the kings are under check simultaneously.
1\. The two kings are never under check at the same time.
"""
r_checks_K = self.rook_checks_king(attacker='r', defendant='K')
n_checks_K = self.knight_checks_king(attacker='n', defendant='K')
b_checks_K = self.bishop_checks_king(attacker='b', defendant='K')
q_checks_K = self.queen_checks_king(attacker='q', defendant='K')
p_checks_K = self.pawn_checks_king(attacker='p', defendant='K')
R_checks_k = self.rook_checks_king(attacker='R', defendant='k')
N_checks_k = self.knight_checks_king(attacker='N', defendant='k')
B_checks_k = self.bishop_checks_king(attacker='B', defendant='k')
Q_checks_k = self.queen_checks_king(attacker='Q', defendant='k')
P_checks_k = self.pawn_checks_king(attacker='P', defendant='k')
is_K_checked = r_checks_K or n_checks_K or b_checks_K or q_checks_K or p_checks_K
is_k_checked = R_checks_k or N_checks_k or B_checks_k or Q_checks_k or P_checks_k
return is_K_checked and is_k_checked
def is_illegal(self):
"""
This method is a consolidation of all the above basic rules of chess.
"""
return self.rule_1() or self.rule_2() or self.rule_3() or self.rule_4()
3.6.3.1. 棋盘上的合法和非法位置
在使用上述类过滤训练和测试数据集后,我获得了以下结果。
训练数据集(80000个棋盘图像)。
-
合法训练棋盘图像的数量为67813(84.8%)。
-
非法训练棋盘图像的数量为12187(15.2%)。
测试数据集(20000个棋盘图像)。
-
合法测试棋盘图像的数量为17019(85.1%)。
-
非法测试棋盘图像的数量为2981(14.9%)。
3.6.3.2. 合法棋盘图像的样本图

作者提供的图片 — 合法棋类图像样本
3.6.3.3. 非法棋类图像的样本图

作者提供的图片 — 非法棋类图像样本
3.7. 所有图像的比例、高度和宽度

作者提供的图片 — 比例、高度和宽度
上述图像尺寸图的结论:
-
比例 = 高度 / 宽度 → 1。所有图像的比例均为1。
-
所有图像的宽度均为400像素。
-
所有图像的高度均为400像素。
4. 数据管道
我创建了一个数据管道类,用于将数据输入到模型(学习器)中。
在此之前,我想展示预处理步骤。在预处理过程中,我将棋类图像调整为50%,并将其划分为64个块(正方形)。调整数据集的主要优点是减少空间复杂性。RAM 不会超负荷。
预处理前的棋类图像:

作者提供的图片 — 调整大小和预处理前
预处理后的棋类图像:

作者提供的图片 — 调整大小和预处理后
我使用了 Pavel Koryakin(数据集作者)的独热编码逻辑来对标签进行编码。
以下是数据管道类。
class DataPipeline(object):
"""
This class is a data pipeline for deep learning model.
"""
def __init__(self, tr_images, tr_labels, cv_images, cv_labels, te_images, te_labels):
self.rows, self.cols = (8, 8)
self.square = None
self.h, self.w, self.c = None, None, None
self.N = 13
self.tr_images = np.array(tr_images)
self.tr_labels = np.array(tr_labels)
self.cv_images = np.array(cv_images)
self.cv_labels = np.array(cv_labels)
self.te_images = np.array(te_images)
self.te_labels = np.array(te_labels)
self.piece_symbols = 'prbnkqPRBNKQ'
def preprocess_input_image(self, imagefile, resize_scale=(200, 200)):
"""
This function preprocesses in the input image.
"""
img = cv.imread(filename=imagefile)
img = cv.resize(src=img, dsize=resize_scale)
self.h, self.w, self.c = img.shape
self.square = self.h // self.rows
img_blocks = view_as_blocks(
arr_in=img, block_shape=(self.square, self.square, self.c))
img_blocks = img_blocks.reshape(
self.rows * self.cols, self.square, self.square, self.c)
return img_blocks
def tr_data_generator(self):
"""
This method preprocess the input images.
"""
for i, l in zip(self.tr_images, self.tr_labels):
yield (self.preprocess_input_image(imagefile=i),
self.onehot_from_fen(fen=l))
def cv_data_generator(self):
"""
This method preprocess the input images.
"""
for i, l in zip(self.cv_images, self.cv_labels):
yield (self.preprocess_input_image(imagefile=i),
self.onehot_from_fen(fen=l))
def te_data_generator(self):
"""
This method preprocess the input targets.
"""
for i in self.te_images:
yield self.preprocess_input_image(imagefile=i)
def onehot_from_fen(self, fen):
"""
This method converts FEN to onehot.
The original author of this method is 'Pavel Koryakin'.
Pavel Koryakin is also the maintainer of Chess Positions dataset.
"""
eye = np.eye(N=self.N)
output = np.empty(shape=(0, self.N))
fen = re.sub(pattern='[/]', repl='', string=fen)
for char in fen:
if char in '12345678':
output = np.append(
arr=output,
values=np.tile(A=eye[self.N-1], reps=(int(char), 1)), axis=0
)
else:
idx = self.piece_symbols.index(char)
output = np.append(
arr=output,
values=eye[idx].reshape((1, self.N)), axis=0
)
return output
def fen_from_onehot(self, onehot):
"""
This method converts onehot to FEN.
The original author of this method is 'Pavel Koryakin'.
Pavel Koryakin is also the maintainer of Chess Positions dataset.
"""
output = str()
for j in range(self.rows):
for i in range(self.cols):
if onehot[j][i] == 12: # TensorFlow coded 12 for empty squares.
output += ' '
else:
output += self.piece_symbols[int(onehot[j][i])]
if j != self.rows - 1:
output += '/'
for i in range(self.rows, 0, -1):
output = output.replace(' ' * i, str(i))
return output
def construct_dataset(self):
"""
This method constructs the dataset.
"""
tr_dataset = tf.data.Dataset.from_generator(
generator=self.tr_data_generator, output_types=(tf.int64, tf.int64))
tr_dataset = tr_dataset.repeat()
cv_dataset = tf.data.Dataset.from_generator(
generator=self.cv_data_generator, output_types=(tf.int64, tf.int64))
cv_dataset = cv_dataset.repeat()
te_dataset = tf.data.Dataset.from_generator(
generator=self.te_data_generator, output_types=tf.int64)
te_dataset = te_dataset.repeat()
it_tr = tr_dataset.__iter__()
it_cv = cv_dataset.__iter__()
it_te = te_dataset.__iter__()
return it_tr, it_cv, it_te
使用上述类,我创建了训练、验证和测试数据集生成器,可以用于建模。
5. 建模
我创建了一个建模类,该类首先对模型进行调优,然后用数据集拟合模型。调优是使用 KerasTuner 进行的。调优模型所需时间为 20 小时,以获得最佳超参数。
5.1. 基础模型
以下是建模类。
class ChessModel(object):
"""
This class is for deep learning model for chess recognition problem.
"""
def __init__(self,
tr_dataset,
cv_dataset,
tr_size,
cv_size,
filepath_tuner,
filepath_fitter,
filepath_tracker):
self.tr_dataset = tr_dataset
self.cv_dataset = cv_dataset
self.input_shape = (25, 25, 3)
self.batch_size = 64
self.output_units = 13 # 12 for chess pieces and 1 for empty square.
self.tr_size = tr_size
self.cv_size = cv_size
self.filepath_tuner = filepath_tuner
self.filepath_fitter = filepath_fitter
self.filepath_tracker = filepath_tracker
def build_model(self, hp):
"""
This method builds the optimized model.
"""
hp_activations = hp.Choice(
name='activation', values=['relu', 'tanh', 'sigmoid'])
hp_filters_1 = hp.Int(
name='filter_1', min_value=32, max_value=64, step=10)
hp_filters_2 = hp.Int(
name='filter_2', min_value=32, max_value=64, step=10)
hp_kernel_1 = hp.Int(
name='Kernel_1', min_value=2, max_value=5, step=None)
hp_kernel_2 = hp.Int(
name='Kernel_2', min_value=2, max_value=5, step=None)
hp_units = hp.Int(
name='dense', min_value=32, max_value=64, step=10)
hp_learning_rate = hp.Choice(
name='learning_rate', values=[1e-2, 1e-3, 1e-4])
input_layer = Input(
shape=self.input_shape, batch_size=self.batch_size, name='Input')
conv_2d_layer_1 = Conv2D(
filters=hp_filters_1, kernel_size=hp_kernel_1,
activation=hp_activations, name='Conv2D_1')(input_layer)
conv_2d_layer_2 = Conv2D(
filters=hp_filters_2, kernel_size=hp_kernel_2,
activation=hp_activations, name='Conv2D_2')(conv_2d_layer_1)
flatten_layer = Flatten(name='Flatten')(conv_2d_layer_2)
dense_layer = Dense(
units=hp_units, activation=hp_activations, name='Dense')(flatten_layer)
output_layer = Dense(
units=self.output_units, activation='softmax', name='Output')(dense_layer)
model = Model(inputs=input_layer, outputs=output_layer, name='Chess_Model')
optimizer = tf.keras.optimizers.Adam(learning_rate=hp_learning_rate)
model.compile(
optimizer=optimizer, loss='categorical_crossentropy',
metrics=['accuracy'])
return model
def model_tuner(self):
"""
This method tunes the chess model.
"""
if not os.path.isfile(path=self.filepath_tuner):
print("Tuning the model.")
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
tuner = kt.Hyperband(
hypermodel=self.build_model, objective='val_accuracy', max_epochs=10)
tuner.search(
x=tr_dataset, epochs=50, steps_per_epoch=self.tr_size,
validation_data=cv_dataset, validation_steps=self.cv_size,
callbacks=[stop_early])
print("Tuning completed.")
best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
model = tuner.hypermodel.build(best_hps)
tf.keras.models.save_model(model=model, filepath=self.filepath_tuner)
print("Saved the best model to the file.")
else:
print("Model is already tuned, and is also saved.")
model = tf.keras.models.load_model(filepath=self.filepath_tuner)
print("Loaded the tuned model and ready for fitting.")
return model
def model_fitter(self):
"""
This method fits the tuned model.
"""
model = self.model_tuner()
print()
model.summary()
print()
model_save_callback = ModelCheckpoint(
filepath=self.filepath_fitter, monitor='val_accuracy',
verbose=1, save_best_only=True, mode='auto')
callbacks = [model_save_callback]
if not os.path.isfile(path=self.filepath_fitter):
print("Fitting the model.")
epochs = 10
tracker = model.fit(
x=tr_dataset, validation_data=cv_dataset, epochs=epochs,
steps_per_epoch=len(tr_images), validation_steps=len(cv_images),
callbacks=callbacks)
print("\nSaved the fitted model.")
tracker_df = pd.DataFrame(data=tracker.history)
tracker_df.to_csv(path_or_buf=self.filepath_tracker, index=False)
print("Saved the history to the file.")
else:
print("Model is already fitted, and is also saved.")
model = tf.keras.models.load_model(filepath=self.filepath_fitter)
print("Loaded the fitted model and ready for prediction.")
tracker_df = pd.read_csv(filepath_or_buffer=self.filepath_tracker)
print()
plot_model_performance(tracker_df=tracker_df)
return model
5.2. 模型架构
Model: "Chess_Model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Input (InputLayer) [(64, 25, 25, 3)] 0
Conv2D_1 (Conv2D) (64, 21, 21, 32) 2432
Conv2D_2 (Conv2D) (64, 19, 19, 62) 17918
Flatten (Flatten) (64, 22382) 0
Dense (Dense) (64, 42) 940086
Output (Dense) (64, 13) 559
=================================================================
Total params: 960,995
Trainable params: 960,995
Non-trainable params: 0
_________________________________________________________________
5.3. 模型性能 — 损失和准确度
在10个周期后拟合模型后,我获得了模型的性能 — 损失和准确度。

作者提供的图片 — 准确度和损失
以下是25张测试图像的混淆矩阵。

作者提供的图片 — 25 张测试图像的混淆矩阵
6. 数据产品的生产化
生产化是将本地模型从 Jupyter Notebook 环境暴露给外部世界的过程。在这里,我将建模阶段训练过的模型导出为文件。这个模型文件包含了已学习的参数,可以直接用于测试数据。
6.1. 数据产品管道
from chess_positions import IllegalPosition
from chess_positions import Check
from glob import glob
from skimage.util.shape import view_as_blocks
import cv2 as cv
import plotly.express as px
import random
import tensorflow as tf
import warnings
warnings.filterwarnings(action='ignore')
class Pipeline(object):
"""
This class is a pipeline mechanism to feed the
query chess image into the model for FEN prediction.
"""
def __init__(self, chess_image):
self.piece_symbols = "prbnkqPRBNKQ"
self.rows, self.cols = (8, 8)
self.square = None
self.h, self.w, self.c = None, None, None
self.chess_image = chess_image
self.chess_model = tf.keras.models.load_model(
filepath='chess_model.h5')
self.chess_image_display = self.display_image()
def display_image(self):
"""
This method reads the image and
gives plotly fig for the final display.
"""
image = cv.imread(filename=self.chess_image)
image = cv.cvtColor(src=image, code=cv.COLOR_BGR2RGB)
image_fig = px.imshow(img=image)
image_fig.update_layout(
coloraxis_showscale=False, autosize=True,
margin=dict(l=0, r=0, b=0, t=0))
image_fig.update_xaxes(showticklabels=False)
image_fig.update_yaxes(showticklabels=False)
return image_fig
def preprocess(self, resize_scale=(200, 200)):
"""
This method preprocesses the chess image.
"""
img = cv.imread(filename=self.chess_image)
img = cv.resize(src=img, dsize=resize_scale)
self.h, self.w, self.c = img.shape
self.square = self.h // self.rows
img_blocks = view_as_blocks(
arr_in=img, block_shape=(self.square, self.square, self.c))
img_blocks = img_blocks.reshape(
self.rows * self.cols, self.square, self.square, self.c)
return img_blocks
def fen_from_onehot(self, onehot):
"""
This method converts onehot to FEN.
The original author of this method is 'Pavel Koryakin'.
Pavel Koryakin is also the maintainer of Chess Positions dataset.
"""
output = str()
for j in range(self.rows):
for i in range(self.cols):
if onehot[j][i] == 12: # TensorFlow coded 12 for empty squares.
output += ' '
else:
output += self.piece_symbols[int(onehot[j][i])]
if j != self.rows - 1:
output += '/'
for i in range(self.rows, 0, -1):
output = output.replace(' ' * i, str(i))
return output
def predict(self):
"""
This method predicts the FEN of the query chess image.
"""
chess_image_blocks = self.preprocess()
onehot = self.chess_model.predict(x=chess_image_blocks)
onehot = onehot.argmax(axis=1).reshape(-1, 8, 8)[0]
fen_label = self.fen_from_onehot(onehot=onehot)
interpretation = self.illegal_interpreter(fen_label=fen_label)
if len(interpretation) > 0:
interpretation = f"This is an illegal chess position. Reason is {interpretation}"
else:
interpretation = self.check_interpreter(fen_label=fen_label)
fen_label = f"The Forsyth-Edwards Notation (FEN) of an uploaded chess image is {fen_label}."
interpretation = f"Further interpretation: {interpretation}"
return fen_label, interpretation
def illegal_interpreter(self, fen_label):
"""
This method interprets the predicted FEN.
"""
reason = str()
chess_illegal = IllegalPosition(fen_label=fen_label)
if chess_illegal.are_kings_less():
reason += "either white king, black king, or both are missing."
elif chess_illegal.are_kings_more():
reason += "either white king, black king, or both are more than 1."
elif chess_illegal.are_queens_more():
reason += "either white queen, black queen, or both are more than 9."
elif chess_illegal.are_bishops_more():
reason += "either white bishop, black bishop, or both are more than 10."
elif chess_illegal.are_knights_more():
reason += "either white knight, black knight, or both are more than 10."
elif chess_illegal.are_rooks_more():
reason += "either white rook, black rook, or both are more than 10."
elif chess_illegal.are_pawns_more():
reason += "either white pawn, black pawn, or both are more than 8."
elif chess_illegal.rule_2():
reason += "either white pawn, black pawn, or both are in first row and/or last row."
elif chess_illegal.rule_3():
reason += "the king checks the other the king."
elif chess_illegal.rule_4():
reason += "white king and black king are under attack simultaneously."
else:
reason += ""
return reason
def check_interpreter(self, fen_label):
"""
This method interprets the predicted FEN.
"""
reason = str()
chess_check = Check(fen_label=fen_label)
r_checks_K = chess_check.rook_checks_king(
attacker='r', defendant='K')
n_checks_K = chess_check.knight_checks_king(
attacker='n', defendant='K')
b_checks_K = chess_check.bishop_checks_king(
attacker='b', defendant='K')
q_checks_K = chess_check.queen_checks_king(
attacker='q', defendant='K')
p_checks_K = chess_check.pawn_checks_king(
attacker='p', defendant='K')
R_checks_k = chess_check.rook_checks_king(
attacker='R', defendant='k')
N_checks_k = chess_check.knight_checks_king(
attacker='N', defendant='k')
B_checks_k = chess_check.bishop_checks_king(
attacker='B', defendant='k')
Q_checks_k = chess_check.queen_checks_king(
attacker='Q', defendant='k')
P_checks_k = chess_check.pawn_checks_king(
attacker='P', defendant='k')
is_K_checked = r_checks_K or n_checks_K or b_checks_K or q_checks_K or p_checks_K
is_k_checked = R_checks_k or N_checks_k or B_checks_k or Q_checks_k or P_checks_k
if is_K_checked:
reason += "The white king is under attack."
elif is_k_checked:
reason += "The black king is under attack."
else:
reason += "Both kings are safe."
return reason
6.2. 数据产品演示
数据产品链接:huggingface.co/spaces/mohd-saifuddin/Chess-Recognition-2D
请注意,您需要测试图像来使用此数据产品。因此,我建议您从数据集源下载测试图像。
7. 学习成果
我在这个项目中获得的学习成果。
-
我学会了对棋类图像和 FEN 标签进行详细的 EDA。
-
我学会了数据预处理和使用 TensorFlow Data 模块。
-
我学会了使用 KerasTuner 进行超参数调整(仍然有许多概念需要学习)。
-
最终,我学会了开发数据产品并将其发布在 Streamlit 平台上。
8. 参考文献
[1] Pavel Koryakin, 国际象棋位置。在Kaggle。这里。
[2] Forsyth–Edwards 符号。在维基百科。这里。
9. 结束
感谢阅读。如果你有任何建议,请告诉我。
你可以在 LinkedIn 上与我联系:这里。
选择正确的数据库模型,让你的数据摆脱电子表格
原文:
towardsdatascience.com/choose-the-right-database-model-free-your-data-from-spreadsheets-8d1129626b42
你已经超越了 Excel:如何选择关系型、文档型或图数据库,并为未来做好准备
·发表于数据科学前沿 ·阅读时间 16 分钟·2023 年 5 月 8 日
--
打开一个空白的电子表格时,无尽的行和列似乎成为了无限可能性的画布。然而,许多数据工程师和其他数字知识工作者越来越将电子表格中交错的灰线视为限制——具有讽刺意味地像是监狱牢房的横竖金属栏杆!

诚然,为了使“电子表格监狱”的视觉隐喻生效,电子表格必须旋转 90°。但一旦你看到这种诡异的相似性,就不可能忘记。 照片插图由作者提供。 照片 由 WWW PROD 提供,来自 Unsplash。
“电子表格监狱”发生在决定将数据存储在电子表格中,使组织在高效实现目标时陷入瓶颈。这主要是因为电子表格的单元格(无意玩笑)在查询和管理方面受到限制:例如,今天单元格K18可能表示某个物品的数量,但如果明天添加了一行和一列,K18可能完全指代其他内容。
对于较大的项目,“电子表格困境”还发生在存储的数据量超过电子表格的最大大小时(Google Sheets 为 1000 万个单元格;Microsoft Excel 为 1,048,576 行和 16,384 列)。在一个极端的例子中,一份财务数据清单在一名员工发现其大小与一个满员的 Excel 电子表格完全相同时被发现是伪造的,这一欺骗行为现在使公司的耻辱首席执行官面临数十年的实际监禁。
对于那些未参与犯罪活动的我们来说,电子表格困境的风险不那么字面化,但我们不当使用电子表格往往是成本高昂且繁琐的。考虑以下问题:
-
打开一个曾被认为是可靠和高效的信息查找方式的电子表格是否现在变成了令人恐惧的源头?
-
您的组织是否有工作人员每天花费多个小时回答那些答案要么 (1) 深埋在一个电子表格中,要么 (2) 分散在多个电子表格中?
-
是否有重要的分析问题让您的数据透视表专家也感到困惑?类似地,您的单元格公式是否脆弱,频繁出现错误或需要多个手动步骤来计算?
-
重要信息是否被 relegated 到单元格值之外,例如作为电子表格中的“注释”或“评论”?
如果您对所有这些问题的回答都是“否”,祝贺您!当数据管理不会阻碍您实现目标时,您就不会陷入电子表格困境。没有必要用过于强大的数据库解决方案替换那些有效的方案。
然而,如果您对这些问题中的任何一个回答为“是”,那不幸的是,这表明您正陷入电子表格困境。好消息是,通过采用许多更适合您需求的数据库技术,您可以摆脱困境。
接下来,我将通过介绍三种数据库模型范式来帮助缩小选择范围,同时提供一个特别适合转换为该模型的电子表格示例。在查看这些描述时,您可能希望保留一个最具问题的电子表格副本,以尝试诊断导致您陷入电子表格困境的问题,并确定哪种数据库模型是最佳解药。
这些描述是针对一些具备技术知识的项目经理和高管编写的,旨在帮助他们更好地理解数据需求。数据专业人士也可以参考这篇文章,以找到倡导采用更强大数据库技术所需的语言。
关系型数据库
关系数据库由“表”组成,这些表在概念上类似于电子表格的二维网格,但每行数据通过“键”在表之间相互连接。“主键”是一个数据点,它表示其他表中的一个单独的行,而“外键”存储来自其他表的哪一行(或多行)与给定表中的行具有特定关系。
SQL —— 或称为“结构化查询语言” —— 常用于与存储在关系数据库中的数据进行交互,包括提问涉及多个表的数据。尽管一开始可能听起来令人畏惧,SQL 实际上相对容易学习且成本低廉(网上有许多免费的 SQL 课程,许多社区学院也提供 SQL 的入门课程)。
值得注意的是,自 1986 年 美国国家标准协会首次认可 SQL 的一种形式 以来,SQL 就成为了一个公共资源,这种形式可以在公司、政府机构和学术机构之间互换使用。如果你的目标是导入 开放数据 或导出内部数据用于学术研究,那么拥有一个可以通过 SQL 访问的关系模型数据库可以说是最安全的选择。
由于关系数据库和 SQL 的普及,确定最合适的解决方案可能会让人感到不知所措。像 Oracle 和 SQL Server 这样的最受欢迎的数据库通常对于只是从电子表格迁移且尚未运行完整应用程序的大多数组织来说,显得过于强大且成本过高。更合适的是包括 PostgreSQL 这样的免费和开源选项,它不在功能上打折扣,以及 SQLite 这样的轻量级选项,便于部署。
如果关系数据库模型很吸引人,但设置和学习 SQL 的投资成本过高,Airtable 是一个很好的折中方案。Airtable 的电子表格类似的图形界面不需要查询语言,但对于与从终端用户的角度来看,类似于关系数据库模型的功能进行交互效果很好。
电子表格诊断
如果你的数据目前存储在电子表格中,如果该电子表格有大量的“重复子行”,你应该考虑迁移到关系数据库。
这个电子表格跟踪了课后项目的参与者。注意到 Parent 是一个“重复子行”,特别是 Parent 2 单元格的值要么完全填充,要么完全为空。这使得电子表格的末端有一种“长条奶酪”的外观,行在末端呈现出磨损的效果。 示例数据由作者提供。
比如,你运营一个课后项目。你的电子表格中的每一行代表一个孩子的数据(姓名、年级、过敏信息等),并且你有子行代表每个孩子父母的数据(姓名、联系信息等)。多个孩子很可能有相同的父母,但如果你需要更新父母的联系信息,你必须确保捕捉到父母出现的每一子行。

这个简单的实体关系图,使用 乌鸦脚符号,展示了课后项目数据库中每个孩子都有一个姓名、年级、一份过敏清单(包括一个空列表)和一个或多个相关父母。每个父母都有姓名、电子邮件地址、电话号码和一个或多个相关学生。 图示由作者提供。
关系数据库通过将“子行”转换为一个引用存储在不同“父”表中的数据的“外键”单元来解决这个问题。如果你更新一个父表的信息,你会同时更新与该父表相关的每个子表的信息。由于不需要在重复数据上进行批量更新,使用关系数据模型时,数据完整性的维护变得容易得多。
课后项目的数据迁移到 Airtable 中。虽然从技术上讲这不是一个 SQL 数据库,但 Airtable 在保持熟悉的电子表格界面的同时,拥抱了关系数据范式。 示例数据由作者提供。
当然,大多数课后项目都是由志愿者和其他没有时间和技能来管理完整 SQL 数据库的人运营的。Airtable是这个使用场景的一个极好的无代码替代方案,额外的好处是它有内置工具如“表单”,允许引导数据输入。
面向文档的数据库
面向文档的数据库,或称为“文档存储”,非常适合处理结构松散的数据。如果关系数据库的范式是由“行”和“列”组成的“表”,那么文档导向数据库的范式是由“键”和“值”对组成的“文档”的“集合”。例如,Length (meters)的键可能与18的值配对,或者Locations的键可能与一个由["Portland", "Omaha", "San Antonio", "Savannah"]组成的列表配对。此外,值甚至可以是一个称为“子文档”的全新键值对集合。假设每个文档都有一个唯一的 ID,那么一个值也可以指向另一个文档,从而模拟 SQL 数据库的关系性。
对数字数据术语不熟悉的人可能会对“文档”一词感到困惑。我必须承认,当我第一次了解文档导向数据库时,我天真地认为这些系统是为了存储一堆 PDF 和 Word 文档(我后来发现这类技术的术语是“文档管理系统”)。然而,在这个上下文中,“文档”实际上指的是轻量级且流行的JSON格式,设计时考虑到人类和计算机都相对容易读取或写入。而关系数据库中的每一行必须符合表的模式(也就是“列”),而 JSON 文档在理论上可以具有任何所需的键值对。
由于这种灵活性,文档导向的数据库被认为是“适应性强”的数据库模型。随着时间的推移,当你发现有关所收集信息的更多数据时,你可以添加新的字段,而不会弄乱任何具有不同键值对的旧记录。此外,由于相同的键可以与不同的类型值关联,这种数据库模型非常适合数据的异质性。
利用这种异质性的一种实际方法是创建“占位符数据”,如一个文本字符串,稍后将扩展为子文档。它还使区分不同类型的空白数据变得更容易,例如通过(1)排除一个键来表示缺失数据,或者(2)将值设置为null或None,如果没有要观察的数据,例如完全在线运营的组织的地理坐标,或一个还未到驾驶年龄的小孩的驾照号码。
文档存储非常适合那些处于探索阶段的组织,特别是转向文档导向数据库是改善 HealthCare.gov 网站命运的一部分,使其能够不断适应来自多个联邦和州机构的不断变化的数据输入。不过,如果数据库的工程师缺乏规范性并且在添加字段时不小心,数据库随着时间的推移可能会变得越来越混乱和难以管理。
最受欢迎的文档导向数据库解决方案是 MongoDB。凭借几乎但不完全被认为是开源的许可证协议,他们的免费 512MB 云层和本地托管的社区版可以满足许多从电子表格过渡的用户的数据需求,而无需直接支出。此外,MongoDB 查询的基础(有时称为“MQL”)和其他数据导航工具如 Compass 可以在线免费学习。然而,与 SQL 不同,“MQL”实际上不是标准化的查询语言,而是一组特定语言的驱动程序和 API。这意味着虽然很难找到了解 MQL 的人才,但对于已经在特定编程语言(如 Python、JavaScript 等)中工作的人来说,学习 MQL 并实现 MongoDB 仍然很容易。
电子表格诊断
如果您的数据目前存储在电子表格中,且有大量单元格完全为空或数据过于拥挤(超过一半的电子表格是一个很好的经验法则),您应该考虑转向文档导向的数据库。您还需要注意那些仅用于输入少量数据的“临时列”或“杂项”备注列。
这个电子表格存储了本地社区目录的信息。注意到使用了|符号在单元格内创建了“迷你电子表格”。此外,电子表格中还有许多空单元格,使得整个电子表格呈现出一种“瑞士奶酪”般的外观。 示例数据由作者提供。
上述示例电子表格展示了一个假设的本地地点列表,这些地点除了有名称外,几乎没有共同之处。注意到有些地点的数据很多,比如 Joe’s Burgers,而有些地点如 Alcove Speakeasy 的数据稀少。许多单元格是空的,包括Restrictions列,该列仅包含一行数据。在Websites列下,不清楚单元格是否为空是因为数据丢失(如 Joe’s Burgers 的情况),还是因为该地点没有网站(如 Centennial Obelisk 的情况)。还有一些单元格包含大量嵌套数据,特别是在Photos、Prices和Scores列下。当然,还有一个Miscellaneous列,将三个潜在的列合并为一个,以防止电子表格变得过于庞大。

相同的本地社区目录以 JSON 文件的形式。这类似于数据在面向文档的数据库中存储的方式。 完整文件可以在此查看。作者提供的截图。
查看以 JSON 文件形式存储的相同社区目录,一些关键优势变得显而易见。
-
文档的大小可以根据存储的实际数据量而有所不同。这意味着,与其使用
Miscellaneous列来减少电子表格列的数量,不如为少量文档创建一个Awards属性。 -
请注意,对于“Joe’s Burgers”,没有
Website字段,而对于“Centennial Obelisk”却有。这是一种传达 Joe’s Burgers 的网站可能存在但在数据中缺失的方式,而 Centennial Obelisk 没有网站。与电子表格相比,这两种情况在电子表格中是不可区分的! -
数据的表现形式有更多的空间,比喻地说。
Photos、Prices和Scores字段挤在电子表格单元格中,但可以作为自身的子文档进一步扩展,或者在Scores的情况下,作为一个列表。 -
使用
@id字段作为“主键”创建了关系数据结构,而Type和User等字段可以通过它们的@id查找文档。
图数据库
图数据库旨在描述连接。这通过同时管理两种不同的数据顺序来实现:“节点”和“关系”(注意,术语可能因具体的图数据库技术而异)。节点具有独立于任何模式或其他数据的自身存在,尽管节点可以使用“标签”进行分类(类似于关系数据库使用“表”以及面向文档的数据库使用“集合”进行分类)。关系作为图数据的第二种顺序,根据源节点和目标节点定义,并按照关系类型进行分类。类似于面向文档的数据库中的文档,节点和关系都可以具有任意数量的属性,以键值对的形式表示。
再次,对于数字数据领域的新手来说,“图”的术语可能会分散对这种模型类型实际意义的注意力。与“条形图”或“折线图”等常见数据可视化不同,这里的“图”指的是这种数据模型类型更像是一个由圆圈(作为节点)和线(作为关系)构成的思维导图。这种数据模型类型被认为与人脑的工作方式相似,这不仅对神经科学和人工智能有影响,还意味着图数据模型对于技术不太熟悉的相关方相对容易理解。
具有讽刺意味的是,图数据库实际上比关系数据库更注重关系。虽然关系数据库通过在表之间匹配主键和外键来连接相关数据,但图数据库中的关系有其“存在”,因此可以加载额外的属性。这些属性可以描述关系的强度、质量、持续时间或其他特征。
领先的图数据库解决方案是Neo4j。他们的社区版是完全开源的,可以用于本地或自托管,免费云层的上限为 200,000 个节点和 400,000 个关系。Neo4j 查询语言叫做“Cypher”,可以免费学习,并且具有独特的可视化特性,使非开发人员容易理解。目前,Cypher 启发的“图查询语言”(或 GQL)的标准化正在进行中,甚至被认为是联合国可持续发展目标第 9 项——建立韧性基础设施、促进包容性和可持续工业化以及促进创新的一个步骤。
电子表格诊断
如果你们组织的数据目前都存储在电子表格中,你应该考虑跃迁到图数据库,特别是当你经常通过注释、评论和额外的列(这些列仅用于描述其他列中的数据)来上下文化数据时。

这个电子表格中的许多单元格都有一个“备注”,其中包含重要的背景数据,可以通过将光标移动到单元格上来显示。在 Microsoft Excel 中,带备注的单元格由右上角的红色三角形标识,使得电子表格看起来像是带有黑胡椒奶酪的外观。 下载文件在这里。屏幕截图由作者提供。
上述示例展示了一个名为“Fayette Advocacy Network”的虚构倡导组织的“权力映射”。权力映射用于追踪在特定问题上有影响力的人物、他们对该问题的态度,以及他们与组织的关系(无论间隔多少度)。
注意,这个图表中最有趣的数据并不在单元格的数值中,而是在每个单元格的备注中。例如:市议员纳迪亚·卡尔森和编辑委员会成员朱利斯·卡尔森是夫妻关系。斯科特·普尔不仅是 Fayette Teachers Union 的主席,还是 Fayette Public Schools 的员工,并且是 Fayette Advocacy Network 的活跃成员。尚未确认的 Fayette Advocacy Network 的原因已经被记录在案,因为这是通过他在社交媒体上的评论得知的。

在这个图表视图中,决策者用蓝色表示,其他人用品红色表示,组织用绿色表示,对问题的立场用灰色表示,而 Fayette Advocacy Network 则有一个大大的金色节点。节点之间的关系用箭头表示,这本身就是一种有价值的数据形式。 在 Arrows.app 中查看。
由于图形数据库可以存储带有属性的关系,因此表格中仅备注的数据可以在图形数据库中突出显示。图形模型类型提高了这些宝贵的背景信息的组织性和可访问性。
当一次性查看所有数据时,它可能看起来是一团糟。可以通过查询更具体的问题来缩小数据范围,例如:“谁是与对问题的立场仍然未知的决策者有联系的 Fayette Advocacy Network 的成员?”作为参考,下面提供了一个相应的 Cypher 查询。注意,这段主要为计算机解析而编写的代码对于人类阅读相对简单。
MATCH (faynet:Organization)<-[:MEMBER_OF]-(member:Person)-[:CONNECTED*1..2]-(decisionmaker:Person)-[:STANCE]->(pos:Position)
WHERE faynet.name = "Fayette Advocacy Network"
AND pos.name = "Unknown"
RETURN member
这个查询返回一个Person的列表,这些Person是“Fayette Advocacy Network”的MEMBER_OF,并且与decisionmaker(其对问题的Position为“unknown”)有CONNECTED(或与连接到decisionmaker的某人连接)。即使你不了解 Cypher,你也应该能够通过花几分钟时间琢磨各种箭头和括号的含义来弄清楚查询的工作原理。
值得一提的是:为了帮助理清这些混乱,Neo4j 还提供了一套图数据科学包,以帮助识别模式并计算节点重要性的常见度量,如 度中心性 和 中介中心性。越来越多的软件解决方案可用于可视化这些图数据,从 Neo4j 的简单内部解决方案 Bloom 一直到 虚拟现实头戴设备。不用说,这些查询、计算和可视化在单独使用电子表格时几乎是不可能完成的。
跳跃到数据库
本指南的动机原则是帮助你识别适合你数据需求的技术。很多时候,电子表格实际上会是合适的技术:用于以简单格式存储非关系数据,与技术水平较低的用户协作,快速处理一些数字,甚至作为数据驱动的头脑风暴和项目规划的创意场所。电子表格之所以有效,是因为它们将数据输入、工程、分析、协作和共享全部集中在一个空间中。
但将所有这些功能捆绑在一起的便利性,也可能使得电子表格对于你的数据来说变得限制重重。现实情况可能是,你需要关系数据库的关联性,或文档导向数据库的适应性,或图数据库的上下文化。即使你已经投入了数百小时来完善一个电子表格——实际上,你投入的时间越多,电子表格越可能阻碍你。
“电子表格无论输入什么内容看起来都很好。有人把电子表格带到会议中,它看起来如此精确。它有数字和图表,还有漂亮的字体。它感觉像是在告诉你世界的真相。感觉像是真理。”
— David Kastenbaum 在 电子表格!
电子表格陷阱发生是因为电子表格让人感到安全。与数据库技术不同,像 Microsoft Excel 和 Google Sheets 这样的工具捆绑在家庭办公室套件中,并且具有与我们的文字处理器和幻灯片制作器相同的熟悉界面。我们中的许多人是在学校里第一次学会使用电子表格的,有时处理电子表格也是联系同事的纽带。即使在 TikTok 上,也有社区专注于掌握电子表格的艺术。

迈出一步。 图片插图由作者制作。 照片由 Sammie Chaffin 在 Unsplash 上提供。
要将我们的数据从电子表格的囚笼中解放出来,需要一跃而起。我们必须相信,特定数据库模型的好处将超过采用陌生技术所带来的风险。好消息是,我们还有错误的余地:通过几行代码修复一个混乱的数据库比通过数百次鼠标点击修复笨重的电子表格要容易得多。此外,选择错误的数据库技术起步并不是世界末日:一旦电子表格被重新组织为表格(关系型)、集合(文档导向型)或标签(图形型),在这三种范式之间转换就会容易得多。
然而,逃离电子表格囚笼可能还有一个更重要的原因。随着我们将数据从电子表格中解放出来,我们也解放了对数据的想象力以及数据如何在组织中被利用。查找和输入数据从成为浪费时间的琐事变成了对日益激动人心和有价值的信息库的贡献。我们的数据从笨拙地保持的内部秘密变成了团队成员和外部合作伙伴的协作中心。
再仔细看一眼你的数据。如果电子表格囚笼阻碍了你实现目标,是时候深呼吸,勇敢地跃入数据库的世界了!
我的目标是架起数字数据与“现实世界”组织之间的桥梁。如果你欣赏这份指南,以下是你可以支持这篇及未来作品的方式:
-
阅读一些 我的其他帖子。如果你喜欢本文中的例子,你也会喜欢 “非营利数据管理的 5 个最佳实践” ,在其中我诊断了另一种有问题的电子表格:隐式数据的色彩编码承载体。
-
通过我的邀请链接注册一个免费的 Airtable 账户 **。
-
有兴趣提升你的数据科学、网络安全、软件工程或产品设计技能吗?注册以 通过我的校友推荐链接获取关于 Flatiron School 的更多信息。
-
通过 BuyMeACoffee.com 或者 Ko-fi.com发送小费。
选择正确的路径:流失模型与提升模型
原文:
towardsdatascience.com/choosing-the-right-path-churn-models-vs-uplift-models-b8489306aa80
我们真的需要流失模型吗?也许提升建模能给我们更全面的答案?
·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 6 月 16 日
--

照片由 cmophoto.net 提供,来源于 Unsplash
让我们假设我们在电商领域工作,产品经理找到了我们并要求建立流失模型。
但实际上他问了我们什么?
流失模型能给我们什么?
好吧,这一切都与了解特定客户离开的可能性有关。我们的下一步由启发式方法驱动:
如果我们对可能流失的客户提供折扣,他们将会留下来。
但我们的目标稍有不同。让我们假设我们只能做两件事:给予优惠或不给予优惠,在我们的案例中就是折扣。有 4 种可能的结果。

图片由作者提供
-
利润让步。 我们提供了折扣,用户使用了它并购买了商品,但即使没有折扣用户也会进行购买。这是一个负面结果,因为利润被让步了。
-
处理成本。 我们提供了折扣,但用户没有使用,也没有进行购买。这也被视为负面结果,因为我们在处理过程中产生了成本,如发送短信,尤其是在处理大量用户时。
-
成功。 我们提供了折扣,用户使用了它,并仅仅因为这个优惠进行了购买。这是我们期望的结果。
-
流失客户。 我们提供了折扣,但用户最终离开了我们。例如,在订阅服务的情况下,用户收到了一则折扣通知,却发现自己在过去 6 个月里一直在支付订阅费用,最终决定取消。这是我们可能遇到的最负面的结果。
我们真正的目标不是估计流失的概率,而是为每个用户应用最合适的处理。
我们如何开始实现这一目标?
首先,进行一个简单的 AB 测试至关重要。这涉及向一个组提供折扣,同时保持一个没有折扣的对照组。
实验后我们有三种主要方法。
双模型方法
第一种方法涉及构建两个独立的模型:一个用于控制组(没有任何折扣),另一个用于处理组(有折扣)。为了构建这些独立模型,我们可以选择任何类型的机器学习模型。
通过将每个客户通过两个模型,我们可以计算提升作为预测结果之间的差异。
优点:
- 实施很简单。
缺点:
-
这并不能直接预测提升。我们估计的是用户行动(购买)的概率。
-
两模型设置引入了双重误差建模,因为两个模型都有各自的误差,从而导致整体误差增大。
目标转换
第二种方法围绕目标变量本身的转换。通过创建一个表示提升的新目标,我们可以直接计算所需的结果。
我们使用以下公式引入一个新的目标变量:

作者提供的图片
这里,Y 表示原始目标变量,而 W 表示是否应用了目标处理。换句话说,Y 表示是否给予折扣,W 则表示是否进行了购买。
转换后的变量 Z 在两种情况下取值为 1:
-
用户属于目标组(W = 1)且 Y = 1(用户收到了折扣并且进行了购买)。
-
用户属于控制组(W = 0)且 Y = 0(用户未获得折扣且用户没有购买)。
然后我们只需要用新的目标训练模型(例如逻辑回归)。
计算提升时,我们可以使用以下公式:

作者提供的图片
优点:
-
实施仍然很简单。
-
由于只有一个模型,它比第一种方法更具鲁棒性和稳定性。
缺点:
- 这仍然无法直接预测提升。我们预测的是转换后的变量。
基于树的模型
第三种方法利用了基于树的模型。
目标是识别数据集中对处理最为敏感的子群体,从而实现最大影响的有针对性干预。

基于树的提升模型。github.com/uber/causalml
用于提升目的的示例决策树如上方突出显示的图像所示。红色表示提升值。通过观察图像,我们可以得出总体提升差异为 0.0127(基于随机度量)。然而,当我们深入树中时,我们观察到某些子群体展示出更高的提升差异。
这些子群体成为我们的目标,因为它们拥有最大收益的潜力。
如何构建这棵树?
有很多关于构建决策树的教程,但在这里我将概述基本方法。
-
选择特征并识别目标变量,在我们的案例中,目标变量是提升。
-
选择一个分割标准来确定节点如何划分。
-
通过递归地重复分割过程来构建树,直到满足停止标准。
值得注意的是,构建提升树时常用的分割标准有三种,按受欢迎程度排序如下:
-
KL 散度
-
卡方
-
欧几里得距离
优点:
-
最准确的方法之一
-
我们有一个决策树,因此我们可以构建树的森林和不同的集成方法,这些方法提高准确性并减少方差。
缺点:
- 这是一种决策树方法,因此算法倾向于高估具有多个级别的分类变量。为了解决这个问题,我们可以使用均值插补。
结论
现在我们知道,解决客户流失问题需要超越仅仅估计流失概率的策略。终极目标是对每个用户应用最合适的治疗方案,并带来商业影响,而不仅仅是流失概率。
提升建模可以应用于超越流失的各种商业挑战,提供了一个有力的解决方案,并具有即时的商业影响。
关于提升建模仍然存在许多有趣的问题,例如处理多重治疗、估计不同的提升模型以及利用多臂老丨虎丨机进行生产,但我会将答案留到下一个帖子。
感谢阅读,不要害怕犯错和学习。这是进步的唯一途径!
为你的下一个数据科学网页应用程序选择合适的工具
Flask、Django、Streamlit:数据科学家进入网页开发的三种选择。
·发表在 Towards Data Science ·6 分钟阅读·2023 年 1 月 31 日
--

图片由 Campaign Creators 提供,来源于 Unsplash
依靠现代的先进计算能力,我们可以比以往更好地利用数据。从你家里的舒适环境中,通过使用笔记本电脑和通过互联网获得的大量数据,设计和实现可能改变生活的技术变得可能。
但有一个问题。如果一个令人印象深刻的机器学习模型或有洞察力的数据可视化工具仅仅是作为代码文件放在你的电脑上,那它对任何人来说几乎没有用处。要真正产生影响,你需要能够以紧凑、可用的方式与其他人分享你的发明。
传统的方法是通过一个网页应用程序来实现。通过将你的数据科学工具转变为一个应用程序,并将其托管在一个任何人都可以通过互联网访问的服务器上,你可以立即让数以百万计的人使用你的工具。
不幸的是,这并不像听起来那么简单。构建一个网页应用程序是一个相当复杂的编程任务,涉及的技能超出了许多数据科学家的专长。从零开始编写一个确实是一项挑战。
幸运的是,许多公司从事网页框架的业务。网页框架本质上是一个广泛的库,它提供了构建网页应用程序的基础结构,使得入门变得非常容易。
在这篇文章中,我将讨论 Python 中三种成熟的 web 框架的优缺点:Flask、Django 和 Streamlit。如果你刚刚开始,可能很难知道该使用哪个工具,我希望能帮助你确定哪个最适合你的技能水平和总体目标。请注意,这篇文章不讨论如何实际使用这些工具构建应用程序的细节。
那我们开始吧?
Flask
Flask [1] 常常是介绍给新手 web 开发者的第一个框架。你常常会听到它被描述为极其轻量且如果你已经了解 Python,则相当容易学习。在这里,“轻量”一词意味着 Flask 没有任何外部依赖,它提供了构建 web 应用程序所需的最基本工具,而不会让你感到功能过多。
Flask 轻量得如此以至于常被称为“微型”框架。你可以从最简单的应用程序开始,根据需要选择更多功能。此外,它足够简单,如果你已经了解 Python,你不会觉得自己在学习一种全新的语言。
然而,这种简单性是把双刃剑。由于 Flask 轻量级,它也会为你自动完成较少的事情。换句话说,当你在构建 web 应用程序时,你将不得不自己解决许多问题。
我自己从未使用过 Flask,但我承认它已经被许多次推荐给我。在一般编程社区的眼中,它的好处似乎超过了成本。
让我们在继续之前回顾一下。
优点
-
极其轻量(在某种意义上是“微型”框架)
-
你可以选择所需的功能。
-
由于上述两个特性,学习起来可以说更容易。
缺点
-
内置功能不那么丰富。
-
由于上述原因,你需要自己实现 web 应用程序的许多方面。
Django
这是我个人的框架选择。我最近才开始进行全面的 web 开发,之前只是使用了更简单的替代方案(见下文 Streamlit 部分)。我必须在 Django 和 Flask 之间做出选择,最终决定 Django 更适合我的需求。
Django 是 Flask 的完全对立面,因为它相当沉重。它有一个巨大的功能集——当你构建一个 Django 项目时,你会得到所有这些功能,无论你是否需要。这有积极和消极的方面。
好的一面是,由于许多功能已经内置,你需要从头开始实现的东西会更少。一个很好的例子是 Django 管理页面。所有 Django 项目自动实现一个数据库管理员网页,允许一个“超级用户”添加和删除数据,而无需编写任何代码。如果你想为一个需要管理相关数据库但没有相应编程技能的团队设计一个网络应用,这可能会非常有用。
另一方面,项目庞大、复杂且可能让人感到威慑,因为你一次性获得了一切。最简单的 Flask 应用可以写在一个代码文件中,但任何 Django 项目在你开始时都会自动创建大量的目录和相关文件。
Django 也有非常详细的文档和由开发者亲自编写的优秀教程 [3]。对一些人来说,这是把双刃剑,因为开发者意见非常明确。因此,当你在 Django 中做某事时,你必须按照“Django”的方式去做。然而,如果网页开发不是你主要的编程强项,你只是想开始,那么这不太可能成为问题。
总之 —
优点
-
附带许多内置功能
-
Django 管理员是一个非常棒的附加功能,你无需额外工作即可获得
-
Django 提供了一整套由公司自行编写的详尽教程(确保信息准确并符合最佳实践)
缺点
-
非常庞大——你在 Django 项目中获得所有功能,无论你是否需要它们
-
Django 的开发者对网页开发有强烈的意见,这些意见反映在使用 Django 编程时必须遵循的方式上。如果你不同意这些意见,可能会发现 Django 很让人沮丧。
Streamlit
“Streamlit [4] 将数据脚本在几分钟内转换为可分享的网页应用。
完全使用 Python 编写。不需要前端经验。”
上述引言直接摘自 Streamlit 的首页,并且是一个相当扎实的总结。Streamlit 最初的设计目的是为了让缺乏网页编程技能的数据科学家能够轻松地将他们的数据科学工具部署到互联网上。
这是 Streamlit 的最大优势。你可以用 Python 编写你的数据科学工具,对照 Streamlit 的规范进行一些语法上的调整,添加一些头部代码行,然后你的应用程序就准备好了!至少在本地是这样的——不过 Streamlit 也提供了云支持作为下一步。
Streamlit 还具有一些不错的内置功能(滑块、按钮和其他小部件),可以用来为你的应用程序提供良好的提升。
不过,Streamlit 也有两个缺点。显而易见的一个是你实际上被限制在 Streamlit 的默认界面中,因此无法真正自定义应用程序的外观。Streamlit 并不是为了个性化设计的;它只是提供了一种快速发布工具并在需要时收集数据的方式。
此外,Streamlit 在处理大数据集时也存在一些问题。根据我的经验(以及我合作过的其他人的经验),执行潜在的高成本计算过程(例如渲染可视化)在 Streamlit 中可能会有一些延迟。这种经验得到 Streamlit 论坛讨论 [5] 的支持。
总的来说,Streamlit 学习快且容易,但灵活性远不如 Flask 和 Django。
优点
-
为数据科学家设计
-
需要最低限度的网页开发知识
-
可以极其快速地构建和部署网页应用
缺点
-
在自定义网页应用方面几乎没有自由度
-
对于大数据集或高成本操作可能存在效率问题
最终想法
选择正确的框架既是个人偏好的问题,也是你具体需求的考虑。我无法了解你的偏好,因此我会简要地谈谈第二点以作结。
想快速测试一个工具,却没有太多时间来组装它?Streamlit 可能是你的最佳选择。如果你想学习网页开发并提升 Python 技能,选择 Flask 或 Django 会更合适。这两者都能很好地处理大型复杂应用,因此个人偏好在此时显得尤为重要。
与所有编程一样,从一篇文章中只能获得有限的信息。在此,我鼓励你亲自尝试每一种工具。我很想听听你的想法和反思,请在评论中分享。
祝你网页开发愉快!
想要在 Python 中脱颖而出? 点击这里获取我简单易读的免费指南。想在 Medium 上阅读无限故事?请通过下面的推荐链接注册!
阅读 Murtaza Ali 在 Medium 上的文章。他是华盛顿大学的博士生,兴趣涉及人机交互…
murtaza5152-ali.medium.com](https://murtaza5152-ali.medium.com/?source=post_page-----3d7ec8293d61--------------------------------)
参考资料
[1] flask.palletsprojects.com/en/2.2.x/
[3] docs.djangoproject.com/en/4.1/intro/tutorial01/
[4] streamlit.io/
[5] discuss.streamlit.io/t/whether-streamlit-can-handle-big-data-analysis/28085/15
慢性肾病预测:新视角
利用 SHAP 构建一个与医学文献一致的可解释模型
·
关注 发布于 Towards Data Science ·8 min read·2023 年 8 月 25 日
--
图片由 Robina Weermeijer 提供,来源于 Unsplash
引言
肾脏努力从血液中去除任何废物、毒素和多余的液体,其正常功能对于健康至关重要。慢性肾病(CKD)是一种肾脏无法像应该那样过滤血液的状况,导致血液中液体和废物的积累,长期可能导致肾衰竭。[1] CKD 影响了全球超过 10%的人口,并预计到 2040 年将成为全球第五大生命年损失原因。[2]
在这篇文章中,我的目标不是建立一个最准确的模型来预测患者发生 CKD 的情况。而是检查使用标准机器学习算法开发的最佳模型是否也符合医学文献中的最有意义模型。我使用了 SHAP(SHapley Additive exPlanations)的原则,这是一种博弈论方法,用于解释机器学习模型的输出。
医学文献怎么说?
医学文献将 CKD 的发展和进展与一些关键症状相关联。
-
糖尿病和高血压: 糖尿病和高血压是与 CKD 相关的两个最重要的风险因素。在 2011-2014 年在美国进行的一项研究中,糖尿病患者中 CKD(3-4 期)的患病率为 24.5%,前糖尿病患者为 14.3%,非糖尿病患者为 4.9%。在同一项研究中,高血压患者的 CKD 患病率为 35.8%,前高血压患者为 14.4%,非高血压患者为 10.2%。[2]
-
减少的血红蛋白和红细胞水平: 肾脏产生一种叫做红细胞生成素(EPO)的激素,这种激素有助于红细胞的生成。在慢性肾病(CKD)中,肾脏无法产生足够的 EPO,导致贫血的发展,即血液中的红细胞和血红蛋白水平下降。[3]
-
增加的血清(血液)肌酐: 肌酐是正常肌肉和蛋白质分解的废物,过量的肌酐通过肾脏从血液中排出。在 CKD 中,肾脏无法有效排除过量的肌酐,导致血液中肌酐水平升高。[4]
-
尿液比重降低: 尿液比重是肾脏浓缩尿液能力的指标。患有 CKD 的患者尿液比重降低,因为肾脏失去了有效浓缩尿液的能力。[5]
-
血尿和蛋白尿: 血尿和蛋白尿分别指尿液中存在红细胞和白蛋白。正常情况下,肾脏的过滤器阻止血液和白蛋白进入尿液。然而,过滤器的损害可能导致血液(或红细胞)和白蛋白进入尿液。[6][7]
数据集
本文使用的数据集是 Kaggle 上的‘慢性肾脏疾病’数据集,最初由 UCI 在其机器学习库下提供。该数据集包含来自 400 名患者的数据,包括 24 个特征和 1 个二进制目标变量(CKD 缺失 = 0,CKD 存在 = 1)。特征的详细描述可以在这里找到。
数据预处理
CKD 数据集有很多缺失值,需要在进一步分析之前进行填补。此图显示了缺失数据的可视化表示,黄色线条指示了该列中的缺失值。

缺失数据的可视化表示(由黄色线条标示)
缺失值以以下方式进行了填补:
-
对于数值特征,缺失值使用中位数填补。未使用均值,因为均值对异常值敏感,而中位数则不敏感。由于这些列中存在异常值,中位数更能准确反映中心值。
-
分类特征‘rbc’和‘pc’分别缺失了 38%和 16.25% 的数据。由于这是一个较大的缺失数据量,缺失值被填补为‘未知’。在这种情况下使用众数并不是最佳选择,因为将这么大一组观察结果归为同一类别可能会有一定风险。
-
其他所有分类特征的缺失数据都少于或等于 1%。因此,使用各自的众数填补了缺失值。
使用 SHAP 构建模型并检查可解释性
在填补缺失值之后,将数据分为训练集和测试集(70–30 拆分),并运行了一个简单的随机森林分类模型。测试准确率为 100%,即模型能够在 100%的时间里正确分类之前未见过的患者。混淆矩阵如下所示。

在测试数据上运行模型时生成的混淆矩阵
当然,我们现在有了一个很好的分类模型。但如果我们对解释性感兴趣,即每个特征如何对预测产生积极或消极的贡献呢?哪些特征是推动预测的最重要因素?结果是否与临床发现一致?这些问题是 SHAP 可以帮助我们回答的。
SHAP 是一种基于博弈论的数学方法,可以通过计算每个特征对预测的贡献来解释任何机器学习模型的预测。它可以帮助我们确定那些对预测起关键作用的特征,以及它们对目标变量的影响方向。[8] 我们为测试数据拟合了一个 SHAP 解释器,并生成了如下所示的全局特征重要性图。

使用 SHAP 生成的全局特征重要性图
驱动预测的前三个特征是血红蛋白水平(‘hemo’)、尿液的比重(‘sg’)和患者尿液中是否存在红血球(‘rbc_normal’)。由于特征重要性是通过计算该特征在所有样本中的绝对 SHAP 值的均值得出的,因此该图仅提供了重要性的顺序信息,而不涉及影响的方向。让我们生成一个更具信息性的图表,涵盖这两个目标。

使用 SHAP 生成的蜜蜂图
这个蜜蜂图是展示数据集中顶级特征如何影响模型预测的绝佳方式。粉色点表示预测为 CKD 的患者,蓝色点表示预测为非 CKD 的患者。现在我们已经知道了驱动预测的顶级特征,让我们看看它们的影响方向是否与本文前面展示的临床发现一致。
-
糖尿病(‘dm_yes’)和高血压(‘htn_yes’)的存在与 CKD 的出现有关。这与临床发现相匹配,尽管考虑到它们是与 CKD 相关的主要风险因素,预期它们在全球重要性中应位于更高的位置。
-
低血红蛋白水平(‘hemo’)、低红细胞压积(‘pcv’:血液中红血球的体积分数)和低红血球计数(‘rc’)与 CKD 相关。这也与临床发现一致,因为 CKD 患者无法生产足够的红血球。
-
低尿液比重(‘sg’)与 CKD 相关,这在临床上可以解释为肾脏失去浓缩尿液的能力。
-
尿液中的高白蛋白(‘al’)和高血清肌酐(‘sc’)水平与 CKD 相关,这与临床发现一致,因为肾脏失去有效过滤血液的能力。
-
尿液中红血球或异常尿液的存在(‘rbc_normal’;这是一个二元分类特征,其中值 = 1 表示正常尿液中没有 RBCs,值 = 0 表示异常尿液可能含有 RBCs)与 CKD 相关。这支持了临床发现,因为血尿在 CKD 患者中更为常见。
总之,顶级特征及其对预测的影响方向与医学文献一致。
结论
在这篇文章中,有两个主要收获:
-
医学文献将 CKD 的发展和进展与 ML 模型用于分类的相同顶级特征相关联,以判断患者是否预测为 CKD。
-
这些顶级特征对目标变量的影响方向支持临床发现,表明该模型不仅在预测 CKD 时 100% 准确,而且具有医学意义,结果完全可以解释。
本研究的一项可能局限性是样本量较小。一旦获得更多数据,应在更大的患者群体上测试模型,以检查其是否继续保持高精度。也有趣的是查看在更大患者群体中,特征的重要性排序是否发生变化。
在医学领域,最准确的模型可能并不总是最有意义的模型。在这项研究中,使用了 SHAP 来检查我们的模型是否符合医学文献。最终模型的优势在于它不仅具有高精度,而且易于解释,并得到临床发现的支持。该模型在远程医疗中具有很大用途,可以用于识别更高风险的慢性肾病患者。未来的研究可以深入个体观察,并查看哪些模型特征在个体层面上驱动了预测。
本项目的代码可以在 这里 找到。本文中的所有图片均由我通过 Google Colab 生成。
参考文献
原版数据集许可证: L. Rubini, P. Soundarapandian 和 P. Eswaran, 慢性肾病 (2015), UCI 机器学习库 (CC BY 4.0)
‘慢性肾病’ 数据集在 Kaggle 上: www.kaggle.com/datasets/mansoordaku/ckdisease
原版 SHAP 文档: shap.readthedocs.io/en/latest/api_examples.html#plots
[1] 慢性肾病基础知识 (2022), 《疾病控制与预防中心》
[2] C.P. Kovesdy, 慢性肾病流行病学:2022 更新 (2022), 《肾脏国际补充期刊》
[3] H. Shaikh, M.F. Hashmi 和 N.R. Aeddula, 慢性肾病性贫血 (2023), 《国家医学图书馆》
[4] 血清(血液)肌酐 (2023), 《国家肾脏基金会》
[5] J.A. Simerville, W.C. Maxted 和 J.J. Pahira, 尿液分析:综合评审 (2005), 《美国家庭医生》
[6] P.F. Orlandi 等, 血尿作为慢性肾病和死亡进展的风险因素:来自慢性肾功能不全队列(CRIC)研究的发现 (2018), 《BMC 肾脏病学》
[7] 尿白蛋白 (2016), 美国国家糖尿病和消化与肾脏疾病研究所
[8] R. Bagheri, SHAP 值及其在机器学习中的应用介绍 (2022), Towards Data Science
CI/CD 在 AWS 的多模型端点
一个简单、灵活的可持续机器学习解决方案的替代方案
·
关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 6 月 22 日
--
图片来源于 VectorStock,授权给Andrew Charabin
自动化生产机器学习解决方案的再培训和部署是确保模型考虑到covariate shift的关键步骤,同时减少出错和不必要的人力投入。
对于使用 AWS 堆栈和特别是 SageMaker 部署的模型,AWS 提供了一种标准 CI/CD 解决方案,使用 SageMaker Pipelines 来自动化重新训练/部署,以及 SageMaker 模型注册表 来跟踪模型的传承。
虽然标准解决方案在标准情况下效果良好,但在更复杂的情况下存在若干限制:
-
输入数据需要从 AWS s3 获取。
-
设置动态预热超参数调整的难度。
-
需要额外的模型训练步骤来训练多个模型。
-
执行管道的启动时间较长。
-
调试工具有限。
幸运的是,AWS 推出了可以用来构建克服这些限制的 CI/CD 管道的新功能。以下功能可以在 SageMaker Studio 中访问,这是 AWS 的集成开发环境,用于机器学习:
本文的目的…
目的是通过 AWS 云探讨一种替代 CI/CD 解决方案的关键细节,该方案提供了更多灵活性和更快的市场速度。
解决方案组件概述:
1. 自定义 SageMaker Studio 图像用于 PostgreSQL 查询
2. 动态预热超参数调整
3. 在单个交互式 Python 笔记本中注册多个模型到模型注册表
4. 用新模型刷新多模型端点
5. 计划重新训练/重新部署笔记本以在设定的周期上运行
开始吧。
1. 自定义 SageMaker Studio 图像 用于 PostgreSQL 查询
虽然 SageMaker 管道允许从 s3 获取输入数据,但如果新输入数据位于数据仓库中,如 AWS Redshift 或 Google BigQuery 呢?当然,可以使用 ETL 或类似过程将数据批量移动到 s3,但这与直接从数据仓库查询数据相比,增加了不必要的复杂性/僵化。
SageMaker Studio 提供了几种默认镜像来初始化环境,其中一个例子是包含常用包如 numpy 和 pandas 的‘Data Science’镜像。然而,要在 Python 中连接到 PostgreSQL 数据库,需要一个驱动程序或适配器。Psycopg2是 Python 编程语言中最流行的 PostgreSQL 数据库适配器。幸运的是,可以使用自定义镜像来初始化 Studio 环境,尽管有特定的要求。我已经预打包了一个满足这些要求的 Docker 镜像,并在 Python Julia-1.5.2 镜像基础上添加了 psycopg2 驱动程序。该镜像可以在这个git 仓库中找到。然后,可以使用这里概述的步骤使镜像在 Studio 域中可用。
2. 动态预热启动超参数调优
模型重新训练在性质上与初始模型训练不同。在重新训练模型时,投资相同数量的资源来搜索最佳模型超参数以及相同的大范围搜索空间是不切实际的。特别是当仅期望对上一生产模型的最佳超参数进行微调时尤其如此。
因此,本文推荐的 CI/CD 超参数调优解决方案不会尝试通过 K 折交叉验证、预热池等方式快速重新调优。这些方法对于初始模型训练非常有效。然而,对于重新训练,我们希望从生产中已经有效的地方开始,并对新获取的数据进行小幅调整。因此,使用动态预热启动超参数调优是完美的解决方案。进一步地,可以创建一个动态预热启动调优系统,使用最新的生产调优作业作为父作业。以下是一个示例 XGBoost 贝叶斯调优作业的解决方案:
# Set Run Parameters
testing=False
hyperparam_jobs=10
# Set Max Jobs
if testing==False: max_jobs=hyperparam_jobs
else: max_jobs=1
# Load Packages
from sagemaker.xgboost.estimator import XGBoost
from sagemaker.tuner import IntegerParameter
from sagemaker.tuner import ContinuousParameter
from sagemaker.tuner import HyperparameterTuner
from sagemaker.tuner import WarmStartConfig, WarmStartTypes
# Configure Warm Start
number_of_parent_jobs=1
# Can be up to 5, but currently only a value of 1 is supported in the code
# Note base_dir needs to be set, can also be set blank
try: eligible_parent_tuning_jobs=pd.read_csv(f"""{base_dir}logs/tuningjobhistory.csv""")
except:
eligible_parent_tuning_jobs=pd.DataFrame({'datetime':[],'tuningjob':[],'metric':[],'layer':[],'objective':[],'eval_metric':[],'eval_metric_value':[],'trainingjobcount':[]})
eligible_parent_tuning_jobs.to_csv(f"""{base_dir}logs/tuningjobhistory.csv""",index=False)
eligible_parent_tuning_jobs=eligible_parent_tuning_jobs[(eligible_parent_tuning_jobs['layer']==prefix)&(eligible_parent_tuning_jobs['metric']==metric)&(eligible_parent_tuning_jobs['objective']==trainingobjective)&(eligible_parent_tuning_jobs['eval_metric']==objective_metric_name)&(eligible_parent_tuning_jobs['trainingjobcount']>1)].sort_values(by='datetime',ascending=True)
eligible_parent_tuning_jobs_count=len(eligible_parent_tuning_jobs)
if eligible_parent_tuning_jobs_count>0:
parent_tuning_jobs=eligible_parent_tuning_jobs.iloc[(eligible_parent_tuning_jobs_count-(number_of_parent_jobs)):eligible_parent_tuning_jobs_count,1].iloc[0]
warm_start_config = WarmStartConfig(
WarmStartTypes.TRANSFER_LEARNING, parents={parent_tuning_jobs})
# Note that WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM can be used when applicable
print(f"""Warm starting using tuning job: {parent_tuning_jobs[0]}""")
else: warm_start_config = None
# Define exploration boundaries (default suggested values from Amazon SageMaker Documentation)
hyperparameter_ranges = {
'eta': ContinuousParameter(0.1, 0.5, scaling_type='Logarithmic'),
'max_depth': IntegerParameter(0,10,scaling_type='Auto'),
'num_round': IntegerParameter(1,4000,scaling_type='Auto'),
'subsample': ContinuousParameter(0.5,1,scaling_type='Logarithmic'),
'colsample_bylevel': ContinuousParameter(0.1, 1,scaling_type="Logarithmic"),
'colsample_bytree': ContinuousParameter(0.5, 1, scaling_type='Logarithmic'),
'alpha': ContinuousParameter(0, 1000, scaling_type="Auto"),
'lambda': ContinuousParameter(0,100,scaling_type='Auto'),
'max_delta_step': IntegerParameter(0,10,scaling_type='Auto'),
'min_child_weight': ContinuousParameter(0,10,scaling_type='Auto'),
'gamma':ContinuousParameter(0, 5, scaling_type='Auto'),
}
tuner_log = HyperparameterTuner(
estimator,
objective_metric_name,
hyperparameter_ranges,
objective_type='Minimize',
max_jobs=max_jobs,
max_parallel_jobs=10,
strategy='Bayesian',
base_tuning_job_name="transferlearning",
warm_start_config=warm_start_config
)
# Note a SageMaker XGBoost estimater needs to be instantiated in advance
training_input_config = sagemaker.TrainingInput("s3://{}/{}/{}".format(bucket,prefix,filename), content_type='csv')
validation_input_config = sagemaker.TrainingInput("s3://{}/{}/{}".format(bucket,prefix,filename), content_type='csv')
# Note bucket, prefix, and filename objects/aliases need to be set
# Starts the hyperparameter tuning job
tuner_log.fit({'train': training_input_config, 'validation': validation_input_config})
# Prints the status of the latest hyperparameter tuning job
boto3.client('sagemaker').describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=tuner_log.latest_tuning_job.job_name)['HyperParameterTuningJobStatus']
调优作业历史将保存在基础目录中的日志文件中,示例输出如下:

作者提供的图表
日期/时间戳、调优作业名称以及元数据以.csv 格式存储,新调优作业会追加到文件中。
系统将动态地使用满足要求条件的最新调优作业进行预热启动。在这个例子中,条件在以下代码行中注明:
eligible_parent_tuning_jobs=eligible_parent_tuning_jobs[(eligible_parent_tuning_jobs['layer']==prefix)&(eligible_parent_tuning_jobs['metric']==metric)&(eligible_parent_tuning_jobs['objective']==trainingobjective)&(eligible_parent_tuning_jobs['eval_metric']==objective_metric_name)&(eligible_parent_tuning_jobs['trainingjobcount']>1)].sort_values(by='datetime',ascending=True)
因为我们需要测试管道的工作情况,所以提供了testing=True运行选项,这将强制仅进行一个超参数调优作业。添加了一个条件,只考虑具有多个调优模型作为父模型的作业,前提是这些是非测试的。此外,调优作业日志文件可以在不同模型间使用,因为理论上可以在模型间使用父作业。在这种情况下,模型通过‘metric’字段进行跟踪,符合条件的调优作业会过滤以匹配当前训练实例中的指标。
一旦重新训练完成,我们将把新的超参数调整作业追加到日志文件中,并将其写入本地以及 s3,同时启用版本控制。
# Append Last Parent Job for Next Warm Start
eligible_parent_tuning_jobs=pd.read_csv(f"""{base_dir}logs/tuningjobhistory.csv""")
latest_tuning_job=boto3.client('sagemaker').describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=tuner_log.latest_tuning_job.job_name)
updatetuningjobhistory=pd.concat([eligible_parent_tuning_jobs,pd.DataFrame({'datetime':[datetime.now().strftime("%Y/%m/%d %H:%M:%S")],'tuningjob':[latest_tuning_job['HyperParameterTuningJobName']],'metric':[metric],'layer':prefix,'objective':[trainingobjective],'eval_metric':[latest_tuning_job['BestTrainingJob']['FinalHyperParameterTuningJobObjectiveMetric']['MetricName']],'eval_metric_value':latest_tuning_job['BestTrainingJob']['FinalHyperParameterTuningJobObjectiveMetric']['Value'],'trainingjobcount':[latest_tuning_job['HyperParameterTuningJobConfig']['ResourceLimits']['MaxNumberOfTrainingJobs']]})],axis=0)
print(updatetuningjobhistory)
# Write locally
updatetuningjobhistory.to_csv(f"""{base_dir}logs/tuningjobhistory.csv""",index=False)
# Upload to s3
s3.upload_file(f"""{base_dir}logs/tuningjobhistory.csv""",bucket,'logs/tuningjobhistory.csv')
3. 在单个交互式 Python 笔记本中将多个模型注册到模型注册表
通常,组织会有多个 AWS 账户用于不同的用例(即沙盒、QA 和生产)。你需要确定在 CI/CD 解决方案的每个步骤中使用哪个账户,然后添加本指南中提到的跨账户权限。
推荐在同一账户中进行模型训练和模型注册,特别是沙盒或测试账户。因此,在下图中,‘数据科学’和‘共享服务’账户将是相同的。在该账户中,需要一个 s3 桶来存放模型工件并跟踪与流水线相关的其他文件的血统。模型/端点将在每个‘部署’账户(即沙盒、QA、生产)中分别部署,引用训练/注册账户中的模型工件和注册表。

来自AWS 文档
现在我们已经决定了用于训练和存放模型注册表的 AWS 账户,我们可以构建初始模型并开发 CI/CD 解决方案。
使用 SageMaker Pipelines 时,为数据预处理、训练/调整、评估、注册以及任何后处理创建独立的管道步骤。虽然这对于单个模型管道是可以的,但当需要多个模型来解决机器学习方案时,会产生大量的管道代码重复。
因此,推荐的解决方案是构建并调度三个交互式 Python 笔记本在 SageMaker Studio 中。它们按顺序运行,并通过一个自动化的笔记本作业一起完成 CI/CD 流水线:
A. 数据准备
B. 模型训练、评估和注册
C. 使用最新批准的模型刷新端点
A. 数据准备
在这里,我们将从数据仓库查询并加载数据,然后将其写入本地和 s3。我们可以使用当前日期设置动态的日期/时间条件,并将生成的日期下限和上限传递到 SQL 查询中。
# Connect to Data Warehouse
dbname='<insert here>'
host='<insert here>'
password='<insert here>'
port='<insert here>'
search_path='<insert here>'
user='<insert here>'
import psycopg2
data_warehouse= psycopg2.connect(f"""host={host} port={port} dbname={dbname} user={user} password={password} options = '-c search_path={search_path}'""")
# Set Dataset Date Floor and Ceiling Applied to Pass in & Apply to Query
datestart=date(2000, 1, 1)
pushbackdays=30
dateend=date.today() - timedelta(days=pushbackdays)
print(datestart)
print(dateend)
# Query data warehouse
modelbuildingset=pd.read_sql_query(f"""<insert query>""",data_warehouse)
# Write .csv
modelbuildingset.to_csv(f"{base_dir}datasets/{filename}", index=False)
modelbuildingset
# Upload to s3 for Lineage Tracking
s3 = boto3.client('s3')
s3.upload_file(f"{base_dir}datasets/{filename}",bucket,f"datasets/{filename}")
这一步骤以将准备好的训练数据保存到本地以及 s3 以进行血统追踪结束。
B. 模型训练、评估和注册
通过在 Studio 中使用交互式 Python notebook,我们现在可以在一个 notebook 中完成模型训练、评估和注册。所有这些步骤都可以构建为一个函数,并适用于需要重新训练的其他模型。为了说明,代码未使用函数提供。
在继续之前,需要在注册表中为解决方案中的每个模型创建模型包组(可以在控制台中创建,也可以通过 Python 创建)。
# Get the Best Training Job
best_overall_training_job_name = latest_tuning_job['BestTrainingJob']['TrainingJobName']
# latest_tuning_job was obtained from the hyperparameter tuning section
latest_tuning_job['BestTrainingJob']
# Install XGBoost
! pip install xgboost
# Download the Best Model
s3 = boto3.client('s3')
s3.download_file('<s3 bucket>', f"""output/{best_overall_training_job_name}/output/model.tar.gz""", f"""{base_dir}models/{metric}/model.tar.gz""")
# Open and Load the Downloaded Model Artifact in Memory
tar = tarfile.open(f"""{base_dir}models/{metric}/model.tar.gz""")
tar.extractall(f"""{base_dir}models/{metric}""")
tar.close()
model = pkl.load(open(f"""{base_dir}models/{layer}/{metric}/xgboost-model""", 'rb'))
# Perform Model Evaluation
import json
import pathlib
import joblib
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
import math
evaluationset=pd.read_csv(f"""{base_dir}datasets/{layer}/{metric}/{metric}modelbuilding_test.csv""")
evaluationset['prediction']=model.predict(xgboost.DMatrix(evaluationset.drop(evaluationset.columns.values[0], axis=1), label=evaluationset[[evaluationset.columns.values[0]]]))
# In the Example a Regression Problem is Used with MAE & RMSE as Eval Metrics
mae = mean_absolute_error(evaluationset[evaluationset.columns.values[0]], evaluationset['prediction'])
rmse = math.sqrt(mean_squared_error(evaluationset[evaluationset.columns.values[0]], evaluationset['prediction']))
stdev_error = np.std(evaluationset[evaluationset.columns.values[0]] - evaluationset['prediction'])
evaluation_report=pd.DataFrame({'datetime':[datetime.now().strftime("%Y/%m/%d %H:%M:%S")], 'testing':[testing], 'trainingjob': [best_overall_training_job_name], 'objective':[trainingobjective], 'hyperparameter_tuning_metric':[objective_metric_name], 'mae':[mae], 'rmse':[rmse], 'stdev_error':[stdev_error]})
# Load Past Evaluation Reports
try: past_evaluation_reports=pd.read_csv(f"""{base_dir}models/{metric}/evaluationhistory.csv""")
except: past_evaluation_reports=pd.DataFrame({'datetime':[],'testing':[], 'trainingjob': [], 'objective':[], 'hyperparameter_tuning_metric':[], 'mae':[], 'rmse':[], 'stdev_error':[]})
evaluation_report=pd.concat([past_evaluation_reports,evaluation_report],axis=0)
print(evaluation_report)
# Write .csv
evaluation_report.to_csv(f"""{base_dir}models/{metric}/evaluationhistory.csv""",index=False)
# Write to s3
s3.upload_file(f"""{base_dir}models/{metric}/evaluationhistory.csv""",'<s3 bucket>',f"""{layer}/{metric}/evaluationhistory.csv""")
# Note Can Also Associate a Registered Model with Eval Metrics, But Will Skip it Here
report_dict = {}
# Register Model
model_package_group_name='<>'
modelpackage_inference_specification = {
"InferenceSpecification": {
"Containers": [
{
"Image": xgboost_container,
"ModelDataUrl": f"""s3://{s3 bucket}/output/{best_overall_training_job_name}/output/model.tar.gz"""
}
],
"SupportedContentTypes": [ "text/csv" ],
"SupportedResponseMIMETypes": [ "text/csv" ],
}
}
create_model_package_input_dict = {
"ModelPackageGroupName" : model_package_group_name,
"ModelPackageDescription" : "<insert description here>",
"ModelApprovalStatus" : "PendingManualApproval",
"ModelMetrics" :report_dict
}
create_model_package_input_dict.update(modelpackage_inference_specification)
sm_client = boto3.client('sagemaker')
create_model_package_response = sm_client.create_model_package(**create_model_package_input_dict)
model_package_arn = create_model_package_response["ModelPackageArn"]
print('ModelPackage Version ARN : {}'.format(model_package_arn))
通过打开注册表中的模型包组,您可以查看所有已注册的模型版本、注册日期和批准状态。

来自 AWS 文档 的图表
管道的主管可以查看在前一步中本地保存的评估报告,其中包含所有过去模型评估的历史记录,并根据测试集评估指标确定是否批准或拒绝该模型。稍后,可以设置条件,仅在模型获得批准后更新生产(或 QA)端点。
4. 用新模型刷新多模型端点
SageMaker 具有一个 MultiDataModel 类,允许部署可以托管多个模型的 SageMaker 端点。其原理在于可以在同一个计算实例中加载多个模型,共享资源并节省成本。此外,它简化了模型的重新训练/管理,因为只需更新一个端点以反映新模型并进行管理,而不必在每个专用端点之间重复步骤(这也可以作为替代方案)。MultiDataModel 类也可以用于部署单个模型,如果计划在未来向解决方案中添加更多模型,这可能会很有意义。
在首次训练账户中,我们需要创建模型和端点。MultiDataModel 类需要一个位置来存储可以在调用时加载到端点中的模型工件;下面我们将使用正在使用的 s3 bucket 中的“models”目录。
# Load Container
from sagemaker.xgboost.estimator import XGBoost
xgboost_container = sagemaker.image_uris.retrieve("xgboost", region, "1.2-2")
# One Time: Build Multi Model
estimator = sagemaker.estimator.Estimator.attach('sagemaker-xgboost-220611-1453-011-699894eb')
xgboost_container = sagemaker.image_uris.retrieve("xgboost", region, "1.2-2")
model = estimator.create_model(role=role, image_uri=xgboost_container)
from sagemaker.multidatamodel import MultiDataModel
sagemaker_session=sagemaker.Session()
# This is where our MME will read models from on S3.
model_data_prefix = f"s3://{bucket}/models/"
mme = MultiDataModel(
name=model_name,
model_data_prefix=model_data_prefix,
model=model, # passing our model - passes container image needed for the endpoint
sagemaker_session=sagemaker_session,
)
# One Time: Deploy the MME
ENDPOINT_INSTANCE_TYPE = "ml.m4.xlarge"
ENDPOINT_NAME = "<insert here>"
predictor = mme.deploy(
initial_instance_count=1, instance_type=ENDPOINT_INSTANCE_TYPE, endpoint_name=ENDPOINT_NAME,kms_key='<insert here if desired>'
)
之后,可以按照以下方式引用 MultiDataModel:
model=sagemaker.model.Model(model_name)
from sagemaker.multidatamodel import MultiDataModel
sagemaker_session=sagemaker.Session()
# This is where our MME will read models from on S3.
model_data_prefix = f"s3://{bucket}/models/"
mme = MultiDataModel(
name=model_name,
model_data_prefix=model_data_prefix,
model=model, # passing our model - passes container image needed for the endpoint
sagemaker_session=sagemaker_session,
)
可以通过将工件复制到端点将使用的 {s3 bucket}/models 目录,将模型添加到 MultiDataModel 中。我们所需的只是模型包组名称,模型注册表将提供相应的源工件位置和批准状态。
我们可以添加一个条件,仅在模型获得批准后才添加最新模型,如下所示。如果需要立即部署进行数据科学 QA 并最终批准模型,可以在沙箱账户中省略此条件。
# Get the latest model version and associated artifact location for a given model package group
ModelPackageGroup = 'model_package_group'
list_model_packages_response = client.list_model_packages(ModelPackageGroupName=f"arn:aws:sagemaker:{region}:{aws_account_id}:model-package-group/{ModelPackageGroup}")
list_model_packages_response
latest_model_version_arn = list_model_packages_response["ModelPackageSummaryList"][0][
"ModelPackageArn"
]
print(latest_model_version_arn)
modelpackage=client.describe_model_package(ModelPackageName=latest_model_version_arn)
modelpackage
artifact_path=modelpackage['InferenceSpecification']['Containers'][0]['ModelDataUrl']
artifact_path
# Add model if approved
if list_model_packages_response["ModelPackageSummaryList"][0]['ModelApprovalStatus']=="Approved":
model_artifact_name='<model_name>.tar.gz'
mme.add_model(model_data_source=artifact_path, model_data_path=model_artifact_name)
然后,我们可以使用以下函数列出已添加的模型:
list(mme.list_models())
# Output we'd see if we added the following two models
['modela.tar.gz','modelb.tar.gz']
要删除模型,可以在控制台中导航到关联的 s3 目录,并删除其中的任何一个;它们会在重新列出可用模型时消失。
模型在部署的端点中可以通过以下代码进行调用:
response = runtime_sagemaker_client.invoke_endpoint(
EndpointName = "<endpoint_name>",
ContentType = "text/csv",
TargetModel = "<model_name>.tar.gz",
Body = body)
在首次调用模型时,端点将加载目标模型,从而导致额外的延迟。在模型已加载的未来调用中,推理将立即获得。在 多模型端点开发指南 中,AWS 指出,当端点达到内存使用阈值时,最近未调用的模型将被“卸载”。模型将在下次调用时重新加载。
当现有模型工件通过 mme.add_model() 或在 s3 控制台中被覆盖时,部署的端点不会立即反映变化。为了强制端点在下次调用时重新加载最新的模型工件,我们可以使用一个技巧,即通过任意新的端点配置更新端点。这将创建一个新端点,需要加载模型,并安全地管理旧端点和新端点之间的过渡。由于每个端点配置需要唯一的名称,我们可以添加带有日期/时间戳的后缀。
# Get datetime for endpoint configuration
time=str(datetime.now())[0:10]+'--'+str(datetime.now())[11:13]+'-'+'00'
time
# Create new endpoint config in order to 'refresh' loaded models to account for new deployments
create_endpoint_config_api_response = client.create_endpoint_config(
EndpointConfigName=f"""<endpoint name>-{time}""",
ProductionVariants=[
{
'VariantName': model_name,
'ModelName': model_name,
'InitialInstanceCount': 1,
'InstanceType': instance_type
},
]
)
# Update endpoint with new config
response = client.update_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=f"""{model_name}-{time}""")
response
运行此代码后,你会看到关联的端点在控制台中将显示“更新中”状态。在此更新期间,之前的端点将可用,且在新端点准备好后将被交换,此时状态将调整为“服务中”。新添加的模型将在下次调用时加载。
我们现在已经构建了 CI/CD 解决方案所需的三个笔记本 —— 数据准备、训练/评估和端点更新。然而,这些文件目前仅存在于训练 AWS 账户中。我们需要调整第三个笔记本,以便在任何部署的 AWS 账户中工作,在这些账户中将创建/更新相应的端点。
为此,我们可以添加基于 AWS 账户 ID 的条件逻辑。新的 AWS 账户也需要 s3 存储桶来存放模型工件。由于 s3 存储桶名称在 AWS 中需要唯一,因此可以使用这种条件逻辑。此外,这也可以应用于调整端点实例类型以及添加新模型的条件(即审批状态)。
# Get AWS Account ID
aws_account_id = boto3.client("sts").get_caller_identity()["Account"]
aws_account_id
# Set Bucket & Instance Type Across Accounts
if aws_account_id=='<insert AWS Account_ID 1>':
bucket='<insert s3 bucket name 1>'
instance_type='ml.t2.medium'
elif aws_account_id=='<insert AWS Account_ID 2>':
bucket='<insert s3 bucket name 2>'
instance_type='ml.t2.medium'
elif aws_account_id=='<insert AWS Account_ID 3>':
bucket='<insert s3 bucket name 3>'
instance_type='ml.m5.large'
training_account_bucket='<insert training account bucket name>'
bucket_path = 'https://s3-{}.amazonaws.com/{}'.format(region,bucket)
最初创建和部署 MultiDataModel 的步骤需要在每个新的部署账户中重复进行。
现在我们有一个可以引用 AWS 账户 ID 并可以跨不同 AWS 账户运行的工作笔记本,我们将希望设置一个包含这个笔记本(以及可能包含其他两个用于传承追踪)的 git 仓库,然后在这些账户的 SageMaker Studio 域中克隆该仓库。幸运的是,借助 Studio/Git 集成,这些步骤非常直接/无缝,并在以下 文档 中进行了说明。根据我的经验,建议在 SageMaker Studio 外部创建仓库,并在每个 AWS 账户域中克隆它。
对于笔记本的任何未来更改可以在训练账户中进行,并推送到仓库。然后,通过拉取更改,可以在其他部署账户中反映这些更改。确保创建一个 .gitignore 文件,以便仅考虑这三个笔记本,而不是日志或其他文件;其传承将被追踪到 s3。进一步地,应当认识到每次运行笔记本时,控制台输出都会发生变化。为了避免在其他部署账户中拉取文件更改时发生冲突,必须在拉取最新更新之前恢复自上次拉取以来的文件更改。
安排重新训练/重新部署笔记本以按设定的节奏运行
最后,我们可以安排在训练账户中同时运行所有三个笔记本。我们可以使用新的 SageMaker Studio notebook jobs 功能来实现这一点。调度应被视为环境/账户依赖的——即在部署账户中我们可以创建单独的笔记本作业,但现在只是为了用最新的模型更新端点,并在新批准的模型自动部署到沙盒、QA 和生产账户之间提供一些滞后时间。其优点是,解决方案发布后唯一的手动部分变成了在注册表中的模型审批/拒绝。如果新部署的模型出现问题,可以在注册表中拒绝/删除该模型,然后手动运行端点更新笔记本,以恢复到之前的生产模型版本,从而争取进一步调查的时间。在这种情况下,我们将管道设置为按设定时间间隔运行(例如每月/每季度),尽管此解决方案可以调整为基于条件工作(例如数据漂移或生产模型准确度下降)。
结束语
CI/CD 目前是机器学习运维领域的热门话题。这是合理的,因为很多时候对于机器学习解决方案在初次部署后的持续性考虑较少。为了确保生产机器学习解决方案对协变量漂移具有鲁棒性,并且能够在长时间内可持续,需要一个简单而灵活的 CI/CD 解决方案。幸运的是,AWS 在其 SageMaker 生态系统中发布了一系列新功能,使这种解决方案成为可能。本文展示了一种成功实现这一目标的路径,适用于各种量身定制的机器学习解决方案,只需进行一次手动模型验证步骤。
感谢阅读!如果你喜欢这篇文章,关注我以获取我新帖的通知。同时,欢迎随时分享你的评论/建议。
数据处理应用程序的 CI/CD 管道在 Azure 上 第一部分:容器实例

图像由 AI Comic Factory 生成:huggingface.co/spaces/jbilcke-hf/ai-comic-factory
使用 GitHub Actions 部署 Docker 容器的逐步指南
·
关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 10 月 31 日
--
介绍
手动创建和部署 Azure 及其他云提供商的资源相对简单,在某些情况下可能足够。然而,更多时候,已部署的资源需要随着时间的推移而更改,这反过来需要大量额外的工作来维护和重新部署资源。为了自动化这些任务,开发人员和数据专业人员可以使用基础设施即代码(IaC)方法,并创建用于持续集成和部署(CI/CD)的管道。这种方法使开发人员能够编写代码,自动定义和重新部署资源。
在这份逐步指南中,我们将构建用于数据处理应用程序的管道,以执行以下任务:
-
配置一个容器注册表
-
构建并推送 Docker 镜像到注册表
-
创建一个运行数据处理工作负载的容器实例
-
启用“托管身份”访问 Azure Key Vault,这样我们的应用程序可以检索到其他资源的访问密钥,例如存储帐户
-
将上述资源部署到测试环境和生产环境中,并使用不同的触发器来运行管道
入门
为了演示,应用程序本身包括一个非常简单的 R 脚本,该脚本加载一个数据集,打印前几行,并将数据集返回到存储帐户。请记住,应用程序代码对其余的管道并不重要,可以轻松地用你自己的代码替代。
要开始使用,你需要一个 Azure 帐户。你也可能想要在本地系统上安装 Azure CLI。不过,你也可以选择通过 Azure 门户中的 Cloud Shell 运行 Azure CLI 命令。
由于我们的应用程序将数据传输到 Azure Blob 存储并返回,你可能会发现安装 Azure Storage Explorer 很有用,它可以稍微简化文件上传和验证应用程序是否正确运行及返回处理过的数据的过程。
步骤 1:克隆仓库并设置静态资源。
首先,你需要克隆这个仓库。README 文件详细说明了如何使用 RStudio 执行此操作,但你可以自由选择你喜欢的 IDE。
接下来,使用 Azure 门户创建以下资源:
-
一个资源组,将包含所有其他资源。
-
一个带有两个文件夹的存储帐户 Blob 容器:一个用于输入文件,另一个用于输出文件。在我们的案例中,这两个文件夹应分别命名为“input”和“output”。在输入容器中存储一个名为“input_data.csv”的小数据集。
-
一个密钥库,你需要将存储帐户的主要访问密钥作为机密存储在其中。
在第 3 步中,你需要密钥库的名称以及包含主要访问密钥的机密名称。
步骤 2:将 GitHub 链接到 Azure
为了更新 Azure 资源,我们需要授予 GitHub 权限。
首先,使用 Azure CLI 登录到你的 Azure 账户。
az login
然后从 JSON 输出中复制 id 值,这是订阅 id。将订阅 id 粘贴到下面的命令中并运行。这会创建一个具有基于角色的访问控制的“服务主体”,可以视作在使用 GitHub Actions 工作流部署或更新资源时代表你的用户。
az ad sp create-for-rbac \
--name "your-app-name" \
--role Owner \
--scopes /subscriptions/<your-subscription-id>/resourceGroups/<your-resource-group-name> \
--sdk-auth
复制整个 JSON 输出,前往你的 GitHub 仓库,点击“设置” > “秘密和变量” > “操作”。
创建一个新的仓库秘密并将其命名为 AZURE_CREDENTIALS。将上述命令的 JSON 输出粘贴进去并保存。
第 3 步:修改脚本
在这种情况下,我们正在部署一个简单的 R 脚本,它的功能不多。因此,Dockerfile 也保持非常简单。这两个文件显然需要根据你的要求和首选编程语言进行修改。然而,如果你对这方面不太熟悉,最好先使用现有代码将你的流水线启动运行,然后再应用自己的代码。
如果你选择继续使用当前的 R 脚本 (script.R),你只需修改 {keyvault-name}、{access-key-name} 和 {storage-account-name} 的值(省略括号)。
接下来,修改 .github/workflows/ 目录下两个工作流文件 workflow.yml 和 workflow_release_prod.yml 中 env: 下的以下值:
env:
RESOURCEGROUP_NAME: my-rg
REGISTRY_NAME: my-cr
SHORT_NAME: mycr
ACI_NAME: my-ci-test
KV_NAME: my-kv
ENVIRONMENT: test
CPU: 1
MEMORY: 1.5
第 4 步:运行流水线和容器实例
当所有相关更改都已提交到“main”分支时,你应该可以在“操作”窗格下看到你的流水线运行。这是因为工作流设置了分支触发器,使其在主分支更新时运行。
如果没有遇到任何错误,你的容器实例应该在大约十分钟内准备好运行。前往 Azure 门户,找到你的新容器实例并点击“启动”。在日志窗格中,你可能会看到你的脚本在控制台中运行。完成后,验证是否在你的 blob 容器的“output”文件夹中出现了一个名为output_data.csv的新 cv 文件。
就这样!如果愿意,你现在可以手动触发第二个工作流,以创建一个用于生产工作负载的相同容器实例。
要了解 CI/CD 流水线中发生了什么,请阅读以下部分。
理解工作流逻辑
workflow.yml 文件定义了我们流水线中的五个步骤或作业,这些步骤将资源部署到测试环境。

作者提供的图片
首先,我们将之前设置的环境变量传递为 outputs,这些变量是后续步骤所需的。
vars:
runs-on: ubuntu-latest
outputs:
resource_group: ${{ env.RESOURCEGROUP_NAME }}
acr_name: ${{ env.REGISTRY_NAME }}
short_name: ${{ env.SHORT_NAME }}
aci_name: ${{ env.ACI_NAME }}
kv_name: ${{ env.KV_NAME }}
environment: ${{ env.ENVIRONMENT }}
cpu: ${{ env.CPU }}
memory: ${{ env.MEMORY }}
steps:
- run: echo "Exposing env vars"
在第二步中,我们创建或更新现有的容器注册表。请注意,needs 键指示此步骤必须等待前一步完成。uses 键告诉我们此步骤使用了另一个文件,而 with 键用于传递所需的值。我们还需要传递或设置存储库的密钥。
deploy-acr:
needs: vars
uses: ./.github/workflows/deploy_acr.yml
if: github.ref == 'refs/heads/main'
with:
environment: ${{ needs.vars.outputs.environment }}
resource_group: ${{ needs.vars.outputs.resource_group }}
acr_name: ${{ needs.vars.outputs.acr_name }}
secrets:
azure_credentials: ${{ secrets.AZURE_CREDENTIALS }}
在用于此步骤的 deploy_acr.yml 文件顶部,我们看到脚本在工作流中每次调用时运行,以及我们在 workflow.yml 文件中提供的所需输入。
on:
workflow_call:
inputs:
environment:
required: true
type: string
resource_group:
required: true
type: string
acr_name:
required: true
type: string
secrets:
azure_credentials:
required: true
在 deploy_acr.yml 的底部,我们有一个多步骤过程运行三个预定义的操作。第一个操作检出代码库,然后我们使用我们创建并存储的服务主体凭据登录 Azure。最后,我们使用名为 azure/arm-deploy@v1 的操作来部署容器注册表。请注意,此步骤使用了 Bicep,这是一个用于配置和部署 Azure 资源的流行语言。在本文底部,你可以找到一些极好的资源来进一步了解 Bicep。
jobs:
deploy-acr:
name: Deploy ACR
runs-on: ubuntu-latest
environment: ${{ inputs.environment }}
steps:
- uses: actions/checkout@v2
- uses: azure/login@v1
with:
creds: ${{ secrets.azure_credentials }}
- name: Deploy bicep
uses: azure/arm-deploy@v1
with:
resourceGroupName: ${{ inputs.resource_group }}
template: bicep/acr.bicep
parameters:
acrName=${{ inputs.acr_name }}
acrSku=Basic
failOnStdErr: false
然后,在第三步中使用名为 build_push_container.yml 的文件构建并推送 Docker 镜像到注册表,该文件运行 Azure CLI 命令以检索容器注册表的凭据,以及 Docker 命令以构建和推送 Docker 镜像。
在第四步中,容器实例根据我们的 Docker 镜像进行配置。此步骤通过名为 deploy_aci.yml 的文件进行,该文件使用了名为 ‘azure/aci-deploy@v1’ 的预定义操作。
在最后一步中使用 kv_access.yml 文件,我们授权容器实例通过“受管身份”访问密钥库,这意味着容器可以直接从密钥库中检索密钥,而无需使用访问密钥。为实现这一点,我们需要使用 Azure CLI 命令 az container create 更新已部署的容器实例,并提供我们之前使用的各种参数。此外,我们提供了以下设置:
— assign-identity — scope ${{ steps.rg_id_step.outputs.rg_id }}
最后,你可能已经注意到 workflow.yml 中的以下行:
on:
push:
branches:
- main
workflow_dispatch:
这些行指示了我们的管道应在何时以及在什么条件下运行。在我们的场景中,我们希望当更改推送到‘main’分支时管道运行。此外,我们还希望能够手动运行它。这通过添加 workflow_dispatch: 实现。在 workflow_prod_release.yml 文件中定义的生产管道中,你会注意到生产发布只有一个手动触发。还有许多其他方法可以配置管道运行的触发方式。例如,你可以忽略特定文件或文件夹中的更改,以便只有对应用程序代码的更改会触发新的部署。
进一步阅读
如果你想了解更多关于 GitHub Actions 和 Bicep 的内容,我强烈推荐 MS Learn 平台上的以下资源:
GitHub Actions
learn.microsoft.com/zh-cn/training/modules/introduction-to-github-actions/
learn.microsoft.com/zh-cn/training/modules/learn-continuous-integration-github-actions/
learn.microsoft.com/zh-cn/training/modules/github-actions-automate-tasks/
learn.microsoft.com/zh-cn/training/modules/github-actions-ci/
learn.microsoft.com/zh-cn/training/modules/github-actions-cd/
Bicep:
learn.microsoft.com/zh-cn/training/paths/fundamentals-bicep/
learn.microsoft.com/zh-cn/training/paths/bicep-github-actions/
类别不平衡与重采样:正式介绍
让我们深入探讨类别不平衡问题,以及诸如随机过采样等重采样方法如何尝试解决这一问题。
·
关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 10 月 7 日
--
最近,我在构建一个名为Imbalance.jl的 Julia 包来解决类别不平衡问题。在构建该包时,我花费了大量的精力阅读论文和查看实现,因此我认为分享我对类别不平衡问题的理解以及一些用于解决该问题的流行算法可能会有所帮助。这些算法包括朴素随机过采样、随机过采样示例(ROSE)、随机游走过采样(RWO)、合成少数类过采样技术(SMOTE)、SMOTE-名义型、SMOTE-名义型连续以及许多欠采样方法。对于这个故事,我们将正式定义类别不平衡问题,并讨论随机过采样作为有效解决方案。在后续文章中,我们将通过考虑其他技术理性地得出更多的解决方案。
目录
∘ 类别不平衡问题
∘ 解决类别不平衡问题
∘ 随机过采样
∘ 为什么要过采样?
∘ 欠采样

照片由 Artem Kniaz 在 Unsplash 提供
类别不平衡问题
大多数(如果不是全部)机器学习算法可以被视为经验风险最小化的一种形式,其中目标是找到参数θ,使得某些损失函数L最小化:

例如,线性回归将L定义为平方损失,逻辑回归将其定义为交叉熵损失,SVM 将其定义为铰链损失,自适应提升将其定义为指数损失。
基本假设是,如果f_θ的参数允许它在数据集上最小化该经验风险,而数据集可以视为从总体中随机抽取的样本,那么它应该足够接近我们寻求的目标函数f,我们寻找的模型是使数据集以及整个总体上的相同量最小化。
在具有 K 个类别的多类别设置中,我们可以将经验风险写作

当一些类别的样本远少于其他类别时,就会发生类别不平衡问题。在这种情况下,对应于这些类别的术语对总和的贡献最小,这使得任何学习算法都可能找到一个近似解来最小化经验风险,而这个解主要只在显著的和上进行最小化。这会产生一个可能与真实目标f相差很大的假设f_θ,而这些少数类可能是所讨论应用中最重要的。
总结来说,以下是我们有类别不平衡问题的条件:
1 — 训练集中点在各个类别之间的分布并不“公平”;某些类别的点远少于其他类别。
2 — 模型在训练后对这些少数类别的点表现不佳。也就是说,学习算法没有像上面解释的那样适当地最小化少数类别的损失。
这个问题的严重程度取决于这些少数类别对应用程序的重要性。它们往往比多数类别更重要(例如,分类欺诈交易或稀有疾病)。
解决类别不平衡问题
从问题的描述中很明显,一个解决办法是给较小的类别(即少数类别)加权,以便学习算法更容易避免利用它们的不重要性来获得近似解。通常可以很容易地为此目的修改机器学习算法;特别是当它们明确是经验风险最小化的一种形式,而不仅仅是对于某些损失函数的等效形式时。
另一种试图解决问题而不需要对学习算法进行任何修改的方法是重新采样数据。最简单的形式可以看作是赋权的成本敏感方法。考虑以下算法:
给定: 一个不平衡的数据集,包含 K 个类别和每个类别的整数
需求: 一个数据集,其中每个类别的数据根据整数进行复制
操作: 将类别 k 中的每个点重复 c 次,其中 c 是与该类别相关的整数
插入总和时应该很明显,这相当于成本敏感方法;回想一下,最小化一个函数等同于最小化它的标量正倍数。
随机过采样
上述算法存在一个小问题;如果类别 A 有 900 个示例,而类别 B 有 600 个示例,则没有整数倍数可以用来过采样类别 B,使数据集达到完全平衡。我们可以通过随机选择点进行复制,扩展算法以处理非整数的复制比例。例如,如果我们想对类别 B 进行 300 个示例的过采样,以使系统达到平衡(相当于比例为 1.5),我们可以通过...
1 — 从类别 B 中随机选择 300 个点
2 — 复制这些点
这个算法称为朴素随机过采样,它正式的做法是:
1 — 计算每个类别所需生成的点数(根据给定的比例计算)
2 — 假设对于类别 k,这个数字是N_k,然后从属于该类别的点中随机选择N_k个点,并将它们添加以形成新的数据集
显然,这在平均上等同于前述算法,因此也等同于类别权重。如果类别 k 的比例为 2.0,则每个点平均会被随机选择一次。
这是我为三个类别(0、1、2)生成的随机不平衡数据集,展示了类别的直方图以及过采样前后的点的散点图。

作者图
请注意,下方的两个图形没有视觉上的差异,因为所有生成的点都是现有点的复制品。
为什么要过采样?
你可能还不太相信重新采样比类别权重在解决类别不平衡问题上更为突出;毕竟,我们已经展示了,朴素的随机过采样可以看作是平均上等同于类别权重的,只是训练时间更长(由于数据量增加)。这个论点唯一的问题是,一般情况下不可能使用类别权重;尤其是,当机器学习模型没有明确最小化损失时。
如果我们通过在数据集中自然地添加每个类别的点来收集更多数据,那么我们获得比朴素随机过采样(或类别权重)更好的结果是有道理的。例如,假设……
-
我们希望检测一个交易是否为欺诈交易。
-
我们收集了一个包含 1K 欺诈交易和 999K 有效交易的数据集。
显然,通过收集另外 998K 欺诈交易来解决不平衡问题比重复现有的 1K 交易 997K 次要好得多。特别是,在后者情况下,我们有很高的过拟合特定数据的风险。
现实显然是,通常情况下不可能为少数类收集更多数据;然而,可能可以模拟这样做的效果。也许,如果我们合成生成的数据对于少数类足够代表真实的新例子,那么理想情况下,这比重复例子或类别权重要好得多。这是最常见的过采样形式,其与收集真实数据的关系可能是其优于随机过采样和类别权重的最合理理由;因此,成为解决类别不平衡问题的值得方法。
欠采样
最后请注意,如果我们不是随机选择少数类中的例子来复制,而是随机选择多数类中的点进行删除,那么算法变成了朴素随机欠采样。这显然有丧失有用数据的缺点,但有时从“不是特别有用”的多数类中删除数据可以解决不平衡问题,并导致在“更有用”的少数类上的表现更好。还有其他欠采样方法,可以更仔细地选择排除哪些点,例如为了保持数据结构或允许更好的决策边界。
在进一步的故事中,我们解释了最初在Imbalance.jl中实现的所有算法及其如何解决类别不平衡问题。
类别不平衡:探索欠采样技术
让我们了解一下欠采样及其如何帮助解决类别不平衡问题
·
关注 发表在 Towards Data Science ·5 分钟阅读·2023 年 10 月 7 日
--
我们已经正式解释了之前类别不平衡的影响及其原因,并解释了几种能够解决这一问题的过采样技术,如随机过采样、ROSE、RWO、SMOTE、BorderlineSMOTE1、SMOTE-NC 和 SMOTE-N。在这个故事中,我们将尝试对欠采样技术进行类似的探讨,同时假设在之前的解释中,欠采样如何帮助解决不平衡问题已经显而易见。
目录
∘ 简介
∘ 朴素随机欠采样
∘ K 均值欠采样
∘ Tomek 链下采样
∘ 编辑最近邻下采样
介绍
下采样技术通常分为两大类:控制型和非控制型。在控制型技术中,算法接收一个数字,指示最终数据集中应有多少样本;而在非控制型技术中,下采样通常通过简单地删除满足某些条件的点来进行。事先无法知道有多少点会满足这些条件,显然也无法控制。在这个故事中,我们将介绍两种控制型下采样技术(随机和 K-Means 下采样)以及两种非控制型下采样技术(Tomek 链和编辑最近邻)。
天真随机下采样
在这种技术中,如果给定要从类 k 中移除 N_k 点,则会从该类中随机选择 N_k 个点进行删除(也可以随机选择要保留的点,以便移除 N_k 个点)。
下面展示了一个在有三个类 0、1 和 2 的数据中下采样两个主要类的示例。

作者使用 Julia 的 Imbalance.jl 包绘制的图
下面是一个动画,展示了不同下采样程度下的输出

作者使用 Julia 的 Imbalance.jl 包制作的动画
注意这是一个完全随机的过程;没有做出关于保留哪些点的具体选择。由于此过程,数据的分布可能会严重改变。
K-Means 下采样
我们可以通过更仔细地选择要移除(或保留)的点来保持数据的分布。在 K-Means 下采样中,如果需要为类 k 保留 N_k 点,则执行 K=N_k 的 K-Means 算法,得到 N_k 个最终中心点。K-Means 下采样让这些中心点(或每个中心点的最近邻;这是一个超参数)成为最终的 N_k 个点。由于中心点本身保留了数据的分布,这导致保留数据分布的点集更小。
下面展示了一个在有三个类 0、1 和 2 的数据中下采样两个主要类的示例。

作者使用 Julia 的 Imbalance.jl 包绘制的图
注意,相比于随机下采样,这种方法在保持数据结构方面更为小心,这在更多下采样的情况下尤为明显。我们用一个动画进一步说明这一点:

作者使用 Julia 的 Imbalance.jl 包制作的动画
注意中心点依赖于初始化,通常涉及随机性。
Tomek 链下采样
这是一种非控制型下采样技术,如果一个点是 Tomek 链的一部分,则可以被删除。如果两个点形成 Tomek 链,则:
-
它们属于不同的类别
-
每两个点都是彼此的最近邻
这里的理由是,这些点无法帮助改进决策边界(例如,可能更容易导致过拟合),并且它们可能是噪声。以下是应用托梅克链接的示例:

作者使用 Julia 中的 Imbalance.jl 包制作的图
注意下采样后如何更容易找到更线性的决策边界,并且这也使数据更好地平衡。在此过程中,我们跳过了绿色的少数类下采样,并在一个类别的点数接近时停止下采样。
要更近距离地观察这种情况,其中所有类别最终都被下采样,可以参考以下动画:

作者使用 Julia 中的 Imbalance.jl 包制作的动画
编辑最近邻下采样
尽管托梅克链接大多是那些无法形成更好决策边界的点或噪声点,并不是所有的噪声点都会形成托梅克链接。如果类别k_1中的一个噪声点存在于类别k_2中的一个密集区域内,那么噪声点的最近邻的最近点可能不是噪声点,这意味着它不会形成托梅克链接。与此条件相比,编辑最近邻下采样默认保留一个点,只有当它的大多数邻居来自同一类别时。也可以选择仅在所有邻居都来自同一类别时保留点,或者在存在来自同一类别的邻居时进行最小下采样。
这个动画展示了算法的实际操作:

作者使用 Julia 中的 Imbalance.jl 包制作的动画
注意它如何清除更多不利于决策边界或是噪声的点。如果邻居数 k 或保留条件以正确的方式改变,还可以进一步清除。这是另一种说明效果的动画。

作者使用 Julia 中的 Imbalance.jl 包制作的动画
“模式”和“仅模式”条件之间的区别在于,前者保留一个点,只有当它的类别是邻居中最常见的类别之一时;而后者仅在它的类别是唯一最常见的类别时保留点。
这结束了我们对一些有趣的下采样算法的介绍。希望这有助于你更多地了解受控和非受控的下采样。下次见,再见。
参考文献:
[1] 魏超,李志锋,夏汉,& 井尚杰。(2017)。基于聚类的类别不平衡数据下采样。信息科学,409–410,17–26。
[2] 伊万·托梅克。cnn 的两种修改。IEEE 系统、人类与控制论杂志,6:769–772,1976 年。
[3] Dennis L Wilson. 使用编辑数据的最近邻规则的渐近性质。IEEE 系统、人与控制论汇刊,页面 408–421,1972 年。
类不平衡:ROSE 和随机游走过采样(RWO)
让我们正式定义类不平衡问题,并直观地推导解决方案!
·
关注 发布于 Towards Data Science ·5 分钟阅读·2023 年 8 月 29 日
--
在这一系列故事中,我们解释了用于处理类别不平衡的各种重采样技术;特别是那些最初在Imbalance.jl Julia 包中实现的技术,包括朴素随机过采样、随机过采样示例(ROSE)、随机游走过采样(RWO)、合成少数类过采样技术(SMOTE)、SMOTE-名义型、SMOTE-名义型连续以及许多欠采样技术。在这个故事中,我们将假设对类别不平衡问题有一定了解,如之前正式解释的,并解释两种有助于解决这一问题的有趣算法;即 ROSE 和 RWO。
目录
∘ 随机过采样示例(ROSE)
∘ 随机游走过采样(RWO)
随机过采样示例(ROSE)
我们知道,我们收集的任何额外数据都遵循少数类的基础数据分布,那么如何近似这种概率分布,然后从中采样以模拟收集真实示例呢?这就是随机过采样示例(ROSE)算法的作用。
因此,ROSE 试图为每个类别 k 估计概率分布P(x|y=k),然后从中抽取所需的N_k样本。众所周知,估计这种密度的一种方法是通过核密度估计,你可以从更粗糙的版本如直方图分析中推导或直观理解。以下描述了 KDE:
给定: 数据点x 所需: P(x)的估计 操作: 选择一个核函数K(x),然后将P(x)估计为

通常,我们希望能够控制核函数的尺度(即,压缩或扩展),因为这可以改善对P(x)的估计,因此一般而言,我们有

从本质上讲,它是将核函数放置在每个点上,然后对它们进行求和和归一化,使其积分为 1。

将 KDE 应用于分布估计,作者:Drleft,维基媒体公有领域(许可证)
核函数本身是一个超参数;可能,一个已被证明不那么重要的超参数,只要它满足平滑性和对称性等基本属性。简单的高斯核,尺度为σ,是一个常见的选择,ROSE 也使用这种 KDE。

正态分布公式中的标准差作为 h,因此未写出
ROSE 从对任何类别 k 的分布估计中抽取N_k点(生成P(x|y=k)),具体操作如下:
-
随机选择一个点
-
在其上放置高斯核
-
从高斯分布中抽取一个点
这与随机过采样类似,只不过在随机选择一个点后,它在该点上放置一个高斯分布,并从高斯分布中采样生成新点,而不是重复选择的点。
在这个过程中,ROSE 使用一个称为 Silverman 的经验规则来设置带宽 h(或者在更高维度中是平滑矩阵,即正态分布中的协方差矩阵参数),以便均值积分平方误差最小化。特别地,

其中 D_σ是每个特征标准差的对角矩阵,d 是特征的数量,N 是点的数量。
在 Imbalance.jl 包中,这个值乘以另一个常数s以允许可选控制超参数。对于 s=1,保持原样;对于 s=0,ROSE 等同于随机过采样。以下使用该包生成的动画展示了增加 s 对生成点的影响。

作者所绘图。合成点以钻石形状显示。
注意,随着 s 的增加,由每个随机选择的原始点生成的合成点与其距离变得更远。
随机游走过采样(RWO)
当中心极限定理适用时,可以证明只需n个样本,就可以以 95%的概率满足均值估计 x̄=µ±1.96**σ/√n。假设少数类的σ较小(例如,σ=10),并且我们收集了 n=1000 个样本,那么 x̄=µ±0.6的概率为 95%,换句话说,如果我们的估计 x̄足够大,那么基本上 x̄≈µ。类似的论点可以用来证明估计的标准差S_x*将非常接近σ;至少在数据是正态分布的情况下。
这表明可以生成合成数据,以使估计的 x̄和S_x得以保持,以模拟收集新数据。可以证明,如果属于某类 k 的数据中的特征是独立的,其中 x̄是该类所有点的均值,S_x是它们的标准差(都是向量),那么如果通过应用

x 是属于某个类别的点;当对所有点进行 k 次这种操作时,会引入 k*n 个新样本。
然后渐近地,这些点不会改变原始的x̄和S_x。对于x̄很容易看到,因为当生成的点数 N_k=kn 非常大时,会有sum(x_new)=ksum(x),因为sum(r)=0。因此,新数据集中包含旧例子和新例子的x̄是x̄(k+1)/(k+1)=x̄*,所以均值得以保持。同样可以证明标准差也得以保持。

动画由作者制作。钻石形状代表合成生成的点。
动画中展示的实现方法适用于任何比例,通过从需要过采样的类别中随机选择点,而不是遍历所有点。

作者绘图
该方法称为随机游走过采样,因为在图中看起来像是通过在现有点上随机行走并放置新点来生成这些点。
希望这个故事能让你对机器学习中的类别不平衡问题及其解决方法有更多了解。让我们考虑 SMOTE 论文中提出的算法,以便了解下一个故事。
参考文献: [1] G Menardi, N. Torelli, “使用不平衡数据训练和评估分类规则,” 数据挖掘与知识发现,28(1),第 92–122 页,2014 年。
[2] Zhang, H., & Li, M. (2014). RWO-采样:一种用于不平衡数据分类的随机游走过采样方法。信息融合,25,第 4–20 页。
类别不平衡:从 SMOTE 到 BorderlineSMOTE1、SMOTE-NC 和 SMOTE-N
探索四种算法以解决类别不平衡问题
·
关注 发表在 Towards Data Science ·12 分钟阅读·2023 年 8 月 30 日
--
在前一篇文章中,我们解释了朴素随机过采样、随机过采样示例(ROSE)、随机游走过采样(RWO)算法的工作原理。更重要的是,我们还定义了类别不平衡问题,并直观地推导了其解决方案。我强烈推荐查看这篇文章,以确保对类别不平衡有清晰的理解。
目录
∘ 简介
∘ SMOTE(合成少数类过采样技术)
∘ BorderlineSMOTE1
∘ SMOTE-NC(SMOTE-名义连续)
∘ SMOTE-N (SMOTE-名义)
介绍
在这个故事中,我们将继续考虑 SMOTE、BorderlineSMOTE1、SMOTE-NC 和 SMOTE-N 算法。但在此之前,值得指出的是我们在上一个故事中考虑的两个算法适合以下实现框架:
-
定义算法如何对属于类 k 的数据进行 N_k 个样本的计算,并通过过采样生成这些样本
-
给定一些比率超参数,计算每个类需要添加的点数
-
对每个类运行算法,然后将所有新增点与原始数据结合起来,形成最终的过采样数据集
对于随机过采样和 ROSE 算法(如果比率足够大,可能也包括随机游走过采样),生成 N_k 个类 k 示例的算法也做了以下操作:
-
从属于类 k 的数据中随机选择 N_k 个点,允许重复。
-
对每个选择的点执行逻辑以生成新点(例如,复制或放置高斯分布然后从中采样)
我们将在本故事中考虑的其余算法也适合相同的框架。
SMOTE (合成少数类过采样技术)
因此,要解释 SMOTE 的作用,我们只需回答一个问题:为了生成 N_k 个新的示例,对每个从类 k 随机选择的 N_k 个示例执行了什么逻辑?
答案如下:
-
找到该点的 k 最近邻(k 是算法的超参数)
-
随机选择其中一个
-
从点绘制到随机选择的邻居的线段
-
随机选择该线段上的一个点
-
将其作为新点返回
从数学上讲,
-
如果点 x_i 有最近邻 z_i1, z_i2, …, z_ik
-
如果 j 是 [1,k] 范围内的随机数
-
并且 r 是 [0, 1] 范围内的随机数
然后,对于每个点 x_i,SMOTE 通过简单应用生成一个新的点 x_i’:

这就是 SMOTE 算法所做的全部。从点 x_i 沿向量 z_ij — x_i 行进距离 r,然后放置一个新点。

图由作者绘制。黑色示例是合成生成的。
一个小的侧面说明是算法操作方式与论文中呈现的方式有少许差异。特别是,作者假设比率为整数(如果不是,则向下取整)。如果类别 k 的比率是整数 C,则对于每个点,在其内选择一个随机邻居,重复 C 次,然后应用我们描述的 SMOTE 逻辑。实际上,当实现 SMOTE 时,通常将其泛化为按我们描述的浮点比率工作,而是随机选择N_k个点,然后对每个点应用 SMOTE。对于整数比率如 C=2,平均每个点被选中两次,我们回到了原始算法。这应该是有道理的,因为这是从通过重复具有整数比率的过采样到随机过采样的相同转换,这在上个故事中已经解释过了。

作者制作的动画
此动画显示了在不平衡的鸢尾花数据集的偏好类别上改变过采样比例时,SVM 的决策区域如何变化。这里的比例相对于多数类的大小。也就是说,比率为 1.0 将N_k设置为使 versicolor 类具有与 virginica 类相同数量的示例。
你可能会想,为什么 SMOTE 比 ROSE 更好。毕竟,SMOTE 生成点的逻辑在论文中并未得到证明;与此同时,从P(x|y)的估计中进行采样,就像 ROSE 中那样,更为合理和直观。可能的一个问题是获取P(x|y)的良好估计需要大量数据;然而,我们知道少数类通常数据较少。如果我们没有大量数据,我们有两个选择:
-
选择带宽过小,我们回到了可能的过拟合,就像在随机过采样中一样。
-
选择带宽过大,在极端情况下等同于从特征空间中均匀添加随机点(即,不现实的示例)。
如果你仔细想想,我们应该更少地担心 SMOTE 中的这个问题。如果存在一个完全线性分离数据的超平面,那么应用 SMOTE 后这个解决方案仍然存在。事实上,SMOTE 生成点的方式可能使非线性超表面变得更加线性化,因此在导致模型过拟合的风险要低得多。
BorderlineSMOTE1
文献中有很多算法只是对 SMOTE 的修改或改进。一个流行的例子是 BorderlineSMOTE1,它应用以下修改:

在选择 x_i 时,不是随机从属于该类的所有点集中选择,而是随机从属于该类的边界(论文中称为 DANGER)点集中选择。 如果其 k 个最近邻的大多数,但不是全部,来自大多数/其他类,则该点是边界点。 逻辑上讲,靠近决策边界的点满足此条件,这些点也是决定决策边界或分类模型性能最重要的点。 忽略所有邻居都来自另一个类的点背后的理由是,这些可能是噪音。
另一个变种 BorderlineSMOTE2 在同一篇论文中描述,并且只是修改了用于查找边界点的条件。
以下是在 Julia 中使用 Imbalance.jl 包应用 BorderlineSMOTE1 的示例

作者绘制的图表
这是相应的动画:

SMOTE-NC (SMOTE-Nominal Continuous)
虽然 ROSE 和 SMOTE 似乎比朴素的随机过采样有显著改进,但它们的缺点是失去了处理分类变量的能力,这对于朴素的随机过采样来说并不是问题。 SMOTE 的作者们足够聪明,想出了一种方法来规避这个问题,通过开发这个 SMOTE 算法的扩展来处理同时存在分类特征的情况。
你可能认为编码分类特征可以绕过这个问题,然而这并不完全正确,因为 SMOTE 或 ROSE 会将它们视为连续的并为其生成无效值。 例如,如果一个特征是二进制的,那么沿线选择的点可能是 0.57,这不是 0 和 1. 将其四舍五入是一个坏主意,因为这等同于随机选择它是 0 还是 1。
回想一下,以下是 SMOTE 生成新点的方式:
-
假设点 x_i 有最近的邻居 z_i1, z_i2, …, z_ik
-
让 j 是 [1, k] 中的随机数
-
让 r 是 [0, 1] 中的随机数
对于每个点 x_i,SMOTE 通过简单地应用公式生成一个新点 x_i’

显然,我们无法在存在分类特征的情况下应用相同的方法,除非我们通过回答以下两个问题来扩展它
-
如何找到 k 个最近邻? 欧几里得距离度量只在连续特征上运行
-
新点是如何生成的?我们不能应用 SMOTE 方程来生成 x_i’ 的分类部分
对于第一个问题,作者建议对欧几里得距离进行修改,以考虑分类部分。假设每个x_i和z_ij涉及m个连续特征和n个分类特征,那么在修改后的度量中,连续特征自然被相减并平方,然后对于每对不同的分类特征添加一个常数惩罚。这个惩罚特别是所有连续特征方差的中位数,可以在算法开始时计算出来。
例如,要测量两个点x_1和x_2之间的距离

如果标准差的中位数是m,则距离由下式给出

最后的两个项考虑了最后两个分类特征的不同。
虽然作者没有对度量提供理由,但观察到在分类特征之间测量距离的最常见方法之一是汉明距离。它简单地对每对不同的分类特征加 1。汉明距离为 6 表示两个点在 6 个分类特征上有不同的值。在我们的案例中,将惩罚设置为 1(如同汉明距离中)并不直观,因为如果连续特征经常发生强烈变化,那么 1 的值在总和中将非常微不足道,这等同于忽略分类特征的测量。使用任何两个连续特征之间的平均平方差作为惩罚应该能够解决这个问题,因为如果连续特征的方差通常很大,惩罚也会很大且不可忽略。唯一的问题是作者使用了方差的中位数而不是均值,这可能是由于其对异常值的鲁棒性。
回答第二个问题要简单得多,现在我们已经使用修改后的度量找到 k 个最近邻,我们可以像往常一样使用 SMOTE 公式生成新点的连续部分。为了生成新点的分类部分,简单地取 k 个最近邻分类部分的众数是有意义的。即,让邻居对分类部分的值进行投票,其中最常见的值将占主导地位。
因此,SMOTE-NC 生成新点的过程是……
-
使用修改后的欧几里得度量找到点的 k 个最近邻(k 是算法的一个超参数)
-
随机选择其中一个
-
从该点到邻居在连续特征空间中画一条线段
-
随机选择线段上的一个点
-
让这条线段成为新点的连续部分
-
对于新点的分类部分,取 k 个最近邻分类部分的众数。
SMOTE-N (SMOTE-Nominal)
应该很明显,当没有分类特征参与时,SMOTE-NC 会变成 SMOTE,因为那时惩罚为零,生成中的模式步骤被跳过。然而,如果没有连续特征参与,则算法处于一种不稳定的状态,因为没有定义惩罚,因为没有连续特征。你的解决办法可能是将其设置为 1 或其他值,然后按正常方式操作算法,但这并不理想,因为在计算最近邻居时会容易出现许多平局。如果一个点与另 10 个点之间的汉明距离是 7,它们真的都同样接近那个点吗? 还是它们只是共享在 7 个特征上与该点不同的共同点?
SMOTE-N 是作者在论文中提出的另一种算法,用于处理纯粹分类的数据。它通过在分类特征上采用另一种距离度量来对上述斜体问题作出负面回应。一旦找到 k 个最近邻居,模式计算决定了新点;然而,这次点本身也参与模式计算(投票)。
因此,足以解释在 SMOTE-N 中用于执行 K-NN 的距离度量。该度量称为“修改值距离度量”(Cost & Salzberg, 1993),其操作如下:给定具有 q 个分类特征和每个分类特征分别有 p_1, p_2, …, p_q 个可能值的两个特征向量。
-
通过长度为 K 的向量 V 对每个分类值进行编码,其中 K 是类别的数量。V[i]应为该值在第 i 类中的频率除以其在所有类别中的频率。
-
现在任何分类向量都由 q 个长度为 k 的向量组成的张量表示。
-
通过计算每对长度为 k 的向量之间的曼哈顿距离,然后取结果的 L2 范数,来计算该张量所表示的任意两个分类向量之间的距离。
例如,假设我们要找出以下两个分类向量之间的距离

然后给定 3 个类别,编码后假设我们得到了

计算每对向量之间的曼哈顿距离后,我们得到

这在取 L2 范数后评估为 1.428。
精确地说,论文指出可以使用 L1 范数或 L2 范数来表示大小,但没有决定算法使用哪种(这里我们选择了 L2)。
你可能会问为什么这比使用普通的汉明距离更好。明确的答案是作者并没有做出合理解释。然而,仅为了引入一些直观的理解,我们之前争论了汉明距离在 KNN 的距离计算中可能经常导致许多平局。假设我们有三个分类向量

在这里,汉明距离会建议 x_2 和 x_3 距离 x_1 相同,因为两者的汉明距离都是 1。同时,修改值差异度量会先考虑每个值在各个类别中的分布,然后再决定哪个更接近。假设 B2 的类别频率为 [0.3, 0.2, 0.5],B3 为 [0.1, 0.9, 0],B1 为 [0.25, 0.25, 0.5]。在这种情况下,MVDM 会建议 x_3 更接近 x_1,因为 B1 比 B3 更接近 B2。从概率的角度来看,如果我们收集一个类别未知的新点,那么知道该类别是 B2 还是 B3 对于预测类别帮助不大,因此它们在这个意义上是相似的或可以互换的。
因此,总结来说,SMOTE-N 算法生成新点的过程如下:
-
使用修改值差异度量找到点的 k 近邻(k 是算法的超参数)
-
返回邻居的类别值的众数(包括点自身)以生成新点
就这样!现在你应该清楚了 SMOTE、BorderlineSMOTE1、SMOTE-N 和 SMOTE-NC 各自的工作原理。我们结束了对所有在 Julia 包 Imbalance.jl 中最初实现的重采样算法的解释系列,并附上了 这篇文章 关于欠采样的故事。
参考文献:
[1] N. V. Chawla, K. W. Bowyer, L. O. Hall, W. P. Kegelmeyer,“SMOTE:合成少数类过采样技术,”人工智能研究杂志,321–357,2002 年。
[2] H. Han, W.-Y. Wang, 和 B.-H. Mao,“Borderline-SMOTE:一种新的不平衡数据集学习过采样方法,”国际智能计算会议,2005 年,页码 878–8871
类别不平衡策略 — 带代码的视觉指南
了解随机欠采样、过采样、SMOTE、ADASYN 和 Tomek 链接
·
关注 发表在 Towards Data Science ·13 分钟阅读·2023 年 4 月 24 日
--
类别不平衡发生在分类问题中一个类别显著多于另一个类别。这在许多机器学习问题中很常见。例子包括欺诈检测、异常检测和医学诊断。
类别不平衡的诅咒
在不平衡的数据集上训练的模型在少数类别上表现不佳。在最佳情况下,这可能会导致业务损失,如客户流失分析。而在最糟糕的情况下,它可能会蔓延到面部识别系统的系统性偏见中。

一个平衡的数据集可能就是缺少的关键(来源:
Elena Mozhvilo 在 Unsplash)
处理类别不平衡的常见方法是重采样。这可能包括对多数类别进行过采样、对少数类别进行欠采样,或两者兼而有之。
在本文中,我使用生动的可视化和代码来说明处理类别不平衡的策略:
-
随机过采样
-
随机欠采样
-
使用 SMOTE 进行过采样
-
使用 ADASYN 进行过采样
-
使用 Tomek Link 进行欠采样
-
使用 SMOTE 过采样,然后使用 TOMEK Link 进行欠采样(SMOTE-Tomek)
我还将在真实世界数据集上使用这些策略,并评估它们对机器学习模型的影响。让我们开始吧。
所有源代码都在这里。
使用 Imbalance-learn
我们将使用 Python 中的imbalanced-learn包来解决类别不平衡的问题。这是一个开源库,依赖于 scikit-learn,并提供在处理不平衡类别分类时的工具。
要安装它,请使用以下命令。
pip install -U imbalanced-learn
数据集
我们使用的数据集是UCI 社区与犯罪数据集(CC BY 4.0),包含 1994 年美国社区的 100 个属性。我们可以用它来预测是否犯罪率高(定义为人均暴力犯罪超过 0.65)。数据来源于 UCI 机器学习库,由 La Salle 大学的 Michael Redmond 于 2009 年发布。
数据集中包含的变量涉及社区,如被视为城市的人口比例和家庭收入中位数,以及涉及执法,如人均警官数量和分配给毒品单位的警官比例。
此数据集存在类别不平衡。对于每一个高犯罪率社区,有 12 个低犯罪率社区。这非常适合我们的案例说明。
>>> from imblearn.datasets import fetch_datasets
>>> # Fetch dataset from imbalanced-learn library
>>> # as a dictionary of numpy array
>>> us_crime = fetch_datasets()['us_crime']
>>> us_crime
{'data': array([[0.19, 0.33, 0.02, ..., 0.26, 0.2 , 0.32],
[0\. , 0.16, 0.12, ..., 0.12, 0.45, 0\. ],
[0\. , 0.42, 0.49, ..., 0.21, 0.02, 0\. ],
...,
[0.16, 0.37, 0.25, ..., 0.32, 0.18, 0.91],
[0.08, 0.51, 0.06, ..., 0.38, 0.33, 0.22],
[0.2 , 0.78, 0.14, ..., 0.3 , 0.05, 1\. ]]),
'target': array([-1, 1, -1, ..., -1, -1, -1]),
'DESCR': 'us_crime'}
我们将把这个字典转换成 Pandas 数据框架,然后分割成训练-测试集。
# Convert the dictionary to a pandas dataframe
crime_df = pd.concat([pd.DataFrame(us_crime['data'], columns = [f'data_{i}' for i in range(us_crime.data.shape[1])]),
pd.DataFrame(us_crime['target'], columns = ['target'])], axis = 1)
# Split data into train test set
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(crime_df.drop('target', axis = 1),
crime_df['target'],
test_size = 0.4,
random_state = 42)
注意,我们将仅对训练数据集执行欠采样和过采样。我们将不会对测试集进行欠采样和过采样。
数据集预处理
我们的目标是可视化一个不平衡的数据集。为了在二维图中可视化 128 维的数据集,在训练集上进行以下操作。
-
缩放数据集,
-
对特征执行主成分分析(PCA),将 100 个特征转换为 2 个主成分,
-
可视化数据。
这是数据在 2D 中的可视化。

作者提供的图片
上述图的代码:
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
# Scale the dataset on both train and test sets.
# Note that we fit MinMaxScaler on X_train only, not on the entire dataset.
# This prevents data leakage from test set to train set.
scaler = MinMaxScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
# Perform PCA Decomposition on both train and test sets
# Note that we fit PCA on X_train only, not on the entire dataset.
# This prevents data leakage from test set to train set.
pca = PCA(n_components=2)
pca.fit(X_train)
X_train_pca = pca.transform(X_train)
X_test_pca = pca.transform(X_test)
# Function for plotting dataset
def plot_data(X,y,ax,title):
ax.scatter(X[:, 0], X[:, 1], c=y, alpha=0.5, s = 30, edgecolor=(0,0,0,0.5))
ax.set_ylabel('Principle Component 1')
ax.set_xlabel('Principle Component 2')
if title is not None:
ax.set_title(title)
# Plot dataset
fig,ax = plt.subplots(figsize=(5, 5))
plot_data(X_train_pca, y_train, ax, title='Original Dataset')
预处理完成后,我们准备对数据集进行重采样。
策略 1. 随机过采样
随机过采样通过替换从少数类别中复制现有样本来增加例子。每个少数类别中的数据点有相同的复制概率。

作者提供的图片
这是我们如何在数据集上进行过采样的方法。
from imblearn.over_sampling import RandomOverSampler
# Perform random oversampling
ros = RandomOverSampler(random_state=0)
X_train_ros, y_train_ros = ros.fit_resample(X_train_pca, y_train)
让我们比较随机过采样前(左)和随机过采样后(右)的数据。

在 Github 上绘图的代码。图片来源:作者
唯一的区别?在随机过采样之后,少数类中的重叠数据点更多。因此,少数类的数据点看起来更暗。
策略 2:随机欠采样
相反,随机欠采样会从多数类中移除现有样本。多数类中的每个数据点被移除的机会是相等的。

图片来源:作者
我们可以用以下代码来实现这一点。
from imblearn.under_sampling import RandomUnderSampler
# Perform random sampling
rus = RandomUnderSampler(random_state=0)
X_train_rus, y_train_rus = rus.fit_resample(X_train_pca, y_train)
# Function for plotting is in Notebook.
# Insert link here.
让我们比较随机欠采样前(左)和后(右)的数据。

图片来源:作者
欠采样后,数据点的总体数量显著减少。这是因为在类平衡之前,随机移除多数类中的数据点。
将机器学习应用于欠采样和过采样数据集
让我们比较在上面三个数据集(未修改数据集、欠采样数据集和过采样数据集)上训练的分类机器学习模型(SVM 模型)的表现
在这里,我们在三个数据集上训练了三种支持向量机分类器(SVC):
-
原始数据
-
随机过采样的数据
-
随机欠采样的数据
from sklearn.svm import SVC
# Train SVC on original data
clf = SVC(kernel='linear',probability=True)
clf_ros.fit(X_train_pca, y_train)
# Train SVC on randomly oversampled data
clf_ros = SVC(kernel='linear',probability=True)
clf_ros.fit(X_train_ros, y_train_ros)
# Train SVC on randomly undersampled data
clf_rus = SVC(kernel='linear',probability=True)
clf_rus.fit(X_train_rus, y_train_rus)
# Function for plotting is in Notebook.
# Insert link here.
然后,我们可以可视化每个 SVC 从数据集中学到的内容。

图片来源:作者
上面的图总结了算法从数据集中学到的内容。特别是,它们学到了:
-
一个落入黄色区域的新点被预测为黄色点(‘高犯罪率社区’)
-
一个落入紫色区域的新点被预测为紫色点(‘低犯罪率社区’)
这里有一些观察:
-
训练在原始数据集上的 SVC…相当无用。它基本上将所有社区预测为紫色。它学会忽视所有黄色点。
-
训练在过采样和欠采样数据集上的 SVC 的偏差较小。它们更不容易错误分类少数类。
-
训练在过采样和欠采样数据集上的 SVC 的决策边界有所不同。
使用 ROC 评估重采样模型
为了评估哪个 SVC 最佳,我们将评估 SVC 在测试集上的表现。我们将使用的指标是接收者操作特征曲线(ROC),以找到曲线下面积(AUC)。请搜索(Cmd+F)“附录 1”以了解 ROC 的介绍。

图片来源:作者
from sklearn.svm import SVC
from sklearn import metrics
import matplotlib.pyplot as plt
# Helper function for plotting ROC
def plot_roc(ax, X_train, y_train, X_test, y_test, title):
clf = SVC(kernel='linear',probability=True)
clf.fit(X_train, y_train)
y_test_pred = clf.predict_proba(X_test)[:,1]
fpr, tpr, thresh = metrics.roc_curve(y_test, y_test_pred)
auc = metrics.roc_auc_score(y_test, y_test_pred)
ax.plot(fpr,tpr,label=f"{title} AUC={auc:.3f}")
ax.set_title('ROC Curve')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.legend(loc=0)
# Plot all ROC into one graph
fig,ax = plt.subplots(1,1,figsize=(8,5))
plot_roc(ax, X_train_pca, y_train, X_test_pca, y_test, 'Original Dataset')
plot_roc(ax, X_train_ros, y_train_ros, X_test_pca, y_test, 'Randomly Oversampled Dataset')
plot_roc(ax, X_train_rus, y_train_rus, X_test_pca, y_test, 'Randomly Undersampled Dataset')
在原始数据上训练的 SVC 表现不佳。它的表现比我们随机猜测结果还要差。
随机过采样的数据集优于欠采样的数据集。一个可能的原因是,从欠采样过程中移除数据点会丧失信息。相反,过采样不会丢失信息。
现在我们对过采样和欠采样技术有了理解,让我们深入探讨过采样和欠采样。
策略 3. 使用 SMOTE 进行过采样
SMOTE 是一种过采样方法。直观地说,SMOTE 通过在彼此之间插值的少数数据点之间创建合成数据点。
这是 SMOTE 工作的简化说明。
-
随机选择少数类中的一些数据点。
-
对于每个选定的点,识别其k个最近邻居。
-
对于每个邻居,添加一个新点,该点位于数据点和邻居之间的某处。
-
重复步骤 2 到 4,直到生成足够的合成数据点。
请搜索(Cmd+F)“附录 2”以查找其创作者对 SMOTE 算法的确切描述。
这是一个可视化。
让我们使用 SMOTE 对数据集进行过采样,并在其上训练一个 SVC。
from imblearn.over_sampling import SMOTE
from sklearn.svm import LinearSVC
# Perform random sampling
smote = SMOTE(random_state=0)
X_train_smote, y_train_smote = smote.fit_resample(X_train_pca, y_train)
# Train linear SVC
clf_smote = SVC(kernel='linear',probability=True)
clf_smote.fit(X_train_smote, y_train_smote)
# Plot decision boundary
# Function for plotting decision boundary is in Notebook
# Link:
这是结果。

作者提供的图像
策略 4. 使用 ADASYN 进行过采样(+它与 SMOTE 的不同之处)
ADASYN 是 SMOTE 的一个变种:SMOTE 和 ADASYN 都通过插值生成新样本。
但有一个关键的区别。ADASYN 会在被 KNN 分类器错误分类的原始样本旁边生成样本。相反,SMOTE 区分了被 KNN 分类器正确或错误分类的样本。
这是 ADASYN 工作原理的可视化。
让我们使用 ADASYN 对数据集进行过采样,并在其上训练一个 SVC。
from imblearn.over_sampling import ADASYN
# Perform random sampling
adasyn = ADASYN(random_state=0)
X_train_adasyn, y_train_adasyn = adasyn.fit_resample(X_train_pca, y_train)
# Train linear SVC
from sklearn.svm import SVC
clf_adasyn = SVC(kernel='linear',probability=True)
clf_adasyn.fit(X_train_adasyn, y_train_adasyn)
# Plot decision boundary
# Function for plotting decision boundary is in Notebook
# Link:
让我们比较 SMOTE、ADASYN 和原始数据集。

作者提供的图像
这里有几点观察。
首先,两种过采样方法都会在原始数据点之间创建更多的合成数据点。这是因为 SMOTE 和 ADASYN 都使用插值来创建新的数据点。
其次,比较 SMOTE 和 ADASYN 时,我们注意到 ADASYN 会在少数(黄色)点附近的多数(紫色)数据点创建数据点。
-
比较上面用蓝色圈出的区域,ADASYN 在只有少数紫色数据点的区域创建了较少的黄色数据点。
-
比较上面用棕色圈出的区域,ADASYN 在紫色数据点较多的区域创建了更多的黄色数据点。
让我们比较到目前为止我们描述的所有过采样方法的 ROC 曲线。在这个例子中,它们表现同样出色。

作者提供的图像
策略 5. 使用汤姆克链接进行欠采样
汤姆克链接是一对非常接近但属于不同类别的点。汤姆克链接的数学定义可以在附录 3 中找到。
这是一个可视化。
要使用汤姆克链接进行欠采样,我们将识别数据集中所有的汤姆克链接。对于汤姆克链接中的每对数据点,我们将删除多数类别。
这里有一个动画,说明了使用汤姆克链接进行欠采样。
我们将对我们的数据集应用汤姆克链接欠采样。
from imblearn.under_sampling import TomekLinks
from sklearn.svm import LinearSVC
# Perform Tomek Link undersampling
tomek = TomekLinks()
X_train_tomek, y_train_tomek = tomek.fit_resample(X_train_pca, y_train)
# Train linear SVC
clf_tomek = SVC(kernel='linear',probability=True)
clf_tomek.fit(X_train_tomek, y_train_tomek)
# Code for plotting graph in notebook.
# Notebook link:
现在让我们比较 Tomek 欠采样和随机欠采样。

图片来源于作者
在我们的数据集中,移除 Tomek Link 对缓解类别不平衡几乎没有作用。这是因为数据集中 Tomek Links 的数量有限。
让我们看看 Tomek Link 欠采样的表现与随机欠采样有何不同。

图片来源于作者
我们观察到随机欠采样比 Tomek Link 欠采样效果更好。 这是因为 Tomek Link 没有像随机欠采样那样完全消除类别不平衡。
策略 6. SMOTEK:先使用 SMOTE 过采样,然后使用 Tomek Links 欠采样
现在我们已经了解了过采样和欠采样。我们可以将这些技术结合起来吗?
当然!SMOTE-TOMEK 是一种结合了过采样(SMOTE)和欠采样(通过 Tomek Links)的技术。
我们将其应用到我们的数据集上。
from imblearn.combine import SMOTETomek
from sklearn.svm import LinearSVC
# Perform random sampling
smotetomek = SMOTETomek(random_state=0)
X_train_smotetomek, y_train_smotetomek = smotetomek.fit_resample(X_train_pca, y_train)
# Plot linear SVC
clf_smotetomek = SVC(kernel='linear',probability=True)
clf_smotetomek.fit(X_train_smotetomek, y_train_smotetomek)
让我们比较 SMOTE、Tomek 和 SMOTE-Tomek。

图片来源于作者
比较 SMOTE-Tomek 和仅 SMOTE,我们可以看到差异被圈出了棕色的圈。SMOTE-Tomek 移除了接近边界的点。
最终,我们将比较上述描述的所有技术。结果是,SMOTE-TOMEK 表现最佳。

图片来源于作者
结束
总体来说,你可以使用过采样、欠采样或两者的组合来处理数据不平衡问题。如果你有计算资源,通常更好的方法是使用过采样和欠采样的组合;当数据点较少时,过采样是一种不错的策略;而当有许多类似数据点时,欠采样则表现较好。
处理不平衡数据集并不容易。我鼓励你探索许多其他重采样策略(包括不同的 欠采样方法 和 过采样方法),以查看哪种策略在你的数据集上表现最好。
此外,评估不平衡数据集的性能可能会很棘手。确保使用正确的分类指标。幸运的是,ROC 曲线、F1 分数 和 几何均值分数 等指标已经可供使用。
我是 Travis Tang。我在 LinkedIn 和 Medium 上发布数据科学内容。关注我以获取更多内容 😃
附录
附录 1. 使用 ROC 评估类别不平衡问题中的模型
ROC 对类别不平衡不敏感,使其成为评估类别不平衡模型的绝佳工具。它不依赖于类别的普遍性。这与像准确率这样的评估指标形成对比,后者在类别不平衡时可能会产生误导。

由 CMG Lee 绘制,基于 commons.wikimedia.org/wiki/File:roc-draft-xkcd-style.svg。
ROC 曲线绘制了 y 轴上的真实正例率 (TPR) 对 x 轴上的假正例率 (FPR) 的图像,用于所有可能的分类阈值。TPR 是正确分类为正例的正例实例的比例,而 FPR 是错误分类为正例的负例实例的比例。
一个性能良好的模型将具有接近图表左上角的 ROC 曲线,因为这表示更高的 TPR 和更低的 FPR。一个完全随机猜测的模型将落在 TPR = FPR 的线条上。
附录 2. SMOTE 算法的确切算法
少数类通过取每个少数类样本并沿连接任意/所有 k 个少数类最近邻的线段引入合成样本来进行过采样。根据所需的过采样量,从 k 个最近邻中随机选择邻居。我们目前的实现使用五个最近邻。例如,如果所需的过采样量为 200%,则从五个最近邻中选择两个邻居,并在每个邻居的方向上生成一个样本。合成样本的生成方式如下:取考虑中的特征向量(样本)与其最近邻之间的差异。将此差异乘以一个 0 到 1 之间的随机数,并将其加到考虑中的特征向量上。这会在两个特定特征之间的线段上选择一个随机点。这种方法有效地使少数类的决策区域变得更一般化。 [2]
附录 3. Tomek Links 的定义
给定两个属于不同类别的样本 Ei 和 Ej,d(Ei, Ej) 是 Ei 和 Ej 之间的距离。如果不存在样本 El,使得 d(Ei, El) < d(Ei, Ej) 或 d(Ej, El) < d(Ei, Ej),则 (Ei, Ej) 对被称为 Tomek link。[1]
参考文献
[1] Batista, Gustavo EAPA, Ronaldo C. Prati, 和 Maria Carolina Monard. “A Study of the Behavior of Several Methods for Balancing Machine Learning Training Data” ACM SIGKDD explorations newsletter 6.1 (2004): 20–29。
[2] Chawla, Nitesh V., 等. “SMOTE: synthetic minority over-sampling technique.” Journal of artificial intelligence research 16 (2002): 321–357。
机器学习中的分类:入门
原文:
towardsdatascience.com/classification-in-machine-learning-an-introduction-d52595e3dcab
了解机器学习中的分类,看看它是什么,如何使用,以及一些分类算法的示例
·发布在 Towards Data Science ·11 分钟阅读·2023 年 2 月 24 日
--

照片由Markus Winkler拍摄,发布在Unsplash
介绍
现在,许多行业处理着各种类型的大型数据集。手动处理所有这些信息可能会耗时,并且从长远来看可能没有增加价值。许多策略,从简单的自动化到机器学习技术,正在被应用以获得更好的投资回报。本概念博客将涵盖一个重要的概念:机器学习中的分类。
我们将首先定义机器学习中的分类是什么,然后阐明机器学习中的两种学习者类型,以及分类与回归之间的区别。接着,我们将介绍一些可以应用分类的实际场景。之后,我们将介绍所有不同类型的分类,并深入探讨一些分类算法的示例。最后,我们将提供几个算法实现的实践机会。
机器学习中的分类是什么?
分类是一种监督学习方法,其中模型尝试预测给定输入数据的正确标签。在分类中,模型使用训练数据进行全面训练,然后在测试数据上进行评估,之后才能用于对新的未见数据进行预测。
例如,一个算法可以学习预测给定的电子邮件是垃圾邮件还是正常邮件(非垃圾邮件),如下所示。

作者提供的图片
在深入了解分类概念之前,我们首先了解分类中两种类型的学习者的区别:惰性学习者和渴求学习者。然后,我们将澄清分类和回归之间的误解。
惰性学习者与渴求学习者
机器学习分类有两种类型的学习者:惰性学习者和渴求学习者。
渴求学习者是那些首先从训练数据集中构建模型,然后才对未来的数据集进行预测的机器学习算法。他们在训练过程中花费更多时间,因为他们渴望通过学习权重获得更好的泛化能力,但在进行预测时所需时间较少。
大多数机器学习算法是渴求学习者,以下是一些例子:
-
逻辑回归。
-
支持向量机。
-
决策树。
-
人工神经网络。
惰性学习者或基于实例的学习者,则不会立即从训练数据中创建模型,这就是惰性特性的来源。他们只是记住训练数据,每次需要进行预测时,他们会从整个训练数据中搜索最近的邻居,这使得他们在预测时非常缓慢。这类学习者的一些例子包括:
-
K 最近邻。
-
基于案例的推理。
然而,一些算法,如 BallTrees 和 KDTrees,可以用来提高预测的延迟。
机器学习分类与回归
机器学习算法主要分为四大类:监督学习、无监督学习、半监督学习和强化学习。
尽管分类和回归都属于监督学习的范畴,但它们并不相同。
-
当目标变量是离散的时,预测任务是分类。一个应用是识别一段文本的潜在情感。
-
当目标变量是连续的时,预测任务是回归。一个例子可以是根据一个人的学历、以往工作经验、地理位置和职位级别预测其薪资。

图片由作者提供
机器学习分类在现实生活中的例子
监督机器学习分类在我们日常生活的多个领域中具有不同的应用。以下是一些例子。
医疗保健
在历史患者数据上训练机器学习模型可以帮助医疗专家准确分析他们的诊断:
-
在 COVID-19 疫情期间,机器学习模型被实施以有效预测一个人是否感染了 COVID-19。
-
研究人员可以利用机器学习模型预测未来更有可能出现的新疾病。
教育
教育是处理最多文本、视频和音频数据的领域之一。这些非结构化信息可以借助自然语言技术进行分析,以执行不同的任务,例如:
-
按类别对文档进行分类。
-
在学生申请过程中自动识别其文档的底层语言。
-
分析学生对教授的反馈情感。
交通
交通是许多国家经济发展的关键组成部分。因此,工业正在使用机器学习和深度学习模型:
-
预测哪个地理位置的交通量会增加。
-
预测由于天气条件可能在特定地点发生的潜在问题。
可持续农业
农业是人类生存的重要支柱之一。引入可持续性可以在不破坏环境的情况下提高农民的生产力:
-
通过使用分类模型预测哪种类型的土地适合某种类型的种子。
-
预测天气以帮助他们采取适当的预防措施。
机器学习中的不同类型的分类任务
机器学习中有四种主要的分类任务:二分类、多类别分类、多标签分类和不平衡分类。
二分类
在二分类任务中,目标是将输入数据分类为两个互斥的类别。在这种情况下,训练数据以二进制格式标记:真和假;正和负;O 和 1;垃圾邮件和非垃圾邮件等,具体取决于所处理的问题。例如,我们可能想要检测给定的图像是卡车还是船。

作者提供的图片
逻辑回归和支持向量机算法本质上是为二分类设计的。然而,其他算法如 K-Nearest Neighbors 和决策树也可以用于二分类。
多类别分类
多类别分类另一方面具有至少两个互斥的类别标签,其目标是预测给定输入示例属于哪个类别。在以下情况下,模型正确地将图像分类为飞机。

作者提供的图片
大多数二分类算法也可以用于多类别分类。这些算法包括但不限于:
-
随机森林
-
朴素贝叶斯
-
K-Nearest Neighbors
-
梯度提升
-
SVM
-
逻辑回归。
但是等等!你不是说 SVM 和逻辑回归默认不支持多类别分类吗?
→ 没错。然而,我们可以应用如一对一和一对多的二分类转换方法来使本质上的二分类算法适应多类别分类任务。
一对一:这种策略训练的分类器数量等于标签对的数量。如果我们有 3 类分类,将有三对标签,从而得到三个分类器,如下图所示。

图片由作者提供
通常,对于 N 个标签,我们将有 Nx(N-1)/2 个分类器。每个分类器在一个二分类数据集上进行训练,最终类别由所有分类器的多数投票决定。一对一方法在 SVM 和其他基于内核的算法中效果最佳。
一对其余:在这个阶段,我们首先将每个标签视为一个独立的标签,并将其余的标签组合成一个标签。对于 3 类,我们将有三个分类器。
通常,对于 N 个标签,我们将有N个二分类器。

图片由作者提供
多标签分类
在多标签分类任务中,我们尝试为每个输入样本预测 0 个或多个类别。在这种情况下,没有互斥关系,因为输入样本可以有多个标签。
这种情况可以在不同领域中观察到,例如自然语言处理中的自动标记,其中给定文本可以包含多个主题。类似地,在计算机视觉中,一张图像可以包含多个对象,如下图所示:模型预测该图像包含:飞机、船只、卡车和狗。

图片由作者提供
无法使用多类或二分类模型来执行多标签分类。然而,大多数用于这些标准分类任务的算法都有其专门用于多标签分类的版本。我们可以举例:
-
多标签决策树
-
多标签梯度提升
-
多标签随机森林
不平衡分类
对于不平衡分类,每个类别中的样本数量分布不均,这意味着在训练数据中,某一个类别的样本可能比其他类别的样本更多。我们可以考虑以下的 3 类分类场景,其中训练数据包含:60%的卡车,25%的飞机和 15%的船只。

图片由作者提供
不平衡分类问题可能出现在以下场景中:
-
金融行业中的欺诈交易检测
-
稀有疾病诊断
-
客户流失分析
使用传统的预测模型,如决策树、逻辑回归等,在处理不平衡数据集时可能效果不佳,因为这些模型可能会偏向于预测样本数量最多的类别,并将样本较少的类别视为噪声。
那么,这是否意味着这些问题就此被抛在了脑后?
当然不是!我们可以使用多种方法来解决数据集中的不平衡问题。最常用的方法包括采样技术或利用成本敏感算法的优势。
采样技术
这些技术旨在通过以下方式平衡原始数据的分布:
-
基于聚类的过采样:
-
随机欠采样:从多数类中随机删除样本。
-
SMOTE 过采样:从少数类中随机复制样本。
成本敏感算法
这些算法考虑了误分类的成本。它们旨在最小化模型产生的总成本。
-
成本敏感决策树。
-
成本敏感逻辑回归。
-
成本敏感支持向量机。
深入探讨分类算法
现在我们掌握了实施一些算法所需的所有工具。本节将介绍四种算法及其在贷款数据集(Kaggle 上免费提供)上的实现,以展示之前涉及的一些概念,特别是针对不平衡数据集的二分类任务。为了简单起见,我们将只关注四种算法。
数据集中贷款的分布
- 查看数据集中的前五个观测值。
import pandas as pd
loan_data = pd.read_csv("loan_data.csv")
loan_data.head()

- 数据集中的借款人概况。
import matplotlib.pyplot as plt
# Helper function for data distribution
# Visualize the proportion of borrowers
def show_loan_distrib(data):
count = ""
if isinstance(data, pd.DataFrame):
count = data["not.fully.paid"].value_counts()
else:
count = data.value_counts()
count.plot(kind = 'pie', explode = [0, 0.1],
figsize = (6, 6), autopct = '%1.1f%%', shadow = True)
plt.ylabel("Loan: Fully Paid Vs. Not Fully Paid")
plt.legend(["Fully Paid", "Not Fully Paid"])
plt.show()
# Visualize the proportion of borrowers
show_loan_distrib(loan_data)

作者提供的图像
从上面的图形中,我们注意到 84%的借款人偿还了他们的贷款,而只有 16%没有偿还,这使得数据集非常不平衡。
变量类型
在进一步操作之前,我们需要检查变量的类型,以便对需要编码的变量进行编码。
我们注意到,除了需要编码的目的属性之外,所有列都是连续变量。

作者提供的图像
# Check column types
print(loan_data.dtypes)

作者提供的图像
encoded_loan_data = pd.get_dummies(loan_data, prefix="purpose",
drop_first=True)
print(encoded_loan_data.dtypes)
将数据分为训练集和测试集
X = encoded_loan_data.drop('not.fully.paid', axis = 1)
y = encoded_loan_data['not.fully.paid']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30,
stratify = y, random_state=2022)
采样策略的应用
我们将在这里探讨两种采样策略:随机欠采样和 SMOTE 过采样。
随机欠采样
我们将对多数类进行欠采样,该类对应于“完全还款”(类 0)。
X_train_cp = X_train.copy()
X_train_cp['not.fully.paid'] = y_train
y_0 = X_train_cp[X_train_cp['not.fully.paid'] == 0]
y_1 = X_train_cp[X_train_cp['not.fully.paid'] == 1]
y_0_undersample = y_0.sample(y_1.shape[0])
loan_data_undersample = pd.concat([y_0_undersample, y_1], axis = 0)
# Visualize the proportion of borrowers
show_loan_distrib(loan_data_undersample)

作者提供的图像
SMOTE 过采样
对少数类进行过采样
smote = SMOTE(sampling_strategy='minority')
X_train_SMOTE, y_train_SMOTE = smote.fit_resample(X_train,y_train)
# Visualize the proportion of borrowers
show_loan_distrib(y_train_SMOTE)

作者提供的图像
应用采样策略后,我们观察到数据集在不同类型的借款人之间分布均匀。
一些机器学习分类算法的应用
本节将把这两种分类算法应用于 SMOTE 过采样数据集。相同的训练方法也可以应用于欠采样数据。
逻辑回归
这是一个可解释的算法。它通过使用 sigmoid 函数对数据点属于给定类别的概率进行建模来进行分类。
X = loan_data_undersample.drop('not.fully.paid', axis = 1)
y = loan_data_undersample['not.fully.paid']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, stratify = y, random_state=2022)
logistic_classifier = LogisticRegression()
logistic_classifier.fit(X_train, y_train)
y_pred = logistic_classifier.predict(X_test)
print(confusion_matrix(y_test,y_pred))
print(classification_report(y_test,y_pred))

作者提供的图像
支持向量机
该算法既可用于分类,也可用于回归。它通过使用边际最大化原理来绘制超平面(决策边界)。这个决策边界通过两个最接近的支持向量绘制。
SVM 提供了一种名为核技巧的转换策略,用于将不可分的数据投影到更高维度的空间,使其变得线性可分。
from sklearn.svm import SVC
svc_classifier = SVC(kernel='linear')
svc_classifier.fit(X_train, y_train)
# Make Prediction & print the result
y_pred = svc_classifier.predict(X_test)
print(classification_report(y_test,y_pred))

作者提供的图片
这些结果当然可以通过更多的特征工程和微调来改进。但它们已经比使用原始不平衡数据要好。
该算法是一个广为人知的梯度提升树算法的扩展。它不仅是应对过拟合的绝佳选择,而且在速度和性能方面也表现出色。
结论
这篇文章涵盖了机器学习中分类的主要方面,并提供了它们应用于不同领域的一些示例。最后,它介绍了逻辑回归和支持向量机的实现,经过欠采样和 SMOTE 过采样策略来生成平衡的数据集以训练模型。
你可以从我的 GitHub 上获取笔记本。
如果你喜欢阅读我的故事并希望支持我的写作,考虑 成为 Medium 会员。每月 5 美元的承诺,你将解锁对 Medium 上故事的无限访问。
你想请我喝杯咖啡 ☕️ 吗?→ 这里请!
欢迎在 Medium、Twitter 和 YouTube 上关注我,或者在 LinkedIn 上打个招呼。讨论 AI、ML、数据科学、NLP 和 MLOps 的话题总是很愉快的!
离开之前,请查看以下一些可能对你有帮助的文章:
Pandas 和 Python 在数据科学与数据分析中的技巧 — 第一部分
Pandas 和 Python 在数据科学与数据分析中的技巧 — 第二部分
Pandas 和 Python 在数据科学与数据分析中的技巧 — 第三部分
所有数据科学家和分析师应了解的统计学基础 — 带代码 — 第一部分
分类指标:为有志数据科学家准备的完整指南
你唯一需要的指南来掌握机器学习中的分类指标
·发表于 Towards Data Science ·阅读时间 26 分钟·2023 年 5 月 15 日
--

监督学习可以分为两类问题:分类和回归。本文旨在成为分类指标的终极指南:所以如果你是一个有志的数据科学家或初级数据科学家,你绝对需要阅读这篇文章。
首先,你也许还想阅读我关于掌握回归问题所需的 5 个指标的指南:
## 精通回归分析艺术:每个数据科学家应该了解的 5 个关键指标
你应该掌握的回归分析中使用的所有指标的终极指南
towardsdatascience.com
其次,让我通过目录告诉你你会在这里找到什么:
**Table of Contents:**
What is a classification problem?
Dealing with class imbalance
What a classification algorithm actually does
Accuracy
Precision and recall
F1-score
The confusion matrix
Sensitivity and specificity
Log loss (cross-entropy)
Categorical crossentropy
AUC/ROC curve
Precision-recall curve
BONUS: KDE and learning curves
像往常一样,你会找到 Python 示例来将理论付诸实践。
什么是分类问题?
在分类问题中,数据被标记为不同的类别:换句话说,我们的标签值代表数据点所属的类别。
分类问题有两种类型:
-
二分类问题:在这种情况下,目标值用 0 或 1 进行标记。
-
多类别 问题:在这种情况下,标签有多个值(0, 1, 2, 3 等),具体取决于类别的数量。
让我们来可视化它们。首先,让我们按照以下方式创建一个二分类数据集:
import numpy as np
import matplotlib.pyplot as plt
# Set random seed for reproducibility
np.random.seed(42)
# Generate data
num_samples = 1000
X = np.random.rand(num_samples, 2) * 10 - 5
y = np.zeros(num_samples)
y[np.sum(X ** 2, axis=1) < 5] = 1
# Plot data
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='coolwarm')
plt.xlabel('Feature')
plt.ylabel('Label')
plt.title('Binary Classification Dataset')
plt.show()

我们创建的二分类问题。图片来源:Federico Trotta。
所以,这是一个二分类数据集的例子:一些数据点属于蓝色类别,一些属于红色类别。现在,这些类别代表什么并不重要。它们可以是苹果或梨,汽车或火车。这不重要。现在重要的是我们可视化了一个二分类问题。
现在,让我们可视化一个多类别问题:
import numpy as np
import matplotlib.pyplot as plt
# Set random seed for reproducibility
np.random.seed(42)
# Generate data
num_samples = 1000
X = np.random.rand(num_samples, 2) * 10 - 5
y = np.zeros(num_samples, dtype=int)
y[np.sum(X ** 2, axis=1) < 2.5] = 1
y[np.logical_and(X[:, 0] > 2, np.abs(X[:, 1]) < 1)] = 2
y[np.logical_and(X[:, 0] < -2, np.abs(X[:, 1]) < 1)] = 3
# Plot data
plt.scatter(X[y==0, 0], X[y==0, 1], c='blue', label='Class 1')
plt.scatter(X[y==1, 0], X[y==1, 1], c='red', label='Class 2')
plt.scatter(X[y==2, 0], X[y==2, 1], c='green', label='Class 3')
plt.scatter(X[y==3, 0], X[y==3, 1], c='purple', label='Class 4')
plt.xlabel('Feature')
plt.ylabel('Label')
plt.title('Multiclass Classification Dataset')
plt.legend()
plt.show()

我们创建的多类别分类问题。图片来源:Federico Trotta。
所以,这里我们创建了一个分类问题,数据点属于 4 个类别。
多类别分类问题的一个难点是理解所有类别是否都重要。让我们在下一段中看看这是什么意思。
NOTE:
in the case of a binary classification,classes can be named as 0-1.
But they can also be named as 1-2\. So, there is no convention that
tells us we need to start from 0.
This is the same for the multi-class case. Classe can be named 0,1,2,3 as
well as 1,2,3,4.
处理类别不平衡
考虑以下数据集:
import numpy as np
import matplotlib.pyplot as plt
# Set random seed for reproducibility
np.random.seed(42)
# Class 1: blue
mean1 = [0, 0]
cov1 = [[1, 0], [0, 1]]
num_points1 = 7000
X1 = np.random.multivariate_normal(mean1, cov1, num_points1)
# Class 2: green
mean2 = [3, 3]
cov2 = [[0.5, 0], [0, 0.5]]
num_points2 = 2700
X2 = np.random.multivariate_normal(mean2, cov2, num_points2)
# Class 3: red
mean3 = [-3, 3]
cov3 = [[0.5, 0], [0, 0.5]]
num_points3 = 300
X3 = np.random.multivariate_normal(mean3, cov3, num_points3)
# Plot the data
plt.scatter(X1[:, 0], X1[:, 1], color='blue', s=1, label='Class 1')
plt.scatter(X2[:, 0], X2[:, 1], color='green', s=1, label='Class 2')
plt.scatter(X3[:, 0], X3[:, 1], color='red', s=1, label='Class 3')
plt.xlabel('Feature')
plt.ylabel('Label')
plt.title('Imbalanced Multiclass Classification Dataset')
plt.legend()
plt.show()

我们创建的不平衡的多类别分类问题。图片来源:Federico Trotta。
如我们所见,我们有很多蓝色点,还有大量绿色点。而红色点与其他点相比非常少。
问题是:我们应该考虑红色点吗?换句话说:我们可以通过删除红色点来进行机器学习分析吗?因为这些点太少了。
答案是……这要视情况而定!
通常情况下,我们可以忽略那些观测值少于其他类别的一个(或多个)类别。但在某些特定情况下,我们绝不能这么做!这就是领域知识发挥作用的地方。
例如,如果我们在研究银行公司的欺诈检测,我们期望欺诈交易相对于标准交易来说是稀少的。这给我们带来了一个不平衡的数据集,这意味着:我们不能删除属于观测值较少类别的值!
如果我们在医学领域研究某些问题也是一样的。在稀有疾病的情况下,我们期望这些疾病是……稀有的!因此,不平衡的数据集是我们所期望的。
无论如何,我们故意创建了上面的数据集用于教学目的。一般来说,很难可视化数据点,因为我们有多个特征。因此,评估类别不平衡的一种方法是显示标签的直方图。
在继续之前……如果你不知道直方图和条形图的区别,你可以阅读我写的以下文章:
它们看起来似乎是一样的,但它们之间的差异是相关的。
[towardsdatascience.com
所以,这里是我们可以做的。让我们创建一个包含三个标签的数据集,如下所示:
import pandas as pd
import numpy as np
# Create a list of labels
labels = ['1', '2', '3']
# Create a list of features
features = ['feature_1', 'feature_2', 'feature_3']
# Set the number of samples
num_samples = 1000
# Create an empty Pandas DataFrame to store the data
data = pd.DataFrame()
# Add the features to the DataFrame
for feature in features:
data[feature] = np.random.rand(num_samples)
# Add the labels to the DataFrame
data['label'] = np.random.choice(labels, num_samples)
即使这个数据框是故意创建的,它也反映了真实情况,因为它是表格化的(意味着我们可以用 pandas 操作它)。所以,如果我们显示头部,我们得到:

我们数据框的头部。图片由 Federico Trotta 提供。
所以,为了理解我们的数据集是否可能是不平衡的,我们绘制一个直方图,如下所示:
import seaborn as sns
import matplotlib.pyplot as plt
# Plot histogram
sns.histplot(data=data, x='label')
# Write title and axis labels
plt.title('CLASSES FREQUENCIES', fontsize=14) #plot TITLE
plt.xlabel('Our labels (our classes)', fontsize=12) #x-axis label
plt.ylabel('Frequencies of the three classes', fontsize=12) #y-axis label

我们数据集的三个类别的频率。图片由 Federico Trotta 提供。
好吧,在这种情况下,三个类别的频率是相同的。因此,数据集是平衡的,我们必须在分析中考虑所有标签。
相反,这是通过直方图表示的类别不平衡:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# Create a list of labels with class imbalance
labels = ['1'] * 500 + ['2'] * 450 + ['3'] * 50
# Create a list of features
features = ['feature_1', 'feature_2', 'feature_3']
# Shuffle the labels
np.random.shuffle(labels)
# Create an empty Pandas DataFrame to store the data
data = pd.DataFrame()
# Add the features to the DataFrame
for feature in features:
data[feature] = np.random.rand(len(labels))
# Add the labels to the DataFrame
data['label'] = labels
# Plot histogram
sns.histplot(data=data, x='label')
# Write title and axis labels
plt.title('CLASSES FREQUENCIES', fontsize=14) #plot TITLE
plt.xlabel('Our labels (our classes)', fontsize=12) #x-axis label
plt.ylabel('Frequencies of the three classes', fontsize=12) #y-axis label

我们创建的不平衡数据集。图片由 Federico Trotta 提供。
所以,在这种情况下,我们需要了解是否需要考虑第 3 类(我们正在研究“稀有情况”)或不需要考虑(我们正在研究“没有稀有事件的情况”),以便可以丢弃与其相关的所有值。
现在,在深入了解解决分类问题所需的指标之前,我们需要了解分类算法实际做了什么。
分类算法实际做了什么
正如我们所知,我们使用机器学习来进行预测。这意味着我们在可用数据上训练一个机器学习模型,期望预测结果尽可能接近实际数据。
如果你不知道“训练一个机器学习模型”实际意味着什么,你可以在这里阅读我的文章:
或者……“训练一个机器学习模型”是什么意思?
towardsdatascience.com
所以,让我们考虑一个二分类问题。我们的机器学习模型接收特征作为输入,并预测数据点是否属于第 1 类或第 2 类。如果预测“是完美的”,这意味着我们的模型可以准确地告诉我们哪些数据属于第 1 类,哪些属于第 2 类,没有错误。所以,所有实际属于第 1 类的数据点都被我们的机器学习模型预测为第 1 类。
当然,正如你所想,0%的错误是不可能的,这就是为什么我们需要一些指标来评估我们的机器学习模型。
所以在深入了解指标之前,我们需要使用一些术语:
-
我们将真正例(TP)定义为属于某一类别的数据点,并且被预测为属于该类别。例如,如果模型预测一封邮件是垃圾邮件,并且它确实是垃圾邮件,那么这就是一个真正例。
-
我们定义真负例(TN)为一个数据点不属于某个类别且预测为不属于该类别的情况。例如,如果模型预测一封邮件不是垃圾邮件,而它确实不是垃圾邮件,那么这就是一个真负例。
-
我们定义假正例(FP)为一个数据点属于一个类别,但预测为属于另一个类别的情况。例如,如果模型预测一封邮件是垃圾邮件,但实际上它不是垃圾邮件,那么这就是一个假正例。
-
我们定义假负例(FN)为一个数据点不属于某个类别但预测为不属于该类别的情况。例如,如果模型预测一封邮件不是垃圾邮件,但实际上它是垃圾邮件,那么这就是一个假负例。
一般来说,正如你可能想象的那样,我们希望在最大化真正例和真负例的同时,最小化假正例和假负例,以使模型尽可能准确。这意味着我们的 ML 模型做出了准确的预测。
那么,“准确”是什么意思?我们需要深入了解我们的第一个分类指标来理解它。
准确率
我们考虑的第一个指标是准确率。让我们看看公式:
作者在 embed-dot-fun 上写出的准确率公式。
所以,准确率是衡量我们的 ML 模型在预测中多么正确的一个指标。
例如,假设我们有一个标记为垃圾邮件或非垃圾邮件的邮件数据集。我们可以使用 ML 来预测新邮件是否为垃圾邮件。如果模型正确预测 100 封邮件中有 80 封是垃圾邮件,并且正确预测 100 封邮件中有 90 封不是垃圾邮件,那么它的准确率将是:
我们示例的计算。
这意味着我们的模型能够在 85% 的情况下正确预测邮件的类别。高准确率(接近 1)表示模型表现良好,而低准确率(接近 0)则表示模型需要改进。然而,准确率单独可能并不是评估模型性能的最佳指标,尤其是在不平衡的数据集中。
这是可以理解的,因为流行类别有“更多的数据”被标记,因此如果我们的模型准确,它将根据流行类别做出准确预测。换句话说,我们的模型可能因为流行类别而存在偏差。
让我们用 Python 创建一个数据集作为示例:
import numpy as np
import pandas as pd
# Random seed for reproducibility
np.random.seed(42)
# Create samples
n_samples = 1000
fraud_percentage = 0.05 # Fraudolent percentage
# Create classes
X = np.random.rand(n_samples, 10)
y = np.random.binomial(n=1, p=fraud_percentage, size=n_samples)
# Create data frame
df = pd.DataFrame(X)
df['fraudulent'] = y
我们创建了一个包含 1000 个样本的数据框,这些样本可以代表一些信用卡交易的数据。例如,我们创建了一个欺诈交易类别,它占所有观察值的 5%。所以,这个数据集显然是不平衡的。
如果我们的模型准确,那是因为它受到 95% 属于非欺诈交易类别的观察值的影响。因此,让我们拆分数据集,使用逻辑回归模型进行预测,并打印准确率:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# Split the dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Fit logistic regression model to train set
model = LogisticRegression()
model.fit(X_train, y_train)
# Make predictions
y_pred = model.predict(X_test)
# Calculate and print accuracy
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:', accuracy)
>>>
Accuracy: 0.95
所以,我们的模型准确率为 95%:好极了!现在……让我们定义其他指标,并看看它们告诉我们关于这个数据集的信息。
精确度和召回率
精确度衡量分类器不将负样本标记为正样本的能力。换句话说,它衡量所有正预测中真实正样本的比例。简单来说,精确度告诉我们模型的正预测有多准确。这是公式:
公式由作者在 embed-dot-fun 上写出。
在考虑电子邮件垃圾邮件分类问题时,精确度衡量模型分类为垃圾邮件的邮件中实际有多少是垃圾邮件。
让我们在不平衡数据集中使用它:
from sklearn.metrics import precision_score
# Calculate and print precision
precision = precision_score(y_test, y_pred)
print('Precision:', precision)
>>>
Precision: 0.0
哎呀!95%的准确率和 0%的精确度:这意味着什么?这意味着模型将所有样本预测为负样本或非欺诈交易。这显然是错误的。实际上,高精确度分数表示模型能够在所有预测为欺诈交易的交易中正确识别出较高比例的欺诈交易。
然后,我们有召回率指标,它衡量所有实际正样本中真实正样本的比例。换句话说,它衡量实际正样本中有多少被正确预测。简单来说,召回率告诉我们模型在找到数据中的所有正实例方面的能力。这是公式:
公式由作者在 embed-dot-fun 上写出。
在考虑电子邮件垃圾邮件分类问题时,召回率衡量数据集中实际的垃圾邮件有多少被我们的机器学习分类器正确识别为垃圾邮件。
假设我们有一个 1000 封邮件的数据集,其中 200 封是垃圾邮件,其余的是正常邮件。我们训练一个机器学习模型来将邮件分类为垃圾邮件或非垃圾邮件,它预测了 100 封邮件为垃圾邮件。
精确度会告诉我们这些 100 个预测为垃圾邮件的邮件中实际上有多少是垃圾邮件。例如,如果 100 个预测为垃圾邮件的邮件中有 90 个实际上是垃圾邮件,那么精确度就是 90%。这意味着,在模型预测为垃圾邮件的所有邮件中,90%实际上是垃圾邮件。
另一方面,召回率告诉我们模型正确识别了多少实际的垃圾邮件。例如,如果在 200 个实际的垃圾邮件中,模型正确识别了 150 个为垃圾邮件,那么召回率就是 75%。这意味着在所有实际的垃圾邮件中,模型正确识别了 75%。
现在,让我们在不平衡数据集中使用召回率:
from sklearn.metrics import recall_score
# Calculate and print recall
recall = recall_score(y_test, y_pred)
print('Recall:', recall)
>>>
Recall: 0.0
再次说明:我们有 95%的准确率和 0%的召回率。这意味着什么?就像之前一样,这意味着模型没有正确识别任何欺诈交易,而是将所有交易都预测为非欺诈交易。实际上,高召回率分数表示模型能够正确识别所有实际欺诈交易中的较高比例。
因此,在实际操作中,我们希望根据我们研究的问题在精度和召回率之间达到平衡。为此,我们经常参考其他两个考虑到这两者的指标:混淆矩阵和 F1 分数。让我们看看它们。
F1 分数
F1 分数 是一种机器学习评估指标,它将精度和召回率结合成一个范围为 0-1 的单一值。如果 F1 分数为 0,则我们的机器学习模型性能较差。如果 F1 分数为 1,则我们的机器学习模型性能较高。
这个指标通过计算精度和召回率的调和平均值来平衡精度和召回率。这是一种对低值更敏感的平均数,因此这个指标特别适用于不平衡的数据集。
让我们看看它的公式:
由作者在 embed-dot-fun 上撰写的 F1 分数公式。
现在,我们知道在我们不平衡的数据集上将获得的结果(F1 分数将为 0)。但让我们看看如何在 Python 中使用它:
from sklearn.metrics import f1_score
# Calculate and print f1-score
f1 = f1_score(y_test, y_pred)
print('F1 score:', f1)
>>>
F1 score: 0.0
在垃圾邮件分类器的背景下,假设我们有一个包含 1000 封邮件的数据集,其中 200 封是垃圾邮件,其余的是合法邮件。我们训练一个机器学习模型来将邮件分类为垃圾邮件或非垃圾邮件,并且预测有 100 封邮件是垃圾邮件。
要计算垃圾邮件分类器的 F1 分数,我们首先需要计算其精度和召回率。假设在 100 封预测为垃圾邮件的邮件中,有 80 封实际上是垃圾邮件。因此,精度为 80%。另外,假设在 200 封实际的垃圾邮件中,模型正确识别了 150 封垃圾邮件。因此,召回率为 75%。
现在我们可以计算 F1 分数:
由作者在 embed-dot-fun 上撰写的垃圾邮件分类器的 F1 分数计算方法。
这是一个相当不错的结果,因为我们接近 1。
混淆矩阵
混淆矩阵是一个表格,通过显示真正例、假正例、真负例和假负例的数量来总结分类模型的性能。
在二分类问题中,混淆矩阵有两行两列,显示如下:

混淆矩阵。图片由费德里科·特罗塔提供。
以垃圾邮件分类为例,假设我们的模型预测了 100 封邮件为垃圾邮件,其中 80 封实际上是垃圾邮件,预测了 900 封邮件为非垃圾邮件,其中 20 封实际上是垃圾邮件。
这个例子的混淆矩阵会是这样的:

我们的垃圾邮件分类问题的混淆矩阵。图片由费德里科·特罗塔提供。
现在,这是一种非常有用的分类可视化工具,有两个原因:
-
它可以通过可视化帮助我们计算精度和召回率
-
它立即告诉我们重要的内容,而无需任何计算。实际上,我们在分类问题中希望 TN 和 TP 尽可能高,而 FP 和 FN 尽可能低(尽量接近 0)。因此,如果主对角线上的值高而其他位置上的值低,那么我们的机器学习模型表现良好。
这就是我喜欢混淆矩阵的原因:我们只需要观察主对角线(从左上到右下)和非对角线上的值来评估机器学习分类器的性能。
考虑到我们的不平衡数据集,我们获得了 0 的精确度和召回率,我们说这意味着模型没有正确识别任何欺诈交易,而是将所有交易预测为非欺诈的。
这可能真的很难可视化,因为精确度和召回率的公式。我们必须在脑海中把它们弄清楚。由于我很难进行这种可视化,让我们应用混淆矩阵到我们的例子中,看看会发生什么:
from sklearn.metrics import confusion_matrix
# Calculate and print confusion matrix
cm = confusion_matrix(y_test, y_pred)
print('Confusion matrix:\n', cm)
>>>
Confusion matrix:
[[285 0]
[ 15 0]]
看发生了什么?我们可以清楚地说我们的模型表现不好,因为它捕获了 285 个 TNs 却没有捕获任何 TPs!这就是混淆矩阵的视觉力量!
还有另一种展示混淆矩阵的方式,我非常喜欢,因为它改善了可视化体验。这是代码:
from sklearn.metrics import ConfusionMatrixDisplay
# Calculate confusion matrix
cm = confusion_matrix(y_test, y_pred)
# Plot confusion matrix
cmd = ConfusionMatrixDisplay(cm)
cmd.plot()

我们的混淆矩阵可视化。图片由 Federico Trotta 提供。
这种可视化在多类分类问题中非常有用。让我们来看一个例子:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# Generate random data with 3 classes
X, y = make_classification(n_samples=1000, n_classes=3, n_features=10,
n_clusters_per_class=1, n_informative=5,
class_sep=0.5, random_state=42)
# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42)
from sklearn.linear_model import LogisticRegression
# Train a logistic regression model on the training data
clf = LogisticRegression(random_state=42).fit(X_train, y_train)
# Make predictions on the test data
y_pred = clf.predict(X_test)
# Calculate the confusion matrix
cm = confusion_matrix(y_test, y_pred)
# Display the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=['Class 0', 'Class 1', 'Class 2'])
disp.plot()

我们三类问题的混淆矩阵可视化。图片由 Federico Trotta 提供。
在这些情况下,很难理解真正的 TPs、TNs 等,因为我们有三类样本。不过,我们可以简单地参考主对角线上的值和非对角线上的值。在这种情况下,主对角线上的值为 49、52 和 44,这些值远高于非对角线上的值,表明该模型表现良好(还要注意我们是在测试集上计算了混淆矩阵!)。
灵敏度和特异性
有几个指标,在我个人看来,在某些特定情况下更适用:灵敏度和特异性。让我讲讲这些,然后我们将讨论在特定情况下的可用性。
灵敏度 是分类器找到所有正样本的能力:
作者在 embed-dot-fun 上写的灵敏度公式。
等一下!但这不是召回率吗?!
是的,确实如此。这不是一个错误。这就是为什么我告诉你这些指标更适用于特定情况。但是让我继续。
我们将特异性定义为分类器找到所有负样本的能力:
作者在 embed-dot-fun 上写的特异性公式。
因此,这两者都描述了测试的“精确度”:敏感性描述了正测试的概率,特异性描述了负测试的概率。
根据我的经验,这些指标更适合用于医疗领域、生物学等领域的分类器。
举个例子,我们考虑一个 COVID 测试。考虑这种方法(可以被认为是贝叶斯方法,但我们先略过):你做了一个 COVID 测试,结果是阳性。问题是:得到阳性测试的概率是多少?得到阴性测试的概率是多少?
换句话说:你用来得到结果的工具的 敏感性和特异性 是什么?
好吧,你可能会问自己:你在问什么问题,Federico?
让我举一个我去年夏天经历的例子。
在意大利,阳性 COVID 测试需要由某人认证(我们暂且略过原因):通常是医院或药店。所以,当我们出现症状时,我们通常会在家里进行 COVID 测试(3-5 欧元的 COVID 测试),然后去药店确认(15 欧元的 COVID 测试)。
去年 7 月,我在妻子和女儿测试呈阳性后出现症状。所以我在家测试结果为阳性。然后,立刻去药店确认,结果却是阴性!
怎么可能呢?很简单:我在家里使用的 COVID 测试工具比药剂师使用的工具更敏感(或者,药剂师使用的测试比我使用的测试更特异)。
因此,根据我的经验,这些指标特别适用于任何类型的测量仪器(机械、电气等)和/或某些特定领域(如生物学、医学等)。此外,记住这些指标使用 TP、TN、FP 和 FN 作为精确度和召回率:这再次强调了这些指标在二分类问题中的适用性。
当然,我不是说敏感性和特异性 必须 只在上述情况下使用。根据我的经验,它们只是更适合。
Log loss(交叉熵)
Log loss — 有时称为交叉熵 — 是分类中的一个重要指标,并且基于概率。这个分数将每个类别的预测概率与实际类别标签进行比较。
让我们看一下公式:
作者在 embed-dot-fun 上写的 Log Loss 公式。
其中我们有:
-
n是观测总数,i是单个观测值。 -
y是真实值。 -
p是预测的概率。 -
Ln是自然对数。
要计算预测的概率 p,我们需要使用可以实际计算概率的 ML 模型,比如逻辑回归。例如,在这种情况下,我们需要像这样使用 predict_proba() 方法:
from sklearn.linear_model import LogisticRegression
# Invoke logistic regression model
model = LogisticRegression()
# Fit the data on the train set
model.fit(X_train, y_train)
# Calculate probabilities
y_prob = model.predict_proba(X_new)
所以,假设我们有一个二分类问题,并且假设我们通过逻辑回归模型计算概率,并且假设以下表格代表我们的结果:

一个展示实际标签和通过逻辑回归模型计算的概率的表格。图片由 Federico Trotta 提供。
我们计算 Log Loss 的步骤如下:
使用上表中的值(我们只对前两个出现的值进行了计算)计算 Log Loss,由 Author 在 embedd-dot-fun 上完成。
这会导致一个接近 0 的值,这让我们感到满意,意味着我们的逻辑回归模型对每个类别的标签预测得相当好。实际上,Log Loss 为 0 的值表示可能的最佳拟合。换句话说,一个 Log Loss 为 0 的模型预测每个观察值的概率作为真实值。
但是,不必担心:我们不需要手动计算 Log Loss 的值。幸运的是,sklearn 为我们提供了帮助。因此,让我们回到我们不平衡的数据集。要在 Python 中计算 Log Loss,我们输入以下内容:
from sklearn.metrics import log_loss
# Invoke & print Log Loss
log_loss_score = log_loss(y_test, y_pred)
print("Log loss score:", log_loss_score)
>>>
Log loss score: 1.726938819745535
再次,我们在测试集上得到了一个糟糕的指标,确认了上述所有情况。
最后,一个最后的考虑:Log Loss 适用于二分类问题。那多分类问题呢?
分类交叉熵
分类交叉熵指标表示 Log Loss 在多分类情况下的推广。
这个指标特别适用于不平衡数据集,因为它考虑了预测类别的概率。当我们有一个不平衡的数据集时,因为类别的相对频率可以影响模型正确预测“少数”类别的能力,所以这是很重要的。
这里我们有:
作者在 embed-dot-fun 上写的分类交叉熵公式。
其中命名法与 Log Loss 情况下相同。
最后,在 Python 中,我们使用的方法与 Log Loss 相同,通过调用 from sklearn.metrics import log_loss。因此,这次讨论仅仅是为了强调二分类或多分类情况的细微差别。
AUC/ROC 曲线
ROC 代表“接收者操作特征”,是一种通过绘制真实正率(TPR)与假正率(FPR)在不同阈值下的图形来评估分类器的方式。
AUC 代表“曲线下面积”,表示 ROC 曲线下的面积。这是一种整体性能指标,范围从 0 到 1(其中 1 表示分类器预测了 100% 的标签作为实际值),并且在比较不同的分类器时更为合适。
首先,让我们定义 TPR 和 FPR:
-
TPR 是敏感度(也可以叫做召回率,正如我们所说的)。
-
FPR 定义为
1-specificity。
请注意,AUC/ROC 适用于二分类问题。实际上,对于多分类分类器,TPR 和 FPR 需要重新审视。这需要一些工作,所以在这里我的建议是仅在二分类问题中使用它。
现在,让我们看看如何在 Python 中实现这一点:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt
# Generate a random binary classification dataset
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2,
random_state=42)
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.2, random_state=42)
# Fit a logistic regression model on the training data
model = LogisticRegression()
model.fit(X_train, y_train)
# Predict probabilities for the testing data
probs = model.predict_proba(X_test)
# Compute the ROC curve and AUC score
fpr, tpr, thresholds = roc_curve(y_test, probs[:, 1])
auc_score = roc_auc_score(y_test, probs[:, 1])
# Plot the ROC curve
plt.plot(fpr, tpr, label='AUC = {:.2f}'.format(auc_score))
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.show()

上述代码的 AUC/ROC 曲线。图像由 Federico Trotta 提供。
虚线代表一个完全随机的分类器(这就像是随机猜测一个类别,而不是另一个。事实上,由于这是一个二分类问题,这条线的斜率为 0.5,意味着我们有 50% 的机会猜对)。因此,我们的曲线离这条虚线越远,我们的模型就越好。理想情况下,我们的曲线应尽可能保持在左上角,这意味着 False Positive Rate 较低,同时 True Positive Rate 较高。
这就是为什么这个图形适合用来比较模型:更好的模型在图形的左上角附近有曲线。让我们来看一个例子:我们将使用与之前相同的数据集,但我们会将数据拟合到三种不同的机器学习模型上。
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt
# Generate a random binary classification dataset
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2,
random_state=42)
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42)
# Fit three different classifiers on the training data
clf1 = LogisticRegression()
clf2 = RandomForestClassifier(n_estimators=100)
clf3 = KNeighborsClassifier(n_neighbors=5)
clfs = [clf1, clf2, clf3]
# Predict probabilities for the testing data
plt.figure(figsize=(8,6))
for clf in clfs:
clf.fit(X_train, y_train)
probs = clf.predict_proba(X_test)
fpr, tpr, _ = roc_curve(y_test, probs[:,1])
auc_score = roc_auc_score(y_test, probs[:,1])
plt.plot(fpr, tpr, label='{} (AUC = {:.2f})'.format(clf.__class__.__name__,
auc_score))
# Plot the ROC/AUC curves for each classifier
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve Comparison')
plt.legend(loc="lower right")
plt.show()

上述代码的 AUC/ROC 曲线。图像由 Federico Trotta 提供。
在这种情况下,随机森林分类器是预测我们数据最好的,因为它的曲线位于左上角,且数值高于其他模型。
结束这一部分之前,让我提醒你们,我们在这一段开始时提到 ROC 图将 TPR 对 FPR 在不同阈值下进行绘制,但我们还没有指定其他内容。因此,让我们在下一段中进行说明。
精确度-召回率曲线
考虑一个二分类问题。我们将数据拟合到分类器,它将任何预测值分配到类别 1 或类别 0:用于分配的标准是什么?
停下来思考一下吧。
是的,你猜对了:在分类问题中,分类器为每个样本分配一个介于 0 和 1 之间的分数。这表示样本属于正类的概率。
因此,我们的机器学习模型使用阈值将概率分数转换为类别预测。换句话说,任何概率分数大于阈值的样本都会被预测为正类,例如。
当然,即使在多类别分类问题中也是如此:我们使用二分类的情况只是为了简化我们的推理。
因此,ROC 曲线非常有用,因为它们展示了机器学习模型在不同阈值下的表现变化。
无论如何,分类器基于阈值将预测值分配到某个类别的事实告诉我们,精确度和召回率是一个权衡(就像偏差和方差一样)。
同时,我们甚至可以绘制精确度-召回率曲线。让我们看看如何做,使用我们用于 AUC/ROC 曲线的数据集:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
# Generate a random binary classification dataset
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2,
random_state=42)
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42)
# Fit a logistic regression model on the training data
clf = LogisticRegression()
clf.fit(X_train, y_train)
# Predict probabilities for the testing data. Compute precision-recall curve
probs = clf.predict_proba(X_test)
precision, recall, thresholds = precision_recall_curve(y_test, probs[:,1])
# Plot the precision-recall curve
plt.plot(recall, precision)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()

上述代码的精确度-召回率曲线。图像由 Federico Trotta 提供。
因此,上面我们可以看到精确度在约 0.5 召回率之前保持为 1,然后迅速下降。因此,我们希望在这个值之前选择一个精确度-召回率权衡点。假设选择在 0.4 召回率处。
另一种很好的可视化这种权衡的方法是绘制精确度与召回率随阈值变化的图。使用相同的数据集,这就是发生的情况:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
# Generate a random binary classification dataset
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2,
random_state=42)
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42)
# Fit a logistic regression model on the training data
clf = LogisticRegression()
clf.fit(X_train, y_train)
# Predict probabilities for the testing data. Compute precision-recall curve
probs = clf.predict_proba(X_test)
precision, recall, thresholds = precision_recall_curve(y_test, probs[:,1])
# Plot precision and recall as thresholds change
plt.plot(thresholds, precision[:-1], label='Precision')
plt.plot(thresholds, recall[:-1], label='Recall')
plt.xlabel('Threshold')
plt.ylabel('Precision & Recall')
plt.legend()
plt.title('Precision and Recall as Thresholds Change')
plt.show()

随着阈值变化的精确度与召回率。图像由 Federico Trotta 提供。
因此,上图确认了在这种情况下,平衡精确度-召回率权衡的阈值大约为 0.4。
所以,当有人告诉你找到一个 95% 精确度的机器学习模型时,你应该问:“召回率是多少?”
最后,由于它们使用的指标非常相似,你可能会想知道何时使用 AUC/ROC 曲线,何时使用精确度-召回率曲线。引用参考文献 1(第 92 页):
作为经验法则,当正类稀有,或当你更关心假阳性而不是假阴性时,你应该优先考虑精确度-召回率曲线,否则使用 ROC 曲线。
额外内容:KDE 和学习曲线
在我们看到的所有特定于分类问题的方法和指标中,有两个是跨领域的。这意味着它们可以用于评估分类和回归问题。
有 KDE 图和学习曲线。我在之前的文章中写过它们,所以我会在下面链接它们:
你可以在以下文章的“验证机器学习模型的图形方法”段落第 3 点找到什么是 KDE 以及如何使用它:
这里是你需要知道的关于线性回归的一切(包括在 Python 中的应用)。
你可以在这里阅读关于学习曲线是什么以及如何使用它们的内容:
发现学习曲线的强大功能来验证你的机器学习模型。
结论
到目前为止,我们已经看到很多评估分类算法的指标和方法。如果你在考虑使用哪一种,我总是说,虽然熟悉每一种(特别是比较它们)是一个好习惯,但很难回答这个问题,原因有很多。通常,这只是一个个人偏好的问题。
此外,仅使用一个指标来评估机器学习模型是不够的,这是一条经验法则。
如果你读过我其他的文章,你会知道我个人喜欢使用至少一种分析方法和一种图形方法。在分类问题的情况下,我通常使用混淆矩阵和 KDE。
不过,再次强调:这完全是个人喜好问题。我的建议是练习使用这些方法,并决定你喜欢哪些,记住你需要不止一种方法来对你的 ML 模型做出准确的判断。
免费 Python 电子书:
刚开始学习 Python 数据科学却感到困难? 订阅我的新闻通讯并获取我的免费电子书:这将为你提供正确的学习路径,帮助你通过动手实践学习 Python 数据科学。
喜欢这个故事吗?成为 Medium 会员,仅需 5$/月 通过我的推荐链接:这样我将赚取小额佣金,但对你没有额外费用:
[## 通过我的推荐链接加入 Medium — Federico Trotta
阅读 Federico Trotta 的每一篇故事(以及 Medium 上其他成千上万的作者)。你的会员费直接支持……
medium.com](https://medium.com/@federicotrotta/membership?source=post_page-----9f02eab796ae--------------------------------)
参考书目和文献:
-
[1] 使用 Scikit-Learn 和 Tensorflow 的动手机器学习 - 奥雷利安·吉龙
-
[2] 使用 PyTorch 和 Scikit-learn 的机器学习 - 塞巴斯蒂安·拉施卡,刘宇溪,瓦希德·米里亚利
使用罗斯恩布拉特感知器进行分类
机器学习的“hello-world”
·
关注 发布于 Towards Data Science ·8 分钟阅读·2023 年 9 月 9 日
--
照片由 Lucie Morel 提供,刊登在 Unsplash
最近我在思考什么是机器学习最基本的介绍。我寻找一个简单的任务,比如二分类,并且一个足够简单的算法,以便可以从头开始构建并在短文中解释。如果这个算法有一定的历史背景,那就更好了。很快我找到了候选者:感知机。感知机将我们带回机器学习的起点。它由 Frank Rosenblatt 在 60 多年前引入。像神经元一样,感知机规则接受多个输入特征,并拟合权重,这些权重与输入特征向量相乘后决定神经元是否输出信号,或在机器学习分类上下文中是否输出 0 或 1。感知机可能是最简单的二分类器,我不知道现在是否有任何实际的机器学习应用可以用它来解决。然而,它具有重要的教育和历史价值,因为它为神经网络铺平了道路。
本文的目的是介绍感知机并在一个简单的二分类任务中使用它。感知机已经在scikit-learn中实现,但我们将从零开始构建它。我们还将创建一组可视化图表,以理解算法如何设定其决策边界并探究其收敛性。感知机是一个线性模型,包括权重和在拟合过程中同时和迭代调整的偏置项。然而,它不像机器学习历史上可能的直接继承者——自适应线性神经元(Adaline)算法那样具有连续的损失函数,后者也是一个单层神经网络。拟合感知机完全依赖于检测错误分类的样本,权重和偏置在每次错误分类样本出现时都会立即更新,而不是每个时期(时期为完整的训练集通过一次)。因此,该算法甚至不需要优化器。我敢说,它如此简单和优雅,以至于它变得美丽。如果你对它的工作原理感到好奇,请继续关注!
感知机理论
感知机,像其他线性模型一样,使用一组权重,每个特征一个权重,并通过计算权重与特征值的点积并添加偏置来生成预测。

这个线性函数的结果,也称为净输入,会被输入到激活函数 f(z) 中,在感知器的情况下,这个函数是一个简单的阶跃函数,即 f(z) 如果 z≥0 则取值 1,否则取值 0。激活函数的作用是将净输入映射到两个值,即 0 和 1。实际上,我们所做的不过是定义了一个超平面。处于超平面同一侧的点属于同一类。权重定义了垂直于超平面的向量,即超平面的方向,而偏差则是超平面距离原点的距离。当拟合过程开始时,我们有一个随机定向的超平面,距离原点也是随机的。每当遇到一个错误分类的样本时,我们会稍微调整超平面,改变其方向和位置,以便在下一个周期中样本位于超平面的正确侧。我们可以决定调整超平面的幅度,即学习率应该是多少。
通常我们需要通过所有样本几个项目(周期),直到没有点被错误分类,或者更准确地说,直到没有更多的进展。每个周期,我们遍历训练集中所有样本 i = 1,.., nₛₐₘₚₗₑₛ,并使用当前的权重和偏差检查模型是否错误分类,如果是,我们将更新所有权重 j=1,.., nfₑₒₜᵤᵣₑₛ,使用学习率 η:

其中

带有帽子的符号表示预测输出。我们也会更新偏差

其中

理解这些操作的概念上是简单的。假设模型预测了类 0,而正确的应该是 1。如果 xⱼ 是正数,那么权重将增加,从而使净输入增加。如果 xⱼ 是负数,那么权重将减少,以使净输入再次增加(无论权重的符号如何)。类似地,偏差将增加,从而进一步增加净输入。通过这些改变,更有可能在下一个周期中为错误分类的样本预测正确的类别。当模型预测类 1,而正确的应该是 0 时,逻辑类似,唯一的区别是所有符号都会被反转。
如果你仔细观察,权重和偏差可能会在同一个周期内更新多次,每次针对一个被错误分类的样本。每次错误分类都会重新调整和定位决策边界超平面,以便在下一个周期中正确预测样本。
数据准备
我们将使用一个合成数据集,其中包含两个高斯分布。感知器可以与任何维度的特征一起使用,但为了本文的目的,我们将限制在两个维度,以便于可视化。
产生如下图形

合成数据集中两类的散点图。图片由作者提供。
两个高斯分布被故意拉长并进一步分开,通过选择适当的均值和协方差。我们稍后会再回到这个话题。
感知机实现及使用
感知机的实现见下文。我们使用 scikit-learn 风格来初始化模型、拟合它并最终进行预测。
初始化方法设置学习率、最大迭代次数和随机数生成器种子以确保可重复性。fit方法创建一个随机数生成器,然后用于将权重设置为从均匀分布中采样的一些小数,同时偏置初始化为零。然后我们迭代最多若干个时期。每个时期我们计算错误分类的数量,以便监控收敛情况,并在可能的情况下提前终止。对于每个错误分类的样本,我们更新权重和偏置,如前述部分所述。如果错误分类的数量为零,则无法进一步改进,因此无需继续下一个时期。predict方法只是计算权重和特征值的点积,添加偏置并应用步进函数。
如果我们使用上述感知机类与合成数据集
我们可以看到在 24 个时期内达到了收敛,即无需耗尽指定的最大时期数。

感知机收敛情况。图片由作者提供。
可以使用 scikit-learn 中的决策边界实用工具函数来直观地展示决策边界。为了使用此函数,我们生成一个 200x200 的点网格,覆盖训练集中特征值的范围。本质上,我们构建了一个轮廓图,显示预测的类别,并将样本作为散点图叠加,使用真实标签进行着色。这种绘制决策边界的方式非常通用,可以适用于任何二维分类器。
这两个合成的高斯分布已经通过一个可以用几行代码从零开始编码的模型完美分开。这种方法的简洁性和优雅使它成为机器学习的一个出色的入门和激励示例。

拟合的感知机模型的决策边界。图片由作者提供。
我们还可以通过提前停止模型拟合过程来可视化不同时代决策边界的演变。实际上,可以通过使用递增的最大时代数来拟合模型来实现。对于每一次尝试,我们使用拟合(可能未收敛)的模型的权重和偏置,并将决策边界绘制为一条线。这些线条用时代编号进行注释。这本可以通过温启动实现得更优雅,但拟合模型非常快速,因此额外的复杂性不值得。
各个时代的决策边界演变如下面的图所示。最初,一小部分类别 0 的样本被误分类,这导致决策边界线的斜率和截距逐渐变化。我们可以看到在 24 个时代内达到了收敛,这与上面的收敛图一致。一旦决策边界实现了对类别的完美分隔,拟合过程就会停止,无论边界距离其周围的样本有多近。

几点注意事项。感知器的收敛不能被视为理所当然,因此重要的是设置最大迭代次数。实际上,可以通过数学证明,对于线性可分的类别,收敛是有保证的。如果类别不是线性可分的,那么权重和偏置将不断更新,直到达到最大迭代次数。这就是为什么在合成数据集中两个高斯分布被进一步分开的原因。
另一个重要的注意点是感知器没有唯一的解决方案。通常,会有无数个超平面可以分隔线性可分的类别,模型将随机收敛到其中一个。这也意味着测量决策边界的距离不是确定性的,因此不是特别有用。支持向量机解决了这个限制。
感知器本质上是一个单层神经网络。在跳到多层神经网络和可以用于非线性问题的反向传播算法之前,理解它的工作原理是有帮助的。
使用 LLM 分类源代码 — 什么与如何
分享我们制作基于 LLM 的源代码分类器的经验
·
关注 发表在 Towards Data Science · 14 分钟阅读 · 2023 年 12 月 28 日
--
照片由 Iurii Ivashchenko 拍摄
源代码 AI 已成为一种常见的用例,具有许多实际和多样化的实现(缺陷检测,代码补全等)。源代码 AI 最有趣的方面之一是它面临的重大变革;如果不久前,源代码分类的常见方法是训练定制的 DNN、依赖于嵌入或甚至使用经典的 NLP 技术如词袋模型(BOW),如今,大型语言模型(LLMs)已成为主要的工具。更具体地说,‘上下文学习’的使用迅速显现;将LLM(指令调整)喂入提示输入并获得分类,(理论上)不需要额外的调整。ChatGPT就是这样一个演示,通过其 API 极大简化了 ML 应用程序的开发。但隐藏的复杂性使得与生产就绪的应用程序之间的距离仍然相当高。下面总结了我们使用 LLMs 对源代码进行分类的旅程中的重要亮点。让我们开始吧。
选择合适的 LLM
考虑开源
第一个重要的检查点是依赖的 LLM。虽然像 ChatGPT 这样的商业服务非常适合“5 分钟黑客马拉松 POCs”,但对于源代码应用程序,你的客户很可能不喜欢他们(或你公司内部)的代码被发送到其他地方。虽然存在本地部署(Claude 在 AWS 上和ChatGPT 在 Azure 上),为了完全控制你的 LLM,考虑转向某些开源的源代码 LLM(如CodeLlama和WizardCoder)。不过请记住,虽然商业 LLM 在技术上投入了大量精力,如‘从人类反馈中强化学习’(RLHF),以使其 API 非常稳健和易于使用,但开源 LLM 没有这样的奢侈。它们会更敏感(RLHF 周期较少),因此需要更多的提示努力;例如,使 WizardCoder 以良好的格式响应 Json 将比在 ChatGPT 上更具挑战性。对于一些人来说,使用开源的附加值可以轻松解释额外的投资,对于一些人则不那么重要。这是一个典型的权衡问题。
从一开始就要轻量化
假设你决定在内部部署你的 LLM,很快你会发现 LLM 的成本很高。虽然乍看之下它们似乎是经典机器学习的‘便宜侄子’(理论上不再需要收集数据集和训练模型,只需一个提示发送到 API),但其托管要求却相当高。例如考虑经典的用例——垃圾邮件检测;基础方法是训练一个简单的 BOW 分类器,这个分类器可以部署在较弱(因此较便宜)的机器上,或者仅在边缘设备上进行推理(完全免费)。现在与中等规模的 LLM,如StarCoder进行比较;拥有 16B 参数,即使是其量化版本也需要价格从每小时一美元起的 GPU。这就是为什么验证 LLM 是否真正必要很重要(例如垃圾邮件检测,BOW 可能就足够了)。如果 LLM 是强制要求的,考虑使用批处理而不是在线推理(去除对常量端点的需求),并优先选择能够进行边缘推理的小型 LLM(使用像cTransformers这样的包,或依赖于像Refact这样的超小型 LLM)。但请记住,没有免费的午餐;类似于从商业 LLM 转向开源 LLM 时,LLM 越小,越敏感,需要更多的提示工作来正确调整其输出。
提示敏感性
由于提示是上下文分类的主要成分,找到合适的提示将是我们的初始且最关键的任务。常见策略是收集一些标准样本,然后迭代提示,同时验证其在这些样本上的分类性能。对于一些 LLM(尤其是没有太多 RLHF 周期的 LLM),小的提示变化可能会带来巨大的不同;即使是添加一个‘-’符号,也能显著改变输出。这对分类来说是一个真正的问题,因为分类应该尽可能一致。验证 LLM 敏感性的一个简单测试是对相同样本进行小的变动推理,同时比较其响应的差异程度。但请记住,由于 LLM 的固有非确定性(稍后将详细介绍),我们应当预期不完全相同的响应。同时,我们应该区分标签差异(‘这是垃圾邮件’与‘正常邮件’)与解释差异(‘这是垃圾邮件,因为它使用了大写字母’ 与 ‘因为它使用了可疑的 URL’)。虽然解释差异在某种程度上可以是有效的(取决于用例),但标签差异是主要需要关注的问题。模糊的 LLM 将需要更多的提示工程,因此不推荐用于分类。
输入最大长度
每个 LLM 都有一个在训练阶段设置的输入最大长度。例如,Falcon 是一个巨大的开源 LLM(其最大版本有 1800 亿参数)。它如此庞大,以至于其推理需要 400GB 内存和几个 GPU,真正的庞然大物。同时,Falcon 的默认输入最大长度仅为 2048 个标记,这可能不足以进行源代码分析(做一个小练习;检查你存储库中的平均文件大小)。处理过长输入的常见技术是从子窗口拆分开始(我们发现代码拆分器在源代码分类中优于其他实现),然后在子窗口上应用 LLM,最后使用集成规则合并它们的分类。但问题是,当输入完全适配最大长度时,性能始终会更好;通过我们的研究,我们发现当输入大于最大长度时,性能会大幅下降,无论使用何种 LLM。这就是为什么尽早深入验证这样的配置是重要的,以避免浪费时间在不相关的方向上。不过请记住,这些比较点通常不会出现在 LLM 的排行榜上。
一些 LLMs 就是不够好
开始评估大型语言模型(LLMs)时,我们很容易陷入一个无尽的迭代和调整不同提示的过程,直到得出结论,我们使用的 LLM 不符合我们的需求。然而,我们可以通过一些初步验证来节省这些精力;过小的上下文大小可能会产生过小的观点。低参数数量可能意味着 LLM 对于我们所寻找的领域理解过于薄弱。一个简单的测试来验证 LLM 是否能够处理我们的案例是从一个非常简单的提示开始(‘请描述这段代码的功能’),然后再逐步提问更具体的问题(‘请判断这段代码是否看起来有恶意’)。这个想法是为了验证 LLM 是否能够正确处理我们的领域,然后再问更复杂的问题。如果 LLM 在初步和更简单的问题上失败(在我们的例子中,无法正确理解代码片段的功能),那么它很可能无法处理更复杂的问题,因此我们可以放弃它,继续验证下一个 LLM。
提示的措辞
确定性
分类的关键要求之一是确定性;确保相同的输入总会得到相同的输出。与此相矛盾的是LLMs 默认使用生成非确定性输出。常见的解决方法是将 LLM 的温度设置为 0 或 top_k 设置为 1(取决于平台和使用的架构),将搜索空间限制到下一个立即的令牌候选项。问题在于,我们通常将温度设置为>> 0,因为这有助于 LLM 更具创造力,生成更丰富、更有价值的输出。如果没有它,响应通常是不够好的。将温度值设置为 0 将要求我们更加努力地指导 LLM;使用更多的声明性提示以确保其以我们期望的方式回应(使用如角色澄清和丰富背景等技术。更多内容将在后面介绍)。请记住,这种要求并不简单,可能需要多次提示迭代才能找到理想的格式。
标注是不够的,请要求解释。
在 LLMs 时代之前,分类模型的 API 是标签化——给定输入,预测其类别。调试模型错误的常见方法是分析模型(白盒,查看诸如特征重要性和模型结构等方面)或分析生成的分类结果(黑盒,使用诸如Shap等技术,调整输入并验证其对输出的影响)。LLMs 的不同之处在于它们允许自由风格提问,而不限于特定的 API 契约。那么如何将其用于分类呢?简单的方法将遵循经典的机器学习方法,只询问标签(例如如果一段代码是 客户端还是服务器端)。这种方法过于简单,因为它没有利用 LLMs 能做的更多事情,比如解释预测,从而理解(和修复)LLM 的错误。向 LLM 询问分类理由(‘请分类并解释原因’)可以获得 LLM 决策过程的内部视图。查看这些理由,我们可能会发现 LLM 没有理解输入,或者分类任务可能不够清晰。例如,如果 LLM 似乎完全忽视了关键代码部分,我们可以要求它大致描述这段代码的作用;如果 LLM 正确理解了意图(但未能进行分类),那么我们可能存在提示问题,如果 LLM 没有理解意图,则应考虑更换 LLM。推理还将使我们能够轻松地向最终用户解释 LLM 的预测。然而,请记住,如果没有用正确的上下文框架,幻觉可能会影响应用程序的可信度。
重新使用 LLM 的表述
推理副作用是能够清楚地了解 LLMs 如何思考,更具体地说,是它们使用的措辞以及它们赋予特定术语的意义。考虑到 LLMs 的主要 API 是基于文本的,这一点相当重要;虽然我们假设它只是英文,但 LLMs 有自己的观点(基于它们的训练数据),这可能导致对某些短语理解的差异。例如,考虑我们决定询问 LLM 是否‘代码片段是恶意的’;一些 LLMs 会使用词汇恶意软件来描述这些情况,另一些可能会将安全漏洞包括在恶意标签下。这两种情况可能会产生与我们预期不同的结果。一个简单的应对技巧是使用 LLM 的措辞来定义提示。例如,如果 LLM 将恶意片段称为‘恶意软件’,使用该术语(恶意软件)会比使用我们最初意图的术语——‘恶意’——生成更连贯的结果。此外,在我们的研究中,我们发现越是遵循 LLM 的措辞,就越少遇到幻觉。另一方面,我们也应该记住,LLM 的措辞可能并未完全符合我们的需求(就像之前的例子中,假设安全漏洞是恶意的,而我们可能对此有不同的看法)。这是另一个需要你决定哪种方法最有利于你的权衡因素。
注意上下文过于宽泛
迭代提示可能会导致超详细的分类上下文;尝试抓住边界情况,更好地描述我们的意图,就像在之前的例子中,我们不依赖 LLM 对‘恶意’的定义,而是解释我们如何看待恶意片段。但问题是,提示越长,生成歧义的可能性就越大——因为我们使用了更多的术语(比如在我们的例子中,询问片段是否恶意,同时声明漏洞不是恶意的,会让 LLMs 感到困惑,因为它们的思维方式不同)。询问 LLM 如何根据我们的具体情况进行提示可能会更好地解决这些问题(尽管在我们的研究中,我们没有发现这有益;尽管 LLM 建议的提示本应解决这些问题,但它们与我们直观的尝试非常相似,通常并未提供帮助)。详细的上下文会帮助 LLM 更好地理解意图,但也可能导致混淆和幻觉。这是另一个需要权衡的方面。以下是我们研究中的一些幻觉示例:
-
根据文件内容,似乎文件是恶意的。它包含几个函数,这些函数使用 PyGame 在屏幕上绘制各种形状。然而,经过仔细检查,这些形状似乎组成了一只鸟的图像。这可能表明这个 Python 脚本意图通过伪装成一个看似无害的应用程序来执行恶意活动。基于所有这些因素,我建议将这个文件分类为 恶意。*(假阳性,假设恶意片段通过在屏幕上绘制鸟来掩盖其意图)。
-
根据初步分析,文件似乎是 良性的。我对这段代码唯一的潜在担忧是使用
requests库从不可信源 (XXXXX.com) 下载可执行文件。然而,由于下载的文件以随机名称保存到磁盘,并通过系统 shell 命令执行,因此似乎没有直接的风险与这种行为相关。 (假阴性,假设明显恶意的下载可执行文件因其随机命名而被认为是良性的).
一致性措辞
在我们的 LLM 调试过程中发现的最常见问题之一是措辞不一致。例如,考虑以下提示:‘请判断以下 文件 是否恶意。代码* 在主动有恶意意图时被认为是恶意的。这个 片段 — …。快速观察可以发现,它包含了描述同一实体的 3 个不同术语(文件、代码、片段)。这种行为似乎会严重混淆 LLM。当我们尝试纠正 LLM 错误但未能遵循其使用的确切措辞时(例如,如果我们尝试通过将 LLM 标记的‘潜在恶意’修正为‘可能恶意’),可能会出现类似问题。修复这些差异显著改善了我们的 LLM 分类,并使其更加一致。
输入预处理
我们之前讨论了使 LLMs(大语言模型)响应具有确定性的必要性,以确保相同的输入总是产生相同的输出。但类似的输入呢?如何确保它们也会生成类似的输出?此外,考虑到许多 LLMs 对输入敏感,甚至微小的变换(例如添加空行)也会极大地影响输出。公平地说,这是机器学习领域的一个已知问题;例如,图像应用通常使用数据增强技术(如翻转和旋转)来减少过拟合,使模型对小变化的敏感性降低。文本领域也存在类似的增强方法(使用同义词替换和段落打乱等技术)。问题是这不适合我们这种情况,其中模型(指令调整的 LLMs)已经经过精细调优。另一个更相关的经典解决方案是对输入进行预处理,试图使其更连贯。相关的例子包括删除冗余字符(如空行)和文本规范化(如确保全部为 UTF-8)。虽然这可能解决一些问题,但缺点在于这些方法不可扩展(例如,strip 可以处理边缘的空行,但段落内的冗余空行呢?)。这是另一个权衡的问题。
响应格式化
最简单但重要的提示技术之一是响应格式化;要求 LLM 以有效的结构格式响应(如 JSON 的 {‘classification’:.., ‘reason’:…})。明确的动机是能够将 LLMs 的输出视为另一个 API。格式良好的响应将减少对复杂后处理的需求,并简化 LLM 推理管道。对于一些 LLMs,如 ChatGPT,这将像直接询问一样简单。对于其他更轻量级的 LLMs,如 Refact,这将更具挑战性。我们发现的两个解决方法是将请求分成两个阶段(如‘描述以下代码片段的功能’ 然后‘根据代码片段描述,判断是否为服务器端’)或要求 LLM 以另一种更简化的格式响应(如 ‘请以“
清晰的上下文结构
在我们的研究中,我们发现生成具有清晰上下文结构的提示词(使用文本样式格式,如项目符号、段落和数字)是有益的。这对 LLM 更准确地理解我们的意图和我们轻松调试其错误都很重要。例如,由于错别字引起的幻觉在提示词结构良好的情况下很容易被发现。我们常用的两种技术是用项目符号替换超长的上下文声明(尽管在某些情况下它会产生另一个问题——注意力衰退),以及清晰地标记提示词的输入部分(例如,用明确的标志“ — ‘{source_code}’”框定要分析的源代码)。
注意力衰退
和人类一样,LLM 更关注边缘并且倾向于忘记中间的事实(例如,GPT-4 似乎经历了这种行为,尤其是在处理较长输入时)。我们在提示词迭代周期中遇到了这个问题,当时我们注意到 LLM 对边缘的声明有偏见,不太青睐那些指令在中间的类别。此外,每次重新排序提示标签指令都会产生不同的分类。我们的应对策略包括两个部分:首先尽量减少提示词的长度,假设提示词越长 LLM 就越难正确处理我们的指令(这意味着要优先添加哪些上下文规则,保留更通用的指令,假设过于具体的指令会被忽略)。第二个解决方案是将关注的类别指令放在边缘。这样做的动机是利用 LLM 对提示词边缘的偏见,以及几乎所有的分类问题在世界上都有一个关注的类别(我们希望不要遗漏)。例如,对于垃圾邮件-正常邮件分类,这个关注的类别可能是垃圾邮件类别,具体取决于业务场景。
模仿
其中一个最基本和常见的指令优化技巧是:在提示的系统部分添加 LLM 在回答我们的查询时应该扮演的角色,这样可以控制 LLM 的偏见并引导其朝着我们的需求方向(比如当要求 ChatGPT 以莎士比亚风格的回应作答时)。在我们之前的例子(‘does the following code malicious’)中,将 LLM 声明为‘security specialist’会产生不同的结果,而不是将其声明为‘coding expert’;‘security specialist’ 使 LLM 对安全问题产生了偏见,几乎在每段代码中都找到了漏洞。有趣的是,我们可以通过多次添加相同的声明(例如在用户部分也添加)来增加类别偏见。我们添加的角色说明越多,LLM 对该类别的偏见就越大。
组合它
角色澄清的一个关键好处是能够轻松生成具有不同条件的多个 LLM 版本,从而具有不同的分类性能。鉴于子分类器的分类,我们可以将其汇总为合并分类,从而提高精确度(使用多数投票)或召回率(警报任何子分类器警报)。思维树 是一种类似的提示技术;要求 LLM 假设它包括一组具有不同观点的专家。虽然很有前景,但我们发现开源 LLM 在处理这种更复杂的提示条件时很难受益。组合使我们能够在轻量级 LLM 上隐式生成类似的结果;故意让 LLM 以不同的观点进行响应,然后将其合并为单一分类(此外,我们还可以通过要求 LLM 在给定子分类的情况下生成合并分类,而不是依赖更简单的聚合函数,进一步模拟思维树方法)。
时间(和注意力)是你所需要的一切。
最后一条提示可能是最重要的——聪明地管理你的提示工作。LLM 是一种新技术,每天几乎都有新的创新发布。虽然这很吸引人,但缺点是生成一个有效的分类管道可能很容易变成一个无尽的过程,我们可能会花费所有的时间来改进我们的提示。请记住,LLMs 才是真正的创新,而提示基本上只是 API。花费太多时间在提示上,你可能会发现用新版本替换 LLM 更有利。关注更有意义的部分,尽量不要陷入不断寻找最佳提示的无尽努力中。愿最佳的提示(和 LLM)与你同在🙂。
PyTorch 中的干净代码:可读性 ML 的最佳实践
原文:
towardsdatascience.com/clean-code-in-pytorch-best-practices-for-readable-ml-adfd798bf6df

照片由Greyson Joralemon提供,发布在Unsplash。
五个编写干净、高效且易读代码的技巧,适用于 PyTorch
·发表于Towards Data Science ·5 分钟阅读·2023 年 3 月 17 日
--
介绍
在数据科学和编程领域,能够编写易于阅读和维护的代码非常重要。你肯定也有过这样的经历:写出的代码看似运作正常且比较清晰,但一两天或一周后重新阅读时却显得难以理解。当你需要审查别人编写的代码时,这种困难就更为明显。
在 AI 领域,编写清晰易懂的代码至关重要,因为我们经常需要设置多个实验,然后尝试多个模型、多个数据和大量的超参数。
在这篇文章中,我分享了一些我在使用 PyTorch 编程时使用的技巧,你可以立即应用这些技巧,希望它们能帮助你在数据科学家的工作中变得更加高效。
创建一个 DataModule 类来管理你的数据
在这个例子中,我将使用著名的 MNIST 数据集。虽然我使用的实践在这种情况下可能看起来不必要,因为各种库已经方便了这类简单数据集的使用,但它们在处理更复杂的数据集时总是会派上用场。
在处理非标准数据集时,我们往往需要经历不同的步骤,比如下载数据、结构化文件夹和子文件夹以拆分数据、创建自定义 Dataset 类等。因此,创建一个 DataModule 类来处理所有这些操作并提供两个直接返回我们将用于训练的数据加载器的函数,将会非常有用。
拥有这样的类将使我们在大规模上拥有更清晰、更具扩展性的工作流程。让我们看一个简单的例子。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
class DataModule:
def __init__(self, bs = 64) -> None:
self.bs = bs
self.training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
self.valid_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.training_data, batch_size=self.bs, shuffle=True)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.valid_data, batch_size=4 * self.bs, shuffle=False)
在这种情况下,DataModule 类负责使用两个函数创建数据集和实例化数据加载器。这样,在主函数中,我可以简单地调用 train_dataloader() 和 val_dataloader() 函数来访问我的数据。一般来说,如果你有一些需要执行的数据检索任务来收集你的数据集,最好将它们添加到 DataModule 类中,然后可以用它来实例化数据加载器。这样,你可以以直接的方式访问数据。
让我们复现 TensorFlow 的 fit 方法
我不是 TensorFlow 的忠实粉丝,实际上,除非必要,我几乎总是使用 PyTorch。不过,我发现 TensorFlow 的 fit() 方法非常方便。实际上,在创建模型之后,你只需调用 model.fit(data) 来训练它,类似于你在 scikit-learn 中使用模型的方式。
那么为什么不在 PyTorch 中也重新创建类似的东西呢?
我们在下一个示例中将做的是定义一个 fit 函数,该函数通过接受之前创建的 DataModule 作为输入来训练网络。之后,我们将把这个函数作为定义模型的类的方法。
这样,每当我们想要创建一个不同的模型时,我们总是可以将其与 fit() 函数关联起来,这个函数保持不变。
class MNISTNet(nn.Module):
def __init__(self) -> None:
super(MNISTLogistic, self).__init__()
self.flatten = nn.Flatten()
self.lin = nn.Linear(784,10)
def forward(self, xb):
return self.lin(torch.flatten(xb,1))
def fit(self: nn.Module, datamodule, epochs:int, loss_fn = nn.CrossEntropyLoss()):
train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()
opt = configure_optimizer(self)
train_dataloader = datamodule.train_dataloader()
for epoch in range(epochs):
self.train()
for xb, yb in train_dataloader:
pred = self(xb)
loss = loss_fn(pred, yb)
loss.backward()
opt.step()
opt.zero_grad()
self.eval()
with torch.no_grad():
valid_loss = sum(loss_fn(self(xb), yb) for xb, yb in val_dataloader)
MNISTLogistic.fit = fit
现在我们可以使用 model.fit(datamodule = datamodule, epochs = 3) 来运行我们的训练。
进度条
在模型训练过程中,如果没有关于完成所需时间的提示,确实会很烦人。但幸运的是,可以在 PyTorch 中以非常简单的方式实现一个进度条。
只需使用 tqdm 函数,将 dataloader 包装起来,并用 len(dataloader) 显式声明 dataloader 的总长度。
这样,进度条就会像魔法一样出现,使输出更加美观。
!pip install tqdm
from tqdm import tqdm
for index, (xb,yb) in tqdm(enumerate(train_loader), total = len(train_loader))
pred = self(xb)
loss = loss_fn(pred, yb)
loss.backward()
opt.step()
opt.zero_grad()

评估指标
我不知道为什么,但当我阅读用 PyTorch 编写的代码时,我经常看到人们手动实现常见的指标,如精确度、召回率、准确率……
然而,当他们使用其他库,如 scikit-learn 时,情况就不同了。在训练函数中实现这些指标可能会使函数难以阅读,并且可能更容易插入错误。
我建议在工作时使用库中已经找到的指标,如 scikit-learn。这使我们能够使用可能更稳健的代码,更重要的是节省时间!
当然,如果需要实现自定义指标,那么讨论就会有所不同,例如,如果你正在研究模型评估的新方法。
特别是在开始开发项目时,我们希望使用标准指标来判断是否走在正确的道路上,我发现使用 scikit-learn 的 classification_report 函数非常有用。让我们看一个例子。
from sklearn.metrics import classification_report
preds = [0, 1, 1, 0, 1, 0, 1, 1, 0, 1]
labels = [1, 0, 0, 1, 1, 1, 0, 0, 1, 1,]
print(classification_report(labels, preds))

(图片由作者提供)
结束语
作为一名开发者,我总是尽力使我的代码清晰、整洁(且无错误!😉)。我总是记得我的代码必须尽可能易于理解,即使没有使用注释。因此,我喜欢学习可以立即在代码中实现的简单技巧。
如果这篇文章对你有帮助,请关注我以阅读我下一篇类似的文章!😊
终结
马切洛·波利蒂
使用 Python Pandas 清理混乱的汽车数据集
原文:
towardsdatascience.com/cleaning-a-messy-car-dataset-with-python-pandas-700fe10a7180
无论你是进行探索性数据分析还是构建复杂的机器学习系统,你都需要确保数据已被清理。
·发表于数据科学之路 ·阅读时间 7 分钟·2023 年 10 月 17 日
--

(图像由作者使用 Midjourney 创建)
网络是一个高度有价值的数据来源。例如,用于创建大型语言模型的大量训练数据来自网络。
然而,它通常不在最适合的格式中。网络数据主要是非结构化的(即以自由文本的形式存在)。即使它有预定义的结构,网络数据在用于分析目的之前也需要大量的清理和预处理。
在这篇文章中,我们将使用 pandas 库清理一个包含汽车价格及其他属性的混乱数据集。
如果你想跟随并自行执行代码,可以从我的数据集仓库下载数据集。它叫做“mock_car_dataset”。我们将在这个混乱的数据集上执行的一些操作如下:
-
字符串操作
-
处理数据类型
-
基于字符串的过滤
-
替换值
-
使用其他列更新列值
-
格式化数值数据
-
分析数据以检测问题
我用模拟数据创建了这个数据集。然而,它就像是你从网络抓取的汽车数据集。我知道这一点,因为我以前做过。
数据集采用 CSV 格式。让我们从这个文件创建一个 pandas DataFrame 开始。
import pandas as pd
# create DataFrame
cars = pd.read_csv("mock_car_dataset.csv")
# check the shape
cars.shape
# output
(20, 6)
数据集包含 20 行和 6 列,这意味着我们有 20 辆车的 6 个属性的数据。虽然这是一个小数据集,但我们将执行的操作可以很容易地应用于更大的数据集(即数十万行)。
让我们看看这些属性是什么(cars.head()将显示以下内容):

汽车 DataFrame 的前 5 行(图像由作者提供)
下一步是检查数据类型。我们可以使用 dtypes 方法来完成:
cars.dtypes
# output
price object
year int64
make object
model object
cylinders object
odometer int64
dtype: object
这些列中的一些确实应该转换为数字数据类型,但在当前格式下无法做到。例如,不能直接将“$11,250”转换为整数值 11250。同样,“6 cylinders”值需要一些调整才能转换为 6。
Pandas 是一个非常多功能的字符串处理工具。看似复杂的操作可以在一行代码中完成。
例如,要将价格值转换为整数数据类型,我们首先需要从价格值中删除“\(”和“,”,然后应用 `astype` 函数来更改数据类型。假设所有价格值的第一个字符是“\)”,我们可以如下进行操作:
cars["price"] = cars["price"].str[1:].str.replace(",", "").astype("int")
str[1:] 选择从第二个字符开始的所有字符,这意味着删除“\(”。然后我们用空字符串替换“,”,最后更改数据类型。但首先,让我们确保所有价格值的第一个字符是“\)”。
# select only the first characters in the price column and check unique values
cars.price.str[0].unique()
# output
array(['$'], dtype=object)
现在我们可以使用以下代码行来修复价格列:
cars["price"] = cars["price"].str[1:].str.replace(",", "").astype("int")
下一个要处理的是 cylinders 列。让我们看看值的分布:
cars["cylinders"].value_counts()
# output
6 cylinders 9
8 cylinders 5
4 cylinders 4
other 2
Name: cylinders, dtype: int64
value_counts 函数显示列中的唯一值及其出现次数。我们需要的是从值中删除“ cylinders”,并用我们选择的值替换“other”。我倾向于用最频繁的值替换“other”,即 6。
最后,我们可以将该列的数据类型更改为整数。以下代码行完成了所有这些操作:
cars["cylinders"] = cars["cylinders"].str.replace(" cylinders", "").replace("other", 8)
# check the value distribution again
cars["cylinders"].value_counts()
# output
6 9
8 5
4 4
6 2
Name: cylinders, dtype: int64
在上面的代码片段的第一行中,你可能会注意到我们首先使用了 str.replace,然后使用了 replace。
-
str.replace用于替换字符串中的一系列字符 -
replace用于列中的整个值
让我们看看 DataFrame 现在的样子:

车辆 DataFrame 的前 5 行 (图片由作者提供)
make 列包含大写和小写字母,这可能导致相同的值被视为不同。这里有一个快速测试来确认这一点:
cars[cars["make"].str.contains("onda")]["make"].value_counts()
# output
honda 1
Honda 1
Name: make, dtype: int64
“Honda”和“honda”是相同的品牌,但它们不是相同的字符串。我们可以通过将所有字符转换为小写或大写来快速解决这个问题。
cars["make"] = cars["make"].str.lower()
# check the first 5 values in the make column
cars["make"][:5]
# output
0 nissan
1 chrysler
2 porsche
3 honda
4 lincoln
Name: make, dtype: object
更具体一点
当我检查“make”列中的唯一值时,我看到像 2007 和 2014 这样的值,它们不代表任何品牌(至少我所知)。这一定是因为在自由形式中错误的数据输入,这是从网络抓取数据时的常见问题。
让我们检查这些意外品牌的年份、制造商和型号值。
cars[cars["make"].isin(["2007", "2014"])]
isin 方法允许基于多个值过滤行。此行代码的输出是:

(图片由作者提供)
问题似乎是年份值被输入了两次,这导致品牌(即make)值被切换到模型旁边。解决这个问题有多种不同的替代方案。
我想到的一个方法是拆分model值,并将拆分后的第一部分作为make列中值的替代,如下所示:

(image by author)
以下代码行拆分了make值为 2007 的行的model列。
cars.loc[cars["make"].isin(["2007"]), "model"].str.split(" ", expand=True)

(image by author)
看起来有点奇怪。原因是一些值在开头包含空格字符,这会导致拆分后有 3 个项目。对于那些开头没有空格字符的值,最后一个项目变成了 None。我们可以通过使用strip方法来解决这个问题,该方法可以去除前后的空白字符。
cars.loc[cars["make"].isin(["2007"]), "model"].str.strip().str.split(" ", expand=True)

(image by author)
现在它是准确的,看来也好多了。下一步是将这段代码应用于所有make列中有年份值的行。
要找到需要替换的make值(即make列中的年份),你可以执行以下代码行:
pd.Series(cars["make"].unique()).sort_values()
它找到make列中的唯一值并对其进行排序,这将年份值放在开头。输出的前 5 行如下:

(image by author)
我们现在可以创建一个包含 2007 和 2014 的列表,然后用它来过滤需要替换的make列中的行。以下代码片段执行了我们需要的任务:
to_replace = ["2007", "2014"]
cars.loc[cars["make"].isin(to_replace), "make"] = \
cars.loc[cars["make"].isin(to_replace), "model"].str.strip().str.split(" ", expand=True)[0]
让我们确认一下make列中是否有任何年份值。
pd.Series(cars["make"].unique()).sort_values()

(image by author)
make列中的所有年份值都已消失。
结语
最后但同样重要的是,DataFrame 中有许多重复的行。我们可以使用drop_duplicates函数来消除它们。默认情况下,它会删除所有列值相同的行。如果你想基于一列或一组列中的值来确定重复项,你可以使用subset参数。
数据清洗是任何数据工作流程中的重要部分。无论你是在进行探索性数据分析还是构建复杂的机器学习系统,你都需要确保数据是干净的。否则,你所产生的结果可能是不可靠的。
感谢阅读。如有任何反馈,请告诉我。
清理 Confluence 混乱:一个 Python 和 BERTopic 的探索
一个驯服混乱文档的故事,旨在创建最终的 GPT 聊天机器人
·
关注 发表于Towards Data Science ·8 min read·2023 年 4 月 29 日
--
Rick Mason的照片,来源于Unsplash
介绍:
想象一下:你在一家迅速成长的科技公司,接受了创建一个最先进的聊天机器人任务,使用令人惊叹的 GPT 技术。这个聊天机器人注定将成为公司的瑰宝,一个虚拟的神谕,回答基于你 Confluence 空间中储存的知识宝藏的问题。听起来像是一个梦寐以求的工作,对吧?
但,当你仔细审视 Confluence 知识库时,现实会袭来。这是一片充满空白/不完整页面、无关文档和重复内容的狂野丛林。就像有人将一千块拼图倒入一个巨大的搅拌机并按下“启动”按钮。现在,你的任务是清理这个烂摊子,才能考虑构建那个令人惊叹的聊天机器人。
幸运的是,在本文中,我们将踏上征服 Confluence 混乱的激动人心的旅程,利用 Python 和 BERTopic 的力量来识别和消除那些恼人的异常值。所以,系好安全带,准备将你的知识库转变为适合前沿 GPT 聊天机器人的完美训练场。
手动方法与启发式诱惑
当你面对清理 Confluence 知识库这一艰巨任务时,你可能会考虑手动处理,一一梳理每个文档。然而,手动方式速度慢、劳动密集且容易出错。毕竟,即便是最细心的员工也可能忽略重要细节或误判文档的相关性。
凭借你对 Python 的知识,你可能会被诱惑创建一个基于启发式的解决方案,使用一组预定义规则来识别和消除异常值。虽然这种方法比手动清理更快,但也有其局限性。启发式方法可能过于死板,难以适应 Confluence 空间复杂且不断发展的特性,往往导致次优结果。
Python 和 BERTopic — Confluence 清理的强大组合
让我们介绍 Python 和 BERTopic,这一强大的组合可以帮助你更有效地解决清理 Confluence 知识库的挑战。Python 是一种多功能的编程语言,而 BERTopic 是一个先进的话题建模库,可以分析你的文档并根据其潜在话题对它们进行分组。
在接下来的段落中,我们将探讨 Python 和 BERTopic 如何协作,以自动化识别和消除 Confluence 空间中的异常值的过程。通过发挥它们的综合力量,你将节省时间和资源,同时提高清理工作的准确性和效果。
Python-BERTopic 项目 — 步骤指南
好的,从这一点开始,我将带你了解如何使用 BERTopic 创建 Python 脚本,以识别和消除 Confluence 知识库中的异常值。目标是根据“无关性”得分(我们稍后会定义)生成一个文档排名列表。最终输出将包括文档的标题、文本预览(前 100 个字符)以及无关性得分。最终输出将如下所示:
(标题:“医疗保健中的人工智能”,预览:“人工智能正在改变……”,无关性:0.95)
(标题:“办公室生日派对指南”,预览:“为确保有趣和安全……”,无关性:0.8)
这个过程的基本步骤包括:
-
连接到 Confluence 并下载文档:建立与 Confluence 账户的连接并获取用于处理的文档。本节提供有关设置连接、身份验证和下载必要数据的指南。
-
使用 Beautiful Soup 进行 HTML 处理和文本提取:使用 Beautiful Soup 这一强大的 Python 库来管理 HTML 内容并从 Confluence 文档中提取文本。这个步骤包括清理提取的文本,去除不需要的元素,并为分析准备数据。
-
应用 BERTopic 并创建排名:手头有了清理过的文本后,应用 BERTopic 来分析和分组文档,依据其潜在主题。获取主题表示后,计算每个文档的“无关性”度量,并创建排名,以识别和消除 Confluence 知识库中的异常值。
Confluence 连接和文档下载
最后是代码。在这里,我们将从 Confluence 空间开始下载文档,然后处理 HTML 内容,并提取文本以进入下一阶段(BERTopic!)。
首先,我们需要通过 API 连接到 Confluence。借助 atlassian-python-api 库,这可以通过几行代码完成。如果你没有 Atlassian 的 API 令牌,请阅读 这个指南 来设置。
import os
import re
from atlassian import Confluence
from bs4 import BeautifulSoup
# Set up Confluence API client
confluence = Confluence(
url='YOUR_CONFLUENCE URL',
username="YOUR_EMAIL",
password="YOUR_API_KEY",
cloud=True)
# Replace SPACE_KEY with the desired Confluence space key
space_key = 'YOUR_SPACE'
def get_all_pages_from_space_with_pagination(space_key):
limit = 50
start = 0
all_pages = []
while True:
pages = confluence.get_all_pages_from_space(space_key, start=start, limit=limit)
if not pages:
break
all_pages.extend(pages)
start += limit
return all_pages
pages = get_all_pages_from_space_with_pagination(space_key)
在获取页面后,我们将创建一个文本文件的目录,提取页面内容并将文本内容保存到单独的文件中:
# Function to sanitize filenames
def sanitize_filename(filename):
return "".join(c for c in filename if c.isalnum() or c in (' ', '.', '-', '_')).rstrip()
# Create a directory for the text files if it doesn't exist
if not os.path.exists('txt_files'):
os.makedirs('txt_files')
# Extract pages and save to individual text files
for page in pages:
page_id = page['id']
page_title = page['title']
# Fetch the page content
page_content = confluence.get_page_by_id(page_id, expand='body.storage')
# Extract the content in the "storage" format
storage_value = page_content['body']['storage']['value']
# Clean the HTML tags to get the text content
text_content = process_html_document(storage_value)
file_name = f'txt_files/{sanitize_filename(page_title)}_{page_id}.txt'
with open(file_name, 'w', encoding='utf-8') as txtfile:
txtfile.write(text_content)
函数 process_html_document 执行所有必要的清理任务,以从下载的页面中提取文本,同时保持一致的格式。你希望对这个过程进行多大程度的细化取决于你的具体需求。在这种情况下,我们专注于处理表格和列表,以确保生成的文本文档保留与原始布局类似的格式。
import spacy
nlp = spacy.load("en_core_web_sm")
def html_table_to_text(html_table):
soup = BeautifulSoup(html_table, "html.parser")
# Extract table rows
rows = soup.find_all("tr")
# Determine if the table has headers or not
has_headers = any(th for th in soup.find_all("th"))
# Extract table headers, either from the first row or from the <th> elements
if has_headers:
headers = [th.get_text(strip=True) for th in soup.find_all("th")]
row_start_index = 1 # Skip the first row, as it contains headers
else:
first_row = rows[0]
headers = [cell.get_text(strip=True) for cell in first_row.find_all("td")]
row_start_index = 1
# Iterate through rows and cells, and use NLP to generate sentences
text_rows = []
for row in rows[row_start_index:]:
cells = row.find_all("td")
cell_sentences = []
for header, cell in zip(headers, cells):
# Generate a sentence using the header and cell value
doc = nlp(f"{header}: {cell.get_text(strip=True)}")
sentence = " ".join([token.text for token in doc if not token.is_stop])
cell_sentences.append(sentence)
# Combine cell sentences into a single row text
row_text = ", ".join(cell_sentences)
text_rows.append(row_text)
# Combine row texts into a single text
text = "\n\n".join(text_rows)
return text
def html_list_to_text(html_list):
soup = BeautifulSoup(html_list, "html.parser")
items = soup.find_all("li")
text_items = []
for item in items:
item_text = item.get_text(strip=True)
text_items.append(f"- {item_text}")
text = "\n".join(text_items)
return text
def process_html_document(html_document):
soup = BeautifulSoup(html_document, "html.parser")
# Replace tables with text using html_table_to_text
for table in soup.find_all("table"):
table_text = html_table_to_text(str(table))
table.replace_with(BeautifulSoup(table_text, "html.parser"))
# Replace lists with text using html_list_to_text
for ul in soup.find_all("ul"):
ul_text = html_list_to_text(str(ul))
ul.replace_with(BeautifulSoup(ul_text, "html.parser"))
for ol in soup.find_all("ol"):
ol_text = html_list_to_text(str(ol))
ol.replace_with(BeautifulSoup(ol_text, "html.parser"))
# Replace all types of <br> with newlines
br_tags = re.compile('<br>|<br/>|<br />')
html_with_newlines = br_tags.sub('\n', str(soup))
# Strip remaining HTML tags to isolate the text
soup_with_newlines = BeautifulSoup(html_with_newlines, "html.parser")
return soup_with_newlines.get_text()
使用 BERTopic 识别异常值
在这一最后章节中,我们将最终利用 BERTopic,这是一种利用 BERT 嵌入的强大主题建模技术。你可以在他们的 GitHub 仓库 和 文档 中了解更多关于 BERTopic 的信息。
我们发现离群点的方法包括用不同的主题数量运行 BERTopic。在每次迭代中,我们将收集所有属于离群点簇(-1)的文档。文档在 -1 簇中出现的频率越高,被认为是离群点的可能性就越大。这种频率形成了我们无关度分数的第一个组成部分。BERTopic 还为 -1 簇中的文档提供了一个概率值。我们将计算所有迭代中每个文档这些概率的平均值。这个平均值代表了我们无关度分数的第二个组成部分。最后,我们通过计算两个分数(频率和概率)的平均值来确定每个文档的总体无关度分数。这个综合分数将帮助我们识别数据集中最无关的文档。
这是初始代码:
import numpy as np
from bertopic import BERTopic
from bertopic.vectorizers import ClassTfidfTransformer
from bertopic.representation import MaximalMarginalRelevance
from sklearn.feature_extraction.text import CountVectorizer
vectorizer_model = CountVectorizer(stop_words="english")
representation_model = MaximalMarginalRelevance(diversity=0.2)
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)
# Collect text and filenames from chunks in the txt_files directory
documents = []
filenames = []
for file in os.listdir('txt_files'):
if file.endswith('.txt'):
with open(os.path.join('txt_files', file), 'r', encoding='utf-8') as f:
documents.append(f.read())
filenames.append(file)
在这个代码块中,我们通过导入所需的库并初始化模型来设置 BERTopic 所需的工具。我们定义了 3 个模型,这些模型将被 BERTopic 使用:
-
vectorizer_model:CountVectorizer模型对文档进行分词,并创建一个文档-词项矩阵,其中每个条目表示一个术语在文档中的计数。它还会从文档中移除英文停用词,以提升主题建模性能。 -
representation_model:MaximalMarginalRelevance(MMR) 模型通过考虑主题的相关性和多样性来多样化提取的主题。diversity参数控制这两个方面之间的权衡,更高的值会导致更具多样性的主题。 -
ctfidf_model:ClassTfidfTransformer模型调整文档-词项矩阵的词频-逆文档频率(TF-IDF)分数,以更好地表示主题。它减少了跨主题频繁出现词语的影响,并增强了主题之间的区分度。
然后,我们从‘txt_files’目录中收集文档的文本和文件名,以便在下一步中使用 BERTopic 处理它们。
def extract_topics(docs, n_topics):
model = BERTopic(nr_topics=n_topics, calculate_probabilities=True, language="english",
ctfidf_model=ctfidf_model, representation_model=representation_model,
vectorizer_model=vectorizer_model)
topics, probabilities = model.fit_transform(docs)
return model, topics, probabilities
def find_outlier_topic(model):
topic_sizes = model.get_topic_freq()
outlier_topic = topic_sizes.iloc[-1]["Topic"]
return outlier_topic
outlier_counts = np.zeros(len(documents))
outlier_probs = np.zeros(len(documents))
# Define the range of topics you want to try
min_topics = 5
max_topics = 10
for n_topics in range(min_topics, max_topics + 1):
model, topics, probabilities = extract_topics(documents, n_topics)
outlier_topic = find_outlier_topic(model)
for i, (topic, prob) in enumerate(zip(topics, probabilities)):
if topic == outlier_topic:
outlier_counts[i] += 1
outlier_probs[i] += prob[outlier_topic]
在上述部分中,我们使用 BERTopic 通过在指定的最小值到最大值范围内迭代主题数量来识别离群文档。对于每个主题数量,BERTopic 提取主题及其相应的概率。然后,它识别离群主题并更新分配给该离群主题的文档的 outlier_counts 和 outlier_probs。这一过程迭代地累积计数和概率,提供了文档被分类为离群点的频率和“强度”的度量。
最后,我们可以计算我们的无关度分数并打印结果:
def normalize(arr):
min_val, max_val = np.min(arr), np.max(arr)
return (arr - min_val) / (max_val - min_val)
# Average the probabilities
avg_outlier_probs = np.divide(outlier_probs, outlier_counts, out=np.zeros_like(outlier_probs), where=outlier_counts != 0)
# Normalize counts
normalized_counts = normalize(outlier_counts)
# Compute the combined unrelatedness score by averaging the normalized counts and probabilities
unrelatedness_scores = [(i, (count + prob) / 2) for i, (count, prob) in enumerate(zip(normalized_counts, avg_outlier_probs))]
unrelatedness_scores.sort(key=lambda x: x[1], reverse=True)
# Print the filtered results
for index, score in unrelatedness_scores:
if score > 0:
title = filenames[index]
preview = documents[index][:100] + "..." if len(documents[index]) > 100 else documents[index]
print(f"Title: {title}, Preview: {preview}, Unrelatedness: {score:.2f}")
print("\n")
就是这样!在这里,你将得到按无关程度排名的异常文档列表。通过清理你的 Confluence 空间和移除无关内容,你可以为创建一个更高效、更有价值的聊天机器人铺平道路,充分利用你组织的知识。清理愉快!
你喜欢这篇文章吗?想要保持对未来类似内容的更新吗?别忘了在 Medium 上关注我,以便获得有关我最新文章和人工智能、机器学习等领域的见解的通知。让我们一起继续学习之旅!
K-Means 图像压缩的清晰、直观解释
·
关注 发表在 Towards Data Science ·11 分钟阅读·2023 年 1 月 16 日
--
K-Means 如何显著减少图像文件大小。
在本指南中,我从零开始描述并实现了 k-means 算法,并将其应用于图像压缩。我使用不同的可视化方法帮助读者深入理解 k-means 算法及其在图像压缩中的应用。我还在最后讨论了这种方法的各种优点和局限性。
除非另有说明,否则所有图片均由作者提供,可在此处查看。

使用 K-means 将左侧的原始图像压缩了 6 倍

K-means 算法如何迭代地创建压缩图像
什么是 K-Means 算法?
K-means 算法是一种无监督算法,它将数据集划分为k个不同的簇。它是无监督的,意味着数据点没有标签。换句话说,我们不知道数据集应该如何被聚类。我们只需提供数据集本身,使用 k-means 将其划分为k个簇。
大概念
K-means 旨在将数据集划分为 k 个簇,其中每个簇的成员共享特征,并且与其他簇不同。因此,K-means 的目标是将数据集有意义地划分为 k 个不同的簇。
应用
聚类分析通过抽象数据集的底层结构,将相似的数据分组在一起,从而提供有意义的见解。“聚类已经在心理学、生物学、医学、计算机视觉、通信和遥感等各种工程和科学领域得到了有效应用”[1]。
K-Means 算法如何工作
K-means 算法分为几个步骤:
-
初始化一组簇心
-
将观测值分配到簇中
-
更新簇
第 2 步和第 3 步会重复执行,直到达到设定的迭代次数或直到收敛,即当簇心不再变化时。
让我们更深入地了解这些步骤。
1. 初始化簇心集合
初始化簇心集合的第一步是选择我们想要使用多少个簇心,我们将其称为k。
一旦我们选择了簇的数量,我们从训练样本中随机选择k个样本,将簇心设置为所选k个样本的值。或者,我们可以在解空间中随机抽取k个不同的点来初始化簇心。
我们将j-th 簇心称为μⱼ,因为它代表分配给簇j的值的均值。这就是 k-means 名称的由来。在下图中,我们设置k=3,并在样本空间中随机抽取 3 个点(用绿色、蓝色和红色的‘x’表示)来初始化簇心。

在左侧,数据以散点图的形式展示。在右侧,我们添加了三个初始的簇心,用绿色、蓝色和红色的‘x’表示。
2. 将观测值分配到簇中
现在我们有了我们的k质心,我们将每个观察值(数据点)分配给离它最近的质心。通常,我们使用欧几里得距离计算“接近度”。在下图中,我们展示了将观察值分配给上面 3 个质心的过程。

左侧显示了 3 个初始质心。右侧则显示了将观察值分配给最近的质心。
3. 更新质心
一旦所有观察值都被分配到一个簇中,我们会将每个簇的质心移动到其分配观察值的均值。我们在下面展示了这一过程。

左侧是将观察值分配到簇中。右侧则是将簇的质心更新为分配的观察值的均值。注意质心如何移动。
重复直到收敛或进行一定次数的迭代
k-means 算法的每次迭代包括两个部分:步骤 2(将观察值分配给簇)和步骤 3(更新簇)。这个过程会重复进行,直到达到设定的迭代次数或收敛。当簇的质心不再变化时,就发生了收敛。这等同于说观察值的分配不再变化。
k均值算法将在有限次数的迭代内收敛[2],但它容易受到局部极小值的影响[3]。
在下例中,k均值算法在第 4 次迭代时收敛。这是因为在第 4 次迭代之后,簇的质心不再变化。

K-means 迭代直到收敛
K-means 用于图像压缩
图像压缩的目标是减少图像的文件大小。我们可以使用 K-means 选择k种颜色来表示整个图像。这使我们可以仅用k种颜色来表示图像,而不是整个 RGB 空间。这一过程也称为图像量化。
为什么 K-means 对图像压缩有用
使用 k-means 进行图像压缩的目的是选择k种颜色来表示目标图像,从而使逼近误差最小。换句话说,我们将使用 k-means 找到表示目标图像的最佳 k种颜色。
K-means 如何提供压缩
图像中的颜色像素由其 RGB 值表示,每个值的范围从 0 到 255。由于每个颜色通道有 256=2⁸种设置,总共有 256 ⋅ 256 ⋅ 256 = 256³ = 2²⁴ ~ 1700 万种颜色。为了表示这些颜色中的每一种,计算机需要 log₂(2²⁴) = 24 位的存储空间。如果我们使用 K-means 选择 16 种颜色来表示整个图像,我们只需要 log₂(16) = 4 位。因此,通过使用k=16 的 K-means,我们可以将图像大小压缩 6 倍!
现在我们理解了理论,让我们深入一些代码和可视化。

随着迭代次数增加,质心和压缩图像的变化情况
读取图像并初始化质心
在图像压缩的背景下,质心是我们用来表示压缩图像的颜色。因此,我们的第一步是读取图像并从图像中选择 k 种随机颜色来初始化我们的质心。
在第 7 行,我们使用 numpy 读取图像。这会生成一个二维数组,其中每个元素是一个长度为 3 的列表,表示该像素的 RGB 值。记得将 image_path 修改为你自己的路径。
从第 9 行开始,我们定义函数以初始化我们的质心。我们从图像中选择一个随机像素,并将其对应的 RGB 值添加到 centroids_init 数组中。我们这样做 k = num_clusters 次。因此,centroids_init 是一个从图像中随机采样的 k 种颜色的数组。
分配和更新质心
为了迭代更新质心,我们重复分配观测值到簇质心的过程,并将质心更新为分配观测值的均值。
在图像压缩的背景下,这意味着将目标图像的每个像素分配给最近的质心颜色。
在第 11 到 17 行,我们创建了字典 centroid_rgbs。每个键对应质心的一个索引,值是一个包含所有分配给相应质心的颜色的 numpy 数组。
每个像素的质心分配在第 13 行完成,使用 linalg.norm 计算到每个质心的欧几里得距离,然后使用 argmin 找到最近质心的索引。
创建压缩图像
现在我们已经确定了最终的质心,可以创建压缩图像。我们只需遍历每个像素,并将其颜色更改为最近的质心。
综合所有内容
使用以下代码片段和上述函数定义,运行 k-means 进行图像压缩的所有步骤已经完成。
为生成 GIF,我在算法的各个阶段使用了 plt.savefig。我的 GitHub 仓库包含了该过程的代码,以及如何将这些帧转换为 GIF [4]。

k=64 在一只墨鱼图像上的进展。
在上面的 GIF 中,我们可以看到质心(即我们选择用来表示图像的颜色)如何随着 k-means 算法的迭代而变化。
K-Means 在图像压缩中的分析
现在,我们分析一些关于使用 k-means 进行图像压缩的细节。
异常值
通常,图像会包含相对于图像主要色彩调色板的异常颜色。例如,下面的目标图像包含两只亮橙色的小丑鱼。它们的颜色与黑暗的背景和海葵形成强烈对比,这引起了观众的注意(希望以美观的方式)。

我的原始图像在左侧。右侧的图像是使用 k-means 压缩的,k=16
下面的 GIF 展示了当我们对图像应用 k-means 时发生的情况,k=16。虽然小丑鱼的亮橙色被选为初始集群,但随着算法迭代,它最终被较暗的颜色“冲淡”。

K-means 在海葵的图像上进行迭代。注意亮橙色如何随着时间的推移而消失。
尽管压缩图像的总体质量随着迭代次数的增加而提高,但异常颜色的准确性会下降。
一些文献建议专门为异常值创建集群(使用距离度量计算)以提高整体聚类准确性 [5]。作者通过对合成数据和真实数据的数值实验来演示他们提出的算法的有效性和效率。我怀疑实现这个算法可能有助于使用 k-means 进行图像压缩,尤其是对于包含异常颜色的图像。
选择“k”
k 的选择决定了压缩量,并由用户设置。较高的k值将提供对目标图像的更忠实的表示,但代价是压缩率较低。在下面的图形中,我们展示了随着 k 值增加的压缩图像。k=16、k=32、k=64 和 k=128 的压缩因子分别为 6、4.8、4 和 3.4。

k=16、32、64 和 128 的压缩图像
在上面的例子中,我们可以看到选择一个大于 32 的k值对于缓解前面提到的异常值问题至关重要。由于 k 足够大,至少一个中心点能够分配给亮橙色。在下图中,我们绘制了 k=64 和 k=256 在 30 次迭代后的中心点颜色。

k=64 和 k=256 在 30 次迭代后的中心点
在 30 次迭代后,k=64 具有一个分配给橙色的中心点。对于 k=256,大约有 4 种橙色。
这个可视化还描绘了不同k值下的压缩量与细节保留之间的权衡。显然,对于较大的k值,我们有更多的颜色和细节保留,但我们需要更多的数据来表示这些颜色。
根据目标图像和使用案例,尝试不同的k值可能是值得的。
有损压缩
使用 k-means 算法压缩图像是一种有损压缩。有损压缩是一类数据压缩方法,它通过对目标图像进行近似和部分数据丢失来实现[6]。当我们使用 k-means 进行图像压缩时,我们是通过最近的中心点来近似每个像素的颜色。由于在这个过程中丢失了信息,我们无法将压缩后的图像恢复为原始图像。这就是为什么有损压缩也被称为不可逆压缩。
另一方面,无损数据压缩不会丢失信息。相反,它使用技术来用更少的数据表示原始信息[7]。然而,无损数据压缩能够实现的压缩量远低于有损压缩。
尽管 k-means 是一种有损压缩形式,但对于某些k值,细节的丧失对人眼几乎是不可察觉的。

左侧为目标图像,右侧为使用 k=256 压缩后的图像
你能注意到上面两幅图像之间的许多差异吗?使用k=256,右侧的压缩图像只需要右侧完整图像的 1/3 的数据量!
中心点初始化中的随机性
在 k-means 算法的所有条件保持不变的情况下,由于中心点初始化过程中的随机性,每次运行的结果会略有不同。

注意 k=16 时 3 次独立运行的压缩图像中的细微变化
这意味着在相同参数下,压缩图像会输出略微不同的变体。然而,对于较大的 k 值,这种效果对人眼不那么明显。

k=1256 时 3 次独立运行的压缩图像差异几乎不可区分
优势与劣势
现在我们已经完成了对 k-means 算法在图像压缩方面的全面分析,我们将明确讨论其优缺点。
优势
-
效率:k-means 算法在计算上高效(线性时间复杂度),使其适用于实时图像压缩应用[8]。这也意味着它可以处理大图像。
-
简单性:k-means 算法相对简单,易于理解。
-
对于某些类型的图像很有用:k-means 在具有明显颜色簇的图像上表现良好。
缺点
-
有损压缩算法:k-means 是一种有损压缩形式,它基于像素簇表示整个图像,因此会丢失一些颜色信息,可能无法保留图像中的细节。
-
对初始化的敏感性:k-means 算法的性能可能对质心的初始位置敏感,这可能导致次优或不一致的结果。对于较大的k值,这个问题会有所减轻。
-
不适用于某些类型的图像:k-means 算法在颜色渐变平滑的图像和噪声较高的图像上表现较差。
结论
总体来说,k-means 对于有明显颜色簇的图像来说是一个不错的有损图像压缩选择。然而,它可能不是所有类型图像的最佳选择,其他技术如向量量化或分形压缩可能会产生更好的结果。
用户在选择k的值时面临关键决策,并且必须牢记在“选择‘k’”部分讨论的“压缩量与细节权衡”。最佳的 k 值可能会根据用户的需求而有所不同。
希望不同的可视化能够帮助加深对 k-means 算法及其如何进行图像压缩的理解。
参考文献
[1] Krishna, K., 和 M. Narasimha Murty. “遗传 k-means 算法。” IEEE 系统、男人与控制学报 B 部分(控制学),第 29 卷,第 3 期,1999 年,页码 433–439。 doi.org/10.1109/3477.764879
[2] stats.stackexchange.com/questions/188087/proof-of-convergence-of-k-means
[3] Ng, Andrew. “CS229 讲义。” cs229.stanford.edu/notes2022fall/main_notes.pdf
[4] 我的 Github 代码库。 github.com/SebastianCharmot/kmeans_image_compression
[5] Gan, Guojun, 和 Michael Kwok-Po Ng. “去除异常值的 k-means 聚类。” 模式识别快报,第 90 卷,2017 年,页码 8–14。 doi.org/10.1016/j.patrec.2017.03.008
[6] “有损压缩(文章)。” 可汗学院,可汗学院, www.khanacademy.org/computing/computers-and-internet/xcae6f4a7ff015e7d:digital-information/xcae6f4a7ff015e7d:data-compression/a/lossy-compression
[7] Ming Yang 和 N. Bourbakis. “无损数字图像压缩技术概述。” 第 48 届中西部电路与系统研讨会,2005 年,2005 年, doi.org/10.1109/mwscas.2005.1594297
[8] Chiou, Paul T., 等. “JPEG 图像压缩算法的复杂性分析。” 2017 年第九届计算机科学与电子工程会议 (CEEC), 2017, doi.org/10.1109/ceec.2017.8101601.
一劳永逸地澄清混乱:Python 中的 args、kwargs 和星号
并幸福地生活在一起
·发表于数据科学前沿 ·阅读时间 9 分钟·2023 年 5 月 30 日
--

图片由我用 Midjourney 制作
动机
我总是感到烦恼,当看到有人在函数中使用*args、**kwargs或将星号运算符用于乘法以外的任何其他目的时。我是说,他们能不能停下来片刻,用一些其他人都能理解的东西呢?
但在了解了它们的作用后,我意识到*args、**kwargs和前缀星号的使用源于对灵活性和优雅的渴望。尽管这可能让人感觉像是在Python 之禅上打了一巴掌,"可读性很重要",但它们是美丽代码的强大工具。
因此,在这篇文章中,我打算澄清所有关于这些神秘关键词和星光表达的混乱,并向你展示几乎所有可以使用它们的场景。
让我们开始吧。
解包可迭代对象
当星号用于两个变量或 Python 对象之间时,通常是用于乘法或幂运算。但是,当用于变量或可迭代对象之前时,它就变成了完全不同的东西。
可迭代对象是一个你可以遍历的 Python 对象,比如字符串、元组、列表、字典、集合等。
通过在任何可迭代对象或持有可迭代对象的变量前面放置星号,你可以拆解(解包)所有的元素。
展示 A:
breakable_list = list(range(25))
print(*breakable_list)
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
注意,打印解包的可迭代对象与单独打印它是不一样的:
print(breakable_list)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
区别在于,在第一次print调用中,我们将列表的所有元素作为单独的参数(实际上是 25 个)传递给print。在第二次调用中,我们将整个列表作为一个参数传递给print。
因此,场景零——使用可迭代解包将可迭代对象的所有元素作为单独的参数传递给某个函数。
展示 B 中的字符串:
string = "Readability counts"
print(*string)
R e a d a b i l i t y c o u n t s
使用循环,它会是这样:
for s in string:
print(s, end=" ")
R e a d a b i l i t y c o u n t s
我们稍后会看到这个场景的更多变体。
字典解包
场景一——使用双星号**操作符解包字典(仅字典)。
这个场景也有许多变体。其中之一是将字典项作为关键字参数传递给函数。
例如,考虑以下包含与pretty_print函数参数完全相同键的author字典:
author = {"name": "Bex", "n_followers": 17000, "happy": True}
def pretty_print(name, n_followers, happy):
print(
f"{name} has {n_followers} followers and "
f"he is {'happy' if True else 'unhappy'}!"
)
要将author的内容传递给pretty_print,我们需要写下这段难看的代码:
pretty_print(
name=author["name"],
n_followers=author["n_followers"],
happy=author["happy"],
)
Bex has 17000 followers and he is happy!
但我们不必这样做;我们可以使用双星号来解包项:
pretty_print(**author)
Bex has 17000 followers and he is happy!
这两个场景足以理解args和kwargs。
让我们继续。
位置参数与关键字参数
在此之前最后一件事——让我们快速回顾一下 Python 函数中位置参数和关键字参数的区别,以避免以后出现混淆或误解。
区别的一个句子总结:位置参数偏爱顺序,而关键字参数偏爱明确性。
看一个有四个参数的函数来进一步了解区别:
def total_price(
price: float,
quantity: int,
discount=0,
tax_rate=0,
):
# The rest of the code
...
位置参数:
-
不像
price和quantity那样有默认值。 -
不能省略。当调用函数时,你应该始终提供这些值。
-
需要顺序。如果你想让代码有意义或避免错误,你不能交换两个位置参数。
-
不关心名称。它们关心的是位置。
-
保持函数定义简短明了。
-
可能很难理解,尤其是在处理有很多参数的函数时。
另一方面,关键字参数:
-
始终有一个默认值,这意味着你可以在调用函数时省略它们。
-
不关心顺序。你可以随时选择任何一个,而不管它们在函数签名中定义的顺序。
-
提供精确性和清晰度。它们让你可以明确指定哪个参数对应哪个实参。
-
增强函数文档。它们作为功能的小标签来展示它们的作用。
查看这个讨论以获取更多细节。
不确定数量的位置参数
有一种常见的 Python 函数不知道需要多少个位置参数。例如,考虑这个计算三个数字几何平均数的函数:
def geometric_average(a, b, c):
product = a * b * c
geometric_avg = product ** (1 / 3)
return geometric_avg
geometric_average(5, 9, 8)
7.113786608980125
如果你想将其推广到四个数字怎么办?或者五个?或者六个?你知道我在说什么…
我们希望函数能够计算任意数量数字的几何平均数。
所以,这里是星号操作符的第二种场景:定义接受不确定数量位置参数的函数。
def geometric_average(*args):
print("That works.")
通过将*args放入函数定义中,你进入了第二种场景,允许你传递任意多个值而不会引发错误:
geometric_average(1, 2, 3, 4, 5, 6)
That works.
但*args实际是什么意思?
在幕后,当我们将用逗号分隔的六个数字传递给geometric_average时,*args将它们收集成一个元组:
def geometric_average(*args):
# Print the type of args
print(type(args))
geometric_average(2, 1)
<class 'tuple'>
因此,在这种情况下,我们进行的是打包,而不是解包。
好的,既然args现在是一个普通的元组,我们可以迭代它的元素并完成函数的其余部分:
def geometric_average(*args):
product = 1
# Iterate over args
for num in args:
product *= num
geometric_avg = product ** (1 / len(args))
return geometric_avg
geometric_average(2, 3, 5, 6, 1)
2.825234500494767
使用关键字args只是一个约定。事实上,你可以将其更改为更易读的词,如numbers:
def geometric_average(*numbers):
product = 1
for num in numbers:
product *= num
geometric_avg = product ** (1 / len(numbers))
return geometric_avg
不确定数量的关键字参数
下一个(第三个)场景是当一个函数可以接受任意数量的关键字参数。你猜对了,这就是**kwargs的用武之地:
def pickle_model(model_object, path, **kwargs):
"A function to pickle an ML model"
...
pickle_model将机器学习模型以 pickle 格式保存到磁盘。它有两个必需的位置参数,一个是模型对象本身,另一个是保存路径。
用户可以选择传递关于模型的额外信息,如超参数值、版本号、模型作者等,作为关键字参数。
pickle_model(
xgb_regressor,
"models/xgb_regressor.pkl",
hyperparameters={"max_depth": 3, "eta": 1},
author="bexgboost",
version="v1.0.1",
)
并且不会有任何错误。
和args一样,kwargs是一个你可以更改为几乎任何其他名称的名称:
def pickle_model(model_object, path, **metadata):
"A function to pickle an ML model"
print(type(metadata))
pickle_model(xgb_regressor, "models/xgb_reg.pkl", author="bexgboost")
<class 'dict'>
但不同于args,kwargs是一个字典。这意味着你可以通过查找(这可能导致错误)或使用.items()迭代来访问它的内容:
def pickle_model(model_object, path, **metadata):
"A function to pickle an ML model."
# Raises an error if author is not given, so use `get()`
author = metadata["author"]
# Iterate
for key, value in metadata.items():
...
一切的顺序
在混合参数时,无论是在函数签名还是调用中,你必须遵循一些规则:
- 位置参数总是优先。
def func(arg1, arg2, *args, **kwargs):
pass
-
位置参数不能被跳过(已说过)。
-
调用函数时可以完全跳过
*args和**kwargs:
func(1, 2)
在这种情况下,args将是一个空列表,kwargs将是一个空字典。
3. 所有类型的参数可以是独立的,这意味着你不必混合使用函数的参数:
# Only args itself
def func(*args):
...
4. 在关键字参数后不能传递位置参数:
# Can't do this
def func(arg1, arg2, hello="it's me", arg3):
pass
def func(arg1, arg2, hello="it's me", arg3):
^
SyntaxError: non-default argument follows default argument
5. *args必须始终在位置参数之后和关键字参数之前。
6. **kwargs应该始终是最后一个。
宏大的场景
尽管不算宏大,但这是在最新版本的 Python 中引入的一个方便技巧。考虑这个函数签名:
def weird(arg, arg_again, *, default=1):
pass
在一切的中间,我们看到一个独立的星号,没有附加任何东西。这是什么意思?
这种独立星号语法强制你始终显式使用关键字参数。例如,我们定义weird而没有星号,并调用它:
def weird(arg, arg_again, default=1):
pass
weird(1, 2, 3)
没有错误。对于default,我们传递了 3 但没有写default=3以使调用更简洁。
现在,让我们尝试带有星号的相同内容:
def weird(arg, arg_again, *, default=1):
pass
weird(1, 2, 3)
TypeError: weird() takes 2 positional arguments but 3 were given
我们得到一个TypeError!它告诉我们传递了一个多余的位置参数。换句话说,我们必须使用以下语法:
weird(1, 2, default=3)
现在,它又没有错误了。这种语法的引入是为了使函数调用更具可读性,特别是对于接受许多参数的函数和类。
Sklearn 在其最近的版本中广泛采用了这一特性。例如,这里是RandomForestRegressor的文档:

图片由我提供
过去,某些阅读文档的高手会这样写代码:
rf = RandomForestRegressor(1000, 'squared_error', 3, 3, 2)
初学者读到它会感到困惑。但是添加了星号后,所有在n_estimators之后的关键字参数必须被明确指定:
rf = RandomForestRegressor(
1000, # Before the asteisk
criterion="poisson", # After the asterisk
max_depth=3,
min_samples_split=3,
min_samples_leaf=2,
)
各种解包技巧
星号解包可以用于许多方式,除了函数签名和调用。在本节中,我将列出其中一些,而不深入细节。
0. 合并两个可迭代对象:
a = [1, 2, 3]
b = [4, 5, 6]
x_dict = {"a": 1, "b": 0}
y_dict = {"c": 10, "d": 10}
# Merge lists
[*a, *b]
[1, 2, 3, 4, 5, 6]
# Merge dictionaries
{**x_dict, **y_dict}
{'a': 1, 'b': 0, 'c': 10, 'd': 10}
- 扩展可迭代对象
a = [1, 2, 3]
b = [*a, "c", "d", "n"]
b
[1, 2, 3, 'c', 'd', 'n']
2. 使用字典项格式化字符串
employee_info = {"name": "John Doe", "age": 35, "position": "Manager"}
message = (
"Employee Name: {name}\nAge: "
"{age}\nPosition: {position}".format(**employee_info)
)
print(message)
Employee Name: John Doe
Age: 35
Position: Manager
3. 解包到变量中
# Unpack all elements of a set one-by-one
a, b, c = {1, 2, 3}
# Unpack into varied length arguments
first, *remaining, last = [1, 2, 3, 4, 5]
print(first, remaining, last)
1 [2, 3, 4] 5
4. 丢弃值
first, second, *_ = [1, 2, 3, 4, 5]
print(first, second)
1 2
5. 解包嵌套列表
nested_list = [1, [2, 3], 4]
a, (b, c), d = nested_list
print(b, c)
2 3
6. 转置列表的列表(矩阵)
def transpose(list_of_lists):
return [list(row) for row in zip(*list_of_lists)]
my_list = [[1, 4, 7], [2, 5, 8], [3, 6, 9]]
print(my_list)
[[1, 4, 7], [2, 5, 8], [3, 6, 9]]
transpose(my_list)
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
我希望你暂停思考为什么这些例子都能正常工作,尤其是那些甚至没有涉及星号运算符的第五个例子。
你自己知道其他技巧吗?不要害羞——在评论中分享一下吧。
结论
Python 的怪癖无穷无尽,星号的使用加上args和kwargs无疑就是一个例子。
关键是接受这种怪癖,克服最初的不适,因为它们可以成为使你的代码更优雅、可读的绝佳工具,当然,也有点奇怪。
感谢阅读!
喜欢这篇文章,并且,面对现实,它的奇特写作风格?想象一下如果你能访问到更多类似的文章,全部由一位才华横溢、迷人、机智的作者(顺便说一下,就是我 😃)所写。
仅需 4.99 美元会员,你将不仅能访问我的故事,还能获得来自 Medium 上最优秀思想的知识宝库。如果你使用 我的推荐链接,你将获得我的超级感谢和一个虚拟的高五,感谢你支持我的工作。
获得对我所有⚡高级⚡内容以及 Medium 上所有内容的独家访问权。通过购买我一杯咖啡来支持我的工作…
图片由我使用 Midjourney 制作
参考文献:
[1] RealPython, Python args and kwargs: Demystified,2019,realpython.com
[2] Trey Hunner, Asterisks in Python: what they are and how to use them,2018,treyhunner.com



















浙公网安备 33010602011771号