Python-加速指南-全-
Python 加速指南(全)
原文:Fast Python
译者:飞龙
前置内容
前言
几年前,我们团队正在开发的一个基于 Python 的流水线突然停止了。一个进程一直在使用 CPU 并且没有完成。这个功能对公司至关重要,我们需要尽快解决这个问题。我们检查了算法,它看起来没问题——事实上,它是一个非常简单的实现。经过几个小时几位工程师的共同努力,我们发现问题归结为在一个列表上进行搜索——一个非常大的列表。在将列表转换为集合后,问题就轻易解决了。我们最终得到了一个更小的数据结构,搜索时间从小时缩短到毫秒。
那时我有了几个顿悟:
-
这是一个简单的问题,但我们的开发流程并没有关注性能问题。例如,如果我们经常使用性能分析器,我们会在几分钟内而不是几小时后发现性能问题。
-
这是一个双赢的局面:我们最终节省了时间和内存。是的,在许多情况下,需要做出权衡,但在其他情况下,有一些非常有效的结果而没有副作用。
-
从更广阔的角度来看,这种情况也是一个双赢的局面。首先,更快的结果对公司财务状况大有裨益。其次,一个好的算法使用更少的 CPU 时间,这意味着更少的电力消耗,而减少电力消耗(即资源使用)对地球更有益。
-
虽然我们的单一案例对节省能源帮助不大,但我突然意识到许多程序员正在设计类似的解决方案。
我决定写这本书,让其他程序员也能从我的顿悟中受益。我的目标是帮助经验丰富的 Python 程序员设计并实现更高效的解决方案,同时理解潜在的权衡。我想通过讨论纯 Python 和重要的 Python 库,从算法的角度出发,考虑现代硬件架构及其影响,并讨论 CPU 和存储性能,来全面地探讨这个主题。我希望这本书能帮助你在 Python 生态系统开发中更加自信地面对性能问题。
致谢
我要感谢开发编辑 Frances Lefkowitz,她的耐心无限。我还要感谢我的女儿和妻子,在我写这本书的过去几年里,她们不得不忍受我的缺席。还要感谢 Manning 出版社的生产团队,他们帮助创建了这本书。
致所有审稿人:Abhilash Babu、Jyotheendra Babu、Andrea Smith、Biswanath Chowdhury、Brian Griner、Brian S Cole、Dan Sheikh、Dana Robinson、Daniel Vasquez、David Paccoud、David Patschke、Grzegorz Mika、James Liu、Jens Christian B. Madsen、Jeremy Chen、Kalyan Reddy、Lorenzo De Leon、Manu Sareena、Nik Piepenbreier、Noah Flynn、Or Golan、Paulo Nuin、Pegah T. Afshar、Richard Vaughan、Ruud Gijsen、Shashank Kalanithi、Simeon Leyzerzon、Simone Sguazza、Sriram Macharla、Srutu Shivakumar、Steve Love、Walter Alexander Mata López、William Jamir Silva 和 Xie Yikuan——您的建议帮助使这本书变得更好。
关于这本书
这本书的目的是帮助你编写在 Python 生态系统中的更高效的应用程序。更高效意味着你的代码将使用更少的 CPU 循环、更少的存储空间和更少的网络通信。
本书对性能问题采取了整体的方法。我们不仅讨论了纯 Python 中的代码优化技术,还考虑了广泛使用的数据库,如 NumPy 和 pandas 的有效使用。由于 Python 在某些情况下性能不足,当我们需要更多速度时,我们也考虑了 Cython。与这种整体方法一致,我们还讨论了硬件对代码设计的影响:我们分析了现代计算机体系结构对算法性能的影响。我们还考察了网络架构对效率的影响,并探讨了 GPU 计算在快速数据分析中的应用。
谁应该阅读这本书?
这本书旨在面向中级到高级的读者。如果你浏览了目录,你应该能认出大多数技术,你很可能已经使用过其中的一些。除了关于 IO 库和 GPU 计算的部分,提供的入门材料很少:你需要已经了解基础知识。如果你目前正在编写代码以实现高性能,并面临处理大量数据时的实际挑战,那么这本书适合你。
为了从这本书中获得最大收益,你应该至少有几年的 Python 经验,并了解 Python 控制结构以及列表、集合和字典。你应该有一些 Python 标准库的经验,如 os、sys、pickle 和 multiprocessing。为了充分利用我这里介绍的技术,你也应该对标准数据分析库有一定的了解,如 NumPy——至少对数组有最小了解——以及 pandas——对数据框有一些经验。
如果你对通过 C 或 Rust 的外部语言接口加速 Python 代码的方式有所了解,即使你没有直接的接触,或者知道像 Cython 或 Numba 这样的替代方法,这将很有帮助。处理 Python 中的 IO 的经验也会有所帮助。鉴于 IO 库在文献中较少被探索,我们将从 Apache Parquet 这样的格式和 Zarr 这样的库开始。
你应该了解 Linux 终端(或 MacOS 终端)的基本 shell 命令。如果你使用的是 Windows,请安装基于 Unix 的 shell 或者熟悉命令行或 PowerShell。当然,你需要在电脑上安装 Python 软件。
在某些情况下,我会提供关于云的一些提示,但阅读这本书并不需要云访问或相关知识。如果你对云方法感兴趣,那么你应该知道如何进行基本操作,比如创建实例和访问云提供商的存储。
虽然你不需要在该领域接受学术训练,但了解复杂度成本的基本概念会有所帮助——例如,直观地认为与数据线性扩展的算法比与数据指数扩展的算法更好。如果你计划使用 GPU 优化,在这个阶段你不需要了解任何相关内容。
本书是如何组织的:一个路线图
本书中的章节大多是独立的,你可以跳转到对你重要的任何章节。但话说回来,本书分为四个部分。
第一部分,基础方法(第一章至第四章),涵盖了入门材料。
-
第一章介绍了问题,并解释了为什么我们必须关注计算和存储中的效率。它还介绍了本书的方法,并提供了针对你需求的导航建议。
-
第二章涵盖了原生 Python 的优化。我们还讨论了 Python 数据结构的优化、代码分析、内存分配和懒加载编程技术。
-
第三章讨论了 Python 中的并发和并行性以及如何最大限度地利用多进程和多线程(包括使用线程进行并行处理时的限制)。本章还涵盖了异步处理,作为一种处理多个低负载并发请求的高效方式,这在网络服务中很常见。
-
第四章介绍了 NumPy 库,这是一个允许你高效处理多维数组的库。NumPy 是所有现代数据处理技术的核心,因此被视为一个基础库。本章分享了特定的 NumPy 技术,以开发更高效的代码,例如视图、广播和数组编程。
第二部分,硬件(第五章和第六章),主要关注从常见硬件和网络中提取最大效率。
-
第五章涵盖了 Cython,它是 Python 的超集,可以生成非常高效的代码。Python 是一种高级解释型语言,因此并不期望它在硬件级别上进行优化。有一些语言,如 C 或 Rust,被设计成在硬件级别上尽可能高效。Cython 属于这个语言领域:虽然它与 Python 非常接近,但它编译成 C 代码。生成最有效的 Cython 代码需要关注代码如何映射到高效的实现。在本章中,我们学习如何创建高效的 Cython 代码。
-
第六章讨论了现代硬件架构对高效 Python 代码设计的影响。鉴于现代计算机的设计方式,一些看似不直观的编程方法可能比预期的更有效。例如,在某些情况下,处理压缩数据可能比处理未压缩数据更快,即使我们需要付出解压缩算法的代价。本章还涵盖了 CPU、内存、存储和网络对 Python 算法设计的影响。我们讨论了 NumExpr 库,它可以通过利用现代硬件架构的特性来使 NumPy 代码更高效。
第三部分,现代数据处理的应用和库(第七章和第八章),探讨了现代数据处理中使用的典型应用和库。
-
第七章专注于尽可能高效地使用 pandas,这是 Python 中使用的 DataFrame 库。我们将探讨与 pandas 相关的技术来优化代码。与本书中的大多数章节不同,这一章是从早期章节开始的。pandas 建立在 NumPy 之上,因此我们将从第四章中学到的知识中汲取,并发现与 NumPy 相关的优化技术。我们还探讨了如何使用 NumExpr 和 Cython 来优化 pandas。最后,我介绍了 Arrow 库,这个库除了其他功能外,还可以用于提高处理 pandas 数据框的性能。
-
第八章探讨了数据持久性的优化。我们讨论了 Parquet,一个用于高效处理列式数据的库,以及 Zarr,它可以处理非常大的磁盘数组。我们还开始讨论如何处理大于内存的数据集。
第四部分,高级主题(第九章和第十章),处理了两种最终且截然不同的方法:使用 GPU 以及使用 Dask 库。
-
第九章探讨了使用图形处理单元(GPU)处理大型数据集的用途。我们将看到 GPU 计算模型——使用许多简单的处理单元——对于处理现代数据科学问题来说是相当合适的。我们将采用两种不同的方法来利用 GPU。首先,我们将讨论提供类似接口的现有库,例如 CuPy 作为 NumPy 的 GPU 版本。其次,我们将介绍如何从 Python 生成在 GPU 上运行的代码。
-
第十章讨论了 Dask 库,这是一个允许你编写可扩展到多台机器的并行代码的库——无论是在本地还是云端——同时提供类似于 NumPy 和 pandas 的熟悉接口。
本书还包括两个附录。
-
附录 A 将指导你安装使用本书示例所需的所有软件。
-
附录 B 讨论了 Numba,它是 Cython 的替代品,用于生成高效的底层代码。Cython 和 Numba 是生成底层代码的主要途径。为了解决现实世界的问题,我推荐使用 Numba。那么,为什么我会在整章中专门介绍 Cython,而将 Numba 放在书的后面呢?因为本书的主要目的是为你提供一个坚实的 Python 生态系统高效编写代码的基础,而 Cython 的额外障碍使我们能够更深入地理解正在发生的事情。
关于代码
本书包含许多源代码示例,既有编号的列表,也有与普通文本混排。在这两种情况下,源代码都以固定宽度字体如这样来格式化,以将其与普通文本区分开来。有时代码也会**加粗`**,以突出显示与章节中先前步骤不同的代码,例如当新功能添加到现有代码行时。
在许多情况下,原始源代码已被重新格式化;我们添加了换行并重新调整了缩进,以适应书中的可用页面空间。在极少数情况下,即使这样也不够,列表中还包括了行续接标记(➥)。此外,当代码在文本中描述时,源代码中的注释通常也会从列表中删除。许多列表旁边都有代码注释,突出显示重要概念。
你可以从本书的 liveBook(在线)版本中获取可执行的代码片段,网址为livebook.manning.com/book/fast-python。本书中示例的完整代码可在 GitHub 上下载,网址为github.com/tiagoantao/python-performance,以及 Manning 网站www.manning.com。当发现错误或 Python 和现有库的重大发展需要一些修订时,我将更新存储库。因此,请预期本书存储库中会有一些变化。你将在存储库中找到一个每个章节的目录。
无论你偏好哪种代码风格,我都已经调整了这里的代码,使其在印刷书籍中运行良好。例如,我倾向于使用长且描述性的变量名,但这些在书籍形式的限制下并不适用。我尝试使用表达性的名称并遵循标准的 Python 约定,如 PEP8,但书籍的可读性更为重要。同样,对于类型注解也是如此:我愿意使用它们,但它们会妨碍代码的可读性。在极少数情况下,我会使用算法来提高可读性,即使它并不处理所有边缘情况或对解释增加多少帮助。
在大多数情况下,本书中的代码将与标准的 Python 解释器兼容。在某些有限的场景中,需要 IPython,特别是为了方便的性能分析。你也可以使用 Jupyter Notebook。
关于安装的详细信息可以在附录 A 中找到。如果任何章节或部分需要特殊软件,将在适当的位置注明。
liveBook 讨论论坛
购买《Fast Python》包括对 liveBook 的免费访问,这是 Manning 的在线阅读平台。使用 liveBook 的独特讨论功能,你可以对整本书、特定章节或段落附加评论。为自己做笔记、提出和回答技术问题,以及从作者和其他用户那里获得帮助都非常简单。要访问论坛,请访问livebook.manning.com/book/fast-python/discussion。你还可以在livebook.manning.com/discussion了解更多关于 Manning 论坛和行为准则的信息。
Manning 对我们读者的承诺是提供一个平台,让读者之间以及读者与作者之间可以进行有意义的对话。这并不是作者在特定程度上参与的承诺,作者对论坛的贡献仍然是自愿的(且未付费)。我们建议你尝试向作者提出一些挑战性的问题,以免他们的兴趣分散!论坛和以前讨论的存档将可以在书籍印刷期间通过出版社的网站访问。
硬件和软件
你可以使用任何操作系统来运行本书中的代码。话虽如此,Linux 是大多数生产代码倾向于部署的地方,因此它是首选的系统。MacOS X 也应该无需任何调整即可运行。如果你使用 Windows,我建议你安装 Windows Subsystem for Linux (WSL)。
所有操作系统的替代方案是 Docker。你可以使用存储库中提供的 Docker 镜像。Docker 将提供一个容器化的 Linux 环境来运行代码。
我建议您至少拥有 16 GB 的内存和 150 GB 的空闲磁盘空间。第九章包含与 GPU 相关的内容,需要至少基于 Pascal 架构的 NVIDIA GPU;过去五年中发布的多数 GPU 都应满足这一要求。关于如何准备您的计算机和软件以充分利用本书的更多详细信息,请参阅附录 A。
关于作者
Tiago Rodrigues Antão 拥有信息学学士学位和生物信息学博士学位。他目前从事生物技术领域的工作。Tiago 使用 Python 及其所有库来执行科学计算和数据工程任务。他经常使用诸如 C 和 Rust 等底层编程语言来优化算法的关键部分。他目前在一个基于 Amazon AWS 的基础设施上开发,但他在职业生涯的大部分时间里使用的是本地计算和科学集群。
除了在工业界工作外,他在科学计算学术方面的经验包括在剑桥大学和牛津大学担任两个数据分析博士后。作为蒙大拿大学的科研人员,他从零开始创建了用于分析生物数据的整个科学计算基础设施。
Tiago 是 Python 编写的生物信息学主要软件包Biopython的合著者之一,也是《Python 生物信息学食谱》(Packt,2022 年)一书的作者,该书已进入第三版。他还撰写和合著了许多在生物信息学领域的重大科学论文。
关于封面插图
《快速 Python》封面上的插图标题为“Bourgeoise de Passeau”,或“Bourgeoise of Passeau”,取自雅克·格拉塞·德·圣索沃尔(Jacques Grasset de Saint-Sauveur)1797 年出版的作品集。每一幅插图都是手工精心绘制和着色的。
在那些日子里,人们通过他们的服饰就能轻易地识别出他们的居住地以及他们的职业或社会地位。Manning 通过基于几个世纪前丰富多样的地区文化的封面设计,庆祝计算机行业的创新精神和主动性,这些文化通过如这一系列图片被重新带回生活。
第一部分. 基础方法
在本书的第一部分,我们将讨论关于使用 Python 的性能的基础方法。我们将涵盖原生 Python 库和基本数据结构,以及 Python 如何在没有外部库的情况下利用并行处理技术。还包括一个关于 NumPy 优化的完整章节。虽然 NumPy 是一个外部库,但它对现代数据处理如此关键,以至于它和纯 Python 方法一样基础。
1 数据处理效率的迫切需求
本章涵盖
-
应对数据指数级增长带来的挑战
-
比较传统和最近的计算架构
-
Python 在现代数据分析中的角色和不足
-
提供高效 Python 计算解决方案的技术
每时每刻都在以极高的速度从广泛的来源收集大量数据。无论是否有当前的使用需求,都会被收集。无论是否有处理、存储、访问或从中学习的方法,都会被收集。在数据科学家能够分析它之前,在设计师、开发人员和政策制定者能够用它来创建产品、服务和计划之前,软件工程师必须找到存储和处理它的方法。现在比以往任何时候,这些工程师都需要更有效的方法来提高性能和优化存储。
在这本书中,我分享了我自己在工作中使用的性能和存储优化策略的集合。简单地向问题投入更多的机器通常既不可能也不有帮助。因此,我在这里介绍的方法更多地依赖于理解和利用我们手头拥有的东西:编码方法、硬件和系统架构、可用的软件,当然还有 Python 语言、库和生态系统的细微差别。
正如俗语所说,Python 已经成为处理或至少粘合围绕这一数据洪流的重型工作的首选语言。确实,Python 在数据科学和数据工程中的流行是推动语言增长的主要驱动力之一,帮助它成为大多数开发者调查中排名前三的最受欢迎的语言之一。Python 在处理大数据方面有其独特的优势和局限性,其缺乏速度无疑带来了挑战。从积极的一面来看,正如您将看到的,有许多不同的角度、方法和解决方案可以使 Python 更有效地处理大量数据。
在我们到达解决方案之前,我们需要完全理解问题(s),这正是我们在第一章的大部分内容中将要做的。我们将花几分钟时间更仔细地看看数据洪流带来的计算挑战,以便我们能够确定我们正在处理的是什么。接下来,我们将检查硬件、网络和云架构在其中的作用,以了解为什么像增加 CPU 速度这样的旧解决方案已经不再足够。然后,我们将转向 Python 在处理大数据时面临的特定挑战,包括 Python 的线程和 CPython 的全局解释器锁(GIL)。一旦我们完全理解了需要新的方法来提高 Python 的性能,我将在本书中概述您将学习的解决方案。
1.1 数据洪流有多糟糕?
你可能已经知道两个计算定律,摩尔定律和埃德霍尔姆定律,它们共同描绘了数据指数增长以及计算系统处理这些数据的滞后能力。埃德霍尔姆定律指出,电信中的数据速率每 18 个月翻一番,而摩尔定律预测,每两年可以在微芯片上放置的晶体管数量翻一番。我们可以将埃德霍尔姆的数据传输速率视为收集数据量的代理,将摩尔定律的晶体管密度视为计算硬件的速度和容量的指标。当我们把它们放在一起时,我们发现我们收集数据速度和存储处理数据的能力之间存在六个月的滞后。由于指数增长在文字上难以理解,我在一张图表中绘制了这两个定律,如图 1.1 所示。

图 1.1 摩尔定律与埃德霍尔姆定律的比例表明,硬件将始终落后于生成数据的数量。此外,随着时间的推移,差距将会增大。
该图表所描述的情况可以看作是我们需要分析的内容(埃德霍尔姆定律)与我们进行该分析的能力(摩尔定律)之间的斗争。实际上,这个图表描绘的比现实中的情况更加乐观。我们将在第六章讨论摩尔定律与现代 CPU 架构的背景下了解这一点。为了集中讨论数据增长,让我们看看一个例子,即互联网流量,它是可用数据的间接衡量标准。如图 1.2 所示,多年来互联网流量的增长与埃德霍尔姆定律非常吻合。

图 1.2 过去几年全球互联网流量的增长,以每月千兆字节为单位。(来源:en.wikipedia.org/wiki/Internet_traffic.)
此外,人类在过去两年内产生的 90%的数据(参见“大数据及其意义”,mng.bz/v1ya)。这些新数据的品质是否与其规模成比例是另一个问题。重要的是,产生出的数据需要被处理,而处理则需要资源。
不仅可用的数据量给软件工程师带来了障碍,这些新数据的表示方式也在本质上发生变化。一些预测到 2025 年,大约 80%的数据将是非结构化的(“挖掘非结构化数据的力量”,mng.bz/BlP0)。我们将在本书的后面部分详细讨论这个问题,但简单来说,非结构化数据从计算的角度来看使得数据处理更加复杂。
我们如何处理所有这些数据增长?据《卫报》报道,超过 99% 产生出来的数据从未被分析。mng.bz/Q8M4。阻碍我们利用如此多的数据的一部分原因是,我们缺乏分析它的有效程序。
数据的增长以及随之而来的对更多处理的需求已经发展成为关于计算的最有害的咒语之一:“如果你有更多的数据,只需向它投入更多的服务器。”由于许多原因,这通常不是一个可行或适当的解决方案。相反,当我们需要提高现有系统的性能时,我们可以查看系统架构和实现,并找到我们可以优化性能的地方。我已经数不清有多少次只是通过在审查现有代码时关注效率问题,就能实现性能的十倍提升。
非常重要的是要理解,需要分析的数据量增加与分析所需的基础设施复杂度之间的关系几乎不是线性的。解决这些问题需要开发者投入更多的时间和创造力,而不仅仅是机器。这不仅适用于云环境,也适用于内部集群,甚至适用于单机实现。一些用例将有助于阐明这一点。例如:
-
你的解决方案只需要一台计算机,但突然你需要更多的机器。 添加机器意味着你将不得不管理机器的数量,在它们之间分配工作负载,并确保数据被正确分区。你可能还需要添加一个文件系统服务器到你的机器列表中。维护一个服务器农场或仅仅是云的成本,在质量上远比维护一台计算机的成本要高得多。
-
你的解决方案在内存中运行良好,但随着数据量的增加,它不再适合你的内存。 处理存储在磁盘中的新数据量通常需要重写你的代码。当然,代码本身的复杂性也会增加。例如,如果主数据库现在在磁盘上,你可能需要创建一个缓存策略。或者你可能需要从多个进程中并发读取,或者更糟糕的是,并发写入。
-
你使用的是 SQL 数据库,突然你达到了服务器的最大吞吐量容量。 如果只是读取容量问题,那么你可能会通过仅仅创建几个读取副本来存活。但如果是写入问题,你该怎么办?也许你会设置分片。¹ 或者你决定完全改变你的数据库技术,以支持一些据说性能更好的 NoSQL 变体?
-
如果你依赖于基于供应商专有技术的云系统,你可能会发现无限扩展的能力更多的是营销话术而不是技术现实。 在许多情况下,如果你遇到性能限制,唯一现实的解决方案是改变你正在使用的科技,这种改变需要巨大的时间、金钱和人力。
我希望这些例子能说明,增长不仅仅是“增加更多机器”的问题,而是需要在多个方面进行大量工作以应对增加的复杂性。即使是“简单”如在一个计算机上实现的并行解决方案,也可能带来并行处理的所有问题(竞争、死锁等)。这些更有效的解决方案可以对复杂性、可靠性和成本产生重大影响。
最后,我们可以提出这样的观点:即使我们能够线性扩展我们的基础设施(实际上我们做不到),我们也需要考虑伦理和生态问题:预测表明与“数据海啸”相关的能源消耗占全球电力生产的 20%(《数据海啸》,mng.bz/X5GE),而且在我们更新硬件时,也存在垃圾填埋处理的问题。
好消息是,在处理大数据时提高计算效率可以帮助我们降低计算账单、解决方案架构的复杂性、我们的存储需求、我们的上市时间和我们的能源足迹。有时,更有效的解决方案甚至可能带来最低的实施成本。例如,合理使用数据结构可能在没有显著开发成本的情况下减少计算时间。
另一方面,我们将要探讨的许多解决方案都会有开发成本,并且它们自身也会增加一定的复杂性。当你查看你的数据和其增长预测时,你将不得不做出判断,决定在哪里进行优化,因为没有明确的食谱或适合所有情况的解决方案。尽管如此,可能只有一条普遍适用的规则:如果这个解决方案对 Netflix、Google、Amazon、Apple 或 Facebook 有益,那么很可能它对你来说不是一个好的选择,除非,当然,你在这家公司工作。
我们大多数人将看到的数据量将远远低于最大的科技公司使用的量。它仍然会很大,仍然很难处理,但可能低几个数量级。在我看来,普遍认为那些公司适用的解决方案也适合我们所有人的观点是错误的。通常,更简单的解决方案对我们大多数人来说更合适。
正如你所见,这个新世界,数据量和算法的复杂度都在极端增长,需要更复杂的技巧来高效且成本意识地进行计算和存储。请别误会我:有时你确实需要扩展你的基础设施。但在架构和实现你的解决方案时,你仍然可以使用关注效率的相同思维方式。只是技术将不同。
1.2 现代计算机架构和高性能计算
创建更有效的解决方案并不是在抽象的虚空中发生的。首先,我们必须考虑我们的领域问题——也就是说,你正在尝试解决什么实际问题。同样重要的是我们的解决方案将运行的计算架构。计算架构在确定最佳优化技术方面发挥着重要作用,因此在我们设计软件解决方案时必须考虑它们。在本节中,我们将探讨影响我们解决方案设计和实现的 主要架构问题。
1.2.1 计算机内部的变化
计算机内部正在发生激进的变化。首先,我们有 CPU,它们主要通过增加并行单元的数量来提高处理能力,而不是像过去那样单纯提高速度。计算机还可以配备图形处理单元(GPU),这些单元最初是为了图形处理而开发的,但现在也可以用于通用计算。实际上,许多高效的 AI 算法实现都是针对 GPU 进行的。不幸的是,至少从我们的角度来看,GPU 的架构与 CPU 完全不同:它们由成千上万的计算单元组成,这些单元预计将在所有单元上执行相同的“简单”计算。内存模型也完全不同。这些差异意味着编程 GPU 需要与编程 CPU 截然不同的方法。
要了解我们如何使用 GPU 进行数据处理,我们需要了解它们的原始用途和架构影响。正如其名所示,GPU 是为了帮助图形处理而开发的。一些计算需求最密集的应用实际上是游戏。游戏以及一般的图形应用,都在不断更新屏幕上的数百万个像素。为了解决这个问题而设计的硬件架构拥有许多小的处理核心。GPU 拥有数千个核心是很常见的,而 CPU 通常只有不到 10 个。GPU 核心相对简单,每个核心上运行的是相同的代码。因此,它们非常适合运行大量的相似任务,比如更新像素。
由于 GPU 中处理能力的巨大,人们试图利用这种能力来完成其他任务,这看起来是通用计算在图形处理单元(GPGPU)上的出现。由于 GPU 架构的组织方式,它们主要适用于本质上是大规模并行的任务。结果发现,许多现代 AI 算法,如基于神经网络的算法,往往具有大规模并行性。因此,两者之间有一个自然契合点。
不幸的是,CPU 和 GPU 之间的区别不仅仅是核心数量和它们的复杂性。GPU 内存,尤其是在计算能力最强的 GPU 上,与主内存是分开的。因此,也存在在主内存和 GPU 内存之间传输数据的问题。所以,当我们针对 GPU 时,我们必须考虑两个巨大的问题。
由于在第九章中将会变得清晰的原因,使用 Python 编程 GPU 比针对 CPU 要困难得多,也不太实用。尽管如此,仍然有足够的空间从 Python 中利用 GPU。
虽然不如 GPU 的进步那么时尚,但 CPU 编程的方式也经历了巨大的变化。而且,与 GPU 不同,我们可以很容易地在 Python 中使用这些 CPU 的大部分变化。制造商现在提供 CPU 性能提升的方式与过去不同。受物理定律驱动的解决方案是内置更多的并行处理,而不是更多的速度。摩尔定律有时被表述为每 24 个月速度翻倍,但实际上这并不是正确的定义:它实际上与每两年晶体管密度翻倍相关。速度和晶体管密度之间的线性关系在十多年前就已经破裂,速度自那时以来基本上已经达到了平台期。鉴于数据量和算法复杂性的持续增长,我们陷入了危险的境地。CPU 制造商提出的第一个解决方案是允许更多的并行性:每台计算机更多的 CPU,每个 CPU 更多的核心,以及同时多线程。处理器不再真正加速顺序计算,而是允许更多的并发执行。这种并发执行需要我们在编程计算机时进行范式转变。以前,当你更换 CPU 时,程序的速度会“神奇地”增加。现在,提高速度取决于程序员是否意识到底层架构向并行编程范式的转变。
我们在编程现代 CPU 方式上有很多变化,正如你将在第六章中看到的,其中一些变化非常不符合直觉,值得从一开始就密切关注。例如,尽管近年来 CPU 速度已经趋于平稳,但 CPU 仍然比 RAM 快几个数量级。如果没有 CPU 缓存,那么 CPU 将大部分时间都处于空闲状态,因为它们会花费大部分时间等待 RAM。这意味着有时与压缩数据(包括解压缩成本)相比,使用原始数据更快。为什么?如果你可以将压缩块放在 CPU 缓存中,那么那些原本会空闲等待 RAM 访问的周期就可以用来使用 CPU 周期解压缩数据,而此时还有富余的 CPU 周期可以用于计算!类似的论点也可以适用于压缩文件系统:它们有时可能比原始文件系统更快。在 Python 世界中有直接的应用;例如,通过更改一个简单的布尔标志,关于 NumPy 数组内部表示的选择,你可以利用缓存局部性问题,显著加快你的 NumPy 处理速度。表 1.1 包含了不同类型内存的访问时间和大小,包括 CPU 缓存、RAM、本地磁盘和远程存储。这里的关键点不是精确的数字,而是大小和访问时间上的数量级差异。
表 1.1 假设但现实的现代桌面内存层次结构,包括大小和访问时间
| 类型 | 大小 | 访问时间 |
|---|---|---|
| CPU | ||
| L1 缓存 | 256 KB | 2 ns |
| L2 缓存 | 1 MB | 5 ns |
| L3 缓存 | 6 MB | 30 ns |
| RAM | ||
| DIMM | 8 GB | 100 ns |
| 二级存储 | ||
| SSD | 256 GB | 50 µs |
| HDD | 2 TB | 5 ms |
| 三级存储 | ||
| NAS - 网络访问服务器 | 100 TB | 网络依赖 |
| 云专有 | 1 PB | 提供商依赖 |
表 1.1 包含了三级存储,它发生在计算机外部。那里也发生了变化,我们将在下一节中讨论。
1.2.2 网络的变化
在高性能计算环境中,我们使用网络作为添加更多存储的方式,特别是增加计算能力的方式。虽然我们希望使用单台计算机解决问题,但有时依赖计算集群是不可避免的。针对具有多台计算机的架构进行优化——无论是在云端还是在本地——将是我们追求高性能旅程的一部分。
使用许多计算机和外部存储会带来与分布式计算相关的一系列新问题:网络拓扑、跨机器共享数据以及管理跨网络的进程。有许多例子。例如,在需要高性能和低延迟的服务上使用 REST API 的成本是多少?我们如何处理远程文件系统的惩罚;能否减轻这些惩罚?
我们将尝试优化我们对网络堆栈的使用,为此,我们必须了解图 1.3 中显示的所有级别。在网络之外,我们有我们的代码和 Python 库,它们对下面的层做出选择。在网络堆栈的顶部,数据传输的一个典型选择是基于 JSON 的有效负载的 HTTPS。
虽然这对于许多应用来说是一个完全合理的选择,但在网络速度和延迟很重要的案例中,有更高效的替代方案。例如,二进制有效负载可能比 JSON 更有效。此外,HTTP 可能被直接 TCP 套接字所取代。但还有更激进的替代方案,比如替换 TCP 传输层:大多数互联网应用协议都使用 TCP,尽管有一些例外,如 DNS 和 DHCP,它们都是基于 UDP 的。TCP 协议非常可靠,但为了这种可靠性,必须付出性能的代价。有时,UDP 较小的开销将是一个更有效的替代方案,而额外的可靠性是不需要的。

图 1.3 通过网络堆栈的 API 调用。了解网络通信的可用的替代方案可以显著提高基于互联网的应用程序的速度。
在传输协议之下,我们有互联网协议(IP)和物理基础设施。在设计我们的解决方案时,物理基础设施可能很重要。例如,如果我们有一个非常可靠的本地网络,那么可以丢失数据的 UDP 将比在不可靠的网络中更有可能成为一个替代方案。
1.2.3 云
在过去,大多数数据处理实现都是为了在单个计算机或由运行工作负载的同一家组织维护的本地集群上运行。目前,所有服务器都是“虚拟”的,并由外部实体维护的基于云的基础设施越来越普遍。有时,就像所谓的无服务器计算一样,我们甚至不直接处理服务器。
云不仅仅是增加更多的计算机或网络存储。它还涉及一系列关于如何处理存储和计算资源的专有扩展,这些扩展在性能方面有影响。此外,虚拟计算机可能会对某些 CPU 优化造成干扰。例如,在裸机机器上,你可以设计出考虑缓存局部性问题的解决方案,但在虚拟机上,你无法知道你的缓存是否被同时执行的其他虚拟机抢占。我们如何在这样的环境中保持算法的高效性?此外,云计算的成本模型完全不同——时间就是金钱——因此,高效的解决方案变得更加重要。
云端中的许多计算和存储解决方案也是专有的,并且具有非常具体的 API 和行为。使用这些专有解决方案也会对性能产生后果,这是需要考虑的。因此,尽管大多数与传统集群相关的问题也适用于云端,但有时会有一些特定的问题需要单独处理。现在我们已经了解了将塑造我们应用程序的架构可能性和限制,让我们转向 Python 在高性能计算中的优缺点。
1.3 与 Python 的限制共事
Python 在现代数据处理应用中被广泛使用。与任何语言一样,它都有其优点和缺点。使用 Python 有很多很好的理由,但在这里我们更关注处理 Python 在高性能数据处理中的限制。
让我们不要美化现实:Python 在处理高性能计算方面明显准备不足。如果性能和并行性是唯一考虑因素,没有人会使用 Python。Python 有一个惊人的用于数据分析的库生态系统、优秀的文档和令人惊叹的支持社区。这就是我们使用它的原因,而不是计算性能。
有一种说法,大意是这样的:“没有慢的语言,只有慢的语言实现。”我希望你能允许我不同意这一点。要求像 Python(或者,比如说,JavaScript)这样的动态、高级语言的实现者与像 C、C++、Rust 或 Go 这样的低级语言在速度上进行竞争是不公平的。
像动态类型和垃圾回收这样的特性会在性能上付出代价。这是可以接受的:在许多情况下,程序员的时间比计算时间更有价值。但让我们不要逃避现实:更声明性和动态的语言在计算和内存上也会付出代价。这是一个平衡。
话虽如此,这并不是表现不佳的语言实现的借口。在这方面,作为你可能正在使用的旗舰 Python 实现,CPython 的表现如何?进行完整分析并不容易,但你可以做一个简单的练习:编写一个矩阵乘法函数并计时。然后,例如,用另一个 Python 实现(如 PyPy)运行它。然后,将你的代码转换为 JavaScript(由于语言也是动态的,这是一个公平的比较;不公平的比较将是 C 语言)并再次计时。
提前剧透:CPython 的表现不会太好。我们有一个自然速度较慢的语言,以及一个似乎没有将速度作为主要考虑的旗舰实现。现在,好消息是大多数这些问题都可以克服。许多人已经开发出了应用程序和库,可以缓解大多数性能问题。你仍然可以用 Python 编写代码,使其具有非常小的内存占用并表现出色。你只需在编写代码时注意 Python 的瑕疵即可。
注意:在本书的大部分内容中,当我们提到 Python 时,我们指的是 CPython 实现。所有违反此规则的例外都将明确指出。
由于 Python 在性能方面的限制,有时仅仅优化我们的 Python 代码可能是不够的。在这种情况下,我们最终将不得不用更低级的语言重写那部分代码,或者至少用注释来标记我们的代码,以便通过某些代码转换工具将其重写为更低级的语言。我们需要重写的代码部分通常非常小,所以我们绝对不是要放弃 Python。当我们进行这个最后的优化阶段时,可能超过 90%的代码仍然是 Python。这正是许多核心科学库(如 NumPy、scikit-learn 和 SciPy)实际的做法:它们最计算密集的部分通常是用 C 或 Fortran 实现的。
1.3.1 全局解释器锁
在关于 Python 性能的讨论中,其 GIL(全局解释器锁)不可避免地会被提及。GIL 究竟是什么?虽然 Python 有线程的概念,但 CPython 有一个 GIL,它只允许在某个时间点只有一个线程执行。即使在多核处理器上,你也只能得到一个在某个时间点执行的线程。
Python 的其他实现,如 Jython 和 IronPython,没有 GIL,并且可以使用现代多处理器的所有核心。但 CPython 仍然是所有主要库开发的基础实现。此外,Jython 和 IronPython 分别依赖于 JVM 和.NET。因此,鉴于其庞大的库基础,CPython 最终成为默认的 Python 实现。我们将在书中简要讨论其他实现,特别是 PyPy,但在实践中,CPython 是女王。
要了解如何绕过 GIL,记住并发和并行之间的区别是有用的。你可能记得,并发是指一定数量的任务可以在时间上重叠,尽管它们可能不是同时运行的。例如,它们可以交错。并行是指任务同时执行。因此,在 Python 中,并发是可能的,但并行性不是……或者它是吗?
没有并行性的并发仍然非常有用。最好的例子来自 JavaScript 世界和 Node.JS,它被广泛用于实现 Web 服务器的后端。在许多服务器端 Web 任务中,大部分时间实际上都花在等待 I/O 上;这是一个线程自愿放弃控制权以便其他线程可以继续计算的好时机。现代 Python 有类似的异步功能,我们将会讨论它们。
但回到主要问题:GIL 是否会对性能造成严重的惩罚?在大多数情况下,答案是令人惊讶的“不”。这有两个主要原因:
-
大多数高性能代码,那些紧密的内循环,可能需要用更低级的语言来编写,正如我们之前讨论的那样。
-
Python 为底层语言提供了释放 GIL 的机制。
这意味着当您进入用底层语言重写的代码部分时,您可以指示 Python 与您的底层实现并行继续其他 Python 线程。您只有在安全的情况下才应该释放 GIL——例如,如果您不写入可能被其他线程使用的对象。
此外,多进程(即同时运行多个进程)不受 GIL 的影响,GIL 只影响线程,因此在纯 Python 中仍有大量空间来部署并行解决方案。
因此,从理论上讲,GIL 是与性能相关的问题,但在实践中,它很少是造成无法克服问题的根源。我们将在第三章深入探讨这个主题。
1.4 解决方案的总结
本书是关于从 Python 中获得高性能的,但代码不是在真空中运行的。只有当您考虑数据、算法需求以及计算架构的更广泛视角时,您才能设计出高效的代码。虽然在一本书中不可能详细讨论每个架构和算法细节,但我的目标是帮助您理解 CPU 设计、GPU、存储替代方案、网络协议和云架构以及其他系统考虑(图 1.4)的影响,以便您可以为提高 Python 代码的性能做出明智的决定。本书应该使您能够评估您的计算架构的优点和缺点,无论它是单台计算机、带 GPU 的计算机、集群还是云环境,并实施必要的更改以充分利用它。

图 1.4 在选择高性能编码解决方案时必须考虑底层硬件架构。
本书的目标是向您介绍一系列解决方案,并展示每个解决方案的最佳应用方式和地点,以便您可以根据自己的资源、目标和问题选择并实施最有效的解决方案。我们将花费大量时间通过实例来让您亲自看到这些方法的效果,无论是积极的还是消极的。没有强制要求应用所有方法,也没有规定应用它们的顺序。每种方法在性能和效率方面都有其更大的或较小的收益,以及其权衡。如果您了解您系统中的资源和可用于改进该系统方面的策略,您就可以选择在哪里花费时间和资源。为了帮助您理解这些方法,表 1.2 总结了本书中介绍的技术及其针对的系统开发过程组件或领域。
表 1.2 书中每章的目的
| 领域 | 应用 | 章节 |
|---|---|---|
| 充分利用您的 Python 解释器 | Python 解释器 | 2 从内置功能中提取最大性能 |
| 理解 Python 的内部功能以从您的计算机中提取最大计算能力 | Python 解释器 | 3 并发、并行和异步处理 |
| 从数据科学的基本库中提取最大性能 | Python 库 | 4 高性能 NumPy |
| 探索当 Python 不够用时低级语言的性能 | Python 库 | 5 使用 Cython 重新实现关键代码 |
| 理解硬件对计算性能的影响 | 硬件 | 6 内存层次结构、存储和网络 |
| 从表格数据中提取最大性能 | Python 库 | 7 高性能 pandas 和 Apache Arrow |
| 通过使用现代 Python 持久化库提高存储效率 | Python 库 | 8 存储大数据 |
| 从 Python 理解 GPU 计算的重要性及其使用 | 硬件 | 9 使用 GPU 计算进行数据分析 |
| 处理需要多台计算机处理的应用程序 | Python 库和硬件 | 10 使用 Dask 分析大数据 |
那张表格中有很多内容,所以让我强调一下主要关注领域的实际应用。阅读这本书后,您将能够查看原生 Python 代码并理解内置数据结构和算法的性能影响。您将能够检测并替换效率低下的结构,以更合适的解决方案:例如,在重复搜索固定列表时用集合替换列表,或者为了速度使用非对象数组而不是对象列表。您还将能够对现有性能不佳的算法进行以下操作:(1)分析代码以找到导致性能问题的部分,以及(2)确定优化这些代码片段的最佳方式。
正如所述,本书旨在通过改进我们使用它们的方式,解决广泛使用的 Python 数据处理和分析库生态系统(如 pandas 和 NumPy)。在计算方面,这是一大块材料,所以我们不会讨论非常高级的库。例如,我们不会讨论优化 TensorFlow 的使用,但我们会讨论使底层算法更有效的技术。
关于数据存储和转换,您将能够查看数据源并理解其在高效处理和存储方面的不足。然后您将能够以这种方式转换数据,即所有所需信息仍然保持不变,但数据的访问模式将大大提高效率。最后,您还将了解 Dask,这是一个基于 Python 的框架,允许您开发可以从小型单机扩展到非常大的计算机集群或云计算解决方案的并行解决方案。
这与其说是一本食谱书,不如说是一本关于优化思考方式以及性能提升研究领域的介绍。因此,讨论的方法应该大部分能够适应硬件、软件、网络、系统甚至数据本身的变更。尽管并非每种技术都会在每个情况下都盈利,甚至可用,但从头到尾阅读这本书是理解您选项、开阔思路并可能提出一些自己解决方案的最可靠方式。一旦您接触到了各种可能性,您就可以将这本书作为参考,挑选您想要实施的技术。
注意软件设置:在继续之前,请务必查看附录 A,了解设置环境选项的描述,这样您就可以在每个练习中运行代码示例。代码列表本身位于github.com/tiagoantao/python-performance。
摘要
-
是的,这个陈词滥调是正确的:数据量很大,如果我们想从数据中提取最大价值,就必须提高处理数据的效率。
-
算法复杂性的增加给计算成本带来了额外的压力,我们将不得不找到减轻计算影响的方法。
-
计算架构存在很大的异质性:网络现在也包括基于云的方法。在我们的计算机内部,现在有强大的 GPU,其计算范式与 CPU 大不相同。我们需要能够利用这些资源。
-
Python 是一种用于数据分析的出色语言,周围环绕着完整的数据处理库和框架生态系统。但它在性能方面也存在严重问题。我们需要能够绕过这些问题,以便使用复杂的算法处理大量数据。
-
虽然我们将要处理的一些问题可能很困难,但它们大多是可解决的。本书的目标是向您介绍大量的替代解决方案,并教会您如何以及在哪里最好地应用每个解决方案,以便您可以选择并实施您遇到任何问题的最有效解决方案。
¹ 分片是指将数据分割,使其部分存储在不同的服务器上。
2 从内置功能中提取最大性能
本章涵盖
-
分析代码以找到速度和内存瓶颈
-
更高效地使用现有的 Python 数据结构
-
理解 Python 分配典型数据结构的内存成本
-
使用惰性编程技术处理大量数据
有许多工具和库可以帮助我们编写更高效的 Python。但在我们深入研究所有外部选项以提高性能之前,让我们首先更仔细地看看我们如何编写更高效的纯 Python 代码,无论是在计算性能还是 IO 性能上。事实上,尽管并非所有,许多 Python 性能问题都可以通过更加关注 Python 的限制和能力来解决。
为了展示 Python 自身用于提高性能的工具,让我们在一个假设的、尽管是现实的问题上使用它们。假设你是一名数据工程师,负责准备全球气候数据分析。数据将基于美国国家海洋和大气管理局(NOAA)的集成表面数据库(mng.bz/ydge)。你面临紧迫的截止日期,并且只能使用大部分标准的 Python。此外,由于预算限制,购买更多处理能力是不可能的。数据将在一个月后开始到达,你计划在数据到达之前利用这段时间来提高代码性能。那么,你的任务就是找到需要优化的地方,并提高它们的性能。
你首先想做的事情是对将要处理数据的现有代码进行性能分析。你知道你现有的代码很慢,但在尝试优化它之前,你需要找到瓶颈的实证证据。性能分析很重要,因为它允许你以严格和系统的方式搜索代码中的瓶颈。最常见的替代方案,即猜测,在这里特别无效,因为许多减速点可能非常不直观。
优化纯 Python 代码是低垂的果实,但也是大多数问题倾向于驻留的地方,因此优化通常非常有优势。在本章中,我们将看到纯 Python 提供了哪些现成功能来帮助我们开发更高效的代码。我们将从使用几个性能分析工具对代码进行性能分析开始,以检测问题区域。然后我们将关注 Python 的基本数据结构:列表、集合和字典。我们的目标将是提高这些数据结构的效率,并以最佳方式分配内存以实现最佳性能。最后,我们将看到现代 Python 惰性编程技术如何帮助我们提高数据管道的性能。
本章将仅讨论在不使用外部库的情况下优化 Python,但我们仍将使用一些外部工具来帮助我们优化性能和访问数据。我们将使用 Snakeviz 来可视化 Python 性能分析的结果,以及 line_profiler 逐行分析代码。最后,我们将使用 requests 库从互联网下载数据。
如果你使用 Docker,默认镜像已经包含了你需要的一切。如果你遵循附录 A 中 Anaconda Python 的说明,你就可以开始了。现在,让我们通过从气象站下载数据并研究每个站点的温度来开始我们的性能分析过程。
2.1 使用 IO 和计算工作负载进行性能分析
我们的首要目标将是下载气象站的数据,并获取该站某一年份的最低温度。NOAA 网站上的数据有 CSV 文件,每年一个,每个站点一个。例如,文件www.ncei.noaa.gov/data/global-hourly/access/2021/01494099999.csv包含了 2021 年站点 01494099999 的所有条目。这包括其他条目,如温度和压力,可能每天记录几次。
让我们编写一个脚本来下载一组站点在多年间的数据。在下载所需数据后,我们将得到每个站点的最低温度。
2.1.1 下载数据和计算最低温度
我们的脚本将有一个简单的命令行界面,其中我们传递一个站点列表和一个感兴趣的年份间隔。以下是解析输入的代码(代码可在02-python/sec1-io-cpu/load.py中找到):
import collections
import csv
import datetime
import sys
import requests
stations = sys.argv[1].split(",")
years = [int(year) for year in sys.argv[2].split("-")]
start_year = years[0]
end_year = years[1]
为了简化编码部分,我们将使用 requests 库来获取文件。以下是下载服务器数据的代码:
TEMPLATE_URL = "https://www.ncei.noaa.gov/data/global-hourly/access/{year}/
➥ {station}.csv"
TEMPLATE_FILE = "station_{station}_{year}.csv"
def download_data(station, year):
my_url = TEMPLATE_URL.format(station=station, year=year)
req = requests.get(my_url) ①
if req.status_code != 200:
return # not found
w = open(TEMPLATE_FILE.format(station=station, year=year), "wt")
w.write(req.text)
w.close()
def download_all_data(stations, start_year, end_year):
for station in stations:
for year in range(start_year, end_year + 1):
download_data(station, year)
① Requests 库使得访问网络内容变得简单。
此代码将把所有请求的站点在所有年份下载的每个文件写入磁盘。现在,让我们将所有温度放入一个文件中:
def get_file_temperatures(file_name):
with open(file_name, "rt") as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
station = row[header.index("STATION")]
# date = datetime.datetime.fromisoformat(row[header.index('DATE')])
tmp = row[header.index("TMP")]
temperature, status = tmp.split(",") ①
if status != "1": ②
continue
temperature = int(temperature) / 10
yield temperature
① 温度字段的格式包括一个表示数据质量状态的字段。
② 我们忽略数据不可用的条目。
现在,让我们获取所有温度和每个站点的最低温度:
def get_all_temperatures(stations, start_year, end_year):
temperatures = collections.defaultdict(list)
for station in stations:
for year in range(start_year, end_year + 1):
for temperature in get_file_temperatures(
➥ TEMPLATE_FILE.format(station=station, year=year)):
temperatures[station].append(temperature)
return temperatures
def get_min_temperatures(all_temperatures):
return {station: min(temperatures) for station, temperatures in
➥ all_temperatures.items()}
现在我们可以将所有东西结合起来:下载数据,获取所有温度,计算每个站点的最低温度,并打印结果:
download_all_data(stations, start_year, end_year)
all_temperatures = get_all_temperatures(stations, start_year, end_year)
min_temperatures = get_min_temperatures(all_temperatures)
print(min_temperatures)
例如,为了加载 2021 年站点 01044099999 和 02293099999 的数据,我们执行以下操作
python load.py 01044099999,02293099999 2021-2021
输出结果为
{'01044099999': -10.0, '02293099999': -27.6}
现在,真正的乐趣开始了。我们的目标是继续从许多站点下载多年的大量数据。为了处理这些数据量,我们希望使代码尽可能高效。使代码更高效的第一步是以有组织和彻底的方式进行性能分析,以找到减缓其速度的瓶颈。为此,我们将使用 Python 内置的性能分析工具。
2.1.2 Python 内置的性能分析模块
由于我们希望确保我们的代码尽可能高效,我们首先需要做的是找到该代码中现有的瓶颈。我们的第一步将是分析代码以检查每个函数的时间消耗。为此,我们通过 Python 的 cProfile 模块运行代码。此模块是 Python 内置的,允许我们从代码中获取分析信息。确保您不要使用 profile 模块,因为它慢得多;它仅在你自己开发分析工具时有用。
我们可以通过以下方式运行分析器:
python -m cProfile -s cumulative load.py 01044099999,02293099999 2021-2021 > profile.txt
记住,使用 -m 标志运行 Python 将执行模块,因此我们正在运行 cProfile 模块。这是 Python 推荐的模块,用于收集分析信息。我们要求按累积时间排序的统计信息。使用该模块的最简单方法是将我们的脚本传递给分析器,如下所示:
375402 function calls (370670 primitive calls) in 3.061 seconds ①
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
158/1 0.000 0.000 3.061 3.061 {built-in method builtins.exec}
1 0.000 0.000 3.061 3.061 load.py:1(<module>)
1 0.001 0.001 2.768 2.768 load.py:27(download_all_data)
2 0.001 0.000 2.766 1.383 load.py:17(download_data)
2 0.000 0.000 2.714 1.357 api.py:64(get)
2 0.000 0.000 2.714 1.357 api.py:16(request)
2 0.000 0.000 2.710 1.355 sessions.py:470(request)
2 0.000 0.000 2.704 1.352 sessions.py:626(send)
3015 0.017 0.000 1.857 0.001 socket.py:690(readinto)
3015 0.017 0.000 1.829 0.001 ssl.py:1230(recv_into)
[...]
1 0.000 0.000 0.000 0.000 load.py:58(get_min_temperatures) ②
① 基本摘要信息可以在第一行找到:函数调用次数和总运行时间。
② 我们代码的计算成本(计算在 get_min_temperatures 中完成)可以忽略不计。
输出按累积时间排序,即在一个特定函数内部花费的所有时间。另一个输出是每个函数的调用次数。例如,只有一个对 download_all_data 的调用(负责下载所有数据),但其累积时间几乎等于脚本的总体时间。你会注意到有两个名为 percall 的列。第一个列表示不包括所有子调用花费的时间在函数上的时间。第二个列包括子调用花费的时间。在 download_all_data 的例子中,很明显,大部分时间被一些子函数消耗了。
在许多情况下,当你有一些像这里这样的密集型 I/O 时,有很强的可能性 I/O 在所需时间方面占主导地位。在我们的例子中,我们既有网络 I/O(从 NOAA 获取数据)也有磁盘 I/O(将其写入磁盘)。网络成本可能差异很大,甚至在运行之间,因为它们依赖于沿途的许多连接点。由于网络成本通常是最大的时间消耗,让我们尝试减轻这些成本。
2.1.3 使用本地缓存以减少网络使用
为了减少网络通信,当我们在第一次下载文件时,让我们保存一个副本以备将来使用。我们将建立一个本地数据缓存。我们将使用与之前相同的代码,除了函数 download_all_data(代码可以在 02-python/sec1-io-cpu/load_cache.py 中找到):
import os
def download_all_data(stations, start_year, end_year):
for station in stations:
for year in range(start_year, end_year + 1):
if not os.path.exists(TEMPLATE_FILE.format(
➥ station=station, year=year)): ①
download_data(station, year)
① 我们检查文件是否已存在,如果不存在则下载它。
代码的第一次运行将与之前的解决方案花费相同的时间,但第二次运行将不需要任何网络访问。例如,给定与之前相同的运行,它从 2.8 秒减少到 0.26 秒——超过一个数量级。记住,由于网络访问的高变异性,下载文件的时间在你的情况下可能会有很大的变化。这是考虑缓存网络数据的另一个原因:拥有更可预测的执行时间:
python -m cProfile -s cumulative load_cache.py 01044099999,02293099999
➥ 2021-2021 > profile_cache.txt
现在,消耗时间的地方不同了:
299938 function calls (295246 primitive calls) in 0.260 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
156/1 0.000 0.000 0.260 0.260 {built-in method builtins.exec}
1 0.000 0.000 0.260 0.260 load_cache.py:1(<module>)
1 0.008 0.008 0.166 0.166 load_cache.py:51(
➥ get_all_temperatures)
33650 0.137 0.000 0.156 0.000 load_cache.py:36(
➥ get_file_temperatures)
[...]
1 0.000 0.000 0.001 0.001 load_cache.py:60(
➥ get_min_temperatures)
虽然运行时间减少了一个数量级,但 IO 仍然是首要的。现在,不是网络而是磁盘访问。这主要是由于计算实际上很低。
警告:缓存,正如这个例子所示,可以以数量级的方式加快代码速度。然而,缓存管理可能会出现问题,并且是 bug 的常见来源。在我们的例子中,文件随时间不会改变,但有许多缓存用例,源可能会改变。在这种情况下,缓存管理代码需要识别这个问题。我们将在本书的其他部分重新审视缓存。
现在,我们将考虑一个 CPU 是限制因素的情况。
2.2 对代码进行性能分析以检测性能瓶颈
在这里,我们查看 CPU 是过程中耗时最多的资源的代码。我们将取 NOAA 数据库中的所有站点并计算它们之间的距离,这是一个复杂度为n2的问题。
在仓库中,你会找到一个文件(02-python/sec2-cpu/locations.csv),其中包含所有站点的地理坐标(代码可以在02-python/sec2-cpu/distance_cache.py中找到):
import csv
import math
def get_locations():
with open("locations.csv", "rt") as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
station = row[header.index("STATION")]
lat = float(row[header.index("LATITUDE")])
lon = float(row[header.index("LONGITUDE")])
yield station, (lat, lon)
def get_distance(p1, p2): ①
lat1, lon1 = p1
lat2, lon2 = p2
lat_dist = math.radians(lat2 - lat1)
lon_dist = math.radians(lon2 - lon1)
a = (
math.sin(lat_dist / 2) * math.sin(lat_dist / 2) +
math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) *
math.sin(lon_dist / 2) * math.sin(lon_dist / 2)
)
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
earth_radius = 6371
dist = earth_radius * c
return dist
def get_distances(stations, locations):
distances = {}
for first_i in range(len(stations) - 1):
first_station = stations[first_i] ②
first_location = locations[first_station]
for second_i in range(first_i, len(stations)):
second_station = stations[second_i] ②
second_location = locations[second_station]
distances[(first_station, second_station)] = get_distance(
first_location, second_location)
return distances
locations = {station: (lat, lon) for station, (lat, lon) in get_locations()}
stations = sorted(locations.keys())
distances = get_distances(stations, locations)
① 这是计算两个站点之间距离的代码。
② 由于我们正在比较所有站点之间的相互关系,其复杂度为 n2。
之前的代码将运行很长时间。它也消耗了很多内存。如果你有内存问题,限制你正在处理的站点的数量。现在,让我们使用 Python 的性能分析基础设施来查看大部分时间花在了哪里。
2.2.1 可视化性能分析信息
再次,我们使用 Python 的性能分析基础设施来找到延迟执行的部分代码。但为了更好地检查跟踪,我们将使用一个外部可视化工具,SnakeViz (jiffyclub.github.io/snakeviz/)。
我们首先保存一个性能分析跟踪:
python -m cProfile -o distance_cache.prof distance_cache.py
-o参数指定了性能信息将被存储的文件。之后,我们像往常一样调用我们的代码。
注意 Python 提供了pstats模块来分析写入磁盘的跟踪。你可以执行python -m pstats distance_cache.prof,这将启动一个命令行界面来分析我们脚本的成本。你可以在 Python 文档或第五章的性能分析部分找到更多关于此模块的信息。
为了分析这些信息,我们将使用基于网络的可视化工具 SnakeViz。你只需要执行 snakeviz distance_cache.prof。这将启动一个交互式浏览器窗口(图 2.1 展示了截图)。
熟悉 SnakeViz 界面
这将是玩 SnakeViz 界面一段时间的好时机。例如,你可以将样式从 Icicle 改为 Sunburst(虽然可能更可爱,但信息较少,因为文件名消失了)。重新排列底部的表格。检查深度和截止值。别忘了点击一些彩色块,最后通过点击调用栈并选择 0 条记录返回主视图。

图 2.1 使用 SnakeViz 检查脚本分析信息
大部分时间都是在 get_distance 函数内部度过的,但具体在哪里呢?我们可以看到一些数学函数的成本,但 Python 的分析并不允许我们以细粒度查看每个函数内部发生的事情。我们只能得到每个三角函数的汇总视图。是的,我们在 math.sin 上花费了一些时间,但鉴于我们在几行代码中使用它,我们到底在哪些地方付出了高昂的代价?为此,我们需要请线分析模块帮忙。
2.2.2 行分析
内置分析,就像我们之前使用的那样,允许我们找到导致巨大延迟的代码片段。但它的功能是有限的。我们将在下面讨论这些限制,并介绍行分析作为找到代码中进一步性能瓶颈的方法。
要了解 get_distance 每一行的成本,我们将使用 line_profiler 包,该包可在 github.com/pyutils/line_profiler 找到。使用行分析器相当简单:你只需要在 get_distance 上添加一个注释:
@profile
def get_distance(p1, p2):
你可能已经注意到我们没有从任何地方导入 profile 注释。这是因为我们将使用来自 line_profiler 包的便利脚本 kernprof,它会处理这个问题。那么,让我们在我们的代码中运行行分析器:
kernprof -l lprofile_distance_cache.py
准备好线分析器所需的仪器可能会显著减慢代码,降低几个数量级。让它运行一分钟或更长时间,然后中断它(如果你让它完成,kernprof 可能会运行数小时)。如果你中断它,你仍然会有一个跟踪记录。分析器结束后,你可以使用以下命令查看结果:
python -m line_profiler lprofile_distance_cache.py.lprof
如果您查看 2.1 列表中显示的输出,您会看到其中有很多耗时较长的调用。因此,我们可能希望优化这段代码。在这个阶段,因为我们只讨论性能分析,所以我们在这里停止,但之后,我们需要优化这些行(我们将在本章后面进行)。如果您对优化这段代码感兴趣,可以查看第六章关于 Cython 的内容或附录 B 关于 Numba 的内容,因为它们提供了最直接提高速度的方法。
列表 2.1 line_profiler包为我们代码的输出
Timer unit: 1e-06 s
Total time: 619.401 s ①
File: lprofile_distance_cache.py
Function: get_distance at line 16
Line # Hits Time Per Hit % Time Line Contents ②
==============================================================
16 @profile
17 def get_distance(p1, p2):
18 84753141 36675975.0 0.4 5.9 lat1, lon1 = p1
19 84753141 35140326.0 0.4 5.7 lat2, lon2 = p2
20
21 84753141 39451843.0 0.5 6.4 lat_dist = math.
➥ radians(lat2 -lat1)
22 84753141 38480853.0 0.5 6.2 lon_dist = math.
➥ adians(lon2 - lon1)
23 84753141 28281163.0 0.3 4.6 a = (
24 169506282 84658529.0 0.5 13.7 math.sin(lat_dist / 2)
➥ * math.sin(
➥ lat_dist / 2) +
25 254259423 118542280.0 0.5 19.1 math.cos(math.radians(
➥ lat1)) * math.cos(
➥ math.radians(
➥ lat2)) *
26 169506282 81240276.0 0.5 13.1 math.sin(lon_dist / 2)
➥ * math.sin(
➥ lon_dist / 2)
27 )
28 84753141 65457056.0 0.8 10.6 c = 2 * math.atan2(
➥ math.sqrt(a),
➥ math.sqrt(1 - a))
29 84753141 29816074.0 0.4 4.8 earth_radius = 6371
30 84753141 33769542.0 0.4 5.5 dist = earth_radius * c
31
32 84753141 27886650.0 0.3 4.5 return dist
① 我们代码的总运行时间
② 我们为每个正在性能分析的行获得的信息。对于每一行,我们得到该行被调用的次数、该行上花费的总时间、每次调用的时间和该行上花费的时间百分比。
希望您会发现line_profiler的输出比内置性能分析器的输出更直观。
2.2.3 要点:代码性能分析
正如我们所见,整体内置性能分析是一个很好的起点;它也比行性能分析快得多。但是,行性能分析提供了更多的信息,主要是因为内置的 Python 性能分析没有在函数内部提供细分。相反,Python 的性能分析只提供每个函数的累积值,以及显示在子调用上花费的时间。在特定情况下,可以知道子调用是否属于另一个函数,但通常情况下,这是不可能的。性能分析的整体策略需要考虑所有这些因素。
我们在这里使用的策略是一种通常合理的做法:首先,尝试使用内置的 Python 性能分析模块cProfile,因为它速度快,并提供了一些高级信息。如果这还不够,使用行性能分析,它提供了更多的信息,但速度较慢。记住,我们在这里主要关注定位瓶颈;后面的章节将提供优化代码的方法。有时仅仅改变现有解决方案的一部分是不够的,需要进行整体重构;我们也将适时讨论这一点。
其他性能分析工具
如果您正在性能分析代码,许多其他实用工具都可能很有用,但如果没有提到这些工具之一,性能分析部分将不会完整,那就是timeit模块。这可能是新来者最常用的性能分析方法,您可以在互联网上找到无数使用timeit模块的示例。使用timeit模块的最简单方法是使用 IPython 或 Jupyter Notebook,因为这些系统使timeit非常流畅。只需将%timeit魔法命令添加到您想要性能分析的代码中即可,例如,在 iPython 中:
In [1]: %timeit list(range(1000000))
27.4 ms ± 72.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [2]: %timeit range(1000000)
189 ns ± 22.6 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops
➥ each)
这将给出你正在分析的功能的多次运行时间。魔法将决定运行多少次以及报告哪些基本统计信息。在上一个代码片段中,你看到了 range(1000000) 和 list(range(1000000)) 之间的差异。在这个特定案例中,timeit 显示 range 的惰性版本比急切版本快两个数量级。
你可以在 timeit 模块的文档中找到更多详细信息,但对于大多数用例,IPython 的 %timeit 魔法就足够访问其功能了。我们鼓励你使用 IPython 和它的魔法,但在本书的大部分内容中,我们将使用标准解释器。你可以在这里了解更多关于 %timeit 魔法的详情:ipython.readthedocs.io/en/stable/interactive/magics.html。
现在你已经熟悉了工具集和性能分析的方法,让我们将注意力转向不同的主题:优化 Python 数据结构的用法。
2.3 优化基本数据结构以提高速度:列表、集合和字典
接下来,我们将尝试找出 Python 基本数据结构的低效使用,并更高效地重写代码片段。为了演示这个过程,我们将继续使用来自 NOAA 的温度数据。但在这里,我们的挑战是确定在指定时间段内某个站点是否发生了特定的温度。
我们将重用本章第一部分中的代码来读取数据(代码可以在 02-python/sec3-basic-ds/exists_temperature.py 中找到)。为了这个示例,我们感兴趣的是从 2005 年到 2021 年该站点的 01044099999 号站点的数据:
stations = ['01044099999']
start_year = 2005
end_year = 2021
download_all_data(stations, start_year, end_year)
all_temperatures = get_all_temperatures(stations, start_year, end_year)
first_all_temperatures = all_temperatures[stations[0]]
first_all_temperatures 包含了该站点的温度列表。我们可以使用 print(len(first_all_temperatures), max(first_all_temperatures), min(first_all_temperatures)) 来获取一些基本统计数据。我们共有 141,082 条记录,最高温度为 27.0 摄氏度,最低温度为 -16.0 摄氏度。
2.3.1 列表搜索的性能
检查温度是否在列表中是 temperature in first_all_temperatures 的问题。让我们大致估算一下检查 -10.7 是否在列表中需要多少时间:
%timeit (-10.7 in first_all_temperatures)
我的计算机上的输出如下:
313 µs ± 6.39 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
现在我们尝试使用一个我们知道不在列表上的值来执行这个查询:
%timeit (-100 in first_all_temperatures))
结果是:
2.87 ms ± 20.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
这大约比我们搜索 -10.7 的速度慢一个数量级。
为什么第二次搜索的性能如此低?因为为了完成这个搜索,in操作符从列表的开始进行顺序扫描。这种方法意味着,在最坏的情况下,整个列表都将被搜索,这正是我们要找的元素(-100)不在列表上的情况。对于小列表,从顶部开始搜索并直接通过,只会增加微不足道的时间。但随着列表的增长,以及你可能需要在那些不断增长的列表上进行的搜索数量,时间会显著增加。
在这个阶段,我们还没有可以比较的数字,但可以安全地假设在毫秒甚至微秒之间的范围并不令人鼓舞。这应该在数量级上少得多的时间内完成。
2.3.2 使用集合进行搜索
让我们看看是否可以通过将数据结构从列表切换到集合来做得更好。让我们将有序列表转换为集合,并尝试进行搜索
set_first_all_temperatures = set(first_all_temperatures)
%timeit (-10.7 in set_first_all_temperatures)
%timeit (-100 in set_first_all_temperatures)
时间成本如下
62.1 ns ± 3.27 ns per loop (mean ± std. dev. of 7 runs,
➥ 10,000,000 loops each)
26.6 ns ± 0.115 ns per loop (mean ± std. dev. of 7 runs,
➥ 10,000,000 loops each)
这比上一节中的解决方案快几个数量级!但为什么会有这样的改进?主要有两个原因:一个是与集合大小相关,另一个是与复杂度相关。复杂度部分将在下一小节中讨论。这里我们将看看集合大小的作用。
关于大小,记住原始列表有 141,082 个元素。但使用集合后,所有重复的值都会合并成一个单一值——原始列表中有很多重复的元素。集合的大小减少到print(len(set_ first_all_temperatures)),即 400 个元素(减少了 350 倍)。难怪搜索速度如此之快,因为结构的大小要小得多。
我们应该意识到列表中可能存在的重复元素,并知道使用集合有潜在的优势,这样搜索就可以在更小的数据结构上发生。但 Python 中列表和集合的实现之间也存在更深刻的差异。
2.3.3 Python 中列表、集合和字典的复杂度
之前示例中性能的提升主要归因于当我们从列表切换到集合时,数据结构的大小实际上减少了。这引发了一个问题:如果没有重复,列表和集合的大小相同会怎样?让我们找出答案。我们可以用范围来模拟这种情况,这将指定所有元素都将不同:
a_list_range = list(range(100000))
a_set_range = set(a_list_range)
%timeit 50000 in a_list_range
%timeit 50000 in a_set_range
%timeit 500000 in a_list_range
%timeit 500000 in a_set_range
因此,我们现在有一个从 0 到 99,999 的范围,它既实现了列表也实现了集合。我们在这两种数据结构中搜索了 50,000 和 500,000。以下是时间:
455 µs ± 2.68 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
40.1 ns ± 0.115 ns per loop (mean ± std. dev. of 7 runs,
➥ 10,000,000 loops each)
936 µs ± 9.37 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
28.1 ns ± 0.107 ns per loop (mean ± std. dev. of 7 runs,
➥ 10,000,000 loops each)
集合实现仍然有更好的性能。这是因为 Python(更准确地说,CPython)中集合是通过哈希实现的。因此,查找一个元素的成本就是搜索哈希的成本。哈希函数有很多种类,需要处理许多设计问题。但是,当我们比较列表和集合时,我们可以一般假设集合查找主要是常数时间,并且对于大小为 10 或 1000 万的集合,性能都会很好。这实际上并不完全正确,但为了直观理解为什么集合查找比列表查找有优势,这是合理的。
记住,集合通常像字典一样实现,没有值,这意味着当你在一个字典键上搜索时,你得到与在集合上搜索相同的性能。然而,集合和字典并不是他们看起来那么万能的银弹。例如,如果你想搜索一个区间,一个有序列表会大大更有效率。在一个有序列表中,你可以找到最小元素,然后从这个点开始遍历,直到你找到区间之上的第一个元素然后停止。在集合或字典中,你将不得不对区间中的每个元素进行查找。所以如果你知道你要搜索的值,那么字典可以非常快。但如果你在一个区间中查找,那么它突然就不再是合理的选择;一个带有二分查找算法的有序列表会表现得更好。
由于列表在 Python 中如此普遍且易于使用,存在许多更适合的数据结构。但强调列表是一个具有许多良好用例的基本数据结构是值得的。关键是注意你的选择,而不是摒弃列表。
提示:在使用in搜索大型列表时要小心。如果你浏览 Python 代码,使用in在列表中查找元素的模式(列表对象的index方法在实践中是同一件事)相当常见。对于小列表来说,时间惩罚相当小,完全合理,但对于大列表来说,可能会很严重。
从一个非常实际的软件工程角度来看,使用in与列表一起使用可能会从开发中的未注意问题变成生产中的重大问题。常见的模式是开发者使用小的数据示例进行测试,因为在大多数单元测试中,提供大数据通常不切实际。然而,实际数据可能非常大,一旦引入,它可能会使生产系统停止运行。
一个更系统的解决方案是在不同的测试阶段——从单元测试到端到端测试——用非常大的数据集测试代码——也许不是总是,但至少偶尔。这不应该被理解为反对使用列表中的in的论点。只是要注意由于数据大小导致的开发期间和生产期间性能差异。
顺便说一下,对于大多数搜索操作,比列表、集合或字典有实质上更好的数据结构家族:树。但在本章中,我们正在评估 Python 的内置数据结构,这些数据结构不包括树。
选择合适的算法和数据结构是许多书籍的主题,也是计算机科学学位课程中最难的部分之一。这里的目的不是对这一主题进行详尽的讨论,而是让你了解 Python 中最常见的替代方案。如果你认为现有的 Python 数据结构不能满足你的需求,你可能需要考虑其他类型的数据结构。本书的重点是 Python,但其他资源将涵盖 Python 之外的数据结构;例如,Michael T. Goodrich、Roberto Tamassia 和 Michael H. Goldwasser 所著的《Python 中的数据结构和算法》(Wiley 2013 年出版),提供了良好的介绍。
另一个有用的资源是 Python 自己的关于时间复杂性的数据(wiki.python.org/moin/TimeComplexity)。在这里,你可以查找许多 Python 数据结构上广泛操作的复杂度。
到目前为止,在本章中,我们一直关注时间性能。但处理大数据集的性能问题时,这并不是唯一因素。让我们转向另一个重要因素:节省内存。
2.4 寻找过度的内存分配
内存消耗对于性能至关重要,而不仅仅是可能耗尽内存。有效的内存分配可以允许在同一台机器上并行运行更多的进程。更重要的是,合理的内存使用可能允许内存中的算法。
让我们回到我们熟悉的场景,NOAA 数据库,看看我们如何减少数据的磁盘消耗。为此,我们将从研究数据文件的内容开始。我们的目标是加载一些这些文件并对字符分布进行一些统计分析。
def download_all_data(stations, start_year, end_year):
for station in stations:
for year in range(start_year, end_year + 1):
if not os.path.exists(TEMPLATE_FILE.format(
➥ station=station, year=year)):
download_data(station, year)
def get_all_files(stations, start_year, end_year):
all_files = collections.defaultdict(list)
for station in stations:
for year in range(start_year, end_year + 1):
f = open(TEMPLATE_FILE.format(station=station, year=year), 'rb')
content = list(f.read())
all_files[station].append(content)
f.close()
return all_files
stations = ['01044099999']
start_year = 2005
end_year = 2021
download_all_data(stations, start_year, end_year)
all_files = get_all_files(stations, start_year, end_year)
all_files现在有一个字典,其中每个条目包含与一个站点相关的所有文件的内容。让我们研究一下这个字典的内存使用情况。
2.4.1 在 Python 内存估计的雷区中导航
Python 在sys模块中提供了一个名为getsizeof的函数,该函数据说返回对象占用的内存。我们可以使用以下代码了解我们的字典占用的内存:
print(sys.getsizeof(all_files))
print(sys.getsizeof(all_files.values()))
print(sys.getsizeof(list(all_files.values())))
结果是:
240
40
64
getsizeof可能不会返回你所期望的结果。磁盘上的文件大小在兆字节范围内,所以低于 1 KB 的估计听起来相当可疑。实际上,getsizeof返回的是容器的大小(第一个是字典,第二个是迭代器,第三个是列表)而不包括内容。因此,我们必须考虑占用内存的两个方面:容器的内容和容器本身。
注意,语言中getsizeof的实现没有问题;只是出乎意料的使用者通常期望的是不同的东西——即它会返回对象中引用的所有内容的内存占用。如果你阅读官方文档,你甚至可以找到递归实现的说明,这种实现可以解决大多数问题。对我们来说,getsizeof的复杂性主要是一个深入讨论 CPython 内存分配的起点。
让我们获取一些关于我们站点数据的基本信息:
station_content = all_files[stations[0]]
print(len(station_content))
print(sys.getsizeof(station_content))
输出如下:
17
248
我们的字典只有一个条目,对应一个单独的站点。它包含一个包含 17 个条目的列表。列表本身占用 248 字节,但请记住,这并不包括内容。现在让我们检查第一个条目的大小:
print(len(station_content[0]))
print(sys.getsizeof(station_content[0]))
print(type(station_content[0]))
长度为 1,303,981,对应文件的大小。我们有getsizeof为 10,431,904。这大约是底层文件大小的八倍。为什么是八倍?因为每个条目都是一个指向字符的指针,而指针的大小是 8 字节。在这个阶段,这看起来相当糟糕,因为我们有一个大的数据结构,但我们还没有计算实际的数据。让我们看看单个字符:
print(sys.getsizeof(station_content[0]))
print(type(station_content[0]))
这在大小上非常巨大。输出是 28,类型为int。所以每个字符,本应只占用 1 个字节,现在却由 28 个字节表示。因此,列表的大小为 10,431,904,加上 28 * 1,303,981(36,511,468),总计为 46,943,372。这是原始文件大小的 36 倍!幸运的是,情况并没有看起来那么糟糕,但我们能做得更好。我们将从看到 Python(或者说 CPython)在内存分配方面相当智能开始。
CPython 可以以更复杂的方式分配对象,结果我们的内存分配计算方法相当天真。让我们只计算内部内容的大小,但不是通过遍历矩阵中的所有整数,而是确保我们不会重复计数。在 Python 中,如果一个对象被多次使用,它将获得相同的id。所以如果我们看到相同的id多次,我们应该只计算一个内存分配:
single_file_data = station_content[0]
all_ids = set()
for entry in single_file_data:
all_ids.add(id(entry)) ①
print(len(all_ids))
①id函数允许我们获取对象的唯一 ID。
之前的代码获取了我们所有数字的唯一标识符。在 CPython 中,这实际上是内存位置。CPython 足够智能,能够看到相同的字符串内容被反复使用——记住,每个 ASCII 字符都由一个介于 0 到 127 之间的整数表示——因此,之前代码的输出是 46。
因此,简单的内存分配将会很糟糕,但 Python(更准确地说,CPython)要聪明得多。这种解决方案的内存成本仅仅是列表基础设施(10,431,904)。请注意,在我们的情况下,我们只有 46 个不同的字符;对于如此小的子集,Python 在智能内存分配方面相当出色。不要期望这种情况总是发生,因为这将取决于你的数据模式。
Python 中的对象缓存和复用
Python 试图在对象复用方面尽可能聪明,但我们需要对期望保持谨慎。第一个原因是这依赖于实现。CPython 在这一点上与其他 Python 实现不同。
另一个原因是,即使是 CPython,在版本之间对其大多数分配策略也没有做出任何承诺。适用于您特定版本的方法可能在不同的版本中会发生变化。
最后,即使你有固定的版本,事情是如何工作的可能并不完全明显。考虑以下 Python 3.7.3 中的代码(在其他版本中可能会有所不同):
s1 = 'a' * 2
s2 = 'a' * 2 ①
s = 2
s3 = 'a' * s ②
s4 = 'a' * s
print(id(s1))
print(id(s2))
print(id(s3))
print(id(s4))
print(s1 == s4) ③
① 这里我们通过将 a 乘以 2 来得到字符串 aa。
② 这里我们通过将 a 乘以 s(即 2)来得到字符串 aa。
③ 所有这些字符串在内容上是相等的。
结果将是:
140002256425568
140002256425568
140002256425904
140002256425960
True
当字符串的大小作为变量时,分配器无法确定内容是否相同,即使大小相同。如果这个简单的例子是这样的,那么更复杂的情况会怎样?当然,你仍然可以使用对分配器工作方式的知识,对于代码,你控制着 Python 版本,这使得这一点特别有意义。但相应地调整你的期望。
我们正在使用基于数字列表的文件表示法。如果我们考虑替代表示法会怎样呢?
2.4.2 一些替代表示法的内存占用
现在我们将考虑一些表示文件的简单替代方案。有些会更好;有些会更差。这里的主要目的是理解每种替代方案的成本。我们不是用整数来表示每个字符,而是可以使用长度为 1 的字符串——类似于这样:
single_file_str_list = [chr(i) for i in single_file_data]
这种方法甚至比我们之前使用的方法更差。只需看看每个字符串的大小:
print(sys.getsizeof(single_file_str_list[0]))
这返回 50,而之前整数表示法的表示只有 28。这是一个退步,所以我们不会这样做。
Python 中大量小对象的内存开销相当大。为什么小数字需要 28 字节,单个字符字符串需要 50 字节?实际上,每个 Python 对象至少需要 24 字节的内存开销,并且你必须将对象类型的开销加到这个开销上,这会因类型而异。正如我们所看到的,字符串的开销比字节数组大(图 2.2)。
字符串和数字的内部表示
Python 对字符串有一个高效的内部表示,它可以变化,因此可能会让人对内存分配的预期感到困惑:
from sys import getsizeof
getsizeof('')
getsizeof('c')
getsizeof('c' * 10000)
getsizeof('ç' * 10000)
getsizeof('ç')
getsizeof('😐')
getsizeof('😐' * 10000)
输出将是:
49
50
10049
10073
74
80
40076
空字符串占用 49 字节;c 字符串占用 50 字节;10,000 个 c 字符串占用 10049 字节。到目前为止,一切看起来都很好。但是,带有 cedilla 的 c 字符串占用 74 字节,10,000 个 ç 字符串占用 10,073 字节。如果你现在有点困惑,知道一个单独的困惑表情符号占用 80 字节,10,000 个这样的表情符号占用 40,076 字节。
Python 3 字符串表示 Unicode 字符,但有一个细微差别:内部表示是作为表示的字符串的函数进行优化的。有关详细信息,请参阅 PEP 393--灵活的字符串表示。对于 Latin-1 字符(ASCII 的超集),Python 使用 1 字节(带有 cedilla 的 c 字符串是那个集合的一部分),但对于其他类型的字符,它可能需要多达 4 字节(在我们的困惑表情符号的情况下)。但从我们的角度来看,字符串的大小很难计算。
整数也有一个优化的实现。精度是任意的,但对于适合 30 位的有符号整数,我们得到较小的表示可能性,即 28 字节(数字 0 是一个例外;它只由 24 字节表示,你可能还记得这是由于 CPython 的对象开销而可能的最小对象大小)。

图 2.2 字符串和字节的对象开销
对于文件,有一个更明显的表示方式:我们不需要使用一个字符字符串的列表,而可以使用整个文件的字符串:
single_file_str = ''.join(single_file_str_list)
print(sys.getsizeof(single_file_str))
这将是 1,304,030 字节的大小——我们文件的大小加上字符串对象的开销。虽然这是一个明显且简单的解决方案,但我们将继续使用字节序列容器的方案,因为事实证明,这些方案仍然可以改进。
2.4.3 使用数组作为列表紧凑表示的替代方案
在这里,我们将探讨一个替代元素容器的内存效率可能显著更高的方法:数组。让我们重新审视我们的 get_all_files 函数的实现:
def get_all_files_clean(stations, start_year, end_year):
all_files = collections.defaultdict(list)
for station in stations:
for year in range(start_year, end_year + 1):
f = open(TEMPLATE_FILE.format(station=station, year=year), 'rb')
content = f.read() ①
all_files[station].append(content)
f.close()
return all_files
① 原始实现是 content = list(f.read())。
这行代码 content = list(f.read()) 将 read 函数的输出转换成了一个列表。现在,我们没有使用列表调用来实现它,而是返回了一个字节数组。让我们检查对象的大小:
print(type(single_file_data))
print(sys.getsizeof(single_file_data))
类型是 bytes,包括数据在内的总大小是 1,304,014 字节。
数组的大小是固定的,并且只能包含相同类型的对象。因此,它们的表示可以做得更加紧凑:它可以与对象开销一起存储。回想一下,对于我们的整数,存储一个实际只有 1 字节数据的存储空间需要 28 字节。
列表中的内存占用
当你分配一个列表时,Python 会为潜在的将来添加创建额外的空间,因此列表通常比预期的空间要多。这使得插入操作的成本大大降低,因为不需要在每次添加新元素时分配内存——只需当分配的额外空间耗尽时。当然,成本是内存开销。一般来说,这种开销并不严重,除非你有大量的微小列表;也就是说,“大量微小对象”的论点对于列表尤其正确。虽然了解这一点很有趣,但通常情况下,这种开销是可以接受的。
与数组管理相关的代码大多可在array模块中找到。然而,除了本章之外,我们不会再使用array模块;相反,我们将使用 NumPy,它在许多方面都取代了它。但这里的重点与模块本身关系不大,更多的是理解和消除对象开销。
在这个阶段,你应该已经对 Python 中对象内存分配的成本和陷阱有了深刻的认识。最后,我们现在将尝试理解如何计算 Python 对象的内存使用情况。
2.4.4 系统化我们所学的知识:估计 Python 对象的内存使用
在这个阶段,你有了理解内存分配工作原理的基础。现在,你已经掌握了基本原理,我们将尝试编写一些代码,使我们能够将上一节的所有知识汇总到一个实用函数中,该函数可以给出良好的内存占用近似值。
现在,我们将提炼本节其余部分所学的所有零散知识。在接下来的讨论中,我们将编写一个函数,该函数将返回对象的估计内存大小。它将返回所有对象的大小以及容器上的开销。如果你查看以下代码,你应该能够找到 ID 跟踪、容器计数(包括需要跟踪键和值的映射对象,如字典),以及字符串和数组管理。
计算通用对象的大小实际上是一个真正的雷区(在一般情况下,仅使用 Python 方法实际上是不可能的)。我们列表 2.2 中的代码试图通过不重复计算重复的对象和报告容器和内容全尺寸的容器/迭代器(如字符串或数组;代码可在02-python/sec4-memory/compute_allocation.py中找到)来变得聪明。
列表 2.2 计算通用 Python 对象的大小
from array import array
from collections.abc import Iterable, Mapping
from sys import getsizeof
from types import GeneratorType
def compute_allocation(obj):
my_ids = set([id(obj)]) ①
to_compute = [obj]
allocation_size = 0
container_allocation = 0 ②
while len(to_compute) > 0:
obj_to_check = to_compute.pop()
allocation_size += getsizeof(obj_to_check)
if type(obj_to_check) in [str, array]: ③
continue
elif isinstance(obj_to_check, GeneratorType): ④
continue
elif isinstance(obj_to_check, Mapping): ⑤
container_allocation += getsizeof(obj_to_check)
for ikey, ivalue in obj_to_check.items():
if id(ikey) not in my_ids:
my_ids.add(id(ikey))
to_compute.append(id(ikey))
if id(ivalue) not in my_ids:
my_ids.add(ivalue)
to_compute.append(id(ivalue))
elif isinstance(obj_to_check, Iterable): ⑥
container_allocation += getsizeof(obj_to_check)
for inner in obj_to_check:
if id(inner) not in my_ids:
my_ids.add(id(inner))
to_compute.append(inner)
return allocation_size, allocation_size - container_allocation
① 我们需要存储先前看到的对象的 ID,以避免重复计算它们。
② 我们还将返回在列表或字典等容器中花费的内存。
③ 字符串和数组是可迭代的,它们返回它们内容的大小;我们不希望重复计算内容。
④ 我们将忽略生成器的内容。
⑤ 对于映射,我们需要计算键和值。
⑥ 最后,对于其他迭代器,我们需要检查其大小。
在这里,我们使用迭代方法来计算内存分配。这是一种适合递归实现的算法,但由于 Python 缺乏对良好尾调用优化和递归实现的适当支持,我们将使用迭代方法。
从使用系统编程语言如 C 或 Rust 实现的系统库中计算对象大小将主要取决于该实现以某种形式提供这些信息。对于这些库,请查阅文档以获取详细信息。
警告:有 Python 内存分析库可供尝试使用。我对一些可用工具的估计可靠性有混合的经验,这在 Python 内存估计的雷区中并不令人惊讶。如果您使用它们,请小心。
检查 Python 内存分配还有更多底层方法,但我们将在我们使用 NumPy 时讨论这些方法。在本章中,我们限制自己使用 Python 而不使用外部库。
2.4.5 要点:估计 Python 对象内存使用量
总结来说,估计内存对象的大小并不像人们预期的那样简单。sys.getsizeof不会报告所有对象的大小,因此需要额外努力才能准确计算对象大小。在一般情况下,这个问题甚至无法解决:用低级语言编写的库可能不会报告它们所做的分配的大小。
精简内存分配有几个副作用。一个是允许在内存是限制因素的情况下运行更多并行进程,因为有时就是这样。另一个优点是,它可能为使用内存算法腾出空间,这些算法比需要磁盘空间且由于磁盘访问而速度慢得多的算法要快。
2.5 使用延迟和生成器进行大数据管道
现在,我们将注意力转向 Python 3 中广泛引入的一个特性:延迟语义。延迟语义将任何计算推迟到数据需要时再进行,而不是在此之前。这对于处理大量数据非常有帮助,因为有时计算(和相关内存分配)不需要进行,或者可以分散到一段时间内进行。如果您使用生成器,您已经在使用延迟语义。Python 3 比 Python 2 懒惰得多,因为range、map和zip等函数变成了延迟的。懒惰方法将允许您处理更多数据,通常内存使用量会显著减少,并且可以更轻松地在代码中创建数据管道。
2.5.1 使用生成器而不是标准函数
让我们回顾本章第一节的原始代码:
def get_file_temperatures(file_name):
with open(file_name, "rt") as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
station = row[header.index("STATION")]
# date = datetime.datetime.fromisoformat(
➥ row[header.index('DATE')])
tmp = row[header.index("TMP")]
temperature, status = tmp.split(",")
if status != "1":
continue
temperature = int(temperature) / 10
yield temperature ①
① 定义中的yield表示生成器。
get_file_temperatures是一个生成器(注意yield)。让我们运行这个生成器:
temperatures = get_file_temperatures(TEMPLATE_FILE.format(
➥ station="01044099999", year=2021))
print(type(temperatures))
print(sys.getsizeof(temperatures))
报告的类型将是 generator,结构的大小将是 112。实际上,由于生成器是惰性的,并没有做太多事情。只有当你开始迭代它时,代码才会根据需要执行:
for temperature in temperatures: ①
print(temperature)
① 每次 for 循环重复时,生成器代码将被调用以提供新的值。
这种方法有几个优点。首先,也是最大的优点是,你不需要为所有温度分配内存,因为每个温度都会依次处理。与此相对比的是,列表需要内存来同时维护所有温度。这在函数返回包含许多元素的大型数据结构时可能非常重要——这可能是代码能否执行的关键。
第二,有时我们不需要得到所有结果,因此急切地执行只会浪费时间。例如,假设你想编写一个函数来查看是否至少有一个温度低于零。你不需要得到所有结果:一旦有一个值低于零,计算就可以停止。
制作生成器的急切版本非常简单,就像这样:
temperatures = list(temperatures)
在这种情况下,你失去了生成器的优势,但有时这可能是有用的。例如,当计算时间不长且列表表示法使用的内存可容忍时,在需要多次访问结果的情况下,一个急切版本更有意义。
注意:Python 2 和 Python 3 之间最大的区别之一是许多原本急切的内置函数变成了惰性的。例如,在我们的案例中,zip、map 和 filter 在 Python 2 中的行为会有很大的不同。
生成器可以用来减少内存占用,在某些情况下,还可以减少计算时间。所以当你编写返回序列的代码时,问问自己是否将其转换为生成器是有意义的。
摘要
-
以直观、非经验的方式检测性能瓶颈并不容易。分析是找到性能缺乏的确切位置的必要第一步。“直觉”在寻找性能问题时往往是不正确的,而经验方法几乎总是获胜。
-
Python 的内部分析系统非常有用,但有时很难解释。像 SnakeViz 这样的可视化工具可以帮助我们理解分析信息。
-
Python 的内部分析系统在帮助我们找到瓶颈发生的确切位置方面存在重大局限性。像 line_profiler 这样的工具在收集信息时运行速度非常慢,但可以提供更高的精确度。
-
虽然 CPU 性能通常是我们的性能优化的首要考虑,但内存使用同样重要,并且可以产生重大的间接效益。例如,一个内存优化不良且需要内存不足算法的解决方案有时可以被完全内存方法所取代,从而产生巨大的时间收益。
-
Python 提供了基本的数据结构,这些结构可以被正确使用或误用,从而影响性能。例如,在无序列表中搜索元素可能会变得相当昂贵。我们必须注意 Python 基本数据结构上许多操作的复杂度成本。这些数据结构出现在所有 Python 程序中,通常是低垂的果实,可以对性能产生巨大影响。
-
对 Python 数据结构的计算复杂度——大 O 符号——有基本的了解对于编写高效代码至关重要。确保定期检查这些内容,因为 Python 版本的变化可能会替换底层实现,从而改变算法的性能。
-
懒惰的编程技巧使我们能够开发出内存占用较小的程序。它们有时还可能使我们完全避免计算中的大部分内容。
-
本章的所有内容都具有广泛的应用性——既包括性能分析,也包括纯 Python 优化——并且可以在本书其余部分讨论的任何技术之前使用。
3 并发、并行性和异步处理
本章涵盖
-
使用异步处理设计减少等待时间的应用程序
-
Python 中的线程及其在编写并行应用程序上的限制
-
使多进程应用程序充分利用多核计算机
现代 CPU 架构允许同时执行多个顺序程序,从而在处理速度上实现令人印象深刻的提升。实际上,速度可以增加到可用的并行处理单元(例如,CPU 核心)的数量。坏消息是,为了充分利用所有这些并行处理速度来为我们的程序服务,我们需要使我们的代码并行化,而 Python 不适合编写并行解决方案。大多数 Python 代码都是顺序的,因此它无法使用所有可用的 CPU 资源。此外,Python 解释器的实现,正如我们将看到的,并没有针对并行处理进行优化。换句话说,我们通常的 Python 代码无法利用现代硬件的能力,并且它将始终以比硬件允许的速度慢得多。因此,我们需要设计技术来帮助 Python 利用所有可用的 CPU 功率。
在本章中,我们将学习如何做到这一点,从一些你可能熟悉的一般方法开始,但它们在 Python 中有着一些独特的技巧。我们将以 Python 的方式讨论并发、多线程和并行性,包括围绕多线程编程的一些强烈限制。
我们还将了解异步编程方法,这些方法允许我们高效地服务许多并发请求,而无需任何并行解决方案。异步编程已经存在了一段时间,在 JavaScript/Node.JS 世界中很受欢迎,但直到最近,它才在 Python 中通过添加新的模块来实现异步编程的标准化。
对于本章,让我们假设你是一家大型软件公司的开发者。你被分配开发一个预期将非常快速的 MapReduce 框架。所有数据都将存储在内存中,并且所有操作都必须在单台计算机上完成。此外,你的服务必须能够同时处理来自多个客户端的请求,其中大多数是自动化的 AI 机器人。为了应对这个项目,你将使用并发和并行编程技术,包括多线程和多进程,以加快 MapReduce 请求的处理速度。此外,你还将使用异步编程来高效地处理来自用户的许多并发查询。
我们将问题分为两部分。在章节的第一部分,我们将构建一个能够同时处理多个请求的服务器。然后我们需要创建 MapReduce 框架本身,这将在 3.1 节之后占据本章的大部分内容。我们将考虑三种不同的方法来构建框架:顺序、多线程和多进程。这将让你看到几种方法在工作,以及它们的优点、权衡和限制。在最后一节,我们将两部分结合起来,将服务器与 MapReduce 框架连接起来,使我们能够理解如何以高效的方式构建一个集成所有部分的解决方案。

图 3.1 章节路线图
为了使章节主题和组织结构更加清晰,图 3.1 提供了一个章节的视觉路线图。它展示了我们将要调查的方法,以及它们如何相互连接和/或一起使用。在每个框的右上角,你可以找到你将学习每个技术的章节编号。
顺序处理、并发和并行
在我们开始之前,让我们简要回顾一下顺序处理、并发和并行的含义。尽管这些都是基本概念,但许多经验丰富的开发者仍然会混淆它们,所以这里有一个快速复习,以确保我们都在以相同的方式使用这些术语。
并行是解释起来最容易的概念:当任务同时运行时,它们就会并行运行。并发任务可以以任何顺序运行:它们可能会并行或顺序运行,这取决于语言和操作系统。所以所有并行任务都是并发的,但反之则不然。
“顺序”这个术语可以用两种不同的方式使用。首先,它可以意味着一组任务需要以严格顺序运行。例如,要在电脑上写东西,你首先得打开它:顺序或序列是由任务本身强加的。第二个任务只能在第一个任务执行之后发生。
然而,有时“顺序”被用来表示系统对任务执行顺序强加的限制。例如,在机场,一次只允许一个人通过金属探测器,即使两个人同时也能通过。
最后,还有中断的概念:当一个任务被中断(非自愿地)以便运行另一个任务时,就会发生这种情况。这与任务之间的调度策略有关,需要软件或硬件来完成,这被称为调度器。
预先中断多任务处理的替代方案是协作多任务处理:你的代码负责告诉系统何时可以被中断并交换给另一个任务。以下图试图使这些概念更加清晰。

理解顺序、并发和并行模型。顺序执行发生在所有任务按顺序执行且不会被中断的情况下。无并行性的并发执行增加了任务可能被另一个任务中断并在之后恢复的可能性。并行性发生在多个任务同时运行时。即使有并行性,由于处理器/核心的数量可能不足以处理所有任务,预 emption 仍然很常见。一个理想的场景是有比任务更多的处理器:这允许所有任务并行执行,而不需要任何预 emption。
注意:我们不会介绍 Python 多线程和多进程的基本功能。如果你需要复习,许多入门指南可以填补这个空白,包括由 Matthew Fowler 编写的《Python Concurrency with Asyncio》(Manning,2022 年;www.manning.com/books/python-concurrency-with-asyncio))。
3.1 编写异步服务器的框架
虽然我们的主要任务是使用 MapReduce 框架来处理请求,但我们将首先构建服务器接收请求的部分(即,为客户端提供一个接口)。正确处理请求将是本章其余部分的内容。在本节中,我们将编写一个服务器,它将接受来自所有客户端的连接并接收 MapReduce 请求(数据和代码)。通过这样做,我们将看到异步编程如何帮助创建高效的服务器,即使不使用并行性。
异步编程的兴起
异步编程在 JavaScript 世界中变得流行,尤其是在 NodeJS 服务器上。当我们有很多需要监控的慢速 I/O 流时,这是一个特别好的模型。最明显的例子是 Web 服务器,其中大多数用例的数据交换量有限,处理速度也很快,通常在毫秒级别。但异步模型也可以帮助我们编写干净的并发和并行程序。此外,正如我们在本章的其余部分将看到的,异步方法对于更传统的数据分析场景也很有用。
为了澄清,异步与单线程、多线程或多进程是正交的。你可以在这些之上拥有异步系统。
首先,让我们看看同步处理造成的主要问题之一,这样我们就可以将同步解决方案与异步解决方案进行比较。在 Python 世界中,同步编程是一种更常见的方法,可能是大多数 Python 程序员首先会尝试的方法。但是,一个同步(单进程)的服务器版本在等待用户输入时会阻塞。由于用户可能需要 1 毫秒或 1 小时来实际写入请求,在同步世界中,这意味着在此期间所有其他客户端都会处于等待状态。这里有三种可能的解决方案(图 3.2):
-
我们只是阻塞(图 3.2 中的标签 1)。这意味着在处理该连接的同时,其他所有事情(例如,关注其他用户)都将无响应。这种对所有连接的阻塞是不可接受的。
-
我们有一个多线程或多进程解决方案,其中启动单个线程或进程来处理请求(图 3.2 中的标签 2)。这意味着主进程被释放以处理其他传入的请求。对于有大量 IO 通道产生少量信息的情况,单线程解决方案是可能的,并且更轻量级。
-
最后,当一个阻塞调用发生时,代码的另一种选择是以某种方式释放执行控制,以便在数据到达时可以执行其他代码片段(图 3.2 中的标签 3)。这是我们在这里将要探索的解决方案:使用单个线程的异步处理。

图 3.2:涵盖同步单进程/线程、同步多进程和异步单进程服务器实现的几种架构
有许多替代方案可以替代这三种选项。例如,本章末尾的解决方案实际上将是方案 2 和 3 的混合。在方案 2 中,另一个非常常见的替代方案是拥有一个预先启动的进程池来加速响应。对于方案 3,我们假设计算任务可以被中断(我们将在本章后面放松这个假设)。最后,正如你可能所知,Python 中的多线程代码通常是(尽管不总是)非并行的;我们将在稍后讨论这个问题。当我们处理 MapReduce 项目时,我们将讨论所有这些解决方案以及它们的问题。我们将遵循图 3.1 中描述的过程。首先,我们将尝试一个完全没有并行的简单解决方案。¹然后,我们尝试基于线程的解决方案,但这并没有满足我们的需求。之后,我们将开发一个多进程解决方案,这将最终提高性能。在最后一节中,当我们把本节开发的网络接口与多进程解决方案结合起来时,我们会发现线程代码仍然有用,尽管不是用于并行处理。
提示:这里提出的解决方案是 一个 解决方案。有大量的替代方法。即使这是可能的最佳解决方案(它不是),也为了解释而做出了让步:最佳 的定义因你的标准而异。此外,不同的问题可能需要完全不同的方法。
你应该从这里得到的是一套明确的规则,而是一套技术和见解,这将帮助你为你的问题和具体标准制定最佳解决方案。
现在,让我们回到我们的异步、单线程和单进程服务器。
3.1.1 实现与客户端通信的框架
我们的服务器将基于 TCP 协议,并在端口 1936 上响应。它在 03-concurrency/sec1-async/server.py 仓库中可用。以下是处理客户端请求的框架的顶层结构:
import asyncio ①
import pickle
results = {}
async def submit_job(reader, writer): ②
job_id = max(list(results.keys()) + [0]) + 1
writer.write(job_id.to_bytes(4, 'little')) ③
results[job_id] = job_id * 3
async def get_results(reader, writer):
job_id = int.from_bytes(await reader.read(4), 'little') ③
pickle.dump(results.get(job_id, None), writer)
async def accept_requests(reader, writer):
op = await reader.read(1) ③
if op[0] == 0:
await submit_job(reader, writer)
elif op[0] == 1:
await get_results(reader, writer)
async def main():
server = await asyncio.start_server(
accept_requests, '127.0.0.1', 1936) ④
async with server: ⑤
await server.serve_forever() ⑥
asyncio.run(main()) ⑦
① 我们使用 Python 的 asyncio 库。
② 我们的所有函数都被声明为 async。
③ 这些行可能会阻塞并暂停周围的全部其他代码。
④ 我们使用 asyncio 的 start_server 来为每个连接调用 accept_requests。我们的服务器将在本地接口 127.0.0.1 端口 1936 上监听。
⑤ 可以使用 async 关键字与关键字一起使用,使其非阻塞。
⑥ 我们正在告诉我们的服务器对象永久地处理请求。
⑦ 这是我们的代码的入口点:主函数被运行。
这段代码现在只是一个框架;我们将在本章的最后部分完成代码,当我们把所有东西结合起来时。尽管如此,这里还有很多东西要解释。基本问题是为什么这样做?为什么不直接做一个“典型”的同步版本?
主要原因是与注释 3 相关的函数——在我们的案例中,是从网络读取和写入——可能需要不确定的时间。此外,网络速度比 CPU 速度慢几个数量级。如果我们在这里阻塞,我们将远远低于潜在的性能,并且还会让其他用户无谓地等待。
你刚才看到的所有 Python 基础设施——async、await 和 asyncio 模块——都是为了防止在单线程应用程序中阻塞调用,从而阻止与它无关的其他代码部分。
3.1.2 使用协程编程
像上一节中那样的 异步 函数(即使用 async def 创建的函数)被称为 协程。协程是自愿释放执行控制的函数。系统中还有另一个部分,即执行器,它管理所有协程并根据某种策略运行它们。
当你在另一个协程内部使用 await 调用一个协程时,你实际上是在告诉 Python 在这个阶段可以转移控制权。这被称为 合作式调度,因为释放控制权是自愿的,并且需要由协程代码显式完成。
将协程系统与大多数操作系统的线程系统进行比较:在那里,线程被强制抢占,无法控制其运行的时间。这被称为抢占式调度。典型的线程代码不需要显式标记可能被中断的位置,因为它将被强制中断。从这种意义上说,Python 线程的工作方式类似于操作系统线程。这只是为了async代码,才适用自愿抢占。
一个典型的例子可能是一个等待某些网络数据并写入磁盘(通常是 IO 任务)的程序。它可以这样工作——请注意,这是顺序的,因此不需要线程:
-
主程序使用异步执行器安排两个协程:一个等待网络连接,另一个写入磁盘。
-
执行器选择——可能是随机地——开始执行网络协程。
-
网络协程设置网络监听。然后它等待连接。目前没有连接,因此它自愿告诉执行器做其他事情。
-
执行器启动磁盘协程。
-
磁盘协程开始向磁盘写入。与 CPU 速度相比,写入速度较慢,因此协程告诉执行器做其他事情。
-
执行器继续执行网络协程。
-
目前还没有连接请求;网络协程让出。
-
执行器安排磁盘协程。
-
磁盘协程完成写入并终止其工作。
-
由于没有更多的事情要做,执行器让网络协程无限期地运行。如果网络协程让出,执行器就回到它那里。
-
网络协程最终响应一个连接或可能超时。
-
执行器结束并返回控制权给主程序。
本节,连同最后一节,提供了协程的示例——记住,所有async def都是协程。但让我们用以下代码片段进行一个小(一个脚手架)测试:
import asyncio
async def accept_requests(reader, writer):
op = await reader.read(1)
# ...
result = accept_requests(None, None)
print(type(result))
让我们看看async在这里提供了什么。如果上面的代码没有async关键字,我们预计它会在reader.read()抛出异常,因为reader将是None。然而,像这样调用accept_requests并不执行一个函数,而是返回一个协程,这正是async def实际创建的内容。
我们代码中的await调用告诉 Python,accept_requests可以在该点暂停,并且可能运行其他任务。因此,当我们等待reader发送数据时,Python 可以执行其他任务,直到数据到达。如果协程在执行延迟和可以暂停的意义上有点像第二章中讨论的生成器,那么你就找到了正确的方向。
3.1.3 从简单的同步客户端发送复杂数据
为了与我们的服务器交互,我们将编写一个简单的同步客户端。这作为一个更同步类型代码的例子,这在 Python 世界中更为常见,并且对于我们的客户端需求来说已经足够了。但更重要的是,我们也借此机会展示如何在进程之间进行数据和代码的通信。虽然服务器将在以后进一步开发,但这实际上是客户端的最终版本。
我们的客户端将提交我们的代码(代码可以在03-concurrency/sec1-async/client.py中找到)和数据,然后探测服务器直到返回答案:
import marshal ①
import pickle ②
import socket
from time import sleep
def my_funs(): ③
def mapper(v):
return v, 1
def reducer(my_args):
v, obs = my_args
return v, sum(obs)
return mapper, reducer
def do_request(my_funs, data):
conn = socket.create_connection(('127.0.0.1', 1936)) ④
conn.send(b'\x00')
my_code = marshal.dumps(my_funs.__code__) ⑤
conn.send(len(my_code).to_bytes(4, 'little'))
conn.send(my_code)
my_data = pickle.dumps(data)
conn.send(len(my_data).to_bytes(4, 'little'))
conn.send(my_data)
job_id = int.from_bytes(conn.recv(4), 'little') ⑥
conn.close()
print(f'Getting data from job_id {job_id}')
result = None
while result is None: ⑦
conn = socket.create_connection(('127.0.0.1', 1936))
conn.send(b'\x01')
conn.send(job_id.to_bytes(4, 'little'))
result_size = int.from_bytes(conn.recv(4), 'little')
result = pickle.loads(conn.recv(result_size))
conn.close()
sleep(1)
print(f'Result is {result}')
if __name__ == '__main__':
do_request(my_funs, 'Python rocks. Python is great'.split(' '))
① marshal用于提交代码。
② pickle用于提交大多数高级 Python 数据结构。
③ 我们的功能定义在一个返回函数的单个函数中。
④ 我们在这里创建一个网络连接。
⑤ 我们创建我们代码的字节表示。
⑥ 我们接收 job_id 并自行处理编码。
⑦ 我们将一直连接,直到结果准备好。
这里也有很多东西要解释。让我们从网络代码开始:我们使用 Python 的 socket 接口创建一个 TCP 连接,并使用该 API 发送和接收数据。所有调用都是可能阻塞的,这对于这个客户端来说是可接受的。
在此代码中需要保留的最重要部分可能是传输数据的各种替代方案。在 Python 中,pickle模块是最常见的序列化数据的方式,然后可以跨进程传输。但它并不是一个万能的解决方案;例如,它不能用来传输代码。对于代码,我们使用marshal模块。我们还使用int对象的to_bytes函数来提醒我们自己可以在更多边缘情况下自行处理编码。其中最常见的是当我们需要一个既紧凑又快速的解决方案时——这两件事pickle都做不到。当然,在这种情况下,我们将承担编码/解码的负担。当我们处理 IO 时,我们将重新审视这个问题。
我们在返回函数的my_funs函数内部传输我们的代码。另一种选择是使用对象。要使用此代码,打开终端并使用以下命令启动服务器:
python server.py
然后使用以下命令运行客户端:
python client.py
输出将是:
Getting data from job_id 1
Result is [Number between 1 and 4]
3.1.4 进程间通信的替代方法
客户端/服务器通信的一个更常见的方法是使用 HTTPS 上的 REST 接口,但当我们试图理解底层概念时,REST 并不是很有帮助。我们将在第六章中重新审视替代网络通信策略的性能影响。无论如何,一个现实实现肯定需要至少某种形式的加密。
3.1.5 吸取的经验:异步编程
异步编程在处理大量同时用户请求时可以非常有效。为了async能够提高响应时间,必须满足两个条件。首先,与外部进程的通信必须受到限制。其次,每个请求的 CPU 处理量也应该较小。由于这两个条件通常与 Web 服务器相关,因此async编程通常对大多数 Web 应用都有帮助。
此外,尽管我们在本节中集中讨论了async编程的基础,但关于 Python 的核心异步功能还有很多可以说的。我鼓励您了解语言功能,如异步迭代器(async for)和上下文管理器(async with)。还有许多异步库,如 aiohttp,用于 HTTP 通信,作为已知同步请求库的替代品。
3.2 实现基本的 MapReduce 引擎
现在让我们继续本章的主要目标,即实现一个 MapReduce 框架。在第一部分,我们处理了框架周围的通信架构。在本节中,我们将开始实现解决方案的核心。本节将设置一个基本解决方案,我们将在本章的后面部分从中推导出更高效的计算版本。
3.2.1 理解 MapReduce 框架
让我们从分解一个 MapReduce 框架开始,看看它包含哪些组件。从理论角度来看,MapReduce 计算被分为至少两个部分:map 和 reduce 部分。让我们通过一个典型的 MapReduce 应用程序的例子:单词计数,来观察这一过程。在这种情况下,我们将使用莎士比亚的《暴风雨》中的两行:“我是个傻瓜。为我所高兴的事而哭泣。”您可以在图 3.3 中看到这个输入。除了 map 和 reduce 之外,在实践中,还需要存在其他组件。例如,map 的结果在发送到 reduce 进程之前需要被打乱:如果单词am的两个实例被发送到不同的 reduce 进程,计数将不会正确。

图 3.3 使用单词计数作为示例的map_reduce框架的基本原理。传统的 MapReduce 框架有几个进程或线程实现 map 和结果步骤。在许多情况下,这些可以分布在多台计算机上。
单词计数可以通过一个 map 函数实现,该函数会对每个找到的单词发出一个条目,计数为 1,而 reduce 函数会汇总所有相同单词的 map 条目。因此,map 会发出:
I, 1
am, 1
a, 1
fool, 1
To, 1
weep, 1
at, 1
what, 1
I, 1
am, 1
glad, 1
of, 1
然后,reduce 会生成:
I, 2
a, 1
fool, 1
To, 1
weep, 1
at, 1
what, 1
am, 2
glad, 1
of, 1
在某个地方,我们需要对结果进行洗牌,以便一个唯一的单词只被一个 reduce 函数看到。例如,如果 am 被两个不同的 reduce 函数看到,那么当我们想要看到一个计数为 2 的时候,我们最终会得到两个计数为 1。在我们的服务器中,洗牌函数是内置的;用户不需要提供它。
3.2.2 开发一个非常简单的测试场景
记住,我们正在自己实现一个 MapReduce 框架。虽然我们不会是用户,但我们需要测试我们的 MapReduce 框架。为此,我们将回到 MapReduce 最常见的练习:在文本中统计单词。然后我们的框架将被用于许多其他问题,但为了框架的基本测试,统计单词就足够了。
实现此功能的用户代码将像以下这样简单。记住,这并不是我们被委托去做的事情;这只是我们将用于测试的例子:
emitter = lambda word: (word, 1) ①
counter = lambda (word, emissions): (work, sum(emissions))
① 我们将故意使用函数式表示法,因为 MapReduce 具有函数式起源。如果你使用 PEP 8,你的语法检查器将会抱怨,因为 PEP 8 说:“始终使用 def 语句而不是将 lambda 表达式直接绑定到标识符的赋值语句。”这种报告方式将取决于你的代码检查器。你是否有偏好使用这种表示法还是 PEP 8 的那种,后者将采用 def emitter(word) 的形式。我们将使用这段代码来测试我们在本章中构建的框架。
3.2.3 尝试实现一个 MapReduce 框架
记住,之前的代码是用户将要编写的代码。我们现在将实现一个 MapReduce 引擎,这是我们真正的目标,它将统计单词并做更多的事情。我们将从一个能工作但不多的事情开始,然后在接下来的章节中通过使用线程、并行性和异步接口(第一版在 03-concurrency/sec2-naive/naive_server.py 中可用)来开发一个高效的引擎:
from collections import defaultdict
def map_reduce_ultra_naive(my_input, mapper, reducer):
map_results = map(mapper, my_input)
shuffler = defaultdict(list)
for key, value in map_results:
shuffler[key].append(value)
return map(reducer, shuffler.items())
你现在可以使用以下内容:
words = 'Python is great Python rocks'.split(' ')
list(map_reduce_ultra_naive(words, emiter, counter))
list 强制惰性映射调用实际执行(如果你对惰性语义有疑问,请参阅第二章),因此你会得到以下输出:
[('Python', 2), ('is', 1), ('great', 1), ('rocks', 1)]
尽管从概念上看,之前的实现相当清晰,但从操作角度来看,它未能满足 MapReduce 框架最重要的操作期望——即其函数是并行运行的。在接下来的几节中,我们将确保在 Python 中创建一个高效的并行实现。
3.3 实现一个 MapReduce 引擎的并发版本
让我们再试一次,这次实现一个并发框架,这次通过使用多线程。我们将使用 concurrent.futures 模块中的线程执行器来管理我们的 MapReduce 任务。我们这样做是为了得到一个不仅并发而且并行(即,允许我们使用所有可用的计算能力)的解决方案——至少这是我们希望做到的。
3.3.1 使用 concurrent.futures 实现线程化服务器
我们从concurrent.futures开始,因为它比最常用的threading和multiprocessing模块更声明式和更高级。这些是领域中的基础模块,我们将在下一节中使用multiprocessing,因为它的低级接口将允许我们更精确地分配 CPU 资源。
这里是新的版本(代码可在03-concurrency/sec3-thread/ threaded_mapreduce_sync.py中找到):
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor as Executor ①
def map_reduce_still_naive(my_input, mapper, reducer):
with Executor() as executor: ②
map_results = executor.map(mapper, my_input) ③
distributor = defaultdict(list) ④
for key, value in map_results:
distributor[key].append(value)
results = executor.map(reducer, distributor.items()) ③
return results
① 我们使用concurrent.futures模块中的线程 executor。
② executor 可以作为上下文管理器工作。
③ Executors 有一个具有阻塞行为的 map 函数。
④ 我们使用一个非常简单的洗牌函数。
我们的功能再次接受一些输入,包括mapper和reducer函数。concurrent.futures模块中的 executor 负责线程管理,尽管我们可以指定我们想要的线程数。如果没有指定,默认值与os.cpu_count相关;实际的线程数在不同的 Python 版本中有所不同。这总结在图 3.4 中。

图 3.4 我们的 MapReduce 框架的线程执行
记住,我们需要确保相同对象(在我们的例子中是一个单词)的结果被发送到正确的 reduce 函数。在我们的情况下,我们在distributor默认字典中实现了一个非常简单的版本,为每个单词创建一个条目。
之前的代码可能有一个相当大的内存占用,特别是因为洗牌器将所有结果都保留在内存中,尽管是以紧凑的方式。但为了简单起见,我们将保持原样。
使用concurrent.futures管理工作者的数量基本上是一个黑盒。作为这样的模块,我们不知道它被优化了什么。因此,如果我们想确保我们正在提取最大的性能,我们必须完全控制执行的方式。如果你想微调工作者管理,你需要直接使用threading模块。² 我们将在下一节中看到如何做到这一点。
你可以尝试这个解决方案,与
words = 'Python is great Python rocks'.split(' ')
print(list(map_reduce_still_naive(words, emiter, counter)))
并且输出将与上一节相同。
然而,上一个解决方案有一个问题:它不允许与正在进行的程序有任何交互。也就是说,当你执行 executor.map 时,你必须等待完整的解决方案计算完成。对于只有五个单词的例子来说,这并不相关,但你可能希望对于非常长的文本有一些反馈。例如,你希望在代码运行时能够报告完成的百分比。这需要一种稍微不同的解决方案。
3.3.2 使用 futures 的异步执行
首先,让我们只编写 map 部分来理解正在发生的事情(代码可在03-concurrency/sec3-thread/threaded_mapreduce.py中找到):
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor as Executor
def async_map(executor, mapper, data):
futures = []
for datum in data:
futures.append(executor.submit(mapper, datum)) ①
return futures
def map_less_naive(executor, my_input, mapper):
map_results = async_map(executor, mapper, my_input)
return map_results
① 在调用 executor 时,我们使用 submit 而不是 map。
当 executor 的 map 函数等待结果时,submit不会等待。我们将在稍后运行时看到这意味着什么。
我们将改变我们的发射器,以便能够跟踪正在发生的事情:
from time import sleep
def emitter(word):
sleep(10)
return word, 1
sleep调用是为了减慢代码的执行速度,这样我们甚至可以通过一个简单的示例来跟踪正在发生的事情。让我们使用我们的 map 函数:
with Executor(max_workers=4) as executor:
maps = map_less_naive(executor, words, emitter)
print(maps[-1])
如果你打印列表中的最后一个项目,你可能会得到一些意外的东西:
<Future at 0x7fca334e0e50 state=pending>
你不会得到('rocks', 1),而是一个未来。未来代表一个可能的结果,它可以被await处理,并检查其状态。我们现在可以允许用户以这种方式跟踪进度:
with Executor(max_workers=4) as executor: ①
maps = map_less_naive(executor, words, emitter)
not_done = 1
while not_done > 0: ②
not_done = 0
for fut in maps:
not_done += 1 if not fut.done() else 0 ③
sleep(1) ④
print(f'Still not finalized: {not_done}')
① 我们只放置了四个执行者,这样我们可以跟踪进度,因为我们有五个任务。
② 当还有任务需要完成时,我们会打印状态。
③ 检查未来是否完成
④ 我们会稍微休息一下,因为我们不希望出现大量的文本
如果你运行前面的代码,你会得到几行显示“仍然未最终确定...”。通常在前 10 秒内,你会看到五个,然后只剩下一个。因为有四个工作者,前四个需要 10 秒来完成,然后最后一个才能开始。鉴于这是并发代码,每次运行的结果可能会有所不同,因此线程被抢占的方式可能会每次都不同:它是非确定性的。
还有一块最后的拼图需要完成,这将包含在最后版本的线程执行者中:我们需要一种方式让调用者能够得知进度。调用者将不得不传递一个回调函数,该函数将在发生重要事件时被调用。在我们的情况下,这个重要事件将是跟踪所有 map 和 reduce 作业的完成。这已经在以下代码中实现:
def report_progress(futures, tag, callback): ①
done = 0
num_jobs = len(map_returns)
while num_jobs > done:
done = 0
for fut in futures:
if fut.done():
done +=1
sleep(0.5)
if callback:
callback(tag, done, num_jobs - done)
def map_reduce_less_naive(my_input, mapper, reducer, callback=None):
with Executor(max_workers=2) as executor:
futures = async_map(executor, mapper, my_input)
report_progress(futures, 'map', callback) ②
map_results = map(lambda f: f.result(), futures) ③
distributor = defaultdict(list)
for key, value in map_results:
distributor[key].append(value)
futures = async_map(executor, reducer, distributor.items())
report_progress(futures, 'reduce', callback) ④
results = map(lambda f: f.result(), futures) ⑤
return results
① report_progress将需要一个回调函数,该函数将每半秒调用一次,并带有关于已完成作业的统计信息。
② 我们会报告所有 map 任务的进度。
③ 因为结果实际上是未来对象,所以我们需要从未来对象中获取这些结果。
④ 我们会报告所有 reduce 任务的进度。
⑤ 因为结果实际上是未来对象,所以我们需要从未来对象中获取这些结果。
因此,在 map 和 reduce 运行期间,每 0.5 秒,用户提供的回调函数将被执行。回调函数可以像你想要的那样简单或复杂,尽管它应该很快,因为其他所有内容都将等待它。对于我们在测试中使用的单词计数示例,我们有一个非常简单的示例:
def reporter(tag, done, not_done):
print(f'Operation {tag}: {done}/{done+not_done}')
注意,回调函数的签名不是任意的:它必须遵循report_progres强加的协议,该协议需要标签以及已完成和未完成的任务数量作为参数。
如果你运行
words = 'Python is great Python rocks'.split(' ')
results = map_reduce_less_naive(words, emitter, counter, reporter)
你将看到几行打印操作正在进行的状态,然后是结果:
Operation map: 3/5
Operation reduce: 0/4
('is', 1)
('great', 1)
('rocks', 1)
('Python', 2)
例如,使用返回值作为指示器来取消 MapReduce 框架的执行并不困难。这将允许我们更改回调函数的语义以中断进程。
不幸的是,这个解决方案是并发的但不是并行的。这是因为 Python(或者更确切地说,CPython)一次只执行一个线程,归功于臭名昭著的 CPython GIL(全局解释器锁)。让我们在下一节更详细地看看 GIL 以及它是如何处理线程的。
3.3.3 GIL 和多线程
虽然 CPython 使用操作系统线程,因此它们是抢占式线程,但 GIL 施加了一个限制,使得一次只能运行一个线程。因此,你可能在多核计算机上运行一个多线程程序,但最终你将没有任何并行性。实际上,情况可能更糟:在多核计算机上线程切换的性能可能会相当差,这是由于 GIL(它不允许一次运行多个线程)与 CPU 和操作系统之间的摩擦,而 CPU 和操作系统实际上被优化来做相反的事情。
这本书包括一个关于多线程的章节,因为任何与性能相关的书籍如果没有它就不完整。但说实话,如果你想获得性能,Python 线程很少是最好的解决方案。
GIL 问题被高估了。事实上,如果你需要在线程级别进行高性能代码,Python 可能本身就太慢了。至少 CPython 实现,可能还有 Python 的动态特性,都会带来成本。你将希望用 C 或 Rust 这样的低级语言实现任何极其高效的代码,或者使用 Cython 或 Numba 这样的系统,我们将在后面学习。
GIL 提供了几个用于其他语言实现的低级代码的逃生路线:当你进入你的低级解决方案时,你实际上可以释放 GIL 并充分利用并行性。这正是 NumPy、SciPy 和 scikit-learn 等库所做的事情。它们用 C 或 Fortran 编写的多线程代码会释放 GIL 并且实际上是并行的。所以你的代码在多线程世界中仍然可以是并行的。只是并行部分不会用 Python 编写。
但你仍然可以用纯 Python 编写高效的并行代码,并在 Python 有意义的计算粒度级别上做到这一点。你不是通过多线程,而是通过多进程来做到这一点。
PyPy
虽然 CPython 是 Python 的标准实现,但还存在其他实现,如 IronPython 和 Jython,分别用于.NET 和 JVM。另一个值得提到的实现是 PyPy,它不是一个解释器,而是一个即时编译器。PyPy 不是 CPython 的替代品,因为许多 CPython 库不能直接与它一起使用。但如果支持的库是你需要的,它可能是一个更快的实现。虽然 PyPy 在许多情况下比 CPython 快,但它仍然有一个 GIL,所以它不会解决这个问题。在这本书中,我们将坚持使用 CPython,但在效率合理的子集情况下,PyPy 可能是一个潜在的替代方案。
最后一点:如果你把 Python 实现PyPy和包仓库PyPI搞混,要知道你并不孤单。
3.4 使用多进程实现 MapReduce
由于 GIL(全局解释器锁),我们的多线程代码实际上并不是并行的。我们可以转向两个方向来解决这个问题:我们可以用 C 或 Rust 这样的底层语言重新实现我们的 Python 代码,或者像本节中的解决方案一样,转向多进程以实现并行性并利用所有可用的 CPU 功率。底层解决方案将在后面的章节中讨论。
3.4.1 基于 concurrent.futures 的解决方案
理论上,基于 concurrent.futures 的解决方案相当简单。该模块的设计目标之一是使其容易将 ThreadPoolExecutor 中的一个导入更改为 ProcessPoolExecutor(此代码位于 03-concurrency/sec4-multiprocess/ futures_mapreduce.py):
from concurrent.futures import ProcessPoolExecutor as Executor
如果你替换上一节异步版本的这一行,你会注意到有问题,因为代码似乎在减少部分冻结了。我们需要深入挖掘;因此,我们将从上一节构建一个更详细的 report_progress 函数:
def report_progress(futures, tag, callback):
done = 0
while num_jobs > done:
done = 0
for fut in futures:
if fut.done():
done +=1
print(fut)
print(fut.exception())
sleep(0.5)
if callback:
callback(tag, done, not_done)
我们刚刚添加了两个打印语句。如果我们再次运行代码,我们会得到:
<Future at 0x7f1ffff104c0 state=finished raised PicklingError>
Can't pickle <function <lambda> at 0x7f2000131ca0>: attribute lookup
<lambda> on __main__ failed
结果表明 lambda(记住我们的 counter 函数是作为 lambda 编写的)无法被序列化。因此,多进程通信是通过 pickle 模块完成的。因此,我们的 counter 函数不能直接传输到子进程中。我们可以将其重写为一个 def 函数:
def counter(emitted):
return emitted[0], sum(emitted[1])
这解决了我们的特定测试示例,但更重要的观点是,你不能简单地从线程执行器直接替换为基于进程的执行器。其他可能不同的功能有哪些?转向下一节以了解详情。
使用 Python 的多进程模块进行数据和代码共享的问题
我们已经看到,在默认的 pickle 配置下,无法在进程间传输 lambda 表达式。或者,如果你想这么做,你必须实现自己的协议。
通常,如果 pickle 无法处理,那么你必须单独处理这个问题,因为多进程依赖于 pickle 进行通信。这可能包括来自外部库的对象,特别是如果它们有非 Python 实现。
文件指针、数据库连接和套接字要么无法传输,要么需要额外小心。使用线程,所有这些对象类型都可以共享,尽管我们需要检查它们是否线程安全。pickle 的另一个问题是它相当慢。在需要传输大量数据的情况下,这可能会完全抵消使用多进程的目的。
使用 Python 的通信原语对于粗粒度处理且通信量低的情况是完全可以的。但在通信开销很大的场景中,应该有所注意,因为通过多进程获得的速度可能会在通信时间中丢失。
3.4.2 基于 multiprocessing 模块的解决方案
concurrent.futures为我们提供了一个非常简单的接口来进行并发处理。对于更明显的问题,它在程序员生产力和计算性能方面都可以非常高效。但编程简单性是有代价的:我们失去了对代码执行方式的控制。futures 的执行顺序是什么?虽然我们定义了最大工作线程数,但在某个特定时间点有多少是真正可用的?进程是回收还是为每个任务从头创建?在concurrent.futures中,这是由执行器定义的,我们无法控制它。
在我们的案例中,我们实际上希望实施一些策略以实现高性能。例如,我们希望在任务到达之前创建所有进程,或者在没有任何任务时保持进程存活。这是因为当请求到达时创建和销毁进程有开销,我们更愿意在处理这些请求时支付这个开销。我们将首先自己创建一个进程池。
我们将从一个简单的解决方案开始,这个解决方案不允许我们实时跟踪进度(代码位于03-concurrency/sec4-multiprocess/mp_mapreduce_0.py):
from collections import defaultdict
import multiprocessing as mp ①
def map_reduce(my_input, mapper, reducer):
with mp.Pool(2) as pool: ②
map_results = pool.map(mapper, my_input) ③
distributor = defaultdict(list)
for key, value in map_results:
distributor[key].append(value)
results = pool.map(reducer, distributor.items()) ③
return results
① 我们导入 multiprocessing 模块。
② 我们创建了一个包含两个进程的池。
③ 池提供了一个同步的 map 函数。
这段代码是最简单的。唯一的新行是创建一个Pool。每当请求 MapReduce 操作时,都会创建池,因此它不是跨多个调用持久存在的:我们为每次执行支付了池创建的代价。
CPU_count与sched_getaffinity的比较来确定池的大小
之前的代码指定了在我们的池中创建两个进程。在大多数情况下,你希望这个数字是计算能力的函数。池的默认值是os.cpu_count,这实际上是一个误称:它通常报告的是超线程的数量,而不是 CPU 的数量。
一个稍微更严谨的替代方案是使用len(os.sched_getaffinity()),因为它报告了你可访问的所有核心。你的电脑可能有更多,但操作系统、容器或虚拟机可能限制了你对其中一部分的访问。
警告:Pool.map函数的语义是急切的,而内置的 map 函数是懒散的。因此,这段代码在语义上并不等价:
map(fun, data)
Pool.map(fun, data)
第一次返回立即发生,并没有执行fun。list(map(fun, data))将是急切等价的形式。典型的发展模式是将Pool.map替换为 map,因为如果在同一进程中调试代码会更简单。然而,这并不完全正确。在multiprocessing方面,你还有imap,这是一个更懒散的版本,以及map_async中的异步版本。
3.4.3 监控多进程解决方案的进度
如目前所示,map_async不支持进度跟踪。请注意,它有回调支持,但它只在所有结果都准备好时调用callback函数。我们希望有更细粒度的功能:每次迭代器中的元素准备好时都能进行调用。这正是我们需要用于进度跟踪的。
我们将修改代码以支持它。虽然有一个Pool.map_async函数在许多场合可能很有用,但其回调系统只报告执行的最后阶段,这还不够。我们需要一个更底层的解决方案(此代码位于03-concurrency/sec4-multiprocess/mp_mapreduce.py):
def async_map(pool, mapper, data):
async_returns = []
for datum in data:
async_returns.append(pool.apply_async( ①
mapper, (datum, ))) ②
return async_returns
def map_reduce(pool, my_input, mapper, reducer, callback=None):
map_returns = async_map(pool, mapper, my_input)
report_progress(map_returns, 'map', callback)
map_results = [ret.get() for ret in map_returns] ③
distributor = defaultdict(list)
for key, value in map_results:
distributor[key].append(value)
returns = async_map(pool, reducer, distributor.items())
results = [ret.get() for ret in returns]
return results
① 我们使用Pool.apply_async来启动单个任务。
② 注意,函数的参数被指定为一个元组。
③ 使用异步对象的 get 方法获取结果
代码与concurrent.futures解决方案没有太大区别——以至于我们可能会问concurrent.futures缺乏灵活性是否值得额外的简单性。
我们现在的map_reduce函数使用用户提供的池,允许池的回收利用。这通常比每次进行新操作时启动新进程更有效。在我们的例子中,几乎没有开销,但在更复杂的例子中,每个进程的初始化可能既耗时又消耗资源。
要调用此代码,我们现在必须先创建池。这是简单的:
pool = mp.Pool()
results = map_reduce(pool, words, emitter, counter, reporter)
pool.close() ①
pool.join() ②
① 我们关闭池。
② 我们等待所有进程终止。
我们在这里可以使用上下文管理器中的池,但这让我们看到清理池不仅仅是关闭所有进程;它还包括等待它们退出——join调用。相对于close的更强替代方案是terminate,这将强制现有进程终止,即使没有完成任何正在进行的工作。
过度分配或未充分分配 CPU 资源
当我们创建池时,我们使用默认大小,os.cpu_count(),但有许多情况下你将想要未充分分配资源。甚至有情况,过度分配资源也是可以接受的。
最常见的未充分分配资源的原因是当进程是 I/O 密集型时:过多的 I/O 操作很容易使机器崩溃,因为许多进程可能造成的 I/O 负载超过了机器的处理能力。这对于磁盘 I/O 尤其如此。
如果进程消耗了大量的内存,那么你也需要相应地减少资源分配,因为你可能因为内存缓存的使用而降低了性能。在最坏的情况下,如果计算机内存耗尽,操作系统可能会开始杀死进程。
过度分配资源的一个典型情况是当你正在等待网络。这意味着进程可能会长时间处于空闲状态,因此 CPU 资源是可用的。
反讽的是,CPU 过度分配对于一些 CPU 密集型过程可能是有用的——例如,当每个进程的 CPU 使用率不是连续的,而是以突发方式出现,或者在计算真正开始之前有大量设置时间时。
report_progress函数几乎相同:当作业完成时调用回调。将Future.done的调用替换为AsyncReturn.ready:
def report_progress(map_returns, tag, callback):
done = 0
num_jobs = len(map_returns)
while num_jobs > done:
done = 0
for ret in map_returns:
if ret.ready():
done += 1
sleep(0.5)
if callback:
callback(tag, done, num_jobs - done)
你现在可以运行代码,并且一切正常。但是,这个解决方案是否足够快?
3.4.4 分块传输数据
为了回答解决方案是否足够快的问题,我们需要将其与其他东西进行比较。正如我们在上一章中看到的那样,并且将在稍后重提,块化对于磁盘写入可以显著加快磁盘写入操作。这种技术对于 CPU 成本和进程间通信也有好处吗?
为了回答这个问题,我们将对我们的 MapReduce 架构进行一些小的修改。我们将在开始时添加一个分割阶段,将数据作为块而不是单个元素发送。图 3.5 向图 3.3 中引入了一个新的步骤,该步骤负责分割。

图 3.5 带分割的map_reduce框架(在我们的情况下实际上是块化)
我们的分割实际上相当简单,但高级 MapReduce 框架可以在这里做更多高级优化。我们将首先查看提交块化作业并在池进程上解块(此代码位于03-concurrency/sec4-multiprocess/chunk_mp_mapreduce.py)的代码:
def chunk(my_iter, chunk_size): ①
chunk_list = []
for elem in my_iter:
chunk_list.append(elem)
if len(chunk_list) == chunk_size:
yield chunk_list
chunk_list = []
if len(chunk_list) > 0:
yield chunk_list
def chunk_runner(fun, data): ②
ret = []
for datum in data:
ret.append(fun(datum))
return ret
def chunked_async_map(pool, mapper, data, chunk_size): ③
async_returns = []
for data_part in chunk(data, chunk_size): ④
async_returns.append(pool.apply_async( ⑤
chunk_runner, (mapper, data_part)))
return async_returns
① 我们现在有了块生成器,它将分割迭代器为大小为chunk_size的列表。
② 在池进程上执行块运行器以解包块列表。
③ 我们必须通过添加一些中间件来解包列表,以使我们的函数能够向池提交作业。
④ 这里我们调用块函数。
⑤ 我们现在调用中间件而不是直接调用最终函数。
chunked_async_map是将在池之间分配工作的代码。它调用chunk生成器将输入分割为大小为chunk_size的块。请注意,它不再直接调用所需函数:在池进程上首先运行的是chunk_runner,它将迭代块中的每个元素并调用实际的工作函数fun。
你可能会认为chunk生成器的实现更简单,如下所示:
def chunk0(my_list, chunk_size):
for i in range(0, len(my_list), chunk_size):
yield my_list[i:i + chunk_size]
实现的问题在于它需要len(my_list),因此限制了我们的输入必须是列表。迭代器可以是懒惰的,因此占用更少的内存,并且可能需要更少的 CPU 来处理。
现在,我们需要修改我们的顶级 MapReduce 函数:
def map_reduce(
pool, my_input, mapper, reducer, chunk_size, callback=None): ①
map_returns = chunked_async_map(pool, mapper, my_input, chunk_size) ①
report_progress(map_returns, 'map', callback)
map_results = []
for ret in map_returns:
map_results.extend(ret.get()) ②
distributor = defaultdict(list)
for key, value in map_results:
distributor[key].append(value)
returns = chunked_async_map(
pool, reducer, distributor.items(), chunk_size) ③
report_progress(returns, 'reduce', callback)
results = []
for ret in returns:
results.extend(ret.get())
return results
① 我们添加chunk_size作为参数。
② 我们使用extend而不是append。
③ 我们添加chunk_size作为参数。
唯一的注意事项是,每次执行的结果不再是单个元素,而是一系列元素。因此,我们必须extend列表而不是append到它。
为了进行更好的速度测试,我们将使用托尔斯泰的《安娜·卡列尼娜》,可在 Project Gutenberg(gutenberg.org/files/1399/1399-0.txt)找到。以下是调用代码:
words = [word ①
for word in map(lambda x: x.strip().rstrip(),
' '.join(open(
'text.txt', 'rt', encoding='utf-8').readlines()).split(' '))
if word != '' ]
chunk_size = int(sys.argv[1]) ②
pool = mp.Pool()
counts = map_reduce(pool, words, emitter, counter, chunk_size, reporter)
pool.close()
pool.join()
for count in sorted(counts, key=lambda x: x[1]): ③
print(count)
① 这将所有文本读入一个列表。
② 块大小是一个命令行参数。
③ 我们按升序打印所有单词计数。
我已经用块大小为 1、10、100、1,000 和 10,000 运行了之前的代码。每种情况的时间都在表 3.1 中展示。
表 3.1 不同块大小的运行时间
| 块大小 | 时间(秒) |
|---|---|
| 1 | 114.2 |
| 10 | 12.3 |
| 100 | 4.3 |
| 1,000 | 3.1 |
| 10,000 | 3.1 |
表 3.1 中的数字不言自明:chunking 可以极大地提高我们框架的性能。chunking 是一个如此重要的概念,我们将在其他章节中再次讨论它。
提示:如果你使用Pool对象的map,chunking 会免费为你实现。你只需添加chunksize参数。对于map_async和imap也是如此。更一般地说,当你使用并行库时,务必检查是否提供了 chunking 功能。在许多情况下,你不需要自己实现它。
共享内存
这里提出的(隐式)消息传递解决方案的替代方案将是使用共享内存服务。像 Python 内置库中可用的那些简单的共享内存模型,因其非常容易出错而臭名昭著,因此我们在此不予讨论。如果你需要内存共享,你可能处于必须使用底层解决方案实现代码的情况。我们将在使用与我们的 Python 代码相关联的底层方法进行处理的后续上下文中讨论共享内存。
3.5 将一切整合:异步多线程和 multiprocessing MapReduce 服务器
我们在这一章中测试了各种方法及其组合,包括并行性、并发性、线程同步和异步编程。现在,我们将选择其中最有效的策略,并在返回到我们开发极快 MapReduce 框架的示例问题时将它们组合起来。让我们回顾一下我们问题的所有参数:所有数据都在内存中,所有工作将在一台计算机上完成,我们的系统将处理来自多个客户端的请求,包括自动化的 AI 机器人。在本节的最后,我们将最终开发一个完整的解决方案,最终得到一个异步 TCP 服务器后面的多进程 MapReduce 实现,该服务器将回答来自多个客户端的查询。
我们已经构建了两个部分:上一节中的分块 MapReduce 实现和我们在 3.1 节中制作的客户端。我们可以直接使用这两个部分。对于这个解决方案中的其他所有内容,请继续阅读。
3.5.1 架构一个完整的高性能解决方案
我们将根据图 3.6 设计架构。与所有客户端交互的前端将是异步的。工作将通过队列发送到另一个线程。那个线程将负责管理一个进程池,该进程池将执行 MapReduce 工作。

图 3.6 MapReduce 服务器的最终架构
我们的前端 TCP 服务器将在一个异步循环中实现。将会有一个第二线程,它只负责管理 MapReduce 多进程池。
两个线程之间的通信将使用来自queue模块的Queue。入口代码将设置异步服务器和负责管理 MapReduce 池的线程(代码位于03-concurrency/sec5-all/server.py):
import asyncio
from queue import Queue, Empty
import multiprocessing as mp
import types
work_queue = Queue()
results_queue = Queue()
results = {}
def worker(): ①
pool = mp.Pool() ②
while True:
job_id, code, data = work_queue.get() ③
func = types.FunctionType(code, globals(), 'mapper_and_reducer')
mapper, reducer = func()
counts = mr.map_reduce(pool, data, mapper, reducer, 100, mr.reporter)
results_queue.put((job_id, counts)) ④
pool.close()
pool.join()
async def main():
server = await asyncio.start_server(accept_requests, '127.0.0.1', 1936)
worker_thread = threading.Thread(target=worker) ⑤
worker_thread.start() ⑥
async with server:
await server.serve_forever()
asyncio.run(main())
① 这个函数在新线程中被调用。
② 池是在工作线程内部创建的。
③ 工作线程等待一些工作要做。
④ 结果被放入响应队列。
⑤ 准备一个线程,将其指向 worker 作为起点。
⑥ 线程被启动。
我们的主要入口点main像以前一样准备异步基础设施,并创建并启动一个将管理 MapReduce 池的线程,该线程在worker函数中实现。worker创建多进程池并处理来自异步服务器的请求。通信是通过 FIFO(先进先出)队列完成的。queue模块确保队列是同步的(即,有锁定机制来确保线程不会导致不一致的状态)。有一个用于接收工作的队列,另一个用于返回结果。worker中的所有函数都是阻塞的,因为在初始化时没有要处理的内容:客户端正由异步部分分发。
注意,当使用多进程而不是线程进行通信时,队列也是一个很好的方式。multiprocessing模块有一个特定的Queue类用于此目的,因为管理进程间通信比多线程版本更困难。一些负担被传递给了用户。只有可以被 pickle 序列化的对象才能通过队列。由于使用了pickle和进程间通信,速度可能会成为一个问题,所以请注意这一点。
让我们从作业提交开始。异步部分现在被编码如下:
async def submit_job(job_id, reader, writer):
writer.write(job_id.to_bytes(4, 'little'))
writer.close()
code_size = int.from_bytes(await reader.read(4), 'little')
my_code = marshal.loads(await reader.read(code_size))
data_size = int.from_bytes(await reader.read(4), 'little')
data = pickle.loads(await reader.read(data_size))
work_queue.put_nowait((job_id, my_code, data)) ①
① 我们将数据写入 work_queue,非阻塞方式。
我们现在的submit_job函数终于开始发挥作用了:它将作业提交到work_queue,这将由运行worker函数的线程获取。我们使用put_nowait来避免在放入结果时阻塞。在我们的情况下,这不应该发生,因为队列是在没有关于大小的限制的情况下初始化的。然而,我们在这里考虑了队列有大小限制的可能性,在这种情况下,你需要在创建队列的调用中考虑并实现这一点。
其余的异步代码如下:
def get_results_queue():
while results_queue.qsize() > 0: ①
try:
job_id, data = results_queue.get_nowait() ②
results[job_id] = data
except Empty: ③
return
async def get_results(reader, writer):
get_results_queue()
job_id = int.from_bytes(await reader.read(4), 'little')
data = pickle.dumps(None)
if job_id in results:
data = pickle.dumps(results[job_id])
del results[job_id]
writer.write(len(data).to_bytes(4, 'little'))
writer.write(data)
async def accept_requests(reader, writer, job_id=[0]):
op = await reader.read(1)
if op[0] == 0:
await submit_job(job_id[0], reader, writer)
job_id[0] += 1
elif op[0] == 1:
await get_results(reader, writer)
① 我们获取队列的大小以查看是否有东西到达。
② 我们从 results_queue 读取响应,非阻塞方式。
③ 我们为空队列做准备。
accept_requests与第一部分完全相同,这里仅为了完整性而展示。
get_results只是在开头新增了一行:调用get_results_queue,该队列负责检查 MapReduce 是否完成,并将结果转移到results字典中。值得注意的是,通过qsize确定的队列大小只是一个近似值,因此我们必须考虑空队列,并在等待消息到达之前避免阻塞。
使用线程和进程进行锁定和低级同步
避免使用大多数低级原语进行锁定!threading和multiprocessing都支持大量标准同步原语,包括锁和信号量,以及一些其他原语。然而,我们在这里的观点是,如果你需要使用这些低级构造,你可能仍然需要用更低级的语言实现代码。因此,我们将在本书的后面部分处理这些机制,那时我们将通过重新实现标准 Python 之外的代码部分来提高性能。
虽然与性能没有直接关系,但一个常用的多进程相关通信原语是Pipe,因为它允许使用标准输入和输出通道与外部应用程序进行通信。
3.5.2 创建服务器的健壮版本
到目前为止,我们很少关注错误和意外输入。在这个阶段,我们将使我们的代码更加健壮。这将显著增加我们的实现规模。我们将确保当服务器关闭时,异步服务器能够优雅地停止:工作线程终止,并且池被正确关闭。
我们的main函数将需要稍微增强一些:
-
我们需要捕获用户中断请求(通常是 Control-C)并在该阶段进行清理。
-
由于我们的异步服务器现在可以被取消,我们也必须捕获这一点。
-
我们必须有一种方式来通知工作线程它需要进行清理。
这里是实现(代码位于03-concurrency/sec5-all/server_robust.py):
import signal
from time import sleep as sync_sleep
def handle_interrupt_signal(server): ①
server.close() ②
while server.is_serving(): ③
sync_sleep(0.1)
def init_worker():
signal.signal(signal.SIGINT, signal.SIG_IGN) ④
async def main():
server = await asyncio.start_server(accept_requests, '127.0.0.1', 1936)
mp_pool = mp.Pool(initializer=init_worker) ⑤
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGINT, partial(
handle_interrupt_signal, server=server)) ⑥
worker_thread = threading.Thread(target=partial(worker, pool=mp_pool))
worker_thread.start()
async with server:
try:
await server.serve_forever()
except asyncio.exceptions.CancelledError: ⑦
print('Server cancelled')
work_queue.put((-1, -1, -1)) ⑧
worker_thread.join() ⑨
mp_pool.close()
mp_pool.join()
print('Bye Bye!')
① 我们定义了中断的信号处理程序。
② 我们请求服务器停止。
③ 我们等待服务器完成请求。
④ 我们忽略中断信号以确保它不会传播到池中。
⑤ 我们确保多进程池已初始化(即忽略池忽略输入信号)。
⑥ 我们向我们的异步处理器添加信号处理。
⑦ 我们捕获取消操作以通知用户。
⑧ 我们发送-1,这被我们的工作进程解释为登出命令。
⑨ 我们等待所有线程最终完成。
如果你查看main,你会注意到现在为了效率,池在这里创建,每个进程都有一个名为init_worker的初始化函数。这是因为当我们按下 Control-C 时,我们不希望池被中断,因为信号被传播到池中的所有进程。因此,我们使用signal库,并指示每个池进程(signal.SIG_IGN)忽略中断信号(signal.SIGINT)。
我们希望主线程能够捕获中断信号并正确处理它。因为我们希望能够从信号控制异步代码,我们需要使用不同的方式来捕获它:我们调用循环中的add_signal_handler。我们需要传递服务器对象,我们通过部分函数应用来实现这一点。处理程序handle_interrupt_signal取消服务器,并等待它不再提供服务,因为取消可能不会立即发生。
当我们运行异步服务器时,现在我们需要注意取消操作,因此我们捕获该异常。最后,我们需要要求监控线程进行清理。因为信号只传递给主线程,我们需要通过某种通信机制来完成这项工作:我们只需发送带有-1的job_id的work。
在多线程和多进程代码中管理错误和异常
调试多线程和多进程代码可能会非常令人沮丧,即使使用简单的进程间通信模型也是如此。我们只是触及了表面,并且有些不切实际地假设架构表现良好。如果你在代码中进行并发处理,你应该考虑实施良好的日志记录来帮助你捕获问题。在可能的情况下,你应该尽量确保问题与并发无关(例如,尝试在任何问题代码不在池或单独的线程中运行,而是在单个进程的单个线程上运行)。例如,你可以暂时用list(map)替换multiprocessing.Pool.map。
由于工作线程需要显式清理,我们需要实现这一点
def worker(pool):
while True:
job_id, code, data = work_queue.get()
if job_id == -1:
break ①
func = types.FunctionType(code, globals(), 'mapper_and_reducer')
mapper, reducer = func()
counts = mr.map_reduce(pool, data, mapper, reducer, 100, mr.reporter)
results_queue.put((job_id, counts))
print('Worker thread terminating')
① 如果job_id是-1,我们就会跳出循环。
摘要
-
当通信需求和所需处理量都较小的时候,异步编程可以是一种有效地处理许多同时请求的方法;这是与 Web 服务器最常见的一种模式。
-
从某种意义上说,Python 是一种慢速语言,因为它有一个慢速的旗舰实现。这使得运行并行代码的能力变得更加重要。
-
Python 的线程对于性能提升来说并不出色。全局解释器锁(GIL)要求一次只能有一个线程运行。话虽如此,Python 的一些其他实现(例如 IronPython)没有 GIL,并且线程代码可以是并行的。
-
线程在架构设计方面仍然非常有用。虽然它不是提高性能的最佳途径,但不要完全摒弃它。在本书范围之外,还有其他一些观点,其中它仍然相关。
-
使用 Python 多进程,即使只是使用纯 Python 代码,也可以利用计算机上的所有 CPU 核心。
-
通常最好保持计算粒度粗略,过多的通信可能会减慢你的解决方案。当你跨进程通信时,确保通信的开销不是性能瓶颈的主要来源。
-
在开发并行代码时,远离共享内存和低级锁。如果你认为你需要它们,那么请使用更低级的语言实现一个顺序解决方案。在具有复杂通信模式的并行解决方案中进行调试非常困难,因为并行系统中的通信大多是不可预测的。
¹ 尽管我们将要实现的第一种异步通信解决方案对我们用例来说可能很原始,但在其他情况下它非常好。例如,对于大多数 Web 服务器来说,这是完全合理的,正如 NodeJS 所展示的那样。像往常一样,什么是原始的或什么是最优的,取决于你具体的问题。
另一个选择是自行实现一个concurrent.futures执行器,但在这种情况下,你仍然需要理解底层模块,如threading和multiprocessing。
4 高性能 NumPy
本章涵盖
-
从性能角度重新发现 NumPy
-
利用 NumPy 视图提高计算效率和内存节省
-
介绍数组编程作为一种范例
-
配置 NumPy 内部结构以提高效率
NumPy 对于使用 Python 进行数据分析的重要性难以言表。这本书甚至可以被称为高性能 Python 与 NumPy。NumPy 将在你的堆栈中某个地方被发现:你使用 pandas 吗?NumPy。你使用 scikit-learn 吗?NumPy。Dask?NumPy。SciPy?NumPy。Matplotlib?NumPy。TensorFlow?NumPy。如果你在 Python 中进行数据分析,几乎可以肯定你的答案中包括 NumPy。
NumPy 是一个 Python 库,它提供了多维或 N 维数组对象,如矩阵(二维),以及操作这些数组的功能。其实现效率极高,核心是用 Fortran 和 C 编写的。许多数据分析问题可以通过 N 维数组在核心进行建模;这就是为什么 NumPy 在这个领域无处不在。
考虑到 NumPy 在 Python 数据分析中的重要性和广泛使用,一些与之相关的主题将在其他章节中讨论,特别是:
-
在第五章中使用 Cython 对函数进行向量化
-
第六章中数组的内部内存组织
-
在第六章使用 NumExpr 进行快速数值表达式评估
-
在第八章和第十章中使用大于内存的数组
-
在第八章和第十章中高效存储数组
-
在第九章中使用 GPU 计算进行数组处理
在本章中,我们将从 NumPy 的复习开始。虽然本书假设你已经接触过 NumPy,但即使你在使用这个库,也可能是在间接使用。例如,你可能在使用 pandas 或 Matplotlib,但自己进行的直接 NumPy 编程很少。这次复习侧重于从性能角度的 NumPy 概念。如果你觉得你需要更全面的介绍,网上有无数免费示例。官方的示例非常好:numpy.org/devdocs/user/quickstart.html。NumPy 网站还提供了一个精选的学习资源列表,位于numpy.org/learn/。
在入门篇之后,我们将探讨数组编程作为一种编程模型,其中一次对多个原子值执行操作。这种方法在性能方面非常有价值,同时也是编写代码的一种优雅方法。在第章的最后一部分,我们将讨论 NumPy 内部架构和依赖对其性能的影响——我们将学习如何对其进行微调。
4.1 从性能角度理解 NumPy
在本节以及本章的其余部分,我们将通过使用一个实际例子来学习关键概念和技术:简单图像处理例程的开发。从最初接触来看,图像是二维数组(即矩阵),因此很容易进行 NumPy 操作。因此,我们假设我们正在开发一款新的图像处理软件。再次强调,虽然本节类似于 NumPy 入门教程,但它强调 NumPy 对性能的影响。所以即使你已经了解了 NumPy 的基础知识,你在这里也可能学到一些新的东西。
4.1.1 数组的副本与视图
我们的首要任务是读取一个图像文件并对它执行几个旋转操作。我们将直接使用 NumPy 来旋转图像,而不是使用我们读取图像时使用的 Pillow 图像库提供的函数。我们将学习如何使用 NumPy 的内存复制和视图创建来完成这项工作,以便我们可以比较它们的效率。视图基于共享相同内存的数组,但以不同的方式解释它,因此它们通常更高效,尽管它们并不总是可以使用,正如我们将看到的。
让我们从加载一个熟悉的图像开始,即来自 Manning 出版物的标志,然后从中获取一个 NumPy 数组。请注意,一些操作,如旋转图像,可以被视为不过是对数组进行不同方式的解释:列变成行,行变成列。这正是 NumPy 视图的作用:对相同原始数据的另一种解释。我们将把这个过程分解成非常小的步骤,因为我们需要仔细考虑和理解每一行(代码可以在 04-numpy/sec1-basics/image_processing.py 中找到):
import sys
import numpy as np
from PIL import Image
image = Image.open("../manning-logo.png").convert("L") ①
print("Image size:", image.size)
width, height = image.size
image_arr = np.array(image)
print("Array shape, array type:", image_arr.shape, image_arr.dtype)
print("Array size * item size: ", image_arr.nbytes)
print("Array nbytes:", image_arr.nbytes)
print("sys.getsizeof:", sys.getsizeof(image_arr))
① convert("L") 操作将图像转换为灰度。
我们使用 Pillow 库加载 Manning 标志(图 4.1),将图像转换为灰度。每个像素将由一个无符号字节表示。图像的大小为 182 × 45。输出如下:
Image size: (182, 45)
Array shape, array type: (45, 182) uint8
Array size * item size: 8190
Array nbytes: 8190
sys.getsizeof: 8302
然后,我们通过使用函数 np.array 获取表示图像数据的数组。NumPy 能够与 Pillow 图像一起工作,并不是因为 NumPy 知道图像是什么,而是因为图像对象实现了 __array__interface__,这是 NumPy 用于构建数组表示的方法。
我们随后打印出数组的shape,其大小为 45 × 182。请注意,对于图像而言——宽度在前,高度在后——这一惯例与 NumPy 所遵循的惯例相反,NumPy 的惯例源自数学——行数在前,列数在后。这种细微差别实际上比看起来更重要,我们将在下一小节讨论不同数据视图时开始看到这一点。但当我们讨论第六章中的内存表示时,问题的严重性将变得尤为明显。
然后我们打印数组的类型,它为uint8(即 8 位无符号整数,或一个字节)。uint8足以保存足够的信息以用于灰度图像。随后,我们以两种不同的方式打印数组占用的内存:(1)通过将数组中的项目数(45 * 182 = 8190)乘以每个元素的大小(在我们的例子中是 1 字节),或者(2)我们可以直接使用nbytes字段。
最后,我们使用第二章中介绍的getsizeof函数来获取数组对象的大小。这包括原始数组(8190)加上 Python 和 NumPy 的开销和元数据(总计 8302)。
我们已经看到了几种确定我们数组大小的方法——有和没有对象开销。现在我们将翻转图像上下颠倒。图像翻转可以通过复制数组或简单地改变我们对原始数据的解释来完成,因此它是一个介绍视图的好例子。在得到两个翻转之后,我们将遮挡原始图像的一半:
flipped_from_view = np.flipud(image_arr) ①
flipped_from_copy = np.flipud(image_arr).copy() ②
image_arr[:, :width//2] = 0
removed = Image.fromarray(image_arr, "L")
image.save("image.png")
removed.save("removed.png")
flipped_from_view_image = Image.fromarray(flipped_from_view, "L")
flipped_from_view_image.save("flipped_view.png")
flipped_from_copy_image = Image.fromarray(flipped_from_copy, "L")
flipped_from_copy_image.save("flipped_copy.png")
① 在垂直轴上翻转图像。这是通过视图完成的。
② 从翻转的图像中创建一个副本
图 4.1 显示了四个图像的组合。flipped_from_view是从image_arr的视图中创建的。这意味着当你用image_arr[:, :width//2] = 0更改image_arr时,flipped_from_view也会被更改,因为原始数组是共享的。

图 4.1 四个图像:原始 Manning 标志(image.png)、左侧被遮挡(removed.png)、从副本中垂直翻转(flipped_copy.png)和从视图翻转(flipped_view.png)
视图共享底层原始数组;副本不共享。因此,flipped_from_copy的图像不会受到image_arr上变化的影响。作为旁注,Image.fromarray创建原始数组的副本。这就是为什么 image.png 和 removed.png 不同的原因。如果它提供了一个视图,那么这些图像将是相等的。
支持图像的底层数据结构在图 4.2 中显示。请注意,原始图像已被销毁,image_arr包含被遮挡的数组。

图 4.2 许多 NumPy 操作可以生成新的对象或视图,这些对象或视图与原始对象共享原始数据。有时这可能不可能或不是所希望的,数据会被复制。
有时可能不希望进行视图共享。例如,你可能不想遮挡原始图像。在这种情况下,你需要创建一个副本,以便原始对象保持完整。也有时候不可能以视图的形式获取数据;我们将在本节稍后看到一个例子。
可以看到数组是否基于另一个数组:
print(flipped_from_copy.base, flipped_from_view.base)
print(flipped_from_view.base is image_arr) ①
print(flipped_from_view.base == image_arr) ②
① 检查是否是同一个对象
② 检查两个数组的所有值是否相等
输出是:
None [[ 0 .... <long array>]]
True
[[ True True True ... True True True]
flipped_from_copy.base 将会是 None,因为它是一个全新的副本。flipped_from_view.base 将会有一个值——一个矩阵。我们可以通过使用 is 来检查它是否与 image_arr 是同一个对象。请注意:如果你使用 ==,你将得到数组所有元素逐个比较的结果。对于 is,你将得到 True;对于 ==,你将得到一个包含 True 的数组。
提示:任何视图的基础不是视图所派生的对象,而是链中的第一个对象。例如,如果你有 v2 = v1[:-1] 和 v1 = arr[::-1],v2.base is arr 和 v1.base is arr 都是 True,但 v2.base is v1 不是。
正如我们所看到的,Numpy 对象有一组元数据,如 shape 或 dtype。原始数组数据位于一个名为 data 的字段中,它指向一个 Python 内置类型,一个 memoryview。memoryview 类提供了许多基本功能,从 Python 端处理具有相同类型的分配内存的块。例如,它实现了索引、切片和内存共享。
有可能查询 NumPy 数组是否共享内存,这比仅仅是视图更通用,因为 memoryviews 可以同时属于其他对象,以及,可能地,其他数组,而无需使用视图来创建它们:
print(np.shares_memory(image_arr, flipped_from_copy),
np.shares_memory(image_arr, flipped_from_view))
np.shares_memory 对于 image_arr 和 flipped_from_copy 将是 False,因为我们正在处理一个副本。对于 image_arr 和 flipped_from_view,它将是 True。一般来说,如果 base 是共享的,那么内存也是共享的,但反之不一定成立:内存可以共享,而不一定有相同的 base。
提示:确定两个数组是否共享内存不是一个简单的问题。在复杂场景中,可能需要花费很多时间,这使得它变得不切实际。在这些情况下,有一个更快的函数 may_share_memory,可以提供一个猜测,即两个数组是否共享内存。
吸收要点
这里要记住的关键点是:视图通常比副本要高效得多。这有两个主要原因。首先,当你复制一个数组时,你需要支付复制原始数据的计算成本,而使用视图时,只需重新创建视图信息。可能更重要的是,当你复制一个数组时,你需要加倍所需的内存,这在处理大型内存数组时可能不可行。无论如何,不要忘记,如果你更改任何视图,所有共享相同内存的对象都将受到影响。
让我们运行一个简短的示例来了解这种效率差异的影响。我们将创建一个具有可变大小的数组,并测量创建视图和副本所需的时间(表 4.1):
import sys
import timeit ①
import numpy as np
for size in [
1, 10, 100, 1000, 10000, 100000, 200000, 400000, 800000, 1000000]:
print(size)
my_array = np.arange(size, dtype=np.uint16)
print(sys.getsizeof(my_array))
print(my_array.data.nbytes)
view_time = timeit.timeit(
"my_array.view()",
f"import numpy; my_array = numpy.arange({size})")
print(view_time)
copy_time = timeit.timeit(
"my_array.copy()",
f"import numpy; my_array = numpy.arange({size})")
print(copy_time)
copy_gc_time = timeit.timeit(
"my_array.copy()",
f"import numpy;
import gc; gc.enable(); my_array = numpy.arange({size})")
print(copy_gc_time)
print()
① 如果你使用 iPython,请记住 %timeit 魔法命令是可用的。
表 4.1 比较复制和视图之间的时间和内存分配
| 数组大小 | 数组内存(b) | 视图时间 | 复制时间 |
|---|---|---|---|
| 1 | 2 | 0.171 | 0.281 |
| 10 | 20 | 0.137 | 0.259 |
| 100 | 200 | 0.139 | 0.286 |
| 1000 | 2000 | 0.162 | 0.502 |
| 10000 | 20000 | 0.142 | 2.275 |
| 100000 | 200000 | 0.138 | 31.257 |
| 200000 | 400000 | 0.152 | 67.005 |
| 400000 | 800000 | 0.144 | 354.287 |
| 800000 | 1600000 | 0.177 | 547.843 |
| 1000000 | 2000000 | 0.142 | 729.966 |
复制的内存负担很容易理解:每次复制时,内存需求量都会加倍。¹ 复制的计算负担则不那么明显:从直观的角度来看,你可以假设复制数组所需的时间与其大小成正比。如果你查看表格,有时线性关系会中断。这一点将在第六章中解释。
吸取的经验
正如我们所见,视图可以节省你计算时间和内存,因此你应该尽可能使用视图。视图的最大缺点是它们并不总是可行的;有时没有其他选择,只能复制。尽管如此,NumPy 的视图机制既强大又灵活,因此可以在许多情况下使用。让我们更深入地了解视图机制,以便我们可以看到如何在各种情况下使用它。
4.1.2 理解 NumPy 的视图机制
为了能够充分利用视图进行高效处理,我们首先必须了解它们是如何工作的。视图的灵活性主要来自两块元数据:第一块,我们刚刚看到的是形状。第二块是步长;我们将在稍后给出步长的更精确定义。首先,让我们看看几个说明不同形状和步长值的例子。
让我们首先分配一个包含 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 的数组,它是 4 字节无符号整数,并查看它的步长和形状。我们稍后将这个简单的线性数组重塑成二维结构:
import numpy as np
linear = np.arange(10, dtype=np.uint32)
数组将在内存中连续分配。这意味着 0 将被 1 接着,1 被 2 接着,2 被 3,以此类推。在这个阶段,这可能会显得很明显,但正如我们将看到的,它远非如此。
警告 NumPy 中展示的连续分配是一个简化的例子——尽管对于这个案例是正确的——出于教学目的。从现在开始,我们将看到非连续数组的例子。但我们将在第六章中推迟关于分配的重要细节。作为一个建议,如果你第一次看到有关数组连续性的讨论,请先不要立即查看那一章。首先确保基本概念正确。
让我们获取一个 2 × 5 矩阵的视图,该视图与相同的数据。之后,让我们创建另一个视图,一个 5 × 2 矩阵,它是通过 2 × 5 矩阵的转置得到的。我们想了解原始数组与新矩阵之间的关系:
m2x5 = linear.reshape((2, 5))
print(np.shares_memory(linear, m2x5))
print("2x5", m2x5.shape)
print("2x5 corners", m2x5[0, 0], m2x5[0, 4],
m2x5[2, 0], m2x5[2, 4])
m5x2 = m2x5.T
print(np.shares_memory(m2x5, m5x2))
print("5x2", m5x2.shape)
print("5x2 corners", m5x2[0, 0], m5x2[0, 1],
m5x2[4, 0], m5x2[4, 1])
我们首先要确保矩阵共享内存(即,它们是视图,而不是副本)。np.shares memory 显示 linear、m5x2 和 m2x5 都共享内存,这在图 4.3 中有所描述。我们还打印了新声明矩阵的角点,以便它们的边界清晰。

图 4.3 这三个数组,无论维度如何,都共享相同的内存。
问题是,NumPy 如何知道如何从相同的内存中找到一个元素?形状可能足以区分一维数组和矩阵之一。但是,如何区分来自相同内存的两个不同形状的矩阵呢?这就是步长的用途:
print("linear", linear.strides)
print("2x5 strides", m2x5.strides)
print("5x2 strides", m5x2.strides)
结果将是 4, (4, 20) 和 (20, 4)。
那么,步长就是你需要跳过多少字节才能到达一个维度上的下一个元素。让我们通过将其与前面三个例子联系起来,来使这个概念更加清晰。
linear 变量的步长是 4:这意味着只有一个维度,并且要从当前元素跳到下一个元素,你需要跳过 4 字节,这是我们所选择的数据类型的大小,np.uint32。在图 4.4 中,每次你在线性数组中前进一个位置,你就在内存中前进 4 字节以获取值。索引 i 的内存位置然后是 stride * i 或 4 * i。

图 4.4 在一维数组中跳转到相邻元素。这有一个单一的步长:数字是我们数据类型的大小。
对于二维数组,事情会变得稍微复杂一些。对于我们的 2 × 5 数组,如果你想向前跳过一列,那么元素是相邻的,所以需要跳过一个元素——当前这个元素——(如图 4.5)。鉴于我们的元素大小是 4 字节,我们得到的步长是 4。但是,如果你想向前跳过一行,那么你必须跳过当前元素加上额外的 4(因为每一行有五个元素)。所以 5 个元素乘以元素大小 4 等于 20。在这种对内存的解释中,你可以通过函数 strides[0]*i + strides[1]*j(20*i + 4*j)来获取元素 i,j。

图 4.5 在二维数组中跳转到 2 × 5 的相邻元素。我们需要两个步长:每个维度一个。
m5x2 的步长是 20, 4:这意味着有两个维度,并且要到达下一行,你需要跳过 20 字节(五个列,每个列 4 字节)。下一个列是下一个值,距离 4 字节。
图 4.6 应该会使这一点更加清晰。(这将与数组的内部表示有显著差异,我们将在第六章中考虑这一点。)

图 4.6 在二维数组中跳转到 2 × 5 的相邻元素。行步长将取决于列数。
许多 NumPy 操作都是视图转换。例如,反转数组也可以表示为一个视图:
back = linear[::-1]
print("back", back.shape, back.strides, back[0], back[-1])
注意,现在的步长是 -4。NumPy 可以创建可以向后移动的视图(如图 4.7)。

图 4.7 反转一维数组将反转步长的信号。
可以用类似的方法处理二维数组。如果它们被反转,m2x5 和 m5x2 的步长会是什么?
将变换渲染为视图的能力并不总是可行的,因为这取决于在现有视图和新视图之间建立线性关系的能力。例如,让我们考虑一个 20 × 5 的矩阵,然后从三个行中选择一行,从两个列中选择一列,创建一个 7 × 3 的矩阵。最后,让我们得到这个 7 × 3 矩阵的一维版本:
a100 = np.arange(100, dtype=np.uint8).reshape(20, 5)
a100_step_3_2 = a100[::3, ::2]
print(a100_step_3_2.shape, a100_step_3_2.strides)
print(np.shares_memory(a100, a100_step_3_2))
a100_step_3_2_linear = a100_step_3_2.reshape(21)
print(np.shares_memory(a100_step_3_2, a100_step_3_2_linear))
使用步长 15, 2,7 × 3 矩阵仍然可以渲染为视图。将其转换为单维表示的看似更容易的过程无法作为视图完成,因此会创建一个副本。
NumPy 所称的花式索引始终以副本的形式渲染。以下是一个从 2 × 5 矩阵中获取五个交替值(上下)的例子:
import numpy as np
m5x2 = np.arange(10).reshape(2, 5)
my_rows = [0, 1, 0, 1, 0]
my_cols = [0, 1, 2, 3, 4]
alternate = m5x2[my_rows, my_cols]
print(m5x2)
print(alternate)
print(np.shares_memory(m5x2, alternate))
为了刷新您对花式索引的记忆,它需要一个索引列表,每个数组维度一个索引,并返回对应于列表上位置的元素。让我们通过表 4.2 中的矩阵来观察这一过程。
表 4.2 原始矩阵
| 0 | 1 | 2 | 3 | 4 |
|---|---|---|---|---|
| 5 | 6 | 7 | 8 | 9 |
对于行列表 [0, 1, 0, 1, 0] 和列列表 [0, 1, 2, 3, 4],我们将有一个 alternate 列表 [0 6 2 8 4]。np.shares_memory 将为 False。
要点
首先,视图可以比复制更有效率。其次,视图可以像我们在本节中学到的那样使用。让我们将这些用途应用到我们的图像处理代码中,这样我们就可以看到使用视图的一些优点以及一些陷阱。
4.1.3 利用视图提高效率
我们现在将通过仅使用视图的操作来转换我们的示例图像。这是在大数据背景下记住我们为什么要这样做的好时机:对于非常大的数组,内存可能有限,复制大数组也会产生时间成本。
让我们先垂直和水平翻转图像:
import numpy as np
from PIL import Image
image = Image.open("../manning-logo.png").convert("L")
width, height = image.size
image_arr = np.array(image)
print("original array", image_arr.shape, image_arr.strides, image_arr.dtype)
image.save("view_initial.png")
invert_rows_arr = image_arr[::-1, :] ①
print("invert rows", invert_rows_arr.shape, invert_rows_arr.strides,
np.shares_memory(invert_rows_arr, image_arr))
Image.fromarray(invert_rows_arr).save("invert_x.png")
invert_cols_arr = image_arr[:, ::-1]
print("invert columns", invert_cols_arr.shape, invert_cols_arr.strides,
np.shares_memory(invert_cols_arr, image_arr))
Image.fromarray(invert_cols_arr).save("invert_y.png")
① 我们可以使用数组 reverse 在一个维度上翻转图像。记住,在前一节中我们使用了 Image.flipud 来水平翻转图像。
在这个阶段,代码应该很容易阅读:对于 182 × 24 的曼宁标志,原始数组具有 45, 182 的形状,步长为 182, 1。我们正在处理 1 字节的无符号整数数据类型,因此每个像素一个字节。当我们水平翻转图像(即通过行翻转)时,唯一改变的是第二个步长,从 1 变为 -1。相反,当我们垂直翻转图像时,第一个步长从 182 变为 -182。
现在我们尝试旋转图像。我们将使用三种方法:重塑(reshape)、转置(.T)和 90 度旋转(.rot90)。我们这样做是为了检查三种不同策略的输出以及它们在内部是如何表示的:
view_swap_arr = image_arr.reshape(image_arr.shape[1], image_arr.shape[0]) ①
print("view_swap", view_swap_arr.shape, view_swap_arr.strides)
Image.fromarray(view_swap_arr, "L").save("view_swap.png")
trans_arr = image_arr.T
print("transpose", trans_arr.shape, trans_arr.strides)
Image.fromarray(trans_arr, "L").save("transpose.png")
rot_arr = np.rot90(image_arr)
print("rot", rot_arr.shape, rot_arr.strides)
Image.fromarray(rot_arr, "L").save("rot90.png")
① 可以使用 swapaxes 方法完成相同的事情。
要检查我们是否有视图或副本,我们正在打印memoryview对象,这将给我们一个反映内存位置的十六进制数。所有的内存位置都将相同,因此这些操作创建了视图。
现在,让我们包括一个切片——仅从标志中提取的单词Manning,在这种情况下,它也是一个视图:
slice_arr = image_arr[15:, 77:]
print("slice_arr", slice_arr.shape, slice_arr.strides,
np.shares_memory(slice_arr, image_arr))
Image.fromarray(slice_arr, "L").save("slice.png")
形状和步长如表 4.3 所示。
表 4.3 轴交换、转置和旋转的形状和步长
| 操作 | 形状 | 步长 |
|---|---|---|
swapaxes |
182 45 | 45 1 |
transpose |
182 45 | 1 182 |
rot90 |
182 45 | -1 182 |
slice |
30 105 | 182 1 |
你能预测图像将如何显示吗,特别是轴交换、旋转和转置?结果如图 4.8 所示。

图 4.8 使用原始 Manning 标志的直接数组操作产生的图像
小贴士:为了便于解释,我们只使用了具有视图的一维和二维数组,但所有这些 NumPy 机制都可以用于高阶数组。
警告:直接更改步长和形状值是可能的。模块numpy.lib.stride_tricks有一个名为as_strided的函数,它接受一个现有数组、所需的形状和步长值,并返回结果。
事实是,这样的函数位于名为stride_tricks的模块中,这应该引起一些注意。此外,该函数的帮助页面以以下文本开头:“此函数必须极其小心地使用。”
问题在于你可以传递任意值(即,错误的值)给函数。这是少数几个你可以在 Python 中造成内存损坏的情况之一,因为错误的值将被用来访问错误的内存位置。您的程序可能会崩溃,甚至可能暴露敏感信息。
吸取的经验
视图可以显著减少计算和内存成本。虽然它们不能在所有情况下使用,但 NumPy 的视图机制非常灵活,可以在许多可以用 NumPy 机制重新解释数据的情况中使用:形状和步长。
关于视图和副本的性能方面,我们已经说得足够多了。现在,让我们看看如何更有效地编程 NumPy,正如您将看到的,这意味着使用更多惯用的 NumPy。
4.2 使用数组编程
数组编程是一种编程模型,其中操作一次应用于数组的所有值。这种模型用于科学和高性能编程。它有两个主要目的:使代码编写更具声明性和可读性,并使代码运行更高效。结果证明,NumPy可以主要使用数组编程技术来完成。
我们应该在数组库和数组编程之间做出区分。一个简单的例子将使这种区别变得清晰。(这个例子只是为了教学目的,我不期望任何人使用非数组解决方案。)
注意:虽然没有人会在这个简单的例子中使用非数组解决方案,但看到使用非数组语法的 NumPy 代码是很常见的。在这个例子中,数组方法的好处是显而易见的,但完全有可能,并且是不希望的,以第一解决方案中展示的较低效的方式使用 NumPy。在使用 NumPy 和许多其他类似库(例如 pandas)时,最重要的技能之一是识别何时数组代码可以替换原始 Python 代码。
假设你想计算两个向量的和。从用户的角度来看,这里是一个常见的、非数组编程版本(代码位于04-numpy/sec3-vectorize/array_and_broadcasting.py):
import numpy as np
def sum_arrays(a, b): # Assumes both are the same size
my_sum = np.empty(a.size, dtype=a.dtype)
for i, (a1, b1) in enumerate(zip(np.nditer(a), np.nditer(b))):
my_sum[i] = a1 + b1
return my_sum.reshape(a.shape)
我们明确编写了代码来遍历所有元素。这种实现充满了问题,我们将在稍后讨论。现在,让我们仅将其与数组版本进行比较:
a + b
显而易见的优势是,数组版本更加简洁和声明性,仅此一点就足以使数组习语更可取。但从我们高性能的角度来看,还有一些更相关的事情正在发生。
数组示例可能也快得多。第一个示例是 Python 代码,具有已经讨论过的原生 Python 代码固有的速度限制。重载+运算符中的数组示例将尽可能快地实现。这意味着非 Python 实现通常在 C 或 Fortran 中,可能利用了向量化 CPU 操作,甚至可能运行在 GPU 上。
在这个阶段,我们不会过分关注+运算符的实现方式。我们只需要掌握使用数组语法的潜在性能优势。
4.2.1 总结
高效的代码可以是干净的代码。关于高效代码必须始终是脏的这种神话,我希望在这里被搁置。数组实现既更高效又更干净。
现在我已经提倡使用数组编程,让我们简要介绍另一个 NumPy 概念:广播。广播很重要,因为它允许我们更好地利用数组编程,并编写更清晰、更简洁的代码。
4.2.2 NumPy 中的广播
为了理解广播,让我们首先深入探讨我们纯 Python 实现的两个数组相加。那段代码在这里重复出现:
def sum_arrays(a, b):
my_sum = np.empty(a.size, dtype=a.dtype) ①
for i, (a1, b1) in enumerate(zip(np.nditer(a), np.nditer(b))):
my_sum[i] = a1 + b1
return my_sum
① 从第一个数组假设形状和类型
在我们讨论这段代码的主要问题之前,值得注意的一个优点是:它使用了np.empty。这个函数为数组分配内存,但不会用任何值初始化它。当你创建一个大型数组时,这可以节省时间。你必须确保稍后初始化数组,就像我们在这里做的那样,否则你将得到垃圾数据。因此,这不是一个通用的解决方案,但它可以在许多情况下提高性能。
之前代码的主要问题是没有输入检查,因此它接受任何形状的数组,包括可能不兼容的数组。看似简单的解决方案是强制 a 和 b 必须具有相同的形状,并相应地重塑输出。这会起作用,但使用起来会很笨拙。例如,如果你要添加一个包含 100,000 个元素的数组,并希望将所有元素增加 1,你必须这样做:
array_100000 = np.arange(100000)
sum_arrays(array_100000, np.ones(array_100000.shape))
这意味着只是为了将我们的原始数组增加 1,就需要分配和初始化一个大数组。这不仅难看,而且在处理大数据时可能过于昂贵。
如果我们能写 sum_arrays(array_100000, 1) 会怎么样?实际上,在 NumPy 中我们可以这样做!以下代码是完全有效的:
array_100000 = np.arange(100000)
array_100000 += 1 ①
① 我们稍后会看到,这并不等同于编写 array_100000 = array_100000 + 1。
array_100000 是一个有 100,000 个位置的数组,而 1 是一个原子值(即它们有不同的类型)。这是一个广播的例子,广播是一组合理的规则,使得 NumPy 可以将运算符应用于不同维度的数组。
下面是一些应用广播规则的实用示例。我们还比较了它们与可以与广播运算符混淆的函数。让我们从一维数组开始:
a = np.array([0, 20, 21, 9], dtype=np.uint8)
b = np.array([10, 2, 25, 5], dtype=np.uint8)
print("add one", a + 1)
print("multiply by two", a * 2)
print("add a vector", a + [10, 2, 25, 5])
print("multiply by a vector", a * [10, 2, 25, 5])
print("dot (inner) product", a.dot(b))
print("matmul (inner product)", a @ b)
+ 运算符在第一次打印时将数组中所有元素加 1。在第二次打印时,它将逐个元素相加:[0 20 21 9] + [10, 2, 25, 5] = [10 22 46 14]。
注意,* 运算符与数组一起工作的方式类似:第一种情况 * 2 将所有值乘以 2。第二种情况将逐个元素相乘:[0 20 21 9] + [10, 2, 25, 5] = [0 40 525 45]。
内积是通过 np.dot 和 @ (np.matmul) 运算符实现的(关于 @ 的更多信息稍后介绍)。
现在让我们看看一些使用广播的矩阵示例:
x = np.array([[0, 20], [250, 500], [1, 2]],
dtype=np.uint8)
y = np.array([[1, 10], [25, 5]], dtype=np.uint8)
print("add a matrix to itself", x + x)
print("add a matrix with column size", x + [1, 2])
# print(x + [-1, -2, -3])
print("add a matrix with row size", (x.T + [-1, -2, -3]).T)
print("inner product", np.inner(a, b))
print("matrix multiplication", x.dot(y))
# print(x.T.dot(y))
print("matmul", x @ y)
x[:, 0] = 0
print("assignement broadcasting", x)
将矩阵加到自身会产生预期的结果,因为所有值都加倍。你还可以将一维数组加到矩阵上:它应该有列的数量,并且将按行应用。你不能做相反的操作(即,添加一个行数为数的数组)。但有一种快速的方法(即,不涉及过多的复制或非数组编程)可以实现这一点,那就是转置矩阵——记住,这是一个非常快的视图操作——然后转置结果。
提示 NumPy 运算符并不直接映射到标准的数学期望。例如,* 并不是数学上的矩阵乘法;那将是 np.dot。
吸取的经验教训
关于广播还有很多东西要学习,但在这里,我们从性能的角度来看,我们已经涵盖了其基本用法。从这个角度来看,关于广播最重要的记住的事情是,它的实现通常是向量化的,正如我们在上一节中讨论的,向量化的实现可以快几个数量级。我们现在可以回到我们的图像处理代码,希望有了足够的动力来考虑基于数组的方案。
4.2.3 应用数组编程
现在我们将这些基于数组的编程技术应用到我们的图像处理程序中。在采用这种更高效的方法的同时,我们还需要学会如何处理一些潜在的问题。不要让这些问题让你对应用这些方法感到气馁:基于数组的编程比基于for循环的典型命令式编程更高效、更优雅。
我们现在将尝试使我们的图像变亮。记住,我们正在使用每个像素 1 字节——在转换图像时选择L选项——其值在 0 到 255 之间变化。我们将通过两种不同的方法来增加亮度:给每个像素加 5,以及将值加倍,如下所示:
import numpy as np
from PIL import Image
image = Image.open("../manning-logo.png").convert("L")
width, height = image.size
image_arr = np.array(image)
brighter_arr = image_arr + 5
Image.fromarray(brighter_arr).save("brighter.png")
brighter2_arr = image_arr * 2
Image.fromarray(brighter2_arr).save("brighter2.png")
在一个理想的世界里,我们的问题将会得到解决。但是,如果你看图 4.9 的结果,有些东西需要纠正。

图 4.9 原始图像变亮没有达到预期的效果。
加倍数字看起来是正确的,但我们增加 5 的那个数字有问题。记住,图像由 0 到 255 范围的 1 字节无符号整数表示。因为 255 加上 5,你最终得到溢出:260 变成 4。因此,颜色变得相当黑。
存在一个更隐蔽的问题——因为它没有被注意到——与加倍数字有关。这个数字看起来是正确的,但同样存在相同的问题。为了理解发生了什么,让我们打印原始图像和brighter2上的最大值:
print(image_arr.max(), image_arr.dtype)
print(brighter2_arr.max(), brighter2_arr.dtype)
原始图像的最大值是 255(全白),而加倍图像的最大值是 254,这是怎么回事?记住,在二进制中,255 是 0x11111111(8 位,全部为 1),2 * 255 是 510,0x1111111110(8 位更高,全部为 1,最后一位为 0),但溢出切掉了右边的位。我们最终得到 0x111111110(即,254)。图像看起来是正确的,但并不是。
警告:在选择数据类型时,请务必非常小心。当没有内存或速度方面的考虑时,你可以放宽要求,选择一个较宽的类型。但如果你需要尽可能节省内存,确保不要选择一个过小的类型。除非你有足够的内存和时间,否则除了对边缘情况有良好的测试覆盖外,没有一般性的规则。
最简单的方法,尽管不是更节省内存的解决方案,是使用更大的数据类型。例如:
brighter3_arr = image_arr.astype(np.uint16)
brighter3_arr = brighter3_arr * 2
print(brighter3_arr.max(), brighter3_arr.dtype)
brighter3_arr = np.minimum(brighter3_arr, 255) ①
print(brighter3_arr.max(), brighter3_arr.dtype)
brighter3_arr = brighter3_arr.astype(np.uint8)
print(brighter3_arr.max(), brighter3_arr.dtype)
Image.fromarray(brighter3_arr).save("brighter3.png")
① 不要混淆最小值和 min。min 返回数组的最小值;最小值通过广播选择最小的值。
我们可以假设一个值不可能比最大白色还要白。因此,我们将所有大于 255 的值转换为 255。我们首先将原始数组转换为 2 字节的无符号整数;然后数组被翻倍。当我们检查最大值时,它是正确的值,510。然后我们使用np.minimum从 255 和数组的每个元素中选择最小值,这是一个广播创建副本的典型例子。这使得最大值可以用单个字节表示。最后,我们将它重新铸造成 8 位无符号整数,此时最大值现在是正确的。在下一节中,我们将探讨一种更节省内存的方法来完成这项工作。²
在某些铸造过程中,有更有效的方法,例如让乘法自动返回np.uint16,我们将在下一小节中看到。
最后,让我们再次看看我们翻倍值的方式。之前,我们使用:
brighter3_arr = brighter3_arr * 2
这种方言创建了一个中间数组,该数组将用于存储乘法的结果。然后,变量brighter3_arr被替换为新的变量。但是,在短时间内——实际上,直到运行垃圾回收——这两个数组都将存在于内存中。这在内存和时间上对非常大的数组都可能是问题。但事实上,我们可以做得更好:
brighter3_arr *= 2
在这种情况下,NumPy 理解数组将要被修改,并执行原地乘法。这意味着不需要加倍内存,也不需要在初始化和垃圾回收上浪费时间。最终结果是一样的,但从性能角度来看,这两种方言完全不同。x = x * 2使用了双倍内存和更多时间。x *= 2更高效,应该尽可能使用。
吸取的经验教训
使用像 NumPy 或 pandas 这样的库时,问题的一部分是倾向于回到非数组编程方言,这些方言效率不高。这并非有意为之,但我们倾向于保持“正常”的方言,除非我们有意改变。让我们继续深入研究数组编程,因为它在处理性能问题时提供了更多的好处,而且因为我们想坚持这种范式。
4.2.4 培养矢量化思维
矢量化纯Python 代码并不比非矢量化代码更高效。np.vectorize的文档对此非常明确。
但在本书的许多章节中,我们将以矢量化术语进行思考:当我们讨论 Cython、pandas、CPU 矢量化以及,达到极致的 GPU 处理时。如果你在遇到这些概念之前先在纯 Python 和 NumPy 中接触过这些概念,那么这可能会使你在这些章节中的学习曲线更加容易。换句话说,本节中强调的矢量化思维在以后将会非常有用。
要理解矢量化以及 NumPy 的通用函数,我们将回到一个熟悉的例子。在前一小节中,我们看到了 brighter_image = image * 2 可能会溢出。一个不会溢出的函数会是什么样子?使用矢量化函数就很简单:
def double_wo_overflow(v):
return min(2 * v, 255)
现在,我们将矢量化这个函数并将其应用到我们的图像上:
import numpy as np
from PIL import image
vec_double_wo_overflow = np.vectorize(
double_wo_overflow, otypes=[np.uint8]) ①
brighter_arr = vec_double_wo_overflow(image_arr)
print(brighter_arr.max(), brighter_arr.dtype)
Image.fromarray(brighter_arr).save("vec_brighter.png")
① 我们需要指定输出类型。
np.vectorize 接收一个典型(即非矢量化)的函数,在这种情况下,允许它应用于每个标量。然后我们将新的 vec_double_wo_overflow 应用到我们的图像上。它将逐元素计算。
小贴士 虽然 np.vectorize 实质上是一个 for 循环,但在理论上,它可以并行调用以使用机器的所有核心,从而可能加快速度。掌握这种编程模式和其并行化的潜力,将使你对 GPU 的工作方式有更深刻的理解。如果你在这里学习了这个概念,那么在讨论第十章中的 GPU 优化时,你会更容易理解。
为了强调这段代码并不更快,我们的函数的 %timeit 测试结果显示在毫秒范围内。对于 *,我们处于 微秒 范围。
np.vectorize 比这要复杂得多,因为它允许支持广播规则。为了举例说明这一点,让我们以一个彩色图像为例。我们将使用一个名为“圣帕特里克节极光”的 NASA 图像(images.nasa.gov/details-GSFC_20171208_Archive_e000760)。
让我们从读取图像开始,现在的图像表示略有不同:
import numpy as np
from PIL import Image
image = Image.open("../aurora.jpg")
width, height = image.size
image_arr = np.array(image)
print(image_arr.shape, image_arr.dtype)
图像大小为 2040 × 1367。因为我们正在读取一个彩色图像,默认模式是 RGB(即三个通道:红色、绿色和蓝色),每个通道将有一个无符号字节。每个像素现在占用 3 个字节,而不是 1 个。因此,我们有一个形状为 2048 × 1367 × 3 的三维 NumPy 数组。当我们有一个 RGB 图像时,将图像转换为灰度是一个简单的算法:我们计算三个分量的平均值,这成为我们的灰度强度:
def get_grayscale_color(row):
mean = np.mean(row) ①
return int(mean) ②
vec_get_grayscale_color = np.vectorize(
get_grayscale_color, otypes=[np.uint8],
signature="(n)->()") ③
grayscale_arr = vec_get_grayscale_color(image_arr)
print(grayscale_arr.max(), grayscale_arr.dtype, grayscale_arr.shape)
Image.fromarray(grayscale_arr).save("grayscale.png")
① 这是三个 RGB 值的平均值。
② 平均值将是一个浮点数,因此我们将其转换为 int。
③ 我们覆盖了矢量化函数签名的默认期望。
默认情况下,np.vectorize 将向我们的函数发送一个标量,但我们可以更改签名以接受和返回其他类型的对象。在我们的情况下,我们希望我们的函数接受一个数组(即三个分量)并返回一个标量,因此签名是 (n)→()。最终结果如图 4.10 所示。

图 4.10 我们简单灰度化算法的结果
警告 这里值得重复的是,我们展示的解决方案是为了说明矢量化,它们可能不是最有效的。在这种情况下,最有效和 最佳 的解决方案可能是
grayscale_arr = np.mean(image_arr, axis=2).astype(np.uint8)
这将计算最后一个轴(即颜色通道所在的轴)上的平均值。虽然迭代解法是最差的选择,但这并不意味着创建自己的向量化函数总是最佳方法。在这种情况下,最佳方法是充分理解内置向量化函数的使用。
吸取的经验教训
从性能角度来看,纯 Python 向量化实际上并不是一个选项。但在其他上下文中,例如使用 Cython 和 GPU 时,向量化可以在性能上产生几个数量级的效应。所以如果你理解了向量化方法的一般原理,那么理解 Cython 和 GPU 章节将会容易得多。
现在我们已经考虑了 NumPy 编程性能的基础,我们将看看 NumPy 的内部结构以及如何优化它们以提高性能。结果是我们安装 NumPy 时所做的选择在速度和进程选择方面可以产生相当大的差异。
4.3 调优 NumPy 内部架构以提高性能
在本节中,我们将深入研究 NumPy,学习如何确保它配置为最大性能。我们首先概述 NumPy 的内部架构。
使 NumPy 成为一个高性能库的许多内部实现并不是用纯 Python 编写的。这并不令人惊讶,对吧?非纯 Python 部分,特别是外部库,可以进行配置,这些选择可以对性能产生巨大影响。
下一个章节的主题可能对大多数只关心 Python 方面的程序员来说有点枯燥。如果你信任你的 NumPy 库状态良好(或者你无法控制它),那么你可以自由地跳到最后一个小节,即 NumPy 中的线程,因为它在 Python 方面非常实用。如果你控制着整个 Python 栈,喜欢系统级问题,并且想要确保你提取了所有可能的性能,请继续阅读。
4.3.1 NumPy 依赖概述
许多科学库,无论是基于 Python 的还是不是,都依赖于两个广泛使用的线性代数库 API:BLAS(基本库代数系统)和 LAPACK(线性代数包)。
BLAS 实现了一套处理数组和矩阵的基本函数。BLAS 库将实现如向量加法和矩阵乘法等函数。在此基础上,LAPACK 实现了多个线性代数算法。例如,奇异值分解(SVD)是主成分分析的基础。图 4.11 展示了 NumPy 的架构。

图 4.11 NumPy 栈,包括库依赖
有许多替代实现,你选择哪一个会有操作上的后果。例如,在netlib.org上可用的标准 LAPACK 实现是非线程的,并且在现代架构上效率不高。一个常见的 BLAS/LAPACK 替代品是 OpenBLAS,另一个是 Intel MKL;两者都是线程化的。如果你的 NumPy 实现是线程化的,那么你使用计算机资源的方式可能会有很大的不同。例如,如果你有一个非线程化的版本,你可以分配尽可能多的进程,但如果你使用的是多线程的 BLAS 和 LAPACK,你必须小心不要过度承诺 CPU 资源。
因此,了解你的 NumPy 实现依赖于哪些库对于理解下一步的优化步骤很重要。理论上,你可以通过以下方式检测依赖关系:
import numpy as np
np.show_config()
在实践中,你可能需要访问你的文件系统和包管理系统来了解发生了什么。例如,当我使用 Anaconda Python 链接 MKL 时,部分输出如下:
lapack_opt_info:
libraries = [
'lapack', 'blas', 'lapack', 'blas', 'cblas', 'blas', 'cblas', 'blas']
library_dirs = ['/home/tra/anaconda3/envs/book-mkl/lib']
language = c
define_macros = [('NO_ATLAS_INFO', 1), ('HAVE_CBLAS', None)]
include_dirs = ['/home/tra/anaconda3/envs/book-mkl/include']
如果你这样做,你可能会很幸运,但在我这个例子中,这非常没有信息量(book-mkl来自我给环境的命名;从中无法推断出任何信息)。为了确定我使用的是什么,我最终做了ls -l '/home/tra/anaconda3/envs/book-mkl/lib/libcblas.so*'并注意到这一点:libcblas .so.3 → libmkl_rt.so。因此,MKL 似乎是链接的库。你可能需要研究你自己的情况下链接了哪些库。
提示:NumPy 不仅使用 BLAS 和 LAPACK,还提供了对它们的 Python 接口,以便你可以直接访问它们(或者更确切地说,SciPy 提供了这个接口)。
SciPy 是 NumPy 的姐妹库,它们有着密切的历史关系。SciPy 实现了比 NumPy 更高级的功能。
由于 NumPy 和 SciPy 有着如此紧密的关系,你可能会被它们的 API 所困惑。SciPy 的文档对此非常明确。以下是 SciPy 线性代数模块scipy.linalg的文档的逐字复制:
See also: `numpy.linalg` for more linear algebra functions. Note that
although `scipy.linalg` imports most of them, identically named
functions from `scipy.linalg` may offer more or slightly differing
functionality.
因此,有时 SciPy 会导入并重新导出 NumPy 函数。有时 API 可能略有不同。有时实现可能完全不同。
如果你想,你可以直接将 BLAS 作为 LAPACK 访问:你可以在scipy.linalg.blas和scipy.linalg.lapack中找到相应的 Python API。如果你认为在 NumPy 而不是 SciPy 中拥有这个接口更有意义,要知道你并不孤单。
因此,你可以直接从 Python 中使用这些库,但从性能角度来看,使用底层语言的库更有用。我们在这里不再进一步讨论,但将在 Cython 章节中重新讨论这个问题。
在我们回到关于 NumPy 内部更实际的问题之前,我们仍然需要通过查看 Python 发行版来讨论我们的安装对性能的影响。
4.3.2 如何调整 Python 发行版中的 NumPy
在这里,我们将探讨一些确保你的 NumPy 针对你的发行版进行优化的技巧。由于无法涵盖所有现有的 Python 发行版和操作系统,我们将涵盖来自python.org的标准 Python 和一个 Anaconda Python。我们将使用 Linux,因为这也是操作系统相关的。评论将以通用方式编写,以便在其他场景中也有用。
如果你在一个标准发行版上安装 NumPy,你很可能会通过使用pip install numpy来安装它。这只有在你的操作系统中已经安装了 BLAS 和 LAPACK 的情况下才会工作。问题变成了安装了什么版本? 最常见的版本是原始 NetLib 版本,它速度慢且不是线程化的。这从性能角度来看将是可怕的。你将不得不确保(1)你安装了更高效的版本,如 OpenBLAS 或 Intel 的 MKL,并且(2)按照之前讨论的np.show_config方法,你正在链接系统中最快的 BLAS/LAPACK 版本。
如果你使用另一个发行版,那么那个发行版的打包系统很可能已经为你处理了 BLAS 和 LAPACK。此外,依赖性安装也将相当合理。例如,在使用 Anaconda Python 时,当你执行conda install numpy,此时你可能会得到 OpenBLAS,这对于大多数情况来说都是可以的。
虽然大多数科学 Python 发行版默认安装的版本可能已经足够好,但你可能想要考虑其他替代方案。大多数发行版都允许你这样做。例如,在使用 Anaconda 时,你可以通过以下方式使用 MKL 安装 NumPy:
conda create -n book-mkl blas=*=mk
conda activate book-mkl
conda install numpy
我们创建了一个名为book-mkl的新环境,以确保当前环境不受影响并保持其默认设置。然后我们安装blas=*=mk,这指定了 MKL 构建的 BLAS。有了这个核心,我们可以继续安装 NumPy。
小贴士:对于大多数用例,你需要确保你没有使用速度较慢的 NetLib BLAS 实现。在许多情况下,OpenBLAS 或 MKL 将足够好。如果你使用其他实现,你将不得不研究它们。
如果你真的需要从你的系统中榨取尽可能多的性能,那么你必须自己基准测试替代实现。虽然互联网上有些基准测试可用,但你应该为你的特定代码设计一个测试,因为不同的实现有不同的优势,基准测试没有一种适合所有情况。
现在你已经正确配置了 NumPy 安装,让我们来利用它。
4.3.3 NumPy 中的线程
NumPY 是否使用线程取决于 BLAS/LAPACK 的实现。大多数 BLAS/LAPACK 库的实现都是线程化的——NumPy 释放了全局解释器锁(GIL),所以我们在这里谈论的是真正的并行性——你可以利用它们。但是有两个注意事项:(1)大多数但并非所有实现都是线程化的,(2)你可能实际上更喜欢 BLAS/LAPACK 是单线程的。
想象以下场景在我们的图像处理应用程序中。你有成千上万张图像需要处理,因此你在八核机器上启动了 8 个并行进程。如果你的 NumPy 是线程化的,你最终会有 8 个进程,每个进程运行 8 个线程,总共 64 个并发线程。我们希望复合最大值仅为 8。
为了更高效,通常有一个线程数为一个的八个进程,而不是一个有八个线程的进程。记住,每个 Python 进程中的非 BLAS 代码将是单线程的,所以你只有在 NumPy/BLAS 内部时才使用这八个线程。
有时候,我们可能希望减少 BLAS 和 LAPACK 使用的线程数量,可能减少到一。这应该很容易,对吧?不幸的是,没有直接控制 NumPy 中线程数量的方法。你必须配置 BLAS/LAPACK 实现;因此,代码是不可移植的。
对于 NetLib,因为它单线程,所以很简单。然而,如果你在寻找性能,你应该避免使用它。OpenBLAS 和 Intel 的 MKL 有不同的接口,因为它们仍然可能有一个用于线程的更低级别的依赖:它们可能被编译并依赖于 OpenMP。因此,你必须配置你的 BLAS/LAPACK 实现,可能还要配置使用的多进程库。
对于 OpenBLAS,在你调用 Python 代码之前,请执行:
export OPENBLAS_NUM_THREADS=1 ①
export GOTO_NUM_THREADS=1 ②
export OMP_NUM_THREADS=1 ③
① 配置 OpenBLAS 的标准方式
② 基于原始包 GotoBLAS2 的 OpenBLAS 的遗留变量
③ 万一 OpenBLAS 使用了 OpenMP
对于 MKL,你需要:
export MKL_NUM_THREADS=1 ①
export OMP_NUM_THREADS=1 ②
① 配置 MKL 的标准方式
② 万一 OpenBLAS 使用了 OpenMP
对于其他库,你需要自己检查。在使用像 OpenMP 这样的底层线程库及其配置要求时,要小心潜在的依赖关系。
从实际角度出发,还有一个问题:当你将代码从一个计算机移动到另一个计算机时——比如说,从你的开发机器到生产环境——链接的库可能会发生变化。你可以调整你的初始化代码来适应这种情况,或者始终为所有库设置所有变量,这是一个更实际的解决方案。
总结
也许深入研究 NymPy 底层库的细节并不是这本书最令人兴奋的收获。但如果你想高效地使用 NumPy 及其之上的所有堆栈,了解你的 NumPy 实现所依赖的库、为你的 Python 发行版优化 NumPy 以及找出底层库是否使用多线程是很重要的。
摘要
-
与复制相比,数组视图在内存和性能方面都可以非常高效。在可能的情况下应该考虑使用它们。
-
NumPy 的视图机制非常灵活,可以用极低的计算和内存成本渲染现有数据的不同视角。
-
理解形状(即数组每个维度上的元素数量)和步长(即在每个维度上跳过的字节数以找到下一个元素)是充分利用视图的基础。形状和步长都可以从视图到视图进行更改,以不同的方式呈现数据。
-
数组编程(即对整个数组执行声明性操作,而不是逐个元素使用命令式风格)可以提供数量级的性能提升。在可能的情况下应该使用它。
-
NumPy 的广播规则——NumPy 可以灵活地在不同维度的数组上使用运算符——允许更高效和优雅的编程。
-
NumPy 的内部架构可以针对计算性能进行优化。NumPy 依赖于 BLAS 和 LAPACK 库,并且为这些库提供了不同的选择。
-
确保检查你的 NumPy 实现是否使用了针对你架构的最有效库。
-
在 NumPy 中进行并行编程可能很棘手,因为 NumPy 的线程语义依赖于代数库的线程语义。
-
在使用基于 Python 的进程多处理在 NumPy 实现之上之前,确保你理解 NumPy 的底层库是否使用了多线程。如果它们没有使用多线程,你可能应该更改底层库。如果它们确实使用了多线程,你应该小心不要在本身是线程化的 NumPy 调用上使用多进程。
-
关于 NumPy 还有很多可以说的,因为它是 Python 数据分析的核心。鉴于其重要性,我们将在其他章节重新审视这个库。
¹ 对于非常小的数组,这并不成立,因为元数据将是分配的重要部分,但对于我们关心的较大数组来说,这是一个非常好的近似,因为 Python 和 NumPy 的 96 字节开销是可以忽略不计的。
² 对于这个问题有一个非常简单的解决方案,但从教学角度来看并不很有用。你能想到吗?作为一个提示,考虑使用最大值而不是最小值。
第二部分. 硬件
本书第二部分关注从提取常用硬件最大性能的角度开发 Python 解决方案。我们首先讨论使用更接近硬件的低级语言来从 CPU 中提取更多速度。具体来说,我们专注于 Cython,它是 Python 的超集,可以生成高效的 C 代码。然后,我们关注现代硬件架构以及它们有时需要反直觉的方法来提取最大性能。我们的讨论包括现代 Python 库,如 NumExpr,是如何设计来利用硬件的。
5 使用 Cython 重实现关键代码
本章涵盖
-
如何更高效地重实现 Python 代码
-
从数据处理的角度理解 Cython
-
分析 Cython 代码
-
使用 Cython 实现高效的 NumPy 函数
-
释放 GIL 以实现真正的线程并行
Python 很慢。标准实现很慢,语言的动态特性会付出性能代价。许多 Python 库之所以性能良好,正是因为它们部分是用底层语言实现的,从而提供了高效的数据处理算法。但有时我们可能需要在比 Python 更快的“某种东西”中实现我们自己的高性能算法。在本章中,我们将考虑 Cython,它是 Python 的超集,可以转换为 C,并且比 Python 性能好得多。
除了 Cython 之外,还有许多其他可以与 Python 集成以提高性能的替代方案,因此我们将从对可用选项的简要概述开始。之后,我们将深入探讨 Cython。
如果您以前从未使用过 Cython,本介绍将为您提供足够的数据分析背景知识,因此,与 NumPy 结合,因为它是数据分析的核心库。然后我们将讨论 Cython 分析和优化。我们还将以允许 NumPy 释放 GIL 并进行并行多线程的方式编写 Cython 代码。最后,我们将完成一个通用并行线程示例,其中我们的 Cython 代码将自行释放 GIL。
但首先,让我们了解一下 Cython 的其他替代方案。可能存在一个比 Cython 更适合您个人情况的替代方案。这尤其适用于您已经熟悉像 C 或 Rust 这样的底层语言的情况。
5.1 高效代码重实现技术概述
Cython 是许多替代方案之一,可以以更高效的方式重实现代码。我们将使用它,因为它不需要我们比 Python 学习更多;Cython 是 Python 的超集。但您应该了解替代方案,要么是因为您可能已经了解其中的一些,要么是因为您有一些需要使用其他东西的限制。
替代方案有四种形式(表 5.1):
-
现有库
-
Numba
-
更快的语言,如 Cython、C 和 Rust
-
如 PyPy、Jython、IronPython 和 Stackless Python 这样的替代 Python 实现
表 5.1 不同高效代码重实现的方法:库、更快语言、替代 Python 实现、Numba
| 库 | 底层语言 | 替代 Python 实现 | 即时编译器 |
|---|---|---|---|
| NumPy、SciPy、scikit-learn、PyTorch | C、Rust、Fortran、C++、Go、Cython | PyPy、IronPython、Jython、Stackless Python | Numba |
NumPy 是一个现有库的例子,它提供了 Python 的高效实现。有许多库以高效的方式实现了各种功能(例如,pandas、scikit-learn)。在实现你自己的代码之前,请确保它不是在现有的库中完成的。
值得考虑的第二个选项是 Numba。Numba 是一个即时编译器,它将 Python 的一部分转换为快速的本地代码。Numba 通常比另一种语言更容易使用,甚至比 Cython 还容易。Numba 为许多库提供了优化,包括 NumPy 和 pandas,以及多种架构,包括 GPU。我们在本书中优先考虑 Cython 的原因是我们有一个隐藏议程,那就是解释事物的工作原理,而 Numba 有一些神奇之处。它试图通过生成的代码变得智能。但在这本书中,我们试图理解如何创建更高效的代码;因此,我们想要比仅仅提供一个魔法解决方案更深入一些。从实际的角度来看,你不应该放弃 Numba;相反,对于大多数情况,它可能提供与 Cython 相当的性能提升,但工作量更少。话虽如此,当有需要时,我们将在本书中提及 Numba;确实,附录 B 是专门为其编写的。
本章我们将要遵循的路径是在更低级语言中重新实现代码。碰巧这个语言,Cython,非常接近 Python,但你有很多其他选择。C 可能是最常见的,但列表几乎永远不会结束:C++、Rust、Julia 等等。Cython 在很多方面使你的生活变得更轻松,因为它与 Python 紧密集成。如果你决定使用另一种语言,你需要研究如何将其他语言的代码链接到 Python 上。对于 C 和 C++,你可能想要考虑 Python 内置的ctypes模块或 SWIG(swig.org)。
最后,你可以考虑使用不同于 CPython 的 Python 实现。如果你依赖于 Java,你可以考虑 Jython。如果你依赖于.Net,那么考虑 IronPython。CPython 最现实的替代品是 PyPy,它更快,因为它是一个即时编译器。虽然 PyPy 在某种程度上是可行的,但它有一些限制,即哪些库可以在其上运行。在这个阶段,对于大多数用例,CPython 仍然是现实的最现实选择。
许多这些替代方案可以相互结合使用。本章的核心是将外部库 NumPy 与底层语言 Cython 绑定。
既然你已经了解了最重要的替代方案,让我们通过一个具体的 Cython 示例来工作。这将展示我们如何轻松地通过 Cython 获得显著的计算性能。
5.2 Cython 快速浏览
虽然这不是一本入门书籍,但可以合理假设可能会有很多读者从未使用过 Cython。在本节中,我们将通过一个小项目来展示基础知识,重点关注性能。我们将避免许多关于 Cython 编译的细节,尽管这些细节很重要,但不是理解性能问题的基本要素。你可以在互联网上找到大量关于 Cython 编译和显式内存管理的教程,如果你想学习基础知识,Cython 的项目文档是一个很好的起点 (cython.readthedocs.io/en/latest/)。
在这里,我们将从上一章的例子,图像处理,构建一个过滤器,该过滤器接受一个图像,生成一个灰度版本,然后根据另一个相同大小的图像上的值进行变暗。我们的第一次实现可能不一定更快——我们将在本章后面得到这个结果——但它将介绍在加速代码之前需要的 Cython 基本概念。图 5.1 提供了一个示例输出以使这一点更清晰。

图 5.1 原始图像(此处为灰度图)、处理后的图像以及应用的过滤器
让我们继续在 Cython 中实现过滤代码。为了进行性能比较,仓库中的代码允许你运行原生 Python 实现。为了参考,原生 Python 代码在我的电脑上需要 35 秒。
5.2.1 Cython 中的简单实现
我们的图像过滤器将使用 NumPy 和 Pillow 进行图像处理,就像上一章一样。图像将是彩色的;因此,将有三个 RGB 分量。过滤器的值将在 0 到 255 之间变化,0 表示没有变暗,255 表示完全变黑。我们的代码首先将每个像素转换为灰度,然后进行变暗。
我们现在将使用 Cython 进行第一次实现。我们将把代码分成两部分:在正常的 .py 文件中调用 Cython 代码的 Python 代码,以及真正的 Cython 代码,它是以 .pyx 扩展名编写的文件。.pyx 来自 Pyrex,这是 Cython 最初分叉的项目。Python 代码位于 05-cython/sec1-intro/apply_filter.py:
import numpy as np
from PIL import Image
import pyximport; ①
pyximport.install(
language_level=3, ②
setup_args={'include_dirs': np.get_include()}) ③
import cyfilter ④
image = Image.open("../../04-numpy/aurora.jpg")
gray_filter = Image.open("../filter.png").convert("L")
darken_arr = cyfilter.darken_naive(image_arr, gray_arr)
Image.fromarray(darken_arr).save("darken.png")
① pyximport 将负责编译和加载 Cython 代码。
② 我们需要 Python 3。
③ 我们需要使用 NumPy 头文件进行编译。
④ 我们将在一个名为 cyfilter 的模块中实现 Cython 代码。
之前代码中唯一概念上新的部分与 pyximport 相关。这将负责编译和链接 Cython 代码。
链接 Cython 代码
记住,Cython 是 Python 的超集,编译成 C,使代码可以作为外部扩展使用。这并不像简单地导入原生 Python 模块那样简单。
有几种方法可以处理整个过程,如 Cython 文档中所述。我们不会在这里详细介绍所有这些方法,但我会指出三种值得注意的方法:
-
这里使用的方法
pyximport将负责以透明和简单的方式编译代码并进行链接。每次你导入一个 Cython 模块时,代码可能会被转换为 C 语言,编译并链接,因此你将在启动时付出性能代价,但仅限于启动时。只需确保在分析性能时,以某种方式排除这部分时间。 -
如果你使用 Jupyter/IPython Notebook,则
%cython魔法是可用的。有关详细信息,请参阅 Cython 或 IPython 文档。 -
直接在
.pyx文件上调用cython。这种方法将需要我们自行进行链接。我们将在本章后面使用它,但只是为了看看生成的代码。
如果你使用 Jupyter 或更普遍地使用 IPython,%cython 会为你处理一切,而且非常简单;如果你使用 Jupyter 或 IPython 进行工作,我推荐使用它。但如果你计划将代码分发给一般的 Python 用户,则不能使用它。
直接调用 cython 主要用于检查 C 语言代码。对于任何其他用例,实际上并没有真正的实用理由去做这件事。
当将代码投入生产或准备分发给用户时,应使用其他替代方案。很可能会需要为目标架构预编译您的 Cython 代码,因为要求用户拥有完整的 C 编译器栈通常要求过多。准备代码以分发是一个相当复杂的话题,我们在此不予讨论,并且对于本书的目的来说,这并不非常相关。
目前,Cython 代码与 Python 版本完全相同。它只是位于一个 pyx 文件中(05-cython/sec2-intro/cyfilter.pyx):
#cython: language_level=3 ①
import numpy as np
def darken_naive(image, darken_filter):
nrows, ncols, _rgb_3 = image.shape
dark_image = np.empty(shape=(nrows, ncols), dtype=np.uint8)
for row in range(nrows):
for col in range(ncols):
pixel = image[row, col]
mean = np.mean(pixel)
dark_pixel = darken_filter[row, col]
dark_image[row, col] = int(mean * (255 - dark_pixel) / 255)
return dark_image
① 第一行实际上并不是注释。它是指导 Cython 为 Python 3 版本编译。
除了第一行和函数名之外,代码与 Python 版本相同。如果你调用 Python 顶级文件,你将会非常失望。在我的电脑上,它花费了 33 秒。仅比原生 Python 版本少 2 秒。在下一节中,我们将编写一些运行得更快的东西,并确定为什么之前的版本运行缓慢。
作为编译型语言的 Cython
Cython 是一种编译型语言,而不是像 Python 那样解释执行。这带来了许多后果;其中之一涉及到你会在哪里发现某些类型的错误。例如,这段代码
def so_wrong():
return a + 1
将仅在 Python 运行时失败,因此你可以将其部署到生产环境中而无需担心问题。但是,当你尝试将其作为 Cython 程序编译时,它将立即失败。在这方面,Cython 有助于捕获一些错误,但请准备好 Cython 编译器会对你提出一些 Python 不会捕获的错误进行抱怨。
5.2.2 使用 Cython 注释提高性能
在我们深入探讨为什么之前的代码运行缓慢的原因之前,让我们先生成一个更快的版本,以便我们可以进行比较。更快的版本依赖于使用 Cython 注释系统:
#cython: language_level=3
import numpy as np
cimport numpy as cnp ①
def darken_annotated(
cnp.ndarray[cnp.uint8_t, ndim=3] image, ②
cnp.ndarray[cnp.uint8_t, ndim=2] darken_filter): ③
cdef int nrows = image.shape[0] ④
cdef int ncols = image.shape[1]
cdef cnp.uint8_t dark_pixel, mean ⑤
cdef cnp.ndarray[np.uint8_t] pixel
cdef cnp.ndarray[cnp.uint8_t, ndim=2]
dark_image = np.empty(shape=(nrows, ncols), dtype=np.uint8)
for row in range(nrows):
for col in range(ncols):
pixel = image[row, col]
mean = (pixel[0] + pixel[1] + pixel[2]) // 3 ⑥
dark_pixel = darken_filter[row, col]
dark_image[row, col] = mean * (255 - dark_pixel) // 255
return dark_image
① 我们导入 NumPy 的 C 级定义。
② 我们将第一个参数作为具有三个维度的 C 级 NumPy 数组输入;记住,对于彩色图像,每个像素有三个组件(RGB)。类型是 8 位无符号整数。
③ 第二个参数是一个 8 位无符号整数的二维数组。
④ 我们还指定了所有局部变量。元组赋值在 Cython 中有时是不可能的,所以我们把元组赋值拆分成两个。
⑤ 变量的类型必须在函数的开始处指定,因此我们在进入内部循环之前定义内部循环变量。
⑥ 这是一种计算平均值稍微高效一些的方法。
这段代码中真正重要的区别是 Cython 注释的使用。例如,nrows变成了cdef int nrows:我们通知 Cython 该变量是int类型。这个定义将在 C 级别生效,因为 Cython 代码被转换为 C。可以对外部库如 NumPy 进行 C 级别的定义;这正是cimport numpy as cnp所导入的。然后我们可以指定数组。
警告:Cython 级别的类型注释与现代 Python 级别的注释完全不同,两者之间没有任何关系。
注意,除了平均值计算和元组赋值拆分外,代码完全相同。也就是说,两个for循环具有相同的复杂度。
运行时间从原生 Python 版本的约 30 秒降低到 1.5 秒——快了 20 倍。现在我们正在取得进展。¹
小贴士:注释你的所有 Cython 变量。
这段代码还可以进一步加快速度。我们将在本章后面讨论加快速度的技术。但到目前为止,我们将把注意力转向理解为什么注释如此重要。
5.2.3 为什么注释对性能至关重要
为什么注释对性能如此重要?为了回答这个问题,我们需要查看我们函数的 C 生成代码。即使你不知道 C,也没有理由害怕:阅读 C 代码比编写它容易得多,你将很容易理解其精髓。
我们将使用一个简单的例子:给一个数字加 4。以下是带注释和不带注释的代码(此代码位于05-cython/sec2-intro/add4.pyx):
#cython: language_level=3
def add4(my_number):
i = my_number + 4
return i
def add4_annotated(int my_number):
cdef int i
i = my_number + 4
return i
简单吗?为了能够看到 C 生成的代码,我们将直接运行 Cython:
cython add4.pyx
这将生成一个名为add4.c的 C 文件。我的 Cython 版本有近 3000 行。没有必要恐慌:其中大部分是样板代码,Cython 使得找到我们的代码并了解生成的代码变得非常容易。你可以在 C 注释中找到你代码的每一行,这样你就可以知道正在生成什么。例如,这是由 Cython 生成的:
/* "add4.pyx":9
*
*
* def add4_annotated(int my_number):
* cdef int i
* i = my_number + 4
*/
对于我们每个函数,Cython 生成两个 C 函数:一个 Python 包装器——负责 Python 和 C 之间的接口——以及函数的正确实现。
因此,对于add4,我们有一个 C 包装器,如下所示:
static PyObject *__pyx_pw_4add4_1add4(
PyObject *__pyx_self, PyObject *__pyx_v_my_number)
如果你从 Python 的角度思考,而不是 C 或 Cython,这很有道理:函数返回一个 Python 对象,static PyObject *。self也是一个 Python 对象,PyObject *__pyx_self,参数是PyObject *__pyx_v_my_number。Python 只知道对象。
这是add4_annotated包装器的签名:
static PyObject *__pyx_pw_4add4_3add4_annotated(
PyObject *__pyx_self, PyObject *__pyx_arg_a) {
类型完全相同,这很有道理:记住 Python 只看到对象。因此,在 Python 接口级别上,函数是相同的。
两个包装器都做了大量的打包和拆包,然后调用实现。以下是add4实现的签名:
static PyObject *__pyx_pf_4add4_add4(
CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v_my_number)
类型大多相同。这很有道理,因为我们没有注释函数和 Cython 来为最通用的情况创建代码。
这是add4_annotated函数的签名:
static PyObject *__pyx_pf_4add4_2add4_annotated(
CYTHON_UNUSED PyObject *__pyx_self, int __pyx_v_my_number) {
----
注意,类型my_number现在是一个原生 C 类型,而不是 Python 对象,int __pyx_v_my_number。
如果你查看这两个函数的包装器,你会注意到,注释函数的包装器更复杂,有很多类型的打包和拆包。非注释包装器可以直接将参数传递给实现,因为它处理的是 Python 对象。但注释包装器必须管理 Python 对象到整数的转换。
在所有这些预备知识之后,我们来到了基本问题:i = my_number + 4的实现。这是非注释版本的内容:
__pyx_t_1 = __Pyx_PyInt_AddObjC(__pyx_v_my_number, __pyx_int_4, 4, 0, 0);
if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 5, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_v_i = __pyx_t_1;
__pyx_t_1 = 0;
这段代码正在调用函数__Pyx_PyInt_AddObjC将一个整数添加到一个对象上。这个函数在源代码add4.c中定义。欢迎你检查这个怪物。你会找到很多对 CPython 函数的调用,很多 C if语句,在某些情况下,甚至有goto调用。记住,所有这些都是为了将 4 加到一个变量上。这非常繁琐。
这段代码还有一个严重的问题:因为我们正在处理 Python 对象,Cython 无法释放 GIL。Python 代码是 GIL 绑定的,但在某些情况下,底层代码可以释放 GIL。按照目前的代码,GIL 无法被释放,因此,我们无法有并行线程。
实际上,最大的问题是前者——求和的实现——因为它大部分时间都在管理 Python 对象。即使我们能够,尽管我们做不到,制作它的并行版本,Python 操作求和的损失将远大于使用少量并行核心的收益。
不再拖延,以下是来自add4_annotated的注释版本:
__pyx_v_i = (__pyx_v_my_number + 4)
这是一个简单的 C 级别添加,因此,它应该比非注释版本快几个数量级。
吸取的教训
注释帮助 Cython 从你的代码的 C 版本中移除大量的 CPython 基础设施。因此,注释的 Cython 将比非注释的 Cython 运行得更快,我建议尽可能使用注释的 Cython。
5.2.4 为函数返回添加类型
你也可以这样为函数返回添加类型:
cdef int add4_annotated_cret(int my_number):
return my_number + 4
注意,不仅返回类型是 int,而且函数现在使用 cdef 而不是 def 定义。这是一个只能从 C 调用的函数。如果你尝试从 Python 调用它,它将不会工作,因为没有包装器。这有什么优势呢?对于仅从其他 Cython 函数调用的 Cython 函数,可以声明一个函数可以从 Python 和 Cython 使用(即,如果从 Python 或 Cython 调用,将使用不同的接口):
---
cpdef int add4_annotated_cpret(int my_number):
return add4_annotated_cret(my_number)
----
在这种情况下,你将得到包装器和底层实现。
所以你有三种声明函数的方式:仅使用 Cython,它通过 cdef 接口;Python 和 Cython 都可以使用,通过 cpdef 接口;以及仅使用 Python,通过 def 接口。任何需要通过 Python 原生接口的时候,你都会付出性能的代价。当你通过仅使用 Cython 的接口时,你的负担会更小。(正如我们在上一节中看到的,def 函数会生成两个级别的接口,但这并不是使用 def 时可以保证的实现细节。)
为什么不总是使用 cpdef 而不是 def 和 cdef 呢?有时你想要使用 def 作为 cpdef 和 cdef 来对函数实现施加额外的限制,有时你有明确的需要添加注解。当你使用 Python 无法理解的数据类型时,需要 cdef,例如,在 Cython 代码中编写的 C 指针。
吸取的经验教训
总是用 Cython 注释类型;好处巨大,唯一的缺点是编写注释的麻烦。如果可能,使用 cdef。如果不可能,考虑重构代码,以便你有一个在 def/cpdef 中的 Python 链接部分,以及在 cdef 中的计算密集部分。
现在我们对为什么注释对性能很重要有了更深入的理解,让我们继续通过分析我们的 Cython 代码来微调我们的代码。
5.3 分析 Cython 代码
让我们回到我们的基于 Cython 的图像过滤代码。虽然它比纯 Python 实现快得多,但仍然感觉有点慢。毕竟,应用一个简单的过滤器需要超过 1 秒。虽然我们的直觉使我们怀疑存在问题,但正如我们在第二章中提到的,性能分析中的直觉往往会导致不良结果,所以我们将从 Cython 视角重新审视分析,以严格找到剩余的瓶颈。
Cython 分析与原生 Python 的分析很好地集成。第二章中介绍的分析技术可以在这里直接使用。因此,我们将对函数进行逐行分析,以找到延迟的来源。
5.3.1 使用 Python 的内置分析基础设施
我们将开始使用内置的分析基础设施。我们首先需要做的是注释我们的 Cython 代码,以便生成可分析代码。这相当简单。以下是我们的暗色注释函数的注释(代码位于 05-cython/sec3-profiling/cython_prof.py):
# cython: profile=True ①
import numpy as np
cimport cython
cimport numpy as cnp
def darken_annotated(
cnp.ndarray[cnp.uint8_t, ndim=3] image,
cnp.ndarray[cnp.uint8_t, ndim=2] darken_filter):
cdef int nrows = image.shape[0]
...
① 我们告诉 Cython 我们想要我们的代码被配置文件。
这就像在代码中添加全局指令一样简单。如果你出于某种原因不想对你的文件中的特定函数进行检测,只需添加 @cython.profile(False)。
作为对第二章中配置文件示例的变体,让我们使用内置的 pstat 模块来获取配置文件统计信息。以下是我们的调用函数(见 code/05-cython/sec3-profiling/apply_filter_prof.py):
import cProfile ①
import pstats ②
import pyximport
import numpy as np
from PIL import Image
pyximport.install(
setup_args={
'include_dirs': np.get_include()})
import cyfilter_prof as cyfilter
image = Image.open("../../04-numpy/aurora.jpg")
gray_filter = Image.open("../filter.png").convert("L")
image_arr, gray_arr = np.array(image), np.array(gray_filter)
# We just want to profile this
cProfile.run("cyfilter.darken_annotated(image_arr, gray_arr)",
"apply_filter.prof") ③
s = pstats.Stats("apply_filter.prof")
s.strip_dirs().sort_stats("time").print_stats() ④
① 我们将负责在内部运行配置文件代码。
② pstats 模块处理分析器的输出。
③ 我们在我们的函数上调用分析器。
④ 我们使用 pstats 模块来打印收集到的统计信息。
关于配置文件,这段代码中没有 Cython 特有的内容。输出如下:
Tue May 10 14:43:03 2022 apply_filter.prof
5 function calls in 0.707 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.707 0.707 0.707 0.707 cyfilter_prof.pyx:9
(darken_annotated)
1 0.000 0.000 0.707 0.707 {built-in method
builtins.exec}
1 0.000 0.000 0.707 0.707 <string>:1(<module>)
1 0.000 0.000 0.707 0.707
{cyfilter_prof.darken_annotated}
1 0.000 0.000 0.000 0.000 {method 'disable' of
'_lsprof.Profiler' objects}
关于输出的详细信息请参阅第二章。正如我们在第二章中看到的那样,内置的配置文件有时并不像我们希望的那样具有信息量。考虑到这一点,让我们重新审视行配置文件,现在是在 Cython 的背景下。
5.3.2 使用 line_profiler
我们将使用 line_profiler 模块,就像在第二章中一样。为此,我们必须指示 Cython 为行配置文件检测我们的代码(有关详细信息,请参阅存储库中的 05-cython/sec3-profiling/cython_lprof.py):
# cython: linetrace=True ①
# cython: binding=True ②
# cython: language_level=3
import numpy as np
cimport cython
cimport numpy as cnp
cpdef darken_annotated(
cnp.ndarray[cnp.uint8_t, ndim=3] image,
cnp.ndarray[cnp.uint8_t, ndim=2] darken_filter):
cdef int nrows = image.shape[0] # Explain
cdef int ncols = image.shape[1]
cdef cnp.uint8_t dark_pixel
cdef cnp.uint8_t mean # define here
cdef cnp.ndarray[cnp.uint8_t] pixel
cdef cnp.ndarray[cnp.uint8_t, ndim=2]
dark_image = np.empty(shape=(nrows, ncols), dtype=np.uint8)
for row in range(nrows):
for col in range(ncols):
pixel = image[row, col]
mean = (pixel[0] + pixel[1] + pixel[2]) // 3
dark_pixel = darken_filter[row, col]
dark_image[row, col] = mean * (255 - dark_pixel) // 255
return dark_image
① 我们需要 Python 类型的绑定。
② 我们告诉 Cython 生成行跟踪代码。
我们需要做的唯一改变是指导 Cython 生成行跟踪检测代码。我们使用指令 # cython: linetrace=True 来实现这一点。你还可以通过不使用指令,而是注释每个你想要配置文件检测的函数来逐个函数地激活行跟踪:
@cython.binding(True)
@cython.linetrace(True)
你可能还记得,从第二章中,行跟踪非常慢。因此,Cython 要求你不仅用 linetrace 注释你的 Cython 代码,而且当你 使用 代码时,你必须明确请求跟踪。为了看到这个动作,让我们检查调用此函数的代码(即 Python 端):
import pyximport
import line_profiler ①
import numpy as np
from PIL import Image
pyximport.install(
language_level=3,
setup_args={
'options': {"build_ext":
{"define": 'CYTHON_TRACE'}}, ②
'include_dirs': np.get_include()})
import cyfilter_lprof as cyfilter
image = Image.open("../../04-numpy/aurora.jpg")
gray_filter = Image.open("../filter.png").convert("L")
image_arr, gray_arr = np.array(image), np.array(gray_filter)
profile = line_profiler.LineProfiler(
cyfilter.darken_annotated) ③
profile.runcall(cyfilter.darken_annotated, image_arr, gray_arr)
profile.print_stats()
① 我们导入 line_profiler。
② 我们需要用 CYTHON_TRACE 宏编译 C 代码。
③ 我们将在这里显式调用 line_profiler。
我们必须记住激活执行检测的 C 代码。C 代码被 C 宏封装,并且只有当编译器传递指令 CYTHON_TRACE 时才会编译。我们通过 distutils 系统指导 pyximport 来完成此操作。C 宏系统和 Python 构建基础设施超出了本书的范围。但你要确保在编译你的 C 代码的系统(记住,pyximport 是几个选项之一)中定义了 CYTHON_TRACE 宏。
在这里,我们直接从我们的代码中配置line_profiler机制;在第二章中,我们通过调用kernprof来使用不同的方法。我们创建LineProfiler对象,在其中调用darken_annotate,并打印统计信息。作为一个练习,在我们查看结果之前,思考一下你对瓶颈位置的预期。然后看看下面的图。

我们函数的运行时间为 3.5 秒;这比之前的整个运行时间 1.5 秒多得多。记住,行分析非常昂贵,并产生开销。我们不应该将标准分析与行分析的绝对时间进行比较,因为它们预期会有所不同。我们还需要在分析时更加耐心。
表面上看似无辜的赋值pixel = image[row, col]占用了 60%的时间。这是你预期的吗?
理解什么在拖慢我们速度的最简单方法再次是执行cython cyfilter_lprof.py并查看生成的代码。这个 C 分析案例比第一个示例更容易;在那个例子中,我们可以使用 Cython 通过cython -a cyfilter_lprof.py生成的网络报告。一个 HTML 文件被创建——cyfilter_lprof.html——可以用任何网络浏览器打开。图 5.2 显示了我们的函数的主视图。你可以点击每一行,并查看为其生成的 C 代码。灰色(在浏览器版本中为黄色)的线条暗示了与 Python 机制的交互:如果需要 Python 机制,你可以肯定性能会受到影响。

图 5.2 cyfilter_lprof.html 的网页输出。灰色(在浏览器版本中为黄色)的线条暗示了与 Python 的交互。
如果你展开第 22 行的代码——我们的“无辜”的赋值pixel = image[row, col]——你会看到这里不仅仅是赋值。有许多 C 调用会引发许多性能标志:__Pyx_PyInt_From_int、PyTuple_New、__Pyx_PyObject_GetItem和__Pyx_SafeReleaseBuffer似乎都在为应该是一个简单赋值的操作调用很多可能很慢的东西,至少在 C 级别上是这样。
吸取的经验教训
尽管我们通过之前的 Cython 注解通知了 Cython 低级类型,但分析过程表明它仍在操作 NumPy 数组。如果代码操作与 Python 对象(如 NumPy 数组)交互,整个 Python 机制仍然需要运行代码。涉及这个机制是降低性能的可靠方法。
那么,问题变成了:我们能否更有效地查看这些数组,并使无辜赋值的代码更简单?结果是我们可以做到。
5.4 使用 Cython 内存视图优化数组访问
为了加速我们的代码,我们需要将与 Python 对象的交互减少到尽可能少的程度——理想情况下为零。我们需要移除 Python 内置函数和 Python 对 NumPy 数组的视图。在我们的当前示例中,我们仍然有作为 Python 对象的数组,我们需要改变这一点。
结果表明,Cython 对于 NumPy 数组有一个内存视图的概念,这与我们在上一章中探索的同等名称的概念有些相似。Cython 可以直接访问原始数组表示,而无需使用 Python 的对象机制。我们将我们的 Cython 代码分为两个函数:一个用于处理 Python 对象,这些对象永远无法那么快,另一个以 C 级别速度工作。所以一个接收 NumPy 数组并准备内存视图,另一个应用图像过滤器(见存储库中的 05-cython/sec4-memoryview)。让我们从接收 NumPy 数组并准备内存视图的函数开始:
cpdef darken_annotated(
cnp.ndarray[cnp.uint8_t, ndim=3] image,
cnp.ndarray[cnp.uint8_t, ndim=2] darken_filter):
cdef int nrows = image.shape[0]
cdef int ncols = image.shape[1]
cdef cnp.ndarray[cnp.uint8_t, ndim=2] dark_image =
np.empty(shape=(nrows, ncols), dtype=np.uint8)
cdef cnp.uint8_t[:,:] dark_image_mv ①
cdef cnp.uint8_t [:,:,:] image_mv
cdef cnp.uint8_t[:,:] darken_filter_mv
dark_image_mv = dark_image ②
darken_filter_mv = darken_filter
image_mv = image
darken_annotated_mv(image_mv,
darken_filter_mv, dark_image_mv) ③
return dark_image
① 这声明了一个将指向 dark_image 原始数据的内存视图。
② 在这里,我们让 Cython 将视图解析为 NumPy 数组的原始数据。
③ 最后,我们调用一个新的函数,darken_annotated_mv,它只处理视图。
注意声明内存视图的语法:它有一个 C 类型,并且需要知道维度(例如,[:,:,:] 用于 image 的三个维度)。Cython 将确保内存视图变量指向具有正确步长和形状的数组原始数据。
现在我们来看新的内部函数:
cpdef darken_annotated_mv(
cnp.uint8_t[:,:,:] image_mv, ①
cnp.uint8_t[:,:] darken_filter_mv,
cnp.uint8_t[:,:] dark_image_mv): ②
cdef int nrows = image_mv.shape[0]
cdef int ncols = image_mv.shape[1]
cdef cnp.uint8_t dark_pixel
cdef cnp.uint8_t mean # define here
cdef cnp.uint8_t[:] pixel
for row in range(nrows):
for col in range(ncols):
pixel = image_mv[row, col]
mean = (pixel[0] + pixel[1] + pixel[2]) // 3
dark_pixel = darken_filter_mv[row, col]
dark_image_mv[row, col] = mean * (255 - dark_pixel) // 255
① 我们将输入参数的类型从 Numpy 数组更改为视图。
② 输出现在是一个参数。
代码最终与原始版本非常相似。输入类型已从数组更改为内存视图,为了得到一个更干净的版本,我们将输出视图作为参数传递。
如下所示,性能通过行分析有所提升。

现在速度提高了 50%,并且“无辜”的赋值 pixel = image_mv[row, col] 现在在意义上更加“无辜”,因为它不需要 Python 对象管理代码。然而,洞察力表明,对于简单的图像处理,它仍然花费了太多的时间。
结果表明,在这个代码中仍然有很多 Python 交互。如果我们运行 cython -a,生成一个代码网页,为每行代码着色以显示 Python 交互的数量,我们会得到如图 5.3 所示的大量标记交互。

图 5.3 基于内存视图函数的网页输出。灰色(在浏览器版本中为黄色)的线条暗示了与 Python 的交互。
5.4.1 吸取的经验
通常,创建 NumPy 数组的内存视图是值得花费时间的,这将允许 Cython 与原始数组表示进行交互,并避免 Python 机制。性能提升可能是显著的。但进一步的性能分析表明,我们仍在与机制交互,因此仍然表现不佳。
我们接下来的问题是:我们能否消除剩余的(Python)交互以显著优化我们的代码?答案仍然是,我们可以。
5.4.2 清理所有与 Python 的内部交互
有三种类型的交互负责 cython -a,标记图 5.3 中的行,仍在继续:
-
我们有一个
cpdef函数,它生成一个带有 Python 桩的 C 函数。我们可以用cdef来替换它。 -
函数隐式返回一个
None对象,正如所有 Python 函数所期望的那样。这意味着管理一个 Python 对象,即使只是一个None。 -
NumPy 内存视图仍在尝试帮助您进行边界检查(即,如果您放入一个无效的索引,Python 异常机制将被激活)。
让我们一举解决所有这些问题。我们只需更改函数定义:
@cython.boundscheck(False) ①
cdef void darken_annotated_mv( ②
cnp.uint8_t[:,:,:] image_mv,
cnp.uint8_t[:,:] darken_filter_mv,
cnp.uint8_t[:,:] dark_image_mv) nogil: ③
① 我们停用边界检查。
② 我们有一个 cdef(即,没有 Python 桩),我们声明返回类型为 C void。
③ 我们现在可以告诉 Cython 这个函数可以释放 GIL。
我们将在稍后重新讨论边界检查;在一般情况下,您的代码现在可以没有边界检查保护栏而崩溃。在这种情况下这不是问题,但本章后面我们将看到这可能会成为一个问题。
nogil 注解是可选的;在这个阶段没有获得任何好处。这将使我们能够拥有真正的并行性,这是一个我们将在本章后面再次讨论的主题。如果存在此注解但未移除所有 Python 连接,Cython 将会抱怨。所以,我们之所以能够这样做,仅仅是因为我们已经做了所有其他改变。
在我的电脑上,这是一台装有 1.6GHz 英特尔 i5 CPU 的笔记本电脑,现在只需要 0.04 秒。记住,我们最初是用 35 秒进行原生 Python 实现和 18 秒进行原始 Cython 实现开始的。
吸取的经验教训
为了消除 Cython 和 Python 之间最后残留的交互,我们可以更改函数定义以避免使用 Python 函数进行调用或返回。通过将此过程与注释 Cython 代码、添加函数返回类型注解以及使用内存视图而不是原始 NumPy 数组相结合,我们有一个消除更快 Cython 和较慢 Python 之间交互的策略。
之后,我们将讨论边界检查和其他 NumPy 优化。我们还将讨论并行性,但到目前为止,我们将探讨在 Cython 中实现 NumPy 通用函数。这非常有用,因为通用函数适用于 NumPy 广播规则。
5.5 在 Cython 中编写 NumPy 通用函数
我们现在将通过在 Cython 中编写通用函数来解决图像过滤问题的一种替代方案。记得从上一章中,通用函数机制通过提供广播等好处使生活变得更简单。通用函数附带所有这些额外的好处,并且,正如我们稍后将看到的,它们在某种程度上类似于 GPU 编程范式。然而,请记住,它们不是通用的计算解决方案,我们将在下一节中看到一个这样的例子。
我们在第四章中也了解到,通用函数是逐元素操作的。在我们的例子中,这意味着逐像素操作。我们的代码将由两部分组成:通用函数和注册它的代码。让我们从通用函数开始(代码可在05-cython/sec5-ufunc目录下的仓库中找到):
# cython: language_level=3
import numpy as np
cimport cython
cimport numpy as cnp
cdef void darken_pixel(
cnp.uint8_t* image_pixel, ①
cnp.uint8_t* darken_filter_pixel,
cnp.uint8_t* dark_image_pixel) nogil:
cdef cnp.uint8_t mean
mean = (image_pixel[0] + image_pixel[1] + image_pixel[2]) // 3
dark_image_pixel[0] = mean * (255 - darken_filter_pixel[0]) // 255
① 注意指针符号的使用(*)。
我们现在是在逐像素的基础上操作,因此代码更简单,因为我们不需要在整个数组/图像上使用for循环。
基本的区别在于,我们传递的是一个数字的指针,cnp.unit8_t,而不是一个数字,cnp.uint8_t *。如果你不习惯像 C 这样的底层语言,这个概念可能对你来说很新。就我们的实际用途而言,这不会带来很多后果,但对于更复杂的例子,你应参考 Cython 的文档。唯一有意义的后果是,输出将被写入“输入变量”。最后,该函数被标记为nogil,允许它并行运行:没有引用 Python 对象,因此并行执行器可以释放 GIL。
我们在本章中提到的通用函数,与上一章的情况相同,是一个通用的通用函数,因为第一个参数image_pixel不是一个原始类型,而是一个数组:彩色像素有三个 RGB 分量。
现在我们需要包装我们的(通用)通用函数。遗憾的是,模板代码有点长且稍微复杂:
cdef cnp.PyUFuncGenericFunction loop_func[1]
cdef char all_types[3] ①
cdef void *funcs[1] ②
loop_func[0] = cnp.PyUFunc_FF_F
all_types[0] = cnp.NPY_UINT8 ③
all_types[1] = cnp.NPY_UINT8
all_types[2] = cnp.NPY_UINT8
funcs[0] = <void*>darken_pixel ④
darken = cnp.PyUFunc_FromFuncAndDataAndSignature( ⑤
loop_func, funcs, all_types,
1, ⑥
2, ⑦
1, ⑧
0,
"darken",
"Darken a pixel", 0
"(n),()->()" ⑨
)
① 我们需要一个变量来指定所有输入和输出的类型。
② 实现通用函数的所有函数
③ 我们指定了两个输入参数和一个输出参数的类型。
④ 实现通用函数的功能列表
⑤ 创建包装后的通用函数
⑥ 输入类型的数量
⑦ 输入参数的数量
⑧ 输出参数的数量
⑨ Numpy 签名
我们需要指定所有参数的数据类型,这用all_types编码。此外,通用函数的签名(n),()→()意味着(n)是一个包含初始像素三个颜色分量的数组,()是一个表示灰度变暗像素的原始值,以及输出(),另一个表示灰度像素的原始类型。
最令人困惑的部分是拥有多个函数来渲染实现的能力;注意,我们有一个名为funcs的函数列表,而不是一个单独的函数。在我们的情况下,我们只需要一个函数,即darken_pixel,但我们可以为不同的输入或输出参数使用不同的函数——比如说,一个用于NPY_UINT8,另一个用于NPY_UINT16。
这现在可以像任何其他通用函数一样使用。在我们的情况下:
import pyximport
import numpy as np
from PIL import Image
pyximport.install(
language_level=3,
setup_args={
'options': {"build_ext": {"define": 'CYTHON_TRACE'}},
'include_dirs': np.get_include()})
import cyfilter_uf as cyfilter
image = Image.open("../../04-numpy/aurora.jpg")
gray_filter = Image.open("../filter.png").convert("L")
image_arr, gray_arr = np.array(image), np.array(gray_filter)
darken_arr = cyfilter.darken(image_arr, gray_arr)
5.5.1 主要收获
在 Cython 中编写 NumPy 通用函数通常是可能的,也是首选的,尤其是它们自带一些节省时间的内置功能。然而,在某些情况下,NumPy 通用函数不足以实现一个算法——例如,当你需要检查数组中的其他位置的状态,而不仅仅是当前位置时。为了处理这种情况以及 Cython 中数组处理的其他问题,我们现在将考虑一个新的示例:生命游戏。
5.6 Cython 中的高级数组访问
在本节中,我们将通过深入研究优化数组访问来巩固我们对 Cython 和 NumPy 交互的理解。具体来说,我们将进行低级别的多线程并行处理,最终绕过 GIL 对一次运行单个 Python 线程的限制。
我们将通过一个新的示例项目来观察这些过程的具体实现:我们将创建康威生命游戏的彩色版本(详情请见conwaylife.com/)。康威生命游戏是一个零玩家游戏,它从初始状态自动进化;设计有趣的初始状态是乐趣的一部分。游戏的状态由任意大小的网格组成,每个细胞可以有两种状态:活着或死亡。随着时间的推移,每个细胞将根据以下规则改变其状态:
-
任何有两个或三个邻居的活细胞都会存活。
-
一个有三个邻居的死亡细胞变成活细胞。
-
所有其他细胞都会死亡或保持死亡状态。
世界是环绕的,这意味着最左边的列会查看最右边的列来计算邻居,反之亦然。同样的情况也适用于顶部和底部的行。
图 5.4 展示了三个随时间变化的示例。第一个示例是一个永恒地改变方向从垂直到水平的破折号。第二个是一个稳定的盒子,第三个完全死亡。

图 5.4 使用生命游戏的标准规则三个示例
我们将使用一个名为 QuadLife 的扩展,²,其中每个活细胞可以有四种不同的状态:红色、绿色、蓝色和黄色。我更喜欢这个扩展,仅仅因为它看起来更酷。它包括两个新的规则:
-
如果某个颜色在邻居中占多数,那么这个颜色将成为新细胞的颜色。
-
如果所有三个活邻居颜色都不同,新细胞将采用四种可能颜色中的剩余颜色。
与之前的 Cython 示例一样,我们的实现将包括两个组件:调用 Python 代码,其中计算密集的部分用 Cython 实现。
Python 部分现在应该很熟悉且简单。它在仓库(05-cython/sec6-quadlife)中可用。这里提供了带有注释说明的内容:
import sys
import numpy as np
import pyximport
pyximport.install( ①
language_level=3,
setup_args={
'include_dirs': np.get_include()})
import cquadlife as quadlife
SIZE_X = int(sys.argv[1]) ②
SIZE_Y = int(sys.argv[2])
GENERATIONS = int(sys.argv[3])
world = quadlife.create_random_world(SIZE_Y, SIZE_X) ③
for i in range(GENERATIONS): ④
world = quadlife.live(world)
① 我们设置 pyximport 以包含 NumPy。
② 我们从命令行读取参数。
③ 我们使用一个(稍后定义)的函数来创建一个随机世界。
④ 我们根据用户指定的代数应用 Quadlife 算法。
我们通过传递所需的 X 和 Y 分辨率以及代数数来调用此脚本。目前脚本不输出任何内容,它只是运行游戏;稍后我们将对结果做一些有趣的事情。
首先,我们将使用 create_random_world 生成一个随机世界,这对于测试来说已经足够好了;稍后我们将考虑更好的替代方案。我们将使用用户指定的 SIZE_Y、SIZE_X 维度的 NumPy 数组。它将被填充在 0 和 4 之间的随机值。0 代表一个死细胞。然后我们运行名为 live 的模拟函数 GENERATIONS 次:第一次调用获取随机世界,然后其输出依次传递给自己。代码中没有展示任何概念上的新内容,应该很容易理解。
现在我们考虑我们的 Cython 代码。创建初始随机世界实际上并不需要优化,因为它只在开始时被调用一次:
#cython: language_level=3
import numpy as np
cimport cython
cimport numpy as cnp
def create_random_world(y, x):
cdef cnp.ndarray [cnp.uint8_t, ndim=2] world =
np.random.randint(0, 5, (y, x), np.uint8)
return world
现在有趣的部分开始了。我们的实现将包括概念上新的技术,但我们将在此基础上构建我们在前面章节中学到的内容。
5.6.1 绕过 GIL 对同时运行多个线程的限制
首先,我们想要确保我们的内部循环可以无 GIL(全局解释器锁)。为此,我们创建了一个 Cython 最高级 live 函数,它主要将 NumPy 数组转换为内存视图:
def live(cnp.ndarray[cnp.uint8_t, ndim=2] old_world):
cdef int size_y = old_world.shape[0]
cdef int size_x = old_world.shape[1]
cdef cnp.ndarray[cnp.uint8_t, ndim=2] extended_world =
np.empty((size_y + 2, size_x + 2), dtype=np.uint8) # empty
cdef cnp.ndarray[cnp.uint8_t, ndim=2] new_world =
np.empty((size_y, size_x), np.uint8)
cdef cnp.ndarray[cnp.uint8_t, ndim=1] states = np.empty((5,), np.uint8)
live_core(old_world, extended_world, new_world, states)
return new_world
转换为内存视图将由 live_core 函数签名强制执行(见以下讨论),但我们仍然需要一个可以将 Python 对象转换为可能无 GIL 表示的层的。old_world 是输入世界;new_world 将有输出。extended_world 和 states 是 live_core 内部变量,我们将在这里预分配。在我们展示 live_core 中的核心算法之前,让我们讨论我们将如何算法优化其一部分。
在生命游戏中,棋盘的边缘是相连的;例如,最左列的细胞将“观察”最右列的细胞的状态来计算它们的新状态。为了避免大量的边缘情况测试,这将增加许多 if 语句并因此增加计算时间,我们将在之前提到的变量 extended_world 中实现一个临时的扩展世界,其维度为 (y+2, x+2)。扩展边界会复制另一侧发生的情况,如图 5.5 所示。

图 5.5 用于计算新世界的扩展棋盘
这个算法的目的是在计算新棋盘时允许一种略微更有效的方法:我们不需要if语句来处理边界条件。这是以内存为代价完成的:我们现在需要存储一个新的大版本的棋盘。在高性能计算问题中,我们经常需要做出这类权衡(即内存与计算之间的权衡)。很难提出通用的指导方针来决定权衡。它将取决于特定算法的计算和内存成本以及你拥有的资源。
以下是实现这个扩展世界的代码。请注意,代码没有边界测试,因此减少了使用if语句的计算时间:
@cython.boundscheck(False) ①
@cython.nonecheck(False)
@cython.wraparound(False)
cdef void get_extended_world( ②
cnp.uint8_t[:,:] world,
cnp.uint8_t[:,:] extended_world): ③
cdef int y = world.shape[0]
cdef int x = world.shape[1]
extended_world[1:y+1, 1:x+1] = world ④
extended_world[0, 1:x+1] = world[y-1, :] # top
extended_world[y+1, 1:x+1] = world[0, :] # bottom
extended_world[1:y+1, 0] = world[:, x-1] # left
extended_world[1:y+1, x+1] = world[:, 0] # right
extended_world[0, 0] = world[y-1, x-1] # top left
extended_world[0, x+1] = world[y-1, 0] # top right
extended_world[y+1, 0] = world[0, x-1] # bottom left
extended_world[y+1, x+1] = world[0, 0] # bottom right
① 使用关闭边界、None 和环绕检查
② 我们使用cdef来避免全局解释器锁(GIL)。
③ 我们在函数签名上对一切进行类型化。
④ 扩展 _world 中间的 world 副本可能很昂贵。
在extended_world中间的world副本可能需要高昂的计算和内存代价,但计算部分可能通过更简单的核心算法得到补偿。³ 但至少从教学目的来看,它使核心算法大大简化,这对于学习目的来说很重要。
你可能会注意到前一个函数中的许多行看起来可以用更便捷的记法来编写。例如,也许
extended_world[1:y+1, 1:x+1] = world
可能可以写成:
extended_world[1:-1, 1:-1] = world
然而,我们发现我们无法进行这类重写,因为当我们关闭环绕检查以避免支付生成的 C 代码的代价时,我们必须花费时间在环绕验证上。此外,关闭环绕检查意味着我们无法使用负索引。这种权衡——即无法编写某些惯用语——是值得的:环绕检查需要 CPython 机制,这会减慢速度,因此我们的没有环绕检查的实现要快得多,并且我们需要关闭它以释放 GIL,因为环绕检查使用了 Python 机制。
警告:不进行环绕或边界检查可能会导致你的代码出现段错误。如果你看到这些错误,请确保在开发期间关闭装饰器。你的代码必须足够健壮,能够容忍移除这些和其他检查。
我们还使用了之前讨论过的优化:cdef,参数和变量的完整类型化,以及使用内存视图代替 NumPy 数组。如果你使用cython -a cquadlife.pyx,你将不会在浏览器版本的先前代码中看到表示 Python 交互行的黄色线条。
主要改变状态的实施利用了扩展世界。接下来的代码实现了 QuadLife 游戏规则。因为它相当长,我们将仔细注释,包括我们可能之前已经解决的问题。
@cython.boundscheck(False) ①
@cython.nonecheck(False)
@cython.wraparound(False)
cdef void live_core( ②
cnp.uint8_t[:,:] old_world, ③
cnp.uint8_t[:,:] extended_world,
cnp.uint8_t[:,:] new_world,
cnp.uint8_t[:] states): ④
cdef cnp.uint16_t x, y, i ⑤
cdef cnp.uint8_t num_alive, max_represented
cdef int size_y = old_world.shape[0]
cdef int size_x = old_world.shape[1]
get_extended_world(old_world, extended_world) ⑥
for x in range(size_x):
for y in range(size_y):
for i in range(5):
states[i] = 0
for i in range(3):
states[extended_world[y, x + i]] += 1
states[extended_world[y + 2, x + i]] += 1
states[extended_world[y + 1, x]] += 1
states[extended_world[y + 1, x + 2]] += 1
num_alive = states[1] + states[2] +
states[3] + states[4] ⑦
if num_alive < 2 or num_alive > 3:
# Too few or too many neighbors
new_world[y, x] = 0
elif old_world[y, x] != 0:
# Stays alive
new_world[y, x] = old_world[y, x]
elif num_alive == 3: # Will be born
max_represented = max(states[1],
max(states[2], max(states[3],
states[4]))) ⑧
if max_represented > 1:
# majority rule for color
for i in range(1, 5):
if states[i] == max_represented: ⑨
new_world[y, x] = i
break
else:
# diversity - use whichever color doesn't exist
for i in range(1, 5):
if states[i] == 0: ⑩
new_world[y, x] = i
break
else:
new_world[y, x] = 0 # stays dead
① 我们关闭了很多检查机制:边界和 None,以及环绕检查。
② 我们使用 cdef 来避免传递标准 Python 对象。我们还声明返回类型为 void,这在 C 语言中表示什么都没有。
③ 我们输入所有参数。
④ 一些内部变量(states 和 extended_world)是在外部分配的,我们使用了可用的内存。
⑤ 我们输入所有局部变量。
⑥ 当我们调用 get_extended_world 时,所有内容都是预先分配的。
⑦ 实现 sum(states[:1])
⑧ 实现 max(states[:1])
⑨ 实现 states[1:].index(max_represented)
⑩ 实现 states[1:].index(0)
这个函数很复杂,但其中许多技术是在之前引入的;在这里,它们在一个更现实的例子中被结合起来。所以仔细阅读代码和注释,您就会明白一切。
您可能会在代码中看到一些奇怪之处——特别是用非声明性版本替换了sum和index。我们这样做是因为sum和index会使用 CPython 机制,而我们想避免这种情况。对于max函数,也有类似的论点,但在那种情况下,替换版本无法在一次调用中比较所有值。当您使用通用函数时,您可能想对它们进行性能分析,并可能用优化的非通用函数来替换它们。
注意:由于生命游戏可以产生漂亮的演变可视化效果,我们将制作一个简单的图形用户界面。我们将使用 Python 内置的tkinter模块来创建 GUI,以及外部库 Pillow 来进行图像处理。我们不会在这里讨论代码,因为它超出了本书的范围,但您可以在05-cython/sec6-quadlife/ gui.py目录下的仓库中找到它。
我们的实施是完整的,但现在我们想衡量我们从代码中获得的速度提升。
5.6.2 基本性能分析
您可以在仓库中找到原生 Python 版本。我们将使用这个版本来与 Cython 版本进行一些基本比较。在我的电脑上,以 1000 × 1000 的分辨率运行 Python 版本,200 代需要略少于 1000 秒,这低于 17 分钟。Cython 代码需要 2.5 秒。
警告:我们的实现是内存密集型的。如果您用高分辨率测试它,请务必小心。实际上,本书的一个主题是考虑算法使用的内存量:如果它们可以在内存中运行,它们将比需要磁盘存储进行持续计算的速度快得多。只要可能,我们将尝试使用内存中的算法。如果不可以,我们通常将不得不优化存储以使处理高效进行。
现在,让我们考虑一个非常大的地图,400 × 900,000,只运行四代。在我的电脑上,这需要 44 秒。现在,如果我们为相同的四代运行 900,000 × 400 的转置地图,你认为需要多少时间?总细胞数和代数相同,但结果却大相径庭。它只需要 20 秒。在理论层面上明显相同的问题,结果却截然不同。对这个令人困惑的差异的答案将在第六章中探讨。更奇怪的是,你在你的电脑上得到的关系可能与我电脑上得到的关系完全不同。
在我们讨论本章的最后一个主题,即 Cython 的无 GIL 多线程之前,让我们先快速地从一个 Game of Life 生成一个酷炫的视频。这个过程将使我们考虑计算复杂度,并思考理论在帮助我们编写更高效程序中的作用。
5.6.3 使用 Quadlife 的空间战示例
在存储库中,你可以找到使用包含“飞船”和“防御”的起始状态的代码来生成视频。代码本身对于优化目的并不太相关,所以我们在这里不讨论它。如果你想复制它,你需要 Python 的 Pillow 库进行图像处理和 ffmpeg 生成视频。你可以找到生成它的主要 shell 脚本在05-cython/generate_video.sh中。图 5.6 显示了起始状态,颜色已反转。

图 5.6 QuadLife 模拟的起始状态视频
视频可在www.youtube.com/watch?v=E0B1fDKU_MI找到。图书馆(www.conwaylife.com)也允许你获取视频中使用的类似图案。
你可以使用存储库中的代码(05-cython/patterns.py)来生成类似的电影。此代码将在 400 × 250 的地图上运行太空船模型的 Game of Life 400 代,耗时不到 1 秒。以 1920 × 1080 的 HD 分辨率运行相同的 400 代需要大约 11 秒;800 帧需要 22 秒;而在 3840 × 2160 的 4K 分辨率下,400 帧需要 48 秒。以每秒 40 帧的速度,90 分钟的 4K 游戏大约需要 196 分钟来生成。但好处是,同样的视频在纯 Python 解决方案中生成需要 54 天。
计算复杂性的作用
虽然这不是一本理论书籍,但无法否认计算复杂性和其基础理论的重要性。这涉及到算法消耗的资源——通常,但不仅限于时间和内存。
例如,我们系统的运行时间成本随着我们计算的代数数量的线性增长。但如果我们计算的是一个边长为 n 的正方形世界,增长是二次的:20 的平方不是 10 的平方的两倍慢;实际上要慢四倍。同样,一个大小为 200 的平方比一个 10×10 的平方慢 400 倍(而不是 20 倍)。在这种情况下,算法对于内存需求也是二次的。
在大数据不断增长的世界中,这意味着一些算法的扩展性可能非常差,最终可能需要被完全不同的解决方案所取代。在这本书中,我们不会正式讨论计算复杂度理论,但有时我们不得不解决其背后的直觉。
我们仍然有足够的动力做得更好。因为我们的代码不与 Python 对象交互,我们可以释放 GIL 并使用多个线程进行真正的并行。让我们在最后一节中这样做。
5.7 使用 Cython 进行并行处理
在我们已经做了清理 GIL 绑定代码的所有准备工作之后,引入多进程解决方案现在相当直接。我们的方法将利用 Cython 的内部并行功能,代码(在05-cython/sec7-parallel中可用)相当简单。
Cython 提供了基于 OpenMP 的声明式并行函数。OpenMP 是一个多平台库,提供并行原语。它提供的一个函数是并行范围函数,该函数将对for循环的内容进行多线程处理;使用它相当简单:
from cython.parallel import prange ①
@cython.boundscheck(False)
@cython.nonecheck(False)
@cython.wraparound(False)
cdef void live_core(
cnp.uint8_t[:,:] old_world,
cnp.uint8_t[:,:] extended_world,
cnp.uint8_t[:,:] new_world,
cnp.uint8_t[:] states) nogil: ②
cdef cnp.uint32_t x, y, i
cdef cnp.uint8_t num_alive, max_represented
cdef int size_y = old_world.shape[0]
cdef int size_x = old_world.shape[1]
get_extended_world(old_world, extended_world)
for x in prange(size_x): ③
for y in range(size_y):
...
① 我们导入 prange 函数。
② nogil 现在是强制性的。
③ 我们只需将 range 替换为 prange。
就这么简单。当然,记住,通过清理所有与 GIL 相关的代码,我们在到达这一步之前已经做了大部分工作。我们还需要记住注释get_extended_world函数:
@cython.boundscheck(False)
@cython.nonecheck(False)
@cython.wraparound(False)
cdef void get_extended_world(
cnp.uint8_t[:,:] world,
cnp.uint8_t[:,:] extended_world) nogil:
...
Cython 提供了一些函数来覆盖 OpenMP,使并行代码更容易编写。这在大多数情况下都非常有用。基本要求是清除所有与 GIL 相关的调用。
我们主要关注 GIL 的 Python 和线程并行之间的交互。我们也关注基于 OpenMP 的 Cython 原语来编写并行代码。房间里的大象是整个并行处理领域。在这里,我们提供了在 Python 空间中释放真正并行线程处理的基本构建块。但并行编程技术通常是一个独立的主题,你应该咨询其他资源。
注意,实际上,使用 Cython 你是在 C 范式内部。虽然没有理由忽略 Cython 的 OpenMP 功能,但请记住,你并不局限于它。你可以使用其他基于 C 的并行库:只是你需要自己走得更低级一些。
概述
-
原生 Python,CPython,不足以实现复杂操作的最快代码。
-
有许多选项可以加速你的基于 Python 的代码:使用优化库、底层语言、Numba,甚至其他 Python 实现,如 PyPy。
-
Cython 是 Python 的超集,编译成 C,提供类似 C 的速度,而无需学习新语言。
-
Cython 可以以类似于 Python 代码的方式进行分析。
-
编写高效的 Cython 代码需要注释 Cython 变量以提供类型提示,这些提示与标准 mypy 的提示不同,有时还需要分析 Cython 生成的 C 代码。
-
Cython 提供的 C 代码浏览器允许你轻松地识别与 Python 解释器交互的行,因此它们是重写以更高效方式执行的潜在候选者。
-
你应该尽可能地减少与 CPython 的交互,甚至应该考虑重构你的代码,使得实现中的昂贵内循环不与 CPython 交互。这可以轻易地以数个数量级加速你的 Cython 代码。
-
Cython 与 NumPy 集成,允许高效地操作数组。资源,如 memoryviews,允许 Cython 和 NumPy 之间直接通信,从而消除了低效的 Python 解释器作为中间人的作用。
-
与 CPython 无关的代码是迈向 GIL 无关代码的第一步。如果我们能够释放 GIL,我们就可以从我们的 Cython 代码中使用并行多线程。
-
记住考虑 Numba 作为 Cython 的替代方案:在许多情况下,它更容易使用,尽管不如 Cython 可定制。
¹ 关于如何测量时间的详细信息,请参阅第二章。使用 IPython/Jupyter,你可以使用%timeit魔法。使用标准 Python,你可以使用timeit模块。或者,就像我在这个例子中所做的那样,我简单地测量了进程的运行时间。在某些情况下,一个粗略的近似值可能是一个好的开始。
² 你可以在 LifeWiki(conwaylife.com/)上找到关于游戏变体和生命游戏一般信息的大量资料。
³ 要确定是否得到补偿,需要进行仔细的分析,你现在知道如何做了。
6 内存层次结构、存储和网络
本章涵盖
-
有效利用 CPU 缓存和主内存
-
使用 Blosc 访问压缩数组数据
-
使用 NumExpr 加速 NumPy 表达式
-
为非常快速的网络设计客户端/服务器架构
众所周知,硬件会影响性能。但硬件如何与性能互动并不总是那么明显。本章的目标是帮助您更好地理解您的机器如何影响您的速度,以及您可以在硬件端做些什么来提高性能。为此,我们将仔细研究现代硬件和网络架构对 Python 高效数据处理的影响。
由于硬件考虑,软件开发中存在许多反直觉的后果。例如,有相当多的情况,处理压缩数据比处理未压缩数据要快。传统观点认为,解压缩和分析数据的花费将远远高于仅分析数据。毕竟,当我们解压缩时,我们是在增加更多的计算。那么,如何才能在计算上更有效率呢?事实证明,现代硬件架构可以对“明显”的观察进行一些小把戏。
为了充分利用现代硬件的性能,我们需要理解是什么使得一些默认假设如此反直觉。为了获得这种理解,我们将从性能角度出发,对现代计算机架构进行介绍。这个主题本身就可以写成一本书,但我们将专注于不那么直观的特性:我们将研究内存层次结构,从 CPU 缓存到广域网,经过 RAM、硬盘和本地网络。
我们对使计算更快、存储处理更高效感兴趣,从大小和速度两个角度来看。在理解了现代硬件架构的一些影响之后,我们将看到一些 Python 库如何充分利用硬件。首先,我们将探索 Blosc,一个用于压缩二进制数据的高性能库,如何被用来生成与未压缩数组访问时间大致相同的紧凑表示的 NumPy 数组。在我们分解这个过程时,您将看到智能地使用 CPU 缓存可以使压缩和解压缩时间几乎无关紧要。然后,我们将探讨 NumExpr 如何通过智能处理数据的方式,在缓存意识上加速处理非常大的 NumPy 数组。
最后,我们将转换话题,讨论在基于非常快速本地网络的集群或云上执行计算的影响。我们用于进行数据分析的大部分代码都是在集群或云上运行的,这些集群或云可能基于这些类型的网络实现,因此了解如何进行这一操作是有用的。
解释性能
因为这是一本硬件相关的章节,所以你得到的结果可能与我的不同,因为你的硬件与我的不同。可能适合我机器缓存的,可能不适合你的。此外,如果你在带有用户界面的机器上运行此代码,由于所有其他进程与你代码的同时运行,缓存使用将主要不可预测。
这里展示的所有基准测试都是在没有用户界面或其他大型进程运行的服务器上运行的。规格如下:Intel Xeon 8375C CPU @ 2.90GHz,32 核心,L1 缓存 2 MB,L2 缓存 40 MB,L3 缓存 54 MB,DRAM 16 GB。我们将在 NumExpr 部分给出一个具体示例,说明结果如何因你的硬件而大幅变化。
让我们从现代硬件架构的回顾开始,重点关注可能对高效 Python 编码产生反直觉后果的问题。这一章需要安装 Blosc (conda install blosc)。如果你使用 Docker,主镜像包含了你需要的一切。
6.1 现代硬件架构如何影响 Python 性能
在本节中,我们将概述当前硬件架构的现状,重点关注它们对高效 Python 开发的较少直观但至关重要的影响。硬件架构包括计算机内部的内容——CPU、内存和本地存储——以及网络。当我们查看本地存储,尤其是网络架构时,我们有时也会涉猎系统软件架构问题:即文件系统和网络协议。再次强调,这里的话题可以轻易地填满几本书,因此我们将我们的关注点缩小到对 Python 性能有直接影响的问题,并且有 Python 库可以解决这些问题。
我们将从看似微不足道的例子开始,这个例子将作为动机,并希望说服你理解这些硬件和系统问题如何影响性能——以及你可以做些什么。如果你对在某些操作中获得高达两个数量级的性能提升感兴趣,其中你原本预期没有任何提升,请继续阅读。
6.1.1 现代架构对性能的反直觉影响
我们的简单例子将是简单地取一个 NumPy 平方 矩阵并复制一行和一列的值。从性能的角度来看,加倍一行应该与加倍一列花费的时间完全相同,前提是矩阵是平方的。这很明显!或者不是吗?
为了找出答案,让我们评估加倍一行和加倍一列的性能成本:
import numpy as np
SIZE = 100
mat = np.random.randint(10, size=(SIZE, SIZE))
double_column = 2 * mat[:, 0] ①
double_row = 2* mat[0, :]
① 在 Jupyter 上的 IPython 允许我们在最后两行之前添加 %timeit 来进行性能分析。
我们创建一个介于 0 和 9 之间的随机矩阵。我们开始时的大小为 100。稍后,我们将改变这个大小,使用 1000 和 10,000。
再次,请注意矩阵是方阵。这意味着 double_column 和 double_row 将需要相同数量的操作。常识表明这是一个微不足道的问题(即不值得我们花费时间),并且加倍一列或加倍一行的时间成本应该大致相同。在这种情况下,常识是错误的。
让我们从之前的代码开始,一个大小为 100 的方阵——因此,有 10,000 个元素。鉴于默认整数表示为 8 字节,我们有 80 KB。在我的电脑上,加倍一列的平均时间为 750 纳秒;加倍一行,715 纳秒。没有太大的差异,考虑到操作的粒度,差异可能是由用于分析此操作的仪器引起的。到目前为止,一切都很明显。
让我们增加矩阵的大小到 1000,因此有 100 万个元素和 8 MB。现在我们有 1.99 微秒和 1.5 微秒。再次,并没有什么真正引人注目的。
让我们将大小增加到 10,000。这样一个矩阵将占用大约 800 MB,所以请确保你有足够的内存来执行此操作。加倍一列需要 4.51 微秒;加倍一行,74.9 微秒!
对于这个更大的矩阵,加倍一列比加倍一行快 16 倍。让这个事实沉淀下来:有关硬件架构和 NumPy 内部表示的某些方面使得两个看似相同的操作在性能上相差一个数量级以上!
这里有两个问题在起作用。一个是 CPU 缓存和主存储器之间的关系。另一个是矩阵的内部表示。它们共同导致了性能差异。我们将在本章后面深入探讨这两个问题。
6.1.2 CPU 缓存如何影响算法效率
我们首先考虑瞬态记忆。我们通常从 DRAM 的角度思考,但计算实际上发生在 CPU 寄存器(即最低级别的内存)中,并经过 CPU 缓存的几层。表 6.1 展示了一个现代机器可能的样子。
表 6.1 假设但现实的现代桌面计算机的内存层次结构,大小和访问时间
| 类型 | 大小 | 访问时间 |
|---|---|---|
| CPU | ||
| L1 缓存 | 256 KB | 2 ns |
| L2 缓存 | 1 MB | 5 ns |
| L3 缓存 | 6 MB | 30 ns |
| RAM | ||
| DIMM | 8 GB | 100 ns |
L1 缓存时间接近现代 CPU 的周期速度。记住,2 GHz 意味着 2 × 10⁹ 个周期/秒,而纳秒是 10^(-9) 秒。
如果 CPU 需要的数据可以在 L1 缓存中找到(命中率),则速度将匹配。然而,使用 DRAM 意味着 CPU 将长时间空闲:90% 的时间可能只是在等待数据读取,这并非不可能。
现在,我们可以解释我们的原始示例:为什么方阵中列的加倍与行的加倍可能具有完全不同的时间成本。所以,如果你有一个像
| I11 | I12 | I13 | I14 |
|---|---|---|---|
| I21 | I22 | I23 | I24 |
| I31 | I32 | I33 | I34 |
| I41 | I42 | I43 | I44 |
它必须在内存中顺序表示:
| I11 | I12 | I13 | I14 | I21 | I22 | I23 | I24 | I31 | I32 | I33 | I34 | I41 | I42 | I43 | I44 |
|---|
当你访问元素 I11 时,CPU 会带来一些额外的元素进入内存,而不仅仅是单个元素。所以,如果你做 2I11,2I12,2I13,2I14,从内存到缓存的移动将只有一次。但是,如果你做 2I11,2I21,2I31,2I41,因为它们不是连续的,每次你进行操作时都会有内存移动,这相对是一个非常昂贵的操作。所以,第一种情况是四个加倍操作加一次内存移动,而第二种情况是四个加倍操作加四次内存移动。
当然,我们的例子是一个简化。根据矩阵的大小和缓存的大小,CPU 可能一次就能将所有数据都带入;这就是为什么我们看不到非常小的矩阵有差异。但是,如果矩阵足够大,这种影响就会变得非常明显,以至于可以使一个操作比另一个操作贵一个数量级。
提示:在矩阵表示方面,还有一个需要考虑的问题:我们可以连续表示每一行,或者连续表示每一列。前者在基于 C 的代码中很常见,而后者在基于 Fortran 的代码中很常见。这对我们来说非常重要,因为 NumPy 的后端可以用这两种语言之一实现,因此我们必须注意后端实现,以便设计如何访问数据。
下两个部分将演示如何使用两个库,Blosc 和 NumExpr,来有效地利用 CPU 缓存。
6.1.3 现代持久存储
另一个潜在的问题领域是持久存储。最常见的是本地存储,无论是硬盘驱动器(HDD)还是固态驱动器(SDD)。持久性内存的访问速度比瞬态内存慢得多:SSD 的访问时间在微秒级别,而 HDD 在毫秒级别。虽然我们在这里不会进一步探讨这个话题(尽管这些话题中的某些将在第八章以不同形式出现),请注意,下一节中介绍的瞬态内存的技术同样适用于存储。例如,有些情况下,处理压缩文件比处理原始文件更快;解压缩的成本可以大大低于从磁盘读取更多(原始)数据。
除了瞬态内存和本地持久内存之外,我们还有远程存储和远程计算。从理论上讲,远程存储和远程计算比本地存储慢得多。例如,当你访问互联网上的存储服务器时,访问时间很长且不可预测。尽管如此,现代本地计算集群可以拥有非常快的骨干网络。有多快?访问远程服务器可能比从本地磁盘获取数据更快!我们将在本章的最后部分讨论这一点的影响。正如你将看到的,当我们在一个本地快速网络上工作时,我们用来访问远程网络服务的标准网络协议可能不够快。
吸取的经验
我希望你能从这一节中得到这样的认识:关于计算和内存局部性的某些旧假设可能是错误的。正如我们所见,同样的操作可能具有截然不同的成本,这取决于内存是如何分配的。因此,如果我们想提高 CPU 效率,我们需要确保算法所需的信息尽可能接近 CPU。此外,DRAM 的接近性还不够,因为访问它可能导致 CPU 饥饿,并使 CPU 空闲多个周期。尽可能让数据在 L1 缓存中可用是我们的目标。
但这个洞还要深:有时在运行时解压缩数据(即使用昂贵的解压缩算法)可能比使用原始数据更快。这正是 Blosc 允许我们做到的,我们将在下一节中探讨这一点。
6.2 使用 Blosc 高效存储数据
Blosc 是一个高性能的压缩框架,旨在使处理压缩数据比处理其未压缩版本更快。这怎么可能呢?记得上一节提到的,如果需要处理的数据位于 DRAM 的较远位置,CPU 大部分时间都会处于饥饿状态。如果我们用于(解)压缩数据的 CPU 周期数足够少,以至于它们发生在 CPU 饥饿时间内,那么压缩实际上就是“免费的”。
6.2.1 压缩数据;节省时间
为了了解在某些情况下处理压缩数据如何比处理原始数据更快,我们将探讨三种创建 NumPy 数组的替代方法,然后使用 NumPy 和 Blosc 将它们存储到磁盘上并检索它们。我们将研究每种方法的时序和磁盘空间影响。这比看起来要复杂得多。
让我们从创建数组和支持函数开始:
import os
import blosc
import numpy as np
random_arr = np.random.randint(
➥ 256, size=(1024, 1024, 1024)).astype(np.uint8)
zero_arr = np.zeros(shape=(1024, 1024, 1024)).astype(np.uint8)
rep_tile_arr = np.tile(
np.arange(256).astype(np.uint8),
4*1024*1024).reshape(1024,1024,1024)
def write_numpy(arr, prefix):
np.save(f"{prefix}.npy", arr) ①
os.system("sync") ②
def write_blosc(arr, prefix, cname="lz4"):
b_arr = blosc.pack_array(arr, cname=cname) ③
w = open(f"{prefix}.bl", "wb")
w.write(b_arr)
w.close()
os.system("sync") ②
def read_numpy(prefix):
return np.load(f"{prefix}.npy")
def read_blosc(prefix):
r = open(f"{prefix}.bl", "rb")
b_arr = r.read()
r.close()
return blosc.unpack_array(b_arr) ④
① NumPy 可以原生处理磁盘持久性。
② sync 强制磁盘刷新。
③ 如果我们想在 Blosc 中写入 NumPy 数组,我们需要将其打包。
④ 如果我们想在 Blosc 中读取 NumPy 数组,我们需要将其解包。
我们首先创建三个数组:一个只包含零,另一个包含最多 256 的瓦片值,还有一个包含随机值。我们使用这三种数组类型,因为它们代表了关于压缩的非常不同的情境:零数组很容易压缩,瓦片数组平均压缩时间,而随机数组基本上无法压缩。
然后,我们创建一组辅助函数来读取和写入磁盘上的数组。在写入方面,这并不简单:因为我们想公平地基准测试写入函数的总成本,我们需要强制操作系统刷新缓冲区——因此,我们使用了sync。sync在 Windows 上不可用。
让我们现在基准测试写入部分:
os.system("sync")
%time write_numpy(zero_arr, "zero")
%time write_blosc(zero_arr, "zero")
%time write_numpy(rep_tile_arr, "rep_tile")
%time write_blosc(rep_tile_arr, "rep_tile")
%time write_numpy(random_arr, "random")
%time write_blosc(random_arr, "random")
我们首先调用sync,这将尽可能清理操作系统的 IO 缓冲区。然后我们使用%time来计时写入函数。虽然对于写入部分来说可能安全,至少对于写入部分,我们可以使用%timeit,但我们想避免任何操作系统优化我们的调用的可能性,这会使基准测试难以解释。表 6.2 提供了时间结果。
表 6.2 使用 NumPy 和 Blosc 不同数组类型写入时间的秒数
| 数组 | NumPy | Blosc |
|---|---|---|
zero |
7.49 | 0.53 |
rep_tile |
7.49 | 0.53 |
random |
7.5 | 8.13 |
对于zero和rep_tile数组,Blosc 大约快 15 倍。对于random数组,NumPy 略快。哪种情况更常见?与现实最不接近的情况是随机情况:表格中的数据往往有一定的模式。对于zero和rep_tile情况,并不过于乐观。
因此,出人意料的是,Blosc 在时间效率方面更高效。但关于磁盘空间占用呢?这是对于非常大的数据集的一个重要指标,我们预计 Blosc 将优于其他。rep_tile是 200 倍小,zero是 250 倍小,而random大小相同。
6.2.2 读取速度(和内存缓冲区)
现在我们将检查读取速度。从理论上讲,这只是读取文件的问题,对吧?然而,鉴于我们已经将它们写入磁盘,操作系统可能将它们存储在中间内存缓冲区中,从而提供了一个有偏差的性能视图。换句话说,缓存可能会使分析不可靠。为了进行公平的比较,我们需要确保我们真正是从磁盘而不是从临时内存缓冲区中读取,这会快得多。因此,我们必须刷新缓冲区。
解决这个问题的粗暴方法是通过重启计算机。另一种方法是告诉操作系统使所有缓存失效。不幸的是,这取决于操作系统。在这里,我将给你一个在 Debian/Ubuntu 及其衍生版本上这样做的方法。这不能在 Windows 或 Mac 上工作,或在某些其他 Linux 发行版上。如果你处于这种情况,你将不得不调查如何使用你的操作系统来完成它。
作为 root 用户,使用以下命令:
sync; echo 3 > /proc/sys/vm/drop_caches
现在,我们可以预期数据不在瞬态缓冲区内存中读取:
%time _ = read_numpy("zero")
%time _ = read_blosc("zero")
%time _ = read_numpy("rep_tile")
%time _ = read_blosc("rep_tile")
%time _ = read_numpy("random")
%time _ = read_blosc("random")
时间信息在表 6.3 中提供。
表 6.3 使用 NumPy 和 Blosc 的不同数组时间读取时间(秒)
| 数组 | NumPy | Blosc |
|---|---|---|
zero |
7.02 | 0.63 |
rep_tile |
7.04 | 0.61 |
random |
7.37 | 8.58 |
这与写入时间有类似的模式。对于非随机数据,Blosc 明显优于 NumPy,这意味着你应该考虑它用于大数据量。
到目前为止,我们并没有太在意我们使用的是哪种压缩算法;我们只是接受了默认设置。但是 Blosc 允许你从许多算法中选择。
6.2.3 不同压缩算法对存储性能的影响
此处的目标不是对现有算法及其基准速度进行详尽的调查。相反,我只想让你知道存在不同的算法,未来可能会添加更多。这些算法的速度和效率因使用方式和时间而异。有了这种认识,你应该能够选择最适合你特定数据和需求的一个。为了说明性能的潜在变化,让我们比较两种算法:LZ4 和 Zstandard。
到目前为止,我们将停止将数据写入磁盘,因为现在应该很明显,对其进行基准测试很繁琐。我们只进行内存操作;也就是说,我们将使用 BLOSC 压缩数据,因为我们已经看到了其优越的性能。首先,我们将使用 LZ4,然后是 Zstandard:
%timeit rep_lz4 = blosc.pack_array(rep_tile_arr,
cname='lz4') ①
rep_lz4 = blosc.pack_array(rep_tile_arr, cname='lz4')
%timeit rep_std = blosc.pack_array(rep_tile_arr,
cname='zstd') ②
rep_std = blosc.pack_array(rep_tile_arr, cname='zstd')
print(len(rep_lz4) // 1024)
print(len(rep_std) // 1024)
① 我们使用 LZ4 作为压缩算法创建一个内存表示。
② 我们使用 Zstandard 作为压缩算法创建一个内存表示。
表 6.4 提供了时间和大小结果。
表 6.4 使用 LZ4 和 Zstandard 的压缩时间和大小
| LZ4 | Zstandard | |
|---|---|---|
| 时间 (ms) | 527 | 919 |
| 大小 (KB) | 5204 | 366 |
如果你回忆一下上一节,LZ4 表示形式比标准的 NumPy 表示形式小 200 倍。这意味着 Zstandard 压缩算法比 NumPy 小 2800 倍(200 倍乘以 14——LZ4 和 Zstandard 的比例)。
Blosc 还有额外的技巧:它不仅提供各种算法,还允许你在运行时更改输入数据的表示。这可以进一步减小压缩数据的大小。让我们看看它是如何工作的。
6.2.4 利用数据表示的见解来提高压缩
假设你知道你的数据有一些规律性模式;例如,序列中的数字很常见。比如说,你的数据包括以下 8 位编码数字的序列:
3,4,5,6
这通常以二进制形式编码如下:
00000011/00000100/00000101/00000110
现在想象一下,你取每个数的最高位并对其进行编码,然后是次高位,依此类推,直到每个数的第 8 位。这将看起来像:
00000000000000000000011110011010
第二种模式看起来规律得多。这正是压缩器能够高效工作的原因。Blosc 允许你做 precisely this:
for shuffle in [blosc.BITSHUFFLE, blosc.NOSHUFFLE]:
a = blosc.pack_array(rep_tile_arr, shuffle=shuffle)
print(len(a))
打乱版本的文件大小为 4,600,034 字节,而未经处理的版本为 5,345,500 字节。打乱操作略微慢一些,596 毫秒比 524 毫秒。
吸取的教训
智能地使用内存层次结构和 CPU 处理可以加速一些基本的数组操作。Blosc 允许我们在许多情况下,比使用原始数据更快地访问压缩存储的数据表示。这除了通常的具有更小的持久数据集的好处之外。
让我们更进一步,在分析数据时使用类似的技术。为此,我们将探索 NumExpr 库。
6.3 使用 NumExpr 加速 NumPy
Blosc 是我们如何智能地使用内存层次结构来加速数据处理的例子之一。但我们可以通过使用 NumExpr 加速 NumPy 表达式来进一步采取这种方法。
NumExpr 是 NumPy 的数值表达式评估器,它可以比 NumPy 更快。它接受一个表达式——比如说,a + b——并计算其结果。但是等等:这有什么意义呢?NumPy 不是也在做这件事吗?确实,NumExpr 用一种试图在处理大型数据集时更有效地重新组织计算的引擎替换了 NumPy 的一些功能。NumExpr 使用的一种技术依赖于不生成表达式部分的全中间表示:计算是在设计用来适合 L1 缓存的块中进行的。
6.3.1 快速表达式处理
现在让我们看看几个 NumExpr 如何改变表达式评估性能的例子:
import numpy as np
import numexpr as ne
a = np.random.rand(100000000).reshape(10000,10000)
b = np.random.rand(100000000).reshape(10000,10000)
f = np.random.rand(100000000).reshape(10000,10000)
.copy('F') ①
%timeit a + a
%timeit ne.evaluate('a + a') ②
%timeit f + f
%timeit ne.evaluate('f + f')
%timeit a + f
%timeit ne.evaluate('a + f')
%timeit a**5 + b
%timeit ne.evaluate('a**5 + b')
%timeit a**5 + b + np.sin(a) + np.cos(a)
%timeit ne.evaluate('a**5 + b + sin(a) + cos(a)')
① 这个矩阵使用 Fortran 标准表示。
② NumExpr 提供了 evaluate 函数来处理表达式。
我们首先创建三个正方形矩阵来支持我们的性能评估。最后一个具有 Fortran 组织。%timeit基准测试的输出总结在表 6.5 中。
表 6.5 NumPy 与 NumExpr 执行时间比较。数值以毫秒为单位。
| 表达式 | Numpy 平均时间 | NumExpr 平均时间 | 加速比 |
|---|---|---|---|
| a + a | 224 | 58 | 3.8 |
| f + f | 224 | 58 | 3.8 |
| a + f | 577 | 153 | 3.7 |
| a**5 + f | 1690 | 87 | 19.4 |
| a**5 + f + sin(a) + cos(a) | 3840 | 153 | 25.1 |
在我们的硬件配置中,NumExpr 更高效。看看以不同格式表示的矩阵操作,C 与 Fortran 相比。当所有数组格式统一时,这些操作的成本更高。随着表达式的复杂化,NumExpr 的优势增加,因为更多的空间可用于优化技术。
然而,这些例子过于美化 NumExpr。有些情况下,NumExpr 可能会降低性能。本节的其余部分将展示一些缺点。我们将从由硬件引起的性能的定性变化开始。
6.3.2 硬件架构如何影响我们的结果
正如我在本章开头提到的,你的结果可能与这里显示的结果大相径庭。为了使这一点清楚,我将比较我用来编写此文本的机器(即运行 Linux GUI 和文本编辑器的机器)。我不会在这里列出我的 CPU 的缓存大小,因为这可能会产生误导:有如此多的进程同时访问缓存,以至于无论如何都无法估计 L1 缓存。表 6.6 比较了在服务器上使用 NumExpr 进行算术运算与我的笔记本电脑的性能提升速度。
表 6.6 硬件架构对性能的影响:硬件作为速度提升的函数
| 表达式 | 服务器速度提升 | 笔记本速度提升 |
|---|---|---|
| a + a | 3.8 | 0.7 |
| f + f | 3.8 | 0.8 |
| a + f | 3.7 | 1.3 |
| a**5 + f | 19.4 | 11.5 |
| a**5 + f + sin(a) + cos(a) | 25.1 | 6.7 |
使用 NumExpr 带来的性能优势严重降低。实际上,对于一些操作,NumExpr 的实现突然变得比 NumPy慢。其中一个主要原因是与典型本地机器上 CPU 缓存不可预测的可用性有关,因为许多进程正在运行并竞争缓存。
提示:不要期望在运行大量其他应用程序(例如所有基于 UI 的应用程序,如你的文本编辑器或浏览器)的本地机器上基于缓存优化获得大的速度提升。CPU 缓存将会有巨大的竞争,运行结果可能会有很大的差异。这些技术将在服务器上发挥作用,但在典型的开发机器上则不会。因此,在测试任何利用 CPU 缓存的技术优势时,请在服务器上进行测试。
如前例所示,并非所有场景都适合使用 NumExpr。让我们详细说明何时、何地以及为什么 NumExpr 不是最佳选择。
6.3.3 当 NumExpr 不适用时
有几种情况下 NumExpr 可能会产生不利影响。让我们来讨论一下。
最重要因素是数组的大小:NumExpr 通常在较大的数组上表现更好。让我们重复一些之前的例子,但这次使用小型数组:
small_a = np.random.rand(100).reshape(10, 10)
small_b = np.random.rand(100).reshape(10, 10)
%timeit small_a + small_a
%timeit ne.evaluate('small_a + small_a')
%timeit small_a**5 + small_b + np.sin(small_a) + np.cos(small_a)
%timeit ne.evaluate('small_a**5 + small_b + sin(small_a) + cos(small_a)')
使用 NumExpr 进行加法运算会慢 15 倍,而如果你使用 NumExpr 进行复杂表达式计算,速度仍然会慢 30%。然而,这并不是一个严重的问题:在大多数情况下,你不会试图优化小型数组:大数据,而不是小数据,才是我们的问题。
NumExpr 性能下降的另一个原因是当你在对运行其他进程的机器进行基准测试时,比如你自己的本地机器。NumExpr 在服务器上表现更好,特别是如果你可以控制运行的应用程序数量。这意味着在学术界常见的共享集群上,NumExpr 的性能会有所变化。最后,NumExpr 只支持 NumPy 操作符的一个子集,因此某些操作无法通过 NumExpr 得到提升。
既然我们已经讨论了几种优化使用瞬态内存的方法,我们将完全改变我们的焦点,并讨论本地网络。现代本地网络可以比访问本地持久存储更快,这个事实再次颠覆了一些常见的假设。
6.4 使用本地网络性能的影响
当我们为网络编写代码时,我们正在处理一个可能具有完全不同特性的基础设施。很多时候我们假设网络是远方的——存在速度、延迟和弹性问题。但在许多高性能场景中,当我们本地使用网络时,这些假设并不适用。现代网络交换机可以支持高达 2 Tb/s 的主干通信,每个网络端口高达 56 Gb/s。作为一个参考,大多数本地磁盘每秒支持 6 Gb/s。思考一下:在一个高性能的本地网络中,与另一台计算机交谈比与本地磁盘交谈要快。如果你处于节点之间有快速网络的情况,请继续阅读。
用于通信的典型软件网络框架完全不适用于处理现代本地网络的速度。你认为使用 HTTPS 上的 REST 调用查询本地磁盘会有效率吗?鉴于高性能本地网络比本地磁盘访问更快,我们必须找到一种更有效的方式进行通信。
在我们设计解决方案之前,让我们了解为什么标准方法将不会有效率。在本节中,我们将实现一个基于 REST 的非后端的服务。一个代码片段服务允许你将文本存储起来,以便通过互联网与他人分享。我们将编写一个客户端,用于发送文本进行存储和请求读取文本。我们还将编写服务器,用于存储文本并在请求时提供服务。有关此类实际服务的示例,请参阅pastebin.com/。我们将假设我们的客户端和服务器将在一个非常快速的本地网络上运行。
6.4.1 REST 调用中的低效来源
在我们设计一个高效解决方案之前,让我们先了解典型 REST 实现中的性能瓶颈。在 REST 中,客户端/服务器通信通常是通过 HTTPS 上的 JSON 有效载荷来完成的。JSON 是一种文本格式,因此需要解析时间,并且占用大量空间。HTTPS 通过使用公钥加密技术增加了 HTTP 协议的认证和加密。因此,HTTPS 在 HTTP 之上增加了大量的处理。
HTTP 协议在其之上的一个名为传输控制协议(TCP)的互联网协议上完成所有工作。TCP 在两个端点之间建立连接抽象——在我们的例子中,是客户端和服务器。连接确保数据按顺序到达且无丢失。但这个协议很重,至少对于我们的非常快的网络来说是这样:仅仅建立连接就需要至少三个数据包在客户端和服务器之间来回传递。
在我们建立了 TCP 连接之后,HTTP 的安全部分需要完成,这部分由传输层安全(TLS)协议委托处理。这个协议执行握手,这需要在客户端和服务器之间来回传递几个数据包。鉴于它涉及到加密,这将非常计算密集。请注意,计算时间成本是相对的:在一个非常快的网络中,它将是计算的一个很大部分;如果你在与世界另一端的服务器进行计算,同样的计算时间在整体计算中将是微不足道的。
现在我们已经准备好发送我们的有效载荷,这是一个需要文本解析的详细 JSON 有效载荷。最后,我们还需要在 HTTPS 和 TCP 级别关闭连接。
从消息的角度来看,通信至少需要交换 20 个网络数据包,可能更多。考虑到本地网络的速率,绝大多数的交换都将花费在协议上。让我们使用仅两个数据包来实现这一点:最基本的一组——一个用于请求,另一个用于响应。
6.4.2 基于 UDP 和 msgpack 的简单客户端
我们的实现将会非常简单。比起理解代码,理解我们的实现带来的权衡更为重要。
让我们从客户端开始。我们的客户端将向 pastebin 服务器发送文本,然后检索它。它开始如下:
import socket
host = '127.0.0.1'
port = 54321
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) ①
① 创建一个新的 UDP,SOCK_DGRAM 套接字
我们不是使用整个基于 TCP 的 HTTPS 堆栈,而是完全摒弃了应用协议,并用 UDP(用户数据报协议)替换了 TCP。UDP 不建立连接;它只是发送数据包。将 UDP 想象成邮政服务,而 TCP 则是电话服务。在邮政服务中,信件可能会丢失,顺序错误地交付,或者错误地路由。在电话通话中,流是按顺序交付的,没有丢失信息。从开销角度来看,UDP(邮政)比 TCP(电话)更轻。
之前的代码片段使用了低级模块 socket 来创建 UDP 通信端点。我们指定服务器地址为 127.0.0.1——在本例中,是本地机器的地址——我们将使用端口 54321。
这个解决方案隐含了一些假设,你需要确保这些假设适用于你的情况:
-
不使用加密通道,我们使我们的实现容易受到窃听和数据更改的影响。在本地高性能网络中通信时,这比在互联网上要小得多的问题。相当实用地,如果一个安全威胁可以访问你的网络骨干,你面临的问题比访问数据更大:你有一个被破坏的基础设施。
-
UDP 协议不保证数据包的交付。这意味着我们的解决方案可能在客户端和服务器之间丢失数据。在高性能的本地网络中,这个问题比在互联网上要少得多。尽管如此,它确实可能发生,我们将在本章的最后一个小节中解决这个问题。
让我们现在通过向服务器发送文本并检索它来完善我们的客户端:
import msgpack ①
def send_text(sock, text):
pack = msgpack.packb({'command': 0, 'text': text}) ②
sock.sendto(pack, (host, port)) ③
text_id_enc = sock.recv(10240) ④
return int.from_bytes(text_id_enc, byteorder='little')
def request_text(sock, text_id):
pack = msgpack.packb({'command': 1, 'text_id': text_id})
sock.sendto(pack, (host, port))
text = sock.recv(10240)
return text
text_id = send_text(sock, 'trial text')
returned_text = request_text(sock, text_id)
① 我们使用外部 msgpack 库来编码复杂的数据结构。
② 我们使用 msgpack 将字典打包到字节数组中。
③ 我们向服务器发送 UDP 消息。
④ 我们从服务器接收响应。
我们使用一个名为send_text的函数来向服务器发送文本。请求包括命令0,表示存储文本和文本本身。我们本可以更明确地编码命令,例如使用字符串“store text”,但这将更加冗长,因此效率更低。我们以原样发送文本,但根据我们在前几节中看到的内容,压缩文本可能是一个可行的选项,尤其是如果我们预计将传输大量文本。
我们从服务器获得的答案没有使用 msgpack 进行编码。鉴于我们将获得存储文本的数字 ID,我们使用更简单的方法:从字节流中重建一个整数。这应该比 msgpack 更快。
request_text函数有一个命令代码1和一个用 msgpack 打包的数字 ID。发送消息后,我们收到文本。
最后,我们向服务器发送文本,然后通过使用文本 ID 来获取文本。现在,我们将实现服务器端。之后,我们将重新访问客户端,使其对消息丢失更加健壮。
6.4.3 基于 UDP 的服务器
服务器代码基于内置模块socketserver,它提供了编写基于套接字的服务器的实用类:
import os
import socketserver
import msgpack
class UDPProcessor(socketserver.BaseRequestHandler): ①
def handle(self): ②
request = msgpack.unpackb(self.request[0])
socket = self.request[1]
if request['command'] == 0:
text = request['text']
w = open(f'texts/{self.server.snippet_number}.txt', 'w')
w.write(text)
w.close()
socket.sendto(self.server.snippet_number.to_bytes(
4, byteorder='little'), self.client_address)
self.server.snippet_number += 1
elif request['command'] == 1:
text_id = request['text_id']
f = open(f'texts/{text_id}.txt')
text = f.read()
f.close()
socket.sendto(text.encode(), self.client_address)
host = '127.0.0.1'
port = 54321
try:
os.mkdir('texts')
except FileExistsError:
pass
with socketserver.UDPServer((host, port), UDPProcessor)
as server: ③
server.snippet_number = 0 ④
server.serve_forever()
① 我们在处理类内部实现我们的服务器处理代码。
② 处理类需要创建处理方法来实现功能。
③ 我们创建一个 UDP 服务器。
④ 我们初始化一个用于文本 ID 的内部变量。
处理函数首先获取命令以决定执行哪个操作。存储请求将接受提供的文本并将其写入磁盘。检索请求将获取与提供的 ID 相关的文本。
之前小节中介绍了与性能相关的概念:它们使用 msgpack 和 UDP。现在我们有了服务器和客户端,让我们使客户端更加健壮。
6.4.4 客户端处理基本恢复
我们的原生客户端发送一条消息并等待响应。UDP 不保证数据包的交付,因此我们需要添加超时机制。话虽如此,在一个高性能的本地网络中,UDP 数据包丢失应该是一个罕见的事件。
这里是我们的实现,使用装饰器:
import functools
def timeout_op(func, max_attempts=3):
@functools.wraps(func)
def wrapper(*args, **kwds):
attempts = 0
while attempts < max_attempts:
try:
return func(*args, **kwds)
except socket.timeout:
print('Timeout: retrying')
attempts += 1
return None
return wrapper
@timeout_op
def send_text(sock, text):
...
@timeout_op
def request_text(sock, text_id):
...
sock = socket.socket(
socket.AF_INET,
socket.SOCK_DGRAM)
sock.settimeout(1.0) ①
① 我们为套接字设置了超时。
我们只需将我们的装饰器应用到send_text和request_text上。默认情况下,套接字是阻塞的;也就是说,它等待直到收到消息,因此我们在socket之后使用settimeout使其非阻塞,并在 1.0 秒后如果没有收到消息则返回。这种简单的超时机制应该足以精确处理客户端,因为网络应该足够可靠,以确保大多数 UDP 数据包不会丢失。
在服务器端也做类似的事情也是合理的。在服务器端,重复操作的问题在于语义:如果您最终保存了两次 pastebin,您正在消耗磁盘资源。作为一般规则,请小心处理您创建的操作的语义。它们可能在没有造成损害的情况下不可重复。
6.4.5 网络计算优化的其他建议
我们的实现使用 UDP 显著减少了消息开销,但有时您可能需要使用 TCP,甚至是在 TCP 之上的 HTTPS 或其他协议。如果这种情况发生,这里有一些提示:
-
如果您的客户端将向服务器发送多个请求,请尝试为所有请求使用相同的连接。 这样,您只需支付一次建立和拆除连接的成本。
-
有时在高峰使用之前预先打开 TCP 连接是可能的。 这样就可以将建立连接的成本支付在关键时间路径之外。这种技术与之前的技术一样,通常与数据库连接一起使用,并被称为 连接池。
-
如果 UDP 太简单而 TCP 太重,考虑使用新的 QUIC 协议。 QUIC 最初代表“快速 UDP 互联网连接”。正如旧名称所表明的,它试图弥合在 UDP 之上拥有连接优势的差距。
摘要
-
注意内存层次结构对于高效程序的设计至关重要。大多数程序员至少对 RAM、磁盘存储和网络访问的影响有所了解,但往往对 CPU 缓存和瞬态 RAM 内存之间的关系的影响了解较少。
-
DRAM 内存访问会导致 CPU 饥饿,使 CPU 空闲许多周期。确保尽可能多的数据在 CPU 缓存中可用可以显著提高处理速度。
-
避免 CPU 饥饿的算法可能会更高效,但有时是以不直观的方式。例如,处理压缩数据可能比处理原始数据更快,因为(解)压缩算法的成本可能低于从 RAM 获取(更大的)未压缩数据所需的成本。
-
在许多情况下,Blosc 允许你以比使用原始数据更快的速度访问数据的压缩存储表示。这除了通常的拥有更小持久数据集的好处之外。
-
NumExpr 可以在比 NumPy 本身更短的时间内,使用更少的内存处理类似 NumPy 的表达式。NumExpr 利用智能 L1 缓存等技巧来加速评估,有时甚至可以超过一个数量级。
-
一些现代的本地网络速度如此之快,以至于通过网络访问其他计算机可能比访问本地磁盘更快。
-
标准的 REST API 太慢且效率低下,无法充分利用快速本地网络。
-
通过在多个层面进行改进,网络通信可以变得更快:选择最佳的传输协议(TCP 与 UDP),不使用 HTTPS,以及使用比 JSON 更快的序列化数据方式。
第三部分:现代数据处理的应用和库
书的第三部分最直接应用于数据问题,因为它涵盖了广泛使用的 Python 分析库。我们首先讨论了始终存在的 pandas 库来处理数据框。我们还探讨了 Apache Arrow,这是一个现代库,可以在其他任务中帮助加速 pandas 处理。然后我们讨论了旨在从持久化中提取最大性能的库。我们考察了用于 N 维数组的 Zarr 和用于数据框的 Parquet。处理大于内存数据集的话题也被引入。
7 高性能的 pandas 和 Apache Arrow
本章涵盖
-
使用 pandas 的数据框创建优化内存使用
-
降低 pandas 操作的计算成本
-
使用 Cython、NumExpr 和 Numpy 加速 pandas 操作
-
使用 Apache Arrow 优化 pandas
数据分析本质上等同于使用 pandas。pandas 是一个数据框库,或者是一个处理表格数据的库。pandas 是 Python 世界中处理内存中表格数据的既定标准。在本章中,我们将讨论优化 pandas 使用的途径。这将是一个双管齐下的方法:我们将直接优化 pandas 的使用,并且我们还将使用 Apache Arrow 来优化它。
Apache Arrow 提供了语言无关的功能,以有效地访问列式数据,以便在不同语言实现之间共享这些数据,并将数据传输到不同的进程甚至不同的计算机。它可以通过引入更快的算法来执行基本操作,如读取 CSV 文件,将 pandas 数据框转换为低级语言的格式以进行更快的处理,以及增强序列化机制以在不同计算机之间传输数据框。
我们将首先考虑一些优化 pandas 使用的技术。在这里,我们将注意力分配在时间和内存之间。鉴于 pandas 是一个内存库,我们希望确保我们拥有尽可能小的内存占用,这不仅允许我们进行复杂的数据分析,而且在我们需要考虑磁盘实现之前,可以加载尽可能多的数据(我们将在第十章中更多关于磁盘实现的内容)。
接下来,我们将使用之前学习过的库——NumPy、Cython 和 NumExpr 来优化 pandas 数据框的处理。由于 pandas 基于 NumPy,实际上使用 Cython 和 NumExpr 来优化 pandas 是相当容易的。
然后,我们将了解 Apache Arrow,并从两个角度进行讨论。首先,我们将看到 Apache Arrow 如何为 pandas 提供标准算法的替代实现。例如,在 Arrow 中读取大型 CSV 文件并将其转换为 pandas,与直接在 pandas 中读取相比,是否有性能优势?其次,我们将使用 Arrow 高效地将数据框传输到其他低级编程语言,这样我们就可以通过使用更高效的算法实现来加速处理。
让我们从使用标准 pandas 优化数据加载开始。如果你使用 conda,你应该安装 PyArrow:在我写这本书的时候,pip install pyarrow 似乎是最有效的解决方案。如果你使用 Docker,请使用镜像 tiagoantao/python-performance-dask。
7.1 加载数据时的内存和时间优化
我们的首要任务是优化内存使用和 pandas 数据框的加载速度。在下一节中,我们将优化数据分析操作。在我们的例子中,我们将使用纽约市著名的黄色出租车行程记录。纽约市出租车和豪华轿车委员会(TLC)在mng.bz/516D提供了一个公开的行程数据集。我们将使用 2020 年 1 月的黄色出租车数据。我们提供了每趟出租车行程的信息,包括开始和结束时间、乘客数量、车费金额、小费等。
我们将首先在本地上下载数据。虽然 pandas 可以直接从远程源下载,但我们不希望每次从网络上加载数据框时都等待,因为这会花费很多时间;我们也不想持续访问数据服务器。我们的目标将是双重的:确定 pandas 加载整个表和不同列所需的内存量,并减少内存使用。您可以使用 wget 下载 566 MB 的数据(tiago.org/yellow_tripdata_2020-01.csv.gz)。
7.1.1 压缩数据与未压缩数据对比
让我们从加载数据开始(本节中的代码可在07-pandas/sec1-intro/read_csv.py中找到):
import pandas as pd
df = pd.read_csv("yellow_tripdata_2020-01.csv")
在我的电脑上,这需要大约 10 秒。正如您在前几章中看到的,压缩数据可能会对处理时间产生积极影响。让我们尝试使用 xz 压缩文件并加载它。您需要安装 xz,然后可以使用yellow_tripdata_2020-01.csv:
df = pd.read_csv("yellow_tripdata_2020-01.csv.xz")
pandas 足够智能,可以根据扩展名推断压缩类型,尽管您可以覆盖它。在我的电脑上,这需要 15 秒。虽然这个数字比未压缩版本要差,但文件大小现在只有 74 MB,减少了七倍。在这种情况下,我们无法减少时间,因此我们需要在磁盘空间和时间打开文件之间做出妥协。您如何平衡这两者将取决于您特定问题的需求。我们将使用 Apache Arrow 重新审视这个问题;目前,表 7.1 提供了不同算法的时间和大小。当然,时间取决于运行此硬件,但不同压缩程序之间的关系才是关键。根据您的使用情况,您可能需要尽可能快地读取,或者可能磁盘上的大小最重要。
表 7.1 CSV 数据压缩对文件大小和 pandas 打开时间的影响
| 应用 | 读取时间(秒) | 大小(MB) |
|---|---|---|
| None | 10 |
566 |
gip |
12 |
105 |
bzip2 |
26 |
103 |
xz |
15 |
74 |
然而,不要将不同压缩算法之间的文件相对大小视为神圣的法则。请确保使用您的数据进行测试,以查看您能得到哪些比率。
本节要点
在这个例子和本章的整个过程中,重要的不是相对数值,而是不同实现和算法可以产生实质上不同的结果的见解。在加载数据时优化内存和时间(或更现实地说,内存或时间)时,有两个要点至关重要。首先,了解底层算法,而不是将其视为黑盒,将允许你发展适当的性能预期。其次,当然,你应该针对你的特定条件进行性能分析,并确定是节省时间还是内存更重要。
7.1.2 列的类型推断
当你加载数据时,你将至少得到以下警告,至少在 pandas 版本 1.0.5 中:
DtypeWarning: Columns (6) have mixed types. Specify the dtype option on
➥ import or set low_memory=False
这条信息表明数据加载器无法正确推断所有列的类型。
警告:抵制住按照 pandas 警告中建议的将 low_memory=False 设置为 False 的诱惑;对于大量数据,你的代码很可能耗尽内存并崩溃。
警告信息通常是某些列正在使用过于通用的数据类型加载的事实的一个明显的迹象。例如,一个整数列被提升为对象,所需的内存也相应增加。我们将在本章后面具体讨论这些例子。
在深入到每一列之前,我们先来确定整个数据框占用了多少内存。pandas 提供了一种比之前章节中介绍的一般方法更具体的方法:
df.info(memory_usage="deep")
简化输出如下:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6405008 entries, 0 to 6405007
Data columns (total 18 columns):
# Column Dtype
--- ------ -----
0 VendorID float64
1 tpep_pickup_datetime object
2 tpep_dropoff_datetime object
3 passenger_count float64
4 trip_distance float64
5 RatecodeID float64
6 store_and_fwd_flag object
7 PULocationID int64
8 DOLocationID int64
9 payment_type float64
...
17 congestion_surcharge float64
dtypes: float64(13), int64(2), object(3)
memory usage: 2.0 GB
我们得到了每个列类型、条目数量和内存使用情况的信息。在我们的例子中,我们的 566 MB 文件正在扩展到 2 GB!鉴于我们讨论的是文本格式的输入,这似乎有些过分。
让我们检查每个列的占用情况以及唯一值的数量:
def summarize_columns(df):
for c in df.columns:
print(c, len(df[c].unique()),
df[c].memory_usage(deep=True) // (1024**2), sep="\t")
summarize_columns()
简化输出如下:
tpep_pickup_datetime 2134342 object 464
passenger_count 11 float64 48
trip_distance 5606 float64 48
RatecodeID 8 float64 48
store_and_fwd_flag 3 object 401
PULocationID 261 int64 48
payment_type 6 float64 48
fare_amount 5283 float64 48
improvement_surcharge 3 float64 48
total_amount 12488 float64 48
congestion_surcharge 8 float64 48
类型为对象的列每个占用超过 400 MB(关于大小的详细信息,请参阅第二章)。float64 每个浮点数需要 64 位——因此,每个值占用 8 字节,从而占用 48 MB。对于 int64 也是如此。我们能否通过更改列的类型来减少内存占用?是的,我们可以——大幅减少。
让我们从 tpep_pickup_datetime 和 tpep_dropoff_datetime 开始。它们的值,正如名称所暗示的,是带有时间的日期。你可以使用 df["tpep_pickup_datetime"].head() 来检查这些值。让我们将这些列转换为 datetime 格式:
df["tpep_pickup_datetime"] = pd.to_datetime(df["tpep_pickup_datetime"])
df["tpep_dropoff_datetime"] = pd.to_datetime(df["tpep_dropoff_datetime"])
就这样简单的更改就将每个列的占用从 464 MB 减少到 48 MB,并将数据框从 2 GB 减少到 1.2 GB。我希望这足以让你相信正确加载数据类型的重要性。
此外,还有几个变量是离散的,并且具有少量可能的值。例如,payment_type 只能具有六个数值,但使用的是 8 字节的 float64。让我们将其重新编码为单字节:
import numpy as np
df["payment_type"] = df["payment_type"].astype(np.int8)
8 位有符号整数的类型来自 NumPy。pandas 当然依赖于 NumPy。
不幸的是,这次尝试将失败。如果你检查该列,你会看到其中有一些缺失数据(NA)值以及数值。我们可以将 NA 重新编码为 0,因为 0 值没有被其他方式使用;如果使用了 0,我们就需要选择另一个值:
df["payment_type"] = df["payment_type"].fillna(0).astype(np.int8)
如预期的那样,这个变化将列的大小从 48 MB 减少到 6 MB——从每个值 8 字节到每个值 1 字节。我们有六个列可以从 64 位缩小到 8 位,还有两个需要 16 位。这意味着又减少了 450 MB。我们现在大约只剩下 750 MB。
不要被解决那个例子有多容易所欺骗;编码和处理缺失值是一个复杂的问题,通常并不容易解决。虽然这个例子很容易解决,但并非所有列都那么简单。
store_and_fwd_flag 是一个更复杂列的例子。它是一个布尔标志,表示票价是否因为存储票价的服务器不可用而被保留在车辆内存中。对于许多缺失值的记录,它是未知的(即,我们不知道它是真还是假)。如果没有缺失值,我们可以将列表示为 boolean,每个值占用 1 位。鉴于我们需要表示第三种状态,我们必须使用下一个可用的数据容器,它是 8 位宽。因此,处理缺失值使我们在这个列上的内存分配增加了八倍。我们最终这样做:
df["store_and_fwd_flag"] = df["store_and_fwd_flag"].fillna(" ").apply(ord)
➥ .apply(
lambda x: [32, 78, 89].index(x) - 1).astype(np.int8)
我们将 NA 转换为空格,并获取每个字符的 ASCII 码值:空格为 32,N 为 78,Y 为 89。使用索引函数,我们可以将 NAs(32) 编码为 -1,N(78) 编码为 0,Y(89) 编码为 1。
经验总结
我们将在本书的剩余部分重新审视表示缺失值的复杂主题,特别是在下一章讨论持久性和 Parquet 文件格式时。现在,知道一个常见的内存浪费(以及因此的操作速度)是比必要的更通用的列数据类型:数据类型越宽,内存占用就越大,操作速度就越慢。更改列类型并不总是简单的事情,但它可以显著减少内存占用。
7.1.3 数据类型精度的效果
我们可以使用的一种减少内存占用的技术是使用相同的数据类型但具有更低的精度。例如,我们可以将一些现金值从 float64 转换为 float32(即,双精度到单精度),从而减少 50%的内存使用:
df["fare_amount_32"] = df["fare_amount"].astype(np.float32)
现在,我们需要评估降低精度对我们表示值的能力的影响。这里的一个简单方法可以是:
(df["fare_amount_32"] - df["fare_amount"]).abs().sum()
我们可以计算双精度和单精度之间的差异。我们必须在相加之前获取绝对值,以避免抵消。整个数据框的总误差高达 0.063:
df = pd.read_csv(
"yellow_tripdata_2020-01.csv.gz",
dtype={ ①
"PULocationID": np.uint8, ②
"DOLocationID": np.uint8
},
parse_dates=[ ③
"tpep_pickup_datetime",
"tpep_dropoff_datetime"],
converters={ ④
"VendorID":
lambda x: np.int8(["", "1", "2"].index(x)), ⑤
"store_and_fwd_flag":
lambda x: ["", "N", "Y"].index(x) - 1,
"payment_type":
lambda x: -1 if x == "" else int(x),
"RatecodeID":
lambda x: -1 if x == "" else int(x),
"passenger_count":
lambda x: -1 if x == "" else int(x)
}
)
① 我们为每个列指定不同的类型。
② 我们将一些列从 64 位整数限制为 8 位整数。
③ 有一种稍微不同的方式来指定日期类型。
④ 我们创建了几个转换器,主要是为了重新编码 NAs。
⑤ 我们明确地将 VendorID 转换为 np.int8。
在这个阶段,df.info(memory_usage="deep") 将报告 757.4 MB 的内存使用量。注意,然而,大多数数值类型将报告 64 位长度,包括我们用 np.int8 调用包装的 VendorID——显然没有效果。
我们有几个列显然还可以更小。因为我们选择不那么精确,所以我们将所有 64 位浮点数减少到 16 位。我们还将所有 64 位整数减少到 8 位,因为 8 位整数的范围-128 到 127 对于这个特定情况来说是足够的:
for c in df.columns:
if df[c].dtype == np.float64:
df[c] = df[c].astype(np.float16)
if df[c].dtype == np.int64:
df[c] = df[c].astype(np.int8)
我们现在的内存占用已经降至 250.4 MB。记住,我们最初是 2 GB。还不错。
7.1.4 重新编码和减少数据
如果你真的需要,你可以尝试进一步减少操作。例如,几个数值列使用了少量不同的值。让我们尝试找到那些:
for c in df.columns:
cnts = df[c].value_counts(dropna=False) ①
if len(cnts) < 10: ②
print(cnts)
① value_counts 返回每个值重复的次数。
② 我们打印所有具有不到 10 个不同值的列。
在这里,我们选择打印所有具有不到 10 个不同值的列——10 是一个任意值,你可以更改。我们已经改进了一些这些列,但有两个由 16 位浮点数表示,可以减少。improvement_surcharge 只有三个不同的值:0,0.3 和-0.3。这些可以很容易地重新编码为,比如说,0,-1 和 1,然后再重新转换。congestion_surcharge 只有以下值:-2.5,-0.75,0.5,0.0,0.5,0.75,2.0,2.5 和 2.5。虽然你可以制作某种表格,但如果将所有值乘以 4,它们就变成了整数。你可以使用 8 位整数表示,并通过乘以 4 来编码它们,通过平均除以 4 来解码它们。
最后,有一个终极解决方案来节省内存:避免加载我们不需要的数据部分。对于我们的下一个任务,如果我们只加载接单和派单日期时间以及拥堵附加费就足够了。我们可以用 pandas 轻松做到这一点:
df = pd.read_csv(
"yellow_tripdata_2020-01.csv.gz",
dtype={
"congestion_surcharge": np.float16,
},
parse_dates=[
"tpep_pickup_datetime",
"tpep_dropoff_datetime"],
usecols=[
"congestion_surcharge",
"tpep_pickup_datetime",
"tpep_dropoff_datetime"],
)
这段代码只需要 109.9 MB,大约是原始 2 GB 需求的 5%。我们是通过减少列数和转换其中一些列的负载来做到这一点的。
使用 inplace=True 的虚假安全性
大多数 pandas 方法都有在原地更改现有数据结构的能力,而不是返回一个新的数据框/序列。你可以通过牺牲原始数据来节省一半的内存。例如,你可以通过以下方式删除所有包含 NAs 的行:
new_df = df.dropna()
但最终你会有三个数据框,占用两倍的内存。或者,你可以使用:
df.dropna(inplace=True)
这将改变原始数据框的状态,因此在所有情况下可能都不会工作,但在许多情况下,这可以是一个简单的解决方案,将内存消耗减半。
但是要小心:在操作执行过程中,pandas 将为两个数组分配空间。因此,在执行过程中,内存需求将加倍。从某种意义上说,这主要是一个便利功能,你可以在使用默认调用而不使用inplace之后用del来复制。
Arrow 通过self_destruct参数提供了一个更有趣的内存管理方法,我们将在本章稍后讨论。
这一部分的代码展示了为列的数据表示做出正确选择的影响。在实践中,pandas 读取器可以在开始时完成所有这些操作:
df = pd.read_csv(
"yellow_tripdata_2020-01.csv.gz",
dtype={ ①
"VendorID": np.int8,
"trip_distance": np.float16,
"PULocationID": np.uint8,
"DOLocationID": np.uint8,
},
parse_dates=[ ②
"tpep_pickup_datetime",
"tpep_dropoff_datetime"],
converters={ ③
"VendorID":
lambda x: np.int8(["", "1", "2"].index(x)),
"store_and_fwd_flag":
lambda x: ["", "N", "Y"].index(x) - 1,
"payment_type":
lambda x: -1 if x == "" else int(x),
"RatecodeID":
lambda x: -1 if x == "" else int(x),
"passenger_count":
lambda x: -1 if x == "" else int(x)
}
)
① 我们可以指定某些列的期望类型。
② 日期与其他类型分开处理。
③ 我们可以在加载时进行转换。
注意,整数和浮点类型始终是较大的类型,因此你可能需要降级。现在我们已经以内存高效的方式加载数据,让我们看看 pandas 如何可以加快数据分析速度。
吸取的经验教训
我们已经进行了这些练习,以证明仔细选择数据类型可以减少使用的内存量。缩小数据类型和扩大精度是两种具有相对较少权衡的方法。
关于这些通用方法,还有更多需要了解的,我们将在稍后回到它们。减少数据和数据表示的数量将是本书最后一章的主要内容。
关于更改数据类型的好消息是,通常你不需要在加载后转换数据,因为 pandas 会为你完成这个工作。在下一个示例中,我们将使用read_csv在加载时进行大部分转换。
7.2 提高数据分析速度的技术
现在我们来访问纽约市的行程记录以进行一些统计分析。例如,我们将确定支付中作为小费的比率。我们不会专注于统计分析本身,因为这不是一本关于数据科学的书。相反,我们想要找出如何高效地访问信息,这样我们就可以在需要时进行这种分析。在这里,我们将考虑数据框索引技术和遍历行的策略。
我们首先加载数据。我们只需要三个字段(代码在07-pandas/ sec2-intro/index.py中):
df = pd.read_csv(
"yellow_tripdata_2020-01.csv.gz",
dtype={
"congestion_surcharge": np.float16,
},
parse_dates=[
"tpep_pickup_datetime",
"tpep_dropoff_datetime"],
usecols=[
"congestion_surcharge",
"tpep_pickup_datetime",
"tpep_dropoff_datetime"],
)
7.2.1 使用索引加速访问
让我们访问具有特定接车时间的所有记录:
df[df["tpep_pickup_datetime"] == "2020-01-06 08:13:00"]
对于我的机器,timeit报告的平均值为 17.1 毫秒。我们可以尝试按接车列对数据框进行排序:
df_sorted = df.sort_values("tpep_pickup_datetime")
df_sorted[df_sorted["tpep_pickup_datetime"] == "2020-01-06 08:13:00"]
不幸的是,从时间角度来说,结果在同一数量级:pandas 在获取行时忽略了列排序。如果我们使用索引,我们可以期望得到完全不同的执行时间:
df_pickup = df.set_index("tpep_pickup_datetime")
df_pickup_sorted = df_pickup.sort_index()
df_pickup.loc["2020-01-06 08:13:00"]
df_pickup_sorted.loc["2020-01-06 08:13:00"]
在这种情况下,我们在tpep_pickup_datetime上索引数据框。如果我们不排序数据框,就像df_pickup中那样,我们什么也得不到。但对于df_pickup_sorted,它现在被索引并且按tpep_pickup_datetime排序,我们达到了 395 微秒,这比之前快了 40 多倍。
这个解决方案有很多注意事项,最明显的一个是它只能用于索引。所以,如果你想使用另一个字段,你将不得不在另一列上索引,或者有一个多列的索引。因此,使用索引不是一个通用的解决方案,但我们的例子表明,你应该仔细选择如何构建你的索引以获得性能。当你依赖于索引——许多 pandas 使用忽略了索引——你是在用查询的一致性换取潜在的加速。例如,如果你想了解所有支付拥堵费的车次,你可以使用以下代码:
df[
(df["tpep_pickup_datetime"] == "2020-01-06 08:13:00") &
(df["congestion_surcharge"] > 0)]
注意,查询语言对待这两个列使用的是相同的语言。
但如果有一个索引,例如,你可以使用以下代码:
my_time = df_pickup_sort.loc["2020-01-06 08:13:00"]
my_time[my_time["congestion_surcharge"] > 0]
小贴士:这里提出的所有关于索引的参数都可以在连接框架(df.join)时使用,对性能的影响更大。
现在我们已经了解了使用索引的影响,让我们实际使用整个数据集来计算支付金额的平均小费比例。
7.2.2 行迭代策略
我们现在将考虑不同的方法来遍历数据帧。我们将计算我们数据集中支付金额中作为小费的比例,这将需要遍历所有记录以获取小费和总金额。
我们将首先读取数据,并删除所有总金额为零的记录(见 07-pandas/sec2-speed/traversing.py):
df = pd.read_csv("../sec1-intro/yellow_tripdata_2020-01.csv.gz")
# ^^ replace
df = df[(df.total_amount != 0)]
df_10 = df.sample(frac=0.1)
df_100 = df.sample(frac=0.01)
注意,我们将主数据帧的样本量减少到 10% 和 1%,这将在我们稍后的性能测试中有所帮助。
让我们先使用一种传统的 Python 技术(即,不使用基于 pandas 或 NumPy 的方法,如向量化):对所有记录进行 for 循环遍历:
def get_tip_mean_explicit(df):
all_tips = 0
all_totals = 0
for i in range(len(df)): ①
row = df.iloc[i] ②
all_tips += row["tip_amount"]
all_totals += row["total_amount"]
return all_tips / all_totals
① 使用行数进行典型的 Python for 循环
② 通过位置访问行
这段代码代表了一个没有 pandas 或 NumPy 经验的 Python 开发者通常会编写的代码。坦白说,性能非常糟糕:在我的计算机上,timeit 报告的测量结果是 分钟。¹
有两种基于 for 的方法可以提供更好的结果。第一种是基于数据帧的 iterrows 方法:
def get_tip_mean_iterrows(df):
all_tips = 0
all_totals = 0
for i, row in df.iterrows():
all_tips += row["tip_amount"]
all_totals += row["total_amount"]
return all_tips / all_totals
这段代码仍然是一个 for 循环,但在这个例子中,我们使用了一个 pandas 迭代器,它返回当前位置和一行。时间略有改善,但仍然非常糟糕。
小贴士:如果你刚开始使用 pandas 和 NumPy,习惯使用 for 循环来进行大量计算是完全正常的。虽然你的中期方法应该考虑其他技术,如向量化(我们稍后会看到一些例子),但在短期内,在你不习惯更高效的方法时,考虑避免使用显式和基于 iterrows 的惯用用法。在所有基于 for 的方法中,itertuples 可能会为你节省最多的时间,同时仍然使用你感到舒适的方法。无论如何,你应该尽快训练自己使用 pandas 的显式迭代。
我们最终的基于 for 的方法基于 itertuples,其中我们使用一个每次返回一行元组的迭代器:
def get_tip_mean_itertuples(df):
all_tips = 0
all_totals = 0
for my_tuple in df.itertuples():
all_tips += my_tuple.tip_amount
all_totals += my_tuple.total_amount
return all_tips / all_totals
虽然习惯用法仍然相同,但平均时间下降到 18 秒!
我们现在将考虑基于 pandas 的方言。我们将首先使用 apply,这在概念上类似于使用 map 函数,其中每一行都是单独处理的:
def get_tip_mean_apply(df):
frac_tip = df.apply(
lambda row: row["tip_amount"] / row["total_amount"],
axis=1 ①
)
return frac_tip.mean() ②
① 可以对每一列调用 apply,这是默认的,或者对每一行调用,这是我们将会使用的。默认的轴 0 将执行列处理,而轴 1 将执行行处理。
② 我们使用数据框的均值函数来计算最终值。
使用 apply 将时间减少到 9.5 秒,这是之前解决方案的两倍快——不错,但还不够好。
在我们讨论基于矢量的解决方案之前,让我们尝试一个稍微不同的 apply 语法:
def get_tip_mean_apply2(df): # df_10: 14.9s
frac_tip = df.apply(
lambda row: row.tip_amount / row.total_amount,
axis=1
)
return frac_tip.mean()
这里的区别在于我们使用 row.tip_amount 和 row.total_amount 而不是 row["tip_amount"] 和 row["total_amount"]。我们的基于对象属性的方法实际上比字典方法要慢。在我的电脑上平均成本是 14 秒。
小贴士:随着 pandas 版本的更新,不同方法之间的性能关系可能会改变,因为没有保证算法保持不变。这对于访问行对象的值以及不同的算法都是有效的。所以,如果性能不够好,总是基准测试不同的算法方法(for、apply、矢量化等)和对象访问模式(对象属性查找与字典查找)。不要将本书中呈现的关系视为圣经。
现在让我们看看使用 pandas 最佳实践进行计算的影响。我们将从矢量化方法开始:
def get_tip_mean_vector(df):
frac_tip = df["tip_amount"] / df["total_amount"]
return frac_tip.mean()
我们从数据框中提取 tip_amount 序列,并将其除以 total_amount 序列。然后我们使用均值。所需时间减少了几个数量级:现在平均时间为 32 毫秒。
记住,我们在这里工作的例子是一个简单的例子,用来说明不同策略遍历行的重要性。对于更复杂的计算,尝试提出一个矢量化方法。如果这不可能,你可能能够将计算分成几个部分,这样你就有了一个独立的矢量化(即快速)部分。然后你可以将非矢量化部分减少到可能的最小成本。
吸取的经验
数据加载过程通常被忽视,因为它可以导致内存使用增加,从而进一步影响操作的效率。在加载时正确设置列类型是优化内存的主要方式,尽管扩展精度参数也可以有所帮助。pandas 在加载时可以处理类型,所以你不需要之后再进行操作。
加载数据后访问数据是另一个可以检查以节省操作时间的潜在过程。在访问数据时提高效率有两种一般策略。首先,使用索引可能会有所帮助,但它也带来了一些缺点。您可以使用行迭代。一些遍历数据帧的方法比其他方法更快。一般规则是,基于行的分析通常应以声明性方式处理:如果可能,则向量化,至少避免显式迭代。现在我们已经考虑了一些 pandas 原生优化方法,让我们考虑使用更低级技术的方法。
7.3 在 NumPy、Cython 和 NumExpr 之上使用 pandas
在接下来的几节中,我们将借鉴前面章节(第四章 NumPy、第五章 Cython 和第六章 NumExpr)中关于这些技术的讨论,因此在这里我们只需对这些技术如何应用于 Python 进行快速浏览。如果您需要回顾基础知识,请随时在需要时回到那些章节。
本节的目标是从 pandas 的角度研究 NumPy、Cython 和 NumExpr,并了解它们如何提高数据分析的性能。为了进行这项调查,我们将回顾上一节中使用的示例:计算每次车费中作为小费的支付金额的比例。
7.3.1 显式使用 NumPy
我们在上一节中使用的方法都是隐式地 NumPy 方法,因为 pandas 位于 NumPy 之上。也可以显式地使用 NumPy。我们首先获取序列的 NumPy 底层表示,然后使用 NumPy 操作(见07-pandas/sec3-numpy-numpexpr-cython/traversing.py):
df_total = df["total_amount"].to_numpy() ①
df_tip = df["tip_amount"].to_numpy()
print(type(df_tip)) ②
def get_tip_mean_numpy(df_total, df_tip):
frac_tip = df_total / df_tip ③
return frac_tip.mean() ④
① to_numpy 引用了底层的 NumPy 数组。
② 类型现在是 numpy.ndarray,而不是 pandas.Series。
③ 我们使用数组上的向量化除法进行除法,而不是使用序列。
④ 我们使用 NumPy 的均值,而不是 pandas。
基于 NumPy 的向量化代码在我的电脑上运行时间为 11 毫秒,而向量化 pandas 版本则需要 35 毫秒。之前的操作适用于 NumPy 数组,但不适用于 pandas 序列。²
提示 to_numpy 方法返回对底层 pandas 数组的引用。如果您更喜欢副本,因为您可能想要进行不会反映在原始数据上的更改,那么请使用 to_numpy(copy=True) 方法。记住,如果您复制,内存使用量将加倍,并且复制操作会有时间成本。
7.3.2 在 NumExpr 之上使用 pandas
我们可以使用 NumExpr 查询数据,而不是使用 pandas 的查询引擎。NumExpr 是一个表达式评估器,其性能可以显著优于 NumPy,这不仅归功于其高度高效的线程化实现,还归功于其对中间内存的明智使用,使得许多计算主要基于 CPU 缓存完成。更多详情,请参阅上一章。
这里是对支付比例作为小费的简单实现:
def get_tip_mean_numexpr(df):
return df.eval("(tip_amount / total_amount).mean()", engine="numexpr")
如预期的那样,pandas 通过支持数据框和序列扩展了 NumExpr 的语言。在前一个案例中,我们在评估字符串中引用了tip_amount和total_amount列,这些列被解析为相关的 pandas 列。也可以在局部命名空间中使用 pandas 的 eval。例如,前面的代码可以实施如下:
def get_tip_mean_numexpr(df):
return pd.eval("(df.tip_amount / df.total_amount).mean()", engine="numexpr")
这将允许你在 eval 调用中引用多个数据框。你也可以使用非 pandas 变量;有关详细信息,请参阅 pandas 网站上的 eval 文档(mng.bz/aMAX)。
关于性能?它与矢量化 pandas 解决方案处于同一水平——大约 35 毫秒——因此比 11 毫秒的 NumPy 解决方案慢得多。发生了什么?
首先,我们通过将字符串解析成可执行代码来付出代价。正如我们在上一章所看到的,这意味着NumExpr 只有在处理的数据量足够大,足以证明开销是合理的时才有效用。确定“足够大”需要逐个案例进行评估:你将不得不对你的数据进行一些分析。在这个特定案例中,我们使用了一个足够大的数据集。那么发生了什么?
随着公式复杂性的增加,NumExpr 生成高效计算策略的能力也增加,因为缓存命中变得更加常见。让我们考虑一个,可以说是人为的,例子,即我们使用四次求和的分数:
tip_amount / total_amount + tip_amount / total_amount + tip_amount / total_amount + tip_amount / total_amount
这里是 NumPy 和 NumExpr 的实现:
def get_tip_mean_numpy4(df_total, df_tip):
frac_tip = (
df_total / df_tip +
df_total / df_tip +
df_total / df_tip +
df_total / df_tip )
return frac_tip.mean()
def get_tip_mean_numexpr4(df):
return df.eval(
"tip_amount / total_amount +"
"tip_amount / total_amount +"
"tip_amount / total_amount +"
"tip_amount / total_amount", engine="numexpr").mean()
唯一发生变化的是公式的复杂性。NumPy 的解决方案略微下降,比线性略差,平均为 55 毫秒。而 NumExpr 则保持在 35 毫秒!这并不是说 NumExpr 的性能使得公式的复杂性变得无关紧要,但它强烈暗示了我们之前章节中观察到的现象:如果我们的实现尽可能依赖 CPU 缓存并避免访问 DRAM,那么我们可以获得显著的性能提升,这看起来似乎是反直觉的。
吸取的教训
NumExpr 在处理大量数据和复杂公式时更有效率,这意味着它将在最复杂的情况下有所帮助。正如你所看到的,在我们的第一种方法中,NumExpr 的表现不如 NumPy 和 pandas。这又是另一个例子,表明你不应该总是使用“最佳”技术,而应该评估你的数据集和算法,以确定最有效的解决方案。一种方法并不适合所有人,这本书中提出的最复杂和最优雅的解决方案,即 NumExpr,不应被视为绝对真理。现在让我们看看如何使用 Cython 与 pandas 来提高我们练习的性能。
7.3.3 Cython 和 pandas
我们现在将使用 Cython 重新实现提示分析代码,并分析潜在的性能优势。我们将故意使这个子节简短,因为没有真正的 Cython 和 pandas 之间的 直接 关系。在这种情况下,我们使用 Cython 的能力完全基于 NumPy,如图 7.1 所示。因此,如果你已经阅读了 Cython 章节的话,你应该对下面的代码感到很熟悉。在这里,我们将使用 Cython 重新创建我们的示例:我们将从一开始就使用 Cython 的最佳实践。

图 7.1 使用 Cython 通过 NumPy 的 pandas
纯 Python 调用代码是(见 07-pandas/sec3-numpy-numpexpr-cython/traversing_cython_top.py):
import pandas as pd
import numpy as np
import pyximport ①
pyximport.install( ②
setup_args={
'include_dirs': np.get_include()})
import traversing_cython_impl as cy_impl
df = pd.read_csv("../sec1-intro/yellow_tripdata_2020-01.csv.gz")
df = df[(df.total_amount != 0)]
df_total = df["total_amount"].to_numpy() ③
df_tip = df["tip_amount"].to_numpy()
get_tip_mean_cython = cy_impl.get_tip_mean_cython ④
① 我们为了方便使用 pyximport 系统。
② 我们需要包含 NumPy C 包含文件。
③ 我们需要访问序列的 NumPy 表示。
④ 我们调用函数的 Cython 实现。
Cython 代码是(见 07-pandas/sec3-numpy-numpexpr-cython/traversing_cython_impl.pyx):
import numpy as np
cimport cython ①
cimport numpy as cnp ②
@cython.boundscheck(False) ③
@cython.nonecheck(False)
@cython.wraparound(False)
@cython.cdivision(True) ④
cdef cnp.float64_t get_tip_mean_cython_impl( ⑤
cnp.float64_t[:] df_total, ⑥
cnp.float64_t[:] df_tip)
nogil: ⑦
cdef cnp.float64_t frac_tip ⑧
cdef int array_size = df_total.shape[0]
cdef cnp.float64_t result = 0
for i in range(array_size):
result += df_tip[i] / df_total[i]
return result / array_size ⑨
def get_tip_mean_cython(df_total, df_tip): ⑩
return get_tip_mean_cython_impl(df_total, df_tip)
① 访问 Cython 支持函数
② 导入对 C 级 NumPy 函数的访问
③ 禁用所有与 Python 绑定的代码,包括数组边界检查、None 检查和包装索引
④ 我们还使用了不带任何检查的 C-除法。这是这里引入的唯一新概念。
⑤ 我们使用 C-only 函数(cdef)并指定返回类型(cnp.float64_t)。这允许 Cython 在签名中避免 Python 开销。
⑥ 我们将输入类型定义为 64 位浮点数的内存视图,这比一般的 Python 对象甚至 NumPy 数组要快得多。
⑦ 由于我们已经完全清理了函数与 Python 的交互,当调用时可以释放 GIL。
⑧ 我们为所有变量使用 Cython 类型。
⑨ 我们只在最后进行除法,这要高效得多。
⑩ 我们有一个可以从 Python 调用的“桥梁”函数。这将隐式地将 NumPy 数组转换为内存视图。
为了完整性,我使用我们在 Cython 章节中学到的实践对所有的代码进行了注释。这里唯一的新特性是使用 cdivision 注解,它不会因为分母为 0 而引发 Python 错误,并且更高效。如果发生除以 0 的情况,代码将会崩溃。记住,我们在原生 Python 代码中已经仔细移除了所有 total_amount 为 0 的行,所以对于我们的情况,我们会没事的。
如果你修改了代码,请记得使用 cython -a 生成与 Python 解释器潜在交互的 HTML 报告(即潜在的性能瓶颈)。性能?8.51 毫秒。这是最好的!
吸取的经验
由于 pandas 是建立在 NumPy 之上的,因此当我们使用 pandas 时,使用 NumPy 是隐含的。尽管如此,如果我们明确要求 NumPy 数据结构,我们还可以进一步提高性能速度。Cython 也可以与 pandas 中的 NumPy 数据结构一起使用来提高性能速度。当处理非常大的数据框和复杂的算法时,NumExpr 往往是速度的最佳选择。由于涉及许多变量,包括硬件、软件、数据和长期及短期目标,因此不可能将这些解决方案相互比较。最适合我的情况和需求的方法可能不是最适合你的最佳解决方案。理解和实验这些不同的方法应该会给你提供选择适当策略所需的见解,以用于特定的项目。
接下来,我们将考虑另一种优化 pandas 的方法。我们将使用 Apache Arrow 来提高一些常见操作的性能,例如从磁盘读取数据。
7.4 使用 Arrow 将数据读入 pandas
在本节中,我们将使用 Apache Arrow 加速数据加载到 pandas 数据框中。但在我们开始之前,让我们退一步,了解 pandas 和 Arrow 之间有些令人困惑的关系。
注意:本节旨在介绍 pandas 和 Arrow 如何协同工作,而不是专门介绍 Arrow,尽管我们将在本节末尾进行一些简单的分析。有关 Arrow 的更多信息,请参阅 arrow.apache.org/。
7.4.1 pandas 和 Apache Arrow 之间的关系
Apache Arrow 实质上是一种针对列式数据的语言无关内存格式。它是语言无关的,这意味着它与 Python/pandas 或任何其他语言无关。在核心上,它是一组库,用于在非常快速的底层语言(如 C、Rust 或 Go)中执行基本操作,尽管有时也存在在高级语言(如 JavaScript)中的实现。对于较慢的语言,较快的实现被封装在一层中。例如,Arrow 的 Python “实现”实际上是对 C++ 实现的封装。
在这里,我们将把 Arrow 视为 pandas 的辅助工具:我们希望 Arrow 加速 pandas 数据分析的部分功能,而不是完全取代 pandas。也许在未来,随着其分析实现的增长,Arrow 可以作为完整的 pandas 替代品,但目前还不是这种情况。
在本节中,我们将用 Arrow 替换 pandas 的持久化机制(即使用 Arrow 读取 CSV 文件)并简要介绍 Arrow 分析。在下一节中,我们将讨论使用提供的 IPC 服务器 Plasma 的 Arrow 高效进程间通信(IPC)机制。在这两种情况下,我们的目标是确定我们是否可以提高速度。
Arrow 有更多可用的功能。我们将在下一章更详细地研究持久化部分,特别是文件格式及其对效率的影响。Arrow 还能够处理几个不同的持久化后端。此外,还有远程过程调用(RPC)的功能,可以在不同的计算机之间发送数据。当前 Arrow 架构的概述如图 7.2 所示。

图 7.2 PyArrow 的内部架构
在 Python 之上是一个 C Arrow 实现的 Python 包装器。C 部分由几个组件组成——其中最重要的是一个初露锋芒的分析引擎,它可以利用 GPU 计算,以及一个可以处理文件系统、Amazon S3 和 Hadoop 等多个后端以及 CSV、Parquet 和 JSON 等多个文件格式的持久化层。最后,还有进程间和机器间功能,允许与其他进程(可能运行在其他机器上)进行高效通信。最重要的是,底层的 Arrow 数据格式适用于许多编程语言和硬件架构。
现在,让我们看看 Arrow 在数据读取方面的效率与 pandas 相比如何。
7.4.2 读取 CSV 文件
在本章的第一部分,我们使用 pandas 读取纽约市出租车行程的信息 CSV 文件。Apache Arrow 相比于 pandas,现代化改进的一个问题就是文件读取。它有一个多线程的读取器,并且在推断列类型方面更加智能。正如本节开头所述,pandas 和 Arrow 的特性在本书出版后可能会演变,它们之间的关系可能会发生变化。在撰写本文时,Apache Arrow 拥有更现代的架构,而 pandas 需要支持长期成功的应用遗产,因此很难赶上。
让我们加载相同的 CSV 文件并研究加载数据的内存占用(见 07-pandas/sec4-arrow-intro/read_csv.py):
from pyarrow import csv ①
table = csv.read_csv("../sec1-intro/yellow_tripdata_2020-01.csv.gz")
tot_bytes = 0
for name in table.column_names: ②
col_bytes = table[name].nbytes
col_type = table[name].type
print(name, col_bytes // (1024 ** 2))
tot_bytes += col_bytes
print("Total", tot_bytes // (1024 ** 2))
① 我们从 PyArrow 导入 CSV 处理器。
② 我们遍历所有列以获取类型和分配。
在我的电脑上,操作时间从 12 秒减少到 2 秒,减少了六倍。以下是输出摘要:
VendorID int64 48
tpep_pickup_datetime timestamp[s] 48
passenger_count int64 48
trip_distance double 48
store_and_fwd_flag string 34
total_amount double 48
Total 865
没有任何帮助的情况下,Arrow 的内存消耗为 865 MB,而 pandas 为 2 GB。如果我们帮助 pandas,那么我们可以降低到 250 MB,尽管使用原始的 Arrow 来帮助 pandas 并不公平。话虽如此,Arrow 的表现如何?
Arrow 表现相当不错:没有领域知识,自动化的系统可能做得更好,但不会太多。VendorID 只有三个值(1、2 和 null)并且可以减少到 int8,这对于大多数整数是有效的。但在使用双精度浮点数方面,Arrow 无法知道我们是否可以接受更小的表示。
注意 Arrow 类型与 Python 和 NumPy/pandas 不同,尽管从一种类型到另一种类型的转换很容易。虽然从程序的角度来看很简单,但在类型表示的方式上存在重要差异。
你可能已经注意到,包括null的VendorID被编码为整数。在 pandas/NumPy 中,除非我们重新编码 NA,否则这是不可能的。Arrow 对缺失值的实现与 NumPy/pandas 完全不同:有一个额外的位数组,每行有一个条目,指示值是否缺失。这意味着,在大多数类型的适度内存成本下,缺失值不需要在类型本身上进行任何表示。因此,整数可以用整数表示,无需对缺失值进行重新编码。
使用 Arrow 对缺失值的表示,我们可以显著减少许多列的内存需求,但在某些情况下我们仍然需要通知 Arrow。为此,PyArrow 提供了ConvertOptions类:
convert_options = csv.ConvertOptions(
column_types = {
"VendorID": pa.bool_() ①
},
true_values=["Y", "1"], ②
false_values=["N", "2"])
table = csv.read_csv(
"../sec1-intro/yellow_tripdata_2020-01.csv.gz",
convert_options=convert_options ③
)
print(
table["store_and_fwd_flag"].unique(),
table["store_and_fwd_flag"].nbytes // (1024 ** 2),
table["store_and_fwd_flag"].nbytes // 1024
)
① Arrow 可以从 true_values 和 false_values 推断出 store_and_fwd_flag 的类型为布尔型,但鉴于 VendorID 是数字型,Arrow 需要明确声明其类型。
② 我们通知 Arrow 将 Y 和 1 转换为 true。这仅用于布尔类型的列。对于 false 值,我们执行类似的操作。
③ 我们将转换选项传递给 CSV 读取器。
从技术上来说,VendorID不是一个boolean,但由于它只有两个可能的值,我们将其重新编码为布尔型以节省内存。
由于缺失值是单独处理的,具有少量可能状态但包含缺失值的列的内存成本要高效得多,其中boolean是最极端的情况。
对于store_and_fwd_flag,我们显示三个值,以及内存占用。因为它低于 1 MB,所以我们还将值打印为 790 KB。
不幸的是,我们将不得不将 Arrow 转换为 pandas 进行分析,因为内部格式不同。因此,我们付出了内存和时间上的代价:
table_df = table.to_pandas()
如预期的那样,pandas 版本需要更多的内存。例如,store_and_fwd_flag现在是一个对象。即使你将其转换为更有效的类型,它也比 Arrow 对缺失值的表示方式要紧凑。
在这个阶段,看起来我们使用 Arrow 读取文件所获得的许多收益都丢失了,但事实并非如此。有两个要点需要考虑:转换数据的时间成本和内存占用。这两者都可以解决。
转换的时间成本是读取文件时收益的一小部分:在 pandas 中读取大约需要 12 秒。在 Arrow 中读取大约需要 2 秒,然后加上转换时间。在我的电脑上,这个时间是 23 毫秒,与读取收益相比可以忽略不计。
内存问题确实是存在的,并且按照当前的方法,在某个时刻,我们需要两倍多的内存。幸运的是,Arrow 提供了一个解决方案。你可以要求 Arrow 在转换过程中自我销毁 Arrow 结构。这不会在任何时刻增加内存消耗,但会以销毁 Arrow 版本为代价:
mission_impossible = table.to_pandas(self_destruct=True)
要点
正如这个例子应该清楚地表明的那样,与 pandas 相比,Arrow 在时间和空间效率方面都提供了现代功能,并且改进了效率。到目前为止,Arrow 在提供与 pandas 相同的分析能力方面还相去甚远,因此它最好与 pandas 集成使用。但让我们花几分钟时间看看使用 Arrow 进行数据分析的样子,并尝尝未来版本 Arrow 可能实现的功能。
7.4.3 使用 Arrow 进行分析
由于这本书是关于你现在可以做什么来提高效率,我们对 Arrow 感兴趣的是它是如何服务于 pandas 的。但事实是,Arrow 可以独立进行数据分析。尽管与 pandas 相比,它在分析功能方面严重缺乏,但随着时间的推移,其分析能力只会不断提高。
为了看到 Arrow 如何进行数据分析,我们再次回到我们对出租车司机小费的分析:
import pyarrow.compute as pc
t0 = table.filter(
pc.not_equal(table["total_amount"], 0.0))
pc.mean(pc.divide(t0["tip_amount"], t0["total_amount"]))
代码的细节以及由此延伸的设计理念与 pandas 有很大差异。首先,在这个阶段,接口是低级别的:如果你尝试使用整数(0)而不是浮点数(0.0)来执行 not_equal 操作,代码将会失败,因为类型不同。其次,接口在功能上比面向对象的方式更为实质;注意我们使用数组参数调用函数,而不是在数组上调用方法。最后,错误报告基于错误代码,而不是抛出异常。无论 API 偏好如何,通过将底层的 C Arrow 库与最少的 Python 习惯用法映射在一起,都可以获得潜在的好处,其中之一就是速度。
在我的电脑上,计算小费比例的时间成本大约是 15 毫秒。这个时间大约是等效 pandas 版本(get_tip_mean_vector)的一半。
我们已经看到了 Arrow 如何高效地将数据加载到 pandas 中。但是,我们还可以利用 Arrow 和 pandas 的第二种方式,接下来我们就转向这一点。让我们看看 Arrow 如何帮助不同语言之间的互操作性;记住,Arrow 在不同实现之间提供了一个标准的内存格式。
7.5 使用 Arrow 互操作将工作委托给更高效的语言和系统
Arrow 的一个优点是其标准的内存格式,它允许数据结构表示在许多不同语言之间的实现中共享。它是通过零拷贝来共享的,或者至少,它以高效的方式传输数据结构表示。在本节中,我们将探讨为什么 Arrow 架构比其他替代方案更高效。我们还将实现一个使用 Arrow 的 Plasma 服务器进行进程间通信的示例。
本节的主要目标是说明 Arrow 如何使进程间通信更加高效。鉴于 Plasma 正在积极开发中,并且你实际上可以在进程间显式地实现内存共享,这里的内容更多的是作为一个设计模式的说明。当然,代码是完全功能性和可用的。
7.5.1 Arrow 语言互操作架构的影响
想象一个场景,你大部分的处理工作在 Python 中完成,但你需要使用一段 R 代码来进行一些分析。有几种进行进程间和语言间通信的方法,但考虑两种典型的非 Arrow 场景和两种典型的 Arrow 场景,如图 7.3 所示,如下所述:
-
第一种方法是将我们的数据以 Python(例如,CSV)的文件格式写入,然后在 R 中读取。这可能非常节省内存,但由于磁盘使用,时间成本将非常低。
-
第二种方法将使用 rpy2³,这将需要将 pandas 转换为 R 数据框,这在 R 世界中大致相当于 pandas。这将至少使转换成本加倍,无论是时间还是内存。实际上,正如我们将在下面的讨论中看到的那样,情况比这还要糟糕。
-
使用 Arrow,第三种方法是将 pandas 数据框转换为 Arrow,然后传递给 R,再从 Arrow 转换,并处理它。这需要两个转换的时间(pandas 到 Arrow 和 Arrow 到 R)。在内存方面,结果将取决于你是否可以进行破坏性操作,这不会增加任何消耗;如果不能,你将需要在转换期间大约加倍你的内存。
-
最后,使用 Arrow 的最佳情况是,如果你在 Python 和 R 中同时处理数据,仅基于 Arrow(例如,不使用 pandas),那么这只是一个传递内存指针的问题,这在处理和内存方面都不会增加任何成本。

图 7.3 Python 和 R 之间互操作的一些替代方案
初看起来,方案 2(从 pandas 的本地格式转换为 R 的格式)似乎比方案 3(从 pandas 转换为 Arrow,然后从 Arrow 转换为 R)更高效。然而,这并不是事实,原因有几个:
-
内存中的 Arrow 格式被所有 Arrow 实现共享,无论语言如何。这意味着共享一个 Arrow 数据结构本质上就是共享一个内存指针。
-
当从 pandas 转换为 R 时,转换必须在一边进行,要么是 Python 要么是 R,这意味着它将是非本地的,并且在其中一个格式(即非本地的)中非常低效。
-
Arrow 转换器是用 C/C++编写的,多线程的,从头开始设计,以确保尽可能高效。您在前一节关于使用 CSV 读取器性能的讨论中已经看到了这种哲学的效果。
另一个重要的原因是:在复杂系统中,您将需要 2n 个转换器。例如,如果您使用 pandas、Java、R 和 Rust,您将需要 pandas/Java、pandas/R、Java/R、Java/Rust、pandas/Rust 和 R/Rust 转换器。但如果您使用 Arrow 作为中间格式,那么您只需要四个:pandas/Arrow、Java/Arrow、R/Arrow 和 Rust/Arrow。如果您使用更多的系统,组合会爆炸式增长。
本书不涉及 Python 之外的语言的使用,但您可以自由地尝试其他语言选项,以衡量每种语言的性能影响。
接下来我们将要做的,完全使用 Python,是展示如何高效地进行数据进程间通信。在许多实际场景中,您的一个或多个组件可能会用不同的语言实现。
7.5.2 使用 Arrow 的 Plasma 服务器进行零拷贝数据操作
让我们回到我们熟悉的老 NYC 出租车数据集,并计算一组数据统计。我们将将其分为三个进程:一个读取数据并将其提交处理,另一个进行分析,第三个仅显示结果。考虑这种架构有两个主要原因:(1)可能有一个算法的实现可用,它位于一个单独的进程中,无法直接链接到 Python,并且(2)我们可能更喜欢将更昂贵的处理代码与分析代码分开。
Arrow 提供了一个名为 Plasma 的服务器,用于管理共享内存:它允许您注册、读取和写入对象,以及所有查询现有对象目录的操作。这简化了进程以标准方式找到彼此的方式。这个服务器是本地的——也就是说,不能通过网络访问,但可以通过本地套接字访问。它主要存在是为了促进内存共享,轻松找到现有对象,并允许在进程生命周期不重叠的情况下共享内存(即,消费者进程在生产者死亡后启动)。
我们需要做的第一件事是启动 Plasma 服务器:
plasma_store -s /tmp/fast_python -m 1000000000
这将使用 UNIX 套接字/tmp/fast_python,这是一种进程间通信的形式,以允许进程通信。将分配一个千兆字节的共享空间。
我们的第一个进程负责加载 CSV 文件并将其放入 Plasma:我们连接到 Plasma 套接字,使用 Arrow 读取文件,并将其存入 Plasma(见07-pandas/sec5-arrow-plasma/load_csv.py):
import os
import sys
import pyarrow as pa
from pyarrow import csv
import pyarrow.plasma as plasma
csv_name = sys.argv[1]
client = plasma.connect("/tmp/fast_python") ①
convert_options = csv.ConvertOptions( ②
column_types={
"VendorID": pa.bool_()
},
true_values=["Y", "1"],
false_values=["N", "2"])
table = csv.read_csv(
csv_name
convert_options=convert_options
)
pid = os.getpid()
plid = plasma.ObjectID(
f"csv-{pid}".ljust(20, " ").encode("us-ascii")) ③
client.put(table, plid) ④
① 我们通过套接字连接到 Plasma。
② 我们假设 CSV 是纽约出租车格式。
③ 我们为我们的表格创建一个 ID。
④ 我们将对象放入等离子体中。
当我们将对象放入等离子体时,我们必须给它一个 ID(即命名它)。我们将使用以csv-开头的名称,后跟我们的进程的 PID(进程 ID)。这个名称对我们的目的来说已经足够好了,但您可能需要一个与其他名称字符串冲突可能性更小的名称作为更通用的解决方案。等离子体需要一个 20 字节的 ID,因此我们用空格填充名称以达到 20 个字节的大小,然后使用 US-ASCII 编解码器对字符串进行编码,这将返回一个字节数组。只要它将一个字符转换为 1 字节,您可以使用任何编解码器,否则您将得到一个太长的字节数组。
我们使用对象 ID 来表示和查找我们的表格。还有其他可能更复杂的技巧,例如使用对象元数据,但只要我们有找到我们感兴趣对象的方法,对象 ID 就足以用于说明目的。
警告:如果您没有足够的可用内存,等离子体将驱逐较旧的对象。
在我们实现其他两个过程之前,让我们创建一个支持脚本以列出当前在等离子体中可用的所有 CSV 文件,这允许我们监控等离子体中的内容。我们还将监控结果,我们将使用前缀result-命名(见07-pandas/sec5-arrow-plasma/list_csvs.py):
import pyarrow as pa
import pyarrow.plasma as plasma
client = plasma.connect("/tmp/fast_python")
all_objects = client.list() ①
for plid, keys in all_objects.items():
try:
plid_str = plid.binary().decode("us-ascii")
except UnicodeDecodeError: ②
continue
if plid_str.startswith("csv-"):
print(plid_str, plid)
print(keys)
elif plid_str.startswith("result-"):
print(plid_str, plid)
print(keys)
① 我们列出所有对象。
② 由于 ID 的解码可能不会产生有效的字符串,我们需要捕获异常。
在我们获取所有对象的列表后,该列表作为字典返回,我们寻找以csv-或result-开头的 ID。因为并非所有 ID 都可以实际转换为字符串(即,其他东西可能在等离子体服务器中共享),我们小心地捕获所有无法解码的异常,然后忽略它们,因为它们并不是真正的错误。
对于所有相关情况,我们打印解码后的 ID、原始 ID 和一些相关元数据。以下是一个示例:
csv-579123 ObjectID(6373762d35373931323320202020202020202020)
{'data_size': 822037944, 'metadata_size': 0, 'ref_count': 0,
'create_time': 1616361341, 'construct_duration': 0,
'state': 'sealed'}
现在,让我们实现我们的计算服务器。它将处于永恒循环中,寻找以csv-开头的对象。如果找到一个对象并且结果尚不存在,那么执行计算并将结果根据以result-开头的对象 ID 命名转换放入等离子体中(见07-pandas/sec5-arrow-plasma/compute_stats.py):
import time
import pandas as pd
import pyarrow as pa
from pyarrow import csv
import pyarrow.compute as pc
import pyarrow.plasma as plasma
client = plasma.connect("/tmp/fast_python")
while True:
client = plasma.connect("/tmp/fast_python")
all_objects = client.list()
for plid, keys in all_objects.items():
plid_str = ""
try:
plid_str = plid.binary().decode("us-ascii")
except UnicodeDecodeError:
continue
if plid_str.startswith("csv-"):
original_pid = plid_str[4:]
result_plid = plasma.ObjectID(
f"result-{original_pid}".ljust(
20, " ")[:20].encode("us-ascii"))
if client.contains(result_plid): ①
continue
print(f"Working on: {plid_str}")
table = client.get(plid) ②
t0 = table.filter(
pc.not_equal(table["total_amount"], 0.0))
my_mean = pc.mean(
pc.divide(t0["tip_amount"], t0["total_amount"])).as_py()
result_plid = plasma.ObjectID(
f"result-{original_pid}".ljust(20, " ")[:20]
.encode("us-ascii"))
client.put(my_mean, result_plid) ③
time.sleep(0.05)
① 我们检查结果是否已存在。
② 我们从等离子体获取表格。
③ 我们将结果放入等离子体。
我们的大部分代码已在当前和上一节中展示。唯一的概念创新是使用contains函数来查看结果是否已存在,以及使用get函数来获取表格。
最后,让我们看看结果(见07-pandas/sec5-arrow-plasma/show_results.py):
import pyarrow as pa
import pyarrow.plasma as plasma
client = plasma.connect("/tmp/fast_python")
all_objects = client.list()
for plid, keys in all_objects.items():
try:
plid_str = plid.binary().decode("us-ascii")
except UnicodeDecodeError:
pass
if plid_str.startswith("result-"):
print(plid_str, client.get(plid, timeout_ms=0))
这段代码实际上并没有什么新东西,但请注意,我们在最后一行指定超时为 0 时不阻塞地获取对象。等离子体默认具有阻塞语义,但您可以根据需要指定超时。
关于 Plasma(例如,使用更底层的 API 获取和放置对象或如何有效地传输 pandas 对象)还有很多可以说的。然而,由于 Arrow/Plasma 架构的开发仍在演进中,因此了解我们在这里所涉及到的 IPC 概念可能更为重要。
我们将在下一章讨论数据持久性时重新审视 Arrow。Arrow 在存储方面有很多可提供的内容。
摘要
-
pandas 是 Python 世界中最广泛使用的数据分析库,但它并不是在设计时优先考虑效率,无论是在计算效率还是内存存储效率方面。
-
加载数据的非常简单技术可以通过例如忽略一些不需要用于计算的列或提前通知 pandas 每个列的类型来显著减少 pandas 数据框的内存占用。
-
聪明地使用索引可以减少处理时间,尽管 pandas 索引的灵活性有所限制。
-
不同的行迭代策略在性能上可能会有两个数量级以上的差异。尽可能避免显式循环,并使用向量化操作。
-
虽然 pandas 是建立在 NumPy 之上的,但有时明确要求从 pandas 中提取 NumPy 数据结构可以进一步提高性能速度。
-
Cython 可以与 pandas 一起使用,尽管是通过 NumPy 数据结构间接使用,并且可以显著提高速度。
-
对于非常大的数据框和复杂的公式,NumExpr 可能是进行数据分析的有效策略。
-
Apache Arrow 可以执行许多任务。在本章中,我们重点关注它如何补充 pandas,特别是作为数据快速读取器。但请确保检查项目的其他功能。
-
Arrow 的架构,通过其 Plasma 服务器,可以在同一台机器上的进程之间进行高效的数据传输。这对于使用不同语言和框架处理数据非常有用,因为该格式在所有实现的语言中都是共享的。
¹ 如果你打算运行这个示例,考虑使用子采样数据框 df_10 和 df_100:你仍然可以感受到所需时间,而无需等待那么久。
² 当然,从 Python 的角度来看,这种差异只是概念上的:如果你向函数传递了一个序列,它也会正常工作,但使用 pandas 对象。我们在这里做一个实现点,即使语言更加灵活。
³ 如果你同时使用 R(在数据科学领域,Python 与 R 的搭配相当常见)——考虑通过使用 rpy2 包来整合两者,该包将在 Python 中嵌入 R 进程,并为 Python/R 通信提供优雅的原语。
8 存储大数据
本章涵盖
-
了解 fsspec,一个在文件系统之上的抽象库
-
使用 Parquet 高效存储异构列式数据
-
使用 pandas 或 Parquet 等内存库处理数据文件
-
使用 Zarr 处理同构的多维数组数据
在处理大数据时,持久性至关重要。我们希望尽可能快地访问——读取和写入——数据,最好是来自多个并行进程。我们还希望持久表示紧凑,因为存储大量数据可能很昂贵。
在本章中,我们将探讨几种使数据持久化存储更有效的方法。我们将从对 fsspec 的简要讨论开始,fsspec 是一个抽象访问本地和远程文件系统的库。虽然 fsspec 并不直接涉及性能问题,但它是一个现代库,被许多应用程序用于处理存储系统,并且在高效的存储实现中经常被使用。
然后,我们将考虑 Parquet,这是一种用于持久化异构列式数据集的文件格式。Parquet 通过 Apache Arrow 项目在 Python 中得到支持,该项目在前一章中已介绍。
接下来,我们将讨论对非常大的数据集进行分块读取,有时称为离核方法。通常,我们存储的数据集无法一次性全部在内存中处理。分块读取允许您使用您已经熟悉的软件库分部分批处理数据,这是一种简单但非常有效的策略。我们的示例将从一个大的 pandas 数据框转换成 Parquet 文件。最后,我们将探讨 Zarr,这是一种用于在持久内存中存储多维同构数组(即 NumPy 数组)的现代格式和库。
对于本章,您需要安装 fsspec、Zarr 和 Arrow,后者提供了 Parquet 接口。要安装 conda,您可以使用conda install fsspec zarr pyarrow。tiagoantao/python-performance-dask Docker 镜像包含了所有必要的库。让我们先对 fsspec 库进行简要概述,它允许我们使用相同的 API 处理不同类型的本地和远程文件系统。
8.1 文件访问的统一接口:fsspec
存储文件系统有很多系统,从古老的本地文件系统,到云存储如 Amazon S3,再到 SFTP 和 SMB(Windows 文件共享)等协议。列表很长,特别是如果我们考虑到还有许多其他类似文件系统的对象:例如,zip 文件是一个文件和目录容器,HTTP 服务器有一个可遍历的树,等等。
处理每种类型的文件系统意味着需要学习每种类型的不同编程 API——这是一个费时甚至痛苦的过程。fsspec 就是这样一种库,它通过统一的 API 抽象出许多文件系统类型。使用 fsspec,你只需要学习一个 API 就可以与许多文件系统类型交互。有几个小问题:例如,你不能期望本地文件系统的行为与远程文件系统相同,但该库通过最小开销大大简化了对文件系统的访问。
8.1.1 使用 fsspec 在 GitHub 仓库中搜索文件
为了说明 fsspec 的工作原理,我们将使用它遍历 GitHub 仓库以查找 zip 文件,然后确定这些 zip 文件是否包含 CSV 文件。在这个练习中,我们将 GitHub 仓库视为文件系统。这并不像听起来那么牵强。当你这么想的时候,GitHub 仓库本质上是一个带有版本化内容的目录树。
对于一个示例仓库,我们将使用本书中的仓库。在 08-persistence/ 01-fspec 中,你可以找到一个名为 dummy.zip 的 zip 文件,其中包含两个虚拟 CSV 文件。我们的代码将遍历仓库,找到所有 zip 文件——在我们的情况下,只有 dummy.zip 存在——打开它们,并使用 pandas 的 describe 命令来总结所有 CSV 文件。
让我们先使用 fsspec 访问仓库并列出根目录:
from fsspec.implementations.github import GithubFileSystem
git_user = "tiagoantao"
git_repo = "python-performance"
fs = GithubFileSystem(git_user, git_repo)
print(fs.ls(""))
我们导入 GithubFileSystem 类,传递用户和仓库名,并列出顶级目录。请注意,根目录由空字符串表示,而不是典型的 /。fsspec 提供了许多其他类来访问存储,如本地文件系统、压缩文件、Amazon S3、Arrow、HTTP、SFTP 等。
fs 对象具有与 Python 文件系统接口的几个常用方法。例如,为了遍历文件系统,我们需要这样做以找到所有 zip 文件,存在一个 walk 方法,它与 os 模块的 walk 方法非常相似:
def get_zip_list(fs, root_path=""):
for root, dirs, fnames in fs.walk(root_path):
for fname in fnames:
if fname.endswith(".zip"):
yield f"{root}/{fname}"
get_zip_list 是一个生成器,它产生所有现有 zip 文件的完整路径。请注意,如果 root_path 是 /,则代码与 os.walk 完全相同。
fsspec 接口限制
虽然 fsspec 为文件系统提供了一个统一且简单的接口,但它不能隐藏所有的语义差异。实际上,在某些情况下,我们并不希望它隐藏所有的差异。以 GitHubFileSystem 为例,这里有两个可能看到差异的情况:
-
额外功能——你可以在任何时间点导航仓库,而不仅仅是当前主分支的时间点。你可以指定一个分支或标签,fsspec 将允许你在那个精确点检查仓库。
-
限制——你不仅会遇到远程文件系统的典型问题(例如,如果你没有连接到互联网,代码将无法工作),而且如果你多次查询服务器,它将对你进行速率限制。
现在我们有了仓库中 zip 文件列表,作为一个初步的、天真的解决方案,我们将从仓库复制 zip 文件到本地文件系统。这里的想法是我们将本地打开它们,看看它们是否有 CSV 文件:
def get_zips(fs):
for zip_name in get_zips(fs):
fs.get_file(zip_name, "/tmp/dl.zip")
yield zip_name
现在,我们可以检查文件内部的内容。为此,我们再次天真地使用 Python 内置的 zipfile 模块:
import zipfile
import pandas as pd
def describe_all_csvs_in_zips(fs):
for zip_name in get_zips(fs):
my_zip = zipfile.ZipFile("/tmp/dl.zip") ①
for zip_info in my_zip.infolist(): ②
print(zip_name)
if not zip_info.filename.endswith(".csv"):
continue
print(zip_info.filename)
my_zip_open = zipfile.ZipFile("/tmp/dl.zip")
df = pd.read_csv(zipfile.Path(my_zip_open,
➥ zip_info.filename).open())
print(df.describe())
① 我们在这里使用 zipfile 模块打开文件。
② 注意 infolist 方法是 zipfile 模块的特有方法,这是需要学习的。
注意我们需要学习的新 API,用于 zipfile。我们从构造函数开始,然后使用了 infolist 方法,但由于 zipfile 的语义,我们可能需要在列表中间重新打开 zip 文件。
8.1.2 使用 fsspec 检查 zip 文件
那段之前的代码列表只是简单说明了 fsspec 帮我们避免的 混乱。fsspec 提供了访问 zip 文件的接口,因此我们可以像这样重写代码:
from fsspec.implementations.zip import ZipFileSystem
def describe_all_csvs_in_zips(fs):
print(zip_name)
for zip_name in get_zips(fs):
my_zip = ZipFileSystem("/tmp/dl.zip")
for fname in my_zip.find(""): ①
if not fname.endswith(".csv"):
continue
print(fname)
df = pd.read_csv(my_zip.open(fname)) ②
print(df.describe())
① find 方法,以及所有其他方法,对所有类型的文件系统都存在,而不仅仅是 zip。
② 与 find 方法一样,open 也适用于所有类型的文件系统。
除了创建 ZipFileSystem 对象外,接口与 GitHub 和常见的 Python 文件接口完全相同。不需要学习 zipfile 接口。
8.1.3 使用 fsspec 访问文件
你也可以使用 fsspec 直接打开文件,尽管其语义与标准的 open 函数略有不同。例如,要使用 fsspec 的 open 打开 zip 文件,我们使用以下代码:
dlf = fsspec.open("/tmp/dl.zip")
with dlf as f: ①
zipf = zipfile.ZipFile(f) ②
print(zipf.infolist())
dlf.close()
① 打开文件时,我们需要使用 with 语句。
② 我们再次使用 Python 的 zipfile 模块来解析文件。
输出结果为:
[
<ZipInfo filename='dummy1.csv' filemode='-rw-rw-r--' file_size=22>,
<ZipInfo filename='dummy2.csv' compress_type=deflate
filemode='-rw-rw-r--' file_size=56 compress_size=54>
]
注意在 open 之后需要使用 with 语句来获取合适的文件描述符,这与仅使用 open 函数的典型方法不同。
8.1.4 使用 URL 链接透明地遍历不同的文件系统
让我们回到 GitHub 仓库内的 zip 文件。注意,因为我们可以将 zip 文件解释为文件的容器,所以这个 zip 文件就像在另一个文件系统中有一个文件系统。fsspec 有一种声明式的方法,允许我们轻松地访问我们的数据:URL 链接。有时你可以取一个流并将其重新解释为文件系统。一个例子将使这一点更清晰;让我们打印 dummy1.csv 的内容:
d1f = fsspec.open("zip://dummy1.csv::/tmp/dl.zip", "rt")
with d1f as f:
print(f.read())
注意 URL 链接的实际应用:我们从 /tmp/dl.zip 中获取 dummy1.csv。你不需要显式打开 zip 文件;fsspec 为你处理了这一点。
记得我们称我们的 get_zips 实现为“天真”吗?这是因为我们不需要显式下载文件,多亏了 URL 链接:
d1f = fsspec.open(
"zip://dummy1.csv::github://tiagoantao:python-performance@/08-"
"persistence/sec1-fsspec/dummy.zip")
with d1f as f:
print(pd.read_csv(f))
我们硬编码完整的链式 URL 以清楚地说明一个明确的用法示例。
8.1.5 替换文件系统后端
现在,因为 fsspec 抽象了文件系统接口,所以很容易替换文件系统实现。例如,让我们将 GitHub 替换为本地文件系统。这很简单:
import os
from fsspec.implementations.local import LocalFileSystem
fs = LocalFileSystem()
os.chdir("../..")
这假设您正在从目录 08-persistence/sec1-fsspec 运行脚本;因此,../.. 将是本书库的根目录。
我们使用 LocalFileSystem 而不是 GitHubFileSystem,这就是全部。因为我们在这个代码中从仓库顶部向下运行了两个层级,我们需要向上移动到树的一级——因此,需要使用 chdir。现在代码的工作方式是在本地文件系统之上,而不是 GitHub。例如,运行 describe_all_csvs_in_zips(fs)。
8.1.6 与 PyArrow 交互
最后,值得注意的是,我们在上一章中讨论的 PyArrow 可以直接与 fsspec 交互:
from pyarrow import csv
from pyarrow.fs import PyFileSystem, FSSpecHandler
zfs = ZipFileSystem("/tmp/dl.zip")
arrow_fs = PyFileSystem(FSSpecHandler(zfs))
my_csv = csv.read_csv(arrow_fs.open_input_stream("dummy1.csv"))
这里重要的是,Arrow 有文件系统的概念,这使得它可以自然地与 fsspec 集成。Arrow 文件系统可以通过 pyarrow.fs.FSSpecHandler 与 fsspec 互连。一旦以这种方式映射了 fsspec 文件系统,就可以在它之上透明地使用 Arrow 文件系统原语。
提示 fsspec 支持从远程服务器部分下载数据的能力,这在可能只需要大文件的一部分的大数据场景中可能很重要。只有当我们尝试使用的服务器类型支持部分文件下载时,才能这样做。例如,GitHub 不支持它;相反,S3 支持。您可以通过在调用 open 时激活缓存来启用此功能,使用参数 cache_type 并将其值设置为 readahead。
这有点跑题了,因为 fsspec 并不直接与性能相关,尽管它被用于许多与性能相关的库中,如 Dask、Zarr 和 Arrow。现在,让我们回到我们的正常编程计划,探讨高效存储异构列式数据(即数据框)的方法。
8.2 Parquet:一种高效的列式数据存储格式
将数据存储在 CSV 中充满了问题。首先,因为它们无法容纳每列的类型,所以在列中意外值并不少见。此外,该格式本身效率低下。例如,您可以用二进制形式比文本形式更紧凑地表示数字。此外,您不能在常数时间内跳转到特定的行或列,因为无法计算其位置,因为 CSV 中的每一行大小都可能不同。
Apache Parquet 正在成为最常用的格式,用于高效存储异构列式数据。这意味着您可以访问所需的列,并且还可以使用数据压缩和列编码格式来提高性能。
在本节中,我们将学习如何使用 Parquet 存储数据框,借鉴上一章中纽约市出租车数据。在完成这个任务的过程中,我还会介绍许多 Parquet 功能的概览。
警告:Parquet 是一个起源于 Java 世界、特别是在 Hadoop 生态系统中的文件格式。虽然可用的 Python 实现非常适合生产目的,但它们并没有完全实现规范。例如,我们无法详细指定我们想要如何编码列;我们也不能检查列是如何存储的——我将在下面展示这一点。但对于绝大多数用例,必要的功能是存在的,并且随着时间的推移只会增加。
就作为一个提醒,出租车数据集包含了关于一段时间内纽约市所有出租车行程的信息。信息包括但不限于行程的开始和结束时间、开始和结束位置、费用、税费和小费。我们将从使用与上一章相同的文件开始,该文件包括 2020 年 1 月的出租车行程。我们将首先将 CSV 文件转换为 Parquet。为此,我们将使用上一章中介绍的 Apache Arrow。代码可以在 08-persistence/ sec2-parquet/start.py 中找到:
import pyarrow as pa
from pyarrow import csv
import pyarrow.parquet as pq
table = csv.read_csv(
"../../07-pandas/sec1-intro/yellow_tripdata_2020-01.csv.gz")
pq.write_table(table, "202001.parquet")
我们简单地使用 PyArrow 的 Parquet 模块中的 write_table。最终我们得到一个 111 MB 的二进制文件。压缩后的 CSV 文件是 105 MB,原始未压缩版本是 567 MB。因为 Parquet 是一种结构化二进制格式,所以我们应预期相同内容的大小会有所不同。这里的重点不是纠结于细节,而是要了解大小关系。
8.2.1 检查 Parquet 元数据
让我们通过检查文件来发现一些 Parquet 的特性:
parquet_file = pq.ParquetFile("202001.parquet")
metadata = parquet_file.metadata
print(metadata)
print(parquet_file.schema)
group = metadata.row_group(0)
print(group)
简化的输出是:
<pyarrow._parquet.FileMetaData object at 0x7f90858879f0>
created_by: parquet-cpp-arrow version 4.0.0
num_columns: 18
num_rows: 6405008
num_row_groups: 1
format_version: 1.0
serialized_size: 4099
<pyarrow._parquet.ParquetSchema object at 0x7f9193aeed00>
required group field_id=0 schema {
optional int32 field_id=1 VendorID (Int(bitWidth=8, isSigned=false));
optional int64 field_id=2 tpep_pickup_datetime (
Timestamp(isAdjustedToUTC=false, timeUnit=milliseconds,
is_from_converted_type=false, force_set_converted_type=false));
....
<pyarrow._parquet.RowGroupMetaData object at 0x7f90858ad0e0>
num_columns: 18
num_rows: 6405008
total_byte_size: 170358087
我们首先打印文件的元数据。这里我们只是获取一些摘要信息,例如有 18 列和 6,405,008 行。Parquet 还告诉我们文件中只有一个行组。行组是总行数的分区:在较大的文件中,可能有多个行组。行组将包含组中所有行的列数据。记住,Parquet 中的信息是按列组织的。这很快就会变得清晰。
然后我们打印文件的架构。简化的版本是:
required group field_id=0 schema {
optional int32 field_id=1 VendorID (Int(bitWidth=8, isSigned=false)); ①
optional int64 field_id=2 tpep_pickup_datetime (
Timestamp(isAdjustedToUTC=false, timeUnit=milliseconds,
is_from_converted_type=false, force_set_converted_type=false));
optional double field_id=5 trip_distance;
optional binary field_id=7 store_and_fwd_flag (String);
}
① 这里是 VendorID 的定义,它有 8 位宽,且未签名。
这段代码列出了我们数据中的所有列。例如 VendorID 是一个 int32,但请注意位宽是 8 位,且是无符号的。VendorID 只有两个可能的值加上一个空值,所以将其实现减少到仅 8 位无符号位是有意义的。理论上,Parquet 支持这种减少,甚至可以减少到更少的位。
然后是 tpep_pickup_datetime,这是一个时间戳。从存储的角度来看,时间单位是最重要的变量,因为更高的精度需要更多的空间。pandas 默认为纳秒精度。注意 store_and_fwd_ flag:文本以通用二进制数据存储。
8.2.2 Parquet 的列编码
现在我们来看看几个列的现有元数据:
tip_col = group.column(13) # tip_amount
print(tip_col)
简化的输出是:
physical_type: DOUBLE
num_values: 6405008
path_in_schema: tip_amount
statistics: ①
has_min_max: True
min: -91.0
max: 1100.0
null_count: 0
distinct_count: 0
num_values: 6405008
physical_type: DOUBLE
logical_type: None
converted_type (legacy): NONE
compression: SNAPPY ②
encodings: ('PLAIN_DICTIONARY', 'PLAIN', 'RLE')
has_dictionary_page: True
① 列的统计信息从这里开始。
② 列中使用的压缩算法
元数据从物理类型、值的数量和列名开始。统计信息(似乎有一个负的小费——可能是输入错误)显示$-91 为最小值,$1,000 为最大小费。现在事情开始变得真正有趣。
关于数据的存储,Parquet 可以压缩列,这可以节省磁盘空间。压缩列还可以提供与上一章讨论的缓存管理问题相关的潜在计算收益。不同的列可以有不同的压缩类型或根本不进行压缩。
在我们的例子中,使用的是 Snappy 压缩算法。与 gzip(也是一个选项)相比,Snappy 在压缩和速度之间进行了权衡。确保在使用时检查 Arrow 实现了哪些压缩算法。Facebook 在facebook.github.io/zstd/#benchmarks有一些基准测试信息,可以帮助你做出决定。
例如,你可以使用 ZSTD:
pq.write_table(table, "202001_std.parquet", compression="ZSTD")
在本例中,我们使用 ZSTD 对所有列进行压缩。在这种情况下,从 Snappy 的 110 MB 降至 82 MB。
Parquet 不仅可以直接使用值编码列,还可以使用字典,其中长值被转换为间接引用,这可能会节省大量磁盘空间。为了了解这如何有所帮助,考虑小费是以双精度表示的,需要 64 位,而小费只有 3626 个不同的值:
print(len(table["tip_amount"].unique()))
字典可以将编码从每值 64 位减少到 12 位,这足以编码高达 4096 个值。我们还需要存储字典,这对于 3626 个值来说是多余的。然而,因为我们有很多不同的值,使用字典可能没有意义。你可以使用write_table来控制是否使用字典存储列。
最后但同样重要的是,请注意编码也使用了 RLE,即运行长度编码。让我们用一个有点愚蠢的例子来看看 RLE 的优势。让我们创建一个包含VendorID列的数据帧,后面跟着另一个也带有VendorID列,但有序:
import pyarrow.compute as pc
silly_table = pa.Table.from_arrays([
table["VendorID"],
table["VendorID"].take(
pc.sort_indices(table["VendorID"]))],
["unordered", "ordered"]
)
所以这是相同的数据,有序和无序版本。现在让我们看看每个列在 Parquet 文件中占用多少空间:
pq.write_table(silly_table, "silly.parquet")
silly = pq.ParquetFile("silly.parquet")
silly_group = silly.metadata.row_group(0)
print(silly_group.column(0))
print(silly_group.column(1))
无序文件占用 953,295 字节,而有序文件仅占用 141 字节!RLE 的工作方式是存储值和重复次数。对于有序的VendorID列,我们有一个极端案例:我们只有三个值(1、2和null),它们是有序的。所以理论上,RLE 可以存储:1.0 2094439 / 2.0 4245128 / null 65441。
RLE 可以相当大幅度地压缩数据。虽然就效率而言我们的案例是极端的,但 RLE 通常适用于有序字段或值较少的字段。然而,如果你偏离了这些假设,确保你评估你获得的压缩效益。
较小的文件有助于存储 和 处理时间。记住,在第六章中提到,如果您可以将数据存储在更快的内存类型中,您有时可以在性能上获得数量级的提升。
该格式是可扩展的,因此您可以期待随着时间的推移开发出一种新的高效存储数据的方式。该格式还允许数据分区,从效率角度来看具有几个优点。让我们用一个例子来说明。
8.2.3 使用数据集进行分区
为了阐明分区意味着什么以及涉及的过程,让我们使用 VendorID 和 passenger_count 对我们的数据集进行分区。由于分区不能基于空值,我们将从数据集中删除这些值。我们只为这次练习这样做;通常情况下,您不能仅仅为了方便就删除空值行:
from pyarrow import csv
import pyarrow.compute as pc
import pyarrow.parquet as pq
table = csv.read_csv(
"../../07-pandas/sec1-intro/yellow_tripdata_2020-01.csv.gz")
table = table.filter(
pc.invert(table["VendorID"].is_null())) ①
table = table.filter(pc.invert(table["passenger_count"].is_null()))
pq.write_to_dataset(
table, root_path="all.parquet",
partition_cols=["VendorID", "passenger_count"])
① 再次注意,使用 Arrow 进行计算时的语法与 pandas 非常不同。
pandas 中第一行过滤语句的等价操作将是 table = table[~table ["VendorID"].isna()]。
如果您查看 all.parquet,您会发现一些惊喜:最大的惊喜是它不再是文件,而是一个目录!简化的内容可能如下所示:
.
├── VendorID=1
│ ├── passenger_count=0
│ │ └── e59ac47b5193411e9772bfee9d423d61.parquet
│ ├── passenger_count=1
│ │ └── ee90fe5b818d4a37a32b5a415915610b.parquet
│ └── passenger_count=9
│ └── 002ff0bba1d340abb6174c5c64f779d7.parquet
└── VendorID=2
├── passenger_count=0
│ └── 5809e29649524202a9b3cef5371c46d9.parquet
└── passenger_count=9
└── feaff7a23bbf4ae2b687b34dcaa10afb.parquet
目录结构反映了我们的分区策略。第一级目录中每个 VendorID 都有一个条目,第二级目录中每个 passenger_count 都有一个条目。
您现在有两个选择。最简单的一个——也许不那么有趣——是将所有内容都加载为一个表格:
all_data = pq.read_table("all.parquet/")
在这里,您将拥有所有数据作为一个正常的表格。您也可以采取以下措施达到相同的效果:
dataset = pq.ParquetDataset("all.parquet/")
ds_all_data = dataset.read()
但是,作为另一种选择,您也可以单独加载每个 parquet 文件。例如,让我们加载包含三个乘客的供应商 ID 1 的分区文件:
import os
data_dir = "all.parquet/VendorID=1/passenger_count=3"
parquet_fname = os.listdir(data_dir)[0] ①
v1p3 = pq.read_table(f"{data_dir}/{parquet_fname}")
print(v1p3)
① parquet 文件的名字并不保证,所以我们获取目录中的第一个文件。
如果您查看输出,您会注意到列 VendorID 和 passenger_count 缺失,因为它们可以从目录中推断出来。
警告:每个目录中的内容可能不同。在我们的案例中,使用 PyArrow,它是一个单独的 Parquet 文件。例如,您可以让 Parquet 将每个分区进一步拆分为一个文件,按行组拆分。因此,请确保您调查数据实际上是如何写入磁盘的,并相应地调整代码。
从性能角度来看,分区的目的是什么?我们现在可以单独加载每个 Parquet 文件并相应地处理每个文件。例如,我们可以通过在同一台机器上使用多个进程来提高性能,每个进程分析每个文件。我们甚至可以在不同的机器上处理不同的文件。隐含地,文件系统可以更高效,因为并发加载在不同的磁盘部分进行。我们还可以通过不加载分区列来获得内存上的收益。最后,分区开辟了并发写入的途径,这从并行性中获得了性能提升。我们将在第 8.4 节中更详细地讨论并发写入。
数据的分区方式从性能角度来看很重要。例如,供应商 1 的数据是供应商 2 的一半,这意味着处理供应商 2 的成本可能大约是供应商 1 的两倍。这种加倍可能会导致您等待所有分区中最慢的一个,因为您希望尽可能均匀。与passenger_count相比,VendorID可能是一个不错的选择。Parquet 有许多更多功能,但从性能角度来看,我们现在已经很好地概述了我们可以从该格式中受益的方式。
8.3 以传统方式处理大于内存的数据集
在本节中,我们将使用 Parquet 和 CSV 文件来介绍两种处理大于内存的数据的简单技术:内存映射和分块。处理这两个任务还有更复杂的方法,我们将在第 8.4 节以及下一章中讨论它们。但是,分块和内存映射是支撑更复杂库的重要概念。因此,理解它们不仅本身有效,而且对于理解更高级的技术也是基本的。
8.3.1 使用 NumPy 进行内存映射文件
当内存的一部分直接与文件系统的一部分关联时,就会发生内存映射。在 NumPy 的具体情况下,持久化到存储的数组可以使用正常的 NumPy API 进行评估,NumPy 将负责将我们从数组中需要的任何部分带到 RAM 中。在大多数情况下,这是由操作系统内核为 NumPy 透明完成的。相反,当我们写入时,它将更改持久表示。因为您正在评估内存,这可以以数量级的方式加快您的代码。图 8.1 描述了内存映射。

图 8.1 内存映射将文件的一部分映射到内存中。
在这种情况下,我们将使用一个简单的抽象示例来创建一个大数组并访问它。您可以决定数组的大小。为此练习,我建议一个比您的内存更大的大小,但您有足够的磁盘空间。分配相当简单:
import numpy as np
SIZE_IN_GB = 10 ①
array = np.memmap("data.np", mode="w+",
dtype=np.int8, shape=(SIZE_IN_GB * 1024, 1024, 1024))
print(array[-1, -1, :10])
① 将大小更改为适合您机器的值,如之前所述
np.memmap调用非常简单:您传递给它一个文件名、一个打开模式以及数组的类型和形状。如果您列出磁盘上的文件,您将找到一个大小为 10GB 的文件。
数组将以所有零初始化;因此,打印将显示一个包含 10 个零的数组。现在让我们向数组中的所有元素添加 2:
array += 2
接口与内存中的 NumPy 数组完全相同。但您会注意到这个操作将花费几秒钟。时间增加是因为整个大文件正在被改变;这不是一个快速的内存操作。
现在让我们打开文件并打印最后一个值:
array = np.memmap("data.np", mode="r",
dtype=np.int8)
print(array.shape)
print(array[:-10])
输出是:
(10737418240,)
[2 2 2 ... 2 2 2]
这里的重要点是数组的形状并没有与其保存,所以如果你不指定形状进行映射,你会得到一个线性数组。因此,你需要确保你恢复所需的形状。然后我们打印该数组的最后 10 个元素,我们得到十个 2。
NumPy 写时复制
NumPy 内存映射允许你使用一种称为 写时复制 的技术。这允许你将多个磁盘数组副本加载到内存中,并在内存使用方面支付显著较低的价格。这种技术在许多情况下都容易出 bug,主要是因为 Python 不是处理共享数据结构最好的语言,而且当更改底层文件时,内存映射语义变得不明确。我认为除非你 确定 你只会执行读操作,否则这种技术的优势不足以证明其风险。如果你想研究这种技术,我推荐 Itamar Turner-Trauring 的优秀文章,可在 pythonspeed.com/articles/reduce-memory-array-copies/ 找到。
我通常会避免使用执行并发写入和共享的显式内存映射技术,除非你绝对确定每个进程只进行读取。此外,如果你是低级库的开发者,你可能会使用带有写入的内存映射,但你可能不会使用 Python 来实现最有效的部分,所以这个问题将在其他语言中得到解决。
记住,即使你没有直接使用内存映射,你使用的许多框架也会隐式地这样做,所以理解它是很有用的。现在让我们讨论另一种处理大文件的技术:分块。
8.3.2 数据帧的块读取和写入
分块,正如其名所示,意味着以,嗯,块的形式处理文件。你按部分读取(或写入)文件。如果你使用 Zarr(见第 8.4 节)或 Dask(见第十章),你肯定会处理分块。
在这里,我们将回到我们信任的出租车示例。我们将以块的形式将文件从 CSV 转换为 Parquet,尽管文件足够小,在大多数计算机上我们可以在内存中完成这个操作,但让我们假设我们在一个内存受限的机器上,并且无法在内存中加载整个文件。
我们将使用 pandas 读取 CSV 文件,并使用 Arrow 写入 Parquet 版本。我们可以用 Arrow 完成所有操作,这将更高效,但我们想展示 pandas 分块接口:
import pandas as pd
table_chunks = pd.read_csv(
"../../07-pandas/sec1-intro/yellow_tripdata_2020-01.csv.gz",
chunksize=1000000
)
print(type(table_chunks)) ①
for chunk in table_chunks: ②
print(chunk.shape)
① 类型将是 pandas.io.parsers.TextFileReader。
② 每个块将是一个数据帧。
我们只需要将 chunksize 参数添加到 read_csv 中。你将不会从 read_csv 获得一个数据帧,而是一个块生成器。然后每个块将是一个数据帧,最大行数为 100 万行。
我们现在将进行适当的转换。首先,我们需要重新打开文件。我们已经遍历了所有块一次,所以我们需要回到开始:
table_chunks = pd.read_csv(
"../../07-pandas/sec1-intro/yellow_tripdata_2020-01.csv.gz",
chunksize=1000000,
dtype={
"VendorID": float,
"passenger_count": float,
"RatecodeID": float,
"PULocationID": float,
"DOLocationID": float,
"payment_type": float,
}
)
我们还需要指定一些列的数据类型;某些列的类型可能会从一块数据变化到另一块数据。这种情况在整数列中尤为常见,尤其是那些包含空值的列。当存在空值时,类型会被提升为float,因为在 pandas 中无法用整数来表示空值。
现在,我们将遍历块并创建 Parquet 文件:
first = True
writer = None
for chunk in table_chunks:
chunk_table = pa.Table.from_pandas(chunk) ①
schema = chunk_table.schema
if first:
first = False
writer = pq.ParquetWriter(
"output.parquet", schema=schema) ②
writer.write_table(chunk_table)
writer.close()
① 我们将 pandas 框架转换为 Arrow 表。
② 我们创建一个写入对象。在初始化时我们需要指定模式。
ParquetWriter接口允许我们在同一文件中写入多个表。每个表将写入一个单独的 Parquet 行组。在某种程度上,它将是一个块。
我们可以通过几种方式读取 Parquet 数据:
pf = pq.ParquetFile("output.parquet")
print(pf.metadata)
for groupi in range(pf.num_row_groups): ①
group = pf.read_row_group(groupi)
print(type(group), len(group))
break
table = pf.read()
table = pq.read_table("output.parquet")
① 我们可以单独读取每一行组。
Parquet 文件的元数据将指示存在七个行组。Parquet 允许我们按行组读取。如果您有足够的内存,有两个接口可以在ParquetFile的read方法或parquet模块的read_table方法中使用,它们负责读取所有行组并在内存中创建一个表。
借助分块的概念,这使我们能够分部分加载数据并进行处理,我们现在将来看看 Zarr。Zarr 是一个库,它允许我们操作非常大的同构 N 维数组(即 NumPy 对象)。
8.4 Zarr 用于大数组持久化
现实中存在的一些最大的数据集并不是异构的表格数据帧,而是多维同构数组。因此,高效地存储这些较大的数组非常重要。
Zarr 允许我们使用不同的后端和不同的编码格式高效地存储同构的多维数组。像并发写入这样的功能可以非常有效地生成数据。
存在着一些非常成熟的表示数组数据的标准(例如 NetCDF 和 HDF5),但在这个案例中,我们将使用新兴的格式 Zarr。Zarr 在优化方面比其他任何格式都要好,对于高效处理非常有用。例如,它允许并发写入和不同的文件结构组织,这两者都可以对性能产生巨大影响。并发写入允许许多并行进程同时在同一结构上工作。不同的文件结构使我们能够利用文件系统的性能特性。
虽然 Zarr 是一种文件格式,但它始于 Python 空间,由一个名为 Zarr 的库实现。因此,您可以确信 Python 版本实现了该格式的所有主要功能。如果您计划使用其他编程语言的 Zarr 文件,您应该首先检查这些语言的库是否支持这些功能。在某种程度上,Zarr 与 Parquet 相反:Parquet 是从 Java 生态系统迁移到 Python 的,因此 Python 对 Parquet 的支持仍然不全面。对于 Zarr,Python 实现是金标准。
Zarr 起源于生物信息学领域,我们将使用一个生物信息学示例。我们将使用来自一个名为 HapMap 的旧基因组项目的数据(www.genome.gov/10001688/international-hapmap-project)。该项目为人类群体中的许多个体的基因组变异(DNA 字母的变化)提供了信息。你不需要了解这个练习中的任何科学细节。我们将随着我们的进展介绍所需的最小概念。
对于我们的示例,我们将从一个预先准备好的 Zarr 数据库开始,该数据库是我从 Plink 格式的 HapMap 数据生成的(www.cog-genomics.org/plink/2.0/)。你不需要担心原始格式,但如果你对它感兴趣并且为了完整性,你可以在08-persistence/sec4-zarr/hapmap 仓库中找到生成你应使用的 Zarr 数据库的代码。预先准备好的 Zarr 文件可以在tiago.org/db.zarr.tar.gz找到。它包括跨越几个人类群体的 210 个个体的遗传信息。
我们的一个目标将是生成另一个 Zarr 数据库,可以用于执行主成分分析(PCA)——在基因组学中常见的无监督机器学习技术——这将需要重新格式化我们从原始数据库中拥有的数据。我们不会在这里运行 PCA,但只是为该操作准备文件。
8.4.1 理解 Zarr 的内部结构
让我们先看看数据库中有什么。在我们遍历数据库的同时,我们将提醒自己涉及的必要基因组概念:
import zarr
genomes = zarr.open("db.zarr")
genomes.tree() ①
① 打印文件内容的树形结构
Zarr 是一个用于数组的树形容器,因此我们有一个目录结构,其中叶子节点是数组。我们文件的简化版本如下:
├── chromosome-1
│ ├── alleles (318558,) <U2
│ ├── calls (318558, 210) uint8
│ └── positions (318558,) int64
├── chromosome-10
│ ├──alleles (216535,) <U2
│ ├──calls (216535, 210) uint8
│ └── positions (216535,) int64
数据按染色体分割,每个染色体都有一个层次结构。每个染色体都有一个基因型位置列表(对于这些位置,我们得到 DNA 字母),这些位置在 positions 中。每个位置的可能等位基因(即 DNA 字母)在 alleles 数组中。主要矩阵在 calls 中,其中对于 210 个个体,我们有每个标记的等位基因。因此,由于第 1 个染色体中有 318,558 个标记,calls 矩阵将是 318,558 × 210。对于每个个体和标记,有两个调用,这些调用将用一个单独的数字编码。
我们的目标是创建一个所有调用拼接的矩阵,以提交给 PCA 实现。不要担心遗传学;从我们的角度来看,重要的是我们有一个二维的 calls 矩阵,其中的 0/1/2 值以 8 位无符号整数编码,以及两个一维数组,一个包含 64 位整数(positions),另一个包含最多两个字符的字符串(alleles)。
在我们深入性能相关的问题之前,让我们简要讨论如何遍历 Zarr 数据。我们可以这样遍历整个结构:
def traverse_hierarchy(group, location=""):
for name, array in group.arrays(): ①
print(f"{location}/{name} {array.shape} {array.dtype}")
for name, group in group.groups(): ②
my_root = f"{location}/{name}"
print(my_root + "/")
traverse_hierarchy(group, my_root)
traverse_hierarchy(genomes)
① 获取组内所有组
② 获取组内所有数组
当 Zarr 读取文件时,它返回一个 Group 对象。groups 方法将返回一个生成器,包含所有子组,因此我们可以依赖它来遍历 Zarr 存储库。
你也可以使用类似目录的简单命名法来访问内容,这取决于你的主观偏好。例如:
in_chr_2 = genomes["chromosome-2"]
pos_chr_2 = genomes["chromosome-2/positions"]
calls_chr_2 = genomes["chromosome-2/calls"]
alleles_chr_2 = genomes["chromosome-2/alleles"]
in_chr_2 将有一个名为 pos_chr_2 的 Group,calls_chr_2 和 alleles_chr_2 分别有相应的数组 chromosome-2/positions、chromosome-2/calls 和 chromosome-2/alleles,准备使用。
让我们从我们的数据结构中获取一些信息:
print(in_chr_2.info)
输出是:
Name : /chromosome-2
Type : zarr.hierarchy.Group
Read-only : False
Store type : zarr.storage.DirectoryStore
No. members : 3
No. arrays : 3
No. groups : 0
Arrays : alleles, calls, positions
我们有一个 Group,包含三个成员,它们恰好都是数组;名称内部也可能有子组。
Zarr 支持许多类型的存储:在我们的例子中,我们使用 zarr.storage .DirectoryStore,但你也可以找到内存、zip 文件、DBM 文件、SQL、fsspec、Mongo 等类的实现。
正如我们很快就会看到的,DirectoryStore在支持高级并行功能方面非常有帮助,但就目前而言,让我们看看它使用的目录结构。如果你还没有注意到,db.zarr不是一个文件,而是一个目录。以下代码片段是目录结构的简化版本:
.
├── chromosome-1
│ ├── alleles
│ └── calls
│ └── positions
├── chromosome-10
│ ├── alleles
│ └── calls
│ └── positions
...
目录结构模仿了 Zarr 组结构,这使得开发变得容易。
8.4.2 Zarr 中数组的存储
现在我们来讨论数组是如何存储的,这是一个更加复杂和有趣的主题:
print(pos_chr_2.info)
现在我们将重新排序输出并将其分成几个部分。让我们从一些基本信息开始:
Type : zarr.core.Array
Data type : int64
Shape : (333056,)
Order : C
Read-only : False
Store type : zarr.storage.DirectoryStore
你应该能够用我们在前几章学到的知识来解释这条信息:这个对象是一个 zarr.core.Array,数据类型是 64 位整数,包含 333,056 个数字,数组是 C 顺序的,并且可以写入。
现在我们来看看块形状:
Chunk shape : (41632,)
Chunks initialized : 8/8
记住,分块是将一个大数组分割成更小的相等部分(块)的一种方式,这样就可以更容易地操作(图 8.2)。

图 8.2 一个大数组文件可以被分割成等大小的块以分别处理。
Zarr 告诉我们每个块的大小是 41,632 个元素;因此,我们最终得到八个块来容纳 333,056 个元素。当我们创建支持脚本来创建预准备版本中的数组时,我们有点天真,没有指定块大小,因此 Zarr 尝试猜测一个合理的值。块大小可以在创建时指定。我们将在后面的部分中看到原因。
注意,所有块都已初始化:然而,在某些情况下,并非所有块都需要初始化(例如,空数组)。未初始化的块可以潜在地节省大量磁盘空间。我们将在创建数组时看到这一点。
如果你进入db.zarr/chromosome-2/positions目录,你会找到八个文件,分别命名为0到7;这是一个块对应的文件。这种分离将使得在 Zarr 中实现并发写入——这是一个在许多数组存储系统中找不到的复杂特性——变得更加容易。
最后,Zarr 数组可以被压缩,从而节省大量磁盘空间和潜在的处理时间,正如本章前面所讨论的。以下是描述这一部分的输出:
Compressor : Blosc(cname='lz4', clevel=5,
shuffle=SHUFFLE, blocksize=0)
No. bytes : 2664448 (2.5M)
No. bytes stored : 687723 (671.6K)
Storage ratio : 3.9
在我们的案例中,数据使用 Blosc 和 LZ4 算法存储。原始大小是 2,664,448 字节——333,056 个元素乘以 8 字节用于 64 位整数,最终存储为 687,723 字节,因此压缩了 3.9 倍。鉴于数组是同质的,我们应该期望,平均而言,压缩将优于异构数据帧的整体压缩。当然,这个期望是针对平均情况;例如,随机数组非常难以压缩。
对于calls数组,我们有类似的输出,但适应了二维。以下是print(calls_chr_2.info)的简化版本:
Shape : (333056, 210)
Chunk shape : (41632, 27)
Chunks initialized : 64/64
在这种情况下,我们有一个 333,056 × 210 维度的矩阵和二维块。
提示:你可以将 N 维数组分成少于 N 维的维度。例如,我们的二维数组可以只在一个维度上分块。如果你需要同时处理一个维度上的所有信息,这个选择可能是有意义的。与所有分块决策一样,这取决于你的用例。
每个维度被分成八个区间,总共 64 个块。如果你列出db.zarr/chromosome-2/calls的内容,你会找到 64 个方便命名的文件 X.Y,其中 X 和 Y 从 0 到 7 变化,这指的是每个维度上的块编号。
最后,我们有一个包含所有等位基因的数组,它是一个由两个字符组成的字符串(例如,AT、CG、TC 等)。print(alleles_chr_2.info)的简化输出如下:
Data type : <U2
这个输出是一个固定两字节大小的 Unicode 字符串。记得在第二章中提到,Python 字符串表示是复杂的——或者根据观点的不同,可能是繁琐的——评估 Python 字符串的字节数远非易事。
为了高效访问,如果我们有固定大小的字符串和可预测大小的表示,那就很有帮助。Zarr 提供了两种内置的字符串表示:如果你只有 ASCII 字符,你可以使用字节数组;如果你有超过 ASCII 字符的情况,Zarr 提供了一个固定大小的 Unicode 表示,这与可变大小的 Python 字符串实现相反。如果你需要可变长度的字符串和不同的编码,Zarr 也提供了相应的编码器,但要注意这种灵活性的性能影响;如果可能且在存储方面合理,分配一个固定长度的长字符串。
现在我们已经了解了 Zarr 数据的组织方式,让我们创建一个数组,它是所有染色体上所有位置的连接。我们需要一个来自所有染色体的单个矩阵,因为 PCA 需要一个单独的矩阵作为输入。
8.4.3 创建新数组
我们现在将创建一个新的数组,它可以用于像 PCA 这样的无监督学习算法。这仅仅是所有调用数组(即所有染色体的调用)的连接。
在我们开始之前,我们必须知道我们需要分配的数组的大小。为了了解这一点,我们遍历现有的 Zarr 文件以提取每个染色体上的标记数量,这取决于你的具体问题:
import zarr
genomes = zarr.open("db.zarr")
chrom_sizes = []
for chrom in range(1, 23):
chrom_pos_array = genomes[f"chromosome-{chrom}/positions"]
chrom_sizes.append(chrom_pos_array.shape[0])
total_size = sum(chrom_sizes)
这段代码只是检查所有一维数组的第一维以确定位置。有了这些信息,我们可以计算包含所有数据的 Zarr 数组的大小。
在掌握了总大小之后,我们可以分配数组:
CHUNK_SIZE = 20000
all_calls = zarr.open(
"all_calls.zarr", "w",
shape=(total_size, 210), ①
dtype=np.uint8, # type change
chunks=(CHUNK_SIZE,))
① 210 是我们数据集中个体的数量。
在性能方面最重要的参数是块大小。我们选择了一个值,使得每个块的大小超过 1 MB,尽管你可能需要根据你的具体情况调整块大小。20,000 乘以 210 的总数大约是 4 MB,但我们预计会有一些压缩。我们假设所有个体将一次性读取,所以我们只在单个维度上进行分块。你可以自由地调整块大小,你会看到明显的性能差异。
决定块大小的通用想法
对于块大小,很难制定一般性的规则。你需要查看你的算法和用例。话虽如此,这里有一些基本的规则:
-
你不希望块太小;它们通常至少应该是 1 MB 或更大。
-
你的块应该能够轻松地适应内存。
-
在不同维度上尝试不同的值。这些值可能会对性能和你在内存中完成所有工作的能力产生重要影响。
-
块大小和存储类型不是正交的。例如,
DirectoryStore由于同一目录中文件过多导致的文件系统性能问题,在处理成千上万的块时扩展性不好。在这种情况下,Zarr 提供了NestedDirectoryStore来将块分散到子目录中。但重要的是,如果你理解不同存储的限制,并根据这些限制来参数化分块,那就更好了。
让我们获取all_calls的信息。简而言之是:
Type : zarr.core.Array
Data type : uint8
Shape : (3976554, 210)
Chunk shape : (20000, 210)
No. bytes : 835076340 (796.4M)
No. bytes stored : 345
Storage ratio : 2420511.1
Chunks initialized : 0/199
需要注意的最重要的问题是存储的字节数,以及与之相关的初始化的块数量。虽然预期的总大小是 796.4 MB,但实际使用中只有 345 字节(!)因为还没有保存数据(即没有初始化任何块)。默认情况下,Zarr 假设如果未初始化,数组中的所有值都是 0。如果在当前阶段列出all_calls.zarr目录,你会发现它是空的,并且它根本不占用任何空间。
实际上,一个名为.zarray的隐藏文件包含一些元数据。如果您打开该文件,您将找到一个包含我们传递给 Zarr 数组创建的参数的 JSON 版本以及其他默认值的文件。
8.4.4 Zarr 数组的并行读写
现在让我们创建一个单一的连接数组。我们需要一个包含所有 PCA 分析数据的单个数组。
我们将讨论两个版本:第一个是顺序版本,第二个是并行版本。以下是第一个:
def do_serial():
curr_pos = 0
for chrom in range(1, 23):
chrom_calls_array = genomes[f"chromosome-{chrom}/calls"]
my_size = chrom_calls_array.shape[0]
all_calls[curr_pos: curr_pos + my_size, :] = chrom_calls_array
curr_pos += my_size
do_serial()
print(all_calls.info)
此代码只是将所有染色体调用按顺序复制到all_calls数组中。请注意,所有存储管理都是在典型的 NumPy 接口之上完全抽象的。
在您运行代码后,如果您打印all_calls的信息,几乎没有变化:
No. bytes : 835076340 (796.4M)
No. bytes stored : 297035153 (283.3M)
Storage ratio : 2.8
Chunks initialized : 199/199
现在所有块都已初始化,存储占用 283.3 MB——与总字节数 796.4 MB 相比,存储比为 2.8。如果您列出all_calls.zarr目录,您将找到 199 个文件:每个块一个。
之前的代码运行需要几秒钟。虽然我不会要求您运行一个包含数以 TB 计的数据的示例,这将花费数小时,但很容易看出,对于更多数据,进行此类转换所需的时间可能会变得过长。
因此,作为第二个版本,我们将创建一个并行版本,该版本将从染色体数组中读取并将写入all_calls数组。读取和写入都将并行进行。
并非许多库支持并行写入,但 Zarr 支持。通过将每个块放入单独的文件中,目录存储使得 Zarr 实现并行写入变得容易。在这种情况下,基于文件系统性能属性的一个简单设计为一个非常重要的特性打开了可能性。
理论上,您可以使用您喜欢的任何大小进行写入,但按块进行写入将是最有效的,因为 Zarr 将不必处理同一文件上的并发写入。基本点是您的块大小应与您的用例相匹配,如果可能的话,您应该尝试按块处理数据。
在我们的情况下,我们不能简单地按染色体进行;我们必须按块写入。以下是写入块的一般函数:
def process_chunk(genomes, all_calls, chrom_sizes, chunk_size, my_chunk):
all_start = my_chunk * chunk_size ①
remaining = all_start
chrom = 0
chrom_start = 0
for chrom_size in chrom_sizes: ②
chrom += 1
remaining -= chrom_size
if remaining <= 0:
chrom_start = chrom_size + remaining
remaining = -remaining
break
while remaining > 0: ③
write_from_chrom = min(remaining, CHUNK_SIZE)
remaining -= write_from_chrom
chrom_calls = genomes[f"chromosome-{chrom}/calls"]
all_calls[all_start:all_start + write_from_chrom, :] = chrom_calls[
chrom_start: chrom_start + write_from_chrom, :]
all_start = all_start + write_from_chrom
① 首次写入位置是块号乘以块大小。
② 我们遍历所有染色体大小,直到找到开始的位置。
③ 一个块可能需要多个染色体。
如果您不完全理解之前的代码,请不要担心:该代码特定于领域。重要的是一般方法。我们正在尝试以块为基础的方式工作,这对于不适合内存的大文件是合适的。
我们现在可以使用一个简单的多进程池和一个映射调用来处理每个块:
from functools import partial
from multiprocessing import Pool
partial_process_chunk = partial(
process_chunk, genomes,
all_calls, chrom_sizes, CHUNK_SIZE)
def do_parallel():
with Pool() as p:
p.map(partial_process_chunk, range(all_calls.nchunks))
do_parallel()
我们通过定义partial_process_chunk进行部分函数应用,以便Pool.map调用更容易。然后我们使用多进程池来处理我们的映射;更多细节,请参阅第三章。
摘要
-
fsspec 作为文件存储的统一接口,允许使用相同的 API 在许多不同的后端之间进行操作。
-
由于与 fsspec 有统一的 API,替换后端的工作变得大大简化。
-
虽然 fsspec 与性能没有直接关系,但几个高级库都使用了它,包括 Arrow 和 Zarr。
-
Parquet 是一种列式数据格式,它允许更有效地存储数据:数据被类型化,可能被压缩,并且按列组织。
-
Parquet 使用复杂的数据编码策略,如字典或运行长度编码,允许非常紧凑的表示,特别是对于具有明显模式和重复的数据。此外,该格式是可扩展的,未来可能会有更多的性能提升。
-
Parquet 允许数据分区,这为程序员提供了并行处理数据的能力。
-
处理大于内存的文件最常见的技术是分块。pandas、Parquet 和 Zarr 等许多库都支持分块。
-
Zarr 是一个用于处理同构多维数组的现代库。它起源于 Python 世界,并提供了基于 NumPy 的接口。
-
Zarr 默认支持并行性。支持并发写入过程是一个值得注意的特点,因为在其他库中这是不常见的功能。
第四部分. 高级主题
第四部分涵盖了高级主题。我们首先讨论了使用图形处理单元(GPUs)处理大数据的优势。结果证明,GPU 的计算模型非常适合处理大型数据集,尤其是 N 维数组。我们通过介绍 Dask 来结束本节:这是一个基于 Python 的框架,可以在多台计算机上执行并行处理,使我们能够在需要使用复杂算法处理大量数据时扩展到多台机器。
9 使用 GPU 计算进行数据分析
本章涵盖
-
使用 GPU 架构改进许多数据分析算法
-
使用 Numba 将 Python 代码转换为高效的 GPU 低级代码
-
编写高度并行的 GPU 代码以处理矩阵
-
使用来自 Python 的 GPU 原生数据分析库
图形处理单元(GPUs)最初是为了使图形应用程序更高效而设计的:绘图和动画软件、计算机辅助设计和当然,游戏!
在某个时候,变得很明显,GPU 不仅能够进行图形处理,还可以用于各种计算,因此出现了通用计算在图形处理单元(GPGPUs)上的应用。GPU 吸引人的地方在于它们比 CPU 具有更多的计算能力。它们已经在许多应用中取得了成功,例如科学计算和人工智能。它们在数据科学和使计算更有效方面有巨大的应用。
GPGPUs 使用 GPU 硬件强加的架构和编程范式,并考虑了两个关键因素。首先,它们需要做大量的计算,因为图形非常数据密集。其次,它们需要同时处理许多相似的数据点,因为图形多处理器中的每个像素都是在同一时间计算的。这些需求对 GPU 设计有很大影响。例如,GPU 有许多、许多处理单元,通常有数千个,它们同时在执行大多数相似的任务。相比之下,典型的 CPU 只有几个处理单元,每个处理单元在相同的时间点执行不同的事情。GPU 的处理速度来自处理单元的数量。实际上,每个单独的核心并不非常快,至少与 CPU 核心相比。因此,GPU 是高度并行的。
这些关键的硬件差异意味着为 GPU 编码与为 CPU 编码非常不同。这不仅仅是一个重新编译现有代码的问题。为 GPU 编码,至少当我们明确关注它时,意味着我们在程序员思维模式上发生了巨大的范式转变。
GPU 计算在数据分析中可能具有优势,但许多人因为需要不同的思维方式而放弃学习为 GPU 编码。因此,本章重点介绍高性能 GPU 编码中最重要的一步:转向那种新的思维方式。与本书中的其他章节不同,我将将其视为一种方法和思维方式介绍的更多。因此,我将在一定程度上简化材料并跳过细节。我们不会讨论一些重要的话题,如线程同步,并且我们将假设可以轻易并行化的问题,这样我们就可以专注于理解编程范式。数据科学中的许多问题实际上很容易并行化,因此这里的内容对我们的领域非常适用。
小贴士:您可以从其他来源获取更多关于 GPU 计算的深入信息。我推荐 NVIDIA 的 CUDA C++ 编程指南 (mng.bz/61Bp)。虽然它针对 C 和 C++,但前四章将为您提供对 GPU 架构和编程概念的多数语言无关的视角。Bob Robey 和 Yuliana Zamora 的《并行与高性能计算》(Manning,2021)的第三部分专门介绍 GPU;您可以通过以下链接免费阅读第九章,“GPU 架构和概念”:mng.bz/oJ6y。
我们将从查看 GPU 架构及其对算法和软件开发的影响开始。我将假设您没有先前的知识,并展示 GPU 的工作原理。
由于 Python 代码不能直接在 GPU 上运行,我们将使用 Numba,它是一种将 Python 转换为机器代码的翻译器,它适用于 GPU 和 CPU。Numba 在运行时将您的 Python 代码编译为与您的 CPU 或 GPU 兼容的更低级表示。您可以在附录 B 中找到 Numba 的介绍。我们的示例将明确在 GPU 上部署 Python 代码。
在完成困难的部分,即理解 GPU 编程模型的基础之后,我们将利用高级数据分析库。虽然您可以直接编程 GPU,但您也可以通过库使用 GPU,这些库为您处理了大多数实现细节。在这里,您将隐式地使用 GPU,因为外部库将在 GPU 上部署计算。例如,我们将用 CuPy 替换 NumPy。正如我们将看到的,尽管库消除了在 GPU 上运行代码的大部分负担,但这并不仅仅是替换库的问题。让我们从理解编码范式和性能所需的必要变化的角度开始理解 GPU 架构。
注意:本章需要访问 GPU——即最新的 NVIDIA GPU(即 Pascal 架构或更新的架构)。因此,本章依赖于供应商。虽然我更愿意进行供应商无关的内容,但现实是 GPU 计算主要在 NVIDIA GPU 上使用 CUDA 架构进行。这在 Python 世界中尤为重要,例如 CuPy 或 cuDF 这样的库。如果您想研究供应商无关的 GPGPU 计算方法,请查看 OpenCL (www.khronos.org/opencl/) 或 Vulkan (www.vulkan.org/)).
您必须确保您的安装具有进行 GPGPU 计算所需的全部 NVIDIA 驱动程序。您需要安装 CUDA 工具包以及 CuPy,以便在本章中运行软件。这可以通过 conda 完成,命令为 conda install -c rapidsai -c nvidia -c numba -c conda-forge cupy cudatoolkit。有一个名为 tiagoantao/python-performance-gpu 的用于 GPU 处理的 Docker 镜像。
9.1 理解 GPU 计算能力
对于某些类别的算法,GPU 可以比 CPU 表现得更好几个数量级。在本节中,我们将探讨 GPU 架构,目的是了解何时以及为什么 GPU 在数据分析问题中可以更有效率。
9.1.1 理解 GPU 的优势
为了理解为什么 GPU 如此高效,我们将通过一个实际例子来展示一个简化的 CPU 和 GPU 执行的概念模型。这个简单例子的目的是让你了解为什么 GPU 在许多,但只有某些类别的并行问题中表现如此出色。
考虑这样一个简单的问题:获取一个包含 100 个元素的数组,并返回其加倍版本:
import numpy as np
a = np.ones(100)
b = np.empty(100)
for i in range(100):
b[i] = 2 * a[i]
如果你有一个简单的单线程单核 CPU,那么前面代码的低级实现,例如伪代码汇编器,可能是这样的:
TMPVAR = A[0] ①
TMPVAR = 2 * TMPVAR ②
B[0] = TMPVAR ③
TMPVAR = A[1]
TMPVAR = 2 * TMPVAR
B[1] = TMPVAR
...
TMPVAR = A[99]
TMPVAR = 2 * TMPVAR
B[99] = TMPVAR
① 将数组 A 的第一个元素取出并放入一个名为 TMPVAR 的寄存器中
② 将寄存器中的值加倍
③ 将寄存器的值放入数组 B 的第一个位置
这个伪代码将数组A的第一个元素放入寄存器中,将其值加倍,并将其放在数组B的第一个元素上。这会重复进行,直到我们的数组中的所有 100 个元素。
现在,记得从第六章中提到的,从主存中检索值是一个极其昂贵的操作。这里假设我们的简单 CPU 没有缓存。我们的读取和写入操作TMPVAR = A[0]和B[0] = TMPVAR每个都需要 90 个时间单位,我们的加倍操作TMPVAR = 2 * TMPVAR需要 2 个时间单位。我们有 100 次读取,100 次写入和 100 次加倍:10090 + 10090 + 100*2。这总共是 18,200 个时间单位。记住,我们的简单 CPU 是顺序的,所以一个操作只能在之前的操作完成后才能开始。
现在想象一个完全不同的执行模型,其中你有 100 个线程并行运行,并且每个线程执行一次内存读取,然后是一次加倍操作,最后是一次写入。假设读取和写入的成本相同:90 个时间单位。但是加倍的成本要高得多;我们有大量的计算单元,所以它们按单位来说较慢——比如说,40 个时间单位。
所以,所有线程同时发出内存读取请求。100 个时间单位后,所有线程收到数据,然后花费 40 个时间单位进行计算。记住,所有线程都是同时并行操作的,并且彼此独立。最后,它们以 100 个时间单位的成本并行写入内存。总成本是 100 + 40 + 100 = 240 个时间单位。
因此,我们的“CPU”需要 18,200 个时间单位,而我们的“GPU”只需要 240 个时间单位:这使得“GPU”大约快 75 倍。然而,如果你只想对一个单个值进行操作,那么“CPU”将需要 202 个时间周期,而“GPU”将需要 240 个时间周期。
这个例子应该能让你对 GPU 计算的优势和劣势有所了解。本质上,GPU 非常适合处理内存延迟以及在大量数据上执行类似操作,但它们在处理单个操作方面效率不高。用比喻来说,我们可以将 CPU 比作法拉利,将 GPU 比作公交车。如果你只需要运送 5 个人,法拉利会打败公交车。但如果你需要运送 500 人,这甚至不是一场公平的竞争。
许多开发者在学习或使用 GPU 时面临的最大障碍之一是克服他们直观的感觉,即大多数代码应该是顺序的。虽然确实有很多代码是顺序的,但许多(如果不是大多数)计算中的昂贵部分是非常并行的。一个典型的例子是屏幕上的像素。一个分辨率为 1920 × 1080 的 HD 屏幕上有 200 万个像素。每个像素都是独立处理的,因此,至少在理论上,我们可以并行处理所有这些像素。或者考虑 N 维数组,这正是数据科学中使用的数据结构类型:数组的每个元素都可以单独计算,因此所有元素都可以潜在地并行计算。所以 GPU 非常适合许多有趣的问题。
为了更清楚地了解为什么 GPU 适合高度并行的问题,我们需要更仔细地研究这些机器的架构。我们将在下一节中这样做。
9.1.2 CPU 和 GPU 之间的关系
GPU 的计算模型与 CPU 非常不同。为了有效地编程 GPU——实际上,为了编程 GPU——我们需要了解其底层架构以及它与我们所习惯的不同之处。
从所有意义上讲,GPU 都是一个协处理器。CPU 是主处理器,为其编写的代码控制计算的最高层。围绕 GPU 计算术语的命名清楚地说明了这种关系:主机指的是 CPU,而设备是 GPU。主机中的代码驱动整体计算过程。
从性能角度来看,对于绝大多数 CPU 和 GPU 架构,CPU 和 GPU 有不同的内存银行,它们彼此分离。因此,我们有主机内存(即主机可用的内存)和设备内存(即 GPU 可用的内存)。

图 9.1 CPU 和 GPU 有独立的内存空间。我们需要将我们的数据传输到与 GPU 相关的内存中进行计算,然后将结果传输回与 CPU 相关的内存,这可能会花费大量时间。
将数据传输到 GPU 内存以及从 GPU 内存中传输数据的成本可以对性能产生巨大影响,尤其是如果我们 GPU 上的计算量有限时。图 9.1 描绘了这种关系。
GPU 供应商和软件可移植性
在我们讨论 GPU 的内部架构之前,我想就 GPU 供应商和软件可移植性提出一个观点。对于通用计算,有两个主要的 GPU 供应商:NVIDIA 和 AMD。从理论上讲,存在无供应商特定的接口,允许我们以非供应商依赖的方式编程。如果您对无供应商特定感兴趣,您可以查看相关的软件解决方案,例如 OpenCL 或 Vulkan 计算 API。
实际情况是,NVIDIA 几乎完全主导了通用计算市场。这可以从 GPU 的底层编程中看出,NVIDIA 的 Compute Unified Device Architecture (CUDA)占主导地位,以及在 Python 级别,许多支持 GPU 的数据分析库都是基于 CUDA 的,例如 CuPy、CuDF、cuML 和 BlazingSQL。
在本章中,我们将仅基于 NVIDIA/CUDA 的 API。然而,从概念上讲,这些信息也可以转移到 AMD 空间以及无供应商特定的库中。
供应商依赖的问题延伸到了术语学:无供应商特定的术语可能与 NVIDIA 的术语不同。我将尽可能多地展示无供应商特定的术语,但当我遇到已经成为通用术语的 NVIDIA 等效术语时,我会给出相应的 NVIDIA 术语。
命名的问题进一步复杂化,因为除了供应商特定的术语外,还有一些基于 GPU 图形起源的词汇。例如,CUDA 核心也是一个流处理器——不要与流式多处理器混淆——以及着色器。
9.1.3 GPU 的内部架构
GPU 有几个流式多处理器(SMs)。数量可以从 1 到 30 甚至更多。每个 SM 由许多流处理器(SPs)组成,有时也称为 CUDA 核心。每个 SM 有许多 SPs。参见图 9.2 以了解主要 GPU 组件的简化概述。
例如,NVIDIA RTX 2070 基于 NVIDIA Turing 106 GPU,该 GPU 有 36 个 SM(Streaming Multiprocessors,流式多处理器),每个 SM 有 64 个 CUDA 核心,总共 2304 个 CUDA 核心;这个 GPU 就像能够同时运行 2,304 个线程。正如我之前所说的,我们在这里选择简化,因为架构实际上要复杂得多。对于数据科学来说,特别重要的是可以用于 AI 计算的新的张量核心;我们不会在本章中涵盖这些内容,但您可能想要研究它们以用于更高级的使用。
内存组织也很重要:每个 SM 都有一定量的 L1 缓存(检查第六章以了解缓存概念),这些缓存可以由同一 SM 中的所有 SP(Streaming Processors,流处理器)共享。我们不会直接使用 L1 缓存,它可以用于在同一个 SM 上运行的线程之间共享状态。还有 L2 缓存,它由所有 SM 共享,最后是 GPU 主内存。例如,TU 106 GPU 每个 SM 有 64 KB 的 L1 缓存和 4 MB 的 L2 缓存,RTX 2070 配备了 8 GB 的主内存。

图 9.2 主要 GPU 组件的简化概述:包含流处理器的流多处理器(CUDA 核心)和本地缓存。GPU 包括所有流多处理器以及一些额外的缓存和 GPU 主内存。
这将如何影响我们的编码和性能设计?你需要明确了解架构才能在上面运行代码,正如我们接下来将要看到的。
9.1.4 软件架构考虑因素
现在我们来看看硬件架构如何影响 GPU 代码设计。我们将查看运行我们之前的简单示例(将矩阵乘以 2)所需的步骤。在本章的后面,我们将实际编写这段代码,但现在让我们回顾一下高级步骤。
我们在 CPU 附近的内存中有我们的矩阵,所以我们需要做的第一件事是将它传输到 GPU 内存。这个操作可能相当昂贵,尤其是如果我们没有在 GPU 上做很多计算的话。
想象一个 1024 × 64 的矩阵(即 65,536 个元素)。在 GPU 中,每个元素将作为一个单独的线程来计算。因此,我们将有 65,536 个线程。每个线程将运行 相同的代码。线程需要被分成线程块;一个线程块中的所有线程都放置在同一个 SM 中,并且可以共享内存和同步原语。在我们的例子中,由于算法非常简单,不需要在不同线程之间共享任何内容,但我们仍然需要将我们的代码分成线程块。
例如,如果我们假设每个块有 32 个线程,那么我们需要 2048 个块。每个块可以在不同的 SM 上执行。
我们如何调用代码?记住 CPU 驱动一切,所以 CPU 将调用 GPU 上的一个入口点来驱动所有计算。这个入口点的名称是 内核函数。我们现在对编写 GPU 代码的基本原理有一个非常基本的了解:存在一个入口点——内核函数——以及相同的代码在许多线程上运行。
我们将部署低级代码到 GPU 上。因为 Cython 没有等效功能将 Python 转换为 OpenCL C(或 CUDA C),所以我们将使用 Numba。如果你从未使用过 Numba,请参阅附录 B,其中介绍了这项技术。
9.2 使用 Numba 生成 GPU 代码
在一些基本准备之后,我们最终将使用 Numba 编写我们的第一个 GPU 程序。为了理解 GPU 编码的基本问题,我们将从最简单的例子开始:将数组值加倍。之后,我们将实现一个 Mandelbrot 生成器,你可以在附录 B 中将其与 CPU 版本进行比较。再次提醒,如果你从未使用过 Numba,可以考虑先查看那个附录,其中介绍了在 CPU 上使用 Numba。
9.2.1 Python GPU 软件安装
在我们开始运行 GPU 代码之前,我们需要确保所有 GPU 的驱动程序和所需的软件都已安装。安装软件并不总是简单的事情。在这里,不可能为不同的操作系统和架构提供一般性的说明,但可以给出一些指导方针。
你可能需要安装内核驱动程序,可能需要重新启动你的机器。你需要 CUDA 工具包,它有不同的版本;如果你使用 Anaconda,执行conda install cudatoolkit可能是最简单的方法。
Numba 有测试现有基础设施和报告现有硬件和库的能力。要检查,请在以下 shell 中运行:
numba -s
你将非常详细地了解你的系统。为了确定 GPU 是否可用,你需要查看硬件是否被检测到并且库是否可用。对于硬件,搜索类似以下内容:
__CUDA Information__
CUDA Device Initialized : True
CUDA Driver Version : 11020
CUDA Detect Output:
Found 1 CUDA devices
id 0 b'Tesla T4' [SUPPORTED]
compute capability: 7.5
pci device id: 30
pci bus id: 0
Summary:
1/1 devices are supported
这段代码将允许你检查设备是否被检测并支持。一些较旧的 GPU 可能不受支持。你还需要确保所有库都被找到。报告的另一部分将包括类似以下内容:
CUDA Libraries Test Output:
Finding cublas from Conda environment
named libcublas.so.11.2.0.252
trying to open library... ok
Finding cusparse from Conda environment
named libcusparse.so.11.1.1.245
trying to open library... ok
Finding cufft from Conda environment
named libcufft.so.10.2.1.245
trying to open library... ok
Finding curand from Conda environment
named libcurand.so.10.2.1.245
trying to open library... ok
Finding nvvm from Conda environment
named libnvvm.so.3.3.0
trying to open library... ok
Finding libdevice from Conda environment
searching for compute_20... ok
searching for compute_30... ok
searching for compute_35... ok
searching for compute_50... ok
这段代码将允许你发现所有有问题的库。现在让我们写一些 GPU 代码。
9.2.2 使用 Numba 的 GPU 编程基础
在我们开始编写代码之前,让我们通过查看我们不希望做的事情来调整我们的心态。记住,我们只是尝试简单地加倍一个数组,所以以下是一个具有 CPU 思维的潜在解决方案:
def double_not_this(my_array):
for position in range(my_array): ①
my_array[position] *= 2
① for 是一个顺序操作。
这段代码是对数组的顺序循环。但我们的 GPU 代码将使用每个元素一个线程,所以我们的代码应该只处理一个单个元素。稍后,我们将确保 GPU 将我们的代码应用于数组的所有元素。以下是第一个版本:
from numba import cuda
@cuda.jit ①
def double(my_array):
position = cuda.grid(1) ②
my_array[position] *= 2
① 编译函数到 CUDA
② cuda.grid 访问数组中要处理的当前位置。
这里没有输出可以显示,因为所有这些都在 GPU 上发生。
是的,一个函数调用只处理单个元素!这种方法与我们习惯的非常不同——在非矢量化方法之外。
我们使用cuda.jit装饰器来注释我们的函数,这样 Numba 就会生成我们代码的 CUDA 版本。然后我们使用“魔法”函数cuda.grid来获取我们将要更改的单个唯一位置;我们稍后会看到那里发生了什么。最后,我们根据位置更改数组中的单个条目。
我们的功能不能返回值,因为它将被实现为一个 GPU 内核函数,所以我们需要传递参数以适应返回值。如果我们尝试执行代码
my_array = np.ones(1000)
double(my_array)
我们将得到一个错误:
Kernel launch configuration was not specified. Use the syntax:
kernel_functionblockspergrid, threadsperblock
这个错误是因为我们必须还告诉 Numba 如何分配计算。正如在架构部分所暗示的,我们必须将计算分成线程块,每个块分成网格。这将有效:
import numpy as np
blocks_per_grid = 50
threads_per_block = 20
my_array = np.ones(1000)
doubleblocks_per_grid, threads_per_block ①
assert (my_array == 2).all() ②
① 函数调用的语法不是非常符合习惯。
② 我们想检查函数是否已应用于数组的所有元素。
注意,调用函数的语法并不十分地道。最后,我们使用 assert 来检查所有元素是否为 2。我们试图小心行事,因为如果我们提供的方块数量不正确,可能不是所有的数组都会被计算。
我们在每个方块中发出 20 个线程,因为我们有 1,000 个元素,我们需要 50 个方块。通常,32 个线程是常见的。考虑到 GPU 的内存层次结构,同一方块中的线程可以非常快速地共享一些状态。在这里,我们不会关注这些类型的算法,因为它们属于更高级的阶段。因此,我们可以相当灵活地在 GPU 上分配我们的代码。
话虽如此,有时不可能使 blocks * threads 等于我们数组中的元素数量(例如,在大小为素数的数组中)。在这种情况下,我们必须指定一个比我们的数组稍大的 blocks * threads。以下是一个例子:
threads_per_block = 16
blocks_per_grid = 63
my_array = np.ones(1000)
doubleblocks_per_grid, threads_per_block
assert (my_array == 2).all()
在这种情况下,我们将有总共 1,008 个线程(16*63)。如果你很幸运,这段代码会工作。也有可能它会崩溃!
你正在调用位置 0 到 1007 的代码,最后八个位置,即 1000 到 1007,尚未分配。现在,你必须停止使用 Python 的思维方式,并记住你的代码已被转换为低级语言。这种转换意味着所有标准的 Python 边界检查将不可用,而且你可能会得到“奖品”,即内存分配错误,或者更糟糕的是,一个静默的错误。我们稍后会看到一个这种错误的例子。
解决这个问题相当简单:
@cuda.jit
def double_safe(my_array):
position = cuda.grid(1)
if position > my_array.shape[0]:
return
my_array[position] *= 2
我们检查位置是否大于数组大小,如果是,则返回。现在我们可以自信地调用代码:
my_array = np.ones(1000)
double_safeblocks_per_grid, threads_per_block
assert (my_array == 2).all()
最后,在这个阶段,我们正确地在 GPU 上调用了代码!
现在,让我们回到 cuda.grid 的“魔法”,以获取要计算的位置。我们将通过自己编写代码来理解这个调用发生了什么。有时必须自己编写代码(例如,对于具有超过三个维度的数组):
@cuda.jit
def double_safe_explicit(my_array):
position = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
if position >= my_array.shape[0]:
return
my_array[position] *= 2
线程调用可以访问它正在运行的方块和线程。它还可以访问方块维度。cuda.blockIdx 给出当前线程正在运行的方块索引。cuda.blockDim 提供方块维度,而 cuda.threadIdx 给出组内的线程。有了这些信息,你可以确保每个线程都指向数组中的不同位置。
线程的位置信息可以是一维、二维和三维:你可能已经注意到了所有 CUDA 调用中的 .x 参数;如果你有二维和三维数组,可以使用 .y 和 .z 参数。
现在,让我们看看同一个函数,但针对二维数组:
@cuda.jit
def double_matrix_unsafe(my_matrix):
x, y = cuda.grid(2)
my_matrix[y, x] *= 2
我们现在使用 cuda.grid(2) 来获取两个索引。注意,我们回到了不安全的代码,因为我们现在可以相当肯定我们会触发一个错误。让我们运行这段代码:
threads_per_block_2d = 16, 16
blocks_per_grid_2d = 63, 63
my_matrix = np.ones((1000, 1000))
double_matrix_unsafeblocks_per_grid_2d, threads_per_block_2d
print((my_matrix == 2).all())
这将打印True,因为矩阵的所有元素现在都是 2。注意,我们定义的块和线程是二维的,就像我们的数据一样。
如果您运行这段代码并且足够幸运以至于它没有崩溃,它几乎肯定会返回错误的结果。这种结果发生是因为我们没有测试矩阵边界,当您超出一行时,您将落在矩阵的下一行上,这在单维数组中是不可能的。因此,可能存在值为 4 而不是 2 的位置,因为代码将在那里执行两次。
正确处理这个问题相当简单,正如我们之前看到的。以下是我们最终的版本,我们在其中也明确指出了矩阵索引:
@cuda.jit
def double_matrix(my_matrix):
x = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
y = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y
if x >= my_matrix.shape[0]:
return
if y >= my_matrix.shape[1]:
return
my_matrix[y, x] *= 2
现在我们已经涵盖了基础知识,让我们使用 GPU 重新创建我们的 Mandelbrot 示例。
9.2.3 重新审视使用 GPU 的 Mandelbrot 示例
我们已经几乎涵盖了所有的 Numba 概念,所以让我们创建一个使用 GPU 的 Mandelbrot 渲染器。为了将这些概念结合起来构建渲染器,我们将采取一条迂回的路线,遵循看似合理的步骤,但最终会遇到几个不工作的死胡同。这里的目的是说明并解释为什么这些步骤不工作,希望这种理解能帮助您避免仅仅跟随直觉的陷阱。
让我们先实现 Mandelbrot 函数来计算单个点的值:
from numba import cuda
@cuda.jit(device=True)
def compute_point(c):
i = -1
z = complex(0, 0)
while abs(z) < 2:
i += 1
if i == 255:
break
z = z**2 + c
return 255 - (255 * i)
注意device=True添加到cuda.jit装饰器中。我们正在告诉 Numba 这个函数需要在设备内部调用。与内核函数不同,设备函数可以返回值。
接下来,我们将实现一个看似合理但实际上不工作的第一个版本:
@cuda.jit
def compute_all_points_doesnt_work(start, end, size, img_array):
x, y = cuda.grid(2)
if x >= img_array.shape[0] or y >= img_array.shape[1]:
return
mandel_x = (end[0] - start[0])*(x/size) + start[0]
mandel_y = (end[1] - start[1])*(y/size) + start[1]
img_array[y, x] = compute_point(complex(mandel_x, mandel_y))
当这个函数编译时,如果您尝试调用它,您将得到以下结果:
NotImplementedError: (UniTuple(float64 x 2), (-1.5, -1.3))
这里的问题在于 Numba 无法处理元组作为输入参数(至少在目前这个时间点是这样)。但这指向了一个更大的问题,我们需要记住:一些 Python 功能不被 Numba 支持。因此,请务必检查 Numba 的文档(numba.pydata.org/)以确定哪些函数被支持。在这里讨论哪些具体功能不被支持是没有意义的,因为 Numba 一直在变化。完全有可能在我写这段话和您阅读这段话之间,Numba 已经支持了新的功能。
因此,为了使这个解决方案工作,我们必须创建一个不带元组作为输入参数的版本:
@cuda.jit
def compute_all_points(startx, starty, endx, endy, size, img_array):
x, y = cuda.grid(2)
if x >= img_array.shape[0] or y >= img_array.shape[1]:
return
mandel_x = (end[0] - startx)*(x/size) + startx
mandel_y = (end[1] - starty)*(y/size) + starty
img_array[y, x] = compute_point(complex(mandel_x, mandel_y))
在最后一行有一个值得注意的细节:记住,对于 NumPy 数组,y坐标在前,所以我们应该写img_array[y, x]。
现在让我们进行调用:
from math import ceil
import numpy as np
from PIL import Image
size = 2000
start = -1.5, -1.3
end = 0.5, 1.3
img_array = np.empty((size, size), dtype=np.uint8)
threads_per_block_2d = 16, 16
blocks_per_grid_2d = ceil(size / 16), ceil(size / 16)
compute_all_pointsblocks_per_grid_2d,
threads_per_block_2d
img = Image.fromarray(img_array, mode="P")
img.save("mandelbrot.png")
我希望您在这里注意的最重要的一点是块数的指定:鉴于我们在每个维度上每个块有 16 个线程,我们需要有size / 16个块。由于这个数字可能不是整数,我们必须向上取整以确保所有点都被覆盖。
我们可以计时这个操作:
In [3]: %timeit compute_all_points[blocks_per_grid_2d, ...
72.6 ms ± 50.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
这与附录 B 中展示的最佳 CPU 版本的 539 ms 相比。然而,说实话,这并不是一个公平的比较,因为它将一个较差的 CPU 与一个优秀的 GPU 进行了比较。此外,还有许多其他因素,如算法类型和 CPU 到 GPU 的内存传输,对速度有巨大影响。尽管如此,应该很清楚,对于某些算法,GPU 可以提供比 CPU 更高的性能。
现在我们已经有了基于 GPU 的 Mandelbrot 生成器,性能有了显著提升,让我们再创建另一个 Mandelbrot 生成器。这次,我们将使用 NumPy 向量化来实现,因为它在加速数据分析方面非常有用,正如我们之前所看到的。
9.2.4 Mandelbrot 代码的 NumPy 版本
我们最终的版本是一个在 GPU 上运行的 NumPy 通用函数。我们已经讨论了所有必要的部分,所以应该很容易将这些部分组合起来。以下是计算点及其向量化版本:
from cuda import vectorize
size = 2000
start = -1.5, -1.3
end = 0.5, 1.3
def compute_point_255_fn(c):
i = -1
z = complex(0, 0)
while abs(z) < 2:
i += 1
if i == 255: ①
break
z = z**2 + c
return 255 - (255 * i) // 255
compute_point_vectorized = vectorize(
["uint8(complex128)"], target="cuda")(compute_point_255_fn)
① 我们将使用一个更简单的点计算版本,其中交互限制是硬编码的。
这段代码的唯一小创新是在最后一行的vectorize调用中使用target="cuda"。
记住上一节中我们提到,我们需要准备一个数组,其中包含我们想要进行计算的位位置:
def prepare_pos_array(start, end, pos_array):
size = pos_array.shape[0]
startx, starty = start
endx, endy = end
for xp in range(size):
x = (endx - startx)*(xp/size) + startx
for yp in range(size):
y = (endy - starty)*(yp/size) + starty
pos_array[yp, xp] = complex(x, y)
pos_array = np.empty((size, size), dtype=np.complex128)
img_array = np.empty((size, size), dtype=np.uint8)
我们现在可以计时这个版本的执行:
In [6]: %timeit compute_point_vectorized(pos_array)
222 ms ± 3.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
这些数字比之前的 GPU 版本更差,但仍然比 CPU 版本好。这里的模式与 CPU 版本不同:在 CPU 版本中,最快的代码是使用通用函数。
由于计算模型,NumPy 的功能在 CUDA 中受到限制。如果有一个原生的 NumPy GPU 实现不是很好吗?这就是 CuPy 的出现……
9.3 GPU 代码的性能分析:CuPy 应用的案例
在本节中,我们将使用 NumPy 的原生 GPU 版本:CuPy 来实现一个解决方案。
注意:许多基于 CPU 的数据分析库都有 GPU 对应版本。因此,你可以使用 GPU,即使对 GPU 代码的工作原理知之甚少。因此,我们将首先列出现有的基于 GPU 的数据分析库版本。
在我们创建我们的 CuPy 解决方案后,我们将使用我们的代码来讨论分析 GPU 代码的技术。我们的 CuPy 示例将作为介绍分析 GPU 解决方案性能的工具的借口。但在讨论代码或分析之前,让我们先概述现有的基于 GPU 的数据科学库。
9.3.1 基于 GPU 的数据分析库
如果你可以访问 GPU,你不必从头开始编写代码。有几个基于 GPU 的库提供了类似的功能——许多情况下与现有的 CPU 数据库接口非常接近。在许多情况下,你不需要了解任何关于 GPU 编程的知识。表 9.1 提供了一个当前存在的库列表及其 CPU 对应版本。
表 9.1 基于 GPU 的库及其 CPU 对应版本
| GPU | CPU | 目的 |
|---|---|---|
| cuBLAS | BLAS | 基本线性代数 |
| CuPy | NumPy | N 维数组处理 |
| CuDF | pandas | 列数据分析 |
| CuGraph | 数据帧的图算法 | |
| CuML | scikit-learn | 机器学习 |
| BlazingSQL | 基于列数据的 SQL 接口 |
其他库能够加速现有的分析代码。例如,cuDNN 可以提高 PyTorch 或 TensorFlow 等机器学习库的性能。
你可以根据 GPU 数据分析项目考虑使用这些库。例如,我们将基于 CuPy 开发一个项目。
9.3.2 使用 CuPy:NumPy 的 GPU 版本
我们将使用高级数据科学库 CuPy 开发一个项目。CuPy 是 NumPy 的 GPU 版本。许多高级 GPU 库与它们的 CPU 对应库具有相似的接口,因此在那个层面上不需要介绍太多新信息。但是,除了能够展示基于 GPU 的数据科学代码的真实示例外,我们还可以使用这个示例中生成的代码来介绍 GPU 代码的剖析工具。我们的项目将是,不出所料,基于 CuPy 数组的曼德布罗特生成器。
9.3.3 与 CuPy 的基本交互
在我们实现曼德布罗特生成器之前,让我们做一些基本的 CuPy 工作,这将使我们能够讨论 CuPy 中的底层机制。我们将简单地创建一个 5000 × 5000 的矩阵并将其加倍:
import numpy as np
import cupy as cp
size = 5000
my_matrix = cp.ones((size, size), dtype=cp.uint8)
print(type(my_matrix))
np_matrix = my_matrix.get()
print(type(np_matrix))
2 * my_matrix
2 * np_matrix
虽然 CuPy 和 NumPy 具有相似的接口,但它们是不同的库,并暴露了不同的对象类型。你可能会在许多分析中同时导入它们,这并不罕见。
my_matrix 的类型将是 cupy._core.core.ndarray,而 np_matrix 的类型将是 numpy.ndarray。my_matrix 的数据位于 GPU 内存中,因此当你想要对其执行操作时,不会有从 CPU 到 GPU 的内存传输。例如,乘法 2 * my_matrix 完全在 GPU 上执行。当你显式地执行 my_matrix.get() 时,将会发生从 GPU 端的内存传输,这将创建原始矩阵的独立 NumPy 表示。
基本上对 GPU 代码进行性能分析时,不应使用典型的 Python 工具,如 timeit 模块或 IPython 的 %timeit 魔法。GPU 代码独立于 CPU 代码执行,CPU 视角下的执行时间并不能代表 GPU 成本。
CuPy 提供了一种简单的机制来剖析代码。让我们运行 2 * my_matrix 200 次,看看它的成本:
from cupyx.time import repeat
print(repeat(lambda : 2 * my_matrix, n_repeat=200))
在我的机器上的输出是:
<lambda> : CPU: 60.910 us +/-14.344
(min: 19.158 / max: 101.755) us
GPU-0: 785.708 us +/-12.013
(min: 749.760 / max: 822.656) us
因此,平均每次执行需要 60 µs 的 CPU 时间和 785 µs 的 GPU 时间。我在一个配备 2.50 GHz 英特尔 Xeon 处理器的 Tesla T4 GPU 上运行了此代码。
现在,让我们继续前进,并最终使用 CuPy 实现 Mandelbrot 生成器。我们在这里的更大目标是展示接口,因为按照设计,它应该与 NumPy 类似。我们也不会讨论 CuPy 相比 NumPy 的局限性,因为这些会随着时间的推移而变化,当你阅读这篇文章时,它们可能已经发生了变化。
我们下一个两个 Mandelbrot 实现的目标是探索如何从 GPU 中提取最大性能。我们将编写使用 CuPy 在 GPU 上工作的处理函数。我们的第一个将展示 CuPy 与 Numba 的交互。
9.3.4 使用 Numba 编写 Mandelbrot 生成器
CuPy 与 Numba 无缝交互:你可以编写一个 Numba 装饰的函数,并使用 CuPy。
提示:CuPy 有一个将 Python 代码转换为 GPU 代码的转换器,它与 Numba 有一定的竞争力。在当前阶段,与 Numba 相比,它对 Python 特性的支持相当有限。我建议首先尝试 Numba,尽管随着时间的推移,也许 CuPy 的本地转换器会变得更加功能完善。
以下是我们使用 Numba 编写的 Mandelbrot 生成器的实现,它与 CuPy 一起工作:
from math import ceil
import numpy as np
import cupy as cp
from numba import cuda
from PIL import Image
size = 2000
start = -1.5, -1.3
end = 0.5, 1.3
@cuda.jit
def compute_all_mandelbrot(startx, starty, endx, endy, size, img_array):
x, y = cuda.grid(2)
if x >= img_array.shape[0] or y >= img_array.shape[1]:
return
mandel_x = (end[0] - startx)*(x/size) + startx
mandel_y = (end[1] - starty)*(y/size) + starty
c = complex(mandel_x, mandel_y)
i = -1
z = complex(0, 0)
while abs(z) < 2:
i += 1
if i == 255:
break
z = z**2 + c
img_array[y, x] = i
与我们在 Numba GPU 部分讨论的内容相比,从概念上讲,这段代码中没有什么真正新的东西。调用代码也是如此:
threads_per_block_2d = 16, 16
blocks_per_grid_2d = ceil(size / 16), ceil(size / 16)
cp_img_array = cp.empty((size, size), dtype=cp.uint8)
compute_all_mandelbrotblocks_per_grid_2d, threads_per_block_2d
剩下的唯一事情就是保存我们的图像:
img = Image.fromarray(cp.asnumpy(cp_img_array), mode="P")
img.save("imandelbrot.png")
在这里,我们需要将我们的 CuPy 数组转换为 NumPy 版本,以便能够使用 Pillow 库创建图像表示。这意味着数据将从 GPU 传输到 CPU 内存。
让我们做一些基本的性能分析:
from cupyx.time import repeat
print(repeat(
lambda: compute_all_mandelbrotblocks_per_grid_2d, threads_per_block_2d,
n_repeat=200))
这里报告的性能是:
<lambda> : CPU: 684.475 us +/-76.369
(min: 629.685 / max: 1387.853) us
GPU-0:70604.003 us +/-89.377
(min:70519.264 / max:71290.688) us
看起来repeat更倾向于报告 70,600 µs 而不是更直观的 70 ms。现在我们已经有了基于 CuPy 的 Mandelbrot 生成器的第一个版本,让我们来做第二个版本,将 CUDA C 代码嵌入到我们的 Python 代码中。
9.3.5 使用 CUDA C 编写 Mandelbrot 生成器
我们将创建一个向量化函数来生成 Mandelbrot 集。我们的向量化函数将接收一个包含所有位置的矩阵,并计算每个位置的 Mandelbrot 值。我们将使用 CUDA C 实现我们的函数。
与 NumPy 版本一样,我们首先准备位置数组。我们实际上会在 NumPy 中做这件事,然后将其传输到 CuPy:
def prepare_pos_array(start, end, pos_array):
size = pos_array.shape[0]
startx, starty = start
endx, endy = end
for xp in range(size):
x = (endx - startx)*(xp/size) + startx
for yp in range(size):
y = (endy - starty)*(yp/size) + starty
pos_array[yp, xp] = complex(x, y)
pos_array = np.empty((size, size), dtype=np.complex64)
prepare_pos_array(pos_array)
cp_pos_array = cp.array(pos_array)
输入准备代码与之前完全相同。在最后一行,我们将 NumPy 数组转换为 GPU 上的 CuPy 版本,这需要内存传输。
我们现在必须准备threads_per_block和blocks_per_grid变量。为了使我们的 C 代码尽可能简单,我们将在一个维度上工作,而不是两个维度:
threads_per_block = 16 ** 2
blocks_per_grid = ceil(size / 16) ** 2
我们根据需要调整一维块和每个块中的线程数。以下是我们的实现:
c_compute_mandelbrot = cp.RawKernel(r'''
#include <cupy/complex.cuh>
extern "C" __global__
void raw_mandelbrot(const complex<float>* pos_array,
char* img_array) {
int x = blockDim.x * blockIdx.x + threadIdx.x;
int i = -1;
complex<float> z = complex<float>(0.0, 0.0);
complex<float> c = pos_array[x];
while (abs(z) < 2) {
i++;
if (i == 255) break;
z = z*z + c;
}
img_array[x] = i;
}
''', 'raw_mandelbrot')
这本书的目的是不教你 C 语言,所以我们不会深入探讨这个列表的细节,但代码的设计是以简洁为原则的,应该容易理解。就像以前一样,我们不是关注我们如何决定计算的位置,而是关注blockDim.x * blockIdx.x threadIdx.x。实际上,C 代码是将矩阵视为一维数组,这样是可行的。
最后,让我们使用前面的函数从位置数组计算 Mandelbrot 集:
c_compute_mandelbrot((blocks_per_grid,),
(threads_per_block,), (cp_pos_array, cp_img_array))
img = Image.fromarray(cp.asnumpy(cp_img_array), mode="P")
img.save("cmandelbrot.png")
注意调用函数时的语法,同时指定每个块和每个块中的线程数:这与 Numba 方法不同。我们通过将 CuPy 数组传输到 NumPy 版本来打印它来结束。
让我们做一些基本的性能分析:
from cupyx.time import repeat
print(repeat(
lambda: c_compute_mandelbrot((blocks_per_grid,),
(threads_per_block,), (cp_pos_array, cp_img_array)),
n_repeat=200))
在我的机器上,我得到:
<lambda> : CPU: 6.677 us +/- 2.769
(min: 4.377 / max: 25.978) us
GPU-0: 3149.825 us +/-801.397
(min: 2635.584 / max: 5881.088) us
这个 3.1 毫秒的结果比 Numba 版本快 20 倍。如果你的 Numba 代码仍然不够快,你可以采取一个最后的步骤:嵌入 CUDA C 实现。
既然我们已经有一些利用 GPU 的代码,让我们了解一下一些 GPU 性能分析工具。
9.3.6 GPU 代码的性能分析工具
在这里,我们将使用 NVIDIA 性能分析工具的一些基本功能来分析我们 Mandelbrot 实现的性能。这些分析工具是通用的:它们不依赖于 CuPy,甚至不依赖于 Python——你可以用任何 GPU 代码使用它们。为了演示这一点,我们将使用向量化 GPU 实现来分析我们 Mandelbrot 代码的 NumPy 版本。
我们将使用 NVIDIA 的 Nsight Systems 来进行性能分析。我们将假设离线使用,并使用 Nsight 的 GUI 分别捕获性能分析并进行分析。这是最灵活的方法,因为它假设 GPU 机器与分析机器是分开的——例如,当 GPU 机器在云端,而你查看性能数据在本地机器上时。
安装 Nsight Systems 后,我们可以通过以下步骤轻松地进行代码分析:
nsys profile -o numba python mandelbrot_numba.py
nsys profile -o c python mandelbrot_c.py
要使用向量化 GPU 实现来分析 NumPy 版本,我们可以这样做:
nsys profile -o numpy python ../sec3-gpu/mandelbrot_numpy.py
我们现在有三个性能分析跟踪:numba.qdrep、c.qdrep 和 numpy.qdrep。
记住,我们 NumPy 版本的平均 timeit 是 222 毫秒,而 Numba 版本的 GPU 成本是 cupyx.time.repeat 的 70 毫秒,CUDA C 版本是 3 毫秒。
我们可以从每个版本收集一些基本的性能统计信息。让我们从 NumPy 版本开始:
nsys stats numpy.qdrep
这段代码将产生大量的输出。让我们关注主要的 GPU 调用:
Time% Total ns Calls Avg ns Min ns Max ns StdDev Name
------ --------- ----- --------- --------- ---------- ------ --------------
96.1 368748545 1 368748545 368748545 368748545 0 cuMemcpyDtoH
3.6 13654495 1 13654495 13654495 13654495 0 cuMemcpyHtoD
0.1 540957 2 270478 234726 306231 50561 cuMemAlloc
0.1 371672 1 371672 371672 371672 0 cuModLdDataEx
0.0 133176 1 133176 133176 133176 0 cuLinkComplete
0.0 66248 1 66248 66248 66248 0 cuLinkCreate
0.0 49602 1 49602 49602 49602 0 cuMemGetInf
0.0 37495 1 37495 37495 37495 0 cuLaunchKernel
0.0 2071 1 2071 2071 2071 0 cuLinkDestroy
注意,我们的实现花费了大量时间在 GPU 内部复制数据:这些是 cuMemcpyDtoH 和 cuMemcpyHtoD 调用,它们占用了超过 99% 的时间。
我们还可以检查仅计算(即内核)部分的成本。以下是一个简化的 NumPy 版本:
Time(%) Total Time (ns) Name
------- --------------- ----------------------------------------------
100.0 365860777 cudapy::__main__::__vectorized_compute_point ...
时间成本是 365860777 纳秒,或 365 毫秒。
使用 Numba 的 CuPy 版本的时间成本为:
Time(%) Total Time (ns) Name
------- --------------- ------------------------------------------
100.0 189965876 cudapy::__main__::compute_all_mandelbrot ...
结果是 180 毫秒。
最后,CUDA C 的时间成本为:
Time(%) Total Time (ns) Name
------- --------------- --------------
100.0 5876134 raw_mandelbrot
结果是 5.8 毫秒。
NumPy 版本比 Numba 慢两倍。对于内核执行,C 版本比 Numba 版本快 32 倍。
Nsight Systems 有一个出色的 GUI,通过nsys-ui调用,允许您实时探索跟踪并跟踪执行。虽然很难在屏幕截图中捕捉到这种动态,但图 9.3 显示了我们的 Mandelbrot 生成器 C 版本的跟踪的放大部分。应用程序可以通过 CPU 和 GPU 事件进行跟踪,但在这里我们专注于 GPU 事件。您可以看到两个相关的块。首先,在左侧的块中,主机到设备的传输正在将 NumPy 矩阵的位置复制到 GPU 上的 CuPy 版本:cp_pos_array = cp.array(pos_array)。第二个块实际上正在执行 Mandelbrot 计算。

图 9.3 Nsight Systems 的 GUI。左上角窗口:所有进程的概览,显示 GPU 和 CPU 的使用情况。主窗口:执行的时间视图。左下角:几个 GPU 操作的时间统计。右下角:主窗口中一个块的详细信息。
总结本节最重要的两个要点:
-
就像 CPU 一样,如果您遇到性能问题,最好是使用适当的分析来量化问题,而不是猜测。
-
如果存在模仿您已知的 CPU API 的 GPU 库,那么最有效的方法可能是使用这些库,而不是从头开始编写代码来实现相同的功能。
摘要
-
虽然 CPU 提供了一些非常快且可以处理不同问题的计算单元,但 GPU 通常提供成千上万的计算单元,这些单元速度较慢,预期执行相似的工作负载。
-
GPU 提供了非常适合高效数据处理的计算能力,因为许多数据科学问题依赖于矩阵等数据结构,这些数据结构可以通过在多个数据点上使用相同的算法进行并行化。
-
虽然有多个 GPU 制造商,但事实上,GPU 计算的标准是基于 NVIDIA 硬件的。
-
当为 GPU 编写代码时,我们需要意识到 GPU 的计算模型与 CPU 非常不同,并且需要与传统顺序 CPU 计算不同的思维方式。
-
标准 Python 代码不能直接在 GPU 上运行;我们需要考虑替代方案来探索 GPU 的强大功能。
-
已经有许多 Python 库允许使用 GPU,而无需直接了解如何编程它们。
-
许多 Python 库几乎是现有 CPU 版本的直接替代品。例如,CuPy 在 GPU 上工作时提供了一个与 NumPy 相似的接口,而 cuDF 则有一个与 pandas 相似的接口。
-
Numba 可以为 GPU 生成代码,但仅仅用 Numba 注释现有的 Python 代码价值不大:代码需要重新设计,以探索许多算法在大型数组上工作时的极端并行性。
-
Numba 代码,即使是针对 GPU 的,也可以与 NumPy 无缝交互,允许它卸载高度并行的算法,同时仍然与传统的 Python 数据分析堆栈集成。
10 使用 Dask 分析大数据
本章节涵盖
-
在具有极大数据集的许多机器上扩展计算
-
介绍 Dask 的执行模型
-
使用
dask.distributed调度器执行代码
处理大量数据有时需要不止一台计算机,因为数据量过大难以处理,或者算法需要大量的计算能力。在本书的这一阶段,我们已经知道如何设计更高效的计算过程,以及如何更智能地存储和结构化我们的数据以便处理。最后一章将介绍如何 扩展——也就是说,使用多台计算机来执行计算。
为了扩展,我们将使用 Dask,这是一个用于分析并行计算的库。Dask 与 Python 生态系统中的其他库,如 NumPy 和 pandas,集成得非常好。Dask 将满足我们扩展 (即使用多台计算机) 的需求。然而,它也可以用于扩展 (即更有效地使用单台计算机中的计算结果)。从这个意义上说,它可以作为第三章中关于并行性的材料的替代品。
除了 Dask 之外,还有其他替代方案,Spark 是最常见的一个。Spark 来自 Java 领域,与 Dask 相比,与其他 Python 库的集成程度较低。因此,我更喜欢使用原生 Python 解决方案,这简化了与 Python 生态系统的交互。这里练习的许多概念仍然可以用于其他框架。
Dask 有几个不同的编程接口。在较高层次上,一些 API 与 NumPy、pandas 和其他分析库类似。然而,如果您熟悉原始库,Dask 的接口易于使用,它允许使用超出内存的对象,如数据框和数组,这是 pandas 或 NumPy 所不具备的。在较低层次上,一个接口基于 concurrent.futures(参见第三章),另一个允许您使用 Dask 并行化更通用的代码(即不仅基于数组和数据框)。
本章的主要目标是帮助您理解 Dask 的底层执行模型,以及调度替代方案和超出内存的数据使用。虽然我们将讨论一些性能问题,但我相信从理解 Dask 计算模型的底层结构中可以获得更多收益。执行环境可能差异很大——从单机到非常大的集群——这可能会使具体的性能建议变得无效,甚至有害。因此,本章采用了与本书大部分章节不同的方法:您将获得基本构建块,并将不得不根据您的特定环境对其进行调整。
与 pandas 或 NumPy 等库相比,Dask 具有不同的——即更惰性的——执行模型。因此,第一部分将涵盖语义上的重大差异。由于我们希望对模型有坚实的基础,第一部分将完全不考虑并行性。我们也不会考虑大于内存的数据结构。第一部分将使用基于 Dask 数据框接口的说明性示例,这与 pandas 类似。
在第二节中,我将讨论大于内存数据集的分区,并展示 Dask 模型的某些性能影响。我还会介绍一些最佳实践来加速计算。
在第三节中,我们将了解 Dask 的分布式调度器,它允许我们智能地将计算分布在多台计算机和架构上:从 HPC 集群到云或启用 GPU 的机器。由于要求读者拥有集群或云来运行此代码可能过于苛刻,我们的示例可以在单台机器上运行,但也很容易扩展。
我们将从 Dask 的执行模型开始。鉴于 Dask 的惰性特性,它与传统库(如 pandas 或 NumPy)相比,存在一些重要的概念性差异,在我们实际实现并行解决方案之前需要理解。本章代码需要使用 Dask 运行。您还需要 Graphviz 库来绘制任务图。使用 conda,执行 conda install dask。目前,使用 pip 安装 Graphviz 库桥接器似乎更容易(pip install graphviz),即使使用 conda。您还应该确保 Graphviz 主应用程序已安装。Docker 镜像是 tiagoantao/python-performance-dask。
10.1 理解 Dask 的执行模型
并行解决方案通常非常困难,尤其是在分布式架构上执行时。在我们深入探讨使用 Dask 的并行性之前,让我们确保我们理解其执行模型。我们将使用类似 pandas 的解决方案在 Dask 中编写,并忽略其底层实现:我们不在乎它是串行还是并行。将我们的讨论限制在模型执行上,将使我们能够理解 Dask 与本例中的 pandas 之间的差异。然后,在下一节中,我们将使用并行和分布式解决方案。
在此示例中,我们将从美国人口普查中获取有关 50 个美国州的税收数据。对于每个州,我们将有关于所有收集的税收的信息,包括从每个税收来源收集的金额细分。换句话说,我们将能够看到收集的税收总额,以及作为所得税、销售税、财产税等的金额。我们正在尝试决定在哪里买房,而我们考虑的一个因素是我们将不得不支付的财产税金额。因此,我们想知道哪些州从财产税中获得大量税收收入,哪些州从财产税中获得的总税收收入很少。我们唯一关心的是确定整个税收收入中来自财产税的比例。
我们将要处理的数据表相当小,因此使用 pandas 处理是微不足道的,但数据大小在这里不是重点。我们想要理解的是执行模型。数据可以在mng.bz/41ND找到。
10.1.1 用于比较的 pandas 基准
让我们从 pandas 版本作为基准开始。我们需要读取文件并清理数据,然后我们将计算每个州的财产税比例:
import numpy as np
import pandas as pd
taxes = pd.read_csv("FY2016-STC-Category-Table.csv", sep="\t")
taxes["Amount"] = taxes["Amount"].str.replace(",",
"").replace("X", np.nan).astype(float) ①
pivot = taxes.pivot_table(index="Geo_Name",
columns="Tax_Type", values="Amount") ②
has_property_info = pivot[pivot["Property Taxes"].notna()].index
pivot_clean = pivot.loc[has_property_info]
frac_property = pivot_clean["Property Taxes"] / pivot_clean["Total Taxes"]
frac_property.sort_values()
① 我们清理“金额”列,以便将其转换为浮点数。
② 我们沿着“税种”进行数据透视。
我们首先读取文件,该文件包括美国州(称为Geo_Name)、税种类型以及州为该税种类型收集的金额:
Geo_Name Tax_Type Amount
Alabama Total Taxes 10,355,317
Alabama Property Taxes 362,515
Alabama Sales and Gross Receipts Taxes 5,214,390
Alabama License Taxes 575,510
Alabama Income Taxes 4,098,278
Alabama Other Taxes 104,624
Connecticut Total Taxes 15,659,420
Connecticut Property Taxes X
Connecticut Sales and Gross Receipts Taxes 6,518,905
Connecticut License Taxes 454,779
Connecticut Income Taxes 8,322,645
Connecticut Other Taxes 363,091
...
然后,对于包含非数字的“金额”列,我们将Xs 转换为NAs。为了我们的计算,我们需要只有数字,而不是字符串。
接下来,我们进行数据透视:我们创建一个新表,每个税种一列,每个美国州一行。这种表示方式使计算更简单,因为我们只需要一行就能获取所有信息。结果是:
index Income Taxes Total Taxes ... Property Taxes
Alabama 4098278 10355317 362515
Colorado 711711 12887859 NaN
然后,我们删除所有没有财产税的行,并最终计算来自财产税的税收百分比:
Nebraska 0.000024
New Jersey 0.000147
Iowa 0.000147
Massachusetts 0.000213
....
Alaska 0.124577
New Hampshire 0.154625
Wyoming 0.177035
DC 0.326369
Vermont 0.338844
现在我们来看 Dask 版本。
10.1.2 开发基于 Dask 的数据帧解决方案
如您所见,在等效的 Dask 版本中,代码几乎相同:
import numpy as np
import dask.dataframe as dd ①
taxes = dd.read_csv("FY2016-STC-Category-Table.csv", sep="\t")
taxes["Amount"] = taxes["Amount"].str.replace(",",
"").replace("X", np.nan).astype(float)
taxes["Tax_Type"] = taxes["Tax_Type"].astype(
"category").cat.as_known() ②
pivot = taxes.pivot_table(index="Geo_Name",
columns="Tax_Type", values="Amount")
has_property_info = pivot[
~pivot["Property Taxes"].isna()].index ③
pivot_clean = pivot.loc[has_property_info]
frac_property = pivot_clean["Property Taxes"] / pivot_clean["Total Taxes"]
① 我们导入 Dask 的数据帧接口。
② 我们需要指定 Tax_Type 为分类数据以进行数据透视。
③ 我们使用 isna 而不是 notna,因为 Dask 不支持 notna。
如您所见,代码非常相似。它感觉相同,您几乎可以原样复制,只需在导入dask.dataframe而不是pandas时进行替换。
警告:虽然 Dask 的数据帧接口确实与 pandas 相似,但一些功能尚未实现或略有不同。我们刚刚看到了notna的案例以及将列标注为分类数据的需要,但还有很多其他案例。我精心设计了此示例以避免实现中的一些额外差异。重点是:风味相同,许多操作非常相似,但 pandas 和 Dask 数据帧之间仍然存在实现差距。
之前的代码如果你使用 pandas,可能不会得到你预期的结果。它做了什么?虽然 Dask 代码非常相似,但它正在做完全不同的事情。
print(frac_property) 并不提供结果。相反,Dask 准备了一个执行计划——任务图——来计算结果。任务图有节点表示需要执行的操作,有边表示操作之间的依赖关系。让我们考虑一个具体的例子。
Dask 可以导出任务图的可视化:
frac_property.visualize(filename="10-property.svg", rankdir="LR")
图 10.1 显示了与代码的前两行对应的任务图部分。pd.read_csv 由第一个节点表示。赋值语句 taxes["Amount"].str.replace(",", "").replace("X", np.nan).astype(float) 的左侧部分由底部一行表示,而正确的赋值 taxes["Amount"] = ... 由右侧最后一个节点表示。

图 10.1 我们财产税计算的任务图开始。包括 CSV 读取和 Amount 的重新编码。
计算现在已准备好运行。我们可以通过运行以下命令来获取结果:
frac_property_result = frac_property.compute()
这将计算结果并返回一个与上一个 pandas 示例中相同的 pandas 数据框。
在这个阶段,我们不在乎计算是如何执行的。它可能是串行、线程、多进程、在集群、在 GPU 上或在云上执行的。本节要理解的基本点是 Dask 在执行上是懒惰的:你编写的代码创建了一个稍后要执行的任务图。
现在你已经理解了 Dask 懒惰方法(计算在需要时进行)与 pandas(或 NumPy)的急切方法(计算在指定时进行)之间的区别,让我们深入探讨算法成本。
10.2 Dask 操作的计算成本
让我们讨论几个 Dask 操作的算法成本。我们的讨论将与执行环境无关。虽然算法和实际执行平台这两个主题在实践中非常交织在一起,但理解算法复杂性的后果本身更容易。与 Dask 分割不适合内存的数据的方式相关的后果,与你要执行的任何运行基础设施无关。我们将考虑的问题相当普遍,其他并行处理方法,如 Spark,也有类似的问题。
我们将执行一些非常简单的任务。首先,我们将创建一个名为 year 的列,它只包含 Survey_Year 的最后两位数字。因此,2016 变成了 16。然后,我们将创建一个名为 k_amount 的列,它是 Amount 但以千为单位(即除以 1,000)。接下来,我们将获取具有最大值的州。最后,我们将按总税收金额对州进行排序。
我们将重用上一节的数据。尽管数据集很小,但我们仍然可以强制 Dask 以类似大型数据集的方式对计算进行分区。无论如何,“大”这个概念将取决于你拥有的硬件。
在我们分区之前,让我们先加载数据并创建一个只有最后两位数字的年份列(例如,2016 转换为 16):
import numpy as np
import dask.dataframe as dd
taxes = dd.read_csv("FY2016-STC-Category-Table.csv", sep="\t")
taxes["year"] = taxes["Survey_Year"] - 2000
taxes.visualize(filename="10-single.svg", rankdir="LR")
如果你可视化任务图,如图 10.2 所示,你会看到文件读取后跟随执行减法操作所需的操作。


现在我们来看看如果我们对数据进行分区,任务图会是什么样子。
10.2.1 为处理分区数据
CSV 输入相当小,不到 15 KB,但为了理解正在发生的事情,让我们假设我们只能同时处理 5 KB。我们可以告诉 read_csv 处理最大块或块,即 5,000 字节:
taxes = dd.read_csv("FY2016-STC-Category-Table.csv",
sep="\t", blocksize=5000)
taxes["year"] = taxes["Survey_Year"] - 2000
taxes.visualize(filename="10-block.svg", rankdir="LR")
如果我们可视化任务图,现在我们有三个独立的分区,如图 10.3 所示。15 KB 被分成三部分,以便最多处理 5,000 字节的数据。


现在我们有三个数据帧的分区,是时候考虑 Dask 中数据帧的实现方式了。记住,我们的目标是强制 Dask 对数据进行分区,以了解分区如何影响 Dask 任务图。图 10.4 提供了数据分区的高级概述。在实现中,三个分区被表示为三个 pandas 数据帧。

图 10.4 Dask 中数据帧实现的高级视图
对于 Dask 数组(NumPy 数组的 Dask 等价物),也采用了类似的实现策略。Dask 数组在每个分区中实现为一个 NumPy 数组。作为基于 Python 的解决方案,Dask 利用现有的库来支持其内部工作。让我们回到实现我们的解决方案。
现在我们已经看到了分区对任务图的影响,让我们看看一种减少重复计算的策略。
10.2.2 持久化中间计算
正如我们在上一节中讨论的,金额列在获得正确的数字之前需要一些解析。由于我们的大部分计算都将依赖于正确解析的数字,我们希望避免每次需要该列时都重新计算字符串转换。相反,我们可以要求 Dask 持久化计算的中间状态:
taxes["Amount"] = taxes["Amount"].str.replace(",",
"").replace("X", np.nan).astype(float)
taxes = taxes.persist()
taxes.visualize(filename="10-persist.svg", rankdir="LR")
虽然 .persist 调用的语义将取决于特定的调度器,但让我们假设所有节点的计算都已启动,因此这将执行任务图以清除 Amount 列。在 .persist 之后,金额的计算将不再重复。对于这个例子,我们只处理轻量级计算,但持久化调用也可能生成非常长的计算图。
维护数据分区的优势在于我们仍然在整个计算环境中拥有它,并且可以对其启动并行查询。这与 compute 相比:你会得到所有数据,如果你需要在此基础上进行更多计算,你需要重新分区数据并将其发送到所有执行计算的进程。此外,如果完整的数据框大于你的内存,compute 会崩溃你的本地进程。
提示:跨处理单元传输数据可能非常昂贵,尤其是在网络上,因为需要序列化数据。另一方面,我们无法持久化一切,因为这可能需要太多内存。通常,那些频繁重用且产生少量数据的图节点是 persist 方法的良好候选。
现在我们优化了部分计算,让我们回到计算具有最大税额的状态,并按总税收收集量对状态进行排序。
10.2.3 分布式数据框上的算法实现
对于某些操作,与我们所习惯的顺序实现相比,算法的分布式实现可能具有完全不同的成本——在我们的例子中,比较 Dask 数据框和 pandas 数据框。
让我们执行一个简单的操作。记住,我们想要将前一个数据集的 Amount 列转换为千美元:
taxes["k_amount"] = taxes["Amount"] / 1000
taxes.visualize(filename="10-k.svg", rankdir="LR")
该操作的任务图非常简单,如图 10.5 所示。在这种情况下,计算在所有分区上并行进行,这非常高效。

图 10.5 在分区上可以发生许多计算,而无需它们之间有任何通信。
现在让我们考虑一个在分布式系统中实现更困难的操作。让我们想象我们想要从计算出的金额中找到最大值。你能想象这个任务的样子吗?代码如下:
max_k = taxes["k_amount"].max()
max_k.visualize(filename="10-max_k.svg", rankdir="LR")
注意,每个分区都有一个节点负责计算该分区的最大值。不幸的是,所有分区的最大值必须合并到一个单独的进程中来计算所有分区的最大值。这个过程有一些影响。当计算最大值时,在最终节点计算所有分区的最大值时,并行性停止。因此,必须将持有数量的分区中的数据传输到计算最大值的节点。在这个操作的作业图(图 10.6)中,你可以看到从三个被称为series-max-chunk的任务到series-max-agg任务的转换中的数据传输。换句话说,我们需要从三个并行任务到一个单独的任务来计算最大值:这个任务成为了持续并行性的瓶颈。

图 10.6 某些计算需要减少并行性——例如,计算最大值。
一般原则是,与 pandas 或 NumPy 相比,使用 Dask 的操作成本可能会有很大的变化。如果操作需要在进程之间进行通信或停止并行处理,你可以预期成本会增加。具体的含义可能因你的底层架构而异。

图 10.7 某些计算,如sort_values,可能很复杂且昂贵。
如果你不了解一个操作的作业图拓扑结构,你可以简单地渲染该操作的作业图来查看结构并找到潜在的瓶颈。例如,图 10.7 显示了(相当昂贵)的操作sort_values的作业图。在这个特定情况下,barrier和shuffle-collect任务在整个图中对并行性造成了损失:
sv = taxes.sort_values("k_amount")
sv.visualize(filename="10-sv.svg", rankdir="LR")
有时你可能需要检查两个操作的作业图,因为 Dask 可能足够智能以优化它们。例如,跟随groupby的操作可能以完全不同的方式优化。优化甚至可能在不同版本的 Dask 之间变化,因此没有一般规则,除了检查 Dask 如何渲染你使用的操作。
我在这里没有列出便宜或昂贵的 Dask 操作有两个主要原因。首先,操作是否昂贵取决于你的执行环境。例如,云与大型多核计算机会有所不同。其次,Dask 始终在发展,实现可能在不同版本之间发生变化。重要的是要理解其背后的原理。
10.2.4 数据重分区
在某些情况下,根据任务图和执行环境,计算的粒度可以通过重新分区数据来受益。例如,假设我们完成了一部分非常昂贵的计算,需要更多的分区和潜在的更多计算节点。在此过程之后,如果我们进入操作的低强度部分,我们可能会减少分区数量。在我们的例子中,让我们将分区从三个减少到两个:
taxes2 = taxes.repartition(npartitions=2)
taxes2.visualize(filename="10-repart.svg", rankdir="LR")
如您在任务图(图 10.8)中看到的,这种重新分区的问题在于三个分区中的两个被合并成了一个新分区。如果两个新分区有相似的数据量,将会更有效率。

图 10.8 使用不同数量的任务重新分区数据
现在我们来看看平衡两个分区需要什么。我们遇到的第一问题是 Dask 和 pandas 之间的一些语义差异。在我们继续之前,我们应该处理这些问题。
repartition方法允许我们不仅根据分区数量来分割数据,还可以通过在分区之间划分数据帧索引来实现。我们需要知道索引,但我们可以这样做:
print(taxes.index)
输出是:
Dask Index Structure:
npartitions=3
int64
...
...
...
dtype: int64
Dask Name: assign, 6 tasks ①
① 我们在这里得到的是一系列要运行的任务,而不是结果。
我们将得到一系列要运行的任务,而不是最终的索引。因此,我们需要compute索引,这涉及到所有的计算成本,这在大多数情况下,就抵消了我们正在进行优化的目的。
另一个潜在的替代方案是获取每个分区的边界。Dask 允许我们通过以下方式获取每个分区的边界:
print(taxes.divisions)
可惜,输出是:
(None, None, None, None)
我们没有得到索引列的值,而是得到了None。显然,我们不能使用None来计算更好的重新分区数据的方法。
在这个阶段,我们处于一个令人沮丧的情况:我们希望得到索引来重新分区数据,但索引却无处可寻。当你使用 Dask 的read_csv时,你将不会得到带有值的索引——即使你已持久化数据帧(至少在当前版本的 Dask 中是这样)。
我们可以设置索引以到达一个起始点来按索引重新分区。让我们使用Geo_Name和Tax_Type作为索引列:
taxes = taxes.set_index(["Geo_Name", "Tax_Type"])
不幸的是,在当前版本的 Dask 中,这个操作不起作用,因为 Dask 不支持多索引。这里的一个重要教训是,尽管 Dask 尽可能地模仿 pandas 的接口和语义,但由于处理分布式数据结构的复杂性,它仍然存在一些限制和差异。Dask 尽力而为,但请期待一些我们刚才讨论过的差异。
提示:务必检查您是否拥有最新的 Dask 版本。也许这里记录的一些限制已经被解决了。
好的,让我们尝试使用单列索引:
taxes = taxes.set_index(["Geo_Name"])
print(taxes.npartitions)
print(taxes.divisions)
输出是:
3
('Alabama', 'Iowa', 'North Carolina', 'Wyoming')
这个结果很好,因为我们有三个清晰的分区:一个从阿拉巴马州开始,另一个从爱荷华州开始,最后一个从北卡罗来纳州到怀俄明州。
如果 Dask 是惰性的,它是如何知道分区的?我们没有明确要求它计算新的数据帧。尽管如此,在某些情况下,set_index是急切的:触发渲染将要被索引的数据帧所需的全部计算,这可能会使用大量的计算资源。
重要的是,Dask 并不总是惰性的,所以对于某些操作,你需要注意潜在的计算成本。你应该检查你使用的操作文档,特别是如果它们的性能让你怀疑其中一些是急切的。
现在我们终于有了索引,让我们将我们的数据从三个分区重新分区到两个分区:
taxes2 = taxes.repartition(divisions=[
"Alabama", "New Hampshire", "Wyoming"])
print(taxes2.npartitions)
print(taxes2.divisions)
输出结果为:
2
('Alabama', 'New Hampshire', 'Wyoming')
请记住,作为一个一般规则,重新分区是一个昂贵的操作,并且只有在得到小的或中间结果时,或者如果你相信(即,进行了性能分析)重新分区是有益的情况下,才应该进行。
我们对 Dask 数据帧接口与 pandas 之间关系所提出的通用论点可以扩展到 Dask 数组与 NumPy 数组之间的关系。Dask 数组主要使用惰性操作实现,并且只实现了 NumPy 接口的一部分,它们有时会有略微不同的语义。
接下来,我们将我们的分布式数据帧存储到磁盘上。
10.2.5 持久化分布式数据帧
要将我们的taxes2数据帧存储到磁盘上,我们可以简单地这样做:
taxes2.compute().to_csv("taxes2_pandas.csv")
在这种情况下,我们将taxes2分布式数据帧计算到一个 pandas 数据帧中。我们让 pandas 负责写入数据。将所有数据从计算节点传输到我们的主节点可能太昂贵,或者数据可能根本不适合内存,所以这个选项可能不可行。
警告:请注意persist的含义。在这里,我们使用persist表示将数据传输到持久存储,如硬盘。然而,Dask 也有.persist方法,该方法在每个分区上计算并存储对象。
我们可以通过以下方式请求 Dask 让节点写入数据:
taxes2.to_csv("partial-*.csv")
请记住,我们有两个分区,所以你最终会得到不止一个 CSV 文件:partial-0.csv和partial-1.csv,两者都有标题。如果你想得到一个单一的 CSV 文件,你必须相应地连接这些文件。
Parquet 格式(见第八章)实际上可以渲染每个分区独立导出数据的单个持久版本:
taxes2.to_parquet("taxes2.parquet")
如果你查看文件系统,你会找到一个包含以下内容的目录:
taxes2.parquet/
_common_metadata
_metadata
part.0.parquet
part.1.parquet
这可以被视为单个 Parquet“文件”。以下是一个使用 Apache Arrow(见第七章)的简单示例:
from pyarrow import parquet
taxes2_pq = parquet.read_table("taxes2.parquet")
taxes_pd = taxes2_pq.to_pandas()
因此,作为格式,Parquet 适合分布式写入,同时仍然提供对所有数据的统一视图。
在这个阶段,我们了解了 Dask 任务生成的工作原理,但关于执行的部分我们谈论得很少。现在是时候进行最后一步:使用 Dask 在异构架构上进行高效的并行计算——即调度。
10.3 使用 Dask 的分布式调度器
我们已经看到 Dask 倾向于惰性,并且只创建一个计算图,这个图最终必须被评估。为了将计算图的节点评估分布到计算资源上,Dask 使用一个调度器。当你计算一个任务图而没有明确配置调度器时,Dask 会自动使用一个默认设置,这个设置取决于你的集合。让我们以数据框为例:
import dask
from dask.base import get_scheduler
import dask.dataframe as dd
df = dd.read_csv("FY2016-STC-Category-Table.csv")
print(get_scheduler(collections=[df]).__module__)
函数get_scheduler将返回一个执行任务图的函数。在我们的例子中,它定义在输出中打印的模块中:
'dask.threaded'
正如名称所暗示的,数据框的默认调度器是线程并行的。Dask 还提供了另外两个简单的调度器:一个多进程调度器和单线程调度器。单线程调度器特别适合调试和性能分析,因为它通过严格的顺序性降低了复杂性。单线程调度器是调试的一个很好的选择。但对于生产环境,我们将使用一个更复杂的调度器:新的 Dask 分布式调度器取代了所有其他 Dask 调度器,同时提供了更多的灵活性。
分布式调度器允许你在多台机器上调度任务。这个调度器有针对 HPC 集群、SSH 连接和云提供商等的实现。它还有一个在本地机器上运行的实现,可以是单线程或多线程,也可以基于多个进程;因此,它包括了所有内置调度器的计算方法。
接下来,我们将使用本地机器配置,这样你就不需要访问集群或云服务,但所有基本构建块都将可用于从单台机器扩展。
注意:作为提醒,本内容与第三章存在一定程度的重叠:你可以使用 Dask 在单台计算机上并行化代码,就像我们使用 Python 的本地库那样。但 Dask 在扩展(即使用多台机器)时能带来最大的价值。
我们将使用前一章中的 Mandelbrot 生成场景来练习 Dask 的数组接口。记住,还有其他方法可以使 Mandelbrot 实现更高效(例如,Cython 或 Numba)。鉴于我们正在对数组进行纯 Python 实现,Cython 或 Numba 的实现将更加高效。实际上,对于非常大的图像,最有效的实现将是 Dask 与 Cython 或 Numba 结合使用。让我们首先看看dask.distributed的架构。
10.3.1 dask.distributed 架构
图 10.9 所示的架构包含以下组件:
-
单个集中式调度器——这个调度器负责为所有工人调度任务。调度器有一个用户可以用来检查计算性能的 Web 仪表板。
-
工人——工人负责执行工作负载。每台机器可能有多个工人。你可以配置一个工人拥有你想要的任意数量的线程。因此,在实践中,你通过线程——比如说,一个拥有与 CPU 核心一样多线程的单个工人——或者通过每个 CPU 核心一个工人的进程来实现并行化。每个工人还有一个仪表板和一个名为nanny的小型附加进程,用于持续监控工人的状态。
-
客户端——这些客户端可以连接到 Dask,使用调度器在上面部署任务,并检查调度器和工人仪表板。

图 10.9 Dask 的执行架构
通常,组件还将包括某种类型的共享存储,如共享文件系统,但这将取决于你的具体情况。
在我们的情况下,我们只使用一台机器。所以,我们将在一个机器上启动所有进程。有更简单的方式来部署架构,但这种方式使得所有组件都非常明确。
让我们从调度器开始:
dask-scheduler --port 8786 --dashboard-address 8787
我们可以这样在调度器相同的机器上启动工人:
dask-worker --nprocs auto 127.0.0.1:8786
--nprocs auto让脚本决定在我们的机器上启动多少个工人以及多少个线程。
在我的机器上,它有四个核心和每个核心两个线程,我最终有四个工人,每个工人有两个线程。我们可以从调度器仪表板中获取这些信息:将你的浏览器指向 http://127.0.0.1:8787,并在菜单中选择“Workers”标签。我在图 10.10 中得到了结果。

图 10.10 Dask 仪表板上的所有工人列表
决定每台机器的工人数和每个工人的线程数相当复杂,这取决于工作负载。对于许多基于 NumPy 的问题,当正确配置时,它是多线程的,我们可以从每台机器上单个 Python 和单个工人开始。NumPy 将使用所需的线程数,我们将把它留给库来做出决定。对于使用 Numba 或 Cython 优化的代码,也可以提出允许脚本确定工人数和线程数的类似论点,因为两者都可以释放 GIL。话虽如此,你的工作负载可能不同。在我们的情况下,确实如此:大部分负担都在纯 Python 上,所以我们将使用每个核心一个进程。
我们将用以下内容替换之前定义的工人:
dask-worker --nprocs 4 --nthreads 1 --memory-limit 1GB 127.0.0.1:8786
我们使用四个进程,每个进程使用一个线程。我们还指定了 1GB 的内存,这是我机器上可用总内存的一半:我这样做是因为我在我的本地机器上运行了更多东西,但在专用机器上你可能可以设置得更高。确保根据你的配置调整这些值。
为了教学目的,我们将工作进程减少到两个,这样我们就可以讨论工作进程之间的通信,每个工作进程仅使用 250 MB。在我的机器上,代码如下:
dask-worker --nprocs 2 --nthreads 1 --memory-limit 250MB 127.0.0.1:8786
接下来,我们将使用 Python 代码连接到我们的调度器。但在解决具体问题之前,让我们检查基础设施:
from pprint import pprint
import dask.dataframe as dd
from dask.distributed import Client
client = Client('127.0.0.1:8786') ①
print(client)
for what, instances in client.get_versions().items(): ②
print(what)
if what == 'workers':
for name, instance in instances.items():
print(name)
pprint(instance)
else:
pprint(instances)
① 我们连接到 dask-scheduler 启动时指定的端口上的调度器。
② get_versions返回有关 Dask 系统中各种组件的信息。
我们通过创建一个指向入口点的Client对象来连接到调度器。第一次打印返回:
<Client: 'tcp://192.168.2.20:8786' processes=2 threads=2, memory=500.00 MB>
此输出反映了我们创建的基础设施:两个工作进程,每个工作进程一个线程,每个工作进程 250 MB。
之后,我们打印所有相关组件的所有软件版本。调度器如下:
scheduler
{'host': {'LANG': 'en_US.UTF-8',
'LC_ALL': 'None',
'OS': 'Linux',
'OS-release': '5.13.0-19-generic',
'byteorder': 'little',
'machine': 'x86_64',
'processor': 'x86_64',
'python': '3.9.7.final.0',
'python-bits': 64},
'packages': {'blosc': '1.9.2',
'cloudpickle': '1.6.0',
'dask': '2021.01.0+dfsg',
'distributed': '2021.01.0+ds.1',
'lz4': None,
'msgpack': '1.0.0',
'numpy': '1.19.5',
'python': '3.9.7.final.0',
'toolz': '0.9.0',
'tornado': '6.1'}}
您可以找到有关主机信息,包括操作系统、处理器类型和 Python 版本,以及安装的库。
这里是针对我们两个工作进程和客户端的简略版本:
workers
tcp://127.0.0.1:32931
{'host': {'LANG': 'en_US.UTF-8',
...
'python-bits': 64},
'packages': {'blosc': '1.9.2',
....
'tornado': '6.1'}}
tcp://127.0.0.1:34719
{'host': {'LANG': 'en_US.UTF-8',
...
'tornado': '6.1'}}
client
{'host': {'LANG': 'en_US.UTF-8',
...
'tornado': '6.1'}}
在具有许多机器的异构集群中,确保库版本兼容性非常重要。在我们的情况下,因为我们的机器同时是客户端、调度器和两个工作进程,所以我们可以确信所有版本都是同步的。然而,当涉及多个机器时,您可能需要调试库版本。现在让我们部署我们的代码。
10.3.2 使用 dask.distributed 运行代码
我们首先连接到调度器:
from dask.distributed import Client
client = Client('127.0.0.1:8786')
客户端将在所有调用中隐式使用,除非我们覆盖它。记住,从上一节中,Dask 数据结构有默认的调度器,但默认调度器将被分布式调度器自动替换。
提示:客户端对象提供了一个与concurrent.futures API 非常相似的显式接口。如果您想使用此类接口,请参阅第三章。在这里,我们将通过数据科学类型接口使用分布式框架——在这种情况下,是dask.array,它模仿 NumPy。
我们将使用 NumPy 通用函数方法。实际上,计算 Mandelbrot 集单点的代码与上一章完全相同:
def compute_point(c):
i = -1
z = complex(0, 0)
max_iter = 200
while abs(z) < 2:
i += 1
if i == max_iter:
break
z = z**2 + c
return 255 - (255 * i) // max_iter
要计算 Mandelbrot 集,我们必须准备一个矩阵,其中每个单元格都有一个二维位置,该位置由一个复数编码。在上一章中,我们使用了以下代码:
def prepare_pos_array(size, start, end, pos_array):
size = pos_array.shape[0]
startx, starty = start
endx, endy = end
for xp in range(size):
x = (endx - startx)*(xp/size) + startx
for yp in range(size):
y = (endy - starty)*(yp/size) + starty
pos_array[yp, xp] = complex(x, y) ①
① 你能猜出这一行的作用吗?
理论上,之前的代码是可行的,但在实践中,它将是一场灾难。请注意,最后一行并不是在数组单元格中存储位置:它实际上是在任务图中创建一个计算结果的作业。为了使这一点更清晰,让我们创建一个 3 × 3(即块大小为 3)的小图像:
size = 3
pos_array = da.empty((size, size), dtype=np.complex128)
prepare_pos_array(3, start, end, pos_array)
pos_array.visualize("10-size3.png", rankdir="LR")
图 10.11 描述了任务图。创建了九个任务,每个任务更新一个单独的像素/单元格。这种解决方案适用于微小的图像,但不适用于较大的图像。对于 1000 × 1000 的图像,我们将处理 100 万个任务。

图 10.11 仅运行九个像素的初始化代码的任务图
一个理论上的替代方案是创建一个本地 NumPy 数组,在本地初始化它,然后分散它。只要 NumPy 数组可以适应内存,这种方法就会有效,但这会违背 Dask 的部分目的:处理大于内存的数据结构。
作为一种更现实的替代方案,Dask 允许我们对整个数据结构的每个独立分区进行计算,从而大大减少了任务的数量:
size = 1000
range_array = da.arange(0, size*size).reshape(size, size)
range__array = pos_array.rechunk(size // 2, size // 2) ①
range__array.visualize("10-rechunk.png", rankdir="TB")
range_array = range_array.persist()
① 我们将数组分块为四个大小为(500, 500)的块。
我们现在使用 1000 × 1000 的图像大小。我们将使用一个范围数字初始化数组,这将允许我们计算二维坐标(有关详细信息,请参阅以下代码片段)。我们从一个大小为 1000 × 1000 的一维数组开始,然后将其重塑为(1000, 1000)。
然后,我们将数组以(500, 500)的块重新分块,最终得到四个块。最后,我们将数组在四个块之间持久化,以准备计算二维位置。
现在我们来准备位置数组。也就是说,我们将创建一个数组,它将二维位置编码为复数:
def block_prepare_pos_array(size, pos_array):
nrows, ncols = pos_array.shape
ret = np.empty(shape=(nrows,ncols), dtype=np.complex128)
startx, starty = start
endx, endy = end
for row in range(nrows):
x = (endx - startx) * ((pos_array[row, 0] // size ) / size) + startx
for col in range(ncols):
y = (endy - starty) * ((pos_array[row, col] % size) / size) + starty
ret[row, col] = complex(x, y)
return ret
此函数将范围数组转换为基于原始单元格中值的坐标数组。不必过于担心算法:它将一维坐标转换为二维坐标。基本点是另一个;让我们看看这段代码来发现它:
pos_array = da.blockwise(
lambda x: block_prepare_pos_array(size, x),
'ij', range_array, 'ij', dtype=np.complex128)
pos_array.visualize("10-blockwise.png", rankdir="TB")
这段代码告诉 Dask 将block_prepare_pos_array的初始化代码应用于四个块中的每一个。我们指定range_array作为输入参数。请注意,有两个参数:ij(即i和j),它告诉 Dask 输入参数的形状与输出参数之间的关系(即它们具有相同的形状)。
此代码仅创建四个任务,如图 10.12 所示。如果我们使用原始代码,我们将有 100 万个任务。

图 10.12 使用分块函数运行初始化代码的任务图
现在,是时候在我们的矩阵上调用我们的 Mandelbrot 代码了:
from PIL import Image
u_compute_point = da.frompyfunc(compute_point, 1, 1)
image_arr = u_compute_point(pos_array)
image = Image.fromarray(image_np, mode="P")
image.save("mandelbrot.png")
我们使用frompyfunc,它将原生 Python 函数转换为 NumPy ufunc。然后我们调用它来处理pos_array矩阵。
接下来,我们将对我们的代码进行一些非常基本的性能分析。主要目的是观察使用较大图像对性能的影响。执行一些简单性能分析的代码如下:
from time import time
def time_scenario(size, persist_range, persist_pos, chunk_div=10):
start_time = time()
size = size
range_array = da.arange(0, size*size).reshape(size, size).persist()
range_array = range_array.rechunk(size // chunk_div, size // chunk_div)
range_array = range_array.persist() if persist_range else range_array
pos_array = da.blockwise(
lambda x: block_prepare_pos_array(size, x),
'ij', range_array, 'ij', dtype=np.complex128)
pos_array = pos_array.persist() if persist_pos else pos_array
image_arr = u_compute_point(pos_array)
image_arr.visualize("task_graph.png", rankdir="TB")
image_arr.compute()
return time() - start_time
我们的功能运行 Mandelbrot 代码,如前所述,这允许我们参数化图像的大小。我们还允许参数化分块以及持久化两个中间数组。该函数返回执行所需的时间,这可以作为我们代码的粗略性能分析。
让我们以大小为 500(即 500 × 500 的图像大小)和块除数为 2(即四个块)运行代码:
size = 500
time_scenario(size, False, False, 2)
图 10.13 展示的任务图显示我们有四个块,计算的数量——lambda 和 frompyfunc 节点——也是四个。

图 10.13 使用四个块计算 Mandelbrot 大小的任务图
现在,让我们以 5,000(即 5,000 × 5,000 的图像)的大小和 10(即 100 块)的块除数运行之前的代码。但在我们这样做之前,打开指向 http://127.0.0.1:8787(即 Dask 的控制台)的网页浏览器:
size = 5000
client = client.restart() ①
time_scenario(size, True, True, 10)
① 在运行此行之后,请重新加载控制台上的网页浏览器以获取一个干净的版本。
当你运行 time_scenario 时,你会看到计算过程中的实时动画。虽然我无法在这里展示视频,但图 10.14 展示了计算进行时的控制台。主控制台上有五个图表。记住,我们有四个工作者,我们将有一个与图 10.13 类似的拓扑结构但具有 100 列的任务图,而不仅仅是四个:
-
左上角的小图表报告了所有工作者存储的字节数。
-
左侧的第二张图表反映了每个工作者的内存使用情况,因此它是左上角图表的更详细版本。
-
左下角的图表列出了每个工作者上正在处理的任务数量。
-
主要图表(右上角)的 X 轴是时间,Y 轴是工作者。每个块代表任务图中的一个任务。不同的颜色(在此以灰度显示)分配给不同的任务类型。
-
最后,右下角的图表显示了所有任务的状态

图 10.14 Dask 控制台的主界面
我鼓励你探索 Dask 控制台的所有页面。例如,配置文件页面将为你提供代码配置文件的 SnakeViz 类型可视化,而图形页面将显示任务图中所有任务的实时状态。
最后,让我们探索 Dask 如何处理比内存更大的数据集。
10.3.3 处理比内存更大的数据集
记住,Dask 允许你处理比内存更大的数据集。当你使用多台计算机和更多的内存时,Dask 会通过这些计算机分发数据结构。
但处理比内存更大的数据集的最后一招解决方案是将它们溢出到磁盘:即暂时存储在磁盘上。然而,正如你所预期的,性能会付出代价。
我们将运行我们的 Mandelbrot 代码,以一个非常大的尺寸 10000 × 10000。在一种情况下,我们将 .persist 中间数组,但在另一种情况下则不会:
size = 10000
print(size, False, False, time_scenario(size, False, False))
print(size, True, True, time_scenario(size, True, True))
这段代码在我的电脑上的输出是:
10000 False False 696
10000 True True 752
第二个版本较慢,因为持久化中间矩阵和持续计算所需的内存大于工作者的可用内存。这个问题可以在仪表板上轻松看到。例如,图 10.15 显示了左上角的两个图表。第一个图表的标题清楚地表明发生了溢出。此外,两个图表上的不同颜色(此处以灰度表示)表明数据存储在不同的位置。

图 10.15 Dask 仪表板的左上部分显示了溢出。
溢出可能导致延迟,其速度比完全内存版本慢几个数量级。有几种替代方案可以考虑,最明显的是增加更多内存或更多机器。无论如何,如果你有溢出,要么尽量避免它,要么确保性能不会受到重大影响。
本章介绍了 Dask 基本概念。有了这些信息,你可以理解在使用 Dask 时涉及的性能问题的基本模块,并将它们应用于完全不同的底层架构,每个架构都有其特定的性能瓶颈。
摘要
-
Dask 允许你在多台机器之间分配计算。
-
Dask 还允许你处理大于内存的对象,如数据框和数组。
-
Dask 实现了广泛使用的 API 的子集,如 pandas 和 NumPy,但 Dask API 的语义不同,因为 Dask 主要是懒加载的,而不是急切加载的。
-
在 Dask 中,你可以在执行前检查任务图,这让你能够理解将要执行的计算,在某些情况下,可以考虑优化它的替代方案。
-
Dask 允许对计算执行进行精细控制。例如,你可以要求计算节点在本地持久化数据,以便在后续计算中有效地重用它。
-
如果重新分区数据有助于加快后续计算,你可以跨节点重新分区数据。然而,重新分区会带来性能成本,因为需要跨节点传输数据。
-
在较低的任务级别,Dask 依赖于 pandas 和 NumPy,在这个级别上,你可以直接使用这些库。
-
Dask 提供了类似于第三章中讨论的
concurrent.futures类型的接口。 -
数据分析的基本算法必须考虑 Dask 对分区数据的利用。分区使得几个算法的性能与 pandas 或 NumPy 中可用的顺序版本有显著不同。
-
Dask 提供了多个调度器;其中,分布式调度器允许计算在广泛的架构中部署——从单台机器到非常大的集群。
-
Dask 提供了一个调度器
dask.distributed,它可以在多种架构上分配任务,从单台机器到科学集群或云端。 -
dask.distributed提供了一个强大的仪表板,可以用来分析和分析分布式应用程序。
附录 A. 设置环境
本附录涵盖了
-
设置 Anaconda
-
设置你的 Python 发行版
-
设置 Docker
-
硬件考虑
本附录提供了一些关于如何设置你的环境的建议。我们将使用 Python 3.10 作为我们代码的基础。
你可以使用你喜欢的任何操作系统。如今,大多数生产代码通常部署在 Linux 上,但你也可以使用 Windows 或 Mac OS X。使用 Mac 和 Linux 之间几乎没有区别。使用 Windows 会更困难。如果你选择这样做,我建议你安装一些 Unix 工具,比如 Bash shell,或者你可以使用 Windows 子系统中的 Linux;Cygwin 也是一个选择。
一个适用于所有操作系统的替代方案是使用提供的——并且完全可选的——Docker 镜像,其中包含所需的软件。如果你选择这条路径,请确保为你的操作系统安装 Docker。我提供了一个默认的 Docker 镜像,尽管某些章节需要专门的软件,在这种情况下,我会提供专门的 Docker 镜像。如果需要,你可以在那些章节中找到具体的说明。
在这本书中使用的软件的完整列表——太长了,这里无法列出——可以在各种Dockerfile中的代码仓库中找到。即使你不使用 Docker,这个列表也可能很有用。本书的仓库可以在github.com/tiagoantao/python-performance找到。
A.1 设置 Anaconda Python
Anaconda Python 可能是数据科学和工程中最常见的发行版。我建议使用它来运行本书中的代码。安装 Anaconda 后,为本书创建一个环境:
conda create -n python-performance python=3.10 ipython=8.3
conda install pandas numpy requests snakeviz line_profiler blosc
一些章节将需要额外的软件。在这种情况下,我建议你克隆原始环境并为每个章节创建一个新的环境。这可以通过以下方式完成:
conda create --clone python-performance -n NEW_NAME
克隆后,你可以在新环境中安装任何所需的软件,而不会影响原始环境。
我还建议你为每个章节创建一个单独的环境,以避免不同包和库之间的冲突。即使在使用像 conda 这样的优秀包管理器的情况下,包管理仍然存在问题,因此拥有单独的环境会更容易。
更新 conda 环境
如果你是一个长期使用 Anaconda 的用户,那么创建一个新的环境可能更好,使用:
conda create -n python-performance python=3.10
conda activate python-performance
更新旧环境可能需要很长时间,甚至可能失败。如果你是 Anaconda 的新用户,你应该考虑为书中的材料或,最好是每个章节,创建一个单独的环境。
A.2 安装你自己的 Python 发行版
你可以选择你喜欢的任何 Python 发行版,但我强烈推荐 Anaconda Python,这是数据科学和高性能计算中的事实上的标准。如果你安装了 Anaconda(不是较小的版本 Miniconda),你将获得我们将会使用的绝大多数软件。我在整本书中假设你使用 Anaconda。如果你使用其他发行版,你可能需要调整章节中提供的安装说明(对于具有特定需求的章节)。我建议检查像 Poetry([https://python-poetry.org/](https://python-poetry.org/))这样的工具,它可以部分帮助你进行包管理。
你将需要一些非常标准的库,如 NumPy 和 SciPy。为了生成图表,我们将使用 matplotlib。不同的章节需要特定的库,如 Cython、Numba、Apache Arrow 或 Apache Parquet。如果你既不使用 conda 也不使用 Poetry,你可能会依赖于pip来安装软件。
A.3 使用 Docker
如果你想要避免包安装,或者你使用 Windows 并希望有一个更典型的环境,我提供了包含运行代码所需一切内容的 Docker 镜像。这些 Docker 镜像将为你提供一个 Linux 环境,无论你的宿主操作系统是什么。
基础镜像可以按照以下方式运行:
docker run -v PATH_TO_THE_REPOSITORY:/code -ti tiagoantao/python-performance
第一次运行此代码时,它将下载镜像,这可能需要一些时间和带宽。你将有一个 shell,你将能够在/code目录中找到代码。对于具有特定软件要求的章节,我提供了定制的 Docker 镜像。
A.4 硬件考虑
我们在这本书中使用的软件相当标准,但我们处于一个设置可能很棘手,实际上可能会破坏一些优化水平的阶段,例如:
-
你编译和链接库的方式可能会对性能产生巨大影响(参见第四章关于 NumPy 的内容)。如果你安装了推荐的发行版,你很可能拥有一个高性能版本。如果你没有(尤其是如果你自己编译它),请务必阅读第四章的最后部分。
-
如果你使用提供的 Docker 镜像,并且你处于虚拟环境中,那么很难知道你是否完全控制着机器上正在发生的事情,这将使得性能分析和特别是 CPU 缓存变得不太可靠。
-
如果你在一台同时进行许多其他任务的台式机上,可以提出类似的论点。
-
除非你有访问整个物理机器的权限,否则云实例也会遇到相同的问题,这是可能的(尽管成本高昂)但不如在共享物理计算机上常见。
-
不同的硬件配置可能会有完全不同的性能特征。例如,固态硬盘(SSD)通常比带有物理旋转盘的硬盘具有更好的读取性能。这种性能的变异性适用于 CPU、CPU 缓存、内部总线、内存、硬盘和网络——尤其是在分析和缓存问题(CPU、硬盘或网络)方面。
-
从不同的角度来处理前面提到的问题,一个配置良好的裸机生产机器可能是观察这里展示的一些技术发挥最大作用的最佳场所。
-
第九章关于 GPU 的内容需要访问最新的 NVIDIA GPU——至少需要配备 Pascal 架构。
我们将在整本书中深入讨论所有这些问题。但更大的观点是,你将看到的实际例子,虽然本身很有趣,但应被视为发展关于性能问题基本洞察的途径,你可能需要根据具体情况对其进行调整。在更高级的水平上,深入理解才是真正的目标。实际例子是手段,但肯定不是终点。
附录 B. 使用 Numba 生成高效的底层代码
Numba 是一个将 Python 代码自动转换为原生代码(CPU 或 GPU)的框架。在 CPU 方面,它是 Cython 的替代品。我们为什么有一个关于 Cython 而不是 Numba 的整个章节,是因为在这本书中,我们感兴趣的是了解事物是如何工作的,而不仅仅是让它们工作。Numba 虽然在其他方面很出色,但从教学角度来看并不出色——因为它具有“魔法性”。
为了解决实际问题,Numba 与 Cython 相当,如果不是更好,因为它需要您更少的工作,并且产生类似的结果。从可用性的角度来看,我建议您将 Numba 作为 Cython 的替代品考虑。实际上,首先考虑 Numba 可能更实用。
当您尝试运行函数时,Numba 会将 Python 函数动态转换为优化的机器代码。换句话说,它是一个即时(JIT)编译器。
在本附录中,我们将为 CPU 开发一个示例。您可以将此内容作为 GPU 章节所需 Numba 介绍的入门,或者您可以将其作为独立的内容来学习 Numba 用于 CPU。
要运行此代码,您需要安装 Numba。如果您使用 conda,可以执行 conda install numba。如果您使用 Docker,镜像为 tiagoantao/python-performance-numba。

图 B.1 曼德布罗特集的灰度渲染
我们的例子将是计算曼德布罗特集,我们将运行一个原生 Python 版本和一个 Numba 版本来比较速度。你可能已经见过这个标志性的图像——一个变体在图 B.1 中显示。曼德布罗特集是在复数空间中计算的——我们将使用复数,并研究方程 z = z² + c 的迭代行为,其中 c 是空间中的一个点,而 z 从 (0, 0) 开始。这个计算比看起来要简单。让我们看看代码来了解细节:
def compute_point(c, max_iter=200): ①
num_iter = -1
z = complex(0, 0) ②
while abs(z) < 2:
num_iter += 1
if i == max_iter:
break
z = z**2 + c ③
return 255 - (255 * num_iter) // max_iter
① 我们需要指定最大迭代次数,因为这可能是无限的。
② Python 本身支持复数。
③ 曼德布罗特方程 z = z² + c
我们的输入是空间中的一个点,c。我们感兴趣的是方程 z = z² + c(z 从 (0, 0) 开始)的绝对值超过 2 的迭代次数。迭代次数决定了位置 c 的像素颜色。我们设定一个最大迭代次数,因为接近 0 的点可能永远不会超过 2,因此迭代次数将是无限的。
因此,当 z 到原点的距离大于 2 时,我们停止迭代。远离原点的点在第一次迭代时停止,而接近原点的点将无限迭代。为了避免无限次的计算,我们使用 max_iter 定义最大迭代次数。在边界附近,迭代次数以混沌的方式变化(图 B.1 中的灰度色调)。图像显示了在复平面 1.5 – 1.3i 和 0.5 + 1.3i 之间的 255 次最大迭代次数下,迭代次数到灰度色调的转换。
计算曼德布罗集的主要函数因此相当简单。这里的版本实际上比标准版本稍微复杂一些,因为我们把输出重缩放到 0 到 255 之间,不管迭代次数多少。重缩放将使绘制灰度 8 位图像变得简单。(我使用灰度图像是因为在打印的书中使用颜色的限制。)
B.1 使用 Numba 生成优化代码
我们现在将使用 @jit 装饰器来创建函数的 Numba 版本:
from numba import jit
compute_point_numba = jit()(compute_point)
不要被装饰器吓到:装饰器不过是一种语法糖。因为在这种情况下,我们想要比较原生和 Numba 版本的性能,使用装饰器更方便,因为 @ 语法只会暴露 Numba 版本。
由于 Numba 是一个 JIT(即时编译器),函数的第一次调用将把它编译成 LLVM 表示形式;这是一个一次性操作。我们将进行一个虚拟调用,这样后续的性能分析就不会因为这个一次性步骤而受到偏差:
compute_point_numba(complex(4,4))
你必须小心那些可能有副作用的功能:确保虚拟调用没有不希望的结果。在大多数生产场景中,如果你没有进行基准测试,你可以简单地忽略这一步。
现在我们有两个版本:原生版本(compute_point)和 Numba 优化版本(compute_point_numba)。我们需要为每个要绘制的点调用这些函数。我们将有一个起始角和结束角以及我们的分辨率:X 和 Y 坐标上我们将有相同的分辨率:
def do_all(size, start, end, img_array, compute_fun):
startx, starty = start
endx, endy = end
for xp in range(size):
x = (endx - startx)*(xp/size) + startx # precision issues
for yp in range(size):
y = (endy - starty)*(yp/size) + starty # precision issues
img_array[yp, xp] = compute_fun(complex(x,y))
这个简单的函数遍历所有点。如果你认为它可以接受向量化的方法,你是正确的;我们稍后会讨论这一点。"size" 是每个维度的像素数,"start" 和 "end" 是复空间中的位置,"img_array" 是输出数组,"compute_fun" 是我们将用于计算每个位置值的函数。
关于如何计算 x 和 y 坐标,这里有一个小的细节。理论上,我们可以像这样向当前位置添加一个增量:
x = startx
deltax = (endx - startx) / size
for xp in range(size):
....
x += deltax
这种方法的缺点,虽然会稍微快一些,是精度误差会从一次迭代累积到下一次迭代,到我们可能得到错误的结果。因此,我们将坚持使用更昂贵的 x = (endx - startx)*(xp/size) + startx。
生成图像的参数是:
size = 2000
start = -1.5, -1.3
end = 0.5, 1.3
img_array = np.empty((size, size), dtype=np.uint8)
我们还需要初始化作为输出的数组。
现在我们比较一下原生和 Numba 版本的运行时间。使用 IPython,我们可以这样做:
In [2]: %timeit do_all(size, start, end, img_array, compute_point_numba)
4.71 s ± 105 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [3]: %timeit do_all(size, start, end, img_array, compute_point)
50.4 s ± 2.94 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
在我的电脑上,我得到了性能的十倍提升。
从这个例子中我们可以看出,Numba 在自动优化 Python 代码方面做得相当好,但它也可能遇到我们在 Cython 中看到的问题:如果它无法摆脱 CPython 对象机制,性能会受到影响。让我们创建一个稍微有些人为的例子来展示这个问题。在我们之前的例子中,Numba 在自动转换方面做得令人钦佩,但我们可以强制 Numba 生成与 CPython 绑定的代码,看看它如何影响性能:
compute_point_numba_forceobj = jit(forceobj=True)(compute_point)
运行时间是:
In [2]: %timeit do_all(size, start, end,
img_array, compute_point_numba_forceobj)
1min 46s ± 2.46 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
现在,我们有 1 分 46 秒。请注意,这个例子是一个最坏的情况。有时 Numba 可以优化代码的一部分,即使它无法优化所有部分。
要强制 Python 编译,请在装饰器中添加nopython=True。如果 Numba 无法编译函数,请务必查看 Numba 的文档numba.readthedocs.io/en/stable/user/5minguide.html,以查看哪些 Python 功能受支持:我们不会在这里详细介绍,因为它们会随时间变化。
B.2 在 Numba 中显式编写并行函数
Numba 还允许你编写并行线程代码,因为有时你可以释放 GIL:
from numba import prange
@jit(nopython=True,parallel=True,nogil=True) ①
def pdo_all(size, start, end, img_array, compute_fun):
startx, starty = start
endx, endy = end
for xp in prange(size): ②
x = (endx - startx)*(xp/size) + startx
for yp in range(size):
y = (endy - starty)*(yp/size) + starty
b = compute_fun(complex(x, y))
img_array[yp, xp] = b
① 我们指定parallel=True,并通过nogil=True释放 GIL。
② 我们使用prange函数。
当我们使用prange时,我们是在要求 Numba 并行化该循环。因为代码是无 GIL 的(即不需要与 CPython 交互),所以可以并行化。因此,结果是:
In [3]: %timeit pdo_all(size, start, end, img_array, compute_point_numba)
1.41 s ± 35.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在我的机器上,性能是串行版本的 3 倍多。这个结果显然更好,但并不与我的可用八核线性相关。在某些情况下(即当 Numba 可以完全绕过 Python 解释器时),Numba 函数可以生成真正的并行代码。
B.3 在 Numba 中编写 NumPy 感知代码
现在我们已经使用 Numba 将纯 Python 代码转换过来,让我们考虑一个使用 NumPy 通用函数的版本,因为将代码集成到 NumPy 对于数据科学应用是基本的。Numba 函数可以被转换为 NumPy 通用函数,这是数据科学中常见的用例。这个过程相当简单:
from numba import vectorize
compute_point_ufunc = vectorize(
["uint8(complex128,uint64)"],
target="parallel")(compute_point)
我们使用vectorize函数来包装compute_point。我们指定该函数可以运行且是并行的。我们还必须提供一个函数签名类型的列表。例如max_iter这样的可选参数在实践中变得强制。
我们使用此代码的方式不同:我们需要传递一个包含我们想要结果的位矩阵:
size = 2000
start = -1.5, -1.3
end = 0.5, 1.3
def prepare_pos_array(start, end, pos_array):
size = pos_array.shape[0]
startx, starty = start
endx, endy = end
for xp in range(size):
x = (endx - startx)*(xp/size) + startx
for yp in range(size):
y = (endy - starty)*(yp/size) + starty
pos_array[yp, xp] = complex(x, y)
pos_array = np.empty((size, size), dtype=np.complex128)
prepare_pos_array(start, end, pos_array)
prepare_pos_array只是准备输入数组,其中包含所有要计算的坐标位置。这种方法的不利之处在于我们需要内存来存储位置和结果数组。
让我们计时运行:
%timeit img_array = compute_point_ufunc(pos_array, 200)
在我的机器上的输出是:
In [2]: %timeit img_array = compute_point_ufunc(pos_array, 200)
539 ms ± 7.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
这个结果比非 NumPy 并行版本快近三倍,而且无需做很多工作。
本附录应能帮助您开始使用 Numba。如果您对使用 Numba 生成 GPU 代码感兴趣,请查看第九章。


浙公网安备 33010602011771号