docs-merge-14
TowardsDataScience 2024 中文翻译(十五)
在浏览器中运行 Rust 的九条规则
从将 range-set-blaze 移植到 WASM 中获得的实践经验
·发布于Towards Data Science ·阅读时长:21 分钟·2024 年 10 月 8 日
--

在浏览器中运行 Rust — 来源:openai.com/dall-e-2/。其他所有图像均来自作者。
你想让你的 Rust 代码无处不在——从大型服务器到网页、机器人,甚至手表吗?在这篇三篇文章中的第二篇[1, 2, 3]中,我将向你展示如何使用 WebAssembly (WASM) 在用户的浏览器中直接运行 Rust 代码。
使用这种技术,你可以通过一个—或许是免费的—静态 web 服务器提供 CPU 密集型、动态网页。作为额外的好处,用户的数据永远不会离开他们的机器,避免了隐私问题。例如,我提供了一个工具,帮助朋友、跑步俱乐部成员和队友查找比赛结果。要查看该工具,请访问其网页,并点击“match”。

旁注:要了解更多关于匹配名称的信息,请参阅使用贝叶斯定理在列表中查找独特名称,《Towards Data Science》中的文章。
在浏览器中运行 Rust 面临一些挑战。你的代码无法像 Linux、Windows 或 macOS 那样访问完整的操作系统。你无法直接访问文件或网络。你仅能有限地访问时间和随机数。我们将探讨一些解决方法和变通方案。
将代码移植到浏览器中的 WASM 需要多个步骤和选择,且这些过程可能耗时。漏掉一步可能会导致失败。我们将通过提供九条规则来简化这一复杂性,接下来我们将详细探讨这些规则:
-
确认你现有的应用程序可以与 WASM WASI 一起工作,并创建一个简单的 JavaScript 网页。
-
安装
wasm32-unknown-unknown目标、wasm-pack、wasm-bindgen-cli以及用于测试的 Chrome 和 Chromedriver。 -
使你的项目成为
cdylib(以及rlib),添加wasm-bindgen依赖,并进行测试。 -
了解
wasm-bindgen支持哪些类型。 -
更改函数以使用支持的类型。将文件更改为通用的
BufRead。 -
适配测试,跳过那些不适用的测试。
-
如有必要,切换到适合 JavaScript 的依赖。运行测试。
-
将你的网页与函数连接起来。
-
将
wasm-pack添加到你的 CI(持续集成)测试中。
旁注:这些文章基于我在蒙特利尔的RustConf24上进行的三小时工作坊。感谢工作坊的参与者。同时也特别感谢来自西雅图 Rust Meetup 的志愿者,他们帮助测试了这些材料。这些文章替代了我去年写的一篇文章,并且包含了更新的信息。
正如在本系列的第一篇文章中所述,在逐条查看规则之前,我们先来定义一下术语。
-
Native:你的本地操作系统(Linux、Windows、macOS)
-
标准库(std):提供 Rust 的核心功能——
Vec、String、文件输入/输出、网络、时间等。 -
WASM:WebAssembly(WASM)是一种二进制指令格式,可以在大多数浏览器中运行(以及其他平台)。
-
WASI:WebAssembly 系统接口(WASI)允许非浏览器环境中的 WASM 访问文件 I/O、网络(尚未实现)和时间处理。
-
no_std:指示 Rust 程序不使用完整的标准库,使其适用于小型嵌入式设备或高度资源受限的环境。
-
alloc:在
no_std环境中提供堆内存分配功能(Vec、String等),这是动态管理内存所必需的。
基于我在[range-set-blaze](https://github.com/CarlKCarlK/range-set-blaze)数据结构项目中的经验,以下是我推荐的决策,逐一描述。为了避免含糊其辞,我将这些决策表达为规则。
规则 1:确认你现有的应用程序可以与 WASM WASI 一起工作,并创建一个简单的 JavaScript 网页。
使你的 Rust 代码能够在浏览器中运行,如果满足两个前提条件,将会更加容易:
-
让你的 Rust 代码在 WASM WASI 中运行。
-
让一些 JavaScript 在浏览器中运行。
对于第一个先决条件,请参阅《在 WASM WASI 上运行 Rust 的九条规则》(来自 Towards Data Science)。那篇文章——本系列的第一篇——详细说明了如何将你的代码从本地操作系统移植到 WASM WASI。通过这个迁移,你将完成一半的工作,接下来就可以在浏览器中运行 WASM 了。

我们希望运行代码的环境就像是一个逐渐收紧约束条件的维恩图。
通过你的测试确认代码在 WASM WASI 上运行:
rustup target add wasm32-wasip1
cargo install wasmtime-cli
cargo test --target wasm32-wasip1
对于第二个先决条件,展示你能创建一些 JavaScript 代码并在浏览器中运行。我建议将这个 index.html 文件添加到你项目的顶层:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Line Counter</title>
</head>
<body>
<h1>Line Counter</h1>
<input type="file" id="fileInput" />
<p id="lineCount">Lines in file: </p>
<script>
const output = document.getElementById('lineCount');
document.getElementById('fileInput').addEventListener('change', (event) => {
const file = event.target.files[0];
if (!file) { output.innerHTML = ''; return } // No file selected
const reader = new FileReader();
// When the file is fully read
reader.onload = async (e) => {
const content = e.target.result;
const lines = content.split(/\r\n|\n/).length;
output.textContent = `Lines in file: ${lines}`;
};
// Now start to read the file as text
reader.readAsText(file);
});
</script>
</body>
</html>
现在,将此页面提供给你的浏览器。你可以通过编辑器扩展来提供网页。我使用 Live Preview 作为 VS Code 的扩展。或者,你也可以安装并使用独立的 Web 服务器,比如 Simple Html Server:
cargo install simple-http-server
simple-http-server --ip 127.0.0.1 --port 3000 --index
# then open browser to http://127.0.0.1:3000
现在你应该能看到一个网页,在该网页上你可以选择一个文件。页面上的 JavaScript 会统计文件中的行数。

让我们看一下 JavaScript 的关键部分,因为稍后我们会修改它以调用 Rust。
旁注:你必须学习 JavaScript 才能在浏览器中使用 Rust 吗?是,也不是。是的,你需要编写一些简单的 JavaScript 代码。不是,你可能不需要“学习” JavaScript。我发现 ChatGPT 足够强大,可以生成我所需的简单 JavaScript 代码。
- 查看用户选择了哪个文件。如果没有选择文件,则返回:
const file = event.target.files[0];
if (!file) { output.innerHTML = ''; return } // No file selected
- 创建一个新的
FileReader对象,进行一些设置,然后以文本形式读取文件:
const reader = new FileReader();
// ... some setup ...
// Now start to read the file as text
reader.readAsText(file);
- 这是设置步骤。它说:等待文件完全读取,将其内容作为字符串读取,按行拆分字符串,然后显示行数。
// When the file is fully read
reader.onload = async (e) => {
const content = e.target.result;
const lines = content.split(/\r\n|\n/).length;
output.textContent = `Lines in file: ${lines}`;
};
满足先决条件后,我们接下来安装所需的 WASM-in-the-Browser 工具。
规则 2:安装 wasm32-unknown-unknown 目标、wasm-pack、wasm-bindgen-cli,以及用于测试的 Chrome 和 Chromedriver。
我们从简单的开始,安装这三个工具:
rustup target add wasm32-unknown-unknown
cargo install wasm-pack --force
cargo install wasm-bindgen-cli --force
第一行安装了一个新的目标,wasm32-unknown-unknown。这个目标将 Rust 编译为 WebAssembly,而不对代码将运行的环境做任何假设。没有假设使得它适合在浏览器中运行。(关于目标的更多信息,请参见上一篇文章中的规则 #2。)
接下来的两行安装 wasm-pack 和 wasm-bindgen-cli,这两个命令行工具。第一个用于构建、打包和发布,生成适用于网页使用的格式。第二个则简化了测试过程。我们使用 --force 来确保工具是最新的并且互相兼容。
现在,我们进入了麻烦的部分,安装测试用的 Chrome 和 Chromedriver。测试用的 Chrome 是一个可自动化的 Chrome 浏览器版本,而 Chromedriver 是一个独立的程序,可以将你的 Rust 测试用例运行在测试用的 Chrome 中。
为什么安装它们很麻烦?首先,过程有些复杂。其次,测试用的 Chrome 版本必须与 Chromedriver 版本匹配。第三,安装测试用的 Chrome 会与当前安装的常规 Chrome 冲突。
在了解这些背景信息后,以下是我的建议。从将这两个程序安装到你主目录的专用子文件夹开始。
- Linux 和 WSL(Windows 子系统 Linux):
cd ~
mkdir -p ~/.chrome-for-testing
cd .chrome-for-testing/
wget https://storage.googleapis.com/chrome-for-testing-public/129.0.6668.70/linux64/chrome-linux64.zip
wget https://storage.googleapis.com/chrome-for-testing-public/129.0.6668.70/linux64/chromedriver-linux64.zip
unzip chrome-linux64.zip
unzip chromedriver-linux64.zip
- Windows(PowerShell):
New-Item -Path $HOME -Name ".chrome-for-testing" -ItemType "Directory"
Set-Location -Path $HOME\.chrome-for-testing
bitsadmin /transfer "ChromeDownload" https://storage.googleapis.com/chrome-for-testing-public/129.0.6668.70/win64/chrome-win64.zip $HOME\.chrome-for-testing\chrome-win64.zip
bitsadmin /transfer "ChromeDriverDownload" https://storage.googleapis.com/chrome-for-testing-public/129.0.6668.70/win64/chromedriver-win64.zip $HOME\.chrome-for-testing\chromedriver-win64.zip
Expand-Archive -Path "$HOME\.chrome-for-testing\chrome-win64.zip" -DestinationPath "$HOME\.chrome-for-testing"
Expand-Archive -Path "$HOME\.chrome-for-testing\chromedriver-win64.zip" -DestinationPath "$HOME\.chrome-for-testing"
旁白:抱歉,我还没有测试任何 Mac 系统的安装说明。请查看 Chrome for Testing 网页,然后尝试适配 Linux 的方法。如果你告诉我什么方法有效,我将更新这一部分内容。
这将安装 129.0.6668.70 版本,这是截至 2024 年 9 月 30 日的稳定版本。如果你愿意,可以查看 Chrome for Testing 可用性 页面,查看最新的稳定版本。
接下来,我们需要将这些程序添加到 PATH 中。我们可以临时添加它们,仅对当前的终端会话有效:
- Linux 和 WSL(仅限本次会话):
export PATH=~/.chrome-for-testing/chrome-linux64:~/.chrome-for-testing/chromedriver-linux64:$PATH
- Windows(仅限本次会话):
# PowerShell
$env:PATH = "$HOME\.chrome-for-testing\chrome-win64;$HOME\.chrome-for-testing\chromedriver-win64;$PATH"
# or, CMD
set PATH=%USERPROFILE%\.chrome-for-testing\chrome-win64;%USERPROFILE%\.chrome-for-testing\chromedriver-win64;%PATH%
另外,我们可以将它们永久添加到我们的 PATH 中,适用于所有未来的终端会话。请理解,这可能会干扰你访问常规版本的 Chrome。
Linux 和 WSL(然后重新启动终端):
echo 'export PATH=~/.chrome-for-testing/chrome-linux64:~/.chrome-for-testing/chromedriver-linux64:$PATH' >> ~/.bashrc
Windows(PowerShell,然后重新启动终端):
[System.Environment]::SetEnvironmentVariable("Path", "$HOME\.chrome-for-testing\chrome-win64;$HOME\.chrome-for-testing\chromedriver-win64;" + $env:PATH, [System.EnvironmentVariableTarget]::User)
安装完成后,你可以使用以下命令验证安装是否成功:
chromedriver --version
旁白:你可以跳过安装并使用测试用的 Chrome 和 Chromedriver 吗?可以,也不可以。如果跳过它们,你仍然能够从 Rust 创建 WASM。此外,你还可以在网页中通过 JavaScript 调用这个 WASM。
然而,你的项目——就像所有优秀的代码一样——应该已经包含了测试。如果跳过测试用的 Chrome,你将无法运行浏览器中的 WASM 测试用例。此外,浏览器中的 WASM 违反了 Rust 的 “如果它能编译,它就能工作” 原则。具体来说,如果你使用了不支持的功能,如文件访问,编译成 WASM 时不会捕获错误。只有测试用例才能捕捉到这种错误。因此,运行测试用例是至关重要的。
现在我们有了在浏览器中运行测试的工具,让我们尝试(几乎可以肯定会失败)运行这些测试。
规则 3:将项目设置为 cdylib(和 rlib),添加 wasm-bindgen 依赖,并进行测试。
wasm-bindgen 包是一组自动生成的 Rust 和 JavaScript 之间的绑定,允许 JavaScript 调用 Rust。
为了在浏览器中为 WASM 准备代码,你需要将项目设置为库项目。此外,你还需要添加并使用 wasm-bindgen 依赖。按照以下步骤操作:
-
如果你的项目是可执行文件,将其改为库项目,通过将
src/main.rs重命名为src/lib.rs。同时,注释掉你的main函数。 -
让你的项目同时创建静态库(默认)和动态库(WASM 所需)。具体来说,编辑
Cargo.toml以包含:
[lib]
crate-type = ["cdylib", "rlib"]
- 添加
wasm-bindgen依赖:
cargo add wasm-bindgen
cargo add wasm-bindgen-test --dev
- 创建或更新
.cargo/config.toml(不要与Cargo.toml混淆),并包括:
[target.wasm32-unknown-unknown]
runner = "wasm-bindgen-test-runner"
接下来,哪些函数你希望能在 JavaScript 中可见?用#[wasm_bindgen]标记这些函数,并将它们设为pub(公共)。在函数文件的顶部,添加use wasm_bindgen::prelude::*;。
旁白:目前,你的函数可能无法编译。我们将在后续规则中解决这个问题。
测试怎么办?在每个#[test]上添加#[wasm_bindgen_test]。在测试需要时,添加以下use语句和配置语句:
use wasm_bindgen_test::wasm_bindgen_test;
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
如果你愿意,可以在一个小型示例项目上尝试上述步骤。从 GitHub 安装该示例项目:
# cd to the top of a work directory
git clone --branch native_version --single-branch https://github.com/CarlKCarlK/rustconf24-good-turing.git good-turing
cd good-turing
cargo test
cargo run pg100.txt
在这里,我们看到这些修改应用到了小型示例项目的lib.rs:
// --- May fail to compile for now. ---
use wasm_bindgen::prelude::*;
// ...
#[wasm_bindgen]
pub fn good_turing(file_name: &str) -> Result<(u32, u32), io::Error> {
let reader = BufReader::new(File::open(file_name)?);
// ...
}
// fn main() {
// ...
// }
#[cfg(test)]
mod tests {
use wasm_bindgen_test::wasm_bindgen_test;
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
// ...
#[test]
#[wasm_bindgen_test]
fn test_process_file() {
let (prediction, actual) = good_turing("./pg100.txt").unwrap();
// ...
}
}
做了这些修改后,我们准备开始测试(并可能失败):
cargo test --target wasm32-unknown-unknown
在这个示例中,编译器抱怨浏览器中的 WASM 不喜欢返回元组类型,这里是(u32, u32)。它还抱怨不喜欢返回包含io::Error的Result。要解决这些问题,我们需要了解浏览器中的 WASM 支持哪些类型。这正是规则 4 的主题。
修复类型问题并能够运行测试后会发生什么?测试仍然会失败,但这时是运行时错误。浏览器中的 WASM 不支持从文件中读取。然而,示例测试尝试从文件中读取。在规则 5 中,我们将讨论针对类型限制和文件访问限制的解决方法。
规则 4:了解wasm-bindgen支持哪些类型。
JavaScript 可以看到的 Rust 函数必须具有wasm-bindgen支持的输入和输出类型。使用不支持的类型会导致编译器错误。例如,传递一个u32是可以的。但传递一个元组(u32, 32)则不行。
更一般来说,我们可以将 Rust 类型分为三类:“是的!”,“不是!”和“避免”。
是的!
这是 Rust 类型的类别,JavaScript(通过wasm-bindgen)能够很好理解。
我们从Rust 的简单复制类型开始:

有两个项目让我感到意外。首先,64 位整数在 JavaScript 端需要额外的工作。具体来说,它们需要使用 JavaScript 的[BigInt](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/BigInt)类。其次,JavaScript 不支持 128 位整数。128 位整数属于“不是!”类别。
现在转向与字符串和向量相关的类型:

这些超有用的类型使用堆分配的内存。由于 Rust 和 JavaScript 的内存管理方式不同,每种语言都会创建一份数据的副本。我曾以为通过从 JavaScript 向 Rust 传递&mut [u8](可变字节切片)可以避免这种分配。但这并没有奏效。它并没有零次拷贝或一次拷贝,而是拷贝了两次。
顺便提一下,对于
String和&str,wasm-bindgen还在 JavaScript 的 UTF-16 Unicode 编码和 Rust 的 UTF-8 Unicode 编码之间进行转换。
接下来,在 Rust 中我们喜欢我们的Option 和 Result 类型。我很高兴地报告它们属于“Yeps”。

Rust 的Some(3)变成 JavaScript 的3,而 Rust 的None变成 JavaScript 的null。换句话说,wasm-bindgen将 Rust 的类型安全的空值处理转换为 JavaScript 的传统方式。在两种情况下,null/None都按照各自语言的习惯方式处理。
Rust 的Result与Option的行为类似。Rust 的Ok(3)变成 JavaScript 的3,而 Rust 的Err("Some error message")变成 JavaScript 的异常,可以通过try/catch捕获。请注意,Rust 的Err中的值只能是实现了Into<JsValue>特征的类型。通常使用String会很有效。
最后,让我们看看struct、enum 和 JSValue,我们最后一组“Yeps”:

激动人心的是,JavaScript 可以构建并调用 Rust 结构体上的方法。为了实现这一点,你需要使用#[wasm_bindgen]标记结构体和任何可以由 JavaScript 访问的方法。
举个例子,假设你想避免将一个巨大的字符串从 JavaScript 传递到 Rust。你可以定义一个 Rust 结构体,按序处理一系列字符串。JavaScript 可以构建该结构体,将文件的块传入其中,然后请求结果。
JavaScript 处理 Rust 枚举的方式不太令人兴奋。它只能处理没有关联数据(类似 C 枚举)的枚举,并将其值视为整数。
在兴奋感的中间位置,你可以将不透明的 JavaScript 值作为JsValue传递给 Rust。Rust 随后可以动态检查该值,确定其子类型,或者——如果适用——调用其方法。
这就结束了“Yeps”部分。现在来看一下“Nopes”部分。
不行!
这是 Rust 类型的类别,JavaScript(通过wasm-bindgen)无法处理。

例如,不能通过引用传递&u8是可以接受的,因为你可以直接使用u8,而且它可能更高效。
不能返回字符串切片(&str)或常规切片(&[u8])有些令人烦恼。为了避免生命周期问题,你必须返回一个拥有所有权的类型,如String或Vec<u8>。
你不能接受可变的String引用(&mut String)。不过,你可以接受一个String的值,通过修改它后再返回修改后的String。
我们如何解决“不可行”的问题?可以用向量(Vec<T>)或结构体来替代固定长度数组、元组和 128 位整数。
Rust 有集合和映射。JavaScript 也有集合和映射。然而,wasm-bindgen 库不会自动在它们之间转换。那么,如何将 Rust 中的 HashSet 传递给 JavaScript 呢?将其包装在你自己的 Rust 结构体中并定义所需的方法。然后,使用 #[wasm-bindgen] 标记该结构体和这些方法。
现在是我们的第三类。
避免
这是 Rust 类型的类别,JavaScript(通过 wasm-bindgen)允许使用,但你不应该使用它们。

避免使用 usize 和 isize,因为大多数人会认为它们是 64 位整数,但在 WebAssembly(WASM)中,它们是 32 位整数。相反,使用 u32、i32、u64 或 i64。
在 Rust 中,char 是一个特殊的 u32,只能包含有效的 Unicode 标量值。相比之下,JavaScript 将 char 视为字符串。它检查 Unicode 有效性,但不强制字符串的长度为 1。如果你需要将 char 从 JavaScript 传递到 Rust,最好使用 String 类型,然后在 Rust 端检查长度。
规则 5:更改函数以使用支持的类型。将文件更改为通用的 BufRead。
通过了解 wasm-bindgen 支持的类型,我们可以修正我们希望暴露给 JavaScript 的函数。我们将规则 3 的示例函数保留为如下所示:
#[wasm_bindgen]
pub fn good_turing(file_name: &str) -> Result<(u32, u32), io::Error> {
let reader = BufReader::new(File::open(file_name)?);
// ...
}
我们现在通过移除 #[wasm_bindgen] pub 来更改函数。我们还将函数修改为从通用读取器中读取,而不是文件名。使用 BufRead 可以提供更多的灵活性,使得该函数能够接受不同类型的输入流,例如内存数据或文件。
fn good_turing<R: BufRead>(reader: R) -> Result<(u32, u32), io::Error> {
// delete: let reader = BufReader::new(File::open(file_name)?);
// ...
}
JavaScript 无法看到这个函数,因此我们创建一个包装函数来调用它。例如:
#[wasm_bindgen]
pub fn good_turing_byte_slice(data: &[u8]) -> Result<Vec<u32>, String> {
let reader = BufReader::new(data);
match good_turing(reader) {
Ok((prediction, actual)) => Ok(vec![prediction, actual]),
Err(e) => Err(format!("Error processing data: {e}")),
}
}
这个包装函数接受一个字节切片(&[u8])作为输入,这是 JavaScript 可以传递的内容。该函数将字节切片转换为一个读取器,并调用内部的 good_turing 函数。内部函数返回一个 Result<(u32, u32), io::Error>。包装函数将此结果转换为 Result<Vec<u32>, String>,这是 JavaScript 可以接受的类型。
一般来说,我只愿意对既能原生运行又能在浏览器中以 WASM 运行的函数做一些小修改。例如,在这里,我愿意将函数改为处理通用读取器,而不是文件名。当 JavaScript 兼容性需要进行重大且不符合惯例的更改时,我会创建一个包装函数。
在这个示例中,在做出这些更改后,主代码现在可以编译。然而,原始测试仍然无法编译。修复测试是规则 6 的主题。
规则 6:调整测试,跳过不适用的部分。
规则 3 提倡将每个常规测试(#[test])同时标记为 WASM 浏览器测试(#[wasm_bindgen_test])。然而,由于 WASM 在访问系统资源(如文件)方面的限制,并非所有来自原生 Rust 的测试都可以在 WebAssembly 环境中运行。
在我们的示例中,规则 3 给出的测试代码无法编译:
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::wasm_bindgen_test;
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
#[test]
#[wasm_bindgen_test]
fn test_process_file() {
let (prediction, actual) = good_turing("./pg100.txt").unwrap();
assert_eq!(prediction, 10223);
assert_eq!(actual, 7967);
}
}
这个测试代码失败的原因是我们更新后的 good_turing 函数需要一个通用的读取器,而不是文件名。我们可以通过从示例文件创建一个读取器来修复这个测试:
use std::fs::File;
#[test]
fn test_process_file() {
let reader = BufReader::new(File::open("pg100.txt").unwrap());
let (prediction, actual) = good_turing(reader).unwrap();
assert_eq!(prediction, 10223);
assert_eq!(actual, 7967);
}
这是一个不错的原生测试。不幸的是,我们无法将其作为 WASM 浏览器测试运行,因为它使用了文件读取器——这是 WASM 不支持的功能。
解决方案是创建一个额外的测试:
#[test]
#[wasm_bindgen_test]
fn test_good_turing_byte_slice() {
let data = include_bytes!("../pg100.txt");
let result = good_turing_byte_slice(data).unwrap();
assert_eq!(result, vec![10223, 7967]);
}
在编译时,这个测试使用宏 include_bytes! 将一个文件转换为 WASM 兼容的字节切片。good_turing_byte_slice 函数将字节切片转换为读取器,并调用 good_turing。 (include_bytes 宏是Rust 标准库的一部分,因此可以用于测试。)
请注意,额外的测试既是常规测试,又是 WASM 浏览器测试。我们尽量让我们的测试两者兼顾。
在我的 range-set-blaze 项目中,我几乎能够将所有测试都标记为常规测试和 WASM 浏览器测试。唯一的例外是:有一个测试使用了 Criterion 基准测试函数。Criterion 无法在 WASM 浏览器中运行,因此我仅将该测试标记为常规测试(#[test])。
在修复了我们的主代码(规则 5)和测试代码(规则 6)之后,我们能实际运行我们的测试吗?不一定,我们可能需要找到适合 JavaScript 的依赖。
附注:如果你使用的是 Windows 并运行 WASM 浏览器测试,你可能会看到“
ERROR tiny_http] Error accepting new client: A blocking operation was interrupted by a call to WSACancelBlockingCall. (os error 10004)”。这与您的测试无关,可以忽略它。
规则 7:如有必要,改用适合 JavaScript 的依赖项。运行测试。
依赖项
示例项目现在可以编译。然而,在我的 range-set-blaze 项目中,仅修复我的代码和测试还不够。我还需要修复一些依赖项。具体来说,我需要在 Cargo.toml 中添加如下内容:
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies]
getrandom = { version = "0.2", features = ["js"] }
web-time = "1.1.0"
这两个依赖使得随机数生成和提供了一个替代的时间库。默认情况下,浏览器中的 WASM 无法访问随机数或时间。这两个依赖封装了 JavaScript 函数,使得它们对 Rust 既可访问又符合惯用法。
附注:有关在 *Cargo.toml* 中使用 *cfg* 表达式的更多信息,请参见我的文章: 九个 Rust Cargo.toml 的陷阱与误区:掌握 Cargo.toml 格式规则,避免挫折 | Towards Data Science (medium.com).
在 WebAssembly — 分类 — crates.io 中查找其他类似的 JavaScript 包装库。我还没有尝试过,但看起来很有趣的流行库包括:
另请参见 上一篇文章 中的规则 7 — 关于 WASM WASI — 了解更多关于修复依赖问题的内容。在本系列的 下一篇文章 中 — 关于 no_std 和嵌入式 — 我们将深入探讨更多修复依赖的策略。
运行测试
在修复了依赖问题后,我们终于可以运行我们的测试,包括常规测试和浏览器中的 WASM 测试:
cargo test
cargo test --target wasm32-unknown-unknown
记住,在幕后,我们对 cargo test --target wasm32-unknown-unknown 的调用:
-
查看
.cargo/config.toml并找到wasm-bindgen-test-runner(规则 3)。 -
调用
wasm-bindgen-test-runner。 -
使用 Chromedriver 在 Chrome 中运行我们的测试以进行测试。(规则 2,确保 Chrome for Testing 和 Chromedriver 已添加到你的路径中)。
在我们的测试成功后,我们现在准备从网页调用我们的 Rust 代码。
规则 8:将你的网页与你的函数连接起来。
要从网页调用你的 Rust 函数,你必须首先为网页打包你的 Rust 库。在规则 2 中,我们安装了 wasm-pack。现在,我们运行它:
wasm-pack build --target web
这将编译你的项目,并创建一个 JavaScript 能够理解的 pkg 输出目录。
示例
在规则 1 中,我们创建了一个没有调用 Rust 的 index.html 文件。现在让我们将其修改为调用 Rust。以下是一个示例的 index.html,并附有相关变化的描述。
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Good-Turing Estimation</title>
</head>
<body>
<h1>Good-Turing Estimation</h1>
<input type="file" id="fileInput" />
<p id="lineCount"></p>
<script type="module">
import init, { good_turing_byte_slice } from './pkg/good_turing.js'; // These files are generated by `wasm-pack build --target web`
const output = document.getElementById('lineCount');
document.getElementById('fileInput').addEventListener('change', (event) => {
const file = event.target.files[0];
if (!file) { output.innerHTML = ''; return } // No file selected
const reader = new FileReader();
// When the file is fully read
reader.onload = async (e) => {
await init(); // Ensure 'good_turing_byte_slice' is ready
// View the memory buffer as a Uint8Array
const u8array = new Uint8Array(e.target.result);
try { // Actually run the WASM
const [prediction, actual] = good_turing_byte_slice(u8array);
output.innerHTML =
`Prediction (words that appear exactly once on even lines): ${prediction.toLocaleString()}<br>` +
`Actual distinct words that appear only on odd lines: ${actual.toLocaleString()}`;
} catch (err) { // Or output an error
output.innerHTML = `Error: ${err}`;
}
};
// Now start to read the file as memory buffer
reader.readAsArrayBuffer(file);
});
</script>
</body>
</html>
让我们了解一下关注的变化。
- 下面这行代码将两个函数从
pkg/good_turing.js模块文件导入到 JavaScript 中,该文件是我们使用wasm-pack创建的。默认函数init初始化我们用 Rust 生成的 WebAssembly (WASM) 模块。第二个函数good_turing_byte_slice通过将其名称放在大括号中显式导入。
import init, { good_turing_byte_slice } from './pkg/good_turing.js';
- 创建一个新的
FileReader对象,进行一些设置,然后将文件作为字节数组读取。
const reader = new FileReader();
// ... some setup code ...
// Now start to read the file as bytes.
reader.readAsArrayBuffer(file);
- 这是我们在文件完全读取后执行的代码设置:
reader.onload = async (e) => {
//...
};
- 这一行确保 WASM 模块已初始化。第一次调用时,模块会被初始化。之后的调用不会有任何动作,因为模块已经准备好了。
await init(); // Ensure 'good_turing_byte_slice' is ready
- 从读取的文件中提取字节数组。
// View the memory buffer as a Uint8Array
const u8array = new Uint8Array(e.target.result);
- 调用 Rust 生成的 WASM 函数。
const [prediction, actual] = good_turing_byte_slice(u8array);
附带说明:这里的
good_turing_byte_slice是一个常规(同步)函数。不过,如果你愿意,可以在 Rust 端标记为async,然后在 JavaScript 端用await调用它。如果你的 Rust 处理速度较慢,这可以让你的网页更加流畅。
- 显示结果。
output.innerHTML =
`Prediction (words that appear exactly once on even lines): ${prediction.toLocaleString()}<br>` +
`Actual distinct words that appear only on odd lines: ${actual.toLocaleString()}`;
- 如果有错误,显示错误信息。
try { // Actually run the WASM
// ...
} catch (err) { // Or output an error
output.innerHTML = `Error: ${err}`;
}
该 最终代码 在 GitHub 上,包含一个 README.md,解释了它的功能。点击 此链接 查看实时演示。
range-set-blaze
我根据用户的要求将 range-set-blaze 移植到 WASM,以便他们可以在自己的项目中使用它。[range-set-blaze](https://github.com/CarlKCarlK/range-set-blaze) 项目通常作为库在其他项目中使用。换句话说,你通常不会期待 range-set-blaze 成为网页的核心部分。然而,我确实做了一个小的演示页面。你可以 浏览它 或 查看它的 index.html。该页面展示了 range-set-blaze 如何将一个整数列表转换为排序后的不相交范围列表。
附带说明:免费在 GitHub 上托管你的 WASM 在浏览器中的项目
1. 在你的项目中创建一个
docs文件夹。2. 执行
wasm-pack build --target web。3. 将
index.html和pkg复制(而不是移动)到docs文件夹中。4. 删除
docs/pkg中的.gitignore文件。5. 将项目提交到 GitHub。
6. 前往 GitHub 上的项目。然后进入“设置”,“Pages”。
7. 设置分支(在我的情况下是
main)和文件夹为docs。保存。8. URL 将基于你的账户和项目名称,例如,
carlkcarlk.github.io/rustconf24-good-turing/9. 要更新,重复步骤 2 到 5(包括)步骤。
规则 9:将wasm-pack添加到你的 CI(持续集成)测试中。
你的项目现在已经可以在浏览器中编译为 WASM,运行测试通过,并展示在网页上。完成了吗?还没完全完成。因为,正如我在第一篇文章中所说:
如果不在 CI 中,它就不存在。
回想一下,持续集成(CI)是一个系统,每次你更新代码时,它都可以自动运行测试,确保你的代码继续按预期工作。在我的例子中,GitHub 托管了我的项目。以下是我添加到 .github/workflows/ci.yml 文件中的配置,以便在浏览器中测试我的 WASM 项目:
test_wasm_unknown_unknown:
name: Test WASM unknown unknown
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Rust
uses: dtolnay/rust-toolchain@master
with:
toolchain: stable
target: wasm32-unknown-unknown
- name: Install wasm-pack
run: |
curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Run WASM tests with Chrome
run: |
rustup target add wasm32-unknown-unknown
wasm-pack test --chrome --headless
通过将浏览器中的 WASM 集成到 CI 中,我可以放心地为我的项目添加新代码。CI 会自动测试我的所有代码,确保它将来继续支持浏览器中的 WASM。
所以,这就是—将你的 Rust 代码移植到浏览器中的 WASM 的九条规则。这里有一件事让我感到惊讶:
不好的部分:
-
在浏览器中为 WASM 设置测试是很难的。特别是,Chrome 测试和 Chromedriver 的安装与管理非常困难。
-
在浏览器中使用 WASM 违反了 Rust 的格言:“如果它能编译,说明它能工作”。如果使用不支持的功能——例如,直接文件访问——编译器不会捕捉到错误。相反,错误会在运行时发生。
-
传递字符串和字节向量会创建数据的两个副本,一个在 JavaScript 端,一个在 Rust 端。
优点:
-
在浏览器中使用 WASM 既有用又有趣。
-
你可以标记你的常规测试,让它们也在浏览器中的 WASM 中运行。只需同时为测试添加这两个属性:
#[test]
#[wasm_bindgen_test]
- 你可以在浏览器中运行 WASM,而不需要移植到
no_std。不过,在浏览器中使用 WASM 是朝着在嵌入式/no_std环境中运行迈出的一个重要步骤。
敬请期待!在下一篇文章中,我们将看看如何将 Rust 代码移植到嵌入式环境并通过no_std运行。这将使你的代码能够在小型设备上运行,我觉得非常酷。
对未来的文章感兴趣吗?请在 Medium 上关注我。我写关于 Rust 和 Python、科学编程、机器学习和统计学的文章。我倾向于每个月写一篇文章。
在嵌入式系统上运行 Rust 的九条规则
将range-set-blaze移植到no_std的实践经验
·发表于Towards Data Science ·阅读时长:16 分钟·2024 年 10 月 13 日
--

Rust 在嵌入式设备上的运行 — 来源:openai.com/dall-e-2/。所有其他图示来自作者。
你想让你的 Rust 代码可以在各种设备上运行——从大型服务器到网页、机器人,甚至是手表吗?在这篇三部分系列文章的最后一篇中[1, 2, 3],我们将看到如何使用 Rust 在嵌入式设备上运行,方法是使用no_std。
将你的 Rust 项目移植到no_std环境,可以让你面向微控制器和深度嵌入式系统,从而为资源受限的环境创建高效的软件。例如,我使用即将发布的range-set-blaze版本,创建了一个 LED 动画序列和合成器,该软件在 Raspberry Pi Pico 上运行:
1 分钟的视频展示了在 Pico 上的 LED 动画
在没有标准库的情况下运行 Rust 会带来独特的挑战。由于没有操作系统的支持,一些功能,如文件 I/O、网络连接,甚至有时动态内存分配都不可用。在本文中,我们将探讨一些实用的策略,以克服这些限制。
将 Rust 移植到no_std环境需要仔细的步骤和选择,任何一步遗漏都可能导致失败。我们将通过遵循这九条规则来简化这一过程,接下来我们将详细探讨这些规则:
-
确保你的项目在 WASM WASI 和浏览器中的 WASM 环境下能够正常工作。
-
使用目标
thumbv7m-none-eabi和cargo tree来识别和修复与no_std不兼容的依赖项。 -
标记主(非测试)代码为
no_std和alloc。将std::替换为core::和alloc::。 -
使用 Cargo 功能使你的主代码能够选择性地使用
std来处理文件相关功能(等)。 -
了解为什么测试代码总是使用标准库。
-
创建一个简单的嵌入式测试项目。通过 QEMU 运行它。
-
在
Cargo.toml中添加适用于 WASM 和no_std的关键字和类别。 -
[可选] 使用预分配数据类型以避免
alloc。 -
将
thumbv7m-none-eabi和 QEMU 添加到你的 CI(持续集成)测试中。
附注:这些文章基于我在RustConf24上于蒙特利尔主持的一个三小时工作坊。感谢所有参与该工作坊的人员。特别感谢来自西雅图 Rust Meetup 的志愿者们,他们帮助测试了这份材料。这些文章更新了我去年撰写的一篇文章中的信息。
和本系列中的第一篇和第二篇文章一样,在逐条讲解规则之前,我们先来定义一些术语。
-
本地环境: 你的主操作系统(Linux,Windows,macOS)
-
标准库 (std):提供 Rust 的核心功能——
Vec,String,文件输入/输出,网络,时间处理。 -
WASM:WebAssembly(WASM)是一种二进制指令格式,能够在大多数浏览器中运行(以及更广泛的环境中)。
-
WASI:WebAssembly 系统接口(WASI)允许浏览器外部的 WASM 访问文件输入/输出、网络(尚未实现)以及时间处理。
-
no_std:指示 Rust 程序不使用完整的标准库,使其适用于小型嵌入式设备或资源极其有限的环境。
-
alloc:在
no_std环境中提供堆内存分配功能(Vec,String等),对于动态管理内存至关重要。
基于我在[range-set-blaze](https://github.com/CarlKCarlK/range-set-blaze)数据结构项目中的经验,以下是我推荐的决策,每个决策逐一描述。为了避免模糊不清,我将它们作为规则表达出来。
规则 1:确保你的项目能够与 WASM WASI 和浏览器中的 WASM 兼容。
在将 Rust 代码移植到嵌入式环境之前,请确保它能够在WASM WASI和WASM 浏览器中成功运行。这些环境暴露了与标准库的脱离相关的问题,并施加了类似嵌入式系统的约束。通过提前解决这些挑战,你将更接近在嵌入式设备上运行你的项目。
旁白:如果你不需要让你的项目同时在本地和/或 WASM 上运行,你可以跳过这一步。不过,你可能会发现之前文章中的一些步骤仍然有用——例如,运行在 32 位环境下和理解条件编译。

我们希望在其中运行代码的环境,可以看作是一个逐步收紧约束的维恩图。
运行以下命令以确认你的代码在 WASM WASI 和 WASM 浏览器中都能正常工作:
cargo test --target wasm32-wasip1
cargo test --target wasm32-unknown-unknown
如果测试失败或无法运行,请重新查看本系列早期文章中的步骤:WASM WASI和WASM 浏览器。
WASM WASI 文章还提供了关于理解 Rust 目标(规则 2)、条件编译(规则 4)和 Cargo 特性(规则 6)的关键背景知识。
一旦满足这些前提条件,下一步是看看我们是否能让依赖项在嵌入式系统上工作。
规则 2:使用目标thumbv7m-none-eabi和cargo tree来识别和修复与no_std不兼容的依赖项。
要检查你的依赖项是否与嵌入式环境兼容,可以为嵌入式目标编译项目。我建议使用thumbv7m-none-eabi目标:
-
thumbv7m— 代表 ARM Cortex-M3 微控制器,是一种流行的嵌入式处理器系列。 -
none— 表示没有可用的操作系统(OS)。在 Rust 中,这通常意味着我们无法依赖标准库(std),因此我们使用no_std。请记住,标准库提供了诸如Vec、String、文件输入/输出、网络和时间等核心功能。 -
eabi— 嵌入式应用程序二进制接口,一种定义嵌入式可执行文件调用约定、数据类型和二进制布局的标准。
由于大多数嵌入式处理器都共享no_std约束,因此确保与此目标的兼容性有助于确保与其他嵌入式目标的兼容性。
安装目标并检查你的项目:
rustup target add thumbv7m-none-eabi
cargo check --target thumbv7m-none-eabi
当我在range-set-blaze上进行此操作时,遇到了一些关于依赖项的错误,例如:

这表明我的项目依赖于num-traits,而num-traits又依赖于either,最终依赖于std。
错误信息可能会让人困惑。为了更好地理解情况,运行以下 cargo tree 命令:
cargo tree --edges no-dev --format "{p} {f}"
它显示了你的项目依赖关系及其激活的 Cargo 特性的递归列表。例如:
range-set-blaze v0.1.6 (C:\deldir\branches\rustconf24.nostd)
├── gen_ops v0.3.0
├── itertools v0.13.0 default,use_alloc,use_std
│ └── either v1.12.0 use_std
├── num-integer v0.1.46 default,std
│ └── num-traits v0.2.19 default,i128,std
│ [build-dependencies]
│ └── autocfg v1.3.0
└── num-traits v0.2.19 default,i128,std (*)
我们看到多个出现了名为 use_std 和 std 的 Cargo 特性,这强烈表明:
-
这些 Cargo 特性需要标准库。
-
我们可以关闭这些 Cargo 特性。
使用在第一篇文章中解释的技巧,规则 6,我们禁用了 use_std 和 std Cargo 特性。请记住,Cargo 特性是累加的,并且具有默认值。为了关闭默认特性,我们使用 default-features = false。然后,我们通过指定例如 features = ["use_alloc"] 来启用我们想要保留的 Cargo 特性。现在,Cargo.toml 文件内容如下:
[dependencies]
gen_ops = "0.3.0"
itertools = { version = "0.13.0", features=["use_alloc"], default-features = false }
num-integer = { version = "0.1.46", default-features = false }
num-traits = { version = "0.2.19", features=["i128"], default-features = false }
关闭 Cargo 特性并不总是足以使你的依赖项兼容 no_std。
例如,流行的 thiserror 包将 std 引入到你的代码中,并且没有提供禁用它的 Cargo 特性。然而,社区已经创建了 no_std 替代版本。你可以通过搜索,例如 crates.io/search?q=thiserror+no_std,来找到这些替代版本。
对于 range-set-blaze,仍然存在与包 [gen_ops](https://crates.io/crates/gen_ops) 相关的问题——这是一个非常方便的包,用于定义操作符如 + 和 &。该包使用了 std,但其实并不需要。我找到了需要的单行更改(使用我们将在规则 3 中讲解的方法)并提交了拉取请求。维护者接受了它,他们发布了更新版本:0.4.0。
有时,我们的项目无法禁用 std,因为我们需要像文件访问这样的能力来运行在完整的操作系统上。然而,在嵌入式系统上,我们愿意—事实上必须—放弃这些能力。在规则 4 中,我们将看到如何通过引入我们自己的 Cargo 特性使 std 使用变为可选。
使用这些方法解决了 range-set-blaze 中所有的依赖错误。然而,解决这些错误暴露出了主代码中的 281 个错误。进展!
规则 3:将主代码(非测试代码)标记为 no_std 和 alloc。将 std:: 替换为 core:: 和 alloc::。
在项目的 lib.rs(或 main.rs)顶部添加:
#![no_std]
extern crate alloc;
这意味着我们不会使用标准库,但仍然会分配内存。对于 range-set-blaze,这一变化将错误数量从 281 降低到 52。
许多剩余的错误是由于使用了 std 中的项,这些项在 core 或 alloc 中是可用的。由于 std 的大部分内容实际上是 core 和 alloc 的重新导出,我们可以通过将 std 引用切换到 core 或 alloc 来解决许多错误。这使我们能够在不依赖标准库的情况下保持必要的功能。
例如,对于以下每一行,我们都会遇到错误:
use std::cmp::max;
use std::cmp::Ordering;
use std::collections::BTreeMap;
将 std:: 改为 core:: 或(如果与内存相关)alloc:: 可以修复这些错误:
use core::cmp::max;
use core::cmp::Ordering;
use alloc::collections::BTreeMap;
一些功能,比如文件访问,仅限于 std —— 即它们在 core 和 alloc 之外定义。幸运的是,对于 range-set-blaze,切换到 core 和 alloc 解决了主代码中的 52 个错误。然而,这一修复暴露了测试代码中的 89 个错误。再次进展!
附注:你还可以通过 Clippy 规则 找到
std可以替换为alloc或core的地方。
我们将在规则 5 中处理测试代码中的错误,但首先,让我们弄清楚如果在完整操作系统上运行时需要文件访问等功能该怎么做。
规则 4:使用 Cargo 特性让你的主代码在文件相关(等)功能上可选使用 std。
如果我们需要两种版本的代码——一种用于在完整操作系统上运行,另一种用于嵌入式系统——我们可以使用 Cargo 特性(参见 第一篇文章 中的规则 6)。例如,定义一个名为 foo 的特性,它将是默认的。我们只会在启用 foo 时包含 demo_read_ranges_from_file 函数。
在 Cargo.toml(初步版)中:
[features]
default = ["foo"]
foo = []
在 lib.rs(初步版本)中:
#![no_std]
extern crate alloc;
// ...
#[cfg(feature = "foo")]
pub fn demo_read_ranges_from_file<P, T>(path: P) -> std::io::Result<RangeSetBlaze<T>>
where
P: AsRef<std::path::Path>,
T: FromStr + Integer,
{
todo!("This function is not yet implemented.");
}
这意味着只有在启用 Cargo 特性 foo 时才定义函数 demo_read_ranges_from_file。现在我们可以检查代码的不同版本:
cargo check # enables "foo", the default Cargo features
cargo check --features foo # also enables "foo"
cargo check --no-default-features # enables nothing
现在让我们通过将 foo 重命名为 std,为我们的 Cargo 特性起一个更有意义的名字。我们的 Cargo.toml(中间版)现在看起来是这样的:
[features]
default = ["std"]
std = []
在我们的 lib.rs 中,我们在顶部添加这些行,以便在启用 std Cargo 特性时引入 std 库:
#[cfg(feature = "std")]
extern crate std;
因此,lib.rs(最终版)看起来是这样的:
#![no_std]
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
// ...
#[cfg(feature = "std")]
pub fn demo_read_ranges_from_file<P, T>(path: P) -> std::io::Result<RangeSetBlaze<T>>
where
P: AsRef<std::path::Path>,
T: FromStr + Integer,
{
todo!("This function is not yet implemented.");
}
我们希望对 Cargo.toml 做最后一次更改。我们希望新的 Cargo 特性控制依赖和它们的特性。下面是最终版的 Cargo.toml:
[features]
default = ["std"]
std = ["itertools/use_std", "num-traits/std", "num-integer/std"]
[dependencies]
itertools = { version = "0.13.0", features = ["use_alloc"], default-features = false }
num-integer = { version = "0.1.46", default-features = false }
num-traits = { version = "0.2.19", features = ["i128"], default-features = false }
gen_ops = "0.4.0"
附注:如果你对
Cargo.toml中指定依赖和特性的格式感到困惑,看看我最近的文章:九个 Rust Cargo.toml 的陷阱与误区:掌握 Cargo.toml 格式规则,避免沮丧 在 Towards Data Science。
要检查你的项目是否同时能够在标准库(std)和无标准库环境下编译,使用以下命令:
cargo check # std
cargo check --no-default-features # no_std
使用 cargo check 已经能正常工作,你可能会认为 cargo test 会很直接。遗憾的是,并非如此。我们接下来看看这个问题。
规则 5:理解为什么测试代码总是使用标准库。
当我们使用 --no-default-features 编译项目时,它将在 no_std 环境中运行。然而,Rust 的测试框架总是包括标准库,即使在 no_std 项目中也是如此。这是因为 cargo test 需要 std;例如,#[test] 属性和测试框架本身都在标准库中定义。
结果是,运行:
# DOES NOT TEST `no_std`
cargo test --no-default-features
实际上并不会测试你代码的no_std版本。即使在真正的no_std环境中,std中那些不可用的函数,在测试时仍然可以访问。例如,下面的测试将在使用--no-default-features时成功编译和运行,尽管它使用了std::fs:
#[test]
fn test_read_file_metadata() {
let metadata = std::fs::metadata("./").unwrap();
assert!(metadata.is_dir());
}
此外,在std模式下进行测试时,你可能需要显式导入标准库中的某些功能。这是因为,即使在测试期间std可用,你的项目仍然是以#![no_std]编译的,这意味着标准前导并不会自动包含在作用域中。例如,你通常需要在测试代码中包含以下导入:
#![cfg(test)]
use std::prelude::v1::*;
use std::{format, print, println, vec};
这些导入将从标准库中引入必要的工具,以便在测试过程中可以使用它们。
要真正测试没有标准库的代码,你需要使用不依赖于cargo test的替代方法。我们将在下一条规则中探讨如何运行no_std测试。
规则 6:创建一个简单的嵌入式测试项目。使用 QEMU 运行它。
你无法在嵌入式环境中运行常规测试。然而,你可以 — 并且应该 — 至少运行一个嵌入式测试。我的哲学是,即使只有一个测试,也比没有测试要好得多。由于“如果它能编译,它就能工作”通常对no_std项目有效,一个(或几个)精心选择的测试可能会非常有效。
旁注:有希望以更正常的方式运行嵌入式测试[1][2]。据我所知,正常的本地测试没有简单的方法。如果有变化,请告诉我,我会更新这一部分内容。
要运行这个测试,我们使用 QEMU(快速仿真器,发音为“cue-em-you”),它允许我们在主操作系统(Linux、Windows 或 macOS)上模拟thumbv7m-none-eabi代码。
安装 QEMU。
查看 QEMU 的下载页面以获取完整信息:
Linux/WSL
-
Ubuntu:
sudo apt-get install qemu-system -
Arch:
sudo pacman -S qemu-system-arm -
Fedora:
sudo dnf install qemu-system-arm
Windows
-
方法 1:
qemu.weilnetz.de/w64。运行安装程序(告诉 Windows 它是可以的)。将"C:\Program Files\qemu\"添加到你的路径中。 -
方法 2:从
www.msys2.org/安装 MSYS2。打开 MSYS2 UCRT64 终端。pacman -S mingw-w64-x86_64-qemu。将C:\msys64\mingw64\bin\添加到你的路径中。
Mac
brew install qemu或sudo port install qemu
测试安装:
qemu-system-arm --version
创建一个嵌入式子项目。
为嵌入式测试创建一个子项目:
cargo new tests/embedded
这个命令生成一个新的子项目,包括位于tests/embedded/Cargo.toml的配置文件。
附注**:此命令还会修改您的顶级
Cargo.toml,将子项目添加到您的工作区。在 Rust 中,工作区是由顶级Cargo.toml中的[workspace]部分定义的相关包的集合。工作区中的所有包共享一个Cargo.lock文件,确保整个工作区的依赖版本一致。
编辑tests/embedded/Cargo.toml使其如下所示,但将"range-set-blaze"替换为您顶级项目的名称:
[package]
name = "embedded"
version = "0.1.0"
edition = "2021"
[dependencies]
alloc-cortex-m = "0.4.4"
cortex-m = "0.7.7"
cortex-m-rt = "0.7.3"
cortex-m-semihosting = "0.5.0"
panic-halt = "0.2.0"
# Change to refer to your top-level project
range-set-blaze = { path = "../..", default-features = false }
更新测试代码。
将tests/embedded/src/main.rs的内容替换为:
// Based on https://github.com/rust-embedded/cortex-m-quickstart/blob/master/examples/allocator.rs
// and https://github.com/rust-lang/rust/issues/51540
#![feature(alloc_error_handler)]
#![no_main]
#![no_std]
extern crate alloc;
use alloc::string::ToString;
use alloc_cortex_m::CortexMHeap;
use core::{alloc::Layout, iter::FromIterator};
use cortex_m::asm;
use cortex_m_rt::entry;
use cortex_m_semihosting::{debug, hprintln};
use panic_halt as _;
#[global_allocator]
static ALLOCATOR: CortexMHeap = CortexMHeap::empty();
const HEAP_SIZE: usize = 1024; // in bytes
#[alloc_error_handler]
fn alloc_error(_layout: Layout) -> ! {
asm::bkpt();
loop {}
}
#[entry]
fn main() -> ! {
unsafe { ALLOCATOR.init(cortex_m_rt::heap_start() as usize, HEAP_SIZE) }
// Test(s) goes here. Run only under emulation
use range_set_blaze::RangeSetBlaze;
let range_set_blaze = RangeSetBlaze::from_iter([100, 103, 101, 102, -3, -4]);
hprintln!("{:?}", range_set_blaze.to_string());
if range_set_blaze.to_string() != "-4..=-3, 100..=103" {
debug::exit(debug::EXIT_FAILURE);
}
debug::exit(debug::EXIT_SUCCESS);
loop {}
}
这部分main.rs代码大部分是嵌入式系统的模板代码。实际的测试代码是:
use range_set_blaze::RangeSetBlaze;
let range_set_blaze = RangeSetBlaze::from_iter([100, 103, 101, 102, -3, -4]);
hprintln!("{:?}", range_set_blaze.to_string());
if range_set_blaze.to_string() != "-4..=-3, 100..=103" {
debug::exit(debug::EXIT_FAILURE);
}
如果测试失败,它返回EXIT_FAILURE;否则,它返回EXIT_SUCCESS。我们使用hprintln!宏在仿真过程中将消息打印到控制台。由于这是一个嵌入式系统,代码会以无限循环的方式结束,以便持续运行。
添加支持文件。
在运行测试之前,您必须将两个文件添加到子项目中:来自 Cortex-M 快速入门仓库的build.rs和memory.x文件。
Linux/WSL/macOS
cd tests/embedded
wget https://raw.githubusercontent.com/rust-embedded/cortex-m-quickstart/master/build.rs
wget https://raw.githubusercontent.com/rust-embedded/cortex-m-quickstart/master/memory.
Windows (Powershell)
cd tests/embedded
Invoke-WebRequest -Uri 'https://raw.githubusercontent.com/rust-embedded/cortex-m-quickstart/master/build.rs' -OutFile 'build.rs'
Invoke-WebRequest -Uri 'https://raw.githubusercontent.com/rust-embedded/cortex-m-quickstart/master/memory.x' -OutFile 'memory.x'
另外,创建一个tests/embedded/.cargo/config.toml,并将以下内容添加到其中:
[target.thumbv7m-none-eabi]
runner = "qemu-system-arm -cpu cortex-m3 -machine lm3s6965evb -nographic -semihosting-config enable=on,target=native -kernel"
[build]
target = "thumbv7m-none-eabi"
此配置指示 Cargo 使用 QEMU 运行嵌入式代码,并将thumbv7m-none-eabi设置为子项目的默认目标。
运行测试。
使用cargo run(而不是cargo test)运行测试:
# Setup
# Make this subproject 'nightly' to support #![feature(alloc_error_handler)]
rustup override set nightly
rustup target add thumbv7m-none-eabi
# If needed, cd tests/embedded
cargo run
您应该看到日志消息,且进程应该无错误退出。在我的例子中,我看到:"-4..=-3, 100..=103"。
这些步骤可能看起来需要做很多工作,仅仅是为了运行一个(或几个)测试。然而,这主要是一次性的工作,主要是复制和粘贴。此外,它还使得在 CI 环境中运行测试成为可能(请参见规则 9)。替代方法——声称代码在no_std环境中运行良好,而实际上从未在no_std中运行过——可能会忽视关键问题。
下一个规则要简单得多。
规则 7:在Cargo.toml中,为 WASM 和no_std添加关键字和分类。
一旦您的包编译并通过了额外的嵌入式测试,您可能希望将其发布到crates.io,Rust 的包注册表。为了让其他人知道它兼容 WASM 和no_std,请将以下关键字和分类添加到您的Cargo.toml文件中:
[package]
# ...
categories = ["no-std", "wasm", "embedded"] # + others specific to your package
keywords = ["no_std", "wasm"] # + others specific to your package
请注意,对于分类,我们在no-std中使用了连字符。对于关键字,no_std(带下划线)比no-std更常用。您的包最多可以有五个关键字和五个分类。
这里有一个分类和关键字的列表,可能会对您有兴趣,并附有每个术语使用的 crate 数量:
-
分类 no-std(6884)
-
分类嵌入式(3455)
-
分类 wasm(2026)
-
Keyword wasm (1686)
-
Keyword no_std (1351)
-
Keyword no-std (1157)
-
Keyword embedded (925)
-
Keyword webassembly (804)
良好的类别和关键字将帮助人们找到你的包,但该系统是不正式的。没有机制检查你的类别和关键字是否准确,也不要求你提供它们。
接下来,我们将探索你可能遇到的最受限的环境之一。
规则 8:[可选] 使用预分配的数据类型以避免使用alloc。
我的项目range-set-blaze实现了一个动态数据结构,需要从堆中进行内存分配(通过alloc)。但是,如果你的项目不需要动态内存分配怎么办?那样的话,它可以运行在更加受限的嵌入式环境中——特别是那些程序加载时就已预分配所有内存的环境。
如果可能的话,避免使用alloc的原因:
-
完全确定的内存使用
-
降低运行时故障的风险(通常由内存碎片引起)
-
降低功耗
有些 crate 可以帮助你替换像Vec、String和HashMap这样的动态数据结构。这些替代方案通常要求你指定一个最大大小。下表展示了一些常用的 crate:

我推荐使用heapless crate,因为它提供了一系列协同工作的数据结构。
这是一个与 LED 显示相关的代码示例——使用heapless。这段代码创建了一个从字节到整数列表的映射。我们将映射中条目的数量和整数列表的长度限制为DIGIT_COUNT(在此例中为 4)。
use heapless::{LinearMap, Vec};
// …
let mut map: LinearMap<u8, Vec<usize, DIGIT_COUNT>, DIGIT_COUNT> = LinearMap::new();
// …
let mut vec = Vec::default();
vec.push(index).unwrap();
map.insert(*byte, vec).unwrap(); // actually copies
创建一个no_alloc项目的完整细节超出了我的经验范围。然而,第一步是从你的lib.rs或main.rs中删除这一行(在规则 3 中添加的):
extern crate alloc; // remove this
规则 9:将thumbv7m-none-eabi和 QEMU 添加到你的 CI(持续集成)测试中。
你的项目现在已经编译为no_std并通过了至少一个嵌入式特定的测试。你完成了吗?还没有。正如我在前两篇文章中所说:
如果它不在 CI 中,那就不存在。
记住,持续集成(CI)是一个每次更新代码时都能自动运行测试的系统。我使用 GitHub Actions 作为我的 CI 平台。以下是我添加到.github/workflows/ci.yml中的配置,用于在嵌入式平台上测试我的项目:
test_thumbv7m_none_eabi:
name: Setup and Check Embedded
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Rust
uses: dtolnay/rust-toolchain@master
with:
toolchain: stable
target: thumbv7m-none-eabi
- name: Install check stable and nightly
run: |
cargo check --target thumbv7m-none-eabi --no-default-features
rustup override set nightly
rustup target add thumbv7m-none-eabi
cargo check --target thumbv7m-none-eabi --no-default-features
sudo apt-get update && sudo apt-get install qemu qemu-system-arm
- name: Test Embedded (in nightly)
timeout-minutes: 1
run: |
cd tests/embedded
cargo run
通过在 CI 中测试嵌入式和no_std,我可以确保我的代码将继续支持未来的嵌入式平台。
所以,事情就是这样——移植 Rust 代码到嵌入式的九条规则。要查看应用了这九条规则后整个range-set-blaze项目的快照,请参见这个 Github 分支。
这就是让我在移植到嵌入式时感到惊讶的地方:
坏的一面:
-
我们无法在嵌入式系统上运行现有的测试。相反,我们必须创建一个新的子项目并编写(一些)新的测试。
-
许多流行的库依赖于
std,因此找到或调整适用于no_std的依赖库可能会面临挑战。
好的一面:
-
Rust 中“如果能编译,就能运行”的说法在嵌入式开发中依然成立。这使我们能够在无需大量新测试的情况下,对代码的正确性充满信心。
-
尽管
no_std移除了我们对标准库的直接访问,但许多项目仍然可以通过core和alloc使用。 -
借助仿真,你可以在没有硬件的情况下为嵌入式系统进行开发。
感谢你加入我从 WASI 到 WebAssembly 再到嵌入式开发的旅程。Rust 凭借其在不同环境中高效且安全运行的能力,继续让我印象深刻。随着你探索这些不同的领域,我希望你能像我一样,感受到 Rust 的灵活性和强大。无论你是在处理云服务器、浏览器还是微控制器,我们讨论的工具都将帮助你自信地应对未来的挑战。
对未来的文章感兴趣吗?请在 Medium 上关注我。我写关于 Rust 和 Python、科学编程、机器学习以及统计学的文章。我倾向于每月写一篇文章。
在 WASM WASI 上运行 Rust 的九条规则
将range-set-blaze移植到这种容器化环境中的实践经验
·发表于Towards Data Science ·13 分钟阅读·2024 年 9 月 28 日
--

在类似容器的环境中运行 Rust —— 来源:openai.com/dall-e-2/。其他所有图表来自作者。
你想让你的 Rust 代码在任何地方运行吗——从大型服务器到网页、机器人,甚至是手表?在这三篇文章中的第一篇[1, 2, 3],我将详细描述实现这一目标的步骤。
在受限环境中运行 Rust 会遇到许多挑战。你的代码可能无法访问完整的操作系统,如 Linux、Windows 或 macOS。你可能无法(或根本无法)访问文件、网络、时间、随机数,甚至是内存。我们将探讨一些解决方法和应对策略。
这篇文章的重点是如何在“WASM WASI”这个类似容器的环境中运行代码。我们将看到,WASM WASI 可能(也可能不)在自身上具有实际用途。然而,它作为在浏览器或嵌入式系统中运行 Rust 的第一步,仍然具有价值。
将代码移植到 WASM WASI 上需要许多步骤和选择。做出这些选择可能非常耗时。错过一个步骤可能导致失败。我们将通过提供九条规则来简化这一过程,我们将在接下来的内容中详细探讨这些规则:
-
为失望做好准备:WASM WASI 很简单,但——目前——基本无用——除非作为一个垫脚石。
-
理解 Rust 的目标平台。
-
安装
wasm32-wasip1目标和 WASMTIME,然后创建“Hello, WebAssembly!”。 -
理解条件编译。
-
运行常规测试,但使用 WASM WASI 目标。
-
理解 Cargo 特性。
-
改变你能改变的事物:通过选择 Cargo 特性解决依赖问题,64 位/32 位问题。
-
接受你不能改变一切:网络、Tokio、Rayon 等等。
-
将 WASM WASI 添加到你的 CI(持续集成)测试中。
顺便说一句:这些文章基于我在 RustConf24 会议上提供的三小时研讨会。感谢参加该研讨会的所有人。同时,特别感谢来自西雅图 Rust Meetup 的志愿者们,他们帮助测试了这篇材料。这些文章替代了我去年写的 一篇文章,并提供了更新的信息。
在我们逐条查看规则之前,先定义一下术语。
-
Native:你的本地操作系统(Linux、Windows、macOS)
-
标准库 (std):提供 Rust 的核心功能——
Vec、String、文件输入/输出、网络、时间等。 -
WASM:WebAssembly(WASM)是一种二进制指令格式,可以在大多数浏览器中运行(以及其他平台)。
-
WASI:WebAssembly 系统接口(WASI)允许非浏览器环境中的 WASM 访问文件 I/O、网络(尚未支持)和时间处理。
-
no_std:指示 Rust 程序不使用完整的标准库,使其适用于小型嵌入式设备或资源受限的环境。
-
alloc:在
no_std环境中提供堆内存分配功能(Vec、String等),对动态管理内存至关重要。
有了这些术语,我们可以将代码运行的环境想象成一个渐进收紧约束的维恩图。本文详细介绍了如何从本地环境迁移到 WASM WASI。第二篇文章 讲述了如何进一步迁移到浏览器中的 WASM。最后一篇文章 涵盖了在 no_std 环境中运行 Rust 的方法,包括带有和不带 alloc 的情况,非常适合嵌入式系统。

根据我在数据结构项目 range-set-blaze 的经验,以下是我推荐的决策,逐条描述。为了避免模糊不清,我将它们表述为规则。
规则 1:为失望做好准备:WASM WASI 很简单,但——目前——大多数情况下没什么用——除非作为一个跳板。
2019 年,Docker 联合创始人 Solomon Hykes 在推特上发文:
如果 2008 年就有 WASM+WASI,我们就不需要创建 Docker 了。这就是它的重要性。服务器上的 WebAssembly 是计算的未来。一个标准化的系统接口是缺失的环节。希望 WASI 能胜任这一任务。
今天,如果你关注技术新闻,你会看到像这样的乐观标题:

如果 WASM WASI 真的是已经准备好并且有用,大家早就已经在使用它了。我们不断看到这些标题,表明它还没有准备好。换句话说,如果 WASM WASI 真的是准备好的,他们就不需要一直坚持说它准备好了。
截至 WASI Preview 1,情况如下:你可以访问某些文件操作、环境变量,并能访问时间和随机数生成。然而,尚不支持网络功能。
WASM WASI 可能 对某些类似 AWS Lambda 风格的 Web 服务有用,但即便如此也不确定。因为如果 WASM WASI 真的有用,你不想将你的 Rust 代码本地编译并以比 WASM WASI 快两倍且成本减半的方式运行吗?
也许 WASM WASI 对插件和扩展有用。在基因组学中,我有一个为 Python 编写的 Rust 扩展,我为 25 个不同的组合(5 个 Python 版本跨 5 个操作系统目标)编译它。即便如此,我也没有覆盖所有可能的操作系统和芯片系列。我能用 WASM WASI 替换这些操作系统目标吗?不行,它会太慢。我可以将 WASM WASI 作为第六个“万用”目标吗?也许可以,但如果我真的需要可移植性,我已经需要支持 Python,最好直接使用 Python。
那么,WASM WASI 到底有什么用呢?现在,它的主要价值在于它是朝着在浏览器或嵌入式系统上运行代码迈出的一步。
规则 2:了解 Rust 目标。
在规则 1 中,我稍微提到过“操作系统目标”。现在,让我们更深入地了解 Rust 目标——这是关于 WASM WASI 的关键信息,也适用于一般的 Rust 开发。
在我的 Windows 机器上,我可以将 Rust 项目编译为在 Linux 或 macOS 上运行。同样,从 Linux 机器上,我也可以将 Rust 项目编译为目标 Windows 或 macOS。以下是我用来在 Windows 机器上添加和检查 Linux 目标的命令:
rustup target add x86_64-unknown-linux-gnu
cargo check --target x86_64-unknown-linux-gnu
旁白:虽然
cargo check可以验证代码是否编译成功,但构建一个完全功能的可执行文件还需要额外的工具。要从 Windows 编译到 Linux(GNU),你还需要安装 Linux GNU 的 C/C++ 编译器和相应的工具链。这可能会有些棘手。幸运的是,对于我们关心的 WASM 目标,所需的工具链是很容易安装的。
要查看 Rust 支持的所有目标,请使用以下命令:
rustc --print target-list
它将列出超过 200 个目标,包括 x86_64-unknown-linux-gnu、wasm32-wasip1 和 wasm32-unknown-unknown。
目标名称包含最多四个部分:CPU 系列、厂商、操作系统和环境(例如,GNU 与 LVMM):

目标名称部分 — 图示来自作者
现在我们已经对目标有所了解,让我们继续安装我们需要的 WASM WASI 目标。
规则 3:安装 wasm32-wasip1 目标和 WASMTIME,然后创建“Hello, WebAssembly!”。
若要在浏览器外运行我们的 Rust 代码,我们需要针对 wasm32-wasip1(具有 WASI Preview 1 的 32 位 WebAssembly)进行编译。我们还将安装 WASMTIME,这是一个运行时,允许我们在浏览器外使用 WASI 运行 WebAssembly 模块。
rustup target add wasm32-wasip1
cargo install wasmtime-cli
为了测试我们的设置,让我们使用 cargo new 创建一个新的“Hello, WebAssembly!” Rust 项目。这将初始化一个新的 Rust 包:
cargo new hello_wasi
cd hello_wasi
编辑 src/main.rs 使其如下所示:
fn main() {
#[cfg(not(target_arch = "wasm32"))]
println!("Hello, world!");
#[cfg(target_arch = "wasm32")]
println!("Hello, WebAssembly!");
}
附注:我们将在规则 4 中更深入地研究
#[cfg(...)]属性,它支持条件编译。
现在,运行项目命令 cargo run,你应该看到 Hello, world! 被打印到控制台。
接下来,创建一个 .cargo/config.toml 文件,指定当目标为 WASM WASI 时,Rust 应如何运行和测试该项目。
[target.wasm32-wasip1]
runner = "wasmtime run --dir ."
附注:这个
.cargo/config.toml文件不同于主Cargo.toml文件,后者定义了项目的依赖项和元数据。
现在,如果你说:
cargo run --target wasm32-wasip1
你应该看到 Hello, WebAssembly!。恭喜!你刚刚成功地在类似容器的 WASM WASI 环境中运行了一些 Rust 代码。
规则 4:理解条件编译。
现在,让我们研究 #[cfg(...)] —— 这是 Rust 中用于条件编译代码的一个重要工具。在规则 3 中,我们看到:
fn main() {
#[cfg(not(target_arch = "wasm32"))]
println!("Hello, world!");
#[cfg(target_arch = "wasm32")]
println!("Hello, WebAssembly!");
}
#[cfg(...)] 行告诉 Rust 编译器根据特定条件包含或排除某些代码项。“代码项”指的是代码单元,如函数、语句或表达式。
使用 #[cfg(…)] 行,你可以条件性地编译你的代码。换句话说,你可以为不同的情况创建不同版本的代码。例如,当为 wasm32 目标编译时,编译器会忽略 #[cfg(not(target_arch = "wasm32"))] 块,只包含以下内容:
fn main() {
println!("Hello, WebAssembly!");
}
你通过表达式指定条件,例如 target_arch = "wasm32"。支持的键包括 target_os 和 target_arch。请参阅 Rust 参考文档中的完整列表的支持键。你还可以使用 Cargo 特性创建表达式,我们将在规则 6 中学习。
你可以使用逻辑运算符 not、any 和 all 组合表达式。Rust 的条件编译不使用传统的 if...then...else 语句。相反,你必须使用 #[cfg(...)] 及其否定来处理不同的情况:
#[cfg(not(target_arch = "wasm32"))]
...
#[cfg(target_arch = "wasm32")]
...
若要条件性地编译整个文件,将 #![cfg(...)] 放置在文件的顶部。(注意“!”)。当文件仅与特定目标或配置相关时,这非常有用。
你还可以在 Cargo.toml 中使用 cfg 表达式来条件性地包含依赖项。这使你能够根据不同的目标定制依赖项。例如,这表示“当不针对 wasm32 时,依赖于 Criterion 和 Rayon”。
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
criterion = { version = "0.5.1", features = ["rayon"] }
附言:关于如何在
Cargo.toml中使用cfg表达式的更多信息,请参见我的文章:九个 Rust Cargo.toml 的注意事项与误区:掌握 Cargo.toml 格式规则,避免沮丧 | Towards Data Science (medium.com)。
规则 5:定期运行测试,但使用 WASM WASI 目标。
现在是时候尝试在 WASM WASI 上运行 你的 项目了。如规则 3 中所述,为你的项目创建一个 .cargo/config.toml 文件。它告诉 Cargo 如何在 WASM WASI 上运行和测试你的项目。
[target.wasm32-wasip1]
runner = "wasmtime run --dir ."
接下来,你的项目——像所有优秀的代码一样——应该已经包含了测试。我的 range-set-blaze 项目就包括了例如这样的测试:
#[test]
fn insert_255u8() {
let range_set_blaze = RangeSetBlaze::<u8>::from_iter([255]);
assert!(range_set_blaze.to_string() == "255..=255");
}
现在让我们尝试在 WASM WASI 上运行你项目的测试。使用以下命令:
cargo test --target wasm32-wasip1
如果这能正常工作,你可能就完成了——但它可能不会正常工作。当我在 range-set-blaze 上尝试时,我得到一个错误消息,抱怨在 WASM 上使用 Rayon。
error: Rayon cannot be used when targeting wasi32\. Try disabling default features.
--> C:\Users\carlk\.cargo\registry\src\index.crates.io-6f17d22bba15001f\criterion-0.5.1\src\lib.rs:31:1
|
31 | compile_error!("Rayon cannot be used when targeting wasi32\. Try disabling default features.");
为了解决这个错误,我们必须首先理解 Cargo 特性。
规则 6:了解 Cargo 特性。
为了解决规则 5 中的 Rayon 错误,理解 Cargo 特性的工作原理非常重要。
在 Cargo.toml 中,一个可选的 [features] 部分允许你根据启用或禁用的特性定义项目的不同配置或版本。例如,这是来自 Criterion 基准测试项目 的简化版 Cargo.toml 文件的一部分:
[features]
default = ["rayon", "plotters", "cargo_bench_support"]
rayon = ["dep:rayon"]
plotters = ["dep:plotters"]
html_reports = []
cargo_bench_support = []
[dependencies]
#...
# Optional dependencies
rayon = { version = "1.3", optional = true }
plotters = { version = "⁰.3.1", optional = true, default-features = false, features = [
"svg_backend",
"area_series",
"line_series",
] }
这定义了四个 Cargo 特性:rayon、plotters、html_reports 和 cargo_bench_support。由于每个特性可以被包含或排除,这四个特性会创建 16 种可能的项目配置。还要注意特殊的默认 Cargo 特性。
一个 Cargo 特性可以包含其他 Cargo 特性。在这个例子中,特殊的 default Cargo 特性包含了三个其他 Cargo 特性——rayon、plotters 和 cargo_bench_support。
一个 Cargo 特性可以包含一个依赖项。上面的 rayon Cargo 特性包含了 rayon crate 作为一个依赖包。
此外,依赖包可能有自己的 Cargo 特性。例如,上述 plotters Cargo 特性包含了 plotters 依赖包,并启用了以下 Cargo 特性:svg_backend、area_series 和 line_series。
你可以指定在运行 cargo check、cargo build、cargo run 或 cargo test 时启用或禁用哪些 Cargo 特性。例如,如果你正在处理 Criterion 项目,并希望仅检查 html_reports 特性而不使用任何默认特性,你可以运行:
cargo check --no-default-features --features html_reports
这个命令告诉 Cargo 默认不包含任何 Cargo 特性,而是特别启用 html_reports Cargo 特性。
在你的 Rust 代码中,你可以根据启用的 Cargo 特性来包含/排除代码项。语法使用 #cfg(…),遵循规则 4:
#[cfg(feature = "html_reports")]
SOME_CODE_ITEM
通过对 Cargo 特性的理解,我们现在可以尝试修复在 WASM WASI 上运行测试时遇到的Rayon错误。
规则 7:改变你能改变的事情:通过选择 Cargo 特性来解决依赖问题,64 位/32 位问题。
当我们尝试运行cargo test --target wasm32-wasip1时,错误信息的部分内容是:Criterion ... Rayon cannot be used when targeting wasi32\. Try disabling default features. 这表明我们应该在针对 WASM WASI 时禁用 Criterion 的rayon Cargo 特性。
为了实现这一点,我们需要在Cargo.toml中进行两个更改。首先,我们需要在[dev-dependencies]部分禁用 Criterion 的rayon特性。所以,这个初始配置:
[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
变成这样,我们显式地关闭了 Criterion 的默认特性,然后启用了除了rayon以外的所有 Cargo 特性。
[dev-dependencies]
criterion = { version = "0.5.1", features = [
"html_reports",
"plotters",
"cargo_bench_support"],
default-features = false }
接下来,为了确保rayon在非 WASM 目标下仍然可用,我们通过在Cargo.toml中添加条件依赖来将其重新启用,如下所示:
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
criterion = { version = "0.5.1", features = ["rayon"] }
通常,当目标是 WASM WASI 时,你可能需要修改依赖项及其 Cargo 特性以确保兼容性。有时这个过程很简单,但有时也会充满挑战——甚至是无法解决的,正如我们将在第 8 条规则中讨论的那样。
旁注:在本系列的第三篇文章中——关于
no_std和嵌入式——我们深入探讨了修复依赖项的策略。
在再次运行测试后,我们越过了之前的错误,遇到了一个新的错误,这也算是进步!
#[test]
fn test_demo_i32_len() {
assert_eq!(demo_i32_len(i32::MIN..=i32::MAX), u32::MAX as usize + 1);
^^^^^^^^^^^^^^^^^^^^^ attempt to compute
`usize::MAX + 1_usize`, which would overflow
}
编译器抱怨u32::MAX as usize + 1溢出。在 64 位 Windows 上,这个表达式不会溢出,因为usize与u64相同,能够容纳u32::MAX as usize + 1。然而,WASM 是一个 32 位环境,所以usize与u32相同,导致表达式超出了限制。
这里的修复是将usize替换为u64,确保表达式不会溢出。更一般来说,编译器不会总是捕捉到这些问题,因此审查你使用usize和isize是很重要的。如果你在引用 Rust 数据结构的大小或索引,usize是正确的。然而,如果你处理的是超过 32 位限制的值,应该使用u64或i64。
旁注:在 32 位环境中,Rust 数组、
Vec、BTreeSet等理论上可以容纳最多 2³²−1 = 4,294,967,295 个元素。然而,这只是基于可寻址内存的理论限制。旁注 旁注:实际的最大元素数量更加有限。Rust 将我们的分配限制为一个
[isize](https://doc.rust-lang.org/std/primitive.pointer.html#method.offset),因此是 2³¹−1(大约 20 亿)字节。如果每个元素是例如 2 字节,我们最多可以有约 10 亿个元素。
所以,我们修复了依赖项问题并解决了usize溢出问题。但是,我们能修复所有问题吗?不幸的是,答案是否定的。
规则 8:接受你不能改变所有东西:网络、Tokio、Rayon 等等。
WASM WASI Preview 1(当前版本)支持文件访问(在指定的目录内)、读取环境变量、以及处理时间和随机数。然而,与你期望的完整操作系统相比,它的功能是有限的。
如果你的项目需要访问网络、使用 Tokio 进行异步任务,或使用 Rayon 进行多线程操作,不幸的是,这些功能在 Preview 1 版本中不被支持。
幸运的是,预计 WASM WASI Preview 2 将在这些限制方面有所改进,提供更多功能,包括更好的网络支持,甚至可能支持异步任务。
规则 9:将 WASM WASI 添加到你的 CI(持续集成)测试中。
所以,你的测试在 WASM WASI 上通过了,项目也成功运行。就这样结束了吗?还不完全。因为,正如我喜欢说的:
如果它不在 CI 中,那就意味着它不存在。
持续集成(CI)是一种系统,它可以在你每次更新代码时自动运行测试,确保你的代码按预期继续工作。通过将 WASM WASI 集成到 CI 中,你可以确保未来的更改不会破坏你的项目与 WASM WASI 目标的兼容性。
在我的情况下,我的项目托管在 GitHub 上,使用 GitHub Actions 作为 CI 系统。以下是我添加到 .github/workflows/ci.yml 中的配置,用来在 WASM WASI 上测试我的项目:
test_wasip1:
name: Test WASI P1
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Rust
uses: dtolnay/rust-toolchain@master
with:
toolchain: stable
targets: wasm32-wasip1
- name: Install Wasmtime
run: |
curl https://wasmtime.dev/install.sh -sSf | bash
echo "${HOME}/.wasmtime/bin" >> $GITHUB_PATH
- name: Run WASI tests
run: cargo test --verbose --target wasm32-wasip1
通过将 WASM WASI 集成到 CI 中,我可以自信地向我的项目中添加新代码。CI 会自动测试我的所有代码,确保它在未来仍然支持 WASM WASI。
所以,事情就是这样 —— 将 Rust 代码移植到 WASM WASI 的九条规则。这里是我在移植到 WASM WASI 时的惊讶之处:
缺点:
-
当前在 WASM WASI 上运行的实用性较低。然而,它有潜力在未来变得有用。
-
在 Rust 中,有一句常见的说法:“如果它能编译,就说明它能工作。” 不幸的是,这对于 WASM WASI 并不总是成立。如果你使用了一个不被支持的特性,比如网络功能,编译器不会捕捉到错误。相反,它会在运行时失败。例如,这段代码在 WASM WASI 上能编译并运行,但由于不支持网络功能,它总是返回错误。
use std::net::TcpStream;
fn main() {
match TcpStream::connect("crates.io:80") {
Ok(_) => println!("Successfully connected."),
Err(e) => println!("Failed to connect: {e}"),
}
}
优点:
-
在 WASM WASI 上运行是朝着在 浏览器中运行代码 和嵌入式系统上运行迈出的第一步。
-
你可以在 WASM WASI 上运行 Rust 代码,而不需要移植到
no_std。(移植到no_std是本系列的 第三篇文章 讨论的话题。) -
你可以在 WASM WASI 上运行标准的 Rust 测试,这使得验证你的代码变得容易。
-
.cargo/config.toml文件和 Rust 的--target选项使得在不同的目标平台上配置和运行代码变得异常简单,包括 WASM WASI。
敬请关注!在下一篇文章中,你将看到如何将 Rust 代码移植到浏览器中的 WASM 运行——这项能力我觉得非常有用。之后,最后一篇文章将讲解如何将代码移植到嵌入式系统,我觉得这非常酷。
附注: 对未来的文章感兴趣吗?请 在 Medium 上关注我。我写关于 Rust 和 Python、科学编程、机器学习和统计学的文章。我通常每月写一篇文章。
九个 Rust Cargo.toml 的 Wat 和 Wat Not
掌握 Cargo.toml 格式规则,避免挫败感
·发布于 Towards Data Science ·阅读时长 8 分钟 ·2024 年 7 月 24 日
--

Rust Cargo 惊讶 — 来源: openai.com/dall-e-2/。所有其他图表来自作者。
在 JavaScript 和其他语言中,我们称 这种令人惊讶或不一致的行为为“Wat!” [即“什么!?”]。例如,在 JavaScript 中,一个空数组加上一个空数组会产生一个空字符串,[] + [] === ""。Wat!
在另一个极端,某些语言表现出令人惊讶的一致性。我称之为“Wat Not”。
Rust 通常(远远)比 JavaScript 更一致。然而,一些与 Rust 相关的格式却带来惊喜。具体来说,这篇文章将讨论 Cargo.toml 中的九个“wat”和“wat not”。
回想一下,Cargo.toml 是定义 Rust 项目配置和依赖项的清单文件。它的格式 TOML(Tom's Obvious, Minimal Language)表示嵌套的键/值对和/或数组。JSON 和 YAML 是类似的格式。Tom 设计 TOML 时,像 YAML 一样,但不同于 JSON,他的目标是让人类更容易阅读和写作。
这段包含九个“wat”和“wat not”的旅程可能不会像 JavaScript 的怪癖那样有趣(谢天谢地)。然而,如果你曾经觉得 Cargo.toml 的格式令人困惑,希望这篇文章能让你对自己感觉更好。而且,更重要的是,当你学习了这九个“wat”和“wat not”之后,希望你能更轻松、更高效地编写 Cargo.toml。
本文并不是要“修复” Cargo.toml。该文件格式在其主要用途上表现得非常出色:指定 Rust 项目的配置和依赖项。相反,本文的目的是帮助理解这种格式及其怪癖。
什么?1:依赖项与配置文件部分名称
你可能知道如何向Cargo.toml添加[dependencies]部分。这样的部分指定发布依赖项,例如:
[dependencies]
serde = "1.0"
类似地,你可以通过[dev-dependencies]部分指定开发依赖,通过[build-dependencies]部分指定构建依赖。
你可能还需要设置编译器选项,例如优化级别和是否包含调试信息。你可以通过配置文件部分来设置这些选项,分别用于发布、开发和构建。你能猜出这三个部分的名称吗?是[profile]、[dev-profile]和[build-profile]吗?
不!是[profile.release]、[profile.dev]和[profile.build]。什么?
[dev-profile]比[profile.dev]更好吗?[dependencies.dev]比[dev-dependencies]更好吗?
我个人更喜欢带点的名称。(在“什么?不是 9”中,我们将看到点的强大功能。)不过,我也愿意记住依赖项和配置文件的工作方式不同。
什么?2:依赖继承
你可能会争辩说,点在配置文件中是可以的,但连字符在依赖项中更好,因为[dev-dependencies]继承自[dependencies]。换句话说,[dependencies]中的依赖项在[dev-dependencies]中也可以使用。那么,这是否意味着[build-dependencies]继承自[dependencies]?
不![build-dependencies]并不继承自[dependencies]。什么?
我觉得这种Cargo.toml的行为既方便又令人困惑。
什么?3:默认键
你可能知道,这样写不好:
[dependencies]
serde = { version = "1.0" }
你可以这样写:
[dependencies]
serde = "1.0"
这里的原则是什么?在一般的 TOML 中,如何指定一个键为默认键?
你不能!一般的 TOML 没有默认键。什么?
Cargo TOML 对[dependencies]部分中的version键进行特殊处理。这是 Cargo 特有的功能,而不是一般的 TOML 功能。据我所知,Cargo TOML 没有其他默认键。
什么?4:子特性
通过Cargo.toml的[features],你可以创建不同依赖项版本的项目。这些依赖项本身也可能在特性上有所不同,我们称之为子特性。
在这里,我们创建了项目的两个版本。默认版本依赖于getrandom,并使用默认特性。wasm版本依赖于getrandom,并使用js子特性:
[features]
default = []
wasm = ["getrandom-js"]
[dependencies]
rand = { version = "0.8" }
getrandom = { version = "0.2", optional = true }
[dependencies.getrandom-js]
package = "getrandom"
version = "0.2"
optional = true
features = ["js"]
在这个示例中,wasm是我们项目的一个特性,依赖于别名getrandom-rs,它表示带有js子特性的getrandom crate 的版本。
那么,我们如何在避免冗长的[dependencies.getrandom-js]部分的同时,给出相同的规范呢?
在[features]中,将getrandom-js替换为"getrandom/js"。我们可以简单地写:
[features]
default = []
wasm = ["getrandom/js"]
[dependencies]
rand = { version = "0.8" }
getrandom = { version = "0.2", optional = true }
什么?
通常,在Cargo.toml中,特性规范如wasm = ["getrandom/js"]可以列出
-
其他特性
-
依赖别名
-
依赖项
-
一个或多个依赖项“斜杠”一个子特性
这不是标准 TOML,而是Cargo.toml特有的简写。
奖励:猜猜你如何使用简写来表示 wasm 特性应该包括 getrandom,并且有两个子特性:js 和 test-in-browser?
答案:列出依赖项两次。
wasm = ["getrandom/js","getrandom/test-in-browser"]
Wat 5:目标的依赖项
我们已经看到如何为发布、调试和构建指定依赖项。
[dependencies]
#...
[dev-dependencies]
#...
[build-dependencies]
#...
我们已经看到如何为各种特性指定依赖项:
[features]
default = []
wasm = ["getrandom/js"]
你如何猜测我们为各种目标(例如某个版本的 Linux、Windows 等)指定依赖项呢?
我们使用 target.*TARGET_EXPRESSION* 来前缀 [dependences],例如:
[target.x86_64-pc-windows-msvc.dependencies]
winapi = { version = "0.3.9", features = ["winuser"] }
按照一般 TOML 的规则,这也意味着我们可以这样说:
[target]
x86_64-pc-windows-msvc.dependencies={winapi = { version = "0.3.9", features = ["winuser"] }}
Wat!
我觉得这个前缀语法很奇怪,但我无法提出更好的替代方案。不过,我确实好奇为什么特性不能以相同的方式处理:
# not allowed
[feature.wasm.dependencies]
getrandom = { version = "0.2", features=["js"]}
Wat Not 6:目标 cfg 表达式
这是我们的第一个“Wat Not”,即它是以其一致性让我感到惊讶的部分。
你可以使用 cfg 表达式(用单引号包裹)来代替具体的目标,如 x86_64-pc-windows-msvc。例如,
[target.'cfg(all(windows, target_arch = "x86_64"))'.dependencies]
我不认为这算是“wat!”,我认为这很棒。
回顾 cfg,它是“配置”(configuration)的缩写,是 Rust 用来有条件编译代码的机制。例如,在我们的 main.rs 中,我们可以写:
if cfg!(target_os = "linux") {
println!("This is Linux!");
}
在 Cargo.toml 中,目标表达式几乎支持整个 [cfg](https://doc.rust-lang.org/reference/conditional-compilation.html) 小语言。
all(), any(), not()
target_arch
target_feature
target_os
target_family
target_env
target_abi
target_endian
target_pointer_width
target_vendor
target_has_atomic
unix
windows
cfg 小语言中唯一不支持的部分是(我认为)你不能通过 --cfg 命令行参数设置一个值。另外,像 test 这样的 cfg 值没有意义。
Wat 7:目标的配置文件
回顾 Wat 1,你可以通过 [profile.release]、[profile.dev] 和 [profile.build] 设置编译器选项。例如:
[profile.dev]
opt-level = 0
你能猜到如何为特定目标(如 Windows)设置编译器选项吗?是这样吗?
[target.'cfg(windows)'.profile.dev]
opt-level = 0
否。相反,你需要创建一个名为 .cargo/config.toml 的新文件,并添加以下内容:
[target.'cfg(windows)']
rustflags = ["-C", "opt-level=0"]
Wat!
一般来说,Cargo.toml 只支持 target.*TARGET_EXPRESSION* 作为依赖项部分的前缀。你不能为配置文件部分添加前缀。但是,在 [.cargo/config.toml](https://doc.rust-lang.org/cargo/reference/config.html) 中,你可以有 [target.*TARGET_EXPRESSION*] 部分。在这些部分中,你可以设置环境变量来设置编译器选项。
Wat Not 8:TOML 列表
Cargo.toml 支持两种列表语法:
-
内联数组
-
表格数组
这个例子同时使用了以下两种:
[package]
name = "cargo-wat"
version = "0.1.0"
edition = "2021"
[dependencies]
rand = { version = "0.8" }
# Inline array 'features'
getrandom = { version = "0.2", features = ["std", "test-in-browser"] }
# Table array 'bin'
[[bin]]
name = "example"
path = "src/bin/example.rs"
[[bin]]
name = "another"
path = "src/bin/another.rs"
我们能把表格数组改成内联数组吗?可以!
# Inline array 'bin'
bins = [
{ name = "example", path = "src/bin/example.rs" },
{ name = "another", path = "src/bin/another.rs" },
]
[package]
name = "cargo-wat"
version = "0.1.0"
edition = "2021"
[dependencies]
rand = { version = "0.8" }
# Inline array 'features'
getrandom = { version = "0.2", features = ["std", "test-in-browser"] }
我们能把特性的内联数组改成表格数组吗?
否。内联数组(在这里是字符串类型)不能表示为表格数组。然而,我认为这算是一个“wat not”,而不是“wat!” 因为这是通用 TOML 的限制,而不仅仅是 Cargo.toml 的限制。
顺便提一下,YAML 格式和 TOML 格式一样,提供了两种列表语法。不过,YAML 的两种语法都能与简单值一起使用。
Wat Not 9:TOML 内联、节和点
这是一个典型的Cargo.toml。它混合了部分语法,如[dependences],以及内联语法,如getrandom = {version = "0.2", features = ["std", "test-in-browser"]}。
[package]
name = "cargo-wat"
version = "0.1.0"
edition = "2021"
[dependencies]
rand = "0.8"
getrandom = { version = "0.2", features = ["std", "test-in-browser"] }
[target.x86_64-pc-windows-msvc.dependencies]
winapi = { version = "0.3.9", features = ["winuser"] }
[[bin]]
name = "example"
path = "src/bin/example.rs"
[[bin]]
name = "another"
path = "src/bin/another.rs"
我们能把它重写成 100%内联的形式吗?可以。
package = { name = "cargo-wat", version = "0.1.0", edition = "2021" }
dependencies = { rand = "0.8", getrandom = { version = "0.2", features = [
"std",
"test-in-browser",
] } }
target = { 'cfg(target_os = "windows")'.dependencies = { winapi = { version = "0.3.9", features = [
"winuser",
] } } }
bins = [
{ name = "example", path = "src/bin/example.rs" },
{ name = "another", path = "src/bin/another.rs" },
]
我们也可以将其重写为最大化使用部分:
[package]
name = "cargo-wat"
version = "0.1.0"
edition = "2021"
[dependencies.rand]
version = "0.8"
[dependencies.getrandom]
version = "0.2"
features = ["std", "test-in-browser"]
[target.x86_64-pc-windows-msvc.dependencies.winapi]
version = "0.3.9"
features = ["winuser"]
[[bin]]
name = "example"
path = "src/bin/example.rs"
[[bin]]
name = "another"
path = "src/bin/another.rs"
最后,让我们谈谈点号。在 TOML 中,点号用于分隔嵌套表中的键。例如,a.b.c表示表a中的表b中的键c。我们能用“很多点号”重写我们的例子吗?可以:
package.name = "cargo-wat"
package.version = "0.1.0"
package.edition = "2021"
dependencies.rand = "0.8"
dependencies.getrandom.version = "0.2"
dependencies.getrandom.features = ["std", "test-in-browser"]
target.x86_64-pc-windows-msvc.dependencies.winapi.version = "0.3.9"
target.x86_64-pc-windows-msvc.dependencies.winapi.features = ["winuser"]
bins = [
{ name = "example", path = "src/bin/example.rs" },
{ name = "another", path = "src/bin/another.rs" },
]
我很欣赏 TOML 在处理部分、内联和点号方面的灵活性。我将这种灵活性视为一种“wat not”。你可能会觉得它提供的所有选择令人困惑。不过,我喜欢Cargo.toml让我们能够充分发挥 TOML 的强大功能。
你可以在GitHub 上查看这个例子。
结论
Cargo.toml是 Rust 生态系统中的一个重要工具,它提供了简单性和灵活性的平衡,既适合初学者,也适合经验丰富的开发者。通过我们探索的九个 wats 和 wat nots,我们看到这个配置文件有时因其怪癖而令人惊讶,但也因其一致性和强大功能而让人印象深刻。
理解这些怪癖可以帮助你避免潜在的挫折,并使你能够最大化地利用Cargo.toml。从管理依赖关系和配置文件到处理特定目标的配置和特性,在这里获得的洞察将帮助你编写更高效、更有效的Cargo.toml文件。
本质上,虽然Cargo.toml可能有些特殊性,但这些特性通常根源于实际的设计选择,优先考虑功能性和可读性。接受这些怪癖,你会发现Cargo.toml不仅能满足你项目的需求,还能提升你的 Rust 开发体验。
请 关注 Carl 在 Medium 上的文章。我在 Rust 和 Python 的科学编程、机器学习以及统计学方面写作。我通常每个月写一篇文章。
图解 NLP,第一部分:文本编码
一份关于文本到数字翻译的图解指南,带有代码
·发布于Towards Data Science ·阅读时长 11 分钟·2024 年 11 月 19 日
--
欢迎回到这个互联网角落,在这里我们将复杂的机器学习概念形象化——最终发现它们其实并没有那么复杂!
今天,我们将启动一个关于自然语言处理(NLP)的新系列。这令人兴奋,因为 NLP 是我们在各个地方看到的所有炫酷的大型语言模型(LLMs)的基础——想想 Claude、GPT 和 Llama。
简单来说,NLP 帮助机器理解人类语言——无论是理解、分析,还是生成语言。
如果你一直在跟随我们的深度学习之旅,我们已经了解到,神经网络的核心原理非常简单:它们接受输入,进行数学运算,然后输出结果。

然而,为了让神经网络做到这一点,输入和输出必须以它们能够理解的格式呈现:数字。
这一规则适用于我们处理简单模型时……

NLP 插图,第二部分:词嵌入
一本关于词嵌入的插图和直观指南
·发布于Towards Data Science ·阅读时间:8 分钟·2024 年 11 月 27 日
--
欢迎来到我们 NLP 系列的第二部分。如果你已经阅读过第一部分,你会记得我们正在解决的挑战是将文本转换为数字,以便将其输入到我们的机器学习模型或神经网络中。
一本关于文本到数字转换的插图指南,包含代码
towardsdatascience.com
之前,我们探讨了一些基本的(也非常初步的)方法,如词袋模型和 TF-IDF。虽然这些方法能够完成任务,但我们也看到了它们的局限性——主要是它们无法捕捉单词的深层含义或单词之间的关系。
这就是词嵌入发挥作用的地方。它们提供了一种更智能的方式来将文本表示为数字,不仅捕捉了单词本身,还能够表达它们的含义和上下文。
让我们通过一个简单的类比来分解这个概念,使它变得非常直观。
假设我们想将电影表示为数字。以电影利刃出鞘为例。

来源:维基百科
我们可以通过在不同特征上对电影进行评分来用数字表示一部电影,比如…
NLP:房地产出租房源的文本摘要与关键词提取——第一部分
在出租房源数据上实施 NLP 技术的实际应用,例如文本摘要、NER、主题建模和文本分类
·发表于Towards Data Science ·阅读时间 10 分钟·2024 年 7 月 8 日
--
介绍
自然语言处理(NLP)可以显著提升出租房源描述的分析和可用性。在本次实践中,我们将探索 NLP 技术的实际应用,例如文本摘要、命名实体识别(NER)和主题建模,以提取洞察并丰富东京 Airbnb 房源数据的描述。使用公开的可用数据和像 spaCy 与 SciKit-Learn 这样的工具,您可以跟随教程,复制结果,或将这些技术应用于自己的文本数据,只需进行最小的调整。代码库可在GitHub上获取,您可以进行分叉并进行实验。

本文展示了使用多种 NLP 技术,从房产出租描述数据(左)中提取信息,并将其转化为更具信息量的描述(右)。文中的所有图片均由作者制作。代码和 Jupyter 笔记本可在GitHub上找到,数据可以在insideairbnb.com上获取,并遵循创意共享署名协议。
第一部分(本文)涵盖基础内容: 目标、数据及其准备工作,以及用于提取关键词和文本摘要的各种技术,如命名实体识别(NER)、TF-IDF / 句子评分、以及谷歌的 T5(文本到文本的转换器)。我们还将涉及如何利用这些见解来提升用户体验 — 包括服务建议。
第二部分(即将发布) 涵盖主题建模和文本预测:第二部分将展示如何在无标签数据上执行主题建模。即将发布的文章将讨论诸如聚类等技术,帮助揭示隐藏的主题,并构建一个预测模型,以根据房源类别和主题对租赁房源进行分类。
目标
任务很简单:
给定的示例输入: 租赁描述
生成输出:
-
关键词: “商业街”、”商店”、或 “靠近车站”
关键词有助于可视化数据、揭示主题、识别相似性,并改善前端的搜索功能。有关如何使用这些关键词的建议,请参见本文底部。
-
摘要: 一到两句话,约 80 个字符。
摘要提供简洁的信息,通过快速传达列表中的最重要方面,提升用户体验。
-
主题/话题: “优越的交通连接”、”适合家庭入住”
对共享相同主题的房源进行分类可以作为推荐系统,帮助用户找到符合他们偏好的房源。与单个关键词不同,这些主题可以涵盖多个关键词(如厨房、桌子、单人床、长期出租 => “数字游牧者友好”)。我们将在第二部分(即将发布的文章)深入讨论这个问题。
章节:
-
数据与准备
获取数据、清理数据、定制词形还原
-
文本摘要
TFIDF/句子评分、深度学习、LLM(T5)、评估
-
使用 NER 提取关键词
正则表达式、匹配器、深度学习
-
服务建议
1. 数据与准备
我们的数据集由来自insideairbnb.com的租赁房源描述组成,遵循创意共享署名 4.0 国际许可证。我们专注于物业所有者撰写的文本。数据包含近 15,000 个租赁描述,主要为英文。用日文书写的记录(令人惊讶的是,只有少数几条!)在数据清理过程中已被移除,数据清理还包括去除重复记录和刮取器留下的 HTML 残余。由于大量数据去重,可能是由于网络抓取工具的副产品,或者可能是更复杂的问题(例如,房东发布了多个相同的房源),数据清理使得原始数据量减少了约一半。
1a. spaCy 流水线
一旦数据清洗完成,我们就可以开始构建 spaCy 管道。我们可以从一个空白模板开始,或者使用像 en_core_web_sm 这样的预训练模型来处理英文文档。这个模型包含一个强大的管道,包含:
-
分词(Tokenization): 将文本拆分为单词、标点符号等。
-
词性标注(Part-of-Speech Tagging): 将单词标记为名词、动词等。
-
依存句法分析(Dependency Parsing): 识别单词之间的关系。
-
句子分割器(Sentencizer): 将文档拆分为句子。
-
词形还原(Lemmatization): 将词汇简化为其基本形式(例如,seeing、see、saw、seen)。
-
属性规则(Attribute Ruler): 添加、删除或更改标记的属性。
-
命名实体识别(NER): 识别命名实体的类别(人名、地名等)。
1b. 自定义词形还原
即使是像 en_core_web_sm 这样的经过严格测试的管道,通常也需要进行调整以涵盖特定的用例。例如,租赁行业中常用的缩写(例如,br 代表卧室,apt 代表公寓,st 代表街道)可以通过自定义词形还原引入到管道中。为了评估这一点,我们可以比较在有和没有自定义词形还原的管道中,token.lemma_的数量。如果需要,还可以使用其他更强大的预制管道,如 en_core_web_md(中型)或 en_core_web_lg(大型)。
在生产级项目中,需要更全面的列表,可能还需要更严格的数据清洗。例如,表情符号和类似表情符号的符号经常出现在受文化影响的写作中,如日本用户的写作中。这些符号可能会引入噪音,需要特定的处理,如删除或转换。其他数据预处理,如更强大的句子边界检测器,也可能是必要的,以处理缺少空格的句子,例如“这是一个句子。这也是。还有这个。还有这个。但是,不,这个 Next.js 是一个有效的术语,而不是两个句子!”
2. 文本摘要
在东京选择租赁选项可能让人不知所措。每个房源都声称是理想的家。然而,数据显示,房产描述常常不尽如人意——它们可能过于冗长,令人沮丧地简短,或者被不相关的细节弄得杂乱无章;这就是为什么文本摘要技术非常有用的原因。

句子评分,以选择最具信息量的句子作为摘要(右图),来自描述(左图)。
2a. 难度级别:简单 — TF-IDF
一种典型的文本摘要方法是使用 TF-IDF(词频-逆文档频率)技术。TF-IDF 同时考虑一个单词在特定文档(例如租赁列表)中的出现频率以及它在整个数据集或语料库中出现的稀有程度。这项技术对于各种文本分析任务也非常有用,如索引、相似度检测和聚类(我们将在第二部分中探讨)。
根据检测到的关键词的相关性计算句子的排名。
该技术的另一种变体是基于词共现的句子评分。与 TF-IDF 类似,这种方法通过比较文档中的单词出现情况来计算得分。该方法快速且简单,不需要额外的工具或其他文档的意识。即使在前端使用 TypeScript,你也可以随时执行此操作,尽管不推荐这样做。
然而,像这样的抽取式摘要技术有一个缺陷:它们只找到文档中的最佳句子,这意味着所选句子中的拼写错误或其他问题会出现在摘要中。这些拼写错误还会影响评分,使得该模型对错误不够宽容,而没有包括在所选句子(或句子)中的重要信息可能会被遗漏。
2b. 级别:中级 — 深度学习
除了基于频率的方法,我们还可以利用深度学习的力量进行文本摘要。序列到序列(Seq2Seq)模型是一种神经网络架构,旨在将序列从一种形式转换为另一种形式。在文本摘要任务中,这些模型充当复杂的翻译器。
一个 Seq2Seq 模型通常由两部分组成:编码器和解码器。编码器处理整个输入文本,捕捉其含义和结构。然后,这些信息被压缩成一个隐藏的表示。接下来,解码器使用来自编码器的隐藏表示生成新的序列——文本摘要。在训练过程中,解码器学习如何将捕捉原始文本关键信息的编码表示进行转换。与抽取式方法不同,这些模型执行抽象式摘要:用自己的话生成摘要,而不是直接从文本中提取句子。
2c. 级别:高级 — 预训练语言模型
对于抽象式摘要,可以考虑使用 T5(Text-To-Text Transfer Transformer)模型。虽然 t5-small 提供了一个不错的起点,但你可能会通过更大的模型,如 t5-base 或 t5-large,获得更好的结果。请注意,较大的模型可能需要更多的计算资源,并且运行时间较长。

大型语言模型(LLMs)能够以富有创意的方式进行文档总结(不仅仅是复制句子),但要获得最佳效果,可能需要在摘要过程中、摘要前后进行额外的步骤,包括适当的提示工程。
像 T5(Text-To-Text Transfer Transformer)或 BERT(Bidirectional Encoder Representations from Transformers)这样的预训练语言模型,可以显著提升摘要效果,适用于那些拥有资源和设置能力的人。然而,尽管这些模型对于大文本有效,但它们可能对于这个特定的用例来说有些过于复杂。它不仅需要更多的设置来发挥最佳效果,还包括提示工程(预处理)、重新训练或微调,以及后处理(如语法、文本大写,甚至事实检查和合理性检查),以引导模型达到期望的输出。
2.d 评估文本摘要

提取式(左)与抽象式(右)文本摘要。鉴于摘要质量是主观的,哪种摘要更优并不总是明确的。考虑到所需的努力、成本和计算能力,比较变得更加复杂。
从上面的图片可以看到,当比较使用 TFIDF 的“简单”模型与使用 LLM 的复杂模型时,哪种模型更优并不总是显而易见的。评估文本摘要系统的质量是一个复杂的挑战。与有明确单一答案的任务不同,对于给定文本,并没有一个完美的摘要。人类可以优先考虑原始内容的不同方面,这使得设计与人类判断完全一致的自动化评估指标变得更加困难。
像 ROUGE(召回导向总结评估指标)这样的评估指标旨在实现这一点。通过比较生成的摘要与人工编写的摘要之间的 n-gram(词组序列)重叠,ROUGE 系统地评分摘要的质量。该方法依赖于一组人工编写的摘要作为评估的基准,但这些人工摘要通常并不总是可用的。
3. 使用命名实体识别(NER)进行关键词提取
尽管摘要很有帮助,但关键词有不同的用途。关键词捕捉了潜在租客可能关注的最关键方面。为了提取关键词,我们可以使用 NLP 技术,例如命名实体识别(NER)。这个过程不仅仅是识别频繁出现的词汇。通过考虑诸如词语共现和与租赁列表领域相关性等因素,我们可以提取出关键信息。这些信息可以是单个词,例如‘豪华的’(形容词)、‘银座’(地点),或者像‘安静的环境’(名词短语)或‘靠近新宿’(接近性)这样的短语。

评估 NER:SpaCy 内置的命名实体识别(NER)表现良好,但某些实体类型可能需要额外的训练数据以达到最佳准确度。(NER 代表命名实体识别,GPE:地理政治实体)
3a. 难度:简单 — 正则表达式
字符串操作中的“find”函数,加上正则表达式,可以完成关键词查找的工作。然而,这种方法需要一个详尽的单词和模式列表,而这在某些情况下并不实际。如果有一个详尽的关键词列表可供查找(例如,金融相关项目中的股票交易所缩写),正则表达式可能是最简单的方式。
3b. 水平:中级 — 匹配器
虽然正则表达式可以用于简单的关键词提取,但由于需要大量的规则列表,覆盖所有情况变得非常困难。幸运的是,大多数自然语言处理(NLP)工具都具备开箱即用的命名实体识别(NER)功能。例如,Natural Language Toolkit(NLTK)有命名实体分块器,而 spaCy 则有匹配器(Matcher)。
匹配器允许你根据语言特征,如词性标签或特定关键词,定义模式。这些模式可以与租赁描述进行匹配,从而识别相关的关键词和短语。这种方法能够捕捉单个词(如东京)和有意义的短语(如美丽的房子),这些更能代表房产的卖点。
noun_phrases_patterns = [
[{'POS': 'NUM'}, {'POS': 'NOUN'}], #example: 2 bedrooms
[{'POS': 'ADJ', 'OP': '*'}, {'POS': 'NOUN'}], #example: beautiful house
[{'POS': 'NOUN', 'OP': '+'}], #example: house
]
# Geo-political entity
gpe_patterns = [
[{'ENT_TYPE': 'GPE'}], #example: Tokyo
]
# Proximity
proximity_patterns = [
# example: near airport
[{'POS': 'ADJ'}, {'POS': 'ADP'}, {'POS': 'NOUN', 'ENT_TYPE': 'FAC', 'OP': '?'}],
# example: near to Narita
[{'POS': 'ADJ'}, {'POS': 'ADP'}, {'POS': 'PROPN', 'ENT_TYPE': 'FAC', 'OP': '?'}]
]
3c. 水平:高级 — 基于深度学习的匹配器
即使使用匹配器,一些术语也可能由于句子中单词的上下文未被规则匹配捕获。例如,匹配器可能会漏掉像“离上野公园一箭之遥”这样的术语,因为它无法通过任何预定义的模式,或者将“新宿歌舞伎町”误认为是一个人名(它是一个区域,或者是地点(LOC))。
在这种情况下,基于深度学习的方法可能更为有效。通过在包含相关关键词的大量租赁列表上进行训练,这些模型能够学习单词之间的语义关系。这使得这种方法更能适应不断变化的语言使用,并能够揭示潜在的洞察。
使用 spaCy 进行基于深度学习的 NER 非常简便。然而,这种方法的主要构建块通常是标注的训练数据,正如本次练习中的情况一样。标签是目标术语和实体名称的配对(例如:“a stone throw away”是名词短语——或者如图所示:新宿歌舞伎町是一个地点(LOC),而非人名),并以特定的方式格式化。与基于规则的方法不同,基于规则的方法是通过内置功能将术语描述为名词、地点等,而在此方法中,需要数据探索或领域专家来发现我们想要识别的目标术语。
本文的第二部分将讨论使用聚类、引导法和其他方法从数据中发现主题或标签的技术,以进行主题建模。
4. 食用建议
提取的关键词对后台和前端应用都非常有价值。我们可以利用它们进行各种后续分析,如主题和话题探索(将在第二部分讨论)。在前端,这些关键词可以帮助用户找到具有相似特征的列表——可以将它们视为 Instagram 或 Twitter 上的标签(但这是自动的!)。你还可以突出显示这些关键词,或者让它们成为可点击的链接。例如,命名实体识别(NER)可以识别出诸如“Iidabashi”或“Asakusa”这样的地点。当用户将鼠标悬停在这些关键词上时,弹出窗口可以显示有关这些地点的相关信息。
摘要提供了列表的简洁概述,非常适合快速掌握关键信息,或用于移动设备显示。

关键词和文本摘要可以丰富用户体验。在这个例子中,我们使用提取的文本摘要来快速概览列表描述。选择的关键词(例如 LOC)也被用来提供更多列表描述的背景。这一过程可以在后台进行(以提高加载速度),也可以在前端进行(以提高便利性)。
向前迈进
在本文中,我们展示了各种自然语言处理(NLP)技术的实际应用,如文本摘要和命名实体识别(NER)在租赁列表数据集上的应用。这些技术通过提供简洁、信息丰富且易于搜索的租赁列表,能够显著改善用户体验。
在接下来的文章(第二部分)中,我们将使用聚类等方法来发现隐藏的主题和标签。这将使我们能够构建一个强大的模型,充当推荐引擎。我们还将进一步探索像主题建模和文本分类等高级 NLP 技术,以增强租赁列表描述的分析和可用性。
以上です★これからもうよろしくおねがいします☆また今度。
**注意:
- Github 仓库:https://github.com/kristiyanto/nlp_on_airbnb_dataset 2) 数据 (创作共用 4.0 国际许可协议):** https://insideairbnb.com/get-the-data/ 3) 本文中的所有图片由作者制作。
无代码 GenAI 代理工作流编排:带本地 Mistral AI 模型的 AutoGen Studio
AutoGen 和 Mistral AI 介绍:
·发布于 Towards Data Science ·9 分钟阅读·2024 年 1 月 23 日
--
免费链接 — 请帮忙点赞这个 LinkedIn 帖子
AutoGen 是微软开发的一个框架,旨在简化多代理应用程序的开发,特别是在编排大型语言模型(LLM)代理方面。
多代理应用程序涉及多个 LLM 或多模态代理或实体在整个工作流中相互作用,以实现特定的目标或任务。这些代理可以是 LLM 代理、检索代理或其他能够做出独立决策、调用功能或采取行动的代理。
如果你想了解更多关于 AutoGen 的信息,可以参考我之前的文章:AutoGen 深入浅出.
Mistral AI 是一家法国的人工智能公司,由前 Meta 和 Google 的研究人员于 2023 年 4 月创立。该公司专注于开发开放的大型语言模型(LLM),并强调开源人工智能模型的重要性。
在本文中,我们将重点介绍 AutoGen Studio 直观的无代码平台与本地集成的 Mistral AI 模型的革命性融合。这种结合不仅仅是让人工智能更容易应用,它还促进了我们如何与不同的生成性 AI 代理互动、部署和从中受益,尤其是在许多现实行业工作流程中。
没有 GPU,没派对:使用 Vertex AI 自定义作业微调 BERT 进行情感分析
通过无服务器作业加速训练过程
·发表于Towards Data Science ·13 分钟阅读·2024 年 6 月 3 日
--

TL;DR:如何在 Vertex 上使用 GPU 启动 Pytorch 训练作业。包含示例代码。
在我之前的文章中,我提到过,当本地训练大规模模型时,如果资源有限,这并不是一个好习惯。有时候你根本没有选择,但有时候,你可以使用像 Google Cloud Platform 这样的云服务提供商,它能够显著加速你的训练过程,方法如下:
-
提供您定制配置(内存、GPU 等)的先进机器
-
允许您同时启动多个作业,并最终选择最佳模型
更不用说,将训练任务卸载到云端将减轻您个人机器的负担。我已经亲眼看到,个人笔记本电脑训练模型一周后电池几乎融化。假期回来后,我的触控板居然快要掉出来了。
在本文中,我们将以一个具体的用例为例,展示如何在社交媒体评论上微调 BERT 模型进行情感分析。正如我们所看到的,使用 CPU 训练这种模型是非常繁琐且不理想的。因此,我们将探讨如何利用 Google Cloud Platform 通过仅花费 60 美分来加速这一过程,使用 GPU 进行训练。
总结
-
什么是 BERT
-
什么是情感分析
-
获取并准备数据
-
使用小型 BERT 预训练模型
-
创建数据加载器。
-
编写主要脚本以训练模型。
-
将脚本 Docker 化。
-
构建并推送镜像到 Google Cloud。
-
在 Vertex AI 上创建一个作业。
什么是 BERT?
BERT 代表双向编码器表示(Bidirectional Encoder Representations from Transformers),由 Google 于 2018 年开源。它主要用于 NLP 任务,因为它被训练用来捕捉句子的语义并提供丰富的词向量(表示)。与其他模型如 Word2Vec 和 Glove 的不同之处在于,它使用 Transformers 来处理文本。Transformers(如果你想了解更多,可以参考我之前的文章)是一类神经网络,它们有点像 RNN,可以双向处理序列,因此能够捕捉到例如一个词的上下文。
什么是情感分析?
情感分析是 NLP 领域中的一项特定任务,目标是将文本分类为与其情感色彩相关的类别。情感色彩通常表现为积极、消极或中立。它通常用于分析文字记录、社交媒体上的帖子、产品评论等。
在社交媒体数据上微调 BERT 模型。
获取和准备数据。
我们将使用的数据集来自 Kaggle,你可以在这里下载:www.kaggle.com/datasets/farisdurrani/sentimentsearch(CC BY 4.0 许可证)。在我的实验中,我只选择了来自 Facebook 和 Twitter 的数据集。
以下代码片段将处理 csv 文件,并将数据分割为 3 部分(训练集、验证集和测试集),然后保存到你指定的位置。我建议将它们保存在 Google Cloud Storage 中。
你可以使用以下命令运行脚本:
python make_splits --output-dir gs://your-bucket/
import pandas as pd
import argparse
import numpy as np
from sklearn.model_selection import train_test_split
def make_splits(output_dir):
df=pd.concat([
pd.read_csv("data/farisdurrani/twitter_filtered.csv"),
pd.read_csv("data/farisdurrani/facebook_filtered.csv")
])
df = df.dropna(subset=['sentiment'], axis=0)
df['Target'] = df['sentiment'].apply(lambda x: 1 if x==0 else np.sign(x)+1).astype(int)
df_train, df_ = train_test_split(df, stratify=df['Target'], test_size=0.2)
df_eval, df_test = train_test_split(df_, stratify=df_['Target'], test_size=0.5)
print(f"Files will be saved in {output_dir}")
df_train.to_csv(output_dir + "/train.csv", index=False)
df_eval.to_csv(output_dir + "/eval.csv", index=False)
df_test.to_csv(output_dir + "/test.csv", index=False)
print(f"Train : ({df_train.shape}) samples")
print(f"Val : ({df_eval.shape}) samples")
print(f"Test : ({df_test.shape}) samples")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output-dir')
args, _ = parser.parse_known_args()
make_splits(args.output_dir)
数据大致应如下所示:

(作者提供的图片)
使用小型 BERT 预训练模型。
对于我们的模型,我们将使用轻量级 BERT 模型 BERT-Tiny。该模型已经在大量数据上进行了预训练,但不一定是社交媒体数据,也不一定是为了进行情感分析而预训练的。因此,我们将对其进行微调。
它仅包含 2 层,每层有 128 个单元,完整的模型列表可以在这里查看,如果你想使用更大的模型。
首先让我们创建一个main.py文件,包含所有必要的模块:
import pandas as pd
import argparse
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import logging
import os
os.environ["TFHUB_MODEL_LOAD_FORMAT"] = "UNCOMPRESSED"
def train_and_evaluate(**params):
pass
# will be updated as we go
让我们也在一个专用的requirements.txt文件中写下我们的需求。
transformers==4.40.1
torch==2.2.2
pandas==2.0.3
scikit-learn==1.3.2
gcsfs
我们现在将加载两部分数据来训练我们的模型:
-
分词器,它将负责将文本输入拆分为 BERT 训练时使用的词元。
-
模型本身。
你可以从 Huggingface这里获得这两者。你也可以将它们下载到 Cloud Storage 中。我就是这么做的,因此会通过以下方式加载它们:
# Load pretrained tokenizers and bert model
tokenizer = BertTokenizer.from_pretrained('models/bert_uncased_L-2_H-128_A-2/vocab.txt')
model = BertModel.from_pretrained('models/bert_uncased_L-2_H-128_A-2')
现在让我们将以下内容添加到文件中:
class SentimentBERT(nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert_module = bert_model
self.dropout = nn.Dropout(0.1)
self.final = nn.Linear(in_features=128, out_features=3, bias=True)
# Uncomment the below if you only want to retrain certain layers.
# self.bert_module.requires_grad_(False)
# for param in self.bert_module.encoder.parameters():
# param.requires_grad = True
def forward(self, inputs):
ids, mask, token_type_ids = inputs['ids'], inputs['mask'], inputs['token_type_ids']
# print(ids.size(), mask.size(), token_type_ids.size())
x = self.bert_module(ids, mask, token_type_ids)
x = self.dropout(x['pooler_output'])
out = self.final(x)
return out
在这里稍作休息。我们在重用现有模型时有几种选择。
-
迁移学习:我们冻结模型的权重,并将其作为“特征提取器”。因此,我们可以在后续添加额外的层。这在计算机视觉中非常常见,比如 VGG、Xception 等模型可以在小数据集上被重新训练,作为自定义模型的一部分。
-
微调:我们解冻模型的全部或部分权重,并在自定义数据集上重新训练模型。这是在训练自定义 LLM 时的首选方法。
关于迁移学习和微调的更多细节,请参见这里:
在模型中,我们选择解冻整个模型,但也可以选择冻结预训练 BERT 模块中的一层或多层,看看它对性能的影响。
这里的关键是,在 BERT 模块后添加一个全连接层,将其“连接”到我们的分类任务,因此最终的层有 3 个单元。这将使我们能够重用预训练 BERT 的权重,并将我们的模型调整到我们的任务。
创建数据加载器
要创建数据加载器,我们将需要上述加载的 Tokenizer。Tokenizer 接受一个字符串作为输入,并返回多个输出,其中包括我们可以找到的标记(在我们的案例中是‘input_ids’):

BERT 的分词器有点特殊,它会返回多个输出,但最重要的是input_ids:它们是用于编码我们句子的标记。它们可能是单词,或者单词的一部分。例如,单词“looking”可能由两个标记组成:“look”和“##ing”。
现在让我们创建一个数据加载器模块来处理我们的数据集:
class BertDataset(Dataset):
def __init__(self, df, tokenizer, max_length=100):
super(BertDataset, self).__init__()
self.df=df
self.tokenizer=tokenizer
self.target=self.df['Target']
self.max_length=max_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
X = self.df['bodyText'].values[idx]
y = self.target.values[idx]
inputs = self.tokenizer.encode_plus(
X,
pad_to_max_length=True,
add_special_tokens=True,
return_attention_mask=True,
max_length=self.max_length,
)
ids = inputs["input_ids"]
token_type_ids = inputs["token_type_ids"]
mask = inputs["attention_mask"]
x = {
'ids': torch.tensor(ids, dtype=torch.long).to(DEVICE),
'mask': torch.tensor(mask, dtype=torch.long).to(DEVICE),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long).to(DEVICE)
}
y = torch.tensor(y, dtype=torch.long).to(DEVICE)
return x, y
编写训练模型的主脚本
首先,让我们定义两个函数来处理训练和评估步骤:
def train(epoch, model, dataloader, loss_fn, optimizer, max_steps=None):
model.train()
total_acc, total_count = 0, 0
log_interval = 50
start_time = time.time()
for idx, (inputs, label) in enumerate(dataloader):
optimizer.zero_grad()
predicted_label = model(inputs)
loss = loss_fn(predicted_label, label)
loss.backward()
optimizer.step()
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
if idx % log_interval == 0:
elapsed = time.time() - start_time
print(
"Epoch {:3d} | {:5d}/{:5d} batches "
"| accuracy {:8.3f} | loss {:8.3f} ({:.3f}s)".format(
epoch, idx, len(dataloader), total_acc / total_count, loss.item(), elapsed
)
)
total_acc, total_count = 0, 0
start_time = time.time()
if max_steps is not None:
if idx == max_steps:
return {'loss': loss.item(), 'acc': total_acc / total_count}
return {'loss': loss.item(), 'acc': total_acc / total_count}
def evaluate(model, dataloader, loss_fn):
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
for idx, (inputs, label) in enumerate(dataloader):
predicted_label = model(inputs)
loss = loss_fn(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return {'loss': loss.item(), 'acc': total_acc / total_count}
我们离让主脚本运行起来越来越近了。让我们将各个部分拼接在一起。我们有:
-
一个
BertDataset类,用于处理数据的加载 -
一个
SentimentBERT模型,它基于我们的 Tiny-BERT 模型并添加了一个额外的层来适应我们的自定义用例 -
train()和eval()函数,用于处理这些步骤 -
一个
train_and_eval()函数,将所有内容组合在一起
我们将使用argparse来使我们能够通过参数启动脚本。这些参数通常是训练/评估/测试文件,用于将数据集传递给我们的模型,模型存储路径以及与训练相关的参数。
import pandas as pd
import time
import torch.nn as nn
import torch
import logging
import numpy as np
import argparse
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', level=logging.DEBUG)
logging.getLogger().setLevel(logging.INFO)
# --- CONSTANTS ---
BERT_MODEL_NAME = 'small_bert/bert_en_uncased_L-2_H-128_A-2'
if torch.cuda.is_available():
logging.info(f"GPU: {torch.cuda.get_device_name(0)} is available.")
DEVICE = torch.device('cuda')
else:
logging.info("No GPU available. Training will run on CPU.")
DEVICE = torch.device('cpu')
# --- Data preparation and tokenization ---
class BertDataset(Dataset):
def __init__(self, df, tokenizer, max_length=100):
super(BertDataset, self).__init__()
self.df=df
self.tokenizer=tokenizer
self.target=self.df['Target']
self.max_length=max_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
X = self.df['bodyText'].values[idx]
y = self.target.values[idx]
inputs = self.tokenizer.encode_plus(
X,
pad_to_max_length=True,
add_special_tokens=True,
return_attention_mask=True,
max_length=self.max_length,
)
ids = inputs["input_ids"]
token_type_ids = inputs["token_type_ids"]
mask = inputs["attention_mask"]
x = {
'ids': torch.tensor(ids, dtype=torch.long).to(DEVICE),
'mask': torch.tensor(mask, dtype=torch.long).to(DEVICE),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long).to(DEVICE)
}
y = torch.tensor(y, dtype=torch.long).to(DEVICE)
return x, y
# --- Model definition ---
class SentimentBERT(nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert_module = bert_model
self.dropout = nn.Dropout(0.1)
self.final = nn.Linear(in_features=128, out_features=3, bias=True)
def forward(self, inputs):
ids, mask, token_type_ids = inputs['ids'], inputs['mask'], inputs['token_type_ids']
x = self.bert_module(ids, mask, token_type_ids)
x = self.dropout(x['pooler_output'])
out = self.final(x)
return out
# --- Training loop ---
def train(epoch, model, dataloader, loss_fn, optimizer, max_steps=None):
model.train()
total_acc, total_count = 0, 0
log_interval = 50
start_time = time.time()
for idx, (inputs, label) in enumerate(dataloader):
optimizer.zero_grad()
predicted_label = model(inputs)
loss = loss_fn(predicted_label, label)
loss.backward()
optimizer.step()
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
if idx % log_interval == 0:
elapsed = time.time() - start_time
print(
"Epoch {:3d} | {:5d}/{:5d} batches "
"| accuracy {:8.3f} | loss {:8.3f} ({:.3f}s)".format(
epoch, idx, len(dataloader), total_acc / total_count, loss.item(), elapsed
)
)
total_acc, total_count = 0, 0
start_time = time.time()
if max_steps is not None:
if idx == max_steps:
return {'loss': loss.item(), 'acc': total_acc / total_count}
return {'loss': loss.item(), 'acc': total_acc / total_count}
# --- Validation loop ---
def evaluate(model, dataloader, loss_fn):
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
for idx, (inputs, label) in enumerate(dataloader):
predicted_label = model(inputs)
loss = loss_fn(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return {'loss': loss.item(), 'acc': total_acc / total_count}
# --- Main function ---
def train_and_evaluate(**params):
logging.info("running with the following params :")
logging.info(params)
# Load pretrained tokenizers and bert model
# update the paths to whichever you are using
tokenizer = BertTokenizer.from_pretrained('models/bert_uncased_L-2_H-128_A-2/vocab.txt')
model = BertModel.from_pretrained('models/bert_uncased_L-2_H-128_A-2')
# Training parameters
epochs = int(params.get('epochs'))
batch_size = int(params.get('batch_size'))
learning_rate = float(params.get('learning_rate'))
# Load the data
df_train = pd.read_csv(params.get('training_file'))
df_eval = pd.read_csv(params.get('validation_file'))
df_test = pd.read_csv(params.get('testing_file'))
# Create dataloaders
train_ds = BertDataset(df_train, tokenizer, max_length=100)
train_loader = DataLoader(dataset=train_ds,batch_size=batch_size, shuffle=True)
eval_ds = BertDataset(df_eval, tokenizer, max_length=100)
eval_loader = DataLoader(dataset=eval_ds,batch_size=batch_size)
test_ds = BertDataset(df_test, tokenizer, max_length=100)
test_loader = DataLoader(dataset=test_ds,batch_size=batch_size)
# Create the model
classifier = SentimentBERT(bert_model=model).to(DEVICE)
total_parameters = sum([np.prod(p.size()) for p in classifier.parameters()])
model_parameters = filter(lambda p: p.requires_grad, classifier.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
logging.info(f"Total params : {total_parameters} - Trainable : {params} ({params/total_parameters*100}% of total)")
# Optimizer and loss functions
optimizer = torch.optim.Adam([p for p in classifier.parameters() if p.requires_grad], learning_rate)
loss_fn = nn.CrossEntropyLoss()
# If dry run we only
logging.info(f'Training model with {BERT_MODEL_NAME}')
if args.dry_run:
logging.info("Dry run mode")
epochs = 1
steps_per_epoch = 1
else:
steps_per_epoch = None
# Action !
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train_metrics = train(epoch, classifier, train_loader, loss_fn=loss_fn, optimizer=optimizer, max_steps=steps_per_epoch)
eval_metrics = evaluate(classifier, eval_loader, loss_fn=loss_fn)
print("-" * 59)
print(
"End of epoch {:3d} - time: {:5.2f}s - loss: {:.4f} - accuracy: {:.4f} - valid_loss: {:.4f} - valid accuracy {:.4f} ".format(
epoch, time.time() - epoch_start_time, train_metrics['loss'], train_metrics['acc'], eval_metrics['loss'], eval_metrics['acc']
)
)
print("-" * 59)
if args.dry_run:
# If dry run, we do not run the evaluation
return None
test_metrics = evaluate(classifier, test_loader, loss_fn=loss_fn)
metrics = {
'train': train_metrics,
'val': eval_metrics,
'test': test_metrics,
}
logging.info(metrics)
# save model and architecture to single file
if params.get('job_dir') is None:
logging.warning("No job dir provided, model will not be saved")
else:
logging.info("Saving model to {} ".format(params.get('job_dir')))
torch.save(classifier.state_dict(), params.get('job_dir'))
logging.info("Bye bye")
if __name__ == '__main__':
# Create arguments here
parser = argparse.ArgumentParser()
parser.add_argument('--training-file', required=True, type=str)
parser.add_argument('--validation-file', required=True, type=str)
parser.add_argument('--testing-file', type=str)
parser.add_argument('--job-dir', type=str)
parser.add_argument('--epochs', type=float, default=2)
parser.add_argument('--batch-size', type=float, default=1024)
parser.add_argument('--learning-rate', type=float, default=0.01)
parser.add_argument('--dry-run', action="store_true")
# Parse them
args, _ = parser.parse_known_args()
# Execute training
train_and_evaluate(**vars(args))
这很好,但不幸的是,这个模型需要很长时间才能训练完成。实际上,训练约 470 万参数时,每一步大约需要 3 秒,在一台配备 Intel 芯片、16GB 内存的 MacBook Pro 上进行训练。

每步 3 秒,对于 1238 步和 10 个 epoch 的训练来说,可能是相当长的时间…
没有 GPU,就没有派对。
如何使用 Vertex AI 并启动派对?
简短回答:Docker 和 gcloud。
如果您的笔记本电脑没有强大的 GPU(就像我们大多数人一样),和/或您不想烧坏笔记本电脑的散热风扇,您可能希望将脚本移至 Google Cloud 等云平台(免责声明:我在工作中使用 Google Cloud)。
Google 的一个优点是,当您使用 Gmail 账户创建自己的项目时,会提供 300 美元的信用额度。
一如既往,当需要将代码转移到其他地方时,Docker 通常是首选解决方案。
将脚本 Docker 化
让我们编写一个启用了 GPU 的 Docker 镜像。您可以在官方 Docker 仓库中找到许多 Docker 镜像,我选择了 pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime,因为我使用的是 Pytorch 2.2.2 版本。请确保选择一个带有 CUDA 的版本,否则您将需要在 Dockerfile 中自己安装它,相信我,除非必须,否则您不希望这样做。
这个 Dockerfile 会预安装必要的 CUDA 依赖和驱动程序,确保我们能够在自定义训练作业中使用它们,并在调用镜像时,使用您传递的参数运行 Python main.py 文件。
FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime
WORKDIR /src
COPY . .
RUN pip install --upgrade pip && pip install -r requirements.txt
ENTRYPOINT ["python", "main.py"]
在 Google Cloud 上构建并推送镜像
一旦我们的镜像准备好构建,我们需要构建它并将其推送到一个注册表。您可以选择任何注册表,但 Google Cloud 提供了一个名为 Artefact Registry 的服务。因此,您将能够非常轻松地将镜像存储在 Google Cloud 上。
在您的目录根目录下创建这个小文件,并确保 Dockerfile 文件与其在同一级别:
# build.sh
export PROJECT_ID=<your-project-id>
export IMAGE_REPO_NAME=pt_bert_sentiment
export IMAGE_TAG=dev
export IMAGE_URI=eu.gcr.io/$PROJECT_ID/$IMAGE_REPO_NAME:$IMAGE_TAG
gcloud builds submit --tag $IMAGE_URI .
运行 build.sh 文件,等待几分钟镜像构建完成后,您应该会看到类似的内容:
eu.gcr.io/
在 Vertex AI 上创建作业
一旦您的镜像构建并推送到 Artefact Registry,我们就可以告诉 Vertex AI 在我们想要的任何机器上运行这个镜像,包括带有强大 GPU 的机器!Google 在您创建自己的 GCP 项目时会提供 300 美元的信用额度,这对于运行我们的模型是非常足够的。
费用详情请见 这里。在我们的案例中,我们将选择每小时 0.24 美元的 n1-standard-4 机器,并附加每小时 0.40 美元的 NVIDIA T4 GPU。

(来源:Google Cloud)

(来源:Google Cloud)
创建一个 job.sh 文件,内容如下,指定您所在的地区和使用的机器类型。如果您在其他地区,费用可能会有所不同,请参考上面的链接。
您还需要传递参数给您的训练脚本。
gcloud ai custom-jobs create的语法由两部分组成:
与作业本身相关的参数:
--region、--display-name、--worker-pool-spec、--service-account和--args与训练相关的参数:
--training-file、--epochs等。后者需要以
--args开头,表示所有后续的参数都与训练 Python 脚本相关。例如:假设我们的脚本有 2 个参数 x 和 y,我们可以这样写:
--args=x=1,y=2
# job.sh
export PROJECT_ID=<your-project-id>
export BUCKET=<your-bucket-id>
export REGION="europe-west4"
export SERVICE_ACCOUNT=<your-service-account>
export JOB_NAME="pytorch_bert_training"
export MACHINE_TYPE="n1-standard-4" # We can specify GPUs here
export ACCELERATOR_TYPE="NVIDIA_TESLA_T4"
export IMAGE_URI="eu.gcr.io/$PROJECT_ID/pt_bert_sentiment:dev"
gcloud ai custom-jobs create \
--region=$REGION \
--display-name=$JOB_NAME \
--worker-pool-spec=machine-type=$MACHINE_TYPE,accelerator-type=$ACCELERATOR_TYPE,accelerator-count=1,replica-count=1,container-image-uri=$IMAGE_URI \
--service-account=$SERVICE_ACCOUNT \
--args=\
--training-file=gs://$BUCKET/data/train.csv,\
--validation-file=gs://$BUCKET/data/eval.csv,\
--testing-file=gs://$BUCKET/data/test.csv,\
--job-dir=gs://$BUCKET/model/model.pt,\
--epochs=10,\
--batch-size=128,\
--learning-rate=0.0001
在 Vertex AI 上运行任务
启动脚本并进入你的 GCP 项目,在 Vertex 菜单下的“训练”部分。

(图片来自作者)
启动脚本并进入控制台。你应该会看到任务状态为“等待中”,然后是“训练中”。
为了确保正在使用 GPU,你可以检查任务及其资源:

(图片来自作者)
这表明我们正在使用 GPU 进行训练,因此现在应该能够显著加速!让我们来看一下日志:

运行 1 个周期不到 10 分钟,而在 CPU 上需要 1 小时/周期!我们已经将训练任务转移到 Vertex,并加速了训练过程。我们可以选择启动其他配置不同的任务,而不会超载我们笔记本的能力。
那么模型的最终准确率如何呢?在经过 10 个周期后,它大约为 94% 到 95%。我们可以让它继续运行更长时间,看看分数是否提高(我们还可以添加早停回调来避免过拟合)

我们的模型表现如何?

(图片来自作者)
到派对时间了!
不让任何标签被遗漏:层次化类别的替代编码方法
寻找一个适用于当前和未来编码的系统
·发布于 Towards Data Science ·阅读时间 15 分钟·2024 年 5 月 17 日
--

图片来源:Gabriel Tenan 通过 Unsplash
作为数据科学家的我,经常会遇到很多标签。数据中包含了邮政编码标签、性别标签、医疗诊断标签、职位标签、股票代码标签,等等。标签可以是简单的(如 S、M、L 尺码的衬衫)或复杂的(国际疾病分类系统编码了超过 70,000 种医疗状况,这些编码有时可能非常具体,令人忍俊不禁).
当它们出现在数据中时,我们称这些标签为类别特征。具有“很多”可能值的类别特征被称为高基数类别特征。高基数类别特征在机器学习模型中使用时会带来困难。大量的维度使得它们无法直接使用,或使用起来不切实际(例如“维度灾难”)。因此,采用了各种编码方案来简化这些特征。
低频或未见过的编码也为高基数类别特征带来了挑战。例如,一些邮政编码的区域人口稀少,而其他邮政编码区域则拥有数百万居民;我们对某些编码的信心更高。此外,常见的编码集,如医疗诊断编码,定期更新,导致在训练时没有可用的未见过值。未见过和低频…
不,你不需要一个新的微服务架构
因为你几乎肯定已经拥有一个,只是没有明确意识到罢了
·发表于 Towards Data Science ·阅读时间:6 分钟 ·2024 年 10 月 29 日
--

企业架构作为一个混乱的系统——由 DALL-E 生成
如果你觉得 AI 生成的文章-图像确实很好地捕捉了你公司系统架构的本质,那么这篇文章就是为你准备的。
毫无疑问,将复杂任务分解为更小、更易管理的子任务,对于任何类型的问题解决都是有帮助的。这一点对于将我们的业务流程数字化的 IT 系统也同样适用。因此,架构师们遵循了 IT 领域中经过验证的“分而治之”的路径,将我们的系统拆分成多个小型应用程序/服务,分别为不同的业务领域执行特定的任务。
随着我们企业复杂性的增长,代表数字化业务流程的互联应用程序/服务的系统也变得过于复杂。因此,我们不断尝试保持秩序和结构,以防整个系统崩溃并彻底停止工作——这实际上就是企业架构,如果你曾经好奇那些坐在象牙塔里的人试图通过他们的架构组件图实现什么目标的话。
企业架构
关于如何通过合适的架构来驯服或防止混乱,已有大量的文献。
2024 年诺贝尔奖:人工智能突破大奖
人工智能诺贝尔辩论后的教训
·发表在Towards Data Science·阅读 7 分钟·2024 年 12 月 10 日
--

人工智能生成的图像。
自 1901 年颁发第一届诺贝尔奖以来,年底时期已成为了解各个领域杰出个人及其贡献的激动人心时刻。
今年的诺贝尔奖季节格外引人注目 —— 也有些争议 —— 因为在物理学和化学类别中对人工智能进展的特别认可。
今年的奖项突显了人工智能的巨大潜力,并提出了关于在计算方法重新定义传统领域的时代科学学科性质的紧迫问题。
在这篇文章中,我们旨在探讨人工智能在 2024 年诺贝尔奖中的角色,讨论事情已经平息后的争议,并邀请您分享您对此事的看法!
人工智能是否会成为未来诺贝尔奖的一个长久存在?
2024 年诺贝尔物理学奖

尽管计算机不能像人类那样“思考”,但计算机算法现在可以模仿类似于人类的功能,如记忆和学习。
没有人能把 AI 逼进角落!
关于转型的两个简短故事,以及如果你想成为“AI 启用型”公司,应该做些什么
·发布于 Towards Data Science ·阅读时间 7 分钟·2024 年 11 月 13 日
--

由 ChatGTP 生成
我与许多产品公司交流时,他们都很难理解“转型为 AI”对他们来说意味着什么。在这篇文章中,我分享了成为 AI 启用型企业的意义,以及你可以做些什么去实现这一目标。不是通过列举你必须做的事情,而是通过两个故事。第一个故事讲的是数字化——一个非数字化公司转型为数字化公司的意义。因为转型为 AI 遵循相同的路径;这是一种“相似但不同”的转型。第二个故事讲的是为什么过去几年许多产品公司在 AI 和数据科学投资中失败,因为他们把 AI 置于了角落里。
但在我们深入之前,请记住,成为 AI 启用型公司是一场转型,或者说是一段旅程。而要顺利开始这段旅程并成功到达目的地,最好是你清楚知道自己要去哪里。那么:什么是“AI 启用型”公司?
成为 AI 启用型公司意味着能够利用 AI 技术把握一个机会,或者获得你原本无法获得的竞争优势。
那么,在完成转型之后,你怎么知道自己是否成功了呢?你可以问自己这个问题:
我们现在能做什么,之前做不到的?我们现在能利用一个以前无法利用的机会吗?
或者更直接地说:我们现在是否能够利用一个以前无法利用的机会?
这个问题与人工智能并无特别关系。它适用于任何组织为获得新能力而进行的转型。因此,如果你希望转型到人工智能领域,也有很多可以从其他转型中学习的经验。
轶事 1:数字化的故事

由 ChatGPT 生成
在过去的几十年里,一些大型企业经历了一个巨大的转变,称为数字化。这是一个过程,企业从将信息技术作为日常工作的工具,转变为将信息技术作为战略资产,以获得竞争优势。几年前,我在石油和天然气行业待过一段时间,参与了大规模的数字化工作。如果你没有在石油和天然气行业工作过,你可能会惊讶地发现,这个庞大的行业在很大程度上仍然没有实现数字化。当然,这个行业自计算机问世以来就一直在使用计算机,但那时它们只是工具:CAD 工具用于设计,物流系统用于项目和生产计划,CRM 系统用于管理员工和客户,等等。然而,一个公司相对于另一个公司的竞争力,主要体现在员工对钢铁、管道和机械的知识上,了解流体如何通过管道流动,如何在恶劣环境下安装重型设备,以及这个行业中的许多其他事情。计算机一直被看作是完成工作的工具,信息技术也一直被视为需要尽量减少的开销。数字化正是旨在改变这种思维方式的转型。
为了使信息技术在竞争中发挥杠杆作用,企业必须从把信息技术视为开销的思维转变为把信息技术视为投资机会。通过投资自己的信息技术,你可以创造出竞争对手没有的工具和产品,从而为自己赢得竞争优势。
但是,投资于内部软件开发是昂贵的,因此,为了确定正确的投资以将竞争转向自己一方,你需要所有工程师、钢铁和机械专家开始思考,你能通过计算机以一种服务于这个目标的方式解决哪些问题和挑战。这是因为,关于如何改善你的产品和服务的知识,掌握在员工的脑海中:与客户沟通的销售人员,触及市场趋势的营销人员,设计和制造资产的产品人员,以及设计、制造和测试最终产品的工程师。这些人必须内化使用计算机技术来整体改善业务的理念,并付诸实践。这就是数字化的目标。
但是你已经知道这些了,对吧?那为什么还要重复呢?
因为向 AI 转型的故事其实和数字化转型是完全相同的;你只需要将“数字化转型”替换为“向 AI 转型”。因此,从数字化项目中有很多东西可以学习。如果你够幸运,你可能已经理解了什么是数字化公司,那么你实际上也知道数字化转型意味着什么。
轶事二:数据科学的三个时代

由 ChatGPT 生成
工业 AI 和数据科学的历史较短,起始于 2010-2012 年。虽然从这段历史中可以学到一些东西,但我立刻要说:目前还没有什么灵丹妙药能让 AI 的转型一蹴而就。但作为一个行业,我们在逐步进步。我把这段历史分为三个不同的时代,根据各家公司在首次启动 AI 项目时的处理方式来划分。
在第一阶段,想要使用 AI 和机器学习的公司大量投资于大型数据基础设施,并雇佣了一大批数据科学家,将他们集中在一个房间里,期待奇迹发生。但什么也没发生,基础设施和人力成本极高,因此这种方法很快被放弃。这个思路受到 Twitter、Facebook、Netflix 和 Google 等成功案例的启发,但这些企业的规模并不适用于大多数公司。经验教训。
在第二阶段,借鉴了第一阶段的经验,AI 顾问们建议应该从识别自己领域中的“杀手级”AI 应用开始,组建一支小型的数据科学家团队,做一个最小可行产品(MVP),并在此基础上不断迭代。这将为你提供一个高价值的项目和示范案例,让你能向整个公司展示 AI 的辉煌。大家都会惊叹不已,看到希望,AI 转型就会完成。因此,公司雇佣了一个小团队的数据科学家,把他们安排在一个角落里,期待奇迹的出现。但什么也没发生。
之所以在这种环境下没有奇迹发生,是因为被聘用来帮助转型的数据科学家和 AI/ML 专家并不了解业务。他们既不了解你,也不了解你客户的痛点。他们不知道业务领域的希望、梦想和目标。而且,更重要的是,了解这些的人——你们组织中的产品人员、管理者和工程师——他们并不懂数据科学家,也不了解 AI,更不清楚 AI 能用来做什么。他们也不理解数据科学家在说什么。在这些群体还没学会互相“交流”之前,魔法是不可能发生的。因为,在那之前,AI 转型根本没有发生。
这就是为什么当你检查自己是否已经完成转型时,重要的不是问自己“你能做什么”,而是问“你将做什么”。AI 团队能帮助应用 AI 来抓住机会,但除非他们知道该做什么,否则这一切不会发生。
这是一项沟通的工作。就是让合适的人互相交流。但跨越这些边界的沟通是具有挑战性的,这也导致了我们现在所处的局面:
第三个时代——虽然目前仍然没有银弹,当前的建议如下:
-
找到一位在 AI 和机器学习方面有经验的人。这是一个专业领域,你需要具备相关能力。除非你拥有卓越的人才,否则不要试图一夜之间将其他领域的专家转型为数据科学家。从零开始建立一个团队需要时间,并且他们在开始时没有经验。如果需要,毫不犹豫地去外部寻找帮助你入门的人。
-
将数据科学家与领域专家和产品开发团队对接,让他们一起提出你业务中的第一个 AI 应用。它不一定是杀手级应用——只要能找到任何可能有用的应用,就足够了。
-
继续开发解决方案,并展示给组织中的其他人。
这项工作的重点不是要命中靶心,而是要提出一个全公司都能识别、理解并评论的可行 AI 示例。如果领域专家和产品团队的人出来说:“但你解决了错误的问题!你应该做的是……”那么你可以把它视为一次胜利。到那时,你已经让关键资源开始对话,合作找出新的、更好的解决方案,解决你已经设定的那些问题。
在我担任数据科学家期间,“角落里的数据科学家”陷阱是导致团队或组织在初始人工智能(AI)项目中失败的主要原因之一。如果 AI 资源没有与产品团队密切互动,那应该被视为注定失败。你需要让产品团队推动 AI 项目——这是确保 AI 解决方案能有效解决正确问题的关键。
总结
-
成为一个 AI 驱动的产品组织的转型,建立在数字化驱动的基础上,并遵循类似的路径:成功的关键是与领域专家和产品团队合作,让他们了解并应用 AI 带来的扩展问题解决能力。
-
AI 和机器学习是一个复杂的专业领域,你需要有熟练的专家。之后,关键是将这些资源与领域专家和产品团队紧密结合,以便他们能够开始共同解决问题。
另外:不要把 AI 放在角落里!

转型过程。图示由作者与 ChatGPT 和 GIMP 合作完成。
非线性:线性回归能与梯度提升竞争吗?
线性模型可以通过数据预处理来处理非线性关系。但它们能接近更复杂的模型吗?
·发表于 Towards Data Science ·阅读时间 8 分钟·2024 年 10 月 11 日
--

[图片来源:作者]
几周前,我在 LinkedIn 上发布了一篇帖子。
这篇帖子是基于以下图表,比较了两种模型的预测结果:线性回归和 CatBoost。

[图片来源:作者]
这篇帖子的核心观点是,像 CatBoost 这样的梯度提升模型似乎能提供更“合理”的预测变量与目标变量之间的关系解释(即房屋状况与房价之间的关系)。
事实上,许多自然界中的关系是非线性的。
该帖子收到了若干反对意见,其中以下评论因获得大量点赞而特别引人注目:

这篇 LinkedIn 帖子的评论区。
这引发了一场讨论,我发现了以下评论(由同一作者写的…)
非响应偏差:沉默的大多数如何决定了一场选举,并让一个深受喜爱的出版商陷入困境

这是一个统计偏差的介绍,它通过完全缺乏数据来显现其残酷的存在。
·发表于 Towards Data Science ·阅读时长 26 分钟·2024 年 8 月 27 日
--
在一个阴沉的十一月傍晚,富兰克林·德拉诺·罗斯福总统在距白宫约 300 英里的春伍德坐定,准备收听收音机。

从空中俯瞰春伍德,位于纽约州海德公园,面朝西南(1932 年 6 月 6 日)。来源:富兰克林·D·罗斯福总统图书馆与博物馆(许可:公有领域照片)
坐落在哈德逊河东岸起伏的树林山丘之间,春伍德是富兰克林·罗斯福的终生家园。它是总统的世界中心。一个熟悉而舒适的地方,他在其定义时代的十二年总统任期内一次次归来。¹

1941 年 7 月,西春伍德客厅的景象。来源:美国国会图书馆/维基共享资源。 (公有领域照片)
在那个十一月的晚上,罗斯福和他的家人因为一个特殊的原因坐在收音机前。
这是 1936 年 11 月 3 日的晚上。
归一化折扣累积增益(NDCG)——终极排名度量标准
NDCG——用于评估推荐系统的排名感知度量标准
·发表于 Towards Data Science ·阅读时长 10 分钟·2024 年 10 月 15 日
--
推荐系统无处不在。既然你正在阅读这篇文章,很有可能是 Medium 向你推荐的。这篇文章将探讨 NDCG——归一化折扣累积增益,作为评估任何推荐系统模型的排名感知度量标准。

图像由 Gemini 生成的 AI 生成
什么是推荐系统?
推荐系统帮助用户根据他们的偏好或行为发现相关的项目,如产品、个人资料、帖子、视频、广告或信息。这些平台处理数百万个项目,展示最相关的内容是提高用户参与度和商业指标的关键。亚马逊、LinkedIn、Twitter、Instagram、Reddit、Spotify、YouTube、Netflix、Medium 和 Quora 等公司在他们的应用中使用推荐系统。
这些系统通常是两阶段系统,包含一个检索模型和一个排名模型。检索模型根据相似度度量,从数百万个项目中筛选出最相关的项目,并将其传递给排名模型。排名模型对这些项目进行更精细的排序。
不是所有 HNSW 索引都一样
克服主要的 HNSW 挑战,提升你的 AI 生产工作负载效率
·发表于Towards Data Science ·7 分钟阅读·2024 年 7 月 3 日
--

图片由Talha Riaz提供,来源于Pexels
层次可导航小世界(HNSW)算法因其在大规模数据搜索中的高效性和准确性而广受欢迎,是搜索任务和像 RAG 这样的 AI/LLM 应用的常见选择。然而,设置和维护 HNSW 索引也伴随着一系列挑战。让我们一起探讨这些挑战,提供一些解决方案,甚至看看如何通过解决其中一个问题来一举两得。
内存消耗
由于其嵌入的层次结构,HNSW 的一个主要挑战是其高内存使用量。但很少有人意识到,内存问题不仅仅局限于存储初始索引所需的内存。这是因为,当 HNSW 索引被修改时,存储节点及其连接所需的内存会进一步增加。稍后的部分将对此进行更详细的解释。内存意识至关重要,因为数据需要的内存越多,计算(搜索)的时间就越长,维护工作负载的成本也会越高。
构建时间

图片由Andrea De Santis提供,来源于Unsplash
在创建索引的过程中,节点会根据它们与图中其他节点的接近程度被添加到图中。对于每个节点,在图的每个层级都会保持一个包含其最近邻居的动态列表。这个过程涉及对列表进行迭代,并执行相似度搜索,以确定节点的邻居是否更接近查询。这一计算密集型的迭代过程显著增加了索引的总体构建时间,负面影响用户体验,并导致云计算使用成本的增加。
参数调优
HNSW 在构建过程中需要预定义的配置参数。优化 HNSW 的这些参数:M(每个节点的连接数)和 ef_construction(用于索引构建过程中最近邻的动态列表大小)对平衡搜索速度、准确性和内存使用至关重要。不正确的参数设置可能导致性能下降,并增加生产成本。微调这些参数对于每个索引都是独特的,并且是一个持续的过程,通常需要频繁重建索引。
重建索引

图片来自 Robin Jonathan Deutsch 在 Unsplash
重建 HNSW 索引是将 HNSW 应用于生产工作负载时最耗资源的环节之一。与传统数据库不同,传统数据库可以通过简单地删除表中的一行来处理数据删除,而在向量数据库中使用 HNSW 往往需要完全重建索引,以保持最佳的性能和准确性。
为什么重建是必要的?
由于其分层图结构,HNSW 本身并不适合处理频繁变化的动态数据集。添加新数据或删除现有数据对于保持数据的最新状态至关重要,尤其是在像 RAG 这样的用例中,RAG 旨在提高搜索相关性。
大多数数据库采用名为“硬删除”和“软删除”的概念。硬删除是永久删除数据,而软删除将数据标记为“待删除”,然后稍后移除。软删除的问题在于,待删除的数据在被永久移除之前仍然占用大量内存。在使用 HNSW 的向量数据库中尤其如此,因为内存消耗本身已经是一个显著的问题。
HNSW 创建了一个图,其中节点(向量)是根据它们在向量空间中的接近度连接的,遍历 HNSW 图就像跳表一样。为了支持这一点,图的层次结构被设计成某些层具有非常少的节点。当向量被删除时,特别是那些位于节点极少的层,这些层在图中作为关键连接点时,整个 HNSW 结构可能会变得碎片化。这种碎片化可能导致某些节点(或层)与主图断开连接,这就需要重建整个图,或者至少会导致搜索效率下降。
HNSW 随后使用软删除技术,这种技术将向量标记为删除,但并不立即移除它们。这种方法减少了频繁完全重建的开销,尽管仍然需要定期重建,以保持图的最优状态。
解决 HNSW 的挑战
那么,我们有哪些方法可以应对这些挑战呢?以下是一些对我有效的方法:
- 向量量化 — 向量量化(VQ)是一个过程,它将来自向量空间ℝ^k 的 k 维向量映射到一个有限的向量集合中,这些向量被称为码字(例如,使用Linde-Buzo-Gray (LBG)算法),这些码字组成了一个码本。每个码字Yi都有一个相关区域,称为Voronoi 区域,它根据与码字的接近度将整个空间ℝ^k 划分为若干区域(见下图)。当输入向量提供时,它会与码本中的每个码字进行比较,以找到最接近的匹配。这是通过识别与输入向量的欧几里得距离最小的码字来完成的。与其传输或存储整个输入向量,不如传输最接近码字的索引(编码)。在检索向量(解码)时,解码器从码本中检索相应的码字。该码字被用作原始输入向量的近似值。重建的向量是原始数据的近似,但由于 VQ 过程的特点,它通常保留了最重要的特征。VQ 是减少索引构建时间和存储 HNSW 图所需内存量的一种流行方法。然而,重要的是要理解,它也会降低搜索结果的准确性。

一个二维向量空间示例(为了简化)。图片由作者提供。
2. 经常重建索引 — 克服 HNSW 扩展内存挑战的一种方法是频繁重建索引,从而清除那些标记为“待删除”的节点,这些节点占用空间并降低搜索速度。考虑在这些时候制作索引的副本,这样你就不会遭受完全的停机时间(然而,这会需要大量内存——这已经是 HNSW 的一个大问题)。
3. 并行索引构建 — 并行构建索引涉及将数据和分配的内存进行分区,并将索引过程分配到多个 CPU 核心上。在这个过程中,所有操作都映射到可用的内存中。例如,系统可能会将数据分割成可管理的块,将每个块分配给不同的处理器核心,并让它们同时构建各自的索引部分。这种并行化方法可以更好地利用系统资源,从而加快索引创建速度,特别是对于大规模数据集。这是一种比传统的单线程构建更快速的索引构建方式;然而,当整个索引无法装入内存,或者当 CPU 核心不足以支持在要求的时间框架内完成工作负载时,仍然会遇到挑战。

并行处理。图片来自作者。
使用自定义构建加速器:一种不同的方法
尽管上述策略可以提供帮助,但它们通常需要相当高的专业知识和开发能力。因此,引入了 GXL,一个新型的付费工具,旨在增强 HNSW 索引的构建。它使用 APU,GSI Technology的计算内存关联处理单元,通过其数百万个比特处理器在内存中进行计算。这种架构能够进行大规模的并行处理,快速计算最近邻距离,从而显著加快大规模动态数据集的索引构建时间。它采用一种自定义算法,结合了向量量化,并通过使用独特的硬件并行化来克服相似性搜索的瓶颈,从而减少整体索引构建时间。
让我们来看一些基准测试数据:

图片来自作者。图片来源:Ron Bar Hen
基准测试比较了 HNSWLIB 和 GXL-HNSW 在不同数据集大小(deep10M、deep50M、deep100M 和 deep500M——这些都是 deep1B 的子集)下的构建时间,使用的参数为 M = 32 和 ef-construction = 100。测试在一台配备 Intel(R) Xeon(R) Gold 5218 CPU @ 2.30GHz 的服务器上进行,使用一个 NUMA 节点(32 个 CPU 核心,380GB 内存,7 个 LEDA-S APU 卡)。
结果清楚地表明,GXL-HNSW 在所有数据集大小上都显著优于 HNSWLIB。例如,GXL-HNSW 在 1 分 35 秒内构建了 deep10M 数据集,而 HNSWLIB 需要 4 分 44 秒,速度提升因子为 3.0。随着数据集大小的增加,GXL-HNSW 的效率进一步提升,deep50M 的速度提升因子为 4.0,deep100M 为 4.3,deep500M 为 4.7。这一持续的改进突显了 GXL-HNSW 在处理大规模数据时的优越表现,使其成为大规模数据集相似性搜索的更高效选择。
总结来说,虽然 HNSW 在向量搜索和 AI 流程中非常有效,但它面临着一些严峻的挑战,如索引构建时间较慢和内存使用量较高,而由于 HNSW 复杂的删除管理,这些问题更加严重。解决这些挑战的策略包括通过频繁重建索引来优化内存使用、对索引实施向量量化以及并行化索引构建。GXL 提供了一种有效结合这些策略的方法。这些方法有助于在依赖 HNSW 的系统中保持准确性和效率。通过减少构建索引所需的时间,索引重建不再是一个耗时的问题,使我们能够一举两得——解决内存扩展问题和长时间索引构建问题。试试看哪种方法最适合你,希望这能帮助你提高整体生产工作负载的性能。
NuCS:一个用于研究、教学和生产应用的约束求解器

照片来自 Eric Prouzet 在 Unsplash
纯 Python 的极速约束求解
·发表于 Towards Data Science ·阅读时间 6 分钟·2024 年 11 月 22 日
--
TLDR
NuCS是一个Python 库,用于求解约束满足和优化问题(CSP 和 COP),是我作为副项目开发的。由于它完全用 Python 编写,NuCS 易于安装,并且可以用几行代码建模复杂问题。NuCS 求解器非常快速,因为它得益于Numpy和Numba的支持。
许多问题都可以形式化为 CSP(约束满足问题)。这也是为什么像 NuCS 这样的约束库对开发者或数据科学家来说非常有帮助。
让我们考虑著名的 N 皇后问题,其要求在一个N x N的棋盘上放置N个皇后,使得它们互不威胁。

8 皇后问题的解法。来源:Yue Guo
14200个12 皇后问题的解法在不到2 秒的时间内,在一台运行以下程序的 MacBook Pro M2 上被找到:
-
Python 3.11,
-
Numpy 2.0.1,
-
Numba 0.60.0 和
-
NuCS 3.0.0.
(venv) ➜ nucs git:(main) time NUMBA_CACHE_DIR=.numba/cache python -m nucs.examples.queens -n 12 --log_level=ERROR --processors=6
{
'ALG_BC_NB': 262006,
'ALG_BC_WITH_SHAVING_NB': 0,
'ALG_SHAVING_NB': 0,
'ALG_SHAVING_CHANGE_NB': 0,
'ALG_SHAVING_NO_CHANGE_NB': 0,
'PROPAGATOR_ENTAILMENT_NB': 0,
'PROPAGATOR_FILTER_NB': 2269965,
'PROPAGATOR_FILTER_NO_CHANGE_NB': 990435,
'PROPAGATOR_INCONSISTENCY_NB': 116806,
'SOLVER_BACKTRACK_NB': 131000,
'SOLVER_CHOICE_NB': 131000,
'SOLVER_CHOICE_DEPTH': 10,
'SOLVER_SOLUTION_NB': 14200
}
NUMBA_CACHE_DIR=.numba/cache python -m nucs.examples.queens -n 12 6.65s user 0.53s system 422% cpu 1.699 total
什么是约束编程?
约束编程是一种求解组合优化问题的范式。在约束编程中,用户通过声明约束条件来明确可行解的限制,这些约束指定了解决方案所需的属性。求解器结合约束传播和回溯算法来寻找解决方案。
作为示例,以下是使用 NuCS 的魔法序列问题模型(找到一个序列x_0, … x_n-1,使得对于每个i在[0, n-1]中,x_i是i在序列中出现的次数):
class MagicSequenceProblem(Problem):
def __init__(self, n: int):
super().__init__([(0, n)] * n)
for i in range(n):
self.add_propagator((list(range(n)) + [i], ALG_COUNT_EQ, [i]))
# redundant constraints
self.add_propagator((list(range(n)), ALG_AFFINE_EQ, [1] * n + [n]))
self.add_propagator((list(range(n)), ALG_AFFINE_EQ, list(range(n)) + [n]))
在 NuCS 中,约束被称为传播器。
传播器(这里是ALG_COUNT_EQ)简单地表明x_i是序列中i出现的次数。以下两个ALG_AFFINE_EQ传播器是冗余的,这意味着它们对于 NuCS 找到解并不是必需的,但它们加速了求解过程。
查看文档以获取 NuCS 支持的传播器完整列表。请注意,NuCS 中的大多数传播器是全局(即n元)并实现了最先进的传播算法。
Python
Python 是数据科学家首选的编程语言:它具有简单的语法,日益壮大的社区以及大量的数据科学和机器学习库。
但另一方面,Python 被认为是一种较慢的语言:根据基准测试,它的速度可能比 C 慢 50 到 100 倍。
选择 Python 来开发高性能的约束编程库并非显而易见,但我们将看到,Numpy(高性能计算包)和 Numba(Python 的即时编译)结合使用,极大地提升了性能。
已经有许多尝试在 Python 中编写约束求解器,但这些要么很慢,要么只是封装器,并依赖于用 Java 或 C/C++编写的外部求解器。
Numpy
NumPy将类似 C 和 Fortran 的语言的计算能力带到了 Python 中。
强大的 N 维数组 NumPy 的向量化、索引和广播概念使其既快速又多功能…
numpy.org](https://numpy.org/?source=post_page-----7b260afc2fe4--------------------------------)
在 NuCS 中,一切都是 Numpy 数组。
这使得可以利用 Numpy 的索引和广播能力,并编写紧凑的传播器,例如Max_i x_i <= y
def compute_domains_max_leq(domains: NDArray, parameters: NDArray) -> int:
x = domains[:-1]
y = domains[-1]
if np.max(x[:, MAX]) <= y[MIN]:
return PROP_ENTAILMENT
y[MIN] = max(y[MIN], np.max(x[:, MIN]))
if y[MIN] > y[MAX]:
return PROP_INCONSISTENCY
for i in range(len(x)):
x[i, MAX] = min(x[i, MAX], y[MAX])
if x[i, MAX] < x[i, MIN]:
return PROP_INCONSISTENCY
return PROP_CONSISTENCY
Numba
Numba 是一个开源的即时编译(Just-In-Time)编译器,它将 Python 和 NumPy 代码的子集转换为快速的机器码。
@njit(parallel=True) def simulator(out): # 并行迭代循环 for i in prange(out.shape[0]): out[i] = run_sim()…
numba.pydata.org](https://numba.pydata.org/?source=post_page-----7b260afc2fe4--------------------------------)
在以下示例中,我们找到了12 皇后问题的 14200 个解(请注意,我们在这里使用的是单处理器)。
NUMBA_DISABLE_JIT=1 python -m nucs.examples.queens -n 12 --log_level=ERROR 179.89s user 0.31s system 99% cpu 3:00.57 total
通过启用即时编译(Just-In-Time compilation),我们实现了x60的加速:
NUMBA_CACHE_DIR=.numba/cache python -m nucs.examples.queens -n 12 3.03s user 0.06s system 99% cpu 3.095 total
为了让 Numba JIT 编译你的代码,你应该:
-
避免面向对象编程(OOP),
-
使用支持的类型或 Numpy 数组,
-
使用 Python 语言的子集,
-
使用 Numpy 函数的子集。
在 NuCS 中,这些指南已经成功地为以下问题实现:
-
传播器(参见
nucs.readthedocs.io/en/latest/reference.html#propagators了解在 NuCS 中实现的传播器列表), -
一致性算法(参见
nucs.readthedocs.io/en/latest/reference.html#consistency-algorithms了解在 NuCS 中实现的一致性算法列表), -
启发式方法(参见
nucs.readthedocs.io/en/latest/reference.html#heuristics了解在 NuCS 中实现的启发式方法列表)。
得益于 Numpy 和 Numba,NuCS 在性能上与用 Java 或 C/C++编写的解算器相当。
请注意,由于 Python 代码是编译的且结果被缓存,当你第二次运行程序时,性能将显著提高。
示例
NuCS 提供了许多模型,用于经典的约束编程问题,如:
-
一些加密算术谜题:Alpha,Donald,
-
平衡不完全区组设计问题,
-
Golomb 标尺问题,
-
背包问题,
-
魔术序列问题,
-
魔方问题,
-
准群问题,
-
n 皇后问题,
-
舒尔引理问题,
-
体育赛事调度问题,
-
数独问题。
其中一些示例需要一些高级技术:
-
冗余约束,
-
自定义启发式方法,
-
自定义一致性算法
大多数这些模型也可以在CSPLib中找到,它是 CSP 相关问题的宝典。
统计与日志记录
在搜索解决方案时,NuCS 还会聚合一些统计数据:
{
'ALG_BC_NB': 262006,
'ALG_BC_WITH_SHAVING_NB': 0,
'ALG_SHAVING_NB': 0,
'ALG_SHAVING_CHANGE_NB': 0,
'ALG_SHAVING_NO_CHANGE_NB': 0,
'PROPAGATOR_ENTAILMENT_NB': 0,
'PROPAGATOR_FILTER_NB': 2269965,
'PROPAGATOR_FILTER_NO_CHANGE_NB': 990435,
'PROPAGATOR_INCONSISTENCY_NB': 116806,
'SOLVER_BACKTRACK_NB': 131000,
'SOLVER_CHOICE_NB': 131000,
'SOLVER_CHOICE_DEPTH': 10,
'SOLVER_SOLUTION_NB': 14200
}
在这里我们可以看到:
-
约束一致性计算了 262006 次,
-
2268895 个传播器被应用,但其中 990435 次无效,同时发现不一致 116806 次,
-
共计 131000 次选择和回溯,最大选择深度为 10,
-
最终,找到了 14200 个解。
通过与模型互动并理解它如何影响统计数据,已被证明是一种非常有用的练习,能够最大化利用 NuCS。
NuCS 还提供了一些基本的日志记录功能。
NUMBA_CACHE_DIR=.numba/cache python -m nucs.examples.golomb -n 10 --symmetry_breaking --log_level=INFO
2024-11-12 17:27:45,110 - INFO - nucs.solvers.solver - Problem has 82 propagators
2024-11-12 17:27:45,110 - INFO - nucs.solvers.solver - Problem has 45 variables
2024-11-12 17:27:45,110 - INFO - nucs.solvers.backtrack_solver - BacktrackSolver uses variable heuristic 0
2024-11-12 17:27:45,110 - INFO - nucs.solvers.backtrack_solver - BacktrackSolver uses domain heuristic 0
2024-11-12 17:27:45,110 - INFO - nucs.solvers.backtrack_solver - BacktrackSolver uses consistency algorithm 2
2024-11-12 17:27:45,110 - INFO - nucs.solvers.backtrack_solver - Choice points stack has a maximal height of 128
2024-11-12 17:27:45,172 - INFO - nucs.solvers.backtrack_solver - Minimizing variable 8
2024-11-12 17:27:45,644 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 80
2024-11-12 17:27:45,677 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 75
2024-11-12 17:27:45,677 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 73
2024-11-12 17:27:45,678 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 72
2024-11-12 17:27:45,679 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 70
2024-11-12 17:27:45,682 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 68
2024-11-12 17:27:45,687 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 66
2024-11-12 17:27:45,693 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 62
2024-11-12 17:27:45,717 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 60
2024-11-12 17:27:45,977 - INFO - nucs.solvers.backtrack_solver - Found a (new) solution: 55
{
'ALG_BC_NB': 22652,
'ALG_BC_WITH_SHAVING_NB': 0,
'ALG_SHAVING_NB': 0,
'ALG_SHAVING_CHANGE_NB': 0,
'ALG_SHAVING_NO_CHANGE_NB': 0,
'PROPAGATOR_ENTAILMENT_NB': 107911,
'PROPAGATOR_FILTER_NB': 2813035,
'PROPAGATOR_FILTER_NO_CHANGE_NB': 1745836,
'PROPAGATOR_INCONSISTENCY_NB': 11289,
'SOLVER_BACKTRACK_NB': 11288,
'SOLVER_CHOICE_NB': 11353,
'SOLVER_CHOICE_DEPTH': 9,
'SOLVER_SOLUTION_NB': 10
}
[ 1 6 10 23 26 34 41 53 55]
定制化
最后,NuCS 是一个非常开放的平台,几乎可以定制任何内容:
-
传播器,
-
一致性算法,
-
启发式方法,
-
解算器。
在以下Golomb 标尺示例中,在使用之前注册了一个自定义一致性算法:
consistency_alg_golomb = register_consistency_algorithm(golomb_consistency_algorithm)
solver = BacktrackSolver(problem, consistency_alg_idx=consistency_alg_golomb)
结论
总结来说,NuCS 是一个功能丰富的约束解算器库。尽管它完全用 Python 编写,但性能非常快,可以应用于广泛的领域:研究、教学和生产。
如果你想参与 NuCS 的开发,欢迎随时在 Github 上联系我!
一些有用的链接,供进一步了解:
-
Pip 包:
pypi.org/project/NUCS/
如果你喜欢这篇关于 NuCS 的文章,请拍手50 次!
Numpy 的随机选择在 Go 语言中的实现
·发表于Towards Data Science ·阅读时间 5 分钟·2024 年 1 月 23 日
--

由 ChatGPT 生成
最近我帮助实现了一些 Java 逻辑,这些逻辑本可以通过简单调用 Numpy 的random.choice来实现。这最终成了一个任务,让我有机会去深入了解那些你每天都在使用,但从未真正有时间完全理解其工作原理的东西。同时,我也有一段时间想开始学习 Go,那么为什么不一举两得,再次用 Go 实现 random.choice 呢?
random.choice 允许我们根据指定的概率从提供的集合中采样 N 个元素。重要的是(对于激励这项工作的使用案例),它允许我们进行无重复的采样。也就是说,如果集合中的一个元素已经被采样,它将不会再次被采样。例如,如果我们有一个集合 [A, B, C],其对应的概率为 [0.1, 0.7, 0.2],并且我们想无重复地采样 3 个元素,那么大多数时候我们会得到 [B, C, A] 作为输出。如果我们进行有重复的采样,预期的输出将是 [B, B, B]。
首先,让我们定义 Go 函数的签名。我们希望它尽可能与 Numpy 的对应函数保持一致。
func ChoiceT any ([]T, error) {}
关于函数签名,有几点需要注意:
-
我们使用了泛型类型 T。这允许我们对不同类型的数组调用该函数(只要它满足类型约束,而在我们这个例子中并没有类型约束)。这应该模仿 Python 中
random.choice的语义,也就是说,它不关心输入数组中元素的类型。 -
我们还传递了一个指向随机数生成器(rng)对象的指针,供我们在抽样时使用。我从 Jax 中学到了这种定义随机函数的风格(与直接访问全局 rng 实例相比)。根据我的经验,这简化了测试和可重复性。
-
这个函数有两个返回值,一个是样本数组,另一个是
error类型。这是 Go 语言中处理“异常”执行流的方式(Go 没有断言或异常)。
现在我们需要弄清楚如何使用仅在[0, 1]之间均匀抽样的浮动随机数(由 rng 返回)从由probs参数定义的离散概率分布中抽取元素。幸运的是,有一种方法可以完美解决这个问题。
CDF 反演法
首先,CDF 代表累积分布函数。在离散情况下,它可以表示为一个数组,其中索引i处的元素等于所有输入概率在位置i及之前的累计和。让我们用一个简单的帮助函数来实现这个公式。
func CDF(probs []float64) []float64 {
cdf := make([]float64, len(probs))
cum := 0.0
for i := range cdf {
cum += probs[i]
cdf[i] = cum
}
return cdf
}
通过离散概率分布的 CDF 和随机数生成器,我们可以通过以下方式从输入集合中抽取元素:
-
从统一分布中随机抽取一个介于[0, 1]之间的浮动数值。
-
找到第一个 CDF 值大于等于随机浮动数的索引。
-
返回原集合中该索引位置的元素。
为了理解它为何有效,我们可以进行一个简单的视觉实验。我们可以把 CDF 数组中的值看作是放置在区间[0, 1]上每个箱子的右边界。每个箱子的宽度与输入的概率成正比。当生成介于[0, 1]之间的均匀随机浮动数时,我们可以把它看作是将球随机投掷到这个区间,并选择我们撞到的箱子。撞到箱子的概率正好与输入概率成正比(这正是我们需要的)。下面是我们最后一个例子——集合[A, B, C]及其相关概率[0.1, 0.7, 0.2]的视觉示范。

由作者在 Excalidraw 中创建
要获得箱子的索引,我们可以返回第一个右边界大于或等于抽样值的索引。同样,这里有一个简单的帮助函数来实现这个功能:
func FindIndexFromRight(val float64, cdf []float64) int {
for i, cumProb := range cdf {
if cumProb >= val {
return i
}
}
return len(cdf) - 1
}
将所有内容整合在一起
这样,我们就具备了实现带有重复的random.choice所需的一切。无重复抽样则需要多一个技巧。为了确保我们不会抽取已经被抽取过的元素,可以在抽样后将其概率置为 0。然而,这样会使我们的离散概率分布失效,因为其总和将不再等于 1。为了修正这一点,我们需要通过将概率除以新的总和来重新归一化概率。作为额外的优化,我们可以直接对 CDF 进行重新归一化,而不需要先重新归一化输入概率再计算 CDF。将所有内容整合在一起:
func ChoiceT any ([]T, error) {
if !replace && (size > len(arr)) {
return nil, errors.New("cannot sample more than array size without replacements")
}
samples := make([]T, size)
probsCopy := make([]float64, len(probs))
copy(probsCopy, probs)
for i := 0; i < size; i++ {
cdf := CDF(probsCopy)
if !replace {
total := cdf[len(cdf)-1]
for cdfInd := range cdf {
cdf[cdfInd] /= total
}
}
randFloat := rng.Float64()
sampledIndex := FindIndexFromRight(randFloat, cdf)
samples[i] = arr[sampledIndex]
if !replace {
probsCopy[sampledIndex] = 0.0
}
}
return samples, nil
}
再次提醒几个需要注意的事项:
-
如果我们进行不放回抽样,就无法抽取超过输入集合初始大小的样本。
-
Go 传递“切片”(没有定义大小的数组)作为可变参数。因此,我们复制输入的概率,以避免用掩码修改原始数组。
驱动代码如下:
rngSource := rand.NewSource(time.Now().UnixNano())
rng := rand.New(rngSource)
arr := []string{"A", "B", "C", "D"}
probs := []float64{0.1, 0.6, 0.2, 0.1}
samples, err := Choice(arr, 3, false, probs, rng)
if err != nil {
log.Fatal(err)
}
fmt.Println("Samples: ", samples)
就是这样,随时可以在评论区提问。
自定义物体检测:探索 YOLO 基本原理并在自定义数据上训练
利用预训练模型、增强图像与边界框,揭示卷积神经网络在物体检测中的强大力量
·发布于 Towards Data Science ·15 分钟阅读·2024 年 1 月 8 日
--
深度学习在过去十年中取得了巨大的进展,尽管早期的模型难以理解和应用,但现代框架和工具使得每个拥有一定代码理解的人都可以为计算机视觉任务训练自己的神经网络。
在本文中,我将全面展示如何加载和增强数据以及边界框,训练物体检测算法,并最终查看我们在测试图像中检测物体的准确性。尽管现有的工具包随着时间的推移变得更易于使用,但仍然存在一些可能遇到的陷阱。
计算机视觉(CV)简介
计算机视觉是一个非常流行且更为广泛的研究与应用领域。深度学习的进展,尤其是过去十年,极大地加速了我们对深度学习的理解以及它广泛的应用潜力。
为什么我们现在看到这些进展?正如 Keras 库的创始人 François Chollet 所描述的那样,我们目睹了 CPU 计算能力的提升,增长了大约 5000 倍,正是由于…
物体检测基础 — 综合初学者指南(第一部分)
在这篇易于理解的多部分初学者指南中,学习这个高级计算机视觉任务——物体检测的基础知识
·发表于 Towards Data Science ·9 分钟阅读 ·2024 年 2 月 5 日
--

图片来自 Javier García 由 Unsplash 提供
如今,配备最新驾驶辅助技术(如车道检测、盲点监测、交通信号识别等)的汽车已经相当普遍。如果我们稍微退后一步,去了解一下幕后发生了什么,作为数据科学家,我们很快会意识到,系统不仅仅是在分类物体,还在实时地定位它们。
这些功能是物体检测系统实际应用的典型示例。驾驶辅助技术、工业机器人和安全系统都利用物体检测模型来检测感兴趣的物体。物体检测是一个高级计算机视觉任务,涉及物体的定位和分类。
在本文中,我们将深入探讨物体检测任务的细节。我们将学习与之相关的各种概念,以帮助我们理解新颖的架构(将在后续文章中讨论)。我们将涵盖理解物体检测模型所需的关键方面和概念,尤其是从迁移学习的角度。
关键概念与构建模块
物体检测包括两个主要子任务,定位和分类。物体的分类是容易理解的。但是,我们如何定义物体的定位呢?让我们了解一些关键概念:
边界框
在目标检测任务中,我们通过一个矩形框来标识给定物体的位置。这个规则的矩形框被称为边界框,用于物体的定位。通常,输入图像的左上角被设为原点或(0,0)。一个矩形边界框通过其左上角和右下角的 x 和 y 坐标来定义。让我们通过图像来直观地理解这一点。图 1(a)展示了一个示例图像,其原点设置在左上角。

图 1:(a)包含不同物体的示例图像,(b)每个物体的边界框,标注了左上角和右下角的顶点,(c)识别边界框的另一种方法是使用左上角坐标及其宽度和高度参数。来源:作者
图 1(b)显示了每个已识别物体及其相应的边界框。需要注意的是,边界框通过其左上角和右下角的坐标进行标注,这些坐标是相对于图像原点的。通过 4 个值,我们可以唯一地识别一个边界框。另一种识别边界框的方法是使用左上角坐标以及其宽度和高度值。图 1(c)展示了这种识别边界框的替代方法。不同的解决方案可能使用不同的方法,这通常是基于个人的偏好。
目标检测模型在每个训练样本中需要每个物体的边界框坐标以及类别标签。同样,在推理阶段,目标检测模型会为每个识别出的物体生成边界框坐标和类别标签。
锚框
每个目标检测模型通过扫描大量可能的区域来识别/定位给定图像中的物体。在训练过程中,模型学习确定哪些扫描到的区域是感兴趣的,并调整这些区域的坐标以匹配真实的边界框。不同的模型可能会以不同的方式生成这些感兴趣区域。然而,最流行和广泛使用的方法是基于锚框。对于给定图像中的每个像素,都会生成多个不同尺寸和宽高比(宽度与高度的比率)的边界框。这些边界框被称为锚框。图 2 展示了给定图像中特定像素的不同锚框。

图 2:给定图像中特定像素(红色标出)对应的不同锚框。来源:作者
锚框的维度通过两个参数进行控制,尺度表示为 s 𝜖 (0,1],宽高比表示为 r >0。正如图 2 所示,对于高度为 h 和宽度为 w 的图像,以及特定的 s 和 r 值,可以生成多个锚框。通常,我们使用以下公式来计算锚框的维度:
wₐ=w.s√r
hₐ = h.s / √r
其中 wₐ 和 hₐ 分别是锚框的宽度和高度。锚框的数量和尺寸要么是预定义的,要么是在训练过程中由模型自行选择的。为了更好地理解,模型在每个像素位置生成多个锚框,并在训练过程中学习调整和匹配这些锚框与真实边界框。
边界框和锚框是理解整个物体检测任务的关键概念。在深入了解这些架构的具体工作方式之前,我们首先要理解评估这些模型性能的方式。以下是一些重要的评估指标:
交并比(IOU)
物体检测模型通常会生成一些锚框,然后根据实际边界框进行调整。但我们如何知道何时匹配发生,或者匹配的效果如何呢?
杰卡德指数 是一种用于确定两个集合之间相似度的度量。在物体检测中,杰卡德指数也被称为交并比(Intersection Over Union,简称 IOU)。它的计算公式为:
IOU = | Bₜ ∩ Bₚ | / | Bₜ ∪ Bₚ |
其中 Bₜ 是真实边界框,Bₚ 是预测边界框。简单来说,它是一个介于 0 和 1 之间的分数,表示预测边界框和真实边界框之间重叠区域与并集区域的面积比。重叠越多,分数越高。接近 1 的分数表示几乎完美的匹配。图 3 展示了预测边界框与真实边界框在样本图像中的不同重叠情况。

图 3:交并比(IOU)是衡量预测边界框与真实边界框匹配程度的指标。重叠越多,得分越高。来源:作者
根据问题陈述和数据集的复杂性,设置不同的 IOU 阈值来确定哪些预测边界框应被认为是有效的。例如,基于 MS-COCO 的物体检测挑战使用 0.5 的 IOU 阈值来将预测边界框视为真正的正样本。
平均精度均值(MAP)
精度和召回率是常用于理解分类器在机器学习上下文中性能的指标。以下公式定义了这些指标:
精度 = TP / (TP + FP)
召回率 = TP / (TP + FN)
其中,TP、FP 和 FN 分别代表真阳性、假阳性和假阴性的结果。精度和召回率通常一起使用来生成精度-召回曲线,以对模型性能进行稳健的量化。这是因为精度和召回率的对立特性,即随着模型的召回率增加,精度开始下降。PR 曲线用于计算F1 分数、曲线下面积 (AUC) 或 平均精度 (AP) 指标。平均精度是通过在不同召回阈值下计算精度的平均值来得到的。图 4(a) 显示了一个典型的 PR 曲线,图 4(b) 展示了如何计算 AP。

图 4:a) 一个典型的 PR 曲线展示了模型在不同召回值下的精度。这是一条向下倾斜的图表,因为精度和召回率的度量是相对的;(b) PR 曲线用于计算聚合/综合分数,如 F1 分数、曲线下面积 (AUC) 和平均精度 (AP);(c) 平均精度均值 (mAP) 是一个稳健的综合指标,用于理解模型在不同阈值下对所有类别的表现。每条彩色曲线表示基于每个类别的特定 IOU 阈值的不同 PR 曲线。来源:作者
图 4(c) 展示了平均精度指标如何扩展到物体检测任务。如图所示,我们在不同的 IOU 阈值下计算 PR 曲线(这是针对每个类别进行的)。然后,我们取所有类别的平均精度值的均值,得到最终的 mAP 指标。这个综合指标是对给定模型在不同类别和阈值下性能的稳健量化。通过将性能缩小到一个可量化的指标,可以轻松地在相同的测试数据集上比较不同模型的表现。
另一个用于评估物体检测模型的指标是每秒帧数 (FPS)。该指标表示模型每秒可以分析多少输入图像或帧以检测物体。这是实时应用场景(如安全视频监控、人脸检测等)中一个重要的指标。
通过掌握这些概念,我们现在准备好理解物体检测的一般框架了。
物体检测框架
物体检测是一个重要且活跃的研究领域。多年来,已经开发并在实际应用中使用了许多不同但有效的架构。物体检测的任务要求所有这些架构解决一系列子任务。在我们深入了解具体模型如何处理这些任务之前,先来理解一下应对物体检测的一般框架。该框架包括以下步骤:
-
区域建议网络
-
定位和分类预测
-
输出优化
现在让我们详细了解一下这些步骤。
区域建议
顾名思义,物体检测框架中的第一步是提出感兴趣区域(ROI)。ROI 是输入图像中,模型认为物体存在的可能性较高的区域。物体存在或不存在的可能性通过一个称为物体性得分的分数来定义。那些物体性得分大于某个阈值的区域会传递到下一阶段,而其他区域则被丢弃。
例如,查看图 5,了解模型提出的不同感兴趣区域(ROI)。需要注意的是,在这一阶段会生成大量的 ROI。基于物体性得分阈值,模型会将 ROI 分类为前景或背景,仅将前景区域传递到下一步进行进一步分析。

图 5:区域提议是物体检测框架中的第一步。感兴趣区域以红色矩形框的形式突出显示。模型将图像中可能性较高的区域(高物体性得分)标记为前景区域,其余标记为背景区域。来源:作者
生成感兴趣区域(ROI)有多种不同的方法。早期的模型通常使用选择性搜索及相关算法来生成 ROI,而新型更复杂的模型则利用深度学习模型来完成这一任务。我们将在接下来的文章中讨论具体架构时进一步探讨这些方法。
定位与分类预测
物体检测模型与我们通常使用的分类模型有所不同。物体检测模型会为每个前景区域生成两个输出,前景区域来自于上一阶段的结果:
-
物体类别:这是典型的分类目标,目的是为每个提议的前景区域分配一个类别标签。通常,会使用预训练的网络从提议区域提取特征,然后利用这些特征来预测类别。像在 ImageNet 或 MS-COCO 上训练的最先进模型,涵盖了大量类别并广泛采用迁移学习。需要注意的是,我们为每个提议区域生成类别标签,而不是像典型的分类任务那样为整个图像生成一个单一标签。
-
边界框坐标:边界框定义为一个包含 4 个值的元组,分别表示 x、y、宽度和高度。在这一阶段,模型会为每个提议的前景区域生成一个元组(同时包括物体类别)。
输出优化
如前所述,物体检测模型在第一步提出了大量的 ROI(感兴趣区域),然后在第二步进行边界框和类别预测。虽然在第一步中对 ROI 进行了某种程度的过滤(基于物体得分区分前景与背景区域),但在第二步中仍然有大量区域用于预测。为如此大量的提议区域生成预测确保了对图像中各种物体的良好覆盖。然而,也有一些区域对同一物体存在大量重叠。例如,看看图 6(a)中为同一个物体预测的 6 个边界框。这可能会导致难以准确统计输入图像中不同物体的数量。

图 6 (a) 物体检测模型为同一物体生成 6 个重叠较多的边界框。 (b) 使用 NMS 优化后的输出。来源:作者
因此,在这个框架中有第三个步骤,涉及输出的优化。这个优化步骤确保每个输入图像中的每个物体只有一个边界框和类别预测。进行这种优化的方法有多种。目前,最流行的方法被称为非极大值抑制(NMS)。顾名思义,NMS 会分析每个物体的所有边界框,找到具有最大概率的那个,并抑制其余的边界框(请参见图 6(b),展示了应用 NMS 后的优化输出)。
这就结束了对一般物体检测框架的高层次理解。我们讨论了定位和分类图像中物体的三个主要步骤。在接下来的文章中,我们将基于这些理解,探讨具体的实现方法及其关键贡献。
目标检测:COCO 和 YOLO 格式,以及它们之间的转换
学习 COCO 和 YOLOv5 格式的结构,以及如何将它们相互转换。
·发表于Towards Data Science ·阅读时间:7 分钟·2024 年 2 月 11 日
--

图片由Matt Briney提供,来源于Unsplash
如果你想在没有 Premium Medium 账号的情况下阅读这篇文章,可以通过以下好友链接访问 😃
www.learnml.wiki/object-detection-coco-and-yolo-formats-and-conversion-between-them/
引言
用于训练目标检测模型的图像标注可以有不同的格式,即使它们包含相同的信息。在现有的不同格式中,有两个非常常用的格式是 COCO JSON 格式和 YOLOv5 PyTorch TXT 格式。前者因微软于 2015 年发布的 MS COCO 数据集[1]而闻名,该数据集是最广泛用于目标检测、分割和图像标注任务的之一。另一方面,YOLOv5 PyTorch TXT 格式的流行则是因为由 Ultralytics 开发的 YOLOv8 架构(目标检测的最先进模型)[2]将其作为输入格式。
本文首先将介绍这两种格式流行的基础,正如上文所述,它们分别是 MS COCO 数据集和 Ultralytics 的 YOLOv8 架构。
关于 LLM、梯度和量子力学
量子计算是否能帮助我们提高训练大型神经网络语言模型(LLM)的能力?
·发表于 Towards Data Science ·13 分钟阅读·2024 年 11 月 12 日
--

图片来源:Alessio Soggetti (@asoggetti) 来自 Unsplash.com
什么是“训练”?
在人工智能(AI)研究的术语中,“训练”是指优化一个统计模型,通常表现为一个神经网络,使其能够根据输入数据和这些预测的好坏(“成本”或“损失”函数)来做出预测。这样的过程可以通过三种主要范式来进行:监督学习、无监督学习(通常是自回归的),以及强化学习。在监督学习中,每个数据点都有标签,因此可以将模型的预测结果与真实值进行直接比较(例如:这是猫还是狗的图像)。在无监督学习中,没有明确的标签,但通过从数据本身提取的特征进行比较(例如:预测句子中的下一个单词)。最后,强化学习是基于通过统计模型与环境的互动来优化一系列决策(预测)的长期回报(例如:在黄色交通灯前,汽车应该减速还是加速?)。
在所有这些情况下,模型参数的优化是一个漫长的过程,需要…
OLAP 已死——还是它并未死?
OLAP 在现代分析时代的命运
·发表于 Towards Data Science ·13 分钟阅读·2024 年 10 月 21 日
--
1993 年,E.F. Codd 及其团队提出了 OLAP(在线分析处理)这一术语,用于描述从不同角度回答多维分析查询的技术。OLAP 主要包括三个关键操作:
-
汇总:在更高层次的聚合中总结数据,
-
下钻:导航至更详细的数据层级,
-
切片与切块:从不同的视角选择和分析数据。
现在浏览网络,似乎每个数据分析问题都与流行的自助式 BI 相关,关注使用强化 AI 分析大数据的平台。像 LinkedIn 和 Reddit 这样的平台上充斥着关于过时 OLAP 的缺点的无尽讨论,相较于最新的数据分析趋势,OLAP 显得不合时宜。所以是的,我们可以自信地宣布:OLAP 已死。但等等……它真的死了吗?

RIP OLAP(作者提供的图片—AI 生成)
我是谁,为什么写这篇文章?
在我们深入讨论这个有争议的话题之前,先让我介绍一下自己,并解释为什么我要用这篇文章打扰你们。我在 icCube 工作,在那里,我解决客户的技术难题。有时,销售团队会邀请我参加潜在客户的演示,而几乎每次,关于数据可扩展性的核心问题都会被提出来——如何处理客户(即将成为的)大数据。作为一个技术性和务实的人,我的天真、非销售性的回答通常是:
我们是否可以先定义一下实际问题,看看我们是否真的需要讨论大数据?
哎呀 😉 我早就说过,我骨子里是个技术宅。所以,在这篇文章中,我想澄清一下 OLAP 在 2024 年的含义,以及它能解决的挑战。我将从我在 icCube 的经验出发,所以我可能有些偏见,但我会尽力保持客观。欢迎在评论中分享你的想法。
OLAP != OLAP Cube
OLAP 通常(如果不是总是)与 OLAP Cube 互换使用——即一个在多维空间中预聚合值的物化结构。基于这个错误的定义,很容易理解为什么人们会说 OLAP 已过时,因为技术的进步已经减少了对预聚合的需求。
然而,OLAP 并不等同于 OLAP Cube。如果要从关于 OLAP 的各种定义和讨论中突出一件事,那就是 OLAP 包含了一套高效分析多维数据的概念和方法。
Chris Webb 在一篇文章中很好的捕捉了这一点,回顾了过去的日子:
我所说的“OLAP”是指一个集中的模型,不仅包含所有数据,还包括表如何连接、度量如何聚合、复杂计算和关键绩效指标(KPIs)等内容。
在他的文章“OLAP 已死了吗”中,Chris Webb 还提到了FASMI 测试,作为用五个关键词来评定 OLAP 系统的一种方式:“快速共享多维信息分析”(Fast Analysis of Shared Multidimensional Information)。
FAST : means that the system is targeted to deliver most
responses to users within about five seconds, with the
simplest analyses taking no more than one second and
very few taking more than 20 seconds.
ANALYSIS : means that the system can cope with any business logic
and statistical analysis that is relevant for the
application and the user, and keep it easy enough for
the target user.
SHARED : means that the system implements all the security
requirements for confidentiality (possibly down to cell
level).
MULTIDIMENSIONAL : is our key requirement. If we had to pick a one-word
definition of OLAP, this is it. The system must provide
a multidimensional conceptual view of the data,
including full support for hierarchies and multiple
hierarchies, as this is certainly the most logical way
to analyze businesses and organizations.
INFORMATION : is all of the data and derived information needed,
wherever it is and however much is relevant for the
application.
我觉得很有趣的是意识到这个定义最早出现在 2005 年的一篇副标题为:
对经常被误用的 OLAP 术语的分析。
所以,很明显,这种混淆并不是新鲜事,我们的营销和销售同事也对此做出了贡献。请注意,这个定义并未指定 OLAP 系统应如何实现。OLAP Cube 只是实现 OLAP 解决方案的一种可能技术。
根据我的数据领域经验,多维(MULTIDIMENSIONAL)和共享(SHARED)是关键要求。我会把“共享”(SHARED)替换为“安全”(SECURED),并且让“下到单元格级别”(down to cell level)成为必选项——一个带有安全约束的复杂多维数据模型不可避免地意味着最终会有一个复杂的安全配置文件。请注意,FASMI 测试并没有规定分析数据的绝对大小。
在深入探讨五个关键术语并展示它们如何应用于现代工具之前,让我们首先挑战一些广泛认同的观念。
数据分析 != 大数据分析
不可避免地,大数据的论点被用来主张 OLAP 已死。
我完全不同意这一说法。然而,让我们看看 Jordan Tigani 在他 2023 年初发布的文章“大数据已死”中是如何开头的:
当然,在“大数据”工作组购买了所有新工具并完成从旧系统的迁移后,人们发现他们仍然很难理解数据的含义。如果他们真心关注的话,可能已经注意到,数据的规模其实根本不是问题所在。
这是一个非常引人入胜且富有启发性的帖子,超越了市场营销的炒作。我觉得我没有必要在这里重复我在工作中以更小的规模所经历的事情。他的结论是:
大数据是真实存在的,但大多数人可能不需要为此担忧。你可以通过以下一些问题来判断自己是否是“大数据中的佼佼者”:
你真的在生成大量的数据吗?
如果是这样,你真的需要一次性使用大量的数据吗?
如果是这样,数据真的太大,无法放入一台机器吗?
如果是这样,你确定你不是一个数据囤积者吗?
如果是这样,你确定总结数据会更好些吗?
如果你对这些问题的回答是否定的,你可能是新一代数据工具的理想候选者,这些工具帮助你处理实际拥有的数据规模,而不是那些让你担心将来可能拥有的大规模数据。
到目前为止我没有什么可补充的。本文后续部分,我们将探讨现代 OLAP 工具如何帮助你管理你正在处理的数据规模。
数据分析 != 自助式 BI
不可避免地,自助式 BI 是另一个用来主张 OLAP 已死的论点。
商务用户可以自主访问和处理原始企业数据,无需依赖数据专业人员的支持。这种方法使用户能够使用易于操作的工具和界面,进行分析、生成报告并创建仪表盘。
如果我们承认所需的分析对于任何商务人士来说足够简单,或者工具足够先进,可以处理更复杂的分析和安全配置,那么潜在的前提是数据已经清理干净,准备好用于商业决策。
在 icCube 中,在客户项目的启用阶段,80% 的时间都花在清理和理解实际数据及其背后的商业模型上。令人惊讶的是,这一时间的相当大一部分也花在与少数同时了解技术和业务的人员沟通上。这并不奇怪,因为数据模型通常会在多年内演变,变得越来越复杂,而人们也会来来去去。
但假设原始数据是干净的,并且商务用户完全理解它。那么当成百上千个报告被创建出来时(很可能是访问 OLTP 数据库,因为在创建分析数据仓库时没有 IT 部门的参与),会发生什么呢?它们彼此一致吗?它们遵循相同的业务规则吗?它们的计算是正确的吗?它们会引起性能问题吗?
假设一切正常,那么你如何维护这些报告?更重要的是,如何管理底层原始数据中的任何必要更改,因为没有简单的方法知道数据是在哪里使用的?
所以,类似于大数据的论点,我不认为自助 BI 是解决每个现代分析挑战的真正方案。事实上,从长远来看,它可能会带来更多问题。
数据分析 != 生成式 AI 数据分析
最后是 AI 的论点。你不再需要 OLAP 引擎,顺便提一下,你也不再需要任何分析工具。AI 来了,统治一切!我有点夸张,但考虑到目前围绕 AI 的所有炒作,我并没有太远离现实 😉
更严肃地说,在 icCube,即使我们目前对于使用 AI 生成 MDX 代码或分析数据持怀疑态度,这并不意味着我们反对 AI。恰恰相反,事实上。我们最近推出了一个聊天机器人小工具,帮助终端用户理解他们的数据。我们正在积极研究如何利用 AI 提高客户的生产力。我们面临的实际问题主要是:
-
它的准确性不足以交给无法分辨幻觉的终端用户。
-
对于那些在领域内是专家且能够理解并修正幻觉的终端用户,提供这些功能就是过度设计了。
-
每个查询的成本(即 LLM 推理成本)。
但不要只听我说 — 我想强调一下 Marco Russo 的实践和类似的观点。你可以通过这里查看他的视频。如果时间紧迫,可以跳到 32 分钟的位置,那里 Marco 分享了他对于使用 ChatGPT 生成 DAX 代码的看法。
目前,生成式 AI 还无法取代任何 OLAP 系统,当然也不能作为 OLAP 已死的论据。
现在,让我们回到 FASMI 测试,看看定义 OLAP 系统的五个关键术语。
FASMI 测试:快速
means that the system is targeted to deliver most responses to users
within about five seconds, with the simplest analyses taking no more than
one second and very few taking more than 20 seconds.
提供快速响应时间的分析查询不再是 OLAP 系统的专利。然而,它仍然是 OLAP 系统的一项附加优势,因为 OLAP 系统专门为此类查询量身定制。一个显著的优势是,它有助于避免对 OLTP 数据库(或任何实际数据源)的过载,因为:
-
可能已经创建了一个专用的数据仓库。
-
它可能作为实际数据源前的缓存。
这个中间层的另一个好处是,它可以帮助降低访问底层原始数据的成本。
FASMI 测试:分析
means that the system can cope with any business logic and statistical
analysis that is relevant for the application and the user, and keep it
easy enough for the target user.
OLAP 系统旨在执行复杂的分析查询,因此提供了一系列通常在其他系统中无法直接获得的功能。这些功能包括:
-
切片和切块功能:允许用户从不同的视角和维度探索数据。
-
自然导航:支持通过父/子层次结构在多维模型中直观导航。
-
聚合度量:支持各种聚合,如求和、最小值、最大值、开盘值、闭盘值等。
为了支持所有这些功能,需要一种专门的查询语言。MDX(多维表达式)是多维分析的事实标准。
我们经常与客户使用的一些高级功能,可能是非标准的,包含:
-
时间周期比较:便于进行基于时间的分析,如同比分析。
-
计算度量:支持在设计或运行时创建临时计算。
-
计算成员:类似于计算度量,但可以应用于任何维度。例如,它们可以用于创建辅助维度,成员基于当前评估上下文进行统计。
-
高级数学运算:提供向量和其他结构,优雅地执行复杂的数学计算(统计、回归等)。
-
MDX 扩展:函数、Java 代码嵌入、结果后处理等。
FASMI 测试:共享
means that the system implements all the security requirements for
confidentiality (possibly down to cell level).
根据我的经验,我认为这是继多维模型之后的第二个最重要的需求。在每个需要安全性的客户模型中,定义适当的授权成为一个重大挑战。
我建议通过将单元格级别粒度设为强制要求来改进 FASMI 测试。
微软分析服务、icCube,以及其他平台可能允许在多维模型中直接定义安全性,使用 MDX 语言(将在下一点介绍)。这种方法非常自然,通常与公司层级安全结构自然对接。
在多维模型层级定义安全性尤为重要,尤其当模型是由多个数据源构建时。例如,在没有此功能的情况下,应用企业安全策略到来自 IoT 传感器等来源的数据可能会非常复杂。
自从 FASMI 测试推出以来,将分析功能嵌入应用程序已经成为一个关键需求。许多 OLAP 系统,包括微软分析服务和 icCube,现在支持在运行时动态创建安全配置文件 —— 一旦用户身份验证通过 —— 基于不同的用户属性。一旦定义了此安全模板,它将在每次用户登录系统时动态应用。
FASMI 测试:多维
is our key requirement. If we had to pick a one-word definition of OLAP,
this is it. The system must provide a multidimensional conceptual view of
the data, including full support for hierarchies and multiple
hierarchies, as this is certainly the most logical way to analyze
businesses and organizations.
我完全同意。多维模型对数据分析至关重要,因为它提供了一种结构化的方法,可以从多个角度分析复杂数据(数据并非孤立存在),并且通常与企业层级安全框架对接。
对业务用户直观易懂
该模型反映了企业自然思考数据的方式——无论是产品、客户还是时间段。对于非技术用户来说,这种方式更为直观,允许他们在不需要理解复杂 SQL 查询的情况下探索数据。诸如父子层级和多对多关系等关键特性也被无缝集成。
增强的数据聚合与汇总
该模型旨在处理跨维度的聚合(如求和、平均、计数),这对于在不同层级汇总数据至关重要。它非常适合创建仪表板,展示高层次的概览,并能够根据需要深入探讨更详细的见解。
促进时间序列分析
时间是许多数据分析类型中的关键维度,例如跟踪趋势、预测和衡量一段时间内的表现。多维模型可以轻松地将时间作为一个维度进行集成,从而实现时间序列分析,如同比(年对年或月对月)的比较。
现实世界中的数据复杂性
尽管无代码数据工具的兴起,现实世界中的数据项目很少是简单的。数据源往往杂乱无章,随着时间的推移不断演变,带有不一致性,增加了复杂性。使用传统 SQL 方法访问原始数据可能会遇到挑战。考虑到熟练人才的短缺,首先建立一个清晰的语义层是明智之举,以确保数据的正确使用,并为未来的数据驱动决策提供良好的基础。
分析中的信任与可靠性
一个定义良好的多维模型(或语义层)的一个主要优势是,它能够建立客户对分析结果的信任。这个强大的模型允许有效的测试,使得在当今快速变化的环境中能够灵活应对。
感知的灵活性不足
OLAP 中的语义层在数据访问之前是一个至关重要的步骤,尽管它最初看起来可能限制了灵活性,但它确保从一开始就正确地建模数据,从而简化了未来的报告。在许多情况下,这种“灵活性不足”更多的是一种感知,而非现实。现代 OLAP 工具,如 icCube,不依赖于过时且繁琐的流程来创建 OLAP 数据立方体,甚至支持增量更新。例如,icCube 的类别功能允许在运行时创建新的维度。
总结来说,尽管与直接访问原始数据相比,OLAP 和维度模型在灵活性上可能给人一种印象,但它们在处理复杂业务逻辑和安全性方面依然提供了至关重要的优势。
FASMI 测试:信息
is all of the data and derived information needed, wherever it is and
however much is relevant for the application.
从各种来源提取数据——无论是 SQL、NoSQL、物联网、文件还是 SaaS 平台——已经不再是 OLAP 系统的专属功能。然而,OLAP 系统仍然具有一个关键优势:它们专门设计用于创建一个安全的多维模型,作为您的分析需求的事实语义层。
FASMI 测试:2024 年依然相关吗?
FASMI 测试的原始定义旨在为在线分析处理(OLAP)系统提供清晰而易于记忆的描述:共享多维信息的快速分析。我相信这个定义依然相关,而且比以往任何时候都更加必要。在 2024 年,人们不应该再将 OLAP 与其过去的某个实现——过时的 OLAP 立方体混淆。
你在 2024 年需要 OLAP 吗?
作为一个务实的人,在不了解你当前的数据分析挑战之前,我不会建议使用特定的工具。我建议仔细识别你的当前需求,然后寻找合适的工具。最重要的是,如果你对当前的分析平台感到满意,不要仅仅为了使用最新的流行工具而更换它。
然而,如果你是:
-
在查询复杂的多维商业模型时感到困难,
-
在应用必须与公司层级安全模型对齐的复杂安全性时遇到困难,
-
在编写复杂的计算以进行高级分析时遇到困难,
-
为了管理数百个或数千个截然不同的查询/仪表板而感到困难,
-
在不到一秒钟的时间内打开仪表板时感到困难,
-
在从不同系统中获取和合并数据时遇到困难,
-
在信任你的分析洞察时感到困难,
那么值得考虑现代的 OLAP 系统。请放心,它们并没有过时,而且还会持续一段时间。现代 OLAP 工具正在积极开发,并在 2024 年保持相关性。此外,它们受益于最新的技术进展:
-
大数据技术,
-
自助服务功能,
-
生成性 AI,
实现新功能或完善现有功能以提高最终用户的生产力。这将是未来一篇文章的主题。敬请期待!
有兴趣的读者可以在此维基百科页面上探索可用的 OLAP 服务器。
忽略变量偏差

Rothstein, A.,摄影师。(1939)农场家庭共进晚餐。蒙大拿州费尔菲尔德农场,蒙大拿州费尔菲尔德农场美国大提顿县,1939 年 5 月。[照片] 来源于国会图书馆,www.loc.gov/item/2017777606/.
介绍一种特别狡猾的偏差,这种偏差常常侵入许多回归模型中
·发布于Towards Data Science ·20 分钟阅读·2024 年 8 月 6 日
--
从 2000 年到 2013 年,涌现出大量研究,显示青少年冒险行为的发生率与他们与家人一起用餐的频率之间有着显著的相关性。
一项又一项的研究似乎都得出了相同的结论:
青少年每周与家人一起用餐的次数越多,他们沉溺于物质滥用、暴力、犯罪、破坏公物以及许多其他问题行为的概率就越低。
家庭用餐频率更高也与减少压力、减少儿童抑郁症发生率以及减少自杀念头的频率相关。一起用餐还与自尊心的提高和青少年情感健康的普遍提升相关。
很快,媒体捕捉到了这些研究结果,并将其包装成易于理解的简短信息,像这样:
“研究表明,家庭一起吃饭的频率越高,孩子们吸烟、饮酒、吸毒、抑郁、患上饮食障碍以及考虑自杀的可能性就越小,他们更有可能做......
OMOP 与 DataSHIELD:提升隐私保护医疗分析的完美匹配?
探索 DataSHIELD 与 OHDSI/OMOP 之间的协同效应,以促进协作医疗分析
·发布于数据科学前沿 ·阅读时间:8 分钟·2024 年 7 月 3 日
--
背景
跨境或多站点的数据共享可能会面临挑战,原因包括法规和法律的差异,以及有关数据隐私、安全性和所有权的担忧。然而,进行这些共享的需求正在增长。
大规模跨国和多站点的临床研究可以生成更强大和及时的证据,从而改善医疗保健。为了解决这一问题,罗氏的联邦开放科学团队相信,联邦分析(增强隐私的去中心化统计分析)是促进更多多站点和数据驱动的合作的有前景的解决方案。
高质量(经过筛选的)患者级别数据的可获得性和可访问性仍然是进展中的一个持续瓶颈。联邦模型是医学领域中进行协作分析和机器学习的一个推动因素,它无需将任何敏感的患者级数据转移。
用于分析的联邦模型
联邦范式的理念是将分析带到数据中,而不是将数据带到分析中。
这意味着数据保持在各自组织的边界内,协作分析工作并不意味着将数据复制到本地基础设施之外,也不意味着对数据进行无限制查询的访问。
它有许多优点,包括:
-
降低数据暴露风险
-
没有难以追踪和管理的数据副本离开场所
-
避免了构建数据湖的前期成本和努力
-
跨越监管边界
-
尝试不同分析方法和功能的互动方式
让我们用一个简化的例子来说明,假设有来自三个不同医院的糖尿病患者。假设外部数据科学家想要分析患者的平均年龄。

简化的联邦分析示意图(图片来源:作者)
远程数据科学家并未完全获得数据拥有者的信任,不能访问数据,无法访问任何行级数据,也不能随意发送查询(如 DataFrame.get),但他们可以调用联邦函数,并在网络中获取聚合的均值。
数据拥有者允许远程数据科学家在指定的队列和变量(例如年龄)上运行联邦函数均值。
这些高级分析能力在进行观察性研究时为评估不同地区群体的治疗效果等提供了极大的附加值和支持。
这就是数据科学家使用流行的联邦分析解决方案 DataSHIELD 时的视角。

分析脚本的截图(图片来源:作者)
DataSHIELD 是什么?
DataSHIELD是一个允许您在不查看数据或推断其中任何敏感信息的情况下分析敏感数据的系统。
它源于学术项目 DataSHIELD(利物浦大学)和 obiba.org(麦吉尔大学)。
它是一个开源解决方案,托管在GitHub上,这有助于建立信任和透明度,因为这段代码在数据拥有者的基础设施防火墙后运行。
它已经在市场上存在了超过十年,并且在多个成功的项目中得到了应用。

RStudio 或 Jupyter R 笔记本是与联邦网络互动的常见方式(图片来源:作者)
DataSHIELD 的主要优势包括:
-
具有披露检查和智能聚合结果的高级联邦分析函数
-
联邦认证与授权,使数据拥有者能够完全控制谁对其数据做什么
-
用于自动化架构各部分的 API
-
内置扩展机制,允许创建自定义联邦函数
-
额外功能的社区包
-
完全透明,所有代码都可在 GitHub 上找到
数据拥有者负责:
-
在其基础设施中部署本地 DataSHIELD Opal 和 Rock 节点
-
管理用户、权限(从函数到变量)
-
配置披露检查过滤器
-
审查和接受自定义函数及其本地部署
数据分析师是:
-
调用联邦函数并聚合结果,通常能提供高精度的结果,而不是使用元分析,且始终确保数据泄露保护。
-
编写和测试他们的自定义联邦函数,然后将其与网络共享,由数据拥有者在所有节点上部署并用于协作分析工作。
观察性健康数据科学与信息学的优势
OHDSI以其数据协调和标准化著称,称为观察性医疗结果合作伙伴关系(OMOP)通用数据模型(CDM)。
当前版本的标准是 5.4,尽管它正在不断发展,以适应来自现实世界应用的反馈和新需求,但它已经成熟并得到了 OHDSI 生态系统中工具的支持,如ATLAS、HADES和Strategus。
OHDSI 技术栈已有十多年历史,拥有许多成功的实际应用。
OHDSI 不要求医院和其他数据源将其数据或 API 暴露到互联网,因此可以通过将分析规范交付给数据拥有者来执行分析查询和算法,数据拥有者执行分析,审查输出结果并通过安全渠道将其发送到分析方。OHDSI 提供端到端的工具来支持这一工作流程的所有步骤。
集成的商业价值
DataSHIELD 虽然需要连接到其分析服务器的 API(Opal),但它通过使用一组不泄露数据的分析函数和内置的高级泄露检查,提供了一种互动的分析数据方式,同时保护数据隐私。
这使得分析更加敏捷、探索性(在一定程度上),并使数据分析师能够尝试不同的分析方法,从数据中学习。
在传统的 OHDSI 方法中,代码固定在定义好的研究定义中,由数据拥有者手动执行。这导致获得结果的等待时间较长(依赖人工),可能需要几周甚至几个月,具体取决于各组织的情况。而在描述的联邦分析方法中,结果可以在几秒钟内获得。
另一方面,不需要手动审查返回给外部分析师的结果,数据拥有者应信任内置的联邦函数和泄露检查。同时,联邦方法需要互联网连接。
利益总结:
-
DataSHIELD 使结果能够立即且自动地可用
-
内置联邦聚合提高了准确性
-
泄露保护保护原始数据
-
重新利用在 OMOP CDM 数据协调中的投资
-
通过使用 OMOP 进行协调来提高数据质量 → 更高质量的分析结果
换句话说,可以兼顾两全其美,以改进现实世界医疗保健应用中的分析结果。
集成场景
我们与 DataSHIELD 团队 合作,确定了四种主要的集成场景。我们的角色(联邦开放科学团队)不仅仅是表达我们对集成的兴趣和商业理由,而是定义可行的集成架构和概念验证定义。
选项 1. 从 OMOP CDM 数据源提取、加载并转换(ETL)数据到 DataSHIELD 数据存储(项目开始时)。

(图片由作者提供)
在这种方法中,我们使用经典的 ETL 方法从 OHDSI 数据源提取数据,并将其转换为即将成为数据源的数据,然后将其作为资源添加或直接导入到 DataSHIELD Opal 服务器。
选项 2. 将 OMOP CDM 作为 DataSHIELD 中原生支持的数据源。

(图片由作者提供)
DataSHIELD 支持多种数据源(如 CSV 等平面文件、XML、JSON 等结构化数据、关系型数据库等),但不直接支持 OHDSI OMOP CDM 数据源。
dsOMOP 库(正在开发中)的目标是为 DataSHIELD 提供扩展,以便为 OMOP CDM 数据源提供一流的支持。
选项 3. 使用 REST API 根据需要检索数据子集。

(图片由作者提供)
该选项不绕过 OHDSI 堆栈的 API 层,而是作为 DataSHIELD API 到 OHDSI 工具 API 的桥接、编排和翻译层。
选项 4. 将 DataSHIELD 嵌入 OHDSI 堆栈中。

(图片由作者提供)
这意味着两个生态系统的深度集成,以最大化其效益,但需要付出较高的努力和两个团队(DataSHIELD 和 OHDSI 技术团队)之间的协调。
采用障碍
这两种解决方案和社区在使用各自的工具和方法进行成功的分析项目方面有着良好的记录。过去,DataSHIELD 方面曾有过有限的尝试来接纳 OMOP CDM 和查询库(即 GitHub — sib-swiss/dsSwissKnife,早期的 github.com/isglobal-brge/dsomop)。
我们试图解决的主要问题是对联邦模型的认知仍然有限,我们在鹿特丹的 OHDSI Europe 2024 研讨会上愉快地展示了这一点,并获得了非常积极的反馈,大家认识到未来集成的好处。来自数据分析师视角的联邦分析如何工作的实操演示非常有助于传达信息。关于计划集成的主要问题是“何时”而非“为什么”,我们认为这是一个好兆头,也是对未来的鼓励。
两个技术生态系统(DataSHIELD,OHDSI)都已经成熟,然而它们的整合正在开发中(截至 2024 年 6 月),尚未准备好投入生产。DataSHIELD 可以在没有 OMOP CDM 的情况下使用,尽管数据质量和协调问题被认识到,OMOP 从未是联合项目的直接要求或指导。
如果项目更多集中于长期合作而非单次分析,联合网络的价值可能会更高,建立网络的初始成本(从各个角度看)在多个研究执行时能够得到重复利用。目前,在这一领域已有进展的迹象,尽管大多数联合项目是单一研究项目。
未来步骤
我们对 OHDSI 和 DataSHIELD 整合的潜力和未来持乐观态度。这是行业期望发生的事情,并且得到了两个社区的热烈欢迎。
dsOMOP R 库在 DataSHIELD 中的开发最近得到了加速。
预计结果将提供数据源整合的端到端解决方案(策略二),并允许进一步发展和更紧密地合作这两个生态系统。预期整合的实际应用总是收集宝贵反馈和发现问题的最佳方式。
作者特别感谢Jacek Chmiel对本文的重大影响,以及以下帮助塑造此项目的人员:Jacek Chmiel、Rebecca Wilson、Olly Butters和Frank DeFalco,以及 Roche 的 Federated Open Science 团队。
在 Power BI 中处理预计算层次数据
虽然层次结构是数据中的常见概念,但一些来源以不寻常的格式提供数据。通常,我们在最低层级获取值。但当我们得到预先聚合的值时,会发生什么呢?在这里,我将深入探讨这个话题。
·发表于Towards Data Science ·8 分钟阅读·2024 年 5 月 3 日
--

照片由ThisisEngineering提供,来源于Unsplash
介绍和数据
让我们设定一个场景:我们有一个包含行政费用的组织。
费用可以发生在国家、州和商店级别。
请看以下表格:

图 1 — 数据在预期位置的值(图源:作者)
我们看到两行分别是两家商店的费用,一行是南卡罗来纳州的组织费用。
我可以使用这些数据来计算费用的总和,并得出南卡罗来纳州所有商店的总费用。
但是,当源系统以不同的形式提供数据时,怎么办?
例如,像这样:

图 2 — 南卡罗来纳州预先聚合的值数据(图源:作者)
第三行包含了南卡罗来纳州两家商店的预先聚合的总和,以及南卡罗来纳州的组织费用。
简单地将这三行相加会得到错误的结果,因为结果中会重复计算这两家商店的费用:

图 3 — 聚合包含预先聚合值的数据时的错误结果(图源:作者)
挑战是:如何计算每个层级中的正确结果?
我的解决方案方法必须考虑以下几点:
-
我不能更改数据源中的数据。
-
我必须在数据模型中添加一些计算以纠正结果。
-
我必须在层级的每个级别执行不同的计算。
但是我在哪里以及如何进行操作呢?
我有三种方法可以解决这个问题:
-
添加一个计算列来获得正确的结果。
-
添加一个度量值来计算正确的结果。
-
使用可视化计算。
计算列
好的,让我们开始添加一些计算列。
首先,我需要知道每一行在层级中的级别。为此,我需要一个名为“路径长度”的列。这样的列通常用于处理父子层级。
因此,我添加了两列新列,以便更好地导航层级:

图 4 — 用于层级导航的额外计算列(作者提供的图)
我使用了以下表达式来计算 HierachyPath 列:
HierarchyPath =
'Cost Data'[Country]
& IF (
'Cost Data'[State] <> 'Cost Data'[Country],
"|" & 'Cost Data'[Country]
)
& IF (
'Cost Data'[Store] <> 'Cost Data'[State],
"|" & 'Cost Data'[Store]
)
然后,我使用了PATHLENGHTH()函数来计算“路径长度”列:
Path Length = PATHLENGTH('Cost Data'[HierarchyPath])
接下来,我可以编写一个表达式,执行以下步骤来处理表中的每一行:
-
获取当前职位的值。
-
获取当前职位在层级中下方的值的总和。
-
从第 2 步的总和中扣除当前行中的值。
结果是一个包含上面第一张图片中值的列。
Corrected Expenses =
VAR CurrentExp = 'Cost Data'[Expenses]
VAR CurrentLevel = 'Cost Data'[Path Length]
VAR CurrentPath = 'Cost Data'[HierarchyPath]
VAR ChildExpenses =
CALCULATE(SUM('Cost Data'[Expenses])
,REMOVEFILTERS('Cost Data')
,'Cost Data'[Path Length] = CurrentLevel + 1
,CONTAINSSTRING('Cost Data'[HierarchyPath], CurrentPath)
)
RETURN
CurrentExp - ChildExpenses
关键在于“ChildExpenses”变量的表达式。该表达式计算了当前职位下、同一父级下的所有行的总和。
请注意,在 Power BI 中调用CALCULATE()函数计算一个计算列时,会触发上下文转换。
如果你不熟悉上下文转换的概念,确保阅读我解释它的文章:
行上下文和筛选上下文是 DAX 中的常见概念。但我们可以通过上下文转换在这两者之间切换。
towardsdatascience.com
这是该列的结果:

图 5 — 计算列的结果以获得正确的结果(作者提供的图)
这列替代了原始的 Expenses 列。
我将原始的 Expenses 列重命名为“Expense_Original”,并将计算列重命名为“Expenses”。由于 Expense_Original 列对报告没有用处,因此它在数据模型中是隐藏的。
现在,我可以直观地创建报告了:

图 6 — Power BI 中重命名的原始 Expenses 列和计算列并排显示(作者提供的图)
这是所需的结果。
但让我们看看我是否能创建一个度量值来计算正确的结果。
度量值
要编写一个度量值,我必须分别处理每个层级。
我不能使用与计算列相同的方法,因为在每个上层级(如国家或州)下,商店级别有多行数据。
结果是以下的 DAX 代码:
Expenses (Corrected) =
VAR CurrentExp = [Expenses (Original)]
VAR CurrentLevel = SELECTEDVALUE('Cost Data'[Path Length])
VAR CurrentPath = SELECTEDVALUE('Cost Data'[HierarchyPath])
VAR CurrentCountry = SELECTEDVALUE('Cost Data'[Country])
VAR CurrentState = SELECTEDVALUE('Cost Data'[State])
VAR CurrentStore = SELECTEDVALUE('Cost Data'[Store])
VAR StateExpenses =
-- Get the pre-aggregated value of the Expenses for the State
CALCULATE([Expenses (Original)]
,REMOVEFILTERS('Cost Data')
,'Cost Data'[Path Length] = CurrentLevel + 1
,CONTAINSSTRING('Cost Data'[HierarchyPath], CurrentPath)
)
RETURN
SWITCH(TRUE()
-- Calculation at the lowest level (Store)
-- But only when the Store has a different name than the State
,NOT ISBLANK(CurrentStore) && CurrentStore <> CurrentState
,CurrentExp
-- Detract the Expenses from the sum at the State level when the "Store" has the same name as the State
-- These are the rows with the Expenses for the State
,NOT ISBLANK(CurrentStore) && CurrentStore = CurrentState
,CurrentExp - StateExpenses
-- Calculate the Sum at the state level
,NOT ISBLANK(CurrentState) && ISBLANK(CurrentStore)
-- First, calculate the Sum for all Stores
-- But only when the Stores have a different name than the State
,CALCULATE([Expenses (Original)]
,REMOVEFILTERS('Cost Data')
,'Cost Data'[Country] = CurrentCountry
,'Cost Data'[State] = CurrentState
,'Cost Data'[Store] <> CurrentState
)
-- At this stage, each row in the Visual has multiple Data rows.
-- Therefore, SELECTEDVALUE() for the path doesn't return any value.
-- Now add the sum for all Stores, detracting the duplicate value for the "Stores" with the same name as the State
+
(
CALCULATE([Expenses (Original)]
,REMOVEFILTERS('Cost Data')
,'Cost Data'[Country] = CurrentCountry
,'Cost Data'[State] = CurrentState
,'Cost Data'[Store] = CurrentState
)
-
CALCULATE([Expenses (Original)]
,REMOVEFILTERS('Cost Data')
,'Cost Data'[Country] = CurrentCountry
,'Cost Data'[State] = CurrentState
,'Cost Data'[Store] <> CurrentState
)
)
-- Calculate the corrected Sum for the Country
-- Must use the same logic as above, but by moving one level above, considering only the Country and the State
,CALCULATE([Expenses (Original)]
,REMOVEFILTERS('Cost Data')
,'Cost Data'[Country] = CurrentCountry
,'Cost Data'[State] <> CurrentCountry
) +
(
CALCULATE([Expenses (Original)]
,REMOVEFILTERS('Cost Data')
,'Cost Data'[Country] = CurrentCountry
,'Cost Data'[State] = CurrentCountry
)
-
CALCULATE([Expenses (Original)]
,REMOVEFILTERS('Cost Data')
,'Cost Data'[Country] = CurrentCountry
,'Cost Data'[State] <> CurrentCountry
)
)
)
我在代码中添加了大量的注释。
因此,我不会详细解释度量值的每一步。
然而,这种方法非常复杂,无法与使用计算列的方法的简便性相比。
可视化计算
最后,我可以使用 Power BI 中的最新功能之一:可视化计算。
可视化计算可以直接在视觉效果中添加计算,而无需将度量值添加到数据模型中。
这为我们提供了一些激动人心的可能性,并且消除了为满足特定视觉效果需求而编写度量值的必要。
我在下面的参考部分添加了一些关于这个话题的链接。
在这里,我尝试使用这个新功能来实现一个简单的解决方案。
然而,在进行了大量的研究和反复试验后,我仍然没有找到一个有效的解决方案。
我找到了解决方案来计算每个商店的正确结果,但对于州和国家的计算则没有成功:
Visual calculation =
VAR CurrentCountry = [Country]
VAR CurrentState = [State]
RETURN
SWITCH(TRUE()
,[State] <> [Store] && ISATLEVEL([Store])
,[Expenses (Original)]
,[State] = [Store] && ISATLEVEL([Store])
,[Expenses (Original)] -
CALCULATE(SUM([Expenses (Original)])
,[State] <> [Store]
,[Country] = CurrentCountry
,[State] = CurrentState
)
)
这是这个公式的结果:

图 7 — 可视化计算的结果(图由作者提供)
我尝试找出计算州和国家结果的解决方案,但没有成功。
一个关键细节是,我们使用这种方法时不会得到总计。这可能是由于公式未完成的原因,但这是一个重要的细节。
尽管在这个特定场景中没有成功,我还是会将这个新功能纳入我的技能范围。
我鼓励你关注这个新的激动人心的功能。它为找到计算特定于某个视觉效果的结果提供了新的可能性,这些结果不会在其他地方重用。
Amit Chandak 写了关于这个话题的介绍:
[## 在 Power BI 中理解可视化计算:数据分析的革命
在 2024 年 2 月,Power BI 推出了一个正在预览中的具有颠覆性意义的功能:可视化计算。这个…
结论
当我开始写这篇文章时,我打算为你提供三个有效的解决方案来解决这个问题。
我找到了两种可行的解决方案。第一种是最直接且高效的,而第二种则是一个很好的操作层次结构的练习。
一段时间前,我写了一篇简短的文章,讲解了为什么预聚合数据对我们不利:
我的一个客户总是希望在他的 Excel 文件中为报告预计算聚合。以下是我们应该避免的原因...
towardsdatascience.com
现在,我有一个更多的例子,证明我之前写的那句话是对的。
一个重要的收获是,准备和格式化数据以支持简易的解决方案至关重要,即使这意味着要额外付出努力去寻找解决方案。
参考文献
下面是一些关于新可视化计算功能的参考资料:
MS Power 博客上的可视化计算(预览)(2024 年 2 月)
SQLBI 文章页面(包含多个相关主题的文章,更多内容敬请期待)
我使用了 Contoso 示例数据集,就像在我之前的文章中一样。你可以从微软这里免费下载 ContosoRetailDW 数据集。
Contoso 数据可以在 MIT 许可证下自由使用,具体请参见这里。
我提取了数据的一个子集,并进行了处理,以获得所需的数据。
请考虑关注我并订阅,以便在我添加新内容时立即收到电子邮件:
[## 每当 Salvatore Cagliari 发布新内容时,收到电子邮件通知。
每当 Salvatore Cagliari 发布新内容时,收到电子邮件通知。通过注册,你将创建一个 Medium 账户,如果你还没有的话...
medium.com](https://medium.com/@salvatorecagliari/subscribe?source=post_page-----4a215b96b99c--------------------------------)
我将我的文章免费开放给每个人,尽管 Medium 有付费墙。这允许我从每个读者那里赚取一点收入,但我关闭了付费墙,以便你能免费阅读我的文章。
你可以通过以下方式支持我的工作,这是我在空闲时间进行的工作:
buymeacoffee.com/salvatorecagliari
或扫描此二维码:

任何支持都非常感谢,这能帮助我腾出更多时间为你创造更多内容。
非常感谢。
关于 Hopfield 网络
从一般模型到特例
·发表于Towards Data Science ·13 分钟阅读·2024 年 10 月 12 日
--

2024 年诺贝尔物理学奖授予了 John Hopfield 和 Geoffrey Hinton,其中 John Hopfield 因其在所谓的 Hopfield 网络方面的工作而获奖。
几十年前,我的博士论文研究的是 Hopfield 风格的网络。我觉得这是写这篇文章的好时机。
我将重点讨论 Hopfield 网络。我会从最简单的例子开始,演示它是如何工作的,然后引入能量函数的概念并在局部最小化它,接着讨论我博士论文中提出的特例,再回到更一般的情况,讨论联想记忆和优化应用,最后再次回到特例,在更具体的场景中实例化这些应用。
让我们从头开始。首先,什么是 Hopfield 网络?
请看下面的图片。
如下解释。我们有两个神经元,每个神经元的取值为 1(可以理解为“激活”)或-1(可以理解为“未激活”)。这两个神经元通过一个正权重连接,表示具有正向影响的突触。这意味着,从任意一个神经元的角度来看,假设是第一个神经元,它希望另一个神经元与它保持相同的状态——“激活”或“未激活”。
关于雅各布·伯努利、大数法则以及中心极限定理的起源

公有领域/公有领域/CC BY-SA 3.0/作者图片/公有领域
通过历史的长镜头探索大数法则弱法则和中心极限定理
·发表于Towards Data Science ·阅读时长 16 分钟·2024 年 1 月 23 日
--
在我之前的文章中,我向大家介绍了中心极限定理。我们解构了它的定义,探讨了它的应用,并在模拟中看到了它的神奇效果。
我在那篇文章的结尾提到了一个哲学性问题,这个问题由 17 世纪著名数学家提出,关于当自然界面对大量物体时,如何表现。这个问题最终引导我们发现了中心极限定理,并且是在一个多世纪后才被发现的。
在本文中,我将深入探讨这个问题,以及思考过这个问题的数学家的生平,并且揭示从中展开的重大发现。
大数法则弱法则的发现
一切始于雅各布·伯努利。大约在 1687 年,来自今天瑞士巴塞尔的伯努利家族的大儿子,年仅 32 岁的他,开始致力于他的大作《猜想的艺术》(Ars Conjectandi)的第四部分,也是最后一部分。在第四部分中,伯努利专注于概率及其在…
关于 AWS Trainium 和 Inferentia 的可编程性
加速 AI/ML 模型训练与自定义运算符 — 第四部分
·发表于 Towards Data Science ·阅读时间:12 分钟·2024 年 11 月 1 日
--

图片由 Agata Bres 提供,来源于 Unsplash
在本文中,我们继续探索通过自定义运算符开发来优化机器学习(ML)工作负载运行时性能的机会。这次,我们重点介绍 AWS Neuron SDK 提供的工具,用于在 AWS Trainium 和 AWS Inferentia 上开发和运行新内核。随着推动 AI 革新的低层次模型组件(例如,注意力层)的快速发展,用于训练和运行 ML 模型的加速器的可编程性变得至关重要。特别是专用 AI 芯片,必须提供一种有价值的替代方案,以应对广泛使用且具有深远影响的通用 GPU(GPGPU)开发框架,如 CUDA 和 Triton。
在之前的文章中(例如,这里和这里),我们探讨了在 AWS 定制的 AI 芯片上构建和运行 ML 模型的机会,使用的是专用的AWS Neuron SDK。在 SDK 的最新版本(2.20.0)中,AWS 引入了用于开发自定义内核的Neuron 内核接口(NKI),该内核支持底层加速器NeuronCore-v2,这一加速器为Trainium和Inferentia2提供动力。NKI 接口与另一个 API 配合使用,该 API 使NeuronCore-v2能够进行编程,即Neuron 自定义 C++操作符。在本文中,我们将探讨这两种机会并展示其应用。
免责声明
重要的是,本篇文章不应被视为官方AWS Neuron SDK 文档的替代。在撰写本文时,Neuron SDK 的自定义内核开发 API 仍处于 Beta 阶段,在你阅读本文时可能会发生变化。我们分享的示例仅供演示用途,不能保证其最优性、鲁棒性、耐用性或准确性。请不要将我们提到的任何平台、工具、API 等视为对其使用的支持。任何项目的最佳选择取决于具体的使用案例,需进行适当的调查和分析。
为神经核心开发自定义内核
尽管 Neuron SDK 支持的 ML 模型列表不断增长,但一些操作仍然不受支持或实现不尽如人意。通过开放 Neuron 内核自定义 API,SDK 使开发者能够创建和/或优化他们所需的底层操作,极大地增加了在 Trainium 和 Inferentia 上运行 ML 工作负载的机会。
如我们在本系列的上一篇文章中讨论的,要充分利用这些 AI 芯片的强大性能,需要详细了解其底层架构。
神经核心架构
NKI 文档包含了一个专门的章节介绍 NeuronCore-v2 的架构设计及其对自定义操作符开发的影响。重要的是,Neuron 核心与其 AI 加速器对等体(例如 GPU 和 TPU)之间有许多差异。针对 Neuron 核心的优化需要一套独特的策略和技能。
与其他专用 AI 芯片类似,NeuronCore-v2 包括多个内部加速引擎,每个引擎专门执行某些类型的计算。各引擎可以异步并行运行。Neuron 编译器负责将机器学习模型转化为低级操作,并优化每个操作所使用的计算引擎。
Tensor 引擎专门用于矩阵乘法。Vector 引擎和Scalar 引擎都操作张量,其中 Vector 引擎专注于归约操作,Scalar 引擎则专注于非线性函数。GpSimd 引擎是一个通用引擎,能够运行任意 C/C++程序。请注意,尽管NKI接口暴露了对所有四个计算引擎的访问,自定义 C++操作符是专门为GpSimd引擎设计的。
每个引擎的能力的更多细节可以在架构文档中找到。此外,NKI 指令集架构(ISA)文档提供了不同低级操作所运行的引擎的详细信息。
Neuron 芯片的另一个重要方面是其内存架构。一个 Neuron 设备包括三种类型的内存,HBM、SBUF 和 PSUM。深入了解每种内存的容量和功能对优化内核开发至关重要。
根据架构概述,你可能会得出结论,Neuron 内核开发需要高度专业的知识。虽然这对于创建充分优化、充分利用 Neuron 核心所有功能的内核来说是正确的,但我们的目标是展示 Neuron 自定义内核 API 的可访问性、价值和潜力——即使对于非专家开发者也是如此。
自定义 NKI 内核
NKI接口是一个 Python 级 API,它向 ML 开发者公开了 Neuron 核心计算引擎和内存资源的使用。 NKI 入门指南详细介绍了设置说明,并通过一个简单的“hello world”内核提供了一个平稳的入门体验。NKI 编程模型指南详细说明了典型 NKI 内核的三个阶段(加载输入、在计算引擎上运行操作和存储输出),并介绍了 NKI Tile 和基于 Tile 的操作。NKI 教程展示了多种 NKI 内核示例应用程序,每个示例都引入了新的核心 NKI API 和功能。考虑到这些示例内核的假定最优性,开发新内核的一种可能策略是:1) 确定一个与你想要实现的操作相似的示例,2) 然后将其作为基准,迭代优化并调整,以实现你所需的特定功能。
NKI API 参考手册详细介绍了用于内核开发的 Python API。其语法和语义与Triton和NumPy类似,NKI 语言的定义旨在最大化可访问性和易用性。然而,需要注意的是,NKI 内核开发仅限于NKI库中定义的操作,这些操作(截至本文撰写时)比Triton和NumPy等库中的操作要少且受限。
示例 — 一个 GIOU 内核
与我们的上一篇文章一样,我们通过构建广义交集比率(GIOU)操作的自定义实现,来评估 NKI 的使用。由于 GIOU 涉及逐像素操作,我们参考了NKI 编程指南中的exp内核,并将 NKI 的高级张量索引技术融入到我们的实现中。为了在 CPU 环境中便于调试,我们还添加了使用nki.simulate_kernel和nki.language.device_print.htmlAPI 运行代码的选项。
import torch
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import numpy as np
simulate = False
try:
# if torch libraries are installed assume that we are running on Neuron
import torch_xla.core.xla_model as xm
import torch_neuronx
from torch_neuronx import nki_jit
device = xm.xla_device()
# empty implementation
def debug_print(*args, **kwargs):
pass
except:
# if torch libraries are not installed assume that we are running on CPU
# and program script to use nki simulation
simulate = True
nki_jit = nki.trace
debug_print = nl.device_print
device = 'cpu'
@nki_jit
def giou_kernel(preds_ptr,
targets_ptr,
output_ptr):
epsilon = 1e-5
TILE_M = nl.tile_size.pmax # 128
TILE_N = nl.tile_size.psum_fmax # 512
TILE_N_OUT = TILE_N // 4
p_1, p_2 = preds_ptr.shape
t_1, t_2 = targets_ptr.shape
o_1, o_2 = output_ptr.shape
# verify input
# batch size must be multiple of 128
assert p_1 % TILE_M == 0
assert p_1 == t_1
assert p_1 == o_1
# num boxes box *4 must be multiple of 512
assert p_2 % TILE_N == 0
assert p_2 == t_2
assert p_2 // 4 == o_2
num_tiles_m = p_1 // TILE_M
num_tiles_n = p_2 // TILE_N
# Generate tensors for advanced indexing
i_p = nl.arange(TILE_M)[:, None]
i_f = nl.arange(TILE_N // 4)[None, :]
i_f_0 = (4 * i_f)
i_f_1 = (4 * i_f + 1)
i_f_2 = (4 * i_f + 2)
i_f_3 = (4 * i_f + 3)
# Use affine_range to loop over tiles
for m in nl.affine_range(num_tiles_m):
for n in nl.affine_range(num_tiles_n):
# Load input data from HBM
preds = nl.load(preds_ptr[m * TILE_M:(m + 1) * TILE_M,
n * TILE_N:(n + 1) * TILE_N])
targets = nl.load(targets_ptr[m * TILE_M:(m + 1) * TILE_M,
n * TILE_N:(n + 1) * TILE_N])
debug_print('preds', preds)
preds_left = preds[i_p, i_f_0]
preds_top = preds[i_p, i_f_1]
preds_right = preds[i_p, i_f_2]
preds_bottom = preds[i_p, i_f_3]
gt_left = targets[i_p, i_f_0]
gt_top = targets[i_p, i_f_1]
gt_right = targets[i_p, i_f_2]
gt_bottom = targets[i_p, i_f_3]
# Compute the area of each box
area1 = (preds_right - preds_left) * (preds_bottom - preds_top)
area2 = (gt_right - gt_left) * (gt_bottom - gt_top)
# Compute the intersection
left = nl.maximum(preds_left, gt_left)
top = nl.maximum(preds_top, gt_top)
right = nl.minimum(preds_right, gt_right)
bottom = nl.minimum(preds_bottom, gt_bottom)
inter_w = nl.maximum(right - left, 0)
inter_h = nl.maximum(bottom - top, 0)
inter_area = inter_w * inter_h
union_area = area1 + area2 - inter_area
iou_val = inter_area / nl.maximum(union_area, epsilon)
# Compute the smallest enclosing box
enclose_left = nl.minimum(preds_left, gt_left)
enclose_top = nl.minimum(preds_top, gt_top)
enclose_right = nl.maximum(preds_right, gt_right)
enclose_bottom = nl.maximum(preds_bottom, gt_bottom)
enclose_w = nl.maximum(enclose_right - enclose_left, 0)
enclose_h = nl.maximum(enclose_bottom - enclose_top, 0)
enclose_area = enclose_w * enclose_h
# Compute GIOU
delta_area = (enclose_area - union_area)
enclose_area = nl.maximum(enclose_area, epsilon)
giou = iou_val - delta_area / enclose_area
# Store results
nl.store(output_ptr[m * TILE_M:(m + 1) * TILE_M,
n * TILE_N_OUT:(n + 1) * TILE_N_OUT],
giou)
为了运行我们的 GIOU 内核,我们生成了两批随机框并将其输入到我们的函数中:
# generate random data in np
np.random.seed(0)
batch_size = 1024
n_boxes = 256
img_size = 256
boxes = []
for i in range(2):
# Randomly generate box sizes and positions
box_sizes = np.random.randint(1, img_size, size=(batch_size,n_boxes,2))
top_left = np.random.randint(0, img_size-1, size=(batch_size,n_boxes,2))
bottom_right = np.clip(top_left + box_sizes, 0, img_size - 1)
# Concatenate top-left and bottom-right coordinates
rand_boxes = np.concatenate((top_left, bottom_right), axis=2)
boxes.append(rand_boxes.astype(np.float32))
out = np.empty((batch_size, n_boxes), np.float32)
# convert tensors to PyTorch
t_boxes_0 = torch.tensor(boxes[0]).to(device)
t_boxes_1 = torch.tensor(boxes[1]).to(device)
t_out = torch.tensor(out).to(device)
if simulate:
# the simulation API requires numpy input
nki.simulate_kernel(giou_kernel,
boxes[0].reshape((batch_size, -1)),
boxes[1].reshape((batch_size, -1)),
out)
else:
giou_kernel(t_boxes_0.view((batch_size, -1)),
t_boxes_1.view((batch_size, -1)),
t_out)
为了评估我们 NKI 内核的性能,我们将其与以下在 PyTorch 中实现的 GIOU 朴素实现进行比较:
def torch_giou(boxes1, boxes2):
# loosely based on torchvision generalized_box_iou_loss code
epsilon = 1e-5
# Compute areas of both sets of boxes
area1 = (boxes1[...,2]-boxes1[...,0])*(boxes1[...,3]-boxes1[...,1])
area2 = (boxes2[...,2]-boxes2[...,0])*(boxes2[...,3]-boxes2[...,1])
# Corners of intersection
lt = torch.max(boxes1[..., :2], boxes2[..., :2])
rb = torch.min(boxes1[..., 2:], boxes2[..., 2:])
# Width and height of intersection
wh = (rb - lt).clamp(min=0)
# Area of the intersection
inter = wh[..., 0] * wh[..., 1]
# Union of the two boxes
union = area1 + area2 - inter
iou = inter / union.clamp(epsilon)
# Corners of enclosing box
lti = torch.min(boxes1[..., :2], boxes2[..., :2])
rbi = torch.max(boxes1[..., 2:], boxes2[..., 2:])
# Width and height of the enclosing box
whi = (rbi - lti).clamp(min=0)
# Area of the enclosing box
areai = (whi[..., 0] * whi[..., 1]).clamp(epsilon)
return iou - (areai - union) / areai
我们使用以下基准测试工具来比较我们两个函数的运行时性能:
import time
def benchmark(f, warmup_iters=20, ntrials: int = 100):
def run(*args, **kwargs):
# warmup
for _ in range(warmup_iters):
f(*args, **kwargs)
start_time = time.time()
for _ in range(ntrials):
f(*args, **kwargs)
end_time = time.time()
# Calculate average time per iteration
avg_time = (end_time - start_time) / ntrials
return avg_time
return run
avg_time = benchmark(torch_giou)(t_boxes_0, t_boxes_1)
print(f'torch_giou: {avg_time}')
avg_time = benchmark(giou_kernel)(t_boxes_0.view((batch_size, -1)),
t_boxes_1.view((batch_size, -1)),
t_out)
print(f'giou_kernel: {avg_time}')
运行环境
我们在一个Amazon EC2 inf2.xlarge实例上运行了脚本(该实例包含两个Neuron 核心和四个 vCPU)。我们使用了在撰写本文时最新版本的Deep Learning AMI for Neuron,即“Deep Learning AMI Neuron (Ubuntu 22.04) 20241027”,以及AWS Neuron 2.20.1和PyTorch 2.1。
结果
我们自定义的 GIOU 内核展示了平均运行时间为 0.211 毫秒,相比之下,原始实现为 0.293 毫秒,性能提升了 39%。请注意,这些结果仅适用于我们的示例。其他操作符,特别是包含矩阵乘法(并利用 Tensor 引擎)的操作,可能会展示不同的比较结果。
优化 NKI 内核性能
我们内核开发的下一步——超出本文范围——是使用专用的Neuron Profiler分析 GIOU 内核的性能,以识别瓶颈并优化实现。有关更多详情,请参阅NKI 性能指南。
Neuron 自定义 C++ 运算符
创建自定义 Neuron 内核的第二种方法是为GpSimd 引擎构建 C++ 运算符。这种方法在Neuron 自定义 C++ 运算符开发者指南中进行了描述,并在Neuron 自定义 C++ 运算符在 MLP 中的应用和Neuron 自定义 C++ 运算符性能优化教程中进行了演示。
Neuron 自定义 C++ 运算符为 GpSimd 引擎上的“内核融合”提供了机会,通过将多个低级操作组合成一个内核执行。这种方法可以显著减少与以下两项操作相关的开销:1)加载多个单独的内核,2)在不同的内存区域之间传输数据。
玩具示例——一个 GIOU C++ 内核
在下面的代码块中,我们实现了一个用于 Neuron 的 C++ GIOU 运算符,并将其保存到名为giou.cpp的文件中。我们的内核使用TCM 访问器来优化内存读写性能,并应用多核设置,以便使用 GpSimd 的八个内部处理器。
#include <stdint.h>
#include <stdlib.h>
#include <torch/torch.h>
#include <neuron/neuron-utils.hpp>
#include <algorithm>
// input boxes of shape 1024x256x4
// output scores of shape 1024x256
torch::Tensor giou(const torch::Tensor& t_pred,
const torch::Tensor& t_target) {
size_t num_samples = t_pred.sizes()[0];
size_t num_boxes = t_pred.sizes()[1];
torch::Tensor t_out = get_dst_tensor();
// get the number of GpSimd processors (8 in NeuronCoreV2)
uint32_t cpu_count = get_cpu_count();
// get index of current processor
uint32_t cpu_id = get_cpu_id();
// divide the batch size into 8 partitions
uint32_t partition = num_samples / cpu_count;
// use tcm buffers to load and write data
size_t tcm_in_size = num_boxes*4;
size_t tcm_out_size = num_boxes;
float *tcm_pred = (float*)torch::neuron::tcm_malloc(
sizeof(float)*tcm_in_size);
float *tcm_target = (float*)torch::neuron::tcm_malloc(
sizeof(float)*tcm_in_size);
float *tcm_output = (float*)torch::neuron::tcm_malloc(
sizeof(float)*tcm_in_size);
auto t_pred_tcm_acc = t_pred.tcm_accessor();
auto t_target_tcm_acc = t_target.tcm_accessor();
auto t_out_tcm_acc = t_out.tcm_accessor();
// iterate over each of the entries in the partition
for (size_t i = 0; i < partition; i++) {
// load the pred and target boxes into local memory
t_pred_tcm_acc.tensor_to_tcm<float>(tcm_pred,
partition*cpu_id + i*tcm_in_size,
tcm_in_size);
t_target_tcm_acc.tensor_to_tcm<float>(tcm_target,
partition*cpu_id + i*tcm_in_size,
tcm_in_size);
// iterate over each of the boxes in the entry
for (size_t j = 0; j < num_boxes; j++) {
const float epsilon = 1e-5;
const float* box1 = &tcm_pred[j * 4];
const float* box2 = &tcm_target[j * 4];
// Compute area of each box
float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
// Compute the intersection
float left = std::max(box1[0], box2[0]);
float top = std::max(box1[1], box2[1]);
float right = std::min(box1[2], box2[2]);
float bottom = std::min(box1[3], box2[3]);
float inter_w = std::max(right - left, 0.f);
float inter_h = std::max(bottom - top, 0.f);
float inter_area = inter_w * inter_h;
// Compute the union area
float union_area = area1 + area2 - inter_area;
// IoU
float iou_val = inter_area / std::max(union_area, epsilon);
// Compute the smallest enclosing box
float enclose_left = std::min(box1[0], box2[0]);
float enclose_top = std::min(box1[1], box2[1]);
float enclose_right = std::max(box1[2], box2[2]);
float enclose_bottom = std::max(box1[3], box2[3]);
float enclose_w = std::max(enclose_right - enclose_left, 0.f);
float enclose_h = std::max(enclose_bottom - enclose_top, 0.f);
float enclose_area = std::max(enclose_w * enclose_h, epsilon);
float result = iou_val - (enclose_area-union_area)/enclose_area;
tcm_output[j] = result;
}
// write the giou scores of all boxes in the current entry
t_out_tcm_acc.tcm_to_tensor<float>(tcm_output,
partition*cpu_id + i*tcm_out_size,
tcm_out_size);
}
torch::neuron::tcm_free(tcm_pred);
torch::neuron::tcm_free(tcm_target);
return t_out;
}
我们需要一个单独的shape.cpp文件,用于定义 GIOU 函数的输出形状并将自定义运算符注册到 Neuron 库中:
#include <stdint.h>
#include <stdlib.h>
#include <torch/torch.h>
#include "torchneuron/register.h"
torch::Tensor giou_shape(torch::Tensor boxes1, torch::Tensor boxes2) {
torch::Tensor t_out = torch::zeros({boxes1.sizes()[0],
boxes1.sizes()[1]},
torch::kFloat);
return t_out;
}
NEURON_LIBRARY(my_ops, m) {
m.def("giou", &giou_shape, "giou");
}
build.py 脚本编译 C++ 运算符,并将其作为 Python API 暴露:
import os
import torch_neuronx
from torch_neuronx.xla_impl import custom_op
custom_op.load(
name='giou',
compute_srcs=['giou.cpp'],
shape_srcs=['shape.cpp'],
build_directory=os.getcwd(),
multicore=True,
verbose=True
)
编译脚本生成一个libgiou.so库,其中包含我们 C++ GIOU 运算符的实现。在下面的代码块中,我们加载库并使用上面定义的基准工具来衡量自定义内核的性能:
from torch_neuronx.xla_impl import custom_op
custom_op.load_library('libgiou.so')
avg_time = benchmark(torch.ops.my_ops.giou)(t_boxes_0, t_boxes_1)
print(f'C++ giou: {avg_time}')
运行时环境
我们使用与 NKI 实验相同的 Neuron 环境来编译和测试我们的 C++内核。请注意安装步骤,这些步骤是开发自定义 C++运算符所必需的。
结果
我们的 C++ GIOU 内核展示了平均运行时间为 0.061 毫秒——几乎比我们的基线实现快五倍。这大概是“内核融合”所带来的结果,如上所述。
结论
下表总结了我们实验的运行时结果。

不同 GIOU 实现的平均时间(越低越好)——作者
请记住,这些结果仅适用于本研究中使用的玩具示例和运行环境。其他内核的比较结果可能会大不相同——这取决于它们能够在多大程度上利用 Neuron 核心的内部计算引擎。
下表总结了我们观察到的 AWS Neuron 内核定制的两种方法之间的一些差异。

内核定制工具的比较(作者)
通过其高级 Python 接口,NKI API 以一种易于访问和用户友好的方式向机器学习开发者暴露了 Neuron 加速引擎的强大功能。低级 C++自定义运算符库则提供了更高的可编程性,但仅限于GpSimd 引擎。通过有效结合这两种工具,开发者可以充分利用 AWS Neuron 架构的能力。
摘要
随着人工智能革命的全面推进,许多公司正在开发先进的 AI 芯片,以满足日益增长的计算需求。尽管公开公告通常强调这些芯片的运行时性能、成本节约和能效,若要使这些芯片及其软件栈在机器学习开发中真正可行,还需要一些核心能力。这些能力包括强大的调试工具、性能分析与优化工具、可编程性等。
在这篇文章中,我们聚焦于用于编程 AWS 自研 AI 加速器的工具,Trainium和Inferentia,并展示了它们在构建自定义机器学习操作中的应用。这些工具使开发者能够优化 AWS AI 芯片上机器学习模型的性能,并开辟了创新和创意的新机遇。
关于四舍五入或分箱数据的统计分析
谢普德修正提供了近似值,但误差依然存在。分析界限为这些误差的大小提供了洞察。
·发表于 Towards Data Science ·阅读时间 9 分钟·2024 年 1 月 4 日
--

图片由 charlesdeluvio 提供,来自 Unsplash
想象一下有一份以英寸为单位、精确到英寸的长度测量数据。这份数据可能代表着某个医学研究中参与者的身高,形成了一个来自感兴趣群体的样本。我们的目标是估计这个群体的平均身高。
假设有一个算术平均值为 70.08 英寸。关键的问题是:这个数据有多准确?尽管样本量很大,实际上每个测量值的精度仅限于英寸。因此,即使有大量数据,我们也只能谨慎地假设真实的平均身高在 69.5 英寸到 70.5 英寸之间,并将其四舍五入到 70 英寸。
这不仅仅是一个可以轻易忽视的理论问题。例如,假设我们要计算以公制单位表示的平均身高。1 英寸等于正好 2.54 厘米,因此我们可以轻松地将测量值从英寸转换为更精细的厘米刻度,然后计算平均值。然而,考虑到英寸级别的精度,我们只能自信地断言,平均身高在 177 厘米到 179 厘米之间。问题是:我们能否自信地得出结论,平均身高是精确的178 厘米?
舍入误差或量化误差可能会带来巨大的后果——比如改变选举结果,或改变弹道导弹的航向,导致意外死亡和伤害。舍入误差如何影响统计分析是一个复杂的问题,本文将对此进行阐明。
Sheppard 修正
假设我们观察到由连续随机变量X生成的值,这些值已被舍入或分箱. 这些观察值遵循离散随机变量Y的分布,定义如下:

其中h是箱宽,⌊ ⋅ ⌋表示下取整函数。例如,X可以生成长度测量值。由于舍入不是可逆操作,仅从舍入后的值重建原始数据是不可能的。
以下近似关系涉及这些分布的均值和方差,称为Sheppard 修正 [Sheppard 1897]:

例如,如果我们给定的是四舍五入到英寸的测量值,h = 2.54 cm,并观察到标准差为 10.0 cm,则 Sheppard 的第二矩修正要求我们假设原始数据的标准差实际上是σ = 9.97 cm。对于许多实际应用,修正值非常小。即使标准差与箱宽的量级相似,修正也仅占原值的 5%。
如果满足以下条件,则可以应用 Sheppard 修正[Kendall 1938, Heitjan 1989]:
-
X的概率密度函数应足够平滑,并且其导数在尾部趋向于零,
-
箱宽h不应过大(h < 1.6 σ),
-
样本大小N不应过小也不应过大(5 < N < 100)。
前两个要求呈现为典型的“无免费午餐”情况:为了检查这些条件是否成立,我们首先必须知道真实的分布。特别是第一个条件是局部条件,因为它涉及密度的导数,而仅凭舍入或分箱数据我们无法稳健地估计这些导数。
样本大小不宜过于大的要求并不意味着随着样本大小增大,四舍五入误差的传播(绝对值)变得更加难以控制。相反,这个要求是针对当试图将四舍五入/分箱引入的偏差与较大样本中标准误差的减小进行比较时,Sheppard 修正可能不再充分的情况。
对均值估计中四舍五入误差的总变差界限
Sheppard 修正只是近似值。例如,通常情况下,估计均值的偏差,E[Y] - E[X],实际上是非零的。我们想要计算该偏差绝对值的一些上界。最简单的界限是期望值单调性的结果,以及四舍五入/分箱最多可以将值改变h / 2 的事实:

如果没有关于X分布的额外信息,我们无法改进这个界限:假设X的概率质量高度集中在一个分箱的中点之上,那么所有由X生成的值都将被+ h / 2 平移,最终得到一个Y值,从而实现上界。
然而,可以根据[定理 2.3 (i), Janson 2006]给出以下精确公式:

这里,φ( ⋅ )表示X的特征函数,即未知概率密度函数p( ⋅ )的傅里叶变换。这个公式意味着以下界限:

我们可以为我们最喜爱的分布计算这个界限,例如支持在区间[a, b]上的均匀分布:

在这里,我们使用了平方倒数和的著名值。例如,如果我们从区间b - a = 10 cm 的均匀分布中抽样,并计算已四舍五入到精度h = 2.54 cm 的数据的均值,那么估算均值的偏差最多为 1.1 毫米。
通过与[Ushakov & Ushakov 2022]中进行的计算非常相似的计算,我们还可以给出当从方差为σ²的正态分布中抽样时的四舍五入误差界限:

指数项随着分箱宽度减小而快速衰减。例如,给定标准差σ = 10 cm 和分箱宽度h = 2.54 cm,估计均值的四舍五入误差约为 10^(-133),即对于任何实际应用来说,它是可以忽略不计的。
应用[Ushakov 1999]的定理 2.5.3,我们可以给出一个更一般的界限,用概率密度函数p( ⋅ )的总变差V(p)来表示,而不是其特征函数:

其中

该计算类似于[Ushakov & Ushakov 2018]中提供的计算。例如,具有区间[a, b]的均匀分布的总变差为 2 / (b - a),因此上述公式通过特征函数的模提供与之前计算相同的界限。
总变差界限使我们能够提供一个实际使用的公式,用于估算基于直方图(具有箱宽h)的四舍五入误差上界:

这里,n_k是落入第k个箱的观察值数量。
作为一个数值示例,我们分析了由美国疾病控制与预防中心(CDC 2022)调查的N = 412,659 人的身高数据,单位为英寸。该数据的平均身高以公制单位表示为 170.33 厘米。由于样本量较大,标准误差σ / √N非常小,仅为 0.02 厘米。然而,由于四舍五入的误差可能更大,总变差的界限可以估算为 0.05 厘米。在这种情况下,统计误差可以忽略不计,因为身高差异在不到一厘米的范围内通常不会有实际意义。然而,对于其他需要高精度估算平均值的测量数据,当数据受到量化处理时,仅计算标准误差可能不足够。
基于 Fisher 信息的界限
如果概率密度函数p( ⋅ )是连续可微的,我们可以将其总变差V(p)表示为对导数模的积分。应用Hölder 不等式,我们可以通过 Fisher 信息I(p)的平方根对总变差进行界定:

因此,我们可以写下一个额外的上界,用于计算四舍五入或分箱数据的均值时的偏差:

这个新的界限具有(理论上的)兴趣,因为 Fisher 信息是密度函数的一个特征,通常比其总变差更常用。
更多的界限可以通过已知的 Fisher 信息上界找到,许多可以在[Bobkov 2022]中找到,其中包括以下涉及概率密度函数三阶导数的内容:

值得注意的是,Fisher 信息在某些量子力学的表述中也具有重要意义,其中它作为哈密顿量的一部分,负责引发量子效应[Curcuraci & Ramezani 2019]。人们可能会思考,是否存在一个具体而有意义的联系,将量子化的物理物质与受到“常规”量子化的经典测量联系起来。然而,值得注意的是,这种推测可能源于数学上的错觉。
结论
Sheppard 修正是一种可以用来修正基于舍入或分箱数据计算均值、方差和其他(中心)矩时的误差的近似方法。
尽管 Sheppard 修正对于均值的影响为零,但实际误差可能与标准误差相当,甚至超过,尤其是在样本量较大的情况下。我们可以通过考虑概率密度函数的总变差来约束基于舍入或分箱数据计算均值时的误差,这一数量可以从分箱数据中估算出来。
在估计均值时,舍入误差的额外约束可以通过 Fisher 信息以及未知分布的概率密度函数的高阶导数来表达。
参考文献
[Sheppard 1897] Sheppard, W.F. (1897). “关于根据等距分割量表排列的数据,计算最可能的频率常数值。”伦敦数学会会刊 s1–29: 353–380。
[Kendall 1938] Kendall, M. G. (1938). “Sheppard 修正有效的条件。”皇家统计学会会刊 101(3): 592–605。
[Heitjan 1989] Daniel F. Heitjan (1989). “从分组的连续数据中推断:综述。”Statist. Sci. 4 (2): 164–179。
[Janson 2006] Janson, Svante (2005). “连续随机变量的舍入和振荡渐近行为。”概率年鉴 34 (5): 1807–1826。
[Ushakov & Ushakov 2022] Ushakov, N. G., & Ushakov, V. G. (2022). “样本量较大时舍入对假设检验的影响。”Stat 11(1): e478。
[Ushakov 1999] Ushakov, N. G. (1999). “特征函数的选定主题。”De Gruyter。
[Ushakov & Ushakov 2018] Ushakov, N. G., Ushakov, V. G. 统计分析舍入数据:测量误差与舍入误差。J Math Sci 234 (2018): 770–773。
[CDC 2022] 美国疾病控制与预防中心(CDC)。行为风险因素监测系统调查数据 2022 年。乔治亚州亚特兰大:美国卫生与公共服务部,疾病控制与预防中心。
[Bobkov 2022] Bobkov, Sergey G. (2022). “Fisher 信息的上界。”《电子概率杂志》27: 1–44。
[Curcuraci & Ramezani 2019] L. Curcuraci, M. Ramezani (2019). “量子势能和波函数温度的热力学推导。”《物理学 A:统计力学及其应用》530: 121570。
告别令人困惑的 Python 错误信息

灵活的错误信息增强库 — PrettyError
·发表于 Towards Data Science ·阅读时长 8 分钟 ·2024 年 3 月 3 日
--
编程是一项活动,我们可能会用 20% 的时间将想法写成代码,剩下的 80% 时间则用来清理错误和修复漏洞。错误信息绝对是我们每天都会看到的东西。然而,你是否曾经在 Python 错误信息上遇到过困难呢?
例如,错误信息可能非常冗长,这虽然不坏,但很难区分不同的部分,并快速找到我们需要的信息。堆栈跟踪有时也太过庞大和复杂,难以理解。除非我们重写 Exception 类,否则也不容易定制错误信息,而这可能再次令人不知所措。
在本文中,我将介绍一个名为 PrettyError 的库,它可以帮助我们解决上述所有痛点,甚至更多。它有许多酷炫的功能,可以简化我们的调试过程,并帮助我们在编程工作中节省大量时间。
1. 安装与快速入门

使你成为更优秀数据科学家的一个心态转变
事实上,任何优秀的员工都应该采纳这种心态
·发表于Towards Data Science ·6 分钟阅读·2024 年 4 月 15 日
--

图片来自Katerina May的作品,发布于Unsplash
从量化金融转行后,我作为数据科学家的第一次经历是在咨询行业。在我刚开始在麦肯锡的工作时,收到的大部分反馈并不是关于我的代码或技术能力,而是类似于“你需要将工作与公司/组织的更高层次优先事项联系起来”,“你应该提供更有力的洞察”或“你需要成为一个思想合伙人”之类的建议。
那时,作为少数几位数据科学家之一,我身处一片咨询通才的海洋,这些反馈最初对我来说感觉像模糊的咨询术语,我更希望有人批评我的代码。现在,作为一名经理,回头看,我意识到这些看似不相关的点其实都与一件事有关——那就是我在当时作为初级个体贡献者时,忽视了的心态,我当时专注于提升技术能力。
我曾经专注于执行任务,而不是像问题的所有者一样思考;然而,回头看,认为我的工作只是把任务做得好是一个错误。从那时起的这些年,我渐渐相信,拥有一种所有者心态是区分高绩效者与其他同事的关键因素之一。
知识图谱中的本体推理
KGs Insights
一本 Python 实操指南,帮助理解通过遵循逻辑过程生成新知识的原理
·发布于数据科学前沿 ·阅读时间 9 分钟·2024 年 11 月 15 日
--

图 1 — 一种端到端的过程,展示了如何通过本体推理将初始语句推导出推论
介绍
推理能力是 AI 系统中广泛讨论的话题。这些能力通常与大型语言模型(LLMs)相关,后者在从大量数据中提取模式方面特别有效。
在这个学习过程中捕获的知识使得大型语言模型(LLMs)能够执行各种语言任务,例如问答和文本摘要,展示了类似于人类推理的技能。
仅仅说“LLMs 无法推理”并没有帮助,因为显然它们可以完成一些人类会用推理来做的事情。 — *Jeremy Howard |
Fast.AI 联合创始人 — 斯坦福大学数字学者*
尽管 LLMs 能够识别和匹配数据中的模式,但在需要结构化和正式推理的任务中,LLMs 仍然存在局限性,尤其是在需要严谨逻辑过程的领域。
这些局限性突出了模式识别与正确逻辑推理之间的区别,而人类并不总是能察觉到这种区别。
无监督的 LLM 评估
大型语言模型输出评估实践指南
·发布于 Towards Data Science ·12 分钟阅读·2024 年 11 月 2 日
--
评估 AI 生成的输出对于构建强大的大型语言模型应用至关重要,因为它可以将复杂的 AI 应用拆分为带有内建错误控制的简单阶段。
在监督模式下评估生成输出相对简单,其中“正确答案”可以通过计算或由人工评估者提示。
与此同时,在许多实际的 LLM 应用中,监督方法过于限制性,亟需能够应对开放式问题的评估方法。构建无监督评估器的最简单方法是让 LLM 自我评估。然而,生成模型在检测自己输出中的错误的能力尚未得到充分理解。
我们证明了自我评估的质量可以通过迭代自我反思来提高。类似于“思维链”技术,这种方法在推理时以计算量换取最终结果的稳健性。
示例的 Google Colab 笔记本链接:
colab.research.google.com/drive/1q_dChQBMbnUXZ377JVwYsjvn7lZ_7qlZ?usp=sharing

图片来源:Flux 1. 专业版提示“机器人评估其他机器人”
引言
在构建使用大型语言模型的处理管道时,常被提到的问题是生成输出的质量。如果有良好的评估过程,它可以突出表现不佳的情况,并触发 LLM 微调、提示调整、人工代理的介入——或者同时进行这些操作。
这里是一个典型的使用评估进行训练的工作流程:LLM 遍历输入数据集,评估器检测到的任何输出差异都用来生成合成数据以微调模型。只有当目标质量指标达到时,应用程序才会部署。

作者提供的图像:LLM 微调的评估循环
在生产环境中使用 LLM 评估器非常类似——只是检测到的差异通常会被发送给人工代理,以确保即使触发了错误标志,工作流仍然可以继续进行。
然而,构建一个好的 LLM 评估器并非易事。这个问题的复杂性来源于两个实际的限制:
首先,在评估中尽量减少人工参与是非常理想的。例如,想象一个聊天机器人与用户互动,并错过了一个常见的省略语模式(用一个词代替完整的输出句子):
机器人: 这是正确的吗?
用户: 正确
机器人: 对不起,我没有明白。请再试一次。
用户: 是的,这是正确的
给定这个对话部分,人工应该能够轻松地指出聊天机器人回答中的不足,并建议进行微调。然而,为了发现这个问题,评估者必须阅读整个对话(可能非常长)。这种方法在大规模应用中行不通——这意味着我们应该努力实现无人工评估。
其次,在不知道“真实答案”的情况下判断 LLM 输出的过程,其复杂度与原始任务相当。这意味着,最先进的 LLM 最多只能使用一个具有类似能力的评估器(很可能是它自己),从而引发关于评估有效性的问题。
监督评估
如果我们看一下今天广泛研究的 LLM 评估方法,我们会发现它们大多数集中在监督或半监督的应用场景中。
如果训练数据集附带“真实答案”,那么评估变得非常简单——甚至可以驱动像 DSPy 这样的优化框架。当测试企业 LLM 应用程序时,如果与历史案例进行对比(这些案例由人工代理处理),其中“真实答案”相当于那些代理的判断。
另一个检验输出与“真实答案”对比的机会出现在 LLM 输出可以被正式验证时——例如可以编译和测试的计算机代码。尽管计算机程序可以用多种不同的方式编写,但正确的代码应通过测试,无论选择哪种实现路径。
生成输出无法正式验证的情况通常需要将人工引入环路。例如,RLHF 可以根据人类偏好的序列对 LLM 输出进行评分,从而引导网络走向复杂且微妙的策略。
无监督自我评估
与此同时,有许多开放性评估案例,无法实现“地面真实”方法,而 RLHF 太长或太昂贵。这解释了人们对无监督自我评估技术的兴趣。
假设我们有一个开放性的大型语言模型评估问题,通常需要人工参与——比如“这个聊天机器人如何改进”——我们能做些什么来实现自动化呢?
如果我们假设当代的大型语言模型具备丰富的语义表示并且本身具有自我评估能力,那么就可以建立一个经济的评估框架。这意味着你可以直接要求模型评估其输出,或者使用另一个大型语言模型执行相同的任务,以避免它们的训练集之间的交叉污染。
不幸的是,LLM 自我判断的简单实现可能会失败。 其中有两个可能的原因。
第一个原因 是最显而易见的:大型语言模型不能保证在其训练集中没有反映的主题上生成连贯的输出,或者在无法通过检索工具(如网页搜索)获取的信息上生成连贯的输出。例如,考虑这个 Google 地图片段,它展示了加利福尼亚州一个名为 Mineral 的未合并小镇和一条小溪 (Martin Creek),这条小溪流经该小镇。

Mineral, CA 附近(来源:Google 地图)
Mineral 小镇没有自己的维基百科页面,而且大多数大型语言模型在训练中不使用 GIS 数据,所以当被问到哪条小溪流经该小镇时,LLM 很容易产生幻觉:

GPT-4o 幻觉示例
Mill Creek 是一条真实的水体,绕过了 Mineral 所在的(更大的)普查指定地点。然而,Mill Creek 并没有穿过 Mineral 小镇本身,因此 GPT-4o 的回答是错误的。
直观地感觉,普通的 LLM 推理失败是因为主题超出了其知识集,并且无法通过常见工具(如网页搜索)获取,且这种情况不容易自我检测。
第二个 使自我评估失败的原因稍微微妙一些,且由大型语言模型输出的概率特性驱动。它在评估需要多次观察的情况下很容易展示。例如,考虑以下用户与聊天机器人的对话:
Bot: 我是一个聊天机器人,帮助你找到最适合的手机套餐。
User: 我需要至少 500GB 的数据,且价格低于 $50
Bot: “Ultimate Freedom” 套餐提供无限制的互联网访问,价格为 $79.99
User: 这或许真的有效。
Bot: 你想更改当前的套餐吗?
User: “Ultimate Freedom” 仅限国内使用吗?我能开通国际漫游吗?
Bot: 对不起,我没听清楚。请再试一次。
User: 我现在已经好了,请不要做任何更改。
Bot: 我是一个聊天机器人,帮助你找到最适合的手机套餐。
如果我们反复要求 LLM 列出基于这段对话可以对机器人进行的改进,我们会注意到输出列表在不同运行之间并不稳定。一个版本可能是这样的:
{["Ensure the bot can handle follow-up questions after suggesting a plan.",
"Improve bot's recognition of farewell messages.",
"Make sure the bot can assist with additional features or package inquiries."]}
另一种可能是这样的:
{["Ensure the bot can switch contexts and answer follow-up questions.",
"Enable the bot to recognize when a conversation ends.",
"Improve the bot's ability to provide options for additional services",
"Support input about cost limitations, e.g. users stating a desired price."]}
尽管这两个答案有很大重叠,但显然在任何一种情况下都没有生成类似于人类专家可能给出的那种详尽的建议列表。
自我反思的意外力量
一旦我们概述了评估的典型失败模式,使用 LLM 来评判自己似乎是个坏主意。毕竟,这听起来像是在要求一个勤奋的学生重新检查自己的答案。因为一个好学生不会犯很多拼写错误,重新检查仅仅反映现有的知识,不应导致改进。
然而,这正是我们对 LLM 的直觉可能完全错误的地方。
事实上,大多数 LLM 都能进行纠正性自我评估,即使主题超出了它们的知识库。
为了说明这一现象,我们回到 GPT-4o 的例子,其中幻觉涉及穿过矿物镇的水体,加州。有趣的是,这种特定的幻觉可以在自我评估过程中被消除:

GPT-4o 中的自我评估能够逆转幻觉
那么,魔力在哪里呢?
在这个例子中,LLM 没有知识或工具来得到正确的答案,所以它幻觉出“最有可能”的完成。然而,当被要求自我评估时,它得出结论:它可以访问的事实与之前的陈述并不一致。即使 GPT-4o 不知道正确答案,它也能否定错误的答案。
一个更复杂的模型(比如 GPT-4o1)可能稍微难以以相同的方式处理,因为它倾向于生成更细致的回应:

GPT-4o1 中的幻觉更加细致。
与其在自己无法验证的主题上产生幻觉,GPT-4o1 可能会选择回答一个它从未被问到的问题——比如“矿物镇附近流经的主要水体是哪一条?”。这种回避意味着,像“判断对错”这样的直接自我评估提示可能会失败。
然而,以一种更加深思熟虑的方式要求自我评估,仍然可能成功,即使这需要多次迭代:

LLMs 以迭代方式自我反思的能力,当然是众所周知的,并且在代码生成等应用中已经有所应用,我们这里只是将相同的技巧扩展到自我评估。
记忆化的“预期”能力
迭代反思的相同思路也适用于那些倾向于生成不完整输出的 LLM 任务。如果我们回顾机器人对话示例,并允许 LLM 在记忆化改进列表上进行迭代,我们将发现模型通常不会对第一次的结果“满意”。
换句话说,如果我们制定一个这样的提示:
iterative_prompt = """
Consider the following dialog between the user and the chatbot.
The bot's goal is to suggest a cheaper mobile plan based on the information the user provides.
The user's responses are not guaranteed to be consistent or coherent at all times.
This dialog was evaluated by an LLM and this evaluation is provided below.
You job is to assess the quality of evaluation and respond with "success"=True and repeat the original action list if there is nothing significant to add.
If there is something missing in evaluation, respond with "success"=False and a new list of action items to create better user experience integrating the old list with new suggestions. Make sure the list items are unique and not repetitive.
"""
然后通常需要对改进列表进行 2 到 4 轮的检查,直到 LLM 得出推荐并宣布评估任务成功:
🍩
success='False' action_items=['Enable bot to understand user inquiries about add-on packages related to international calls.', "Improve bot's understanding to handle informal or casual goodbyes such as 'byebye'."]
🍩
success='False' action_items=['Enable bot to understand user inquiries about add-on packages related to international calls.', "Improve bot's understanding to handle informal or casual goodbyes such as 'byebye'.", "Enhance the bot's capability to suggest plans that are closer to the user's budget, such as recommending plans around $10 instead of $14 when the user specifies a $10 budget."]
🍩
success='False' action_items=['Enable bot to understand user inquiries about add-on packages related to international calls.', "Improve bot's understanding to handle informal or casual goodbyes such as 'byebye'.", "Enhance the bot's capability to suggest plans that are closer to the user's budget, such as recommending plans around $10 instead of $14 when the user specifies a $10 budget.", 'Ensure the bot confirms if the user is interested in plans without inclusive international minutes given their travel habits.', 'Add functionality for the bot to suggest alternative communication methods like VoIP for international calls if budget constraints are strict.', "Improve the bot's ability to suggest plans that balance cost with user requirements, such as considering travel habits and required features."]
🍩
success='True' action_items=['Enable bot to understand user inquiries about add-on packages related to international calls.', "Improve bot's understanding to handle informal or casual goodbyes such as 'byebye'.", "Enhance the bot's capability to suggest plans that are closer to the user's budget, such as recommending plans around $10 instead of $14 when the user specifies a $10 budget.", 'Ensure the bot confirms if the user is interested in plans without inclusive international minutes given their travel habits.', 'Add functionality for the bot to suggest alternative communication methods like VoIP for international calls if budget constraints are strict.', "Improve the bot's ability to suggest plans that balance cost with user requirements, such as considering travel habits and required features."]
在这一初步的“一轮热身”对话后,我们可以给模型提供更多的示例对话,看看会发生什么。
与人类评估者的做法类似,GPT-4o 模型认为许多对话样本不值得生成新的推荐(一次模型运行就足够了)——然而,有些可能会引发更长时间的深思:

来自ExpBot 数据集的前 50 个对话中的 LLM 调用次数,直到收敛(图表由作者提供)
最终的结果将是一个相当详尽的关于改进聊天机器人的推荐列表:
Final recommendations:
["Improve the bot's ability to avoid repetitive greetings and restarts when the user's input is vague or repeated, creating a more fluid conversation flow.",
"Enhance the bot's active listening skills to acknowledge user needs and concerns before suggesting starting over, to better handle user dissatisfaction.",
"Include a function allowing users to ask follow-up questions for more details about the suggested plan, such as data overage charges and roaming fees.",
"Develop a mechanism for the bot to detect and correct minor typographical errors and currency symbol mismatches in user inputs.",
"Provide alternative suggestions that might not fit all criteria but offer significant savings or benefits in other areas based on the provided user data.",
"Implement a feedback system enabling users to rate the accuracy or helpfulness of the plan suggestion provided, allowing for iterative improvements.",
"Incorporate a bot training mechanism to ensure it can handle responses that are non-standard in format or include extraneous details not directly related to the plan.",
"Add the ability for the bot to suggest seeking human assistance when complex queries or dissatisfaction arise that the bot cannot resolve.",
"Enhance the bot's language processing capabilities to accurately interpret various phrasings and informal expressions from the user.",
"Increase the bot's capability for dynamic clarification requests, creating a smoother interaction flow.",
"Refine the bot's ability to verify user information effectively to reduce misunderstandings and user frustration.",
"Improve the bot's handling of unrealistic and inconsistent user inputs to guide the conversation back to relevant queries.",
"Integrate a process for flagging nonsensical data entries and guide the user toward providing accurate information.",
"Provide clearer explanations or breakdowns of the suggested plan's features, especially if different from the user's mentioned requirements.",
"Improve response to questions unrelated to starting new calculations to avoid redundant loops."]
关于此示例的一些技术说明:
-
为了简化,我们将评估和生成合并为一个提示,依赖于OpenAI 的结构化输出来生成期望的结果。
-
记忆化的固有限制在于需要按顺序处理样本。这在处理大型数据集时可能需要一些时间,并且还阻止我们通过批处理调用使用低成本推理。
为了进一步提高性能,我们可以利用这样一个事实:数据集中的大多数样本并不会生成新的见解。这意味着我们可以通过按顺序迭代一个小子集的样本来生成初步的推荐列表,并通过DataChain 库(或通过OpenAI API批量处理)并行服务剩余的数据集,以标记“有趣”的案例,从而节省 30%到 50%的时间(或费用),具体取决于您的偏好。
结论
LLM 可以且应该用于无监督评估(包括自我评估)。关键是它需要一个经过深思熟虑的方法——这通常会发展成一种迭代的方式来改进和完善判断。
这是 Google Colab 中示例实现的链接:
colab.research.google.com/drive/1q_dChQBMbnUXZ377JVwYsjvn7lZ_7qlZ?usp=sharing
开源数据可观察性与 Elementary — 从零到英雄(第一部分)
一份我在刚开始时希望能拥有的循序渐进的实操指南
·发表于 Towards Data Science ·阅读时长 6 分钟·2024 年 9 月 10 日
--
数据可观察性及其重要性常常被讨论和撰写成现代数据和分析工程的重要方面。市场上有许多具有各种功能和价格的工具可供选择。在这篇两部分的文章中,我们将重点介绍开源版本的 Elementary,这个平台是众多数据可观察性工具之一,专为 dbt 量身定制,并与之无缝对接。我们将从零开始设置,并计划在第二部分结束时理解其工作原理以及在不同数据场景下的可行性。在开始之前,我还想声明,我与 Elementary 没有任何关联,所有观点均为我个人的看法。
在第一部分中,我们将设置 Elementary 并检查如何读取 Elementary 的每日报告。如果你已经对这部分感到熟悉,并且有兴趣查看不同类型的数据测试以及哪种测试最适合哪种场景,你可以直接跳到第二部分:
我已经使用 Elementary 有一段时间了,我作为数据工程师的经验是积极的,我的团队对结果的看法也很正面。我们的团队使用 Elementary 进行自动化的每日监控,配合自托管的 Elementary 仪表盘。Elementary 还拥有一个非常方便的云平台……
开源模型、温度缩放、重排序等:不要错过我们近期的 LLM 必读文章
·发表于 Towards Data Science ·作为 新闻简报 发送 ·阅读时间 3 分钟 ·2024 年 5 月 16 日
--
想要写下你的第一篇 TDS 文章吗?我们始终欢迎新作者的贡献。
新的 LLM(大语言模型)几乎每天都在不断涌现,而它们带来的工具和工作流程则以更快的速度繁荣发展。我们认为现在是时候回顾一下近期关于这一不断变化领域的讨论了,而我们也想不到比精选过去几周内一些最强文章的方式更好的方式来做到这一点。
我们整理的这些文章既涉及高层次的议题,也探讨了细致的问题,因此无论你是对 AI 伦理、开源技术的演变,还是创新的 RAG 方法感兴趣,我们确信你会在这里找到能够激发你兴趣的内容。让我们深入探讨。
-
变化的潮流:开源 LLM 相较于封闭源 LLM 的竞争优势最初的生成式 AI 工具由像 OpenAI 发布的专有模型引领。 Leonie Monigatti 的新文章聚焦于一个新兴趋势:小型开源基础模型的崛起及其日益主导地位,这些模型因数据安全、可定制性和成本等因素而受到关注。
-
聊天机器人道德问题? 我们知道,当被要求提供事实信息时,LLM 可能会产生幻觉;那么,当用户开始向它们询问以伦理为重点的建议时会发生什么呢?Eyal Aharoni和Eddy Nahmias展示了他们在这一棘手问题上的最新研究,并探讨了聊天机器人“能够在特定、受控情况下模仿或合成人的道德话语”这一现象所固有的危险。
-
LLM 的推荐是否可以被操控以提升产品的可见性? 电子商务是一个已经容易受到操控和可疑商业行为影响的领域。正如Parul Pandey在她对一篇近期论文的分析中所展示的那样,LLM 凭借其快速、大规模生成文本和其他媒体的能力,已经准备好利用这个生态系统中的各种漏洞和盲点。

图片由Thomas Kelley提供,来源:Unsplash
-
LLM 中的温度缩放与束搜索文本生成,面向机器学习相关领域 在一篇内容详尽、举例丰富的指南中,Mike Cvet深入解析了在生成式 AI 工作流中“温度”这一概念:它是一个修改模型输出序列可预测性的参数,掌握其细微差别有助于从业者更有效地使用 AI 工具。
-
如何通过重新排序改进 LLM RAG 检索 在检索增强生成(RAG)初步兴奋过后,许多从业者很快意识到,RAG 系统通常可以从更先进的精炼方法中受益。Dr. Leon Eversberg的最新教程带我们了解了一种工作流,它利用两步检索(使用开源的双编码器和交叉编码器)以获得更好的结果。
正如他们一贯所做的那样,我们的作者们在近期几周涉及了许多其他话题,创作了一些高质量的文章;以下是一个代表性样本:
-
在她精彩的客户生命周期价值系列文章的完结篇中,Katherine Munro提供了可用预测方法的详细概述,以及市场营销人员和数据科学家可以从每种方法中期待的结果。
-
每一篇Sachin Date的深度剖析都值得庆祝,而最新的这篇也不例外:它是对统计收敛的细致探索,通过一个 19 世纪船难的故事讲述。
-
在她最新的初学者友好指南中,Srijanie Dey 博士转向了 Llama 3,并深入剖析了其变换器架构的细微差别。
-
Murto Hilali在分子生物学、计算生物学和人工智能的交叉领域写作,展示了他是如何构建一个多分类器模型来预测突变对蛋白质相互作用的影响的。
-
如果你正在考虑从物理学(及相关领域)转向数据科学的职业转型,不要错过Sara Nóbrega的实用指南,基于她个人的历程和一路上积累的经验。
-
对于任何刚刚踏入深度学习领域的人,Shreya Rao带着她的新作回来了,这是一本面向初学者的、巧妙插图的卷积神经网络入门书。
-
揭示 Kolmogorov-Arnold 网络(KANs)的论文才刚刚两周,但已经在该领域掀起了巨大波澜。Theo Wolf的首篇 TDS 文章帮助我们理解 KANs 是如何工作的,以及它为何如此引起关注。
感谢你支持我们作者的工作!我们非常喜欢发布新作者的文章,因此,如果你最近写了一个有趣的项目演示、教程,或对我们核心主题的理论反思,别犹豫,与我们分享。
直到下一个变量,
TDS 团队
打开人工大脑:用于 LLM 检查的稀疏自编码器
|LLM|可解释性|稀疏自编码器|可解释人工智能|
深入探讨使用稀疏自编码器进行 LLM 可视化和解释
·发表于Towards Data Science ·13 分钟阅读·2024 年 11 月 16 日
--

图片由作者使用 DALL-E 创建
一切事物都受到解释的支配,任何解释在特定时间的主导地位是权力的体现,而非真理。——弗里德里希·尼采
随着人工智能系统规模的增长,理解其机制变得愈加困难且迫切。今天,关于模型的推理能力、潜在的偏见、幻觉、以及大语言模型(LLM)的其他风险和局限性都在讨论之中。
探索人工智能的极限:为什么掌握模式不等于真正的推理
towardsdatascience.com
OpenAI 嵌入技术与聚类分析在调查分析中的应用 — 操作指南
如何从调查数据中获取洞察,并使用嵌入技术和大语言模型提取话题
·发布于 Towards Data Science ·7 分钟阅读·2024 年 10 月 25 日
--

图片来源:Olav Ahrens Røtne 在 Unsplash
离我换工作已经整整 4 个月了,这段时间我已经安顿下来并继续我的小型副项目。
最新的一个工具是一个调查分析工具,是我公司的一位产品负责人请求的。他必须每季度查看成千上万的公司范围的调查回应,试图提取出可执行的业务改进措施
现在这个工具正在被使用,并且(希望)为产品负责人和分析师节省了大量时间,我编写了这篇操作指南,帮助你创建类似的工具。
项目流程
我们将调查问卷的回答作为输入数据框(使用 pandas)。关键列是每个用户留下的评论。其他字段如部门、职位和提交日期也可以进行分析,但为了最小可行产品,我决定保持变量简单且少。
本研究的数据是通过在线调查收集的,并且将保持机密。为了本文的目的,实际数据未显示;图片展示了一个较小样本的分析过程……
OpenAI o1:这是将重塑我们所知道的每个知识领域的神秘力量吗?
我第一次接触 o1 模型
·发表于Towards Data Science ·6 分钟阅读·2024 年 9 月 16 日
--

由 DALL-E 生成的一幅图像,提示与博客标题完全相同。
我第一次接触 o1 模型
2024 年 9 月 12 日上午 10 点,我正在亚利桑那州立大学的“生成式人工智能前沿课题”课堂上上课。这是一个研究生水平的课程。就在前一天,即 9 月 11 日,我提交了一份团队作业,内容是尝试识别 GPT-4 生成的缺陷和错误输出(本质上是通过提示 GPT-4,看看它是否会在琐碎问题或高中水平推理问题上犯错),这是另一个研究生课程“自然语言处理课题”的一部分。我们识别出了 GPT-4 的几个小错误,其中之一是 无法数出单词 strawberry 中字母 r 的数量。在提交这份作业之前,我在网上查阅了几篇同行评审的论文,这些论文指出了 GPT-4 出错的地方和原因,以及如何改正这些错误。我看到的大多数文献指出了 GPT-4 出错的两个主要领域,分别是 规划和推理。
这篇论文¹(尽管已接近一年)深入探讨了多个案例,其中 GPT-4 无法回答一些涉及简单计数、简单算术、基础逻辑,甚至常识的 trivial 问题。论文¹ 认为这些问题需要一定程度的推理,而 GPT-4 完全无法进行推理,因此几乎总是会答错这些问题。作者还指出,推理是一个(非常)计算上困难的问题。尽管 GPT-4 计算资源密集,但它的计算密集型特性并未针对涉及推理的问答设计。其他几篇论文也呼应了 GPT-4 无法推理或规划²³的观点。
好的,让我们回到 9 月 12 日。我的课大约在上午 10:15 结束,然后我直接从课堂回到家,打开手机上的 YouTube,一边享用我的早午餐。我的 YouTube 首页上的第一个推荐视频是 OpenAI 发布的名为 “Building OpenAI o1” 的视频。他们宣布这个模型是一个专门的推理模型,并表示它将在推理和回答问题时花费更多时间,从而提供更准确的答案。他们表示,在 RL(强化学习)方面投入的计算时间比之前的模型更多,以生成连贯的思维链⁴。实质上,他们使用强化学习训练了思维链生成过程(以生成和完善自身生成的思维链过程)。在 o1 模型中,工程师们可以向模型提出问题,询问它在思维链过程中为什么会出错(每当它出错时),模型可以识别出错误并自我纠正。模型可以自我质疑并反思(见“LLM 中的反思”)其输出并加以修正。
在另一个视频 “Reasoning with OpenAI o1” 中,Jerry Tworek 展示了之前的 OpenAI 模型和市场上大多数其他大型语言模型(LLM)在以下提示上通常会失败:
“假设地球上的物理定律成立。将一个小草莓放入一个普通的杯子里,并将杯子倒扣放在桌子上。然后,有人把这个杯子放入微波炉。现在草莓在哪里?请逐步解释你的推理过程。”
以下是 GPT-4 的传统答案:

图 1:GPT-4 在草莓杯子问题上的错误回答
相对较新的 GPT-4o 也答错了:

图 2:GPT-4 o 在草莓杯子问题上的错误回答
GPT o1 给出了正确答案:

图 3:GPT o1 在草莓杯子问题上的正确回答
如果你点击模型回答开头的下拉菜单(见图 4),你会看到它展示了自己的思考过程(链式思维),OpenAI 的研究人员声称 o1 模型已经通过强化学习训练,使这个思考链变得更加完善。另外,有趣的是,Jason Wei(你可以在视频《构建 OpenAI o1》中看到他坐在底行第三个位置),他曾在 Google 发布了链式思维的论文,现在是 OpenAI 的员工,正在致力于将他在 Google 发现的链式思维过程整合到 o1 模型中。

图 4:GPT o1 的链式思维引导
现在,让我们回到我的团队在我的任务中发现的计数问题。
草莓这个词中有多少个字母 r?
让我们在 GPT-4o 上运行这个问题:

图 5:GPT4o 在回答草莓问题时,错误地计算了字母 r 的个数。
一个非常简单的计数问题,它做错了。
让我们在新的 GPT o1 上运行这个问题:

图 6:GPT o1 正确地回答了草莓问题中字母 r 的个数。
GPT o1 通过思考几秒钟后得出了正确答案。OpenAI 的研究人员表示,它会反复检查自己的回答,并通过思考找到正确的答案。看来模型在解决许多学术考试问题方面确实有了显著的进步。
不管怎样,在我打开 X.com(前身为 Twitter)后,我看到几个人展示了他们试图让 o1 模型失败的尝试。这是我看到的一个有趣的例子(来自@creeor 的this tweet),在这个例子中,模型未能回答一个非常简单的问题,而答案就在问题本身。于是我在我的账户上尝试了完全相同的提示,但它给出了错误的答案(见图 7)。

图 7:即使在调整谜语后,OpenAI o1 仍然无法解答简单的谜题。显示出模型仍然依赖于它在训练过程中记住的很多内容,并没有充分发挥其推理能力。
当我问它它在谈论的这个经典谜语是什么时,它告诉我一个它从互联网记住的谜语。很有趣的是,看到这些模型有时会依赖于记忆的内容,而不是通过真正的推理来解决问题。尽管在基准测试上有了显著进展和改进,但 AI 模型在某些领域仍然存在困难,特别是在那些需要深入推理或以细致的方式理解上下文的任务中。尽管基准测试可以显示进展,实际应用往往暴露了其局限性。正是通过持续的测试、反馈和实际应用案例,这些模型才能不断得到完善。

图 8:o1 盲目地从它记住的谜语中回答问题。它没有阅读给定的问题,并尝试按所呈现的方式作答。
大约一年半前曾有一份关于ChatGPT 错误汇编。模型错误的汇编对于理解和改进 AI 系统非常宝贵。我相信人们很快会推出一份关于 o1 模型的错误汇编。
尽管我完全同意链式思维过程有利于 AI 和人类的学习,但真正的学习确实来自于经验和犯错误。
我将继续在我的 Medium 页面上发布关于 o1 模型的发现。关注我的账户以保持更新。感谢你抽时间阅读我的 Medium 文章。
参考文献:
[1] Arkoudas, Konstantine。“GPT-4 无法推理。” arXiv 预印本 arXiv:2308.03762(2023 年)。
[2] Aghzal, Mohamed, Erion Plaku 和 Ziyu Yao。“往前看:测试 GPT-4 在路径规划中的极限。” arXiv 预印本 arXiv:2406.12000(2024 年)。
[3] Kambhampati, Subbarao 等人。“LLMs 不能规划,但可以在 LLM-Modulo 框架中帮助规划。” arXiv 预印本 arXiv:2402.01817(2024 年)。
[4] Wei, Jason 等人。“链式思维提示在大型语言模型中引发推理。” 神经信息处理系统进展 35(2022 年):24824–24837。
OpenAI 提示缓存监控

图像由 AI(Dalle-3)生成
使用 Python 和聊天完成 API 的工作示例
·发表于Towards Data Science ·阅读时长 9 分钟·2024 年 12 月 10 日
--
作为他们最近 DEV Day 演示的一部分,OpenAI 宣布提示缓存现在已经适用于各种模型。写这篇文章时,这些模型包括:
GPT-4o、GPT-4o mini、o1-preview 和 o1-mini,以及这些模型的微调版本。
这一消息不容小觑,因为它将帮助开发者节省成本并减少应用程序运行时延迟。
对支持的模型的 API 调用将在提示超过 1,024 个令牌时自动受益于提示缓存。API 会缓存已计算过的提示的最长前缀,从 1,024 个令牌开始,并按 128 个令牌的增量递增。如果你重用带有共同前缀的提示,OpenAI 会自动应用提示缓存折扣,而无需你更改 API 集成。
作为 OpenAI API 开发者,你唯一可能需要担心的是如何监控你的提示缓存使用情况,即检查它是否已被应用。
在本文中,我将向你展示如何使用 Python、Jupyter Notebook 和一个聊天完成示例来实现这一点。
安装 WSL2 Ubuntu
OpenAI 与开源多语言嵌入模型
选择最适合你数据的模型
·发表于 Towards Data Science ·阅读时间 12 分钟·2024 年 2 月 24 日
--

我们将使用欧盟人工智能法案作为我们的嵌入模型比较的数据语料库。图像由 Dall-E 3 生成。
OpenAI 最近发布了他们的新一代嵌入模型,称为embedding v3,他们描述这些模型是性能最强的嵌入模型,具有更高的多语言性能。该模型有两种类型:一种较小的,称为text-embedding-3-small,另一种较大且更强大的,称为text-embedding-3-large。
关于这些模型的设计和训练方式,公开的信息非常少。与他们之前的嵌入模型发布(2022 年 12 月,ada-002 模型类)一样,OpenAI 再次选择了一种闭源的方法,模型只能通过付费 API 访问。
但是,性能真的那么好吗,值得付费吗?
这篇文章的动机是通过实证比较这些新模型与它们的开源对手的表现。我们将依赖一个数据检索工作流,在该工作流中,必须根据用户查询找到语料库中最相关的文档。
我们的语料库将是欧洲人工智能法案,该法案目前正处于最终验证阶段。这个语料库的一个有趣特点是,除了它是全球首个关于人工智能的法律框架外,它还提供 24 种语言版本。这使得我们能够比较不同语言家族之间的数据检索准确性。
本文将涵盖以下两个主要步骤:
-
从多语言文本语料库中生成一个定制的合成问答数据集
-
比较 OpenAI 与最先进的开源嵌入模型在此自定义数据集上的准确性。
为了重现本文中展示的结果,代码和数据已经公开在这个 Github 仓库中。请注意,《欧盟人工智能法案》作为示例使用,本文中遵循的方法可以适应其他数据语料库。
生成自定义的问答数据集
让我们首先开始生成自定义数据上的问答数据集(Q/A),用于评估不同嵌入模型的性能。生成自定义问答数据集有两个好处。首先,它避免了偏见,确保数据集未参与嵌入模型的训练,这种情况可能发生在参考基准上,如MTEB。其次,它使得评估可以根据特定数据语料库进行定制,这在检索增强应用(RAG)等情况下尤为重要。
我们将遵循Llama Index 文档中建议的简单流程。首先将语料库拆分成多个块。然后,对于每个块,使用大型语言模型(LLM)生成一组合成问题,使得问题的答案位于相应的块中。该过程如下图所示:

为你的数据生成问答数据集,方法参考Llama Index
实现这个策略在使用像 Llama Index 这样的 LLM 数据框架时非常直接。语料库的加载和文本的拆分可以通过高阶函数方便地完成,如下面的代码所示。
from llama_index.readers.web import SimpleWebPageReader
from llama_index.core.node_parser import SentenceSplitter
language = "EN"
url_doc = "https://eur-lex.europa.eu/legal-content/"+language+"/TXT/HTML/?uri=CELEX:52021PC0206"
documents = SimpleWebPageReader(html_to_text=True).load_data([url_doc])
parser = SentenceSplitter(chunk_size=1000)
nodes = parser.get_nodes_from_documents(documents, show_progress=True)
在这个示例中,语料库是《欧盟人工智能法案》的英文版,直接从网络上获取,使用的是这个官方网址。我们使用的是 2021 年 4 月的草案版本,因为最终版本尚未提供所有欧盟语言。在这个版本中,网址中的英文可以替换为其他 23 种欧盟官方语言中的任何一种,以获取不同语言的文本(例如,BG 代表保加利亚语,ES 代表西班牙语,CS 代表捷克语,等等)。

下载《欧盟人工智能法案》在 24 种官方欧盟语言中的链接(来自欧盟官网)
我们使用 SentenceSplitter 对象将文档拆分成 1000 个标记的块。对于英文文本,这大约会生成 100 个块。
然后,每个文档块作为上下文提供给以下提示(Llama Index 库中建议的默认提示):
prompts={}
prompts["EN"] = """\
Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, generate only questions based on the below query.
You are a Teacher/ Professor. Your task is to setup {num_questions_per_chunk} questions for an upcoming quiz/examination.
The questions should be diverse in nature across the document. Restrict the questions to the context information provided."
"""
该提示的目的是生成关于文档块的问题,就像老师在准备即将到来的小测验一样。每个文档块要生成的问题数量通过参数‘num_questions_per_chunk’传递,我们将其设置为两个。然后,可以通过调用 Llama Index 库中的 generate_qa_embedding_pairs 来生成问题:
from llama_index.llms import OpenAI
from llama_index.legacy.finetuning import generate_qa_embedding_pairs
qa_dataset = generate_qa_embedding_pairs(
llm=OpenAI(model="gpt-3.5-turbo-0125",additional_kwargs={'seed':42}),
nodes=nodes,
qa_generate_prompt_tmpl = prompts[language],
num_questions_per_chunk=2
)
我们在此任务中依赖于 OpenAI 的 GPT-3.5-turbo-0125 模型,根据 OpenAI 的说法,这是该系列的旗舰模型,支持 16K 的上下文窗口,并针对对话进行了优化(platform.openai.com/docs/models/gpt-3-5-turbo)。
生成的对象‘qa_dataset’包含问题和答案(文档块)对。作为生成问题的示例,以下是前两个问题的结果(其中‘答案’是第一个文档块的文本):
根据说明性备忘录,关于人工智能的统一规则提案(人工智能法案)的主要目标是什么?
根据上下文信息,关于人工智能的统一规则提案如何在促进人工智能在欧盟的应用的同时,解决与使用人工智能相关的风险?
文档块和问题的数量取决于语言,从英语的约 100 个文档块和 200 个问题,到匈牙利语的 200 个文档块和 400 个问题不等。
OpenAI 嵌入模型的评估
我们的评估函数遵循Llama Index 文档,由两个主要步骤组成。首先,所有答案(文档块)的嵌入存储在一个 VectorStoreIndex 中,以便高效检索。然后,评估函数循环遍历所有查询,检索最相似的前 k 个文档,并通过 MRR(平均倒数排名)来评估检索的准确性。
def evaluate(dataset, embed_model, insert_batch_size=1000, top_k=5):
# Get corpus, queries, and relevant documents from the qa_dataset object
corpus = dataset.corpus
queries = dataset.queries
relevant_docs = dataset.relevant_docs
# Create TextNode objects for each document in the corpus and create a VectorStoreIndex to efficiently store and retrieve embeddings
nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
index = VectorStoreIndex(
nodes, embed_model=embed_model, insert_batch_size=insert_batch_size
)
retriever = index.as_retriever(similarity_top_k=top_k)
# Prepare to collect evaluation results
eval_results = []
# Iterate over each query in the dataset to evaluate retrieval performance
for query_id, query in tqdm(queries.items()):
# Retrieve the top_k most similar documents for the current query and extract the IDs of the retrieved documents
retrieved_nodes = retriever.retrieve(query)
retrieved_ids = [node.node.node_id for node in retrieved_nodes]
# Check if the expected document was among the retrieved documents
expected_id = relevant_docs[query_id][0]
is_hit = expected_id in retrieved_ids # assume 1 relevant doc per query
# Calculate the Mean Reciprocal Rank (MRR) and append to results
if is_hit:
rank = retrieved_ids.index(expected_id) + 1
mrr = 1 / rank
else:
mrr = 0
eval_results.append(mrr)
# Return the average MRR across all queries as the final evaluation metric
return np.average(eval_results)
嵌入模型通过embed_model参数传递给评估函数,对于 OpenAI 模型来说,这是一个初始化了模型名称和模型维度的 OpenAIEmbedding 对象。
from llama_index.embeddings.openai import OpenAIEmbedding
embed_model = OpenAIEmbedding(model=model_spec['model_name'],
dimensions=model_spec['dimensions'])
dimensions API 参数可以缩短嵌入(即移除序列末尾的一些数字),而不会失去其表示概念的特性。OpenAI 例如在它们的公告中建议,在 MTEB 基准测试中,嵌入可以缩短到 256 的大小,同时仍然优于未缩短的text-embedding-ada-002嵌入(其大小为 1536)。
我们在四个不同的 OpenAI 嵌入模型上运行了评估函数:
-
两个版本的
text-embedding-3-large:一个是最低维度(256),另一个是最高维度(3072)。它们分别称为‘OAI-large-256’和‘OAI-large-3072’。 -
OAI-small:
text-embedding-3-small嵌入模型,维度为 1536。 -
OAI-ada-002:传统的
text-embedding-ada-002模型,维度为 1536。
每个模型都在四种不同的语言上进行了评估:英语(EN)、法语(FR)、捷克语(CS)和匈牙利语(HU),分别涵盖了日耳曼语系、罗曼语系、斯拉夫语系和乌拉尔语系的示例。
embeddings_model_spec = {
}
embeddings_model_spec['OAI-Large-256']={'model_name':'text-embedding-3-large','dimensions':256}
embeddings_model_spec['OAI-Large-3072']={'model_name':'text-embedding-3-large','dimensions':3072}
embeddings_model_spec['OAI-Small']={'model_name':'text-embedding-3-small','dimensions':1536}
embeddings_model_spec['OAI-ada-002']={'model_name':'text-embedding-ada-002','dimensions':None}
results = []
languages = ["EN", "FR", "CS", "HU"]
# Loop through all languages
for language in languages:
# Load dataset
file_name=language+"_dataset.json"
qa_dataset = EmbeddingQAFinetuneDataset.from_json(file_name)
# Loop through all models
for model_name, model_spec in embeddings_model_spec.items():
# Get model
embed_model = OpenAIEmbedding(model=model_spec['model_name'],
dimensions=model_spec['dimensions'])
# Assess embedding score (in terms of MRR)
score = evaluate(qa_dataset, embed_model)
results.append([language, model_name, score])
df_results = pd.DataFrame(results, columns = ["Language" ,"Embedding model", "MRR"])
结果中的准确度(以 MRR 衡量)如下所示:

OpenAI 模型表现的总结
正如预期的那样,对于大型模型,随着嵌入大小增大到 3072,表现有所提升。与小型和传统 Ada 模型相比,大型模型的表现虽有所提升,但依然比我们预期的要小。为进行比较,我们还在下方报告了 OpenAI 模型在 MTEB 基准测试中的表现。

OpenAI 嵌入模型的表现,详见它们的官方公告。
值得注意的是,在我们的评估中,大型、小型和 Ada 模型之间的表现差异,比 MTEB 基准测试中观察到的差异要小,这反映了一个事实,即在大型基准测试中观察到的平均表现,并不一定能反映在定制数据集上获得的结果。
开源嵌入模型的评估
关于嵌入的开源研究非常活跃,新的模型定期发布。一个不错的保持最新发布模型的地方是Hugging Face 😊 MTEB 排行榜。
本文中的比较,我们选择了一组最近发布的四种嵌入模型(2024 年)。选择标准是它们在 MTEB 排行榜上的平均得分以及处理多语言数据的能力。以下是所选模型的主要特征总结。
所选的开源嵌入模型
-
E5-Mistral-7B-instruct(E5-mistral-7b):微软的这款 E5 嵌入模型是从Mistral-7B-v0.1初始化的,并在多语言数据集的混合上进行微调。该模型在 MTEB 排行榜上表现最佳,但也是最大的模型(14GB)。
-
multilingual-e5-large-instruct(ML-E5-large):微软的另一款 E5 模型,旨在更好地处理多语言数据。它是从xlm-roberta-large初始化的,并在多语言数据集的混合上进行训练。它比 E5-Mistral 小得多(小 10 倍),但上下文大小也更小(514)。
-
BGE-M3:该模型由北京人工智能研究院设计,是其针对多语言数据的最先进嵌入模型,支持 100 多种工作语言。截至 2024 年 2 月 22 日,它尚未在 MTEB 排行榜上进行基准测试。
-
nomic-embed-text-v1(Nomic-Embed):该模型由Nomic设计,声称在性能上优于 OpenAI Ada-002 和 text-embedding-3-small,同时仅为 0.55GB 大小。有趣的是,该模型是第一个完全可复现和可审计的模型(开放数据和开源训练代码)。
评估这些开源模型的代码与用于 OpenAI 模型的代码类似。主要的变化在于模型规格,必须指定附加的细节,如最大上下文长度和池化类型。然后,我们对四种语言中的每种模型进行了评估:
embeddings_model_spec = {
}
embeddings_model_spec['E5-mistral-7b']={'model_name':'intfloat/e5-mistral-7b-instruct','max_length':32768, 'pooling_type':'last_token',
'normalize': True, 'batch_size':1, 'kwargs': {'load_in_4bit':True, 'bnb_4bit_compute_dtype':torch.float16}}
embeddings_model_spec['ML-E5-large']={'model_name':'intfloat/multilingual-e5-large','max_length':512, 'pooling_type':'mean',
'normalize': True, 'batch_size':1, 'kwargs': {'device_map': 'cuda', 'torch_dtype':torch.float16}}
embeddings_model_spec['BGE-M3']={'model_name':'BAAI/bge-m3','max_length':8192, 'pooling_type':'cls',
'normalize': True, 'batch_size':1, 'kwargs': {'device_map': 'cuda', 'torch_dtype':torch.float16}}
embeddings_model_spec['Nomic-Embed']={'model_name':'nomic-ai/nomic-embed-text-v1','max_length':8192, 'pooling_type':'mean',
'normalize': True, 'batch_size':1, 'kwargs': {'device_map': 'cuda', 'trust_remote_code' : True}}
results = []
languages = ["EN", "FR", "CS", "HU"]
# Loop through all models
for model_name, model_spec in embeddings_model_spec.items():
print("Processing model : "+str(model_spec))
# Get model
tokenizer = AutoTokenizer.from_pretrained(model_spec['model_name'])
embed_model = AutoModel.from_pretrained(model_spec['model_name'], **model_spec['kwargs'])
if model_name=="Nomic-Embed":
embed_model.to('cuda')
# Loop through all languages
for language in languages:
# Load dataset
file_name=language+"_dataset.json"
qa_dataset = EmbeddingQAFinetuneDataset.from_json(file_name)
start_time_assessment=time.time()
# Assess embedding score (in terms of hit rate at k=5)
score = evaluate(qa_dataset, tokenizer, embed_model, model_spec['normalize'], model_spec['max_length'], model_spec['pooling_type'])
# Get duration of score assessment
duration_assessment = time.time()-start_time_assessment
results.append([language, model_name, score, duration_assessment])
df_results = pd.DataFrame(results, columns = ["Language" ,"Embedding model", "MRR", "Duration"])
下面报告了以 MRR 为标准的准确度结果。

开源模型的性能总结
BGE-M3 表现最佳,平均表现紧随其后的是 ML-E5-Large、E5-mistral-7b 和 Nomic-Embed。BGE-M3 模型尚未在 MTEB 排行榜上进行基准测试,我们的结果表明它可能会排在其他模型之前。有趣的是,虽然 BGE-M3 是针对多语言数据进行优化的,但它在英语上的表现优于其他模型。
我们还报告了每个嵌入模型的处理时间。

通过英文问答数据集的处理时间(以秒为单位)
E5-mistral-7b 比其他模型大 10 倍以上,毫不意外地是最慢的模型。
结论
让我们将这八个测试模型的性能并排展示在一个图中。

八个测试模型的性能总结
从这些结果中可以得出以下关键观察:
-
开源模型的表现最好。由北京人工智能学会开发的 BGE-M3 模型表现最佳。该模型的上下文长度与 OpenAI 模型相同(8K),大小为 2.2GB。
-
OpenAI 模型范围的一致性。大(3072)、小型和遗留版 OpenAI 模型的表现非常相似。然而,减少大模型(256)的嵌入大小会导致性能下降。
-
语言敏感性。几乎所有模型(除了 ML-E5-large)在英语上的表现最好。在捷克语和匈牙利语等语言中,性能差异较大。
那么,你是应该选择付费的 OpenAI 订阅,还是托管一个开源嵌入模型?
OpenAI 的 最近价格调整使得访问其 API 更加实惠,现在每百万个 token 的费用为 $0.13。处理每月百万次查询(假设每次查询大约涉及 1K token)大约需要 $130。根据你的使用场景,租用和维护自己的嵌入服务器可能不具有成本效益。
然而,成本效益并非唯一的考虑因素。其他因素如延迟、隐私和数据处理工作流的控制也可能需要被考虑。开源模型提供了完全控制数据的优势,从而增强了隐私性和定制性。另一方面,OpenAI 的 API 存在延迟问题,有时会导致响应时间延长。
总结来说,选择开源模型还是像 OpenAI 这样的专有解决方案并不是一个简单的答案。开源嵌入模型提供了令人信服的选择,结合了性能和对数据的更大控制。相反,OpenAI 的产品可能仍然吸引那些优先考虑便利性的用户,尤其是当隐私问题不是主要关注点时。
有用的链接
-
配套的 Github 仓库:
github.com/Yannael/multilingual-embeddings -
文本嵌入:全面指南
-
如何为你的 RAG 找到最佳的多语言嵌入模型
注意:
-
除非另有说明,所有图片均由作者提供
-
《欧盟人工智能法案草案》根据委员会的文件重用政策发布,基于决定 2011/833/EU,可用于商业或非商业目的。
喜欢这篇文章吗?分享你的想法,给它点个赞,或 在 LinkedIn 上与我联系。
打开潘多拉的盒子:征服数据云迁移和新领域项目中的 7 个“邪恶使者”
克服云迁移挑战的指南
·发表于Towards Data Science ·阅读时长 15 分钟·2024 年 10 月 8 日
--

“尽管有警告,潘多拉依然充满好奇,她打开了罐子,释放了世界的恶行——只留下希望被困其中。” [照片由Bailey Heedick拍摄,来源于Unsplash]
潘多拉, 这位第一位凡人女性,是由众神创造的,作为宙斯计划的一部分,旨在惩罚人类因为普罗米修斯偷取火种[1]。
她被赋予了美貌和智慧,而宙斯将她送往厄庇墨修斯,普罗米修斯的兄弟。作为结婚礼物,宙斯送给潘多拉一个罐子(常被解释为“盒子”)并警告她绝对不要打开它 [1]。
尽管有警告,潘多拉充满好奇。她打开了罐子,释放了 世界上带来邪恶的存在——只留下希望被困其中 [2]。
从那时起,“打开潘多拉的盒子” 就成了做或开始一件将导致许多无法预见的问题的代名词 [3]。
将这与我的职业生涯进行对比,唯一让我感觉像是打开了“潘多拉盒子”的时刻是当我几年前开始从事数据云迁移/新建项目时。
有趣的是,多年来这个想法没有改变,即便是参与了另外两个几乎完全相同的项目之后。
每一个新的数据云迁移项目,不仅让我经历了新的“祸根”,而且我还…
使用 Elementary 实现开源数据可观测性——从零到英雄(第二部分)
带你免费提升 dbt 测试水平的指南
·发表于 Towards Data Science ·阅读时间 6 分钟·2024 年 9 月 10 日
--

图片来源:Caspar Camille Rubin 于 Unsplash
在上一部分中,我们已在 dbt 仓库中设置了 Elementary,并希望它也能在生产环境中运行。在本部分中,我们将更详细地探讨 Elementary 中可用的测试,并通过示例说明哪些测试适用于哪些数据场景。
如果你错过了第一部分,下面是它的内容:
开源数据可观测性与 Elementary —— 从零到英雄(第一部分)
在运行报告时,我们看到在 Elementary Cloud 中有一个“测试配置”标签。这是报告中一个便捷的 UI 部分,只在云端版本中可用,但我们也可以在 Elementary 的开源版本中通过 .yaml 文件创建测试配置。它与设置原生 dbt 测试类似,遵循类似的 dbt 层次结构,更具体的配置会覆盖更高层次的配置。
你可以设置哪些测试?Elementary 将它们分为 3 大类:模式测试、异常测试和 Python 测试。我们一起来逐个了解它们的工作原理:
模式测试:
运营数据与分析数据
在企业中,数据的区别是什么,我们应该如何对待数据?
·发布于Towards Data Science ·7 分钟阅读·2024 年 11 月 7 日
--
不幸的是,我们仍然对运营数据和分析数据的确切含义存在很大的困惑。因此,我们仍在努力寻找一种合适的方式,从企业层面整体处理数据。
所谓“数据的巨大分界”,是我们今天在数据架构中面临的许多挑战的根源。运营数据与分析数据之间的区分,在当前的定义下并没有实际帮助。

图片来源:作者,灵感来自Zhamak Dehghani 的《数据的巨大分界》
我在之前的文章中已经讨论过这个问题,并在《数据网格中的挑战与解决方案》系列的第一部分中做出了关键声明:
为了解决脆弱的 ETL 管道问题,我们不妨完全不在运营数据与分析数据之间划定严格的界限。相反,我们应该只区分源数据与衍生数据——这两者都可以用于运营和分析目的。
这一点非常基础,我想在这里进一步阐述,以明确为什么我如此坚持普遍数据供应,它能有效弥合两者之间的差距。
误解
优化:用最简单的术语解读排队理论
你是否曾经在超市、餐厅或银行排队等候,盼望自己的号码能尽快到?
·发布于 Towards Data Science ·11 分钟阅读·2024 年 3 月 11 日
--

图像由 DALL-E 生成
引言
排队理论可以为这种常见的困扰提供解决方案。顾名思义,排队理论运用数学模型来评估排队或等待线,旨在优化操作效率。
以超市为例,通过分析顾客排队情况,超市能够确定服务顾客所需的最佳收银台数量和员工配备,从而在不影响顾客等待时间的前提下,提升服务效率。分析的目标是找到顾客满意度与资源限制之间的平衡。
然而,尽管排队理论非常有用,它常常与复杂的概率论和数学联系在一起。这也是为什么在这篇文章中,我将尽量简化这一概念,而不是深入繁重的数学内容,目的是为您提供排队理论的整体框架,最重要的是,讲解所有术语是如何相互关联的。
目录
- 排队理论的应用
优化定价和促销中的非线性处理效应
因果 AI,探索因果推理与机器学习的结合
·发表于Towards Data Science ·12 分钟阅读·2024 年 5 月 24 日
--

这系列文章讲了什么?
欢迎来到我的因果 AI 系列文章,在这里我们将探索因果推理与机器学习模型的结合。你将看到多个不同业务情境下的实际应用。
在上一篇文章中,我们讨论了使用双重机器学习和线性规划来优化处理策略。这次我们将继续优化的主题,探讨优化定价与促销中的非线性处理效应。
如果你错过了上一篇关于双重机器学习和线性规划的文章,可以在这里查看:
因果 AI,探索因果推理与机器学习的结合
towardsdatascience.com
介绍
本文将展示我们如何优化定价中的非线性处理效应(但这些理念也可以应用于市场营销和其他领域)。
在本文中,我将帮助你理解:
-
为什么在定价中非线性处理效应如此常见?
-
我们的因果人工智能工具箱中有哪些工具适用于估算非线性处理效应?
-
非线性编程如何用于优化定价?
-
一个使用 Python 的案例研究,展示了我们如何结合因果人工智能工具箱和非线性编程来优化定价预算。
完整的笔记本可以在这里找到:
[## causal_ai/notebooks/using dml and lp to optimise treatment strategies.ipynb at main ·…
本项目介绍了因果人工智能(Causal AI)及其如何推动业务价值。- causal_ai/notebooks/using dml and lp to…
为什么在定价中非线性处理效应如此常见?
递减收益
让我们以零售商调整产品价格为例。最初,降低价格可能会导致销售量显著增加。然而,随着价格继续降低,销售的增长可能会开始趋于平稳。我们称之为递减收益。如下面所示,递减收益的效果通常是非线性的。

用户生成的图片
递减收益可以在定价之外的多个领域观察到。一些常见的例子包括:
-
营销 — 增加社交媒体投入可以提高客户获取,但随着时间的推移,瞄准新的、未开发的受众会变得越来越困难。
-
农业 — 向田地添加肥料最初可以显著提高作物产量,但这种效果很快就会开始递减。
-
制造 — 向生产过程中添加更多工人将提高效率,但每增加一名工人对整体产出的贡献可能会减少。
这让我开始思考,如果递减收益如此常见,那么我们的因果人工智能工具箱中有哪些技术可以应对这一问题?
我们的因果人工智能工具箱中有哪些方法适合估算非线性处理效应?
工具箱
我们将提出两个关键问题,帮助我们识别哪些因果人工智能工具箱中的方法适合解决定价问题:
-
它能处理连续性处理吗?
-
它能捕捉非线性处理效应吗?
以下是我们如何评估每种方法适用性的总结:
-
倾向得分匹配(PSM)— 处理需要是二元的 ❌
-
倾向得分逆向匹配(IPSM)— 处理需要是二元的 ❌
-
T 学习者(T-Learner)— 处理需要是二元的 ❌
-
双重机器学习(DML)— 处理效应是线性的 ❌
-
双重鲁棒学习者(DR)— 处理需要是二元的 ❌
-
S-Learner — 如果使用适当的机器学习算法(例如梯度提升),它可以处理连续处理和处理与结果之间的非线性关系💚
S-Learner
S-Learner 中的“S”来自于它是一个“单一模型”。一个任意的机器学习模型被用来预测结果,使用处理、混杂因素和其他协变量作为特征。这个模型随后被用来估计在不同处理条件下潜在结果的差异(从而给我们带来处理效应)。
S-Learner 有许多优点:
-
它可以处理二元和连续性处理。
-
它可以使用任何机器学习算法,赋予我们灵活性来捕捉特征和处理之间的非线性关系。
一个警告:正则化偏差!现代机器学习算法使用正则化来防止过拟合——但这可能对因果问题产生负面影响。以梯度提升树方法中的超参数max features为例——在多个树中,可能会出现处理未被包含在模型中的情况。这会削弱处理效应。
在使用 S-Learner 时,我建议仔细考虑正则化参数,例如将max features设置为 1.0(有效地关闭特征正则化)。
如何使用非线性编程来优化定价?
价格优化
假设我们有多种产品,并且想要在给定的促销预算下优化它们的价格。对于每个产品,我们训练一个 S-Learner(使用梯度提升),将处理设置为折扣水平,将结果设置为总订单数。我们的 S-Learner 输出一个复杂模型,可以用来估计不同折扣水平的效应。那么我们如何优化每个产品的折扣水平呢?
响应曲线
优化技术,如线性(甚至非线性)编程,依赖于响应的清晰函数形式。像随机森林和梯度提升这样的机器学习技术并不会给我们提供这个(与线性回归不同)。然而,响应曲线可以将 S-Learner 的输出转化为一种综合形式,展示结果如何响应处理。
如果你还不太能想象我们如何创建响应曲线,别担心,我们将在 Python 案例研究中详细讲解!
米哈利斯-门农方程
有几种方程可以用来将 S-Learner 映射到响应曲线。其中之一就是米哈利斯-门农方程。
米哈利斯-门农方程通常用于酶动力学(研究酶催化化学反应的速率)中,用来描述酶促反应的速率。

用户生成的图片
-
v — 是反应速度(这是我们转化后的响应,所以在我们的定价示例中是订单的总数)
-
Vmax — 是最大反应速度(我们称之为 alpha,这是一个我们需要学习的参数)
-
Km — 是底物浓度(我们称之为 lambda,这是一个我们需要学习的参数)
-
S — 是迈克利斯常数(这是我们的处理变量,所以在定价示例中是折扣水平)
它的原理也可以应用于其他领域,特别是在处理那些由于饱和因素导致输入增加不能按比例增加输出的系统时。下面我们展示不同的 alpha 和 lambda 值如何影响曲线:
def michaelis_menten(x, alpha, lam):
return alpha * x / (lam + x)

用户生成的图像
一旦我们获得了响应曲线,接下来我们可以考虑优化问题。迈克利斯-孟东方程给出了一个非线性函数。因此,非线性规划是一个合适的选择。
非线性规划
在我上一篇文章中我们介绍了线性规划。非线性规划类似,但目标函数和/或约束条件本质上是非线性的。
序列最小二乘法规划(SLSQP)是一种用于解决非线性规划问题的算法。它允许同时处理等式约束和不等式约束,因此在我们的使用场景中是一个合理的选择。
-
等式约束,例如总促销预算等于£100k
-
不等式约束,例如每个产品的折扣在£1 到£10 之间
SciPy 提供了一个易于使用的 SLSQP 实现:
[## minimize(method='SLSQP') - SciPy v1.13.0 Manual
如果 jac 在['2-point', '3-point', 'cs']中,使用相对步长进行数值近似 jac。绝对…
接下来,我们将展示 S-Learner、迈克利斯-孟东方程和非线性规划结合的强大威力!
案例研究
背景
历史上,促销团队一直依靠他们的专家判断来为他们的三大主打产品设置折扣。考虑到当前的经济状况,他们被迫将整体促销预算削减 20%。于是,他们求助于数据科学团队,咨询如何在减少订单量损失的同时做到这一点。
数据生成过程
我们设置了一个具有以下特点的数据生成过程:
-
4 个与订单数量有复杂关系的特征
-
一个遵循迈克利斯-孟东方程的处理效果
def data_generator(n, tau_weight, alpha, lam):
# Set number of features
p=4
# Create features
X = np.random.uniform(size=n * p).reshape((n, -1))
# Nuisance parameters
b = (
np.sin(np.pi * X[:, 0])
+ 2 * (X[:, 1] - 0.5) ** 2
+ X[:, 2] * X[:, 3]
)
# Create treatment and treatment effect
T = np.linspace(200, 10000, n)
T_mm = michaelis_menten(T, alpha, lam) * tau_weight
tau = T_mm / T
# Calculate outcome
y = b + T * tau + np.random.normal(size=n) * 0.5
y_train = y
X_train = np.hstack((X, T.reshape(-1, 1)))
return y_train, X_train, T_mm, tau
X 特征是混杂变量:

用户生成的图像
我们使用数据生成器为三个产品创建样本,每个产品有不同的处理效果:
np.random.seed(1234)
n=100000
y_train_1, X_train_1, T_mm_1, tau_1 = data_generator(n, 1.00, 2, 5000)
y_train_2, X_train_2, T_mm_2, tau_2 = data_generator(n, 0.25, 2, 5000)
y_train_3, X_train_3, T_mm_3, tau_3 = data_generator(n, 2.00, 2, 5000)
S-Learner
我们可以通过使用任何机器学习算法,并将处理和协变量作为特征来训练一个 S-Learner:
def train_slearner(X_train, y_train):
model = LGBMRegressor(random_state=42)
model.fit(X_train, y_train)
yhat_train = model.predict(X_train)
mse_train = mean_squared_error(y_train, yhat_train)
r2_train = r2_score(y_train, yhat_train)
print(f'MSE on train set is {round(mse_train)}')
print(f'R2 on train set is {round(r2_train, 2)}')
return model, yhat_train
我们为每个产品训练一个 S-Learner:
np.random.seed(1234)
model_1, yhat_train_1 = train_slearner(X_train_1, y_train_1)
model_2, yhat_train_2 = train_slearner(X_train_2, y_train_2)
model_3, yhat_train_3 = train_slearner(X_train_3, y_train_3)
目前这只是一个预测模型——下面我们可视化它在这项工作中的表现:

用户生成的图像
提取处理效果
接下来我们将使用我们的 S-learner 来提取整个处理值范围(折扣金额)的处理效果,同时将其他特征保持在其平均值。
我们首先提取整个处理值范围的预期结果(订单数量):
def extract_treated_effect(n, X_train, model):
# Set features to mean value
X_mean_mapping = {'X1': [X_train[:, 0].mean()] * n,
'X2': [X_train[:, 1].mean()] * n,
'X3': [X_train[:, 2].mean()] * n,
'X4': [X_train[:, 3].mean()] * n}
# Create DataFrame
df_scoring = pd.DataFrame(X_mean_mapping)
# Add full range of treatment values
df_scoring['T'] = X_train[:, 4].reshape(-1, 1)
# Calculate outcome prediction for treated
treated = model.predict(df_scoring)
return treated, df_scoring
我们对每个产品执行此操作:
treated_1, df_scoring_1 = extract_treated_effect(n, X_train_1, model_1)
treated_2, df_scoring_2 = extract_treated_effect(n, X_train_2, model_2)
treated_3, df_scoring_3 = extract_treated_effect(n, X_train_3, model_3)
然后我们提取当处理设置为 0 时的预期结果(订单数量):
def extract_untreated_effect(n, X_train, model):
# Set features to mean value
X_mean_mapping = {'X1': [X_train[:, 0].mean()] * n,
'X2': [X_train[:, 1].mean()] * n,
'X3': [X_train[:, 2].mean()] * n,
'X4': [X_train[:, 3].mean()] * n,
'T': [0] * n}
# Create DataFrame
df_scoring = pd.DataFrame(X_mean_mapping)
# Add full range of treatment values
df_scoring
# Calculate outcome prediction for treated
untreated = model.predict(df_scoring)
return untreated
再次,我们对每个产品执行此操作:
untreated_1 = extract_untreated_effect(n, X_train_1, model_1)
untreated_2 = extract_untreated_effect(n, X_train_2, model_2)
untreated_3 = extract_untreated_effect(n, X_train_3, model_3)
我们现在可以计算整个处理值范围的处理效果:
treatment_effect_1 = treated_1 - untreated_1
treatment_effect_2 = treated_2 - untreated_2
treatment_effect_3 = treated_3 - untreated_3
当我们将其与从数据生成器保存的实际处理效果进行比较时,我们可以看到 S-Learner 在估计整个处理值范围的处理效果方面非常有效:

用户生成的图像
现在我们拥有了这些处理效果数据,可以用它为每个产品构建响应曲线。
米氏-孟东方程(Michaelis-Menton)
为了构建响应曲线,我们需要一个曲线拟合工具。SciPy 有一个很好的实现,我们将使用它:
[## scipy.optimize.curve_fit - SciPy v1.13.0 手册]
scipy.optimize. curve_fit ( f , xdata , ydata , , , , , bounds = (-inf, inf) , , , * , , , ** kwargs ) [source] 使用…
docs.scipy.org](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html?source=post_page-----011ce140d180--------------------------------)
我们首先设置我们想要学习的函数:
def michaelis_menten(x, alpha, lam):
return alpha * x / (lam + x)
然后我们可以使用 curve_fit 来学习 alpha 和 lambda 参数:
def response_curves(treatment_effect, df_scoring):
maxfev = 100000
lam_initial_estimate = 0.001
alpha_initial_estimate = max(treatment_effect)
initial_guess = [alpha_initial_estimate, lam_initial_estimate]
popt, pcov = curve_fit(michaelis_menten, df_scoring['T'], treatment_effect, p0=initial_guess, maxfev=maxfev)
return popt, pcov
我们对每个产品执行此操作:
popt_1, pcov_1 = response_curves(treatment_effect_1, df_scoring_1)
popt_2, pcov_2 = response_curves(treatment_effect_2, df_scoring_2)
popt_3, pcov_3 = response_curves(treatment_effect_3, df_scoring_3)
我们现在可以将学习到的参数输入到米氏孟东方程中,帮助我们可视化曲线拟合的效果:
treatment_effect_curve_1 = michaelis_menten(df_scoring_1['T'], popt_1[0], popt_1[1])
treatment_effect_curve_2 = michaelis_menten(df_scoring_2['T'], popt_2[0], popt_2[1])
treatment_effect_curve_3 = michaelis_menten(df_scoring_3['T'], popt_3[0], popt_3[1])
我们可以看到曲线拟合做得非常好!

用户生成的图像
现在我们拥有了每个产品的 alpha 和 lambda 参数,我们可以开始考虑非线性优化……
非线性编程
我们首先开始收集所有优化所需的信息:
-
所有产品的列表
-
总促销预算
-
每个产品的预算范围
-
从米氏-孟东反应曲线中提取的每个产品的参数
# List of products
products = ["product_1", "product_2", "product_3"]
# Set total budget to be the sum of the mean of each product reduced by 20%
total_budget = (df_scoring_1['T'].mean() + df_scoring_2['T'].mean() + df_scoring_3['T'].mean()) * 0.80
# Dictionary with min and max bounds for each product - set as +/-20% of max/min discount
budget_ranges = {"product_1": [df_scoring_1['T'].min() * 0.80, df_scoring_1['T'].max() * 1.2],
"product_2": [df_scoring_2['T'].min() * 0.80, df_scoring_2['T'].max() * 1.2],
"product_3": [df_scoring_3['T'].min() * 0.80, df_scoring_3['T'].max() * 1.2]}
# Dictionary with response curve parameters
parameters = {"product_1": [popt_1[0], popt_1[1]],
"product_2": [popt_2[0], popt_2[1]],
"product_3": [popt_3[0], popt_3[1]]}
接下来我们设置目标函数——我们希望最大化订单数,但由于我们将使用最小化方法,因此返回预期订单总数的负值。
def objective_function(x, products, parameters):
sum_orders = 0.0
# Unpack parameters for each product and calculate expected orders
for product, budget in zip(products, x, strict=False):
L, k = parameters[product]
sum_orders += michaelis_menten(budget, L, k)
return -1 * sum_orders
最后我们可以运行优化,确定分配给每个产品的最优预算:
# Set initial guess by equally sharing out the total budget
initial_guess = [total_budget // len(products)] * len(products)
# Set the lower and upper bounds for each product
bounds = [budget_ranges[product] for product in products]
# Set the equality constraint - constraining the total budget
constraints = {"type": "eq", "fun": lambda x: np.sum(x) - total_budget}
# Run optimisation
result = minimize(
lambda x: objective_function(x, products, parameters),
initial_guess,
method="SLSQP",
bounds=bounds,
constraints=constraints,
options={'disp': True, 'maxiter': 1000, 'ftol': 1e-9},
)
# Extract results
optimal_treatment = {product: budget for product, budget in zip(products, result.x, strict=False)}
print(f'Optimal promo budget allocations: {optimal_treatment}')
print(f'Optimal orders: {round(result.fun * -1, 2)}')
输出向我们展示了每个产品的最优促销预算:

用户生成的图像
如果你仔细检查响应曲线,你会发现优化结果是直观的:
-
稍微减少产品 1 的预算。
-
显著减少产品 2 的预算。
-
显著增加产品 3 的预算。
结语。
今天我们讨论了 S-Learner、Michaelis-Menten 方程和非线性规划的强大结合!以下是一些结语:
-
如前所述,使用 S-Learner 时要小心正则化偏差!
-
S-Learner 的一个很好的替代方法是使用 DML,但在训练模型之前对处理进行转换——然而,这意味着你需要对处理的函数形式有一定的先验知识。
-
我选择使用 Michaelis-Menten 方程来构建我的响应曲线——然而,这可能不适合你的问题,可以通过其他更合适的转换方法来替代。
-
使用 SLSQP 来解决非线性规划问题可以让你灵活地使用等式和不等式约束。
-
你收集的数据很可能是观察性数据——这带来了一些挑战,尤其是在你将收集到的折扣值范围上——这些值可能会集中在一个特定的区域。使用某种 Shapley 方法来创建用于生成响应曲线的数据,在这种情况下可能更为合适。
-
我选择专注于定价和促销,但这个框架可以扩展到营销预算。
如果你想继续深入了解因果 AI,关注我——在下一篇文章中,我们将讨论如何衡量营销活动的内在因果影响。
使用线性求解器优化神经网络
如何使用线性求解器优化多维非线性神经网络。
·发布于Towards Data Science ·19 分钟阅读·2024 年 2 月 16 日
--

图片来自Sam Moghadam Khamseh在Unsplash上的分享
注意:如果你没有 Medium 订阅,可以免费阅读文章(如果你有订阅,请继续阅读,谢谢!🥰)
最近,我遇到了一个问题,要求我创建一个能够接受多个输入特征并预测连续输出的模型。
然后,我需要从该模型中获得最佳输出,对我来说是最低的数值。换句话说,我需要解决一个优化问题。
问题是(直到那个阶段我才意识到)我所在的环境不允许我使用非线性方法或复杂的框架——所以不能使用神经网络,不能使用非线性求解器,什么都不行……
但是,我创建的模型运行得很好(考虑到我用来训练它的数据点数量很少),而且我不想删除所有代码,从头开始使用线性模型。
所以,在喝了一杯咖啡之后☕,我决定使用这个已经训练好的非线性模型,生成一些小的线性模型。然后,我可以使用线性求解器来解决优化问题。
非线性函数的优化通过分段线性化
如何使用免费的混合整数线性求解器来解决非线性优化问题。一个逐步的示例。
·发表于 Towards Data Science ·12 分钟阅读·2024 年 1 月 10 日
--

注意:如果你没有中等订阅权限,你可以在这里免费阅读文章(如果你有订阅,请继续在这里阅读,谢谢!🥰)
本文是试图理解如何将包含非线性项的模型转化为线性模型的结果。通常,人们希望这样做有几个原因:
-
非线性模型可能非常难以求解。
-
线性模型可能更容易理解和解释。
-
线性模型通常求解速度更快。
-
还有许多其他(特定案例的)原因。
如果你所在的大学或公司没有访问强大(混合整数)非线性求解器的权限,比如BARON、ANTIGONE等,并且你遇到了非线性问题,那么一个选择是尝试将模型线性化,以便使用线性求解器。
我知道你现在可能会想,在“仅非线性且没有整数变量”的情况下,你也可以使用像IPOPT这样的公开可用求解器。或者如果你想使用其他……
使用替代模型进行优化,通过符号回归实现
通过符号回归方法识别代数替代模型来优化黑箱系统的一种可能性。
·发布于 Towards Data Science ·16 分钟阅读·2024 年 1 月 19 日
--

图片由 Jeremy Bishop 提供,来自 Unsplash
注:如果你没有 Medium 订阅,你可以免费阅读本文(如果你有订阅,请继续阅读,谢谢!🥰)
执行优化是一个非常有趣的任务。在我们的日常生活中,我们可能会对如何以最短的时间到达工作地点感兴趣,或者可能会关注如何调整我们研磨咖啡的最佳颗粒大小,以制作一杯非常美味的咖啡 ☕。各行各业也在关注优化,例如供应链、碳排放或废物积累等问题。
设置优化的方式有很多,具体取决于特定情况的不同。为了本篇文章,我将这些情况分为两部分:
一方面,我们可能了解驱动研究系统的物理、化学或生物学原理。基于这些知识,我们可以建立代数方程,准确地描述我们观察到的现象(第一性原理)。这些情况允许使用现成的求解器,如GLPK、BARON、ANTIGONE、SBB或其他,因为我们有…
线性规划简介 — 第一部分
使用 R 优化生产
·发表于Towards Data Science ·阅读时间 8 分钟·2024 年 6 月 11 日
--
去年,我的一位朋友找到了我,他在一家小型的家族拥有的钢铁和金属公司工作。他想知道是否能创建一个工具,帮助他解决在切割钢梁时如何最小化浪费的问题。这听起来像是一个适合用线性规划解决的问题!
当我刚开始时,关于如何在 R 中使用线性规划的初学者文章并不多,而且这些文章对于数学基础不深的人来说很难理解。到 2023 年初,ChatGPT 在 R 中使用线性规划的方面表现并不出色,因此我曾经希望能有一份这样的指南。
本系列是我尝试编写的一份指南。希望它能对某些人有所帮助。
本系列的第一部分将介绍 R 语言中的线性规划概念,并通过一个基本的例子进行讲解。在第二部分中,我将教你如何创建一个更高级的模型。如果需求足够,我可能会扩展到第三部分,在其中我会详细介绍如何创建一个 Shiny 应用,实际使用线性规划来优化工作。
线性规划
线性规划涉及到找到线性函数的最优解。常见的例子有背包问题(给定一组物品,每个物品有一个价值和重量,要求找出哪些物品应放入背包,以使得不超过重量限制的同时,最大化所选物品的总价值)和旅行商问题(给定一组...
在 Amazon SageMaker 实时推理上优化 Mistral7B 的部署
利用由 DJL Serving 和 Nvidia TensorRT 驱动的大型模型推理容器
·发布于Towards Data Science ·9 分钟阅读·2024 年 2 月 21 日
--

生成式人工智能领域以空前的速度持续扩展,每天都有更多的大型语言模型(LLM)家族问世。在每个家族中,也有不同规模的模型,例如 Llama7b、Llama13B 和 Llama70B。无论选择哪个模型,托管这些 LLM 进行推理时都会面临相同的挑战。
这些大型语言模型(LLM)的规模仍然是最紧迫的挑战,因为很难/几乎不可能将许多这样的 LLM 部署到单一的 GPU 上。为了解决这个问题,有几种不同的方法,比如模型分割。通过模型分割,你可以使用管道并行或张量并行等技术,将模型分割到多个 GPU 上。除模型分割外,其他常用方法还包括将模型权重进行量化,通过降低精度来减少模型本身的大小,代价是精度的损失。
虽然模型的规模本身就是一个巨大的挑战,但在文本生成中,保留先前的推理/注意力也是一个挑战,对于基于解码器的模型。使用这些模型进行文本生成并不像传统方法那么简单……
通过权重量化优化深度学习模型
权重量化的实际应用及其对模型大小和性能的影响。
·发布于数据科学前沿 ·阅读时间 14 分钟·2024 年 6 月 7 日
--

图片来自作者
📚什么是深度学习中的量化?
为什么我们需要量化?
让我们来谈谈深度学习中的量化。你是否曾经想过,为什么量化在深度学习中如此重要?尽管深度学习和大型语言模型(LLMs)非常强大,但它们也面临许多挑战。由于这些模型非常庞大,它们的计算需求也相当高——需要大量的计算能力和内存,这使得在资源有限的地方使用它们变得十分困难。此外,在进行预测时,它们甚至可能消耗大量能源,这就导致如果计算资源有限,推理变得不可能。
量化通过调整模型大小,使其更加易于管理,同时几乎不影响其性能,帮助解决这些问题。这涉及到修改模型参数的数量和数据类型的精度。通过这种方式,模型变得更轻便、更快速,这意味着它们可以在更多地方运行并消耗更少的能源。
优化云端 Spot 市场中的 AI 开发实例类型选择
深度学习实例选择——第二部分
·发布于Towards Data Science ·9 分钟阅读·2024 年 1 月 22 日
--

图片来源:Mike Enerio 来自Unsplash
本文由Tomer Berkovich、Yitzhak Levi和Max Rabin合作撰写。
对于机器学习(ML)工作负载,选择合适的实例是一个重要的决策,可能对开发的速度和成本产生重大影响。在上一篇文章中,我们详细阐述了这一过程,提出了一个用于做出这一重要决策的度量标准,并强调了在做决策时应考虑的许多因素。在本文中,我们将展示通过在选择云端实例时考虑Spot 实例的可用性来降低 AI 模型训练成本的机会。
使用 Spot 实例降低成本
在云计算中,最重要的节省成本的机会之一就是利用低成本的Amazon EC2 Spot 实例。Spot 实例是来自多余云服务容量的折扣计算引擎。作为交换,AWS 保留在几乎没有预警的情况下中断实例的权利。因此,Spot 实例的使用仅适用于容错性强的工作负载。幸运的是,通过有效使用模型检查点,可以设计出具有容错能力并能够利用 Spot 实例的机器学习训练工作负载。事实上,Amazon SageMaker,AWS 的机器学习开发管理服务,通过管理完整的 Spot 生命周期使得在 Spot 实例上进行训练变得更加简单。
预测 Spot 实例容量的挑战
不幸的是,Spot 实例容量(用于衡量 Spot 实例的可用性)会不断波动,且很难预测。Amazon 提供了部分帮助,通过Spot placement score(SPS)功能评估所选实例类型的Spot 实例容量,该功能可以指示 Spot 请求在特定区域或可用区(AZ)成功的可能性。当你有自由选择在多个不同位置训练模型时,这个功能尤其有用。然而,SPS 功能并不提供任何保证。
当你选择在一个或多个 Spot 实例上训练模型时,你将面临风险,即所选实例类型可能没有任何 Spot 容量(即你的训练任务无法启动),或者更糟糕的是,你可能会进入一个反复迭代的周期,在这个周期中,训练仅仅进行了少数训练步骤,并在没有任何实质性进展的情况下被停止——这可能会增加你的训练成本而没有任何回报。
在过去几年里,Spot 实例的使用挑战在多 GPU EC2 实例类型(如 g5.12xlarge 和 p4d.24xlarge)中尤为突出。对强大训练加速器的需求大幅增加(部分受生成式 AI 领域进展推动),再加上全球供应链的中断,使得几乎不可能依赖多 GPU Spot 实例进行 ML 训练。自然的应对措施是使用更昂贵的 按需实例(OD)或 预留实例。然而,在我们 之前的文章 中,我们强调了考虑多种不同备选方案来选择实例类型的价值。在本帖中,我们将展示通过将多 GPU 按需实例替换为多个单 GPU Spot 实例所带来的潜在收益。
尽管我们的演示将使用亚马逊 Web 服务(AWS),但在其他云服务平台(CSPs)上也能得出类似的结论。请不要将我们选择的 CSP 或服务解读为一种推荐。最适合您的选项将取决于项目的具体细节。此外,请考虑到我们展示的成本节省类型可能无法在您的项目中复现,和/或我们提出的解决方案可能不适用(例如,出于本帖讨论范围之外的某些原因)。在将其应用到您的用例之前,请务必对提案的相关性和有效性进行详细评估。
当多个单 GPU 实例比单个多 GPU 实例更好时
如今,在多个 GPU 设备上并行训练 AI 模型——这一过程被称为分布式训练——已经变得很常见。如果不考虑实例定价,当您在选择一个多 GPU 实例和多个单 GPU 实例(相同类型)之间做出选择时,通常会选择多 GPU 实例。分布式训练通常需要大量的数据通信(例如,梯度共享)在 GPU 之间。单个实例中 GPU 的近距离位置有助于更高的网络带宽和更低的延迟。此外,一些多 GPU 实例还包括专用的 GPU 到 GPU 互连,可以进一步加速通信(例如,NVLink 在 p4d.24xlarge 实例上)。然而,当 Spot 容量仅限于单个 GPU 实例时,以更低的成本在多个单 GPU 实例上训练就显得更具吸引力。至少,这值得评估其节省成本的潜力。
优化多个 EC2 实例之间的数据通信
当分布式训练在多个实例上运行时,GPU 通过宿主机之间的网络相互通信。为了优化训练速度并减少网络瓶颈的可能性和/或影响,我们需要确保最小的网络延迟和最大的数据信息流量。这些因素可能会受到多个因素的影响。
实例并置
网络延迟会受到 EC2 实例相对位置的巨大影响。理想情况下,当我们请求多个云端实例时,希望它们都位于同一个物理机架上。实际上,如果没有适当的配置,它们甚至可能不在同一个城市。我们将在下面的示范中使用一个VPC 配置对象,通过编程将 Amazon SageMaker 训练任务指定为使用一个Amazon 虚拟私有云(VPC)的单一子网。这种方法将确保所有请求的训练实例都在同一个可用区(AZ)中。然而,仅仅在同一可用区并置可能不足以满足需求。此外,我们描述的方法涉及选择与特定可用区(例如,具有最高Spot 放置分数)关联的子网。理想的 API 应能在任何具有足够容量的可用区内满足请求。
控制实例位置的更好方法是将它们启动在一个放置组中,特别是一个集群放置组。这样不仅能保证所有实例都在同一个可用区,还会将它们放置在“网络的同一高带宽分割段”上,从而最大化它们之间网络流量的性能。然而,截至本文撰写时,SageMaker 不提供指定放置组的选项。为了利用放置组,我们需要使用替代的训练服务解决方案(如下所示)。
EC2 网络带宽约束
确保考虑你选择的 EC2 实例所支持的最大网络带宽。特别注意,单 GPU 机器的网络带宽通常被描述为“最多”某个 Gbps 值。确保理解这意味着什么,以及它如何影响训练过程中的速度。
请记住,GPU 之间的数据通信(例如,梯度共享)可能需要与其他通过网络流动的数据共享有限的网络带宽,例如流入训练实例的训练样本或上传到持久存储的训练产物。考虑减少每类数据负载的方式,以最小化网络瓶颈的可能性。
弹性网络适配器(EFA)
越来越多的 EC2 实例类型支持 弹性网络适配器(EFA),这是一种专用网络接口,用于优化节点间的通信。使用 EFA 可以对训练工作负载的运行时性能产生决定性影响。请注意,EFA 网络通道的带宽与标准网络的带宽不同。在写作时,EFA 功能的详细文档很难获得,通常最好通过试验和错误来评估其影响。在相关情况下,考虑使用支持 EFA 类型的EC2 实例。
玩具示例
我们现在将演示在四个单 GPU EC2 g5 Spot 实例(ml.g5.2xlarge 和 ml.g5.4xlarge)与一个单四 GPU 按需实例(ml.g5.12xlarge)上的训练性能对比。我们将使用下面的训练脚本,该脚本包含一个基于 Vision Transformer(ViT)的分类模型(在合成数据上进行训练)。
import os, torch, time
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from timm.models.vision_transformer import VisionTransformer
batch_size = 128
log_interval = 10
# use random data
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=[index % 1000], dtype=torch.int64)
return rand_image, label
def mp_fn():
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
# model definition
model = VisionTransformer()
loss_fn = torch.nn.CrossEntropyLoss()
model.to(torch.cuda.current_device())
model = DDP(model)
optimizer = torch.optim.Adam(params=model.parameters())
# dataset definition
num_workers = os.cpu_count()//int(os.environ['LOCAL_WORLD_SIZE'])
dl = DataLoader(FakeDataset(), batch_size=batch_size, num_workers=num_workers)
model.train()
t0 = time.perf_counter()
for batch_idx, (x, y) in enumerate(dl, start=1):
optimizer.zero_grad(set_to_none=True)
x = x.to(torch.cuda.current_device())
y = torch.squeeze(y.to(torch.cuda.current_device()), -1)
with autocast(enabled=True, dtype=torch.bfloat16):
outputs = model(x)
loss = loss_fn(outputs, y)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0 and local_rank == 0:
time_passed = time.perf_counter() - t0
samples_processed = dist.get_world_size() * batch_size * log_interval
print(f'{samples_processed / time_passed} samples/second')
t0 = time.perf_counter()
if __name__ == '__main__':
mp_fn()
以下代码块演示了我们如何使用 SageMaker Python 包(版本 2.203.1)运行实验。请注意,对于四实例的实验,我们配置了使用一个具有单一子网的 VPC,如上所述。
from sagemaker.pytorch import PyTorch
# Toggle flag to switch between multiple single-GPU nodes and
# single multi-GPU node
multi_inst = False
inst_count=1
inst_type='ml.g5.12xlarge'
use_spot_instances=False
max_wait=None #max seconds to wait for Spot job to complete
subnets=None
security_group_ids=None
if multi_inst:
inst_count=4
inst_type='ml.g5.4xlarge' # optinally change to ml.g5.2xlarge
use_spot_instances=True
max_wait=24*60*60 #24 hours
# configure vpc settings
subnets=['<VPC subnet>']
security_group_ids=['<Security Group>']
estimator = PyTorch(
role='<sagemaker role>',
entry_point='train.py',
source_dir='<path to source dir>',
instance_type=inst_type,
instance_count=inst_count,
framework_version='2.1.0',
py_version='py310',
distribution={'torch_distributed': {'enabled': True}},
subnets=subnets,
security_group_ids=security_group_ids,
use_spot_instances=use_spot_instances,
max_wait=max_wait
)
# start job
estimator.fit()
请注意,我们的代码依赖于第三方timmPython 包,且我们在源代码目录根目录中的 requirements.txt 文件中指定了该包。这假设 VPC 已配置为启用互联网访问。另外,您可以定义一个私有的 PyPI 服务器(如此处所述),或者创建一个预先安装了第三方依赖项的自定义镜像(如此处所述)。
结果
我们在下表中总结了实验结果。按需定价来自于SageMaker 定价页面(截至本文写作时,2024 年 1 月)。竞价节省的值是从已完成任务的托管竞价训练节省报告中收集的。请参阅EC2 竞价定价文档以了解报告的竞价节省是如何计算的。

实验结果(作者)
我们的结果清楚地表明,当使用四个单 GPU 竞价实例而不是一个四 GPU 按需实例时,节省的潜力是相当大的。进一步的结果表明,尽管按需 g5.4xlarge 实例的成本较高,但由于增加的 CPU 功率和/或网络带宽,以及更高的竞价节省,最终导致了更大的节省。
重要的是,请记住,相对性能结果可能会根据您的任务细节以及实验运行时的竞价价格而大幅变化。
使用集群放置组强制执行 EC2 实例协同定位
在上一篇文章中,我们描述了如何在未管理的服务之上创建一个自定义的管理环境,例如Amazon EC2。其中列出的一个动机因素是希望在多实例设置中对设备放置进行更大的控制,例如,使用集群放置组,如上所述。在本节中,我们展示了如何使用集群放置组创建多节点设置。
我们的代码假设存在一个默认 VPC,并且(一次性)创建了一个集群放置组,这里演示使用了AWS Python SDK(版本 1.34.23):
import boto3
ec2 = boto3.client('ec2')
ec2.create_placement_group(
GroupName='cluster-placement-group',
Strategy='cluster'
)
在下面的代码块中,我们使用AWS Python SDK来启动我们的 Spot 实例:
import boto3
ec2 = boto3.resource('ec2')
instances = ec2.create_instances(
MaxCount=4,
MinCount=4,
ImageId='ami-0240b7264c1c9e6a9', # replace with image of choice
InstanceType='g5.4xlarge',
Placement={'GroupName':'cluster-placement-group'},
InstanceMarketOptions={
'MarketType': 'spot',
'SpotOptions': {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate"
}
},
)
请参阅我们的上一篇文章,了解如何一步步将其扩展为自动化训练解决方案的提示。
摘要
在这篇文章中,我们演示了如何通过选择灵活的训练实例类型来提高利用 Spot 实例容量的能力,从而降低整体训练成本。
随着 AI 模型规模的不断扩大以及 AI 训练加速器成本的不断上涨,探索减少训练开销的方式变得愈加重要。这里介绍的技术只是优化成本性能的几种方法之一。我们鼓励你查阅我们的上一篇文章,以获取更多关于该领域其他机会的见解。
使用强化学习优化库存管理:一个实用的 Python 指南
一份关于如何在 Python 中应用 Q 学习方法以优化库存管理和降低成本的完整指南
·发布于 Towards Data Science ·13 分钟阅读·2024 年 10 月 3 日
--

库存管理 — 我们解决的是什么问题?
假设你在管理一家自行车店。每天,你需要决定从供应商那里订购多少辆自行车。如果你订购太多,就会产生高额的库存持有成本(存储自行车的费用);如果订购太少,你可能会错失潜在的销售机会。在这里,挑战在于制定一个(订购)策略,能够最佳地平衡这些权衡。库存管理在多个行业中至关重要,其目标是确定最佳的定期订货数量,以最大化盈利。
为什么选择强化学习来进行库存管理?
之前,我们讨论过使用动态规划(DP)和马尔可夫决策过程(MDP)来解决这个问题 这里。然而,DP 方法需要一个完整的环境模型(在这种情况下,我们需要知道需求的概率分布),而这一点可能并不总是可用或实际可行的。
这里介绍了强化学习(RL)方法,它通过遵循“数据驱动”的方法克服了这个挑战。
目标是构建一个“数据驱动”的代理,通过与环境(不确定性)互动来学习最佳策略(订购多少)。RL 方法去除了对环境模型先验知识的需求。本文探索了 RL 方法,特别是 Q 学习,来寻找最优的库存策略。
如何框定库存管理问题?
在深入了解 Q 学习方法之前,了解库存管理问题的基础非常重要。从本质上讲,库存管理是一个顺序决策问题,今天做出的决策会影响明天的结果和可用的选择。让我们分解这个问题的关键要素:状态,不确定性和重复决策。
状态:当前的情况是什么?
在自行车店的背景下,状态代表当前关于库存的情况。它由两个关键组成部分定义:
α(Alpha):你目前店内的自行车数量。(称为现有库存)
β(Beta):你昨天订购的自行车,预计明天早上到达(36 小时交货时间)。这些自行车仍在运输途中。(称为在途库存)
一起,(α,β) 形成了状态,提供了在任何给定时刻你库存状态的快照。
不确定性:可能会发生什么?
该问题中的不确定性来源于每天自行车需求的随机性。你无法确切知道有多少客户会进店并要求自行车,这使得预测确切需求变得具有挑战性。
决策:你每天应该订购多少辆自行车?
作为自行车店的老板,你每天都面临一个重复的决策:你应该从供应商那里订购多少辆自行车?你的决策需要考虑到当前库存的状态(α,β),以及第二天客户需求的不确定性。
管理自行车店库存的典型 24 小时周期如下所示:
6 PM: 观察当前库存的状态 St:(α,β)。(状态)
6 PM: 做出决定,确定要订购多少辆新自行车。(决策)
6 AM: 收到 36 小时前订购的自行车。
8 AM: 开店迎接顾客。
8 AM — 6 PM: 一整天感受客户需求。(不确定性)
6 PM: 关店并准备进入下一个周期。
下图展示了库存管理过程的图示:

一个典型的 24 小时库存管理周期 — 图片来源:作者
什么是强化学习?
强化学习(RL)是一种数据驱动的方法,侧重于学习如何通过一系列决策(遵循策略)来最大化累积奖励。它类似于人类和动物通过试错来学习采取什么行动。在库存管理的背景下,RL 可用于学习优化的订货策略,从而最小化库存管理的总成本。
强化学习方法的关键组件是:
代理商:与环境交互的决策者。
环境:代理商与之交互的外部系统。在此案例中,环境是随机的客户需求。
状态:环境的当前情况或快照。
动作:代理商做出的决策或选择。
奖励:反馈信号,告诉代理商其表现如何。
代理商(决策者)的目标是学习最优策略,该策略是从状态到动作的映射,可以最大化累计奖励。
在库存管理的背景下,策略告诉代理商根据当前库存状态和客户需求的不确定性每天应该订购多少辆自行车。
实现库存优化问题的强化学习
Q 学习是一种无模型的强化学习算法,它学习在任何给定状态下选择最优动作的策略。与需要完整环境模型的动态规划方法不同,Q 学习通过与环境的交互(这里是指不确定性和它获得的奖励)直接学习,更新 Q 表。
Q 学习的关键组件
在我们的案例中,代理商是决策者(自行车店老板),而环境是客户需求。状态由当前库存水平(alpha, beta)表示,动作是订购多少辆自行车。奖励是与持有库存和错失销售相关的成本。Q 表是一个存储每个状态-动作对的预期未来奖励的表格。
Q 表的初始化
在本工作中,Q 表被初始化为名为 Q 的字典。状态通过元组(alpha, beta)表示,其中:alpha 是库存中的物品数量(现有库存)。beta 是已订购物品的数量(待订购库存)。
动作是每个状态下可以采取的可能库存订购数量。对于每个状态(alpha, beta),可能的动作取决于库存中剩余空间的大小(剩余容量 = 库存容量 — (alpha + beta))。限制条件是订购的物品数量不能超过库存的剩余容量。
Q 值的示意设计如下图所示:

Q 字典的示意设计如下所示 — 图片来源:作者
Q 字典可以初始化为:
def initialize_Q(self):
# Initialize the Q-table as a dictionary
Q = {}
for alpha in range(self.user_capacity + 1):
for beta in range(self.user_capacity + 1 - alpha):
state = (alpha, beta)
Q[state] = {}
max_action = self.user_capacity - (alpha + beta)
for action in range(max_action + 1):
Q[state][action] = np.random.uniform(0, 1) # Small random values
return Q
正如上述代码所示,Q 值(Q[state][action])被初始化为较小的随机值,以鼓励探索。
Q 学习算法
Q 学习方法根据来自环境的奖励(在此是与环境的交互)更新状态-动作对的表格。以下是该算法的三个步骤:

Q 学习方程 — 图片来源:作者
其中,s 是当前状态,a 是所采取的行动,s' 是下一个状态,( α ) 是学习率,( γ ) 是折扣因子。
我们将公式拆解,并在下面分成三部分重新写出:

Q 学习方程 — 图片来源:作者
上述公式的 Python 代码转换如下:
def update_Q(self, state, action, reward, next_state):
# Update the Q-table based on the state, action, reward, and next state
best_next_action = max(self.Q[next_state], key=self.Q[next_state].get)
# reward + gamma * Q[next_state][best_next_action]
td_target = reward + self.gamma * self.Q[next_state][best_next_action]
# td_target - Q[state][action]
td_error = td_target - self.Q[state][action]
# Q[state][action] = Q[state][action] + alpha * td_error
self.Q[state][action] += self.alpha * td_error
在上述函数中,每一行的等效方程已经作为注释显示在每一行上方。
在 Q 学习中模拟过渡和奖励以进行库存优化
当前状态由一个元组 (alpha, beta) 表示,其中:alpha 是当前的现有库存(库存中的物品数量),beta 是当前的在订库存(已订购但尚未收到的物品数量),init_inv 通过将 alpha 和 beta 相加来计算总的初始库存。
接下来,我们需要使用泊松分布模拟顾客需求,lambda 值为“self.poisson_lambda”。这里,需求表现出顾客需求的随机性:
alpha, beta = state
init_inv = alpha + beta
demand = np.random.poisson(self.poisson_lambda)
注意:泊松分布被用来模拟需求,这是建模随机事件(如顾客到达)时的常见选择。然而,我们可以使用历史需求数据或与环境的实时交互来训练模型。从本质上讲,强化学习就是从数据中学习,它不需要先验的模型知识。
现在,"下一个 alpha"(现有库存)可以表示为 max(0, init_inv - demand)。这意味着如果需求大于初始库存,那么新的 alpha 将为零;如果不大于,则为 init_inv - demand。
成本分为两部分。持有成本:通过将商店中的自行车数量与单位持有成本相乘来计算。然后,我们有另一个成本,即缺货成本。这是我们需要支付的错失需求的成本。这两部分构成了我们尝试通过强化学习方法最大化的“奖励”。(更好的表述是我们希望最小化成本,因此我们最大化奖励)。
new_alpha = max(0, init_inv - demand)
holding_cost = -new_alpha * self.holding_cost
stockout_cost = 0
if demand > init_inv:
stockout_cost = -(demand - init_inv) * self.stockout_cost
reward = holding_cost + stockout_cost
next_state = (new_alpha, action)
Q 学习中的探索 — 利用
在 Q 学习方法中选择行动涉及一定程度的探索,以便全面了解 Q 表中所有状态的 Q 值。为此,在每次选择行动时,有一个 epsilon 的概率我们采取探索方法,并“随机”选择一个行动,而在 1-ϵ 的概率下,我们从 Q 表中选择最优的行动。
def choose_action(self, state):
# Epsilon-greedy action selection
if np.random.rand() < self.epsilon:
return np.random.choice(self.user_capacity - (state[0] + state[1]) + 1)
else:
return max(self.Q[state], key=self.Q[state].get)
训练 RL 代理
RL 代理的训练通过“train”函数完成,具体步骤如下:首先,我们需要初始化 Q(空字典结构)。然后,在每一批次中收集经验(self.batch.append((state, action, reward, next_state))),并在每批次结束时更新 Q 表(self.update_Q(self.batch))。每批次的最大回合数被限制为“max_actions_per_episode”。回合数是指代理与环境互动以学习最优策略的次数。
每个回合从随机分配的状态开始,当动作数量低于 max_actions_per_episode 时,批次的数据收集继续进行。
def train(self):
self.Q = self.initialize_Q() # Reinitialize Q-table for each training run
for episode in range(self.episodes):
alpha_0 = random.randint(0, self.user_capacity)
beta_0 = random.randint(0, self.user_capacity - alpha_0)
state = (alpha_0, beta_0)
#total_reward = 0
self.batch = [] # Reset the batch at the start of each episode
action_taken = 0
while action_taken < self.max_actions_per_episode:
action = self.choose_action(state)
next_state, reward = self.simulate_transition_and_reward(state, action)
self.batch.append((state, action, reward, next_state)) # Collect experience
state = next_state
action_taken += 1
self.update_Q(self.batch) # Update Q-table using the batch
示例案例和结果
这个示例案例展示了如何将上述所有代码组合在一起,并查看 Q 学习智能体如何学习库存管理的最优策略。在这里,user_capicty(存储容量)为 10,表示库存可以容纳的物品总数(容量)。然后,poisson_lambda是需求分布中的λ值,其值为 4。持有成本为 8,表示每晚将一个物品保留在库存中的成本,而缺货成本则是失去需求的成本(假设当天有顾客需要该物品,而你却没有该物品在库存中)为 10。gamma值小于 1,用于在方程中折扣未来奖励(0.9),其中alpha(学习率)为 0.1。epsilon项用于控制探索-开发的困境。回合数为 1000,每个批次包含 1000 个(每回合的最大动作数)。
# Example usage:
user_capacity = 10
poisson_lambda = 4
holding_cost = 8
stockout_cost = 10
gamma = 0.9
alpha = 0.1
epsilon = 0.1
episodes = 1000
max_actions_per_episode = 1000
定义了这些初始参数后,我们可以定义 ql Python 类,然后使用该类进行训练,再通过模块“get_optimal_policy()”获取最优策略。
# Define the Class
ql = QLearningInventory(user_capacity, poisson_lambda, holding_cost, stockout_cost, gamma,
alpha, epsilon, episodes, max_actions_per_episode)
# Train Agent
ql.train()
# Get the Optimal Policy
optimal_policy = ql.get_optimal_policy()
结果
现在我们已经得到了通过 Q 学习方法找到的策略,我们可以将结果可视化并查看其表现。x 轴是状态,是(alpha,beta)的元组,y 轴是 Q 学习在每个状态下找到的“订单数量”。

Q 学习策略下每个状态(x 轴)对应的订单数量(y 轴) — 图片来源:作者
通过查看图表,可以得到几个启示。首先,当我们向右移动时,可以看到订单数量减少。当我们向右移动时,alpha 值增加(现有库存),这意味着我们需要“订购”更少,因为现有的库存可以满足需求。其次,当 alpha 保持不变时,随着 beta 的增加,我们减少了新地点的订单数量。这可以理解为,当“我们有更多的物品在‘订购中’时”,我们不需要增加订单。
将 Q 学习策略与基准策略进行比较
现在我们使用 Q 学习来找到策略(在给定状态下订购多少物品),我们可以将其与基准策略(一个简单的策略)进行比较。基准策略就是“按政策订购”,这意味着你查看现有库存和正在订购的库存,然后订购到“满足目标水平”。我们可以在这里编写简单的 Python 代码来实现这个策略:
# Create a simple policy
def order_up_to_policy(state, user_capacity, target_level):
alpha, beta = state
max_possible_order = user_capacity - (alpha + beta)
desired_order = max(0, target_level - (alpha + beta))
return min(max_possible_order, desired_order)
在代码中,target_level是我们希望订购的库存目标值。如果 target_level = user_capacity,那么我们只是为了填满库存。首先,我们可以比较这些不同方法的策略。对于每个状态,按照简单策略和 Q-learning 策略,订购的“数量”会是多少?在下图中,我们绘制了两种策略的比较。

比较 Q-Learning 和简单策略之间的订购策略,针对每个状态 — 图片来源:作者
简单策略只是按照一定的顺序进行订购,以满足库存需求,而 Q-learning 策略的订购量通常低于简单策略的订购量。
这可以归因于“poisson_lambda”在这里是 4,意味着需求远低于库存容量=10,因此订购“高数量的自行车”并不是最优选择,因为它有很高的持有成本。
我们还可以比较在应用这两种策略时,你能够获得的总累计奖励。为此,我们可以使用test_policy函数,该函数是“QLearningInventory”中特别设计用来评估策略的:
def test_policy(self, policy, episodes):
"""
Test a given policy on the environment and calculate the total reward.
Args:
policy (dict): A dictionary mapping states to actions.
episodes (int): The number of episodes to simulate.
Returns:
float: The total reward accumulated over all episodes.
"""
total_reward = 0
alpha_0 = random.randint(0, self.user_capacity)
beta_0 = random.randint(0, self.user_capacity - alpha_0)
state = (alpha_0, beta_0) # Initialize the state
for _ in range(episodes):
action = policy.get(state, 0)
next_state, reward = self.simulate_transition_and_reward(state, action)
total_reward += reward
state = next_state
return total_reward
该函数的工作方式是,它从一个新的状态(state = (alpha_0, beta_0))开始,然后根据该状态从策略中获得动作(订购数量),执行操作并查看奖励和下一个状态,过程继续进行,直到达到总的回合数,同时收集总奖励。

管理库存的总成本,遵循 Q-Learning 策略和简单策略 — 图片来源:作者
上面的图表比较了遵循“Q-Learning”和“简单策略”时,管理库存的总成本。目标是最小化运行库存的成本。由于我们模型中的“奖励”代表了这个成本,因此我们将总成本设置为-总奖励。
使用 Q-Learning 策略运行库存将导致比简单策略更低的成本。
GitHub 中的代码
本博客的完整代码可以在 GitHub 仓库中找到,链接:这里。
总结和主要收获
在这篇文章中,我们讨论了如何使用强化学习(特别是 Q-Learning)来优化库存管理。我们开发了一个 Q-learning 算法,通过与环境(不确定性)的互动来学习最优的订购策略。在这里,环境是客户的“随机”需求(自行车买家),状态是当前的库存状态(alpha, beta)。Q-learning 算法能够学习到最优策略,从而最小化库存管理的总成本。
主要收获
-
Q-Learning:Q-Learning 是一种无模型的强化学习算法,可以在不需要完整环境模型的情况下找到最优库存策略。
-
状态表示:库存管理中的状态由当前手头库存和订购库存表示,状态 = (α, β)。
-
成本降低:我们可以看到,相比简单的按容量订购的策略,Q-learning 策略能够带来更低的成本。
-
灵活性:Q-learning 方法非常灵活,可以应用于我们有历史需求数据的情况,或者我们可以与环境互动,学习最优策略。
-
数据驱动决策:正如我们所展示的,强化学习(RL)方法不需要任何关于环境模型的先验知识,因为它是从数据中学习的。
参考文献
[1] A. Rao, T. Jelvis,《强化学习基础与金融应用》(2022)。
[2] S. Sutton, A. Barto,《强化学习:导论》(2018)。
[3] W. B. Powell,《顺序决策分析与建模:用 Python 建模》(2022)。
[4] R. B. Bratvold,《做出正确决策》(2010)。
在你离开之前!🦸🏻♀️
优化营销活动:使用预算化的多臂赌博机
通过演示、我们的新解决方案和一段视频
·发表于 Towards Data Science ·9 分钟阅读·2024 年 8 月 16 日
--

图片由作者使用 GPT-4o 创建
让我们直接进入一个实际的例子。假设一家银行或电信公司为现有客户推出一款新产品/计划。为了推广这款产品,公司为其销售代表创建了多个呼叫模板(脚本)。目标是有效地说服客户购买新产品或注册新计划。
以下是该活动的运作方式:
-
呼叫脚本创建: 市场团队开发了多个版本的呼叫脚本,每个脚本采用不同的方法来推广新产品或新计划。
-
代理呼叫: 销售代理使用这些脚本来呼叫一部分客户。每次客户互动时,都会使用预定义的脚本之一。
-
数据收集: 在呼叫过程中,公司收集客户的回应数据,如表现出的兴趣、提出的问题,以及最终的转化率(即有多少客户购买新产品或注册新计划)。
-
实时分析: 公司实时分析这些数据,以评估每个脚本的有效性。这一分析有助于确定哪些脚本在将客户转化为新计划/产品方面更成功。
-
策略更新: 基于持续分析,公司动态调整每个脚本的使用频率。转化率较高的脚本会使用得更频繁,确保随着时间推移活动变得越来越有效。
接下来,我们展示如何使用传统的多臂老丨虎丨机问题来建模这个简单版本的活动。随着我们加入更多细节使得该模型更具现实性,我们展示了现有解决方案及其简单适配的不足之处。然后,我们提出了一种新的预算化多臂老丨虎丨机算法,该算法来自我们为 KDD 2024 会议提交并接受的论文,在这项任务中表现非常出色。我们还提供了代码链接和一段简短的视频,总结了这篇论文。
在这个故事中,使用“我们”是因为我和 Marco Heyden 一起撰写了这篇文章(Linkedin,Github),他是算法思想和我们论文[1]的作者。所有后续的图表都是由我们创建的,并使用了这个Jupyter notebook中的代码。
多臂老丨虎丨机问题的解决方案
我们的场景类似于多臂老丨虎丨机(MAB)问题。假设一个玩家在赌场里面对一台带有多个臂的老丨虎丨机(“老丨虎丨机”),每个臂的回报(奖励)分布未知。玩家的目标是通过决定玩哪些臂以及每个臂玩多少次来最大化他们的总奖金。挑战在于平衡探索(尝试不同的臂以收集有关它们奖励的信息)和利用(使用收集到的信息来玩具有最高已知奖励的臂)。
在我们的示例中,每个电话脚本就像一只老丨虎丨机的机械臂,其中奖励是脚本的成功。如果客户注册了新计划或购买了新产品,奖励为 1,否则为 0。例如,三个电话脚本的转化率分别为0.1、0.3和0.7,其成功率遵循伯努利分布,期望值分别为0.1、0.3和0.7。下图展示了不同策略的累计奖励。紫色线代表使用脚本 1,转化率为0.1,而绿色线代表使用脚本 3,转化率为0.7。这些线定义了可能奖励的范围。浅蓝色线显示了每次呼叫随机选择一个脚本的策略的累计奖励。在一个现实的环境中,只有转化率的估计值可用,优秀策略的累计奖励应该接近绿色线,并且至少高于浅蓝色线。

解决多臂老丨虎丨机问题的一个流行策略是上置信界限(UCB)算法[2]。它为每个臂(电话脚本)的期望奖励分配一个上置信界限,并选择具有最高上置信界限的臂进行执行。通过这种方式,算法在利用已知高奖励的同时,积极探索不确定性高的动作。从数学上讲,UCB 算法选择臂i,其解为:

-
rᵢ(t) 是手臂 i 在时间 t 时的经验平均奖励。
-
Nᵢ(t) 是手臂 i 到时间 t 时已被拉动的次数。
-
t 是到目前为止总共进行的游戏次数。
下图中的白线展示了这个策略在我们示例中的应用。

这个上界是基于 Chernoff-Hoeffding 上界,假设收益分布支持区间为 [0,1],这正是我们的情况。对于支持区间不同的奖励分布 [aᵢʳ, bᵢʳ],其中 aᵢʳ 和 aᵢʳ 是有限的,UCB 应相应地进行缩放:

预算的重要性
到目前为止,我们关注的是在给定次数的调用后最大化奖励总和。然而,期望不同脚本的调用具有相同的持续时间是不现实的。如果市场营销团队的能力不足以在给定时间预算内(例如几个月)接触到所有客户,那么更实际的做法是最大化给定通话持续时间内的累积奖励,而不是调用次数。
在我们的示例中,假设 脚本 1、脚本 2 和 脚本 3 的通话持续时间(即成本)是常数(我们稍后会放宽这个假设),分别为 1、2 和 8 分钟。如果我们现在根据总通话持续时间而不是通话次数来绘制结果,那么始终使用 脚本 2 的策略将成为最佳选择,而使用 脚本 3 的策略将成为最差选择。仅考虑转化率的上述 UCB 策略现在表现远不如随机策略。

通过用 rᵢ(t) 与 cᵢ* 对应值标准化奖励估计,并按照上述公式(2)调整 UCB 策略,可以不难修正 UCB 策略:

通过这种方式更新的 UCB 策略再次表现得相当好:

随机通话持续时间
请注意,假设固定通话时长也是不现实的。当奖励和成本都不固定时,有几种方法可以将 UCB 策略扩展到这种情况,例如:
约简型: 通过将奖励与成本比率 vᵢ=rᵢ/cᵢ 视为单一随机变量,并拉动具有最高上界 UCBᵢᵛ 的手臂:

联合型: 通过忽略成本的变化,使用它们的估计值 cᵢ(t) 来缩放奖励的 UCBᵢʳ,类似于公式(3):

组合型: 通过拉动最大化上界奖励与下界成本比率 UCBᵢʳ/LCBᵢᶜ 的手臂:

在(6)中,我们假设奖励来自 [0,1] 区间,为了简化公式。
上述所有策略都有问题,要么过于乐观,要么过于悲观,要么仅仅优化了错误的量。
简约策略(4)旨在最大化奖励-成本比的期望值 𝔼(vᵢ)=𝔼(rᵢ/cᵢ)。这与最大化奖励同时保持成本在给定预算范围内的目标不同。对于足够高的预算,后者等同于最大化奖励和成本期望值的比率 𝔼(rᵢ)/𝔼(cᵢ)。要理解为何 𝔼(rᵢ/cᵢ)≠𝔼(rᵢ)/𝔼(cᵢ),请注意,如果rᵢ和cᵢ是独立同分布的伯努利随机变量,则 𝔼(rᵢ)/𝔼(cᵢ)=1,而 𝔼(rᵢ/cᵢ) 是无穷大或未定义,具体取决于如何处理除以零的情况。在这种情况下,vᵢ=rᵢ/cᵢ*的支持也是无限的,导致 UCB 公式(4)无效。
因为联合策略没有考虑成本变化,它通常会产生过于紧凑的上置信界限,从而限制了策略的探索部分。每当成本的经验均值超过真实均值时,例如在对称成本分布的 50%的情况下,就会发生这种情况。
复合策略明确地对成本的不确定性进行了建模。然而,这种模型过于保守,允许分母为零甚至负数(!)。因此,复合策略在探索上花费了过多资源,并低估了开发步骤的价值。
还需要注意的是,至今讨论的平均奖励的上界可能远高于 1,尽管我们知道奖励的取值范围是{0,1}。在这种情况下,这个界限显然过于宽松,这可能增加探索,从而部分抵消考虑联合策略中固定成本的效果。
以下图展示了我们示例设置中三种策略的表现,在该设置中,我们现在使用支持范围为[1,10]的贝塔分布来建模通话时长,且脚本 1、脚本 2和脚本 3的期望值分别为1、2和8分钟。简约策略的表现几乎与随机选择脚本一样差,复合策略表现稍好,而联合策略则是明显的赢家。但是,是否有方法能够超越联合策略呢?

我们的解决方案:不对称置信区间
我们的ω-UCB 策略解决了之前描述的解决方案的不足之处。虽然我们与复合策略使用相同的 UCBᵢʳ/LCBᵢᶜ比例开始,但我们采用不同的方法来计算这些界限。我们的置信区间更为精确,并且始终保持在奖励或成本的支持范围内。具体来说,设p为一个有界随机变量——可以是奖励或成本——其支持范围为[a,b]。我们按如下方式计算p的置信区间[LCBᵖ, UCBᵖ*]。



-
μ 是时间 t 时刻 p 的经验均值,
-
σ² 是时间 t 时刻 p 的经验方差,
-
N 是到时间 t 为止对 p 的观察次数,
-
t 是当前的时间,
-
c 是期望能可靠估计均值和方差所需的臂拉次数;在实际操作中,c=30 效果良好,
-
ρ 是 ω-UCB 的参数;ρ=1 提供了更好的渐近性质,但对于实际应用,当游戏次数为 ~10⁴ 或更少时,我们推荐使用 ρ=1/4。
以下图展示了 ω-UCB 的表现。使用它几乎能够获得最大可能的累积奖赏。

我们还制作了一个 2 分钟的视频,概述了 ω-UCB 的思想:
该视频由 Marco Heyden 为 KDD 2024 制作。如果你喜欢它,请随时在 YouTube 上为它点赞
最后的想法
到现在为止,你已经掌握了如何通过即时客户反馈实时优化营销活动的洞察力。我们已经描述了一种强大的算法来帮助你做到这一点。然而,单凭这一点来保证成功,未免过于乐观。下面,我们概述了一些额外的考虑因素,这些因素可以进一步增强活动的效果。
首先,奖赏不太可能立即得知。通常,最好的期望是来自客户的兴趣信号。因此,构建一个可靠的奖赏代理,可能通过使用来自之前活动的数据,是至关重要的。
接下来,本讨论集中在为平均或代表性客户选择最佳脚本。稍微深入一点,可能不同的脚本对不同的客户群体效果更好。最佳脚本可能因群体而异。一种简单的方法是对客户进行细分,并将每个细分-脚本组合视为我们之前描述的预算化多臂强盗算法中的一个独立臂。在之前的文章中,我讨论了识别有趣客户细分的方法;对于一个活动来说,选择适当的目标变量来应用该方法是非常重要的。
患者规则归纳法发现了比以往报告的好 35% 的片段
[towardsdatascience.com
最后,除了客户特征外,“环境”因素如一天中的时间或一周中的某天,也可能影响脚本的相对表现。为了考虑所有这些因素,你可能会考虑将方法扩展到上下文预算化强盗问题,这也是另一篇文章的主题。
参考文献
[1] Marco Heyden, Vadim Arzamasov, Edouard Fouché, 和 Klemens Böhm。“带有非对称置信区间的预算化多臂强盗问题。”KDD ’24
[2] 彼得·奥尔(Peter Auer)、尼科洛·切萨-比安基(Nicolo Cesa-Bianchi)和保罗·费舍尔(Paul Fischer)。 “多臂赌博机问题的有限时间分析。” 机器学习 47 (2002):235–256。
使用 Python 优化数据分析的内存消耗——从 400 到 0.1

作者在 Canva 中创建
降低代码的内存消耗意味着降低硬件需求
·发表于Towards Data Science ·阅读时长 9 分钟·2024 年 6 月 3 日
--
有许多文章告诉我们如何提高代码的性能。当然,性能非常关键,尤其是在我们使用 Python 进行数据分析时。
然而,我认为内存消耗同样重要,尤其是在处理大型数据集或当我们有有限的硬件资源来运行任务时,内存消耗有时甚至更为重要。
在本文中,我将介绍几种有效的技巧,用于在不降低性能的情况下减少常见数据分析活动中的内存消耗。
1. 内存消耗测量

作者在 Canva 中创建
在我分享任何减少内存消耗的技巧之前,我们需要有测量内存消耗的方法。上周,我写了一篇文章,详细介绍了几种测量内存消耗的方法。如果你感兴趣,请查看它。
实践中的多任务学习模型优化
什么是多任务学习模型,如何优化它们
·发表于Towards Data Science ·阅读时长 6 分钟·2024 年 3 月 29 日
--

为什么选择多任务学习
多任务学习
多任务学习(MTL)[1]是机器学习领域的一项技术,我们利用单一模型同时学习多个任务。

多任务学习模型(图源:作者)
从理论上讲,这种方法允许任务之间的知识共享,并且比单任务训练取得更好的结果。此外,由于模型试图学习一个表示来优化多个任务,因此过拟合的可能性较低,从而实现更好的泛化能力。
多任务学习是一种归纳迁移方法,通过使用相关任务训练信号中包含的领域信息作为归纳偏置,从而改善泛化能力。它通过在并行学习任务的同时使用共享表示来实现;每个任务所学到的知识可以帮助其他任务更好地学习。[2]
在实际应用中,大型推荐和搜索系统通常基于多个指标来衡量用户满意度,例如停留时间、点击率等…
优化 Pandas 代码:操作顺序的影响
PYTHON 编程
学习如何重排代码以实现显著的速度提升。
·发布于Towards Data Science ·阅读时间 9 分钟·2024 年 3 月 18 日
--

图片来自Nick Fewings提供,来源于Unsplash
Pandas 提供了一个出色的数据框操作框架。在数据科学中,我们处理的可能是小型、大型——有时是非常大的数据框。分析小型数据框可能非常迅速,但即使对一个大型数据框进行单一操作,也可能需要相当长的时间。
在本文中,我将展示如何通过一个几乎不需要成本的操作:调整数据框的操作顺序,来缩短这段时间。
假设我们有如下的数据框(dataframe):
import pandas as pd
n = 1_000_000
df = pd.DataFrame({
letter: list(range(n))
for letter in "abcdefghijklmnopqrstuwxyz"
})
拥有一百万行和 25 列的数据框非常庞大。在当前的个人计算机上,对这样的数据框进行多次操作会显著影响性能。
假设我们想要筛选行,满足以下条件:a < 50_000 且 b > 3000,并选择五个列:take_cols=['a', 'b', 'g', 'n', 'x']。我们可以通过以下方式实现:
subdf = df[take_cols]
subdf = subdf[subdf['a'] < 50_000]
subdf = subdf[subdf['b'] > 3000]
使用 Aho-Corasick 算法优化 Spark 中的 Sigma 规则
扩展 Spark 以提高处理多个搜索词的性能
·发布于 Towards Data Science ·8 分钟阅读·2024 年 6 月 20 日
--

图片来自 Unsplash 上的 Aditya Chinchure
在将我们的入侵检测系统部署到CCCS的生产环境时,我们观察到许多 SigmaHQ 规则使用了非常庞大的搜索模式列表。这些列表用于测试CommandLine是否包含特定字符串,或者CommandLine是否以某个子字符串开始或结束。
我们特别感兴趣的是研究涉及“包含”条件的规则,因为我们怀疑这些条件可能会让 Spark 的评估变得耗时。以下是一个典型的 Sigma 规则示例:
detection:
selection_image:
- Image|contains:
- '\CVE-202'
- '\CVE202'
- Image|endswith:
- '\poc.exe'
- '\artifact.exe'
- '\artifact64.exe'
- '\artifact_protected.exe'
- '\artifact32.exe'
- '\artifact32big.exe'
- 'obfuscated.exe'
- 'obfusc.exe'
- '\meterpreter'
selection_commandline:
CommandLine|contains:
- 'inject.ps1'
- 'Invoke-CVE'
- 'pupy.ps1'
- 'payload.ps1'
- 'beacon.ps1'
- 'PowerView.ps1'
- 'bypass.ps1'
- 'obfuscated.ps1'
完整的可疑程序名称规则可以在这里找到
该规则说明了如何使用CommandLine|contains和Image|endswith。一些 Sigma 规则在<field>|contains条件下有成百上千的搜索词。
在 Spark SQL 中应用 Sigma 规则
在CCCS我们将 Sigma 规则转换为可执行的 Spark SQL 语句。为此,我们扩展了 SQL Sigma 编译器,并加入了一个自定义后端。它将上述规则转换成如下语句:
select
map(
'Suspicious Program Names',
(
(
(
Imagepath LIKE '%\\cve-202%'
OR Imagepath LIKE '%\\cve202%'
)
OR (
Imagepath LIKE '%\\poc.exe'
OR Imagepath LIKE '%\\artifact.exe'
...
OR Imagepath LIKE '%obfusc.exe'
OR Imagepath LIKE '%\\meterpreter'
)
)
OR (
CommandLine LIKE '%inject.ps1%'
OR CommandLine LIKE '%invoke-cve%'
OR CommandLine LIKE '%pupy.ps1%'
...
OR CommandLine LIKE '%encode.ps1%'
OR CommandLine LIKE '%powercat.ps1%'
)
)
) as sigma_rules_map
我们在一个 Spark 结构化流处理作业中运行上述语句。在对事件进行单次扫描时,Spark 评估了多个(数百个)Sigma 规则。sigma_rules_map列保存了所有这些规则的评估结果。通过这个映射,我们可以确定哪些规则是命中的,哪些不是。
正如我们所看到的,这些规则通常涉及将事件的属性(如CommandLine)与多个字符串模式进行比较。
其中一些测试是精确匹配的,如CommandLine = ‘something’。其他则使用startswith,并呈现为Imagepath LIKE ‘%\\poc.exe’。
Equals、startswith和endswith执行得非常快,因为这些条件都在事件的某个特定位置上锚定。
然而,像contains这样的测试呈现为CommandLine LIKE ‘%hound.ps1%’,这需要 Spark 扫描整个属性,以找到字母‘h’的可能起始位置,然后检查它后面是否跟着字母‘o’、‘u’等。
在内部,Spark 使用UTF8String,它抓取第一个字符,扫描缓冲区,如果找到匹配项,就继续使用matchAt函数比较剩余的字节。以下是UTF8String.contains函数的实现。
public boolean contains(final UTF8String substring) {
if (substring.numBytes == 0) {
return true;
}
byte first = substring.getByte(0);
for (int i = 0; i <= numBytes - substring.numBytes; i++) {
if (getByte(i) == first && matchAt(substring, i)) {
return true;
}
}
return false;
}
equals、startswith和endswith条件也使用matchAt函数,但与contains不同,这些条件知道从哪里开始比较,因此执行速度非常快。
为了验证我们关于contains条件执行成本高的假设,我们进行了一个快速且简单的实验。我们删除了所有 Sigma 规则中的contains条件,看看这会如何影响整体执行时间。结果差异显著,这鼓励我们继续推进实现自定义 Spark Catalyst 函数来处理涉及大量搜索词的contains操作的想法。
Aho-Corasick 算法
一些研究使我们找到了Aho-Corasick 算法,它似乎非常适合这个用例。Aho-Corasick 算法构建一个前缀树(字典树),并且可以在一次扫描要测试的文本时评估多个contains表达式。
这是如何使用 Robert Bor 在 GitHub 上提供的 Aho-Corasick Java 实现:github.com/robert-bor/aho-corasick
// create the trie
val triBuilder = Trie.builder()
triBuilder.addKeyword("test1")
triBuilder.addKeyword("test2")
trie = triBuilder.build()
// apply the trie to some text
aTextColumn = "some text to scan for either test1 or test2"
found = trie.containsMatch(aTextColumn)
设计一个aho_corasick_in Spark 函数
我们的函数需要两样东西:要测试的列和要查找的搜索模式。我们将实现一个具有以下签名的函数:
boolean aho_corasick_in(string text, array<string> searches)
我们修改了 CCCS Sigma 编译器,使其生成使用aho_corasick_in函数的 SQL 语句,而不是生成多个 OR 连接的 LIKE 谓词。在下面的输出中,您会注意到使用了aho_corasick_in函数。我们传递了要测试的字段和一个包含搜索词的字符串数组。以下是我们自定义编译器处理多个contains条件的输出:
select
map(
'Suspicious Program Names',
(
(
(
Imagepath LIKE '%\\cve-202%'
OR Imagepath LIKE '%\\cve202%'
)
OR (
Imagepath LIKE '%\\poc.exe'
OR Imagepath LIKE '%\\artifact.exe'
...
OR Imagepath LIKE '%\\meterpreter'
)
)
OR (
aho_corasick_in(
CommandLine,
ARRAY(
'inject.ps1',
'invoke-cve',
...
'hound.ps1',
'encode.ps1',
'powercat.ps1'
)
)
)
)
) as sigma_rules_map
注意,aho_corasick_in函数接受两个参数:第一个是列,第二个是字符串数组。现在,让我们实际实现aho_corasick_in函数。
实现 Catalyst 函数
我们没有找到很多关于如何实现 Catalyst 函数的文档,因此,我们使用了现有函数的源代码作为参考。我们以regexp(str, regexp)函数为例,因为它会预编译其正则表达式模式,然后在处理行时使用该模式。这类似于预构建 Aho-Corasick 字典树,然后将其应用于每一行。
我们的自定义 Catalyst 表达式接受两个参数。因此,它是一个BinaryExpression,有两个字段,Spark 将其命名为left和right。我们的 AhoCorasickIn 构造函数将text列参数赋值给left字段,将searches字符串数组赋值给right字段。
在初始化 AhoCorasickIn 时,我们还会评估cacheTrie字段。该评估测试searches参数是否为可折叠表达式,即常量表达式。如果是,它会进行求值,并期望得到一个字符串数组,然后使用该数组调用createTrie(searches)。
createTrie函数遍历搜索词并将它们添加到trieBuilder,最终构建出一个 Aho-Corasick 字典树。
case class AhoCorasickIn(text: Expression, searches: Expression)
extends BinaryExpression
with CodegenFallback
with ImplicitCastInputTypes
with NullIntolerant
with Predicate {
override def prettyName: String = "aho_corasick_in"
// Assign text to left field
override def left: Expression = text
// Assign searches to right field
override def right: Expression = searches
override def inputTypes: Seq[DataType] = Seq(StringType, ArrayType(StringType))
// Cache foldable searches expression when AhoCorasickIn is constructed
private lazy val cacheTrie: Trie = right match {
case p: Expression if p.foldable => {
val searches = p.eval().asInstanceOf[ArrayData]
createTrie(searches)
}
case _ => null
}
protected def createTrie(searches: ArrayData): Trie = {
val triBuilder = Trie.builder()
searches.foreach(StringType, (i, s) => triBuilder.addKeyword(s.toString()))
triBuilder.build()
}
protected def getTrie(searches: ArrayData) = if (cacheTrie == null) createTrie(searches) else cacheTrie
override protected def nullSafeEval(text: Any, searches: Any): Any = {
val trie = getTrie(searches.asInstanceOf[ArrayData])
trie.containsMatch(text.asInstanceOf[UTF8String].toString())
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): AhoCorasickIn =
copy(text = newLeft, searches = newRight)
}
nullSafeEval方法是 AhoCorasickIn 的核心。Spark 会为数据集中的每一行调用 eval 函数。在nullSafeEval中,我们检索cacheTrie并使用它来测试text字符串参数。
评估性能
为了比较aho_corasick_in函数的性能,我们编写了一个小型基准测试脚本。我们比较了执行多个LIKE操作与单个aho_corasick_in调用的性能。
select
*
from (
select
text like '%' || uuid() || '%' OR
text like '%' || uuid() || '%' OR
text like '%' || uuid() || '%' OR
...
as result
from (
select
uuid()||uuid()||uuid()... as text
from
range(0, 1000000, 1, 32)
)
)
where
result = TRUE
使用aho_corasick_in的相同实验:
select
*
from (
select
aho_corasick_in(text, array(uuid(), uuid(),...) as result
from (
select
uuid()||uuid()||uuid()... as text
from
range(0, 1000000, 1, 32)
)
)
where
result = TRUE
我们进行了这两个实验(like 与 aho_corasick_in),在一个包含 200 个字符的text列上,变化了搜索词的数量。下面是一个对数图,比较了这两个查询。

作者提供的图片
该图显示了随着我们向“LIKE”查询中添加更多搜索词,性能如何下降,而使用aho_corasick_in函数的查询在搜索词数量增加时保持相对稳定。在 100 个搜索词时,aho_corasick_in函数的运行速度比多个 LIKE 语句快五倍。
我们发现,使用 Aho-Corasick 只有在超过 20 个搜索时才有利。这可以通过构建字典树的初始成本来解释。然而,随着搜索词数量的增加,前期成本逐渐得到回报。这与 LIKE 表达式相反,增加更多 LIKE 表达式会使查询的成本变得更高。
接下来,我们将搜索词的数量设置为 20,并改变 text 字符串的长度。我们观察到,在不同的字符串长度下,LIKE 函数和 aho_corasick_in 函数的耗时大致相同。在这两个实验中,执行时间取决于 text 字符串的长度。

图片来源:作者
需要注意的是,构建字典树所产生的开销将取决于查询执行计划中的 Spark 任务数量。Spark 会为执行计划中的每个任务实例化表达式(即:为每个任务实例化新的 AhoCorasickIn 对象)。换句话说,如果你的查询使用了 200 个任务,那么 AhoCorasickIn 构造函数将被调用 200 次。
总结来说,使用的策略将取决于搜索词的数量。我们将这一优化集成到我们的 Sigma 编译器中。在给定的阈值下(例如 20 个词),它会使用 LIKE 语句;而超过这个阈值时,它则会使用 aho_corasick_in 函数进行查询。
当然,这个阈值将取决于你的实际数据以及 Spark 执行计划中的任务数量。
我们在生产数据和真实的 SigmaHQ 规则上进行的初步实验结果表明,应用 aho_corasick_in 函数可以将我们的处理速率(每秒事件数)提高 1.4 倍。

图片来源:作者
结论
在本文中,我们展示了如何实现一个原生的 Spark 函数。这个 Catalyst 表达式利用了 Aho-Corasick 算法,可以同时测试多个搜索词。然而,和任何方法一样,它也有其权衡。使用 Aho-Corasick 需要构建一个字典树(前缀树),当仅使用少量搜索词时,这可能会导致性能下降。我们的编译器使用一个阈值(即搜索词的数量)来选择最优策略,从而确保查询执行的最高效率。
在免费的 T4 GPU 上优化小型语言模型
使用直接偏好优化(DPO)微调 Phi-2 的全面指南
·发布于Towards Data Science ·阅读时长 11 分钟·2024 年 1 月 30 日
--

“小型”语言模型(LLMs)正在迅速成为人工智能领域的一项变革性技术。
与传统的语言模型(LLMs)需要大量计算资源不同,这些模型要小得多且更高效。尽管它们的性能可以与较大的模型相媲美,但它们能够轻松在标准设备(如笔记本电脑)上运行,甚至可以部署到边缘设备。这也意味着它们可以轻松定制并集成到你的数据集上使用。
在本文中,我将首先解释模型微调和对齐过程的基础知识和内部工作原理。然后,我将引导你通过使用一种名为直接偏好优化(Direct Preference Optimization,DPO)的创新方法对 2 亿参数的小型语言模型 Phi 2 进行偏好微调的过程。
由于模型体积小且采用了量化和 QLoRA 等优化技术,我们将能够通过使用免费的 T4 GPU 在 Google Colab 中执行此过程!这需要一些调整 Hugging Face 用于训练其 Zephyr 7B 模型的设置和超参数。
目录:
- 为什么我们需要微调…
用线性编程优化超级碗方格游戏
一种灵感来源于数独的方式,旨在最小化竞争优势
·发表于Towards Data Science ·10 分钟阅读·2024 年 2 月 6 日
--
超级碗星期天历来通过满足不同兴趣的观众,吸引了美国观众的关注。足球迷为赛场上的激烈动作而兴奋,而普通观众则被这场盛会所吸引——从引人注目的广告到炫目的中场秀。然而,如今的超级碗超越了传统的两类观众模式。例如,今年,泰勒·斯威夫特的忠实粉丝群体(“Swifties”)可能会调频观看她与特拉维斯·凯尔西的互动,因为堪萨斯城酋长队试图卫冕他们的冠军头衔。此外,随着法律体育博彩市场的蓬勃发展,再加上比赛在拉斯维加斯举行,为经验丰富和业余的投注者都带来了更多肾上腺素的刺激。

图片来源:Robert Hernandez Villalta:www.pexels.com/photo/nfl-stadium-field-full-with-crowd-watching-the-game-during-daytime-128457/
本文适用于那些了解足球一般规则,但更偏向于“好玩”的人,而非铁杆粉丝。如果你喜欢在家庭和/或朋友聚会时享受各种薯片和蘸酱,并且不介意以类似“宾果”游戏的方式为一年一度的超级碗比赛增添一些兴奋感,那么请允许我向你介绍超级碗方格游戏。
作者注:如果你已经知道超级碗方格游戏的规则,可以跳过下一部分。
优化 PySpark 中的数据处理性能
PySpark 技术与策略,解决常见的性能挑战:一个实用的操作指南
·发表于Towards Data Science ·阅读时间:9 分钟·发布日期:2024 年 11 月 7 日
--
Apache Spark由于其强大的分布式数据处理能力,近年来已成为领先的数据分析引擎之一。PySpark 是 Spark 的 Python API,常用于个人和企业项目中,以解决数据挑战。例如,我们可以使用 PySpark 高效地实现时间序列数据的特征工程,包括数据摄取、提取和可视化。然而,尽管 PySpark 能够处理大规模数据集,但在一些特定场景下,如极端数据分布和复杂的数据转换流程,性能瓶颈仍然可能出现。
本文将探讨在Databricks上使用 PySpark 进行数据处理时常见的性能问题,并介绍各种优化策略,以实现更快的执行速度。

图片来源:Veri Ivanova 来自Unsplash
假设你开设了一家在线零售店,提供多种产品,主要面向美国客户。你计划通过分析当前交易的购买习惯来满足现有客户的更多需求,并吸引更多新客户。这促使你投入大量精力处理交易记录,作为准备步骤。
#0 模拟数据
我们首先模拟了 100 万条交易记录(在实际的大数据场景中,预计会处理更大的数据集),这些记录包含了客户 ID、购买的产品和交易细节,如支付方式和总金额。值得一提的是,客户 ID #100 的产品代理商有着庞大的客户群,因此在你的店铺中占据了大部分代发货的购买。
以下是演示此场景的代码:
import csv
import datetime
import numpy as np
import random
# Remove existing ‘retail_transactions.csv’ file, if any
! rm -f /p/a/t/h retail_transactions.csv
# Set the no of transactions and othet configs
no_of_iterations = 1000000
data = []
csvFile = 'retail_transactions.csv'
# Open a file in write mode
with open(csvFile, 'w', newline='') as f:
fieldnames = ['orderID', 'customerID', 'productID', 'state', 'paymentMthd', 'totalAmt', 'invoiceTime']
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for num in range(no_of_iterations):
# Create a transaction record with random values
new_txn = {
'orderID': num,
'customerID': random.choice([100, random.randint(1, 100000)]),
'productID': np.random.randint(10000, size=random.randint(1, 5)).tolist(),
'state': random.choice(['CA', 'TX', 'FL', 'NY', 'PA', 'OTHERS']),
'paymentMthd': random.choice(['Credit card', 'Debit card', 'Digital wallet', 'Cash on delivery', 'Cryptocurrency']),
'totalAmt': round(random.random() * 5000, 2),
'invoiceTime': datetime.datetime.now().isoformat()
}
data.append(new_txn)
writer.writerows(data)
在模拟数据之后,我们使用 Databrick 的 Jupyter Notebook 将 CSV 文件加载到 PySpark DataFrame 中。
# Set file location and type
file_location = "/FileStore/tables/retail_transactions.csv"
file_type = "csv"
# Define CSV options
schema = "orderID INTEGER, customerID INTEGER, productID INTEGER, state STRING, paymentMthd STRING, totalAmt DOUBLE, invoiceTime TIMESTAMP"
first_row_is_header = "true"
delimiter = ","
# Read CSV files into DataFrame
df = spark.read.format(file_type) \
.schema(schema) \
.option("header", first_row_is_header) \
.option("delimiter", delimiter) \
.load(file_location)
我们还创建了一个可重用的装饰器工具,用于衡量和比较每个函数内不同方法的执行时间。
import time
# Measure the excution time of a given function
def time_decorator(func):
def wrapper(*args, **kwargs):
begin_time = time.time()
output = func(*args, **kwargs)
end_time = time.time()
print(f"Execution time of function {func.__name__}: {round(end_time - begin_time, 2)} seconds.")
return output
return wrapper
好的,所有准备工作已经完成。接下来我们将探讨以下几个章节中执行性能的不同潜在挑战。
#1 存储
Spark 使用弹性分布式数据集(RDD)作为其核心构建块,数据默认通常保存在内存中。无论是执行计算(如连接和聚合)还是在集群中存储数据,所有操作都会在统一区域中贡献内存使用。

一个包含执行内存和存储内存的统一区域(图源:作者)
如果设计不当,可能导致可用内存不足。这会导致过多的分区溢出到磁盘,从而导致性能下降。
缓存和持久化中间结果或频繁访问的数据集是常见的做法。虽然缓存和持久化具有相同的目的,但它们的存储级别可能有所不同。应当合理利用资源,以确保高效的读写操作。
例如,如果转换后的数据会在不同的后续阶段中重复用于计算和算法,建议对这些数据进行缓存。
代码示例: 假设我们想要调查使用数字钱包作为支付方式的不同交易记录子集。
- 低效 — 没有缓存
from pyspark.sql.functions import col
@time_decorator
def without_cache(data):
# 1st filtering
df2 = data.where(col("paymentMthd") == "Digital wallet")
count = df2.count()
# 2nd filtering
df3 = df2.where(col("totalAmt") > 2000)
count = df3.count()
return count
display(without_cache(df))
- 高效 — 缓存关键数据集
from pyspark.sql.functions import col
@time_decorator
def after_cache(data):
# 1st filtering with cache
df2 = data.where(col("paymentMthd") == "Digital wallet").cache()
count = df2.count()
# 2nd filtering
df3 = df2.where(col("totalAmt") > 2000)
count = df3.count()
return count
display(after_cache(df))
缓存之后,即使我们想要根据不同的交易金额阈值或其他数据维度来过滤转换后的数据集,执行时间也会更易于控制。
#2 洗牌
当我们执行如连接 DataFrame 或按数据字段分组的操作时,会发生洗牌。这是必要的,目的是将所有记录重新分布到集群中,并确保具有相同键的记录位于同一个节点。这有助于同时处理并合并结果。

洗牌连接(图源:作者)
然而,这种洗牌操作是代价高昂的——由于数据在节点间的移动,执行时间长且额外的网络开销。
为了减少洗牌操作,有几种策略:
(1) 对于小数据集,使用广播变量,将只读副本发送到每个工作节点进行本地处理
虽然“较小”数据集通常定义为每个执行器最大内存阈值为 8GB,但广播的理想大小应通过针对特定案例的实验来确定。

广播连接(作者图片)
(2) 提前过滤,尽早尽可能减少处理的数据量;
(3) 控制分区数量,以确保最佳性能
代码示例: 假设我们想返回与我们的州列表匹配的交易记录及其全名
- 低效——大数据集与小数据集之间的 shuffle 连接
from pyspark.sql.functions import col
@time_decorator
def no_broadcast_var(data):
# Create small dataframe
small_data = [("CA", "California"), ("TX", "Texas"), ("FL", "Florida")]
small_df = spark.createDataFrame(small_data, ["state", "stateLF"])
# Perform joining
result_no_broadcast = data.join(small_df, "state")
return result_no_broadcast.count()
display(no_broadcast_var(df))
- 高效——使用广播变量将大数据集与小数据集合并
from pyspark.sql.functions import col, broadcast
@time_decorator
def have_broadcast_var(data):
small_data = [("CA", "California"), ("TX", "Texas"), ("FL", "Florida")]
small_df = spark.createDataFrame(small_data, ["state", "stateFullName"])
# Create broadcast variable and perform joining
result_have_broadcast = data.join(broadcast(small_df), "state")
return result_have_broadcast.count()
display(have_broadcast_var(df))
#3 倾斜性
数据有时会分布不均,尤其是用于处理的键字段。这会导致分区大小不平衡,其中某些分区比平均值大或小得多。
由于执行性能受到最长运行任务的限制,因此需要解决过载节点的问题。
一种常见的方法是加盐。其原理是通过向倾斜键添加随机数,使得数据在分区中更加均匀分布。假设在基于倾斜键进行聚合时,我们将使用加盐后的键进行聚合,然后再使用原始键进行聚合。另一种方法是重新分区,它通过增加分区的数量来帮助数据更均匀地分布。

数据分布——加盐前后的情况(作者图片)
代码示例: 我们想聚合一个不对称的数据集,主要由客户 ID #100 引起的倾斜。
- 低效——直接使用倾斜键
from pyspark.sql.functions import col, desc
@time_decorator
def no_salting(data):
# Perform aggregation
agg_data = data.groupBy("customerID").agg({"totalAmt": "sum"}).sort(desc("sum(totalAmt)"))
return agg_data
display(no_salting(df))
- 高效——使用加盐的倾斜键进行聚合
from pyspark.sql.functions import col, lit, concat, rand, split, desc
@time_decorator
def have_salting(data):
# Salt the customerID by adding the suffix
salted_data = data.withColumn("salt", (rand() * 8).cast("int")) \
.withColumn("saltedCustomerID", concat(col("customerID"), lit("_"), col("salt")))
# Perform aggregation
agg_data = salted_data.groupBy("saltedCustomerID").agg({"totalAmt": "sum"})
# Remove salt for further aggregation
final_result = agg_data.withColumn("customerID", split(col("saltedCustomerID"), "_")[0]).groupBy("customerID").agg({"sum(totalAmt)": "sum"}).sort(desc("sum(sum(totalAmt))"))
return final_result
display(have_salting(df))
向倾斜键添加一个随机的前缀或后缀都可以有效。通常,5 到 10 个随机值是一个很好的起点,可以在扩展数据和保持高复杂性之间取得平衡。
#4 序列化
人们通常更倾向于使用用户定义函数(UDFs),因为它在定制数据处理逻辑方面更灵活。然而,UDFs 是按行逐一操作的。代码需要被 Python 解释器序列化,发送到执行器 JVM,然后再反序列化。这会产生高昂的序列化开销,且阻碍 Spark 对代码的优化和高效处理。
简单直接的方法是尽可能避免使用 UDFs。
我们应首先考虑使用内置 Spark 函数,这些函数可以处理聚合、数组/映射操作、日期/时间戳以及 JSON 数据处理等任务。如果内置函数无法满足你的需求,确实可以考虑使用pandas UDFs。与 UDFs 相比,它们建立在 Apache Arrow 基础上,具有更低的开销和更高的性能。
代码示例: 交易价格根据来源州进行折扣。
- 低效 — 使用 UDF
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType
from pyspark.sql import functions as F
import numpy as np
# UDF to calculate discounted amount
def calculate_discount(state, amount):
if state == "CA":
return amount * 0.90 # 10% off
else:
return amount * 0.85 # 15% off
discount_udf = udf(calculate_discount, DoubleType())
@time_decorator
def have_udf(data):
# Use the UDF
discounted_data = data.withColumn("discountedTotalAmt", discount_udf("state", "totalAmt"))
# Show the results
return discounted_data.select("customerID", "totalAmt", "state", "discountedTotalAmt").show()
display(have_udf(df))
- 高效 — 使用内置的 PySpark 函数
from pyspark.sql.functions import when
@time_decorator
def no_udf(data):
# Use when and otherwise to discount the amount based on conditions
discounted_data = data.withColumn(
"discountedTotalAmt",
when(data.state == "CA", data.totalAmt * 0.90) # 10% off
.otherwise(data.totalAmt * 0.85)) # 15% off
# Show the results
return discounted_data.select("customerID", "totalAmt", "state", "discountedTotalAmt").show()
display(no_udf(df))
在这个示例中,我们使用内置的 PySpark 函数“when”和“otherwise”来有效地按顺序检查多个条件。基于我们对这些函数的熟悉,示例几乎是无限的。例如,pyspark.sql.functions.transform,一个帮助对输入数组中的每个元素应用转换的函数,自 PySpark 3.1.0 版本开始引入。
#5 溢出
如在存储部分讨论的那样,溢出是由于内存不足以容纳所有所需数据,导致将临时数据从内存写入磁盘。我们提到的许多性能问题都与溢出有关。例如,在分区之间洗牌大量数据的操作,容易导致内存耗尽并随之发生溢出。

由于内存不足引起的溢出不同场景(图像由作者提供)
审查 Spark UI 中的性能指标至关重要。如果我们发现溢出(内存)和溢出(磁盘)的统计数据,那么溢出可能是长时间运行任务的原因。为了解决这个问题,可以尝试实例化一个每个工作节点有更多内存的集群,例如通过调节配置值spark.executor.memory来增加执行进程的内存大小;另外,我们还可以配置spark.memory.fraction来调整分配给执行和存储的内存量。
总结
我们遇到了一些常见的导致 PySpark 性能下降的因素,以及可能的改进方法:
-
存储:使用缓存和持久化存储常用的中间结果
-
Shuffle:为小数据集使用广播变量,以促进 Spark 的本地处理
-
偏斜:执行加盐或重新分区以更均匀地分布偏斜数据
-
序列化:更倾向于使用内置 Spark 函数以优化性能
-
溢出:调整配置值以明智地分配内存
最近,自适应查询执行(AQE)被提出用于基于运行时统计信息对查询进行动态规划和重新规划。这支持查询执行过程中发生的不同查询重新优化特性,从而成为一种出色的优化技术。然而,在初期设计阶段理解数据特征仍然至关重要,因为这有助于制定更好的策略,以编写有效的代码和查询,并利用 AQE 进行微调。
在你离开之前
如果你喜欢这篇文章,欢迎关注我的Medium 页面和LinkedIn 页面。通过这样做,你可以及时获取有关数据科学副项目、机器学习运维(MLOps)示范以及项目管理方法学的精彩内容。
用于数据摄取、验证、处理和测试的 Python 技巧与技术:实用的操作流程
towardsdatascience.com ## 使用 PySpark 在 Databricks 上进行时间序列特征工程
探索 PySpark 在时间序列数据中的潜力:摄取、提取和可视化数据,并附带实践…
towardsdatascience.com
针对变长输入序列优化 Transformer 模型
PyTorch NestedTensors、FlashAttention2 和 xFormers 如何提升性能并降低 AI 成本
·发表于Towards Data Science ·阅读时间:14 分钟·2024 年 11 月 26 日
--

图片来源:Tanja Zöllner 于Unsplash
随着生成型 AI(genAI)模型的普及和规模不断扩大,与其训练和部署相关的计算需求和成本也在增加。优化这些模型对于提升其运行时性能和降低运营成本至关重要。现代生成型 AI 系统的核心是 Transformer 架构及其注意力机制,而该机制通常计算密集型。
在上一篇文章中,我们展示了如何通过优化注意力核显著加速 Transformer 模型的性能。在本文中,我们继续探讨如何解决变长输入序列的挑战——这是现实世界数据的固有特性,包括文档、代码、时间序列等。
批处理变长输入的挑战
在典型的深度学习工作负载中,单个样本会在被复制到 GPU 并馈送给 AI 模型之前,被分成多个批次。批处理可以提高计算效率,并且通常有助于模型在训练中的收敛。通常,批处理涉及沿着一个新的维度—批次维度—将所有样本张量进行堆叠。然而,torch.stack要求所有张量具有相同的形状,而这对于可变长度的序列并不适用。
填充及其低效性
解决这个挑战的传统方法是将输入序列填充到一个固定长度,然后执行堆叠。这个解决方案需要在模型内部进行适当的掩码处理,以确保输出不受无关张量元素的影响。在注意力层的情况下,填充掩码表示哪些标记是填充的,不应该被关注(例如,见PyTorch MultiheadAttention)。然而,填充可能会浪费大量的 GPU 资源,增加成本并减慢开发进程。对于大规模 AI 模型来说,这一点尤其如此。
不填充,拼接
避免填充的一种方法是将序列沿着现有的维度进行拼接,而不是将它们沿着新维度进行堆叠。与torch.stack不同,torch.cat允许形状不同的输入。拼接的输出是一个单一的序列,其长度等于所有单个序列的长度之和。为了使这个解决方案有效,我们的单一序列需要通过一个注意力掩码进行补充,以确保每个标记只关注同一原始序列中的其他标记,这个过程有时被称为文档掩码。设所有单个序列的长度之和为N,并采用“大 O”符号,该掩码的大小需要是O(N²),就像一个朴素的注意力层的计算复杂度一样(该层只在计算完注意力分数后才应用掩码),使得这个解决方案极其低效。
注意力层优化
解决此问题的方案是专门化的注意力层。与标准的注意力层需要执行完整的 O(N²) 注意力分数,然后屏蔽掉无关的部分不同,这些优化过的注意力内核只计算 重要的分数。在本文中,我们将探讨几种不同的解决方案,每种都有其独特的特点。这些方案包括:
集成到现有的 HuggingFace 模型中
对于使用预训练模型的团队来说,转向这些优化可能看起来具有挑战性。我们将展示 HuggingFace 的 API 如何简化这一过程,使开发者能够以最小的代码修改和努力将这些技术集成进来。
免责声明
-
请不要将我们使用任何平台、库或优化技术的做法解读为对其使用的推荐。适合您的最佳选择将很大程度上依赖于您具体的应用场景。
-
此处讨论的部分 API 处于原型或测试阶段,未来可能会有所变化。
-
提供的代码示例仅用于演示目的。我们不对其准确性、最优性或稳健性做出任何声明。
特别感谢 Yitzhak Levi 和 Peleg Nahaliel 对本文的贡献。
玩具 LLM 模型
为了方便讨论,我们将定义一个简单的生成模型(部分灵感来自于 GPT 模型,详见 这里)。有关构建语言模型的更全面指南,请参阅在线的诸多优秀教程之一(例如,这里)。
Transformer 模块
我们首先构建一个基本的 Transformer 模块,特别设计用于便于实验不同的注意力机制和优化方法。尽管我们的模块执行与标准 Transformer 模块相同的计算,但我们对常规操作符的选择做了些微调整,以支持 PyTorch NestedTensor 输入(如 这里 所述)。
# general imports
import time, functools
# torch imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
# Define Transformer settings
BATCH_SIZE = 32
NUM_HEADS = 16
HEAD_DIM = 64
DIM = NUM_HEADS * HEAD_DIM
DEPTH = 24
NUM_TOKENS = 1024
MAX_SEQ_LEN = 1024
PAD_ID = 0
DEVICE = 'cuda'
class MyAttentionBlock(nn.Module):
def __init__(
self,
attn_fn,
dim,
num_heads,
format=None,
**kwargs
):
super().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.dim = dim
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim, bias=False)
self.norm2 = nn.LayerNorm(dim, bias=False)
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# mlp layers
self.fc1 = nn.Linear(dim, dim * 4)
self.act = nn.GELU()
self.fc2 = nn.Linear(dim * 4, dim)
self.permute = functools.partial(torch.transpose, dim0=1, dim1=2)
if format == 'bshd':
self.permute = nn.Identity()
def mlp(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def reshape_and_permute(self,x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
return self.permute(x)
def forward(self, x_in, attn_mask=None):
batch_size = x_in.size(0)
x = self.norm1(x_in)
qkv = self.qkv(x)
# rather than first reformatting and then splitting the input
# state, we first split and then reformat q, k, v in order to
# support PyTorch Nested Tensors
q, k, v = qkv.chunk(3, -1)
q = self.reshape_and_permute(q, batch_size)
k = self.reshape_and_permute(k, batch_size)
v = self.reshape_and_permute(v, batch_size)
# call the attn_fn with the input attn_mask
x = self.attn_fn(q, k, v, attn_mask=attn_mask)
# reformat output
x = self.permute(x).reshape(batch_size, -1, self.dim)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x
Transformer 解码器模型
基于我们可编程的 Transformer 模块,我们构建了一个典型的 Transformer 解码器模型。
class MyDecoder(nn.Module):
def __init__(
self,
block_fn,
num_tokens,
dim,
num_heads,
num_layers,
max_seq_len,
pad_idx=None
):
super().__init__()
self.num_heads = num_heads
self.pad_idx = pad_idx
self.embedding = nn.Embedding(num_tokens, dim, padding_idx=pad_idx)
self.positional_embedding = nn.Embedding(max_seq_len, dim)
self.blocks = nn.ModuleList([
block_fn(
dim=dim,
num_heads=num_heads
)
for _ in range(num_layers)])
self.output = nn.Linear(dim, num_tokens)
def embed_tokens(self, input_ids, position_ids=None):
x = self.embedding(input_ids)
if position_ids is None:
position_ids = torch.arange(input_ids.shape[1],
device=x.device)
x = x + self.positional_embedding(position_ids)
return x
def forward(self, input_ids, position_ids=None, attn_mask=None):
# Embed tokens and add positional encoding
x = self.embed_tokens(input_ids, position_ids)
if self.pad_idx is not None:
assert attn_mask is None
# create a padding mask - we assume boolean masking
attn_mask = (input_ids != self.pad_idx)
attn_mask = attn_mask.view(BATCH_SIZE, 1, 1, -1) \
.expand(-1, self.num_heads, -1, -1)
for b in self.blocks:
x = b(x, attn_mask)
logits = self.output(x)
return logits
可变长度序列输入
接下来,我们创建一个包含可变长度序列的数据集,其中每个序列由随机生成的令牌组成。为简便起见,我们(任意)选择了一个固定的序列长度分布。在实际场景中,序列长度的分布通常反映了数据的性质,例如文档或音频片段的长度。需要注意的是,长度分布直接影响由填充引起的计算低效。
# Use random data
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
length = torch.randint(1, MAX_SEQ_LEN, (1,))
sequence = torch.randint(1, NUM_TOKENS, (length + 1,))
inputs = sequence[:-1]
targets = sequence[1:]
return inputs, targets
def pad_sequence(sequence, length, pad_val):
return torch.nn.functional.pad(
sequence,
(0, length - sequence.shape[0]),
value=pad_val
)
def collate_with_padding(batch):
padded_inputs = []
padded_targets = []
for b in batch:
padded_inputs.append(pad_sequence(b[0], MAX_SEQ_LEN, PAD_ID))
padded_targets.append(pad_sequence(b[1], MAX_SEQ_LEN, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_targets = torch.stack(padded_targets, dim=0)
return {
'inputs': padded_inputs,
'targets': padded_targets
}
def data_to_device(data, device):
if isinstance(data, dict):
return {
key: data_to_device(val,device)
for key, val in data.items()
}
elif isinstance(data, (list, tuple)):
return type(data)(
data_to_device(val, device) for val in data
)
elif isinstance(data, torch.Tensor):
return data.to(device=device, non_blocking=True)
else:
return data.to(device=device)
训练/评估循环
最后,我们实现了一个 main 函数,用于在可变长度的输入序列上执行训练/评估。
def main(
block_fn,
data_collate_fn=collate_with_padding,
pad_idx=None,
train=True,
compile=False
):
torch.random.manual_seed(0)
device = torch.device(DEVICE)
torch.set_float32_matmul_precision("high")
# Create dataset and dataloader
data_set = FakeDataset()
data_loader = DataLoader(
data_set,
batch_size=BATCH_SIZE,
collate_fn=data_collate_fn,
num_workers=12,
pin_memory=True,
drop_last=True
)
model = MyDecoder(
block_fn=block_fn,
num_tokens=NUM_TOKENS,
dim=DIM,
num_heads=NUM_HEADS,
num_layers=DEPTH,
max_seq_len=MAX_SEQ_LEN,
pad_idx=pad_idx
).to(device)
if compile:
model = torch.compile(model)
# Define loss and optimizer
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
optimizer = torch.optim.SGD(model.parameters())
def train_step(model, inputs, targets,
position_ids=None, attn_mask=None):
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = model(inputs, position_ids, attn_mask)
outputs = outputs.view(-1, NUM_TOKENS)
targets = targets.flatten()
loss = criterion(outputs, targets)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
@torch.no_grad()
def eval_step(model, inputs, targets,
position_ids=None, attn_mask=None):
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = model(inputs, position_ids, attn_mask)
if outputs.is_nested:
outputs = outputs.data._values
targets = targets.data._values
else:
outputs = outputs.view(-1, NUM_TOKENS)
targets = targets.flatten()
loss = criterion(outputs, targets)
return loss
if train:
model.train()
step_fn = train_step
else:
model.eval()
step_fn = eval_step
t0 = time.perf_counter()
summ = 0
count = 0
for step, data in enumerate(data_loader):
# Copy data to GPU
data = data_to_device(data, device=device)
step_fn(model, data['inputs'], data['targets'],
position_ids=data.get('indices'),
attn_mask=data.get('attn_mask'))
# Capture step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
count += 1
t0 = time.perf_counter()
if step >= 100:
break
print(f'average step time: {summ / count}')
带填充的 PyTorch SDPA
在我们的基准实验中,我们配置了 Transformer 块,以使用 PyTorch 的 SDPA 机制。在我们的实验中,我们分别进行了训练和评估,分别使用和不使用 torch.compile。这些实验在 NVIDIA H100 上运行,使用 CUDA 12.4 和 PyTorch 2.5.1。
from torch.nn.functional import scaled_dot_product_attention as sdpa
block_fn = functools.partial(MyAttentionBlock, attn_fn=sdpa)
causal_block_fn = functools.partial(
MyAttentionBlock,
attn_fn=functools.partial(sdpa, is_causal=True)
)
for mode in ['eval', 'train']:
for compile in [False, True]:
block_func = causal_block_fn\
if mode == 'train' else block_fn
print(f'{mode} with {collate}, '
f'{"compiled" if compile else "uncompiled"}')
main(block_fn=block_func,
pad_idx=PAD_ID,
train=mode=='train',
compile=compile)
性能结果:
-
评估:没有 torch.compile 时为 132 毫秒 (ms),使用 torch.compile 时为 130 毫秒
-
训练:没有 torch.compile 时为 342 毫秒,使用 torch.compile 时为 299 毫秒
优化可变长度输入
在本节中,我们将探讨几种优化技术,用于处理 Transformer 模型中的可变长度输入序列。
填充优化
我们的第一个优化并不是针对注意力内核,而是针对我们的填充机制。我们不再将每批次中的序列填充到一个固定长度,而是将它们填充到批次中最长序列的长度。以下代码块展示了我们修改后的拼接函数和更新的实验。
def collate_pad_to_longest(batch):
padded_inputs = []
padded_targets = []
max_length = max([b[0].shape[0] for b in batch])
for b in batch:
padded_inputs.append(pad_sequence(b[0], max_length, PAD_ID))
padded_targets.append(pad_sequence(b[1], max_length, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_targets = torch.stack(padded_targets, dim=0)
return {
'inputs': padded_inputs,
'targets': padded_targets
}
for mode in ['eval', 'train']:
for compile in [False, True]:
block_func = causal_block_fn\
if mode == 'train' else block_fn
print(f'{mode} with {collate}, '
f'{"compiled" if compile else "uncompiled"}')
main(block_fn=block_func,
data_collate_fn=collate_pad_to_longest,
pad_idx=PAD_ID,
train=mode=='train',
compile=compile)
将每批次中最长期列的长度作为填充目标会略微加速性能:
-
评估:没有 torch.compile 时为 129 毫秒,使用 torch.compile 时为 116 毫秒
-
训练:没有 torch.compile 时为 337 毫秒,使用 torch.compile 时为 294 毫秒
使用 PyTorch NestedTensors 的 SDPA
接下来,我们利用 PyTorch NestedTensors 在评估模式下对 SDPA 的内置支持。目前这是一个原型功能,PyTorch NestedTensors 支持将不同长度的张量进行分组,这些张量有时被称为 锯齿状 或 不规则 张量。在下面的代码块中,我们定义了一个拼接函数,将我们的序列组合成 NestedTensors。我们还定义了一个 indices 条目,以便我们能够正确计算 位置嵌入。
PyTorch 的 NestedTensors 支持由有限数量的 PyTorch 操作提供。解决这些限制可能需要一些创造性。例如,只有在 NestedTensors 具有完全相同的“锯齿”形状时,才能支持它们之间的加法。在下面的代码中,我们使用一种变通方法,确保indices条目与模型输入共享相同的形状。
def nested_tensor_collate(batch):
inputs = torch.nested.as_nested_tensor([b[0] for b in batch],
layout=torch.jagged)
targets = torch.nested.as_nested_tensor([b[1] for b in batch],
layout=torch.jagged)
indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
# workaround for creating a NestedTensor with identical "jagged" shape
xx = torch.empty_like(inputs)
xx.data._values[:] = indices
return {
'inputs': inputs,
'targets': targets,
'indices': xx
}
for compile in [False, True]:
print(f'eval with nested tensors, '
f'{"compiled" if compile else "uncompiled"}')
main(
block_fn=block_fn,
data_collate_fn=nested_tensor_collate,
train=False,
compile=compile
)
尽管使用 torch.compile 时,NestedTensor 优化导致的步长时间为 131 毫秒,接近我们的基线结果,但在编译模式下,步长时间降至 42 毫秒,取得了令人印象深刻的约 3 倍的提升。
FlashAttention2
在我们之前的文章中,我们演示了使用FlashAttention及其对 Transformer 模型性能的影响。在本文中,我们演示了使用flash_attn_varlen_func来自flash-attn (2.7.0),这是一个专为可变大小输入设计的 API。使用此功能时,我们将批次中的所有序列拼接成一个单一的序列。我们还创建了一个cu_seqlens张量,它指向拼接张量中每个单独序列开始的位置。下面的代码块包括我们的整理函数,接着是评估和训练实验。请注意,flash_attn_varlen_func目前不支持 torch.compile(截至本文撰写时)。
def collate_concat(batch):
inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
seqlens = torch.tensor([b[0].shape[0] for b in batch])
seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
cu_seqlens = torch.nn.functional.pad(seqlens, (1, 0))
return {
'inputs': inputs,
'targets': targets,
'indices': indices,
'attn_mask': cu_seqlens
}
from flash_attn import flash_attn_varlen_func
fa_varlen = lambda q, k, v, attn_mask: flash_attn_varlen_func(
q.squeeze(0),
k.squeeze(0),
v.squeeze(0),
cu_seqlens_q=attn_mask,
cu_seqlens_k=attn_mask,
max_seqlen_q=MAX_SEQ_LEN,
max_seqlen_k=MAX_SEQ_LEN
).unsqueeze(0)
fa_varlen_causal = lambda q, k, v, attn_mask: flash_attn_varlen_func(
q.squeeze(0),
k.squeeze(0),
v.squeeze(0),
cu_seqlens_q=attn_mask,
cu_seqlens_k=attn_mask,
max_seqlen_q=MAX_SEQ_LEN,
max_seqlen_k=MAX_SEQ_LEN,
causal=True
).unsqueeze(0)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=fa_varlen,
format='bshd')
causal_block_fn = functools.partial(MyAttentionBlock,
attn_fn=fa_varlen_causal,
format='bshd')
print('flash-attn eval')
main(
block_fn=block_fn,
data_collate_fn=collate_concat,
train=False
)
print('flash-attn train')
main(
block_fn=causal_block_fn,
data_collate_fn=collate_concat,
train=True,
)
这一优化的影响是显著的,评估时间为 51 毫秒,训练时间为 160 毫秒,分别比我们的基线实验提高了 2.6 倍和 2.1 倍的性能。
XFormers 内存高效注意力
在我们之前的文章中,我们演示了使用xFormers (0.0.28)中的memory_efficient_attention操作。在这里,我们演示了使用BlockDiagonalMask,该操作专为任意长度的输入序列设计。所需的整理函数出现在下面的代码块中,接着是评估和训练实验。请注意,在训练模式下,torch.compile 失败。
from xformers.ops import fmha
from xformers.ops import memory_efficient_attention as mea
def collate_xformer(batch):
inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
seqlens = [b[0].shape[0] for b in batch]
batch_sizes = [1 for b in batch]
block_diag = fmha.BlockDiagonalMask.from_seqlens(seqlens, device='cpu')
block_diag._batch_sizes = batch_sizes
return {
'inputs': inputs,
'targets': targets,
'indices': indices,
'attn_mask': block_diag
}
mea_eval = lambda q, k, v, attn_mask: mea(
q,k,v, attn_bias=attn_mask)
mea_train = lambda q, k, v, attn_mask: mea(
q,k,v, attn_bias=attn_mask.make_causal())
block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea_eval,
format='bshd')
causal_block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea_train,
format='bshd')
print(f'xFormer Attention ')
for compile in [False, True]:
print(f'eval with xFormer Attention, '
f'{"compiled" if compile else "uncompiled"}')
main(block_fn=block_fn,
train=False,
data_collate_fn=collate_xformer,
compile=compile)
print(f'train with xFormer Attention')
main(block_fn=causal_block_fn,
train=True,
data_collate_fn=collate_xformer)
在没有 torch.compile 的情况下,评估和训练的步长时间分别为 50 毫秒和 159 毫秒。使用 torch.compile 进行评估时,步长时间为 42 毫秒。
结果
以下表格总结了我们优化方法的结果。

不同优化方法的步长时间结果(越低越好)—— 作者
我们玩具模型的最佳表现是xFormer 的 memory_efficient_attention,它在评估时提供了约 3 倍的性能提升,在训练时提供了约 2 倍的性能提升。我们提醒不要根据这些结果得出任何结论,因为不同的注意力函数的性能影响可能会根据特定的模型和用例有显著的变化。
为可变长度输入优化 HuggingFace 模型
上述描述的工具和技术在从头创建模型时很容易实现。然而,现如今,ML 开发人员采用现有的(预训练的)模型并对其进行微调以适应其用例并不罕见。虽然我们描述的优化可以在不改变模型权重集和不改变模型行为的情况下集成,但目前尚不完全清楚如何做才是最好的方法。在理想情况下,我们的 ML 框架将允许我们编程使用针对可变长度输入优化的注意力机制。在本节中,我们演示了如何为可变长度输入优化 HuggingFace 模型。
玩具 HuggingFace 模型 - GPT2LMHeadModel
为了便于讨论,我们创建了一个玩具示例,在其中我们训练了一个 HuggingFace 的GPT2LMHead模型,处理可变长度的序列。这需要根据 HuggingFace 的输入规范调整我们的随机数据集和数据填充整理函数。
from transformers import GPT2Config, GPT2LMHeadModel
# Use random data
class HuggingFaceFakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
length = torch.randint(1, MAX_SEQ_LEN, (1,))
input_ids = torch.randint(1, NUM_TOKENS, (length,))
labels = input_ids.clone()
labels[0] = PAD_ID # ignore first token
return {
'input_ids': input_ids,
'labels': labels
}
return input_ids, labels
def hf_collate_with_padding(batch):
padded_inputs = []
padded_labels = []
for b in batch:
input_ids = b['input_ids']
labels = b['labels']
padded_inputs.append(pad_sequence(input_ids, MAX_SEQ_LEN, PAD_ID))
padded_labels.append(pad_sequence(labels, MAX_SEQ_LEN, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_labels = torch.stack(padded_labels, dim=0)
return {
'input_ids': padded_inputs,
'labels': padded_labels,
'attention_mask': (padded_inputs != PAD_ID)
}
训练函数
我们的训练函数实例化了一个基于请求的GPT2LMHeadModel的GPT2Config,并在我们的可变长度序列上进行训练。
def hf_main(
config,
collate_fn=hf_collate_with_padding,
compile=False
):
torch.random.manual_seed(0)
device = torch.device(DEVICE)
torch.set_float32_matmul_precision("high")
# Create dataset and dataloader
data_set = HuggingFaceFakeDataset()
data_loader = DataLoader(
data_set,
batch_size=BATCH_SIZE,
collate_fn=collate_fn,
num_workers=12 if DEVICE == "CUDA" else 0,
pin_memory=True,
drop_last=True
)
model = GPT2LMHeadModel(config).to(device)
if compile:
model = torch.compile(model)
# Define loss and optimizer
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
optimizer = torch.optim.SGD(model.parameters())
model.train()
t0 = time.perf_counter()
summ = 0
count = 0
for step, data in enumerate(data_loader):
# Copy data to GPU
data = data_to_device(data, device=device)
input_ids = data['input_ids']
labels = data['labels']
position_ids = data.get('position_ids')
attn_mask = data.get('attention_mask')
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = model(input_ids=input_ids,
position_ids=position_ids,
attention_mask=attn_mask)
logits = outputs.logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
loss = criterion(logits.view(-1, NUM_TOKENS), labels.flatten())
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# Capture step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
count += 1
t0 = time.perf_counter()
if step >= 100:
break
print(f'average step time: {summ / count}')
带填充的 SDPA
在下面的回调中,我们使用默认的序列填充整理器调用我们的训练函数。
config = GPT2Config(
n_layer=DEPTH,
n_embd=DIM,
n_head=NUM_HEADS,
vocab_size=NUM_TOKENS,
)
for compile in [False, True]:
print(f"HF GPT2 train with SDPA, compile={compile}")
hf_main(config=config, compile=compile)
结果的步骤时间在未使用 torch.compile 时为 815 毫秒,使用 torch.compile 时为 440 毫秒。
FlashAttention2
我们现在利用 HuggingFace 的内置支持 FlashAttention2,通过将attn_implementation参数设置为“flash_attention_2”。在后台,HuggingFace 将取消填充填充的数据输入,然后将其传递给我们上面看到的优化过的flash_attn_varlen_func函数:
flash_config = GPT2Config(
n_layer=DEPTH,
n_embd=DIM,
n_head=NUM_HEADS,
vocab_size=NUM_TOKENS,
attn_implementation='flash_attention_2'
)
print(f"HF GPT2 train with flash")
hf_main(config=flash_config)
结果的时间步长为 620 毫秒,相较于未编译模式提高了 30%(仅通过一个简单的开关切换)。
FlashAttention2 与未填充输入
当然,在合并函数中填充序列,结果却又将其解填充,这似乎毫无意义。在最近的HuggingFace 更新中,已增加了对将连接(未填充)序列传递给选定模型的支持。不幸的是(截至本文写作时),我们的 GPT2 模型未包括在内。然而,添加支持只需要在modeling_gpt2.py中增加五行代码,即可将序列position_ids传播到flash-attention 内核。完整的补丁如下所示:
@@ -370,0 +371 @@
+ position_ids = None
@@ -444,0 +446 @@
+ position_ids=position_ids
@@ -611,0 +614 @@
+ position_ids=None
@@ -621,0 +625 @@
+ position_ids=position_ids
@@ -1140,0 +1145 @@
+ position_ids=position_ids
我们定义了一个合并函数,将我们的序列连接在一起,并在未填充的序列上训练我们的 HuggingFace 模型。(另请参见内置的DataCollatorWithFlattening工具。)
def collate_flatten(batch):
input_ids = torch.concat([b['input_ids'] for b in batch]).unsqueeze(0)
labels = torch.concat([b['labels'] for b in batch]).unsqueeze(0)
position_ids = [torch.arange(b['input_ids'].shape[0]) for b in batch]
position_ids = torch.concat(position_ids)
return {
'input_ids': input_ids,
'labels': labels,
'position_ids': position_ids
}
print(f"HF GPT2 train with flash, no padding")
hf_main(config=flash_config, collate_fn=collate_flatten)
结果的步骤时间为 323 毫秒,比在填充输入上运行 flash-attention 快了 90%。
结果
我们的 HuggingFace 实验结果总结如下。

不同优化方法的步骤时间结果(越低越好)——作者
通过一点小努力,我们将运行时性能提升了 2.5 倍,相比未编译的基线实验,提升了 36%,相较于编译版本。
在本节中,我们展示了 HuggingFace APIs 如何让我们利用 FlashAttention2 中的优化内核,显著提升现有模型在不同长度序列上的训练性能。
总结
随着 AI 模型在流行度和复杂度上的不断增长,优化它们的性能已成为减少运行时间和成本的关键,尤其是对于像注意力层这样的计算密集型组件。在这篇文章中,我们继续探索注意力层的优化,并展示了提升 Transformer 模型性能的新工具和技术。欲了解更多关于 AI 模型优化的见解,请务必查看本系列的第一篇文章以及我们在这个话题上的其他多篇文章。
最优分配与匈牙利算法

匈牙利算法在实际应用中的效果!图片来自作者。
本文提供了一个逐步示例,展示了匈牙利算法如何在图上解决最优分配问题。
·发表于 Towards Data Science ·11 分钟阅读·2024 年 7 月 7 日
--
我写这篇文章的原因是因为我花了几天时间才理解匈牙利算法在图上的工作原理。矩阵版本更容易理解,但它没有提供所需的洞察力。我在网上找到的所有优秀资料都未能让我清晰地理解算法为什么要按这样做。
对我来说,将算法描述转化为一个可工作的示例并不简单。尽管我们今天拥有的各种大型语言模型(LLM)工具帮助我用多种方式重新措辞算法描述,但当我要求它们生成一个逐步的工作示例时,它们都失败了。所以我坚持不懈地生成了一个匈牙利算法在图上应用的示例。我在这里展示这个逐步示例以及从这次练习中获得的直觉,希望能帮助其他想学习这个精彩算法以解决最优分配问题的人。
最优分配问题是从一组节点到另一组节点找到一对一匹配,其中每一对节点之间的边都有一个相关的成本,生成的匹配必须确保总成本最小。
这个问题是普遍存在的。它可以是将一组人分配到一组工作,每个工作需要为特定工作支付特定的价格,然后问题就是将人分配到工作上,以使总成本最小化。它也可以是将一组可用的出租车分配给需要出租车的一组人,每辆出租车需要特定的时间到达特定的人。每次我们预定 Uber 或 Ola 时,都会使用匈牙利算法来解决这个问题。
分配问题最好用二分图来表示,二分图是一个具有两个不同节点集的图,且边从不连接同一集合中的节点。我们以出租车匹配问题为例,二分图显示了节点之间的所有可能连接。边的权重表示特定出租车到达特定人的时间(以分钟为单位)。如果一个集合中的所有节点都与另一个集合中的所有节点连接,则该图被称为完全图。

用于出租车匹配问题的二分图。图片由作者提供
在这个例子中,我们需要将 4 个人映射到 4 辆出租车上。我们可以暴力破解求解:
-
对于第一个人,有四个可能的出租车分配,我们可以选择其中任何一辆
-
对于第二个人,剩下三辆出租车,我们从这三辆中选择一辆
-
对于第三个人,我们选择剩下的两辆出租车中的一辆
-
对于最后一个人,我们选择剩余的出租车
所以可能的分配总数是 4 x 3 x 2 x 1,即 4 的阶乘。然后我们计算所有这些 4!种分配的总成本,并选择成本最低的那一项。对于较小规模的分配问题,暴力破解仍然是可行的。但是随着n(人或出租车的数量)的增加,n!会变得非常大。
另一种方法是贪心算法。你选择一个人,将最小成本的出租车分配给他,然后选择下一个人,给他分配剩余出租车中最小成本的那一辆,以此类推。在这个例子中,对于最后一个人,不可能进行最小成本分配,因为可以最短时间到达他的出租车已经分配给了另一个人。所以算法选择了下一个可用的最小成本出租车。因此,最后一个人由于其他人的贪心而受到影响。贪心算法的解决方案可以在这里的图中看到。虽然在这个例子中,贪心方法确实达到了最低成本 36,但不能保证这种贪心方法会得到最优分配。

贪心分配。图片由作者提供
匈牙利算法提供了一种高效的寻找最优解的方法。我们将首先从矩阵算法开始。我们可以将图表示为一个邻接矩阵,其中每条边的权重是矩阵中的一个条目。图和其邻接矩阵可以在这里看到。

二分图及其邻接矩阵。图片由作者提供
这里是匈牙利算法在邻接矩阵上操作的步骤:
-
对于邻接矩阵中的每一行,找出并减去该行所有条目的最小值。
-
对于邻接矩阵中的每一列,找出并减去该列所有条目的最小值。
-
用最少的线覆盖矩阵中的所有零。
a. 计算每一行和每一列中的零的数量
b. 首先在零最多的行/列上画线,其次是零第二多的行/列,以此类推。
c. 如果覆盖所有零所需的线数少于矩阵的行/列数,则继续执行第 4 步,否则进入第 5 步
-
找到未被线覆盖的最小条目,并将该条目从所有未覆盖的条目中减去,同时将它添加到所有被两条线(水平线和垂直线)覆盖的条目中,然后进入第 3 步。
-
可以通过从仅有一个零的行/列开始来生成最佳分配。
前两步很简单。第三步需要按照零的数量从高到低的顺序划去行或列。如果划去的行和列的数量不等于需要匹配的节点数,则需要创建额外的零——这在第四步中完成。重复执行第三步和第四步,直到划去足够的行和列。这样,矩阵中就有足够的零用于最佳匹配。
这个出租车匹配示例的匈牙利算法步骤显示在这个动画 GIF 中。
匈牙利算法在邻接矩阵上的动画。图片由作者提供
对于那些难以跟随 GIF 动画的人,这里是显示步骤的图片。

匈牙利算法在邻接矩阵上的应用。图片由作者提供
算法在已划去的行/列数等于需要匹配的节点数时终止。最后一步是找到分配结果,可以通过首先分配对应零条目的边来轻松完成,在这个例子中,可能是(P1,T2),通过选择第一行中的第一个零。这样我们就不能再将 T2 分配给其他人,因此 T2 列中的第二个零可以被移除。P4 行中唯一剩下的零表示它必须分配给 T1,因此下一个分配是(P4,T1)。此时,T1 列中的第二个零也可以被移除,P2 行只剩下一个零。第三个分配因此是(P2,T3)。最后的分配就是(P3,T4)。读者可以通过加总这些分配对应的边的权重来计算总成本,结果为 36。
如果我们查看 GIF 动画,整个分配过程就更直观了,我们有一个连接所有节点的子图,并且我们可以创建一个交替路径,其中的边交替为已匹配(绿色)和未匹配(红色)。
现在我们已经看到了匈牙利算法在邻接矩阵上的应用,我们知道这些步骤为何是这样的了吗?究竟创建覆盖零的最小数量的线条告诉我们什么?为什么最小的线条数量必须等于要匹配的节点数量才能停止算法?我们如何理解步骤 3 中的这个奇怪规则,即要创建额外的零,我们需要找到未覆盖的最小值,将其从所有未覆盖的条目中减去,同时加到覆盖了两次的条目中?
为了获得更好的理解,我们需要看到匈牙利算法是如何在图上工作的。为此,我们需要将最优分配问题视为匹配需求和供应。这需要为每个节点创建标签,表示供应和需求的数量。
现在我们需要一些符号来解释这个标签过程。二分图中有两个不同的节点集合,我们称它们为 X 和 Y。所有属于集合 X 的节点用x表示,属于 Y 的节点用y表示。标签则分别为l(x)和l(y),连接x和y的边的代价是w(x,y)。标签必须是可行的,这意味着当我们希望最小化成本时,l(x)+l(y)≤w(x,y),而当我们希望最大化成本时,l(x)+l(y)≥w(x,y)。
在我们的例子中,人们的需求是他们必须立即乘坐出租车,零等待时间。出租车的供应是到达这些人所需的时间。初始的可行标签是,将每辆出租车到达四个乘客所需的最短时间作为出租车的标签,而乘客的标签则为零。
一旦我们有了可行的标签,我们就可以创建一个等式子图,其中我们只选择那些满足等式的边:l(x)+l(y)=w(x,y)。初始标签和结果等式子图(高亮显示的边)如图所示。

初始标签的等式子图。图像由作者提供
我们看到在这个等式子图中有些节点没有连接,为了解决这个问题,我们需要修正标签。我们通过查看那些未连接的节点,并使用可以更新标签的最小值来更新这些节点的标签,以便建立连接。
在我们的例子中,如果 P3 的标签是 0,那么它无法连接。为了使这个标签增加到足以建立连接的最小值是 1,这个值是通过查看连接到 P3 的每一条边上的最小松弛(Δ=w(x,y)-(l(x)+l(y)))得出的。同样,对于 P4,其标签是根据它的边上的最小松弛来更新的。更新标签后的结果等式图在图中显示。

未连接节点的标签修正。图像由作者提供
我们现在可以尝试在等式子图上找到匹配,我们会看到只有 3 个节点可以匹配。这是因为在等式图中,我们还没有一条能够连接所有节点的交替路径。由于没有足够的边来创建这样的交替路径,我们需要再次修订标签以添加额外的边。但是这次,等式子图已经与所有节点建立了连接。所以我们这次添加的边应该有助于扩展交替路径。

交替路径。图由作者提供
我们观察由该交替路径连接的所有节点(图中通过红色和绿色边表示),并提出问题:这些节点的标签可以通过什么最小松弛量进行修订以添加一条边。为了找到这个最小松弛量,我们查看所有连接到不在交替路径中的 X 节点的边的松弛量,如下一张图所示,并计算最小值。对于这个例子,最小松弛量是 2(来自边 P3-T3)。

扩展交替路径的边选项。图由作者提供
交替路径中所有节点的标签需要更新,但当我们调整需求(通过添加最小松弛量),我们还需要从供应中减少相同的值(通过从中减去最小松弛量),以确保等式图中的现有边不发生变化。修订后的标签和更新后的等式图如图所示。

标签修订后的等式子图。图由作者提供
我们现在可以看到,有一条交替路径连接了所有节点。现在可以通过使用交替路径并在匹配和未匹配的边之间交替来找到匹配(见图)。请注意,交替路径与每个节点有两条边相连。节点下方的数字表示交替路径连接这些节点的顺序。

匈牙利算法在图上的结果。图由作者提供
可以通过从图中选择所有用绿色高亮的边来读取分配情况。也就是说,(P1,T4),(P2,T1),(P3,T3) 和 (P4,T2),这将得到一个总成本 8+10+11+7 = 36。整个过程的动画可以在这里的 GIF 中看到。
匈牙利算法在图上的应用。图由作者提供
我们从图上的匈牙利算法看到,我们始终只通过添加一个额外边所需的最小值来调整供应和需求。因此,这个过程保证了我们最终能得到最优成本。这个过程的数学公式和证明在许多网上资源中都有清晰的阐述,但由于我们必须跟踪子图和子集,图的数学公式并不容易理解。一个展示这一过程的例子很有帮助,至少我希望是这样。
我们还可以看到与我们在邻接矩阵上所做的步骤之间的相似性。通过在图上应用算法获得的洞察,我们可以看出,覆盖零的最小线数需要等于要匹配的节点数,才能确保最大匹配。邻接矩阵中创建额外零的规则并没有提供直观的理解,但基于未包含在交替路径中的边的最小松弛对图上标签的修正,立即提供了连接。
话虽如此,操作矩阵的算法更容易理解和实现,这也是为什么网上有很多关于用这种方式解释匈牙利算法的信息。但我希望你们中的一些人同意我, 一旦我们看到匈牙利算法在图上的应用,这种清晰度为我们提供的理解程度是理解这个简洁算法的关键。
这篇文章到这里就结束了。我还录制了一段视频,介绍了这些内容,但视频时长为 45 分钟,因为我似乎随着年龄增长说话变得更慢了。也许有一天我会把视频链接放在这里。
在 Azure 中编排动态时间序列管道
探索如何使用 Azure Data Factory(ADF)和 Databricks 构建、触发和参数化一个时间序列数据管道,并附有逐步教程。
·发布于Towards Data Science ·阅读时间 8 分钟·2024 年 5 月 31 日
--
在上一篇故事中,我们回顾了 PySpark 在 Databricks 上处理时间序列数据的潜力。我鼓励你通过这里了解更多内容。在不配置独立 Spark 实例的情况下,我们可以通过 Databricks 上的 PySpark 摄取静态和流数据,执行数据转换,提取有用的时间相关特征,并构建可视化。当处理企业级数据的大规模复杂转换时,PySpark 的可扩展性和性能特别具有优势,甚至可以处理 PB 级别的数据。
所有特征工程任务都成功地在一个 Databricks 笔记本中完成。然而,这只是构建数据中心系统时数据工程故事的一部分。数据管道的核心部分在于数据编排。
数据编排通常指的是对数据流进行集中控制,以便我们可以自动化、管理和监控整个数据管道。

图片由Julio Rionaldo提供,来自Unsplash
Azure Data Factory (ADF)与 Azure Databricks
为了满足这些需求,行业中最流行的解决方案之一是从ADF平台运行Azure Databricks笔记本。
ADF 是一个基于云的、无服务器且完全托管的数据集成服务。尽管Databricks Workflow提供了一个很好的替代方案,涵盖了部分 ADF 的功能,但选择 ADF 仍然有若干关键优势。例如,ADF 是一个成熟的工具,能够使用连接器与各种数据存储进行集成,包括像 Salesforce 这样的 SaaS 应用和像 Amazon Redshift、Google BigQuery 这样的大数据源。因此,它在数据摄取和集成方面表现良好,尤其是当当前系统与 Databricks 以外的数据系统存在复杂依赖关系时。此外,ADF 简化并便于使用拖放和低代码界面快速构建基本管道。
在这个实践过程中,我们将深入探讨数据工程项目,并探索 ADF 如何帮助构建一个动态的、骨架型的数据管道,用于时间序列数据。 我将展示如何在 Azure Databricks 上挂载云存储,通过嵌入的 Notebook 转换数据,并通过 ADF 中的自定义设置动态编排数据。让我们开始吧!
初始设置
首先有几个云组件和服务。
#1 创建一个 Azure 资源组
这个容器用于保存和分组 Azure 解决方案的资源。我们将把必要的云服务组件放入这个逻辑组中,以便更容易进行构建或部署。

Azure 资源组(作者提供的图片)
#2 创建一个 Azure Data Lake Gen 2 存储账户
你可以根据性能和复制需求选择合适的存储账户。在高级选项卡中,我们启用了分层命名空间以设置Data Lake Storage Gen 2。这使得既可以存储结构化数据,也可以存储非结构化数据。

存储账户(作者提供的图片)
#3 设置 Azure Databricks 服务
如果你之前使用过 Databricks,Azure Databricks 服务大体相同。此外,它与其他 Azure 服务原生集成,并提供统一的计费平台。这里有两个层级:(1) 标准层——足以满足我们在此的概念验证需求;(2) 高级层——具有标准层的功能,额外提供Unity Catalog和可能对于拥有多个 Databricks 工作区的大型企业所需的高级网络功能。

Azure Databricks 工作区(图片来源:作者)
#4 注册应用程序
该服务将帮助将 Azure 存储挂载到 Databricks,因此请确保记下应用 ID 和租户 ID,最重要的是应用的密钥值,在你重新访问时是无法查看的。

应用注册 — 设置(图片来源:作者)

应用注册 — 信息(图片来源:作者)

应用注册 — 客户端密钥(图片来源:作者)
然后,授予应用服务对应用服务的访问权限。这是通过将“Storage Blob Data Contributor”角色分配给我们刚注册的应用来实现的。

存储账户 — 授予访问权限(1/3)(图片来源:作者)

存储账户 — 授予访问权限(2/3)(图片来源:作者)

存储账户 — 授予访问权限(3/3)(图片来源:作者)
#5 创建 Azure SQL 数据库
为了存储转换后的数据框,我们搜索 Azure SQL 资源并选择“单一数据库”作为资源类型。SQL 数据库服务器提供了不同的计算硬件、最大数据大小等选项。你可以在调整服务器配置时即时查看估算的费用摘要。

创建 SQL 数据库(1/2)(图片来源:作者)

创建 SQL 数据库(2/2)(图片来源:作者)
完成所有初始设置后,你就可以探索这些服务是如何相互连接的。
准备数据编排管道
#1 导入数据
我们首先将电力消耗数据上传到 Azure Data Lake Gen2。这个数据集[许可证为数据库:开放数据库,内容:数据库内容],来自 Kaggle,采样频率为每分钟一次,数据时间从 2006 年 12 月到 2010 年 11 月。

上传输入数据(图片来源:作者)
接下来,我们在 Azure Databricks 工作区创建一个 Notebook,并通过定义参数,使用之前存储的 ID 值来挂载存储。
# Define the configuration specifications
configs = {"fs.azure.account.auth.type": "OAuth",
"fs.azure.account.oauth.provider.type": "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider",
"fs.azure.account.oauth2.client.id": "<Client ID>",
"fs.azure.account.oauth2.client.secret": "<Client Secret>",
"fs.azure.account.oauth2.client.endpoint": "https://login.microsoftonline.com/<Tenant ID>/oauth2/token"
}
dbutils.fs.mount(
source = "abfss://input@adlstsdp.dfs.core.windows.net/", # URI of the object storage
mount_point = "/mnt/adlstsdp/input", # local path in the /mnt directory
extra_configs = configs)
为了验证文件访问,我们可以运行以下命令:
dbutils.fs.ls(“/mnt/adlstsdp/input”)
# Output: [FileInfo(path='dbfs:/mnt/adlstsdp/input/household_power_consumption.csv', name='household_power_consumption.csv', size=132960755, modificationTime=1716798010000)]
#2 在 Azure Databricks 中嵌入 Notebook
本节中的大部分源代码基于我的上一篇文章。其思路是进行数据清理、转换和特征工程(创建时间相关特征和移动平均特征)。转换后的数据最终写入 Azure 数据库表中。
你可以查看下面的完整代码,了解其实现过程。
# Define file location, file typem and CSV options
file_location = "/mnt/adlstsdp/input/household_power_consumption.csv"
file_type = "csv"
schema = "Date STRING, Time STRING, Global_active_power DOUBLE, Global_reactive_power DOUBLE, Voltage DOUBLE, Global_intensity DOUBLE, Sub_metering_1 DOUBLE, Sub_metering_2 DOUBLE, Sub_metering_3 DOUBLE"
first_row_is_header = "true"
delimiter = ";"
# Read CSV files
org_df = spark.read.format(file_type) \
.schema(schema) \
.option("header", first_row_is_header) \
.option("delimiter", delimiter) \
.load(file_location)
# Data cleansing and transformation
from pyspark.sql.functions import *
cleaned_df = org_df.na.drop()
cleaned_df = cleaned_df.withColumn("Date", to_date(col("Date"),"d/M/y"))
cleaned_df = cleaned_df.withColumn("Date", cleaned_df["Date"].cast("date"))
cleaned_df = cleaned_df.select(concat_ws(" ", to_date(col("Date"),"d/M/y"), col("Time")).alias("DateTime"), "*")
cleaned_df = cleaned_df.withColumn("DateTime", cleaned_df["DateTime"].cast("timestamp"))
df = cleaned_df.groupby("Date").agg(
round(sum("Global_active_power"), 2).alias("Total_global_active_power"),
).sort(["Date"])
# Add time-related features
df = df.withColumn("year", year("Date"))
df = df.withColumn("month", month("Date"))
df = df.withColumn("week_num", weekofyear("Date"))
# Add lagged value features of total global active power
from pyspark.sql.window import Window
from pyspark.sql.functions import lag
windowSpec = Window.orderBy("Date")
df = df.withColumn("power_lag1", round(lag(col("Total_global_active_power"), 1).over(windowSpec), 2))
# Create delta field
df = df.withColumn("power_lag1_delta", round(col("power_lag1") - col("Total_global_active_power"), 2))
# Create window average fields
def add_window_avg_fields(df, window_sizes):
for idx, window_size in enumerate(window_sizes, start=1):
window_col_name = f"avg_power_lag_{idx}"
windowSpec = Window.orderBy("Date").rowsBetween(-window_size, 0)
df = df.withColumn(window_col_name, round(avg(col("Total_global_active_power")).over(windowSpec), 2))
return df
window_sizes = [14, 30]
df = add_window_avg_fields(df, window_sizes)
# Create Exponentially Weighted Moving Average (EWMA) fields
import pyspark.pandas as ps
ps.set_option('compute.ops_on_diff_frames', True)
def add_ewma_fields(df, alphas):
for idx, alpha in enumerate(alphas, start=1):
ewma_col_name = f"ewma_power_weight_{idx}"
windowSpec = Window.orderBy("Date")
df[ewma_col_name] = df.Total_global_active_power.ewm(alpha=alpha).mean().round(2)
return df
alphas = [0.2, 0.8]
df_pd = df.pandas_api()
df_pd = add_ewma_fields(df_pd, alphas)
df = df_pd.to_spark()
# Write transformed dataframe to the database table "electric_usage_table"
df.write.format("jdbc") \
.option("url", "jdbc:sqlserver://sql-db-dp.database.windows.net:1433;databaseName=sql-db-dp") \
.option("dbtable", "dbo.electric_usage_table") \
.option("user", "<username>") \
.option("password", "<password>") \
.mode("overwrite") \
.save()
#3 在 ADF 中构建基本管道
在 ADF 中,我们将“Notebook”活动添加到管道环境中,然后配置它以引用 Databricks 文件夹中的所需 Notebook。设置 Databricks 连接服务,然后在 ADF 中验证并发布整个活动管道。然后,您可以在“调试”模式下运行管道。

管道运行的成功状态(图片由作者提供)
活动状态显示为“已成功”,这意味着数据应该已迁移并插入到 Azure SQL 数据库表中。我们可以使用查询编辑器查看结果以进行验证。

查询 Azure SQL 数据库的结果(图片由作者提供)
#4 自动化管道
ADF 提供的功能远超上述简单实现。例如,我们可以通过创建基于存储的事件触发器来自动化管道。确保Microsoft.EventGrid已注册为您账户订阅中的资源提供者之一,然后设置触发器:每当新数据集上传到存储帐户时,管道将自动执行。

在 ADF 中设置新的触发器(图片由作者提供)
这种类型的触发器在行业中有各种应用场景,例如监控库存水平以补充供应链订单,或追踪客户互动以实现数字营销中的个性化推荐。
#5 参数化 Notebook 变量
为了进一步构建更具动态性的数据信息管道,我们可以使变量更加参数化。例如,在时间序列数据的特征工程中,数据特征的窗口大小最初可能并未优化。窗口大小可能需要根据季节性模式或下游模型微调进行调整。对于这种情况,我们可以通过以下设置进行修改。

设置管道运行的参数(图片由作者提供)
在 Notebook 中,添加以下代码以创建一个小部件,可以从 ADF 管道获取参数输入:
# Additional code: Access the current value of the widget
inputWindowSizes = dbutils.widgets.get("inputWindowSizes")
window_sizes = inputWindowSizes.split(",")
# Original function for adding window average features
df = add_window_avg_fields(df, window_sizes)
在调整设置和 Notebook 代码后,我们可以通过提供窗口大小参数值,如 30 和 60,来运行管道。

为管道运行输入窗口大小值(图片由作者提供)
最后,我们可以通过 ADF 或 Databricks 工作区再次监控管道状态。
总结
在我们的实践探索中,我们主要使用 ADF 与 Azure Databricks 来编排一个动态的时间序列数据管道:
-
设置云资源用于计算、分析和存储。
-
从数据摄取到存储,构建数据管道的骨架。
-
通过创建触发器和参数化变量,为管道带来灵活性。
在企业层面,可能会实施更多的复杂云架构,以满足不断变化的需求,如数据流、模型监控和多模型管道。因此,团队在治理政策和支出管理上的协作变得至关重要,以实现性能、可靠性和成本效益之间的精细平衡。
在您离开之前
如果您喜欢这篇文章,我邀请您关注我的Medium 页面和LinkedIn 页面。通过这样做,您可以随时了解与数据科学侧项目和机器学习运维(MLOps)演示方法相关的精彩内容。
发现 LangChain 在客户分析中的潜力与局限性,并附带实际的实施案例…
towardsdatascience.com ## 管理机器学习系统的技术债务
探索通过实施代码持续降低快速交付成本的实践
towardsdatascience.com
组织的机器学习投资是(或应该是)渐进式的
将机器学习系统嵌入生产环境仍然是一个艰难的任务(对于大多数公司来说)
·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 7 月 23 日
--

图片来源:Glen Carrie @ Unsplash.com
你是否听说过有公司成功地将机器学习系统在一夜之间集成到他们的业务流程中,彻底改变了组织的运作方式,一天到晚就发生了翻天覆地的变化?
没错,我也是!
而你知道吗,大多数机器学习模型永远无法投入生产?
将生产级系统整合到业务流程中是极其困难的。所谓生产级系统,指的是那些具有一定可靠性并能为公司带来价值的系统。将机器学习系统嵌入到组织中不是一夜之间可以完成的任务,坦率来说,数据科学和机器学习之所以被误解,正是因为领导者在过程中迷失了方向。特别是,我看到在尝试先进行机器学习实验时,通常会犯两种错误:
- 不正确的期望:这个问题非常普遍,问题出在机器学习供应商身上。对机器学习和人工智能系统的高期望通常源于那些想要销售这些系统的人(或媒体炒作)。但请听我说:每个机器学习系统都有误差,这是无法避免的。
在工具类中组织 Python 函数
PYTHON 编程
探索工具类如何提供增强的命名空间来组织相关函数。
·发表于 Towards Data Science ·13 分钟阅读 ·2024 年 4 月 24 日
--

图片来源:Matteo Grassi 在 Unsplash
Python 提供了强大的面向对象编程(OOP)工具。您可以创建各种类型的类,使用继承和组合,创建可调用对象;基本上,您可以创建各种自定义类。
在本文中,我们将讨论所谓的工具类,也称为助手类或命名空间类。它们在 Python 代码库中并不常见,但我猜其中一个原因是许多 Python 程序员根本没有听说过它们,因此并没有意识到他们手中拥有这个强大的工具。
同时,您应该注意这类类的局限性。如果过度使用,它们可能会使代码变得过于复杂且难以理解——尽管您为了实现这些类付出了额外的工作。
Python 工具类主要用作一组静态方法,这些方法无需实例化即可作为相关函数的命名空间。这类类并不设计用于创建对象(实例);相反,它们在一个共同的命名空间下逻辑地将相关函数分组,使代码更有组织且更易于访问。
ORPO:无监督微调(SFT)步骤的偏好优化
一种比 DPO 更便宜的对齐方法,性能相当
·发布于 Towards Data Science ·阅读时长 7 分钟·2024 年 4 月 10 日
--

由 DALL-E 生成
现在有许多方法可以将大型语言模型(LLM)与人类偏好对齐。带有人工反馈的强化学习(RLHF)是最早的一种方法,并且催生了 ChatGPT,但 RLHF 成本非常高。DPO、IPO 和 KTO 比 RLHF 显著便宜,因为它们不需要奖励模型。
虽然 DPO 和 IPO 更便宜,但它们仍然需要训练两个不同的模型。一个模型用于监督微调(SFT)步骤,即训练模型回答指令,然后使用该 SFT 模型进行初始化和作为参考,来对齐人类偏好。
ORPO 是另一种新的大型语言模型(LLM)对齐方法,但这一方法甚至不需要 SFT 模型。使用 ORPO,LLM 可以共同学习如何回答指令和人类偏好。
在这篇文章中,我将解释 ORPO 并回顾其性能。我展示了如何使用它将 Mistral 7B 转换为一个聊天模型,使用普通消费者硬件即可实现。
联合 SFT 和偏好优化
本文介绍了 ORPO:
在医疗健康领域克服 LLM 挑战:生产环境中的实际开发策略
生成性人工智能
一篇关于我遇到的最常见 LLM 开发挑战、有效的缓解策略和一场决定我职业生涯的面试错误的文章
·发表于 Towards Data Science ·阅读时间:9 分钟·2024 年 11 月 5 日
--
引言
我一直是那种喜欢深入研究某一领域并专注到痴迷的人。当我从数据科学硕士毕业时,我的痴迷对象是计算机视觉;特别是将计算机视觉应用于神经科学或心理健康领域。尽管我的导师们建议我拓宽视野,尽早迈入职场,但我还是决定成为一名“计算机视觉工程师”(不过“机器学习工程师”也可以),专注于心理健康领域。我压抑了内心的疑虑,相信有合适的团队会认同我的“专业能力”。

由 DALL·E 生成的图片
幸运的是,我的理论似乎奏效了;我获得了几家心理健康公司的面试机会。但接着,我犯下了我面试中的最大错误之一。在我的首选公司——一家我非常喜爱的公司——的最终面试中,我犯了一个错误,每当我反思时,都会感到非常尴尬。这个职位主要集中在自然语言处理(NLP),需要处理文本数据,但我忍不住表达了我对……
克服保护共享生成式 AI 环境中的安全挑战
确保在多租户中的安全 AI
·发表于 Towards Data Science ·阅读时长 13 分钟·2024 年 12 月 2 日
--
免费链接 => 请帮助点赞 这篇 LinkedIn 帖子 来传播本文。
引言
让我们首先概述一下该领域的情况:许多组织正在搭乘生成式 AI 的浪潮。一份最新报告显示,超过 65%的组织报告称他们已经将生成式 AI 应用于其业务流程。然而,经过更深入的检查发现,绝大多数这些应用要么处于初期阶段,要么处于概念设计阶段,这主要是由于公司对成功部署它们的能力存在过度乐观的偏见。

图片来源:Pawel Czerwinski 通过 Unsplash
概念与生产之间的差距源于几个挑战:数据集成问题、遗留系统的限制、使用案例的投资回报考虑,以及安全壁垒。在本文中,我们将重点讨论一个关键的安全问题——多租户中的资源。
在生成式人工智能驱动的应用程序中,通常不仅仅是关于撰写文本或回复。大多数应用程序会执行数据查找操作,向大型语言模型(LLM)提供相关信息,以确保输出质量。当一个人工智能模型应用于针对多个客户或内部部门的生成式人工智能时,我们通常会让每个客户或部门拥有不同的…
访问之战:克服(无意的)数据监狱
即使你能看到数据,它也可能完全无用。
·发表于Towards Data Science ·5 分钟阅读·2024 年 6 月 17 日
--

感谢 ChatGPT 4o 对数据监狱图像的解释,接下来我会更好地定义这个概念……
更好的数据胜过巧妙的算法,但更多的数据胜过更好的数据。
— 彼得·诺维格
我做了一个东西。这很有趣,我认为它带来了(或者希望它将带来)价值。但它也付出了代价,这是我在我的行业中变得非常熟悉的代价。数据难以访问不应该是(而且不必是)常态。我把这个称为 数据监狱。* 数据很容易输入,但很难取出。而且在许多情况下,数据监狱的“铁栏”是透明的。你并不知道它很难访问,直到你真的需要它。
定义‘数据监狱’
让我首先确保我们都清楚我所说的数据监狱是什么意思。基本上,数据监狱描述的是这样一种情形:尽管数据在技术上是可用的,但它被困在格式中,限制了其轻松访问、分析和有效使用。常见的罪魁祸首包括 PDF 和其他未设计为便于数据提取和处理的文档格式。
我正在解决问题的背景
西雅图公立学校(SPS)在 2023/2024 学年接近尾声时宣布,由于预算缺口超过每年 1 亿美元且持续增长,他们无法克服这一困境。随后,一个项目和分析启动,旨在确定并关闭西雅图近 70 所小学中的最多 20 所。
我是其中一所小学学生的家长。像许多其他在没有太多预警的情况下被推向这个项目的家长一样,尽管学区通过其网页指向了多个PDF 文件,提供了相关数据,但我仍然对数据的开放性和可用性感到沮丧。
当然,也可以有人去逐一复制粘贴每个 PDF 中的数据,但这将花费大量的时间。
当然,也有人可以查看那些已经公开的先前分析(同样是通过 PDF 提供),但这些分析可能只是间接相关。
当然,有人可以通过 CSV 请求这些数据,但这些请求仅由2 个兼职工作人员支持,获取数据的时间通常是以月为单位,而非天。
因此,我花了一些时间来获取我认为任何人都需要的数据,以便合理判断哪些学校(如果有的话)应该关闭。显而易见的信息,如预算、入学人数和设施数据——过去 3 年每所学校的相关数据。
幸运的是,我不需要手动复制粘贴数据。相反,我使用Python来抓取 PDF,从而获得一个任何人都可以用来进行强有力分析的数据集。尽管如此,这仍然花费了很长时间。
当数据被解锁时,可能发生的事情
从我开始收集数据的几周后,你可以看到最终的产品。我开发的应用程序托管在Streamlit平台上,这是一个非常简洁的平台,提供了所有的框架和支持,能够快速实现数据探索或为你的代码提供用户界面。你可以将精力集中在解决问题上,而不是纠结于按钮、HTML 等细节。

该应用程序的默认设置是没有学校关闭,提供了一个基线。用户可以选择学校,查看在学校关闭后的前后数据,包括指标和地图视角,以了解学生如何被重新分配到其他学校。图片由作者提供。
我的探索开始时是对预算和招生本身的检查,但很快转变为一种理解关闭学校所带来的影响的方式——具体来说,学生如何根据招生边界之间的现有关系以及学生在这些边界内外的就读情况进行重新分配。
所以,这就成了我所创建内容的主要使用场景:
作为社区成员,从容量角度来看,特定的学校关闭情景如何影响其他周围学校?

通过加载第一个示例,我们可以看到 16 所学校被标记为关闭,要求将 3400 多名学生重新分配到其他学校,并且大多数学校的容量百分比都显著增加。这种情况导致另外 24 所学校的容量超过了 100%。图像来源:作者。
所有数据都可以通过下面的表格快速下载,用户可以迅速操作并观察自己的场景。例如:“如果他们关闭了我的学校,会怎么样?”
一个无意的顿悟!
在分析这些数据时,我确实做出了一个有趣的观察。这个观察是在完成一个相对简单的线性回归后做出的。回归的 y 轴截距大约为 76 万美元,这代表了学校开放的预估基准成本。简单来说,通过关闭一所学校,重新分配员工和预算,学区可能会看到每所学校平均节省约 76 万美元。因此,关闭多达 20 所学校,保持人员水平并重新分配学生,可能会节省超过 1500 万美元。但这与关闭所需弥补的 1 亿美元赤字之间存在很大差距。这可能需要进一步的分析——如果我能接触到更好(甚至更多)的数据就好了……
突破困境是一个选择
当我进行这个练习时,越来越明显的是,信息自由法案(FOIA)和公共记录法为突破数据监禁提供了一个机会(也许是无意的),当一些简单的抓取技能无法发挥作用时。
其他人可能已经请求过这些数据,获得了必要的批准,并收到了这些数据。尽管共享给请求者的数据被视为公共数据,但它并没有以便捷的方式提供给其他人。这就是问题所在。为什么我不能直接查看并使用别人已经请求并获得的数据呢?
总结
所以——我做了一个东西。我通过使用一个工具从 PDF 中抓取数据。但我也向西雅图公立学校和 Seattle.gov 提出了请求,要求获取过去两年内通过公共请求和信息自由法案(FOIA)提供的所有公共学校数据。这些响应和请求本身也是公共记录。
但是,对于那些没有编写代码抓取数据技能的人来说,这些数据仍然触手可及,却被 PDF、网页和图片所锁住。事情不必是这样的,也不应该是这样的。
当然,有关于首先统一数据格式的讨论是非常必要的。像Delta Lake这样的标准表格格式,看起来是一个非常可扩展且合理的解决方案(感谢Robert Dale Thompson),但即使是使过去的 FOIA(信息自由法案)和公共记录请求的数据,在现有网站如data.seattle.gov上可访问,也似乎是最基本的要求。
让我们携手解锁公共数据的潜力。查看我的Streamlit 应用程序,了解如何通过易于访问的数据带来实际的变化。通过联系当地代表并支持推动透明度的倡议,加入我一起倡导开放数据。与您的社区分享自己的经验和知识,传播意识并推动变革。我们共同努力,可以打破这些数据监狱,确保信息真正对每个人都可获取。
一旦我获得数据,就会有更多内容发布。
过采样与欠采样,详解:带有迷你二维数据集的视觉指南
数据预处理
人为地生成和删除数据,以实现更大的利益
·发表于Towards Data Science ·阅读时间:9 分钟·2024 年 10 月 26 日
--

⛳️ 更多的[数据预处理](https://medium.com/@samybaladram/list/data-preprocessing-17a2c49b44e4),详解:· 缺失值填补 · 分类编码 · 数据缩放 · 离散化 ▶ 过采样与欠采样 · 数据泄漏在预处理中的影响
收集一个每个类别的样本数量完全相同的数据集可能是一项挑战。实际上,情况很少是完美平衡的,在构建分类模型时,这可能成为一个问题。当一个模型在这样的数据集上进行训练时,如果某一类的样本比另一类多,通常它会更擅长预测较大的类别,而在预测较小的类别时表现较差。为了解决这个问题,我们可以使用过采样和欠采样等策略——为较小类别创造更多样本,或从较大类别中删除一些样本。
有许多不同的过采样和欠采样方法(如 SMOTE、ADASYN 和 Tomek Links 等这些令人害怕的名字),但是似乎没有很多资源可以直观地比较它们的工作方式。所以,在这里,我们将使用一个简单的二维数据集来展示应用这些方法后数据发生的变化,以便我们看到每种方法的输出有多大不同。你将会在可视化中看到,这些不同的方法给出了不同的解决方案,谁知道,也许其中某个方法适合你的具体机器学习挑战!

所有可视化:作者使用 Canva Pro 创建。优化为移动设备使用;在桌面上可能会显得过大。
定义
过采样
过采样可以使数据集更平衡,当一个组的样本比另一个组少得多时,它通过增加较小组的样本副本来工作。这有助于数据集更均衡地代表两个组。
欠采样
另一方面,欠采样通过删除较大组中的一些样本,直到它几乎与较小组的大小相同。最终,数据集确实会变小,但两个组的样本数量会更为相似。
混合采样
将过采样和欠采样结合起来,称为“混合采样”。它通过增加较小组样本的副本来增大较小组的规模,同时也通过删除较大组中的一些样本来缩小较大组的规模。它试图创建一个更加平衡的数据集——既不太大也不太小。

📊 使用的数据集
让我们使用一个简单的人工高尔夫数据集来展示过采样和欠采样。这个数据集展示了在特定天气条件下,一个人做什么类型的高尔夫活动。

列:温度(0–3),湿度(0–3),高尔夫活动(A=标准球场,B=练习场,C=室内高尔夫)。训练数据集有 2 个维度和 9 个样本。
⚠️ 请注意,虽然这个小数据集有助于理解概念,但在实际应用中,你应该在应用这些技术之前使用更大的数据集,因为使用过少的数据进行采样可能会导致结果不可靠。
过采样方法
随机过采样
随机过采样是一种简单的方法,通过复制较小组的样本,直到所有类别平衡。
👍 最适合需要快速平衡的小型数据集
👎 不推荐用于复杂的数据集


随机过采样简单地复制较小组(A)中选定的样本,同时保持较大组(B 和 C)中的所有样本不变,如右图中的 A×2 标记所示。
SMOTE
SMOTE(合成少数类过采样技术)是一种过采样技术,通过对较小组进行插值来生成新样本。与随机过采样不同,它不仅仅是复制已有的样本,而是利用较小组的样本生成它们之间的样本。
👍 最适合在你有足够的样本并且需要数据多样性的情况
👎 如果样本数量非常少,不推荐使用
👎 如果数据点过于分散或噪声过大,不推荐使用


SMOTE 通过选择 A 点的成对样本,并在它们之间的某个位置生成新样本,从而创建新的 A 样本。类似地,一个新的 B 点会在随机选取的 B 点对之间生成。
ADASYN
ADASYN(自适应合成)类似于 SMOTE,但专注于在较小组的难学部分生成新的样本。它会找到那些最难分类的样本,并在这些样本周围生成更多的新点。这有助于模型更好地理解具有挑战性的区域。
👍 如果数据的某些部分比其他部分更难分类,最佳选择
👍 最适合处理具有挑战性区域的复杂数据集
👎 如果你的数据相对简单直接,不推荐使用


ADASYN 在较小组(A)的“难学区域”生成更多合成点,这些区域是 A 点与其他组(B 和 C)接近的地方。它还会在类似区域生成新的 B 点。
欠采样方法
欠采样通过缩小较大组的大小,使其与较小组的大小更为接近。有几种方法可以做到这一点:
随机欠采样
随机欠采样通过随机移除较大组中的样本,直到其大小与较小组相同。与随机过采样一样,这种方法相当简单,但它可能会丢失一些重要信息,这些信息展示了各组之间的差异。
👍 最适合处理非常大的数据集,尤其是当样本重复较多时
👍 如果你需要一个快速、简单的解决方案,最佳选择
👎 如果你较大组中的每个样本都很重要,不推荐使用
👎 如果你不能接受丢失任何信息,不推荐使用


随机欠采样通过从较大的组(B 和 C)中随机删除样本,同时保持较小组(A)中的所有样本不变。
Tomek 链接
Tomek Links是一种欠采样方法,它使得组之间的“边界”更清晰。它搜索来自不同组的非常相似的示例对。当它找到一对示例,这些示例是彼此最接近的邻居但属于不同的组时,它会从较大组中移除这个示例。
👍 当你的组之间重叠过多时,效果最佳
👍 最适合清理杂乱或噪声数据
👍 当你需要组之间清晰的边界时,效果最佳
👎 如果你的组已经很分开,不推荐使用


Tomek Links 识别来自不同组(A-B,B-C)的点对,这些点对是彼此最接近的邻居。然后,从较大组(B 和 C)中移除这些点对,而保留所有较小组(A)中的点。
Near Miss
Near Miss是一组基于不同规则的欠采样技术:
-
Near Miss-1:保留较大组中与较小组中的示例最接近的示例。
-
Near Miss-2:保留较大组中与较小组中三个最接近邻居的平均距离最小的示例。
-
Near Miss-3:保留较大组中距离自己组内其他示例最远的示例。
这里的主要思想是保留较大组中最具信息量的示例,去除那些不太重要的示例。
👍 当你需要控制保留哪些示例时,效果最佳
👎 如果你需要一个简单、快速的解决方案,不推荐使用


NearMiss-1 保留来自较大组(B 和 C)中与较小组(A)最接近的点,同时去除其他点。这里,仅保留距离 A 点最近的 B 和 C 点。
ENN
Edited Nearest Neighbors(ENN)方法去除那些可能是噪声或离群点的示例。对于较大组中的每个示例,它检查其大多数最近邻是否属于同一组。如果不属于同一组,它会移除该示例。这有助于创建更清晰的组边界。
👍 最适合清理杂乱的数据
👍 当你需要移除离群点时,效果最佳
👍 最适合创建更清晰的组边界
👎 如果你的数据已经干净且组织良好,不推荐使用


ENN 从较大组(B 和 C)中移除其大多数最近邻属于不同组的点。在右侧的图中,划掉的点被移除,因为它们的大多数最近邻来自其他组。
混合采样方法
SMOTETomek
SMOTETomek的工作原理是,首先使用 SMOTE 为较小的类别创建新样本,然后通过使用 Tomek 链接去除“混淆”样本来清理杂乱的边界。这有助于创建一个更加平衡、边界更清晰且噪声更少的数据集。
👍 最适用于极度不平衡的数据
👍 最适合当你需要更多样本并且边界更清晰时
👍 最适合处理噪声较多且重叠的类别
👎 如果数据已经清理和整理得很好,建议不要使用
👎 不推荐用于小型数据集

SMOTETomek 结合了两个步骤:首先应用 SMOTE,在现有 A 点之间的线段上创建新的 A 点(如中间图所示),然后从较大类别(B 和 C)中去除 Tomek 链接。最终结果是更加平衡的类别,且它们之间的边界更加清晰。
SMOTEENN
SMOTEENN的工作原理是,首先使用 SMOTE 为较小类别创建新样本,然后通过使用 ENN 清理两个类别,去除那些与邻近样本不匹配的样本。像 SMOTETomek 一样,这有助于创建一个更干净的数据集,并使类别之间的边界更加清晰。
👍 最适合同时清理两个类别
👍 最适合当你需要更多样本但又希望数据更干净时
👍 最适合处理大量离群点时
👎 如果数据已经清理和整理得很好,建议不要使用
👎 不推荐用于小型数据集

SMOTEENN 结合了两个步骤:首先使用 SMOTE 在现有 A 点之间的线段上创建新的 A 点(如中间图所示),然后应用 ENN 去除那些邻近点大多来自不同类别的较大类别(B 和 C)中的点。最终图显示了清理后的平衡数据集。
⚠️ 使用重采样方法时的风险
重采样方法可以很有帮助,但也存在一些潜在的风险:
过采样:
-
生成人工样本可能会给出不真实的模式,这些模式在现实生活中并不存在。
-
模型可能因为合成样本而变得过于自信,这会在应用于实际情况时导致严重失败。
-
如果重采样操作不当(比如在数据拆分用于交叉验证之前),可能会发生数据泄漏的风险。
欠采样:
-
你可能会永久丢失重要信息。
-
你可能会意外地破坏类别之间的重要边界,从而导致对问题的误解。
-
你可能会创建与真实世界条件差异过大的人工类别分布。
混合方法:
- 结合两种方法的错误可能会使问题变得更糟,而不是更好。
在使用重采样方法时,很难找到在不改变数据中重要模式的情况下,既能解决类别不平衡问题又不影响模型性能的平衡点。根据我的经验,错误的重采样实际上可能会损害模型性能,而不是提高它。
在进行重采样之前,尝试使用那些自然能更好处理不平衡数据的模型,例如基于树的算法。重采样应作为更广泛策略的一部分,而不是解决类别不平衡的唯一方案。
🌟 超采样与欠采样代码总结
对于代码示例,我们将使用由 [imblearn](https://imbalanced-learn.org/stable/index.html) 库提供的方法:
import pandas as pd
from imblearn.over_sampling import SMOTE, ADASYN, RandomOverSampler
from imblearn.under_sampling import TomekLinks, NearMiss, RandomUnderSampler
from imblearn.combine import SMOTETomek, SMOTEENN
# Create a DataFrame from the dataset
data = {
'Temperature': [1, 0, 1, 3, 2, 3, 1, 3, 4],
'Humidity': [0, 2, 1, 1, 3, 2, 3, 4, 4],
'Activity': ['A', 'A', 'B', 'B', 'B', 'C', 'C', 'C', 'C']
}
df = pd.DataFrame(data)
# Split the data into features (X) and target (y)
X, y = df[['Temperature', 'Humidity']], df['Activity'].astype('category')
# Initialize a resampling method
# sampler = RandomOverSampler() # Random OverSampler for oversampling
sampler = SMOTE() # SMOTE for oversampling
# sampler = ADASYN() # ADASYN for oversampling
# sampler = RandomUnderSampler() # Random UnderSampler for undersampling
# sampler = TomekLinks() # Tomek Links for undersampling
# sampler = NearMiss(version=1) # NearMiss-1 for undersampling
# sampler = EditedNearestNeighbours() # ENN for undersampling
# sampler = SMOTETomek() # SMOTETomek for a combination of oversampling & undersampling
# sampler = SMOTEENN() # SMOTEENN for a combination of oversampling & undersampling
# Apply the resampling method
X_resampled, y_resampled = sampler.fit_resample(X, y)
# Print the resampled dataset
print("Resampled dataset:")
print(X_resampled)
print(y_resampled)
技术环境
本文使用 Python 3.7、pandas 1.3 和 imblearn 1.2。虽然讨论的概念普遍适用,但不同版本的具体代码实现可能会有所不同。
关于插图
除非另有说明,所有图片均由作者创作,包含了来自 Canva Pro 的授权设计元素。
𝙎𝙚𝙚 𝙢𝙤𝙧𝙚 𝘿𝙖𝙩𝙖 𝙋𝙧𝙚𝙥𝙧𝙤𝙘𝙚𝙨𝙨𝙞𝙣𝙜 𝙢𝙚𝙩𝙝𝙤𝙙𝙨 𝙝𝙚𝙧𝙚:

数据预处理
查看列表6 个故事


𝙔𝙤𝙪 𝙢𝙞𝙜𝙝𝙩 𝙖𝙡𝙨𝙤 𝙡𝙞𝙠𝙚:

分类算法
查看列表8 个故事



回归算法
查看列表5 个故事!一个戴着粉色帽子、扎着辫子的卡通娃娃。这个“假人”娃娃,凭借其简单的设计和心形图案的衬衫,在视觉上代表了机器学习中的假回归模型(dummy regressor)概念。就像这个玩具般的形象是一个简化的、静态的人物表达,假回归模型是一些基本的模型,用作更复杂分析的基准。

在 Python 中覆盖对象:棘手、危险且强大
PYTHON 编程
虽然覆盖对象是 Python 编程中的典型技巧,但它可能会导致意想不到的效果。你需要了解如何使用它,才能发挥其优势。
·发表于Towards Data Science ·27 分钟阅读·2024 年 4 月 10 日
--

在 Python 中覆盖对象可能是危险的:不要在未深思熟虑的情况下进行操作。照片由Raúl Nájera提供,来源于Unsplash
我曾经覆盖过许多不同的对象。如果你也曾在 Python 中编写过代码,可能也有过类似的经历。这是因为,在 Python 中,覆盖对象是语言的核心特性之一。
我不仅覆盖了变量,还覆盖了函数、类和类方法——甚至是异常:
[## 如何在 Python 中覆盖 AssertionError 并使用自定义异常
Python 的 assert 语句使用的是 AssertionError。了解如何使用其他异常来代替它。
betterprogramming.pub](https://betterprogramming.pub/how-to-overwrite-asserterror-in-python-and-use-custom-exceptions-c0b252989977?source=post_page-----04b12a9b1a7e--------------------------------)
事实上,我们会区分覆盖变量和可调用对象。这两者之间的区别实际上相当重要,后者更为棘手。尽管我们的重点是覆盖可调用对象,因为它是一种更高级的技巧,我们也会讨论覆盖变量,因为它为我们的讨论提供了一个良好的起点。
P-Companion:亚马逊的多元化补充产品推荐原则框架
深入探讨亚马逊的补充产品推荐框架
·发布于 Towards Data Science ·11 分钟阅读·2024 年 10 月 1 日
--
介绍
补充产品推荐(CPR)在电子商务平台的成功中变得越来越重要。CPR 的目标是提供最相关的产品,这些产品通常是一起购买的。手机和手机壳经常一起购买;网球拍购买时,通常会一起购买网球;购买笔记本电脑后,通常会购买鼠标。本文将讨论亚马逊如何将 CPR 作为一种产品到产品的推荐问题解决:
给定一个“查询”产品,目标是推荐与该“查询”产品相关且多样的补充产品,使其可以一起购买,从而满足共同的需求。

补充产品推荐场景(图片来源 [1])
预测这样的补充产品是一个非平凡的任务。让我们通过一个简单的例子来了解在解决这个问题时所面临的挑战。假设网球拍是“查询产品”,平台会显示三份相关产品列表。
-
列表 1 包括其他三款相似的网球拍。
-
列表 2 包括三款网球。
将你的 TypeScript 客户端打包成 Python 后端

图片由 Markus Spiske 提供,来源于 Unsplash
完整的实操指南
将你的 React 应用与 FastAPI Web 服务器结合
·发布于 Towards Data Science ·阅读时间 6 分钟·2024 年 4 月 5 日
--
在本指南中,你将学习如何将一个简单的 TypeScript React 应用程序 打包成一个 Python 包,并通过你的 FastAPI Python Web 服务器提供服务。如果你想查看完整的代码,请查看 客户端 和 服务器 的代码库。让我们开始吧!
在开发过程中,你可能会使用两个不同的集成开发环境(IDE):
-
TypeScript 或 JavaScript React 应用窗口,运行在专用监听端口(例如:5173)上,用于提供客户端/前端页面。
-
Python FastAPI,运行在另一个端口(例如:8080)上,用于提供 REST API。
换句话说,你有两个不同的服务器在本地运行。每当你想调用 FastAPI 服务器时,浏览器需要与两个不同的服务器进行交互。

本地开发(图示来自作者)
虽然在本地(localhost)运行时一切正常,但当你将代码部署到生产环境时,浏览器会遇到“跨源请求被阻止”的错误。在将代码部署到生产环境之前,最佳实践是将客户端页面和 REST API 都从同一个后端 Web 服务器提供服务。这样浏览器将只与一个后端交互,这对安全性、性能和简洁性更有益。

准备上线(图示来自作者)
1. 创建一个简单的 React 应用
首先,在你的 workspace 目录中,使用 vite 创建一个新的 TypeScript React 应用程序:
~/workspace ➜ npm create vite@latest
✔ Project name: … vite-project
✔ Select a framework: › React
✔ Select a variant: › TypeScript
然后,进入新的项目目录,安装依赖并运行应用程序(localhost:5173):
~/workspace ➜ cd vite-project
~/workspace/vite-project ➜ npm install
~/workspace/vite-project ➜ npm run dev
你应该会看到类似的内容:

第一个 Vite React 模板(图片来自作者)
现在,让我们对模板做一个小修改——我们将添加一个异步 HTTP 调用来获取未来 FastAPI 后端的状态:
function App() {
...
const [health, setHealth] = useState('');
useEffect(() => {
const getStatus = async () => {
const response = await fetch('/v1/health-check/liveness', {
method: 'GET',
});
let status: { [status: string]: string } = {};
try {
status = await response.json();
} catch (err) {
console.log(`failed to get backend status. ${err}`);
}
setHealth(status['status'] || 'unknown');
};
getStatus();
}, []);
return (
...
<div>Backend Status: {health}</div>
...
)
}
现在我们应该能看到类似这样的结果:

使用后端调用(图片来自作者)
此时,后端状态是 unknown,因为我们尚未实现它。别担心,我们很快就会处理它。最后,让我们构建客户端,方便稍后进行打包:
~/workspace/vite-project ➜ npm run build
打包输出应创建一个 dist 文件夹,其中包含最终优化后的代码,如下所示:
└── dist/
├── assets/
├── static/
└── index.html
2. 构建 Python 包
此时,我们切换到 Python 环境。我更喜欢在 虚拟环境 中工作以实现隔离。在一个专用的虚拟环境中,我们将安装 twine 和 build 来创建我们的 Python 包:
~/workspace/vite-project ➜ python3 -m venv venv
~/workspace/vite-project ➜ . venv/bin/activate
~/workspace/vite-project (venv) ➜ python -m pip install --upgrade pip
~/workspace/vite-project (venv) ➜ pip install twine==5.0.0 build==1.2.1
在根文件夹(vite-project)中创建一个新的 setup.py 文件,内容如下:
from setuptools import setup
from pathlib import Path
cwd = Path(__file__).parent
long_description = (cwd / "README.md").read_text()
setup(
name="vite-project",
version="0.0.1",
package_dir={"vite_project": "dist"},
package_data={"vite_project": ["**/*.*"]},
long_description=long_description,
long_description_content_type="text/markdown",
)
然后运行以下命令来创建包:
~/workspace/vite-project (venv) ➜ python setup.py sdist -d tmp
~/workspace/vite-project (venv) ➜ python -m build --wheel --outdir tmp
~/workspace/vite-project (venv) ➜ twine upload -u ${USERNAME} -p ${PASSWORD} --repository-url ${REPO_URL} tmp/*
上面的最后一行是可选的,如果你打算将你的包上传到远程仓库,例如 PyPI、JFrog Artifactory 等。
3. 创建一个 FastAPI Python Web 服务器
最后一步是构建 Python 服务器并使用客户端包。为此,我们将:
-
创建一个新的
backend目录。 -
创建一个新的虚拟环境。
-
安装相关的包和我们的客户端包:
~/workspace/backend ➜ python3 -m venv venv
~/workspace/backend ➜ . venv/bin/activate
~/workspace/backend (venv) ➜ python -m pip install --upgrade pip
~/workspace/backend (venv) ➜ pip install fastapi==0.110.0 uvicorn==0.29.0
~/workspace/backend (venv) ➜ pip install ~/workspace/vite-project/tmp/vite-project-0.0.1.tar.gz
请注意,我们从之前创建的本地路径安装了我们的客户端包。如果你将包上传到远程仓库,你可以使用以下命令安装:
~/workspace/backend (venv) ➜ pip install --extra-index-url https://${USERNAME}:${PASSWORD}@${REPO_URL} vite-project==0.0.1
接下来,让我们创建一个简单的 Python 服务器(2 个文件):
main.py
from distutils.sysconfig import get_python_lib
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from backend.health_router import router
from uvicorn import run
def create_app():
app = FastAPI(
title="Backend Server",
)
app.include_router(router)
client_path = f"{get_python_lib()}/vite_project"
app.mount("/assets", StaticFiles(directory=f"{client_path}/assets"), name="assets")
app.mount("/static", StaticFiles(directory=f"{client_path}/static"), name="static")
@app.get("/{catchall:path}")
async def serve_react_app(catchall: str):
return FileResponse(f"{client_path}/index.html")
return app
def main():
app = create_app()
run(app, host="0.0.0.0", port=8080)
if __name__ == "__main__":
main()
health_router.py
from typing import Literal
from typing_extensions import TypedDict
from fastapi import APIRouter, status
STATUS = Literal["success", "error", "partial", "unknown"]
class ReturnHealthcheckStruct(TypedDict):
status: STATUS
router = APIRouter(
prefix="/v1/health-check",
tags=["Health Check"],
)
@router.get(
"/liveness",
summary="Perform a Liveness Health Check",
response_description="Return HTTP Status Code 200 (OK)",
status_code=status.HTTP_200_OK,
response_model=ReturnHealthcheckStruct,
)
async def liveness() -> ReturnHealthcheckStruct:
return {"status": "success"}
在上面的实现中,我们通过挂载 static 和 assets 文件夹以及任何其他客户端文件来支持从客户端应用程序提供静态文件,并由我们的 Python 服务器进行服务。
我们还创建了一个简单的 GET 端点 v1/health-check/liveness,它返回一个简单的 {“status": “success"} JSON 响应。这样我们就能确保服务器同时处理客户端静态文件和服务器端 RESTful API。
现在,如果我们访问 localhost:8080,我们可以看到我们的客户端正在运行。注意下面的 后端状态,现在它是 success(而不是 unknown)。

同时运行 Python 服务器与 React 应用程序(图片来自作者)
总结
在本教程中,我们创建了一个简单的 React 应用程序,该应用程序向后端发出单个请求。我们将这个客户端应用程序封装为一个 Python 包,并通过我们的 FastAPI Python Web 服务器提供服务。
采用这种方法可以让你在两个领域中利用最好的工具:前端使用 TypeScript 和 React,后端使用 Python 和 FastAPI。然而,我们希望保持这两个组件之间的高度内聚和低耦合。这样,你将获得所有的好处:
-
效率,通过将前端和后端分离到不同的代码库,每个部分可以由不同的团队进行开发。
-
稳定性和质量,通过锁定版本化的客户端包,并仅在服务器准备好支持新客户端版本时才进行更新。
-
安全性——浏览器仅与一个后端服务器进行交互。我们无需启用 CORS 或任何其他可能妥协安全性的解决方法。
-
简单性——通过使用单一服务器进行工作
PAGA 解释:单细胞数据的图形抽象
如何从更广阔的视角看待数据,进而获得对其深层含义的洞察
·发表于Towards Data Science ·阅读时长 6 分钟·2024 年 6 月 6 日
--

图片来源:Clint Adair 在Unsplash
在单细胞基因组学数据中,我们对个体细胞的数万个特征进行分析,无论是在基因表达、蛋白质表达,还是其他全基因组测量模式下,我们通常都会努力从数据中提取出高层次的总结。这可以通过多种方式实现,例如差异基因表达,我们通过对细胞群体进行统计检验,比较不同群体之间哪些基因在统计上具有显著性;或者数据可视化,我们将由多个特征测量所带来的高维数据压缩成 2 或 3 个维度,从而帮助我们理解数据。
理解单细胞数据的一种有效方式是将其转化为图——即一组由节点和边组成的元组。在这种情况下,节点代表细胞,边则定义了细胞之间的连接。这种基于元组的数据结构提供了一种灵活的方式来将细胞分组,探索不同类型细胞之间的关系(例如,病变细胞与正常细胞的区别,不同发展阶段的细胞),以及可视化数据集的整体结构。
Pandas 列:括号索引(df[‘x’])与点语法(df.x)
PANDAS 数据科学应用
使用哪种方式有区别吗?也许某种方式比另一种更快?
·发表于 Towards Data Science ·阅读时间:5 分钟·2024 年 3 月 1 日
--

点语法在 Python 中非常流行,在 Pandas 中也广泛使用。图片来源:Alejandro Barba 提供的 Unsplash
使用 Pandas 时,大多数数据科学家会选择 df['x'] 或 df["x"]——其实这两个用法没有太大区别,只要你选定其中一个并坚持使用即可。你可以在这里查看更多相关内容:
[## 所以,Python 中使用单引号(‘)还是双引号(“)?
很多人认为你应该在 Python 中偏好使用单引号而不是双引号。你真的应该这么做吗?
因此,从现在开始,无论我写的是 df["x"],它都等同于 df['x']。不过,还有另一种选择。你也可以使用 df.x。虽然这种方法不太常见,但它能提高代码的可读性,前提是列名是一个有效的 Python 标识符。¹
选择哪种语法真的有区别吗?本文旨在从两个最重要的角度讨论这个问题:可读性和性能。
优缺点
这两种方法——df["x"] 和 df.x——是从数据框(此处为 df)中访问列(此处为 "x")的常见方式。在数据科学领域,大多数…
数据工程师的 Pandas
高级技巧:高效处理和加载数据
·发布于 Towards Data Science ·阅读时间 9 分钟·2024 年 2 月 10 日
--

使用 Kandinsky 生成的 AI 图像
在这篇文章中,我想谈谈我喜欢并且在编写 ETL 应用程序时经常使用 Pandas 的一些功能,用来处理数据。我们将讨论探索性数据分析、数据清洗和数据框变换。我将展示一些我最喜欢的技巧,以优化内存使用并高效处理大量数据。使用 Pandas 处理相对较小的数据集通常不会有问题。它能够轻松处理数据框中的数据,并提供一组非常方便的命令来处理这些数据。当涉及到更大数据框(1GB 及以上)上的数据变换时,我通常会使用 Spark 和分布式计算集群。Spark 能处理数 TB 和 PB 级的数据,但运行这些硬件可能会花费大量的资金。因此,当我们在内存资源有限的环境中处理中等大小的数据集时,Pandas 可能是一个更好的选择。
Pandas 和 Python 生成器
在我之前的一篇文章中,我写到了如何使用 Python 中的生成器高效处理数据 [1]。
Pandas: 从杂乱到优雅
这就是如何让你的 pandas 代码既易于阅读又坚不可摧。
·发布于 Towards Data Science ·阅读时间 9 分钟 ·2024 年 3 月 31 日
--

在 pandas DataFrame 上编写脚本可能会变成一堆尴尬的(并不那么)老式意大利面式代码。我和我的同事们经常使用这个包,尽管我们尽力遵循良好的编程实践,比如将代码拆分成模块和进行单元测试,但有时我们还是会互相妨碍,编写出令人困惑的代码。
我收集了一些技巧和陷阱,帮助避免在编写 pandas 代码时出现问题,从而使代码更加清晰和万无一失。希望你也能从中受益。我们将借助 Robert C. Martin 的经典著作《Clean Code》,特别是结合 pandas 包的背景进行讨论。简短总结见文末。
不该做的事
让我们从一些受真实案例启发的错误模式开始分析。稍后,我们将尝试重构这些代码,以提升可读性和可控性。
可变性
Pandas DataFrame 是值可变的 [2, [3]](https://realpython.com/python-mutable-vs-immutable-types/) 对象。每当你更改一个可变对象时,它会影响你最初创建的相同实例,并且它在内存中的物理位置保持不变。相比之下,当你修改一个不可变对象(例如字符串)时,Python 会创建一个全新的对象,并将其存储在新的内存位置,然后替换为对该新对象的引用。
这是关键点:在 Python 中,对象通过赋值传递给函数[4, [5]](https://realpython.com/python-pass-by-reference/)。 看图:df 的值在作为参数传递给函数时被赋值给变量 in_df。即使它们使用不同的变量名,原始的 df 和函数内部的 in_df 都指向相同的内存位置(括号中的数字值)。在修改其属性时,可变对象的位置保持不变。现在,所有其他作用域也能看到这些变化——它们都指向相同的内存位置。

修改可变对象在 Python 内存中的情况。
实际上,由于我们已修改了原始实例,因此返回 DataFrame 并将其赋值给变量是多余的。这个代码有相同的效果:

修改可变对象在 Python 内存中的情况,冗余赋值已被移除。
提醒:现在函数返回 None,因此要小心不要将 df 覆盖为 None,如果你确实执行赋值:df = modify_df(df)。
相反,如果对象是不可变的,它会在整个修改过程中改变内存位置,就像下面的例子一样。由于红色字符串无法修改(字符串是不可变的),绿色字符串是在旧字符串之上创建的,但它是一个全新的对象,声明了一个新的内存位置。返回的字符串不是同一个字符串,而返回的 DataFrame 是完全相同的 DataFrame。

修改不可变对象在 Python 内存中的情况。
关键是,在函数内修改 DataFrame 会产生全局效果。如果你没有牢记这一点,可能会:
-
意外修改或删除部分数据,以为操作只发生在函数作用域内——其实并非如此,
-
失去对何时向
DataFrame添加内容的控制,例如在嵌套的函数调用中。
输出参数
我们稍后会解决这个问题,但在我们进入do之前,这里还有另一个don't
前一节的设计实际上是一种反模式,称为输出参数[1 p.45]。通常,函数的输入将用于创建输出值。如果传递一个参数给函数的唯一目的是修改它,使得输入参数改变其状态,那么这就挑战了我们的直觉。这种行为被称为函数的副作用[1 p.44],应当做好文档记录并尽量减少,因为它们迫使程序员记住那些发生在后台的事情,从而使得脚本容易出错。
当我们阅读一个函数时,我们习惯了信息通过参数进入函数并通过返回值输出。我们通常不期望信息会通过参数输出。[1 p.41]
如果函数有双重责任:既修改输入又返回输出,情况会变得更糟。考虑这个函数:
def find_max_name_length(df: pd.DataFrame) -> int:
df["name_len"] = df["name"].str.len() # side effect
return max(df["name_len"])
它的确返回了一个值,正如你所期望的,但它也永久修改了原始的DataFrame。这个副作用会让你感到惊讶——函数签名中没有任何东西表明我们的输入数据会被影响。在下一步中,我们将看到如何避免这种设计。
应该做的事
减少修改
为了消除副作用,在下面的代码中,我们创建了一个新的临时变量,而不是修改原始的DataFrame。符号lengths: pd.Series表示该变量的数据类型。
def find_max_name_length(df: pd.DataFrame) -> int:
lengths: pd.Series = df["name"].str.len()
return max(lengths)
这种函数设计更好,因为它封装了中间状态,而不是产生副作用。
另一个提醒:请注意深拷贝和浅拷贝的区别[6]。在上面的示例中,我们修改了原始df["name"] Series的每个元素,因此旧的DataFrame和新的变量没有共享的元素。然而,如果你直接将原始列之一赋值给一个新变量,底层元素在内存中仍然引用相同的对象。请参阅以下示例:
df = pd.DataFrame({"name": ["bert", "albert"]})
series = df["name"] # shallow copy
series[0] = "roberta" # <-- this changes the original DataFrame
series = df["name"].copy(deep=True)
series[0] = "roberta" # <-- this does not change the original DataFrame
series = df["name"].str.title() # not a copy whatsoever
series[0] = "roberta" # <-- this does not change the original DataFrame
你可以在每一步后打印出DataFrame,以观察效果。记住,创建深拷贝会分配新的内存,所以考虑一下你的脚本是否需要高效地使用内存。
分组相似的操作
也许由于某种原因,你想存储那个长度计算的结果。把它附加到函数内部的DataFrame中依然不是一个好主意,因为这违反了副作用的规则,并且使一个函数承担了多重责任。
我喜欢每个函数一个抽象层次的规则,它的意思是:
我们需要确保函数中的所有语句都处于相同的抽象层次。
在一个函数内混合不同的抽象层次总是令人困惑。读者可能无法分辨某个表达式是核心概念还是细节。[1 p.36]
另外,尽管我们现在并不专注于面向对象代码,但我们还是可以借鉴 OOP 中的单一职责原则[1 p.138]。
为什么不提前准备好数据呢?我们将数据准备和实际计算分成不同的函数:
def create_name_len_col(series: pd.Series) -> pd.Series:
return series.str.len()
def find_max_element(collection: Collection) -> int:
return max(collection) if len(collection) else 0
df = pd.DataFrame({"name": ["bert", "albert"]})
df["name_len"] = create_name_len_col(df.name)
max_name_len = find_max_element(df.name_len)
创建name_len列的单一任务已外包给另一个函数。它不会修改原始的DataFrame,并且一次只执行一个任务。稍后我们通过将新列传递给另一个专用函数来获取最大元素。注意,聚合函数对于Collections来说是通用的。
让我们通过以下步骤来优化代码:
-
我们可以使用
concat函数,并将其提取到一个名为prepare_data的单独函数中,这样可以将所有数据准备步骤集中在一个地方, -
我们还可以利用
apply方法,处理单个文本,而不是Series中的文本, -
让我们记住使用浅拷贝和深拷贝,具体取决于是否应该修改原始数据:
def compute_length(word: str) -> int:
return len(word)
def prepare_data(df: pd.DataFrame) -> pd.DataFrame:
return pd.concat([
df.copy(deep=True), # deep copy
df.name.apply(compute_length).rename("name_len"),
...
], axis=1)
可重用性
我们拆分代码的方式确实让我们以后很容易回到脚本,拿出整个函数并在另一个脚本中重用。我们喜欢这样!
还有一件事我们可以做,以提高可重用性:将列名作为参数传递给函数。重构有点过头,但有时为了灵活性或可重用性,还是值得的。
def create_name_len_col(df: pd.DataFrame, orig_col: str, target_col: str) -> pd.Series:
return df[orig_col].str.len().rename(target_col)
name_label, name_len_label = "name", "name_len"
pd.concat([
df,
create_name_len_col(df, name_label, name_len_label)
], axis=1)
可测试性
你是否曾在对预处理数据集进行几周的实验后,才发现预处理有问题?没有?真幸运。我实际上因为注释错误不得不重复一批实验,如果我在一开始测试过一些基本的函数,完全可以避免这种情况。
重要的脚本应该进行测试 [1 p.121, 7]。即使这个脚本只是一个辅助工具,我现在也会尝试至少测试最关键的、最底层的函数。让我们回顾一下从一开始到现在的步骤:
1. 我甚至不愿意考虑测试这个,它非常冗余,我们已经忽视了副作用。它还测试了许多不同的功能:计算名称长度和聚合最大元素的结果。而且它失败了,你预见到了吗?
def find_max_name_length(df: pd.DataFrame) -> int:
df["name_len"] = df["name"].str.len() # side effect
return max(df["name_len"])
@pytest.mark.parametrize("df, result", [
(pd.DataFrame({"name": []}), 0), # oops, this fails!
(pd.DataFrame({"name": ["bert"]}), 4),
(pd.DataFrame({"name": ["bert", "roberta"]}), 7),
])
def test_find_max_name_length(df: pd.DataFrame, result: int):
assert find_max_name_length(df) == result
2. 这样好多了——我们集中精力处理一个任务,因此测试变得简单。我们也不需要像之前那样过于关注列名。然而,我认为数据格式阻碍了验证计算正确性。
def create_name_len_col(series: pd.Series) -> pd.Series:
return series.str.len()
@pytest.mark.parametrize("series1, series2", [
(pd.Series([]), pd.Series([])),
(pd.Series(["bert"]), pd.Series([4])),
(pd.Series(["bert", "roberta"]), pd.Series([4, 7]))
])
def test_create_name_len_col(series1: pd.Series, series2: pd.Series):
pd.testing.assert_series_equal(create_name_len_col(series1), series2, check_dtype=False)
3. 这里我们清理了桌面。我们彻底测试了计算函数,抛开了 Pandas 的封装。在专注于一个问题时,更容易想到边界情况。我发现我想测试DataFrame中可能出现的None值,最终我不得不改进我的函数,以使这个测试通过。一个 bug 被发现了!
def compute_length(word: Optional[str]) -> int:
return len(word) if word else 0
@pytest.mark.parametrize("word, length", [
("", 0),
("bert", 4),
(None, 0)
])
def test_compute_length(word: str, length: int):
assert compute_length(word) == length
4. 我们只差对find_max_element的测试了:
def find_max_element(collection: Collection) -> int:
return max(collection) if len(collection) else 0
@pytest.mark.parametrize("collection, result", [
([], 0),
([4], 4),
([4, 7], 7),
(pd.Series([4, 7]), 7),
])
def test_find_max_element(collection: Collection, result: int):
assert find_max_element(collection) == result
我从不忘提到的单元测试的另一个好处是,它是文档化代码的一种方式,因为不懂代码的人(比如未来的你)可以仅通过查看测试,轻松了解输入和期望输出,包括边界情况。双重收获!
结论
这些是我在编写代码和审查别人代码时发现有用的一些技巧。我并不是告诉你某种编码方式就是唯一正确的——你可以从中选择你需要的,你决定是要一个快速的临时解决方案,还是要一个高度精炼和经过测试的代码库。我希望这篇文章能帮助你更好地组织脚本,使你对它们更加满意并对它们的可靠性感到更加自信。
如果你喜欢这篇文章,请告诉我。我很高兴能得到反馈。祝你编码愉快!
摘要
没有一种唯一正确的编码方式,但以下是一些使用 pandas 进行脚本编写的灵感:
禁止事项:
在函数内尽量不要频繁修改你的
*DataFrame*,因为这样你可能会失去对其中内容的控制,无法知道哪些数据被追加或移除。不要编写修改
*DataFrame*并且不返回任何结果的方法,因为这会让人感到困惑。必做事项:
创建新对象,而不是修改原始的
*DataFrame*,并记得在需要时进行深拷贝,在一个函数内只执行类似层次的操作,
设计灵活且可重用的函数,
测试你的函数,因为这能帮助你设计更简洁的代码,防止出现错误和边界情况,还能免费为代码生成文档。
参考文献
-
[1] Robert C. Martin,《代码整洁之道:敏捷软件工艺手册》(2009),Pearson Education, Inc.
-
[2] pandas 文档 - 包概述 — 数据的可变性与拷贝,
pandas.pydata.org/pandas-docs/stable/getting_started/overview.html#mutability-and-copying-of-data -
[3] Python 的可变类型与不可变类型:有何区别?,
realpython.com/python-mutable-vs-immutable-types/ -
[4] 理解 Python 对象的可变性:5 个层次,
medium.com/techtofreedom/5-levels-of-understanding-the-mutability-of-python-objects-a5ed839d6c24 -
[5] Python 中的引用传递:背景与最佳实践,
realpython.com/python-pass-by-reference/ -
[6] Python 对象的浅拷贝与深拷贝,
realpython.com/copying-python-objects/ -
[7] Brian Okken,《Python Testing with pytest》,第二版(2022),The Pragmatic Programmers, LLC.
这些图表是我使用Miro制作的。封面图片也是我使用Titanic数据集和 GIMP(涂抹效果)制作的。
Pandas 索引和标题,你曾感到困惑过吗?

由作者在 Canva 中创建
从单级索引和标题到多级索引和标题,为什么以及如何实现?
·发表于Towards Data Science ·阅读时间:9 分钟·2024 年 6 月 9 日
--
有很多关于 Pandas 的教程,但大多数都在试图告诉我们一些技巧。我仍然记得当我刚接触 Python Pandas 库时,我曾对索引和标题感到困惑,尤其是在它们有多个级别的时候。
在这篇文章中,我将重点介绍与索引和标题相关的概念,以及 Pandas 数据框的重塑操作。我知道多级索引和多级标题是最让人困惑的概念,因此我将尽力通过许多示例在一个部分中进行解释。希望这篇文章能帮助你理解这些概念!
1. 基础

由作者在 Canva 中创建
让我们从创建一个简单的数据框开始,我们需要它来展示后面的示例。如果你是 Pandas 的新手,这也是一个很好的热身机会。
1.1 创建带有标题的数据框
import pandas as pd
data = {
'Name': ['Alice', 'Bob', 'Chris'],
'Age': [25, 30, 35]
}…
Pandas:我参与贡献一个重要开源项目的经历
开源
你也许值得参与其中
·发布于 Towards Data Science ·18 分钟阅读 ·2024 年 4 月 19 日
--

图片来源:Markus Winkler 在 Pexels
开源项目通常依赖于众多人的贡献,以保持它们无错、安全、更新并不断向前发展。
这些人是谁?嗯,“他们”可能就是你我!除了花费一些时间和精力外,没有任何东西能阻止我们贡献。所以,值得我们花时间吗?如果我们决定尝试一下,究竟涉及些什么呢?
我决定参与贡献一个广为人知且被广泛使用的开源库—— Pandas.
希望我在这个过程中所经历的能够为你提供一些关于所涉及内容的见解,并可能突出一种你也许能够参与其中的方式。或者,至少它能够让你对幕后发生的事情有一个很好的了解!
介绍
在我之前的一篇 文章 中,我提到了使用 pandas 库生成 箱线图 时遇到的一个 bug。
我最初查找了这个 bug,看看是否有其他人遇到同样的问题。我发现了一个 stackoverflow 的帖子,看起来也正是我遇到的那个问题:
Pandas vs. Polars — 该是时候切换了吗?
想要将数据处理管道的速度提高至 10 倍吗?也许是时候说再见给 Pandas 了。
·发表于Towards Data Science ·阅读时间 7 分钟·2024 年 4 月 7 日
--

图片来源:Hans-Jurgen Mager提供,来自Unsplash
在一个按秒计费的计算世界里,尽可能地将计算时间最小化是合乎逻辑的。甚至要更进一步。
Python 庞大的数据处理生态系统对于初学者来说非常友好,但随着数据集的增大,扩展变得十分困难。并行处理、查询优化和延迟计算在 Pandas 中是闻所未闻的概念,但如果你想在大规模生产环境中使用 Python,这些概念是必须理解的。
介绍Polars。它是一个从头开始编写的 Python 库,专注于性能优化。Polars 具有一个用 Rust 编写的多线程查询引擎,这意味着你可以期待看到比 Pandas 快得多的数据处理速度,甚至比 Pandas 快 30 到 50 倍。
今天,你将看到如何通过一系列 4 个基准测试,比较 Polars 与 Pandas 在处理一个有 1100 万行的 CSV 文件时的表现。
但首先,让我们来回顾一下为什么你应该考虑将 Polars 作为 Pandas 的替代品。
Pandas vs. Polars — 为什么作为数据专业人士你应该考虑 Polars
论文回顾 — 用于软件开发的交互式代理
对“ChatDev”AI 代理论文的详细回顾
·发表于 Towards Data Science ·11 分钟阅读·2024 年 6 月 8 日
--

ChatDev论文封面截图
在阅读并回顾了生成代理论文后,我决定探索 AI 编码代理的世界。我这段旅程的下一站是题为“用于软件开发的交互式代理”的论文,也被称为ChatDev。
ChatDev 通过利用大型语言模型,通过人类用户与 AI 代理之间仅用自然语言交流,简化整个软件开发过程,从而为软件开发呈现了一种创新的范式。
正如你所猜测的,这是一个雄心勃勃的任务,这也使得这篇论文成为一篇同样令人兴奋的阅读材料。
从本质上讲,ChatDev 是一个虚拟的、由聊天驱动的软件开发公司,将软件代理聚集在一起,进行编码、设计、测试并生产指定的应用程序。
在这篇文章中,我们将解释这项工作的动机,接着深入探讨 ChatDev 的架构。最后,我们将展示这篇论文的发现,并分享我们对这项工作的个人看法。开始吧!
为什么我们希望 AI 来构建软件应用程序?
对许多人来说,软件是我们世界的魔法。就像在神秘的领域里,巫师能够施展魔法创造物理对象一样,在我们现实中,软件工程师能够创建各种程序,来增强、自动化并提升我们的生活。
然而,构建软件并非易事。它需要硬技能、团队合作、经验、直觉和品味。它也是昂贵的。
这些因素使得自动化创建软件变得困难。
世界各地的许多人和企业都希望创建软件程序以获取利润或娱乐,但他们没有足够的技能和资本去实现这一目标。这使得我们面临着巨大的未开发潜力和无法满足的机会,这些机会本可以改善人们的生活并丰富经济。
然而,近年来人工智能,特别是深度学习和大型语言模型的进步,使得我们能够在适度的成功水平上应对这一挑战。
在 ChatDev 中,研究人员提出了一个雄心勃勃的任务,旨在通过利用大型语言模型的力量生成整个软件程序。
ChatDev 架构
ChatDev 是一家虚拟的、由聊天驱动的软件开发公司,模仿了传统的瀑布模型来构建软件。它通过精确地将开发过程分为四个不同的时间顺序阶段:设计、编码、测试和文档编写。
每个阶段开始时,会招聘一组专业的软件代理。例如,一个阶段可能涉及招聘首席技术官(CTO)、程序员和设计师代理。

截图 — 阶段 + 聊天链 — 来自 ChatDev 论文
每个阶段进一步细分为原子级的聊天,这些聊天被称为聊天链,代表着两个代理之间的中间任务解决聊天序列。每个聊天的设计目的是实现一个特定目标,并计入构建所需应用程序的总体目标中。
这些聊天按顺序连接在一起,以便将来自两个 AI 代理的先前聊天结果传播到后续的涉及两个其他 AI 代理的聊天中。
解决代码幻觉问题
ChatDev 解决的关键挑战之一是代码幻觉问题,这种问题可能出现在直接使用大型语言模型(LLM)生成整个软件系统时。
这些幻觉可能包括不完整的功能实现、缺失的依赖关系和未发现的错误。研究人员将这一现象归因于两个主要原因:
1. 缺乏细粒度和具体性:尝试一次性生成所有代码,而不是将目标分解为如语言选择和需求分析等阶段,可能会导致大型语言模型(LLM)的混乱。
2. 缺乏交叉检查和自我反思:对某个代理所做工作的反馈不充分或不具针对性,导致生成的代码不正确,而大型语言模型(LLM)没有纠正这些错误。
为了解决这些挑战,ChatDev 采用了一种新颖的方法,将开发过程分解为顺序的原子子任务,每个子任务涉及两个角色之间的协作互动和交叉检查。
这是一个高效的框架,能够促进代理之间的强大合作,从而提高目标软件的质量控制,确保代理能够成功构建所需的软件。
ChatDev 阶段的结构
每个阶段都以角色 专业化步骤开始,在该步骤中为该阶段招募合适的代理并赋予他们必须执行的角色。
链条中的每个聊天都由两个代理组成,它们分别承担以下角色之一:
-
指导者 代理:发起对话,并引导对话朝向任务完成。
-
助手 代理:遵循指导者代理给出的指令,并朝着完成任务的方向努力。
指导者和助手通过多轮对话进行合作,直到他们一致认为已经成功完成任务。
阶段 1 — 设计
该阶段涉及 CEO、CTO 和 CPO 代理。
在初始阶段,角色专业化是通过初始提示实现的。初始提示是一种来自CAMEL 论文的技术,旨在将原始陈述扩展为更具体的提示,并为指导者和助手代理提供明确的目标,以便它们共同完成任务。

截图 — 初始提示示例 — 来自于CAMEL 论文
类似于在生成代理中,ChatDev 中也使用了记忆流(Memory Stream)。记忆流包含了每个阶段和每个特定链条的对话历史。
与生成代理中的记忆流不同,ChatDev 的研究人员没有使用检索模块,也没有实现记忆反射。这可能是由于阶段和链条的顺序性质,使得从前一步骤流动的信息是可预测且容易访问的。
为了完成一个聊天,指导者和助手在多轮对话结束时通过相同的格式发出相同的消息,例如“
自我反思机制用于当两个代理已达成共识,但没有使用预期的字符串来结束他们的对话时。在这种情况下,系统会创建一个助手的伪自我,并与后者发起一个新的聊天(更多细节请参见上图)。

截图 — 设计阶段的步骤 — 来自于ChatDev 论文
在此聊天中,伪自我要求助手总结助手与指导者之间的对话历史,以便从对话中提取结论性信息。
阶段 2 — 编码
该阶段涉及 CTO、程序员和设计师代理。
编码阶段进一步分解为以下几个聊天:
生成完整代码: CTO 指示程序员根据设计阶段产生的规范编写代码。这些规范包括所选的编程语言(例如 Python),当然还包括要构建的应用程序类型。程序员认真地生成代码。
设计图形用户界面: 程序员指示设计师设计相关的用户界面。设计师则通过文本到图像的工具(如稳定扩散或 OpenAI 的 DALLe 等扩散模型)提出一个具有图标的友好图形用户界面供用户交互。然后,程序员将这些视觉资产集成到应用程序中。
ChatDev 使用面向对象编程语言(如 Python)生成代码,因为它具有强大的封装性和通过继承实现的重用性。此外,系统只向代理展示代码的最新版本,并从记忆流中删除之前的代码版本,以减少幻觉现象。

截图 — 思维指令 — 来自 ChatDev 论文
为了进一步抵抗幻觉,采用了思维指令。在思维指令中,代理之间的角色暂时交换。例如,CTO 和程序员交换角色片刻。这时,CTO 询问未实现的方法,使得程序员能够专注于代码库的特定部分。
本质上,通过思维指令,将一个大的任务(例如实现所有未实现的方法)拆分成更小的任务(例如,先实现方法 1,再实现方法 2,依此类推)。思维指令本身来源于链式思维提示。
第三阶段 — 测试
测试阶段涉及将所有组件集成到系统中,并使用来自解释器的反馈信息进行调试。此阶段涉及三个角色:程序员、审查者和测试者。
参与的聊天内容如下:
同行评审: 审查者代理检查源代码,识别潜在问题而不运行它(静态调试)。审查者代理尝试发现明显的错误、遗漏和可以改进的代码。
系统测试: 测试者代理通过程序员代理使用解释器(动态调试)进行的测试,验证软件的执行,重点评估通过黑盒测试应用程序的性能。
此处再次使用思维指令来调试程序的特定部分,测试者分析错误、提出修改建议,并根据这些建议指导程序员。
此外,ChatDev 允许人类客户以自然语言提供反馈和建议,这些反馈和建议会被纳入审查和测试过程。
第四阶段 — 文档编写
文档编写阶段包括生成软件系统的环境规范和用户手册。此阶段涉及四个角色:CEO、CPO、CTO 和程序员。
使用少量示例的提示,通过上下文示例,代理生成各种文档文件。
首席技术官指示程序员提供配置说明和依赖要求(例如 Python 的 requirements.txt),而首席执行官将需求和系统设计传达给首席产品官,后者生成用户手册。

截图 — 文档编写阶段步骤 — 来自ChatDev 论文
大型语言模型用于根据提供的提示和示例生成文档,生成一整套完整的文档文件,以支持软件系统的部署和使用。
评估与观察
在 70 个软件任务的评估中,ChatDev 展示了令人印象深刻的结果:
-
每个软件平均生成 17.04 个文件,包括代码文件、设计师创建的资产文件和文档文件。
-
生成的软件通常包含 39 到 359 行代码,平均为 131.61 行,这部分是由于通过面向对象编程实现代码复用。
-
审阅者和程序员之间的讨论导致了近 20 种代码漏洞的识别和修改,例如“模块未找到”,“属性错误”和“未知选项”错误。
-
测试人员和程序员之间的互动导致识别和解决了 10 多种潜在的错误,其中最常见的是由于令牌长度限制或外部依赖问题导致的执行失败。
-
使用 ChatDev 开发的软件平均成本为 0.2967 美元,显著低于传统定制软件开发公司的开支。
-
开发小型软件平均需要 409.84 秒,当然与人类软件公司开发类似应用所需的几周(或几个月)时间相比,表现得非常有利。

截图 — ChatDev 生成软件所需时间分析 — 来自ChatDev 论文
研究人员承认的局限性
尽管这些结果令人鼓舞,研究人员也承认了若干局限性。
即使使用低温(例如 0.2),研究人员仍然观察到生成的代码输出中存在随机性。这意味着同一个应用程序的代码在不同的运行之间可能会有所不同。因此,研究人员承认在此阶段,ChatDev 最好用于头脑风暴或创意工作。
有时,由于用户体验差或需求理解错误,软件无法满足用户需求。
此外,设计师代理缺乏视觉和风格一致性可能会显得突兀。这是因为生成与给定风格或品牌一致的视觉资产仍然困难(现在可以通过 LoRAs 来解决这个问题)。

截图 — ChatDev 生成的五子棋游戏示例 — 来自ChatDev 论文
研究人员还强调了当前 LLM(大语言模型)的偏见,这导致生成的代码看起来不像任何一个人类开发者可能写的代码。
最后,研究人员指出,利用他们的资源,很难对 ChatDev 生成的软件进行全面评估。对所生成应用程序的真正评估需要人类的参与,包括:
-
软件工程师
-
设计师/用户体验专家
-
测试人员
-
用户
个人批评
就个人而言,尽管这项工作是一个令人兴奋的进展,我也想表达我对其中一些内容的保留意见。
首先,现在大多数软件团队都采用敏捷开发方法,这使得在面对变化的用户需求时能提供更多的灵活性。瀑布式开发虽然仍用于某些项目,但如今已不再是常态。看看 ChatDev 如何迭代以适应更动态的软件开发生命周期将会非常有趣。
我建议我们用一个更直接、更精炼的提示替代起始提示,这个提示应直接来自用户。起始提示可能会编造需求,或者无法完全捕捉最终用户的意图。
当时使用的模型(gpt 3.5 turbo)仅有 16K tokens 的上下文窗口,这大大限制了使用 ChatDev 构建应用程序的范围和复杂度。
看起来 ChatDev 生成的代码并不是在沙盒内执行,而是直接在用户的机器上运行。这带来了许多安全风险,未来需要加以解决。

动画 — ChatDev 可视化工具 — 来自ChatDev 源代码
ChatDev 对我来说并不完全有效。当我尝试运行它生成一个棋盘游戏时,它确实生成了一些代码,但在运行时,我只看到一个空白的桌面应用程序。这可能是因为我使用的是 Python 3.12,而论文中使用的是 Python 3.8。
结语
ChatDev 代表了实现为软件开发构建具有代理能力的 AI 系统愿景的令人兴奋的一步。通过使用一个多阶段过程,结合具有记忆和反思能力的大型语言模型,ChatDev 展示了高效且具有成本效益的软件生成潜力。
虽然仍然存在一些挑战需要克服,比如解决基础语言模型的偏差问题以及确保系统的鲁棒性评估,但 ChatDev 范式代表了我们继续推动 AI 能力边界时,展现出来的令人兴奋的可能性。
如果你对 AI 代理感兴趣并希望进一步探索这个领域,我强烈建议阅读 ChatDev 论文。你可以在这里访问它。
此外,研究人员已开源了一个名为 SRDD(软件需求描述数据集)的多样化数据集,旨在促进基于自然语言的软件创建研究。你可以在这里找到该数据集。
至于我,我将继续探索 AI 代理,研究我自己的Python AI 代理库,阅读更多论文,并通过每日推文在 Twitter/X 上分享我的思考和发现。
我最近推出了 Kiseki Labs,一家通过工作坊、战略咨询和定制解决方案帮助企业实施生成式 AI 的咨询公司。如果你有兴趣合作,可以在 kisekilabs.com 预约免费咨询,或者通过 X 或 LinkedIn 与我联系。
论文解析:Attention Is All You Need
完整的从零开始实现 Transformer 的指南
·发表于Towards Data Science ·阅读时间 42 分钟·2024 年 11 月 3 日
--

图片由Samule Sun提供,来源于Unsplash
介绍
正如标题所示,本文将从零开始用 PyTorch 实现 Transformer 架构——没错,真的是从零开始。在开始之前,让我简要介绍一下该架构。Transformer 首次出现在 2017 年 Vaswani 等人撰写的论文《Attention Is All You Need》中[1]。该神经网络模型旨在执行seq2seq(序列到序列)任务,它接收一个序列作为输入,并期望返回另一个序列作为输出,应用场景包括机器翻译和问答系统。
在 Transformer 出现之前,我们通常使用基于 RNN 的模型,如 LSTM 或 GRU,来完成seq2seq任务。这些模型确实能够捕获上下文,但它们是按顺序处理的。这种方法使得捕捉长程依赖变得具有挑战性,尤其是当重要的上下文信息远在当前时间步之前时。相比之下,Transformer 能够自由地关注序列中任何它认为重要的部分,而不受顺序处理的限制。
Transformer 组件
论文解读:神经风格迁移
使用深度学习将您的照片变成画作——从零开始实现 NST,使用 PyTorch
·发表于Towards Data Science ·阅读时长:23 分钟·2024 年 12 月 3 日
--

图片由Birmingham Museums Trust提供,来源于Unsplash
介绍
最近,“生成式 AI”成为全球热议的话题,这得益于像 ChatGPT、Gemini、Claude 等公开发布的 AI 模型。正如我们所知,这些模型最初只能理解和生成文本,但不久之后,它们也获得了对图像执行相同操作的能力。更具体地说,关于图像数据的生成模型,实际上我们可以使用很多不同的模型变体,每个模型都有其特定的目的。到目前为止,我已经在 Medium 上发布了关于图像数据生成 AI 的一些文章,比如自动编码器和变分自动编码器(VAE),这些文章我会在本文末尾附上链接。在今天的文章中,我将介绍另一种令人着迷的生成算法:神经风格迁移(NST)。
NST 首次出现在 2015 年 Gatys 等人撰写的论文《A Neural Algorithm of Artistic Style》中[1]。论文中解释了他们的主要目标是将一幅图像(通常是一幅画)的艺术风格转移到另一幅图像上,这也是“风格迁移”名称的由来。请看…
论文解读:U-Net
一个关于最流行的语义分割模型之一的 PyTorch 实现。
·发表在Towards Data Science·阅读 17 分钟·2024 年 9 月 20 日
--

照片由Caleb Jones在Unsplash拍摄
U-Net 简介
当我们谈论图像分割时,我们不应忘记 U-Net,这是一个神经网络架构,最早由 Ronneberger 等人于 2015 年提出[1]。该模型最初旨在执行医学图像的分割任务。后来,其他研究人员发现该架构实际上也可用于一般的语义分割任务。此外,还可以利用该模型进行其他任务,如超分辨率(即将低分辨率图像放大为高分辨率图像)和扩散(即从噪声生成图像)。在本文中,我想向您展示如何使用 PyTorch 从头开始实现 U-Net。您可以在图 1 中看到整个 U-Net 架构。通过查看这个结构,我认为这个网络是如何得名的就很明显了。

图 1. U-Net 架构[1]。
架构中有几个关键组件。首先是收缩路径,也称为编码器。该组件负责逐渐缩小空间维度…
论文解析:Vision Transformer(ViT)
通过从零开始的 PyTorch 实现探索 Vision Transformer(ViT)。
·发表于 Towards Data Science ·阅读时长 19 分钟 ·2024 年 8 月 13 日
--

简介
Vision Transformer —— 通常缩写为 ViT —— 可以被视为计算机视觉领域的一个突破。在与视觉相关的任务中,通常使用基于 CNN 的模型,而这些模型迄今为止总是比任何其他类型的神经网络表现得更好。直到 2020 年,一篇名为“一张图片值 16×16 个单词:用于大规模图像识别的 Transformer”的论文由 Dosovitskiy 等人 [1] 发布,才提供了比 CNN 更强的能力。
在 CNN 中,单个卷积层通过使用卷积核提取特征。由于卷积核的大小相对于输入图像较小,因此它只能捕捉到该小区域内的信息。换句话说,我们可以简单地说它侧重于提取局部特征。为了理解图像的全局上下文,需要堆叠多个卷积层。ViT 通过直接从初始层捕获全局信息来解决这个问题。因此,在 ViT 中堆叠多个层可以实现更加全面的信息提取。
大语言模型时代评估的范式转变
大语言模型在评估方法上需要一些微妙、概念上简单但却非常重要的变化。
·发表于 Towards Data Science ·阅读时长 7 分钟·5 天前
--
在我的职业生涯中,我一直在为机器学习系统构建评估。在 Quora 担任数据科学主管时,我们为信息流排序、广告、内容审核等构建了评估系统。我的 Waymo 团队为自动驾驶汽车构建了评估系统。最近,在我们的金融科技创业公司 Coverbase 中,我们使用大语言模型来缓解第三方风险管理的困难。通过这些经验,我逐渐意识到,大语言模型在评估方法上需要一些微妙、概念上简单但却非常重要的变化。
本文的目标不是为您的大语言模型应用提供具体的评估技术,而是提出这三种范式转变:
-
评估是蛋糕,而不再是糖霜。
-
基准评估差异。
-
将人工分拣作为评估的重要组成部分。
我需要说明的是,我的讨论主要集中在大语言模型的应用上,而非基础模型的开发。此外,尽管标题如此,我在这里讨论的内容也适用于其他生成系统(受到我在自动驾驶领域经验的启发),不仅仅是大语言模型应用。
1. 评估是蛋糕,而不再是糖霜。
评估在机器学习(ML)发展中一直很重要,不论是否涉及大语言模型(LLM)。但我认为,评估在大语言模型的发展中尤为重要,原因有二:
a) 评估的重要性上升,因为在构建 LLM 应用时,调整的自由度较低,导致非评估工作的时间减少。在 LLM 开发中,基于基础模型(例如 OpenAI 的 GPT 或 Anthropic 的 Claude 模型)进行开发时,应用层可以调整的参数更少。而这些参数的调整速度也要快得多(警告:调整速度快,并不代表能更快达到正确的结果)。例如,修改提示语(prompt)显然比为一个梯度提升决策树编写一个新的手工特征要快速得多。因此,非评估工作的时间减少,导致评估所占时间比例增加。

作者提供的图片
b) 评估的绝对重要性上升,因为生成性人工智能的输出自由度更高,使得评估任务变得更加复杂。与分类或排序任务相比,生成性人工智能任务(例如:写一篇关于 X 的文章、制作 Y 的图像、为自动驾驶车辆生成轨迹)可以有无数种可接受的输出。因此,评估是将高维空间投射到低维空间的过程。例如,对于一个 LLM 任务,可以衡量:“输出文本是否真实?”,“输出是否包含有害内容?”,“语言是否简洁?”,“是否经常以‘当然!’开头?”等。如果二元分类任务中的精确度和召回率是对这些二元输出的无损测量(衡量你所看到的内容),那么我之前列出的 LLM 任务的示例指标就是对输出文本的有损测量(衡量你所看到内容的低维表示)。这要正确执行要困难得多。
这种范式转变对团队规模和招聘在大语言模型(LLM)应用项目中的实践意义重大。
2. 基准对比差异。
这是理想的场景:我们设定一个目标指标并不断在其上进行改进。

作者提供的图片
现实情况呢?
你几乎无法在图表中绘制超过 2 个连续的点!
这些可能对你来说很熟悉:
在第一次发布后,我们获得了一个更大的数据集,因此新的度量值与旧的度量值不再是苹果对苹果的比较。我们也无法在新数据集上重新运行旧模型——也许系统的其他部分已经升级,我们无法查看旧的提交以重现旧模型;也许评估指标是一个 LLM 作为评判者,而数据集非常庞大,因此每次评估运行的成本非常高,等等。
在第二次发布后,我们决定更改输出模式。例如,之前我们指示模型输出“是/否”答案;现在我们指示模型输出“是/否/也许/我不知道”。因此,之前精心策划的真实标签集不再有效。
在第三次发布后,我们决定将单一的 LLM 调用分解为两个调用的组合,并且需要评估这些子组件。我们需要为子组件评估准备新的数据集。
….
关键是,在 LLM 时代,开发周期往往太快,以至于无法对同一度量进行纵向跟踪。
那么解决方案是什么?
测量增量。
换句话说,接受图表上只有两个连续数据点的事实。关键是确保每个模型版本比前一个版本更好(在当时你所知道的情况下),即使很难知道其表现的绝对水平。
假设我有一个基于 LLM 的语言辅导器,它首先将输入分类为英语或西班牙语,然后提供语法提示。一个简单的度量可以是“英语/西班牙语”标签的准确率。现在,假设我对提示做了一些修改,想知道新提示是否提高了准确率。与其手动标记一个大型数据集并计算其准确率,另一种方法是只关注旧提示和新提示产生不同标签的数据点。我不能通过这种方式知道任何模型的绝对准确率,但我会知道哪个模型的准确率更高。

作者提供的图片
我需要澄清的是,我并不是说基准测试绝对值没有价值。我只是说我们应该意识到这样做的成本,而基准测试增量——尽管不能完全替代——可以是一种更加具有成本效益的方式,来得到一个方向性的结论。这种范式转变的一个更根本的原因是,如果你从零开始构建你的机器学习模型,你通常必须整理一个大的训练集,因此评估数据集通常是这个过程的副产品。而这在零-shot 和少-shot 学习中(例如,LLM)并非如此。
作为第二个例子,假设我有一个基于 LLM 的度量:我们使用一个独立的 LLM 来判断我在 LLM 语言辅导器中生成的解释是否足够清晰。有人可能会问:“既然评估现在是自动化的,基准测试增量仍然比基准测试绝对值便宜吗?”是的。因为度量现在更复杂了,你可以不断改进度量本身(例如,进行基于 LLM 的度量的提示工程)。首先,我们仍然需要评估评估;基准测试增量可以告诉你新版本的度量是否更好。其次,随着基于 LLM 的度量的发展,如果我们只专注于比较 LLM 语言辅导器模型的两个相邻版本,就不必费心去填补所有旧版本的基准结果,用新的基于 LLM 的度量版本来重新评估。
基准测试增量可以成为一种有效的内部循环、快速迭代机制,同时节省外部循环、低频率迭代中基准测试绝对值或纵向跟踪的高成本。
3. 将人工筛选作为评估的一个组成部分。
如上所述,精心筛选一个黄金数据集,一劳永逸地将其作为永恒基准使用的梦想可能无法实现。筛选将成为开发过程中的一个不可或缺的、持续的部分,无论是直接筛选 LLM 输出,还是筛选 LLM 作为判断者或其他更复杂的指标。我们应继续使评估尽可能可扩展;关键是,尽管如此,我们不应指望能够消除人工筛选。我们越早接受这一点,就能越早在工具投资上做出正确决策。
因此,无论我们使用何种评估工具,无论是内部工具还是外部工具,都应该有一个便捷的人工筛选界面。一个简单的界面可以像下面这样。结合前面提到的差异基准,界面可以呈现并排面板,用户可以轻松浏览结果。它还应允许你轻松记录筛选的笔记,以便将其作为黄金标签用于未来的基准测试(从而减少未来的筛选负担)。

图像来自作者
更高级的版本理想情况下应该是盲测,即筛选者不知道哪个版本是哪个。我们通过数据反复验证了,当不进行盲测时,开发人员即便是出于最佳意图,也会有潜在的偏见,偏向自己开发的版本。
一旦发现这三种范式转变,适应起来其实相当直接。挑战不在于解决方案的复杂性,而在于在激动人心的快速开发节奏中,提前识别这些转变。我希望分享这些思考能帮助其他在自己工作中面临类似挑战的人。
Parquet 文件格式:你需要了解的一切
新的数据格式需要新的存储方式。了解你需要知道的关于 Parquet 文件格式的所有内容
·发表于 Towards Data Science ·8 分钟阅读·2024 年 7 月 18 日
--

作者提供的图片
随着近年来数据量呈指数增长,最大的挑战之一是找到最优的方式存储各种数据格式。与(不久前的)过去不同,当时关系型数据库被视为唯一的选择,现在的组织希望能够对原始数据进行分析——比如社交媒体情感分析、音频/视频文件等——这些数据通常无法以传统(关系型)方式存储,或者以传统方式存储需要大量的努力和时间,这样会增加整体的分析时间。
另一个挑战是如何坚持使用传统的方法以结构化的方式存储数据,但又不需要设计复杂且耗时的 ETL 工作负载,将这些数据迁移到企业数据仓库中。此外,如果你组织中的一半数据专业人员精通 Python(如数据科学家、数据工程师),另一半(如数据工程师、数据分析师)精通 SQL,那么你会坚持让“Python 用户”学习 SQL 吗?还是反过来呢?
使用 LayoutLM 和 Label Studio 解析您的发票
使用 Transformers 库、Label Studio 和 AWS S3 对您的发票进行 LayoutLM 的微调。
·发表于Towards Data Science ·阅读时长 34 分钟·2024 年 4 月 16 日
--
从发票中提取信息长期以来一直是公司、机构和会计人员重复且繁琐的任务。
这个任务可以自动化吗?答案是肯定的。
这就是机器学习的承诺:处理成千上万的文档并提取所有相关信息。
许多公司,如Rossum、Digitoo或Docsumo,都是基于这一简单想法创建的,并且累计筹集了数亿美元,证明了这种技术的需求。
您也可以创建自己的解决方案。
在本文中,我将引导您通过构建一个微调您公司文档的发票解析器的过程。
我们介绍LayoutLM,这是由微软开发的一个著名模型,用于从文档中提取信息。为了定制适合我们特定需求的解决方案,我们使用Label Studio进行文档标注,这是一款开源标注工具,并将其与我们的远程存储AWS S3连接。
让我们开始吧!

使用 Label Studio 进行发票标注,以便进行 LayoutLM 训练(图片来自作者)
LayoutLM:文档图像理解的布局
粒子群优化
优化任意函数的最迷人的方式
·发布于数据科学探索 ·阅读时间 7 分钟·2024 年 1 月 10 日
--

图片由James Wainscoat提供,来自Unsplash
无论我们处理的是机器学习、运筹学还是其他数值领域,我们都有一个共同的任务,那就是优化函数。根据不同的领域,出现了一些常用的方法:
-
在机器学习中,当训练神经网络时,我们通常使用梯度下降。这之所以有效,是因为我们处理的函数是可微的(至少在几乎所有地方——见 ReLU)。
-
在运筹学中,我们经常处理可以通过线性(或凸)规划解决的线性(或凸)优化问题。
如果我们能够应用这些方法,那总是非常棒的。然而,对于优化一般函数——所谓的黑箱优化——我们必须借助其他技术。其中一个特别有趣的技术是所谓的粒子群优化,在本文中,我将向你展示它是如何工作的以及如何实现它。
请注意,这些算法并不总是能给出最佳解,因为它是一种高度随机和启发式的算法。尽管如此,它仍然是你工具箱中一个很好的技术,当你遇到难以优化的函数时,应该尝试一下!
粒子群优化
在 Python Pytest 中将函数传递到测试文件
PYTHON 编程
这是一个非常常见的问题,但解决方法非常简单:使用 fixture。
·发布于Towards Data Science ·7 分钟阅读·2024 年 5 月 27 日
--

Fixtures 有助于在测试函数中重用对象,包括函数。图片由rivage提供,Unsplash
当你使用 Pytest 进行 Python 单元测试时,可以通过在conftest.py文件中定义的 Pytest fixtures 将对象传递给测试文件。我用了复数形式,因为你可以定义任意数量的conftest.py文件,而它们的位置决定了作用域。
当我刚接触 Python 单元测试时,我常常想,我该如何定义一个测试辅助函数,并在特定的测试函数中使用它。如果这个函数只在一个测试文件中需要,没问题:只需在这个文件中定义函数,然后就可以使用了。
如果你需要在多个测试文件中使用同一个函数怎么办?一种解决方案是在所有这些文件中重新定义该函数——但这显然违反了DRY 原则(Don’t Repeat Yourself)。那么,如何实现呢?
我注意到很多人都在问这个问题——所以写了这篇文章。你会发现,只要你了解pytest fixtures 的概念,解决方案其实非常简单,甚至是相当自然的。
Fixtures
本文并不打算介绍 Pytest fixtures。因为我计划写一篇...
Python 中的路径表示
停止使用字符串表示路径,改用 pathlib
·发表于Towards Data Science ·阅读时间:5 分钟·2024 年 8 月 20 日
--

由Pawel Czerwinski拍摄,图片来源于Unsplash
与文件系统打交道是看似微不足道的任务,但即使是经验丰富的开发者,也可能因此出其不意。我是第一个承认——我也犯过不少错误。我遇到的最常见的反模式之一就是在 Python 中将文件路径表示为字符串。
是时候重新考虑这种方法了。
在今天的文章中,我们将探讨为什么使用字符串(甚至是os模块)来表示文件路径是一场灾难的前奏。我们将深入了解最佳实践,并看到pathlib包如何帮助你编写更清晰、更易维护的代码。
为什么使用字符串表示路径是个糟糕的主意
如果你曾经参与过需要在不同操作系统上运行的项目,那么你一定知道处理路径的痛苦。不同的系统有不同的路径表示方式。基于 Unix 的系统(如 Linux 和 macOS)使用正斜杠/,而 Windows 使用反斜杠\。这是一个小细节,但如果不小心,它会引发巨大的麻烦。
# Unix (e.g. Linux, OSX, etc.)
/home/this/is/a/path/to/a/directory
# Windows
C:\home\this\is\a\path\to\a\directory
爪子、利爪和代码:6 个必知的 Python 示例
与 Python 一起庆祝国际猫咪日
·发表于 Towards Data Science ·阅读时长 11 分钟·2024 年 7 月 31 日
--

图片由 Seidenperle 提供,来自 Pixabay
国际猫咪日(8 月 8 日)是一个完美的时机,能够将我们对猫咪的喜爱与数据科学和 Python 编程的力量结合起来。如果你是猫咪主人——或者对猫咪有任何了解——我相信你会知道它们以一些古怪的行为著称,比如似乎攻击任何移动的物体(或者至少会盯着它们看,如果它们像我其中一只猫一样懒惰……)、狂吃食物以及长时间睡觉。
在本文中,我们将通过 6 个简单的 Python 代码示例来探索这些迷人的特征,适合初学者,从简单的示例到更高级的概念,如面向对象编程、GUI 开发和单元测试,这些都是经验丰富的软件开发人员在实践中经常使用的。
请放心,Python 不会吃掉你的猫咪
这些示例旨在简洁且易于初学者理解,但它们也构成了现实世界应用中常用的重要概念的基础:
-
攻击移动物体(带时间暂停的基础函数)
-
吃饭习惯(具有可选输入的函数)
-
睡眠模式(面向对象编程和类型提示)
在 Python 中使用 PCA 和 K-Means 进行交通数据分析
基于每小时交通数据,减少维度并对台北捷运车站进行聚类
·发表于 Towards Data Science ·阅读时间 8 分钟·2024 年 5 月 7 日
--

台北铁路地图(实际上基于罗马化标准引入)包括高铁、台铁、台北捷运及其他线路。图片来源:Taiwan J。
主成分分析(PCA)已经在交通数据中用于检测异常,但它也可以用于捕捉交通站点历史数据的模式,类似于它在顾客购买数据上的应用。
在本文中,我们将讨论:
-
PCA 做了什么技巧
-
应用 PCA 后我们可以做什么
-
玩得开心!
看看我们的数据集:
完整代码也包含在上述的 Kaggle 数据集中。
-
在每小时交通数据上使用 PCA
-
对 PCA 结果进行聚类
-
关于台北捷运交通的洞察
-
主要收获
1. PCA 做了什么技巧
简而言之,PCA 通过找到特征的线性组合来总结数据,这可以理解为拍摄一个三维物体的几张照片,并且它会在交给你之前自然地按最具代表性的照片到最不具代表性的照片进行排序。
输入是我们的原始数据,PCA 会输出 2 个有用的结果:Z 和 W。通过将它们相乘,我们可以得到重构数据,即原始数据,但有一些可以容忍的信息丢失(因为我们已经减少了维度)。
我们将在下面的实践中解释这 2 个输出矩阵与我们的数据的关系。
2. 应用 PCA 后我们可以做什么
在将 PCA 应用于我们的数据以降维后,我们可以将其用于其他机器学习任务,如聚类、分类和回归。
在本文稍后的台北捷运案例中,我们将在较低维度的数据上进行聚类,其中几个维度可以解释为一天中不同时间段的乘客比例,例如早晨、午间和傍晚。那些白天具有相似乘客比例的车站会被认为在同一簇中(它们的模式相似!)。
3. 看看我们的交通数据集!
我们这里使用的数据是台北捷运系统每小时交通数据,包含列:date, hour, origin, destination, passenger_count。
在我们的案例中,我只保留工作日的数据,因为在工作日不同车站之间有更多有趣的模式。例如,位于住宅区的车站白天可能有更多的通勤乘客进入,而在晚上,位于商业区的车站可能有更多人进站。

位于住宅区的车站白天可能有更多的通勤乘客进入。
上面的图是 4 个不同车站的每小时交通趋势(乘客进入车站的数量)。红色的两条线是新埔和永安市场,这两个车站实际上位于新北市的超拥挤区域。另一方面,蓝色的两条线是台北市政府和忠孝复兴站,这里是大多数公司所在地,也是商业活动发生的地方。
这些趋势反映了这些区域和车站的特点,我们可以注意到,比较它们在通勤时间(早上 7 点到 9 点,以及下午 5 点到 7 点)期间的趋势时,差异最为明显。
4. 使用 PCA 处理每小时交通数据
为什么在进行进一步的机器学习任务之前要进行降维?
有两个主要原因:
-
随着维度的增加,所有数据点在许多方面看起来都变得稀疏且不相似,这就是所谓的“维度灾难”。
-
由于交通数据的高维特性,它难以可视化和解释。
通过应用 PCA,我们可以识别出不同车站的交通趋势最为明显和具有代表性的时间段。直观地,通过之前显示的图,我们可以假设早上 8 点和下午 6 点左右的时间段可能足够具有代表性,能够对车站进行聚类。
记得我们在上一节提到过 PCA 的有用输出矩阵 Z 和 W 吗?在这里,我们将用 MRT 案例来解释它们。
原始数据,X

-
索引:车站
-
列:小时
-
数值:特定小时进入的乘客比例(对于每个车站:#乘客 / #总乘客数)
通过这样的 X,我们可以通过以下代码应用 PCA:
from sklearn.decomposition import PCA
n_components = 3
pca = PCA(n_components=n_components)
X_tran = StandardScaler().fit_transform(X)
pca.fit(X_tran)
在这里,我们指定参数n_components为 3,这意味着 PCA 将为我们提取最重要的三个主成分。
请注意,这就像是“拍摄多个三维物体的照片,并按最具代表性的顺序排序,我们选择前三张照片”,因此如果我们将n_components设为 5,我们将得到两张额外的照片,但我们的前三张照片仍然保持不变!
PCA 输出,W 矩阵
W可以被视为每个特征(即小时数)相对于我们的“图片”的权重,或者更具体地说,主成分。
pd.set_option('precision', 2)
W = pca.components_
W_df = pd.DataFrame(W, columns=hour_mapper.keys(), index=[f'PC_{i}' for i in range(1, n_components+1)])
W_df.round(2).style.background_gradient(cmap='Blues')

对于我们的三个主成分,我们可以看到 PC_1 在夜间时段的权重大,而 PC_2 在中午时段的权重大,PC_3 则与早晨时间相关。
PCA 输出,Z 矩阵
我们可以将Z矩阵解读为车站的表现。
Z = pca.fit_transform(X)
# Name the PCs according to the insights on W matrix
Z_df = pd.DataFrame(Z, index=origin_mapper.keys(), columns=['Night', 'Noon', 'Morning'])
# Look at the stations we demonstrated earlier
Z_df = Z_df.loc[['Zhongxiao_Fuxing', 'Taipei_City_Hall', 'Xinpu', 'Yongan_Market'], :]
Z_df.style.background_gradient(cmap='Blues', axis=1)

在我们的案例中,既然我们已经解读了 W 矩阵并理解了每个成分的潜在含义,我们可以为这些主成分命名。
这四个车站的 Z 矩阵表明,前两个车站的夜间时段占比更大,而另外两个车站则更多是在早晨时段。这一分布也支持我们在 EDA 中的发现(回想一下早期部分这四个车站的折线图)。
5. 使用 K 均值聚类对 PCA 结果进行聚类
在得到 PCA 结果后,我们将进一步根据车站的交通模式(由 3 个主成分表示)对车站进行聚类。
在最后一部分中,Z 矩阵表示了各个车站在夜间、中午和早晨的表现。
我们将根据这些表示对车站进行聚类,使得同一组中的车站在这三个时间段的客流分布相似。
聚类方法有很多种,比如 K 均值、DBSCAN、层次聚类等等。由于这里的主要话题是展示 PCA 的便捷性,我们将跳过实验哪些方法更合适的过程,直接使用K 均值。
from sklearn.cluster import KMeans
# Fit Z matrix to K-Means model
kmeans = KMeans(n_clusters=3)
kmeans.fit(Z)
在拟合 K 均值模型后,让我们通过plotly绘制三维散点图来可视化这些聚类结果。
import plotly.express as px
cluster_df = pd.DataFrame(Z, columns=['PC1', 'PC2', 'PC3']).reset_index()
# Turn the labels from integers to strings,
# such that it can be treated as discrete numbers in the plot.
cluster_df['label'] = kmeans.labels_
cluster_df['label'] = cluster_df['label'].astype(str)
fig = px.scatter_3d(cluster_df, x='PC1', y='PC2', z='PC3',
color='label',
hover_data={"origin": (pca_df['index'])},
labels={
"PC1": "Night",
"PC2": "Noon",
"PC3": "Morning",
},
opacity=0.7,
size_max=1,
width = 800, height = 500
).update_layout(margin=dict(l=0, r=0, b=0, t=0)
).update_traces(marker_size = 5)

6. 台北捷运交通洞察——聚类结果

-
Cluster 0:白天有更多的乘客,因此它可能是“居住区”组。
-
Cluster 2:傍晚时段有更多的乘客,因此它可能是“商业区”组。
-
Cluster 1:白天和夜间都有大量的乘客进入车站,解释这些车站的特性更加复杂,因为不同车站可能有不同的原因。接下来,我们将深入分析该聚类中的两个极端案例。

例如,在集群 1中,拥有最多乘客的台北车站——台北车站,是台北的一个重要交通枢纽,通勤者可以在这里从公交和铁路系统换乘到地铁。因此,早晚高峰时段的高流量模式是显而易见的。
相反,台北动物园站也位于集群 1 中,但“白天和晚上的人流都很大”并不适用。实际上,在这两个时段,那里的人不多,因为周围没有多少居民,而且大多数市民平日很少去台北动物园。
这两座车站的模式并不相似,尽管它们在同一集群中。这意味着集群 1 可能包含了许多实际上并不相似的车站。因此,未来我们需要对 K-Means 的超参数进行微调,例如集群的数量,像轮廓系数和肘部法则这样的方式会很有帮助。
结论
总结来说,
-
在交通数据上应用 PCA 来降低维度可以通过提取 3 个重要时段(早晨、中午、晚间)来实现,这些时段来自 21 个工作小时。
-
PCA 的输出是 W 和 Z 矩阵,其中 Z 可以视为车站在主成分(时间段)上的表示,而 W 可以视为主成分(时间段)在原始特征(小时)上的表示。
-
考虑到 W 矩阵有助于我们理解每个主成分的潜在含义。
-
聚类方法可以应用于 PCA 输出的 Z 矩阵。
请注意,为了专注于本文的主题,我们在这里跳过了 EDA 和超参数调优部分,但它们实际上非常重要。
感谢你读到这里!
希望你在台北的线上之旅愉快 🫶
进一步阅读
-
KMeans 超参数解释与示例,Sujeewa Kumaratunga 博士
-
如何在 Python 中结合 PCA 与 K-means 聚类?,Elitsa Kaloyanova
参考文献
-
DSCI 563 讲义,UBC 数据科学硕士课程,Varada Kolhatkar
-
高维数据上的 K 均值聚类,shivangi singh
-
维度灾难 — 机器学习的“诅咒”,Shashmi Karanam
除非另有说明,所有图片均来自作者。
Pearson 与 Spearman 相关性:在变量之间找到和谐
你应该为你的任务使用哪种相关性度量?了解你需要知道的所有关于 Pearson 和 Spearman 相关性的知识
·发表于Towards Data Science ·阅读时间:7 分钟·2024 年 1 月 18 日
--
想象一个交响乐团在演出前调音。每个音乐家都调整他们的音符,以与其他人和谐地融合,确保无缝的音乐体验。在数据科学中,数据集中的变量可以与乐团的音乐家相比较:理解它们之间的和谐或不和谐至关重要。

相关性是一个统计度量,像是乐团的指挥,指导我们理解数据中复杂的关系。在这里,我们将重点讨论两种相关性:Pearson和Spearman。
如果我们的数据是一个交响乐,Pearson 和 Spearman 就是我们的乐团指挥:他们有着独特的风格来诠释这场交响乐,各自有着不同的优势和细微之处。理解这两种不同的方法将帮助你提取洞察并理解变量之间的联系。
Pearson 相关性
Pearson 相关系数,用r表示,量化了两个变量之间线性关系的强度和方向…
受感知启发的图卷积用于音乐理解任务
这篇文章讨论了MusGConv,一种受感知启发的图卷积模块,适用于符号音乐应用。
·发表于Towards Data Science ·阅读时间 10 分钟·2024 年 7 月 9 日
--

引言
在音乐信息研究(MIR)领域,理解和处理乐谱的挑战一直在不断引入新的方法和途径。最近,许多基于图的技术被提出,作为解决音乐理解任务的方法,如声音分离、拍子检测、作曲家分类和罗马数字分析。
本文讨论了我最近的一篇论文,其中我介绍了一种新的图卷积模块,名为MusGConv,专门用于处理乐谱数据。MusGConv利用音乐感知原理,提升了图卷积在应用于音乐理解任务中的效率和性能。
理解问题
传统的 MIR 方法通常依赖于音乐的音频或符号表示。音频能够捕捉声音波的强度随时间变化,而符号表示如 MIDI 文件或乐谱则编码了离散的音乐事件。符号表示特别有价值,因为它们提供了更高层次的信息,这对于音乐分析和生成等任务至关重要。
然而,现有的基于符号音乐表示的技术通常借鉴计算机视觉(CV)或自然语言处理(NLP)方法。例如,将音乐表示为矩阵格式的“钢琴卷轴”并将其类似于图像,或将音乐表示为一系列符号并通过序列模型或变换器处理。这些方法虽然有效,但可能无法完全捕捉到音乐的复杂多维特性,包括音符的层次关系和复杂的音高-时间关系。一些最新的方法已经提出将音乐乐谱建模为图,并应用图神经网络来解决各种任务。
将音乐乐谱作为图
基于图神经网络(GNN)的方法在音乐乐谱中的基本思想是将音乐乐谱建模为一个图,其中音符是顶点,边是基于音符之间的时间关系构建的。为了从音乐乐谱创建图,我们可以考虑四种类型的边(请参见下图以查看乐谱上图的可视化):
-
起始边:连接具有相同起始时间的音符;
-
连续边(或下一个边):如果音符 x 的偏移量与音符 y 的起始时间对应,则连接音符 x 和音符 y;
-
期间边:如果音符 y 的起始时间落在音符 x 的起始时间和结束时间之间,则连接音符 x 和音符 y;
-
休止边(或静音边):连接休止符前的最后一个音符和其后的第一个音符。

GNN 可以处理从音符和这四种类型关系中创建的图。
介绍 MusGConv
MusGConv 旨在利用音乐乐谱图并通过将音乐感知原理融入图卷积过程来增强这些图。它专注于音乐的两个基本维度:音高和节奏,考虑它们的相对表示和绝对表示。

绝对表示指的是可以归属于每个音符的特征,例如音符的音高或拼写、持续时间或任何其他特征。另一方面,相对特征是通过音符对之间计算的,例如两个音符之间的音程、它们的起始时间差,即它们发生的时间等。
MusGConv 的主要特征
-
边特征计算:MusGConv 基于音符之间的起始时间、持续时间和音高计算边特征。可以对边特征进行归一化,以确保它们在神经网络计算中更加有效。
-
相对和绝对表示:通过同时考虑相对特征(作为边特征的音高之间的距离)和绝对值(作为节点特征的实际音高和时间),MusGConv 可以根据实际情况调整和使用更相关的表示。
-
与图神经网络的集成:MusGConv 模块可以轻松与现有的 GNN 架构集成,几乎不增加额外的计算成本,并且可以用于改进音乐理解任务,例如声部分离、和声分析、节奏检测或作曲家识别。
相对表示和绝对表示的重要性与共存,可以从音乐中的移调角度理解。想象一下相同的音乐内容被移调。那么,音符之间的音程关系保持不变,但每个音符的音高发生了变化。

相同内容通过大三度移调。顶部和底部音符之间的关系相同,但绝对音高发生了变化。
理解图神经网络(GNNs)中的消息传递
为了充分理解 MusGConv 卷积模块的内部工作原理,首先需要解释消息传递的原理。
什么是消息传递?
在 GNN 的上下文中,消息传递是一个过程,其中图中的顶点与它们的邻居交换信息,以更新自身的表示。这种交换使每个节点能够从图中收集上下文信息,然后用于预测任务。
消息传递过程通过以下步骤定义:
-
初始化:每个节点被分配一个特征向量,其中可以包含一些重要的属性。例如,在乐谱中,这可能包括每个节点/音符的音高、时值和起始时间。
-
消息生成:每个节点生成一条消息发送给它的邻居。消息通常包含节点的当前特征向量以及描述节点之间关系的任何边特征。消息可以是邻居节点特征的线性变换。
-
消息聚合:每个节点从其邻居处收集消息。聚合函数通常是一个置换不变的函数,例如求和、平均值或最大值,它将这些消息合并成一个单一的向量,确保节点能够捕获来自其整个邻域的信息。
-
节点更新:聚合后的消息用于更新节点的特征向量。此更新通常涉及应用神经网络层(如全连接层),然后是非线性激活函数(如 ReLU)。
-
迭代:步骤 2 至步骤 4 会根据指定的迭代次数或层数重复执行,从而使信息能够在图中传播。每次迭代时,节点会将来自越来越大邻域的信息整合进来。
MusGConv 中的消息传递
MusGConv 通过将绝对特征作为节点特征以及相对音乐特征作为边特征来改变标准的消息传递过程。这个设计是为了适应音乐数据的特点。
MusGConv 卷积通过以下步骤定义:
-
边缘特征计算:在 MusGConv 中,边缘特征通过音符的起始时间、持续时间和音高的差异来计算。此外,还包括音高类间隔(不考虑八度的音符间距离),提供了一种简化但有效的方法来量化音乐间隔。
-
消息计算:MusGConv 中的消息不仅包括源节点的当前特征向量,还包括从源节点到目标节点的上述边缘特征,使得网络在消息传递过程中能够利用邻居的绝对和相对信息。
-
聚合与更新:MusGConv 使用求和作为聚合函数,然而,它将当前节点的表示与其邻居信息的总和进行连接。

MusGConv 图卷积模块。
通过这样设计消息传递机制,MusGConv 试图保持音乐的相对感知特性(如音程和节奏),从而产生更有意义的音乐数据表示。
如果缺少边缘特征或故意不提供,则 MusGConv 计算两个节点之间的边缘特征为它们节点特征的绝对差异。带有边缘特征的 MusGConv 版本在实验中被称为 MusGConv(+EF)。
应用与实验
为了展示 MusGConv 的潜力,我将在下面讨论论文中进行的任务和实验。所有模型无论任务如何,都按下图所示的管道设计。当使用 MusGConv 时,GNN 模块被 MusGConv 模块替代。
我决定将 MusGConv 应用于四个任务:声音分离、作曲家分类、罗马数字分析和和弦进行检测。这些任务从图学习的角度呈现了不同的分类。声音分离是一个链路预测任务,作曲家分类是一个全局分类任务,和弦进行检测是一个节点分类任务,而罗马数字分析可以看作是一个子图分类任务。因此,我们不仅从音乐分析的角度,而且从整个图深度学习任务分类的范围来探索 MusGConv 的适用性。

一般图形管道在符号音乐理解任务中的示例
声音分离
声音分离是从多声部音乐片段中检测出各个单声部流的方法。以往的方法采用了 GNN 来解决这个任务。从 GNN 的角度来看,声音分离可以视为一个链路预测任务,即对于每一对音符,我们预测它们是否被一条边连接。链路预测的结果应当是一个图,其中在同一声部中的连续音符应该是连接在一起的。然后,声部就是预测图的连通分量。关于使用 GNN 进行声音分离的更多信息,请参考这篇论文。
对于声音分离,上述图中的流程适用于架构中的 GNN 编码器部分。链路预测部分则作为任务特定模块进行处理。使用 MusGConv 时,只需将 GNN 编码器中的卷积块替换为 MusGConv。这个简单的替换使得预测更加准确,错误更少。
由于深度学习系统的解释并非易事,因此很难准确指出性能提升的原因。从音乐的角度来看,同一声部中的连续音符通常具有较小的相对音高差异。MusGConv 的设计确实通过相对边特征突出了音高差异。然而,我还需要补充说,从个别观察来看,音乐并不严格遵循任何规则。
作曲家分类
作曲家分类是根据某些音乐片段识别作曲家的过程。以前基于 GNN 的方法处理此任务时,类似于上面展示的流程,它们接收一个分数图作为输入,然后包含一些全局池化层,将音乐片段的图转化为一个向量。然后,从该向量进行分类处理,类别即为预定义的作曲家。
再次强调,MusGConv 通过替换 GNN 卷积块,易于实现。在实验中,使用 MusGConv 确实在解决这个任务中非常有益。我的直觉是,相对特征与绝对特征相结合,能为作曲风格的组成提供更好的洞察。
罗马数字分析
罗马数字分析是一种和声分析方法,其中和弦用罗马数字表示。预测罗马数字的任务相当复杂。先前的架构使用了 GNN 和顺序模型的混合。此外,罗马数字分析是一个多任务分类问题,通常一个罗马数字会被分解成更简单的单独任务,以减少独特罗马数字的类别词汇量。最后,罗马数字分析的图形化架构还包括在图卷积之后的一个起始收缩层,该层将图转换为有序序列。这个起始收缩层收缩同时发生的音符组,并在分类时将它们分配到相同的标签。因此,这可以视为一个子图分类任务。我认为这个模型的解释值得单独写一篇文章,因此,我建议阅读这篇论文以获取更多见解。
然而,图中的一般图形管道仍然适用。顺序模型与多任务分类过程以及起始收缩模块完全属于任务特定的部分。然而,用 MusGConv 模块替换图卷积块似乎对这个任务和架构没有影响。我将此归因于任务和模型架构本身过于复杂。
终止式检测
最后,让我们讨论终止式检测。检测终止式可以视为与检测乐句结尾相似,它是音乐分析中的一个重要方面。之前的终止式检测方法采用了带有编码器-解码器 GNN 架构的 GNN。每个音符,直到现在我们知道它也对应图中的一个节点,被分类为终止式音符或非终止式音符。终止式检测任务包括许多特殊情况,如非常严重的类别不平衡以及注释歧义。如果你感兴趣,我再次建议查阅这篇论文。
在编码器中使用 MusGConv 卷积有助于检测终止式。我认为,相对和绝对特征的结合以及 MusGConv 的设计能够追踪在终止式附近经常发生的声部连接模式。
结果与评估
广泛的实验表明,MusGConv 在上述音乐理解任务中能够超越最先进的模型。下表总结了这些改进:

(F1)表示宏 F1 得分,其他情况下显示的是简单的准确率得分。
然而,不管表格多么没有生气,我更倾向于不深入探讨更多细节,以保持这篇博客的生动性和讨论的方向。因此,我邀请你查看原始论文,以获取有关结果和数据集的更多细节。
总结与讨论
MusGConv 是一个用于音乐的图卷积模块。它提供了一种简单的、受感知启发的图卷积方法,当应用于音乐理解任务时,能够提高 GNN(图神经网络)的表现。它的简洁性是其有效性的关键。在某些任务中,它非常有益,而在其他任务中则不那么明显。音乐中相对和绝对特征的归纳偏向是一个巧妙的技巧,可以神奇地提升你的 GNN 结果,但我的建议是始终保持一些保留意见。尽管可以尝试 MusGConv,但也不要忘记探索其他所有有趣的图卷积模块可能性。
如果你有兴趣尝试 MusGConv,相关代码和模型可以在 GitHub 上找到。
备注与致谢
本文中的所有图片均由作者提供。我想感谢我的共同作者 Francesco Foscarin,他为这项工作的贡献。
使用特征子集更有效地执行异常值检测
确定相关子空间:特征的子集,它们可以让你在表格数据上更有效地进行异常值检测
·发表于Towards Data Science ·阅读时长 28 分钟·2024 年 11 月 24 日
--
本文是系列文章的一部分,涉及在数据中识别异常值的挑战和可以使用的技术,包括使用 PCA、距离度量学习、共享最近邻、频繁模式异常因子、计数异常值检测器(一种基于多维直方图的方法),以及doping技术。本文还包含了我书中的一段摘录,Python 中的异常值检测。
在这里,我们探讨了一种技术,旨在创建一系列较小的异常值检测器,而不是单一的异常值检测器来检查数据集中的所有特征,每个检测器都只处理特征的子集(称为子空间)。
异常值检测中的挑战
在对表格数据进行异常值检测时,我们关注的是数据中最不寻常的记录——这些记录要么与同一数据集中的其他记录相比最为特殊,要么与之前的数据相比最为特殊。
寻找最有意义的异常值存在许多挑战,尤其是没有明确的统计异常定义来确定哪些数据中的异常值应被视为最强的异常值。此外,最相关的异常值(而不一定是最统计异常的)将取决于你的项目,并可能随着时间的推移而变化。
异常值检测中还存在许多技术挑战,其中之一是数据中存在许多特征时所出现的困难。如在之前与 Counts Outlier Detector 和 Shared Nearest Neighbors 相关的文章中所述,当特征很多时,我们通常会面临一个称为维度灾难的问题。
这对异常值检测有多个影响,包括使得距离度量变得不可靠。许多异常值检测算法依赖于计算记录之间的距离——为了识别作为异常值的记录,这些记录与其他记录相比非常相似,且与大多数其他记录有显著不同——也就是说,记录与其他记录的距离是近少数记录,远多数记录。
例如,如果我们有一个包含 40 个特征的表格,数据中的每条记录可以被视为 40 维空间中的一个点,并且其异常性可以通过该点与该空间中其他点之间的距离来评估。因此,这就需要一种衡量记录之间距离的方法。使用了多种度量方法,其中欧几里得距离相当常见(假设数据是数值型的,或已转换为数值)。因此,每条记录的异常性通常是根据它与数据集中其他记录之间的欧几里得距离来衡量的。
然而,当我们处理许多特征时,这些距离计算可能会出现问题,实际上,即使只有十个或二十个特征,距离度量的问题也可能出现,尤其是当特征数量达到三十、四十个甚至更多时。
需要注意的是,处理大量特征的问题并非所有异常值检测器都会遇到。例如,在使用单变量测试(例如 z-score 或四分位数范围测试,这些测试逐个考虑每个特征,独立于其他特征——在A Simple Example Using PCA for Outlier Detection中有更详细的描述)时,它们通常不会显得很重要,或者在使用像FPOF这样的类别型异常值检测器时。
然而,常用的大多数异常值检测器都是数值型多变量异常值检测器——这些检测器假设所有特征都是数值型的,并且通常会同时处理所有特征。例如,LOF(局部异常因子)和 KNN(k-近邻)是最广泛使用的两种检测器,它们都基于记录与其他记录在高维空间中的距离来评估每个记录的异常性。
基于记录与其他数据点的距离判断异常值的示例
请看下面的图。这展示了一个包含六个特征的数据集,并通过三个二维散点图表示。图中包括了两个可以合理认为是异常值的点,P1 和 P2。
现在来看 P1,它至少在特征 A 上与其他点相距较远。也就是说,仅考虑特征 A 时,P1 很容易被标记为异常值。然而,大多数检测器会考虑每个点到其他点的距离,使用所有六个维度,这不幸的是意味着 P1 未必会因为高维空间中的距离计算方式而显得突出。P1 在其他五个特征上是相当典型的,因此它在六维空间中的距离可能是相对正常的。
然而,我们可以看到,这种通用的异常值检测方法——通过检查每个记录到其他记录的距离——是相当合理的:P1 和 P2 是异常值,因为它们与其他点的距离较远(至少在某些维度上)。

KNN 和 LOF 算法
由于 KNN 和 LOF 是非常常用的异常值检测器,我们将在这里仔细研究它们,然后特别探讨如何在这些算法中使用子空间。
使用 KNN 异常值检测器时,我们选择一个 k 值,决定每个记录与多少个邻居进行比较。假设我们选择 10(在实际应用中,这通常是一个相对典型的值)。
对于每个记录,我们测量它与其 10 个最近邻的距离,这有助于我们了解每个点的孤立性和远离程度。然后,我们需要根据这 10 个距离为每个记录创建一个单一的异常值评分(即一个数字)。通常,我们会取这些距离的均值或最大值。
假设我们选择最大值(使用均值、中位数或其他函数也类似,尽管每个方法有其细微差别)。如果某个记录与其第 10 近邻的距离异常大,这意味着最多有 9 个记录与其相对较近(可能更少),而它与大多数其他点的距离则异常远,因此可以认为它是异常值。
使用 LOF 离群点检测器时,我们采用类似的方法,尽管其工作方式有所不同。我们同样查看每个点到其 k 个最近邻的距离,然后将其与这些 k 个邻居到它们的 k 个最近邻的距离进行比较。因此,LOF 通过衡量每个点相对于其邻域内其他点的离群程度来评估离群点。
也就是说,虽然 KNN 使用全局标准来判断与邻居之间的距离是否异常大,但 LOF 则使用局部标准来判断这些距离是否异常大。
LOF 算法的细节实际上要复杂一些,这两个算法(以及这些算法的多种变体)在具体差异上的含义在《Python 中的离群点检测》中有更详细的介绍。
这些本身就是有趣的考虑因素,但此处的主要观点是,KNN 和 LOF 都根据记录与其最近邻的距离来评估记录。而且,如果同时使用大量特征,这些距离度量可能会表现不佳(甚至完全失效),而通过一次只处理少量特征(子空间)可以大大减少这种情况。
使用子空间的思想在检测器不使用距离度量的情况下仍然有用,但在使用基于距离计算的检测器时,使用子空间的一些好处可能会更加明显。而且,像 KNN 和 LOF 这样的距离使用方法在检测器中是相当常见的。除了 KNN 和 LOF 之外,例如,Radius、ODIN、INFLO、LoOP 检测器,以及基于采样和聚类的检测器,都使用距离。
然而,维度灾难的问题也可能出现在其他检测器中。例如,ABOD(基于角度的离群点检测器)使用记录之间的角度来评估每条记录的离群程度,而不是使用距离。但其思路相似,使用子空间在使用 ABOD 时也同样有效。
此外,我将在下面介绍的其他子空间的好处同样适用于许多检测器,无论是否使用距离计算。不过,维度灾难在离群点检测中是一个严重的问题:当检测器使用距离计算(或类似的度量,如角度计算),且特征数量很多时,这些距离计算可能会失效。在上面的图示中,P1 和 P2 在仅考虑六个维度的情况下可能被良好检测到,甚至在使用 10 或 20 个特征时也可能如此,但如果有 100 个维度的话,所有点之间的距离最终可能会非常相似,P1 和 P2 就不再显得异常了。
特征数适中时的问题
除了与处理大量特征相关的问题外,即便是在特征数量相对较少的情况下,我们在试图识别数据集中最不寻常的记录时也可能会遇到困难。
虽然大量的特征可能会使记录之间计算出的距离变得毫无意义,但即使是适度数量的特征,也可能使仅在一两个特征上表现异常的记录更难被识别出来。
再次考虑之前展示的散点图,并在这里重复展示。点 P1 在特征 A 上是异常值(但在其他五个特征上并非如此)。点 P2 在特征 C 和 D 上表现异常,但在其他四个特征上并非如此。然而,当考虑这些点到其他点在 6 维空间中的欧几里得距离时,它们可能并不能可靠地突出表现为异常值。使用曼哈顿距离以及大多数其他距离度量时也是如此。

左侧窗格展示了点 P1 在 2D 数据空间中的位置。考虑到特征 A,这个点是异常的,但如果使用 6D 数据空间中的欧几里得距离,甚至在此图所示的 2D 数据空间中,它的异常性就较小了。这是一个使用额外特征可能适得其反的例子。在中间窗格中,我们看到另一个点——点 P2,它在 C-D 子空间中是异常值,但在 A-B 或 E-F 子空间中并不是。我们只需要特征 C 和 D 来识别这个异常值,再次强调,包含其他特征只会让 P2 更难被识别出来。
例如,即使在最左侧图表中显示的二维空间中,P1 与大多数其他点的距离也并不异常远。它的异常之处在于周围没有其他点(KNN 和 LOF 会检测到这一点),但是 P1 到该二维空间中其他点的距离并不异常:它与大多数其他点对之间的距离相似。
使用 KNN 算法,我们很可能能够检测到这一点,至少当 k 设置得相对较低时,例如设为 5 或 10——大多数记录的第 5 个(以及第 10 个)最近邻距离比 P1 更近。不过,当将所有六个特征纳入计算时,这一点就不如仅查看特征 A 或仅查看最左侧的图表(仅包含特征 A 和 B)时那么明显。
当仅考虑特征 C 和 D 时,点 P2 很明显是异常值。使用 KNN 检测器,假设 k 值为 5,我们可以识别它的 5 个最近邻,并且这些点到 P2 的距离会比该数据集中的典型点距离要大。
使用 LOF 检测器,同样设置 k 值为 5,我们可以比较 P1 或 P2 与其 5 个最近邻的距离与它们的 5 个最近邻之间的距离,在这种情况下,P1 或 P2 到它们的 5 个最近邻的距离会被发现异常大。
至少当只考虑特征 A 和 B,或特征 C 和 D 时,这一点比较直接,但当考虑完整的 6 维空间时,它们变得更难以识别为异常值。
尽管许多离群点检测器可能仍然能够在六个或稍多维度的情况下识别 P1 和 P2,但显然使用较少的特征会更容易且更可靠。要检测 P1,我们实际上只需要考虑特征 A;而要识别 P2,我们实际上只需要考虑特征 C 和 D。将其他特征包含进来反而会增加难度。
这实际上是离群点检测中的一个常见主题。我们通常处理的数据集有许多特征,每个特征都可能有用。例如,如果我们有一个包含 50 个特征的表格,可能所有 50 个特征都是相关的:无论是这些特征中的任何一个出现稀有值都可能是有趣的,还是这些 50 个特征中的两个或多个特征出现稀有值组合都可能是有趣的。那么,对于分析来说,保留所有 50 个特征是值得的。
但是,要识别任何一个异常,通常只需要少量特征。事实上,记录在所有特征中异常的情况非常罕见。基于许多特征的稀有组合来判断记录是否异常也非常罕见(请参阅Counts Outlier Detector以获取更多解释)。
任何给定的离群点可能在一两个特征上具有罕见的值,或在一对特征中具有罕见的值组合,或者在三个或四个特征的组合中具有罕见的值。即使其他特征可能对于检测其他行的异常很重要,识别该行的异常时,只有这些特征是必要的。
子空间
为了解决这些问题,离群点检测中的一个重要技术是使用子空间。术语子空间仅指特征的子集。在上面的例子中,如果我们使用以下子空间:A-B、C-D、E-F、A-E、B-C、B-D-F,以及 A-B-E,那么我们有七个子空间(五个二维子空间和两个三维子空间)。创建这些子空间后,我们会在每个子空间上运行一个(或多个)检测器,因此每条记录至少会运行七个检测器。
从现实角度看,当特征数超过六个时,子空间变得更加有用,通常即使是子空间本身也会有超过六个特征,而不仅仅是二三个特征,但目前来看,使用特征数量较少的小子空间的简单情况相对容易理解。
使用这些子空间,我们可以更可靠地将 P1 和 P2 识别为离群点。P1 很可能会被运行在特征 A-B、A-E 和 A-B-E 上的检测器打分较高。P2 很可能会被运行在特征 C-D 上的检测器识别出来,可能还会被运行在特征 B-C 上的检测器检测到。
然而,我们必须小心:仅使用这七个子空间,而不是覆盖所有特征的单一 6 维空间,将会错过任何罕见的特征组合,例如 A 和 D,或 C 和 E。这些组合可能会被覆盖所有六个特征的检测器检测到,也可能不会被检测到,但绝对无法通过一组检测器来检测到,因为这些检测器根本不会检查这些特征组合。
使用子空间确实有一些显著的好处,但也存在错过相关离群点的风险。我们将在下面介绍一些生成子空间的技术,以缓解这个问题,但仍然可以考虑在完整的数据空间上运行一个或多个离群点检测器。通常来说,使用离群点检测时,除非我们应用多种技术,否则很难找到我们感兴趣的完整离群点集。尽管子空间的使用非常重要,但仍然通常有必要使用多种技术,这可能包括在完整数据上运行一些检测器。
类似地,在每个子空间中,我们可能会执行多个检测器。例如,我们可能会同时使用 KNN 和 LOF 检测器,以及 Radius、ABOD,甚至可能是其他一些检测器——再次强调,使用多种技术可以帮助我们更好地覆盖希望检测的异常范围。
子空间的进一步动机
因此,我们已经看到了一些使用子空间的动机:我们可以缓解维度灾难,并且可以减少那些基于小数量特征且在众多特征中丢失的异常情况,从而无法可靠识别的情况。
除了处理类似的情况外,使用子空间进行离群点检测还有许多其他优势。包括:
-
由于使用集成方法的精度提升 — 使用多个子空间使我们能够创建集成(离群点检测器集合),这使我们能够结合多个检测器的结果。通常,使用检测器的集成方法比使用单一检测器提供更高的精度。这与集成预测器在分类和回归问题中通常比单一预测器更强的方式类似(尽管也存在一些实际差异)。在这里,使用子空间时,每个记录会被多次检查,这比任何单一检测器提供了更稳定的评估。
-
可解释性 — 结果可以更具可解释性,而可解释性通常是离群点检测中的一个关键问题。在离群点检测中,我们经常标记异常记录,认为它们可能是某种程度上的问题或关注点,通常这些记录会被手动检查。了解为什么它们不寻常是有效且高效地进行此操作所必需的。通过检测器标记的异常点,尤其是那些检查了许多特征的检测器,手动评估起来可能尤其困难;另一方面,通过只使用少数特征的检测器标记的异常点则更易于评估。
-
更快的系统 — 使用更少的特征可以帮助我们创建更快(且占用内存更少)的检测器。这可以加速拟合和推理过程,尤其是在处理执行时间与特征数量呈非线性关系的检测器时(例如,许多检测器的执行时间与特征数量的平方成正比)。根据检测器的不同,使用例如 20 个检测器,每个覆盖 8 个特征,可能比使用一个覆盖 100 个特征的单一检测器执行得更快。
-
并行执行 — 由于我们使用许多小检测器而不是一个大型检测器,因此可以在硬件资源支持的情况下,将拟合和预测步骤并行执行,从而加快执行速度。
-
随时间调节的简易性 — 使用许多简单的检测器可以创建一个更易于随时间调节的系统。在离群点检测中,我们通常只是在评估一个单一的数据集,并希望识别其中的离群点。但通常也会定期执行离群点检测系统,例如监控工业过程、网站活动、金融交易、输入到机器学习系统或其他软件应用的数据、这些系统的输出等。在这些情况下,我们通常希望随着时间的推移改进离群点检测系统,使我们能够更好地关注更相关的离群点。拥有一组简单的检测器,每个检测器基于少量特征,可以使这项工作更具可管理性。这使得我们能够随着时间的推移,增加更有用的检测器的权重,减少不太有用的检测器的权重。
选择子空间
如前所述,我们需要针对每个被评估的数据集,确定适当的子空间。然而,找到相关的子空间集合可能很困难,或者至少难以找到最优的子空间集合。也就是说,假设我们有兴趣找到任何异常的值组合,确定哪些特征集包含最相关的异常组合可能是困难的。
例如,如果一个数据集有 100 个特征,我们可以训练 10 个模型,每个模型覆盖 10 个特征。我们可以使用例如前 10 个特征作为第一个检测器,接下来的 10 个特征作为第二个检测器,依此类推。如果前两个特征存在一些具有异常值组合的行,我们将检测到这一点。但如果存在与第一个特征相关的异常值组合,而这些异常值组合与其余 90 个未包含在同一模型中的特征相关,我们将无法检测到这些。
我们可以通过使用更多的子空间来提高将相关特征组合在一起的可能性,但要确保所有应该在一起的特征集至少出现一次是很困难的,特别是在数据中存在基于三个、四个或更多特征的相关异常值时——这些特征必须至少在一个子空间中一起出现才能被检测到。例如,在一个员工费用表格中,你可能希望识别部门、费用类型和金额的稀有组合。如果是这样,这三个特征必须至少在一个子空间中一起出现。
所以,我们有以下几个问题:每个子空间中应该包含多少个特征,哪些特征应该放在一起,应该创建多少个子空间。
需要考虑的组合非常多。如果有 20 个特征,那么可能的子空间有²²⁰个,略多于一百万个。如果有 30 个特征,则有超过十亿个。如果我们提前决定每个子空间中包含多少个特征,组合的数量会减少,但仍然非常庞大。如果有 20 个特征,我们希望每个子空间有 8 个特征,则有 20 选 8,即 125,970 个组合。如果有 30 个特征,我们希望每个子空间有 7 个特征,则有 30 选 7,即 2,035,800 个组合。
我们可能希望采取的一种方法是保持子空间较小,这样有助于提高可解释性。最具可解释性的选项是每个子空间使用两个特征,它还允许简单的可视化。然而,如果我们有 d 个特征,我们将需要 d*(d-1)/2 个模型来覆盖所有组合,这可能是不可行的。如果有 100 个特征,我们将需要 4,950 个检测器。我们通常需要每个检测器使用至少几个特征,但不一定是大量特征。
我们希望使用足够的检测器,每个检测器使用足够的特征,使得每对特征理想情况下至少出现一次,并且每个检测器中的特征足够少,以便这些检测器之间具有显著不同的特征。例如,如果每个检测器使用 100 个特征中的 90 个,我们将很好地覆盖所有特征组合,但子空间仍然会相当大(从而抵消使用子空间的很多好处),而且所有子空间也会彼此非常相似(从而抵消创建集成方法的很多好处)。
虽然每个子空间中使用的特征数量需要平衡这些问题,但创建的子空间数量则稍微直观一些:从准确性角度看,使用更多子空间是绝对更好的,但计算开销更大。
寻找有用子空间的方法有几种广泛的策略。我在这里快速列出这些方法,然后在下面详细讨论一些。
-
基于领域知识——在这里我们考虑哪些特征组合可能具有我们认为值得注意的值组合。
-
基于关联 — 只有当一组特征以某种方式关联时,才可能出现值的异常组合。在预测问题中,我们通常希望最小化特征之间的相关性,但在异常值检测中,这些特征是最有用的,应该一起考虑。如果有异常情况出现,具有最强关联的特征将包含最有意义的异常值。
-
基于找到非常稀疏的区域 — 如果记录与数据中的大多数其他记录不同,通常会被视为异常值,这意味着它们位于数据的稀疏区域。因此,可以将包含大且几乎为空的区域的子空间视为有用的子空间。
-
随机 — 这是稍后将介绍的一种技术,称为FeatureBagging,虽然它可能不是最优的,但它避免了昂贵的关联搜索和稀疏区域搜索,并且在使用多个子空间的情况下,能够较好地工作。
-
穷举搜索 — 这是Counts Outlier Detector使用的方法。这仅限于具有少量特征的子空间,但结果是高度可解释的。它还避免了与选择可能子空间的子集相关的任何计算或偏差。
-
使用与已知异常值相关的特征 — 如果我们有一组已知的异常值,并且能够确定它们为何是异常值(相关特征),并且我们不希望识别未知的异常值(仅识别这些特定的异常值),那么我们可以利用这一点,确定与每个已知异常值相关的特征集,并为所需的各个特征集构建模型。
接下来我们将更详细地探讨其中的一些内容。
领域知识
让我们以一个数据集为例,特别是下面展示的开销表。如果我们查看此表格,我们可能能够确定我们感兴趣或不感兴趣的异常值类型。帐户和金额的异常组合,以及部门和帐户的异常组合,可能会引起我们的兴趣;而费用日期和时间的组合可能不会是有用的组合。我们可以继续以这种方式进行,创建少数几个子空间,每个子空间通常包含两个、三个或四个特征,这可以实现非常高效且可解释的异常值检测,标记出最相关的异常值。

开销表
这种方法可能会遗漏数据中存在某些关联的情况,尽管这种关联并不显而易见。因此,除了利用领域知识外,搜索数据中的关联也许是值得的。例如,我们可以发现特征之间的关系,例如,测试是否可以通过其他特征使用简单的预测模型准确预测某些特征。如果发现了这样的关联,它们值得进一步调查。
然而,发现这些关联可能对某些目的有用,但可能对异常值检测过程有用也可能无用。例如,如果账户和时间之间存在关系,这可能仅仅是因为人们通常使用的报销流程,而这种偏差可能是值得关注的,但更可能并不重要。
随机特征子空间
如果没有领域知识可以借鉴,随机创建子空间可能是有效的。这种方法快速且能够创建一组子空间,这些子空间往往能够捕捉到最强的异常值,尽管它也可能遗漏一些重要的异常值。
以下代码提供了创建一组随机子空间的示例。这个示例使用了一组八个特征,命名为 A 到 H,并基于这些特征创建子空间。
每个子空间首先选择目前使用最少的特征(如果有平局,则随机选择一个)。它使用一个名为ft_used_counts的变量来跟踪这一点。接着,它会逐一添加特征到这个子空间,每一步选择在当前子空间中与其他子空间中出现次数最少的特征。它使用一个名为ft_pair_mtx的特征来跟踪每对特征已经一起出现在多少个子空间中。通过这种方式,我们创建了一组子空间,使得每对特征大致相同的频率出现在子空间中。
import pandas as pd
import numpy as np
def get_random_subspaces(features_arr, num_base_detectors,
num_feats_per_detector):
num_feats = len(features_arr)
feat_sets_arr = []
ft_used_counts = np.zeros(num_feats)
ft_pair_mtx = np.zeros((num_feats, num_feats))
# Each loop generates one subspace, which is one set of features
for _ in range(num_base_detectors):
# Get the set of features with the minimum count
min_count = ft_used_counts.min()
idxs = np.where(ft_used_counts == min_count)[0]
# Pick one of these randomly and add to the current set
feat_set = [np.random.choice(idxs)]
# Find the remaining set of features
while len(feat_set) < num_feats_per_detector:
mtx_with_set = ft_pair_mtx[:, feat_set]
sums = mtx_with_set.sum(axis=1)
min_sum = sums.min()
min_idxs = np.where(sums==min_sum)[0]
new_feat = np.random.choice(min_idxs)
feat_set.append(new_feat)
feat_set = list(set(feat_set))
# Updates ft_pair_mtx
for c in feat_set:
ft_pair_mtx[c][new_feat] += 1
ft_pair_mtx[new_feat][c] += 1
# Updates ft_used_counts
for c in feat_set:
ft_used_counts[c] += 1
feat_sets_arr.append(feat_set)
return feat_sets_arr
np.random.seed(0)
features_arr = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
num_base_detectors = 4
num_feats_per_detector = 5
feat_sets_arr = get_random_subspaces(features_arr,
num_base_detectors,
num_feats_per_detector)
for feat_set in feat_sets_arr:
print([features_arr[x] for x in feat_set])
通常,我们会创建更多的基础检测器(每个子空间通常对应一个基础检测器,尽管我们也可以在每个子空间上运行多个基础检测器),但是为了简化,本示例仅使用了四个。这将输出以下子空间:
['A', 'E', 'F', 'G', 'H']
['B', 'C', 'D', 'F', 'H']
['A', 'B', 'C', 'D', 'E']
['B', 'D', 'E', 'F', 'G']
这里的代码将创建具有相同特征数量的子空间。让子空间覆盖不同数量的特征也有一个好处,因为这可以引入更多的多样性(这在创建集成模型时非常重要),但无论如何,通过使用不同的特征已经能够提供强大的多样性(只要每个子空间使用相对较少的特征,从而使得子空间大体上具有不同的特征)。
拥有相同数量的特征有几个好处。它简化了模型的调优,因为许多异常值检测器使用的参数依赖于特征的数量。如果所有子空间都有相同数量的特征,它们也可以使用相同的参数。
它还简化了分数组合,因为检测器之间的可比性更强。如果使用不同数量的特征,可能会生成在不同尺度上的分数,难以进行比较。例如,使用 k-最近邻(KNN)时,如果特征更多,我们预期邻居之间的距离会更大。
基于相关性的特征子空间
在创建子空间时,其他条件相同的情况下,尽可能将关联特征放在一起是非常有用的。在下面的代码中,我们提供了一个示例,展示如何根据相关性选择子空间。
测试关联性有几种方法。我们可以创建预测模型,尝试从每个特征中预测其他单一特征(这将捕捉到特征之间甚至是相对复杂的关系)。对于数值特征,最简单的方法可能是检查斯皮尔曼相关性,这种方法虽然无法检测到非单调关系,但能发现大多数强关联关系。以下代码示例使用的就是这种方法。
要执行代码,我们首先指定所需的子空间数量以及每个子空间中的特征数量。
该方法首先通过寻找特征之间的所有配对相关性并将其存储在一个矩阵中来执行。然后我们创建第一个子空间,方法是首先找到相关矩阵中的最大相关性(这会将两个特征添加到该子空间),然后遍历待添加到该子空间的其他特征数量。对于每个特征,我们从相关矩阵中选择一对特征,要求其中一个特征已经在子空间内,而另一个特征还未在子空间内。一旦该子空间包含足够的特征,我们就创建下一个子空间,继续选择相关矩阵中剩余的最大相关性,以此类推。
在这个例子中,我们使用一个真实数据集——baseball数据集,来自 OpenML(该数据集具有公开许可)。这个数据集包含了一些较强的相关性。例如,击球次数(At bats)和得分(Runs)之间的相关性为 0.94,这意味着任何显著偏离这一模式的值很可能是异常值。
import pandas as pd
import numpy as np
from sklearn.datasets import fetch_openml
# Function to find the pair of features remaining in the matrix with the
# highest correlation
def get_highest_corr():
return np.unravel_index(
np.argmax(corr_matrix.values, axis=None),
corr_matrix.shape)
def get_correlated_subspaces(corr_matrix, num_base_detectors,
num_feats_per_detector):
sets = []
# Loop through each subspace to be created
for _ in range(num_base_detectors):
m1, m2 = get_highest_corr()
# Start each subspace as the two remaining features with
# the highest correlation
curr_set = [m1, m2]
for _ in range(2, num_feats_per_detector):
# Get the other remaining correlations
m = np.unravel_index(np.argsort(corr_matrix.values, axis=None),
corr_matrix.shape)
m0 = m[0][::-1]
m1 = m[1][::-1]
for i in range(len(m0)):
d0 = m0[i]
d1 = m1[i]
# Add the pair if either feature is already in the subset
if (d0 in curr_set) or (d1 in curr_set):
curr_set.append(d0)
curr_set = list(set(curr_set))
if len(curr_set) < num_feats_per_detector:
curr_set.append(d1)
# Remove duplicates
curr_set = list(set(curr_set))
if len(curr_set) >= num_feats_per_detector:
break
# Update the correlation matrix, removing the features now used
# in the current subspace
for i in curr_set:
i_idx = corr_matrix.index[i]
for j in curr_set:
j_idx = corr_matrix.columns[j]
corr_matrix.loc[i_idx, j_idx] = 0
if len(curr_set) >= num_feats_per_detector:
break
sets.append(curr_set)
return sets
data = fetch_openml('baseball', version=1)
df = pd.DataFrame(data.data, columns=data.feature_names)
corr_matrix = abs(df.corr(method='spearman'))
corr_matrix = corr_matrix.where(
np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))
corr_matrix = corr_matrix.fillna(0)
feat_sets_arr = get_correlated_subspaces(corr_matrix, num_base_detectors=5,
num_feats_per_detector=4)
for feat_set in feat_sets_arr:
print([df.columns[x] for x in feat_set])
这将产生:
['Games_played', 'At_bats', 'Runs', 'Hits']
['RBIs', 'At_bats', 'Hits', 'Doubles']
['RBIs', 'Games_played', 'Runs', 'Doubles']
['Walks', 'Runs', 'Games_played', 'Triples']
['RBIs', 'Strikeouts', 'Slugging_pct', 'Home_runs']
PyOD
PyOD可能是目前 Python 中最全面、最常用的数值表格数据异常值检测工具。它包括大量的检测器,涵盖从非常简单到非常复杂的多种方法——其中包括几种基于深度学习的方法。
现在我们已经了解了子空间在异常值检测中的作用,接下来我们将看看 PyOD 提供的两个处理子空间的工具:SOD 和 FeatureBagging。两个工具都会识别一组子空间,对每个子空间执行检测器,并将结果合并为每条记录的单一分数。
无论是否使用子空间,都必须确定使用什么基础检测器。如果不使用子空间,我们将选择一个或多个检测器,并在整个数据集上运行它们。如果使用子空间,我们同样选择一个或多个检测器,并在每个子空间上运行这些检测器。如上所述,LOF 和 KNN 可以是合理的选择,但 PyOD 还提供了许多其他的选择,这些选择在每个子空间上执行时也能很好地工作,例如基于角度的离群点检测器(ABOD)、基于高斯混合模型(GMM)的模型、核密度估计(KDE)等。PyOD 之外的其他检测器也能非常有效地工作。
SOD(子空间离群点检测)
SOD(子空间离群点检测)专门设计来处理如上所示的情况。SOD 的工作方式类似于 KNN 和 LOF,通过为每个点识别 k 个邻居的邻域,这些邻居被称为参考集。不过,参考集是通过不同的方式找到的,使用的是一种叫做共享最近邻(SNN)的方法。
共享最近邻在这篇文章中有详细描述,但一般来说,如果两个点是由相同机制生成的,它们不仅会接近,而且往往会有许多相同的邻居。因此,任何两个记录的相似度可以通过它们共享的邻居数量来衡量。基于这一点,邻域可以通过使用不仅是欧几里得距离最小的点集(如 KNN 和 LOF 方法所做的),还可以使用共享最多邻居的点来识别。这种方法即使在高维空间中,甚至在有许多无关特征的情况下,也能保持鲁棒性:即使在这些情况下,邻居的排名顺序仍然保持有意义,因此即使在无法计算具体距离的情况下,最近邻集仍然能够可靠地找到。
一旦我们得到了参考集,就可以用它来确定子空间,这里指的是解释参考集方差最大的一组特征。一旦识别出这些子空间,SOD 就会检查每个点到数据中心的距离。
下面我提供了一个使用 SOD 的快速示例。假设已经安装了 pyod,安装步骤如下:
pip install pyod
我们将使用一个合成数据集作为示例,这使我们可以通过实验数据和模型超参数,更好地理解每个检测器的优缺点。这里的代码提供了一个处理 35 个特征的示例,其中两个特征(特征 8 和 9)是相关的,其他特征是无关的。一个单独的离群点是通过这两个相关特征的不寻常组合创建的。
SOD 能够将已知的一个异常点识别为最显著的异常点。我将污染率设置为 0.01,以指定返回(假设有 100 条记录)仅一个异常点。然而,当测试超过 35 个特征时,SOD 将此点的得分大大降低。此示例将参考集的大小指定为 3;使用不同的值可能会得到不同的结果。
import pandas as pd
import numpy as np
from pyod.models.sod import SOD
np.random.seed(0)
d = np.random.randn(100, 35)
d = pd.DataFrame(d)
#A Ensure features 8 and 9 are correlated, while all others are irrelevant
d[9] = d[9] + d[8]
# Insert a single outlier
d.loc[99, 8] = 3.5
d.loc[99, 9] = -3.8
#C Execute SOD, flagging only 1 outlier
clf = SOD(ref_set=3, contamination=0.01)
d['SOD Scores'] = clf.fit (d)
d['SOD Scores'] = clf.labels_
我们下面展示了四个散点图,显示了 35 个特征中的四对特征。在每个图中,已知的异常点被标记为星号。我们可以在第二个面板中看到特征 8 和 9(这两个相关特征),并且可以看到该点明显是一个异常点,尽管它在所有其他维度中是典型的。

测试 SOD 在 35 维数据上的表现。数据中插入了一个异常点,并且可以在第二个面板中清楚地看到特征 8 和 9。尽管该点在其他维度中是典型的,它还是被 SOD 标记为最显著的异常点。第三个面板也包含特征 9,我们可以看到该点在这里有些不同寻常,尽管与其他维度中的许多点相比并无太大区别。特征 8 和 9 之间的关系最为相关,而 SOD 似乎能够检测到这一点。
FeatureBagging
FeatureBagging 的设计目的是解决与 SOD 相同的问题,尽管它采用了不同的方法来确定子空间。它完全随机地创建子空间(与上面的示例略有不同,后者会记录每对特征一起被放入子空间的频率,并尝试平衡这一点)。它还会对每个基本检测器进行行采样,从而在检测器之间提供更多的多样性。
使用指定数量的基本检测器(默认值为 10,尽管使用更多的检测器会更好),每个检测器都会选择一组随机的行和特征。对于每个检测器,可以选择的最大特征数被指定为一个参数,默认为所有特征。因此,对于每个基本检测器,FeatureBagging 会:
-
确定要使用的特征数量,直到达到指定的最大值。
-
随机选择如此多的特征。
-
随机选择一组行。这是一个与行数相同大小的自助样本。
-
创建一个 LOF 检测器(默认为此;也可以使用其他基本检测器)来评估子空间。
一旦完成,每一行将通过每个基本检测器进行评分,然后必须将这些评分合并为每一行的单一最终评分。PyOD 的 FeatureBagging 提供了两种合并评分的选项:使用最大评分和使用平均评分。
正如我们在上面的散点图中所看到的,某些子空间中的点可能是强异常值,而在其他子空间中则不是,从这些子空间中取平均分数,可能会稀释它们的得分,进而削弱使用子空间的好处。在其他异常值检测的集成方法中,使用均值通常有效,但在处理多个子空间时,使用最大值通常是两个选项中更好的。这样,我们根据记录在最异常的子空间中的得分来给每个记录打分。这也不是完美的,可能还有更好的选择,但使用最大值既简单又几乎总是比均值更可取。
任何检测器都可以在子空间中使用。PyOD 默认使用 LOF,就像最初描述 FeatureBagging 的论文一样。LOF 是一个强大的检测器,是一个合理的选择,尽管你可能会发现使用其他基础检测器能得到更好的结果。
在原始论文中,子空间是随机创建的,每个子空间使用 d/2 到 d-1 之间的特征,其中 d 是特征的总数。一些研究人员指出,原始论文中使用的特征数量可能远大于适当的数量。
如果特征的总数较大,同时使用超过一半的特征将使维度灾难生效。而且,在每个检测器中使用许多特征将导致检测器之间的相关性(例如,如果所有基础检测器都使用 90%的特征,它们将使用大致相同的特征,并且倾向于为每条记录打上相似的分数),这也可能会消除创建集成模型的许多好处。
PyOD 允许设置每个子空间中使用的特征数量,通常应该设置得比较低,并创建大量的基础估计器。
使用其他检测器
在本文中,我们探讨了子空间作为改善异常值检测的一种方式,包括减少维度灾难、提高可解释性、支持并行执行、简化时间上的调优等。每个因素都是重要的考虑因素,使用子空间通常非常有帮助。
然而,通常还有其他方法可以用于这些目的,有时作为替代方案,有时与使用子空间结合使用。例如,为了提高可解释性,重要的是尽可能选择本身具有可解释性的模型类型(例如,单变量测试,如 z 分数测试、计数异常值检测器,或 PyOD 提供的一个名为 ECOD 的检测器)。
当主要关注点是减少维度灾难时,在这里再次,查看那些在许多特征上表现良好的模型类型可能会很有用,例如隔离森林(Isolation Forest)或计数异常值检测器(Counts Outlier Detector)。此外,执行单变量测试或应用PCA也是有益的。
正在进行的异常值检测项目
在构建子空间时需要注意的一点是,如果子空间是基于相关性或稀疏区域形成的,那么随着数据的变化,相关的子空间可能会发生变化。随着新特征之间的关联出现或新的稀疏区域形成,这些区域可能对识别离群点有用,但如果不定期重新计算子空间,可能会错过这些信息。通过这种方式找到相关的子空间可以非常有效,但可能需要在某些时间表上进行更新,或者当已知数据发生变化时进行更新。
结论
在表格数据的离群点检测项目中,通常值得考虑使用子空间,特别是在我们有许多特征的情况下。使用子空间是一种相对简单的技术,并具有许多显著的优势。
在面临大数据量、执行时间或内存限制等问题时,使用PCA也可能是一种有用的技术,并且在某些情况下可能比创建子空间更有效,尽管使用子空间(因此,使用原始特征而不是 PCA 生成的组件)通常更具可解释性,而可解释性在离群点检测中往往非常重要。
子空间可以与其他技术结合使用,以提高离群点检测的效果。例如,使用子空间可以与其他方法结合来创建集成:通过子空间(在集成中的不同检测器使用不同的特征)以及不同的模型类型、不同的训练数据、不同的预处理等,能够创建更大的集成。这可以带来一些额外的好处,但也会增加计算量。
所有图片均由作者提供
Spark 流处理中 Sigma 规则检测的性能洞察
在网络安全日志中利用 Sigma 规则进行异常检测:关于性能优化的研究
·发布于 Towards Data Science ·14 分钟阅读·2024 年 6 月 1 日
--

图片来源:Ed Vazquez,Unsplash
加拿大网络安全中心(CCCS)的职责之一是尽快检测异常并发布缓解措施。
在将我们的 Sigma 规则检测投入生产时,我们在 Spark 流处理应用程序中发现了一个有趣的现象。运行一个包含 1000 条 Sigma 检测规则的大型 SQL 语句的速度比运行五个单独的查询要慢,每个查询应用 200 条 Sigma 规则。这令人惊讶,因为运行五个查询会迫使 Spark 阅读源数据五次,而不是一次。有关更多细节,请参考我们的系列文章:
## 使用 Sigma 规则进行异常检测(第一部分):利用 Spark SQL 流处理
Sigma 规则用于检测网络安全日志中的异常。我们使用 Spark 结构化流处理来评估 Sigma…
towardsdatascience.com
鉴于我们需要执行的大量遥测数据和检测规则,每一点性能提升都能带来显著的成本节省。因此,我们决定调查这个奇特的观察结果,旨在解释它并可能发现额外的性能提升机会。在这个过程中我们学到了一些东西,并希望与更广泛的社区分享。
简介
我们的直觉是,我们达到了 Spark 代码生成的限制。因此,了解一下这个主题的背景是必要的。2014 年,Spark 引入了代码生成来评估形如(id > 1 and id > 2) and (id < 1000 or (id + id) = 12)的表达式。Databricks 的一篇文章对此做了很好的解释:Spark SQL 的激动人心的性能提升即将来临
两年后,Spark 引入了全阶段代码生成。这项优化将多个操作符合并成一个单一的 Java 函数。像表达式代码生成一样,全阶段代码生成消除了虚拟函数调用,并利用 CPU 寄存器处理中间数据。然而,它与表达式级别的生成不同,它是应用于操作符级别的。操作符是执行计划中的节点。欲了解更多信息,请阅读Apache Spark 作为编译器:在笔记本电脑上每秒连接十亿行数据
为了总结这些文章,让我们生成这个简单查询的执行计划:
explain codegen
select
id,
(id > 1 and id > 2) and (id < 1000 or (id + id) = 12) as test
from
range(0, 10000, 1, 32)
在这个简单的查询中,我们使用了两个操作符:Range 用于生成行,Select 用于执行投影。我们可以在查询的物理计划中看到这些操作符。注意节点旁边的星号(*)以及它们相关的[codegen id : 1]。这表示这两个操作符通过全阶段代码生成被合并为一个单一的 Java 函数。
|== Physical Plan ==
* Project (2)
+- * Range (1)
(1) Range [codegen id : 1]
Output [1]: [id#36167L]
Arguments: Range (0, 10000, step=1, splits=Some(32))
(2) Project [codegen id : 1]
Output [2]: [id#36167L, (((id#36167L > 1) AND (id#36167L > 2)) AND ((id#36167L < 1000) OR ((id#36167L + id#36167L) = 12))) AS test#36161]
Input [1]: [id#36167L]
生成的代码清楚地展示了这两个操作符是如何被合并的。
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean range_initRange_0;
/* 010 */ private long range_nextIndex_0;
/* 011 */ private TaskContext range_taskContext_0;
/* 012 */ private InputMetrics range_inputMetrics_0;
/* 013 */ private long range_batchEnd_0;
/* 014 */ private long range_numElementsTodo_0;
/* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 016 */
/* 017 */ public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 018 */ this.references = references;
/* 019 */ }
/* 020 */
/* 021 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 022 */ partitionIndex = index;
/* 023 */ this.inputs = inputs;
/* 024 */
/* 025 */ range_taskContext_0 = TaskContext.get();
/* 026 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 027 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 028 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 030 */
/* 031 */ }
/* 032 */
/* 033 */ private void project_doConsume_0(long project_expr_0_0) throws java.io.IOException {
/* 034 */ // common sub-expressions
/* 035 */
/* 036 */ boolean project_value_4 = false;
/* 037 */ project_value_4 = project_expr_0_0 > 1L;
/* 038 */ boolean project_value_3 = false;
/* 039 */
/* 040 */ if (project_value_4) {
/* 041 */ boolean project_value_7 = false;
/* 042 */ project_value_7 = project_expr_0_0 > 2L;
/* 043 */ project_value_3 = project_value_7;
/* 044 */ }
/* 045 */ boolean project_value_2 = false;
/* 046 */
/* 047 */ if (project_value_3) {
/* 048 */ boolean project_value_11 = false;
/* 049 */ project_value_11 = project_expr_0_0 < 1000L;
/* 050 */ boolean project_value_10 = true;
/* 051 */
/* 052 */ if (!project_value_11) {
/* 053 */ long project_value_15 = -1L;
/* 054 */
/* 055 */ project_value_15 = project_expr_0_0 + project_expr_0_0;
/* 056 */
/* 057 */ boolean project_value_14 = false;
/* 058 */ project_value_14 = project_value_15 == 12L;
/* 059 */ project_value_10 = project_value_14;
/* 060 */ }
/* 061 */ project_value_2 = project_value_10;
/* 062 */ }
/* 063 */ range_mutableStateArray_0[2].reset();
/* 064 */
/* 065 */ range_mutableStateArray_0[2].write(0, project_expr_0_0);
/* 066 */
/* 067 */ range_mutableStateArray_0[2].write(1, project_value_2);
/* 068 */ append((range_mutableStateArray_0[2].getRow()));
/* 069 */
/* 070 */ }
/* 071 */
/* 072 */ private void initRange(int idx) {
/* 073 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 074 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(32L);
/* 075 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L);
/* 076 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 077 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 078 */ long partitionEnd;
/* 079 */
/* 080 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 081 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 082 */ range_nextIndex_0 = Long.MAX_VALUE;
/* 083 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 084 */ range_nextIndex_0 = Long.MIN_VALUE;
/* 085 */ } else {
/* 086 */ range_nextIndex_0 = st.longValue();
/* 087 */ }
/* 088 */ range_batchEnd_0 = range_nextIndex_0;
/* 089 */
/* 090 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 091 */ .multiply(step).add(start);
/* 092 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 093 */ partitionEnd = Long.MAX_VALUE;
/* 094 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 095 */ partitionEnd = Long.MIN_VALUE;
/* 096 */ } else {
/* 097 */ partitionEnd = end.longValue();
/* 098 */ }
/* 099 */
/* 100 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 101 */ java.math.BigInteger.valueOf(range_nextIndex_0));
/* 102 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue();
/* 103 */ if (range_numElementsTodo_0 < 0) {
/* 104 */ range_numElementsTodo_0 = 0;
/* 105 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 106 */ range_numElementsTodo_0++;
/* 107 */ }
/* 108 */ }
/* 109 */
/* 110 */ protected void processNext() throws java.io.IOException {
/* 111 */ // initialize Range
/* 112 */ if (!range_initRange_0) {
/* 113 */ range_initRange_0 = true;
/* 114 */ initRange(partitionIndex);
/* 115 */ }
/* 116 */
/* 117 */ while (true) {
/* 118 */ if (range_nextIndex_0 == range_batchEnd_0) {
/* 119 */ long range_nextBatchTodo_0;
/* 120 */ if (range_numElementsTodo_0 > 1000L) {
/* 121 */ range_nextBatchTodo_0 = 1000L;
/* 122 */ range_numElementsTodo_0 -= 1000L;
/* 123 */ } else {
/* 124 */ range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 125 */ range_numElementsTodo_0 = 0;
/* 126 */ if (range_nextBatchTodo_0 == 0) break;
/* 127 */ }
/* 128 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 129 */ }
/* 130 */
/* 131 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 132 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 133 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 134 */
/* 135 */ project_doConsume_0(range_value_0);
/* 136 */
/* 137 */ if (shouldStop()) {
/* 138 */ range_nextIndex_0 = range_value_0 + 1L;
/* 139 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 140 */ range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 141 */ return;
/* 142 */ }
/* 143 */
/* 144 */ }
/* 145 */ range_nextIndex_0 = range_batchEnd_0;
/* 146 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 147 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 148 */ range_taskContext_0.killTaskIfInterrupted();
/* 149 */ }
/* 150 */ }
/* 151 */
/* 152 */ }
project_doConsume_0函数包含了评估(id > 1 and id > 2) and (id < 1000 or (id + id) = 12)的代码。请注意,这段代码是如何被生成来评估这个特定表达式的。这是表达式代码生成的一个示例。
整个类是一个具有processNext方法的操作符。这个生成的操作符同时执行投影(Projection)和范围(Range)操作。在第 117 行的 while 循环中,我们可以看到生成行的代码和一个特定的调用(不是虚拟函数)project_doConsume_0。这展示了全阶段代码生成(Whole-Stage Code Generation)是如何工作的。
性能分析
现在我们对 Spark 的代码生成有了更好的理解,让我们尝试解释为什么将一个包含 1000 个 Sigma 规则的查询拆分成较小的规则会表现得更好。我们来考虑一个评估两个 Sigma 规则的 SQL 语句。这些规则非常简单:Rule1 匹配 Imagepath 以 ‘schtask.exe’ 结尾的事件,Rule2 匹配 Imagepath 以 ‘d:’ 开头的事件。
select /* #3 */
Imagepath,
CommandLine,
PID,
map_keys(map_filter(results_map, (k,v) -> v = TRUE)) as matching_rules
from (
select /* #2 */
*,
map('rule1', rule1, 'rule2', rule2) as results_map
from (
select /* #1 */
*,
(lower_Imagepath like '%schtasks.exe') as rule1,
(lower_Imagepath like 'd:%') as rule2
from (
select
lower(PID) as lower_PID,
lower(CommandLine) as lower_CommandLine,
lower(Imagepath) as lower_Imagepath,
*
from (
select
uuid() as PID,
uuid() as CommandLine,
uuid() as Imagepath,
id
from
range(0, 10000, 1, 32)
)
)
)
)
标记为 #1 的选择执行检测,并将结果存储在名为 rule1 和 rule2 的新列中。选择 #2 将这些列重新组合到一个名为 results_map 的单一列中,最后选择 #3 将映射转换为一个匹配规则的数组。它使用 map_filter 仅保留实际匹配的规则条目,然后使用 map_keys 将映射条目转换为匹配规则名称的列表。
让我们打印出这个查询的 Spark 执行计划:
== Physical Plan ==
Project (4)
+- * Project (3)
+- * Project (2)
+- * Range (1)
...
(4) Project
Output [4]: [Imagepath#2, CommandLine#1, PID#0, map_keys(map_filter(map(rule1, EndsWith(lower_Imagepath#5, schtasks.exe), rule2, StartsWith(lower_Imagepath#5, d:)), lambdafunction(lambda v#12, lambda k#11, lambda v#12, false))) AS matching_rules#9]
Input [4]: [lower_Imagepath#5, PID#0, CommandLine#1, Imagepath#2]
请注意,节点 Project (4) 不是由代码生成的。节点 4 包含一个 lambda 函数,它是否阻止了整个阶段的代码生成?稍后会详细讨论这个问题。
这个查询并不是我们想要的。我们希望生成一个事件表,并且有一列显示匹配的规则。如下所示:
+--------------------+--------------------+--------------------+--------------+
| Imagepath| CommandLine| PID| matched_rule|
+--------------------+--------------------+--------------------+--------------+
|09401675-dc09-4d0...|6b8759ee-b55a-486...|44dbd1ec-b4e0-488...| rule1|
|e2b4a0fd-7b88-417...|46dd084d-f5b0-4d7...|60111cf8-069e-4b8...| rule1|
|1843ee7a-a908-400...|d1105cec-05ef-4ee...|6046509a-191d-432...| rule2|
+--------------------+--------------------+--------------------+--------------+
这很简单。我们只需要展开 matching_rules 列。
select
Imagepath,
CommandLine,
PID,
matched_rule
from (
select
*,
explode(matching_rules) as matched_rule
from (
/* original statement */
)
)
这会产生两个额外的操作符:Generate (6) 和 Project (7)。然而,还有一个新的 Filter (3)。
== Physical Plan ==
* Project (7)
+- * Generate (6)
+- Project (5)
+- * Project (4)
+- Filter (3)
+- * Project (2)
+- * Range (1)
...
(3) Filter
Input [3]: [PID#34, CommandLine#35, Imagepath#36]
Condition : (size(map_keys(map_filter(map(rule1, EndsWith(lower(Imagepath#36),
schtasks.exe), rule2, StartsWith(lower(Imagepath#36), d:)),
lambdafunction(lambda v#47, lambda k#46, lambda v#47, false))), true) > 0)
...
(6) Generate [codegen id : 3]
Input [4]: [PID#34, CommandLine#35, Imagepath#36, matching_rules#43]
Arguments: explode(matching_rules#43), [PID#34, CommandLine#35, Imagepath#36], false, [matched_rule#48]
(7) Project [codegen id : 3]
Output [4]: [Imagepath#36, CommandLine#35, PID#34, matched_rule#48]
Input [4]: [PID#34, CommandLine#35, Imagepath#36, matched_rule#48]
explode 函数会为数组中的每个元素生成一行。当数组为空时,explode 不会生成任何行,实际上过滤掉了那些数组为空的行。
Spark 有一个优化规则,用于检测 explode 函数,并产生这个额外的条件。该过滤器是 Spark 尝试尽可能短路处理的做法。这个规则的源代码,名为 org.apache.spark.sql.catalyst.optimizer.InferFiltersFromGenerate,是这样解释的:
从 Generate 推断过滤器,以便可以在连接和数据源之前,提前移除本该被此 Generate 移除的行。
有关 Spark 如何优化执行计划的更多细节,请参阅 David Vrba 的文章 Mastering Query Plans in Spark 3.0。
另一个问题出现了:我们是否从这个额外的过滤器中受益?注意,这个额外的过滤器同样没有被整个阶段的代码生成,可能是因为 lambda 函数的原因。让我们尝试表达相同的查询,但不使用 lambda 函数。
另外,我们可以将规则结果放入映射中,展开映射,并过滤掉不需要的行,从而绕过 map_filter。
select
Imagepath,
CommandLine,
PID,
matched_rule
from (
select
*
from (
select
*,
explode(results_map) as (matched_rule, matched_result)
from (
/* original statement */
)
)
where
matched_result = TRUE
)
选择 #3 操作将映射展开成两个新列。matched_rule 列将保存键,表示规则名称,而 matched_result 列将包含检测测试的结果。为了过滤行,我们只保留 matched_result 为正的行。
物理计划表明,所有节点都被整个阶段的代码生成,合并成一个 Java 函数,这非常有前景。
== Physical Plan ==
* Project (8)
+- * Filter (7)
+- * Generate (6)
+- * Project (5)
+- * Project (4)
+- * Filter (3)
+- * Project (2)
+- * Range (1)
让我们进行一些测试,以比较使用 map_filter 和使用 explode 然后 filter 的查询性能。
我们在一台配备 4 个 CPU 的机器上运行了这些测试。我们生成了 100 万行数据,每行包含 100 条规则,每条规则评估 5 个表达式。这些测试共运行了 5 次。
平均而言
-
map_filter 耗时 42.6 秒
-
explode_then_filter 耗时 51.2 秒
所以,map_filter 略微更快,尽管它没有使用 WholeStageCodeGen。
然而,在我们的生产查询中,我们执行了更多的 Sigma 规则——共计 1000 条规则。这包括 29 个正则表达式,529 个等号,115 个以“开始”为前缀,2352 个以“结束”为后缀的表达式,以及 5838 个包含表达式。让我们再次测试查询,但这次我们将每条规则的表达式数量从 5 增加到 7。当这样做时,我们在日志中遇到了以下错误:
Caused by: org.codehaus.commons.compiler.InternalCompilerException: Code grows beyond 64 KB
我们尝试增加了spakr.sql.codegen.maxFields和spark.sql.codegen.hugeMethodLimit,但从根本上讲,Java 类的函数大小限制为 64 KB。此外,JVM JIT 编译器限制其只能编译小于 8 KB 的函数。
然而,查询仍然能够正常运行,因为 Spark 会在某些执行计划的部分回退到火山执行模型(Volcano execution model)。毕竟,WholeStageCodeGen 只是一个优化。
运行与之前相同的测试,但每条规则使用 7 个表达式而不是 5 个时,explode_then_filter 比 map_filter 快得多。
-
map_filter 耗时 68.3 秒
-
explode_then_filter 耗时 15.8 秒
增加表达式数量导致 explode_then_filter 的部分代码不再进行 WholeStageCodeGen 优化。特别是,由规则 org.apache.spark.sql.catalyst.optimizer.InferFiltersFromGenerate 引入的 Filter 操作符过大,无法包含在 WholeStageCodeGen 中。让我们看看如果排除 InferFiltersFromGenerate 规则会发生什么:
spark.sql("SET spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.InferFiltersFromGenerate")
正如预期的那样,两个查询的物理计划中都不再有额外的 Filter 操作符。
== Physical Plan ==
* Project (6)
+- * Generate (5)
+- Project (4)
+- * Project (3)
+- * Project (2)
+- * Range (1)
== Physical Plan ==
* Project (7)
+- * Filter (6)
+- * Generate (5)
+- * Project (4)
+- * Project (3)
+- * Project (2)
+- * Range (1)
移除规则确实对性能产生了显著影响:
-
map_filter 耗时 22.49 秒
-
explode_then_filter 耗时 4.08 秒
这两个查询在移除规则后都受益匪浅。鉴于性能的改善,我们决定将 Sigma 规则的数量增加到 500,并将复杂度提高到 21 个表达式:
结果:
-
map_filter 耗时 195.0 秒
-
explode_then_filter 耗时 25.09 秒
尽管复杂性增加了,但两个查询仍然能提供相当不错的性能,其中 explode_then_filter 显著优于 map_filter。
探讨 Spark 所采用的不同代码生成方式是很有趣的。尽管我们目前可能没有从 WholeStageCodeGen 中获益,但我们仍然可以从表达式生成中获得优势。
表达式生成不受与整体代码生成相同的限制。非常大的表达式树可以被拆分成更小的树,Spark 的 spark.sql.codegen.methodSplitThreshold 控制如何拆分这些树。虽然我们尝试了这一属性,但并没有观察到显著的改进。默认设置似乎已足够令人满意。
Spark 提供了一个名为 spark.sql.codegen.factoryMode 的调试属性,可以设置为 FALLBACK、CODEGEN_ONLY 或 NO_CODEGEN。我们可以通过设置 spark.sql.codegen.factoryMode=NO_CODEGEN 来关闭表达式代码生成,这会导致性能急剧下降:
使用 500 条规则和 21 个表达式:
-
map_filter 花费了 1581 秒
-
explode_then_filter 花费了 122.31 秒。
即使并非所有操作符都参与整体代码生成,我们仍然观察到表达式代码生成带来了显著的好处。
结果

图片由作者提供
在我们的最佳案例中,评估 10,500 个表达式所需的时间为 25.1 秒,处理 100 万行数据时,我们实现了每个 CPU 每秒处理 1.04 亿个表达式的非常可观的速度。
本研究的启示是,在评估大量表达式时,我们通过将使用 map_filter 的查询转换为使用先 explode 后 filter 的方法可以获得好处。此外,org.apache.spark.sql.catalyst.optimizer.InferFiltersFromGenerate 规则在我们的使用案例中似乎并不有利,因此我们应该将该规则从查询中排除。
它解释了我们最初的观察结果吗?
在生产工作中实施这些经验教训带来了显著的好处。然而,即使在进行了这些优化后,将大型查询拆分为多个较小查询仍然提供了优势。经过进一步调查,我们发现这不仅仅是由于代码生成,实际上有一个更简单的解释。
Spark 流处理通过运行微批次直到完成,然后在开始新的微批次之前检查点其进度来工作。
在每个微批次中,Spark 必须完成所有任务,通常是 200 个。然而,并非所有任务的难度相同。Spark 采用轮询策略将行分配给这些任务。因此,某些任务可能会包含大属性的事件,例如非常大的命令行,导致某些任务快速完成,而其他任务则需要更长时间。例如,这里展示了微批任务执行时间的分布。中位数任务时间为 14 秒。然而,最慢的任务竟然需要 1.6 分钟!

图片由作者提供
这确实揭示了一个不同的现象。每个微批中,Spark 会等待一些滞后任务,这导致许多 CPU 空闲,这也解释了为什么将大型查询拆分为多个较小的查询会导致整体性能更快。
这张图展示了 5 个较小的查询在同一个 Spark 应用中并行运行。Batch3 在等待一个拖慢的任务,而其他查询继续进展。

图片由作者提供
在这些等待期间,Spark 可以利用空闲的 CPU 来处理其他查询,从而最大化资源利用率和整体吞吐量。
结论
在本文中,我们概述了 Spark 的代码生成过程,并讨论了内置优化可能并不总是产生理想的结果。此外,我们展示了将查询从使用 lambda 函数重构为使用简单的 explode 操作后,性能得到了提升。最后,我们得出结论,尽管拆分大查询确实提升了性能,但推动这些提升的主要因素是执行拓扑结构,而非查询本身。
高性能 IPv4 范围 Spark 连接
实用指南:优化 Spark 中的非等值连接
·发布于Towards Data Science ·9 分钟阅读·2024 年 1 月 25 日
--

图片由 John Lee 提供,来源于 Unsplash
通过 IP 地理位置信息丰富网络事件是一个至关重要的任务,特别是对于加拿大网络安全中心等组织,这是加拿大的国家计算机安全事件响应团队(CSIRT)。在本文中,我们将展示如何优化 Spark SQL 连接,特别关注涉及不等式条件的场景——这是处理 IP 地理位置数据时常见的挑战。
作为网络安全从业者,我们依赖于通过 IP 地理位置数据库丰富网络事件,这就要求我们采用高效的策略来处理非等值连接。尽管有许多文章阐述了 Spark 支持的各种连接策略,但这些策略在实际应用中的效果仍然是业内专业人士关注的一个问题。
David Vrba 的深刻文章,“关于 Spark 3.0 中的连接”,发布于 Towards Data Science,是一篇宝贵的资源。它解释了 Spark 选择特定连接策略的条件。在他的文章中,David 简要地指出,优化非等值连接的方法是将其转化为等值连接。
本文旨在提供一个实用指南,帮助优化非等值连接的性能,特别是聚焦于与地理位置表中的 IP 范围进行连接的情况。
为了举例说明这些优化,我们将回顾在我们之前的文章中介绍的地理位置表。
+----------+--------+---------+-----------+-----------+
| start_ip | end_ip | country | city | owner |
+----------+--------+---------+-----------+-----------+
| 1 | 2 | ca | Toronto | Telus |
| 3 | 4 | ca | Quebec | Rogers |
| 5 | 8 | ca | Vancouver | Bell |
| 10 | 14 | ca | Montreal | Telus |
| 19 | 22 | ca | Ottawa | Rogers |
| 23 | 29 | ca | Calgary | Videotron |
+----------+--------+---------+-----------+-----------+
等值连接
为了说明 Spark 如何执行等值连接,我们将通过考虑一个假设的场景来开始我们的探索。假设我们有一个事件表,每个事件都与特定的 owner 相关联,该 owner 由 event_owner 列表示。
+------------+--------------+
| event_time | event_owner |
+------------+--------------+
| 2024-01-01 | Telus |
| 2024-01-02 | Bell |
| 2024-01-03 | Rogers |
| 2024-01-04 | Videotron |
| 2024-01-05 | Telus |
| 2024-01-06 | Videotron |
| 2024-01-07 | Rogers |
| 2024-01-08 | Bell |
+------------+--------------+
让我们更仔细地看看 Spark 如何处理这个等值连接:
SELECT
*
FROM
events
JOIN geolocation
ON (event_owner = owner)
在这个例子中,等值连接建立在 events 表和 geolocation 表之间。连接的标准是基于 events 表中的 event_owner 列与 geolocation 表中的 owner 列是否相等。
正如 David Vrba 在他的博客中所解释的:
如果存在等值条件,并且连接键是可排序的,Spark 会选择使用 SMJ 进行连接规划。
Spark 将执行排序合并连接(Sort Merge Join),通过对左侧的 event_owner 和右侧的 owner 进行哈希操作来分配两表的行。哈希到同一 Spark 分区的两表行将由同一 Spark 任务处理——一个工作单元。例如,Task-1 可能接收:
+----------+-------+---------+-----------+-----------+
| start_ip | end_ip| country | city | owner |
+----------+-------+---------+-----------+-----------+
| 1 | 2 | ca | Toronto | Telus |
| 10 | 14 | ca | Montreal | Telus |
+----------+-------+---------+-----------+-----------+
+------------+--------------+
| event_time | event_owner |
+------------+--------------+
| 2024-01-01 | Telus |
| 2024-01-05 | Telus |
+------------+--------------+
注意 Task-1 仅处理数据的一个子集。连接问题被划分为多个较小的任务,其中只需要左右两边的部分行。此外,Task-1 处理的左右两边行必须匹配。因为无论来自 events 表还是 geolocation 表,“Telus” 的每次出现都会哈希到相同的分区。我们可以确定其他 Task-X 不会有 owner 为“Telus”的行。
一旦数据按照上面所示的方式划分,Spark 将对两边的数据进行排序,因此这种连接策略被称为排序合并连接(Sort Merge Join)。合并操作通过取左边的第一行并测试其是否与右边匹配来进行。一旦右边的行不再匹配,Spark 将从左边提取行。它将不断从两边队列中提取,直到两边没有剩余行。
非等值连接
现在我们更好地理解了等值连接是如何执行的,接下来我们将其与非等值连接进行对比。假设我们有事件表,其中包含 event_ip,并且我们希望向该表添加地理位置信息。
+------------+----------+
| event_time | event_ip |
+------------+----------+
| 2024-01-01 | 6 |
| 2024-01-02 | 14 |
| 2024-01-03 | 18 |
| 2024-01-04 | 27 |
| 2024-01-05 | 9 |
| 2024-01-06 | 23 |
| 2024-01-07 | 15 |
| 2024-01-08 | 1 |
+------------+----------+
为了执行此连接,我们需要确定 event_ip 所在的 IP 范围。我们通过以下条件来实现这一点:
SELECT
*
FROM
events
JOIN geolocation
ON (event_ip >= start_ip and event_ip <= end_ip)
现在,让我们考虑 Spark 如何执行此连接。在右边(地理位置表),没有可以哈希并分配行的键。无法将此问题划分为可以在计算集群中分发并并行执行的较小任务。
在这种情况下,Spark 被迫采用更为资源密集型的连接策略。正如 David Vrba 所说:
如果没有等值条件,Spark 必须使用广播嵌套循环连接(BroadcastNestedLoopJoin,BNLJ)或笛卡尔积连接(Cartesian Product Join,CPJ)。
这两种策略都涉及暴力破解问题;对于左侧的每一行,Spark 会对右侧的每一行测试“between”条件。它别无选择。如果右侧的表足够小,Spark 可以通过将右侧表复制到每个读取左侧的任务中来优化,这种情况称为 BNLJ(嵌套循环连接)情况。然而,如果左侧太大,每个任务将需要读取右侧和左侧的表,这种情况称为 CPJ(连接点连接)情况。在这两种情况下,这两种策略都是非常昂贵的。
那么,我们如何改进这种情况呢?诀窍是引入连接条件中的相等性。例如,我们可以简单地展开地理位置表中的所有 IP 范围,为每个 IP 生成一行。
这是在 Spark 中容易实现的;我们可以执行以下 SQL 来展开所有 IP 范围:
SELECT
country,
city,
owner,
explode(sequence(start_ip, end_ip)) AS ip
FROM
geolocation
sequence函数创建一个从start_ip到end_ip的 IP 值数组。explode函数将这个数组展开成单独的行。
+---------+---------+---------+-----------+
| country | city | owner | ip |
+---------+---------+---------+-----------+
| ca | Toronto | Telus | 1 |
| ca | Toronto | Telus | 2 |
| ca | Quebec | Rogers | 3 |
| ca | Quebec | Rogers | 4 |
| ca | Vancouver | Bell | 5 |
| ca | Vancouver | Bell | 6 |
| ca | Vancouver | Bell | 7 |
| ca | Vancouver | Bell | 8 |
| ca | Montreal | Telus | 10 |
| ca | Montreal | Telus | 11 |
| ca | Montreal | Telus | 12 |
| ca | Montreal | Telus | 13 |
| ca | Montreal | Telus | 14 |
| ca | Ottawa | Rogers | 19 |
| ca | Ottawa | Rogers | 20 |
| ca | Ottawa | Rogers | 21 |
| ca | Ottawa | Rogers | 22 |
| ca | Calgary | Videotron | 23 |
| ca | Calgary | Videotron | 24 |
| ca | Calgary | Videotron | 25 |
| ca | Calgary | Videotron | 26 |
| ca | Calgary | Videotron | 27 |
| ca | Calgary | Videotron | 28 |
| ca | Calgary | Videotron | 29 |
+---------+---------+---------+-----------+
有了两侧的键,我们现在可以执行等值连接,Spark 可以高效地分配问题,从而实现最优性能。然而,在实际应用中,这种情况并不现实,因为真正的地理位置表通常包含数十亿行。
为了解决这个问题,我们可以通过增加映射的粗糙度来提高效率。我们可以将 IP 范围映射到 IP 空间中的段,而不是将 IP 范围映射到每个单独的 IP。假设我们将 IP 空间分割成 5 个一组的段。分段后的空间大致如下所示:
+---------------+-------------+-----------+
| segment_start | segment_end | bucket_id |
+---------------+-------------+-----------+
| 1 | 5 | 0 |
| 6 | 10 | 1 |
| 11 | 15 | 2 |
| 16 | 20 | 3 |
| 21 | 25 | 4 |
| 26 | 30 | 5 |
+---------------+-------------+-----------+
现在,我们的目标是将 IP 范围映射到它们重叠的段。类似于我们之前做的那样,我们可以展开 IP 范围,但这次我们将按 5 个一组的段进行操作。
SELECT
country,
city,
owner,
explode(sequence(start_ip / 5, end_ip / 5)) AS bucket_id
FROM
geolocations
我们观察到某些 IP 范围共享相同的bucket_id。范围 1–2 和 3–4 都属于段 1–5。
+----------+--------+---------+-----------+-----------+-----------+
| start_ip | end_ip | country | city | owner | bucket_id |
+----------+--------+---------+-----------+-----------+-----------+
| 1 | 2 | ca | Toronto | Telus | 0 |
| 3 | 4 | ca | Quebec | Rogers | 0 |
| 5 | 8 | ca | Vancouver | Bell | 1 |
| 10 | 14 | ca | Montreal | Telus | 2 |
| 19 | 22 | ca | Ottawa | Rogers | 3 |
| 19 | 22 | ca | Ottawa | Rogers | 4 |
| 23 | 29 | ca | Calgary | Videotron | 4 |
| 23 | 29 | ca | Calgary | Videotron | 5 |
+----------+--------+---------+-----------+-----------+-----------+
此外,我们注意到一些 IP 范围是重复的。IP 范围 23–29 的最后两行与段 20–25 和 26–30 重叠。类似于我们展开单个 IP 的情况,我们仍然在重复行,但重复的程度要小得多。
现在,我们可以利用这个分桶表来执行我们的连接。
SELECT
*
FROM
events
JOIN geolocation
ON (
event_ip / 5 = bucket_id
AND event_ip >= start_ip
AND event_ip <= end_ip
)
连接中的相等条件使得 Spark 能够执行排序合并连接(Sort Merge Join,SMJ)策略。“between”条件消除了 IP 范围共享相同bucket_id的情况。
在这个示例中,我们使用了 5 个一组的段;然而,实际上,我们会将 IP 空间分割成 256 个一组的段。这是因为全球 IP 地址空间由互联网号码分配局(IANA)监管,传统上,IANA 按 256 个 IP 的块分配地址空间。
使用 Spark 的approx_percentile函数分析真正的地理位置表中的 IP 范围可以发现,大多数记录的跨度小于 256,而大于 256 的记录非常少。
SELECT
approx_percentile(
end_ip - start_ip,
array(0.800, 0.900, 0.950, 0.990, 0.999, 0.9999),
10000)
FROM
geolocation
这意味着大多数 IP 范围都被分配了一个bucket_id,而少数较大的范围则被展开,导致展开后的表格大约包含额外 10% 的行。
执行的查询与一个真实的地理定位表可能如下所示:
WITH
b_geo AS (
SELECT
explode(
sequence(
CAST(start_ip / 256 AS INT),
CAST(end_ip / 256 AS INT))) AS bucket_id,
*
FROM
geolocation
),
b_events AS (
SELECT
CAST(event_ip / 256 AS INT) AS bucket_id,
*
FROM
events
)
SELECT
*
FROM
b_events
JOIN b_geo
ON (
b_events.bucket_id = b_geo.bucket_id
AND b_events.event_ip >= b_geo.start_ip
AND b_events.event_ip <= b_geo.end_ip
);
结论
总之,本文展示了通过实现一种涉及分割 IP 范围的映射技术,将非等值连接转化为等值连接的实际示例。需要特别注意的是,这种方法不仅限于 IP 地址,还可以应用于任何由带区或范围构成的数据集。
有效地映射和分割数据是数据工程师和分析师工具箱中一个有价值的工具,为 Spark SQL 连接中的非等值条件所带来的挑战提供了实际的解决方案。
使用 LangChain 和 LLM 进行客户分析
发现 LangChain 在统计计算、洞察生成、可视化以及为客户分析进行对话中的潜力与限制——包括实现代码
·发布于Towards Data Science ·阅读时长 11 分钟·2024 年 2 月 13 日
--
许多企业拥有大量存储在数据库中的专有数据。然而,这些数据复杂且难以被用户接触,因此他们常常难以识别趋势并提取可操作的洞察。这就是商业智能(BI)仪表盘发挥重要作用的地方,它是用户与数据汇总视图互动的起点。
BI 仪表盘的瓶颈
一个有效的 BI 仪表盘[/how-to-build-effective-and-useful-dashboards-711759534639]应该只包含对目标受众相关的信息,避免将杂乱的视觉元素堆砌在一起。但这并不能很好地解决一个挑战。有时候,用户可能突然有了额外的查询,或者希望探索仪表盘上没有显示的新分析视角。如果他们没有任何技术背景,无法动态调整可视化的底层逻辑,那么仪表盘可能无法满足他们的需求。

图片由Emily Morter提供,来源于Unsplash
最近的框架LangChain通过其先进的语言处理能力,降低了与数据交互的技术门槛,从而为企业提供了潜在的新机遇。让我们来探讨一下它的基本工作原理。
LangChain 的工作原理
大语言模型(LLMs),如 ChatGPT 和 Llama,具有强大的语言理解和文本生成能力。作为一个开源库,LangChain 将 LLMs 集成到应用程序中。它提供了多个模块,以便高效交互和简化工作流,例如:
-
文档加载器: 促进从各种来源加载数据,包括 CSV 文件、SQL 数据库和公共数据集(如 Wikipedia)。
-
代理: 使用语言模型作为推理引擎,决定采取哪些行动以及行动的顺序。它通过不断循环思考-行动-观察,直到任务完成。
-
链: 与代理不同,链由预定的动作序列组成,这些序列是硬编码的。它通过引导多个工具执行高级指令来解决复杂和明确的任务。
-
记忆: 目前的测试版支持访问过去消息的窗口,这为应用程序提供了对话接口。
以这些模块为基础,我们将开始编写一个简单的应用程序,利用大语言模型(LLM)。在这个实践过程中,我们将扮演商业用户的角色,尝试通过输入自然语言查询来进行探索性数据分析。

LLM 驱动的应用程序工作流程(图片来自作者)
假设你计划为一家零售店进行客户分析,因此你收集了过去 12 个月的销售数据。你的目标是更好地了解客户的不同方面,如人口统计、消费行为和产品类别。
从 Kaggle 获得的数据集,其许可证为CC0: 公共领域,包含多个字段,包括交易 ID、交易日期、客户 ID、性别、年龄、产品类别、购买的产品单位数、单价以及交易的总金额。我们可以开始分析了。
初始设置
我们需要正确设置环境和配置,才能在 Python 中使用 LangChain。
-
配置 Python 环境,并安装 LangChain 库以及其他必要的依赖项,如SQLite和 Pandas
-
配置 OpenAI 密钥以查询 GPT 模型
-
将 CSV 文件‘retail_sales_dataset.csv’导入 SQLite 数据库中的表格
# Import necessary libraries and modules
from langchain.chat_models import ChatOpenAI
import sqlite3
import pandas as pd
# Set the OpenAI API key
OPENAI_API_KEY = "<OpenAI API key>"
# Initialize the Langchain ChatOpenAI model
llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name="gpt-3.5-turbo-1106")
# Connect to the SQLite database
connection = sqlite3.connect("customer.db")
# Convert DataFrame to a SQLite table named "RetailSalesTable"
df.to_sql("RetailSalesTable", connection, if_exists='replace')
创建一个 LangChain 应用程序
#1 生成基础统计数据
每年每月的交易数量是多少**?
要查询与销售 SQL 表相关的基本统计数据,我们使用create_sql_agent代理助手。两个参数verbose和return_intermediate_steps都设置为 True,以便在执行过程中显示内部状态和步骤。这将帮助我们迭代评估并优化与代理的沟通方式。
# Import necessary libraries and modules
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase
# Create an instance of SQLDatabase using the 'customer.db' SQLite database
db = SQLDatabase.from_uri('sqlite:///customer.db')
# Create an SQL agent executor with specified parameters
agent_executor = create_sql_agent(
llm=llm,
toolkit=SQLDatabaseToolkit(db=db, llm=llm),
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
handle_parsing_errors=True,
verbose=True,
return_intermediate_steps=True
)
# Define user input
user_inquiry = "What is the number of transactions per year and month?"
# Run the agent to generate a response
agent_executor.run(user_inquiry)
输出的完整内容如下所示。

代理的输出步骤——基础统计(1/2)(图由作者提供)

代理的输出步骤——基础统计(2/2)(图由作者提供)
代理编写了查询,包括纠正“日期”列的格式并生成正确的结果。它成功展示了从 2023 年 1 月到 2024 年 1 月的交易次数分解。
#2 执行特征工程
不同年龄组客户的交易数量是多少?
这一次,我们稍微修改了查询,重点关注“年龄组”维度。这旨在评估代理生成不直接来自数据集的新特征的能力。应用类似的概念,您还可以探索按工作日/周末分类的交易日期、购买频率等维度的统计信息。

代理输出步骤的一部分——特征工程(图由作者提供)
代理尝试编写查询,但不幸的是给出了错误的答案。看起来代理过于简化了“年龄组”这一概念,因此没有将其归类为一个独立的维度,而只是将其视为“年龄”。
改进:使用带背景信息的提示模板
当我们发现模型误解了我们的意图,或者缺乏某些知识时,可以使用类PromptTemplate为语言模型创建参数化的提示。在这种情况下,我在显示用户查询之前补充了有关客户分析中额外特征的示例背景信息。这样做旨在为模型提供更清晰的指南,并传达我们的人工意图,以生成适当的机器生成响应。
from langchain_core.prompts import PromptTemplate
# Create the prompt template
template = PromptTemplate(
input_variables=["user_inquiry", "background_info"],
template="""{background_info}
Question: {user_inquiry}
"""
)
# Define the background information
background_info = """
As the customer analyst, my role is to analyze the transaction patterns of customers. The feature engineering in table 'RetailSalesTable' is crucial for statistical exploration. For example:
- column 'Age' can be grouped into bins of age ranges, such as 21-25, 26-30, and so on.
Understanding the data in these columns helps us gain insights about our customers, enabling us to offer personalized services and develop effective marketing strategies.
"""
# Define user input
user_inquiry = "What is the number of transactions across different age ranges of customers?"
# Run the agent with the formatted template
agent_executor.run(template.format(background_info=background_info, user_inquiry=user_inquiry))
以下是相应输出的关键亮点。

代理的改进输出——特征工程(图由作者提供)
在提示的帮助下,代理现在成功地将交易分类到多个年龄组中。值得一提的是,可能还有其他方式实现相同目标,例如使用少量示例来演示问答对。
#3 为多维特征绘制图表
我们使用create_sql_agent代理计算基本统计数据并生成洞察。通过提示设计的艺术,代理能够执行任务。为了支持提出 SQL 查询无法独立完成的新问题,我们需要开发我们的自定义工具。
显示一个分组条形图,可视化以下问题的答案:产品类别、平均总金额和性别之间的关系是什么?
在这个例子中,查询涉及数据探索和可视化。我们将基于create_sql_agent创建我们的代理,并添加工具PythonREPLTool以执行 Python 命令,如使用Matplotlib进行可视化。
让我们看看这些工具在实践中的实现。
from langchain import LLMChain
from langchain.agents import (AgentExecutor, Tool, ZeroShotAgent)
from langchain_experimental.tools import PythonREPLTool
# Define a description to suggest how to determine the choice of tool
description = (
"Useful when you require to answer analytical questions about customers. "
"Use this more than the Python REPL tool if the question is about customer analytics,"
"like 'How many customers are there?' or 'count the number of transactions by age group'. "
"Try not to use clause in the SQL."
)
# Create a Tool object for customer data with the previously defined agent executor 'create_sql_agent' and description
customer_data_tool = Tool(
name="Customer",
func=agent_executor.run,
description=description,
)
# Create the whole list of tools
tools = [PythonREPLTool()]
tools.append(customer_data_tool)
# Define the prefix and suffix for the prompt
prefix = "Below are tools that you can access:"
suffix = (
"Pass the relevant part of the request directly to the Customer tool.\n\n"
"Request: {input}\n"
"{agent_scratchpad}"
)
# Create the prompt using ZeroShotAgent
# Use agent_scratchpad to store the actions previously used, guiding the subsequent responses.
agent_prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=["input", "agent_scratchpad"]
)
# Create an instance of ZeroShotAgent with the LLMChain and the allowed tool names
zero_shot_agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=agent_prompt),
allowed_tools=[tool.name for tool in tools]
)
# Create an AgentExecutor which enables verbose mode and handling parsing errors
agent_executor = AgentExecutor.from_agent_and_tools(
agent=zero_shot_agent, tools=tools, verbose=True, handle_parsing_errors=True
)
# Define user input
user_inquiry = "Use a grouped bar graph to visualize the result of the following inquiry: " \
"What are the relationships between product category, average total amount, and gender?"
# Run the agent to generate a response
agent_executor.run(user_inquiry)
SQL 代理输出的流程与我们之前讲解的例子相似,因此在此省略。接下来的 Python REPL 工具的输出如下所示。

代理的输出 — 绘图(1/2)(图片由作者提供)

代理的输出 — 绘图(2/2)(图片由作者提供)
自定义工具组合成功地将自然语言查询转化为 SQL 查询。然后,汇总的查询结果用于生成分组条形图,通过 x 轴、y 轴和图例清晰有效地展示关系。
尽管整体设计和执行过程看起来顺利,但当前设计确实存在一些限制。例如,假设我们想要生成一个包含大多数交易数据点的散点图,那么执行过程应该生成一个长查询输出,涵盖所有相关信息。然而,代理的输出可能并不理想,因为代理偶尔会使用LIMIT子句(该子句限制元组的数量),或者查询结果超出最大令牌限制(在我们的案例中为 4096 个令牌)。因此,生成的可视化种类可能会受到限制。
#4 进行连贯对话
实际上,业务用户在收到客户分析结果后通常会有后续问题。为了解决这些情况,我们需要增强现有的基本 LLM 应用,使其更具对话性。我们添加了内存缓冲区以保留过去的互动,使 LLM 能够生成针对当前对话上下文的响应。这通过不断存储 LLM 输出并在生成响应前引用内存存储来实现。
初始问题:顾客在不同季节如何调整他们的购物习惯?
后续问题:您能详细说明一下吗?
我们补充并修正了以下自定义工具组合:
from langchain.memory import ConversationBufferMemory
# Skipped here - Define your own prefix, suffix, and description with "chat_history" for the prompt
# Keep the original list of tools
# Create the prompt using ZeroShotAgent with additonal "chat_history" as input variables
agent_prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=["input", "chat_history", "agent_scratchpad"],
)
# Create an instance of ZeroShotAgent with the LLMChain and the allowed tool names
zero_shot_agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=agent_prompt),
allowed_tools=[tool.name for tool in tools]
)
# Initiate memory which allows for storing and extracting messages
memory = ConversationBufferMemory(memory_key="chat_history")
# Create an AgentExecutor with memory parameter
agent_chain = AgentExecutor.from_agent_and_tools(
agent=zero_shot_agent, tools=tools, verbose=True, handle_parsing_errors=True, memory=memory
)
# Define initial question as user input
user_inquiry = "How do customers adapt their shopping habits during different seasons?"
# Run the agent to generate a response
agent_executor.run(user_inquiry)
# Define follow-up question as user input
user_inquiry = "Can you elaborate more?"
# Run the agent to generate another response
agent_executor.run(user_inquiry)
代理的回答:

代理的输出 — 初始问题(1/2)(图片由作者提供)

代理的输出 —— 后续问题(2/2)(图片由作者提供)
在后续问题“你能详细说明一下吗?”中,我们故意没有提供任何提示/关键词来引导提问,但该代理展示了其在继续分析不同季节的购物习惯方面的能力。这表明使用内存的有效性,并通过提供跨产品类别和季节的更深入描述展示了其优势。
总结
我们进行了实验,探索了 LangChain 在基于 LLM 开发客户分析应用中的关键功能和潜在方法:
-
通过使用
create_sql_agent代理查询数据库并获取相关的统计信息来计算统计数据。 -
洞察生成,通过应用提示模板来定义关键数据特征。
-
可视化,通过使用自定义代理和工具
PythonREPLTool的组合。 -
会话功能,通过添加内存缓冲区来存储和检索聊天历史。
自然语言查询中的措辞选择通常与数据库模式中的措辞不完全一致。观察发现,LangChain 执行器有时不能按预期工作,甚至可能会hypothesize,特别是在识别数据关系以生成图表时。因此,代码开发需要反复调试。虽然 LangChain 框架在处理多样化的客户分析任务时可能仅能提供有限的可靠性和效果,但当用户有迫切需求从传统分析仪表板之外发现洞察时,它仍然能提供一些边际优势。
这个应用程序的设计只是初步阶段,还有更多的可能性等待发掘。例如,客户数据有时以文本格式存在,如客户评论或产品描述。LangChain 提供了tagging function,我们可以通过标注情感、语言、风格等,进行全面分析。
在你离开之前
如果你喜欢这篇文章,邀请你关注我的Medium 页面和LinkedIn 页面。这样,你可以保持对数据科学副项目、机器学习运维(MLOps)演示和项目管理方法论相关精彩内容的更新。
探索通过实现代码可持续减轻快速交付成本的实践
towardsdatascience.com ## 在生产环境中监控机器学习模型:为何以及如何?
我们的模型在不断变化的世界中如何受到影响?这是一篇聚焦于漂移示例的分析,并实现了基于 Python 的…
towardsdatascience.com
追求 p99 的危险
隐藏的相关性可能误导优化策略
·发表于Towards Data Science ·阅读时间:5 分钟·2024 年 6 月 5 日
--

图片来源:Chun Kit Soo 通过Unsplash
p99,即 99%的观察值落在其下的值,在各行业中广泛用于跟踪和优化最差情况性能。例如,页面加载时间、购物订单的完成时间或货物配送时间都可以通过跟踪 p99 来优化。
虽然 p99 无疑具有价值,但我们必须认识到,它忽略了最顶部的 1%观察值,而这些值在与其他关键业务指标相关时,可能会产生出乎意料的大影响。如果盲目追求 p99 而不检查这些相关性,可能会破坏其他业务目标。
在本文中,我们将通过一个包含虚拟数据的例子分析 p99 的局限性,了解何时应依赖 p99,并探索替代的指标。
相关性难题
假设有一个电子商务平台,团队的任务是优化购物车结账体验。团队收到顾客的反馈,称结账速度相比其他平台较慢。因此,团队抓取了最新的 1,000 个结账数据并分析了结账所需的时间。(我为此创建了一些虚拟数据,你可以自由使用并进行修改)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="ticks", font_scale = 1)
order_time = pd.read_csv('https://gist.githubusercontent.com/kkraoj/77bd8332e3155ed42a2a031ce63d8903/raw/458a67d3ebe5b649ec030b8cd21a8300d8952b2c/order_time.csv')
fig, ax = plt.subplots(figsize=(4,2))
sns.histplot(data = order_time, x = 'fulfillment_time_seconds', bins = 40, color = 'k', ax = ax)
print(f'p99 for fulfillment_time_seconds: {order_time.fulfillment_time_seconds.quantile(0.99):0.2f} s')

订单结账时间的分布。图像来源:作者。
正如预期,大多数购物车结账似乎在几秒钟内完成。99%的结账都发生在 12.1 秒内。换句话说,p99 是 12.1 秒。有一些长尾案例需要 30 秒才能完成。由于这些案例非常少,它们可能是异常值,可以安全忽略,对吧?
现在,如果我们不暂停并分析最后一句话的含义,可能会非常危险。忽视前 1%真的安全吗?我们确定结账时间与其他业务指标没有相关性吗?
假设我们的电商公司也关注总商品交易额(GMV),并且有一个整体的公司级目标是增加 GMV。在忽略前 1%之前,我们应该立即检查结账时间是否与 GMV 相关。
from matplotlib.ticker import ScalarFormatter
order_value = pd.read_csv('https://gist.githubusercontent.com/kkraoj/df53cac7965e340356d6d8c0ce24cd2d/raw/8f4a30db82611a4a38a90098f924300fd56ec6ca/order_value.csv')
df = pd.merge(order_time, order_value, on='order_id')
fig, ax = plt.subplots(figsize=(4,4))
sns.scatterplot(data=df, x="fulfillment_time_seconds", y="order_value_usd", color = 'k')
plt.yscale('log')
ax.yaxis.set_major_formatter(ScalarFormatter())

订单价值与履行时间的关系。图像由作者提供。
哦,天哪!不仅购物车价值与结账时间相关,而且结账时间越长,价值增加得越快。忽视结账时间前 1%的代价是什么?
pct_revenue_ignored = order_value.loc[order_time.fulfillment_time_seconds>order_time.fulfillment_time_seconds.quantile(0.99), 'order_value_usd'].sum()/order_value.order_value_usd.sum()*100
print(f'If we only focussed on p99, we would ignore {pct_revenue_ignored:0.0f}% of revenue')
## >>> If we only focussed on p99, we would ignore 27% of revenue
如果我们只关注 p99,我们将忽视 27%的收入(这比我们认为忽视的 1%高出 27 倍)。也就是说,结账时间的 p99 相当于收入的 p73。在这种情况下,专注于 p99 无意中对业务造成了损害。它忽视了我们最高价值购物者的需求。
df.sort_values('fulfillment_time_seconds', inplace = True)
dfc = df.cumsum()/df.cumsum().max() # percent cumulative sum
fig, ax = plt.subplots(figsize=(4,4))
ax.plot(dfc.fulfillment_time_seconds.values, color = 'k')
ax2 = ax.twinx()
ax2.plot(dfc.order_value_usd.values, color = 'magenta')
ax.set_ylabel('cumulative fulfillment time')
ax.set_xlabel('orders sorted by fulfillment time')
ax2.set_ylabel('cumulative order value', color = 'magenta')
ax.axvline(0.99*1000, linestyle='--', color = 'k')
ax.annotate('99% of orders', xy = (970,0.05), ha = 'right')
ax.axhline(0.73, linestyle='--', color = 'magenta')
ax.annotate('73% of revenue', xy = (0,0.75), color = 'magenta')

订单履行时间和订单价值的累积分布函数。图像由作者提供。
如上所示,我们可以看到结账时间的百分位数和 GMV 之间存在较大差异。GMV 曲线在订单的 99 百分位附近急剧上升,导致前 1%的订单对 GMV 的影响过大。
这不仅仅是我们虚拟数据的一个现象。这种极端相关性是相当常见的。例如,Slack 的前 1%客户占50%的收入。UPS 的约 12%的收入来自仅 1 个客户(亚马逊)。
平衡的方法
为了避免仅优化 p99 带来的陷阱,我们可以采取一种更全面的方法。
一种解决方案是同时跟踪 p99 和 p100(最大值)。这样,我们就不会忽视高价值用户。
另一种解决方案是使用按收入加权的 p99(或按总商品交易额、利润或任何其他相关业务指标加权),这会赋予收入较高的观察数据更大的权重。该指标确保优化工作优先考虑最有价值的交易或流程,而不是将所有观察数据视为相同。
最后,当绩效与业务指标之间存在高度相关性时,更严格的 p99.5 或 p99.9 可以减少忽视高价值用户的风险。
小结
仅仅依赖于像 p99 这样的指标来进行优化是很有诱惑力的。然而,正如我们所看到的,忽略掉前 1% 的观测值可能会对其他大量业务结果产生负面影响。追踪 p99 和 p100,或使用基于收入加权的 p99,可以提供更全面的视角,并降低仅仅针对 p99 进行优化的风险。至少,让我们记住,避免狭隘地专注于某个性能指标,而忽视整体客户结果。
本文在 Perplexity(用于定义和背景研究,聊天 这里)和 ChatGPT(用于拼写检查,聊天 这里)的帮助下完成。
从零开始的置换特征重要性
理解置换在可解释人工智能领域的重要性
·发表于Towards Data Science ·阅读时间 10 分钟·2024 年 4 月 24 日
--

(来源:作者)
如果你深入了解最先进的 XAI 方法,你会发现它们都涉及置换。SHAP、LIME、PDPs & ICE 图、ALE 和 Friedman 的 H-stat 都依赖于此。这也是为什么理解置换及其局限性对于该领域如此重要的原因。所以,让我们从最简单的 XAI 方法——置换特征重要性(PFI)开始。
为了深入理解这种方法,我们将:
-
使用 Python 从零开始计算 PFI。
-
解释该方法背后的选择,包括为什么要进行置换、重复以及使用哪种度量标准。
-
讨论置换的局限性。
你可以在GitHub上找到完整的项目。
你可能也会喜欢这段关于该主题的视频。如果你想了解更多,可以查看我的课程——Python 中的 XAI。如果你注册我的通讯,你可以免费访问。
为什么选择 PFI?
Phi-3 与高度高效的 iPhone LLMs 开始
本文将深入探讨 Phi-3 论文的发现,以及像 Phi-3 这样的模型发布所带来的一些影响。
·发表于 Towards Data Science ·阅读时间 8 分钟 ·2024 年 5 月 9 日
--

图片来源:作者 — 由 Stable Diffusion 2.1 生成
我之前的文章的读者可能还记得我曾讨论过 “教材就是你所需要的一切”,这是微软的一篇论文,展示了优质数据如何对模型性能产生超乎想象的影响。那里的研究结果直接反驳了“模型必须巨大才能具备能力”这一观点。该论文背后的研究人员继续进行相关工作,并发布了我认为非常令人兴奋的成果。
本文标题解释了或许是最大的发现:“Phi-3 技术报告:在你的手机上本地运行的高度有能力的语言模型”。
让我们深入了解作者从 Phi-2 模型中做出的改变、他们如何训练模型以及它如何在你的 iPhone 上运行。
关键术语
在我们深入探讨架构之前,有几个关键概念需要了解。如果你已经知道这些内容,可以跳过到下一部分。
模型的参数是指模型在训练过程中学习的权重和偏置的数量。如果你有 10 亿个参数,那么你就有 10 亿个权重和偏置来决定模型的表现。参数越多,神经网络的复杂性也就越高。头指的是变换器中自注意力机制所拥有的键、值和查询向量的数量。层指的是变换器神经网络中存在的神经单元的数量,而隐藏维度则是指典型隐藏层内神经元的数量。
分词器是将输入文本转换为嵌入的程序,变换器(transformer)随后将处理这些嵌入。词汇表大小指的是模型训练时使用的唯一标记数量。变换器的块结构是指我们在为特定模型选择层、头、激活函数、分词器和层归一化时所采用的组合方式。

图 2 来自 “GQA: 训练通用多查询变换器模型”
分组查询注意力(GQA)是一种优化多头注意力的方法,旨在减少训练和推理过程中的计算开销。正如下面的图像所示,GQA 采取了中庸之道——我们不再将 1 个值和 1 个键与 1 个查询配对,而是采用 1:1:M 的方式,其中“多”的数量小于所有查询的总数。这样做是为了从多查询注意力(MQA)中仍然获得训练成本的好处,同时最大限度地减少因此而带来的性能下降。
Phi 3 架构
让我们从该模型背后的架构开始。研究人员发布了 3 个不同的仅解码器模型,phi-3-mini、phi-3-small 和 phi-3-medium,每个模型都有不同的超参数。
-
phi-3-mini
-
38 亿个参数
-
32 个头
-
32 层
-
3072 个隐藏维度
-
4k token 默认上下文长度
-
32064 词汇表大小
-
权重存储为 bfloat16
-
训练于 3.3 万亿个标记
-
-
phi-3-small
-
70 亿个参数
-
32 个头
-
32 层
-
4096 个隐藏维度
-
8k token 默认上下文长度
-
100352 词汇表大小
-
权重存储为 bfloat16
-
训练于 4.8 万亿个标记
-
-
phi-3-medium
-
140 亿个参数
-
40 个头
-
40 层
-
3072 个隐藏维度
-
训练于 4.8 万亿个标记
-
讲解一下这些模型之间的差异,phi-3-mini模型使用了典型的多头注意力进行训练。尽管论文中没有明确指出,但我怀疑由于该模型的规模大约是其他两个模型的一半,训练多头注意力的开销是可以接受的。自然地,当他们扩展到 phi-3-small 时,选择了分组查询注意力,每个键连接到 4 个查询。
此外,他们尽量使phi-3-mini的块结构与 LLaMa-2 结构保持一致。这里的目标是让开源社区能够继续在 LLaMa-2 的基础上,使用 Phi-3 进行研究。这是进一步理解该块结构能力的合理方式。
然而,phi-3-small并没有使用 LLaMa 的块结构,而是选择使用tiktoken分词器,交替使用密集注意力层和新的块稀疏注意力。此外,他们在这些模型的训练数据集中加入了 10%的多语言数据。
训练和数据的最佳组合
与 Phi-2 类似,研究人员主要投资于优质数据。他们采用了之前用于生成数据训练模型的相似“教育价值”范式,并选择使用比上次更多的数据。他们将数据生成分为两个阶段。
第一阶段涉及找到他们认为对用户具有高“教育价值”的网页数据。这里的目标是为模型提供一般知识。第二阶段则选取第一阶段数据的一个子集,生成能够教会模型如何进行逻辑推理或获得特定技能的数据。
这里的挑战在于确保来自每个语料库的数据组合适合正在训练的模型的规模(即phi-3-small与phi-3-mini)。这是“数据最优”机制的理念,意味着你提供给 LLM 进行训练的数据能够为其块结构提供最佳能力。换句话说,如果你认为数据是训练一个优秀 LLM 的关键,那么通过数据展示给模型的技能的正确组合,可能和找到优质数据一样重要。研究人员强调,他们希望模型具备比知识更强的推理能力,因此选择了更多来自第二阶段语料库的数据,而不是第一阶段的数据。

图 2 来自论文,突出显示数据最优性的潜在关系
有趣的是,当他们用与训练phi-3-small时大致相同的数据组合来训练phi-3-medium时,他们发现从 7B 参数到 14B 的改进远不如从 3.8B 到 7B 的改进那么显著。作者怀疑这并非块结构的限制,而是他们用来训练phi-3-medium的数据组合的问题。
训练后
团队使用了监督微调(Supervised Fine Tuning, SFT)和直接偏好优化(Direct Preference Optimization, DPO)来提升模型的训练后性能。对于想要深入了解 DPO 的读者,可以参考我在这里的博客文章。监督微调是一种迁移学习方法,通过使用自定义数据集来提高大语言模型(LLM)在该数据集上的能力。作者们通过 SFT 提升了模型在数学、编程、推理和安全等多个领域的能力。随后,他们使用 DPO 优化聊天功能,引导模型远离不希望的回答,朝向理想的回应。
正是在这一阶段,作者们将phi-3-mini的上下文窗口从 4k 个标记扩展到了 128k 个标记。他们使用的方法称为“长绳法”(Long Rope)。作者声称,在这两种上下文类型之间,性能是一致的,这一点非常重要,因为上下文长度大幅增加。如果有足够的兴趣,我将单独写一篇博客文章来探讨这篇论文中的发现。
手机使用的量化
尽管这些模型较小,但要让它们在手机上运行仍然需要进一步的优化。通常,LLM 的权重是以浮点数形式存储的;例如,Phi-3 的原始权重是bfloat16格式,这意味着每个权重在内存中占用 16 位。虽然 16 位看似微不足道,但当考虑到模型中有约 10⁹个参数时,你就会意识到每增加一位位数,所需的存储量就会迅速增加。
为了解决这个问题,作者们将权重从 16 位压缩到 4 位。基本思路是减少存储每个数字所需的位数。例如,数字 2.71828 可以压缩为 2.72。虽然这是一个有损操作,但它仍然能保留大部分信息,同时大大减少存储需求。

图 1 来自论文
作者们在配备 A16 芯片的 iPhone 上运行了量化后的模型,发现该模型每秒能生成最多 12 个标记。作为对比,一台运行 LLaMa-2 量化 4 位的 M1 MacBook 每秒大约生成 107 个标记。我见过的最快标记生成速度(Groq)为每秒 853.35 个标记。考虑到这才是刚刚起步,能够在这款模型上看到 iPhone 生成标记的速度已经非常惊人。推理速度似乎只会越来越快。
将 Phi-3 与搜索结合使用
小型模型的一个局限性是它能够在网络中存储信息的位置较少。因此,我们看到 Phi-3 在需要广泛知识的任务上不如 LLaMa-2 等模型表现得那么好。
作者建议,将 Phi-3 与搜索引擎结合,模型的能力将得到显著提升。如果情况确实如此,这让我认为检索增强生成(RAG)可能会长期存在,成为帮助小型模型达到大型模型性能的关键部分。

图 4 来自论文,展示了搜索如何提高 Phi-3 的性能
结论
总结来说,我们正在看到高性能小型模型的开端。虽然训练这些模型仍然在很大程度上依赖于高性能硬件,但它们的推理过程正变得越来越普及。这引入了一些有趣的现象。
首先,能够在本地运行的模型几乎可以完全保持私密性,允许用户向这些大型语言模型(LLM)提供他们可能不愿意通过互联网发送的数据。这为更多的使用场景打开了大门。
其次,这些模型将推动移动硬件的性能提升。因此,我预计高端智能手机上会有更多的系统级芯片(SoC),特别是具有共享内存的 SoC,以便在 CPU 和 GPU 之间共享内存,最大化推理速度。此外,拥有高质量接口的硬件将变得至关重要。像 MLX 这样的库,专为苹果硅设计,可能会成为任何新硬件进入消费硬件市场的必需品。
其次,正如本文所示,高质量数据在许多方面可以超越 LLM 中的网络复杂性,因此,不仅仅是寻找,而是生成高质量数据的竞争将只会加剧。
现在是构建的激动人心时刻。
[1] Abdin, M., 等人 “Phi-3 技术报告:在手机本地运行的高性能语言模型”(2024 年),arXiv
[2] Ding, Y., 等人 “LongRoPE:将 LLM 上下文窗口扩展到超过 200 万个令牌”(2024 年),arXiv
[3] Gerganov, G., 等人 “llama.cpp 在苹果硅 M 系列上的性能”(2023 年),GitHub
[4] Ainslie, J., 等人 “GQA:从多头检查点训练通用多查询变换器模型”(2023 年),arXiv
哲学与数据科学 — 深入思考数据
第三部分:因果关系
·发表于 Towards Data Science ·阅读时间 12 分钟·2024 年 1 月 4 日
--

图片来源:Cottonbro Studios,来自 Pexels.com
我的希望是,通过本文的阅读,你能够很好地理解哲学思维中关于因果关系的观点是如何应用于你作为数据科学家的工作的。理想情况下,你将拥有更深的哲学视角,为你的工作提供更好的背景和上下文!
这是关于哲学与数据科学的多篇系列文章中的第三部分。第一部分讲述了决定论理论如何与数据科学相连接,第二部分探讨了认识论这一哲学领域如何帮助你作为数据科学家进行批判性思考。
第一部分:决定论
第二部分:认识论
towardsdatascience.com
简介
我喜欢许多哲学话题,它们会从表面看似显而易见的概念出发,例如因果关系,最终让你意识到它远不像你想的那样简单。例如,不查定义,试着思考一下…
带强制函数的物理信息神经网络
直接使用神经网络解微分方程(附代码)
·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 5 月 27 日
--

图片由 agsandrew 提供,来自 iStock
在物理学、数学、经济学、工程学以及许多其他领域,微分方程通过变量的导数来描述一个函数。简单来说,当一个变量相对于其他变量的变化率涉及其中时,通常会遇到微分方程。许多例子描述了这些关系。微分方程的解通常通过解析或数值方法推导得出。
在推导解析解时,可能会遇到繁琐甚至在某些情况下无法完成的任务,而物理信息神经网络(PINN)则直接从微分方程中得到解,绕过了解析过程。这种解微分方程的创新方法是该领域的重要发展。
作者的上一篇文章中,使用 PINN 找到了解一个描述简单电子电路的微分方程的解。本文探讨了在通过强制函数驱动电路时,寻找解的更具挑战性的任务。考虑下列串联电子电路,其中包含电阻R、电容C、电感L,以及一个正弦电压源V sin(ωt)。该电路中电流* i(t)* 的行为由方程 1 描述,这是一个带有强制函数Vω/L cos(ωt)的二阶非齐次微分方程。

图 1: 带正弦电压源的 RLC 电路

方程 1
解析解
方程 1 的解析解要求根据λ和ω₀的关系解决三种不同情况。如下所示,每种情况都会得到一个复杂且独特的i(t)公式。在稍后的结果部分,将把这些解析解与 PINN 求解得到的结果进行比较。PINN 直接从微分方程中产生解,而无需考虑这些特殊情况。
(作者使用拉普拉斯变换技术的详细解析解可在此处查看。)
案例 1:欠阻尼(λ/2 < ω₀)
阻尼指的是电路从初始过渡到平衡状态的速度。欠阻尼响应试图快速过渡,但通常会经历过度和欠调的周期性波动,才最终达到平衡。

方程 2
案例 2:过阻尼(λ/2 > ω₀)
过阻尼响应会从初始过渡缓慢地过渡到平衡状态,而不会经历过度和欠调的周期性波动。

方程 3
案例 3:临界阻尼(λ/2 = ω₀)
临界阻尼响应介于欠阻尼和过阻尼之间,提供了从初始过渡到平衡状态的最快响应。

方程 4
PINN 解
PyTorch 代码可在此处获取。
神经网络通常通过输入和期望输出的配对进行训练。输入数据被传入神经网络,反向传播调整网络的权重和偏置,以最小化目标函数。目标函数表示神经网络输出与期望输出之间的误差。
相比之下,PINN 的目标函数需要三个分量:残差分量(obj ᵣₑₛ)和两个初始条件分量(obj ᵢₙᵢₜ₁ 和 obj ᵢₙᵢₜ₂)。这些分量结合起来构成目标函数:

方程 5

PINN 目标函数
残差
残差分量是物理信息的关键所在。这个分量包含了输出的导数,约束网络必须符合定义的微分方程。残差方程(方程 6)是通过重新排列方程 1 得到的。

方程 6
在训练过程中,t 的值被输入到神经网络中,产生一个残差。然后反向传播将目标函数的残差部分减少到接近 0 的最小值,涵盖所有训练点。残差部分由以下公式给出:

方程 7
方程 6 所要求的第一和第二阶导数,di/dt 和 d²i/dt²,由 PyTorch 和 TensorFlow 神经网络平台中的自动微分功能提供。
初始条件 1
在这个电路示例中,第一个初始条件要求当输入 t = 0 时,PINN 的输出为 i(t) = 0。这是因为在 t = 0 时,正弦源 V sin(t) = 0,导致电路中没有电流流动。初始条件 1 的目标部分由方程 8 给出。在训练过程中,反向传播将减少该部分到接近 0 的值。

方程 8
初始条件 2
第二个初始条件要求当输入 t = 0 时,L di/dt = 0。这个条件源自于基尔霍夫电压定律(即,闭合回路中电压降的总和为零)。具体来说,在 t = 0 时,电路中存在以下条件:
-
电压源 V sin(ωt) = 0
-
电容器 C 的初始电荷为 Q = 0,因此电容器电压为 V_cap = Q/C = 0
-
电阻器 R 两端的电压为 V_res = iR = 0,因为 i(t) = 0(初始条件 1)
-
电感器 L 两端的电压为 V_ind = L di/dt
-
根据上述条件,电路中电压降的总和减少为 L di/dt = 0
初始条件 2 的目标部分由方程 9 给出。反向传播将减少该部分到接近 0 的值。

方程 9
目标图
以下图显示了训练过程中目标值的减少:

目标图
结果
以下测试案例比较了训练后的 PINN 对每种情况的适当解析解的响应。电路元件的值被选择以产生欠阻尼、过阻尼和临界阻尼响应,如上所讨论。所有三种情况都由一个正弦电压源驱动,V = 10 伏特 和 ω = 1.8 弧度/秒。对于每种情况,电容器和电感器的值分别为 C = 0.3 法拉 和 L = 1.51 亨利。每种情况的电阻器 R 的值如下所示。
欠阻尼 (R = 1.2 欧姆)

欠阻尼测试案例
过阻尼 (R = 6.0 欧姆)

过阻尼测试案例
临界阻尼 (R = 4.487 欧姆)

临界阻尼测试案例
结论
本文使用具有自定义目标函数的神经网络成功解决了描述正弦源驱动电子电路的微分方程。通常,微分方程的解是通过繁琐的解析过程或数值方法推导出来的。这里提供的示例展示了神经网络能够以直接和高效的方式准确解决这些方程。如三个测试案例所示,神经网络的响应与解析解完全一致。
附录:PINN 训练笔记
-
PINN 结构:
-
输入层,1 个输入
-
隐藏层,128 个神经元,使用 GELU 激活函数
-
隐藏层,128 个神经元,使用 GELU 激活函数
-
输出层,1 个神经元,使用线性激活函数
-
-
PINN 在 0 到 20 秒的时间域内使用 220 个点进行训练。点的数量由域的持续时间和每秒点数的超参数控制,测试案例中每秒设置为 11 个点。这个值为每个正弦驱动源周期提供了足够的训练点,ω = 1.8。对于更高的ω值,需要更多的点每秒,例如,ω = 4.0需要 25 个点/秒。
-
PINN 在每次从所有训练点中采样 32 个点的批次中进行训练。训练点在每个 epoch 都会随机打乱。
-
学习率在训练开始时设置为 0.01,并且每 2000 个 epoch 减少 0.75 倍。
-
目标图是成功训练的重要指标。随着训练的进行,目标值应该下降几个数量级,并在接近 0 的一个小值处达到最低。如果训练没有产生这个结果,则需要调整超参数。建议首先尝试增加 epoch 数量,然后增加每秒训练点数。
本文的 PDF 版本可通过 此处获取。
除非另有说明,所有图片均由作者提供。
物理启发的神经网络:面向应用的指南
PINN 在现实世界中的成功案例的全面概述
·发布于Towards Data Science ·阅读时长 36 分钟·2024 年 2 月 9 日
--

图片由 DALL-E 生成。
当谈到将机器学习应用于物理系统建模时,越来越多的实践者开始远离纯粹的数据驱动策略,转而接受一种混合思维方式,在这种方式下,丰富的先验物理知识(例如,控制微分方程)与数据一起用于增强模型训练。
在这种背景下,物理启发的神经网络(PINNs)作为一个多功能概念应运而生,并在有效解决现实世界问题方面取得了许多成功案例。
作为一个渴望采用 PINNs 的实践者,我非常希望了解最新的训练算法进展,以及 PINNs 在现实应用中的新颖用例。然而,我常见的一个痛点是,尽管有大量的研究论文/博客总结了有效的 PINN 算法,但关于 PINNs 新颖用例的概述却很难找到。一个明显的原因是,与领域无关的训练算法不同,PINN 用例的报告分散在各个工程领域,并且对于通常专注于某一特定领域的实践者来说,难以直接访问。因此,我经常发现自己…
physipy:使 Python 具备单位意识
第一部分:physipy 将米和焦耳引入 Python
·发表于Towards Data Science ·阅读时长 9 分钟·2024 年 4 月 24 日
--
你是否曾经使用 Python 进行工程/科学计算,却迷失或困惑于变量的单位,例如“这个值是米还是毫米?”或者你意识到在某一时刻你把电流与电阻相加了——这是不可能的?正如每个物理教师曾经说过的那样:你不能把胡萝卜和西红柿加在一起。
好的,physipy正是为了解决这些问题而存在的。

图片由Artturi Jalli提供,来自Unsplash
目录:
· 什么是 physipy? · 一步一步理解 physipy
∘ 使用 physipy 计算体重指数 BMI
∘ 使用 numpy 数组的牛顿运动定律
∘ 使用 NumPy 函数的欧姆定律
∘ 爱因斯坦的质量-能量等价关系,适用于常见粒子,使用 favunit
∘ 自由落体与内置 favunit
∘ 使用 Matplotlib 绘制物体位置和速度
· 总结
所有图片由作者提供。
什么是 physipy?
皮埃尔-西蒙·拉普拉斯、逆概率和中心极限定理

CC BY-SA 4.0/公共领域/公共领域/图片由作者提供/公共领域
关于拉普拉斯对逆概率的精彩解答及其中心极限定理的发现
·发表于Towards Data Science ·27 分钟阅读·2024 年 3 月 5 日
--
在 1600 年代末,雅各布·伯努利思考了一个有趣的问题:如何估计一个样本空间无法完全访问的事件的概率?比如,如何估计在一生中被天上的闪电击中的概率?或者,如果你偏好一个不那么戏剧化的情境,如何估计一个充满未知数量黑白票的抽屉中黑票的真实比例?这个实验情境显然是二项式的:一定数量的独立试验,每次试验只有两种可能结果之一。雅各布·伯努利关于这个二项式实验的沉思使他发现了大数法则(弱法则)。那大约是在 1689 年。
1733 年,在伯努利发现大数法则(WLLN)后的四十年,一位名叫亚伯拉罕·德·莫伊弗的杰出法国人,在英格兰过着捉襟见肘的流亡生活,他弄清楚了如何准确计算伯努利二项式思想实验中表达的概率。德·莫伊弗的方法被称为德·莫伊弗定理(详见这里),成为 18 世纪和 19 世纪最具影响力的发现之一……
在你的公司中推销(AI)创新
启动你当前工作中的 AI 之旅的关键步骤
·发布于Towards Data Science ·6 分钟阅读·2024 年 7 月 16 日
--

图片来自KindelMedia,发布于Pexels
我听过很多次数据科学家因为公司缺乏有趣的项目而感到沮丧。说服业务利益相关者和管理层启动 AI 项目可能是一个挑战。虽然通常数据科学家不负责思考和提出需要优先考虑的项目,但我见证了数据科学家与数据经理和产品经理一起如何影响产品路线图,帮助引入更多创新和有影响力的项目。
在这篇博客文章中,我将分享一些我看到过的成功影响团队或公司文化、推动更多创新的机器学习(ML)或人工智能(AI)项目的步骤和策略。请注意,这不是一蹴而就的事情,而是一个过程,在这个过程中,你的知识和动力可以帮助公司中的其他人跳出思维定势,看到机器学习和人工智能的潜力。
推动公司创新和 AI 的关键步骤和策略包括:提高认知、通过使用案例激发灵感、寻找赞助人和创意,以及优先级排序。
1. 提高对人工智能的认知
第一步是提高组织内对人工智能能够做什么和不能做什么的认知。许多人对人工智能的理解有限,这可能导致既有怀疑也有不切实际的期望。
第一步的最终目标是帮助你周围的人对 AI 产生敏感度。这种敏感度包括:什么是 ML 和 AI 的区别,我可以用传统的 ML 解决哪些问题(分类、回归、时间序列等),GenAI 带来了哪些新机会(文本生成、图像生成、少样本分类等)。一些实现这种意识的策略包括:
-
研讨会和培训:这些可以在公司内部组织,或者你也可以推荐在线课程。第二种选择通常更快速且成本较低;像“AI For Everyone”和“Generative AI For Everyone”这样的课程,来自deeplearning.ai,总是一个不错的起点。
-
赋能并鼓励每个人使用 GenAI:可以通过随意地解释你自己如何利用 GenAI,分享通过它获得的图片和诗歌,或者挑战为什么他们还没有使用它来实现这一点。试着理解是否有具体的担忧阻碍了人们的使用(例如,“我不信任它处理我的数据”),并分享可以帮助缓解这些感知风险的工具或技巧。
-
展示 ML / AI 项目:积极参与公司内的演示、全员大会或内部知识共享会议。你可以分享你或你的团队已经实现的 ML 或 AI 项目。确保提供合适的技术细节,使人们能够跟上你的演讲,并突出项目的潜力、影响和收获也非常重要。还可以有趣地分享这些项目如何与“传统软件开发”或公司其他类型的项目有所不同。
2. 通过相关的应用案例进行启发
你周围的人已经对 AI 和 ML 有一定的了解和敏感度,知道存在的模型类型、它们的潜力以及这些类型的项目如何运作,太好了!下一步是开始引入可以启发你公司项目的应用案例。这些案例可以来自竞争对手或类似行业,也可以来自适用于大多数公司的通用应用案例(用户细分、客户/用户流失预测、销售预测等)。
展示竞争对手或其他公司如何利用 AI,可以有效地展示其潜力并激发下一步的行动。在展示应用案例时,可以着重讲解该案例解决的问题、取得的实际效益,并类比如何在你的公司中应用类似的解决方案。类似地,对于更一般的应用案例,例如用户细分,展示在你公司中可能产生的应用类型(动态定价、个性化、改进沟通等)也会很有趣。
如果已经有一些团队在进行竞争对手分析(通常是用户研究人员),确保他们也在考虑 ML / AI 功能。帮助他们提高敏感度,理解这些解决方案如何在底层运作,从而进一步丰富他们的研究,并发现你公司可能的 AI 机会。
3. 寻找你的赞助人和使用案例

现在大家已经意识到 AI 是什么,以及它能在公司中解决哪些类型的问题和使用案例。如果你做得对,你应该能够让一些人对这一切潜力感到非常兴奋!
这种兴奋感可能会转化为人们直接找上你,分享其他使用案例,提出问题,甚至询问某个问题是否可以通过 AI 来解决。这些人就是你的赞助人:组织内的支持者,他们可以支持并倡导 AI 项目。根据公司规模和文化变革的需求,这些赞助人可能会接近到足以影响最高层的决策。然而,能够激励业务利益相关者也是足够好的,因为他们可以推动通过 AI 解决自己的目标。
你已经为 AI 项目的想法埋下了种子。现在你可以开始为公司特定问题或目标提出具体的 AI 解决方案。感谢之前在意识、使用案例和赞助人方面的工作,这些提案应该会更容易被接受!
然而,最有趣的部分是,等待使用案例也会反过来找上你。你的 AI 赞助人和公司中的其他人现在能够将问题和目标与 AI 解决方案联系起来。你可能会惊讶于从这个方向出现的使用案例数量。你所建立的意识将自然地引导出更多有根据且相关的建议。
4. 评估和优先排序使用案例
到此为止,你可能已经能够收集到多个倡议的想法,并且得到了管理层的支持,分配一些时间来处理这些想法。但你如何决定从哪里开始呢?可能开始时选择潜力最大的倡议是有意义的,但预测创新的投资回报率,尤其是 AI 项目的回报率,可能会因为它们固有的不确定性而变得具有挑战性。然而,有一些关键点可以帮助你做出决定:
-
聚焦于公司内的特定战略痛点或机会。
-
使用行业基准来估算成功率和潜在收入。
-
评估潜在的好处,同时也要考虑可行性和风险。
-
区分探索性项目(高不确定性,长期)和利用性项目(低不确定性,短期)。
尝试从探索性的想法(快速收益)开始,以便更快地证明价值、获得关注并建立信任。一旦这些管理得当,也许你可以开始引入探索性想法(长期目标与月球任务),它们旨在实现更长远、更大的转型,但也涉及更高的失败风险。平衡持续交付与改进以及月球任务的探索,对于长期保持信任,同时探索真正的创新至关重要。
在上一篇文章“从正确的起点开始机器学习产品的倡议”中,我深入探讨了如何从一开始就成功启动机器学习倡议,并管理其固有的不确定性。
学到的前三个教训:问题、规模和数据
towardsdatascience.com
总结
在公司中推广人工智能是一个长期的旅程,而不是一蹴而就的事情。根据我的经验,重要的是从生成意识和教育开始,展示应用案例,并与处于合适位置的支持者对齐。只有这样,提出应用案例时才能获得关注;甚至其他人也可能会带着相关的想法来找你!一旦收集了一些应用案例,并且有足够的带宽和支持来优先考虑某些投入,就该专注于战略性问题,准确量化机会和潜力,并在快速收益与长期目标之间取得平衡。
我们正处在一个每个人都在讨论人工智能的时刻。特别是,各家公司都在思考他们的(生成性)人工智能战略,以及这一新技术如何改变业务和工作方式。这对你有利:现在是开始引入这些步骤的好时机,因为人们特别渴望学习、尝试并利用人工智能。
面向产品的机器学习:数据科学家的指南
如何构建用户喜爱的机器学习产品。
·发表于 Towards Data Science ·阅读时间:23 分钟·2024 年 10 月 14 日
--

图片来源:Pavel Danilyuk: www.pexels.com/photo/a-robot-holding-a-flower-8438979/
数据科学为探索新概念并验证其可行性提供了丰富的机会,所有这些都是为了构建功能和产品背后的“智能”。然而,大多数机器学习(ML)项目都失败了!这不仅仅是因为工作的本质具有实验性。项目可能缺乏目的性,或者没有与实际问题相结合,而将机器学习整合到产品中需要致力于长期的解决问题、投资数据基础设施以及多方技术专家的参与。本文旨在帮助在规划阶段减少这些风险,快速失败,同时培养成为以产品为导向的数据科学家。
本文提供了一种规划机器学习(ML)产品的结构化方法,通过介绍产品设计文档的关键领域来进行讲解。我们将涵盖明确需求、理解数据限制以及定义成功标准等内容,这些内容决定了你构建成功机器学习产品的方式。这些文档应该具有灵活性,可以用来找出最适合你团队的方案。
我很幸运曾在初创公司工作,成为小型而灵活团队的一员,在这里,角色和责任通常是交叉的。我提到这一点是因为下面讨论的话题跨越了传统的边界,涉及项目管理、产品、UI/UX、市场营销等多个领域。我发现,那些能够跨越这些边界并以同理心进行协作的人,能够创造出优秀的产品,成为更好的同事。
为了说明这个过程,我们将通过一个假设的快递公司提出的功能请求来进行讲解:
“作为一家快递公司,我们希望提高在包裹预计延迟时,提前向用户发出警告的能力。”
问题定义
本节旨在简洁地描述问题及项目动机。由于开发通常跨越数月或数年,这不仅能确保每个人从同一页面开始,尤其是在机器学习领域,它有助于你在面对挑战和实验失败时始终保持坚定。以项目启动为起点。鼓励开放合作,并努力揭示所有跨职能团队中的假设,确保从第一天起在产品战略和愿景上达成一致。
实际撰写陈述时,首先要用你自己的话重述问题。对我而言,将其写成长篇并逐渐缩减,使得聚焦于具体细节变得更容易。在我们的示例中,我们从一个功能需求开始。它提供了一些方向,但在具体要求上留有模糊空间。例如,“提高我们的能力”暗示着现有系统——我们是否能访问现有的数据集?“提前警告”虽然信息模糊,但表明如果包裹延迟,客户会收到主动提示。这些都对我们如何构建系统产生影响,并为评估项目的可行性提供了机会。
我们还需要理解项目背后的动机。虽然我们可以假设新功能将提供更好的用户体验,但商业机会在哪里?在定义问题时,始终要将其与更大的商业战略联系起来。例如,改进延迟通知不仅仅是为了构建一个更好的产品——它关系到减少客户流失和提高满意度,从而增强品牌忠诚度并降低支持成本。这才是你衡量项目成功的真正标准。
在团队中共同解决问题是所有工程师应该培养的技能——这不仅是面试过程中常常考察的内容,而且,正如之前所讨论的,它帮助设定项目和战略的期望,确保每个人自上而下都能认同。如果一开始就没有达成一致,可能会对项目造成灾难性的影响,甚至几年后仍然如此。不幸的是,这正是 Babylon 健康聊天机器人的命运。Babylon 的目标是通过使用人工智能提供准确的诊断,来彻底改革医疗保健。然而,令公司吃亏的是,它过于简化了医疗保健的复杂性,尤其是在不同地区和患者群体之间。例如,在英国,发烧症状可能意味着普通感冒,但在东南亚可能意味着更为严重的疾病。缺乏清晰性并对人工智能能力的过度承诺导致了系统实际能做的与现实世界医疗环境中所需的严重不匹配(sifted.eu/articles/the-rise-and-fall-of-babylon)。
需求和约束
在定义了问题及其重要性后,我们可以开始记录交付项目的需求并设定范围。这些需求通常分为两类:
-
功能性需求,即从用户的角度定义系统应做什么。这些需求直接关联到用户期望的功能和交互。
-
非功能性需求,即系统如何运行——性能、安全性、可扩展性和可用性。
如果你曾经使用过敏捷框架,你会熟悉用户故事——从用户的角度讲述功能的简短、简单的描述。我发现,作为团队共同定义这些需求是对齐的好方法,首先从记录功能性需求开始,确保从用户的角度出发。然后,将这些需求映射到用户旅程中,识别出机器学习模型能够提供价值的关键时刻。这种方法有助于在早期确立明确的边界,减少“范围蔓延”的可能性。如果你的项目没有传统的最终用户,或许你是在替代现有的流程?和一线人员交谈——无论是操作员工还是流程工程师,他们是你的领域专家。
从一组简单的故事中,我们可以构建可操作的模型需求:
用户将接收到什么信息?
作为一名等待交付的客户,我希望及时而清晰地收到关于我的包裹是否延迟或准时的通知,这样我可以相应地规划我的一天。
用户将如何收到警告?
作为一名等待交付的客户,我希望通过我偏好的通信渠道(短信或本地应用)接收关于我的包裹是否延迟的通知,这样我可以采取行动,而无需一直检查应用程序。
系统可以使用哪些用户特定的数据?
作为一名关注隐私的客户,我只希望使用诸如我的地址等必要的信息来预测我的包裹是否延迟。
如果做得对,这些需求应该能够约束你关于数据、模型和训练评估的决策。如果你发现存在冲突,可以根据用户影响和可行性进行权衡。让我们分析上面的用户故事,看看我们的机器学习策略会受到哪些约束:
用户将接收到什么信息?
- 如果仅需要延迟通知,模型可以保持简单(例如二分类);更详细的输出需要更复杂的模型和额外的数据。
用户将如何收到警告?
- 实时警告需要低延迟的系统,这就对模型和预处理的复杂度提出了限制。
系统可以使用哪些用户特定的数据?
- 如果我们只能使用有限的用户特定信息,模型的准确性可能会受到影响。另一方面,使用更详细的用户特定数据需要获得用户同意,并增加了数据存储的复杂性,以便遵守数据隐私最佳实践和法规。
考虑用户促使我们在设计时将伦理和隐私嵌入其中,从而打造人们信任的产品。我们的训练数据是否导致包含偏见的输出,歧视某些用户群体?例如,低收入地区可能因基础设施较差而影响交付时间——这一点在数据中是否得到了公平反映?我们需要确保模型不会延续或放大现有的偏见。不幸的是,这类案例层出不穷,例如美国广泛使用的基于机器学习的再犯风险评估工具 COMPAS,它被证明高估了黑人被告的再犯风险,而低估了白人被告的风险(www.propublica.org/article/how-we-analyzed-the-compas-recidivism-algorithm)。
除了伦理,我们还需要考虑其他非功能性需求,比如性能和可解释性:
-
透明性与可解释性:我们将模型呈现为多少“黑箱”?错误预测或程序缺陷的后果是什么?这些问题并不容易回答。展示更多关于模型如何做出决策的信息需要强大的模型以及使用可解释的模型,如决策树。像 SHAP(Shapley 加性解释)和 LIME(局部可解释模型无关解释)这样的技术可以帮助解释不同特征如何影响预测,尽管这有可能让用户感到信息过载。在我们的示例中,告诉用户为什么一个包裹会被延迟是否能建立信任?通常,模型的可解释性能增加内部利益相关者的认同。
-
实时或批量处理:实时预测需要低延迟的基础设施和流式数据管道。批量预测则可以在定期的时间间隔内处理,通常对于不那么紧急的需求已经足够。选择实时预测或批量预测会影响解决方案的复杂性,并且会影响哪些模型适合部署。例如,简单的模型或优化技术可以减少延迟。稍后会详细讨论这个问题。
从营销中借来的一个技巧是创建用户画像。通常,这基于通过正式访谈和调查收集的市场研究,以了解用户的需求、行为和动机。然后根据共同的特征(如人口统计学、目标和挑战)进行细分。由此,我们可以为每个细分群体开发详细的档案,给他们起名字并赋予背景故事。在规划过程中,用户画像帮助我们理解模型预测将如何被接收,并在不同情境下引发的行动。
以莎拉为例,她是一个“忙碌的家长”角色。她优先考虑速度和简洁性。因此,她重视关于包裹延迟的及时简洁的通知。这意味着我们的模型应侧重于快速的二元预测(延迟或准时),而不是详细的输出。最后,由于莎拉更喜欢通过她的手机接收实时通知,模型需要无缝集成到低延迟的系统中,以便提供即时更新。
通过记录功能性和非功能性需求,我们定义了“我们在构建什么”以满足用户需求,并结合“为什么”这与业务目标相符。
建模方法
现在是时候思考我们如何满足需求了。这从用机器学习(ML)术语来描述问题开始,记录输入类型(特征)、输出类型(预测)以及学习它们之间关系的策略。至少需要一些起点,我们知道这将是一个实验性的过程。
对于我们的例子,输入特征可能包括交通数据、天气报告或包裹详情,同时需要一个二元预测:“延迟”或“准时”。显然,我们的问题需要一个二元分类模型。对我们来说,这很简单,但对于其他产品背景,有多种方法可供选择:
监督学习模型:需要一个带标签的数据集进行训练。
-
分类模型:二元分类对于利益相关者来说容易实现和解释,非常适合最小可行产品(MVP)。但这会牺牲多类分类所提供的更细致的见解,比如我们案例中的延迟原因。然而,这通常需要更多的数据,意味着更高的成本和开发时间。
-
回归模型:如果目标是一个连续值,比如包裹延迟的准确时间(例如,“您的包裹将延迟 20 分钟”),回归模型将是合适的选择。这些输出也会受到更多不确定性的影响。
无监督学习模型:处理无标签数据。
-
聚类模型:在包裹延迟的背景下,聚类可以在探索阶段用于根据相似特征(如地区或常见交通问题)对交付进行分组。发现这些模式可以为产品改进提供信息,或指导用户细分,以便个性化功能/通知。
-
降维:对于具有大量特征空间的噪声数据集,可以使用主成分分析(PCA)或自动编码器等降维技术来减少计算成本和过拟合,通过使用较小的模型来牺牲一些特征上下文的损失。
生成模型:通过对标签数据和无标签数据的训练,生成新的数据。
-
生成对抗网络(GANs):对于我们来说,GANs 可以有限地用于模拟一些稀有但影响深远的配送延迟场景,比如极端天气条件或突发交通事件,前提是需要容忍边缘案例。然而,这些网络因训练难度大、计算成本高而著名,并且需要确保生成的数据具有现实性。对于早期产品来说,这通常不太适用。
-
变分自编码器(VAEs):VAEs 的使用场景与 GANs 相似,具有更高的控制能力,能更好地控制生成输出的范围。
-
大语言模型(LLMs):如果我们想将基于文本的数据(如客户反馈或司机笔记)纳入预测,LLMs 可以帮助生成摘要或洞察。然而,实时处理是一个挑战,尤其是当计算负载很重时。
强化学习模型:这些模型通过与环境互动学习,并通过奖励或惩罚来接收反馈。对于一家配送公司,强化学习可以用于根据实际配送结果优化系统。然而,这对于最小可行产品(MVP)来说并不太适合。
在问题的初步框架设计中,随着我们从数据探索和早期模型训练中获得洞察,问题的定义会发生变化是很正常的。因此,首先从一个简单、可解释的模型开始,测试可行性。然后,通过增加更多的特征、调整超参数,逐步增加复杂性,再探索更复杂的模型,如集成方法或深度学习架构。这种方式既能保持较低的成本和开发时间,又能快速推向市场。
机器学习与传统软件开发在估算开发时间方面有显著区别,工作中的很大一部分是由实验组成的。在实验中,结果总是未知的,所需的实验次数也无法预料。这意味着你提供的任何估算都应该包括较大的预留时间,或者要有这样的期望:它会随时变化。如果产品特性不是至关重要的,我们可以通过从简单模型开始并为后续逐步改进做计划,来提供更紧凑的时间估算。
开发模型所需的时间是任何项目中的一项重大成本。根据我的经验,即使是简单模型的快速结果,在后续环节中也会带来巨大好处,允许你将工作交接给前端开发人员和运营团队。为此,我有一些建议。首先,快速失败,优先进行最少工作量、最大成功可能性的实验。然后根据你的学习成果调整计划。虽然这很明显,但人们确实很难接受失败。所以,要支持你的团队,这是过程的一部分。我的第二个建议是,做好调研!查找类似问题的例子以及它们是如何被解决的,或者没有解决的。尽管机器学习最近的火爆趋势让它变得更加流行,但这个领域已经存在很长时间了,而且十有八九已经有人至少在某种程度上解决了与你的问题相关的挑战。保持关注文献,使用像 Papers with Code、Hugging Face 的每日论文或 AlphaSignal 这样的站点,后者提供了一个很好的邮件通讯。对于数据库,可以尝试 Google Scholar、Web of Science 或 ResearchGate。令人沮丧的是,获取主要期刊的费用成为了进行全面文献回顾的一个重大障碍。Sci-Hub…
数据需求
现在我们知道我们的“黑盒”会做什么,那么我们应该往里面放什么呢?是时候考虑数据了,根据我的经验,这是设计中最关键的一部分,关系到降低风险。目标是为获取足够的、相关的、高质量的数据创建一个早期的路线图。这包括训练数据、潜在的内部或外部数据源,以及评估数据的相关性、质量、完整性和覆盖范围。处理隐私问题,并规划数据的收集、存储和预处理,同时考虑应对类不平衡等限制的策略。
如果没有充分考虑项目的数据需求,你将面临预算爆炸并且永远无法完全交付的风险,特斯拉自动驾驶就是一个这样的例子。他们在数据收集方面的挑战突显了低估实际数据需求的风险。从一开始,系统就受到早期采用者车辆所捕获数据的限制,至今仍然缺乏实现真正自动驾驶所需的传感器深度(spectrum.ieee.org/tesla-autopilot-data-deluge)。
如果你正在开发的功能已经是手动过程的一部分,那么数据源的获取会变得容易得多。如果是这样,你可能已经有现有的数据集和性能基准。如果没有,可以从内部寻找。大多数组织都会收集大量的数据,可能是系统日志、CRM 数据或用户分析。然而请记住,垃圾进,垃圾出!如果数据集从一开始就没有为机器学习构建,通常会缺乏训练所需的质量。它们可能不够丰富,或无法完全代表手头的任务。
如果不成功,你需要向外部寻求帮助。从专为机器学习设计的高质量公开数据集开始,如 Kaggle、UCI ML 数据库和 Google Dataset Search。
如果特定问题的数据不可用,可以尝试更一般的公开数据集。浏览像恩隆邮件数据集(用于文本分析和自然语言处理)、政府人口普查数据(用于基于人口的研究)或商业发布的数据集,如 IMDb 电影评论数据集(用于情感分析)等数据泄漏。如果这也失败了,你可以开始从多个来源汇总数据来创建一个丰富的数据集。这可能涉及从电子表格、API,甚至是爬取网页中提取数据。无论哪种情况,挑战在于确保你的数据是干净的、一致的,并且格式适合机器学习的用途。
最坏的情况是,你从零开始,需要自己收集原始数据。特别是处理视频、图像或文本等非结构化数据时,这将非常昂贵且耗时。在某些情况下,数据收集可以通过进行调查、设置传感器或物联网设备,甚至发起众包标签挑战来实现自动化。
无论如何,手动标注几乎总是必要的。这里有许多高度推荐的现成解决方案,包括 LabelBox、Amazon SageMaker Ground Truth 和 Label Studio。它们都能加速标注过程,并帮助维持质量,即使在随机抽样的大数据集上也是如此。
如果还不清楚,随着你从内部数据源转向手动收集数据,构建适合机器学习的数据库的成本和复杂性会显著增加,项目的风险也会随之增加。虽然这不是项目的致命问题,但考虑你的时间表和预算限制是非常重要的。如果你只能收集一个小数据集,你可能只能采用较小的模型解决方案,或者对像 Hugging Face 和 Ollama 这样的基础模型进行微调。此外,确保你有一笔额外的预算来应对项目后期获取更多数据的需要。这一点很重要,因为理解项目所需的数据量只能通过解决机器学习问题来获得答案。因此,通过确保你有途径获取更多数据来提前缓解风险。常见的做法是引用“餐巾纸背面的估算”作为数据需求的合理估计。但这实际上只适用于一些非常明确的问题,如图像分类和传统的机器学习问题。
如果明显无法收集足够的数据,生成模型在产生合成训练数据方面取得了一些有限的成功。例如,美国运通公司开发的欺诈检测系统就采用了这种技术,通过模拟卡号和交易来检测与实际欺诈的差异或相似之处 (masterofcode.com/blog/generative-ai-for-fraud-detection)。
一旦建立了基础数据集,你需要了解其质量。我发现手动操作问题非常有效。它能提供关于有用特征和未来挑战的洞察,同时为模型性能设定现实的期望。在这个过程中,你还能早期发现数据质量问题和覆盖范围的漏洞。动手处理数据,积累领域知识,并注意以下几点:
-
数据的相关性:确保现有数据能反映你解决问题的努力。以我们的例子为例,交通报告和配送距离是有用的,但客户购买历史可能并不相关。识别数据的相关性有助于减少噪音,同时让较小的数据集和模型更有效。
-
数据质量:注意任何你发现的偏差、缺失数据或异常情况,这对后续构建数据预处理管道非常有用。
-
数据的完整性和覆盖面:检查数据是否充分覆盖所有相关的场景。以我们的例子为例,数据可能需要涵盖城市中心和更为偏远的地区,忽略这一点会影响模型的泛化能力。
-
类别不平衡:了解类别或目标变量的分布,以便在可能的情况下收集更多数据。希望在我们的案例中,“延迟”包裹将是一个稀有事件。在训练过程中,我们可以实施成本敏感学习来应对这一点。就个人而言,我总是通过像 SMOTE(合成少数类过采样技术)或自适应合成(ADASYN)采样等技术,在过采样少数类时获得更多成功。
-
数据的时效性:考虑数据需要多么实时才能做出准确的预测。例如,可能需要实时交通数据才能做出最准确的预测。
当涉及到更全面的质量分析时,探索性数据分析(EDA)是发现模式、识别异常以及更好理解数据分布的有效方法。我将在另外一篇文章中详细介绍 EDA,但通过可视化数据趋势、使用相关性矩阵以及理解异常值,可以揭示潜在的特征重要性或挑战。
最后,考虑不仅仅是解决眼前的问题——要考虑数据的长期价值。它是否可以用于未来的项目或扩展到其他模型?例如,交通和配送数据最终可以帮助优化整个物流链中的配送路线,从而提高效率并在长远来看降低成本。
成功指标——找到足够好
在训练模型时,快速的性能提升往往会伴随着收益递减的阶段。这可能导致没有方向的试验和错误,同时打击士气。解决方案是什么?从一开始就定义“足够好”的训练指标,以确保达到最小的阈值,从而实现项目的业务目标。
为这些指标设定可接受的阈值需要对产品有广泛的了解,并具备沟通技术与业务观点之间差距的软技能。在敏捷方法中,我们将这些称为验收标准。这样做可以让我们快速推出最低规格的版本,然后进行迭代。
什么是业务指标? 业务指标是衡量任何项目成功的真正标准。这些可以是降低客户支持成本或增加用户参与度,并在产品上线后进行衡量,因此也称为线上指标。在我们的例子中,如果准确率为 80%,但能减少 15%的客户服务成本,那可能是可以接受的。实际上,您应该使用单一模型和单一业务指标进行跟踪,这样可以保持项目的专注,并避免在何时成功交付的问题上产生歧义。您还需要确定如何跟踪这些指标,查找业务团队应该能够访问的内部仪表盘和分析工具,如果没有,可能就不是业务的驱动力。
平衡业务和技术指标:找到一个“足够好”的性能,首先要理解现实世界中事件的分布,然后将其与用户(因此也影响业务)的反应联系起来。以我们的快递员示例为例,我们期望延迟的包裹是一个罕见事件,因此对于我们的二分类器来说,存在类别不平衡。这使得仅使用准确度不合适,我们需要考虑用户对预测的反应:
-
假阳性(预测存在延迟但实际上没有)可能会给客户带来烦人的通知,但当包裹随后按时到达时, inconvenience 很小。避免假阳性意味着优先考虑高精度。
-
假阴性(未能预测到延迟)很可能会导致更高的客户挫败感,因为客户没有收到包裹且没有提前警告,降低了重复业务的机会,并增加了客户支持成本。避免假阴性意味着优先考虑高召回率。
对于我们的示例,业务可能更看重高召回率的模型。然而,对于准确度不足 100%的模型,仍然需要在精度和召回率之间取得平衡(我们不能通知每个客户包裹延迟)。这种权衡最适合通过 ROC 曲线来说明。对于所有分类问题,我们通过 F1 得分来衡量精度和召回率的平衡,对于不平衡的类别,我们可以扩展为加权 F1 得分。
平衡精度和召回率是一门精细的艺术,可能会对用户产生意想不到的后果。为了说明这一点,考虑像 Google 日历这样的服务,它提供公司和个人用户账户。为了减少经常收到假会议请求的企业的负担,工程师可能会优先考虑高精度的垃圾邮件过滤。这确保了大多数假会议会被正确标记为垃圾邮件,但也会以较低的召回率为代价,导致一些合法会议被错误标记为垃圾邮件。然而,对于个人账户来说,收到假会议请求的情况要少得多。随着账户使用时间的增长,由于模型召回率较低,合法会议被错误标记的风险变得显著。在这种情况下,用户对服务的负面看法会变得非常重要。
如果我们将我们的快递员示例视为回归任务,目标是预测延迟时间,那么像 MAE 和 MSE 这样的指标是合适的选择,它们对你的产品有略微不同的含义:
-
平均绝对误差(MAE):这是一个直观的指标,用于衡量平均预测值与实际值的接近程度。因此,它是一个简单的指标,用于评估发送给用户的延迟估计的准确性。
-
均方误差(MSE):由于差异被平方,这会更加惩罚较大的错误,因此如果延迟预测中的重大错误对用户满意度的影响更大,MSE 就显得很重要。然而,这也意味着该指标对离群值更为敏感。
如上所述,这是将模型指标转化为每个人都能理解的术语,并传达权衡的过程。这是一个协作过程,因为与用户和产品更接近的团队成员会更好地理解需要推动的业务指标。找到那个能够指引项目朝着同一方向前进的单一模型指标。
最后一点,我发现涉及机器学习的项目往往会过度承诺可以交付的内容。通常这种现象来自组织的高层,在那里产品或投资者之间会产生炒作。这对项目和你的理智都是不利的。应对这种情况的最佳方法是通过在设计中沟通现实的期望值,这些期望值与问题的复杂性相匹配。永远记住,承诺少一些,交付多一些总是更好。
高层系统设计
到目前为止,我们已经涵盖了数据、模型和指标,并讨论了如何处理我们的功能需求。现在,是时候关注非功能需求,特别是可扩展性、性能、安全性和部署策略了。对于机器学习系统,这涉及到使用系统上下文或数据流图来记录系统架构。这些图表将关键组件表示为模块,定义了输入、转换和输出。展示系统各部分如何交互,包括数据摄取、处理管道、模型服务和用户界面。通过这种方式,确保了系统的模块化,使得团队可以在不影响整个管道的情况下,隔离并解决问题,随着数据量或用户需求的增长,从而最小化瓶颈或成本上升相关的风险。
一旦我们的模型训练完成,我们需要有一个计划,将模型部署到生产环境,使其能够被用户或下游系统访问。常见的方法是通过 REST API 暴露模型,其他服务或前端可以进行请求。对于实时应用,像 AWS Lambda 或 Google Cloud Functions 这样的无服务器平台非常适合低延迟(只需管理冷启动)。如果吞吐量是一个要求,那么可以使用批处理处理和可扩展的数据管道,如 AWS Batch 或 Apache Spark。我们可以将机器学习系统设计的考虑因素分解为以下几项:
基础设施和可扩展性:
首先,我们需要选择系统的基础设施。具体来说,我们将把系统部署在哪里:本地、云端,还是作为一种混合解决方案。云平台,如 AWS 或 Google Cloud,提供了基于需求的自动扩展,既可以垂直扩展(更大的机器),也可以水平扩展(增加更多的机器)。考虑系统如何应对 10 倍或 100 倍的数据量。Netflix 通过他们的技术博客提供了关于如何在大规模下运作的宝贵见解。例如,他们开源了他们的容器编排平台 Titus,Titus 自动化了在 AWS EC2 实例上通过自动扩展组部署成千上万的容器(netflixtechblog.com/auto-scaling-production-services-on-titus-1f3cd49f5cd7)。有时候,如果处理敏感数据,可能需要本地基础设施。这能提供更好的安全控制,但在维护和扩展时成本较高。无论如何,准备好使用基础设施即代码工具(如 Terraform 和 AWS CloudFormation)对基础设施进行版本控制,并实现自动化部署。
性能(吞吐量和延迟):
对于实时预测,性能至关重要。有两个关键指标需要考虑:吞吐量,衡量系统每秒能够处理的请求数量(即每秒请求数);延迟,衡量返回预测所需的时间。如果你预期使用相同的输入进行多次预测,则可以考虑为部分或整个管道添加缓存,以减少延迟。通常,水平扩展更为优先,以便在高峰时期响应流量激增并减少单点瓶颈。这强调了在系统设计过程中做出的关键决策将直接影响性能。例如,Uber 围绕 Cassandra 数据库构建了他们的核心服务,专门优化低延迟实时数据复制,确保快速访问相关数据。(www.uber.com/en-GB/blog/how-uber-optimized-cassandra-operations-at-scale/)。
安全性:
对于机器学习系统,安全性适用于用户请求的 API 认证。这通常是标准的做法,采用像 OAuth2 这样的认证方法,并通过速率限制、阻止 IP 地址列表和遵循 OWASP 标准来保护端点。此外,确保任何存储的用户数据在静态和传输过程中都经过加密,并且对于内部和外部用户都实施严格的访问控制策略。
监控与警报:
同样,考虑监控以维护系统健康至关重要。跟踪关键性能指标(KPI),如吞吐量、延迟和错误率,并设置警报,以便在这些指标低于可接受的阈值时通知工程师。这可以在服务器端(例如你的模型端点)或客户端(例如用户端)进行,以包括网络延迟。
成本考虑:
为了简化基础设施管理,基于云的系统成本可能迅速增加。首先估算处理数据、模型训练和服务所需的实例数量,并将这些与项目预算和不断增长的用户需求进行平衡。大多数云平台提供成本管理工具,帮助你跟踪支出并优化资源。
MLOps:
从一开始就要包含一个有效管理模型生命周期的计划。目标是加速模型迭代,自动化部署,并保持对指标和数据漂移的强大监控。这使你能够从简单做起,并迅速进行迭代!使用 Git 进行代码的版本控制,并使用 DVC(数据版本控制)来跟踪数据模型的变更。像 MLFlow 或 Weights & Biases 这样的工具跟踪实验,而 CI/CD 管道则自动化测试和部署。一旦部署,模型需要实时监控,使用像 Prometheus 和 Grafana 这样的工具来检测数据漂移等问题。
高级系统设计可以降低风险,确保你的团队能够适应并随着系统的成长而演化。这意味着设计一个与模型无关且能够扩展的系统,通过将系统分解为模块化组件,构建一个强大的架构,以支持快速试验与错误、可扩展的部署和有效的监控。
通过模拟机器学习进行原型设计
现在我们有了一种交付项目需求的方法,至少从机器学习的角度来看。为了完善我们的设计,我们现在可以概述一个产品原型,重点关注用户界面和体验(UI/UX)。在可能的情况下,原型应该是互动式的,验证该功能是否能为用户提供真正的价值,准备好在用户体验上进行迭代。由于我们知道机器学习是耗时且资源密集型的,你可以将模型设计和原型放在一边,而无需一个完整的机器学习组件。记录你将如何模拟这些输出,并测试端到端系统,详细描述在设计文档中原型制作所用的工具和方法。这一点非常重要,因为原型可能是你第一次收集反馈并完善设计的机会,可能会发展成 V1 版本。
为了模拟我们的机器学习,我们用一个简单的占位符替代预测并模拟输出。这可以简单地是生成随机预测或构建一个基于规则的系统。原型设计 UI/UX 涉及使用像 Figma 这样的设计工具创建原型,或者使用 Postman 和 Swagger 进行 API 原型设计。
一旦你的原型准备好,就让人们来体验它,无论你多么害羞。大公司通常有这方面的资源,但小团队也可以自己创建用户小组。我在当地大学取得了很好的成功——学生们喜欢参与新事物,亚马逊购物券也很有帮助!收集反馈,进行迭代,并开始基本的 A/B 测试。当你接近发布产品时,可以考虑更高级的方法,如多臂老丨虎丨机测试。
苹果有一篇很好的文章,作为用这种方式模拟机器学习的例子。在类似 Siri 的对话式数字助手的用户测试中,他们使用人类操作员来模拟一个原型助手,在对话风格上进行变化——如健谈、不健谈或模仿用户的风格。通过这种方式,他们展示了用户更喜欢那些模仿自己健谈程度的助手,从而提高了可信度和可爱度。所有这一切都无需投入大量的机器学习开发来测试用户体验(arxiv.org/abs/1904.01664)。
从中我们可以看出,模拟 ML 组件将重点放在结果上,使我们能够更改输出格式,测试正向和负向流程,并找到边缘情况。我们还可以衡量感知性能的限制,以及我们如何管理用户的挫败感,这对我们能够构建的模型的复杂性和基础设施成本有重要影响。所有这一切都不需要考虑模型的准确性。最后,内部分享原型有助于获得业务领导的支持,没有什么比把项目交到人们手中更能激发支持和承诺了。
收集反馈并进行迭代
当你进入开发和部署阶段时,你不可避免地会发现需求发生变化,实验会带来一些意想不到的结果。你需要进行迭代!通过版本控制记录变化,通过重新审视问题定义、重新评估数据质量和重新评估用户需求来整合反馈循环。这一过程从持续监控开始,随着产品的成熟,应用统计测试来检测预测分布的变化(数据漂移),以识别性能退化。实施在线学习来应对这一变化,或者如果可能的话,将用户反馈方法集成到 UI 中,以帮助揭示真实的偏见并建立信任,所谓的“人机协同”。首先积极寻求内部反馈,然后通过访谈和小组了解用户的反馈,了解他们如何与产品互动以及如何产生新的问题。使用 A/B 测试来比较你选择的模型版本,了解它对用户行为和相关产品/业务指标的影响。
ML 项目通过在整个模型生命周期中采用敏捷方法论能够带来好处,帮助我们管理 ML 中固有的不确定性和变化,这一切从规划过程开始。小步开始,快速测试,不要害怕快速失败。将其应用于规划和发现阶段,可以降低风险,同时交付一个不仅有效而且与用户产生共鸣的产品。
人工智能深度网络模型是否正在趋同?
人工智能模型是否正在朝着统一的现实表征演化?柏拉图式表征假设认为,机器学习模型正在趋同。
·发表于Towards Data Science ·8 分钟阅读·2024 年 5 月 23 日
--
一篇近期的 MIT 论文引起了我的注意,因为其提出了一个令人印象深刻的观点:人工智能模型正在趋同,即使是在不同的模态——视觉和语言之间。“我们认为,人工智能模型中的表征,特别是深度网络的表征,正在趋同”,这就是柏拉图式表征假设论文的开头。
但是,不同的模型,经过不同数据集的训练并用于不同的应用场景,如何能够趋同?是什么导致了这种趋同?
✨这是付费文章。如果你不是 Medium 会员,你可以在我的新闻通讯中免费阅读此文: Qiubyte.

柏拉图的洞穴寓言,由Jan Saenredam(公有领域)创作。
1. 柏拉图式表征假设
我们认为,不同神经网络模型中数据点的表示方式正日益相似。这种相似性跨越了不同的模型架构、训练目标,甚至数据形式。

柏拉图式表征假设。视觉表征X和文本表征Y都是共同现实Z的投影。(来源:论文)
介绍
本文的核心论点是,来源和形式各异的模型正在趋向于一种现实的表征——即描述我们观察到并用于训练模型的世界事件的联合分布。
作者认为,这种趋向柏拉图式表示的收敛性是由模型所训练的底层数据结构和数据本身的性质驱动的,以及模型本身日益增长的复杂性和能力。随着模型接触到更多样的数据集和更广泛的应用,它们需要一种能够捕捉所有数据类型中常见的基本属性的表示。

《洞穴寓言》的插图,摘自柏拉图的《理想国》(艺术作品来自4edges,来源:Wikipedia)
2. AI 模型会收敛吗?
各种规模的 AI 模型,即使是基于不同架构构建并为不同任务训练的模型,也开始表现出在数据表示上的收敛迹象。随着这些模型的规模和复杂度不断增长,输入数据变得更加庞大和多样,它们处理数据的方式开始趋于一致。
在不同数据模态——视觉或文本上训练的模型也会收敛吗?答案可能是是的!
2.1 能说话的视觉模型
这种对齐跨越了视觉和文本数据——论文随后确认,这一理论的局限性在于它只关注这两种模态,而没有涉及音频或机器人对世界的感知等其他模态。支持这一点的一个案例[1]是LLaVA,该案例展示了通过 2 层 MLP 将视觉特征投影到语言特征中,从而实现了最先进的结果。

LLaVA 如何将视觉特征映射到语言模型的概述。(来源:LLaVA,CC-BY)
2.2 能看见的语言模型
另一个有趣的例子是大型语言模型的视力检查[2],它探讨了大型语言模型在理解和处理视觉数据方面的程度。该研究使用代码作为图像和文本之间的桥梁,作为将视觉数据输入 LLM 的创新方法。论文揭示了 LLM 可以通过代码生成图像,这些图像虽然可能看起来不真实,但仍包含足够的视觉信息来训练视觉模型。

语言模型能看见吗?(source)
2.3 更大的模型,更强的对齐
不同模型的对齐与其规模相关。例如,训练用于CIFAR-10 分类的较大模型,表现出比小模型更强的对齐性。这意味着随着当前构建模型的趋势向 10 亿和 100 亿级别发展,这些巨型模型将会更加一致。
“所有强大的模型都是相似的,每个弱模型都是以自己独特的方式弱。”
3. 为什么 AI 模型会收敛?

AI 模型的学习过程,f ∗ 是训练后的模型,𝐹 F 是函数类,𝐿 L 是依赖于模型 𝑓 f 和来自数据集的输入 𝑥 x 的损失函数,𝑅 R 表示正则化函数,𝐸 E 表示数据集的期望值。每种颜色代表收敛的一个原因。 (来源:论文)
在训练一个 AI 模型时,有一些因素对 AI 模型为何会收敛贡献最大:
3.1 任务变得更加通用
随着模型被训练以同时解决越来越多的任务,其解决方案空间变得越来越小且更加受限。更高的通用性意味着尝试学习更接近现实的数据点。

一个模型能够解决的任务越多,它就被迫学习一个在解决所有这些任务时都有效的非重叠表示。 (来源:论文)
柏拉图表示假说 论文将其表述为 多任务扩展假说:
“能够胜任 N 个任务的表示比能够胜任 M < N 个任务的表示要少。随着我们训练更多通用的模型以同时解决更多任务,我们应该预期可能的解决方案会更少。”
换句话说,解决复杂问题的方案比解决简单问题的方案要窄得多。当我们训练越来越通用的模型,且这些模型在庞大的、跨不同模态的互联网数据集上进行训练时,你可以想象解决方案空间会是多么的受限。
3.2 模型变得越来越大
随着模型的能力增强,通过更复杂的架构、更大的数据集或更复杂的训练算法,这些模型开发出的表示方式变得更加相似。

更大的假设空间比小的假设空间更容易收敛到一个解。 (来源:论文)
尽管 柏拉图表示假说 论文并未为他们所称之为 能力假说 提供证明或示例——即“更大的模型比小的模型更容易收敛到共享表示”,但似乎显而易见的是,至少更大的模型有更多的能力去得出共同的解空间,远超过小模型。
随着 AI 模型的规模扩大,得益于它们的深度和复杂性,它们具备了更强的抽象能力。这使得它们能够捕捉数据的基本概念和模式,同时抛弃噪声或异常值,从而得出一个更加通用且可能更接近现实世界的表示。
3.3 简单性偏向
想象一下,在两个不同任务上训练两个大规模神经网络:一个模型必须能够识别图像中的面孔,另一个模型被训练来解读面孔的情绪。最初,这两个任务似乎没有什么关系——但是你会惊讶地发现两个模型最终会在面部特征表示上趋于相似吗?毕竟,一切归结于准确识别和解读面部关键点(眼睛、鼻子、嘴巴等)。

深度神经网络倾向于更简单的函数。(来源:论文)
有多篇文献指出深度神经网络有一种倾向,倾向于找到更简单、更通用的解决方案[3,4,5]。换句话说,深度网络偏爱简单的解决方案。通常被称为简约偏差,论文将其表述为:
深度网络偏向于找到对数据的简单拟合,并且模型越大,偏差越强。因此,随着模型的增大,我们应预期其收敛到更小的解决空间。
为什么神经网络会表现出这种行为?网络表现出简约偏差主要是因为用于训练它们的学习算法的基本属性。算法倾向于偏好更简单、可泛化的模型,这是为了防止过拟合并增强泛化能力。在训练过程中,简单的模型更有可能出现,因为通过捕捉数据中的主导模式,它们可以更有效地最小化损失函数。
简约偏差在训练过程中充当了一种自然的调节器。它推动模型朝向一种最佳的数据表示和处理方式,这种方式不仅能跨任务通用,而且足够简单,便于高效学习和应用,从而增加了模型学习到共同假设空间的机会。
4. 这种收敛性的影响
那么,如果模型正在收敛,又会怎么样呢?首先,这表明不同模态的数据比以前认为的更有用。从预训练的 LLM 微调视觉模型,或反之,可能会得到出乎意料的好结果。
论文中指出的另一个影响是“规模化可能减少幻觉和偏见”。这一论点是,随着模型的规模扩大,它们可以从更大、更具多样性的数据库中学习,从而帮助它们形成更准确、更健壮的世界理解。这种增强的理解使得模型能够做出更加可靠且更少偏见的预测和输出。

VISION 模型随着能力的增加而收敛。(来源:论文)
5. 一点怀疑
在考虑论文中提出的论点时,必须考虑一些局限性,几乎所有这些局限性都在论文中有所讨论。
首先,论文假设现实世界的双射投影,其中一个现实世界概念 Z 有可以学习的投影 X 和 Y。然而,某些概念是独特地固有于某一模态的。有时,语言能够表达一种概念或情感,而许多图像无法做到,反之,语言也可能无法替代图像来描述视觉概念。
其次,正如前面提到的,论文关注两种模态:视觉和语言。第三,关于“AI 模型正在趋同”的论点仅适用于多任务 AI 模型,而不适用于特定模型,如 ADAS 或情感分析模型。
最后,尽管论文表明不同模型的对齐度有所增加,但并未表明这些模型的表示变得相似。大模型之间的对齐分数确实高于小模型,但即使如此,0.16/1.00 的分数仍然留给研究一些悬而未解的问题。
🌟 加入 1000+人一起学习 Python🐍,机器学习/机器学习操作/人工智能🤖,数据科学📈,以及大语言模型 🗯
关注我,并查看我的X/Twitter,我每天都会为你提供更新。
[## QiuByte | Hesam Sheikh | Substack
人工智能、编程和机器学习,仅在简易的方式下。点击阅读《QiuByte》,由 Hesam Sheikh 主办,Substack…
感谢阅读,
— Hesam
[1] 刘浩,李晨,吴奇,李洋杰。《视觉指令调优》。发表于 NeurIPS,2023。
[2] Sharma, P., Rott Shaham, T., Baradad, M., Fu, S., Rodriguez-Munoz, A., Duggal, S., Isola, P., and Torralba, A. 《语言模型的视觉检查》。发表于 arXiv 预印本,2024。
[3] H. Shah,K. Tamuly,《神经网络中简单性偏差的陷阱》,2020 年,arxiv.org/abs/2006.07710
[4] 关于简单性偏差的简短说明
[5] 深度神经网络在初始化时偏向简单函数
与 LLM 一起玩 20 个问题游戏
学习 LLM 架构技术、解析输出、测试设计以及系统的性能测量
·发布于Towards Data Science ·阅读时间:7 分钟·2024 年 12 月 6 日
--

20 个问题游戏是一个经典的猜谜游戏,适合两名玩家。一名玩家想出一个物体、人物或地方,另一名玩家轮流问是非问题,试图猜出它是什么。目标是在 20 个问题内猜对。如果到第 20 个问题时没人猜出来,思考者将揭示答案,并且回合结束。真实的游戏可以在这里找到,我鼓励你试着猜一些简单的东西,比如汽车或苹果。
本文的目标
构建最佳的 LLM 代理,使其成为这款游戏中的猜测者。
要遵循的步骤
代码基础
从开始这个项目的第一刻起,有几件事我就非常清楚。我希望猜测者代理能够接收到一份非常清晰的、所有之前问题和答案的列表,然后被提示想出下一个要问的问题。
请让这个 AI 的准确性更低一点
揭开“准确性”在数据科学和人工智能中的面纱
·发布于Towards Data Science ·7 分钟阅读·2024 年 4 月 16 日
--
准确性是一个大家直觉上都认为自己理解的词,许多人也认为它越高越好。
随着人工智能(AI)受到越来越多关注以及人们对输出结果的可靠性或准确性问题的认知提升,了解数据产品(如 AI)并不遵循其他技术的一致性或准确性规则变得尤为重要。
混淆矩阵
为了说明这一点,我将介绍“混淆矩阵”这一概念。对于那些为分类目的构建预测模型的数据科学家来说,这个概念一定非常熟悉。虽然其他人可能不太了解,但我发现这个概念、方法论以及其中涉及的人类和商业互动是理解机器学习中准确性术语的一个有用案例研究。它是一个帮助理解这些术语中的细微差别和权衡的有用可视化工具。

作者提供的混淆矩阵模板
当我们谈论总准确性时,指的是所有预测结果中正确预测的数量(即上图中绿色框的总和),占所有预测结果的比例(即上图四个框的总和)。因此,在这里你可能会听到诸如“我们的孕妇测试准确率是 99%...”之类的术语。
在 R 中使用 Google Earth 绘制高尔夫球场
一份关于如何在 Google Earth 中绘制高尔夫球场并将其在 R 中呈现的用户指南。
·发表于 Towards Data Science ·阅读时间 7 分钟·2024 年 5 月 6 日
--
在数据可视化的世界里,我们常常被条形图、折线图和饼图淹没。但事实并不一定非得如此——在本文中,我将展示如何绘制高尔夫球场。

图片由作者提供
介绍
全球大约有 40,000 个高尔夫球场。如果你像我一样热衷于高尔夫和数据,那么这篇文章就是为你准备的。在绘制一个球场后,我们还可以做一些很酷的事情:
-
将击球数据叠加到球场地图上—— 这可以通过使用 Plotly 包进行一些操作,或者通过从 Google Earth 下载单独的高尔夫球击打“位置标记”并将其作为点绘制到地图上。对于你第一次打破 90 杆、80 杆、70 杆等的时刻,这可能是一个不错的纪念方式。
-
计算球场指标——一般来说,大家都知道 Pebble Beach 的果岭很小,被称为“小邮票果岭”。或者 Whistling Straits 有很多沙坑。通过描绘球场元素的多边形,另一个附带的好处是可以计算每个元素的面积。这使我们能够为我们绘制的任何球场推导出平均果岭大小、沙坑数量、球道平均宽度等信息。
上述要点并不是一个全面的列表,但它为项目提供了一个未来扩展的路线图。
目录
我将首先介绍项目的步骤,然后详细讲解每一步:
-
在 Google Earth 中描绘高尔夫球场的多边形,代表高尔夫球场的各个元素(发球台、球道、沙坑、果岭、水域、障碍物)
-
从 Google Earth 下载多边形为 KML 文件
-
读取 KML 数据到 R 中并进行一些轻度的数据清洗/操作
-
使用 ggplot2 绘制高尔夫球场
在 Google Earth 中绘制多边形
首先,让我们前往 Google Earth,选择一个我们想要绘制的高尔夫球场。我们将以位于威斯康星州的 Erin Hills 为例。通常,熟悉球场布局或提前打开一张球场地图有助于通过卫星影像更容易地识别每个洞的具体位置。
我们需要通过点击左上角的蓝色“+ 新建”按钮来创建一个新项目。接下来,我们将使用本地的 KML 文件并点击“创建”。最后,为项目命名,通常可以直接使用所绘制课程的名称。项目的名称将是我们完成后下载的 KML 文件的名称。

来自 Google Earth 的图片,由作者编辑
现在我们的项目文件夹已经设置好,可以开始绘制了。免责声明:Erin Hills 有 138 个沙坑,我是通过亲身体验发现绘制它们有些繁琐……不过,还是让我们先前往第一个发球台开始绘制吧。
到达第一个发球台后,首先识别该洞的关键元素。Erin Hills 的第一个洞有水障碍和一个位于果岭左侧的危险区,果岭有一个左转的狗腿,几处沙坑等等。要开始绘制,请点击“添加路径或多边形”,这是位于顶部工具栏中第二个从左边开始的图标,图标形状为一条连接的点线。这将初始化一个类似铅笔的工具,我们可以用它来进行绘制。
附加说明:你可以通过同时按住 Shift 键并按左箭头或右箭头来旋转屏幕。
我通常从发球台开始,然后向果岭方向绘制。要求每个绘制的多边形都必须形成一个闭合的形状,这意味着你需要回到原始的起点。完成一个多边形后,将其保存到项目中并为其命名。为每个多边形命名时,使用一致的命名规范也非常重要,例如course_hole_element,在这种情况下可以翻译为:erin_hills_hole_1_tee,或者erin_hills_hole_5_fairway,等等。我们稍后会在 R 代码中使用字符串匹配来提取每个多边形名称中的这些关键信息。这将帮助我们创建一个多边形元素到颜色的映射,也就是告诉 ggplot2 如何为每个多边形上色。因此,如果“沙坑”是元素,那么我们希望将其着色为棕黄色。如果“水障碍”是元素,它应该是蓝色的。这样也可以帮助我们提取课程名称和洞号,从而提供更多的绘图功能。
下图是 Erin Hills 的第 15 洞(我在这里打球时最喜欢的洞)。左侧是原始的 Google Earth 图像,中间是我们绘制过后的图像,右侧是使用 ggplot2 渲染后的图像。我选择不绘制球场的长草区、树木、球车道等元素。



(左) 来自 Google Earth 的照片,(中) 来自 Google Earth 的照片,由作者编辑,(右) 作者提供的图片
一旦我们完成了球洞或球场的绘制,就该把所有辛苦工作的成果导出为 KML 文件了。可以通过点击屏幕左侧项目所在位置的三个竖点来完成此操作。该项目与 geoJSON 数据最为兼容,我们可以在接下来的步骤中轻松地将 KML 文件转换为 geoJSON 格式。现在我们准备好进入 R 了。
在 R 中绘图
我们需要准备的包有:sf(用于处理地理空间数据)、tidyverse(用于数据清理和绘图)、stringr(用于字符串匹配)和geojsonsf(用于将 KML 转换为 geoJSON)。我们的第一步是读取 KML 文件,这可以通过st_read()函数来实现。
# load libraries
library(sf)
library(tidyverse)
library(stringr)
library(geojsonsf)
kml_df <- st_read("/Users/adambeaudet/Downloads/erin_hills.kml")
太好了!现在我们应该已经在 R 中获取了高尔夫球场的 KML 数据。数据框应该有两列:Name(项目名称,或者在我们这个案例中是球场名称),和geometry(一个包含所有构成我们描绘的多边形的单个点的列表)。如前所述,让我们将 KML 数据转换为 geoJSON,并提取球场名称和洞号。
# convert from KML to geoJSON
geojson_df <- st_as_sf(kml_df, "POLYGON")
# extracting course name and hole number from polygon name
# assuming "course_hole_element" naming convention is used for polygons
geojson_df$course_name <- str_match(geojson_df$Name, “^(.+)_hole”)[,2]
geojson_df$hole_num <- gsub(“.*_hole_(\\d+)_.*”, “\\1”, geojson_df$Name)
为了使我们的地图指向正北方,我们需要以一种保持方向性的方式进行投影。我们可以使用st_transform()函数来做到这一点。
# define a CRS so map always points due north
crs <- "+proj=lcc +lat_1=33 +lat_2=45 +lat_0=39 +lon_0=-96 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
# transform data to CRS
geojson_df <- st_transform(geojson_df, crs)
我们几乎准备好绘制了,但首先,我们需要告诉 ggplot2 如何为每个多边形上色。下面是我的项目使用的调色板,您也可以根据需要自定义。
可选:在此步骤中,我们还可以使用st_centroid()函数计算我们的多边形的重心,这样我们就可以将洞号叠加到每个果岭上。

图片由作者提供
geojson_df <- geojson_df %>%
mutate(color = case_when(
grepl(“_tee$”, Name) ~ “#57B740”,
grepl(“_bunker$”, Name) ~ “#EDE6D3”,
grepl(“_water$”, Name) ~ “#2243b6”,
grepl(“_fairway$”, Name) ~ “#57B740”,
grepl(“_green$”, Name) ~ “#86D14A”,
grepl(“_hazard$”, Name) ~ “#094d1d”
)) %>%
mutate(centroid = st_centroid(geometry))
我们正式准备好绘图了。我们可以结合使用geom_sf()、geom_text(),如果想要更花哨一点,还可以使用geom_point()来绘制地图上的击球位置。我通常会去掉网格线、坐标轴标签和图例,以保持界面的简洁。
ggplot() +
geom_sf(data = geojson_df, aes(fill = color), color = "black") +
geom_text(data = filter(geojson_df, grepl("_green$", Name)),
aes(x = st_coordinates(centroid)[, 1],
y = st_coordinates(centroid)[, 2],
label = hole_num),
size = 3, color = "black", fontface = "bold", hjust = 0.5, vjust = 0.5) +
scale_fill_identity() +
theme_minimal() +
theme(axis.title.x = element_blank(),
axis.title.y = element_blank(),
axis.text.x = element_blank(),
axis.text.y = element_blank(),
plot.title = element_text(size = 16),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank()) +
theme(legend.position = "none") +
labs(title = 'Erin Hills | Hartford, WI')
就这样——在 R 中绘制的高尔夫球场,真是个好主意!
要查看我在撰写本文时所绘制的其他课程,您可以访问我的 Shiny 应用:abodesy14.shinyapps.io/golfMapsR/
如果您跟随本教程并且玩得开心,或者感兴趣的话,欢迎尝试绘制您最喜欢的高尔夫球场,并为我维护的golfMapsR仓库创建一个 Pull Request:github.com/abodesy14/golfMapsR
通过一些共同的努力,我们可以创建一个关于全球高尔夫球场的可绘制数据库!
使用 PointNet 和 PyTorch3D 进行点云分类
·发表于 Towards Data Science ·阅读时间:11 分钟·2024 年 3 月 22 日
--

来自 ModelNet 数据集的类别为“monitor”的物体
通过此文章,您可以使用 Google Colab 笔记本来跟进学习。
在今天这个技术迅猛发展的时代,3D 技术正变得不可或缺。原型设计、虚拟试穿、虚拟和增强现实体验、数字双胞胎、测量、医疗假肢以及电影和游戏行业只是 3D 技术冰山一角。LinkedIn 估计,到 2028 年,全球对 3D 内容的需求将超过 30 亿美元,并且没有放缓的迹象。从《冰雪奇缘》到《堡垒之夜》,可以说,3D 模型正成为新的照片。
随着对 3D 数据需求的增加,对于有效分类和理解 3D 数据的方法的需求也在增长。2016 年由斯坦福大学研究人员发明的PointNet在快速发展的机器学习领域中可谓是一个化石,然而它经受住了时间的考验。直到 2023 年,研究人员仍然发布了基于 PointNet 架构的变种,适用于各种任务,包括:
-
脉冲神经网络
-
通过多孔介质预测流体流动
-
森林场景中的垂直结构分割
-
3D 人脸验证
-
基于雷达的人体活动识别
-
单孔膜冷却
-
以及更多内容
Polars + NVIDIA GPU 教程
使用 Polars 与 NVIDIA GPU 结合可以加速你的数据管道
·发表于 Towards Data Science ·5 分钟阅读·2024 年 9 月 17 日
--

图片来源:noaa @ Unsplash.com
在 Python 中处理庞大数据集一直是一个挑战。Python 语言并非专门为处理大量数据而设计,像原生 SQL 系统或 Spark 那样。
在 Python 中处理二维数据集最著名的库,毫无疑问是pandas。虽然易于使用,且每个数据科学家都在使用,但 Pandas 是用 Python 和 C 编写的,使得它在对大数据进行操作时稍显繁琐且较慢。如果你是数据科学家,你一定经历过等待 200 年才能完成一个group by操作的痛苦。
旨在解决这个问题的库之一是polars ——一个极为高效的 Python 包,能够处理大规模数据集,主要有以下几个原因:
-
它是用 Rust 编写的
-
它自动利用多线程
-
它通过使用懒计算推迟大部分计算
而且……今天之后,你可以利用 NVIDIA 的硬件来最大化polars的 GPU 引擎能力。
在这篇博客文章中,我们将看到如何利用polars+GPU大幅加速你的数据管道。
环境设置
策略梯度方法在强化学习中的应用
使用策略梯度方法在 Python 中教车穿越山脉:强化学习的数学深度分析
·发表于 Towards Data Science ·阅读时长 26 分钟·2024 年 5 月 29 日
--

图像由 DALL-E 生成
假设你正在试图教一只狗去取回一个球。一开始,狗狗根本不知道在你扔球的时候应该做什么。它可能会朝不同的方向跑,忽视球,或者做一些完全无关的事情。你的目标是教会狗狗取回球并把它带回给你。
每次狗狗做某事时,你要么奖励它一块零食,要么什么也不做。如果狗狗跑向球,你就给它奖励。如果它做了其他事情,你就不给奖励。这一套狗狗用来决定该做什么的准则或策略称为“策略”。最初,这些准则是随机的,但通过训练,它们会变得更加专注于取回球。
随着时间的推移,狗狗学会了跑向球并获得奖励。它开始更频繁地采用这种策略,因为它能带来奖励。这本质上就是策略梯度方法的工作原理。在本文中,我们将探索它们的机制和数学原理,并使用 OpenAI Gym 来训练一辆车穿越山脉。让我们开始吧!
目录
策略梯度:RLHF 的基础
理解策略优化及其在强化学习中的应用
·发表于Towards Data Science ·15 分钟阅读·2024 年 2 月 6 日
--

尽管强化学习(RL)在多种应用中非常有用,但它在大规模语言模型(LLMs)对齐过程中的作用至关重要,特别是在强化学习与人类反馈(RLHF)中的应用。不幸的是,RL 在 AI 社区中并不广为人知。也就是说,许多从业者(包括我自己)更熟悉监督学习技术,这导致了对使用 RL 的潜在偏见,尽管它具有巨大的实用性。在这一系列概述中,我们的目标是通过全面回顾 RL,从基本思想入手,逐步过渡到现代算法,如近端策略优化(PPO) [7],这些算法在 RLHF 中被广泛使用,从而减少这种偏见。

现代强化学习算法的分类(来自[5])
本概述。 如上所示,存在两种类型的无模型强化学习(RL)算法:Q 学习和策略优化。之前,我们学习了 Q 学习、强化学习的基础知识,以及这些思想如何可以推广到语言模型微调。在本概述中,我们将概述策略优化和策略梯度这两个在实践中广泛应用的思想……
卷积神经网络(CNN)的池化层
什么是池化层及其不同类型
·发表于 Towards Data Science ·阅读时间 7 分钟·2024 年 1 月 20 日
--

”www.flaticon.com/free-icons/neural-network" title=”神经网络图标”>由 Freepik 设计的神经网络图标 — Flaticon。
背景
在我之前的文章中,我们介绍了卷积神经网络(CNN)背后的关键组成部分——卷积层。
卷积层使得神经网络能够学习最佳的卷积核,以解码或分类我们的输入图像。
如果你不太了解,卷积核是一个小矩阵,它在输入图像上滑动,并且在每一步应用卷积操作。根据卷积核的结构,它会对输入图像产生不同的效果。它可以进行模糊、锐化,甚至检测边缘(Sobel 算子)。
在 CNN 中,卷积操作的输出称为特征图。
下面是一个卷积的示例图,其中我们对结果图像进行了模糊处理:

下面是一个示例卷积,应用于对灰度图像进行模糊效果处理。图示由作者创建。
如果你想了解卷积是如何工作的完整解析,可以查看我之前的相关文章:
将 Twitter 的异常检测算法移植到 Swift
从 Twitter 到 Swift:构建异常检测。
·发表于Towards Data Science ·12 分钟阅读·2024 年 11 月 29 日
--

Twitter(现为 X),在 2015 年开发了一种异常检测算法,用于追踪其数百万用户之间的趋势
[## GitHub - twitter/AnomalyDetection:使用 R 进行异常检测
使用 R 进行异常检测。通过在 GitHub 上创建账户,贡献于 twitter/AnomalyDetection 的开发。
这个完全使用 R 语言制作的软件包仍然非常实用。它被设计用于检测全球性和局部性的异常,并且能够成功地检测各种异常。有关它能检测和不能检测的内容,请查看Anomaly.io 对原始算法的测试,它非常全面。
为什么要移植到 Swift?
为什么不呢 🤷♂️?我感到无聊。
理解 Twitter 的异常检测算法
Twitter 的异常检测算法是一个统计框架,旨在检测时间序列数据集中的异常值或离群点。
该算法有两个主要的核心组成部分。
- 季节性分解:该算法...
视觉变换器的位置嵌入解析
视觉变换器解析系列
位置嵌入在视觉变换器中的数学原理与代码
·发表于 Towards Data Science ·11 分钟阅读·2024 年 2 月 27 日
--
自 2017 年“Attention is All You Need”¹提出以来,变换器已确立为自然语言处理(NLP)领域的最新技术。2021 年,An Image is Worth 16x16 Words² 成功将变换器应用于计算机视觉任务。此后,许多基于变换器的架构被提出用于计算机视觉。*
本文探讨了为何位置嵌入是视觉变换器中的必要组成部分,并分析了不同文献中位置嵌入的实现方式。文章包含了位置嵌入的开源代码和概念解释,所有代码均使用 PyTorch Python 包。

图片来源:BoliviaInteligente via Unsplash
本文是一个系列文章的一部分,深入探讨了视觉变换器的内部工作原理。这些文章中的每一篇都可以通过 Jupyter 笔记本以可执行代码的形式查看。系列中的其他文章包括:
-
视觉变换器解析→ Jupyter 笔记本
-
视觉变换器的注意力机制解析
-
视觉变换器的位置嵌入,解释
-
标记到标记的视觉变换器,解释
目录
-
为什么使用位置嵌入?
-
注意力不变性直到置换
-
文献中的位置嵌入
-
一个位置嵌入示例
— 定义位置嵌入
— 将位置嵌入应用于标记
-
结论
— 进一步阅读
— 引用
为什么使用位置嵌入?
Attention is All You Need¹中指出,由于变换器缺乏递归或卷积,它们无法学习关于一组标记顺序的信息。如果没有位置嵌入,变换器对标记的顺序是不可变的。对于图像来说,这意味着图像的各个块可以被打乱,而不影响预测的输出。
让我们来看一个关于块顺序的示例,图像为 Luis Zuno(@ansimuz)创作的像素艺术作品《黄昏山脉》³。原始作品已经被裁剪并转换为单通道图像。这意味着每个像素的值在零和一之间。单通道图像通常以灰度显示;然而,为了更容易观察,我们将其以紫色调显示。
mountains = np.load(os.path.join(figure_path, 'mountains.npy'))
H = mountains.shape[0]
W = mountains.shape[1]
print('Mountain at Dusk is H =', H, 'and W =', W, 'pixels.')
print('\n')
fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.xticks(np.arange(-0.5, W+1, 10), labels=np.arange(0, W+1, 10))
plt.yticks(np.arange(-0.5, H+1, 10), labels=np.arange(0, H+1, 10))
plt.clim([0,1])
cbar_ax = fig.add_axes([0.95, .11, 0.05, 0.77])
plt.clim([0, 1])
plt.colorbar(cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'mountains.png'), bbox_inches='tight')
Mountain at Dusk is H = 60 and W = 100 pixels.

代码输出(图片由作者提供)
我们可以将这张图像分割成大小为 20 的块。(关于将图像分割成块的更深入解释,请参见视觉变换器文章。)
P = 20
N = int((H*W)/(P**2))
print('There will be', N, 'patches, each', P, 'by', str(P)+'.')
print('\n')
fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.clim([0,1])
plt.hlines(np.arange(P, H, P)-0.5, -0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5, -0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+1, 10), labels=np.arange(0, W+1, 10))
plt.yticks(np.arange(-0.5, H+1, 10), labels=np.arange(0, H+1, 10))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(1, N+1):
plt.text(x_text[i-1], y_text[i-1], str(i), color='w', fontsize='xx-large', ha='center')
plt.text(x_text[2], y_text[2], str(3), color='k', fontsize='xx-large', ha='center');
#plt.savefig(os.path.join(figure_path, 'mountain_patches.png'), bbox_inches='tight')
There will be 15 patches, each 20 by 20.

代码输出(图片由作者提供)
这个说法是,视觉变换器将无法区分原始图像和一个将块打乱后的版本。
np.random.seed(21)
scramble_order = np.random.permutation(N)
left_x = np.tile(np.arange(0, W-P+1, 20), 3)
right_x = np.tile(np.arange(P, W+1, 20), 3)
top_y = np.repeat(np.arange(0, H-P+1, 20), 5)
bottom_y = np.repeat(np.arange(P, H+1, 20), 5)
scramble = np.zeros_like(mountains)
for i in range(N):
t = scramble_order[i]
scramble[top_y[i]:bottom_y[i], left_x[i]:right_x[i]] = mountains[top_y[t]:bottom_y[t], left_x[t]:right_x[t]]
fig = plt.figure(figsize=(10,6))
plt.imshow(scramble, cmap='Purples_r')
plt.clim([0,1])
plt.hlines(np.arange(P, H, P)-0.5, -0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5, -0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+1, 10), labels=np.arange(0, W+1, 10))
plt.yticks(np.arange(-0.5, H+1, 10), labels=np.arange(0, H+1, 10))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(N):
plt.text(x_text[i], y_text[i], str(scramble_order[i]+1), color='w', fontsize='xx-large', ha='center')
i3 = np.where(scramble_order==2)[0][0]
plt.text(x_text[i3], y_text[i3], str(scramble_order[i3]+1), color='k', fontsize='xx-large', ha='center');
#plt.savefig(os.path.join(figure_path, 'mountain_scrambled_patches.png'), bbox_inches='tight')

代码输出(图片由作者提供)
显然,这是与原始图像非常不同的图像,你不希望视觉变换器将这两张图像视为相同的。
注意力不变性直到置换
让我们探讨一下关于视觉 transformer 对标记顺序不变性的说法。与标记顺序不变的 transformer 组件是注意力模块。虽然本文不以详细解释注意力模块为重点,但我们需要具备基础理解。欲了解更多关于视觉 transformer 中注意力的详细讲解,请参见 注意力文章。
注意力是通过三个矩阵计算的——Q 查询、K 键和 V 值——每个矩阵都是通过将标记通过线性层生成的。一旦生成了 Q、K 和 V 矩阵,注意力就可以通过以下公式计算。
其中 Q, K, V 分别是查询、键和值;dₖ 是一个缩放值。为了展示注意力对标记顺序的不变性,我们将从三个随机生成的矩阵开始,表示 Q、K 和 V。Q、K 和 V 的形状如下:

Q、K 和 V 的维度(图像来源:作者)
在这个示例中,我们将使用 4 个投影长度为 9 的标记。矩阵将包含整数,以避免浮动点乘法错误。一旦生成,我们将在所有三个矩阵中交换标记 0 和标记 2 的位置。交换位置的矩阵将用下标 s 表示。
n_tokens = 4
l_tokens = 9
shape = n_tokens, l_tokens
mx = 20 #max integer for generated matricies
# Generate Normal Matricies
np.random.seed(21)
Q = np.random.randint(1, mx, shape)
K = np.random.randint(1, mx, shape)
V = np.random.randint(1, mx, shape)
# Generate Row-Swapped Matricies
swapQ = copy.deepcopy(Q)
swapQ[[0, 2]] = swapQ[[2, 0]]
swapK = copy.deepcopy(K)
swapK[[0, 2]] = swapK[[2, 0]]
swapV = copy.deepcopy(V)
swapV[[0, 2]] = swapV[[2, 0]]
# Plot Matricies
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(8,8))
fig.tight_layout(pad=2.0)
plt.subplot(3, 2, 1)
mat_plot(Q, 'Q')
plt.subplot(3, 2, 2)
mat_plot(swapQ, r'$Q_S$')
plt.subplot(3, 2, 3)
mat_plot(K, 'K')
plt.subplot(3, 2, 4)
mat_plot(swapK, r'$K_S$')
plt.subplot(3, 2, 5)
mat_plot(V, 'V')
plt.subplot(3, 2, 6)
mat_plot(swapV, r'$V_S$')

代码输出(图像来源:作者)
注意力公式中的第一次矩阵乘法是 Q·Kᵀ=A,其中结果矩阵 A 是一个正方形,大小等于标记的数量。当我们用 Qₛ 和 Kₛ 计算 Aₛ 时,结果 Aₛ 的行 [0, 2] 和列 [0,2] 会与 A 中的行和列交换。
A = Q @ K.transpose()
swapA = swapQ @ swapK.transpose()
modA = copy.deepcopy(A)
modA[[0,2]] = modA[[2,0]] #swap rows
modA[:, [2, 0]] = modA[:, [0, 2]] #swap cols
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(8,3))
fig.tight_layout(pad=1.0)
plt.subplot(1, 3, 1)
mat_plot(A, r'$A = Q*K^T$')
plt.subplot(1, 3, 2)
mat_plot(swapA, r'$A_S = Q_S * K_S^T$')
plt.subplot(1, 3, 3)
mat_plot(modA, 'A\nwith rows [0,2] swaped\n and cols [0,2] swaped')

代码输出(图像来源:作者)
下一个矩阵乘法是 A·V=A,其中结果矩阵 A 的形状与初始的 Q、K 和 V 矩阵相同。当我们用 Aₛ 和 Vₛ 计算 Aₛ 时,结果 Aₛ 的行 [0,2] 会与 A 中的行交换。
A = A @ V
swapA = swapA @ swapV
modA = copy.deepcopy(A)
modA[[0,2]] = modA[[2,0]] #swap rows
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 7))
fig.tight_layout(pad=1.0)
plt.subplot(2, 2, 1)
mat_plot(A, r'$A = A*V$')
plt.subplot(2, 2, 2)
mat_plot(swapA, r'$A_S = A_S * V_S$')
plt.subplot(2, 2, 4)
mat_plot(modA, 'A\nwith rows [0,2] swaped')
axs[1,0].axis('off')

代码输出(图像来源:作者)
这证明了在输入到注意力层中的标记顺序发生变化时,输出的注意力矩阵中相应的标记行也会发生变化。这是直观的,因为注意力是在计算标记之间的关系。没有位置信息的情况下,改变标记顺序不会改变标记之间的关系。我不太明白为什么这种输出的排列不能传递位置信息给 transformer。然而,我所读的所有资料都表示这不够,因此我们接受这一点并继续前进。
文献中的位置嵌入
除了位置嵌入的理论依据外,使用位置嵌入的模型比不使用位置嵌入的模型具有更高的准确性。然而,目前没有明确的证据表明哪种类型的位置嵌入优于其他类型。
在Attention is All You Need¹中,他们使用了固定的正弦位置嵌入。他们指出,他们曾尝试过学习的位置嵌入,但观察到“几乎相同的结果”。请注意,这个模型是为 NLP 应用设计的,特别是翻译任务。作者最终选择了固定嵌入,因为它可以适应不同长度的短语。在计算机视觉应用中,这可能不会是一个问题。
在An Image is Worth 16x16 Words²中,他们将位置嵌入应用于图像。他们对四种不同的位置信息嵌入(包括固定和可学习的设置)进行了消融研究。该研究包括没有位置嵌入、1D 位置嵌入、2D 位置嵌入和相对位置嵌入。他们发现,带有位置嵌入的模型明显优于没有位置嵌入的模型。然而,不同类型的位置信息嵌入之间或固定与可学习嵌入之间的差异很小。这与[1]中的结果一致,即位置嵌入是有益的,尽管选择的具体嵌入并无太大影响。
在Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet⁴中,他们使用了一种正弦位置嵌入,描述上与[2]中的相同。他们发布的代码与[1]中的正弦位置嵌入公式一致。此外,他们发布的代码将位置嵌入固定,而不是将其作为一个通过正弦初始化的学习参数。
一个位置嵌入的示例
定义位置嵌入
现在,我们可以具体看看正弦位置嵌入的细节。该代码基于公开可用的 GitHub 代码,适用于Tokens-to-Token ViT⁴。从功能上讲,位置嵌入是一个与令牌形状相同的矩阵。它看起来像这样:

位置嵌入矩阵的形状(图像由作者提供)
来自[1]的正弦位置嵌入公式如下:
其中,PE是位置嵌入矩阵,i是令牌的数量,j是令牌的长度,d是令牌的长度。
在代码中,它看起来像这样:
def get_sinusoid_encoding(num_tokens, token_len):
""" Make Sinusoid Encoding Table
Args:
num_tokens (int): number of tokens
token_len (int): length of a token
Returns:
(torch.FloatTensor) sinusoidal position encoding table
"""
def get_position_angle_vec(i):
return [i / np.power(10000, 2 * (j // 2) / token_len) for j in range(token_len)]
sinusoid_table = np.array([get_position_angle_vec(i) for i in range(num_tokens)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
让我们生成一个示例位置嵌入矩阵。我们将使用 176 个令牌。每个令牌的长度为 768,这是 T2T-ViT⁴代码中的默认值。矩阵生成后,我们可以将其绘制出来。
PE = get_sinusoid_encoding(num_tokens=176, token_len=768)
fig = plt.figure(figsize=(10, 8))
plt.imshow(PE[0, :, :], cmap='PuOr_r')
plt.xlabel('Along Length of Token')
plt.ylabel('Individual Tokens');
cbar_ax = fig.add_axes([0.95, .36, 0.05, 0.25])
plt.clim([-1, 1])
plt.colorbar(label='Value of Position Encoding', cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'fullPE.png'), bbox_inches='tight')

代码输出(图像由作者提供)
让我们放大查看令牌的开始部分。
fig = plt.figure()
plt.imshow(PE[0, :, 0:301], cmap='PuOr_r')
plt.xlabel('Along Length of Token')
plt.ylabel('Individual Tokens');
cbar_ax = fig.add_axes([0.95, .2, 0.05, 0.6])
plt.clim([-1, 1])
plt.colorbar(label='Value of Position Encoding', cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'zoomedinPE.png'), bbox_inches='tight')

代码输出(图像由作者提供)
它确实具有正弦结构!
将位置嵌入应用于令牌
现在,我们可以将位置嵌入添加到标记中!我们将使用《黄昏山脉》³,并采用与上述相同的补丁标记化方法。这样,我们将得到 15 个标记,每个标记的长度为 20²=400。有关补丁标记化的更多详细信息,请参阅视觉变换器文章。回想一下,这些补丁看起来是这样的:
fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.hlines(np.arange(P, H, P)-0.5, -0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5, -0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+1, 10), labels=np.arange(0, W+1, 10))
plt.yticks(np.arange(-0.5, H+1, 10), labels=np.arange(0, H+1, 10))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(1, N+1):
plt.text(x_text[i-1], y_text[i-1], str(i), color='w', fontsize='xx-large', ha='center')
plt.text(x_text[2], y_text[2], str(3), color='k', fontsize='xx-large', ha='center')
cbar_ax = fig.add_axes([0.95, .11, 0.05, 0.77])
plt.clim([0, 1])
plt.colorbar(cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'mountain_patches_w_colorbar.png'), bbox_inches='tight')

代码输出(图片由作者提供)
当我们将这些补丁转换为标记时,结果如下所示
tokens = np.zeros((15, 20**2))
for i in range(15):
patch = gray_mountains[top_y[i]:bottom_y[i], left_x[i]:right_x[i]]
tokens[i, :] = patch.reshape(1, 20**2)
tokens = tokens.astype(int)
tokens = tokens/255
fig = plt.figure(figsize=(10,6))
plt.imshow(tokens, aspect=5, cmap='Purples_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95, .36, 0.05, 0.25])
plt.clim([0, 1])
plt.colorbar(cax=cbar_ax)

代码输出(图片由作者提供)
现在,我们可以创建一个形状正确的位置嵌入:
PE = get_sinusoid_encoding(num_tokens=15, token_len=400).numpy()[0,:,:]
fig = plt.figure(figsize=(10,6))
plt.imshow(PE, aspect=5, cmap='PuOr_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95, .36, 0.05, 0.25])
plt.clim([0, 1])
plt.colorbar(cax=cbar_ax)

代码输出(图片由作者提供)
现在我们准备将位置嵌入添加到标记中。位置嵌入中的紫色区域会让标记变得更暗,而橙色区域会让标记变得更亮。
mountainsPE = tokens + PE
resclaed_mtPE = (position_mountains - np.min(position_mountains)) / np.max(position_mountains - np.min(position_mountains))
fig = plt.figure(figsize=(10,6))
plt.imshow(resclaed_mtPE, aspect=5, cmap='Purples_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95, .36, 0.05, 0.25])
plt.clim([0, 1])
plt.colorbar(cax=cbar_ax)

代码输出(图片由作者提供)
你可以看到原始标记的结构以及位置嵌入中的结构!这两部分信息都已经传递到变换器中。
结论
现在,你应该对位置嵌入如何帮助视觉变换器学习有所直觉。本文中的代码可以在GitHub 仓库找到,专门用于这一系列内容。T2T-ViT 论文中的代码⁴可以在这里找到。祝你愉快地进行变换!
本文已由洛斯阿拉莫斯国家实验室批准发布,发布编号为 LA-UR-23–33876。相关代码已获批准,并根据 O#4693 颁发了 BSD-3 开源许可证。
进一步阅读
要了解更多关于在 NLP 中的位置信息嵌入,请参阅
- 《位置编码在变换器模型中的温和介绍》:
machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/
若要查看关于视觉变换器的广泛视频讲座(相关章节已标注),请参见
-
视觉变换器及其应用:
youtu.be/hPb6A92LROc?si=GaGYiZoyDg0PcdSP— 视觉变换器对补丁位置不变 10:44–12:52 (
youtu.be/hPb6A92LROc?t=644&si=Keu-5i9BQ5c69mxz)— 位置嵌入 12:52–14:15 (
youtu.be/hPb6A92LROc?t=772&si=spdlYZl-TRgbGgzn)
引用
[1] Vaswani 等(2017 年)。注意力机制就是你所需要的一切。 doi.org/10.48550/arXiv.1706.03762
[2] Dosovitskiy 等(2020 年)。一张图片价值 16x16 个词:用于大规模图像识别的变换器。 doi.org/10.48550/arXiv.2010.11929
[3] 路易斯·祖诺 (@ansimuz)。 黄昏时山脉背景。 许可证 CC0: opengameart.org/content/mountain-at-dusk-background
[4] 袁等人 (2021). Tokens-to-Token ViT: 从头开始在 ImageNet 上训练视觉变换器。 doi.org/10.48550/arXiv.2101.11986
→ GitHub 代码: github.com/yitu-opensource/T2T-ViT
统计功效分析,解密
从基础原理推导样本量方程
·发表于 Towards Data Science ·9 分钟阅读·2024 年 12 月 3 日
--
在规划在线实验时,出现了一个关键问题:
需要多少观察值才能自信地检测到一个有意义的效应?
在本文中,我们旨在全面展示样本量确定的机制——也称为统计功效分析。通过从基础原理推导样本量方程,我们将揭开这一过程的神秘面纱,帮助您深入理解统计学的基础知识。通过本指南的学习,您将能够清晰且自信地计算最小样本量。
由于计算会根据我们是测量比例还是连续性结果有所不同,我们将分别研究这两种情况。
1 比例度量
1.1 设计实验
假设我们想评估重新设计的主页对注册账户的访客比例的影响。我们设计了一个实验,使得实验组的访客看到新主页,而对照组的访客看到旧主页。

作者生成的图像
1.2 明确假设
AI 代理与 CrewAI 的强大协作
一个实战营销用例
·发表于Towards Data Science ·阅读时间 11 分钟·2024 年 2 月 12 日
--

图片由作者提供(通过)
在本文中,我将向你展示如何编排 AI“木偶”来成功应对现实生活中的营销挑战,其中代理们共同合作,完成以下任务:
-
分析客户数据资料
-
为目标营销选择理想产品
-
为这些产品创建有吸引力的推广文案。
我们将使用一个新的框架——CrewAI,使得自主 AI 代理能够协作并实现共同目标。所有内容都有文档记录在提供的Colab 笔记本中,因此你可以复制并将其适应到你自己的用例中。
什么是 CrewAI?
CrewAI是一个新的框架,旨在促进 AI 代理之间的协作。代理可以扮演特定角色,分享共同目标,并作为一个整体高效地运作。它是开源的,基于Langchain构建的。在同一领域的一些替代方案包括微软的 AutoGen和ChatDev。

图片来自官方文档
CrewAI 的主要概念围绕三个核心实体:代理、任务和小组。
-
代理:这些是独立单元,编程用于执行任务、做出决策并与其他代理沟通。它们可以使用工具,这些工具可以是简单的搜索功能,或者是涉及其他链条、API 等的复杂集成。
-
任务:任务是人工智能代理需要完成的工作或职责。任务可以包含额外的信息,例如哪个代理应该执行此任务以及他们可能需要哪些工具。
-
团队是由每个具有特定角色的代理组成的团队,这些代理共同努力实现一个共同的目标。组建团队的过程包括汇集代理、定义他们的任务并确定任务执行顺序。
为了测试它的能力,我设计了以下情景……
营销挑战
假设你是一个地方零售商的负责人。下周,你将对 12 种产品进行促销活动。你会将哪些产品推广给哪些客户?正如你所想,向男性客户推广口红几乎没有意义。这个过程能通过人工智能进行优化吗?
通过使用忠诚卡和数据挖掘,您可以访问一个包含客户个性分析的数据库。这有助于了解客户的特点、偏好和行为。通过精心设计的提示(见笔记本),我几秒钟内便获得了这一客户数据集:
birth_year,sex,marital_status,yearly_household_income_percentile,number_of_toddlers,number_of_teens,highest_education,monthly_spend_on_wine,monthly_spend_on_vegetables,monthly_spend_on_toys,last_month_coupon_use
1975,F,married,65,1,0,masters,30,65,20,high
1992,M,single,30,0,0,bachelors,15,40,0,none
1980,F,married,80,0,2,bachelors,50,80,85,low
1968,M,divorced,45,0,0,high_school,45,50,0,high
1990,F,single,25,0,0,associates,10,35,0,low
1985,M,married,90,2,1,phd,80,100,120,high
2000,F,single,40,0,0,bachelors,25,55,0,none
1972,M,married,70,0,1,masters,60,70,40,low
1988,F,married,55,1,0,associates,20,40,30,high
1970,M,single,60,0,0,high_school,35,60,0,none
对于可能的促销产品,我提出了以下几项:
Fresh Lettuce
Diapers
Irish whiskey
laundry detergent
Chips
Spaghetti cans (ready to eat)
Minecraft Video Game
Mascara
Toilet Paper (best value)
Wagyu beef steak
Organic avocados
Cigarettes
这些产品是特别挑选的,因为我预期会看到某些关联。例如,婴儿尿布的促销与幼儿数量之间的关系,以及和牛牛排与家庭收入之间的关系。代理们能否通过常识合作发现这种逻辑?
实现
我的第一组将通过为每个特定客户选择最合适的三款产品来精准定位产品。在这个团队中,我创建了三个代理:
-
首席促销总监:负责主要任务,监督其他人的工作
-
客户画像专家:用于了解客户
-
产品专家:精通将产品与客户匹配
每个代理都需要一个角色、目标和背景故事。我们可以使用自然语言来描述这一点。在我的笔记本中,它变成了这样:
profiler = Agent(
role='profiler',
goal='''From limited data, you logically deduct conclusions about people.''',
backstory='You are an expert psychologist with decades of experience.',
llm=llm,
verbose=True,
allow_delegation=True
)
product_specialist = Agent(
role='product specialist',
goal='''Match the product to the customer''',
backstory='You have exceptional knowledge of the products and can say
how valuable they are to a customer.',
llm=llm,
verbose=True,
allow_delegation=True
)
Chief_Promotional_Director = Agent(
role="Chief Promotion Director",
goal='''
Oversee the work done by your team to make sure it's the best possible
and aligned with the product's goals, review, approve, ask clarifying
question or delegate follow up work if necessary to make decisions''',
backstory='''
You're the Chief Promotion Officer of a large retailer. You're
launching a personalized ad campaign, trying to make sure your team
is crafting the best possible content for the customer.''',
tools=[],
llm=llm,
verbose=True
)
以下任务交给了代理首席促销总监,他会将工作分配给其他代理:
select_3_products_task = f'''You're creating a targeted marketing campaign
tailored to what we know about our customers.
For each customer, we have to choose exactly three products to promote
in the next campaign. Make sure the selection is the best possible and
aligned with the customer. Review, approve, ask clarifying question or
delegate follow up work if necessary to make decisions. When delegating
work send the full draft as part of the information.
This is the list of all the products participating in the campaign: {products}.
This is all we know so far from the customer: {customer_description}.
To start this campaign we will need to build first an understanding of our
customer. Once we have a profile about the customers interests, lifestyle and
means and needs, we have to select exactly three products that have the
highest chance to be bought by them.
Your final answer MUST be exactly 3 products from the list, each with a short
description why it matches with this customer. '''
所有代理彼此互动,直到首席促销总监满意并结束任务。为了更好地理解当这组代理被解除任务时发生了什么,我创建了这个示意图:

作者图片
第二组将为每个产品编写简短的促销文本。文本内容也应与可以从客户数据中推断出的信息一致。为此,我创建了一个额外的创意内容创作者代理:
creative_content_creator_agent = Agent(
role="Creative Content Creator",
goal=dedent("""\
Develop compelling and innovative content
for ad campaigns, with a focus customer specific ad copies."""),
backstory=dedent("""\
As a Creative Content Creator at a top-tier digital marketing
agency, you excel in crafting advertisements that resonate with
potential customers. Your expertise lies in turning marketing
strategies into engaging stories that capture attention and
inspire buying action."""),
llm=llm,
verbose=True
)
该代理将与之前的“老板”代理合作,组成新的团队,执行以下任务:
get_ad_campaign_written_task = f'''
You're creating a targeted marketing campaign tailored to what we know
about our customers.
For each customer, we have chosen three products to promote in the next
campaign. This selection is tailored specifically to the customer: {selection}
To end this campaign succesfully we will need a promotional message
advertising these products to the customer with the ultimate intent that
they buy from us. This message should be around 3 paragraphs, so that it can
be easily integrated into the full letter. For example:
Tired of making dinner, get our best ready made canned tuna.
Your lifestyle deserves a taste of this fresh lobster.
In the weekends, go on a day trip with the kids with these new lunch box containers.
You need to review, approve, and delegate follow up work if necessary to
have the complete promotional message. When delegating work send the full
draft as part of the information.
Your final answer MUST include the 3 products from the list,
each with a short promotional message.'''
每个代理都与他们将使用的大型语言模型连接。我将他们连接到Mistral API。我首先集成了他们的 Mistral 7B Instruct v0.2 模型,但它推理得非常奇怪,例如对于选择尿布的解释:“虽然客户目前是单身,但她未来可能会有孩子。尿布是未来为父母准备的实际投资,因为它们是照顾婴儿和幼儿所必需的。”
通过升级到更大的型号:Mixtral 8x7B,这个问题得以解决。我再也没见过类似的逻辑缺陷。
结果
在我设置好两组人员后,是时候让他们开始分析 10 个客户了。通过将详细模式设置为 2,我能够在屏幕后面查看并跟踪代理的思考和互动。例如:
Thought: Do I need to use a tool? Yes
Action: Ask question to co-worker
Action Input: profiler|Can you provide a customer profile based on the given demographic and behavioral data?|The customer was born in 1968\. Its sex is M. Its marital status is divorced. Its yearly household income percentile is 45\. The customer has 0 toddlers. The customer has 0 teenage kids. The customer highest education is high_school. The customer spends on average 45 $ on wine. The customer spends on average 50 $ on vegetables. The customer spends on average 0 $ on toys. The customer's coupon use is high.
Observation: The profiler responds with: Our customer is a 54-year-old divorced man with a mid-level income. He has no children living with him. He has a high school education. He enjoys wine and vegetables, and does not spend on toys. He is likely to be budget-conscious and looking for deals, given his high coupon use.
Thought: Do I need to use a tool? Yes
Action: Ask question to co-worker
Action Input: product specialist|Which three products would you recommend for this customer, based on his profile?|54-year-old divorced man with a mid-level income, high school education, enjoys wine and vegetables, budget-conscious, and uses coupons.
Observation: The product specialist responds with: Based on the customer's profile, I would recommend:
1\. Laundry detergent: A necessary household item, and we can offer a deal or discount to appeal to his budget-conscious nature.
2\. Fresh lettuce and vegetables: Given his interest in vegetables, promoting fresh produce is a good fit.
3\. Wine or Irish whiskey: Since he spends a significant amount on wine, promoting a higher-end option like Irish whiskey could be appealing, especially if we offer a discount.
Thought: Do I need to use a tool? No
Final Answer:
1\. Laundry detergent : A necessary household item, and we can offer a deal or discount to appeal to his budget-conscious nature.
2\. Fresh lettuce and vegetables : Given his interest in vegetables, promoting fresh produce is a good fit.
3\. Irish whiskey : Since he spends a significant amount on wine, promoting a higher-end option like Irish whiskey could be appealing, especially if we offer a discount.
Observation: I have provided the three recommended products for the customer based on their profile.
现在,让我们详细分析两个客户。第一个客户是:
1990,F,single,25,0,0,associates,10,35,0,low
将其交给分析代理,其回应是:“我们的客户是一位 20 多岁末或 30 岁初的单身女性,收入中等。她喜欢葡萄酒和蔬菜,可能是自己做饭。她没有孩子,具有较高的教育背景,这表明她可能更加注重健康并且生活节奏繁忙。她不常使用优惠券,说明她可能更偏向于便利而非省钱。”
产品专家代理回复:“根据客户档案,我推荐以下产品:”
-
新鲜生菜:考虑到她对蔬菜的兴趣,她可能会喜欢新鲜生菜作为健康的餐食配料。
-
睫毛膏:作为单身女性,她可能会喜欢睫毛膏等美容产品,这与她的高学历和忙碌的生活方式相符。
-
意大利面罐头(即食):由于她忙碌的生活方式,她可能没有时间做饭,因此即食意大利面是一个便捷的选择。
这些选择非常合理!现在让我们看看其他小组提出的促销信息。我已将其粘贴在这个示意图中。产品信息非常精准:

作者提供的图片
让我们看看他们如何处理第二位客户:
1985,M,married,90,2,1,phd,80,100,120,high
构建的个人档案非常基础,但完整:“客户出生于 1985 年,男性,已婚,年收入位于前 90%,有 2 个幼儿和 1 个青少年孩子,最高学历博士,花费 80 美元买葡萄酒,100 美元买蔬菜,120 美元买玩具,使用优惠券频繁。”
以下是所选择并推广的产品:

作者提供的图片
产品选择和促销文案与客户已知的信息非常匹配。他们因为客户收入较高推荐了和牛肉,因为他是一个有 3 个孩子的忙碌父母,而游戏则是因为他有青少年孩子。
拉远一点,这里是我从整个小组输出中得到的观察:
✔️ 尿布只推荐给有幼儿的客户
✔️ 睫毛膏只推荐给女性客户
✔️ 和牛肉只推荐给最高收入的客户
✔️ 洗衣粉被推荐了 4 次。两次推荐给女性,也有两次推荐给未婚(!)男性
✔️ 香烟从未被推荐过(拍了拍 Mixtral —— 好 AI)
✔️ 所有的配套文本都根据客户的情况量身定制,而不会显得过于具体或包含虚假信息
✔️ 我从未看到不合适的产品建议
这正是一个营销人员所期待的!团队似乎具备一些常识,并且能够围绕这些常识进行推理。
伦理性
当我说常识时,请注意,它并不是普遍适用的常识。每个大型语言模型都固有地带有一定的偏见,这取决于训练数据和训练优化(例如,RLHF)。
举个例子:我发现,在我的例子中,模型只向未婚男性推荐洗衣粉。这是偶然的,还是假设已婚男性在家庭中不洗衣服?也许是因为模型只见过那些只有女性做洗衣的夫妻?这从统计学上可能是正确的,但事实是我们不知道还存在哪些其他偏见。我们应该小心在没有审查的情况下将这些框架投入生产,因为它们可能会传播关于性别、政治、宗教、种族等的刻板印象。这仍然是一个正在进行的研究领域,旨在减少这些风险。更多的阅读资料我可以推荐:OpenAI 的 博客文章、🤗 评估库 和 实践建议。
结论
我的结果表明,AI 代理能够根据客户的个人资料推理出最佳产品,并提供精准的促销信息,量身定制给每个客户,考虑到年龄、性别、婚姻状况和收入等因素。
我们看到自主 AI 代理协作、委派或被分配子任务、呼叫帮助并互相检查工作。当它们这样做时,它们的输出质量显著提升,从而做出更可靠、更符合常识的决策。像 CrewAI 这样的框架通过允许自然语言指令提供了一种直接而有效的方式来利用这一能力。我还讨论了大型语言模型中内嵌的常识,以及偏见的存在,应该努力追求公平。
这只是一个简单的例子,专注于营销。还有许多其他的应用场景:数据解析、自动社交媒体发布、Markdown 校验器或城市旅行规划器……可能性无穷无尽。也许它能帮助你应对下一个挑战?
CrewAI 强大的 AI 代理协作:营销应用案例 [## 首页
用于编排角色扮演、自主 AI 代理的前沿框架。通过促进协作智能…
docs.crewai.com [## GitHub - joaomdmoura/crewAI:用于编排角色扮演、自主 AI 代理的框架。通过…
用于编排角色扮演、自主 AI 代理的框架。通过促进协作智能,CrewAI…
通过让 LLMs 访问库,利用自然语言请求进行强大的数据分析和绘图
在你的网页浏览器中,让大型语言模型(LLMs)为你分析和绘制数据
LucianoSphere (Luciano Abriata, PhD)
·发表于Towards Data Science ·22 分钟阅读·2024 年 1 月 24 日
--

这张图是作者根据他自己的 Web 应用截图整理的。
介绍:使用大型语言模型自动化数据分析
最近,我开始探索如何使用大型语言模型(LLMs)自动化数据分析,这样你就可以用自然语言向它们提问关于数据集的问题,它们会通过生成并运行代码来回答这些问题。我将这一切实现为一个 Web 应用程序,(我和你!)可以尝试这一方法的强大功能和局限性,目前完全依赖程序编写标准的 JavaScript:
如何使用大型语言模型将有关数据集的问题转化为即时运行的代码,从而提供…
towardsdatascience.com
正如我在那篇文章中所解释的,我的主要兴趣是解决这个问题:
我能否用自己的话向 LLM 询问一个数据集的问题,并让它通过必要的数学或脚本解释这些问题并给出答案?
使用 CUPED 和双重机器学习为实验赋能
因果 AI,探索因果推理与机器学习的结合
·发布于Towards Data Science ·阅读时长 17 分钟·2024 年 8 月 15 日
--

图片来源:Karsten Würth 于Unsplash
这系列文章讲的是什么?
欢迎来到我的因果 AI 系列文章,我们将在这里探索因果推理如何融入机器学习模型。你将看到跨多个业务场景的多种实际应用。
在上一篇文章中,我们介绍了如何通过因果图保护需求预测。今天,我们将重点讨论如何使用 CUPED 和双重机器学习为实验赋能。
如果你错过了上一篇关于保护需求预测的文章,可以在这里查看:
因果 AI,探索因果推理与机器学习的结合
towardsdatascience.com
引言
在本文中,我们将评估 CUPED 和双重机器学习是否能够增强你的实验效果。我们将通过案例研究来探索以下几个方面:
-
实验的构建模块:假设检验、功效分析、自助法。
-
什么是 CUPED,它如何帮助增强实验效果?
-
CUPED 和双重机器学习在概念上有哪些相似之处?
-
什么时候我们应该使用双重机器学习,而不是 CUPED?
完整的笔记本可以在这里找到:
[## causal_ai/notebooks/powering your experiments - cuped.ipynb at main · raz1470/causal_ai
该项目介绍了因果人工智能及其如何推动商业价值。 - causal_ai/notebooks/powering your experiments…
案例研究
背景
你最近加入了一个领先的在线零售商的实验团队,该零售商以其庞大的产品目录和动态用户群体而闻名。数据科学团队已部署了一种先进的推荐系统,旨在提升用户体验并推动销售。该系统与零售平台实时集成,涉及大量基础设施和工程成本。
财务团队迫切希望了解系统的财务影响,特别是它相比没有推荐的基准情境所产生的额外收入。为了评估推荐系统的有效性,您计划进行一个随机对照实验。
数据生成过程:实验前
我们首先创建一些实验前的数据。我们使用的数据生成过程具有以下特点:
-
3 个观察到的协变量与先前销售的近期性(x_recency)、频率(x_frequency)和价值(x_value)相关。
-
1 个未观察到的协变量,用户的月收入(u_income)。

用户生成的图像
- 使用协变量之间的复杂关系来估计我们的目标指标——销售值:

用户生成的图像
下面的 python 代码用于创建实验前的数据:
np.random.seed(123)
n = 10000 # Set number of observations
p = 4 # Set number of pre-experiment covariates
# Create pre-experiment covariates
X = np.random.uniform(size=n * p).reshape((n, -1))
# Nuisance parameters
b = (
1.5 * X[:, 0] +
2.5 * X[:, 1] +
X[:, 2] ** 3 +
X[:, 3] ** 2 +
X[:, 1] * X[:, 2]
)
# Create some noise
noise = np.random.normal(size=n)
# Calculate outcome
y = np.maximum(b + noise, 0)
# Scale variables for interpretation
df_pre = pd.DataFrame({"noise": noise * 1000,
"u_income": X[:, 0] * 1000,
"x_recency": X[:, 1] * 1000,
"x_frequency": X[:, 2] * 1000,
"x_value": X[:, 3] * 1000,
"y_value": y * 1000
})
# Visualise target metric
sns.histplot(df_pre['y_value'], bins=30, kde=False)
plt.xlabel('Sales Value')
plt.ylabel('Frequency')
plt.title('Sales Value')
plt.show()

用户生成的图像
实验的基本构建块:假设检验、效能分析、自助法
在我们进入 CUPED 之前,我认为有必要先讲解一些关于实验的基础知识。
假设检验
假设检验有助于判断实验中观察到的差异是否具有统计学意义,或者仅仅是随机噪音。在我们的实验中,我们将用户分为两组:
-
对照组:不接受推荐。
-
实验组:接受系统的个性化推荐。
我们将假设定义如下:
-
原假设 (H₀):推荐系统对收入没有影响。任何观察到的差异都是由于偶然因素造成的。
-
备择假设 (Hₐ):推荐系统增加了收入。接受推荐的用户相比未接受推荐的用户产生了显著更多的收入。
为了评估假设,你将比较控制组和处理组的平均收入。然而,有几个方面需要注意:
-
第一类错误(假阳性):如果实验得出结论认为推荐系统显著提高了收入,尽管实际上它没有任何效果。
-
第二类错误(Beta,假阴性):如果实验未能发现推荐系统在收入上有显著增加,尽管它实际上确实带来了有意义的增长。
-
显著性水平(Alpha):如果将显著性水平设置为 0.05,则意味着你接受 5%的概率错误地得出推荐系统提高收入的结论,尽管实际上并没有提升(假阳性)。
-
功效(1 — Beta):达到 0.80 的功效意味着在推荐系统确实有作用的情况下,你有 80%的机会检测到收入的显著增长。较高的功效可以降低假阴性的风险。
当你开始思考实验设计时,你会设定一些初步目标:
-
你希望可靠地检测到效应 — 确保你平衡了检测虚假效应与未能检测到真实效应的风险。
-
尽可能快 — 财务部门在催促你!
-
保持样本量尽可能具有成本效益 — 数据科学团队的商业案例表明该系统将大幅增加收入,因此他们不希望控制组的样本量过大。
但是,如何实现这些目标呢?接下来让我们深入探讨功效分析!
功效分析
当我们讨论实验的功效时,通常是指确定在给定置信度下,检测到某一特定大小效应所需的最小样本量的过程。功效分析包括 3 个组件:
-
效应大小 — H₀和 Hₐ的平均值之间的差异。我们通常需要根据对业务/行业的理解,做出合理的假设。
-
显著性水平 — 错误得出存在效应结论的概率,通常设定为 0.05。
-
功效 — 正确检测到效果的概率,通常设定为 0.80。
我最初发现这些直觉很难理解,但可视化真的可以帮助理解。让我们试试吧!关键区域是 H₀和 Hₐ交叉的地方——看看它是否能帮助你将上面讨论的各个组件联系起来……

用户生成的图片
更大的样本量会导致更小的标准误差。标准误差较小时,H₀和 Hₐ的抽样分布变得更窄,重叠部分减少。这种重叠的减少使得检测差异变得更加容易,从而提高了统计功效。
以下函数展示了如何使用 statsmodels Python 包来进行功效分析:
from typing import Union
import pandas as pd
import numpy as np
import statsmodels.stats.power as smp
def power_analysis(metric: Union[np.ndarray, pd.Series], exp_perc_change: float, alpha: float = 0.05, power: float = 0.80) -> int:
'''
Perform a power analysis to determine the minimum sample size required for a given metric.
Args:
metric (np.ndarray or pd.Series): Array or Series containing the metric values for the control group.
exp_perc_change (float): The expected percentage change in the metric for the test group.
alpha (float, optional): The significance level for the test. Defaults to 0.05.
power (float, optional): The desired power of the test. Defaults to 0.80.
Returns:
int: The minimum sample size required for each group to detect the expected percentage change with the specified power and significance level.
Raises:
ValueError: If `metric` is not a NumPy array or pandas Series.
'''
# Validate input types
if not isinstance(metric, (np.ndarray, pd.Series)):
raise ValueError("metric should be a NumPy array or pandas Series.")
# Calculate statistics
control_mean = metric.mean()
control_std = np.std(metric, ddof=1) # Use ddof=1 for sample standard deviation
test_mean = control_mean * (1 + exp_perc_change)
test_std = control_std # Assume the test group has the same standard deviation as the control group
# Calculate (Cohen's D) effect size
mean_diff = control_mean - test_mean
pooled_std = np.sqrt((control_std**2 + test_std**2) / 2)
effect_size = abs(mean_diff / pooled_std) # Cohen's d should be positive
# Run power analysis
power_analysis = smp.TTestIndPower()
sample_size = round(power_analysis.solve_power(effect_size=effect_size, alpha=alpha, power=power))
print(f"Control mean: {round(control_mean, 3)}")
print(f"Control std: {round(control_std, 3)}")
print(f"Min sample size: {sample_size}")
return sample_size
那么,让我们用预实验数据来进行测试!
exp_perc_change = 0.05 # Set the expected percentage change in the chosen metric caused by the treatment
min_sample_size = power_analysis(df_pre["y_value"], exp_perc_change

用户生成图像
我们可以看到,鉴于我们目标指标的分布,要检测出 5%的增加,我们需要的样本量是 1,645。
数据生成过程:实验数据
在你急于设置实验之前,你决定利用实验前的数据来模拟实验。
以下函数随机选择用户进行治疗,并应用治疗效应。在函数的最后,我们记录治疗前后均值的差异以及真实的 ATE(平均治疗效应):
def exp_data_generator(t_perc_change, t_samples):
# Create copy of pre-experiment data ready to manipulate into experiment data
df_exp = df_pre.reset_index(drop=True)
# Calculate the initial treatment effect
treatment_effect = round((df_exp["y_value"] * (t_perc_change)).mean(), 2)
# Create treatment column
treated_indices = np.random.choice(df_exp.index, size=t_samples, replace=False)
df_exp["treatment"] = 0
df_exp.loc[treated_indices, "treatment"] = 1
# treatment effect
df_exp["treatment_effect"] = 0
df_exp.loc[df_exp["treatment"] == 1, "treatment_effect"] = treatment_effect
# Apply treatment effect
df_exp["y_value_exp"] = df_exp["y_value"]
df_exp.loc[df_exp["treatment"] == 1, "y_value_exp"] = df_exp["y_value"] + df_exp["treatment_effect"]
# Calculate mean diff before treatment
mean_t0_pre = df_exp[df_exp["treatment"] == 0]["y_value"].mean()
mean_t1_pre = df_exp[df_exp["treatment"] == 1]["y_value"].mean()
mean_diff_pre = round(mean_t1_pre - mean_t0_pre)
# Calculate mean diff after treatment
mean_t0_post = df_exp[df_exp["treatment"] == 0]["y_value_exp"].mean()
mean_t1_post = df_exp[df_exp["treatment"] == 1]["y_value_exp"].mean()
mean_diff_post = round(mean_t1_post - mean_t0_post)
# Calculate ate
treatment_effect = round(df_exp[df_exp["treatment"]==1]["treatment_effect"].mean())
print(f"Diff-in-means before treatment: {mean_diff_pre}")
print(f"Diff-in-means after treatment: {mean_diff_post}")
print(f"ATE: {treatment_effect}")
return df_exp
我们可以输入之前计算的最小样本量:
np.random.seed(123)
df_exp_1 = exp_data_generator(exp_perc_change, min_sample_size)
让我们首先检查一下我们为接受治疗的用户创建的数据,帮助你理解该函数的作用:

用户生成图像
接下来,让我们看一下该函数打印的结果:

用户生成图像
有趣的是,我们看到在选择需要治疗的用户后,但在我们进行治疗之前,已经存在均值的差异。这个差异是由随机因素造成的。这意味着,当我们查看治疗后的用户差异时,我们并没有正确估计 ATE(平均治疗效应)。当我们讲解 CUPED 时,会回到这一点。

用户生成图像
接下来,我们将探讨一种比仅仅取均值差异更复杂的推断方法……
自助法
自助法是一种强大的统计技术,涉及带有放回的重抽样数据。这些重抽样的数据集被称为自助法样本,帮助我们估计从原始数据中提取统计量(如均值或中位数)的变异性。在实验中,这种方法尤其具有吸引力,因为它使我们能够计算置信区间。让我们通过一个简单的例子一步步演示…
你已经进行了一个实验,控制组和治疗组各有 1k 用户。
-
创建自助法样本 —— 从控制组和治疗组中随机选择(有放回)1k 用户。这为控制组和治疗组各提供一个自助法样本。
-
重复这个过程 n 次(例如 10k 次)。
-
对每一对自助法样本,计算控制组和治疗组之间的均值差异。
-
现在我们有了一个分布(由 10k 个自助法样本的均值差异组成),我们可以使用它来计算置信区间。

用户生成图像
将其应用于我们的案例研究
让我们通过案例研究来说明它是如何工作的。下面我们使用 sciPy stats Python 包来帮助计算自助法置信区间:
from typing import Union
import pandas as pd
import numpy as np
from scipy import stats
def mean_diff(group_a: Union[np.ndarray, pd.Series], group_b: Union[np.ndarray, pd.Series]) -> float:
'''
Calculate the difference in means between two groups.
Args:
group_a (Union[np.ndarray, pd.Series]): The first group of data points.
group_b (Union[np.ndarray, pd.Series]): The second group of data points.
Returns:
float: The difference between the mean of group_a and the mean of group_b.
'''
return np.mean(group_a) - np.mean(group_b)
def bootstrapping(df: pd.DataFrame, adjusted_metric: str, n_resamples: int = 10000) -> np.ndarray:
'''
Perform bootstrap resampling on the adjusted metric of two groups in the dataframe to estimate the mean difference and confidence intervals.
Args:
df (pd.DataFrame): The dataframe containing the data. Must include a 'treatment' column indicating group membership.
adjusted_metric (str): The name of the column in the dataframe representing the metric to be resampled.
n_resamples (int, optional): The number of bootstrap resamples to perform. Defaults to 1000.
Returns:
np.ndarray: The array of bootstrap resampled mean differences.
'''
# Separate the data into two groups based on the 'treatment' column
group_a = df[df["treatment"] == 1][adjusted_metric]
group_b = df[df["treatment"] == 0][adjusted_metric]
# Perform bootstrap resampling
res = stats.bootstrap((group_a, group_b), statistic=mean_diff, n_resamples=n_resamples, method='percentile')
ci = res.confidence_interval
# Extract the bootstrap distribution and confidence intervals
bootstrap_means = res.bootstrap_distribution
bootstrap_ci_lb = round(ci.low,)
bootstrap_ci_ub = round(ci.high)
bootstrap_mean = round(np.mean(bootstrap_means))
print(f"Bootstrap confidence interval lower bound: {bootstrap_ci_lb}")
print(f"Bootstrap confidence interval upper bound: {bootstrap_ci_ub}")
print(f"Bootstrap mean diff: {bootstrap_mean}")
return bootstrap_means
当我们为我们的案例研究数据运行这个方法时,我们可以看到我们现在有了一些置信区间:
bootstrap_og_1 = bootstrapping(df_exp_1, "y_value_exp")

用户生成图像
我们的真实 ATE(平均处理效应)是 143(来自我们实验数据生成函数的实际处理效应),该值位于我们的置信区间内。然而,值得注意的是,均值差异并没有发生变化(它仍然是 93,就像我们在仅计算控制组和处理组均值差异时一样),并且处理前的差异仍然存在。
那么,如果我们想要得出更窄的置信区间怎么办?有没有办法处理处理前的差异呢?这将我们引导到 CUPED……
什么是 CUPED,它如何帮助提升实验效能?
背景
CUPED(使用实验前数据的控制实验)是一种强大的技术,旨在通过微软研究人员开发的技术来提高实验的准确性。原始论文对于任何对实验感兴趣的人来说,都是一篇具有洞察力的阅读材料:
ai.stanford.edu/~ronnyk/2009controlledExperimentsOnTheWebSurvey.pdf
CUPED 的核心思想是使用实验开始前收集的数据来减少目标指标的方差。通过这样做,你可以使实验更敏感,从而获得两个主要好处:
-
你可以在相同的样本量下检测到更小的效果。
-
你可以在更小的样本量下检测到相同的效果。
可以把它想象成去除“背景噪声”,这样你可以更清晰地看到“信号”。
方差、标准差、标准误差
当你了解 CUPED 时,可能会听到人们谈论它如何减少方差、标准差或标准误。如果你像我一样,可能会忘记这些概念之间的关系,因此在我们继续深入之前,让我们回顾一下这些概念!
-
方差:方差衡量每个数据点与均值的平方偏差的平均值,反映了数据集内的整体分布或离散程度。
-
标准差:标准差是方差的平方根,表示每个数据点与均值的平均距离,并提供了一个更易于理解的离散度量。
-
标准误差:标准误差量化了样本均值作为总体均值估计的精度,计算方法是将标准差除以样本大小的平方根。
CUPED 是如何工作的?
为了理解 CUPED 是如何工作的,我们先来分解一下……
实验前协变量 — 在 CUPED 的最简单实现中,实验前协变量是实验开始前的一段时间内测量的目标指标。因此,如果你的目标指标是销售额,那么你的协变量可以是每个客户在实验前 4 周的销售额。
你的协变量与目标指标相关,并且不受处理的影响,这一点非常重要。这也是为什么我们通常使用控制组的处理前数据。
回归调整 — 使用线性回归来建模协变量(实验前测量)与目标指标(实验期间测量)之间的关系。然后,我们可以通过去除协变量的影响来计算 CUPED 调整后的目标指标:

用户生成的图片
值得注意的是,去除协变量的均值是为了使结果变量围绕均值进行中心化,这样当与原始目标指标进行比较时,更容易解释。
方差减少 — 经过回归调整后,我们目标指标的方差减少了。方差较小意味着控制组和实验组之间的差异更容易被检测到,从而增加了实验的统计效能。
将其应用到我们的案例研究中
让我们通过案例研究来说明它是如何工作的。下面我们将 CUPED 编写成一个函数:
from typing import Union
import pandas as pd
import numpy as np
import statsmodels.api as sm
def cuped(df: pd.DataFrame, pre_covariates: Union[str, list], target_metric: str) -> pd.Series:
'''
Implements the CUPED (Controlled Experiments Using Pre-Experiment Data) technique to adjust the target metric
by removing predictable variation using pre-experiment covariates. This reduces the variance of the metric and
increases the statistical power of the experiment.
Args:
df (pd.DataFrame): The input DataFrame containing both the pre-experiment covariates and the target metric.
pre_covariates (Union[str, list]): The column name(s) in the DataFrame corresponding to the pre-experiment covariates used for the adjustment.
target_metric (str): The column name in the DataFrame representing the metric to be adjusted.
Returns:
pd.Series: A pandas Series containing the CUPED-adjusted target metric.
'''
# Fit control model using pre-experiment covariates
control_group = df[df['treatment'] == 0]
X_control = control_group[pre_covariates]
X_control = sm.add_constant(X_control)
y_control = control_group[target_metric]
model_control = sm.OLS(y_control, X_control).fit()
# Compute residuals and adjust target metric
X_all = df[pre_covariates]
X_all = sm.add_constant(X_all)
residuals = df[target_metric].to_numpy().flatten() - model_control.predict(X_all)
adjustment_term = model_control.params['const'] + sum(model_control.params[covariate] * df[pre_covariates].mean()[covariate] for covariate in pre_covariates)
adjusted_target = residuals + adjustment_term
return adjusted_target
当我们将其应用到我们的案例研究数据中,并将调整后的目标指标与原始目标指标进行比较时,我们看到方差已减少:
# Apply CUPED
pre_covariates = ["x_recency", "x_frequency", "x_value"]
target_metric = ["y_value_exp"]
df_exp_1["adjusted_target"] = cuped(df_exp_1, pre_covariates, target_metric)
# Plot results
plt.figure(figsize=(10, 6))
sns.kdeplot(data=df_exp_1[df_exp_1['treatment'] == 0], x="adjusted_target", hue="treatment", fill=True, palette="Set1", label="Adjusted Value")
sns.kdeplot(data=df_exp_1[df_exp_1['treatment'] == 0], x="y_value_exp", hue="treatment", fill=True, palette="Set2", label="Original Value")
plt.title(f"Distribution of Value by Original vs CUPED")
plt.xlabel("Value")
plt.ylabel("Density")
plt.legend(title="Distribution")

用户生成的图片
它是否减少了标准误差?
现在我们已经应用了 CUPED 并减少了方差,让我们运行引导法函数,看看它产生了什么影响:
bootstrap_cuped_1 = bootstrapping(df_exp_1, "adjusted_target")

用户生成的图片
如果将其与我们之前使用原始目标指标的结果进行比较,你会发现置信区间更窄了:
bootstrap_1 = pd.DataFrame({
'original': bootstrap_og_1,
'cuped': bootstrap_cuped_1
})
# Plot the KDE plots
plt.figure(figsize=(10, 6))
sns.kdeplot(bootstrap_1['original'], fill=True, label='Original', color='blue')
sns.kdeplot(bootstrap_1['cuped'], fill=True, label='CUPED', color='orange')
# Add mean lines
plt.axvline(bootstrap_1['original'].mean(), color='blue', linestyle='--', linewidth=1)
plt.axvline(bootstrap_1['cuped'].mean(), color='orange', linestyle='--', linewidth=1)
plt.axvline(round(df_exp_1[df_exp_1["treatment"]==1]["treatment_effect"].mean(), 3), color='green', linestyle='--', linewidth=1, label='Treatment effect')
# Customize the plot
plt.title('Distribution of Value by Original vs CUPED')
plt.xlabel('Value')
plt.ylabel('Density')
plt.legend()
# Show the plot
plt.show()

用户生成的图片
引导法(bootstrap)均值差异也逐渐接近真实的处理效应。这是因为 CUPED 对处理控制组和实验组之间的先前差异非常有效。
它是否减少了最小样本量?
下一个问题是,它是否减少了我们所需的最小样本量。让我们来看看吧!
treatment_effect_1 = round(df_exp_1[df_exp_1["treatment"]==1]["treatment_effect"].mean(), 2)
cuped_sample_size = power_analysis(df_exp_1[df_exp_1['treatment'] == 0]['adjusted_target'], treatment_effect_1 / df_exp_1[df_exp_1['treatment'] == 0]['adjusted_target'].mean())

用户生成的图片
所需的最小样本量从 1,645 减少到了 901。无论是财务团队还是数据科学团队都将非常高兴,因为我们可以在较短的时间内用更小的控制样本来进行实验!
CUPED 和双重机器学习之间有哪些概念上的相似之处?
背景
当我第一次读到 CUPED 时,我就想到了双重机器学习及其相似之处。如果你不熟悉双重机器学习,查看我在系列文章中的早期内容:
因果 AI,探索因果推理与机器学习的结合
towardsdatascience.com
请注意双重机器学习中的第一阶段结果模型:
- 结果模型(去噪): 用于仅使用控制特征估计结果的机器学习模型。然后计算结果模型的残差。
这在概念上与我们使用 CUPED 的做法非常相似!
它与 CUPED 相比如何?
让我们通过案例研究数据,看看是否得到类似的结果:
# Train DML model
dml = LinearDML(discrete_treatment=False)
dml.fit(df_exp_1[target_metric].to_numpy().ravel(), T=df_exp_1['treatment'].to_numpy().ravel(), X=df_exp_1[pre_covariates], W=None)
ate_dml = round(dml.ate(df_exp_1[pre_covariates]))
ate_dml_lb = round(dml.ate_interval(df_exp_1[pre_covariates])[0])
ate_dml_ub = round(dml.ate_interval(df_exp_1[pre_covariates])[1])
print(f'DML confidence interval lower bound: {ate_dml_lb}')
print(f'DML confidence interval upper bound: {ate_dml_ub}')
print(f'DML ate: {ate_dml}')

用户生成的图像
我们得到几乎完全相同的结果!
当我们绘制残差时,可以看到方差像在 CUPED 中一样减少(尽管我们没有添加均值来进行缩放以便于解释):
# Fit model outcome model using pre-experiment covariates
X_all = df_exp_1[pre_covariates]
X_all = sm.add_constant(X)
y_all = df_exp_1[target_metric]
outcome_model = sm.OLS(y_all, X_all).fit()
# Compute residuals and adjust target metric
df_exp_1['outcome_residuals'] = df_exp_1[target_metric].to_numpy().flatten() - outcome_model.predict(X_all)
# Plot results
plt.figure(figsize=(10, 6))
sns.kdeplot(data=df_exp_1[df_exp_1['treatment'] == 0], x="outcome_residuals", hue="treatment", fill=True, palette="Set1", label="Adjusted Target")
sns.kdeplot(data=df_exp_1[df_exp_1['treatment'] == 0], x="y_value_exp", hue="treatment", fill=True, palette="Set2", label="Original Value")
plt.title(f"Distribution of Value by Original vs DML")
plt.xlabel("Value")
plt.ylabel("Density")
plt.legend(title="Distribution")
plt.show()

用户生成的图像
“那又怎么样?”我听到你问!
首先,我认为这对任何使用双重机器学习的人来说是一个有趣的观察——第一阶段结果模型有助于减少方差,因此我们应该能够获得类似 CUPED 的好处。
其次,它提出了一个问题:每种方法什么时候合适?让我们通过回答这个问题来结束本部分…
我们什么时候应该使用双重机器学习而不是 CUPED?
有几个原因说明为什么倾向于使用 CUPED 可能是合理的:
-
这更容易理解。
-
它更易于实现。
-
这是一个模型,而不是三个模型,这意味着你会面临更少的过拟合挑战。
然而,也有一些例外情况,双重机器学习优于 CUPED:
- 偏倚的处理分配——当处理分配存在偏倚时,例如使用观察数据时,双重机器学习可以处理这个问题。我的上一篇文章深入探讨了这一点:
因果 AI,探索因果推理与机器学习的结合
towardsdatascience.com](/de-biasing-treatment-effects-with-double-machine-learning-63b16fcb3e97?source=post_page-----34dc2f3d3284--------------------------------)
- 异质治疗效果——当你想了解个体层面的效果时,例如找出哪些人值得发送折扣,双重机器学习可以提供帮助。在我之前关于优化治疗策略的文章中,有一个很好的案例研究说明了这一点:
因果 AI,探索因果推理与机器学习的结合
towardsdatascience.com](/using-double-machine-learning-and-linear-programming-to-optimise-treatment-strategies-920c20a29553?source=post_page-----34dc2f3d3284--------------------------------)
最后的思考
今天我们进行了快速浏览实验的过程,涵盖了假设检验、功效分析和自助法(bootstrapping)。然后,我们探讨了 CUPED 如何减少标准误差并提高实验的功效。最后,我们提到了它与双重机器学习(double machine learning)的相似性,并讨论了何时应该使用每种方法。关于 CUPED,还有一些额外的关键点值得一提:
-
我们不必使用线性回归——如果我们有多个协变量,其中一些具有非线性关系,我们可以使用机器学习技术,例如提升方法(boosting)。
-
如果我们选择使用机器学习技术,我们需要确保不会对数据进行过拟合。
-
何时运行 CUPED 需要仔细思考——你是打算在实验开始前运行 CUPED,然后进行功效分析以确定减少后的样本量?还是打算在实验结束后运行它以减少标准误差?
如果你想继续深入了解因果 AI,请跟随我——在下一篇文章中,我们将探讨多重共线性是否在市场营销组合模型中影响你的因果推断!
产品分析师的实用计算机模拟
第二部分:使用自助法进行观测和 A/B 测试
·发表于 Towards Data Science ·21 分钟阅读·2024 年 4 月 30 日
--

图片来自 DALL-E 3
在系列的第一部分中,我们讨论了计算机模拟的基本思想,以及如何利用它们来回答“如果……会怎样”问题。在谈论模拟时,无法不提自助法。
统计学中的自助法是一种实用的计算机方法,用于估计概率分布的统计量。它基于通过蒙特卡罗方法从现有样本中反复生成样本。这种方法可以简单快捷地估计复杂模型的各种统计量(如置信区间、方差、相关性等)。
当我在统计学课程中学习自助法时,感觉有点像是黑科技。你不需要学习多个公式和不同情况下的标准,只需写几行代码,就能为任何自定义和复杂的用例获得置信区间估计。这听起来就像魔法一样。
这确实是事实。现在,甚至你的笔记本电脑都能在几分钟甚至几秒钟内运行成千上万次模拟,自助法是你分析工具包中一个强大的工具,可以帮助你在许多情况下解决问题。因此,我认为学习或刷新你对它的理解是值得的。
在本文中,我们将讨论自助法背后的思想,了解何时应该使用它,学习如何为不同的度量获取置信区间,并分析 A/B 测试的结果。
什么是自助法(bootstrap)?
事实上,引导法非常简单。我们需要通过有放回地从样本分布中抽取元素来进行模拟,然后基于这个分布得出结论。
让我们看一个简单的例子,假设我们有四个元素:1、2、3 和 4。然后,我们可以模拟许多其他包含 4 个元素的集合,每个元素可能是 1、2、3 或 4,且每个元素的概率相等,并使用这些模拟来了解例如均值如何变化。

引导法的统计意义在于,我们认为实际总体与我们的样本具有完全相同的分布(或者总体由无数个我们的样本副本组成)。然后,我们假设我们了解总体,并利用它来理解数据中的变异性。
通常,在使用经典统计方法时,我们假设变量遵循某种已知分布(例如正态分布)。然而,在引导法(Bootstrap)中,我们不需要对分布的性质做任何假设。它非常方便,甚至可以帮助分析非常复杂的自定义指标。
几乎不可能搞错引导法估计。所以,在很多情况下,我更倾向于使用它而不是经典统计方法。唯一的缺点是计算时间。如果你在处理大数据时,模拟可能需要几个小时,而经典统计方法的估计则可以在几秒钟内完成。
然而,有些情况下,如果没有引导法,获得估计值会非常具有挑战性。让我们讨论一下引导法的最佳使用场景:
-
如果你的数据中有异常值或有影响力的数据点;
-
如果你的样本相对较小(大约少于 100 个案例);
-
如果你的数据分布与正态分布或其他理论分布差距较大,例如,它有多个峰值;
-
如果你在处理自定义指标(例如,在 SLA 内完成的案件占比或百分位数)。
引导法是一个奇妙且强大的统计概念。我们来尝试将它用于描述性统计。
处理观测数据
首先,让我们从观测数据开始,并使用一个合成数据集。假设我们正在帮助一家健身俱乐部设立一个新的健身计划,帮助客户为伦敦马拉松做准备。我们得到的第一组试验组包含 12 名客户,并测量了他们的结果。
这是我们拥有的数据。

我们为每个 12 名客户收集了三个字段:
-
races_before— 客户在我们的计划之前参加的比赛次数, -
kms_during_program— 客户在我们计划中跑步的公里数, -
finished_marathon— 该计划是否成功,客户是否完成了伦敦马拉松。
我们的目标是建立一个以目标为导向的公平项目,激励我们的客户更多地与我们一起训练,并取得更好的成绩。因此,我们希望在客户在准备过程中跑了至少 150 公里但未能完成马拉松时退还他们的费用。然而,在启动这个项目之前,我们想做一些估算:客户在准备期间跑了多远,以及预计退款的比例。我们需要这些数据来确保我们的业务盈利且可持续。
估计平均值
让我们从估计平均距离开始。我们可以尝试利用我们在数学统计学方面的知识,并使用置信区间的公式。
为此,我们需要对这个变量的分布做出假设。最常用的是正态分布。让我们试试。
import numpy as np
from scipy.stats import norm, t
def get_normal_confidence_interval(data, confidence=0.95):
# Calculate sample mean and standard deviation
sample_mean = np.mean(data)
sample_std = np.std(data, ddof=1)
n = len(data)
# Calculate the critical value (z) based on the confidence level
z = norm.ppf((1 + confidence) / 2)
# Calculate the margin of error using standard error
margin_of_error = z * sample_std / np.sqrt(n)
# Calculate the confidence interval
lower_bound = sample_mean - margin_of_error
upper_bound = sample_mean + margin_of_error
return lower_bound, upper_bound
get_normal_confidence_interval(df.kms_during_program.values)
# (111.86, 260.55)
另一种通常与真实数据一起使用的分布是 t 检验分布,它给出一个更宽的置信区间(因为它假设比正态分布更胖的尾部)。
def get_ttest_confidence_interval(data, confidence=0.95):
# Calculate sample mean and standard deviation
sample_mean = np.mean(data)
sample_std = np.std(data, ddof=1)
n = len(data)
# Calculate the critical value (z) based on the confidence level
z = t.ppf((1 + confidence) / 2, df=len(data) - 1)
# Calculate the margin of error using standard error
margin_of_error = z * sample_std / np.sqrt(n)
# Calculate the confidence interval
lower_bound = sample_mean - margin_of_error
upper_bound = sample_mean + margin_of_error
return lower_bound, upper_bound
get_ttest_confidence_interval(df.kms_during_program.values)
# (102.72, 269.69)
我们的样本中有一些例子。另外,还有一个离群值:一位拥有 12 场比赛经验的客户,成功跑了近 600 公里为马拉松做准备,而其他大多数客户的跑步距离不到 200 公里。

因此,这是一个使用自助法技术来更好地理解分布和置信区间的绝佳案例。
让我们创建一个函数来计算并可视化置信区间:
-
我们运行
num_batches次模拟,进行有放回的抽样,并计算平均距离。 -
然后,基于这些变量,我们可以得到一个 95%的置信区间:该分布的 2.5%和 97.5%百分位数。
-
最后,我们可以在图表上可视化分布。
import tqdm
import matplotlib.pyplot as plt
def get_kms_confidence_interval(num_batches, confidence = 0.95):
# Running simulations
tmp = []
for i in tqdm.tqdm(range(num_batches)):
tmp_df = df.sample(df.shape[0], replace = True)
tmp.append(
{
'iteration': i,
'mean_kms': tmp_df.kms_during_program.mean()
}
)
# Saving data
bootstrap_df = pd.DataFrame(tmp)
# Calculating confidence interval
lower_bound = bootstrap_df.mean_kms.quantile((1 - confidence)/2)
upper_bound = bootstrap_df.mean_kms.quantile(1 - (1 - confidence)/2)
# Creating a chart
ax = bootstrap_df.mean_kms.hist(bins = 50, alpha = 0.6,
color = 'purple')
ax.set_title('Average kms during the program, iterations = %d' % num_batches)
plt.axvline(x=lower_bound, color='navy', linestyle='--',
label='lower bound = %.2f' % lower_bound)
plt.axvline(x=upper_bound, color='navy', linestyle='--',
label='upper bound = %.2f' % upper_bound)
ax.annotate('CI lower bound: %.2f' % lower_bound,
xy=(lower_bound, ax.get_ylim()[1]),
xytext=(-10, -20),
textcoords='offset points',
ha='center', va='top',
color='navy', rotation=90)
ax.annotate('CI upper bound: %.2f' % upper_bound,
xy=(upper_bound, ax.get_ylim()[1]),
xytext=(-10, -20),
textcoords='offset points',
ha='center', va='top',
color='navy', rotation=90)
plt.xlim(ax.get_xlim()[0] - 20, ax.get_xlim()[1] + 20)
plt.show()
让我们从少量的批次开始,以便快速查看初步结果。
get_kms_confidence_interval(100)

使用自助法,我们得到了一个稍微窄一些并偏向右侧的置信区间,这与我们的实际分布相符:(139.31, 297.99) 与 (102.72, 269.69)。
然而,经过 100 次自助法模拟后,分布并不十分清晰。让我们尝试增加更多的迭代次数。我们可以看到我们的分布包含了多个模态——对于含有一个离群值、两个离群值、三个离群值等的样本。

随着更多迭代的进行,我们可以看到更多的模态(因为离群值的出现次数更少),但所有的置信区间都非常接近。
在自助法的情况下,增加更多的迭代次数并不会导致过拟合(因为每次迭代是独立的)。我会把它想象成是增加图像的分辨率。
由于我们的样本量较小,运行大量模拟不会花费太多时间。即使是 100 万次自助法迭代,也只需要大约 1 分钟。
估计自定义指标
正如我们讨论的,使用自助法在处理不像平均值那么直观的指标时非常有用。例如,你可能想要估计中位数或在 SLA 内完成的任务比例。
你甚至可能用 bootstrap 做一些更不寻常的事情。假设你想给客户提供折扣,如果你的配送延迟了:延迟 15 分钟给予 5%的折扣,延迟 1 小时给予 10%的折扣,延迟 3 小时给予 20%的折扣。
理论上,使用简单的统计方法来获取这种情况下的置信区间可能会很有挑战性,因此 bootstrap 将非常有价值。
让我们回到我们的跑步程序,估算退款的比例(当一个客户跑了 150 公里却没能完成马拉松)。我们将使用类似的函数,但将计算每次迭代的退款比例,而不是均值。
import tqdm
import matplotlib.pyplot as plt
def get_refund_share_confidence_interval(num_batches, confidence = 0.95):
# Running simulations
tmp = []
for i in tqdm.tqdm(range(num_batches)):
tmp_df = df.sample(df.shape[0], replace = True)
tmp_df['refund'] = list(map(
lambda kms, passed: 1 if (kms >= 150) and (passed == 0) else 0,
tmp_df.kms_during_program,
tmp_df.finished_marathon
))
tmp.append(
{
'iteration': i,
'refund_share': tmp_df.refund.mean()
}
)
# Saving data
bootstrap_df = pd.DataFrame(tmp)
# Calculating confident interval
lower_bound = bootstrap_df.refund_share.quantile((1 - confidence)/2)
upper_bound = bootstrap_df.refund_share.quantile(1 - (1 - confidence)/2)
# Creating a chart
ax = bootstrap_df.refund_share.hist(bins = 50, alpha = 0.6,
color = 'purple')
ax.set_title('Share of refunds, iterations = %d' % num_batches)
plt.axvline(x=lower_bound, color='navy', linestyle='--',
label='lower bound = %.2f' % lower_bound)
plt.axvline(x=upper_bound, color='navy', linestyle='--',
label='upper bound = %.2f' % upper_bound)
ax.annotate('CI lower bound: %.2f' % lower_bound,
xy=(lower_bound, ax.get_ylim()[1]),
xytext=(-10, -20),
textcoords='offset points',
ha='center', va='top',
color='navy', rotation=90)
ax.annotate('CI upper bound: %.2f' % upper_bound,
xy=(upper_bound, ax.get_ylim()[1]),
xytext=(-10, -20),
textcoords='offset points',
ha='center', va='top',
color='navy', rotation=90)
plt.xlim(-0.1, 1)
plt.show()
即便只有 12 个样本,我们也获得了一个小于 2 倍的置信区间。我们可以以 95%的置信度得出结论,少于 42%的客户将有资格获得退款。

这是一个很好的结果,考虑到数据量如此之小。然而,我们还可以更进一步,尝试对因果效应进行估算。
效果估算
我们有关于这场马拉松之前的几场比赛的数据,并且可以看到这个值与预期距离之间的相关性。我们也可以使用 bootstrap 来处理这个问题。我们只需要在当前过程中加入线性回归步骤。
def get_races_coef_confidence_interval(num_batches, confidence = 0.95):
# Running simulations
tmp = []
for i in tqdm.tqdm(range(num_batches)):
tmp_df = df.sample(df.shape[0], replace = True)
# Linear regression model
model = smf.ols('kms_during_program ~ races_before', data = tmp_df).fit()
tmp.append(
{
'iteration': i,
'races_coef': model.params['races_before']
}
)
# Saving data
bootstrap_df = pd.DataFrame(tmp)
# Calculating confident interval
lower_bound = bootstrap_df.races_coef.quantile((1 - confidence)/2)
upper_bound = bootstrap_df.races_coef.quantile(1 - (1 - confidence)/2)
# Creating a chart
ax = bootstrap_df.races_coef.hist(bins = 50, alpha = 0.6, color = 'purple')
ax.set_title('Coefficient between kms during the program and previous races, iterations = %d' % num_batches)
plt.axvline(x=lower_bound, color='navy', linestyle='--', label='lower bound = %.2f' % lower_bound)
plt.axvline(x=upper_bound, color='navy', linestyle='--', label='upper bound = %.2f' % upper_bound)
ax.annotate('CI lower bound: %.2f' % lower_bound,
xy=(lower_bound, ax.get_ylim()[1]),
xytext=(-10, -20),
textcoords='offset points',
ha='center', va='top',
color='navy', rotation=90)
ax.annotate('CI upper bound: %.2f' % upper_bound,
xy=(upper_bound, ax.get_ylim()[1]),
xytext=(10, -20),
textcoords='offset points',
ha='center', va='top',
color='navy', rotation=90)
# plt.legend()
plt.xlim(ax.get_xlim()[0] - 5, ax.get_xlim()[1] + 5)
plt.show()
return bootstrap_df
我们可以查看分布。由于置信区间大于 0,因此我们可以以 95%的置信度说存在某种效应。

你可以发现分布是双峰的,每个峰值对应着一种情形:
-
12 周围的组件与没有异常值的样本有关——它是对前几场比赛对程序中预期距离的影响的估算,如果我们忽略掉异常值的话。
-
第二个组件对应的是数据集中存在一个或多个异常值的样本。
所以,通过查看 bootstrap 分布,能对不同情境做出估算真的很酷。
我们已经学会了如何使用 bootstrap 与观测数据,但它的核心用途是 A/B 测试。因此,接下来让我们继续看第二个例子。
A/B 测试的模拟
bootstrap 的另一个常见应用是设计和分析 A/B 测试。让我们看一个例子。它也将基于一个合成数据集,展示折扣对客户留存的影响。假设我们正在做一个电子杂货产品,想要测试我们推出的 20 欧元折扣营销活动是否会影响客户的消费。
关于每个客户,我们知道他们的居住国家、与他们同住的家庭成员数量、该国的平均年薪以及他们在我们商店里消费的金额。

功效分析
首先,我们需要设计实验,并且理解每个实验组中需要多少客户,以便我们可以有信心地得出结论。这个步骤叫做功效分析。
让我们快速回顾一下 A/B 测试的基本统计理论和主要指标。每个测试都基于原假设(即当前现状)。在我们的案例中,原假设是“折扣不影响客户对我们产品的消费”。然后,我们需要收集控制组和实验组的客户消费数据,并估算在原假设成立的情况下看到这样或更极端结果的概率。这个概率叫做 p 值,如果 p 值足够小,我们可以得出结论,认为我们有足够的数据来拒绝原假设,进而说明处理对客户的消费或留存产生了影响。
在这种方法中,有三个主要指标:
-
效应大小——我们希望能够检测到的度量的最小变化,
-
统计显著性等于假阳性率(拒绝原假设时实际没有效应的概率)。最常用的显著性水平是 5%。然而,根据你的假阳性容忍度,你可能会选择其他值。例如,如果实施变更的成本较高,你可能希望使用较低的显著性阈值。
-
统计功效表示在实际存在与效应大小相等或更大的效应时,拒绝原假设的概率。人们通常使用 80%的阈值,但在某些情况下(即你希望更有信心地排除负面效应),你可能会选择 90%甚至 99%。
我们需要所有这些数值来估算实验中的客户数量。让我们尝试在我们的案例中定义这些数值,以便更好地理解它们的含义。
我们将从效应大小开始:
-
我们预计通过我们的活动,客户的留存率将至少变化 3 个百分点,
-
我们希望能够发现客户消费变化达到 20 欧元或更多。
对于统计显著性,我将使用默认的 5%阈值(这样,如果我们在 A/B 测试分析中看到效应,我们可以 95%的信心确认效应存在)。我们将目标设定为 90%的统计功效阈值,以便如果实际效应等于或大于效应大小,我们将在 90%的情况下发现这种变化。
让我们从统计公式开始,这将帮助我们快速获得估算值。统计公式假设我们的变量具有特定的分布,但通常它们可以帮助你估算样本数量的大小。稍后,我们将使用自助法(bootstrap)获得更准确的结果。
对于留存率,我们可以使用标准的比例检验。我们需要知道实际值以估算标准化效应大小。我们可以通过实验前的历史数据获得这个值。
import statsmodels.stats.power as stat_power
import statsmodels.stats.proportion as stat_prop
base_retention = before_df.retention.mean()
ret_effect_size = stat_prop.proportion_effectsize(base_retention + 0.03,
base_retention)
sample_size = 2*stat_power.tt_ind_solve_power(
effect_size = ret_effect_size,
alpha = 0.05, power = 0.9,
nobs1 = None, # we specified nobs1 as None to get an estimation for it
alternative='larger'
)
# ret_effect_size = 0.0632, sample_size = 8573.86
我们使用单边检验,因为从商业角度来看,是否存在负效应或没有效应并无区别,因为我们不会实施此变更。使用单边检验而非双边检验可以提高统计功效。
我们可以类似地估算客户价值的样本大小,假设正态分布。然而,实际分布并不是正态的,因此我们应该期待从自助法(bootstrap)中获得更精确的结果。

让我们写代码。
val_effect_size = 20/before_df.customer_value.std()
sample_size = 2*stat_power.tt_ind_solve_power(
effect_size = val_effect_size,
alpha = 0.05, power = 0.9,
nobs1 = None,
alternative='larger'
)
# val_effect_size = 0.0527, sample_size = 12324.13
我们得到了每个测试所需的样本大小的估算。然而,也有一些情况是你拥有有限数量的客户,并且希望了解你可以获得的统计效能。
假设我们只有 5000 个客户(每组 2500 个)。那么,对于留存分析,我们将能够实现 72.2%的统计效能,对于客户价值分析则是 58.7%(假设期望的统计显著性和效应大小)。
唯一的区别是这次我们指定了nobs1 = 2500,并将power保留为None。
stat_power.tt_ind_solve_power(
effect_size = ret_effect_size,
alpha = 0.05, power = None,
nobs1 = 2500,
alternative='larger'
)
# 0.7223
stat_power.tt_ind_solve_power(
effect_size = val_effect_size,
alpha = 0.05, power = None,
nobs1 = 2500,
alternative='larger'
)
# 0.5867
现在,是时候使用自助法进行效能分析了,我们将从客户价值测试开始,因为它更容易实现。

让我们讨论一下使用自助法进行效能分析的基本思路和步骤。首先,我们需要明确我们的目标。我们希望估算统计效能与样本大小的关系。如果用更实际的术语来说,我们希望知道在客户消费增加 20 欧元或以上的情况下,我们能够拒绝原假设并实施这一变化的百分比。因此,我们需要模拟一系列这样的实验,并计算出在多少情况下,我们能够看到指标的统计显著变化。
让我们来看一个实验,并将其分解为步骤。第一步是生成实验数据。为此,我们需要从总体中随机抽取一个与样本大小相等的子集,随机将这些客户分配到控制组和实验组,并为处理组添加一个与效果大小相等的效应。所有这些逻辑都在下面的get_sample_for_value函数中实现。
def get_sample_for_value(pop_df, sample_size, effect_size):
# getting sample of needed size
sample_df = pop_df.sample(sample_size)
# randomly assign treatment
sample_df['treatment'] = sample_df.index.map(
lambda x: 1 if np.random.uniform() > 0.5 else 0)
# add efffect for the treatment group
sample_df['predicted_value'] = sample_df['customer_value'] \
+ effect_size * sample_df.treatment
return sample_df
现在,我们可以像通常进行 A/B 测试分析那样处理这个合成实验数据,运行一系列自助法模拟,估算效果,然后为这个效果获取一个置信区间。
我们将使用线性回归来估算处理效果。如上一篇文章中讨论的那样,值得将能够解释结果变量(客户消费)的特征添加到线性回归中。我们将把家庭成员数量和平均工资加入回归中,因为它们与客户消费正相关。
import statsmodels.formula.api as smf
val_model = smf.ols('customer_value ~ num_family_members + country_avg_annual_earning',
data = before_df).fit(disp = 0)
val_model.summary().tables[1]

我们将把进行多个自助法模拟和估算处理效果的所有逻辑放入get_ci_for_value函数中。
def get_ci_for_value(df, boot_iters, confidence_level):
tmp_data = []
for iter in range(boot_iters):
sample_df = df.sample(df.shape[0], replace = True)
val_model = smf.ols('predicted_value ~ treatment + num_family_members + country_avg_annual_earning',
data = sample_df).fit(disp = 0)
tmp_data.append(
{
'iteration': iter,
'coef': val_model.params['treatment']
}
)
coef_df = pd.DataFrame(tmp_data)
return coef_df.coef.quantile((1 - confidence_level)/2),
coef_df.coef.quantile(1 - (1 - confidence_level)/2)
下一步是将这些逻辑整合在一起,运行一系列这样的合成实验,并保存结果。
def run_simulations_for_value(pop_df, sample_size, effect_size,
boot_iters, confidence_level, num_simulations):
tmp_data = []
for sim in tqdm.tqdm(range(num_simulations)):
sample_df = get_sample_for_value(pop_df, sample_size, effect_size)
num_users_treatment = sample_df[sample_df.treatment == 1].shape[0]
value_treatment = sample_df[sample_df.treatment == 1].predicted_value.mean()
num_users_control = sample_df[sample_df.treatment == 0].shape[0]
value_control = sample_df[sample_df.treatment == 0].predicted_value.mean()
ci_lower, ci_upper = get_ci_for_value(sample_df, boot_iters, confidence_level)
tmp_data.append(
{
'experiment_id': sim,
'num_users_treatment': num_users_treatment,
'value_treatment': value_treatment,
'num_users_control': num_users_control,
'value_control': value_control,
'sample_size': sample_size,
'effect_size': effect_size,
'boot_iters': boot_iters,
'confidence_level': confidence_level,
'ci_lower': ci_lower,
'ci_upper': ci_upper
}
)
return pd.DataFrame(tmp_data)
让我们为sample_size = 100运行这个模拟,并查看结果。
val_sim_df = run_simulations_for_value(before_df, sample_size = 100,
effect_size = 20, boot_iters = 1000, confidence_level = 0.95,
num_simulations = 20)
val_sim_df.set_index('simulation')[['sample_size', 'ci_lower', 'ci_upper']].head()
我们得到了 20 个模拟实验的数据。我们知道每个实验的置信区间,现在我们可以估计功效。

如果置信区间的下限大于零,我们将拒绝零假设,所以让我们计算这些实验的比例。
val_sim_df['successful_experiment'] = val_sim_df.ci_lower.map(
lambda x: 1 if x > 0 else 0)
val_sim_df.groupby(['sample_size', 'effect_size']).aggregate(
{
'successful_experiment': 'mean',
'experiment_id': 'count'
}
)

我们从仅 20 个模拟实验和 1000 个自助法模拟开始,以估计其置信区间。如此少的模拟可以帮助我们迅速获得一个低分辨率的图像。考虑到我们从经典统计中得到的估计,我们应该预期大约 10K 的样本量能够提供所需的统计功效。
tmp_dfs = []
for sample_size in [100, 250, 500, 1000, 2500, 5000, 10000, 25000]:
print('Simulation for sample size = %d' % sample_size)
tmp_dfs.append(
run_simulations_for_value(before_df, sample_size = sample_size, effect_size = 20,
boot_iters = 1000, confidence_level = 0.95, num_simulations = 20)
)
val_lowres_sim_df = pd.concat(tmp_dfs)
我们得到了与理论估算相似的结果。让我们尝试使用更多的模拟实验(100 次和 500 次实验)来运行估算。我们可以看到,12.5K 的客户将足以达到 90%的统计功效。
我已将所有功效分析结果添加到图表中,这样我们就可以清晰地看到关系。

在这种情况下,您可能已经看到,自助法可能需要大量的时间。例如,仅仅为了对 3 个样本量进行 500 次实验模拟来准确估计功效,我花了将近 2 小时。

现在,我们可以估计对于 12.5K 样本量,效应大小与统计功效之间的关系。
tmp_dfs = []
for effect_size in [1, 5, 10, 15, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100]:
print('Simulation for effect size = %d' % effect_size)
tmp_dfs.append(
run_simulations_for_value(before_df, sample_size = 12500, effect_size = effect_size,
boot_iters = 1000, confidence_level = 0.95, num_simulations = 100)
)
val_effect_size_sim_df = pd.concat(tmp_dfs)
我们可以看到,如果对客户消费的实际影响高于 20 欧元,我们将获得更高的统计功效,并且我们能够在超过 90%的情况下拒绝零假设。但在不到 50%的情况下,我们将无法发现 10 欧元的效应。

让我们继续进行保留率的功效分析。完整的代码与客户消费分析的结构相似。我们将在下文详细讨论其中的细节。
import tqdm
def get_sample_for_retention(pop_df, sample_size, effect_size):
base_ret_model = smf.logit('retention ~ num_family_members', data = pop_df).fit(disp = 0)
tmp_pop_df = pop_df.copy()
tmp_pop_df['predicted_retention_proba'] = base_ret_model.predict()
sample_df = tmp_pop_df.sample(sample_size)
sample_df['treatment'] = sample_df.index.map(lambda x: 1 if np.random.uniform() > 0.5 else 0)
sample_df['predicted_retention_proba'] = sample_df['predicted_retention_proba'] + effect_size * sample_df.treatment
sample_df['retention'] = sample_df.predicted_retention_proba.map(lambda x: 1 if x >= np.random.uniform() else 0)
return sample_df
def get_ci_for_retention(df, boot_iters, confidence_level):
tmp_data = []
for iter in range(boot_iters):
sample_df = df.sample(df.shape[0], replace = True)
ret_model = smf.logit('retention ~ treatment + num_family_members', data = sample_df).fit(disp = 0)
tmp_data.append(
{
'iteration': iter,
'coef': ret_model.params['treatment']
}
)
coef_df = pd.DataFrame(tmp_data)
return coef_df.coef.quantile((1 - confidence_level)/2), coef_df.coef.quantile(1 - (1 - confidence_level)/2)
def run_simulations_for_retention(pop_df, sample_size, effect_size,
boot_iters, confidence_level, num_simulations):
tmp_data = []
for sim in tqdm.tqdm(range(num_simulations)):
sample_df = get_sample_for_retention(pop_df, sample_size, effect_size)
num_users_treatment = sample_df[sample_df.treatment == 1].shape[0]
retention_treatment = sample_df[sample_df.treatment == 1].retention.mean()
num_users_control = sample_df[sample_df.treatment == 0].shape[0]
retention_control = sample_df[sample_df.treatment == 0].retention.mean()
ci_lower, ci_upper = get_ci_for_retention(sample_df, boot_iters, confidence_level)
tmp_data.append(
{
'experiment_id': sim,
'num_users_treatment': num_users_treatment,
'retention_treatment': retention_treatment,
'num_users_control': num_users_control,
'retention_control': retention_control,
'sample_size': sample_size,
'effect_size': effect_size,
'boot_iters': boot_iters,
'confidence_level': confidence_level,
'ci_lower': ci_lower,
'ci_upper': ci_upper
}
)
return pd.DataFrame(tmp_data)
首先,由于我们有一个二元的保留结果(客户是否会在下个月回归),我们将使用逻辑回归模型,而不是线性回归。我们可以看到,保留率与家庭规模有关。可能是因为当你为家人购买许多不同类型的产品时,找到一个能满足你所有需求的其他服务变得更困难。
base_ret_model = smf.logit('retention ~ num_family_members', data = before_df).fit(disp = 0)
base_ret_model.summary().tables[1]

此外,函数get_sample_for_retention具有一些较为复杂的逻辑来调整处理组的结果。让我们一步一步地分析。
首先,我们在整个样本数据上拟合一个逻辑回归模型,并使用该模型来预测保留的概率。
base_ret_model = smf.logit('retention ~ num_family_members', data = pop_df)\
.fit(disp = 0)
tmp_pop_df = pop_df.copy()
tmp_pop_df['predicted_retention_proba'] = base_ret_model.predict()
然后,我们获取了一个与样本量相等的随机样本,并将其分为对照组和实验组。
sample_df = tmp_pop_df.sample(sample_size)
sample_df['treatment'] = sample_df.index.map(
lambda x: 1 if np.random.uniform() > 0.5 else 0)
对于处理组,我们通过预期的效应大小来增加保留概率。
sample_df['predicted_retention_proba'] = sample_df['predicted_retention_proba'] \
+ effect_size * sample_df.treatment
最后一步是基于概率定义客户是否被保留。我们使用了均匀分布(在 0 和 1 之间的随机数)来进行计算:
-
如果从均匀分布中抽取的随机值低于概率值,那么客户就会被保留(这是按照指定的概率发生的),
-
否则,客户已经流失。
sample_df['retention'] = sample_df.predicted_retention_proba.map(
lambda x: 1 if x > np.random.uniform() else 0)
你可以运行一些模拟来确保我们的抽样函数按预期工作。例如,通过这个调用,我们可以看到控制组的客户保留率为 64%,与总体一致,而实验组的客户保留率为 93.7%(这是在 effect_size = 0.3 时的预期结果)。
get_sample_for_retention(before_df, 10000, 0.3)\
.groupby('treatment', as_index = False).retention.mean()
# | | treatment | retention |
# |---:|------------:|------------:|
# | 0 | 0 | 0.640057 |
# | 1 | 1 | 0.937648 |
现在,我们还可以运行模拟,看看达到 90% 统计功效所需的样本量。我们可以看到,12.5K 的样本量对于客户保留来说也是足够的。

结果分析
我们可以使用线性回归或逻辑回归分析结果,或者利用我们已经拥有的函数来计算 bootstrap CI。
value_model = smf.ols(
'customer_value ~ treatment + num_family_members + country_avg_annual_earning',
data = experiment_df).fit(disp = 0)
value_model.summary().tables[1]

因此,我们得出了对于客户支出 25.84 欧元的统计显著结果,其 95% 置信区间为 (16.82, 34.87)。
使用 bootstrap 函数时,置信区间将非常接近。
get_ci_for_value(experiment_df.rename(
columns = {'customer_value': 'predicted_value'}), 1000, 0.95)
# (16.28, 34.63)
同样,我们也可以使用逻辑回归进行客户保留分析。
retention_model = smf.logit('retention ~ treatment + num_family_members',
data = experiment_df).fit(disp = 0)
retention_model.summary().tables[1]

再次强调,bootstrap 方法提供了接近的置信区间(CI)估计。
get_ci_for_retention(experiment_df, 1000, 0.95)
# (0.072, 0.187)
对于逻辑回归,解释系数可能会有些棘手。不过,我们可以使用一种简便方法:对于数据集中的每个客户,计算客户在控制组和处理组中的概率,然后查看概率之间的平均差异。
experiment_df['treatment_eq_1'] = 1
experiment_df['treatment_eq_0'] = 0
experiment_df['retention_proba_treatment'] = retention_model.predict(
experiment_df[['retention', 'treatment_eq_1', 'num_family_members']]\
.rename(columns = {'treatment_eq_1': 'treatment'}))
experiment_df['retention_proba_control'] = retention_model.predict(
experiment_df[['retention', 'treatment_eq_0', 'num_family_members']]\
.rename(columns = {'treatment_eq_0': 'treatment'}))
experiment_df['proba_diff'] = experiment_df.retention_proba_treatment \
- experiment_df.retention_proba_control
experiment_df.proba_diff.mean()
# 0.0281
所以,我们可以估算出客户保留的影响为 2.8%。
恭喜!我们终于完成了完整的 A/B 测试分析,并能够估算客户支出和客户保留的影响。我们的实验是成功的,因此在实际生活中,我们会开始考虑将其投入生产。
你可以在GitHub上找到这个示例的完整代码。
摘要
让我快速回顾一下我们今天讨论的内容:
-
Bootstrap 的主要思想是从你的样本中进行有放回的模拟,假设总体分布与我们拥有的数据分布相同。
-
Bootstrap 在数据点较少、数据中有异常值或数据偏离理论分布时表现尤为出色。Bootstrap 还可以帮助你估算自定义指标。
-
你可以使用 bootstrap 来处理观察数据,例如,获取你的值的置信区间。
-
此外,bootstrap 被广泛应用于 A/B 测试分析——既用于估算处理的影响,也用于进行功效分析来设计实验。
非常感谢你阅读这篇文章。如果你有任何后续问题或评论,请在评论区留言。
参考文献
除非另有说明,所有图片均由作者制作。
本文的灵感来自 Florent Buisson 的书籍《使用 R 和 Python 进行行为数据分析》。
产品分析师的实用计算机模拟
第一部分:针对场景预测的任务特定方法
·发布于 Towards Data Science ·20 分钟阅读·2024 年 4 月 19 日
--

图片来自 DALL-E
在产品分析中,我们经常遇到“如果怎样”的问题。我们的团队不断发明各种方法来改进产品,并希望了解它如何影响我们的 KPI 或其他指标。
让我们来看一些例子:
-
假设我们在金融科技行业,面对新的规定,要求我们检查来自首次捐赠或向特定国家汇款超过 10 万美元的客户的更多文件。我们希望了解这一变化对我们的运营需求的影响,以及是否需要雇佣更多的客服人员。
-
让我们切换到另一个行业。我们可能希望通过推出一个新的奖励计划,鼓励出租车司机晚间工作或接受长途订单。在推出这一变化之前,估算奖励的预期规模并进行成本与收益分析对我们来说至关重要。
-
以最后一个例子为例,让我们看看主要的客户支持 KPI。通常,公司会跟踪平均等待时间。改善这一指标有很多可能的方式。我们可以增加夜班,雇佣更多的客服人员,或者利用大型语言模型(LLMs)快速回答问题。为了优先考虑这些想法,我们需要估算它们对我们的 KPI 的影响。
当你第一次看到这样的问题时,它们看起来可能非常令人畏惧。
如果有人要求你计算月活跃用户数或 7 天留存率,这是直接的。你只需访问你的数据库,写 SQL 语句并使用现有的数据。
当你需要计算一些不存在的东西时,事情变得更加具有挑战性(也更为激动人心)。计算机模拟通常是此类任务的最佳解决方案。根据维基百科,模拟是对一个可能存在于现实世界中的过程或系统的模拟性表现。因此,我们将尝试模仿不同的情景,并将其用于决策过程。
模拟是一个强大的工具,可以在各种情况下帮助你。因此,我想通过这系列文章与大家分享计算机模拟的实际案例:
-
在本文中,我们将讨论如何利用模拟来估算不同的情景。你将了解模拟的基本概念,并看到它们如何解决复杂的任务。
-
在第二部分,我们将不再讨论情景分析,而是聚焦于计算机模拟的经典方法——自举法(bootstrap)。自举法可以帮助你为度量指标获取置信区间,并分析 A/B 测试。
-
我想将第三部分的内容奉献给基于代理的模型。我们将模拟客户服务(CS)代理的行为,了解我们的过程变动如何影响客户服务的关键绩效指标(KPI),如队列长度或平均等待时间。
所以,是时候开始并讨论我们将在本文中解决的任务了。
我们的项目:为英语课程启动测试
假设我们正在开发一款教育技术产品,帮助人们学习英语。我们正在进行一项测试,这项测试可以从不同角度评估学生的知识水平(阅读、听力、写作和口语)。这项测试将为我们和我们的学生提供清晰的当前水平认知。
我们已经决定为所有新学生启动该测试,以便评估他们的初始水平。同时,我们会建议现有学生在下次回归服务时通过这项测试。
我们的目标是根据提交的测试数量建立时间预测。由于这些测试的某些部分(写作和口语)需要我们的老师进行人工审核,我们希望确保我们有足够的能力按时完成这些测试的审核。
让我们试着对问题进行结构化。我们有两组学生:
-
第一组是现有学生。在分析中精准是一个好习惯,因此我们将把他们定义为在本次上线之前已开始使用我们服务的学生。我们需要在他们下次交易时进行一次检查,因此在处理他们时会出现一个显著的需求激增。之后,这一群体的需求将变得微不足道(只有偶尔的重新激活)。
-
新学生希望能够继续加入我们的课程。因此,我们应该预期这一群体的需求会保持稳定。
现在,是时候思考我们如何估算这两组客户的需求了。
对于新学生的情况就比较直接——我们需要预测每周的新客户数量,并用它来估算需求。因此,这就是经典的时间序列预测任务。
预测现有客户需求的任务可能会更具挑战性。直接的方法是构建一个模型,预测学生下次回归服务的周数,并用它来进行估算。这是一个可行的解决方案,但对我来说有点过于复杂。
我更倾向于采用另一种方法。我会模拟我们在一段时间前启动这个测试的情景,并使用之前的数据。这样,我们将获得“此模拟启动”之后的所有数据,并能够计算出所有的指标。所以,这实际上是场景模拟的基本思路。
很好,我们有了计划。让我们继续执行。
建模新客户需求
在进行分析之前,让我们先检查一下我们拥有的数据。我们记录了课程完成事件的记录。我们知道每个事件的用户标识符、日期、模块和课程编号。我们将使用每周数据,以避免季节性波动,并捕捉有意义的趋势。

让我分享一些关于教育过程的背景信息。学生们主要来到我们的服务平台,从零开始学习英语,并通过六个模块(从预 A1 到 C1)。每个模块包含 100 节课程。
这些数据是专门为这个用例生成的,因此我们正在使用一个合成数据集。
首先,我们需要计算出我们想要预测的指标。我们将为学生提供在完成第一节演示课后参加初步评估测试的机会。所以,我们可以轻松地计算出通过第一节课的客户数,或者按他们的首次日期来汇总用户数据。
new_users_df = df.groupby('user_id', as_index = False).date.min()\
.rename(columns = {'date': 'cohort'})
new_users_stats_df = new_users_df.groupby('cohort')[['user_id']].count()\
.rename(columns = {'user_id': 'new_users'})
我们可以查看数据,发现整体呈增长趋势,并伴随有一些季节性影响(例如,夏季或圣诞节期间客户的加入人数较少)。

在预测中,我们将使用Prophet——Meta 的一个开源库。它在处理商业数据时效果非常好,因为它可以预测非线性趋势,并自动考虑季节性影响。你可以很容易地从 PyPI 安装它。
pip install prophet
Prophet 库期望的数据框架包含两列:ds表示时间戳,y表示我们想要预测的指标。此外,ds必须是日期时间列。因此,我们需要将数据转换为所需的格式。
pred_new_users_df = new_users_df.copy()
pred_new_users_df = pred_new_users_df.rename(
columns = {'new_users': 'y', 'cohort': 'ds'})
pred_new_users_df.ds = pd.to_datetime(pred_new_users_df.ds)
现在,我们准备好进行预测了。像往常一样,在机器学习中,我们需要初始化并拟合一个模型。
from prophet import Prophet
m = Prophet()
m.fit(pred_new_users_df)
下一步是预测。首先,我们需要创建一个未来的数据框,指定预测的周期数和它们的频率(在我们的例子中是每周)。然后,我们需要调用predict函数。
future = m.make_future_dataframe(periods= 52, freq = 'W')
forecast_df = m.predict(future)
forecast_df.tail()[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
结果,我们得到了预测值(yhat)和置信区间(yhat_lower 和 yhat_upper)。

如果没有图表,结果很难理解。让我们使用 Prophet 函数来更好地可视化输出结果。
m.plot(forecast_df) # forecast
m.plot_components(forecast_df) # components
预测图表向你展示了带有置信区间的预测结果。

组件视图让你了解趋势和季节性效应之间的分布。例如,第二张图显示了夏季期间的季节性下降,以及九月初的增长(当人们可能更有动力开始学习新事物时)。

我们可以将所有这些预测逻辑放入一个函数中。这样以后会对我们很有帮助。
import plotly.express as px
import plotly.io as pio
pio.templates.default = 'simple_white'
def make_prediction(tmp_df, param, param_name = '', periods = 52):
# pre-processing
df = tmp_df.copy()
date_param = df.index.name
df.index = pd.to_datetime(df.index)
train_df = df.reset_index().rename(columns = {date_param: 'ds', param: 'y'})
# model
m = Prophet()
m.fit(train_df)
future = m.make_future_dataframe(periods=periods, freq = 'W')
forecast = m.predict(future)
forecast = forecast[['ds', 'yhat']].rename(columns = {'ds': date_param, 'yhat': param + '_model'})
# join to actual data
forecast = forecast.set_index(date_param).join(df, how = 'outer')
# visualisation
fig = px.line(forecast,
title = '<b>Forecast:</b> ' + (param if param_name == '' else param_name),
labels = {'value': param if param_name == '' else param_name},
color_discrete_map = {param: 'navy', param + '_model': 'gray'}
)
fig.update_traces(mode='lines', line=dict(dash='dot'),
selector=dict(name=param + '_model'))
fig.update_layout(showlegend = False)
fig.show()
return forecast
new_forecast_df = make_prediction(new_users_stats_df,
'new_users', 'new users', periods = 75)
我更倾向于与我的利益相关者分享一个更具样式的可视化版本(尤其是对于公开演讲),所以我也将其添加到了函数中。

在这个例子中,我们使用了默认的 Prophet 模型,并得到了一个相当可信的预测结果。然而,在某些情况下,你可能需要调整参数,因此我建议你阅读Prophet 文档以了解更多可能的调整选项。
例如,在我们的案例中,我们相信我们的受众将以相同的速度继续增长。然而,这可能并非如此,你可能预计它会在约 100 个用户时达到上限。让我们更新一下对饱和增长的预测。
# adding cap to the initial data
# it's not required to be constant
pred_new_users_df['cap'] = 100
#specifying logistic growth
m = Prophet(growth='logistic')
m.fit(pred_new_users_df)
# adding cap for the future
future = m.make_future_dataframe(periods= 52, freq = 'W')
future['cap'] = 100
forecast_df = m.predict(future)
我们可以看到,预测结果发生了显著变化,并且增长停滞在每周约 100 个新客户。

在这种情况下,查看组件图表也很有趣。我们可以看到季节性效应保持不变,而趋势已变为逻辑增长(正如我们指定的那样)。

我们已经了解了一些关于调整预测的能力。然而,针对未来的计算,我们将使用一个基本模型。我们的业务仍然相对较小,且最有可能尚未达到饱和状态。
我们已经得到了所有新客户所需的估算,并准备开始考虑现有客户的需求。
建模现有客户的需求
第一个版本
我们方法的关键点在于模拟我们在某段时间前进行该测试的情况,并使用这些数据计算需求。我们的解决方案基于一个理念,即我们可以利用过去的数据,而不是预测未来。
由于存在显著的年度季节性,我将使用过去一年的数据来自动考虑这些效应。我们希望在 4 月初启动这个项目。因此,我将使用 2023 年 4 月 2 日那一周的数据。
首先,我们需要筛选出与现有客户相关的数据,时间是 2023 年 4 月初。我们已经预测了新用户的需求,因此在此次估算中不需要考虑新用户。
model_existing_users = df[df.date < '2023-04-02'].user_id.unique()
raw_existing_df = df[df.user_id.isin(model_existing_users)]
然后,我们需要对这些用户的需求进行建模。我们将为现有学生提供下次使用我们产品时参加测试的机会。所以,我们需要定义每个客户在推出后何时返回我们的服务,并按周汇总客户数量。这一点其实并不复杂。
existing_model_df = raw_existing_df[raw_existing_df.date >= '2023-04-02']\
.groupby('user_id', as_index = False).date.min()\
.groupby('date', as_index = False).user_id.count()\
.rename(columns = {'user_id': 'existing_users'})
我们得到了第一次估算。如果我们在 2023 年 4 月推出这个测试,第一周大约会有 1.3K 个测试,第二周 0.3K,第三周 80 个案例,之后会更少。

我们假设 100%的现有客户都会完成测试,我们需要对其进行检查。在实际任务中,值得考虑转化率,并相应调整数字。为了简便起见,这里我们将继续使用 100%的转化率。
所以,我们完成了第一次建模。其实一点也不难。但这个估算足够好吗?
考虑长期趋势
我们正在使用去年的数据。然而,一切都在变化。让我们来看看活跃客户的数量变化情况。
active_users_df = df.groupby('date')[['user_id']].nunique()\
.rename(columns = {'user_id': 'active_users'})

我们可以看到它在稳步增长。我预计它会继续增长。所以,考虑到这一点,我们应该调整我们的预测以反映这个同比增长(Year-over-Year)。我们可以重新使用我们的预测函数,并利用预测值计算同比增长,以使其更加准确。
active_forecast_df = make_prediction(active_users_df,
'active_users', 'active users')

让我们根据我们的预测计算同比增长,并调整模型的预测值。
# calculating YoYs
active_forecast_df['active_user_prev_year'] = active_forecast_df.active_users.shift(52)
active_forecast_df['yoy'] = active_forecast_df.active_users_model/\
active_forecast_df.active_user_prev_year
existing_model_df = existing_model_df.rename(
columns = {'date': 'model_date', 'existing_users': 'model_existing_users'})
# adjusting dates from 2023 to 2024
existing_model_df['date'] = existing_model_df.model_date.map(
lambda x: datetime.datetime.strptime(x, '%Y-%m-%d') + datetime.timedelta(364)
)
existing_model_df = existing_model_df.set_index('date')\
.join(active_forecast_df[['yoy']])
# updating estimations
existing_model_df['existing_users'] = list(map(
lambda x, y: int(round(x*y)),
existing_model_df.model_existing_users,
existing_model_df.yoy
))
我们也完成了现有学生的估算。所以,我们准备将两部分合并,得到最终结果。
将一切汇总
初步结果
现在,我们可以将所有之前的估算结合起来,看看最终的图表。为此,我们需要将数据转换为通用格式,并添加不同的分段,以便我们能够区分新客户和现有客户的需求。
# existing segment
existing_model_df = existing_model_df.reset_index()[['date', 'existing_users']]\
.rename(columns = {'existing_users': 'users'})
existing_model_df['segment'] = 'existing'
# new segment
new_model_df = new_forecast_df.reset_index()[['cohort', 'new_users_model']]\
.rename(columns = {'cohort': 'date', 'new_users_model': 'users'})
new_model_df = new_model_df[(new_model_df.date >= '2024-03-31')
& (new_model_df.date < '2025-04-07')]
new_model_df['users'] = new_model_df.users.map(lambda x: int(round(x)))
new_model_df['segment'] = 'new'
# combining everything
demand_model_df = pd.concat([existing_model_df, new_model_df])
# visualisation
px.area(demand_model_df.pivot(index = 'date',
columns = 'segment', values = 'users').head(15)[['new', 'existing']],
title = '<b>Demand</b>: modelling number of tests after launch',
labels = {'value': 'number of test'})

我们应该预期在推出后的第一周大约会有 2.5K 的测试需求,主要来自现有客户。然后,在接下来的四周内,我们将审查现有用户的测试,新增用户每周大约只有 100 到 130 个案例。
太好了。现在,我们可以将我们的估算与同事分享,这样他们也可以规划自己的工作。
如果我们遇到需求限制怎么办?
在现实生活中,当无法将新功能推广到 100%的客户时,你经常会面临容量限制的问题。所以,是时候学习如何应对这种情况了。
假设我们发现我们的教师每周只能检查 1000 个测试。那我们就需要分散需求,避免糟糕的客户体验(当学生需要等待好几周才能得到结果时)。
幸运的是,我们可以通过将测试分批推出给现有客户来轻松实现。我们可以在第一周为所有新加入的用户和 X%的现有客户开启此功能。然后,我们可以在第二周增加 Y%的现有客户,等等。最终,我们将评估所有现有学生,并仅从新用户中产生持续的需求。
让我们制定一个推出计划,确保不会超过 1000 个容量的阈值。
由于我们肯定要为所有新学生推出,因此让我们从他们开始并将其加入计划中。我们将按分段存储所有的需求估算在raw_demand_est_model_df数据框中,并使用我们之前得到的new_model_df估算进行初始化。
raw_demand_est_model_df = new_model_df.copy()
现在,我们可以汇总这些数据并计算剩余的容量。
capacity = 1000
demand_est_model_df = raw_demand_est_model_df.pivot(index = 'date',
columns = 'segment', values = 'users')
demand_est_model_df['total_demand'] = demand_est_model_df.sum(axis = 1)
demand_est_model_df['capacity'] = capacity
demand_est_model_df['remaining_capacity'] = demand_est_model_df.capacity \
- demand_est_model_df.total_demand
demand_est_model_df.head()

让我们将这个逻辑放到一个单独的函数中,因为我们将在每次迭代后需要它来评估我们的估算结果。
import plotly.graph_objects as go
def get_total_demand_model(raw_demand_est_model_df, capacity = 1000):
demand_est_model_df = raw_demand_est_model_df.pivot(index = 'date',
columns = 'segment', values = 'users')
demand_est_model_df['total_demand'] = demand_est_model_df.sum(axis = 1)
demand_est_model_df['capacity'] = capacity
demand_est_model_df['remaining_capacity'] = demand_est_model_df.capacity \
- demand_est_model_df.total_demand
tmp_df = demand_est_model_df.drop(['total_demand', 'capacity',
'remaining_capacity'], axis = 1)
fig = px.area(tmp_df,
title = '<b>Demand vs Capacity</b>',
category_orders={'segment': ['new'] + list(sorted(filter(lambda x: x != 'new', tmp_df.columns)))},
labels = {'value': 'tests'})
fig.add_trace(go.Scatter(
x=demand_est_model_df.index, y=demand_est_model_df.capacity,
name='capacity', line=dict(color='black', dash='dash'))
)
fig.show()
return demand_est_model_df
demand_plan_df = get_total_demand_model(raw_demand_est_model_df)
demand_plan_df.head()
我还在此函数的输出中添加了一个图表,帮助我们轻松评估结果。

现在,我们可以按周计划现有客户的推出进程。
首先,让我们转变当前针对现有学生的需求模型。我希望它能够按周序列号进行索引,并显示 100%的需求估算。然后,我可以通过将需求乘以权重,并根据启动日期和周数计算日期,顺利得到每批次的估算。
existing_model_df['num_week'] = list(range(existing_model_df.shape[0]))
existing_model_df = existing_model_df.set_index('num_week')\
.drop(['date', 'segment'], axis = 1)
existing_model_df.head()

所以,例如,如果我们为 10%的随机客户推出评估测试,那么我们预计第一周会进行 244 个测试,第二周 52 个测试,第三周 14 个测试,等等。
我将使用相同的估算来处理所有批次。我假设所有相同大小的批次在接下来的几周内将产生相同数量的测试。所以,我没有考虑与每个批次的启动日期相关的任何季节性影响。
这个假设大大简化了你的过程。在我们的案例中这是相当合理的,因为我们将在 4 到 5 周内完成推出,而且在这段时间内没有显著的季节性影响。然而,如果你想要更准确(或有显著季节性波动),你可以通过重复我们之前的过程来为每个批次建立需求估算。
让我们从 2024 年 3 月 31 日那一周开始。如我们之前看到的,我们有 888 个测试的剩余容量。如果我们为 100%的现有客户推出测试,第一周我们将需要检查大约 2400 个测试。因此,我们只准备向一部分客户推出。让我们来计算一下。
cohort = '2024-03-31'
demand_plan_df.loc[cohort].remaining_capacity/existing_model_df.iloc[0].users
# 0.3638
使用更简洁的数字操作会更容易,所以让我们将数字四舍五入到 5%的倍数。我已将数字向下舍入,以留出一些缓冲。
full_demand_1st_week = existing_model_df.iloc[0].users
next_group_share = demand_plan_df.loc[cohort].remaining_capacity/full_demand_1st_week
next_group_share = math.floor(20*next_group_share)/20
# 0.35
由于我们将进行几次迭代,我们需要追踪已有客户中启用了新功能的百分比。此外,检查我们是否已经处理了所有客户以避免重复计数也是值得的。
enabled_user_share = 0
# if we can process more customers than are left, update the number
if next_group_share > 1 - enabled_user_share:
print('exceeded')
next_group_share = round(1 - enabled_user_share, 2)
enabled_user_share += next_group_share
# 0.35
此外,将我们的推出计划保存在一个单独的变量中将会很有帮助。
rollout_plan = []
rollout_plan.append(
{'launch_date': cohort, 'rollout_percent': next_group_share}
)
现在,我们需要估算这一批次的预期需求。3 月 31 日对 35%的客户进行测试,将不仅会在第一周产生需求,还会在接下来的几周产生需求。因此,我们需要计算这一批次的总需求,并将其纳入我们的计划中。
# copy the model
next_group_demand_df = existing_model_df.copy().reset_index()
# calculate the dates from cohort + week number
next_group_demand_df['date'] = next_group_demand_df.num_week.map(
lambda x: (datetime.datetime.strptime(cohort, '%Y-%m-%d') \
+ datetime.timedelta(7*x))
)
# adjusting demand by weight
next_group_demand_df['users'] = (next_group_demand_df.users * next_group_share).map(lambda x: int(round(x)))
# labelling the segment
next_group_demand_df['segment'] = 'existing, cohort = %s' % cohort
# updating the plan
raw_demand_est_model_df = pd.concat([raw_demand_est_model_df,
next_group_demand_df.drop('num_week', axis = 1)])
现在,我们可以重用函数get_total_demand_mode,它帮助我们分析当前的需求与容量之间的平衡。
demand_plan_df = get_total_demand_model(raw_demand_est_model_df)
demand_plan_df.head()
我们在第一周已经利用了大部分容量。我们仍然有一些空闲资源,但我们故意决定为可持续性留出一些缓冲。我们可以看到,在三周后,这一批次几乎没有需求。

到此为止,我们完成了第一轮迭代,可以进入下一个周期——2024 年 4 月 4 日。我们可以在这一周再检查 706 个案例。

我们可以重复整个过程直到本周结束,然后继续进入下一个阶段。我们可以进行迭代,直到我们将项目推广到 100%的现有客户(enabled_user_share等于 1)。
我们可以在不违反每周 1000 次测试容量限制的情况下,在四周内将我们的测试推出到所有客户。最后,我们将得到如下的每周预测。

我们还可以查看我们在模拟过程中记录的推出计划。所以,我们需要在 3 月 31 日那一周对随机选择的 35%的客户进行测试,接着在下周对下一个 20%的客户进行测试,再对接下来的两周分别对 25%和 20%的现有用户进行测试。之后,我们将把我们的项目推广到所有现有学生。
rollout_plan
# [{'launch_date': '2024-03-31', 'rollout_percent': 0.35},
# {'launch_date': '2024-04-07', 'rollout_percent': 0.2},
# {'launch_date': '2024-04-14', 'rollout_percent': 0.25},
# {'launch_date': '2024-04-21', 'rollout_percent': 0.2}]
所以,恭喜你。我们现在有了一个可持续推出新功能的计划。
跟踪学生的表现变化
我们已经做了很多工作来估算需求。我们通过模拟一年前项目的推出过程,进行扩展并评估后果,从而运用了模拟的思路。所以,这绝对是一个模拟的例子。
然而,我们主要使用的是你每天都会用到的基本工具——一些 Pandas 数据整理和算术运算。在文章的最后部分,我想向你展示一个稍微复杂的案例,我们将需要为每个客户独立模拟该过程。
产品需求随着时间的推移通常会发生变化,我们的项目也经历了这种情况。你和团队决定,如果能够让学生在学习过程中跟踪进展(不仅仅是在一开始),那会更好。因此,我们希望能够让学生在每个模块结束后进行一次表现测试(如果距离上次测试已经超过一个月),或者如果学生在三个月的缺席后重新返回服务。
现在,测试分配的标准相当复杂。然而,我们仍然可以使用相同的方法,通过查看前一年的数据来进行判断。但是,这一次,我们需要查看每个客户的行为,并确定他们在什么情况下会得到一个测试。
我们将同时考虑新客户和现有客户,因为我们希望评估跟踪测试对所有客户的影响。我们不需要发布前的数据,因为第一个测试将在下一次有效交易时分配,之前的历史数据不再重要。因此,我们可以将其过滤掉。
sim_df = df[df.date >= '2023-03-31']
让我们还定义一个函数,用于计算两个日期字符串之间的天数。这对我们在实现时会非常有帮助。
def days_diff(date1, date2):
return (datetime.datetime.strptime(date2, '%Y-%m-%d')\
- datetime.datetime.strptime(date1, '%Y-%m-%d')).days
让我们从一个用户开始,详细讨论一下逻辑。首先,我们将筛选与该用户相关的事件,并将其转换为字典列表。这样处理数据对我们来说会更加方便。
user_id = 4861
user_events = sim_df[sim_df.user_id == user_id]\
.sort_values('date')\
.to_dict('records')
# [{'user_id': 4861, 'date': '2023-04-09', 'module': 'pre-A1', 'lesson_num': 8},
# {'user_id': 4861, 'date': '2023-04-16', 'module': 'pre-A1', 'lesson_num': 9},
# {'user_id': 4861, 'date': '2023-04-23', 'module': 'pre-A1', 'lesson_num': 10},
# {'user_id': 4861, 'date': '2023-04-23', 'module': 'pre-A1', 'lesson_num': 11},
# {'user_id': 4861, 'date': '2023-04-30', 'module': 'pre-A1', 'lesson_num': 12},
# {'user_id': 4861, 'date': '2023-05-07', 'module': 'pre-A1', 'lesson_num': 13}]
为了模拟我们的产品逻辑,我们将逐一处理用户事件,并在每个阶段检查客户是否符合评估的条件。
让我们讨论一下我们需要维护哪些变量,以便能够判断客户是否符合测试条件。为此,让我们回顾一下客户可能会进行测试的所有情况:
-
如果之前没有测试记录 -> 我们需要知道他们是否曾经通过过测试。
-
如果客户完成了模块并且距离上次测试已经超过一个月 -> 我们需要知道最后一次测试的日期。
-
如果客户在三个月后返回 -> 我们需要存储最后一课的日期。
为了能够检查所有这些标准,我们只需要使用两个变量:最后的测试日期(如果之前没有测试,则为None)和上次课程日期。此外,我们还需要存储所有生成的测试,以便稍后进行计算。让我们初始化所有变量。
tmp_gen_tests = []
last_test_date = None
last_lesson_date = None
现在,我们需要按事件进行迭代并检查标准。
for rec in user_events:
pass
让我们从最初的测试开始,逐一检查我们的标准。在这种情况下,last_test_date将等于None。在"分配"测试之后,更新last_test_date变量是非常重要的。
if last_test_date is None: # initial test
last_test_date = rec['date']
# TBD saving the test info
在已完成的模块的情况下,我们需要检查它是否是该模块的最后一课,并且是否已经过去了 30 天以上。
if (rec['lesson_num'] == 100) and (days_diff(last_test_date, rec['date']) >= 30):
last_test_date = rec['date']
# TBD saving the test info
最后一种情况是客户已经三个月没有使用我们的服务。
if (days_diff(last_lesson_date, rec['date']) >= 30):
last_test_date = rec['date']
# TBD saving the test info
此外,我们需要在每次迭代时更新last_lesson_date以保持准确性。
我们已经讨论了所有的构建块,现在可以将它们结合起来,进行所有客户的模拟。
import tqdm
tmp_gen_tests = []
for user_id in tqdm.tqdm(sim_raw_df.user_id.unique()):
# initialising variables
last_test_date = None
last_lesson_date = None
for rec in sim_raw_df[sim_raw_df.user_id == user_id].to_dict('records'):
# initial test
if last_test_date is None:
last_test_date = rec['date']
tmp_gen_tests.append(
{
'user_id': rec['user_id'],
'date': rec['date'],
'trigger': 'initial test'
}
)
# finish module
elif (rec['lesson_num'] == 100) and (days_diff(last_test_date, rec['date']) >= 30):
last_test_date = rec['date']
tmp_gen_tests.append(
{
'user_id': rec['user_id'],
'date': rec['date'],
'trigger': 'finished module'
})
# reactivation
elif (days_diff(last_lesson_date, rec['date']) >= 92):
last_test_date = rec['date']
tmp_gen_tests.append(
{
'user_id': rec['user_id'],
'date': rec['date'],
'trigger': 'reactivation'
})
last_lesson_date = rec['date']
现在,我们可以汇总这些数据。由于我们再次使用的是去年的数据,我将按大约 80%的年增长率调整这个数字,正如我们之前估算的那样。
exist_model_upd_stats_df = exist_model_upd.pivot_table(
index = 'date', columns = 'trigger', values = 'user_id',
aggfunc = 'nunique'
).fillna(0)
exist_model_upd_stats_df = exist_model_upd_stats_df\
.map(lambda x: int(round(x * 1.8)))
我们为初始测试得到了相似的估算。在这个案例中,“初始测试”部分等于我们之前估算中新增需求和现有需求的总和。

所以,观察其他部分要更有趣,因为它们会是我们之前计算的增量。我们可以看到每周大约有 30 到 60 个案例来自于 5 月开始的模块完成的客户。
几乎不会出现重新激活的情况。在我们的模拟中,我们每年总共得到了 4 个案例。

恭喜!现在问题已经解决,我们找到了一个很好的方法,可以让我们在没有高级数学的情况下,仅通过模拟就能做出精确的估算。你可以使用类似的方法。
你可以在GitHub上找到这个示例的完整代码。
摘要
让我快速回顾一下今天讨论的内容:
-
计算机模拟的主要思想是基于你的数据进行模仿。
-
在许多情况下,你可以将问题的框架从预测未来转变为使用你已经拥有的数据,并模拟你感兴趣的过程。因此,这种方法非常强大。
-
在本文中,我们通过一个端到端的示例介绍了情景估算。我们展示了如何构建复杂问题,并将其拆分成更多定义明确的问题。我们还学会了如何处理约束并计划逐步推出。
非常感谢你阅读本文。如果你有任何后续问题或评论,请在评论区留言。
参考文献
除非另有说明,所有图像均由作者制作。
面向产品分析师的实用计算机仿真
第三部分:建模运营队列
·发布于 Towards Data Science ·阅读时长 23 分钟·2024 年 5 月 24 日
--

图像由 DALL-E 3 生成
今天,我想给大家展示一个离散事件仿真方法的例子。我们将模拟客户支持团队,并决定采用何种策略来提高其绩效。但首先,请允许我分享一下我的个人经历。
我第一次在大学里学习离散仿真。我的一门课程是排队论,为了获得期末成绩,我需要实现一个机场仿真并计算一些关键绩效指标(KPI)。不幸的是,由于我已经全职工作,我错过了所有的研讨会,因此对这一主题背后的理论以及如何进行操作毫无头绪。
我下定决心要获得一个优异的成绩,因此找了一本书,读懂了基础知识,并花了几个晚上进行实现。虽然这对我来说相当有挑战性,因为一段时间没写代码了,但我还是搞明白了,并且得到了 A 的成绩。
在那时(就像很多学生一样),我感觉这些信息对我未来的工作没有什么帮助。然而,后来我意识到,许多分析性任务都可以通过这种方法来解决。所以,我想和大家分享这个方法。
代理基础仿真最明显的应用场景之一是运营分析。大多数产品都有客户支持,客户可以在此寻求帮助。客户支持团队通常会关注以下指标:
-
平均解决时间 — 客户向客服提出问题到收到第一次回复所经历的时间,
-
队列大小,即目前我们积压了多少任务。
没有适当的模型,理解我们的变化(例如,增加夜班或仅仅增加代理数量)如何影响 KPI 可能会比较困难。模拟将帮助我们做到这一点。
所以,让我们不要浪费时间,继续前进。
模拟和建模的基础
从最开始开始。我们将建模系统。系统是一个由实体(例如人、服务器甚至机械工具)组成的集合,这些实体相互作用以实现某个逻辑目标(即回答客户问题或通过机场的边检)。
你可以根据研究目标定义系统所需的粒度级别。例如,在我们的案例中,我们希望研究代理的效率和时间表的变化如何影响平均的 CS 工单解决时间。因此,系统将仅是一个代理集合。然而,如果我们想要模拟将一些工单外包给不同的外包公司,我们就需要在模型中包含这些合作伙伴。
系统由一组变量描述——例如队列中的工单数量或当前在工作时刻的代理数量。这些变量定义了系统状态。
系统有两种类型:
-
离散 — 当系统状态瞬间发生变化时,例如,新的工单被添加到队列中,或者代理完成了他们的班次。
-
连续——当系统不断演变时。例如,飞行中的飞机,坐标、速度、高度等参数在飞行过程中持续变化。
对于我们的任务,我们可以将系统视为离散的,并使用离散事件模拟方法。这是指系统只能在有限的时间点发生变化。这些时间点是事件发生的时刻,并立即改变系统状态。
因此,整个方法基于事件。我们将逐一生成和处理事件,以模拟系统如何运作。我们可以使用时间线的概念来结构化事件。
由于这个过程是动态的,我们需要跟踪模拟时间的当前值,并能够将其从一个值推进到另一个值。在模拟模型中,显示当前时间的变量通常称为模拟时钟。
我们还需要一个机制来推进模拟时间。推进时间有两种方法:
-
下一事件时间推进 — 我们从一个事件的时间戳移动到下一个事件的时间戳。
-
固定增量时间推进 — 我们选择一个时间段,例如 1 分钟,每次按照这个时间段调整时钟。
我认为第一种方法更容易理解、实现和调试。因此,在本文中我将坚持使用这种方法。
让我们回顾一个简单的例子,理解它是如何运作的。我们将讨论一个简化的 CS 工单队列案例。
我们开始模拟,初始化模拟时钟。有时候,人们会使用零作为初始值。我更喜欢使用真实的时间数据和实际的日期时间。
这是我们系统的初始状态。我们的时间线上有两个与客户请求相关的事件。

下一步是将模拟时钟推进到我们的时间线上的第一个事件——9:15 的客户请求。

现在是处理这个事件的时候了。我们应该找到一个代理来处理这个请求,分配请求给他们,并生成一个完成任务的事件。事件是我们模拟的主要驱动力,所以如果一个事件创造了另一个事件是可以的。
查看更新后的时间线,我们可以看到,最紧迫的事件不是第二个客户请求,而是第一个任务的完成。

所以,我们需要将时钟推进到 9:30,并处理下一个事件。请求的完成不会创造新的事件,因此在此之后,我们将转到第二个客户请求。

我们将重复这个从一个事件到另一个事件的过程,直到模拟结束。
为了避免无限循环的过程,我们需要定义停止标准。在这种情况下,我们可以使用以下逻辑:如果时间线上没有更多事件,我们应该停止模拟。在这个简化的例子中,我们的模拟将在完成第二个任务后停止。
我们已经讨论了离散事件模拟的理论,并理解了它是如何工作的。现在,是时候实践并在代码中实现这种方法了。
程序架构
面向对象编程
在我的日常工作中,我通常使用过程式编程范式。我为一些重复性的任务创建函数,但除此之外,我的代码是相当线性的。这是数据处理任务中的一种标准方法。
在这个例子中,我们将使用面向对象编程。所以,如果你之前没有使用过 Python 类,或者需要复习一下这个主题,那就花点时间来复习它吧。
面向对象编程(OOP)基于对象的概念。对象由数据(称为属性的某些特征)和行为(函数或方法)组成。整个程序描述了不同对象之间的交互。例如,如果我们有一个代表客户服务代理的对象,它可以具有以下属性:
-
属性:名称、代理开始工作的日期、他们在任务上花费的平均时间或当前状态(
"out of office"、"working on task"或"free")。 -
方法:返回名称,更新状态或开始处理客户请求。
为了表示这样的一个对象,我们可以使用 Python 类。让我们为客户服务代理编写一个简单的类。
class CSAgent:
# initialising class
def __init__(self, name, average_handling_time):
# saving parameters mentioned during object creation
self.name = name
self.average_handling_time = average_handling_time
# specifying constant value
self.role = 'CS agent'
print('Created %s with name %s' % (self.role, self.name))
def get_name(self):
return self.name
def get_handling_time(self):
return self.average_handling_time
def update_handling_time(self, average_handling_time):
print('Updating time from %.2f to %.2f' % (self.average_handling_time,
average_handling_time))
self.average_handling_time = average_handling_time
这个类定义了每个代理人的姓名、平均处理时间和角色。我还添加了一些可以返回内部变量的函数,遵循封装模式。另外,我们有一个update_handling_time函数,可以让我们更新代理人的表现。
我们创建了一个类(一个解释任何类型 CS 代理的对象)。现在让我们创建该对象的实例——代理人 John Doe。
john_agent = CSAgent('John Doe', 12.3)
# Created CS agent with name John Doe
当我们创建类的实例时,__init__函数会被执行。我们可以使用__dict__属性将类字段以字典的形式展示出来。这在很多情况下都很有用,例如,如果你想把一系列对象转换为数据框架。
print(john_agent.__dict__)
# {'name': 'John Doe', 'average_handling_time': 12.3, 'role': 'CS agent'}
我们可以尝试执行一个方法并更新代理人的表现。
john_agent.update_handling_time(5.4)
# Updating time from 12.30 to 5.40
print(john_agent.get_handling_time())
# 5.4
今天我们将使用的面向对象编程的基本概念之一是继承。继承允许我们有一个高层的祖先类,并在子类中使用其特性。想象一下,我们不仅想要 CS 代理人,还想要 KYC 代理人。我们可以创建一个高层的Agent类,包含共同功能,并仅为 KYC 和 CS 代理人定义一次。
class Agent:
# initialising class
def __init__(self, name, average_handling_time, role):
# saving parameters mentioned during object creation
self.name = name
self.average_handling_time = average_handling_time
self.role = role
print('Created %s with name %s' % (self.role, self.name))
def get_name(self):
return self.name
def get_handling_time(self):
return self.average_handling_time
def update_handling_time(self, average_handling_time):
print('Updating time from %.2f to %.2f' % (self.average_handling_time,
average_handling_time))
self.average_handling_time = average_handling_time
现在,我们可以为这些代理类型创建独立的类,并定义稍微不同的__init__和get_job_description函数。
class KYCAgent(Agent):
def __init__(self, name, average_handling_time):
super().__init__(name, average_handling_time, 'KYC agent')
def get_job_description(self):
return 'KYC (Know Your Customer) agents help to verify documents'
class CSAgent(Agent):
def __init__(self, name, average_handling_time):
super().__init__(name, average_handling_time, 'CS agent')
def get_job_description(self):
return 'CS (Customer Support) answer customer questions and help resolving their problems'
为了指定继承关系,我们在当前类名后面的括号中提到基类。使用super(),我们可以调用基类的方法,例如__init__,用自定义的role值创建对象。
让我们创建对象并检查它们是否按预期工作。
marie_agent = KYCAgent('Marie', 25)
max_agent = CSAgent('Max', 10)
print(marie_agent.__dict__)
# {'name': 'Marie', 'average_handling_time': 25, 'role': 'KYC agent'}
print(max_agent.__dict__)
# {'name': 'Max', 'average_handling_time': 10, 'role': 'CS agent'}
让我们更新 Marie 的处理时间。尽管我们没有在KYCAgent类中实现这个函数,它使用的是基类中的实现,效果很好。
marie_agent.update_handling_time(22.5)
# Updating time from 25.00 to 22.50
我们还可以调用我们在类中定义的方法。
print(marie_agent.get_job_description())
# KYC (Know Your Customer) agents help to verify documents
print(max_agent.get_job_description())
# CS (Customer Support) answer customer questions and help resolving their problems
所以,我们已经涵盖了面向对象编程范式和 Python 类的基础。我希望这对你有所帮助,起到了复习作用。
现在,是时候回到我们的任务和我们需要的仿真模型了。
架构:类
如果你之前没怎么使用过面向对象编程,可能会觉得从过程式编程切换到面向对象编程有些困难。这需要一些时间来完成思维方式的转变。
一个生活小窍门是使用现实世界的类比(例如,很明显代理人是一个具有一些特征和行为的对象)。
同时,不要害怕犯错。程序架构有好有坏:有些架构随时间推移更容易阅读和维护。然而,即使在成熟的软件工程师之间,对于最佳实践也有很多争论,所以我建议不要过于纠结于让它在分析性临时研究中变得完美。
让我们思考一下在仿真中需要哪些对象:
-
System——我们任务中最顶层的概念。系统将表示当前状态并执行仿真。 -
正如我们之前讨论的,系统是一个实体的集合。因此,我们需要的下一个对象是
Agent。这个类将描述在任务中工作的代理人。 -
每个代理人都有自己的时间表:代理人工作的小时数,所以我将其提取到一个单独的类
Schedule中。 -
我们的代理人将处理客户请求。所以,这不言而喻——我们需要在系统中表示它们。此外,我们将在
System对象中存储一个已处理请求的列表,以便在仿真结束后获取最终统计数据。 -
如果没有空闲的代理人接手新的客户请求,该请求将被放入队列中。因此,我们将拥有一个
RequestQueue对象来存储所有客户请求,采用 FIFO 逻辑(先进先出)。 -
以下是一个重要的概念,即
TimeLine,它代表了我们需要按时间顺序处理的事件集合。 -
TimeLine将包括事件,因此我们也将为它们创建一个Event类。由于我们将有许多不同类型的事件需要以不同方式处理,我们可以利用面向对象编程的继承机制。我们将在下一节中更详细地讨论事件类型。
就这样。我已经把所有的类和它们之间的关系放入了图表中,以便更清晰地表达它。在开始实现之前,我使用这样的图表来获得系统的高层次视图——它有助于提前考虑架构。

正如你可能已经注意到的,图表并不是特别详细。例如,它没有包含所有字段名称和方法。这是故意的。这个架构将作为一个宏观视图来指导开发。因此,我不想花太多时间列出所有字段和方法的名称,因为这些细节可能会在实现过程中发生变化。
架构:事件类型
我们已经讨论了程序架构,现在是时候思考我们仿真的主要驱动因素——事件。
让我们讨论一下我们需要生成哪些事件来保持系统的运转。
-
我将从"Agent Ready"事件开始。它表示代理人开始工作并准备接手任务(如果队列中有等待的任务)。
-
我们需要知道代理人何时开始工作。这些工作时间可以依赖于代理人以及星期几。可能的话,我们甚至希望在仿真过程中更改这些时间表。由于在系统初始化时我们并不知道完成仿真需要多长时间,创建所有的"Agent Ready"事件是相当具有挑战性的。因此,我建议使用一个周期性的"Plan Agents Schedule"事件来为第二天创建准备工作的事件。
-
另一个必不可少的事件是"New Customer Request"——一个表示我们收到了新的客户服务请求,我们需要么开始处理它,么将其放入队列中的事件。
-
最后的事件是"Agent Finished Task",它表示代理人完成了他正在处理的任务,并且可能准备好接手一个新的任务。
就这样。这四个事件足以运行整个仿真。
类似于类的定义,对于系统建模没有对错之分。你可以使用略有不同的事件集。例如,你可以添加一个“开始任务”事件,以便明确表示。
实现
你可以在GitHub上找到完整的实现。
我们已经定义了我们解决方案的高层结构,因此我们准备好开始实现它了。让我们从我们的仿真核心——系统类开始。
初始化系统
让我们从系统类的__init__方法开始。
首先,让我们思考一下我们希望为仿真指定的参数:
-
agents— 将在 CS 团队中工作的代理集, -
queue— 当前的客户请求队列(如果有的话), -
initial_date— 由于我们同意使用实际时间戳而非相对时间戳,我将指定开始仿真的日期, -
logging— 定义我们是否希望打印一些调试信息的标志, -
customer_requests_df— 包含我们希望处理的客户请求集的信息的数据框。
除了输入参数,我们还将创建以下内部字段:
-
current_time— 我们将初始化为指定初始日期的 00:00:00 的仿真时钟, -
timeline对象,我们将用它来定义事件的顺序, -
processed_request— 一个空列表,我们将在其中存储处理过的客户请求,以便在仿真后获取数据。
现在是时候采取必要的行动来初始化系统了。只剩下两个步骤:
-
计划代理在第一天工作。我将生成并处理一个带有初始时间戳的相应事件。
-
通过将相应的“新客户请求”事件添加到时间线上来加载客户请求。
这是执行所有这些初始化系统操作的代码。
class System:
def __init__(self, agents, queue, initial_date,
customer_requests_df, logging = True):
initial_time = datetime.datetime(initial_date.year, initial_date.month,
initial_date.day, 0, 0, 0)
self.agents = agents
self.queue = RequestQueue(queue)
self.logging = logging
self.current_time = initial_time
self._timeline = TimeLine()
self.processed_requests = []
initial_event = PlanScheduleEvent('plan_agents_schedule', initial_time)
initial_event.process(self)
self.load_customer_request_events(customer_requests_df)
它还没有工作,因为它链接到一些未实现的类和方法,但我们会逐一解决这些问题。
时间线
让我们从系统定义中使用的类开始。第一个是TimeLine。它唯一的字段是事件列表。此外,它实现了一些方法:
-
添加事件(并确保它们按时间顺序排列),
-
返回下一个事件并将其从列表中删除,
-
显示剩余的事件数量。
class TimeLine:
def __init__(self):
self.events = []
def add_event(self, event:Event):
self.events.append(event)
self.events.sort(key = lambda x: x.time)
def get_next_item(self):
if len(self.events) == 0:
return None
return self.events.pop(0)
def get_remaining_events(self):
return len(self.events)
客户请求队列
我们在初始化中使用的另一个类是RequestQueue。
没有什么意外的情况:请求队列由客户请求组成。让我们从这个构建块开始。我们知道每个请求的创建时间以及处理该请求时,代理需要多少时间。
class CustomerRequest:
def __init__(self, id, handling_time_secs, creation_time):
self.id = id
self.handling_time_secs = handling_time_secs
self.creation_time = creation_time
def __str__(self):
return f'Customer Request {self.id}: {self.creation_time.strftime("%Y-%m-%d %H:%M:%S")}'
它是一个简单的数据类,仅包含参数。这里唯一的新内容是我重写了__str__方法,以改变打印函数的输出。这对调试非常有用。你可以自己比较。
test_object = CustomerRequest(1, 600, datetime.datetime(2024, 5, 1, 9, 42, 1))
# without defining __str__
print(test_object)
# <__main__.CustomerRequest object at 0x280209130>
# with custom __str__
print(test_object)
# Customer Request 1: 2024-05-01 09:42:01
现在,我们可以继续处理请求队列。与时间轴类似,我们实现了方法来添加新请求、计算队列中的请求并获取队列中的下一个请求。
class RequestQueue:
def __init__(self, queue = None):
if queue is None:
self.requests = []
else:
self.requests = queue
def get_requests_in_queue(self):
return len(self.requests)
def add_request(self, request):
self.requests.append(request)
def get_next_item(self):
if len(self.requests) == 0:
return None
return self.requests.pop(0)
代理
我们需要初始化系统的另一件事是代理。首先,每个代理都有一个时间表——根据工作日来决定他们的工作时间。
class Schedule:
def __init__(self, time_periods):
self.time_periods = time_periods
def is_within_working_hours(self, dt):
weekday = dt.strftime('%A')
if weekday not in self.time_periods:
return False
hour = dt.hour
time_periods = self.time_periods[weekday]
for period in time_periods:
if (hour >= period[0]) and (hour < period[1]):
return True
return False
我们唯一有的关于时间表的方法是检查在指定时刻代理是否在工作。
让我们定义代理类。每个代理将具有以下属性:
-
id和name— 主要用于日志记录和调试目的, -
schedule— 我们刚刚定义的代理时间表对象, -
request_in_work— 连接到客户请求对象,显示代理当前是否忙碌。 -
effectiveness— 显示代理与解决特定任务的预期时间相比效率的系数。
我们为代理实现了以下方法:
-
理解他们是否能接受新的任务(即他们是否空闲并仍在工作)。
-
启动和结束处理客户请求。
class Agent:
def __init__(self, id, name, schedule, effectiveness = 1):
self.id = id
self.schedule = schedule
self.name = name
self.request_in_work = None
self.effectiveness = effectiveness
def is_ready_for_task(self, dt):
if (self.request_in_work is None) and (self.schedule.is_within_working_hours(dt)):
return True
return False
def start_task(self, customer_request):
self.request_in_work = customer_request
customer_request.handling_time_secs = int(round(self.effectiveness * customer_request.handling_time_secs))
def finish_task(self):
self.request_in_work = None
将初始客户请求加载到时间轴中
我们在系统的__init__函数中唯一缺失的部分(除了稍后我们将详细讨论的事件处理)是load_customer_request_events函数的实现。这个相当直接。我们只需要将其添加到我们的System类中。
class System:
def load_customer_request_events(self, df):
# filter requests before the start of simulation
filt_df = df[df.creation_time >= self.current_time]
if filt_df.shape[0] != df.shape[0]:
if self.logging:
print('Attention: %d requests have been filtered out since they are outdated' % (df.shape[0] - filt_df.shape[0]))
# create new customer request events for each record
for rec in filt_df.sort_values('creation_time').to_dict('records'):
customer_request = CustomerRequest(rec['id'], rec['handling_time_secs'],
rec['creation_time'])
self.add_event(NewCustomerRequestEvent(
'new_customer_request', rec['creation_time'],
customer_request
))
很好,我们已经弄清楚了主要类。那么,接下来我们就开始实现事件。
处理事件
如前所述,我将使用继承方法并创建一个Event类。目前,它只实现了__init__和__str__函数,但它可能帮助我们为所有事件提供额外的功能。
class Event:
def __init__(self, event_type, time):
self.type = event_type
self.time = time
def __str__(self):
if self.type == 'agent_ready_for_task':
return '%s (%s) - %s' % (self.type, self.agent.name, self.time)
return '%s - %s' % (self.type, self.time)
然后,我为每种可能具有略微不同初始化方式的事件类型实现了一个单独的子类。例如,对于AgentReady事件,我们还有一个Agent对象。更重要的是,每个事件类都实现了process方法,该方法以system作为输入。
class AgentReadyEvent(Event):
def __init__(self, event_type, time, agent):
super().__init__(event_type, time)
self.agent = agent
def process(self, system: System):
# get next request from the queue
next_customer_request = system.queue.get_next_item()
# start processing request if we had some
if next_customer_request is not None:
self.agent.start_task(next_customer_request)
next_customer_request.start_time = system.current_time
next_customer_request.agent_name = self.agent.name
next_customer_request.agent_id = self.agent.id
if system.logging:
print('<%s> Agent %s started to work on request %d' % (system.current_time,
self.agent.name, next_customer_request.id))
# schedule finish processing event
system.add_event(FinishCustomerRequestEvent('finish_handling_request',
system.current_time + datetime.timedelta(seconds = next_customer_request.handling_time_secs),
next_customer_request, self.agent))
class PlanScheduleEvent(Event):
def __init__(self, event_type, time):
super().__init__(event_type, time)
def process(self, system: System):
if system.logging:
print('<%s> Scheeduled agents for today' % (system.current_time))
current_weekday = system.current_time.strftime('%A')
# create agent ready events for all agents working on this weekday
for agent in system.agents:
if current_weekday not in agent.schedule.time_periods:
continue
for time_periods in agent.schedule.time_periods[current_weekday]:
system.add_event(AgentReadyEvent('agent_ready_for_task',
datetime.datetime(system.current_time.year, system.current_time.month,
system.current_time.day, time_periods[0], 0, 0),
agent))
# schedule next planning
system.add_event(PlanScheduleEvent('plan_agents_schedule', system.current_time + datetime.timedelta(days = 1)))
class FinishCustomerRequestEvent(Event):
def __init__(self, event_type, time, customer_request, agent):
super().__init__(event_type, time)
self.customer_request = customer_request
self.agent = agent
def process(self, system):
self.agent.finish_task()
# log finish time
self.customer_request.finish_time = system.current_time
# save processed request
system.processed_requests.append(self.customer_request)
if system.logging:
print('<%s> Agent %s finished request %d' % (system.current_time, self.agent.name, self.customer_request.id))
# pick up the next request if agent continue working and we have something in the queue
if self.agent.is_ready_for_task(system.current_time):
next_customer_request = system.queue.get_next_item()
if next_customer_request is not None:
self.agent.start_task(next_customer_request)
next_customer_request.start_time = system.current_time
next_customer_request.agent_name = self.agent.name
next_customer_request.agent_id = self.agent.id
if system.logging:
print('<%s> Agent %s started to work on request %d' % (system.current_time,
self.agent.name, next_customer_request.id))
system.add_event(FinishCustomerRequestEvent('finish_handling_request',
system.current_time + datetime.timedelta(seconds = next_customer_request.handling_time_secs),
next_customer_request, self.agent))
class NewCustomerRequestEvent(Event):
def __init__(self, event_type, time, customer_request):
super().__init__(event_type, time)
self.customer_request = customer_request
def process(self, system: System):
# check whether we have a free agent
assigned_agent = system.get_free_agent(self.customer_request)
# if not put request in a queue
if assigned_agent is None:
system.queue.add_request(self.customer_request)
if system.logging:
print('<%s> Request %d put in a queue' % (system.current_time, self.customer_request.id))
# if yes, start processing it
else:
assigned_agent.start_task(self.customer_request)
self.customer_request.start_time = system.current_time
self.customer_request.agent_name = assigned_agent.name
self.customer_request.agent_id = assigned_agent.id
if system.logging:
print('<%s> Agent %s started to work on request %d' % (system.current_time, assigned_agent.name, self.customer_request.id))
system.add_event(FinishCustomerRequestEvent('finish_handling_request',
system.current_time + datetime.timedelta(seconds = self.customer_request.handling_time_secs),
self.customer_request, assigned_agent))
实际上,事件处理的业务逻辑就到此为止。我们需要完成的唯一部分就是将所有内容组合起来,运行我们的仿真。
在系统类中将所有内容整合在一起
正如我们讨论的那样,System类将负责运行仿真。因此,我们将在那里完成剩下的部分。
这是剩余的代码。让我简要地向你介绍一下要点:
-
is_simulation_finished定义了我们的仿真停止标准——队列中没有请求,且时间轴中没有事件。 -
process_next_event从时间线中获取下一个事件并对其执行process。这里有一个小细节:我们可能会遇到一种情况,模拟永远不会结束,因为"Plan Agents Schedule"事件不断发生。因此,在处理这种事件时,我会检查时间线中是否还有其他事件,如果没有,我就不再处理它,因为我们不再需要安排代理了。 -
run_simulation是控制我们世界的函数,但由于我们的架构相当不错,它只有几行代码:我们检查是否能完成模拟,如果不能,我们就处理下一个事件。
class System:
# defines the stopping criteria
def is_simulation_finished(self):
if self.queue.get_requests_in_queue() > 0:
return False
if self._timeline.get_remaining_events() > 0:
return False
return True
# wrappers for timeline methods to incapsulate this logic
def add_event(self, event):
self._timeline.add_event(event)
def get_next_event(self):
return self._timeline.get_next_item()
# returns free agent if we have one
def get_free_agent(self, customer_request):
for agent in self.agents:
if agent.is_ready_for_task(self.current_time):
return agent
# finds and processes the next event
def process_next_event(self):
event = self.get_next_event()
if self.logging:
print('# Processing event: ' + str(event))
if (event.type == 'plan_agents_schedule') and self.is_simulation_finished():
if self.logging:
print("FINISH")
else:
self.current_time = event.time
event.process(self)
# main function
def run_simulation(self):
while not self.is_simulation_finished():
self.process_next_event()
这是一段漫长的旅程,但我们做到了。做得好!现在,我们拥有了所有需要的逻辑。让我们进入有趣的部分,使用我们的模型进行分析。
你可以在GitHub上找到完整的实现。
分析
我将使用一个合成的客户请求数据集来模拟不同的运营设置。

首先,让我们运行我们的系统并查看指标。我将从 15 个工作正常时间的代理开始。
# initialising agents
regular_work_week = Schedule(
{
'Monday': [(9, 12), (13, 18)],
'Tuesday': [(9, 12), (13, 18)],
'Wednesday': [(9, 12), (13, 18)],
'Thursday': [(9, 12), (13, 18)],
'Friday': [(9, 12), (13, 18)]
}
)
agents = []
for id in range(15):
agents.append(Agent(id + 1, 'Agent %s' % id, regular_work_week))
# inital date
system_initial_date = datetime.date(2024, 4, 8)
# initialising the system
system = System(agents, [], system_initial_date, backlog_df, logging = False)
# running the simulation
system.run_simulation()
执行结果后,我们得到了system.processed_requests中的所有统计数据。让我们编写几个辅助函数,方便分析结果。
# convert results to data frame and calculate timings
def get_processed_results(system):
processed_requests_df = pd.DataFrame(list(map(lambda x: x.__dict__, system.processed_requests)))
processed_requests_df = processed_requests_df.sort_values('creation_time')
processed_requests_df['creation_time_hour'] = processed_requests_df.creation_time.map(
lambda x: x.strftime('%Y-%m-%d %H:00:00')
)
processed_requests_df['resolution_time_secs'] = list(map(
lambda x, y: int(x.strftime('%s')) - int(y.strftime('%s')),
processed_requests_df.finish_time,
processed_requests_df.creation_time
))
processed_requests_df['waiting_time_secs'] = processed_requests_df.resolution_time_secs - processed_requests_df.handling_time_secs
processed_requests_df['waiting_time_mins'] = processed_requests_df['waiting_time_secs']/60
processed_requests_df['handling_time_mins'] = processed_requests_df.handling_time_secs/60
processed_requests_df['resolution_time_mins'] = processed_requests_df.resolution_time_secs/60
return processed_requests_df
# calculating queue size with 5 mins granularity
def get_queue_stats(processed_requests_df):
queue_stats = []
current_time = datetime.datetime(system_initial_date.year, system_initial_date.month, system_initial_date.day, 0, 0, 0)
while current_time <= processed_requests_df.creation_time.max() + datetime.timedelta(seconds = 300):
queue_size = processed_requests_df[(processed_requests_df.creation_time <= current_time) & (processed_requests_df.start_time > current_time)].shape[0]
queue_stats.append(
{
'time': current_time,
'queue_size': queue_size
}
)
current_time = current_time + datetime.timedelta(seconds = 300)
return pd.DataFrame(queue_stats)
此外,让我们做几个图表并计算每周的指标。
def analyse_results(system, show_charts = True):
processed_requests_df = get_processed_results(system)
queue_stats_df = get_queue_stats(processed_requests_df)
stats_df = processed_requests_df.groupby('creation_time_hour').aggregate(
{'id': 'count', 'handling_time_mins': 'mean', 'resolution_time_mins': 'mean',
'waiting_time_mins': 'mean'}
)
if show_charts:
fig = px.line(stats_df[['id']],
labels = {'value': 'requests', 'creation_time_hour': 'request creation time'},
title = '<b>Number of requests created</b>')
fig.update_layout(showlegend = False)
fig.show()
fig = px.line(stats_df[['waiting_time_mins', 'handling_time_mins', 'resolution_time_mins']],
labels = {'value': 'time in mins', 'creation_time_hour': 'request creation time'},
title = '<b>Resolution time</b>')
fig.show()
fig = px.line(queue_stats_df.set_index('time'),
labels = {'value': 'number of requests in queue'},
title = '<b>Queue size</b>')
fig.update_layout(showlegend = False)
fig.show()
processed_requests_df['period'] = processed_requests_df.creation_time.map(
lambda x: (x - datetime.timedelta(x.weekday())).strftime('%Y-%m-%d')
)
queue_stats_df['period'] = queue_stats_df['time'].map(
lambda x: (x - datetime.timedelta(x.weekday())).strftime('%Y-%m-%d')
)
period_stats_df = processed_requests_df.groupby('period')\
.aggregate({'id': 'count', 'handling_time_mins': 'mean',
'waiting_time_mins': 'mean',
'resolution_time_mins': 'mean'})\
.join(queue_stats_df.groupby('period')[['queue_size']].mean())
return period_stats_df
# execution
analyse_results(system)
现在,我们可以使用这个函数来分析模拟结果。显然,15 个代理对于我们的产品来说不够,因为经过三周后,我们队列中有超过 4000 个请求,平均解决时间大约是十天。如果我们只有 15 个代理,客户对我们的服务肯定会非常不满意。


让我们找出需要多少个代理才能应对需求。我们可以进行多个模拟,使用不同数量的代理并比较结果。
tmp_dfs = []
for num_agents in tqdm.tqdm(range(15, 105, 5)):
agents = []
for id in range(num_agents):
agents.append(Agent(id + 1, 'Agent %s' % id, regular_work_week))
system = System(agents, [], system_initial_date, backlog_df, logging = False)
system.run_simulation()
tmp_df = analyse_results(system, show_charts = False)
tmp_df['num_agents'] = num_agents
tmp_dfs.append(tmp_df)
我们可以看到,在大约 25 到 30 个代理的情况下,不同周的指标大致相同,因此有足够的容量来处理不断增加的请求,队列不会一周一周地增长。


如果我们模拟 30 个代理的情况,可以看到从星期二到星期五,队列从 13:50 开始一直到工作日结束都为空。代理们会在星期一处理我们在周末积累的大量队列。

在这种设置下,平均解决时间为 500.67 分钟,平均队列长度为 259.39。
让我们尝试思考一些可能改善我们运营团队的方案:
-
我们可以再雇佣五个代理,
-
我们可以开始利用 LLMs,减少 30%的处理时间,
-
我们可以调整代理的工作时间表,以便在周末和晚间时段提供覆盖。
既然我们现在有了模型,我们可以轻松地估算所有机会并选择最可行的方案。
前两种方法很简单。接下来我们讨论如何调整代理的工作时间安排。我们所有的代理都在周一至周五,从 9 点到 18 点工作。我们尝试让他们的覆盖时间分布更均匀一些。
首先,我们可以覆盖更早和更晚的时间段,将代理分为两组。一组的工作时间是从 7 点到 16 点,另一组是从 11 点到 20 点。
其次,我们可以更加均匀地分配他们的工作日。我使用了一种相当直接的方法。

实际上,你可以进一步减少周末的代理人数,因为我们在周末的需求要少得多。这可以进一步改善你的指标。然而,额外的效果会是边际性的。
如果我们对所有这些情境进行仿真,令人惊讶的是,如果我们只是调整代理的工作时间,KPI 会有显著的提升。如果我们再雇佣五个员工或提高代理的绩效 30%,也无法取得如此显著的改善。

让我们看看代理工作时间的变化如何影响我们的关键绩效指标(KPI)。分辨时间只会在非工作时间(从 20 点到 7 点)增长,队列大小也永远不会超过 200 个案件。


这是一个非常好的结果。我们的仿真模型帮助我们优先考虑运营变更,而不是雇佣更多员工或投资于大语言模型工具的开发。
我们在本文中讨论了这种方法的基本原理。如果你想深入了解并将其应用于实践,这里有一些可能有用的建议:
-
在开始将这些模型投入生产之前,值得先进行测试。最简单的方法是模拟你当前的状况,并比较主要 KPI。如果它们相差很大,那么你的系统并没有很好地反映现实世界,你需要在决策前使其更加准确。
-
当前的指标是以客户为中心的。我使用了平均解决时间作为主要 KPI 来做决策。在商业中,我们还关心成本。因此,从运营角度来看这个问题也是值得的,也就是说,衡量代理没有任务可做的时间百分比(这意味着我们在为他们支付薪水,但他们什么都不做)。
-
在现实生活中,可能会出现突发情况(例如,由于产品中的一个 bug,客户请求数量翻倍),因此我建议你使用这些模型来确保你的客服团队能够应对这种情况。
-
最后但同样重要的是,我使用的模型完全是确定性的(每次运行返回相同的结果),因为处理时间是为每个客户请求定义的。为了更好地理解指标的变动性,您可以为每个代理指定处理时间的分布(取决于任务类型、星期几等),并在每次迭代中从该分布中获取处理时间。然后,您可以多次运行仿真并计算您的指标的置信区间。
摘要
那么,让我们简要总结一下今天讨论的要点:
-
我们已经学习了离散事件仿真方法的基础,这种方法有助于模拟具有可计数事件的离散系统。
-
我们已经复习了 Python 中的面向对象编程和类,因为这种范式比数据分析师通常使用的常规过程化代码更适合此任务。
-
我们已经构建了 CS 团队的模型,并能够估算不同潜在改进对我们的 KPI(解决时间和队列大小)的影响。
非常感谢您阅读本文。如果您有任何后续问题或评论,请在评论区留言。
参考
除非另有说明,否则所有图片均由作者制作。
数据分析与预处理实用指南
数据清理、转换和验证技术,以确保数据质量
·发布于Towards Data Science ·49 分钟阅读·2024 年 10 月 31 日
--

图片来自Danist Soh提供,来源于Unsplash
在这个项目中,我们将利用一个源自虚构公司的数据集,该数据集包括人口统计数据以及对员工进行的心理测评结果。
关键变量包括**年龄**、**性别**、**教育水平**和**薪资**,这些在公司环境中至关重要。主要目标是对这些数据进行预处理,确保后续分析的数据质量和一致性。
虽然数据集是虚构的,但它有效地模拟了现实世界的场景,变量经过精心选择,能够代表与商业环境相关的实际和适用的信息。所有项目文件和额外资源可以在我的 GitHub 上访问:
[## GitHub - Anello92/data-preprocessing-guide
通过在 GitHub 上创建帐户,贡献 Anello92/data-preprocessing-guide 的开发。
在这个项目中,我们将深入探讨基本的预处理技术,解决常见的挑战并找出解决方案。项目的结构将引导我们从数据导入的初步阶段开始…
使用潜在狄利克雷分配(LDA)进行主题建模的实用指南
在最多减少 99%的训练时间内获得更好的结果
·发布于 Towards Data Science ·14 分钟阅读·2024 年 1 月 6 日
--
潜在狄利克雷分配(简称 LDA)是一种混合成员(“软聚类”)模型,通常用于推断文档讨论的内容。当你阅读本文时,可以轻松推断它是关于机器学习、数据科学、主题建模等方面的。但当你面对上百万个文档时,无法手动阅读并标记每一个文档来提取模式和趋势。你需要像 LDA 这样的机器学习模型来帮助你。
LDA 即使在你不处理文本数据时也能有用。文本是经典的应用场景,但并不是唯一的。如果你在一家在线商店工作,你可以使用 LDA 推断产品的软分类。在分类设置中,“巧克力”必须归类为“零食”这一类,而 LDA 允许“巧克力”同时归入多个类别,如“零食”、“烘焙”、“饮料”和“酱料”。你还可以将 LDA 应用于点击流数据,根据观察到的用户行为对页面进行分组和分类。
由于 LDA 是一个概率模型,它可以很好地与其他概率模型(如泊松分解)结合使用。你可以通过 LDA 嵌入项目,然后使用 PF学习用户偏好。在新闻文章的背景下,这可以为“冷启动”推荐提供帮助,当一篇文章刚发布时(或许用于推送通知?),在新闻变得过时之前。
我的资历?我花了一个学期专注于贝叶斯推理算法,并从头开始编写 LDA 代码,以理解其内部工作原理。之后,我在一家新闻集团工作,创建了一个必须扩展到数百万篇文章的 LDA 管道。在这个规模上,许多小的选择可能决定了模型运行时间是几天还是一年。可以说,我比绝大多数数据科学家更了解 LDA。
在所有那段时间里,我从未遇到过一本能够解释如何正确使用 LDA 的资源,特别是在大规模应用时。本文可能是第一个。希望它对你有用,不管你是谁。简而言之:
-
使用 spaCy 而不是 NLTK 进行分词
-
使用特定的 scikit-learn 实现的 LDA
-
将 learning_mode 设置为“在线”
-
知道哪些超参数范围是合理的
-
通过随机搜索选择超参数,使用验证熵作为标准
我假设读者已经熟悉 LDA 的工作原理及其作用。许多文章已经对其进行了说明。我不会重复那些容易找到的信息。
免责声明:本文的内容可能已经过时一两年,因为我已经很长时间没有使用 LDA,但我相信一切仍然是准确的。
为什么选择潜在狄利克雷分配(LDA)?
LDA 及其相关方法(NMF,PF,截断 SVD 等)不过是针对计数数据进行修改的高级 PCA。 (顺便问一下,你看过这个精彩的 PCA 解释 吗?)LDA 与其他方法的不同之处在于,它通过以下特性创建人类可解释的嵌入,呈现为主题:
- 非负。显然,计数不能为负数,但真正重要的是 非负性迫使模型学习部分。我最喜欢的 短篇论文之一 说明了非负性如何迫使模型学习面部的部分,比如鼻子、眼睛、嘴巴等。相比之下,PCA 的加载向量是抽象的,因为你可以从一个部分中减去另一个部分。

从非负矩阵分解中学习到的面部部位。来源:Gillis (2014)
-
和为 1。LDA 中的嵌入是比例。该模型假设混合成员关系,因为文本是复杂的,且很少仅涉及单一主题。
-
稀疏性。嵌入大多是零。每篇文档预计只会讨论少数几个主题。没有人会写一篇包含 100 个主题的文章。
-
人类可解释的加载向量。在 PCA 和其他嵌入算法中,通常无法明确每个维度的含义。而在 LDA 中,你可以通过查看最高概率的词语(“top n words”)来理解每个维度(“主题”)的含义。
一个常见的误解是 LDA 是一种 NLP 算法。事实上,只要数据不太稀疏,你可以在任何计数数据上使用 LDA。LDA 所做的仅仅是创建一个低维度的可解释的计数嵌入。你可以在用户的购买历史或浏览历史上应用 LDA,以推断出不同类型的购物习惯。我过去曾这样使用过,它效果出奇的好。Blei 教授曾在一次研讨会上提到,有一位经济学研究者正是在用 LDA 做类似的实验;那时我感到非常欣慰。
LDA 的输出经常被误解。人们将其当作分类算法,而不是混合归属模型。当 LDA 说一篇文档是 60% 政治和 40% 经济时,它实际上是在说该文档同时是政治和经济的,比例分别是 60% 和 40%。有些人误解为“文档被归类为政治,但模型不太确定”。如果是一篇长篇文章,模型可能非常确定这篇文档既是政治也是经济。
也有替代方法,比如top2vec,它的概念上与word2vec相似。非常酷!然而,我认为 LDA 在几个方面优于 top2vec:
-
LDA 是一种多重归属模型,而 top2vec 假设每篇文档只属于一个主题。如果你的语料库很简单,每篇文档都紧扣一个主题,那么 top2vec 是有意义的。
-
top2vec 使用距离来推断主题,这并不符合直观理解。由于维度诅咒的存在,距离这一概念在高维空间中变得模糊不清。那么这些距离意味着什么呢?作为一个过于简化的例子,假设三个主题在一条数轴上:食物 — 体育 — 科学。如果一篇文档讨论的是食品科学,它就位于中间,结果变成了一篇体育文档?实际上,距离在高维空间中并非如此工作,但我的保留意见应该是显而易见的。
提示 #1:使用 spaCy 代替 NLTK 进行分词和词形还原
语料库在输入 LDA 之前需要经过处理。如何处理呢?spaCy 在业界非常流行,而 NLTK 在学术界非常受欢迎。它们各有优劣。在工作环境中,NLTK 并不真正可接受——不要因为你在学校用它习惯了就继续使用它。
NLTK 以慢著称。我没有进行过自己的比较,但这个人报告称,使用 spaCy 代替 NLTK 在分词时速度提升了 20 倍。
令人惊讶的是,目前还不清楚 LDA 是否从词干提取或词形还原中受益。我见过不同的观点和实验,结果互有胜负。这篇论文声称词干提取会使主题更糟。进行词形还原的主要原因是为了使主题更加可解释,通过将词素归结为一个标记。
我不会对是否应该进行词形还原提供意见,但如果你决定进行词形还原,spaCy 的词形还原速度和效果都比 NLTK 更好。在 NLTK 中,我们需要设置一个词性标注管道,然后将其传递给 WordNet 词形还原器,该词形还原器会在词汇数据库中查找单词。spaCy 使用 word2vec 自动推断词性,这样它就能正确地进行词形还原——使用起来更加简单,速度也更快。
使用 spaCy 时,确保使用基于 word2vec 的en_core_web_lg,而不是基于 transformer 的 en_core_web_trf 语言模型。虽然 transformer 稍微准确一些(也许准确率提升 1%),但根据 spaCy 的速度基准测试,它的速度慢了大约 15 倍。我自己在工作中也观察到了这一差距。对于数百万篇文章,transformer 实在太慢了,因为处理所有内容需要几个月的时间才能完成词形还原和分词。
提示 #2:使用 scikit-learn,而不要触碰其他包来进行 LDA。
这或许是最重要且最令人惊讶的建议:无条件使用 sklearn 的 LDA 实现。性能差异简直无法比拟。我们将其与两个流行的 LDA 模型拟合包进行比较:
-
mallet 使用折叠吉布斯采样,一种 MCMC 算法。(如果你想了解更多关于 MCMC 的内容,可以查看 我的文章。)MCMC 以慢和不可扩展而著称。更糟糕的是,吉布斯采样经常卡在局部极值;大多数 NLP 问题是高度多模态的,这使得 mallet 无法应用于真实世界的任务。
-
gensim 使用随机变分推断(SVI),这是随机梯度下降的贝叶斯类比。作为 LDA 更新规则的一部分,gensim 选择精确计算 digamma 函数,这是一项极为昂贵的操作。而 sklearn 选择了对其进行近似计算,从而实现了 10 到 20 倍的速度提升。更糟糕的是,gensim 的 SVI 实现是错误的,没有任何函数参数可以修复它。准确地说:如果你一次性输入整个语料库,gensim 的 SVI 会正常运行;但如果你在每次迭代时提供一个样本,gensim 的 LDA 就永远无法收敛。
这一点关于 gensim 的发现让我很吃惊。它是一个非常流行的包(每月超过 300 万次下载!),专门用于主题建模——它怎么可能比 sklearn 差,sklearn 是一个通用包呢?在工作中,我花了很多天进行故障排除。我深入研究了源代码。结果,我发现源代码的更新方程有错误。
我在学校时从头开始编写了使用 SVI 训练的 LDA。它运行得非常低效(我是数据科学家,不是机器学习工程师!),但输出是正确的。我知道模型在每次迭代时应该如何更新。gensim 的实现是不正确的。仅仅在第一次迭代之后,结果就偏差如此之大,我不得不将手动计算与 gensim 的输出进行比较,才搞清楚出了什么问题。如果你从 100 篇文档中抽样来输入 SVI 的一次迭代,gensim 会认为你的整个语料库仅有 100 篇文档,尽管你是从一百万篇文档中抽样的。你无法在 update()方法中告诉 gensim 语料库的大小。
如果你一次性提供整个语料库,gensim 运行得很好。然而,在工作中,我处理了数百万篇新闻文章,根本无法将所有内容都放入内存。在处理大规模语料库时,gensim 完全失败。
sklearn 的版本实现是正确的。
提示 #3:使用随机变分推断(SVI)算法进行训练
既然我们已经确定不应该使用除 sklearn 之外的任何工具,我们将参考sklearn 的 LDA 函数。我们将特别讨论学习方法参数: “批量”与“在线”(SVI)类似于线性回归中的“IRLS”与“SGD”。
线性回归的运行时间为 O(n³)。IRLS 需要一次性处理整个数据集。如果我们有一百万个数据点,IRLS 需要 10¹⁸单位的时间。使用 SGD,我们可以在每次迭代中抽取 1,000 个数据点,并运行 1,000 次迭代来逼近 IRLS 的精确解,这将消耗 10⁹ x 10³ = 10¹²单位的时间。在这种情况下,SGD 的运行速度是 IRLS 的一百万倍!SGD 预计会有一些不完美,因为它只是逼近 IRLS 的最优解,但通常足够接近。
使用 SVI 方法时,那个直觉就不适用了:“在线”比“批量”更合适,而且运行速度更快。它是严格更优的。没有任何理由使用“批量”模式。SVI 论文深入探讨了这一点:

一般来说,“在线”模式的训练时间仅为“批量”模式的 10%,且能获得相同的结果。为了在大语料库上正确使用“在线”模式,你必须将 total_samples 设置为语料库中所有文档的总数;否则,如果样本量仅占语料库的一小部分,LDA 模型将无法在合理的时间内收敛。你还需要使用 partial_fit()方法,一次处理一个小批量数据。我将在下一节中讨论其他设置。
提示 #4:了解超参数的合理搜索空间
根据 sklearn 的参数,LDA 有六个可调的超参数:
-
n_components(默认为 10):主题的数量。显而易见。
-
doc_topic_prior(默认为 1/n_components):局部参数的先验。贝叶斯先验相当于正则化,等同于用虚假数据进行填充。doc_topic_prior × n_components表示每篇文档中添加的虚假词汇数量。如果你分析的是推文,1 到 2 个虚假词汇可能是有意义的,但 1000 个虚假词汇完全没有意义。如果你分析的是短篇小说,1 到 2 个虚假词汇几乎可以忽略不计,而 1000 个虚假词汇则是合理的。请运用你的判断力。通常,除非每篇文档非常长,否则值设置为 1 以下。你的搜索空间可以设置为{0.001, 0.01, 0.1, 1}。
-
topic_word_prior(默认为 1/n_components):全局参数的先验。再说一遍,贝叶斯先验相当于正则化,等同于用虚假数据进行填充。topic_word_prior × n_components × n_features表示在任何训练之前,模型中添加的虚假词汇数。n_features 是模型或语料库中标记的数量。如果该乘积为 1000,并且你分析的推文每条平均 10 个词,那么你就会向语料库中添加 100 条虚假推文。请运用你的判断力。
-
learning_decay(默认为 0.7):确定每次迭代时步长的缩小程度。较低的 learning_decay 值使得步长更慢地缩小——模型可以在多模态目标函数中探索更多模式,但收敛速度较慢。你必须将 learning_decay 设置为 0.5 < learning_decay ≤ 1,才能使 LDA 收敛(这适用于任何 SGD 算法,必须满足Robbins-Monro 条件)。有趣的是,gensim 的默认值是 0.5,这会误导不了解的用户,训练一个无法收敛的模型。从经验上来看,0.7 到 0.8 之间的值能获得最佳结果。
-
learning_offset(默认为 10):确定初始步长。较高的值会导致较小的初始步长。从经验来看,当 batch_size 相对于语料库中的文档数量较小时,模型会从较高的 learning_offset 中受益,通常设置在 100 以上。你希望采取较大的步伐。搜索{1, 2, 3, 4}的效果不如搜索{1, 10, 100, 1000}。
-
batch_size(默认值 = 128):每次迭代时 SVI 看到的文档数量。可以将其视为一个不精确的指南针。batch_size 越大,你对自己朝正确方向迈步的确定性就越强,但计算的时间也会越长。根据我的经验,128 太小了,因为步骤往往走错方向,这使得模型更难以收敛。我推荐一个大约 2–10 千的 batch_size,SVI 可以轻松处理。如果计算时间不成问题,更大的 batch_size 几乎总是更好。在超参数调优时,我通常会在心里设定一个固定数量的(带替换的)文档,比如 500k,并设置运行 50 次 batch_size 为 10,000 的迭代,或 250 次 batch_size 为 2,000 的迭代,以比较哪个设置能在计算上获得更多的回报。然后,我会保持这些设置,进行更多的迭代训练。你需要为
partial_fit()方法提供一个随机采样的文档,大小为 batch_size。
提示#5:使用随机搜索和熵损失调优超参数
在如今的时代,随机搜索应该是超参数调优的默认算法。 在仅 60 次迭代内,随机搜索有超过 95%的概率找到搜索空间中最佳 5%的超参数(证明)。当然,如果你的搜索空间完全错过了最佳区域,你永远无法获得良好的性能。
这篇论文由 Bergstra 和 Bengio 撰写,说明了随机搜索能够合理地击败网格搜索。网格搜索对不影响特定用例的超参数赋予了过多的关注。如果两个超参数中只有一个对目标有显著影响,那么一个 3x3 的网格仅会尝试那个超参数的三个值;而 9 点的随机搜索则会尝试该超参数的九个不同值,这给了你更多的机会去找到一个优秀的值。网格搜索也常常会忽略那些表现优良的狭窄区域。
使用 SVI 拟合的 LDA 有六个可调超参数(如果使用全批次,则只有三个)。如果我们想为每个超参数尝试少至三个值,那么我们的网格搜索将经历 3⁶ = 729 次迭代。使用随机搜索将其减少到 60 次(通常)能获得更好的结果,这显而易见。
随机搜索应该配置为“智能地”采样。n_components可以从离散均匀分布中采样,但其他超参数,如doc_topic_prior,应从对数正态分布或对数均匀分布中采样,也就是说,与其在{1, 2, 3, 4}中采样,不如在{0.01, 0.1, 1, 10}之间均匀采样更为智能。
如果你想稍微比随机搜索做得更好,你可以通过hyperopt 包使用 TPE。与使用高斯过程的贝叶斯优化不同,TPE 设计上更适合混合使用连续和离散(n_components)超参数。然而,考虑到投入的工作量,它带来的改进非常有限,因此在大多数情况下不值得使用。
好的,现在我们已经确认随机搜索比网格搜索更好……那我们如何知道哪个超参数组合表现最佳呢?
主题建模有一个特定的度量指标:主题一致性。它有多种形式,例如 UMass 和 UCI。根据我的经验,一致性在实际应用中并不是一个好的度量标准,因为它通常无法在验证集上计算。当一个词汇没有出现在验证集中时,这个度量就会试图除以零。主题一致性对于超参数调优是没用的。
传统上,语言模型的评估使用困惑度,定义为 2^熵。然而,当超参数不好时,这个数值可能非常大,导致数值溢出错误。sklearn 的 LDA 有一个score方法,它是负熵的近似值。使用 sklearn 的 score。分数越高越好。(如果 score 方法仍然遇到溢出问题,你将需要自己创建对数困惑度方法。)
提示:你可以为主题创建先验
LDA 的输出可能非常不一致且随机。这是任何 NLP 问题的固有特性。目标函数是多模态的,而 SVI LDA 只适合于一个单一模式。即使使用完全相同的设置重新运行 LDA,也可能得到不同的主题。
有时候,我们需要更好地控制 LDA 所学习的主题。例如,业务相关方可能需要确保存在十个特定的主题。你可以尝试一遍又一遍地运行 LDA,直到这十个主题出现,但你更有可能在玩轮盘时运气更好。
解决方案?尽管 sklearn 文档中说 topic_word_prior 接受一个单一的浮动值,它其实可以接受一个矩阵! 我深入源码发现,sklearn 实际上创建了一个矩阵,矩阵中的所有元素都是输入的浮动值。然而,如果你提供了正确维度的矩阵,LDA 会使用你提供的矩阵。

一个好的先验会在模型训练开始之前,通过为每个文档中的某些单词进行颜色编码。来源:rawpixel
假设你需要一个篮球话题和一个高尔夫话题。你可以将一个话题的先验填充为包含高概率篮球相关词汇的分布。同样地,处理高尔夫话题,然后将另一个话题的先验填充为均匀分布。当你训练模型时,LDA 会更有可能生成这两个话题。
我说的是更有可能。LDA 是通过随机方法拟合的。我们无法根据初始设置预料它最终会在哪里结束。
然而,通过对设置进行一些调整,我们可以增加这些话题出现的可能性:提高 learning_offset 值,并增加 learning_decay 值,同时进行更多迭代(因为模型变得更慢,收敛需要更多时间)。相反,这两个超参数的低值将立即抹去你设置的任何先验。
最后
希望本文能清楚地表明,99%的训练时间减少并不是为了吸引眼球。一个对 LDA 知之甚少的人,合理的做法是使用 NLTK 进行分词,采用 gensim 的随机变分推断算法,然后在一个低效的搜索空间内进行网格搜索。从 NLTK 切换到 spaCy 可以提升 8 到 20 倍的速度,但这是模型管道中的一个单独且相对较小的部分。我们将重点关注模型训练方面。遵循本文中的所有建议可以带来以下改进:
-
对 LDA 不太熟悉的人可能会使用 gensim。sklearn 对目标函数的实现本身就能将训练时间缩短 10 到 20 倍。我们保守估计,它可以将训练时间缩短至原来的 10%。
-
或者,LDA 不熟悉的人可能会从 sklearn 开始,但使用‘批处理’模式。从全批次变分推断切换到随机变分推断可以将时间缩短 10 倍。这也将训练时间缩短至 10%。
-
我们需要调优六个超参数。如果我们想尝试每个参数的 3 个不同值并进行网格搜索,那将需要 729 次迭代。而随机搜索只需要 60 次迭代就能表现得很好,而且它很可能会超过网格搜索。这相当于减少了 10 倍的计算量,将训练时间缩短到原来的 1%。
将模型训练时间减少 100 倍并不是唯一的结果。如果你按照本文中的建议进行操作,模型应该会生成更合适的主题,使其更有意义。
数据科学的很多部分仅仅是对算法的表面理解,并随机地投掷东西,看看什么会有效。专业知识常常被标签化为过于学究(尤其在“科学”领域!)。然而,深入理解让我们能够更加高效地使用工具,我敦促大家深入研究我们选择使用的工具。


浙公网安备 33010602011771号