Scala-应用机器学习-全-

Scala 应用机器学习(全)

原文:annas-archive.org/md5/91947947c2d1b3831a619c5773df4a8c

译者:飞龙

协议:CC BY-NC-SA 4.0

前言

许多人认为 Scala 是大数据领域的 Java 的继任者。它特别擅长在不显著影响性能的情况下分析大量数据,因此 Scala 被许多开发者和数据科学家采用。

本学习路径旨在将使用 Scala 的整个机器学习世界展现在您面前。我们将从向您介绍 Scala 中用于摄取、存储、操作、处理和可视化数据的库开始。接着,我们将介绍 Scala 中的机器学习,并深入探讨如何利用 Scala 构建和研习可以从数据中学习系统的技巧。最后,我们将全面掌握 Scala 机器学习,并传授您构建复杂机器学习项目的专业知识。

本学习路径涵盖的内容

模块 1,Scala 数据科学入门,为您提供了 Raspberry Pi 的介绍。它帮助您使用 PyGame 构建游戏,并使用 Raspberry Pi 创建实际应用。它进一步展示了 OpenCV 的高级概念中的 GPIO 和摄像头。本模块还深入探讨了设置 Web 服务器和创建网络工具。

模块 2,Scala 机器学习入门,通过图表、正式数学符号、源代码片段和实用技巧引导您构建 AI 应用。对 Akka 框架和 Apache Spark 集群的回顾结束了教程。

模块 3,Scala 机器学习精通,是本课程的最后一步。它将把您的知识提升到新的水平,并帮助您利用这些知识构建高级应用,如社交媒体挖掘、智能新闻门户等。在用 REPL 快速复习函数式编程概念后,您将看到一些设置开发环境和处理数据的实际示例。然后,我们将探讨使用 k-means 和决策树与 Spark 和 MLlib 一起工作。

您需要为本学习路径准备的内容

您需要以下设置来完成所有三个模块:

模块 1

本课程提供的示例要求您拥有一个可工作的 Scala 安装和 SBT,即简单构建工具,这是一个用于编译和运行 Scala 代码的命令行实用程序。我们将在下一节中向您介绍如何安装这些工具。

我们不要求使用特定的 IDE。代码示例可以编写在您喜欢的文本编辑器或 IDE 中。

安装 JDK

Scala 代码编译成 Java 字节码。要运行字节码,您必须安装 Java 虚拟机(JVM),它包含在 Java 开发工具包(JDK)中。有几种 JDK 实现,在本课程中,您选择哪一个并不重要。您可能已经在计算机上安装了 JDK。要检查这一点,请在终端中输入以下内容:

$ java -version
java version "1.8.0_66"
Java(TM) SE Runtime Environment (build 1.8.0_66-b17)
Java HotSpot(TM) 64-Bit Server VM (build 25.66-b17, mixed mode)

如果您没有安装 JDK,您将收到一个错误,表明 java 命令不存在。

如果您已经安装了 JDK,您仍然应该验证您正在运行一个足够新的版本。重要的是次要版本号:1.8.0_66 中的 8。Java 的 1.8.xx 版本通常被称为 Java 8。对于本课程的前十二章,Java 7 就足够了(您的版本号应该是 1.7.xx 或更新的版本)。然而,您将需要 Java 8 来完成最后两章,因为 Play 框架需要它。因此,我们建议您安装 Java 8。

在 Mac 上,安装 JDK 最简单的方法是使用 Homebrew:

$ brew install java

这将安装来自 Oracle 的 Java 8,特别是 Java 标准版开发工具包。

Homebrew 是 Mac OS X 的包管理器。如果您不熟悉 Homebrew,我强烈建议您使用它来安装开发工具。您可以在 brew.sh 上找到 Homebrew 的安装说明。

要在 Windows 上安装 JDK,请访问 www.oracle.com/technetwork/java/javase/downloads/index.html(或者,如果此 URL 不存在,请访问 Oracle 网站,然后点击“下载”并下载Java 平台,标准版)。选择 Windows x86 用于 32 位 Windows,或 Windows x64 用于 64 位。这将下载一个安装程序,您可以通过运行它来安装 JDK。

要在 Ubuntu 上安装 JDK,使用您发行版的包管理器安装 OpenJDK:

$ sudo apt-get install openjdk-8-jdk

如果您正在运行一个足够旧的 Ubuntu 版本(14.04 或更早),此包将不可用。在这种情况下,您可以选择回退到 openjdk-7-jdk,这将允许您运行前十二章的示例,或者通过 PPA(非标准包存档)安装来自 Oracle 的 Java 标准版开发工具包:

$ sudo add-apt-repository ppa:webupd8team/java
$ sudo apt-get update
$ sudo apt-get install oracle-java8-installer

您需要告诉 Ubuntu 优先使用 Java 8,方法如下:

$ sudo update-java-alternatives -s java-8-oracle

安装和使用 SBT

简单构建工具(SBT)是一个用于管理依赖项、构建和运行 Scala 代码的命令行工具。它是 Scala 的默认构建工具。要安装 SBT,请遵循 SBT 网站的说明 (www.scala-sbt.org/0.13/tutorial/Setup.html)。

当您启动一个新的 SBT 项目时,SBT 会为您下载一个特定的 Scala 版本。因此,您不需要直接在您的计算机上安装 Scala。从 SBT 管理整个依赖套件,包括 Scala 本身,是非常强大的:您不必担心在同一个项目上工作的开发者使用不同版本的 Scala 或库。

由于我们将在本课程中广泛使用 SBT,让我们创建一个简单的测试项目。如果您之前已经使用过 SBT,请跳过此部分。

创建一个名为sbt-example的新目录并导航到它。在这个目录内,创建一个名为build.sbt的文件。该文件编码了项目的所有依赖项。在build.sbt中写入以下内容:

// build.sbt

scalaVersion := "2.11.7"

这指定了我们想要为项目使用的 Scala 版本。在sbt-example目录中打开一个终端并输入:

$ sbt

这将启动一个交互式 shell。让我们打开一个 Scala 控制台:

> console

这将使你能够访问项目上下文中的 Scala 控制台:

scala> println("Scala is running!")
Scala is running!

除了在控制台运行代码,我们还将编写 Scala 程序。在sbt-example目录中打开一个编辑器,并输入一个基本的“hello, world”程序。将文件命名为HelloWorld.scala

// HelloWorld.scala

object HelloWorld extends App {
  println("Hello, world!")
}

返回 SBT 并输入:

> run

这将编译源文件并运行可执行文件,打印出"Hello, world!"

除了编译和运行你的 Scala 代码,SBT 还管理 Scala 依赖项。让我们指定对 Breeze 库的依赖,这是一个用于数值算法的库。按照以下方式修改build.sbt文件:

// build.sbt

scalaVersion := "2.11.7"

libraryDependencies ++= Seq(
  "org.scalanlp" %% "breeze" % "0.11.2",
  "org.scalanlp" %% "breeze-natives" % "0.11.2"
)

SBT 要求语句之间用空行分隔,所以请确保在scalaVersionlibraryDependencies之间留一个空行。在这个例子中,我们指定了对 Breeze 版本"0.11.2"的依赖。我们是如何知道使用这些坐标来指定 Breeze 的?大多数 Scala 包在其文档中都会引用确切的 SBT 字符串以获取最新版本。

如果不是这种情况,或者你正在指定对 Java 库的依赖,请访问 Maven Central 网站(mvnrepository.com)并搜索感兴趣的包,例如“Breeze”。该网站提供了一系列包,包括几个名为breeze_2.xx的包。下划线后面的数字表示该包为哪个 Scala 版本编译。点击"breeze_2.11"以获取不同 Breeze 版本的列表。选择"0.11.2"。你将看到一个包含包管理器的列表以供选择(Maven、Ivy、Leiningen 等)。选择 SBT。这将打印出类似以下的一行:

libraryDependencies += "org.scalanlp" % "breeze_2.11" % "0.11.2"

这些是你想要复制到build.sbt文件中的坐标。请注意,我们只是指定了"breeze",而不是"breeze_2.11"。通过在包名前加上两个百分号%%,SBT 会自动解析到正确的 Scala 版本。因此,指定%% "breeze"% "breeze_2.11"相同。

现在返回你的 SBT 控制台并运行:

> reload

这将从 Maven Central 获取 Breeze JAR 文件。你现在可以在控制台或脚本中(在 Scala 项目的上下文中)导入 Breeze。让我们在控制台中测试一下:

> console
scala> import breeze.linalg._
import breeze.linalg._

scala> import breeze.numerics._
import breeze.numerics._

scala> val vec = linspace(-2.0, 2.0, 100)
vec: breeze.linalg.DenseVector[Double] = DenseVector(-2.0, -1.9595959595959596, ...

scala> sigmoid(vec)
breeze.linalg.DenseVector[Double] = DenseVector(0.11920292202211755, 0.12351078065 ...

现在,你应该能够编译、运行并指定 Scala 脚本依赖项。

模块 2

掌握 Scala 编程语言是先决条件。阅读数学公式,方便地定义在信息框中,是可选的。然而,对数学和统计学的一些基本知识可能有助于理解某些算法的内部工作原理。

本课程使用以下库:

  • Scala 2.10.3 或更高版本

  • Java JDK 1.7.0_45 或 1.8.0_25

  • SBT 0.13 或更高版本

  • JFreeChart 1.0.1

  • Apache Commons Math 库 3.5 (第三章,数据预处理,第四章,无监督学习,和第六章,回归和正则化)

  • 印度理工学院孟买 CRF 0.2 (第七章,顺序数据模型)

  • LIBSVM 0.1.6 (第八章,核模型和支持向量机)

  • Akka 2.2.4 或更高版本(或 Typesafe activator 1.2.10 或更高版本)(第十二章,可扩展框架)

  • Apache Spark 1.3.0 或更高版本(第十二章,可扩展框架)

第 3 模块

本课程基于开源软件。首先,是 Java。您可以从 Oracle 的 Java 下载页面下载 Java。您必须接受许可协议并选择适合您平台的适当镜像。不要使用 OpenJDK——它与 Hadoop/Spark 存在一些问题。

其次,Scala。如果您使用 Mac,我建议安装 Homebrew:

$ ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"

您还将获得多个开源软件包。要安装 Scala,请运行brew install scala。在 Linux 平台上安装需要从www.scala-lang.org/download/网站下载适当的 Debian 或 RPM 软件包。我们将使用当时最新的版本,即 2.11.7。

Spark 发行版可以从spark.apache.org/downloads.html下载。我们使用为 Hadoop 2.6 及更高版本预先构建的镜像。由于它是 Java,您只需解压软件包并从bin子目录中的脚本开始使用即可。

R 和 Python 软件包分别可在cran.r-project.org/binpython.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tar.xz网站上找到。文本中具体说明了如何配置它们。尽管我们使用的软件包应该是版本无关的,但我在这本书中使用了 R 版本 3.2.3 和 Python 版本 2.7.11。

本学习路径适合谁

本学习路径是为熟悉 Scala 并希望学习如何创建、验证和应用机器学习算法的工程师和科学家设计的。它也将对有 Scala 编程背景的软件开发人员有益,他们希望应用机器学习。

读者反馈

我们始终欢迎读者的反馈。请告诉我们您对这个课程的想法——您喜欢什么或不喜欢什么。读者反馈对我们很重要,因为它帮助我们开发出您真正能从中受益的课程。

要发送给我们一般反馈,请简单地发送电子邮件至<feedback@packtpub.com>,并在邮件主题中提及课程标题。

如果您在某个主题上具有专业知识,并且您有兴趣撰写或为课程做出贡献,请参阅我们的作者指南,网址为www.packtpub.com/authors

客户支持

现在您是课程书的自豪拥有者,我们有一些事情可以帮助您从您的购买中获得最大收益。

下载示例代码

您可以从您的账户中下载此课程的示例代码文件,网址为www.packtpub.com。如果您在其他地方购买了此课程,您可以访问www.packtpub.com/support并注册,以便将文件直接通过电子邮件发送给您。

您可以通过以下步骤下载代码文件:

  1. 使用您的电子邮件地址和密码登录或注册我们的网站。

  2. 将鼠标指针悬停在顶部的支持标签上。

  3. 点击代码下载与勘误

  4. 搜索框中输入课程的名称。

  5. 选择您想要下载代码文件的课程。

  6. 从下拉菜单中选择您购买此课程的来源。

  7. 点击代码下载

您还可以通过点击 Packt Publishing 网站上课程网页上的代码文件按钮来下载代码文件。您可以通过在搜索框中输入课程名称来访问此页面。请注意,您需要登录到您的 Packt 账户。

文件下载完成后,请确保您使用最新版本解压缩或提取文件夹:

  • WinRAR / 7-Zip for Windows

  • Zipeg / iZip / UnRarX for Mac

  • 7-Zip / PeaZip for Linux

该课程的代码包也托管在 GitHub 上,网址为github.com/PacktPublishing/Scala-Applied-Machine-Learning-Code

勘误

尽管我们已经尽一切努力确保内容的准确性,但错误仍然可能发生。如果您在我们的课程中发现错误——可能是文本或代码中的错误——如果您能向我们报告这一点,我们将不胜感激。通过这样做,您可以节省其他读者的挫败感,并帮助我们改进后续版本的课程。如果您发现任何勘误,请通过访问www.packtpub.com/submit-errata,选择您的课程,点击勘误提交表单链接,并输入您的勘误详情来报告。一旦您的勘误得到验证,您的提交将被接受,勘误将被上传到我们的网站或添加到该标题的勘误部分下的现有勘误列表中。

要查看之前提交的勘误表,请访问 www.packtpub.com/books/content/support,并在搜索字段中输入课程名称。所需信息将在勘误部分显示。

盗版

互联网上对版权材料的盗版是一个跨所有媒体的持续问题。在 Packt,我们非常重视我们版权和许可证的保护。如果您在互联网上发现我们作品的任何非法副本,请立即提供位置地址或网站名称,以便我们可以寻求补救措施。

请通过 <copyright@packtpub.com> 联系我们,并提供涉嫌盗版材料的链接。

我们感谢您在保护我们作者和我们为您提供有价值内容的能力方面的帮助。

询问

如果您在这门课程的任何方面遇到问题,您可以通过 <questions@packtpub.com> 联系我们,我们将尽力解决问题。

第一部分。模块 1

Scala 数据科学

利用 Scala 的不同工具构建可扩展、健壮的数据科学应用程序

第一章。Scala 和数据科学

20 世纪后半叶是硅的时代。在 50 年里,计算能力从极其稀缺到完全平凡。21 世纪前半叶是互联网的时代。在过去的 20 年里,谷歌、推特和 Facebook 等巨头崛起——这些巨头永远改变了我们看待知识的方式。

互联网是一个庞大的信息枢纽。人类产生的 90%的数据在过去 18 个月内已经生成。能够利用这些数据洪流以获得真正理解的程序员、统计学家和科学家将对商业、政府和慈善机构如何做出决策产生越来越大的影响。

本书旨在介绍一些您将需要的工具,以从数据洪流中综合提炼出真正的洞察力。

数据科学

数据科学是从数据中提取有用信息的过程。作为一个学科,它仍然有些模糊不清,定义的数量几乎与专家的数量一样多。而不是再添加另一个定义,我将遵循德鲁·康威的描述(drewconway.com/zia/2013/3/26/the-data-science-venn-diagram)。他描述数据科学为三个正交技能集的融合:

  • 数据科学家必须具备黑客技能。数据存储和传输通过计算机进行。计算机、编程语言和库是数据科学家的锤子和凿子;他们必须自信且准确地使用它们来塑造数据。这正是 Scala 发挥作用的地方:它是编程工具箱中一个强大的工具。

  • 数据科学家必须对统计学和数值算法有扎实的理解。优秀的数据科学家将理解机器学习算法的工作原理以及如何解读结果。他们不会被误导的指标、欺骗性的统计数据或误解的因果关系所迷惑。

  • 一个优秀的数据科学家必须对问题领域有扎实的理解。数据科学过程涉及以科学严谨的方式构建和发现关于问题领域知识。因此,数据科学家必须提出正确的问题,了解以往的结果,并理解数据科学努力如何融入更广泛的商业或研究背景。

德鲁·康威用一张维恩图优雅地总结了这一点,展示了数据科学位于黑客技能、数学和统计学知识以及实质性专业知识交汇处:

数据科学

当然,人们精通这些领域中的多个领域是罕见的。数据科学家通常在跨职能团队中工作,不同成员为不同领域提供专业知识。然而,为了有效运作,团队中的每个成员都必须对所有三个领域有一个一般的工作知识。

为了更具体地概述数据科学项目中的工作流程,让我们设想我们正在尝试编写一个分析公众对政治运动看法的应用程序。数据科学管道可能看起来是这样的:

  • 获取数据:这可能涉及从文本文件中提取信息、轮询传感器网络或查询 Web API。例如,我们可以查询 Twitter API 以获取包含相关标签的推文列表。

  • 数据摄取:数据通常来自许多不同的来源,可能是非结构化或半结构化的。数据摄取涉及将数据从数据源移动到处理它以提取结构化信息,并将这些信息存储在数据库中。例如,对于推文,我们可能会提取用户名、推文中提到的其他用户的名称、标签、推文文本以及推文是否包含某些关键词。

  • 探索数据:我们通常对想要从数据中提取的信息有一个明确的想法,但对如何做到这一点却知之甚少。例如,让我们设想我们已经摄取了包含与我们政治运动相关的标签的数千条推文。从我们的推文数据库到最终目标——了解公众对我们运动的总体看法——没有明确的路径。数据探索涉及规划我们如何到达那里。这一步骤通常会揭示新的问题或数据来源,这需要回到管道的第一步。例如,对于我们的推文数据库,我们可能会决定需要有人手动标记一千条或更多的推文,以表达对政治运动的“积极”或“消极”情绪。然后我们可以使用这些推文作为训练集来构建模型。

  • 构建特征:机器学习算法的好坏取决于进入它的特征。数据科学家的大部分时间都花在转换和组合现有特征以创建与我们要解决的问题更紧密相关的新特征上。例如,我们可能会构建一个新特征,对应于推文中“积极”语气单词或单词对的数量。

  • 模型构建和训练:在构建了进入模型的特征之后,数据科学家现在可以在他们的数据集上训练机器学习算法。这通常涉及尝试不同的算法和优化模型的超参数。例如,我们可能会决定使用随机森林算法来决定一条推文是对活动“正面”还是“负面”的看法。构建模型涉及选择合适的树的数量以及如何计算不纯度度量。对统计学和问题领域的良好理解将有助于这些决策。

  • 模型外推和预测:数据科学家现在可以使用他们新的模型来尝试推断关于先前未见数据点的信息。他们可能会将一条新的推文通过他们的模型来确认它是否对政治活动持正面或负面的看法。

  • 从模型中提取智能和洞察力:数据科学家将数据分析过程的结果与业务领域的知识相结合,以指导业务决策。他们可能会发现某些信息与目标受众或目标受众的特定部分产生更好的共鸣,从而实现更精确的目标。向利益相关者提供信息的关键部分涉及数据可视化和展示:数据科学家创建图表、可视化和报告,以帮助使得出的见解清晰且引人入胜。

这远非一个线性的管道。通常,在一个阶段获得的见解将要求数据科学家回溯到管道的先前阶段。确实,从原始数据生成业务见解通常是一个迭代的过程:数据科学家可能会进行快速的第一遍扫描以验证问题的前提,然后通过添加新的数据源或新特征或尝试新的机器学习算法来逐步完善方法。

在这本书中,你将学习如何在 Scala 中处理管道的每个步骤,利用现有库构建健壮的应用程序。

数据科学中的编程

这本书不是一本关于数据科学的书。它是一本关于如何使用 Scala 编程语言进行数据科学的书。那么,在处理数据时编程的作用在哪里呢?

计算机在数据科学管道的每个步骤中都发挥作用,但并不一定是同样的方式。如果我们只是编写临时脚本以探索数据或试图构建一个可扩展的应用程序,该应用程序通过一个被充分理解的管道推送数据以持续提供业务智能,那么我们构建的程序的风格将会有很大的不同。

让我们假设我们为一家制作手机游戏的公司工作,在这些游戏中你可以购买游戏内福利。大多数用户从不购买任何东西,但一小部分人可能会花很多钱。我们想要构建一个模型,根据他们的游戏模式识别大额消费者。

第一步是探索数据,找到正确的特征,并在数据的一个子集上构建模型。在这个探索阶段,我们有一个明确的目标,但几乎没有关于如何实现目标的想法。我们希望有一个轻量级、灵活的语言,以及强大的库,以便尽快得到一个工作模型。

一旦我们有了工作模型,我们需要将其部署到我们的游戏平台上,以分析所有当前用户的用法模式。这是一个非常不同的问题:我们对程序的目标和如何实现目标有相对清晰的理解。挑战在于设计能够扩展以处理所有用户并适应未来用法模式变化的软件。

实际上,我们编写的软件类型通常位于从单个一次性脚本到必须能够抵御未来扩展和负载增加的生产级代码的连续谱上。在编写任何代码之前,数据科学家必须了解他们的软件在这个谱上的位置。让我们称这个为持久性谱。

为什么选择 Scala?

你想编写一个处理数据的程序。你应该选择哪种语言?

有几种不同的选择。你可能会选择一种动态语言,如 Python 或 R,或者一种更传统的面向对象语言,如 Java。在本节中,我们将探讨 Scala 与这些语言的差异以及何时使用它可能是有意义的。

在选择语言时,架构师需要在可证明的正确性和开发速度之间进行权衡。你需要强调哪些方面将取决于应用程序的要求以及你的程序在持久性谱上的位置。这是一个短脚本,将被少数人使用,他们可以轻松修复任何出现的问题?如果是这样,你可能在很少使用的代码路径中允许一定数量的错误:当开发者遇到问题时,他们可以立即修复问题。相比之下,如果你正在开发一个计划发布给更广泛世界的数据库引擎,你很可能会更重视正确性而不是快速开发。例如,SQLite 数据库引擎以其广泛的测试套件而闻名,测试代码量是应用代码的 800 倍(www.sqlite.org/testing.html)。

在估计程序正确性时,重要的是不是感知到没有错误,而是你能够证明某些错误确实不存在的程度。

在代码运行之前,有几种方法可以证明不存在错误:

  • 在静态类型语言中,静态类型检查发生在编译时,但这也可以用于支持类型注解或类型提示的强类型动态语言。类型检查有助于验证我们是否按预期使用函数和类。

  • 静态分析器和代码检查器可以检查未定义的变量或可疑的行为(例如,代码中永远无法到达的部分)。

  • 在编译语言中声明某些属性为不可变或常量。

  • 单元测试以证明特定代码路径上没有 bug。

有几种方法可以检查运行时某些错误的缺失:

  • 在静态类型和动态语言中均支持动态类型检查

  • 断言验证假设的程序不变性或预期契约

在接下来的章节中,我们将探讨 Scala 在数据科学领域与其他语言的比较。

静态类型和类型推断

Scala 的静态类型系统非常灵活。程序行为的大量信息可以编码在类型中,使得编译器能够保证一定程度的正确性。这对于很少使用的代码路径特别有用。动态语言无法在特定执行分支运行之前捕获错误,因此错误可能长时间存在,直到程序遇到它。在静态类型语言中,任何编译器可以捕获的 bug 都会在程序开始运行之前在编译时被捕获。

静态类型面向对象语言常因冗余而被批评。以 Java 中Example类实例的初始化为例:

Example myInstance = new Example() ;

我们必须重复两次类名——一次是为了定义myInstance变量的编译时类型,另一次是为了构造实例本身。这感觉像是多余的工作:编译器知道myInstance的类型是Example(或Example的父类),因为我们绑定了一个Example类型的值。

Scala,像大多数函数式语言一样,使用类型推断来允许编译器从绑定到它们的实例中推断变量的类型。我们可以在 Scala 中这样写等效的行:

val myInstance = new Example()

Scala 编译器在编译时推断myInstance具有Example类型。很多时候,指定函数的参数和返回值的类型就足够了。然后编译器可以推断函数体中定义的所有变量的类型。Scala 代码通常比等效的 Java 代码更简洁、更易读,而不牺牲任何类型安全性。

Scala 鼓励不可变性

Scala 鼓励使用不可变对象。在 Scala 中,定义一个属性为不可变非常容易:

val amountSpent = 200

默认集合是不可变的:

val clientIds = List("123", "456") // List is immutable
clientIds(1) = "589" // Compile-time error

具有不可变对象消除了常见的错误来源。知道某些对象一旦实例化后就不能改变,可以减少错误可能潜入的地方。而不是考虑对象的生命周期,我们可以专注于构造函数。

Scala 和函数式程序

Scala 鼓励使用函数式代码。大量的 Scala 代码由使用高阶函数来转换集合组成。作为程序员,你不需要处理遍历集合的细节。让我们编写一个occurrencesOf函数,它返回元素在列表中出现的索引:

def occurrencesOfA:List[Int] = {
  for { 
    (currentElem, index) <- collection.zipWithIndex
    if (currentElem == elem)
  } yield index
}

这是如何工作的?我们首先声明一个新的列表,collection.zipWithIndex,其元素是(collection(0), 0)(collection(1), 1)等等:集合的元素及其索引的配对。

然后,我们告诉 Scala 我们想要遍历这个集合,将currentElem变量绑定到当前元素,将index绑定到索引。我们对迭代应用一个过滤器,只选择那些currentElem == elem的元素。然后我们告诉 Scala 只返回index变量。

在 Scala 中,我们不需要处理迭代过程的细节。语法非常声明式:我们告诉编译器我们想要集合中每个等于elem的元素的索引,然后让编译器去担心如何遍历集合。

考虑 Java 中的等效代码:

static <T> List<Integer> occurrencesOf(T elem, List<T> collection) {
  List<Integer> occurrences = new ArrayList<Integer>() ;
  for (int i=0; i<collection.size(); i++) {
    if (collection.get(i).equals(elem)) {
      occurrences.add(i) ;
    }
  }
  return occurrences ;
}

在 Java 中,你首先定义一个(可变的)列表,用于存放你找到的实例。然后通过定义一个计数器来遍历这个集合,依次考虑每个元素,并在需要时将其索引添加到实例列表中。为了使这个方法正常工作,我们需要正确处理许多其他部分。这些部分的存在是因为我们必须告诉 Java 如何遍历集合,并且它们是 bug 的常见来源。

此外,由于大量的代码被迭代机制占用,定义函数逻辑的行更难找到:

static <T> List<Integer> occurrencesOf(T elem, List<T> collection) {
  List<Integer> occurences = new ArrayList<Integer>() ;
  for (int i=0; i<collection.size(); i++) {
    if (collection.get(i).equals(elem)) { 
      occurrences.add(i) ;
    }
  }
  return occurrences ;
}

注意,这并不是对 Java 的攻击。事实上,Java 8 增加了一系列函数式构造,如 lambda 表达式、与 Scala 的Option相对应的Optional类型或流处理。相反,这是为了展示函数式方法在最小化错误潜力和最大化清晰度方面的好处。

空指针不确定性

我们经常需要表示值的可能不存在。例如,假设我们正在从 CSV 文件中读取用户名列表。CSV 文件包含姓名和电子邮件信息。然而,一些用户选择不将他们的电子邮件输入到系统中,因此这些信息不存在。在 Java 中,人们通常会使用字符串或Email类来表示电子邮件,并通过将那个引用设置为null来表示特定用户的电子邮件信息不存在。同样,在 Python 中,我们可能会使用None来表示值的缺失。

这种方法很危险,因为我们没有编码电子邮件信息的可能缺失。在任何非平凡程序中,决定实例属性是否可以为null需要考虑这个实例定义的每一个场合。这很快就会变得不切实际,因此程序员要么假设变量不是null,要么编写过于防御性的代码。

Scala(跟随其他函数式语言的趋势)引入了Option[T]类型来表示可能缺失的属性。然后我们可能会写出以下内容:

class User {
  ...
  val email:Option[Email]
  ...
}

我们现在已经在类型信息中编码了电子邮件可能不存在的情况。对于使用User类的任何程序员来说,电子邮件信息可能不存在是显而易见的。更好的是,编译器知道email字段可能不存在,这迫使我们处理这个问题,而不是鲁莽地忽略它,让应用程序在运行时因为空指针异常而崩溃。

所有这些都回到了实现一定程度的可证明正确性的目标。从不使用null,我们知道我们永远不会遇到空指针异常。在没有Option[T]的语言中实现相同级别的正确性需要编写单元测试来验证当电子邮件属性为null时客户端代码的行为是否正确。

注意,在 Java 中,可以使用例如 Google 的 Guava 库(code.google.com/p/guava-libraries/wiki/UsingAndAvoidingNullExplained)或 Java 8 中的Optional类来实现这一点。这更多是一个约定:在 Java 中使用null来表示值的缺失已经很长时间是规范了。

更容易的并行性

编写利用并行架构的程序具有挑战性。尽管如此,解决除了最简单的数据科学问题之外的所有问题仍然是必要的。

并行编程困难,因为我们作为程序员,往往倾向于按顺序思考。在并发程序中推理不同事件可能发生的顺序是非常具有挑战性的。

Scala 提供了几个抽象,这些抽象极大地简化了并行代码的编写。这些抽象通过限制实现并行性的方式来工作。例如,并行集合强制用户将计算表述为对集合的操作序列(如mapreducefilter)。演员系统要求开发者从封装应用程序状态并通过传递消息进行通信的演员的角度来思考。

限制程序员自由编写他们想要的并行代码似乎有些矛盾,但这可以避免与并发相关的大多数问题。然而,限制程序行为的方式有助于思考其行为。例如,如果一个演员表现不佳,我们知道问题要么在于这个演员的代码,要么在于演员收到的某个消息。

作为具有一致、限制性抽象的强大功能的例子,让我们使用并行集合来解决一个简单的概率问题。我们将计算在 100 次抛硬币中至少得到 60 次正面的概率。我们可以使用蒙特卡洛方法来估计这一点:通过抽取 100 个随机的布尔值来模拟 100 次抛硬币,并检查真值的数量是否至少为 60。我们重复这个过程,直到结果收敛到所需的精度,或者我们等得不耐烦了。

让我们在 Scala 控制台中演示这个过程:

scala> val nTosses = 100
nTosses: Int = 100

scala> def trial = (0 until nTosses).count { i =>
 util.Random.nextBoolean() // count the number of heads
}
trial: Int

trial函数运行一组 100 次投掷,返回正面朝上的次数:

scala> trial
Int = 51

为了得到我们的答案,我们只需要尽可能多地重复trial,并汇总结果。重复相同的操作集非常适合并行集合:

scala> val nTrials = 100000
nTrials: Int = 100000

scala> (0 until nTrials).par.count { i => trial >= 60 }
Int = 2745

因此,概率大约是 2.5%到 3%。我们只需使用par方法并行化范围(0 until nTrials),就可以将计算分布到我们计算机的每个 CPU 上。这证明了具有一致抽象的好处:并行集合使我们能够轻易地将任何可以用集合上的高阶函数表述的计算并行化。

显然,并非每个问题都像简单的蒙特卡洛问题那样容易并行化。然而,通过提供丰富的直观抽象,Scala 使得编写并行应用程序变得可行。

与 Java 的互操作性

Scala 运行在 Java 虚拟机上。Scala 编译器将程序编译成 Java 字节码。因此,Scala 开发者可以原生地访问 Java 库。鉴于用 Java 编写的应用程序数量庞大,无论是开源的还是作为组织中的遗留代码的一部分,Scala 和 Java 的互操作性有助于解释 Scala 的快速采用。

互操作性不仅仅是单向的:一些 Scala 库,如 Play 框架,在 Java 开发者中越来越受欢迎。

何时不使用 Scala

在前面的章节中,我们描述了 Scala 的强类型系统、对不可变性的偏好、函数能力和并行抽象如何使得编写可靠的程序变得容易,并最小化意外行为的风险。

你可能有哪些理由在下一个项目中避免使用 Scala?一个重要的原因是熟悉度。Scala 引入了许多概念,例如隐式参数、类型类和通过特质使用组合,这些可能对来自面向对象世界的程序员来说并不熟悉。Scala 的类型系统非常强大,但要充分了解它以发挥其全部功能需要时间,并需要适应新的编程范式。最后,对于来自 Java 或 Python 的程序员来说,处理不可变数据结构可能会感到不适应。

然而,这些都是可以通过时间克服的缺点。Scala 在库可用性方面确实不如其他数据科学语言。IPython Notebook 与 matplotlib 结合使用,是数据探索的无与伦比的资源。有持续的努力在 Scala 中提供类似的功能(例如 Spark Notebooks 或 Apache Zeppelin),但没有项目达到相同的成熟度。当探索数据或尝试不同的模型时,类型系统也可能成为轻微的障碍。

因此,在这个作者有偏见的观点中,Scala 在编写更“永久”的程序方面表现出色。如果你正在编写一个一次性脚本或探索数据,你可能会发现 Python 更适合。如果你正在编写需要重用并需要一定程度的可证明正确性的东西,你会发现 Scala 非常强大。

摘要

现在必要的介绍已经结束,是时候编写一些 Scala 代码了。在下一章中,你将学习如何利用 Breeze 在 Scala 中进行数值计算。在我们的第一次数据科学探索中,我们将使用逻辑回归来预测给定一个人的身高和体重来预测其性别。

参考文献

到目前为止,关于 Scala 的最佳书籍是马丁·奥德斯基(Martin Odersky)、莱克斯·斯波恩(Lex Spoon)和比尔·文纳(Bill Venners)合著的《Programming in Scala》。这本书不仅权威(马丁·奥德斯基是 Scala 的推动力),而且易于接近和阅读。

《Scala Puzzlers》由安德鲁·菲利普斯(Andrew Phillips)和内尔明·谢里福维奇(Nermin Šerifović)所著,提供了一种有趣的方式来学习更高级的 Scala。

《Scala for Machine Learning》由帕特里克·R·尼古拉斯(Patrick R. Nicholas)所著,提供了如何使用 Scala 编写机器学习算法的示例。

第二章:使用 Breeze 操作数据

数据科学在很大程度上是关于结构化数据的操作。大量结构化数据集可以被视为表格数据:每一行代表一个特定的实例,而列代表该实例的不同属性。表格表示的普遍性解释了像 Microsoft Excel 这样的电子表格程序或像 SQL 数据库这样的工具的成功。

要对数据科学家有用,一种语言必须支持数据列或表格的操作。例如,Python 通过 NumPy 和 pandas 来实现这一点。不幸的是,在 Scala 中,没有一个单一、连贯的数值计算生态系统可以与 Python 中的 SciPy 生态系统相媲美。

在本章中,我们将介绍 Breeze,这是一个用于快速线性代数和数据数组操作以及许多其他科学计算和数据科学所需特性的库。

代码示例

访问本书中的代码示例最简单的方法是克隆 GitHub 仓库:

$ git clone 'https://github.com/pbugnion/s4ds'

每章的代码示例都在一个单独的独立文件夹中。你还可以在 GitHub 上在线浏览代码。

安装 Breeze

如果你已经下载了本书的代码示例,使用 Breeze 最简单的方法是进入chap02目录,并在命令行中输入sbt console。这将打开一个 Scala 控制台,你可以在其中导入 Breeze。

如果你想要构建一个独立的项目,安装 Breeze(以及任何 Scala 模块)最常见的方式是通过 SBT。为了获取本章所需的依赖项,将以下行复制到名为build.sbt的文件中,注意在scalaVersion之后留一个空行:

scalaVersion := "2.11.7"

libraryDependencies ++= Seq(
  "org.scalanlp" %% "breeze" % "0.11.2",
  "org.scalanlp" %% "breeze-natives" % "0.11.2"
)

在终端中输入sbt console,在build.sbt文件相同的目录中打开 Scala 控制台。您可以通过从 Scala 提示符导入 Breeze 来检查 Breeze 是否正常工作:

scala> import breeze.linalg._
import breeze.linalg._

获取 Breeze 的帮助

本章对 Breeze 进行了相当详细的介绍,但并不旨在提供完整的 API 参考。

要获取 Breeze 功能的完整列表,请查阅 GitHub 上的 Breeze Wiki 页面github.com/scalanlp/breeze/wiki。对于某些模块来说,这是非常完整的,而对于其他模块来说则不那么完整。源代码([github.com/scalanlp/breeze/](https://github.com/scalanlp/breeze/))详细且提供了大量信息。要了解特定函数的预期用法,请查看该函数的单元测试。

基本的 Breeze 数据类型

Breeze 是一个功能丰富的库,提供了对数据数组进行快速和简单操作的功能,包括优化、插值、线性代数、信号处理和数值积分的例程。

Breeze 背后的基本线性代数操作依赖于netlib-java库,该库可以使用系统优化的BLASLAPACK库(如果存在)。因此,Breeze 中的线性代数操作通常非常快。Breeze 仍在快速发展中,因此可能有些不稳定。

向量

Breeze 使得操作一维和二维数据结构变得简单。首先,通过 SBT 打开 Scala 控制台并导入 Breeze:

$ sbt console
scala> import breeze.linalg._
import breeze.linalg._

让我们直接定义一个向量:

scala> val v = DenseVector(1.0, 2.0, 3.0)
breeze.linalg.DenseVector[Double] = DenseVector(1.0, 2.0, 3.0)

我们刚刚定义了一个包含三个元素的向量,v。向量只是一维数据数组,提供了针对数值使用的定制方法。它们可以像其他 Scala 集合一样进行索引:

scala> v(1)
Double = 2.0

它们还支持与标量的逐元素操作:

scala> v :* 2.0 // :* is 'element-wise multiplication'
breeze.linalg.DenseVector[Double] = DenseVector(2.0, 4.0, 6.0)

它们还支持与另一个向量的逐元素操作:

scala> v :+ DenseVector(4.0, 5.0, 6.0) // :+ is 'element-wise addition'
breeze.linalg.DenseVector[Double] = DenseVector(5.0, 7.0, 9.0)

Breeze 使得编写向量操作直观,并且比 Scala 原生等价物更易于阅读。

注意,Breeze 将在编译时拒绝将操作数强制转换为正确的类型:

scala> v :* 2 // element-wise multiplication by integer
<console>:15: error: could not find implicit value for parameter op:
...

它还会在运行时拒绝将不同长度的向量相加:

scala> v :+ DenseVector(8.0, 9.0)
java.lang.IllegalArgumentException: requirement failed: Vectors must have same length: 3 != 2
...

在 Breeze 中对向量的基本操作对习惯于使用 NumPy、MATLAB 或 R 的人来说将感觉自然。

到目前为止,我们只看了逐元素运算符。这些运算符都是以冒号开头的。所有常见的运算符都存在::+:*:-:/:%(余数)和:^(幂)以及布尔运算符。要查看运算符的完整列表,请查看DenseVectorDenseMatrix的 API 文档(github.com/scalanlp/breeze/wiki/Linear-Algebra-Cheat-Sheet)。

除了逐元素操作外,Breeze 向量还支持您可能期望的数学向量的操作,例如点积:

scala> val v2 = DenseVector(4.0, 5.0, 6.0)
breeze.linalg.DenseVector[Double] = DenseVector(4.0, 5.0, 6.0)

scala> v dot v2
Double = 32.0

小贴士

逐元素运算符的陷阱

除了我们迄今为止看到的用于逐元素加法和减法的 :+:- 运算符之外,我们还可以使用更传统的 +- 运算符:

scala> v + v2
breeze.linalg.DenseVector[Double] = DenseVector(5.0, 7.0, 9.0)

然而,在将 :+:*:+ 运算符混合使用时,必须非常小心运算符优先级规则。:+:* 运算符的运算优先级非常低,因此它们将被最后评估。这可能会导致一些不符合直觉的行为:

scala> 2.0 :* v + v2 // !! equivalent to 2.0 :* (v + v2)
breeze.linalg.DenseVector[Double] = DenseVector(10.0, 14.0, 18.0)

相比之下,如果我们使用 :+ 而不是 +,则运算符的数学优先级将被尊重:

scala> 2.0 :* v :+ v2 // equivalent to (2.0 :* v) :+ v2
breeze.linalg.DenseVector[Double] = DenseVector(6.0, 9.0, 12.0)

总结来说,应尽可能避免将 :+ 风格的运算符与 + 风格的运算符混合使用。

稠密和稀疏矢量和矢量特性

我们迄今为止查看的所有矢量都是稠密矢量。Breeze 也支持稀疏矢量。当处理主要由零组成的数字数组时,使用稀疏矢量可能更有效。矢量何时有足够的零以证明切换到稀疏表示取决于操作类型,因此你应该运行自己的基准测试以确定使用哪种类型。尽管如此,一个好的启发式方法是,如果你的矢量大约有 90% 是零,那么使用稀疏表示可能会有所裨益。

稀疏矢量在 Breeze 中以 SparseVectorHashVector 类的形式提供。这两种类型都支持与 DenseVector 相同的许多操作,但使用不同的内部实现。SparseVector 实例非常节省内存,但添加非零元素的速度较慢。HashVector 更灵活,但代价是内存占用和遍历非零元素的计算时间增加。除非你需要从应用程序中挤出最后一点内存,否则我建议使用 HashVector。本书中不会进一步讨论这些内容,但如果需要,读者应该会发现它们的使用非常直观。DenseVectorSparseVectorHashVector 都实现了 Vector 特性,从而提供了统一的接口。

小贴士

Breeze 仍然处于实验阶段,并且截至本文撰写时,有些不稳定。我发现处理 Vector 特性的特定实现(如 DenseVectorSparseVector)比直接处理 Vector 特性更可靠。在本章中,我们将明确地将每个矢量类型指定为 DenseVector

矩阵

Breeze 允许以类似的方式构建和操作二维数组:

scala> val m = DenseMatrix((1.0, 2.0, 3.0), (4.0, 5.0, 6.0))
breeze.linalg.DenseMatrix[Double] =
1.0  2.0  3.0
4.0  5.0  6.0

scala> 2.0 :* m
breeze.linalg.DenseMatrix[Double] =
2.0  4.0   6.0
8.0  10.0  12.0

构建矢量和矩阵

我们已经看到了如何通过将它们的值传递给构造函数(或者更确切地说,传递给伴随对象的 apply 方法)来显式构建矢量和矩阵:DenseVector(1.0, 2.0, 3.0)。Breeze 提供了构建矢量和矩阵的几种其他强大方法:

scala> DenseVector.onesDouble
breeze.linalg.DenseVector[Double] = DenseVector(1.0, 1.0, 1.0, 1.0, 1.0)

scala> DenseVector.zerosInt
breeze.linalg.DenseVector[Int] = DenseVector(0, 0, 0)

linspace 方法(在 breeze.linalg 包对象中可用)创建一个等间隔值的 Double 矢量。例如,要创建一个在 01 之间均匀分布的 10 个值的矢量,请执行以下操作:

scala> linspace(0.0, 1.0, 10)
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 0.1111111111111111, ..., 1.0)

tabulate 方法允许我们通过函数构造向量和矩阵:

scala> DenseVector.tabulate(4) { i => 5.0 * i }
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 5.0, 10.0, 15.0)

scala> DenseMatrix.tabulateInt { 
 (irow, icol) => irow*2 + icol 
}
breeze.linalg.DenseMatrix[Int] =
0  1  2
2  3  4

DenseVector.tabulate 的第一个参数是向量的长度,第二个参数是一个函数,它返回向量在特定位置的值。这有助于创建数据范围,以及其他用途。

rand 函数允许我们创建随机向量和矩阵:

scala> DenseVector.rand(2)
breeze.linalg.DenseVector[Double] = DenseVector(0.8072865137359484, 0.5566507203838562)

scala> DenseMatrix.rand(2, 3)
breeze.linalg.DenseMatrix[Double] =
0.5755491874682879   0.8142161471517582  0.9043780212739738
0.31530195124023974  0.2095094278911871  0.22069103504148346

最后,我们可以从 Scala 数组构造向量:

scala> DenseVector(Array(2, 3, 4))
breeze.linalg.DenseVector[Int] = DenseVector(2, 3, 4)

要从其他 Scala 集合构造向量,必须使用 splat 操作符,:_ *

scala> val l = Seq(2, 3, 4)
l: Seq[Int] = List(2, 3, 4)

scala> DenseVector(l :_ *)
breeze.linalg.DenseVector[Int] = DenseVector(2, 3, 4)

高级索引和切片

我们已经看到如何通过索引选择向量 v 中的特定元素,例如 v(2)。Breeze 还提供了几个强大的方法来选择向量的部分。

让我们先创建一个向量来玩耍:

scala> val v = DenseVector.tabulate(5) { _.toDouble }
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 1.0, 2.0, 3.0, 4.0)

与原生 Scala 集合不同,Breeze 向量支持负索引:

scala> v(-1) // last element
Double = 4.0

Breeze 允许我们使用范围来切片向量:

scala> v(1 to 3)
breeze.linalg.DenseVector[Double] = DenseVector(1.0, 2.0, 3.0)

scala v(1 until 3) // equivalent to Python v[1:3]
breeze.linalg.DenseVector[Double] = DenseVector(1.0, 2.0)

scala> v(v.length-1 to 0 by -1) // reverse view of v
breeze.linalg.DenseVector[Double] = DenseVector(4.0, 3.0, 2.0, 1.0, 0.0)

小贴士

通过范围进行索引返回原始向量的一个 视图:当运行 val v2 = v(1 to 3) 时,不会复制任何数据。这意味着切片非常高效。从大向量中取切片不会增加内存占用。这也意味着在更新切片时应该小心,因为它也会更新原始向量。我们将在本章后续部分讨论向量和矩阵的修改。

Breeze 还允许我们从向量中选择任意一组元素:

scala> val vSlice = v(2, 4) // Select elements at index 2 and 4
breeze.linalg.SliceVector[Int,Double] = breeze.linalg.SliceVector@9c04d22

这创建了一个 SliceVector,它类似于 DenseVector(两者都实现了 Vector 接口),但实际上并没有为值分配内存:它只是知道如何将其索引映射到父向量的值。可以将 vSlice 视为 v 的一个特定视图。我们可以通过将其转换为 DenseVector 来具体化视图(给它自己的数据,而不是作为 v 的观察透镜):

scala> vSlice.toDenseVector
breeze.linalg.DenseVector[Double] = DenseVector(2.0, 4.0)

注意,如果切片中的某个元素超出了范围,只有在访问该元素时才会抛出异常:

scala> val vSlice = v(2, 7) // there is no v(7)
breeze.linalg.SliceVector[Int,Double] = breeze.linalg.SliceVector@2a83f9d1

scala> vSlice(0) // valid since v(2) is still valid
Double = 2.0

scala> vSlice(1) // invalid since v(7) is out of bounds
java.lang.IndexOutOfBoundsException: 7 not in [-5,5)
 ...

最后,可以使用布尔数组来索引向量。让我们先定义一个数组:

scala> val mask = DenseVector(true, false, false, true, true)
breeze.linalg.DenseVector[Boolean] = DenseVector(true, false, false, true, true)

然后,v(mask) 结果是一个包含 vmasktrue 的元素的视图:

scala> v(mask).toDenseVector
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 3.0, 4.0)

这可以用作过滤向量中某些元素的一种方式。例如,要选择小于 3.0v 的元素:

scala> val filtered = v(v :< 3.0) // :< is element-wise "less than"
breeze.linalg.SliceVector[Int,Double] = breeze.linalg.SliceVector@2b1edef3

scala> filtered.toDenseVector
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 1.0, 2.0)

矩阵的索引方式与向量非常相似。矩阵索引函数接受两个参数——第一个参数选择行(s),第二个参数切片列(s):

scala> val m = DenseMatrix((1.0, 2.0, 3.0), (5.0, 6.0, 7.0))
m: breeze.linalg.DenseMatrix[Double] =
1.0  2.0  3.0
5.0  6.0  7.0

scala> m(1, 2)
Double = 7.0

scala> m(1, -1)
Double = 7.0

scala> m(0 until 2, 0 until 2)
breeze.linalg.DenseMatrix[Double] =
1.0  2.0
5.0  6.0

您还可以混合不同类型的行和列切片:

scala> m(0 until 2, 0)
breeze.linalg.DenseVector[Double] = DenseVector(1.0, 5.0)

注意,在这种情况下,Breeze 返回一个向量。一般来说,切片操作返回以下对象:

  • 当传递单个索引作为行和列参数时,返回一个标量

  • 当行参数是一个范围,列参数是一个单个索引时,返回一个向量

  • 当列参数是一个范围,行参数是一个单个索引时,返回一个向量的转置

  • 否则返回一个矩阵

符号::可以用来表示沿特定方向的每个元素。例如,我们可以选择m的第二列:

scala> m(::, 1)
breeze.linalg.DenseVector[Double] = DenseVector(2.0, 6.0)

修改向量和矩阵

Breeze 中的向量和矩阵是可变的。上述描述的大多数切片操作也可以用来设置向量和矩阵的元素:

scala> val v = DenseVector(1.0, 2.0, 3.0)
v: breeze.linalg.DenseVector[Double] = DenseVector(1.0, 2.0, 3.0)

scala> v(1) = 22.0 // v is now DenseVector(1.0, 22.0, 3.0)

我们不仅限于修改单个元素。实际上,上述概述的所有索引操作都可以用来设置向量和矩阵的元素。在修改向量和矩阵的切片时,使用逐元素赋值运算符:=

scala> v(0 until 2) := DenseVector(50.0, 51.0) // set elements at position 0 and 1
breeze.linalg.DenseVector[Double] = DenseVector(50.0, 51.0)

scala> v
breeze.linalg.DenseVector[Double] = DenseVector(50.0, 51.0, 3.0)

赋值运算符:=在 Breeze 中像其他逐元素运算符一样工作。如果右侧是一个标量,它将自动广播到给定形状的向量:

scala> v(0 until 2) := 0.0 // equivalent to v(0 until 2) := DenseVector(0.0, 0.0)
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 0.0)

scala> v
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 0.0, 3.0)

所有逐元素运算符都有一个更新对应的运算符。例如,:+=运算符类似于逐元素加法运算符:+,但也会更新其左侧的操作数:

scala> val v = DenseVector(1.0, 2.0, 3.0)
v: breeze.linalg.DenseVector[Double] = DenseVector(1.0, 2.0, 3.0)

scala> v :+= 4.0
breeze.linalg.DenseVector[Double] = DenseVector(5.0, 6.0, 7.0)

scala> v
breeze.linalg.DenseVector[Double] = DenseVector(5.0, 6.0, 7.0)

注意更新操作是如何就地更新向量并返回它的。

我们已经学会了如何在 Breeze 中切片向量和矩阵来创建原始数据的新的视图。这些视图并不独立于它们创建的向量——更新视图将更新底层向量,反之亦然。这最好用一个例子来说明:

scala> val v = DenseVector.tabulate(6) { _.toDouble }
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)

scala> val viewEvens = v(0 until v.length by 2)
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 2.0, 4.0)

scala> viewEvens := 10.0 // mutate viewEvens
breeze.linalg.DenseVector[Double] = DenseVector(10.0, 10.0, 10.0)

scala> viewEvens
breeze.linalg.DenseVector[Double] = DenseVector(10.0, 10.0, 10.0)

scala> v  // v has also been mutated!
breeze.linalg.DenseVector[Double] = DenseVector(10.0, 1.0, 10.0, 3.0, 10.0, 5.0)

如果我们记住,当我们创建一个向量和矩阵时,我们实际上是在创建一个底层数据数组的视图,而不是创建数据本身,这会很快变得直观:

修改向量和矩阵

v向量的切片v(0 to 6 by 2)只是v底层数组的另一种视图。这个视图本身不包含数据。它只包含指向原始数组中数据的指针。内部,视图只是存储为指向底层数据的指针和一个遍历该数据的配方:在这个切片的情况下,配方就是“从底层数据的第一元素开始,以两步的间隔到达底层数据的第七元素”。

当我们想要创建数据的独立副本时,Breeze 提供了一个copy函数。在之前的例子中,我们可以构建viewEvens的副本如下:

scala> val copyEvens = v(0 until v.length by 2).copy
breeze.linalg.DenseVector[Double] = DenseVector(10.0, 10.0, 10.0)

现在,我们可以独立于v更新copyEvens

矩阵乘法、转置和向量的方向

到目前为止,我们主要关注了向量和矩阵的逐元素操作。现在让我们看看矩阵乘法和相关操作。

矩阵乘法运算符是*

scala> val m1 = DenseMatrix((2.0, 3.0), (5.0, 6.0), (8.0, 9.0))
breeze.linalg.DenseMatrix[Double] =
2.0  3.0
5.0  6.0
8.0  9.0

scala> val m2 = DenseMatrix((10.0, 11.0), (12.0, 13.0))
breeze.linalg.DenseMatrix[Double] 
10.0  11.0
12.0  13.0

scala> m1 * m2
56.0   61.0
122.0  133.0
188.0  205.0

除了矩阵-矩阵乘法之外,我们还可以在矩阵和向量之间使用矩阵乘法运算符。Breeze 中的所有向量都是列向量。这意味着,当矩阵和向量相乘时,向量应被视为一个(n * 1)矩阵。让我们通过一个矩阵-向量乘法的例子来了解一下。我们想要以下操作:

矩阵乘法、转置和向量的方向

scala> val v = DenseVector(1.0, 2.0)
breeze.linalg.DenseVector[Double] = DenseVector(1.0, 2.0)

scala> m1 * v 
breeze.linalg.DenseVector[Double] = DenseVector(8.0, 17.0, 26.0)

相比之下,如果我们想要:

矩阵乘法、转置和向量的方向

我们必须将 v 转换为一个行向量。我们可以使用转置操作来完成此操作:

scala> val vt = v.t
breeze.linalg.Transpose[breeze.linalg.DenseVector[Double]] = Transpose(DenseVector(1.0, 2.0))

scala> vt * m2
breeze.linalg.Transpose[breeze.linalg.DenseVector[Double]] = Transpose(DenseVector(34.0, 37.0))

注意,v.t 的类型是 Transpose[DenseVector[_]]。在逐元素操作方面,Transpose[DenseVector[_]]DenseVector 几乎以相同的方式表现,但它不支持修改或切片。

数据预处理和特征工程

我们现在已经发现了 Breeze 的基本组件。在接下来的几节中,我们将将它们应用于实际例子,以了解它们如何组合在一起形成数据科学的一个强大基础。

数据科学的一个重要部分涉及预处理数据集以构建有用的特征。让我们通过一个例子来了解这个过程。为了跟随这个例子并访问数据,您需要下载本书的代码示例(www.github.com/pbugnion/s4ds)。

您将在本书附带的代码的 chap02/data/ 目录中找到一个 CSV 文件,其中包含 181 名男性和女性的真实身高和体重,以及自我报告的身高和体重。原始数据集是作为一项关于身体形象的研究的一部分收集的。有关更多信息,请参阅以下链接:vincentarelbundock.github.io/Rdatasets/doc/car/Davis.html

本书提供的包中有一个辅助函数,用于将数据加载到 Breeze 数组中:

scala> val data = HWData.load
HWData [ 181 rows ]

scala> data.genders
breeze.linalg.Vector[Char] = DenseVector(M, F, F, M, ... )

data 对象包含五个向量,每个向量长度为 181:

  • data.genders:一个描述参与者性别的 Char 向量

  • data.heights:一个包含参与者真实身高的 Double 向量

  • data.weights:一个包含参与者真实体重的 Double 向量

  • data.reportedHeights:一个包含参与者自我报告身高的 Double 向量

  • data.reportedWeights:一个包含参与者自我报告体重的 Double 向量

让我们先计算研究中男性和女性的数量。我们将定义一个只包含 'M' 的数组,并与 data.genders 进行逐元素比较:

scala> val maleVector = DenseVector.fill(data.genders.length)('M')
breeze.linalg.DenseVector[Char] = DenseVector(M, M, M, M, M, M,... )

scala> val isMale = (data.genders :== maleVector)
breeze.linalg.DenseVector[Boolean] = DenseVector(true, false, false, true ...)

isMale 向量与 data.genders 的长度相同。当参与者为男性时,它为 true,否则为 false。我们可以使用这个布尔数组作为数据集中其他数组的掩码(记住 vector(mask) 选择 vector 中掩码为 true 的元素)。让我们获取我们数据集中男性的身高:

scala> val maleHeights = data.heights(isMale)
breeze.linalg.SliceVector[Int,Double] = breeze.linalg.SliceVector@61717d42

scala> maleHeights.toDenseVector
breeze.linalg.DenseVector[Double] = DenseVector(182.0, 177.0, 170.0, ...

要计算我们数据集中男性的数量,我们可以使用指示函数。这个函数将布尔数组转换为一个双精度浮点数数组,将 false 映射到 0.0,将 true 映射到 1.0

scala> import breeze.numerics._
import breeze.numerics._

scala> sum(I(isMale))
Double: 82.0

让我们计算实验中男性和女性的平均身高。我们可以使用 mean(v) 计算向量的平均值,通过导入 breeze.stats._ 来访问它:

scala> import breeze.stats._
import breeze.stats._

scala> mean(data.heights)
Double = 170.75690607734808

要计算男性的平均身高,我们可以使用我们的isMale数组来切片data.heightsdata.heights(isMale)data.heights数组的一个视图,其中包含所有男性的身高值:

scala> mean(data.heights(isMale)) // mean male height
Double = 178.0121951219512

scala> mean(data.heights(!isMale)) // mean female height
Double = 164.74747474747474

作为稍微复杂一点的例子,让我们看看在这个实验中男性和女性实际体重与报告体重之间的差异。我们可以得到一个报告体重与真实体重之间百分比差异的数组:

scala> val discrepancy = (data.weights - data.reportedWeights) / data.weights
breeze.linalg.Vector[Double] = DenseVector(0.0, 0.1206896551724138, -0.018867924528301886, -0.029411764705882353, ... )

注意 Breeze 对数学运算符的重载如何使我们能够轻松优雅地操作数据数组。

我们现在可以计算这个数组中男性的平均值和标准差:

scala> mean(discrepancy(isMale))
res6: Double = -0.008451852933123775

scala> stddev(discrepancy(isMale))
res8: Double = 0.031901519634244195

我们还可以计算高估自己身高的男性的比例:

scala> val overReportMask = (data.reportedHeights :> data.heights).toDenseVector
breeze.linalg.DenseVector[Boolean] = DenseVector(false, false, false, false...

scala> sum(I(overReportMask :& isMale))
Double: 10.0

因此,有十名男性认为自己比实际更高。逐元素 AND 运算符:&返回一个向量,对于其两个参数都为真的所有索引,该向量是真实的。因此,向量overReportMask :& isMale对于所有报告身高超过实际身高的男性参与者都是真实的。

Breeze – 函数优化

在研究了特征工程之后,我们现在来看看数据科学管道的另一端。通常,机器学习算法定义了一个损失函数,该函数是一组参数的函数。损失函数的值表示模型拟合数据的程度。然后,参数被优化以最小化(或最大化)损失函数。

在第十二章《使用 MLlib 的分布式机器学习》中,我们将探讨MLlib,这是一个包含许多知名算法的机器学习库。通常,我们不需要担心直接优化损失函数,因为我们可以依赖 MLlib 提供的机器学习算法。然而,了解优化基础知识仍然很有用。

Breeze 有一个optimize模块,其中包含用于寻找局部最小值的函数:

scala> import breeze.optimize._
import breeze.optimize._

让我们创建一个我们想要优化的玩具函数:

Breeze – 函数优化

我们可以用以下方式在 Scala 中表示这个函数:

scala> def f(xs:DenseVector[Double]) = sum(xs :^ 2.0)
f: (xs: breeze.linalg.DenseVector[Double])Double

大多数局部优化器也要求提供正在优化的函数的梯度。梯度是与函数的参数相同维度的向量。在我们的例子中,梯度是:

Breeze – 函数优化

我们可以用一个接受向量参数并返回相同长度向量的函数来表示梯度:

scala> def gradf(xs:DenseVector[Double]) = 2.0 :* xs
gradf: (xs:breeze.linalg.DenseVector[Double])breeze.linalg.DenseVector[Double]

例如,在点(1, 1, 1)处,我们有:

scala> val xs = DenseVector.onesDouble
breeze.linalg.DenseVector[Double] = DenseVector(1.0, 1.0, 1.0)

scala> f(xs)
Double = 3.0

scala> gradf(xs)
breeze.linalg.DenseVector[Double] = DenseVector(2.0, 2.0, 2.0)

让我们设置优化问题。Breeze 的优化方法要求我们传递一个实现了DiffFunction特质的实现,该特质只有一个方法,即calculate。此方法必须返回一个包含函数及其梯度的元组:

scala> val optTrait = new DiffFunction[DenseVector[Double]] {
 def calculate(xs:DenseVector[Double]) = (f(xs), gradf(xs))
}
breeze.optimize.DiffFunction[breeze.linalg.DenseVector[Double]] = <function1>

我们现在可以运行优化了。优化模块提供了一个minimize函数,它正好符合我们的需求。我们传递给它optTrait和优化的起始点:

scala> val minimum = minimize(optTrait, DenseVector(1.0, 1.0, 1.0))
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 0.0, 0.0)

真实最小值在(0.0, 0.0, 0.0)。因此,优化器正确地找到了最小值。

minimize函数默认使用L-BFGS方法运行优化。它接受几个额外的参数来控制优化。我们将在下一节中探讨这些参数。

数值微分

在前面的例子中,我们明确指定了f的梯度。虽然这通常是好的做法,但计算函数的梯度往往很繁琐。Breeze 提供了一个使用有限差分的梯度近似函数。重用与上一节相同的目标准函数def f(xs:DenseVector[Double]) = sum(xs :^ 2.0)

scala> val approxOptTrait = new ApproximateGradientFunction(f)
breeze.optimize.ApproximateGradientFunction[Int,breeze.linalg.DenseVector[Double]] = <function1>

特性approxOptTrait有一个gradientAt方法,它返回在一点处的梯度近似值:

scala> approxOptTrait.gradientAt(DenseVector.ones(3))
breeze.linalg.DenseVector[Double] = DenseVector(2.00001000001393, 2.00001000001393, 2.00001000001393)

注意,这可能会相当不准确。ApproximateGradientFunction构造函数接受一个可选的epsilon参数,该参数控制计算有限差分时步长的大小。改变epsilon的值可以提高有限差分算法的准确性。

ApproximateGradientFunction实例实现了DiffFunction特性。因此,它可以直接传递给minimize

scala> minimize(approxOptTrait, DenseVector.onesDouble)
breeze.linalg.DenseVector[Double] = DenseVector(-5.000001063126813E-6, -5.000001063126813E-6, -5.000001063126813E-6)

这再次给出了接近零的结果,但比我们明确指定梯度时稍远一些。一般来说,通过解析计算函数的梯度将比依赖 Breeze 的数值梯度更有效、更准确。可能最好只在数据探索期间或检查解析梯度时使用数值梯度。

正则化

minimize函数接受许多与机器学习算法相关的可选参数。特别是,我们可以指示优化器在执行优化时使用正则化参数。正则化在损失函数中引入惩罚,以防止参数任意增长。这有助于避免过拟合。我们将在第十二章中更详细地讨论正则化,使用 MLlib 的分布式机器学习

例如,要使用具有超参数0.5L2Regularization

scala> minimize(optTrait, DenseVector(1.0, 1.0, 1.0), L2Regularization(0.5))
breeze.linalg.DenseVector[Double] = DenseVector(0.0, 0.0, 0.0)

在这个例子中,正则化没有区别,因为参数在最小值处为零。

要查看可以传递给minimize的可选参数列表,请查阅在线的 Breeze 文档。

一个例子——逻辑回归

现在让我们想象我们想要构建一个分类器,它接受一个人的身高体重,并为他们被分配为男性女性的概率。我们将重用本章前面引入的身高和体重数据。让我们先绘制数据集:

一个例子——逻辑回归

181 名男性和女性的身高与体重数据

分类算法有很多种。初步观察数据表明,我们可以通过在图上画一条直线来大致区分男性和女性。因此,线性方法是对分类的合理初步尝试。在本节中,我们将使用逻辑回归来构建分类器。

逻辑回归的详细解释超出了本书的范围。对逻辑回归不熟悉的读者可参考 HastieTibshiraniFriedman 所著的 《统计学习的要素》。我们在这里只做简要概述。

逻辑回归使用以下 sigmoid 函数估计给定身高和体重属于男性的概率:

一个示例 – 逻辑回归

在这里,f 是一个线性函数:

一个示例 – 逻辑回归

在这里,一个示例 – 逻辑回归 是我们需要使用训练集确定的参数数组。如果我们将身高和体重视为 features = (height, weight) 矩阵,我们可以将 sigmoid 核 f 重新写为 features 矩阵与 params 向量的矩阵乘法:

一个示例 – 逻辑回归

为了进一步简化这个表达式,通常会在 features 矩阵中添加一个值始终为 1 的虚拟特征。然后我们可以将 params(0) 乘以这个特征,这样我们就可以将整个 sigmoid 核 f 写作一个单一的矩阵-向量乘法:

一个示例 – 逻辑回归

特征矩阵,features,现在是一个 (181 * 3) 矩阵,其中每一行代表一个特定参与者的 (1, height, weight)

为了找到参数的最优值,我们可以最大化似然函数,L(params|features)。似然函数以一组给定的参数值作为输入,并返回这些特定参数导致训练集的概率。对于一组参数和相关的概率函数 P(male|features[i]),似然函数为:

一个示例 – 逻辑回归

如果我们神奇地提前知道人口中每个人的性别,我们可以将男性的概率分配为 P(male)=1,女性的概率分配为 P(male)=0。此时,似然函数将是 1。相反,任何不确定性都会导致似然函数的降低。如果我们选择一组参数,这些参数始终导致分类错误(男性低 P(male) 或女性高 P(male)),则似然函数将降至 0

最大似然对应于最有可能描述观察数据的参数值。因此,为了找到最能描述我们的训练集的参数,我们只需要找到最大化 L(params|features) 的参数。然而,最大化似然函数本身很少被做,因为它涉及到将许多小的值相乘,这很快就会导致浮点下溢。最好最大化似然的对数,它与似然有相同的最大值。最后,由于大多数优化算法都是针对最小化函数而不是最大化函数而设计的,因此我们将最小化 一个示例 – 逻辑回归

对于逻辑回归,这相当于最小化:

一个示例 – 逻辑回归

在这里,求和遍历训练数据中的所有参与者,一个示例 – 逻辑回归 是训练集中第 i- 次观察的向量 一个示例 – 逻辑回归,而 一个示例 – 逻辑回归 如果该人是男性则为 1,如果是女性则为 0

为了最小化 成本 函数,我们还必须知道其相对于参数的梯度。这是:

一个示例 – 逻辑回归

我们将首先通过它们的平均值和标准差来缩放身高和体重。虽然这对逻辑回归来说不是严格必要的,但通常是一个好的实践。它有助于优化,如果我们想使用正则化方法或构建超线性特征(允许将男性和女性分开的边界是曲线而不是直线),则成为必要。

对于这个示例,我们将离开 Scala shell 并编写一个独立的 Scala 脚本。以下是完整的代码列表。不要担心这看起来令人畏惧。我们将在下一分钟将其分解成可管理的块:

import breeze.linalg._
import breeze.numerics._
import breeze.optimize._
import breeze.stats._

object LogisticRegressionHWData extends App {

  val data = HWData.load

  // Rescale the features to have mean of 0.0 and s.d. of 1.0
  def rescaled(v:DenseVector[Double]) =
    (v - mean(v)) / stddev(v)

  val rescaledHeights = rescaled(data.heights)
  val rescaledWeights = rescaled(data.weights)

  // Build the feature matrix as a matrix with 
  //181 rows and 3 columns.
  val rescaledHeightsAsMatrix = rescaledHeights.toDenseMatrix.t
  val rescaledWeightsAsMatrix = rescaledWeights.toDenseMatrix.t

  val featureMatrix = DenseMatrix.horzcat(
    DenseMatrix.onesDouble,
    rescaledHeightsAsMatrix,
    rescaledWeightsAsMatrix
  )

  println(s"Feature matrix size: ${featureMatrix.rows} x " +s"${featureMatrix.cols}")

  // Build the target variable to be 1.0 where a participant
  // is male, and 0.0 where the participant is female.
  val target = data.genders.values.map {
    gender => if(gender == 'M') 1.0 else 0.0
  }

  // Build the loss function ready for optimization.
  // We will worry about refactoring this to be more 
  // efficient later.
  def costFunction(parameters:DenseVector[Double]):Double = {
    val xBeta = featureMatrix * parameters
    val expXBeta = exp(xBeta)
    - sum((target :* xBeta) - log1p(expXBeta))
  }

  def costFunctionGradient(parameters:DenseVector[Double])
  :DenseVector[Double] = {
    val xBeta = featureMatrix * parameters
    val probs = sigmoid(xBeta)
    featureMatrix.t * (probs - target)
  }

  val f = new DiffFunction[DenseVector[Double]] {
    def calculate(parameters:DenseVector[Double]) =
      (costFunction(parameters), costFunctionGradient(parameters))
  }

  val optimalParameters = minimize(f, DenseVector(0.0, 0.0, 0.0))

  println(optimalParameters)
  // => DenseVector(-0.0751454743, 2.476293647, 2.23054540)
}

这听起来很复杂!让我们一步一步来。在明显的导入之后,我们开始:

object LogisticRegressionHWData extends App {

通过扩展内置的 App 特质,我们告诉 Scala 将整个对象视为一个 main 函数。这仅仅消除了 def main(args:Array[String]) 的样板代码。然后我们加载数据并将身高和体重缩放到具有 mean 为零和标准差为之一的值:

def rescaled(v:DenseVector[Double]) =
  (v - mean(v)) / stddev(v)

val rescaledHeights = rescaled(data.heights)
val rescaledWeights = rescaled(data.weights)

rescaledHeightsrescaledWeights 向量将是我们的模型特征。现在我们可以为这个模型构建训练集矩阵。这是一个 (181 * 3) 矩阵,其中第 i- 行是 (1, height(i), weight(i)),对应于第 i 个参与者的身高和体重值。我们首先将 rescaledHeightsrescaledWeights 从向量转换为 (181 * 1) 矩阵

val rescaledHeightsAsMatrix = rescaledHeights.toDenseMatrix.t
val rescaledWeightsAsMatrix = rescaledWeights.toDenseMatrix.t

我们还必须创建一个只包含 1 的 (181 * 1) 矩阵作为虚拟特征。我们可以使用以下方法来完成:

DenseMatrix.onesDouble

现在,我们需要将我们的三个 (181 * 1) 矩阵组合成一个形状为 (181 * 3) 的单个特征矩阵。我们可以使用 horzcat 方法将三个矩阵连接起来:

val featureMatrix = DenseMatrix.horzcat(
  DenseMatrix.onesDouble,
  rescaledHeightsAsMatrix,
  rescaledWeightsAsMatrix
)

数据预处理阶段的最后一步是创建目标变量。我们需要将 data.genders 向量转换为包含一和零的向量。我们将男性赋值为 1,女性赋值为 0。因此,我们的分类器将预测任何给定的人是男性的概率。我们将使用 .values.map 方法,这是 Scala 集合上的 .map 方法的等价方法:

val target = data.genders.values.map {
  gender => if(gender == 'M') 1.0 else 0.0
}

注意,我们也可以使用我们之前发现的指示函数:

val maleVector = DenseVector.fill(data.genders.size)('M')
val target = I(data.genders :== maleVector)

这会导致分配一个临时数组,maleVector,因此如果实验中有许多参与者,这可能会增加程序的内存占用。

现在我们有一个表示训练集的矩阵和一个表示目标变量的向量。我们可以写出我们想要最小化的损失函数。如前所述,我们将最小化 一个示例 - 逻辑回归。损失函数接受一组线性系数的值作为输入,并返回一个数字,表示这些线性系数的值如何拟合训练数据:

def costFunction(parameters:DenseVector[Double]):Double = {
  val xBeta = featureMatrix * parameters 
  val expXBeta = exp(xBeta)
  - sum((target :* xBeta) - log1p(expXBeta))
}

注意,我们使用 log1p(x) 来计算 log(1+x)。这对于 x 的较小值是稳健的,不会下溢。

让我们探索成本函数:

costFunction(DenseVector(0.0, 0.0, 0.0)) // 125.45963968135031
costFunction(DenseVector(0.0, 0.1, 0.1)) // 113.33336518036882
costFunction(DenseVector(0.0, -0.1, -0.1)) // 139.17134594294433

我们可以看到,对于稍微正的身高和体重参数值,成本函数略低。这表明对于稍微正的身高和体重值,似然函数更大。这反过来又意味着(正如我们根据图所期望的),身高和体重高于平均值的人更有可能是男性。

我们还需要一个函数来计算损失函数的梯度,因为这有助于优化:

def costFunctionGradient(parameters:DenseVector[Double])
:DenseVector[Double] = {
  val xBeta = featureMatrix * parameters 
  val probs = sigmoid(xBeta)
  featureMatrix.t * (probs - target)
}

在定义了损失函数和梯度之后,我们现在可以设置优化:

 val f = new DiffFunction[DenseVector[Double]] {
   def calculate(parameters:DenseVector[Double]) = 
     (costFunction(parameters), costFunctionGradient(parameters))
 }

现在剩下的就是运行优化。逻辑回归的成本函数是凸函数(它有一个唯一的极小值),所以在原则上优化起点是不相关的。在实践中,通常从一个系数向量开始,该向量在所有地方都是零(相当于将每个参与者的男性概率赋值为 0.5):

  val optimalParameters = minimize(f, DenseVector(0.0, 0.0, 0.0))

这返回最优参数的向量:

DenseVector(-0.0751454743, 2.476293647, 2.23054540)

我们如何解释最优参数的值?身高和体重的系数都是正的,这表明身高和体重较高的人更有可能是男性。

我们还可以直接从系数中得到决策边界(分隔更可能属于女性的(身高,体重)对和更可能属于男性的(身高,体重)对的线)。决策边界是:

一个示例 - 逻辑回归一个示例 - 逻辑回归

身高和体重数据(通过均值偏移并按标准差缩放)。橙色线是逻辑回归决策边界。逻辑回归预测边界上方的个体是男性。

向可重用代码迈进

在上一节中,我们在单个脚本中执行了所有计算。虽然这对于数据探索来说是不错的,但这意味着我们无法重用我们已经构建的逻辑回归代码。在本节中,我们将开始构建一个机器学习库,你可以在不同的项目中重用它。

我们将逻辑回归算法提取到它自己的类中。我们构建一个LogisticRegression类:

import breeze.linalg._
import breeze.numerics._
import breeze.optimize._

class LogisticRegression(
    val training:DenseMatrix[Double], 
    val target:DenseVector[Double])
{

该类接受一个表示训练集的矩阵和一个表示目标变量的向量作为输入。注意我们如何将它们分配给vals,这意味着它们在类创建时设置,并且将在类销毁之前保持不变。当然,DenseMatrixDenseVector对象是可变的,所以trainingtarget指向的值可能会改变。由于编程最佳实践规定可变状态会使程序行为难以推理,我们将避免利用这种可变性。

让我们添加一个计算成本函数及其梯度的方法:

  def costFunctionAndGradient(coefficients:DenseVector[Double])
  :(Double, DenseVector[Double]) = {
    val xBeta = training * coefficients
    val expXBeta = exp(xBeta)
    val cost = - sum((target :* xBeta) - log1p(expXBeta))
    val probs = sigmoid(xBeta)
    val grad = training.t * (probs - target)
    (cost, grad)
  }

我们现在已经准备好运行优化过程,以计算最佳系数,这些系数能够最好地重现训练集。在传统的面向对象语言中,我们可能会定义一个getOptimalCoefficients方法,该方法返回一个包含系数的DenseVector。然而,Scala 却更加优雅。由于我们已经将trainingtarget属性定义为vals,因此最优系数的可能值集只有一个。因此,我们可以定义一个val optimalCoefficients = ???类属性来保存最优系数。问题是这会强制所有计算都在实例构造时发生。这对用户来说可能是意外的,也可能造成浪费:如果用户只对访问成本函数感兴趣,例如,用于最小化的时间将会被浪费。解决方案是使用lazy val。这个值只有在客户端代码请求时才会被评估:

lazy val optimalCoefficients = ???

为了帮助计算系数,我们将定义一个私有辅助方法:

private def calculateOptimalCoefficients
:DenseVector[Double] = {
  val f = new DiffFunction[DenseVector[Double]] {
    def calculate(parameters:DenseVector[Double]) = 
      costFunctionAndGradient(parameters)
  }

  minimize(f, DenseVector.zerosDouble)
}

lazy val optimalCoefficients = calculateOptimalCoefficients

我们已经将逻辑回归重构为它自己的类,我们可以在不同的项目中重用它。

如果我们打算重用身高体重数据,我们可以将其重构为一个自己的类,该类便于数据加载、特征缩放以及任何我们经常重用的其他功能。

Breeze 的替代方案

Breeze 是线性代数和数值计算方面功能最丰富且易于使用的 Scala 框架。然而,不要仅凭我的话就下结论:尝试其他表格数据库。特别是,我推荐尝试Saddle,它提供了一个类似于 pandas 或 R 中的数据框的Frame对象。在 Java 领域,Apache Commons Maths 库提供了一个非常丰富的数值计算工具包。在第十章(part0097.xhtml#aid-2SG6I1 "第十章. 使用 Spark 进行分布式批量处理")、第十一章(part0106.xhtml#aid-352RK2 "第十一章. Spark SQL 和 DataFrames")和第十二章(part0117.xhtml#aid-3FIHQ2 "第十二章. 使用 MLlib 进行分布式机器学习")中,我们将探讨SparkMLlib,它们允许用户运行分布式机器学习算法。

摘要

这就结束了我们对 Breeze 的简要概述。我们学习了如何操作基本的 Breeze 数据类型,如何使用它们进行线性代数,以及如何执行凸优化。然后我们运用我们的知识清理了一个真实的数据集,并在其上执行了逻辑回归。

在下一章中,我们将讨论 breeze-viz,这是一个用于 Scala 的绘图库。

参考文献

统计学习的要素》,由HastieTibshiraniFriedman著,对机器学习的数学基础进行了清晰、实用的描述。任何希望不仅仅盲目地将机器学习算法作为黑盒应用的人,都应该拥有一本翻阅得破旧的这本书。

Scala for Machine Learning》,由Patrick R. Nicholas著,描述了许多有用的机器学习算法在 Scala 中的实际应用。

Breeze 的文档(github.com/scalanlp/breeze/wiki/Quickstart)、API 文档(www.scalanlp.org/api/breeze/#package)和源代码(github.com/scalanlp/breeze)提供了关于 Breeze 的最新文档资源。

第三章. 使用 breeze-viz 进行绘图

数据可视化是数据科学的一个基本组成部分。可视化需求可以分为两大类:在开发和新模型验证期间,以及在管道末尾,从数据和模型中提取意义,为外部利益相关者提供洞察。

这两种可视化类型相当不同。在数据探索和模型开发阶段,可视化库最重要的特性是它的易用性。它应该尽可能少地执行步骤,从拥有数字数组(或 CSV 文件或在数据库中)的数据到在屏幕上显示数据。图表的寿命也相当短:一旦数据科学家从图表或可视化中学到了所有他能学到的知识,它通常就会被丢弃。相比之下,当为外部利益相关者开发可视化小部件时,人们愿意为了更大的灵活性而容忍增加的开发时间。可视化可以具有相当长的寿命,特别是如果底层数据随时间变化的话。

在 Scala 中,用于第一种可视化的首选工具是 breeze-viz。当为外部利益相关者开发可视化时,基于 Web 的可视化(如 D3)和 Tableau 往往更受欢迎。

在本章中,我们将探索 breeze-viz。在第十四章,使用 D3 和 Play Framework 进行可视化中,我们将学习如何为 JavaScript 可视化构建 Scala 后端。

Breeze-viz 是(无需猜测)Breeze 的可视化库。它封装了JFreeChart,一个非常流行的 Java 图表库。Breeze-viz 仍然处于实验阶段。特别是,它比 Python 中的 matplotlib、R 或 MATLAB 的功能要少得多。尽管如此,breeze-viz 允许访问底层的 JFreeChart 对象,因此用户可以始终回退到直接编辑这些对象。breeze-viz 的语法受到了 MATLAB 和 matplotlib 的启发。

深入了解 Breeze

让我们开始吧。我们将在 Scala 控制台中工作,但与这个例子类似的程序可以在本章对应的示例中的BreezeDemo.scala文件中找到。创建一个包含以下行的build.sbt文件:

scalaVersion := "2.11.7"

libraryDependencies ++= Seq(
  "org.scalanlp" %% "breeze" % "0.11.2",
  "org.scalanlp" %% "breeze-viz" % "0.11.2",
  "org.scalanlp" %% "breeze-natives" % "0.11.2"
)

启动sbt控制台:

$ sbt console

scala> import breeze.linalg._
import breeze.linalg._

scala> import breeze.plot._
import breeze.plot._

scala> import breeze.numerics._
import breeze.numerics._

让我们先绘制一个 sigmoid 曲线,深入 Breeze。我们将首先使用 Breeze 生成数据。回想一下,linspace方法创建了一个在两个值之间均匀分布的双精度浮点数向量:

scala> val x = linspace(-4.0, 4.0, 200)
x: DenseVector[Double] = DenseVector(-4.0, -3.959798...

scala> val fx = sigmoid(x)
fx: DenseVector[Double] = DenseVector(0.0179862099620915,...

现在我们已经准备好了用于绘制的数据。第一步是创建一个图形:

scala> val fig = Figure()
fig: breeze.plot.Figure = breeze.plot.Figure@37e36de9

这创建了一个空的 Java Swing 窗口(它可能出现在你的任务栏或等效位置)。一个图形可以包含一个或多个绘图。让我们向我们的图形添加一个绘图:

scala> val plt = fig.subplot(0)
plt: breeze.plot.Plot = breeze.plot.Plot@171c2840

现在,让我们忽略传递给.subplot0作为参数。我们可以向我们的plot添加数据点:

scala> plt += plot(x, fx)
breeze.plot.Plot = breeze.plot.Plot@63d6a0f8

plot函数接受两个参数,对应于要绘制的数据系列的xy值。要查看更改,您需要刷新图形:

scala> fig.refresh()

现在看看 Swing 窗口。你应该看到一个漂亮的 sigmoid 曲线,类似于下面的一个。在窗口上右键单击可以让你与绘图交互并将图像保存为 PNG:

深入 Breeze

您还可以按以下方式程序化地保存图像:

scala> fig.saveas("sigmoid.png")

Breeze-viz 目前仅支持导出为 PNG 格式。

自定义绘图

现在我们图表上有一条曲线。让我们再添加几条:

scala> val f2x = sigmoid(2.0*x)
f2x: breeze.linalg.DenseVector[Double] = DenseVector(3.353501304664E-4...

scala> val f10x = sigmoid(10.0*x)
f10x: breeze.linalg.DenseVector[Double] = DenseVector(4.24835425529E-18...

scala> plt += plot(x, f2x, name="S(2x)")
breeze.plot.Plot = breeze.plot.Plot@63d6a0f8

scala> plt += plot(x, f10x, name="S(10x)")
breeze.plot.Plot = breeze.plot.Plot@63d6a0f8

scala> fig.refresh()

现在查看这个图,你应该看到三种不同颜色的所有三条曲线。注意,我们在添加到绘图时命名了数据系列,使用 name="" 关键字参数。要查看名称,我们必须设置 legend 属性:

scala> plt.legend = true

自定义绘图

我们的绘图还有很多可以改进的地方。让我们首先将 x 轴的范围限制,以移除绘图两侧的空白带:

scala> plt.xlim = (-4.0, 4.0)
plt.xlim: (Double, Double) = (-4.0,4.0)

现在,请注意,虽然 x 刻度合理地分布,但只有两个 y 刻度:在 01 处。每增加 0.1 就有一个刻度将很有用。Breeze 没有提供直接设置此功能的方法。相反,它暴露了当前绘图所属的底层 JFreeChart 轴对象:

scala> plt.yaxis
org.jfree.chart.axis.NumberAxis = org.jfree.chart.axis.NumberAxis@0

Axis 对象支持一个 .setTickUnit 方法,允许我们设置刻度间距:

scala> import org.jfree.chart.axis.NumberTickUnit
import org.jfree.chart.axis.NumberTickUnit

scala> plt.yaxis.setTickUnit(new NumberTickUnit(0.1))

JFreeChart 允许对 Axis 对象进行广泛的定制。有关可用方法的完整列表,请参阅 JFreeChart 文档 (www.jfree.org/jfreechart/api/javadoc/org/jfree/chart/axis/Axis.html)。

让我们在 x=0 处添加一条垂直线,在 f(x)=1 处添加一条水平线。我们需要访问底层 JFreeChart 绘图来添加这些线条。这在我们的 Breeze Plot 对象中作为 .plot 属性(有些令人困惑)可用:

scala> plt.plot
org.jfree.chart.plot.XYPlot = org.jfree.chart.plot.XYPlot@17e4db6c

我们可以使用 .addDomainMarker.addRangeMarker 方法向 JFreeChart XYPlot 对象添加垂直和水平线条:

scala> import org.jfree.chart.plot.ValueMarker
import org.jfree.chart.plot.ValueMarker

scala> plt.plot.addDomainMarker(new ValueMarker(0.0))

scala> plt.plot.addRangeMarker(new ValueMarker(1.0))

让我们也给坐标轴添加标签:

scala> plt.xlabel = "x"
plt.xlabel: String = x

scala> plt.ylabel = "f(x)"
plt.ylabel: String = f(x)

如果你已经运行了所有这些命令,你应该有一个看起来像这样的图形:

自定义绘图

我们现在知道如何自定义图形的基本构建块。下一步是学习如何更改曲线的绘制方式。

自定义线条类型

到目前为止,我们只是使用默认设置绘制了线条。Breeze 允许我们自定义线条的绘制方式,至少在一定程度上。

对于这个例子,我们将使用在 第二章 中讨论的高度-体重数据,使用 Breeze 操作数据。我们将在这里使用 Scala shell 进行演示,但你将在 BreezeDemo.scala 程序中找到一个遵循示例 shell 会话的程序。

本章的代码示例附带一个用于加载数据的模块,HWData.scala,它从 CSV 文件中加载数据:

scala> val data = HWData.load
data: HWData = HWData [ 181 rows ]

scala> data.heights
breeze.linalg.DenseVector[Double] = DenseVector(182.0, ...

scala> data.weights
breeze.linalg.DenseVector[Double] = DenseVector(77.0, 58.0...

让我们创建一个身高与体重的散点图:

scala> val fig = Figure("height vs. weight")
fig: breeze.plot.Figure = breeze.plot.Figure@743f2558

scala> val plt = fig.subplot(0)
plt: breeze.plot.Plot = breeze.plot.Plot@501ea274

scala> plt += plot(data.heights, data.weights, '+',         colorcode="black")
breeze.plot.Plot = breeze.plot.Plot@501ea274

这产生了身高-体重数据的散点图:

自定义线条类型

注意,我们传递了第三个参数给plot方法,'+'。这控制了绘图样式。截至本文写作时,有三种可用的样式:'-'(默认),'+''.'。尝试这些样式以查看它们的作用。最后,我们传递一个colorcode="black"参数来控制线的颜色。这可以是颜色名称或 RGB 三元组,以字符串形式编写。因此,要绘制红色点,我们可以传递colorcode="[255,0,0]"

观察身高体重图,身高和体重之间显然存在趋势。让我们尝试通过数据点拟合一条直线。我们将拟合以下函数:

自定义线型

注意

科学文献表明,拟合类似自定义线型的东西会更好。如果你愿意,应该可以轻松地将二次线拟合到数据上。

我们将使用 Breeze 的最小二乘函数来找到ab的值。leastSquares方法期望一个特征矩阵的输入和一个目标向量,就像我们在上一章中定义的LogisticRegression类一样。回想一下,在第二章中,使用 Breeze 操作数据,当我们为逻辑回归分类准备训练集时,我们引入了一个虚拟特征,每个参与者都是 1,以提供y截距的自由度。我们在这里将使用相同的方法。因此,我们的特征矩阵包含两列——一列在所有地方都是1,另一列是身高:

scala> val features = DenseMatrix.horzcat(
 DenseMatrix.onesDouble,
 data.heights.toDenseMatrix.t
)
features: breeze.linalg.DenseMatrix[Double] =
1.0  182.0
1.0  161.0
1.0  161.0
1.0  177.0
1.0  157.0
...

scala> import breeze.stats.regression._
import breeze.stats.regression._

scala> val leastSquaresResult = leastSquares(features, data.weights)
leastSquaresResult: breeze.stats.regression.LeastSquaresRegressionResult = <function1>

leastSquares方法返回一个LeastSquareRegressionResult实例,它包含一个coefficients属性,包含最佳拟合数据的系数:

scala> leastSquaresResult.coefficients
breeze.linalg.DenseVector[Double] = DenseVector(-131.042322, 1.1521875)

因此,最佳拟合线是:

自定义线型

让我们提取系数。一种优雅的方法是使用 Scala 的模式匹配功能:

scala> val Array(a, b) = leastSquaresResult.coefficients.toArray
a: Double = -131.04232269750622
b: Double = 1.1521875435418725

通过编写val Array(a, b) = ...,我们告诉 Scala 表达式右侧是一个包含两个元素的数组,并将该数组的第一个元素绑定到值a,第二个元素绑定到b。参见附录,模式匹配和提取器,以了解模式匹配的讨论。

我们现在可以将最佳拟合线添加到我们的图中。我们首先生成均匀间隔的虚拟身高值:

scala> val dummyHeights = linspace(min(data.heights), max(data.heights), 200)
dummyHeights: breeze.linalg.DenseVector[Double] = DenseVector(148.0, ...

scala> val fittedWeights = a :+ (b :* dummyHeights)
fittedWeights: breeze.linalg.DenseVector[Double] = DenseVector(39.4814...

scala> plt += plot(dummyHeights, fittedWeights, colorcode="red")
breeze.plot.Plot = breeze.plot.Plot@501ea274

让我们也将最佳拟合线的方程添加到图中作为注释。我们首先生成标签:

scala> val label = f"weight = $a%.4f + $b%.4f * height"
label: String = weight = -131.0423 + 1.1522 * height

要添加注释,我们必须访问底层的 JFreeChart 绘图:

scala> import org.jfree.chart.annotations.XYTextAnnotation
import org.jfree.chart.annotations.XYTextAnnotation

scala> plt.plot.addAnnotation(new XYTextAnnotation(label, 175.0, 105.0))

XYTextAnnotation 构造函数接受三个参数:注释字符串和定义注释在图上中心的 (x, y) 坐标对。注释的坐标以数据的坐标系表示。因此,调用 new XYTextAnnotation(label, 175.0, 105.0) 将生成一个中心在对应 175 厘米高度和 105 公斤重量的点的注释:

自定义线型

更高级的散点图

Breeze-viz 提供了一个 scatter 函数,为散点图添加了显著的定制程度。特别是,我们可以使用标记点的大小和颜色向图中添加额外的信息维度。

scatter 函数接受其前两个参数为 xy 点的集合。第三个参数是一个函数,它将整数 i 映射到表示 第 i 个 点大小的 Double 值。点的大小以 x 轴的单位来衡量。如果你有 Scala 集合或 Breeze 向量的尺寸,你可以使用该集合的 apply 方法作为该函数。让我们看看这在实践中是如何工作的。

与之前的示例一样,我们将使用 REPL,但你可以在 BreezeDemo.scala 中找到一个示例程序:

scala> val fig = new Figure("Advanced scatter example")
fig: breeze.plot.Figure = breeze.plot.Figure@220821bc

scala> val plt = fig.subplot(0)
plt: breeze.plot.Plot = breeze.plot.Plot@668f8ae0

scala> val xs = linspace(0.0, 1.0, 100)
xs: breeze.linalg.DenseVector[Double] = DenseVector(0.0, 0.010101010101010102, 0.0202 ...

scala> val sizes = 0.025 * DenseVector.rand(100) // random sizes
sizes: breeze.linalg.DenseVector[Double] = DenseVector(0.014879265631723166, 0.00219551...

scala> plt += scatter(xs, xs :^ 2.0, sizes.apply)
breeze.plot.Plot = breeze.plot.Plot@668f8ae0

选择自定义颜色的工作方式类似:我们传递一个 colors 参数,它将整数索引映射到 java.awt.Paint 对象。直接使用这些可能比较麻烦,所以 Breeze 提供了一些默认调色板。例如,GradientPaintScale 将给定域内的双精度值映射到均匀的颜色渐变。让我们将范围 0.01.0 的双精度值映射到红色和绿色之间的颜色:

scala> val palette = new GradientPaintScale(
 0.0, 1.0, PaintScale.RedToGreen)
palette: breeze.plot.GradientPaintScale[Double] = <function1>

scala> palette(0.5) // half-way between red and green
java.awt.Paint = java.awt.Color[r=127,g=127,b=0]

scala> palette(1.0) // green
java.awt.Paint = java.awt.Color[r=0,g=254,b=0]

除了 GradientPaintScale,breeze-viz 还提供了一个 CategoricalPaintScale 类用于分类调色板。有关不同调色板的概述,请参阅 PaintScale.scala 源文件,位于 scalagithub.com/scalanlp/breeze/blob/master/viz/src/main/scala/breeze/plot/PaintScale.scala

让我们利用我们新获得的知识来绘制一个多彩散点图。我们将假设与上一个示例相同的初始化。我们将为每个点分配一个随机颜色:

scala> val palette = new GradientPaintScale(0.0, 1.0, PaintScale.MaroonToGold)
palette: breeze.plot.GradientPaintScale[Double] = <function1>

scala> val colors = DenseVector.rand(100).mapValues(palette)
colors: breeze.linalg.DenseVector[java.awt.Paint] = DenseVector(java.awt.Color[r=162,g=5,b=0], ...

scala> plt += scatter(xs, xs :^ 2.0, sizes.apply, colors.apply)
breeze.plot.Plot = breeze.plot.Plot@8ff7e27

更高级的散点图

多图示例 – 散点矩阵图

在本节中,我们将学习如何在同一图中拥有多个图。

允许在同一图中绘制多个图的关键新方法是 fig.subplot(nrows, ncols, plotIndex)。这个方法是我们迄今为止一直在使用的 fig.subplot 方法的重载版本,它既设置了图中的行数和列数,又返回一个特定的子图。它接受三个参数:

  • nrows:图中子图行数

  • ncols:图中子图列数

  • plotIndex:要返回的图的索引

熟悉 MATLAB 或 matplotlib 的用户会注意到.subplot方法与这些框架中的同名方法相同。这可能会显得有点复杂,所以让我们看一个例子(你可以在BreezeDemo.scala中找到这个代码):

import breeze.plot._

def subplotExample {
  val data = HWData.load
  val fig = new Figure("Subplot example")

  // upper subplot: plot index '0' refers to the first plot
  var plt = fig.subplot(2, 1, 0)
  plt += plot(data.heights, data.weights, '.')

  // lower subplot: plot index '1' refers to the second plot
  plt = fig.subplot(2, 1, 1)
  plt += plot(data.heights, data.reportedHeights, '.', colorcode="black")

  fig.refresh
}

运行此示例将生成以下图表:

多图示例 – 散点图矩阵图表

现在我们已经基本掌握了如何在同一张图上添加多个子图的方法,让我们做一些更有趣的事情。我们将编写一个类来绘制散点图矩阵。这些对于探索不同特征之间的相关性非常有用。

如果你不太熟悉散点图矩阵,请查看本节末尾的图表,以了解我们正在构建的内容。想法是为每对特征构建一个散点图的正方形矩阵。矩阵中的元素(ij)是特征i与特征j的散点图。由于一个变量与其自身的散点图用途有限,通常会在对角线上绘制每个特征的直方图。最后,由于特征i与特征j的散点图包含与特征j与特征i的散点图相同的信息,通常只绘制矩阵的上三角或下三角。

让我们先编写用于单个图表的函数。这些函数将接受一个Plot对象,它引用正确的子图,以及要绘制的数据的向量:

import breeze.plot._
import breeze.linalg._

class ScatterplotMatrix(val fig:Figure) {

  /** Draw the histograms on the diagonal */
  private def plotHistogram(plt:Plot)(
  data:DenseVector[Double], label:String) {
     plt += hist(data)
     plt.xlabel = label
  }

  /** Draw the off-diagonal scatter plots */
  private def plotScatter(plt:Plot)(
    xdata:DenseVector[Double],
    ydata:DenseVector[Double],
    xlabel:String,
    ylabel:String) {
      plt += plot(xdata, ydata, '.')
      plt.xlabel = xlabel
      plt.ylabel = ylabel
  }

...

注意到使用了hist(data)来绘制直方图。hist函数的参数必须是一个数据点的向量。hist方法会将这些数据点分箱,并以直方图的形式表示它们。

现在我们已经有了绘制单个图表的机制,我们只需要将所有东西连接起来。难点是知道如何根据矩阵中的给定行和列位置选择正确的子图。我们可以通过调用fig.subplot(nrows, ncolumns, plotIndex)来选择单个图表,但将()索引对转换为单个plotIndex并不明显。图表按顺序编号,首先是左到右,然后是上到下:

0 1 2 3
4 5 6 7
...

让我们编写一个简短的功能来选择一个在()索引对上的图表:

  private def selectPlot(ncols:Int)(irow:Int, icol:Int):Plot = {
    fig.subplot(ncols, ncols, (irow)*ncols + icol)
  }

现在我们已经可以绘制矩阵图本身了:

  /** Draw a scatterplot matrix.
    *
    * This function draws a scatterplot matrix of the correlation
    * between each pair of columns in `featureMatrix`.
    *
    * @param featureMatrix A matrix of features, with each column
    *   representing a feature.
    * @param labels Names of the features.
    */
  def plotFeatures(featureMatrix:DenseMatrix[Double], labels:List[String]) {
    val ncols = featureMatrix.cols
    require(ncols == labels.size,
      "Number of columns in feature matrix "+ "must match length of labels"
    )
    fig.clear
    fig.subplot(ncols, ncols, 0)

    (0 until ncols) foreach { irow =>
      val p = selectPlot(ncols)(irow, irow)
      plotHistogram(p)(featureMatrix(::, irow), labels(irow))

      (0 until irow) foreach { icol =>
        val p = selectPlot(ncols)(irow, icol)
        plotScatter(p)(
          featureMatrix(::, irow),
          featureMatrix(::, icol),
          labels(irow),
          labels(icol)
        )
      }
    }
  }
}

让我们为我们的类编写一个示例。我们将再次使用身高体重数据:

import breeze.linalg._
import breeze.numerics._
import breeze.plot._

object ScatterplotMatrixDemo extends App {

  val data = HWData.load
  val m = new ScatterplotMatrix(Figure("Scatterplot matrix demo"))

  // Make a matrix with three columns: the height, weight and
  // reported weight data.
  val featureMatrix = DenseMatrix.horzcat(
    data.heights.toDenseMatrix.t,
    data.weights.toDenseMatrix.t,
    data.reportedWeights.toDenseMatrix.t
  )
  m.plotFeatures(featureMatrix,List("height", "weight", "reportedWeights"))

}

在 SBT 中运行此代码将生成以下图表:

多图示例 – 散点图矩阵图表

无文档管理

Breeze-viz 的文档不幸地相当不完善。这可能会使学习曲线变得有些陡峭。幸运的是,它仍然是一个非常小的项目:在撰写本文时,只有十个源文件(github.com/scalanlp/breeze/tree/master/viz/src/main/scala/breeze/plot)。了解 breeze-viz 确切做什么的一个好方法是阅读源代码。例如,要查看 Plot 对象上可用的方法,请阅读源文件 Plot.scala。如果您需要 Breeze 提供的功能之外的功能,请查阅 JFreeChart 的文档,以了解您是否可以通过访问底层的 JFreeChart 对象来实现所需的功能。

Breeze-viz 参考文档

在编程书中编写参考是一项危险的活动:您很快就会过时。尽管如此,鉴于 breeze-viz 的文档很少,这一部分变得更加相关——与不存在的东西竞争更容易。请带着一点盐来接受这一部分,如果本节中的某个命令不起作用,请查看源代码:

命令 描述
plt += plot(xs, ys) 这将一系列 (xs, ys) 值绘制出来。xsys 值必须是类似集合的对象(例如 Breeze 向量、Scala 数组或列表)。
plt += scatter(xs, ys, size) plt += scatter(xs, ys, size, color)
plt += hist(xs) plt += hist(xs, bins=10)
plt += image(mat) 这将图像或矩阵绘制出来。mat 参数应该是 Matrix[Double]。有关详细信息,请阅读 breeze.plot 中的 package.scala 源文件(github.com/scalanlp/breeze/blob/master/viz/src/main/scala/breeze/plot/package.scala)。

总结 plot 对象上可用的选项也是有用的:

属性 描述
plt.xlabel = "x-label" plt.ylabel = "y-label"
plt.xlim = (0.0, 1.0) plt.ylim = (0.0, 1.0)
plt.logScaleX = true plt.logScaleY = true
plt.title = "title" 这设置绘图标题

Breeze-viz 之外的数据可视化

Scala 中用于数据可视化的其他工具正在出现:基于 IPython 笔记本的 Spark 笔记本(github.com/andypetrella/spark-notebook#description)和 Apache Zeppelin(zeppelin.incubator.apache.org)。这两个都依赖于 Apache Spark,我们将在本书的后面部分探讨。

摘要

在本章中,我们学习了如何使用 breeze-viz 绘制简单的图表。在本书的最后一章中,我们将学习如何使用 JavaScript 库构建交互式可视化。

接下来,我们将学习关于基本 Scala 并发构造的知识——特别是并行集合。

第四章。并行集合和 Futures

数据科学通常涉及处理中等或大量数据。由于之前单个 CPU 速度的指数增长已经放缓,而数据量仍在增加,因此有效地利用计算机必须涉及并行计算。

在本章中,我们将探讨在单台计算机上并行化计算和数据处理的方法。几乎所有的计算机都有多个处理单元,将这些核心的计算分布在这些核心上可以是一种加快中等规模计算的有效方式。

在单个芯片上并行化计算适合涉及千兆字节或几太字节数据的计算。对于更大的数据流,我们必须求助于将计算并行分布在多台计算机上。我们将在第十章,使用 Spark 的分布式批量处理中讨论 Apache Spark,这是一个并行数据处理框架。

在这本书中,我们将探讨在单台机器上利用并行架构的三个常见方法:并行集合、futures 和 actors。我们将在本章中考虑前两种,并将 actors 的研究留给第九章,使用 Akka 的并发

并行集合

并行集合提供了一种极其简单的方式来并行化独立任务。由于读者熟悉 Scala,他们将知道许多任务可以表述为集合上的操作,如mapreducefiltergroupBy。并行集合是 Scala 集合的一种实现,它将这些操作并行化以在多个线程上运行。

让我们从例子开始。我们想要计算句子中每个字母出现的频率:

scala> val sentence = "The quick brown fox jumped over the lazy dog"
sentence: String = The quick brown fox jumped ...

首先,让我们将我们的句子从字符串转换为字符向量:

scala> val characters = sentence.toVector
Vector[Char] = Vector(T, h, e,  , q, u, i, c, k, ...)

现在,我们可以将characters转换为并行向量,即ParVector。为此,我们使用par方法:

scala> val charactersPar = characters.par
ParVector[Char] = ParVector(T, h, e,  , q, u, i, c, k,  , ...)

ParVector集合支持与常规向量相同的操作,但它们的方法是在多个线程上并行执行的。

首先,让我们从charactersPar中过滤掉空格:

scala> val lettersPar = charactersPar.filter { _ != ' ' }
ParVector[Char] = ParVector(T, h, e, q, u, i, c, k, ...)

注意 Scala 如何隐藏执行细节。filter操作使用了多个线程,而你几乎都没有注意到!并行向量的接口和行为与其串行对应物相同,除了我们将在下一节中探讨的一些细节。

现在我们使用toLower函数将字母转换为小写:

scala> val lowerLettersPar = lettersPar.map { _.toLower }
ParVector[Char] = ParVector(t, h, e, q, u, i, c, k, ...)

与之前一样,map方法是在并行中应用的。为了找到每个字母的出现频率,我们使用groupBy方法将字符分组到包含该字符所有出现的向量中:

scala> val intermediateMap = lowerLettersPar.groupBy(identity)
ParMap[Char,ParVector[Char]] = ParMap(e -> ParVector(e, e, e, e), ...)

注意groupBy方法创建了一个ParMap实例,它是不可变映射的并行等价物。为了得到每个字母出现的次数,我们在intermediateMap上执行mapValues调用,将每个向量替换为其长度:

scala> val occurenceNumber = intermediateMap.mapValues { _.length }
ParMap[Char,Int] = ParMap(e -> 4, x -> 1, n -> 1, j -> 1, ...)

恭喜!我们用几行代码编写了一个多线程算法,用于查找每个字母在文本中出现的频率。你应该会发现将其修改为查找文档中每个单词的出现频率非常简单,这是分析文本数据时常见的预处理问题。

并行集合使得并行化某些操作流程变得非常容易:我们只需在characters向量上调用.par。所有后续操作都被并行化。这使得从串行实现切换到并行实现变得非常容易。

并行集合的限制

并行集合的部分力量和吸引力在于它们提供了与它们的串行对应物相同的接口:它们有map方法、foreach方法、filter方法等等。总的来说,这些方法在并行集合上的工作方式与在串行集合上相同。然而,也有一些值得注意的注意事项。其中最重要的一点与副作用有关。如果在并行集合上的操作有副作用,这可能会导致竞争条件:一种最终结果取决于线程执行操作顺序的情况。

集合中的副作用最常见的情况是我们更新集合外部定义的变量。为了给出一个意外的行为的不平凡例子,让我们定义一个count变量,并使用并行范围将其增加一千次:

scala> var count = 0
count: Int = 0

scala> (0 until 1000).par.foreach { i => count += 1 }

scala> count
count: Int = 874 // not 1000!

这里发生了什么?传递给foreach的函数有一个副作用:它增加count,这是一个在函数作用域之外的变量。这是一个问题,因为+=运算符是一系列两个操作:

  • 获取count的值并将其加一

  • 将结果赋值回count

要理解这为什么会引起意外的行为,让我们想象foreach循环已经被并行化到两个线程上。线程 A可能在变量为832时读取count变量,并将其加一得到833。在它有时间将833重新赋值给count之前,线程 B读取count,仍然为832,然后加一得到833。然后线程 A833赋值给count。接着线程 B也将833赋值给count。我们进行了两次更新,但只增加了计数一次。问题出现是因为+=可以被分解成两个指令:它不是原子性的。这为线程提供了交错操作的空间:

并行集合的限制

竞态条件的解剖:线程 A 和线程 B 都在尝试并发更新count,导致其中一个更新被覆盖。count的最终值是 833 而不是 834。

为了给出一个稍微更现实的例子,说明非原子性引起的问题,让我们看看一种不同的方法来计算句子中每个字母出现的频率。我们在循环外部定义一个可变的Char -> Int哈希表。每次我们遇到一个字母时,我们就在映射中增加相应的整数:

scala> import scala.collection.mutable
import scala.collection.mutable

scala> val occurenceNumber = mutable.Map.empty[Char, Int]
occurenceNumber: mutable.Map[Char,Int] = Map()

scala> lowerLettersPar.foreach { c => 
 occurenceNumber(c) = occurenceNumber.getOrElse(c, 0) + 1
}

scala> occurenceNumber('e') // Should be 4
Int = 2

这种差异是由于foreach循环中操作的非原子性造成的。

通常,避免在集合的高阶函数中引入副作用是一个好习惯。它们使得代码更难以理解,并阻止从串行集合切换到并行集合。避免暴露可变状态也是一个好习惯:不可变对象可以在线程之间自由共享,并且不会受到副作用的影响。

并行集合的另一个限制出现在归约(或折叠)操作中。用于组合项的功能必须是结合律的。例如:

scala> (0 until 1000).par.reduce {_ - _ } // should be -499500
Int = 63620

减号运算符,,不是结合律的。连续操作应用的顺序很重要:(a – b) – c不等于a – (b – c)。用于归约并行集合的功能必须是结合律的,因为归约发生的顺序与集合的顺序无关。

错误处理

在单线程程序中,异常处理相对简单:如果发生异常,函数可以处理它或将其升级。当引入并行性时,这并不那么明显:一个线程可能会失败,但其他线程可能仍然成功返回。

如果并行集合方法在任何一个元素上失败,它们将抛出异常,就像它们的串行对应物一样:

scala> Vector(2, 0, 5).par.map { 10 / _ }
java.lang.ArithmeticException: / by zero
...

有时候我们并不希望这种行为。例如,我们可能正在使用并行集合并行检索大量网页。我们可能不介意如果有一些页面无法获取。

Scala 的Try类型是为了沙盒化可能会抛出异常的代码而设计的。它与Option类似,因为它是一个单元素容器:

scala> import scala.util._
import scala.util._

scala> Try { 2 + 2 }
Try[Int] = Success(4)

与表示表达式是否有有用值的 Option 类型不同,Try 类型表示表达式是否可以在不抛出异常的情况下执行。它具有以下两个值:

  • 尝试 { 2 + 2 } == Success(4) 如果 Try 语句中的表达式成功评估

  • 尝试 { 2 / 0 } == Failure(java.lang.ArithmeticException: / by zero) 如果 Try 块中的表达式导致异常

通过一个例子来解释会更清晰。为了看到 Try 类型的实际应用,我们将尝试以容错的方式获取网页。我们将使用内置的 Source.fromURL 方法来获取网页并打开页面内容的迭代器。如果它无法获取网页,它将抛出一个错误:

scala> import scala.io.Source
import scala.io.Source

scala> val html = Source.fromURL("http://www.google.com")
scala.io.BufferedSource = non-empty iterator

scala> val html = Source.fromURL("garbage")
java.net.MalformedURLException: no protocol: garbage
...

而不是让表达式传播出去并使我们的代码崩溃,我们可以将 Source.fromURL 的调用包裹在 Try 中:

scala> Try { Source.fromURL("http://www.google.com") }
Try[BufferedSource] = Success(non-empty iterator)

scala> Try { Source.fromURL("garbage") }
Try[BufferedSource] = Failure(java.net.MalformedURLException: no protocol: garbage)

为了看到我们的 Try 语句的力量,现在让我们以容错的方式并行检索 URL 列表:

scala> val URLs = Vector("http://www.google.com", 
 "http://www.bbc.co.uk",
 "not-a-url"
)
URLs: Vector[String] = Vector(http://www.google.com, http://www.bbc.co.uk, not-a-url)

scala> val pages = URLs.par.map { url =>
 url -> Try { Source.fromURL(url) } 
}
pages: ParVector[(String, Try[BufferedSource])] = ParVector((http://www.google.com,Success(non-empty iterator)), (http://www.bbc.co.uk,Success(non-empty iterator)), (not-a-url,Failure(java.net.MalformedURLException: no protocol: not-a-url)))

然后,我们可以使用 collect 语句对成功获取的页面进行操作。例如,获取每个页面的字符数:

scala> pages.collect { case(url, Success(it)) => url -> it.size }
ParVector[(String, Int)] = ParVector((http://www.google.com,18976), (http://www.bbc.co.uk,132893))

通过充分利用 Scala 的内置 Try 类和并行集合,我们只用几行代码就构建了一个容错、多线程的 URL 检索器。(与那些在代码示例前加上 '为了清晰起见省略错误处理' 的 Java/C++ 书籍相比。)

小贴士

Try 类型与 try/catch 语句的比较

具有命令式或面向对象背景的程序员会更熟悉用于处理异常的 try/catch 块。我们可以通过将获取 URL 的代码包裹在 try 块中,在调用引发异常时返回 null 来实现类似的功能。

然而,除了更冗长之外,返回 null 的效果也不太令人满意:我们失去了有关异常的所有信息,而 null 的表达性不如 Failure(exception)。此外,返回 Try[T] 类型迫使调用者考虑函数可能失败的可能性,通过将这种可能性编码在返回值的类型中。相比之下,只返回 T 并用 null 值编码失败,允许调用者忽略失败,从而增加了在程序中完全不同的点上抛出令人困惑的 NullPointerException 的可能性。

简而言之,Try[T] 只是一种更高阶的类型,就像 Option[T]List[T] 一样。将失败的可能性以与代码中其他部分相同的方式处理,增加了程序的一致性,并鼓励程序员明确处理异常的可能性。

设置并行级别

到目前为止,我们将并行集合视为黑盒:将 par 添加到普通集合中,所有操作都会并行执行。通常,我们可能希望对任务执行的细节有更多的控制。

内部,并行集合通过在多个线程上分配操作来工作。由于线程共享内存,并行集合不需要复制任何数据。更改并行集合可用的线程数将更改用于执行任务的 CPU 数量。

并行集合有一个tasksupport属性,用于控制任务执行:

scala> val parRange = (0 to 100).par
parRange: ParRange = ParRange(0, 1, 2, 3, 4, 5,...

scala> parRange.tasksupport
TaskSupport = scala.collection.parallel.ExecutionContextTaskSupport@311a0b3e

scala> parRange.tasksupport.parallelismLevel
Int = 8 // Number of threads to be used

集合的任务支持对象是一个执行上下文,这是一个能够在一个单独的线程中执行 Scala 表达式的抽象。默认情况下,Scala 2.11 中的执行上下文是一个工作窃取线程池。当一个并行集合提交任务时,上下文将这些任务分配给其线程。如果一个线程发现它已经完成了其队列中的任务,它将尝试从其他线程中窃取未完成的任务。默认执行上下文维护一个线程池,线程数等于 CPU 的数量。

并行集合分配工作的线程数可以通过更改任务支持来改变。例如,为了通过四个线程并行化由范围执行的操作:

scala> import scala.collection.parallel._
import scala.collection.parallel._

scala> parRange.tasksupport = new ForkJoinTaskSupport(
 new scala.concurrent.forkjoin.ForkJoinPool(4)
)
parRange.tasksupport: scala.collection.parallel.TaskSupport = scala.collection.parallel.ForkJoinTaskSupport@6e1134e1

scala> parRange.tasksupport.parallelismLevel
Int: 4

例子 - 使用并行集合的交叉验证

让我们将你到目前为止学到的知识应用到解决数据科学问题中。机器学习管道的许多部分可以轻易并行化。其中之一就是交叉验证。

我们在这里将简要介绍交叉验证,但你可以参考《统计学习的要素》,由哈斯蒂蒂布斯哈尼弗里德曼撰写,以获得更深入的讨论。

通常,一个监督式机器学习问题涉及在一个训练集上训练一个算法。例如,当我们构建一个基于身高和体重计算一个人性别概率的模型时,训练集是每个参与者的(身高,体重)数据,以及每行的男/女标签。一旦算法在训练集上训练完成,我们就可以用它来分类新的数据。这个过程只有在训练集能够代表我们可能遇到的新数据时才有意义。

训练集有有限数量的条目。因此,不可避免地,它将具有一些不是代表整个人群的独特性,仅仅是因为其有限性。这些独特性将在预测新人是男性还是女性时导致预测错误,这超出了算法在训练集本身上的预测错误。交叉验证是估计由不反映整个人群的训练集独特性引起的错误的工具。

交叉验证通过将训练集分成两部分:一个较小的、新的训练集和一个交叉验证集来工作。算法在减少的训练集上训练。然后我们看看算法对交叉验证集的建模效果如何。由于我们知道交叉验证集的正确答案,我们可以衡量当展示新信息时我们的算法表现得多好。我们重复这个程序多次,使用不同的交叉验证集。

有几种不同的交叉验证类型,它们在如何选择交叉验证集方面有所不同。在本章中,我们将探讨重复随机子采样:我们从训练数据中随机选择k行来形成交叉验证集。我们这样做很多次,计算每个子样本的交叉验证误差。由于每个迭代都是相互独立的,我们可以简单地并行化这个过程。因此,它是一个并行集合的良好候选者。我们将在第十二章中查看交叉验证的另一种形式,即k折交叉验证使用 MLlib 进行分布式机器学习*。

我们将构建一个执行并行交叉验证的类。我鼓励你在编写代码的同时进行,但你将在 GitHub 上找到与这些示例对应的源代码(github.com/pbugnion/s4ds)。我们将使用并行集合来处理并行性,并在内部循环中使用 Breeze 数据类型。build.sbt文件与我们在第二章中使用的相同,即使用 Breeze 操作数据

scalaVersion := "2.11.7"

libraryDependencies ++= Seq(
 "org.scalanlp" %% "breeze" % "0.11.2",
 "org.scalanlp" %% "breeze-natives" % "0.11.2"
)

我们将构建一个RandomSubsample类。该类公开了一个类型别名CVFunction,用于一个接受两个索引列表的函数——第一个对应于减少的训练集,第二个对应于验证集——并返回一个与交叉验证误差相对应的Double值:

type CVFunction = (Seq[Int], Seq[Int]) => Double

RandomSubsample类将公开一个名为mapSamples的单个方法,它接受一个CVFunction,重复传递不同的索引分区,并返回一个错误向量。这个类看起来是这样的:

// RandomSubsample.scala

import breeze.linalg._
import breeze.numerics._

/** Random subsample cross-validation
  * 
  * @param nElems Total number of elements in the training set.
  * @param nCrossValidation Number of elements to leave out of training set.
*/
class RandomSubsample(val nElems:Int, val nCrossValidation:Int) {

  type CVFunction = (Seq[Int], Seq[Int]) => Double

  require(nElems > nCrossValidation,
    "nCrossValidation, the number of elements " +
    "withheld, must be < nElems")

  private val indexList = DenseVector.range(0, nElems)

  /** Perform multiple random sub-sample CV runs on f
    *
    * @param nShuffles Number of random sub-sample runs.
    * @param f user-defined function mapping from a list of
    *   indices in the training set and a list of indices in the
    *   test-set to a double indicating the out-of sample score
    *   for this split.
    * @returns DenseVector of the CV error for each random split.
    */
  def mapSamples(nShuffles:Int)(f:CVFunction)
  :DenseVector[Double] = {
    val cvResults = (0 to nShuffles).par.map { i =>

      // Randomly split indices between test and training
      val shuffledIndices = breeze.linalg.shuffle(indexList)
      val Seq(testIndices, trainingIndices) =
        split(shuffledIndices, Seq(nCrossValidation))

 // Apply f for this split
      f(trainingIndices.toScalaVector, 
        testIndices.toScalaVector)
    }
    DenseVector(cvResults.toArray)
  }
}

让我们更详细地看看发生了什么,从传递给构造函数的参数开始:

class RandomSubsample(val nElems:Int, val nCrossValidation:Int)

我们在类构造函数中传递训练集中元素的总数和为交叉验证留出的元素数。因此,将 100 传递给nElems,将 20 传递给nCrossValidation意味着我们的训练集将包含 80 个随机元素的总数据,而测试集将包含 20 个元素。

然后我们构建一个介于0nElems之间的所有整数的列表:

private val indexList = DenseVector.range(0, nElems)

对于交叉验证的每一次迭代,我们将对这个列表进行洗牌,并取前nCrossValidation个元素作为测试集行索引,剩余的作为训练集行索引。

我们公开了一个单一的方法mapSamples,它接受两个柯里化参数:nShuffles,表示执行随机子采样的次数,以及f,一个CVFunction

  def mapSamples(nShuffles:Int)(f:CVFunction):DenseVector[Double] 

在设置好所有这些之后,进行交叉验证的代码看起来很简单。我们生成一个从0nShuffles的并行范围,并对范围中的每个项目,生成一个新的训练-测试分割并计算交叉验证错误:

    val cvResults = (0 to nShuffles).par.map { i =>
      val shuffledIndices = breeze.linalg.shuffle(indexList)
      val Seq(testIndices, trainingIndices) = 
        split(shuffledIndices, Seq(nCrossValidation))
      f(trainingIndices.toScalaVector, testIndices.toScalaVector)
    }

这个函数的唯一难点是将打乱后的索引列表分割成训练集索引列表和测试集索引列表。我们使用 Breeze 的split方法。这个方法接受一个向量作为其第一个参数,一个分割点列表作为其第二个参数,并返回原始向量的片段列表。然后我们使用模式匹配来提取各个部分。

最后,mapSamplescvResults转换为 Breeze 向量:

DenseVector(cvResults.toArray) 

让我们看看它是如何工作的。我们可以通过在第二章中开发的逻辑回归示例上运行交叉验证来测试我们的类,使用 Breeze 操作数据。在该章中,我们开发了一个LogisticRegression类,它在构造时接受一个训练集(以DenseMatrix的形式)和一个目标(以DenseVector的形式)。然后该类计算最能代表训练集的参数。我们首先将向LogisticRegression类添加两个方法,以使用训练好的模型对先前未见过的示例进行分类:

  • predictProbabilitiesMany方法使用训练好的模型来计算目标变量被设置为 1 的概率。在我们的示例中,这是给定身高和体重为男性的概率。

  • classifyMany方法将分类标签(一个或零)分配给测试集的成员。如果predictProbabilitiesMany返回的值大于0.5,我们将分配一个一。

使用这两个函数,我们的LogisticRegression类变为:

// Logistic Regression.scala

class LogisticRegression(
  val training:DenseMatrix[Double],
  val target:DenseVector[Double]
) {
  ...
  /** Probability of classification for each row
    * in test set.
    */
  def predictProbabilitiesMany(test:DenseMatrix[Double])
  :DenseVector[Double] = {
    val xBeta = test * optimalCoefficients
    sigmoid(xBeta)
  }

  /** Predict the value of the target variable 
    * for each row in test set.
    */
  def classifyMany(test:DenseMatrix[Double])
  :DenseVector[Double] = {
    val probabilities = predictProbabilitiesMany(test)
    I((probabilities :> 0.5).toDenseVector)
  }
  ...
}

现在,我们可以为我们的RandomSubsample类编写一个示例程序。我们将使用与第二章中相同的身高-体重数据,使用 Breeze 操作数据。数据预处理将类似。本章的代码示例提供了一个辅助模块HWData,用于将身高-体重数据加载到 Breeze 向量中。数据本身位于本章代码示例的data/目录中(可在 GitHub 上找到github.com/pbugnion/s4ds/tree/master/chap04)。

对于每个新的子样本,我们创建一个新的LogisticRegression实例,在训练集的子集上对其进行训练,以获得此训练-测试分割的最佳系数,并使用classifyMany在此分割的交叉验证集上生成预测。然后我们计算分类错误并报告每个训练-测试分割的平均分类错误:

// RandomSubsampleDemo.scala

import breeze.linalg._
import breeze.linalg.functions.manhattanDistance
import breeze.numerics._
import breeze.stats._

object RandomSubsampleDemo extends App {

  /* Load and pre-process data */
  val data = HWData.load

  val rescaledHeights:DenseVector[Double] =
    (data.heights - mean(data.heights)) / stddev(data.heights)

  val rescaledWeights:DenseVector[Double] =
    (data.weights - mean(data.weights)) / stddev(data.weights)

  val featureMatrix:DenseMatrix[Double] =
    DenseMatrix.horzcat(
      DenseMatrix.onesDouble,
      rescaledHeights.toDenseMatrix.t,
      rescaledWeights.toDenseMatrix.t
    )

  val target:DenseVector[Double] = data.genders.values.map { 
    gender => if(gender == 'M') 1.0 else 0.0 
  }

  /* Cross-validation */
  val testSize = 20
  val cvCalculator = new RandomSubsample(data.npoints, testSize)

  // Start parallel CV loop
  val cvErrors = cvCalculator.mapSamples(1000) { 
    (trainingIndices, testIndices) =>

    val regressor = new LogisticRegression(
      data.featureMatrix(trainingIndices, ::).toDenseMatrix,
      data.target(trainingIndices).toDenseVector
    )
    // Predictions on test-set
    val genderPredictions = regressor.classifyMany(
      data.featureMatrix(testIndices, ::).toDenseMatrix
    )
    // Calculate number of mis-classified examples
    val dist = manhattanDistance(
      genderPredictions, data.target(testIndices)
    )
    // Calculate mis-classification rate
    dist / testSize.toDouble
  }

  println(s"Mean classification error: ${mean(cvErrors)}")
}

在高度-体重数据上运行此程序给出 10%的分类错误。

现在我们有一个完全工作、并行化的交叉验证类。Scala 的并行范围使得在不同的线程中重复计算相同的函数变得简单。

未来

并行集合提供了一个简单而强大的并行操作框架。然而,在一方面它们是有限的:必须预先知道总工作量,并且每个线程必须执行相同的函数(可能在不同的输入上)。

假设我们想要编写一个程序,每隔几秒从网页(或查询一个 Web API)中获取数据,并从该网页中提取数据以进行进一步处理。一个典型的例子可能涉及查询 Web API 以维护特定股票价格的最新值。从外部网页获取数据通常需要几百毫秒。如果我们在这个主线程上执行此操作,它将无谓地浪费 CPU 周期等待 Web 服务器回复。

解决方案是将获取网页的代码包装在一个未来中。未来是一个包含计算未来结果的单一元素容器。当你创建一个未来时,其中的计算会被卸载到不同的线程,以避免阻塞主线程。当计算完成时,结果会被写入未来,从而使其对主线程可访问。

作为例子,我们将编写一个程序,查询“Markit on demand”API 以获取给定股票的价格。例如,谷歌当前股价的 URL 是dev.markitondemand.com/MODApis/Api/v2/Quote?symbol=GOOG。请将此粘贴到您网络浏览器的地址框中。您将看到一个 XML 字符串出现,其中包含当前股价等信息。让我们先不使用未来来以编程方式获取这个信息:

scala> import scala.io._
import scala.io_

scala> val url = "http://dev.markitondemand.com/MODApis/Api/v2/Quote?symbol=GOOG"
url: String = http://dev.markitondemand.com/MODApis/Api/v2/Quote?symbol=GOOG

scala> val response = Source.fromURL(url).mkString
response: String = <StockQuote><Status>SUCCESS</Status>
...

注意查询 API 需要一点时间。现在让我们做同样的事情,但使用一个未来(现在不用担心导入,我们将在后面详细讨论它们的含义):

scala> import scala.concurrent._
import scala.concurrent._

scala> import scala.util._
import scala.util._

scala> import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.ExecutionContext.Implicits.global

scala> val response = Future { Source.fromURL(url).mkString }
response: Future[String] = Promise$DefaultPromise@3301801b

如果你运行这个程序,你会在 API 有机会响应之前立即返回控制台。为了使这一点更加明显,让我们通过添加Thread.sleep调用来模拟一个慢速连接:

scala> val response = Future { 
 Thread.sleep(10000) // sleep for 10s
 Source.fromURL(url).mkString
}
response: Future[String] = Promise$DefaultPromise@231f98ef

当你运行这个程序时,你不需要等待十秒钟才能出现下一个提示符:你立即重新获得控制台的控制权。未来中的代码是异步执行的:它的执行与主程序流程独立。

我们如何检索计算的成果?我们注意到response的类型是Future[String]。我们可以通过查询未来的isCompleted属性来检查未来中包装的计算是否完成:

scala> response.isCompleted
Boolean = true

未来暴露了一个包含计算结果的value属性:

scala> response.value
Option[Try[String]] = Some(Success(<StockQuote><Status>SUCCESS</Status>
...

Future 的 value 属性类型为 Option[Try[T]]。我们已经在并行集合的上下文中看到了如何使用 Try 类型优雅地处理异常。在这里也是同样的用法。一个 future 的 value 属性在 future 完成之前是 None,然后它被设置为 Some(Success(value)) 如果 future 成功运行,或者设置为 Some(Failure(error)) 如果抛出了异常。

在 shell 中重复调用 f.value 直到 future 完成,效果很好,但这并不能推广到更复杂的程序中。相反,我们希望告诉计算机一旦 future 完成,就执行某些操作:我们希望将 回调 函数绑定到 future。我们可以通过设置 future 的 onComplete 属性来实现这一点。让我们告诉 future 在完成时打印 API 响应:

scala> response.onComplete {
 case Success(s) => println(s)
 case Failure(e) => println(s"Error fetching page: $e")
}

scala> 
// Wait for response to complete, then prints:
<StockQuote><Status>SUCCESS</Status><Name>Alphabet Inc</Name><Symbol>GOOGL</Symbol><LastPrice>695.22</LastPrice><Chan...

传递给 onComplete 的函数在 future 完成时运行。它接受一个类型为 Try[T] 的单个参数,包含 future 的结果。

小贴士

失败是正常的:如何构建弹性应用程序

通过将运行代码的输出包装在 Try 类型中,futures 强制客户端代码考虑代码可能失败的可能性。客户端可以隔离失败的影响,避免整个应用程序崩溃。例如,他们可能会记录异常。在 Web API 查询的情况下,他们可能会将受影响的 URL 添加到稍后再次查询的列表中。在数据库失败的情况下,他们可能会回滚事务。

通过将失败视为一等公民,而不是通过在末尾附加异常控制流来处理,我们可以构建出更加弹性的应用程序。

Future 组合 – 使用 future 的结果

在上一节中,你学习了关于 onComplete 方法将回调绑定到 future 的内容。这在 future 完成时引发副作用非常有用。然而,它并不能让我们轻松地转换 future 的返回值。

为了继续我们的股票示例,让我们想象我们想要将查询响应从字符串转换为 XML 对象。让我们首先在 build.sbt 中将 scala-xml 库作为依赖项包含进来:

libraryDependencies += "org.scala-lang" % "scala-xml" % "2.11.0-M4"

让我们重新启动控制台并重新导入 scala.concurrent._scala.concurrent.ExecutionContext.Implicits.globalscala.io._ 这三个依赖。我们还想导入 XML 库:

scala> import scala.xml.XML
import scala.xml.XML

我们将使用与上一节相同的 URL:

dev.markitondemand.com/MODApis/Api/v2/Quote?symbol=GOOG

有时将 future 视为一个集合是有用的,如果计算成功,则包含一个元素;如果失败,则包含零个元素。例如,如果 Web API 已成功查询,我们的 future 包含响应的字符串表示。像 Scala 中的其他容器类型一样,futures 支持一个 map 方法,该方法将函数应用于 future 中包含的元素,返回一个新的 future,如果 future 中的计算失败则不执行任何操作。但在计算可能尚未完成的情况下,这意味着什么呢?map 方法在 future 完成时立即应用,就像 onComplete 方法一样。

我们可以使用 futuremap 方法异步地对 future 的结果应用转换。让我们再次轮询 "Markit on demand" API。这次,我们不会打印结果,而是将其解析为 XML。

scala> val strResponse = Future { 
 Thread.sleep(20000) // Sleep for 20s
 val res = Source.fromURL(url).mkString
 println("finished fetching url")
 res
}
strResponse: Future[String] = Promise$DefaultPromise@1dda9bc8

scala> val xmlResponse = strResponse.map { s =>
 println("applying string to xml transformation")
 XML.loadString(s) 
}
xmlResponse: Future[xml.Elem] = Promise$DefaultPromise@25d1262a

// wait while the remainder of the 20s elapses
finished fetching url
applying string to xml transformation

scala> xmlResponse.value
Option[Try[xml.Elem]] = Some(Success(<StockQuote><Status>SUCCESS</Status>...

通过在 future 上注册后续的映射,我们为运行 future 的执行器提供了一个路线图,说明要执行的操作。

如果任何步骤失败,包含异常的失败 Try 实例将被传播:

scala> val strResponse = Future { 
 Source.fromURL("empty").mkString 
}

scala> val xmlResponse = strResponse.map { 
 s => XML.loadString(s) 
}

scala> xmlResponse.value 
Option[Try[xml.Elem]] = Some(Failure(MalformedURLException: no protocol: empty))

如果你将失败的 future 视为一个空容器,这种行为是有意义的。当将映射应用于空列表时,它返回相同的空列表。同样,当将映射应用于空(失败)的 future 时,返回空 future

阻塞直到完成

用于获取股票价格的代码在 shell 中运行良好。然而,如果你将其粘贴到独立程序中,你会注意到没有任何内容被打印出来,程序立即结束。让我们看看这个简单示例:

// BlockDemo.scala
import scala.concurrent._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._

object BlockDemo extends App {
  val f = Future { Thread.sleep(10000) }
  f.onComplete { _ => println("future completed") }
  // "future completed" is not printed
}

程序在主线程完成其任务后立即停止运行,在这个例子中,这仅仅涉及创建 future。特别是,"future completed" 这一行永远不会打印出来。如果我们想让主线程等待 future 执行,我们必须明确告诉它阻塞执行,直到 future 完成。这是通过使用 Await.readyAwait.result 方法来完成的。这两种方法都会阻塞主线程的执行,直到 future 完成。我们可以通过添加以下行来使上述程序按预期工作:

Await.ready(f, 1 minute)

Await 方法将 future 作为第一个参数,将 Duration 对象作为第二个参数。如果 future 完成所需时间超过指定的持续时间,则抛出 TimeoutException。传递 Duration.Inf 来设置无限超时。

Await.readyAwait.result 之间的区别在于后者返回 future 内部的值。特别是,如果 future 导致异常,则该异常将被抛出。相比之下,Await.ready 返回 future 本身。

通常,应尽可能避免阻塞:future的整个目的就是在后台线程中运行代码,以保持执行主线程的响应性。然而,在程序结束时进行阻塞是一个常见且合理的用例。如果我们正在运行一个大规模的集成过程,我们可能会调度几个future来查询 Web API、从文本文件中读取或向数据库中插入数据。将代码嵌入到future中比按顺序执行这些操作更具有可扩展性。然而,由于大部分密集型工作都在后台线程中运行,当主线程完成时,我们会留下许多未完成的future。在这个阶段,阻塞直到所有future完成是有意义的。

使用执行上下文控制并行执行

现在我们已经知道了如何定义future,让我们看看如何控制它们的运行。特别是,你可能想要控制运行大量future时使用的线程数量。

当定义一个future时,它会传递一个execution context,无论是直接还是隐式。执行上下文是一个对象,它公开了一个execute方法,该方法接受一段代码并运行它,可能是异步的。通过更改执行上下文,我们可以更改运行future的“后端”。我们已经看到了如何使用执行上下文来控制并行集合的执行。

到目前为止,我们只是通过导入scala.concurrent.ExecutionContext.Implicits.global来使用默认的执行上下文。这是一个具有与底层 CPU 数量相同的线程的 fork/join 线程池。

现在让我们定义一个新的使用十六个线程的执行上下文:

scala> import java.util.concurrent.Executors
import java.util.concurrent.Executors

scala> val ec = ExecutionContext.fromExecutorService(
 Executors.newFixedThreadPool(16)
)
ec: ExecutionContextExecutorService = ExecutionContextImpl$$anon$1@1351ce60

定义了执行上下文后,我们可以显式地将它传递给正在定义的future

scala> val f = Future { Thread.sleep(1000) } (ec)
f: Future[Unit] = Promise$DefaultPromise@458b456

或者,我们可以隐式地定义执行上下文:

scala> implicit val context = ec
context: ExecutionContextExecutorService = ExecutionContextImpl$$anon$1@1351ce60

然后将它作为隐式参数传递给所有新构建的future

scala> val f = Future { Thread.sleep(1000) }
f: Future[Unit] = Promise$DefaultPromise@3c4b7755

你可以关闭执行上下文来销毁线程池:

scala> ec.shutdown()

当执行上下文收到关闭命令时,它将完成当前任务的执行,但会拒绝任何新的任务。

期货示例 - 股票价格获取器

让我们将本节中介绍的一些概念结合起来,构建一个命令行应用程序,该程序提示用户输入股票名称并获取该股票的价值。难点在于,为了保持用户界面响应,我们将使用future来获取股票信息:

// StockPriceDemo.scala

import scala.concurrent._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.io._
import scala.xml.XML
import scala.util._

object StockPriceDemo extends App {

 /* Construct URL for a stock symbol */
  def urlFor(stockSymbol:String) =
    ("http://dev.markitondemand.com/MODApis/Api/v2/Quote?" + 
     s"symbol=${stockSymbol}")

  /* Build a future that fetches the stock price */
  def fetchStockPrice(stockSymbol:String):Future[BigDecimal] = {
    val url = urlFor(stockSymbol)
    val strResponse = Future { Source.fromURL(url).mkString }
    val xmlResponse = strResponse.map { s => XML.loadString(s) }
    val price = xmlResponse.map { 
      r => BigDecimal((r \ "LastPrice").text) 
    }
    price
  }

  /* Command line interface */
  println("Enter symbol at prompt.")
  while (true) {
    val symbol = readLine("> ") // Wait for user input
    // When user puts in symbol, fetch data in background
    // thread and print to screen when complete
    fetchStockPrice(symbol).onComplete { res =>
      println()
      res match {
        case Success(price) => println(s"$symbol: USD $price")
        case Failure(e) => println(s"Error fetching  $symbol: $e")
      }
      print("> ") // Simulate the appearance of a new prompt
    }
  }

}

尝试运行程序并输入一些股票的代码:

[info] Running StockPriceDemo
Enter symbol at prompt:
> GOOG
> MSFT
>
GOOG: USD 695.22
>
MSFT: USD 47.48
> AAPL
> 
AAPL: USD 111.01

让我们总结一下代码的工作原理。当你输入一个股票时,主线程构建一个期货,从 API 获取股票信息,将其转换为 XML,并提取价格。我们使用(r \ "LastPrice").text从 XML 节点r中提取LastPrice标签内的文本。然后我们将值转换为大数据。当转换完成后,结果通过onComplete绑定回调打印到屏幕。异常处理通过我们使用.map方法处理转换来自然处理。

通过将获取股票价格的代码包装在期货中,我们释放了主线程,使其仅用于响应用户。这意味着如果我们的互联网连接很慢,用户界面不会阻塞。

这个例子有些人为,但你很容易将其包装得更加复杂:股票价格可以写入数据库,我们还可以添加额外的命令来绘制股票价格随时间的变化,例如。

在本节中,我们只是刚刚触及了期货所能提供的表面内容。当我们查看第七章(part0059.xhtml#aid-1O8H61 "第七章。Web API")、Web API和第九章(part0077.xhtml#aid-29DRA1 "第九章。使用 Akka 的并发")、使用 Akka 的并发时,我们将更详细地回顾期货。

期货是数据科学家构建可扩展系统工具箱的关键部分。将昂贵的计算(无论是 CPU 时间还是墙时间)移动到后台线程可以极大地提高可扩展性。因此,期货是许多 Scala 库(如AkkaPlay框架)的重要组成部分。

摘要

通过提供高级并发抽象,Scala 使得编写并行代码直观且简单。并行集合和期货是数据科学家工具箱中不可或缺的部分,它们允许他们以最小的努力并行化代码。然而,尽管这些高级抽象消除了直接处理线程的需求,但理解 Scala 并发模型内部机制是必要的,以避免竞争条件。

在下一章中,我们将暂停并发的研究,学习如何与 SQL 数据库交互。然而,这只是一个临时的举措:期货将在本书剩余的许多章节中扮演重要的角色。

参考文献

Aleksandar Prokopec在 Scala 中学习并发编程。这是对 Scala 并发编程基础的一个详细介绍。特别是,它比本章更详细地探讨了并行集合和期货。

Daniel Westheide 的博客为许多 Scala 概念提供了出色的介绍,特别是:

关于交叉验证的讨论,请参阅 Hastie、Tibshirani 和 Friedman 合著的《统计学习基础》。

第五章:Scala 和 SQL 通过 JDBC

数据科学的一个基本目标是处理大型数据集的困难。许多公司或研究组感兴趣的数据都无法方便地存储在单个计算机的 RAM 中。因此,以易于查询的方式存储数据是一个复杂的问题。

关系型数据库在解决数据存储问题上取得了成功。最初于 1970 年提出(www.seas.upenn.edu/~zives/03f/cis550/codd.pdf),今天仍在使用的绝大多数数据库仍然是关系型。在这段时间里,每兆字节 RAM 的价格下降了十亿倍。同样,硬盘容量也从几十或几百兆字节增加到太字节。值得注意的是,尽管数据存储容量呈指数增长,但关系型模型仍然占据主导地位。

几乎所有关系型数据库都使用 SQL(结构化查询语言)的变体进行描述和查询。随着分布式计算的出现,SQL 数据库作为事实上的数据存储标准的地位正受到其他类型数据库的挑战,这些数据库通常被统称为 NoSQL。许多 NoSQL 数据库比 SQL 数据库更具有分区容错性:它们可以被分割成几个部分,分别存储在不同的计算机上。尽管作者预计 NoSQL 数据库将越来越受欢迎,但 SQL 数据库作为数据持久化机制可能仍然普遍存在;因此,本书的很大一部分内容将致力于从 Scala 与 SQL 交互。

虽然 SQL 是标准化的,但大多数实现并不完全遵循该标准。此外,大多数实现都提供了对标准的扩展。这意味着,尽管本书中的许多概念将适用于所有 SQL 后端,但确切的语法可能需要调整。在这里,我们将仅考虑 MySQL 实现。

在本章中,您将学习如何使用 JDBC(一个基本的 Java API)从 Scala 与 SQL 数据库交互。在下一章中,我们将考虑 Slick,这是一个对象关系映射器(ORM),它为与 SQL 的交互提供了更符合 Scala 风格的感觉。

本章大致分为两部分:我们首先将讨论连接和与 SQL 数据库交互的基本功能,然后讨论可以用来创建优雅、松散耦合和一致的数据访问层的有用功能模式。

本章假设你具备基本的 SQL 工作知识。如果你不具备,你最好先阅读本章末尾提到的参考书籍之一。

与 JDBC 交互

JDBC 是 Java 连接 SQL 数据库的 API。它仍然是 Scala 连接 SQL 数据库的最简单方式。此外,大多数用于与数据库交互的高级抽象仍然使用 JDBC 作为后端。

JDBC 本身不是一个库。相反,它暴露了一组用于与数据库交互的接口。关系数据库供应商随后提供这些接口的特定实现。

让我们先创建一个build.sbt文件。我们将声明对 MySQL JDBC 连接器的依赖:

scalaVersion := "2.11.7"

libraryDependencies += "mysql" % "mysql-connector-java" % "5.1.36"

JDBC 的第一步

让我们从命令行连接到 JDBC 开始。为了跟随示例,你需要访问一个正在运行的 MySQL 服务器。如果你已经将 MySQL 连接器添加到依赖列表中,可以通过输入以下命令打开 Scala 控制台:

$ sbt console

让我们导入 JDBC:

scala> import java.sql._
import java.sql._

我们接下来需要告诉 JDBC 使用特定的连接器。这通常是通过反射完成的,在运行时加载驱动:

scala> Class.forName("com.mysql.jdbc.Driver")
Class[_] = class com.mysql.jdbc.Driver

这个操作将在运行时将适当的驱动加载到命名空间中。如果你觉得这有点神奇,可能不值得担心它具体是如何工作的。这是本书中我们将考虑的唯一反射示例,而且它并不是特别符合 Scala 的惯例。

连接到数据库服务器

已经指定了 SQL 连接器,我们现在可以连接到数据库。假设我们有一个名为test的数据库,它在127.0.0.1主机上监听端口3306。我们创建连接如下:

scala> val connection = DriverManager.getConnection(
 "jdbc:mysql://127.0.0.1:3306/test",
 "root", // username when connecting
 "" // password
)
java.sql.Connection = com.mysql.jdbc.JDBC4Connection@12e78a69

getConnection的第一个参数是一个类似 URL 的字符串,格式为jdbc:mysql://host[:port]/database。第二个和第三个参数是用户名和密码。如果你无需密码即可连接,请传入空字符串。

创建表

现在我们已经建立了数据库连接,让我们与服务器交互。对于这些示例,你将发现打开一个 MySQL 壳(或 MySQL GUI,如MySQLWorkbench)以及 Scala 控制台很有用。你可以在终端中输入以下命令打开 MySQL 壳:

$ mysql

例如,我们将创建一个小表来跟踪著名物理学家。在mysql壳中,我们会运行以下命令:

mysql> USE test;
mysql> CREATE TABLE physicists (
 id INT(11) AUTO_INCREMENT PRIMARY KEY,
 name VARCHAR(32) NOT NULL
);

要在 Scala 中实现相同的功能,我们需要向连接发送一个 JDBC 语句:

scala> val statementString = """
CREATE TABLE physicists (
 id INT(11) AUTO_INCREMENT PRIMARY KEY,
 name VARCHAR(32) NOT NULL
)
"""

scala> val statement = connection.prepareStatement(statementString)
PreparedStatement = JDBC4PreparedStatement@c983201: CREATE TABLE ...

scala> statement.executeUpdate()
results: Int = 0

现在,我们先忽略executeUpdate方法的返回值。

插入数据

现在我们已经创建了一个表,让我们向其中插入一些数据。我们可以使用 SQL INSERT语句来完成这个操作:

scala> val statement = connection.prepareStatement("""
 INSERT INTO physicists (name) VALUES ('Isaac Newton')
""")

scala> statement.executeUpdate()
Int = 1

在这种情况下,executeUpdate返回1。当插入行时,它返回插入的行数。同样,如果我们使用了一个SQL UPDATE语句,这将返回更新的行数。对于不直接操作行的语句(如上一节中的CREATE TABLE语句),executeUpdate仅返回0

让我们直接跳入mysql shell 来验证插入是否正确执行:

mysql> select * from physicists ;
+----+--------------+
| id | name         |
+----+--------------+
|  1 | Isaac Newton |
+----+--------------+
1 row in set (0.00 sec)

让我们快速总结一下到目前为止我们所看到的:要执行不返回结果的 SQL 语句,请使用以下方法:

val statement = connection.prepareStatement("SQL statement string")
statement.executeUpdate()

在数据科学的背景下,我们经常需要一次插入或更新多行。例如,我们可能有一个物理学家列表:

scala> val physicistNames = List("Marie Curie", "Albert Einstein", "Paul Dirac")

我们希望将这些全部插入到数据库中。虽然我们可以为每个物理学家创建一个语句并将其发送到数据库,但这非常低效。更好的解决方案是创建一个批量语句并将它们一起发送到数据库。我们首先创建一个语句模板:

scala> val statement = connection.prepareStatement("""
 INSERT INTO physicists (name) VALUES (?)
""")
PreparedStatement = JDBC4PreparedStatement@621a8225: INSERT INTO physicists (name) VALUES (** NOT SPECIFIED **)

这与之前的prepareStatement调用相同,只是我们将物理学家名字替换为?占位符。我们可以使用statement.setString方法设置占位符的值:

scala> statement.setString(1, "Richard Feynman")

这将用字符串Richard Feynman替换语句中的第一个占位符:

scala> statement
com.mysql.jdbc.JDBC4PreparedStatement@5fdd16c3:
INSERT INTO physicists (name) VALUES ('Richard Feynman')

注意,JDBC 有些反直觉地,从 1 而不是 0 开始计算占位符位置。

我们现在已创建了更新批量的第一个语句。运行以下命令:

scala> statement.addBatch()

通过运行前面的命令,我们启动了一个批量插入:该语句被添加到一个临时缓冲区中,当我们运行executeBatch方法时将执行。让我们将列表中的所有物理学家添加进去:

scala> physicistNames.foreach { name => 
 statement.setString(1, name)
 statement.addBatch()
}

我们现在可以执行批量中的所有语句:

scala> statement.executeBatch
Array[Int] = Array(1, 1, 1, 1)

executeBatch的返回值是一个数组,表示批量中每个项目更改或插入的行数。

注意,我们使用了statement.setString来用特定的名字填充模板。PreparedStatement对象为所有基本类型都有setXXX方法。要获取完整的列表,请阅读PreparedStatement API 文档(docs.oracle.com/javase/7/docs/api/java/sql/PreparedStatement.html)。

读取数据

现在我们知道了如何将数据插入数据库,让我们看看相反的情况:读取数据。我们使用 SQL SELECT语句来查询数据库。让我们首先在 MySQL shell 中这样做:

mysql> SELECT * FROM physicists;
+----+-----------------+
| id | name            |
+----+-----------------+
|  1 | Isaac Newton    |
|  2 | Richard Feynman |
|  3 | Marie Curie     |
|  4 | Albert Einstein |
|  5 | Paul Dirac      |
+----+-----------------+
5 rows in set (0.01 sec)

要在 Scala 中提取此信息,我们定义一个PreparedStatement

scala> val statement = connection.prepareStatement("""
 SELECT name FROM physicists
""")
PreparedStatement = JDBC4PreparedStatement@3c577c9d:
SELECT name FROM physicists

我们通过运行以下命令来执行此语句:

scala> val results = statement.executeQuery()
results: java.sql.ResultSet = com.mysql.jdbc.JDBC4ResultSet@74a2e158

这返回一个 JDBC ResultSet实例。ResultSet是一个表示数据库中一组行的抽象。请注意,我们使用了statement.executeQuery而不是statement.executeUpdate。一般来说,应该使用executeQuery来执行返回数据(以ResultSet形式)的语句。修改数据库而不返回数据(插入、创建、更改或更新语句等)的语句使用executeUpdate执行。

ResultSet对象的行为有点像迭代器。它暴露了一个next方法,该方法将其自身推进到下一个记录,如果ResultSet中还有记录,则返回true

scala> results.next // Advance to the first record
Boolean = true

ResultSet实例指向一个记录时,我们可以通过传递字段名来提取该记录中的字段:

scala> results.getString("name")
String = Isaac Newton

我们还可以使用位置参数提取字段。字段从一开始索引:

scala> results.getString(1) // first positional argument
String = Isaac Newton

当我们完成特定记录的处理时,我们调用 next 方法将 ResultSet 移动到下一个记录:

scala> results.next // advances the ResultSet by one record
Boolean = true

scala> results.getString("name")
String = Richard Feynman

读取数据

ResultSet 对象支持 getXXX(fieldName) 方法来访问记录的字段,以及一个 next 方法来移动到结果集中的下一个记录。

可以使用 while 循环遍历结果集:

scala> while(results.next) { println(results.getString("name")) }
Marie Curie
Albert Einstein
Paul Dirac

小贴士

关于读取可空字段有一个警告。虽然当面对 null SQL 字段时,人们可能期望 JDBC 返回 null,但返回类型取决于使用的 getXXX 命令。例如,getIntgetLong 将对任何 null 字段返回 0。同样,getDoublegetFloat 返回 0.0。这可能导致代码中的一些微妙错误。一般来说,应该小心使用返回 Java 值类型(intlong)而不是对象的 getter。要确定数据库中的值是否为 null,首先使用 getInt(或 getLonggetDouble,视情况而定)查询它,然后使用返回布尔值的 wasNull 方法:

scala> rs.getInt("field")
0
scala> rs.wasNull // was the last item read null?
true

这种(令人惊讶的)行为使得从 ResultSet 实例中读取变得容易出错。本章第二部分的一个目标是为您提供构建在 ResultSet 接口之上的抽象层的工具,以避免直接调用如 getInt 这样的方法。

在 Scala 中直接从 ResultSet 对象读取值感觉很不自然。在本章的后续部分,我们将探讨通过构建一个层来访问结果集,您可以通过类型类来访问这个层。

现在我们已经知道如何读取和写入数据库。现在我们已经完成了对数据库的操作,我们将关闭结果集、预处理语句和连接:

scala> results.close

scala> statement.close

scala> connection.close

虽然在 Scala shell 中关闭语句和连接并不重要(它们将在你退出时关闭),但在运行程序时很重要;否则,对象将持续存在,导致“内存不足异常”。在下一节中,我们将探讨使用 贷款模式 建立连接和语句,这是一种设计模式,在完成使用资源后自动关闭它。

JDBC 概述

现在我们对 JDBC 有了一个概述。本章的其余部分将专注于编写位于 JDBC 之上的抽象,使数据库访问感觉更自然。在我们这样做之前,让我们总结一下到目前为止我们所看到的。

我们已经使用了三个 JDBC 类:

  • Connection 类表示与特定 SQL 数据库的连接。以下是如何实例化连接的示例:

    import java.sql._
    Class.forName("com.mysql.jdbc.Driver")val connection = DriverManager.getConnection(
      "jdbc:mysql://127.0.0.1:3306/test",
      "root", // username when connecting
      "" // password
    )
    

    我们主要使用 Connection 实例来生成 PreparedStatement 对象:

    connection.prepareStatement("SELECT * FROM physicists")
    
  • PreparedStatement 实例代表即将发送到数据库的 SQL 语句。它还代表一个 SQL 语句的模板,其中包含尚未填充的值占位符。该类公开以下方法:

    statement.executeUpdate 这将语句发送到数据库。用于修改数据库且不返回任何数据的 SQL 语句,例如 INSERTUPDATEDELETECREATE 语句。
    val results = statement.executeQuery 这将语句发送到数据库。用于返回数据的 SQL 语句(主要是 SELECT 语句)。这返回一个 ResultSet 实例。
    statement.addBatch statement.executeBatch addBatch 方法将当前语句添加到语句批处理中,而 executeBatch 将语句批处理发送到数据库。
    statement.setString(1, "Scala") statement.setInt(1, 42) statement.setBoolean(1, true) PreparedStatement 中填写占位符值。第一个参数是语句中的位置(从 1 开始计数)。第二个参数是值。这些方法的常见用例包括批量更新或插入:我们可能有一个想要插入数据库的 Scala 对象列表。我们使用 .setXXX 方法为列表中的每个对象填写占位符,然后使用 .addBatch 将此语句添加到批处理中。然后我们可以使用 .executeBatch 将整个批处理发送到数据库。
    statement.setNull(1, java.sql.Types.BOOLEAN) 这将语句中的特定项设置为 NULL。第二个参数指定 NULL 类型。如果我们正在设置布尔列中的单元格,例如,这应该是 Types.BOOLEANjava.sql.Types 包的 API 文档中提供了类型列表(docs.oracle.com/javase/7/docs/api/java/sql/Types.html)。
  • ResultSet 实例表示由 SELECTSHOW 语句返回的一组行。ResultSet 提供了访问当前行字段的方法:

    rs.getString(i) rs.getInt(i) 这些方法获取当前行中第 i 个字段的值;i 从 1 开始计算。
    rs.getString("name") rs.getInt("age") 这些方法获取特定字段的值,字段通过列名索引。
    rs.wasNull 这返回最后读取的列是否为 NULL。当读取 Java 值类型(如 getIntgetBooleangetDouble)时,这尤其重要,因为这些在读取 NULL 值时返回默认值。

ResultSet 实例公开了 .next 方法以移动到下一行;.next 返回 true,直到 ResultSet 前进到最后一行之后。

JDBC 的函数式包装

我们现在对 JDBC 提供的工具有一个基本的概述。到目前为止,我们与之交互的所有对象在 Scala 中都显得有些笨拙和不合适。它们并不鼓励函数式编程风格。

当然,优雅本身可能不是目标(或者至少,你可能很难说服你的 CEO 因为代码缺乏优雅而推迟产品的发布)。然而,它通常是一个症状:要么代码不可扩展或耦合度过高,要么容易引入错误。对于 JDBC 来说,后者尤其如此。忘记检查wasNull?这会反过来咬你。忘记关闭你的连接?你会得到一个“内存不足异常”(希望不是在生产环境中)。

在接下来的几节中,我们将探讨我们可以用来包装 JDBC 类型以减轻许多这些风险的模式。我们在这里引入的模式在 Scala 库和应用程序中非常常见。因此,除了编写健壮的类与 JDBC 交互之外,了解这些模式,我希望,将使你对 Scala 编程有更深入的理解。

使用借款模式实现更安全的 JDBC 连接

我们已经看到了如何连接到 JDBC 数据库并向数据库发送执行语句。然而,这种技术有些容易出错:你必须记得关闭语句;否则,你会很快耗尽内存。在更传统的命令式风格中,我们在每个连接周围编写以下 try-finally 块:

// WARNING: poor Scala code
val connection = DriverManager.getConnection(url, user, password)
try {
  // do something with connection
}
finally {
  connection.close()
}

Scala,凭借一等函数,为我们提供了一个替代方案:借款模式。我们编写一个负责打开连接、将连接借给客户端代码以执行一些有趣的操作,并在客户端代码完成后关闭连接的函数。因此,客户端代码不再负责关闭连接。

让我们创建一个新的SqlUtils对象,并使用usingConnection方法利用借款模式:

// SqlUtils.scala

import java.sql._

object SqlUtils {

  /** Create an auto-closing connection using 
    * the loan pattern */
  def usingConnectionT(f:Connection => T):T = {

    // Create the connection
    val Url = s"jdbc:mysql://$host:$port/$db"
    Class.forName("com.mysql.jdbc.Driver")
    val connection = DriverManager.getConnection(
      Url, user, password)

    // give the connection to the client, through the callable 
    // `f` passed in as argument
    try {
      f(connection)
    }
    finally {
      // When client is done, close the connection
      connection.close()
    }
  }
}

让我们看看这个函数的实际应用:

scala> SqlUtils.usingConnection("test") {
 connection => println(connection)
}
com.mysql.jdbc.JDBC4Connection@46fd3d66

因此,客户端不需要记住关闭连接,对于客户端来说,代码感觉更像是 Scala。

我们的usingConnection函数是如何工作的?函数定义是def usingConnection( ... )(f : Connection => T ):T。它接受第二组参数,即作用于Connection对象的功能。usingConnection的主体创建连接,然后将其传递给f,最后关闭连接。这种语法与 Ruby 中的代码块或 Python 中的with语句有些相似。

小贴士

当混合借款模式与延迟操作时要小心。这尤其适用于从f返回迭代器、流和未来。一旦执行线程离开f,连接就会被关闭。在此点之前未实例化的任何数据结构将无法继续访问连接。

借款模式当然不仅仅局限于数据库连接。当你遇到以下模式时,它非常有用,以下为伪代码:

open resource (eg. database connection, file ...)
use resource somehow // loan resource to client for this part.
close resource

使用“pimp my library”模式丰富 JDBC 语句

在上一节中,我们看到了如何使用贷款模式创建自关闭的连接。这允许我们打开数据库连接,而无需记住关闭它们。然而,我们仍然需要记住关闭我们打开的任何ResultSetPreparedStatement

// WARNING: Poor Scala code
SqlUtils.usingConnection("test") { connection =>
  val statement = connection.prepareStatement(
    "SELECT * FROM physicists")
  val results = statement.executeQuery
  // do something useful with the results
  results.close
  statement.close
}

需要打开和关闭语句有些丑陋且容易出错。这也是贷款模式的另一个自然用例。理想情况下,我们希望编写以下内容:

usingConnection("test") { connection =>
  connection.withQuery("SELECT * FROM physicists") {
    resultSet => // process results
  }
}

我们如何能在Connection类上定义一个.withQuery方法?我们并不控制Connection类的定义,因为它属于 JDBC API 的一部分。我们希望能够以某种方式重新打开Connection类的定义来添加withQuery方法。

Scala 不允许我们重新打开类来添加新方法(这种做法被称为猴子补丁)。然而,我们仍然可以使用pimp my library模式通过隐式转换来增强现有库(www.artima.com/weblogs/viewpost.jsp?thread=179766)。我们首先定义一个包含withQuery方法的RichConnection类。这个RichConnection类是由现有的Connection实例创建的。

// RichConnection.scala

import java.sql.{Connection, ResultSet}

class RichConnection(val underlying:Connection) {

  /** Execute a SQL query and process the ResultSet */
  def withQueryT(f:ResultSet => T):T = {
    val statement = underlying.prepareStatement(query)
    val results = statement.executeQuery
    try {
      f(results) // loan the ResultSet to the client
    }
    finally {
      // Ensure all the resources get freed.
      results.close
      statement.close
    }
  }
}

我们可以通过将每个Connection实例包装在一个RichConnection实例中来使用这个类:

// Warning: poor Scala code
SqlUtils.usingConnection("test") { connection =>
  val richConnection = new RichConnection(connection)
  richConnection.withQuery("SELECT * FROM physicists") {
    resultSet => // process resultSet
  }
}

这增加了不必要的样板代码:我们必须记住将每个连接实例转换为RichConnection才能使用withQuery。幸运的是,Scala 提供了一个更简单的方法,即隐式转换:我们告诉 Scala 如何从Connection转换为RichConnection,反之亦然,并告诉它如果需要则自动(隐式)执行此转换:

// Implicits.scala
import java.sql.Connection

// Implicit conversion methods are often put in 
// an object called Implicits.
object Implicits {
  implicit def pimpConnection(conn:Connection) = 
    new RichConnection(conn)
  implicit def depimpConnection(conn:RichConnection) =  
    conn.underlying
}

现在,每当pimpConnectiondepimpConnection在当前作用域内时,Scala 将自动使用它们将Connection实例转换为RichConnection,并在需要时将其转换回Connection

现在,我们可以编写以下内容(我添加了类型信息以强调):

// Bring the conversion functions into the current scope
import Implicits._ 

SqlUtils.usingConnection("test") { (connection:Connection) =>
  connection.withQuery("SELECT * FROM physicists") {
    // Wow! It's like we have just added 
    // .withQuery to the JDBC Connection class!
    resultSet => // process results
  }
}

这可能看起来像魔法,所以让我们退后一步,看看当我们对一个Connection实例调用withQuery时会发生什么。Scala 编译器首先会查看Connection类的定义是否包含withQuery方法。当它发现没有时,它会寻找将Connection实例转换为定义withQuery的类的隐式方法。它会发现pimpConnection方法允许从Connection转换为定义withQueryRichConnection。Scala 编译器会自动使用pimpConnectionConnection实例转换为RichConnection

注意,我们使用了pimpConnectiondepimpConnection这样的名称作为转换函数,但它们可以是任何名称。我们从未明确调用这些方法。

让我们总结一下如何使用pimp my library模式向现有类添加方法:

  1. 编写一个包装你想要增强的类的类:class RichConnection(val underlying:Connection)。添加你希望原始类拥有的所有方法。

  2. 编写一个方法将你的原始类转换为你的增强类,作为名为(传统上)Implicits 的对象的一部分。确保你告诉 Scala 使用 implicit 关键字自动使用这个转换:implicit def pimpConnection(conn:Connection):RichConnection。你还可以告诉 Scala 通过添加反向转换方法自动将增强类转换回原始类。

  3. 通过导入隐式转换方法允许隐式转换:import Implicits._

在流中封装结果集

JDBC ResultSet 对象与 Scala 集合配合得非常糟糕。真正能够用它做些有用的事情的唯一方法就是直接使用 while 循环遍历它。例如,为了获取我们数据库中物理学家的名字列表,我们可以编写以下代码:

// WARNING: poor Scala code
import Implicits._ // import implicit conversions

SqlUtils.usingConnection("test") { connection =>
  connection.withQuery("SELECT * FROM physicists") { resultSet =>
    var names = List.empty[String]
    while(resultSet.next) {
      val name = resultSet.getString("name")
      names = name :: names
    }
 names
  }
}
//=> List[String] = List(Paul Dirac, Albert Einstein, Marie Curie, Richard Feynman, Isaac Newton)

ResultSet 接口感觉很不自然,因为它与 Scala 集合的行为非常不同。特别是,它不支持我们在 Scala 中视为理所当然的高阶函数:没有 mapfilterfoldfor 语句。幸运的是,编写一个封装 ResultSetstream 非常简单。Scala 流是一个延迟评估的列表:它在需要时评估集合中的下一个元素,并在不再使用时忘记之前的元素。

我们可以定义一个 stream 方法,如下封装 ResultSet

// SqlUtils.scala
object SqlUtils {   
  ... 
  def stream(results:ResultSet):Stream[ResultSet] = 
    if (results.next) { results #:: stream(results) }
    else { Stream.empty[ResultSet] }
}

这可能看起来相当令人困惑,所以让我们慢慢来。我们定义一个 stream 方法来封装 ResultSet,返回一个 Stream[ResultSet]。当客户端在空结果集上调用 stream 时,这只会返回一个空流。当客户端在非空 ResultSet 上调用 stream 时,ResultSet 实例会向前推进一行,客户端会得到 results #:: stream(results)。流上的 #:: 操作符类似于列表上的 :: 操作符:它将 results 预先添加到现有的 Stream 中。关键的区别是,与列表不同,stream(results) 不会在必要时进行评估。因此,这避免了在内存中重复整个 ResultSet

让我们使用我们全新的 stream 函数来获取我们数据库中所有物理学家的名字:

import Implicits._

SqlUtils.usingConnection("test") { connection =>
  connection.withQuery("SELECT * FROM physicists") { results =>
    val resultsStream = SqlUtils.stream(results)
    resultsStream.map { _.getString("name") }.toVector
  }
}
//=> Vector(Richard Feynman, Albert Einstein, Marie Curie, Paul Dirac)

相比直接使用结果集,流式处理结果允许我们以更自然的方式与数据交互,因为我们现在处理的是 Scala 集合。

当你在 withQuery 块(或者更一般地,在自动关闭结果集的块)中使用 stream 时,你必须在函数内部始终将流具体化,因此调用了 toVector。否则,流将等待其元素被需要时才具体化它们,而那时,ResultSet 实例将被关闭。

通过类型类实现更松散的耦合

到目前为止,我们一直在数据库中读取和写入简单类型。让我们假设我们想要向我们的数据库添加一个 gender 列。我们将把性别作为枚举存储在我们的物理学家数据库中。我们的表现在如下所示:

mysql> CREATE TABLE physicists (
 id INT(11) AUTO_INCREMENT PRIMARY KEY,
 name VARCHAR(32) NOT NULL,
 gender ENUM("Female", "Male") NOT NULL
);

我们如何在 Scala 中表示性别?一种好的方法是使用枚举:

// Gender.scala

object Gender extends Enumeration {
  val Male = Value
  val Female = Value
}

然而,我们现在在从数据库反序列化对象时遇到了问题:JDBC 没有内置机制将 SQL ENUM类型转换为 Scala Gender类型。我们可以通过每次需要读取性别信息时手动转换来实现这一点:

resultsStream.map { 
  rs => Gender.withName(rs.getString("gender")) 
}.toVector

然而,我们需要在所有想要读取gender字段的地方都写下这些代码。这违反了 DRY(不要重复自己)原则,导致代码难以维护。如果我们决定更改数据库中存储性别的方式,我们就需要找到代码中所有读取gender字段的地方并对其进行更改。

一个稍微好一些的解决方案是在ResultSet类中添加一个getGender方法,使用我们在本章中广泛使用的 pimp my library 习语。这个解决方案仍然不是最优的。我们正在向ResultSet添加不必要的特定性:它现在与我们的数据库结构耦合。

我们可以通过继承ResultSet来创建一个子类,例如PhysicistResultSet,这样就可以读取特定表中的字段。然而,这种方法是不可组合的:如果我们还有另一个表,它跟踪宠物,包括名称、种类和性别字段,我们就必须要么在新的PetResultSet中重新实现读取性别的代码,要么提取一个GenderedResultSet超类。随着表数量的增加,继承层次结构将变得难以管理。更好的方法可以让我们组合所需的功能。特别是,我们希望将从一个结果集中提取 Scala 对象的过程与遍历结果集的代码解耦。

类型类

Scala 提供了一个优雅的解决方案,使用类型类。类型类是 Scala 架构师箭袋中的一个非常强大的箭头。然而,它们可能有一定的学习曲线,尤其是在面向对象编程中没有直接等效物。

我不会提供抽象的解释,而是直接进入一个例子:我将描述如何利用类型类将ResultSet中的字段转换为 Scala 类型。目标是定义一个readT方法在ResultSet上,该方法知道如何精确地将对象反序列化为类型T。此方法将替换并扩展ResultSet中的getXXX方法:

// results is a ResultSet instance
val name = results.readString
val gender = results.readGender.Value

我们首先定义一个抽象的SqlReader[T]特质,它公开一个read方法,用于从ResultSet中读取特定字段并返回类型为T的实例:

// SqlReader.scala

import java.sql._

trait SqlReader[T] {
  def read(results:ResultSet, field:String):T
}

现在,我们需要为每个我们想要读取的T类型提供一个SqlReader[T]的具体实现。让我们为GenderString字段提供具体实现。我们将实现放在SqlReader伴生对象中:

// SqlReader.scala

object SqlReader {
  implicit object StringReader extends SqlReader[String] {
    def read(results:ResultSet, field:String):String =
      results.getString(field)
  }

  implicit object GenderReader extends SqlReader[Gender.Value] {
    def read(results:ResultSet, field:String):Gender.Value =
      Gender.withName(StringReader.read(results, field))
  }
}

我们现在可以使用我们的ReadableXXX对象从结果集中读取:

import SqlReader._
val name = StringReader.read(results, "name")
val gender = GenderReader.read(results, "gender")

这已经比使用以下方法好一些:

Gender.withName(results.getString("gender"))

这是因为将 ResultSet 字段映射到 Gender.Value 的代码集中在一个单独的位置:ReadableGender。然而,如果我们能告诉 Scala 在需要读取 Gender.Value 时使用 ReadableGender,在需要读取字符串值时使用 ReadableString,那就太好了。这正是类型类的作用。

面向类型类的编码

我们定义了一个 Readable[T] 接口,它抽象化了如何从 ResultSet 字段中读取类型为 T 的对象。我们如何告诉 Scala 需要使用这个 Readable 对象将 ResultSet 字段转换为适当的 Scala 类型?

关键是我们在 GenderReaderStringReader 对象定义前使用的 implicit 关键字。它允许我们编写:

implicitly[SqlReader[Gender.Value]].read(results, "gender")
implicitly[SqlReader[String]].read(results, "name")

通过编写 implicitly[SqlReader[T]],我们是在告诉 Scala 编译器找到一个扩展 SqlReader[T] 并标记为隐式使用的类(或对象)。你可以通过在命令行粘贴以下内容来尝试,例如:

scala> :paste

import Implicits._ // Connection to RichConnection conversion
SqlUtils.usingConnection("test") {
 _.withQuery("select * from physicists") {
 rs => {
 rs.next() // advance to first record
 implicitly[SqlReader[Gender.Value]].read(rs, "gender")
 }
 }
}

当然,在所有地方使用 implicitly[SqlReader[T]] 并不是特别优雅。让我们使用“pimp my library”惯用法向 ResultSet 添加一个 read[T] 方法。我们首先定义一个 RichResultSet 类,我们可以用它来“pimp” ResultSet 类:

// RichResultSet.scala

import java.sql.ResultSet

class RichResultSet(val underlying:ResultSet) {
  def readT : SqlReader:T = {
    implicitly[SqlReader[T]].read(underlying, field)
  }
}

这里的唯一不熟悉的部分应该是 read[T : SqlReader] 泛型定义。我们在这里声明,如果存在 SqlReader[T] 的实例,read 将接受任何 T 类型。这被称为上下文限制

我们还必须在 Implicits 对象中添加隐式方法,以将 ResultSet 转换为 RichResultSet。你现在应该熟悉这个了,所以我就不会详细说明了。你现在可以为任何具有 SqlReader[T] 隐式对象的 T 调用 results.readT

import Implicits._

SqlUtils.usingConnection("test") { connection =>
  connection.withQuery("SELECT * FROM physicists") {
    results =>
      val resultStream = SqlUtils.stream(results)
      resultStream.map { row => 
 val name = row.readString
 val gender = row.readGender.Value
        (name, gender)
      }.toVector
  }
}
//=> Vector[(String, Gender.Value)] = Vector((Albert Einstein,Male), (Marie Curie,Female))

让我们总结一下使类型类正常工作所需的步骤。我们将在从 SQL 反序列化的上下文中进行此操作,但你将能够将这些步骤适应以解决其他问题:

  • 定义一个抽象的泛型特质,为类型类提供接口,例如,SqlReader[T]。任何与 T 无关的功能都可以添加到这个基本特质中。

  • 为基本特质创建伴随对象,并为每个 T 添加扩展特质的隐式对象,例如,

    implicit object StringReader extends SqlReader[T].
    
  • 类型类始终用于泛型方法。一个依赖于类型类存在的参数的方法必须在泛型定义中包含上下文限制,例如,def readT : SqlReader:T。要访问此方法中的类型类,使用 implicitly 关键字:implicitly[SqlReader[T]]

何时使用类型类

当你需要为许多不同类型实现特定行为,但此行为的具体实现在这类之间有所不同时,类型类很有用。例如,我们需要能够从 ResultSet 中读取几种不同的类型,但每种类型的读取方式各不相同:对于字符串,我们必须使用 getStringResultSet 中读取,而对于整数,我们必须使用 getInt 后跟 wasNull

一个很好的经验法则是当你开始想“哦,我完全可以写一个泛型方法来做这件事。啊,但是等等,我必须将 Int 实现作为一个特定的边缘情况来写,因为它的行为不同。哦,还有 Gender 实现。我想知道是否有更好的方法?”时,类型类可能就很有用了。

类型类的优势

数据科学家经常必须处理新的输入流、变化的需求和新数据类型。因此,拥有一个易于扩展或更改的对象关系映射层对于有效地应对变化至关重要。最小化代码实体之间的耦合和关注点的分离是确保代码能够根据新数据更改的唯一方法。

使用类型类,我们保持了访问数据库中的记录(通过 ResultSet 类)和将单个字段转换为 Scala 对象的方式之间的正交性:这两者可以独立变化。这两个关注点之间的唯一耦合是通过 SqlReader[T] 接口。

这意味着这两个关注点可以独立进化:要读取新的数据类型,我们只需实现一个 SqlReader[T] 对象。相反,我们可以在不重新实现字段转换方式的情况下向 ResultSet 添加功能。例如,我们可以添加一个 getColumn 方法,它返回 ResultSet 实例中一个字段的 Vector[T] 所有值:

def getColumnT : SqlReader:Vector[T] = {
  val resultStream = SqlUtils.stream(results)
  resultStream.map { _.readT }.toVector
}

注意我们如何在不增加对单个字段读取方式耦合的情况下完成这件事。

创建数据访问层

让我们汇总我们所看到的一切,并为从数据库中检索 Physicist 对象构建一个 数据映射器 类。这些类(也称为 数据访问对象)有助于将对象的内部表示与其在数据库中的表示解耦。

我们首先定义了 Physicist 类:

// Physicist.scala
case class Physicist(
  val name:String,
  val gender:Gender.Value
)

数据访问对象将公开一个单一的方法,readAll,它返回包含我们数据库中所有物理学家的 Vector[Physicist]

// PhysicistDao.scala

import java.sql.{ ResultSet, Connection }
import Implicits._ // implicit conversions

object PhysicistDao {

  /* Helper method for reading a single row */
  private def readFromResultSet(results:ResultSet):Physicist = {
    Physicist(
      results.readString,
      results.readGender.Value
    )
  }

  /* Read the entire 'physicists' table. */
  def readAll(connection:Connection):Vector[Physicist] = {
    connection.withQuery("SELECT * FROM physicists") {
      results =>
        val resultStream = SqlUtils.stream(results)
        resultStream.map(readFromResultSet).toVector
    }
  }
}

客户端代码可以使用数据访问层,如下例所示:

object PhysicistDaoDemo extends App {

  val physicists = SqlUtils.usingConnection("test") {
    connection => PhysicistDao.readAll(connection)
  }

  // physicists is a Vector[Physicist] instance.
  physicists.foreach { println }
  //=> Physicist(Albert Einstein,Male)
  //=> Physicist(Marie Curie,Female)
}

摘要

在本章中,我们学习了如何使用 JDBC 与 SQL 数据库进行交互。我们编写了一个库来封装原生 JDBC 对象,目的是为它们提供一个更功能化的接口。

在下一章中,你将了解 Slick,这是一个 Scala 库,它提供了与关系数据库交互的功能包装器。

参考资料

JDBC 的 API 文档非常完整:docs.oracle.com/javase/7/docs/api/java/sql/package-summary.html

ResultSet 接口的 API 文档(docs.oracle.com/javase/7/docs/api/java/sql/ResultSet.html)、PreparedStatement 类(docs.oracle.com/javase/7/docs/api/java/sql/PreparedStatement.html)和 Connection 类(docs.oracle.com/javase/7/docs/api/java/sql/Connection.html)的文档尤其相关。

数据映射模式在 Martin Fowler 的《企业应用架构模式》中描述得非常详细。在他的网站上也有简要的描述(martinfowler.com/eaaCatalog/dataMapper.html)。

对于 SQL 的入门,我建议阅读 Alan Beaulieu 的《Learning SQL》(O'Reilly)。

对于类型类的另一篇讨论,请阅读 danielwestheide.com/blog/2013/02/06/the-neophytes-guide-to-scala-part-12-type-classes.html

本文描述了如何使用类型类在 Scala 中更优雅地重新实现一些常见的面向对象设计模式:

staticallytyped.wordpress.com/2013/03/24/gang-of-four-patterns-with-type-classes-and-implicits-in-scala-part-2/

这篇由 Martin Odersky 撰写的文章详细介绍了 Pimp my Library 模式:

www.artima.com/weblogs/viewpost.jsp?thread=179766

第六章。Slick – SQL 的函数式接口

在 第五章 中,我们探讨了如何使用 JDBC 访问 SQL 数据库。由于与 JDBC 交互感觉有些不自然,我们通过自定义封装扩展了 JDBC。这些封装是为了提供一个函数式接口,以隐藏 JDBC 的命令式本质。

由于直接从 Scala 与 JDBC 交互的难度以及 SQL 数据库的普遍性,你可能会期待存在现有的 Scala 库来封装 JDBC。Slick 就是这样一个库。

Slick 自称为 函数式-关系映射 库,这是对更传统的 对象-关系映射 名称的一种戏谑,后者用来表示从关系数据库构建对象的库。它提供了一个函数式接口来访问 SQL 数据库,允许客户端以类似于原生 Scala 集合的方式与之交互。

FEC 数据

在本章中,我们将使用一个稍微复杂一些的示例数据集。美国联邦选举委员会(Federal Electoral Commission of the United States,简称 FEC)记录了所有超过 200 美元的总统候选人捐款。这些记录是公开可用的。我们将查看导致巴拉克·奥巴马连任的 2012 年大选前的捐款情况。数据包括对两位总统候选人奥巴马和罗姆尼的捐款,以及共和党初选中其他竞争者的捐款(没有民主党初选)。

在本章中,我们将使用 FEC 提供的交易数据,将其存储在表中,并学习如何查询和分析它。

第一步是获取数据。如果你已经从 Packt 网站下载了代码示例,你应该已经在本章代码示例的 data 目录中有两个 CSV 文件。如果没有,你可以使用以下链接下载文件:

  • data.scala4datascience.com/fec/ohio.csv.gz(或 ohio.csv.zip

  • data.scala4datascience.com/fec/us.csv.gz(或 us.csv.zip

解压缩这两个文件,并将它们放置在本章源代码示例相同的 data/ 目录中。数据文件对应以下内容:

  • ohio.csv 文件是俄亥俄州所有捐赠者捐赠的 CSV 文件。

  • us.csv 文件是全国范围内所有捐赠者捐赠的 CSV 文件。这是一个相当大的文件,有六百万行。

这两个 CSV 文件包含相同的列。使用俄亥俄州数据集可以获得更快的响应,或者如果你想处理更大的数据集,可以使用全国数据文件。数据集是从 www.fec.gov/disclosurep/PDownload.do 下载的贡献列表中改编的。

让我们先创建一个 Scala case class 来表示一笔交易。在本章的上下文中,一笔交易是一个个人向候选人捐赠的单笔捐款:

// Transaction.scala
import java.sql.Date

case class Transaction(
  id:Option[Int], // unique identifier
  candidate:String, // candidate receiving the donation
  contributor:String, // name of the contributor
  contributorState:String, // contributor state
  contributorOccupation:Option[String], // contributor job
  amount:Long, // amount in cents
  date:Date // date of the donation
)

本章的代码仓库包含一个 FECData 单例对象中的辅助函数,用于从 CSV 文件中加载数据:

scala> val ohioData = FECData.loadOhio
s4ds.FECData = s4ds.FECData@718454de

调用 FECData.loadOhioFECData.loadAll 将创建一个具有单个属性 transactionsFECData 对象,该属性是一个遍历来自俄亥俄州或整个美国的所有捐赠的迭代器:

scala> val ohioTransactions = ohioData.transactions
Iterator[Transaction] = non-empty iterator

scala> ohioTransactions.take(5).foreach(println)
Transaction(None,Paul, Ron,BROWN, TODD W MR.,OH,Some(ENGINEER),5000,2011-01-03)
Transaction(None,Paul, Ron,DIEHL, MARGO SONJA,OH,Some(RETIRED),2500,2011-01-03)
Transaction(None,Paul, Ron,KIRCHMEYER, BENJAMIN,OH,Some(COMPUTER PROGRAMMER),20120,2011-01-03)
Transaction(None,Obama, Barack,KEYES, STEPHEN,OH,Some(HR EXECUTIVE / ATTORNEY),10000,2011-01-03)
Transaction(None,Obama, Barack,MURPHY, MIKE W,OH,Some(MANAGER),5000,2011-01-03)

现在我们有一些数据可以操作,让我们尝试将其放入数据库中,以便我们可以运行一些有用的查询。

导入 Slick

要将 Slick 添加到依赖项列表中,你需要在 build.sbt 文件中的依赖项列表中添加 "com.typesafe.slick" %% "slick" % "2.1.0"。你还需要确保 Slick 有权访问 JDBC 驱动程序。在本章中,我们将连接到 MySQL 数据库,因此必须将 MySQL 连接器 "mysql" % "mysql-connector-java" % "5.1.37" 添加到依赖项列表中。

通过导入特定的数据库驱动程序来导入 Slick。由于我们使用 MySQL,我们必须导入以下内容:

scala> import slick.driver.MySQLDriver.simple._
import slick.driver.MySQLDriver.simple._

要连接到不同类型的 SQL 数据库,导入相关的驱动程序。查看可用的驱动程序的最简单方法是查阅slick.driver包的 API 文档,该文档可在slick.typesafe.com/doc/2.1.0/api/#scala.slick.driver.package找到。所有常见的 SQL 类型都受支持(包括H2PostgreSQLMS SQL ServerSQLite)。

定义模式

让我们创建一个表来表示我们的交易。我们将使用以下模式:

CREATE TABLE transactions(
    id INT(11) AUTO_INCREMENT PRIMARY KEY,
    candidate VARCHAR(254) NOT NULL,
    contributor VARCHAR(254) NOT NULL,
    contributor_state VARCHAR(2) NOT NULL,
    contributor_occupation VARCHAR(254),
    amount BIGINT(20) NOT NULL,
    date DATE 
);

注意,捐赠金额是以表示的。这允许我们使用整数字段(而不是定点小数,或者更糟糕的是浮点数)。

注意

你永远不应该使用浮点格式来表示金钱,实际上,任何离散量,因为浮点数无法精确表示大多数分数:

scala> 0.1 + 0.2
Double = 0.30000000000000004

这种看似荒谬的结果发生是因为在双精度浮点数中无法精确存储 0.3。

这篇文章广泛讨论了浮点格式限制:

docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html

要在数据库中的表上使用 Slick,我们首先需要告诉 Slick 有关数据库模式。我们通过创建一个扩展Table抽象类的类来实现这一点。定义模式的方式相当直接,所以让我们直接进入代码。我们将我们的模式存储在Tables单例中。我们定义一个Transactions类,它提供了从Transaction实例的集合到类似transactions表的 SQL 表的映射:

// Tables.scala

import java.sql.Date
import slick.driver.MySQLDriver.simple._

/** Singleton object for table definitions */
object Tables {

  // Transactions table definition
  class Transactions(tag:Tag)
  extends TableTransaction {
    def id = columnInt
    def candidate = columnString
    def contributor = columnString
    def contributorState = columnString"))
    def contributorOccupation = column[Option[String]](
      "contributor_occupation")
    def amount = columnLong
    def date = columnDate

    def * = (id.?, candidate, contributor, 
      contributorState, contributorOccupation, amount, date) <> (
      Transaction.tupled, Transaction.unapply)
  }

  val transactions = TableQuery[Transactions]

}

让我们逐行分析。我们首先定义一个Transactions类,它必须将其第一个参数作为 Slick Tag对象。Tag对象由 Slick 内部用于构建 SQL 语句。Transactions类扩展了一个Table对象,传递给它标签和数据库中表的名称。我们可以选择性地通过扩展TableTransaction, "transactions")而不是仅仅TableTransaction来添加数据库名称。Table类型由Transaction参数化。这意味着在数据库上运行SELECT语句返回Transaction对象。同样,我们将通过传递一个事务或事务列表到相关的 Slick 方法来在数据库中插入数据。

让我们更详细地看看Transactions类的定义。类的主体首先列出数据库列。例如,id列定义如下:

def id = columnInt

我们告诉 Slick 它应该读取名为id的列,并将其转换为 Scala 整数。此外,我们还告诉 Slick 该列是主键,并且它是自动增长的。Slick 文档中包含column可用选项的列表。

candidatecontributor 列很简单:我们告诉 Slick 从数据库中读取这些作为 Stringcontributor_state 列有点更有趣。除了指定它应该作为 String 从数据库中读取外,我们还告诉 Slick 它应该以 VARCHAR(2) 类型存储在数据库中。

我们表中的 contributor_occupation 列可以包含 NULL 值。在定义模式时,我们将 Option[String] 类型传递给列方法:

def contributorOccupation = column[Option[String]]("contributor_occupation")

当从数据库读取时,对于指定为 Option[T] 的列,NULL 字段将被转换为 None。相反,如果字段有值,它将以 Some(value) 返回。

类体的最后一行是最有趣的部分:它指定了如何将读取的原始数据转换为 Transaction 对象以及如何将 Transaction 对象转换为准备插入的原始字段:

def * = (id.?, candidate, contributor, 
contributorState, contributorOccupation, amount, date) <> (
Transaction.tupled, Transaction.unapply)

第一部分只是从数据库中读取的字段元组:(id.?, candidate, contributor, contributorState, contributorOccupation, amount, date),附带少量元数据。第二部分是一对函数,描述了如何将此元组转换为 Transaction 对象以及如何反向转换。在这种情况下,由于 Transaction 是一个案例类,我们可以利用为案例类自动提供的 Transaction.tupledTransaction.unapply 方法。

注意我们是如何在 id 条目后面跟着 .? 的。在我们的 Transaction 类中,捐赠 id 具有类型 Option[Int],但数据库中的列具有类型 INT,并附加了 O.AutoInc 选项。.? 后缀告诉 Slick 如果 idNone,则使用数据库提供的默认值(在这种情况下,数据库的自动递增)。

最后,我们定义了值:

val transactions = TableQuery[Transactions]

这是实际与数据库交互的句柄。例如,正如我们稍后将会看到的,要获取巴拉克·奥巴马的捐赠列表,我们将运行以下查询(现在不用担心查询的细节):

Tables.transactions.filter {_.candidate === "Obama, Barack"}.list

让我们总结一下我们的 Transactions 映射器类的部分:

  • Transactions 类必须扩展由我们想要返回的类型参数化的 Table 抽象类:Table[Transaction]

  • 我们使用 column 显式定义从数据库中读取的列,例如,def contributorState = columnString"))[String] 类型参数定义了此列读取为的 Scala 类型。第一个参数是 SQL 列名。请参阅 Slick 文档以获取附加参数的完整列表(slick.typesafe.com/doc/2.1.0/schemas.html)。

  • 我们描述了如何使用 def * = (id.?, candidate, ...) <> (Transaction.tupled, Transaction.unapply) 将列值的元组转换为 Scala 对象,反之亦然。

连接到数据库

到目前为止,你已经学习了如何定义Table类,这些类将 SQL 表中的行转换为 Scala 案例类。为了超越表定义并开始与数据库服务器交互,我们必须连接到数据库。与上一章类似,我们将假设有一个 MySQL 服务器在本地主机的3306端口上运行。

我们将使用控制台来演示本章的功能,但你可以在SlickDemo.scala中找到一个等效的示例程序。让我们打开一个 Scala 控制台并连接到在3306端口上运行的数据库:

scala> import slick.driver.MySQLDriver.simple._
import slick.driver.MySQLDriver.simple._

scala> val db = Database.forURL(
 "jdbc:mysql://127.0.0.1:3306/test",
 driver="com.mysql.jdbc.Driver"
)
db: slick.driver.MySQLDriver.backend.DatabaseDef = slick.jdbc.JdbcBackend$DatabaseDef@3632d1dd

如果你已经阅读了上一章,你会认出第一个参数是一个 JDBC 风格的 URL。URL 首先定义了一个协议,在本例中是jdbc:mysql,然后是数据库服务器的 IP 地址和端口,最后是数据库名称(这里为test)。

forURL的第二个参数是 JDBC 驱动程序的类名。此驱动程序在运行时使用反射导入。请注意,此处指定的驱动程序必须与静态导入的 Slick 驱动程序匹配。

定义了数据库后,我们现在可以使用它来创建一个连接:

scala> db.withSession { implicit session =>
 // do something useful with the database
 println(session)
}
scala.slick.jdbc.JdbcBackend$BaseSession@af5a276

需要访问数据库的流畅函数隐式地接受一个Session参数:如果作用域内可用一个标记为隐式的Session实例,它们将使用它。因此,在session前加上implicit关键字可以节省我们每次在数据库上运行操作时显式传递session

如果你已经阅读了上一章,你会认出 Slick 处理关闭连接的需求使用的是借款模式:数据库连接以session对象的形式创建,并临时传递给客户端。当客户端代码返回时,会话被关闭,确保所有打开的连接都被关闭。因此,客户端代码免除了关闭连接的责任。

借款模式在生产代码中非常有用,但在 shell 中可能会有些繁琐。Slick 允许我们显式地创建一个会话,如下所示:

scala> implicit val session = db.createSession
session: slick.driver.MySQLDriver.backend.Session = scala.slick.jdbc.JdbcBackend$BaseSession@2b775b49

scala> session.close

创建表格

让我们使用我们新的连接在数据库中创建事务表。我们可以通过TableQuery[Transactions]实例上的ddl属性访问创建和删除表的方法:

scala> db.withSession { implicit session =>
 Tables.transactions.ddl.create
}

如果你跳入mysql shell,你会看到已经创建了一个transactions表:

mysql> describe transactions ;
+------------------------+--------------+------+-----+
| Field                  | Type         | Null | Key |
+------------------------+--------------+------+-----+
| id                     | int(11)      | NO   | PRI |
| candidate              | varchar(254) | NO   |     |
| contributor            | varchar(254) | NO   |     |
| contributor_state      | varchar(2)   | NO   |     |
| contributor_occupation | varchar(254) | YES  |     |
| amount                 | bigint(20)   | NO   |     |
| date                   | date         | NO   |     |
+------------------------+--------------+------+-----+
7 rows in set (0.01 sec)

ddl属性还包括一个drop方法来删除表。顺便提一下,ddl代表“数据定义语言”,通常用来指代与模式定义和约束定义相关的 SQL 部分。

插入数据

Slick 的TableQuery实例让我们可以通过与 Scala 集合类似的接口与 SQL 表交互。

让我们首先创建一个交易。我们将假装在 2010 年 6 月 22 日发生了一次捐赠。不幸的是,在 Scala 中创建日期并将其传递给 JDBC 的代码特别繁琐。我们首先创建一个java.util.Date实例,然后我们必须将其转换为java.sql.Date以用于我们新创建的交易:

scala> import java.text.SimpleDateFormat
import java.text.SimpleDateFormat

scala> val date = new SimpleDateFormat("dd-MM-yyyy").parse("22-06-2010")
date: java.util.Date = Tue Jun 22 00:00:00 BST 2010

scala> val sqlDate = new java.sql.Date(date.getTime())
sqlDate: java.sql.Date = 2010-06-22

scala> val transaction = Transaction(
 None, "Obama, Barack", "Doe, John", "TX", None, 200, sqlDate
)
transaction: Transaction = Transaction(None,Obama, Barack,Doe, John,TX,None,200,2010-06-22)

TableQuery实例提供的界面大部分与可变列表相似。为了在事务表中插入单行,我们可以使用+=运算符:

scala> db.withSession {
 implicit session => Tables.transactions += transaction
}
Int = 1

在底层,这将创建一个 JDBC 预处理语句并运行该语句的executeUpdate方法。

如果你一次提交多行,你应该使用 Slick 的批量插入运算符:++=。它接受一个List[Transaction]作为输入,并利用 JDBC 的addBatchexecuteBatch功能将所有交易批量插入:

让我们插入所有 FEC 交易,这样我们就有一些数据在下一节运行查询时可以操作。我们可以通过调用以下代码来加载俄亥俄州的交易迭代器:

scala> val transactions = FECData.loadOhio.transactions
transactions: Iterator[Transaction] = non-empty iterator

我们也可以加载整个美国的交易:

scala> val transactions = FECData.loadAll.transactions
transactions: Iterator[Transaction] = non-empty iterator

为了避免一次性将所有交易实体化——从而可能超出我们计算机的可用内存——我们将从迭代器中取出交易批次并插入它们:

scala> val batchSize = 100000
batchSize: Int = 100000

scala> val transactionBatches = transactions.grouped(batchSize)
transactionBatches: transactions.GroupedIterator[Transaction] = non-empty iterator

迭代器的grouped方法将迭代器分割成批次。将长集合或迭代器分割成可管理的批次,以便可以逐个处理,这在集成或处理大型数据集时非常重要。

我们现在要做的就是遍历我们的批次,在插入的同时将它们存入数据库:

scala> db.withSession { implicit session =>
 transactionBatches.foreach { 
 batch => Tables.transactions ++= batch.toList
 }
}

虽然这样可行,但在进行长时间运行的集成过程时,有时查看进度报告是有用的。由于我们将集成分成了批次,我们知道(到最近的批次为止)我们集成到了哪里。让我们在每一个批次的开始处打印进度信息:

scala> db.withSession { implicit session =>
 transactionBatches.zipWithIndex.foreach { 
 case (batch, batchNumber) =>
 println(s"Processing row ${batchNumber*batchSize}")
 Tables.transactions ++= batch.toList
 }
}
Processing row 0
Processing row 100000
...

我们使用.zipWithIndex方法将我们的批次迭代器转换为(批次, 当前索引)对的迭代器。在一个完整规模的应用中,进度信息可能会被写入日志文件而不是屏幕。

Slick 精心设计的界面使得插入数据非常直观,与原生 Scala 类型很好地集成。

查询数据

在上一节中,我们使用 Slick 将捐赠数据插入到我们的数据库中。现在让我们探索这些数据。

在定义Transactions类时,我们定义了一个TableQuery对象transactions,它作为访问交易表的句柄。它提供了一个类似于 Scala 迭代器的接口。例如,要查看数据库中的前五个元素,我们可以调用take(5)

scala> db.withSession { implicit session =>
 Tables.transactions.take(5).list
}
List[Tables.Transactions#TableElementType] = List(Transaction(Some(1),Obama, Barack,Doe, ...

内部,Slick 使用 SQL LIMIT来实现.take方法。实际上,我们可以通过查询上的.selectStatement方法来获取 SQL 语句:

scala> db.withSession { implicit session =>
 println(Tables.transactions.take(5).selectStatement)
}
select x2.`id`, x2.`candidate`, x2.`contributor`, x2.`contributor_state`, x2.`contributor_occupation`, x2.`amount`, x2.`date` from (select x3.`date` as `date`, x3.`contributor` as `contributor`, x3.`amount` as `amount`, x3.`id` as `id`, x3.`candidate` as `candidate`, x3.`contributor_state` as `contributor_state`, x3.`contributor_occupation` as `contributor_occupation` from `transactions` x3 limit 5) x2

我们的流畅查询由以下两个部分组成:

  • .take(n): 这一部分称为 调用器。调用器构建 SQL 语句,但不会实际将其发送到数据库。你可以将多个调用器链接在一起来构建复杂的 SQL 语句。

  • .list: 这一部分将调用器准备好的语句发送到数据库,并将结果转换为 Scala 对象。这需要一个 session 参数,可能是隐式的。

调用器

调用器是 Slick 查询的组成部分,用于构建 SQL 选择语句。Slick 提供了各种调用器,允许构建复杂的查询。让我们看看其中的一些调用器:

  • map 调用器用于选择单个列或对列应用操作:

    scala> db.withSession { implicit session =>
     Tables.transactions.map {
     _.candidate 
     }.take(5).list 
    }
    List[String] = List(Obama, Barack, Paul, Ron, Paul, Ron, Paul, Ron, Obama, Barack)
    
    
  • filter 调用器等同于 SQL 中的 WHERE 语句。请注意,Slick 字段必须使用 === 进行比较:

    scala> db.withSession { implicit session => 
     Tables.transactions.filter {
     _.candidate === "Obama, Barack"
     }.take(5).list
    }
    List[Tables.Transactions#TableElementType] = List(Transaction(Some(1),Obama, Barack,Doe, John,TX,None,200,2010-06-22), ...
    
    

    同样,要过滤掉对巴拉克·奥巴马的捐款,使用 =!= 运算符:

    scala> db.withSession { implicit session => 
     Tables.transactions.filter { 
     _.candidate =!= "Obama, Barack"
     }.take(5).list
    }
    List[Tables.Transactions#TableElementType] = List(Transaction(Some(2),Paul, Ron,BROWN, TODD W MR.,OH,...
    
    
  • sortBy 调用器等同于 SQL 中的 ORDER BY 语句:

    scala> db.withSession { implicit session => 
     Tables.transactions.sortBy { 
     _.date.desc 
     }.take(5).list
    }
    List[Tables.Transactions#TableElementType] = List(Transaction(Some(65536),Obama, Barack,COPELAND, THOMAS,OH,Some(COLLEGE TEACHING),10000,2012-01-02)
    
    
  • leftJoinrightJoininnerJoinouterJoin 调用器用于连接表。由于本教程不涉及多表之间的交互,我们无法演示连接操作。请参阅 Slick 文档(slick.typesafe.com/doc/2.1.0/queries.html#joining-and-zipping)以了解这些操作的示例。

  • lengthminmaxsumavg 等聚合调用器可用于计算汇总统计信息。这些操作必须使用 .run 而不是 .list 执行,因为它们返回单个数字。例如,要获取巴拉克·奥巴马的总捐款:

    scala> db.withSession { implicit session => 
     Tables.transactions.filter {
     _.candidate === "Obama, Barack"
     }.map { _.amount  }.sum.run
    }
    Option[Int] = Some(849636799) // (in cents)
    
    

列上的操作

在上一节中,你学习了不同的调用器以及它们如何映射到 SQL 语句。然而,我们只是简要地提到了列本身支持的方法:我们可以使用 === 进行相等比较,但 Slick 列还支持哪些其他操作?

大多数 SQL 函数都得到了支持。例如,要获取以 "O" 开头的候选人的总捐款,我们可以运行以下命令:

scala> db.withSession { implicit session =>
 Tables.transactions.filter { 
 _.candidate.startsWith("O") 
 }.take(5).list 
}
List[Tables.Transactions#TableElementType] = List(Transaction(Some(1594098)...

同样,要计算在 2011 年 1 月 1 日至 2011 年 2 月 1 日之间发生的捐款,我们可以在 date 列上使用 .between 方法:

scala> val dateParser = new SimpleDateFormat("dd-MM-yyyy")
dateParser: java.text.SimpleDateFormat = SimpleDateFormat

scala> val startDate = new java.sql.Date(dateParser.parse("01-01-2011").getTime())
startDate: java.sql.Date = 2011-01-01

scala> val endDate = new java.sql.Date(dateParser.parse("01-02-2011").getTime())
endDate: java.sql.Date = 2011-02-01

scala> db.withSession { implicit session =>
 Tables.transactions.filter { 
 _.date.between(startDate, endDate)
 }.length.run 
}
Int = 9772

等同于 SQL 中的 IN (...) 操作符,用于选择特定集合中的值的是 inSet。例如,要选择所有对巴拉克·奥巴马和米特·罗姆尼的交易,我们可以使用以下命令:

scala> val candidateList = List("Obama, Barack", "Romney, Mitt")
candidateList: List[String] = List(Obama, Barack, Romney, Mitt)

scala> val donationCents = db.withSession { implicit session =>
 Tables.transactions.filter {
 _.candidate.inSet(candidateList)
 }.map { _.amount }.sum.run
}
donationCents: Option[Long] = Some(2874484657)

scala> val donationDollars = donationCents.map { _ / 100 }
donationDollars: Option[Long] = Some(28744846)

因此, Mitt Romney 和 Barack Obama 之间,他们共收到了超过 2800 万美元的注册捐款。

我们还可以使用 ! 运算符否定布尔列。例如,要计算除巴拉克·奥巴马和米特·罗姆尼之外所有候选人的总捐款金额:

scala> db.withSession { implicit session =>
 Tables.transactions.filter { 
 ! _.candidate.inSet(candidateList) 
 }.map { _.amount }.sum.run
}.map { _ / 100 }
Option[Long] = Some(1930747)

列操作是通过在基Column实例上隐式转换来添加的。有关字符串列上可用的所有方法的完整列表,请参阅StringColumnExtensionMethods类的 API 文档(slick.typesafe.com/doc/2.1.0/api/#scala.slick.lifted.StringColumnExtensionMethods)。对于布尔列上可用的方法,请参阅BooleanColumnExtensionMethods类的 API 文档(slick.typesafe.com/doc/2.1.0/api/#scala.slick.lifted.BooleanColumnExtensionMethods)。对于数值列上可用的方法,请参阅NumericColumnExtensionMethods的 API 文档(slick.typesafe.com/doc/2.1.0/api/#scala.slick.lifted.NumericColumnExtensionMethods)。

使用“按组分组”进行聚合

Slick 还提供了一个groupBy方法,其行为类似于原生 Scala 集合的groupBy方法。让我们获取每个候选人的所有捐款的候选人列表:

scala> val grouped = Tables.transactions.groupBy { _.candidate }
grouped: scala.slick.lifted.Query[(scala.slick.lifted.Column[...

scala> val aggregated = grouped.map {
 case (candidate, group) =>
 (candidate -> group.map { _.amount }.sum)
}
aggregated: scala.slick.lifted.Query[(scala.slick.lifted.Column[...

scala> val groupedDonations = db.withSession { 
 implicit session => aggregated.list 
}
groupedDonations: List[(String, Option[Long])] = List((Bachmann, Michele,Some(7439272)),...

让我们分解一下。第一条语句transactions.groupBy { _.candidate }指定了分组的关键字。你可以将其视为构建一个中间列表,其中包含(String, List[Transaction])元组,将组关键字映射到满足此关键字的表的所有行列表。这种行为与在 Scala 集合上调用groupBy相同。

groupBy调用必须后跟一个map来聚合组。传递给map的函数必须将groupBy调用创建的(String, List[Transaction])元组对作为其唯一参数。map调用负责聚合List[Transaction]对象。我们选择首先提取每个交易的amount字段,然后对这些字段进行求和。最后,我们在整个管道上调用.list来实际运行查询。这仅仅返回一个 Scala 列表。让我们将总捐款从分转换为美元:

scala> val groupedDonationDollars = groupedDonations.map {
 case (candidate, donationCentsOption) =>
 candidate -> (donationCentsOption.getOrElse(0L) / 100)
}
groupedDonationDollars: List[(String, Long)] = List((Bachmann, Michele,74392),...

scala> groupedDonationDollars.sortBy { 
 _._2 
}.reverse.foreach { println }
(Romney, Mitt,20248496)
(Obama, Barack,8496347)
(Paul, Ron,565060)
(Santorum, Rick,334926)
(Perry, Rick,301780)
(Gingrich, Newt,277079)
(Cain, Herman,210768)
(Johnson, Gary Earl,83610)
(Bachmann, Michele,74392)
(Pawlenty, Timothy,42500)
(Huntsman, Jon,23571)
(Roemer, Charles E. 'Buddy' III,8579)
(Stein, Jill,5270)
(McCotter, Thaddeus G,3210)

访问数据库元数据

通常,尤其是在开发期间,你可能从删除(如果存在)表并重新创建它开始脚本。我们可以通过通过MTable对象访问数据库元数据来检查表是否已定义。要获取与特定模式匹配的表列表,我们可以运行MTable.getTables(pattern)

scala> import slick.jdbc.meta.MTable
import slick.jdbc.meta.MTable

scala> db.withSession { implicit session =>
 MTable.getTables("transactions").list
}
List[scala.slick.jdbc.meta.MTable] = List(MTable(MQName(fec.transactions),TABLE,,None,None,None) ...)

因此,为了删除(如果存在)交易表,我们可以运行以下操作:

scala> db.withSession { implicit session =>
 if(MTable.getTables("transactions").list.nonEmpty) {
 Tables.transactions.ddl.drop
 }
}

MTable实例包含关于表的大量元数据。如果你在前一个示例中删除了它,现在就重新创建transactions表。然后,为了找到关于表的主键的信息:

scala> db.withSession { implicit session =>
 val tableMeta = MTable.getTables("transactions").first
 tableMeta.getPrimaryKeys.list
}
List[MPrimaryKey] = List(MPrimaryKey(MQName(test.transactions),id,1,Some(PRIMARY)))

要获取MTable实例上可用的方法完整列表,请参阅 Slick 文档(slick.typesafe.com/doc/2.1.0/api/index.html#scala.slick.jdbc.meta.MTable)。

Slick 与 JDBC 的比较

本章和上一章介绍了两种不同的与 SQL 交互的方式。在上一章中,我们描述了如何使用 JDBC 并在其之上构建扩展以使其更易于使用。在本章中,我们介绍了 Slick,这是一个在 JDBC 之上提供函数式接口的库。

你应该选择哪种方法?如果你正在启动一个新项目,你应该考虑使用 Slick。即使你花费相当多的时间编写位于 JDBC 之上的包装器,你也不太可能达到 Slick 提供的流畅性。

如果你正在对一个大量使用 JDBC 的现有项目进行工作,我希望上一章的示例表明,只需一点时间和努力,你就可以编写 JDBC 包装器,以减少 JDBC 的命令式风格和 Scala 的函数式方法之间的阻抗。

摘要

在前两章中,我们广泛地探讨了如何从 Scala 查询关系型数据库。在本章中,你学习了如何使用 Slick,这是一个“函数式关系型”映射器,允许像与 Scala 集合交互一样与 SQL 数据库交互。

在下一章中,你将学习如何通过查询 Web API 来摄取数据。

参考文献

要了解更多关于 Slick 的信息,你可以参考 Slick 文档(slick.typesafe.com/doc/2.1.0/)及其 API 文档(slick.typesafe.com/doc/2.1.0/api/#package)。

第七章。Web API

数据科学家和数据工程师从各种不同的来源获取数据。通常,数据可能以 CSV 文件或数据库转储的形式出现。有时,我们必须通过 Web API 获取数据。

个人或组织建立 Web API 以通过互联网(或内部网络)向程序分发数据。与数据旨在由 Web 浏览器消费并显示给用户的网站不同,Web API 提供的数据对查询它的程序类型是中立的。服务于 HTML 的 Web 服务器和支撑 API 的 Web 服务器基本上以相同的方式进行查询:通过 HTTP 请求。

我们已经在第四章中看到了一个 Web API 的例子,并行集合和未来,其中我们查询了“Markit on demand”API 以获取当前的股票价格。在本章中,我们将更详细地探讨如何与 Web API 交互;具体来说,如何将 API 返回的数据转换为 Scala 对象,以及如何通过 HTTP 头(例如,用于身份验证)向请求添加额外信息。

“按需 Markit”API 返回的数据格式化为 XML 对象,但越来越多的新 Web API 返回的数据格式化为 JSON。因此,本章我们将重点关注 JSON,但这些概念很容易应用到 XML 上。

JSON 是一种用于格式化结构化数据的语言。许多读者在过去可能已经遇到过 JSON,但如果没有,本章后面将简要介绍其语法和概念。你会发现它非常直观。

在本章中,我们将轮询 GitHub API。在过去的几年里,GitHub 已经成为开源软件协作的事实上工具。它提供了一个功能强大、特性丰富的 API,可以以编程方式访问网站上的几乎所有数据。

让我们来尝尝我们能做什么。在您的网络浏览器地址栏中输入api.github.com/users/odersky。这将返回 API 提供的特定用户(在这种情况下是 Martin Odersky)的数据:

{
  "login": "odersky",
  "id": 795990,
  ...
  "public_repos": 8,
  "public_gists": 3,
  "followers": 707,
  "following": 0,
  "created_at": "2011-05-18T14:51:21Z",
  "updated_at": "2015-09-15T15:14:33Z"
}

数据以 JSON 对象的形式返回。本章致力于学习如何以编程方式访问和解析这些数据。在第十三章《使用 Play 的 Web API》中,你将学习如何构建你自己的 Web API。

小贴士

GitHub API 非常广泛且文档齐全。在本章中,我们将探索 API 的一些功能。要查看 API 的完整范围,请访问文档(developer.github.com/v3/)。

JSON 快速浏览

JSON 是一种用于传输结构化数据的格式。它灵活,易于计算机生成和解析,对于人类来说相对易于阅读。它已成为持久化程序数据结构和在程序之间传输数据的一种非常常见的方式。

JSON 有四种基本类型:数字字符串布尔值null,以及两种复合类型:数组对象。对象是无序的键值对集合,其中键始终是字符串,值可以是任何简单或复合类型。我们已经看到了一个 JSON 对象:API 调用api.github.com/users/odersky返回的数据。

数组是有序的简单或复合类型列表。例如,在您的浏览器中输入api.github.com/users/odersky/repos,以获取一个对象数组,每个对象代表一个 GitHub 仓库:

[
  {
    "id": 17335228,
    "name": "dotty",
    "full_name": "odersky/dotty",
    ...
  },
  {
    "id": 15053153,
    "name": "frontend",
    "full_name": "odersky/frontend",
    ...
  },
  ...
]

我们可以通过在对象或数组内部嵌套其他对象来构建复杂结构。然而,大多数 Web API 返回的 JSON 结构最多只有一两个嵌套层级。如果你不熟悉 JSON,我鼓励你通过你的网络浏览器探索 GitHub API。

查询 Web API

从 Scala 中查询 Web API 最简单的方法是使用 Source.fromURL。我们已经在 第四章,并行集合和未来 中使用过它,当时我们查询了 "Markit on demand" API。Source.fromURL 提供了一个类似于 Source.fromFile 的接口:

scala> import scala.io._
import scala.io._

scala> val response = Source.fromURL(
 "https://api.github.com/users/odersky"
).mkString
response: String = {"login":"odersky","id":795990, ...

Source.fromURL 返回响应字符的迭代器。我们使用其 .mkString 方法将迭代器实体化为一个字符串。现在我们有了作为 Scala 字符串的响应。下一步是使用 JSON 解析器解析字符串。

Scala 中的 JSON – 一个模式匹配练习

Scala 中有几个用于操作 JSON 的库。我们更喜欢 json4s,但如果你是另一个 JSON 库的死忠粉丝,你应该能够轻松地适应本章中的示例。让我们创建一个包含对 json4s 依赖的 build.sbt 文件:

// build.sbt
scalaVersion := "2.11.7"

libraryDependencies += "org.json4s" %% "json4s-native" % "3.2.11"

我们可以导入 json4s 到 SBT 控制台会话中:

scala> import org.json4s._
import org.json4s._

scala> import org.json4s.native.JsonMethods._
import org.json4s.native.JsonMethods._

让我们使用 json4s 解析 GitHub API 查询的响应:

scala> val jsonResponse = parse(response)
jsonResponse: org.json4s.JValue = JObject(List((login,JString(odersky)),(id,JInt(795990)),...

parse 方法接受一个字符串(包含格式良好的 JSON),并将其转换为 JValue,这是所有 json4s 对象的超类型。此特定查询的响应运行时类型为 JObject,它是表示 JSON 对象的 json4s 类型。

JObject 是一个围绕 List[JField] 的包装器,JField 表示对象中的单个键值对。我们可以使用 提取器 来访问这个列表:

scala> val JObject(fields) = jsonResponse
fields: List[JField] = List((login,Jstring(odersky)),...

这里发生了什么?通过编写 val JObject(fields) = ...,我们告诉 Scala:

  • 右侧的运行时类型为 JObject

  • 进入 JObject 实例并将字段列表绑定到常量 fields

熟悉 Python 的读者可能会注意到与元组解包的相似之处,尽管 Scala 提取器要强大得多,也更加灵活。提取器被广泛用于从 json4s 类型中提取 Scala 类型。

小贴士

使用案例类进行模式匹配

Scala 编译器是如何知道如何处理像这样的提取器的:

val JObject(fields) = ...

JObject 是一个具有以下构造函数的案例类:

case class JObject(obj:List[JField])

所有案例类都附带一个提取器,它可以精确地反转构造函数。因此,编写 val JObject(fields) 将将 fields 绑定到 JObjectobj 属性。有关提取器如何工作的更多详细信息,请参阅 附录,模式匹配和提取器

现在我们已经从 JObject 中提取了 fields,这是一个(普通的旧 Scala)字段列表。JField 是一个键值对,键是一个字符串,值是 JValue 的子类型。同样,我们可以使用提取器来提取字段中的值:

scala> val firstField = fields.head
firstField: JField = (login,JString(odersky))

scala> val JField(key, JString(value)) = firstField
key: String = login
value: String = odersky

我们将右侧与模式 JField(_, JString(_)) 匹配,将第一个元素绑定到 key,第二个绑定到 value。如果右侧不匹配模式会发生什么?

scala> val JField(key, JInt(value)) = firstField
scala.MatchError: (login,JString(odersky)) (of class scala.Tuple2)
...

代码在运行时抛出MatchError。以下示例展示了嵌套模式匹配的强大功能:在一行代码中,我们成功验证了firstField的类型,确认其值为JString类型,并将键和值分别绑定到keyvalue变量。作为另一个例子,如果我们知道第一个字段是登录字段,我们既可以验证这一点,也可以提取其值:

scala> val JField("login", JString(loginName)) = firstField
loginName: String = odersky

注意这种编程风格是声明式而不是命令式:我们在右侧声明我们想要一个JField("login", JString(_))变量。然后让语言找出如何检查变量类型。模式匹配是函数式语言中的一个常见主题。

我们还可以在遍历字段时使用模式匹配。当在 for 循环中使用时,模式匹配定义了一个部分函数:只有与模式匹配的元素才会通过循环。这让我们能够过滤出匹配模式的元素集合,并对这些元素应用转换。例如,我们可以从我们的fields列表中提取每个字符串字段:

scala> for {
 JField(key, JString(value)) <- fields
} yield (key -> value)
List[(String, String)] = List((login,odersky), (avatar_url,https://avatars.githubusercontent.com/...

我们可以使用它来搜索特定字段。例如,提取"followers"字段:

scala> val followersList = for {
 JField("followers", JInt(followers)) <- fields
} yield followers
followersList: List[Int] = List(707)

scala> val followers = followersList.headOption
blogURL: Option[Int] = Some(707)

我们首先提取所有匹配模式JField("follower", JInt(_))的字段,返回JInt内部的整数。由于源集合fields是一个列表,这返回一个整数列表。然后我们使用headOption从该列表中提取第一个值,它如果列表至少有一个元素,则返回列表的头部,如果列表为空,则返回None

我们不仅限于一次提取一个字段。例如,要一起提取"id""login"字段:

scala> {
 for {
 JField("login", JString(loginName)) <- fields
 JField("id", JInt(id)) <- fields
 } yield (id -> loginName)
}.headOption 
Option[(BigInt, String)] = Some((795990,odersky))

Scala 的模式匹配和提取器为你提供了一种极其强大的方法来遍历json4s树,提取我们需要的字段。

JSON4S 类型

我们已经发现了json4s类型层次结构的一部分:字符串被包裹在JString对象中,整数(或大整数)被包裹在JInt中,依此类推。在本节中,我们将退后一步,正式化类型结构和它们提取到的 Scala 类型。这些都是json4s的运行时类型:

  • val JString(s) // => 提取为 String

  • val JDouble(d) // => 提取为 Double

  • val JDecimal(d) // => 提取为 BigDecimal

  • val JInt(i) // => 提取为 BigInt

  • val JBool(b) // => 提取为布尔值

  • val JObject(l) // => 提取为[JField]列表

  • val JArray(l) // => 提取为[JValue]列表

  • JNull // => 表示 JSON null

所有这些类型都是JValue的子类。parse的编译时结果是JValue,你通常需要使用提取器将其转换为具体类型。

层次结构中的最后一个类型是JField,它表示键值对。JField只是(String, JValue)元组的类型别名。因此,它不是JValue的子类型。我们可以使用以下提取器提取键和值:

val JField(key, JInt(value)) = ...

使用 XPath 提取字段

在前面的章节中,你学习了如何使用提取器遍历 JSON 对象。在本节中,我们将探讨另一种遍历 JSON 对象和提取特定字段的方法:XPath DSL(领域特定语言)。XPath 是一种用于遍历树状结构的查询语言。它最初是为在 XML 文档中定位特定节点而设计的,但它同样适用于 JSON。当我们从“Markit on demand”API 返回的 XML 文档中提取股票价格时,我们已经看到了 XPath 语法的示例,这在第四章中,并行集合和未来。我们使用 r \ "LastPrice" 提取了标签为 "LastPrice" 的节点。\ 操作符是由 scala.xml 包定义的。

json4s 包提供了一个类似的 DSL 来从 JObject 实例中提取字段。例如,我们可以从 JSON 对象 jsonResponse 中提取 "login" 字段:

scala> jsonResponse \ "login"
org.json4s.JValue = JString(odersky)

这返回了一个 JValue,我们可以使用提取器将其转换为 Scala 字符串:

scala> val JString(loginName) = jsonResponse \ "login"
loginName: String = odersky

注意 XPath DSL 和遍历文件系统的相似性:我们可以将 JObject 实例视为目录。字段名对应文件名,字段值对应文件内容。这在嵌套结构中更为明显。GitHub API 的 users 端点没有嵌套文档,所以让我们尝试另一个端点。我们将查询与这本书对应的仓库的 API:"api.github.com/repos/pbugnion/s4ds"。响应具有以下结构:

{
  "id": 42269470,
  "name": "s4ds",
  ...
  "owner": { "login": "pbugnion", "id": 1392879 ... }
  ...
}

让我们获取这个文档,并使用 XPath 语法提取仓库所有者的登录名:

scala> val jsonResponse = parse(Source.fromURL(
 "https://api.github.com/repos/pbugnion/s4ds"
).mkString)
jsonResponse: JValue = JObject(List((id,JInt(42269470)), (name,JString(s4ds))...

scala> val JString(ownerLogin) = jsonResponse \ "owner" \ "login"
ownerLogin: String = pbugnion

再次,这类似于遍历文件系统:jsonResponse \ "owner" 返回与 "owner" 对象相对应的 JObject。这个 JObject 可以进一步查询 "login" 字段,返回与该键关联的值 JString(pbugnion)

如果 API 响应是一个数组呢?文件系统类比就有些不适用了。让我们查询列出马丁·奥德斯基仓库的 API 端点:api.github.com/users/odersky/repos。响应是一个包含 JSON 对象的数组,每个对象代表一个仓库:

[
  {
    "id": 17335228,
    "name": "dotty",
    "size": 14699,
    ...
  },
  {
    "id": 15053153,
    "name": "frontend",
    "size": 392
    ...
  },
  {
    "id": 2890092,
    "name": "scala",
    "size": 76133,
    ...
  },
  ...
]

让我们获取这个文档并将其解析为 JSON:

scala> val jsonResponse = parse(Source.fromURL(
 "https://api.github.com/users/odersky/repos"
).mkString)
jsonResponse: JValue = JArray(List(JObject(List((id,JInt(17335228)), (name,Jstring(dotty)), ...

这返回了一个 JArray。XPath DSL 在 JArray 上的工作方式与在 JObject 上相同,但现在,它返回的是一个与数组中每个对象的路径匹配的字段数组。让我们获取所有马丁·奥德斯基的仓库的大小:

scala> jsonResponse \ "size"
JValue = JArray(List(JInt(14699), JInt(392), ...

现在,我们有一个包含每个仓库中 "size" 字段值的 JArray。我们可以使用 for 理解遍历这个数组,并使用提取器将元素转换为 Scala 对象:

scala> for {
 JInt(size) <- (jsonResponse \ "size")
} yield size
List[BigInt] = List(14699, 392, 76133, 32010, 98166, 1358, 144, 273)

因此,结合提取器和 XPath DSL,我们得到了从 JSON 对象中提取信息的有力、互补的工具。

XPath 语法远比这里能涵盖的要多,包括从当前根的任何深度提取嵌套字段的能力,或者匹配谓词或特定类型的字段。我们发现,设计良好的 API 可以消除许多这些更强大功能的需求,但请查阅文档(json4s.org)以了解您可以做什么的概述。

在下一节中,我们将探讨如何直接将 JSON 提取到案例类中。

使用案例类进行提取

在前面的章节中,我们使用 Scala 提取器从 JSON 响应中提取了特定字段。我们可以做得更好,提取完整的案例类。

当我们超出 REPL 时,编程最佳实践规定我们应该尽快从 json4s 类型移动到 Scala 对象,而不是在程序中传递 json4s 类型。从 json4s 类型转换为 Scala 类型(或表示域对象的案例类)是良好的实践,因为:

  • 它将程序与从 API 收到的数据结构解耦,我们对这些结构几乎没有控制权。

  • 它提高了类型安全性:从编译器的角度来看,JObject 总是 JObject,无论它包含哪些字段。相比之下,编译器永远不会将 User 错误地认为是 Repository

Json4s 允许我们直接从 JObject 实例中提取案例类,这使得将 JObject 实例转换为自定义类型层的编写变得简单。

让我们定义一个表示 GitHub 用户的案例类:

scala> case class User(id:Long, login:String)
defined class User

要从 JObject 中提取案例类,我们首先必须定义一个隐式 Formats 值,该值定义了简单类型应该如何序列化和反序列化。我们将使用 json4s 提供的默认 DefaultFormats

scala> implicit val formats = DefaultFormats
formats: DefaultFormats.type = DefaultFormats$@750e685a

我们现在可以提取 User 的实例。让我们为马丁·奥德斯基(Martin Odersky)做这个操作:

scala> val url = "https://api.github.com/users/odersky"
url: String = https://api.github.com/users/odersky

scala> val jsonResponse = parse(Source.fromURL(url).mkString)
jsonResponse: JValue = JObject(List((login,JString(odersky)), ...

scala> jsonResponse.extract[User]
User = User(795990,odersky)

只要对象格式良好,这种方法就有效。extract 方法在 JObject 中寻找与 User 属性匹配的字段。在这种情况下,extract 会注意到 JObject 包含 "login": "odersky" 字段,并且 JString("odersky") 可以转换为 Scala 字符串,因此它将 "odersky" 绑定到 User 中的 login 属性。

如果属性名称与 JSON 对象中的字段名称不同怎么办?我们必须首先将对象转换为具有正确字段的形式。例如,让我们将 User 类中的 login 属性重命名为 userName

scala> case class User(id:Long, userName:String)
defined class User

如果我们尝试在 jsonResponse 上使用 extract[User],我们将得到一个映射错误,因为反序列化器在响应中缺少 login 字段。我们可以通过在 jsonResponse 上使用 transformField 方法来重命名 login 字段来修复这个问题:

scala> jsonResponse.transformField { 
 case("login", n) => "userName" -> n 
}.extract[User]
User = User(795990,odersky)

关于可选字段怎么办?假设 GitHub API 返回的 JSON 对象并不总是包含 login 字段。我们可以在对象模型中通过将 login 参数的类型指定为 Option[String] 而不是 String 来表示这一点:

scala> case class User(id:Long, login:Option[String])
defined class User

这正如你所期望的那样工作。当响应包含非空的login字段时,调用extract[User]会将其反序列化为Some(value),如果它缺失或为JNull,则会产生None

scala> jsonResponse.extract[User]
User = User(795990,Some(odersky))

scala> jsonResponse.removeField { 
 case(k, _) => k == "login" // remove the "login" field
}.extract[User]
User = User(795990,None)

让我们将这个功能封装在一个小的程序中。该程序将接受一个命令行参数,即用户的登录名,提取一个User实例,并将其打印到屏幕上:

// GitHubUser.scala

import scala.io._
import org.json4s._
import org.json4s.native.JsonMethods._

object GitHubUser {

  implicit val formats = DefaultFormats

  case class User(id:Long, userName:String)

  /** Query the GitHub API corresponding to `url` 
    * and convert the response to a User.
    */
  def fetchUserFromUrl(url:String):User = {
    val response = Source.fromURL(url).mkString
    val jsonResponse = parse(response)
    extractUser(jsonResponse)
  }

  /** Helper method for transforming the response to a User */
  def extractUser(obj:JValue):User = {
    val transformedObject = obj.transformField {
      case ("login", name) => ("userName", name)
    }
    transformedObject.extract[User]
  }

  def main(args:Array[String]) {
    // Extract username from argument list
    val name = args.headOption.getOrElse { 
      throw new IllegalArgumentException(
        "Missing command line argument for user.")
    }

    val user = fetchUserFromUrl(
      s"https://api.github.com/users/$name")

    println(s"** Extracted for $name:")
    println()
    println(user)

  }

}

我们可以从 SBT 控制台按照以下方式运行此程序:

$ sbt
> runMain GitHubUser pbugnion
** Extracted for pbugnion:
User(1392879,pbugnion)

使用未来(futures)进行并发和异常处理

尽管我们在上一节中编写的程序可以工作,但它非常脆弱。如果我们输入一个不存在的用户名,或者 GitHub API 发生变化或返回格式错误的响应,它将会崩溃。我们需要使其具有容错性。

如果我们还想获取多个用户呢?按照目前的程序编写方式,它是完全单线程的。fetchUserFromUrl方法会向 API 发起调用并阻塞,直到 API 返回数据。一个更好的解决方案是并行地获取多个用户。

正如你在第四章,并行集合和未来中学到的,实现容错性和并行执行有两种简单的方法:我们可以将所有用户名放入并行集合中,并将获取和提取用户的代码封装在Try块中,或者我们可以将每个查询封装在未来的框架中。

当查询网络 API 时,有时请求可能会异常地长时间。为了避免这阻碍其他线程,最好依赖于未来(futures)而不是并行集合(parallel collections)来实现并发,正如我们在第四章末尾的并行集合或未来?部分所看到的,并行集合和未来

让我们重写上一节中的代码,以并行且容错地获取多个用户。我们将把fetchUserFromUrl方法改为异步查询 API。这与第四章中的内容没有太大区别,我们在其中查询了"Markit on demand" API:

// GitHubUserConcurrent.scala

import scala.io._
import scala.concurrent._
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util._

import org.json4s._
import org.json4s.native.JsonMethods._

object GitHubUserConcurrent {

  implicit val formats = DefaultFormats

  case class User(id:Long, userName:String)

  // Fetch and extract the `User` corresponding to `url`
  def fetchUserFromUrl(url:String):Future[User] = {
    val response = Future { Source.fromURL(url).mkString }
    val parsedResponse = response.map { r => parse(r) }
    parsedResponse.map { extractUser }
  }

  // Helper method for extracting a user from a JObject
  def extractUser(jsonResponse:JValue):User = {
    val o = jsonResponse.transformField {
      case ("login", name) => ("userName", name)
    }
    o.extract[User]
  }

  def main(args:Array[String]) {
    val names = args.toList

    // Loop over each username and send a request to the API 
    // for that user 
    val name2User = for {
      name <- names
      url = s"https://api.github.com/users/$name"
      user = fetchUserFromUrl(url)
    } yield name -> user

    // callback function
    name2User.foreach { case(name, user) =>
      user.onComplete {
        case Success(u) => println(s" ** Extracted for $name: $u")
        case Failure(e) => println(s" ** Error fetching $name:$e")
      }
    }

    // Block until all the calls have finished.
    Await.ready(Future.sequence(name2User.map { _._2 }), 1 minute)
  }
}

让我们通过sbt运行这段代码:

$ sbt
> runMain GitHubUserConcurrent odersky derekwyatt not-a-user-675
 ** Error fetching user not-a-user-675: java.io.FileNotFoundException: https://api.github.com/users/not-a-user-675
 ** Extracted for odersky: User(795990,odersky)
 ** Extracted for derekwyatt: User(62324,derekwyatt)

代码本身应该是直截了当的。这里使用的所有概念都已在本章或第四章中探讨过,除了最后一行:

Await.ready(Future.sequence(name2User.map { _._2 }), 1 minute)

这条语句指示程序等待我们列表中的所有未来都已完成。Await.ready(..., 1 minute)将一个未来作为其第一个参数,并在该未来返回之前阻塞执行。第二个参数是对这个未来的超时。唯一的缺点是我们需要将单个未来传递给Await,而不是未来列表。我们可以使用Future.sequence将一组未来合并为一个单一的未来。这个未来将在序列中的所有未来都完成后完成。

身份验证 – 添加 HTTP 头部

到目前为止,我们一直在使用未经身份验证的 GitHub API。这限制了我们每小时只能进行六十次请求。现在我们可以在并行查询 API 的情况下,几秒钟内就能超过这个限制。

幸运的是,如果您在查询 API 时进行身份验证,GitHub 会慷慨得多。限制增加到每小时 5,000 次请求。您必须有一个 GitHub 用户账户才能进行身份验证,所以如果您需要,请现在就创建一个账户。创建账户后,导航到github.com/settings/tokens并点击生成新令牌按钮。接受默认设置,并在屏幕上输入令牌描述,应该会出现一个长的十六进制数字。现在先复制令牌。

HTTP – 快速概述

在使用我们新生成的令牌之前,让我们花几分钟时间回顾一下 HTTP 是如何工作的。

HTTP 是在不同计算机之间传输信息的协议。这是我们本章一直在使用的协议,尽管 Scala 在Source.fromURL调用中隐藏了这些细节。它也是您在将网络浏览器指向网站时使用的协议,例如。

在 HTTP 中,一台计算机通常会向远程服务器发送一个请求,服务器会返回一个响应。请求包含一个动词,它定义了请求的类型,以及一个标识资源的 URL。例如,当我们在我们浏览器中键入api.github.com/users/pbugnion时,这被转换为一个针对users/pbugnion资源的 GET(动词)请求。我们迄今为止所做的一切调用都是 GET 请求。您可能会使用不同类型的请求,例如 POST 请求,来修改(而不仅仅是查看)GitHub 上的某些内容。

除了动词和资源之外,HTTP 请求还有两个其他部分:

  • 头部包含了关于请求的元数据,例如响应的预期格式和字符集或身份验证凭据。头部只是一个键值对的列表。我们将使用Authorization头部将我们刚刚生成的 OAuth 令牌传递给 API。这篇维基百科文章列出了常用的头部字段:en.wikipedia.org/wiki/List_of_HTTP_header_fields

  • 请求体在 GET 请求中不被使用,但在修改它们查询的资源时变得重要。例如,如果我想通过编程方式在 GitHub 上创建一个新的仓库,我会向/pbugnion/repos发送 POST 请求。POST 体将是一个描述新仓库的 JSON 对象。在本章中,我们不会使用请求体。

在 Scala 中向 HTTP 请求添加头信息

我们将在 HTTP 请求中传递 OAuth 令牌作为头信息。不幸的是,Source.fromURL方法在创建 GET 请求时添加头信息并不特别适合。我们将改用库,scalaj-http

让我们在build.sbt的依赖项中添加scalaj-http

libraryDependencies += "org.scalaj" %% "scalaj-http" % "1.1.6"

我们现在可以导入scalaj-http

scala> import scalaj.http._
import scalaj.http._

我们首先创建一个HttpRequest对象:

scala> val request = Http("https://api.github.com/users/pbugnion")
request:scalaj.http.HttpRequest = HttpRequest(api.github.com/users/pbugnion,GET,...

我们现在可以向请求添加授权头(在此处添加你自己的令牌字符串):

scala> val authorizedRequest = request.header("Authorization", "token e836389ce ...")
authorizedRequest:scalaj.http.HttpRequest = HttpRequest(api.github.com/users/pbugnion,GET,...

提示

.header方法返回一个新的HttpRequest实例。它不会就地修改请求。因此,仅仅调用request.header(...)实际上并没有将头信息添加到请求本身,这可能会引起混淆。

让我们发起请求。我们通过请求的asString方法来完成,该方法查询 API,获取响应,并将其解析为 Scala String

scala> val response = authorizedRequest.asString
response:scalaj.http.HttpResponse[String] = HttpResponse({"login":"pbugnion",...

响应由三个部分组成:

  • 状态码,对于成功的请求应该是200

    scala> response.code 
    Int = 200
    
    
  • 响应体,这是我们感兴趣的部分:

    scala> response.body 
    String = {"login":"pbugnion","id":1392879,...
    
    
  • 响应头(关于响应的元数据):

    scala> response.headers 
    Map[String,String] = Map(Access-Control-Allow-Credentials -> true, ...
    
    

要验证授权是否成功,查询X-RateLimit-Limit头:

scala> response.headers("X-RateLimit-Limit")
String = 5000

这个值是你可以从单个 IP 地址向 GitHub API 发起的最大每小时请求数量。

现在我们已经对如何向 GET 请求添加认证有了些了解,让我们修改我们的用户获取脚本,使用 OAuth 令牌进行认证。我们首先需要导入scalaj-http

import scalaj.http._

将令牌的值注入到代码中可能有些棘手。你可能想将其硬编码,但这会阻止你共享代码。更好的解决方案是使用环境变量。环境变量是在你的终端会话中存在的一组变量,该会话中的所有进程都可以访问这些变量。要获取当前环境变量的列表,请在 Linux 或 Mac OS 上输入以下内容:

$ env
HOME=/Users/pascal
SHELL=/bin/zsh
...

在 Windows 上,等效的命令是SET。让我们将 GitHub 令牌添加到环境变量中。在 Mac OS 或 Linux 上使用以下命令:

$ export GHTOKEN="e83638..." # enter your token here

在 Windows 上,使用以下命令:

$ SET GHTOKEN="e83638..."

如果你打算在许多项目中重用这个环境变量,那么在每次会话中输入export GHTOKEN=...会很快变得令人厌烦。一个更持久的解决方案是将export GHTOKEN="e83638…"添加到你的 shell 配置文件中(如果你使用 Bash,则是你的.bashrc文件)。只要你的.bashrc只能被用户读取,这将是安全的。任何新的 shell 会话都将能够访问GHTOKEN环境变量。

我们可以使用sys.env从 Scala 程序中访问环境变量,它返回一个包含变量的Map[String, String]。让我们在我们的类中添加一个lazy val token,包含token值:

lazy val token:Option[String] = sys.env.get("GHTOKEN") orElse {
  println("No token found: continuing without authentication")
  None
}

现在我们有了令牌,唯一需要更改代码的部分,以添加身份验证,就是fetchUserFromUrl方法:

def fetchUserFromUrl(url:String):Future[User] = {
  val baseRequest = Http(url)
  val request = token match {
    case Some(t) => baseRequest.header(
      "Authorization", s"token $t")
    case None => baseRequest
  }
  val response = Future { 
    request.asString.body 
  }
  val parsedResponse = response.map { r => parse(r) }
  parsedResponse.map(extractUser)
}

此外,为了获得更清晰的错误信息,我们可以检查响应的状态码是否为 200。由于这很简单,所以留作练习。

概述

在本章中,你学习了如何查询 GitHub API,将响应转换为 Scala 对象。当然,仅仅将结果打印到屏幕上并不十分有趣。在下一章中,我们将探讨数据摄取过程的下一步:将数据存储在数据库中。我们将查询 GitHub API 并将结果存储在 MongoDB 数据库中。

在第十三章,使用 Play 构建 Web API中,我们将探讨构建我们自己的简单 Web API。

参考文献

GitHub API,凭借其详尽的文档,是探索如何构建丰富 API 的好地方。它有一个入门部分值得一读:

developer.github.com/guides/getting-started/

当然,这不仅仅针对 Scala:它使用 cURL 查询 API。

请阅读json4s的文档(json4s.org)和源代码(github.com/json4s/json4s)以获取完整的参考。这个包的许多部分我们还没有探索,特别是如何从 Scala 构建 JSON。

第八章. Scala 和 MongoDB

在第五章,通过 JDBC 的 Scala 和 SQL和第六章,Slick – SQL 的功能接口中,你学习了如何在 SQL 数据库中插入、转换和读取数据。这些数据库在数据科学中仍然(并且可能仍然)非常受欢迎,但 NoSQL 数据库正在成为强劲的竞争者。

数据存储的需求正在迅速增长。公司正在生产和存储更多的数据点,希望获得更好的商业智能。他们也在组建越来越大的数据科学家团队,所有这些人都需要访问数据存储。随着数据负载的增加,保持恒定的访问时间需要利用并行架构:我们需要将数据库分布在几台计算机上,这样当服务器负载增加时,我们只需添加更多机器来提高吞吐量。

在 MySQL 数据库中,数据自然地分布在不同的表中。复杂的查询需要跨多个表进行连接。这使得在多台计算机上分区数据库变得困难。NoSQL 数据库的出现填补了这一空白。

在本章中,你将学习如何与 MongoDB 交互,这是一个提供高性能且易于分布的开源数据库。MongoDB 是更受欢迎的 NoSQL 数据库之一,拥有强大的社区。它提供了速度和灵活性的合理平衡,使其成为存储具有不确定查询要求的大型数据集(如数据科学中可能发生的情况)的 SQL 的自然替代品。本章中的许多概念和食谱也适用于其他 NoSQL 数据库。

MongoDB

MongoDB 是一个面向文档的数据库。它包含文档集合。每个文档都是一个类似 JSON 的对象:

{
    _id: ObjectId("558e846730044ede70743be9"),
    name: "Gandalf",
    age: 2000,
    pseudonyms: [ "Mithrandir", "Olorin", "Greyhame" ],
    possessions: [ 
        { name: "Glamdring", type: "sword" }, 
        { name: "Narya", type: "ring" }
    ]
}

正如 JSON 一样,文档是一组键值对,其中值可以是字符串、数字、布尔值、日期、数组或子文档。文档被分组在集合中,集合被分组在数据库中。

你可能会想,这与 SQL 并没有太大的不同:一个文档类似于一行,一个集合对应一个表。但存在两个重要的区别:

  • 文档中的值可以是简单值、数组、子文档或子文档数组。这使得我们可以在单个集合中编码一对一和多对多关系。例如,考虑巫师集合。在 SQL 中,如果我们想为每个巫师存储化名,我们必须使用一个单独的wizard2pseudonym表,并为每个巫师-化名对创建一行。在 MongoDB 中,我们只需使用一个数组。在实践中,这意味着我们可以通常用一个文档来表示一个实体(例如客户、交易或巫师)。在 SQL 中,我们通常需要跨多个表进行连接,以检索特定实体的所有信息。

  • MongoDB 是无模式的。集合中的文档可以具有不同的字段集,不同文档中同一字段的类型也可以不同。在实践中,MongoDB 集合有一个松散的模式,由客户端或约定强制执行:大多数文档将具有相同字段的子集,字段通常将包含相同的数据类型。具有灵活的模式使得调整数据结构变得容易,因为不需要耗时的ALTER TABLE语句。缺点是,在数据库端没有简单的方法来强制执行我们的灵活模式。

注意到_id字段:这是一个唯一键。如果我们插入一个没有_id字段的文档,MongoDB 将自动生成一个。

本章提供了从 Scala 与 MongoDB 数据库交互的配方,包括维护类型安全和最佳实践。我们不会涵盖高级 MongoDB 功能(如聚合或数据库的分布式)。我们假设您已在您的计算机上安装了 MongoDB(docs.mongodb.org/manual/installation/)。对 MongoDB 有非常基本的了解也会有所帮助(我们将在本章末尾讨论一些参考资料,但任何在线可用的基本教程都足以满足本章的需求)。

使用 Casbah 连接到 MongoDB

Scala 的官方 MongoDB 驱动程序称为Casbah。Casbah 不是完整的驱动程序,而是包装了 Java Mongo 驱动程序,提供了一个更函数式的接口。还有其他 Scala 的 MongoDB 驱动程序,我们将在本章末尾简要讨论。现在,我们将坚持使用 Casbah。

让我们从向我们的build.sbt文件添加 Casbah 开始:

scalaVersion := "2.11.7"

libraryDependencies += "org.mongodb" %% "casbah" % "3.0.0"

Casbah 还期望slf4j绑定(一个 Scala 日志框架)可用,因此让我们也添加slf4j-nop

libraryDependencies += "org.slf4j" % "slf4j-nop" % "1.7.12"

我们现在可以启动 SBT 控制台并在 Scala shell 中导入 Casbah:

$ sbt console
scala> import com.mongodb.casbah.Imports._
import com.mongodb.casbah.Imports._

scala> val client = MongoClient()
client: com.mongodb.casbah.MongoClient = com.mongodb.casbah.MongoClient@4ac17318

这将连接到默认主机(localhost)和默认端口(27017)上的 MongoDB 服务器。要连接到不同的服务器,将主机和端口作为参数传递给MongoClient

scala> val client = MongoClient("192.168.1.1", 27017)
client: com.mongodb.casbah.MongoClient = com.mongodb.casbah.MongoClient@584c6b02

注意,创建客户端是一个延迟操作:它不会在需要之前尝试连接到服务器。这意味着如果您输入了错误的 URL 或密码,您直到尝试访问服务器上的文档时才会知道。

一旦我们与服务器建立了连接,访问数据库就像使用客户端的apply方法一样简单。例如,要访问github数据库:

scala> val db = client("github")
db: com.mongodb.casbah.MongoDB = DB{name='github'}

然后,我们可以访问"users"集合:

scala> val coll = db("users")
coll: com.mongodb.casbah.MongoCollection = users

使用身份验证连接

MongoDB 支持多种不同的身份验证机制。在本节中,我们假设您的服务器正在使用SCRAM-SHA-1机制,但您应该会发现将代码适应不同类型的身份验证很简单。

最简单的身份验证方式是在连接时通过 URI 传递usernamepassword

scala> val username = "USER"
username: String = USER

scala> val password = "PASSWORD"
password: String = PASSWORD

scala> val uri = MongoClientURI(
 s"mongodb://$username:$password@localhost/?authMechanism=SCRAM-SHA-1"
)
uri: MongoClientURI = mongodb://USER:PASSWORD@localhost/?authMechanism=SCRAM-SHA-1

scala> val mongoClient = MongoClient(uri)
client: com.mongodb.casbah.MongoClient = com.mongodb.casbah.MongoClient@4ac17318

通常,您不希望在代码中以明文形式放置密码。您可以在命令行中提示输入密码或通过环境变量传递,就像我们在第七章中处理 GitHub OAuth 令牌那样。以下代码片段演示了如何通过环境传递凭据:

// Credentials.scala

import com.mongodb.casbah.Imports._

object Credentials extends App {

  val username = sys.env.getOrElse("MONGOUSER",
    throw new IllegalStateException(
      "Need a MONGOUSER variable in the environment")
  )
  val password = sys.env.getOrElse("MONGOPASSWORD",
    throw new IllegalStateException(
      "Need a MONGOPASSWORD variable in the environment")
  )

  val host = "127.0.0.1"
  val port = 27017

  val uri = s"mongodb://$username:$password@$host:$port/?authMechanism=SCRAM-SHA-1"

  val client = MongoClient(MongoClientURI(uri))
}

您可以通过以下方式在 SBT 中运行它:

$ MONGOUSER="pascal" MONGOPASSWORD="scalarulez" sbt
> runMain Credentials

插入文档

让我们在新创建的数据库中插入一些文档。我们希望存储有关 GitHub 用户的信息,使用以下文档结构:

{
    id: <mongodb object id>,
    login: "pbugnion",
    github_id: 1392879,
    repos: [ 
        {
            name: "scikit-monaco",
            id: 14821551,
            language: "Python"
        },
        {
            name: "contactpp",
            id: 20448325,
            language: "Python"
        }
    ]
}

Casbah 提供了一个DBObject类来表示 Scala 中的 MongoDB 文档(和子文档)。让我们首先为每个存储库子文档创建一个DBObject实例:

scala> val repo1 = DBObject("name" -> "scikit-monaco", "id" -> 14821551, "language" -> "Python")
repo1: DBObject = { "name" : "scikit-monaco" , "id" : 14821551, "language" : "Python"}

如您所见,DBObject 只是一个键值对列表,其中键是字符串。值具有编译时类型 AnyRef,但如果您尝试添加无法序列化的值,Casbah 将在运行时失败。

我们还可以直接从键值对列表创建 DBObject 实例。这在将 Scala 映射转换为 DBObject 时特别有用:

scala> val fields:Map[String, Any] = Map(
 "name" -> "contactpp",
 "id" -> 20448325,
 "language" -> "Python"
)
Map[String, Any] = Map(name -> contactpp, id -> 20448325, language -> Python)

scala> val repo2 = DBObject(fields.toList)
repo2: dDBObject = { "name" : "contactpp" , "id" : 20448325, "language" : "Python"}

DBObject 类提供了与映射相同的大多数方法。例如,我们可以访问单个字段:

scala> repo1("name")
AnyRef = scikit-monaco

我们可以通过向现有对象添加字段来构造一个新对象:

scala> repo1 + ("fork" -> true)
mutable.Map[String,Any] = { "name" : "scikit-monaco" , "id" : 14821551, "language" : "python", "fork" : true}

注意返回类型:mutable.Map[String,Any]。Casbah 不是直接实现如 + 之类的方法,而是通过提供到和从 mutable.Map 的隐式转换将它们添加到 DBObject 中。

新的 DBObject 实例也可以通过连接两个现有实例来创建:

scala> repo1 ++ DBObject(
 "locs" -> 6342, 
 "description" -> "Python library for Monte Carlo integration"
)
DBObject = { "name" : "scikit-monaco" , "id" : 14821551, "language" : "Python", "locs" : 6342 , "description" : "Python library for Monte Carlo integration"}

DBObject 实例可以使用 += 操作符插入到集合中。让我们将我们的第一个文档插入到 user 集合中:

scala> val userDocument = DBObject(
 "login" -> "pbugnion", 
 "github_id" -> 1392879, 
 "repos" -> List(repo1, repo2)
)
userDocument: DBObject = { "login" : "pbugnion" , ... }

scala> val coll = MongoClient()("github")("users")
coll: com.mongodb.casbah.MongoCollection = users

scala> coll += userDocument
com.mongodb.casbah.TypeImports.WriteResult = WriteResult{, n=0, updateOfExisting=false, upsertedId=null}

包含单个文档的数据库有点无聊,所以让我们添加一些直接从 GitHub API 查询的更多文档。您在上一章中学习了如何查询 GitHub API,所以我们不会在这里详细说明如何进行此操作。

在本章的代码示例中,我们提供了一个名为 GitHubUserIterator 的类,该类查询 GitHub API(特别是 /users 端点)以获取用户文档,将它们转换为案例类,并将它们作为迭代器提供。您可以在本章的代码示例(可在 GitHub 上找到 github.com/pbugnion/s4ds/tree/master/chap08)中的 GitHubUserIterator.scala 文件中找到该类。访问该类最简单的方法是在本章代码示例的目录中打开一个 SBT 控制台。API 随后按登录 ID 递增的顺序获取用户:

scala> val it = new GitHubUserIterator
it: GitHubUserIterator = non-empty iterator

scala> it.next // Fetch the first user
User = User(mojombo,1,List(Repo(...

GitHubUserIterator 返回 User 案例类的实例,该类定义如下:

// User.scala
case class User(login:String, id:Long, repos:List[Repo])

// Repo.scala
case class Repo(name:String, id:Long, language:String)

让我们编写一个简短的程序来获取 500 个用户并将它们插入到 MongoDB 数据库中。我们需要通过 GitHub API 进行身份验证来检索这些用户。GitHubUserIterator 构造函数接受 GitHub OAuth 令牌作为可选参数。我们将通过环境注入令牌,就像我们在上一章中所做的那样。

在分解代码之前,我们首先给出整个代码列表——如果您正在手动输入,您需要将 GitHubUserIterator.scala 从本章的代码示例复制到您运行此操作的目录中,以便访问 GitHubUserIterator 类。该类依赖于 scalaj-httpjson4s,因此您可以选择复制代码示例中的 build.sbt 文件,或者在您的 build.sbt 文件中指定这些包作为依赖项。

// InsertUsers.scala

import com.mongodb.casbah.Imports._

object InsertUsers {

  /** Function for reading GitHub token from environment. */
  lazy val token:Option[String] = sys.env.get("GHTOKEN") orElse {
    println("No token found: continuing without authentication")
    None
  }

  /** Transform a Repo instance to a DBObject */
  def repoToDBObject(repo:Repo):DBObject = DBObject(
    "github_id" -> repo.id,
    "name" -> repo.name,
    "language" -> repo.language
  )

  /** Transform a User instance to a DBObject */
  def userToDBObject(user:User):DBObject = DBObject(
    "github_id" -> user.id,
    "login" -> user.login,
    "repos" -> user.repos.map(repoToDBObject)
  )

  /** Insert a list of users into a collection. */
  def insertUsers(coll:MongoCollection)(users:Iterable[User]) {
    users.foreach { user => coll += userToDBObject(user) }
  }

  /**  Fetch users from GitHub and passes them to `inserter` */
  def ingestUsers(nusers:Int)(inserter:Iterable[User] => Unit) {
    val it = new GitHubUserIterator(token)
    val users = it.take(nusers).toList
    inserter(users)
  }

  def main(args:Array[String]) {
    val coll = MongoClient()("github")("users")
    val nusers = 500
    coll.dropCollection()
    val inserter = insertUsers(coll)_
    ingestUsers(inserter)(nusers)
  }

}

在深入了解程序的工作原理之前,让我们通过 SBT 运行它。您将想要使用身份验证查询 API 以避免达到速率限制。回想一下,我们需要设置 GHTOKEN 环境变量:

$ GHTOKEN="e83638..." sbt
$ runMain InsertUsers

程序运行大约需要五分钟(取决于您的互联网连接)。为了验证程序是否工作,我们可以查询 github 数据库中 users 集合中的文档数量:

$ mongo github --quiet --eval "db.users.count()"
500

让我们分解一下代码。我们首先加载 OAuth 令牌以验证 GitHub API。令牌存储为环境变量 GHTOKENtoken 变量是一个 lazy val,因此令牌只在形成对 API 的第一个请求时加载。我们已经在 第七章,Web APIs 中使用了这种模式。

然后我们定义两个方法,将领域模型中的类转换成 DBObject 实例:

def repoToDBObject(repo:Repo):DBObject = ...
def userToDBObject(user:User):DBObject = ...

带着这两个方法,我们可以轻松地将用户添加到我们的 MongoDB 集合中:

def insertUsers(coll:MongoCollection)(users:Iterable[User]) {
  users.foreach { user => coll += userToDBObject(user) }
}

我们使用柯里化来拆分 insertUsers 的参数。这使得我们可以将 insertUsers 作为函数工厂使用:

val inserter = insertUsers(coll)_

这创建了一个新的方法 inserter,其签名是 Iterable[User] => Unit,用于将用户插入到 coll 中。为了了解这如何有用,让我们编写一个函数来包装整个数据导入过程。这是这个函数的第一个尝试可能看起来像这样:

def ingestUsers(nusers:Int)(inserter:Iterable[User] => Unit) {
  val it = new GitHubUserIterator(token)
  val users = it.take(nusers).toList
  inserter(users)
}

注意 ingestUsers 方法如何将其第二个参数作为一个指定用户列表如何插入到数据库中的方法。这个函数封装了插入到 MongoDB 集合的整个特定代码。如果我们决定,在未来的某个时间点,我们讨厌 MongoDB 并且必须将文档插入到 SQL 数据库或写入平面文件,我们只需要将不同的 inserter 函数传递给 ingestUsers。其余的代码保持不变。这展示了使用高阶函数带来的更高灵活性:我们可以轻松构建一个框架,并让客户端代码插入它需要的组件。

如前所述定义的 ingestUsers 方法有一个问题:如果 nusers 值很大,它将在构建整个用户列表时消耗大量内存。更好的解决方案是将它分解成批次:我们从 API 获取一批用户,将它们插入到数据库中,然后继续处理下一批。这样我们可以通过改变批次大小来控制内存使用。它也更加容错:如果程序崩溃,我们只需从最后一个成功插入的批次重新启动。

.grouped 方法,适用于所有可迭代对象,用于批量处理。它返回一个遍历原始可迭代对象片段的迭代器:

scala> val it = (0 to 10)
it: Range.Inclusive = Range(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> it.grouped(3).foreach { println } // In batches of 3
Vector(0, 1, 2)
Vector(3, 4, 5)
Vector(6, 7, 8)
Vector(9, 10)

让我们重写我们的 ingestUsers 方法以使用批次。我们还会在每个批次后添加一个进度报告,以便给用户一些反馈:

/**  Fetch users from GitHub and pass them to `inserter` */
def ingestUsers(nusers:Int)(inserter:Iterable[User] => Unit) {
  val batchSize = 100
  val it = new GitHubUserIterator(token)
  print("Inserted #users: ")
  it.take(nusers).grouped(batchSize).zipWithIndex.foreach {
    case (users, batchNumber) =>
      print(s"${batchNumber*batchSize} ")
      inserter(users)
  }
  println()
}

让我们更仔细地看看高亮行。我们从用户迭代器 it 开始。然后我们取前 nusers 个用户。这返回一个 Iterator[User],它不会在 GitHub 数据库中的每个用户上愉快地运行,而是在 nusers 后终止。然后我们将这个迭代器分组为 100 个用户的批次。.grouped 方法返回 Iterator[Iterator[User]]。然后我们将每个批次与其索引进行连接,这样我们就可以知道我们目前正在处理哪个批次(我们在 print 语句中使用这个)。.zipWithIndex 方法返回 Iterator[(Iterator[User], Int)]。我们在循环中使用一个 case 语句来解包这个元组,将 users 绑定到 Iterator[User],将 batchNumber 绑定到索引。让我们通过 SBT 运行这个例子:

$ GHTOKEN="2502761..." sbt 
> runMain InsertUsers
[info] Running InsertUsers
Inserted #users: 0 100 200 300 400
[success] Total time: 215 s, completed 01-Nov-2015 18:44:30

从数据库中提取对象

现在我们有一个包含一些用户的数据库。让我们从 REPL 中查询这个数据库:

scala> import com.mongodb.casbah.Imports._
import com.mongodb.casbah.Imports._

scala> val collection = MongoClient()("github")("users")
MongoCollection = users

scala> val maybeUser = collection.findOne
Option[collection.T] = Some({ "_id" : { "$oid" : "562e922546f953739c43df02"} , "github_id" : 1 , "login" : "mojombo" , "repos" : ...

findOne 方法返回一个包含在选项中的单个 DBObject 对象,除非集合为空,在这种情况下它返回 None。因此,我们必须使用 get 方法来提取对象:

scala> val user = maybeUser.get
collection.T = { "_id" : { "$oid" : "562e922546f953739c43df02"} , "github_id" : 1 , "login" : "mojombo" , "repos" : ...

如您在本章早期所学的,DBObject 是一个类似于映射的对象,其键的类型为 String,值类型为 AnyRef

scala> user("login")
AnyRef = mojombo

通常,我们在从数据库导入对象时,希望尽可能早地恢复编译时类型信息:我们不希望在可以更具体的情况下传递 AnyRef。我们可以使用 getAs 方法提取字段并将其转换为特定类型:

scala> user.getAsString
Option[String] = Some(mojombo)

如果文档中缺少字段或值无法转换,getAs 将返回 None

scala> user.getAsInt
Option[Int] = None

聪明的读者可能会注意到,getAs[T] 提供的接口与我们定义在 第五章 中的 JDBC 结果集上的 read[T] 方法类似,通过 JDBC 的 Scala 和 SQL

如果 getAs 失败(例如,因为字段缺失),我们可以使用 orElse 部分函数来恢复:

scala> val loginName = user.getAsString orElse { 
 println("No login field found. Falling back to 'name'")
 user.getAsString
}
loginName: Option[String] = Some(mojombo)

getAsOrElse 方法允许我们在转换失败时替换默认值:

scala> user.getAsOrElseInt
Int = 1392879

注意,我们还可以使用 getAsOrElse 抛出异常:

scala> user.getAsOrElseString
)
java.lang.IllegalArgumentException: Missing value for name
...

文档中嵌入的数组可以转换为 List[T] 对象,其中 T 是数组中元素的类型:

scala> user.getAsOrElse[List[DBObject]]("repos",
 List.empty[DBObject])
List[DBObject] = List({ "github_id" : 26899533 , "name" : "30daysoflaptops.github.io" ...

一次检索一个文档并不很有用。要检索集合中的所有文档,请使用 .find 方法:

scala> val userIterator = collection.find()
userIterator: collection.CursorType = non-empty iterator

这将返回一个 DBObject 迭代器。要实际从数据库中获取文档,您需要通过将其转换为集合来具体化迭代器,例如使用 .toList

scala> val userList = userIterator.toList
List[DBObject] = List({ "_id" : { "$oid": ...

让我们把所有这些放在一起。我们将编写一个玩具程序,该程序打印我们集合中每个用户的平均存储库数量。代码通过获取集合中的每个文档,从每个文档中提取存储库数量,然后对这些数量进行平均来实现:

// RepoNumber.scala

import com.mongodb.casbah.Imports._

object RepoNumber {

  /** Extract the number of repos from a DBObject
    * representing a user.
    */   
  def extractNumber(obj:DBObject):Option[Int] = {
    val repos = obj.getAs[List[DBObject]]("repos") orElse {
      println("Could not find or parse 'repos' field")
      None
    }
    repos.map { _.size }
  }

  val collection = MongoClient()("github")("users")

  def main(args:Array[String]) {    
    val userIterator = collection.find()

    // Convert from documents to Option[Int]
    val repoNumbers = userIterator.map { extractNumber }

    // Convert from Option[Int] to Int
    val wellFormattedNumbers = repoNumbers.collect { 
      case Some(v) => v 
    }.toList

    // Calculate summary statistics
    val sum = wellFormattedNumbers.reduce { _ + _ }
    val count = wellFormattedNumbers.size

    if (count == 0) {
      println("No repos found")
    }
    else {
      val mean = sum.toDouble / count.toDouble
      println(s"Total number of users with repos: $count")
      println(s"Total number of repos: $sum")
      println(s"Mean number of repos: $mean")
    }
  }
}

让我们通过 SBT 运行这个例子:

> runMain RepoNumber
Total number of users with repos: 500
Total number of repos: 9649
Mean number of repos: 19.298

代码从 extractNumber 函数开始,该函数从每个 DBObject 中提取存储库数量。如果文档不包含 repos 字段,则返回值是 None

代码的主体部分首先创建一个遍历集合中 DBObject 的迭代器。然后,这个迭代器通过 extractNumber 函数进行映射,将其转换为 Option[Int] 的迭代器。然后我们对这个迭代器运行 .collect,收集所有不是 None 的值,在这个过程中将 Option[Int] 转换为 Int。然后我们才使用 .toList 将迭代器实体化为列表。得到的列表 wellFormattedNumbers 具有类型 List[Int]。然后我们只取这个列表的平均值并将其打印到屏幕上。

注意,除了 extractNumber 函数外,这个程序没有处理任何与 Casbah 特定的类型相关的事务:.find() 返回的迭代器只是一个 Scala 迭代器。这使得 Casbah 的使用变得简单:你需要熟悉的唯一数据类型是 DBObject(与 JDBC 的 ResultSet 进行比较,我们必须显式地将其包装在流中,例如)。

复杂查询

我们现在知道如何将 DBObject 实例转换为自定义 Scala 类。在本节中,你将学习如何构建只返回集合中部分文档的查询。

在上一节中,你学习了如何如下检索集合中的所有文档:

scala> val objs = collection.find().toList
List[DBObject] = List({ "_id" : { "$oid" : "56365cec46f9534fae8ffd7f"} ,...

collection.find() 方法返回一个遍历集合中所有文档的迭代器。通过在这个迭代器上调用 .toList,我们将其实体化为列表。

我们可以通过传递一个查询文档到 .find 方法来自定义返回哪些文档。例如,我们可以检索特定登录名的文档:

scala> val query = DBObject("login" -> "mojombo")
query: DBObject = { "login" : "mojombo"}

scala> val objs = collection.find(query).toList
List[DBObject] = List({ "_id" : { "$oid" : "562e922546f953739c43df02"} , "login" : "mojombo",...

MongoDB 查询以 DBObject 实例的形式表达。DBObject 中的键对应于集合文档中的字段,而值是控制该字段允许值的表达式。因此,DBObject("login" -> "mojombo") 将选择所有 login 字段为 mojombo 的文档。使用 DBObject 实例表示查询可能看起来有些晦涩,但如果你阅读 MongoDB 文档(docs.mongodb.org/manual/core/crud-introduction/),你会很快明白:查询在 MongoDB 中本身就是 JSON 对象。因此,Casbah 中的查询表示为 DBObject 与其他 MongoDB 客户端实现保持一致。它还允许熟悉 MongoDB 的人迅速开始编写 Casbah 查询。

MongoDB 支持更复杂的查询。例如,要查询 "github_id"2030 之间的所有人,我们可以编写以下查询:

scala> val query = DBObject("github_id" -> 
 DBObject("$gte" -> 20, "$lt" -> 30))
query: DBObject = { "github_id" : { "$gte" : 20 , "$lt" : 30}}

scala> collection.find(query).toList
List[com.mongodb.casbah.Imports.DBObject] = List({ "_id" : { "$oid" : "562e922546f953739c43df0f"} , "github_id" : 23 , "login" : "takeo" , ...

我们使用DBObject("$gte" -> 20, "$lt" -> 30)限制了github_id可以取的值的范围。"$gte"字符串表示github_id必须大于或等于20。同样,"$lt"表示小于操作符。要获取查询时可以使用的所有操作符的完整列表,请查阅 MongoDB 参考文档(docs.mongodb.org/manual/reference/operator/query/)。

到目前为止,我们只看了顶级字段的查询。Casbah 还允许我们使用点符号查询子文档和数组中的字段。在数组值的上下文中,这将返回数组中至少有一个值与查询匹配的所有文档。例如,要检索所有在 Scala 中拥有主要语言为 Scala 的仓库的用户:

scala> val query = DBObject("repos.language" -> "Scala")
query: DBObject = { "repos.language" : "Scala"}

scala> collection.find(query).toList
List[DBObject] = List({ "_id" : { "$oid" : "5635da4446f953234ca634df"}, "login" : "kevinclark"...

Casbah 查询 DSL

使用DBObject实例来表示查询可能非常冗长且难以阅读。Casbah 提供了一个 DSL 来更简洁地表示查询。例如,要获取所有github_id字段在2030之间的文档,我们会编写以下代码:

scala> collection.find("github_id" $gte 20 $lt 30).toList
List[com.mongodb.casbah.Imports.DBObject] = List({ "_id" : { "$oid" : "562e922546f953739c43df0f"} , "github_id" : 23 , "login" : "takeo" , "repos" : ...

DSL 提供的运算符将自动构造DBObject实例。尽可能多地使用 DSL 运算符通常会导致代码更易读、更易于维护。

进入查询 DSL 的详细内容超出了本章的范围。您会发现使用它相当简单。要获取 DSL 支持的运算符的完整列表,请参阅 Casbah 文档mongodb.github.io/casbah/3.0/reference/query_dsl/。我们在此总结了最重要的运算符:

运算符 描述
"login" $eq "mojombo" 这将选择login字段正好是mojombo的文档
"login" $ne "mojombo" 这将选择login字段不是mojombo的文档
"github_id" $gt 1 $lt 20 这将选择github_id大于1且小于20的文档
"github_id" $gte 1 $lte 20 这将选择github_id大于或等于1且小于或等于20的文档
"login" $in ("mojombo", "defunkt") login字段是mojombodefunkt
"login" $nin ("mojombo", "defunkt") login字段不是mojombodefunkt
"login" $regex "^moj.*" login字段匹配特定的正则表达式
"login" $exists true login字段存在
$or("login" $eq "mojombo", "github_id" $gte 22) login字段是mojombogithub_id字段大于或等于22
$and("login" $eq "mojombo", "github_id" $gte 22) login字段是mojombogithub_id字段大于或等于22

我们还可以使用点符号来查询数组和子文档。例如,以下查询将计算所有在 Scala 中拥有仓库的用户:

scala> collection.find("repos.language" $eq "Scala").size
Int = 30

自定义类型序列化

到目前为止,我们只尝试了序列化和反序列化简单类型。如果我们想将存储在仓库数组中的语言字段解码为枚举而不是字符串呢?例如,我们可以定义以下枚举:

scala> object Language extends Enumeration {
 val Scala, Java, JavaScript = Value
}
defined object Language

Casbah 允许我们定义与特定 Scala 类型相关的自定义序列化器:我们可以通知 Casbah,每当它在 DBObject 中遇到 Language.Value 类型的实例时,该实例应通过一个自定义转换器进行转换,例如转换为字符串,然后再将其写入数据库。

要定义一个自定义序列化器,我们需要定义一个扩展 Transformer 特质的类。这个特质暴露了一个方法,transform(o:AnyRef):AnyRef。让我们定义一个 LanguageTransformer 特质,它将 Language.Value 转换为 String

scala> import org.bson.{BSON, Transformer}
import org.bson.{BSON, Transformer}

scala> trait LanguageTransformer extends Transformer {
 def transform(o:AnyRef):AnyRef = o match {
 case l:Language.Value => l.toString
 case _ => o
 }
}
defined trait LanguageTransformer

我们现在需要注册特质,以便在需要解码类型 Language.Value 的实例时使用。我们可以使用 addEncodingHook 方法来完成此操作:

scala> BSON.addEncodingHook(
 classOf[Language.Value], new LanguageTransformer {})

我们现在可以构建包含 Language 枚举值的 DBObject 实例:

scala> val repoObj = DBObject(
 "github_id" -> 1234L,
 "language" -> Language.Scala
)
repoObj: DBObject = { "github_id" : 1234 , "language" : "Scala"}

反过来呢?我们如何告诉 Casbah 将 "language" 字段读取为 Language.Value?这不可能通过自定义反序列化器实现:"Scala" 现在作为字符串存储在数据库中。因此,在反序列化时,"Scala""mojombo" 没有区别。因此,当 "Scala" 被序列化时,我们失去了类型信息。

因此,虽然自定义编码钩子在序列化时很有用,但在反序列化时则不那么有用。一种更干净、更一致的替代方案是使用 类型类 来自定义序列化和反序列化。我们已经在 第五章,通过 JDBC 的 Scala 和 SQL 的上下文中广泛介绍了如何使用这些类型类,用于将数据序列化和反序列化到 SQL。这里的程序将非常相似:

  1. 定义一个具有 read(v:Any):T 方法的 MongoReader[T] 类型类。

  2. MongoReader 伴生对象中为所有感兴趣的类型定义具体的 MongoReader 实现,例如 StringLanguage.Value

  3. 使用 pimp my library 模式,为 DBObject 增强一个 read[T:MongoReader] 方法。

例如,Language.ValueMongoReader 实现如下:

implicit object LanguageReader extends MongoReader[Language.Value] {
  def read(v:Any):Language.Value = v match {
    case s:String => Language.withName(s)
  }
}

我们可以用同样的方式使用 MongoWriter 类型类。使用类型类是自定义序列化和反序列化的惯用和可扩展的方法。

我们在本章相关的代码示例(在 typeclass 目录中)提供了一个类型类的完整示例。

除此之外

本章我们只考虑了 Casbah。然而,MongoDB 还有其他驱动程序。

ReactiveMongo 是一个专注于数据库异步读写操作的驱动程序。所有查询都返回一个未来对象,强制执行异步行为。这非常适合数据流或 Web 应用程序。

Salat位于 Casbah 之上,旨在提供对 case 类进行简单序列化和反序列化的功能。

完整的驱动程序列表可在docs.mongodb.org/ecosystem/drivers/scala/找到。

摘要

在本章中,你学习了如何与 MongoDB 数据库交互。通过将上一章学到的结构(从 Web API 中提取信息)与本章学到的结构结合起来,我们现在可以构建一个用于数据摄取的并发、响应式程序。

在下一章中,你将学习如何使用 Akka 演员构建更灵活的分布式、并发结构。

参考文献

MongoDB: The Definitive Guide》,由Kristina Chodorow所著,是 MongoDB 的良好入门指南。它完全不涉及使用 Scala 与 MongoDB 交互,但对于熟悉 MongoDB 的人来说,Casbah 足够直观。

类似地,MongoDB 文档(docs.mongodb.org/manual/)提供了对 MongoDB 的深入讨论。

Casbah 本身有很好的文档(mongodb.github.io/casbah/3.0/)。有一个入门指南,与本章有些类似,还有一个完整的参考指南,将填补本章留下的空白。

这个片段,gist.github.com/switzer/4218526,实现了类型类,用于将领域模型中的对象序列化和反序列化为DBObject。前提与本章中建议的类型类用法略有不同:我们将 Scala 类型转换为AnyRef,以便在DBObject中使用作为值。然而,这两种方法互为补充:可以想象有一组类型类将UserRepo转换为DBObject,另一组将Language.Value转换为AnyRef

第九章。使用 Akka 进行并发

本书的大部分内容都专注于利用多核和分布式架构。在第四章,并行集合和未来中,你学习了如何使用并行集合将批处理问题分布到多个线程上,以及如何使用未来执行异步计算。在第七章,Web API中,我们将这些知识应用于使用多个并发线程查询 GitHub API。

并发抽象,如未来和并行集合,通过限制你可以做的事情来简化并发编程的巨大复杂性。例如,并行集合强制你将并行化问题表述为集合上的纯函数序列。

演员提供了一种不同的并发思考方式。演员在封装 状态 方面非常出色。管理不同执行线程之间共享的状态可能是开发并发应用程序最具挑战性的部分,正如我们将在本章中发现的那样,演员使其变得可管理。

GitHub 关注者图

在前两章中,我们探讨了 GitHub API,学习了如何使用 json-4s 查询 API 并解析结果。

让我们想象一下,我们想要提取 GitHub 关注者图:我们想要一个程序,它将从特定用户开始,提取该用户关注者,然后提取他们的关注者,直到我们告诉它停止。问题是,我们事先不知道需要获取哪些 URL:当我们下载特定用户关注者的登录名时,我们需要验证我们是否已经获取了这些用户。如果没有,我们将它们添加到需要获取其关注者的用户队列中。算法爱好者可能会认出这作为 广度优先搜索

让我们概述一下我们如何以单线程方式编写它。核心组件是一组已访问用户和未来用户队列:

val seedUser = "odersky" // the origin of the network

// Users whose URLs need to be fetched 
val queue = mutable.Queue(seedUser) 

// set of users that we have already fetched 
// (to avoid re-fetching them)
val fetchedUsers = mutable.Set.empty[String] 

while (queue.nonEmpty) {
  val user = queue.dequeue
  if (!fetchedUsers(user)) {
    val followers = fetchFollowersForUser(user)
    followers foreach { follower =>           
      // add the follower to queue of people whose 
      // followers we want to find.
      queue += follower
    }
    fetchedUsers += user
  }
}

在这里,fetchFollowersForUser 方法的签名是 String => Iterable[String],它负责接受一个登录名,将其转换为 GitHub API 中的 URL,查询 API,并从响应中提取关注者列表。我们在这里不会实现它,但您可以在本书代码示例的 chap09/single_threaded 目录中找到一个完整的示例(github.com/pbugnion/s4ds)。如果您已经阅读了 第七章

假设 Chris 接到了一个订单。他会查看订单,决定是否可以自己处理,如果不能,他会将订单转发给 Mark 或 Sally。让我们假设订单要求一个小程序,所以 Bob 将订单转发给了 Sally。Sally 非常忙,正在处理一批积压的订单,因此她不能立即处理订单信息,它将在她的邮箱中短暂停留。当她最终开始处理订单时,她可能会决定将订单分成几个部分,其中一些部分她会分配给 Kevin 和 Bob。

当 Bob 和 Kevin 完成任务时,他们会向 Sally 发送消息以通知她。当订单的每个部分都得到满足时,Sally 会将这些部分汇总起来,直接向客户或向 Chris 发送结果的消息。

跟踪哪些工作必须完成以完成订单的任务落在 Sally 身上。当她收到 Bob 和 Kevin 的消息时,她必须更新她正在进行的任务列表,并检查与这个订单相关的每个任务是否完成。这种协调在传统的synchronize块中会更具有挑战性:对正在进行的任务列表和已完成任务列表的每次访问都需要同步。通过将这种逻辑嵌入只能一次处理一个消息的 Sally 中,我们可以确保不会出现竞态条件。

我们的初创公司运作良好,因为每个人只负责做一件事:Chris 要么委托给 Mark 或 Sally,Sally 将订单拆分成几个部分并分配给 Bob 和 Kevin,而 Bob 和 Kevin 完成每个部分。你可能会想,“等等,所有的逻辑都嵌入在 Bob 和 Kevin 中,他们是底层的员工,做所有实际的工作”。与员工不同,演员成本低,所以如果嵌入在演员中的逻辑变得过于复杂,很容易引入额外的委托层,直到任务足够简单。

我们初创公司的员工拒绝进行多任务处理。当他们得到一份工作时,他们会完全处理它,然后转到下一个任务。这意味着他们不会因为多任务处理的复杂性而变得混乱。通过一次处理一个消息,演员大大减少了引入并发错误(如竞态条件)的范围。

更重要的是,通过提供一个程序员可以直观理解的抽象——即人类工作者——Akka 使得关于并发的推理变得更加容易。

使用 Akka 的 Hello World

让我们安装 Akka。我们将它添加到我们的 build.sbt 文件中:

scalaVersion := "2.11.7"

libraryDependencies += "com.typesafe.akka" %% "akka-actor" % "2.4.0"

我们现在可以按照以下方式导入 Akka:

import akka.actor._

对于我们第一次进入演员的世界,我们将构建一个接收并回显每个接收到的消息的演员。本节中的代码示例位于本书提供的示例代码目录 chap09/hello_akka 中(github.com/pbugnion/s4ds):

// EchoActor.scala
import akka.actor._

class EchoActor extends Actor with ActorLogging {
  def receive = {
    case msg:String => 
      Thread.sleep(500)
      log.info(s"Received '$msg'") 
  }
}

让我们分析这个例子,从构造函数开始。我们的演员类必须扩展 Actor。我们还添加了 ActorLogging,这是一个实用特性,它添加了 log 属性。

Echo 演员公开一个单一的方法,receive。这是演员与外部世界通信的唯一方式。为了有用,所有演员都必须公开一个 receive 方法。receive 方法是一个部分函数,通常使用多个 case 语句实现。当演员开始处理消息时,它将匹配每个 case 语句,直到找到匹配的一个。然后执行相应的代码块。

我们的 echo 演员接受一种类型的消息,一个普通的字符串。当这个消息被处理时,演员会等待半秒钟,然后将消息回显到日志文件中。

让我们实例化几个 Echo 演员,并发送它们消息:

// HelloAkka.scala

import akka.actor._

object HelloAkka extends App {

  // We need an actor system before we can 
  // instantiate actors
  val system = ActorSystem("HelloActors")

  // instantiate our two actors
  val echo1 = system.actorOf(Props[EchoActor], name="echo1")
  val echo2 = system.actorOf(Props[EchoActor], name="echo2")

  // Send them messages. We do this using the "!" operator
  echo1 ! "hello echo1"
  echo2 ! "hello echo2"
  echo1 ! "bye bye"

  // Give the actors time to process their messages, 
  // then shut the system down to terminate the program
  Thread.sleep(500)
  system.shutdown
}

运行此命令会得到以下输出:

[INFO] [07/19/2015 17:15:23.954] [HelloActor-akka.actor.default-dispatcher-2] [akka://HelloActor/user/echo1] Received 'hello echo1'
[INFO] [07/19/2015 17:15:23.954] [HelloActor-akka.actor.default-dispatcher-3] [akka://HelloActor/user/echo2] Received 'hello echo2'
[INFO] [07/19/2015 17:15:24.955] [HelloActor-akka.actor.default-dispatcher-2] [akka://HelloActor/user/echo1] Received 'bye bye'

注意,echo1echo2 演员显然是并发执行的:hello echo1hello echo2 同时被记录。传递给 echo1 的第二个消息在演员完成处理 hello echo1 后才被处理。

有几点需要注意:

  • 要开始实例化演员,我们首先必须创建一个演员系统。通常每个应用程序只有一个演员系统。

  • 我们实例化演员的方式看起来有点奇怪。我们不是调用构造函数,而是创建一个演员属性对象,Props[T]。然后我们要求演员系统使用这些属性创建一个演员。实际上,我们从不使用 new 实例化演员:它们要么是通过调用演员系统中的 actorOf 方法或另一个演员内的类似方法(稍后详细介绍)创建的。

我们从不从外部调用演员的方法。与演员交互的唯一方式是向其发送消息。我们使用 tell 操作符,! 来这样做。因此,从外部无法干扰演员的内部结构(或者至少,Akka 使得干扰演员的内部结构变得困难)。

作为消息的案例类

在我们的 "hello world" 示例中,我们构建了一个预期接收字符串消息的演员。任何不可变的对象都可以作为消息传递。使用案例类来表示消息非常常见。这比使用字符串更好,因为增加了额外的类型安全性:编译器会在案例类中捕获错误,而不会在字符串中。

让我们重写我们的 EchoActor 以接受案例类的实例作为消息。我们将使其接受两种不同的消息:EchoMessage(message)EchoHello,后者只是回显默认消息。本节和下一节的示例位于本书提供的示例代码中的 chap09/hello_akka_case_classes 目录(github.com/pbugnion/s4ds)。

定义演员可以接收的消息是 Akka 的一个常见模式:

// EchoActor.scala

object EchoActor { 
  case object EchoHello
  case class EchoMessage(msg:String)
}

让我们更改演员定义以接受这些消息:

class EchoActor extends Actor with ActorLogging {
  import EchoActor._ // import the message definitions
  def receive = {
    case EchoHello => log.info("hello")
    case EchoMessage(s) => log.info(s)  
  }
}

我们现在可以向我们的演员发送 EchoHelloEchoMessage

echo1 ! EchoActor.EchoHello
echo2 ! EchoActor.EchoMessage("We're learning Akka.")

演员构建

演员构建是 Akka 新手常见的难题来源。与(大多数)普通对象不同,你永远不会显式实例化演员。例如,你永远不会写 val echo = new EchoActor。实际上,如果你这样做,Akka 会抛出异常。

在 Akka 中创建演员是一个两步过程:首先创建一个 Props 对象,它封装了构建演员所需的属性。构建 Props 对象的方式取决于演员是否接受构造函数参数。如果构造函数不接受参数,我们只需将演员类作为类型参数传递给 Props

val echoProps = Props[EchoActor]

如果演员的构造函数接受参数,我们必须在定义 Props 对象时将这些参数作为额外的参数传递。让我们考虑以下演员,例如:

class TestActor(a:String, b:Int) extends Actor { ... }

我们如下将构造函数参数传递给 Props 对象:

val testProps = Props(classOf[TestActor], "hello", 2)

Props 实例只是封装了创建演员的配置。它实际上并没有创建任何东西。要创建演员,我们将 Props 实例传递给定义在 ActorSystem 实例上的 system.actorOf 方法:

val system = ActorSystem("HelloActors")
val echo1 = system.actorOf(echoProps, name="hello-1")

name 参数是可选的,但用于日志和错误消息很有用。.actorOf 返回的值不是演员本身:它是对演员的 引用(可以将其视为演员居住的地址),具有 ActorRef 类型。ActorRef 是不可变的,但它可以被序列化和复制,而不会影响底层演员。

除了在演员系统中调用 actorOf 之外,还有另一种创建演员的方法:每个演员都公开了一个 context.actorOf 方法,该方法接受一个 Props 实例作为参数。上下文仅可以从演员内部访问:

class TestParentActor extends Actor {
  val echoChild = context.actorOf(echoProps, name="hello-child")
  ...
}

从演员系统创建的演员与从另一个演员的上下文创建的演员之间的区别在于演员层次结构:每个演员都有一个父演员。在另一个演员的上下文中创建的任何演员都将具有该演员作为其父演员。由演员系统创建的演员有一个预定义的演员,称为 用户守护者,作为其父演员。当我们在本章末尾研究演员生命周期时,我们将了解演员层次结构的重要性。

一个非常常见的习惯用法是在演员的伴生对象中定义一个 props 方法,它作为该演员 Props 实例的工厂方法。让我们修改 EchoActor 伴生对象:

object EchoActor {
  def props:Props = Props[EchoActor]

  // message case class definitions here
}

然后,我们可以按照以下方式实例化演员:

val echoActor = system.actorOf(EchoActor.props)

演员的解剖结构

在深入一个完整的应用程序之前,让我们看看演员框架的不同组件以及它们是如何协同工作的:

  • 邮箱: 邮箱基本上是一个队列。每个演员都有自己的邮箱。当你向一个演员发送消息时,消息会落在它的邮箱中,直到演员从队列中取出并通过其 receive 方法处理它。

  • 消息: 消息使得演员之间的同步成为可能。消息可以具有任何类型,唯一的要求是它应该是不可变的。通常,最好使用案例类或案例对象来获得编译器的帮助,以检查消息类型。

  • 演员引用: 当我们使用 val echo1 = system.actorOf(Props[EchoActor]) 创建一个演员时,echo1 具有类型 ActorRefActorRef 是一个代理,用于表示演员,并且是其他世界与之交互的方式:当你发送一个消息时,你是将它发送到 ActorRef,而不是直接发送给演员。实际上,在 Akka 中,你永远无法直接获取到演员的句柄。演员可以使用 .self 方法为自己获取一个 ActorRef

  • 演员上下文: 每个演员都有一个 context 属性,通过它可以访问创建或访问其他演员的方法,以及获取有关外部世界的信息。我们已经看到了如何使用 context.actorOf(props) 创建新的演员。我们还可以通过 context.parent 获取演员的父引用。演员还可以使用 context.stop(actorRef) 停止另一个演员,其中 actorRef 是我们想要停止的演员的引用。

  • 调度器: 调度器是实际执行演员中代码的机器。默认调度器使用 fork/join 线程池。Akka 允许我们为不同的演员使用不同的调度器。调整调度器可能有助于优化性能并给某些演员赋予优先权。演员运行的调度器可以通过 context.dispatcher 访问。调度器实现了 ExecutionContext 接口,因此它们可以用来运行未来。

追随者网络爬虫

本章的最终目标是构建一个爬虫来探索 GitHub 的粉丝图谱。我们已经在本章前面概述了如何以单线程方式完成这项工作。现在让我们设计一个演员系统来并发地完成这项任务。

代码中的动态部分是管理哪些用户已被获取或正在被获取的数据结构。这些需要封装在一个演员中,以避免多个演员尝试同时更改它们时产生的竞争条件。因此,我们将创建一个fetcher 管理器演员,其任务是跟踪哪些用户已被获取,以及我们接下来要获取哪些用户。

代码中可能成为瓶颈的部分是查询 GitHub API。因此,我们希望能够扩展同时执行此操作的工作者数量。我们将创建一个fetcher池,这些演员负责查询特定用户的 API 以获取粉丝。最后,我们将创建一个演员,其责任是解释 API 的响应。这个演员将把其对响应的解释转发给另一个演员,该演员将提取粉丝并将他们交给 fetcher 管理器。

这就是程序架构的样貌:

粉丝网络爬虫

GitHub API 爬虫的演员系统

我们程序中的每个演员都执行单一的任务:fetcher 只查询 GitHub API,而队列管理器只将工作分配给 fetcher。Akka 的最佳实践是尽可能给演员分配狭窄的责任范围。这有助于在扩展时获得更好的粒度(例如,通过添加更多的 fetcher 演员,我们只是并行化瓶颈)和更好的弹性:如果一个演员失败,它只会影响其责任范围。我们将在本章后面探讨演员的失败。

我们将分几个步骤构建应用程序,在编写程序的同时探索 Akka 工具包。让我们从build.sbt文件开始。除了 Akka,我们还将scalaj-httpjson4s标记为依赖项:

// build.sbt
scalaVersion := "2.11.7"

libraryDependencies ++= Seq(
  "org.json4s" %% "json4s-native" % "3.2.10",
  "org.scalaj" %% "scalaj-http" % "1.1.4",
  "com.typesafe.akka" %% "akka-actor" % "2.3.12"
)

Fetcher 演员

我们应用程序的核心是 fetcher 演员,它负责从 GitHub 获取粉丝详情。最初,我们的演员将接受一个单一的消息,Fetch(user)。它将获取与user对应的粉丝,并将响应记录到屏幕上。我们将使用在第七章中开发的配方,即Web APIs,使用 OAuth 令牌查询 GitHub API。我们将通过演员构造函数注入令牌。

让我们从伴随对象开始。这将包含Fetch(user)消息的定义和两个工厂方法来创建Props实例。您可以在本书提供的示例代码中的chap09/fetchers_alone目录中找到本节的代码示例(github.com/pbugnion/s4ds):

// Fetcher.scala
import akka.actor._
import scalaj.http._
import scala.concurrent.Future

object Fetcher {
  // message definitions
  case class Fetch(login:String)

  // Props factory definitions
  def props(token:Option[String]):Props = 
    Props(classOf[Fetcher], token)
  def props():Props = Props(classOf[Fetcher], None)
}

现在我们来定义 fetcher 本身。我们将把对 GitHub API 的调用封装在一个 future 中。这避免了单个缓慢的请求阻塞演员。当我们的演员收到一个 Fetch 请求时,它会将这个请求封装在一个 future 中,发送出去,然后可以处理下一个消息。让我们继续实现我们的演员:

// Fetcher.scala
class Fetcher(val token:Option[String])
extends Actor with ActorLogging {
  import Fetcher._ // import message definition

  // We will need an execution context for the future.
  // Recall that the dispatcher doubles up as execution
  // context.
  import context.dispatcher

  def receive = {
    case Fetch(login) => fetchUrl(login)
  }

  private def fetchUrl(login:String) {
    val unauthorizedRequest = Http(
      s"https://api.github.com/users/$login/followers")
    val authorizedRequest = token.map { t =>
      unauthorizedRequest.header("Authorization", s"token $t")
    }

    // Prepare the request: try to use the authorized request
    // if a token was given, and fall back on an unauthorized 
    // request
    val request = authorizedRequest.getOrElse(unauthorizedRequest)

    // Fetch from github
    val response = Future { request.asString }
    response.onComplete { r =>
      log.info(s"Response from $login: $r")
    }
  }

}

让我们实例化一个演员系统并创建四个 fetcher 来检查我们的演员是否按预期工作。我们将从环境变量中读取 GitHub 令牌,如第七章 Web APIs 中所述,然后创建四个演员并要求每个演员获取特定 GitHub 用户的关注者。我们等待五秒钟以完成请求,然后关闭系统:

// FetcherDemo.scala
import akka.actor._

object FetcherDemo extends App {
  import Fetcher._ // Import the messages

  val system = ActorSystem("fetchers")

  // Read the github token if present.
  val token = sys.env.get("GHTOKEN")

  val fetchers = (0 until 4).map { i =>
    system.actorOf(Fetcher.props(token))
  }

  fetchers(0) ! Fetch("odersky")
  fetchers(1) ! Fetch("derekwyatt")
  fetchers(2) ! Fetch("rkuhn")
  fetchers(3) ! Fetch("tototoshi")

  Thread.sleep(5000) // Wait for API calls to finish
  system.shutdown // Shut system down

}

让我们通过 SBT 运行代码:

$ GHTOKEN="2502761..." sbt run
[INFO] [11/08/2015 16:28:06.500] [fetchers-akka.actor.default-dispatcher-2] [akka://fetchers/user/$d] Response from tototoshi: Success(HttpResponse([{"login":"akr4","id":10892,"avatar_url":"https://avatars.githubusercontent.com/u/10892?v=3","gravatar_id":""...

注意我们如何明确地需要使用 system.shutdown 来关闭演员系统。程序会挂起,直到系统关闭。然而,关闭系统将停止所有演员,因此我们需要确保他们已经完成工作。我们通过插入对 Thread.sleep 的调用来实现这一点。

使用 Thread.sleep 等待直到 API 调用完成以关闭演员系统是一种比较粗糙的方法。更好的方法可能是让演员向系统发出信号,表明他们已经完成了任务。当我们实现 fetcher manager 演员时,我们将看到这种模式的示例。

Akka 包含一个功能丰富的 scheduler 来安排事件。我们可以使用调度器来替换对 Thread.sleep 的调用,并安排在五秒后关闭系统。这比 Thread.sleep 更好,因为调度器不会阻塞调用线程。要使用调度器,我们需要导入全局执行上下文和持续时间模块:

// FetcherDemoWithScheduler.scala

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._

然后,我们可以通过用以下代码替换对 Thread.sleep 的调用来安排系统关闭:

system.scheduler.scheduleOnce(5.seconds) { system.shutdown }

除了 scheduleOnce,调度器还公开了一个 schedule 方法,允许您定期安排事件的发生(例如,每两秒一次)。这对于心跳检查或监控系统非常有用。有关更多信息,请参阅位于 doc.akka.io/docs/akka/snapshot/scala/scheduler.html 的调度器 API 文档。

注意,我们在这里实际上有点作弊,因为我们没有获取每个关注者的信息。关注者查询的响应实际上是分页的,因此我们需要获取几页才能获取所有关注者。在演员中添加逻辑来完成这项工作并不复杂。我们目前将忽略这一点,并假设每个用户的关注者数量上限为 100。

路由

在上一个示例中,我们创建了四个 fetchers 并将消息依次派发给了它们。我们在其中分配任务的一组相同 actor 中有一个池。手动将消息路由到正确的 actor 以最大化我们池的利用率是痛苦且容易出错的。幸运的是,Akka 为我们提供了几个路由策略,我们可以使用这些策略在我们的 actor 池中分配工作。让我们用自动路由重写之前的示例。您可以在本书提供的示例代码中的chap09/fetchers_routing目录下找到本节的代码示例(github.com/pbugnion/s4ds)。我们将重用与之前章节相同的Fetchers及其伴随对象的定义。

让我们先导入路由包:

// FetcherDemo.scala
import akka.routing._

一个router是一个 actor,它将其接收到的消息转发给其子 actor。定义 actor 池的最简单方法就是告诉 Akka 创建一个 router 并传递一个为其子 actor 的Props对象。然后,router 将直接管理工作者的创建。在我们的示例中(我们将在文本中仅注释与上一个示例不同的部分,但您可以在本章的示例代码的fetchers_routing目录中找到完整的代码),我们用以下代码替换了自定义的Fetcher创建代码:

// FetcherDemo.scala

// Create a router with 4 workers of props Fetcher.props()
val router = system.actorOf(
  RoundRobinPool(4).props(Fetcher.props(token))
)

然后,我们可以直接将 fetch 消息发送到路由器。路由器将以轮询的方式将消息路由到其子 actor:

List("odersky", "derekwyatt", "rkuhn", "tototoshi").foreach { 
  login => router ! Fetch(login)
}

在这个示例中,我们使用了轮询路由器。Akka 提供了许多不同类型的路由器,包括具有动态池大小的路由器,以满足不同类型的负载均衡。前往 Akka 文档查看所有可用的路由器列表,请参阅doc.akka.io/docs/akka/snapshot/scala/routing.html

Actor 之间的消息传递

仅记录 API 响应并不很有用。为了遍历跟随者图,我们必须执行以下操作:

  • 检查响应的返回码以确保 GitHub API 对我们的请求满意

  • 将响应解析为 JSON

  • 提取跟随者的登录名,如果我们还没有获取它们,将它们推入队列

您在第七章中学习了如何做所有这些事情,Web APIs,但不是在 actor 的上下文中。

我们可以将额外的处理步骤添加到我们的Fetcher actor 的receive方法中:我们可以通过 future composition 添加进一步的转换到 API 响应。然而,让 actor 执行多个不同的操作,并且可能以多种方式失败,是一种反模式:当我们学习管理 actor 生命周期时,我们会看到如果 actor 包含多个逻辑片段,我们的 actor 系统就变得难以推理。

因此,我们将使用三个不同 actor 的管道:

  • 我们已经遇到的 fetchers 仅负责从 GitHub 获取 URL。如果 URL 格式不正确或无法访问 GitHub API,它们将失败。

  • 响应解释器负责从 GitHub API 获取响应并将其解析为 JSON。如果在任何步骤中失败,它将仅记录错误(在实际应用中,我们可能会根据失败类型采取不同的纠正措施)。如果它成功提取 JSON,它将把 JSON 数组传递给跟随者提取器。

  • 跟随者提取器将从 JSON 数组中提取跟随者,并将它们传递给需要获取其跟随者的用户队列。

我们已经构建了 fetchers,尽管我们需要修改它们,以便将 API 响应转发给响应解释器,而不仅仅是记录日志。

你可以在本书提供的示例代码中的chap09/all_workers目录中找到本节的代码示例(github.com/pbugnion/s4ds)。第一步是修改 fetchers,使其在记录响应而不是转发响应。为了能够将响应转发给响应解释器,fetchers 将需要一个对这个演员的引用。我们将通过 fetcher 构造函数传递这个引用,现在的构造函数是:

// Fetcher.scala
class Fetcher(
  val token:Option[String], 
  val responseInterpreter:ActorRef) 
extends Actor with ActorLogging {
  ...
}

我们还必须修改伴随对象中的Props工厂方法:

// Fetcher.scala
def props(
  token:Option[String], responseInterpreter:ActorRef
):Props = Props(classOf[Fetcher], token, responseInterpreter)

我们还必须修改receive方法,将 HTTP 响应转发给解释器,而不仅仅是记录日志:

// Fetcher.scala
class Fetcher(...) extends Actor with ActorLogging {
  ...
  def receive = {
    case Fetch(login) => fetchFollowers(login)
  }

  private def fetchFollowers(login:String) {
    val unauthorizedRequest = Http(
      s"https://api.github.com/users/$login/followers")
    val authorizedRequest = token.map { t =>
      unauthorizedRequest.header("Authorization", s"token $t")
    }

    val request = authorizedRequest.getOrElse(unauthorizedRequest)
    val response = Future { request.asString }

    // Wrap the response in an InterpretResponse message and
    // forward it to the interpreter.
    response.onComplete { r =>
      responseInterpreter !
        ResponseInterpreter.InterpretResponse(login, r)
    }
  }
}

响应解释器获取响应,判断其是否有效,将其解析为 JSON,并将其转发给跟随者提取器。响应解释器将需要一个指向跟随者提取器的引用,我们将通过构造函数传递这个引用。

让我们从定义ResponseInterpreter伴随对象开始。它将只包含响应解释器可以接收的消息定义以及一个用于创建Props对象的工厂,以帮助进行实例化:

// ResponseInterpreter.scala
import akka.actor._
import scala.util._

import scalaj.http._
import org.json4s._
import org.json4s.native.JsonMethods._

object ResponseInterpreter {

  // Messages
  case class InterpretResponse(
    login:String, response:Try[HttpResponse[String]]
  )

  // Props factory
  def props(followerExtractor:ActorRef) = 
    Props(classOf[ResponseInterpreter], followerExtractor)
}

ResponseInterpreter的主体应该感觉熟悉:当演员收到一个提供要解释的响应的消息时,它使用你在第七章,“Web APIs”中学到的技术将其解析为 JSON。如果我们成功解析响应,我们将把解析后的 JSON 转发给跟随者提取器。如果我们无法解析响应(可能是由于格式不正确),我们只需记录错误。我们可以以其他方式从中恢复,例如,通过将此登录重新添加到队列管理器中以便再次获取:

// ResponseInterpreter.scala
class ResponseInterpreter(followerExtractor:ActorRef) 
extends Actor with ActorLogging {
  // Import the message definitions
  import ResponseInterpreter._

  def receive = {
    case InterpretResponse(login, r) => interpret(login, r)
  }

  // If the query was successful, extract the JSON response
  // and pass it onto the follower extractor.
  // If the query failed, or is badly formatted, throw an error
  // We should also be checking error codes here.
  private def interpret(
    login:String, response:Try[HttpResponse[String]]
  ) = response match {
    case Success(r) => responseToJson(r.body) match {
      case Success(jsonResponse) => 
        followerExtractor ! FollowerExtractor.Extract(
          login, jsonResponse)
      case Failure(e) => 
        log.error(
          s"Error parsing response to JSON for $login: $e")
    }
    case Failure(e) => log.error(
      s"Error fetching URL for $login: $e")
  }

  // Try and parse the response body as JSON. 
  // If successful, coerce the `JValue` to a `JArray`.
  private def responseToJson(responseBody:String):Try[JArray] = {
    val jvalue = Try { parse(responseBody) }
    jvalue.flatMap {
      case a:JArray => Success(a)
      case _ => Failure(new IllegalStateException(
        "Incorrectly formatted JSON: not an array"))
    }
  }
}

现在我们已经有了三分之二的工人演员。最后一个链接是跟随者提取器。这个演员的工作很简单:它接受响应解释器传递给它的JArray,并将其转换为跟随者列表。目前,我们只是记录这个列表,但当我们构建获取器管理者时,跟随者提取器将发送消息要求管理者将其添加到要获取的登录队列中。

如前所述,伴随者仅定义了此演员可以接收的消息以及一个 Props 工厂方法:

// FollowerExtractor.scala
import akka.actor._

import org.json4s._
import org.json4s.native.JsonMethods._

object FollowerExtractor {

  // Messages
  case class Extract(login:String, jsonResponse:JArray)

  // Props factory method
  def props = Props[FollowerExtractor]
}

FollowerExtractor类接收包含表示跟随者的JArray信息的Extract消息。它提取login字段并记录它:

class FollowerExtractor extends Actor with ActorLogging {
  import FollowerExtractor._
  def receive = {
    case Extract(login, followerArray) => {
      val followers = extractFollowers(followerArray)
      log.info(s"$login -> ${followers.mkString(", ")}")
    }
  }

  def extractFollowers(followerArray:JArray) = for {
    JObject(follower) <- followerArray
    JField("login", JString(login)) <- follower
  } yield login
}

让我们编写一个新的main方法来测试所有我们的演员:

// FetchNetwork.scala

import akka.actor._
import akka.routing._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._

object FetchNetwork extends App {

  import Fetcher._ // Import messages and factory method

  // Get token if exists
  val token = sys.env.get("GHTOKEN")

  val system = ActorSystem("fetchers")

  // Instantiate actors
  val followerExtractor = system.actorOf(FollowerExtractor.props)
  val responseInterpreter =   
    system.actorOf(ResponseInterpreter.props(followerExtractor))

  val router = system.actorOf(RoundRobinPool(4).props(
    Fetcher.props(token, responseInterpreter))
  )

  List("odersky", "derekwyatt", "rkuhn", "tototoshi") foreach {
    login => router ! Fetch(login)
  }

  // schedule a shutdown
  system.scheduler.scheduleOnce(5.seconds) { system.shutdown }

}

让我们通过 SBT 运行这个例子:

$ GHTOKEN="2502761d..." sbt run
[INFO] [11/05/2015 20:09:37.048] [fetchers-akka.actor.default-dispatcher-3] [akka://fetchers/user/$a] derekwyatt -> adulteratedjedi, joonas, Psycojoker, trapd00r, tyru, ...
[INFO] [11/05/2015 20:09:37.050] [fetchers-akka.actor.default-dispatcher-3] [akka://fetchers/user/$a] tototoshi -> akr4, yuroyoro, seratch, yyuu, ...
[INFO] [11/05/2015 20:09:37.051] [fetchers-akka.actor.default-dispatcher-3] [akka://fetchers/user/$a] odersky -> misto, gkossakowski, mushtaq, ...
[INFO] [11/05/2015 20:09:37.052] [fetchers-akka.actor.default-dispatcher-3] [akka://fetchers/user/$a] rkuhn -> arnbak, uzoice, jond3k, TimothyKlim, relrod, ...

队列控制和拉取模式

我们现在已经在我们的爬虫应用程序中定义了三个工作演员。下一步是定义管理者。获取器管理者负责维护一个要获取的登录队列以及一组我们已经看到的登录名称,以避免多次获取相同的登录。

第一个尝试可能涉及构建一个演员,它维护一组我们已经看到的用户,并在给定一个新用户获取时将其调度到轮询路由器。这种方法的问题在于获取器邮箱中的消息数量会迅速积累:对于每个 API 查询,我们可能会得到数十个跟随者,每个跟随者都可能回到获取器的收件箱。这使我们很难控制堆积的工作量。

这个问题可能导致的第一个问题是 GitHub API 速率限制:即使有认证,我们每小时也限制在 5,000 个请求。当我们达到这个阈值时,停止查询会有用。如果每个获取器都有数百个需要获取的用户积压,我们就无法做出响应。

一个更好的替代方案是使用拉取系统:当获取器发现自己空闲时,它们会从中央队列请求工作。在 Akka 中,当我们有一个生产者比消费者处理得更快时,拉取系统很常见(参考www.michaelpollmeier.com/akka-work-pulling-pattern/)。

管理者和获取器之间的对话将如下进行:

  • 如果管理者从没有工作状态变为有工作状态,它会向所有获取器发送WorkAvailable消息。

  • 当获取器收到WorkAvailable消息或完成一项工作时,它会向队列管理者发送GiveMeWork消息。

  • 当队列管理者收到GiveMeWork消息时,如果没有工作可用或它被限制,它会忽略请求。如果有工作,它向演员发送Fetch(user)消息。

让我们从修改我们的 fetcher 开始。您可以在本书提供的示例代码中的chap09/ghub_crawler目录中找到本节的代码示例(github.com/pbugnion/s4ds)。我们将通过构造函数传递 fetcher 管理器的引用。我们需要更改伴随对象以添加WorkAvailable消息,并将props工厂包括到管理器的引用中:

// Fecther.scala
object Fetcher {
  case class Fetch(url:String)
  case object WorkAvailable

  def props(
    token:Option[String], 
    fetcherManager:ActorRef, 
    responseInterpreter:ActorRef):Props =
      Props(classOf[Fetcher], 
        token, fetcherManager, responseInterpreter)
}

我们还需要更改receive方法,以便在处理完一个请求或接收到WorkAvailable消息后,它能够查询FetcherManager以获取更多工作。

这是 fetchers 的最终版本:

class Fetcher(
  val token:Option[String], 
  val fetcherManager:ActorRef,
  val responseInterpreter:ActorRef) 
extends Actor with ActorLogging {
  import Fetcher._
  import context.dispatcher

  def receive = {
    case Fetch(login) => fetchFollowers(login)
    case WorkAvailable => 
      fetcherManager ! FetcherManager.GiveMeWork
  }

  private def fetchFollowers(login:String) {
    val unauthorizedRequest = Http(
      s"https://api.github.com/users/$login/followers")
    val authorizedRequest = token.map { t =>
      unauthorizedRequest.header("Authorization", s"token $t")
    }
    val request = authorizedRequest.getOrElse(unauthorizedRequest)
    val response = Future { request.asString }

    response.onComplete { r =>
      responseInterpreter ! 
        ResponseInterpreter.InterpretResponse(login, r)
      fetcherManager ! FetcherManager.GiveMeWork
    }
  }

}

现在我们已经有一个工作定义的 fetchers,让我们构建FetcherManager。这是我们迄今为止构建的最复杂的 actor,在我们深入构建它之前,我们需要更多地了解 Akka 工具包的组件。

访问消息的发送者

当我们的 fetcher 管理器收到GiveMeWork请求时,我们需要将工作发送回正确的 fetcher。我们可以使用sender方法访问发送消息的 actor,这是Actor的一个方法,它返回正在处理的消息对应的ActorRef。因此,fetcher 管理器中对应于GiveMeWorkcase语句是:

def receive = {
  case GiveMeWork =>
    login = // get next login to fetch
    sender ! Fetcher.Fetch(login)
  ...
}

由于sender是一个方法,它的返回值会随着每个新传入的消息而改变。因此,它应该仅与receive方法同步使用。特别是,在 future 中使用它是危险的:

def receive = {
  case DoSomeWork =>
    val work = Future { Thread.sleep(20000) ; 5 }
    work.onComplete { result => 
      sender ! Complete(result) // NO!
    }
}

问题在于,当 future 在消息处理后的 20 秒完成时,actor 很可能会处理不同的消息,因此sender的返回值将改变。因此,我们将Complete消息发送给一个完全不同的 actor。

如果您需要在receive方法之外回复消息,例如当 future 完成时,您应该将当前发送者的值绑定到一个变量上:

def receive = {
  case DoSomeWork =>
    // bind the current value of sender to a val
    val requestor = sender
    val work = Future { Thread.sleep(20000) ; 5 }
    work.onComplete { result => requestor ! Complete(result) }
}

有状态演员

fetcher 管理器的行为取决于它是否有工作要分配给 fetchers:

  • 如果它有工作要提供,它需要用Fetcher.Fetch消息响应GiveMeWork消息

  • 如果没有工作,它必须忽略GiveMeWork消息,并且如果添加了工作,它必须向 fetchers 发送一个WorkAvailable消息。

在 Akka 中编码状态的概念很简单。我们指定不同的receive方法,并根据状态从一个切换到另一个。我们将为我们的 fetcher 管理器定义以下receive方法,对应于每个状态:

// receive method when the queue is empty
def receiveWhileEmpty: Receive = { 
    ... 
}

// receive method when the queue is not empty
def receiveWhileNotEmpty: Receive = {
    ...
}

注意,我们必须将接收方法的返回类型定义为Receive。为了将演员从一个方法切换到另一个方法,我们可以使用context.become(methodName)。例如,当最后一个登录名从队列中弹出时,我们可以通过context.become(receiveWhileEmpty)过渡到使用receiveWhileEmpty方法。我们通过将receiveWhileEmpty分配给receive方法来设置初始状态:

def receive = receiveWhileEmpty

跟随者网络爬虫

我们现在准备编写网络爬虫剩余部分的代码。最大的缺失部分是获取器管理器。让我们从伴随对象开始。与工作演员类似,这仅包含演员可以接收的消息定义以及创建Props实例的工厂:

// FetcherManager.scala
import scala.collection.mutable
import akka.actor._

object FetcherManager {
  case class AddToQueue(login:String)
  case object GiveMeWork

  def props(token:Option[String], nFetchers:Int) = 
    Props(classOf[FetcherManager], token, nFetchers)
}

管理器可以接收两种消息:AddToQueue,它告诉管理器将用户名添加到需要获取跟随者的用户队列中,以及由获取器在失业时发出的GiveMeWork

管理器将负责启动获取器、响应解释器和跟随者提取器,以及维护一个用户名内部队列和一组我们已看到的用户名:

// FetcherManager.scala

class FetcherManager(val token:Option[String], val nFetchers:Int) 
extends Actor with ActorLogging {

  import FetcherManager._

  // queue of usernames whose followers we need to fetch
  val fetchQueue = mutable.Queue.empty[String]

  // set of users we have already fetched. 
  val fetchedUsers = mutable.Set.empty[String]

  // Instantiate worker actors
  val followerExtractor = context.actorOf(
    FollowerExtractor.props(self))
  val responseInterpreter = context.actorOf(
    ResponseInterpreter.props(followerExtractor))
  val fetchers = (0 until nFetchers).map { i =>
    context.actorOf(
      Fetcher.props(token, self, responseInterpreter))
  }

  // receive method when the actor has work:
  // If we receive additional work, we just push it onto the
  // queue.
  // If we receive a request for work from a Fetcher,
  // we pop an item off the queue. If that leaves the 
  // queue empty, we transition to the 'receiveWhileEmpty'
  // method.
  def receiveWhileNotEmpty:Receive = {
    case AddToQueue(login) => queueIfNotFetched(login)
    case GiveMeWork =>
      val login = fetchQueue.dequeue
      // send a Fetch message back to the sender.
      // we can use the `sender` method to reply to a message
      sender ! Fetcher.Fetch(login)
      if (fetchQueue.isEmpty) { 
        context.become(receiveWhileEmpty) 
      }
  }

  // receive method when the actor has no work:
  // if we receive work, we add it onto the queue, transition
  // to a state where we have work, and notify the fetchers
  // that work is available.
  def receiveWhileEmpty:Receive = {
    case AddToQueue(login) =>
      queueIfNotFetched(login)
      context.become(receiveWhileNotEmpty)
      fetchers.foreach { _ ! Fetcher.WorkAvailable }
    case GiveMeWork => // do nothing
  }

  // Start with an empty queue.
  def receive = receiveWhileEmpty

  def queueIfNotFetched(login:String) {
    if (! fetchedUsers(login)) {
      log.info(s"Pushing $login onto queue") 
      // or do something useful...
      fetchQueue += login
      fetchedUsers += login
    }
  }
}

我们现在有了获取器管理器。除了跟随者提取器之外,其余的代码可以保持不变。而不是记录跟随者名称,它必须向管理器发送AddToQueue消息。我们将在构造时传递管理器的引用:

// FollowerExtractor.scala
import akka.actor._
import org.json4s._
import org.json4s.native.JsonMethods._

object FollowerExtractor {

  // messages
  case class Extract(login:String, jsonResponse:JArray)

  // props factory method
  def props(manager:ActorRef) = 
    Props(classOf[FollowerExtractor], manager)
}

class FollowerExtractor(manager:ActorRef)
extends Actor with ActorLogging {
  import FollowerExtractor._

  def receive = {
    case Extract(login, followerArray) =>
      val followers = extractFollowers(followerArray)
      followers foreach { f => 
        manager ! FetcherManager.AddToQueue(f) 
      }
  }

  def extractFollowers(followerArray:JArray) = for {
    JObject(follower) <- followerArray
    JField("login", JString(login)) <- follower
  } yield login

}

运行所有这些的main方法非常简单,因为所有实例化演员的代码都已移动到FetcherManager。我们只需要实例化管理器,并给它网络中的第一个节点,然后它会完成其余的工作:

// FetchNetwork.scala
import akka.actor._

object FetchNetwork extends App {

  // Get token if exists
  val token = sys.env.get("GHTOKEN")

  val system = ActorSystem("GithubFetcher")
  val manager = system.actorOf(FetcherManager.props(token, 2))
  manager ! FetcherManager.AddToQueue("odersky")

}

注意我们不再尝试关闭演员系统了。我们将让它运行,爬取网络,直到我们停止它或达到认证限制。让我们通过 SBT 运行这个程序:

$ GHTOKEN="2502761d..." sbt "runMain FetchNetwork"
[INFO] [11/06/2015 06:31:04.614] [GithubFetcher-akka.actor.default-dispatcher-2] [akka://GithubFetcher/user/$a] Pushing odersky onto queue
[INFO] [11/06/2015 06:31:05.563] [GithubFetcher-akka.actor.default-dispatcher-4] [akka://GithubFetcher/user/$a] Pushing misto onto queueINFO] [11/06/2015 06:31:05.563] [GithubFetcher-akka.actor.default-dispatcher-4] [akka://GithubFetcher/user/$a] Pushing gkossakowski onto queue
^C

我们的程序实际上并没有对检索到的跟随者做任何有用的操作,除了记录它们。我们可以将log.info调用替换为,例如,将节点存储在数据库中或在屏幕上绘制图形。

容错性

真实程序会失败,并且以不可预测的方式失败。Akka 以及整个 Scala 社区都倾向于明确规划失败,而不是试图编写不可失败的应用程序。一个容错系统是指当其一个或多个组件失败时仍能继续运行的系统。单个子系统的失败并不一定意味着应用程序的失败。这如何应用于 Akka?

演员模型提供了一个自然的单元来封装失败:演员。当一个演员在处理消息时抛出异常,默认行为是演员重启,但异常不会泄露并影响系统的其余部分。例如,让我们在响应解释器中引入一个任意的失败。我们将修改receive方法,当它被要求解释misto(Martin Odersky 的一个关注者)的响应时抛出异常:

// ResponseInterpreter.scala
def receive = {
  case InterpretResponse("misto", r) => 
    throw new IllegalStateException("custom error")
  case InterpretResponse(login, r) => interpret(login, r)
}

如果你通过 SBT 重新运行代码,你会注意到记录了一个错误。然而,程序并没有崩溃。它只是继续正常运行:

[ERROR] [11/07/2015 12:05:58.938] [GithubFetcher-akka.actor.default-dispatcher-2] [akka://GithubFetcher/user/$a/$b] custom error
java.lang.IllegalStateException: custom error
 at ResponseInterpreter$
 ...
[INFO] [11/07/2015 12:05:59.117] [GithubFetcher-akka.actor.default-dispatcher-2] [akka://GithubFetcher/user/$a] Pushing samfoo onto queue

misto的任何关注者都不会被添加到队列中:他从未通过ResponseInterpreter阶段。让我们逐步了解当异常被抛出时会发生什么:

  • 解释器接收到InterpretResponse("misto", ...)消息。这导致它抛出异常并死亡。其他演员不受异常的影响。

  • 使用与最近去世的演员相同的 Props 实例创建响应解释器的新实例。

  • 当响应解释器完成初始化后,它被绑定到与去世演员相同的ActorRef。这意味着,对于系统的其余部分来说,没有任何变化。

  • 邮箱绑定到ActorRef而不是演员,因此新的响应解释器将与其前任具有相同的邮箱,而不包括有问题的消息。

因此,无论出于什么原因,我们的爬虫在抓取或解析用户响应时崩溃,应用程序的影响将最小——我们只是不会抓取此用户的关注者。

当演员重启时,它携带的任何内部状态都会丢失。因此,例如,如果 fetcher 管理器死亡,我们将丢失队列的当前值和已访问用户。可以通过以下方式减轻丢失内部状态的风险:

  • 采用不同的失败策略:例如,在失败的情况下,我们可以继续处理消息而不重启演员。当然,如果演员死亡是因为其内部状态不一致,这几乎没有什么用处。在下一节中,我们将讨论如何更改失败恢复策略。

  • 通过定期将其写入磁盘并从备份中重新启动时加载来备份内部状态。

  • 通过确保所有“风险”操作都委派给其他演员来保护携带关键状态的演员。在我们的爬虫示例中,所有与外部服务的交互,例如查询 GitHub API 和解析响应,都是通过不带内部状态的演员进行的。正如我们在前面的示例中看到的,如果这些演员中的任何一个死亡,应用程序的影响将最小。相比之下,宝贵的 fetcher 管理器只允许与经过清理的输入进行交互。这被称为错误内核模式:可能引起错误的代码被委派给自杀式演员。

自定义监督策略

在失败时重新启动演员的默认策略并不总是我们想要的。特别是对于携带大量数据的演员,我们可能希望在异常后恢复处理而不是重新启动演员。Akka 允许我们通过在演员的监督者中设置监督策略来自定义此行为。

请记住,所有演员都有父母,包括顶级演员,它们是称为用户守护者的特殊演员的孩子。默认情况下,演员的监督者是它的父母,监督者决定在失败时对演员做什么。

因此,要改变演员对失败的反应方式,你必须设置其父母的监督策略。你通过设置supervisorStrategy属性来完成此操作。默认策略等同于以下内容:

val supervisorStrategy = OneForOneStrategy() {
  case _:ActorInitializationException => Stop
  case _:ActorKilledException => Stop
  case _:DeathPactException => Stop
  case _:Exception => Restart
}

一个监督策略有两个组成部分:

  • OneForOneStrategy确定策略仅适用于失败的演员。相比之下,我们可以使用AllForOneStrategy,它将相同的策略应用于所有被监督者。如果一个子进程失败,所有子进程都将被重新启动(或停止或恢复)。

  • 一个部分函数映射ThrowablesDirective,这是一个关于在失败时如何操作的指令。例如,默认策略将ActorInitializationException(如果构造函数失败会发生)映射到Stop指令,以及(几乎所有)其他异常映射到Restart

有四个指令:

  • Restart:这会销毁有缺陷的演员并重新启动它,将新生成的演员绑定到旧的ActorRef。这将清除演员的内部状态,这可能是一件好事(演员可能因为某些内部不一致性而失败)。

  • Resume:演员只需继续处理其收件箱中的下一个消息。

  • Stop:演员停止并不会被重新启动。这在用于完成单个操作的丢弃演员中很有用:如果这个操作失败,演员就不再需要了。

  • Escalate:监督者本身重新抛出异常,希望它的监督者知道如何处理它。

监督者无法访问其子进程失败的情况。因此,如果一个演员的子进程可能需要不同的恢复策略,最好创建一组中间监督演员来监督不同的子进程组。

作为设置监督策略的示例,让我们调整FetcherManager的监督策略以采用全为一策略,并在其中一个子进程失败时停止其子进程。我们首先进行相关导入:

import akka.actor.SupervisorStrategy._

然后,我们只需在FetcherManager定义中设置supervisorStrategy属性:

class FetcherManager(...) extends Actor with ActorLogging {

  ...

  override val supervisorStrategy = AllForOneStrategy() {
    case _:ActorInitializationException => Stop
    case _:ActorKilledException => Stop
    case _:Exception => Stop
  }

  ...
}

如果你通过 SBT 运行它,你会注意到当代码遇到响应解释器抛出的自定义异常时,系统会停止。这是因为除了检索管理器之外的所有演员现在都已失效。

生命周期钩子

Akka 允许我们通过生命周期钩子指定在演员的生命周期中响应特定事件的代码。Akka 定义了以下钩子:

  • preStart():在演员构造函数完成后但开始处理消息之前运行。这对于运行依赖于演员完全构建的初始化代码很有用。

  • postStop():当演员停止处理消息后死亡时运行。在终止演员之前运行清理代码很有用。

  • preRestart(reason: Throwable, message: Option[Any]):在演员收到重启命令后立即调用。preRestart方法可以访问抛出的异常和有问题的消息,允许进行纠正操作。preRestart的默认行为是停止每个子演员,然后调用postStop

  • postRestart(reason:Throwable):在演员重启后调用。默认行为是调用preStart()

让我们使用系统钩子来在程序运行之间持久化FetcherManager的状态。您可以在本书提供的示例代码中的chap09/ghub_crawler_fault_tolerant目录中找到本节的代码示例(github.com/pbugnion/s4ds)。这将使 fetcher 管理器具有容错性。我们将使用postStop将当前队列和已访问用户集合写入文本文件,并使用preStart从磁盘读取这些文本文件。让我们首先导入读取和写入文件所需的库:

// FetcherManager.scala

import scala.io.Source 
import scala.util._
import java.io._

我们将在FetcherManager伴随对象(更好的方法是将它们存储在配置文件中)中存储两个文本文件的名称:

// FetcherManager.scala
object FetcherManager {
  ...
  val fetchedUsersFileName = "fetched-users.txt"
  val fetchQueueFileName = "fetch-queue.txt"
}

preStart方法中,我们从文本文件中加载已获取用户集合和要获取用户的后备队列,并在postStop方法中,我们用这些数据结构的新值覆盖这些文件:

class FetcherManager(
  val token:Option[String], val nFetchers:Int
) extends Actor with ActorLogging {

  ...

  /** pre-start method: load saved state from text files */
  override def preStart {
    log.info("Running pre-start on fetcher manager")

    loadFetchedUsers
    log.info(
      s"Read ${fetchedUsers.size} visited users from source"
    )

    loadFetchQueue
    log.info(
      s"Read ${fetchQueue.size} users in queue from source"
    )

    // If the saved state contains a non-empty queue, 
    // alert the fetchers so they can start working.
    if (fetchQueue.nonEmpty) {
      context.become(receiveWhileNotEmpty)
      fetchers.foreach { _ ! Fetcher.WorkAvailable }
    }

  }

  /** Dump the current state of the manager */
  override def postStop {
    log.info("Running post-stop on fetcher manager")
    saveFetchedUsers
    saveFetchQueue
  }

     /* Helper methods to load from and write to files */
  def loadFetchedUsers {
    val fetchedUsersSource = Try { 
      Source.fromFile(fetchedUsersFileName) 
    }
    fetchedUsersSource.foreach { s =>
      try s.getLines.foreach { l => fetchedUsers += l }
      finally s.close
    }
  }

  def loadFetchQueue {
    val fetchQueueSource = Try { 
      Source.fromFile(fetchQueueFileName) 
    }
    fetchQueueSource.foreach { s =>
      try s.getLines.foreach { l => fetchQueue += l }
      finally s.close
    }
  }

  def saveFetchedUsers {
    val fetchedUsersFile = new File(fetchedUsersFileName)
    val writer = new BufferedWriter(
      new FileWriter(fetchedUsersFile))
    fetchedUsers.foreach { user => writer.write(user + "\n") }
    writer.close()
  }

  def saveFetchQueue {
    val queueUsersFile = new File(fetchQueueFileName)
    val writer = new BufferedWriter(
      new FileWriter(queueUsersFile))
    fetchQueue.foreach { user => writer.write(user + "\n") }
    writer.close()
  }

...
}

现在我们保存了爬虫在关闭时的状态,我们可以为程序设置比仅仅在无聊时中断程序更好的终止条件。在生产中,例如,当我们在数据库中有足够的名字时,我们可能会停止爬虫。在这个例子中,我们将简单地让爬虫运行 30 秒然后关闭它。

让我们修改main方法:

// FetchNetwork.scala
import akka.actor._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._

object FetchNetwork extends App {

  // Get token if exists
  val token = sys.env.get("GHTOKEN")

  val system = ActorSystem("GithubFetcher")
  val manager = system.actorOf(FetcherManager.props(token, 2))

  manager ! FetcherManager.AddToQueue("odersky")

  system.scheduler.scheduleOnce(30.seconds) { system.shutdown }

}

30 秒后,我们只需调用system.shutdown,这将递归地停止所有演员。这将停止 fetcher 管理器,并调用postStop生命周期钩子。程序运行一次后,我在fetched-users.txt文件中有 2,164 个名字。再次运行它将用户数量增加到 3,728。

我们可以通过在代码运行时让 fetcher manager 定期转储数据结构来进一步提高容错性。由于写入磁盘(或数据库)存在一定的风险(如果数据库服务器宕机或磁盘空间不足怎么办?),将数据结构的写入委托给一个自定义演员而不是危及管理者会更好。

我们的爬虫有一个小问题:当 fetcher manager 停止时,它会停止 fetcher actors、响应解释器和 follower extractor。然而,目前通过这些演员的用户信息都没有被存储。这也导致代码末尾有少量未投递的消息:如果响应解释器在 fetcher 之前停止,fetcher 将尝试向一个不存在的演员投递。这仅涉及少数用户。为了恢复这些登录名,我们可以创建一个收割者 actor,其任务是协调正确顺序地杀死所有工作 actor 并收集它们的内部状态。这种模式已在 Derek Wyatt 的博客文章中记录(letitcrash.com/post/30165507578/shutdown-patterns-in-akka-2)。

我们还没有讨论的内容

Akka 是一个非常丰富的生态系统,远远超出了单章所能公正地描述的范围。有一些重要的工具包部分是你需要的,但我们在这里没有涵盖。我们将给出简要描述,但你可以在 Akka 文档中查找更多详细信息:

摘要

在本章中,你学习了如何将 actor 编织在一起以解决一个困难的并发问题。更重要的是,我们看到了 Akka 的 actor 框架如何鼓励我们以许多独立的封装可变数据块的形式来思考并发问题,这些数据块通过消息传递进行同步。Akka 使得并发编程更容易推理,并且更有趣。

参考文献

德里克·惠特尼的书籍《Akka 并发》是 Akka 的绝佳入门指南。对于想要进行严肃 Akka 编程的人来说,这绝对是一个必看的起点。

“让它崩溃”博客(letitcrash.com)是官方 Akka 博客,其中包含了许多用于解决常见问题的惯用语句和模式的示例。

第十章. 使用 Spark 进行分布式批处理

在第四章《并行集合与未来》中,我们发现了如何使用并行集合来解决“令人尴尬”的并行问题:这些问题可以被分解成一系列不需要(或非常少)任务间通信的任务。

Apache Spark 提供了类似于 Scala 并行集合(以及更多)的行为,但它不是在相同计算机的不同 CPU 上分发任务,而是允许任务在计算机集群中分发。这提供了任意水平的横向可伸缩性,因为我们只需简单地添加更多计算机到集群中。

在本章中,我们将学习 Apache Spark 的基础知识,并使用它来探索一组电子邮件,提取特征,以构建垃圾邮件过滤器为目标。我们将在第十二章《使用 MLlib 进行分布式机器学习》中探索实际构建垃圾邮件过滤器的几种方法。

安装 Spark

在前面的章节中,我们通过在build.sbt文件中指定依赖项,并依赖 SBT 从 Maven Central 仓库中获取它们来包含依赖项。对于 Apache Spark,显式下载源代码或预构建的二进制文件更为常见,因为 Spark 附带了许多命令行脚本,这些脚本极大地简化了作业的启动和与集群的交互。

访问spark.apache.org/downloads.html下载 Spark 1.5.2 版本,选择“为 Hadoop 2.6 或更高版本预构建”的包。如果您需要定制,也可以从源代码构建 Spark,但我们将坚持使用预构建版本,因为它不需要配置。

点击下载将下载一个 tar 包,您可以使用以下命令解包:

$ tar xzf spark-1.5.2-bin-hadoop2.6.tgz

这将创建一个spark-1.5.2-bin-hadoop2.6目录。要验证 Spark 是否正确工作,请导航到spark-1.5.2-bin-hadoop2.6/bin并使用./spark-shell启动 Spark shell。这只是一个加载了 Spark 库的 Scala shell。

您可能希望将bin/目录添加到系统路径中。这样,您就可以从系统中的任何位置调用该目录中的脚本,而无需引用完整路径。在 Linux 或 Mac OS 上,您可以通过在 shell 配置文件(Mac OS 上的.bash_profile,Linux 上的.bashrc.bash_profile)中输入以下行来将变量添加到系统路径:

export PATH=/path/to/spark/bin:$PATH

这些更改将在新的 shell 会话中生效。在 Windows(如果您使用 PowerShell)上,您需要在Documents文件夹中WindowsPowerShell目录下的profile.ps1文件中输入此行:

$env:Path += ";C:\Program Files\GnuWin32\bin"

如果操作正确,您应该能够在系统中的任何目录下通过在终端中输入spark-shell来打开 Spark shell。

获取示例数据

在本章中,我们将探索 Ling-Spam 电子邮件数据集(原始数据集的描述见csmining.org/index.php/ling-spam-datasets.html)。从data.scala4datascience.com/ling-spam.tar.gz(或ling-spam.zip,取决于您首选的压缩方式)下载数据集,并将内容解压缩到包含本章代码示例的目录中。该存档包含两个目录,spam/ham/,分别包含垃圾邮件和合法邮件。

弹性分布式数据集

Spark 将所有计算表达为对分布式集合的转换和操作的序列,称为弹性分布式数据集RDD)。让我们通过 Spark shell 来探索 RDD 是如何工作的。导航到示例目录,并按照以下方式打开 Spark shell:

$ spark-shell
scala> 

让我们从加载一个 RDD 中的电子邮件开始:

scala> val email = sc.textFile("ham/9-463msg1.txt")
email: rdd.RDD[String] = MapPartitionsRDD[1] at textFile

email是一个 RDD,每个元素对应输入文件中的一行。注意我们是如何通过在名为sc的对象上调用textFile方法来创建 RDD 的:

scala> sc
spark.SparkContext = org.apache.spark.SparkContext@459bf87c

sc是一个SparkContext实例,代表 Spark 集群(目前是本地机器)的入口点(现在只需我们的本地机器)。当我们启动 Spark shell 时,会创建一个上下文并将其自动绑定到变量sc

让我们使用flatMap将电子邮件拆分为单词:

scala> val words = email.flatMap { line => line.split("\\s") }
words: rdd.RDD[String] = MapPartitionsRDD[2] at flatMap

如果您熟悉 Scala 中的集合,这会感觉很自然:email RDD 的行为就像一个字符串列表。在这里,我们使用表示空白字符的正则表达式 \s 进行拆分。我们不仅可以使用 flatMap 显式地操作 RDD,还可以使用 Scala 的语法糖来操作 RDD:

scala> val words = for { 
 line <- email
 word <- line.split("\\s") 
} yield word
words: rdd.RDD[String] = MapPartitionsRDD[3] at flatMap

让我们检查结果。我们可以使用 .take(n) 来提取 RDD 的前 n 个元素:

scala> words.take(5)
Array[String] = Array(Subject:, tsd98, workshop, -, -)

我们也可以使用 .count 来获取 RDD 中的元素数量:

scala> words.count
Long = 939

RDD 支持集合支持的大多数操作。让我们使用 filter 从我们的电子邮件中删除标点符号。我们将删除包含任何非字母数字字符的所有单词。我们可以通过过滤掉匹配此 正则表达式 的元素来实现:[^a-zA-Z0-9]

scala> val nonAlphaNumericPattern = "[^a-zA-Z0-9]".r
nonAlphaNumericPattern: Regex = [^a-zA-Z0-9]

scala> val filteredWords = words.filter { 
 word => nonAlphaNumericPattern.findFirstIn(word) == None 
}
filteredWords: rdd.RDD[String] = MapPartitionsRDD[4] at filter

scala> filteredWords.take(5)
Array[String] = Array(tsd98, workshop, 2nd, call, paper)

scala> filteredWords.count
Long = 627

在这个例子中,我们从一个文本文件创建了一个 RDD。我们还可以使用 Spark 上下文上可用的 sc.parallelize 方法从 Scala 可迭代对象创建 RDD:

scala> val words = "the quick brown fox jumped over the dog".split(" ") 
words: Array[String] = Array(the, quick, brown, fox, ...)

scala> val wordsRDD = sc.parallelize(words)
wordsRDD: RDD[String] = ParallelCollectionRDD[1] at parallelize at <console>:23

这对于调试和在 shell 中试验行为很有用。与并行化相对应的是 .collect 方法,它将 RDD 转换为 Scala 数组:

scala> val wordLengths = wordsRDD.map { _.length }
wordLengths: RDD[Int] = MapPartitionsRDD[2] at map at <console>:25

scala> wordLengths.collect
Array[Int] = Array(3, 5, 5, 3, 6, 4, 3, 3)

.collect 方法需要整个 RDD 都能在主节点上适应内存。因此,它要么用于调试较小的数据集,要么用于管道的末尾,以缩减数据集。

如您所见,RDD 提供了一个类似于 Scala 可迭代对象的 API。关键的区别是 RDD 是 分布式容错的。让我们探讨这在实践中意味着什么。

RDD 是不可变的

一旦创建 RDD,就不能更改它。对 RDD 的所有操作要么创建新的 RDD,要么创建其他 Scala 对象。

RDD 是惰性的

当你在交互式 shell 中对 Scala 集合执行 map 和 filter 等操作时,REPL 会将新集合的值打印到屏幕上。这并不适用于 Spark RDD。这是因为 RDD 上的操作是惰性的:只有在需要时才会进行评估。

因此,当我们写:

val email = sc.textFile(...)
val words = email.flatMap { line => line.split("\\s") }

我们正在创建一个 RDD,words,它知道如何从其父 RDD,email,构建自己,而 email 又知道它需要读取一个文本文件并将其拆分为行。然而,直到我们通过调用一个返回 Scala 对象的 动作 来强制评估 RDD,实际上没有任何命令发生。如果我们尝试从一个不存在的文本文件中读取,这一点尤为明显:

scala> val inp = sc.textFile("nonexistent")
inp: rdd.RDD[String] = MapPartitionsRDD[5] at textFile

我们可以无障碍地创建 RDD。我们甚至可以在 RDD 上定义更多的转换。程序只有在这些转换最终评估时才会崩溃:

scala> inp.count // number of lines
org.apache.hadoop.mapred.InvalidInputException: Input path does not exist: file:/Users/pascal/...

.count 动作预期返回我们 RDD 中元素的数量作为一个整数。Spark 除了评估 inp 没有其他选择,这会导致异常。

因此,可能更合适的是将 RDD 视为一个操作管道,而不是一个更传统的集合。

RDD 知道它们的血缘关系

RDD 只能从稳定存储(例如,通过从 Spark 集群中每个节点都存在的文件加载数据)或通过基于其他 RDD 的转换集构建。由于 RDD 是惰性的,它们需要在需要时知道如何构建自身。它们通过知道自己的父 RDD 是谁以及需要应用于父 RDD 的操作来实现这一点。由于父 RDD 是不可变的,这是一个定义良好的过程。

toDebugString 方法提供了一个 RDD 构建过程的图示:

scala> filteredWords.toDebugString
(2) MapPartitionsRDD[6] at filter at <console>:27 []
 |  MapPartitionsRDD[3] at flatMap at <console>:23 []
 |  MapPartitionsRDD[1] at textFile at <console>:21 []
 |  ham/9-463msg1.txt HadoopRDD[0] at textFile at <console>:21 []

RDD 具有弹性

如果你在一个单机上运行应用程序,你通常不需要担心应用程序中的硬件故障:如果计算机失败,你的应用程序无论如何都是注定要失败的。

相比之下,分布式架构应该是容错的:单个机器的故障不应该导致整个应用程序崩溃。Spark RDD 是以容错性为设计理念的。让我们假设其中一个工作节点失败,导致与 RDD 相关的一些数据被破坏。由于 Spark RDD 知道如何从其父 RDD 构建自身,因此不会永久丢失数据:丢失的元素可以在需要时在另一台计算机上重新计算。

RDD 是分布式的

当你从一个文本文件等构建 RDD 时,Spark 会将 RDD 拆分为多个分区。每个分区将完全本地化在单个机器上(尽管通常每台机器有多个分区)。

RDD 上的许多转换可以在每个分区独立执行。例如,当执行.map操作时,输出 RDD 中的给定元素依赖于父 RDD 中的单个元素:数据不需要在分区之间移动。.flatMap.filter操作也是如此。这意味着由这些操作之一产生的 RDD 中的分区依赖于父 RDD 中的单个分区。

另一方面,.distinct转换,它从 RDD 中删除所有重复元素,需要将给定分区中的数据与每个其他分区中的数据进行比较。这需要在节点之间进行洗牌。洗牌,特别是对于大型数据集,是一个昂贵的操作,如果可能的话应该避免。

RDD 上的转换和操作

RDD 支持的操作集可以分为两类:

  • 转换从当前 RDD 创建一个新的 RDD。转换是惰性的:它们不会立即被评估。

  • 操作强制评估 RDD,通常返回一个 Scala 对象,而不是 RDD,或者有一些形式的副作用。操作会立即评估,触发构成此 RDD 的所有转换的执行。

在下面的表格中,我们给出了一些有用的转换和操作的示例。对于完整和最新的列表,请参阅 Spark 文档(spark.apache.org/docs/latest/programming-guide.html#rdd-operations)。

对于这些表格中的示例,我们假设你已经创建了一个包含以下内容的 RDD:

scala> val rdd = sc.parallelize(List("quick", "brown", "quick", "dog"))

以下表格列出了 RDD 上的常见转换。请记住,转换始终生成一个新的 RDD,并且它们是惰性操作:

转换 注意事项 示例(假设 rdd { "quick", "brown", "quick", "dog" }
rdd.map(func) rdd.map { _.size } // => { 5, 5, 5, 3 }
rdd.filter(pred) rdd.filter { _.length < 4 } // => { "dog" }
rdd.flatMap(func) rdd.flatMap { _.toCharArray } // => { 'q', 'u', 'i', 'c', 'k', 'b', 'r', 'o' … }
rdd.distinct() 从 RDD 中移除重复元素。 rdd.distinct // => { "dog", "brown", "quick" }
rdd.pipe(command, [envVars]) 通过外部程序进行管道传输。RDD 元素逐行写入进程的stdin。从stdout读取输出。 rdd.pipe("tr a-z A-Z") // => { "QUICK", "BROWN", "QUICK", "DOG" }

以下表格描述了 RDD 上的常见动作。请记住,动作始终生成 Scala 类型或引起副作用,而不是创建一个新的 RDD。动作强制评估 RDD,触发 RDD 下支撑的转换的执行。

动作 节点 示例(假设 rdd { "quick", "brown", "quick", "dog" }
rdd.first RDD 中的第一个元素。 rdd.first // => quick
rdd.collect 将 RDD 转换为数组(数组必须在主节点上能够适应内存)。 rdd.collect // => ArrayString
rdd.count RDD 中的元素数量。 rdd.count // => 4
rdd.countByValue 元素到该元素出现次数的映射。映射必须在主节点上适应。 rdd.countByValue // => Map(quick -> 2, brown -> 1, dog -> 1)
rdd.take(n) 返回 RDD 中前n个元素的数组。 rdd.take(2) // => Array(quick, brown)
rdd.takeOrdered(n:Int)(implicit ordering: Ordering[T]) 根据元素的默认排序或作为第二个参数传递的排序,按顺序获取 RDD 中的前n个元素。有关如何定义自定义比较函数的说明,请参阅 Scala 文档中的Orderingwww.scala-lang.org/api/current/index.html#scala.math.Ordering)。 rdd.takeOrdered(2) // => Array(brown, dog)``rdd.takeOrdered(2) (Ordering.by { _.size }) // => Array[String] = Array(dog, quick)
rdd.reduce(func) 根据指定的函数减少 RDD。使用 RDD 中的第一个元素作为基数。func应该是交换律和结合律的。 rdd.map { _.size }.reduce { _ + _ } // => 18
rdd.aggregate(zeroValue)(seqOp, combOp) 用于返回类型与 RDD 类型不同的值的减少情况。在这种情况下,我们需要提供一个用于单个分区内减少的函数(seqOp)和一个用于合并两个分区值的函数(combOp)。 rdd.aggregate(0) ( _ + _.size, _ + _ ) // => 18

持久化 RDDs

我们了解到 RDDs 只保留构建元素所需的操作序列,而不是元素本身的值。这当然大大减少了内存使用,因为我们不需要在内存中保留 RDDs 的中间版本。例如,假设我们想要遍历事务日志以识别特定账户上发生的所有交易:

val allTransactions = sc.textFile("transaction.log")
val interestingTransactions = allTransactions.filter { 
  _.contains("Account: 123456")
}

所有交易的集合将很大,而感兴趣账户上的交易集合将小得多。Spark 记住如何构建数据集,而不是数据集本身的政策意味着我们任何时候都不需要在内存中保留所有输入文件的行。

有两种情况我们可能希望避免每次使用 RDD 时重新计算其元素:

  • 对于交互式使用:我们可能已经检测到账户“123456”上的欺诈行为,并希望调查这种情况是如何发生的。我们可能需要在这个 RDD 上执行许多不同的探索性计算,而不必每次都重新读取整个日志文件。因此,持久化interestingTransactions是有意义的。

  • 当算法重新使用中间结果或数据集时。一个典型的例子是逻辑回归。在逻辑回归中,我们通常使用迭代算法来找到最小化损失函数的“最优”系数。在我们迭代算法的每一步中,我们必须从训练集中计算损失函数及其梯度。如果可能的话,我们应该避免重新计算训练集(或从输入文件中重新加载它)。

Spark 在 RDD 上提供了一个.persist方法来实现这一点。通过在 RDD 上调用.persist,我们告诉 Spark 在下次计算时将数据集保留在内存中。

scala> words.persist
rdd.RDD[String] = MapPartitionsRDD[3] at filter

Spark 支持不同的持久化级别,您可以通过传递参数给.persist来调整:

scala> import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel

scala> interestingTransactions.persist(
 StorageLevel.MEMORY_AND_DISK)
rdd.RDD[String] = MapPartitionsRDD[3] at filter

Spark 提供了几个持久化级别,包括:

  • MEMORY_ONLY: 默认存储级别。RDD 存储在 RAM 中。如果 RDD 太大而无法适应内存,则其部分将不会持久化,并且需要即时重新计算。

  • MEMORY_AND_DISK: 尽可能多地存储 RDD 在内存中。如果 RDD 太大,它将溢出到磁盘。这只有在 RDD 计算成本很高的情况下才有意义。否则,重新计算它可能比从磁盘读取更快。

如果你持续持久化多个 RDD 并且内存不足,Spark 将会清除最近最少使用的 RDD(根据选择的持久化级别,可能是丢弃它们或将它们保存到磁盘)。RDD 还提供了一个unpersist方法,可以显式地告诉 Spark 一个 RDD 不再需要。

持久化 RDD 可能会对性能产生重大影响。因此,在调整 Spark 应用程序时,什么和如何持久化变得非常重要。找到最佳持久化级别通常需要一些调整、基准测试和实验。Spark 文档提供了何时使用哪种持久化级别的指南(spark.apache.org/docs/latest/programming-guide.html#rdd-persistence),以及调整内存使用的通用技巧(spark.apache.org/docs/latest/tuning.html)。

重要的是,persist方法不会强制 RDD 进行评估。它只是通知 Spark 引擎,下次计算此 RDD 中的值时,应该保存而不是丢弃。

键值 RDD

到目前为止,我们只考虑了 Scala 值类型的 RDD。支持更复杂数据类型的 RDD 支持额外的操作。Spark 为键值 RDD添加了许多操作:类型参数为元组(K, V)的 RDD,对于任何类型KV

让我们回到我们的示例电子邮件:

scala> val email = sc.textFile("ham/9-463msg1.txt")
email: rdd.RDD[String] = MapPartitionsRDD[1] at textFile

scala> val words = email.flatMap { line => line.split("\\s") }
words: rdd.RDD[String] = MapPartitionsRDD[2] at flatMap

让我们将words RDD 持久化到内存中,以避免反复从磁盘重新读取email文件:

scala> words.persist

要访问键值操作,我们只需要对我们的 RDD 应用一个转换,创建键值对。现在,我们将使用单词作为键。对于每个值,我们将使用 1:

scala> val wordsKeyValue = words.map { _ -> 1 }
wordsKeyValue: rdd.RDD[(String, Int)] = MapPartitionsRDD[32] at map 

scala> wordsKeyValue.first
(String, Int) = (Subject:,1)

键值 RDD 除了核心 RDD 操作外,还支持几个操作。这些操作通过隐式转换添加,使用我们在第五章中探讨的“pimp my library”模式,即通过 JDBC 使用 Scala 和 SQL。这些额外的转换分为两大类:按键转换和 RDD 之间的连接

按键转换是聚合相同键对应值的操作。例如,我们可以使用reduceByKey来计算每个单词在电子邮件中出现的次数。此方法接受属于同一键的所有值,并使用用户提供的函数将它们组合起来:

scala> val wordCounts = wordsKeyValue.reduceByKey { _ + _ }
wordCounts: rdd.RDD[(String, Int)] = ShuffledRDD[35] at reduceByKey

scala> wordCounts.take(5).foreach { println }
(university,6)
(under,1)
(call,3)
(paper,2)
(chasm,2)

注意,reduceByKey通常需要将 RDD 进行洗牌,因为给定键的每个出现可能不在同一个分区:

scala> wordCounts.toDebugString
(2) ShuffledRDD[36] at reduceByKey at <console>:30 []
 +-(2) MapPartitionsRDD[32] at map at <console>:28 []
 |  MapPartitionsRDD[7] at flatMap at <console>:23 []
 |      CachedPartitions: 2; MemorySize: 50.3 KB; ExternalBlockStoreSize: 0.0 B; DiskSize: 0.0 B
 |  MapPartitionsRDD[3] at textFile at <console>:21 []
 |      CachedPartitions: 2; MemorySize: 5.1 KB; ExternalBlockStoreSize: 0.0 B; DiskSize: 0.0 B
 |  ham/9-463msg1.txt HadoopRDD[2] at textFile at <console>:21 []

注意,键值 RDD 与 Scala Maps 不同:相同的键可以出现多次,并且它们不支持O(1)查找。可以使用.collectAsMap操作将键值 RDD 转换成 Scala Map:

scala> wordCounts.collectAsMap
scala.collection.Map[String,Int] = Map(follow -> 2, famous -> 1...

这需要将整个 RDD 拉到主 Spark 节点上。因此,您需要在主节点上拥有足够的内存来存放映射。这通常是管道中的最后一个阶段,用于过滤大型 RDD,只保留我们所需的信息。

下面表格中描述了许多按键操作。对于表格中的示例,我们假设 rdd 是以下方式创建的:

scala> val words = sc.parallelize(List("quick", "brown","quick", "dog"))
words: RDD[String] = ParallelCollectionRDD[25] at parallelize at <console>:21

scala> val rdd = words.map { word => (word -> word.size) }
rdd: RDD[(String, Int)] = MapPartitionsRDD[26] at map at <console>:23

scala> rdd.collect
Array[(String, Int)] = Array((quick,5), (brown,5), (quick,5), (dog,3))

转换 备注 示例(假设 rdd { quick -> 5, brown -> 5, quick -> 5, dog -> 3 }
rdd.mapValues 对值应用一个操作。 rdd.mapValues { _ * 2 } // => { quick -> 10, brown -> 10, quick -> 10, dog ->6 }
rdd.groupByKey 返回一个键值 RDD,其中对应相同键的值被分组到可迭代对象中。 rdd.groupByKey // => { quick -> Iterable(5, 5), brown -> Iterable(5), dog -> Iterable(3) }
rdd.reduceByKey(func) 返回一个键值 RDD,其中对应相同键的值使用用户提供的函数进行组合。 rdd.reduceByKey { _ + _ } // => { quick -> 10, brown -> 5, dog -> 3 }
rdd.keys 返回一个键的 RDD。 rdd.keys // => { quick, brown, quick, dog }
rdd.values 返回一个值的 RDD。 rdd.values // => { 5, 5, 5, 3 }

键值 RDD 上的第二种操作类型涉及通过键将不同的 RDD 连接在一起。这与 SQL 连接有些相似,其中键是要连接的列。让我们加载一封垃圾邮件,并应用我们对我们的正常邮件应用相同的转换:

scala> val spamEmail = sc.textFile("spam/spmsgb17.txt")
spamEmail: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[52] at textFile at <console>:24

scala> val spamWords = spamEmail.flatMap { _.split("\\s") }
spamWords: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[53] at flatMap at <console>:26

scala> val spamWordCounts = spamWords.map { _ -> 1 }.reduceByKey { _ + _ }
spamWordsCount: org.apache.spark.rdd.RDD[(String, Int)] = ShuffledRDD[55] at reduceByKey at <console>:30

scala> spamWordCounts.take(5).foreach { println }
(banner,3)
(package,14)
(call,1)
(country,2)
(offer,1)

spamWordCountswordCounts 都是键值 RDD,其中键对应于消息中的唯一单词,值是该单词出现的次数。由于电子邮件将共享许多相同的单词,因此 spamWordCountswordCounts 之间的键将存在一些重叠。让我们在这两个 RDD 之间进行一个 内部连接,以获取两个电子邮件中都出现的单词:

scala> val commonWordCounts = wordCounts.join(spamWordCounts)
res93: rdd.RDD[(String, (Int, Int))] = MapPartitionsRDD[58] at join at <console>:41

scala> commonWordCounts.take(5).foreach { println }
(call,(3,1))
(include,(6,2))
(minute,(2,1))
(form,(1,7))
((,(36,5))

内部连接产生的 RDD 中的值将是成对出现的。成对中的第一个元素是第一个 RDD 中该键的值,第二个元素是第二个 RDD 中该键的值。因此,单词 call 在合法电子邮件中出现了三次,在垃圾邮件中出现了两次。

Spark 支持所有四种连接类型。例如,让我们执行一个左连接:

scala> val leftWordCounts = wordCounts.leftOuterJoin(spamWordCounts)
leftWordCounts: rdd.RDD[(String, (Int, Option[Int]))] = MapPartitionsRDD[64] at leftOuterJoin at <console>:40

scala> leftWordCounts.take(5).foreach { println }
(call,(3,Some(1)))
(paper,(2,None))
(chasm,(2,None))
(antonio,(1,None))
(event,(3,None))

注意,我们成对中的第二个元素具有 Option[Int] 类型,以适应 spamWordCounts 中缺失的键。例如,单词 paper 在合法电子邮件中出现了两次,在垃圾邮件中从未出现。在这种情况下,用零表示缺失比用 None 更有用。使用 getOrElse 替换 None 为默认值很简单:

scala> val defaultWordCounts = leftWordCounts.mapValues { 
 case(leftValue, rightValue) => (leftValue, rightValue.getOrElse(0))
}
org.apache.spark.rdd.RDD[(String, (Int, Option[Int]))] = MapPartitionsRDD[64] at leftOuterJoin at <console>:40

scala> defaultwordCounts.take(5).foreach { println }
(call,(3,1))
(paper,(2,0))
(chasm,(2,0))
(antonio,(1,0))
(event,(3,0))

下表列出了键值 RDD 上最常见的连接:

转换 结果(假设 rdd1{ quick -> 1, brown -> 2, quick -> 3, dog -> 4 } rdd2 { quick -> 78, brown -> 79, fox -> 80 }
rdd1.join(rdd2) { quick -> (1, 78), quick -> (3, 78), brown -> (2, 79) }
rdd1.leftOuterJoin(rdd2) { dog -> (4, None), quick -> (1, Some(78)), quick -> (3, Some(78)), brown -> (2, Some(79)) }
rdd1.rightOuterJoin(rdd2) { quick -> (Some(1), 78), quick -> (Some(3), 78), brown -> (Some(2), 79), fox -> (None, 80) }
rdd1.fullOuterJoin(rdd2) { dog -> (Some(4), None), quick -> (Some(1), Some(78)), quick -> (Some(3), Some(78)), brown -> (Some(2), Some(79)), fox -> (None, Some(80)) }

要获取完整的转换列表,请查阅 PairRDDFunctions 的 API 文档,spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions

双精度浮点 RDD

在前面的部分中,我们看到了 Spark 通过隐式转换向键值 RDD 添加了功能。同样,Spark 向 doubles 的 RDD 添加了统计功能。让我们提取火腿消息的单词频率,并将值从整数转换为双精度浮点数:

scala> val counts = wordCounts.values.map { _.toDouble }
counts: rdd.RDD[Double] = MapPartitionsRDD[9] at map

然后,我们可以使用 .stats 动作来获取摘要统计信息:

scala> counts.stats
org.apache.spark.util.StatCounter = (count: 397, mean: 2.365239, stdev: 5.740843, max: 72.000000, min: 1.000000)

因此,最常见的单词出现了 72 次。我们还可以使用 .histogram 动作来了解值的分布:

scala> counts.histogram(5)
(Array(1.0, 15.2, 29.4, 43.6, 57.8, 72.0),Array(391, 1, 3, 1, 1))

.histogram 方法返回一个数组的对。第一个数组表示直方图桶的界限,第二个数组表示该桶中元素的数量。因此,有 391 个单词出现次数少于 15.2 次。单词的分布非常偏斜,以至于使用常规大小的桶并不合适。我们可以通过传递一个桶边界的数组到 histogram 方法来传递自定义的桶。例如,我们可能以对数方式分布桶:

scala> counts.histogram(Array(1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0))
res13: Array[Long] = Array(264, 94, 22, 11, 1, 4, 1)

构建和运行独立程序

到目前为止,我们主要通过 Spark shell 与 Spark 进行交互。在接下来的部分中,我们将构建一个独立的应用程序,并在本地或 EC2 集群上启动 Spark 程序。

本地运行 Spark 应用程序

第一步是编写 build.sbt 文件,就像运行标准的 Scala 脚本一样。我们下载的 Spark 二进制文件需要针对 Scala 2.10 运行(你需要从源代码编译 Spark 以运行针对 Scala 2.11 的版本。这并不困难,只需遵循 spark.apache.org/docs/latest/building-spark.html#building-for-scala-211 上的说明即可)。

// build.sbt file

name := "spam_mi"

scalaVersion := "2.10.5"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % "1.4.1"
)

然后,我们运行 sbt package 来编译和构建程序的 jar 包。jar 包将在 target/scala-2.10/ 目录下构建,并命名为 spam_mi_2.10-0.1-SNAPSHOT.jar。你可以使用本章提供的示例代码尝试这个操作。

然后,我们可以使用位于 Spark 安装目录 bin/ 文件夹中的 spark-submit 脚本在本地上运行 jar 包:

$ spark-submit target/scala-2.10/spam_mi_2.10-0.1-SNAPSHOT.jar
... runs the program

可以通过传递参数到 spark-submit 来控制分配给 Spark 的资源。使用 spark-submit --help 来查看完整的参数列表。

如果 Spark 程序有依赖项(例如,对其他 Maven 包的依赖),最简单的方法是使用 SBT 打包 插件将它们捆绑到应用程序 jar 中。让我们假设我们的应用程序依赖于 breeze-viz。现在的 build.sbt 文件如下所示:

// build.sbt

name := "spam_mi"

scalaVersion := "2.10.5"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % "1.5.2" % "provided",
  "org.scalanlp" %% "breeze" % "0.11.2",
  "org.scalanlp" %% "breeze-viz" % "0.11.2",
  "org.scalanlp" %% "breeze-natives" % "0.11.2"
)

SBT 打包是一个 SBT 插件,它构建 jar:包含程序本身以及程序的所有依赖项的 jar。

注意,我们在依赖列表中将 Spark 标记为“提供”,这意味着 Spark 本身将不会包含在 jar 文件中(无论如何它都是由 Spark 环境提供的)。要包含 SBT 打包插件,请在 project/ 目录下创建一个名为 assembly.sbt 的文件,并包含以下行:

addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.0")

您需要重新启动 SBT 以使更改生效。然后您可以使用 SBT 中的 assembly 命令创建打包 jar。这将在 target/scala-2.10 目录中创建一个名为 spam_mi-assembly-0.1-SNAPSHOT.jar 的 jar。您可以使用 spark-submit 运行此 jar。

减少日志输出和 Spark 配置

Spark 默认情况下非常详细。默认日志级别设置为 INFO。为了避免错过重要信息,将日志设置更改为 WARN 是有用的。要全局更改默认日志级别,请进入您安装 Spark 的目录中的 conf 目录。您应该会找到一个名为 log4j.properties.template 的文件。将此文件重命名为 log4j.properties 并查找以下行:

log4j.rootCategory=INFO, console

将此行更改为:

log4j.rootCategory=WARN, console

在该目录中还有其他几个配置文件,您可以使用它们来更改 Spark 的默认行为。有关配置选项的完整列表,请访问 spark.apache.org/docs/latest/configuration.html

在 EC2 上运行 Spark 应用程序

在本地运行 Spark 对于测试很有用,但使用分布式框架的全部意义在于运行能够利用多台不同计算机能力的程序。我们可以在任何能够通过 HTTP 互相通信的计算机集上设置 Spark。通常,我们还需要设置一个分布式文件系统,如 HDFS,这样我们就可以在集群间共享输入文件。为了本例的目的,我们将在一个 Amazon EC2 集群上设置 Spark。

Spark 附带一个名为 ec2/spark-ec2 的 shell 脚本,用于设置 EC2 集群并安装 Spark。它还会安装 HDFS。您需要 Amazon Web Services (AWS) 的账户才能遵循这些示例 (aws.amazon.com)。您需要 AWS 访问密钥和秘密密钥,您可以通过 AWS 网络控制台的 账户 / 安全凭证 / 访问凭证 菜单访问它们。您需要通过环境变量将这些密钥提供给 spark-ec2 脚本。如下注入到您的当前会话中:

$ export AWS_ACCESS_KEY_ID=ABCDEF...
$ export AWS_SECRET_ACCESS_KEY=2dEf...

您也可以将这些行写入您的 shell 配置脚本(例如.bashrc文件或等效文件),以避免每次运行setup-ec2脚本时都需要重新输入。我们已在第六章中讨论了环境变量,Slick – SQL 的函数式接口

您还需要通过在 EC2 网页控制台中点击密钥对来创建一个密钥对,创建一个新的密钥对并下载证书文件。我将假设您将密钥对命名为test_ec2,证书文件为test_ec2.pem。请确保密钥对是在N. Virginia 区域创建的(通过在 EC2 管理控制台右上角选择正确的区域),以避免在本章的其余部分中显式指定区域。您需要将证书文件的访问权限设置为仅用户可读:

$ chmod 400 test_ec2.pem

我们现在可以启动集群了。导航到ec2目录并运行:

$ ./spark-ec2 -k test_ec2 -i ~/path/to/certificate/test_ec2.pem -s 2 launch test_cluster

这将创建一个名为test_cluster的集群,包含一个主节点和两个从节点。从节点的数量通过-s命令行参数设置。集群启动需要一段时间,但您可以通过查看 EC2 管理控制台中的实例窗口来验证实例是否正在正确启动。

设置脚本支持许多选项,用于自定义实例类型、硬盘数量等。您可以通过将--help命令行选项传递给spark-ec2来探索这些选项。

通过向spark-ec2脚本传递不同的命令,可以控制集群的生命周期,例如:

# shut down 'test_cluster'
$ ./spark-ec2 stop test_cluster

# start 'test_cluster'
$ ./spark-ec2 -i test_ec2.pem start test_cluster

# destroy 'test_cluster'
$ ./spark-ec2 destroy test_cluster

关于在 EC2 上使用 Spark 的更多详细信息,请参阅官方文档:spark.apache.org/docs/latest/ec2-scripts.html#running-applications

垃圾邮件过滤

让我们将所学的一切用于实际,并为我们的垃圾邮件过滤器进行一些数据探索。我们将使用 Ling-Spam 电子邮件数据集:csmining.org/index.php/ling-spam-datasets.html。该数据集包含 2412 封正常邮件和 481 封垃圾邮件,所有邮件都是通过语言学邮件列表收到的。我们将提取最能说明邮件是垃圾邮件还是正常邮件的单词。

任何自然语言处理工作流程的第一步是去除停用词和词形还原。去除停用词包括过滤掉非常常见的词,如thethis等等。词形还原涉及将同一词的不同形式替换为规范形式:colorscolor都会映射到color,而organizeorganizingorganizes都会映射到organize。去除停用词和词形还原非常具有挑战性,超出了本书的范围(如果你需要去除停用词并对数据集进行词形还原,你应该使用斯坦福 NLP 工具包:nlp.stanford.edu/software/corenlp.shtml)。幸运的是,Ling-Spam 电子邮件数据集已经被清理和词形还原了(这就是为什么电子邮件中的文本看起来很奇怪)。

当我们构建垃圾邮件过滤器时,我们将使用电子邮件中特定词的存在作为我们模型的特征。我们将使用词袋模型方法:我们考虑电子邮件中哪些词出现,但不考虑词序。

直观上,在决定一封电子邮件是否为垃圾邮件时,一些词会比其他词更重要。例如,包含language的电子邮件很可能是 ham,因为邮件列表是用于语言学讨论的,而language是一个不太可能被垃圾邮件发送者使用的词。相反,那些两种消息类型都常见的词,例如hello,不太可能有很大的作用。

量化一个词在确定一条消息是否为垃圾邮件中的重要性的一种方法是通过互信息MI)。互信息是在我们知道一条消息包含特定词的情况下,关于该消息是 ham 还是 spam 的信息增益。例如,特定电子邮件中存在language这一事实对于判断该电子邮件是垃圾邮件还是 ham 非常有信息量。同样,dollar这个词的存在也是有信息量的,因为它经常出现在垃圾邮件中,而很少出现在 ham 邮件中。相比之下,morning这个词的存在是无信息量的,因为它在垃圾邮件和 ham 邮件中出现的频率大致相同。电子邮件中特定词的存在与该电子邮件是垃圾邮件还是 ham 之间的互信息公式是:

垃圾邮件过滤

其中垃圾邮件过滤是电子邮件包含特定词和属于该类别(ham 或 spam)的联合概率,垃圾邮件过滤是特定词出现在电子邮件中的概率,垃圾邮件过滤是任何电子邮件属于该类别的概率。互信息通常用于决策树。

注意

互信息表达式的推导超出了本书的范围。感兴趣的读者可以参考David MacKay的杰出著作信息论、推理和学习算法,特别是依赖随机变量这一章。

我们计算互信息的关键组成部分是评估一个单词出现在垃圾邮件或垃圾邮件中的概率。给定我们的数据集,这个概率的最佳近似值是该单词出现的消息比例。因此,例如,如果 language 出现在 40% 的消息中,我们将假设该语言出现在任何消息中的概率 垃圾邮件过滤 为 0.4。同样,如果 40% 的消息是垃圾邮件,而 language 出现在这些垃圾邮件的 50% 中,我们将假设该语言出现在电子邮件中的概率,以及该电子邮件是垃圾邮件的概率 垃圾邮件过滤

让我们编写一个 wordFractionInFiles 函数来计算给定语料库中每个单词出现的消息比例。我们的函数将接受一个参数,即一个路径,该路径使用 shell 通配符标识一组文件,例如 ham/*,并且它将返回一个键值 RDD,其中键是单词,值是该单词出现在这些文件中的概率。我们将该函数放入一个名为 MutualInformation 的对象中。

我们首先给出该函数的整个代码列表。如果一开始看不懂没关系:我们将在代码之后详细解释这些难点。你可能发现将这些命令在 shell 中输入是有用的,例如将 fileGlob 替换为 "ham/*"

// MutualInformation.scala
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD

object MutualInformation extends App {

  def wordFractionInFiles(sc:SparkContext)(fileGlob:String)
  :(RDD[(String, Double)], Long) = {

    // A set of punctuation words that need to be filtered out.
    val wordsToOmit = SetString", 
      "(", "@", "/", "Subject:"
    )

    val messages = sc.wholeTextFiles(fileGlob)
    // wholeTextFiles generates a key-value RDD of 
    // file name -> file content

    val nMessages = messages.count()

    // Split the content of each message into a Set of unique
    // words in that message, and generate a new RDD mapping:
    // message -> word
    val message2Word = messages.flatMapValues {
      mailBody => mailBody.split("\\s").toSet
    }

    val message2FilteredWords = message2Word.filter { 
      case(email, word) => ! wordsToOmit(word) 
    }

    val word2Message = message2FilteredWords.map { _.swap }

    // word -> number of messages it appears in.
    val word2NumberMessages = word2Message.mapValues { 
      _ => 1 
    }.reduceByKey { _ + _ }

    // word -> fraction of messages it appears in
    val pPresent = word2NumberMessages.mapValues { 
      _ / nMessages.toDouble 
    }

    (pPresent, nMessages)
  }
}

让我们在 Spark shell 中玩这个函数。为了能够从 shell 中访问此函数,我们需要创建一个包含 MutualInformation 对象的 jar。编写一个类似于上一节中展示的 build.sbt 文件,并使用 sbt package 将代码打包成 jar。然后,使用以下命令打开 Spark shell:

$ spark-shell --jars=target/scala-2.10/spam_mi_2.10-0.1-SNAPSHOT.jar

这将在类路径上打开一个带有我们新创建的 jar 的 Spark shell。让我们在 ham 邮件上运行我们的 wordFractionInFiles 方法:

scala> import MutualInformation._
import MutualInformation._

scala> val (fractions, nMessages) = wordFractionInFiles(sc)("ham/*")
fractions: org.apache.spark.rdd.RDD[(String, Double)] = MapPartitionsRDD[13] at mapValues
nMessages: Long = 2412

让我们获取 fractions RDD 的快照:

scala> fractions.take(5)
Array[(String, Double)] = Array((rule-base,0.002902155887230514), (reunion,4.1459369817578774E-4), (embarrasingly,4.1459369817578774E-4), (mller,8.291873963515755E-4), (sapore,4.1459369817578774E-4))

很想看到在垃圾邮件中出现频率最高的单词。我们可以使用 .takeOrdered 动作来获取 RDD 的顶部值,并使用自定义排序。.takeOrdered 的第二个参数期望是一个类型类 Ordering[T] 的实例,其中 T 是我们 RDD 的类型参数:在这种情况下是 (String, Double)Ordering[T] 是一个具有单个 compare(a:T, b:T) 方法的特质,它描述了如何比较 ab。创建 Ordering[T] 的最简单方法是通过伴随对象的 by 方法,该方法定义了一个用于比较 RDD 元素的关键字。

我们希望按值对我们的键值 RDD 进行排序,并且由于我们想要最常见的单词,而不是最不常见的,我们需要反转这种排序:

scala> fractions.takeOrdered(5)(Ordering.by { - _._2 })
res0: Array[(String, Double)] = Array((language,0.6737147595356551), (university,0.6048922056384743), (linguistic,0.5149253731343284), (information,0.45480928689883915), ('s,0.4369817578772803))

毫不奇怪,language 出现在 67% 的垃圾邮件中,university 出现在 60% 的垃圾邮件中,等等。对垃圾邮件的类似调查显示感叹号字符 ! 出现在 83% 的垃圾邮件中,our 出现在 61%,free 出现在 57%。

我们现在可以开始编写应用程序的主体,以计算每个单词与消息是否为垃圾邮件或正常邮件之间的互信息。我们将代码的主体放入MutualInformation对象中,该对象已经包含了wordFractionInFiles方法。

第一步是创建一个 Spark 上下文:

// MutualInformation.scala
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD

object MutualInformation extends App {

  def wordFractionInFiles(sc:SparkContext)(fileGlob:String)
  :(RDD[(String, Double)], Long) = {
    ...
  }

  val conf = new SparkConf().setAppName("lingSpam")
  val sc = new SparkContext(conf)

注意,当我们使用 Spark shell 时,我们不需要这样做,因为 shell 自带一个预构建的上下文绑定到变量sc

我们现在可以计算在给定消息是垃圾邮件的情况下,包含特定单词的条件概率垃圾邮件过滤。这仅仅是包含该单词的垃圾邮件语料库中的消息比例。这反过来又让我们可以推断包含特定单词且为垃圾邮件的联合概率垃圾邮件过滤。我们将对所有四个类别的组合进行此操作:任何给定单词是否存在于消息中,以及该消息是否为垃圾邮件或正常邮件:

    /* Conditional probabilities RDD:
       word -> P(present | spam) 
    */
    val (pPresentGivenSpam, nSpam) = wordFractionInFiles(sc)("spam/*")
    val pAbsentGivenSpam = pPresentGivenSpam.mapValues { 1.0 - _ }
    val (pPresentGivenHam, nHam) = wordFractionInFiles(sc)("ham/*")
    val pAbsentGivenHam = pPresentGivenHam.mapValues { 1.0 - _ }

    // pSpam is the fraction of spam messages
    val nMessages = nSpam + nHam
    val pSpam = nSpam / nMessages.toDouble

    // pHam is the fraction of ham messages
    val pHam = 1.0 - pSpam 

    /* pPresentAndSpam is a key-value RDD of joint probabilities
       word -> P(word present, spam) 
    */
    val pPresentAndSpam = pPresentGivenSpam.mapValues { 
      _ * pSpam 
    }
    val pPresentAndHam = pPresentGivenHam.mapValues { _ * pHam }
    val pAbsentAndSpam = pAbsentGivenSpam.mapValues { _ * pSpam }
    val pAbsentAndHam = pAbsentGivenHam.mapValues { _ * pHam }

我们将在计算中的几个地方重用这些 RDD,所以让我们告诉 Spark 将它们保存在内存中,以避免需要重新计算:

    pPresentAndSpam.persist
    pPresentAndHam.persist
    pAbsentAndSpam.persist
    pAbsentAndHam.persist

我们现在需要计算单词存在的概率垃圾邮件过滤。这仅仅是pPresentAndSpampPresentAndHam的和,对于每个单词。棘手的部分是并非所有单词都存在于正常邮件和垃圾邮件中。因此,我们必须对这些 RDD 执行全外连接。这将给出一个 RDD,将每个单词映射到一个Option[Double]值的对。对于在正常邮件或垃圾邮件中不存在的单词,我们必须使用默认值。一个合理的默认值是垃圾邮件过滤对于垃圾邮件(更严格的方法是使用加性平滑)。这意味着如果语料库是两倍大,该单词将出现一次。

    val pJoined = pPresentAndSpam.fullOuterJoin(pPresentAndHam)
    val pJoinedDefault = pJoined.mapValues {
      case (presentAndSpam, presentAndHam) => 
        (presentAndSpam.getOrElse(0.5/nSpam * pSpam), 
        presentAndHam.getOrElse(0.5/nHam * pHam))
    }

注意,我们也可以选择 0 作为默认值。这会使信息增益的计算变得有些复杂,因为我们不能对零值取对数,并且似乎不太可能一个特定的单词在电子邮件中出现的概率恰好为零。

我们现在可以构建一个 RDD,将单词映射到垃圾邮件过滤,即单词存在于垃圾邮件或正常邮件中的概率:

    val pPresent = pJoinedDefault.mapValues { 
      case(presentAndHam, presentAndSpam) => 
        presentAndHam + presentAndSpam 
    }
    pPresent.persist

    val pAbsent = pPresent.mapValues { 1.0 - _ }
    pAbsent.persist

我们现在拥有所有需要的 RDD 来计算单词在消息中存在与否与消息是否为正常邮件或垃圾邮件之间的互信息。我们需要使用前面概述的互信息方程将它们全部结合起来。

我们将首先定义一个辅助方法,该方法给定联合概率 RDD P(X, Y) 和边缘概率 P(X)P(Y),计算垃圾邮件过滤。在这里,P(X) 例如可以是单词存在于消息中的概率垃圾邮件过滤,而 P(Y) 将会是该消息是垃圾邮件的概率垃圾邮件过滤

    def miTerm(
      pXYs:RDD[(String, Double)], 
      pXs:RDD[(String, Double)], 
      pY: Double,
      default: Double // for words absent in PXY
    ):RDD[(String, Double)] = 
      pXs.leftOuterJoin(pXYs).mapValues { 
        case (pX, Some(pXY)) => pXY * math.log(pXY/(pX*pY)) 
        case (pX, None) => default * math.log(default/(pX*pY))
    }

我们可以使用我们的函数来计算互信息总和中的四个项:

    val miTerms = List(
      miTerm(pPresentAndSpam, pPresent, pSpam, 0.5/nSpam * pSpam),
      miTerm(pPresentAndHam, pPresent, pHam, 0.5/nHam * pHam),
      miTerm(pAbsentAndSpam, pAbsent, pSpam, 0.5/nSpam * pSpam),
      miTerm(pAbsentAndHam, pAbsent, pHam, 0.5/nHam * pHam)
    )

最后,我们只需要将这些四个项相加:

    val mutualInformation = miTerms.reduce { 
      (term1, term2) => term1.join(term2).mapValues { 
         case (l, r) => l + r 
      } 
    }

RDD mutualInformation 是一个键值 RDD,将每个词映射到衡量该词在区分邮件是否为垃圾邮件或正常邮件时的信息量的度量。让我们打印出最能够表明邮件是否为正常邮件或垃圾邮件的二十个词:

    mutualInformation.takeOrdered(20)(Ordering.by { - _._2 })
      .foreach { println }

让我们使用 spark-submit 来运行这个示例:

$ sbt package
$ spark-submit target/scala-2.10/spam_mi_2.10-0.1-SNAPSHOT.jar
(!,0.1479941771292119)
(language,0.14574624861510874)
(remove,0.11380645864246142)
(free,0.1073496947123657)
(university,0.10695975885487692)
(money,0.07531772498093084)
(click,0.06887598051593441)
(our,0.058950906866052394)
(today,0.05485248095680509)
(sell,0.05385519653184113)
(english,0.053509319455430575)
(business,0.05299311289740539)
(market,0.05248394151802276)
(product,0.05096229706182162)
(million,0.050233193237964546)
(linguistics,0.04990172586630499)
(internet,0.04974101556655623)
(company,0.04941817269989519)
(%,0.04890193809823071)
(save,0.04861393414892205)

因此,我们发现像 languagefree! 这样的词包含最多的信息,因为它们几乎只存在于垃圾邮件或正常邮件中。一个非常简单的分类算法只需取前 10 个(按互信息排序)垃圾邮件词和前 10 个正常邮件词,然后查看一条消息是否包含更多垃圾邮件词或正常邮件词。我们将在第十二章分布式机器学习与 MLlib 中更深入地探讨用于分类的机器学习算法。

揭开盖子

在本章的最后部分,我们将非常简要地讨论 Spark 的内部工作原理。对于更详细的讨论,请参阅本章末尾的 参考文献 部分。

当你打开一个 Spark 上下文,无论是显式地还是通过启动 Spark shell,Spark 会启动一个包含当前任务和过去任务执行详情的 Web UI。让我们看看我们在上一节中编写的示例互信息程序的实际操作。为了防止程序完成后上下文关闭,你可以在 main 方法的最后(在调用 takeOrdered 之后)插入一个对 readLine 的调用。这会期望用户输入,因此程序执行将暂停,直到你按下 enter 键。

要访问 UI,将你的浏览器指向 127.0.0.1:4040。如果你有其他正在运行的 Spark shell 实例,端口可能是 40414042 等等。

揭开盖子

UI 的第一页告诉我们,我们的应用程序包含三个 作业。作业是动作的结果。实际上,我们的应用程序中确实有三个动作:前两个是在 wordFractionInFiles 函数中调用的:

val nMessages = messages.count()

最后一个作业是由对 takeOrdered 的调用产生的,它强制执行计算互信息的整个 RDD 转换管道。

Web UI 允许我们深入了解每个作业。点击作业表中的 takeOrdered 作业。你将被带到一页,其中更详细地描述了该作业:

揭开盖子

特别值得注意的是 DAG 可视化 项。这是一个执行计划的图,用于满足动作,并提供了对 Spark 内部工作原理的洞察。

当你通过在 RDD 上调用一个动作来定义一个作业时,Spark 会查看 RDD 的 lineage 并构建一个映射依赖关系的图: lineage 中的每个 RDD 都由一个节点表示,从该 RDD 的父节点到自身的有向边。这种图称为有向无环图(DAG),是一种用于依赖关系解析的有用数据结构。让我们使用 web UI 来探索我们程序中takeOrdered作业的 DAG。这个图相当复杂,因此很容易迷路,所以这里有一个简化的复制品,它只列出了程序中绑定到变量名的 RDD。

揭开盖子

如您所见,在图的底部,我们有mutualInformation RDD。这是我们为我们的动作需要构建的 RDD。这个 RDD 依赖于求和中的一些中间元素,例如igFragment1igFragment2等。我们可以通过依赖关系列表回溯,直到达到图的另一端:不依赖于其他 RDD,只依赖于外部源的 RDD。

一旦构建了图,Spark 引擎就会制定一个执行作业的计划。计划从只有外部依赖(例如从磁盘加载文件或从数据库中检索而构建的 RDD)或已经缓存了数据的 RDD 开始。图上的每个箭头都被转换为一组任务,每个任务将一个转换应用于数据的一个分区。

任务被分组到阶段中。一个阶段由一组可以在不需要中间洗牌的情况下执行的任务组成。

数据洗牌和分区

要理解 Spark 中的数据洗牌,我们首先需要了解 RDD 中数据是如何分区的。当我们通过例如从 HDFS 加载文件或读取本地存储中的文件来创建一个 RDD 时,Spark 无法控制哪些数据位被分布在哪些分区中。这对于键值 RDD 来说是一个问题:这些 RDD 通常需要知道特定键的出现位置,例如执行连接操作。如果键可以在 RDD 的任何位置出现,我们必须查看每个分区以找到该键。

为了防止这种情况,Spark 允许在键值 RDD 上定义一个分区器。分区器是 RDD 的一个属性,它决定了特定键落在哪个分区。当一个 RDD 设置了分区器时,键的位置完全由分区器决定,而不是由 RDD 的历史或键的数量决定。具有相同分区器的两个不同的 RDD 将把相同的键映射到相同的分区。

分区通过它们对转换的影响来影响性能。在键值 RDD 上有两种类型的转换:

  • 窄变换,例如mapValues。在窄变换中,用于计算子 RDD 中分区的数据位于父分区的一个分区。因此,窄变换的数据处理可以完全本地执行,无需在节点之间通信数据。

  • 广泛变换,例如reduceByKey。在广泛变换中,用于计算任何单个分区的数据可以位于父分区中的所有分区。一般来说,由广泛变换产生的 RDD 将设置一个分区器。例如,reduceByKey变换的输出默认是哈希分区:特定键最终所在的分区由hash(key) % numPartitions确定。

因此,在我们的互信息示例中,pPresentAndSpampPresentAndHam将具有相同的分区结构,因为它们都有默认的哈希分区器。所有子 RDD 都保留相同的键,一直到mutualInformation。例如,单词language将在每个 RDD 中位于相同的分区。

为什么这些都很重要?如果一个 RDD 设置了分区器,那么这个分区器将保留在所有后续的窄变换中,这些变换源自该 RDD。让我们回到我们的互信息示例。pPresentGivenHampPresentGivenSpam这两个 RDD 都源自reduceByKey操作,并且它们都有字符串键。因此,它们都将有相同的哈希分区器(除非我们明确设置不同的分区器)。当我们构建pPresentAndSpampPresentAndHam时,这个分区器将被保留。当我们构建pPresent时,我们执行pPresentAndSpampPresentAndHam的完全外连接。由于这两个 RDD 有相同的分区器,子 RDD pPresent有窄依赖:我们只需将pPresentAndSpam的第一个分区与pPresentAndHam的第一个分区连接起来,将pPresentAndSpam的第二个分区与pPresentAndHam的第二个分区连接起来,依此类推,因为任何字符串键都会在两个 RDD 中被哈希到相同的分区。相比之下,如果没有分区器,我们就必须将pPresentAndSpam的每个分区的数据与pPresentAndSpam的每个分区连接起来。这将需要将数据发送到所有持有pPresentAndSpam的节点,这是一个耗时的操作。

由于广泛依赖关系,需要将数据发送到网络中以构建子 RDD 的过程称为洗牌。优化 Spark 程序的大部分艺术在于减少洗牌,并在必要时减少洗牌量。

摘要

在本章中,我们探讨了 Spark 的基础知识,并学习了如何构建和操作 RDD。在下一章中,我们将学习关于 Spark SQL 和 DataFrame 的知识,这是一组隐式转换,允许我们以类似于 pandas DataFrame 的方式操作 RDD,以及如何使用 Spark 与不同的数据源进行交互。

参考资料

第十一章。Spark SQL 和 DataFrame

在上一章中,我们学习了如何使用 Spark 构建一个简单的分布式应用程序。我们使用的数据是以文本文件形式存储的一组电子邮件。

我们了解到 Spark 是围绕 弹性分布式数据集RDDs)的概念构建的。我们探索了几种类型的 RDD:简单的字符串 RDD、键值 RDD 和双精度浮点数 RDD。在键值 RDD 和双精度浮点数 RDD 的情况下,Spark 通过隐式转换增加了比简单 RDD 更多的功能。还有一种重要的 RDD 类型我们尚未探索:DataFrame(之前称为 SchemaRDD)。DataFrame 允许操作比我们迄今为止探索的更复杂的对象。

DataFrame 是一种分布式表格数据结构,因此非常适合表示和操作结构化数据。在本章中,我们将首先通过 Spark shell 研究 DataFrame,然后使用上一章中介绍的 Ling-spam 电子邮件数据集,看看 DataFrame 如何集成到机器学习管道中。

DataFrames – 快速入门

让我们从打开 Spark shell 开始:

$ spark-shell

让我们假设我们对在患者群体上运行分析以估计他们的整体健康状况感兴趣。我们已经为每位患者测量了他们的身高、体重、年龄以及他们是否吸烟。

我们可能将每位患者的读数表示为一个案例类(你可能希望将其中一些内容写入文本编辑器,然后使用 :paste 命令将其粘贴到 Scala shell 中):

scala> case class PatientReadings(
 val patientId: Int,
 val heightCm: Int,
 val weightKg: Int,
 val age:Int, 
 val isSmoker:Boolean 
)
defined class PatientReadings

通常,我们会有成千上万的病人,可能存储在数据库或 CSV 文件中。我们将在本章后面讨论如何与外部源交互。现在,让我们直接在 shell 中硬编码一些读取值:

scala> val readings = List(
 PatientReadings(1, 175, 72, 43, false),
 PatientReadings(2, 182, 78, 28, true),
 PatientReadings(3, 164, 61, 41, false),
 PatientReadings(4, 161, 62, 43, true)
)
List[PatientReadings] = List(...

我们可以通过使用 sc.parallelizereadings 转换为 RDD:

scala> val readingsRDD = sc.parallelize(readings)
readingsRDD: RDD[PatientReadings] = ParallelCollectionRDD[0] at parallelize at <console>:25

注意,我们 RDD 的类型参数是 PatientReadings。让我们使用 .toDF 方法将 RDD 转换为 DataFrame:

scala> val readingsDF = readingsRDD.toDF
readingsDF: sql.DataFrame = [patientId: int, heightCm: int, weightKg: int, age: int, isSmoker: boolean]

我们已经创建了一个 DataFrame,其中每一行对应于特定病人的读取值,列对应于不同的特征:

scala> readingsDF.show
+---------+--------+--------+---+--------+
|patientId|heightCm|weightKg|age|isSmoker|
+---------+--------+--------+---+--------+
|        1|     175|      72| 43|   false|
|        2|     182|      78| 28|    true|
|        3|     164|      61| 41|   false|
|        4|     161|      62| 43|    true|
+---------+--------+--------+---+--------+

创建 DataFrame 最简单的方法是使用 RDD 上的 toDF 方法。我们可以将任何 RDD[T](其中 T 是一个 case class 或一个元组)转换为 DataFrame。Spark 将将 case class 的每个属性映射到 DataFrame 中适当类型的列。它使用反射来发现属性的名字和类型。还有几种其他方法可以构建 DataFrame,无论是从 RDD 还是外部来源,我们将在本章后面探讨。

DataFrames 支持许多操作来操作行和列。例如,让我们添加一个用于 体质指数BMI)的列。BMI 是一种常见的将 身高体重 聚合起来以判断某人是否超重或体重不足的方法。BMI 的公式是:

DataFrames – 快速入门

让我们先创建一个以米为单位的身高列:

scala> val heightM = readingsDF("heightCm") / 100.0 
heightM: sql.Column = (heightCm / 100.0)

heightM 具有数据类型 Column,表示 DataFrame 中的数据列。列支持许多算术和比较运算符,这些运算符按元素方式应用于列(类似于在 第二章中遇到的 Breeze 向量,使用 Breeze 操作数据)。列上的操作是惰性的:当定义时,heightM 列实际上并没有计算。现在让我们定义一个 BMI 列:

scala> val bmi = readingsDF("weightKg") / (heightM*heightM)
bmi: sql.Column = (weightKg / ((heightCm / 100.0) * (heightCm / 100.0)))

在我们的读取 DataFrame 中添加 bmi 列将很有用。由于 DataFrames,就像 RDDs 一样,是不可变的,我们必须定义一个新的 DataFrame,它与 readingsDF 完全相同,但增加了一个用于 BMI 的列。我们可以使用 withColumn 方法来实现,该方法接受新列的名称和一个 Column 实例作为参数:

scala> val readingsWithBmiDF = readingsDF.withColumn("BMI", bmi)
readingsWithBmiDF: sql.DataFrame = [heightCm: int, weightKg: int, age: int, isSmoker: boolean, BMI: double]

我们迄今为止看到的所有操作都是 转换:它们定义了一个操作管道,创建新的 DataFrames。这些转换在我们调用 动作(如 show)时执行:

scala> readingsWithBmiDF.show
+---------+--------+--------+---+--------+------------------+
|patientId|heightCm|weightKg|age|isSmoker|               BMI|
+---------+--------+--------+---+--------+------------------+
|        1|     175|      72| 43|   false|23.510204081632654|
|        2|     182|      78| 28|    true| 23.54788069073783|
|        3|     164|      61| 41|   false|22.679952409280194|
|        4|     161|      62| 43|    true|  23.9188302920412|
+---------+--------+--------+---+--------+------------------+

除了创建额外的列,DataFrames 还支持过滤满足特定谓词的行。例如,我们可以选择所有吸烟者:

scala> readingsWithBmiDF.filter {
 readingsWithBmiDF("isSmoker") 
}.show
+---------+--------+--------+---+--------+-----------------+
|patientId|heightCm|weightKg|age|isSmoker|              BMI|
+---------+--------+--------+---+--------+-----------------+
|        2|     182|      78| 28|    true|23.54788069073783|
|        4|     161|      62| 43|    true| 23.9188302920412|
+---------+--------+--------+---+--------+-----------------+

或者,为了选择体重超过 70 公斤的人:

scala> readingsWithBmiDF.filter { 
 readingsWithBmiDF("weightKg") > 70 
}.show
+---------+--------+--------+---+--------+------------------+
|patientId|heightCm|weightKg|age|isSmoker|               BMI|
+---------+--------+--------+---+--------+------------------+
|        1|     175|      72| 43|   false|23.510204081632654|
|        2|     182|      78| 28|    true| 23.54788069073783|
+---------+--------+--------+---+--------+------------------+

在表达式中重复 DataFrame 名称可能会变得繁琐。Spark 定义了操作符 $ 来引用当前 DataFrame 中的列。因此,上面的 filter 表达式可以更简洁地写成:

scala> readingsWithBmiDF.filter { $"weightKg" > 70 }.show
+---------+--------+--------+---+--------+------------------+
|patientId|heightCm|weightKg|age|isSmoker|               BMI|
+---------+--------+--------+---+--------+------------------+
|        1|     175|      72| 43|   false|23.510204081632654|
|        2|     182|      78| 28|    true| 23.54788069073783|
+---------+--------+--------+---+--------+------------------+

.filter 方法是重载的。它接受一个布尔值列,如上所述,或者一个标识当前 DataFrame 中布尔列的字符串。因此,为了过滤我们的 readingsWithBmiDF DataFrame 以子选择吸烟者,我们也可以使用以下方法:

scala> readingsWithBmiDF.filter("isSmoker").show
+---------+--------+--------+---+--------+-----------------+
|patientId|heightCm|weightKg|age|isSmoker|              BMI|
+---------+--------+--------+---+--------+-----------------+
|        2|     182|      78| 28|    true|23.54788069073783|
|        4|     161|      62| 43|    true| 23.9188302920412|
+---------+--------+--------+---+--------+-----------------+

当比较相等时,你必须使用特殊的 三重等于 操作符来比较列:

scala> readingsWithBmiDF.filter { $"age" === 28 }.show
+---------+--------+--------+---+--------+-----------------+
|patientId|heightCm|weightKg|age|isSmoker|              BMI|
+---------+--------+--------+---+--------+-----------------+
|        2|     182|      78| 28|    true|23.54788069073783|
+---------+--------+--------+---+--------+-----------------+

类似地,你必须使用 !== 来选择不等于某个值的行:

scala> readingsWithBmiDF.filter { $"age" !== 28 }.show
+---------+--------+--------+---+--------+------------------+
|patientId|heightCm|weightKg|age|isSmoker|               BMI|
+---------+--------+--------+---+--------+------------------+
|        1|     175|      72| 43|   false|23.510204081632654|
|        3|     164|      61| 41|   false|22.679952409280194|
|        4|     161|      62| 43|    true|  23.9188302920412|
+---------+--------+--------+---+--------+------------------+

聚合操作

我们已经看到了如何将操作应用于 DataFrame 中的每一行以创建新列,以及如何使用过滤器从原始 DataFrame 中选择子集行来构建新的 DataFrame。DataFrame 的最后一系列操作是分组操作,相当于 SQL 中的 GROUP BY 语句。让我们计算吸烟者和非吸烟者的平均 BMI。我们必须首先告诉 Spark 按列(在这种情况下是 isSmoker 列)对 DataFrame 进行分组,然后应用聚合操作(在这种情况下是平均)以减少每个组:

scala> val smokingDF = readingsWithBmiDF.groupBy(
 "isSmoker").agg(avg("BMI"))
smokingDF: org.apache.spark.sql.DataFrame = [isSmoker: boolean, AVG(BMI): double]

这已经创建了一个包含两列的新 DataFrame:分组列和我们要对其聚合的列。让我们展示这个 DataFrame:

scala> smokingDF.show
+--------+------------------+
|isSmoker|          AVG(BMI)|
+--------+------------------+
|    true|23.733355491389517|
|   false|23.095078245456424|
+--------+------------------+

除了平均之外,还有几个操作符可以用于对每个组进行聚合。以下表格中概述了一些更重要的一些,但要获取完整列表,请参阅spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.functions$聚合函数 部分:

操作符 备注
avg(column) 指定列的组平均值。
count(column) 在指定列中每个组中的元素数量。
countDistinct(column, ... ) 每个组中不同元素的数量。这也可以接受多个列以返回跨多个列的唯一元素计数。
first(column), last(column) 每个组中的第一个/最后一个元素
max(column), min(column) 每个组中的最大/最小元素
sum(column) 每个组中值的总和

每个聚合操作符都接受列名,作为字符串,或者类型为 Column 的表达式。后者允许对复合表达式进行聚合。如果我们想得到样本中吸烟者和非吸烟者的平均身高(以米为单位),我们可以使用:

scala> readingsDF.groupBy("isSmoker").agg { 
 avg($"heightCm"/100.0) 
}.show
+--------+-----------------------+
|isSmoker|AVG((heightCm / 100.0))|
+--------+-----------------------+
|    true|                  1.715|
|   false|     1.6949999999999998|
+--------+-----------------------+

我们还可以使用复合表达式来定义要分组的列。例如,为了计算每个 age 组中患者的数量,按十年递增,我们可以使用:

scala> readingsDF.groupBy(floor($"age"/10)).agg(count("*")).show
+-----------------+--------+
|FLOOR((age / 10))|count(1)|
+-----------------+--------+
|              4.0|       3|
|              2.0|       1|
+-----------------+--------+

我们使用短横线 "*" 来表示对每一列的计数。

将 DataFrames 合并

到目前为止,我们只考虑了单个 DataFrame 上的操作。Spark 还提供了类似 SQL 的连接来组合 DataFrame。假设我们还有一个将患者 ID 映射到(收缩压)血压测量的 DataFrame。我们将假设我们有一个将患者 ID 映射到血压的列表对:

scala> val bloodPressures = List((1 -> 110), (3 -> 100), (4 -> 125))
bloodPressures: List[(Int, Int)] = List((1,110), (3,100), (4,125))

scala> val bloodPressureRDD = sc.parallelize(bloodPressures)
res16: rdd.RDD[(Int, Int)] = ParallelCollectionRDD[74] at parallelize at <console>:24

我们可以从这个元组 RDD 构建一个 DataFrame。然而,与从案例类 RDD 构建 DataFrame 不同,Spark 无法推断列名。因此,我们必须将这些列名显式传递给 .toDF

scala> val bloodPressureDF = bloodPressureRDD.toDF(
 "patientId", "bloodPressure")
bloodPressureDF: DataFrame = [patientId: int, bloodPressure: int]

scala> bloodPressureDF.show
+---------+-------------+
|patientId|bloodPressure|
+---------+-------------+
|        1|          110|
|        3|          100|
|        4|          125|
+---------+-------------+

让我们将 bloodPressureDFreadingsDF 通过患者 ID 作为连接键进行连接:

scala> readingsDF.join(bloodPressureDF, 
 readingsDF("patientId") === bloodPressureDF("patientId")
).show
+---------+--------+--------+---+--------+---------+-------------+
|patientId|heightCm|weightKg|age|isSmoker|patientId|bloodPressure|
+---------+--------+--------+---+--------+---------+-------------+
|        1|     175|      72| 43|   false|        1|          110|
|        3|     164|      61| 41|   false|        3|          100|
|        4|     161|      62| 43|    true|        4|          125|
+---------+--------+--------+---+--------+---------+-------------+

这执行了一个内连接:只有同时存在于两个 DataFrame 中的患者 ID 被包含在结果中。连接类型可以作为额外的参数传递给 join。例如,我们可以执行一个左连接

scala> readingsDF.join(bloodPressureDF,
 readingsDF("patientId") === bloodPressureDF("patientId"),
 "leftouter"
).show
+---------+--------+--------+---+--------+---------+-------------+
|patientId|heightCm|weightKg|age|isSmoker|patientId|bloodPressure|
+---------+--------+--------+---+--------+---------+-------------+
|        1|     175|      72| 43|   false|        1|          110|
|        2|     182|      78| 28|    true|     null|         null|
|        3|     164|      61| 41|   false|        3|          100|
|        4|     161|      62| 43|    true|        4|          125|
+---------+--------+--------+---+--------+---------+-------------+

可能的连接类型有 innerouterleftouterrightouterleftsemi。这些都应该很熟悉,除了 leftsemi,它对应于左半连接。这与内连接相同,但在连接后只保留左侧的列。因此,这是一种过滤 DataFrame 以找到存在于另一个 DataFrame 中的行的方法。

自定义 DataFrame 上的函数

到目前为止,我们只使用内置函数来操作 DataFrame 的列。虽然这些通常足够用,但我们有时需要更大的灵活性。Spark 允许我们通过用户定义函数UDFs)将自定义转换应用于每一行。假设我们想使用我们在第二章中推导出的方程,即使用 Breeze 操作数据,来计算给定身高和体重的男性概率。我们计算出决策边界如下:

自定义 DataFrame 上的函数

任何具有 f > 0 的人,在给定他们的身高、体重以及用于第二章,使用 Breeze 操作数据(该数据基于学生,因此不太可能代表整个人群)的训练集的情况下,比女性更有可能是男性。要将厘米单位的身高转换为归一化身高 rescaledHeight,我们可以使用以下公式:

自定义 DataFrame 上的函数

同样,要将体重(以千克为单位)转换为归一化体重 rescaledWeight,我们可以使用以下公式:

自定义 DataFrame 上的函数

从训练集中计算了 heightweight 的平均值和标准差。让我们编写一个 Scala 函数,该函数返回给定身高和体重的人更有可能是男性:

scala> def likelyMale(height:Int, weight:Int):Boolean = {
 val rescaledHeight = (height - 171.0)/8.95
 val rescaledWeight = (weight - 65.7)/13.4
 -0.75 + 2.48*rescaledHeight + 2.23*rescaledWeight > 0
}

要在 Spark DataFrame 上使用此函数,我们需要将其注册为用户定义函数UDF)。这将我们的函数,该函数接受整数参数,转换为接受列参数的函数:

scala> val likelyMaleUdf = sqlContext.udf.register(
 "likelyMaleUdf", likelyMale _)
likelyMaleUdf: org.apache.spark.sql.UserDefinedFunction = UserDefinedFunction(<function2>,BooleanType,List())

要注册一个 UDF,我们必须能够访问一个 sqlContext 实例。SQL 上下文提供了 DataFrame 操作的入口点。Spark shell 在启动时创建一个 SQL 上下文,绑定到变量 sqlContext,并在 shell 会话关闭时销毁它。

传递给 register 函数的第一个参数是 UDF 的名称(我们将在以后编写 DataFrame 上的 SQL 语句时使用 UDF 名称,但现在你可以忽略它)。然后我们可以像使用 Spark 中包含的内置转换一样使用 UDF:

scala> val likelyMaleColumn = likelyMaleUdf(
 readingsDF("heightCm"), readingsDF("weightKg"))
likelyMaleColumn: org.apache.spark.sql.Column = UDF(heightCm,weightKg)

scala> readingsDF.withColumn("likelyMale", likelyMaleColumn).show
+---------+--------+--------+---+--------+----------+
|patientId|heightCm|weightKg|age|isSmoker|likelyMale|
+---------+--------+--------+---+--------+----------+
|        1|     175|      72| 43|   false|      true|
|        2|     182|      78| 28|    true|      true|
|        3|     164|      61| 41|   false|     false|
|        4|     161|      62| 43|    true|     false|
+---------+--------+--------+---+--------+----------+

正如你所见,Spark 将 UDF 的底层函数应用于 DataFrame 中的每一行。我们不仅限于使用 UDF 创建新列。我们还可以在 filter 表达式中使用它们。例如,为了选择可能对应女性的行:

scala> readingsDF.filter(
 ! likelyMaleUdf($"heightCm", $"weightKg")
).show
+---------+--------+--------+---+--------+
|patientId|heightCm|weightKg|age|isSmoker|
+---------+--------+--------+---+--------+
|        3|     164|      61| 41|   false|
|        4|     161|      62| 43|    true|
+---------+--------+--------+---+--------+

使用 UDF 允许我们定义任意的 Scala 函数来转换行,为数据处理提供了巨大的额外功能。

DataFrame 的不可变性和持久性

与 RDD 一样,DataFrame 是不可变的。当你对一个 DataFrame 定义一个转换时,这总是创建一个新的 DataFrame。原始 DataFrame 不能就地修改(这与 pandas DataFrame 明显不同,例如)。

DataFrame 上的操作可以分为两类:转换,它导致创建一个新的 DataFrame,和动作,它通常返回一个 Scala 类型或有一个副作用。例如 filterwithColumn 是转换,而 showhead 是动作。

转换是惰性的,就像 RDD 上的转换一样。当你通过转换现有的 DataFrame 生成一个新的 DataFrame 时,这会导致创建新 DataFrame 的执行计划的详细阐述,但数据本身并不会立即被转换。你可以使用 queryExecution 方法访问执行计划。

当你在 DataFrame 上调用一个动作时,Spark 会像处理一个常规 RDD 一样处理该动作:它隐式地构建一个无环图来解析依赖关系,处理构建被调用动作的 DataFrame 所需要的转换。

与 RDD 类似,我们可以在内存或磁盘上持久化 DataFrame:

scala> readingsDF.persist
readingsDF.type = [patientId: int, heightCm: int,...]

这与持久化 RDD 的方式相同:下次计算 RDD 时,它将被保留在内存中(前提是有足够的空间),而不是被丢弃。持久化级别也可以设置:

scala> import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel

scala> readingsDF.persist(StorageLevel.MEMORY_AND_DISK)
readingsDF.type = [patientId: int, heightCm: int, ...]

DataFrame 上的 SQL 语句

到现在为止,你可能已经注意到 DataFrame 上的许多操作都是受 SQL 操作启发的。此外,Spark 允许我们将 DataFrame 注册为表,并直接使用 SQL 语句查询它们。因此,我们可以将临时数据库作为程序流程的一部分构建。

让我们将 readingsDF 注册为临时表:

scala> readingsDF.registerTempTable("readings")

这注册了一个临时表,该表可以在 SQL 查询中使用。注册临时表依赖于 SQL 上下文的存在。当 SQL 上下文被销毁时(例如,当我们关闭 shell 时),临时表将被销毁。

让我们探索我们可以使用我们的临时表和 SQL 上下文做什么。我们首先可以获取上下文中当前注册的所有表的列表:

scala> sqlContext.tables
DataFrame = [tableName: string, isTemporary: boolean]

这将返回一个 DataFrame。一般来说,所有在 SQL 上下文中返回数据的操作都会返回 DataFrames:

scala> sqlContext.tables.show
+---------+-----------+
|tableName|isTemporary|
+---------+-----------+
| readings|       true|
+---------+-----------+

我们可以通过向 SQL 上下文中传递 SQL 语句来查询这个表:

scala> sqlContext.sql("SELECT * FROM readings").show
+---------+--------+--------+---+--------+
|patientId|heightCm|weightKg|age|isSmoker|
+---------+--------+--------+---+--------+
|        1|     175|      72| 43|   false|
|        2|     182|      78| 28|    true|
|        3|     164|      61| 41|   false|
|        4|     161|      62| 43|    true|
+---------+--------+--------+---+--------+

sqlContext 中注册的任何 UDF 都可以通过它们注册时给出的名称访问。因此,我们可以在 SQL 查询中使用它们:

scala> sqlContext.sql("""
SELECT 
 patientId, 
 likelyMaleUdf(heightCm, weightKg) AS likelyMale
FROM readings
""").show
+---------+----------+
|patientId|likelyMale|
+---------+----------+
|        1|      true|
|        2|      true|
|        3|     false|
|        4|     false|
+---------+----------+

你可能会想知道为什么有人想要将 DataFrames 注册为临时表并在这些表上运行 SQL 查询,当同样的功能可以直接在 DataFrames 上使用时。主要原因是为了与外部工具交互。Spark 可以运行一个 SQL 引擎,该引擎公开 JDBC 接口,这意味着知道如何与 SQL 数据库交互的程序将能够使用临时表。

我们没有足够的空间在这本书中介绍如何设置分布式 SQL 引擎,但您可以在 Spark 文档中找到详细信息(spark.apache.org/docs/latest/sql-programming-guide.html#distributed-sql-engine)。

复杂数据类型 - 数组、映射和 struct

到目前为止,我们 DataFrame 中的所有元素都是简单类型。DataFrames 支持三种额外的集合类型:数组、映射和 struct。

Structs

我们将要查看的第一个复合类型是 struct。一个 struct 类似于 case class:它存储一组键值对,具有一组固定的键。如果我们将包含嵌套 case class 的 case class RDD 转换为 DataFrame,Spark 将将嵌套对象转换为 struct。

让我们想象一下,我们想要序列化《指环王》中的角色。我们可能会使用以下对象模型:

case class Weapon(name:String, weaponType:String)
case class LotrCharacter(name:String, val weapon:Weapon)

我们想要创建一个 LotrCharacter 实例的 DataFrame。让我们创建一些虚拟数据:

scala> val characters = List(
 LotrCharacter("Gandalf", Weapon("Glamdring", "sword")),
 LotrCharacter("Frodo", Weapon("Sting", "dagger")),
 LotrCharacter("Aragorn", Weapon("Anduril", "sword"))
)
characters: List[LotrCharacter] = List(LotrCharacter...

scala> val charactersDF = sc.parallelize(characters).toDF
charactersDF: DataFrame = [name: string, weapon: struct<name:string,weaponType:string>]

scala> charactersDF.printSchema
root
 |-- name: string (nullable = true)
 |-- weapon: struct (nullable = true)
 |    |-- name: string (nullable = true)
 |    |-- weaponType: string (nullable = true)

scala> charactersDF.show
+-------+-----------------+
|   name|           weapon|
+-------+-----------------+
|Gandalf|[Glamdring,sword]|
|  Frodo|   [Sting,dagger]|
|Aragorn|  [Anduril,sword]|
+-------+-----------------+

在 case class 中的 weapon 属性在 DataFrame 中被转换为 struct 列。要从 struct 中提取子字段,我们可以将字段名传递给列的 .apply 方法:

scala> val weaponTypeColumn = charactersDF("weapon")("weaponType")
weaponTypeColumn: org.apache.spark.sql.Column = weapon[weaponType]

我们可以使用这个派生列就像我们使用任何其他列一样。例如,让我们过滤我们的 DataFrame,只包含挥舞着剑的角色:

scala> charactersDF.filter { weaponTypeColumn === "sword" }.show
+-------+-----------------+
|   name|           weapon|
+-------+-----------------+
|Gandalf|[Glamdring,sword]|
|Aragorn|  [Anduril,sword]|
+-------+-----------------+

Arrays

让我们回到之前的例子,并假设除了身高、体重和年龄测量值之外,我们还有我们患者的电话号码。每个患者可能有零个、一个或多个电话号码。我们将定义一个新的 case class 和新的虚拟数据:

scala> case class PatientNumbers(
 patientId:Int, phoneNumbers:List[String])
defined class PatientNumbers

scala> val numbers = List(
 PatientNumbers(1, List("07929123456")),
 PatientNumbers(2, List("07929432167", "07929234578")),
 PatientNumbers(3, List.empty),
 PatientNumbers(4, List("07927357862"))
)

scala> val numbersDF = sc.parallelize(numbers).toDF
numbersDF: org.apache.spark.sql.DataFrame = [patientId: int, phoneNumbers: array<string>]

在我们的 case class 中,List[String] 数组被转换为 array<string> 数据类型:

scala> numbersDF.printSchema
root
 |-- patientId: integer (nullable = false)
 |-- phoneNumbers: array (nullable = true)
 |    |-- element: string (containsNull = true)

与 structs 类似,我们可以为数组中的特定索引构造一个列。例如,我们可以选择每个数组中的第一个元素:

scala> val bestNumberColumn = numbersDF("phoneNumbers")(0)
bestNumberColumn: org.apache.spark.sql.Column = phoneNumbers[0]

scala> numbersDF.withColumn("bestNumber", bestNumberColumn).show
+---------+--------------------+-----------+
|patientId|        phoneNumbers| bestNumber|
+---------+--------------------+-----------+
|        1|   List(07929123456)|07929123456|
|        2|List(07929432167,...|07929432167|
|        3|              List()|       null|
|        4|   List(07927357862)|07927357862|
+---------+--------------------+-----------+

Maps

最后的复合数据类型是映射。映射在存储键值对方面与 structs 类似,但 DataFrame 创建时键的集合不是固定的。因此,它们可以存储任意键值对。

当构建 DataFrame 时,Scala 映射将被转换为 DataFrame 映射。然后可以以类似结构体的方式查询它们。

与数据源交互

在数据科学或工程中,一个主要挑战是处理用于持久化数据的丰富输入和输出格式。我们可能以 CSV 文件、JSON 文件或通过 SQL 数据库的形式接收或发送数据,仅举几例。

Spark 提供了一个统一的 API,用于将 DataFrame 序列化和反序列化到不同的数据源。

JSON 文件

Spark 支持从 JSON 文件加载数据,前提是 JSON 文件中的每一行都对应一个 JSON 对象。每个对象将被映射到 DataFrame 行。JSON 数组被映射到数组,嵌套对象被映射到结构体。

如果没有一些数据,本节可能会显得有些枯燥,所以让我们从 GitHub API 生成一些数据。不幸的是,GitHub API 并不返回每行一个 JSON 格式的对象。本章的代码库包含一个名为 FetchData.scala 的脚本,该脚本将下载并格式化 Martin Odersky 的存储库的 JSON 条目,并将对象保存到名为 odersky_repos.json 的文件中(如果你想的话,请更改 FetchData.scala 中的 GitHub 用户)。你也可以从 data.scala4datascience.com/odersky_repos.json 下载预先构建的数据文件。

让我们进入 Spark shell 并将此数据加载到 DataFrame 中。从 JSON 文件读取就像将文件名传递给 sqlContext.read.json 方法一样简单:

scala> val df = sqlContext.read.json("odersky_repos.json")
df: DataFrame = [archive_url: string, assignees_url: ...]

从 JSON 文件读取数据时,数据被加载为 DataFrame。Spark 会自动从 JSON 文档中推断模式。我们的 DataFrame 中有许多列。让我们子选择一些列以获得更易于管理的 DataFrame:

scala> val reposDF = df.select("name", "language", "fork", "owner")
reposDF: DataFrame = [name: string, language: string, ...] 

scala> reposDF.show
+----------------+----------+-----+--------------------+
|            name|  language| fork|               owner|
+----------------+----------+-----+--------------------+
|           dotty|     Scala| true|[https://avatars....|
|        frontend|JavaScript| true|[https://avatars....|
|           scala|     Scala| true|[https://avatars....|
|      scala-dist|     Scala| true|[https://avatars....|
|scala.github.com|JavaScript| true|[https://avatars....|
|          scalax|     Scala|false|[https://avatars....|
|            sips|       CSS|false|[https://avatars....|
+----------------+----------+-----+--------------------+

让我们将 DataFrame 保存回 JSON:

scala> reposDF.write.json("repos_short.json")

如果你查看运行 Spark shell 的目录中的文件,你会注意到一个 repos_short.json 目录。在里面,你会看到名为 part-000000part-000001 等的文件。当序列化 JSON 时,DataFrame 的每个分区都是独立序列化的。如果你在多台机器上运行此操作,你将在每台计算机上找到序列化输出的部分。

你可以选择性地传递一个 mode 参数来控制 Spark 如何处理现有的 repos_short.json 文件:

scala> import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.SaveMode

scala> reposDF.write.mode(
 SaveMode.Overwrite).json("repos_short.json")

可用的保存模式有 ErrorIfExistsAppend(仅适用于 Parquet 文件)、OverwriteIgnore(如果文件已存在则不保存)。

Parquet 文件

Apache Parquet 是一种流行的文件格式,非常适合存储表格数据。它常用于 Hadoop 生态系统中的序列化,因为它允许在不读取整个文件的情况下高效地提取特定列和行。

Parquet 文件的序列化和反序列化与 JSON 相同,只需将 json 替换为 parquet

scala> reposDF.write.parquet("repos_short.parquet")

scala> val newDF = sqlContext.read.parquet("repos_short.parquet")
newDF: DataFrame = [name: string, language: string, fo...]

scala> newDF.show
+----------------+----------+-----+--------------------+
|            name|  language| fork|               owner|
+----------------+----------+-----+--------------------+
|           dotty|     Scala| true|[`avatars....|
|        frontend|JavaScript| true|[https://avatars....|
|           scala|     Scala| true|[https://avatars....|
|      scala-dist|     Scala| true|[https://avatars....|
|scala.github.com|JavaScript| true|[https://avatars....|
|          scalax|     Scala|false|[https://avatars....|
|            sips|       CSS|false|[https://avatars....|
+----------------+----------+-----+--------------------+

通常,Parquet 在存储大量对象集合时比 JSON 更节省空间。如果可以从行中推断出分区,Parquet 在检索特定列或行时也更为高效。因此,除非您需要输出可由外部程序读取或反序列化,否则 Parquet 相对于 JSON 具有优势。

独立程序

到目前为止,我们一直通过 Spark shell 使用 Spark SQL 和 DataFrame。要在独立程序中使用它,您需要从 Spark 上下文中显式创建它:

val conf = new SparkConf().setAppName("applicationName")
val sc = new SparkContext(conf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)

此外,导入嵌套在sqlContext中的implicits对象允许将 RDD 转换为 DataFrame:

import sqlContext.implicits._

在下一章中,我们将广泛使用 DataFrame 来操纵数据,使其准备好与 MLlib 一起使用。

摘要

在本章中,我们探讨了 Spark SQL 和 DataFrame。DataFrame 在 Spark 的核心引擎之上增加了一层丰富的抽象层,极大地简化了表格数据的操作。此外,源 API 允许从丰富的数据文件中序列化和反序列化 DataFrame。

在下一章中,我们将基于我们对 Spark 和 DataFrame 的知识来构建一个使用 MLlib 的垃圾邮件过滤器。

参考文献

DataFrame 是 Spark 相对较新的功能。因此,相关的文献和文档仍然很少。首先应该查阅 Scala 文档,可在以下网址找到:http://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.DataFrame`

DataFrame Column类型上可用的操作 Scaladocs 可在以下网址找到:spark.apache.org/docs/latest/api/scala/#org.apache.spark.sql.Column

关于 Parquet 文件格式的详细文档也可在以下网址找到:parquet.apache.org

第十二章。使用 MLlib 进行分布式机器学习

机器学习描述了构建从数据中进行预测的算法。它是大多数数据科学流程的核心组件,通常被认为是增加最大价值的组件:机器学习算法的准确性决定了数据科学项目的成功。它还可能是数据科学流程中需要从软件工程以外的领域获取最多知识的部分:机器学习专家不仅熟悉算法,还熟悉统计学和业务领域。

选择和调整机器学习算法来解决特定问题涉及大量的探索性分析,以尝试确定哪些特征是相关的,特征之间的相关性如何,数据集中是否存在异常值等等。设计合适的机器学习管道是困难的。再加上数据集的大小和可扩展性需求带来的额外复杂性,您就面临了一个真正的挑战。

MLlib 有助于缓解这种困难。MLlib 是 Spark 的一个组件,它提供了核心 Spark 库之上的机器学习算法。它提供了一套学习算法,这些算法在分布式数据集上并行化效果很好。

MLlib 已经发展成两个独立的层。MLlib 本身包含核心算法,而 ml,也称为 pipeline API,定义了一个用于将算法粘合在一起的 API,并提供了一个更高层次的抽象。这两个库在它们操作的数据类型上有所不同:原始的 MLlib 在 DataFrame 引入之前就已经存在,主要作用于特征向量的 RDD。pipeline API 在 DataFrame 上操作。

在本章中,我们将研究较新的 pipeline API,只有在管道 API 缺少功能时才会深入研究 MLlib。

本章并不试图教授我们所展示算法背后的机器学习基础。我们假设读者对机器学习工具和技术有足够的了解,至少能够表面上理解这里展示的算法做什么,我们将深入解释统计学习机制的部分工作留给更好的作者(我们在本章末尾提供了一些参考文献)。

MLlib 是一个快速发展的丰富库。本章的目标不是提供一个完整的库概述。我们将通过构建一个用于训练垃圾邮件过滤器的机器学习管道来工作,在这个过程中了解我们需要的 MLlib 的各个部分。阅读完本章后,您将了解库的不同部分是如何结合在一起的,并且可以使用在线文档或更专业的书籍(请参阅本章末尾的参考文献)来了解这里未涵盖的 MLlib 部分。

介绍 MLlib – 垃圾邮件分类

让我们用一个具体的例子来介绍 MLlib。我们将查看使用我们在第十章分布式批量处理与 Spark 中使用的 Ling-Spam 数据集进行的垃圾邮件分类,我们将创建一个使用逻辑回归来估计给定消息是否为垃圾邮件的概率的垃圾邮件过滤器。

我们将通过 Spark shell 运行示例,但你将在本章的示例中找到类似程序,在LogisticRegressionDemo.scala中。如果你还没有安装 Spark,请参考第十章分布式批量处理 Spark,获取安装说明。

让我们从加载 Ling-Spam 数据集中的电子邮件开始。如果你在第十章分布式批量处理 Spark 中没有这样做,请从data.scala4datascience.com/ling-spam.tar.gzdata.scala4datascience.com/ling-spam.zip下载数据,根据你想要tar.gz文件还是zip文件来选择,然后解压存档。这将创建一个spam目录和一个ham目录,分别包含垃圾邮件和正常邮件。

让我们使用wholeTextFiles方法来加载垃圾邮件和正常邮件:

scala> val spamText = sc.wholeTextFiles("spam/*")
spamText: RDD[(String, String)] = spam/...

scala> val hamText = sc.wholeTextFiles("ham/*")
hamText: RDD[(String, String)] = ham/...

wholeTextFiles方法创建一个键值 RDD,其中键是文件名,值是文件内容:

scala> spamText.first
(String, String) =
(file:spam/spmsga1.txt,"Subject: great part-time summer job! ...")

scala> spamText.count
Long = 481

管道 API 中的算法在 DataFrame 上工作。因此,我们必须将我们的键值 RDD 转换为 DataFrame。我们定义一个新的 case class,LabelledDocument,它包含一个消息文本和一个类别标签,用于标识消息是spam还是ham

scala> case class LabelledDocument(
 fileName:String, 
 text:String, 
 category:String
)
defined class LabelledDocument

scala> val spamDocuments = spamText.map {
 case (fileName, text) => 
 LabelledDocument(fileName, text, "spam")
}
spamDocuments: RDD[LabelledDocument] = MapPartitionsRDD[2] at map

scala> val hamDocuments = hamText.map {
 case (fileName, text) => 
 LabelledDocument(fileName, text, "ham")
}
hamDocuments: RDD[LabelledDocument] = MapPartitionsRDD[3] at map

要创建模型,我们需要将所有文档放入一个 DataFrame 中。因此,我们将两个LabelledDocument RDD 合并,并将其转换为 DataFrame。union方法将 RDD 连接起来:

scala> val allDocuments = spamDocuments.union(hamDocuments)
allDocuments: RDD[LabelledDocument] = UnionRDD[4] at union

scala> val documentsDF = allDocuments.toDF
documentsDF: DataFrame = [fileName: string, text: string, category: string]

让我们做一些基本的检查来验证我们已经加载了所有文档。我们首先将 DataFrame 保存在内存中,以避免需要从原始文本文件中重新创建它。

scala> documentsDF.persist
documentsDF.type = [fileName: string, text: string, category: string]

scala> documentsDF.show
+--------------------+--------------------+--------+
|            fileName|                text|category|
+--------------------+--------------------+--------+
|file:/Users/pasca...|Subject: great pa...|    spam|
|file:/Users/pasca...|Subject: auto ins...|    spam|
|file:/Users/pasca...|Subject: want bes...|    spam|
|file:/Users/pasca...|Subject: email 57...|    spam|
|file:/Users/pasca...|Subject: n't miss...|    spam|
|file:/Users/pasca...|Subject: amaze wo...|    spam|
|file:/Users/pasca...|Subject: help loa...|    spam|
|file:/Users/pasca...|Subject: beat irs...|    spam|
|file:/Users/pasca...|Subject: email 57...|    spam|
|file:/Users/pasca...|Subject: best , b...|    spam|
|...                                               |
+--------------------+--------------------+--------+

scala> documentsDF.groupBy("category").agg(count("*")).show
+--------+--------+
|category|COUNT(1)|
+--------+--------+
|    spam|     481|
|     ham|    2412|
+--------+--------+

现在让我们将 DataFrame 分割成训练集和测试集。我们将使用测试集来验证我们构建的模型。现在,我们将只使用一个分割,用 70%的数据训练模型,用剩余的 30%进行测试。在下一节中,我们将探讨交叉验证,它提供了一种更严格的方式来检查我们模型的准确性。

我们可以使用 DataFrame 的.randomSplit方法实现 70-30 的分割:

scala> val Array(trainDF, testDF) = documentsDF.randomSplit(
 Array(0.7, 0.3))
trainDF: DataFrame = [fileName: string, text: string, category: string]
testDF: DataFrame = [fileName: string, text: string, category: string]

.randomSplit方法接受一个权重数组,并返回一个 DataFrame 数组,其大小大约由权重指定。例如,我们传递了权重0.70.3,表示任何给定行有 70%的机会最终进入trainDF,有 30%的机会进入testDF。请注意,这意味着分割的 DataFrame 大小不是固定的:trainDF大约是documentsDF的 70%,但不是正好 70%:

scala> trainDF.count / documentsDF.count.toDouble
Double = 0.7013480815762184

如果你需要一个固定大小的样本,请使用 DataFrame 的.sample方法来获取trainDF,并过滤documentDF以排除trainDF中的行。

我们现在可以开始使用 MLlib 了。我们的分类尝试将涉及在词频向量上执行逻辑回归:我们将计算每个单词在每个消息中出现的频率,并使用发生频率作为特征。在深入代码之前,让我们退一步来讨论机器学习管道的结构。

管道组件

管道由一系列组件组成,这些组件连接在一起,使得一个组件生成的 DataFrame 被用作下一个组件的输入。可用的组件分为两类:转换器估计器

转换器

转换器将一个 DataFrame 转换成另一个,通常是通过添加一个或多个列。

我们垃圾邮件分类算法的第一步是将每个消息分割成一个单词数组。这被称为分词。我们可以使用 MLlib 提供的Tokenizer转换器:

scala> import org.apache.spark.ml.feature._
import org.apache.spark.ml.feature._

scala> val tokenizer = new Tokenizer()
tokenizer: org.apache.spark.ml.feature.Tokenizer = tok_75559f60e8cf 

可以通过获取器和设置器来定制转换器的行为。获取可用参数列表的最简单方法是通过调用.explainParams方法:

scala> println(tokenizer.explainParams)
inputCol: input column name (undefined)
outputCol: output column name (default: tok_75559f60e8cf__output)

我们看到,可以使用两个参数来定制Tokenizer实例的行为:inputColoutputCol,分别描述包含输入(要分词的字符串)和输出(单词数组)的列的标题。我们可以使用setInputColsetOutputCol方法设置这些参数。

我们将inputCol设置为"text",因为在我们的训练和测试 DataFrame 中,该列被命名为text。我们将outputCol设置为"words"

scala> tokenizer.setInputCol("text").setOutputCol("words")
org.apache.spark.ml.feature.Tokenizer = tok_75559f60e8cf

在适当的时候,我们将tokenizer集成到管道中,但现在,让我们只使用它来转换训练 DataFrame,以验证其是否正确工作。

scala> val tokenizedDF = tokenizer.transform(trainDF)
tokenizedDF: DataFrame = [fileName: string, text: string, category: string, words: array<string>]

scala> tokenizedDF.show
+--------------+----------------+--------+--------------------+
|      fileName|            text|category|               words|
+--------------+----------------+--------+--------------------+
|file:/Users...|Subject: auto...|    spam|[subject:, auto, ...|
|file:/Users...|Subject: want...|    spam|[subject:, want, ...|
|file:/Users...|Subject: n't ...|    spam|[subject:, n't, m...|
|file:/Users...|Subject: amaz...|    spam|[subject:, amaze,...|
|file:/Users...|Subject: help...|    spam|[subject:, help, ...|
|file:/Users...|Subject: beat...|    spam|[subject:, beat, ...|
|...                                                          |
+--------------+----------------+--------+--------------------+

tokenizer转换器生成一个新的 DataFrame,其中包含一个额外的列words,包含text列中的单词数组。

显然,我们可以使用我们的tokenizer来转换具有正确模式的任何 DataFrame。例如,我们可以将其用于测试集。机器学习的许多方面都涉及到在不同的数据集上调用相同的(或非常相似的)管道。通过提供管道抽象,MLlib 简化了由许多清理、转换和建模组件组成的复杂机器学习算法的推理。

我们管道中的下一步是计算每个消息中每个单词的出现频率。我们最终将使用这些频率作为算法中的特征。我们将使用HashingTF转换器将单词数组转换为每个消息的词频向量。

HashingTF转换器从输入可迭代对象构建一个词频的稀疏向量。单词数组中的每个元素都被转换为一个哈希码。这个哈希码被截断为一个介于0和输出向量中元素总数大数 n之间的值。词频向量仅仅是截断哈希的出现次数。

让我们手动运行一个示例来理解它是如何工作的。我们将计算Array("the", "dog", "jumped", "over", "the")的词频向量。在这个例子中,我们将稀疏输出向量中的元素数量n设为 16。第一步是计算数组中每个元素的哈希码。我们可以使用内置的##方法,该方法为任何对象计算哈希码:

scala> val words = Array("the", "dog", "jumped", "over", "the")
words: Array[String] = Array(the, dog, jumped, over, the)

scala> val hashCodes = words.map { _.## }
hashCodes: Array[Int] = Array(114801, 99644, -1148867251, 3423444, 114801)

为了将哈希码转换为有效的向量索引,我们将每个哈希码对向量的大小(在这种情况下为16)取模:

scala> val indices = hashCodes.map { code => Math.abs(code % 16) }
indices: Array[Int] = Array(1, 12, 3, 4, 1)

然后,我们可以创建一个从索引到该索引出现次数的映射:

scala> val indexFrequency = indices.groupBy(identity).mapValues {
 _.size.toDouble
}
indexFrequency: Map[Int,Double] = Map(4 -> 1.0, 1 -> 2.0, 3 -> 1.0, 12 -> 1.0)

最后,我们可以将此映射转换为稀疏向量,其中向量的每个元素的值是此特定索引出现的频率:

scala> import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg._

scala> val termFrequencies = Vectors.sparse(16, indexFrequency.toSeq)
termFrequencies: linalg.Vector = (16,[1,3,4,12],[2.0,1.0,1.0,1.0])

注意,稀疏向量的.toString输出由三个元素组成:向量的总大小,后跟两个列表:第一个是索引系列,第二个是这些索引处的值系列。

使用稀疏向量提供了一种紧凑且高效的方式来表示消息中单词出现的频率,这正是HashingTF在底层的工作方式。缺点是,从单词到索引的映射不一定唯一:通过向量的长度截断哈希码将不同的字符串映射到相同的索引。这被称为碰撞。解决方案是使n足够大,以最小化碰撞的频率。

小贴士

HashingTF类似于构建一个哈希表(例如,Scala 映射),其键是单词,值是单词在消息中出现的次数,但有一个重要区别:它不试图处理哈希冲突。因此,如果两个单词映射到相同的哈希,它们将具有错误的频率。使用此算法而不是仅构建哈希表有两个优点:

  • 我们不需要在内存中维护一个不同单词的列表。

  • 每封电子邮件都可以独立于其他所有电子邮件转换为向量:我们不需要在不同的分区上执行降维操作来获取映射中的键集。这极大地简化了将此算法应用于分布式环境中的每封电子邮件,因为我们可以在每个分区上独立应用HashingTF转换。

主要缺点是我们必须使用能够高效利用稀疏表示的机器学习算法。这种情况适用于逻辑回归,我们将在下面使用。

如您所预期的那样,HashingTF转换器接受输入和输出列作为参数。它还接受一个参数,定义向量中不同的哈希桶的数量。增加桶的数量会减少冲突的数量。在实践中,建议的值在TransformersTransformers之间。

scala> val hashingTF = (new HashingTF()
 .setInputCol("words")
 .setOutputCol("features")
 .setNumFeatures(1048576))
hashingTF: org.apache.spark.ml.feature.HashingTF = hashingTF_3b78eca9595c

scala> val hashedDF = hashingTF.transform(tokenizedDF)
hashedDF: DataFrame = [fileName: string, text: string, category: string, words: array<string>, features: vector]

scala> hashedDF.select("features").show
+--------------------+
|            features|
+--------------------+
|(1048576,[0,33,36...|
|(1048576,[0,36,40...|
|(1048576,[0,33,34...|
|(1048576,[0,33,36...|
|(1048576,[0,33,34...|
|(1048576,[0,33,34...|
+--------------------+

features列中的每个元素都是一个稀疏向量:

scala> import org.apache.spark.sql.Row
import org.apache.spark.sql.Row

scala> val firstRow = hashedDF.select("features").first
firstRow: org.apache.spark.sql.Row = ...

scala> val Row(v:Vector) = firstRow
v: Vector = (1048576,[0,33,36,37,...],[1.0,3.0,4.0,1.0,...])

因此,我们可以将我们的向量解释为:哈希到元素33的单词出现三次,哈希到元素36的单词出现四次等等。

估计器

现在,我们已经为逻辑回归准备好了特征。在运行逻辑回归之前,最后一步是创建目标变量。我们将 DataFrame 中的category列转换为二进制 0/1 目标列。Spark 提供了一个StringIndexer类,该类将列中的字符串集替换为双精度浮点数。StringIndexer不是一个转换器:它必须首先与一组类别拟合以计算从字符串到数值值的映射。这引入了管道 API 中的第二类组件:估计器

与“开箱即用”的转换器不同,估计器必须与 DataFrame 拟合。对于我们的字符串索引器,拟合过程包括获取唯一字符串列表("spam""ham")并将每个这些映射到双精度浮点数。拟合过程输出一个转换器,该转换器可以用于后续的 DataFrames。

scala> val indexer = (new StringIndexer()
 .setInputCol("category")
 .setOutputCol("label"))
indexer: org.apache.spark.ml.feature.StringIndexer = strIdx_16db03fd0546

scala> val indexTransform = indexer.fit(trainDF)
indexTransform: StringIndexerModel = strIdx_16db03fd0546

拟合过程产生的转换器有一个labels属性,描述了它应用的映射:

scala> indexTransform.labels
Array[String] = Array(ham, spam)

每个标签都将映射到数组中的索引:因此,我们的转换器将ham映射到0,将spam映射到1

scala> val labelledDF = indexTransform.transform(hashedDF)
labelledDF: org.apache.spark.sql.DataFrame = [fileName: string, text: string, category: string, words: array<string>, features: vector, label: double]

scala> labelledDF.select("category", "label").distinct.show
+--------+-----+
|category|label|
+--------+-----+
|     ham|  0.0|
|    spam|  1.0|
+--------+-----+

现在我们有了适合逻辑回归的正确格式的特征向量和分类标签。执行逻辑回归的组件是一个估计器:它被拟合到一个训练 DataFrame 中,以创建一个训练好的模型。然后,可以使用该模型来转换测试 DataFrame。

scala> import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.classification.LogisticRegression

scala> val classifier = new LogisticRegression().setMaxIter(50)
classifier: LogisticRegression = logreg_a5e921e7c1a1 

LogisticRegression估计器默认期望特征列命名为"features",标签列(目标)命名为"label"。没有必要明确设置这些,因为它们与hashingTFindexer设置的列名匹配。有几个参数可以设置以控制逻辑回归的工作方式:

scala> println(classifier.explainParams)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. (default: 0.0)
fitIntercept: whether to fit an intercept term (default: true)
labelCol: label column name (default: label)
maxIter: maximum number of iterations (>= 0) (default: 100, current: 50)
regParam: regularization parameter (>= 0) (default: 0.0)
threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
tol: the convergence tolerance for iterative algorithms (default: 1.0E-6)
...

现在,我们只设置maxIter参数。稍后我们将研究其他参数的影响,例如正则化。现在,让我们将分类器拟合到labelledDF

scala> val trainedClassifier = classifier.fit(labelledDF)
trainedClassifier: LogisticRegressionModel = logreg_353d18f6a5f0

这产生了一个转换器,我们可以将其应用于具有features列的 DataFrame。转换器附加了一个prediction列和一个probability列。例如,我们可以使用trainedClassifierlabelledDF(训练集本身)转换:

scala> val labelledDFWithPredictions = trainedClassifier.transform(
 labelledDF)
labelledDFWithPredictions: DataFrame = fileName: string, ...

scala> labelledDFWithPredictions.select($"label", $"prediction").show
+-----+----------+
|label|prediction|
+-----+----------+
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
+-----+----------+

检查我们模型性能的一个快速方法是仅计算误分类消息的数量:

scala> labelledDFWithPredictions.filter { 
 $"label" !== $"prediction" 
}.count
Long = 1

在这种情况下,逻辑回归成功地将训练集中除了一条消息外的所有消息正确分类。考虑到特征数量众多,以及垃圾邮件和合法电子邮件中使用的单词之间的相对清晰界限,这也许并不令人惊讶。

当然,对模型的真实测试不是它在训练集上的表现,而是在测试集上的表现。为了测试这一点,我们可以将测试 DataFrame 通过与我们用于训练模型的相同阶段,用它们产生的拟合转换器替换估计器。MLlib 提供了管道抽象来简化这个过程:我们将有序列表的转换器和估计器包装在管道中。然后,这个管道被拟合到一个对应于训练集的 DataFrame 中。拟合产生一个PipelineModel实例,相当于管道,但估计器被转换器替换,如图所示:

![估计器

让我们构建我们的逻辑回归垃圾邮件过滤器的管道:

scala> import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.Pipeline

scala> val pipeline = new Pipeline().setStages(
 Array(indexer, tokenizer, hashingTF, classifier)
)
pipeline: Pipeline = pipeline_7488113e284d

一旦定义了管道,我们就将其拟合到包含训练集的 DataFrame 中:

scala> val fittedPipeline = pipeline.fit(trainDF)
fittedPipeline: org.apache.spark.ml.PipelineModel = pipeline_089525c6f100

当将管道拟合到 DataFrame 时,估计器和转换器被处理得不同:

  • 转换器被应用到 DataFrame 中,并直接复制到管道模型中。

  • 估计器被拟合到 DataFrame 中,生成一个转换器。然后,转换器被应用到 DataFrame 上,并附加到管道模型中。

我们现在可以将管道模型应用到测试集上:

scala> val testDFWithPredictions = fittedPipeline.transform(testDF)
testDFWithPredictions: DataFrame = fileName: string, ...

这在 DataFrame 中添加了一个prediction列,其中包含我们的逻辑回归模型的预测结果。为了衡量我们算法的性能,我们在测试集上计算分类错误:

scala> testDFWithPredictions.filter { 
 $"label" !== $"prediction" 
}.count
Long = 20

因此,我们的朴素逻辑回归算法,没有模型选择或正则化,将 2.3%的电子邮件误分类。当然,由于训练集和测试集的划分是随机的,你可能会得到略微不同的结果。

让我们将包含预测的培训和测试 DataFrame 保存为parquet文件:

scala> import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.SaveMode

scala> (labelledDFWithPredictions
 .select("fileName", "label", "prediction", "probability")
 .write.mode(SaveMode.Overwrite)
 .parquet("transformedTrain.parquet"))

scala> (testDFWithPredictions
 .select("fileName", "label", "prediction", "probability")
 .write.mode(SaveMode.Overwrite)
 .parquet("transformedTest.parquet"))

小贴士

在垃圾邮件分类中,误报比漏报要严重得多:将合法邮件误判为垃圾邮件,比让垃圾邮件通过要糟糕得多。为了解决这个问题,我们可以提高分类的阈值:只有得分达到 0.7 或以上的邮件才会被分类为垃圾邮件。这引发了选择正确阈值的问题。一种方法是对不同阈值在测试集上产生的误报率进行调查,并选择最低的阈值以获得可接受的误报率。可视化这一点的良好方法是使用 ROC 曲线,我们将在下一节中探讨。

评估

不幸的是,截至版本 1.5.2,管道 API 中评估模型质量的功能仍然有限。逻辑回归确实输出一个包含多个评估指标(通过训练模型的summary属性可用)的摘要,但这些是在训练集上计算的。通常,我们希望在训练集和单独的测试集上评估模型性能。因此,我们将深入到底层的 MLlib 层以访问评估指标。

MLlib 提供了一个模块,org.apache.spark.mllib.evaluation,其中包含一系列用于评估模型质量的类。在这里,我们将使用BinaryClassificationMetrics类,因为垃圾邮件分类是一个二分类问题。其他评估类为多分类模型、回归模型和排序模型提供指标。

如前所述,我们将在 shell 中阐述这些概念,但您将在本章代码示例中的ROC.scala脚本中找到类似代码。我们将使用breeze-viz来绘制曲线,因此,在启动 shell 时,我们必须确保相关的库在类路径上。我们将使用 SBT assembly,如第十章[分布式批处理与 Spark 中所述,分布式批处理与 Spark(特别是构建和运行独立程序部分),来创建一个包含所需依赖项的 JAR 文件。然后我们将这个 JAR 文件传递给 Spark shell,这样我们就可以导入 breeze-viz。让我们编写一个build.sbt文件,声明对 breeze-viz 的依赖:

// build.sbt
name := "spam_filter"

scalaVersion := "2.10.5"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % "1.5.2" % "provided",
  "org.apache.spark" %% "spark-mllib" % "1.5.2" % "provided",
  "org.scalanlp" %% "breeze" % "0.11.2",
  "org.scalanlp" %% "breeze-viz" % "0.11.2",
  "org.scalanlp" %% "breeze-natives" % "0.11.2"
)

使用以下命令将依赖项打包到 jar 中:

$ sbt assembly

这将在target/scala-2.10目录下创建一个名为spam_filter-assembly-0.1-SNAPSHOT.jar的 JAR 文件。要将此 JAR 文件包含在 Spark shell 中,请使用--jars命令行参数重新启动 shell:

$ spark-shell --jars=target/scala-2.10/spam_filter-assembly-0.1-SNAPSHOT.jar

为了验证打包是否成功,尝试导入breeze.plot

scala> import breeze.plot._
import breeze.plot._

让我们加载测试集,包括预测,我们在上一节中创建并保存为parquet文件:

scala> val testDFWithPredictions = sqlContext.read.parquet(
 "transformedTest.parquet")
testDFWithPredictions: org.apache.spark.sql.DataFrame = [fileName: string, label: double, prediction: double, probability: vector]

BinaryClassificationMetrics对象期望一个RDD[(Double, Double)]对象,其中包含一对分数(分类器分配给特定电子邮件是垃圾邮件的概率)和标签(电子邮件是否实际上是垃圾邮件)。我们可以从我们的 DataFrame 中提取这个 RDD:

scala> import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.Vector

scala> import org.apache.spark.sql.Row
import org.apache.spark.sql.Row

scala> val scoresLabels = testDFWithPredictions.select(
 "probability", "label").map {
 case Row(probability:Vector, label:Double) => 
 (probability(1), label)
}
org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[3] at map at <console>:23

scala> scoresLabels.take(5).foreach(println)
(0.9999999967713409,1.0)
(0.9999983827108793,1.0)
(0.9982059900606365,1.0)
(0.9999790713978142,1.0)
(0.9999999999999272,1.0)

我们现在可以构建BinaryClassificationMetrics实例:

scala> import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import mllib.evaluation.BinaryClassificationMetrics

scala> val bm = new BinaryClassificationMetrics(scoresLabels)
bm: BinaryClassificationMetrics = mllib.evaluation.BinaryClassificationMetrics@254ed9ba

BinaryClassificationMetrics 对象包含许多用于评估分类模型性能的有用指标。我们将探讨接收者操作****特征ROC)曲线。

提示

ROC 曲线

想象逐渐降低,从 1.0 开始,我们假设特定电子邮件是垃圾邮件的概率阈值。显然,当阈值设置为 1.0 时,没有电子邮件会被分类为垃圾邮件。这意味着不会有假阳性(我们错误地将正常邮件分类为垃圾邮件),但也意味着不会有真阳性(我们正确地将垃圾邮件识别为垃圾邮件):所有垃圾邮件都会被错误地识别为正常邮件。

随着我们逐渐降低我们假设特定电子邮件是垃圾邮件的概率阈值,我们的垃圾邮件过滤器,希望如此,将开始识别大量电子邮件为垃圾邮件。其中绝大多数,如果我们的算法设计得很好,将是真正的垃圾邮件。因此,我们的真正例率增加。随着我们逐渐降低阈值,我们开始将我们不太确定的邮件分类为垃圾邮件。这将增加正确识别为垃圾邮件的邮件数量,但也会增加误报的数量。

ROC 曲线图对于每个阈值值,绘制真正例率与假正例率的比率。在最佳情况下,曲线始终为 1:这发生在所有垃圾邮件消息都被赋予 1.0 分,而所有正常邮件都被赋予 0.0 分时。相比之下,最坏的情况发生在曲线为对角线 P(真正例) = P(假正例) 时,这发生在我们的算法不如随机时。通常,ROC 曲线位于两者之间,形成一个位于对角线之上的凸壳。这个壳越深,我们的算法就越好。

评估

(左)对于一个明显优于随机性能的模型的 ROC 曲线:曲线在低误报率下达到非常高的真正例率。

(中间)对于一个显著优于随机性能的模型的 ROC 曲线。

(右)对于一个仅略优于随机性能的模型:对于任何给定的阈值,真正例率仅略高于假正例率,这意味着近一半的示例被错误分类。

我们可以使用 BinaryClassificationMetrics 实例上的 .roc 方法计算 ROC 曲线上的点数组。这返回一个 RDD[(Double, Double)],包含每个阈值值的 (假正例真正例) 比率。我们可以将其收集为数组:

scala> val rocArray = bm.roc.collect
rocArray: Array[(Double, Double)] = Array((0.0,0.0), (0.0,0.16793893129770993), ...

当然,一个数字数组并不很有启发性,所以让我们用 breeze-viz 绘制 ROC 曲线。我们首先将我们的配对数组转换为两个数组,一个为假正例,一个为真正例:

scala> val falsePositives = rocArray.map { _._1 }
falsePositives: Array[Double] = Array(0.0, 0.0, 0.0, 0.0, 0.0, ...

scala> val truePositives = rocArray.map { _._2 }
truePositives: Array[Double] = Array(0.0, 0.16793893129770993, 0.19083969465...

让我们绘制这两个数组:

scala> import breeze.plot._
import breeze.plot.

scala> val f = Figure()
f: breeze.plot.Figure = breeze.plot.Figure@3aa746cd

scala> val p = f.subplot(0)
p: breeze.plot.Plot = breeze.plot.Plot@5ed1438a

scala> p += plot(falsePositives, truePositives)
p += plot(falsePositives, truePositives)

scala> p.xlabel = "false positives"
p.xlabel: String = false positives

scala> p.ylabel = "true positives"
p.ylabel: String = true positives

scala> p.title = "ROC"
p.title: String = ROC

scala> f.refresh

ROC 曲线在 x 的一个较小值时达到 1.0:也就是说,我们以相对较少的误报为代价检索到所有真正例。为了更准确地可视化曲线,限制 x 轴的范围从 00.1 是有益的。

scala> p.xlim = (0.0, 0.1)
p.xlim: (Double, Double) = (0.0,0.1)

我们还需要告诉 breeze-viz 使用适当的刻度间隔,这需要深入到 breeze-viz 之下的 JFreeChart 层:

scala> import org.jfree.chart.axis.NumberTickUnit
import org.jfree.chart.axis.NumberTickUnit

scala> p.xaxis.setTickUnit(new NumberTickUnit(0.01))

scala> p.yaxis.setTickUnit(new NumberTickUnit(0.1))

我们现在可以保存这个图表:

scala> f.saveas("roc.png")

这将生成以下图表,存储在 roc.png 中:

评估

使用逻辑回归进行垃圾邮件分类的 ROC 曲线。注意,我们已经将假正例轴限制在 0.1

通过观察图表,我们看到我们可以过滤掉 85% 的垃圾邮件而没有单个 误报。当然,我们需要一个更大的测试集来真正验证这个假设。

图形有助于真正理解模型的行为。有时,我们只是想有一个衡量模型质量的单一指标。ROC 曲线下的面积可以是一个很好的这样的指标:

scala> bm.areaUnderROC
res21: Double = 0.9983061235861147

这可以解释如下:给定从测试集中随机抽取的两个消息,其中一个为垃圾邮件,另一个为正常邮件,模型将垃圾邮件分配给垃圾邮件消息的似然性大于正常邮件消息的似然性的概率为 99.8%。

其他衡量模型质量的指标包括特定阈值下的精确度和召回率,或者 F1 分数。所有这些都可以通过BinaryClassificationMetrics实例提供。API 文档列出了可用的方法:spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

逻辑回归中的正则化

机器学习的一个危险是过拟合:算法不仅捕获了训练集中的信号,而且还捕获了由训练集有限大小产生的统计噪声。

在逻辑回归中减轻过拟合的一种方法是使用正则化:我们在优化时对参数的大值施加惩罚。我们可以通过向代价函数添加与参数幅度成比例的惩罚来实现这一点。形式上,我们将逻辑回归代价函数(在第二章,使用 Breeze 操作数据中描述)重新写为:

逻辑回归中的正则化

其中 逻辑回归中的正则化 是标准的逻辑回归代价函数:

逻辑回归中的正则化

在这里,params 是参数向量,逻辑回归中的正则化 是第 i 个训练示例的特征向量,而 逻辑回归中的正则化 是当第 i 个训练示例是垃圾邮件时为 1,否则为 0。这与第二章,使用 Breeze 操作数据中引入的逻辑回归代价函数相同,除了添加了正则化项 逻辑回归中的正则化 和参数向量的 逻辑回归中的正则化 范数。n 的最常见值是 2,在这种情况下 逻辑回归中的正则化 只是参数向量的幅度:

逻辑回归中的正则化

额外的正则化项驱动算法减小参数向量的幅度。在使用正则化时,特征必须具有可比的幅度。这通常通过归一化特征来实现。MLlib 提供的逻辑回归估计器默认情况下归一化所有特征。这可以通过setStandardization参数关闭。

Spark 有两个可以调整的超参数来控制正则化:

  • 正则化的类型,通过elasticNetParam参数设置。0 值表示逻辑回归中的正则化正则化。

  • 正则化的程度(成本函数中的逻辑回归中的正则化),通过regParam参数设置。正则化参数的高值表示强烈的正则化。一般来说,过拟合的危险越大,正则化参数应该越大。

让我们创建一个新的逻辑回归实例,该实例使用正则化:

scala> val lrWithRegularization = (new LogisticRegression()
 .setMaxIter(50))
lrWithRegularization: LogisticRegression = logreg_16b65b325526

scala> lrWithRegularization.setElasticNetParam(0) lrWithRegularization.type = logreg_1e3584a59b3a

为了选择逻辑回归中的正则化的适当值,我们将管道拟合到训练集,并计算测试集上逻辑回归中的正则化的几个值的分类误差。在章节的后面,我们将学习 MLlib 中的交叉验证,它提供了一种更严格的方法来选择超参数。

scala> val lambdas = Array(0.0, 1.0E-12, 1.0E-10, 1.0E-8)
lambdas: Array[Double] = Array(0.0, 1.0E-12, 1.0E-10, 1.0E-8)

scala> lambdas foreach { lambda =>
 lrWithRegularization.setRegParam(lambda)
 val pipeline = new Pipeline().setStages(
 Array(indexer, tokenizer, hashingTF, lrWithRegularization))
 val model = pipeline.fit(trainDF)
 val transformedTest = model.transform(testDF)
 val classificationError = transformedTest.filter { 
 $"prediction" !== $"label"
 }.count
 println(s"$lambda => $classificationError")
}
0 => 20
1.0E-12 => 20
1.0E-10 => 20
1.0E-8 => 23

对于我们的例子,我们看到任何尝试添加 L[2]正则化的尝试都会导致分类精度的下降。

交叉验证和模型选择

在上一个例子中,我们通过在训练时保留 30%的数据,并在该子集上进行测试来验证我们的方法。这种方法并不特别严格:确切的结果取决于随机的训练-测试分割。此外,如果我们想测试几个不同的超参数(或不同的模型)以选择最佳模型,我们可能会无意中选择的模型最能反映测试集中特定行,而不是整体人群。

这可以通过交叉验证来克服。我们已经在第四章中遇到了交叉验证,并行集合和未来。在那个章节中,我们使用了随机子样本交叉验证,其中我们随机创建训练-测试分割。

在本章中,我们将使用k 折交叉验证:我们将训练集分成k部分(其中,通常k103),使用k-1部分作为训练集,最后的部分作为测试集。重复k次训练/测试周期,每次保持不同的部分作为测试集。

交叉验证通常用于选择模型的最佳超参数集。为了说明选择合适的超参数,我们将回到我们的正则化逻辑回归示例。我们不会自己直觉超参数,而是选择给我们最佳交叉验证分数的超参数。

我们将探讨设置正则化类型(通过elasticNetParam)和正则化程度(通过regParam)。找到一个好的参数值的一个粗略但有效的方法是执行网格搜索:我们计算正则化参数感兴趣值对的交叉验证分数。

我们可以使用 MLlib 的ParamGridBuilder构建参数网格。

scala> import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}

scala> val paramGridBuilder = new ParamGridBuilder()
paramGridBuilder: ParamGridBuilder = ParamGridBuilder@1dd694d0

要将优化超参数添加到网格中,我们使用addGrid方法:

scala> val lambdas = Array(0.0, 1.0E-12, 1.0E-10, 1.0E-8)
Array[Double] = Array(0.0, 1.0E-12, 1.0E-10, 1.0E-8)

scala> val elasticNetParams = Array(0.0, 1.0)
elasticNetParams: Array[Double] = Array(0.0, 1.0)

scala> paramGridBuilder.addGrid(
 lrWithRegularization.regParam, lambdas).addGrid(
 lrWithRegularization.elasticNetParam, elasticNetParams)
paramGridBuilder.type = ParamGridBuilder@1dd694d0

一旦添加了所有维度,我们只需在构建器上调用build方法来构建网格:

scala> val paramGrid = paramGridBuilder.build
paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
Array({
 logreg_f7dfb27bed7d-elasticNetParam: 0.0,
 logreg_f7dfb27bed7d-regParam: 0.0
}, {
 logreg_f7dfb27bed7d-elasticNetParam: 1.0,
 logreg_f7dfb27bed7d-regParam: 0.0
} ...)

scala> paramGrid.length
Int = 8

如我们所见,网格只是一个参数集的一维数组,在拟合逻辑回归模型之前传递给模型。

设置交叉验证管道的下一步是定义一个用于比较模型性能的指标。在本章的早期,我们看到了如何使用BinaryClassificationMetrics来估计模型的质量。不幸的是,BinaryClassificationMetrics类是核心 MLLib API 的一部分,而不是新的管道 API,因此它(不容易)兼容。管道 API 提供了一个BinaryClassificationEvaluator类。这个类直接在 DataFrame 上工作,因此非常适合管道 API 流程:

scala> import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

scala> val evaluator = new BinaryClassificationEvaluator()
evaluator: BinaryClassificationEvaluator = binEval_64b08538f1a2

scala> println(evaluator.explainParams)
labelCol: label column name (default: label)
metricName: metric name in evaluation (areaUnderROC|areaUnderPR) (default: areaUnderROC)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)

从参数列表中,我们看到BinaryClassificationEvaluator类支持两个指标:ROC 曲线下的面积和精确率-召回率曲线下的面积。它期望输入一个包含label列(模型真实值)和rawPrediction列(包含电子邮件是垃圾邮件或正常邮件的概率的列)的 DataFrame。

我们现在拥有了运行交叉验证所需的所有参数。我们首先构建管道,然后将管道、评估器和要运行交叉验证的参数数组传递给CrossValidator的一个实例:

scala> val pipeline = new Pipeline().setStages(Array(indexer, tokenizer, hashingTF, lrWithRegularization))
pipeline: Pipeline = pipeline_3ed29f72a4cc

scala> val crossval = (new CrossValidator()
 .setEstimator(pipeline)
 .setEvaluator(evaluator)
 .setEstimatorParamMaps(paramGrid)
 .setNumFolds(3))
crossval: CrossValidator = cv_5ebfa1143a9d 

我们现在将crossval拟合到trainDF

scala> val cvModel = crossval.fit(trainDF)
cvModel: CrossValidatorModel = cv_5ebfa1143a9d

这一步可能需要相当长的时间(在单台机器上可能超过一小时)。这创建了一个对应于具有最佳参数表示trainDF的逻辑回归对象的 transformer,cvModel。我们可以用它来预测测试 DataFrame 上的分类错误:

scala> cvModel.transform(testDF).filter { 
 $"prediction" !== $"label" 
}.count
Long = 20

因此,交叉验证产生了一个与原始、无超参数的朴素逻辑回归模型表现相同的模型。cvModel还包含参数网格中每组的评估分数列表:

scala> cvModel.avgMetrics
Array[Double] = Array(0.996427805316161, ...)

将其与超参数相关联的最简单方法是将它与cvModel.getEstimatorParamMaps一起压缩。这给我们一个(超参数值,交叉验证分数)对的列表:

scala> val params2score = cvModel.getEstimatorParamMaps.zip(
 cvModel.avgMetrics)
Array[(ml.param.ParamMap,Double)] = Array(({
 logreg_8f107aabb304-elasticNetParam: 0.0,
 logreg_8f107aabb304-regParam: 0.0
},0.996427805316161),...

scala> params2score.foreach {
 case (params, score) => 
 val lambda = params(lrWithRegularization.regParam)
 val elasticNetParam = params(
 lrWithRegularization.elasticNetParam)
 val l2Orl1 = if(elasticNetParam == 0.0) "L2" else "L1"
 println(s"$l2Orl1, $lambda => $score")
}
L2, 0.0 => 0.996427805316161
L1, 0.0 => 0.996427805316161
L2, 1.0E-12 => 0.9964278053175655
L1, 1.0E-12 => 0.9961429402772803
L2, 1.0E-10 => 0.9964382546369551
L1, 1.0E-10 => 0.9962223090037103
L2, 1.0E-8 => 0.9964159754613495
L1, 1.0E-8 => 0.9891008277659763

最佳的超参数集对应于 L[2]正则化,正则化参数为1E-10,尽管这仅对应于 AUC 的微小提升。

这完成了我们的垃圾邮件过滤器示例。我们已经成功地为这个特定的 Ling-Spam 数据集训练了一个垃圾邮件过滤器。为了获得更好的结果,可以尝试更好的特征提取:我们可以移除停用词或使用 TF-IDF 向量,而不是仅使用词频向量作为特征,我们还可以添加额外的特征,如消息长度,甚至n-grams。我们还可以尝试非线性算法,如随机森林。所有这些步骤都很容易添加到管道中。

逻辑回归之外

我们在本章中专注于逻辑回归,但 MLlib 提供了许多其他算法,这些算法可以更有效地捕捉数据中的非线性。管道 API 的一致性使得尝试不同的算法并查看它们的性能变得容易。管道 API 提供了用于分类的决策树、随机森林和梯度提升树,以及一个简单的前馈神经网络,这仍然是实验性的。它还提供了 Lasso 和岭回归以及用于回归的决策树,以及用于降维的 PCA。

更低级别的 MLlib API 还提供了降维的主成分分析,包括k-means 和潜在狄利克雷分配在内的几种聚类方法,以及使用交替最小二乘法的推荐系统。

摘要

MLlib 直面设计可扩展机器学习算法的挑战。在本章中,我们用它来训练一个简单的可扩展垃圾邮件过滤器。MLlib 是一个庞大且快速发展的库。了解它能提供什么最好的方式是尝试将你使用其他库(如 scikit-learn)编写的代码移植过来。

在下一章中,我们将探讨如何构建 Web API 和交互式可视化,以便与世界分享我们的结果。

参考文献

最佳的参考资料是在线文档,包括:

《Spark 高级分析》,由Sandy RyzaUri LasersonSean OwenJosh Wills所著,提供了 Spark 机器学习的详细和最新介绍。

有几本书比我们在这里介绍的更详细地介绍了机器学习。我们在这本书中多次提到了《统计学习的要素》,由FriedmanTibshiraniHastie所著。这是目前可用的最完整的机器学习数学基础介绍之一。

安德鲁·纳格(Andrew Ng)的机器学习课程在www.coursera.org/提供了机器学习的良好介绍。它使用 Octave/MATLAB 作为编程语言,但应该可以轻松地适应 Breeze 和 Scala。

第十三章。使用 Play 的网络 API

在本书的前 12 章中,我们介绍了任何想要构建数据科学应用程序的人的基本工具和库:我们学习了如何与 SQL 和 MongoDB 数据库交互,如何使用 Spark 构建快速批处理应用程序,如何使用 MLlib 应用最先进的机器学习算法,以及如何在 Akka 中构建模块化并发应用程序。

在本书的最后几章中,我们将扩展讨论,探讨一个网络框架:Play。你可能会 wonder 为什么网络框架会出现在一本数据科学书中;当然,这样的主题最好留给软件工程师或网络开发者。然而,数据科学家很少存在于真空之中。他们经常需要将结果或见解传达给利益相关者。对于一个精通统计学的某人来说,ROC 曲线可能很有说服力,但对于技术不那么精通的人来说,它可能没有那么大的分量。事实上,当见解伴随着引人入胜的视觉呈现时,销售见解可能会更容易。

许多现代交互式数据可视化应用程序都是运行在网页浏览器中的网络应用程序。通常,这些应用程序涉及D3.js,这是一个用于构建数据驱动网页的 JavaScript 库。在本章和下一章中,我们将探讨如何将 D3 与 Scala 集成。

编写一个网络应用程序是一项复杂的任务。我们将把这个任务分成本章和下一章来讨论。在本章中,我们将学习如何编写一个 REST API,我们可以将其用作应用程序的后端,或者直接查询。在下一章中,我们将探讨如何将前端代码与 Play 集成,以查询后端暴露的 API 并使用 D3 进行展示。在本章中,我们假设你对 HTTP 有至少基本的了解:你应该至少阅读过第七章,网络 API

许多数据科学家或希望成为数据科学家的人可能不太熟悉网络技术的内部工作原理。学习如何构建复杂的网站或网络 API 可能会令人望而却步。因此,本章从对动态网站和网络应用程序架构的一般讨论开始。如果你已经熟悉服务器端编程和网络框架,你可以轻松地跳过前几节。

客户端-服务器应用程序

网站通过两台计算机之间的交互来工作:客户端和服务器。如果你在网页浏览器中输入www.github.com/pbugnion/s4ds/graphs,你的浏览器会查询 GitHub 服务器之一。服务器会在其数据库中查找有关你感兴趣的仓库的信息。它将以 HTML、CSS 和 JavaScript 的形式将此信息提供给你的电脑。然后,你的浏览器负责以正确的方式解释这个响应。

如果你查看相关的 URL,你会注意到该页面上有几个图表。即使断开互联网连接,你仍然可以与这些图表进行交互。所有与图表交互所需的信息,在加载该网页时,都已经以 JavaScript 的形式传输。当你与图表互动时,使这些变化发生的 CPU 周期是在你的电脑上消耗的,而不是 GitHub 服务器。代码是在客户端执行的。相反,当你请求有关新仓库的信息时,该请求由 GitHub 服务器处理。这被称为服务器端处理。

可以在服务器上使用像 Play 这样的网络框架。对于客户端代码,我们只能使用客户端浏览器能理解的编程语言:HTML 用于布局,CSS 用于样式,JavaScript 用于逻辑,或者可以编译成 JavaScript 的语言。

网络框架简介

本节简要介绍了现代网络应用的设计方式。如果你已经熟悉编写后端代码,可以跳过这一部分。

大体上,网络框架是一套用于构建网络应用的工具和代码库。为了理解网络框架提供的内容,让我们退一步思考,如果你没有网络框架,你需要做什么。

你想要编写一个程序,监听 80 端口,并将 HTML(或 JSON 或 XML)发送回请求它的客户端。如果你要向每个客户端发送相同的文件,这很简单:只需在启动服务器时从文件中加载 HTML,并将其发送给请求它的客户端。

到目前为止,一切顺利。但如果你现在想根据客户端请求自定义 HTML,你会怎么做?你可能选择根据客户端在其浏览器中输入的 URL 的一部分,或者根据 HTTP 请求中的特定元素来做出不同的响应。例如,amazon.com上的产品页面与支付页面不同。你需要编写代码来解析 URL 和请求,然后将请求路由到相应的处理器。

现在,您可能希望根据请求的具体元素动态定制返回的 HTML。每个产品在amazon.com的页面都有相同的轮廓,但具体元素是不同的。为每个产品存储整个 HTML 内容将是浪费的。更好的方法是存储每个产品的详细信息到数据库中,并在客户端请求该产品的信息时将其注入到 HTML 模板中。您可以使用模板处理器来完成这项工作。当然,编写一个好的模板处理器是困难的。

您可能部署了您的 Web 框架,并意识到它无法处理指向它的流量。您决定响应客户端请求的处理程序应该异步运行。现在您必须处理并发。

Web 框架本质上提供了将一切连接在一起的“电线”。除了捆绑 HTTP 服务器外,大多数框架还会有一个路由器,它会根据 URL 自动将请求路由到正确的处理器。在大多数情况下,处理器将异步运行,这为您提供了更好的可扩展性。许多框架都有一个模板处理器,允许您直观地编写 HTML(有时是 JSON 或 XML)模板。一些 Web 框架还提供访问数据库、解析 JSON 或 XML、制定 HTTP 请求以及本地化和国际化的功能。

模型-视图-控制器架构

许多 Web 框架强加了程序架构:在不做出一些关于这些组件是什么的假设的情况下,很难提供将不同组件连接在一起的“电线”。模型-视图-控制器MVC)架构在 Web 上特别受欢迎,并且是 Play 框架所假设的架构。让我们依次看看每个组件:

  • 模型是应用程序背后的数据。例如,我期望 GitHub 背后的应用程序有用户、存储库、组织、拉取请求等模型。在 Play 框架中,模型通常是案例类的实例。模型的核心责任是记住应用程序的当前状态。

  • 视图是模型或一组模型在屏幕上的表示。

  • 控制器处理客户端交互,可能会更改模型。例如,如果您在 GitHub 上为项目添加星标,控制器将更新相关的模型。控制器通常携带很少的应用程序状态:记住事情是模型的工作。模型-视图-控制器架构

    MVC 架构:应用程序的状态由模型提供。视图向用户提供模型的视觉表示,控制器处理逻辑:当用户按下按钮或提交表单时应该做什么。

MVC 框架工作得很好,因为它将用户界面与底层数据和结构分离,并结构化了动作的流程:控制器可以更新模型状态或视图,模型可以向视图发送信号,告诉它更新,而视图只是显示这些信息。模型不携带与用户界面相关的任何信息。这种关注点的分离导致了对信息流的更容易的心理模型、更好的封装和更高的可测试性。

单页应用程序

客户端-服务器二元性给优雅的 MVC 架构增加了一层复杂性。模型应该放在哪里?控制器呢?传统上,模型和控制器几乎完全运行在服务器上,服务器只是将相关的 HTML 视图推送到客户端。

客户端 JavaScript 框架的增长,如 AngularJS,导致了将更多代码放入客户端的逐渐转变。控制器和模型的临时版本通常都在客户端运行。服务器仅作为 Web API 运行:例如,如果用户更新了模型,控制器将向服务器发送一个 HTTP 请求,通知它变化。

然后,将运行在服务器端和客户端的程序视为两个独立的应用程序是有意义的:服务器在数据库中持久化数据,例如,并提供一个程序接口来访问这些数据,通常是通过返回 JSON 或 XML 数据的 Web 服务。客户端程序维护自己的模型和控制器,并在需要新的模型或需要通知服务器模型持久视图应该更改时轮询服务器。

极端情况下,这导致了单页应用程序。在单页应用程序中,客户端第一次从服务器请求页面时,他会收到构建整个应用程序框架所需的 HTML 和 JavaScript。如果客户端需要从服务器获取更多数据,他将轮询服务器的 API。这些数据以 JSON 或 XML 的形式返回。

这在抽象上可能看起来有些复杂,所以让我们思考一下亚马逊网站可能作为单页应用程序的结构。我们在这里只关注产品页面,因为那已经足够复杂了。让我们想象一下,你正在主页上,点击了一个特定产品的链接。运行在你电脑上的应用程序知道如何显示产品,例如通过 HTML 模板。JavaScript 也有一个模型的原型,例如:

{
    product_id: undefined,
    product_name: undefined,
    product_price: undefined,
    ...
}

目前缺少的是关于你刚刚选择的产品应该在这些字段中放入哪些数据的知识:当网站加载时,信息不可能发送到你的电脑,因为没有方法知道你可能点击的产品(发送关于每个产品的信息将成本过高)。因此,Amazon 客户端向服务器发送关于该产品的信息请求。Amazon 服务器以 JSON 对象(或可能是 XML)的形式回复。然后客户端使用该信息更新其模型。当更新完成后,将触发一个事件来更新视图:

单页应用程序

单页应用程序中的客户端-服务器通信:当客户端首次访问网站时,它会接收到包含应用程序全部逻辑的 HTML、CSS 和 JavaScript 文件。从那时起,客户端只有在请求额外数据时才将服务器用作 API。运行在用户浏览器中的应用程序和运行在服务器上的应用程序几乎是独立的。唯一的耦合是通过服务器暴露的 API 结构。

构建应用程序

在本章和下一章中,我们将构建一个依赖于用 Play 编写的 API 的单页应用程序。我们将构建一个看起来像这样的网页:

构建应用程序

用户输入 GitHub 上某人的名字,可以查看他们的仓库列表以及一个总结他们使用语言的图表。您可以在 app.scala4datascience.com 上找到已部署的应用程序。不妨试一试。

要一窥其内部结构,请输入 app.scala4datascience.com/api/repos/odersky。这将返回一个类似以下的 JSON 对象:

[{"name":"dotty","language":"Scala","is_fork":true,"size":14653},
{"name":"frontend","language":"JavaScript","is_fork":true,"size":392},
{"name":"legacy-svn-scala","language":"Scala","is_fork":true,"size":296706},
...

我们将在本章构建 API,并在下一章编写前端代码。

Play 框架

Play 框架是在 Akka 之上构建的 Web 框架。它在业界有着可靠的记录,因此是构建可扩展 Web 应用程序的一个可靠选择。

Play 是一个有明确立场的 Web 框架:它期望你遵循 MVC 架构,并且对应该使用哪些工具有着强烈的看法。它自带 JSON 和 XML 解析器,自带访问外部 API 的工具,以及关于如何访问数据库的建议。

与我们在本书中开发的命令行脚本相比,Web 应用程序要复杂得多,因为它们包含更多的组件:后端代码、路由信息、HTML 模板、JavaScript 文件、图片等等。Play 框架对你的项目目录结构有很强的假设。从头开始构建这个结构既无聊又容易出错。幸运的是,我们可以使用Typesafe activators来启动项目(你也可以从github.com/pbugnion/s4ds的 Git 仓库下载代码,但我鼓励你从基本的 activator 结构开始项目,并使用完成的版本作为示例,边学边做)。

Typesafe activator 是 SBT 的一个定制版本,它包含模板,可以帮助 Scala 程序员快速启动。要安装 activator,你可以从www.typesafe.com/activator/download下载一个 JAR 文件,或者在 Mac OS 上通过 homebrew:

$ brew install typesafe-activator

你可以从终端启动 activator 控制台。如果你下载了 activator:

$ ./path/to/activator/activator new

或者,如果你是通过 Homebrew 安装的:

$ activator new

这将在当前目录中启动一个新的项目。它首先会询问你想要从哪个模板开始。选择 play-scala。然后它会询问你的应用程序的名称。我选择了 ghub-display,但你可以发挥创意!

让我们探索新创建的项目结构(我只保留了最重要的文件):

├── app
│   ├── controllers
│   │   └── Application.scala
│   └── views
│       ├── main.scala.html
│       └── index.scala.html
├── build.sbt
├── conf
│   ├── application.conf
│   └── routes
├── project
│   ├── build.properties
│   └── plugins.sbt
├── public
│   ├── images
│   │   └── favicon.png
│   ├── javascripts
│   │   └── hello.js
│   └── stylesheets
│       └── main.css
└── test
 ├── ApplicationSpec.scala
 └── IntegrationSpec.scala

让我们运行应用程序:

$ ./activator
[ghub-display] $ run

打开浏览器并导航到 URL 127.0.0.1:9000/。页面可能需要几秒钟才能加载。一旦加载完成,你应该看到一个默认页面,上面写着您的应用程序已准备就绪

在我们修改任何内容之前,让我们了解一下这个过程。当你要求浏览器带你到 127.0.0.1:9000/ 时,你的浏览器会向监听该地址的服务器发送一个 HTTP 请求(在这个例子中,是 Play 框架捆绑的 Netty 服务器)。请求是一个针对路由 / 的 GET 请求。Play 框架会在 conf/routes 中查找是否有满足 / 的路由:

$ cat conf/routes
# Home page
GET     /                           controllers.Application.index
...

我们可以看到,conf/routes 文件确实包含了针对 GET 请求的路由 /。该行的第二部分,controllers.Application.index,是处理该路由的 Scala 函数的名称(稍后我们会详细讨论)。让我们进行实验。将路由端点更改为 /hello。刷新浏览器而不更改 URL。这将触发应用程序的重新编译。你现在应该看到一个错误页面:

Play 框架

错误页面告诉你,应用程序不再有针对路由 / 的操作。如果你导航到 127.0.0.1:9000/hello,你应该再次看到着陆页面。

除了学习一点路由的工作原理外,我们还了解了关于开发 Play 应用程序的两大要点:

  • 在开发模式下,当你刷新浏览器并且有代码变更时,代码会被重新编译

  • 编译时和运行时错误会传播到网页

让我们将路由改回 /。关于路由还有很多要说的,但我们可以等到我们开始构建应用程序时再说。

conf/routes 文件告诉 Play 框架使用 controllers.Application.index 方法来处理对 / 的请求。让我们看看 app/controllers 中的 Application.scala 文件,其中定义了 index 方法:

// app/controllers/Application.scala
package controllers

import play.api._
import play.api.mvc._

class Application extends Controller {

  def index = Action {
    Ok(views.html.index("Your new application is ready."))
  }

}

我们可以看到 controllers.Application.index 指的是 Application 类中的 index 方法。这个方法返回类型为 ActionAction 只是一个将 HTTP 请求映射到响应的函数。在详细解释之前,让我们将操作更改为:

def index = Action {
  Ok("hello, world")
}

刷新浏览器,你应该会看到登录页面被替换为 "hello world"。通过让我们的操作返回 Ok("hello, world"),我们是在请求 Play 返回一个状态码为 200 的 HTTP 响应(表示请求成功)和正文 "hello world"

让我们回到 index 的原始内容:

Action {
  Ok(views.html.index("Your new application is ready.")) 
}

我们可以看到这调用了 views.html.index 方法。这看起来可能有些奇怪,因为任何地方都没有 views 包。但是,如果你查看 app/views 目录,你会注意到两个文件:index.scala.htmlmain.scala.html。这些是模板,在编译时,它们被转换成 Scala 函数。让我们看看 main.scala.html

// app/views/main.scala.html
@(title: String)(content: Html)

<!DOCTYPE html>

<html lang="en">
    <head>
        <title>@title</title>
        <!-- not so important stuff -->
    </head>
    <body>
        @content
    </body>
</html>

在编译时,这个模板被编译为 views.html 包中的 main(title: String)(content: Html) 函数。请注意,函数的包和名称来自模板文件名,而函数参数来自模板的第一行。模板包含嵌入的 @title@content 值,这些值由传递给函数的参数填充。让我们在 Scala 控制台中实验一下:

$ activator console
scala> import views.html._
import views.html._

scala> val title = "hello"
title: String = hello

scala> val content = new play.twirl.api.Html("<b>World</b>")
content: play.twirl.api.Html = <b>World</b>

scala> main(title)(content)
res8: play.twirl.api.HtmlFormat.Appendable =
<!DOCTYPE html>

<html lang="en">
 <head>
 <title>hello</title>
 <!-- not so important stuff -->
 </head>
 <body>
 <b>World</b>
 </body>
</html>

我们可以调用 views.html.main,就像我们调用一个普通的 Scala 函数一样。我们传递的参数被嵌入到由 views/main.scala.html 中的模板定义的正确位置。

这就结束了我们对 Play 的入门之旅。让我们简要回顾一下我们学到了什么:当一个请求到达 Play 服务器时,服务器读取 URL 和 HTTP 动词,并检查这些是否存在于其 conf/routes 文件中。然后,它将请求传递给为该路由定义的控制器中的 Action。这个 Action 返回一个 HTTP 响应,该响应被反馈到浏览器。在构建响应时,Action 可能会使用模板,对于它来说,模板只是一个 (arguments list) => String(arguments list) => HTML 的函数。

动态路由

如我们所见,路由是将 HTTP 请求映射到 Scala 处理器。路由存储在 conf/routes 中。一个路由由一个 HTTP 动词、端点和一个 Scala 函数定义:

// verb   // end-point              // Scala handler
GET       /                         controllers.Application.index

我们学会了通过向routes文件中添加行来添加新路由。然而,我们并不局限于静态路由。Play 框架允许我们在路由中包含通配符。通配符的值可以作为参数传递给控制器。为了了解这是如何工作的,让我们创建一个以人的名字作为参数的控制器。在app.controllers中的Application对象中添加:

// app/controllers/Application.scala

class Application extends Controller {

  ...

  def hello(name:String) = Action {
    Ok(s"hello, $name")
  }
}

我们现在可以定义由该控制器处理的路由:

// conf/routes
GET  /hello/:name             controllers.Application.hello(name)

如果你现在将浏览器指向127.0.0.1:9000/hello/Jim,你将在屏幕上看到hello, Jim

任何在:和随后的/之间的字符串都被视为通配符:它将匹配任何字符组合。通配符的值可以传递给控制器。请注意,通配符可以出现在 URL 的任何位置,并且可以有多个通配符。以下都是有效的路由定义,例如:

GET /hello/person-:name        controllers.Application.hello(name)
// ... matches /hello/person-Jim

GET /hello/:name/picture  controllers.Application.pictureFor(name)
// ... matches /hello/Jim/picture

GET /hello/:first/:last controllers.Application.hello(first, last)
// ... matches /hello/john/doe

选择路由并将参数传递给控制器有许多其他选项。请参阅 Play 框架的文档,以全面讨论路由的可能性:www.playframework.com/documentation/2.4.x/ScalaRouting

提示

URL 设计

通常认为,将 URL 尽可能简化是最佳实践。URL 应反映网站信息的层次结构,而不是底层实现。GitHub 就是很好的例子:它的 URL 直观易懂。例如,这本书的仓库 URL 是:

github.com/pbugnion/s4ds

要访问该仓库的问题页面,请在路由中添加/issues。要访问第一个问题,请在该路由中添加/1。这些被称为语义 URL(en.wikipedia.org/wiki/Semantic_URL)。

动作

我们已经讨论了路由以及如何向控制器传递参数。现在让我们谈谈我们可以用控制器做什么。

路由中定义的方法必须返回一个play.api.mvc.Action实例。Action类型是Request[A] => Result类型的薄包装,其中Request[A]标识一个 HTTP 请求,Result是 HTTP 响应。

组合响应

如我们在第七章中看到的,HTTP 响应由以下组成:

  • 状态码(例如,成功响应的 200 或缺失页面的 404)

  • 响应头,一个表示与响应相关的元数据的键值列表

  • 响应正文。这可以是网页的 HTML,或 JSON、XML 或纯文本(或许多其他格式)。这通常是真正感兴趣的部分。

Play 框架定义了一个play.api.mvc.Result对象,它表示一个响应。该对象包含一个header属性,包含状态码和头信息,以及一个包含正文的body属性。

生成 Result 的最简单方法就是使用 play.api.mvc.Results 中的工厂方法之一。我们已经看到了 Ok 方法,它生成状态码为 200 的响应:

def hello(name:String) = Action {
  Ok("hello, $name")
}

让我们退一步,打开一个 Scala 控制台,以便我们理解它是如何工作的:

$ activator console
scala> import play.api.mvc._
import play.api.mvc._

scala> val res = Results.Ok("hello, world")
res: play.api.mvc.Result = Result(200, Map(Content-Type -> text/plain; charset=utf-8))

scala> res.header.status
Int = 200

scala> res.header.headers
Map[String,String] = Map(Content-Type -> text/plain; charset=utf-8)

scala> res.body
play.api.libs.iteratee.Enumerator[Array[Byte]] = play.api.libs.iteratee.Enumerator$$anon$18@5fb83873

我们可以看到 Results.Ok(...) 是如何创建一个状态为 200Result 对象,在这个例子中,它包含一个表示内容类型的单个头信息。体部分稍微复杂一些:它是一个枚举器,当需要时可以推送到输出流。枚举器包含传递给 Ok 的参数:在这个例子中是 "hello, world"

Results 中有许多用于返回不同状态码的工厂方法。其中一些更相关的如下:

  • Action { Results.NotFound }

  • Action { Results.BadRequest("bad request") }

  • Action { Results.InternalServerError("error") }

  • Action { Results.Forbidden }

  • Action { Results.Redirect("/home") }

要获取 Result 工厂方法的完整列表,请参阅 Results 的 API 文档 (www.playframework.com/documentation/2.4.x/api/scala/index.html#play.api.mvc.Results)。

到目前为止,我们一直限制自己只将字符串作为 Ok 结果的内容传递:Ok("hello, world")。然而,我们并不局限于传递字符串。我们可以传递一个 JSON 对象:

scala> import play.api.libs.json._
import play.api.libs.json._

scala> val jsonObj = Json.obj("hello" -> "world")
jsonObj: play.api.libs.json.JsObject = {"hello":"world"}

scala> Results.Ok(jsonObj)
play.api.mvc.Result = Result(200, Map(Content-Type -> application/json; charset=utf-8))

当我们开始构建 API 时,我们将更详细地介绍与 JSON 的交互。我们也可以传递 HTML 作为内容。这通常是在返回视图时的情况。

scala> val htmlObj = views.html.index("hello")
htmlObj: play.twirl.api.HtmlFormat.Appendable =

<!DOCTYPE html>

<html lang="en">
 <head>
...

scala> Results.Ok(htmlObj)
play.api.mvc.Result = Result(200, Map(Content-Type -> text/html; charset=utf-8))

注意 Content-Type 头是如何根据传递给 Ok 的内容类型设置的。Ok 工厂使用 Writeable 类型类将它的参数转换为响应体。因此,对于任何存在 Writeable 类型类的类型,都可以用作 Ok 的参数。如果你对类型类不熟悉,你可能想阅读第五章 使用类型类进行松耦合 的部分,Scala 和 SQL 通过 JDBC

理解和解析请求

我们现在知道了如何制定(基本的)响应。等式的另一半是 HTTP 请求。回想一下,Action 只是一个将 Request => Result 映射的函数。我们可以使用以下方式访问请求:

def hello(name:String) = Action { request => 
  ...
}

需要引用请求的一个原因是访问查询字符串中的参数。让我们修改我们之前写的 Hello, <name> 示例,使其可选地包含查询字符串中的标题。因此,一个 URL 可以格式化为 /hello/Jim?title=Drrequest 实例公开了 getQueryString 方法,用于访问查询字符串中的特定键。如果键存在于查询中,该方法返回 Some[String],否则返回 None。我们可以将我们的 hello 控制器重写为:

def hello(name:String) = Action { request =>
  val title = request.getQueryString("title")
  val titleString = title.map { _ + " " }.getOrElse("")
  Ok(s"Hello, $titleString$name")
}

通过在浏览器中访问 URL 127.0.0.1:9000/hello/Odersky?title=Dr 来尝试这个示例。浏览器应该显示 Hello, Dr Odersky

到目前为止,我们一直专注于 GET 请求。这些请求没有正文。其他类型的 HTTP 请求,最常见的是 POST 请求,确实包含正文。Play 允许用户在定义操作时传递 正文解析器。请求正文将通过正文解析器传递,它将将其从字节流转换为 Scala 类型。作为一个非常简单的例子,让我们定义一个新的路由,它接受 POST 请求:

POST      /hello            controllers.Application.helloPost

我们将预定义的 parse.text 正文解析器应用于传入的请求正文。这会将请求正文转换为字符串。helloPost 控制器看起来像:

def helloPost = Action(parse.text) { request =>
  Ok("Hello. You told me: " + request.body)
}

小贴士

在浏览器中轻松测试 POST 请求是不可能的。您可以使用 cURL。cURL 是一个用于发送 HTTP 请求的命令行实用程序。它在 Mac OS 上默认安装,并且应该可以通过 Linux 发行版的包管理器获得。以下示例将发送一个正文为 "I think that Scala is great" 的 POST 请求:

$ curl --data "I think that Scala is great" --header "Content-type:text/plain"  127.0.0.1:9000/hello

这将在终端打印以下行:

Hello. You told me: I think that Scala is great

有几种内置的正文解析器:

  • parse.file(new File("filename.txt")) 将正文保存到文件中。

  • parse.json 将正文解析为 JSON(我们将在下一节中了解更多关于与 JSON 交互的内容)。

  • parse.xml 将正文解析为 XML。

  • parse.urlFormEncoded 将解析由提交 HTML 表单返回的正文。request.body 属性是一个从 StringSeq[String] 的 Scala 映射,将每个表单元素映射到其值(们)。

要获取正文解析器的完整列表,最佳来源是 play.api.mvc.BodyParsers.parse 的 Scala API 文档,可在以下网址找到:www.playframework.com/documentation/2.5.x/api/scala/index.html#play.api.mvc.BodyParsers$parse$

与 JSON 交互

如我们在前面的章节中所发现的,JSON 正在成为通过 HTTP 通信结构化数据的默认语言。如果您开发一个 Web 应用程序或 Web API,您可能需要消费或发射 JSON,或者两者都要。

在 第七章,Web APIs 中,我们学习了如何通过 json4s 解析 JSON。Play 框架包括它自己的 JSON 解析器和发射器。幸运的是,它的行为与 json4s 非常相似。

让我们假设我们正在构建一个总结 GitHub 仓库信息的 API。当查询特定用户时,我们的 API 将输出一个 JSON 数组,列出该用户的仓库(类似于 GitHub API,但只包含部分字段)。

让我们先定义一个仓库的模型。在 Play 应用程序中,模型通常存储在 app/models 文件夹中的 models 包下:

// app/models/Repo.scala

package models

case class Repo (
  val name:String,
  val language:String,
  val isFork: Boolean,
  val size: Long
)

让我们在应用程序中添加一个路由,为特定用户提供仓库数组。在 conf/routes 中添加以下行:

// conf/routes
GET   /api/repos/:username       controllers.Api.repos(username)

现在我们来实现控制器的框架。我们将为我们的 API 创建一个新的控制器,暂时命名为 Api。目前,我们只是让控制器返回一些示例数据。代码如下(我们将在稍后解释细节):

// app/controllers/Api.scala
package controllers
import play.api._
import play.api.mvc._
import play.api.libs.json._

import models.Repo

class Api extends Controller {

  // Some dummy data.
  val data = ListRepo,
    Repo("frontend", "JavaScript", true, 392)
  )

  // Typeclass for converting Repo -> JSON
  implicit val writesRepos = new Writes[Repo] {
    def writes(repo:Repo) = Json.obj(
      "name" -> repo.name,
      "language" -> repo.language,
      "is_fork" -> repo.isFork,
      "size" -> repo.size
    )
  }

  // The controller
  def repos(username:String) = Action {

    val repoArray = Json.toJson(data) 
    // toJson(data) relies on existence of 
    // `Writes[List[Repo]]` type class in scope

    Ok(repoArray)
  }
}

如果你将你的网络浏览器指向 127.0.0.1:9000/api/repos/odersky,你现在应该看到以下 JSON 对象:

[{"name":"dotty","language":"Scala","is_fork":true,"size":14315},{"name":"frontend","language":"JavaScript","is_fork":true,"size":392}]

这段代码中唯一棘手的部分是将 Repo 转换为 JSON。我们在 data 上调用 Json.toJson,它是 List[Repo] 类型的实例。toJson 方法依赖于传递给它的类型 T 存在的类型类 Writes[T]

Play 框架广泛使用类型类来定义如何将模型转换为特定格式。回想一下,我们学习了如何在 SQL 和 MongoDB 的上下文中编写类型类。Play 框架的期望非常相似:为了使 Json.toJson 方法能够在 Repo 类型的实例上工作,必须有 Writes[Repo] 的实现可用,该实现指定了如何将 Repo 对象转换为 JSON。

在 Play 框架中,Writes[T] 类型类定义了一个单一的方法:

trait Writes[T] {
  def writes(obj:T):Json
}

Writes 方法为内置简单类型和集合已经内置到 Play 框架中,因此我们不需要担心定义 Writes[Boolean] 等。

Writes[Repo] 实例通常直接在控制器中定义,如果它仅用于该控制器,或者在 Repo 伴生对象中定义,这样它就可以在多个控制器中使用。为了简单起见,我们只是将其嵌入到控制器中。

注意类型类如何实现关注点的分离。模型仅定义了 Repo 类型,而没有附加任何行为。Writes[Repo] 类型类只知道如何将 Repo 实例转换为 JSON,但不知道它在什么上下文中被使用。最后,控制器只知道如何创建 JSON HTTP 响应。

恭喜你,你刚刚定义了一个返回 JSON 的 Web API!在下一节中,我们将学习如何从 GitHub Web API 获取数据,以避免不断返回相同的数组。

查询外部 API 和消费 JSON

到目前为止,我们已经学会了如何响应 /api/repos/:username 的请求,向用户提供一个示例 JSON 数组形式的仓库。在本节中,我们将用从 GitHub 下载的实际仓库数据替换示例数据。

在 第七章,Web APIs 中,我们学习了如何使用 Scala 的 Source.fromURL 方法以及 scalaj-http 查询 GitHub API。Play 框架实现自己的库以与外部 Web 服务交互应该不会让人感到惊讶。

让我们编辑 Api 控制器以从 GitHub 获取有关用户仓库的信息,而不是使用示例数据。当以用户名作为参数调用时,控制器将:

  1. 向 GitHub API 发送 GET 请求以获取该用户的仓库信息。

  2. 解释响应,将体从 JSON 对象转换为 List[Repo]

  3. List[Repo] 转换为 JSON 数组,形成响应。

我们首先给出完整的代码列表,然后再详细解释更复杂的部分:

// app/controllers/Api.scala

package controllers

import play.api._
import play.api.mvc._
import play.api.libs.ws.WS // query external APIs
import play.api.Play.current
import play.api.libs.json._ // parsing JSON
import play.api.libs.functional.syntax._
import play.api.libs.concurrent.Execution.Implicits.defaultContext

import models.Repo

class Api extends Controller {

  // type class for Repo -> Json conversion
  implicit val writesRepo = new Writes[Repo] {
    def writes(repo:Repo) = Json.obj(
      "name" -> repo.name,
      "language" -> repo.language,
      "is_fork" -> repo.isFork,
      "size" -> repo.size
    )
  }

  // type class for Github Json -> Repo conversion
  implicit val readsRepoFromGithub:Reads[Repo] = (
    (JsPath \ "name").read[String] and
    (JsPath \ "language").read[String] and
    (JsPath \ "fork").read[Boolean] and
    (JsPath \ "size").read[Long]
  )(Repo.apply _)

  // controller
  def repos(username:String) = Action.async {

    // GitHub URL
    val url = s"https://api.github.com/users/$username/repos"
    val response = WS.url(url).get() // compose get request

    // "response" is a Future
    response.map { r =>
      // executed when the request completes
      if (r.status == 200) {

        // extract a list of repos from the response body
        val reposOpt = Json.parse(r.body).validate[List[Repo]]
        reposOpt match {
          // if the extraction was successful:
          case JsSuccess(repos, _) => Ok(Json.toJson(repos))

          // If there was an error during the extraction
          case _ => InternalServerError
        }
      }
      else {
        // GitHub returned something other than 200
        NotFound
      }

    }
  }

}

如果您已经编写了所有这些,请将浏览器指向例如 127.0.0.1:9000/api/repos/odersky 来查看 Martin Odersky 拥有的仓库列表:

[{"name":"dotty","language":"Scala","is_fork":true,"size":14653},{"name":"frontend","language":"JavaScript","is_fork":true,"size":392},...

这个代码示例内容很多,所以让我们将其分解。

调用外部网络服务

查询外部 API 的第一步是导入 WS 对象,它定义了创建 HTTP 请求的工厂方法。这些工厂方法依赖于命名空间中隐含的 Play 应用程序的引用。确保这种情况的最简单方法是导入 play.api.Play.current,这是对当前应用程序的引用。

现在让我们忽略 readsRepoFromGithub 类型类,直接跳到控制器主体。我们想要通过 GET 请求访问的 URL 是 "https://api.github.com/users/$username/repos",其中 $username 是适当的值。我们使用 WS.url(url).get() 创建一个 GET 请求。我们还可以向现有请求添加头信息。例如,为了指定内容类型,我们可以这样写:

WS.url(url).withHeaders("Content-Type" -> "application/json").get()

我们可以使用头信息通过以下方式传递 GitHub OAuth 令牌:

val token = "2502761d..."
WS.url(url).withHeaders("Authorization" -> s"token $token").get()

要形成 POST 请求而不是 GET 请求,将最后的 .get() 替换为 .post(data)。在这里,data 可以是 JSON、XML 或字符串。

添加 .get.post 触发请求,返回一个 Future[WSResponse]。到现在为止,您应该熟悉 futures。通过编写 response.map { r => ... },我们指定在 future 结果返回时要执行的可转换操作,该操作验证响应的状态,如果响应的状态码不是 200,则返回 NotFound

解析 JSON

如果状态码是 200,回调将解析响应体为 JSON,并将解析后的 JSON 转换为 List[Repo] 实例。我们已经知道如何使用 Writes[Repo] 类型类将 Repo 对象转换为 JSON。反过来,从 JSON 到 Repo 对象的转换要复杂一些,因为我们必须考虑格式不正确的 JSON。为此,Play 框架在 JSON 对象上提供了 .validate[T] 方法。此方法尝试将 JSON 转换为类型 T 的实例,如果 JSON 格式良好,则返回 JsSuccess,否则返回 JsError(类似于 Scala 的 Try 对象)。.validate 方法依赖于类型类 Reads[Repo] 的存在。让我们在 Scala 控制台中实验一下:

$ activator console

scala> import play.api.libs.json._
import play.api.libs.json._

scala> val s = """ 
 { "name": "dotty", "size": 150, "language": "Scala", "fork": true }
"""
s: String = "
 { "name": "dotty", "size": 150, "language": "Scala", "fork": true }
"

scala> val parsedJson = Json.parse(s)
parsedJson: play.api.libs.json.JsValue = {"name":"dotty","size":150,"language":"Scala","fork":true}

使用 Json.parse 将字符串转换为 JsValue 实例,它是 JSON 实例的超类型。我们可以使用类似 XPath 的语法访问 parsedJson 中的特定字段(如果您不熟悉类似 XPath 的语法,您可能想阅读第六章,Slick – 一个 SQL 的函数式接口):

scala> parsedJson \ "name"
play.api.libs.json.JsLookupResult = JsDefined("dotty")

XPath 类似的查找返回一个类型为 JsLookupResult 的实例。它包含两个值:如果路径有效,则为 JsDefined,如果无效,则为 JsUndefined

scala> parsedJson \ "age"
play.api.libs.json.JsLookupResult = JsUndefined('age' is undefined on object: {"name":"dotty","size":150,"language":"Scala","fork":true})

要以类型安全的方式将 JsLookupResult 实例转换为 String,我们可以使用 .validate[String] 方法:

scala> (parsedJson \ "name").validate[String]
play.api.libs.json.JsResult[String] = JsSuccess(dotty,) 

.validate[T] 方法返回 JsSuccess,如果 JsDefined 实例可以被成功转换为 T,否则返回 JsError。为了说明后者,让我们尝试将其验证为 Int

scala> (parsedJson \ "name").validate[Int]
dplay.api.libs.json.JsResult[Int] = JsError(List((,List(ValidationError(List(error.expected.jsnumber),WrappedArray())))))

JsUndefined 实例上调用 .validate 也会返回 JsError

scala> (parsedJson \ "age").validate[Int]
play.api.libs.json.JsResult[Int] = JsError(List((,List(ValidationError(List('age' is undefined on object: {"name":"dotty","size":150,"language":"Scala","fork":true}),WrappedArray())))))

要将 JsResult[T] 实例转换为类型 T 的实例,我们可以使用模式匹配:

scala> val name = (parsedJson \ "name").validate[String] match {
 case JsSuccess(n, _) => n
 case JsError(e) => throw new IllegalStateException(
 s"Error extracting name: $e")
}
name: String = dotty

我们现在可以使用 .validate 以类型安全的方式将 JSON 转换为简单类型。但在代码示例中,我们使用了 .validate[Repo]。这只有在命名空间中隐式可用 Reads[Repo] 类型类时才有效。

定义 Reads[T] 类型类最常见的方式是通过在 import play.api.libs.functional.syntax._ 中提供的 DSL(领域特定语言)。该 DSL 通过链式操作返回 JsSuccessJsError 来工作。具体讨论这个 DSL 的工作原理超出了本章的范围(例如,可以参考 Play 框架关于 JSON 组合器的文档页面:www.playframework.com/documentation/2.4.x/ScalaJsonCombinators)。我们将坚持讨论语法。

scala> import play.api.libs.functional.syntax._
import play.api.libs.functional.syntax._

scala> import models.Repo
import models.Repo

scala> implicit val readsRepoFromGithub:Reads[Repo] = (
 (JsPath \ "name").read[String] and
 (JsPath \ "language").read[String] and
 (JsPath \ "fork").read[Boolean] and
 (JsPath \ "size").read[Long]
)(Repo.apply _)
readsRepoFromGithub: play.api.libs.json.Reads[models.Repo] = play.api.libs.json.Reads$$anon$8@a198ddb

Reads 类型类分为两个阶段定义。第一阶段通过 read[T] 方法与 and 链接起来,结合成功和错误。第二阶段使用案例类(或 Tuple 实例)的伴生对象的 apply 方法来构建对象,前提是第一阶段成功完成。现在我们已经定义了类型类,我们可以在 JsValue 对象上调用 validate[Repo]

scala> val repoOpt = parsedJson.validate[Repo]
play.api.libs.json.JsResult[models.Repo] = JsSuccess(Repo(dotty,Scala,true,150),)

然后,我们可以使用模式匹配从 JsSuccess 实例中提取 Repo 对象:

scala> val JsSuccess(repo, _) = repoOpt
repo: models.Repo = Repo(dotty,Scala,true,150)

到目前为止,我们只讨论了验证单个仓库。Play 框架为集合类型定义了类型类,因此,如果 Reads[Repo] 被定义,Reads[List[Repo]] 也会被定义。

现在我们已经了解了如何从 JSON 中提取 Scala 对象,让我们回到代码。如果我们能够成功将仓库转换为 List[Repo],我们再次将其作为 JSON 发射。当然,将 GitHub 的仓库 JSON 表示转换为 Scala 对象,然后直接转换为我们的对象 JSON 表示可能看起来很复杂。然而,如果这是一个真实的应用程序,我们会有额外的逻辑。例如,我们可以将仓库存储在缓存中,并尝试从缓存中获取而不是查询 GitHub API。尽早将 JSON 转换为 Scala 对象可以解耦我们编写的代码与 GitHub 返回仓库的方式。

异步操作

代码示例中新增的最后部分是调用 Action.async,而不是仅仅 Action。回想一下,Action 实例是 Request => Result 方法的薄包装。然而,我们的代码返回一个 Future[Result],而不是 Result。在这种情况下,使用 Action.async 来构建动作,而不是直接使用 Action。使用 Action.async 告诉 Play 框架,创建 Action 的代码是异步的。

使用 Play 创建 API:总结

在最后一节中,我们部署了一个响应 GET 请求的 API。由于这需要很多理解,让我们总结一下如何进行 API 创建:

  1. /conf/routes 中定义适当的路由,根据需要使用 URL 中的通配符。

  2. /app/models 中创建 Scala 的案例类来表示 API 使用的模型。

  3. 创建 Write[T] 方法,以便将模型写入 JSON 或 XML,这样它们就可以通过 API 返回。

  4. 将路由绑定到控制器。如果控制器需要做更多的工作,将工作包装在 future 中以避免阻塞服务器。

Play 框架中还有许多你可能需要的更有用的组件,例如,例如如何使用 Slick 访问 SQL 数据库。不幸的是,我们没有时间在这篇介绍中涵盖这些内容。Play 框架有详尽、写得很好的文档,将填补这篇教程中的空白。

Rest API:最佳实践

随着互联网的成熟,REST(表示状态转移)API 正在成为网络 API 最可靠的设计模式。如果一个 API 遵循以下指导原则,则称为 RESTful

  • API 被设计为一组资源。例如,GitHub API 提供了关于用户、仓库、关注者等信息。每个用户或仓库都是一个特定的资源。每个资源都可以通过不同的 HTTP 端点进行访问。

  • 网址应当简洁明了,并能清楚地标识资源。例如,api.github.com/users/odersky 就很简单,清楚地告诉我们应该期待关于用户马丁·奥德斯基的信息。

  • 没有一个包含系统所有信息的 全局资源。相反,顶级资源包含指向更专业资源的链接。例如,GitHub API 中的用户资源包含指向该用户仓库和该用户关注者的链接,而不是直接在用户资源中嵌入所有这些信息。

  • API 应该是可发现的。对特定资源的请求的响应应包含相关资源的 URL。当你查询 GitHub 上的用户资源时,响应包含访问该用户关注者、仓库等的 URL。客户端应使用 API 提供的 URL,而不是尝试在客户端构建它们。这使客户端对 API 的变化更加稳健。

  • 应尽可能在服务器上保持最少的状态。例如,在查询 GitHub API 时,我们必须在每次请求中传递认证令牌,而不是期望我们的认证状态在服务器上被记住。使每次交互独立于历史记录提供了更好的可扩展性:如果任何交互可以由任何服务器处理,负载均衡就更容易实现。

摘要

在本章中,我们介绍了 Play 框架作为构建 Web API 的工具。我们构建了一个返回用户 GitHub 仓库 JSON 数组的 API。在下一章中,我们将在此基础上构建 API,并构建一个单页应用程序来图形化表示这些数据。

参考资料

第十四章. 使用 D3 和 Play 框架进行可视化

在上一章中,我们学习了 Play 框架,一个 Scala 的 Web 框架。我们构建了一个返回描述用户 GitHub 仓库的 JSON 数组的 API。

在本章中,我们将构建一个完整的 Web 应用程序,显示一个表格和一个图表,描述用户的仓库。我们将学习如何将D3.js,一个用于构建数据驱动 Web 页面的 JavaScript 库,与 Play 框架集成。这将使你走上构建引人注目的交互式可视化之路,展示使用机器学习获得的结果。

本章假设你已经熟悉 HTML、CSS 和 JavaScript。我们将在本章末尾提供参考资料。你还应该阅读上一章。

GitHub 用户数据

我们将构建一个单页应用程序,其后端使用上一章开发的 API。该应用程序包含一个表单,用户可以输入 GitHub 账户的登录名。应用程序查询 API 以获取该用户的仓库列表,并在屏幕上以表格和饼图的形式显示,总结该用户使用的编程语言:

GitHub 用户数据

要查看应用程序的实时版本,请访问 app.scala4datascience.com

我需要后端吗?

在上一章中,我们学习了支撑互联网工作方式的客户端-服务器模型:当你在你浏览器中输入一个网站 URL 时,服务器会向你的浏览器提供 HTML、CSS 和 JavaScript,然后浏览器以适当的方式渲染它们。

这对你意味着什么?当构建 Web 应用时,你应考虑的第二个问题可能是你是否需要进行任何服务器端处理(在“这真的值得付出努力吗?”之后)。你能否仅仅创建一个带有一些 JavaScript 的 HTML 网页?

如果构建整个应用所需的数据足够小,你就可以不使用后端:通常只有几兆字节。如果你的应用更大,你需要一个后端来传输客户端当前需要的仅有的数据。令人惊讶的是,你通常可以在没有后端的情况下构建可视化:虽然数据科学通常习惯于处理 TB 级别的数据,但数据科学流程的目标通常是将这些庞大的数据集压缩成几个有意义的数字。

拥有后端还可以让你包含对客户端不可见的逻辑。如果你需要验证密码,显然你不能将执行该操作的代码发送到客户端计算机:它需要在看不见的地方,在服务器上完成。

如果你的应用足够小,且不需要进行任何服务器端处理,就停止阅读这一章,如果你需要的话,复习一下 JavaScript,现在先别考虑 Scala。不必担心构建后端会使你的生活更轻松。

然而,显然我们并没有为我们要构建的应用拥有这样的自由度:用户可以输入 GitHub 上任何人的名字。查找关于该用户的信息需要后端访问巨大的存储和查询能力(我们通过仅将请求转发到 GitHub API 并重新解释响应来模拟)。

通过 web-jars 使用 JavaScript 依赖项

开发 Web 应用的一个挑战是,我们需要编写两个近乎独立的程序:服务器端程序和客户端程序。这些程序通常需要不同的技术。特别是,对于任何非最简单应用,我们必须跟踪 JavaScript 库,并在构建过程中集成处理 JavaScript 代码(例如,进行压缩)。

Play 框架通过web-jars管理 JavaScript 依赖项。这些只是打包成 jar 文件的 JavaScript 库。它们部署在 Maven Central 上,这意味着我们只需将它们添加到我们的build.sbt文件中作为依赖项。对于这个应用,我们需要以下 JavaScript 库:

  • Require.js,一个用于编写模块化 JavaScript 的库

  • JQuery

  • Bootstrap

  • Underscore.js,一个添加了许多功能构造和客户端模板的库。

  • D3,一个图形绘图库

  • NVD3,一个基于 D3 构建的图形库

如果您计划编写本章提供的示例,最简单的方法是从上一章的代码开始(您可以从 GitHub 下载第十三章的代码,使用 Play 的 Web APIgithub.com/pbugnion/s4ds/tree/master/chap13)。从现在起,我们将假设这是一个起点。

让我们在build.sbt文件中包含对 web-jars 的依赖项:

libraryDependencies ++= Seq(
  "org.webjars" % "requirejs" % "2.1.22",
  "org.webjars" % "jquery" % "2.1.4",
  "org.webjars" % "underscorejs" % "1.8.3",
  "org.webjars" % "nvd3" % "1.8.1",
  "org.webjars" % "d3js" % "3.5.6",
  "org.webjars" % "bootstrap" % "3.3.6"
)

通过运行activator update来获取模块。一旦完成此操作,您将注意到target/web/public/main/lib中的 JavaScript 库。

迈向网络应用:HTML 模板

在上一章中,我们简要介绍了如何在 HTML 文件中交错 Scala 片段来构建 HTML 模板的方法。我们了解到模板被编译为 Scala 函数,并学习了如何从控制器中调用这些函数。

在单页应用中,控制浏览器实际显示的大多数逻辑位于客户端 JavaScript 中,而不是服务器端。服务器提供的页面包含基本的 HTML 框架。

让我们为我们的应用程序创建 HTML 布局。我们将将其保存在views/index.scala.html中。该模板将仅包含应用程序的布局,但不会包含任何关于任何用户仓库的信息。要获取这些信息,应用程序必须查询上一章开发的 API。该模板不接收任何参数,因为所有动态 HTML 生成都将发生在客户端。

我们使用 Bootstrap 网格布局来控制 HTML 布局。如果您不熟悉 Bootstrap 布局,请查阅getbootstrap.com/css/#grid-example-basic的文档。

// app/views/index.scala.html
<!DOCTYPE html>

<html lang="en">
  <head>
    <title>Github User display</title>
    <link rel="stylesheet" media="screen" 
      href="@routes.Assets.versioned("stylesheets/main.css")">
    <link rel="shortcut icon" type="image/png"
      href="@routes.Assets.versioned("images/favicon.png")">
    <link rel="stylesheet" media="screen" 
      href=@routes.Assets.versioned("lib/nvd3/nv.d3.css") >
    <link rel="stylesheet" media="screen"
      href=@routes.Assets.versioned(
      "lib/bootstrap/css/bootstrap.css")>
  </head>

  <body>
    <div class="container">

      <!-- Title row -->
      <div class="row">
        <h1>Github user search</h1>
      </div>

      <!-- User search row -->
      <div class="row">
        <label>Github user: </label>
        <input type="text" id="user-selection">
        <span id="searching-span"></span> <hr />
      </div>

      <!-- Results row -->
      <div id="response" class="row"></div>
    </div>
  </body>
</html>

在 HTML 头部,我们链接应用程序所需的 CSS 样式表。我们不是明确指定路径,而是使用@routes.Assets.versioned(...)函数。这解析为编译后资产存储位置的 URI。传递给函数的参数应该是从target/web/public/main到所需资产的路径。

当用户访问我们的服务器上的路由/时,我们希望提供此视图的编译版本。因此,我们需要将此路由添加到conf/routes中:

# conf/routes
GET   /      controllers.Application.index

路由由Application控制器中的index函数提供。此控制器需要做的只是提供index视图:

// app/controllers/Application.scala
package controllers

import play.api._
import play.api.mvc._

class Application extends Controller {

  def index = Action {
    Ok(views.html.index())
  }
}

通过在应用程序根目录中运行activator run来启动 Play 框架,并将您的网络浏览器指向127.0.0.1:9000/。您应该看到我们的网络应用程序框架。当然,应用程序目前还没有做任何事情,因为我们还没有编写任何 JavaScript 逻辑。

迈向网络应用:HTML 模板

通过 RequireJS 模块化 JavaScript

将 JavaScript 库注入命名空间的最简单方法是将它们添加到 HTML 框架中,通过 HTML 头部的<script>...</script>标签。例如,要添加 JQuery,我们会在文档的头部添加以下行:

<script src=@routes.Assets.versioned("lib/jquery/jquery.js") type="text/javascript"></script>

虽然这可行,但它不适合大型应用程序的扩展,因为每个库都被导入到全局命名空间中。现代客户端 JavaScript 框架,如 AngularJS,提供了一种定义和加载模块的替代方法,这有助于保持封装性。

我们将使用 RequireJS。简而言之,RequireJS 让我们可以通过函数封装 JavaScript 模块。例如,如果我们想编写一个包含用于隐藏div的函数的模块example,我们会按照以下方式定义该模块:

// example.js
define(["jquery", "underscore"], function($, _) {

  // hide a div
  function hide(div_name) {
    $(div_name).hide() ;
  }

  // what the module exports.
  return { "hide": hide }

}) ;

我们将我们的模块封装在一个名为define的函数的回调中。define函数接受两个参数:一个依赖项列表和一个函数定义。define函数将依赖项绑定到回调的参数列表:在这种情况下,JQuery 中的函数将被绑定到$,而 Underscore 中的函数将被绑定到_。这创建了一个模块,它暴露了回调函数返回的任何内容。在这种情况下,我们导出hide函数,并将其绑定到名称"hide"。因此,我们的示例模块暴露了hide函数。

要加载此模块,我们将它作为依赖项传递给我们要在其中使用它的模块:

define(["example"], function(example) {

  function hide_all() {
 example.hide("#top") ;
 example.hide("#bottom") ;
  }

  return { "hide_all": hide_all } ;
});

注意example中的函数是如何被封装的,而不是存在于全局命名空间中。我们通过example.<function-name>来调用它们。此外,在example模块内部定义的任何函数或变量都保持私有。

有时,我们希望 JavaScript 代码存在于模块之外。对于启动应用程序的脚本来说,这通常是情况。对于这些脚本,将define替换为require

require(["jquery", "example"], function($, example) {
  $(document).ready(function() {
    example.hide("#header") ;
  });
}) ;

现在我们已经对 RequireJS 有了概述,那么我们如何在 Play 框架中使用它?第一步是添加对 RequireJS web jar 的依赖,我们已经这样做了。Play 框架还添加了一个 RequireJS SBT 插件(github.com/sbt/sbt-rjs),如果您使用了play-scala激活器,则默认应该已安装。如果缺少此插件,可以在plugins.sbt文件中添加以下行:

// project/plugins.sbt

addSbtPlugin("com.typesafe.sbt" % "sbt-rjs" % "1.0.7")

我们还需要将插件添加到阶段列表中。这允许插件在打包应用程序为 jar 时操作 JavaScript 资源。将以下行添加到build.sbt文件中:

pipelineStages := Seq(rjs)

您需要重新启动激活器以使更改生效。

我们现在已准备好在我们的应用程序中使用 RequireJS。我们可以在视图的头部部分添加以下行来使用它:

// index.scala.html

<html>
  <head>
...

    <script
      type="text/javascript"
      src=@routes.Assets.versioned("lib/requirejs/require.js").url
      data-main=@routes.Assets.versioned("javascripts/main.js").url>
    </script>

  </head>
...
</html>

当视图被编译时,这会被解析为类似以下标签:

<script type="text/javascript" 
  data-main="/assets/javascripts/main.js" 
  src="img/require.min.js">
</script>

传递给 data-main 的参数是应用程序的入口点。当 RequireJS 加载时,它将执行 main.js。因此,该脚本必须引导我们的应用程序。特别是,它应该包含一个用于 RequireJS 的配置对象,使其知道所有库的位置。

引导应用程序

当我们将 require.js 链接到我们的应用程序时,我们告诉它使用 main.js 作为我们的入口点。为了测试这是否工作,让我们首先输入一个虚拟的 main.js。Play 应用程序中的 JavaScript 文件位于 /public/javascripts

// public/javascripts/main.js

require([], function() {
  console.log("hello, JavaScript"); 
});

为了验证这已经工作,转到 127.0.0.1:9000 并打开浏览器控制台。你应该在控制台中看到 "hello, JavaScript"

现在我们来编写一个更有用的 main.js。我们首先配置 RequireJS,给它提供我们将在应用程序中使用的模块的位置。不幸的是,我们使用的图形库 NVD3 与 RequireJS 不是很兼容,所以我们不得不使用一个丑陋的技巧来使其工作。这使我们的 main.js 文件变得有些复杂:

// public/javascripts/main.js

(function (requirejs) {
  'use strict';

  // -- RequireJS config --
  requirejs.config({
    // path to the web jars. These definitions allow us 
    // to use "jquery", rather than "../lib/jquery/jquery",
    // when defining module dependencies.
    paths: {
      "jquery": "../lib/jquery/jquery",
      "underscore": "../lib/underscorejs/underscore",
      "d3": "../lib/d3js/d3",
      "nvd3": "../lib/nvd3/nv.d3",
      "bootstrap": "../lib/bootstrap/js/bootstrap"
    },

    shim: {
      // hack to get nvd3 to work with requirejs.
      // see this so question:
      // http://stackoverflow.com/questions/13157704/how-to-integrate-d3-with-require-js#comment32647365_13171592        
      nvd3: {
        deps: ["d3.global"],
        exports: "nv"
      },
      bootstrap : { deps :['jquery'] }
    }

  }) ;
})(requirejs) ;

// hack to get nvd3 to work with requirejs.
// see this so question on Stack Overflow:
// http://stackoverflow.com/questions/13157704/how-to-integrate-d3-with-require-js#comment32647365_13171592
define("d3.global", ["d3"], function(d3global) {
  d3 = d3global;
});

require([], function() {
  // Our application
  console.log("hello, JavaScript");
}) ;

现在我们已经设置了配置,我们可以深入到应用程序的 JavaScript 部分。

客户端程序架构

基本思想很简单:用户在输入框中搜索 GitHub 上某人的名字。当他输入名字时,我们向本章 earlier 设计的 API 发送请求。当 API 的响应返回时,程序将响应绑定到模型,并发出一个事件通知模型已更改。视图监听此事件并从模型中刷新。

设计模型

让我们先定义客户端模型。模型保存有关当前显示的用户仓库的信息。它会在第一次搜索后填充。

// public/javascripts/model.js

define([], function(){
   return {
    ghubUser: "", // last name that was searched for
    exists: true, // does that person exist on github?
    repos: [] // list of repos
  } ;
});

要查看模型的填充值,请转到 app.scala4datascience.com 上的完整应用程序示例,在浏览器中打开一个 JavaScript 控制台,搜索应用程序中的用户(例如,odersky),然后在控制台中输入以下内容:

> require(["model"], function(model) { console.log(model) ; }) 
{ghubUser: "odersky", exists: true, repos: Array}

> require(["model"], function(model) { 
 console.log(model.repos[0]); 
})
{name: "dotty", language: "Scala", is_fork: true, size: 14653}

这些导入 "model" 模块,将其绑定到变量 model,然后打印信息到控制台。

事件总线

我们需要一个机制来通知视图当模型更新时,因为视图需要从新的模型中刷新。这通常通过 Web 应用程序中的 事件 来处理。jQuery 允许我们将回调绑定到特定事件。当该事件发生时,回调将被执行。

例如,要将回调绑定到事件 "custom-event",在 JavaScript 控制台中输入以下内容:

> $(window).on("custom-event", function() { 
 console.log("custom event received") ; 
});

我们可以使用以下方式触发事件:

> $(window).trigger("custom-event"); 
custom event received

在 jQuery 中,事件需要注册在 事件总线 上的 DOM 元素。在这个例子中,我们使用了 window DOM 元素作为我们的事件总线,但任何 jQuery 元素都可以。将事件定义集中到单个模块中是有帮助的。因此,我们将创建一个包含两个函数的 events 模块:trigger,用于触发一个事件(由一个字符串指定),和 on,用于将回调绑定到特定事件:

// public/javascripts/events.js

define(["jquery"], function($) {

  var bus = $(window) ; // widget to use as an event bus

  function trigger(eventType) {
    $(bus).trigger(eventType) ;
  }

  function on(eventType, f) {
    $(bus).on(eventType, f) ;
  }

  return {
    "trigger": trigger,
    "on": on
  } ;
});

我们现在可以使用 events 模块发出和接收事件。你可以在应用程序的实时版本(在 app.scala4datascience.com)的 JavaScript 控制台中测试这一点。让我们首先注册一个监听器:

> require(["events"], function(events) {
  // register event listener
  events.on("hello_event", function() {
    console.log("Received event") ;
  }) ;
}); 

如果我们现在触发事件 "hello_event",监听器将打印 "Received event"

> require(["events"], function(events) {
  // trigger the event
  events.trigger("hello_event") ;
}) ;

使用事件使我们能够将控制器与视图解耦。控制器不需要了解任何关于视图的信息,反之亦然。控制器只需要在模型更新时发出一个 "model_updated" 事件,而视图在接收到该事件时需要从模型刷新。

通过 jQuery 进行 AJAX 调用

现在我们可以编写我们应用程序的控制器。当用户在文本输入框中输入名称时,我们查询 API,更新模型并触发一个 model_updated 事件。

我们使用 jQuery 的 $.getJSON 函数来查询我们的 API。这个函数将其第一个参数作为 URL,第二个参数作为回调函数。API 调用是异步的:$.getJSON 在执行后立即返回。因此,所有请求处理都必须在回调中进行。如果请求成功,将调用回调,但我们可以定义始终调用或失败时调用的额外处理程序。让我们在浏览器控制台(无论是你自己的,如果你正在运行上一章开发的 API,还是 app.scala4datascience.com 上的)中尝试一下。回想一下,API 正在监听 /api/repos/:user 的端点:

> $.getJSON("/api/repos/odersky", function(data) { 
 console.log("API response:");
 console.log(data);
 console.log(data[0]); 
}) ;
{readyState: 1, getResponseHeader: function, ...}

API response:
[Object, Object, Object, Object, Object, ...]
{name: "dotty", language: "Scala", is_fork: true, size: 14653}

getJSON 立即返回。几秒钟后,API 响应,此时响应将通过回调传递。

回调仅在成功时执行。它接受 API 返回的 JSON 对象作为其参数。要绑定在 API 请求失败时执行的回调,请调用 getJSON 返回值的 .fail 方法:

> $.getJSON("/api/repos/junk123456", function(data) { 
 console.log("called on success"); 
}).fail(function() { 
 console.log("called on failure") ; 
}) ;
{readyState: 1, getResponseHeader: function, ...}

called on failure

我们还可以使用 getJSON 返回值的 .always 方法来指定无论 API 查询是否成功都执行的回调。

现在我们知道了如何使用 $.getJSON 查询我们的 API,我们可以编写控制器。控制器监听 #user-selection 输入字段的更改。当发生更改时,它向 API 发送 AJAX 请求以获取该用户的信息。它绑定一个回调,当 API 回复包含存储库列表时更新模型。我们将定义一个 controller 模块,该模块导出一个名为 initialize 的单个函数,该函数创建事件监听器:

// public/javascripts/controller.js
define(["jquery", "events", "model"], function($, events, model) {

  function initialize() {
    $("#user-selection").change(function() {

      var user = $("#user-selection").val() ;
      console.log("Fetching information for " + user) ;

      // Change cursor to a 'wait' symbol 
      // while we wait for the API to respond
      $("*").css({"cursor": "wait"}) ; 

      $.getJSON("/api/repos/" + user, function(data) {
        // Executed on success
        model.exists = true ;
        model.repos = data ;
      }).fail(function() {
        // Executed on failure
        model.exists = false ;
        model.repos = [] ;
      }).always(function() {
        // Always executed
        model.ghubUser = user ;

        // Restore cursor
        $("*").css({"cursor": "initial"}) ;

        // Tell the rest of the application 
        // that the model has been updated.
        events.trigger("model_updated") ;
      });
    }) ;
  } ;

  return { "initialize": initialize };

});

我们的控制模块仅公开 initialize 方法。一旦初始化完成,控制器将通过事件监听器与应用程序的其余部分交互。我们将在 main.js 中调用控制器的 initialize 方法。目前,该文件的最后一行只是一个空的 require 块。让我们导入我们的控制器并初始化它:

// public/javascripts/main.js

require(["controller"], function(controller) {
  controller.initialize();
});

为了测试这一点,我们可以将一个虚拟监听器绑定到 "model_updated" 事件。例如,我们可以使用以下片段将当前模型记录到浏览器 JavaScript 控制台中(您可以直接在 JavaScript 控制台中编写此片段):

> require(["events", "model"], 
function(events, model) {
  events.on("model_updated", function () { 
    console.log("model_updated event received"); 
    console.log(model); 
  });
}); 

如果你搜索一个用户,模型将被打印到控制台。我们现在有了控制器。最后一步是编写视图。

响应视图

如果请求失败,我们只需在响应 div 中显示 未找到。这部分代码编写起来最简单,所以让我们先做这个。我们定义一个 initialize 方法来生成视图。然后视图监听 "model_updated" 事件,该事件在控制器更新模型后触发。一旦初始化完成,与响应视图交互的唯一方式是通过 "model_updated" 事件:

// public/javascripts/responseView.js

define(["jquery", "model", "events"],
function($, model, events) {

  var failedResponseHtml = 
    "<div class='col-md-12'>Not found</div>" ;

  function initialize() {
    events.on("model_updated", function() {
      if (model.exists) {
        // success – we will fill this in later.
        console.log("model exists")
      }
      else {
        // failure – the user entered
        // is not a valid GitHub login 
        $("#response").html(failedResponseHtml) ;
      }
    }) ;
  }

  return { "initialize": initialize } ;

});

为了启动视图,我们必须从 main.js 中调用 initialize 函数。只需在 require 块中添加对 responseView 的依赖,并调用 responseView.initialize()。经过这些修改,main.js 中的最终 require 块如下:

// public/javascripts/main.js

require(["controller", "responseView"],
function(controller, responseView) {
  controller.initialize();
  responseView.initialize() ;
}) ;

你可以通过在用户输入中输入垃圾信息来故意导致 API 请求失败,以检查这一切是否正常工作。

当用户输入有效的 GitHub 登录名并且 API 返回一个仓库列表时,我们必须在屏幕上显示这些信息。我们显示一个表格和一个饼图,该饼图按语言聚合仓库大小。我们将定义饼图和表格在两个单独的模块中,分别称为 repoGraph.jsrepoTable.js。现在让我们假设这些模块存在,并且它们公开了一个接受 model 和要在其中出现的 div 名称的 build 方法。

让我们更新 responseView 的代码,以适应用户输入有效的 GitHub 用户名:

// public/javascripts/responseView.js

define(["jquery", "model", "events", "repoTable", "repoGraph"],
function($, model, events, repoTable, repoGraph) {

  // HTHML to inject when the model represents a valid user 
 var successfulResponseHtml = 
 "<div class='col-md-6' id='response-table'></div>" +
 "<div class='col-md-6' id='response-graph'></div>" ;

  // HTML to inject when the model is for a non-existent user
  var failedResponseHtml = 
    "<div class='col-md-12'>Not found</div>" ;

  function initialize() {
    events.on("model_updated", function() {
      if (model.exists) {
 $("#response").html(successfulResponseHtml) ;
 repoTable.build(model, "#response-table") ;
 repoGraph.build(model, "#response-graph") ;
      }
      else {
        $("#response").html(failedResponseHtml) ;
      }
    }) ;
  }

  return { "initialize": initialize } ;

});

让我们回顾一下在 API 调用成功时会发生什么。我们在 #response div 中注入以下 HTML 片段:

var successfulResponseHtml = 
  "<div class='col-md-6' id='response-table'></div>" +
  "<div class='col-md-6' id='response-graph'></div>" ;

这添加了两个 HTML div,一个用于仓库表,另一个用于图表。我们使用 Bootstrap 类来垂直分割响应 div。

现在我们将注意力转向表格视图,它需要暴露一个单独的 build 方法,正如前一小节所描述的。我们只需在 HTML 表格中显示仓库。我们将使用 Underscore 模板 来动态构建表格。Underscore 模板的工作方式与 Scala 中的字符串插值类似:我们定义一个带有占位符的模板。让我们在浏览器控制台中试一试:

> require(["underscore"], function(_) {
  var myTemplate = _.template(
    "Hello, <%= title %> <%= name %>!"
  ) ;
});

这创建了一个 myTemplate 函数,该函数接受具有 titlename 属性的对象:

> require(["underscore"], function(_) {
  var myTemplate = _.template( ... ); 
  var person = { title: "Dr.", name: "Odersky" } ;
  console.log(myTemplate(person)) ;
});

Underscore 模板因此提供了一种方便的机制,可以将对象格式化为字符串。我们将为表格中的每一行创建一个模板,并将每个仓库的模型传递给模板:

// public/javascripts/repoTable.js

define(["underscore", "jquery"], function(_, $) {

  // Underscore template for each row
  var rowTemplate = _.template("<tr>" +
    "<td><%= name %></td>" +
    "<td><%= language %></td>" +
    "<td><%= size %></td>" +
    "</tr>") ;

  // template for the table
  var repoTable = _.template(
    "<table id='repo-table' class='table'>" +
      "<thead>" +
        "<tr>" +
          "<th>Name</th><th>Language</th><th>Size</th>" +
        "</tr>" +
      "</thead>" +
      "<tbody>" +
        "<%= tbody %>" +
      "</tbody>" +
    "</table>") ;

  // Builds a table for a model
  function build(model, divName) {
    var tbody = "" ;
    _.each(model.repos, function(repo) {
      tbody += rowTemplate(repo) ;
    }) ;
    var table = repoTable({tbody: tbody}) ;
    $(divName).html(table) ;
  }

  return { "build": build } ;
}) ;

使用 NVD3 绘制图表

D3 是一个提供用于在 JavaScript 中构建交互式可视化的低级组件的库。通过提供低级组件,它为开发者提供了极大的灵活性。然而,学习曲线可能相当陡峭。在这个例子中,我们将使用 NVD3,这是一个为 D3 提供预制图表的库。这可以大大加快初始开发速度。我们将代码放置在文件 repoGraph.js 中,并公开一个名为 build 的单一方法,该方法接受一个模型和一个 div 作为参数,并在该 div 中绘制饼图。饼图将汇总所有用户仓库中的语言使用情况。

生成饼图的代码几乎与 NVD3 文档中给出的示例相同,该文档可在 nvd3.org/examples/pie.html 找到。传递给图表的数据必须作为一个对象数组可用。每个对象必须包含一个 label 字段和一个 size 字段。label 字段标识语言,而 size 字段是该用户用该语言编写的所有仓库的总大小。以下是一个有效的数据数组:

[ 
  { label: "Scala", size: 1234 },
  { label: "Python", size: 4567 }
]

要以这种格式获取数据,我们必须在我们的模型中聚合特定语言编写的仓库的大小。我们编写了 generateDataFromModel 函数来将模型中的 repos 数组转换为适合 NVD3 的数组。聚合的核心操作是通过调用 Underscore 的 groupBy 方法来按语言分组仓库。此方法与 Scala 的 groupBy 方法完全相同。考虑到这一点,generateDataFromModel 函数如下:

// public/javascripts/repoGraph.js

define(["underscore", "d3", "nvd3"], 
function(_, d3, nv) {

  // Aggregate the repo size by language.
  // Returns an array of objects like:
  // [ { label: "Scala", size: 1245}, 
  //   { label: "Python", size: 432 } ]
  function generateDataFromModel(model) {

    // Build an initial object mapping each
    // language to the repositories written in it
    var language2Repos = _.groupBy(model.repos, 
      function(repo) { return repo.language ; }) ;

    // Map each { "language":  [ list of repos ], ...} 
    // pairs to a single document { "language": totalSize }
    // where totalSize is the sum of the individual repos.
    var plotObjects = _.map(language2Repos, 
      function(repos, language) {
        var sizes = _.map(repos, function(repo) { 
          return repo.size; 
        });
        // Sum over the sizes using 'reduce'
        var totalSize = _.reduce(sizes, 
          function(memo, size) { return memo + size; },
        0) ;
        return { label: language, size: totalSize } ;
      }) ;

     return plotObjects;
  }

我们现在可以使用 NVD3 的 addGraph 方法来构建饼图:

  // Build the chart.
  function build(model, divName) {
    var transformedModel = generateDataFromModel(model) ;
    nv.addGraph(function() {

      var height = 350;
      var width = 350; 

      var chart = nv.models.pieChart()
        .x(function (d) { return d.label ; })
        .y(function (d) { return d.size ;})
        .width(width)
        .height(height) ;

      d3.select(divName).append("svg")
        .datum(transformedModel)
        .transition()
        .duration(350)
        .attr('width', width)
        .attr('height', height)
        .call(chart) ;

      return chart ;
    });
  }

  return { "build" : build } ;

});

这是我们应用程序的最后一个组件。将您的浏览器指向 127.0.0.1:9000,您应该会看到应用程序正在运行。

恭喜!我们已经构建了一个功能齐全的单页 Web 应用程序。

摘要

在本章中,我们学习了如何使用 Play 框架编写一个功能齐全的 Web 应用程序。恭喜你走到了这一步。构建 Web 应用程序可能会让许多数据科学家超出他们的舒适区,但了解足够的 Web 知识来构建基本的应用程序将允许你以引人入胜、吸引人的方式分享你的结果,同时也有助于与软件工程师和 Web 开发者进行沟通。

这就结束了我们对 Scala 库的快速浏览。在这本书的过程中,我们学习了如何使用 Breeze 高效地解决线性代数和优化问题,如何以函数式的方式在 SQL 数据库中插入和查询数据,以及如何与 Web API 交互以及如何创建它们。我们回顾了一些数据科学家用于编写并发或并行应用程序的工具,从并行集合和未来到 Spark 通过 Akka。我们看到了这些结构在 Scala 库中的普遍性,从 Play 框架中的未来到 Spark 的骨干 Akka。如果你已经读到这儿,给自己点个赞。

这本书对其涵盖的库进行了最简短的介绍,希望这足以让你尝到每个工具的用途,你可以用它完成什么,以及它在更广泛的 Scala 生态系统中的位置。如果你决定在你的数据科学流程中使用这些工具,你需要更详细地阅读文档,或者一本更完整的参考书。每章末尾列出的参考应该是一个良好的起点。

Scala 和数据科学都在快速发展。不要对某个特定的工具包或概念过于执着。保持对当前发展的关注,最重要的是,保持务实:找到适合工作的正确工具。Scala 和这里讨论的库通常会是那个工具,但并不总是:有时,一个 shell 命令或简短的 Python 脚本会更有效。记住,编程技能只是数据科学家知识体系的一个方面。即使你想在数据科学的工程方面专长,也要了解问题领域和机器学习数学基础。

最重要的是,如果你花时间阅读这本书,你很可能认为编程和数据科学不仅仅是工作。在 Scala 中编码可以令人满意和有成就感,所以享受乐趣,做得很棒!

参考文献

网上到处都是 HTML 和 CSS 教程。简单的 Google 搜索就能给你一个比我提供的任何参考列表都要好的资源了解。

Mike Bostock 的网站上有丰富的美丽 D3 可视化:bost.ocks.org/mike/。为了更好地了解 D3,我推荐Scott Murray 的《Web 交互数据可视化》

你也可以参考前一章中给出的关于 Play 框架和设计 REST API 的参考书籍。

附录 A. 模式匹配和提取器

模式匹配是 Scala 中控制流的一个强大工具。对于从命令式语言转向 Scala 的人来说,它往往被低估和未充分利用。

在深入理论之前,我们先来看几个模式匹配的例子。我们首先定义一个元组:

scala> val names = ("Pascal", "Bugnion")
names: (String, String) = (Pascal,Bugnion)

我们可以使用模式匹配来提取这个元组的元素并将它们绑定到变量上:

scala> val (firstName, lastName) = names
firstName: String = Pascal
lastName: String = Bugnion

我们刚刚提取了names元组的两个元素,并将它们绑定到变量firstNamelastName。注意左侧定义了一个模式,右侧必须匹配:我们声明变量names必须是一个包含两个元素的元组。为了使模式更具体,我们还可以指定元组中元素的预期类型:

scala> val (firstName:String, lastName:String) = names
firstName: String = Pascal
lastName: String = Bugnion

如果左侧的模式与右侧不匹配会发生什么?

scala> val (firstName, middleName, lastName) = names
<console>:13: error: constructor cannot be instantiated to expected type;
found   : (T1, T2, T3)
required: (String, String)
 val (firstName, middleName, lastName) = names

这会导致编译错误。其他类型的模式匹配失败会导致运行时错误。

模式匹配非常具有表现力。要实现没有模式匹配的相同行为,你必须明确地做以下操作:

  • 验证变量names是一个包含两个元素的元组

  • 提取第一个元素并将其绑定到firstName

  • 提取第二个元素并将其绑定到lastName

如果我们期望元组中的某些元素具有特定的值,我们可以在模式匹配过程中验证这一点。例如,我们可以验证names元组的第一个元素匹配"Pascal"

scala> val ("Pascal", lastName) = names
lastName: String = Bugnion

除了元组之外,我们还可以在 Scala 集合上进行匹配:

scala> val point = Array(1, 2, 3)
point: Array[Int] = Array(1, 2, 3)

scala> val Array(x, y, z) = point
x: Int = 1
y: Int = 2
z: Int = 3

注意这种模式匹配与数组构造的相似性:

scala> val point = Array(x, y, z)
point: Array[Int] = Array(1, 2, 3)

语法上,Scala 将模式匹配表达为实例构造过程的逆过程。我们可以将模式匹配视为对象的解构,将对象的组成部分绑定到变量。

当匹配集合时,有时人们只对匹配第一个元素或前几个元素感兴趣,并丢弃集合的其余部分,无论其长度如何。运算符_*将匹配任意数量的元素:

scala> val Array(x, _*) = point
x: Int = 1

默认情况下,由_*运算符匹配的模式部分不会被绑定到变量。我们可以如下捕获它:

scala> val Array(x, xs @ _*) = point
x: Int = 1
xs: Seq[Int] = Vector(2, 3)

除了元组和集合之外,我们还可以匹配 case 类。让我们首先定义一个表示名称的 case:

scala> case class Name(first: String, last: String)
defined class Name

scala> val name = Name("Martin", "Odersky")
name: Name = Name(Martin,Odersky)

我们可以以与匹配元组相同的方式匹配Name的实例:

scala> val Name(firstName, lastName) = name
firstName: String = Martin
lastName: String = Odersky

所有这些模式也可以用在match语句中:

scala> def greet(name:Name) = name match {
 case Name("Martin", "Odersky") => "An honor to meet you"
 case Name(first, "Bugnion") => "Wow! A family member!"
 case Name(first, last) => s"Hello, $first"
}
greet: (name: Name)String

在 for 推导式中的模式匹配

模式匹配在for推导式中非常有用,可以提取与特定模式匹配的集合中的项目。让我们构建一个Name实例的集合:

scala> val names = List(Name("Martin", "Odersky"), 
 Name("Derek", "Wyatt"))
names: List[Name] = List(Name(Martin,Odersky), Name(Derek,Wyatt))

我们可以在 for 推导式中使用模式匹配来提取类的内部结构:

scala> for { Name(first, last) <- names } yield first
List[String] = List(Martin, Derek)

到目前为止,没有什么特别突破性的。但如果我们想提取所有名字为"Martin"的人的姓氏呢?

scala> for { Name("Martin", last) <- names } yield last
List[String] = List(Odersky)

写作Name("Martin", last) <- names提取与模式匹配的names元素。你可能认为这是一个人为的例子,确实如此,但第七章,Web APIs中的例子展示了这种语言模式的有用性和多功能性,例如,用于从 JSON 对象中提取特定字段。

模式匹配内部机制

如果你定义了一个案例类,就像我们用 Name 看到的那样,你将免费获得对构造函数的模式匹配。你应该尽可能使用案例类来表示你的数据,从而减少实现自己的模式匹配的需求。尽管如此,了解模式匹配的工作原理仍然是有用的。

当你创建一个案例类时,Scala 会自动构建一个伴随对象:

scala> case class Name(first: String, last: String)
defined class Name

scala> Name.<tab>
apply   asInstanceOf   curried   isInstanceOf   toString   tupled   unapply

用于(内部)模式匹配的方法是 unapply。此方法接受一个对象作为参数,并返回 `Option[T],其中 T 是案例类的值元组的元组。

scala> val name = Name("Martin", "Odersky")
name: Name = Name(Martin,Odersky)

scala> Name.unapply(name)
Option[(String, String)] = Some((Martin,Odersky))

unapply 方法是一个 提取器。它扮演着与构造函数相反的角色:它接受一个对象,并提取构建该对象所需的参数列表。当你写 val Name(firstName, lastName),或者当你使用 Name 作为匹配语句中的案例时,Scala 会调用 Name.unapply 来匹配你正在匹配的对象。Some[(String, String)] 的值表示模式匹配,而 None 的值表示模式失败。

要编写自定义提取器,你只需要一个具有 unapply 方法的对象。虽然 unapply 通常位于你正在解构的类的伴随对象中,但这并不一定是这种情况。实际上,它根本不需要对应于现有的类。例如,让我们定义一个匹配任何非零双精度浮点数的 NonZeroDouble 提取器:

scala> object NonZeroDouble { 
 def unapply(d:Double):Option[Double] = {
 if (d == 0.0) { None } else { Some(d) } 
 }
}
defined object NonZeroDouble

scala> val NonZeroDouble(denominator) = 5.5
denominator: Double = 5.5

scala> val NonZeroDouble(denominator) = 0.0
scala.MatchError: 0.0 (of class java.lang.Double)
 ... 43 elided

尽管没有相应的 NonZeroDouble 类,我们仍然定义了一个 NonZeroDouble 提取器。

这个 NonZeroDouble 提取器在匹配对象中很有用。例如,让我们定义一个 safeDivision 函数,当分母为零时返回默认值:

scala> def safeDivision(numerator:Double, 
 denominator:Double, fallBack:Double) =
 denominator match {
 case NonZeroDouble(d) => numerator / d
 case _ => fallBack
 }
safeDivision: (numerator: Double, denominator: Double, fallBack: Double)Double

scala> safeDivision(5.0, 2.0, 100.0)
Double = 2.5

scala> safeDivision(5.0, 0.0, 100.0)
Double = 100.0

这是一个简单的例子,因为 NonZeroDouble.unapply 方法非常简单,但如果你定义一个更复杂的测试,你可能会看到它的有用性和表达性。定义自定义提取器让你能够定义强大的控制流结构来利用 match 语句。更重要的是,它们使使用提取器的客户端能够以声明性的方式考虑控制流:客户端可以声明他们需要一个 NonZeroDouble,而不是指示编译器检查值是否为零。

提取序列

前一节解释了从案例类中提取数据以及如何编写自定义提取器,但没有解释如何在序列上提取:

scala> val Array(a, b) = Array(1, 2)
a: Int = 1
b: Int = 2

与依赖 unapply 方法不同,序列依赖于在伴随对象中定义的 unapplySeq 方法。这应该返回 Option[Seq[A]]

scala> Array.unapplySeq(Array(1, 2))
Option[IndexedSeq[Int]] = Some(Vector(1, 2))

让我们写一个例子。我们将编写一个用于 Breeze 向量的提取器(目前 Breeze 向量不支持模式匹配)。为了避免与DenseVector伴生对象冲突,我们将我们的unapplySeq写在单独的对象中,称为DV。我们的unapplySeq方法需要做的只是将其参数转换为 Scala Vector实例。为了避免泛型混淆概念,我们将仅为此[Double]向量编写此实现:

scala> import breeze.linalg._
import breeze.linalg._

scala> object DV {
 // Just need to convert to a Scala vector.
 def unapplySeq(v:DenseVector[Double]) = Some(v.toScalaVector)
}
defined object DV

让我们尝试我们的新提取器实现:

scala> val vec = DenseVector(1.0, 2.0, 3.0)
vec: breeze.linalg.DenseVector[Double] = DenseVector(1.0, 2.0, 3.0)

scala> val DV(x, y, z) = vec
x: Double = 1.0
y: Double = 2.0
z: Double = 3.0

摘要

模式匹配是控制流的一个强大工具。它鼓励程序员以声明式的方式思考:声明你期望一个变量匹配某个模式,而不是明确告诉计算机如何检查它是否匹配这个模式。这可以节省许多代码行并提高清晰度。

参考

对于 Scala 中模式匹配的概述,没有比Programming in Scala更好的参考书籍了,作者是Martin OderskyBill VennersLex Spoon。第一版的在线版本可在以下网址找到:www.artima.com/pins1ed/case-classes-and-pattern-matching.html

Daniel Westheide的博客涵盖了稍微高级一点的 Scala 结构,并且是非常有用的阅读材料:danielwestheide.com/blog/2012/11/21/the-neophytes-guide-to-scala-part-1-extractors.html

第二部分。模块 2

Scala 机器学习

利用 Scala 和机器学习构建和研究可以从数据中学习系统的系统

第一章。入门

对于任何计算机科学家来说,理解不同类别的机器学习算法并能够选择与其专业领域和数据集相关的算法至关重要。然而,这些算法的应用只是从输入数据中提取准确和性能良好的模型所需整体努力的一小部分。常见的数据挖掘工作流程包括以下顺序步骤:

  1. 定义要解决的问题。

  2. 加载数据。

  3. 预处理、分析和过滤输入数据。

  4. 发现模式、亲和力、簇和类(如果需要)。

  5. 选择模型特征和合适的机器学习算法(们)。

  6. 精炼和验证模型。

  7. 提高实现的计算性能。

在这本书中,过程的每个阶段对于构建 正确 的模型都是至关重要的。

提示

在一本书中详细描述关键机器学习算法及其实现是不可能的。信息量和 Scala 代码的量甚至会让最热心的读者感到压倒。每一章都专注于理解该主题绝对必要的数学和代码。鼓励开发者浏览以下内容:

  • 书中使用的 Scala 编码约定和标准在附录 A 中,基本概念

  • API Scala 文档

  • 可在线获取的完整文档源代码

这第一章向您介绍了机器学习算法的分类、本书中使用的工具和框架,以及一个简单的逻辑回归应用,让您入门。

数学符号的趣味

每一章都包含一个小节,专门针对对机器学习科学与艺术背后的数学概念感兴趣的人,阐述算法的公式。这些部分是可选的,并定义在提示框中。例如,提示框中提到的变量 X 的均值和方差的数学表达式如下:

提示

约定和符号

本书在数学公式中使用数据集的零基索引。

M1:一组 N 个观测值表示为 {xi} = x[0], x[1], … , x[N-1],以 xi 为值的随机值的算术平均值定义为:

数学符号的趣味

为什么是机器学习?

数字设备的数量激增,产生了越来越多的数据。我能找到的最佳类比来描述从大型数据集中提取知识的需求、愿望和紧迫性,就是从矿山中提取贵金属的过程,在某些情况下,是从石头中提取血液。

知识通常被定义为一种可以随着新数据的到来而不断更新或调整的模型。模型显然是特定领域的,从信用风险评估、人脸识别、服务质量最大化、疾病病理症状分类、计算机网络优化、安全入侵检测,到客户的在线行为和购买历史。

机器学习问题被分为分类、预测、优化和回归。

分类

分类的目的是从历史数据中提取知识。例如,可以构建一个分类器来从一组症状中识别疾病。科学家收集有关体温(连续变量)、拥堵(离散变量 HIGHMEDIUMLOW)和实际诊断(流感)的信息。这个数据集用于创建一个模型,例如 IF temperature > 102 AND congestion = HIGH THEN patient has the flu (probability 0.72),医生可以使用这个模型进行诊断。

预测

一旦模型使用历史观察数据训练并经过历史观察数据验证后,它就可以用来预测某些结果。医生收集患者的症状,如体温和鼻塞,并预测其健康状况。

优化

一些全局优化问题使用传统的线性和非线性优化方法难以解决。机器学习技术提高了优化方法收敛到解决方案(智能搜索)的机会。你可以想象,对抗新病毒的传播需要优化一个可能随着更多症状和病例的发现而演变的流程。

回归

回归是一种特别适合连续模型的分类技术。线性(最小二乘法)、多项式和逻辑回归是最常用的技术之一,用于拟合参数模型或函数 y= f (x), x={x[i]} 到数据集。回归有时被视为分类的特殊情况,其中输出变量是连续的而不是分类的。

为什么选择 Scala?

与大多数函数式语言一样,Scala 为开发人员和科学家提供了一个工具箱,用于实现可以轻松编织成连贯数据流的迭代计算。在一定程度上,Scala 可以被视为流行的 MapReduce 模型在分布式计算大量数据方面的扩展。在语言的功能中,以下特性被认为是机器学习和统计分析中的基本特性。

抽象

函子单子是函数式编程中的重要概念。单子是从范畴论和群论中派生出来的,允许开发者创建如 Scalaz、Twitter 的 Algebird 或 Google 的 Breeze Scala 库中所示的高级抽象。更多关于这些库的信息可以在以下链接中找到:

在数学中,一个范畴 M 是一个由以下定义的结构:

  • 某些类型的对象:{x ϵ X, y ϵ Y, z ϵ Z, …}

  • 应用到这些对象上的态射或映射:x ϵ X, y ϵ Y, f: x -› y

  • 模态的复合:f: x -› y, g: y -› z => g o f: x -› z

协变函子逆变函子双函子是代数拓扑中理解良好的概念,与流形和向量丛相关。它们在微分几何和非线性模型的数据生成中常用。

高阶投影

科学家将观测定义为特征集或向量。分类问题依赖于观测向量之间相似性的估计。一种技术是通过计算归一化内积来比较两个向量。协变向量被定义为将向量映射到内积(域)的线性映射 α。

提示

内积

M1:内积和 α 协变向量的定义如下:

高阶投影

让我们定义一个向量为一个从任何 _ => Vector[_] 域(或 Function1[_, Vector])构造函数。协变向量随后被定义为将向量映射到其 Vector[_] => _ 域(或 Function1[Vector, _])的映射函数。

让我们定义一个二维(两种类型或域)的高阶结构 Hom,它可以被定义为向量或协变向量,通过固定两种类型之一:

type Hom[T] = {
  type Right[X] = Function1[X,T] // Co-vector
  type Left[X] = Function1[T,X]   // Vector
 }

注意

张量和流形

向量和协变向量是张量(逆变和协变)的类。张量(域)用于非线性模型的学习和核函数的生成。流形在 第四章 的 降维 部分的 流形 简要介绍,无监督学习。张量场和流形学习的话题超出了本书的范围。

高阶投影的投影,HomRightLeft 单参数类型被称为函子,如下所示:

  • 一个针对 right 投影的协变函子

  • 一个针对 left 投影的逆变函子。

向量的协变函子

变量上的一个协变函子是一个映射 F: C => C,使得:

  • 如果 f: x -› yC 上的态射,那么 F(x) -› F(y) 也是 C 上的态射

  • 如果 id: x -› xC 上的恒等态射,那么 F(id) 也是 C 上的恒等态射

  • 如果 g: y -› z 也是 C 上的一个形态,那么 F(g o f) = F(g) o F(f)

Scala 中F[U => V] := F[U] => F[V]协变函子的定义如下:

trait Functor[M[_]] {
  def mapU,V(f: U =>V): M[V]
}

例如,让我们考虑一个定义为T类型的n维向量的观察,Obs[T]。观察的构造函数可以表示为Function1[T,Obs]。它的ObsFunctor函子实现如下:

trait ObsFunctor[T] extends Functor[(Hom[T])#Left] { self =>
  override def mapU,V(f: U =>V): 
    Function1[T,V] = f.compose(vu)
}

由于形态应用于Obs元素的返回类型Function1[T, Obs],该函子被指定为协变函子。两个参数类型的Hom投影实现为(Hom[T])#Left

逆变函子用于协变向量

一个一变量的逆变函子是一个映射 F: C => C,使得:

  • 如果 f: x -› yC 上的一个形态,那么 F(y) -> F(x) 也是 C 上的一个形态

  • 如果 id: x -› xC 上的恒等形态,那么 F(id) 也是 C 上的恒等形态

  • 如果 g: y -› z 也是 C 上的一个形态,那么 F(g o f) = F(f) o F(g)

Scala 中F[U => V] := F[V] => F[U]逆变函子的定义如下:

trait CoFunctor[M[_]] {
  def mapU,V(f: V =>U): M[V]
}

注意,在f形态中的输入和输出类型与协变函子的定义相反。协变向量的构造函数可以表示为Function1[Obs,T]。它的CoObsFunctor函子实现如下:

trait CoObsFunctor[T] extends CoFunctor[(Hom[T])#Right] {
  self =>
    override def mapU,V(f: V =>U): 
       Function1[V,T] = f.andThen(vu)
}

Monads

Monads 是代数拓扑中的结构,与范畴论相关。Monads 扩展了函子的概念,允许在单个类型上执行称为单子组合的形态组合。它们使计算能够链式或编织成一系列步骤或管道。Scala 标准库中捆绑的集合(ListMap等)被构建为 Monads [1:1]。

Monads 提供了以下功能的能力:

  • 创建集合

  • 转换集合的元素

  • 展平嵌套集合

以下是一个例子:

trait Monad[M[_]] {
  def unitT: M[T]
  def mapU,V(f U =>V): M[V]
  def flatMapU,V(f: U =>M[V]): M[V]
}

因此,在机器学习中,Monads 至关重要,因为它们允许你将多个数据转换函数组合成一个序列或工作流程。这种属性适用于任何类型的复杂科学计算 [1:2]。

注意

核函数的单子组合

在第八章的“核函数”部分的“核单调组合”部分中,Monads 用于核函数的组合,核模型和支持向量机

可伸缩性

如前所述,通过利用 Scala 的高阶方法,functors 和 monads 可以实现数据处理函数的并行化和链式操作。在实现方面,演员是使 Scala 可扩展的核心元素之一。演员为 Scala 开发者提供了构建可扩展、分布式和并发应用程序的高级抽象。演员隐藏了并发和底层线程池管理的繁琐实现细节。演员通过异步不可变消息进行通信。例如 AkkaApache Spark 这样的分布式计算 Scala 框架扩展了 Scala 标准库的功能,以支持在非常大的数据集上进行计算。Akka 和 Apache Spark 在本书的最后一章中进行了详细描述 [1:3]。

简而言之,工作流被实现为一串活动或计算任务。这些任务包括 flatMapmapfoldreducecollectjoinfilter 等高阶 Scala 方法,它们应用于大量观测数据。Scala 为开发者提供了将数据集分区并通过演员集群执行任务的工具。Scala 还支持在本地和远程演员之间进行消息调度和路由。开发者可以决定以很少的代码更改将工作流部署在本地或多个 CPU 核心和服务器上。

可扩展性

将模型训练工作流作为分布式计算进行部署

在前面的图中,一个控制器,即主节点,管理着任务14的顺序,类似于调度器。这些任务实际上是在多个工作节点上执行的,这些节点由演员实现。主节点或演员与工作者交换消息以管理工作流程执行的状 态以及其可靠性,如第十二章的使用演员的可扩展性部分所示,可扩展框架。通过监督演员的层次结构实现了这些任务的高可用性。

可配置性

Scala 支持使用抽象变量、自引用组合和可堆叠特质组合来依赖注入。最常用的依赖注入模式之一,即蛋糕模式,在第二章的组合混入构建工作流部分进行了描述,Hello World!

可维护性

Scala 本地嵌入领域特定语言DSL)。DSL 是建立在 Scala 本地库之上的语法层。DSL 允许软件开发者用科学家容易理解的方式来抽象计算。最著名的 DSL 应用是定义 MATLAB 程序中使用的语法的仿真,这是数据科学家所熟悉的。

按需计算

懒惰的方法和价值观允许开发者根据需要执行函数和分配计算资源。Spark 框架依赖于懒惰变量和方法来链式处理弹性分布式数据集RDD)。

模型分类

一个模型可以是预测性的、描述性的或自适应的。

预测性模型在历史数据中发现模式,并提取因素(或特征)之间的基本趋势和关系。它们用于预测和分类未来的事件或观察。预测分析在包括市场营销、保险和制药在内的各种领域中使用。预测模型通过使用预选的训练集进行监督学习来创建。

描述性模型试图通过将观察结果分组到具有相似属性的聚类中,在数据中找到异常模式或亲和力。这些模型定义了知识发现的第一步和重要步骤。它们通过无监督学习生成。

第三类模型,称为自适应建模,是通过强化学习创建的。强化学习包括一个或多个决策代理,它们在尝试解决问题、优化目标函数或解决约束条件的过程中推荐并可能执行动作。

机器学习算法分类法

机器学习的目的是教会计算机在没有人类干预的情况下执行任务。越来越多的应用,如基因组学、社交网络、广告或风险评估,产生了大量可以分析或挖掘以提取知识或洞察过程、客户或组织的数据。最终,机器学习算法包括通过使用历史、现在和未来的数据来识别和验证模型,以优化性能标准[1:4]。

数据挖掘是从数据集中提取或识别模式的过程。

无监督学习

无监督学习的目标是在一组观察中发现规律性和不规则性的模式。这个过程在统计学中被称为密度估计,可以分为两大类:数据聚类发现和潜在因素发现。该方法包括处理输入数据以理解与婴儿或动物自然学习过程相似的规律。无监督学习不需要标记数据(或预期值),因此易于实现和执行,因为不需要专业知识来验证输出。然而,可以对聚类算法的输出进行标记,并将其用于未来的分类。

聚类

数据聚类的目的是将一组数据划分为若干个簇或数据段。实际上,聚类算法通过最小化簇内观测值之间的距离和最大化簇间观测值之间的距离来组织观测值到簇中。聚类算法包括以下步骤:

  1. 通过对输入数据做出假设来创建模型。

  2. 选择聚类的目标函数或目标。

  3. 评估一个或多个算法以优化目标函数。

数据聚类也称为数据分段数据划分

维度缩减

维度缩减技术旨在找到构建可靠模型所需的最小但最相关的特征集。在模型中减少特征或参数的数量有很多原因,从避免过拟合到降低计算成本。

使用无监督学习从数据中提取知识的不同技术有很多种分类方法。以下分类法根据其目的对这些技术进行了细分,尽管这个列表远非详尽,如下面的图所示:

维度缩减

无监督学习算法分类法

监督学习

监督学习的最佳类比是函数逼近曲线拟合。在其最简单形式中,监督学习试图使用训练集 {x, y} 找到一个关系或函数 f: x → y。只要输入(标记数据)可用且可靠,监督学习比任何其他学习策略都要准确。缺点是可能需要一个领域专家来标记(或标记)数据作为训练集。

监督机器学习算法可以分为两大类:

  • 生成模型

  • 判别模型

生成模型

为了简化统计公式的描述,我们采用以下简化:事件 X 的概率与离散随机变量 X 取值 x 的概率相同:p(X) = p(X=x)

联合概率的表示为 p(X,Y) = p(X=x,Y=y)

条件概率的表示为 p(X|Y) = p(X=x|Y=y)

生成模型试图拟合两个 XY 事件(或随机变量)的联合概率分布,p(X,Y),代表两组观察到的和隐藏的 xy 变量。判别模型计算隐藏变量 y 的一个事件或随机变量 Y 的条件概率,p(Y|X),给定观察变量 x 的一个事件或随机变量 X。生成模型通常通过贝叶斯定理引入。给定 X 事件的 Y 事件的条件概率是通过 X 事件在 Y 事件给定的条件概率与 X 事件概率通过 Y 事件概率归一化的乘积来计算的 [1:5]。

注意

贝叶斯定理

独立随机变量 X=xY=y 的联合概率如下:

生成模型

随机变量 Y = y 在给定 X = x 的条件概率如下:

生成模型

贝叶斯公式如下:

生成模型

贝叶斯定理是朴素贝叶斯分类器的基础,如第五章中“介绍多项式朴素贝叶斯”部分所述,朴素贝叶斯分类器

判别模型

与生成模型相反,判别模型直接计算条件概率 p(Y|X),使用相同的算法进行训练和分类。

生成模型和判别模型各有其优势和劣势。新手数据科学家通过实验学习将适当的算法匹配到每个问题。以下是一个简要的指南,描述了根据项目的目标或标准,哪种类型的模型是有意义的:

目标 生成模型 判别模型
准确性 高度依赖于训练集。 这取决于训练集和算法配置(即核函数)
建模需求 需要建模观察到的和隐藏的变量,这需要大量的训练。 训练集的质量不必像生成模型那样严格。
计算成本 这通常很低。例如,任何从贝叶斯定理导出的图形方法都有很低的开销。 大多数算法依赖于优化具有显著性能开销的凸函数。
约束 这些模型假设模型特征之间存在某种程度的独立性。 大多数判别算法可以适应特征之间的依赖关系。

我们可以通过将生成模型的顺序和随机变量任意分离,并将判别方法分解为应用于连续过程(回归)和离散过程(分类)来进一步细化监督学习算法的分类:

判别模型

监督学习算法的分类

半监督学习

半监督学习用于从带有不完整标签的数据集构建模型。流形学习和信息几何算法通常应用于部分标记的大数据集。半监督学习技术的描述超出了本书的范围。

强化学习

强化学习在机器人或游戏策略领域之外,不如监督学习和无监督学习那样被充分理解。然而,自 90 年代以来,基于遗传算法的分类器在解决需要与领域专家协作的问题上变得越来越受欢迎。对于某些类型的应用,强化学习算法输出一系列推荐的动作供自适应系统执行。在最简单的形式中,这些算法估计最佳的行动方案。大多数基于强化学习的复杂系统都会建立和更新政策,如果需要,专家可以否决。强化学习系统的开发者面临的最主要挑战是,推荐的动作或政策可能取决于部分可观察的状态。

遗传算法通常不被认为是强化学习工具箱的一部分。然而,一些高级模型,如学习分类系统,使用遗传算法来分类和奖励表现最佳规则和政策。

与前两种学习策略一样,强化学习模型可以分为马尔可夫或进化模型:

强化学习

强化学习算法的分类

这是对机器学习算法的简要概述,提供了一个建议的、近似的分类法。介绍机器学习的方法几乎和数据和计算机科学家一样多。我们鼓励您浏览本书末尾的参考文献列表,以找到适合您兴趣和理解水平的文档。

不要重新发明轮子!

有许多稳健、准确和高效的 Java 库用于数学、线性代数或优化,这些库多年来被广泛使用:

完全没有必要在 Scala 中重新编写、调试和测试这些组件。开发者应考虑创建一个包装器或接口,以便使用他们最喜欢的和可靠的 Java 库。本书利用 Apache Commons Math 库来处理一些特定的线性代数算法。

工具和框架

在动手实践之前,您需要下载和部署一组最小化的工具和库;毕竟,没有必要重新发明轮子。为了编译和运行本书中描述的源代码,必须安装一些关键组件。我们专注于开源和常见库,尽管您被邀请尝试您选择的等效工具。这里描述的框架的学习曲线非常平缓。

Java

本书描述的代码已在 Windows x64 和 Mac OS X x64 上使用 JDK 1.7.0_45 和 JDK 1.8.0_25 进行了测试。如果您尚未安装,则需要安装 Java 开发工具包。最后,必须相应地更新JAVA_HOMEPATHCLASSPATH环境变量。

Scala

代码已在 Scala 2.10.4 和 2.11.4 上进行了测试。我们建议您使用 Scala 版本 2.10.4 或更高版本与 SBT 0.13 或更高版本一起使用。假设 Scala 运行时(REPL)和库已经正确安装,并且SCALA_HOMEPATH环境变量已经更新。

Eclipse 的 Scala 插件(版本 4.0 或更高)的描述和安装说明可在scala-ide.org/docs/user/gettingstarted.html找到。您还可以从 JetBrains 网站下载IntelliJ IDEA 的 Scala 插件(版本 13 或更高)。

无处不在的简单构建工具SBT)将是我们的主要构建引擎。构建文件sbt/build.sbt的语法符合版本 0.13,并用于编译和组装本书中展示的源代码。SBT 可以作为 Typesafe activator 的一部分下载,或者直接从www.scala-sbt.org/download.html下载。

Apache Commons Math

Apache Commons Math 是一个用于数值处理、代数、统计学和优化的 Java 库[1:6]。

描述

这是一个轻量级库,为开发者提供了一组小型、现成的 Java 类,这些类可以轻松地编织到机器学习问题中。本书中使用的示例需要版本 3.5 或更高。

数学库支持以下内容:

  • 函数、微分和积分及常微分方程

  • 统计分布

  • 线性和非线性优化

  • 稠密和稀疏向量和矩阵

  • 曲线拟合、相关性分析和回归

更多信息,请访问commons.apache.org/proper/commons-math

许可证

我们需要 Apache 公共许可证 2.0;条款可在www.apache.org/licenses/LICENSE-2.0找到。

安装

Apache Commons Math 库的安装和部署相当简单。步骤如下:

  1. 请访问下载页面

  2. 将最新的.jar文件下载到二进制部分,例如commons-math3-3.5-bin.zip(针对 3.5 版本)。

  3. 解压并安装.jar文件。

  4. 按以下步骤将commons-math3-3.5.jar添加到类路径中:

    • 对于 Mac OS Xexport CLASSPATH=$CLASSPATH:/Commons_Math_path/commons-math3-3.5.jar

    • 对于 Windows 系统:转到系统属性 | 高级系统设置 | 高级 | 环境变量,然后编辑CLASSPATH变量

  5. 如有必要,将commons-math3-3.5.jar文件添加到您的 IDE 环境(例如,对于 Eclipse,请转到项目 | 属性 | Java 构建路径 | | 添加外部 JARs,对于 IntelliJ IDEA,请转到文件 | 项目结构 | 项目设置 | )。

您还可以从部分下载commons-math3-3.5-src.zip

JFreeChart

JFreeChart 是一个开源的图表和绘图 Java 库,在 Java 程序员社区中广泛使用。它最初由 David Gilbert [1:7]创建。

描述

该库支持多种可配置的图表和图形(散点图、仪表盘、饼图、面积图、条形图、箱线图、堆叠图和 3D 图)。我们使用 JFreeChart 在整本书中显示数据处理和算法的输出,但鼓励您在时间允许的情况下自行探索这个优秀的库。

许可证

它根据 GNU 较小通用公共许可证LGPL)的条款分发,允许其在专有应用程序中使用。

安装

要安装和部署 JFreeChart,请执行以下步骤:

  1. 访问www.jfree.org/jfreechart/

  2. 从 Source Forge 下载最新版本sourceforge.net/projects/jfreechart/files

  3. 解压并部署.jar文件。

  4. 按以下步骤将jfreechart-1.0.17.jar(针对 1.0.17 版本)添加到类路径中:

    • 对于 Mac OS Xexport CLASSPATH=$CLASSPATH:/JFreeChart_path/jfreechart-1.0.17.jar

    • 对于 Windows 系统:转到系统属性 | 高级系统设置 | 高级 | 环境变量,然后编辑CLASSPATH变量

  5. 如有必要,将jfreechart-1.0.17.jar文件添加到您的 IDE 环境

其他库和框架

每章特有的库和工具将与主题一起介绍。可扩展的框架在最后一章中介绍,并附有下载说明。与条件随机字段和支持向量机相关的库在各自的章节中描述。

注意

为什么不使用 Scala 代数和数值库呢?

如 Breeze、ScalaNLP 和 Algebird 等库是线性代数、数值分析和机器学习的有趣 Scala 框架。它们为经验丰富的 Scala 程序员提供了高质量的抽象层。然而,本书设计为一个教程,允许开发者使用现有的或遗留的 Java 库从头开始编写算法 [1:8]。

源代码

本书使用 Scala 编程语言来实现和评估《Scala for Machine Learning》中涵盖的机器学习技术。然而,源代码片段被缩减到理解书中讨论的机器学习算法所必需的最小范围。这些算法的正式实现可以在 Packt Publishing 的网站上找到(www.packtpub.com)。

小贴士

下载示例代码

您可以从您在 www.packtpub.com 的账户中下载您购买的所有 Packt 书籍的示例代码文件。如果您在其他地方购买了这本书,您可以访问 www.packtpub.com/support 并注册,以便将文件直接通过电子邮件发送给您。

上下文与视图边界

书中讨论的大多数 Scala 类都是使用与离散/分类值(Int)或连续值(Double)关联的类型参数化的。上下文边界要求客户端代码中使用的任何类型都必须有 IntDouble 作为上界:

class AT <: Int
class BT <: Double

这样的设计对客户端施加了从简单类型继承并处理容器类型的协变和逆协变的约束 [1:9]。

对于本书,视图边界被用于代替上下文边界,因为它们只需要定义参数化类型的隐式转换:

class AT <: AnyVal(implicit f: T => Int)
class CT < : AnyVal(implicit f: T => Float)

注意

视图边界弃用

视图边界的表示法 T <% Double 在 Scala 2.11 及更高版本中被弃用。class A[T <% Float] 声明是 class AT 的简写形式。

展示

为了提高算法实现的可读性,所有非必需代码,如错误检查、注释、异常或导入都被省略。书中展示的代码片段中省略了以下代码元素:

  • 代码文档:

    // …..
    /* … */
    
  • 类参数和方法参数的验证:

    require( Math.abs(x) < EPS, " …")
    
  • 类限定符和作用域声明:

    final protected class SVM { … }
    private[this] val lsError = …
    
  • 方法限定符:

    final protected def dot: = …
    
  • 异常:

    try {
       correlate …
    } catch {
       case e: MathException => ….
    }
    Try {    .. } match {
      case Success(res) =>
      case Failure(e => ..
    }
    
  • 日志和调试代码:

    private val logger = Logger.getLogger("..")
    logger.info( … )
    
  • 非必需注解:

    @inline def main = ….
    @throw(classOf[IllegalStateException])
    
  • 非必需方法

本书代码片段中省略的 Scala 代码元素完整列表可以在 附录 A 的 代码片段格式 部分找到,基本概念

原始类型和隐式转换

本书所展示的算法共享相同的原始类型、泛型操作符和隐式转换。

原始类型

为了提高代码的可读性,以下原始类型将被使用:

type DblPair = (Double, Double)
type DblArray = Array[Double]
type DblMatrix = Array[DblArray]
type DblVector = Vector[Double]
type XSeries[T] = Vector[T]         // One dimensional vector
type XVSeries[T] = Vector[Array[T]] // multi-dimensional vector

在第三章“Scala 中的时间序列”部分介绍的时间序列作为参数化类型TXSeries[T]XVSeries[T]实现。

注意

记住这六种类型;它们在本书中都有使用。

类型转换

隐式转换是 Scala 编程语言的一个重要特性。它允许开发者在单个位置为整个库指定类型转换。以下是本书中使用的几个隐式类型转换:

object Types {
  Object ScalaMl {  
   implicit def double2Array(x: Double): DblArray = 
      ArrayDouble
   implicit def dblPair2Vector(x: DblPair): Vector[DblPair] = 
      VectorDblPair
   ...
  }
}

注意

库特定转换

这里列出的原始类型与特定库(如 Apache Commons Math 库)中引入的类型之间的转换在相关章节中描述。

不可变性

通常,减少对象的状态数量是一个好主意。方法调用将对象从一个状态转换到另一个状态。方法或状态的数量越多,测试过程就越繁琐。

创建未定义(训练)的模型是没有意义的。因此,将模型的训练作为其实现的类的构造函数的一部分是非常有意义的。因此,机器学习算法的唯一公共方法如下:

  • 分类或预测

  • 验证

  • 如果需要,检索模型参数(权重、潜在变量、隐藏状态等)

Scala 迭代器的性能

Scala 高阶迭代方法性能的评估超出了本书的范围。然而,了解每种方法的权衡是很重要的。

for构造不应作为计数迭代器使用。它旨在实现 for-comprehensive monad(mapflatMap)。本书中展示的源代码使用高阶foreach方法。

让我们试试水

本节最后介绍了训练和分类工作流程的关键元素。使用简单的逻辑回归作为测试案例来展示计算工作流程的每一步。

计算工作流程概述

在最简单的情况下,执行数据集运行时处理的计算工作流程由以下阶段组成:

  1. 从文件、数据库或任何流式设备中加载数据集。

  2. 将数据集分割以进行并行数据处理。

  3. 使用过滤技术、方差分析和在必要时应用惩罚及归一化函数进行数据预处理。

  4. 应用模型——无论是聚类集还是类别集——以对新数据进行分类。

  5. 评估模型的品质。

使用类似的任务序列从训练数据集中提取模型:

  1. 从文件、数据库或任何流式设备中加载数据集。

  2. 将数据集分割以进行并行数据处理。

  3. 在必要时,将过滤技术、方差分析和惩罚及归一化函数应用于原始数据集。

  4. 从清洗后的输入数据中选择训练集、测试集和验证集。

  5. 使用聚类技术或监督学习算法提取关键特征,并在一组相似观测之间建立亲和力。

  6. 将特征数量减少到可管理的属性集,以避免过度拟合训练集。

  7. 通过迭代步骤 5、6 和 7 直到错误满足预定义的收敛标准来验证模型并调整模型。

  8. 将模型存储在文件或数据库中,以便可以应用于未来的观测。

数据聚类和数据分类可以独立进行,或者作为使用聚类技术在监督学习算法训练阶段预处理阶段的工作流程的一部分进行。数据聚类不需要从训练集中提取模型,而分类只能在从训练集构建了模型后才能执行。以下图像给出了训练、分类和验证的概述:

计算工作流程概述

训练和运行模型的通用数据流

上述图表是典型数据挖掘处理流程的概述。第一阶段包括通过聚类或监督学习算法的训练提取模型。然后,该模型与测试数据进行验证,其来源与训练集相同,但观测不同。一旦模型创建并验证,就可以用于分类实时数据或预测未来行为。现实世界的工作流程更为复杂,需要动态配置以允许不同模型的实验。可以使用几种不同的分类器来执行回归,并根据原始数据中的潜在噪声应用不同的过滤算法。

编写一个简单的工作流程

本书依赖于金融数据来实验不同的学习策略。练习的目标是构建一个模型,能够区分股票或商品的波动和非波动交易时段。对于第一个例子,我们选择简化版的二项式逻辑回归作为我们的分类器,因为我们把股价-成交量行为视为连续或准连续过程。

注意

逻辑回归简介

逻辑回归在第六章的逻辑回归部分进行了深入解释,回归和正则化。本例中处理的是针对二维观测的简单二项式逻辑回归分类器。

根据交易时段的波动性和成交量对交易时段进行分类的步骤如下:

  1. 确定问题范围

  2. 加载数据

  3. 预处理原始数据

  4. 在可能的情况下发现模式

  5. 实现分类器

  6. 评估模型

第 1 步 – 确定问题范围

目标是创建一个使用其每日交易量和波动性的股票价格模型。在整本书中,我们将依靠财务数据来评估和讨论不同数据处理和机器学习方法的优点。在这个例子中,数据使用 CSV 格式从Yahoo Finances中提取,以下字段:

  • 日期

  • 开盘价

  • 会话中的最高价

  • 会话中的最低价

  • 会话结束时的价格

  • 成交量

  • 在会话结束时调整价格

YahooFinancials 枚举器从雅虎财经网站提取历史每日交易信息:

type Fields = Array[String]
object YahooFinancials extends Enumeration {
   type YahooFinancials = Value
   val DATE, OPEN, HIGH, LOW, CLOSE, VOLUME, ADJ_CLOSE = Value

   def toDouble(v: Value): Fields => Double =   //1
   (s: Fields) => s(v.id).toDouble
   def toDblArray(vs: Array[Value]): Fields => DblArray = //2
       (s: Fields) => vs.map(v => s(v.id).toDouble)
  …
}

toDouble 方法将字符串数组转换为单个值(行1),而toDblArray将字符串数组转换为值数组(行2)。YahooFinancials 枚举器在附录 A 的数据来源部分以及基本概念中详细描述。

让我们创建一个简单的程序,该程序加载文件内容,执行一些简单的预处理函数,并创建一个简单的模型。我们选择了 2012 年 1 月 1 日至 2013 年 12 月 1 日之间的 CSCO 股票价格作为我们的数据输入。

让我们考虑以下截图中的两个变量,价格成交量。顶部图表显示了思科股票价格随时间的变化,底部柱状图表示思科股票随时间的每日交易量:

步骤 1 – 确定问题范围

2012-2013 年思科股票的价格-成交量动作

步骤 2 – 加载数据

第二步是从本地或远程数据存储加载数据集。通常,大型数据集是从数据库或如Hadoop 分布式文件系统HDFS)这样的分布式文件系统中加载的。load 方法接受绝对路径名,extract,并将输入数据从文件转换为Vector[DblPair]类型的时间序列:

def load(fileName: String): Try[Vector[DblPair]] = Try {
   val src =  Source.fromFile(fileName)  //3
   val data = extract(src.getLines.map(_.split(",")).drop(1)) //4
   src.close //5
   data
 }

数据文件通过调用Source.fromFile静态方法(行3)提取,然后在移除标题(文件中的第一行)之前通过drop提取字段。必须关闭文件以避免文件句柄泄漏(行5)。

备注

数据提取

Source.fromFile.getLines.map 调用管道方法返回一个只能遍历一次的迭代器。

extract 方法的目的是生成两个变量(相对股票波动性相对股票每日交易量)的时间序列:

def extract(cols: Iterator[Array[String]]): XVSeries[Double]= {
  val features = ArrayYahooFinancials //6
  val conversion = YahooFinancials.toDblArray(features)  //7
  cols.map(c => conversion(c)).toVector   
      .map(x => ArrayDouble/x(1), x(2)))  //8
}

extract 方法的唯一目的是将原始文本数据转换为二维时间序列。第一步包括选择要提取的三个特征 LOW(会话中的最低股价)、HIGH(会话中的最高价)和 VOLUME(会话的交易量)(行 6)。这个特征集用于将字段行转换为相应的三个值集(行 7)。最后,特征集被缩减为以下两个变量(行 8):

  • 会话中股价的相对波动:1.0 – LOW/HIGH

  • 会话中该股票的交易量:VOLUME

注意

代码可读性

一系列长的 Scala 高阶方法使得代码及其底层代码难以阅读。建议您分解长链式的方法调用,例如以下内容:

val cols = Source.fromFile.getLines.map(_.split(",")).toArray.drop(1)

我们可以将方法调用分解为以下几个步骤:

val lines = Source.fromFile.getLines
val fields = lines.map(_.split(",")).toArray
val cols = fields.drop(1)

我们强烈建议您查阅由 Twitter 的 Marius Eriksen 编写的优秀指南 Effective Scala。这绝对是对任何 Scala 开发者必读的[1:10]。

第 3 步 - 数据预处理

下一步是将数据归一化到[0.0, 1.0]范围,以便由二项式逻辑回归进行训练。是时候引入一个不可变且灵活的归一化类了。

不可变归一化

逻辑回归依赖于 sigmoid 曲线或逻辑函数,这在第六章的逻辑函数部分有描述,回归和正则化。逻辑函数用于将训练数据分类。逻辑函数的输出值范围从 x = - INFINITY 的 0 到 x = + INFINITY 的 1。因此,对输入数据或观测进行[0, 1]归一化是有意义的。

注意

归一化还是不归一化?

数据归一化的目的是为所有特征施加一个单一的范围,这样模型就不会偏袒任何特定的特征。归一化技术包括线性归一化和 Z 分数。归一化是一个昂贵的操作,并不总是需要的。

归一化是原始数据的线性变换,可以推广到任何范围 [l, h]

注意

线性归一化

M2: [0, 1] 归一化特征 {x[i]} 的最小值 x[min] 和最大值 x[max]:

不可变归一化

M3: [l, h] 归一化特征 {xi}:

不可变归一化

监督学习中输入数据的归一化有特定的要求:对新观测的分类和预测必须使用从训练集中提取的归一化参数(minmax),因此所有观测共享相同的缩放因子。

让我们定义MinMax归一化类。该类是不可变的:最小值min和最大值max在构造函数中计算。该类接受参数化类型T的时间序列和值作为参数(第 8 行)。归一化过程的步骤定义如下:

  1. 在实例化时初始化给定时间序列的最小值(第 9 行)。

  2. 计算归一化参数(第 10 行)并归一化输入数据(第 11 行)。

  3. 重新使用归一化参数归一化任何新的数据点(第 14 行):

    class MinMaxT <: AnyVal (f : T => Double) { //8
      val zero = (Double.MaxValue, -Double.MaxValue)
      val minMax = values./:(zero)((mM, x) => { //9
        val min = mM._1
        val max = mM._2
       (if(x < min) x else min, if(x > max) x else max)
      })
      case class ScaleFactors(low:Double ,high:Double, ratio: Double)
      var scaleFactors: Option[ScaleFactors] = None //10
    
      def min = minMax._1
      def max = minMax._2
      def normalize(low: Double, high: Double): DblVector //11
      def normalize(value: Double): Double
    }
    

类构造函数使用折叠(第 9 行)计算最小值和最大值的元组minMaxscaleFactors缩放参数在时间序列归一化期间计算(第 11 行),具体描述如下。normalize方法在归一化输入数据之前初始化缩放因子参数(第 12 行):

def normalize(low: Double, high: Double): DblVector = 
  setScaleFactors(low, high).map( scale => { //12
    values.map(x =>(x - min)*scale.ratio + scale.low) //13
  }).getOrElse(/* … */)

def setScaleFactors(l: Double, h: Double): Option[ScaleFactors]={
    // .. error handling code
   Some(ScaleFactors(l, h, (h - l)/(max - min))
}

后续观察使用从normalize中提取的相同缩放因子(第 14 行):

def normalize(value: Double):Double = setScaleFactors.map(scale => 
   if(value < min) scale.low
   else if (value > max) scale.high
   else (value - min)* scale.high + scale.low
).getOrElse( /* … */)

MinMax类对单变量观测值进行归一化。

注意

统计学类

数据概览部分介绍的Stats数据集中提取基本统计信息的类,在第二章的Hello World!中继承自MinMax类。

使用多项式逻辑回归的测试用例通过MinMaxVector类实现了多变量归一化,该类将XVSeries[Double]类型的观测值作为输入:

class MinMaxVector(series: XVSeries[Double]) {
  val minMaxVector: Vector[MinMax[Double]] = //15
      series.transpose.map(new MinMaxDouble)
  def normalize(low: Double, high: Double): XVSeries[Double]
}

MinMaxVector类的构造函数将观测值的数组向量转置,以便计算每个维度的最小值和最大值(第 15 行)。

第 4 步 – 发现模式

价格行为图表有一个非常有趣的特性。

分析数据

仔细观察,价格突然变动和交易量增加大约每三个月发生一次。经验丰富的投资者无疑会认识到这些价格-交易量模式与思科公司季度收益的发布有关。这种规律但不可预测的模式,如果风险得到适当管理,可能成为担忧或机会的来源。股票价格对公司收益发布的强烈反应可能会吓到一些长期投资者,同时吸引日交易者。

下面的图表展示了突然的价格变动(波动性)和大量交易量之间的潜在相关性:

分析数据

思科股票 2012-2013 年的价格-交易量相关性

下一节对于理解测试用例不是必需的。它展示了 JFreeChart 作为一个简单的可视化和绘图库的能力。

数据绘图

虽然图表绘制不是本书的主要目标,但我们认为您将受益于对 JFreeChart 的简要介绍。

注意

绘图类

本节展示了 JFreeChart Java 类的一个简单 Scala 接口。阅读本节内容对于理解机器学习不是必需的。计算结果的可视化超出了本书的范围。

在可视化中使用的某些类在附录 A 中进行了描述,基本概念

将数据集(波动性和成交量)转换为 JFreeChart 内部数据结构。ScatterPlot类实现了一个简单的可配置散点图,具有以下参数:

  • config:这包括图表的信息、标签、字体等。

  • theme:这是图表的预定义主题(黑色、白色背景等)。

代码如下:

class ScatterPlot(config: PlotInfo, theme: PlotTheme) { //16
  def display(xy: Vector[DblPair], width: Int, height) //17
  def display(xt: XVSeries[Double], width: Int, height)
  // ….
}

PlotTheme类定义了图表的特定主题或预配置(第 16 行)。该类提供了一套display方法,以适应广泛的数据结构和配置(第 17 行)。

注意

可视化

JFreeChart 库被介绍为一个健壮的图表工具。为了使代码片段简洁并专注于机器学习,本书省略了与绘图和图表相关的代码。在少数情况下,输出数据格式化为 CSV 文件,以便导入电子表格。

ScatterPlot.display方法用于显示在二项式逻辑回归中使用的标准化输入数据,如下所示:

val plot = new ScatterPlot(("CSCO 2012-2013", 
   "Session High - Low", "Session Volume"), new BlackPlotTheme)
plot.display(volatility_vol, 250, 340)

绘制数据

2012-2013 年思科股票的波动性和成交量散点图

散点图显示了时段交易量和时段波动性之间的相关性水平,并证实了股票价格和成交量图表中的初步发现。我们可以利用这些信息根据波动性和成交量对交易时段进行分类。下一步是通过将训练集、观察值和期望值加载到我们的逻辑回归算法中,创建一个双分类模型。类别由散点图上绘制的决策边界(也称为超平面)分隔。

可视化标签——选择交易时段开盘价和收盘价之间股票价格的标准化变化作为此分类器的标签。

第 5 步 – 实现分类器

本次训练的目的是构建一个能够区分波动性和非波动性交易时段的模型。为了练习,将时段波动性定义为时段最高价与最低价之间的相对差异。时段内的总交易量构成模型的第二个参数。交易时段内的相对价格变动(即收盘价/开盘价 - 1)是我们期望的值或标签。

逻辑回归在统计学推断中常用。

提示

M4:逻辑回归模型

第 5 步 - 实现分类器

第一个权重 w[0] 被称为截距。二项式逻辑回归在 第六章 的 逻辑回归 部分中详细描述,回归和正则化

以下二项式逻辑回归分类器的实现公开了一个单一的 classify 方法,以满足我们减少复杂性和对象生命周期的方法。模型 weights 参数在实例化 LogBinRegression 类/模型时计算。如前所述,省略了代码中与理解算法无关的部分。

LogBinRegression 构造函数有五个参数(第 18 行):

  • obsSet:这些是表示体积和波动的向量观测值

  • expected:这是一个预期值的向量

  • maxIters:这是优化器在训练期间提取回归权重允许的最大迭代次数

  • eta:这是学习或训练速率

  • eps:这是模型有效的最大误差值(预测值—预期值

代码如下:

class LogBinRegression(
     obsSet: Vector[DblArray], 
     expected: Vector[Int],
     maxIters: Int, 
     eta: Double, 
     eps: Double) {  //18

   val model: LogBinRegressionModel = train  //19
   def classify(obs: DblArray): Try[(Int, Double)]   //20
   def train: LogBinRegressionModel
   def intercept(weights: DblArray): Double
   …
}

LogBinRegressionModel 模型是在实例化 LogBinRegression 逻辑回归类(第 19 行)时通过训练生成的:

case class LogBinRegressionModel(val weights: DblArray)

模型完全由其权重定义,如数学公式 M3 所述。weights(0) 截距表示变量为零的观测值的预测平均值。对于大多数情况,截距没有特定的含义,并且不一定可计算。

注意

是否包含截距?

截距对应于观测值为空值时的权重值。在可能的情况下,独立于模型斜率的最小化误差函数来估计二项式线性或逻辑回归的截距是一种常见做法。多项式回归模型将截距或权重 w[0] 视为回归模型的一部分,如 第六章 的 普通最小二乘回归 部分所述,回归和正则化

代码如下:

def intercept(weights: DblArray): Double = {
  val zeroObs = obsSet.filter(!_.exists( _ > 0.01))
  if( zeroObs.size > 0)
    zeroObs.aggregate(0.0)((s,z) => s + dot(z, weights), 
       _ + _ )/zeroObs.size
  else 0.0
}

classify 方法接受新的观测值作为输入,并计算观测值所属的类别索引(0 或 1)以及实际的似然值(第 20 行)。

选择优化器

使用预期值训练模型的目标是计算最优权重,以最小化 误差成本函数。我们选择 批量梯度下降 算法来最小化预测值和预期值之间所有观测值的累积误差。尽管有相当多的替代优化器,但梯度下降对于本章来说足够稳健且简单。该算法通过最小化成本来更新回归模型的权重 w[i]

注意

代价函数

M5:代价(或 复合误差 = 预测值 - 预期值):

选择优化器

M6:更新模型权重 w[i] 的批量梯度下降方法如下:

选择优化器

对于那些对学习优化技术感兴趣的人来说,附录 A 中的“优化技术总结”部分,基本概念部分概述了最常用的优化器。批量下降梯度法也用于多层感知器的训练(参见第九章下的训练周期部分,人工神经网络)。

批量梯度下降算法的执行遵循以下步骤:

  1. 初始化回归模型的权重。

  2. 打乱观测值和预期值的顺序。

  3. 聚合整个观测集的代价或误差。

  4. 使用代价作为目标函数来更新模型权重。

  5. 从步骤 2 重复,直到达到最大迭代次数或代价的增量更新接近零。

在迭代之间打乱观测顺序的目的是为了避免代价最小化达到局部最小值。

小贴士

批量和随机梯度下降

随机梯度下降是梯度下降的一种变体,它在计算每个观测的误差后更新模型权重。尽管随机梯度下降需要更高的计算努力来处理每个观测,但在经过少量迭代后,它相当快地收敛到权重的最优值。然而,随机梯度下降对权重的初始值和学习率的选取非常敏感,学习率通常由自适应公式定义。

训练模型

train 方法通过简单的梯度下降法迭代计算权重。该方法计算 weights 并返回一个 LogBinRegressionModel 模型实例:

def train: LogBinRegressionModel = {
  val nWeights = obsSet.head.length + 1  //21
  val init = Array.fill(nWeights)(Random.nextDouble )  //22
  val weights = gradientDescent(obsSet.zip(expected),0.0,0,init)
  new LogBinRegressionModel(weights)   //23
}

train 方法提取回归模型的权重数量 nWeights 作为 每个观测变量的数量 + 1(行 21)。该方法使用 [0, 1] 范围内的随机值初始化 weights(行 22)。权重通过尾递归的 gradientDescent 方法计算,该方法返回二元逻辑回归的新模型(行 23)。

小贴士

从 Try 中解包值

通常不建议对 Try 值调用 get 方法,除非它被包含在 Try 语句中。最佳做法是执行以下操作:

1. 使用 match{ case Success(m) => ..case Failure(e) =>} 捕获失败

2. 安全地提取 getOrElse( /* … */ ) 的结果

3. 将结果作为 Try 类型的 map( _.m) 传播

让我们看看 gradientDescent 方法中通过最小化成本函数来计算 weights 的过程:

type LabelObs = Vector[(DblArray, Int)]

@tailrec
def gradientDescent(
      obsAndLbl: LabelObs, 
      cost: Double, 
      nIters: Int, 
      weights: DblArray): DblArray = {  //24

  if(nIters >= maxIters) 
       throw new IllegalStateException("..")//25
  val shuffled = shuffle(obsAndLbl)   //26
  val errorGrad = shuffled.map{ case(x, y) => {  //27
      val error = sigmoid(dot(x, weights)) - y
      (error, x.map( _ * error))  //28
   }}.unzip

   val scale = 0.5/obsAndLbl.size
   val newCost = errorGrad._1   //29
.aggregate(0.0)((s,c) =>s + c*c, _ + _ )*scale
   val relativeError = cost/newCost - 1.0

   if( Math.abs(relativeError) < eps)  weights  //30
   else {
     val derivatives = VectorDouble ++ 
                 errorGrad._2.transpose.map(_.sum) //31
     val newWeights = weights.zip(derivatives)
                       .map{ case (w, df) => w - eta*df)  //32
     newWeights.copyToArray(weights)
     gradientDescent(shuffled, newCost, nIters+1, newWeights)//33
   }
}

gradientDescent 方法递归于包含观察和预期值的向量 obsAndLblcost 和模型 weights(第 24 行)。如果达到允许的优化最大迭代次数,则抛出异常(第 25 行)。在计算每个权重的成本 errorGrad 导数之前(第 27 行),它会打乱观察的顺序(第 26 行)。成本导数(或 error = 预测值 - 预期值)的计算(公式 M5)使用公式返回累积成本和导数值(第 28 行)。

接下来,该方法使用公式 M4 (第 29 行) 计算整体复合成本,将其转换为相对增量 relativeError 成本,并与 eps 收敛标准(第 30 行)进行比较。该方法通过转置误差矩阵来提取成本关于权重的 derivatives,然后将偏置 1.0 值添加到数组中,以匹配权重数组(第 31 行)。

注意

偏置值

偏置值的目的是在向量中添加 1.0,以便可以直接与权重(例如,压缩和点积)进行处理。例如,对于二维观察(x,y)的回归模型(w[0],w[1],w[2])具有三个权重。偏置值 +1 添加到观察中,以计算预测值 1.0:w[0] + x.w[1],+ y.w[2]

这种技术在多层感知器的激活函数计算中得到了应用,如 第九章 中 多层感知器 部分所述,人工神经网络

公式 M6 在调用带有新权重、成本和迭代计数的函数之前(第 33 行)更新下一次迭代的权重(第 32 行)。

让我们通过随机序列生成器来查看观察顺序的打乱。以下实现是 Scala 标准库方法 scala.util.Random.shuffle 对集合元素进行打乱的替代方案。目的是在迭代之间改变观察和标签的顺序,以防止优化器达到局部最小值。shuffle 方法通过将 labelObs 观察向量划分为随机大小的段并反转其他段的顺序来重新排列观察的 labelObs 向量:

val SPAN = 5
def shuffle(labelObs: LabelObs): LabelObs = { 
  shuffle(new ArrayBuffer[Int],0,0).map(labelObs( _ )) //34
}

一旦更新了观察的顺序,通过映射(第 34 行)可以轻松构建(观察,标签)对的向量。实际的索引打乱是在下面的 shuffle 递归函数中执行的:

val maxChunkSize = Random.nextInt(SPAN)+2  //35

@tailrec
def shuffle(indices: ArrayBuffer[Int], count: Int, start: Int): 
      Array[Int] = {
  val end = start + Random.nextInt(maxChunkSize) //36
  val isOdd = ((count & 0x01) != 0x01)
  if(end >= sz) 
    indices.toArray ++ slice(isOdd, start, sz) //37
  else 
    shuffle(indices ++ slice(isOdd, start, end), count+1, end)
}

maxChunkSize向量观察值的最大分区大小是随机计算的(第 35 行)。该方法提取下一个切片(startend)(第 36 行)。该切片要么在所有观察值都打乱后添加到现有的索引向量中并返回,要么传递给下一次调用。

slice方法返回一个索引数组,该数组覆盖范围(startend),如果处理的段数是奇数,则按正确顺序返回,如果是偶数,则按相反顺序返回:

def slice(isOdd: Boolean, start: Int, end: Int): Array[Int] = {
  val r = Range(start, end).toArray
  (if(isOdd) r else r.reverse)
}

注意

迭代与尾递归计算

Scala 中的尾递归是迭代算法的一个非常高效的替代方案。尾递归避免了为方法每次调用创建新的栈帧的需要。它应用于本书中提出的许多机器学习算法的实现。

为了训练模型,我们需要标记输入数据。标记过程包括将一个会话期间的相对价格变动(收盘价/开盘价 - 1)与以下两种配置之一相关联:

  • 交易量高、波动性高的交易时段

  • 交易量低、波动性低的交易时段

在前一个部分中绘制的散点图上的决策边界将两类训练观察值隔离开。标记过程通常相当繁琐,应尽可能自动化。

注意

自动标记

虽然非常方便,但自动创建训练标签并非没有风险,因为它可能会错误地标记单个观察值。这种技术在本次测试中出于方便而使用,但除非领域专家手动审查标签,否则不建议使用。

对观察进行分类

一旦通过训练成功创建模型,它就可以用于对新观察值进行分类。通过classify方法实现的二项逻辑回归的观察值运行时分类如下:

def classify(obs: DblArray): Try[(Int, Double)] = 
  val linear = dot(obs, model.weights)  //37
  val prediction = sigmoid(linear)
  (if(linear > 0.0) 1 else 0, prediction) //38
})

该方法将逻辑函数应用于模型的新obsweights观察值的线性内积linear(第 37 行)。该方法返回一个元组(观察的预测类别{0, 1},预测值),其中类别是通过将预测值与边界值0.0(第 38 行)进行比较来定义的。

权重和观察值的点积计算使用偏差值如下:

def dot(obs: DblArray, weights: DblArray): Double =
   weights.zip(ArrayDouble ++ obs)
          .aggregate(0.0){case (s, (w,x)) => s + w*x, _ + _ }

权重和观察值的点积的替代实现是提取第一个w.head权重:

def dot(x: DblArray, w: DblArray): Double = 
  x.zip(w.drop(1)).map {case (_x,_w) => _x*_w}.sum + w.head

classify方法中使用dot方法。

第 6 步 – 评估模型

第一步是定义测试的配置参数:最大迭代次数NITERS,收敛标准EPS,学习率ETA,用于标记BOUNDARY训练观察值的决策边界,以及训练集和测试集的路径:

val NITERS = 800; val EPS = 0.02; val ETA = 0.0001
val path_training = "resources/data/chap1/CSCO.csv"
val path_test = "resources/data/chap1/CSCO2.csv"

创建和测试模型的各种活动,包括加载数据、归一化数据、训练模型、加载数据和分类测试数据,是通过使用Try类的单子组合作为一个工作流程组织的:

for {
  volatilityVol <- load(path_training)    //39
  minMaxVec <- Try(new MinMaxVector(volatilityVol))    //40
  normVolatilityVol <- Try(minMaxVec.normalize(0.0,1.0))//41
  classifier <- logRegr(normVolatilityVol)    //42
  testValues <- load(path_test)    //43
  normTestValue0 <- minMaxVec.normalize(testValues(0))  //44
  class0 <- classifier.classify(normTestValue0)   //45
  normTestValue1 <- minMaxVec.normalize(testValues(1))    
  class1 <- classifier.classify(normTestValues1)
} yield {
   val modelStr = model.toString
   …
}

首先,从文件中加载volatilityVol股票价格的每日交易波动性和成交量(第39行)。工作流程初始化多维MinMaxVec归一化器(第40行),并使用它来归一化训练集(第41行)。logRegr方法实例化了二项式classifier逻辑回归(第42行)。从文件中加载testValues测试数据(第43行),使用已应用于训练数据的MinMaxVec进行归一化(第44行),并进行分类(第45行)。

load方法从文件中提取XVSeries[Double]类型的data(观测值)。繁重的工作由extract方法(第46行)完成,然后在返回原始观测值向量之前关闭文件句柄(第47行):

def load(fileName: String): Try[XVSeries[Double], XSeries[Double]] =  {
  val src =  Source.fromFile(fileName)
  val data = extract(src.getLines.map( _.split(",")).drop(1)) //46
  src.close; data //47
}

私有的logRegr方法有两个目的:

  • 自动标记obs观测值以生成expected值(第48行)

  • 初始化(实例化和训练)二项式逻辑回归(第49行)

代码如下:

def logRegr(obs: XVSeries[Double]): Try[LogBinRegression] = Try {
    val expected = normalize(labels._2).get  //48
    new LogBinRegression(obs, expected, NITERS, ETA, EPS)  //49
}

该方法通过评估观测值是否属于由BOUNDARY条件定义的两个类别之一来标记观测值,如图中前一个部分的散点图所示。

注意

验证

在这个测试案例中提供的简单分类是为了说明模型的运行时应用。这无论如何都不能构成对模型的验证。下一章将深入研究验证方法(请参阅第二章中的评估模型部分,Hello World!

训练运行使用三个不同的学习率值进行。以下图表说明了不同学习率值下批量梯度下降在成本最小化过程中的收敛:

第 6 步 – 评估模型

学习率对批量梯度下降在成本(误差)收敛中的影响

如预期的那样,使用较高学习率的优化器在成本函数中产生了最陡的下降。

测试执行产生了以下模型:

迭代次数 = 495

权重:0.859-3.6177923,-64.927832

输入(0.0088,4.10E7)归一化(0.063,0.061)类别 1 预测 0.515

输入(0.0694,3.68E8)归一化(0.517,0.641)类别 0 预测 0.001

注意

了解更多关于回归模型

二项式逻辑回归仅用于说明训练和预测的概念。它在第六章的逻辑回归部分,回归和正则化中进行了详细描述。

摘要

希望你喜欢这篇机器学习的介绍。你学习了如何利用 Scala 编程技能创建一个简单的逻辑回归程序来预测股价/量价走势。以下是本章的要点:

  • 从单调组合和高阶集合方法用于并行化到可配置性和重用模式,Scala 是实施大规模项目中的数据挖掘和机器学习算法的完美选择。

  • 创建和部署机器学习模型有许多逻辑步骤。

  • 作为测试用例一部分呈现的二项逻辑回归分类器的实现足够简单,足以鼓励你学习如何编写和应用更高级的机器学习算法。

让 Scala 编程爱好者感到高兴的是,下一章将深入探讨通过利用单调数据转换和可堆叠特性来构建灵活的工作流程。

第二章。你好,世界!

在第一章中,你熟悉了一些关于数据处理、聚类和分类的基本概念。本章致力于创建和维护一个灵活的端到端工作流程来训练和分类数据。本章的第一节介绍了一种以数据为中心(函数式)的方法来创建数据处理应用。

你将学习如何:

  • 将单调设计概念应用于创建动态工作流程

  • 利用 Scala 的一些高级模式,如蛋糕模式,来构建可移植的计算工作流程

  • 在选择模型时考虑偏差-方差权衡

  • 克服建模中的过拟合问题

  • 将数据分解为训练集、测试集和验证集

  • 使用精确度、召回率和 F 分数在 Scala 中实现模型验证

建模

数据是任何科学家的生命线,选择数据提供者对于开发或评估任何统计推断或机器学习算法至关重要。

任何名称的模型

我们在第一章的模型分类部分简要介绍了模型的概念,入门

什么是模型?维基百科为科学家理解的模型提供了一个相当好的定义 [2:1]:

科学模型旨在以逻辑和客观的方式表示经验对象、现象和物理过程。

在软件中呈现的模型允许科学家利用计算能力来模拟、可视化、操作并获取对所表示实体、现象或过程的直观理解。

在统计学和概率论中,模型描述了从系统可能观察到的数据,以表达任何形式的不确定性和噪声。模型使我们能够推断规则、做出预测并从数据中学习。

一个模型由特征组成,也称为属性变量,以及这些特征之间的一组关系。例如,由函数 f(x, y) = x.sin(2y) 表示的模型有两个特征,xy,以及一个关系,f。这两个特征被认为是独立的。如果模型受到如 f(x, y) < 20 这样的约束,那么条件独立性就不再有效。

一个敏锐的 Scala 程序员会将一个模型与一个幺半群关联起来,其中集合是一组观测值,运算符是实现模型的函数。

模型有多种形状和形式:

  • 参数化:这包括函数和方程(例如,y = sin(2t + w)

  • 微分:这包括常微分方程和偏微分方程(例如,dy = 2x.dx

  • 概率论:这包括概率分布(例如,p(x|c) = exp (k.logx – x)/x!)

  • 图形表示:这包括抽象变量之间条件独立性的图(例如,p(x,y|c) = p(x|c).p(y|c)

  • 有向图:这包括时间和空间关系(例如,一个调度器)

  • 数值方法:这包括有限差分、有限元或牛顿-拉夫森等计算方法

  • 化学:这包括公式和成分(例如,H[2]O, Fe + C[12] = FeC[13],等等)

  • 分类学:这包括概念的含义和关系(例如,APG/Eudicots/Rosids/Huaceae/Malvales

  • 语法和词汇:这包括文档的句法表示(例如,Scala 编程语言)

  • 推理逻辑:这包括规则(例如,IF (stock vol > 1.5 * average) AND rsi > 80 THEN …

模型与设计

在计算机科学中,模型和设计的混淆相当常见,原因在于这些术语对不同的人来说有不同的含义,这取决于主题。以下隐喻应该有助于您理解这两个概念:

  • 建模:这描述了您所知道的东西。一个模型做出假设,如果得到证实,则成为断言(例如,美国人口,p,每年增加 1.2%,dp/dt = 1.012)。

  • 设计:这操作未知事物的表示。设计可以被视为建模的探索阶段(例如,哪些特征有助于美国人口的增长?出生率?移民?经济条件?社会政策?)。

选择特征

选择模型特征的过程是发现和记录构建模型所需的最小变量集。科学家们假设数据包含许多冗余或不相关的特征。冗余特征不提供已由所选特征提供的信息,而不相关的特征不提供任何有用的信息。

特征选择包括两个连续的步骤:

  1. 寻找新的特征子集。

  2. 使用评分机制评估这些特征子集。

对于大型数据集,评估每个可能的特征子集以找到最大化目标函数或最小化错误率的进程在计算上是不可行的。具有n个特征的模型需要2^n-1次评估。

提取特征

观测是一组对隐藏的、也称为潜在变量的间接测量,这些变量可能是有噪声的或包含高度的相关性和冗余。在分类任务中使用原始观测可能会导致不准确的结果。在每个观测中使用所有特征也会产生很高的计算成本。

特征提取的目的是通过消除冗余或不相关的特征来减少模型中的变量或维度数量。特征提取是通过将原始观测集转换为更小的集合来进行的,这可能会丢失原始集合中嵌入的一些重要信息。

定义一个方法

数据科学家在选择和实现分类或聚类算法时有很多选择。

首先,需要选择一个数学或统计模型来从原始输入数据或数据上游转换的输出中提取知识。模型的选择受到以下参数的限制:

  • 商业需求,例如结果准确性或计算时间

  • 训练数据、算法和库的可用性

  • 如果需要,访问领域或主题专家

其次,工程师必须选择一个适合处理的数据量的计算和部署框架。计算环境将由以下参数定义:

  • 可用资源,例如机器、CPU、内存或 I/O 带宽

  • 一种实现策略,例如迭代计算或缓存

  • 对整体过程响应性的要求,例如计算持续时间或显示中间结果

第三,领域专家必须标记或标注观测,以便生成准确的分类器。

最后,模型必须与可靠的测试数据集进行验证。

以下图表说明了创建工作流程的选择过程:

定义一个方法

机器学习应用中的统计和计算建模

注意

领域专业知识、数据科学和软件工程

领域或主题专家是在特定领域或主题上具有权威或认可的专业知识的人。化学家是化学领域的专家,可能是相关领域。

数据科学家在生物科学、医疗保健、营销或金融等众多领域解决与数据相关的问题。数据挖掘、信号处理、统计分析以及使用机器学习算法进行建模是数据科学家执行的一些活动。

软件开发者执行与创建软件应用程序相关的所有任务,包括分析、设计、编码、测试和部署。

数据转换的参数可能需要根据上游数据转换的输出重新配置。Scala 的高阶函数特别适合实现可配置的数据转换。

单子数据转换

第一步是定义一个特性和方法,描述工作流程的计算单元对数据的转换。数据转换是任何处理和分类数据集、训练和验证模型以及显示结果的工作流程的基础。

定义数据处理或数据转换时使用了两种符号模型:

  • 显式模型:开发者从一组配置参数显式创建模型。大多数确定性算法和无监督学习技术使用显式模型。

  • 隐式模型:开发者提供一组标记观察(具有预期结果的观察)的训练集。分类器通过训练集提取模型。监督学习技术依赖于从标记数据隐式生成的模型。

错误处理

数据转换的最简单形式是两种类型UV之间的同态。数据转换强制执行一个契约,用于验证输入并返回值或错误。从现在开始,我们使用以下约定:

  • 输入值:验证通过返回数据转换的PartialFunction类型部分函数实现。如果输入值不符合所需条件(契约),则抛出MatchErr错误。

  • 输出值:返回值的类型为Try[V],在发生错误时返回异常。

注意

部分函数的可重用性

可重用性是部分函数的另一个好处,以下代码片段展示了这一点:

class F { 
  def f: PartialFunction[Int, Try[Double]] { case n: Int … 
  }
}
val pfn = (new F).f
pfn(4)
pfn(10)

部分函数允许开发者实现针对最常见(主要)用例的方法,这些用例的输入值已经过测试。所有其他非平凡用例(或输入值)都会生成MatchErr异常。在开发周期的后期,开发者可以实施代码来处理较少见的用例。

注意

部分函数的运行时验证

验证一个部分函数是否为特定参数值定义是一个好的实践:

for {
  pfn.isDefinedAt(input)
  value <- pfn(input)
} yield { … }

这种先发制人的方法允许开发者选择一个替代方法或完整功能。这是捕获MathErr异常的有效替代方案。为了清晰起见,本书中省略了对部分函数的验证。

因此,数据转换的签名被定义为如下:

def |> : PartialFunction[U, Try[V]]

注意

F#语言参考

作为转换签名的|>符号是从 F#语言[2:2]借用的。

显式模型

目标是定义不同类型数据的符号表示,而不暴露实现数据转换的算法的内部状态。数据集上的转换是通过用户完全定义的模型或配置来执行的,如下面的图示所示:

显式模型

显式模型的可视化

显式配置或模型config的转换被定义为ETransform抽象类的一个参数化类型T

abstract class ETransformT { //explicit model
  type U   // type of input
  type V   // type of output
  def |> : PartialFunction[U, Try[V]]  // data transformation
}

输入U类型和输出V类型必须在ETransform的子类中定义。|>转换运算符返回一个部分函数,可以用于不同的输入值。

创建一个实现特定转换的显式配置的类相当简单:你所需要的就是定义一个输入/输出U/V类型以及实现|>转换方法。

让我们考虑从金融数据源DataSource中提取数据,该数据源接受一个将某些文本字段Fields转换为Double值的函数列表作为输入,并产生一个XSeries[Double]类型的观察列表。提取参数在DataSourceConfig类中定义:

class DataSource(
  config: DataSourceConfig,   //1
  srcFilter: Option[Fields => Boolean]= None)
        extends ETransformDataSourceConfig { //2
  type U = List[Fields => Double]   //3
  type V = List[XSeries[Double]]     //4
  override def |> : PartialFunction[U, Try[V]] = { //5
    case u: U if(!u.isEmpty) => … 
  }
}

DataSourceConfig配置作为DataSource构造函数的参数显式提供(行1)。构造函数实现了与显式模型相关的基本类型和数据转换(行2)。该类定义了输入值的U类型(行3),输出值的V类型(行4),以及返回部分函数的|>转换方法(行5)。

注意

DataSource

附录 A 中的数据提取部分,基本概念描述了DataSource类的功能。本书中使用了DataSource类。

使用显式模型或配置进行数据转换构成一个具有单调运算的类别。与ETransform类子类关联的单子继承自高阶单子的定义_Monad

private val eTransformMonad = new _Monad[ETransform] {
  override def unitT = eTransform(t)   //6
  override def mapT,U     //7
      (f: T => U): ETransform[U] = eTransform( f(m.config) )
  override def flatMapT,U  //8
      (f: T =>ETransform[U]): ETransform[U] = f(m.config)
}

单例eTransformMonad实现了在第一章中抽象部分的Monads节下引入的以下基本单调运算符,入门

  • unit方法用于实例化ETransform(行6

  • 使用map通过变形其元素(第7行)来转换一个ETransform对象。

  • 使用flatMap通过实例化其元素(第8行)来转换一个ETransform对象。

为了实际应用,创建了一个隐式类,将ETransform对象转换为相关的单子,允许透明访问unitmapflatMap方法:

implicit class eTransform2MonadT {
  def unit(t: T) = eTransformMonad.unit(t)
  final def mapU: ETransform[U] = 
      eTransformMonad.map(fct)(f)
  final def flatMapU: ETransform[U] =
      eTransformMonad.flatMap(fct)(f)
}

隐式模型

监督学习模型是从训练集中提取的。如分类或回归等转换使用隐式模型来处理输入数据,如下面的图所示:

隐式模型

隐式模型的可视化

从训练数据隐式提取的模型的转换定义为由观察类型Txt参数化的抽象ITransform类:

abstract class ITransformT { //Model input
   type V   // type of output
   def |> : PartialFunction[T, Try[V]]  // data transformation
}

数据集合的类型是Vector,它是一个不可变且有效的容器。通过定义观察的T类型、数据转换的V输出以及实现转换的|>方法(通常是一个分类或回归)来创建ITransform类型。让我们以支持向量机算法SVM为例,说明使用隐式模型实现数据转换的实现方式:

class SVMT <: AnyVal(implicit f: T => Double)
  extends ITransform[Array[T]](xt) {//10

 type V = Double  //11
 override def |> : PartialFunction[Array[T], Try[V]] = { //12
     case x: Array[T] if(x.length == data.size) => ...
  }

支持向量机是一种在第八章《核模型与支持向量机》中描述的判别式监督学习算法。支持向量机(SVM)通过配置和训练集实例化:xt观察数据和expected数据(第9行)。与显式模型相反,config配置不定义用于数据转换的模型;模型是从xt输入数据的训练集和expected值中隐式生成的。通过指定V输出类型(第11行)并覆盖|>转换方法(第12行),创建了一个ITransform实例。

|>分类方法产生一个部分函数,它接受一个x观察值作为输入,并返回一个Double类型的预测值。

与显式转换类似,我们通过覆盖unit(第13行)、map(第14行)和flatMap(第15行)方法来定义ITransform的单子操作:

private val iTransformMonad = new _Monad[ITransform] {
  override def unitT = iTransform(VectorT)  //13

  override def mapT,U(f: T => U): 
ITransform[U] = iTransform( m.xt.map(f) )   //14

  override def flatMapT,U  
    (f: T=>ITransform[U]): ITransform[U] = 
 iTransform(m.xt.flatMap(t => f(t).xt)) //15
}

最后,让我们创建一个隐式类,自动将ITransform对象转换为相关的单子,以便它可以透明地访问unitmapflatMap单子方法:

implicit class iTransform2MonadT {
   def unit(t: T) = iTransformMonad.unit(t)

   final def mapU: ITransform[U] = 
      iTransformMonad.map(fct)(f)
   final def flatMapU: ITransform[U] = 
      iTransformMonad.flatMap(fct)(f)
   def filter(p: T =>Boolean): ITransform[T] =  //16
      iTransform(fct.xt.filter(p))
}

filter 方法严格来说不是单子的运算符(行 16)。然而,它通常被包含以约束(或保护)一系列变换(例如,用于理解闭包)。如 第一章 的 源代码 部分的 演示 部分所述,入门,与异常、错误检查和参数验证相关的代码被省略。

注意

不可变变换

数据变换(或处理单元或分类器)类的模型应该是不可变的。任何修改都将改变模型或用于处理数据的参数的完整性。为了确保在整个变换的生命周期中始终使用相同的模型来处理输入数据,我们执行以下操作:

  • ETransform 的模型定义为构造函数的参数。

  • ITransform 的构造函数从给定的训练集中生成模型。如果模型提供不正确的结果或预测,则必须从训练集中重新构建模型(而不是修改)。

模型由分类器或数据变换类的构造函数创建,以确保其不可变性。不可变变换的设计在 附录 A 的 Scala 编程 部分的 不可变分类器设计模板 部分中描述,基本概念

工作流计算模型

单子对于使用隐式配置或显式模型操作和链式数据变换非常有用。然而,它们被限制为单个形态 T => U 类型。更复杂和灵活的工作流程需要使用通用工厂模式编织不同类型变换。

传统的工厂模式依赖于组合和继承的组合,并且不向开发者提供与可堆叠特质相同级别的灵活性。

在本节中,我们向您介绍使用混入和蛋糕模式变体进行建模的概念,以提供一个具有三个配置级别的流程。

支持数学抽象

可堆叠特质使开发者能够在 Scala 中实现模型时遵循严格的数学形式主义。科学家使用一个普遍接受的模板来解决数学问题:

  1. 声明与问题相关的变量。

  2. 定义一个模型(方程、算法、公式等)作为问题的解决方案。

  3. 实例化变量并执行模型以解决问题。

让我们考虑核函数的概念示例(在第八章的 核函数 部分中描述,第八章,核模型和支持向量机),这是一个由两个数学函数的组合及其在 Scala 中的潜在实现的模型。

第 1 步 – 变量声明

实现包括将两个函数包装(作用域)到特质中,并将这些函数定义为抽象值。

数学正式性如下:

步骤 1 – 变量声明

Scala 实现如下:

type V = Vector[Double]
trait F { val f: V => V}
trait G { val g: V => Double }

步骤 2 – 模型定义

模型定义为两个函数的组合。GF特质栈描述了可以使用自引用的self: G with F约束组合的兼容函数类型。

正式性将是 h = f o g

Scala 实现如下:

class H {self: G with F => def apply(v:V): Double =g(f(v))}

步骤 3 – 实例化

模型在fg变量实例化后执行。

正式性如下:

步骤 3 – 实例化

Scala 实现如下:

val h = new H with G with F {
  val f: V => V = (v: V) => v.map(Math.exp(_))
  val g: V => Double = (v: V) => v.sum
}

注意

懒值触发器

在前面的例子中,h(v) = g(f(v)) 的值可以在 gf 初始化后自动计算,通过将 h 声明为懒值。

显然,Scala 保留了数学模型的正式性,这使得科学家和开发者更容易将用科学导向语言编写的现有项目迁移到 Scala。

注意

R 的模拟

大多数数据科学家使用 R 语言创建模型并应用学习策略。在某些情况下,他们可能会将 Scala 视为 R 的替代品,因为 Scala 保留了在 R 中实现的模型所使用的数学正式性。

让我们将数学正式性的概念扩展到使用特质动态创建工作流程。下一节中描述的设计模式有时被称为 Cake 模式

通过组合 mixins 构建工作流程

本节介绍了 Cake 模式背后的关键构造。由可配置数据转换组成的流程需要动态模块化(替换)工作流程的不同阶段。

注意

特质和 mixins

Mixins 是堆叠在类上的特质。本节中描述的 mixin 组合和 Cake 模式对于定义数据转换的序列很重要,但这个主题与机器学习没有直接关系,因此您可以跳过这一节。

Cake 模式是一种高级类组合模式,它使用 mixin 特质来满足可配置计算工作流程的需求。它也被称为可堆叠修改特质 [2:4]。

这不是对 Scala 中的 可堆叠特质注入自引用 的深入分析。有一些关于依赖注入的有趣文章值得一看 [2:5]。

Java 依赖于与目录结构紧密耦合的包,并使用前缀来模块化代码库。Scala 为开发者提供了一种灵活且可重用的方法来创建和组织模块:特质。特质可以嵌套、与类混合、堆叠和继承。

理解问题

依赖注入是一个用于反向查找和绑定依赖关系的花哨名称。让我们考虑一个需要数据预处理、分类和验证的简单应用程序。使用特质的简单实现如下:

val app = new Classification with Validation with PreProcessing { 
   val filter = .. 
}

如果在后续阶段需要使用无监督聚类算法而不是分类器,那么应用程序必须重新布线:

val app = new Clustering with Validation with PreProcessing { 
    val filter = ..  
}

这种方法会导致代码重复和缺乏灵活性。此外,filter 类成员需要为应用程序组合中的每个新类重新定义。当组合中使用的特性之间存在依赖关系时,问题就出现了。让我们考虑这样一个案例,其中 filter 依赖于 validation 方法。

注意

混合函数的线性化 [2:6]

混合函数之间的线性化或方法调用遵循从右到左和基类到子类的模式:

  • 特性 B 扩展 A

  • 特性 C 扩展 A

  • M 扩展 NCB

Scala 编译器将线性化实现为 A => B => C => N

尽管你可以将 filter 定义为一个抽象值,但它仍然需要在引入新的验证类型时重新定义。解决方案是在新组成的 PreProcessingWithValidation 特质的定义中使用 self 类型:

trait PreProcessiongWithValidation extends PreProcessing {
   self: Validation => val filter = ..
}

应用程序是通过将 PreProcessingWithValidation 混合函数堆叠到 Classification 类上来构建的:

val app = new Classification with PreProcessingWithValidation {
   val validation: Validation
}

注意

用 val 覆盖 def

用具有相同签名的值声明覆盖方法声明是有利的。与在实例化期间一次性分配给所有值的值不同,方法可以为每次调用返回不同的值。一个 def 是一个可以重新定义为 defvallazy valproc。因此,你不应该用具有相同签名的值声明覆盖方法:

trait Validator { val g = (n: Int) =>  }trait MyValidator extends Validator { def g(n: Int) = …} //WRONG 

让我们调整并推广这个模式,构建一个样板模板,以便创建动态计算工作流程。

定义模块

第一步是生成不同的模块来封装不同类型的数据转换。

注意

描述蛋糕模式的用例

使用书中后面介绍过的类和算法构建一个真实世界工作流程的示例很困难。以下简单的示例足以说明蛋糕模式的不同组件:

让我们定义一个由三个参数化模块组成的序列,每个模块都使用 Etransform 类型的显式配置来定义特定的数据转换:

  • 采样:这用于从原始数据中提取样本

  • 归一化:这用于将样本数据归一化到 [0, 1] 范围内

  • 聚合:这用于聚合或减少数据

代码如下:

trait Sampling[T,A,B] { 
  val sampler: ETransform[T] { type U = A; type V = B }
}
trait Normalization[T,A,B] { 
  val normalizer: ETransform[T] { type U = A; type V = B }
  }
trait Aggregation[T,A,B] { 
  val aggregator: ETransform[T] { type U = A; type V = B }
}

模块包含一个抽象值。蛋糕模式的一个特点是通过对模块中封装的类型初始化抽象值来强制执行严格的模块化。构建框架的一个目标是在不依赖任何工作流程的情况下允许开发者独立创建数据转换(从ETransform继承)。

注意

Scala 特性和 Java 包

在模块化方面,Scala 和 Java 之间存在重大差异。Java 包将开发者约束在遵循严格的语法中,例如,源文件必须与包含的类同名。基于可堆叠特质的 Scala 模块要灵活得多。

实例化工作流程

下一步是将不同的模块写入工作流程。这是通过使用前一个部分中定义的三个特质的self引用到栈中实现的:

class Workflow[T,U,V,W,Z] {
  self: Sampling[T,U,V] with 
         Normalization[T,V,W] with 
           Aggregation[T,W,Z] =>
    def |> (u: U): Try[Z] = for {
      v <- sampler |> u
      w <- normalizer |> v
      z <- aggregator |> w
    } yield z
}

一图胜千言;以下 UML 类图说明了工作流程工厂(或蛋糕)设计模式:

实例化工作流程

工作流程工厂的 UML 类图

最后,通过动态初始化转换的samplernormalizeraggregator抽象值,只要签名(输入和输出类型)与每个模块中定义的参数化类型匹配(行1)来实例化工作流程:

type Dbl_F = Function1[Double, Double]
val samples = 100; val normRatio = 10; val splits = 4

val workflow = new Workflow[Int, Dbl_F, DblVector, DblVector,Int] 
      with Sampling[Int, Dbl_F, DblVector] 
         with Normalization[Int, DblVector, DblVector] 
            with Aggregation[Int, DblVector, Int] {
    val sampler = new ETransformInt { /* .. */} //1
    val normalizer = new ETransformInt { /*  .. */}
    val aggregator = new ETransformInt {/*  .. */}
}

让我们通过为抽象值分配转换来实现每个三个模块/特性的数据转换函数。

第一个转换,sampler,在区间[0, 1]上以频率1/samples采样f函数。第二个转换,normalizer,使用下一章中引入的Stats类在范围[0, 1]内归一化数据。最后一个转换,aggregator,提取大样本的索引(值 1.0):

val sampler = new ETransformInt { //2
  type U = Dbl_F  //3
  type V = DblVector  //4
  override def |> : PartialFunction[U, Try[V]] = { 
    case f: U => 
     Try(Vector.tabulate(samples)(n =>f(1.0*n/samples))) //5
  }
}

sampler转换使用单个模型或配置参数,sample(行2)。输入的U类型定义为Double => Double(行3),输出的V类型定义为浮点值向量,DblVector(行4)。在这种情况下,转换包括将输入的f函数应用于递增归一化值的向量(行5)。

normalizeraggregator转换遵循与sampler相同的模式:

val normalizer = new ETransformInt {
  type U = DblVector;  type V = DblVector
  override def |> : PartialFunction[U, Try[V]] = { case x: U 
    if(x.size >0) => Try((StatsDouble).normalize)
  }
}
val aggregator = new ETransformInt {
  type U = DblVector; type V = Int
  override def |> : PartialFunction[U, Try[V]] = case x: U 
    if(x.size > 0) => Try(Range(0,x.size).find(x(_)==1.0).get)
  }
}

转换函数的实例化遵循本章中“显式模型”部分中描述的模板。

工作流程现在可以处理任何函数作为输入:

val g = (x: Double) => Math.log(x+1.0) + Random.nextDouble
Try( workflow |> g )  //6

工作流程通过向第一个sampler混合提供输入g函数来执行(行6)。

Scala 的强类型检查在编译时捕捉任何不一致的数据类型。它减少了开发周期,因为运行时错误更难追踪。

注意

ITransform 的混合组成

我们任意选择了一个使用显式 ETransform 配置的数据转换来展示混入(mixins)组合的概念。相同的模式也适用于隐式 ITransform 数据转换。

模块化

最后一步是工作流程的模块化。对于复杂的科学计算,你需要能够做到以下几步:

  1. 根据执行目标(回归、分类、聚类等)选择适当的 工作流程 作为模块或任务的序列。

  2. 根据数据(噪声数据、不完整的训练集等)选择完成任务的适当 算法

  3. 根据环境(具有高延迟网络的分布式、单个主机等)选择算法的适当 实现模块化

    从模块/特性动态创建工作流程的示例

让我们考虑一个在 PreprocessingModule 模块中定义的简单预处理任务。该模块(或任务)被声明为一个特性,以隐藏其内部工作原理对其他模块的可见性。预处理任务由 Preprocessor 类型的预处理程序执行。我们任意列出两个算法:ExpMovingAverage 类型的指数移动平均和 DFTFilter 类型的离散傅里叶变换低通滤波器作为潜在的预处理程序:

trait PreprocessingModule[T] {
  trait Preprocessor[T] { //7
    def execute(x: Vector[T]): Try[DblVector] 
  } 
  val preprocessor: Preprocessor[T]//8

  class ExpMovingAverageT <: AnyVal
      (implicit num: Numeric[T], f: T =>Double) 
    extends Preprocessor[T] {

    val expMovingAvg = filtering.ExpMovingAverageT //10
    val pfn = expMovingAvg |>  //11
    override def execute(x: Vector[T]): Try[DblVector] = 
      pfn(x).map(_.toVector)
  }

   class DFTFilterT <: AnyVal
    (g: (Double,Double) =>Double) 
     (implicit f : T => Double)
   extends Preprocessor[T] { //12

     val filter = filtering.DFTFirT
     val pfn = filter |>
     override def execute(x: Vector[T]): Try[DblVector]=
        pfn(x).map(_.toVector)
   }
}

通用预处理特性 Preprocessor 声明了一个单一的 execute 方法,其目的是过滤 x 输入向量中 T 类型的元素以去除噪声(第 7 行)。预处理器的实例被声明为一个抽象类,以便作为过滤算法之一进行实例化(第 8 行)。

ExpMovingAverage 类型的第一个过滤算法实现了 Preprocessor 特性并覆盖了 execute 方法(第 9 行)。该类声明了算法,但其实现委托给具有相同 org.scalaml.filtering.ExpMovingAverage 签名的类(第 10 行)。从 |> 方法返回的偏函数被实例化为 pfn 值,因此它可以被多次应用(第 11 行)。相同的设计模式也用于离散傅里叶变换滤波器(第 12 行)。

根据输入数据的配置或特性选择过滤算法(ExpMovingAverageDFTFir)。其在 org.scalaml.filtering 包中的实现取决于环境(单个主机、集群、Apache spark 等)。

注意

过滤算法

在 Cake 模式背景下,用于展示模块化概念的过滤算法在 第三章 数据预处理 中详细描述。

配置文件数据

预处理、聚类或分类算法的选择高度依赖于输入数据(观察值和预期值,如果可用)的质量和特征。在第一章中,“让我们试试轮胎”下的“第 3 步 - 预处理数据”部分介绍了MinMax类,用于使用最小值和最大值归一化数据集。

不可变统计

均值和标准差是最常用的统计量。

注意

均值和方差

算术平均数定义为:

不可变统计

方差定义为:

不可变统计

考虑到抽样偏差的方差调整定义为:

不可变统计

让我们使用Stats扩展MinMax类以获得一些基本的统计功能:

class StatsT < : AnyVal(implicit f ; T => Double)
  extends MinMaxT {

  val zero = (0.0\. 0.0)
  val sums = values./:(zero)((s,x) =>(s._1 +x, s._2 + x*x)) //1

  lazy val mean = sums._1/values.size  //2
  lazy val variance = 
         (sums._2 - mean*mean*values.size)/(values.size-1)
  lazy val stdDev = Math.sqrt(variance)
  …
}

Stats类实现了不可变统计。其构造函数计算values的总和以及平方值的总和sums(行1)。如meanvariance之类的统计量在需要时通过将这些值声明为懒加载(行2)来一次性计算。Stats类继承了MinMax的归一化函数。

Z 分数与高斯

输入数据的高斯分布是通过Stats类的gauss方法实现的。

注意

高斯分布

M1:对于均值μ和标准差σ的变换,高斯定义为:

Z 分数与高斯

代码如下:

def gauss(mu: Double, sigma: Double, x: Double): Double = {
   val y = (x - mu)/sigma
   INV_SQRT_2PI*Math.exp(-0.5*y*y)/sigma
}
val normal = gauss(1.0, 0.0, _: Double)

正态分布的计算作为一个部分应用函数。Z 分数是通过考虑标准差对原始数据进行归一化计算得出的。

注意

Z 分数归一化

M2:对于均值μ和标准差σ的 Z 分数定义为:

Z 分数与高斯

Z 分数的计算是通过Stats类的zScore方法实现的:

def zScore: DblVector = values.map(x => (x - mean)/stdDev )

以下图表说明了zScore归一化和正态变换的相对行为:

Z 分数与高斯

线性、高斯和 Z 分数归一化的比较分析

评估模型

评估模型是工作流程的一个基本部分。如果你没有评估其质量的工具,创建最复杂的模型也没有意义。验证过程包括定义一些定量可靠性标准,设置策略,如K 折交叉验证方案,并选择适当的有标签数据。

验证

本节的目的在于创建一个可重用的 Scala 类来验证模型。首先,验证过程依赖于一组指标来量化通过训练生成的模型的适用性。

关键质量指标

让我们考虑一个简单的具有两个类别的分类模型,正类(相对于负类)用黑色(相对于白色)在以下图中表示。数据科学家使用以下术语:

  • 真阳性TP):这些是被正确标记为属于正类的观察值(在深色背景上的白色点)

  • 真阴性TN):这些是被正确标记为属于负类的观察值(浅色背景上的黑色点)

  • 假阳性FP):这些是不正确标记为属于正类的观察值(深色背景上的白色点)

  • 假阴性FN):这些是不正确标记为属于负类的观察值(浅色背景上的深色点)关键质量指标

    验证结果分类

这种简单的表示可以扩展到涉及两个以上类别的分类问题。例如,假阳性被定义为不正确标记为属于任何其他类别的观察值。这四个因素用于评估准确度、精确度、召回率和 F 及 G 度量,如下所示:

  • 准确度:这是正确分类的观察值的百分比,用 ac 表示。

  • 精确度:这是在分类器声明为正的组中正确分类为正的观察值的百分比。用 p 表示。

  • 回忆率:这是被标记为正的观察值中正确分类的比例,用 r 表示。

  • F[1]度量或 F[1]分数:这个度量在精确度和召回率之间取得平衡。它是精确度和召回率的调和平均数,分数范围在 0(最差分数)到 1(最佳分数)之间。用 F[1] 表示。

  • F[n]分数:这是具有任意 n 次方的通用 F 分数方法。用 F[n] 表示。

  • G 度量:这与 F 度量类似,但它是精确度 p 和召回率 r 的几何平均数。用 G 表示。

注意

验证指标

M3:准确度 ac、精确度 p、召回率 rF[1]F[n]G 分数定义如下:

关键质量指标

精确度、召回率和 F[1] 分数的计算取决于分类器中使用的类别数量。我们将考虑以下实现:

  • 二项式(两个类别)分类的 F 分数验证(即正负结果)

  • 多项式(多于两个类别)分类的 F 分数验证

二项式分类的 F 分数

二项式 F 验证计算正类的精确度、召回率和 F 分数。

让我们实现 F 分数或 F 度量作为以下特殊验证:

trait Validation { def score: Double }

BinFValidation类封装了F[n]分数以及精确度和召回率的计算,通过计数TPTNFPFN值来实现。它实现了M3公式。在 Scala 编程的传统中,该类是不可变的;它在类实例化时计算TPTNFPFN的计数器。该类接受以下三个参数:

  • 对于负结果为0值和正结果为1值的expected

  • 观测值集xt用于验证模型

  • 预测的predict函数对观测值进行分类(行1

代码如下:

class BinFValidationT <: AnyVal
     (predict: Array[T] => Int)(implicit f: T => Double) 
  extends Validation { //1

  val counters = {
    val predicted = xt.map( predict(_))
    expected.zip(predicted)
      .aggregate(new Counter[Label])((cnt, ap) => 
         cnt + classify(ap._1, ap._2), _ ++ _) //2
  }

  override def score: Double = f1   //3
  lazy val f1 = 2.0*precision*recall/(precision + recall)
  lazy val precision = compute(FP)  //4
  lazy val recall = compute(FN) 

  def compute(n: Label): Double = {
    val denom = counters(TP) + counters(n)
    counters(TP).toDouble/denom
  }
  def classify(predicted: Int, expected: Int): Label = //5
    if(expected == predicted) if(expected == POSITIVE) TP else TN
    else if(expected == POSITIVE) FN else FP 
}

构造函数计算每个四个结果(TPTNFPFN)的出现次数(行2)。precisionrecallf1值被定义为懒值,因此它们仅在直接访问或调用score方法时计算(行4)。F[1]度量是验证分类器最常用的评分值,因此它是默认评分(行3)。classify私有方法从预期值和预测值中提取限定符(行5)。

BinFValidation类与分类器类型、训练、标记过程和观测类型无关。

与 Java 不同,Java 将枚举定义为类型类,Scala 要求枚举必须是单例。枚举扩展了scala.Enumeration抽象类:

object Label extends Enumeration {
  type Label = Value
  val TP, TN, FP, FN = Value
}

具有更高基数(F[n])的 F 分数公式(n > 1)优先考虑精确度而不是召回率,这在以下图表中显示:

二项式分类的 F 分数

对给定召回率下精确度对 F1、F2 和 F3 分数影响的比较分析

注意

多类评分

我们实现的二项式验证计算仅对正类进行精确度、召回率和 F[1]分数。下一节中介绍的通用多项式验证类计算正类和负类的这些质量指标。

多项式分类的 F 分数

验证指标由M3公式定义。其思想非常简单:计算所有类别的精确度和召回率值,然后取平均值以产生整个模型的单个精确度和召回率值。整个模型的精确度和召回率利用了上一节中引入的TPFPFNTN的计数。

有两组常用的公式用于计算模型的精确度和召回率:

  • :此方法计算每个类的精确度和召回率,然后求和并取平均值。

  • :此方法在计算精确度和召回率之前,对所有类别的精确度和召回率的分子和分母进行求和。

从现在开始,我们将使用宏公式。

注意

多项式精确度和召回率的宏公式

M4:对于c个类别的模型,精确率p和召回率r的宏版本计算如下:

多项式分类的 F 分数

对于具有两个以上类别的分类器,计算精确率和召回率因子需要提取和操作混淆矩阵。我们使用以下约定:预期值定义为列,预测值定义为行

多项式分类的 F 分数

六类分类的混淆矩阵

MultiFValidation多项式验证类接受以下四个参数:

  • expected类别索引,对于负结果值为0,对于正结果值为1

  • 使用观测集xt来验证模型

  • 模型中的classes数量

  • predict预测函数对观测值进行分类(第 7 行)

代码如下:

class MultiFValidationT <: AnyVal
    (predict: Array[T] => Int)(implicit f : T => Double)
  extends Validation { //7

  val confusionMatrix: Matrix[Int] = //8
  labeled./:(MatrixInt){case (m, (x,n)) => 
    m + (n, predict(x), 1)}  //9

 val macroStats: DblPair = { //10
   val pr= Range(0, classes)./:(0.0,0.0)((s, n) => {
     val tp = confusionMatrix(n, n)   //11
     val fn = confusionMatrix.col(n).sum – tp  //12
     val fp = confusionMatrix.row(n).sum – tp  //13
     (s._1 + tp.toDouble/(tp + fp), s._2 +tp.toDouble/(tp + fn))
   })
   (pr._1/classes, pr._2/classes)
 }
 lazy val precision: Double = macroStats._1
 lazy val recall: Double = macroStats._1
 def score: Double = 2.0*precision*recall/(precision+recall)
 }

多类验证的核心元素是混淆矩阵,confusionMatrix(第 8 行)。其元素在索引(i, j) = (观测值的预期类别索引,相同观测值的预测类别索引)处计算,使用每个类别的预期和预测结果(第 9 行)。

如本节引言所述,我们使用宏定义的精确率和召回率(第 10 行)。每个类别的真阳性tp计数对应于混淆矩阵的对角元素(第 11 行)。一个类别的假阴性fn计数是所有预测类别计数之和(列值),给定一个预期类别(除了真阳性类别)(第 12 行)。一个类别的假阳性fp计数是所有预期类别计数之和(行值),给定一个预测类别(除了真阳性类别)(第 13 行)。

计算 F[1]分数的公式与二项式验证中使用的公式相同。

交叉验证

对于科学家可用的标记数据集(观测值加上预期结果),通常不是很大。解决方案是将原始标记数据集分成K组数据。

单折交叉验证

单折交叉验证是从标记数据集中提取训练集和验证集所使用的最简单方案,如下图中所述:

单折交叉验证

单折验证集生成的示意图

单折交叉验证方法包括以下三个步骤:

  1. 选择训练集大小与验证集大小的比例。

  2. 随机选择用于验证阶段的标记观测值。

  3. 将剩余的标记观测值创建为训练集。

单折交叉验证是通过OneFoldXValidation类实现的。它接受以下三个参数:观测值的xt向量,预期类别的expected向量,以及训练集大小与验证集大小的ratio(第14行):

type ValidationType[T] = Vector[(Array[T], Int)]
class OneFoldXValidationT <: AnyVal(implicit f : T => Double) {  //14
  val datasSet: (ValidationType[T], ValidationType[T]) //15
  def trainingSet: ValidationType[T] = datasSet._1
  def validationSet: ValidationType[T] = datasSet._1
}

OneFoldXValidation类的构造函数从观测值和预期类别的集合中生成分离的训练集和验证集(第15行):

val datasSet: (Vector[LabeledData[T]],Vector[LabeledData[T]]) = { 
  val labeledData = xt.drop(1).zip(expected)  //16
  val trainingSize = (ratio*expected.size).floor.toInt //17

  val valSz = labeledData.size - trainingSize
  val adjSz = if(valSz < 2) 1 
          else if(valSz >= labeledData.size)  labeledData.size -1 
          else valSz  //18
  val iter = labeledData.grouped(adjSz )  //18
  val ordLabeledData = labeledData
      .map( (_, Random.nextDouble) )  //19
      .sortWith( _._2 < _._2).unzip._1 //20

  (ordlabeledData.takeRight(adjValSz),   
   ordlabeledData.dropRight(adjValSz))  //21
}

OneFoldXValidation类的初始化通过将观测值和预期结果进行压缩来创建标记观测值的labeledData向量(第16行)。训练ratio值用于计算训练集(第17行)和验证集(第18行)的相应大小。

为了随机创建训练集和验证集,我们将标记的数据集与随机生成器进行压缩(第19行),然后通过排序随机值对标记的数据集进行重新排序(第20行)。最后,该方法返回训练集和验证集的成对(第21行)。

K 折交叉验证

数据科学家通过选择一个组作为验证集,然后将所有剩余的组合并成一个训练集来创建K个训练-验证数据集,如图所示。这个过程被称为K 折交叉验证 [2:7]。

K 折交叉验证

以下图示了 K 折交叉验证集的生成

第三部分用作验证数据,除了 S3 之外的所有数据集部分合并成一个单独的训练集。这个过程应用于原始标记数据集的每个部分。

偏差-方差分解

挑战在于创建一个模型,该模型既能拟合训练集,也能在验证阶段正确分类后续的观测值。

如果模型紧密地拟合了用于训练的观测值,那么新观测值可能无法被正确分类。这通常发生在模型复杂的情况下。这种模型的特点是具有低偏差和高方差。这种场景可以归因于科学家过度自信地认为她/他选择的用于训练的观测值代表了现实世界。

当选定的模型对训练集的拟合较为宽松时,新观测值被分类为正类别的概率会增加。在这种情况下,该模型的特点是具有高偏差和低方差。

分布的偏差方差均方误差(MSE)的数学定义由以下公式给出:

注意

M5: 对于真实模型θ,其方差和偏差定义为:

偏差-方差分解

M6: 均方误差定义为:

偏差-方差分解

让我们通过一个例子来说明偏差、方差和均方误差的概念。在这个阶段,你还没有接触到大多数机器学习技术。因此,我们创建了一个模拟器来展示分类器的偏差和方差之间的关系。模拟的组成部分如下:

  • 训练集,training

  • 从训练集中提取的target: Double => Double类型的模拟target模型

  • 一组可能的models用于评估

完全匹配训练数据的模型会过拟合目标模型。近似目标模型的模型很可能会欠拟合。本例中的模型由单变量函数定义。

这些模型与验证数据集进行比较。BiasVariance类接受目标模型target和验证测试的nValues大小作为参数(行22)。它仅实现了计算每个模型的偏差和方差的公式:

type Dbl_F = Double => Double 
class BiasVarianceT
     (implicit f: T => Double) {//22
  def fit(models: List[Dbl_F]): List[DblPair] = { //23
    models.map(accumulate(_, models.size)) //24
  }
}

fit方法计算每个models模型相对于target模型的方差和偏差(行23)。它在accumulate方法中计算均值、方差和偏差(行24):

def accumulate(f: Dbl_F, y:Double, numModels: Int): DblPair = 
  Range(0, nValues)./:(0.0, 0.0){ case((s,t) x) => { 
    val diff = (f(x) - y)/numModels
    (s + diff*diff, t + Math.abs(f(x)-target(x))) //25
  }}

训练数据由具有噪声成分r[1]r[2]的单变量函数生成:

偏差-方差分解

accumulate方法为每个模型返回一个元组(方差,偏差)f(行25)。模型候选者由以下单变量函数族定义,对于值n = 1, 2, 和 4

偏差-方差分解

target模型(行26)和models(行27)属于同一类单变量函数:

val template = (x: Double, n : Int) => 
                        0.2*x*(1.0 + Math.sin(x*0.1)/n) 
val training = (x: Double) => {
  val r1 = 0.45*(Random.nextDouble-0.5)
  val r2 = 38.0*(Random.nextDouble - 0.5) + Math.sin(x*0.3)
  0.2*x*(1.0 + Math.sin(x*0.1 + r1)) + r2
}
Val target = (x: Double) => template(x, 1) //26
val models = List[(Dbl_F, String)] (  //27
  ((x: Double) => template(x, 4), "Underfit1"),  
  ((x: Double) => template(x, 2), "Underfit2"),
  ((x : Double) => training(x), "Overfit")
  (target, "target"),
)
val evaluator = new BiasVarianceDouble
evaluator.fit(models.map( _._1)) match { /* … */ }

使用JFreeChart库显示训练数据集和模型:

偏差-方差分解

将模型拟合到数据集

复制训练数据的模型会过拟合。对template函数的正弦分量使用较低振幅进行平滑的模型会欠拟合。不同模型和训练数据的偏差-方差权衡在以下散点图中展示:

偏差-方差分解

四个模型的偏差-方差权衡的散点图,其中一个复制了训练集

每个平滑或近似模型的方差都低于训练集的方差。正如预期的那样,目标模型0.2.x.(1+sin(x/10))没有偏差和方差。训练集具有非常高的方差,因为它对任何目标模型都过拟合。最后一张图表比较了每个模型、训练集和目标模型之间的均方误差:

偏差-方差分解

四个模型的比较均方误差

注意

评估偏差和方差

该节使用一个虚构的目标模型和训练集来说明模型的偏差和方差的概念。机器学习模型的偏差和方差实际上是通过验证数据来估计的。

过度拟合

你可以将示例中提出的方法应用于任何分类和回归模型。具有低方差模型的列表包括常数函数和与训练集无关的模型。高次多项式、复杂函数和深度神经网络具有高方差。应用于线性数据的线性回归具有低偏差,而应用于非线性数据的线性回归具有更高的偏差 [2:8]。

过度拟合对建模过程的各个方面都有负面影响,例如:

  • 它使得调试变得困难

  • 它使模型过于依赖微小的波动(长尾)和噪声数据

  • 它可能会发现观测特征和潜在特征之间不相关的关系

  • 它导致预测性能不佳

然而,有一些经过充分验证的解决方案可以减少过度拟合 [2:9]:

  • 在可能的情况下增加训练集的大小

  • 使用平滑和过滤技术减少标记观察中的噪声

  • 使用主成分分析等技术减少特征数量,如第四章主成分分析中所述,无监督学习

  • 使用卡尔曼或自回归模型对可观测和潜在噪声数据进行建模,如第三章数据预处理中所述,数据预处理

  • 通过应用交叉验证减少训练集中的归纳偏差

  • 使用正则化技术惩罚模型的一些特征中的极端值,如第六章正则化中所述,回归和正则化

摘要

在本章中,我们为本书中将要介绍的不同数据处理单元建立了框架。在本书中早期探讨模型验证和过度拟合的话题有很好的理由。如果我们没有一种方法来评估它们的相对优点,那么构建模型和选择算法就没有意义。

在本章中,你被介绍了以下内容:

  • 隐式和显式模型的单子变换概念

  • Scala 中 Cake 模式及其 mixins 组合的灵活性和简洁性,作为数据处理的有效脚手架工具

  • 一种稳健的方法来验证机器学习模型

  • 将模型拟合到训练数据和现实世界数据中的挑战

下一章将通过识别异常值和减少数据中的噪声来解决这个问题,即过度拟合。

第三章:数据预处理

实际数据通常是有噪声的,并且与缺失观察不一致。没有任何分类、回归或聚类模型可以从原始数据中提取相关信息。

数据预处理包括使用统计方法对原始观察数据进行清理、过滤、转换和归一化,以便关联特征或特征组,识别趋势和模型,并过滤掉噪声。清理原始数据的目的如下:

  • 从原始数据集中提取一些基本知识

  • 评估数据质量并生成用于无监督或监督学习的干净数据集

你不应低估传统统计分析方法从文本或非结构化数据中推断和分类信息的能力。

在本章中,你将学习如何:

  • 将常用的移动平均技术应用于时间序列以检测长期趋势

  • 使用离散傅里叶级数识别市场和行业周期

  • 利用离散卡尔曼滤波器从不完全和有噪声的观察中提取线性动态系统的状态

Scala 中的时间序列

本书用于说明不同机器算法的大多数示例都涉及时间序列或按时间顺序排列的观察数据集。

类型和方法

在第一章中“源代码”部分的“原始类型”部分介绍了单个XSeries[T]变量和多个XVSeries[T]变量时间序列的类型。

观察的时间序列是以下类型的观察元素向量(一个Vector类型):

  • 对于单个变量/特征观察,使用T类型。

  • 对于具有多个变量/特征的观察,使用Array[T]类型

标签或预期值的时间序列是一个单一变量向量,其元素可以是用于分类的原始Int类型和用于回归的Double类型。

标记观察的时间序列是一对观察向量与标签向量:

类型和操作

单个特征和多特征观察的可视化

从现在起,将使用两个通用的XSeriesXVSeries类型作为输入数据的主要类,用于时间序列。

注意

标记观察的结构

在整本书中,标记观察被定义为观察向量和标签/预期值向量的对,或者是一对{观察,标签/预期值}的向量。

在第二章中“数据概要”部分介绍的Stats类,实现了对单个变量观察的一些基本统计和归一化。让我们创建一个XTSeries单例来计算统计信息和归一化多维观察:

object XTSeries { 
  def zipWithShiftT: Vector[(T, T)] = 
     xv.drop(n).zip(xv.view.dropRight(n))  //1

  def zipWithShift1T: Vector[(T, T)] = 
     xv.zip(xv.view.drop(n))

  def statisticsT <: AnyVal
       (implicit f: T =>: Double): Vector[Stats[T]] = 
    xt.transpose.map( StatsT)  //2

  def normalizeT <: AnyVal 
      (implicit ordering: Ordering[T], 
          f: T => Double): Try[DblVector] = 
    Try (StatsT.normalize(low, high) )
   ...
}

XTSeries单例的第一个方法通过将时间序列的最后size – n个元素与它的第一个size – n个元素进行连接来生成一对元素的向量(行1)。statistics(行2)和normalize(行3)方法作用于单变量和多变量观测。这三个方法是XTSeries中实现的功能的子集。

通过将两个xy向量进行连接并转换成数组来创建一个XVSeries[T]类型的时间序列:

def zipToSeriesT: ClassTag: XVSeries[T]

将单维或多维时间序列xv在索引n处分割成两个时间序列:

def splitAtT: (XSeries[T], XSeries[T])

对单维时间序列应用zScore转换:

def zScoreT <: AnyVal
    (implicit f: T => Double): Try[DblVector]

对多维时间序列应用zScore转换:

def zScoresT <: AnyVal
    (implicit f: T => Double): Try[XVSeries[Double]] 

将单维时间序列x转换成一个新的时间序列,其元素为x(n) – x(n-1)

def delta(x: DblVector): DblVector

将单维时间序列x转换成一个新的时间序列,其元素为(x(n) – x(n-1) > 0.0) 1 else 0

def binaryDelta(x: DblVector): Vector[Int]

计算两个xz数组之间平方误差的和:

def sseT <: AnyVal
   (implicit f: T => Double): Double

计算两个xz数组之间的均方误差:

def mseT <: AnyVal
    (implicit f: T => Double): Double

计算两个xz向量之间的均方误差:

def mse(x: DblVector, z: DblVector): Double

计算多维时间序列每个特征的统计数据:

def statisticsT <: AnyVal
    (implicit f: T => Double): Vector[Stats[T]]

XVSeries类型的两个多维向量对应用f函数:

def zipToVectorT
  (f: (Array[T], Array[T]) =>Double): XSeries[Double] = 
  x.zip(y.view).map{ case (x, y) => f(x,y)}

磁场模式

实现为XTSeries方法的某些时间序列操作可能有多种输入和输出类型。Scala 和 Java 支持方法重载,但有以下限制:

  • 它不能防止由 JVM 中的擦除类型引起的类型冲突

  • 它不允许提升到单个通用函数

  • 它不能完全减少代码冗余

转置算子

让我们考虑任何类型的多维时间序列的转置算子。转置算子可以对象化为Transpose特质:

sealed trait Transpose {
  type Result   //4
  def apply(): Result  //5
}

特质有一个抽象的Result类型(行4)和一个抽象的apply()构造函数(行5),这允许我们创建一个具有任何输入和输出类型组合的通用transpose方法。transpose方法的输入和输出类型的转换类型定义为implicit

implicit def xvSeries2MatrixT: ClassTag = 
  new Transpose { type Result = Array[Array[T]]  //6
    def apply(): Result =  from.toArray.transpose
}
implicit def list2MatrixT: ClassTag = 
  new Transpose { type Result = Array[Array[T]]  //7
   def apply(): Result =  from.toArray.transpose
}
…

第一个xvSeries2Matrix隐式地将XVSeries[T]类型的时间序列转换为元素类型为T的矩阵(行6)。list2Matrix隐式地将List[Array[T]]类型的时间序列转换为元素类型为T的矩阵(行7)。

通用transpose方法编写如下:

def transpose(tpose: Transpose): tpose.Result = tpose()

微分算子

磁场模式的第二个候选是计算时间序列的微分。目的是从一个时间序列{x[t]}生成时间序列{x[t+1] – x[t]}

sealed trait Difference[T] {
  type Result
  def apply(f: (Double, Double) => T): Result
}

Difference特性允许我们计算任意元素类型的时间序列的差分。例如,Double类型一维向量的差分由以下隐式转换定义:

implicit def vector2DoubleT = new Difference[T] {
  type Result = Vector[T]
  def apply(f: (Double, Double) => T): Result =  //8
    zipWithShift(x, 1).collect{case(next,prev) =>f(prev,next)}
}

apply()构造函数接受一个参数:用户定义的f函数,该函数计算时间序列中连续两个元素的差(第8行)。通用的差分方法如下:

def differenceT => T): diff.Result = diff(f)

这里是一些时间序列预定义的差分算子,其输出类型为Double(第9行),Int(第10行)和Boolean(第11行):

val diffDouble = (x: Double,y: Double) => y –x //9
val diffInt = (x: Double,y: Double) => if(y > x) 1 else 0 //10
val diffBoolean = (x: Double,y: Double) => (y > x) //11

差分算子用于实现labeledData方法,从具有两个特征和目标(标签)数据集的观测值生成标签数据:

def differentialDataT =>T): Try[(XVSeries[Double],Vector[T])] = 
  Try((zipToSeries(x,y), difference(target, f)))

标签数据的结构是观测值和目标值差分的配对。

懒惰视图

Scala 中的视图是一个代理集合,它代表一个集合,但以懒加载的方式实现数据转换或高阶方法。视图的元素被定义为懒值,它们在需要时才被实例化。

视图相较于一个严格(或完全分配)的集合的一个重要优点是减少了内存消耗。

让我们看看在第二章中“工作流计算模型”下的“实例化工作流”部分引入的aggregator数据转换,Hello World!。不需要分配整个x.size的元素集合:高阶find方法可能只读取了几个元素后就会退出(第12行):

val aggregator = new ETransformInt { 
  override def |> : PartialFunction[U, Try[V]] = { 
    case x: U if(!x.isEmpty) => 
      Try( Range(0, x.size).view.find(x(_) == 1.0).get) //12
   }
}

注意

视图、迭代器和流

视图、迭代器和流具有相同的构建按需元素的目标。然而,它们之间有一些主要区别:

  • 迭代器不会持久化集合的元素(只读一次)

  • 流允许对具有未定义大小的集合执行操作

移动平均

移动平均为数据分析师和科学家提供了一个基本的预测模型。尽管其简单,移动平均法在各个领域都得到了广泛应用,如市场调查、消费者行为或体育统计。交易者使用移动平均法来识别给定证券价格的不同支撑和阻力水平。

注意

平均减少函数

让我们考虑时间序列x[t] = x(t)和函数f(x[t-p-1],…, x[t]),它将最后p个观测值减少到一个值或平均值。在t处的观测值估计由以下公式定义:

移动平均

这里,f是来自之前p个数据点的平均减少函数。

简单移动平均

简单移动平均是移动平均算法的最简单形式 [3:1]。周期为 p 的简单移动平均通过以下公式估计时间 t 的值:

注意

简单移动平均

M1:时间序列 {x[t]} 的周期为 p 的简单移动平均是通过计算最后 p 个观测值的平均值来得到的:

简单移动平均

M2:计算通过以下公式迭代实现:

简单移动平均

在这里,简单移动平均 是时间 t 的估计值或简单移动平均值。

让我们构建一个移动平均算法的类层次结构,其中参数化的 MovingAverage 特性作为其根:

trait MovingAverage[T]

我们使用通用的 XSeries[T] 类型以及 ETransform 显式配置的数据转换,该配置在 第二章 下 Monadic 数据转换显式模型 部分中介绍,用于实现简单移动平均,SimpleMovingAverage

class SimpleMovingAverageT <: AnyVal
     (implicit num: Numeric[T], f: T => Double) //1
  extends EtransformInt with MovingAverage[T] {

  type U = XSeries[T]  //2
  type V = DblVector   //3

  val zeros = Vector.fill(0.0)(period-1) 
  override def |> : PartialFunction[U, Try[V]] = {
    case xt: U if( xt.size >= period ) => {
      val splits = xt.splitAt(period)
      val slider = xt.take(xt.size - period).zip(splits._2)  //4

      val zero = splits._1.sum/period //5
      Try( zeros ++ slider.scanLeft(zero) {
         case (s, (x,y)) => s + (x - y)/period }) //7
  }
}

该类针对输入时间序列的 T 类型元素进行参数化;我们无法对输入数据的类型做出任何假设。输出时间序列的元素类型为 Doublesum/ 算术运算符需要 Numeric[T] 类的隐式实例化(行 1)。简单移动平均通过定义输入的抽象 U 类型(行 2)和输出为 V 的类型(行 3)作为时间序列 DblVector 来实现 ETransform

实现有几个有趣的元素。首先,观测值集被复制,并在与原始数据合并到包含一对 slider 值的数组之前,将索引在结果克隆实例中移动 p 个观测值(行 4):

简单移动平均

计算移动平均的滑动算法

平均值初始化为第一个 period 数据点的平均值(行 5)。趋势的第一个 period 值初始化为零(行 6)。该方法通过连接初始空值和计算出的平均值来实现 M2 公式(行 7)。

加权移动平均

加权移动平均方法是通过计算最后 p 个观测值的加权平均值来扩展简单移动平均的 [3:2]。权重 α[j] 被分配给最后 p 个数据点 x[j],并通过权重的总和进行归一化。

注意

加权移动平均

M3:具有周期 p 和归一化权重分布 {α[j]} 的序列 {x[t]} 的加权移动平均由以下公式给出:

加权移动平均

在这里,x[t] 是时间 t 的估计值或简单移动平均值。

WeightedMovingAverage 类的实现需要计算最后 p (weights.size) 个数据点。没有简单的迭代公式可以用来在时间 t + 1 时使用时间 t 时刻的移动平均来计算加权移动平均:

class WeightedMovingAverage@specialized(Double) T <: AnyVal
    (implicit num: Numeric[T], f: T => Double) 
  extends SimpleMovingAverageT {  //8

  override def |> : PartialFunction[U, Try[V]] = {
    case xt: U if(xt.size >= weights.length ) => {
      val smoothed =  (config to xt.size).map( i => 
       xt.slice(i- config, i).zip(weights) //9
             .map { case(x, w) => x*w).sum  //10
      )
      Try(zeros ++ smoothed) //11
    }
  }
}

加权移动平均的计算比简单移动平均复杂一些。因此,我们使用专门的注解指定为 Double 类型生成专用字节码。加权移动平均继承自 SimpleMovingAverage 类,因此,实现了配置权重时的 ETransform 显式转换,输入观测类型为 XSeries[T],输出观测类型为 DblVectorM3 公式的实现通过切片(第 9 行)输入时间序列并计算权重与时间序列切片的内积(第 10 行)来生成平滑的时间序列。

与简单移动平均一样,输出是初始 weights.size 个空值、zerossmoothed 数据(第 11 行)的连接。

指数移动平均

指数移动平均在金融分析和市场调查中被广泛使用,因为它倾向于最新的值。值越老,对时间 t 时刻的移动平均值的冲击就越小 [3:3]。

注意

指数移动平均

M4:计算序列 {x[t]} 的指数移动平均和平滑因子 α 的迭代公式如下:

指数移动平均

这里,指数移动平均 是在 t 时刻的指数平均值。

ExpMovingAverage 类的实现相当简单。构造函数有一个单一的 α 参数(衰减率)(第 12 行):

class ExpMovingAverage@specialized(Double) T <: AnyVal    //12
    (implicit f: T => Double) 
  extends ETransformDouble with MovingAverage[T]{ //13

  type U = XSeries[T]    //14
  type V = DblVector    //15

  override def |> : PartialFunction[U, Try[V]] = {
    case xt: U if( xt.size > 0) => {
      val alpha_1 = 1-alpha
      var y: Double = data(0)
      Try( xt.view.map(x => {
        val z = x*alpha + y*alpha_1; y = z; z})) //16
    }
}
}

指数移动平均通过定义输入(第 14 行)的抽象 U 类型为时间序列 XSeries[T] 和输出(第 15 行)的 V 类型为时间序列 DblVector 来实现 ETransform(第 13 行)。|> 方法在 map(第 16 行)中对时间序列的所有观测值应用 M4 公式。

使用 p 期来计算 alpha = 1/(p+1) 的构造函数版本是通过 Scala 的 apply 方法实现的:

def applyT <: AnyVal
     (implicit f: T => Double): ExpMovingAverage[T] = 
  new ExpMovingAverageT)

让我们比较从这三种移动平均方法生成的结果与原始价格。我们使用一个数据源,DataSource,从美国银行(BAC)的历史每日收盘价中加载和提取值,这些数据可在雅虎财经页面找到。DataSink 类负责格式化和将结果保存到 CSV 文件中,以供进一步分析。DataSourceDataSink 类在 附录 A 的 数据提取 部分中详细描述,基本概念

import YahooFinancials._
val hp = p >>1
val w = Array.tabulate(p)(n => 
       if(n == hp) 1.0 else 1.0/(Math.abs(n - hp)+1)) //17
val sum = w.sum
val weights = w.map { _ / sum }                          //18

val dataSrc = DataSource(s"$RESOURCE_PATH$symbol.csv", false)//19
val pfnSMvAve = SimpleMovingAverageDouble |>         //20   
val pfnWMvAve = WeightedMovingAverageDouble |>  
val pfnEMvAve = ExpMovingAverageDouble |>

for {
   price <- dataSrc.get(adjClose)   //21
   if(pfnSMvSve.isDefinedAt(price) )
   sMvOut <- pfnSMvAve(price)         //22
   if(pfnWMvSve.isDefinedAt(price)
   eMvOut <- pfnWMvAve(price)
  if(pfnEMvSve.isDefinedAt(price)
   wMvOut <- pfnEMvAve(price)
} yield {
  val dataSink = DataSinkDouble
  val results = ListDblSeries
  dataSink |> results  //23
}

注意

isDefinedAt

每个部分函数通过调用isDefinedAt进行验证。从现在起,为了清晰起见,本书中将省略部分函数的验证。

加权移动平均的系数是在(行17)生成的,并在(行18)进行归一化。关于股票代码 BAC 的交易数据是从 Yahoo Finances CSV 文件(行19),YahooFinancials,使用adjClose提取器(行20)提取的。下一步是初始化与每个移动平均相关的pfnSMvAvepfnWMvAvepfnEMvAve部分函数(行21)。以price作为参数调用部分函数生成三个平滑的时间序列(行22)。

最后,DataSink实例格式化并将结果输出到文件(行23)。

注意

隐式后缀操作

实例化filter |>部分函数需要通过导入scala.language.postfixOps使后缀操作postfixOps可见。

加权移动平均方法依赖于通过将函数作为参数传递给通用tabulate方法计算出的归一化权重的对称分布。请注意,如果无法计算特定的移动平均之一,则显示原始的价格时间序列。以下图形是加权移动平均的对称滤波器示例:

指数移动平均

加权移动平均的对称滤波器示例

三种移动平均技术应用于美国银行(BAC)股票在 200 个交易日内的价格。简单移动平均和加权移动平均都使用 11 个交易日的周期。指数移动平均方法使用缩放因子2/(11+1) = 0.1667

指数移动平均

美国银行历史股票价格 11 天的移动平均

这三种技术从原始的历史价格时间序列中滤除了噪声。尽管平滑因子较低,指数移动平均对突然的价格波动反应敏感。如果您将周期增加到 51 个交易日(相当于两个日历月),简单移动平均和加权移动平均产生的时间序列比具有alpha = 2/(p+1) = 0.038的指数移动平均更平滑:

指数移动平均

美国银行历史股票价格 51 天的移动平均

欢迎您进一步实验不同的平滑因子和权重分布。您将能够确认以下基本规则:随着移动平均周期的增加,频率逐渐降低的噪声被消除。换句话说,允许的频率窗口正在缩小。移动平均充当一个低通滤波器,仅保留低频。

微调平滑因子的周期是耗时的。频谱分析,或更具体地说,傅里叶分析将时间序列转换为一个频率序列,为统计学家提供了一个更强大的频谱分析工具。

注意

多维时间序列上的移动平均

为了简化起见,移动平均技术被用于单个特征或变量时间序列。多维时间序列上的移动平均是通过使用XTSeriestransform方法对每个特征执行单个变量的移动平均来计算的,这在第一节中已介绍。例如,应用于多维时间序列xt的简单移动平均,其平滑值计算如下:

   val pfnMv = SimpleMovingAverageDouble |>
   val smoothed = transform(xt, pfnMv)

傅里叶分析

谱密度估计的目的是根据其频率测量信号或时间序列的幅度 [3:4]。目标是通过对数据集中的周期性进行检测来估计谱密度。通过分析其谐波,科学家可以更好地理解信号或时间序列。

注意

谱理论

时间序列的频谱分析不应与谱理论混淆,谱理论是线性代数的一个子集,研究希尔伯特和巴拿赫空间上的特征函数。事实上,调和分析和傅里叶分析被视为谱理论的一部分。

让我们探讨离散傅里叶级数的概念以及其在金融市场中的应用优势。傅里叶分析将任何通用函数近似为三角函数(正弦和余弦函数)的和。

注意

复数傅里叶变换

本节重点介绍实值离散傅里叶级数。通用的傅里叶变换适用于复数值 [3:5]。

在基本三角函数分解过程中,分解称为 傅里叶变换 [3:6]。

离散傅里叶变换

时间序列 {x[k]} 可以表示为一个离散的实时域函数 f, x = f(t)。在 18 世纪,让-巴蒂斯特-约瑟夫·傅里叶证明了任何连续周期函数 f 都可以表示为正弦和余弦函数的线性组合。离散傅里叶变换DFT)是一种线性变换,它将时间序列转换为一个有限组合的复数或实数三角函数的系数列表,按其频率排序。

每个三角函数的频率 ω 定义了信号的谐波之一。表示信号幅度与频率关系的空间被称为 频域。通用的快速傅里叶变换(DFT)将时间序列转换为一个由复数 a + j.φ (j² = -1) 定义的频率序列,其中 a 是频率的幅度,φ 是相位。

本节专门介绍将时间序列转换为具有实数值的有序频率序列的实数 DFT。

注意

实数离散傅里叶变换

M5:周期函数 f 可以表示为正弦和余弦函数的无限组合:

离散傅里叶变换

M6:函数 f 的傅里叶余弦变换定义为:

离散傅里叶变换

M7:函数 f(-x) = f(x) 的离散实数余弦级数定义为:

离散傅里叶变换

M8:函数的傅里叶正弦变换定义为:

离散傅里叶变换

M9:函数 f(-x) = f(x) 的离散实数正弦级数定义为:

离散傅里叶变换

傅里叶三角级数的计算耗时,其渐近时间复杂度为 O(n²)。科学家和数学家一直在努力使计算尽可能有效。用于计算傅里叶级数最常用的数值算法是由 J.W. 库利和 J. 图基创建的 快速傅里叶变换FFT)[3:7]。

Radix-2 版本的算法递归地将时间序列的傅里叶变换分解为任何组合的 N[1]N[2] 大小的段,例如 N = N[1] N[2]。最终,离散傅里叶变换应用于更深层次的段。

提示

库利-图基算法

我鼓励你使用 Scala 的尾递归实现 Radix-2 库利-图基算法。

Radix-2 实现要求数据点的数量为 N=2^n,对于偶函数(正弦函数)和 N=2^n+1,对于余弦函数。有两种方法来满足这个约束:

  • 将实际点数减少到下一个较低的基数,2^n < N

  • 通过用 0 填充到下一个更高的基数,扩展原始时间序列,N < 2^n**+1

填充时间序列是首选选项,因为它不会影响原始的观测集。

让我们为任何离散傅里叶变换的变体定义一个 DTransform 特性。第一步是将 Apache Commons Math 库中使用的默认配置参数封装到一个 Config 单例中:

trait DTransform {
  object Config {
     final val FORWARD = TransformType.FORWARD
     final val INVERSE = TransformType.INVERSE
     final val SINE = DstNormalization.STANDARD_DST_I
     final val COSINE = DctNormalization.STANDARD_DCT_I
   }
   …
}

DTransform 特性的主要目的是用零值填充 vec 时间序列:

def pad(vec: DblVector, 
    even: Boolean = true)(implicit f: T =>Double): DblArray = {
  val newSize = padSize(vec.size, even)   //1
  val arr: DblVector = vec.map(_.toDouble)
  if( newSize > 0) arr ++ Array.fill(newSize)(0.0) else arr //2
}

def padSize(xtSz: Int, even: Boolean= true): Int = {
  val sz = if( even ) xtSz else xtSz-1  //3
  if( (sz & (sz-1)) == 0) 0
  else {
    var bitPos = 0
    do { bitPos += 1 } while( (sz >> bitPos) > 0) //4
    (if(even) (1<<bitPos) else (1<<bitPos)+1) - xtSz
  }
}

pad 方法通过调用 padSize 方法(行 1)计算频率向量的最佳大小为 2^N。然后,它将填充与原始时间序列或观测向量连接(行 2)。padSize 方法根据时间序列最初是否有偶数个观测值调整数据的大小(行 3)。它依赖于位操作来找到下一个基数,N(行 4)。

注意

while 循环

Scala 开发者更喜欢使用 Scala 高阶方法来对集合进行迭代计算。然而,如果可读性或性能是一个问题,你仍然可以使用传统的 whiledo {…} while 循环。

填充方法 pad 的快速实现包括检测 N 个观测值作为 2 的幂(下一个最高的基数)。该方法在将值 N 中的位数移动后评估 N(N-1) 是否为零。代码展示了在 pad 方法中有效使用隐式转换以使代码可读的示例:

 val arr: DblVector = vec.map(_.toDouble)

下一步是编写 DFT 类,用于实正弦和余弦离散变换,通过继承 DTransform 实现。该类在必要时依赖于 DTransform 中实现的填充机制:

class DFT@specialized(Double) T <: AnyVal(implicit f: T => Double)
    extends ETransformDouble with DTransform { //5
  type U = XSeries[T]   //6
  type V = DblVector

  override def |> : PartialFunction[U, Try[V]] = { //7
    case xv: U if(xv.size >= 2) => fwrd(xv).map(_._2.toVector) 
  }
}

我们将离散傅里叶变换视为使用显式 ETransform 配置对时间序列进行变换(行 5)。输入的 U 数据类型和输出的 V 类型必须定义(行 6)。|> 变换函数将计算委托给 fwrd 方法(行 7):

def fwrd(xv: U): Try[(RealTransformer, DblArray)] = {
  val rdt = if(Math.abs(xv.head) < config)  //8
      new FastSineTransformer(SINE)  //9
  else  new FastCosineTransformer(COSINE)  //10

  val padded = pad(xv.map(_.toDouble), xv.head == 0.0).toArray
  Try( (rdt, rdt.transform(padded, FORWARD)) )
}

如果时间序列的第一个值是 0.0,则 fwrd 方法选择离散傅里叶正弦级数;否则,它选择离散余弦级数。此实现通过评估 xt.head(行 8)来自动选择适当的序列。变换调用 Apache Commons Math 库中的 FastSineTransformer(行 9)和 FastCosineTransformer(行 10)类,这些类在第一章中介绍 [3:8]。

此示例使用由 COSINE 参数定义的余弦和正弦变换的标准公式。通过将频率通过一个因子 1/sqrt(2(N-1)) 进行归一化,其中 N 是时间序列的大小,生成一个更干净的频率谱,但计算成本更高。

注意

@specialized 注解

@specialized(Double) 注解用于指示 Scala 编译器为 Double 类型生成一个专门且更高效的类版本。专化的缺点是字节码的重复,因为专门版本与参数化类共存 [3:9]。

为了说明 DFT 背后的不同概念,让我们考虑由正弦函数的 h 序列生成的时间序列的情况:

val F = ArrayDouble
val A = ArrayDouble

def harmonic(x: Double, n: Int): Double =  
      A(n)*Math.cos(Math.PI*F(n)*x)
val h = (x: Double) => 
    Range(0, A.size).aggregate(0.0)((s, i) => 
          s + harmonic(x, i), _ + _)

由于信号是合成的,我们可以选择时间序列的大小以避免填充。时间序列的第一个值不为空,因此观测值的数量是 2^n+1。由 h 函数生成的数据如下所示:

离散傅里叶变换

正弦时间序列的示例

让我们提取由 h 函数生成的时间序列的频率谱。数据点是通过表格化 h 函数创建的。频率谱是通过简单调用 DFT 类的显式 |> 数据变换来计算的:

val OUTPUT1 = "output/chap3/simulated.csv"
val OUTPUT2 = "output/chap3/smoothed.csv"
val FREQ_SIZE = 1025; val INV_FREQ = 1.0/FREQ_SIZE

val pfnDFT = DFT[Double] |> //11
for {
  values <- Try(Vector.tabulate(FREQ_SIZE)
               (n => h(n*INV_FREQ))) //12
  output1 <- DataSinkDouble.write(values)
  spectrum <- pfnDFT(values)
  output2 <- DataSinkDouble.write(spectrum) //13
} yield {
  val results = format(spectrum.take(DISPLAY_SIZE), 
"x/1025", SHORT)
  show(s"$DISPLAY_SIZE frequencies: ${results}")
}

数据模拟器的执行遵循以下步骤:

  1. 使用 3 阶谐波的 h 函数生成原始数据(行 12)。

  2. 实例化由变换生成的部分函数(行 11)。

  3. 将结果频率存储在数据接收器(文件系统)中(行 13)。

注意

数据接收器和电子表格

在这个特定的情况下,离散傅里叶变换的结果被输出到一个 CSV 文件中,以便它可以被加载到电子表格中。一些电子表格支持一组过滤技术,可以用来验证示例的结果。一个更简单的替代方案是使用 JFreeChart。

时间序列的频率谱,对于前 32 个点进行绘图,清楚地显示了在 k = 2515 处的三个频率。这个结果是预期的,因为原始信号由三个正弦函数组成。这些频率的幅度分别是 1024/1、1024/2 和 1024/6。以下图表表示时间序列的前 32 个谐波:

离散傅里叶变换

三频率正弦波的频谱

下一步是使用频率谱通过 DFT 创建一个低通滤波器。在时域中实现低通或带通滤波器的算法有很多,从自回归模型到巴特沃斯算法。然而,离散傅里叶变换仍然是一种非常流行的技术,用于平滑信号和识别趋势。

注意

大数据

对大型时间序列进行 DFT 可能非常计算密集。一个选项是将时间序列视为连续信号,并使用 奈奎斯特 频率对其进行采样。奈奎斯特频率是连续信号采样率的一半。

基于 DFT 的过滤

本节的目的在于介绍、描述和实现一个利用离散傅里叶变换的噪声过滤机制。这个想法相当简单:正向和逆向傅里叶级数依次使用,将原始数据从时域转换为频域再转换回时域。您唯一需要提供的是函数 g,该函数修改频率序列。这种操作称为滤波器 g 与频率谱的卷积。

在频域中,卷积类似于两个时间序列的内积。数学上,卷积定义为以下内容:

注意

卷积

M10:两个函数 fg 的卷积定义为:

基于 DFT 的过滤

M11:时间序列 x = (x[i]} 与频率谱 ω^x 和频域中的滤波器 f 的卷积 F 定义为:

基于 DFT 的过滤

让我们将卷积应用于我们的过滤问题。使用离散傅里叶变换的过滤算法包括五个步骤:

  1. 填充时间序列以启用离散正弦或余弦变换。

  2. 使用正向变换 F 生成频率的有序序列。

  3. 在频域中选择滤波函数 G 和截止频率。

  4. 将频率序列与滤波函数 G 进行卷积。

  5. 通过对卷积频率应用逆 DFT 变换,在时域中生成滤波后的信号。基于 DFT 的滤波

    离散傅里叶滤波器的示意图

最常用的低通滤波函数被称为 sincsinc2 函数,分别定义为矩形函数和三角形函数。这些函数是从通用 convol 方法派生出的部分应用函数。最简单的 sinc 函数在截止频率 fC 以下返回 1,如果频率更高则返回 0

val convol = (n: Int, f: Double, fC: Double) => 
     if( Math.pow(f, n) < fC) 1.0 else 0.0
val sinc = convol(1, _: Double, _:Double)
val sinc2 = convol(2, _: Double, _:Double)
val sinc4 = convol(4, _: Double, _:Double)

注意

部分应用函数与部分函数

部分函数和部分应用函数实际上并不相关。

部分函数 f' 是应用于输入空间 X 的子集 X' 的函数。它不会执行所有可能的输入值:

基于 DFT 的滤波

部分应用函数 f" 是用户为其中一个或多个参数提供值的函数值。投影减少了输入空间的维度 (X, Z)

基于 DFT 的滤波

DFTFilter 类从 DFT 类继承,以便重用 fwrd 正向变换函数。g 频率域函数是滤波器的一个属性。g 函数将 fC 频率截止值作为第二个参数(第 14 行)。上一节中定义的两个 sincsinc2 滤波器是滤波函数的例子:

class DFTFilter@specialized(Double) T <: AnyVal
    (g: (Double, Double) =>Double)(implicit f: T => Double)
  extends DFTT { //14

  override def |> : PartialFunction[U, Try[V]] = {
    case xt: U if( xt.size >= 2 ) => {
      fwrd(xt).map{ case(trf, freq) => {  //15
        val cutOff = fC*freq.size
        val filtered = freq.zipWithIndex
                     .map{ case(x, n) => x*g(n, cutOff) } //16
        trf.transform(filtered, INVERSE).toVector }) //17
    }
  }
}

滤波过程分为三个步骤:

  1. 计算 fwrd 离散傅里叶正向变换(正弦或余弦)(第 15 行)。

  2. 通过 Scala 的 map 方法(第 16 行)应用滤波函数(公式 M11)。

  3. 对频率应用逆变换(第 17 行)。

让我们评估截止值对滤波数据的影响。测试程序的实现包括从文件中加载数据(第 19 行),然后调用 pfnDFTfilter 部分函数的 DFTFilter(第 19 行):

import YahooFinancials._

val inputFile = s"$RESOURCE_PATH$symbol.csv"
val src = DataSource(input, false, true, 1)
val CUTOFF = 0.005
val pfnDFTfilter = DFTFilterDouble(sinc) |>
for {
  price <- src.get(adjClose)  //18
  filtered <- pfnDFTfilter(price)  //19
} 
yield { /* ... */ }

通过选择介于三个谐波之间的截止值来过滤噪声,这三个谐波的频率分别为 2、5 和 15。原始数据和两个滤波后的时间序列将在以下图表中展示:

基于 DFT 的滤波

基于离散傅里叶滤波器的平滑绘图

如您所预期,截止值为 12 的低通滤波器消除了最高频率的噪声。截止值为 4 的滤波器消除了第二个谐波(低频噪声),只留下了主要趋势周期。

市场周期的检测

使用离散傅里叶变换生成周期性时间序列的频率谱是容易的。然而,对于现实世界中的信号,例如代表股票历史价格的时序,该怎么办呢?

下一个练习的目的是通过将 2009 年 1 月 1 日至 2013 年 12 月 31 日标准普尔 500 指数的报价应用离散傅里叶变换,检测整体股市是否存在长期周期(s),如图所示:

市场周期检测

历史标准普尔 500 指数价格

第一步是应用 DFT 以提取标准普尔 500 历史价格的频率谱,如图所示,包括前 32 个谐波:

市场周期检测

历史标准普尔指数的频率谱

频率域图突出了关于标准普尔 500 历史价格的一些有趣特征:

  • 正负振幅都存在,正如你在具有复数值的时间序列中预期的那样。余弦级数对正振幅有贡献,而正弦级数影响正负振幅,(cos(x+π) = sin(x))

  • 沿频率的振幅衰减足够陡峭,足以保证对第一谐波(代表历史股价的主要趋势)之外进行进一步分析。下一步是对标准普尔 500 历史数据进行带通滤波器技术处理,以识别具有较低周期的短期趋势。

限带通滤波器用于减少或消除原始数据中的噪声。在这种情况下,使用频率范围或窗口的带通滤波器是合适的,以隔离表征特定周期的频率或频率组。在上一节中引入的sinc函数,用于实现低通滤波器,被修改为在窗口[w[1], w[2]]内强制执行带通滤波器,如下所示:

def sinc(f: Double, w: (Double, Double)): Double = 
    if(f > w._1 && f < w._2) 1.0 else 0.0

让我们定义一个基于 DFT 的带通滤波器,窗口宽度为 4,w=(i, i+4),其中i的范围在 2 到 20 之间。应用窗口[4, 8]可以隔离第二个谐波对价格的影响。随着我们消除小于 4 的频率的主上升趋势,所有过滤后的数据相对于主趋势都在一个相对较短的范围内变化。以下图形显示了该滤波器的输出:

市场周期检测

历史标准普尔指数上的带通 DFT 滤波器输出范围 4-8

在这种情况下,我们使用第三组谐波(频率范围从 18 到 22)对标准普尔 500 指数进行滤波;信号被转换成熟悉的正弦函数,如图所示:

市场周期检测

历史标准普尔指数的带通 DFT 滤波器输出范围 18-22

对于通过频率为 20 的带通滤波器过滤的标准普尔 500 数据,存在一个可能的合理解释,如图所示。标准普尔 500 历史数据图显示,上升趋势(交易时段 620 至 770)中间的波动频率显著增加。

这种现象可以通过以下事实来解释:当现有的上升趋势被打破时,标准普尔 500 指数在交易时段 545 左右达到阻力位。多头和空头之间开始了一场拉锯战,多头押注市场会小幅上涨,而空头则预期会有修正。当标准普尔 500 指数突破阻力位并恢复到具有高振幅低频率的强劲上升趋势时,交易者的拉锯战结束,如下面的图表所示:

市场周期检测

历史标准普尔 500 指数价格支撑和阻力水平的示意图

使用基于离散傅立叶的滤波器清理数据的一个局限性是,它要求数据科学家定期提取频率谱并修改滤波器,因为他或她永远无法确定最新的数据批次不会引入不同频率的噪声。卡尔曼滤波器解决了这一局限性。

离散卡尔曼滤波器

卡尔曼滤波器是一个数学模型,它提供了一种准确且递归的计算方法来估计过程的先前状态并预测未来状态,其中某些变量可能未知。R.E. 卡尔曼在 20 世纪 60 年代初引入它来模拟动态系统并在航空航天领域预测轨迹[3:10]。今天,卡尔曼滤波器用于发现两个可能或可能不与其他隐藏变量相关联的观测变量之间的关系。在这方面,卡尔曼滤波器与第七章中描述的隐藏马尔可夫模型部分中所述的顺序数据模型有一些相似之处[3:11]。

卡尔曼滤波器被用作:

  • 从当前观测值预测下一个数据点

  • 通过处理最后两个观测值来消除噪声的滤波器

  • 从观测历史中识别趋势的平滑模型

注意

平滑与滤波

平滑是一种从时间序列或信号中去除高频波动的操作。滤波包括选择一系列频率来处理数据。在这方面,平滑与低通滤波有些相似。唯一的区别是低通滤波通常通过线性方法实现。

从概念上讲,卡尔曼滤波器通过噪声观测估计系统的状态。卡尔曼滤波器有两个特点:

  • 递归:使用先前状态输入预测并校正新状态

  • 最优:这是一个最优估计器,因为它最小化了估计参数的均方误差(与实际值相比)

卡尔曼滤波器是自适应控制中使用的随机模型之一[3:12]。

注意

卡尔曼与非线性系统

卡尔曼滤波器估计线性动态系统的内部状态。然而,它可以通过使用线性或二次近似函数扩展到非线性状态空间模型。这些滤波器被称为,你可能已经猜到了,扩展卡尔曼滤波器EKF),其理论超出了本书的范围。

下一节专门介绍用于金融工程的线性系统的离散卡尔曼滤波器。连续信号可以使用奈奎斯特频率转换为时间序列。

状态空间估计

卡尔曼滤波器模型由动态系统的两个核心元素组成:生成数据的过程和收集数据的测量。这些元素被称为状态空间模型。从数学上讲,状态空间模型由两个方程组成:

  • 转换方程:这描述了系统的动力学,包括未观察到的变量

  • 测量方程:这描述了观察到的和未观察到的变量之间的关系

转换方程

让我们考虑一个具有线性状态 x[t]n 个变量的系统和控制输入向量 u[t]。在时间 t 的状态预测是通过一个线性随机方程(M12)计算的:

转换方程

  • A^t 是一个 n 维度的方阵,它表示从时间 t-1 的状态 x[t-1] 到时间 t 的状态 x[t] 的转换。这个矩阵是考虑的动态系统固有的。

  • B[t] 是一个 n by n 矩阵,它描述了控制输入模型(对系统或模型的外部作用)。它应用于控制向量,u[t]

  • w[t] 代表系统产生的噪声,或者从概率的角度来看,它代表模型的不确定性。它被称为过程白噪声。

控制输入向量表示系统状态的外部输入(或控制)。大多数系统,包括本章后面的金融示例,都没有外部输入到模型状态。

注意

白噪声和高斯噪声

白噪声是高斯噪声,遵循均值为零的正态分布。

测量方程

系统状态的 m 个值 z[t] 的测量由以下方程定义(M13):

测量方程

  • H[t] 是一个 m by n 矩阵,它模拟了测量与系统状态之间的依赖关系。

  • v[t] 是测量设备引入的白噪声。与过程噪声类似,v 遵循均值为零、方差为 R 的高斯分布,称为测量噪声协方差

注意

时间依赖性模型

我们不能假设广义离散卡尔曼滤波器的参数,如状态转换 A[t]、控制输入 B[t] 和观测矩阵(或测量依赖性)H[t] 与时间无关。然而,在大多数实际应用中,这些参数是恒定的。

递归算法

离散卡尔曼滤波器的方程组被实现为具有两个不同步骤的递归计算:

  • 算法使用转换方程来估计下一个观测值

  • 估计是通过此观测的实际测量值创建的

递归在以下图中得到可视化:

递归算法

递归卡尔曼算法的概述图

让我们在过滤金融数据的情况下,以类似于移动平均和傅里叶变换的方式说明预测和校正阶段。目标是提取 10 年期国债收益的趋势和暂时性成分。卡尔曼滤波器特别适合分析利率,原因有两个:

  • 收益是多个因素的结果,其中一些因素是不可直接观察的。

  • 收益受美联储政策的影响,该政策可以通过控制矩阵轻松建模。

10 年期国债的交易量高于期限更长的债券,这使得利率的趋势更加可靠 [3:13]。

将卡尔曼滤波器应用于清洗原始数据需要您定义一个包含观测和非观测状态的模型。在趋势分析的情况下,我们可以安全地创建我们的模型,具有两个变量状态:当前收益 x[t] 和前一个收益 x[t-1]

注意

动态系统的状态

“状态”一词指的是所考虑的动态系统的状态,而不是算法执行的状态。

此卡尔曼滤波器的实现使用了 Apache Commons Math 库。因此,我们需要指定从我们在第一章的“原始和隐式”部分中引入的原始类型到 RealMatrixRealVectorArray2DRowRealMatrixArrayRealVector Apache Commons Math 类型的隐式转换:

implicit def double2RealMatrix(x: DblMatrix): RealMatrix = 
    new Array2DRowRealMatrix(x)
implicit def double2RealRow(x: DblVector): RealMatrix = 
    new Array2DRowRealMatrix(x)
implicit def double2RealVector(x: DblVector): RealVector = 
    new ArrayRealVector(x)

客户端代码必须在其作用域内导入隐式转换函数。

卡尔曼模型假设过程噪声和测量噪声遵循高斯分布,也称为白噪声。为了便于维护,白噪声的生成被封装在具有以下参数的 QRNoise 类中(行 1):

  • qr:这是过程噪声矩阵 Q 和测量噪声 R 的尺度因子元组的表示

  • profile:这是默认为正态分布的噪声配置文件

两个 noiseQnoiseR 方法生成两个独立的白噪声元素数组(行 2):

val normal = Stats.normal(_)
class QRNoise(qr: DblPair,profile: Double=>Double = normal){ //1 
  def q = profile(qr._1)
  def r = profile(qr._2)
  lazy val noiseQ = ArrayDouble   //2
  lazy val noiseR = ArrayDouble
}

注意

尝试不同的噪声分布

虽然离散卡尔曼滤波假设噪声分布遵循正态分布,但QRNoise类允许用户尝试不同的噪声分布。

管理递归中使用的矩阵和向量的最简单方法是将它们定义为kalmanConfig配置类的参数。配置的参数遵循数学公式中定义的命名约定:A是状态转换矩阵,B是控制矩阵,H是定义测量与系统状态之间依赖关系的观测矩阵,P是协方差误差矩阵:

case class KalmanConfig(A: DblMatrix, B: DblMatrix, 
    H: DblMatrix, P: DblMatrix)

让我们在具有预定义KalmanConfig配置的时间序列上实现卡尔曼滤波作为ETransform类型的DKalman转换:

class DKalman(config: KalmanConfig)(implicit qrNoise: QRNoise) 
            extends ETransformKalmanConfig {
  type U = Vector[DblPair]  //3
  type V = Vector[DblPair]  //4
  type KRState = (KalmanFilter, RealVector)  //5
  override def |> : PartialFunction[U, Try[V]]
   ...
}

与任何明确的数据转换一样,我们需要指定UV的类型(第34行),它们是相同的。卡尔曼滤波不改变数据的结构,它只改变值。我们通过创建一个包含两个KalmanFilterRealVector(第5行)Apache Commons Math 类型的元组来定义KRState卡尔曼计算的内部状态。

现在滤波器的关键元素已经就位,是时候实现卡尔曼算法的预测-校正循环部分了。

预测

预测阶段包括使用转换方程估计x状态(国库券的收益)。我们假设美联储对利率没有实质性影响,使得控制输入矩阵B为空。转换方程可以通过对矩阵进行简单运算轻松解决:

预测

卡尔曼滤波转换方程的可视化

本练习的目的是评估转换矩阵A的不同参数对平滑度的影响。

注意

控制输入矩阵 B

在本例中,控制矩阵B为空,因为没有已知的外部行动对 10 年期国库券收益产生影响。然而,收益可能受到我们表示为隐藏变量的未知参数的影响。例如,矩阵B可以用来模拟美联储关于资产购买和联邦基金利率的决定。

卡尔曼滤波背后的数学作为 Scala 实现参考,使用相同的矩阵和向量符号。这绝对不是理解卡尔曼滤波及其在下一节中的实现的前提条件。如果你对线性代数有自然的倾向,以下描述了预测步骤的两个方程。

注意

预测步骤

M14:在时间t的状态预测是通过外推状态估计来计算的:

预测

  • A是维度为n的平方矩阵,代表从时间t-1的状态x到时间t的状态x的转换

  • x'[t] 是基于当前状态和模型 A 预测的系统状态

  • B 是一个描述状态输入的 n 维向量

M15:要最小化的均方误差矩阵 P,使用以下公式更新:

预测

  • A^T 是状态转换矩阵的转置

  • Q 是过程白噪声,描述为零均值和方差 Q 的高斯分布,称为噪声协方差

状态转换矩阵使用 Apache Commons Math 库中包含的矩阵和向量类实现。矩阵和向量的类型自动转换为 RealMatrixRealVector 类。

方程 M14 的实现如下:

x = A.operate(x).add(qrNoise.create(0.03, 0.1))

新状态被预测(或估计),然后用作校正步骤的输入。

校正

递归卡尔曼算法的第二步是校正 10 年期国债收益的实际收益。在这个例子中,测量白噪声可以忽略不计。测量方程很简单,因为状态由当前和前一个收益及其测量值 z 表示:

更正

卡尔曼滤波器测量方程的可视化

校正阶段的数学方程序列包括使用实际值 z 更新状态 x 的估计,并计算卡尔曼增益 K

注意

校正步骤

M16:系统状态 x 使用以下公式从实际测量 z 中估计:

更正

  • r[t] 是预测测量值和实际测量值之间的残差

  • K[t] 是校正因子的卡尔曼增益

M17:卡尔曼增益的计算如下:

更正

这里,H^TH 的矩阵转置,P[t]' 是误差协方差的估计。

卡尔曼平滑

是时候检验我们对状态和测量方程的知识了。Apache Commons Math 库定义了两个 DefaultProcessModelDefaultMeasurementModel 类来封装矩阵和向量的组件。10 年期国债收益的历史值通过 DataSource 方法加载,并映射到滤波器的输出平滑序列:

  override def |> : PartialFunction[U, Try[V]] = {
    case xt: U if( !xt.isEmpty) => Try( 
      xt.map { case(current, prev) => {
        val models = initialize(current, prev) //6
        val nState = newState(models) //7
        (nState(0), nState(1))  //8
      }}
    ) 
   }

卡尔曼滤波器的数据转换在私有 initialize 方法(行 6)中初始化每个数据点的过程和测量模型,在 newState 方法(行 7)中迭代地使用转换和校正方程更新状态,并返回成对的值滤波序列(行 8)。

注意

异常处理

书中省略了捕获和处理由 Apache Commons Math 库抛出的异常的代码,这是标准做法。就卡尔曼滤波的执行而言,必须处理以下异常:

  • NonSquareMatrixException

  • DimensionMismatchException

  • MatrixDimensionMismatchException

initialize方法封装了pModel过程模型(第 9 行)和mModel测量(观测)模型(第 10 行)的初始化,这些模型在 Apache Commons Math 库中定义:

def initialize(current: Double, prev: Double): KRState = {  
  val pModel = new DefaultProcessModel(config.A, config.B, 
                Q, input, config.P) //9
  val mModel = new DefaultMeasurementModel(config.H, R) //10
  val in = ArrayDouble
  (new KalmanFilter(pModel,mModel), new ArrayRealVector(in))
}

Apache Commons Math API 抛出的异常被捕获并通过Try单子进行处理。通过newState方法实现 10 年期国债平滑收益的迭代预测和校正。该方法通过以下步骤迭代:

  1. 通过调用实现M14公式(第 11 行)的 Apache Commons Math KalmanFilter.predict方法来估计新的状态值。

  2. M12公式应用于时间 t 的新状态x(第 12 行)。

  3. 使用M13公式(第 13 行)计算时间 t 的测量值z

  4. 调用 Apache Commons Math 的KalmanFilter.correct方法来实现M16公式(第 14 行)。

  5. 通过调用 Apache Commons Math 的KalmanFilter.getStateEstimation方法(第 15 行)返回状态x的估计值。

代码如下:

def newState(state: KRState): DblArray = {
  state._1.predict  //11
  val x = config.A.operate(state._2).add(qrNoise.noisyQ) //12
  val z = config.H.operate(x).add(qrNoise.noisyR) //13
  state._1.correct(z)  //14
  state._1.getStateEstimation  //15
}

注意

退出条件

newState方法的代码片段中,当达到最大迭代次数时,特定数据点的迭代会退出。更详细实现包括在每个迭代中评估矩阵P或估计收敛到预定义范围内。

固定滞后平滑

到目前为止,我们已经学习了卡尔曼滤波算法。我们需要将其适应于时间序列的平滑处理。固定滞后平滑技术包括向后纠正先前数据点,并考虑最新的实际值。

N 滞后平滑器定义输入为向量X = {x[t-N-1], x[t-N-2], …, x[t]},其中x[t-N-j]的值在考虑x[t]的当前值时被纠正。

该策略与隐藏马尔可夫模型的前向和后向传递相当相似(参见第七章下隐藏马尔可夫模型中的评估 – CF-1部分,顺序数据模型)。

注意

滞后平滑的复杂策略

有许多公式或方法可以实现精确的固定滞后平滑策略并纠正预测观测值。这些策略超出了本书的范围。

实验

目标是使用两步滞后平滑算法来平滑 10 年期国债的收益。

注意

两步滞后平滑

M18:使用单个平滑因子α对状态St定义的两步滞后平滑算法如下:

实验

状态方程使用先前的状态[x[t], x[t-1]]更新状态值,其中x[t]代表时间t的 10 年期国债收益率。这是通过使用 drop 方法将原始时间序列{x[0] … x[n-1]}的值向右移动 1 来实现的,X[1] ={x[1], …, x[n-1]},创建一个不包含最后一个元素的原始时间序列的副本X[2]={x[0], …, x[n-2]},然后将X[1]X[2]进行 zip 操作。这个过程是通过zipWithShift方法实现的,该方法在第一部分中介绍。

状态向量序列S[k] = [x[k], x[k-1]]^T由卡尔曼算法处理,如下面的代码所示:

Import YahooFinancials._ 
val RESOURCE_DIR = "resources/data/chap3/"
implicit val qrNoise = new QRNoise((0.7, 0.3)) //16

val H: DblMatrix = ((0.9, 0.0), (0.0, 0.1))    //17
val P0: DblMatrix = ((0.4, 0.3), (0.5, 0.4))   //18
val ALPHA1 = 0.5; val ALPHA2 = 0.8
val src = DataSource(s"${RESOURCE_DIR}${symbol}.csv", false)

(src.get(adjClose)).map(zt => {  //19
   twoStepLagSmoother(zt, ALPHA1)     //20
   twoStepLagSmoother(zt, ALPHA2)
})

注意

隐式噪声实例

对于过程和测量的噪声被定义为DKalman卡尔曼滤波的隐式参数,以下两个原因:

  • 噪声的轮廓特定于评估的过程或系统及其测量;它与ABH卡尔曼配置参数无关。因此,它不能是KalmanConfig类的一个成员。

  • 如果需要,应与其他替代滤波技术共享相同的噪声特性。

过程和测量的白噪声被隐式地初始化为qrNoise值(第 16 行)。代码初始化了测量对状态的依赖矩阵H(第 17 行)和包含初始协方差错误的P0(第 18 行)。输入数据是从包含每日 Yahoo 金融数据的 CSV 文件中提取的(第 19 行)。最后,该方法使用两个不同的ALPHA1ALPHA2 alpha 参数值执行twoStepLagSmoother两步滞后平滑算法(第 20 行)。

让我们看看twoStepLagSmoother方法:

def twoStepLagSmoother(zSeries: DblVector,alpha: Double): Int = { 
  val A: DblMatrix = ((alpha, 1.0-alpha), (1.0, 0.0))  //21
  val xt = zipWithShift(1)  //22
  val pfnKalman = DKalman(A, H, P0) |>    //23
  pfnKalman(xt).map(filtered =>          //24
    display(zSeries, filtered.map(_._1), alpha) )
}

twoStepLagSmoother方法接受两个参数:

  • 一个zSeries单变量时间序列

  • 一个alpha状态转换参数

它使用alpha指数移动平均衰减参数(第 21 行)初始化状态转换矩阵A。它使用zipWithShift方法(第 22 行)创建了两个步骤的滞后时间序列xt。它提取了pfnKalman部分函数(第 23 行),进行处理,最后显示两个步骤的滞后时间序列(第 24 行)。

注意

建模状态转换和噪声

状态转换和与过程相关的噪声必须仔细选择。状态方程的分辨率依赖于Cholesky(QR)分解,它需要一个非负定矩阵。如果违反了这一原则,Apache Commons Math 库的实现将抛出NonPositiveDefiniteMatrixException异常。

平滑后的收益率如下绘制在原始数据上:

实验

对于 10 年期国债历史价格的卡尔曼滤波输出

卡尔曼滤波器能够平滑 10 年期国债的历史收益率,同时保留尖峰和低频噪声。让我们分析一个较短的时间段内的数据,在这个时间段内噪声最强,即在第 190 天和第 275 天的交易日之间:

实验

卡尔曼滤波器对 10 年期国债价格的输出为 0.8-.02

在不取消实际尖峰的情况下,高频噪声已显著降低。分布(0.8,0.2)考虑了先前状态并倾向于预测值。相反,一个使用状态转移矩阵 A [0.2,0.8,0.0,1.0] 且倾向于最新测量的运行将保留噪声,如下图中所示:

实验

卡尔曼滤波器对 10 年期国债价格的输出为 0.2-0.8

优点和缺点

卡尔曼滤波器是一个非常有用且强大的工具,用于帮助您理解过程和观测之间的噪声分布。与基于离散傅里叶变换的低通或带通滤波器不同,卡尔曼滤波器不需要计算频率谱或假设噪声的频率范围。

然而,线性离散卡尔曼滤波器有其局限性,如下所述:

  • 由过程和测量产生的噪声必须是高斯分布的。具有非高斯噪声的过程可以使用高斯和滤波器或自适应高斯混合[3:14]等技术进行建模。

  • 它要求底层过程是线性的。然而,研究人员已经能够制定卡尔曼滤波器的扩展,称为扩展卡尔曼滤波器EKF),以从非线性动态系统中过滤信号,但这会带来显著的计算复杂性。

    注意

    连续时间卡尔曼滤波器

    卡尔曼滤波器不仅限于具有离散状态 x 的动态系统。对于连续状态-时间的情况,通过修改状态转移方程来处理,因此估计的状态是计算为导数 dx/dt

不同的预处理技术

为了节省空间和您的宝贵时间,本章介绍了并应用了三种滤波和光滑算法类别。移动平均、傅里叶级数和卡尔曼滤波器远非清理原始数据所使用的唯一技术。其他技术可以归类为以下类别:

  • 自回归模型包括自回归移动平均ARMA)、自回归积分移动平均ARIMA)、广义自回归条件异方差GARCH)以及依赖于某种自相关函数的 Box-Jenkins。

  • 曲线拟合算法包括多项式和几何拟合使用普通最小二乘法,非线性最小二乘使用Levenberg-Marquardt优化器,以及概率分布拟合。

  • 具有高斯噪声的非线性动态系统,例如粒子滤波器

  • 如第七章的隐藏马尔可夫模型部分所述的隐藏马尔可夫模型序列数据模型

摘要

这完成了对最常用数据过滤和平滑技术的概述。还有其他类型的数据预处理算法,如标准化、分析和方差减少;识别缺失值对于避免许多使用机器学习进行回归或分类的项目所面临的垃圾输入垃圾输出难题也是至关重要的。

Scala 可以有效地用于使代码易于理解,并避免在方法中使用不必要的参数。

本章中介绍的三种技术,从最简单的移动平均和傅里叶变换到更复杂的卡尔曼滤波,对于为下一章介绍的概念设置数据大有裨益:无监督学习和更具体地说,聚类。

第四章 无监督学习

为分类或回归标记一组观测值可能是一项艰巨的任务,尤其是在特征集很大的情况下。在某些情况下,标记的观测值可能不可用或无法创建。为了尝试从观测值中提取一些隐藏的关联或结构,数据科学家依赖于无监督学习技术来检测数据中的模式或相似性。

无监督学习的目标是发现一组观测值中的规律性和不规则性模式。这些技术也应用于减少解或特征空间。

有许多无监督算法;一些更适合处理相关特征,而另一些则在隐藏特征的情况下生成亲和组[4:1]。在本章中,你将学习三种最常见的无监督学习算法:

  • K-means:这是用于聚类观测特征

  • 期望最大化EM):这是用于聚类观测和潜在特征

  • 主成分分析PCA):这是用来降低模型维度

这些算法中的任何一个都可以应用于技术分析或基本面分析。在附录 A 的Finances 101下的技术分析部分讨论了财务比率的基本面分析和价格走势的技术分析。K-means 算法在 Scala 中完全实现,而期望最大化(expectation-maximization)和主成分分析(Principal Components Analysis)则利用了 Apache Commons Math 库。

本章以非线性模型降维技术的简要概述结束。

聚类

对于大型数据集,涉及大量特征的问题很快就会变得难以处理,并且很难评估特征之间的独立性。任何需要一定程度的优化和至少计算一阶导数的计算都需要大量的计算能力来操作高维矩阵。与许多工程领域一样,将非常大的数据集分类的分割和征服方法非常有效。目标是把非常大的观测集减少到一小组具有一些共同属性的观测。

聚类

数据聚类的可视化

这种方法被称为向量量化。向量量化是一种将一组观测分为相似大小组的方法。向量量化的主要好处是,使用每个组的代表进行的分析远比分析整个数据集简单得多 [4:2]。

聚类,也称为聚类分析,是一种基于距离或相似性的概念来生成称为聚类的组的向量量化形式。

注意

学习向量量化(LVQ

向量化不应与学习向量量化混淆;学习向量量化是人工神经网络的一个特殊情况,它依赖于赢家通吃的学习策略来压缩信号、图像或视频。

本章介绍了两种最常用的聚类算法:

  • K-means:这是用于定量类型,给定聚类数量和距离公式,最小化总误差(称为重建误差)。

  • 期望最大化EM):这是一种两步概率方法,最大化一组参数的似然估计。EM 特别适合处理缺失数据。

K-means 聚类

K-means 是一种流行的聚类算法,它可以迭代或递归地实现。每个聚类的代表是计算为该聚类的中心,称为质心。单个聚类内观测之间的相似性依赖于观测之间的距离(或相似性)概念。

测量相似性

测量观测之间的相似性的方法有很多。最合适的度量必须直观且避免计算复杂性。本节回顾了三种相似性度量:

  • 曼哈顿距离

  • 欧几里得距离

  • 归一化内积或点积

曼哈顿距离定义为两个相同大小的变量或向量({x[i]}{y[i]})之间的绝对距离(M1):

测量相似性

实现足够通用,可以计算不同类型元素的两个向量的距离,只要已经定义了这些类型之间的隐式转换为 Double 值,如下所示:

def manhattanT <: AnyVal, U <: AnyVal(implicit f: T => Double): Double = 
 (x,y).zipped.map{case (u,v) => Math.abs(u-v)}.sum

同大小两个向量,{x[i]}{y[i]},之间的普遍欧几里得距离由以下公式定义(M2):

测量相似度

代码将如下所示:

def euclideanT <: AnyVal, U <: AnyVal(implicit f: T => Double): Double = 
  Math.sqrt((x,y).zipped.map{case (u,v) =>u-v}.map(sqr(_)).sum)

两个向量,{x[i]}{y[i]},之间的归一化内积或余弦距离由以下公式定义(M3):

测量相似度

在此实现中,使用fold方法中的元组同时计算每个数据集的点积和范数:

def cosine[T <: AnyVal, U < : AnyVal] (
     x: Array[T], 
     y: Array[U])(implicit f : T => Double): Double = {
  val norms = (x,y).zipped
          .map{ case (u,v) => ArrayDouble}
          ./:(Array.fill(3)(0.0))((s, t) => s ++ t) 
  norms(0)/Math.sqrt(norms(1)*norms(2))
}

注意

zip 和 zipped 的性能

两个向量的标量积是最常见的操作之一。使用通用的zip方法实现点积很有诱惑力:

 def dot(x:Array[Double], y:Array[Double]): Array[Double] = x.zip(y).map{case(x, y) => f(x,y) )

一种功能替代方法是使用Tuple2.zipped方法:

 def dot(x:Array[Double], y:Array[Double]): Array[Double] = (x, y).zipped map ( _ * _)

如果可读性不是主要问题,你总是可以使用while循环实现dot方法,这可以防止你使用普遍的while循环。

定义算法

K-means 算法的主要优势(以及其受欢迎的原因)是其简单性[4:3]。

注意

K-means 目标

M4:让我们考虑 K 个簇,{C[k]},具有均值或质心,{m[k]}。K-means 算法实际上是一个优化问题,其目标是使重建或总误差最小化,定义为距离的总和:

定义算法

K-means 算法的四个步骤如下:

  1. 簇配置(初始化 K 个簇的质心或均值 m[k])。

  2. 簇分配(根据质心 m[k] 将观测值分配给最近的簇)。

  3. 错误最小化(计算总重建误差):

    1. 计算质心 m[k] 以最小化当前分配的总重建误差。

    2. 根据新的质心 m[k] 重新分配观测值。

    3. 重复计算总重建误差,直到没有观测值被重新分配。

  4. 通过将观测值分配给最近的簇来对新观测值进行分类。

在实现算法之前,我们需要在 Scala 中定义 K-means 的组件。

步骤 1 – 簇配置

让我们创建 K-means 算法的两个主要组件:观测值的簇和 K-means 算法的实现。

定义簇

第一步是定义一个簇。簇由以下参数定义:

  • 质心(center)(行1

  • 属于此簇的观测值的索引(成员)(行2

代码将如下所示:

class ClusterT <: AnyVal
     (implicit f: T => Double) {  //1
  type DistanceFunc[T] = (DblArray, Array[T])=> Double
  val members = new ListBuffer[Int]    //2
  def moveCenter(xt: XVSeries[T]): Cluster[T] 
  ...
}

簇负责在任何迭代计算 K-means 算法的点中管理其成员(数据点)。假设簇永远不会包含相同的数据点两次。Cluster类中的两个关键方法如下:

  • moveCenter:这重新计算簇的质心

  • stdDev:这计算所有观测成员与质心之间的距离的标准差

Cluster 类的构造函数是通过伴随对象的 apply 方法实现的。为了方便,请参考 附录 A 中的 类构造函数模板 部分,基本概念

object Cluster {
  def applyT <: AnyVal
      (implicit f: T => Double): Cluster[T] = 
     new ClusterT
}

让我们看看 moveCenter 方法。它通过现有成员和新的中心点创建一个新的聚类。计算中心点值需要将观测矩阵按特征转置为按观测特征矩阵(第 3 行)。新的中心点通过将所有观测中每个特征的求和除以数据点的数量进行归一化计算(第 4 行):

def moveCenter(xt: XVSeries[T])
       (implicit m: Manifest[T], num: Numeric[T])
       : Cluster[T] = {  
  val sum = transpose(members.map( xt(_)).toList)
             .map(_.sum)  //3
  ClusterT.toArray) //4
}

stdDev 方法计算相对于其中心的聚类中所有观测的标准差。通过映射调用提取每个成员与中心点之间的 distance 值(第 5 行)。然后将其加载到统计实例中以计算标准差(第 6 行)。计算中心点与观测之间距离的函数是方法的一个参数。默认距离是 euclidean

def stdDev(xt: XVSeries[T], distance: DistanceFunc): Double = {
  val ts = members.map(xt( _)).map(distance(center,_)) //5
  StatsDouble.stdDev  //6
}

备注

聚类选择

在重新分配观测(更新其成员资格)时,选择最合适的聚类有不同的方法。在本实现中,我们将选择具有较大分散度或最低密度的聚类。另一种选择是选择具有最大成员资格的聚类。

初始化聚类

确保 K-means 算法快速收敛,聚类中心的初始化非常重要。解决方案范围从简单的随机生成中心点到应用遗传算法来评估中心点候选者的适应性。我们选择了一个由 M. Agha 和 W. Ashour 开发的既高效又快速的初始化算法 [4:4]。

初始化步骤如下:

  1. 计算观测集的标准差。

  2. 计算具有最大标准差的特征索引 {x[k,0], x[k,1] … x[k,n]}

  3. 按照维度 k 的标准差递增值对观测进行排序。

  4. 将排序后的观测集平均分成 K 个集合 {S[m]}

  5. 找到中值大小的 size(S[m])/2

  6. 使用得到的观测作为初始中心点。

让我们分解 initialize 方法中 Agha-Ashour 算法的实现:

def initialize(xt: U): V = {
  val stats = statistics(xt)  //7
  val maxSDevVar = Range(0,stats.size)  //8
                   .maxBy(stats( _ ).stdDev )

  val rankedObs = xt.zipWithIndex
                .map{case (x, n) => (x(maxSDevVar), n)}
                .sortWith( _._1  < _._1)  //9

  val halfSegSize = ((rankedObs.size>>1)/_config.K)
                  .floor.toInt //10
  val centroids = rankedObs
     .filter(isContained( _, halfSegSize, rankedObs.size))
     .map{ case(x, n) => xt(n)} //11
  centroids.aggregate(List[Cluster[T]]())((xs, c) => 
        ClusterT :: xs, _ ::: _)  //12
}

XVSeries 类型的时间序列的 statistics 方法在 第三章 的 时间序列在 Scala 中 部分定义,数据预处理(第 7 行)。使用 maxBy 方法在 Stats 实例上计算具有 maxSDevVar 最大方差或标准差的特征维度(或特征)(第 8 行)。然后,根据 rankedObs 标准差的递增值对观测进行排序(第 9 行)。

然后将有序的观测值序列分割成xt.size/_config.K个段(第 10 行),并使用isContained过滤条件(第 11 行)选择这些段的索引作为中点(或中位数)观测值:

def isContained(t: (T,Int), hSz: Int, dim: Int): Boolean = 
    (t._2 % hSz == 0) && (t._2 %(hSz<<1) != 0)

最后,通过在质心集合上调用aggregate来生成簇列表(第 12 行)。

第 2 步 – 簇分配

K-means 算法的第二步是在第 1 步初始化质心后的簇中分配观测值。这一壮举是通过私有的assignToClusters方法完成的:

def assignToClusters(xt: U, clusters: V, 
     members: Array[Int]): Int =  {

  xt.zipWithIndex.filter{ case(x, n) => {  //13
    val nearestCluster = getNearestCluster(clusters, x) //14
    val reassigned = nearestCluster != members(n) 

    clusters(nearestCluster) += n //15
    members(n) = nearestCluster //16
    reassigned
  }}.size
}

观测值分配到每个簇的核心是对时间序列的过滤(第 13 行)。过滤计算最近簇的索引,并检查观测值是否需要重新分配(第 14 行)。在n索引处的观测值被添加到最近的簇clusters(nearestCluster)(第 15 行)。然后更新观测值的当前成员资格(第 16 行)。

通过私有的getNearestCluster方法计算与观测值数据最近的簇,如下所示:

def getNearestCluster(clusters: V, x: Array[T]): Int = 
  clusters.zipWithIndex./:((Double.MaxValue, 0)){
    case (p, (c, n) ) => { 
     val measure = distance(c.center, x) //17
     if( measure < p._1) (measure, n)  else p
  }}._2

使用 K-means 构造函数中定义的距离度量(第 17 行),通过折叠提取与x观测值最近的簇列表中的簇。

与其他数据处理单元一样,K-means 簇的提取封装在数据转换中,以便可以将聚类集成到工作流程中,使用第二章中组合 mixins 以构建工作流程部分描述的 mixins 组合。

注意

K-means 算法退出条件

在一些罕见的情况下,算法可能会在簇之间重新分配相同的一些观测值,这会阻止其在合理的时间内收敛到解决方案。因此,建议您添加一个最大迭代次数作为退出条件。如果 K-means 在最大迭代次数后没有收敛,那么需要重新初始化簇质心,并重新启动迭代(或递归)执行。

|>转换要求在stdDev方法中计算与质心c相关的观测值距离的标准差:

type KMeansModel[T} = List[Cluster[T]] 
def stdDevT: DblVector = 
 clusters.map( _.stdDev(xt, distance)).toVector

注意

质心与均值

中心点和均值指的是同一实体:簇的中心。本章使用这两个术语互换。

第 3 步 – 重建/误差最小化

簇的成员初始化为预定义的观测值集合。算法通过最小化总重建误差来更新每个簇的成员资格。执行 K-means 算法有两种有效策略:

  • 尾递归执行

  • 迭代执行

创建 K-means 组件

让我们声明 K-means 算法类,KMeans,并定义其公共方法。KMeans通过从训练集中提取的隐式模型实现一个ITransform数据转换,并在第二章的单态数据转换部分进行描述,Hello World!(行18)。KMeansConfig类型的配置由元组(KmaxIters)组成,其中K是簇的数量,maxIters是算法收敛允许的最大迭代次数:

case class KMeansConfig(val K: Int, maxIters: Int)

KMeans类接受以下三个参数:

  • config:这是用于算法执行的配置

  • distance:这是用于计算任何观测值与簇质心之间距离的函数

  • xt:这是训练集

T类型到Double的隐式转换实现为一个视图边界。KMeans类的实例化初始化了一个从 K-means 输出的V类型,作为Cluster[T](行20)。Numeric类的num实例必须作为类参数隐式传递,因为它在initialize中的sortWith调用、maxBy方法和Cluster.moveCenter方法(行19)中是必需的。Manifest用于在 JVM 中保留Array[T]的擦除类型:

class KMeansT <: AnyVal
    (implicit m: Manifest[T], num: Numeric[T], f: T=>Double) //19
  extends ITransform[Array[T]](xt) with Monitor[T] { 

  type V = Cluster[T]     //20
  val model: Option[KMeansModel[T]] = train
  def train: Option[KMeansModel[T]]
  override def |> : PartialFunction[U, Try[V]]
  ...
}

KMeansModel模型定义为通过训练提取的簇列表。

尾递归实现

通过train训练方法实现转换或聚类函数,该方法创建一个以XVSeries[T]作为输入和KMeansModel[T]作为输出的部分函数:

def train: Option[KMeansModel[T]] = Try {
         // STEP 1
  val clusters =initialize(xt) //21
  if( clusters.isEmpty)  /* ...  */
  else  {
         // STEP 2
    val members = Array.fill(xt.size)(0)
    assignToClusters(xt, clusters, members) //22
    var iters = 0

    // Declaration of the tail recursion def update      
    if( iters >= _config.maxIters )
      throw new IllegalStateException( /* .. */)
         // STEP 3
    update(clusters, xt, members)  //23
  } 
} match {
   case Success(clusters) => Some(clusters)
   case Failure(e) => /* … */
}

K-means 训练算法通过以下三个步骤实现:

  1. 使用initialize方法初始化簇的质心(行21)。

  2. 使用assignToClusters方法(行22)将观测值分配给每个簇。

  3. 使用update递归方法(行23)重新计算总误差重建。

总误差重建的计算实现为一个尾递归方法,update,如下所示:

@tailrec
def update(clusters: KMeansModel[T], xt: U, 
       members: Array[Int]): KMeansModel[T] = {  //24

  val newClusters = clusters.map( c => {         
      if( c.size > 0) c.moveCenter(xt) //25
    else clusters.filter( _.size >0)
          .maxBy(_.stdDev(xt, distance)) //26
  }) 
  iters += 1
  if(iters >= config.maxIters ||       //27
      assignToClusters(xt, newClusters, members) ==0) 
    newClusters
  else 
    update(newClusters, xt, membership)   //28
}

递归接受以下三个参数(行24):

  • 在递归过程中更新的当前clusters列表

  • xt输入时间序列

  • 簇的成员索引,members

通过重新计算每个簇不为空时的每个质心(行25)或评估每个观测值相对于每个质心的距离的标准差(行26)来计算新的簇列表,newClusters。当达到maxIters的最大递归调用次数或没有更多的观测值被重新分配到不同的簇时(行27),执行退出。否则,方法使用更新的簇列表调用自身(行28)。

迭代实现

为了信息目的,展示了迭代执行的实现。它遵循与递归实现相同的调用序列。新的聚类被计算(第29行),当达到允许的最大迭代次数(第30行)或没有更多的观测值被重新分配到不同的聚类(第31行)时,执行退出:

val members = Array.fill(xt.size)(0)
assignToClusters(xt, clusters, members) 
var newClusters: KMeansModel[T] = List.empty[Cluster[T]]
Range(0,  maxIters).find( _ => {  //29
  newClusters = clusters.map( c => {   //30
    if( c.size > 0)  c.moveCenter(xt) 
    else clusters.filter( _.size > 0)
           .maxBy(_.stdDev(xt, distance))
  }) 
  assignToClusters(xt, newClusters, members) > 0  //31
}).map(_ => newClusters)

聚类密度在KMeans类中按以下方式计算:

def density: Option[DblVector] = 
  model.map( _.map( c => 
    c.getMembers.map(xt(_)).map( distance(c.center, _)).sum)

第 4 步 – 分类

分类目标是把一个观测值分配到与最近质心最接近的聚类:

override def |> : PartialFunction[Array[T], Try[V]] = {
  case x: Array[T] if( x.length == dimension(xt) 
                     && model != None) => 
  Try (model.map( _.minBy(c => distance(c.center,x))).get )
}

最合适的聚类是通过选择c聚类,其中心x观测值最接近的聚类来计算的,使用minBy高阶方法。

维度诅咒

具有大量特征(高维度)的模型需要更多的观测值来提取相关和可靠的聚类。具有非常小数据集(< 50)的 K-means 聚类会产生具有高偏差和有限聚类数量的模型,这些聚类受到观测值顺序的影响[4:5]。我一直在使用以下简单的经验法则来为大小为n的训练集、预期的K个聚类和N个特征:n < K.N

注意

维度和训练集的大小

在给定模型维度的情况下确定训练集大小的问题并不仅限于无监督学习算法。所有监督学习技术都面临着设置可行的训练计划的相同挑战。

无论遵循哪种经验规则,这种限制对于使用历史报价分析股票尤其成问题。让我们考虑我们的例子,使用技术分析根据股票在 1 年(或大约 250 个交易日)内的价格行为对股票进行分类。问题维度是 250(250 个日收盘价)。股票数量(观测值)将超过数百个!

维度诅咒

K-means 聚类的价格模型

有选项可以绕过这种限制并减少观测数的数量,如下所示:

  • 在不丢失原始数据中大量信息的情况下采样交易数据,假设观测值的分布遵循已知的概率密度函数。

  • 对数据进行平滑以去除第三章中看到的噪声,假设噪声是高斯分布的。在我们的测试中,平滑技术将去除每只股票的价格异常值,从而减少特征数(交易会话)。这种方法与第一种(采样)技术不同,因为它不需要假设数据集遵循已知的密度函数。另一方面,特征的减少将不那么显著。

这些方法充其量是权宜之计,仅用于本教程。在将这些技术应用于实际商业应用之前,您需要考虑数据的质量。本章最后一段介绍的原理成分分析是最可靠的降维技术之一。

设置评估

目标是从 2013 年 1 月 1 日至 12 月 31 日之间的某个时间段内的一组股票价格行为中提取簇作为特征。为此测试,从标准普尔 500 指数中随机选择了 127 只股票。以下图表可视化了一组这些 127 只股票中正常化价格的行为:

设置评估

用于 K-means 聚类的股票篮子的价格行为

关键是在聚类之前选择合适的特征以及操作的时间窗口。考虑 252 个交易日的整个历史价格作为特征是有意义的。然而,观测值(股票)的数量太少,无法使用整个价格范围。观测值是第 80^(th)和第 130^(th)个交易日之间的每个交易日的股票收盘价。调整后的每日收盘价使用它们各自的最小值和最大值进行归一化。

首先,让我们创建一个简单的方法来计算簇的密度:

val MAX_ITERS = 150
def density(K: Int, obs: XVSeries[Double]): DblVector = 
  KMeansDouble).density.get //32

density方法调用了第 3 步中描述的KMeans.density。让我们使用DataSource类从 CSV 文件加载数据,如附录 A 中数据提取部分所述附录 A,基本概念

import YahooFinancials.) 
val START_INDEX = 80; val NUM_SAMPLES = 50  //33
val PATH = "resources/data/chap4/"

type INPUT = Array[String] => Double
val extractor = adjClose :: List[INPUT]() //34
val symbolFiles = DataSource.listSymbolFiles(PATH) //35

for {
  prices <- getPrices  //36
  values <- Try(getPricesRange(prices))  //37
  stdDev <- Try(ks.map( density(_, values.toVector))) //38
  pfnKmeans <- Try { 
    KMeansDouble,values.toVector) |> 
  }   //39
  predict <- pfnKmeans(values.head)   //40
} yield {
  val results = s"""Daily prices ${prices.size} stocks")
     | \nClusters density ${stdDev.mkString(", ")}"""
  .stripMargin
  show(results)
}

如前所述,聚类分析应用于第 80^(th)和第 130^(th)个交易日之间的收盘价范围(第33行)。extractor函数从YahooFinancials中检索股票的调整后收盘价(第34行)。股票的标记(或符号)列表被提取为位于path中的 CSV 文件名列表(第35行)。例如,通用电气公司的标记符号是 GE,交易数据位于GE.csv

执行提取了 50 个每日价格,然后使用filter(第36行)过滤掉格式不正确的数据:

type XVSeriesSet = Array[XVSeries[Double]]
def getPrices: Try[XVSeriesSet] = Try {
   symbolFiles.map( DataSource(_, path) |> extractor )
   .filter( _.isSuccess ).map( _.get)
}

第 80^(th)和第 130^(th)天之间的交易日的历史股票价格由getPricesRange闭包生成(第37行):

def getPricesRange(prices: XVSeriesSet) = 
   prices.view.map(_.head.toArray)
    .map( _.drop(START_INDEX).take(NUM_SAMPLES))

它通过调用每个簇数ks值的density方法来计算簇的密度(第38行)。

创建了一个用于 5 簇的pfnKmeans部分分类函数,它是基于KMeans(第39行),然后用于对其中一个观测值进行分类(第40行)。

评估结果

第一次测试运行使用K=3簇。每个簇的均值(或质心)向量如下绘制:

评估结果

使用 K-means K=3 的簇均值图表

三个聚类的均值向量非常独特。图表中顶部和底部的均值 12 分别具有 0.34 和 0.27 的标准差,并且具有非常相似的图案。12 聚类均值向量之间的元素差异几乎恒定:0.37。具有均值向量 3 的聚类代表在时间周期开始时像聚类 2 中的股票,而在时间周期结束时像聚类 1 中的股票的股票群体。

这种行为可以通过以下事实轻松解释:时间窗口或交易期,第 80 天至第 130 天的交易日,对应于联邦储备关于量化宽松计划的货币政策转变。以下是每个聚类(其质心值在图表上显示)的股票部分列表:

聚类 1 AET, AHS, BBBY, BRCM, C, CB, CL, CLX, COH, CVX, CYH, DE, …
聚类 2 AA, AAPL, ADBE, ADSK, AFAM, AMZN, AU, BHI, BTU, CAT, CCL, …
聚类 3 ADM, ADP, AXP, BA, BBT, BEN, BK, BSX, CA, CBS, CCE, CELG, CHK, …

让我们评估聚类数量 K 对每个聚类特征的影响。

调整聚类数量

我们重复之前的测试,在 127 只股票和相同的时间窗口上,聚类数量从 2 到 15 变化。

对于 K = 2 的每个聚类的均值(或质心)向量如下所示:

调整聚类数量

使用 K-means K=2 的聚类均值图表

2 个聚类的 K-means 算法结果图表显示,标记为 2 的聚类均值向量与 K = 5 图表中标记为 3 的均值向量相似。然而,具有均值向量 1 的聚类在一定程度上反映了图表 K = 5 中聚类 13 均值向量的聚合或求和。聚合效应解释了为什么聚类 1(0.55)的标准差是聚类 2(0.28)标准差的两倍。

对于 K = 5 的每个聚类的均值(或质心)向量如下所示:

调整聚类数量

使用 K-means K=5 的聚类均值图表

在此图表中,我们可以评估聚类 1(具有最高的均值)、2(具有最低的均值)和 3K = 3 图表中具有相同标签的聚类非常相似。具有均值向量 4 的聚类包含行为与聚类 3 非常相似的股票,但方向相反。换句话说,聚类 34 的股票在货币政策变化公告后反应相反。

K 值较高的测试中,不同聚类之间的区别变得模糊,如下图中 K = 10 所示:

调整聚类数量

使用 K-means K=10 的簇均值的图表

在第一个图表中,对于 K = 2 的情况,可以看到簇 123 的均值。有理由假设这些可能是最可靠的簇。这些簇恰好具有低标准差或高密度。

让我们定义具有质心 c[j] 的簇的密度,C[j],为每个簇的所有成员与其均值(或质心)之间的欧几里得距离的倒数(M6):

调整簇的数量

簇的密度与簇的数量(从 K = 1K = 13)的关系图被绘制出来:

调整簇的数量

K = 1 到 13 的平均簇密度的条形图

如预期的那样,随着 K 的增加,每个簇的平均密度增加。从这个实验中,我们可以得出一个简单的结论:在 K = 5 及以上的测试运行中,每个簇的密度并没有显著增加。你可能观察到,随着簇数量的增加,密度并不总是增加(K = 6K = 11)。这种异常可以通过以下三个因素来解释:

  • 原始数据有噪声

  • 模型在一定程度上依赖于质心的初始化

  • 退出条件过于宽松

验证

有几种方法可以验证 K-means 算法的输出,从纯度到互信息 [4:6]。验证聚类算法输出的一个有效方法是为每个簇标记,并将这些簇通过一批新的标记观测值运行。例如,如果在这些测试之一中,你发现其中一个簇,CC,包含大部分与商品相关的股票,那么你可以选择另一个与商品相关的股票,SC,它不是第一批的一部分,并再次运行整个聚类算法。如果 SCCC 的子集,那么 K-means 已经按预期执行。如果是这种情况,你应该运行一个新的股票集,其中一些与商品相关,并测量真正阳性、真正阴性、假阳性、假阴性的数量。在 第二章 中 评估模型 部分引入的精确度、召回率和 F[1] 分数的值,Hello World! 确认你为簇选择的调整参数和标签是否确实正确。

注意

K-means 的 F1 验证

簇的质量,通过 F[1] 分数来衡量,取决于用于标记观测值(即,将簇标记为包含簇中股票相对百分比最高的行业)的规则、策略或公式。这个过程相当主观。验证一个方法的唯一可靠方式是评估几个标记方案,并选择产生最高 F[1] 分数的方案。

一种测量观测值分布在不同簇之间同质性的替代方法是计算统计熵。低熵值表明簇的杂质水平较低。熵可以用来找到最佳簇数 K

我们回顾了一些影响 K-means 聚类结果质量的调整参数,如下所示:

  • 初始选择质心

  • K 簇的数量

在某些情况下,相似性标准(即欧几里得距离或余弦距离)可能会影响聚类簇的纯净度或密度。

最后和重要的考虑因素是 K-means 算法的计算复杂性。本章的前几节描述了 K-means 的一些性能问题及其可能的补救措施。

尽管 K-means 算法有许多优点,但它并不擅长处理缺失数据或未观测到的特征。实际上,相互依赖的特征可能依赖于一个共同的隐藏(也称为潜在)特征。下一节中描述的期望最大化算法解决了这些局限性之一。

期望最大化算法

期望最大化算法最初被引入来估计不完整数据情况下的最大似然[4:7]。它是一种迭代方法,用于计算最大化观测值可能估计的模型特征,同时考虑未观测值。

迭代算法包括计算以下内容:

  • 通过推断潜在值来估计观测数据的最大似然期望 E(E 步)

  • 模型特征最大化期望 E(M 步)

期望最大化算法通过假设每个潜在变量遵循正态或高斯分布来解决聚类问题。这与 K-means 算法类似,其中每个数据点到每个簇中心的距离遵循高斯分布[4:8]。因此,一组潜在变量是高斯分布的混合。

高斯混合模型

潜在变量,Z[i],可以被视为模型(观测)X的行为(或症状),其中 Z 是行为的原因:

高斯混合模型

观测和潜在特征的可视化

潜在值,Z,遵循高斯分布。对于我们中的统计学家,混合模型的数学描述如下:

注意

最大化对数似然

M7:如果 x = {x[i]} 是与潜在特征 z = {z[i]} 相关的观测特征集,则给定模型参数 θ,观测 x 中特征 x[i] 的概率定义为:

高斯混合模型

M8:目标是最大化似然,L(θ),如下所示:

高斯混合模型

EM 概述

在实现方面,期望最大化算法可以分为三个阶段:

  1. 在给定一些潜在变量(初始步骤)的情况下,计算模型特征的日志似然。

  2. 在迭代 t 时计算对数似然期望(E 步)。

  3. 在迭代 t 时最大化期望(M 步)。

注意

E 步

M9: 模型参数 θ[n] 在迭代 n 时,完整数据日志似然期望值 Q 的计算使用潜在变量 z 的后验分布 p(z|x, θ) 和观测值与潜在变量的联合概率:

EM 概述

M 步

M10: 为了计算下一次迭代的模型参数 θ[n+1],期望函数 Q 对模型特征 θ 进行最大化:

EM 概述

S. Borman 的教程 [4:9] 中可以找到 EM 算法的正式、详细但简短的数学公式。

实现

让我们在 Scala 中实现三个步骤(初始步骤、E 步和 M 步)。EM 算法的内部计算相当复杂且令人困惑。您可能不会从特定实现的细节中获得太多好处,例如计算期望对数似然的协方差矩阵的特征值。此实现使用 Apache Commons Math 库包隐藏了一些复杂性 [4:10]。

注意

EM 算法的内部工作原理

如果您需要了解抛出异常的条件,可以下载 Apache Commons Math 库中 EM 算法实现的源代码。

MultivariateEM 类型的期望最大化算法实现为一个 ITransform 类型的数据转换,如 第二章 中 Monadic 数据转换 部分所述,Hello World!。构造函数的两个参数是 K 个聚类(或高斯分布)的数量和 xt 训练集(第 1 行)。构造函数将输出 V 类型的初始化为 EMCluster(第 2 行):

class MultivariateEMT <: AnyVal(implicit f: T => Double)
  extends ITransform[Array[T]](xt) with Monitor[T] {  //1
  type V = EMCluster  //2
  val model: Option[EMModel ] = train //3
  override def |> : PartialFunction[U, Try[V]]
}

多变量期望最大化类有一个由 EMCluster 类型的 EM 聚类列表组成的模型。Monitor 特性用于在训练过程中收集配置文件信息(请参阅 附录 A 中 Utility 类 下的 Monitor 部分,基本概念)。

EM 聚类 EMCluster 的信息由 key(质心或 means 值)和聚类的 density(密度)定义,即所有数据点到均值的距离的标准差(第 4 行):

case class EMCluster(key: Double, val means: DblArray, 
      val density: DblArray)   //4
type EMModel = List[EMCluster]

train 方法中实现 EM 算法时,使用 Apache Commons Math 的 MultivariateNormalMixture 用于高斯混合模型和 MultivariateNormalMixtureExpectationMaximization 用于 EM 算法:

  def train: Option[EMModel ] = Try {
  val data: DblMatrix = xt  //5
  val multivariateEM = new EM(data)
  multivariateEM.fit( estimate(data, K) ) //6

  val newMixture = multivariateEM.getFittedModel  //7
  val components = newMixture.getComponents.toList  //8
  components.map(p =>EMCluster(p.getKey, p.getValue.getMeans, 
                      p.getValue.getStandardDeviations)) //9
} match {/* … */}

让我们看看MultivariateEM包装类的主要train方法。第一步是将时间序列转换为以观测/历史报价为行、股票符号为列的原始Double矩阵。

XVSeries[T]类型的xt时间序列通过诱导的隐式转换(行5)转换为DblMatrix

高斯分布的初始混合可以通过用户提供或从estimate数据集中提取(行6)。调用getFittedModel触发 M 步骤(行7)。

注意

Java 和 Scala 集合的转换

使用import scala.collection.JavaConversions包将 Java 原语转换为 Scala 类型。例如,通过调用WrapAsScala类的asScalaIterator方法,将java.util.List转换为scala.collection.immutable.List,这是JavaConversions的基本特质之一。

Apache Commons Math 的getComponents方法返回一个java.util.List,通过调用toList方法(行8)将其转换为scala.collection.immutable.List。最后,数据转换返回一个EMCluster类型的聚类信息列表(行9)。

注意

第三方库异常

Scala 不强制在方法签名中声明异常。因此,不能保证会捕获所有类型的异常。当第三方库在两种情况下抛出异常时,会出现这个问题:

  • API 文档没有列出所有异常类型

  • 库已更新,并在一个方法中添加了新的异常类型

一个简单的解决方案是利用 Scala 的异常处理机制:

Try {
     ..
} match {
    case Success(results) => …
    case Failure(exception)  => ...
}

分类

新观测或数据点的分类是通过|>方法实现的:

override def |> : PartialFunction[Array[T], Try[V]] = {
  case x: Array[T] 
    if(isModel && x.length == dimension(xt)) => 
  Try( model.map(_.minBy(c => euclidean(c.means,x))).get)
}

|>方法类似于KMeans.|>分类器。

测试

让我们将MultivariateEM类应用于评估 K-means 算法时使用的相同 127 只股票的聚类。

如在维度诅咒部分所述,要分析的股票数量(127)限制了 EM 算法可用的观察数量。一个简单的选项是过滤掉股票价格的一些噪声,并应用简单的采样方法。最大采样率受历史价格中不同类型噪声频谱的限制。

注意

过滤和采样

在聚类之前,使用简单移动平均和固定间隔采样对数据进行预处理,在这个例子中非常基础。例如,我们不能假设所有股票的历史价格具有相同的噪声特征。动量股票和交易活跃的股票的报价噪声模式肯定与具有强大所有权的蓝筹股不同,而这些股票由大型共同基金持有。

采样率应考虑噪声频率的范围。它应设置为至少是最低频率噪声频率的两倍。

测试的目标是评估采样率samplingRate和 EM 算法中使用的K簇数量的影响:

val K = 4; val period = 8
val smAve = SimpleMovingAverageDouble  //10
val pfnSmAve = smAve |>    //11

val obs = symbolFiles.map(sym => (
  for {
    xs <- DataSource(sym, path, true, 1) |>extractor //12
    values <- pfnSmAve(xs.head)  //13
    y <- Try {   
        values.view.zipWithIndex.drop(period+1).toVector
          .filter( _._2 % samplingRate == 0)
          .map( _._1).toArray //14
    }
  } yield y).get) 

em(K, obs)  //15

第一步是创建一个具有预定义周期的简单移动平均,如第三章中“简单移动平均”部分所述,在[数据预处理]中描述。测试代码实例化了实现移动平均计算的pfnSmAve部分函数(第 11 行)。考虑的股票的符号是从路径目录中文件的名称中提取的。历史数据包含在名为path/STOCK_NAME.csv的 CSV 文件中(第 12 行)。

移动平均(第 13 行)的执行生成了一组平滑值,该值根据采样率samplingRate(第 14 行)进行采样。最后,在em方法中实例化期望最大化算法以聚类采样数据(第 15 行):

def em(K: Int, obs: DblMatrix): Int = {
  val em = MultivariateEMDouble //16
  show(s"${em.toString}")  //17
}

em方法实例化了具有特定簇数K的 EM 算法(第 16 行)。通过调用MultivariateEM.toString显示模型的内容。结果被汇总,然后在标准输出上以文本格式显示(第 17 行)。

第一次测试是执行具有 3 个簇和 10 个采样周期的 EM 算法,数据是通过 8 个周期的简单移动平均平滑的。2013 年 1 月 1 日至 2013 年 12 月 31 日之间 127 只股票的历史价格以 0.1 赫兹的频率采样,产生 24 个数据点。以下图表显示了每个簇的平均值:

Testing

使用 EM K=3 的每个簇标准化均值图表

簇 2 和簇 3 的均值向量具有相似的图案,这可能表明一组三个簇足以提供对股票群体内部相似性的初步了解。以下是一个使用 EM 方法且 K=3 时每个簇标准化标准差的图表:

Testing

使用 EM K=3 的每个簇标准化标准差图表

每个簇的标准差分布以及均值向量可以由以下事实解释:几个行业的股票价格协同下降,而其他股票则作为一个半同质群体随着美联储宣布,作为量化宽松计划的一部分,未来几个月购买的债券数量将减少而上涨。

注意

与 K-means 的关系

你可能会想知道 EM 和 K-means 之间有什么关系,因为这两种技术都解决了相同的问题。K-means 算法将每个观测值唯一地分配给一个且仅一个簇。EM 算法根据后验概率分配观测值。K-means 是 EM 在高斯混合模型中的特例[4:11]。

在线 EM 算法

在处理非常大的数据集时,在线学习是训练聚类模型的一个强大策略。最近,这种方法又引起了科学家的兴趣。在线 EM 算法的描述超出了本教程的范围。然而,如果你必须处理大型数据集,你可能需要知道有几个在线 EM 算法可供选择,例如批量 EM、逐步 EM、增量 EM 和蒙特卡洛 EM [4:12]。

减维

在没有对问题领域有先验知识的情况下,数据科学家会在他们的第一次尝试中包括所有可能的特征来创建分类、预测或回归模型。毕竟,做出假设是减少搜索空间的一个糟糕且危险的方法。一个模型使用数百个特征并不罕见,这增加了构建和验证模型的复杂性和显著的计算成本。

噪声过滤技术降低了模型对与偶发行为相关的特征的敏感性。然而,这些与噪声相关的特征在训练阶段之前是未知的,因此不能完全丢弃。因此,模型的训练变得非常繁琐且耗时。

过拟合是另一个可能由大量特征集引起的障碍。有限大小的训练集不允许你使用大量特征创建一个准确的模型。

减维技术通过检测对整体模型行为影响较小的特征来缓解这些问题。

降低模型中特征数量的三种方法:

  • 对于较小的特征集,可以使用如方差分析等统计分析解决方案

  • 在第六章的“正则化”部分中介绍的正规化和收缩技术,回归和正则化

  • 通过变换协方差矩阵来最大化数据集方差的算法

下一节介绍了第三类中最常用的算法之一:主成分分析。

主成分分析

主成分分析的目的是通过降低方差顺序将原始特征集转换成一个新的有序特征集。原始观测值被转换成一组相关性较低的新变量。让我们考虑一个具有两个特征{x, y}和一组观测值{x[i], y[i]}的模型,这些观测值在以下图表中绘制:

主成分分析

二维模型的主成分分析可视化

xy 特征转换为两个 XY 变量(即旋转),以适当地匹配观察值的分布。具有最高方差的变量称为第一个主成分。具有第 n 个最高方差的变量称为第 n 个主成分。

算法

我强烈推荐您阅读 Lindsay Smith 的教程[4:13],该教程以非常具体和简单的方式使用二维模型描述了 PCA 算法。

注意

PCA 和协方差矩阵

M11:两个 XY 特征与观察集 {x[i],y[i]} 及其各自平均值之间的协方差定义为:

算法

在这里,算法算法 分别是观察值 xy 的各自平均值。

M12:协方差是从每个观察值的 Z 分数计算的:

算法

M13:对于具有 n 个特征的模型,x[i],协方差矩阵定义为:

算法

M14:将 x 转换为 X 的转换包括计算协方差矩阵的特征值:

算法

M15:特征值按其方差递减的顺序排序。最后,累积方差超过预定义阈值(矩阵迹的百分比)的前 m 个特征值是主成分:

算法

算法分为五个步骤:

  1. 通过标准化均值和标准差计算观察值的 Z 分数。

  2. 计算原始观察集的协方差矩阵 Σ

  3. 通过提取特征值和特征向量,为具有转换特征的观察值计算新的协方差矩阵 Σ'

  4. 通过降低方差顺序将矩阵转换为秩特征值。有序的特征值是主成分。

  5. 通过将新协方差矩阵迹的百分比作为阈值,选择总方差超过阈值的特征值。

通过对协方差矩阵 Σ 进行对角化提取主成分的过程在以下图中进行了可视化。用于表示协方差值的灰色阴影从白色(最低值)到黑色(最高值)变化:

算法

PCA 中提取特征值的可视化

特征值(X 的方差)按其值递减的顺序排序。当最后一个特征值的累积值(对角矩阵的右下角部分)变得不显著时,PCA 算法成功。

实现

可以使用 Apache Commons Math 库中计算特征值和特征向量的方法轻松实现主成分分析。PCA类定义为ITransform类型的数据转换,如第二章中单调数据转换部分所述,第二章,Hello World!

PCA类有一个单一参数:xt训练集(第 1 行)。输出类型为Double,用于投影观测值和特征向量(第 2 行)。构造函数定义了 Z 分数norm函数(第 3 行):

class PCA@specialized(Double) T <: AnyVal(implicit f: T => Double) 
  extends ITransform[Array[T]](xt) with Monitor[T] { //1
  type V = Double    //2

  val norm = (xv: XVSeries[T]) =>  zScores(xv) //3
  val model: Option[PCAModel] = train //4
  override def |> : PartialFunction[U, Try[V]]
}

PCA 算法的模型由PCAModel案例类定义(第 4 行)。

注意

触发隐式转换

可以通过向完全声明的变量赋值来调用隐式转换。例如,通过声明目标变量的类型来调用从XVSeries[T]XVseries[Double]的转换:val z: XVSeries[Double] = xv(第 4 行)。

PCA 算法的模型PCAModel由公式M11中定义的协方差矩阵covariance和公式M16中计算的特征值数组组成:

case class PCAModel(val covariance: DblMatrix, 
   val eigenvalues: DblArray)

|>转换方法实现了主成分的计算(即特征向量和特征值):

def train: Option[PCAModel] = zScores(xt).map(x => { //5
  val obs: DblMatrix = x.toArray
  val cov = new Covariance(obs).getCovarianceMatrix //6

  val transform = new EigenDecomposition(cov) //7
  val eigenVectors = transform.getV  //8
    val eigenValues = 
           new ArrayRealVector(transform.getRealEigenvalues)

  val covariance = obs.multiply(eigenVectors).getData  //9
  PCAModel(covariance, eigenValues.toArray)   //10
}) match {/* … */}

标准化函数zScores执行 Z 分数转换(公式M12)(第 5 行)。接下来,该方法从标准化数据中计算协方差矩阵(第 6 行)。计算了特征向量eigenVectors(第 7 行),然后使用 Apache Commons Math EigenDecomposition类的getV方法检索(第 8 行)。该方法从特征向量计算对角线转换后的协方差矩阵(第 9 行)。最后,数据转换返回 PCA 模型的一个实例(第 10 行)。

|>预测方法包括将观测值投影到主成分上:

override def |> : PartialFunction[Array[T], Try[V]] = {
  case x: Array[T] 
     if(isModel && x.length == dimension(xt)) => 
        Try( inner(x, model.get.eigenvalues) )
}

XTSeries对象的inner方法计算值x和特征值模型的点积。

测试用例

让我们将 PCA 算法应用于提取代表 34 家标准普尔 500 公司部分财务指标比率的特征子集。考虑的指标如下:

  • 跟踪市盈率(PE)

  • 市销率(PS)

  • 市净率(PB)

  • 净资产收益率(ROE)

  • 运营利润率(OM)

财务指标在附录 A 中金融 101术语部分进行描述,基本概念

输入数据以以下格式指定为一个元组(一个股票代码和一个包含五个财务比率(市盈率 PE、市销率 PS、市净率 PB、净资产收益率 ROE 和运营利润率 OM)的数组):

val data = Array[(String, DblVector)] (
  // Ticker              PE     PS     PB   ROE    OM
  ("QCOM", ArrayDouble),
  ("IBM",  ArrayDouble), 
   …
)

执行 PCA 算法的客户端代码定义如下:

val dim = data.head._2.size
val input = data.map( _._2.take(dim)) 
val pca = new PCADouble //11
show(s"PCA model: ${pca.toString}")  //12

使用input数据实例化 PCA,然后以文本格式显示(第 12 行)。

评估

对 34 个财务比率的第一项测试使用了一个具有五个维度的模型。正如预期的那样,算法产生了一个五个有序的特征值列表:

2.5321, 1.0350, 0.7438, 0.5218, 0.3284

让我们在下面的柱状图中绘制特征值的相对值(即每个特征的相对重要性):

评估

PCA 中 5 维度的特征值分布

图显示,五个特征中有三个特征占到了总方差的 85%(变换后协方差矩阵的迹)。我邀请你尝试这些特征的不同组合。选择现有特征子集的操作就像应用 Scala 的takedrop方法一样简单:

data.map( _._2.take(dim))

让我们绘制三个不同模型配置的累积特征值:

  • 五个特征:PE、PS、PB、ROE 和 OM

  • 四个特征:PE、PS、PB 和 ROE

  • 三个特征:PE、PS 和 PB

图将如下所示:

评估

PCA 中 3、4 和 5 个特征的特征值分布

图显示了变换后特征的方差的特征值的累积值。如果我们对累积方差应用 90%的阈值,那么每个测试模型的特征值数量如下:

  • {PE, PS, PB}: 2

  • {PE, PS, PB, ROE}: 3

  • {PE, PS, PB, ROE, OM}: 3

总之,PCA 算法将三特征模型降低了 33%,四特征模型降低了 25%,五特征模型降低了 40%,阈值为 90%。

注意

PCA 交叉验证

像任何其他无监督学习技术一样,得到的特征值必须通过使用回归估计量(如偏最小二乘回归(PLSR)或预测残差误差平方和(PRESS))进行一或 K 折交叉验证的方法来验证。对于那些不怕统计的人来说,我建议你阅读 S. Engelen 和 M. Hubert 的《快速鲁棒 PCA 交叉验证》 [4:14]。你需要意识到这些回归估计量的实现并不简单。PCA 的验证超出了本书的范围。

主成分分析是更一般性因子分析的一个特例。后者类别的算法不需要将协方差矩阵转换成正交矩阵。

非线性模型

主成分分析技术要求模型是线性的。尽管这类算法的研究超出了本书的范围,但值得提及两种扩展 PCA 以用于非线性模型的方法:

  • 核 PCA

  • 流形学习

核 PCA

主成分分析(PCA)从一组相关值数组中提取一组正交线性投影,X = {x[i]}。核 PCA 算法包括从内积矩阵X^TX中提取相似的正交投影集。通过应用核函数到内积,支持非线性。核函数在第八章的核函数部分中描述,核模型和支持向量机。核 PCA 是尝试从原始观测空间中提取一个低维特征集(或流形)。线性 PCA 是流形切空间上的投影。

流形

流形的概念借鉴自微分几何。流形将二维空间中的曲线或三维空间中的曲面推广到更高维度。非线性模型与度量是切空间上的内积X^TX的黎曼流形相关联。流形代表一个低维特征空间,它嵌入到原始观测空间中。想法是使用指数映射将线性切空间的主成分投影到流形上。这一壮举是通过使用从局部线性嵌入和密度保持映射到拉普拉斯特征映射等众多技术实现的 [4:15]。

观测向量不能直接用于流形上,因为诸如范数或内积等度量依赖于向量应用到的流形的定位。在流形上的计算依赖于诸如逆变和协变向量之类的张量。张量代数由协变和逆变函子支持,这在第一章的抽象部分中介绍,入门

使用可微流形的技巧被称为谱降维。

注意

替代降维技术

这里有一些更多的替代技术,列作参考文献:因子分析、主因子分析、最大似然因子分析、独立成分分析、随机投影、非线性 ICA、Kohonen 的自组织映射、神经网络和多维尺度,仅举几例 [4:16]。

流形学习方法,如分类器和降维技术,与半监督学习相关联。

性能考虑

这三种无监督学习技术具有相同的局限性——高计算复杂度。

K-means

K-means 的计算复杂度为O(iKnm),其中i是迭代次数(或递归次数),K是簇的数量,n是观测的数量,m是特征的数量。以下是针对 K-means 算法性能不佳的一些补救措施:

  • 通过使用如本章开头所述的通过对初始簇的方差进行排序的初始化技术来播种质心,以减少平均迭代次数

  • 利用 Hadoop 或 Spark 等大规模框架并行实现 K-means

  • 通过使用离散傅里叶变换或卡尔曼滤波等平滑算法过滤噪声来减少异常值和特征的数量

  • 通过两步过程降低模型的维度:

    1. 使用较少的聚类数K和/或关于数据点重新分配的宽松退出条件执行第一次遍历。接近每个质心的数据点被汇总为一个单独的观察结果。

    2. 对汇总观察结果执行第二次遍历。

EM

期望最大化算法每轮迭代(E + M 步骤)的计算复杂度为O(m²n),其中m是隐藏或潜在变量的数量,n是观察数。

建议的性能改进的部分列表包括:

  • 对原始数据进行过滤以去除噪声和异常值

  • 在大型特征集上使用稀疏矩阵以降低协方差矩阵的复杂性,如果可能的话

  • 在可能的情况下应用高斯混合模型——高斯分布的假设简化了似然对数的计算

  • 使用 Apache Hadoop 或 Spark 等并行数据处理框架,如第十二章中Apache Spark部分所述,可扩展框架

  • 在 E 步骤中使用核函数来降低协方差的估计

PCA

提取主成分的计算复杂度为O(m²n + n³),其中m是特征数,n是观察数。第一项表示计算协方差矩阵的计算复杂度。第二项反映了特征值分解的计算复杂度。

PCA 的潜在性能改进或替代解决方案的列表包括以下内容:

  • 假设方差是高斯分布

  • 使用稀疏矩阵计算具有大型特征集和缺失数据的特征值

  • 调查 PCA 的替代方案以降低模型维度,例如离散傅里叶变换DFT)或奇异值分解SVD)[4:17]

  • 在研究阶段结合 PCA 和 EM(期望最大化算法)

  • 在 Apache Hadoop 或 Spark 等并行数据处理框架上部署数据集,如第十二章中Apache Spark部分所述,可扩展框架

摘要

这完成了对三种最常用的无监督学习技术的概述:

  • K-means 用于聚类具有合理维度的模型的全观测特征

  • 预测聚类中观察到的和潜在特征的期望最大化

  • 主成分分析用于转换和提取线性模型中关于方差的最重要的特征

非线性模型的多维学习是一个技术上有挑战性但具有巨大潜力的领域,特别是在动态对象识别方面[4:18]。

要记住的关键点是,无监督学习技术被用于:

  • 通过自身从未标记的观察中提取结构和关联

  • 作为监督学习在训练阶段之前减少特征数量的预处理阶段

无监督学习和监督学习之间的区别并不像你想象的那样严格。例如,K-means 算法可以被增强以支持分类。

在下一章中,我们将解决第二个用例,并从生成模型开始介绍监督学习技术。

第五章。朴素贝叶斯分类器

本章介绍了最常见和最简单的生成分类器——朴素贝叶斯。如前所述,生成分类器是监督学习算法,试图拟合两个事件 X 和 Y 的联合概率分布 p(X,Y),代表两组观察到的和隐藏(或潜在)变量 x 和 y。

在本章中,你将学习,并希望你能欣赏,通过一个具体的例子来展示朴素贝叶斯技术的简单性。然后,你将学习如何构建一个朴素贝叶斯分类器,根据金融市场分析中的一些先验技术指标来预测股价走势。

最后,你将学习如何通过使用金融新闻和新闻稿来预测股价,将朴素贝叶斯应用于文本挖掘。

概率图模型

让我们从基础统计学复习开始。

给定两个事件或观察 X 和 Y,X 和 Y 的联合概率定义为p(X,Y) = p(X∩Y)。如果观察 X 和 Y 不相关,那么有一个被称为条件独立性的假设,则p(X,Y) = p(X).p(Y)。事件 Y 在 X 条件下的条件概率定义为p(Y|X) = p(X,Y)/p(X)

这两个定义相当简单。然而,在大量变量和条件概率序列的情况下,概率推理可能难以阅读。正如一图胜千言,研究人员引入了图模型来描述随机变量之间的概率关系[5:1]。

有两种类型的图,因此,图模型,如下所示:

  • 如贝叶斯网络这样的有向图

  • 如条件随机字段(参考第七章中的条件随机字段部分,序列数据模型),无向图

有向图模型是有向无环图,已被引入用于:

  • 提供一种简单的方式来可视化概率模型

  • 描述变量之间的条件依赖关系

  • 用图形对象之间的连接来表示统计推理

贝叶斯网络是一种有向图模型,它定义了一组变量上的联合概率[5:2]。

两个联合概率 p(X,Y)p(X,Y,Z) 可以使用贝叶斯网络进行图形建模,如下所示:

概率图形模型

概率图形模型的例子

条件概率 p(Y|X) 由从输出(或症状)Y 到输入(或原因)X 的箭头表示。详细模型可以描述为变量之间的大型有向图。

图形模型的隐喻

从软件工程的角度来看,图形模型以与 UML 类图以相同的方式可视化概率方程。

这里是一个真实世界贝叶斯网络的例子;烟雾探测器的功能:

  1. 火灾可能会产生烟雾。

  2. 烟雾可能会触发警报。

  3. 空电池可能会触发警报。

  4. 警报可能会通知房主。

  5. 警报可能会通知消防部门。

流程图如下:

概率图形模型

烟雾探测器的贝叶斯网络

这种表示可能有点反直觉,因为顶点是从症状(或输出)指向原因(或输入)。除了贝叶斯网络外,有向图模型还用于许多不同的模型[5:3]。

板模型

除了有向无环图之外,概率模型还有几种不同的表示方式,例如常用的潜在狄利克雷分配LDA)的板模型[5:4]。

朴素贝叶斯模型是基于贝叶斯定理的概率模型,在特征独立性的假设下,如第一章中“监督学习”下的“生成模型”部分所述,入门

朴素贝叶斯分类器

X 特征之间的条件独立性是朴素贝叶斯分类器的一个基本要求。它也限制了其适用性。通过简单和具体的例子可以更好地理解朴素贝叶斯分类[5:5]。

介绍多项式朴素贝叶斯

让我们考虑如何预测利率变化的问题。

第一步是列出可能触发或导致利率增加或减少的因素。为了说明朴素贝叶斯,我们将选择消费者价格指数CPI)和改变联邦基金利率FDF)以及国内生产总值GDP),作为第一组特征。术语在附录 A 中“金融 101”下的“术语”部分进行了描述,基本概念

该用例用于预测 1 年期国债收益率(1yTB)的变化方向,同时考虑当前 CPI、FDF 和 GDP 的变化。因此,目标是创建一个预测模型,结合这三个特征。

假设没有可用的金融投资专家能够提供规则或政策来预测利率。因此,该模型高度依赖于历史数据。直观上,如果一个特征在 1 年期国债收益率增加时总是增加,那么我们可以得出结论,这些特征与利率输出变化之间存在强烈的因果关系。

介绍多项式朴素贝叶斯

预测 1 年期 T-bill 收益率变化的朴素贝叶斯模型

相关性(或因果关系)是从历史数据中得出的。该方法包括计算每个特征增加(上升)或减少(下降)的次数,并记录相应的预期结果,如下表所示:

ID GDP FDF CPI 1y-TB
1 上升 下降 上升 上升
2 上升 上升 上升 上升
3 下降 上升 下降 下降
4 上升 下降 下降 下降
256 下降 下降 上升 下降

首先,让我们列出三个特征和输出值(1 年期国债收益率变化的方向)的每个变化(上升下降)的次数:

数量 GDP FDF CPI 1yTB
上升 169 184 175 159
下降 97 72 81 97
总计 256 256 256 256
总/上升 0.66 0.72 0.68 0.625

接下来,让我们计算当 1 年期国债收益率增加时(共发生 159 次)每个特征的正面方向数量:

数量 GDP 联邦基金 CPI
上升 110 136 127
下降 49 23 32
总计 159 159 159
总/上升 0.69 0.85 0.80

从前表可以得出结论,当 GDP 增加时(69%的时间),联邦基金利率增加(85%的时间),以及 CPI 增加(80%的时间),1 年期国债收益率会增加。

在将这些发现转化为概率模型之前,让我们将朴素贝叶斯模型形式化。

形式化

让我们先明确贝叶斯模型中使用的术语:

  • 类先验概率或类先验:这是一个类的概率

  • 似然:这是在给定一个类的情况下观察到一个值或事件的概率,也称为给定一个类的预测概率

  • 证据:这是发生观察到的概率,也称为预测者的先验概率

  • 后验概率:这是观测 x 在给定类别中的概率

没有模型可以比这更简单!通常使用对数似然率 log p(x[i]|C[j]) 而不是似然率,以减少具有低似然率的特征 x[i] 的影响。

新观测的朴素贝叶斯分类的目标是计算具有最高对数似然率的类别。朴素贝叶斯模型的数学符号也很简单。

注意

朴素贝叶斯分类

M1:后验概率 p(C[j]|x) 定义为:

形式化

在这里,x = {x[i]} (0, n-1) 是一组 n 个特征。{C[j]} 是一组具有其类别先验 p(C[j]) 的类别。x = {x[i]} (0, n-1) 是一组具有 n 个特征的。p(x|C[j]) 是每个特征的概率

M2:通过假设特征的条件独立性,后验概率 p(C[j]| x) 的计算被简化。

形式化

在这里,x[i] 是独立的,并且概率已归一化,证据 p(x) = 1

M3:最大似然估计MLE)定义为:

形式化

M4:类别 C[m] 的观测 x 的朴素贝叶斯分类定义为:

形式化

这个特定的用例有一个主要的缺点——GDP 统计数据是按季度提供的,而 CPI 数据每月提供一次,FDF 利率的变化相对较少。

经验主义视角

计算后验概率的能力取决于使用历史数据对似然率的公式化。一个简单的解决方案是计算每个类别的观测发生次数并计算频率。

让我们考虑第一个例子,预测 1 年期国债收益率的变动方向,给定 GDP、FDF 和 CPI 的变化。

结果用简单的概率公式和有向图模型表示:

P(GDP=UP|1yTB=UP) = 110/159
P(1yTB=UP) = num occurrences (1yTB=UP)/total num of occurrences=159/256
p(1yTB=UP|GDP=UP,FDF=UP,CPI=UP) = p(GDP=UP|1yTB=UP) x
                                  p(FDF=UP|1yTB=UP) x
                                  p(CPI=UP|1yTB=UP) x
                                  p(1yTB=UP) = 0.69 x 0.85 x 0.80 x 0.625

经验主义视角

预测 1 年期国债收益率变化的贝叶斯网络

注意

过拟合

如果观测的数量相对于特征的数量不足,朴素贝叶斯模型不会免疫过拟合。解决这个问题的方法之一是使用互信息排除法进行特征选择[5:6]。

这个问题不适合作为贝叶斯分类的候选者,以下有两个原因:

  • 训练集不够大,无法计算准确的先验概率并生成稳定的模型。需要数十年的季度 GDP 数据来训练和验证模型。

  • 特征有不同的变化率,这主要有利于频率最高的特征;在这种情况下,是 CPI。

让我们选择另一个具有大量可用历史数据并可自动标记的用例。

预测模型

预测模型是第二个用例,它包括预测股票收盘价变化方向,pr[t+1] = {UP, DOWN},在交易日 t + 1,给定其价格、成交量、波动率的历史记录,pr[i] 对于 i = 0i = t。体积和波动率特征已在 第一章 的 编写简单工作流 部分中使用,入门

因此,考虑的三个特征如下:

  • 最后一个交易会话,t,的收盘价 pr[t] 高于或低于过去 n 个交易日平均收盘价,[t-n, t]

  • 最后一个交易日的成交量 vl[t] 高于或低于过去 n 个交易日平均成交量。

  • 最后一个交易日的波动率 vt[t] 高于或低于过去 n 个交易日平均波动率。

有向图形模型可以使用一个输出变量(第 t + 1 个会话的价格高于第 t 个会话的价格)和三个特征来表示:价格条件(1)、成交量条件(2)和波动率条件(3)。

预测模型

预测股价未来方向的贝叶斯模型

该模型基于以下假设:每个特征和期望值至少有一个观测值,理想情况下观测值很少。

零频率问题

可能训练集不包含任何特定标签或类别的特征的实际观测数据。在这种情况下,平均值为 0/N = 0,因此,似然值为空,使得分类不可行。对于给定类别中特征观测值很少的情况也是一个问题,因为它会扭曲似然值。

对于未观测特征或出现次数较少的特征,存在一些校正或平滑公式来解决这个问题,例如拉普拉斯利德斯顿平滑公式。

注意

计数器的平滑因子

M5:维度为 n 的特征 N 个观测值中平均 k/N 的拉普拉斯平滑公式定义为:

零频率问题

M6:具有因子 α 的利德斯顿平滑公式定义为:

零频率问题

这两个公式在自然语言处理应用中常用,其中特定单词或标签的出现是一个特征 [5:7]。

实现

我认为现在是时候写一些 Scala 代码,并尝试使用朴素贝叶斯。让我们从软件组件的概述开始。

设计

我们实现的朴素贝叶斯分类器使用了以下组件:

  • Model 类型的通用模型,NaiveBayesModel,在类实例化过程中通过训练初始化。

  • BinNaiveBayesModel二项式分类模型的模型,它是NaiveBayesModel的子类。该模型由一对正负Likelihood类实例组成。

  • MultiNaiveBayesModel多项式分类的模型。

  • NaiveBayes分类器类有四个参数:一个平滑函数,如拉普拉斯,以及一组XVSeries类型的观察值,一组DblVector类型的标签,一个LogDensity类型的对数密度函数,以及类的数量。

应用于分类器实现的软件架构原则在附录 A 的不可变分类器设计模板部分中描述,基本概念

Naïve Bayes 分类器的关键软件组件在以下 UML 类图中描述:

设计

Naïve Bayes 分类器的 UML 类图

UML 图省略了Monitor或 Apache Commons Math 组件等辅助特性或类。

训练

训练阶段的目标是构建一个模型,该模型由每个特征的似然和类先验组成。一个特征的似然被识别如下:

  • 对于二进制特征或计数器,在N > k个观察值中此特征的出现次数k

  • 对于数值或连续特征,此特征的所有观察值的平均值

为了这个测试案例,假设特征,即技术分析指标,如价格、体积和波动性是条件独立的。这个假设实际上并不正确。

备注

条件依赖

近期模型,被称为隐式朴素贝叶斯HNB),放宽了特征之间独立性的限制。HNB 算法使用条件互信息来描述某些特征之间的相互依赖性[5:8]。

让我们编写代码来训练二项式和多项式朴素贝叶斯。

类似然

第一步是使用历史数据为每个特征定义类似然。Likelihood类具有以下属性(行1):

  • label观察值的标签

  • 元组拉普拉斯或 Lidstone 平滑均值和标准差的数组,muSigma

  • prior类的先验概率

与本书中展示的任何代码片段一样,为了保持代码的可读性,省略了类参数和方法参数的验证:

class LikelihoodT <: AnyVal(implicit f: T => Double)  {  //1

  def score(obs: Array[T], logDensity: LogDensity): Double = //2
    (obs, muSigma).zipped
     .map{ case(x, (mu,sig)) => (x, mu, sig)}
     ./:(0.0)((prob, entry) => {
         val x = entry._1
         val mean = entry._2
         val stdDev = entry._3
         val logLikelihood = logDensity(mean, stdDev, x) //3
         val adjLogLikelihood = if(logLikelihood <MINLOGARG)
                     MINLOGVALUE else logLikelihood)
         prob + Math.log(adjLogLikelihood) //4
   }) + Math.log(prior)
}

参数化的Likelihood类有两个目的:

  • 定义关于类C[k]的统计信息:其标签、其均值和标准差,以及先验概率p(C[k])

  • 计算新观察值的分数以进行其运行时分类(第 2 行)。计算似然的对数使用 LogDensity 类的 logDensity 方法(第 3 行)。如下一节所示,对数密度可以是高斯分布或伯努利分布。score 方法使用 Scala 的 zipped 方法将观察值与标记值合并,并实现了 M3 公式(第 4 行)。

高斯混合模型特别适合于建模特征具有大量离散值或连续变量的数据集。特征 x 的条件概率由正态概率密度函数 [5:9] 描述。

注意

使用高斯密度的对数似然

M7:对于 Lidstone 或 Laplace 平滑的均值 µ 和标准差 σ,高斯分布的后验概率的对数似然定义为:

类别似然

高斯的对数 logGauss 和正态分布的对数 logNormalstats 类中定义,该类在 第二章 的 数据概要 节中介绍,Hello World!

def logGauss(mean: Double, stdDev: Double, x: Double): Double ={
  val y = (x - mean)/stdDev
  -LOG_2PI - Math.log(stdDev) - 0.5*y*y
}
val logNormal = logGauss(0.0, 1.0, _: Double)

logNormal 的计算实现为一个部分应用函数。

LogDensity 类的函数计算每个特征的概率密度(第 5 行):

type LogDensity = (Double*) => Double

二项式模型

下一步是定义用于双类分类方案的 BinNaiveBayesModel 模型。双类模型由两个 Likelihood 实例组成:positives 用于标签 UP(值 1)和 negatives 用于标签 DOWN(值 0)。

为了使模型通用,我们创建了一个 NaiveBayesModel 特质,可以根据需要扩展以支持二项式和多项式朴素贝叶斯模型,如下所示:

trait NaiveBayesModel[T]  {
  def classify(x: Array[T], logDensity: LogDensity): Int //5
}

classify 方法使用训练好的模型对给定 logDensity 概率密度函数的 Array[T] 类型的多变量观察值 x 进行分类(第 5 行)。该方法返回观察值所属的类别。

让我们从实现二项式朴素贝叶斯 BinNaiveBayesModel 类的定义开始:

class BinNaiveBayesModelT <: AnyVal(implicit f: T => Double)
  extends NaiveBayesModel[T] { //6

  override def classify(x: Array[T], logDensity: logDensity): Int = //7
   if(pos.score(x,density) > neg.score(x,density)) 1 else 0
  ...
}

BinNaiveBayesModel 类的构造函数接受两个参数:

  • pos: 正面结果的观察值的类别似然

  • neg: 负面结果的观察值的类别似然(第 6 行)

classify 方法在朴素贝叶斯分类器中通过 |> 操作符被调用。如果观察值 x 与包含正例的 Likelihood 类匹配,则返回 1,否则返回 0(第 7 行)。

注意

模型验证

Naïve Bayes 模型(似然)的参数通过训练计算得出,无论在此示例中模型是否实际经过验证,model值都会实例化。商业应用需要使用如 K 折验证和 F1 度量等方法验证模型,如附录 A 中“不可变分类器设计模板”部分所述,基本概念

多项式模型

MultiNaiveBayesModel类定义的多项式朴素贝叶斯模型与BinNaiveBayesModel非常相似:

class MultiNaiveBayesModel[T <: AnyVal(   //8
   likelihoodSet: Seq[Likelihood[T]])(implicit f: T => Double)
  extends NaiveBayesModel[T]{

  override def classify(x: Array[T], logDensity: LogDensity): Int = {
    val <<< = (p1: Likelihood[T], p2: Likelihood[T]) => 
              p1.score(x, density) > p1.score(x, density) //9
    likelihoodSet.sortWith(<<<).head.label  //10
  }
  ...
}

多项式朴素贝叶斯模型与其二项式对应模型的不同之处如下:

  • 其构造函数需要一个类似然序列likelihoodSet(第 8 行)。

  • classify运行时分类方法使用<<<函数按分数(后验概率)对类似然进行排序(第 9 行)。该方法返回具有最高对数似然值的类 ID(第 10 行)。

分类器组件

朴素贝叶斯算法作为数据转换实现,使用从ITransform类型训练集隐式提取的模型,如第二章中“单态数据转换”部分所述,Hello World!

多项式朴素贝叶斯属性如下:

  • 平滑公式(拉普拉斯、Lidstone 等),smoothing

  • 定义为xt的多变量观测值集合

  • 与一组观测值相关联的期望值(或标签),expected

  • 概率密度函数的对数,logDensity

  • 类的数量——对于二项式朴素贝叶斯和BinNaiveBayesModel类型为两个,对于多项式朴素贝叶斯和MultiNaiveBayesModel类型则更多(第 11 行)

代码如下:

class NaiveBayesT <: AnyVal(implicit f: T => Double) 
  extends ITransform[Array[T]](xt) 
    with Supervised[T, Array[T]] with Monitor[Double] { //11
  type V = Int  //12
  val model: Option[NaiveBayesModel[T]]  //13
  def train(expected: Int): Likelihood[T]
  …
}

Monitor特质定义了各种日志和显示函数。

ITransform类型的数据转换需要指定输出类型V(第 12 行)。朴素贝叶斯输出是观测值所属类的索引。模型的model类型可以是BinNaiveBayesModel(用于两个类)或MultiNaiveBayesModel(用于多项式模型)(第 13 行):

val model: Option[NaiveBayesModel[T]] = Try {
  if(classes == 2) 
    BinNaiveBayesModelT, train(0))
  else 
    MultiNaiveBayesModelT( train(_)))
} match {
  case Success(_model) => Some(_model)
  case Failure(e) => /* … */
}

注意

训练和类实例化

允许在训练时仅实例化一次 Naïve Bayes 模型具有几个优点。它防止客户端代码在未训练或部分训练的模型上调用算法,并减少了模型的状态数量(未训练、部分训练、训练、验证等)。这是一种优雅的方式,将模型训练的细节隐藏给用户。

train方法应用于每个类。它接受类的索引或标签,并生成其对数似然数据(第 14 行):

def train(index: Int): Likelihood[T] = {   //14
  val xv: XVSeries[Double] = xt
  val values = xv.zip(expected) //15
               .filter( _._2 == index).map(_._1) //16
  if( values.isEmpty )
     throw new IllegalStateException( /* ... */)

  val dim = dimension(xt)
  val meanStd = statistics(values).map(stat => 
         (stat.lidstoneMean(smoothing, dim), stat.stdDev)) //17
  Likelihood(index, meanStd, values.size.toDouble/xv.size) //18
}

训练集通过将观测的 xt 向量与预期类别 expected 进行压缩生成(行 15)。该方法过滤掉标签与该类别不对应的观测(行 16)。使用 Lidstone 平滑因子计算 meanStd 对(均值和标准差)(行 17)。最后,训练方法返回与索引 label 对应的类别似然(行 18)。

NaiveBayes 类还定义了 |> 运行时分类方法和 F[1] 验证方法。这两种方法将在下一节中描述。

注意

处理缺失数据

朴素贝叶斯在处理缺失数据方面采用无废话的方法。你只需忽略观测中值缺失的属性。在这种情况下,这些观测的特定属性的先验不进行计算。这种解决方案显然是由于特征之间的条件独立性而成为可能的。

NaiveBayesapply 构造函数返回 NaiveBayes 类型:

object NaiveBayes {
  def applyT <: AnyVal (implicit f: T => Double): NaiveBayes[T] = 
    new NaiveBayesT
   …
}

分类

通过训练计算出的似然和类先验用于验证模型和分类新观测。

该分数代表似然估计的对数(或后验概率),它是通过使用从训练阶段提取的均值和标准差计算的高斯分布的对数之和以及类似然的对数。

使用高斯分布的朴素贝叶斯分类通过两个 C[1]C[2] 类别以及具有两个特征 (xy) 的模型进行说明:

分类

使用二维模型说明高斯朴素贝叶斯

|> 方法返回部分函数,该函数使用其中一个朴素贝叶斯模型对新的 x 观测进行运行时分类。modellogDensity 函数用于将 x 观测分配到适当的类别(行 19):

override def |> : PartialFunction[Array[T], Try[V]] = {
  case x: Array[T] if(x.length >0 && model != None) => 
    Try( model.map(_.classify(x, logDensity)).get)  //19
}

F1 验证

最后,朴素贝叶斯分类器通过 NaiveBayes 类实现。它使用朴素贝叶斯公式实现训练和运行时分类。为了强制开发者为任何新的监督学习技术定义验证,该类从声明 validate 验证方法的 Supervised 特性继承:

trait Supervised[T, V] {
  self: ITransform[V] =>  //20
    def validate(xt: XVSeries[T], 
      expected: Vector[V]): Try[Double]  //21
}

模型的验证仅适用于 ITransform 类型的数据转换(行 20)。

validate 方法接受以下参数(行 21):

  • 多维观测的 xt 时间序列

  • 预期类值的 expected 向量

默认情况下,validate 方法返回模型的 F[1] 分数,如 第二章 中 评估模型 一节所述,Hello World!

让我们实现朴素贝叶斯分类器的 Supervised 特性的关键功能:

override def validate(
    xt: XVSeries[T], 
    expected: Vector[V]): Try[Double] =  Try {   //22
  val predict = model.get.classify(_:Array[Int],logDensity) //23
  MultiFValidation(expected, xt, classes)(predict).score  //24
}

通过将预测类别分配给新的x观测值(第23行),创建预测predict部分应用函数,然后将预测和类别的索引加载到MultiFValidation类中,以计算 F[1]分数(第24行)。

特征提取

监督学习算法训练中最关键的因素是创建标记数据。幸运的是,在这种情况下,标签(或预期类别)可以自动生成。目标是预测下一个交易日股票价格的方向,考虑到过去n天的移动平均价格、成交量和波动性。

特征提取遵循以下六个步骤:

  1. 提取每个特征的历 史交易数据(即价格、成交量波动性)。

  2. 计算每个特征的简单移动平均。

  3. 计算每个特征的值和移动平均之间的差异。

  4. 通过将正值赋值为 1,负值赋值为 0 来归一化差异。

  5. 生成股票收盘价与前一交易日收盘价之间的差异的时间序列。

  6. 通过将正值赋值为 1,负值赋值为 0 来归一化差异。

以下图表说明了步骤 1 到 4 的特征提取:

特征提取

差异值的二进制量化——移动平均

第一步是提取每个股票在 2000 年 1 月 1 日至 2014 年 12 月 31 日期间的平均价格、成交量波动性(即1 – 低/高),使用每日和每周收盘价。让我们使用简单移动平均来计算[t-n, t]窗口的平均值。

extractor变量定义了从金融数据源中提取的特征列表,如附录 A 中数据提取数据源部分所述,基本概念

val extractor = toDouble(CLOSE)  // stock closing price
               :: ratio(HIGH, LOW) //volatility(HIGH-LOW)/HIGH
               :: toDouble(VOLUME)  // daily stock trading volume
               :: List[Array[String] =>Double]()

交易数据和度量名称约定在附录 A 中技术分析下的交易数据部分描述,基本概念

使用单子组合实现二项朴素贝叶斯的训练和验证:

val trainRatio = 0.8  //25
val period = 4
val symbol ="IBM"
val path = "resources/chap5"
val pfnMv = SimpleMovingAverageDouble |> //26
val pfnSrc = DataSource(symbol, path, true, 1) |>  //27

for {
  obs <- pfnSrc(extractor)  //28
  (x, delta) <- computeDeltas(obs) //29
  expected <- Try{difference(x.head.toVector, diffInt)}//30
  features <- Try { transpose(delta) } //31
  labeled <- //32
     OneFoldXValidationInt 
  nb <- NaiveBayesInt //33
  f1Score <- nb.validate(labeled.validationSet) //34
}
yield {
  val labels = ArrayString
  show(s"\nModel: ${nb.toString(labels)}")
}

第一步是将观测值分布在训练集和验证集之间。trainRatio值(第25行)定义了要包含在训练集中的原始观测集的比例。简单移动平均值由pfnMv部分函数(第26行)生成。用于生成三个交易时间序列(价格、波动性和成交量)的提取pfnSrc部分函数(第27行)。

下一步是应用pfnMv简单移动平均到obs多维时间序列(第29行),使用computeRatios方法:

type LabeledPairs = (XVSeries[Double], Vector[Array[Int]])

def computeDeltas(obs: XVSeries[Double]): Try[LabeledPairs] =
Try{
  val sm = obs.map(_.toVector).map( pfnMv(_).get.toArray) //35
  val x = obs.map(_.drop(period-1) )
  (x, x.zip(sm).map{ case(x,y) => x.zip(y).map(delta(_)) })//36
}

computeDeltas方法计算了使用简单移动平均平滑的sm观察值的时序(第 35 行)。该方法为xs观察集和sm平滑数据集中的每个三个特征生成一个 0 和 1 的时序(第 36 行)。

接下来,调用difference微分计算生成标签(0 和 1),代表在连续两个交易日之间证券价格的变动方向:如果价格下降则为 0,如果价格上升则为 1(第 30 行)(参考第三章中Scala 中的时间序列部分的微分算子数据预处理)。

Naïve Bayes 模型的训练特征是从这些比率中提取的,通过在XTSeries单例的transpose方法中转置比率-时间序列矩阵(第 31 行)。

接下来,使用OneFoldXValidation类从features集中提取训练集和验证集,该类在第二章中交叉验证部分的单折交叉验证中介绍,Hello World!(第 32 行)。

备注

选择训练数据

在我们的例子中,训练集是简单地第一个trainRatio乘以数据集观察的大小。实际应用使用 K 折交叉验证技术来验证模型,如第二章中评估模型部分的K 折交叉验证中所述,Hello World!。一个更简单的替代方案是通过随机选择观察值来创建训练集,并使用剩余的数据进行验证。

工作流程的最后两个阶段包括通过实例化NaiveBayes类(第 33 行)并计算应用于股票价格、波动性和成交量的简单移动平均平滑系数的不同值的 F[1]分数来训练 Naive Bayes 模型(第 34 行)。

备注

隐式转换

NaiveBayes类操作IntDouble类型的元素,因此假设存在IntDouble之间的转换(查看有界)。Scala 编译器可能会生成警告,因为IntDouble的转换尚未定义。尽管 Scala 依赖于其自己的转换函数,但我建议您显式定义并控制您的转换函数:

implicit def intToDouble(n: Int): Double = n.toDouble

测试

下一个图表绘制了使用价格、成交量波动以及过去 n 个交易日的数据,来预测 IBM 股票方向的预测器 F[1]值的图表,其中 n 的范围从 1 到 12 个交易日:

测试

Naive Bayes 模型的验证 F1 度量图

前面的图表说明了平均周期(交易天数)的值对多项式朴素贝叶斯预测质量的影响,使用的是股票价格、波动性和成交量相对于平均周期的值。

从这个实验中,我们得出以下结论:

  • 使用平均价格、成交量波动预测股票走势并不很好。使用每周(相对于每日)收盘价的模型 F[1]分数在 0.68 到 0.74 之间变化(相对于 0.56 和 0.66)。

  • 使用每周收盘价进行的预测比使用每日收盘价进行的预测更准确。在这个特定的例子中,每周收盘价的分布比每日价格的分布更能反映中期趋势。

  • 预测与用于平均特征的周期多少有些独立。

多变量伯努利分类

之前的例子使用高斯分布来表示本质上为二元的特征(UP = 1DOWN = 0),以表示价值的变动。均值计算为x[i] = UP的观察次数与总观察次数的比率。

如第一部分所述,高斯分布对于非常大的标记数据集的连续特征或二元特征更为合适。该示例是伯努利模型的理想候选。

模型

伯努利模型与朴素贝叶斯分类器不同,它对没有观察到的特征x进行惩罚;朴素贝叶斯分类器忽略它[5:10]。

注意

伯努利混合模型

M8:对于特征函数f[k],当f[k]等于 1 时,如果特征被观察到,否则为 0,观察到特征x[k]属于类别C[j]的概率p,后验概率计算如下:

模型

实现

伯努利模型的实现包括在Likelihood类中使用伯努利密度方法bernoulli修改score函数,该方法定义在Stats对象中:

object Stats {
  def bernoulli(mean: Double, p: Int): Double = 
     mean*p + (1-mean)*(1-p)
def bernoulli(x: Double*): Double = bernoulli(x(0), x(1).toInt)
…

伯努利算法的第一个版本是直接实现M8数学公式。第二个版本使用Density (Double*) => Double类型的签名。

均值与高斯密度函数中的相同。二元特征以Int类型实现,值为UP = 1(相对于DOWN = 0),表示向上(相对于向下)的金融技术指标方向。

朴素贝叶斯与文本挖掘

多项式朴素贝叶斯分类器特别适合于文本挖掘。朴素贝叶斯公式在分类以下实体时非常有效:

  • 电子邮件垃圾邮件

  • 商业新闻故事

  • 电影评论

  • 按专业领域划分的技术论文

这第三个用例包括根据财务新闻预测股票的方向。有两种类型的新闻会影响特定公司的股票:

  • 宏观趋势:如冲突、经济趋势或劳动力市场统计等经济或社会新闻

  • 微观更新:与特定公司相关的财务或市场新闻,如收益、所有权变更或新闻稿

与特定公司相关的宏观经济新闻有可能影响投资者对公司的情绪,并可能导致其股价突然变动。另一个重要特征是投资者对新闻做出反应并影响股价的平均时间。

  • 长期投资者可能在几天或几周内做出反应

  • 短期交易者会在几小时内调整他们的头寸,有时甚至在同一交易时段内

市场对一家公司重大财务新闻的平均反应时间在以下图表中展示:

朴素贝叶斯和文本挖掘

新闻发布后投资者对股票价格的反应示意图

市场反应的延迟只有在反应时间的方差显著时才是一个相关特征。对于任何有关 TSLA 的新闻文章,市场反应延迟的频率分布相当恒定。这表明在 82%的情况下,股价在同一天内做出反应,如下面的条形图所示:

朴素贝叶斯和文本挖掘

投资者对股票价格的反应频率分布,在新闻发布后

市场反应延迟 1.75 天的频率峰值可以解释为,有些新闻在周末发布,投资者必须等到下周一才能推动股价上涨或下跌。另一个挑战是在考虑某些新闻可能重复、混淆或同时发生的情况下,将任何股价变动归因于特定的新闻发布。

因此,预测股票价格 pr[t+1] 的模型特征是某个术语 T[i] 在时间窗口 [t-n, t] 内出现的相对频率 f[i],其中 tn 是交易日。

以下图形模型正式描述了两个连续交易日 tt + 1 之间股票价格相对变化的因果关系或条件依赖性,给定媒体中某些关键术语出现的相对频率:

朴素贝叶斯和文本挖掘

根据财务新闻预测股票走势的贝叶斯模型

对于这个练习,观测集是最著名的金融新闻机构发布的新闻和文章的语料库,例如彭博社或 CNBC。第一步是制定一种方法来提取和选择与特定股票相关的最相关术语。

信息检索基础

信息检索和文本挖掘的全面讨论超出了本书的范围 [5:11]。为了简化,模型将依赖于一个非常简单的方案来提取相关术语并计算它们的相对频率。以下 10 个步骤的序列描述了从语料库中提取最相关术语的多种方法之一:

  1. 为每篇新闻文章创建或提取时间戳。

  2. 使用马尔可夫分类器提取每篇文章的标题、段落和句子。

  3. 使用正则表达式从每个句子中提取术语。

  4. 使用字典和度量标准(如 Levenstein 距离)更正错别字。

  5. 移除停用词。

  6. 进行 词干提取词形还原

  7. 提取词袋并生成 n-gramn 个术语的序列)列表。

  8. 应用使用最大熵或条件随机字段构建的 标记模型 来提取名词和形容词(例如,NNNNP 等)。

  9. 将术语与支持词义、下位词和同义词的字典相匹配,例如 WordNet

  10. 使用维基百科的存储库 DBpedia [5:12] 来消除词义歧义。

注意

从网络中提取文本

本节讨论的方法不包括从网络中搜索和提取新闻和文章的过程,这需要额外的步骤,如搜索、爬取和抓取 [5:13]。

实现

让我们将文本挖掘方法论模板应用于预测基于金融新闻的股票走势。该算法依赖于一系列七个简单的步骤:

  1. 搜索和加载与给定公司和其股票相关的新闻文章,作为 Document 类的 Ɗ[t] 文档。

  2. 使用正则表达式提取文章的 date: T 时间戳。

  3. 按时间戳对 Ɗ[t] 文档进行排序。

  4. 从每个 Ɗ**t 文档的内容中提取 {T[i,D]} 项。

  5. 聚合所有具有相同发布日期 tƊ[t] 文档的 {T[t,D]} 项。

  6. 计算每个 {T[i,D]} 项在日期 t 上的 rtf 相对频率,即其在所有在 t 日期发布的文章中出现的次数与整个语料库中该术语出现总次数的比率。

  7. 将相对频率标准化为平均每天文章数,nrtf

注意

文本分析指标

M9:在文章 a 中,术语(或关键词)t[i] 的相对频率定义为 n[i]^a 次出现与文章中所有术语出现的总次数的比率。

实现

M10:术语 t[i] 的出现频率相对于每天平均文章数进行归一化,其中 N[a] 是文章总数,N[d] 是调查的天数,其定义为以下:

实现

新闻文章是具有时间戳、标题和内容的 极简主义 文档,由 Document 类实现:

case class DocumentT <: AnyVal
(implicit f: T => Double)   

date时间戳的类型限制为Long类型,因此T可以转换为 JVM 的当前时间(行1)的毫秒数。

分析文档

本节专门用于实现简单的文本分析器。其目的是将一组Document类型的文档;在我们的案例中,是新闻文章,转换为关键词的相对频率分布。

TextAnalyzer类实现了ETransform类型的数据转换,如第二章中单调数据转换部分所述,Hello World!。它将文档序列转换为相对频率分布序列。

TextAnalyzer类有以下两个参数(行4):

  • 一个简单的文本解析器parser,从每篇新闻文章的标题和内容中提取关键词数组(行2)。

  • 一个lexicon类型,列出了用于监控与公司相关的新闻及其同义词的关键词。与每个关键词语义相似的同义词或术语定义在一个不可变映射中。

代码将如下所示:

type TermsRF = Map[String, Double]  
type TextParser = String => Array[String] //2
type Lexicon = immutable.Map[String, String]  //3
type Corpus[T] = Seq[Document[T]]

class TextAnalyzerT <: AnyVal(implicit f: T => Double)
  extends ETransformLexicon {

  type U = Corpus[T]    //5
  type V = Seq[TermsRF] //6

  override def |> : PartialFunction[U, Try[V]] = {
     case docs: U => Try( score(docs) )
  }

  def score(corpus: Corpus[T]): Seq[TermsRF]  //7
  def quantize(termsRFSeq: Seq[TermsRF]): //8
          Try[(Array[String], XVSeries[Double])]
  def count(term: String): Counter[String] //9
}

输入到数据转换|>中的U类型是语料库或新闻文章的序列(行5)。数据转换输出的V类型是TermsRF类型的相对频率分布序列(行6)。

score私有方法为类执行繁重的工作(行7)。quantize方法创建一个同质化的观察特征集(行8),而count方法计算具有相同发布日期的文档或新闻文章中术语或关键词的出现次数(行9)。

下面的图描述了文本挖掘过程的不同组件:

分析文档

文本挖掘过程组件的说明

提取相对术语的频率

让我们深入了解score方法:

def score(corpus: Corpus[T]): Seq[TermsRF] = {
  val termsCount = corpus.map(doc =>  //10
      (doc.date, count(doc.content))) //Seq[(T, Counter[String])]

  val termsCountMap = termsCount.groupBy( _._1).map{ 
     case (t, seq) => (t, seq.aggregate(new Counter[String])
                         ((s, cnt) => s ++ cnt._2, _ ++ _)) //11
  }
  val termsCountPerDate = termsCountMap.toSeq
         .sortWith( _._1 < _._1).unzip._2  //12
  val allTermsCounts = termsCountPerDate
          .aggregate(new Counter[String])((s, cnt) => 
                               s ++ cnt, _ ++ _) //13

  termsCountPerDate.map( _ /allTermsCounts).map(_.toMap) //14
}

执行score方法的第一个步骤是计算每个文档/新闻文章中lexicon类型关键词的出现次数(行10)。出现次数的计算是通过count方法实现的:

def count(term: String): Counter[String] = 
   parser(term)./:(new Counter[String])((cnt, w) =>   //16
   if(lexicon.contains(w)) cnt + lexicon(w) else cnt)

该方法依赖于Counter计数类,它继承自mutable.Map[String, Int],如附录 A 中Scala 编程下的Counter部分所述,基本概念。它使用折叠来更新与关键词相关的每个术语的计数(行16)。整个语料库的count术语是通过聚合所有文档的术语计数来计算的(行11)。

下一步是对每个时间戳在整个文档中关键词的计数进行汇总。通过调用groupBy高阶方法(第 11 行)生成以日期为键、关键词计数器为值的termsCountMap映射。接下来,score方法提取关键词计数的排序序列,termsCountPerDate(第 12 行)。使用整个语料库allTermsCounts中每个关键词的总计数来计算相对或归一化的关键词频率(公式M9M10)(第 14 行)。

生成特征

没有保证所有与特定出版日期相关的新闻文章都被用于模型中。quantize方法为缺失的新闻文章中的关键词分配相对频率0.0,如下表所示:

生成特征

按出版日期列出的关键词相对频率表

quantize方法将术语相对频率序列转换为关键词和观察值的对:

def quantize(termsRFSeq: Seq[TermsRF]): 
             Try[(Array[String], XVSeries[Double])] = Try {
  val keywords = lexicon.values.toArray.distinct //15
  val relFrequencies = 
      termsRFSeq.map( tf =>  //16
          keywords.map(key => 
              if(tf.contains(key)) tf.get(key).get else 0.0))
  (keywords, relFrequencies.toVector) //17
}

quantize方法从词典中提取关键词数组(第 15 行)。通过为在特定日期发布的新闻文章中未检测到的关键词分配相对0.0关键词频率,生成特征向量relFrequencies(第 16 行)。最后,在行 17 中生成(关键词和相对关键词频率)键值对。

注意

稀疏相对频率向量

文本分析和自然语言处理处理非常大的特征集,可能有数十万个特征或关键词。如果不是因为大多数关键词在每个文档中都不存在,这样的计算几乎是不可能的。使用稀疏向量和稀疏矩阵来减少训练期间的内存消耗是一种常见的做法。

测试

为了测试目的,让我们选择在 2 个月期间提及特斯拉汽车及其股票代码 TSLA 的新闻文章。

获取文本信息

让我们从实现和定义TextAnalyzer的两个组件开始:parsing函数和lexicon变量:

val pathLexicon = "resources/text/lexicon.txt"
val LEXICON = loadLexicon  //18

def parse(content: String): Array[String] = {
  val regExpr = "['|,|.|?|!|:|\"]"
  content.trim.toLowerCase.replace(regExpr," ") //19
  .split(" ") //20
  .filter( _.length > 2) //21
}

词典从文件中加载(第 18 行)。parse方法使用简单的regExpr正则表达式将任何标点符号替换为空格字符(第 19 行),这用作单词分隔符(第 20 行)。所有长度小于三个字符的单词都被忽略(第 21 行)。

让我们描述加载、解析和分析与公司特斯拉公司及其股票(股票代码 TSLA)相关的新闻文章的工作流程。

第一步是加载和清理在pathCorpus目录中定义的所有文章(语料库)(第 22 行)。这项任务由DocumentsSource类执行,如附录 A 中在Scala 编程部分下的数据提取部分所述,基本概念

val pathCorpus = "resources/text/chap5/"   //22
val dateFormat = new SimpleDateFormat("MM.dd.yyyy")
val pfnDocs = DocumentsSource(dateFormat, pathCorpus) |>  //23

val textAnalyzer = TextAnalyzerLong
val pfnText = textAnalyzer |>   //24

for {
  corpus <- pfnDocs(None)  //25
  termsFreq <- pfnText(corpus)  //26
  featuresSet <- textAnalyzer.quantize(termsFreq) //27
  expected <- Try(difference(TSLA_QUOTES, diffInt)) //28
  nb <- NaiveBayesDouble)//29
} yield {
  show(s"Naive Bayes model${nb.toString(quantized._1)}")
   …
}

文档源由数据输入文件路径和用于时间戳的格式完全定义(第23行)。文本分析器和其显式的pfnText数据转换被实例化(第24行)。文本处理管道由以下步骤定义:

  1. 使用pfnDoc部分函数(第25行)将输入源文件转换为语料库(一系列新闻文章)。

  2. 使用pfnText部分函数(第26行)将语料库转换为一系列termsFreq相对关键词频率向量。

  3. 使用quantize(第27行)将一系列相对关键词频率向量转换为featuresSet(参考第三章下时间序列在 Scala中的微分算子部分,数据预处理)。

  4. 使用对(featuresSet._2expected)作为训练数据创建二项式NaiveBayes模型(第29行)。

从特斯拉汽车的每日股票价格TSLA_QUOTES中提取预期的类别值(0,1)。

val TSLA_QUOTES = ArrayDouble

注意

语义分析

本例使用一个非常原始的语义图(词典)来展示多项式朴素贝叶斯算法的益处和内部工作原理。涉及情感分析或主题分析的商用应用需要更深入地理解语义关联,并使用高级生成模型,如潜在狄利克雷分配来提取主题。

评估文本挖掘分类器

下图描述了与特斯拉汽车或其股票代码 TSLA 相关的某些关键词出现的频率:

评估文本挖掘分类器

一部分与股票相关的术语相对频率图

下图展示了发布(或新闻文章)后的交易日股票价格预期变化方向:

评估文本挖掘分类器

特斯拉汽车股票价格和走势图

前面的图表显示了股票 TSLA 的历史价格及其方向(上升和下降)。为验证分类器而选择的 15%标记数据的分类中,F[1]得分为 0.71。需要记住的是,没有进行预处理或聚类来隔离最相关的特征/关键词。我们最初根据金融新闻中关键词出现的频率选择关键词。

有理由假设某些关键词对股票价格方向的影响比其他关键词更大。一个简单但有趣的练习是记录只使用具有特定关键词高出现次数的观察值进行验证的 F[1]分数,如下面的图表所示:

评估文本挖掘分类器

表示预测 TSLA 股票走势的主要关键词的条形图

前面的条形图显示,代表中国所有特斯拉汽车在中国活动提及的中国,以及涵盖所有充电站引用的充电器,对股票方向有显著的积极影响,概率平均达到 75%。风险类别下的术语对股票方向的负面影响概率为 68%,或对股票方向有 32%的积极影响。在剩余的八个类别中,有 72%的它们无法作为股票价格方向的预测指标。

这种方法可以用作选择特征,作为使用更复杂的分类器的替代方案,而不是互信息的特征选择。然而,它不应被视为选择特征的主要方法,而应被视为将朴素贝叶斯公式应用于具有非常少量相关特征的模型时的副产品。例如,在第四章中“降维”部分描述的主成分分析等技术,可用于降低问题的维度,使朴素贝叶斯成为一个可行的分类器。

优缺点

本章所选的示例并没有充分体现朴素贝叶斯分类器家族的多样性和准确性。

朴素贝叶斯算法是一个简单且健壮的生成分类器,它依赖于先验条件概率从训练数据集中提取模型。朴素贝叶斯模型具有其优点,如以下所述:

  • 它易于实现和并行化

  • 它的计算复杂度非常低:O((n+c)m),其中m是特征的数量,C是类别的数量,而n*是观察的数量

  • 它处理缺失数据

  • 它支持增量更新、插入和删除

然而,朴素贝叶斯并不是万能的。它有以下缺点:

  • 它需要一个大的训练集才能达到合理的准确性

  • 在现实世界中,特征独立性的假设并不实用

  • 它需要处理计数器的零频率问题

摘要

有一个原因使得朴素贝叶斯模型是机器学习课程中最早教授的有监督学习技术之一:它简单且健壮。实际上,这是当你考虑从标记数据集中创建模型时,首先应该想到的技术,只要特征是条件独立的。

本章还介绍了朴素贝叶斯作为文本挖掘应用的基础知识。

尽管它有诸多优点,但朴素贝叶斯分类器假设特征是条件独立的,这是一个无法总是克服的限制。在文档或新闻发布的分类中,朴素贝叶斯错误地假设术语在语义上是独立的:两个实体的年龄和出生日期高度相关。下一章中描述的判别分类器解决了朴素贝叶斯的一些限制[5:14]。

本章不处理时间依赖性、事件序列或观察到的特征和隐藏特征之间的条件依赖性。这些类型的依赖性需要不同的建模方法,这是第七章“序列数据模型”的主题。

第六章。回归与正则化

在第一章中,我们简要介绍了二元逻辑回归(单变量的二项逻辑回归)作为我们的第一个测试案例。其目的是说明判别分类的概念。还有许多更多的回归模型,从无处不在的普通最小二乘线性回归和逻辑回归[6:1]开始。

回归的目的是最小化损失函数,其中残差平方和RSS)是常用的一个。在第二章中“评估模型”部分下的“过拟合”章节中描述的过拟合问题可以通过向损失函数中添加惩罚项来解决。惩罚项是正则化这一更大概念的一个元素。

本章的第一部分将描述并实现线性最小二乘回归。第二部分将介绍正则化的概念,并通过岭回归的实现来展示。最后,将从分类模型的角度详细回顾逻辑回归。

线性回归

线性回归到目前为止是最广泛使用,或者至少是最常见的回归方法。术语通常与将模型拟合到数据的概念相关联,并通过计算平方误差和、残差平方和或最小二乘误差来最小化预期值和预测值之间的误差。

最小二乘问题分为以下两类:

  • 普通最小二乘

  • 非线性最小二乘

单变量线性回归

让我们从最简单的线性回归形式开始,即单变量回归,以便介绍线性回归背后的术语和概念。在最简单的解释中,单变量线性回归包括拟合一条线到一组数据点{x, y}

注记

M1:对于模型f,其特征x[j]的权重为w[j],标签(或预期值)为y[j]的单变量线性回归如下公式所示:

一元线性回归

在这里,w[1] 是斜率,w[0] 是截距,f 是最小化均方误差(RSS)的线性函数,而 (x[j], y[j]) 是一组 n 个观测值。

均方误差(RSS)也称为平方误差之和SSE)。对于 n 个观测值的均方误差MSE)定义为 RSS/n 的比率。

注意

术语

科学文献中关于回归的术语有时会有些混乱。回归权重也被称为回归系数或回归参数。在整个章节中,权重被称为 w,尽管参考书中也使用了 β

实现

让我们创建一个 SingleLinearRegression 参数化类来实现 M1 公式。线性回归是一种使用从数据隐式导出或构建的模型进行数据转换的方法。因此,简单线性回归实现了在第二章中“单调数据转换”部分所描述的 ITransform 特性,Hello World!

SingleLinearRegression 类接受以下两个参数:

  • 单变量观测值的 xt 向量

  • 预期值或标签的向量(行 1

代码如下:

class SingleLinearRegressionT <: AnyVal(implicit f: T => Double)
  extends ITransformT with Monitor[Double] {   //1
  type V = Double  //2

  val model: Option[DblPair] = train //3
  def train: Option[DblPair]
  override def |> : PartialFunction[T, Try[V]]
  def slope: Option[Double] = model.map(_._1)
  def intercept: Option[Double] = model.map(_._2)
}

Monitor 特性用于在训练期间收集配置文件信息(请参阅附录 A 中“实用类”下的Monitor部分,基本概念)。

该类必须定义 |> 预测方法的输出类型,它是一个 Double(行 2)。

注意

模型实例化

模型参数是通过训练计算得出的,无论模型是否实际经过验证,模型都会被实例化。商业应用需要使用如 K 折验证等方法来验证模型,正如在附录 A 中“不可变分类器设计模板”部分所描述的,基本概念

训练生成定义为回归权重(斜率和截距)的模型(行 3)。如果在训练期间抛出异常,则模型设置为 None

def train: Option[DblPair] = {
   val regr = new SimpleRegression(true) //4
   regr.addData(zipToSeries(xt, expected).toArray)  //5
   Some((regr.getSlope, regr.getIntercept))  //6
}

回归权重或系数,即模型元组,使用 Apache Commons Math 库中的stats.regression包的SimpleRegression类计算得出,使用true参数触发截距的计算(行 4)。输入时间序列和标签(或预期值)被压缩生成一个包含两个值(输入和预期)的数组(行 5)。model 使用训练期间计算的斜率和截距进行初始化(行 6)。

XTSeries 对象的 zipToSeries 方法在第三章中“Scala 中的时间序列”部分进行了描述,数据预处理

注意

private 与 private[this]

一个 private 值或变量只能被类的所有实例访问。声明为 private[this] 的值只能由 this 实例操作。例如,model 值只能由 SingleLinearRegressionthis 实例访问。

测试用例

对于我们的第一个测试用例,我们计算了铜 ETF(股票代码:CU)在 6 个月(2013 年 1 月 1 日至 2013 年 6 月 30 日)期间的单一变量线性回归。

val path = "resources/data/chap6/CU.csv"
for {
  price <- DataSource(path, false, true, 1) get adjClose //7
  days <- Try(Vector.tabulate(price.size)(_.toDouble)) //8
  linRegr <- SingleLinearRegressionDouble //9
} yield {
  if( linRegr.isModel ) {
    val slope = linRegr.slope.get
    val intercept = linRegr.intercept.get
    val error = mse(days, price, slope, intercept)//10
  }
  …
}

ETF CU 的每日收盘价从 CSV 文件中提取(第 7 行)作为预期值,使用 DataSource 实例,如 附录 A 中 数据提取和数据源 部分所述,基本概念。x 值 days 自动生成为一个线性函数(第 8 行)。预期值(price)和会话(days)是简单线性回归实例化的输入(第 9 行)。

一旦模型创建成功,测试代码计算预测值和预期值的 mse 均方误差(第 10 行):

def mse(
    predicted: DblVector, 
    expected: DblVector, 
    slope: Double, 
    intercept: Double): Double = {
  val predicted = xt.map( slope*_ + intercept)
  XTSeries.mse(predicted, expected)  //11
}

使用 XTSeriesmse 方法计算平均最小二乘误差(第 11 行)。原始股价和线性回归方程在下面的图表中绘制:

测试用例

总最小二乘误差为 0.926。

尽管单变量线性回归很方便,但它仅限于标量时间序列。让我们考虑多变量情况。

普通最小二乘回归

普通最小二乘回归通过最小化残差平方和来计算线性函数 y = f(x[0], x[2] … x[d]) 的参数 w。优化问题通过执行向量和矩阵运算(转置、求逆和替换)来解决。

注意

M2:损失函数的最小化由以下公式给出:

普通最小二乘回归

在这里,wj 是回归的权重或参数,(x[i], y[i])[i:0, n-1] 是向量 x 和预期输出值 yn 个观测值,f 是线性多元函数,y = f (x0, x1, …,xd)

对于线性回归,有几种方法可以最小化残差平方和:

  • 使用 nd 列矩阵的 QR 分解来解决具有 d 个变量(权重)的 n 个方程组,该矩阵表示具有 n >= dd 维向量 n 个观测值的时间序列[6:2]

  • 奇异值分解在观测-特征矩阵上,在维度 d 超过观测数 n 的情况下[6:3]

  • 梯度下降 [6:4]

  • 随机 梯度下降 [6:5]

这些矩阵分解和优化技术的概述可以在 附录 A 的 线性代数优化技术总结 部分找到,基本概念

QR 分解为最常见的最小二乘问题生成最小的相对误差 MSE。该技术在我们的最小二乘回归实现中得到了应用。

设计

最小二乘回归的实现利用了 Apache Commons Math 库对普通最小二乘回归的实现 [6:6]。

本章描述了几种回归算法。定义一个通用的 Regression 特质,以定义回归算法的关键元素组件是有意义的。

  • 一个 RegressionModel 类型的模型(第 1 行)

  • 两种访问回归模型组件的方法:weightsrss(第 2 行和第 3 行)

  • 一个 train 多态方法,实现了此特定回归算法的训练(第 4 行)

  • 一个将 train 包装到 Try 单子中的 training 受保护方法

代码如下:

trait Regression {
  val model: Option[RegressionModel] = training //1

  def weights: Option[DblArray] = model.map( _.weights)//2
  def rss: Option[Double] = model.map(_.rss) //3
  def isModel: Boolean = model != None

  protected def train: RegressionModel  //4
  def training: Option[RegressionModel] = Try(train) match {
    case Success(_model) => Some(_model)
    case Failure(e) => e match {
      case err: MatchError => { … }        case _ => { … }
    }
  }
}

模型简单地由其 weights 和残差平方和定义(第 5 行):

case class RegressionModel(  //5
   val weights: DblArray, val rss: Double) extends Model

object RegressionModel {
  def dotT <: AnyVal(implicit f: T => Double): Double = 
     x.zip(weights.drop(1))  
      .map{ case(_x, w) => _x*w}.sum + weights.head //6
}

RegressionModel 伴生对象实现了 weights 回归和观察值 xdot 内积的计算,dot 方法在本章中得到了广泛应用。

让我们创建一个 MultiLinearRegression 类,作为一个数据转换,其模型隐式地从输入数据(训练集)推导出来,如 第二章 中 单调数据转换 部分所述,Hello World!

class MultiLinearRegressionT <: AnyVal(implicit f: T => Double)
  extends ITransform[Array[T]](xt) with Regression 
       with Monitor[Double] { //7
  type V = Double  //8

  override def train: Option[RegressionModel] //9
  override def |> : PartialFunction[Array[T], Try[V]] //10
}

MultiLinearRegression 类接受两个参数:xt 观察值的多元时间序列和 expected 值的向量(第 7 行)。该类实现了 ITransform 特质,并需要定义预测或回归的输出值类型,V 作为 Double(第 8 行)。MultiLinearRegression 的构造函数通过训练创建 model(第 9 行)。ITransform 特质的 |> 方法实现了多线性回归的运行时预测(第 10 行)。

Monitor 特质用于在训练过程中收集配置文件信息(参见 附录 A 下 实用类 部分的 Monitor 部分,基本概念)。

注意

回归模型

RSS 包含在模型中,因为它向客户端代码提供了有关用于最小化损失函数的底层技术准确性的重要信息。

多线性回归的不同组件之间的关系在以下 UML 类图中描述:

设计

多线性(OLS)回归的 UML 类图

UML 图省略了Monitor或 Apache Commons Math 组件等辅助特性和类。

实现

训练是在MultiLinearRegression类的实例化过程中进行的(参考附录 A 中的不可变分类器设计模板部分,基本概念):

def train: RegressionModel = {
  val olsMlr = new MultiLinearRAdapter   //11
  olsMlr.createModel(expected, data)  //12
  RegressionModel(olsMlr.weights, olsMlr.rss) //13
}

通过对MultiLinearRAdapter适配器类的olsMlr引用访问 Apache Commons Math 库中普通最小二乘回归的功能(第11行)。

train方法通过调用 Apache Commons Math 类的OLSMultipleLinearRegression(第12行)创建模型,并返回回归模型(第13行)。该类的方法通过MultiLinearRAdapter适配器类访问:

class MultiLinearRAdapter extends OLSMultipleLinearRegression {
  def createModel(y: DblVector, x: Vector[DblArray]): Unit = 
     super.newSampleData(y.toArray, x.toArray)

  def weights: DblArray = estimateRegressionParameters
  def rss: Double = calculateResidualSumOfSquares
}

createModelweightsrss方法将请求路由到OLSMultipleLinearRegression中的相应方法。

为了捕获 Apache Commons Math 库抛出的不同类型的异常,如MathIllegalArgumentExceptionMathRuntimeExceptionOutOfRangeExceptiontrain方法的返回类型使用了 Scala 异常处理单子Try{}

注意

异常处理

使用Try {} Scala 异常处理单子包装第三方方法调用的调用对于几个原因很重要:

  • 通过将代码与第三方库分离,这使得调试变得更加容易。

  • 当可能时,它允许你的代码通过重新执行具有替代第三方库方法的相同函数来从异常中恢复。

普通最小二乘回归的预测算法通过|>数据转换实现。该方法根据模型和输入值x预测输出值:

def |> : PartialFunction[Array[T], Try[V]] = {
  case x: Array[T] if isModel && 
       x.length == model.get.size-1  
         => Try( dot(x, model.get) ) //14
}

使用在前面本节中引入的RegressionModel单例中定义的dot方法计算预测值(第14行)。

测试用例 1 – 趋势

趋势包括从时间序列中提取长期运动。使用多元最小二乘回归检测趋势线。这个第一个测试的目的是评估普通最小二乘回归的滤波能力。

回归是在铜 ETF(股票代码:CU)的相对价格变动上进行的。选定的特征是波动性成交量,标签或目标变量是两个连续y交易会话之间的价格变动。

交易数据和指标命名规范在附录 A 的技术分析部分下的交易数据章节中描述,基本概念

在以下图表中绘制了 2013 年 1 月 1 日至 2013 年 6 月 30 日之间 CU 的成交量、波动性和价格变动:

测试用例 1 – 趋势

铜 ETF 的价格变动、波动性和交易量图表

让我们编写客户端代码来计算多元线性回归,价格变化 = w[0] + 波动性.w[1] + 体积.w[2]

import  YahooFinancials._
val path = "resources/data/chap6/CU.csv"  //15
val src = DataSource(path, true, true, 1)  //16

for {
  price <- src.get(adjClose)  //17
  volatility <- src.get(volatility)  //18
  volume <- src.get(volume)  //19
  (features, expected) <- differentialData(volatility, 
                     volume, price, diffDouble)  //20
  regression <- MultiLinearRegressionDouble  //21
} yield {
  if( regression.isModel ) {
    val trend = features.map(dot(_,regression.weights.get))
    display(expected, trend)  //22
  }
}

让我们看看执行测试所需的步骤:它包括收集数据、提取特征和期望值,以及训练多线性回归模型:

  1. 定位 CSV 格式的数据源文件(行15)。

  2. 为交易会话收盘pricevolatility会话和volume会话创建一个数据源提取器,DataSource,用于 ETF CU(行16)。

  3. 使用DataSource转换提取 ETF 的价格(行17)、交易会话内的波动性(行18)和会话期间的交易量(行19)。

  4. 生成标记数据作为特征对(ETF 的相对波动性和相对成交量)和期望结果(0,1)的配对,用于训练模型,其中1代表价格上涨,0代表价格下跌(行20)。XTSeries单例的differentialData通用方法在第三章的Scala 中的时间序列部分中描述,数据预处理

  5. 使用features集和每日 ETF 价格预期的expected变化(行21)实例化多线性回归。

  6. 使用 JFreeChart 显示期望值和趋势值(行22)。

期望值的时间序列和回归预测的数据绘制在以下图表中:

测试案例 1 – 趋势

根据波动性和成交量对铜 ETF 的价格变动和最小二乘回归

最小二乘回归模型通过以下线性函数定义来估计价格变动:

价格(t+1)-价格(t) = -0.01 + 0.014 波动性 – 0.0042 体积

估计的价格变动(前一个图表中的虚线)代表了过滤掉噪声的长期趋势。换句话说,最小二乘回归作为一种简单的低通滤波器,作为一些过滤技术(如离散傅里叶变换或卡尔曼滤波器)的替代,如第三章的数据预处理部分所述 [6:7]。

虽然趋势检测是最小二乘回归的一个有趣应用,但该方法对于时间序列的过滤能力有限 [6:8]:

  • 它对异常值敏感

  • 需要丢弃第一个和最后几个观测值

  • 作为一种确定性方法,它不支持噪声分析(分布、频率等)

测试案例 2 – 特征选择

第二个测试案例与特征选择相关。目标是发现哪个初始特征子集能生成最准确的回归模型,即训练集上残差平方和最小的模型。

让我们考虑一个初始特征集{x[i]}。目标是使用最小二乘回归估计与观测集最相关的特征子集{x[id]}。每个特征子集都与一个fj模型相关联:

测试案例 2 – 特征选择

特征集选择图

当特征集较小时,使用普通最小二乘回归来选择模型参数w。对大型原始特征集的每个子集进行回归并不实际。

备注

M3:特征选择可以用以下数学公式表示:

测试案例 2 – 特征选择

在这里,w[jk]是回归函数/模型的权重,(x[i], y[i])[i:0,n-1]是向量xn个观测值和预期输出值yf是线性多元函数,y = f (x[0], x[1], …,x[d])

让我们考虑从 2009 年 1 月 1 日到 2013 年 12 月 31 日这段时间内的以下四个金融时间序列:

  • 人民币兑美元汇率

  • 标普 500 指数

  • 黄金现货价格

  • 10 年期国债价格

问题是要估计哪个 S&P 500 指数、金价和 10 年期国债价格变量的组合与人民币汇率的相关性最高。出于实际考虑,我们使用交易所交易基金 CYN 作为人民币/美元汇率(类似地,SPY、GLD 和 TLT 分别代表标普 500 指数、现货金价和 10 年期国债价格)的代理。

备注

特征提取自动化

本节中的代码实现了一个使用任意固定模型集的特征提取。该过程可以使用优化器(梯度下降、遗传算法等)通过将1/RSS作为目标函数来轻松自动化。

要评估的模型数量相对较小,因此对每个组合计算 RSS 的临时方法是可以接受的。让我们看一下以下图表:

测试案例 2 – 特征选择

人民币汇率、金价、10 年期国债价格和标普 500 指数的图表

getRss方法实现了给定一组xt观测值、y预期(平滑)值和featureLabels特征标签的 RSS 值的计算,然后返回一个文本结果:

def getRss(
     xt: Vector[DblArray], 
     expected: DblVector, 
     featureLabels: Array[String]): String = {

  val regression = 
         new MultiLinearRAdapterDouble //23
  val modelStr = regression.weights.get.view
        .zipWithIndex.map{ case( w, n) => {
    val weights_str = format(w, emptyString, SHORT)
    if(n == 0 ) s"${featureLabels(n)} = $weights_str"
    else s"${weights_str}.${featureLabels(n)}"
  }}.mkString(" + ")
  s"model: $modelStr\nRSS =${regression.get.rss}" //24
}

getRss方法仅通过实例化多线性回归类(第 23 行)来训练模型。一旦在MultiLinearRegression类的实例化过程中训练了回归模型,回归权重的系数和 RSS 值就会被字符串化(第 24 行)。getRss方法被调用于 ETF、GLD、SPY 和 TLT 变量对 CNY 标签的任何组合。

让我们看一下下面的测试代码:

val SMOOTHING_PERIOD: Int = 16  //25
val path = "resources/data/chap6/"
val symbols = ArrayString //26
val movAvg = SimpleMovingAverageDouble //27

for {
  pfnMovAve <- Try(movAvg |>)  //28
  smoothed <- filter(pfnMovAve)  //29
  models <- createModels(smoothed)  //30
  rsses <- Try(getModelsRss(models, smoothed)) //31
  (mses, tss) <- totalSquaresError(models,smoothed.head) //32
} yield {
   s"""${rsses.mkString("\n")}\n${mses.mkString("\n")}
      | \nResidual error= $tss".stripMargin
}

数据集很大(1,260 个交易日)且噪声足够,需要使用 16 个交易日的简单移动平均进行过滤(第 25 行)。测试的目的是评估四个 ETF(CNY、GLD、SPY 和 TLT)之间可能的相关性。执行测试时实例化了简单移动平均(第 27 行),如第三章中 简单移动平均 部分所述,第三章,数据预处理

工作流程执行以下步骤:

  1. 实例化一个简单的移动平均 pfnMovAve 部分函数(第 28 行)。

  2. 使用 filter 函数(第 29 行)为 CNY、GLD、SPY 和 TLT ETFs 生成平滑的历史价格:

    Type PFNMOVAVE = PartialFunction[DblVector, Try[DblVector]]
    
    def filter(pfnMovAve: PFNMOVEAVE): Try[Array[DblVector]] = Try {
       symbols.map(s => DataSource(s"$path$s.csv", true, true, 1))
          .map( _.get(adjClose) )
          .map( pfnMovAve(_)).map(_.get)
    
  3. 使用 createModels 方法(第 30 行)为每个模型生成特征列表:

    type Models = List[(Array[String], DblMatrix)]
    
    def createModels(smoothed: Array[DblVector]): Try[Models] = 
    Try {
      val features = smoothed.drop(1).map(_.toArray)  //33
      List[(Array[String], DblMatrix)](   //34
       (ArrayString, features.transpose),
       (ArrayString,features.drop(1).transpose),
       (ArrayString,features.take(2).transpose),
       (ArrayString, features.zipWithIndex
                          .filter( _._2 != 1).map( _._1).transpose),
       (ArrayString, features.slice(1,2).transpose)
       )
    }
    

    使用 CNY 的平滑值作为预期值。因此,它们被从特征列表中移除(第 33 行)。通过向特征列表中添加或删除元素来评估五个模型(第 34 行)。

  4. 接下来,工作流程使用 getModelsRss(第 31 行)计算所有模型的残差平方和。该方法对每个模型调用 getRss,该函数在本节前面已介绍(第 35 行):

    def getModelsRss(
        models: Models, 
        y: Array[DblVector]): List[String] = 
      models.map{ case (labels, m) => 
             s"${getRss(m.toVector, y.head, labels)}" }  //35
    
    
  5. 最后,工作流程的最后一个步骤是计算每个模型的 mses 平均平方误差和总平方误差(第 33 行):

    def totalSquaresError(
        models: Models, 
        expected: DblVector): Try[(List[String], Double)] = Try {
    
      val errors = models.map{case (labels,m) => 
    rssSum(m, expected)._1}//36
      val mses = models.zip(errors)
               .map{case(f,e) => s"MSE: ${f._1.mkString(" ")} = $e"}
      (mses, Math.sqrt(errors.sum)/models.size)  //37
    }
    

totalSquaresError 方法通过为每个模型累加 RSS 值 rssSum 来计算每个模型的误差(第 36 行)。该方法返回一个包含每个模型的平均平方误差数组和总平方误差的数组(第 37 行)。

RSS 并不总是能准确展示回归模型的适用性。回归模型的适用性通常通过 r² 统计量 来评估。r² 值是一个数字,表示数据与统计模型拟合的程度。

注意

M4:RSS 和 r² 统计量由以下公式定义:

测试案例 2 – 特征选择

r² 统计量的计算实现很简单。对于每个模型 f[j]rssSum 方法计算由 M4 公式定义的元组(rss 和最小二乘误差):

def rssSum(xt: DblMatrix, expected: DblVector): DblPair = {
  val regression = 
         MultiLinearRegressionDouble //38
  val pfnRegr = regression |>  //39
  val results = sse(expected.toArray, xt.map(pfnRegr(_).get))
  (regression.rss, results) //40
}

rssSum 方法实例化 MultiLinearRegression 类(第 38 行),检索 RSS 值,然后验证 pfnRegr 回归模型(第 39 行)与预期值(第 40 行)的匹配度。测试的输出结果如下截图所示:

测试案例 2 – 特征选择

输出结果清楚地显示,三个变量回归,CNY = f (SPY, GLD, TLT),是最准确或最合适的模型,其次是 CNY = f (SPY, TLT)。因此,特征选择过程生成了特征集,{SPY, GLD, TLT}

让我们绘制模型与原始数据的关系图:

测试案例 2 – 特征选择

对人民币 ETF(CNY)的普通最小二乘回归

回归模型对原始人民币时间序列进行了平滑处理。它去除了除了最显著的价格变动之外的所有变动。绘制每个模型 r²值的图表证实,三个特征模型CNY=f (SPY, GLD, TLT)是最准确的:

测试案例 2 – 特征选择

注意

一般线性回归

线性回归的概念不仅限于多项式拟合模型,如y = w[0] + w[1].x + w[2].x² + …+ w[n]x^n。回归模型也可以定义为基函数的线性组合,如ϕ[j]: y = w[0] + w[1].ϕ1 + w[2]ϕ2 + … + w[n].ϕn [6:9]。

正则化

寻找回归参数的普通最小二乘法是最大似然的一个特例。因此,回归模型在过拟合方面面临着与其他任何判别模型相同的挑战。您已经知道,正则化用于减少模型复杂度并避免过拟合,正如在第二章中的过拟合部分所述,Hello World!

L[n]粗糙度惩罚

正则化包括向损失函数(或回归分类中的均方误差)中添加一个J(w)惩罚函数,以防止模型参数(也称为权重)达到高值。一个非常适合训练集的模型往往具有许多具有相对较大权重的特征变量。这个过程被称为收缩。实际上,收缩涉及将一个以模型参数为参数的函数添加到损失函数中(M5):

L[n]粗糙度惩罚

惩罚函数与训练集x,y完全独立。惩罚项通常表示为模型参数(或权重)w[d]的范数的函数的幂。对于一个D维度的模型,通用的L[p] -范数定义为以下(M6):

L[n]粗糙度惩罚

注意

符号

正则化适用于与观测值相关的参数或权重。为了与我们的符号一致,w[0]是截距值,正则化适用于参数w[1]…w[d]

最常用的两种正则化惩罚函数是 L[1]和 L[2]。

注意

机器学习中的正则化

正则化技术不仅限于线性或逻辑回归。任何最小化残差平方和的算法,如支持向量机或前馈神经网络,都可以通过向均方误差添加粗糙度惩罚函数来进行正则化。

应用到线性回归的 L[1]正则化被称为lasso 正则化岭回归是一种使用 L[2]正则化惩罚的线性回归。

你可能会想知道哪种正则化对给定的训练集有意义。简而言之,L[2] 和 L[1] 正则化在计算效率、估计和特征选择方面有所不同 [6:10] [6:11]:

  • 模型估计:L[1] 生成的回归参数估计比 L[2] 更稀疏。对于大型非稀疏数据集,L[2] 的估计误差比 L[1] 小。

  • 特征选择:L[1] 在减少具有较高值的特征回归权重方面比 L[2] 更有效。因此,L1 是一个可靠的特性选择工具。

  • 过拟合:L[1] 和 L[2] 都能减少过拟合的影响。然而,L[1] 在克服过拟合(或模型过度复杂性)方面具有显著优势;因此,L[1] 更适合于特征选择。

  • 计算:L[2] 有助于更高效的计算模型。损失函数和 L[2] 惩罚项 的和是一个连续且可微的函数,其第一和第二导数可以计算(凸最小化)。L1 项是 |w[i]| 的和,因此不可微。

备注

术语

岭回归有时被称为 惩罚最小二乘回归。L[2] 正则化也称为 权重衰减

让我们实现岭回归,然后评估 L[2]-范数惩罚因子的影响。

岭回归

岭回归是一个具有 L[2]-范数惩罚项的多变量线性回归(M7):

岭回归

岭回归参数的计算需要解决与线性回归相似的线性方程组。

备注

M8:对于输入数据集 X、正则化因子 λ 和期望值向量 y 的岭回归闭式矩阵表示定义为以下(I 是单位矩阵):

岭回归

M9:矩阵方程通过以下 QR 分解求解:

岭回归

设计

岭回归的实现将 L[2] 正则化项添加到 Apache Commons Math 库的多元线性回归计算中。RidgeRegression 方法与它们的普通最小二乘对应方法具有相同的签名,除了 lambda L[2] 惩罚项(行 1):

class RidgeRegressionT <: AnyVal(implicit f: T => Double)
   extends ITransform[Array[T]](xt) with Regression 
       with Monitor[Double] { //2

  type V = Double //3
  override def train: Option[RegressionModel]  //4
  override def |> : PartialFunction[Array[T], Try[V]]
}

RidgeRegression 类被实现为一个 ITransform 数据转换,其模型隐式地从输入数据(训练集)中导出,如第二章中 Monadic 数据转换 部分所述,Hello World!(行 2)。|> 预测函数的输出 V 类型是 Double(行 3)。模型在类的实例化过程中通过训练创建(行 4)。

Ridge 回归的不同组件之间的关系在以下 UML 类图中描述:

设计

Ridge 回归的 UML 类图

UML 图省略了如Monitor或 Apache Commons Math 组件等辅助特性和类。

实现

让我们看看训练方法train

def train: RegressionModel = {
  val mlr = new RidgeRAdapter(lambda, xt.head.size) //5
  mlr.createModel(data, expected) //6
  RegressionModel(mlr.getWeights, mlr.getRss)  //7
}

它相当简单;它初始化并执行了RidgeRAdapter类(行5)中实现的回归算法,该类作为适配器,适配org.apache.commons.math3.stat.regression包中内部 Apache Commons Math 库的AbstractMultipleLinearRegression类(行6)。该方法返回一个完全初始化的回归模型,类似于普通最小二乘回归(行7)。

让我们看看RidgeRAdapter适配器类:

class RidgeRAdapter(
    lambda: Double, 
    dim: Int) extends AbstractMultipleLinearRegression {
  var qr: QRDecomposition = _  //8

  def createModel(x: DblMatrix, y: DblVector): Unit ={ //9
    this.newXSampleData(x) //10
    super.newYSampleData(y.toArray)
  }
  def getWeights: DblArray = calculateBeta.toArray //11
  def getRss: Double = rss
}

RidgeRAdapter类的构造函数接受两个参数:lambda L[2]惩罚参数和观察中的特征数量dimAbstractMultipleLinearRegression基类中的 QR 分解不处理惩罚项(行8)。因此,必须在createModel方法(行9)中重新定义模型的创建,这需要覆盖newXSampleData方法(行10):

override protected def newXSampleData(x: DblMatrix): Unit =  {
  super.newXSampleData(x)    //12
  val r: RealMatrix = getX
  Range(0, dim).foreach(i => 
        r.setEntry(i, i, r.getEntry(i,i) + lambda) ) //13
  qr = new QRDecomposition(r) //14
}

newXSampleData方法通过向其对角线元素添加lambda系数(行13)来覆盖默认的观察-特征r矩阵(行12),然后更新 QR 分解组件(行14)。

Ridge 回归模型的权重通过在calculateBeta覆盖方法(行15)中实现M6公式(行11)来计算:

override protected def calculateBeta: RealVector =
   qr.getSolver().solve(getY()) //15

普通最小二乘回归的预测算法通过|>数据转换实现。该方法根据模型和输入的x值(行16)预测输出值:

def |> : PartialFunction[Array[T], Try[V]] = {
  case x: Array[T] if(isModel && 
      x.length == model.get.size-1) => 
        Try( dot(x, model.get) ) //16
}

测试案例

测试案例的目标是确定 L[2]惩罚对 RSS 值的影响,然后比较预测值与原始值。

让我们考虑与铜 ETF(符号:CU)每日价格变动回归相关的第一个测试案例,使用股票每日波动性和成交量作为特征。观察提取的实现与上一节中描述的最小二乘回归相同:

val LAMBDA: Double = 0.5
val src = DataSource(path, true, true, 1)  //17

for {
  price <- src.get(adjClose)   //18
  volatility <- src.get(volatility) //19
  volume <- src.get(volume)  //20
  (features, expected) <- differentialData(volatility, 
              volume, price, diffDouble) //21
  regression <- RidgeRegressionDouble  //22
} yield {
  if( regression.isModel ) {
    val trend = features
               .map( dot(_, regression.weights.get) )  //23

    val y1 = predict(0.2, expected, volatility, volume) //24
    val y2 = predict(5.0, expected, volatility, volume)
    val output = (2 until 10 by 2).map( n => 
          predict(n*0.1, expected, volatility, volume) )
  }
}

让我们看看执行测试所需的步骤。这些步骤包括收集数据、提取特征和期望值,以及训练 Ridge 回归模型:

  1. 使用DataSource转换创建 ETF CU 的价格交易时段收盘、波动性时段和成交量时段的数据源提取器(行17)。

  2. 提取 ETF(交易代码:CU)的收盘价格(行18),交易时段内的波动性(行19),以及同一时段内的成交量(行20)。

  3. 生成标记数据作为特征对(ETF 的相对波动性和相对成交量)以及用于训练模型的预期结果 {0, 1},其中1代表价格上涨,0代表价格下跌(第 21 行)。XTSeries单例的differentialData通用方法在第三章的Scala 中的时间序列部分中描述,数据预处理

  4. 使用features集和每日股票价格预期的变化(第 22 行)来实例化岭回归。

  5. 使用RegressionModel单例的dot函数计算trend值(第 23 行)。

  6. 通过predict方法实现岭回归的执行(第 24 行)。

代码如下:

def predict(
    lambda: Double, 
    deltaPrice: DblVector, 
    volatility: DblVector, 
    volume: DblVector): DblVector = {

  val observations = zipToSeries(volatility, volume)//25
  val regression = new RidgeRegressionDouble
  val fnRegr = regression |> //26
  observations.map( fnRegr(_).get)  //27
}

观察值是从volatilityvolume时间序列中提取的(第 25 行)。将fnRegr岭回归的预测方法(第 26 行)应用于每个观察值(第 27 行)。RSS 值,rss,以不同λ值绘制,如下面的图表所示:

测试案例

铜 ETF 的 RSS 与λ的图形

随着λ的增加,残差平方和减少。曲线似乎在λ ≈ 1 附近达到最小值。λ = 0 的情况对应于最小二乘回归。

接下来,让我们绘制λ在 1 到 100 之间变化的 RSS 值图:

测试案例

铜 ETF 的 RSS 与一个大的 Lambda 值的图形

这次,RSS 值随着λ的增加而增加,在λ > 60 时达到最大值。这种行为与其他发现[6:12]一致。随着λ的增加,过拟合的成本更高,因此 RSS 值增加。

让我们使用不同值的 lambda (λ) 来绘制铜 ETF 的预测价格变化图:

测试案例

岭回归在铜 ETF 价格变化上的图形,λ值可变

铜 ETF 的原价变化 Δ = price(t + 1) - price(t),以λ = 0 绘制。让我们分析不同λ值的预测模型的行为:

  • 预测值 λ = 0.8 与原始数据非常相似。

  • 预测值 λ = 2 沿着原始数据模式,但大幅度变化(峰值和低谷)减少。

  • 预测值 λ = 5 对应于一个平滑的数据集。原始数据模式得以保留,但价格变化的幅度显著降低。

在第一章的让我们试试看部分简要介绍的逻辑回归是下一个要讨论的逻辑回归模型。逻辑回归依赖于优化方法。在深入研究逻辑回归之前,让我们先快速复习一下优化课程。

数值优化

本节简要介绍了可以应用于最小化损失函数的不同的优化算法,无论是否有惩罚项。这些算法在附录 A 的优化技术总结部分有更详细的描述,基本概念

首先,让我们定义最小二乘问题。损失函数的最小化包括消除一阶导数,这反过来又生成一个包含 D 个方程的系统(也称为梯度方程),其中 D 是回归权重(参数)的数量。权重通过使用数值优化算法求解该方程组进行迭代计算。

注意

M10:基于最小二乘法的残差 r[i]、权重 w、模型 f、输入数据 x[i] 和期望值 y[i] 的损失函数定义如下:

数值优化

M10:在损失函数 L 最小化后,生成具有雅可比矩阵 J 的梯度方程(参见附录 A 中的数学部分,基本概念)的定义如下:

数值优化

M11:使用泰勒级数在模型 f 上对权重 w 进行 k 次迭代的迭代近似定义为以下内容:

数值优化

逻辑回归是一个非线性函数。因此,它需要最小化平方和的非线性最小化。非线性最小二乘问题的优化算法可以分为两类:

  • 牛顿(或二阶技术):这些算法通过计算二阶导数(Hessian 矩阵)来计算消除梯度的回归权重。这一类别中最常见的两种算法是高斯-牛顿法和 Levenberg-Marquardt 方法(参见附录 A 中的非线性最小二乘最小化部分,基本概念)。这两种算法都包含在 Apache Commons Math 库中。

  • 拟牛顿(或一阶技术):一阶算法不计算而是估计从雅可比矩阵中得到的平方残差的一阶导数。这些方法可以最小化任何实值函数,而不仅仅是平方和。这一类算法包括 Davidon-Fletcher-Powell 和 Broyden-Fletcher-Goldfarb-Shannon 方法(参见附录 A 中的拟牛顿算法部分,基本概念)。

逻辑回归

尽管名为“逻辑回归”,但它实际上是一个分类器。事实上,由于它的简单性和能够利用大量优化算法的能力,逻辑回归是最常用的判别学习技术之一。该技术用于量化观察到的目标(或预期)变量 y 与它所依赖的一组变量 x 之间的关系。一旦模型创建(训练)完成,它就可以用于对实时数据进行分类。

逻辑回归可以是二项分类(两个类别)或多项分类(三个或更多类别)。在二项分类中,观察到的结果被定义为 {true, false}, {0, 1}, 或 {-1, +1}

逻辑函数

线性回归模型中的条件概率是其权重的线性函数[6:13]。逻辑回归模型通过将条件概率的对数定义为参数的线性函数来解决非线性回归问题。

首先,让我们介绍逻辑函数及其导数,它们被定义为以下(M12):

逻辑函数

逻辑函数及其导数在以下图中展示:

逻辑函数

逻辑函数及其导数的图形

本节的剩余部分致力于将多元逻辑回归应用于二项分类。

二项分类

逻辑回归因其几个原因而广受欢迎;以下是一些原因:

  • 它在大多数统计软件包和开源库中都是可用的

  • 它的 S 形描述了多个解释变量的综合效应

  • 它的值域 [0, 1] 从概率的角度来看是直观的

让我们考虑使用两个类别的分类问题。如第二章中“验证”部分所述[part0165.xhtml#aid-4TBCQ2 "第二章。Hello World!"],“Hello World!”,即使是最好的分类器也会产生假阳性和假阴性。二项分类的训练过程在以下图中展示:

二项分类

一个二维数据集的二项分类示意图

训练的目的是计算将观察值分为两类或类别的 超平面。从数学上讲,n 维空间(特征数量)中的超平面是 n - 1 维的子空间,如第四章 第 4. 无监督学习中的 流形 部分所述,无监督学习。三维空间中的分隔超平面是一个曲面。二维问题(平面)的分隔超平面是一条线。在我们前面的例子中,超平面将训练集分割成两个非常不同的类别(或组),类别 1类别 2,试图减少重叠(假阳性和假阴性)。

超平面的方程定义为回归参数(或权重)与特征的点积的逻辑函数。

逻辑函数强调了由超平面分隔的两个训练观察组之间的差异。它 将观察值推向 分隔超平面的任一类别。

在两个类别 c1c2 及其相应概率的情况下,p(C=c1| X=x[i]|w) = p(x[i]|w)p(C=c2 |X= x[i]|w) = 1- p(x[i]|w),其中 w 是模型参数集或权重,在逻辑回归的情况下。

注意

逻辑回归

M13:给定回归权重 wN 个观察值 x[i] 的对数似然定义为:

二项式分类

M14:使用逻辑函数对 N 个观察值 x[i]d 个特征 {x[ij]} [j=0;d-1] 的条件概率 p(x|w) 定义如下:

二项式分类

M15:对于具有权重 w、输入值 x [i] 和预期二元结果 y 的二项式逻辑回归,其平方误差和 sse 如下:

二项式分类

M16:通过最大化给定输入数据 x[i] 和预期结果(标签)y[i] 的对数似然来计算逻辑回归的权重 w 定义如下:

二项式分类

让我们使用 Apache Commons Math 库来实现不带正则化的逻辑回归。该库包含几个最小二乘优化器,允许您为逻辑回归类 LogisticRegression 中的损失函数指定 optimizer 最小化算法。

LogisticRegression 类的构造函数遵循一个非常熟悉的模式:它定义了一个 ITransform 数据转换,其模型隐式地从输入数据(训练集)中导出,如第二章 第 2. Hello World!中的 单子数据转换 部分所述,Hello World!(行 2)。|> 预测器的输出是一个类别 ID,因此输出 V 的类型是 Int(行 3):

class LogisticRegressionT <: AnyVal  //1
    (implicit f: T => Double)
   extends ITransform[Array[T]](xt) with Regression 
       with Monitor[Double] { //2

  type V = Int //3
    override def train: RegressionModel  //4
  def |> : PartialFunction[Array[T], Try[V]]
}

逻辑回归类的参数是多变量 xt 时间序列(特征),目标或预期类别 expected,以及用于最小化损失函数或残差平方和的 optimizer(行 1)。在二项式逻辑回归的情况下,expected 被分配给一个类为 1,另一个类为 0 的值。

Monitor 特性用于在训练过程中收集配置文件信息(参见 附录 A 中 实用类 下的 Monitor 部分,基本概念)。

训练的目的是确定回归权重,以最小化损失函数,如 M14 公式定义以及残差平方和(行 4)。

注意

目标值

对于二项式逻辑回归,没有将两个值分配给观测数据的特定规则:{-1, +1}{0, 1},或 {false, true}. 值对 {0, 1} 很方便,因为它允许开发者重用用于多项式逻辑回归的代码,并使用归一化类别值。

为了方便,优化器的定义和配置被封装在 LogisticRegressionOptimizer 类中。

设计

逻辑回归的实现使用了以下组件:

  • RegressionModel 模型是 Model 类型的,在分类器的实例化过程中通过训练进行初始化。我们重用了在 线性回归 部分介绍的 RegressionModel 类型。

  • 实现了对未来观测预测的 ITransformLogisticRegression

  • 一个名为 RegressionJacobian 的适配器类,用于计算雅可比矩阵

  • 一个名为 RegressionConvergence 的适配器类,用于管理平方误差和最小化的收敛标准和退出条件

逻辑回归的关键软件组件在以下 UML 类图中描述:

设计

逻辑回归的 UML 类图

UML 图省略了如 Monitor 或 Apache Commons Math 组件之类的辅助特性和类。

训练工作流程

我们对逻辑回归模型的训练实现利用了高斯-牛顿或 Levenberg-Marquardt 非线性最小二乘优化器,这些优化器包含在 Apache Commons Math 库中(参见 附录 A 中的 非线性最小二乘最小化 部分,基本概念)。

逻辑回归的训练是通过 train 方法执行的。

注意

处理 Apache Commons Math 库的异常

使用 Apache Commons Math 库训练逻辑回归需要处理 ConvergenceExceptionDimensionMismatchExceptionTooManyEvaluationsExceptionTooManyIterationsExceptionMathRuntimeException 异常。通过理解这些异常在 Apache 库源代码中的上下文,调试过程将大大简化。

训练方法 train 的实现依赖于以下五个步骤:

  1. 选择并配置最小二乘优化器。

  2. 定义逻辑函数及其雅可比矩阵。

  3. 指定收敛和退出准则。

  4. 使用最小二乘问题构建器计算残差。

  5. 运行优化器。

下面的流程图可视化了用于训练逻辑回归的工作流程和 Apache Commons Math 类:

训练工作流程

使用 Apache Commons Math 库训练逻辑回归的工作流程

前四个步骤是 Apache Commons Math 库在损失函数最小化之前初始化逻辑回归配置所必需的。让我们从最小二乘优化器的配置开始:

def train: RegressionModel = {
  val weights0 = Array.fill(data.head.length +1)(INITIAL_WEIGHT)
  val lrJacobian = new RegressionJacobian(data, weights0) //5
  val exitCheck = new RegressionConvergence(optimizer) //6

  def createBuilder: LeastSquaresProblem  //7
  val optimum = optimizer.optimize(createBuilder) //8
  RegressionModel(optimum.getPoint.toArray, optimum.getRMS)
}

train 方法实现了回归模型计算的最后四个步骤:

  • 计算逻辑值和雅可比矩阵(第 5 行)。

  • 收敛准则的初始化(第 6 行)。

  • 定义最小二乘问题(第 7 行)。

  • 最小化平方误差和(第 8 行)。这是由优化器在 LogisticRegression 构造函数中作为损失函数最小化的一部分来执行的。

第 1 步 – 配置优化器

在这一步,您必须指定用于最小化平方和残差的算法。LogisticRegressionOptimizer 类负责配置优化器。该类具有以下两个目的:

  • 封装优化器的配置参数

  • 调用 Apache Commons Math 库中定义的 LeastSquaresOptimizer 接口

代码如下:

class LogisticRegressionOptimizer(
     maxIters: Int, 
     maxEvals: Int,
     eps: Double, 
     lsOptimizer: LeastSquaresOptimizer) {  //9
  def optimize(lsProblem: LeastSquaresProblem): Optimum = 
       lsOptimizer.optimize(lsProblem)
}

逻辑回归优化器的配置定义为最大迭代次数 maxIters、最大评估次数 maxEval(针对逻辑函数及其导数)、残差平方和的收敛标准 eps 以及最小二乘问题的实例(第 9 行)。

第 2 步 – 计算雅可比矩阵

下一步包括通过覆盖 fitting.leastsquares.MultivariateJacobianFunction Apache Commons Math 接口的 value 方法来计算逻辑函数及其关于权重的第一阶偏导数的值:

class RegressionJacobianT <: AnyVal(implicit f: T => Double) 
  extends MultivariateJacobianFunction {

  type GradientJacobian = Pair[RealVector, RealMatrix]
  override def value(w: RealVector): GradientJacobian = { //11
    val gradient = xv.map( g => { //12
      val f = logistic(dot(g, w))//13
     (f, f*(1.0-f))  //14
   })
   xv.zipWithIndex   //15
    ./:(Array.ofDimDouble) {
     case (j, (x,i)) => {   
       val df = gradient(i)._2
       Range(0, x.size).foreach(n => j(i)(n+1) = x(n)*df)
       j(i)(0) = 1.0; j //16
     }
   }
   (new ArrayRealVector(gradient.map(_._1).toArray), 
      new Array2DRowRealMatrix(jacobian))  //17
  }
}

RegressionJacobian 类的构造函数需要以下两个参数(第 10 行):

  • 观察的 xv 时间序列

  • weights0 初始回归权重

value方法使用在org.apache.commons.math3.linear Apache Commons Math 包中定义的原始类型RealVectorRealMatrixArrayRealVectorArray2DRowRealMatrix(第11行)。它将w回归权重作为参数,计算每个数据点的逻辑函数的gradient(第12行),并返回值及其导数(第14行)。

雅可比矩阵被填充为逻辑函数的导数值(第15行)。雅可比矩阵每一列的第一个元素被设置为1.0以考虑截距(第16行)。最后,value函数使用符合 Apache Commons Math 库中value方法签名的类型返回梯度值对和雅可比矩阵(第17行)。

第 3 步 – 管理优化器的收敛

第 3 步定义了优化器的退出条件。这是通过在org.apache.commons.math3.optim Java 包中覆盖参数化ConvergenceChecker接口的converged方法来实现的:

val exitCheck = new ConvergenceChecker[PointVectorValuePair] {
  override def converged(
      iters: Int, 
      prev: PointVectorValuePair, 
      current:PointVectorValuePair): Boolean =  
   sse(prev.getValue, current.geValue) < optimizer.eps 
           && iters >= optimizer.maxIters //18
}

此实现计算收敛或退出条件如下:

  • 两次连续迭代权重之间的sse平方误差之和小于eps收敛标准

  • iters值超过了允许的最大迭代次数maxIters(第18行)

第 4 步 – 定义最小二乘问题

Apache Commons Math 最小二乘优化器包要求所有输入到非线性最小二乘求解器的都定义为由LeastSquareBuilder工厂类生成的LeastSquareProblem实例:

def createBuilder: LeastSquaresProblem = 
   (new LeastSquaresBuilder).model(lrJacobian)    //19
   .weight(MatrixUtils.createRealDiagonalMatrix(
             Array.fill(xt.size)(1.0)))  //20
   .target(expected.toArray) //21
   .checkerPair(exitCheck)  //22
   .maxEvaluations(optimizer.maxEvals)  //23
   .start(weights0)  //24
   .maxIterations(optimizer.maxIters) //25
   .build

权重矩阵的对角元素被初始化为1.0(第20行)。除了使用lrJacobian雅可比矩阵初始化模型(第19行)之外,方法调用的序列设置了最大评估次数(第23行)、最大迭代次数(第25行)和退出条件(第22行)。

回归权重使用LogisticRegression构造函数的weights0权重作为参数进行初始化(第24行)。最后,预期的或目标值被初始化(第21行)。

第 5 步 – 最小化平方误差之和

训练通过简单调用lsp最小二乘法求解器执行:

val optimum = optimizer.optimize(lsp)
(optimum.getPoint.toArray, optimum.getRMS)

回归系数(或权重)以及残差均方RMS)是通过在 Apache Commons Math 库的Optimum类上调用getPoint方法返回的。

测试

让我们通过使用前两节中提到的铜 ETF 的价格变动与波动性和成交量示例来测试我们的二项式多元逻辑回归实现。唯一的区别是我们需要将目标值定义为如果 ETF 价格在连续两个交易日之间下降为 0,否则为 1:

import YahooFinancials._ 
val maxIters = 250
val maxEvals = 4500
val eps = 1e-7

val src = DataSource(path, true, true, 1)  //26
val optimizer = new LevenbergMarquardtOptimizer  //27

for {
  price <- src.get(adjClose) //28
  volatility <- src.get(volatility)  //29
  volume <- src.get(volume)  //30
  (features, expected) <- differentialData(volatility, 
volume, price, diffInt) //31
  lsOpt <- LogisticRegressionOptimizer(maxIters, maxEvals, 
                        eps, optimizer) //32
  regr <- LogisticRegressionDouble      
  pfnRegr <- Try(regr |>) //33
} 
yield {
   show(s"${LogisticRegressionEval.toString(regr)}")
   val predicted = features.map(pfnRegr(_))
   val delta = predicted.view.zip(expected.view)
            .map{case(p, e) => if(p.get == e) 1 else 0}.sum
   show(s"Accuracy: ${delta.toDouble/expected.size}")
}

让我们看看执行测试所需的步骤,该测试包括收集数据、初始化平方误差最小化参数、训练逻辑回归模型和运行预测:

  1. 创建一个src数据源以提取市场和交易数据(第26行)。

  2. 选择LevenbergMarquardtOptimizer Levenberg-Marquardt 算法作为optimizer(第27行)。

  3. 加载 ETF 的每日收盘价格(第28行),交易时段内的波动性(第29行),以及 ETF 的成交量每日交易(第30行)。

  4. 生成标记数据作为特征对(ETF 的相对波动性和相对成交量)以及expected结果{0, 1},用于训练模型,其中1表示价格上涨,0表示价格下跌(第31行)。XTSeries单例的differentialData通用方法在第三章的时间序列在 Scala 中部分进行描述,数据预处理)。

  5. 实例化lsOpt优化器以在训练过程中最小化平方误差之和(第32行)。

  6. 训练regr模型并返回pfnRegr预测部分函数(第33行)。

有许多可用的替代优化器可以最小化平方误差优化器(请参阅附录 A 中的非线性最小二乘法部分,基本概念)。

注意

Levenberg-Marquardt 参数

驱动代码使用默认调整参数配置的LevenbergMarquardtOptimizer以保持实现简单。然而,该算法有几个重要的参数,如成本和矩阵求逆的相对容差,对于商业应用来说值得调整(请参阅附录 A 下的非线性最小二乘法中的Levenberg-Marquardt部分,基本概念)。

测试执行产生了以下结果:

  • 残差均方为 0.497

  • 权重为截距-0.124,ETF 波动性 0.453,ETF 成交量-0.121

最后一步是对实时数据的分类。

分类

如前所述,尽管名为二项式逻辑回归,但实际上它是一个二元分类器。分类方法实现为一个隐式数据转换|>

val HYPERPLANE = - Math.log(1.0/INITIAL_WEIGHT -1)
def |> : PartialFunction[Array[T], Try[V]] = {
  case x: Array[T] if(isModel && 
      model.size-1 == x.length && isModel)  => 
       Try (if(dot(x, model) > HYPERPLANE) 1 else 0 ) //34
}

观测xweights模型的点积(或内积)与超平面进行比较。如果产生的值超过HYPERPLANE,则预测类别为1,否则为0(第34行)。

注意

类别识别

新数据x所属的类别由dot(x, weights) > 0.5测试确定,其中dot是特征与回归权重(w[0]+w[1].volatility + w[2].volume)的乘积。你可能在科学文献中找到不同的分类方案。

铜 ETF 价格变动的方向,CU price(t+1) – price(t),与逻辑回归预测的方向进行比较。如果正确预测了正向或负向,则结果以成功值绘制;否则,以失败值绘制:

分类

使用逻辑回归预测铜 ETF 价格变动的方向

逻辑回归能够对 121 个交易时段中的 78 个进行分类(准确率达到 65%)。

现在,让我们使用逻辑回归来预测铜 ETF 的正向价格变动,考虑到其波动性和交易量。这种交易或投资策略被称为在市场上多头。这个特定的用例忽略了价格要么持平要么下跌的交易时段:

分类

使用逻辑回归预测铜 ETF 价格变动的方向

逻辑回归能够正确预测 64 个交易时段中的 58 个时段的正向价格变动(准确率达到 90.6%)。第一次和第二次测试案例之间的区别是什么?

在第一种情况下,w[0] + w[1].volatility + w[2].volume分离超平面方程用于将产生正向或负向价格变动的特征分开。因此,分类的整体准确率受到两个类别特征重叠的负面影响。

在第二种情况下,分类器只需考虑超平面方程的正侧的观察结果,而不考虑假阴性。

注意

舍入误差的影响

在某些情况下,计算雅可比矩阵时产生的舍入误差会影响w[0] + w[1].volatility + w[2].volume分离超平面方程的准确性。这降低了预测正向和负向价格变动的准确性。

通过考虑价格的正向变动并使用边缘误差 EPS(price(t+1) – price(t) > EPS)来进一步提高二元分类器的准确度。

注意

验证方法

验证集是通过从原始标记数据集中随机选择观察结果生成的。正式验证需要你使用 K 折验证方法来计算逻辑回归模型的召回率、精确率和 F1 度量。

摘要

这就完成了线性回归和逻辑回归以及正则化概念的描述和实现,以减少过拟合。你的第一个使用机器学习的分析项目(或已经涉及)可能涉及某种类型的回归模型。回归模型,连同朴素贝叶斯分类,是那些没有深入了解统计学或机器学习的人最理解的技巧。

在完成本章之后,你可能会希望掌握以下主题:

  • 基于线性和非线性最小二乘优化的概念

  • 普通最小二乘回归以及逻辑回归的实现

  • 正则化的影响,以及岭回归的实现

逻辑回归也是条件随机字段的基础,如第七章中条件随机字段部分所述,序列数据模型,以及多层感知器,这在第九章中多层感知器部分介绍。

与朴素贝叶斯模型(参见第五章,朴素贝叶斯分类器)相反,最小二乘法或逻辑回归并不强加条件,即特征必须是独立的。然而,回归模型并没有考虑到时间序列(如资产定价)的序列性质。下一章,即专门针对序列数据模型的章节,介绍了两种考虑时间序列时间依赖性的分类器:隐藏马尔可夫模型和条件随机字段。

第七章。序列数据模型

马尔可夫模型的世界是广阔的,包括马尔可夫决策过程、离散马尔可夫、贝叶斯网络的马尔可夫链蒙特卡洛以及隐藏马尔可夫模型等计算概念。

马尔可夫过程,更具体地说,隐藏马尔可夫模型HMM),在语音识别、语言翻译、文本分类、文档标记、数据压缩和解码中通常被使用。

本章的第一部分介绍了并描述了使用 Scala 实现隐藏马尔可夫模型的三种经典形式的完整实现。本节涵盖了在隐藏马尔可夫模型的评估、解码和训练中使用的不同动态规划技术。分类器的设计遵循与逻辑回归和线性回归相同的模式,如第六章中所述,回归和正则化

本章的第二部分和最后一部分致力于隐马尔可夫模型的判别性(标签对观察条件)替代方案:条件随机字段。印度理工学院孟买分校的 Sunita Sarawagi 编写的开源 CRF Java 库用于创建使用条件随机字段的预测模型 [7:1]。

马尔可夫决策过程

本节还描述了您需要了解的基本概念,以便理解、开发和应用隐马尔可夫模型。马尔可夫宇宙的基础是被称为马尔可夫性质的概念。

马尔可夫性质

马尔可夫性质是随机过程的特征,其中未来状态的条件概率分布取决于当前状态,而不是其过去状态。在这种情况下,状态之间的转换发生在离散时间,马尔可夫性质被称为离散马尔可夫链

第一阶离散马尔可夫链

以下示例摘自 《机器学习导论》,E. Alpaydin [7:2]。

让我们考虑以下用例。N 个不同颜色的球隐藏在 N 个盒子中(每个盒子一个)。球只能有三种颜色(蓝色、红色和绿色)。实验者逐个抽取球。发现过程的状态由从其中一个盒子中抽取的最新球的颜色定义:S[0] = 蓝色,S[1] = 红色,S[2] = 绿色

{π[0], π[1], π[2]} 为每个盒子中初始颜色集合的初始概率。

q[t] 表示在时间 t 抽取的球的颜色。在时间 t 抽取颜色为 S j 的球之后,在时间 k 抽取颜色为 S[k] 的球的概率定义为 p(q[t] = S[k] q[t-1] = S[j]) = a[jk]。第一次尝试抽取红色球的概率为 p(q[t0] = S[1]) = π[1]。第二次尝试抽取蓝色球的概率为 p(q[0] = S[1]) p(q[1] = S[0]|q[0] = S[1]) = π[1] a[10]。该过程重复进行,以创建以下概率的状态序列 {S[t]} = {红色,蓝色,蓝色,绿色,…}p(q[[0]] = S[1]).p(q[1] = S[0]|q[0] = S[1]).p(q[2] = S[0]|q[1] = S[0]).p(q[3] = S[2]|q[2] = S[0])… = π[1].a[10].a[00].a2…。状态/颜色的序列可以表示如下:

第一阶离散马尔可夫链

球和盒子示例的插图

让我们使用历史数据(学习阶段)估计概率 p

  • 在第一次尝试中抽取红色球(S[1])的概率估计为 π[1],其计算方式为以 S[1] (红色)开头的序列数除以球的总数。

  • 在第二次尝试中检索蓝色球的概率估计为 a[10],即在红色球之后抽取蓝色球的序列数除以序列总数,依此类推。

注意

N 阶马尔可夫

马尔可夫属性之所以流行,主要是因为其简单性。当你学习隐藏马尔可夫模型时,你会发现状态仅依赖于前一个状态,这使我们能够应用有效的动态规划技术。然而,一些问题需要超过两个状态之间的依赖关系。这些模型被称为马尔可夫随机场。

虽然离散马尔可夫过程可以应用于试错类型的应用,但其适用性仅限于解决那些观测不依赖于隐藏状态的问题。隐藏马尔可夫模型是解决此类挑战的常用技术。

隐藏马尔可夫模型

隐藏马尔可夫模型在语音识别、人脸识别(生物识别)以及图片和视频中的模式识别等方面有众多应用[7:3]。

隐藏马尔可夫模型由一个具有离散时间的观测马尔可夫过程(也称为马尔可夫链)组成。与马尔可夫过程的主要区别在于状态是不可观测的。每次系统或模型的状态改变时,都会以称为发射概率的概率发射一个新的观测。

存在两种随机性的来源,如下所述:

  • 状态之间的转换

  • 当给定状态时观测的发射

让我们重用盒子与球体的例子。如果盒子是隐藏状态(不可观测的),那么用户抽取的球的颜色是不可见的。发射概率是检索到颜色为k的球从隐藏盒子I的概率b[ik] =p(o[t] = colork|q[t] =S[i]),如下图所示:

隐藏马尔可夫模型

球和盒子示例的隐藏马尔可夫模型

在这个例子中,我们不假设所有盒子都包含不同颜色的球。我们不能对由转换a[ij]定义的顺序做出任何假设。HMM 不假设颜色(观测)的数量与盒子(状态)的数量相同。

注意

时间不变性

与卡尔曼滤波器等不同,隐藏马尔可夫模型要求转换元素a[ji]与时间无关。这一属性被称为平稳性或齐次性限制。

请记住,在这个例子中,观测(球的颜色)是实验者唯一可用的有形数据。从这个例子中,我们可以得出结论,一个正式的 HMM 有三个组件:

  • 一组观测

  • 一系列隐藏状态

  • 最大化观测和隐藏状态联合概率的模型,称为 Lambda 模型

Lambda 模型,λ,由初始概率π、由矩阵A定义的状态转换概率以及状态发射一个或多个观测的概率组成,如下图所示:

隐藏马尔可夫模型

HMM 关键组件的可视化

上述图示说明了,给定一个观察值序列,HMM 处理以下三个被称为规范形式的问题:

  • CF1 (评估):这评估了给定观察序列 O**t,在模型 λ = (π, A, B) 下的概率

  • CF2 (训练):这识别(或学习)给定观察值集合 O 的模型 λ = (π, A, B)

  • CF3 (解码):这估计了生成给定观察值集合 O 和模型 λ 的最高概率的状态序列 Q

解决这三个问题的解决方案使用了动态规划技术。然而,在深入隐藏马尔可夫模型的数学基础之前,我们需要明确符号。

符号

描述隐藏马尔可夫模型的一个挑战是数学符号有时会因作者而异。从现在开始,我们将使用以下符号:

描述 公式
N 隐藏状态的数量
S 一个包含 N 个隐藏状态的有限集合 S = {S[0], S[1], … S[N-1]}
M 观察符号的数量
qt 时间或步骤 t 的状态
Q 状态的时间序列 Q = {q[0], q[1], … q[n-1]} = Q[0:n-1]
T 观察值的数量
ot 时间 t 的观察值
O 一个由 T 个观察值组成的有限序列 O = {o[0], o[1], … o[T-1]} = O[0:T-1]
A 状态转移概率矩阵 *a[ji] = p(q[t+1]=S[i]
B 发射概率矩阵 *b[jk] = p(o[t]=O[k]
π 初始状态概率向量 π[i] = p(q[0]=S[j])
λ 隐藏马尔可夫模型 λ = (π, A, B)

注意

符号的方差

一些作者使用符号 z 来表示隐藏状态,而不是 q,使用 x 来表示观察值 O

为了方便起见,让我们使用简化的形式来表示观察值和状态的序列:p(O[0:T], q[t]| λ) = p(O[0], O[1], … O[T], q[t]| λ)。用网格状的状态和观察值来可视化隐藏马尔可夫模型是很常见的,这与我们描述的箱子和球例子类似,如下所示:

符号

正式的 HMM 有向图

状态 Si 在时间 t 被观察到 O[k],然后过渡到状态 S[j],在时间 t+1 被观察到 O[m]。创建我们的 HMM 的第一步是定义实现 lambda 模型 λ = (π, A, B) [7:4] 的类。

Lambda 模型

隐藏马尔可夫模型的三个规范形式在很大程度上依赖于矩阵和向量的操作。为了方便起见,让我们定义一个 HMMConfig 类,它包含 HMM 中使用的维度:

class HMMConfig(val numObs: Int, val numStates: Int, 
    val numSymbols: Int, val maxIters: Int, val eps: Double) 
    extends Config

类的输入参数如下:

  • numObs:这是观察值的数量

  • numStates:这是隐藏状态的数量

  • numSymbols:这是观测符号或特征的数量

  • maxIters:这是 HMM 训练所需的最大迭代次数

  • eps:这是 HMM 训练的收敛标准

注意

与数学符号的一致性

实现使用 numObs(相对于 numStatesnumSymbols)来程序化地表示观测数量 T(相对于 N 个隐藏状态和 M 个特征)。一般来说,实现尽可能地重用数学符号。

HMMConfig 伴随对象定义了对矩阵行和列索引范围的运算。foreach(行 1),foldLeft (/:)(行 2),和 maxBy(行 3)方法在每个三个规范形式中经常使用:

object HMMConfig {
  def foreach(i: Int, f: Int => Unit): Unit =
      Range(0, i).foreach(f)  //1
  def /:(i: Int, f: (Double, Int) => Double, zero: Double) = 
        Range(0, i)./:(zero)(f) //2
  def maxBy(i: Int, f: Int => Double): Int = 
      Range(0,i).maxBy(f)   //3
   … 
}

注意

λ 符号

HMM 中的 λ 模型不应与 第六章 中 L[n] 粗糙度惩罚 部分讨论的正则化因子混淆,回归和正则化

如前所述,λ 模型定义为转移概率矩阵 A,发射概率矩阵 B 和初始概率 π 的元组。它可以通过在 附录 A 中定义的 DMatrix 类轻松实现为 HMMModel 类,如 实用类 部分所述,基本概念。当状态转移概率矩阵、发射概率矩阵和初始状态已知时,将调用 HMMModel 类的最简单构造函数,如下面的代码所示:

class HMMModel( val A: DMatrix, val B: DMatrix, var pi: DblArray, 
    val numObs: Int) {   //4
  val numStates = A.nRows
  val numSymbols = B.nCols

  def setAlpha(obsSeqNum: Vector[Int]): DMatrix
  def getAlphaVal(a: Double, i: Int, obsId: Int): Double
  def getBetaVal(b: Double, i: Int, obsId: Int): Double
  def update(gamma: Gamma, diGamma: DiGamma, 
      obsSeq: Vector[Int])
  def normalize: Unit
}

HMMModel 类的构造函数有以下四个参数(行 4):

  • A:这是状态转移概率矩阵

  • B:这是缺失概率矩阵

  • pi:这是状态的初始概率

  • numObs:这是观测的数量

从矩阵 AB 的维度中提取了状态和符号的数量。

HMMModel 类有几个方法,当它们在执行模型时需要时将详细描述。对于 pi 初始状态的概率是未知的,因此,它们使用值 [0, 1] 的随机生成器进行初始化。

注意

归一化

在我们初始化 AB 矩阵之前,输入状态和观测数据可能需要归一化并转换为概率。

HMM 的其他两个组成部分是观测序列和隐藏状态序列。

设计

HMM 的规范形式通过动态规划技术实现。这些技术依赖于定义 HMM 任何规范形式的执行状态的变量:

  • Alpha(正向传递):在特定状态 S[i] 下,观察第一个 t < T 观测的概率是 αt = p(O[0:t], q[t]=S[i]|λ)

  • Beta(反向传递):对于特定状态,观察序列 q**t 的剩余部分的概率是 βt =p(O[t+1:T-1]|q[t]=S[i],λ)

  • Gamma: 在给定观测序列和模型的情况下,处于特定状态的概率是 γt =p(q[t]=S[i]|O[0:T-1], λ)

  • Delta: 这是对于特定测试定义的第一个 i 个观测的最高概率路径序列,定义为 δt

  • Qstar: 这是最优状态序列 q* 的 Q[0:T-1]

  • DiGamma: 在给定观测序列和模型的情况下,处于特定状态 t 和另一个定义状态 t + 1 的概率是 γt = p(q[t]=S[i],q[t+1]=S[j]|O[0:T-1], λ)

每个参数都在与每个特定标准形式相关的部分以数学和程序方式描述。在评估标准形式中使用了GammaDiGamma类,并进行了描述。DiGamma单例作为维特比算法的一部分进行描述,用于根据给定的 λ 模型和一组观测值提取具有最高概率的状态序列。

在任何三种标准形式中使用的与动态规划相关的算法列表通过我们实现 HMM 的类层次结构进行可视化:

设计

HMM 的 Scala 类层次结构(UML 类图)

UML 图省略了Monitor或 Apache Commons Math 组件等实用特性类。

λ 模型、HMM 状态和观测序列都是实现三种标准情况所需的元素。每个类都在 HMM 三种标准形式的描述中按需进行描述。现在是时候深入研究每种标准形式的具体实现细节了,从评估开始。

任何三种标准形式的执行都依赖于动态规划技术(请参阅附录 A 中的动态规划概述部分,基本概念)[7:5]。动态规划技术中最简单的是对观测/状态链的单次遍历。

评估 – CF-1

目标是计算给定 λ 模型观察序列 O[t] 的概率(或似然)。使用动态规划技术将观测序列的概率分解为两个概率(M1):

评估 – CF-1

似然是通过对所有隐藏状态 {S[i]} [7:6] 进行边缘化来计算的(M2):

评估 – CF-1

如果我们使用前一章中引入的 alpha 和 beta 变量的符号,给定 λ 模型观察序列 O[t] 的概率可以表示如下(M3):

评估 – CF-1

αβ 概率的乘积可能会下溢。因此,建议您使用概率的对数而不是概率本身。

Alpha – 前向传递

给定一个隐藏状态序列和一个 λ 模型,计算观察特定序列的概率依赖于一个两遍算法。alpha 算法包括以下步骤:

  1. 计算初始 alpha 值 [M4]。然后通过所有隐藏状态的 alpha 值之和进行归一化 [M5]。

  2. 依次计算从时间 0 到时间 t 的 alpha 值,然后通过所有状态的 alpha 值之和进行归一化 [M6]。

  3. 最终一步是计算观察序列 [M7] 的概率的对数。

备注

性能考虑

直接计算观察特定序列的概率需要 2TN[2] 次乘法。迭代的 alpha 和 beta 类将乘法次数减少到 N[2]T

对于那些对数学有些倾向的人来说,alpha 矩阵的计算定义在以下信息框中。

备注

Alpha(前向传递)

M4:初始化定义为:

Alpha – 前向传递

M5:初始值归一化 N – 1 定义为:

Alpha – 前向传递

M6:归一化求和定义为:

Alpha – 前向传递

M7:给定一个 lambda 模型和状态,观察一个序列的概率定义为:

Alpha – 前向传递

让我们看看 Scala 中 alpha 类的实现,使用 alpha 类的数学表达式引用编号。alpha 和 beta 值必须归一化 [M3],因此我们定义了一个 HMMTreillis 基类,用于 alpha 和 beta 算法,并实现了归一化:

class HMMTreillis(numObs: Int, numStates: Int){ //5
  var treillis: DMatrix = _   //6
  val ct = Array.fill(numObs)(0.0) 

   def normalize(t: Int): Unit = { //7
     ct.update(t, /:(numStates, (s, n) => s + treillis(t, n)))
     treillis /= (t, ct(t))
   }
   def getTreillis: DMatrix = treillis
}

HMMTreillis 类有两个配置参数:观察数 numObs 和状态数 numStates(第 5 行)。treillis 变量代表在 alpha(或前向)和 beta(或后向)传递中使用的缩放矩阵(第 6 行)。

归一化方法 normalize 通过重新计算 ct 缩放因子(第 7 行)实现了 M6 公式。

备注

计算效率

Scala 的 reducefoldforeach 方法比 for 循环更高效的迭代器。你需要记住,Scala 中 for 循环的主要目的是 mapflatMap 操作的单一组合。

Alpha 类中计算 alpha 变量的计算流程与在 M4M5M6 数学表达式中定义的计算流程相同:

class Alpha(lambda: HMMModel, obsSeq: Vector[Int]) //8
    extends HMMTreillis(lambda.numObs, lambda.numStates) {

  val alpha: Double = Try { 
    treillis = lambda.setAlpha(obsSeq) //9
    normalize(0)  //10
    sumUp  //11
  }.getOrElse(Double.NaN)

  override def isInitialized: Boolean = alpha != Double.NaN

  val last = lambda.numObs-1
  def sumUp: Double = {
    foreach(1, lambda.numObs, t => {
      updateAlpha(t) //12
      normalize(t)  //13
    })
    /:(lambda.numStates, (s,k) => s + treillis(last, k))
  }

  def updateAlpha(t: Int): Unit = 
    foreach(lambda.numStates, i => { //14
      val newAlpha = lambda.getAlphaVal(treillis(t-1, i)
      treillis += (t, i, newAlpha, i, obsSeq(t))) 
    })

  def logProb: Double = /:(lambda.numObs, (s,t) => //15
    s + Math.log(ct(t)), Math.log(alpha))
}

Alpha类有两个参数:lambda模型和观测值序列obsSeq(行8)。缩放因子alpha的定义初始化treillis缩放矩阵使用HMMModel.setAlpha方法(行9),通过调用HMMTreillis.normalize方法对第一个观测值进行归一化以初始化矩阵的初始值(行10),并通过调用sumUp来返回缩放因子(行11)。

setAlpha方法实现了数学表达式M4,如下所示:

def setAlpha(obsSeq: Array[Int]): DMatrix = 
  Range(0,numStates)./:(DMatrix(numObs, numStates))((m,j) => 
      m += (0, j, pi(j)*B(j, obsSeq.head)))
}

折叠生成了一个DMatrix类的实例,正如在附录 A 中实用类部分所描述的,基本概念

sumUp方法实现了数学表达式M6,如下所示:

  • updateAlpha方法中更新缩放因子的treillis矩阵(行12

  • 标准化所有剩余观测值的缩放因子(行13

updateAlpha方法通过计算所有状态的alpha因子来更新treillis缩放矩阵(行14)。logProb方法实现了数学表达式M7。它计算在给定状态序列和预定义的λ模型下观察特定序列的概率的对数(行15)。

注意

对数概率

logProb方法计算概率的对数而不是概率本身。概率对数的和比概率的乘积更不容易导致下溢。

Beta – 反向传播

beta 值的计算类似于Alpha类,除了迭代在状态序列上反向执行。

Beta的实现与 alpha 类相似:

  1. 计算(M5)并标准化(M6)时间t = 0的 beta 值在各个状态之间。

  2. 递归地计算并标准化时间T - 1t的 beta 值,该值从t + 1的值更新(M7)。

注意

Beta(反向传播

M8: 初始化 beta βT-1 = 1

M9: 初始 beta 值的归一化定义为:

Beta – 反向传播

M10: 归一化 beta 值的加和定义为:

Beta – 反向传播

Beta类的定义与Alpha类非常相似:

class Beta(lambda: HMMModel, obsSeq: Vector[Int]) 
     extends HMMTreillis(lambda.numObs, lambda.numStates) {

  val initialized: Boolean  //16

  override def isInitialized: Boolean = initialized
  def sumUp: Unit =   //17
    (lambda.numObs-2 to 0 by -1).foreach(t => { //18
      updateBeta(t)  //19
      normalize(t) 
    })

   def updateBeta(t: Int): Unit =
     foreach(lambda.numStates, i => { 
       val newBeta = lambda.getBetaVal(treillis(t+1, i)
       treillis += (t, i, newBeta, i, obsSeq(t+1))) //20
     })
}

Alpha类相反,Beta类不生成输出值。Beta类有一个initialized布尔属性,用来指示构造函数是否成功执行(行16)。构造函数通过遍历从最后一个观测值之前到第一个观测值的观测序列来更新和归一化 beta 矩阵:

val initialized: Boolean = Try {
  treillis = DMatrix(lambda.numObs, lambda.numStates)
  treillis += (lambda.numObs-1, 1.0) //21
  normalize(lambda.numObs-1)  //22
  sumUp  //23
}._toBoolean("Beta initialization failed")

DMatrix 类的 treillis beta 缩放矩阵的初始化将值 1.0 分配给最后一个观察值(第 21 行),并按 M8(第 22 行)中定义的归一化 beta 值。它通过调用 sumUp 方法实现数学表达式 M9M10(第 23 行)。

sumUp 方法与 Alpha.sumUp(第 17 行)类似。它从观察序列的末尾开始遍历(第 18 行),并更新定义在数学表达式 M9(第 19 行)中的 beta 缩放矩阵。updateBeta 方法中数学表达式 M10 的实现与 alpha 过程类似:它使用在 lambda 模型中计算的 newBeta 值更新 treillis 缩放矩阵(第 20 行)。

注意

构造函数和初始化

alpha 和 beta 值在各自类的构造函数中计算。客户端代码必须通过调用 isInitialized 验证这些实例。

如果一个模型无法创建,它的值是多少?下一个规范形式 CF2 利用动态规划和递归函数提取 λ 模型。

训练 – CF-2

该规范形式的目标是在给定一组观察值和状态序列的情况下提取 λ 模型。它与分类器的训练类似。当前状态对先前状态的简单依赖性使得可以使用迭代过程实现,称为 Baum-Welch 估计器期望最大化EM)。

Baum-Welch 估计器(EM)

在其核心,该算法由三个步骤和一个迭代方法组成,类似于评估规范形式:

  1. 计算概率 πt = 0 时的 gamma 值)(M11)。

  2. 计算并归一化状态的过渡概率矩阵 AM12)。

  3. 计算并归一化发射概率矩阵 BM13)。

  4. 重复步骤 2 和 3,直到似然度的变化不显著。

该算法使用 digamma 和求和 gamma 类。

注意

Baum-Welch 算法

M11:状态 q[i]t 时刻和状态 q[j]t+1 时刻的联合概率(digamma)定义为:

Baum-Welch 估计器(EM)Baum-Welch 估计器(EM)

M12:初始概率向量 N−1 和所有状态的联合概率之和(gamma)定义为:

Baum-Welch 估计器(EM)

M13:过渡概率矩阵的更新定义为:

Baum-Welch 估计器(EM)

M14:发射概率矩阵的更新定义为:

Baum-Welch 估计器(EM)

Baum-Welch 算法在 BaumWelchEM 类中实现,并需要以下两个输入(第 24 行):

  • config 配置计算得到的 λ 模型

  • 观察序列 obsSeq(向量)

代码如下:

class BaumWelchEM(config: HMMConfig, obsSeq: Vector[Int]) { //24
  val lambda = HMMModel(config)
  val diGamma = new DiGamma(lambda.numObs,lambda.numStates)//25
  val gamma = new Gamma(lambda.numObs, lambda.numStates) //26
  val maxLikelihood: Option[Double] //27
}

DiGamma类定义了任何连续状态之间的联合概率(行25):

class DiGamma(numObs: Int, numStates: Int) {
  val diGamma = Array.fill(numObs-1)(DMatrix(numStates))
  def update(alpha: DMatrix, beta: DMatrix, A: DMatrix, 
  B: DMatrix, obsSeq: Array[Int]): Try[Int]
}

diGamma变量是一个矩阵数组,表示两个连续状态的联合概率。它通过调用update方法初始化,该方法实现了数学表达式M11

Gamma类计算所有状态之间的联合概率总和(行26):

class Gamma(numObs: Int, numStates: Int) {
  val gamma = DMatrix(numObs, numStates)
  def update(alpha: DMatrix, beta: DMatrix): Unit
}

Gamma类的update方法实现了数学表达式M12

注意

Gamma 和 DiGamma 的源代码

GammaDiGamma类实现了 Baum-Welch 算法的数学表达式。update方法使用简单的线性代数,此处未描述;请参阅文档化的源代码以获取详细信息。

给定现有的 lambda 模型和观察序列(行27),使用getLikelihood尾递归方法计算状态序列的最大似然maxLikelihood,如下所示:

val maxLikelihood: Option[Double] = Try {

  @tailrec
  def getLikelihood(likelihood: Double, index: Int): Double ={
    lambda.update(gamma, diGamma, obsSeq) //28
    val _likelihood = frwrdBckwrdLattice   //29
    val diff = likelihood - _likelihood

    if( diff < config.eps ) _likelihood    //30
    else if (index >= config.maxIters)  //31
      throw new IllegalStateException(" … ")
    else getLikelihood(_likelihood, index+1) 
  }

  val max = getLikelihood(frwrdBckwrdLattice, 0)
  lambda.normalize   //32
  max
}._toOption("BaumWelchEM not initialized", logger)

maxLikelihood值实现了数学表达式M13M14getLikelihood递归方法更新 lambda 模型矩阵AB以及初始状态概率pi(行28)。使用frwrBckwrdLattice方法实现的向前向后格算法重新计算状态序列的似然(行29)。

注意

lambda 模型的更新

HMMModel对象的update方法使用简单的线性代数,此处未描述;请参阅文档化的源代码以获取详细信息。

Baum-Welch 期望最大化算法的核心是迭代地更新时间tt + 1之间状态和观察值的格。基于格的迭代计算在以下图中展示:

Baum-Welch 估计器(EM)

Baum-Welch 算法的 HMM 图格可视化

代码如下:

def frwrdBckwrdLattice: Double  = {
  val _alpha = Alpha(lambda, obsSeq) //33
  val beta = Beta(lambda, obsSeq).getTreillis //34
  val alphas = _alpha.getTreillis
  gamma.update(alphas, beta) //35
  diGamma.update(alphas, beta, lambda.A, lambda.B, obsSeq)
  _alpha.alpha
}

向前向后算法在正向传递中使用Alpha类进行lambda模型的计算/更新(行33),在反向传递中使用Beta类更新lambda(行34)。与联合概率相关的gammadiGamma矩阵在每个递归时更新(行35),反映了数学表达式M11M14的迭代。

如果算法收敛(行30),则存在maxLikelihood的递归计算。如果超过最大递归次数,则抛出异常(行31)。

解码 – CF-3

最后这个规范形式包括在给定一组观察值O[t]λ模型的情况下,提取最可能的状态序列{q[t]}。解决此问题需要再次使用递归算法。

Viterbi 算法

提取最佳状态序列(具有最高概率的状态序列)非常耗时。一种替代方法是应用动态规划技术通过迭代找到最佳序列{q[t]}。这个算法被称为维特比算法。给定状态序列{q[t]}和观察序列{o[j]},对于状态S[i],定义了任何序列在第一个T观察中的最高概率路径的概率δt [7:7]。

注意

维特比算法

M12:delta 函数的定义如下:

维特比算法

M13:delta 的初始化定义为:

维特比算法

M14:delta 的递归计算定义为:

维特比算法

M15:最优状态序列{q}的计算定义为:

维特比算法

ViterbiPath类实现了维特比算法,其目的是在给定一组观察和λ模型的情况下计算最优状态序列(或路径)。最优状态序列或路径是通过最大化 delta 函数来计算的。

ViterbiPath类的构造函数与 forward、backward 和 Baum-Welch 算法的参数相同:lambda模型和观察集obsSeq

class ViterbiPath(lambda: HMMModel, obsSeq: Vector[Int]) {
  val nObs = lambda.numObs
  val nStates = lambda.numStates
  val psi = Array.fill(nObs)(Array.fill(nStates)(0)) //35
  val qStar = new QStar(nObs, nStates) //36

  val delta = { //37
    Range(0, nStates)./:(DMatrix(nObs, nStates))((m,n) => {
     psi(0)(n) = 0
     m += (0, n, lambda.pi(n) * lambda.B(n,obsSeq.head))
    })
  val path = HMMPrediction(viterbi(1), qStar()) //38
}

如前所述,包含维特比算法数学表达式的信息框中,必须定义以下矩阵:

  • psi:这是由nObs个观察的索引和nStates个状态的索引组成的矩阵(第35行)。

  • qStar:这是维特比算法每次递归的最优状态序列(第36行)。

  • delta:这是具有最高概率路径的前n个观察序列。它还设置了第一个观察的psi值为 0(第37行)。

ViterbiPath类的所有成员都是私有的,除了定义给定obsSeq观察的最优状态序列或路径的path(第38行)。

使用数学表达式M13(第37行)初始化定义了给定lambda模型和obsSeq观察的最大概率delta矩阵。预测模型返回路径或最优状态序列作为HMMPrediction实例:

case class HMMPrediction(likelihood: Double, states: Array[Int])

likelihood的第一个参数是通过viterbi递归方法计算的。states最优序列中状态的下标是通过QStar类(第38行)计算的。

让我们来看看维特比递归方法的内部结构:

@tailrec
def viterbi(t: Int): Double = {
  Range(0, numStates).foreach( updateMaxDelta(t, _)) //39

  if( t == obsSeq.size-1) {  //40
    val idxMaxDelta = Range(0, numStates)
                .map(i => (i, delta(t, i))).maxBy(_._2) //41
    qStar.update(t+1, idxMaxDelta._1)  //42
    idxMaxDelta._2
  }
  else viterbi(t+1)  //43
}

由于qStarpsidelta参数已经在构造函数中初始化,递归实现从第二个观察开始。递归实现调用updateMaxDelta方法来更新psi索引矩阵和任何状态的最高概率,如下所示:

def updateMaxDelta(t: Int, j: Int): Unit = {.
   val idxDelta = Range(0, nStates)
        .map(i => (i, delta(t-1, i)*lambda.A(i, j)))
        .maxBy(_._2)   //44
   psi(t)(j) = idxDelta._1
   delta += (t, j, idxDelta._2)  //45
}

updateMaxDelta 方法实现了数学表达式 M14,该表达式提取使 psi 最大的状态索引(行 44)。相应地更新 delta 概率矩阵和 psi 索引矩阵(行 45)。

viterbi 方法对剩余的观察值(除了最后一个)进行递归调用(行 43)。在 obsSeq.size-1 索引的最后观察值,算法执行 QStar 类中实现的数学表达式 M15(行 42)。

注意

QStar 类

QStar 类及其 update 方法使用线性代数,此处未进行描述;请参阅文档化的源代码和 Scaladocs 文件以获取详细信息。

此实现完成了对隐藏马尔可夫模型解码形式的描述及其在 Scala 中的实现。现在,让我们将此知识付诸实践。

将所有内容整合在一起

HMM 类实现了三种规范形式。一个绑定到整数数组的视图用于参数化 HMM 类。我们假设连续或准连续值的时间序列被量化为离散符号值。

@specialized 注解确保为 Array[Int] 原始类型生成字节码,而不执行隐式声明的绑定视图的转换。

有两种模式可以执行隐藏马尔可夫模型的任何三种规范形式:

  • ViterbiPath 类:构造函数初始化/训练一个类似于任何其他学习算法的模型,如 附录 A 中 不可变分类器设计模板 小节所述,基本概念。构造函数通过执行 Baum-Welch 算法生成模型。一旦模型成功创建,就可以用于解码或评估。

  • ViterbiPath 对象:伴随对象提供了使用 HMM 对观察序列进行解码和评估的 decodeevaluate 方法。

以下图表描述了两种操作模式:

将所有内容整合在一起

隐藏马尔可夫模型的计算流程

让我们通过定义其类来完善我们对 HMM 的实现。HMM 类被定义为使用从 xt 训练集中隐式生成的模型进行数据转换,如第二章 中 单态数据转换 小节所述,Hello World!(行 46):

class HMM@specialized(Double) T <: AnyVal
    (implicit quantize: Array[T] => Int, f: T => Double) 
  extends ITransform[Array[T]](xt) with Monitor[Double] {//46

  type V = HMMPrediction  //47
  val obsSeq: Vector[Int] = xt.map(quantize(_)) //48

  val model: Option[HMMModel] = train  //49
  override def |> : PartialFunction[U, Try[V]] //50
}

HMM 构造函数接受以下四个参数(行 46):

  • config:这是 lambda 模型的维度和执行参数的 HMM 配置

  • xt:这是具有 T 类型的观察值的多元时间序列

  • form:这是在模型生成后(评估或解码)要使用的规范形式

  • quantize:这是一个量化函数,将 Array[T] 类型的观察转换为 Int 类型

  • f:这是从 T 类型到 Double 的隐式转换

构造函数必须覆盖在 ITransform 抽象类中声明的输出数据 V 类型(HMMPrediction)(第 47 行)。HMMPrediction 类的结构已在上一节中定义。

使用 Monitor 特性在训练期间收集配置文件信息(请参阅 附录 A 下的 Monitor 部分,在 实用类 中,基本概念)。

xt 观察的时间序列通过应用每个观察的 quantize 量化函数转换为 obsSeq 观察状态的向量(第 48 行)。

与任何监督学习技术一样,模型是通过训练创建的(第 49 行)。最后,|> 多态预测器调用 decode 方法或 evaluate 方法(第 50 行)。

train 方法包括执行 Baum-Welch 算法,并返回 lambda 模型:

def train: Option[HMMModel] = Try {
  BaumWelchEM(config, obsSeq).lambda }.toOption

最后,|> 预测器是一个简单的包装器,包装了评估形式(evaluate)和解码形式(decode):

override def |> : PartialFunction[U, Try[V]] = {
  case x: Array[T] if(isModel && x.length > 1) => 
  form match {
    case _: EVALUATION => 
      evaluation(model.get, VectorInt)
    case _: DECODING => 
       decoding(model.get, VectorInt)
   }
}

HMM 伴随对象的受保护 evaluation 方法是围绕 Alpha 计算的包装器:

def evaluation(model: HMMModel, 
    obsSeq: Vector[Int]): Try[HMMPrediction] = Try {
  HMMPrediction(-Alpha(model,obsSeq).logProb, obsSeq.toArray) 
}

HMM 对象的 evaluate 方法公开了解释规范形式:

def evaluateT <: AnyVal(implicit quantize: Array[T] => Int, 
      f: T => Double): Option[HMMPrediction] =  
  evaluation(model, xt.map(quantize(_))).toOption

decoding 方法包装了 Viterbi 算法以提取最佳状态序列:

def decoding( model: HMMModel, obsSeq: Vector[Int]): 
     Try[HMMPrediction] = Try { 
  ViterbiPath(model, obsSeq).path
}

HMM 对象的 decode 方法公开了解码规范形式:

def decodeT <: AnyVal(implicit quantize: Array[T] => Int,
    f: T => Double): Option[HMMPrediction] =
  decoding(model, xt.map(quantize(_))).toOption

注意

归一化概率输入

您需要确保用于评估和解码规范形式的 λ 模型的输入概率已归一化——对于 π 向量和 AB 矩阵中所有状态的概率之和等于 1。此验证代码在示例代码中省略。

测试案例 1 – 训练

我们的第一个测试案例是训练一个 HMM(隐马尔可夫模型)来预测投资者情绪,该情绪是通过美国个人投资者协会AAII)每周情绪调查的成员来衡量的[7:8]。目标是计算给定观察和隐藏状态(训练规范形式)的转移概率矩阵 A、发射概率矩阵 B 和稳态概率分布 π

我们假设投资者情绪的变化与时间无关,这是隐藏马尔可夫模型所要求的。

AAII 情绪调查按百分比评估市场的看涨情绪:

测试案例 1 – 训练

每周 AAII 市场情绪(由 AAII 借鉴)

投资者情绪被认为是股票市场未来方向的逆指标。请参阅 附录 A 中的 术语 部分,基本概念

让我们选择投资者上涨百分比与下跌百分比的比例。然后对这个比例进行归一化。以下表格列出了这个比例:

时间 上涨 下跌 中性 比率 标准化比率
t0 0.38 0.15 0.47 2.53 1.0
t1 0.41 0.25 0.34 1.68 0.53
t2 0.25 0.35 0.40 0.71 0.0
….

非归一化观察序列(上涨情绪与下跌情绪的比率)定义在 CSV 文件中如下:

val OBS_PATH = "resources/data/chap7/obsprob.csv"
val NUM_SYMBOLS = 6
val NUM_STATES = 5
val EPS = 1e-4
val MAX_ITERS = 150
val observations = VectorDouble

val quantize = (x: DblArray) => 
      (x.head* (NUM_STATES+1)).floor.toInt  //51
val xt = observations.map(ArrayDouble)

val config = HMMConfig(xt.size, NUM_STATES, NUM_SYMBOLS, 
    MAX_ITERS, EPS)
val hmm = HMM[Array[Int]](config,  xt) //52
show(s"Training):\n${hmm.model.toString}")

HMM类的构造函数需要一个T => Array[Int]隐式转换,该转换由quantize函数实现(第51行)。hmm.model模型是通过使用预定义配置和观察状态序列obsSeq实例化HMM类创建的(第52行)。

HMM 的培训生成以下状态转移概率矩阵:

A 1 2 3 4 5
1 0.090 0.026 0.056 0.046 0.150
2 0.094 0.123 0.074 0.058 0.0
3 0.093 0.169 0.087 0.061 0.056
4 0.033 0.342 0.017 0.031 0.147
5 0.386 0.47 0.314 0.541 0.271

发射矩阵如下:

| B | 1 | 2 | 3 | 4 | 5 | 6 |
| --- | --- | --- | --- | --- | --- |
| 1 | 0.203 | 0.313 | 0.511 | 0.722 | 0.264 | 0.307 |
| 2 | 0.149 | 0.729 | 0.258 | 0.389 | 0.324 | 0.471 |
| 3 | 0.305 | 0.617 | 0.427 | 0.596 | 0.189 | 0.186 |
| 4 | 0.207 | 0.312 | 0.351 | 0.653 | 0.358 | 0.442 |
| 5 | 0.674 | 0.520 | 0.248 | 0.294 | 0.259 | 0.03 |

测试用例 2 – 评估

评估的目标是计算给定λ模型(A0B0PI0)的xt观察数据的概率:

val A0 = Array[Array[Double]](
  ArrayDouble,
  ArrayDouble,
  …. 
) 
val B0 =  Array[Array[Double]](
  ArrayDouble,
  ArrayDouble,
  …  
)
val PI0 = ArrayDouble

val xt = VectorDouble.map(ArrayDouble)
val max = data.max
val min = data.min
implicit val quantize = (x: DblArray) => 
  ((x.head/(max - min) + min)*(B0.head.length-1)).toInt   //55
val lambda = HMMModel(
  DMatrix(A0), DMatrix(B0), PI0, xt.length) //53
evaluation(lambda, xt).map( _.toString).map(show(_)) //54

模型通过将A0状态转移概率和B0发射概率作为DMatrix类型的矩阵直接转换创建(第53行)。评估方法生成一个HMMPrediction对象,将其转换为字符串,然后显示在标准输出中(第54行)。

quantization方法包括在lambda模型关联的符号(或范围)的数量上归一化输入数据。符号的数量是发射概率矩阵B的行的大小。在这种情况下,输入数据的范围是[0.0, 3.0]。使用线性变换f(x) = x/(max – min) + min进行归一化,然后根据符号的数量(或状态值)进行调整(第55行)。

在调用评估方法之前,必须显式定义quantize量化函数。

备注

解码测试用例

请参考源代码和 API 文档中与解码形式相关的测试用例。

HMM 作为过滤技术

隐藏马尔可夫模型(HMM)的评估形式非常适合过滤离散状态的数据。与第三章“数据预处理”中介绍的卡尔曼滤波器等时间序列滤波器相反,HMM 要求数据在创建可靠模型时保持平稳。然而,隐藏马尔可夫模型克服了分析时间序列分析的一些局限性。滤波器和平滑技术假设噪声(频率均值、方差和协方差)是已知的,并且通常遵循高斯分布。

隐藏马尔可夫模型没有这样的限制。滤波技术,如移动平均技术、离散傅里叶变换和卡尔曼滤波器,适用于离散和连续状态,而 HMM 则不适用。此外,扩展卡尔曼滤波器可以估计非线性状态。

条件随机场

条件随机场CRF)是由 John Lafferty、Andrew McCallum 和 Fernando Pereira 在世纪之交引入的一种判别式机器学习算法,作为 HMM 的替代方案。该算法最初是为了为一系列观测序列分配标签而开发的。

让我们考虑一个具体的例子,以理解观测值和标签数据之间的条件关系。

CRF 简介

让我们考虑在足球比赛中使用视频和音频的组合来检测犯规的问题。目标是协助裁判并分析球员的行为,以确定场上的动作是否危险(红牌)、不适当(黄牌)、有疑问需要回放,或者合法。以下图像是图像处理中视频帧分割的示例:

CRF 简介

一个需要机器学习的图像处理问题示例

视频分析包括对每个视频帧进行分割并提取图像特征,如颜色或边缘[7:10]。一个简单的分割方案是将每个视频帧分解成瓦片或像素组,这些像素组按其在屏幕上的坐标索引。然后为每个瓦片 S[ij] 创建一个时间序列,如下面的图像所示:

CRF 简介

视频帧序列中像素的学习策略

图像分割段 S[ij] 是与多个观测值相关联的标签之一。相同的特征提取过程也适用于视频相关的音频。以下模型图展示了视频/图像分割段与足球运动员之间争执的隐藏状态之间的关系:

CRF 简介

足球违规检测中 CRF 的无向图表示

CRFs 是判别模型,可以被视为逻辑回归的结构化输出扩展。CRFs 解决对数据序列进行标记的问题,例如为句子中的每个单词分配一个标签。目标是估计输出(观察)值 Y 与输入值(特征)X 之间的相关性。

输出值与输入值之间的相关性描述为图(也称为图结构 CRF)。图结构 CRF 的一个很好的例子是团。团是图中的一组连接节点,其中每个顶点都有一个边连接到团中其他所有顶点。

这样的模型很复杂,其实现具有挑战性。大多数与时间序列或有序数据序列相关的实际问题都可以作为观察线性序列与输入数据线性序列之间的相关性来解决,这与 HMM 类似。这样的模型被称为线性链结构图 CRF或简称线性链 CRF

CRF 简介

非线性链和线性链 CRF 的示意图

线性链 CRF 的一个主要优点是,可以使用动态规划技术(如 HMM 中使用的 Viterbi 算法)估计最大似然 p(Y|X, w)。从现在开始,本节将专门关注线性链 CRF,以保持与上一节中描述的 HMM 的一致性。

线性链 CRF

考虑一个随机变量 X={x[i]}[0:n-1],它代表 n 个观察值,以及一个随机变量 Y,它代表相应的标签序列 Y={y[j]}[0:n-1]。隐藏马尔可夫模型估计联合概率 p(X,Y),因为任何生成模型都需要枚举所有观察值序列。

如果 Y 的每个元素 y[j] 都遵循马尔可夫属性的零阶,那么 (Y, X) 是一个 CRF。似然被定义为条件概率 p(Y|X, w),其中 w 是模型参数向量。

注意

观察依赖性

CRF 模型的目的在于估计 p(Y|X, w) 的最大似然。因此,不需要 X 观察值之间的独立性。

图形模型是一种概率模型,其中图表示随机变量(顶点)之间的条件独立性。随机变量的条件概率和联合概率用边表示。通用条件随机场的图确实可能很复杂。最常见的最简单图是线性链 CRF。

一阶线性链条件随机场可以可视化为一个无向图模型,它说明了在给定一组观察值 X 的条件下,标签 Y[j] 的条件概率:

线性链 CRF

一个线性、条件、无向随机场图

马尔可夫属性通过仅考虑相邻标签来简化给定XY的条件概率,即p(Y[1]|X, Y[j] j ≠1) = p(Y[1]|X, Y[0], Y[2]) 和 p(Y[i]|X, Y[j] j ≠i) = p(Y[i]|X, Y[i-1], Y[i+1])

条件随机域引入了一组新的实体和新的术语:

  • 潜在函数 (f**[i]): 这些严格为正的实值函数代表了一组对随机变量配置的约束。它们没有明显的概率解释。

  • 恒等潜在函数:这些是I(x, t)潜在函数,如果时间t的特征x的条件为真,则取值为 1,否则为 0。

  • 转移特征函数:简称为特征函数,t[i],是接受一系列特征{X[i]}、前一个标签Y[t-1]、当前标签Y[t]和索引i的潜在函数。转移特征函数输出一个实值函数。在文本分析中,转移特征函数可以通过一个句子作为观察特征序列、前一个单词、当前单词和句子中单词的位置来定义。每个转移特征函数分配一个权重,类似于逻辑回归中的权重或参数。转移特征函数在 HMM 中的状态转移因子a[ij]扮演着类似的角色,但没有直接的概率解释。

  • 状态特征函数 (s[j]): 这些是接受特征序列{X[i]}、当前标签Y[i]和索引i的潜在函数。它们在 HMM 中的发射因子扮演着类似的角色。

CRF 定义了在给定观察序列X的情况下,特定标签序列Y的对数概率,作为转移特征函数和状态特征函数的归一化乘积。换句话说,给定观察特征X的特定序列Y的似然性是一个逻辑回归。

针对一阶线性链 CRF 计算条件概率的数学符号在以下信息框中描述:

注意

CRF 条件分布

M1:给定观察x,标签序列y的对数概率定义为:

线性链式 CRF

M2:转移特征函数定义为:

线性链式 CRF

M3:使用以下符号:

线性链式 CRF

M4:使用马尔可夫属性定义给定x的标签y的条件分布:

线性链式 CRF

科学论文中有时将权重w[j]称为λ,这可能会让读者感到困惑;w用于避免与λ正则化因子混淆。

现在,让我们熟悉条件随机域算法及其由 Sunita Sarawagi 实现的实现。

正则化 CRFs 和文本分析

大多数用于展示条件随机字段能力的示例都与文本挖掘、入侵检测或生物信息学相关。尽管这些应用具有很大的商业价值,但它们不适合作为入门测试案例,因为它们通常需要详细描述模型和训练过程。

特征函数模型

对于我们的示例,我们将选择一个简单的问题:如何从不同来源和不同格式的不同来源收集和汇总分析师对任何给定股票的推荐。

证券经纪公司和投资基金的分析师通常会发布任何股票的推荐或评级列表。这些分析师使用不同的评级方案,从买入/持有/卖出、A/B/C 评级和星级评级,到市场表现/中性/市场表现不佳评级。对于此示例,评级按以下方式归一化:

  • 0 表示强烈卖出(F 或 1 星评级)

  • 1 表示卖出(D,2 星,或标记为表现不佳)

  • 2 表示中性(C,持有,3 星,市场表现等)

  • 3 表示买入(B,4 星,市场表现良好等)

  • 4 表示强烈买入(A,5 星,强烈推荐等)

下面是股票分析师推荐的一些示例:

  • 摩根士丹利将 AUY 从中性评级上调为优于市场评级

  • 雷蒙德·詹姆斯将艾恩斯沃斯木材评为优于市场

  • BMO 资本市场将 Bear Creek Mining 升级为优于市场

  • 高盛将其对 IBM 的信心列入其信心名单

目标是从发布推荐或评级的金融机构、被评级的股票、如果有的话,前一次评级以及新评级中提取名称。输出可以插入数据库以进行进一步的趋势分析、预测,或者简单地创建报告。

备注

应用范围

分析师的评级每天通过不同的协议(馈送、电子邮件、博客、网页等)更新。在处理之前,必须从 HTML、JSON、纯文本或 XML 格式中提取数据。在这个练习中,我们假设输入已经使用正则表达式或其他分类器转换为纯文本(ASCII)。

第一步是定义代表评级类别或语义的标签 Y。一个段或序列被定义为推荐句子。在审查了不同的推荐后,我们能够指定以下七个标签:

  • 推荐来源(高盛等)

  • 行动(升级、启动等)

  • 股票(公司名称或股票代码)

  • 来自(可选的关键词)

  • 评级(可选的前一次评级)

  • 收件人

  • 评级(股票的新评级)

训练集是通过从原始数据中通过标记推荐的各个组成部分生成的。股票的第一个(或初始)评级没有前述列表中定义的 4 和 5 的标签。

考虑以下示例:

Citigroup // Y(0) = 1 
upgraded // Y(1) 
Macys // Y(2) 
from // Y(3) 
Buy // Y(4) 
to // Y(5)
Strong Buy //Y(6) = 7

备注

标记

标记作为一个词,其含义可能因上下文而异。在自然语言处理NLP)中,标记是指将属性(形容词、代词、动词、专有名词等)分配给句子中一个词的过程[7:11]。

可以在以下无向图中可视化训练序列:

特征函数模型

作为 CRF 训练序列的推荐示例

你可能会想知道为什么我们需要在创建训练集时标记FromTo标签。原因是这些关键字可能并不总是被明确指出,并且/或者它们在推荐中的位置可能因来源而异。

设计

条件随机字段的实现遵循分类器的模板设计,这在附录 A 中“源代码考虑”部分的不可变分类器设计模板一节有所描述,基本概念

其关键组件如下:

  • 在分类器的实例化过程中,通过训练初始化了一个Model类型的CrfModel模型。模型是一个weights的数组。

  • 预测或分类例程作为ITransform类型的隐式数据转换实现。

  • Crf条件随机字段分类器有四个参数:标签数(或特征数),nLabelsCrfConfig类型的配置,CrfSeqDelimiter类型的分隔符序列,以及包含标记观察值的文件名xt的向量。

  • CrfAdapter类与 IITB CRF 库接口。

  • CrfTagger类从标记的文件中提取特征。

条件随机字段的以下 UML 类图描述了其关键软件组件:

设计

条件随机字段的 UML 类图

UML 图省略了Monitor或 Apache Commons Math 组件等实用特性类。

实现

测试案例使用了印度理工学院 Bombay 的 Sunita Sarawagi 提供的 CRF Java 实现。JAR 文件可以从 SourceForge 下载(sourceforge.net/projects/crf/)。

库以 JAR 文件和源代码的形式提供。一些功能,如选择训练算法,无法通过 API 访问。库的组件(JAR 文件)如下:

  • 实现 CRF 算法的 CRF

  • LBFGS 用于凸函数的有限内存 Broyden-Fletcher-Goldfarb-Shanno 非线性优化(用于训练)。

  • CERN Colt 库用于矩阵操作

  • GNU 通用的哈希容器用于索引

配置 CRF 分类器

让我们看看实现条件随机字段分类器的Crf类。Crf类被定义为ITransform类型的数据转换,如第二章中的单调数据转换部分所述,Hello World!(第2行):

class Crf(nLabels: Int, config: CrfConfig, 
    delims: CrfSeqDelimiter, xt: Vector[String])//1
  extends ITransformString with Monitor[Double]{//2

  type V = Double  //3
  val tagsGen = new CrfTagger(nLabels) //4
  val crf = CrfAdapter(nLabels, tagsGen, config.params) //5
  val model: Option[CrfModel] = train //6
  weights: Option[DblArray] = model.map( _.weights)

  override def |> : PartialFunction[String, Try[V]] //7
}

Crf 构造函数有以下四个参数(第1行):

  • nLabels:这些是用于分类的标签数量

  • config:这是用于训练Crf的配置参数

  • delims:这些是在原始和标记文件中使用的分隔符

  • xt:这是一个包含原始和标记数据的文件名的向量

注意

原始和标记数据的文件名

为了简化,原始观测和标记观测的文件名相同,但扩展名不同:filename.rawfilename.tagged

Monitor 特性用于在训练期间收集配置文件信息(请参阅附录 A 下的Monitor部分,实用类),基本概念)。

|> 预测器的输出类型 V 被定义为 Double(第3行)。

CRF 算法的执行由封装在CrfConfig配置类中的各种配置参数控制:

class CrfConfig(w0: Double, maxIters: Int, lambda: Double, 
     eps: Double) extends Config { //8
  val params = s"""initValue $w0 maxIters $maxIters
     | lambda $lambda scale true eps $eps""".stripMargin
}

为了简化,我们使用默认的CrfConfig配置参数来控制学习算法的执行,除了以下四个变量(第8行):

  • w0 权重的初始化使用预定义值或 0 到 1 之间的随机值(默认为 0)

  • 在学习阶段计算权重时使用的最大迭代次数,maxIters(默认 50)

  • 用于减少高值观测的 L2 惩罚函数的lamdba缩放因子(默认 1.0)

  • 用于计算wj权重最优值的eps收敛标准(默认 1e-4)

注意

L[2]正则化

本实现的条件随机字段支持 L[2]正则化,如第六章中的正则化部分所述,回归和正则化。通过将λ = 0来关闭正则化。

CrfSeqDelimiter 情况类指定以下正则表达式:

  • obsDelim 用于解析原始文件中的每个观测

  • labelsDelim 用于解析标记文件中的每个标记记录

  • seqDelim 用于从原始和标记文件中提取记录

代码如下:

case class CrfSeqDelimiter(obsDelim: String, 
    labelsDelim: String, seqDelim: String)

DEFAULT_SEQ_DELIMITER 实例是本实现中使用的默认序列分隔符:

val DEFAULT_SEQ_DELIMITER = 
   new CrfSeqDelimiter(",\t/ -():.;'?#`&_", "//", "\n") 

CrfTagger标签或标记生成器遍历标记文件,并应用CrfSeqDelimiter的相关正则表达式来提取用于训练的符号(第4行)。

CrfAdapter 对象定义了访问 IITB CRF 库的不同接口(第 5 行)。CRF 实例的工厂是通过 apply 构造函数实现的,如下所示:

object CrfAdapter {
  import iitb.CRF.CRF
    def apply(nLabels: Int, tagger: CrfTagger, 
     config: String): CRF = new CRF(nLabels, tagger, config)
  …
}

注意

适配器类到 IITB CRF 库

对序列进行条件随机字段训练需要定义几个关键接口:

  • DataSequence 用于指定访问训练和测试数据中的观测和标签的机制

  • DataIter 用于遍历使用 DataSequence 接口创建的数据序列

  • FeatureGenerator 用于聚合所有特征类型

这些接口在 CRF Java 库 [7:12] 中包含默认实现。每个接口都必须作为适配器类实现:

class CrfTagger(nLabels: Int) extends FeatureGenerator
class CrfDataSeq(nLabels: Int, tags: Vector[String], delim: String) extends DataSequence
class CrfSeqIter(nLabels: Int, input: String, delim: CrfSeqDelimiter) extends DataIter

请参阅文档化的源代码和 Scaladocs 文件,以了解这些适配器类的描述和实现。

训练 CRF 模型

训练的目标是计算最大化条件对数似然函数的权重 w[j],而不包含 L[2] 惩罚函数。最大化对数似然函数等同于最小化带有 L[2] 惩罚的损失。该函数是凸函数,因此可以迭代地应用任何变体梯度下降(贪婪)算法。

注意

M5: 线性链 CRF 训练集 D = {xi, yi} 的条件对数似然如下所示:

训练 CRF 模型

M6: 损失函数和 L2 惩罚的最大化如下所示:

训练 CRF 模型

训练文件由一对文件组成:

  • 原始数据集:推荐(例如 Raymond James 将 Gentiva Health Services 从表现不佳提升至市场表现

  • 标记数据集:标记推荐(例如 Raymond James [1] 提升 [2] Gentiva Health Services [3],从 [4] 表现不佳 [5] 提升至 [6] 市场表现 [7]

注意

标签类型

在此实现中,标签具有 Int 类型。然而,可以使用其他类型,例如枚举或甚至是连续值(即 Double)。

训练或计算权重可能相当昂贵。强烈建议您将观测数据和标记观测数据集分散到多个文件中,以便它们可以并行处理:

训练 CRF 模型

CRF 权重的计算分布

train 方法通过计算 CRF 的 weights 创建模型。它由 Crf 的构造函数调用:

def train: Option[CrfModel] = Try {
  val weights = if(xt.size == 1)  //9
    computeWeights(xt.head) 
  else {
    val weightsSeries = xt.map( computeWeights(_) )
    statistics(weightsSeries).map(_.mean).toArray //10
  }
  new CrfModel(weights) //11
}._toOption("Crf training failed", logger)

我们不能假设只有一个标记的数据集(即单一对*.raw*.tagged文件)(第9行)。如果只有一对原始和标记文件,用于计算 CRF 权重的computeWeights方法将应用于第一个数据集。在存在多个数据集的情况下,train方法计算从每个标记数据集中提取的所有权重的平均值(第10行)。权重的平均值使用XTSeries对象的statistics方法计算,该方法在第三章的Scala 中的时间序列部分介绍,数据预处理。如果成功,train方法返回CrfModel,否则返回None(第11行)。

为了提高效率,应该使用ParVector类将映射并行化,如下所示:

val parXt = xt.par
val pool = new ForkJoinPool(nTasks)
v.tasksupport = new ForkJoinTaskSupport(pool)
parXt.map(computeWeights(_) )

并行集合在第十二章的Scala部分的并行集合部分有详细描述,可扩展框架

注意

CRF 权重计算

假设输入标记文件共享相同的标签或符号列表,因此每个数据集产生相同的权重数组。

computeWeights方法从每个观测和标记观测文件对中提取权重。它调用CrfTagger标记生成器的train方法(第12行)来准备、归一化和设置训练集,然后在 IITB CRF类上调用训练过程(第13行):

def computeWeights(tagsFile: String): DblArray = {
  val seqIter = CrfSeqIter(nLabels, tagsFile, delims)
  tagsGen.train(seqIter)  //12
  crf.train(seqIter)  //13
}

注意

IITB CRF Java 库评估范围

CRF 库已经通过三个简单的文本分析测试案例进行了评估。尽管该库肯定足够健壮,可以说明 CRF 的内部工作原理,但我不能保证其在其他领域(如生物信息学或过程控制)的可扩展性或适用性。

应用 CRF 模型

预测方法实现了|>数据转换操作符。它接受一个新的观测(分析师对股票的建议)并返回最大似然,如下所示:

override def |> : PartialFunction[String, Try[V]] = {
   case obs: String if( !obs.isEmpty && isModel) => {
     val dataSeq = new CrfDataSeq(nLabels,obs,delims.obsDelim)
     Try (crf.apply(dataSeq)) //14
   }
}

|>方法仅创建一个dataSeq数据序列并调用 IITB CRF类的构造函数(第14行)。对部分函数的obs输入参数的条件相当基础。应该使用正则表达式实现更详细的观测条件。为了可读性,省略了验证类和方法参数的代码以及异常处理程序。

注意

高级 CRF 配置

IITB库的 CRF 模型高度可配置。它允许开发者指定具有任何组合的平坦和嵌套依赖关系的状态-标签无向图。源代码包括几种训练算法,如指数梯度。

测试

执行测试的客户端代码包括定义标签数量,NLABELS(即推荐的标签数量),LAMBDA L2 惩罚因子,允许在损失函数最小化中的最大迭代次数,MAX_ITERS,以及EPS收敛标准:

val LAMBDA = 0.5
val NLABELS = 9
val MAX_ITERS = 100
val W0 = 0.7
val EPS = 1e-3
val PATH = "resources/data/chap7/rating"
val OBS_DELIM = ",\t/ -():.;'?#`&_"

val config = CrfConfig(W0 , MAX_ITERS, LAMBDA, EPS) //15
val delims = CrfSeqDelimiter(DELIM,"//","\n") //16
val crf = Crf(NLABELS, config, delims, PATH) //17
crf.weights.map( display(_) )

三个简单步骤如下:

  1. 实例化 CRF 的config配置(行15

  2. 定义三个delims分隔符以提取标记数据(行16

  3. 实例化和训练 CRF 分类器,crf(行17

对于这些测试,权重的初始值(相对于最大迭代次数以最大化对数似然和收敛标准)设置为 0.7(相对于 100 和 1e-3)。标签序列、观察特征序列和训练集的分隔符针对rating.rawrating.tagged输入数据文件的格式进行了定制。

训练收敛曲线

第一次训练运行从 34 位分析师的股票推荐中发现了 136 个特征。算法在 21 次迭代后收敛。每个迭代的对数似然值被绘制出来,以说明向最优 w 解决方案的收敛:

训练收敛曲线

训练过程中 CRF 的日志条件概率的可视化

训练阶段快速收敛到解决方案。这可以解释为分析师推荐六字段格式变化很小。松散或自由风格的格式在训练期间需要更多的迭代才能收敛。

训练集大小的影响

第二次测试评估了训练集大小对训练算法收敛的影响。它包括计算两个连续迭代之间的模型参数(权重)差异 Δw{w[i]}[t+1]{w[i]}[t]

训练集大小的影响

测试在 163 个随机选择的推荐中使用相同的模型,但使用两个不同的训练集进行:

  • 34 位分析师的股票推荐

  • 55 个股票推荐

较大的训练集是 34 个推荐集的超集。以下图表展示了使用 34 和 55 个 CRF 训练序列生成的特征比较:

训练集大小的影响

使用不同大小的训练集的 CRF 权重收敛

使用两种不同大小的训练集进行测试运行之间的差异非常小。这可以很容易地解释为分析师推荐格式之间的小差异。

L[2]正则化因子的影响

第三次测试评估了 L[2]正则化惩罚对向最优权重/特征收敛的影响。该测试与第一次测试类似,但λ的值不同。以下图表绘制了不同λ = 1/σ2(0.2、0.5 和 0.8)的log [p(Y|X, w)]

L2 正则化因子的影响

L2 惩罚对 CRF 训练算法收敛的影响

条件概率的对数随着迭代次数的增加而减少或增加。L[2]正则化因子越低,条件概率越高。

训练集中分析师建议的变化很小,这限制了过拟合的风险。自由式建议格式对过拟合会更敏感。

比较 CRF 和 HMM

与生成模型相比,判别模型的成本/收益分析适用于条件随机字段与隐马尔可夫模型的比较。

与隐马尔可夫模型不同,条件随机字段不需要观测值相互独立(条件概率)。条件随机字段可以通过扩展转移概率到任意特征函数(这些特征函数可以依赖于输入序列)来被视为 HMM 的推广。HMM 假设转移概率矩阵是常数。

HMM 通过处理更多的训练数据自行学习转移概率a[ij]。HMM 可以被视为 CRF 的一个特例,其中状态转移中使用的概率是常数。

性能考虑

N个状态和T个观测值的隐马尔可夫模型的标准形式的解码和评估时间复杂度为O(N[2]T)。使用 Baum-Welch 算法训练 HMM 的时间复杂度为O(N[2]TM),其中M是迭代次数的数量。

有几种方法可以提高 HMM 的性能:

  • 通过使用稀疏矩阵或跟踪空项来避免在发射概率矩阵中不必要的乘以 0。

  • 在训练数据的最相关子集上训练 HMM。在自然语言处理中标记单词或词袋的情况下,这项技术可能特别有效。

线性链条件随机字段的训练使用与 HMM 实现相同的动态规划技术(维特比算法、前向-后向遍历等)。其训练时间复杂度为O(MTN[2]),其中T是数据序列的数量,N是标签(或预期结果)的数量,M是权重/特征λ的数量。

通过使用如 Akka 或 Apache Spark 等框架,可以将对数似然和梯度的计算分布到多个节点上,从而降低 CRF 训练的时间复杂度,如第十二章所述,可扩展框架 [7:13]。

摘要

在本章中,我们更详细地探讨了使用两种常用算法对具有隐藏(或潜在)状态的观测序列进行建模:

  • 生成式隐马尔可夫模型以最大化p(X,Y)

  • 用于最大化log p(Y|X)的判别性条件随机场

HMM(隐马尔可夫模型)是贝叶斯网络的一种特殊形式。它要求观测值是独立的。尽管有局限性,但条件独立性前提使得 HMM 相对容易理解和验证,而 CRF(条件随机场)则不然。

你学习了如何在 Scala 中实现三种动态规划技术:Viterbi、Baum-Welch 和 alpha/beta 算法。这些算法用于解决各种类型的优化问题。它们应该是你的算法工具箱中的基本组成部分。

条件随机场依赖于逻辑回归来估计模型的最佳权重。这种技术也用于多层感知器,这在第九章(第九章,人工神经网络)中介绍过。下一章将介绍两种重要的逻辑回归替代方案,用于区分观测值:非线性模型的核函数和观测值类之间的边缘最大化。

第八章. 核模型与支持向量机

本章介绍了核函数、二元支持向量机分类器、用于异常检测的单类支持向量机以及支持向量回归。

在第六章的二项分类部分,回归与正则化中,你学习了超平面分割训练集观测值并估计线性决策边界的概念。逻辑回归至少有一个局限性:它要求使用定义好的函数(sigmoid)将数据集线性分离。对于高维问题(大量高度非线性相关的特征),这个问题尤为突出。支持向量机SVMs)通过使用核函数估计最佳分离超平面来克服这一局限性。

本章我们将涵盖以下主题:

  • 一些 SVM 配置参数和核方法对分类准确性的影响

  • 如何将二元支持向量机分类器应用于估计上市公司削减或消除股息的风险

  • 如何使用单类支持向量机检测异常值

  • 支持向量回归与线性回归的比较

支持向量机被表述为一个凸优化问题。本章描述了相关算法的数学基础,以供参考。

核函数

本书至今介绍的所有机器学习模型都假设观测值由一个固定大小的特征向量表示。然而,一些现实世界的应用,如文本挖掘或基因组学,并不适合这种限制。分类过程的关键是定义两个观测值之间的相似性或距离。核函数允许开发者计算观测值之间的相似性,而无需将它们编码在特征向量中 [8:1]。

概述

核方法的概念可能一开始对新手来说有点奇怪。让我们以蛋白质分类的例子来考虑。蛋白质有不同的长度和组成,但这并不妨碍科学家对它们进行分类 [8:2]。

注意

蛋白质

蛋白质是由肽键连接在一起的氨基酸聚合物。它们由一个碳原子与一个氢原子、另一个氨基酸或羧基键合而成。

蛋白质使用生物化学家熟悉的传统分子符号来表示。遗传学家用被称为蛋白质序列注释的字符序列来描述蛋白质。序列注释编码了蛋白质的结构和组成。以下图像展示了蛋白质的分子(左侧)和编码(右侧)表示:

概述

蛋白质的序列注释

对一组蛋白质的分类和聚类需要定义一个相似性因子或距离,用于评估和比较蛋白质。例如,三个蛋白质之间的相似性可以定义为它们序列注释的归一化点积:

概述

三个蛋白质序列注释之间的相似性

为了确定蛋白质属于同一类,您不需要将整个序列注释表示为特征向量。您只需要逐个比较每个序列的每个元素,并计算相似性。同样,相似性的估计不需要两个蛋白质具有相同的长度。

在这个例子中,我们不需要给注释的每个元素分配一个数值。让我们考虑蛋白质注释的一个元素作为其字符 c 和位置 p(例如,K,4)。两个蛋白质注释 xx' 的点积,分别对应长度 nn',定义为两个注释之间相同元素(字符和位置)的数量除以两个注释之间的最大长度(M1):

概述

对三个蛋白质的相似性计算结果为 sim(x,x')=6/12 = 0.50sim(x,x'')=3/13 =0.23,和 sim(x',x'')= 4/13= 0.31

另一个相似之处是,两个相同注释的相似性为 1.0,两个完全不同注释的相似性为 0.0。

注意

相似性的可视化

通常使用径向表示来可视化特征之间的相似性更方便,例如在蛋白质注释的例子中。距离 d(x,x') = 1/sim(x,x') 被可视化成两个特征之间的角度或余弦值。余弦度量在文本挖掘中常用。

在这个例子中,相似性被称为蛋白质序列注释空间中的核函数。

常见的判别核函数

虽然相似度的度量对于理解核函数的概念非常有用,但核函数有更广泛的定义。核 K(x, x') 是一个对称的、非负的实值函数,它接受两个实数参数(两个特征的值)。有许多不同类型的核函数,其中最常见的是以下几种:

  • 线性核函数点积):这在非常高维度的数据中很有用,其中问题可以表示为原始特征的线性组合。

  • 多项式核函数:这扩展了线性核,用于组合非完全线性的特征。

  • 径向基函数RBF):这是最常用的核。它用于标签或目标数据有噪声且需要一定程度的正则化的场合。

  • Sigmoid 核函数:这通常与神经网络一起使用。

  • 拉普拉斯核函数:这是对 RBF 的一种变体,对训练数据有更高的正则化影响。

  • 对数核函数:这在图像处理中使用。

注意

RBF 术语

在这个演示及其实现中使用的库中,径向基函数是高斯核函数的同义词。然而,RBF 也指包括高斯、拉普拉斯和指数函数在内的指数核函数族。

线性回归的简单模型由回归参数(权重)与输入数据的点积组成(参见第六章中的普通最小二乘回归部分,回归和正则化)。

模型实际上是权重和输入的线性组合。通过定义一个通用的回归模型为非线性函数的线性组合,即基函数(M2)的概念可以扩展:

常见的判别核函数

最常用的基函数是幂函数和高斯函数。核函数被描述为两个特征向量 xx' 的基函数向量 φ(x).φ(x') 的点积。以下是一些核方法的列表:

注意

M3:通用核函数定义为:

常见的判别核函数

M4:线性核定义为:

常见的判别核函数

M5:具有斜率 γ、度数 n 和常数 c 的多项式核函数定义为:

常见的判别核函数

M6:具有斜率 γ 和常数 c 的 sigmoid 核定义为:

常见的判别核函数

M7:具有斜率 γ 的径向基函数核定义为:

常见的判别核函数

M8:具有斜率 γ 的拉普拉斯核定义为:

常见的判别核函数

M9:具有度数 n 的对数核定义为:

常见的判别核函数

之前描述的判别核函数列表只是核方法宇宙的一个子集。其他类型的核包括以下几种:

  • 概率核:这些是从生成模型中导出的核。例如,高斯过程这样的概率模型可以用作核函数[8:3]。

  • 平滑核:这是非参数公式,通过最近邻观测值平均密度[8:4]。

  • 可复现核希尔伯特空间:这是有限或无限基函数的点积[8:5]。

在非线性问题中,核函数在支持向量机中扮演着非常重要的角色。

核单子组合

核函数的概念实际上是从微分几何中推导出来的,更具体地说,是从流形中推导出来的,这在第四章的“降维”部分下的“非线性模型”中介绍过,属于“无监督学习”。

流形是嵌入在更高维观测空间中的低维特征空间。两个观测之间的点积(或内积),称为黎曼度量,是在欧几里得切空间上计算的。

注意

热核函数

流形上的核函数实际上是通过求解使用拉普拉斯-贝尔特拉米算子的热方程来计算的。热核是热微分方程的解。它将点积与指数映射关联起来。

核函数是使用指数映射在流形上投影切空间上的点积的复合,如下图所示:

核单子组合

流形、黎曼度量以及内积投影的可视化

核函数是两个函数的复合 g o f

  • 一个实现两个向量 vw 之间的黎曼度量或相似度的函数 h

  • 一个实现相似度 h(v, w) 投影到流形(指数映射)的函数 g

KF 类实现了核函数作为函数 gh 的复合:

type F1 = Double => Double
type F2 = (Double, Double) => Double

case class KFG {
  def metric(v: DblVector, w: DblVector)
      (implicit gf: G => F1): Double =  //1
    g(v.zip(w).map{ case(_v, _w) => h(_v, _w)}.sum) //2
}

KF 类使用可以转换为 Function1[Double, Double]G 类型进行参数化。因此,计算 metric(点积)需要从 GFunction1 的隐式转换(第 1 行)。metric 通过将两个向量配对,映射 h 相似度函数,并求和得到的向量来计算(第 2 行)。

让我们为 KF 类定义单子组合:

val kfMonad = new _Monad[KF] {
  override def mapG,H(f: G =>H): KF[H] = 
     KFH, kf.h) //3
  override def flatMapG,H(f: G =>KF[H]): KF[H] =
     KFH.g, kf.h)
}

创建 kfMonad 实例覆盖了在 第一章 的 Monads 部分中描述的通用 _Monad 特质中定义的 mapflatMap 方法。unit 方法的实现对于单子组合不是必要的,因此被省略。

mapflatMap 方法的函数参数仅适用于指数映射函数 g(第 3 行)。两个核函数 kf1 = g1 o hkf2 = g2 o h 的组合产生核函数 kf3 = g2 o (g1 o h) = (g2 o g1) o h = g3 o h

注意

核函数单子组合的解释

在流形上核函数单子组合的可视化非常直观。两个核函数的组合包括组合它们各自的投影或指数映射函数 g。函数 g 与计算度量时数据点周围的流形曲率直接相关。核函数的单子组合试图调整指数映射以适应流形的曲率。

下一步是定义一个隐式类,将 KF 类型的核函数转换为它的单子表示,以便它可以访问 mapflatMap 方法(第 4 行):

implicit class kF2MonadG {  //4
  def mapH: KF[H] = kfMonad.map(kf)(f)
  def flatMapH: KF[H] =kfMonad.flatMap(kf)(f)
}

让我们通过定义各自的 gh 函数来实现 RBF 径向基函数和多项式核函数 Polynomial。核函数的参数化类型简单为 Function1[Double, Double]

class RBF(s2: Double) extends KFF1 => Math.exp(-0.5*x*x/s2), 
    (x: Double, y: Double) => x -y)
class Polynomial(d: Int) extends KFF1 => Math.pow(1.0+x, d), 
    (x: Double, y: Double) => x*y)

下面是两个核函数组合的例子:一个 kf1 核 RBF,标准差为 0.6(第 5 行),以及一个 kf2 多项式核,次数为 3(第 6 行):

val v = VectorDouble
val w = VectorDouble
val composed = for {
  kf1 <- new RBF(0.6)  //5
  kf2 <- new Polynomial(3)  //6
} yield kf2
composed.metric(v, w) //7

最后,在 composed 核函数上计算 metric(第 7 行)。

注意

SVM 中的核函数

我们对支持向量机的实现使用了包含在 LIBSVM 库中的核函数。

支持向量机

支持向量机是一种线性判别分类器,它在训练过程中试图最大化类别之间的间隔。这种方法与通过逻辑回归训练超平面的定义类似(参考第六章 回归和正则化 中的 二项式分类 部分),回归和正则化)。主要区别在于支持向量机计算观察值组或类别之间的最优分隔超平面。超平面确实是代表通过训练生成的模型的方程。

支持向量机的质量取决于不同类别观察值之间的距离,即间隔。随着间隔的增加,分类器的准确性提高。

线性支持向量机

首先,让我们将支持向量机应用于提取一组标记观察值的线性模型(分类器或回归)。定义线性模型有两种情况。标记观察值如下:

  • 它们在特征空间中自然地被分隔开(可分情况)

  • 它们相互交织并重叠(不可分情况)

当观察值自然分隔时,很容易理解最优分离超平面的概念。

可分情况 – 硬间隔

使用二维 (x, y) 观察值集和两个类别 C[1]C[2] 的二维集,以更好地解释使用超平面分隔训练集观察值的概念。标签 y 的值为 -1 或 +1。

分隔超平面的方程由线性方程 y=w.x[T] +w[0] 定义,它位于类别 C[1] (H[1]: w.x ^T + w[0] + 1=0) 和类别 C[2] (H[2]: w.x^T + w[0] - 1) 的边界数据点之间。平面 H[1]H[2] 是支持向量:

可分情况 – 硬间隔

支持向量机中硬间隔的可视化

在可分情况下,支持向量将观察值完全分隔成两个不同的类别。两个支持向量之间的间隔对所有观察值都是相同的,称为 硬间隔

注意

可分情况

M1:支持向量方程 w 表示为:

可分情况 – 硬间隔

M2:硬间隔优化问题如下所示:

可分情况 – 硬间隔

不可分情况 – 软间隔

在不可分情况下,支持向量无法通过训练完全分隔观察值。它们仅仅变成了线性函数,惩罚位于其各自支持向量 H[1]H[2] 之外(或超出)的少数观察值或异常值。如果异常值离支持向量更远,则惩罚变量 ξ,也称为松弛变量,会增加:

不可分情况 – 软间隔

支持向量机中硬间隔的可视化

属于适当(或自身)类的观测值不需要被惩罚。条件与硬间隔相似,这意味着松弛ξ为零。这种技术惩罚属于该类但位于其支持向量之外的观测值;随着观测值接近另一类的支持向量以及更远,松弛ξ增加。因此,间隔被称为软间隔,因为分离超平面是通过松弛变量强制执行的。

备注

不可分情况

M3:具有C公式的线性支持向量机软间隔的优化定义为:

不可分情况 – 软间隔

在这里,C是惩罚(或逆正则化)因子。

你可能会想知道边缘误差最小化如何与损失函数和为岭回归引入的惩罚因子相关(参见第六章中的数值优化部分,回归和正则化)。公式中的第二个因子对应于普遍存在的损失函数。你肯定能认出第一个项是 L2 正则化惩罚,其中λ = 1/2C

问题可以重新表述为最小化一个称为原问题的函数[8:6]。

备注

M4:使用 L[2]正则化的支持向量机原问题表述如下:

不可分情况 – 软间隔

C惩罚因子是 L2 正则化因子的倒数。损失函数L被称为铰链损失。使用C惩罚(或成本)参数的间隔表述称为C-SVM表述。C-SVM 有时被称为非可分情况的C-εSVM表述。

υ-SVM(或 Nu-SVM)是 C-SVM 的另一种表述。该表述比 C-SVM 更具描述性;υ代表训练观测值被错误分类的上限和位于支持向量上的观测值的下限[8:7]。

备注

M5:使用 L2 正则化的线性 SVM 的ν-SVM表述定义为:

不可分情况 – 软间隔

这里,ρ是一个作为优化变量的边缘因子。

C-SVM 表述在章节中用于二元、单类支持向量分类器以及支持向量回归。

备注

顺序最小优化

优化问题包括在 N 个线性约束下最小化二次目标函数 (),其中 N 是观测数。算法的时间复杂度是 O(N³)。为了将时间复杂度降低到 O(N²),已经引入了一种更有效的算法,称为 顺序最小优化SMO)。

非线性 SVM

到目前为止,我们假设分离超平面以及支持向量是线性函数。不幸的是,在现实世界中,这样的假设并不总是正确的。

最大间隔分类

支持向量机被称为大或 最大间隔分类器。目标是最大化支持向量之间的间隔,对于可分情况使用硬约束(类似地,对于不可分情况使用软约束和松弛变量)。

模型参数 {w[i]} 在优化过程中进行缩放,以确保间隔至少为 1。这类算法被称为最大(或大)间隔分类器。

使用支持向量将非线性模型拟合到标记观测值的问题并不容易。一个更好的替代方案是将问题映射到一个新的更高维空间,使用非线性变换。非线性分离超平面在新空间中成为线性平面,如下图所示:

最大间隔分类

SVM 中核技巧的示意图

非线性 SVM 是通过基函数 ϕ(x) 实现的。非线性 C-SVM 的公式与线性情况非常相似。唯一的区别是约束条件和支持向量,使用基函数 φM6):

最大间隔分类

在前一个方程中,最小化 w^T.ϕ(x) 需要计算内积 ϕ(x)^T.ϕ(x)。基函数的内积是通过在第一节中介绍的一种核函数实现的。前一个凸问题的优化计算了最优超平面 w,它是训练样本 y 的核化线性组合,以及 拉格朗日乘子。这个优化问题的表述被称为 SVM 对偶问题。对偶问题的描述作为参考提及,并且超出了本书的范围 [8:8]。

注意

M7:SVM 对偶问题的最优超平面定义为:

最大间隔分类

M8:SVM 对偶问题的硬间隔公式定义为:

最大间隔分类

核技巧

变换 (x,x') => K(x,x') 将非线性问题映射到更高维空间中的线性问题。它被称为 核技巧

让我们以第一部分定义的多项式核为例,该核在二维空间中具有度数 d = 2 和系数 C0 = 1。两个向量 x = [x[1], x[2]]z = [x'[1], x'[2]] 的多项式核函数在 6 维空间中被分解为一个线性函数:

内核技巧

支持向量分类器 – SVC

支持向量机可以应用于分类、异常检测和回归问题。让我们首先深入了解支持向量分类器。

二进制支持向量机(Binary SVC)

首先要评估的分类器是二元(2 类)支持向量分类器。该实现使用了由台湾大学 Chih-Chung Chang 和 Chih-Jen Lin 创建的 LIBSVM 库[8:9]。

LIBSVM

该库最初是用 C 语言编写的,后来被移植到 Java。可以从www.csie.ntu.edu.tw/~cjlin/libsvm下载为.ziptar.gzip文件。该库包括以下分类器模式:

  • 支持向量分类器(C-SVC、υ-SVC 和单类 SVC)

  • 支持向量回归(υ-SVR 和ε-SVR)

  • RBF 核、线性核、Sigmoid 核、多项式核和预计算核

LIBSVM 具有使用顺序最小优化SMO)的独特优势,这降低了训练n个观察值的时间复杂度至O(n²)。LIBSVM 文档涵盖了硬边缘和软边缘的理论和实现,可在www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf找到。

注意

为什么选择 LIBSVM?

学习和实验支持向量机(SVM)的 LIBSVM 库有其他替代方案。加州大学伯克利分校的 David Soergel 对 Java 版本进行了重构和优化[8:10]。Thorsten Joachims 的 SVMLight [8:11] 在 Spark/MLlib 1.0 中包含了两个使用弹性分布式数据集(resilient distributed datasets)的 Scala 实现的 SVM(请参考第十二章中的Apache Spark部分,可扩展框架)。然而,LIBSVM 是最常用的 SVM 库。

LIBSVM 中不同支持向量分类器和支持向量回归的实现被分解为以下五个 Java 类:

  • svm_model: 这定义了训练过程中创建的模型参数

  • svm_node:此模型表示稀疏矩阵 Q 的元素,该矩阵用于边缘最大化的过程

  • svm_parameters:这包含了支持向量分类器和回归的不同模型,LIBSVM 支持的五种核及其参数,以及用于交叉验证的权重向量

  • svm_problem: 这配置了任何 SVM 算法的输入(观测数量、输入向量数据 x 作为矩阵,以及标签向量 y

  • svm: 这实现了用于训练、分类和回归的算法

该库还包括用于训练、预测和归一化数据集的模板程序。

注意

LIBSVM Java 代码

LIBSVM 的 Java 版本是原始 C 代码的直接移植。它不支持泛型类型,并且不易配置(代码使用 switch 语句而不是多态)。尽管有其局限性,但 LIBSVM 是一个相当经过测试且健壮的 SVM Java 库。

让我们创建一个 Scala 包装器来提高 LIBSVM 库的灵活性和易用性。

设计

支持向量机算法的实现使用了分类器的设计模板(参见 附录 A 中的 设计模板 for classifier 部分,基本概念)。

SVM 实现的关键组件如下:

  • 在分类器实例化过程中通过训练初始化 SVMModel 类型的模型。该模型类是 LIBSVM 中定义的 svm_model 结构的适配器。

  • SVMAdapter 对象与内部 LIBSVM 数据结构和方法进行接口。

  • SVM 支持向量机类被实现为 ITransform 类型的隐式数据转换。它有三个参数:SVMConfig 类型的配置包装器、XVSeries 类型的特征/时间序列以及目标或标记值,DblVector

  • 配置(SVMConfig 类型)由三个不同的元素组成:SVMExecution 定义执行参数,例如最大迭代次数或收敛标准,SVMKernel 指定训练过程中使用的核函数,以及 SVMFormulation 定义用于计算支持向量分类器和回归的非可分情况的公式(Cepsilonnu)。

支持向量机的主要软件组件在以下 UML 类图中描述:

设计

支持向量机的 UML 类图

UML 图省略了 Monitor 或 Apache Commons Math 组件等辅助特性和类。

配置参数

LIBSVM 为配置和执行任何 SVM 算法公开了大量参数。任何 SVM 算法都配置了三类参数,如下所示:

  • 使用 SVMFormulation 类的 SVM 算法(多类分类器、单类分类器、回归等)的公式(或类型)

  • 算法中使用的核函数(RBF 核、Sigmoid 核等)使用 SVMKernel

  • 使用 SVMExecution 类的训练和执行参数(收敛标准、交叉验证的折叠数等)

SVM 公式

配置的实例化包括通过用户选择的 SVM 类型、核函数和执行上下文初始化 param LIBSVM 参数。

每个 SVM 参数的案例类都扩展了通用 SVMConfigItem 特性:

trait SVMConfigItem { def update(param: svm_parameter): Unit }

SVMConfigItem 继承的类负责更新 LIBSVM 中定义的 SVM 参数列表 svm_parameterupdate 方法封装了 LIBSVM 的配置。

SVMFormulation 作为基特性,通过类层次结构对 SVM 算法进行公式的定义如下:

sealed trait SVMFormulation extends SVMConfigItem {   
  def update(param: svm_parameter): Unit 
}

SVM 公式的列表(对于回归,Cnueps)完全定义且已知。因此,层次结构不应更改,并且必须声明 SVMFormulation 特性为密封的。以下是一个 SVM CSVCFormulation 公式类的示例,它定义了 C-SVM 模型:

class CSVCFormulation (c: Double) extends SVMFormulation {   
   override def update(param: svm_parameter): Unit = {      
     param.svm_type = svm_parameter.C_SVC
     param.C = c
  }
}

其他 SVM NuSVCFormulationOneSVCFormulationSVRFormulation 公式类分别实现了υ-SVM、1-SVM 和ε-SVM,用于回归模型。

SVM 核函数

接下来,您需要通过定义和实现 SVMKernel 特性来指定核函数:

sealed trait SVMKernel extends SVMConfigItem {
  override def update(param: svm_parameter): Unit 
}

再次强调,LIBSVM 支持的核函数数量有限。因此,核函数的层次结构是封闭的。以下代码片段以配置半径基函数核 RbfKernel 为例,展示了核定义类的定义:

class RbfKernel(gamma: Double) extends SVMKernel {
  override def update(param: svm_parameter): Unit = {
    param.kernel_type = svm_parameter.RBF
    param.gamma = gamma
}

由于 LIBSVM Java 字节码库的可扩展性不高,但这并不妨碍您在 LIBSVM 源代码中定义一个新的核函数。例如,可以通过以下步骤添加拉普拉斯核:执行以下步骤:

  1. svm_parameter 中创建一个新的核类型,例如 svm_parameter.LAPLACE = 5

  2. svm 类中将核函数名称添加到 kernel_type_table 中。

  3. svm_check_ 参数方法中添加 kernel_type != svm_parameter.LAPLACE

  4. svm 中添加核函数的实现到两个值:kernel_function(Java 代码):

    case svm_parameter.LAPLACE:
       double sum = 0.0;
       for(int k = 0; k < x[i].length; k++) { 
         final double diff = x[i][k].value - x[j][k].value; 
         sum += diff*diff;
        }    
        return Math.exp(-gamma*Math.sqrt(sum));
    
  5. 通过修改现有的 RBF (distanceSqr) 实现添加拉普拉斯核函数到 svm.k_function 方法中。

  6. 重新构建 libsvm.jar 文件

SVM 执行

SVMExecution 类定义了模型训练执行的配置参数,即优化器的 eps 收敛因子(第 2 行),缓存大小 cacheSize(第 1 行),以及在交叉验证期间使用的折数 nFolds

class SVMExecution(cacheSize: Int, eps: Double, nFolds: Int) 
     extends SVMConfigItem {
  override def update(param: svm_parameter): Unit = { 
    param.cache_size = cacheSize //1
    param.eps = eps //2
  }
}

只有当 nFolds 值大于 1 时,才会执行交叉验证。

我们最终准备好创建 SVMConfig 配置类,该类隐藏和管理所有不同的配置参数:

class SVMConfig(formula: SVMFormulation, kernel: SVMKernel,
     exec: SVMExecution) {
  val param = new svm_parameter
  formula.update(param) //3
  kernel.update(param)  //4
  exec.update(param)  //5
}

SVMConfig 类将公式的选择委托给 SVMFormulation 类(第 3 行),核函数的选择委托给 SVMKernel 类(第 4 行),参数的执行委托给 SVMExecution 类(第 5 行)。更新调用序列初始化了 LIBSVM 配置参数列表。

LIBSVM 接口

我们需要创建一个适配器对象来封装对 LIBSVM 的调用。SVMAdapter对象隐藏了 LIBSVM 内部数据结构:svm_modelsvm_node

object SVMAdapter {
  type SVMNodes = Array[Array[svm_node]]
  class SVMProblem(numObs: Int, expected: DblArray) //6

  def createSVMNode(dim: Int, x: DblArray): Array[svm_node] //7
  def predictSVM(model: SVMModel, x: DblArray): Double //8
  def crossValidateSVM(problem: SVMProblem, //9
     param: svm_parameter, nFolds: Int, expected: DblArray) 
  def trainSVM(problem: SVMProblem,  //10
     param: svm_parameter): svm_model 
}

SVMAdapter对象是训练、验证 SVM 模型和执行预测到 LIBSVM 的单个入口点:

  • SVMProblem将 LIBSVM 中训练目标或问题的定义包装起来,使用标签或expected值(行6

  • createSVMNode为每个观察值x创建一个新的计算节点(行7

  • predictSVM根据通过训练生成的模型svm_model预测新观察值x的结果(行8

  • crossValidateSVM使用nFold训练-验证集验证模型svm_model(行9

  • trainSVM执行problem训练配置(行10

注意

svm_node

LIBSVM 的svm_nodeJava 类被定义为观察数组中特征的索引及其值的对:

public class svm_node implements java.io.Serializable {   
  public int index;    
  public double value;
}

SVMAdapter方法将在下一节中描述。

训练

SVM 模型由以下两个组件定义:

  • svm_model:这是在 LIBSVM 中定义的 SVM 模型参数

  • accuracy:这是在交叉验证期间计算的模型准确率

代码如下:

case class SVMModel(val svmmodel: svm_model, 
     val accuracy: Double) extends Model {
  lazy val residuals: DblArray = svmmodel.sv_coef(0)
}

residuals,即r = y – f(x),在 LIBSVM 库中计算。

注意

SVM 模型中的准确率

你可能会想知道为什么准确率的值是模型的一个组件。模型的准确率组件为客户端代码提供了一个与模型相关的质量指标。将准确率集成到模型中,使用户能够在接受或拒绝模型时做出明智的决定。准确率存储在模型文件中,以供后续分析。

接下来,让我们为双分类问题创建第一个支持向量机分类器。SVM 类实现了ITransform单调数据转换,该转换隐式地从训练集中生成一个模型,正如在第二章中“单调数据转换”部分所描述的,Hello World!(行11)。

SVM 的构造函数遵循附录 A 中“不可变分类器设计模板”部分所描述的模板,基本概念

class SVMT <% Double extends ITransform[Array[T]](xt) {//11

  type V = Double   //12
  val normEPS = config.eps*1e-7  //13
  val model: Option[SVMModel] = train  //14

  def accuracy: Option[Double] = model.map( _.accuracy) //15
  def mse: Option[Double]  //16
  def margin: Option[Double]  //17
}

ITransform抽象类的实现需要将预测器的输出值定义为Double(行12)。normEPS用于计算边界的舍入误差(行13)。SVMModel类型的模型通过SVM构造函数通过训练生成(行14)。最后四个方法用于计算accuracy模型的参数(行15)、均方误差mse(行16)和margin(行17)。

让我们看看训练方法,train

def train: Option[SVMModel] = Try {
  val problem = new SVMProblem(xt.size, expected.toArray) //18
  val dim = dimension(xt)

  xt.zipWithIndex.foreach{ case (_x, n) =>  //19
      problem.update(n, createSVMNode(dim, _x))
  }
  new SVMModel(trainSVM(problem, config.param), accuracy(problem))   //20
}._toOption("SVM training failed", logger)

train方法创建SVMProblem,为 LIBSVM 提供训练组件(第18行)。SVMProblem类的作用是管理 LIBSVM 中实现的训练参数的定义,如下所示:

class SVMProblem(numObs: Int, expected: DblArray) {
  val problem = new svm_problem  //21
  problem.l = numObs
  problem.y = expected 
  problem.x = new SVMNodes(numObs)

  def update(n: Int, node: Array[svm_node]): Unit = 
    problem.x(n) = node  //22
}

SVMProblem构造函数的参数、观察数和标签或期望值用于在 LIBSVM 中初始化相应的svm_problem数据结构(第21行)。update方法将每个观察值(定义为svm_node数组)映射到问题(第22行)。

createSVMNode方法从一个观察值创建一个svm_node数组。在 LIBSVM 中,svm_node是一个观察值中特征j的索引(第23行)及其值,y(第24行)的配对:

def createSVMNode(dim: Int, x: DblArray): Array[svm_node] = {
   val newNode = new Arraysvm_node
   x.zipWithIndex.foreach{ case (y, j) =>  {
      val node = new svm_node
      node.index= j  //23
      node.value = y  //24
      newNode(j) = node 
   }}
   newNode

观察与 LIBSVM 节点之间的映射在以下图中说明:

训练

使用 LIBSVM 对观察值进行索引

trainSVM方法通过调用svm_train方法,将具有明确定义的问题和配置参数的训练请求推送到 LIBSVM(第26行):

def trainSVM(problem: SVMProblem, 
     param: svm_parameter): svm_model =
   svm.svm_train(problem.problem, param) //26

准确率是真实正例加上真实负例与测试样本大小的比率(参考第二章中的关键质量指标部分,Hello World!)。只有在SVMExecution配置类中初始化的折数大于 1 时,才会通过交叉验证来计算。实际上,准确率是通过调用 LIBSVM 包中的交叉验证方法svm_cross_validation来计算的,然后计算预测值与标签匹配的数量与观察总数之比:

def accuracy(problem: SVMProblem): Double = { 
  if( config.isCrossValidation ) {
    val target = new ArrayDouble
    crossValidateSVM(problem, config.param,  //27
        config.nFolds, target)

    target.zip(expected)
       .filter{case(x, y) =>Math.abs(x- y) < config.eps}  //28
       .size.toDouble/expected.size
  }
  else 0.0
}

调用SVMAdaptercrossValidateSVM方法将配置和执行交叉验证的config.nFolds(第27行):

def crossValidateSVM(problem: SVMProblem, param: svm_parameter, 
    nFolds: Int, expected: DblArray) {
  svm.svm_cross_validation(problem.problem, param, 
    nFolds, expected)
}

Scala 的filter过滤掉预测不良的观察值(第28行)。这种最小化实现足以开始探索支持向量分类器。

分类

SVM类的|>分类方法实现遵循与其他分类器相同的模式。它调用SVMAdapter中的predictSVM方法,将请求转发给 LIBSVM(第29行):

override def |> : PartialFunction[Array[T], Try[V]] =  {
   case x: Array[T] if(x.size == dimension(xt) && isModel) =>
      Try( predictSVM(model.get.svmmodel, x) )  //29
}

C 惩罚和边缘

第一次评估包括理解惩罚因子C对生成类边缘的影响。让我们实现边缘的计算。边缘定义为2/|w|,并在SVM类中实现为一个方法,如下所示:-

def margin: Option[Double] = 
  if(isModel) {
    val wNorm = model.get.residuals./:(0.0)((s,r) => s + r*r)
    if(wNorm < normEPS) None else Some(2.0/Math.sqrt(wNorm))
  }
  else None

第一条指令计算残差r = y – f(x|w)的平方和,wNorm。如果平方和足够大以避免舍入误差,则最终计算边缘。

使用人工生成的时间序列和标记数据评估边缘。首先,我们定义评估特定惩罚(逆正则化系数)因子C的边缘的方法:

val GAMMA = 0.8
val CACHE_SIZE = 1<<8
val NFOLDS = 1
val EPS = 1e-5

def evalMargin(features: Vector[DblArray], 
    expected: DblVector, c: Double): Int = {
  val execEnv = SVMExecution(CACHE_SIZE, EPS, NFOLDS)
  val config = SVMConfig(new CSVCFormulation(c), 
     new RbfKernel(GAMMA), execEnv)
  val svc = SVMDouble
  svc.margin.map(_.toString)     //30
}

evalMargin 方法使用 CACHE_SIZEEPSNFOLDS 执行参数。执行显示不同 C 值的边缘值(行 30)。该方法通过迭代调用以评估惩罚因子对从模型训练中提取的边缘的影响。测试使用合成时间序列来突出 C 与边缘之间的关系。由 generate 方法创建的合成时间序列由两个大小相等的训练集 N 组成:

  • 为标签 1 生成的数据点 y = x(1 + r/5),其中 r 是在 [0,1] 范围内随机生成的数字(行 31

  • 为标签 -1 随机生成的数据点 y = r(行 32

考虑以下代码:

def generate: (Vector[DblArray], DblArray) = {
  val z  = Vector.tabulate(N)(i => {
    val ri = i*(1.0 + 0.2*Random.nextDouble)
    ArrayDouble  //31
  }) ++
  Vector.tabulate(N)(i =>ArrayDouble)
  (z, Array.fill(N)(1) ++ Array.fill(N)(-1))  //32
}

C 从 0 到 5 的不同值执行 evalMargin 方法:

generate.map(y => 
  (0.1 until 5.0 by 0.1)
    .flatMap(evalMargin(y._1, y._2, _)).mkString("\n") 
)

注意

值与最终值

最终值和非最终值之间有区别。非最终值可以在子类中被覆盖。覆盖最终值会产生编译错误,如下所示:

class A { val x = 5;  final val y = 8 } 
class B extends A { 
  override val x = 9 // OK    
  override val y = 10 // Error 
}

以下图表说明了惩罚或成本因子 C 与边缘之间的关系:

C-penalty and margin

支持向量分类器的边缘值与 C 惩罚因子之间的关系

如预期,随着惩罚项 C 的增加,边缘值会减小。C 惩罚因子与 L [2] 正则化因子 λ 相关,关系为 C ~ 1/λ。具有较大 C 值的模型具有高方差和低偏差,而较小的 C 值将产生较低的方差和较高的偏差。

注意

优化 C 惩罚

C 的最佳值通常通过交叉验证来评估,通过将 C 以 2 的增量幂变化:2n,2n+1,… [8:12]。

核函数评估

下一个测试是对比核函数对预测准确性的影响。再次,生成一个合成时间序列来突出每个核的贡献。测试代码使用运行时预测或分类方法 |> 来评估不同的核函数。让我们创建一个评估和比较这些核函数的方法。我们需要的只是以下内容(行 33):

  • Vector[DblArray] 类型的 xt 训练集

  • Vector[DblArray] 类型的测试集 test

  • 训练集的一组 labels,其值为 0 或 1

  • kF 核函数

考虑以下代码:

val C = 1.0
def evalKernel(xt: Vector[DblArray],  test: Vector[DblArray], 
     labels: DblVector, kF: SVMKernel): Double = { //33

  val config = SVMConfig(new CSVCFormulation(C), kF) //34
  val svc = SVMDouble
  val pfnSvc = svc |>  //35
  test.zip(labels).count{case(x, y) =>pfnSvc(x).get == y}
    .toDouble/test.size  //36
}

SVM 的 config 配置使用 C 惩罚因子 1,C 公式,以及默认执行环境(行 34)。预测的 pfnSvc 部分函数(行 35)用于计算测试集的预测值。最后,evalKernel 方法计算预测值与标记或预期值匹配的成功次数。准确率是成功预测与测试样本大小的比率(行 36)。

为了比较不同的核函数,让我们使用伪随机 genData 数据生成方法生成三个大小为 2N 的二项分类数据集:

def genData(variance: Double, mean: Double): Vector[DblArray] = {
  val rGen = new Random(System.currentTimeMillis)
  Vector.tabulate(N)( _ => { 
    rGen.setSeed(rGen.nextLong)
    ArrayDouble
      .map(variance*_ - mean)  //37
  })
}

随机值通过变换 f(x) = variancex = mean* (行 37) 来计算。训练集和测试集由两类数据点的总和组成:

  • 与标签 0.0 相关的具有方差 a 和均值 b 的随机数据点

  • 与标签 1.0 相关的具有方差 a 和均值 1-b 的随机数据点

考虑以下代码用于训练集:

val trainSet = genData(a, b) ++ genData(a, 1-b)
val testSet = genData(a, b) ++ genData(a, 1-b)

ab 参数是从具有不同分离程度的两组训练数据点中选择,以说明分离超平面。

下图描述了高边缘;使用参数 a = 0.6b = 0.3 生成的第一个训练集展示了高度可分离的类别,具有干净且明显的超平面:

内核评估

当 a = 0.6 和 b = 0.3 时的训练集和测试集散点图

下图描述了中等边缘;参数 a = 0.8b = 0.3 生成两组具有一些重叠的观察值:

内核评估

当 a = 0.8 和 b = 0.3 时的训练集和测试集散点图

下图描述了低边缘;这个最后训练集中的两组观察值是用 a = 1.4b = 0.3 生成的,并显示出显著的重叠:

内核评估

当 a = 1.4 和 b = 0.3 时的训练集和测试集散点图

测试集以与训练集类似的方式生成,因为它们是从相同的数据源提取的:

val GAMMA = 0.8; val COEF0 = 0.5; val DEGREE = 2 //38
val N = 100

def compareKernel(a: Double, b: Double) {
  val labels = Vector.fill(N)(0.0) ++ Vector.fill(N)(1.0)
  evalKernel(trainSet, testSet,labels,new RbfKernel(GAMMA)) 
  evalKernel(trainSet, testSet, labels, 
      new SigmoidKernel(GAMMA)) 
  evalKernel(trainSet, testSet, labels, LinearKernel) 
  evalKernel(trainSet, testSet, labels, 
      new PolynomialKernel(GAMMA, COEF0, DEGREE))
}

每个四个核函数的参数都是从教科书中任意选择的(行 38)。之前定义的 evalKernel 方法应用于三个训练集:高边缘 (a = 1.4)、中等边缘 (a = 0.8) 和低边缘 (a = 0.6),每个都使用四个核(RBF、sigmoid、线性和多项式)。通过计算预测器每次调用中对所有类别的正确分类的观察数来评估准确性,|>

内核评估

使用合成数据的核函数比较图

尽管不同的核函数在影响分类器准确性的方面没有差异,但你可以观察到 RBF 和多项式核产生的结果略为准确。正如预期的那样,随着边缘的减小,准确性降低。减小的边缘表示案例不易分离,这会影响分类器的准确性:

内核评估

RBF 和 Sigmoid 核函数的边缘值对准确性的影响

注意

测试用例设计

比较不同核方法的测试高度依赖于训练集和测试集中数据的分布或混合。在这个测试案例中,使用合成数据来展示观察值类之间的边界。现实世界的数据集可能会产生不同的结果。

总结来说,创建基于 SVC 的模型需要四个步骤:

  1. 选择一个特征集。

  2. 选择 C-惩罚(逆正则化)。

  3. 选择核函数。

  4. 调整核参数。

如前所述,这个测试案例依赖于合成数据来展示边界概念和比较核方法。让我们使用支持向量分类器来展示现实世界的金融应用。

风险分析中的应用

测试案例的目的是评估公司削减或消除其季度或年度股息的风险。所选特征是反映公司长期产生现金流和支付股息能力的财务指标。

我们需要选择以下财务技术分析指标中的任何子集(参考附录 A,基本概念):

  • 过去 12 个月股票价格的相对变化

  • 长期债务权益比率

  • 股息覆盖率比率

  • 年度股息收益率

  • 营业利润率

  • 空头(已售出股份与流通股份的比例)

  • 每股现金-股价比率

  • 每股收益趋势

收益趋势有以下值:

  • 如果过去 12 个月内每股收益下降超过 15%,则-2

  • -1 如果每股收益下降在 5%到 15%之间。

  • 如果每股收益保持在 5%以内,则 0

  • 如果每股收益在 5%到 15%之间增加,则+1。

  • 如果每股收益增加超过 15%,则+2。这些值用 0 和 1 进行归一化。

标签或预期输出(股息变化)按以下方式分类:

  • 如果股息削减超过 5%,则-1

  • 如果股息保持在 5%以内,则 0

  • 如果股息增加超过 5%,则+1

让我们将这三个标签中的两个 {-1, 0, 1} 结合起来,为二元 SVC 生成两个类别:

  • C1 类 = 稳定或下降的股息和 C2 类 = 增加的股息——训练集 A

  • C1 类 = 下降的股息和 C2 类 = 稳定或增加的股息——训练集 B

使用一组固定的 CGAMMA 配置参数以及 2 折验证配置进行不同的测试:

val path = "resources/data/chap8/dividends2.csv"
val C = 1.0
val GAMMA = 0.5
val EPS = 1e-2
val NFOLDS = 2

val extractor = relPriceChange :: debtToEquity :: 
    dividendCoverage :: cashPerShareToPrice :: epsTrend :: 
    shortInterest :: dividendTrend :: 
    List[Array[String] =>Double]()  //39

val pfnSrc = DataSource(path, true, false,1) |> //40
val config = SVMConfig(new CSVCFormulation(C), 
     new RbfKernel(GAMMA), SVMExecution(EPS, NFOLDS))

for {
  input <- pfnSrc(extractor) //41
  obs <- getObservations(input)  //42
  svc <- SVMDouble
} yield {
  show(s"${svc.toString}\naccuracy ${svc.accuracy.get}")
}

第一步是定义 extractor(即从 dividends2.csv 文件中检索的字段列表)(行 39)。由 DataSource 转换类生成的 pfnSrc 部分函数(行 40)将输入文件转换为一系列类型字段(行 41)。一个观察值是一个字段数组。通过转置矩阵观察值 x 特征生成 obs 观察值序列(行 42):

def getObservations(input: Vector[DblArray]):
     Try[Vector[DblArray]] = Try {
  transpose( input.dropRight(1).map(_.toArray) ).toVector
}

测试在 SVM 实例化的过程中计算模型参数和交叉验证的准确性。

注意

LIBSVM 缩放

LIBSVM 支持在训练之前进行特征归一化,称为缩放。缩放的主要优势是避免更大数值范围的属性支配那些数值范围较小的属性。另一个优势是避免计算过程中的数值困难。在我们的示例中,我们使用normalize时间序列的归一化方法。因此,LIBSVM 中的缩放标志被禁用。

测试重复使用不同的特征集,并包括比较不同特征集的支持向量分类器的准确性。特征集是通过使用不同配置的提取器从.csv文件的内容中选择的,如下所示:

val extractor =  … :: dividendTrend :: …

让我们看一下以下图表:

风险分析应用

使用二元 SVC 进行交易策略比较研究

测试表明,选择合适的特征集是应用支持向量机(以及任何其他模型)到分类问题中最关键的步骤。在这个特定案例中,准确性也受到训练集规模较小的 影响。特征数量的增加也减少了每个特定特征对损失函数的贡献。

注意

N 折交叉验证

在这个测试示例中,交叉验证只使用两个折,因为观察的数量很小,并且你想要确保任何类别至少包含几个观察值。

对于测试 B,重复相同的过程,其目的是对减少分配的公司和稳定或增加分配的公司进行分类,如下面的图所示:

风险分析应用

使用二元 SVC 进行交易策略比较研究

在前一个图中,第一个三个特征集和最后两个特征集在预测准确性方面的差异在测试 A 中比测试 B 中更为明显。在这两个测试中,eps特征(每股收益)的趋势提高了分类的准确性。它是对增加分配的公司特别好的预测器。

预测(或不)分配的分布问题可以重新表述为评估公司大幅减少其分配的风险。

如果一家公司完全取消其分配,风险是什么?这种情况很少见,这些案例实际上是异常值。可以使用单类支持向量分类器来检测异常值或异常[8:13]。

使用单类 SVC 进行异常检测

单类支持向量机(SVC)的设计是二类 SVC 的扩展。主要区别在于单个类别包含了大部分基线(或正常)观测值。一个称为 SVC 原点的参考点取代了第二类。异常(或异常)观测值位于单个类别的支持向量之外(或之外):

使用单类 SVC 进行异常检测

单类 SVC 的可视化

异常观测值有一个标记值为-1,而剩余的训练集被标记为+1。为了创建一个相关的测试,我们添加了四家大幅削减股息的公司(股票代码 WLT、RGS、MDC、NOK 和 GM)。数据集包括在削减股息之前的股票价格和财务指标。

此测试案例的实现与二类 SVC 驱动代码非常相似,除了以下方面:

  • 分类器使用 Nu-SVM 公式,OneSVFormulation

  • 标记数据是通过将-1 分配给取消股息的公司,将+1 分配给所有其他公司生成的

测试是在resources/data/chap8/dividends2.csv数据集上执行的。首先,我们需要定义单类 SVM 的公式:

class OneSVCFormulation(nu: Double) extends SVMFormulation {
  override def update(param: svm_parameter): Unit = {
    param.svm_type = svm_parameter.ONE_CLASS
    param.nu = nu
  }
}

测试代码与二类 SVC 执行代码类似。唯一的区别是输出标签的定义;取消股息的公司为-1,其他所有公司为+1:

val NU = 0.2
val GAMMA = 0.5
val EPS = 1e-3
val NFOLDS = 2

val extractor = relPriceChange :: debtToEquity ::
   dividendCoverage :: cashPerShareToPrice :: epsTrend ::
   dividendTrend :: List[Array[String] =>Double]()

val filter = (x: Double) => if(x == 0) -1.0 else 1.0  //43
val pfnSrc = DataSource(path, true, false, 1) |>
val config = SVMConfig(new OneSVCFormulation(NU),  //44
    new RbfKernel(GAMMA), SVMExecution(EPS, NFOLDS))

for {
  input <- pfnSrc(extractor)
  obs <- getObservations(input)
  svc <- SVMDouble).toVector)
} yield {
  show(s"${svc.toString}\naccuracy ${svc.accuracy.get}")'
}

标签或预期数据是通过应用二进制过滤器到最后的dividendTrend字段(行43)生成的。配置中的公式具有OneSVCFormulation类型(行44)。

模型以 0.821 的准确率生成。这种准确率水平不应令人惊讶;异常值(取消股息的公司)被添加到原始股息.csv文件中。这些异常值与原始输入文件中的基线观测值(减少、维持或增加股息的公司)有显著差异。

在有标记观测值的情况下,单类支持向量机是聚类技术的优秀替代品。

注意

异常的定义

单类支持向量分类器生成的结果在很大程度上取决于对异常的主观定义。测试案例假设取消股息的公司具有独特的特征,使其与众不同,甚至与削减、维持或增加股息的公司也不同。不能保证这个假设确实总是有效的。

支持向量回归

大多数使用支持向量机的应用都与分类相关。然而,同样的技术也可以应用于回归问题。幸运的是,与分类一样,LIBSVM 支持两种支持向量回归公式:

  • ∈-VR(有时称为 C-SVR)

  • υ-SVR

为了与前两个案例保持一致性,以下测试使用支持向量回归的 ∈(或 C)公式。

概述

SVR 引入了 误差不敏感区 和不敏感误差 ε 的概念。不敏感区定义了预测值 y(x) 附近的值范围。惩罚成分 C 不影响属于不敏感区 [8:14] 的数据点 {x[i],y[i]}

下图使用单个变量特征 x 和输出 y 阐述了误差不敏感区的概念。在单变量特征的情况下,误差不敏感区是一个宽度为 的带(ε 被称为不敏感误差)。不敏感误差在 SVC 中扮演着与边缘相似的角色。

概述

支持向量回归和不敏感误差的可视化

对于数学倾向者,非线性模型的边缘最大化引入了一对松弛变量。如您所记得,C-支持向量分类器使用单个松弛变量。前面的图展示了最小化公式。

注意

M9:ε-SVR 公式定义为:

概述

这里,ε 是不敏感误差函数。

M10:ε-SVR 回归方程如下:

概述

让我们重用 SVM 类来评估 SVR 的能力,与线性回归相比(参考第六章 普通最小二乘回归中的 普通最小二乘回归 部分,回归和正则化)。

SVR 与线性回归的比较

本测试包括重用单变量线性回归的示例(参考第六章 单变量线性回归中的 单变量线性回归 部分,回归和正则化)。目的是比较线性回归的输出与 SVR 的输出,以预测股价或指数的值。我们选择了标准普尔 500 交易所交易基金,SPY,它是标准普尔 500 指数的代理。

模型由以下内容组成:

  • 一个标记的输出:SPY 调整后的每日收盘价

  • 单个变量特征集:交易时段的索引(或 SPY 的值索引)

实现遵循熟悉的模式:

  1. 定义 SVR 的配置参数(C 成本/惩罚函数,RBF 内核的 GAMMA 系数,EPS 用于收敛标准,以及 EPSILON 用于回归不敏感误差)。

  2. 从数据源(DataSource)中提取标记数据(SPY 的 price),数据源是 Yahoo 财经 CSV 格式的数据文件。

  3. 创建线性回归,SingleLinearRegression,以交易时段的索引作为单一变量特征,以 SPY 调整后的收盘价作为标记的输出。

  4. 将观测值创建为时间序列索引,xt

  5. 使用交易会话的索引作为特征,将 SPY 调整后的收盘价作为标记输出实例化 SVR。

  6. 运行 SVR 和线性回归的预测方法,并比较线性回归和 SVR 的结果,collect

代码如下:

val path = "resources/data/chap8/SPY.csv"
val C = 12
val GAMMA = 0.3
val EPSILON = 2.5

val config = SVMConfig(new SVRFormulation(C, EPSILON), 
    new RbfKernel(GAMMA)) //45
for {
  price <-  DataSource(path, false, true, 1) get close
  (xt, y) <- getLabeledData(price.size)  //46
  linRg <- SingleLinearRegressionDouble //47
  svr <- SVMDouble
} yield {
  collect(svr, linRg, price)
}

配置中的公式具有SVRFormulation类型(第45行)。DataSource类提取 SPY ETF 的价格。getLabeledData方法生成xt输入特征和y标签(或预期值)(第46行):

type LabeledData = (Vector[DblArray], DblVector)
def getLabeledData(numObs: Int): Try[LabeledData ] = Try {
    val y = Vector.tabulate(numObs)(_.toDouble)
    val xt = Vector.tabulate(numObs)(ArrayDouble)
    (xt, y)
}

单变量线性回归SingleLinearRegression使用price输入和y标签作为输入(第47行)进行实例化。

最后,collect方法执行了两个pfSvrpfLinr回归部分函数:

def collect(svr: SVM[Double], 
   linr: SingleLinearRegression[Double], price: DblVector){

  val pfSvr = svr |>
  val pfLinr = linr |>
  for {
    if( pfSvr.isDefinedAt(n.toDouble))
    x <- pfSvr(n.toDouble) 
    if( pfLin.isDefinedAt(n))
    y <- pfLinr(n)
  } yield  {  ... }
}

注意

isDefinedAt

验证一个部分函数是否为特定参数值定义是一个好的实践。这种预防性方法允许开发者选择一个替代方法或完整函数。这是捕获MathErr异常的有效替代方案。

结果显示在下图中,这些图是用 JFreeChart 库生成的。绘制数据的代码被省略,因为它对于理解应用程序不是必要的。

SVR 与线性回归对比

线性回归和 SVR 的比较图

支持向量回归比线性回归模型提供更准确的预测。你还可以观察到 SVR 的 L[2]正则化项对价格均值有较大偏差的数据点(SPY 价格)进行惩罚。C的值越低,L[2]-范数惩罚因子λ =1/C就越大。

注意

SVR 和 L[2] 正则化

欢迎您尝试使用不同的C值来量化 L[2]正则化对 SVR 预测值的影响。

没有必要将 SVR 与逻辑回归进行比较,因为逻辑回归是一个分类器。然而,SVM 与逻辑回归相关;SVM 中的 hinge 损失与逻辑回归中的损失相似[8:15]。

性能考虑

您可能已经观察到,在大型数据集上对支持向量回归模型进行训练是耗时的。支持向量机的性能取决于在训练期间选择的优化器类型(例如,序列最小优化)以最大化边缘:

  • 一个线性模型(没有核的 SVM)在训练N个标记观察值时具有渐近时间复杂度O(N)

  • 非线性模型依赖于核方法,这些方法被表述为具有渐近时间复杂度O(N³)的二次规划问题。

  • 使用序列最小优化技术(如索引缓存或消除空值,如 LIBSVM)的算法具有渐近时间复杂度O(N²),最坏情况(二次优化)为O(N³)

  • 对于非常大的训练集(N > 10,000)的稀疏问题也有O(N²)的渐近时间复杂度

核化支持向量机的时间和空间复杂度已经引起了极大的关注[8:16] [8:17]。

概述

这标志着我们对核和支持向量机的调查结束。支持向量机已成为从大型训练集中提取判别模型的一种稳健替代方案,优于逻辑回归和神经网络。

除了不可避免地引用最大边缘分类器的数学基础,如 SVMs 之外,你还应该对 SVMs 不同变体的调整和配置参数的强大功能和复杂性有一个基本的了解。

与其他判别模型一样,SVMs 的优化方法选择不仅对模型的质量有重大影响,而且对训练和交叉验证过程的表现(时间复杂度)也有重大影响。

下一章将描述第三种最常用的判别监督模型——人工神经网络。

第九章:人工神经网络

20 世纪 90 年代,神经网络的人气激增。它们被视为解决大量问题的银弹。从核心来看,神经网络是一个非线性统计模型,它利用逻辑回归来创建非线性分布式模型。人工神经网络的概念源于生物学,目的是模拟大脑的关键功能,并在神经元、激活和突触方面复制其结构。

在本章中,你将超越炒作,学习以下主题:

  • 多层感知器MLP)的概念和元素

  • 如何使用误差反向传播训练神经网络

  • 评估和调整 MLP 配置参数

  • MLP 分类器的完整 Scala 实现

  • 如何将 MLP 应用于提取货币汇率的相关模型

  • 卷积神经网络CNN)简介

前馈神经网络

人工神经网络背后的想法是构建大脑中自然神经网络的数学和计算模型。毕竟,大脑是一个非常强大的信息处理引擎,在诸如学习、归纳推理、预测和视觉、语音识别等领域超越了计算机。

生物背景

在生物学中,神经网络由通过突触相互连接的神经元群组成[9:1],如下面的图所示:

生物背景

生物神经元和突触的可视化

神经科学家特别感兴趣的是了解大脑中数十亿个神经元如何相互作用,为人类提供并行处理能力。20 世纪 60 年代出现了一个新的研究领域,称为联结主义。联结主义将认知心理学、人工智能和神经科学结合起来。目标是创建一个心理现象的模型。尽管联结主义有多种形式,但神经网络模型已成为所有联结主义模型中最受欢迎和最常教授的 [9:2]。

生物神经元通过称为刺激的电能进行交流。这个神经元网络可以用以下简单的示意图表示:

生物背景

神经层、连接和突触的表示

这种表示法将神经元群分类为层。用于描述自然神经网络的术语在人工神经网络中也有相应的命名法。

生物神经网络 人工神经网络
轴突 连接
树突 连接
突触 权重
电势 加权求和
阈值 偏置权重
信号、刺激 激活
神经元群 神经元层

在生物世界中,刺激在神经元之间不沿任何特定方向传播。人工神经网络可以具有相同的自由度。数据科学家最常用的人工神经网络有一个预定义的方向:从输入层到输出层。这些神经网络被称为前馈神经网络FFNN)。

数学背景

在上一章中,你了解到支持向量机有能力将模型的训练表述为非线性优化,其目标函数是凸的。凸目标函数相对容易实现。缺点是 SVM 的核化可能会导致大量的基函数(或模型维度)。请参阅第八章下的核技巧部分,核模型和支持向量机。一种解决方案是通过参数化减少基函数的数量,这样这些函数可以适应不同的训练集。这种方法可以建模为 FFNN,称为多层感知器 [9:3]。

线性回归可以可视化为一个简单的使用神经元和突触的连接模型,如下所示:

数学背景

二层神经网络

特征 x[0]=+1 被称为偏置输入(或偏置元素),它对应于经典线性回归中的截距。

与支持向量机一样,线性回归适用于可以线性分离的观测值。现实世界通常由非线性现象驱动。因此,逻辑回归自然被用来计算感知器的输出。对于一组输入变量 x = {x[i]}[0,n] 和权重 w={w[i]}[1,n],输出 y 的计算如下(M1):

数学背景

FFNN 可以被视为堆叠了具有线性回归输出层的逻辑回归层。

每个隐藏层中变量的值是通过连接权重和前一层的输出的点积的 Sigmoid 函数来计算的。尽管很有趣,但人工神经网络背后的理论超出了本书的范围[9:4]。

多层感知器

感知器是一个基本的处理单元,通过将标量或向量映射到二元(或XOR)值 {true, false}{-1, +1} 来执行二元分类。原始的感知器算法被定义为单层神经元,其中特征向量的每个值 x[i] 都并行处理并生成单个输出 y。感知器后来扩展到包括激活函数的概念。

单层感知器仅限于处理单个权重的线性组合和输入值。科学家发现,在输入层和输出层之间添加中间层使他们能够解决更复杂的分类问题。这些中间层被称为隐藏层,因为它们只与其他感知器接口。隐藏节点只能通过输入层访问。

从现在开始,我们将使用三层感知器来研究和说明神经网络的特征,如下所示:

多层感知器

三层感知器

三层感知器需要两组权重:w[ij] 用于处理输入层到隐藏层的输出,v[ij] 在隐藏层和输出层之间。在线性回归和逻辑回归中,截距值 w[0] 在神经网络的可视化中表示为 +1w[0].1+ w[1].x[1]+w[2].x[2]+ …)。

备注

无隐藏层的 FFNN

无隐藏层的 FFNN 类似于线性统计模型。输入层和输出层之间唯一的转换或连接实际上是一个线性回归。线性回归是无隐藏层 FFNN 的一个更有效的替代方案。

MLP 组件及其实现的描述依赖于以下阶段:

  1. 软件设计的概述。

  2. MLP 模型组件的描述。

  3. 四步训练周期的实现。

  4. 训练策略的定义和实现以及由此产生的分类器。

备注

术语

人工神经网络包含大量学习算法,多层感知器是其中之一。感知器确实是按照输入、输出和隐藏层组织起来的神经网络组件。本章专门讨论具有隐藏层的多层感知器。术语“神经网络”和“多层感知器”可以互换使用。

激活函数

感知器被表示为权重w[i]和输入值x[i]通过输出单元激活函数h的线性组合,如下所示(M2):

激活函数

输出激活函数h必须对于权重的一定范围内的值是连续且可微分的。它根据要解决的问题采取不同的形式,如下所述:

  • 回归模式的输出层(线性公式)的恒等式

  • 二项式分类器的隐藏层和输出层的 sigmoid σ

  • 多项式分类的 softmax

  • 双曲正切tanh用于使用零均值的分类

软 max 公式在训练周期下的步骤 1 - 输入前向传播中描述。

网络拓扑

输出层和隐藏层具有计算能力(权重、输入和激活函数的点积)。输入层不转换数据。一个 n 层神经网络是一个具有n计算层的网络。其架构由以下组件组成:

  • 一个输入层

  • n-1隐藏层

  • 一个输出层

一个全连接神经网络的所有输入节点都连接到隐藏层神经元。如果一个或多个输入变量没有被处理,网络被定义为部分连接神经网络。本章讨论的是全连接神经网络。

备注

部分连接网络

部分连接网络并不像看起来那么复杂。它们可以通过将一些权重设置为零从全连接网络生成。

输出层的结构高度依赖于需要解决的问题的类型(回归或分类),也称为多层感知器的操作模式。当前问题的类型定义了输出节点的数量[9:5]。考虑以下示例:

  • 一个一元回归有一个输出节点,其值是一个实数[0, 1]

  • 一个具有n个变量的多元回归有n个实数输出节点

  • 二元分类有一个二元输出节点{0, 1}{-1, +1}

  • 多项式或 K 类分类有K个二元输出节点

设计

MLP 分类器的实现遵循与先前分类器相同的模式(参考附录 A 中的不可变分类器设计模板部分,基本概念):

  • 一个MLPNetwork连接主义网络由MLPLayer类型的神经元层组成,这些神经元通过MLPConnection类型的连接器中的MLPSynapse类型的突触连接。

  • 所有配置参数都被封装到一个单一的MLPConfig配置类中。

  • 一个模型,MLPModel,由一系列连接突触组成。

  • MLP多层感知器类被实现为一个数据转换ITransform,模型会自动从带有标签的训练集中提取。

  • MLP多层感知器类接受四个参数:一个配置、一个XVSeries类型的特征集或时间序列、一个带有标签的XVSeries类型的数据集以及一个Function1[Double, Double]类型的激活函数。

多层感知器的软件组件在以下 UML 类图中描述:

.

设计

多层感知器的 UML 类图

类图是一个方便的导航图,用于理解构建 MLP 所使用的 Scala 类的角色和关系。让我们从 MLP 网络及其组件的实现开始。UML 图省略了Monitor或 Apache Commons Math 组件等辅助特性和类。

配置

多层感知器的MLPConfig配置包括定义网络配置及其隐藏层、学习和训练参数以及激活函数:

case class MLPConfig(
    val alpha: Double,  //1
    val eta: Double, 
    val numEpochs: Int, 
    val eps: Double, 
    val activation: Double => Double) extends Config {  //1
}

为了提高可读性,配置参数的名称与数学公式中定义的符号相匹配(行1):

  • alpha: 这是用于在线训练中平滑权重梯度计算的动量因子α。动量因子在第二步 - 错误反向传播下的训练周期中的数学表达式M10中使用。

  • eta: 这是梯度下降中使用的学习率η。梯度下降通过数量eta.(predicted – expected).input更新模型的权重或参数,如数学公式M9第二步 - 错误反向传播部分下训练周期中所述。梯度下降在第一章的Let's kick the tires中介绍,入门

  • numEpochs: 这是允许训练神经网络的最多周期数(或循环或剧集)。一个周期是在整个观察集上执行错误反向传播的执行。

  • eps: 这是在神经网络训练过程中用作退出条件的收敛标准,当error < eps时。

  • 激活函数: 这是用于非线性回归并应用于隐藏层的激活函数。默认函数是用于逻辑回归的 sigmoid 函数(或双曲正切函数)(参考第六章中的逻辑函数部分,回归和正则化)。

网络组件

MLP 模型的训练和分类依赖于网络架构。MLPNetwork类负责创建和管理网络的不同组件及其拓扑,即层、突触和连接。

网络拓扑

实例化MLPNetwork类需要一个最小参数集,包括一个模型实例,以及一个可选的第三个参数(行2):

  • 在上一节中引入的 MLP 执行配置,config

  • 定义为每层节点数的数组:输入层、隐藏层和输出层。

  • 如果已经通过训练生成,则具有Option[MLPModel]类型的model,否则为None

  • 对 MLP 操作模式的隐式引用

代码如下:

class MLPNetwork(config: MLPConfig, 
     topology: Array[Int], 
     model: Option[MLPModel] = None)
     (implicit mode: MLPMode){ //2

  val layers = topology.zipWithIndex.map { case(t, n) => 
    if(topology.size != n+1) 
       MLPLayer(n, t+1, config.activation) 
   else MLPOutLayer(n, t) 
  }  //3
  val connections = zipWithShift1(layers,1).map{case(src,dst) => 
     new MLPConnection(config, src, dst,  model)} //4

  def trainEpoch(x: DblArray, y: DblArray): Double //5
  def getModel: MLPModel  //6
  def predict(x: DblArray): DblArray  //7
}

MLP 网络具有以下组件,这些组件来自拓扑数组:

  • MLPLayers类的多个(行3

  • MLPConnection类的多个连接(行4

拓扑定义为从输入节点开始的每层的节点数数组。数组索引遵循网络中的前向路径。输入层的大小自动从特征向量的大小生成。输出层的大小自动从输出向量的大小提取(行3)。

MLPNetwork的构造函数通过将MLPLayer实例分配并排序到拓扑中的每个条目来创建一系列层(行3)。构造函数创建了层数 - 1MLPConnection类型的层间连接(行4)。XTSeries对象的zipWithShift1方法将时间序列与其重复移位一个元素的时间序列进行连接。

trainEpoch方法(行5)实现了对整个观察集的单次遍历的网络训练(参考训练周期下的整合一切部分)。getModel方法检索通过训练 MLP 生成的模型(突触)(行6)。predict方法使用前向传播算法计算网络生成的输出值(行7)。

以下图表展示了模型不同组件之间的交互:MLPLayerMLPConnectionMLPSynapse

网络拓扑

MLP 网络的核心组件

输入层和隐藏层

首先,让我们从MLPLayer层类的定义开始,该类完全由其在网络中的位置(或排名)id以及它包含的节点数,numNodes来指定:

class MLPLayer(val id: Int, val numNodes: Int, 
    val activation: Double => Double)  //8
    (implicit mode: MLPMode){  //9

  val output = Array.fill(numNodes)(1.0)  //10

  def setOutput(xt: DblArray): Unit =
      xt.copyToArray(output, 1) //11
  def activate(x: Double): Double = activation(x) .//12
  def delta(loss: DblArray, srcOut: DblArray, 
      synapses: MLPConnSynapses): Delta //13
  def setInput(_x: DblArray): Unit  //14
}

id参数是层在网络中的顺序(输入层为 0,第一个隐藏层为 1,输出层为n – 1)。numNodes值是此层中元素或节点的数量,包括偏置元素。activation函数是层给出的用户定义模式或目标的最后一个参数(行8)。操作mode必须在实例化层之前隐式提供(行9)。

层的output向量是一个未初始化的值数组,在正向传播过程中更新。它使用值 1.0 初始化偏置值(行9)。与输出向量相关的权重差矩阵deltaMatrix(行10)使用错误反向传播算法更新,如训练周期下的步骤 2 – 错误反向传播部分所述。setOutput方法在反向传播错误到网络的输出(预期 – 预测)值时初始化输出和隐藏层的输出值(行11)。

activate方法调用配置中定义的激活方法(tanhsigmoid等)(行12)。

delta方法计算应用于每个权重或突触的校正,如训练周期下的步骤 2 – 错误反向传播部分所述(行13)。

setInput方法使用值x初始化输入和隐藏层的节点output值,除了偏置元素。该方法在输入值的正向传播过程中被调用(行14)。

def setInput(x: DblVector): Unit = 
  x.copyToArray(output, output.length -x.length)

MLPLayer类的方法对于输入和隐藏层被重写以适用于MLPOutLayer类型的输出层。

输出层

与隐藏层相反,输出层既没有激活函数也没有偏置元素。MLPOutLayer类有以下参数:网络中的顺序id(作为网络的最后一层)和输出或节点的数量numNodes(行15):

class MLPOutLayer(id: Int, numNodes: Int) 
    (implicit mode: MLP.MLPMode)  //15
  extends MLPLayer(id, numNodes, (x: Double) => x) {

  override def numNonBias: Int = numNodes
  override def setOutput(xt: DblArray): Unit = 
    obj(xt).copyToArray(output)
  override def delta(loss: DblArray, srcOut: DblArray, 
     synapses: MLPConnSynapses): Delta 
 …
}

numNonBias方法返回网络的实际输出值数量。delta方法的实现如训练周期下的步骤 2 – 错误反向传播部分所述。

突触

突触被定义为两个实数(浮点数)值的对:

  • 从前一层的神经元i到神经元j的连接的权重w[ij]

  • 权重调整(或权重梯度)∆w[ij]

它的类型被定义为MLPSynapse,如下所示:

type MLPSynapse = (Double, Double)

连接

连接通过选择索引n的两个连续层(相对于n + 1)作为源(相对于目标)来实例化。两个连续层之间的连接实现了突触矩阵,作为(w[ij] , ∆w[ij])对。MLPConnection实例使用以下参数创建(行16):

  • 配置参数,config

  • 源层,有时也称为入口层,src

  • dst目标(或出口)层

  • 如果模型已经通过训练生成,则引用model,如果没有训练,则为None

  • 隐式定义的操作模式或目标模式

MLPConnection类定义如下:

type MLPConnSynapses = Array[Array[MLPSynapse]]

class MLPConnection(config: MLPConfig, 
    src: MLPLayer, 
    dst: MLPLayer,
    model: Option[MLPModel]) //16
    (implicit mode: MLP.MLPMode) {

  var synapses: MLPConnSynapses  //17
  def connectionForwardPropagation: Unit //18
  def connectionBackpropagation(delta: Delta): Delta  //19
    …
}

MLP 算法初始化的最后一步是选择权重(突触)的初始值(通常是随机的)(行17)。

MLPConnection方法实现了connectionForwardPropagation连接的权重计算的前向传播(行18)和训练期间的 delta 误差的反向传播connectionBackpropagation(行19)。这些方法在下一节中描述,该节与 MLP 模型的训练相关。

权重的初始化

权重的初始化值取决于特定领域。一些问题需要一个非常小的范围,小于1e-3,而其他问题则使用概率空间[0, 1]。初始值会影响收敛到最优权重集所需的 epoch 数量[9:6]。

我们的实现依赖于 sigmoid 激活函数,并使用范围[0, BETA/sqrt(numOutputs + 1)](行20)。然而,用户可以选择不同的随机值范围,例如tanh激活函数的[-r, +r]。偏置的权重显然定义为w[0] =+1,其权重调整初始化为∆w[0] = 0,如下所示(行20):

var synapses: MLPConnSynapses = if(model == None) {
  val max = BETA/Math.sqrt(src.output.length+1.0) //20
  Array.fill(dst.numNonBias)(
    Array.fill(src.numNodes)((Random.nextDouble*max,0.00))
  )
} else model.get.synapses(src.id)  //21

如果连接已经通过训练创建,则其权重或突触来自模型(行21)。

模型

MLPNetwork类定义了多层感知器的拓扑模型。权重或突触是多层感知器模型MLPModel类型的属性,通过训练生成:

case class MLPModel(
   val synapses: Vector[MLPConnSynapses]) extends Model

模型可以存储在简单的键值对 JSON、CVS 或序列文件中。

备注

封装和模型工厂

网络组件:连接、层和突触被实现为清晰起见的顶级类。然而,模型不需要将其内部工作暴露给客户端代码。这些组件应声明为模型的内部类。工厂设计模式非常适合动态实例化MLPNetwork实例[9:7]。

一旦初始化,MLP 模型就准备好使用前向传播、输出误差反向传播和权重及权重梯度的迭代调整组合进行训练。

问题类型(模式)

与多层感知器相关联的有三种不同类型的问题或操作模式:

  • 二项式分类(二元)具有两个类别和一个输出

  • 多项式分类(多类别)具有n个类别和输出

  • 回归

每种操作模式都有独特的错误、隐藏层和输出层激活函数,如下表所示:

操作模式 错误函数 隐藏层激活函数 输出层激活函数
二项式分类 交叉熵 Sigmoid Sigmoid
多项式分类 平方误差和或均方误差 Sigmoid Softmax
回归 平方误差和或均方误差 Sigmoid 线性

多层感知器的操作模式表

交叉熵由数学表达式M6M7描述,而 softmax 在训练周期下的步骤 1 – 输入前向传播中使用公式M8

在线训练与批量训练

一个重要的问题是找到一种策略来按有序数据序列进行时间序列的训练。有两种策略可以创建用于时间序列的 MLP 模型:

  • 批量训练:整个时间序列一次作为单个输入处理到神经网络中。在每个周期使用时间序列输出的平方误差总和更新权重(突触)。一旦平方误差总和满足收敛标准,训练就会退出。

  • 在线训练:一次将观测值输入到神经网络中。一旦处理完时间序列,就会计算所有观测值的时间序列平方误差总和(sse)。如果未满足退出条件,观测值将被网络重新处理。在线训练与批量训练

    在线和批量训练的示意图

在线训练比批量训练快,因为每个数据点都必须满足收敛标准,这可能导致更少的周期数 [9:12]。如前所述的动量因子或任何自适应学习方案可以提高在线训练方法的表现和准确性。

本章的所有测试案例都应用了在线训练策略。

训练周期

模型的训练通过迭代多次处理训练观测值。一个训练周期或迭代被称为周期。每个周期观测值的顺序都会被打乱。训练周期的三个步骤如下:

  1. 将输入值的前向传播推进到特定周期。

  2. 输出错误的计算和反向传播。

  3. 评估收敛标准,如果满足标准则退出。

训练过程中网络权重的计算可以使用每个层的标签数据和实际输出的差异。但这种方法不可行,因为隐藏层的输出实际上是未知的。解决方案是将输出值(预测值)上的错误反向传播到输入层,通过隐藏层,如果定义了错误。

训练周期或训练周期的三个步骤总结如下图所示:

训练周期

MLP 训练的迭代实现

让我们在 MLPNetwork 类的 trainEpoch 方法中使用简单的 Scala 高阶函数 foreach 应用训练周期的三个步骤,如下所示:

def trainEpoch(x: DblArray, y: DblArray): Double = {
  layers.head.setInput(x)  //22
  connections.foreach( _.connectionForwardPropagation) //23

  val err = mode.error(y, layers.last.output) 
  val bckIterator = connections.reverseIterator 

  var delta = Delta(zipToArray(y, layers.last.output)(diff)) //24
  bckIterator.foreach( iter => 
     delta = iter.connectionBackpropagation(delta))  //25
  err  //26
}

你当然可以识别出训练周期的前两个阶段:输入的前向传播和单个周期的在线训练错误的反向传播。

网络在一个周期 trainEpoch 中的训练执行,初始化输入层为观察值 x(第 22 行)。通过调用每个连接的 connectionForwardPropagation 方法(第 23 行),将输入值通过网络传播。delta 错误从输出层和期望值 y(第 24 行)的值初始化。

训练方法通过在反向迭代器 bckIterator 上调用 connectionBackpropagation 方法反向遍历连接,以通过每个连接传播错误。最后,根据操作模式(第 26 行),训练方法返回累积错误、均方误差或交叉熵。

这种方法与隐藏马尔可夫模型中的 beta(或反向)传递没有太大区别,这在第七章 Beta – 反向传递 节中有所介绍,该节在 序列数据模型 中。

让我们看看每种类型连接的前向和反向传播算法的实现:

  • 一个输入层或隐藏层到隐藏层

  • 一个隐藏层到输出层

步骤 1 – 输入前向传播

如前所述,隐藏层的输出值是计算为权重 w[ij] 和输入值 x[i] 的点积的 sigmoid 或双曲正切。

在以下图中,MLP 算法计算隐藏层的权重 w[ij] 和输入 x[i] 的线性乘积。然后,该乘积通过激活函数 σ(sigmoid 或双曲正切)进行处理。输出值 z[j] 然后与输出层没有激活函数的权重 v[ij] 结合:

步骤 1 – 输入前向传播

MLP 隐藏层和输出层中权重的分布

神经元 j 输出的数学公式定义为激活函数和权重 w[ij] 与输入值 x[i] 的点积的组合。

注意

M3:从先前隐藏层的输出值 z[j] 和权重 vkj 计算输出层的计算(或预测)定义为:

步骤 1 – 输入前向传播

M4:对于具有激活函数 σ 的二分类输出值的估计定义为:

步骤 1 – 输入前向传播

如网络架构部分所示,对于具有两个以上类别的多项式(或多类)分类,输出值使用指数函数进行归一化,如以下 Softmax 部分所述。

计算流程

从输入 x 计算输出值 y 的计算称为输入前向传播。为了简化,我们用以下框图表示层之间的前向传播:

计算流程

输入前向传播的计算模型

前面的图方便地说明了输入前向传播的计算模型,如图中源层和目标层之间的程序关系及其连通性。输入 x 通过每个连接向前传播。

connectionForwardPropagation 方法计算权重和输入值的点积,并应用激活函数,在隐藏层的情况下,对每个连接应用。因此,它是 MLPConnection 类的成员。

整个网络中输入值的正向传播由 MLP 算法本身管理。输入值的正向传播用于分类或预测 y = f(x)。它取决于需要通过训练估计的值权重 w[ij]v[ij]。正如你可能猜到的,权重定义了类似于回归模型的神经网络模型。让我们看看 MLPConnection 类的 connectionForwardPropagation 方法:

def connectionForwardPropagation: Unit = {
  val _output = synapses.map(x => {
    val dot = inner(src.output, x.map(_._1) ) //27
    dst.activate(dot)  //28
  })
  dst.setOutput(_output) //29
}

第一步是计算当前源层对此连接的输出 _output 和突触(权重)(行 27)的线性内积(或点积)(参考第三章中 "Scala 中的时间序列" 部分,数据预处理)。通过应用目标层的 activate 方法到点积(行 28)来计算激活函数。最后,使用计算出的值 _output 来初始化目标层的输出(行 29)。

错误函数

问题类型(模式) 部分所述,有两种方法来计算输出值上的误差或损失:

  • 预期输出值和预测输出值之间平方误差之和,如 M5 数学表达式所定义

  • M6M7 数学公式中描述的预期值和预测值的交叉熵

注意

M5:预测值 ~y 和期望值 y 的平方误差和均方误差定义为:

误差函数

M6:对于单个输出值 y 的交叉熵定义为:

误差函数

M7:多变量输出向量 y 的交叉熵定义为:

误差函数

平方误差和均方误差函数已在第三章的时间序列在 Scala部分中描述,数据预处理

XTSeries对象的单变量crossEntropy方法实现如下:

def crossEntropy(x: Double, y: Double): Double = 
  -(x*Math.log(y) + (1.0 - x)*Math.log(1.0 - y))

多变量特征作为签名的交叉熵计算与单变量情况类似:

def crossEntropy(xt: DblArray, yt: DblArray): Double = 
  yt.zip(xt).aggregate(0.0)({ case (s, (y, x)) => 
    s - y*Math.log(x)}, _ + _)

操作模式

网络架构部分,你了解到输出层的结构取决于需要解决的问题的类型,也称为操作模式。让我们将不同的操作模式(二项式、多项式分类和回归)封装成一个类层次结构,实现MLPMode特质。MLPMode特质有两个针对问题类型特定的方法:

  • apply:这是应用于输出值的转换

  • error:这是整个观察集累积误差的计算

代码如下:

trait MLPMode { 
  def apply(output: DblArray): DblArray   //30
  def error(labels: DblArray, output: DblArray): Double = 
    mse(labels, output)  //31
}

apply方法将转换应用于输出层,如操作模式表(行30)的最后一列所述。error函数计算输出层所有观察值的累积误差或损失,如操作模式表(行31)的第一列所述。

MLPBinClassifier二项式(双类)分类器的输出层中的转换包括将sigmoid函数应用于每个output值(行32)。累积误差计算为期望输出、标签和预测输出的交叉熵(行33):

class MLPBinClassifier extends MLPMode {   
  override def apply(output: DblArray): DblArray = 
    output.map(sigmoid(_))  //32
  override def error(labels: DblArray,  
      output: DblArray): Double = 
    crossEntropy(labels.head, output.head)  //33
}

多层感知器的回归模式根据问题类型(模式)部分的操作模式表定义:

class MLPRegression extends MLPMode  {
  override def apply(output: DblArray): DblArray = output
}

多项式分类器模式由MLPMultiClassifier类定义。它使用softmax方法提升output的最高值,如下面的代码所示:

class MLPMultiClassifier extends MLPMode {
  override def apply(output: DblArray):DblArray = softmax(output)
}

softmax方法应用于实际的output值,而不是偏差。因此,在应用softmax归一化之前,必须删除第一个节点 y(0) = +1

Softmax

在具有 K 个类别(K > 2)的分类问题中,输出必须转换为概率 [0, 1]。对于需要大量类别的难题,需要提升具有最高值(或概率)的输出 y[k]。这个过程被称为指数归一化或 softmax [9:8]。

注意

M8:多项式(K > 2)分类的 softmax 公式如下:

Softmax

这是MLPMultiClassifier类的softmax方法的简单实现:

def softmax(y: DblArray): DblArray = {
  val softmaxValues = new DblArray(y.size)
  val expY = y.map( Math.exp(_))  //34
  val expYSum = expY.sum  //35

  expY.map( _ /expYSum).copyToArray(softmaxValues, 1) //36
  softmaxValues
}

softmax方法实现了M8数学表达式。首先,该方法计算输出值的expY指数值(第34行)。然后,通过它们的总和expYSum(第35行)对指数变换后的输出进行归一化,生成softmaxValues输出数组(第36行)。再次强调,无需更新偏置元素y(0)

训练阶段的第二步是定义并初始化要反向传播到各层的 delta 误差值矩阵,从输出层反向传播到输入层。

第 2 步 – 错误反向传播

错误反向传播是一种算法,用于估计隐藏层的误差,以便计算网络权重的变化。它以输出平方误差的总和作为输入。

注意

计算累积误差的约定

一些作者将反向传播称为 MLP 的训练方法,它将梯度下降应用于定义为平方误差总和或多项式分类或回归的平均平方误差的输出误差。在本章中,我们保持反向传播的较窄定义,即平方误差总和的反向计算。

权重调整

通过计算误差的导数之和,并乘以学习因子来调整连接权重∆v∆w。然后,使用权重的梯度来计算源层输出的误差[9:9]。

更新权重的最简单算法是梯度下降[9:10]。批量梯度下降在第一章的“Let's kick the tires”中介绍,入门

梯度下降是一个非常简单且稳健的算法。然而,它在收敛到全局最小值方面可能比共轭梯度法或拟牛顿法慢(请参阅附录 A 中的优化技术总结部分,基本概念)。

有几种方法可以加快梯度下降收敛到最小值,例如动量因子和自适应学习系数[9:11]。

训练过程中权重的较大变化会增加模型(连接权重)收敛所需的 epoch 数。这在称为在线训练的训练策略中尤其如此。训练策略将在下一节中讨论。本章剩余部分使用动量因子α。

注意

M9:学习率

使用梯度下降计算神经网络权重的方法如下:

权重调整

M10:学习率和动量因子

使用动量系数 α 的梯度下降法计算神经网络权重如下:

权重调整

通过将动量因子 α 设置为零,在通用(M10)数学表达式中选择了梯度下降算法(M9)的最简单版本。

误差传播

感知器训练的目标是使所有输入观察到的损失或累积误差最小化,无论是输出层的平方误差之和还是交叉熵。对于每个输出神经元 y[k] 的误差 ε[k] 是预测输出值与标签输出值之间的差异。由于那些层的标签值未知,因此无法计算隐藏层 z[j] 的输出值误差:

误差传播

反向传播算法的示意图

在平方误差和的情况下,计算累积误差对输出层每个权重的偏导数,作为平方函数的导数和权重与输入 z 的点积的导数的组合。

如前所述,计算误差对隐藏层权重的偏导数有点棘手。幸运的是,偏导数的数学表达式可以写成三个偏导数的乘积:

  • 输出值 y[k] 上累积误差 ε 的导数

  • 输出值 yk 对隐藏值 z[j] 的导数,已知 sigmoid σ 的导数是 σ(1 - σ)

  • 隐藏层输出 z[j] 对权重 w[ij] 的导数

偏导数的分解产生了以下公式,通过传播误差(或损失)ε 来更新输出和隐藏神经元的突触权重。

注意

输出权重调整

M11:对于具有预测值 ~y 和期望值 y 以及隐藏层输出 z 的输出层,计算 delta δ 和权重调整 ∆v 如下:

误差传播

隐藏层权重调整

M12:对于具有预测值 ~y 和期望值 y、隐藏层输出 z 和输入值 x 的隐藏层,计算 delta δ 和权重调整 ∆w 如下:

误差传播

矩阵 δ[ij]Delta 类中的 delta 矩阵定义。它包含在层之间传递的基本参数,从输出层遍历网络到输入层。参数如下:

  • 在输出层计算出的初始 loss 或误差

  • 当前连接的 delta 值矩阵

  • 下游连接(或目标层与下一层之间的连接)的权重或 synapses

代码如下:

case class Delta(val loss: DblArray, 
  val delta: DblMatrix = Array.empty[DblArray],
  val synapses: MLPConnSynapses = Array.empty[Array[MLPSynapse]] )

使用期望值y为输出层生成Delta类的第一个实例,然后在MLPNetwork.trainEpoch方法中传播到前一个隐藏层(行24):

val diff = (x: Double, y: Double) => x - y
Delta(zipToArray(y, layers.last.output)(diff))

M11数学表达式由MLPOutLayer类的delta方法实现:

def delta(error: DblArray, srcOut: DblArray, 
     synapses: MLPConnSynapses): Delta = {

  val deltaMatrix = new ArrayBuffer[DblArray] //34
  val deltaValues = error./:(deltaMatrix)( (m, l) => {
    m.append( srcOut.map( _*l) )
    m
  })   //35
  new Delta(error, deltaValues.toArray, synapses)  //36
}

该方法生成了与输出层相关的 delta 值矩阵(行34)。M11公式实际上是通过折叠srcOut输出值(行35)来实现的。新的 delta 实例被返回到MLPNetworktrainEpoch方法,并反向传播到前一个隐藏层(行36)。

MLPLayer类的delta方法实现了M12数学表达式:

def delta(oldDelta: DblArray, srcOut: DblArray, 
     synapses: MLPConnSynapses): Delta = {

  val deltaMatrix = new ArrayBuffer[(Double, DblArray)]
  val weights = synapses.map(_.map(_._1))
       .transpose.drop(1) //37

  val deltaValues = output.drop(1)
   .zipWithIndex./:(deltaMatrix){  // 38
     case (m, (zh, n)) => {
       val newDelta = inner(oldDelta, weights(n))*zh*(1.0 - zh)
       m.append((newDelta, srcOut.map( _ * newdelta) )
       m
     } 
  }.unzip
  new Delta(deltaValues._1.toArray, deltaValues._2.toArray)//39
}

delta方法的实现与MLPOutLayer.delta方法类似。它通过转置从输出层提取权重v(行37)。通过应用M12公式(行38)计算隐藏连接中 delta 矩阵的值。新的 delta 实例被返回到trainEpoch方法(行39),以便传播到前一个隐藏层(如果存在)。

计算模型

错误反向传播算法的计算模型与输入的前向传播非常相似。主要区别在于δ(delta)的传播是从输出层到输入层进行的。以下图表展示了具有两个隐藏层z[s]z[t]的反向传播的计算模型:

计算模型

delta 误差反向传播的示意图

connectionBackPropagation方法将错误从输出层或隐藏层之一反向传播到前一层。它是MLPConnection类的一个成员。整个网络中输出错误的反向传播由MLP类管理。

它实现了两套方程,其中synapses(j)(i)._1是权重w[ji]dst.delta是目标层中误差导数的向量,src.delta是源层输出的误差导数,如此处所示:

def connectionBackpropagation(delta: Delta): Delta = {  //40
  val inSynapses =  //41
    if( delta.synapses.length > 0) delta.synapses 
    else synapses 

  val delta = dst.delta(delta.loss, src.output,inSynapses) //42
  synapse = synapses.zipWithIndex.map{ //43
    case (synapsesj, j) => synapsesj.zipWithIndex.map{
      case ((w, dw), i) => { 
        val ndw = config.eta*connectionDelta.delta(j)(i)
        (w + ndw - config.alpha*dw, ndw)
      } 
    }
  }
  new Delta(connectionDelta.loss, 
       connectionDelta.delta, synapses)
}

connectionBackPropagation方法将关联到目标(输出)层的delta作为参数(行40)。输出层是网络的最后一层,因此,后续连接的突触被定义为长度为零的空矩阵(行41)。该方法使用源层的输出src.outputdelta.loss错误计算隐藏层的新delta矩阵(行42)。权重(突触)使用具有动量因子的梯度下降更新,如M10数学表达式(行43)。

备注

可调学习率

通过调整学习率,可以进一步改进每个新纪元连接新权重的计算。

第 3 步 – 退出条件

收敛标准包括评估与操作模式(或问题)相关的累积误差(或损失)与预定义的eps收敛值。累积误差使用平方和误差公式(M5)或交叉熵公式(M6M7)计算。另一种方法是计算两个连续时代之间累积误差的差异,并将eps收敛标准作为退出条件。

将所有这些放在一起

MLP类被定义为使用从训练集xt隐式生成的模型的数据转换,ITransform类型,如第二章中单调数据转换部分所述,Hello World!(第44行)。

MLP 算法采用以下参数:

  • config: 这是算法的配置

  • hidden: 如果有的话,这是隐藏层的大小数组

  • xt: 这是用于训练模型的特征时间序列

  • expected: 这是用于训练目的的标记输出值

  • mode: 这是算法的隐式操作模式或目标

  • f: 这是将特征从类型T转换为Double的隐式转换

预测或分类方法|>的隐式转换的输出V类型为DblArray(第45行):

class MLPT <: AnyVal
    (implicit mode: MLPMode, f: T => Double) 
  extends ITransform[Array[T]](xt) with Monitor[Double] {  //44

  type V = DblArray  //45

  lazy val topology = if(hidden.length ==0) 
    ArrayInt 
  else  ArrayInt ++ hidden ++ 
         ArrayInt  //46

  val model: Option[MLPModel] = train
  def train: Option[MLPModel]   //47
  override def |> : PartialFunction[Array[T], Try[V]] 
}

拓扑是从xt输入变量、expected值以及如果有,隐藏层的配置创建的(第46行)。以下图表展示了从MLPNetwork类的参数生成拓扑的过程:

将所有这些放在一起

多层感知器的拓扑编码

例如,具有三个输入变量、一个输出变量和两个各含三个神经元的隐藏层的神经网络topology被指定为ArrayInt。模型通过调用train方法(第47行)进行训练生成。最后,根据所选的操作模式(第48行),使用ITransform特质的|>运算符进行分类、预测或回归。

训练和分类

一旦定义了训练周期或时代,就仅仅是定义和实施一个策略,使用一系列数据或时间序列来创建模型。

正则化

对于给定的分类或回归问题,有两种方法来找到最合适的网络架构,如下所述:

  • 毁坏性调整:从一个大的网络开始,然后移除对平方误差和没有影响的节点、突触和隐藏层

  • 构造性调整:从一个小的网络开始,然后逐步添加节点、突触和隐藏层,以减少输出误差

毁坏性调整策略通过将权重置零来移除突触。这通常通过正则化来完成。

你已经看到,正则化是解决线性回归和逻辑回归中过拟合的有力技术,这在第六章的岭回归部分中有所描述,回归和正则化。神经网络可以从向平方误差和添加正则化项中受益。正则化因子越大,某些权重被减少到零的可能性就越大,从而减小网络的规模[9:13]。

模型生成

在多层感知器的实例化过程中创建了(训练了)MLPModel实例。构造函数遍历所有xt时间序列数据点的训练周期(或 epoch),直到累积值小于eps收敛标准,如下面的代码所示:

def train: Option[MLPModel] = {
  val network = new MLPNetwork(config, topology) //48
  val zi =  xt.toVector.zip(expected.view)   // 49

  Range(0, config.numEpochs).find( n => {  //50
    val cumulErr = fisherYates(xt.size)
       .map(zi(_))
       .map{ case(x, e) => network.trainEpoch(x, e)}
       .sum/st.size   //51
     cumulErr  < config.eps  //52
  }).map(_ => network.getModel)
}

train方法使用配置和topology作为输入实例化一个 MLP 网络(第 48 行)。该方法执行多个 epoch,直到动量梯度下降收敛或达到允许的最大迭代次数(第 50 行)。在每个 epoch 中,该方法使用 Fisher-Yates 算法对输入值和标签进行洗牌,调用MLPNetwork.trainEpoch方法,并计算cumulErr累积误差(第 51 行)。这个特定的实现将累积误差的值与eps收敛标准进行比较,作为退出条件(第 52 行)。

注意

MLP 的尾递归训练

多层感知器的训练被实现为一个迭代过程。它可以很容易地用权重和累积误差作为递归参数的尾递归来替换。

使用懒视图来减少不必要的对象创建(第 49 行)。

注意

退出条件

在这个实现中,如果模型在达到最大 epoch 数之前没有收敛,则将其初始化为None。另一种选择是在非收敛的情况下生成模型,并将准确度指标添加到模型中,就像我们在支持向量机(SVM)的实现中那样(参见第八章下的训练部分,支持向量分类器 – SVC),核模型和支持向量机)。

一旦在多层感知器的实例化过程中创建了模型,它就可以用于预测或分类新观察值的类别。

快速 Fisher-Yates 洗牌

在第一章的让我们试试看部分下的第 5 步 – 实现分类器入门中描述了一个自制的洗牌算法,作为 Scala 标准库中scala.util.Random.shuffle方法的替代。本节描述了一种称为 Fisher-Yates 洗牌算法的替代洗牌机制:

def fisherYates(n: Int): IndexedSeq[Int] = {

   def fisherYates(seq: Seq[Int]): IndexedSeq[Int] = {
     Random.setSeed(System.currentTimeMillis)
    (0 until seq.size).map(i => {
       var randomIdx: Int = i + Random.nextInt(seq.size-i) //53
       seq(randomIdx) ^= seq(i)    //54
       seq(i) = seq(randomIdx) ^ seq(i) 
       seq(randomIdx) ^= (seq(i))
       seq(i)
    })
  }

  if( n <= 0)  Array.empty[Int]
  else 
    fisherYates(ArrayBuffer.tabulate(n)(n => n)) //55
}

鱼丁混合洗牌算法创建一个有序的整数序列(第 55 行),并将每个整数与从初始序列剩余部分随机选择的另一个整数交换(第 52 行)。这种实现特别快速,因为整数是通过位操作就地交换的,也称为位交换(第 54 行)。

注意

鱼丁混合洗牌的尾递归实现

鱼丁混合洗牌算法可以使用尾递归而不是迭代来实现。

预测

|> 数据转换实现了运行时分类/预测。如果模型成功训练,则返回归一化为概率的预测值,否则返回 None。这些方法调用 MLPNetwork 的前向预测函数(第 53 行):

override def |> : PartialFunction[Array[T],Try[V]] ={
  case x: Array[T] if(isModel && x.size == dimension(xt)) => 
   Try(MLPNetwork(config, topology, model).predict(x)) //56
}

MLPNetworkpredict 方法通过以下方式使用前向传播从输入 x 计算输出值:

def predict(x: DblArray): DblArray = {
   layers.head.set(x)
   connections.foreach( _.connectionForwardPropagation)
   layers.last.output
}

模型适应性

模型的适应性衡量模型拟合训练集的程度。具有高度适应性的模型可能会过度拟合。fit 适应性方法计算预测值与训练集标签(或预期值)的均方误差。该方法使用高级 count 方法返回预测值正确的观测值的百分比:

def fit(threshold: Double): Option[Double] = model.map(m => 
  xt.map( MLPNetwork(config, topology, Some(m)).predict(_) )
    .zip(expected)
    .count{case (y, e) =>mse(y, e.map(_.toDouble))< threshold }
    /xt.size.toDouble
)

注意

模型适应性对比准确度

模型相对于训练集的适应性反映了模型拟合训练集的程度。适应性的计算不涉及验证集。质量参数,如准确度、精确度或召回率,衡量模型相对于验证集的可靠性和质量。

我们的 MLP 类现在已准备好应对一些分类挑战。

评估

在将我们的多层感知器应用于理解货币市场交易所的波动之前,让我们熟悉一下第一部分中介绍的一些关键学习参数。

执行配置文件

让我们看看多层感知器训练的收敛性。监控特性(参考附录 A 下实用类中的监控部分,基本概念),收集并显示一些执行参数。我们选择使用连续两个时期(或纪元)之间反向传播错误的差异来提取多层感知器收敛的配置文件。

测试使用学习率 η = 0.03 和动量因子 α = 0.3 对具有两个输入值的多层感知器进行收敛性测试:一个包含三个节点的隐藏层和一个输出值。测试依赖于合成的随机值:

执行配置文件

MLP 累积错误的执行配置文件

学习率的影响

第一项练习的目的是评估学习率η对训练时期收敛的影响,这是通过所有输出变量的累积误差来衡量的。观察xt(相对于标记的输出yt)是通过使用几个噪声模式如f1(第 57 行)和f2函数(第 58 行)合成的,如下所示:

def f1(x: Double): DblArray = ArrayDouble 
def f2(x: Double): DblArray =  ArrayDouble

val HALF_TEST_SIZE = (TEST_SIZE>>1)
val xt = Vector.tabulate(TEST_SIZE)(n =>   //59
  if( n <HALF_TEST_SIZE) f1(n) else f2(n -HALF_TEST_SIZE))
val yt = Vector.tabulate(TEST_SIZE)(n => 
  if( n < HALF_TEST_SIZE) ArrayDouble 
  else ArrayDouble )  //60

输入值xt是通过f1函数为数据集的一半合成的,而另一半是通过f2函数合成的(第 59 行)。期望值的数据生成器yt将使用f1函数生成的输入值分配标签 0.0,而将使用f2函数创建的输入值分配标签 1.0(第 60 行)。

测试在TEST_SIZE个数据点的样本上运行,最大NUM_EPOCHS个时期,一个没有softmax变换的单个隐藏层HIDDENS.head神经元,以及以下 MLP 参数:

val ALPHA = 0.3
val ETA = 0.03
val HIDDEN = ArrayInt
val NUM_EPOCHS = 200
val TEST_SIZE = 12000
val EPS = 1e-7

def testEta(eta: Double, 
    xt: XVSeries[Double], 
    yt: XVSeries[Double]): 
    Option[(ArrayBuffer[Double], String)] = {

  implicit val mode = new MLPBinClassifier //61
  val config = MLPConfig(ALPHA, eta, NUM_EPOCHS, EPS)
  MLPDouble  
    .counters("err").map( (_, s"eta=$eta")) //62
}

testEta方法根据不同的eta值生成轮廓或误差。

MLP类的实例化之前,必须隐式定义操作mode(第 61 行)。它被设置为MLPBinClassifier类型的二项式分类器。执行配置文件数据是通过Monitor特征的counters方法收集的(第 62 行)(请参阅附录 A 下实用类部分的Monitor部分,基本概念)。

评估学习率对多层感知器收敛影响的驱动代码相当简单:

val etaValues = ListDouble
val data = etaValues.flatMap( testEta(_, xt, yt))
    .map{ case(x, s) => (x.toVector, s) }

val legend = new Legend("Err", 
   "MLP [2-3-1] training - learning rate", "Epochs", "Error")
LinePlot.display(data, legend, new LightPlotTheme)

轮廓是用 JFreeChart 库创建的,并在以下图表中显示:

学习率的影响

学习率对 MLP 训练的影响

图表表明,MLP 模型训练在较大的学习率值下收敛得更快。然而,你需要记住,非常陡峭的学习率可能会将训练过程锁定在累积误差的局部最小值,生成精度较低的权重。相同的配置参数用于评估动量因子对梯度下降算法收敛的影响。

动量因子的影响

让我们量化动量因子α对训练过程收敛到最优模型(突触权重)的影响。测试代码与评估学习率影响的方法非常相似。

整个时间序列的累积误差在以下图表中绘制:

动量因子的影响

动量因子对 MLP 训练的影响

前面的图表显示,随着动量因子的增加,均方误差的速率降低。换句话说,动量因子对梯度下降的收敛有积极但有限的影响。

隐藏层数量的影响

让我们考虑一个具有两个隐藏层(7 个和 3 个神经元)的多层感知器。训练的执行配置文件显示,在几个 epoch 之后,输出累积误差突然收敛,因为下降梯度未能找到方向:

隐藏层数量的影响

具有两个隐藏层的 MLP 训练的执行配置文件

让我们应用我们新学到的关于神经网络和影响某种货币汇率变量分类的知识。

测试用例

神经网络已被用于金融应用,从抵押贷款申请中的风险管理到商品定价的对冲策略,再到金融市场预测建模[9:14]。

测试用例的目标是理解某些货币汇率、黄金现货价格和标准普尔 500 指数之间的相关性因素。为此练习,我们将使用以下交易所交易基金ETFs)作为货币汇率的代理:

  • FXA:这是澳大利亚元对美元的汇率

  • FXB:这是英镑对美元的汇率

  • FXE:这是欧元对美元的汇率

  • FXC:这是加拿大元对美元的汇率

  • FXF:这是瑞士法郎对美元的汇率

  • FXY:这是日元对美元的汇率

  • CYB:这是人民币对美元的汇率

  • SPY:这是标准普尔 500 指数

  • GLD:这是黄金的美元价格

实际上,要解决的问题是从一个或多个回归模型中提取一个,该模型将一个 ETF 的y与一组其他 ETF 的y={x[i]} y=f(x[i])相关联。例如,日元汇率(FXY)与黄金现货价格(GLD)、欧元对美元汇率(FXE)、澳大利亚元对美元汇率(FXA)等的组合之间是否存在关系?如果是这样,回归f将被定义为FXY = f (GLD, FXE, FXA)

下面的两个图表显示了两种货币在两年半期间的波动情况。第一个图表显示了一组可能相关的 ETF:

测试用例

一个相关货币 ETF 的例子

第二个图表显示了一组具有类似价格行为的相关货币 ETF。神经网络不提供其内部推理的任何分析表示;因此,对新手工程师来说,视觉相关性可以非常有助于验证他们的模型:

测试用例

一个相关货币 ETF 的例子

找到货币汇率变动与黄金现货价格之间任何相关性的一个非常简单的方法是选择一个作为目标的股票代码,以及一组基于其他货币的 ETF 子集作为特征。

让我们考虑以下问题:寻找 FXE 价格与一系列货币 FXB、CYB、FXA 和 FXC 之间的相关性,如下面的图所示:

测试案例

从股票代码生成特征机制

实现

第一步是定义 MLP 分类器的配置参数,如下所示:

val path = "resources/data/chap9/"
val ALPHA = 0.8; 
val ETA = 0.01
val NUM_EPOCHS = 250
val EPS = 1e-3
val THRESHOLD = 0.12
val hiddens = ArrayInt //59

除了学习参数外,网络还初始化了多个拓扑配置(行 59)。

接下来,让我们创建分析中使用的所有 ETF 价格的搜索空间:

val symbols = ArrayString
val STUDIES = List[Array[String]](   //60
  ArrayString,
  ArrayString,
  ArrayString,
  ArrayString,
  ArrayString,
  symbols
)

测试的目的是评估和比较七个不同的投资组合或研究(行 60)。使用 GoogleFinancials 提取器从 Google 财务表中提取了所有 ETF 在 3 年期间的开盘价(行 61):

val prices = symbols.map(s => DataSource(s"$path$s.csv"))
      .flatMap(_.get(close).toOption) //61

下一步是实施从上一段中介绍的 ETF 篮子中提取目标和特征的机制。让我们考虑以下研究作为 ETF 股票代码列表:

val study = ArrayString

研究的第一个元素 FXE 是标记的输出;其余三个元素是观察到的特征。对于这项研究,网络架构有三个输入变量(FXFFXBCYB)和一个输出变量 FXE

val obs = symbols.flatMap(index.get(_))
              .map( prices(_).toArray )  //62
val xv = obs.drop(1).transpose  //63
val expected = ArrayDblArray.transpose //64

使用索引(行 62)构建了观察集 obs。按照惯例,第一个观察值被选为标签数据,其余的研究作为训练特征。由于观察值作为时间序列数组加载,因此使用 transpose(行 63)计算序列的时间特征。单个 target 输出变量在转置之前必须转换为矩阵(行 64)。

最终,通过实例化 MLP 类来构建模型:

implicit val mode = new MLPBinClassifier  //65
val classifier = MLPDouble
classifier.fit(THRESHOLD)

目标或操作 模式 隐式定义为 MLP 二元分类器,MLPBinClassifier(行 65)。MLP.fit 方法在 训练和分类 部分定义。

模型评估

测试包括评估六个不同的模型,以确定哪些模型提供最可靠的相关性。确保结果在一定程度上独立于神经网络架构至关重要。不同的架构作为测试的一部分进行评估。

下面的图表比较了两种架构的模型:

  • 每个隐藏层有四个节点的两个隐藏层

  • 有八个节点(相对于五个和六个)的三个隐藏层

第一张图表展示了由可变数量的输入(2,7)组成的六个回归模型的适应性:一个输出变量和两个每个有四个节点的隐藏层。特征(ETF 符号)列在箭头 => 沿 y 轴的左侧。箭头右侧的符号是预期的输出值:

模型评估

MLP 每个隐藏层有四个节点的准确度

下图显示了具有三个隐藏层(分别为八个、五个和六个节点)的架构的六个回归模型的适应性:

模型评估

具有八个、五个和六个节点的三个隐藏层的 MLP 准确率

这两种网络架构有很多相似之处;在两种情况下,最合适的回归模型如下:

  • FXE = f (FXA, SPY, GLD, FXB, FXF, FXD, FXY, CYB)

  • FXE = g (FXC, GLD, FXA, FXY, FXB)

  • FXE = h (FXF, FXB, CYB)

另一方面,使用日元(FXY)和澳大利亚元(FXA)的汇率预测加拿大元对美元(FXC)的汇率(FXC)的预测在这两种配置下都很差。

注意

经验评估

这些经验测试使用了一个简单的准确率指标。回归模型的形式比较将系统地分析每个输入和输出变量的所有组合。评估还将计算每个模型的精确度、召回率和 F1 分数(请参阅第二章中“评估模型”部分下的“验证”部分下的关键质量指标)。

隐藏层架构的影响

下一个测试包括评估配置的隐藏层对三个模型准确率的影响:FXF, FXB, CYB => FXEFCX, GLD, FXA => FXY,和FXC, GLD, FXA, FXY, FXB => FXE。为了方便起见,这个测试通过选择训练数据的一个子集作为测试样本来计算准确率。测试的目标是使用某些指标比较不同的网络架构,而不是估计每个模型的绝对准确率。

四种网络配置如下:

  • 具有四个节点的单个隐藏层

  • 每个隐藏层有四个节点的两个隐藏层

  • 每个隐藏层有七个节点的两个隐藏层

  • 具有八个、五个和六个节点的三个隐藏层

让我们看一下以下图表:

隐藏层架构的影响

隐藏层架构对 MLP 准确率的影响

具有两个或更多隐藏层的复杂神经网络架构生成的权重具有相似的准确率。四节点单隐藏层架构生成的准确率最高。使用正式的交叉验证技术计算准确率将生成一个较低的准确率数值。

最后,我们来看看网络复杂性对训练持续时间的影响,如下所示图:

隐藏层架构的影响

隐藏层架构对训练持续时间的影响

毫不奇怪,随着隐藏层和节点数量的增加,时间复杂度显著增加。

卷积神经网络

本节提供了对卷积神经网络的一个简要介绍,没有 Scala 实现。

到目前为止,感知器的层组织成完全连接的网络。很明显,随着隐藏层数量和规模的增加,突触或权重的数量会显著增加。例如,一个具有 6 维特征集的网络,3 个隐藏层,每个层有 64 个节点,以及一个输出值,需要764 + 26564 + 651 = 8833个权重!

如图像或字符识别等应用需要非常大的特征集,这使得训练一个全连接的层感知器非常计算密集。此外,这些应用需要将像素的邻近性等空间信息作为特征向量的一部分进行传递。

一种最近的方法,称为卷积神经网络,其核心思想是限制输入节点连接到隐藏层节点的数量。换句话说,该方法利用空间定位来减少输入层和隐藏层之间的连接复杂性[9:15]。连接到隐藏层中单个神经元的输入节点集合被称为局部感受野

局部感受野

隐藏层的神经元从局部感受野或子图像(n by n像素)中学习,每个像素都是一个输入值。下一个局部感受野,在任何方向上平移一个像素,连接到第一隐藏层中的下一个神经元。第一隐藏层被称为卷积层。输入(图像)与第一隐藏层(卷积层)之间的映射图如下:

局部感受野

从图像生成卷积层

每个局部感受野(n by n)都有一个偏置元素(+1),它连接到隐藏神经元。然而,额外的复杂性并不会导致更精确的模型,因此,偏置在卷积层的神经元之间共享。

权重共享

通过将字段平移一个像素(上、下、左或右),生成表示图像一小部分的局部感受野。因此,与局部感受野相关的权重也在隐藏层中的神经元之间共享。实际上,图像中的许多像素可以检测到相同的特征,如颜色或边缘。输入特征和隐藏层中的神经元之间的映射,称为特征图,在卷积层中共享权重。输出是通过激活函数计算得出的。

注意

双曲正切与 Sigmoid 激活函数的比较

Sigmoid 函数主要在多层感知器相关的例子中作为隐藏层的激活函数。双曲正切函数通常用于卷积网络。

卷积层

从特征图计算得出的输出被表示为一个类似于在基于离散傅里叶变换的滤波器中使用的卷积(参考第三章下傅里叶分析部分的基于 DFT 的滤波部分的M11数学表达式),在隐藏层中计算输出的激活函数必须修改以考虑局部感受野。

注意

卷积神经网络的激活

M13:对于共享偏置w[0],激活函数σn像素的局部感受野,输入值x[ij]和与特征图相关的权重wuv,输出值z[j]由以下公式给出:

卷积层

构建神经网络的下一步将是使用卷积层的输出到一个全连接的隐藏层。然而,卷积层中的特征图通常相似,因此可以使用一个称为下采样层的中间层将它们减少到更小的输出集[9:16]。

下采样层

卷积层中的每个特征图都被减少或压缩成一个更小的特征图。由这些较小的特征图组成的层被称为下采样层。采样的目的是减少权重对相邻像素之间图像的任何微小变化的敏感性。权重的共享减少了图像中任何非显著变化的敏感性:

下采样层

从卷积到下采样层的特征图之间的连接性

下采样层有时被称为池化层

整合所有内容

卷积神经网络的最后一层是全连接的隐藏层和输出层,受到与传统多层感知器相同的转换公式的约束。输出值可以使用线性乘积或softmax函数来计算:

整合所有内容

卷积神经网络概述

步骤 2 - 误差反向传播部分描述的误差反向传播算法必须修改以支持特征图[9:17]。

注意

卷积网络的架构

深度卷积神经网络有多个卷积层和下采样层的序列,并且可能有一个以上的全连接隐藏层。

优点和局限性

神经网络的优缺点取决于它们与哪些其他机器学习方法进行比较。然而,基于神经网络的分类器,特别是使用误差反向传播的多层感知器,有一些明显的优点,如下所示:

  • 神经网络的数学基础不需要动态规划或线性代数的专业知识,除了基本的梯度下降算法。

  • 神经网络可以执行线性算法无法执行的任务。

  • MLP 通常适用于高度动态和非线性过程。与支持向量机不同,它们不需要我们通过核化来增加问题维度。

  • MLP 不对线性、变量独立性或正态性做出任何假设。

  • MLP 的训练执行非常适合在线训练的并发处理。在大多数架构中,算法可以在网络中的某个节点失败的情况下继续进行(参考第十二章中的Apache Spark部分,可扩展框架)。

然而,与任何机器学习算法一样,神经网络也有其批评者。最常记录的限制如下:

  • MLP 模型是黑盒,其中特征与类之间的关联可能难以描述和理解。

  • MLP 需要一个漫长的训练过程,特别是使用批量训练策略。例如,一个两层网络对于n个输入变量、m个隐藏神经元、p个输出值、N个观察值和e个周期的时间复杂度(乘法次数)为O(n.m.p.N.e)。在数千个周期后出现解决方案并不罕见。使用动量因子的在线训练策略往往收敛更快,并且比批量过程需要更少的周期。

  • 调整配置参数,例如优化学习率和动量因子、选择最合适的激活方法以及累积误差公式可能变成一个漫长的过程。

  • 估计生成准确模型所需的最小训练集大小以及限制计算时间并不明显。

  • 神经网络不能增量重新训练。任何新的标记数据都需要执行几个训练周期。

注意

其他类型的神经网络

本章涵盖了多层感知器并介绍了卷积神经网络的概念。还有许多其他类型的神经网络,例如循环网络和混合密度网络。

摘要

这不仅结束了多层感知器内部的旅程,也介绍了监督学习算法的介绍。在本章中,你学习了:

  • 人工神经网络的组件和架构

  • 反向传播多层感知器的训练周期(或时代)的阶段

  • 如何从零开始用 Scala 实现 MLP

  • 创建 MLP 分类或回归模型时,可用的众多配置参数和选项。

  • 为了评估学习率和梯度下降动量因子对训练过程收敛的影响。

  • 如何将多层感知器应用于货币波动的金融分析

  • 卷积神经网络概述

下一章将介绍遗传算法的概念,并使用 Scala 进行完整实现。尽管严格来说,遗传算法不属于机器学习算法家族,但它们在非线性、不可微问题的优化以及集成中的强分类器选择中发挥着至关重要的作用。

第十章. 遗传算法

本章介绍了进化计算的概念。源自进化论的理论算法在解决大型组合或NP 问题方面特别有效。进化计算是由约翰·霍兰德[10:1]和大卫·戈德堡[10:2]开创的。他们的发现应该对任何渴望了解遗传算法GA)和人工生命基础的人感兴趣。

本章涵盖了以下主题:

  • 进化计算的开端

  • 遗传算法的理论基础

  • 遗传算法的优点和局限性

从实际的角度来看,你将学习如何:

  • 应用遗传算法,利用市场价格和成交量变动的技术分析来预测未来回报

  • 评估或估计搜索空间

  • 使用分层或扁平寻址将解决方案编码为二进制格式

  • 调整一些遗传算子

  • 创建和评估适应度函数

进化

进化论理论,由查尔斯·达尔文阐述,描述了生物体的形态适应性[10:3]。

起源

达尔文过程包括优化生物体的形态以适应最恶劣的环境——对鱼类是流体动力学优化,对鸟类是空气动力学优化,对捕食者则是隐匿技能。以下图表显示了一个基因:

起源

生物体的种群随时间变化。种群中个体的数量会发生变化,有时变化非常显著。这些变化通常与捕食者和猎物的丰富程度或缺乏以及环境的变化有关。只有种群中最适应的个体才能通过快速适应生活环境中的突然变化和新约束而生存下来。

NP 问题

NP 代表非确定性多项式时间。NP 问题的概念与计算理论相关,更确切地说,与时间和空间复杂度相关。NP 问题的类别如下:

  • P 问题(或 P 决策问题):对于这些问题,在确定性图灵机(计算机)上的解决方案需要确定性的多项式时间。

  • NP 问题:这些问题可以在非确定性机器上以多项式时间解决。

  • NP 完全问题:这些是 NP 难问题,它们被简化为 NP 问题,其解决方案需要确定性的多项式时间。这类问题可能难以解决,但它们的解决方案可以验证。

  • NP 难问题:这些问题可能有解,但这些解可能无法在多项式时间内找到。![NP 问题]

    使用计算复杂性对 NP 问题的分类

如旅行商问题、楼层商店调度、计算图 K 最小生成树、地图着色或循环排序等问题,其搜索执行时间是非确定性多项式时间,对于包含 n 个元素的种群,时间范围从n!2[n] [10:4]。

由于计算开销,NP 问题不能总是使用分析方法来解决——即使在模型的情况下,它也依赖于可微函数。基因算法由约翰·霍兰德在 20 世纪 70 年代发明,它们从达尔文的进化论理论中获得了属性,以解决 NP 和 NP 完全问题。

进化计算

一个生物体由包含相同染色体的细胞组成。染色体DNA的链,作为整个生物体的模型。染色体由基因组成,这些基因是 DNA 的块,并编码特定的蛋白质。

重组(或交叉)是繁殖的第一个阶段。父母的基因生成整个新的染色体(后代),这些染色体可以发生变异。在变异过程中,一个或多个元素,也称为 DNA 链或染色体的个体碱基,发生改变。这些变化主要是由于父母基因传递给后代时发生的错误所引起的。一个生物在其生命中的成功衡量其适应性[10:5]。

基因算法通过繁殖来进化一个可能解决方案的种群,以解决一个问题。

基因算法与机器学习

基因算法作为一种优化技术,其实际目的是通过在一系列或一组解决方案中找到最相关或最适应的解决方案来解决问题。基因算法在机器学习中有许多应用,具体如下:

  • 离散模型参数:基因算法在寻找最大化对数似然度的离散参数集合方面特别有效。例如,黑白电影的着色依赖于从灰色阴影到 RGB 色彩方案的有限但庞大的转换集。搜索空间由不同的转换组成,目标函数是着色电影的质量。

  • 强化学习:选择最合适的规则或策略以匹配给定数据集的系统依赖于基因算法随着时间的推移来进化规则集。搜索空间或种群是候选规则的集合,目标函数是由这些规则触发的动作的信用或奖励(参见第十一章[part0220.xhtml#aid-6HPRO2 "第十一章。强化学习"],强化学习)。

  • 神经网络架构:遗传算法驱动对网络不同配置的评估。搜索空间包括隐藏层的不同组合以及这些层的大小。适应度或目标函数是平方误差的总和。

  • 集成学习 [10:6]:遗传算法可以从一组分类器中剔除弱学习者,以提高预测质量。

遗传算法组件

遗传算法有以下三个组件:

  • 遗传编码和解码):这是将解决方案候选及其组件转换为二进制格式(位或字符串01的数组)

  • 遗传操作:这是应用一组算子以提取最佳(最适应遗传的)候选者(染色体)

  • 遗传适应度函数:这是使用目标函数评估最合适的候选者

编码和适应度函数是问题相关的。遗传算子不是。

编码

让我们考虑机器学习中的优化问题,该问题包括最大化对数似然或最小化损失函数。目标是计算最小化或最大化函数 f(w) 的参数或权重,w={w[i]}。在非线性模型的情况下,变量可能依赖于其他变量,这使得优化问题特别具有挑战性。

值编码

遗传算法将变量作为位或位字符串操作。将变量转换为位字符串的过程称为编码。在变量是连续的情况下,转换称为量化离散化。每种类型的变量都有独特的编码方案,如下所示:

  • 布尔值可以用 1 位轻松编码:0 表示假,1 表示真。

  • 连续变量以类似于将模拟信号转换为数字信号的方式量化或离散化。让我们考虑一个函数,它在值范围内有最大值 max(类似地,min 表示最小值),使用 n = 16 位进行编码:值编码

    连续变量 y = f(x) 量化的示意图

离散化的步长计算如下(M1):

值编码

在 16 位中,正弦 y = sin(x) 量化的步长为 1.524e-5。

离散或分类变量编码为位更具有挑战性。至少,所有离散值都必须被考虑。然而,没有保证变量的数量将与位边界一致:

值编码

为值的基础 2 表示形式添加填充

在这种情况下,下一个指数,n+1,定义了表示值集所需的最小位数:n = log2(m).toInt + 1。具有 19 个值的离散变量需要 5 位。剩余的位根据问题设置为任意值(0,NaN 等)。这个过程被称为填充

编码既是艺术也是科学。对于每个编码函数,你需要一个解码函数来将位表示转换回实际值。

谓词编码

变量 x 的谓词是一个定义为 x 操作符 [目标] 的关系;例如,单位成本 < [9$]温度 = [82F],或 电影评分是 [3 星]

谓词的最简单编码方案如下:

  • 变量被编码为类别或类型(例如,温度,气压等),因为任何模型中都有有限数量的变量

  • 操作符被编码为离散类型

  • 被编码为离散或连续值

注意

谓词编码格式

在位串中对谓词进行编码有许多方法。例如,格式 {操作符,左操作数,右操作数} 是有用的,因为它允许你编码一个二叉树。整个规则,IF 谓词 THEN 行动,可以用表示为离散或分类值的行动来编码。

解决编码

解决编码方法将问题的解描述为谓词的无序序列。让我们考虑以下规则:

IF {Gold price rises to [1316$/ounce]} AND 
   {US$/Yen rate is [104]}).
THEN {S&P 500 index is [UP]}

在这个例子中,搜索空间由两个级别定义:

  • 布尔运算符(例如,AND)和谓词

  • 每个谓词被定义为元组(一个变量,操作符,目标值)

搜索空间的树表示如下图所示:

解决编码

编码规则的图表示

位串表示被解码回其原始格式以进行进一步计算:

解决编码

谓词的编码、修改和解码

编码方案

编码这样的候选解或谓词链有两种方法:

  • 染色体的平面编码

  • 将染色体作为基因组合进行层次编码

平面编码

平面编码方法包括将谓词集编码成一个单一的染色体(位串),代表优化问题的特定解候选者。谓词的身份没有被保留:

平面编码

染色体的平面寻址方案

遗传操作符操纵染色体的位,而不考虑这些位是否指向特定的谓词:

平面编码

使用平面寻址的染色体编码

层次编码

在此配置中,每个谓词的特征在编码过程中得到保留。每个谓词被转换为由比特串表示的基因。基因被聚合形成染色体。在比特串或染色体中添加一个额外字段以选择基因。这个额外字段由基因的索引或地址组成:

分层编码

染色体的分层寻址方案

一个通用算子选择它需要首先操作的谓词。一旦选择了目标基因,算子将更新与基因相关的比特串,如下所示:

分层编码

使用扁平寻址的染色体编码

下一步是定义操纵或更新表示染色体或单个基因的比特串的遗传算子。

遗传算子

繁殖周期的实现试图复制自然繁殖过程[10:7]。控制染色体群体的繁殖周期由三个遗传算子组成:

  • 选择:此算子根据适应度函数或标准对染色体进行排序。它消除了最弱或不适应的染色体,并控制群体增长。

  • 交叉:此算子配对染色体以生成后代染色体。这些后代染色体及其父代染色体一起添加到群体中。

  • 变异:此算子对遗传代码(比特串表示)进行微小修改,以防止连续的繁殖周期选择相同的最佳染色体。在优化术语中,此算子降低了遗传算法快速收敛到局部最大值或最小值的风险。

注意

转置算子

一些遗传算法的实现使用第四个算子,即遗传转置,以防适应度函数无法很好地定义,并且初始群体非常大。尽管额外的遗传算子可能有助于降低找到局部最大值或最小值的风险,但无法描述适应度标准或搜索空间是遗传算法可能不是最合适的工具的明确迹象。

下图概述了遗传算法的工作流程:

遗传算子

遗传算法执行的基本工作流程

注意

初始化

在任何优化过程中,搜索空间(问题的一组潜在解决方案)的初始化都是一项挑战,遗传算法也不例外。在没有偏差或启发式方法的情况下,繁殖通过随机生成的染色体初始化群体。然而,提取群体特征是值得努力的。在初始化过程中引入的任何合理的偏差都有助于繁殖过程的收敛。

这些遗传算子中每个至少有一个可配置的参数,需要估计和/或调整。此外,你很可能需要尝试不同的适应度函数和编码方案,以提高找到最适应解(或染色体)的机会。

选择

基因选择阶段的目的是评估、排序和淘汰那些不适合该问题的染色体(即解决方案候选者)。选择过程依赖于一个适应度函数,通过候选解的染色体表示来评分和排序。限制染色体种群的增长,通常通过设定种群大小的上限来实现,这是一种常见的做法。

有几种方法可以实现选择过程,从比例适应度、荷兰轮盘赌和锦标赛选择到基于排名的选择[10:8]。

注意

相对适应度退化

随着染色体初始种群的进化,染色体之间变得越来越相似。这种现象是种群实际上正在收敛的健康迹象。然而,对于某些问题,你可能需要缩放或放大相对适应度,以保持染色体之间适应度分数的有意义差异[10:9]。

以下实现依赖于基于排名的选择。

选择过程包括以下步骤:

  1. 将适应度函数应用于种群中的每个染色体 j,得到 f[j]

  2. 计算整个种群的总体适应度分数 ∑f[j]

  3. 通过所有染色体的适应度分数之和来归一化每个染色体的适应度分数 f[j] = f[i]/Σf[j]

  4. 按照降序适应度分数对染色体进行排序 f[j] < f[j-1]

  5. 计算每个染色体 j 的累积适应度分数 f[j] = f[j] + ∑f[k]

  6. 生成选择概率(对于基于排名的公式)作为一个随机值 p ε [0,1]

  7. 消除具有低不适应度分数 f[k] < p 或高适应度成本 f[k] > p 的染色体 k

  8. 如果种群超过允许的最大染色体数量,则减小种群大小。

注意

自然选择

你可能不会对需要控制染色体种群大小感到惊讶。毕竟,自然界不会让任何物种无限制地增长,以避免耗尽自然资源。由洛特卡-沃尔泰拉方程[10:10]模拟的捕食者-猎物过程保持每种物种的种群在可控范围内。

交叉

遗传交叉的目的是为了扩大当前染色体的种群,以增强解决方案候选者之间的竞争。交叉阶段包括从一代到下一代的染色体重编程。存在许多不同的交叉技术变体。染色体种群进化的算法与交叉技术无关。因此,案例研究使用了更简单的单点交叉。交叉交换两个父染色体的一部分以产生两个后代染色体,如下所示:

交叉

染色体的交叉操作

交叉阶段的一个重要元素是选择和配对父染色体。有不同方法来选择和配对最适合繁殖的父染色体:

  • 只选择* n*个最适应的染色体进行繁殖

  • 按照它们的适应度(或不适度)值对染色体进行配对

  • 将最适应的染色体与最不适应的染色体配对,将第二适应的染色体与第二不适应的染色体配对,依此类推

依赖于特定的优化问题来选择最合适的选择方法是一种常见做法,因为它高度依赖于特定领域。

使用分层寻址作为编码方案的交叉阶段包括以下步骤:

  1. 从种群中提取染色体对。

  2. 生成一个随机概率 p ϵ [0,1]

  3. 计算交叉应用的基因索引 r[i]r[i] = p.num_genes,其中 num_genes 是染色体中的基因数。

  4. 计算所选基因中交叉应用的位索引为 x[i] = p.gene_length,其中 gene_length 是基因中的位数。

  5. 通过在父母之间交换链来生成两个后代染色体。

  6. 将两个后代染色体添加到种群中。

注意

保留父染色体

你可能会想知道为什么一旦创建了后代染色体,父母染色体就不会从种群中移除。这是因为没有保证任何后代染色体是更好的适应者。

突变

遗传突变的目的是通过向遗传材料中引入伪随机变化来防止繁殖周期收敛到局部最优。突变过程在染色体中插入小的变化,以保持代与代之间的一定程度的多样性。该方法包括翻转染色体位串表示中的一个位,如下所示:

突变

染色体的突变操作

突变是繁殖过程中最简单的三个阶段之一。在分层寻址的情况下,步骤如下:

  1. 选择要突变的染色体。

  2. 生成一个随机概率 p ϵ**[0,1]

  3. 使用公式 m[i] = p.num_genes 计算要突变的基因的索引 m[i]

  4. 计算要突变的基因的位索引 x[i] = p.genes_length

  5. 对选定的位执行翻转 XOR 操作。

注意

调整问题

调整遗传算法可能是一项艰巨的任务。一个包括系统设计实验的计划,用于测量编码、适应度函数、交叉和突变比的影响,是必要的,以避免漫长的评估和自我怀疑。

适应度分数

适应度函数是选择过程的核心。适应度函数分为三类,如下所示:

  • 固定的适应度函数:在此函数中,适应度值的计算在繁殖过程中不发生变化

  • 进化适应度函数:在此函数中,适应度值的计算在每个选择之间根据预定义的标准变化

  • 近似适应度函数:在此函数中,适应度值不能直接使用解析公式计算 [10:11]

我们对遗传算法的实现使用固定的适应度函数。

实现

如前所述,遗传算子与要解决的问题无关。让我们实现繁殖周期的所有组件。适应度函数和编码方案具有高度的专业性。

根据面向对象编程的原则,软件架构采用自顶向下的方法定义遗传算子:从种群开始,然后是每个染色体,最后到每个基因。

软件设计

遗传算法的实现使用的设计类似于分类器的模板(参考附录 A 中的 分类器设计模板 部分,基本概念)。

遗传算法实现的要点如下:

  • Population 类定义了当前解决方案候选集或染色体的集合。

  • GASolver 类实现了 GA 求解器,并具有两个组件:GAConfig 类型的配置对象和初始种群。此类实现了 ETransform 类型的显式单子数据转换。

  • GAConfig 配置类包含 GA 执行和繁殖配置参数。

  • 繁殖(Reproduction 类型)通过 mate 方法控制连续染色体代之间的繁殖周期。

  • GAMonitor 监控特性跟踪优化的进度并评估每个繁殖周期的退出条件。

以下 UML 类图描述了遗传算法不同组件之间的关系:

软件设计

遗传算法组件的 UML 类图

让我们从定义控制遗传算法的关键类开始。

关键组件

种群参数化类(带有基因子类型)包含染色体集合或池。一个种群包含从基因继承的类型的序列或列表元素。是一个可变数组,用于避免与不可变集合关联的染色体实例的过度复制。

注意

可变性的案例

避免使用可变集合是良好的 Scala 编程实践。然而,在这种情况下,染色体的数量可能非常大。大多数遗传算法的实现在每个繁殖周期中可能更新种群三次,生成大量对象,并给 Java 垃圾收集器带来压力。

种群

种群类接受两个参数:

  • limit:这是种群的最大大小

  • chromosomes:这是定义当前种群的染色体池

繁殖周期在一个种群上执行以下三个遗传操作的序列:select用于在种群的所有染色体之间进行选择(行1),+-用于所有染色体的交叉(行2),以及^用于每个染色体的突变(行3)。考虑以下代码:

type Pool[T <: Gene] = mutable.ArrayBuffer[Chromosome[T]]

class PopulationT <: Gene {    
  def select(score: Chromosome[T]=>Unit, cutOff: Double) //1
  def +- (xOver: Double) //2
  def ^ (mu: Double) //3
  …
}

limit值指定了优化过程中种群的最大大小。它定义了种群增长的硬限制或约束。

染色体

染色体是基因型层次结构中的第二层包含。染色体类接受一个基因列表作为参数(代码)。交叉和突变方法的签名+-^种群类中的实现类似,除了交叉和可变参数是相对于基因列表和每个基因的索引传递的。关于遗传交叉的章节描述了遗传索引类:

class ChromosomeT <: Gene {   
  var cost: Double = Random.nextDouble //4
  def +- (that: Chromosome[T], idx: GeneticIndices): 
        (Chromosome[T], Chromosome[T]) 
  def ^ (idx: GeneticIndices): Chromosome[T]
   …
}

算法为每个染色体分配(不)适应度分数或成本值,以便对种群中的染色体进行排序,并最终选择最适应的染色体(行4)。

注意

适应度与成本

机器学习算法使用损失函数或其变体作为要最小化的目标函数。此 GA 实现使用成本分数,以便与最小化成本、损失或惩罚函数的概念保持一致。

基因

最后,繁殖过程对每个基因执行遗传操作:

class Gene(val id: String, 
    val target: Double, 
    op: Operator)
    (implicit quantize: Quantization, encoding: Encoding){//5
  lazy val bits: BitSet = apply(target, op)

  def apply(value: Double, op: Operator): BitSet //6
  def unapply(bitSet: BitSet): (Double, Operator) //7
   …
  def +- (index: Int, that: Gene): Gene //8
  def ^ (index: Int): Unit //9
  …
}

基因类接受三个参数和两个隐式参数,如下所示:

  • id:这是基因的标识符。它通常是基因表示的变量的名称。

  • target:这是要转换或离散化为位字符串的目标值或阈值。

  • op:这是应用于目标值的运算符。

  • quantize:这是量化离散化类,它将双精度值转换为整数以转换为位,反之亦然(行5)。

  • encoding:这是基因的编码或位布局,作为一对值和操作符。

apply方法将值和操作符的对编码为位集(第6行)。unapply方法是apply的逆操作。在这种情况下,它将位集解码为值和操作符的对(第7行)。

注意

unapply()

unapply方法反转了apply方法执行的状态转换。例如,如果apply方法填充了一个集合,则unapply方法会从其元素中清除集合。

在基因上实现交叉(第8行)和变异(第9行)操作符的操作与容器染色体的操作类似。

量化被实现为一个案例类:

case class Quantization(toInt: Double => Int,
     toDouble: Int => Double) {
   def this(R: Int) =  this((x: Double) => 
         (x*R).floor.toInt, (n: Int) => n/R) 
}

第一个toInt函数将实数值转换为整数,而toDouble函数将整数转换回实数值。discretizationinverse函数被封装到一个类中,以减少两个相反转换函数之间不一致的风险。

基因的实例化将谓词表示转换为位字符串(java.util.BitSet类型的位),使用量化函数Quantization.toInt

基因的布局由Encoding类如下定义:

class Encoding(nValueBits: Int, nOpBits: Int) {
  val rValue = Range(0, nValueBits)  
  val length = nValueBits + nOpBits
  val rOp = Range(nValueBits, length)
}

Encoding类指定了基因的位布局,作为一个数字nValueBits来编码值,以及一个数字nOpBits来编码操作符。该类定义了值的rValue范围和操作符的rOp范围。客户端代码必须提供给Encoding类的隐式实例。

基因(编码)的位集bitset通过使用apply方法实现:

def apply(value: Double, op: Operator): BitSet = {
  val bitset = new BitSet(encoding.length)  
  encoding.rOp foreach(i =>  //10
    if(((op.id>>i) & 0x01)==0x01) bitset.set(i))
  encoding.rValue foreach(i => //11
    if( ((quant.toInt(value)>>i) & 0x01)==0x01) bitset.set(i))
  bitset
}

基因的位布局使用java.util.BitSet创建。op操作符首先通过其标识符id(第10行)进行编码。value通过调用toInt方法进行量化,然后进行编码(第11行)。

unapply方法将基因从位集或位字符串解码为值和操作符的对。该方法使用量化实例将位覆盖到值,并使用一个辅助函数convert,该函数与其在源代码中的实现以及伴随书籍一起描述(第12行):

def unapply(bits: BitSet): (Double, Operator) = 
  (quant.toDouble(convert(encoding.rValue, bits)), 
   op(convert(encoding.rOp, bits))) //12

Operator特质定义了任何操作符的签名。每个特定领域的问题都需要一组独特的操作:布尔、数值或字符串操作:

trait Operator {
   def id: Int
   def apply(id: Int): Operator
}

前一个操作符有两个方法:一个标识符id和一个将索引转换为操作符的apply方法。

选择

生命周期中的第一个遗传操作是选择过程。Population类的select方法以最有效的方式实现了选择阶段到染色体群体的步骤,如下所示:

def select(score: Chromosome[T]=> Unit, cutOff: Double): Unit = {
  val cumul = chromosomes.map( _.cost).sum/SCALING_FACTOR //13
  chromosomes foreach( _ /= cumul) //14

  val _chromosomes = chromosomes.sortWith(_.cost < _.cost)//15
  val cutOffSize = (cutOff*_chromosomes.size).floor.toInt //16
  val popSize = if(limit < cutOffSize) limit else cutOffSize

  chromosomes.clear //17
  chromosomes ++= _chromosomes.take(popSize) //18
}

select方法计算整个种群的cost(第13行)的累积和。它将每个染色体的成本进行归一化(第14行),按值递减对种群进行排序(第15行),并在种群增长上应用cutOff软限制函数(第16行)。下一步将种群大小减少到两个限制中的最低值:硬限制limit或软限制cutOffSize。最后,清除现有的染色体(第17行)并用下一代更新(第18行)。

注意

种群大小为偶数

繁殖周期中的下一阶段是交叉,这需要配对父染色体。使种群大小为偶数整数是有意义的。

score评分函数接受一个染色体作为参数,并返回该染色体的cost值。

控制种群增长

自然选择过程控制或管理物种种群的增长。遗传算法使用以下两种机制:

  • 种群的最大绝对大小(硬限制)。

  • 随着优化过程的进行,减少种群的动力(软限制)。这种对种群增长的激励(或惩罚)由选择过程中使用的cutOff值定义(select方法)。

cutoff值是通过使用用户定义的Int => Double类型的softLimit函数来计算的,该函数作为配置参数提供(softLimit(cycle: Int) => a.cycle +b)。

GA 配置

遗传算法所需的四个配置和调整参数如下:

  • xOver:这是交叉比率(或概率),其值在区间[0, 1]内

  • mu:这是突变率

  • maxCycles:这是最大繁殖周期数

  • softLimit:这是对种群增长的软约束

考虑以下代码:

class GAConfig(val xover: Double, 
     val mu: Double, 
     val maxCycles: Int, 
     val softLimit: Int => Double) extends Config {
   val mutation = (cycle : Int) => softLimit(cycle)
}

交叉

如前所述,遗传交叉算子将两个染色体配对以生成两个后代染色体,这些后代染色体在下一繁殖周期的选择阶段与种群中的所有其他染色体(包括它们的父母)竞争。

种群

我们使用+-符号作为 Scala 中交叉算子的实现。有几种选择交叉染色体对的选项。这种实现按染色体的适应度(或cost的倒数)值对染色体进行排序,然后将种群分成两半。最后,它将来自每个半数的相同排名的染色体配对,如下面的图所示:

种群

在交叉之前,种群内染色体的配对

交叉实现+-使用前面描述的配对方案选择交叉的父染色体候选者。考虑以下代码:

def +- (xOver: Double): Unit = 
  if( size > 1) {
    val mid = size>>1
    val bottom = chromosomes.slice(mid, size) //19
    val gIdx = geneticIndices(xOver)  //20

    val offSprings = chromosomes.take(mid).zip(bottom)
          .map{ case (t, b) => t +- (b, gIdx) }.unzip //21
    chromosomes ++= offSprings._1 ++ offSprings._2 //22
  }

此方法将种群分成两个大小相等的子种群(行19),并应用 Scala 的zipunzip方法生成后代染色体对的集合(行20)。对每个染色体对应用+-交叉操作符,以产生offSprings对的数组(行21)。最后,crossover方法将后代染色体添加到现有种群中(行22)。xOver交叉值是在[config.xOver, 1]`区间内随机生成的概率。

当发生交叉或突变时,GeneticIndices案例类定义了位的两个索引。第一个chOpIdx索引是染色体中受遗传操作影响的位的绝对索引(行23)。第二个geneOpIdx索引是基因内位交叉或突变的索引(行24):

case class GeneticIndices(val chOpIdx: Int, //23
     val geneOpIdx: Int)  //24

geneticIndices方法计算染色体和基因中交叉位的相对索引:

def geneticIndices(prob: Double): GeneticIndices = {
  var idx = (prob*chromosomeSize).floor.toInt //25
  val chIdx = if(idx == chromosomeSize) chromosomeSize-1 
       else idx //25

  idx = (prob*geneSize).floor.toInt  
  val gIdx = if(idx == geneSize) geneSize-1 else idx //26
  GeneticIndices(chIdx, gIdx)
}

第一个chIdx索引器是基因在受遗传操作影响的染色体中的索引或排名(行25)。第二个gIdx索引器是基因内位的相对索引(行26)。

让我们考虑一个由 2 个基因组成,每个基因有 63 位/元素的染色体,如下面的图所示:

种群

geneticIndices方法计算以下内容:

  • 基因在染色体中的chIdx索引和基因内位的gIdx索引

  • 遗传操作选择chIdx索引的基因进行改变(即第二个基因)

  • 遗传操作在gIdx索引的位上改变染色体(即chIdx64 + gIdx*)

染色体

首先,我们需要定义Chromosome类,它将基因列表code(遗传代码)作为参数:

val QUANT = 500
class ChromosomeT <: Gene {  
  var cost: Double = QUANT*(1.0 + Random.nextDouble) //27

  def +- (that: Chromosome[T], indices: GeneticIndices): //28
     (Chromosome[T], Chromosome[T]) 
  def ^ (indices: GeneticIndices): Chromosome[T] //29
  def /= (normalizeFactor: Double): Unit =   //30
      cost /= normalizeFactor
  def decode(implicit d: Gene=>T): List[T] =  //31
      code.map( d(_)) 
  …
}

染色体的成本(或不适度)初始化为QUANT2*QUANT之间的随机值(行27)。遗传+-交叉操作符生成一对两个后代染色体(行28)。遗传^突变操作符创建一个略微修改的(1 或 2 位)染色体克隆(行29)。/=方法将染色体的成本归一化(行30)。decode方法使用基因与其子类之间的隐式转换d将基因转换为逻辑谓词或规则(行31)。

注意

成本初始化

从初始种群初始化染色体的成本没有绝对规则。然而,建议使用具有大范围的非零随机值来区分染色体。

使用分层编码实现一对染色体的交叉操作遵循两个步骤:

  1. 在每个染色体上找到对应于indices.chOpIdx交叉索引的基因,然后交换剩余的基因。

  2. xoverIdx处分割和拼接基因交叉。

考虑以下代码:

def +- (that: Chromosome[T], indices: GeneticIndices): 
    (Chromosome[T], Chromosome[T]) = {
  val xoverIdx = indices.chOpIdx //32
  val xGenes =  spliceGene(indices, that.code(xoverIdx)) //33

  val offSprng1 = code.slice(0, xoverIdx) ::: 
       xGenes._1 :: that.code.drop(xoverIdx+1) //34
  val offSprng2 = that.code.slice(0, xoverIdx) ::: 
      xGenes._2 :: code.drop(xoverIdx+1)
  (ChromosomeT, ChromosomeT) //35
}

交叉方法计算每个父染色体中定义交叉的位索引xoverIdx(行32)。this.code(xoverIdx)that.code(xoverIdx)基因通过spliceGene方法交换并拼接,以生成拼接基因(行33):

def spliceGene(indices: GeneticIndices, thatCode: T): (T,T) ={
  ((this.code(indices.chOpIdx) +- (thatCode,indices)), 
   (thatCode +- (code(indices.chOpIdx),indices)) )
}

后代染色体的收集是通过整理父染色体前xOverIdx个基因、交叉基因以及另一个父染色体的剩余基因(行34)来完成的。该方法返回一对后代染色体(行35)。

基因

交叉操作应用于基因使用Gene类的+-方法。thisthat基因之间的位交换使用BitSet Java 类来重新排列排列后的位:

def +- (that: Gene, indices: GeneticIndices): Gene = {
  val clonedBits = cloneBits(bits) //36

  Range(indices.geneOpIdx, bits.size).foreach(n => 
    if( that.bits.get(n) ) clonedBits.set(n) 
    else clonedBits.clear(n) //37
   )
   val valOp = decode(clonedBits) //38
   new Gene(id, valOp._1, valOp._2)
}

基因的位被克隆(行36),然后通过交换它们的位以及indices.geneOpIdx交叉点来拼接。cloneBits函数复制一个位字符串,然后使用decode方法将其转换为(目标值,操作符)元组(行37)。我们省略这两个方法,因为它们对于理解算法不是关键的。

变异

种群变异使用与交叉操作相同的算法方法。

种群

^变异操作员对种群中的所有染色体调用相同的操作符,然后将变异后的染色体添加到现有种群中,以便它们可以与原始染色体竞争。我们使用^符号来定义变异操作符,以提醒您变异是通过翻转一个位来实现的:

def ^ (prob: Double): Unit = 
  chromosomes ++= chromosomes.map(_ ^ geneticIndices(prob))

prob变异参数用于计算变异基因的绝对索引,geneticIndices(prob)

染色体

在染色体上实现^变异操作员的方法包括变异indices.chOpIdx索引处的基因(行39),然后更新染色体中的基因列表(行40)。该方法返回一个新的染色体(行41),该染色体将与原始染色体竞争:

def ^ (indices: GeneticIndices): Chromosome[T] = { //39 
  val mutated = code(indices.chOpIdx) ^ indices 
  val xs = Range(0, code.size).map(i =>
    if(i== indices.chOpIdx) mutated 
    else code(i)).toList //40
  ChromosomeT //41
}

基因

最后,变异操作员翻转(XOR)indices.geneOpIdx索引处的位:

def ^ (indices: GeneticIndices): Gene = { 
  val idx = indices.geneOpIdx
  val clonedBits = cloneBits(bits) //42

  clonedBits.flip(idx)  //43
  val valOp = decode(clonedBits)  //44
  new Gene(id, valOp._1, valOp._2) //45
}

^方法通过在indices.geneOpIdx索引处翻转位(行42)来变异复制的位字符串clonedBits。它通过将其转换为(目标值,操作符)元组来解码和转换变异后的位字符串(行43)。最后一步是从目标-操作符元组创建一个新的基因(行44)。

繁殖

让我们将繁殖周期封装到一个Reproduction类中,该类使用评分函数score

class ReproductionT <: Gene

mate繁殖函数实现了三个遗传操作符的序列或工作流程:select用于选择,+-(交叉)用于交叉,^(mu)用于变异:

def mate(population: Population[T], config: GAConfig, 
    cycle: Int): Boolean = (population.size: @switch) match {
  case 0 | 1 | 2 => false   //46
  case _ => {
    rand.setSeed(rand.nextInt + System.currentTimeMillis)
    population.select(score, config.softLimit(cycle)) //47
    population +- rand.nextDouble*config.xover //48
    population ^ rand.nextDouble*config.mu  //49
    true
  }
}

如果种群大小小于 3(第46行),则mate方法返回 false(即,繁殖周期终止)。当前种群中的染色体按成本增加排序。具有高成本或低适应度的染色体被丢弃,以符合对种群增长的软限制softLimit(第47行)。随机生成的概率被用作对整个剩余种群进行交叉操作的输入(第48行),以及作为对剩余种群进行突变的输入(第49行):

繁殖

人口增长的线性与二次软限制的示意图

求解器

GASolver类管理繁殖周期和染色体种群。求解器被定义为使用GAConfig类型的显式配置的ETransform类型的数据转换,如第二章中单调数据转换部分所述,Hello World!(第50行)。

GASolver类实现了GAMonitor特性,以监控种群多样性、管理繁殖周期和控制优化器的收敛性(第51行)。

基于遗传算法的求解器有以下三个参数:

  • config:这是遗传算法执行的配置

  • score:这是染色体的评分函数

  • tracker:这是一个可选的跟踪函数,用于初始化GAMonitor的监控功能

代码如下:

class GASolverT <: Gene 
     extends ETransformGAConfig //50
        with GAMonitor[T] { //51

  type U = Population[T]  //52
  type V = Population[T]  //53

  val monitor: Option[Population[T] => Unit] = tracker
  def |>(initialize: => Population[T]): Try[Population[T]] = 
      this.|> (initialize()) //54
  override def |> : PartialFunction[U, Try[V]] //55
}

这种显式数据转换必须初始化输入元素的U类型(第52行)和输出元素的V类型(第53行),用于预测或优化方法|>。优化器以初始种群作为输入,并生成一个非常小的适应度染色体种群,从中提取最佳解决方案(第55行)。

种群是通过|>方法( => Population[T])生成的,该方法以Population类的构造函数作为参数(第54行)。

让我们简要地看一下分配给遗传算法的GAMonitor监控特性。该特性具有以下两个属性:

  • monitor:这是一个抽象值,由实现此特性的类初始化(第55行)。

  • state:这是遗传算法执行的当前状态。遗传算法的初始状态是GA_NOT_RUNNING(第56行)。

代码如下:

trait GAMonitor[T <: Gene] extends Monitor {
  self: { 
    def |> :PartialFunction[Population[T],Try[Population[T]]] 
  } => //55
    val monitor: Option[Population[T] => Unit] //56
    var state: GAState = GA_NOT_RUNNING //57

    def isReady: Boolean = state == GA_NOT_RUNNING
    def start: Unit = state = GA_RUNNING
    def isComplete(population: Population[T], 
        remainingCycles: Int): Boolean  = { … } //58
}

遗传算法的状态只能通过GAMonitor类的实例在|>方法中更新。(第55行)

这里是遗传算法执行可能状态的子集:

sealed abstract class GAState(description: String)
case class GA_FAILED(description: String) 
   extends GAState(description)
object GA_RUNNING extends GAState("Running")

求解器在每次繁殖周期调用isComplete方法来测试优化器的收敛性(第58行)。

有两种方法可以估计繁殖周期正在收敛:

  • 贪婪算法:在此方法中,目标是检查是否在最后 m 代繁殖周期中,n 个最适应的染色体没有发生变化

  • 损失函数:这种方法与监督学习训练的收敛标准类似

让我们考虑以下遗传算法求解器的实现:

override def |> : PartialFunction[U, Try[V]] = {
  case population: U if(population.size > 1 && isReady) => {
    start //59
    val reproduction = ReproductionT  //60

    @tailrec
    def reproduce(population: Population[T], 
          n:Int): Population[T] = { //61
      if( !reproduction.mate(population, config, n) || 
         isComplete(population, config.maxCycles -n) )
       population
      else
       reproduce(population, n+1)
    }
    reproduce(population, 0)
    population.select(score, 1.0) //62
    Try(population)
  }
}

优化方法初始化执行状态(第 59 行)和繁殖周期(或一个时代)的组件(第 60 行)。繁殖周期(或一个时代)被实现为一个尾递归,它测试最后一个繁殖周期是否失败或优化是否收敛到一个解决方案(第 61 行)。最后,通过调用 Population 类的 select 方法(第 62 行)对剩余的最适应染色体进行重新排序。

交易策略的遗传算法

让我们将我们在遗传算法方面的专业知识应用于评估使用交易信号进行证券交易的不同策略。了解交易策略的知识不是理解遗传算法实现所必需的。然而,你可能希望熟悉附录 A 中简要描述的证券和金融市场技术分析的基础和术语,如技术分析部分所述。

问题是要找到最佳交易策略来预测给定一组交易信号的安全价格的增加或减少。交易策略被定义为当变量 x = {x[j]}(从安全价格或每日或每周交易量等金融指标中衍生而来)超过或等于或低于预定义的目标值 α[j] 时被触发或触发的交易信号集 ts[j](参见附录 A 中的交易信号和策略部分,基本概念)。

从价格和成交量可以导出的变量数量可能非常大。即使是经验最丰富的金融专业人士也面临着两个挑战,如下所述:

  • 选择一组与给定数据集相关的最小交易信号(最小化成本或不适度函数)

  • 通过从个人经验和专业知识中得出的启发式方法来转换那些交易信号

备注

遗传算法的替代方案

之前描述的问题当然可以使用之前章节中介绍的一种机器学习算法来解决。这只是一个定义训练集并将问题表述为最小化预测器和训练分数之间损失函数的问题。

以下表格列出了交易类及其在遗传世界中的对应物:

通用类 对应的证券交易类
操作符 SOperator
基因 Signal
染色体 Strategy
种群 StrategiesFactory

交易策略的定义

染色体是交易策略的遗传编码。工厂类 StrategyFactory 组装交易策略的组件:算子、不适度函数和信号

交易算子

让我们将 Operator 特性通过 SOperator 类扩展来定义触发信号所需的操作。SOperator 实例有一个单一参数:其标识符 _id。该类重写了 id() 方法以检索 ID(类似地,该类还重写了 apply 方法以将 ID 转换为 SOperator 实例):

class SOperator(_id: Int) extends Operator {
  override def id: Int = _id
  override def apply(idx: Int): SOperator = new SOperator(idx) 
}

交易信号使用的算子是逻辑算子:< (LESS_THAN), > (GREATER_THAN), 和 = (EQUAL),如下所示:

object LESS_THAN extends SOperator(1) 
object GREATER_THAN extends SOperator(2)
object EQUAL extends SOperator(3)

SOperator 类型的每个算子通过 operatorFuncMap 映射与一个评分函数相关联。评分函数计算信号相对于实值或时间序列的成本(或不适度):

val operatorFuncMap = MapOperator, (Double,Double) =>Double => target - x),
  … )

Populationselect 方法通过量化谓词的真实性来计算信号的 cost 值。例如,对于交易信号 x > 10,当 x = 5 时,不适度值为 5 – 10 = -5,如果 x = 14,则获得信用 14 – 10 = 4。在这种情况下,不适度值类似于判别性机器学习算法中的成本或损失。

成本函数

让我们考虑以下定义为一组两个信号来预测证券价格突然相对下降 Δp 的交易策略:

  • 相对成交量 v[m],条件为 v[m] < α

  • 相对波动率 v[l],条件为 v[l] > β

让我们看一下以下图表:

成本函数

证券的价格、相对成交量、相对波动率的图表

由于目标是模拟股价的突然暴跌,我们应该奖励那些预测股价大幅下跌的交易策略,并惩罚那些仅在股价小幅下跌或上涨时表现良好的策略。在具有两个信号(相对成交量 v[m] 和相对波动率 v[l])、n 个交易时段、成本或不适度函数 C、给定的股价相对变化和惩罚 w = -Δp (M2) 的交易策略的情况下:

成本函数

交易信号

让我们通过以下方式将 Gene 类子类化来定义 Signal 类型的交易信号:

class Signal(id: String, target: Double, op: Operator,
   xt: DblVector, weights: Option[DblVector] = None)
   (implicit quantize: Quantization, encoding: Encoding) 
  extends Gene(id, target, op) 

Signal 类需要以下参数:

  • 特征的标识符 id

  • 一个 target

  • 一个 op 算子

  • DblVector 类型的 xt 时间序列

  • 与时间序列 xt 的每个数据点相关联的可选 weights

  • 一个隐式量化实例,quantize

  • 一个隐式 encoding 方案

Signal类的主要目的是计算其score作为染色体。染色体通过累加其包含的信号的分数或加权分数来更新其cost。交易信号的分数简单地是时间序列ts中每个条目的信号惩罚或真实性的总和:

override def score: Double = 
  if(!operatorFuncMap.contains(op)) Double.MaxValue
  else {
    val f = operatorFuncMap.get(op).get
    if( weights != None ) xt.zip(weights.get)
        .map{case(x, w) => w*f(x,target)}.sum
    else xt.map( f(_, target)).sum   
  }

交易策略

交易策略是一个无序的交易信号列表。创建一个用于生成交易策略的工厂类是有意义的。StrategyFactory类从现有的子类型Gene的信号池中创建List[Signal]类型的策略:

交易策略

交易信号的工厂模式

StrategyFactory类有两个参数:交易策略中的信号数量nSignals以及隐式的QuantizationEncoding实例(第63行):

class StrategyFactory(nSignals: Int)  //63
    (implicit quantize: Quantization, encoding: Encoding){
  val signals = new ListBuffer[Signal]   
  lazy val strategies: Pool[Signal] //64
  def += (id: String, target: Double, op: SOperator, 
       xt: DblVector, weights: DblVector)
  …
 }

+=方法接受五个参数:标识符idtarget值、将类指定为Geneop操作、用于评分信号的xt时间序列以及与整体成本函数相关的weightsStrategyFactory类以懒值的形式生成所有可能的信号序列作为交易策略,以避免在需要时无谓地重新生成池(第64行),如下所示:

lazy val strategies: Pool[Signal] = {
  implicit val ordered = Signal.orderedSignals //70

  val xss = new Pool[Signal] //65
  val treeSet = new TreeSet[Signal] ++= signals.toList //66
  val subsetsIterator = treeSet.subsets(nSignals) //67

  while( subsetsIterator.hasNext) {
    val signalList = subsetsIterator.next.toList  //68
    xss.append(ChromosomeSignal) //69
  } 
  xss
}

strategies值的实现通过将信号列表转换为treeset(第66行)创建了一个信号池Pool(第65行)。它将树集分解为每个具有nSignals个节点的唯一子树。它实例化了一个subsetsIterator迭代器来遍历子树序列(第67行),并将它们转换为列表(第68行),作为新染色体(交易策略)的参数(第69行)。在树集中对信号进行排序的orderedSignals过程必须隐式定义(第70行)为val orderedSignals = Ordering.by((signal: Signal) => signal.id)

交易信号编码

交易断言的编码是遗传算法中最关键的部分。在我们的例子中,我们将一个断言编码为一个元组(目标值,操作符)。让我们考虑简单的断言波动率 > 0.62。离散化将值 0.62 转换为实例的 32 位和操作符的 2 位表示:

交易信号编码

交易信号的编码:波动率 > 0.62

注意

IEEE-732 编码

断言的阈值值被转换为一个整数(Int类型或Long)。浮点值的 IEEE-732 二进制表示使得应用遗传操作所需的位寻址相当具有挑战性。一个简单的转换包括以下内容:

  • encoding e: (x: Double) => (x*100000).toInt

  • decoding d: (x: Int) => x*1e-5

所有值都被归一化,因此不存在溢出 32 位表示的风险。

一个测试用例

目标是评估在 2008 年秋季股市崩溃期间哪个交易策略最相关(最适应)。让我们以一家金融机构,高盛的股票价格为代理,来考虑市场突然下跌:

一个测试案例

2008 年 9 月至 11 月高盛股票价格的突然下跌

除了两个连续交易时段之间股票价格的变化(dPrice)之外,该模型还使用以下参数(或交易信号):

  • dVolume:这是两个连续交易时段之间成交量的相对变化

  • dVolatility:这是两个连续交易时段之间波动的相对变化

  • volatility:这是交易时段内的相对波动性

  • vPrice:这是股票开盘价和收盘价之间的相对差异

交易数据和度量名称约定在附录 A 的技术分析下的交易数据部分中描述,基本概念

执行遗传算法需要以下步骤:

  1. 提取模型参数或变量。

  2. 生成交易策略的初始种群。

  3. 设置 GA 配置参数,包括允许的最大繁殖周期数、交叉和突变比率以及种群增长的软限制函数。

  4. 使用评分/不适应度函数实例化 GA 算法。

  5. 提取最能解释高盛股票价格急剧下跌的最适应交易策略。

创建交易策略

遗传算法的输入是交易策略种群。每个策略由三个交易信号的组合组成,每个交易信号是一个元组(信号 ID、操作符和目标值)。

第一步是提取模型参数,如图所示,包括股票价格量、波动性和两个连续交易时段之间的相对波动性(行71):

Import YahooFinancials._
val NUM_SIGNALS = 3

def createStrategies: Try[Pool[Signal]] = {
  val src = DataSource(path, false, true, 1) //71
  for {  //72
    price <- src.get(adjClose)
    dPrice <- delta(price, -1.0) 
    volume <- src.get(volume)
    dVolume <- delta(volume, 1.0)
    volatility <- src.get(volatility)
    dVolatility <- delta(volatility, 1.0)
    vPrice = src.get(vPrice)
  } yield { //72
    val factory = new StrategyFactory(NUM_SIGNALS) //73

    val weights = dPrice  //74
    factory += ("dvolume", 1.1, GREATER_THAN, dVolume, weights)
    factory += ("volatility", 1.3, GREATER_THAN, 
       volatility.drop(1), weights)
    factory += ("vPrice", 0.8, LESS_THAN, 
       vPrice.drop(1), weights)
    factory += ("dVolatility", 0.9, GREATER_THAN, 
       dVolatility, weights)
    factory.strategies
   }
}

目的是生成初始策略种群,这些策略竞争以成为与高盛股票价格下跌相关。初始交易策略种群是通过从四个按股票价格变化加权的交易信号中创建组合来生成的:∆(volume) > 1.1∆(volatility) > 1.3∆(close-open) < 0.8,和volatility > 0.9

delta方法计算连续交易时段之间交易变量的变化。它调用XTSeries.zipWithShift方法,该方法在第三章的时间序列在 Scala部分中介绍,数据预处理

def delta(xt: DblVector, a: Double): Try[DblVector] = Try {
  zipWithShift(xt, 1).map{case(x, y) => a*(y/x - 1.0)}
}

交易策略是由上一节中介绍的StrategyFactory类生成的(行73)。交易策略的权重是通过计算两个连续交易会话之间股票价格的dPrice差异来计算的(行74)。通过将权重替换为平均价格变化来选择无权重的交易策略,如下所示:

val avWeights = dPrice.sum/dPrice.size
val weights = Vector.fill(dPrice.size)(avWeights)

以下图表说明了交易策略初始种群的生成:

创建交易策略

生成交易策略初始种群的方案

配置优化器

遗传算法执行配置参数的分类如下:

  • 调优参数,如交叉、突变比率或种群增长的软限制

  • 数据表示参数,如量化编码

  • 评分方案

GA 的四个配置参数是执行中允许的最大繁殖周期数(MAX_CYCLES)、交叉(XOVER)、突变比率(MU)以及控制种群增长的软限制函数(softLimit)。软限制作为周期数(n)的线性递减函数实现,以重新训练种群的增长,随着遗传算法的执行而进行:

val XOVER = 0.8  //Probability(ratio) for cross-over
val MU = 0.4  //Probability(ratio) for mutation
val MAX_CYCLES = 400  //Max. number of optimization cycles 

val CUTOFF_SLOPE = -0.003  //Slope linear soft limit
val CUTOFF_INTERCEPT = 1.003  //Intercept linear soft limit
val softLimit = (n: Int) => CUTOFF_SLOPE*n +CUTOFF_INTERCEPT

交易策略通过编码(行75)转换为染色体。为了将每个交易信号(基因)中的目标值编码,必须隐式定义一个digitize量化方案(行76):

implicit val encoding = defaultEncoding //75
val R = 1024  //Quantization ratio
implicit val digitize = new Quantization(R) //76

评分函数通过将评分函数应用于包含的三个交易信号(基因)中的每一个来计算交易策略(染色体)的成本或不适度(行77):

val scoring = (chr: Chromosome[Signal]) => {
  val signals: List[Gene] = chr.code
  chr.cost = signals.map(_.score).sum //77
}

寻找最佳交易策略

工厂在createStrategies方法中生成的交易策略作为初始种群输入遗传算法(行79)。种群增长的上限设置为初始种群大小的八倍(行78):

createStrategies.map(strategies => {
  val limit = strategies.size <<3 //78
  val initial = PopulationSignal //79

  val config = GAConfig(XOVER, MU, MAX_CYCLES,softLimit) //80
  val solver = GASolverSignal) 
                                                //81
  (solver |> initial)
    .map(_.fittest.map(_.symbolic).getOrElse("NA")) match {
      case Success(results) => show(results)
      case Failure(e) => error("GAEval: ", e)
    } //82
})

配置(config,行80),评分函数以及可选的跟踪函数,这些都是创建和执行solver遗传算法(行81)所必需的。由|>运算符生成的部分函数将交易策略的初始种群转换为两个最适应的策略(行82)。

监控函数、跟踪器和其他方法的文档化源代码可在网上找到。

测试

每个交易策略的成本函数C(或不适应度)评分根据高盛股票价格下降的比率进行加权。让我们运行以下两个测试:

  • 使用价格变化加权评分评估遗传算法的配置

  • 使用无权评分函数评估遗传算法

加权评分

分数是根据股票 GS 的价格变动进行加权的。测试使用了三组不同的交叉和变异比率:(0.6, 0.2),(0.3, 0.1),和(0.2, 0.6)。每种情景的最佳交易策略如下:

  • 0.6-0.2: 变化 < 0.82 dVolume > 1.17 波动性 > 1.35 成本= 0.0 适应度: 1.0E10

  • 0.3-0.1: 变化 < 0.42 dVolume > 1.61 波动性 > 1.08 成本= 59.18 适应度: 0.016

  • 0.2-0.6: 变化 < 0.87 dVolume < 8.17 波动性 > 3.91 成本= 301.3 适应度: 0.003

对于每种情况,最佳交易策略与初始种群在以下一个或几个原因上没有太大差异:

  • 交易信号的初始猜测是好的

  • 初始种群的大小太小,无法产生遗传多样性

  • 测试没有考虑股价的下降率

使用 交叉 = 0.2变异 = 0.6 的遗传算法执行产生了一种与前两种情况不一致的交易策略。一个可能的解释是,交叉总是应用于三个基因中的第一个,迫使优化器收敛到一个局部最小值。

让我们检查遗传算法在执行过程中的行为。我们特别感兴趣的是平均染色体不适应度分数的收敛。平均染色体不适应度是种群总不适应度分数与种群大小的比率。让我们看一下以下图表:

加权分数

对于交叉比率为 0.2 和变异率为 0.6 的加权分数遗传算法的收敛

遗传算法收敛得相当快,然后稳定下来。种群大小通过交叉和变异操作增加,直到达到 256 个交易策略的最大值。在 23 个交易周期后,种群大小达到软限制或约束。测试再次使用不同的交叉和变异比率进行,如下面的图表所示:

加权分数

交叉和变异比率对加权分数遗传算法收敛的影响

遗传算法执行的配置文件不太受交叉和变异比率不同值的影响。在高交叉比率(0.6)的情况下,染色体不适应度分数随着执行过程而波动。在某些情况下,染色体之间的不适应度分数非常小,以至于遗传算法重复使用相同的几个交易策略。

染色体不适应性的快速下降与一些最佳策略是初始种群的一部分的事实一致。然而,这也应该引起一些担忧,即遗传算法早期就锁定在局部最小值上。

未加权分数

执行与之前类似的测试,使用无权重交易策略(使用平均价格变动的交易策略)评分公式产生了一些有趣的结果,如下面的图表所示:

无权重评分

遗传算法在交叉比 0.4 和突变 0.4 以及无权重评分下的收敛

种群大小的配置文件与使用加权评分的测试相似。然而,染色体平均成本模式略呈线性。无权重(或平均)将股价下降率加到评分(成本)上。

注意

评分函数的复杂性

评分(或成本计算)公式的复杂性增加了遗传算法无法正确收敛的可能性。收敛问题的可能解决方案如下:

  • 使加权函数加性(更简单)

  • 增加初始种群的大小和多样性

遗传算法的优点和风险

现在,应该很清楚,遗传算法为科学家提供了一个强大的工具箱,可以用来优化以下问题:

  • 理解得不好。

  • 可能存在多个足够好的解决方案。

  • 具有离散、不连续和非可微分的函数。

  • 可以轻松与规则引擎和知识库(例如,学习分类器系统)集成。

  • 不需要深厚的领域知识。遗传算法通过遗传算子生成新的解决方案候选者。初始种群不必包含最适应的解决方案。

  • 不需要了解如牛顿-拉夫森共轭梯度BFGS等数值方法作为优化技术,这些方法会让那些对数学不感兴趣的人感到害怕。

然而,进化计算不适合以下问题:

  • 适应度函数无法明确定义

  • 寻找全局(绝对)最小值或最大值对于问题至关重要

  • 执行时间必须可预测

  • 解决方案必须实时或准实时(流数据)提供。

摘要

你是否对进化计算,特别是遗传算法及其优点、局限以及一些常见陷阱上瘾?如果答案是肯定的,那么你可能发现下一章中介绍的分类学习系统非常有趣。本章讨论了以下主题:

  • 进化计算中的关键概念

  • 遗传算子的关键组件和算子

  • 使用金融交易策略作为背景定义适应度或非适应度评分的陷阱

  • 交易策略中编码谓词的挑战

  • 遗传算法的优点和风险

  • 从底层构建遗传算法预测工具的过程

遗传算法是强化学习特殊类别中的一个重要元素,将在下一章的学习分类系统部分介绍。

第十一章:强化学习

本章介绍了强化学习的概念,该概念在游戏和机器人领域得到广泛应用。本章的第二部分致力于学习分类系统,它将强化学习技术与上一章中引入的进化计算技术相结合。学习分类器是一类有趣的算法,通常不包括在专门针对机器学习的文献中。如果您对强化学习的起源、目的和科学基础感兴趣,我强烈建议您阅读 R. Sutton 和 A. Barto 合著的关于强化学习的开创性书籍[11:1]。

在本章中,您将学习以下主题:

  • 强化学习背后的基本概念

  • Q 学习算法的详细实现

  • 使用强化学习管理并平衡投资组合的一个简单方法

  • 学习分类系统的介绍

  • 扩展学习分类器的一个简单实现

学习分类系统LCS)这一节主要是信息性的,不包括测试案例。

强化学习

随着第一个自主系统的设计,对传统学习技术的替代需求产生了。

问题

自主系统是半独立系统,以高度自主的方式执行任务。自主系统触及我们生活的各个方面,从机器人、自动驾驶汽车到无人机。自主设备对其操作环境做出反应。这种反应或行动需要了解环境的当前状态以及之前的状态。

自主系统具有以下特定特征,这些特征挑战了传统的机器学习方法:

  • 由于可能的状态组合数量巨大,自主系统具有定义不明确的领域知识。

  • 由于以下原因,传统的非顺序监督学习不是一个实际的选择:

    • 训练消耗了大量的计算资源,这些资源并不总是小型自主设备所拥有的

    • 一些学习算法不适合实时预测

    • 模型没有捕捉到数据流的顺序性

  • 隐藏马尔可夫模型等顺序数据模型需要训练集来计算发射和状态转移矩阵(如第七章中“隐藏马尔可夫模型”部分所述,顺序数据模型),这些矩阵并不总是可用。然而,如果某些状态未知,强化学习算法将从隐藏马尔可夫模型中受益。这些算法被称为行为隐藏马尔可夫模型[11:2]。

  • 如果搜索空间可以通过启发式方法进行约束,则遗传算法是一个选项。然而,遗传算法的响应时间不可预测,这使得它们在实际实时处理中不实用。

一种解决方案 – Q 学习

强化学习是一种算法方法,用于理解和最终自动化基于目标的决策制定。强化学习也被称为控制学习。从知识获取的角度来看,它与监督学习和无监督学习技术不同:自主、自动化系统或设备通过与环境的直接实时交互来学习。强化学习在机器人、导航智能体、无人机、自适应过程控制、游戏和在线学习、调度和路由问题等方面有众多实际应用。

术语

强化学习引入了新的术语,如以下所示,这些术语与较老的学习技术大不相同:

  • 环境:这是任何具有状态和机制在状态之间转换的系统。例如,机器人的环境是它操作的地形或设施。

  • 智能体:这是一个与环境交互的自动化系统。

  • 状态:环境或系统的状态是描述环境的变量或特征的集合。

  • 目标或吸收状态或终止状态:这是提供比任何其他状态更高折现累积奖励的状态。高累积奖励防止最佳策略在训练期间依赖于初始状态。

  • 动作:这定义了状态之间的转换。智能体负责执行或至少推荐一个动作。在执行动作后,智能体从环境中收集奖励(或惩罚)。

  • 策略:这定义了环境任何状态应选择和执行的动作。

  • 最佳策略:这是通过训练生成的策略。它定义了 Q 学习中的模型,并且会随着任何新剧集的更新而不断更新。

  • 奖励:这量化了智能体与环境之间的积极或消极交互。奖励基本上是学习引擎的训练集。

  • 剧集:这定义了从初始状态达到目标状态所需的步骤数量。剧集也称为试验。

  • 视野:这是用于最大化奖励的未来步骤或动作的数量。视野可以是无限的,在这种情况下,未来奖励会被折现,以便策略的价值收敛。

概念

强化学习的关键组成部分是一个决策代理,它通过选择和执行最佳行动方案来对其环境做出反应,并因该行动方案而获得奖励或惩罚[11:3]。你可以将这些代理想象成在未知地形或迷宫中导航的机器人。毕竟,机器人将强化学习作为其推理过程的一部分。以下图表展示了强化学习代理的概要架构:

概念

强化学习的四个状态转换

代理收集环境的状态,选择,然后执行最合适的行动。环境通过改变其状态并对代理的行动进行奖励或惩罚来响应该行动。

一个场景或学习周期的四个步骤如下:

  1. 学习代理检索或通知环境的新状态。

  2. 代理评估并选择可能提供最高奖励的行动。

  3. 代理执行行动。

  4. 代理收集奖励或惩罚并将其应用于校准学习算法。

注意

强化与监督

强化学习的训练过程奖励那些最大化价值或回报的特征。监督学习奖励那些满足预定义标签值的特征。监督学习可以被视为强制学习。

代理的行动修改了系统的状态,反过来又通知代理新的操作条件。尽管并非每个行动都会触发环境状态的改变,但代理仍然收集奖励或惩罚。在本质上,代理必须设计和执行一系列行动以达到其目标。这一系列行动是通过无处不在的马尔可夫决策过程(参见第七章中“马尔可夫决策过程”部分,序列数据模型)来建模的。

注意

虚拟行动

设计代理时,重要的是要确保行动不会自动触发环境的新状态。很容易想象一个场景,其中代理触发一个行动只是为了评估其奖励,而不显著影响环境。

对于这种情况,一个很好的隐喻是行动的回滚。然而,并非所有环境都支持这种虚拟行动,代理可能不得不运行蒙特卡洛模拟来尝试一个行动。

政策的价值

强化学习特别适合于长期奖励可以与短期奖励相平衡的问题。一项策略通过将环境状态映射到其行动来指导代理的行为。每个策略都通过一个称为政策价值的变量来评估。

直观地说,策略的值是智能体采取的连续动作序列所收集的所有奖励的总和。在实践中,策略中更远的动作显然比从状态 S[t] 到状态 S[t+1] 的下一个动作影响小。换句话说,未来动作对当前状态的影响必须通过一个称为未来奖励折现系数的因子进行折现,该系数小于 1。

注意

转移和奖励矩阵

在第七章的隐藏马尔可夫模型部分介绍了转移和发射矩阵,序列数据模型

最优策略 π* 是使未来奖励折现到当前时间最大化的智能体动作序列。

下表介绍了强化学习中每个组件的数学符号:

符号 描述
S = {s[i]} 这些是环境的状态
A = {a[i]} 这些是环境上的动作
Π[t] = p(a[t] | s[t]) 这是智能体的策略(或策略)
V^π(s[t]) 这是策略在状态 s[t] 的值
pt =p(s[t+1] | s[t],a[t]) 这些是从状态 st 到状态 s[t+1] 的状态转移概率
r[t]= p(r[t+1] | s[t],s[t+1],a[t]) 这是动作 a[t] 在状态 s[t] 上的奖励
R[t] 这是期望折现长期回报
γ 这是折现未来奖励的系数

目的是计算从任何起始状态 s[k] 出发到当前状态 s[t] 的最大期望奖励 R[t],即所有折现奖励的总和。策略 π 在状态 s[t] 的值 V^π 是给定状态 s[t] 的最大期望奖励 R[t]

注意

M1: 给定策略 π 和折现率 γ,对于状态 st 的累积奖励 R[t] 和值函数 V^π(st) 定义如下:

策略的值

贝尔曼最优方程

寻找最优策略的问题实际上是一个非线性优化问题,其解是迭代的(动态规划)。策略 π 的值函数 V^π 的表达式可以使用马尔可夫状态转移概率 p[t] 来表述。

注意

M2: 使用转移概率 p[k],对于状态 st 和未来状态 s[k] 以及奖励 r[k],给定策略 π 和折现率 γ,值函数 V^π(s[t]) 定义如下:

贝尔曼最优方程

V(s[t])* 是所有策略中状态 st 的最优值。这些方程被称为贝尔曼最优方程。

注意

维度灾难

对于高维问题(大特征向量)的状态数量会迅速变得无法解决。一种解决方案是通过采样 近似值函数并减少状态数量。应用测试案例引入了一个非常简单的近似函数。

如果环境模型、状态、动作、奖励以及状态之间的转换都被完全定义,那么这种强化学习技术被称为基于模型的 学习。在这种情况下,没有必要探索新的动作序列或状态转换。基于模型的学习类似于玩棋类游 戏,其中所有必要的步骤组合以赢得比赛都是完全已知的。

然而,大多数使用序列数据的实际应用都没有一个完整、确定的模型。不依赖于完全定义和可用模型的 学习技术被称为无模型技术。这些技术需要探索以找到任何给定状态的 最佳策略。本章剩余部分将讨论无模型学习技术,特别是时序差分算法。

无模型学习时的时序差分

时序差分是一种无模型学习技术,它采样环境。它是一种常用的迭代求解贝尔曼方程的方法。没有模型的存在需要发现或探索环境。最简单的探索形式是使用下一个状态的价值和从动作中定义的奖励来更新当前状态的价值,如以下图所示:

无模型学习时的时序差分

时序差分算法的示意图

用于调整状态上价值动作的迭代反馈循环在人工神经网络中的错误反向传播或监督学习中的损失函数 最小化中起着类似的作用。调整算法必须:

  • 使用折扣率 γ 折扣下一个状态的估计价值

  • 在使用学习率 α 更新时间 t 的价值时,在当前状态和下一个状态的影响之间取得平衡

首个贝尔曼方程的迭代公式预测 V^π(st),即状态 st 的值函数,从下一个状态 s[t+1] 的值函数中得出。预测值与实际值之间的差被称为时序差分误差,简称 δt

注意

M3:对于状态 s[t] 的值函数 V(s[t])、学习率 α、奖励 r[t] 和折扣率 γ 的表格时序差分 δ[t] 的公式定义为:

无模型学习时的时序差分

评估策略的一种替代方法,即使用状态 V^π(s[t]) 的值,是使用在状态 s[t] 上采取动作的值,称为动作值(或动作-值)Q^π(s[t], a[t])

注意

M4:动作在状态 st 的价值 Q 定义为在状态 s[t] 上对动作 a[t] 的奖励 R[t] 的期望是定义为:

无模型学习的时间差分

实现时间差分算法有两种方法:

  • 按策略:这是使用策略的下一个最佳动作的价值

  • 离策略:这是不使用策略的下一个最佳动作的价值

让我们考虑使用离策略方法的时间差分算法及其最常用的实现:Q-learning。

动作价值迭代更新

Q-learning 是一种无模型学习技术,使用离策略方法。它通过学习动作价值函数来优化动作选择策略。像任何依赖凸优化的机器学习技术一样,Q-learning 算法通过质量函数迭代地遍历动作和状态,如下面的数学公式所述。

该算法预测并折现当前状态 st 和动作 at 在环境中的最优动作值 max{Q[t]},以过渡到状态 s[t+1]

与像遗传算法那样在上一轮繁殖周期中重复使用染色体种群来产生后代类似,Q-learning 技术通过学习率 α 在质量函数的新值 Q[t+1] 和旧值 Q[t] 之间取得平衡。Q-learning 将时间差分技术应用于离策略方法的 Bellman 方程。

注意

M5:给定策略 π、状态集合 {s[t]}、与每个状态 s[t] 相关的动作集合 {a[t]}、学习率 α 和折扣率 γ 的 Q-learning 动作价值更新公式如下:

动作价值迭代更新

  • 学习率 α 的值为 1 会丢弃先前的状态,而值为 0 会丢弃学习

  • 折扣率 γ 的值为 1 只使用长期奖励,而值为 0 只使用短期奖励

Q-learning 估计未来动作的累积奖励的折现值。

注意

Q-learning 作为强化学习

Q-learning 可以被归类为强化学习技术,因为它不严格需要标记数据和训练。此外,Q 值不必是连续的、可微分的函数。

让我们将我们辛苦学到的强化学习知识应用到交易所基金组合的管理和优化中。

实施方案

让我们在 Scala 中实现 Q-learning 算法。

软件设计

Q-learning 算法实现的要点如下定义:

  • QLearning 类实现了训练和预测方法。它使用 QLConfig 类的显式配置定义了 ETransform 类型的数据转换。

  • QLSpace 类有两个组件:一个 QLState 类型状态的序列和一个或多个目标状态的标识符 id,这些目标状态位于序列中。

  • 一个状态,QLState,包含用于其转换到另一个状态的 QLAction 实例序列以及要评估和预测的状态的对象或 instance 的引用。

  • 一个索引状态,QLIndexedState,索引搜索中的状态以指向目标状态。

  • 一个可选的 constraint 函数,它限制从当前状态搜索下一个最有奖励动作的范围。

  • QLModel 类型的模型通过训练生成。它包含最佳策略和模型的准确性。

下图显示了 Q-learning 算法的关键组件:

软件设计

Q-learning 算法的 UML 组件图

状态和动作

QLAction 类指定了一个状态从 from 标识符到另一个状态 to 标识符的转换,如下所示:

class QLAction(val from: Int, val to: Int)

动作具有 Q 值(或动作值)、奖励和概率。实现将这些三个值定义在三个独立的矩阵中:Q 用于动作值,R 用于奖励,P 用于概率,以保持与数学公式的连贯性。

QLState 类型的状态完全由其标识符 id、转换到其他状态的 actions 列表以及参数化类型的 prop 属性定义,如下面的代码所示:

class QLStateT

状态可能没有任何动作。这通常是目标或吸收状态的情况。在这种情况下,列表为空。参数化的 instance 是对计算状态的对象的引用。

下一步是创建图或搜索空间。

搜索空间

搜索空间是负责任何状态序列的容器。QLSpace 类接受以下参数:

  • 所有可能的 states 的序列

  • 被选为 goals 的一个或几个状态的 ID

注意

为什么有多个目标?

完全没有要求状态空间必须有一个单一的目标。你可以将问题的解决方案描述为达到一个阈值或满足几个条件之一。每个条件都可以定义为状态目标。

QLSpace 类的实现如下:

class QLSpaceT {
  val statesMap = states.map(st =>(st.id, st)) //1
  val goalStates = new HashSet[Int]() ++ goals //2

  def maxQ(state: QLState[T], 
      policy: QLPolicy): Double   //3
  def init(state0: Int)  //4
  def nextStates(st: QLState[T]): Seq[QLState[T]]  //5
   …
}

QLSpace 类的构造函数生成一个映射,statesMap。它使用其 id 值(第 1 行)和目标状态数组 goalStates(第 2 行)检索状态。此外,maxQ 方法计算给定策略的状态的最大动作值 maxQ(第 3 行)。maxQ 方法的实现将在下一节中描述。

init 方法为训练周期选择一个初始状态(第 4 行)。如果 state0 参数无效,则状态是随机选择的:

def init(state0: Int): QLState[T] = 
  if(state0 < 0) {
    val seed = System.currentTimeMillis+Random.nextLong
      states((new Random(seed)).nextInt(states.size-1))
  }
  else states(state0)

最后,nextStates 方法检索由与 st 状态相关联的所有动作执行产生的状态列表(行 5)。

QLSpace 搜索空间实际上是由 QLSpace 伴随对象中定义的 apply 工厂方法创建的,如下所示:

def applyT: QLSpace[T] ={ //6

  val r = Range(n, instances.size)  

  val states = instances.zipWithIndex.map{ case(x, n) => {
    val validStates = constraints.map( _(n)).getOrElse(r)
    val actions = validStates.view
          .map(new QLAction(n, _)).filter(n != _.to) //7
    QLStateT
  }}
  new QLSpaceT 
}

apply 方法使用 instances 集合、goalsconstraints 约束函数作为输入创建一个状态列表(行 6)。每个状态创建其动作列表。动作从这个状态生成到任何其他状态(行 7)。

状态的搜索空间如下图所示:

搜索空间

带有 QLData(Q 值、奖励和概率)的状态转换矩阵

constraints 函数限制了从任何给定状态可以触发的动作范围,如前图所示。

策略和动作值

每个动作都有一个动作值、奖励和一个潜在的几率。几率变量被引入以简单地模拟动作执行的阻碍或不利条件。如果动作没有外部约束,几率是 1。如果动作不允许,几率是 0。

备注

将策略与状态分离

动作和状态是搜索空间或搜索图中的边和顶点。由动作值、奖励和概率定义的策略与图完全分离。Q 学习算法独立于图的结构初始化奖励矩阵并更新动作值矩阵。

QLData 类是一个包含三个值的容器:rewardprobability 和用于 Q 值的 value 变量,如下所示:

class QLData(val reward: Double, val probability: Double) {
  var value: Double = 0.0
  def estimate: Double = value*probability
}

备注

奖励和惩罚

QLData 类中的几率代表从一个状态到达另一个状态的阻碍或难度。没有任何东西阻止你将几率用作负奖励或惩罚。然而,其适当的定义是为状态转换创建一个对状态子集的软约束。对于大多数应用,绝大多数状态转换的几率是 1.0,因此依赖于奖励来引导搜索向目标前进。

estimate 方法通过概率调整 Q 值,value,以反映任何可能阻碍动作的外部条件。

备注

可变数据

你可能会想知道为什么 QLData 类将值定义为变量而不是像最佳 Scala 编码实践[11:4]所建议的那样定义为值。原因是对于每个需要更新 value 变量的动作或状态转换,可以创建一个不可变类的实例。

Q-learning 模型的训练涉及遍历多个场景,每个场景被定义为多次迭代。例如,对一个具有 400 个状态的模型进行 10 个场景、每个场景 100 次迭代的训练可能会创建 1.6 亿个QLData实例。虽然并不十分优雅,但可变性减少了 JVM 垃圾收集器的负担。

接下来,让我们创建一个简单的模式或类,QLInput,以初始化与每个动作相关的奖励和概率,如下所示:

class QLInput(val from: Int, val to: Int, 
    val reward: Double, val probability: Double = 1.0)

前两个参数是此特定动作的from源状态和to目标状态的标识符。最后两个参数是动作完成时收集的reward及其probability。无需提供整个矩阵。默认情况下,动作具有 1 的奖励和 1 的概率。您只需要为具有更高奖励或更低概率的动作创建输入。

状态数和输入序列定义了QLPolicy类型的策略。它仅仅是一个数据容器,如下所示:

class QLPolicy(input: Seq[QLInput]) { 
  val numStates = Math.sqrt(input.size).toInt  //8

  val qlData = input.map(qlIn => 
new QLData(qlIn.reward, qlIn.prob))  //9

  def setQ(from: Int, to: Int, value: Double): Unit =
     qlData(from*numStates + to).value = value //10

  def Q(from: Int, to: Int): Double = 
    qlData(from*numStates + to).value  //11
  def EQ(from: Int, to: Int): Double = 
    qlData(from*numStates + to).estimate //12
  def R(from: Int, to: Int): Double = 
    qlData(from*numStates + to).reward //13
  def P(from: Int, to: Int): Double = 
    qlData(from*numStates + to).probability //14
}

状态数numStates是初始输入矩阵input元素数的平方根(第8行)。构造函数使用输入数据、rewardprobability初始化QLData类型的qlData矩阵(第9行)。QLPolicy类定义了更新(第10行)和检索(第11行)valueestimate(第12行)、reward(第13行)和probability(第14行)的快捷方法。

Q-learning 组件

QLearning类封装了 Q-learning 算法,更具体地说,是动作值更新方程。它是一个ETransform类型的数据转换,具有显式的QLConfig类型配置(第16行)(请参阅第二章中的单子数据转换部分,Hello World!):

class QLearningT //15
    extends ETransformQLConfig { //16

  type U = QLState[T] //17
  type V = QLState[T] //18

  val model: Option[QLModel] = train //19
  def train: Option[QLModel]
  def nextState(iSt: QLIndexedState[T]): QLIndexedState[T]

  override def |> : PartialFunction[U, Try[V]]
  …
}

构造函数接受以下参数(第15行):

  • config:这是算法的配置

  • qlSpace:这是搜索空间

  • qlPolicy:这是策略

model在类的实例化过程中生成或训练(请参阅附录 A 中的分类器设计模板部分,基本概念)(第19行)。Q-learning 算法被实现为一个显式的数据转换;因此,|>预测器的输入元素类型U和输出元素类型V被初始化为QLState(第1718行)。

Q-learning 算法的配置,QLConfig,指定了学习率,alpha,折扣率,gamma,一个场景的最大状态数(或长度),episodeLength,用于训练的场景数(或时代),numEpisodes,以及选择最佳策略所需的最低覆盖率,minCoverage,如下所示:

case class QLConfig(val alpha: Double, 
  val gamma: Double, 
  val episodeLength: Int, 
  val numEpisodes: Int, 
  val minCoverage: Double) extends Config

QLearning 类在其伴生对象中定义了两个构造函数,用于从状态输入矩阵或计算奖励和概率的函数初始化策略:

  • 客户端代码指定 input 函数以从输入数据初始化 Q-learning 算法的状态

  • 客户端代码指定生成每个动作或状态转换的 rewardprobability 的函数

QLearning 类的第一个构造函数除了配置和目标(第 20 行)外,还传递了状态的初始化 => Seq[QLInput]、与状态关联的 instances 引用序列以及 constraints 范围约束函数作为参数:

def applyT: QLearning[T]={

   new QLearningT,
     new QLPolicy(input))
}

第二个构造函数将输入数据 xt(第 21 行)、reward 函数(第 22 行)、probability 函数(第 23 行)以及与状态和 constraints 范围约束函数关联的 instances 引用序列作为参数传递:

def applyT => Double, //22
   probability: (Double, Double) => Double, //23
   instances: Seq[T].
   constraints: Option[Int =>List[Int]] =None): QLearning[T] ={

  val r = Range(0, xt.size)
  val input = new ArrayBuffer[QLInput] //24
  r.foreach(i => 
    r.foreach(j => 
     input.append(QLInput(i, j, reward(xt(i), xt(j)), 
          probability(xt(i), xt(j))))
    )
  )
  new QLearningT, 
    new QLPolicy(input))
}

奖励和概率矩阵用于初始化 input 状态(第 24 行)。

Q-learning 训练

让我们看看训练期间最佳策略的计算。首先,我们需要定义一个带有 bestPolicy 最佳策略(或路径)及其 coverage 参数的 QLModel 模型类:

class QLModel(val bestPolicy: QLPolicy, 
    val coverage: Double) extends Model

model 的创建包括执行多个剧集以提取最佳策略。训练在 train 方法中执行:每个剧集从随机选择的状态开始,如下面的代码所示:

def train: Option[QLModel] = Try {
  val completions = Range(0, config.numEpisodes)
     .map(epoch => if(train(-1)) 1 else 0).sum  //25
  completions.toDouble/config.numEpisodes //26
}
.map( coverage => {
  if(coverage > config.minCoverage) 
    Some(new QLModel(qlPolicy, coverage))  //27
  else None
}).get

train 方法遍历从随机选择的状态开始生成最佳策略的过程,共 config.numEpisodes 次(第 25 行)。状态 coverage 被计算为搜索结束于目标状态的百分比(第 26 行)。只有当覆盖率超过配置中指定的阈值值 config.minAccuracy 时,训练才成功。

注意

模型的质量

实现使用准确性来衡量模型或最佳策略的质量。由于没有假阳性,F1 测量值(参考第二章的 评估模型 部分[part0165.xhtml#aid-4TBCQ2 "Chapter 2. Hello World!"],Hello World!)不适用。

train(state0: Int) 方法在每个剧集(或时代)中执行繁重的工作。如果 state0 小于 0,它将通过选择 state0 初始状态或一个新的随机生成器 r 来触发搜索,如下面的代码所示:

case class QLIndexedStateT

QLIndexedState 工具类跟踪在剧集或时代中特定迭代 iterstate

def train(state0: Int): Boolean = {
  @tailrec
  def search(iSt: QLIndexedState[T]): QLIndexedState[T]

    val finalState = search(
       QLIndexedState(qlSpace.init(state0), 0)
    )
  if( finalState.index == -1) false //28
  else qlSpace.isGoal(finalState.state) //29
}

从预定义或随机的 state0 搜索目标状态(s)的实现是 Scala 尾递归的教科书式实现。递归搜索要么在没有更多状态要考虑时结束(第 28 行),要么达到目标状态(第 29 行)。

尾递归救命

尾递归是将操作应用于集合中每个项目的非常有效的结构 [11:5]。它优化了递归期间函数栈帧的管理。注解触发编译器优化函数调用的条件验证,如下所示:

@tailrec
def search(iSt: QLIndexedState[T]): QLIndexedState[T] = {
  val states = qlSpace.nextStates(iSt.state) //30 

  if( states.isEmpty || iSt.iter >= config.episodeLength) //31 
    QLIndexedState(iSt.state, -1)

  else {
    val state = states.maxBy(s => 
       qlPolicy.R(iSt.state.id, s.id)) //32
    if( qlSpace.isGoal(state) ) 
       QLIndexedState(state, iSt.iter)  //33

    else {
      val fromId = iSt.state.id
      val r = qlPolicy.R(fromId, state.id)   
      val q = qlPolicy.Q(fromId, state.id) //34

      val nq = q + config.alpha*(r + 
         config.gamma * qlSpace.maxQ(state, qlPolicy)-q) //35
      qlPolicy.setQ(fromId, state.id,  nq) //36
      search(QLIndexedState(state, iSt.iter+1))
    }
  }
}

让我们深入探讨 Q 动作值更新方程的实现。search 方法实现了每个递归的 M5 数学表达式。

递归使用 QLIndexedState 实用类(状态、场景中的迭代次数)作为参数。首先,递归调用 QLSpacenextStates 方法(第 30 行)以检索通过其动作与 st 当前状态相关联的所有状态,如下所示:

def nextStates(st: QLState[T]): Seq[QLState[T]] = 
  if( st.actions.isEmpty )
    Seq.empty[QLState[T]]
  else
    st.actions.map(ac => statesMap.get(ac.to) )

如果达到场景长度(访问过的最大状态数)、达到 goal 或没有进一步的状态可以转换到(第 31 行),则搜索完成并返回当前 state。否则,递归计算从当前策略生成更高奖励 R 的状态(第 32 行)。如果它是目标状态之一,递归返回具有最高奖励的状态(第 33 行)。该方法从策略中检索当前的 q 动作值和 r 奖励矩阵(第 34 行),然后应用方程更新动作值(第 35 行)。该方法使用新的值 nq 更新动作值 Q(第 36 行)。

动作值更新方程需要计算与当前状态关联的最大动作值,这由 QLSpace 类的 maxQ 方法执行:

def maxQ(state: QLState[T], policy: QLPolicy): Double = {
   val best = states.filter( _ != state)  //37
      .maxBy(st => policy.EQ(state.id, st.id))  //38
   policy.EQ(state.id, best.id)
}

maxQ 方法过滤掉当前状态(第 37 行),然后提取最佳状态,该状态最大化策略(第 38 行)。

注意

可达目标

算法不需要每个场景都达到目标状态。毕竟,没有任何保证可以从任何随机选择的状态达到目标。算法的一个约束是在场景内转换状态时遵循奖励的正梯度。训练的目标是从任何给定的初始状态计算最佳策略或状态序列。您负责验证从训练集中提取的模型或最佳策略,无论每个场景是否达到目标状态。

验证

商业应用可能需要关于状态转换、奖励、概率和 Q 值矩阵的多种验证机制。

一个关键验证是验证用户定义的 constraints 函数不会在 Q-learning 的搜索或训练中造成死胡同。constraints 函数建立从给定状态通过动作可以访问的状态列表。如果约束过于严格,一些可能的搜索路径可能无法达到目标状态。以下是 constraints 函数的简单验证:

def validateConstraints(numStates: Int, 
    constraints: Int => List[Int]): Boolean = 
  Range(0, numStates).exists( constraints(_).isEmpty )

预测

QLearning类的最后一个功能是使用训练期间创建的模型进行预测。|>方法从给定的状态state0预测最佳状态转换(或动作):

override def |> : PartialFunction[U, Try[V]] = {
  case state0: U if(isModel) =>  Try {
    if(state0.isGoal) state0  //39
    else nextState(QLIndexedStateT).state) //40
  }
}

|>数据转换如果输入状态state0是目标(第 39 行),则返回自身;否则,使用另一个尾递归计算最佳结果nextState,如下所示:

@tailrec
def nextState(iSt: QLIndexedState[T]): QLIndexedState[T] =  {
  val states = qlSpace.nextStates(iSt.state) //41

  if( states.isEmpty || iSt.iter >=config.episodeLength) 
     iSt  //42
  else {
    val fromId = iSt.state.id
    val qState = states.maxBy(s =>  //43
       model.map(_.bestPolicy.R(fromId, s.id)).getOrElse(-1.0))
    nextState(QLIndexedStateT) //44
  }
}

nextState方法执行以下调用序列:

  1. 从当前状态iSt.state(第 41 行)检索可以过渡到的合格状态。

  2. 如果没有更多状态或方法在最大允许迭代次数config.episodeLength(第 42 行)内收敛,则返回状态。

  3. 提取具有最有利策略的状态qState(第 43 行)。

  4. 增加迭代计数器iSt.iter(第 44 行)。

注意

退出条件

当没有更多状态可用或超过场景中的最大迭代次数时,预测结束。您可以定义一个更复杂的退出条件。挑战在于除了时间差分误差外,没有明确的错误或损失变量/函数可以使用。

如果在训练期间无法创建模型,|>预测方法返回最佳可能状态或None

使用 Q 学习的期权交易

Q 学习算法在许多金融和市场监管应用中得到了使用[11:6]。让我们考虑在给定市场条件和交易数据的情况下,计算某些类型期权最佳交易策略的问题。

芝加哥期权交易所CBOE)提供了一个优秀的在线教程,关于期权[11:7]。期权是一种合同,赋予买方在特定价格上或之前在特定日期购买或出售基础资产的权利,但没有义务(参考附录 A 下Finances 101中的期权交易部分,基本概念)。存在几种期权定价模型,其中 Black-Scholes 随机偏微分方程是最为人所知的[11:8]。

本练习的目的是根据从到期时间、证券价格和波动性中得出的当前观察到的特征,预测 N 天后的证券期权价格。让我们专注于给定证券 IBM 的看涨期权。以下图表显示了 2014 年 5 月的 IBM 股票及其衍生看涨期权(行权价为 190 美元)的每日价格:

使用 Q 学习的期权交易

IBM 股票和 2013 年 5 月至 10 月期间的 190 美元行权价期权定价

期权的价格取决于以下参数:

  • 期权的到期时间(时间衰减)

  • 基础证券的价格

  • 基础资产的回报波动性

定价模型通常不考虑基础证券交易量的变化。因此,将其包含在我们的模型中会非常有趣。让我们使用以下四个归一化特征来定义期权的状态:

  • 时间衰减 (timeToExp): 这是在[0, 1]范围内归一化后的到期时间。

  • 相对波动率 (volatility): 这是交易期间基础证券价格的相对变化。它与 Black-Scholes 模型中定义的更复杂的波动率不同,例如。

  • 相对于成交量的波动率 (vltyByVol): 这是调整了其交易量的证券价格的相对波动率。

  • 当前价格与执行价格之间的相对差异 (priceToStrike): 这衡量的是价格与执行价格之间差异与执行价格的比率。

以下图表显示了 IBM 期权策略的四个标准化特征:

使用 Q-learning 进行期权交易

标准化相对股票价格波动率、相对于成交量的波动率,以及相对于执行价格的 IBM 股票价格

使用 Q-learning 实现期权交易策略的步骤如下:

  1. 描述期权的属性

  2. 定义函数逼近

  3. 指定状态转移的约束

OptionProperty

让我们选择N = 2作为预测未来的天数。任何长期预测都相当不可靠,因为它超出了离散马尔可夫模型的约束。因此,两天后的期权价格是奖励——利润或损失的值。

OptionProperty类封装了期权的四个属性(第45行)如下:

class OptionProperty(timeToExp: Double, volatility: Double, 
     vltyByVol: Double, priceToStrike: Double) { //45

  val toArray = ArrayDouble
}

注意

模块化设计

实现避免了通过子类化QLState类来定义我们的期权定价模型的特征。期权的状态是状态类的一个参数化prop参数。

OptionModel

OptionModel类是一个容器和工厂,用于期权属性。它通过访问前面引入的四个特征的数据源来创建propsList期权属性列表。它接受以下参数:

  • 证券的符号。

  • strikePrice期权的执行价格。

  • 数据来源,src

  • 最小时间衰减或到期时间,minTDecay。价外期权在到期时毫无价值,而价内期权在接近到期日时具有非常不同的价格行为(参见附录 A 中的期权交易部分,基本概念)。因此,在到期日前最后minTDecay个交易时段不用于模型的训练。

  • 步数(或桶数),nSteps。它用于逼近每个特征值。例如,四个步骤的逼近创建四个桶 [0, 25]、[25, 50]、[50, 75] 和 [75, 100]。

OptionModel 类的实现如下:

class OptionModel(symbol: String,  strikePrice: Double, 
   src: DataSource, minExpT: Int, nSteps: Int) {

  val propsList = (for {
    price <- src.get(adjClose)
    volatility <- src.get(volatility)
    nVolatility <- normalize(volatility)
    vltyByVol <- src.get(volatilityByVol)
    nVltyByVol <- normalize(vltyByVol)
    priceToStrike <- normalize(price.map(p => 
       (1.0 - strikePrice/p)))
  } yield {
    nVolatility.zipWithIndex./:(List[OptionProperty]()){ //46
      case(xs, (v,n)) => {
         val normDecay = (n + minExpT).toDouble/
            (price.size + minExpT) //47
         new OptionProperty(normDecay, v, nVltyByVol(n), 
           priceToStrike(n)) :: xs
      }
    }.drop(2).reverse  //48
   })
  .getOrElse(List.empty[OptionProperty].)

  def quantize(o: DblArray): Map[Array[Int], Double] 
}

工厂使用 zipWithIndex Scala 方法来表示交易会的索引(第 46 行)。所有特征值都在区间 [0, 1] 上归一化,包括 normDecay 期权的衰减时间(或到期时间)(第 47 行)。如果构造函数成功,则 OptionModel 类的实例化生成一个 OptionProperty 元素列表(第 48 行),否则生成一个空列表。

量化

期权的四个属性是连续值,归一化为概率 [0, 1]。Q-learning 算法中的状态是离散的,需要一种称为 函数逼近 的量化或分类;尽管函数逼近方案可能相当复杂 [11:9]。让我们满足于一个简单的线性分类,如下面的图所示:

量化

交易期权的状态量化

函数逼近定义了状态的数量。在这个例子中,将归一化值转换为三个区间或桶的函数逼近生成 3⁴ = 81 个状态或潜在的 3⁸-3⁴ = 6480 个动作!对于 l 个桶的函数逼近和 n 个特征,最大状态数为 l^n,最大动作数为 l(2n)-ln

注意

量化或函数逼近指南

设计用于逼近期权状态的函数必须解决以下两个相互冲突的要求:

  • 准确性要求精细的逼近

  • 有限的计算资源限制了状态的数量,因此逼近的级别

OptionModel 类的 quantize 方法将每个期权属性的特征归一化值转换为桶索引数组。它返回一个以桶索引数组为键的利润和损失映射,如下代码所示:

def quantize(o: DblArray): Map[Array[Int], Double] = {
  val mapper = new mutable.HashMap[Int, Array[Int]] //49
  val _acc = new NumericAccumulator[Int] //50

  val acc = propsList.view.map( _.toArray)
       .map( toArrayInt( _ )) //51
       .map(ar => { 
          val enc = encode(ar)  //52
          mapper.put(enc, ar)
          enc
       }).zip(o)
       ./:(_acc){ 
          case (acc, (t,y)) => { //53
          acc += (t, y)
          acc 
       }}
   acc.map {case (k, (v,w)) => (k, v/w) }  //54
     .map {case( k,v) => (mapper(k), v) }.toMap
}

该方法创建一个 mapper 实例来索引桶数组(第 49 行)。acc 累加器类型为 NumericAccumulator,扩展 Map[Int, (Int, Double)] 并计算每个桶中特征的出现次数和期权价格的增减总和(第 50 行)。toArrayInt 方法将每个期权属性的值(如 timeToExpvolatility 等)转换为适当的桶索引(第 51 行)。然后,索引数组被编码(第 52 行)以生成状态的 id 或索引。该方法通过更新累加器中的出现次数和期权交易会的总利润和损失来更新(第 53 行)。它最后通过平均每个桶的利润和损失来计算每个动作的奖励(第 54 行)。

使用视图生成OptionProperty列表以避免不必要的对象创建。

toArrayIntencode方法以及NumericAccumulator的源代码已记录并在网上提供。

将所有内容整合在一起

最后一个拼图是配置和执行 Q 学习算法在一种或多种证券(IBM)上的代码:

val STOCK_PRICES = "resources/data/chap11/IBM.csv"
val OPTION_PRICES = "resources/data/chap11/IBM_O.csv"
val QUANTIZER = 4
val src = DataSource(STOCK_PRICES, false, false, 1) //55

val model = for {
  option <- Try(createOptionModel(src)) //56
  oPrices <- DataSource(OPTION_PRICES, false).extract //57
   _model <- createModel(option, oPrices) //58
} yield _model

上述实现通过以下步骤创建 Q 学习模型:

  1. 通过实例化数据源src(第55行)提取 IBM 股票的历史价格。

  2. 创建一个option模型(第56行)。

  3. 提取期权调用$190 May 2014 的历史价格oPrices(第57行)。

  4. 创建具有预定义目标goalStr的模型_model(第58行)。

代码如下:

val STRIKE_PRICE = 190.0
val MIN_TIME_EXPIRATION = 6

def createOptionModel(src: DataSource): OptionModel = 
   new OptionModel("IBM", STRIKE_PRICE, src, 
       MIN_TIME_EXPIRATION, QUANTIZER)

让我们看看createModel方法,该方法接受期权定价模型optionoPrice期权的历史价格作为参数:

val LEARNING_RATE = 0.2
val DISCOUNT_RATE = 0.7
val MAX_EPISODE_LEN = 128
val NUM_EPISODES = 80

def createModel(option: OptionModel, oPrices: DblArray,
     alpha: Double, gamma: Double): Try[QLModel] = Try {

  val qPriceMap = option.quantize(oPrices) //59
  val numStates = qPriceMap.size

  val qPrice = qPriceMap.values.toVector //60
  val profit= zipWithShift(qPrice,1).map{case(x,y) => y -x}//61
  val maxProfitIndex = profit.zipWithIndex.maxBy(_._1)._2 //62

  val reward = (x: Double, y: Double) 
            => Math.exp(30.0*(y – x)) //63
  val probability = (x: Double, y: Double) => 
         if(y < 0.3*x) 0.0 else 1.0  //64

  if( !validateConstraints(profit.size, neighbors)) //65 
      throw new IllegalStateException(" ... ")

  val config = QLConfig(alpha, gamma, 
     MAX_EPISODE_LEN, NUM_EPISODES) //66
  val instances = qPriceMap.keySet.toSeq.drop(1)
  QLearning[Array[Int]](config, ArrayInt, 
     profit, reward, probability, 
     instances, Some(neighbors)).getModel //67
}

该方法量化了期权价格映射oPrices(第59行),提取了历史期权价格qPrice(第60行),计算了利润作为两个连续交易会话中期权价格的差异(第61行),并计算了具有最高利润的会话的索引maxProfitIndex(第62行)。具有maxProfitIndex索引的状态被选为目标。

输入矩阵通过使用rewardprobability函数自动生成。reward函数按利润比例奖励状态转换(第63行)。probability函数通过将概率值设置为0来惩罚损失y – x大于0.3x*的状态转换(第64行)。

注意

奖励和概率的初始化

在我们的示例中,奖励和概率矩阵是通过两个函数自动生成的。另一种方法是通过使用历史数据或合理的猜测来初始化这两个矩阵。

QLearning伴生对象的validateConstraints方法验证了neighbors约束函数,如验证部分所述(第65行)。

最后两个步骤包括为 Q 学习算法创建一个配置config(第66行),并通过实例化QLearning类并使用适当的参数(包括定义任何给定状态的邻近状态的neighbors方法)来训练模型,包括第67行的neighbors方法。neighbors方法在在线文档源代码中有描述。

注意

反目标状态

目标状态是具有最高分配奖励的状态。这是一种奖励良好表现的策略的启发式方法。然而,可以定义一个具有最高分配惩罚或最低分配奖励的反目标状态,以引导搜索远离某些条件。

评估

除了函数逼近之外,训练集的大小也会影响状态的数量。一个分布良好或大的训练集为每个由逼近创建的桶提供至少一个值。在这种情况下,训练集相当小,只有 34 个桶中有实际值。因此,状态的数量是 34。Q 学习模型的初始化生成了以下奖励矩阵:

评估

期权定价 Q 学习策略的奖励矩阵

图形化了从期权的盈亏计算出的奖励分布。xy 平面表示状态之间的动作。状态的 ID 列在 xy 轴上。z 轴衡量与每个动作相关的实际奖励值。

奖励反映了期权价格的波动。期权的价格比基础证券的价格波动性更高。

xy 奖励矩阵 R 分布相当不均匀。因此,我们选择一个较小的学习率 0.2 以减少先前状态对新状态的影响。折扣率 0.7 的值适应了状态数量有限的事实。没有必要使用长序列的状态来计算未来的折现奖励。策略的训练生成了以下 34x34 状态的动作-价值矩阵 Q,在第一个剧集之后:

评估

第一个剧集(时期)的 Q 动作-价值矩阵

在第一个剧集结束时,状态之间的动作-价值分布反映了状态到状态动作的奖励分布。第一个剧集由从初始随机选择的状态到目标状态的九个状态序列组成。动作-价值图与以下图中 20 个剧集之后的地图进行比较:

评估

最后一个剧集(时期)的 Q 动作-价值矩阵

在最后一个剧集结束时,动作-价值图显示了某些明显的模式。大多数奖励动作从大量状态(X 轴)过渡到较少的状态(Y 轴)。图表说明了以下关于小训练样本的问题:

  • 训练集的小规模迫使我们必须使用每个特征的近似表示。目的是增加大多数桶至少有一个数据点的概率。

  • 然而,松散的函数逼近或量化往往将相当不同的状态分组到同一个桶中。

  • 数值非常低的桶可能会错误地描述一个状态的一个属性或特征。

下一个测试是显示 Q 值日志(QLData.value)的配置文件,随着不同剧集或时期的递归搜索(或训练)进度。该测试使用学习率 α = 0.1 和折扣率 γ = 0.9

评估

Q 学习训练过程中不同纪元的日志(Q 值)配置文件。

前面的图表说明了这样一个事实:在训练过程中,每个配置文件的 Q 值与纪元的顺序无关。然而,配置文件的长度(或达到目标状态所需的迭代次数)取决于随机选择的初始状态,在本例中。

最后的测试包括评估学习率和折扣率对训练覆盖率的影响:

评估

训练覆盖率与学习率和折扣率的关系

随着学习率的增加,覆盖率(达到目标状态的时间段或纪元的百分比)会降低。这一结果证实了使用学习率小于 0.2 的一般规则。用于评估折扣率对覆盖率影响的类似测试结果并不明确。

强化学习的优缺点

强化学习算法非常适合以下问题:

  • 在线学习

  • 训练数据量小或不存在

  • 模型不存在或定义不明确。

  • 计算资源有限

然而,这些技术在以下情况下表现不佳:

  • 搜索空间(可能动作的数量)很大,因为维护状态、动作图和奖励矩阵变得具有挑战性。

  • 在可扩展性和性能方面,执行并不总是可预测的。

学习分类系统

J. Holland 在 30 多年前引入了学习分类系统LCS)的概念,作为进化计算的一个扩展[11:10]。

学习分类系统是一种基于规则的系统,具有并行处理规则、自适应生成新规则和测试新规则有效性的通用机制。

然而,这个概念仅在几年前开始引起计算机科学家的注意,随着原始概念的几个变体(包括扩展学习分类器XCS))的引入。学习分类系统之所以有趣,是因为它们结合了规则、强化学习和遗传算法。

备注

免责声明

仅为了信息目的,展示了扩展学习分类器的实现。验证 XCS 对已知和标记的规则群体是非常有意义的。仅展示源代码片段是为了说明 XCS 算法的不同组件。

LCS 简介

学习分类系统融合了强化学习、基于规则的策略和进化计算的概念。这一独特的学习算法类别代表了以下研究领域的融合[11:11]:

  • 强化学习

  • 遗传算法和进化计算

  • 监督学习

  • 基于规则的知识编码

让我们看看以下图示:

LCS 简介

学习分类系统所需科学学科的图示

学习分类系统是复杂自适应系统的一个例子。一个学习分类系统具有以下四个组件:

  • 一个分类器或规则的种群:这个种群随时间演变。在某些情况下,领域专家创建了一组原始规则(核心知识)。在其他情况下,规则在学习分类系统执行之前随机生成。

  • 基于遗传算法的发现引擎:这个组件从现有种群中生成新的分类器或规则。该组件也被称为规则发现模块。这些规则依赖于前一章中引入的同一生物进化模式。规则被编码为字符串或位字符串,以表示条件(谓词)和动作。

  • 一个性能或评估函数:这个函数衡量了最适应分类器或策略的动作的积极或消极影响。

  • 一个强化学习组件:这个组件奖励或惩罚对动作有贡献的分类器,正如前节所述。对提高系统性能的动作有贡献的规则被奖励,而对降低系统性能的规则被惩罚。这个组件也被称为信用分配模块。

为什么选择 LCS?

学习分类系统特别适合于环境不断变化的问题,它们是学习策略和构建、维护知识库的进化方法的组合[11:12]。

仅使用监督学习方法在大数据集上可能有效,但它们需要大量的标记数据或减少特征集以避免过拟合。在环境不断变化的情况下,这些限制可能不切实际。

在过去的 20 年里,引入了许多学习分类系统的变体,它们属于以下两个类别:

  • 从正确预测中计算准确度并将发现应用于这些正确类别的子集的系统。它们结合了监督学习的元素来约束分类器的种群。这些系统众所周知遵循匹兹堡方法

  • 探索所有分类器并应用规则准确度到规则遗传选择的系统。每个个体分类器都是一个规则。这些系统众所周知遵循密歇根方法

本节的其余部分致力于第二种类型的学习分类器——更具体地说,是扩展学习分类系统。在 LCS 的背景下,术语分类器指的是系统生成的谓词或规则。从现在开始,术语规则将取代分类器,以避免与更常见的分类定义混淆。

术语

每个研究领域都有自己的术语,LCS 也不例外。LCS 的术语包括以下术语:

  • 环境:这是强化学习上下文中的环境变量。

  • 代理:在强化学习中使用的代理。

  • 谓词:使用格式 变量-运算符-值 的子句或事实,通常实现为 (运算符,变量值);例如,Temperature-exceeds-87F('Temperature', 87F)Hard drive–failed('Status hard drive', FAILED),等等。它被编码为基因以便由遗传算法处理。

  • 复合谓词:这是由几个谓词和布尔逻辑运算符组成的组合,通常实现为一个逻辑树(例如,((谓词 1 AND 谓词 2) OR 谓词 3) 实现为 OR (AND (谓词 1, 谓词 2), 谓词 3)。它使用染色体表示。

  • 动作:这是一种通过修改一个或多个参数的值来改变环境的机制,使用格式 (动作类型,目标);例如,更改恒温器设置更换硬盘,等等。

  • 规则:这是一种使用格式 IF 复合谓词 THEN 动作序列 的形式一阶逻辑公式;例如,IF 黄金价格 < $1140 THEN 卖出石油和天然气生产公司的股票

  • 分类器:这是在 LCS(学习分类系统)上下文中的一个规则。

  • 规则适应度或分数:这与遗传算法中适应度或分数的定义相同。在 LCS 的上下文中,它是规则在环境变化时被调用和触发的概率。

  • 传感器:这些是代理监控的环境变量;例如,温度和硬盘状态。

  • 输入数据流:这是由传感器生成数据的流动。它通常与在线训练相关联。

  • 规则匹配:这是一种将谓词或复合谓词与传感器匹配的机制。

  • 覆盖:这是创建新规则以匹配环境中新的条件(传感器)的过程。它通过使用随机生成器或变异现有规则来生成规则。

  • 预测器:这是一种算法,用于在匹配规则集中找到出现次数最多的动作。

扩展学习分类系统

与强化学习类似,XCS 算法有一个探索阶段和一个利用阶段。利用过程包括利用现有规则以有利或奖励的方式影响目标环境:

扩展学习分类系统

XCS 算法的利用组件

以下列表描述了每个编号块:

  1. 传感器从系统中获取新的数据或事件。

  2. 从当前种群中提取与条件匹配输入事件的规则。

  3. 如果在现有种群中没有找到匹配项,则会创建一条新规则。这个过程被称为覆盖。

  4. 根据其适应性值对选定的规则进行排名,并使用预测结果最高的规则来触发行动。

探索组件的目的是增加规则库,作为编码这些规则的染色体种群。

扩展学习分类器系统

XCS 算法的探索组件

以下列表描述了框图中的每个编号块:

  1. 一旦执行了行动,系统会奖励执行了该行动的规则。强化学习模块将这些规则赋予信用。

  2. 奖励用于更新规则适应性,将进化约束应用于现有种群。

  3. 遗传算法通过交叉和变异等操作更新现有的分类器/规则种群。

XCS 组件

本节描述了 XCS 的关键类。该实现利用了遗传算法和强化学习现有的设计。通过具体应用,更容易理解 XCS 算法的内部工作原理。

投资组合管理的应用

投资组合管理和交易从扩展学习分类器的应用中受益[11:13]。用例是管理在不断变化的金融环境中交易所交易基金的投资组合。与股票不同,交易所交易基金代表特定行业的一组股票或整个金融市场。因此,这些 ETF 的价格受以下宏观经济变化的影响:

  • 国内生产总值

  • 通货膨胀

  • 地缘政治事件

  • 利率

让我们选择 10 年期国债收益率为宏观经济状况的代理值,以简化问题。

投资组合必须不断调整,以应对任何影响投资组合总价值的环境或市场条件的变化,这可以通过以下表格来完成:

XCS 组件 投资组合管理
环境 这是根据其组成、总价值和 10 年期国债收益率定义的证券投资组合
行动 这是投资组合组成的改变
奖励 这是投资组合总价值的盈亏
输入数据流 这是股票和债券价格报价的流
传感器 这是关于投资组合中证券的交易信息,如价格、成交量、波动率、收益率和-10 年国债收益率
谓词 这是投资组合组成的改变
行动 通过买卖证券来重新平衡投资组合
规则 这是将交易数据与投资组合重新平衡相关联

第一步是创建一组关于投资组合的初始规则。这个初始集可以是随机创建的,就像遗传算法的初始种群一样,或者由领域专家定义。

注意

XCS 初始种群

规则或分类器通过进化来定义和/或细化。因此,没有绝对的要求需要领域专家设置一个全面的知识库。实际上,规则可以在训练阶段的开始时随机生成。然而,用几个相关规则初始化 XCS 初始种群可以提高算法快速收敛的概率。

欢迎您尽可能多地用相关且财务稳健的交易规则初始化规则种群。随着时间的推移,XCS 算法的执行将证实初始规则是否确实合适。以下图表描述了 XCS 算法应用于 ETF 投资组合(如 VWO、TLT、IWC 等)组成的实例,以下为以下组件:

  • 交易规则种群

  • 匹配规则并计算预测的算法

  • 提取动作集的算法

  • Q 学习模块为选定的规则分配信用或奖励

  • 用于进化规则种群的遗传算法

让我们看看以下图表:

应用于投资组合管理

XCS 算法优化投资组合分配的概述

代理通过匹配现有规则之一来响应投资组合中 ETF 分配的变化。

让我们从地面开始构建 XCS 代理。

XCS 核心数据

XCS 代理操作三种类型的数据:

  • Signal:这是交易信号。

  • XcsAction:这是对环境的作用。它继承自遗传算法中定义的Gene

  • XcsSensor:这是传感器或环境中的数据。

在第十章的交易信号部分介绍了Gene类,用于评估遗传算法。代理创建、修改和删除动作。将这些动作定义为可变的基因是有意义的,如下所示:

class XcsAction(sensorId: String, target: Double)
    (implicit quantize: Quantization, encoding: Encoding) //1
    extends Gene(sensorId, target, EQUAL)

XCSAction量化并编码为Gene必须显式声明(行1)。XcsAction类具有sensorId传感器的标识符和目标值作为参数。例如,将投资组合中 ETF VWO 的持股数量增加到 80 的动作定义如下:

val vwoTo80 = new XcsAction("VWO", 80.0)

在此方案中,唯一允许的操作类型是使用EQUAL运算符设置值。你可以创建支持其他运算符的动作,例如用于增加现有值的+=运算符。这些运算符需要实现运算符特性,如第十章遗传算法中的交易运算符部分所述,遗传算法

最后,XcsSensor类封装了变量的sensorId标识符和传感器的value,如下所示:

case class XcsSensor(val sensorId: String, val value: Double) 
val new10ytb = XcsSensor("10yTBYield", 2.76)

注意

设置器和获取器

在这个简化的场景中,传感器从环境变量中检索新的值。动作将新值设置到环境变量中。你可以将传感器视为环境类的获取方法,将动作视为带有变量/传感器 ID 和值的参数的设置方法。

XCS 规则

下一步是定义一个XcsRule类型的规则,作为一对基因:一个signal和一个action,如下面的代码所示:

class XcsRule(val signal: Signal, val action: XcsAction)

规则:r1:IF(10 年 TB 收益率 > 2.84%) THEN 将 VWO 股份减少至 240 的实现方式如下:

val signal = new Signal("10ytb", 2.84, GREATER_THAN) 
val action = new XcsAction("vwo", 240) 
val r1 = new XcsRule(signal, action)

智能体使用 2 位来表示操作符和 32 位来表示值,将规则编码为染色体,如下面的图所示:

XCS 规则

在此实现中,无需对动作类型进行编码,因为智能体仅使用一种类型的动作——设置。复杂动作需要对其类型进行编码。

注意

知识编码

此示例使用非常简单的规则,其中只有一个谓词作为条件。现实世界领域的知识通常使用具有多个子句的复杂规则进行编码。强烈建议将复杂规则分解为多个基本分类器的规则。

将规则与新的传感器匹配包括将传感器与信号匹配。算法将新的new10ytb传感器(第2行)与当前种群中s10ytb1(第3行)和s10ytb2(第4行)规则使用的信号进行匹配,这些规则使用相同的传感器或10ytb变量,如下所示:

val new10ytb = new XcsSensor("10ytb", 2.76) //2
val s10ytb1 = Signal("10ytb", 2.5, GREATER_THAN)  //3
val s10ytb2 = Signal("10ytb", 2.2, LESS_THAN) //4

在这种情况下,智能体在现有种群中选择r23规则而不是r34。然后智能体将act12动作添加到可能的动作列表中。智能体列出所有与r23r11r46传感器匹配的规则,如下面的代码所示:

val r23: XcsRule(s10yTB1, act12) //5
val r11: XcsRule(s10yTB6, act6) 
val r46: XcsRule(s10yTB7, act12) //6

具有最多引用的动作act12(第5行和第6行)被执行。Q 学习算法计算执行所选r23r46规则后产生的利润或损失奖励。智能体使用奖励来调整r23r46的适应性,在下一个繁殖周期之前的遗传选择中。这两个规则将达到并保持在种群规则的前列,直到通过交叉和变异修改的新遗传规则或通过覆盖创建的规则触发了对环境更有利的动作。

覆盖

覆盖阶段的目的是在没有规则与输入或传感器匹配的情况下生成新规则。XcsCover单例的cover方法根据传感器和现有动作集生成一个新的XcsRule实例,如下所示:

val MAX_NUM_ACTIONS = 2048

def cover(sensor: XcsSensor, actions: List[XcsAction])
   (implicit quant: Quantization, encoding: Encoding): 
   List[XcsRule] = 

  actions./:(List[XcsRule]()) ((xs, act) => {
    val signal = Signal(sensor.id, sensor.value, 
       new SOperator(Random.nextInt(Signal.numOperators)))
    new XcsRule(signal, XcsAction(act, Random)) :: xs
  })
}

你可能会想知道为什么 cover 方法使用一组动作作为参数,因为覆盖包括创建新的动作。该方法通过突变(^ 操作符)现有动作来创建一个新动作,而不是使用随机生成器。这是将动作定义为基因的一个优点。XcsAction 的一个构造函数执行突变,如下所示:

def apply(action: XcsAction, r: Random): XcsAction = 
   (action ^ r.nextInt(XCSACTION_SIZE))

操作符 r 类型的索引是一个在区间 [0, 3] 内的随机值,因为一个信号使用四种类型的操作符:None><=

一个实现示例

Xcs 类有以下用途:

  • gaSolver:这是选择和生成基因修改规则的工具

  • qlLearner:这是奖励和评分规则

  • Xcs:这些是匹配、覆盖和生成动作的规则

扩展学习分类器是一个 ETransform 类型的数据转换,具有显式配置的 XcsConfig 类型(第 8 行)(参考第 2 章的 单调数据转换 部分,第 2 章处理数据的方法,你好,世界!):

class Xcs(config: XcsConfig, 
    population: Population[Signal], 
    score: Chromosome[Signal]=> Unit, 
    input: Array[QLInput])  //7
      extends ETransformXcsConfig { //8

  type U = XcsSensor   //9
  type V = List[XcsAction]   //10

   val solver = GASolverSignal 
   val features = population.chromosomes.toSeq
   val qLearner = QLearning[Chromosome[Signal]]( //11
      config.qlConfig, extractGoals(input), input, features)
   override def |> : PartialFunction[U, Try[V]]
   ...
}

XCS 算法使用配置 config、初始规则集 population、适应度函数 score 和 Q-learning 策略的 input 来初始化,为 qlLearner 生成奖励矩阵(第 7 行)。作为一个显式数据转换,输入元素的 U 类型以及输出到 |> 预测器的 V 类型被初始化为 XcsSensor(第 9 行)和 List[XcsAction](第 10 行)。

目标和状态数量是从 Q-learning 算法策略的输入中提取的。

在这个实现中,solver 泛型算法是可变的。它与 Xcs 容器类一起实例化。Q-learning 算法采用相同的设计,作为任何分类器,是不可变的。Q-learning 的模型是奖励规则的最好可能策略。状态数量或奖励方案的变化需要学习器的新实例。

学习分类系统的益处和局限性

学习分类系统,特别是 XCS,拥有许多承诺,如下列所示:

  • 它们允许非科学家和领域专家使用熟悉的布尔构造和推理,如谓词和规则来描述知识

  • 它们通过区分对知识库的探索和利用需求,为分析师提供知识库及其覆盖范围的概述

然而,科学界在认识到这些技术的优点方面进展缓慢。学习分类系统的更广泛采用受到以下因素的影响:

  • 探索和利用阶段使用的参数数量众多,增加了算法的纯粹复杂性。

  • 学习分类系统的竞争变体太多

  • 没有一个明确的统一理论来验证进化策略或规则的概念。毕竟,这些算法是独立技术的融合。许多学习分类系统变体的准确性和性能取决于每个组件以及组件之间的交互。

  • 可扩展性和性能方面不一定可预测的执行。

总结

软件工程社区有时会忽视强化学习算法。让我们希望这一章能够为以下问题提供充分的答案:

  • 什么是强化学习?

  • 哪些不同类型的算法可以被认为是强化学习?

  • 我们如何在 Scala 中实现 Q 学习算法?

  • 我们如何将 Q 学习应用于期权交易的优化?

  • 使用强化学习的优缺点是什么?

  • 学习分类系统是什么?

  • XCS 算法的关键组件是什么?

  • 学习分类系统的潜力和局限性是什么?

这标志着最后一类学习技术的介绍结束。我们周围的数据量不断增加,需要数据处理和机器学习算法具有高度的可扩展性。这是下一章和最后一章的主题。

第十二章:可扩展框架

社交网络、互动媒体和深度分析的出现导致每天处理的数据量激增。对于数据科学家来说,这不再仅仅是找到最合适和最准确的算法来挖掘数据的问题;它还涉及到利用多核 CPU 架构和分布式计算框架及时解决问题。毕竟,如果模型不可扩展,数据挖掘应用的价值有多大?

Scala 开发者有许多选项可用于为非常大的数据集构建分类和回归应用。本章涵盖了 Scala 并行集合、Actor 模型、Akka 框架和 Apache Spark 内存集群。本章涵盖了以下主题:

  • Scala 并行集合简介

  • 在多核 CPU 上并行集合的性能评估

  • Actor 模型和反应式系统

  • 使用 Akka 进行集群和可靠的分布式计算

  • 使用 Akka 路由器设计的计算工作流程

  • Apache Spark 聚类及其设计原则简介

  • 使用 Spark MLlib 进行聚类

  • Spark 的相对性能调整和评估

  • Apache Spark 框架的优缺点

概述

不同堆叠的框架和库提供了分布式和并发处理的支持。Scala 并发和并行集合类利用 Java 虚拟机的线程能力。Akka.io 实现了一个可靠的动作模型,最初作为 Scala 标准库的一部分引入。Akka 框架支持远程 Actor、路由、负载均衡协议、调度器、集群、事件和可配置的邮箱管理。此框架还提供了对不同传输模式、监督策略和类型化 Actor 的支持。Apache Spark 的弹性分布式数据集利用了高级序列化、缓存和分区能力,利用 Scala 和 Akka 库。

以下栈表示说明了框架之间的相互依赖关系:

概述

使用 Scala 的可扩展框架的栈表示

每一层都为前一层添加新的功能,以增加可扩展性。Java 虚拟机作为一个进程在单个主机上运行。Scala 并发类通过利用多核 CPU 能力,无需编写多线程应用程序,有效地部署应用程序。Akka 将 Actor 模式扩展到具有高级消息和路由选项的集群。最后,Apache Spark 通过其弹性分布式数据集和内存持久性,利用 Scala 的高阶集合方法和 Akka 的 Actor 模型实现,提供具有更好性能和可靠性的大规模数据处理系统。

Scala

Scala 标准库提供了一套丰富的工具,例如并行集合和并发类,用于扩展数值计算应用。尽管这些工具在处理中等规模的数据集时非常有效,但遗憾的是,开发者往往更倾向于使用更复杂的框架。

对象创建

尽管代码优化和内存管理超出了本章的范围,但记住采取一些简单的步骤可以提高应用程序的可扩展性是值得的。使用 Scala 处理大型数据集时最令人沮丧的挑战之一是创建大量对象和垃圾收集器的负载。

以下是一些补救措施的清单:

  • 在迭代函数中使用可变实例限制不必要的对象重复

  • 使用延迟值和 Stream 类按需创建对象

  • 利用高效的集合,如布隆过滤器跳表

  • 运行 javap 解析 JVM 生成的字节码

一些问题需要预处理和训练非常大的数据集,导致 JVM 的内存消耗显著增加。流是类似列表的集合,其中元素是延迟实例化或计算的。流与视图具有相同的延迟计算和内存分配的目标。

让我们考虑机器学习中损失函数的计算。DataPoint类型的观察定义为特征向量x和标记或预期值y

case class DataPoint(x: DblVector, y: Double)

我们可以创建一个损失函数,LossFunction,在内存有限的平台上处理一个非常大的数据集。负责损失或误差最小化的优化器在每次迭代或递归时调用损失函数,如下面的图示所示:

Streams

Scala 流分配和释放的说明

LossFunction类的构造函数有三个参数(第2行):

  • 损失函数对每个数据点的计算f

  • 模型的weights

  • 整个流的尺寸dataSize

代码如下:

type StreamLike = WeakReference[Stream[DataPoint]] //1
class LossFunction(
    f: (DblVector, DblVector) => Double,
    weights: DblVector, 
    dataSize: Int) {  //2

  var nElements = 0
  def compute(stream: () => StreamLike): Double = 
      compute(stream().get, 0.0)  //3

  def _loss(xs: List[DataPoint]): Double = xs.map(
    dp => dp.y - f(weights, dp.x)).map( sqr(_)).sum //4
}

流的损失函数实现为compute尾递归(第3行)。递归方法更新流的引用。流的引用类型是WeakReference(第1行),因此垃圾收集器可以回收与已计算损失的切片相关的内存。在这个例子中,损失函数被计算为平方误差之和(第4行)。

compute方法管理流切片的分配和释放:

@tailrec
def compute(stream: Stream[DataPoint], loss: Double): Double = {
  if( nElements >= dataSize)  loss
  else {
    val step = if(nElements + STEP > dataSize) 
             dataSize - nElements else STEP
    nElements += step
    val newLoss = _loss(stream.take(step).toList) //5
    compute( stream.drop(STEP), loss + newLoss ) //6
  }
 }

数据集处理分为两个步骤:

  • 驱动程序分配(即take)观察流的一个切片,然后计算切片中所有观察的累积损失(第5行)

  • 一旦完成切片的损失计算,分配给弱引用的内存被释放(即drop)(第6行)

注意

弱引用的替代方案

为了使流强制垃圾收集器回收与每个观察切片相关的内存块,有如下替代弱引用的方法:

  • 将流引用定义为def

  • 将引用包装到方法中;当包装方法返回时,引用对垃圾收集器是可访问的

  • 使用List迭代器

在整个流执行损失函数的平均分配内存是分配单个切片所需的内存。

并行集合

Scala 标准库包括并行集合,其目的是屏蔽开发人员对并发线程执行和竞态条件的复杂性。并行集合是将并发结构封装到更高抽象级别的一种非常方便的方法[12:1]。

在 Scala 中有两种创建并行集合的方法,如下所示:

  • 使用par方法将现有集合转换为具有相同语义的并行集合;例如,List[T].par: ParSeq[T]Array[T].par: ParArray[T]Map[K,V].par: ParMap[K,V]等等

  • 使用来自collection.parallelparallelimmutableparallel.mutable包的集合类;例如,ParArrayParMapParSeqParVector等等

处理并行集合

并行集合本身并不适合并发处理,直到分配给它一个线程池和任务调度器。幸运的是,Scala 的并行和并发包为开发者提供了一个强大的工具箱,可以将集合的分区或段映射到运行在不同 CPU 核心上的任务。组件如下:

  • TaskSupport:这个特质继承了通用的Tasks特质。它负责在并行集合上调度操作。有三种具体的实现。

  • ThreadPoolTaskSupport:它使用 JVM 旧版本中的线程池。

  • ExecutionContextTaskSupport:它使用ExecutorService,将任务的管理委托给线程池或ForkJoinTasks池。

  • ForkJoinTaskSupport:它使用 Java SDK 1.6 中引入的java.util.concurrent.FortJoinPool类型的 fork-join 池。在 Java 中,fork-join 池ExecutorService的一个实例,它试图运行当前任务及其任何子任务。它执行轻量级线程的ForkJoinTask实例。

以下示例实现了使用并行向量和ForkJoinTaskSupport生成随机指数值。

val rand = new ParVector[Float]
Range(0,MAX).foreach(n => rand.updated(n, n*Random.nextFloat))//1
rand.tasksupport = new ForkJoinTaskSupport(new ForkJoinPool(16)) 
val randExp = vec.map( Math.exp(_) ) //2

随机概率的rand并行向量由主任务(第1行)创建和初始化,但将其转换为randExp指数值的向量是由 16 个并发任务池(第2行)执行的。

注意

保持元素顺序

使用迭代器遍历并行集合的操作会保留集合元素的原始顺序。没有迭代器的方法,如foreachmap,不能保证处理元素的顺序会被保留。

基准框架

并行集合的主要目的是通过并发来提高执行性能。第一步是选择一个现有的基准或创建我们自己的基准。

注意

Scala 库基准

Scala 标准库中有一个用于通过命令行测试的testing.Benchmark特质[12:2]。你需要做的只是将你的函数或代码插入到run方法中:

object test with Benchmark { def run { /* … /* }

让我们创建一个参数化的ParBenchmark类来评估并行集合上操作的性能:

abstract class ParBenchmarkU { 
  def map(f: U => U)(nTasks: Int): Double  //1
  def filter(f: U => Boolean)(nTasks: Int): Double //2
  def timing(g: Int => Unit ): Long
}

用户必须为并行集合的map(第1行)和filter(第2行)操作提供数据转换f,以及并发任务的数量nTaskstiming方法收集给定操作g在并行集合上times次执行的时间:

def timing(g: Int => Unit ): Long = {   
  var startTime = System.currentTimeMillis
  Range(0, times).foreach(g)
  System.currentTimeMillis - startTime
}

让我们定义用于并行数组的映射和归约操作,其基准定义如下:

class ParArrayBenchmarkU extends ParBenchmarkT

基准构造函数的第一个参数是 Scala 标准库的默认数组(第3行)。第二个参数是与数组关联的并行数据结构(或类)(第4行)。

让我们比较 ParArrayBenchmarkmapreduce 方法在并行化数组和默认数组上的行为,如下所示:

def map(f: U => U)(nTasks: Int): Unit = {
  val pool = new ForkJoinPool(nTasks)
  v.tasksupport = new ForkJoinTaskSupport(pool)
  val duration = timing(_ => u.map(f)).toDouble  //5
  val ratio = timing( _ => v.map(f))/duration  //6
  show(s"$numTasks, $ratio")
}

用户必须定义映射函数 f 和可用于在数组 u(第 5 行)及其并行对应数组 v(第 6 行)上执行映射转换的并发任务数 nTasksreduce 方法遵循相同的设计,如下面的代码所示:

def reduce(f: (U,U) => U)(nTasks: Int): Unit = { 
  val pool = new ForkJoinPool(nTasks)
  v.tasksupport = new ForkJoinTaskSuppor(pool)   
  val duration = timing(_ => u.reduceLeft(f)).toDouble //7
  val ratio = timing( _ => v.reduceLeft(f) )/duration  //8
  show(s"$numTasks, $ratio")
}

用户定义的函数 f 用于在数组 u(第 7 行)及其并行对应数组 v(第 8 行)上执行减少操作。

同样的模板可以用于其他高级 Scala 方法,例如 filter

每个操作的绝对时间完全取决于环境。记录并行化数组上操作的执行时间与单线程数组上执行时间的比率要更有用得多。

用于评估 ParHashMap 的基准类 ParMapBenchmarkParArray 的基准类似,如下面的代码所示:

class ParMapBenchmarkU extends ParBenchmarkT

例如,ParMapBenchmarkfilter 方法评估了并行映射 v 相对于单线程映射 u 的性能。它将过滤条件应用于每个映射的值,如下所示:

def filter(f: U => Boolean)(nTasks: Int): Unit = {
  val pool = new ForkJoinPool(nTasks)
  v.tasksupport = new ForkJoinTaskSupport(pool)   
  val duration = timing(_ => u.filter(e => f(e._2))).toDouble 
  val ratio = timing( _ => v.filter(e => f(e._2)))/duration
  show(s"$nTasks, $ratio")
}

性能评估

第一次性能测试包括创建一个单线程并行数组,并使用递增的任务数量执行 mapreduce 评估方法,如下所示:

val sz = 1000000; val NTASKS = 16
val data = Array.fill(sz)(Random.nextDouble) 
val pData = ParArray.fill(sz)(Random.nextDouble) 
val times: Int = 50

val bench = new ParArrayBenchmarkDouble 
val mapper = (x: Double) => Math.sin(x*0.01) + Math.exp(-x)
Range(1, NTASKS).foreach(bench.map(mapper)(_)) 
val reducer = (x: Double, y: Double) => x+y 
Range(1, NTASKS).foreach(bench.reduce(reducer)(_)) 

注意

性能测量

代码必须在循环中执行,并且必须在大量执行的平均持续时间上计算,以避免瞬态操作,例如 JVM 进程的初始化或收集未使用的内存(GC)。

以下图表显示了性能测试的输出:

性能评估

并发任务对 Scala 并行化 mapreduce 性能的影响

测试在具有 8 个核心 CPU 和 8 GB 可用内存的 JVM 上执行了 1 百万次映射和归约函数。

以下结果在以下方面并不令人惊讶:

  • 归约器没有利用数组的并行性。ParArray 在单任务场景下有轻微的开销,然后与 Array 的性能相匹配。

  • map 函数的性能得益于数组的并行化。当分配的任务数量等于或超过 CPU 核心数时,性能趋于平稳。

第二次测试包括比较 ParArrayParHashMap 并行集合在 mapfilter 方法上的行为,使用与第一次测试相同的配置,如下所示:

val sz = 10000000
val mData = new HashMap[Int, Double]
Range(0, sz).foreach( mData.put(_, Random.nextDouble)) //9
val mParData = new ParHashMap[Int, Double]
Range(0, sz).foreach( mParData.put(_, Random.nextDouble))

val bench = new ParMapBenchmarkDouble
Range(1, NTASKS).foreach(bench.map(mapper)(_)) //10
val filterer = (x: Double) => (x > 0.8) 
Range(1, NTASKS).foreach( bench.filter(filterer)(_)) //11

测试初始化一个HashMap实例及其ParHashMap并行计数器,包含一百万个随机值(第9行)。基准bench使用第一个测试中引入的mapper实例(第10行)和过滤函数filterer(第11行)处理这些哈希表的所有元素,NTASKS等于 6。结果如下面的图表所示:

性能评估

并发任务对 Scala 并行化数组和哈希表性能的影响

集合并行化的影响在方法和集合之间非常相似。重要的是要注意,对于五个并发任务以上的并行集合,其性能会达到单线程集合的四倍左右。核心停用部分负责这种行为。核心停用通过禁用一些 CPU 核心来节约电力,在单个应用程序的情况下,它几乎消耗了所有的 CPU 周期。

备注

进一步性能评估

性能测试的目的是突出使用 Scala 并行集合的好处。您应该进一步实验除ParArrayParHashMap之外的其他集合以及其他高阶方法,以确认这种模式。

显然,性能提高四倍并不是什么值得抱怨的事情。话虽如此,并行集合仅限于单主机部署。如果您无法忍受这种限制,但仍需要可扩展的解决方案,Actor 模型为高度分布式应用程序提供了一个蓝图。

使用 Actors 的可扩展性

传统的多线程应用程序依赖于访问位于共享内存中的数据。该机制依赖于同步监视器,如锁、互斥锁或信号量,以避免死锁和不一致的可变状态。即使是经验最丰富的软件工程师,调试多线程应用程序也不是一项简单的任务。

Java 中共享内存线程的第二个问题是由于连续的上下文切换引起的高计算开销。上下文切换包括将当前由基指针和栈指针定义的栈帧保存到堆内存中,并加载另一个栈帧。

可以通过依赖以下关键原则的并发模型来避免这些限制和复杂性:

  • 不可变数据结构

  • 异步通信

Actor 模型

Actor 模型最初在Erlang编程语言中引入,解决了这些问题[12:3]。使用 Actor 模型的目的有两个,如下所述:

  • 它将计算分布在尽可能多的核心和服务器上

  • 它减少了或消除了在 Java 开发中非常普遍的竞争条件和死锁。

该模型由以下组件组成:

  • 独立的称为 Actors 的处理单元。它们通过异步交换消息而不是共享状态来进行通信。

  • 在每个代理依次处理之前,不可变的消息被发送到队列,称为邮箱。

让我们看看以下图表:

代理模型

代理之间的消息表示

有两种消息传递机制,如下所示:

  • 发送-忽略或告知:这会将不可变消息异步发送到目标或接收代理,并立即返回而不阻塞。语法是targetActorRef ! message

  • 发送-接收或询问:这会异步发送消息,但返回一个Future实例,该实例定义了目标代理的预期回复val future = targetActorRef ? message

代理消息处理器的通用结构在某种程度上类似于 Java 中的Runnable.run()方法,如下所示:

while( true ){
  receive { case msg1: MsgType => handler } 
}

receive关键字实际上是PartialFunction[Any, Unit]类型的部分函数[12:4]。其目的是避免强迫开发者处理所有可能的消息类型。消费消息的代理可能运行在单独的组件或甚至应用程序中,这些消息由代理产生。在应用程序的未来版本中,预测代理必须处理的消息类型并不总是容易。

类型不匹配的消息只是被忽略。在代理的例程中不需要抛出异常。代理模型的实现努力避免上下文切换和线程创建的开销[12:5]。

注意

I/O 阻塞操作

虽然强烈建议不要使用代理来阻塞操作,如 I/O,但在某些情况下,发送者需要等待响应。需要记住,阻塞底层线程可能会使其他代理从 CPU 周期中饿死。建议您配置运行时系统使用大线程池,或者通过设置actors.enableForkJoin属性为false允许线程池调整大小。

分区

数据集被定义为 Scala 集合,例如ListMap等。并发处理需要以下步骤:

  1. 将数据集分解成多个子数据集。

  2. 独立并发地处理每个数据集。

  3. 聚合所有生成的数据集。

这些步骤通过在“为什么 Scala?”部分下的“抽象”节中与集合关联的 monad 定义,在第一章的“入门”中。

  1. apply方法创建第一步的子集合或分区,例如,def applyT: List[T]

  2. 类似于映射的操作定义了第二阶段。最后一步依赖于 Scala 集合的 monoidal 结合性,例如,def ++ (a: List[T], b: List[T]): List[T] = a ++ b

  3. 聚合,如reducefoldsum等,包括将所有子结果展平为单个输出,例如val xs: List(…) = List(List(..), List(..)).flatten

可以并行化的方法有mapflatMapfilterfindfilterNot。不能完全并行化的方法有reducefoldsumcombineaggregategroupBysortWith

超越演员——响应式编程

Actor 模型是响应式编程范式的例子。其概念是函数和方法在响应事件或异常时执行。响应式编程将并发与基于事件的系统相结合[12:6]。

高级函数式响应式编程结构依赖于可组合的 future 和传值调用风格CPS)。一个 Scala 响应式库的例子可以在github.com/ingoem/scala-react找到。

Akka

Akka 框架通过添加提取能力扩展了 Scala 中原始的 Actor 模型,例如支持类型化 Actor、消息分发、路由、负载均衡和分区,以及监督和可配置性[12:7]。

Akka 框架可以从akka.io/网站或通过www.typesafe.com/platform的 Typesafe Activator 下载。

Akka 通过封装 Scala Actor 的一些细节来简化 Actor 模型的实现,这些细节包含在akka.actor.Actorakka.actor.ActorSystem类中。

您想要重写的三个方法如下:

  • prestart: 这是一个可选的方法,在 Actor 执行之前被调用以初始化所有必要的资源,例如文件或数据库连接

  • receive: 这个方法定义了 Actor 的行为,并返回一个PartialFunction[Any, Unit]类型的部分函数

  • postStop: 这是一个可选的方法,用于清理资源,例如释放内存、关闭数据库连接、套接字或文件句柄

注意

类型化和无类型演员

无类型演员可以处理任何类型的消息。如果接收 Actor 没有匹配消息类型,则该消息被丢弃。无类型演员可以被视为无契约演员。它们是 Scala 中的默认演员。

类型化演员类似于 Java 远程接口。它们响应方法调用。调用被公开声明,但执行异步委派给目标演员的私有实例[12:8]。

Akka 提供了一系列功能来部署并发应用程序。让我们为使用从显式或隐式单子数据转换继承的任何预处理或分类算法转换数据集的 master Actor 和 worker Actors 创建一个通用模板,如第二章中“单子数据转换”部分所述[part0165.xhtml#aid-4TBCQ2 "第二章。Hello World!"],“Hello World!”

主 Actor 以以下一种方式管理工作 Actor:

  • 单个演员

  • 通过路由器调度器进行集群

路由器是 Actor 监督的一个非常简单的例子。Akka 中的监督策略是使应用程序容错的关键组件[12:9]。一个监督 Actor 管理其子操作、可用性和生命周期,这些子操作被称为下属。演员之间的监督组织成层次结构。监督策略分为以下几类:

  • 一对一策略:这是默认策略。如果某个下属出现故障,主管将仅对该下属执行恢复、重启或恢复操作。

  • 全对一策略:如果其中一个演员失败,主管将对所有下属执行恢复或补救操作。

主-工作

首先要评估的模型是传统的主从主工作设计,用于计算工作流程。在这个设计中,工作 Actor 由主 Actor 初始化和管理,主 Actor 负责控制算法的迭代过程、状态和终止条件。分布式任务的编排是通过消息传递来完成的。

注意

设计原则

强烈建议您将计算或特定领域逻辑的实现与工作 Actor 和主 Actor 的实际实现分开。

消息交换

实现主-工作设计的第一步是定义主 Actor 和每个工作 Actor 之间交换的不同消息类,以控制迭代过程的执行。主-工作设计的实现如下:

sealed abstract class Message(val i: Int)
case class Terminate(i: Int) extends Message(i)
case class Start(i: Int =0) extends Message(i)  //1
case class Activate(i: Int, x: DblVector) extends Message(i) //2
case class Completed(i: Int, x: DblVector) extends Message(i)//3

让我们定义控制算法执行的消息。我们需要至少以下消息类型或案例类:

  • Start:这是客户端代码发送给主节点以启动计算(行1)的消息。

  • Activate:这是主 Actor 发送给工作 Actor 以激活计算的消息。此消息包含要由工作 Actor 处理的时间序列x,还包含对sender(主 actor)的引用。(行2

  • Completed:这是每个工作 Actor 发送回sender的消息。它包含组内数据的方差(行3)。

主 Actor 使用PoisonPill消息停止工作 Actor。终止 actor 的不同方法在主 actor部分中描述。

Message类的层次结构被封闭,以防止第三方开发者添加另一个消息类型。工作 Actor 通过执行ITransform类型的数据转换来响应激活消息。主 Actor 和工作 Actor 之间交换的消息如下图中所示:

消息交换

在 actor 框架中主从通信的草图设计

注意

消息作为案例类

演员通过管理每个消息实例(复制、匹配等)检索其邮箱中的消息。因此,必须将消息类型定义为 case 类。否则,开发者将不得不重写equalshashCode方法。

工作进程演员

工作进程演员负责将主 Actor 创建的每个分区数据集进行转换,如下所示:

type PfnTransform =  PartialFunction[DblVector, Try[DblVector]]

class Worker(id: Int, 
     fct: PfnTransform) extends Actor  {  //1
  override def receive = {
    case msg: Activate =>  //2
      sender ! Completed(msg.id+id,  fct(msg.xt).get)
   }
}

Worker类构造函数接受fct(部分函数作为参数)(行1)。工作进程在接收到Activate消息时启动对msg.xt数据的处理或转换(行2)。一旦fct数据转换完成,它将返回Completed消息给主进程。

工作流程控制器

在第一章的可伸缩性部分,我们介绍了工作流程和控制器概念,以将训练和分类过程作为时间序列上的转换序列来管理。让我们定义一个用于所有控制器演员的抽象类Controller,具有以下三个关键参数:

  • 要处理的时间序列xt

  • 一个作为部分函数实现的fct数据转换

  • 分区数量nPartitions用于将时间序列分解为并行处理

Controller类可以定义为以下内容:

abstract class Controller (
  val xt: DblVector, 
   val fct: PfnTransform, 
   val nPartitions: Int) extends Actor with Monitor { //3

   def partition: Iterator[DblVector] = { //4
      val sz = (xt.size.toDouble/nPartitions).ceil.toInt
      xt.grouped(sz)
   }
}

控制器负责将时间序列分割成几个分区,并将每个分区分配给一个专用的工作进程(行4)。

主演员

让我们定义一个主演员类Master。需要重写的三个方法如下:

  • prestart: 这是一个在演员执行之前调用以初始化所有必要资源(如文件或数据库连接)的方法(行9

  • receive: 这是一个部分函数,用于从邮箱中出队并处理消息

  • postStop: 这会清理资源,例如释放内存和关闭数据库连接、套接字或文件句柄(行10

Master类可以定义为以下内容:

abstract class Master(  //5
    xt: DblVector, 
    fct: PfnTransform, 
    nPartitions: Int) extends Controller(xt, fct, nPartitions) {

  val aggregator = new Aggregator(nPartitions)  //6
  val workers = List.tabulate(nPartitions)(n => 
        context.actorOf(Props(new Worker(n, fct)), 
               name = s"worker_$n"))  //7
  workers.foreach( context.watch ( _ ) )  //8

  override def preStart: Unit = /* ... */  //9
  override def postStop: Unit = /* ... */  //10
  override def receive 
}

Master类有以下参数(行5):

  • xt: 这是需要转换的时间序列

  • fct: 这是一个转换函数

  • nPartitions: 这是分区的数量

聚合类aggregator收集并减少每个工作进程的结果(行6):

class Aggregator(partitions: Int) {
  val state = new ListBuffer[DblVector]

  def += (x: DblVector): Boolean = {
    state.append(x)
    state.size == partitions
  }

  def clear: Unit = state.clear
  def completed: Boolean = state.size == partitions
}

工作进程演员通过ActorSystem上下文的actorOf工厂方法创建(行7)。工作进程演员附加到主演员的上下文中,因此当工作进程终止时可以通知它(行8)。

receive消息处理程序仅处理两种类型的消息:来自客户端代码的Start和来自工作进程的Completed,如下所示:

override def receive = {
  case s: Start => start  //11

  case msg: Completed =>   //12
    if( aggregator +=  msg.xt) //13
      workers.foreach( context.stop(_) )   //14

  case Terminated(sender) => //15
    if( aggregator.completed ) {  
      context.stop(self)   //16
      context.system.shutdown
    }
}

Start消息触发将输入时间序列分割成分区(行11):

def start: Unit = workers.zip(partition.toVector)
             .foreach {case (w, s) => w ! Activate(0,s)} //16

然后将分区通过带有Activate消息(行16)发送到每个工作进程。

每个工人完成任务后(第 12 行)都会向主节点发送一个Completed消息。主节点会汇总来自每个工人的结果(第 13 行)。一旦所有工人完成任务,它们就会被从主节点的上下文中移除(第 14 行)。主节点通过发送一个Terminated消息来终止所有工人(第 15 行),最后,通过请求其context停止它自己来终止(第 16 行)。

之前的代码片段使用了两种不同的方法来终止一个 Actor。正如这里提到的,有四种不同的方法可以关闭一个 Actor:

  • actorSystem.shutdown: 这个方法由客户端用来关闭父 Actor 系统

  • actor ! PoisonPill: 这个方法由客户端用来向 Actor 发送毒药丸消息

  • context.stop(self): 这个方法由 Actor 在其上下文中用来关闭自己

  • context.stop(childActorRef): 这个方法由 Actor 通过其引用来关闭自己

带路由的主节点

之前的设计只有在每个工人都有独特的特性,需要与主节点直接通信时才有意义。在大多数应用中并非如此。工人之间的通信和内部管理可以委托给一个路由器。主节点路由功能的实现与之前的设计非常相似,如下面的代码所示:

class MasterWithRouter(
    xt: DblVector, 
    fct: PfnTransform, 
    nPartitions: Int) extends Controller(xt, fct, nPartitions)  {

  val aggregator = new Aggregator(nPartitions)
  val router = {   //17
    val routerConfig = RoundRobinRouter(nPartitions, //18
           supervisorStrategy = this.supervisorStrategy)
    context.actorOf(
       Props(new Worker(0,fct)).withRouter(routerConfig) )
   }
   context.watch(router)

   override def receive
}

唯一的区别是context.actorOf工厂在创建工人时还会创建一个额外的 Actor,即路由器(第 17 行)。这个特定的实现依赖于路由器按轮询方式将消息分配给每个工人(第 18 行)。Akka 支持多种路由机制,可以选择随机 Actor、邮箱最小的 Actor、第一个响应广播的 Actor 等等。

注意

路由器监督

路由器 Actor 是工人 Actor 的父 Actor。按照设计,它是工人 Actor 的监督者,即它的子 Actor。因此,路由器负责工人 Actor 的生命周期,包括它们的创建、重启和终止。

receive消息处理器的实现几乎与不带路由功能的主节点的消息处理器相同,唯一的区别是通过路由器来终止工人(第 19 行):

override def receive = {
  case Start => start
  case msg: Completed => 
    if( aggregator += msg.xt) context.stop(router)  //19
   ...
}

start消息处理器需要修改为通过路由器向所有工人广播Activate消息:

def start: Unit = 
  partition.toVector.foreach {router ! Activate(0, _)}

分布式离散傅里叶变换

让我们在时间序列xt上选择离散傅里叶变换DFT)作为我们的数据转换。我们已经在第三章的离散傅里叶变换部分讨论过这个问题,数据预处理。测试代码在主节点是否有路由功能的情况下都是一样的。

首先,让我们定义一个专门用于执行分布式离散傅里叶变换的 master 控制器DFTMaster,如下所示:

type Reducer = List[DblVector] => immutable.Seq[Double]
class DFTMaster(
    xt: DblVector, 
    nPartitions: Int, 
    reducer: Reducer)   //20
      extends Master(xt, DFT[Double].|>, nPartitions)

reducer方法聚合或减少每个工人(第 20 行)的离散傅里叶变换(频率分布)的结果。在离散傅里叶变换的情况下,fReduce聚合方法将频率分布列表转置,然后对每个频率求和振幅(第 21 行):

def fReduce(buf: List[DblVector]): immutable.Seq[Double] = 
   buf.transpose.map( _.sum).toSeq  //21

让我们看看测试代码:

val NUM_WORKERS = 4 
val NUM_DATAPOINTS = 1000000
val h = (x: Double) =>2.0*Math.cos(Math.PI*0.005*x) + 
    Math.cos(Math.PI*0.05*x) + 0.5*Math.cos(Math.PI*0.2*x) +
    0.3* Random.nextDouble   //22

val actorSystem = ActorSystem("System")  //23
val xt = Vector.tabulate(NUM_DATA_POINTS)(h(_))
val controller = actorSystem.actorOf(
         Props(new DFTMasterWithRouter(xt, NUM_WORKERS, 
                    fReduce)), "MasterWithRouter")  //24
controller ! Start(1) //25

输入的时间序列是通过噪声正弦函数h(第 22 行)合成的。函数h有三个不同的谐波:0.0050.050.2,因此变换的结果可以很容易地验证。ActorSystem实例化(第 23 行),并通过 Akka 的ActorSystem.actorOf工厂生成主 Actor(第 24 行)。主程序向主 Actor 发送Start消息以触发分布式离散傅里叶变换的计算(第 25 行)。

注意

动作实例化

虽然可以使用构造函数实例化scala.actor.Actor类,但akka.actor.Actor是通过ActorSystem上下文、actorOf工厂和一个Props配置对象实例化的。这种第二种方法有几个优点,包括将 Actor 的部署与其功能解耦,并强制执行默认的监督者或父 Actor;在这种情况下,是ActorSystem

以下序列图展示了主程序、主 Actor 和工人 Actor 之间的消息交换:

分布式离散傅里叶变换

用于交叉验证组归一化的序列图

测试的目的是评估使用 Akka 框架计算离散傅里叶变换相对于原始实现(无 Actor)的性能。与 Scala 并行集合一样,转换的绝对时间取决于主机和配置,如下面的图表所示:

分布式离散傅里叶变换

工人(奴隶)Actor 的数量对离散傅里叶变换性能的影响

离散傅里叶变换的单线程版本比使用单个工人 Actor 的 Akka 主-工人模型实现要快得多。分区和聚合(或减少)结果的成本给傅里叶变换的执行增加了显著的开销。然而,当有三个或更多工人 Actor 时,主-工人模型效率更高。

局限性

主-工人实现存在一些问题,如下所述:

  • 在主 Actor 的消息处理程序中,无法保证在主 Actor 停止之前所有工人都会消耗掉毒药丸。

  • 主程序必须休眠一段时间,足够让主 Actor 和工人 Actor 完成他们的任务。当主程序醒来时,无法保证计算已经完成。

  • 没有机制来处理在传递或处理消息时的失败。

犯罪者是只使用“发射后不管”机制在主节点和工作者之间交换数据。发送和接收协议以及未来是解决这些问题的方法。

未来

未来是一个对象,更具体地说是一个单子,用于以非阻塞方式检索并发操作的结果。这个概念与提供给工作者的回调非常相似,它在任务完成时调用它。未来持有可能在任务完成时(无论成功与否)可用的值 [12:10]。

从未来中检索结果有两种选择:

  • 使用scala.concurrent.Await阻塞执行

  • onComplete, onSuccess, 和 onFailure 回调函数

注意

哪个未来?

Scala 环境为开发者提供了两个不同的Future类:scala.actor.Futurescala.concurrent.Future

actor.Future类用于编写传递风格的工作流程,其中当前演员在未来的值可用之前被阻塞。本章中使用的scala.concurrent.Future类型的实例与 Scala 中的java.concurrent.Future等效。

演员生命周期

让我们重新实现上一节中介绍的通过方差对交叉验证组进行归一化的方法,使用未来支持并发。第一步是导入执行主演员和未来的适当类,如下所示:

import akka.actor.{Actor, ActorSystem, ActorRef, Props} //26
import akka.util.Timeout   //27
import scala.concurrent.{Await, Future}  //28

由于 Akka 扩展的演员模型(行26),演员类由akka.actor包提供,而不是scala.actor._包。与java.concurrent包类似,未来相关的类FutureAwait是从scala.concurrent包导入的(行28)。akka.util.Timeout类用于指定演员等待未来完成的最大持续时间(行27)。

父演员或主程序管理它创建的未来有两种选择,如下所示:

  • 阻塞:父演员或主程序停止执行,直到所有未来都完成了它们的任务。

  • 回调:父演员或主程序在执行期间启动未来。未来任务与父演员并行执行,并在每个未来任务完成时通知。

使用scala.concurrent.Await阻塞未来

以下设计包括阻塞启动未来的演员,直到所有未来都已完成,无论是返回结果还是抛出异常。让我们将主演员修改为TransformFutures类,该类管理未来而不是工作者或路由演员,如下所示:

abstract class TransformFutures(
    xt: DblVector, 
    fct: PfnTransform, 
    nPartitions: Int)
    (implicit timeout: Timeout) //29
         extends Controller(xt, fct, nPartitions) {
  override def receive = {
    case s: Start => compute(transform) //30
  }
}

TransformFutures 类需要与 Master 演员相同的参数:时间序列 xt、数据转换 fct 和分区数 nPartitionstimeout 参数是 Await.result 方法的隐含参数,因此需要将其声明为参数(第 29 行)。唯一的信息 Start 触发每个 futures 的数据转换计算,然后聚合结果(第 30 行)。transformcompute 方法与主从设计中的语义相同。

注意

通用消息处理器

你可能已经阅读过,甚至可能编写过用于调试目的的具有通用情况 _ => 处理器的演员示例。消息循环将一个部分函数作为参数。因此,如果消息类型未被识别,则不会抛出错误或异常。除了用于调试目的的处理程序外,不需要这样的处理程序。消息类型应该从密封的抽象类或密封的特质继承,以防止错误地添加新的消息类型。

让我们来看看 transform 方法。其主要目的是实例化、启动并返回一个负责分区转换的 futures 数组,如下面的代码所示:

def transform: Array[Future[DblVector]] = {
  val futures = new Array[Future[DblVector]](nPartitions) //31

  partition.zipWithIndex.foreach { case (x, n) => { //32
    futures(n) = Future[DblVector] { fct(x).get } //33
  }}
  futures
}

创建了一个 futures 数组(每个分区一个 futures)(第 31 行)。transform 方法调用分区方法 partition(第 32 行),然后使用 fct 部分函数初始化 futures(第 33 行):

def compute(futures: Array[Future[DblVector]]): Seq[Double] = 
  reduce(futures.map(Await.result(_, timeout.duration))) //34

compute 方法在 futures 上调用用户定义的 reduce 函数。演员的执行会阻塞,直到 Await 类的 scala.concurrent.Await.result 方法(第 34 行)返回每个 futures 计算的结果。在离散傅里叶变换的情况下,在将每个频率的幅度求和之前,频率列表被转置(第 35 行),如下所示:

def reduce(data: Array[DblVector]): Seq[Double] = 
    data.view.map(_.toArray)
        .transpose.map(_.sum)   //35
            .take(SPECTRUM_WIDTH).toSeq

下面的顺序图说明了阻塞设计以及演员和 futures 执行的活动:

在 futures 上阻塞

演员在 future 结果上阻塞的顺序图

处理 future 回调

回调是替代演员在 futures 上阻塞的绝佳选择,因为它们可以在未来执行的同时并发执行其他函数。

实现回调函数有两种简单的方法,如下所示:

  • Future.onComplete

  • Future.onSuccessFuture.onFailure

onComplete 回调函数接受一个 Try[T] => U 类型的函数作为参数,并具有对执行上下文的隐含引用,如下面的代码所示:

val f: Future[T] = future { execute task } f onComplete {   
  case Success(s) => { … }   
  case Failure(e) => { … }
}

你当然可以识别 {Try, Success, Failure} 模纳。

另一种实现方式是调用使用部分函数作为参数的 onSuccessonFailure 方法来实现回调,如下所示:

f onFailure { case e: Exception => { … } } 
f onSuccess { case t => { … } }

阻塞一个期货数据转换和处理回调之间的唯一区别是compute方法或 reducer 的实现。类定义、消息处理程序和期货的初始化与以下代码中显示的相同:

def compute(futures: Array[Future[DblVector]]): Seq[Double] = {
  val buffer = new ArrayBuffer[DblVector]

  futures.foreach( f => {
    f onSuccess {   //36
      case data: DblVector => buffer.append(data)
    }
    f onFailure { case e: Exception =>  /* .. */ } //37
  })
   buffer.find( _.isEmpty).map( _ => reduce(buffer)) //38
}

每个期货都会通过主演员回调,携带数据转换的结果、onSuccess消息(第36行)或异常、OnFailure消息(第37行)。如果所有期货都成功,所有分区的所有频率的值都会相加(第38行)。以下顺序图说明了主演员中回调的处理:

处理未来回调

用于处理演员未来结果的回调的顺序图

注意

执行上下文

使用期货(futures)的应用要求执行上下文(execution context)必须由开发者隐式提供。定义执行上下文有三种不同的方式:

  • 导入上下文:import ExecutionContext.Implicits.global

  • 在演员(或演员上下文)内创建上下文实例:implicit val ec = ExecutionContext.fromExecutorService( … )

  • 在实例化期货时定义上下文:val f= Future[T] ={ } (ec)

将所有内容整合

让我们重用离散傅里叶变换。客户端代码使用与主从测试模型中相同的合成时间序列。第一步是创建离散傅里叶变换的转换期货DFTTransformFuture,如下所示:

class DFTTransformFutures(
    xt: DblVector, 
    partitions: Int)(implicit timeout: Timeout) 
    extends TransformFutures(xt, DFT[Double].|> , partitions)  {

  override def reduce(data: Array[DblVector]): Seq[Double] = 
    data.map(_.toArray).transpose
        .map(_.sum).take(SPECTRUM_WIDTH).toSeq
}

DFTTransformFuture类的唯一目的是定义离散傅里叶变换的reduce聚合方法。让我们重用主从部分下的分布式离散傅里叶变换节中的相同测试用例:

import akka.pattern.ask

val duration = Duration(8000, "millis")
implicit val timeout = new Timeout(duration)
val master = actorSystem.actorOf(   //39
       Props(new DFTTransformFutures(xt, NUM_WORKERS)), 
                        "DFTTransform")
val future = master ? Start(0)  //40
Await.result(future, timeout.duration)   //41
actorSystem.shutdown  //42

主演员(master actor)以TransformFutures类型初始化,输入时间序列xt,离散傅里叶变换DFT以及工作进程或分区数nPartitions作为参数(第39行)。程序通过向主演员发送(askStart消息来创建一个期货实例(第40行)。程序会阻塞直到期货完成(第41行),然后关闭 Akka 演员系统(第42行)。

Apache Spark

Apache Spark 是一个快速且通用的集群计算系统,最初作为 AMPLab/加州大学伯克利分校的一部分开发,作为伯克利数据分析堆栈BDAS)的一部分(en.wikipedia.org/wiki/UC_Berkeley)。它为以下编程语言提供了高级 API,使得编写和部署大型和并发并行作业变得容易 [12:11]:

注意

最新信息的链接

任何关于 Apache Spark 的 URL 在未来版本中可能会发生变化。

Spark 的核心元素是一个弹性分布式数据集RDD),它是一组跨集群节点和/或服务器 CPU 核心分区元素集合。RDD 可以从本地数据结构(如列表、数组或哈希表)、本地文件系统或Hadoop 分布式文件系统HDFS)创建。

Spark 中对 RDD 的操作非常类似于 Scala 的高阶方法。这些操作是在每个分区上并发执行的。RDD 的操作可以分为以下几类:

  • Transformation:这个操作在每个分区上转换、操作和过滤 RDD 的元素

  • Action:这个操作从所有分区聚合、收集或减少 RDD 的元素

RDD 可以被持久化、序列化和缓存以供未来的计算使用。

Spark 是用 Scala 编写的,并建立在 Akka 库之上。Spark 依赖于以下机制来分发和分区 RDD:

  • Hadoop/HDFS 用于分布式和复制的文件系统

  • Mesos 或 Yarn 用于集群管理和共享数据节点池

Spark 生态系统可以用技术栈和框架的堆叠来表示,如下面的图所示:

Apache Spark

Apache Spark 框架生态系统

Spark 生态系统已经发展到支持一些开箱即用的机器学习算法,例如MLlib,这是一个类似 SQL 的接口,使用关系运算符操作数据集,SparkSQL,这是一个用于分布式图的库,GraphX,以及一个流库 [12:12]。

为什么选择 Spark?

Spark 的作者试图通过实现内存迭代计算来解决 Hadoop 在性能和实时处理方面的限制,这对于大多数判别性机器学习算法至关重要。已经进行了许多基准测试并发表了相关论文,以评估 Spark 相对于 Hadoop 的性能提升。在迭代算法的情况下,每次迭代的耗时可以减少到 1:10 或更多。

Spark 提供了一系列预构建的转换和动作,这些动作远远超出了基本的 map-reduce 范式。这些 RDD 上的方法是对 Scala 集合的自然扩展,使得 Scala 开发者迁移代码变得无缝。

最后,Apache Spark 通过允许 RDD 在内存和文件系统中持久化来支持容错操作。持久性使得从节点故障中自动恢复成为可能。Spark 的弹性依赖于底层 Akka 演员的监督策略、他们邮箱的持久性以及 HDFS 的复制方案。

设计原则

Spark 的性能依赖于以下五个核心设计原则 [12:13]:

  • 内存持久化

  • 调度任务的惰性

  • 应用到 RDD 上的转换和操作

  • 共享变量的实现

  • 支持数据框(SQL 感知 RDD)

内存持久化

开发者可以决定持久化(persist)和/或缓存(cache)一个 RDD 以供将来使用。一个 RDD 可能只保存在内存中或只保存在磁盘上——如果可用,则在内存中,否则作为反序列化或序列化的 Java 对象保存在磁盘上。例如,可以通过以下简单语句通过序列化来缓存 RDD rdd,如下所示:

rdd.persist(StorageLevel.MEMORY_ONLY_SER).cache

注意

Kryo 序列化

通过 Serializable 接口进行 Java 序列化是出了名的慢。幸运的是,Spark 框架允许开发者指定更有效的序列化机制,例如 Kryo 库。

惰性

Scala 原生支持惰性值。赋值语句的左侧,可以是值、对象引用或方法,只执行一次,即第一次调用时,如下所示:

class Pipeline {
  lazy val x = { println("x"); 1.5}   
  lazy val m = { println("m"); 3}   
  val n = { println("n"); 6}   
  def f = (m <<1)
  def g(j: Int) = Math.pow(x, j)
}
val pipeline = new Pipeline  //1
pipeline.g(pipeline.f)  //2

打印变量的顺序是 nm,然后 xPipeline 类的实例化初始化了 n,但没有初始化 mx(第 1 行)。在稍后的阶段,调用 g 方法,该方法反过来调用 f 方法。f 方法初始化它需要的 m 值,然后 gx 初始化为计算其 m << 1 次幂(第 2 行)。

Spark 通过仅在执行操作时执行转换来将相同的原理应用于 RDD。换句话说,Spark 将内存分配、并行化和计算推迟到驱动代码通过执行操作获取结果时。通过直接无环图调度器执行所有这些转换的回溯级联效应。

转换和操作

由于 Spark 是用 Scala 实现的,所以您可能不会对 Spark 支持 Scala 集合中最相关的更高方法感到惊讶。第一个表描述了 Spark 使用的转换方法,以及它们在 Scala 标准库中的对应方法。我们使用 (K, V) 表示法表示(键,值)对:

Spark Scala 描述
map(f) map(f) 这通过在集合的每个元素上执行 f 函数来转换 RDD
filter(f) filter(f) 这通过选择 f 函数返回 true 的元素来转换 RDD
flatMap(f) flatMap(f) 这通过将每个元素映射到输出项的序列来转换 RDD
mapPartitions(f) 这在每个分区上单独执行 map 方法
sample 这使用随机生成器有放回或无放回地采样数据的一部分
groupByKey groupBy 这是在 (K,V) 上调用的,用于生成一个新的 (K, Seq(V)) RDD
union union 这将创建一个新的 RDD,作为当前 RDD 和参数的并集
distinct distinct 这从这个 RDD 中消除重复元素
reduceByKey(f) reduce 这使用f函数聚合或减少与每个键对应的值
sortByKey sortWith 这通过键,K,的升序、降序或其他指定顺序重新组织 RDD 中的(K,V)
join 这将 RDD (K,V) 与 RDD (K,W) 连接以生成一个新的 RDD (K, (V,W))
coGroup 这实现了一个连接操作但生成一个 RDD (K, Seq(V), Seq(W))

触发所有分区数据集收集或减少的操作方法如下:

Spark Scala 描述
reduce(f) reduce(f) 这聚合所有分区中的 RDD 元素并返回一个 Scala 对象到驱动器
collect collect 这收集并返回 RDD 在所有分区中的所有元素作为驱动器中的列表
count count 这将 RDD 中的元素数量返回到驱动器
first head 这将 RDD 的第一个元素返回到驱动器
take(n) take(n) 这返回 RDD 的前n个元素到驱动器
takeSample 这将返回一个从 RDD 中随机元素到驱动器的数组
saveAsTextFile 这将 RDD 的元素作为文本文件写入本地文件系统或 HDFS
countByKey 这生成一个带有原始键,K,以及每个键的值计数的(K, Int) RDD
foreach foreach 这对 RDD 的每个元素执行T=> Unit函数

Scala 方法,如foldfinddropflattenminmaxsum目前在 Spark 中尚未实现。其他 Scala 方法,如zip必须谨慎使用,因为在zip中两个集合的顺序在分区之间没有保证。

共享变量

在一个完美的世界中,变量是不可变的并且局部于每个分区以避免竞争条件。然而,在某些情况下,变量必须共享而不破坏 Spark 提供的不可变性。在这方面,Spark 复制共享变量并将它们复制到数据集的每个分区。Spark 支持以下类型的共享变量:

  • 广播值:这些值封装并转发数据到所有分区

  • 累加器变量:这些变量充当求和或引用计数器

四个设计原则可以总结如下图所示:

共享变量

Spark 驱动器和 RDD 之间的交互

以下图说明了 Spark 驱动器和其工作者之间最常见的交互,如下列步骤所示:

  1. 输入数据,无论是作为 Scala 集合存储在内存中还是作为文本文件存储在 HDFS 中,都被并行化并分区为一个 RDD。

  2. 将转换函数应用于所有分区的数据集的每个元素。

  3. 执行一个操作以减少并收集数据回驱动程序。

  4. 数据在驱动程序内部本地处理。

  5. 执行第二次并行化以通过 RDD 分发计算。

  6. 一个变量作为最后一个 RDD 转换的外部参数广播到所有分区。

  7. 最后,最后一个操作在驱动程序中聚合并收集最终结果。

如果您仔细观察,Spark 驱动程序对数据集和 RDD 的管理与 Akka 主和工作者 actor 对未来的管理并没有太大区别。

尝试使用 Spark

Spark 针对迭代计算进行的内存计算使其成为机器学习模型训练(使用动态规划或优化算法实现)的理想选择。Spark 运行在 Windows、Linux 和 Mac OS 操作系统上。它可以以单主机本地模式或分布式环境的主机模式部署。使用的 Spark 框架版本是 1.3。

注意

JVM 和 Scala 兼容版本

在撰写本文时,Spark 1.3.0 版本需要 Java 1.7 或更高版本,Scala 2.10.2 或更高版本。Spark 1.5.0 支持 Scala 2.11,但需要使用带有标志 D-scala2.11 重新组装框架。

部署 Spark

学习 Spark 最简单的方法是在独立模式下部署本地主机。您可以从网站上部署 Spark 的预编译版本,或者使用简单的构建工具sbt)或 Maven [12:14]构建 JAR 文件,如下所示:

  1. 前往spark.apache.org/downloads.html的下载页面。

  2. 选择一个包类型(Hadoop 发行版)。Spark 框架依赖于 HDFS 以集群模式运行;因此,您需要选择 Hadoop 的一个发行版或如 MapR 或 Cloudera 之类的开源发行版。

  3. 下载并解压缩包。

  4. 如果您对框架中添加的最新功能感兴趣,请查看github.com/apache/spark.git上的最新源代码。

  5. 接下来,您需要使用 Maven 或 sbt 从顶级目录构建或组装 Apache Spark 库:

    • Maven:设置以下 Maven 选项以支持构建、部署和执行:

      MAVEN_OPTS="-Xmx4g -XX:MaxPermSize=512M 
                -XX:ReservedCodeCacheSize=512m"
      mvn [args] –DskipTests clean package
      
      

      以下是一些示例:

      • 基于 2.4 版本的 Hadoop 使用 Yarn 集群管理器和 Scala 2.10(默认)构建:

        mvn -Pyarn –Phadoop-2.4 –Dhadoop.version-2.4.0 –DskipTests 
         clean package
        
        
      • 基于 2.6 版本的 Hadoop 使用 Yarn 集群管理器和 Scala 2.11 构建:

        mvn -Pyarn –Phadoop-2.6 –Dhadoop.version-2.6.0 –Dscala-2.11 
         –DskipTests clean package
        
        
      • 简单的构建工具:使用以下命令:

        sbt/sbt [args] assembly
        
        

      以下是一些示例:

      • 基于 2.4 版本的 Hadoop 使用 Yarn 集群管理器和 Scala 2.10(默认)构建:

         sbt -Pyarn –pHadoop 2.4 assembly
        
        
      • 基于 2.6 版本的 Hadoop 使用 Yarn 集群管理器和 Scala 2.11 构建:

         sbt -Pyarn –pHadoop 2.6 –Dscala-2.11 assembly
        
        

注意

安装说明

Spark 中使用的目录和工件名称无疑会随着时间的推移而改变。您可以参考文档和安装指南以获取 Spark 的最新版本。

Apache 支持多种部署模式:

  • 独立模式:驱动程序和执行器作为主节点和从节点 Akka 演员运行,捆绑在默认的 spark 分布 JAR 文件中。

  • 本地模式:这是一种在单个主机上运行的独立模式。从节点演员部署在同一主机内的多个核心上。

  • Yarn 集群管理器:Spark 依赖于运行在 Hadoop 2 及以上版本的 Yarn 资源管理器。Spark 驱动程序可以在与客户端应用程序相同的 JVM 上运行(客户端模式)或在与主节点相同的 JVM 上运行(集群模式)。

  • Apache Mesos 资源管理器:这种部署允许动态和可伸缩的分区。Apache Mesos 是一个开源的通用集群管理器,需要单独安装(请参阅mesos.apache.org/)。Mesos 管理抽象的硬件资源,如内存或存储。

主节点(或驱动程序)、集群管理器和一组从节点(或工作节点)之间的通信在以下图中说明:

部署 Spark

主节点(或驱动程序)、从节点(或工作节点)和集群管理器之间的通信

注意

Windows 下的安装

Hadoop 依赖于一些 UNIX/Linux 实用程序,当在 Windows 上运行时需要添加到开发环境中。必须安装winutils.exe文件并将其添加到HADOOP_PATH环境变量中。

使用 Spark shell

使用以下任何一种方法来使用 Spark shell:

  • Shell 是轻松接触 Spark 弹性分布式数据集的一种方式。要在本地启动 shell,执行./bin/spark-shell –master local[8]以在 8 核心的本地主机上运行 shell。

  • 要在本地启动 Spark 应用程序,连接到 shell 并执行以下命令行:

    ./bin/spark-submit --class application_class --master local[4] 
       --executor-memory 12G  --jars myApplication.jar 
       –class myApp.class
    

    命令启动了应用程序myApplication,在具有 4 核心 CPU 的本地主机和 12GB 内存上运行myApp.main主方法。

  • 要远程启动相同的 Spark 应用程序,连接到 shell 并执行以下命令行:

    ./bin/spark-submit --class application_class 
       --master spark://162.198.11.201:7077 
       –-total-executor-cores 80  
       --executor-memory 12G  
       --jars myApplication.jar –class myApp.class
    

输出将如下所示:

使用 Spark shell

Spark shell 命令行输出的部分截图

注意

Spark shell 的潜在陷阱

根据您的环境,您可能需要通过重新配置conf/log4j.properties来禁用将日志信息记录到控制台。Spark shell 也可能与配置文件或环境变量列表中的类路径声明冲突。在这种情况下,它必须通过ADD_JARS作为环境变量来替换,例如ADD_JARS = path1/jar1, path2/jar2

MLlib

MLlib 是一个建立在 Spark 之上的可扩展机器学习库。截至版本 1.0,该库仍在开发中。

库的主要组件如下:

  • 包括逻辑回归、朴素贝叶斯和支持向量机在内的分类算法

  • 版本 1.0 中聚类限制为 K-means

  • L1 和 L1 正则化

  • 优化技术,如梯度下降、逻辑梯度下降和随机梯度下降,以及 L-BFGS

  • 线性代数,如奇异值分解

  • K-means、逻辑回归和支持向量机的数据生成器

机器学习字节码方便地包含在用简单构建工具构建的 Spark 组装 JAR 文件中。

RDD 生成

转换和操作是在 RDD 上执行的。因此,第一步是创建一个机制,以便从时间序列生成 RDD。让我们创建一个具有convert方法的RDDSource单例,该方法将时间序列xt转换为 RDD,如下所示:

def convert(
    xt: immutable.Vector[DblArray], 
    rddConfig: RDDConfig) 
    (implicit sc: SparkContext): RDD[Vector] = {

  val rdd: RDD[Vector] = 
     sc.parallelize(xt.toVector.map(new DenseVector(_))) //3
  rdd.persist(rddConfig.persist) //4
  if( rddConfig.cache) rdd.cache  //5
  rdd
}

convert方法的最后一个rddConfig参数指定了 RDD 的配置。在本例中,RDD 的配置包括启用/禁用缓存和选择持久化模型,如下所示:

case class RDDConfig(val cache: Boolean, 
    val persist: StorageLevel)

有理由假设SparkContext已经在类似于 Akka 框架中的ActorSystem的方式下隐式定义。

RDD 的生成按照以下步骤进行:

  1. 使用上下文的parallelize方法创建 RDD 并将其转换为向量(SparseVectorDenseVector)(行3)。

  2. 如果需要覆盖默认级别,指定持久化模型或存储级别(行3)。

  3. 指定 RDD 是否需要在内存中持久化(行5)。

注意

创建 RDD 的替代方法

可以使用SparkContext.textFile方法从本地文件系统或 HDFS 加载数据生成 RDD,该方法返回一个字符串 RDD。

一旦创建了 RDD,它就可以作为任何定义为一系列转换和操作的算法的输入使用。让我们尝试在 Spark/MLlib 中实现 K-means 算法。

使用 Spark 的 K-means

第一步是创建一个SparkKMeansConfig类来定义 Apache Spark K-means 算法的配置,如下所示:

class SparkKMeansConfig(K: Int, maxIters: Int, 
     numRuns: Int = 1) {   
  val kmeans: KMeans = {      
    (new KMeans).setK(K) //6
      .setMaxIterations(maxIters)  //7
      .setRuns(numRuns) //8
  }
}

MLlib K-means 算法的最小初始化参数集如下:

  • 聚类数量,K(行6

  • 总误差重建的最大迭代次数,maxIters(行7

  • 训练运行次数,numRuns(行8

SparkKMeans类将 Spark 的KMeans封装为ITransform类型的数据转换,如第二章中“单子数据转换”部分所述,第二章,“Hello World!”该类遵循分类器的设计模板,如附录 A 中“不可变分类器设计模板”部分所述,附录 A,“基本概念”:

class SparkKMeans(    //9
    kMeansConfig: SparkKMeansConfig, 
    rddConfig: RDDConfig, 
    xt: Vector[DblArray])
   (implicit sc: SparkContext) extends ITransformDblArray{

  type V = Int   //10
  val model: Option[KMeansModel] = train  //11

  override def |> : PartialFunction[DblArray, Try[V]] //12
  def train: Option[KMeansModel] 
}

构造函数接受三个参数:Apache Spark KMeans配置kMeansConfig、RDD 配置rddConfig以及用于聚类的输入时间序列xt(第9行)。ITransform特质的偏函数|>的返回类型定义为Int(第10行)。

model的生成仅包括使用rddConfig将时间序列xt转换为 RDD 并调用 MLlib KMeans.run(第11行)。一旦创建,聚类模型(KMeansModel)就可用于预测新的观测值x(第12行),如下所示:

override def |> : PartialFunction[DblArray, Try[V]] = {
  case x: DblArray if(x.length > 0 && model != None) => 
     TryV))
}

|>预测方法返回观测值的聚类索引。

最后,让我们编写一个简单的客户端程序来使用股票价格波动和其每日交易量来练习SparkKMeans模型。目标是提取具有特征(波动性和成交量)的聚类,每个聚类代表股票的特定行为:

val K = 8
val RUNS = 16
val MAXITERS = 200
val PATH = "resources/data/chap12/CSCO.csv"
val CACHE = true

val sparkConf = new SparkConf().setMaster("local[8]")
   .setAppName("SparkKMeans")
   .set("spark.executor.memory", "2048m") //13
implicit val sc = new SparkContext(sparkConf) //14

extract.map { case (vty,vol)  => {  //15
  val vtyVol = zipToSeries(vty, vol)  
  val conf = SparkKMeansConfig(K,MAXITERS,RUNS) //16
  val rddConf = RDDConfig(CACHE, 
                    StorageLevel.MEMORY_ONLY) //17

  val pfnSparkKMeans = SparkKMeans(conf,rddConf,vtyVol) |> //18
  val obs = ArrayDouble
  val clusterId = pfnSparkKMeans(obs)
}

第一步是定义sc上下文的最低配置(第13行)并初始化它(第14行)。vtyvol波动性变量被用作 K-means 的特征,并从 CSV 文件中提取(第15行):

def extract: Option[(DblVector, DblVector)] = {
  val extractors = List[Array[String] => Double](
    YahooFinancials.volatility, YahooFinancials.volume 
  )
  val pfnSrc = DataSource(PATH, true) |>
  pfnSrc( extractors ) match {
    case Success(x) => Some((x(0).toVector, x(1).toVector))
    case Failure(e) => { error(e.toString); None }
  }
}

执行创建 K-means 的配置config(第16行)和另一个 Spark RDD 的配置rddConfig(第17行)。使用 K-means、RDD 配置和输入数据vtyVol创建实现 K-means 算法的pfnSparkKMeans偏函数(第18行)。

性能评估

让我们在具有 32 GB RAM 的 8 核心 CPU 机器上执行交叉验证组的归一化。数据以每个 CPU 核心两个分区的比例分区。

注意

有意义的性能测试

可扩展性测试应使用大量数据点(归一化波动性、归一化成交量)进行,超过 100 万,以估计渐近时间复杂度。

调优参数

Spark 应用程序的性能在很大程度上取决于配置参数。在 Spark 中选择这些配置参数的适当值可能会令人不知所措——据最后一次统计,有 54 个配置参数。幸运的是,其中大多数参数都有相关的默认值。然而,有一些参数值得您关注,包括以下内容:

  • 可用于在 RDD 上执行转换和操作的 CPU 核心数(config.cores.max)。

  • 用于执行转换和操作的内存(spark.executor.memory)。将值设置为最大 JVM 堆栈的 60%通常是一个好的折衷方案。

  • 在所有分区上用于 shuffle 相关操作的可同时使用的并发任务数;它们使用一个键,如reduceByKeyspark.default.parallelism)。推荐的公式是并行度 = 总核心数 x 2。可以通过spark.reduceby.partitions参数覆盖该参数的值,以针对特定的 RDD 减少器。

  • 一个用于压缩序列化 RDD 分区的标志 MEMORY_ONLY_SER (spark.rdd.compress)。其目的是在增加额外 CPU 周期的代价下减少内存占用。

  • 包含操作结果的最多消息大小被发送到 spark.akka.frameSize 驱动程序。如果可能生成大型数组,则需要增加此值。

  • 一个用于压缩大型广播变量的标志 spark.broadcast.compress。这通常是被推荐的。

测试

测试的目的是评估执行时间与训练集大小之间的关系。测试在以下时间段内对美银(BAC)股票的波动性和交易量执行了 MLlib 库中的 K-means:3 个月、6 个月、12 个月、24 个月、48 个月、60 个月、72 个月、96 个月和 120 个月。

以下配置用于执行 K-means 的训练:10 个集群,最大迭代次数为 30 次,运行 3 次。测试在一个具有 8 个 CPU 核心和 32GB RAM 的单个主机上运行。测试使用了以下参数值:

  • StorageLevel = MEMORY_ONLY

  • spark.executor.memory = 12G

  • spark.default.parallelism = 48

  • spark.akka.frameSize = 20

  • spark.broadcast.compress = true

  • 无序列化

执行特定数据集的测试后的第一步是登录到 Spark 监控控制台 http://host_name:4040/stages

测试

K-means 聚类与每月交易数据大小的平均持续时间

显然,每个环境都会产生不同的性能结果,但确认 Spark K-means 的时间复杂度是训练集大小的线性函数。

注意

分布式环境中的性能评估

在多个主机上部署 Spark 将会增加 TCP 通信的整体执行时间延迟。这种延迟与将聚类结果收集回 Spark 驱动程序有关,这是可以忽略的,并且与训练集的大小无关。

性能考虑

这项测试仅仅触及了 Apache Spark 功能的表面。以下是从个人经验中总结出的教训,以避免在部署 Spark 1.3+ 时最常见的性能陷阱:

  • 熟悉与分区、存储级别和序列化相关的最常见的 Spark 配置参数。

  • 除非你使用像 Kryo 这样的有效 Java 序列化库,否则请避免序列化复杂或嵌套的对象。

  • 考虑定义自己的分区函数以减少大型键值对数据集。reduceByKey 的便利性是有代价的。分区数与核心数之比会影响使用键的减少器的性能。

  • 避免不必要的操作,如collectcountlookup。一个操作会减少 RDD 分区中驻留的数据,并将其转发到 Spark 驱动程序。Spark 驱动程序(或主程序)在一个有限的资源 JVM 上运行。

  • 在必要时始终依赖共享或广播变量。例如,广播变量可以改善对大小非常不同的多个数据集的操作性能。让我们考虑一个常见的情况,即连接两个大小非常不同的数据集。将较小的数据集广播到较大数据集 RDD 的每个分区,比将较小的数据集转换为 RDD 并在两个数据集之间执行连接操作要高效得多。

  • 使用累加器变量进行求和,因为它比在 RDD 上使用 reduce 操作要快。

优缺点

越来越多的组织正在采用 Spark 作为他们实时或准实时操作的分发数据处理平台。Spark 快速被采用的原因有以下几点:

  • 它得到了一个庞大且专门的开发者社区的支持[12:15]

  • 内存持久性对于机器学习和统计推断算法中发现的迭代计算来说非常理想

  • 优秀的性能和可扩展性,可以通过流模块进行扩展

  • Apache Spark 利用 Scala 函数式能力以及大量的开源 Java 库

  • Spark 可以利用 Mesos 或 Yarn 集群管理器,这降低了在工作节点之间定义容错和负载平衡的复杂性

  • Spark 需要与商业 Hadoop 供应商如 Cloudera 集成

然而,没有平台是完美的,Spark 也不例外。关于 Spark 最常见的问题或担忧如下:

  • 对于没有函数式编程知识的前开发者来说,创建 Spark 应用程序可能会令人望而却步。

  • 与数据库的集成一直有些滞后,严重依赖 Hive。Spark 开发团队已经开始通过引入 SparkSQL 和数据框 RDD 来解决这些限制。

0xdata Sparkling Water

Sparkling Water是一个将0xdata H2O与 Spark 集成并补充 MLlib 的倡议[12:16]。0xdata 的 H2O 是一个非常快速的开源内存平台,用于处理非常大的数据集的机器学习(0xdata.com/product/)。该框架值得提及的原因如下:

  • 它拥有 Scala API

  • 它完全致力于机器学习和预测分析

  • 它利用了 H2O 的框架数据表示和 Spark 的内存聚类

H2O 在其他优点中拥有广泛的广义线性模型和梯度提升分类的实现。其数据表示由分层 数据框 组成。数据框是可能与其他框共享向量的容器。每个向量由 数据块 组成,而数据块本身是 数据元素 的容器 [12:17]。在撰写本文时,Sparkling Water 处于测试版本。

摘要

这完成了使用 Scala 构建的常见可扩展框架的介绍。在几页纸上描述框架,如 Akka 和 Spark,以及新的计算模型如 Actors、futures 和 RDDs,是非常具有挑战性的。本章应被视为进一步探索这些框架在单主机和大型部署环境中的能力的邀请。

在最后一章中,我们学习了:

  • 异步并发的好处

  • Actor 模型的基本原理以及使用阻塞或回调模式组合 futures

  • 如何实现一个简单的 Akka 集群以提升分布式应用程序的性能

  • Spark 的弹性分布式数据集的易用性和卓越性能以及内存持久化方法

附录 A. 基本概念

机器学习算法在大量使用线性代数和优化技术。详细描述线性代数、微积分和优化算法的概念和实现,将会给本书增加显著复杂性,并使读者从机器学习的本质中分心。

附录列出了书中提到的线性代数和优化的一些基本元素。它还总结了编码实践,并使读者熟悉金融分析的基本知识。

Scala 编程

这里是本书中使用的部分编码实践和设计技术的列表。

库和工具列表

预编译的 Scala for Machine Learning 代码是 ScalaMl-2.11-0.99.jar,位于 $ROOT/project/target/scala-2.11 目录。并非所有章节都需要所有库。以下是列表:

  • 所有章节都需要 Java JDK 1.7 或 1.8

  • 所有章节都需要 Scala 2.10.4 或更高版本

  • Scala IDE for Eclipse 4.0 或更高版本

  • IntelliJ IDEA Scala 插件 13.0 或更高版本

  • sbt 0.13 或更高版本

  • 对于 第三章,数据预处理,第四章,无监督学习,和 第六章,回归和正则化,需要 Apache Commons Math 3.5+。

  • JFChart 1.0.7 是第一章, 入门,第二章,Hello World!,第五章,朴素贝叶斯分类器,和第九章,人工神经网络所必需的

  • Iitb CRF 0.2(包括 LBGFS 和 Colt 库)是第七章,序列数据模型所必需的

  • LIBSVM 0.1.6 是第八章,核模型和支持向量机所必需的

  • Akka 框架 2.2 或更高版本是第十二章,可扩展框架所必需的

  • Apache Spark/MLlib 1.3 或更高版本是第十二章,可扩展框架所必需的

  • Apache Maven 3.3 或更高版本(Apache Spark 1.4 或更高版本所必需)

注意

Spark 开发者的注意事项

与 Apache Spark 的 assembly JAR 文件捆绑的 Scala 库和编译器 JAR 文件包含 Scala 标准库和编译器 JAR 文件的版本,可能与现有的 Scala 库冲突(即 Eclipse 默认 ScalaIDE 库)。

lib 目录包含以下与书中使用的第三方库或框架相关的 JAR 文件:colt、CRF、LBFGS 和 LIBSVM。

代码片段格式

为了提高算法实现的可读性,所有非必需代码,如错误检查、注释、异常或导入,都已省略。书中提供的代码片段中丢弃以下代码元素:

  • 注释:

    /**
    This class is defined as …
    */
    // The MathRuntime exception has to be caught here!
    
  • 类参数和方法参数的验证:

    class BaumWelchEM(val lambda: HMMLambda ...) {
    require( lambda != null, "Lambda model is undefined")
    
  • 类限定符,如 finalprivate

    final protected class MLP[T <% Double] …
    
  • 方法限定符和访问控制(finalprivate 等):

    final def inputLayer: MLPLayer
    private def recurse: Unit =
    
  • 序列化:

    class Config extends Serializable { … }
    
  • 部分函数验证:

    val pfn: PartialFunction[U, V]
    pfn.isDefinedAt(u)
    
  • 中间状态验证:

    assert( p != None, " … ")
    
  • Java 风格异常:

    try { … }
    catch { case e: ArrayIndexOutOfBoundsException  => … }
    if (y < EPS)
       throw new IllegalStateException( … )
    
  • Scala 风格异常:

    Try(process(args)) match {
       case Success(results) => …
       case Failure(e) => …
    }
    
  • 非必需注解:

    @inline def mean = { … }
    @implicitNotFound("Conversion $T to Array[Int] undefined")
    @throws(classOfIllegalStateException)
    
  • 日志和调试代码:

    m_logger.debug( …)
    Console.println( … )
    
  • 辅助和非必需方法

最佳实践

封装

在创建 API 时,一个重要的目标是减少对支持辅助类的访问。有两种封装辅助类的方法,如下所示:

  • 包作用域:支持类是具有保护访问权限的一级类

  • 类或对象作用域:支持类嵌套在主类中

本书中的算法遵循第一种封装模式。

类构造函数模板

类的构造函数在伴随对象中使用 apply 定义,并且类具有包作用域(protected):

protected class A[T { … } 
object A {
  def applyT: A[T] = new A(x, y,…)
  def applyT: A[T] = new A(x, y0, …)
}

例如,实现支持向量机的 SVM 类定义如下:

final protected class SVMT <: AnyVal(implicit f: T => Double) 
  extends ITransform[Array[T]](xt) {

SVM伴随对象负责定义与SVM受保护类相关的所有构造函数(实例工厂):

def applyT <: AnyVal(implicit f: T => Double): SVM[T] = 
  new SVMT

伴随对象与案例类的比较

在前面的例子中,构造函数在伴随对象中显式定义。尽管构造函数的调用与案例类的实例化非常相似,但有一个主要区别;Scala 编译器为实例操作生成几个方法,如 equals、copy、hash 等。

案例类应保留用于单状态数据对象(无方法)。

枚举与案例类的比较

在 Scala 中,关于枚举与案例类模式匹配相对优点的讨论相当普遍 [A:1]。作为一个非常一般的指导原则,枚举值可以被视为轻量级的案例类,或者案例类可以被视为重量级的枚举值。

让我们以一个 Scala 枚举为例,该枚举评估scala.util.Random库的均匀分布:

object A extends Enumeration {
  type TA = Value
  val A, B, C = Value
}

import A._
val counters = Array.fill(A.maxId+1)(0)
Range(0, 1000).foreach( _ => Random.nextInt(10) match {
  case 3 => counters(A.id) += 1
  …
  case _ => { }
})

模式匹配与 Java 的switch语句非常相似。

让我们考虑以下使用案例类进行模式匹配的例子,根据输入选择数学公式:

package AA {
  sealed abstract class A(val level: Int)
  case class AA extends A(3) { def f =(x:Double) => 23*x}
  …
}

import AA._
def compute(a: A, x: Double): Double = a match {
   case a: A => a.f(x)
   …
}

模式匹配使用默认的 equals 方法执行,其字节码为每个案例类自动设置。这种方法比简单的枚举更灵活,但代价是额外的计算周期。

使用枚举而非案例类的优点如下:

  • 枚举在单个属性比较时涉及更少的代码

  • 枚举对于 Java 开发者来说更易读。

使用案例类的优点如下:

  • 案例类是数据对象,支持的属性比枚举 ID 更多

  • 模式匹配针对密封类进行了优化,因为 Scala 编译器知道案例的数量

简而言之,你应该使用枚举来表示单值常量,使用案例类来匹配数据对象。

赋值运算符

与 C++不同,Scala 实际上并没有重载运算符。以下是代码片段中使用的少量运算符的定义:

  • +=:这会将一个元素添加到集合或容器中

  • +:这是对相同类型的两个元素求和

不变分类器的设计模板

本书描述的机器学习算法使用了以下设计模式和组件:

  • 分类器的配置和调整参数集在继承自Config的类(即SVMConfig)中定义。

  • 分类器实现了一个ITransform类型的单调数据转换,模型隐式地从训练集(即SVM[T])生成。分类器至少需要三个参数,如下所示:

    • 执行训练和分类任务的配置

    • Vector[T]类型的输入数据集,xt

    • 标签或expected值的向量

  • Model 继承的类型模型。构造函数负责通过训练创建模型(即,SVMModel)。

让我们看看以下图示:

不可变分类器的设计模板

分类器的通用 UML 类图

例如,支持向量机包的关键组件是分类器 SVMs:

final protected class SVMT <: AnyVal(implicit f: T => Double)
  extends ITransform[Array[T]](xt) with Monitor[Double] {

  type V = 
  val model: Option[SVMModel] = { … }
  override def |> PartialFunction[Array[T], V]
  …
}

训练集是通过将输入数据集 xt 与标签或预期值 expected 结合或压缩来创建的。一旦训练和验证,模型就可以用于预测或分类。

这种设计的主要优势是减少了分类器的生命周期:一个模型要么被定义,可供分类使用,要么没有被创建。

配置和模型类实现如下:

final class SVMConfig(val formulation: SVMFormulation, 
    val kernel: SVMKernel, 
    val svmExec: SVMExecution) extends Config

class SVMModel(val svmmodel: svm_model) extends Model

注意

实施注意事项

为了可读性,本书中的大多数实际示例都省略了验证阶段。

工具类

数据提取

CSV 文件是最常用的格式,用于存储历史财务数据。它是本书中导入数据的默认格式。数据源依赖于 DataSourceConfig 配置类,如下所示:

case class DataSourceConfig(pathName: String, normalize: Boolean, 
     reverseOrder: Boolean, headerLines: Int = 1)

DataSourceConfig 类的参数如下:

  • pathName:如果参数是文件或包含多个输入数据文件的目录,这是要加载的数据文件的相对路径名。大多数文件是 CSV 文件。

  • normalize:这是一个标志,用于指定数据是否需要标准化到 [0, 1]。

  • reverseOrder:这是一个标志,用于指定文件中的数据顺序是否需要反转(例如,时间序列),如果其值为 true

  • headerLines:这指定了列标题和注释的行数。

数据源 DataSource 使用显式配置 DataSourceConfig 实现了 ETransform 类型的数据转换,如第二章中 单调数据转换 部分所述,第二章, Hello World!

final class DataSource(config: DataSourceConfig,
    srcFilter: Option[Fields => Boolean]= None)
  extends ETransformDataSourceConfig {

  type Fields = Array[String]
  type U = List[Fields => Double]
  type V = XVSeries[Double]
  override def |> : PartialFunction[U, Try[V]] 
  ...
}

srcFilter 参数指定了某些行字段的过滤器或条件,以跳过数据集(即,缺失数据或不正确的格式)。作为一个明确的数据转换,DataSource 类的构造函数必须初始化 |> 提取方法的 U 输入类型和 V 输出类型。该方法从字面值行提取提取器到双精度浮点值:

override def |> : PartialFunction[U, Try[V]] = {
  case fields: U if(!fields.isEmpty) =>load.map(data =>{ //1
    val convert = (f: Fields =>Double) => data._2.map(f(_))
    if( config.normalize)  //2
      fields.map(t => new MinMaxDouble) //3
           .normalize(0.0, 1.0).toArray ).toVector //4
    else fields.map(convert(_)).toVector
  })
}

数据通过 load 辅助方法(行 1)从文件中加载。如果需要(行 2),数据将被标准化,通过使用 MinMax 类的实例将每个字面量转换为浮点值(行 3)。最后,MinMax 实例将浮点值序列进行标准化(行 4)。

DataSource 类实现了一系列重要的方法,这些方法在在线源代码中有文档记录。

数据源

书中的示例依赖于使用 CSV 格式的三个不同的财务数据源:

  • YahooFinancials:这是用于历史股票和 ETF 价格的雅虎模式

  • GoogleFinancials:这是用于历史股票和 ETF 价格的谷歌模式

  • Fundamentals:这是用于基本财务分析比率(CSV 文件)

让我们用一个例子来说明使用 YahooFinancials 从数据源提取的过程:

object YahooFinancials extends Enumeration {
   type YahooFinancials = Value
   val DATE, OPEN, HIGH, LOW, CLOSE, VOLUME, ADJ_CLOSE = Value
   val adjClose = ((s:Array[String]) =>
        s(ADJ_CLOSE.id).toDouble)  //5
   val volume =  (s: Fields) => s(VOLUME.id).toDouble
   …
   def toDouble(value: Value): Array[String] => Double = 
       (s: Array[String]) => s(value.id).toDouble
}

让我们看看一个 DataSource 转换的应用示例:从雅虎财经网站加载历史股票数据。数据以 CSV 格式下载。每一列都与一个提取函数相关联(行 5):

val symbols = ArrayString  //6
val prices = symbols
       .map(s => DataSource(s"$path$s.csv",true,true,1))//7
       .map( _ |> adjClose ) //8

需要下载历史数据的股票列表被定义为符号数组(行 6)。每个符号都与一个 CSV 文件相关联(即 CSCO => resources/CSCO.csv)(行 7)。最后,调用 YahooFinancials 提取器的 adjClose 价格(行 8)。

从谷歌财经页面提取的财务数据格式与雅虎财经页面使用的格式相似:

object GoogleFinancials extends Enumeration {
   type GoogleFinancials = Value
   val DATE, OPEN, HIGH, LOW, CLOSE, VOLUME = Value
   val close = ((s:Array[String]) =>s(CLOSE.id).toDouble)//5
   …
}

YahooFinancialsYahooFinancialsFundamentals 类在在线可用的源代码中实现了许多方法。

文档提取

DocumentsSource 类负责提取文本文档或文本文件列表的日期、标题和内容。该类不支持 HTML 文档。DocumentsSource 类实现了 ETransform 类型的单子数据转换,并显式配置了 SimpleDataFormat 类型:

class DocumentsSource(dateFormat: SimpleDateFormat,
    val pathName: String) 
  extends ETransformSimpleDateFormat {

 type U = Option[Long] //2
 type V = Corpus[Long]  //3

 override def |> : PartialFunction[U, Try[V]] = { //4
    case date: U if (filesList != None) => 
      Try( if(date == None ) getAll else get(date) )
 }
 def get(t: U): V = getAll.filter( _.date == t.get)
 def getAll: V  //5
 ...
}

DocumentsSource 类接受两个参数:与文档关联的日期格式以及文档所在的路径名称(行 1)。作为一个显式数据转换,DocumentsSource 类的构造函数必须初始化 U 输入类型(行 2)为日期,并将其转换为 Long 输出类型,并将 V 输出类型转换为 Corpus 以提取 |> 方法。

|> 提取器生成与特定日期关联的语料库,并将其转换为 Long 类型(行 4)。getAll 方法负责提取或排序文档(行 5)。

getAll 方法的实现以及其他 DocumentsSource 类的方法在在线可用的文档源代码中有描述。

DMatrix 类

一些判别性学习模型需要在矩阵的行和列上执行操作。DMatrix 类简化了对列和行的读写操作:

class DMatrix(val nRows: Int, val nCols: Int, 
     val data: DblArray) {
 def apply(i: Int, j: Int): Double = data(i*nCols+j)
 def row(iRow: Int): DblArray = { 
   val idx = iRow*nCols
   data.slice(idx, idx + nCols)
 }
 def col(iCol: Int): IndexedSeq[Double] =
   (iCol until data.size by nCols).map( data(_) )
 def diagonal: IndexedSeq[Double] = 
    (0 until data.size by nCols+1).map( data(_))
 def trace: Double = diagonal.sum
  …
}

apply 方法返回矩阵的一个元素。row 方法返回一个行数组,col 方法返回列元素的索引序列。diagonal 方法返回对角元素的索引序列,trace 方法对对角元素求和。

DMatrix 类支持元素、行和列的归一化;转置;以及元素、列和行的更新。DMatrix 类实现了许多方法,这些方法在在线源代码中有详细文档。

计数器

Counter 类实现了一个通用的可变计数器,其中键是一个参数化类型。键的出现次数由一个可变哈希映射管理:

class Counter[T] extends mutable.HashMap[T, Int] {
  def += (t: T): type.Counter = super.put(t, getOrElse(t, 0)+1) 
  def + (t: T): Counter[T] = { 
   super.put(t, getOrElse(t, 0)+1); this 
  }
  def ++ (cnt: Counter[T]): type.Counter = { 
    cnt./:(this)((c, t) => c + t._1); this
  }
  def / (cnt: Counter[T]): mutable.HashMap[T, Double] = map { 
    case(str, n) => (str, if( !cnt.contains(str) ) 
      throw new IllegalStateException(" ... ")
        else n.toDouble/cnt.get(str).get )
  }
  …
}

+= 运算符更新 t 键的计数器并返回自身。+ 运算符更新并复制更新的计数器。++ 运算符使用另一个计数器更新此计数器。/ 运算符将每个键的计数除以另一个计数器的计数。

Counter 类实现了一组重要的方法,这些方法在在线源代码中有详细文档。

监控器

Monitor 类有两个目的:

  • 它使用 showerror 方法存储日志信息和错误消息。

  • 它收集并显示与算法的递归或迭代执行相关的变量。

数据在每个迭代或递归时收集,然后以迭代作为 x 轴值的时间序列形式显示:

trait Monitor[T] {
  protected val logger: Logger
  lazy val _counters = 
      new mutable.HashMap[String, mutable.ArrayBuffer[T]]()

  def counters(key: String): Option[mutable.ArrayBuffer[T]]
  def count(key: String, value: T): Unit 
  def display(key: String, legend: Legend)
      (implicit f: T => Double): Boolean
  def show(msg: String): Int = DisplayUtils.show(msg, logger)
  def error(msg: String): Int = DisplayUtils.error(msg, logger)
  ...
}

counters 方法返回与特定键相关联的数组。count 方法更新与键关联的数据。display 方法绘制时间序列。最后,showerror 方法将信息和错误消息发送到标准输出。

Monitor 类的实现源代码的文档可在网上找到。

数学

本节简要描述了本书中使用的一些数学概念。

线性代数

机器学习中使用的许多算法,如凸损失函数的最小化、主成分分析或最小二乘回归,不可避免地涉及矩阵的操作和变换。关于这个主题有许多优秀的书籍,从价格低廉的 [A:2] 到复杂的 [A:3]。

QR 分解

QR 分解(或 QR 分解)是将矩阵 A 分解为正交矩阵 Q 和上三角矩阵 R 的乘积。因此,A=QRQ^T Q=I [A:4]。

如果 A 是一个实数、平方且可逆的矩阵,则分解是唯一的。在矩形矩阵 A 的情况下,mn 列,且 m > n,分解实现为两个矩阵向量的点积:A = [Q[1], Q[2]].[R[1], R[2]]^T,其中 Q[1] 是一个 mn 列的矩阵,Q[2] 是一个 mn 列的矩阵,R[1] 是一个 nn 列的上三角矩阵,R[2] 是一个 mn 列的零矩阵。

QR 分解是一种可靠的方法,用于解决方程数量(行)超过变量数量(列)的大规模线性方程组。对于具有 m 维度和 n 个观察值的训练集,其渐近计算时间复杂度为 O(mn²-n³/3)

它用于最小化普通最小二乘回归的损失函数(参见第六章中的普通最小二乘回归部分,回归和正则化)。

LU 分解

LU 分解是一种用于求解矩阵方程A.x = b的技术,其中A是一个非奇异矩阵,而xb是两个向量。该技术包括将原始矩阵A分解为简单矩阵的乘积A= A[1]A[2]…A[n]

  • 基本 LU 分解:这定义矩阵A为下三角单位矩阵L和上三角矩阵U的乘积。因此,A=LU

  • 带置换的 LU 分解定义矩阵A为置换矩阵P、下三角单位矩阵L和上三角矩阵U的乘积。因此,A=PLU

LDL 分解

实矩阵的 LDL 分解定义一个实正矩阵A为下三角单位矩阵L、对角矩阵DL的转置矩阵LT*的乘积,即*LT。因此,A=LDL^T

Cholesky 分解

实矩阵的Cholesky 分解(或Cholesky 分解)是 LU 分解的特殊情况[A:4]。它将正定矩阵A分解为下三角矩阵L和其共轭转置LT*的乘积。因此,*A=LLT

Cholesky 分解的计算时间复杂度为O(mn²),其中m是特征数(模型参数)的数量,n是观测数的数量。Cholesky 分解用于线性最小二乘卡尔曼滤波(参见第三章中的递归算法部分,数据预处理)和非线性拟牛顿优化器。

奇异值分解

实矩阵的奇异值分解SVD)定义一个mn列的实矩阵Am个实方单位矩阵Umn个矩形对角矩阵Σ和实矩阵V的转置矩阵VT*的乘积。因此,*A=UΣVT*。

UV矩阵的列是正交基,对角矩阵Σ的值是奇异值[A:4]。对于n个观测数和m个特征,奇异值分解的计算时间复杂度为O(mn²-n³)。奇异值分解用于最小化总最小二乘和求解齐次线性方程。

特征值分解

实方阵A的特征分解是标准分解,Ax = λx*。

λ 是与向量 x 对应的特征值(标量)。然后定义 n×n 矩阵 AA = QDQ^TQ 是包含特征向量的方阵,D 是对角矩阵,其对角线元素是与特征向量关联的特征值 [A:5] 和 [A:6]。特征值分解用于主成分分析(参考第四章“无监督学习”中的主成分分析部分,Chapter 4,无监督学习)。

代数和数值库

除了在 第三章 的 数据预处理、第五章 的 朴素贝叶斯分类器、第六章 的 回归和正则化 以及 第十二章 的 可扩展框架 中使用的 Apache Commons Math 之外,还有许多开源代数库可供开发者作为 API 使用。它们如下所示:

  • jBlas 1.2.3 (Java) 由 Mikio Braun 创建,遵循 BSD 修订版许可。这个库为 Java 和 Scala 开发者提供了一个高级的 Java 接口,用于 BLASLAPACK (github.com/mikiobraun/jblas).

  • Colt 1.2.0 (Java) 是一个高性能科学库,在 CERN 开发,遵循欧洲核研究组织许可 (acs.lbl.gov/ACSSoftware/colt/).

  • AlgeBird 2.10 (Scala) 是在 Twitter 开发的,遵循 Apache Public License 2.0 许可。它使用单例和单子定义了抽象线性代数概念。这个库是使用 Scala 进行高级函数式编程的杰出例子 (github.com/twitter/algebird).

  • Breeze 0.8 (Scala) 是一个使用 Apache Public License 2.0 许可最初由 David Hall 创建的数值处理库。它是 ScalaNLP 套件中机器学习和数值计算库的组成部分 (www.scalanlp.org/).

Apache Spark/MLlib 框架捆绑了 jBlas、Colt 和 Breeze。Iitb 框架用于条件随机字段,使用了 Colt 线性代数组件。

注意

Java/Scala 库的替代方案

如果您的应用程序或项目在有限的资源(CPU 和 RAM 内存)下需要高性能数值处理工具,那么如果便携性不是约束条件,使用 C/C++ 编译的库是一个极好的替代方案。二进制函数通过 Java 本地接口 (JNI) 访问。

一阶谓词逻辑

命题逻辑公理或命题的表述。命题有几种形式化的表示方法:

  • 名词-动词-形容词:例如,股票价格方差 超过 0.76损失函数最小化 不收敛

  • 实体值 = 布尔值:例如,股票价格方差 大于 0.76 = true损失函数最小化 收敛 = false

  • 变量操作值:例如,股票价格方差 > 0.76最小化损失函数 != 收敛

命题逻辑受布尔代数规则的约束。让我们考虑三个命题:PQR以及三个布尔算子NOTANDOR

  • NOT (NOT P) = P

  • P AND false = falseP AND true = PP or false = P,和P or true = P

  • P AND Q = Q AND PP OR Q = Q OR P

  • P AND (Q AND R) = (P AND Q) AND R

一阶谓词逻辑,也称为一阶谓词演算,是命题逻辑的量化[A:7]。一阶逻辑最常见的形式如下:

  • 规则(例如,IF P THEN action

  • 存在性算子

一阶逻辑用于描述学习分类系统中的分类器(参见第十一章的XCS 规则部分,强化学习)。

雅可比和海森矩阵

让我们考虑一个具有n个变量x[i]m个输出y[j]的函数,使得f: {x[i]} -> {y[j] =fj}

雅可比矩阵[A:8]是一个连续、微分函数输出值的首次偏导数的矩阵:

雅可比和海森矩阵

海森矩阵是一个连续、二阶可微函数的二阶偏导数的方阵:

雅可比和海森矩阵

以下是一个例子:

雅可比和海森矩阵

优化技术概述

与线性代数算法相关的相同评论适用于优化。对这些技术进行深入研究会使本书变得不切实际。然而,优化对于机器学习算法的效率和,在一定程度上,准确性至关重要。在这个领域的一些基本知识对于构建适用于大数据集的实用解决方案大有裨益。

梯度下降法

最速下降

最速下降法(或梯度下降法)是用于寻找任何连续、可微函数F的局部最小值或任何定义、可微、凸函数的全局最小值的最简单技术之一[A:9]。在迭代t+1时,向量或数据点x[t+1]的值是从前一个值x[t]使用函数F的梯度 F和斜率γ计算得出的:

最速下降

最陡下降算法用于解决非线性方程组和逻辑回归中的损失函数最小化问题(参见第六章的数值优化部分[第六章. 回归和正则化],回归和正则化),在支持向量机分类器中(参见第八章的不可分情况 – 软间隔部分[第八章. 核模型和支持向量机],核模型和支持向量机),以及在多层感知器中(参见第九章的训练和分类部分[第九章. 人工神经网络],人工神经网络)。

共轭梯度

共轭梯度用于解决无约束优化问题和线性方程组。它是正定、对称的平方矩阵 LU 分解的替代方案。方程 Ax = b 的解 x 被扩展为 n 个基正交方向 p[i](或共轭方向)的加权求和:

共轭梯度

通过计算第 i 个共轭向量 p[i] 然后计算系数 α[i] 来提取解 x

随机梯度下降

随机梯度方法是最陡下降的一种变体,通过将目标函数 F 定义为可微分的基函数 f[i] 的和来最小化凸函数:

随机梯度下降

在迭代 t+1 时,解 x[t+1] 是从迭代 t 时的值 x[t],步长(或学习率) α,以及基函数梯度的和 [A:10] 计算得出的。随机梯度下降通常比其他梯度下降或拟牛顿方法在收敛到凸函数的解时更快。随机梯度下降用于逻辑回归、支持向量机和反向传播神经网络。

随机梯度对于具有大量数据集的判别模型特别适合 [A:11]。Spark/MLlib 广泛使用随机梯度方法。

注意

批量梯度下降

批量梯度下降在第一章的让我们试试车部分下的第 5 步 – 实现分类器中引入并实现,入门

近似牛顿算法

拟牛顿算法是牛顿法寻找最大化或最小化函数 F(一阶导数为零)的向量或数据点的值的变体 [A:12]。

牛顿法是一种众所周知且简单的优化方法,用于求解方程F(x) = 0的解,其中F是连续且二阶可微的。它依赖于泰勒级数展开来近似函数F,使用变量∆x = x[t+1]-x[t]的二次近似来计算下一次迭代的值,使用一阶F'和二阶F"导数:

拟牛顿算法

与牛顿法不同,拟牛顿法不需要计算目标函数的二阶导数,即 Hessian 矩阵;只需对其进行近似[A:13]。有几种方法可以近似计算 Hessian 矩阵。

BFGS

Broyden-Fletcher-Goldfarb-ShannoBGFS)是一种用于解决无约束非线性问题的拟牛顿迭代数值方法。在迭代t时,使用前一次迭代t的值来近似 Hessian 矩阵H[t+1],即H[t+1]=H[t] + U[t] + V[t],应用于方向p[t]的牛顿方程:

BFGS

BFGS 用于条件随机场和 L[1]和 L[2]回归的成本函数的最小化。

L-BFGS

BFGS 算法的性能与在内存(UV)中缓存 Hessian 矩阵近似的成本相关,这会导致高内存消耗。

有限内存 Broyden-Fletcher-Goldfarb-ShannoL-BFGS)算法是 BFGS 的一种变体,它使用最少的计算机 RAM。该算法在迭代t时维护∆x[t]和梯度∆G[t]的最后m个增量更新,然后计算这些值用于下一个步骤t+1

L-BFGS

它由 Apache Commons Math 3.3+、Apache Spark/MLlib 1.0+、Colt 1.0+和 Iiitb CRF 库支持。L-BFGS 用于条件随机场中损失函数的最小化(参见第七章中的条件随机场部分,序列数据模型)。

非线性最小二乘法

让我们考虑非线性函数y = F(x, w)的最小二乘法的经典最小化,其中w[i]是观测值y, x[i]的参数。目标是使残差平方和r[i]最小化,如下所示:

非线性最小二乘法最小化

高斯-牛顿

高斯-牛顿技术是牛顿法的一种推广。该技术通过在迭代t+1时更新参数w[t+1]来解决非线性最小二乘法,使用一阶导数(或雅可比矩阵):

高斯-牛顿

高斯-牛顿算法在逻辑回归中使用(参见第六章中的逻辑回归部分,回归和正则化)。

Levenberg-Marquardt

Levenberg-Marquardt 算法是解决非线性最小二乘和曲线拟合问题的 Gauss-Newton 技术的替代方案。该方法包括将梯度(雅可比)项添加到残差r[i]中,以近似最小二乘误差:

Levenberg-Marquardt

Levenberg-Marquardt 算法用于逻辑回归的训练(参考第六章中的逻辑回归部分,回归和正则化)。

拉格朗日乘数

拉格朗日乘数方法是一种优化技术,用于在满足等式约束的情况下找到多元函数的局部最优值[A:14]。问题表述为在 g(x) = c(c 是常数,x 是变量或特征向量)的约束下最大化 f(x)

这种方法引入了一个新变量λ,将约束g整合到一个称为拉格朗日函数 (x, λ)的函数中。让我们记∇ℒ,它是关于变量x[i]λ的梯度。拉格朗日乘数通过最大化来计算:

Lagrange multipliers

以下是一个示例:

Lagrange multipliers

拉格朗日乘数用于在不可分情况下最小化线性支持向量机的损失函数(参考第八章中的不可分情况 – 软间隔情况部分,核模型和支持向量机)。

动态规划概述

动态规划的目的是将一个优化问题分解成一系列称为子结构的步骤[A:15]。动态规划适用于两种类型的问题。

全局优化问题的解可以分解为其子问题的最优解。子问题的解称为最优子结构。贪婪算法或计算图的最小跨度是分解为最优子结构的例子。这些算法可以递归或迭代地实现。

如果子问题的数量较少,则将全局问题的解递归地应用于子问题。这种方法被称为使用重叠子结构的动态规划。在隐藏马尔可夫模型上的前向-后向传递、维特比算法(参考第七章中的维特比算法部分,顺序数据模型),或在多层感知器中的误差反向传播(参考第九章中的步骤 2 – 误差反向传播部分,人工神经网络)是重叠子结构的良好示例。

动态规划解决方案的数学公式是针对它试图解决的问题而特定的。动态规划技术也常用于诸如汉诺塔等数学谜题。

财务 101

本书中的练习与历史财务数据相关,需要读者对金融市场和报告有一定的基本理解。

基本分析

基本分析是一套用于评估证券(股票、债券、货币或商品)的技术,它涉及通过审查宏观和微观的金融和经济报告来尝试衡量其内在价值。基本分析通常用于通过使用各种财务比率来估计股票的最佳价格。

本书使用了大量的财务指标。以下是常用指标的定义[A:16]:

  • 每股收益(EPS):这是净利润与流通在外股票数量的比率。

  • 市盈率(PE):这是每股市场价格与每股收益的比率。

  • 市销率(PS):这是每股市场价格与毛销售额(或收入)的比率。

  • 市净率(PB):这是每股市场价格与每股资产负债表价值的比率。

  • 市盈率/增长率(PEG):这是每股价格/市盈率(PE)与每股收益年增长率的比率。

  • 营业利润:这是营业收入与营业费用的差额。

  • 净销售额:这是收入或毛销售额与商品成本或销售成本的差额。

  • 营业利润率:这是营业利润与净销售额的比率。

  • 净利率:这是净利润与净销售额(或净收入)的比率。

  • 空头头寸:这是已售出但尚未平仓的股票数量。

  • 空头头寸比率:这是空头头寸与流通在外股票总数的比率。

  • 每股现金:这是每股现金价值与每股市场价格的比率。

  • 派息比率:这是扣除非常项目后,以现金股息形式支付给普通股东的初级/基本每股收益的百分比。

  • 年度股息收益率:这是过去 12 个月滚动期间支付的股息总额与当前股价的比率。包括定期和额外股息。

  • 股息保障倍数:这是最近 12 个月连续 12 个月中,扣除非常项目后可用于普通股股东的收入的比率,与支付给普通股东的股息总额的比率,以百分比表示。

  • 国内生产总值(GDP):这是衡量一个国家经济产出的综合指标。它实际上衡量了商品生产和服务的交付所增加的价值总和。

  • 消费者价格指数CPI):这是一个衡量消费者价格变动情况的指标,由劳工统计局使用任意商品和服务篮子来评估通货膨胀趋势。

  • 联邦基金利率:这是银行在联邦储备持有的余额所交易的利率。这些余额被称为联邦基金。

技术分析

技术分析是一种通过研究从价格和成交量中得出的过去市场信息来预测任何给定证券价格走势的方法论。用更简单的话说,它是研究价格活动和价格模式,以识别交易机会的方法 [A:17]。股票、商品、债券或金融期货的价格反映了市场参与者处理过的关于该资产的公开信息

术语

  • 熊市或熊市头寸:这是通过赌证券价格将下跌来尝试获利的。

  • 牛市或牛市头寸:这是通过赌证券价格将上涨来尝试获利的。

  • 多头头寸:这与牛市相同。

  • 中性头寸:这是通过赌证券价格不会显著变化来尝试获利的。

  • 振荡器:这是一种使用某些统计公式来衡量证券价格动量的技术指标。

  • 超买:当证券价格因一个或多个交易信号或指标显示上升过快时,该证券被认为是超买的。

  • 超卖:当证券价格因一个或多个交易信号或指标显示下降过快时,该证券被认为是超卖的。

  • 相对强弱指数RSI):这是一个计算在平均交易时段中收盘价高于开盘价的交易次数的平均值,与在平均交易时段中收盘价低于开盘价的交易次数的平均值。该值在[0, 1]或[0, 100%]范围内归一化。

  • 阻力位:这是证券价格范围的最高点。价格一旦达到阻力位就会回落。

  • 空头头寸:这与熊市相同。

  • 支撑位:这是在一定时期内证券价格范围的最低点。价格一旦达到支撑位就会反弹。

  • 技术指标:这是从证券价格(可能还有其交易量)中派生出的变量。

  • 交易区间:在一定时期内,证券的交易区间是该时期最高价和最低价之间的差值。

  • 交易信号:这是一个当技术指标达到预定义的值(向上或向下)时触发的信号。

  • 波动性:这是指在一定时期内,证券价格的方差或标准差。

交易数据

从谷歌或雅虎财经页面提取的原始交易数据包括以下内容:

  • adjClose(或收盘价):这是交易结束时的证券调整后或非调整后的价格

  • 开盘价:这是交易开始时证券的价格

  • 高点:这是交易期间证券的最高价格

  • 低点:这是交易期间证券的最低价格

让我们看一下以下图表:

交易数据

我们可以从原始交易数据中推导出以下指标:

  • 价格波动性:volatility = 1.0 – high/low

  • 价格变动:vPrice = adjClose – open

  • 连续两个交易日的价格差异(或变化):dPrice = adjClose – prevClose = adjClose(t) – adjClose(t-1)

  • 连续两个交易日的成交量差异:dVolume = volume(t)/volume(t-1) – 1.0

  • 连续两个交易日的波动性差异:dVolatility = volatility(t)/volatility(t-1) – 1.0

  • 过去 T 个交易日相对价格变化:rPrice = price(t)/average(price over T) – 1.0

  • 过去 T 个交易日相对成交量变化:rVolume = volume(t)/average(volume over T) – 1.0

  • 过去 T 个交易日相对波动性变化:rVolatility = volatility(t)/average(volatility over T) – 1.0

交易信号和策略

目的是创建一个由价格和成交量导出的变量 x,即 x= f (price, volume),然后为 op 是比较布尔运算符(如 >=)的 x op c 生成谓词,该运算符将 x 的值与预定的阈值 c 进行比较。

让我们考虑从价格衍生出的最常见的技术指标之一:相对强弱指数 RSI 或归一化 RSI nRSI,其公式在此提供作为参考:

注意

相对强弱指数

定义为 T 个交易期的 RSI,其开盘价为 p[o],收盘价为 p[c]

交易信号和策略

交易信号是一个使用技术指标 nRSIT < 0.2 的谓词。在交易术语中,对于谓词为真的任何时间周期 t,都会发出信号:

交易信号和策略

使用相对强弱指数可视化超卖和超买位置

交易者通常不会依赖单一的交易信号来做出合理的决策。

例如,如果 G 是黄金的价格,I[10] 是 10 年期国债的当前利率,而 RSI[sp500] 是标准普尔 500 指数的相对强弱指数,那么我们可以得出结论,美元兑日元的汇率增加在以下交易策略下最大化:{G < $1170 and I[10] > 3.9% and RSI[sp500] > 0.6 and RSI[sp500] < 0.8}

价格模式

技术分析假设历史价格包含一些虽然嘈杂但重复出现的模式,可以使用统计方法发现。本书中最常用的模式是趋势、支撑和阻力水平[A:18],如下图表所示:

价格模式

技术分析中趋势、支撑和阻力水平的示意图

期权交易

期权是一种合约,赋予买方在特定价格和特定日期或之前购买或出售证券的权利,但没有义务[A:19]。

如下所述,有两种期权类型:看涨期权和看跌期权:

  • 看涨期权赋予持有者在特定时间段内以特定价格购买证券的权利。看涨期权的买方预期在期权到期前,证券的价格将显著高于行权价格。

  • 看跌期权赋予持有者在特定时间段内以特定价格出售证券的权利。看跌期权的买方预期在期权到期前,股票的价格将低于行权价格。

让我们考虑一个以$23 的行权价格购买 100 股的看涨期权合约,总成本为$270(每份期权$2.7)。看涨期权持有者的最大损失是损失溢价或$270,当期权到期时。然而,利润可能是几乎无限的。如果看涨期权到期时证券的价格达到$36,所有者将获得利润($36 - $23)*100 - $270 = $1030。投资的回报率是 1030/270 = 380%。购买然后出售股票将产生 50%的投资回报率,即 36/24 - 1 = 50%。这个例子很简单,没有考虑交易费或保证金成本[A:20]:

让我们看一下以下图表:

期权交易

看涨期权定价的示意图

金融数据来源

有许多金融数据来源可用于实验机器学习和验证模型[A:21]:

推荐的在线课程

参考文献

第三部分. 模块 3

掌握 Scala 机器学习

利用 Scala、Spark 和 Hadoop 的强大工具提高您在高效数据分析与数据处理方面的技能

第一章. 探索性数据分析

在我深入探讨本书后面更复杂的数据分析方法之前,我想停下来谈谈基本的数据探索性任务,几乎所有数据科学家至少花费 80-90% 的生产力时间在这些任务上。仅数据准备、清洗、转换和合并数据本身就是一个价值 440 亿美元/年的产业(《大数据时代的数据准备》,作者:Federico CastanedoBest Practices for Data IntegrationO'Reilly Media2015)。鉴于这一事实,人们最近才开始在开发最佳实践的科学、建立良好的习惯、文档和整个数据准备过程的教学材料上投入更多时间,这确实令人惊讶(《美丽的数据:优雅数据解决方案背后的故事》,由 Toby SegaranJeff Hammerbacher 编著,O'Reilly Media2009 以及 Sandy Ryza 等人所著的 Advanced Analytics with Spark: Patterns for Learning from Data at ScaleO'Reilly Media2015*)。

很少有数据科学家会就特定的工具和技术达成一致意见——进行探索性数据分析有多种方法,从 Unix 命令行到使用非常流行的开源和商业 ETL 和可视化工具。本章的重点是如何使用 Scala 和基于笔记本电脑的环境来利用通常被称为编程的函数式范式的技术。正如我将讨论的,这些技术可以转移到使用 Hadoop/Spark 的分布式机器系统上的探索性分析中。

函数式编程与这有什么关系?Spark 是用 Scala 开发的,这并非没有原因。许多位于函数式编程基础之上的基本原则,如惰性评估、不可变性、无副作用、列表推导和单子,非常适合在分布式环境中处理数据,特别是在对大数据进行数据准备和转换任务时。得益于抽象,这些技术在本地工作站或笔记本电脑上也能很好地工作。如前所述,这并不妨碍我们在连接到分布式存储/处理节点集群的现代笔记本电脑上处理数十 TB 的非常大的数据集。我们可以一次处理一个主题或关注领域,但通常我们甚至不需要对数据集进行适当的分区采样或过滤。我们将使用 Scala 作为我们的主要工具,但在需要时也会求助于其他工具。

虽然 Scala 在某种意义上是完整的,即其他语言可以实现的任何内容都可以在 Scala 中实现,但 Scala 本质上是一种高级语言,甚至是一种脚本语言。你不必处理数据结构和算法实现中的低级细节,这些细节在 Java 或 C++等语言中已经由大量的应用程序和时间测试过——尽管 Scala 今天有自己的集合和一些基本的算法实现。具体来说,在本章中,我将专注于使用 Scala/Spark 仅进行高级任务。

在本章中,我们将涵盖以下主题:

  • 安装 Scala

  • 学习简单的数据探索技术

  • 学习如何对原始数据集进行下采样以加快周转速度

  • 讨论在 Scala 中实现基本数据转换和聚合的实现

  • 熟悉大数据处理工具,如 Spark 和 Spark Notebook

  • 获取一些基本数据集可视化的代码

Scala 入门

如果你已经安装了 Scala,你可以跳过这一段。你可以从www.scala-lang.org/download/获取最新的 Scala 下载。我在 Mac OS X El Capitan 10.11.5 上使用了 Scala 版本 2.11.7。你可以使用你喜欢的任何其他版本,但可能与其他包(如 Spark)存在一些兼容性问题,这是开源软件中常见的问题,因为技术的采用通常落后于几个发布版本。

小贴士

在大多数情况下,你应该尝试保持推荐版本之间的精确匹配,因为版本之间的差异可能导致模糊的错误和漫长的调试过程。

如果你正确安装了 Scala,在输入scala后,你应该看到以下类似的内容:

[akozlov@Alexanders-MacBook-Pro ~]$ scala
Welcome to Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40).
Type in expressions to have them evaluated.
Type :help for more information.

scala>

这是一个 Scala 读取-评估-打印循环REPL)提示。虽然 Scala 程序可以编译,但本章的内容将在 REPL 中进行,因为我们专注于与交互,可能会有一些例外。:help命令提供了 REPL 中可用的某些实用命令(注意开头的冒号):

Scala 入门

分类的字段的不同值

现在,你有一个数据集和一台计算机。为了方便,我为你提供了一个小型的匿名和混淆的点击流数据样本,你可以从github.com/alexvk/ml-in-scala.git获取这个样本。chapter01/data/clickstream目录中的文件包含时间戳、会话 ID 以及一些额外的调用时的事件信息,如 URL、分类信息等。首先要做的事情是对数据进行转换,以找出数据集中不同列的值分布。

图 01-1 展示的截图显示了 gzcat chapter01/data/clickstream/clickstream_sample.tsv.gz | less –U 命令在终端窗口中的数据集输出。列由制表符(^I)分隔。可以注意到,正如许多现实世界的大数据数据集一样,许多值是缺失的。数据集的第一列可识别为时间戳。该文件包含复杂的数据,如数组、结构体和映射,这是大数据数据集的另一个特征。

Unix 提供了一些工具来剖析数据集。可能,lesscutsortuniq 是最常用于文本文件操作的工具。Awksedperltr 可以执行更复杂的转换和替换。幸运的是,Scala 允许你在 Scala REPL 中透明地使用命令行工具,如下面的截图所示:

分类字段的唯一值

图 01-1. less -U Unix 命令的输出作为点击流文件

幸运的是,Scala 允许你在 Scala REPL 中透明地使用命令行工具:

[akozlov@Alexanders-MacBook-Pro]$ scala
…
scala> import scala.sys.process._
import scala.sys.process._
scala> val histogram = ( "gzcat chapter01/data/clickstream/clickstream_sample.tsv.gz"  #|  "cut -f 10" #| "sort" #|  "uniq -c" #| "sort -k1nr" ).lineStream
histogram: Stream[String] = Stream(7731 http://www.mycompany.com/us/en_us/, ?)
scala> histogram take(10) foreach println 
7731 http://www.mycompany.com/us/en_us/
3843 http://mycompanyplus.mycompany.com/plus/
2734 http://store.mycompany.com/us/en_us/?l=shop,men_shoes
2400 http://m.mycompany.com/us/en_us/
1750 http://store.mycompany.com/us/en_us/?l=shop,men_mycompanyid
1556 http://www.mycompany.com/us/en_us/c/mycompanyid?sitesrc=id_redir
1530 http://store.mycompany.com/us/en_us/
1393 http://www.mycompany.com/us/en_us/?cp=USNS_KW_0611081618
1379 http://m.mycompany.com/us/en_us/?ref=http%3A%2F%2Fwww.mycompany.com%2F
1230 http://www.mycompany.com/us/en_us/c/running

我使用了 scala.sys.process 包从 Scala REPL 调用熟悉的 Unix 命令。从输出中,我们可以立即看到我们网店的主要客户对男鞋和跑步感兴趣,并且大多数访客正在使用推荐代码,KW_0611081618

小贴士

当我们开始使用复杂的 Scala 类型和方法时,可能会有人想知道。请稍等,在 Scala 之前已经创建了大量的高度优化的工具,它们对于探索性数据分析来说效率更高。在初始阶段,最大的瓶颈通常是磁盘 I/O 和缓慢的交互性。稍后,我们将讨论更多迭代算法,这些算法通常更占用内存。还要注意,UNIX 管道操作可以在现代多核计算机架构上隐式并行化,就像在 Spark 中一样(我们将在后面的章节中展示)。

已经证明,在输入数据文件上使用压缩,无论是隐式还是显式,实际上可以节省 I/O 时间。这对于(大多数)现代半结构化数据集尤其如此,这些数据集具有重复的值和稀疏的内容。在现代快速的多核计算机架构上,解压缩也可以隐式并行化,从而消除计算瓶颈,除非,可能在硬件中隐式实现压缩的情况下(SSD,在这种情况下我们不需要显式压缩文件)。我们还建议使用目录而不是文件作为数据集的模式,其中插入操作简化为将数据文件放入目录中。这就是大数据 Hadoop 工具(如 Hive 和 Impala)展示数据集的方式。

数值字段的摘要

让我们来看看数值数据,尽管数据集中的大多数列都是分类的或复杂的。总结数值数据的传统方法是五数摘要,它表示中位数或平均值、四分位数范围以及最小值和最大值。我将把中位数和四分位数范围的计算留到 Spark DataFrame 介绍时,因为它使这些计算变得极其简单;但我们可以通过应用相应的运算符在 Scala 中计算平均值、最小值和最大值:

scala> import scala.sys.process._
import scala.sys.process._
scala> val nums = ( "gzcat chapter01/data/clickstream/clickstream_sample.tsv.gz"  #|  "cut -f 6" ).lineStream
nums: Stream[String] = Stream(0, ?) 
scala> val m = nums.map(_.toDouble).min
m: Double = 0.0
scala> val m = nums.map(_.toDouble).sum/nums.size
m: Double = 3.6883642764024662
scala> val m = nums.map(_.toDouble).max
m: Double = 33.0

在多个字段中进行 grep 搜索

有时需要了解某个值在多个字段中的外观——最常见的是 IP/MAC 地址、日期和格式化消息。例如,如果我想查看文件或文档中提到的所有 IP 地址,我需要将上一个示例中的cut命令替换为grep -o -E [1-9][0-9]{0,2}(?:\\.[1-9][0-9]{0,2}){3},其中-o选项指示grep只打印匹配的部分——一个更精确的 IP 地址正则表达式应该是grep –o –E (?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?),但在我和原始的笔记本电脑上大约慢 50%,但原始的正则表达式在大多数实际情况下都适用。我将把它作为一个练习,在书中提供的样本文件上运行这个命令。

基本抽样、分层抽样和一致抽样

我遇到过很多数据从业者,他们轻视抽样。理想情况下,如果一个人能够处理整个数据集,模型只会得到改进。在实践中,这种权衡要复杂得多。首先,一个人可以在抽样集上构建更复杂的模型,尤其是如果模型构建的时间复杂度是非线性的——在大多数情况下,至少是N log(N)*。更快的模型构建周期允许你更快地迭代模型并收敛到最佳方法。在许多情况下,“行动时间”会打败基于完整数据集构建的模型在预测精度上的潜在改进。

抽样可以与适当的过滤相结合——在许多实际情况下,一次关注一个子问题可以更好地理解整个问题域。在许多情况下,这种划分是算法的基础,例如在稍后讨论的决策树中。通常,问题的性质要求你关注原始数据的一个子集。例如,网络安全分析通常关注一组特定的 IP 地址,而不是整个网络,因为它允许更快地迭代假设。如果不在正确的轨道上,将网络中所有 IP 地址的集合包括在内可能会在最初使事情复杂化。

当处理罕见事件,例如 ADTECH 中的点击次数时,以不同的概率对正负案例进行抽样,这有时也被称为过采样,通常能在短时间内带来更好的预测结果。

基本上,采样等同于对每一行数据抛硬币——或者调用随机数生成器。因此,它非常类似于流过滤操作,这里的过滤是在随机数增强列上进行的。让我们考虑以下示例:

import scala.util.Random
import util.Properties

val threshold = 0.05

val lines = scala.io.Source.fromFile("chapter01/data/iris/in.txt").getLines
val newLines = lines.filter(_ =>
    Random.nextDouble() <= threshold
)

val w = new java.io.FileWriter(new java.io.File("out.txt"))
newLines.foreach { s =>
    w.write(s + Properties.lineSeparator)
}
w.close

这一切都很好,但它有以下缺点:

  • 结果文件的行数事先是未知的——尽管平均来说应该是原始文件的 5%

  • 采样的结果是非确定的——很难重新运行此过程进行测试或验证

为了解决第一个问题,我们需要传递一个更复杂的对象给函数,因为我们需要在原始列表遍历期间保持状态,这使得原始算法功能性和并行性降低(这将在稍后讨论):

import scala.reflect.ClassTag
import scala.util.Random
import util.Properties

def reservoirSampleT: ClassTag: Array[T] = {
  val reservoir = new ArrayT
  // Put the first k elements in the reservoir.
  var i = 0
  while (i < k && input.hasNext) {
    val item = input.next()
    reservoir(i) = item
    i += 1
  }

  if (i < k) {
    // If input size < k, trim the array size
    reservoir.take(i)
  } else {
    // If input size > k, continue the sampling process.
    while (input.hasNext) {
      val item = input.next
      val replacementIndex = Random.nextInt(i)
      if (replacementIndex < k) {
        reservoir(replacementIndex) = item
      }
      i += 1
    }
    reservoir
  }
}

val numLines=15
val w = new java.io.FileWriter(new java.io.File("out.txt"))
val lines = io.Source.fromFile("chapter01/data/iris/in.txt").getLines
reservoirSample(lines, numLines).foreach { s =>
    w.write(s + scala.util.Properties.lineSeparator)
}
w.close

这将输出numLines行。类似于蓄水池采样,分层采样保证为所有由另一个属性的级别定义的层提供相同的输入/输出行比例。我们可以通过将原始数据集分割成与级别相对应的N个子集,执行蓄水池采样,然后合并结果来实现这一点。然而,将在第三章中介绍的 MLlib 库,使用 Spark 和 MLlib,已经实现了分层采样的实现:

val origLinesRdd = sc.textFile("file://...")
val keyedRdd = origLines.keyBy(r => r.split(",")(0))
val fractions = keyedRdd.countByKey.keys.map(r => (r, 0.1)).toMap
val sampledWithKey = keyedRdd.sampleByKeyExact(fractions)
val sampled = sampledWithKey.map(_._2).collect

另一个要点更为微妙;有时我们希望在多个数据集中保持值的一致子集,无论是为了可重复性还是为了与其他采样数据集连接。一般来说,如果我们采样两个数据集,结果将包含随机子集的 ID,这些 ID 可能几乎没有交集或完全没有交集。密码学哈希函数在这里提供了帮助。应用 MD5 或 SHA1 等哈希函数的结果是一系列在理论上至少是统计上不相关的位序列。我们将使用MurmurHash函数,它是scala.util.hashing包的一部分:

import scala.util.hashing.MurmurHash3._

val markLow = 0
val markHigh = 4096
val seed = 12345

def consistentFilter(s: String): Boolean = {
  val hash = stringHash(s.split(" ")(0), seed) >>> 16
  hash >= markLow && hash < markHigh
}

val w = new java.io.FileWriter(new java.io.File("out.txt"))
val lines = io.Source.fromFile("chapter01/data/iris/in.txt").getLines
lines.filter(consistentFilter).foreach { s =>
     w.write(s + Properties.lineSeparator)
}
w.close

此函数保证基于第一个字段的值返回完全相同的记录子集——要么是第一个字段等于某个特定值的所有记录,要么是没有任何记录——并且将产生大约原始样本的六分之一;hash的范围是065,535

注意

MurmurHash?它不是一个密码学哈希!

与 MD5 和 SHA1 等密码学哈希函数不同,MurmurHash 并不是专门设计成难以找到哈希的逆函数。然而,它确实非常快且高效。在我们的用例中,这真正是重要的。

使用 Scala 和 Spark 笔记本工作

通常,最频繁的值或五数摘要不足以获得对数据的初步理解。术语 描述性统计 非常通用,可能指代描述数据非常复杂的方法。分位数、帕累托图,或者当分析多个属性时,相关性也是描述性统计的例子。当分享所有这些查看数据聚合的方法时,在许多情况下,分享得到这些计算的具体计算也很重要。

Scala 或 Spark Notebook github.com/Bridgewater/scala-notebook, github.com/andypetrella/spark-notebook 记录整个转换路径,结果可以作为基于 JSON 的 *.snb 文件共享。Spark Notebook 项目可以从 spark-notebook.io 下载,我将提供与本书一起的示例 Chapter01.snb 文件。我将使用 Spark,我将在第三章(part0249.xhtml#aid-7DES21 "第三章. 使用 Spark 和 MLlib")中更详细地介绍,使用 Spark 和 MLlib

对于这个特定的例子,Spark 将在本地模式下运行。即使在本地模式下,Spark 也可以在您的工作站上利用并行性,但受限于可以在您的笔记本电脑或工作站上运行的内核和超线程数量。然而,通过简单的配置更改,Spark 可以指向一组分布式机器,并使用分布式节点集的资源。

这里是下载 Spark Notebook 并从代码仓库复制必要文件的命令集:

[akozlov@Alexanders-MacBook-Pro]$ wget http://s3.eu-central-1.amazonaws.com/spark-notebook/zip/spark-notebook-0.6.3-scala-2.11.7-spark-1.6.1-hadoop-2.6.4-with-hive-with-parquet.zip
...
[akozlov@Alexanders-MacBook-Pro]$ unzip -d ~/ spark-notebook-0.6.3-scala-2.11.7-spark-1.6.1-hadoop-2.6.4-with-hive-with-parquet.zip
...
[akozlov@Alexanders-MacBook-Pro]$ ln -sf ~/ spark-notebook-0.6.3-scala-2.11.7-spark-1.6.1-hadoop-2.6.4-with-hive-with-parquet ~/spark-notebook
[akozlov@Alexanders-MacBook-Pro]$ cp chapter01/notebook/Chapter01.snb ~/spark-notebook/notebooks
[akozlov@Alexanders-MacBook-Pro]$ cp chapter01/ data/kddcup/kddcup.parquet ~/spark-notebook
[akozlov@Alexanders-MacBook-Pro]$ cd ~/spark-notebook
[akozlov@Alexanders-MacBook-Pro]$ bin/spark-notebook 
Play server process ID is 2703
16/04/14 10:43:35 INFO play: Application started (Prod)
16/04/14 10:43:35 INFO play: Listening for HTTP on /0:0:0:0:0:0:0:0:9000
...

现在,您可以在浏览器中打开 http://localhost:9000 上的笔记本,如下面的截图所示:

使用 Scala 和 Spark 笔记本

图 01-2. Spark 笔记本的第一页,列出了笔记本列表

通过点击它打开 Chapter01 笔记本。语句被组织成单元格,可以通过点击顶部的较小右箭头来执行,如下面的截图所示,或者通过导航到 单元格 | 运行所有 来一次性运行所有单元格:

使用 Scala 和 Spark 笔记本

图 01-3. 执行笔记本中的前几个单元格

首先,我们将查看所有或某些离散变量的值。例如,要获取标签的分布,请执行以下代码:

val labelCount = df.groupBy("lbl").count().collect
labelCount.toList.map(row => (row.getString(0), row.getLong(1)))

我第一次读取数据集时,在 MacBook Pro 上大约花费了一分钟,但 Spark 将数据缓存到内存中,后续的聚合运行只需大约一秒钟。Spark Notebook 提供了值的分布,如下面的截图所示:

使用 Scala 和 Spark 笔记本

图 01-4. 计算分类字段的值分布

我还可以查看离散变量的交叉表计数,这让我对变量之间的相互依赖性有了了解,请参阅spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions——该对象尚不支持计算如卡方检验之类的相关度量:

使用 Scala 和 Spark 笔记本

图 01-5. 列联表或交叉表

然而,我们可以看到最受欢迎的服务是私有的,并且它与 SF 标志的相关性很好。另一种分析依赖关系的方法是查看 0 条目。例如,S2S3 标志显然与 SMTP 和 FTP 流量相关,因为所有其他条目都是 0

当然,最有趣的关联是与目标变量相关,但这些最好通过我在第三章(part0249.xhtml#aid-7DES21 "第三章。使用 Spark 和 MLlib")和第五章(part0260.xhtml#aid-7NUI81 "第五章。回归和分类")中将要介绍的监督学习算法来发现,使用 Spark 和 MLlib回归和分类

使用 Scala 和 Spark 笔记本

图 01-6. 使用 org.apache.spark.sql.DataFrameStatFunctions 计算简单聚合。

类似地,我们可以使用 dataFrame.stat.corr()dataFrame.stat.cov() 函数计算数值变量的相关性(参见图 01-6)。在这种情况下,该类支持皮尔逊相关系数。或者,我们可以直接在 parquet 文件上使用标准的 SQL 语法:

sqlContext.sql("SELECT lbl, protocol_type, min(duration), avg(duration), stddev(duration), max(duration) FROM parquet.`kddcup.parquet` group by lbl, protocol_type")

最后,我答应您计算百分位数。计算百分位数通常涉及对整个数据集进行排序,这是昂贵的;然而,如果分块是前几个或最后几个之一,通常可以优化计算:

val pct = sqlContext.sql("SELECT duration FROM parquet.`kddcup.parquet` where protocol_type = 'udp'").rdd.map(_.getLong(0)).cache
pct.top((0.05*pct.count).toInt).last

对于更通用的案例,计算确切的百分位数计算成本更高,它是 Spark 笔记本示例代码的一部分。

基本相关性

您可能已经注意到,从列联表中检测相关性是困难的。检测模式需要练习,但许多人更擅长直观地识别模式。检测可操作的模式是机器学习的主要目标之一。虽然将在第四章(part0256.xhtml#aid-7K4G02 "第四章。监督学习和无监督学习")和第五章(part0260.xhtml#aid-7NUI81 "第五章。回归和分类")中介绍的高级监督机器学习技术,监督学习和无监督学习以及回归和分类存在,但变量之间相互依赖性的初步分析有助于正确转换变量或选择最佳推理技术。

存在多个成熟的可视化工具,并且有多个网站,例如www.kdnuggets.com,它们专注于对数据分析、数据探索和可视化软件进行排名和提供推荐。在这本书中,我不会质疑这些排名的有效性和准确性,实际上很少有网站提到 Scala 作为可视化数据的具体方式,即使使用D3.js包也是可能的。一个好的可视化是向更广泛的受众传达发现的好方法。一看胜千言。

为了本章的目的,我将使用每个 Mac OS 笔记本上都有的Grapher。要打开Grapher,请转到实用工具(在 Finder 中按shift + command + U)并单击Grapher图标(或按command + space按名称搜索)。Grapher 提供了许多选项,包括以下对数-对数极坐标

基本相关性

图 01-7. Grapher 窗口

从根本上讲,通过可视化可以传达的信息量受屏幕上像素数量的限制,对于大多数现代计算机来说,这是数百万,以及颜色变化,这也可以说是数百万(JuddDeane B.WyszeckiGünter(1975)。商业、科学和工业中的颜色纯与应用光学系列(第 3 版)。纽约)。如果我正在处理一个多维 TB 数据集,该数据集首先需要被总结、处理,并减少到可以在计算机屏幕上查看的大小。

为了说明目的,我将使用可以在archive.ics.uci.edu/ml/datasets/Iris找到的 Iris UCI 数据集。要将数据集引入工具,请输入以下代码(在 Mac OS 上):

[akozlov@Alexanders-MacBook-Pro]$ pbcopy < chapter01/data/iris/in.txt

Grapher中打开新的点集command + alt + P),按编辑点…并按command + V粘贴数据。该工具具有基本的线性、多项式和指数族等线拟合能力,并提供流行的卡方指标来估计拟合的优良程度,相对于自由参数的数量:

基本相关性

图 01-8. 在 Mac OS X 上使用 Grapher 拟合 Iris 数据集

我们将在接下来的章节中介绍如何估计模型拟合的优良程度。

摘要

我试图在书中建立一个共同的基础,以便在后面的章节中进行更复杂的数据科学。不要期望这些是完整的探索性技术集,因为探索性技术可以扩展到运行非常复杂模式。然而,我们已经涵盖了简单的聚合、抽样、读写等文件操作,以及使用笔记本和 Spark DataFrames 等工具,这些工具将熟悉的 SQL 结构引入了与 Spark/Scala 一起工作的分析师的武器库中。

下一章将完全转变方向,将数据管道视为数据驱动企业的组成部分,并从业务角度涵盖数据发现过程:我们通过数据分析试图实现哪些最终目标。在这之后,我将介绍一些传统的机器学习主题,如监督学习和无监督学习,然后再深入研究更复杂的数据表示,Scala 在这里真正显示出其相对于 SQL 的优势。

第二章:数据管道和建模

在上一章中,我们探讨了探索数据的基本动手工具,因此我们现在可以深入研究更复杂的主题,如统计模型构建、最优控制或科学驱动工具和问题。我将继续说,我们只会触及一些最优控制的话题,因为这本书实际上只是关于 Scala 中的机器学习,而不是数据驱动业务管理的理论,这可能是一个单独成书的激动人心的主题。

在本章中,我将避免具体实现 Scala,并从高层次讨论构建数据驱动企业的相关问题。后面的章节将解决这些难题的解决方案。特别强调处理不确定性。不确定性通常有几种形式:首先,我们提供的信息中可能存在噪声。其次,信息可能不完整。系统在填补缺失部分时可能有一定的自由度,这导致不确定性。最后,模型解释和结果指标可能存在差异。最后一个观点很微妙,因为大多数经典教科书都假设我们可以直接测量事物。不仅测量可能存在噪声,而且测量的定义可能随时间变化——尝试测量满意度或幸福感。当然,我们可以通过说我们只能优化可测量的指标来避免这种歧义,就像人们通常做的那样,但这将显著限制实际应用的范围。没有任何东西阻止科学机制在处理解释不确定性时将其考虑在内。

预测模型通常只是为了数据理解而构建。从语言学的推导来看,模型是对实际复杂建筑或过程的简化表示,其目的正是为了阐明观点和说服人们,无论通过何种方式。预测建模的最终目标,也就是我在本书和本章中关注的目标,是通过考虑最重要的因素来优化业务流程,以便让世界变得更加美好。这当然是一个充满不确定性的句子,但至少它看起来比优化点击率要好得多。

让我们看看传统的商业决策过程:一家传统的企业可能涉及一组 C 级高管根据通常从一组包含一个或多个数据库中数据图形表示的仪表板中获得的信息做出决策。自动化数据驱动型企业的承诺是,在消除人类偏见的情况下,能够自动做出大多数决策。这并不是说我们不再需要 C 级高管,但 C 级高管将忙于帮助机器做出决策,而不是相反。

在本章中,我们将涵盖以下主题:

  • 探索影响图作为决策工具的基本原理

  • 在自适应马尔可夫决策过程凯利准则的背景下查看纯决策优化变体

  • 熟悉至少三种不同的探索-利用权衡策略

  • 描述数据驱动型企业的架构

  • 讨论决策管道的主要架构组件

  • 熟悉构建数据管道的标准工具

影响图

虽然决策过程可能具有多个方面,但一本关于不确定条件下决策的书如果没有提到影响图(团队决策分析中的影响图,《决策分析》第 2 卷(4):207–228),就会显得不完整。影响图有助于分析和理解决策过程。决策可能像选择在个性化环境中向用户展示的下一条新闻文章这样平凡,也可能像在企业网络中检测恶意软件或选择下一个研究项目这样复杂。

根据天气情况,她可以尝试进行一次乘船之旅。我们可以将决策过程表示为图表。让我们决定在她在俄勒冈州波特兰逗留期间是否乘坐游船:

影响图

图 02-1。一个简单的假期影响图,用于表示简单的决策过程。该图包含决策节点,如度假活动,可观察和不可观察的信息节点,如天气预报和天气,以及最终的价值节点,如满意度

上述图表代表了这种情况。是否参加活动的决策明显是由获得一定满意度的可能性驱动的,这是决策本身和活动时的天气的函数。虽然旅行计划时的实际天气条件是未知的,但我们相信天气预报和旅行期间实际经历的天气之间存在某种相关性,这由天气天气预报节点之间的边表示。度假活动节点是决策节点,它只有一个父节点,因为决策完全基于天气预报。DAG 中的最后一个节点是满意度,它是实际天气和我们在旅行计划期间所做的决策的函数——显然,“是 + 好天气”和“否 + 坏天气”可能得分最高。而“是 + 坏天气”和“否 + 好天气”将是一个不良的结果——后者可能是错过机会,但并不一定是一个糟糕的决策,前提是天气预报不准确。

边缘的缺失包含了一个独立性假设。例如,我们相信满意度不应该依赖于天气预报,因为后者在我们上船后就变得无关紧要了。一旦度假计划确定,实际天气在划船活动期间就不再影响决策,该决策完全基于天气预报;至少在我们的简化模型中,我们排除了购买旅行保险的选项。

图表展示了决策的不同阶段和信息流(我们将在第七章 Chapter 7,使用图算法)中提供实际的 Scala 实现)。在我们的简化图中,做出决策只需要一条信息:天气预报。一旦做出决策,我们就无法更改它,即使我们有关于旅行时实际天气的信息。天气和决策数据可以用来模拟她对所做决策的满意度。

让我们将这种方法映射到一个广告问题作为说明:最终目标是获得用户对目标广告的满意度,这将为广告商带来额外的收入。满意度是用户特定环境状态的函数,在决策时是未知的。然而,使用机器学习算法,我们可以根据用户的最近网页访问历史和其他我们可以收集的信息(如地理位置、浏览器代理字符串、一天中的时间、广告类别等)来预测这种状态(参见图 2-2)。

虽然我们不太可能测量用户大脑中的多巴胺水平,这肯定会落入可测量指标的范畴,并可能减少不确定性,但我们可以通过用户的行为间接测量用户满意度,无论是他们是否对广告做出了回应,还是用户在点击浏览相关信息之间花费的时间,这可以用来估计我们建模和算法的有效性。以下是一个影响图,类似于“假期”的影响图,调整用于广告决策过程:

影响图

图 02-2. 调整后的在线广告决策案例的假期影响图。在线广告的决策可以每秒进行数千次

实际过程可能更复杂,代表一系列决策,每个决策都依赖于几个先前的时间切片。例如,所谓的马尔可夫链决策过程。在这种情况下,图表可能需要在多个时间切片上重复。

另一个例子可能是企业网络互联网恶意软件分析系统。在这种情况下,我们试图根据对企业交换机流经的网络数据包的分析来检测指示命令和控制(C2)、横向移动或数据泄露的网络连接。目标是最大限度地减少爆发对系统运行的最小影响。

我们可能做出的一个决定是重新映像一部分节点,或者至少将它们隔离。我们收集的数据可能包含不确定性——许多良性软件包可能会以可疑的方式发送流量,而模型需要根据风险和潜在影响来区分它们。在这个特定案例中,可能的一个决定是收集更多信息。

我将把这个以及其他潜在的商业案例映射到相应的图表上作为练习留给读者。现在让我们考虑一个更复杂的优化问题。

顺序试验和风险处理

如果我为了多赚几美元而牺牲同样风险的偏好是什么?我将在本节稍后停止讨论为什么一个人的偏好可能是不对称的,并且有科学证据表明这种不对称性是受进化原因根植于我们心中的,但你是对的,我现在必须优化参数化效用函数的不对称函数的期望值,如下所示:

顺序试验和风险处理

为什么分析中会出现非对称函数?一个例子是重复投注或再投资,也称为凯利公式问题。虽然最初,凯利公式是为赌博机等二元结果的具体情况开发的,用于优化每轮投注的资金比例(《信息率的新解释》,贝尔系统技术期刊 35 (4): 917–926,1956),但作为一个更通用的再投资问题,它涉及到可能回报的概率分布。

多次投注的回报是每次投注的个别回报率的乘积——回报率是投注后的资金与每次个别投注前的原始资金的比率,如下所示:

顺序试验和风险管理

这对我们优化总回报帮助不大,因为我们不知道如何优化独立同分布(i.i.d)随机变量的乘积。然而,我们可以通过对数变换将乘积转换为和,并应用中心极限定理CLT)来近似独立同分布变量的和(假设r的分布满足 CLT 条件,例如,具有有限的均值和方差),如下所示:

顺序试验和风险管理

因此,进行N次投注的累积结果将类似于进行N次投注,期望回报为顺序试验和风险管理,而不是顺序试验和风险管理

正如我之前提到的,这个问题最常应用于二元竞标的情况,尽管它可以很容易地推广,在这种情况下,还有一个额外的参数:x,即每轮投注的金额。假设我以概率p获得W的利润,或者以概率(1-p)完全输掉投注。优化与以下额外参数相关的期望回报:

顺序试验和风险管理顺序试验和风险管理顺序试验和风险管理

最后一个方程是凯利公式比率,它给出了最佳投注金额。

一个人可能会投注少于总金额的原因是,即使平均回报是正的,仍然有可能输掉全部资金,尤其是在高度偏斜的情况下。例如,即使你投注获得10 x的概率是0.105W = 10,期望回报是5%),组合分析表明,即使经过60次投注,整体回报为负的概率大约为50%,而且有11%的几率,特别是会输掉(57 - 10 x 3) = 27倍或更多的投注:

akozlov@Alexanders-MacBook-Pro$ scala
Welcome to Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40).
Type in expressions to have them evaluated.
Type :help for more information.27 

scala> def logFactorial(n: Int) = { (1 to n).map(Math.log(_)).sum }
logFactorial: (n: Int)Double

scala> def cmnp(m: Int, n: Int, p: Double) = {
 |   Math.exp(logFactorial(n) -
 |   logFactorial(m) +
 |   m*Math.log(p) -
 |   logFactorial(n-m) +
 |   (n-m)*Math.log(1-p))
 | }
cmnp: (m: Int, n: Int, p: Double)Double

scala> val p = 0.105
p: Double = 0.105

scala> val n = 60
n: Int = 60

scala> var cumulative = 0.0
cumulative: Double = 0.0

scala> for(i <- 0 to 14) {
 |   val prob = cmnp(i,n,p)
 |   cumulative += prob
 |   println(f"We expect $i wins with $prob%.6f probability $cumulative%.3f cumulative (n = $n, p = $p).")
 | }
We expect 0 wins with 0.001286 probability 0.001 cumulative (n = 60, p = 0.105).
We expect 1 wins with 0.009055 probability 0.010 cumulative (n = 60, p = 0.105).
We expect 2 wins with 0.031339 probability 0.042 cumulative (n = 60, p = 0.105).
We expect 3 wins with 0.071082 probability 0.113 cumulative (n = 60, p = 0.105).
We expect 4 wins with 0.118834 probability 0.232 cumulative (n = 60, p = 0.105).
We expect 5 wins with 0.156144 probability 0.388 cumulative (n = 60, p = 0.105).
We expect 6 wins with 0.167921 probability 0.556 cumulative (n = 60, p = 0.105).
We expect 7 wins with 0.151973 probability 0.708 cumulative (n = 60, p = 0.105).
We expect 8 wins with 0.118119 probability 0.826 cumulative (n = 60, p = 0.105).
We expect 9 wins with 0.080065 probability 0.906 cumulative (n = 60, p = 0.105).
We expect 10 wins with 0.047905 probability 0.954 cumulative (n = 60, p = 0.105).
We expect 11 wins with 0.025546 probability 0.979 cumulative (n = 60, p = 0.105).
We expect 12 wins with 0.012238 probability 0.992 cumulative (n = 60, p = 0.105).
We expect 13 wins with 0.005301 probability 0.997 cumulative (n = 60, p = 0.105).
We expect 14 wins with 0.002088 probability 0.999 cumulative (n = 60, p = 0.105).

注意,为了恢复 27 x 的金额,平均只需要再玩 顺序试验和风险处理 轮额外的游戏,但必须先有赌注才能开始。凯利公式提供的是,最佳策略是只投注我们赌注的 1.55%。注意,如果我投注全部赌注,我将在第一轮(获胜的概率仅为 0.105)以 89.5% 的确定性输掉所有钱。如果我只投注赌注的一小部分,留在游戏中的机会将无限好,但整体回报较小。期望对数收益的图示如 图 02-3 所示,作为投注赌注的份额 x 和我刚刚计算的 60 轮可能结果分布的函数。在 24% 的游戏中,我们的表现将不如下方的曲线,在 39% 的游戏中不如下一条曲线,大约一半——44%——的赌徒的表现将与中间的黑曲线相同或更好,而在 30% 的情况下表现将优于最上方的一条。对于 x 的最佳凯利公式值为 0.0155,这将最终在无限多轮中优化整体回报:

顺序试验和风险处理

图 02-3. 作为投注金额和 60 轮可能结果函数的期望对数收益(见方程(2.2))

凯利公式因其过于激进(赌徒倾向于高估他们的获胜潜力和比例,同时低估破产的概率)以及过于保守(风险价值应该是总可用资本,而不仅仅是赌注)而受到批评,但它展示了我们需要用一些额外的转换来补偿我们对“收益”的直观理解的一个例子。

从金融角度来看,凯利公式比标准定义(如收益的波动性或方差)更好地描述了风险。对于一个通用的参数化收益分布 y(z),其概率分布函数为 f(z),方程(2.3)可以重新表述如下。在替换 r(x) = 1 + x y(z) 后,其中 x 仍然是投注的金额:

顺序试验和风险处理顺序试验和风险处理

在离散情况下,它也可以写成以下形式:

顺序试验和风险处理

在这里,分母强调了负收益区域的贡献。具体来说,失去所有赌注的可能性正好是分母 顺序试验和风险处理 为零的地方。

正如我之前提到的,有趣的是,风险规避植根于我们的直觉,似乎在人类和灵长类动物中编码了一个自然的风险规避偏好系统(参见 Laurie Santos 的《像我们一样不理性的猴子经济》,TED 演讲,2010 年)。现在关于猴子和风险的话题就足够了,让我们进入另一个相当有争议的主题——探索与利用的权衡,在这个权衡中,一个人甚至可能一开始都不知道收益权衡。

探索与利用

探索与利用的权衡是另一个问题,尽管其真实应用范围从研究项目的资金分配到自动驾驶汽车,但其明显的起源在赌博中。传统的表述是多臂老丨虎丨机问题,它指的是一个或多个臂的虚拟老丨虎丨机。每个臂的连续操作产生i.i.d的回报,每个臂的回报概率未知;在简化模型中,连续操作是独立的。假设奖励在臂之间是独立的。目标是最大化奖励——例如,赢得的金额,并最小化学习损失,即花费在获胜率低于最优的臂上的金额,前提是有一个商定的臂选择策略。明显的权衡是在寻找产生最佳回报的臂的探索和利用已知最佳回报的利用之间:

探索与利用

然后,伪后悔是以下差值:

探索与利用

在这里,探索与利用是从N次试验中选出的i臂。多臂老丨虎丨机问题在 20 世纪 30 年代和 21 世纪初都得到了广泛的研究,其应用领域包括金融和 ADTECH。尽管由于问题的随机性,通常无法提供一个比N的平方根更好的期望后悔的上界,但可以通过控制伪后悔来使其受到N的对数的约束(参见 Sebastien Bubeck 和 Nicolo Cesa-Bianchi 的论文《随机和非随机多臂老丨虎丨机问题的后悔分析》,arxiv.org/pdf/1204.5721.pdf)。

实践中最常用的策略之一是ε策略,其中最优臂以探索与利用的概率被选中,而其他臂则以剩余的概率被选中。这种方法的缺点是我们可能会在永远不会提供任何奖励的臂上花费大量的探索资源。UCB 策略通过选择具有最大回报估计值的臂,以及回报估计值的标准差的某个倍数或分数来改进ε策略。这种方法需要在每一轮重新计算最佳臂,并且由于对均值和标准差的估计所做的近似而受到影响。此外,UCB 需要为每次连续抽取重新计算估计值,这可能会成为可扩展性问题。

最后,Thompson 抽样策略使用 Beta-Bernoulli 后验估计的固定随机样本,并将下一个臂分配给给出最小预期后悔的臂,为此可以使用实际数据来避免参数重新计算。尽管具体数字可能取决于假设,但以下图表提供了这些模型性能的一个可用比较:

探索与利用

图 02-3。对于 K = 5,单臂老丨虎丨机和不同策略的不同探索/利用策略的模拟结果。

图 02-3显示了不同策略的模拟结果(摘自 Rich Relevance 网站engineering.richrelevance.com/recommendations-thompson-sampling)。随机策略随机分配臂,对应于纯探索。天真策略在某个阈值内是随机的,然后切换到纯利用模式。上置信界UCB)以 95%的置信水平。UCB1 是 UCB 的一种修改,以考虑分布的对数正态性。最后,Thompson 抽样策略从实际后验分布中随机抽取样本以优化后悔。

探索/利用模型众所周知对初始条件和异常值非常敏感,尤其是在低响应方面。人们可能会在实际上已经无望的臂上花费大量的试验。

通过基于额外信息(如位置)估计更好的先验,或者由于这种额外信息而限制要探索的臂的集合—K—,可以对这些策略进行其他改进,但这些方面更具有领域特定性(如个性化或在线广告)。

未知之未知

未知未知的事物在很大程度上是因为美国国防部长唐纳德·拉姆斯菲尔德在 2002 年 2 月 12 日的一次美国国防部(DoD)新闻发布会上对关于伊拉克政府与向恐怖组织供应大规模杀伤性武器缺乏证据的提问的回答而闻名,以及纳西姆·尼古拉斯·塔勒布的书籍(《黑天鹅:几乎不可能发生的事件的影响》由纳西姆·尼古拉斯·塔勒布著,Random House 出版社,2007 年)。

注意

火鸡悖论

争议性地,未知未知的事物可以通过火鸡悖论更好地解释。假设你有一群火鸡在后院玩耍,享受保护和免费食物。栅栏的另一边,还有另一群火鸡。这一切日复一日,月复一月,直到感恩节到来——感恩节是加拿大和美国庆祝的全国性假日,在这一天,人们习惯于在烤箱里烤火鸡。火鸡很可能在这个时候被收割并消费,尽管从火鸡的角度来看,没有任何可辨别的信号表明在加拿大的 10 月第二个星期一或美国的 11 月第四个星期四会发生任何事情。除了额外的年度信息之外,没有任何模型可以在年内数据的基础上解决这个问题。

未知未知的事物是模型中没有的,并且无法预测它们会在模型中。实际上,唯一真正感兴趣的未知未知的事物是那些对模型影响如此之大,以至于之前几乎不可能或几乎不可能发生的结果现在变成了现实。鉴于大多数实际分布都属于指数家族,尾部非常薄,因此,与正态分布的偏差不必超过几个标准差,就会对标准模型假设产生破坏性的影响。尽管人们仍然需要想出一个可行的策略来如何在模型中包含未知因素——已经提出了几种方法,包括分形,但几乎没有可行的——从业者必须意识到风险,而这里的定义风险正是模型无用的可能性。当然,已知未知和未知未知之间的区别正是我们理解风险和需要探索的内容。

当我们审视决策系统面临的基本问题范围时,让我们来看看数据管道,这些软件系统为做出决策提供信息,以及设计数据驱动系统数据管道的更实际方面。

数据驱动系统的基本组件

简而言之,数据驱动架构包含以下组件——至少我所见到的所有系统都有这些组件——或者可以简化为这些组件:

  • 数据采集:我们需要从系统和设备中收集数据。大多数系统都有日志,或者至少有一个将文件写入本地文件系统的选项。一些系统可能具有将信息报告给基于网络的接口(如 syslog)的能力,但通常没有持久化层意味着可能存在数据丢失的风险,如果不是审计信息缺失的话。

  • 数据转换层:它也被称为提取、转换和加载ETL)。今天,数据转换层也可以用于实时处理,其中聚合是在最新数据上计算的。数据转换层也传统上用于重新格式化和索引数据,以便高效地由管道下游的算法的 UI 组件访问。

  • 数据分析与机器学习引擎:这个层次不是标准数据转换层的一部分的原因通常是因为这个层次需要相当不同的技能。构建合理统计模型的人的心态通常与那些使数以 TB 计的数据快速移动的人不同,尽管偶尔我也能找到具备这两种技能的人。通常,这些“独角兽”被称为数据科学家,但任何特定领域的技能通常都不如专门从事该领域的人。尽管如此,我们仍然需要更多这样的人。另一个原因是,机器学习,以及在某种程度上数据分析,需要多次对相同数据进行聚合和遍历,这与更流式的 ETL 转换不同,需要不同的引擎。

  • UI 组件:是的,UI 代表用户界面,它通常是一组组件,允许您通过浏览器与系统通信(它曾经是一个本地的 GUI,但如今基于 Web 的 JavaScript 或 Scala 框架要强大得多,并且更易于移植)。从数据管道和建模的角度来看,该组件提供了一个 API 来访问数据和模型的内部表示。

  • 动作引擎:这通常是一个可配置的规则引擎,根据洞察力优化提供的指标。动作可以是实时的,例如在线广告中的情况,在这种情况下,引擎应该能够提供实时评分信息,或者为用户动作提供推荐,这可能采取电子邮件警报的形式。

  • 关联引擎:这是一个新兴的组件,它可能分析数据分析与机器学习引擎的输出,以推断数据或模型行为方面的额外见解。这些动作也可能由该层的输出触发。

  • 监控:这是一个复杂的系统,如果没有日志、监控以及某种方式来更改系统参数,它将是不完整的。监控的目的是拥有一个嵌套的决策系统,关于系统的最佳健康状况,要么自动减轻问题(s),要么向系统管理员发出关于问题(s)的警报。

让我们在以下各节中详细讨论每个组件。

数据摄取

随着智能设备的普及,信息收集不再是问题,而是任何从事除打字文本之外业务的企业的一种必要需求。为了本章的目的,我将假设设备或设备已连接到互联网或以某种方式通过家庭拨号或直接网络连接传递此信息。

此组件的主要目的是收集所有可能对后续数据驱动决策相关的相关信息。以下表格提供了关于数据摄取最常见实现的详细信息:

框架 当使用 评论
Syslog Syslog 是 Unix 机器之间传递消息的最常见标准之一。Syslog 通常监听端口 514,传输协议可以配置为 UDP(不可靠)或 TCP。在 CentOS 和 Red Hat Linux 上的最新增强实现是 rsyslog,它包括许多高级选项,如基于正则表达式的过滤,这对于系统性能调整和调试非常有用。除了略微低效的原始消息表示——纯文本,对于重复字符串的长消息可能效率不高——syslog 系统可以每秒支持数万条消息。Syslog 是由 Eric Allman 在 1980 年代作为 Sendmail 的一部分开发的最早协议之一。虽然它不保证交付或持久性,尤其是对于分布式系统,但它是最广泛的消息传递协议之一。一些后来的框架,如 Flume 和 Kafka,也有 syslog 接口。
Rsync Rsync 是一个在 1990 年代开发的较新的框架。如果数据被放置在本地文件系统上的平面文件中,rsync 可能是一个选择。虽然 rsync 传统上用于同步两个目录,但它也可以定期运行以批量传输日志数据。Rsync 使用由澳大利亚计算机程序员 Andrew Tridgell 发明的递归算法,在接收计算机已经有一个类似但不完全相同的相同结构版本的情况下,高效地检测差异并跨通信链路传输结构(如文件)。虽然它会产生额外的通信,但从持久性的角度来看,它更好,因为原始副本始终可以检索。如果已知日志数据最初是以批量形式到达的(例如上传或下载),则特别适用。Rsync 已知会受到网络瓶颈的限制,因为它在比较目录结构时最终会在网络上传递更多信息。然而,传输的文件在网络传输时可能会被压缩。网络带宽可以通过命令行标志进行限制。
Flume Flume 是 Cloudera 在 2009-2011 年间开发的一个较年轻的框架,并已开源。Flume——我们指的是更流行的 flume-ng 实现,称为 Flume,而不是较老的常规 Flume——由源、管道和可能配置在多个节点上的汇组成,以实现高可用性和冗余。Flume 被设计为在可靠性的代价下尽可能避免数据重复。Flume 以Avro格式传递消息,该格式也是开源的,传输协议以及消息都可以进行编码和压缩。 虽然 Flume 最初是为了从文件或一组文件中传输记录而开发的,但它也可以配置为监听端口,甚至从数据库中抓取记录。Flume 有多个适配器,包括前面的 syslog。
Kafka Kafka 是 LinkedIn 开发的日志处理框架的最新补充,并已开源。与之前的框架相比,Kafka 更像是一个分布式可靠的消息队列。Kafka 保持分区,可能分布在多个分布式机器上;缓冲区,并且可以订阅或取消订阅特定主题的消息。Kafka 在设计时考虑了强大的可靠性保证,这是通过复制和共识协议实现的。 Kafka 可能不适合小型系统(小于五个节点),因为完全分布式系统的优势可能只有在更大规模时才明显。Kafka 由 Confluent 提供商业支持。

信息传输通常以批量或微批量的形式进行,如果需求接近实时,则可能为微批量。通常,信息首先存储在设备本地文件系统中的一个文件中,传统上称为日志文件,然后传输到中央位置。最近开发的 Kafka 和 Flume 常用于管理这些传输,同时还有更传统的 syslog、rsync 或 netcat。最后,数据可以存储在本地或分布式存储中,如 HDFS、Cassandra 或 Amazon S3。

数据转换层

数据最终存储在 HDFS 或其他存储中后,需要使数据可用于处理。传统上,数据按计划处理,并最终按基于时间的桶进行分区。处理可以按日、按小时,甚至在新型的 Scala 流框架的基础上按亚分钟级进行,具体取决于延迟要求。处理可能涉及一些初步的特征构造或矢量化,尽管它传统上被认为是机器学习任务。以下表格总结了一些可用的框架:

框架 使用情况 备注
Oozie 这是雅虎(Yahoo)开发的最古老的开放源代码框架之一。它与大数据 Hadoop 工具具有良好的集成。它具有有限的用户界面,列出了作业历史。 整个工作流被放入一个大的 XML 文件中,这可能从模块化的角度来看被认为是一个缺点。
Azkaban 这是由领英(LinkedIn)开发的一个替代开源工作流调度框架。与 Oozie 相比,它可能具有更好的用户界面。缺点是所有高级任务都在本地执行,这可能会带来可扩展性问题。Azkaban 背后的理念是创建一个完全模块化的即插即用架构,其中新作业/任务可以尽可能少地修改后添加。
StreamSets StreamSets 是由前 Informix 和 Cloudera 的开发者最新构建的。它具有非常发达的用户界面,并支持更丰富的输入源和输出目标。 这是一个完全由用户界面驱动的工具,强调数据管理,例如,持续监控数据流中的问题和异常。

应当特别关注流处理框架,其中延迟需求降低到每次一个或几个记录。首先,流处理通常需要更多资源用于处理,因为与处理记录批次相比,每次处理单个记录的成本更高,即使只有几十或几百条记录也是如此。因此,架构师需要根据更近期结果的价值来证明额外成本是合理的,而这种价值并不总是有保证的。其次,流处理需要对架构进行一些调整,因为处理更近期数据成为优先事项;例如,最近在像Druiddruid.io)这样的系统中,一个处理更近期数据的独立子流或节点集的 delta 架构变得非常流行。|

数据分析和机器学习

为了本章的目的,机器学习ML)是指任何可以计算可操作聚合或摘要的算法。我们将从第三章 使用 Spark 和 MLlib 到第六章 使用非结构化数据,涵盖更复杂的算法,但在某些情况下,一个简单的滑动窗口平均值和平均值偏差可能就足够作为采取行动的信号。在过去几年中,A/B 测试中的“它有效”某种程度上成为模型构建和部署的有力论据。我并不是在猜测是否可能有或可能没有坚实的科学原理适用,但许多基本假设,如i.i.d、平衡设计和尾部稀薄性,在许多大数据情况下都未能成立。更简单的模型往往速度更快,性能和稳定性更好。|

例如,在在线广告中,人们可能会跟踪一组广告在一段时间内某些相似属性的平均性能,以决定是否显示该广告。关于异常或行为偏离的信息可能表明一个新未知的新情况,这表明旧数据不再适用,在这种情况下,系统别无选择,只能开始新的探索周期。

我将在第六章、处理非结构化数据、第八章、Scala 与 R 和 Python 的集成和第九章、Scala 中的 NLP中更晚些时候讨论更复杂的非结构化、图和模式挖掘。

UI 组件

嗯,UI 是给弱者的!只是开玩笑...也许有点严厉,但现实中,UI 通常提供一种必要的语法糖,以说服数据科学家之外的人群。一个好的分析师可能只需通过查看数字表格就能找出 t 检验的概率。

然而,可能应该应用我们在本章开头使用的相同方法,评估不同组件的有用性和投入其中的周期数量。良好的用户界面通常是有理由的,但它取决于目标受众。

首先,存在许多现有的 UI 和报告框架。不幸的是,其中大多数与函数式编程方法不一致。此外,复杂/半结构化数据的存在,我将在第六章、处理非结构化数据中更详细地描述,为许多框架带来了新的挑战,它们在没有实现某种类型的领域特定语言(DSL)的情况下无法应对。以下是我认为特别有价值的几个用于在 Scala 项目中构建 UI 的框架:

框架 当使用时 备注
Scala Swing 如果你使用了 Java 中的 Swing 组件并且熟练掌握它们,Scala Swing 是一个不错的选择。Swing 组件可以说是 Java 中最不便携的组件,所以你在不同平台上的表现可能会有所不同。 Scala.swing包在底层使用标准的 Java Swing 库,但它有一些很好的补充。最值得注意的是,由于它是为 Scala 设计的,它可以比标准的 Swing 以更简洁的方式使用。
Lift Lift 是一个安全、以开发者为中心、可扩展和交互式的框架,用 Scala 编写。Lift 在 Apache 2.0 许可下开源。 开源 Lift 框架于 2007 年由 David Polak 启动,他对 Ruby on Rails 框架的某些方面感到不满意。任何现有的 Java 库和 Web 容器都可以用于运行 Lift 应用程序。因此,Lift Web 应用程序被打包为 WAR 文件,并部署在任何 servlet 2.4 引擎上(例如,Tomcat 5.5.xx、Jetty 6.0 等)。Lift 程序员可以使用标准的 Scala/Java 开发工具链,包括 Eclipse、NetBeans 和 IDEA 等 IDE。动态 Web 内容通过模板使用标准的 HTML5 或 XHTML 编辑器编写。Lift 应用程序还受益于对高级 Web 开发技术(如 Comet 和 Ajax)的原生支持。
Play Play 框架可以说是比任何其他平台都更符合 Scala 作为函数式语言的特点——它由 Scala 背后的商业公司 Typesafe 官方支持。Play 框架 2.0 建立在 Scala、Akka 和 sbt 之上,提供卓越的异步请求处理、快速和可靠的性能。Typesafe 模板以及一个功能强大的构建系统,具有灵活的部署选项。Play 在 Apache 2.0 许可下开源。 开源 Play 框架于 2007 年由 Guillaume Bort 创建,他希望为长期受苦的 Java Web 开发社区带来一个受现代 Web 框架如 Ruby on Rails 启发的全新 Web 开发体验。Play 遵循熟悉的无状态模型-视图-控制器(MVC)架构模式,强调约定优于配置和开发者生产力。与传统的 Java Web 框架相比,它们有繁琐的编译-打包-部署-重启周期,Play 应用程序的更新只需简单的浏览器刷新即可立即可见。
Dropwizard Dropwizard 项目试图在 Java 和 Scala 中构建一个通用的 RESTful 框架,尽管最终可能使用 Java 多于 Scala。这个框架的优点在于它足够灵活,可以用于任意复杂的数据(包括半结构化数据)。此框架的许可协议为 Apache License 2.0。 RESTful API 假设状态,而函数式语言则避免使用状态。除非你足够灵活,能够偏离纯函数式方法,否则这个框架可能不适合你。
Slick 虽然 Slick 不是一个 UI 组件,但它是由 Typesafe 开发的 Scala 的现代数据库查询和访问库,可以作为 UI 后端使用。它允许您几乎像使用 Scala 集合一样处理存储的数据,同时同时,让您完全控制数据库访问发生的时间和传输的数据。您还可以直接使用 SQL。如果您的所有数据都是纯关系型数据,请使用它。该项目采用 BSD-Style 许可证开源。Slick 由 Stefan Zeiger 于 2012 年启动,主要由 Typesafe 维护。它主要用于关系型数据。
NodeJS Node.js 是一个基于 Chrome 的 V8 JavaScript 引擎构建的 JavaScript 运行时。Node.js 使用事件驱动、非阻塞 I/O 模型,使其轻量级且高效。Node.js 的包生态系统 npm 是世界上最大的开源库生态系统。该项目采用 MIT 许可证开源。Node.js 首次由 Ryan Dahl 和在 Joyent 工作的其他开发者于 2009 年推出。最初 Node.js 只支持 Linux,但现在它可以在 OS X 和 Windows 上运行。
AngularJS AngularJS (angularjs.org) 是一个前端开发框架,旨在简化单页网络应用程序的开发。该项目采用 MIT 许可证开源。AngularJS 最初于 2009 年由 Brat Tech LLC 的 Misko Hevery 开发。AngularJS 主要由 Google 和一群个人开发者及企业维护,因此特别适用于 Android 平台(从 1.3 版本开始不再支持 IE8)。

动作引擎

虽然这是面向数据系统管道的核心,但也有人认为它是最容易的一个。一旦知道了指标和值的系统,系统就会根据已知的方程式,根据提供的信息,决定是否采取某些行动。虽然基于阈值的触发器是最常见的实现方式,但向用户提供一系列可能性和相关概率的概率方法的重要性正在显现——或者就像搜索引擎那样,向用户提供最相关的 N 个选择。

规则的管理可能会变得相当复杂。过去,使用规则引擎(如 Drools (www.drools.org))管理规则是足够的。然而,管理复杂的规则成为一个问题,通常需要开发一个 DSL(马丁·福勒的《领域特定语言》,Addison-Wesley,2010 年)。Scala 特别适合开发这样的动作引擎。

相关引擎

决策系统越复杂,就越需要一个二级决策系统来优化其管理。DevOps 正在转变为 DataOps(Michael Stonebraker 等人所著的《Getting Data Right》,Tamr,2015)。关于数据驱动型系统性能收集的数据用于检测异常和半自动化维护。

模型通常会受到时间漂移的影响,性能可能会因为数据收集层的变化或人群行为的变化而下降(我将在第十章高级模型监控中讨论模型漂移)。模型管理的另一个方面是跟踪模型性能,在某些情况下,通过各种共识方案使用模型的“集体智慧”。

监控

监控一个系统涉及收集有关系统性能的信息,无论是为了审计、诊断还是性能调整。虽然它与前面章节中提出的问题有关,但监控解决方案通常包含诊断和历史存储解决方案,以及关键数据的持久化,就像飞机上的黑匣子。在 Java 和 Scala 的世界里,一个流行的选择工具是 Java 性能 bean,可以在 Java 控制台中监控。虽然 Java 原生支持 MBean 通过 JMX 暴露 JVM 信息,但Kamon(kamon.io)是一个开源库,它使用这种机制专门暴露 Scala 和 Akka 指标。

一些其他流行的开源监控解决方案包括Ganglia(ganglia.sourceforge.net/)和Graphite(graphite.wikidot.com)。

我在这里就不再继续了,因为我在第十章高级模型监控中会更详细地讨论系统和模型监控。

优化和交互性

虽然收集到的数据可以仅用于理解业务,但任何数据驱动型企业的最终目标是通过自动做出基于数据和模型的决定来优化业务行为。我们希望将人为干预降到最低。以下简化的图可以描述为一个循环:

优化和交互性

图 02-4. 预测模型生命周期

当系统中有新信息进入时,这个循环会一次又一次地重复。系统参数可能需要调整以提高整体系统性能。

反馈循环

虽然人类可能仍然会在大多数系统中保持参与,但最近几年出现了能够独立管理完整反馈循环的系统——从广告系统到自动驾驶汽车。

这个问题的经典表述是最优控制理论,它也是一个优化问题,旨在最小化成本泛函,给定一组描述系统的微分方程。最优控制是一组控制策略,用于在给定约束条件下最小化成本泛函。例如,问题可能是在不超过某个时间限制的情况下,找到一种驾驶汽车以最小化其燃油消耗的方法。另一个控制问题是在满足库存和时间约束的条件下,最大化在网站上展示广告的利润。大多数最优控制软件包是用其他语言编写的,如 C 或 MATLAB(PROPT、SNOPT、RIOTS、DIDO、DIRECT 和 GPOPS),但可以与 Scala 接口。

然而,在许多情况下,优化的参数或状态转移,或微分方程,并不确定。马尔可夫决策过程(MDPs)提供了一个数学框架来模拟在结果部分随机且部分受决策者控制的情况下的决策。在 MDPs 中,我们处理一组可能的状态和一组动作。奖励和状态转移既取决于状态也取决于动作。MDPs 对于通过动态规划和强化学习解决广泛的优化问题非常有用。

摘要

在本章中,我描述了一个设计数据驱动企业的整体架构和途径。我还向您介绍了影响图,这是一种理解在传统和数据驱动企业中如何做出决策的工具。我简要介绍了几个关键模型,如凯利准则和多臂老丨虎丨机,这些模型对于从数学角度展示问题至关重要。在此基础上,我介绍了基于先前决策和观察结果的处理决策策略的马尔可夫决策过程方法。我深入探讨了构建决策数据管道的更多实际方面,描述了可以用来构建它们的主要组件和框架。我还讨论了在不同阶段和节点之间传达数据和建模结果的问题,包括向用户展示结果、反馈循环和监控。

在下一章中,我将描述 MLlib,这是一个用于在 Scala 编写的分布式节点集上进行机器学习的库。

第三章:使用 Spark 和 MLlib

现在我们已经掌握了统计和机器学习在全球数据驱动企业架构中如何定位以及如何应用的知识,让我们专注于 Spark 和 MLlib 的具体实现。MLlib 是 Spark 之上的一个机器学习库。Spark 是大数据生态系统中的相对较新成员,它优化了内存使用而不是磁盘。数据在必要时仍然可以溢出到磁盘,但 Spark 只有在被明确指令或活动数据集不适合内存时才会进行溢出。Spark 存储 lineage 信息,以便在节点故障或由于其他原因信息从内存中删除时重新计算活动数据集。这与传统的 MapReduce 方法形成对比,在每次 map 或 reduce 任务之后,数据都会持久化到磁盘。

Spark 特别适合在分布式节点上执行迭代或统计机器学习算法,并且可以扩展到核心之外。唯一的限制是所有 Spark 节点上可用的总内存和磁盘空间以及网络速度。我将在本章中介绍 Spark 架构和实现的基础知识。

可以通过简单地更改配置参数,将 Spark 直接指向在单个节点或一组节点上执行数据管道。当然,这种灵活性是以稍微更重的框架和更长的设置时间为代价的,但框架非常易于并行化,并且由于大多数现代笔记本电脑已经多线程且足够强大,这通常不会成为一个大问题。

在本章中,我们将涵盖以下主题:

  • 如果您还没有这样做,安装和配置 Spark

  • 学习 Spark 架构的基础以及为什么它与 Scala 语言紧密相连

  • 学习为什么 Spark 是继顺序实现和 Hadoop MapReduce 之后的下一代技术

  • 学习 Spark 组件

  • 查看 Scala 和 Spark 中单词计数的简单实现

  • 查看流式单词计数实现

  • 了解如何从分布式文件或分布式数据库创建 Spark DataFrame

  • 学习 Spark 性能调优

设置 Spark

如果您还没有这样做,您可以从 spark.apache.org/downloads.html 下载预构建的 Spark 包。写作时的最新版本是 1.6.1

设置 Spark

图 03-1. 该章节推荐的下载网站 http://spark.apache.org,以及本章的推荐选择

或者,您可以通过从 github.com/apache/spark 下载完整的源代码分布来构建 Spark:

$ git clone https://github.com/apache/spark.git
Cloning into 'spark'...
remote: Counting objects: 301864, done.
...
$ cd spark
$sh ./ dev/change-scala-version.sh 2.11
...
$./make-distribution.sh --name alex-build-2.6-yarn --skip-java-test --tgz -Pyarn -Phive -Phive-thriftserver -Pscala-2.11 -Phadoop-2.6
...

命令将下载必要的依赖并创建位于 Spark 目录下的 spark-2.0.0-SNAPSHOT-bin-alex-spark-build-2.6-yarn.tgz 文件;版本号为 2.0.0,因为它是写作时的下一个发布版本。通常情况下,除非你对最新特性感兴趣,否则你不想从主干分支构建。如果你想获取发布版本,可以检出相应的标签。可以通过 git branch –r 命令查看可用的完整版本列表。spark*.tgz 文件是你在任何安装了 Java JRE 的机器上运行 Spark 所需要的全部。

该发行版附带 docs/building-spark.md 文档,该文档描述了构建 Spark 的其他选项及其描述,包括增量 Scala 编译器 zinc。下一个 Spark 2.0.0 版本的完整 Scala 2.11 支持正在开发中。

理解 Spark 架构

并行执行涉及将工作负载拆分为在不同线程或不同节点上执行的子任务。让我们看看 Spark 如何做到这一点,以及它是如何管理子任务之间的执行和通信的。

任务调度

Spark 工作负载拆分由 Resilient Distributed DatasetRDD)的分区数量决定,RDD 是 Spark 中的基本抽象,以及管道结构。RDD 代表了一个不可变、分区元素集合,这些元素可以并行操作。虽然具体细节可能取决于 Spark 运行的模式,但以下图表捕捉了 Spark 任务/资源调度的过程:

任务调度

图 03-2. 一个通用的 Spark 任务调度图。虽然图中没有明确显示,但 Spark Context 会打开一个 HTTP UI,通常在端口 4040(并发上下文将打开 4041、4042 等端口),在任务执行期间存在。Spark Master UI 通常为 8080(虽然在 CDH 中被改为 18080),Worker UI 通常为 7078。每个节点可以运行多个执行器,每个执行器可以运行多个任务。

小贴士

你会发现 Spark,以及 Hadoop,有很多参数。其中一些被指定为环境变量(参见图 $SPARK_HOME/conf/spark-env.sh 文件),还有一些可以作为命令行参数提供。此外,一些具有预定义名称的文件可以包含将改变 Spark 行为的参数,例如 core-site.xml。这可能会让人困惑,我将在本章和下一章尽可能多地涵盖这些内容。如果你正在使用 Hadoop 分布式文件系统HDFS),那么 core-site.xmlhdfs-site.xml 文件将包含 HDFS 主节点的指针和规范。选择此文件的要求是它必须位于 CLASSPATH Java 进程中,这可以通过指定 HADOOP_CONF_DIRSPARK_CLASSPATH 环境变量来设置。与开源项目一样,有时你需要 grep 代码来理解各种参数的工作方式,因此在你笔记本电脑上保留源代码树副本是个好主意。

集群中的每个节点可以运行一个或多个执行器,每个执行器可以调度一系列任务以执行 Spark 操作。Spark 驱动程序负责调度执行,并与集群调度器(如 Mesos 或 YARN)一起调度可用资源。Spark 驱动程序通常在客户端机器上运行,但在最新版本中,它也可以在集群管理器下运行。YARN 和 Mesos 有能力动态管理每个节点上并发运行的执行器数量,前提是满足资源限制。

在 Standalone 模式下,Spark Master 执行集群调度器的工作——在资源分配方面可能不太高效,但在没有预配置 Mesos 或 YARN 的情况下,总比没有好。Spark 标准发行版包含在 sbin 目录下启动 Spark 的 shell 脚本。Spark Master 和驱动程序直接与运行在各个节点上的一个或多个 Spark 工作节点通信。一旦主节点运行,你可以使用以下命令启动 Spark Shell:

$ bin/spark-shell --master spark://<master-address>:7077

小贴士

注意,你始终可以在本地模式下运行 Spark,这意味着所有任务都将在一个 JVM 中执行,通过指定 --master local[2],其中 2 是至少需要 2 个线程的数量。实际上,我们将在本书中非常频繁地使用本地模式来运行小型示例。

Spark Shell 是从 Spark 视角的一个应用程序。一旦你启动一个 Spark 应用程序,你将在 Spark Master UI(或相应的集群管理器)中的 运行中的应用程序 下看到它,这可以让你重定向到 Spark 应用程序 HTTP UI,端口为 4040,在那里可以看到子任务执行时间线以及其他重要属性,例如环境设置、类路径、传递给 JVM 的参数以及资源使用信息(参见图 3-3):

任务调度

图 03-3. 独立模式下的 Spark 驱动器 UI 与时间分解

如我们所见,使用 Spark,可以通过提供 --master 命令行选项、设置 MASTER 环境变量或修改 spark-defaults.conf(应在执行期间位于类路径上)来轻松地在本地模式和集群模式之间切换,或者甚至可以直接在 Scala 中使用 SparkConf 对象的 setters 方法显式设置,这将在后面介绍:

集群管理器 MASTER 环境变量 注释
本地模式(单节点,多线程) local[n] n 是要使用的线程数,应大于或等于 2. 如果您想让 Spark 与其他 Hadoop 工具(如 Hive)通信,您仍然需要通过设置 HADOOP_CONF_DIR 环境变量或将 Hadoop *-site.xml 配置文件复制到 conf 子目录来指向集群。
独立模式(在节点上运行的守护进程) spark:// master-address>:7077 此模式在 $SPARK_HOME/sbin 目录下有一组启动/停止脚本。此模式也支持高可用性模式。更多详情请参阅 spark.apache.org/docs/latest/spark-standalone.html
Mesos mesos://host:5050mesos://zk://host:2181(多主) 在这里,您需要设置 MESOS_NATIVE_JAVA_LIBRARY=<libmesos.so 路径>SPARK_EXECUTOR_URI=<spark-1.5.0.tar.gz 的 URL>. 默认为细粒度模式,其中每个 Spark 任务作为一个独立的 Mesos 任务运行。用户还可以指定粗粒度模式,其中 Mesos 任务持续整个应用程序的运行时间。这种模式的优点是总启动成本较低。在粗粒度模式下,可以使用动态分配(参考以下 URL)。更多详情请参阅 spark.apache.org/docs/latest/running-on-mesos.html
YARN yarn Spark 驱动器可以在集群或客户端节点上运行,由 --deploy-mode 参数(集群或客户端,shell 只能在客户端模式下运行)管理。设置 HADOOP_CONF_DIRYARN_CONF_DIR 以指向 YARN 配置文件。使用 --num-executors 标志或 spark.executor.instances 属性设置固定数量的执行器(默认)。将 spark.dynamicAllocation.enabled 设置为 true 以根据应用程序需求动态创建/销毁执行器。更多详情请参阅 spark.apache.org/docs/latest/running-on-yarn.html

最常见的端口是 8080(主 UI)和 4040(应用程序 UI)。其他 Spark 端口总结在下表中:

独立模式端口
默认端口 用途
--- --- --- ---
浏览器 独立模式主节点 8080 Web UI
浏览器 独立工作者 8081 Web UI
驾驶员/独立工作者 独立主节点 7077 将作业提交到集群/加入集群
独立主节点 独立工作者 (随机) 调度执行器
执行器/独立主节点 驱动器 (随机) 连接到应用程序/通知执行器状态变化
其他端口
默认端口 用途
浏览器 应用程序 4040 Web UI
浏览器 历史服务器 18080 Web UI
驱动器 执行器 (随机) 调度任务
执行器 驱动器 (随机) 文件服务器(用于文件和 jar 文件)
执行器 驱动器 (随机) HTTP 广播

此外,一些文档还可在源分布的docs子目录中找到,但可能已过时。

Spark 组件

自从 Spark 出现以来,已经编写了多个利用 Spark 缓存 RDD 能力的好处应用程序:Shark、Spork(Spark 上的 Pig)、图库(GraphX、GraphFrames)、流处理、MLlib 等;其中一些将在本章和后续章节中介绍。

在本节中,我将介绍 Spark 中收集、存储和分析数据的主要架构组件。虽然我将在第二章中介绍更完整的数据生命周期架构,数据管道和建模,但以下是 Spark 特定的组件:

Spark 组件

图 03-4。Spark 架构和组件。

MQTT、ZeroMQ、Flume 和 Kafka

所有这些都是在不丢失和重复的情况下,可靠地将数据从一个地方移动到另一个地方的不同方式。它们通常实现一个发布-订阅模型,其中多个编写者和读者可以从同一个队列中写入和读取,并具有不同的保证。Flume 作为一个第一个分布式日志和事件管理实现脱颖而出,但它正逐渐被 LinkedIn 开发的具有完全功能的发布-订阅分布式消息队列 Kafka 所取代,Kafka 可以选择性地在分布式节点集上持久化。我们在上一章中简要介绍了 Flume 和 Kafka。Flume 配置是基于文件的,传统上用于将消息从 Flume 源传递到 Flume 的一个或多个接收器。其中一种流行的源是netcat——监听端口的原始数据。例如,以下配置描述了一个每 30 秒(默认)接收数据并将其写入 HDFS 的代理:

# Name the components on this agent
a1.sources = r1
a1.sinks = k1
a1.channels = c1

# Describe/configure the source
a1.sources.r1.type = netcat
a1.sources.r1.bind = localhost
a1.sources.r1.port = 4987

# Describe the sink (the instructions to configure and start HDFS are provided in the Appendix)
a1.sinks.k1.type=hdfs
a1.sinks.k1.hdfs.path=hdfs://localhost:8020/flume/netcat/data
a1.sinks.k1.hdfs.filePrefix=chapter03.example
a1.sinks.k1.channel=c1
a1.sinks.k1.hdfs.writeFormat = Text

# Use a channel which buffers events in memory
a1.channels.c1.type = memory
a1.channels.c1.capacity = 1000
a1.channels.c1.transactionCapacity = 100

# Bind the source and sink to the channel
a1.sources.r1.channels = c1
a1.sinks.k1.channel = c1

此文件作为本书提供的代码的一部分包含在chapter03/conf目录中。让我们下载并启动 Flume 代理(使用flume.apache.org/download.html提供的 MD5 校验和进行检查):

$ wget http://mirrors.ocf.berkeley.edu/apache/flume/1.6.0/apache-flume-1.6.0-bin.tar.gz
$ md5sum apache-flume-1.6.0-bin.tar.gz
MD5 (apache-flume-1.6.0-bin.tar.gz) = defd21ad8d2b6f28cc0a16b96f652099
$ tar xf apache-flume-1.6.0-bin.tar.gz
$ cd apache-flume-1.6.0-bin
$ ./bin/flume-ng agent -Dlog.dir=. -Dflume.log.level=DEBUG,console -n a1 -f ../chapter03/conf/flume.conf
Info: Including Hadoop libraries found via (/Users/akozlov/hadoop-2.6.4/bin/hadoop) for HDFS access
Info: Excluding /Users/akozlov/hadoop-2.6.4/share/hadoop/common/lib/slf4j-api-1.7.5.jar from classpath
Info: Excluding /Users/akozlov/hadoop-2.6.4/share/hadoop/common/lib/slf4j-log4j12-1.7.5.jar from classpath
...

现在,在另一个窗口中,你可以输入一个 netcat 命令将文本发送到 Flume 代理:

$ nc localhost 4987
Hello
OK
World
OK

...

Flume 代理首先创建一个 *.tmp 文件,然后将其重命名为没有扩展名的文件(文件扩展名可以用来过滤正在写入的文件):

$ bin/hdfs dfs -text /flume/netcat/data/chapter03.example.1463052301372
16/05/12 04:27:25 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
1463052302380  Hello
1463052304307  World

在这里,每一行都是一个以毫秒为单位的 Unix 时间和接收到的数据。在这种情况下,我们将数据放入 HDFS,然后 Spark/Scala 程序可以从那里进行分析,我们可以排除以 *.tmp 文件名模式正在写入的文件。然而,如果你真的对最后一分钟的数据感兴趣,Spark 以及一些其他平台支持流式处理,我将在接下来的几节中介绍。

HDFS、Cassandra、S3 和 Tachyon

HDFS、Cassandra、S3 和 Tachyon 是将数据以不同保证的方式放入持久存储和计算节点的方法。HDFS 是作为 Hadoop 的一部分实现的分布式存储,为 Hadoop 生态系统中的许多产品提供后端服务。HDFS 将每个文件划分为块,默认块大小为 128 MB,并将每个块存储在至少三个节点上。尽管 HDFS 是可靠的并支持高可用性,但关于 HDFS 存储的一般抱怨是它速度较慢,尤其是在机器学习方面。Cassandra 是一种通用键/值存储,也存储行的多个副本,并可以配置为支持不同级别的数据一致性以优化读写速度。Cassandra 相比于 HDFS 模型的优势在于它没有中央主节点;读取和写入是基于共识算法完成的。然而,这有时可能会反映在 Cassandra 的稳定性上。S3 是亚马逊存储:数据存储在集群之外,这影响了 I/O 速度。最后,最近开发的 Tachyon 声称利用节点的内存来优化节点间工作集的访问。

此外,新的后端正在不断开发中,例如,来自 Cloudera 的 Kudu (getkudu.io/kudu.pdf) 和来自 GridGain 的 Ignite 文件系统 (IGFS) (apacheignite.gridgain.org/v1.0/docs/igfs))。两者都是开源的,并拥有 Apache 许可证。

Mesos、YARN 和 Standalone

正如我们之前提到的,Spark 可以在不同的集群资源调度器下运行。这些是用于在集群上调度 Spark 容器和任务的多种实现。调度器可以被视为集群内核,执行类似于操作系统内核的功能:资源分配、调度、I/O 优化、应用程序服务和用户界面。

Mesos 是最早的集群管理器之一,它使用与 Linux 内核相同的原理构建,只是在不同的抽象级别上。Mesos 从节点在每个机器上运行并提供跨整个数据中心和云环境的资源管理和调度 API。Mesos 使用 C++ 编写。

YARN 是由 Yahoo 开发的一个较新的集群管理器。YARN 中的每个节点运行一个节点管理器,它与可能运行在单独节点上的资源管理器通信。资源管理器调度任务以满足内存和 CPU 限制。Spark 驱动程序本身可以在集群中运行,这被称为 YARN 的集群模式。否则,在客户端模式下,只有 Spark 执行器在集群中运行,而调度 Spark 管道的驱动程序运行在运行 Spark shell 或提交程序的同一台机器上。在这种情况下,Spark 执行器将通过一个随机开放的端口与本地主机通信。YARN 是用 Java 编写的,其后果是 GC 暂停不可预测,这可能会使延迟的长尾更宽。

最后,如果这些资源调度器都不可用,独立部署模式将在每个节点上启动一个org.apache.spark.deploy.worker.Worker进程,该进程与作为org.apache.spark.deploy.master.Master运行的 Spark Master 进程通信。工作进程完全由主进程管理,可以运行多个执行器和任务(参见图 3-2)。

在实际实施中,建议通过驱动程序的 UI 跟踪程序的并行性和所需资源,并根据需要调整并行性和可用内存,如果需要的话增加并行性。在下一节中,我们将开始探讨 Scala 和 Spark 中的 Scala 如何解决不同的问题。

应用程序

让我们考虑一些 Spark/Scala 中的实际示例和库,从一个非常传统的单词计数问题开始。

单词计数

大多数现代机器学习算法需要对数据进行多次遍历。如果数据适合单个机器的内存,数据就可以随时可用,这不会成为性能瓶颈。然而,如果数据变得太大而无法放入 RAM 中,可以选择将数据的一部分(或数据库)写入磁盘,这大约慢 100 倍,但容量要大得多,或者在网络中的多台机器之间分割数据集并传输结果。尽管仍有持续的争论,但对于大多数实际系统,分析表明,在一系列网络连接的节点上存储数据,与在单个节点上反复从硬盘存储和读取数据相比,略有优势,尤其是如果我们能够有效地在多个 CPU 之间分配工作负载。

小贴士

一块普通磁盘的带宽大约为 100 MB/sec,传输延迟仅为几毫秒,这取决于旋转速度和缓存。这比从内存中读取数据慢大约 100 倍,具体取决于数据大小和缓存实现。现代数据总线可以以超过 10 GB/sec 的速度传输数据。尽管网络速度仍然落后于直接内存访问,尤其是在标准 TCP/IP 内核网络层开销的情况下,但专用硬件可以达到数十 GB/sec,如果并行运行,其速度可能接近从内存中读取。实际上,网络传输速度在 1 到 10 GB/sec 之间,但在大多数实际系统中仍然比磁盘快。因此,我们有可能将数据放入所有集群节点的组合内存中,并在它们组成的系统中执行迭代机器学习算法。

然而,内存的一个问题是它无法在节点故障和重启后持久化。一个流行的大数据框架 Hadoop,在原始的 Dean/Ghemawat 论文(Jeff Dean 和 Sanjay Ghemawat,MapReduce: Simplified Data Processing on Large Clusters,OSDI,2004 年)的帮助下成为可能,正是使用磁盘层持久化来保证容错性和存储中间结果。一个 Hadoop MapReduce 程序首先在数据集的每一行上运行一个map函数,输出一个或多个键值对。然后,这些键值对将被排序、分组和按键聚合,以便具有相同键的记录最终会在同一个 reducer 上一起处理,这个 reducer 可能运行在同一个或另一个节点上。reducer 应用一个reduce函数,遍历为相同键发出的所有值,并相应地聚合它们。中间结果的持久化将保证如果 reducer 由于一个或另一个原因失败,可以丢弃部分计算,并从检查点保存的结果重新启动 reduce 计算。许多简单的 ETL-like 应用程序仅遍历数据集一次,并且从一条记录到另一条记录只保留很少的信息作为状态。

例如,MapReduce 的一个传统应用是词频统计。程序需要统计文档中每行文本中每个单词的出现次数。在 Scala 中,词频统计可以很容易地表示为对排序单词列表应用foldLeft方法:

val lines = scala.io.Source.fromFile("...").getLines.toSeq
val counts = lines.flatMap(line => line.split("\\W+")).sorted.
  foldLeft(List[(String,Int)]()){ (r,c) =>
    r match {
      case (key, count) :: tail =>
        if (key == c) (c, count+1) :: tail
        else (c, 1) :: r
        case Nil =>
          List((c, 1))
  }
}

如果我运行这个程序,输出将是一个包含(word, count)元组的列表。程序将行分割成单词,对单词进行排序,然后将每个单词与(word, count)元组列表中的最新条目进行匹配。在 MapReduce 中,同样的计算可以表示如下:

val linesRdd = sc.textFile("hdfs://...")
val counts = linesRdd.flatMap(line => line.split("\\W+"))
    .map(_.toLowerCase)
    .map(word => (word, 1)).
    .reduceByKey(_+_)
counts.collect

首先,我们需要通过将行拆分为单词和生成(word, 1)对来处理文本的每一行。这个任务很容易并行化。然后,为了并行化全局计数,我们需要通过为单词的子集分配一个任务来拆分计数部分。在 Hadoop 中,我们计算单词的哈希值并根据哈希值来划分工作。

一旦 map 任务找到了给定哈希的所有条目,它就可以将键值对发送给 reducer,这部分发送通常在 MapReduce 术语中称为 shuffle。reducer 会等待从所有 mapper 那里接收到所有的键值对,合并值——如果可能的话,在 mapper 上也可以进行部分合并——并计算整体汇总,在这种情况下就是求和。单个 reducer 将看到给定单词的所有值。

让我们看看 Spark 中单词计数操作的日志输出(Spark 默认非常详细,你可以通过修改conf/log4j.properties文件来管理详细程度,将INFO替换为ERRORFATAL):

$ wget http://mirrors.sonic.net/apache/spark/spark-1.6.1/spark-1.6.1-bin-hadoop2.6.tgz
$ tar xvf spark-1.6.1-bin-hadoop2.6.tgz
$ cd spark-1.6.1-bin-hadoop2.6
$ mkdir leotolstoy
$ (cd leotolstoy; wget http://www.gutenberg.org/files/1399/1399-0.txt)
$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1
 /_/

Using Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.
scala> val linesRdd = sc.textFile("leotolstoy", minPartitions=10)
linesRdd: org.apache.spark.rdd.RDD[String] = leotolstoy MapPartitionsRDD[3] at textFile at <console>:27

在这个阶段,唯一发生的事情是元数据操作,Spark 还没有触及数据本身。Spark 估计数据集的大小和分区数。默认情况下,这是 HDFS 块的数量,但我们可以使用minPartitions参数显式指定最小分区数:

scala> val countsRdd = linesRdd.flatMap(line => line.split("\\W+")).
 | map(_.toLowerCase).
 | map(word => (word, 1)).
 | reduceByKey(_+_)
countsRdd: org.apache.spark.rdd.RDD[(String, Int)] = ShuffledRDD[5] at reduceByKey at <console>:31

我们刚刚定义了另一个由原始linesRdd派生出的 RDD:

scala> countsRdd.collect.filter(_._2 > 99)
res3: Array[(String, Int)] = Array((been,1061), (them,841), (found,141), (my,794), (often,105), (table,185), (this,1410), (here,364), (asked,320), (standing,132), ("",13514), (we,592), (myself,140), (is,1454), (carriage,181), (got,277), (won,153), (girl,117), (she,4403), (moment,201), (down,467), (me,1134), (even,355), (come,667), (new,319), (now,872), (upon,207), (sister,115), (veslovsky,110), (letter,125), (women,134), (between,138), (will,461), (almost,124), (thinking,159), (have,1277), (answer,146), (better,231), (men,199), (after,501), (only,654), (suddenly,173), (since,124), (own,359), (best,101), (their,703), (get,304), (end,110), (most,249), (but,3167), (was,5309), (do,846), (keep,107), (having,153), (betsy,111), (had,3857), (before,508), (saw,421), (once,334), (side,163), (ough...

对超过 2 GB 的文本数据进行单词计数——40,291 行和 353,087 个单词——读取、拆分和按单词分组耗时不到一秒。

使用扩展日志,你可以看到以下内容:

  • Spark 打开了一些端口以与 executors 和用户通信

  • Spark UI 在http://localhost:4040的 4040 端口上运行

  • 你可以从本地或分布式存储(HDFS、Cassandra 和 S3)读取文件

  • 如果 Spark 带有 Hive 支持构建,Spark 将连接到 Hive

  • Spark 使用懒加载评估,仅在必要时或需要输出时才执行管道

  • Spark 使用内部调度器将作业拆分为任务,优化执行并执行任务

  • 结果存储到 RDD 中,可以使用save方法保存或使用collect方法将其带入执行 shell 的节点的 RAM 中

并行性能调优的艺术在于在不同节点或线程之间分配工作负载,以便开销相对较小且工作负载平衡。

流式单词计数

Spark 支持监听传入的流,对其进行分区,并接近实时地计算汇总。目前支持的资源包括 Kafka、Flume、HDFS/S3、Kinesis、Twitter,以及传统的 MQs,如 ZeroMQ 和 MQTT。在 Spark 中,流式处理作为微批处理实现的。内部,Spark 将输入数据划分为微批处理,通常从亚秒到分钟大小,并对这些微批处理执行 RDD 聚合操作。

例如,让我们扩展一下我们之前提到的 Flume 示例。我们需要修改 Flume 配置文件以创建一个 Spark 轮询接收器。用 HDFS 代替接收器部分:

# The sink is Spark
a1.sinks.k1.type=org.apache.spark.streaming.flume.sink.SparkSink
a1.sinks.k1.hostname=localhost
a1.sinks.k1.port=4989

现在,Flume 将等待 Spark 轮询数据,而不是写入 HDFS:

object FlumeWordCount {
  def main(args: Array[String]) {
    // Create the context with a 2 second batch size
    val sparkConf = new SparkConf().setMaster("local[2]").setAppName("FlumeWordCount")
    val ssc = new StreamingContext(sparkConf, Seconds(2))
    ssc.checkpoint("/tmp/flume_check")
    val hostPort=args(0).split(":")
    System.out.println("Opening a sink at host: [" + hostPort(0) + "] port: [" + hostPort(1).toInt + "]")
    val lines = FlumeUtils.createPollingStream(ssc, hostPort(0), hostPort(1).toInt, StorageLevel.MEMORY_ONLY)
    val words = lines
      .map(e => new String(e.event.getBody.array)).map(_.toLowerCase).flatMap(_.split("\\W+"))
      .map(word => (word, 1L))
      .reduceByKeyAndWindow(_+_, _-_, Seconds(6), Seconds(2)).print
    ssc.start()
    ssc.awaitTermination()
  }
}

要运行程序,在一个窗口中启动 Flume 代理:

$ ./bin/flume-ng agent -Dflume.log.level=DEBUG,console -n a1 –f ../chapter03/conf/flume-spark.conf
...

然后在另一个窗口中运行 FlumeWordCount 对象:

$ cd ../chapter03
$ sbt "run-main org.akozlov.chapter03.FlumeWordCount localhost:4989
...

现在,任何输入到 netcat 连接的文本将被分割成单词,并在每两秒为一个六秒滑动窗口内进行计数:

$ echo "Happy families are all alike; every unhappy family is unhappy in its own way" | nc localhost 4987
...
-------------------------------------------
Time: 1464161488000 ms
-------------------------------------------
(are,1)
(is,1)
(its,1)
(family,1)
(families,1)
(alike,1)
(own,1)
(happy,1)
(unhappy,2)
(every,1)
...

-------------------------------------------
Time: 1464161490000 ms
-------------------------------------------
(are,1)
(is,1)
(its,1)
(family,1)
(families,1)
(alike,1)
(own,1)
(happy,1)
(unhappy,2)
(every,1)
...

Spark/Scala 允许无缝地在流源之间切换。例如,用于 Kafka 发布/订阅主题模型的相同程序看起来类似于以下内容:

object KafkaWordCount {
  def main(args: Array[String]) {
    // Create the context with a 2 second batch size
    val sparkConf = new SparkConf().setMaster("local[2]").setAppName("KafkaWordCount")
    val ssc = new StreamingContext(sparkConf, Seconds(2))
    ssc.checkpoint("/tmp/kafka_check")
    System.out.println("Opening a Kafka consumer at zk:[" + args(0) + "] for group group-1 and topic example")
    val lines = KafkaUtils.createStream(ssc, args(0), "group-1", Map("example" -> 1), StorageLevel.MEMORY_ONLY)
    val words = lines
      .flatMap(_._2.toLowerCase.split("\\W+"))
      .map(word => (word, 1L))
      .reduceByKeyAndWindow(_+_, _-_, Seconds(6), Seconds(2)).print
    ssc.start()
    ssc.awaitTermination()
  }
}

要启动 Kafka 代理,首先下载最新的二进制发行版并启动 ZooKeeper。ZooKeeper 是分布式服务协调器,即使在单节点部署中 Kafka 也需要它:

$ wget http://apache.cs.utah.edu/kafka/0.9.0.1/kafka_2.11-0.9.0.1.tgz
...
$ tar xf kafka_2.11-0.9.0.1.tgz
$ bin/zookeeper-server-start.sh config/zookeeper.properties
...

在另一个窗口中,启动 Kafka 服务器:

$ bin/kafka-server-start.sh config/server.properties
...

运行 KafkaWordCount 对象:

$ sbt "run-main org.akozlov.chapter03.KafkaWordCount localhost:2181"
...

现在,将单词流发布到 Kafka 主题将生成窗口计数:

$ echo "Happy families are all alike; every unhappy family is unhappy in its own way" | ./bin/kafka-console-producer.sh --broker-list localhost:9092 --topic example
...

$ sbt "run-main org.akozlov.chapter03.FlumeWordCount localhost:4989
...
-------------------------------------------
Time: 1464162712000 ms
-------------------------------------------
(are,1)
(is,1)
(its,1)
(family,1)
(families,1)
(alike,1)
(own,1)
(happy,1)
(unhappy,2)
(every,1)

如您所见,程序每两秒输出一次。Spark 流处理有时被称为 微批处理。流处理有其他许多应用(和框架),但这个话题太大,不能完全在这里讨论,需要单独介绍。我将在 第五章,回归和分类 中介绍一些关于数据流中的机器学习。现在,让我们回到更传统的类似 SQL 的接口。

Spark SQL 和 DataFrame

DataFrame 是 Spark 中相对较新的功能,自 1.3 版本引入,允许使用标准 SQL 语言进行数据分析。我们已经在 第一章,探索性数据分析 中使用了一些 SQL 命令,用于探索性数据分析。SQL 对于简单的探索性分析和数据聚合来说非常好用。

根据最新的调查结果,大约 70% 的 Spark 用户使用 DataFrame。尽管 DataFrame 最近已成为处理表格数据最流行的框架,但它是一个相对较重的对象。使用 DataFrame 的管道可能比基于 Scala 的向量或 LabeledPoint 的管道执行得慢得多,这些将在下一章中讨论。来自不同开发者的证据表明,响应时间可以根据查询从几十毫秒到几百毫秒不等,对于更简单的对象,甚至可以低于毫秒。

Spark 实现了自己的 SQL shell,除了标准的 Scala REPL shell 之外还可以调用:./bin/spark-sql 可以用来访问现有的 Hive/Impala 或关系型数据库表:

$ ./bin/spark-sql
…
spark-sql> select min(duration), max(duration), avg(duration) from kddcup;
…
0  58329  48.34243046395876
Time taken: 11.073 seconds, Fetched 1 row(s)

在标准的 Spark REPL 中,可以通过运行以下命令执行相同的查询:

$ ./bin/spark-shell
…
scala> val df = sqlContext.sql("select min(duration), max(duration), avg(duration) from kddcup"
16/05/12 13:35:34 INFO parse.ParseDriver: Parsing command: select min(duration), max(duration), avg(duration) from alex.kddcup_parquet
16/05/12 13:35:34 INFO parse.ParseDriver: Parse Completed
df: org.apache.spark.sql.DataFrame = [_c0: bigint, _c1: bigint, _c2: double]
scala> df.collect.foreach(println)
…
16/05/12 13:36:32 INFO scheduler.DAGScheduler: Job 2 finished: collect at <console>:22, took 4.593210 s
[0,58329,48.34243046395876]

机器学习库

Spark,尤其是与基于内存的存储系统结合使用,声称可以显著提高节点内和节点间的数据访问速度。机器学习似乎是一个自然的选择,因为许多算法需要对数据进行多次遍历或重新分区。MLlib 是首选的开源库,尽管私有公司也在追赶,推出了自己的专有实现。

正如我将在第五章中阐述的,回归与分类,大多数标准机器学习算法都可以表示为一个优化问题。例如,经典线性回归最小化回归线与实际值 y 之间 y 距离的平方和:

机器学习库

在这里,机器学习库是根据线性表达式预测的值:

机器学习库

A 通常被称为斜率,而 B 被称为截距。在更广义的表述中,线性优化问题是最小化一个加性函数:

机器学习库

在这里,机器学习库是一个损失函数,机器学习库是一个正则化函数。正则化函数是模型复杂度的递增函数,例如参数的数量(或其自然对数)。以下表格给出了最常见的损失函数:

损失函数 L 梯度
线性 机器学习库 机器学习库
逻辑回归 机器学习库 机器学习库
拉链损失 机器学习库 机器学习库

正则化器的目的是惩罚更复杂的模型,以避免过拟合并提高泛化误差:目前 MLlib 支持以下正则化器:

正则化器 R 梯度
L2 机器学习库 机器学习库
L1 机器学习库 机器学习库
弹性网络 机器学习库 机器学习库

在这里,sign(w)w 所有元素的符号向量。

目前,MLlib 包括以下算法的实现:

  • 基本统计学:

    • 概率统计

    • 相关系数

    • 分层抽样

    • 假设检验

    • 流式显著性检验

    • 随机数据生成

  • 分类与回归:

    • 线性模型(SVM、逻辑回归和线性回归)

    • 朴素贝叶斯

    • 决策树

    • 树的集成(随机森林和梯度提升树)

    • 等距回归

  • 协同过滤:

    • 交替最小二乘法 (ALS)
  • 聚类:

    • k-means

    • 高斯混合

    • 幂迭代聚类 (PIC)

    • 潜在狄利克雷分配 (LDA)

    • 二分 k-means

    • 流式 k-means

  • 维度降低:

    • 奇异值分解SVD

    • 主成分分析PCA

  • 特征提取和转换

  • 频繁模式挖掘:

    • FP-growth?关联规则

    • PrefixSpan

  • 优化:

    • 随机梯度下降SGD

    • 有限内存 BFGSL-BFGS

我将在第五章中介绍一些算法,回归和分类。更复杂的非结构化机器学习方法将在第六章中考虑,处理非结构化数据

SparkR

R 是由约翰·查默斯在贝尔实验室工作时创建的流行 S 编程语言的实现。R 目前由R 统计计算基金会支持。根据调查,R 的普及率近年来有所增加。SparkR 提供了从 R 使用 Apache Spark 的轻量级前端。从 Spark 1.6.0 开始,SparkR 提供了支持选择、过滤、聚合等操作的分布式 DataFrame 实现,这与 R DataFrames、dplyr 类似,但适用于非常大的数据集。SparkR 还支持使用 MLlib 的分布式机器学习。

SparkR 需要 R 版本 3 或更高版本,可以通过./bin/sparkR shell 调用。我将在第八章中介绍 SparkR,集成 Scala 与 R 和 Python

图算法 – GraphX 和 GraphFrames

图算法是节点间正确分配中最困难的一类,除非图本身是自然划分的,也就是说,它可以表示为一组不连接的子图。由于 Facebook、Google 和 LinkedIn 等公司使得在数百万节点规模上的社交网络分析变得流行,研究人员一直在提出新的方法来形式化图表示、算法和提出的问题类型。

GraphX 是一个现代的图计算框架,由 2013 年的一篇论文描述(Reynold Xin、Joseph Gonzalez、Michael Franklin 和 Ion Stoica 的《GraphX: A Resilient Distributed Graph System on Spark》,GRADES(SIGMOD workshop),2013)。它有 Pregel 和 PowerGraph 这样的图并行框架作为前身。图由两个 RDD 表示:一个用于顶点,另一个用于边。一旦 RDDs 被连接,GraphX 支持 Pregel-like API 或 MapReduce-like API,其中 map 函数应用于节点的邻居,reduce 是在 map 结果之上的聚合步骤。

在撰写本文时,GraphX 包括以下图算法的实现:

  • PageRank

  • 连通分量

  • 三角计数

  • 标签传播

  • SVD++(协同过滤)

  • 强连通分量

由于 GraphX 是一个开源库,预计列表会有所变化。GraphFrames 是 Databricks 的新实现,完全支持以下三种语言:Scala、Java 和 Python,并且建立在 DataFrames 之上。我将在第七章使用图算法中讨论具体的实现。

Spark 性能调优

虽然数据管道的高效执行是任务调度器的特权,它是 Spark 驱动程序的一部分,但有时 Spark 需要提示。Spark 调度主要受两个参数驱动:CPU 和内存。当然,其他资源,如磁盘和网络 I/O,在 Spark 性能中也扮演着重要角色,但 Spark、Mesos 或 YARN 目前无法主动管理它们。

需要关注的第一个参数是 RDD 分区数,在从文件读取 RDD 时可以显式指定。Spark 通常倾向于过多的分区,因为它提供了更多的并行性,而且在许多情况下确实有效,因为任务设置/拆除时间相对较小。然而,人们可以尝试减少分区数,尤其是在进行聚合时。

默认情况下,每个 RDD 的分区数和并行级别由spark.default.parallelism参数决定,该参数定义在$SPARK_HOME/conf/spark-defaults.conf配置文件中。特定 RDD 的分区数也可以通过coalesce()repartition()方法显式更改。

总核心数和可用内存通常是死锁的原因,因为任务无法进一步执行。在通过命令行调用 spark-submit、spark-shell 或 PySpark 时,可以使用--executor-cores标志指定每个 executor 的核心数。或者,也可以在前面讨论的spark-defaults.conf文件中设置相应的参数。如果核心数设置得太高,调度器将无法在节点上分配资源,从而导致死锁。

以类似的方式,--executor-memory(或spark.executor.memory属性)指定了所有任务请求的堆大小(默认为 1g)。如果 executor 内存设置得太高,同样,调度器可能会死锁,或者只能在节点上调度有限数量的 executor。

在独立模式下,在计算核心数和内存时隐含的假设是 Spark 是唯一运行的应用程序——这可能是也可能不是真的。当在 Mesos 或 YARN 下运行时,配置集群调度器以使其具有 Spark 驱动器请求的执行器资源非常重要。相关的 YARN 属性是:yarn.nodemanager.resource.cpu-vcoresyarn.nodemanager.resource.memory-mb。YARN 可能会将请求的内存向上取整一点。YARN 的yarn.scheduler.minimum-allocation-mbyarn.scheduler.increment-allocation-mb属性分别控制最小和增量请求值。

JVM 也可以使用一些堆外内存,例如,用于内部字符串和直接字节数据缓冲区。spark.yarn.executor.memoryOverhead属性的值被添加到执行器内存中,以确定每个执行器对 YARN 的完整内存请求。默认值为最大值(384 + 0.07 * spark.executor.memory)。

由于 Spark 可以在执行器和客户端节点之间内部传输数据,因此高效的序列化非常重要。我将在第六章中考虑不同的序列化框架,处理非结构化数据,但 Spark 默认使用 Kryo 序列化,这要求类必须显式地在静态方法中注册。如果你在代码中看到序列化错误,很可能是因为相应的类尚未注册或 Kryo 不支持它,就像在过于嵌套和复杂的数据类型中发生的那样。一般来说,建议避免在执行器之间传递复杂对象,除非对象序列化可以非常高效地进行。

驱动器有类似的参数:spark.driver.coresspark.driver.memoryspark.driver.maxResultSize。后者设置了使用collect方法收集的所有执行器的结果限制。保护驱动器进程免受内存不足异常非常重要。避免内存不足异常及其后续问题的另一种方法是修改管道以返回聚合或过滤后的结果,或者使用take方法代替。

运行 Hadoop HDFS

一个分布式处理框架如果没有分布式存储就不会完整。其中之一是 HDFS。即使 Spark 在本地模式下运行,它仍然可以使用后端分布式文件系统。就像 Spark 将计算分解为子任务一样,HDFS 将文件分解为块并在多台机器上存储它们。对于高可用性,HDFS 为每个块存储多个副本,副本的数量称为复制级别,默认为三个(参见图 3-5)。

NameNode 通过记住块位置和其他元数据(如所有者、文件权限和块大小,这些是文件特定的)来管理 HDFS 存储。Secondary Namenode 是一个轻微的错误名称:其功能是将元数据修改、编辑合并到 fsimage 中,或者是一个充当元数据库的文件。合并是必要的,因为将 fsimage 的修改写入单独的文件比直接将每个修改应用于 fsimage 的磁盘镜像更实际(除了在内存中应用相应的更改)。二级 Namenode 不能作为 Namenode 的第二个副本。运行 Balancer 以将块移动到服务器之间保持大约相等的磁盘使用率——节点上的初始块分配应该是随机的,如果空间足够且客户端不在集群内运行。最后,ClientNamenode 通信以获取元数据和块位置,但之后,要么直接读取或写入数据到节点,其中包含块的副本。客户端是唯一可以在 HDFS 集群外运行的组件,但它需要与集群中所有节点建立网络连接。

如果任何节点死亡或从网络断开连接,Namenode 会注意到这种变化,因为它通过心跳不断与节点保持联系。如果节点在 10 分钟内(默认值)未能重新连接到 Namenode,则 Namenode 将开始复制块,以实现节点上丢失的块所需的复制级别。Namenode 中的单独块扫描线程将扫描块以查找可能的位错——每个块都维护一个校验和——并将损坏的和孤立的块删除:

运行 Hadoop HDFS

图 03-5. 这是 HDFS 架构。每个块存储在三个不同的位置(复制级别)。

  1. 在您的机器上启动 HDFS(复制级别为 1)时,请下载一个 Hadoop 发行版,例如,从 hadoop.apache.org

    $ wget ftp://apache.cs.utah.edu/apache.org/hadoop/common/h/hadoop-2.6.4.tar.gz
    --2016-05-12 00:10:55--  ftp://apache.cs.utah.edu/apache.org/hadoop/common/hadoop-2.6.4/hadoop-2.6.4.tar.gz
     => 'hadoop-2.6.4.tar.gz.1'
    Resolving apache.cs.utah.edu... 155.98.64.87
    Connecting to apache.cs.utah.edu|155.98.64.87|:21... connected.
    Logging in as anonymous ... Logged in!
    ==> SYST ... done.    ==> PWD ... done.
    ==> TYPE I ... done.  ==> CWD (1) /apache.org/hadoop/common/hadoop-2.6.4 ... done.
    ==> SIZE hadoop-2.6.4.tar.gz ... 196015975
    ==> PASV ... done.    ==> RETR hadoop-2.6.4.tar.gz ... done.
    ...
    $ wget ftp://apache.cs.utah.edu/apache.org/hadoop/common/hadoop-2.6.4/hadoop-2.6.4.tar.gz.mds
    --2016-05-12 00:13:58--  ftp://apache.cs.utah.edu/apache.org/hadoop/common/hadoop-2.6.4/hadoop-2.6.4.tar.gz.mds
     => 'hadoop-2.6.4.tar.gz.mds'
    Resolving apache.cs.utah.edu... 155.98.64.87
    Connecting to apache.cs.utah.edu|155.98.64.87|:21... connected.
    Logging in as anonymous ... Logged in!
    ==> SYST ... done.    ==> PWD ... done.
    ==> TYPE I ... done.  ==> CWD (1) /apache.org/hadoop/common/hadoop-2.6.4 ... done.
    ==> SIZE hadoop-2.6.4.tar.gz.mds ... 958
    ==> PASV ... done.    ==> RETR hadoop-2.6.4.tar.gz.mds ... done.
    ...
    $ shasum -a 512 hadoop-2.6.4.tar.gz
    493cc1a3e8ed0f7edee506d99bfabbe2aa71a4776e4bff5b852c6279b4c828a0505d4ee5b63a0de0dcfecf70b4bb0ef801c767a068eaeac938b8c58d8f21beec  hadoop-2.6.4.tar.gz
    $ cat !$.mds
    hadoop-2.6.4.tar.gz:    MD5 = 37 01 9F 13 D7 DC D8 19  72 7B E1 58 44 0B 94 42
    hadoop-2.6.4.tar.gz:   SHA1 = 1E02 FAAC 94F3 35DF A826  73AC BA3E 7498 751A 3174
    hadoop-2.6.4.tar.gz: RMD160 = 2AA5 63AF 7E40 5DCD 9D6C  D00E EBB0 750B D401 2B1F
    hadoop-2.6.4.tar.gz: SHA224 = F4FDFF12 5C8E754B DAF5BCFC 6735FCD2 C6064D58
     36CB9D80 2C12FC4D
    hadoop-2.6.4.tar.gz: SHA256 = C58F08D2 E0B13035 F86F8B0B 8B65765A B9F47913
     81F74D02 C48F8D9C EF5E7D8E
    hadoop-2.6.4.tar.gz: SHA384 = 87539A46 B696C98E 5C7E352E 997B0AF8 0602D239
     5591BF07 F3926E78 2D2EF790 BCBB6B3C EAF5B3CF
     ADA7B6D1 35D4B952
    hadoop-2.6.4.tar.gz: SHA512 = 493CC1A3 E8ED0F7E DEE506D9 9BFABBE2 AA71A477
     6E4BFF5B 852C6279 B4C828A0 505D4EE5 B63A0DE0
     DCFECF70 B4BB0EF8 01C767A0 68EAEAC9 38B8C58D
     8F21BEEC
    
    $ tar xf hadoop-2.6.4.tar.gz
    $ cd hadoop-2.6.4
    
    
  2. 要获取最小的 HDFS 配置,修改 core-site.xmlhdfs-site.xml 文件,如下所示:

    $ cat << EOF > etc/hadoop/core-site.xml
    <configuration>
     <property>
     <name>fs.defaultFS</name>
     <value>hdfs://localhost:8020</value>
     </property>
    </configuration>
    EOF
    $ cat << EOF > etc/hadoop/hdfs-site.xml
    <configuration>
     <property>
     <name>dfs.replication</name>
     <value>1</value>
     </property>
    </configuration>
    EOF
    
    

    这将把 Hadoop HDFS 元数据和数据目录放在 /tmp/hadoop-$USER 目录下。为了使其更加持久,我们可以添加 dfs.namenode.name.dirdfs.namenode.edits.dirdfs.datanode.data.dir 参数,但现在我们将省略这些。对于更定制的发行版,可以从 archive.cloudera.com/cdh 下载 Cloudera 版本。

  3. 首先,我们需要编写一个空元数据:

    $ bin/hdfs namenode -format
    16/05/12 00:55:40 INFO namenode.NameNode: STARTUP_MSG: 
    /************************************************************
    STARTUP_MSG: Starting NameNode
    STARTUP_MSG:   host = alexanders-macbook-pro.local/192.168.1.68
    STARTUP_MSG:   args = [-format]
    STARTUP_MSG:   version = 2.6.4
    STARTUP_MSG:   classpath =
    ...
    
    
  4. 然后启动 namenodesecondarynamenodedatanode Java 进程(我通常打开三个不同的命令行窗口来查看日志,但在生产环境中,这些通常被作为守护进程运行):

    $ bin/hdfs namenode &
    ...
    $ bin/hdfs secondarynamenode &
    ...
    $ bin/hdfs datanode &
    ...
    
    
  5. 现在我们已经准备好创建第一个 HDFS 文件:

    $ date | bin/hdfs dfs –put – date.txt
    ...
    $ bin/hdfs dfs –ls
    Found 1 items
    -rw-r--r-- 1 akozlov supergroup 29 2016-05-12 01:02 date.txt
    $ bin/hdfs dfs -text date.txt
    Thu May 12 01:02:36 PDT 2016
    
    
  6. 当然,在这种情况下,实际文件只存储在一个节点上,这个节点就是我们运行 datanode 的节点(localhost)。在我的情况下,如下所示:

    $ cat /tmp/hadoop-akozlov/dfs/data/current/BP-1133284427-192.168.1.68-1463039756191/current/finalized/subdir0/subdir0/blk_1073741827
    Thu May 12 01:02:36 PDT 2016
    
    
  7. Namenode UI 可在 http://localhost:50070 找到,并显示大量信息,包括 HDFS 使用情况和 DataNodes 列表,即 HDFS 主节点的奴隶,如下所示:运行 Hadoop HDFS

    图 03-6. HDFS NameNode UI 的快照。

前面的图显示了单节点部署中的 HDFS Namenode HTTP UI(通常为 http://<namenode-address>:50070)。实用工具 | 浏览文件系统选项卡允许您浏览和下载 HDFS 中的文件。可以通过在另一个节点上启动 DataNodes 并使用 fs.defaultFS=<namenode-address>:8020 参数指向 Namenode 来添加节点。辅助 Namenode HTTP UI 通常位于 http:<secondarynamenode-address>:50090

Scala/Spark 默认将使用本地文件系统。然而,如果 core-site/xml 文件位于类路径中或放置在 $SPARK_HOME/conf 目录中,Spark 将使用 HDFS 作为默认。

摘要

在本章中,我以非常高的水平介绍了 Spark/Hadoop 以及它们与 Scala 和函数式编程的关系。我考虑了一个经典的词频统计示例及其在 Scala 和 Spark 中的实现。我还提供了 Spark 生态系统的高级组件,包括词频统计和流处理的特定示例。我现在有了所有组件来开始查看 Scala/Spark 中经典机器学习算法的具体实现。在下一章中,我将首先介绍监督学习和无监督学习——这是结构化数据学习算法的传统划分。

第四章 监督学习和无监督学习

我在上一章中介绍了 MLlib 库的基础知识,但至少在撰写本书时,MLlib 更像是一个快速移动的目标,它正在取得领先地位,而不是一个结构良好的实现,每个人都用于生产中,甚至没有一个一致且经过测试的文档。在这种情况下,正如人们所说,与其给你鱼,我更愿意专注于库背后的成熟概念,并在本书中教授钓鱼的过程,以避免每次新的 MLlib 发布都需要大幅修改章节。不管好坏,这越来越像是一个数据科学家需要掌握的技能。

统计学和机器学习本质上处理不确定性,由于我们在第二章中讨论的一个或另一个原因,数据管道和建模。虽然一些数据集可能是完全随机的,但这里的目的是找到趋势、结构和模式,这些是随机数生成器无法提供的。机器学习的基本价值在于我们可以泛化这些模式并在至少一些指标上取得改进。让我们看看 Scala/Spark 中可用的基本工具。

本章,我将介绍监督学习和无监督学习,这两种历史上不同的方法。监督学习传统上用于当我们有一个特定的目标来预测一个标签,或数据集的特定属性时。无监督学习可以用来理解数据集中任何属性之间的内部结构和依赖关系,并且通常用于将记录或属性分组到有意义的聚类中。在实践中,这两种方法都可以用来补充和辅助对方。

本章将涵盖以下主题:

  • 学习监督学习的标准模型——决策树和逻辑回归

  • 讨论无监督学习的基础——k-均值聚类及其衍生

  • 理解评估上述算法有效性的指标和方法

  • 略窥上述方法在流数据、稀疏数据和非结构化数据特殊情况的扩展

记录和监督学习

为了本章的目的,记录是对一个或多个属性的一个观察或测量。我们假设这些观察可能包含噪声 记录和监督学习(或者由于某种原因不准确):

记录和监督学习

虽然我们相信属性之间存在某种模式或相关性,但我们追求并希望揭示的那个,噪声在属性或记录之间是不相关的。在统计术语中,我们说每个记录的值来自相同的分布,并且是独立的(或统计术语中的 i.i.d)。记录的顺序并不重要。其中一个属性,通常是第一个,可能被指定为标签。

监督学习的目标是预测标签 yi

记录和监督学习

在这里,N 是剩余属性的数量。换句话说,目标是泛化模式,以便我们只需知道其他属性就可以预测标签,无论是由于我们无法物理获取测量值,还是只想探索数据集的结构而不具有立即预测标签的目标。

无监督学习是在我们不使用标签的情况下进行的——我们只是尝试探索结构和相关性,以理解数据集,从而可能更好地预测标签。随着无结构数据学习和流的学习的出现,这一类问题最近数量有所增加,我将在本书的单独章节中分别介绍。

Iris 数据集

我将通过机器学习中最著名的数据库之一,Iris 数据集,来演示记录和标签的概念(archive.ics.uci.edu/ml/datasets/Iris)。Iris 数据集包含三种 Iris 花类型各 50 条记录,总共 150 行,五个字段。每一行是对以下内容的测量:

  • 萼片长度(厘米)

  • 萼片宽度(厘米)

  • 花瓣长度(厘米)

  • 花瓣宽度(厘米)

最后一个字段是花的类型(setosaversicolorvirginica)。经典问题是要预测标签,在这种情况下,这是一个具有三个可能值的分类属性,这些值是前四个属性的功能:

Iris 数据集

一种选择是在四维空间中绘制一个平面,该平面可以分隔所有四个标签。不幸的是,正如人们可以发现的,虽然其中一个类别可以清楚地分离,但剩下的两个类别则不行,如下面的多维散点图所示(我们使用了 Data Desk 软件创建它):

Iris 数据集

图 04-1. Iris 数据集的三维图。Iris setosa 记录,用交叉表示,可以根据花瓣长度和宽度与其他两种类型分开。

颜色和形状是根据以下表格分配的:

标签 颜色 形状
Iris setosa 蓝色 x
Iris versicolor 绿色 竖条
Iris virginica 紫色 水平条形

Iris setosa是可分离的,因为它与其他两种类型相比,花瓣长度和宽度非常短。

让我们看看如何使用 MLlib 找到那个分隔多维平面的方法。

标记点

之前使用的标记数据集在机器学习(ML)中占有非常重要的位置——我们将在本章后面讨论无监督学习,在那里我们不需要标签,因此 MLlib 有一个特殊的数据类型来表示带有org.apache.spark.mllib.regression.LabeledPoint标签的记录(请参阅spark.apache.org/docs/latest/mllib-data-types.html#labeled-point)。要从文本文件中读取 Iris 数据集,我们需要将原始 UCI 仓库文件转换为所谓的 LIBSVM 文本格式。虽然有很多从 CSV 到 LIBSVM 格式的转换器,但我希望使用一个简单的 AWK 脚本来完成这项工作:

awk -F, '/setosa/ {print "0 1:"$1" 2:"$2" 3:"$3" 4:"$4;}; /versicolor/ {print "1 1:"$1" 2:"$2" 3:"$3" 4:"$4;}; /virginica/ {print "1 1:"$1" 2:"$2" 3:"$3" 4:"$4;};' iris.csv > iris-libsvm.txt

注意

为什么我们需要 LIBSVM 格式?

LIBSVM 是许多库使用的格式。首先,LIBSVM 只接受连续属性。虽然现实世界中的许多数据集包含离散或分类属性,但出于效率原因,它们在内部始终转换为数值表示,即使结果数值属性的 L1 或 L2 度量在无序的离散值上没有太多意义。其次,LIBSVM 格式允许高效地表示稀疏数据。虽然 Iris 数据集不是稀疏的,但几乎所有现代大数据源都是稀疏的,该格式通过仅存储提供的值来实现高效存储。许多现代大数据键值和传统关系型数据库管理系统实际上出于效率原因也这样做。

对于缺失值,代码可能更复杂,但我们知道 Iris 数据集不是稀疏的——否则我们会用一堆 if 语句来补充我们的代码。现在,我们将最后两个标签映射到 1。

SVMWithSGD

现在,让我们运行 MLlib 中的线性支持向量机SVM)SVMWithSGD 代码:

$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1
 /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
scala> import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
scala> import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.MLUtils
scala> val data = MLUtils.loadLibSVMFile(sc, "iris-libsvm.txt")
data: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[6] at map at MLUtils.scala:112
scala> val splits = data.randomSplit(Array(0.6, 0.4), seed = 123L)
splits: Array[org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint]] = Array(MapPartitionsRDD[7] at randomSplit at <console>:26, MapPartitionsRDD[8] at randomSplit at <console>:26)
scala> val training = splits(0).cache()
training: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[7] at randomSplit at <console>:26
scala> val test = splits(1)
test: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[8] at randomSplit at <console>:26
scala> val numIterations = 100
numIterations: Int = 100
scala> val model = SVMWithSGD.train(training, numIterations)
model: org.apache.spark.mllib.classification.SVMModel = org.apache.spark.mllib.classification.SVMModel: intercept = 0.0, numFeatures = 4, numClasses = 2, threshold = 0.0
scala> model.clearThreshold()
res0: model.type = org.apache.spark.mllib.classification.SVMModel: intercept = 0.0, numFeatures = 4, numClasses = 2, threshold = None
scala> val scoreAndLabels = test.map { point =>
 |   val score = model.predict(point.features)
 |   (score, point.label)
 | }
scoreAndLabels: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[212] at map at <console>:36
scala> val metrics = new BinaryClassificationMetrics(scoreAndLabels)
metrics: org.apache.spark.mllib.evaluation.BinaryClassificationMetrics = org.apache.spark.mllib.evaluation.BinaryClassificationMetrics@692e4a35
scala> val auROC = metrics.areaUnderROC()
auROC: Double = 1.0

scala> println("Area under ROC = " + auROC)
Area under ROC = 1.0
scala> model.save(sc, "model")
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.

因此,你只需运行机器学习工具箱中最复杂的算法之一:SVM。结果是区分Iris setosa花与其他两种类型的分离平面。在这种情况下,模型正是最佳分离标签的截距和平面系数:

scala> model.intercept
res5: Double = 0.0

scala> model.weights
res6: org.apache.spark.mllib.linalg.Vector = [-0.2469448809675877,-1.0692729424287566,1.7500423423258127,0.8105712661836376]

如果深入了解,模型存储在一个parquet文件中,可以使用parquet-tool进行转储:

$ parquet-tools dump model/data/part-r-00000-7a86b825-569d-4c80-8796-8ee6972fd3b1.gz.parquet
…
DOUBLE weights.values.array 
----------------------------------------------------------------------------------------------------------------------------------------------
*** row group 1 of 1, values 1 to 4 *** 
value 1: R:0 D:3 V:-0.2469448809675877
value 2: R:1 D:3 V:-1.0692729424287566
value 3: R:1 D:3 V:1.7500423423258127
value 4: R:1 D:3 V:0.8105712661836376

DOUBLE intercept 
----------------------------------------------------------------------------------------------------------------------------------------------
*** row group 1 of 1, values 1 to 1 *** 
value 1: R:0 D:1 V:0.0
…

受试者工作特征ROC)是评估分类器能否根据其数值标签正确排序记录的常用指标。我们将在第九章,Scala 中的 NLP中更详细地考虑精确度指标。

小贴士

ROC 是什么?

ROC 最初出现在信号处理领域,首次应用是测量模拟雷达的准确性。准确性的常用指标是 ROC 曲线下的面积,简而言之,是随机选择两个点按其标签正确排序的概率(0标签应始终具有比1标签低的排名)。AUROC 具有许多吸引人的特性:

  • 该值,至少从理论上讲,不依赖于过采样率,即我们看到0标签而不是1标签的比率。

  • 该值不依赖于样本大小,排除了由于样本量有限而产生的预期方差。

  • 在最终得分中添加一个常数不会改变 ROC,因此截距可以始终设置为0。计算 ROC 需要对生成的得分进行排序。

当然,分离剩余的两个标签是一个更难的问题,因为将Iris versicolorIris virginica分开的平面不存在:AUROC 分数将小于1.0。然而,SVM 方法将找到最佳区分后两个类别的平面。

逻辑回归

逻辑回归是最古老的分类方法之一。逻辑回归的结果也是一组权重,这些权重定义了超平面,但损失函数是逻辑损失而不是 L2

逻辑回归

当标签是二元时(如上式中的 y = +/- 1),对数函数是一个常见的选择:

$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1
 /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.SparkContext
import org.apache.spark.SparkContext
scala> import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
scala> import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.evaluation.MulticlassMetrics
scala> import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.regression.LabeledPoint
scala> import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vectors
scala> import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.MLUtils
scala> val data = MLUtils.loadLibSVMFile(sc, "iris-libsvm-3.txt")
data: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[6] at map at MLUtils.scala:112
scala> val splits = data.randomSplit(Array(0.6, 0.4))
splits: Array[org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint]] = Array(MapPartitionsRDD[7] at randomSplit at <console>:29, MapPartitionsRDD[8] at randomSplit at <console>:29)
scala> val training = splits(0).cache()
training: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[7] at randomSplit at <console>:29
scala> val test = splits(1)
test: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[8] at randomSplit at <console>:29
scala> val model = new LogisticRegressionWithLBFGS().setNumClasses(3).run(training)
model: org.apache.spark.mllib.classification.LogisticRegressionModel = org.apache.spark.mllib.classification.LogisticRegressionModel: intercept = 0.0, numFeatures = 8, numClasses = 3, threshold = 0.5
scala> val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
 |   val prediction = model.predict(features)
 |   (prediction, label)
 | }
predictionAndLabels: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[67] at map at <console>:37
scala> val metrics = new MulticlassMetrics(predictionAndLabels)
metrics: org.apache.spark.mllib.evaluation.MulticlassMetrics = org.apache.spark.mllib.evaluation.MulticlassMetrics@6d5254f3
scala> val precision = metrics.precision
precision: Double = 0.9516129032258065
scala> println("Precision = " + precision)
Precision = 0.9516129032258065
scala> model.intercept
res5: Double = 0.0
scala> model.weights
res7: org.apache.spark.mllib.linalg.Vector = [10.644978886788556,-26.850171485157578,3.852594349297618,8.74629386938248,4.288703063075211,-31.029289381858273,9.790312529377474,22.058196856491996]

在这种情况下,标签可以是范围 [0, k) 中的任何整数,其中 k 是类的总数(正确的类别将通过针对基准类(在这种情况下,是带有 0 标签的类别)构建多个二元逻辑回归模型来确定)(《统计学习的要素》,作者:Trevor HastieRobert TibshiraniJerome FriedmanSpringer Series in Statistics)。

准确性指标是精确度,即正确预测的记录百分比(在我们的案例中为 95%)。

决策树

前两种方法描述了线性模型。不幸的是,线性方法并不总是适用于属性之间的复杂交互。假设标签看起来像这样的独热编码:如果 X ? Y 则为 0,如果 X = Y 则为 1

X Y 标签
1 0 0
0 1 0
1 1 1
0 0 1

XY 空间中,没有超平面可以区分这两个标签。在这种情况下,递归分割解决方案,其中每个级别的分割仅基于一个变量或其线性组合,可能会稍微好一些。决策树也已知与稀疏和交互丰富的数据集配合得很好:

$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1
 /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.DecisionTree
scala> import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel
scala> import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.MLUtils
scala> import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Strategy
scala> import org.apache.spark.mllib.tree.configuration.Algo.Classification
import org.apache.spark.mllib.tree.configuration.Algo.Classification
scala> import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
scala> val data = MLUtils.loadLibSVMFile(sc, "iris-libsvm-3.txt")
data: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[6] at map at MLUtils.scala:112

scala> val splits = data.randomSplit(Array(0.7, 0.3), 11L)
splits: Array[org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint]] = Array(MapPartitionsRDD[7] at randomSplit at <console>:30, MapPartitionsRDD[8] at randomSplit at <console>:30)
scala> val (trainingData, testData) = (splits(0), splits(1))
trainingData: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[7] at randomSplit at <console>:30
testData: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[8] at randomSplit at <console>:30
scala> val strategy = new Strategy(Classification, Gini, 10, 3, 10)
strategy: org.apache.spark.mllib.tree.configuration.Strategy = org.apache.spark.mllib.tree.configuration.Strategy@4110e631
scala> val dt = new DecisionTree(strategy)
dt: org.apache.spark.mllib.tree.DecisionTree = org.apache.spark.mllib.tree.DecisionTree@33d89052
scala> val model = dt.run(trainingData)
model: org.apache.spark.mllib.tree.model.DecisionTreeModel = DecisionTreeModel classifier of depth 6 with 21 nodes
scala> val labelAndPreds = testData.map { point =>
 |   val prediction = model.predict(point.features)
 |   (point.label, prediction)
 | }
labelAndPreds: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[32] at map at <console>:36
scala> val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
testErr: Double = 0.02631578947368421
scala> println("Test Error = " + testErr)
Test Error = 0.02631578947368421

scala> println("Learned classification tree model:\n" + model.toDebugString)
Learned classification tree model:
DecisionTreeModel classifier of depth 6 with 21 nodes
 If (feature 3 <= 0.4)
 Predict: 0.0
 Else (feature 3 > 0.4)
 If (feature 3 <= 1.7)
 If (feature 2 <= 4.9)
 If (feature 0 <= 5.3)
 If (feature 1 <= 2.8)
 If (feature 2 <= 3.9)
 Predict: 1.0
 Else (feature 2 > 3.9)
 Predict: 2.0
 Else (feature 1 > 2.8)
 Predict: 0.0
 Else (feature 0 > 5.3)
 Predict: 1.0
 Else (feature 2 > 4.9)
 If (feature 0 <= 6.0)
 If (feature 1 <= 2.4)
 Predict: 2.0
 Else (feature 1 > 2.4)
 Predict: 1.0
 Else (feature 0 > 6.0)
 Predict: 2.0
 Else (feature 3 > 1.7)
 If (feature 2 <= 4.9)
 If (feature 1 <= 3.0)
 Predict: 2.0
 Else (feature 1 > 3.0)
 Predict: 1.0
 Else (feature 2 > 4.9)
 Predict: 2.0
scala> model.save(sc, "dt-model")
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.

如您所见,在保留 30%样本的情况下,错误(误预测)率仅为 2.6%。150 个样本中的 30%只有 45 条记录,这意味着我们只从整个测试集中遗漏了 1 条记录。当然,结果可能会随着不同的种子而改变,并且我们需要更严格的交叉验证技术来证明模型的准确性,但这已经足够对模型性能进行粗略估计。

决策树在回归案例中进行了泛化,即当标签在本质上连续时。在这种情况下,分割标准是最小化加权方差,而不是分类情况中的熵增益或基尼指数。我将在第五章中更多地讨论回归和分类之间的差异。

有许多参数可以调整以提高性能:

参数 描述 推荐值
maxDepth 这是树的最大深度。深度树成本较高,通常更容易过拟合。浅层树更高效,更适合像 AdaBoost 这样的 bagging/boosting 算法。 这取决于原始数据集的大小。值得实验并绘制结果的树准确性与参数之间的关系图,以找出最佳值。
minInstancesPerNode 这也限制了树的大小:一旦实例数量低于此阈值,就不会再进行进一步分割。 该值通常为 10-100,具体取决于原始数据集的复杂性和潜在标签的数量。
maxBins 这仅用于连续属性:分割原始范围的箱数。 大量的箱会增加计算和通信成本。也可以考虑根据领域知识对属性进行预离散化的选项。
minInfoGain 这是分割节点所需的信息增益(熵)、不纯度(基尼)或方差(回归)增益。 默认值为0,但你可以增加默认值以限制树的大小并降低过拟合的风险。
maxMemoryInMB 这是用于收集足够统计信息的内存量。 默认值保守地选择为 256 MB,以便决策算法在大多数场景下都能工作。增加maxMemoryInMB可以通过允许对数据进行更少的遍历来加快训练速度(如果内存可用)。然而,随着maxMemoryInMB的增加,每次迭代的通信量可能成比例增加,这可能导致收益递减。
subsamplingRate 这是用于学习决策树的训练数据的一部分。 此参数对于训练树集合(使用RandomForestGradientBoostedTrees)最为相关,其中对原始数据进行子采样可能很有用。对于训练单个决策树,此参数不太有用,因为训练实例的数量通常不是主要限制因素。
useNodeIdCache 如果设置为 true,算法将避免在每次迭代中将当前模型(树或树集合)传递给执行器。 这对于深度树(加快工作节点的计算速度)和大型随机森林(减少每次迭代的通信量)很有用。
checkpointDir: 这是用于检查点化节点 ID 缓存 RDD 的目录。 这是一个优化,用于将中间结果保存下来,以避免节点故障时的重新计算。在大集群或不稳定的节点上设置。
checkpointInterval 这是检查点化节点 ID 缓存 RDD 的频率。 设置得太低会导致写入 HDFS 的额外开销,设置得太高则可能在执行器失败且需要重新计算 RDD 时引起问题。

Bagging 和 boosting – 集成学习方法

由于股票组合的特性优于单个股票,可以将模型结合起来产生更好的分类器。通常,这些方法与决策树作为训练技术结合得很好,因为训练技术可以被修改以产生具有较大变异的模型。一种方法是在原始数据的随机子集或属性的随机子集上训练模型,这被称为随机森林。另一种方法是通过生成一系列模型,将误分类实例重新加权,以便在后续迭代中获得更大的权重。已经证明这种方法与模型参数空间中的梯度下降方法有关。虽然这些是有效且有趣的技术,但它们通常需要更多的模型存储空间,并且与裸决策树模型相比,可解释性较差。对于 Spark,集成模型目前正处于开发中——主要问题为 SPARK-3703 (issues.apache.org/jira/browse/SPARK-3703)。

无监督学习

如果我们在 Iris 数据集中去掉标签,如果某些算法能够恢复原始分组,即使没有确切的标签名称——setosaversicolorvirginica——那将很理想。无监督学习在压缩和编码、客户关系管理(CRM)、推荐引擎和安全领域有多个应用,可以在不实际拥有确切标签的情况下揭示内部结构。标签有时可以根据属性值分布的奇异性给出。例如,Iris setosa 可以描述为 小叶花

虽然可以通过忽略标签将监督学习问题视为无监督问题,但反之亦然。可以将聚类算法视为密度估计问题,通过将所有向量分配标签 1 并生成带有标签 0 的随机向量(参见 Trevor HastieRobert TibshiraniJerome Friedman《统计学习基础》Springer 统计系列)。这两种方法之间的区别是正式的,对于非结构化和嵌套数据来说甚至更加模糊。通常,在标记数据集上运行无监督算法可以更好地理解依赖关系,从而更好地选择和表现监督算法。

聚类和无监督学习中最受欢迎的算法之一是 k-means(以及其变体,k-median 和 k-center,将在后面描述):

$ bin/spark-shell
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1
 /_/
Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
scala> import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vectors
scala> val iris = sc.textFile("iris.txt")
iris: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[4] at textFile at <console>:23

scala> val vectors = data.map(s => Vectors.dense(s.split('\t').map(_.toDouble))).cache()
vectors: org.apache.spark.rdd.RDD[org.apache.spark.mllib.linalg.Vector] = MapPartitionsRDD[5] at map at <console>:25

scala> val numClusters = 3
numClusters: Int = 3
scala> val numIterations = 20
numIterations: Int = 20
scala> val clusters = KMeans.train(vectors, numClusters, numIterations)
clusters: org.apache.spark.mllib.clustering.KMeansModel = org.apache.spark.mllib.clustering.KMeansModel@5dc9cb99
scala> val centers = clusters.clusterCenters
centers: Array[org.apache.spark.mllib.linalg.Vector] = Array([5.005999999999999,3.4180000000000006,1.4640000000000002,0.2439999999999999], [6.8538461538461535,3.076923076923076,5.715384615384614,2.0538461538461537], [5.883606557377049,2.740983606557377,4.388524590163936,1.4344262295081966])
scala> val SSE = clusters.computeCost(vectors)
WSSSE: Double = 78.94506582597859
scala> vectors.collect.map(x => clusters.predict(x))
res18: Array[Int] = Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 2)
scala> println("Sum of Squared Errors = " + SSE)
Sum of Squared Errors = 78.94506582597859
scala> clusters.save(sc, "model")
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.

可以看到第一个中心,即索引为 0 的中心,花瓣长度和宽度为 1.4640.244,这比其他两个——5.7152.0544.3891.434——要短得多。预测完全符合第一个聚类,对应于 Iris setosa,但对于其他两个有一些误判。

如果我们想要实现一个期望的分类结果,簇质量的度量可能取决于(期望的)标签,但鉴于算法没有关于标签的信息,一个更常见的度量是每个簇中从质心到点的距离之和。以下是一个关于WSSSE的图表,它取决于簇的数量:

scala> 1.to(10).foreach(i => println("i: " + i + " SSE: " + KMeans.train(vectors, i, numIterations).computeCost(vectors)))
i: 1 WSSSE: 680.8244
i: 2 WSSSE: 152.3687064773393
i: 3 WSSSE: 78.94506582597859
i: 4 WSSSE: 57.47327326549501
i: 5 WSSSE: 46.53558205128235
i: 6 WSSSE: 38.9647878510374
i: 7 WSSSE: 34.311167589868646
i: 8 WSSSE: 32.607859500805034
i: 9 WSSSE: 28.231729411088438
i: 10 WSSSE: 29.435054384424078

如预期的那样,随着配置的簇越来越多,平均距离在减小。确定最佳簇数量的一个常见方法是添加一个惩罚函数。一个常见的惩罚是簇数量的对数,因为我们期望一个凸函数。对数前面的系数是多少?如果每个向量都与自己的簇相关联,所有距离的总和将为零,因此如果我们想要一个在可能值集合的两端都达到大约相同值的度量,1150,系数应该是680.8244/log(150)

scala> for (i <- 1.to(10)) println(i + " -> " + ((KMeans.train(vectors, i, numIterations).computeCost(vectors)) + 680 * scala.math.log(i) / scala.math.log(150)))
1 -> 680.8244
2 -> 246.436635016484
3 -> 228.03498068120865
4 -> 245.48126639400738
5 -> 264.9805962616268
6 -> 285.48857890531764
7 -> 301.56808340425164
8 -> 315.321639004243
9 -> 326.47262191671723
10 -> 344.87130979355675

这是带有惩罚的平方距离之和的图形表示:

无监督学习

图 04-2. 随簇数量变化的聚类质量度量

除了 k-means 聚类外,MLlib 还实现了以下功能:

  • 高斯混合

  • 幂迭代聚类PIC

  • 潜在狄利克雷分配LDA

  • 流式 k-means

高斯混合是另一种经典机制,尤其以频谱分析而闻名。当属性是连续的,并且我们知道它们可能来自一组高斯分布时,高斯混合分解是合适的。例如,当对应于簇的点的潜在组可能具有所有属性的均值,比如Var1Var2时,点可能围绕两个相交的超平面中心,如下面的图所示:

无监督学习

图 04-3. 无法用 k-means 聚类正确描述的两个高斯混合

这使得 k-means 算法无效,因为它无法区分这两者(当然,一个简单的非线性变换,如到其中一个超平面的距离,可以解决这个问题,但这就是领域知识和数据科学家专业知识派上用场的地方)。

PIC 使用图中的聚类顶点,并提供了作为边属性的成对相似度度量。它通过幂迭代计算图的归一化亲和矩阵的伪特征向量,并使用它来聚类顶点。MLlib 包含了一个使用 GraphX 作为其后端的 PIC 实现。它接受一个包含 (srcId, dstId, similarity) 元组的 RDD,并输出一个具有聚类分配的模型。相似度必须是非负的。PIC 假设相似度度量是对称的。无论顺序如何,一对 (srcId, dstId) 在输入数据中最多只能出现一次。如果一对在输入中缺失,它们的相似度被视为零。

LDA 可以用于基于关键词频率对文档进行聚类。LDA 不是使用传统的距离来估计聚类,而是使用基于文本文档生成统计模型的函数。

最后,流式 k-means 是 k-means 算法的一种改进,其中聚类可以根据新的数据批次进行调整。对于每个数据批次,我们将所有点分配到最近的聚类,根据分配计算新的聚类中心,然后使用以下方程更新每个聚类的参数:

无监督学习无监督学习

这里,c tc' t 是旧模型和为新批次计算的中心的坐标,而 n tn' t 是旧模型和新批次中的向量数量。通过改变 a 参数,我们可以控制旧运行的信息对聚类的影响程度——0 表示新的聚类中心完全基于新批次中的点,而 1 表示我们考虑到目前为止看到的所有点。

k-means 聚类算法有许多改进版本。例如,k-medians 计算聚类中心是属性值的中位数,而不是平均值,这对于某些分布和与 L1 目标距离度量(差值的绝对值)相比 L2(平方和)来说效果更好。K-medians 的中心不一定是数据集中具体的一个点。K-medoids 是同一家族中的另一个算法,其中结果聚类中心必须是输入集中的一个实际实例,我们实际上不需要全局排序,只需要点之间的成对距离。关于如何选择原始种子聚类中心和收敛到最佳聚类数量(除了我展示的简单对数技巧之外)的技术有许多变体。

另一大类聚类算法是层次聚类。层次聚类可以是自顶向下进行的——类似于决策树算法——或者自底向上;我们首先找到最近的邻居,将它们配对,然后继续配对过程,直到所有记录都被合并。层次聚类的优点是它可以被设计成确定性的,并且相对较快,尽管 k-means 的一次迭代的成本可能更好。然而,如前所述,无监督问题实际上可以被转换为一个密度估计的监督问题,可以使用所有可用的监督学习技术。所以,享受理解数据吧!

问题维度

属性空间越大或维度数越多,通常预测给定属性值组合的标签就越困难。这主要是因为属性空间中可能的不同属性组合的总数随着属性空间维度的增加而指数增长——至少在离散变量的情况下(对于连续变量,情况更复杂,取决于使用的度量),并且泛化的难度也在增加。

问题的有效维度可能与输入空间的维度不同。例如,如果标签仅依赖于(连续的)输入属性的线性组合,则该问题称为线性可分,其内部维度为一——尽管我们仍然需要像逻辑回归那样找到这个线性组合的系数。

这种想法有时也被称为问题的、模型的或算法的Vapnik–ChervonenkisVC)维度——模型的表达能力取决于它能够解决或分解的依赖关系的复杂性。更复杂的问题需要具有更高 VC 维度的算法和更大的训练集。然而,在简单问题上使用具有更高 VC 维度的算法可能会导致过拟合,并且对新数据的泛化更差。

如果输入属性的单元是可比的,比如说它们都是米或时间的单位,则可以使用 PCA,或者更一般地说,使用核方法,来降低输入空间的维度:

$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1
 /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.regression.LabeledPoint
scala> import org.apache.spark.mllib.feature.PCA
import org.apache.spark.mllib.feature.PCA
scala> import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.MLUtils
scala> val pca = new PCA(2).fit(data.map(_.features))
pca: org.apache.spark.mllib.feature.PCAModel = org.apache.spark.mllib.feature.PCAModel@4eee0b1a

scala> val reduced = data.map(p => p.copy(features = pca.transform(p.features)))
reduced: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[311] at map at <console>:39
scala> reduced.collect().take(10)
res4: Array[org.apache.spark.mllib.regression.LabeledPoint] = Array((0.0,[-2.827135972679021,-5.641331045573367]), (0.0,[-2.7959524821488393,-5.145166883252959]), (0.0,[-2.621523558165053,-5.177378121203953]), (0.0,[-2.764905900474235,-5.0035994150569865]), (0.0,[-2.7827501159516546,-5.6486482943774305]), (0.0,[-3.231445736773371,-6.062506444034109]), (0.0,[-2.6904524156023393,-5.232619219784292]), (0.0,[-2.8848611044591506,-5.485129079769268]), (0.0,[-2.6233845324473357,-4.743925704477387]), (0.0,[-2.8374984110638493,-5.208032027056245]))

scala> import scala.language.postfixOps
import scala.language.postfixOps

scala> pca pc
res24: org.apache.spark.mllib.linalg.DenseMatrix = 
-0.36158967738145065  -0.6565398832858496 
0.08226888989221656   -0.7297123713264776 
-0.856572105290527    0.17576740342866465 
-0.35884392624821626  0.07470647013502865

scala> import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
scala> import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
scala> val splits = reduced.randomSplit(Array(0.6, 0.4), seed = 1L)
splits: Array[org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint]] = Array(MapPartitionsRDD[312] at randomSplit at <console>:44, MapPartitionsRDD[313] at randomSplit at <console>:44)
scala> val training = splits(0).cache()
training: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[312] at randomSplit at <console>:44
scala> val test = splits(1)
test: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[313] at randomSplit at <console>:44
scala> val numIterations = 100
numIterations: Int = 100
scala> val model = SVMWithSGD.train(training, numIterations)
model: org.apache.spark.mllib.classification.SVMModel = org.apache.spark.mllib.classification.SVMModel: intercept = 0.0, numFeatures = 2, numClasses = 2, threshold = 0.0
scala> model.clearThreshold()
res30: model.type = org.apache.spark.mllib.classification.SVMModel: intercept = 0.0, numFeatures = 2, numClasses = 2, threshold = None
scala> val scoreAndLabels = test.map { point =>
 |   val score = model.predict(point.features)
 |   (score, point.label)
 | }
scoreAndLabels: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[517] at map at <console>:54
scala> val metrics = new BinaryClassificationMetrics(scoreAndLabels)
metrics: org.apache.spark.mllib.evaluation.BinaryClassificationMetrics = org.apache.spark.mllib.evaluation.BinaryClassificationMetrics@27f49b8c

scala> val auROC = metrics.areaUnderROC()
auROC: Double = 1.0
scala> println("Area under ROC = " + auROC)
Area under ROC = 1.0

在这里,我们将原始的四维问题降低到二维。像平均一样,计算输入属性的线性组合并仅选择描述大部分方差的那些属性有助于减少噪声。

摘要

在本章中,我们探讨了监督学习和无监督学习,以及如何在 Spark/Scala 中运行它们的几个示例。在 UCI Iris 数据集的示例中,我们考虑了 SVM、逻辑回归、决策树和 k-means。这绝对不是一份完整的指南,而且现在存在或正在制作许多其他库,但我敢打赌,你只需使用这些工具就能解决 99%的即时数据分析问题。

这将为你提供一个快速捷径,了解如何开始使用新的数据集进行高效工作。查看数据集的方法有很多,但在我们深入更高级的主题之前,让我们在下一章讨论回归和分类,即如何预测连续和离散标签。

第五章。回归和分类

在上一章中,我们熟悉了监督学习和无监督学习。机器学习方法的另一种标准分类是基于标签来自连续空间还是离散空间。即使离散标签是有序的,也存在显著差异,尤其是在评估拟合优度指标方面。

在本章中,我们将涵盖以下主题:

  • 学习回归这个词的起源

  • 学习在连续和离散空间中评估拟合优度指标的方法

  • 讨论如何用 Scala 编写简单的线性回归和逻辑回归代码

  • 学习关于正则化、多类预测和异方差性等高级概念

  • 讨论 MLlib 应用回归树分析的示例

  • 学习评估分类模型的不同方法

回归代表什么?

虽然分类这个词直观上很清楚,但回归这个词似乎并不暗示连续标签的预测器。根据韦伯斯特词典,回归是:

“回到以前或较少发展的状态。”

它也提到了统计学的一个特殊定义,即一个变量(例如,输出)的平均值与相应变量的值(例如,时间和成本)之间的关系度量,这在当今实际上是正确的。然而,从历史上看,回归系数原本是用来表示某些特征(如体重和大小)从一代传到另一代的遗传性,暗示着有计划的基因选择,包括人类(www.amstat.org/publications/jse/v9n3/stanton.html)。更具体地说,在 1875 年,查尔斯·达尔文的表亲、一位杰出的 19 世纪科学家高尔顿,也因为推广优生学而受到广泛批评,他向七个朋友分发了甜豌豆种子。每个朋友都收到了重量均匀的种子,但七个包裹之间有显著的差异。高尔顿的朋友们应该收获下一代种子并将它们寄回给他。高尔顿随后分析了每个群体中种子的统计特性,其中一项分析就是绘制回归线,这条线似乎总是具有小于 1 的斜率——具体引用的数字是 0.33(高尔顿,F. (1894),《自然遗传》(第 5 版),纽约:麦克米伦公司),与没有相关性且没有遗传的情况下的0相反;或者与父母特征在后代中完全复制的1相反。我们将在有噪声数据的情况下讨论为什么回归线的系数应该始终小于1,即使相关性是完美的。然而,除了讨论和细节之外,回归这个术语的起源部分是由于植物和人类的计划育种。当然,高尔顿当时没有访问 PCA、Scala 或其他任何计算设备,这些设备可能会更多地阐明相关性和回归线斜率之间的差异。

连续空间和度量

由于本章的大部分内容将涉及尝试预测或优化连续变量,让我们首先了解如何在连续空间中测量差异。除非很快有重大发现,我们所处的空间是一个三维欧几里得空间。无论我们是否喜欢,这是我们今天大多数人都比较适应的世界。我们可以用三个连续的数字完全指定我们的位置。位置之间的差异通常通过距离或度量来衡量,度量是一个关于两个参数的函数,它返回一个正实数。自然地,XY 之间的距离,连续空间和度量,应该始终等于或小于 XZ 以及 YZ 之间距离之和:

连续空间和度量

对于任何 XYZ,这也就是三角不等式。度量的另外两个性质是对称性:

连续空间和度量

距离的非负性:

连续空间和度量连续空间和度量

在这里,如果且仅当 X=Y 时,度量是 0连续空间和度量 距离是我们对日常生活中的距离的理解,即每个维度上平方差的平方根。我们物理距离的推广是 p-范数(对于 连续空间和度量 距离,p = 2):

连续空间和度量

在这里,总和是 XY 向量的整体分量。如果 p=1,1-范数是绝对差分的总和,或曼哈顿距离,就像从点 X 到点 Y 的唯一路径是只沿一个分量移动一样。这种距离也常被称为 连续空间和度量 距离:

连续空间和度量

图 05-1. 二维空间中的 连续空间和度量 圆(距离原点 (0, 0) 精确为一单位的点的集合)

这里是二维空间中圆的一个表示:

连续空间和度量

图 05-2. 连续空间和度量 圆在二维空间中(距离原点 (0, 0) 等距离的点的集合),在我们对距离的日常理解中,它实际上看起来像一个圆。

另一个常用的特殊情况是 连续空间和度量,当 连续空间和度量 时,即沿任何组件的最大偏差,如下所示:

连续空间和度量

对于 连续空间和度量 距离的等距圆,请参见 图 05-3

连续空间和度量

图 05-3. 连续空间和度量 圆在二维空间中(距离原点 (0, 0) 等距离的点的集合)。这是一个正方形,因为 连续空间和度量 度量是沿任何组件的最大距离。

当我谈到分类时,我会稍后考虑 Kullback-LeiblerKL) 距离,它衡量两个概率分布之间的差异,但它是一个不对称的距离的例子,因此它不是一个度量。

度量性质使得问题分解更容易。由于三角不等式,可以通过分别优化问题的多个维度分量来替换一个困难的目标优化问题。

线性回归

如 第二章 所述,数据管道和建模,大多数复杂的机器学习问题都可以归结为优化,因为我们的最终目标是优化整个流程,其中机器作为中介或完整解决方案。指标可以是明确的,如错误率,或者更间接的,如 月活跃用户 (MAU),但算法的有效性最终是通过它如何改善我们生活中的某些指标和流程来评判的。有时,目标可能包括多个子目标,或者维护性和稳定性等指标最终也可能被考虑,但本质上,我们需要以某种方式最大化或最小化一个连续指标。

为了流程的严谨性,让我们展示如何将线性回归表述为一个优化问题。经典的线性回归需要优化累积 线性回归 错误率:

线性回归

在这里,线性回归是模型给出的估计值,在线性回归的情况下,如下所示:

线性回归

(第三章 中已列举了其他潜在的 损失函数)。由于 线性回归 指标是 ab 的可微凸函数,可以通过将累积错误率的导数等于 0 来找到极值:

线性回归

在这种情况下,计算导数是直接的,并导致以下方程:

线性回归线性回归

这可以通过以下方式解决:

线性回归线性回归

这里,avg() 表示整体输入记录的平均值。注意,如果 avg(x)=0,则前面的方程简化为以下形式:

线性回归线性回归

因此,我们可以快速使用基本的 Scala 操作符计算线性回归系数(我们可以通过执行 线性回归 使 avg(x) 为零):

akozlov@Alexanders-MacBook-Pro$ scala

Welcome to Scala version 2.11.6 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40).
Type in expressions to have them evaluated.
Type :help for more information.

scala> import scala.util.Random
import scala.util.Random

scala> val x = -5 to 5
x: scala.collection.immutable.Range.Inclusive = Range(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)

scala> val y = x.map(_ * 2 + 4 + Random.nextGaussian)
y: scala.collection.immutable.IndexedSeq[Double] = Vector(-4.317116812989753, -4.4056031270948015, -2.0376543660274713, 0.0184679796245639, 1.8356532746253016, 3.2322795591658644, 6.821999810895798, 7.7977904139852035, 10.288549406814154, 12.424126535332453, 13.611442206874917)

scala> val a = (x, y).zipped.map(_ * _).sum / x.map(x => x * x).sum
a: Double = 1.9498665133868092

scala> val b = y.sum / y.size
b: Double = 4.115448625564203

我之前没有告诉你 Scala 是一种非常简洁的语言吗?我们只用五行代码就完成了线性回归,其中三行只是数据生成语句。

虽然有使用 Scala 编写的用于执行(多元)线性回归的库,例如 Breeze (github.com/scalanlp/breeze),它提供了更广泛的功能,但能够使用纯 Scala 功能来获取一些简单的统计结果是非常好的。

让我们看看加尔顿先生的问题,他发现回归线总是小于一的斜率,这意味着我们应该始终回归到某个预定义的均值。我将生成与之前相同的点,但它们将分布在水平线上,并带有一些预定义的噪声。然后,我将通过在 xy-空间中进行线性旋转变换将线旋转 45 度。直观上,应该很明显,如果 yx 强烈相关且不存在,那么 y 的噪声就只能是 x

[akozlov@Alexanders-MacBook-Pro]$ scala
Welcome to Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40).
Type in expressions to have them evaluated.
Type :help for more information.

scala> import scala.util.Random.nextGaussian
import scala.util.Random.nextGaussian

scala> val x0 = Vector.fill(201)(100 * nextGaussian)
x0: scala.collection.immutable.IndexedSeq[Double] = Vector(168.28831870102465, -40.56031270948016, -3.7654366027471324, 1.84679796245639, -16.43467253746984, -76.77204408341358, 82.19998108957988, -20.22095860147962, 28.854940681415442, 42.41265353324536, -38.85577931250823, -17.320873680820082, 64.19368427702135, -8.173507833084892, -198.6064655461397, 40.73700995880357, 32.36849515282444, 0.07758364225363915, -101.74032407199553, 34.789280276495646, 46.29624756866302, 35.54024768650289, 24.7867839701828, -11.931948933554782, 72.12437623460166, 30.51440227306552, -80.20756177356768, 134.2380548346385, 96.14401034937691, -205.48142161773896, -73.48186022765427, 2.7861465340245215, 39.49041527572774, 12.262899592863906, -118.30408039749234, -62.727048950163855, -40.58557796128219, -23.42...
scala> val y0 = Vector.fill(201)(30 * nextGaussian)
y0: scala.collection.immutable.IndexedSeq[Double] = Vector(-51.675658534203876, 20.230770706186128, 32.47396891906855, -29.35028743620815, 26.7392929946199, 49.85681312583139, 24.226102932450917, 31.19021547086266, 26.169544117916704, -4.51435617676279, 5.6334117227063985, -59.641661744341775, -48.83082934374863, 29.655750956280304, 26.000847703123497, -17.43319605936741, 0.8354318740518344, 11.44787080976254, -26.26312164695179, 88.63863939038357, 45.795968719043785, 88.12442528090506, -29.829048945601635, -1.0417034396751037, -27.119245702417494, -14.055969115249258, 6.120344305721601, 6.102779172838027, -6.342516875566529, 0.06774080659895702, 46.364626315486014, -38.473161588561, -43.25262339890197, 19.77322736359687, -33.78364440355726, -29.085765762613683, 22.87698648100551, 30.53...
scala> val x1 = (x0, y0).zipped.map((a,b) => 0.5 * (a + b) )
x1: scala.collection.immutable.IndexedSeq[Double] = Vector(58.30633008341039, -10.164771001647015, 14.354266158160707, -13.75174473687588, 5.152310228575029, -13.457615478791094, 53.213042011015396, 5.484628434691521, 27.51224239966607, 18.949148678241286, -16.611183794900917, -38.48126771258093, 7.681427466636357, 10.741121561597705, -86.3028089215081, 11.651906949718079, 16.601963513438136, 5.7627272260080895, -64.00172285947366, 61.71395983343961, 46.0461081438534, 61.83233648370397, -2.5211324877094174, -6.486826186614943, 22.50256526609208, 8.229216578908131, -37.04360873392304, 70.17041700373827, 44.90074673690519, -102.70684040557, -13.558616956084126, -17.843507527268237, -1.8811040615871129, 16.01806347823039, -76.0438624005248, -45.90640735638877, -8.85429574013834, 3.55536787...
scala> val y1 = (x0, y0).zipped.map((a,b) => 0.5 * (a - b) )
y1: scala.collection.immutable.IndexedSeq[Double] = Vector(109.98198861761426, -30.395541707833143, -18.11970276090784, 15.598542699332269, -21.58698276604487, -63.31442860462248, 28.986939078564482, -25.70558703617114, 1.3426982817493691, 23.463504855004075, -22.244595517607316, 21.160394031760845, 56.51225681038499, -18.9146293946826, -112.3036566246316, 29.08510300908549, 15.7665316393863, -5.68514358375445, -37.73860121252187, -26.924679556943964, 0.2501394248096176, -26.292088797201085, 27.30791645789222, -5.445122746939839, 49.62181096850958, 22.28518569415739, -43.16395303964464, 64.06763783090022, 51.24326361247172, -102.77458121216895, -59.92324327157014, 20.62965406129276, 41.37151933731485, -3.755163885366482, -42.26021799696754, -16.820641593775086, -31.73128222114385, -26.9...
scala> val a = (x1, y1).zipped.map(_ * _).sum / x1.map(x => x * x).sum
a: Double = 0.8119662470457414

斜率仅为 0.81!请注意,如果对 x1y1 数据运行 PCA,第一个主成分将正确地沿着对角线。

为了完整性,我给出了 (x1, y1) 的一个绘图,如下所示:

线性回归

图 05-4. 一个看似完美相关的数据集的回归曲线斜率小于一。这与回归问题优化的度量(y 距离)有关。

我将留给读者去找到斜率小于一的原因,但这与回归问题应该回答的具体问题和它优化的度量有关。

逻辑回归

逻辑回归优化了关于 w 的对数损失函数:

逻辑回归

在这里,y 是二进制的(在这种情况下是正负一)。虽然与之前线性回归的情况不同,误差最小化问题没有封闭形式的解,但逻辑函数是可微的,并允许快速收敛的迭代算法。

梯度如下:

逻辑回归

再次,我们可以快速编写一个 Scala 程序,该程序使用梯度收敛到值,如 逻辑回归(我们仅使用 MLlib LabeledPoint 数据结构是为了读取数据的方便):

$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1-SNAPSHOT
 /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.Vector

scala> import org.apache.spark.util._
import org.apache.spark.util._

scala> import org.apache.spark.mllib.util._
import org.apache.spark.mllib.util._

scala> val data = MLUtils.loadLibSVMFile(sc, "data/iris/iris-libsvm.txt")
data: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[291] at map at MLUtils.scala:112

scala> var w = Vector.random(4)
w: org.apache.spark.util.Vector = (0.9515155226069267, 0.4901713461728122, 0.4308861351586426, 0.8030814804136821)

scala> for (i <- 1.to(10)) println { val gradient = data.map(p => ( - p.label / (1+scala.math.exp(p.label*(Vector(p.features.toDense.values) dot w))) * Vector(p.features.toDense.values) )).reduce(_+_); w -= 0.1 * gradient; w }
(-24.056553839570114, -16.585585503253142, -6.881629923278653, -0.4154730884796032)
(38.56344616042987, 12.134414496746864, 42.178370076721365, 16.344526911520397)
(13.533446160429868, -4.95558550325314, 34.858370076721364, 15.124526911520398)
(-11.496553839570133, -22.045585503253143, 27.538370076721364, 13.9045269115204)
(-4.002010810020908, -18.501520148476196, 32.506256310962314, 15.455945245916512)
(-4.002011353029471, -18.501520429824225, 32.50625615219947, 15.455945209971787)
(-4.002011896036225, -18.501520711171313, 32.50625599343715, 15.455945174027184)
(-4.002012439041171, -18.501520992517463, 32.506255834675365, 15.455945138082699)
(-4.002012982044308, -18.50152127386267, 32.50625567591411, 15.455945102138333)
(-4.002013525045636, -18.501521555206942, 32.506255517153384, 15.455945066194088)

scala> w *= 0.24 / 4
w: org.apache.spark.util.Vector = (-0.24012081150273815, -1.1100912933124165, 1.950375331029203, 0.9273567039716453)

逻辑回归被简化为只有一行 Scala 代码!最后一行是为了归一化权重——只有相对值对于定义分离平面很重要——以便与之前章节中 MLlib 获得的值进行比较。

在实际实现中使用的 随机梯度下降SGD)算法本质上与梯度下降相同,但以下方面进行了优化:

  • 实际梯度是在记录的子样本上计算的,这可能会由于减少了舍入噪声而加快转换速度,并避免局部最小值。

  • 步长——在我们的例子中是固定的 0.1——是迭代的单调递减函数,如 逻辑回归,这也可能导致更好的转换。

  • 它包含了正则化;不是仅仅最小化损失函数,而是最小化损失函数的总和,加上一些惩罚度量,这是一个关于模型复杂度的函数。我将在下一节讨论这个问题。

正则化

正则化最初是为了应对病态问题而开发的,其中问题是不受约束的——给定数据允许有多个解,或者数据和包含过多噪声的解(A.N. TikhonovA.S. LeonovA.G. Yagola. 非线性病态问题Chapman and HallLondonWeinhe)。添加额外的惩罚函数,如果解没有期望的特性,如曲线拟合或频谱分析中的平滑性,通常可以解决问题。

惩罚函数的选择在一定程度上是任意的,但它应该反映对解的期望偏斜。如果惩罚函数是可微分的,它可以被纳入梯度下降过程;岭回归就是一个例子,其中惩罚是权重或系数平方和的正则化度量。

MLlib 目前实现了正则化正则化,以及称为弹性网络的混合形式,如第三章所示,使用 Spark 和 MLlib正则化正则化有效地惩罚了回归权重中非零项的数量,但已知其收敛速度较慢。最小绝对收缩和选择算子LASSO)使用了正则化正则化。

另一种减少受约束问题不确定性的方法是考虑可能来自领域专家的先验信息。这可以通过贝叶斯分析实现,并在后验概率中引入额外的因素——概率规则通常用乘法而不是加法表示。然而,由于目标通常是最小化对数似然,贝叶斯校正通常也可以表示为标准正则化器。

多元回归

同时最小化多个度量是可能的。虽然 Spark 只有少数多元分析工具,但其他更传统且已建立的包包含了多元方差分析MANOVA),它是方差分析ANOVA)方法的一种推广。我将在第七章,处理图算法中介绍 ANOVA 和 MANOVA。

对于实际分析,我们首先需要了解目标变量是否相关,我们可以使用第三章中介绍的 PCA Spark 实现来做到这一点,使用 Spark 和 MLlib。如果因变量高度相关,最大化一个会导致最大化另一个,我们只需最大化第一个主成分(并且可能基于第二个成分构建回归模型来理解驱动差异的因素)。

如果目标不相关,可以为每个目标构建一个单独的模型,以确定驱动它们的变量以及这两个集合是否互斥。在后一种情况下,我们可以构建两个单独的模型来独立预测每个目标。

异方差性

回归方法中的一个基本假设是目标方差与独立(属性)或依赖(目标)变量不相关。一个可能违反此假设的例子是计数数据,它通常由泊松分布描述。对于泊松分布,方差与期望值成正比,高值可以更多地贡献到权重的最终方差。

虽然异方差性可能会或可能不会显著地扭曲结果权重,但一种补偿异方差性的实际方法是进行对数变换,这在泊松分布的情况下可以补偿:

异方差性异方差性

一些其他(参数化)的变换是Box-Cox 变换

异方差性

这里,异方差性是一个参数(对数变换是部分情况,其中异方差性)和 Tuckey 的 lambda 变换(对于介于01之间的属性):

异方差性

这些变换用于补偿泊松二项分布的属性或一系列试验中成功概率的估计,这些试验可能包含混合的n个伯努利分布。

异方差性是逻辑函数最小化在二元预测问题中比带有回归树最小化的线性回归表现更好的主要原因之一。让我们更详细地考虑离散标签。

回归树

我们在上一章中看到了分类树。可以为回归问题构建一个递归的分割和合并结构,其中分割是为了最小化剩余的方差。回归树不如决策树或经典 ANOVA 分析流行;然而,让我们在这里提供一个回归树的例子,作为 MLlib 的一部分:

akozlov@Alexanders-MacBook-Pro$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1-SNAPSHOT
 /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.DecisionTree

scala> import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel

scala> import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.MLUtils

scala> // Load and parse the data file.

scala> val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
data: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[6] at map at MLUtils.scala:112

scala> // Split the data into training and test sets (30% held out for testing)

scala> val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
trainingData: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[7] at randomSplit at <console>:26
testData: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[8] at randomSplit at <console>:26

scala> val categoricalFeaturesInfo = Map[Int, Int]()
categoricalFeaturesInfo: scala.collection.immutable.Map[Int,Int] = Map()

scala> val impurity = "variance"
impurity: String = variance

scala> val maxDepth = 5
maxDepth: Int = 5

scala> val maxBins = 32
maxBins: Int = 32

scala> val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, maxDepth, maxBins)
model: org.apache.spark.mllib.tree.model.DecisionTreeModel = DecisionTreeModel regressor of depth 2 with 5 nodes

scala> val labelsAndPredictions = testData.map { point =>
 |   val prediction = model.predict(point.features)
 |   (point.label, prediction)
 | }
labelsAndPredictions: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[20] at map at <console>:36

scala> val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
testMSE: Double = 0.07407407407407407

scala> println(s"Test Mean Squared Error = $testMSE")
Test Mean Squared Error = 0.07407407407407407

scala> println("Learned regression tree model:\n" + model.toDebugString)
Learned regression tree model:
DecisionTreeModel regressor of depth 2 with 5 nodes
 If (feature 378 <= 71.0)
 If (feature 100 <= 165.0)
 Predict: 0.0
 Else (feature 100 > 165.0)
 Predict: 1.0
 Else (feature 378 > 71.0)
 Predict: 1.0

每一层的分割都是为了最小化方差,如下所示:

回归树

这等价于最小化标签值与其在每个叶节点内的平均值之间的回归树距离,并求和所有叶节点的总和。

分类度量

如果标签是离散的,预测问题称为分类。通常,每个记录的目标只能取一个值(尽管可能存在多值目标,尤其是在第六章中考虑的文本分类问题,处理非结构化数据)。

如果离散值是有序的,并且排序有意义,例如 BadWorseGood,则可以将离散标签转换为整数或双精度浮点数,问题就简化为回归(我们相信如果你在 BadWorse 之间,你肯定比 Worse 更远离 Good)。

要优化的通用度量是误分类率,如下所示:

分类度量

然而,如果算法可以预测目标可能值的分布,可以使用更通用的度量,如 KL 散度或曼哈顿距离。

KL 散度是当使用概率分布 分类度量 来近似概率分布 分类度量 时信息损失的一个度量:

分类度量

它与决策树归纳中使用的熵增益分割标准密切相关,因为后者是所有叶节点上节点概率分布到叶概率分布的 KL 散度的总和。

多类问题

如果目标可能的结果数量超过两个,通常,我们必须预测目标值的期望概率分布,或者至少是有序值的列表——最好是通过一个排名变量来增强,该变量可以用于额外的分析。

虽然一些算法,如决策树,可以原生地预测多值属性。一种常见的技术是通过选择一个值作为基准,将一个 K 个目标值的预测减少到 (K-1) 个二元分类问题,构建 (K-1) 个二元分类器。通常选择最密集的级别作为基准是一个好主意。

感知器

在机器学习的早期,研究人员试图模仿人脑的功能。20 世纪初,人们认为人脑完全由称为神经元的细胞组成——具有长突起的细胞称为轴突,能够通过电脉冲传递信号。AI 研究人员试图通过感知器来复制神经元的功能,感知器是一个基于其输入值的线性加权和的激活函数:

感知器

这是对人脑中过程的一种非常简单的表示——自那时起,生物学家已经发现了除了电脉冲之外的其他信息传递方式,例如化学脉冲。此外,他们已经发现了 300 多种可能被归类为神经元的细胞类型(neurolex.org/wiki/Category:Neuron)。此外,神经元放电的过程比仅仅电压的线性传输要复杂得多,因为它还涉及到复杂的时间模式。尽管如此,这个概念证明是非常有成效的,为神经网络或层间相互连接的感知集开发了许多算法和技术。具体来说,可以证明,通过某些修改,在放电方程中将步函数替换为逻辑函数,神经网络可以以任何所需的精度逼近任意可微函数。

MLlib 实现了多层感知器分类器MLCP)作为一个org.apache.spark.ml.classification.MultilayerPerceptronClassifier类:

$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1-SNAPSHOT
 /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier

scala> import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

scala> import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.MLUtils

scala> 

scala> val data = MLUtils.loadLibSVMFile(sc, "iris-libsvm-3.txt").toDF()
data: org.apache.spark.sql.DataFrame = [label: double, features: vector] 

scala> 

scala> val Array(train, test) = data.randomSplit(Array(0.6, 0.4), seed = 13L)
train: org.apache.spark.sql.DataFrame = [label: double, features: vector]
test: org.apache.spark.sql.DataFrame = [label: double, features: vector]

scala> // specify layers for the neural network: 

scala> // input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes)

scala> val layers = Array(4, 5, 4, 3)
layers: Array[Int] = Array(4, 5, 4, 3)

scala> // create the trainer and set its parameters

scala> val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(13L).setMaxIter(100)
trainer: org.apache.spark.ml.classification.MultilayerPerceptronClassifier = mlpc_b5f2c25196f9

scala> // train the model

scala> val model = trainer.fit(train)
model: org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel = mlpc_b5f2c25196f9

scala> // compute precision on the test set

scala> val result = model.transform(test)
result: org.apache.spark.sql.DataFrame = [label: double, features: vector, prediction: double]

scala> val predictionAndLabels = result.select("prediction", "label")
predictionAndLabels: org.apache.spark.sql.DataFrame = [prediction: double, label: double]

scala> val evaluator = new MulticlassClassificationEvaluator().setMetricName("precision")
evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_55757d35e3b0

scala> println("Precision = " + evaluator.evaluate(predictionAndLabels))
Precision = 0.9375

泛化误差和过拟合

那么,我们如何知道我们讨论的模型是好的呢?一个明显且最终的标准是其实践中的表现。

一个常见的难题困扰着更复杂的模型,例如决策树和神经网络,那就是过拟合问题。模型可以在提供的数据上最小化期望的指标,但在实际部署中,对稍微不同的数据集却表现得很差。即使是一个标准的技巧,当我们把数据集分成训练集和测试集,用于推导模型的训练和验证模型在保留数据集上表现良好的测试,也可能无法捕捉到部署中所有的变化。例如,线性模型如方差分析、逻辑回归和线性回归通常相对稳定,不太容易过拟合。然而,你可能会发现,任何特定的技术对于你的特定领域要么有效要么无效。

另一个可能导致泛化失败的情况是时间漂移。数据可能会随着时间的推移发生显著变化,以至于在旧数据上训练的模型在部署中的新数据上不再泛化。在实践中,始终拥有几个生产中的模型并持续监控它们的相对性能总是一个好主意。

我将在第七章《使用图算法》中考虑避免过拟合的标准方法,如保留数据集和交叉验证,以及在第九章《Scala 中的 NLP》中的模型监控。

摘要

我们现在拥有了所有必要的工具来查看更复杂的问题,这些更常见的问题通常被称为大数据问题。装备了标准的统计算法——我明白我没有涵盖很多细节,我完全准备好接受批评——有一个全新的领域可以探索,在那里我们没有明确定义的记录,数据集中的变量可能是稀疏和嵌套的,我们必须覆盖大量领域并做大量准备工作才能达到可以应用标准统计模型的地步。这正是 Scala 发挥最佳作用的地方。

在下一章中,我们将更深入地探讨如何处理非结构化数据。

第六章:处理非结构化数据

我非常激动地向大家介绍这一章。非结构化数据是现实中使大数据与旧数据不同的东西,它也使 Scala 成为处理数据的新范式。首先,非结构化数据乍一看似乎是一个贬义词。尽管如此,这本书中的每一句话都是非结构化数据:它没有传统的记录/行/列语义。然而,对大多数人来说,这比将书籍呈现为表格或电子表格要容易阅读得多。

在实践中,非结构化数据意味着嵌套和复杂的数据。一个 XML 文档或一张照片都是非结构化数据的良好例子,它们具有非常丰富的结构。我的猜测是,这个术语的创造者意味着新的数据,工程师在像 Google、Facebook 和 Twitter 这样的社交互动公司看到的数据,与传统大家习惯看到的传统平面表结构不同。这些确实不符合传统的 RDBMS 范式。其中一些可以被展平,但底层存储将过于低效,因为 RDBMS 没有优化来处理它们,而且不仅对人类,对机器来说也难以解析。

本章中介绍的大多数技术都是作为应急的 Band-Aid 来应对仅仅处理数据的需要。

在本章中,我们将涵盖以下主题:

  • 学习关于序列化、流行的序列化框架以及机器之间交流的语言

  • 学习关于嵌套数据的 Avro-Parquet 编码

  • 学习 RDBMs 如何尝试在现代类似 SQL 的语言中融入嵌套结构以与之交互

  • 学习如何在 Scala 中开始使用嵌套结构

  • 看一个会话化的实际例子——这是非结构化数据最常用的用例之一

  • 看看 Scala 特性和 match/case 语句如何简化路径分析

  • 学习嵌套结构如何使你的分析受益

嵌套数据

在前面的章节中,你已经看到了非结构化数据,数据是一个LabeledPoint数组的集合,其中LabeledPoint是一个元组(label: Double, features: Vector)。标签只是一个Double类型的数字。Vector是一个密封的特质,有两个子类:SparseVectorDenseVector。类图如下:

嵌套数据

图 1:LabeledPoint 类结构是一个标签和特征的元组,其中特征是一个具有两个继承子类{Dense,Sparse}Vector 的特质。DenseVector 是一个 double 数组,而 SparseVector 通过索引和值存储大小和非默认元素。

每个观测值是一个标签和特征的元组,特征可以是稀疏的。当然,如果没有缺失值,整个行可以表示为一个向量。密集向量表示需要(8 x size + 8)字节。如果大多数元素是缺失的——或者等于某个默认值——我们只能存储非默认元素。在这种情况下,我们需要(12 x non_missing_size + 20)字节,具体取决于 JVM 实现的小幅变化。因此,从存储的角度来看,在大小大于1.5 x (non_missing_size + 1)或大约至少 30%的元素是非默认值时,我们需要在一种或另一种表示之间切换。虽然计算机语言擅长通过指针表示复杂结构,但我们还需要一种方便的形式来在 JVM 或机器之间交换这些数据。首先,让我们看看 Spark/Scala 是如何做的,特别是如何将数据持久化在 Parquet 格式中:

akozlov@Alexanders-MacBook-Pro$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1-SNAPSHOT
 /_/

Using Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.regression.LabeledPoint

scala> import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vectors
Wha
scala> 

scala> val points = Array(
 |    LabeledPoint(0.0, Vectors.sparse(3, Array(1), Array(1.0))),
 |    LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 0.0)),
 |    LabeledPoint(2.0, Vectors.sparse(3, Array((1, 3.0)))),
 |    LabeledPoint.parse("(3.0,[0.0,4.0,0.0])"));
pts: Array[org.apache.spark.mllib.regression.LabeledPoint] = Array((0.0,(3,[1],[1.0])), (1.0,[0.0,2.0,0.0]), (2.0,(3,[1],[3.0])), (3.0,[0.0,4.0,0.0]))
scala> 

scala> val rdd = sc.parallelize(points)
rdd: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = ParallelCollectionRDD[0] at parallelize at <console>:25

scala> 

scala> val df = rdd.repartition(1).toDF
df: org.apache.spark.sql.DataFrame = [label: double, features: vector]

scala> df.write.parquet("points")

我们所做的是从命令行创建一个新的 RDD 数据集,或者我们可以使用org.apache.spark.mllib.util.MLUtils来加载一个文本文件,将其转换为 DataFrames,并在points目录下创建其序列化表示的 Parquet 文件。

注意

Parquet 是什么意思?

Apache Parquet 是一种列式存储格式,由 Cloudera 和 Twitter 共同开发,用于大数据。列式存储允许对数据集中的值进行更好的压缩,并且在只需要从磁盘检索部分列时更为高效。Parquet 是从头开始构建的,考虑到复杂嵌套数据结构,并使用了 Dremel 论文中描述的记录切割和组装算法(blog.twitter.com/2013/dremel-made-simple-with-parquet)。Dremel/Parquet 编码使用定义/重复字段来表示数据来自层次结构中的哪个级别,这覆盖了大多数直接的编码需求,因为它足以存储可选字段、嵌套数组和映射。Parquet 通过块存储数据,因此可能得名 Parquet,其意为由按几何图案排列的木块组成的地面。Parquet 可以优化为只从磁盘读取部分块,这取决于要读取的列子集和使用的索引(尽管这很大程度上取决于特定实现是否了解这些功能)。列中的值可以使用字典和运行长度编码RLE),这对于具有许多重复条目的列提供了非常好的压缩效果,这在大数据中是一个常见的用例。

Parquet 文件是一种二进制格式,但您可能可以使用parquet-tools来查看其中的信息,这些工具可以从archive.cloudera.com/cdh5/cdh/5下载:

akozlov@Alexanders-MacBook-Pro$ wget -O - http://archive.cloudera.com/cdh5/cdh/5/parquet-1.5.0-cdh5.5.0.tar.gz | tar xzvf -

akozlov@Alexanders-MacBook-Pro$ cd parquet-1.5.0-cdh5.5.0/parquet-tools

akozlov@Alexanders-MacBook-Pro$ tar xvf xvf parquet-1.5.0-cdh5.5.0/parquet-tools/target/parquet-tools-1.5.0-cdh5.5.0-bin.tar.gz

akozlov@Alexanders-MacBook-Pro$ cd parquet-tools-1.5.0-cdh5.5.0

akozlov@Alexanders-MacBook-Pro $ ./parquet-schema ~/points/*.parquet 
message spark_schema {
 optional double label;
 optional group features {
 required int32 type (INT_8);
 optional int32 size;
 optional group indices (LIST) {
 repeated group list {
 required int32 element;
 }
 }
 optional group values (LIST) {
 repeated group list {
 required double element;
 }
 }
 }
}

让我们看看模式,它与图 1中描述的结构非常接近:第一个成员是类型为 double 的标签,第二个和最后一个成员是复合类型的特征。关键字optional是另一种表示值可以在记录中为空(缺失)的方式。列表或数组被编码为重复字段。由于整个数组可能不存在(所有特征都可能不存在),它被包裹在可选组(索引和值)中。最后,类型编码表示它是一个稀疏或密集表示:

akozlov@Alexanders-MacBook-Pro $ ./parquet-dump ~/points/*.parquet 
row group 0 
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
label:       DOUBLE GZIP DO:0 FPO:4 SZ:78/79/1.01 VC:4 ENC:BIT_PACKED,PLAIN,RLE
features: 
.type:       INT32 GZIP DO:0 FPO:82 SZ:101/63/0.62 VC:4 ENC:BIT_PACKED,PLAIN_DICTIONARY,RLE
.size:       INT32 GZIP DO:0 FPO:183 SZ:97/59/0.61 VC:4 ENC:BIT_PACKED,PLAIN_DICTIONARY,RLE
.indices: 
..list: 
...element:  INT32 GZIP DO:0 FPO:280 SZ:100/65/0.65 VC:4 ENC:PLAIN_DICTIONARY,RLE
.values: 
..list: 
...element:  DOUBLE GZIP DO:0 FPO:380 SZ:125/111/0.89 VC:8 ENC:PLAIN_DICTIONARY,RLE

 label TV=4 RL=0 DL=1
 ------------------------------------------------------------------------------------------------------------------------------------------------------------------
 page 0:                                           DLE:RLE RLE:BIT_PACKED VLE:PLAIN SZ:38 VC:4

 features.type TV=4 RL=0 DL=1 DS:                 2 DE:PLAIN_DICTIONARY
 ------------------------------------------------------------------------------------------------------------------------------------------------------------------
 page 0:                                           DLE:RLE RLE:BIT_PACKED VLE:PLAIN_DICTIONARY SZ:9 VC:4

 features.size TV=4 RL=0 DL=2 DS:                 1 DE:PLAIN_DICTIONARY
 ------------------------------------------------------------------------------------------------------------------------------------------------------------------
 page 0:                                           DLE:RLE RLE:BIT_PACKED VLE:PLAIN_DICTIONARY SZ:9 VC:4

 features.indices.list.element TV=4 RL=1 DL=3 DS: 1 DE:PLAIN_DICTIONARY
 ------------------------------------------------------------------------------------------------------------------------------------------------------------------
 page 0:                                           DLE:RLE RLE:RLE VLE:PLAIN_DICTIONARY SZ:15 VC:4

 features.values.list.element TV=8 RL=1 DL=3 DS:  5 DE:PLAIN_DICTIONARY
 ------------------------------------------------------------------------------------------------------------------------------------------------------------------
 page 0:                                           DLE:RLE RLE:RLE VLE:PLAIN_DICTIONARY SZ:17 VC:8

DOUBLE label 
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
*** row group 1 of 1, values 1 to 4 *** 
value 1: R:0 D:1 V:0.0
value 2: R:0 D:1 V:1.0
value 3: R:0 D:1 V:2.0
value 4: R:0 D:1 V:3.0

INT32 features.type 
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
*** row group 1 of 1, values 1 to 4 *** 
value 1: R:0 D:1 V:0
value 2: R:0 D:1 V:1
value 3: R:0 D:1 V:0
value 4: R:0 D:1 V:1

INT32 features.size 
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
*** row group 1 of 1, values 1 to 4 *** 
value 1: R:0 D:2 V:3
value 2: R:0 D:1 V:<null>
value 3: R:0 D:2 V:3
value 4: R:0 D:1 V:<null>

INT32 features.indices.list.element 
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
*** row group 1 of 1, values 1 to 4 *** 
value 1: R:0 D:3 V:1
value 2: R:0 D:1 V:<null>
value 3: R:0 D:3 V:1
value 4: R:0 D:1 V:<null>

DOUBLE features.values.list.element 
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
*** row group 1 of 1, values 1 to 8 *** 
value 1: R:0 D:3 V:1.0
value 2: R:0 D:3 V:0.0
value 3: R:1 D:3 V:2.0
value 4: R:1 D:3 V:0.0
value 5: R:0 D:3 V:3.0
value 6: R:0 D:3 V:0.0
value 7: R:1 D:3 V:4.0
value 8: R:1 D:3 V:0.0

您可能对输出中的RD感到有些困惑。这些是 Dremel 论文中描述的重复和定义级别,并且对于有效地编码嵌套结构中的值是必要的。只有重复字段会增加重复级别,只有非必需字段会增加定义级别。R的下降表示列表(数组)的结束。对于层次结构树中的每个非必需级别,都需要一个新的定义级别。重复和定义级别值设计得较小,可以有效地以序列化形式存储。

如果有大量重复条目,它们都将放在一起。这是压缩算法(默认为 gzip)优化的情况。Parquet 还实现了其他利用重复值的算法,例如字典编码或 RLE 压缩。

这是一个简单且高效的默认序列化。我们已经能够将一组复杂对象写入文件,每个列存储在一个单独的块中,代表记录和嵌套结构中的所有值。

现在我们来读取文件并恢复 RDD。Parquet 格式对LabeledPoint类一无所知,因此我们在这里需要进行一些类型转换和技巧。当我们读取文件时,我们会看到一个org.apache.spark.sql.Row的集合:

akozlov@Alexanders-MacBook-Pro$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1-SNAPSHOT
 /_/

Using Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> val df = sqlContext.read.parquet("points")
df: org.apache.spark.sql.DataFrame = [label: double, features: vector]

scala> val df = sqlContext.read.parquet("points").collect
df: Array[org.apache.spark.sql.Row] = Array([0.0,(3,[1],[1.0])], [1.0,[0.0,2.0,0.0]], [2.0,(3,[1],[3.0])], [3.0,[0.0,4.0,0.0]])

scala> val rdd = df.map(x => LabeledPoint(x(0).asInstanceOf[scala.Double], x(1).asInstanceOf[org.apache.spark.mllib.linalg.Vector]))
rdd: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[16] at map at <console>:25

scala> rdd.collect
res12: Array[org.apache.spark.mllib.regression.LabeledPoint] = Array((0.0,(3,[1],[1.0])), (1.0,[0.0,2.0,0.0]), (2.0,(3,[1],[3.0])), (3.0,[0.0,4.0,0.0]))

scala> rdd.filter(_.features(1) <= 2).collect
res13: Array[org.apache.spark.mllib.regression.LabeledPoint] = Array((0.0,(3,[1],[1.0])), (1.0,[0.0,2.0,0.0]))

个人认为,这相当酷:无需任何编译,我们就可以编码和决定复杂对象。在 REPL 中,人们可以轻松创建自己的对象。让我们考虑我们想要跟踪用户在网上的行为:

akozlov@Alexanders-MacBook-Pro$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1-SNAPSHOT
 /_/

Using Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> case class Person(id: String, visits: Array[String]) { override def toString: String = { val vsts = visits.mkString(","); s"($id -> $vsts)" } }
defined class Person

scala> val p1 = Person("Phil", Array("http://www.google.com", "http://www.facebook.com", "http://www.linkedin.com", "http://www.homedepot.com"))
p1: Person = (Phil -> http://www.google.com,http://www.facebook.com,http://www.linkedin.com,http://www.homedepot.com)

scala> val p2 = Person("Emily", Array("http://www.victoriassecret.com", "http://www.pacsun.com", "http://www.abercrombie.com/shop/us", "http://www.orvis.com"))
p2: Person = (Emily -> http://www.victoriassecret.com,http://www.pacsun.com,http://www.abercrombie.com/shop/us,http://www.orvis.com)

scala> sc.parallelize(Array(p1,p2)).repartition(1).toDF.write.parquet("history")

scala> import scala.collection.mutable.WrappedArray
import scala.collection.mutable.WrappedArray

scala> val df = sqlContext.read.parquet("history")
df: org.apache.spark.sql.DataFrame = [id: string, visits: array<string>]

scala> val rdd = df.map(x => Person(x(0).asInstanceOf[String], x(1).asInstanceOf[WrappedArray[String]].toArray[String]))
rdd: org.apache.spark.rdd.RDD[Person] = MapPartitionsRDD[27] at map at <console>:28

scala> rdd.collect
res9: Array[Person] = Array((Phil -> http://www.google.com,http://www.facebook.com,http://www.linkedin.com,http://www.homedepot.com), (Emily -> http://www.victoriassecret.com,http://www.pacsun.com,http://www.abercrombie.com/shop/us,http://www.orvis.com))

作为良好的实践,我们需要将新创建的类注册到Kryo 序列化器中——Spark 将使用另一种序列化机制在任务和执行器之间传递对象。如果类未注册,Spark 将使用默认的 Java 序列化,这可能会慢上10 倍

scala> :paste
// Entering paste mode (ctrl-D to finish)

import com.esotericsoftware.kryo.Kryo
import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}

class MyKryoRegistrator extends KryoRegistrator {
 override def registerClasses(kryo: Kryo) {
 kryo.register(classOf[Person])
 }
}

object MyKryoRegistrator {
 def register(conf: org.apache.spark.SparkConf) {
 conf.set("spark.serializer", classOf[KryoSerializer].getName)
 conf.set("spark.kryo.registrator", classOf[MyKryoRegistrator].getName)
 }
}
^D

// Exiting paste mode, now interpreting.

import com.esotericsoftware.kryo.Kryo
import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
defined class MyKryoRegistrator
defined module MyKryoRegistrator

scala>

如果你正在将代码部署到集群上,建议将此代码放在类路径上的 jar 文件中。

我确实在生产中看到了多达 10 层嵌套的例子。尽管这可能在性能上可能有些过度,但在越来越多的生产业务用例中,嵌套是必需的。在我们深入到构建嵌套对象的特定示例(例如会话化)之前,让我们先对序列化的一般情况有一个概述。

其他序列化格式

我确实推荐使用 Parquet 格式来存储数据。然而,为了完整性,我至少需要提及其他序列化格式,其中一些,如 Kryo,将在 Spark 计算过程中不为人知地隐式使用,并且显然存在默认的 Java 序列化。

小贴士

面向对象方法与函数式方法

在面向对象的方法中,对象以其状态和行为为特征。对象是面向对象编程的基石。一个类是具有表示状态的字段和可能表示行为的方法的对象的模板。抽象方法实现可能依赖于类的实例。在函数式方法中,状态通常是不受欢迎的;在纯编程语言中,不应该有状态,没有副作用,并且每次调用都应该返回相同的结果。行为可以通过额外的函数参数和高级函数(如柯里化函数)来表示,但应该像抽象方法一样明确。由于 Scala 是面向对象和函数式语言的混合,一些先前的约束被违反了,但这并不意味着你必须在绝对必要时才使用它们。最佳实践是在存储数据的同时将代码存储在 jar 包中,尤其是对于大数据,应将数据文件(以序列化形式)与代码分开;但再次强调,人们经常将数据/配置存储在 jar 文件中,而将代码存储在数据文件中则较少见,但也是可能的。

序列化问题自从需要在磁盘上持久化数据或通过网络将对象从一个 JVM 或机器传输到另一个机器以来一直存在。实际上,序列化的目的是将复杂的嵌套对象表示为一系列机器可理解的字节,正如你可以想象的那样,这可能是语言相关的。幸运的是,序列化框架在它们可以处理的一组常见数据结构上达成了一致。

以下是最受欢迎的序列化机制之一,但不是最有效的,即在一个 ASCII 文件中转储对象:CSV、XML、JSON、YAML 等。它们对于更复杂的嵌套数据结构,如结构、数组和映射,是有效的,但从存储空间的角度来看效率低下。例如,一个 Double 表示一个具有 15-17 位有效数字的连续数字,在没有舍入或简单比率的情况下,将需要 15-17 个字节来表示,而二进制表示只需要 8 个字节。整数可能存储得更有效率,尤其是如果它们很小,因为我们可以压缩/删除零。

文本编码的一个优点是它们使用简单的命令行工具更容易可视化,但现在任何高级序列化框架都附带了一套用于处理原始记录(如avro -parquet-tools)的工具。

以下表格提供了大多数常见序列化框架的概述:

序列化格式 开发时间 评论
XML, JSON, YAML 这是对编码嵌套结构和在机器之间交换数据的必要性的直接回应。 虽然效率低下,但它们仍然被许多地方使用,尤其是在网络服务中。唯一的优点是它们相对容易解析,无需机器。
Protobuf 由谷歌在 2000 年代初开发。该协议实现了 Dremel 编码方案,并支持多种语言(Scala 尚未官方支持,尽管存在一些代码)。主要优势是 Protobuf 可以在许多语言中生成本地类。C++、Java 和 Python 是官方支持的语言。C、C#、Haskell、Perl、Ruby、Scala 和其他语言正在进行中的项目。运行时可以调用本地代码来检查/序列化/反序列化对象和二进制表示。
Avro Avro 是由 Doug Cutting 在 Cloudera 工作时开发的。主要目标是使编码与特定实现和语言分离,从而实现更好的模式演变。 虽然关于 Protobuf 或 Avro 哪个更高效的争论仍在继续,但与 Protobuf 相比,Avro 支持更多的复杂结构,例如开箱即用的联合和映射。Scala 的支持仍需加强以达到生产水平。Avro 文件包含每个文件的编码模式,这既有优点也有缺点。
Thrift Apache Thrift 是在 Facebook 开发的,目的是与 Protobuf 相同。它可能支持的语言种类最广泛:C++、Java、Python、PHP、Ruby、Erlang、Perl、Haskell、C#、Cocoa、JavaScript、Node.js、Smalltalk、OCaml、Delphi 和其他语言。再次,Twitter 正在努力为 Scala 的 Thrift 代码生成提供支持(twitter.github.io/scrooge/)。 Apache Thrift 通常被描述为跨语言服务开发的框架,并且最常用于远程过程调用RPC)。尽管它可以直接用于序列化/反序列化,但其他框架却更受欢迎。
Parquet Parquet 是由 Twitter 和 Cloudera 共同开发的。与以行为导向的 Avro 格式相比,Parquet 是列式存储,如果只选择少量列,则可以提供更好的压缩和性能。区间编码基于 Dremel 或 Protobuf,尽管记录以 Avro 记录的形式呈现;因此,它通常被称为AvroParquet 索引、字典编码和 RLE 压缩等高级功能可能使其对于纯磁盘存储非常高效。由于 Parquet 需要在提交到磁盘之前进行一些预处理和索引构建,因此写入文件可能会更慢。
Kryo 这是一个用于在 Java 中编码任意类的框架。然而,并非所有内置的 Java 集合类都可以序列化。如果避免非序列化异常,例如优先队列,Kryo 可以非常高效。Scala 的直接支持也在进行中。

当然,Java 有一个内置的序列化框架,但由于它必须支持所有 Java 情况,因此过于通用,Java 序列化比任何先前的序列化方法都要低效得多。我确实看到其他公司更早地实现了它们自己的专有序列化,这会优于先前的任何序列化方法。如今,这已不再必要,因为维护成本肯定超过了现有框架的收敛低效。

Hive 和 Impala

新框架的设计考虑之一总是与旧框架的兼容性。不论好坏,大多数数据分析师仍在使用 SQL。SQL 的根源可以追溯到一篇有影响力的关系建模论文(Codd, Edgar F. (1970 年 6 月). 《大型共享数据银行的数据关系模型》. 《ACM 通讯》(计算机机械协会)13(6):377–87)。所有现代数据库都实现了 SQL 的一个或多个版本。

虽然关系模型对提高数据库性能有影响,尤其是对于在线事务处理OLTP)的竞争力水平,但对于需要执行聚合操作的分析工作负载,以及对于关系本身发生变化并受到分析的情况,规范化的重要性较低。本节将涵盖用于大数据分析的传统分析引擎的标准 SQL 语言的扩展:Hive 和 Impala。它们目前都是 Apache 许可项目。以下表格总结了复杂类型:

类型 Hive 支持版本 Impala 支持版本 备注
ARRAY 自 0.1.0 版本起支持,但非常量索引表达式的使用仅限于 0.14 版本之后。 自 2.3.0 版本起支持(仅限 Parquet 表)。 可以是任何类型的数组,包括复杂类型。在 Hive 中索引为int(在 Impala 中为bigint),访问通过数组表示法,例如,在 Hive 中为element[1](在 Impala 中为array.positem伪列)。
MAP 自 0.1.0 版本起支持,但非常量索引表达式的使用仅限于 0.14 版本之后。 自 2.3.0 版本起支持(仅限 Parquet 表)。 键应为原始类型。一些库仅支持字符串类型的键。字段使用数组表示法访问,例如,在 Hive 中为map["key"](在 Impala 中为 map 键和值的伪列)。
STRUCT 自 0.5.0 版本起支持。 自 2.3.0 版本起支持(仅限 Parquet 表)。 使用点表示法访问,例如,struct.element
UNIONTYPE 自 0.7.0 以来支持。 在 Impala 中不支持。 支持不完整:引用UNIONTYPE字段的JOIN(HIVE-2508)、WHEREGROUP BY子句的查询将失败,并且 Hive 没有定义提取UNIONTYPE的标签或值字段的语法。这意味着UNIONTYPEs实际上只能查看。

虽然 Hive/Impala 表可以建立在许多底层文件格式(文本、序列、ORC、Avro、Parquet 以及甚至自定义格式)和多种序列化之上,但在大多数实际情况下,Hive 用于读取 ASCII 文件中的文本行。底层的序列化/反序列化格式是 LazySimpleSerDe序列化/反序列化SerDe))。该格式定义了多个分隔符级别,如下所示:

row_format
  : DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] [COLLECTION ITEMS TERMINATED BY char]
    [MAP KEYS TERMINATED BY char] [LINES TERMINATED BY char]
    [NULL DEFINED AS char]

分隔符的默认值是'\001'^A'\002'^B,以及'\003'^B。换句话说,它是在层次结构的每一级使用新的分隔符,而不是在 Dremel 编码中使用定义/重复指示符。例如,为了对之前使用的LabeledPoint表进行编码,我们需要创建一个文件,如下所示:

$ cat data
0^A1^B1^D1.0$
2^A1^B1^D3.0$
1^A0^B0.0^C2.0^C0.0$
3^A0^B0.0^C4.0^C0.0$

archive.cloudera.com/cdh5/cdh/5/hive-1.1.0-cdh5.5.0.tar.gz下载 Hive 并执行以下操作:

$ tar xf hive-1.1.0-cdh5.5.0.tar.gz 
$ cd hive-1.1.0-cdh5.5.0
$ bin/hive
…
hive> CREATE TABLE LABELED_POINT ( LABEL INT, VECTOR UNIONTYPE<ARRAY<DOUBLE>, MAP<INT,DOUBLE>> ) STORED AS TEXTFILE;
OK
Time taken: 0.453 seconds
hive> LOAD DATA LOCAL INPATH './data' OVERWRITE INTO TABLE LABELED_POINT;
Loading data to table alexdb.labeled_point
Table labeled_point stats: [numFiles=1, numRows=0, totalSize=52, rawDataSize=0]
OK
Time taken: 0.808 seconds
hive> select * from labeled_point;
OK
0  {1:{1:1.0}}
2  {1:{1:3.0}}
1  {0:[0.0,2.0,0.0]}
3  {0:[0.0,4.0,0.0]}
Time taken: 0.569 seconds, Fetched: 4 row(s)
hive>

在 Spark 中,可以通过sqlContext.sql方法从关系型表中选择,但遗憾的是,截至 Spark 1.6.1,Hive 联合类型并不直接支持;尽管如此,它支持映射和数组。在其他 BI 和数据分析工具中支持复杂对象仍然是它们被采用的最大障碍之一。将所有内容作为丰富的数据结构在 Scala 中支持是收敛到嵌套数据表示的一种选项。

会话化

我将以会话化为例,演示复杂或嵌套结构的用法。在会话化中,我们想要找到某个 ID 在一段时间内实体的行为。虽然原始记录可能以任何顺序到来,但我们想要总结随时间推移的行为以推导出趋势。

我们已经在第一章中分析了网络服务器日志,探索性数据分析。我们发现了在一段时间内不同网页被访问的频率。我们可以对这一信息进行切块和切片,但如果没有分析页面访问的顺序,就很难理解每个用户与网站的交互。在这一章中,我想通过跟踪用户在整个网站中的导航来给这种分析增加更多的个性化特点。会话化是网站个性化、广告、物联网跟踪、遥测和企业安全等与实体行为相关的常见工具。

假设数据以三个元素的元组形式出现(原始数据集中的字段1511,见第一章,探索性数据分析):

(id, timestamp, path)

在这里,id是一个唯一的实体 ID,时间戳是一个事件的timestamp(任何可排序的格式:Unix 时间戳或 ISO8601 日期格式),而path是关于 Web 服务器页面层次结构的某种指示。

对于熟悉 SQL 或至少熟悉其子集的人来说,会话化(sessionization)更知名为窗口分析函数:

SELECT id, timestamp, path 
  ANALYTIC_FUNCTION(path) OVER (PARTITION BY id ORDER BY timestamp) AS agg
FROM log_table;

在这里,ANALYTIC_FUNCTION是对给定id的路径序列进行的一些转换。虽然这种方法对于相对简单的函数,如第一个、最后一个、滞后、平均来说有效,但在路径序列上表达复杂函数通常非常复杂(例如,来自 Aster Data 的 nPath (www.nersc.gov/assets/Uploads/AnalyticsFoundation5.0previewfor4.6.x-Guide.pdf)))。此外,在没有额外的预处理和分区的情况下,这些方法通常会导致在分布式环境中多个节点之间的大数据传输。

在纯函数式方法中,人们只需要设计一个函数或一系列函数应用来从原始的元组集中生成所需的答案,我将创建两个辅助对象,这将帮助我们简化与用户会话概念的工作。作为额外的优势,新的嵌套结构可以持久化到磁盘上,以加快对额外问题的答案获取速度。

让我们看看在 Spark/Scala 中使用案例类是如何实现的:

akozlov@Alexanders-MacBook-Pro$ bin/spark-shell
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1-SNAPSHOT
 /_/

Using Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> :paste
// Entering paste mode (ctrl-D to finish)

import java.io._

// a basic page view structure
@SerialVersionUID(123L)
case class PageView(ts: String, path: String) extends Serializable with Ordered[PageView] {
 override def toString: String = {
 s"($ts :$path)"
 }
 def compare(other: PageView) = ts compare other.ts
}

// represent a session
@SerialVersionUID(456L)
case class SessionA  <: PageView extends Serializable {
 override def toString: String = {
 val vsts = visits.mkString("[", ",", "]")
 s"($id -> $vsts)"
 }
}^D
// Exiting paste mode, now interpreting.

import java.io._
defined class PageView
defined class Session

第一个类将代表一个带有时间戳的单页浏览,在这种情况下,是一个 ISO8601 String,而第二个是一个页浏览序列。我们能通过将这两个成员编码为一个带有对象分隔符的String来实现吗?当然可以,但将字段表示为类的成员提供了良好的访问语义,同时将一些需要在编译器上执行的工作卸载掉,这总是件好事。

让我们读取之前描述的日志文件并构建对象:

scala> val rdd = sc.textFile("log.csv").map(x => { val z = x.split(",",3); (z(1), new PageView(z(0), z(2))) } ).groupByKey.map( x => { new Session(x._1, x._2.toSeq.sorted) } ).persist
rdd: org.apache.spark.rdd.RDD[Session] = MapPartitionsRDD[14] at map at <console>:31

scala> rdd.take(3).foreach(println)
(189.248.74.238 -> [(2015-08-23 23:09:16 :mycompanycom>homepage),(2015-08-23 23:11:00 :mycompanycom>homepage),(2015-08-23 23:11:02 :mycompanycom>running:slp),(2015-08-23 23:12:01 :mycompanycom>running:slp),(2015-08-23 23:12:03 :mycompanycom>running>stories>2013>04>themycompanyfreestore:cdp),(2015-08-23 23:12:08 :mycompanycom>running>stories>2013>04>themycompanyfreestore:cdp),(2015-08-23 23:12:08 :mycompanycom>running>stories>2013>04>themycompanyfreestore:cdp),(2015-08-23 23:12:42 :mycompanycom>running:slp),(2015-08-23 23:13:25 :mycompanycom>homepage),(2015-08-23 23:14:00 :mycompanycom>homepage),(2015-08-23 23:14:06 :mycompanycom:mobile>mycompany photoid>landing),(2015-08-23 23:14:56 :mycompanycom>men>shoes:segmentedgrid),(2015-08-23 23:15:10 :mycompanycom>homepage)])
(82.166.130.148 -> [(2015-08-23 23:14:27 :mycompanycom>homepage)])
(88.234.248.111 -> [(2015-08-23 22:36:10 :mycompanycom>plus>home),(2015-08-23 22:36:20 :mycompanycom>plus>home),(2015-08-23 22:36:28 :mycompanycom>plus>home),(2015-08-23 22:36:30 :mycompanycom>plus>onepluspdp>sport band),(2015-08-23 22:36:52 :mycompanycom>onsite search>results found),(2015-08-23 22:37:19 :mycompanycom>plus>onepluspdp>sport band),(2015-08-23 22:37:21 :mycompanycom>plus>home),(2015-08-23 22:37:39 :mycompanycom>plus>home),(2015-08-23 22:37:43 :mycompanycom>plus>home),(2015-08-23 22:37:46 :mycompanycom>plus>onepluspdp>sport watch),(2015-08-23 22:37:50 :mycompanycom>gear>mycompany+ sportwatch:standardgrid),(2015-08-23 22:38:14 :mycompanycom>homepage),(2015-08-23 22:38:35 :mycompanycom>homepage),(2015-08-23 22:38:37 :mycompanycom>plus>products landing),(2015-08-23 22:39:01 :mycompanycom>homepage),(2015-08-23 22:39:24 :mycompanycom>homepage),(2015-08-23 22:39:26 :mycompanycom>plus>whatismycompanyfuel)])

哈哈!我们得到了一个包含会话的 RDD,每个唯一的 IP 地址对应一个会话。IP 189.248.74.238有一个从23:09:1623:15:10的会话,看起来是在浏览男士鞋子后结束的。IP 82.166.130.148的会话只有一个点击。最后一个会话专注于运动手表,从2015-08-23 22:36:102015-08-23 22:39:26持续了超过三分钟。现在,我们可以轻松地询问涉及特定导航路径模式的问题。例如,我们想要分析所有导致结账(路径中包含checkout)的会话,并查看主页上最后点击后的点击次数和分布:

scala> import java.time.ZoneOffset
import java.time.ZoneOffset

scala> import java.time.LocalDateTime
import java.time.LocalDateTime

scala> import java.time.format.DateTimeFormatter
import java.time.format.DateTimeFormatter

scala> 
scala> def toEpochSeconds(str: String) : Long = { LocalDateTime.parse(str, DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")).toEpochSecond(ZoneOffset.UTC) }
toEpochSeconds: (str: String)Long

scala> val checkoutPattern = ".*>checkout.*".r.pattern
checkoutPattern: java.util.regex.Pattern = .*>checkout.*

scala> val lengths = rdd.map(x => { val pths = x.visits.map(y => y.path); val pchs = pths.indexWhere(checkoutPattern.matcher(_).matches); (x.id, x.visits.map(y => y.ts).min, x.visits.map(y => y.ts).max, x.visits.lastIndexWhere(_ match { case PageView(ts, "mycompanycom>homepage") => true; case _ => false }, pchs), pchs, x.visits) } ).filter(_._4>0).filter(t => t._5>t._4).map(t => (t._5 - t._4, toEpochSeconds(t._6(t._5).ts) - toEpochSeconds(t._6(t._4).ts)))

scala> lengths.toDF("cnt", "sec").agg(avg($"cnt"),min($"cnt"),max($"cnt"),avg($"sec"),min($"sec"),max($"sec")).show
+-----------------+--------+--------+------------------+--------+--------+

|         avg(cnt)|min(cnt)|max(cnt)|          avg(sec)|min(sec)|max(sec)|
+-----------------+--------+--------+------------------+--------+--------+
|19.77570093457944|       1|     121|366.06542056074767|      15|    2635|
+-----------------+--------+--------+------------------+--------+--------+

scala> lengths.map(x => (x._1,1)).reduceByKey(_+_).sortByKey().collect
res18: Array[(Int, Int)] = Array((1,1), (2,8), (3,2), (5,6), (6,7), (7,9), (8,10), (9,4), (10,6), (11,4), (12,4), (13,2), (14,3), (15,2), (17,4), (18,6), (19,1), (20,1), (21,1), (22,2), (26,1), (27,1), (30,2), (31,2), (35,1), (38,1), (39,2), (41,1), (43,2), (47,1), (48,1), (49,1), (65,1), (66,1), (73,1), (87,1), (91,1), (103,1), (109,1), (121,1))

这些会话持续时间为 1 到 121 次点击,平均点击次数为 8 次,以及从 15 到 2653 秒(或大约 45 分钟)。你为什么会对这个信息感兴趣?长时间的会话可能表明会话中间出现了问题:长时间的延迟或无响应的电话。这并不一定意味着:这个人可能只是吃了个漫长的午餐休息或打电话讨论他的潜在购买,但这里可能有一些有趣的东西。至少应该同意这一点是一个异常值,需要仔细分析。

让我们谈谈将数据持久化到磁盘。正如你所看到的,我们的转换被编写为一个长管道,所以结果中没有任何东西是不能从原始数据计算出来的。这是一个函数式方法,数据是不可变的。此外,如果我们的处理过程中出现错误,比如说我想将主页更改为其他锚定页面,我总是可以修改函数而不是数据。你可能对这个事实感到满意或不满意,但结果中绝对没有额外的信息——转换只会增加无序和熵。它们可能使人类更容易接受,但这仅仅是因为人类是一个非常低效的数据处理装置。

小贴士

为什么重新排列数据会使分析更快?

会话化似乎只是简单的数据重新排列——我们只是将按顺序访问的页面放在一起。然而,在许多情况下,它使实际数据分析的速度提高了 10 到 100 倍。原因是数据局部性。分析,如过滤或路径匹配,通常倾向于一次在单个会话的页面上进行。推导用户特征需要用户的全部页面浏览或交互都在磁盘和内存的一个地方。这通常比其他低效性更好,比如编码/解码嵌套结构的开销,因为这可以在本地 L1/L2 缓存中发生,而不是从 RAM 或磁盘进行数据传输,这在现代多线程 CPU 中要昂贵得多。当然,这非常取决于分析复杂性。

有理由将新数据持久化到磁盘上,我们可以使用 CSV、Avro 或 Parquet 格式来完成。原因是我们不希望在再次查看数据时重新处理数据。新的表示可能更紧凑,更高效地检索和展示给我的经理。实际上,人类喜欢副作用,幸运的是,Scala/Spark 允许你像前一个部分描述的那样做。

嗯,嗯,嗯……熟悉会话化的人会这么说。这只是故事的一部分。我们想要将路径序列分割成多个会话,运行路径分析,计算页面转换的条件概率等等。这正是函数式范式大放异彩的地方。编写以下函数:

def splitSession(session: Session[PageView]) : Seq[Session[PageView]] = { … }

然后运行以下代码:

val newRdd = rdd.flatMap(splitSession)

哈哈!结果是会话的拆分。我故意省略了实现;实现是用户依赖的,而不是数据,每个分析师可能有自己将页面访问序列拆分成会话的方式。

将函数应用于特征生成以应用机器学习……这已经暗示了副作用:我们希望修改世界的状态,使其更加个性化和用户友好。我想最终是无法避免的。

与特性一起工作

正如我们所见,case 类显著简化了我们构造的新嵌套数据结构的处理。case 类定义可能是从 Java(和 SQL)迁移到 Scala 的最有说服力的理由。现在,关于方法呢?我们如何快速向类中添加方法而不需要昂贵的重新编译?Scala 允许你通过特性透明地做到这一点!

函数式编程的一个基本特征是函数与对象一样是第一类公民。在前一节中,我们定义了两个将 ISO8601 格式转换为秒的纪元时间的EpochSeconds函数。我们还建议了splitSession函数,它为给定的 IP 提供多会话视图。我们如何将这种行为或其他行为与给定的类关联起来?

首先,让我们定义一个期望的行为:

scala> trait Epoch {
 |   this: PageView =>
 |   def epoch() : Long = { LocalDateTime.parse(ts, DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")).toEpochSecond(ZoneOffset.UTC) }
 | }
defined trait Epoch

这基本上创建了一个针对PageView的特定函数,该函数将日期时间的字符串表示转换为秒的纪元时间。现在,如果我们只是进行以下转换:

scala> val rddEpoch = rdd.map(x => new Session(x.id, x.visits.map(x => new PageView(x.ts, x.path) with Epoch)))
rddEpoch: org.apache.spark.rdd.RDD[Session[PageView with Epoch]] = MapPartitionsRDD[20] at map at <console>:31

现在我们有一个包含额外行为的页面浏览 RDD。例如,如果我们想找出会话中每个单独页面的时间花费是多少,我们将运行以下管道:

scala> rddEpoch.map(x => (x.id, x.visits.zip(x.visits.tail).map(x => (x._2.path, x._2.epoch - x._1.epoch)).mkString("[", ",", "]"))).take(3).foreach(println)
(189.248.74.238,[(mycompanycom>homepage,104),(mycompanycom>running:slp,2),(mycompanycom>running:slp,59),(mycompanycom>running>stories>2013>04>themycompanyfreestore:cdp,2),(mycompanycom>running>stories>2013>04>themycompanyfreestore:cdp,5),(mycompanycom>running>stories>2013>04>themycompanyfreestore:cdp,0),(mycompanycom>running:slp,34),(mycompanycom>homepage,43),(mycompanycom>homepage,35),(mycompanycom:mobile>mycompany photoid>landing,6),(mycompanycom>men>shoes:segmentedgrid,50),(mycompanycom>homepage,14)])
(82.166.130.148,[])
(88.234.248.111,[(mycompanycom>plus>home,10),(mycompanycom>plus>home,8),(mycompanycom>plus>onepluspdp>sport band,2),(mycompanycom>onsite search>results found,22),(mycompanycom>plus>onepluspdp>sport band,27),(mycompanycom>plus>home,2),(mycompanycom>plus>home,18),(mycompanycom>plus>home,4),(mycompanycom>plus>onepluspdp>sport watch,3),(mycompanycom>gear>mycompany+ sportwatch:standardgrid,4),(mycompanycom>homepage,24),(mycompanycom>homepage,21),(mycompanycom>plus>products landing,2),(mycompanycom>homepage,24),(mycompanycom>homepage,23),(mycompanycom>plus>whatismycompanyfuel,2)])

可以同时添加多个特性,而不会影响原始类定义或原始数据。不需要重新编译。

与模式匹配一起工作

没有一本 Scala 书籍会不提及 match/case 语句。Scala 有一个非常丰富的模式匹配机制。例如,假设我们想找出所有以主页开始,然后是产品页面的页面浏览序列的实例——我们真正想过滤掉确定的买家。这可以通过以下新函数实现:

scala> def findAllMatchedSessions(h: Seq[Session[PageView]], s: Session[PageView]) : Seq[Session[PageView]] = {
 |     def matchSessions(h: Seq[Session[PageView]], id: String, p: Seq[PageView]) : Seq[Session[PageView]] = {
 |       p match {
 |         case Nil => Nil
 |         case PageView(ts1, "mycompanycom>homepage") :: PageView(ts2, "mycompanycom>plus>products landing") :: tail =>
 |           matchSessions(h, id, tail).+:(new Session(id, p))
 |         case _ => matchSessions(h, id, p.tail)
 |       }
 |     }
 |    matchSessions(h, s.id, s.visits)
 | }
findAllSessions: (h: Seq[Session[PageView]], s: Session[PageView])Seq[Session[PageView]]

注意,我们在 case 语句中明确放置了PageView构造函数!Scala 将遍历visits序列,并生成与指定的两个PageViews匹配的新会话,如下所示:

scala> rdd.flatMap(x => findAllMatchedSessions(Nil, x)).take(10).foreach(println)
(88.234.248.111 -> [(2015-08-23 22:38:35 :mycompanycom>homepage),(2015-08-23 22:38:37 :mycompanycom>plus>products landing),(2015-08-23 22:39:01 :mycompanycom>homepage),(2015-08-23 22:39:24 :mycompanycom>homepage),(2015-08-23 22:39:26 :mycompanycom>plus>whatismycompanyfuel)])
(148.246.218.251 -> [(2015-08-23 22:52:09 :mycompanycom>homepage),(2015-08-23 22:52:16 :mycompanycom>plus>products landing),(2015-08-23 22:52:23 :mycompanycom>homepage),(2015-08-23 22:52:32 :mycompanycom>homepage),(2015-08-23 22:52:39 :mycompanycom>running:slp)])
(86.30.116.229 -> [(2015-08-23 23:15:00 :mycompanycom>homepage),(2015-08-23 23:15:02 :mycompanycom>plus>products landing),(2015-08-23 23:15:12 :mycompanycom>plus>products landing),(2015-08-23 23:15:18 :mycompanycom>language tunnel>load),(2015-08-23 23:15:23 :mycompanycom>language tunnel>geo selected),(2015-08-23 23:15:24 :mycompanycom>homepage),(2015-08-23 23:15:27 :mycompanycom>homepage),(2015-08-23 23:15:30 :mycompanycom>basketball:slp),(2015-08-23 23:15:38 :mycompanycom>basketball>lebron-10:cdp),(2015-08-23 23:15:50 :mycompanycom>basketball>lebron-10:cdp),(2015-08-23 23:16:05 :mycompanycom>homepage),(2015-08-23 23:16:09 :mycompanycom>homepage),(2015-08-23 23:16:11 :mycompanycom>basketball:slp),(2015-08-23 23:16:29 :mycompanycom>onsite search>results found),(2015-08-23 23:16:39 :mycompanycom>onsite search>no results)])
(204.237.0.130 -> [(2015-08-23 23:26:23 :mycompanycom>homepage),(2015-08-23 23:26:27 :mycompanycom>plus>products landing),(2015-08-23 23:26:35 :mycompanycom>plus>fuelband activity>summary>wk)])
(97.82.221.34 -> [(2015-08-23 22:36:24 :mycompanycom>homepage),(2015-08-23 22:36:32 :mycompanycom>plus>products landing),(2015-08-23 22:37:09 :mycompanycom>plus>plus activity>summary>wk),(2015-08-23 22:37:39 :mycompanycom>plus>products landing),(2015-08-23 22:44:17 :mycompanycom>plus>home),(2015-08-23 22:44:33 :mycompanycom>plus>home),(2015-08-23 22:44:34 :mycompanycom>plus>home),(2015-08-23 22:44:36 :mycompanycom>plus>home),(2015-08-23 22:44:43 :mycompanycom>plus>home)])
(24.230.204.72 -> [(2015-08-23 22:49:58 :mycompanycom>homepage),(2015-08-23 22:50:00 :mycompanycom>plus>products landing),(2015-08-23 22:50:30 :mycompanycom>homepage),(2015-08-23 22:50:38 :mycompanycom>homepage),(2015-08-23 22:50:41 :mycompanycom>training:cdp),(2015-08-23 22:51:56 :mycompanycom>training:cdp),(2015-08-23 22:51:59 :mycompanycom>store locator>start),(2015-08-23 22:52:28 :mycompanycom>store locator>landing)])
(62.248.72.18 -> [(2015-08-23 23:14:27 :mycompanycom>homepage),(2015-08-23 23:14:30 :mycompanycom>plus>products landing),(2015-08-23 23:14:33 :mycompanycom>plus>products landing),(2015-08-23 23:14:40 :mycompanycom>plus>products landing),(2015-08-23 23:14:47 :mycompanycom>store homepage),(2015-08-23 23:14:50 :mycompanycom>store homepage),(2015-08-23 23:14:55 :mycompanycom>men:clp),(2015-08-23 23:15:08 :mycompanycom>men:clp),(2015-08-23 23:15:15 :mycompanycom>men:clp),(2015-08-23 23:15:16 :mycompanycom>men:clp),(2015-08-23 23:15:24 :mycompanycom>men>sportswear:standardgrid),(2015-08-23 23:15:41 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:15:45 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:15:45 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:15:49 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:15:50 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:15:56 :mycompanycom>men>sportswear:standardgrid),(2015-08-23 23:18:41 :mycompanycom>pdp>mycompany bruin low men's shoe),(2015-08-23 23:18:42 :mycompanycom>pdp>mycompany bruin low men's shoe),(2015-08-23 23:18:53 :mycompanycom>pdp>mycompany bruin low men's shoe),(2015-08-23 23:18:55 :mycompanycom>pdp>mycompany bruin low men's shoe),(2015-08-23 23:18:57 :mycompanycom>pdp>mycompany bruin low men's shoe),(2015-08-23 23:19:04 :mycompanycom>men>sportswear:standardgrid),(2015-08-23 23:20:12 :mycompanycom>men>sportswear>silver:standardgrid),(2015-08-23 23:28:20 :mycompanycom>onsite search>no results),(2015-08-23 23:28:33 :mycompanycom>onsite search>no results),(2015-08-23 23:28:36 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:28:40 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:28:41 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:28:43 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:28:43 :mycompanycom>pdp>mycompany blazer low premium vintage suede men's shoe),(2015-08-23 23:29:00 :mycompanycom>pdp:mycompanyid>mycompany blazer low id shoe)])
(46.5.127.21 -> [(2015-08-23 22:58:00 :mycompanycom>homepage),(2015-08-23 22:58:01 :mycompanycom>plus>products landing)])
(200.45.228.1 -> [(2015-08-23 23:07:33 :mycompanycom>homepage),(2015-08-23 23:07:39 :mycompanycom>plus>products landing),(2015-08-23 23:07:42 :mycompanycom>plus>products landing),(2015-08-23 23:07:45 :mycompanycom>language tunnel>load),(2015-08-23 23:07:59 :mycompanycom>homepage),(2015-08-23 23:08:15 :mycompanycom>homepage),(2015-08-23 23:08:26 :mycompanycom>onsite search>results found),(2015-08-23 23:08:43 :mycompanycom>onsite search>no results),(2015-08-23 23:08:49 :mycompanycom>onsite search>results found),(2015-08-23 23:08:53 :mycompanycom>language tunnel>load),(2015-08-23 23:08:55 :mycompanycom>plus>products landing),(2015-08-23 23:09:04 :mycompanycom>homepage),(2015-08-23 23:11:34 :mycompanycom>running:slp)])
(37.78.203.213 -> [(2015-08-23 23:18:10 :mycompanycom>homepage),(2015-08-23 23:18:12 :mycompanycom>plus>products landing),(2015-08-23 23:18:14 :mycompanycom>plus>products landing),(2015-08-23 23:18:22 :mycompanycom>plus>products landing),(2015-08-23 23:18:25 :mycompanycom>store homepage),(2015-08-23 23:18:31 :mycompanycom>store homepage),(2015-08-23 23:18:34 :mycompanycom>men:clp),(2015-08-23 23:18:50 :mycompanycom>store homepage),(2015-08-23 23:18:51 :mycompanycom>footwear:segmentedgrid),(2015-08-23 23:19:12 :mycompanycom>men>footwear:segmentedgrid),(2015-08-23 23:19:12 :mycompanycom>men>footwear:segmentedgrid),(2015-08-23 23:19:26 :mycompanycom>men>footwear>new releases:standardgrid),(2015-08-23 23:19:26 :mycompanycom>men>footwear>new releases:standardgrid),(2015-08-23 23:19:35 :mycompanycom>pdp>mycompany cheyenne 2015 men's shoe),(2015-08-23 23:19:40 :mycompanycom>men>footwear>new releases:standardgrid)])

我留给读者编写一个函数,该函数仅过滤那些用户在转到产品页面之前花费不到 10 秒的会话。纪元特性或先前定义的EpochSeconds函数可能很有用。

match/case 函数也可以用于特征生成,并返回一个会话上的特征向量。

非结构化数据的其他用途

个性化设备和诊断显然不是非结构化数据的唯一用途。前面的例子是一个很好的例子,因为我们从结构化记录开始,很快转向了构建非结构化数据结构以简化分析的需求。

事实上,非结构化数据比结构化数据要多得多;只是传统统计分析中扁平结构的便利性让我们将数据呈现为记录集。文本、图像和音乐是半结构化数据的例子。

非结构化数据的一个例子是非规范化数据。传统上,记录数据主要是为了性能原因而规范化的,因为关系型数据库管理系统已经针对结构化数据进行了优化。这导致了外键和查找表,但如果维度发生变化,这些表就非常难以维护。非规范化数据没有这个问题,因为查找表可以与每条记录一起存储——它只是一个与行关联的附加表对象,但可能不太节省存储空间。

概率结构

另一个用例是概率结构。通常人们认为回答问题是有确定性的。正如我在第二章,数据管道和建模中所示,在许多情况下,真正的答案与一些不确定性相关。编码不确定性的最流行的方法之一是概率,这是一种频率论方法,意味着当答案确实发生时,简单的计数除以总尝试次数——概率也可以编码我们的信念。我将在以下章节中涉及到概率分析和模型,但概率分析需要存储每个可能的输出及其概率度量,这恰好是一个嵌套结构。

投影

处理高维性的一个方法是在低维空间上进行投影。投影可能有效的基本依据是 Johnson-Lindenstrauss 引理。该引理指出,高维空间中的一小部分点可以嵌入到一个维度低得多的空间中,使得点之间的距离几乎得到保留。当我们谈到第第九章,Scala 中的 NLP中的 NLP 时,我们将涉及到随机和其他投影,但随机投影对于嵌套结构和函数式编程语言来说效果很好,因为在许多情况下,生成随机投影是应用一个函数到紧凑编码的数据上,而不是显式地展平数据。换句话说,Scala 中随机投影的定义可能看起来像是函数式范式闪耀。编写以下函数:

def randomeProjecton(data: NestedStructure) : Vector = { … }

在这里,向量处于低维空间中。

用于嵌入的映射至少是 Lipschitz 连续的,甚至可以被视为一个正交投影。

摘要

在本章中,我们看到了如何在 Scala 中表示和操作复杂和嵌套数据的示例。显然,要涵盖所有情况是困难的,因为非结构化数据的世界比现实世界中结构化数据按行逐行简化的美好领域要大得多,而且仍在建设中。图片、音乐以及口语和书面语言有很多细微差别,很难在平面表示中捕捉到。

在进行最终数据分析时,我们最终会将数据集转换为面向记录的平面表示,但至少在收集数据时,需要小心地存储数据原样,不要丢弃可能包含在数据或元数据中的有用信息。通过扩展数据库和存储,以记录这些有用信息的方式是第一步。接下来的一步是使用能够有效分析这些信息的语言;这无疑是 Scala。

在下一章中,我们将探讨与图相关的话题,这是非结构化数据的一个特定示例。

第七章。使用图算法

在本章中,我将深入探讨 Scala 中的图库和算法实现。特别是,我将介绍 Graph for Scala(www.scala-graph.org),这是一个始于 2011 年 EPFL Scala 孵化器的开源项目。Graph for Scala 目前不支持分布式计算——流行图算法的分布式计算方面可在 GraphX 中找到,它是 Spark 项目(spark.apache.org/docs/latest/mllib-guide.html)的一部分 MLlib 库。Spark 和 MLlib 都始于 2009 年左右或之后在加州大学伯克利分校的课堂项目。我在第三章中考虑了 Spark,并在使用 Spark 和 MLlib中介绍了 RDD。在 GraphX 中,图是一对 RDD,每个 RDD 都在执行器和任务之间分区,代表图中的顶点和边。

在本章中,我们将涵盖以下主题:

  • 配置简单构建工具SBT)以交互式地使用本章中的材料

  • 学习 Graph for Scala 支持的图的基本操作

  • 学习如何强制执行图约束

  • 学习如何导入/导出 JSON 中的图

  • 在 Enron 电子邮件数据上执行连通分量、三角形计数和强连通分量的计算

  • 在 Enron 电子邮件数据上执行 PageRank 计算

  • 学习如何使用 SVD++

图的快速介绍

什么是图?图是一组顶点,其中一些顶点对通过相互连接。如果每个顶点都与每个其他顶点相连,我们说这个图是一个完全图。相反,如果没有边,我们说这个图是空的。当然,这些是实践中很少遇到的情况,因为图具有不同的密度;边的数量与顶点数量的比例越高,我们说它越密集。

根据我们打算在图上运行哪些算法以及预期的密度如何,我们可以选择如何适当地在内存中表示图。如果图非常密集,将其存储为方形的N x N矩阵是有益的,其中第n行和m列的0表示第n个顶点没有与第m个顶点相连。对角线条目表示节点与其自身的连接。这种表示称为邻接矩阵。

如果边不多,并且我们需要无区别地遍历整个边集,通常将其存储为对偶对的简单容器是有益的。这种结构称为边表

在实践中,我们可以将许多现实生活中的情况和事件建模为图。我们可以想象城市作为顶点,平面航线作为边。如果两个城市之间没有航班,它们之间就没有边。此外,如果我们将机票的数值成本添加到边中,我们说这个图是加权的。如果有些边只存在单向旅行,我们可以通过使图有向而不是无向图来表示这一点。因此,对于一个无向图,它是对称的,即如果AB相连,那么B也与A相连——这不一定适用于有向图。

没有环的图称为无环图。多重图可以在节点之间包含多条边,这些边可能是不同类型的。超边可以连接任意数量的节点。

在无向图中,最流行的算法可能是连通分量,或者将图划分为子图,其中任何两个顶点都通过路径相互连接。划分对于并行化图上的操作很重要。

Google 和其他搜索引擎使 PageRank 变得流行。根据 Google 的说法,PageRank 通过计算指向一个页面的链接数量和质量来估计网站的重要性。其基本假设是,更重要的网站更有可能从其他网站(尤其是排名更高的网站)那里获得更多链接。PageRank 可以应用于网站排名以外的许多问题,并且相当于寻找连通矩阵的特征向量和最重要的特征值。

最基本的、非平凡的子图由三个节点组成。三角形计数找出所有可能的完全连接(或完整)的节点三元组,这是在社区检测和 CAD 中使用的另一个众所周知的算法。

是一个完全连接的子图。强连通分量是针对有向图的类似概念:子图中的每个顶点都可以从其他每个顶点访问。GraphX 提供了这两种实现的实现。

最后,推荐图是连接两种类型节点的图:用户和项目。边可以包含推荐的强度或满意度的度量。推荐系统的目标是预测可能缺失的边的满意度。已经为推荐引擎开发了多种算法,例如 SVD 和 SVD++,这些算法在本章的末尾进行讨论。

SBT

每个人都喜欢 Scala REPL。REPL 是 Scala 的命令行。它允许你输入立即评估的 Scala 表达式,尝试和探索事物。正如你在前面的章节中看到的,你可以在命令提示符中简单地输入 scala 并开始开发复杂的数据管道。更方便的是,你可以按 tab 键进行自动完成,这是任何成熟现代 IDE(如 Eclipse 或 IntelliJ,Ctrl +. 或 Ctrl + Space)的必需功能,通过跟踪命名空间和使用反射机制。我们为什么需要一个额外的构建工具或框架呢,尤其是 Ant、Maven 和 Gradle 等其他构建管理框架已经存在于 IDE 之外?正如 SBT 的作者所争论的,尽管一个人可能使用前面的工具编译 Scala,但所有这些工具在交互性和 Scala 构建的再现性方面都有效率低下的问题(Joshua SuerethMatthew FarwellSBT in Action,2015 年 11 月)。

对于我来说,SBT 的一个主要特性是交互性和能够无缝地与多个版本的 Scala 和依赖库一起工作。最终,对于软件开发来说,关键在于能够多快地原型设计和测试新想法。我曾经在大型机使用穿孔卡片工作,程序员们等待执行他们的程序和想法,有时需要几个小时甚至几天。计算机的效率更为重要,因为这是瓶颈。那些日子已经过去了,现在个人笔记本电脑的计算能力可能比几十年前满屋的服务器还要强大。为了利用这种效率,我们需要通过加快程序开发周期来更有效地利用人的时间,这也意味着交互性和仓库中的更多版本。

除了处理多个版本和 REPL 的能力之外,SBT 的主要特性如下:

  • 原生支持编译 Scala 代码以及与多个测试框架集成,包括 JUnit、ScalaTest 和 Selenium

  • 使用 DSL 编写的 Scala 构建描述

  • 使用 Ivy 进行依赖管理(同时支持 Maven 格式仓库)

  • 持续执行、编译、测试和部署

  • 与 Scala 解释器集成以实现快速迭代和调试

  • 支持混合 Java/Scala 项目

  • 支持测试和部署框架

  • 能够通过自定义插件来补充工具

  • 任务并行执行

SBT 是用 Scala 编写的,并使用 SBT 来构建自身(自举或自用)。SBT 已成为 Scala 社区的默认构建工具,并被LiftPlay框架所使用。

虽然您可以从www.scala-sbt.org/download直接下载 SBT,但要在 Mac 上安装 SBT 最简单的方法是运行 MacPorts:

$ port install sbt

您还可以运行 Homebrew:

$ brew install sbt

虽然存在其他工具来创建 SBT 项目,但最直接的方法是在为每个章节提供的 GitHub 项目存储库中运行bin/create_project.sh脚本:

$ bin/create_project.sh

这将创建主和测试源子目录(但不包括代码)。项目目录包含项目范围内的设置(请参阅project/build.properties)。目标将包含编译后的类和构建包(目录将包含不同版本的 Scala 的不同子目录,例如 2.10 和 2.11)。最后,任何放入lib目录的 jar 或库都将在整个项目中可用(我建议在build.sbt文件中使用libraryDependencies机制,但并非所有库都可通过集中式存储库获得)。这是最小设置,目录结构可能包含多个子项目。Scalastyle 插件甚至会为您检查语法(www.scalastyle.org/sbt.html)。只需添加project/plugin.sbt

$ cat >> project.plugin.sbt << EOF
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0")
EOF

最后,SBT 使用sdbt doc命令创建 Scaladoc 文档。

注意

build.sbt 中的空白行和其他设置

可能大多数的build.sbt文件都是双倍间距:这是旧版本的遗留。您不再需要它们。从版本 0.13.7 开始,定义不需要额外的行。

build.sbtbuild.properties 中,你可以使用许多其他设置,最新的文档可在 www.scala-sbt.org/documentation.html 查找。

当从命令行运行时,工具将自动下载并使用依赖项,在这种情况下,是 graph-{core,constrained,json}lift-json。为了运行项目,只需输入 sbt run

在连续模式下,SBT 将自动检测源文件的变化并重新运行命令。为了在启动 REPL 后连续编译和运行代码,请在 REPL 启动后输入~~ run

要获取命令的帮助,请运行以下命令:

$ sbt
 [info] Loading global plugins from /Users/akozlov/.sbt/0.13/plugins
[info] Set current project to My Graph Project (in build file:/Users/akozlov/Scala/graph/)
> help

 help                                    Displays this help message or prints detailed help on requested commands (run 'help <command>').
For example, `sbt package` will build a Java jar, as follows:
$  sbt package
[info] Loading global plugins from /Users/akozlov/.sbt/0.13/plugins
[info] Loading project definition from /Users/akozlov/Scala/graph/project
[info] Set current project to My Graph Project (in build file:/Users/akozlov/Scala/graph/)
[info] Updating {file:/Users/akozlov/Scala/graph/}graph...
[info] Resolving jline#jline;2.12.1 ...
[info] Done updating.
$ ls -1 target/scala-2.11/
classes
my-graph-project_2.11-1.0.jar

即使使用简单的编辑器如viEmacs,SBT 也足以满足我们的需求,但sbteclipse项目在github.com/typesafehub/sbteclipse中会创建必要的项目文件,以便与您的 Eclipse IDE 一起使用。

Scala 的 Graph

对于这个项目,我将创建一个src/main/scala/InfluenceDiagram.scala文件。为了演示目的,我将仅重新创建来自第二章 数据管道和建模的图:

import scalax.collection.Graph
import scalax.collection.edge._
import scalax.collection.GraphPredef._
import scalax.collection.GraphEdge._

import scalax.collection.edge.Implicits._

object InfluenceDiagram extends App {
  var g = GraphString, LDiEdge("Forecast"), ("'Weather Forecast'"~+>"'Vacation Activity'")("Decision"), ("'Vacation Activity'"~+>"'Satisfaction'")("Deterministic"), ("'Weather'"~+>"'Satisfaction'")("Deterministic"))
  println(g.mkString(";"))
  println(g.isDirected)
  println(g.isAcyclic)
}

~+>运算符用于在scalax/collection/edge/Implicits.scala中定义的两个节点之间创建一个有向标签边,在我们的案例中,这些节点是String类型。其他边类型和运算符的列表如下表所示:

以下表格显示了来自scalax.collection.edge.Implicits(来自www.scala-graph.org/guides/core-initializing.html)的图边

边类 速记/运算符 描述
超边
HyperEdge ~ 超边
WHyperEdge ~% 加权超边
WkHyperEdge ~%# 关键加权超边
LHyperEdge ~+ 标签超边
LkHyperEdge ~+# 关键标签超边
WLHyperEdge ~%+ 加权标签超边
WkLHyperEdge ~%#+ 关键加权标签超边
WLkHyperEdge ~%+# 加权关键标签超边
WkLkHyperEdge ~%#+# 关键加权关键标签超边
有向超边
DiHyperEdge ~> 有向超边
WDiHyperEdge ~%> 加权有向超边
WkDiHyperEdge ~%#> 关键加权有向超边
LDiHyperEdge ~+> 标签有向超边
LkDiHyperEdge ~+#> 关键标签有向超边
WLDiHyperEdge ~%+> 加权标签有向超边
WkLDiHyperEdge ~%#+> 关键加权标签有向超边
WLkDiHyperEdge ~%+#> 加权关键标签有向超边
WkLkDiHyperEdge ~%#+#> 关键加权关键标签有向超边
无向边
UnDiEdge ~ 无向边
WUnDiEdge ~% 加权无向边
WkUnDiEdge ~%# 关键加权无向边
LUnDiEdge ~+ 标签无向边
LkUnDiEdge ~+# 关键标签无向边
WLUnDiEdge ~%+ 加权标签无向边
WkLUnDiEdge ~%#+ 关键加权标签无向边
WLkUnDiEdge ~%+# 加权关键标签无向边
WkLkUnDiEdge ~%#+# 关键加权关键标签无向边
有向边
DiEdge ~> 有向边
WDiEdge ~%> 加权有向边
WkDiEdge ~%#> 关键加权有向边
LDiEdge ~+> 标签有向边
LkDiEdge ~+#> 关键标签有向边
WLDiEdge ~%+> 加权标签有向边
WkLDiEdge ~%#+> 关键加权标签有向边
WLkDiEdge ~%+#> 加权关键标签有向边
WkLkDiEdge ~%#+#> 关键加权关键标签有向边

你已经看到了 Scala 中图的力量:边可以有权重,我们可能构建一个多图(键标签边允许一对源节点和目标节点有多个边)。

如果你使用 SBT 在src/main/scala目录中的 Scala 文件运行前面的项目,输出将如下所示:

[akozlov@Alexanders-MacBook-Pro chapter07(master)]$ sbt
[info] Loading project definition from /Users/akozlov/Src/Book/ml-in-scala/chapter07/project
[info] Set current project to Working with Graph Algorithms (in build file:/Users/akozlov/Src/Book/ml-in-scala/chapter07/)
> run
[warn] Multiple main classes detected.  Run 'show discoveredMainClasses' to see the list

Multiple main classes detected, select one to run:

 [1] org.akozlov.chapter07.ConstranedDAG
 [2] org.akozlov.chapter07.EnronEmail
 [3] org.akozlov.chapter07.InfluenceDiagram
 [4] org.akozlov.chapter07.InfluenceDiagramToJson

Enter number: 3

[info] Running org.akozlov.chapter07.InfluenceDiagram 
'Weather';'Vacation Activity';'Satisfaction';'Weather Forecast';'Weather'~>'Weather Forecast' 'Forecast;'Weather'~>'Satisfaction' 'Deterministic;'Vacation Activity'~>'Satisfaction' 'Deterministic;'Weather Forecast'~>'Vacation Activity' 'Decision
Directed: true
Acyclic: true
'Weather';'Vacation Activity';'Satisfaction';'Recommend to a Friend';'Weather Forecast';'Weather'~>'Weather Forecast' 'Forecast;'Weather'~>'Satisfaction' 'Deterministic;'Vacation Activity'~>'Satisfaction' 'Deterministic;'Satisfaction'~>'Recommend to a Friend' 'Probabilistic;'Weather Forecast'~>'Vacation Activity' 'Decision
Directed: true
Acyclic: true

如果启用了连续编译,主方法将在 SBT 检测到文件已更改时立即运行(如果有多个类具有主方法,SBT 将询问你想要运行哪一个,这对交互性来说不是很好;因此,你可能想要限制可执行类的数量)。

我将在短时间内介绍不同的输出格式,但首先让我们看看如何在图上执行简单操作。

添加节点和边

首先,我们已经知道图是有向和无环的,这是所有决策图所需的一个属性,这样我们知道我们没有犯错误。假设我想使图更复杂,并添加一个节点来表示我向另一个人推荐俄勒冈州波特兰度假的可能性。我需要添加的只是以下这一行:

g += ("'Satisfaction'" ~+> "'Recommend to a Friend'")("Probabilistic")

如果你启用了连续编译/运行,按下保存文件按钮后你将立即看到变化:

'Weather';'Vacation Activity';'Satisfaction';'Recommend to a Friend';'Weather Forecast';'Weather'~>'Weather Forecast' 'Forecast;'Weather'~>'Satisfaction' 'Deterministic;'Vacation Activity'~>'Satisfaction' 'Deterministic;'Satisfaction'~>'Recommend to a Friend' 'Probabilistic;'Weather Forecast'~>'Vacation Activity' 'Decision
Directed: true
Acyclic: true

现在,如果我们想知道新引入的节点的父节点,我们可以简单地运行以下代码:

println((g get "'Recommend to a Friend'").incoming)

Set('Satisfaction'~>'Recommend to a Friend' 'Probabilistic)

这将为我们提供一个特定节点的父节点集——从而驱动决策过程。如果我们添加一个循环,无环方法将自动检测到:

g += ("'Satisfaction'" ~+> "'Weather'")("Cyclic")
println(g.mkString(";")) println("Directed: " + g.isDirected)
println("Acyclic: " + g.isAcyclic)

'Weather';'Vacation Activity';'Satisfaction';'Recommend to a Friend';'Weather Forecast';'Weather'~>'Weather Forecast' 'Forecast;'Weather'~>'Satisfaction' 'Deterministic;'Vacation Activity'~>'Satisfaction' 'Deterministic;'Satisfaction'~>'Recommend to a Friend' 'Probabilistic;'Satisfaction'~>'Weather' 'Cyclic;'Weather Forecast'~>'Vacation Activity' 'Decision
Directed: true
Acyclic: false

注意,你可以完全通过编程创建图:

 var n, m = 0; val f = Graph.fill(45){ m = if (m < 9) m + 1 else { n = if (n < 8) n + 1 else 8; n + 1 }; m ~ n }

  println(f.nodes)
  println(f.edges)
  println(f)

  println("Directed: " + f.isDirected)
  println("Acyclic: " + f.isAcyclic)

NodeSet(0, 9, 1, 5, 2, 6, 3, 7, 4, 8)
EdgeSet(9~0, 9~1, 9~2, 9~3, 9~4, 9~5, 9~6, 9~7, 9~8, 1~0, 5~0, 5~1, 5~2, 5~3, 5~4, 2~0, 2~1, 6~0, 6~1, 6~2, 6~3, 6~4, 6~5, 3~0, 3~1, 3~2, 7~0, 7~1, 7~2, 7~3, 7~4, 7~5, 7~6, 4~0, 4~1, 4~2, 4~3, 8~0, 8~1, 8~2, 8~3, 8~4, 8~5, 8~6, 8~7)
Graph(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1~0, 2~0, 2~1, 3~0, 3~1, 3~2, 4~0, 4~1, 4~2, 4~3, 5~0, 5~1, 5~2, 5~3, 5~4, 6~0, 6~1, 6~2, 6~3, 6~4, 6~5, 7~0, 7~1, 7~2, 7~3, 7~4, 7~5, 7~6, 8~0, 8~1, 8~2, 8~3, 8~4, 8~5, 8~6, 8~7, 9~0, 9~1, 9~2, 9~3, 9~4, 9~5, 9~6, 9~7, 9~8)
Directed: false
Acyclic: false

在这里,作为填充方法第二个参数提供的元素计算被重复45次(第一个参数)。图将每个节点连接到其所有前驱节点,这在图论中也称为团。

图约束

Graph for Scala 使我们能够设置任何未来的图更新都不能违反的约束。当我们想要保留图结构中的某些细节时,这很有用。例如,有向无环图(DAG)不应包含循环。目前有两个约束作为scalax.collection.constrained.constraints包的一部分实现——连通和无环,如下所示:

package org.akozlov.chapter07

import scalax.collection.GraphPredef._, scalax.collection.GraphEdge._
import scalax.collection.constrained.{Config, ConstraintCompanion, Graph => DAG}
import scalax.collection.constrained.constraints.{Connected, Acyclic}

object AcyclicWithSideEffect extends ConstraintCompanion[Acyclic] {
  def apply [N, E[X] <: EdgeLikeIn[X]] (self: DAG[N,E]) =
    new Acyclic[N,E] (self) {
      override def onAdditionRefused(refusedNodes: Iterable[N],
        refusedEdges: Iterable[E[N]],
        graph:        DAG[N,E]) = {
          println("Addition refused: " + "nodes = " + refusedNodes + ", edges = " + refusedEdges)
          true
        }
    }
}

object ConnectedWithSideEffect extends ConstraintCompanion[Connected] {
  def apply [N, E[X] <: EdgeLikeIn[X]] (self: DAG[N,E]) =
    new Connected[N,E] (self) {
      override def onSubtractionRefused(refusedNodes: Iterable[DAG[N,E]#NodeT],
        refusedEdges: Iterable[DAG[N,E]#EdgeT],
        graph:        DAG[N,E]) = {
          println("Subtraction refused: " + "nodes = " + refusedNodes + ", edges = " + refusedEdges)
        true
      }
    }
}

class CycleException(msg: String) extends IllegalArgumentException(msg)
object ConstranedDAG extends App {
  implicit val conf: Config = ConnectedWithSideEffect && AcyclicWithSideEffect
  val g = DAG(1~>2, 1~>3, 2~>3, 3~>4) // Graph()
  println(g ++ List(1~>4, 3~>1))
  println(g - 2~>3)
  println(g - 2)
  println((g + 4~>5) - 3)
}

这是运行尝试添加或删除违反约束的节点的程序的命令:

[akozlov@Alexanders-MacBook-Pro chapter07(master)]$ sbt "run-main org.akozlov.chapter07.ConstranedDAG"
[info] Loading project definition from /Users/akozlov/Src/Book/ml-in-scala/chapter07/project
[info] Set current project to Working with Graph Algorithms (in build file:/Users/akozlov/Src/Book/ml-in-scala/chapter07/)
[info] Running org.akozlov.chapter07.ConstranedDAG 
Addition refused: nodes = List(), edges = List(1~>4, 3~>1)
Graph(1, 2, 3, 4, 1~>2, 1~>3, 2~>3, 3~>4)
Subtraction refused: nodes = Set(), edges = Set(2~>3)
Graph(1, 2, 3, 4, 1~>2, 1~>3, 2~>3, 3~>4)
Graph(1, 3, 4, 1~>3, 3~>4)
Subtraction refused: nodes = Set(3), edges = Set()
Graph(1, 2, 3, 4, 5, 1~>2, 1~>3, 2~>3, 3~>4, 4~>5)
[success] Total time: 1 s, completed May 1, 2016 1:53:42 PM 

添加或减去违反约束之一的节点将被拒绝。如果尝试添加或减去违反条件的节点,程序员还可以指定副作用。

JSON

Graph for Scala 支持将图导入/导出到 JSON,如下所示:

object InfluenceDiagramToJson extends App {

  val g = GraphString,LDiEdge("Forecast"), ("'Weather Forecast'" ~+> "'Vacation Activity'")("Decision"), ("'Vacation Activity'" ~+> "'Satisfaction'")("Deterministic"), ("'Weather'" ~+> "'Satisfaction'")("Deterministic"), ("'Satisfaction'" ~+> "'Recommend to a Friend'")("Probabilistic"))

  import scalax.collection.io.json.descriptor.predefined.{LDi}
  import scalax.collection.io.json.descriptor.StringNodeDescriptor
  import scalax.collection.io.json._

  val descriptor = new DescriptorString
  )

  val n = g.toJson(descriptor)
  println(n)
  import net.liftweb.json._
  println(Printer.pretty(JsonAST.render(JsonParser.parse(n))))
}

要为示例图生成 JSON 表示,请运行:

[kozlov@Alexanders-MacBook-Pro chapter07(master)]$ sbt "run-main org.akozlov.chapter07.InfluenceDiagramToJson"
[info] Loading project definition from /Users/akozlov/Src/Book/ml-in-scala/chapter07/project
[info] Set current project to Working with Graph Algorithms (in build file:/Users/akozlov/Src/Book/ml-in-scala/chapter07/)
[info] Running org.akozlov.chapter07.InfluenceDiagramToJson 
{
 "nodes":[["'Recommend to a Friend'"],["'Satisfaction'"],["'Vacation Activity'"],["'Weather Forecast'"],["'Weather'"]],
 "edges":[{
 "n1":"'Weather'",
 "n2":"'Weather Forecast'",
 "label":"Forecast"
 },{
 "n1":"'Vacation Activity'",
 "n2":"'Satisfaction'",
 "label":"Deterministic"
 },{
 "n1":"'Weather'",
 "n2":"'Satisfaction'",
 "label":"Deterministic"
 },{
 "n1":"'Weather Forecast'",
 "n2":"'Vacation Activity'",
 "label":"Decision"
 },{
 "n1":"'Satisfaction'",
 "n2":"'Recommend to a Friend'",
 "label":"Probabilistic"
 }]
}
[success] Total time: 1 s, completed May 1, 2016 1:55:30 PM

对于更复杂的结构,可能需要编写自定义描述符、序列化和反序列化程序(参考www.scala-graph.org/api/json/api/#scalax.collection.io.json.package)。

GraphX

虽然 Graph for Scala 可能被认为是图操作和查询的 DSL,但应该转向 GraphX 以实现可扩展性。GraphX 建立在强大的 Spark 框架之上。作为一个 Spark/GraphX 操作的例子,我将使用 CMU Enron 电子邮件数据集(约 2 GB)。实际上,对电子邮件内容的语义分析对我们来说直到下一章才重要。数据集可以从 CMU 网站下载。它包含 150 个用户(主要是 Enron 经理)的电子邮件,以及他们之间大约 517,401 封电子邮件。这些电子邮件可以被视为两个人之间关系(边)的指示:每封电子邮件都是一个源(From:)和目标(To:)顶点的边。

由于 GraphX 需要 RDD 格式的数据,我必须做一些预处理。幸运的是,使用 Scala 来做这一点非常简单——这就是为什么 Scala 是半结构化数据的完美语言。以下是代码:

package org.akozlov.chapter07

import scala.io.Source

import scala.util.hashing.{MurmurHash3 => Hash}
import scala.util.matching.Regex

import java.util.{Date => javaDateTime}

import java.io.File
import net.liftweb.json._
import Extraction._
import Serialization.{read, write}

object EnronEmail {

  val emailRe = """[a-zA-Z0-9_.+\-]+@enron.com""".r.unanchored

  def emails(s: String) = {
    for (email <- emailRe findAllIn s) yield email
  }

  def hash(s: String) = {
    java.lang.Integer.MAX_VALUE.toLong + Hash.stringHash(s)
  }

  val messageRe =
    """(?:Message-ID:\s+)(<[A-Za-z0-9_.+\-@]+>)(?s)(?:.*?)(?m)
      |(?:Date:\s+)(.*?)$(?:.*?)
      |(?:From:\s+)([a-zA-Z0-9_.+\-]+@enron.com)(?:.*?)
      |(?:Subject: )(.*?)$""".stripMargin.r.unanchored

  case class Relation(from: String, fromId: Long, to: String, toId: Long, source: String, messageId: String, date: javaDateTime, subject: String)

  implicit val formats = Serialization.formats(NoTypeHints)

  def getFileTree(f: File): Stream[File] =
    f #:: (if (f.isDirectory) f.listFiles().toStream.flatMap(getFileTree) else Stream.empty)

  def main(args: Array[String]) {
    getFileTree(new File(args(0))).par.map {
      file => {
        "\\.$".r findFirstIn file.getName match {
          case Some(x) =>
          try {
            val src = Source.fromFile(file, "us-ascii")
            val message = try src.mkString finally src.close()
            message match {
              case messageRe(messageId, date, from , subject) =>
              val fromLower = from.toLowerCase
              for (to <- emails(message).filter(_ != fromLower).toList.distinct)
              println(write(Relation(fromLower, hash(fromLower), to, hash(to), file.toString, messageId, new javaDateTime(date), subject)))
                case _ =>
            }
          } catch {
            case e: Exception => System.err.println(e)
          }
          case _ =>
        }
      }
    }
  }
}

首先,我们使用MurmurHash3类生成节点 ID,它们是Long类型,因为它们在 GraphX 中的每个节点都是必需的。emailRemessageRe用于将文件内容与所需内容匹配。Scala 允许您在不费太多功夫的情况下并行化程序。

注意第 50 行的par调用,getFileTree(new File(args(0))).par.map。这将使循环并行化。即使是在 3 GHz 处理器上处理整个 Enron 数据集也可能需要长达一个小时,但添加并行化可以在一个 32 核心的 Intel Xeon E5-2630 2.4 GHz CPU Linux 机器上减少大约 8 分钟(在 2.3 GHz Intel Core i7 的 Apple MacBook Pro 上需要 15 分钟)。

运行代码将生成一组 JSON 记录,这些记录可以加载到 Spark 中(要运行它,您需要在类路径上放置joda-timelift-json库 jar 文件),如下所示:

# (mkdir Enron; cd Enron; wget -O - http://www.cs.cmu.edu/~./enron/enron_mail_20150507.tgz | tar xzvf -)
...
# sbt --error "run-main org.akozlov.chapter07.EnronEmail Enron/maildir" > graph.json

# spark --driver-memory 2g --executor-memory 2g
...
scala> val df = sqlContext.read.json("graph.json")
df: org.apache.spark.sql.DataFrame = [[date: string, from: string, fromId: bigint, messageId: string, source: string, subject: string, to: string, toId: bigint]

很好!Spark 能够自己确定字段和类型。如果 Spark 无法解析所有记录,则会有一个_corrupt_record字段包含未解析的记录(其中一个是数据集末尾的[success]行,可以使用grep -Fv [success]过滤掉)。您可以使用以下命令查看它们:

scala> df.select("_corrupt_record").collect.foreach(println)
...

节点(人)和边(关系)数据集可以通过以下命令提取:

scala> import org.apache.spark._
...
scala> import org.apache.spark.graphx._
...
scala> import org.apache.spark.rdd.RDD
...
scala> val people: RDD[(VertexId, String)] = df.select("fromId", "from").unionAll(df.select("toId", "to")).na.drop.distinct.map( x => (x.get(0).toString.toLong, x.get(1).toString))
people: org.apache.spark.rdd.RDD[(org.apache.spark.graphx.VertexId, String)] = MapPartitionsRDD[146] at map at <console>:28

scala> val relationships = df.select("fromId", "toId", "messageId", "subject").na.drop.distinct.map( x => Edge(x.get(0).toString.toLong, x.get(1).toString.toLong, (x.get(2).toString, x.get(3).toString)))
relationships: org.apache.spark.rdd.RDD[org.apache.spark.graphx.Edge[(String, String)]] = MapPartitionsRDD[156] at map at <console>:28

scala> val graph = Graph(people, relationships).cache
graph: org.apache.spark.graphx.Graph[String,(String, String)] = org.apache.spark.graphx.impl.GraphImpl@7b59aa7b

注意

GraphX 中的节点 ID

正如我们在 Graph for Scala 中看到的,指定边就足以定义节点和图。在 Spark/GraphX 中,节点需要显式提取,并且每个节点都需要与一个Long类型的n ID 相关联。虽然这可能会限制灵活性和唯一节点的数量,但它提高了效率。在这个特定的例子中,将电子邮件字符串的哈希值作为节点 ID 是足够的,因为没有检测到冲突,但生成唯一 ID 通常是一个难以并行化的难题。

第一个 GraphX 图已经准备好了!它比 Scala 的 Graph 多花了一些时间,但现在它完全准备好进行分布式处理了。需要注意几点:首先,我们需要明确地将字段转换为LongString,因为Edge构造函数需要帮助确定类型。其次,Spark 可能需要优化分区数量(很可能创建了太多的分区):

scala> graph.vertices.getNumPartitions
res1: Int = 200

scala> graph.edges.getNumPartitions
res2: Int = 200

要重新分区,有两个调用:repartition 和 coalesce。后者试图避免 shuffle,如下所示:

scala> val graph = Graph(people.coalesce(6), relationships.coalesce(6))
graph: org.apache.spark.graphx.Graph[String,(String, String)] = org.apache.spark.graphx.impl.GraphImpl@5dc7d016

scala> graph.vertices.getNumPartitions
res10: Int = 6

scala> graph.edges.getNumPartitions
res11: Int = 6

然而,如果在大集群上执行计算,这可能会限制并行性。最后,使用cache方法将数据结构固定在内存中是个好主意:

scala> graph.cache
res12: org.apache.spark.graphx.Graph[String,(String, String)] = org.apache.spark.graphx.impl.GraphImpl@5dc7d016

在 Spark 中构建图需要更多的命令,但四个并不是太多。让我们计算一些统计数据(并展示 Spark/GraphX 的强大功能,如下表所示:

在 Enron 电子邮件图上计算基本统计数据。

统计信息 Spark 命令 Enron 的值
总关系数(成对沟通) graph.numEdges 3,035,021
电子邮件(消息 ID)数量 graph.edges.map( e => e.attr._1 ).distinct.count 371,135
连接对数 graph.edges.flatMap( e => List((e.srcId, e.dstId), (e.dstId, e.srcId))).distinct.count / 2 217,867
单向沟通数 graph.edges.flatMap( e => List((e.srcId, e.dstId), (e.dstId, e.srcId))).distinct.count - graph.edges.map( e => (e.srcId, e.dstId)).distinct.count 193,183
不同的主题行数 graph.edges.map( e => e.attr._2 ).distinct.count 110,273
节点总数 graph.numVertices 23,607
仅目标节点数 graph. numVertices - graph.edges.map( e => e.srcId).distinct.count 17,264
仅源节点数 graph. numVertices - graph.edges.map( e => e.dstId).distinct.count 611

谁在接收电子邮件?

评估一个人在组织中的重要性最直接的方法之一是查看连接数或进出沟通的数量。GraphX 图内置了inDegreesoutDegrees方法。要按收到的邮件数量对电子邮件进行排名,请运行:

scala> people.join(graph.inDegrees).sortBy(_._2._2, ascending=false).take(10).foreach(println)
(268746271,(richard.shapiro@enron.com,18523))
(1608171805,(steven.kean@enron.com,15867))
(1578042212,(jeff.dasovich@enron.com,13878))
(960683221,(tana.jones@enron.com,13717))
(3784547591,(james.steffes@enron.com,12980))
(1403062842,(sara.shackleton@enron.com,12082))
(2319161027,(mark.taylor@enron.com,12018))
(969899621,(mark.guzman@enron.com,10777))
(1362498694,(geir.solberg@enron.com,10296))
(4151996958,(ryan.slinger@enron.com,10160))

要按发出的电子邮件数量对电子邮件进行排名,请运行:

scala> people.join(graph.outDegrees).sortBy(_._2._2, ascending=false).take(10).foreach(println)
(1578042212,(jeff.dasovich@enron.com,139786))
(2822677534,(veronica.espinoza@enron.com,106442))
(3035779314,(pete.davis@enron.com,94666))
(2346362132,(rhonda.denton@enron.com,90570))
(861605621,(cheryl.johnson@enron.com,74319))
(14078526,(susan.mara@enron.com,58797))
(2058972224,(jae.black@enron.com,58718))
(871077839,(ginger.dernehl@enron.com,57559))
(3852770211,(lorna.brennan@enron.com,50106))
(241175230,(mary.hain@enron.com,40425))
…

让我们在 Enron 数据集上应用一些更复杂的算法。

连接组件

连接组件确定图是否自然地分为几个部分。在 Enron 关系图中,这意味着两个或多个群体主要相互沟通:

scala> val groups = org.apache.spark.graphx.lib.ConnectedComponents.run(graph).vertices.map(_._2).distinct.cache
groups: org.apache.spark.rdd.RDD[org.apache.spark.graphx.VertexId] = MapPartitionsRDD[2404] at distinct at <console>:34

scala> groups.count
res106: Long = 18

scala> people.join(groups.map( x => (x, x))).map(x => (x._1, x._2._1)).sortBy(_._1).collect.foreach(println)
(332133,laura.beneville@enron.com)
(81833994,gpg.me-q@enron.com)
(115247730,dl-ga-enron_debtor@enron.com)
(299810291,gina.peters@enron.com)
(718200627,techsupport.notices@enron.com)
(847455579,paul.de@enron.com)
(919241773,etc.survey@enron.com)
(1139366119,enron.global.services.-.us@enron.com)
(1156539970,shelley.ariel@enron.com)
(1265773423,dl-ga-all_ews_employees@enron.com)
(1493879606,chairman.ees@enron.com)
(1511379835,gary.allen.-.safety.specialist@enron.com)
(2114016426,executive.robert@enron.com)
(2200225669,ken.board@enron.com)
(2914568776,ge.americas@enron.com)
(2934799198,yowman@enron.com)
(2975592118,tech.notices@enron.com)
(3678996795,mail.user@enron.com)

我们看到了 18 个群体。每个群体都可以通过过滤 ID 进行计数和提取。例如,与<etc.survey@enron.com>关联的群体可以通过在 DataFrame 上运行 SQL 查询来找到:

scala> df.filter("fromId = 919241773 or toId = 919241773").select("date","from","to","subject","source").collect.foreach(println)
[2000-09-19T18:40:00.000Z,survey.test@enron.com,etc.survey@enron.com,NO ACTION REQUIRED - TEST,Enron/maildir/dasovich-j/all_documents/1567.]
[2000-09-19T18:40:00.000Z,survey.test@enron.com,etc.survey@enron.com,NO ACTION REQUIRED - TEST,Enron/maildir/dasovich-j/notes_inbox/504.]

这个群体基于 2000 年 9 月 19 日发送的一封电子邮件,从 <survey.test@enron.com> 发送到 <etc.survey@enron>。这封电子邮件被列出了两次,仅仅是因为它最终落入了两个不同的文件夹(并且有两个不同的消息 ID)。只有第一个群体,最大的子图,包含组织中的超过两个电子邮件地址。

三角形计数

三角形计数算法相对简单,可以按以下三个步骤计算:

  1. 计算每个顶点的邻居集合。

  2. 对于每条边,计算集合的交集并将计数发送到两个顶点。

  3. 在每个顶点计算总和,然后除以二,因为每个三角形被计算了两次。

我们需要将多重图转换为具有 srcId < dstId 的无向图,这是算法的一个先决条件:

scala> val unedges = graph.edges.map(e => if (e.srcId < e.dstId) (e.srcId, e.dstId) else (e.dstId, e.srcId)).map( x => Edge(x._1, x._2, 1)).cache
unedges: org.apache.spark.rdd.RDD[org.apache.spark.graphx.Edge[Int]] = MapPartitionsRDD[87] at map at <console>:48

scala> val ungraph = Graph(people, unedges).partitionBy(org.apache.spark.graphx.PartitionStrategy.EdgePartition1D, 10).cache
ungraph: org.apache.spark.graphx.Graph[String,Int] = org.apache.spark.graphx.impl.GraphImpl@77274fff

scala> val triangles = org.apache.spark.graphx.lib.TriangleCount.run(ungraph).cache
triangles: org.apache.spark.graphx.Graph[Int,Int] = org.apache.spark.graphx.impl.GraphImpl@6aec6da1

scala> people.join(triangles.vertices).map(t => (t._2._2,t._2._1)).sortBy(_._1, ascending=false).take(10).foreach(println)
(31761,sally.beck@enron.com)
(24101,louise.kitchen@enron.com)
(23522,david.forster@enron.com)
(21694,kenneth.lay@enron.com)
(20847,john.lavorato@enron.com)
(18460,david.oxley@enron.com)
(17951,tammie.schoppe@enron.com)
(16929,steven.kean@enron.com)
(16390,tana.jones@enron.com)
(16197,julie.clyatt@enron.com)

尽管三角形计数与组织中人们的重要性之间没有直接关系,但具有更高三角形计数的那些人可能更具社交性——尽管 clique 或强连通分量计数可能是一个更好的衡量标准。

强连通分量

在有向图的数学理论中,如果一个子图中的每个顶点都可以从另一个顶点到达,那么这个子图被称为强连通。可能整个图只是一个强连通分量,但另一方面,每个顶点可能就是它自己的连通分量。

如果将每个连通分量收缩为一个单一点,你会得到一个新的有向图,它具有无环的性质——无环。

SCC 检测算法已经内置到 GraphX 中:

scala> val components = org.apache.spark.graphx.lib.StronglyConnectedComponents.run(graph, 100).cache
components: org.apache.spark.graphx.Graph[org.apache.spark.graphx.VertexId,(String, String)] = org.apache.spark.graphx.impl.GraphImpl@55913bc7

scala> components.vertices.map(_._2).distinct.count
res2: Long = 17980

scala> people.join(components.vertices.map(_._2).distinct.map( x => (x, x))).map(x => (x._1, x._2._1)).sortBy(_._1).collect.foreach(println)
(332133,laura.beneville@enron.com) 
(466265,medmonds@enron.com)
(471258,.jane@enron.com)
(497810,.kimberly@enron.com)
(507806,aleck.dadson@enron.com)
(639614,j..bonin@enron.com)
(896860,imceanotes-hbcamp+40aep+2ecom+40enron@enron.com)
(1196652,enron.legal@enron.com)
(1240743,thi.ly@enron.com)
(1480469,ofdb12a77a.a6162183-on86256988.005b6308@enron.com)
(1818533,fran.i.mayes@enron.com)
(2337461,michael.marryott@enron.com)
(2918577,houston.resolution.center@enron.com)

有 18,200 个强连通分量,每个组平均只有 23,787/18,200 = 1.3 个用户。

PageRank

PageRank 算法通过分析链接(在这种情况下是电子邮件)来估计一个人的重要性。例如,让我们在 Enron 电子邮件图中运行 PageRank:

scala> val ranks = graph.pageRank(0.001).vertices
ranks: org.apache.spark.graphx.VertexRDD[Double] = VertexRDDImpl[955] at RDD at VertexRDD.scala:57

scala> people.join(ranks).map(t => (t._2._2,t._2._1)).sortBy(_._1, ascending=false).take(10).foreach(println)

scala> val ranks = graph.pageRank(0.001).vertices
ranks: org.apache.spark.graphx.VertexRDD[Double] = VertexRDDImpl[955] at RDD at VertexRDD.scala:57

scala> people.join(ranks).map(t => (t._2._2,t._2._1)).sortBy(_._1, ascending=false).take(10).foreach(println)
(32.073722548483325,tana.jones@enron.com)
(29.086568868043248,sara.shackleton@enron.com)
(28.14656912897315,louise.kitchen@enron.com)
(26.57894933459292,vince.kaminski@enron.com)
(25.865486865014493,sally.beck@enron.com)
(23.86746232662471,john.lavorato@enron.com)
(22.489814482022275,jeff.skilling@enron.com)
(21.968039409295585,mark.taylor@enron.com)
(20.903053536275547,kenneth.lay@enron.com)
(20.39124651779771,gerald.nemec@enron.com)

表面上,这些人就是目标人物。PageRank 倾向于强调入边,Tana Jones 相比于三角形计数的第 9 位,回到了列表的顶端。

SVD++

SVD++ 是一种推荐引擎算法,由 Yahuda Koren 和团队在 2008 年专门为 Netflix 竞赛开发——原始论文仍在公共领域,可以通过 Google 搜索 kdd08koren.pdf 获得。具体的实现来自 ZenoGarther 的 .NET MyMediaLite 库(github.com/zenogantner/MyMediaLite),他已将 Apache 2 许可证授予 Apache 基金会。假设我有一组用户(在左侧)和物品(在右侧):

SVD++

图 07-1. 将推荐问题作为二分图的一个图形表示。

上述图表是推荐问题的图形表示。左侧的节点代表用户。右侧的节点代表物品。用户1推荐物品AC,而用户23只推荐单个物品A。其余的边缺失。常见的问题是找到其余物品的推荐排名,边也可能附有权重或推荐强度。该图通常是稀疏的。这种图也常被称为二分图,因为边只从一个节点集到另一个节点集(用户不会推荐其他用户)。

对于推荐引擎,我们通常需要两种类型的节点——用户和物品。推荐基于(用户、物品和评分)元组的评分矩阵。推荐算法的一种实现是基于前面矩阵的奇异值分解SVD)。最终的评分有四个组成部分:基线,即整个矩阵的平均值、用户平均和物品平均,如下所示:

SVD++

在这里,SVD++SVD++SVD++可以理解为整个群体的平均值、用户(在所有用户推荐中)和物品(在所有用户中)。最后一部分是两行的笛卡尔积:

SVD++

该问题被表述为一个最小化问题(参见第四章):

SVD++

在这里,SVD++ 是在 第四章 中讨论的正则化系数,监督学习和无监督学习。因此,每个用户都与一组数字 (SVD++,以及每个项目与 SVD++SVD++ 相关联。在这个特定的实现中,最优系数是通过梯度下降法找到的。这是 SVD 优化的基础。在线性代数中,SVD 将一个任意的 SVD++ 矩阵 A 表示为一个正交 SVD++ 矩阵 U、一个对角 SVD++ 矩阵 SVD++ 和一个 SVD++ 单位矩阵 V 的乘积,例如,列是相互正交的。可以说,如果取 SVD++ 矩阵中最大的 SVD++ 个条目,乘积就简化为一个非常高 SVD++ 的矩阵和一个很宽 SVD++ 的矩阵的乘积,其中 SVD++ 被称为分解的秩。如果剩余的值很小,新的 SVD++ 数字就近似于原始 SVD++ 数字,对于关系 A。如果 mn 起初就很大,在实践中的在线购物场景中,m 是商品,可能有数十万,而 n 是用户,可能有数亿,这种节省可能是巨大的。例如,对于 r=10m=100,000,和 n=100,000,000,节省如下:

SVD++

SVD 也可以看作是针对 SVD++ 矩阵的 PCA。在 Enron 案例中,我们可以将发件人视为用户,收件人视为商品(我们需要重新分配节点 ID),如下所示:

scala> val rgraph = graph.partitionBy(org.apache.spark.graphx.PartitionStrategy.EdgePartition1D, 10).mapEdges(e => 1).groupEdges(_+_).cache
rgraph: org.apache.spark.graphx.Graph[String,Int] = org.apache.spark.graphx.impl.GraphImpl@2c1a48d6

scala> val redges = rgraph.edges.map( e => Edge(-e.srcId, e.dstId, Math.log(e.attr.toDouble)) ).cache
redges: org.apache.spark.rdd.RDD[org.apache.spark.graphx.Edge[Double]] = MapPartitionsRDD[57] at map at <console>:36

scala> import org.apache.spark.graphx.lib.SVDPlusPlus
import org.apache.spark.graphx.lib.SVDPlusPlus

scala> implicit val conf = new SVDPlusPlus.Conf(10, 50, 0.0, 10.0, 0.007, 0.007, 0.005, 0.015)
conf: org.apache.spark.graphx.lib.SVDPlusPlus.Conf = org.apache.spark.graphx.lib.SVDPlusPlus$Conf@15cdc117

scala> val (svd, mu) = SVDPlusPlus.run(redges, conf)
svd: org.apache.spark.graphx.Graph[(Array[Double], Array[Double], Double, Double),Double] = org.apache.spark.graphx.impl.GraphImpl@3050363d
mu: Double = 1.3773578970633769

scala> val svdRanks = svd.vertices.filter(_._1 > 0).map(x => (x._2._3, x._1))
svdRanks: org.apache.spark.rdd.RDD[(Double, org.apache.spark.graphx.VertexId)] = MapPartitionsRDD[1517] at map at <console>:31

scala> val svdRanks = svd.vertices.filter(_._1 > 0).map(x => (x._1, x._2._3))
svdRanks: org.apache.spark.rdd.RDD[(org.apache.spark.graphx.VertexId, Double)] = MapPartitionsRDD[1520] at map at <console>:31

scala> people.join(svdRanks).sortBy(_._2._2, ascending=false).map(x => (x._2._2, x._2._1)).take(10).foreach(println)
(8.864218804309887,jbryson@enron.com)
(5.935146713012661,dl-ga-all_enron_worldwide2@enron.com)
(5.740242927715701,houston.report@enron.com)
(5.441934324464593,a478079f-55e1f3b0-862566fa-612229@enron.com)
(4.910272928389445,pchoi2@enron.com)
(4.701529779800544,dl-ga-all_enron_worldwide1@enron.com)
(4.4046392452058045,eligible.employees@enron.com)
(4.374738019256556,all_ena_egm_eim@enron.com)
(4.303078586979311,dl-ga-all_enron_north_america@enron.com)
(3.8295412053860867,the.mailout@enron.com)

svdRanksSVD++ 预测的用户部分。分布列表具有优先级,因为这通常用于群发电子邮件。要获取用户特定的部分,我们需要提供用户 ID:

scala> import com.github.fommil.netlib.BLAS.{getInstance => blas}

scala> def topN(uid: Long, num: Int) = {
 |    val usr = svd.vertices.filter(uid == -_._1).collect()(0)._2
 |    val recs = svd.vertices.filter(_._1 > 0).map( v => (v._1, mu + usr._3 + v._2._3 + blas.ddot(usr._2.length, v._2._1, 1, usr._2, 1)))
 |    people.join(recs).sortBy(_._2._2, ascending=false).map(x => (x._2._2, x._2._1)).take(num)
 | }
topN: (uid: Long, num: Int)Array[(Double, String)]

scala> def top5(x: Long) : Array[(Double, String)] = topN(x, 5)
top5: (x: Long)Array[(Double, String)]

scala> people.join(graph.inDegrees).sortBy(_._2._2, ascending=false).map(x => (x._1, x._2._1)).take(10).toList.map(t => (t._2, top5(t._1).toList)).foreach(println)
(richard.shapiro@enron.com,List((4.866184418005094E66,anne.bertino@enron.com), (3.9246829664352734E66,kgustafs@enron.com), (3.9246829664352734E66,gweiss@enron.com), (3.871029763863491E66,hill@enron.com), (3.743135924382312E66,fraser@enron.com)))
(steven.kean@enron.com,List((2.445163626935533E66,anne.bertino@enron.com), (1.9584692804232504E66,hill@enron.com), (1.9105427465629028E66,kgustafs@enron.com), (1.9105427465629028E66,gweiss@enron.com), (1.8931872324048717E66,fraser@enron.com)))
(jeff.dasovich@enron.com,List((2.8924566115596135E66,anne.bertino@enron.com), (2.3157345904446663E66,hill@enron.com), (2.2646318970030287E66,gweiss@enron.com), (2.2646318970030287E66,kgustafs@enron.com), (2.2385865127706285E66,fraser@enron.com)))
(tana.jones@enron.com,List((6.1758464471309754E66,elizabeth.sager@enron.com), (5.279291610047078E66,tana.jones@enron.com), (4.967589820856654E66,tim.belden@enron.com), (4.909283344915057E66,jeff.dasovich@enron.com), (4.869177440115682E66,mark.taylor@enron.com)))
(james.steffes@enron.com,List((5.7702834706832735E66,anne.bertino@enron.com), (4.703038082326939E66,gweiss@enron.com), (4.703038082326939E66,kgustafs@enron.com), (4.579565962089777E66,hill@enron.com), (4.4298763869135494E66,george@enron.com)))
(sara.shackleton@enron.com,List((9.198688613290757E67,louise.kitchen@enron.com), (8.078107057848099E67,john.lavorato@enron.com), (6.922806078209984E67,greg.whalley@enron.com), (6.787266892881456E67,elizabeth.sager@enron.com), (6.420473603137515E67,sally.beck@enron.com)))
(mark.taylor@enron.com,List((1.302856119148208E66,anne.bertino@enron.com), (1.0678968544568682E66,hill@enron.com), (1.031255083546722E66,fraser@enron.com), (1.009319696608474E66,george@enron.com), (9.901391892701356E65,brad@enron.com)))
(mark.guzman@enron.com,List((9.770393472845669E65,anne.bertino@enron.com), (7.97370292724488E65,kgustafs@enron.com), (7.97370292724488E65,gweiss@enron.com), (7.751983820970696E65,hill@enron.com), (7.500175024539423E65,george@enron.com)))
(geir.solberg@enron.com,List((6.856103529420811E65,anne.bertino@enron.com), (5.611272903720188E65,gweiss@enron.com), (5.611272903720188E65,kgustafs@enron.com), (5.436280144720843E65,hill@enron.com), (5.2621103015001885E65,george@enron.com)))
(ryan.slinger@enron.com,List((5.0579114162531735E65,anne.bertino@enron.com), (4.136838933824579E65,kgustafs@enron.com), (4.136838933824579E65,gweiss@enron.com), (4.0110663808847004E65,hill@enron.com), (3.8821438267917902E65,george@enron.com)))

scala> people.join(graph.outDegrees).sortBy(_._2._2, ascending=false).map(x => (x._1, x._2._1)).take(10).toList.map(t => (t._2, top5(t._1).toList)).foreach(println)
(jeff.dasovich@enron.com,List((2.8924566115596135E66,anne.bertino@enron.com), (2.3157345904446663E66,hill@enron.com), (2.2646318970030287E66,gweiss@enron.com), (2.2646318970030287E66,kgustafs@enron.com), (2.2385865127706285E66,fraser@enron.com)))
(veronica.espinoza@enron.com,List((3.135142195254243E65,gweiss@enron.com), (3.135142195254243E65,kgustafs@enron.com), (2.773512892785554E65,anne.bertino@enron.com), (2.350799070225962E65,marcia.a.linton@enron.com), (2.2055288158758267E65,robert@enron.com)))
(pete.davis@enron.com,List((5.773492048248794E66,louise.kitchen@enron.com), (5.067434612038159E66,john.lavorato@enron.com), (4.389028076992449E66,greg.whalley@enron.com), (4.1791711984241975E66,sally.beck@enron.com), (4.009544764149938E66,elizabeth.sager@enron.com)))
(rhonda.denton@enron.com,List((2.834710591578977E68,louise.kitchen@enron.com), (2.488253676819922E68,john.lavorato@enron.com), (2.1516048969715738E68,greg.whalley@enron.com), (2.0405329247770104E68,sally.beck@enron.com), (1.9877213034021861E68,elizabeth.sager@enron.com)))
(cheryl.johnson@enron.com,List((3.453167402163105E64,mary.dix@enron.com), (3.208849221485621E64,theresa.byrne@enron.com), (3.208849221485621E64,sandy.olofson@enron.com), (3.0374270093157086E64,hill@enron.com), (2.886581252384442E64,fraser@enron.com)))
(susan.mara@enron.com,List((5.1729089729525785E66,anne.bertino@enron.com), (4.220843848723133E66,kgustafs@enron.com), (4.220843848723133E66,gweiss@enron.com), (4.1044435240204605E66,hill@enron.com), (3.9709951893268635E66,george@enron.com)))
(jae.black@enron.com,List((2.513139130001457E65,anne.bertino@enron.com), (2.1037756300035247E65,hill@enron.com), (2.0297519350719265E65,fraser@enron.com), (1.9587139280519927E65,george@enron.com), (1.947164483486155E65,brad@enron.com)))
(ginger.dernehl@enron.com,List((4.516267307013845E66,anne.bertino@enron.com), (3.653408921875843E66,gweiss@enron.com), (3.653408921875843E66,kgustafs@enron.com), (3.590298037045689E66,hill@enron.com), (3.471781765250177E66,fraser@enron.com)))
(lorna.brennan@enron.com,List((2.0719309635087482E66,anne.bertino@enron.com), (1.732651408857978E66,kgustafs@enron.com), (1.732651408857978E66,gweiss@enron.com), (1.6348480059915056E66,hill@enron.com), (1.5880693846486309E66,george@enron.com)))
(mary.hain@enron.com,List((5.596589595417286E66,anne.bertino@enron.com), (4.559474243930487E66,kgustafs@enron.com), (4.559474243930487E66,gweiss@enron.com), (4.4421474044331

在这里,我们计算了度数最高的前五个推荐电子邮件列表。

SVD 在 Scala 中只有 159 行代码,可以成为一些进一步改进的基础。SVD++ 包括基于隐式用户反馈和商品相似性信息的一部分。最后,Netflix 获胜方案也考虑到了用户偏好随时间变化的事实,但这一部分尚未在 GraphX 中实现。

摘要

虽然人们可以轻松地为图问题创建自己的数据结构,但 Scala 对图的支持既来自语义层——对于 Scala 来说,Graph 实际上是一种方便、交互式且表达性强的语言,用于处理图——也来自通过 Spark 和分布式计算的可扩展性。我希望本章中暴露的一些材料将对在 Scala、Spark 和 GraphX 之上实现算法有所帮助。值得一提的是,这两个库仍在积极开发中。

在下一章中,我们将从天空中我们的飞行中降下来,看看 Scala 与传统的数据分析框架的集成,如统计语言 R 和 Python,这些语言通常用于数据处理。稍后,在第九章中,我将探讨Scala 中的 NLP。我会查看 NLP Scala 工具,这些工具广泛利用复杂的数据结构。

第八章. 将 Scala 与 R 和 Python 集成

虽然 Spark 提供了 MLlib 作为机器学习库,但在许多实际情况下,R 或 Python 提供了更熟悉且经过时间考验的统计计算接口。特别是,R 的广泛统计库包括非常流行的方差和变量依赖/独立分析方法(ANOVA/MANOVA)、一系列统计测试和随机数生成器,这些目前尚未出现在 MLlib 中。R 到 Spark 的接口可在 SparkR 项目中找到。最后,数据分析师知道 Python 的 NumPy 和 SciPy 线性代数实现因其效率以及其他时间序列、优化和信号处理包而闻名。通过 R/Python 集成,所有这些熟悉的功能都可以暴露给 Scala/Spark 用户,直到 Spark/MLlib 接口稳定,并且库进入新框架,同时利用 Spark 在多台机器上以分布式方式执行工作流的能力为用户带来好处。

当人们在 R 或 Python 中编程,或者使用任何统计或线性代数包时,他们通常不会特别关注函数式编程方面。正如我在第一章中提到的,探索性数据分析,Scala 应该被视为一种高级语言,这正是它的亮点。与高效且免费可用的基本线性代数子程序BLAS)、线性代数包LAPACK)和Arnoldi 包ARPACK)的 C 和 Fortran 实现集成,已知其可以进入 Java 和 Scala(www.netlib.orggithub.com/fommil/netlib-java)。我希望将 Scala 留在它最擅长的地方。然而,在本章中,我将专注于如何使用这些语言与 Scala/Spark 一起使用。

我将使用公开可用的美国交通部航班数据集来介绍本章内容(www.transtats.bts.gov)。

在本章中,我们将涵盖以下主题:

  • 如果您还没有这样做,请安装 R 和配置 SparkR

  • 了解 R(和 Spark)DataFrame

  • 使用 R 进行线性回归和方差分析

  • 使用 SparkR 进行广义线性模型GLM)建模

  • 如果您还没有这样做,请安装 Python

  • 学习如何使用 PySpark 并从 Scala 调用 Python

与 R 集成

就像许多高级且精心设计的科技一样,人们对 R 语言通常要么爱得要命,要么恨之入骨。其中一个原因在于,R 语言是首批尝试操作复杂对象的编程语言之一,尽管大多数情况下它们最终只是列表,而不是像更成熟的现代实现那样是结构体或映射。R 语言最初由罗素·伊哈卡和罗伯特·甘特曼于 1993 年左右在奥克兰大学创建,其根源可以追溯到 1976 年左右在贝尔实验室开发的 S 语言,当时大多数商业编程仍在 Fortran 语言中进行。虽然 R 语言包含一些功能特性,如将函数作为参数传递和 map/apply,但它明显缺少一些其他特性,如惰性评估和列表推导。尽管如此,R 语言有一个非常好的帮助系统,如果有人说他们从未需要回到help(…)命令来了解如何更好地运行某个数据转换或模型,那么他们要么在撒谎,要么是刚开始使用 R。

设置 R 和 SparkR

要运行 SparkR,您需要 R 版本 3.0 或更高版本。根据您的操作系统,遵循给定的安装说明。

Linux

在 Linux 系统上,详细的安装文档可在cran.r-project.org/bin/linux找到。然而,例如,在 Debian 系统上,您可以通过运行以下命令来安装它:

# apt-get update
...
# apt-get install r-base r-base-dev
...

要列出 Linux 存储库站点上安装的/可用的包,请执行以下命令:

# apt-cache search "^r-.*" | sort
...

R 包,它们是r-baser-recommended的一部分,被安装到/usr/lib/R/library目录中。这些可以使用常规的包维护工具进行更新,例如apt-get或 aptitude。其他作为预编译的 Debian 包提供的 R 包,r-cran-*r-bioc-*,被安装到/usr/lib/R/site-library。以下命令显示了所有依赖于r-base-core的包:

# apt-cache rdepends r-base-core

这包括来自 CRAN 和其他存储库的大量贡献包。如果您想安装作为包未提供的 R 包,或者如果您想使用更新的版本,您需要从源代码构建它们,这需要安装r-base-dev开发包,可以通过以下命令安装:

# apt-get install r-base-dev

这将引入编译 R 包的基本要求,例如开发工具组的安装。然后,本地用户/管理员可以从 CRAN 源代码包中安装 R 包,通常在 R 中使用 R> install.packages() 函数或 R CMD INSTALL 来安装。例如,要安装 R 的 ggplot2 包,请运行以下命令:

> install.packages("ggplot2")
--- Please select a CRAN mirror for use in this session ---
also installing the dependencies 'stringi', 'magrittr', 'colorspace', 'Rcpp', 'stringr', 'RColorBrewer', 'dichromat', 'munsell', 'labeling', 'digest', 'gtable', 'plyr', 'reshape2', 'scales'

这将从可用的站点之一下载并可选地编译该包及其依赖项。有时 R 会弄混存储库;在这种情况下,我建议在主目录中创建一个 ~/.Rprofile 文件,指向最近的 CRAN 存储库:

$ cat >> ~/.Rprofile << EOF
r = getOption("repos") # hard code the Berkeley repo for CRAN
r["CRAN"] = "http://cran.cnr.berkeley.edu"
options(repos = r)
rm(r)

EOF

~/.Rprofile 包含用于自定义会话的命令。我建议在其中放入的命令之一是 options (prompt="R> "),以便能够通过提示符区分您正在使用的 shell,遵循本书中大多数工具的传统。已知镜像列表可在 cran.r-project.org/mirrors.html 查找。

此外,指定通过以下命令安装 system/site/user 包的目录是一种良好的做法,除非您的操作系统设置已经通过将这些命令放入 ~/.bashrc 或系统 /etc/profile 来完成:

$ export R_LIBS_SITE=${R_LIBS_SITE:-/usr/local/lib/R/site-library:/usr/lib/R/site-library:/usr/lib/R/library}
$ export R_LIBS_USER=${R_LIBS_USER:-$HOME/R/$(uname -i)-library/$( R --version | grep -o -E [0-9]+\.[
0-9]+ | head -1)}

Mac OS

R for Mac OS 可以从 cran.r-project.org/bin/macosx 下载。写作时的最新版本是 3.2.3。始终检查下载的包的一致性。为此,请运行以下命令:

$ pkgutil --check-signature R-3.2.3.pkg
Package "R-3.2.3.pkg":
 Status: signed by a certificate trusted by Mac OS X
 Certificate Chain:
 1\. Developer ID Installer: Simon Urbanek
 SHA1 fingerprint: B7 EB 39 5E 03 CF 1E 20 D1 A6 2E 9F D3 17 90 26 D8 D6 3B EF
 -----------------------------------------------------------------------------
 2\. Developer ID Certification Authority
 SHA1 fingerprint: 3B 16 6C 3B 7D C4 B7 51 C9 FE 2A FA B9 13 56 41 E3 88 E1 86
 -----------------------------------------------------------------------------
 3\. Apple Root CA
 SHA1 fingerprint: 61 1E 5B 66 2C 59 3A 08 FF 58 D1 4A E2 24 52 D1 98 DF 6C 60

前一小节中的环境设置也适用于 Mac OS 设置。

Windows

R for Windows 可以从 cran.r-project.org/bin/windows/ 下载为一个 exe 安装程序。以管理员身份运行此可执行文件来安装 R。

通常可以通过按照 Windows 菜单中的 控制面板 | 系统和安全 | 系统 | 高级系统设置 | 环境变量 路径来编辑 系统/用户 的环境设置。

通过脚本运行 SparkR

要运行 SparkR,需要运行 Spark git 树中包含的 R/install-dev.sh 脚本。实际上,只需要 shell 脚本和 R/pkg 目录的内容,这些内容并不总是包含在编译好的 Spark 发行版中:

$ git clone https://github.com/apache/spark.git
Cloning into 'spark'...
remote: Counting objects: 301864, done.
...
$ cp –r R/{install-dev.sh,pkg) $SPARK_HOME/R
...
$ cd $SPARK_HOME
$ ./R/install-dev.sh
* installing *source* package 'SparkR' ...
** R
** inst
** preparing package for lazy loading
Creating a new generic function for 'colnames' in package 'SparkR'
...
$ bin/sparkR

R version 3.2.3 (2015-12-10) -- "Wooden Christmas-Tree"
Copyright (C) 2015 The R Foundation for Statistical Computing
Platform: x86_64-redhat-linux-gnu (64-bit)

R is free software and comes with ABSOLUTELY NO WARRANTY.
You are welcome to redistribute it under certain conditions.
Type 'license()' or 'licence()' for distribution details.

 Natural language support but running in an English locale

R is a collaborative project with many contributors.
Type 'contributors()' for more information and
'citation()' on how to cite R or R packages in publications.

Type 'demo()' for some demos, 'help()' for on-line help, or
'help.start()' for an HTML browser interface to help.
Type 'q()' to quit R.

Launching java with spark-submit command /home/alex/spark-1.6.1-bin-hadoop2.6/bin/spark-submit   "sparkr-shell" /tmp/RtmpgdTfmU/backend_port22446d0391e8 

 Welcome to
 ____              __ 
 / __/__  ___ _____/ /__ 
 _\ \/ _ \/ _ `/ __/  '_/ 
 /___/ .__/\_,_/_/ /_/\_\   version  1.6.1 
 /_/ 

 Spark context is available as sc, SQL context is available as sqlContext>

通过 R 的命令行运行 Spark

或者,我们也可以直接从 R 命令行(或从 RStudio rstudio.org/)使用以下命令初始化 Spark:

R> library(SparkR, lib.loc = c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib")))
...
R> sc <- sparkR.init(master = Sys.getenv("SPARK_MASTER"), sparkEnvir = list(spark.driver.memory="1g"))
...
R> sqlContext <- sparkRSQL.init(sc)

如前文第三章 Chapter 3 中所述,在 Working with Spark and MLlib 中,SPARK_HOME 环境变量需要指向您的本地 Spark 安装目录,SPARK_MASTERYARN_CONF_DIR 指向所需的集群管理器(本地、独立、mesos 和 YARN)以及 YARN 配置目录,如果使用 Spark 与 YARN 集群管理器一起使用的话。

虽然大多数分布都带有 UI,但根据本书的传统和本章的目的,我将使用命令行。

DataFrames

DataFrames 最初来自 R 和 Python,所以在 SparkR 中看到它们是很自然的。

注意

请注意,SparkR 中 DataFrame 的实现是在 RDD 之上,因此它们的工作方式与 R DataFrame 不同。

最近,关于何时何地存储和应用模式以及类型等元数据的问题一直是活跃讨论的课题。一方面,在数据早期提供模式可以实现对数据的彻底验证和潜在的优化。另一方面,对于原始数据摄取来说可能过于限制,其目标只是尽可能多地捕获数据,并在之后进行数据格式化和清洗,这种方法通常被称为“读取时模式”。后者最近由于有了处理演变模式的工具(如 Avro)和自动模式发现工具而获得了更多支持,但为了本章的目的,我将假设我们已经完成了模式发现的部分,并可以开始使用 DataFrames。

让我们首先从美国交通部下载并提取一个航班延误数据集,如下所示:

$ wget http://www.transtats.bts.gov/Download/On_Time_On_Time_Performance_2015_7.zip
--2016-01-23 15:40:02--  http://www.transtats.bts.gov/Download/On_Time_On_Time_Performance_2015_7.zip
Resolving www.transtats.bts.gov... 204.68.194.70
Connecting to www.transtats.bts.gov|204.68.194.70|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 26204213 (25M) [application/x-zip-compressed]
Saving to: "On_Time_On_Time_Performance_2015_7.zip"

100%[====================================================================================================================================================================================>] 26,204,213   966K/s   in 27s 

2016-01-23 15:40:29 (956 KB/s) - "On_Time_On_Time_Performance_2015_7.zip" saved [26204213/26204213]

$ unzip -d flights On_Time_On_Time_Performance_2015_7.zip
Archive:  On_Time_On_Time_Performance_2015_7.zip
 inflating: flights/On_Time_On_Time_Performance_2015_7.csv 
 inflating: flights/readme.html

如果你已经在集群上运行了 Spark,你想要将文件复制到 HDFS:

$ hadoop fs –put flights .

flights/readme.html文件提供了详细的元数据信息,如下所示:

DataFrames

图 08-1:美国交通部发布的准时性能数据集提供的元数据(仅用于演示目的)

现在,我希望你分析SFO返程航班的延误情况,并可能找出导致延误的因素。让我们从 R 的data.frame开始:

$ bin/sparkR --master local[8]

R version 3.2.3 (2015-12-10) -- "Wooden Christmas-Tree"
Copyright (C) 2015 The R Foundation for Statistical Computing
Platform: x86_64-apple-darwin13.4.0 (64-bit)

R is free software and comes with ABSOLUTELY NO WARRANTY.
You are welcome to redistribute it under certain conditions.
Type 'license()' or 'licence()' for distribution details.

 Natural language support but running in an English locale

R is a collaborative project with many contributors.
Type 'contributors()' for more information and
'citation()' on how to cite R or R packages in publications.

Type 'demo()' for some demos, 'help()' for on-line help, or
'help.start()' for an HTML browser interface to help.
Type 'q()' to quit R.

[Previously saved workspace restored]

Launching java with spark-submit command /Users/akozlov/spark-1.6.1-bin-hadoop2.6/bin/spark-submit   "--master" "local[8]" "sparkr-shell" /var/folders/p1/y7ygx_4507q34vhd60q115p80000gn/T//RtmpD42eTz/backend_port682e58e2c5db 

 Welcome to
 ____              __ 
 / __/__  ___ _____/ /__ 
 _\ \/ _ \/ _ `/ __/  '_/ 
 /___/ .__/\_,_/_/ /_/\_\   version  1.6.1 
 /_/ 

 Spark context is available as sc, SQL context is available as sqlContext
> flights <- read.table(unz("On_Time_On_Time_Performance_2015_7.zip", "On_Time_On_Time_Performance_2015_7.csv"), nrows=1000000, header=T, quote="\"", sep=",")
> sfoFlights <- flights[flights$Dest == "SFO", ]
> attach(sfoFlights)
> delays <- aggregate(ArrDelayMinutes ~ DayOfWeek + Origin + UniqueCarrier, FUN=mean, na.rm=TRUE)
> tail(delays[order(delays$ArrDelayMinutes), ])
 DayOfWeek Origin UniqueCarrier ArrDelayMinutes
220         4    ABQ            OO           67.60
489         4    TUS            OO           71.80
186         5    IAH            F9           77.60
696         3    RNO            UA           79.50
491         6    TUS            OO          168.25
84          7    SLC            AS          203.25

如果你是在 2015 年 7 月周日的阿拉斯加航空从盐湖城起飞,那么你觉得自己很不幸(我们到目前为止只进行了简单的分析,所以人们不应该过分重视这个结果)。可能有多个其他随机因素导致延误。

尽管我们在 SparkR 中运行了示例,但我们仍然使用了 R 的data.frame。如果我们想分析跨多个月份的数据,我们需要在多个节点之间分配负载。这就是 SparkR 分布式 DataFrame 发挥作用的地方,因为它可以在单个节点上的多个线程之间进行分布。有一个直接的方法可以将 R DataFrame 转换为 SparkR DataFrame(以及因此转换为 RDD):

> sparkDf <- createDataFrame(sqlContext, flights)

如果我在笔记本电脑上运行它,我会耗尽内存。由于我需要在多个线程/节点之间传输数据,开销很大,我们希望尽可能早地进行过滤:

sparkDf <- createDataFrame(sqlContext, subset(flights, select = c("ArrDelayMinutes", "DayOfWeek", "Origin", "Dest", "UniqueCarrier")))

这甚至可以在我的笔记本电脑上运行。当然,从 Spark 的 DataFrame 到 R 的data.frame也有反向转换:

> rDf <- as.data.frame(sparkDf)

或者,我可以使用 spark-csv 包从 .csv 文件中读取它,如果原始的 .csv 文件位于 HDFS 等分布式文件系统中,这将避免在集群设置中通过网络在集群中移动数据。目前唯一的缺点是 Spark 不能直接从 .zip 文件中读取:

> $ ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.3.0 --master local[8]

R version 3.2.3 (2015-12-10) -- "Wooden Christmas-Tree"
Copyright (C) 2015 The R Foundation for Statistical Computing
Platform: x86_64-redhat-linux-gnu (64-bit)

R is free software and comes with ABSOLUTELY NO WARRANTY.
You are welcome to redistribute it under certain conditions.
Type 'license()' or 'licence()' for distribution details.

 Natural language support but running in an English locale

R is a collaborative project with many contributors.
Type 'contributors()' for more information and
'citation()' on how to cite R or R packages in publications.

Type 'demo()' for some demos, 'help()' for on-line help, or
'help.start()' for an HTML browser interface to help.
Type 'q()' to quit R.

Warning: namespace 'SparkR' is not available and has been replaced
by .GlobalEnv when processing object 'sparkDf'
[Previously saved workspace restored]

Launching java with spark-submit command /home/alex/spark-1.6.1-bin-hadoop2.6/bin/spark-submit   "--master" "local[8]" "--packages" "com.databricks:spark-csv_2.10:1.3.0" "sparkr-shell" /tmp/RtmpfhcUXX/backend_port1b066bea5a03 
Ivy Default Cache set to: /home/alex/.ivy2/cache
The jars for the packages stored in: /home/alex/.ivy2/jars
:: loading settings :: url = jar:file:/home/alex/spark-1.6.1-bin-hadoop2.6/lib/spark-assembly-1.6.1-hadoop2.6.0.jar!/org/apache/ivy/core/settings/ivysettings.xml
com.databricks#spark-csv_2.10 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent;1.0
 confs: [default]
 found com.databricks#spark-csv_2.10;1.3.0 in central
 found org.apache.commons#commons-csv;1.1 in central
 found com.univocity#univocity-parsers;1.5.1 in central
:: resolution report :: resolve 189ms :: artifacts dl 4ms
 :: modules in use:
 com.databricks#spark-csv_2.10;1.3.0 from central in [default]
 com.univocity#univocity-parsers;1.5.1 from central in [default]
 org.apache.commons#commons-csv;1.1 from central in [default]
 ---------------------------------------------------------------------
 |                  |            modules            ||   artifacts   |
 |       conf       | number| search|dwnlded|evicted|| number|dwnlded|
 ---------------------------------------------------------------------
 |      default     |   3   |   0   |   0   |   0   ||   3   |   0   |
 ---------------------------------------------------------------------
:: retrieving :: org.apache.spark#spark-submit-parent
 confs: [default]
 0 artifacts copied, 3 already retrieved (0kB/7ms)

 Welcome to
 ____              __ 
 / __/__  ___ _____/ /__ 
 _\ \/ _ \/ _ `/ __/  '_/ 
 /___/ .__/\_,_/_/ /_/\_\   version  1.6.1 
 /_/ 

 Spark context is available as sc, SQL context is available as sqlContext
> sparkDf <- read.df(sqlContext, "./flights", "com.databricks.spark.csv", header="true", inferSchema = "false")
> sfoFlights <- select(filter(sparkDf, sparkDf$Dest == "SFO"), "DayOfWeek", "Origin", "UniqueCarrier", "ArrDelayMinutes")
> aggs <- agg(group_by(sfoFlights, "DayOfWeek", "Origin", "UniqueCarrier"), count(sparkDf$ArrDelayMinutes), avg(sparkDf$ArrDelayMinutes))
> head(arrange(aggs, c('avg(ArrDelayMinutes)'), decreasing = TRUE), 10)
 DayOfWeek Origin UniqueCarrier count(ArrDelayMinutes) avg(ArrDelayMinutes) 
1          7    SLC            AS                      4               203.25
2          6    TUS            OO                      4               168.25
3          3    RNO            UA                      8                79.50
4          5    IAH            F9                      5                77.60
5          4    TUS            OO                      5                71.80
6          4    ABQ            OO                      5                67.60
7          2    ABQ            OO                      4                66.25
8          1    IAH            F9                      4                61.25
9          4    DAL            WN                      5                59.20
10         3    SUN            OO                      5                59.00

注意,我们通过在命令行上提供 --package 标志加载了额外的 com.databricks:spark-csv_2.10:1.3.0 包;我们可以通过在节点集群上使用 Spark 实例轻松地进行分布式处理,甚至分析更大的数据集:

$ for i in $(seq 1 6); do wget http://www.transtats.bts.gov/Download/On_Time_On_Time_Performance_2015_$i.zip; unzip -d flights On_Time_On_Time_Performance_2015_$i.zip; hadoop fs -put -f flights/On_Time_On_Time_Performance_2015_$i.csv flights; done

$ hadoop fs -ls flights
Found 7 items
-rw-r--r--   3 alex eng  211633432 2016-02-16 03:28 flights/On_Time_On_Time_Performance_2015_1.csv
-rw-r--r--   3 alex eng  192791767 2016-02-16 03:28 flights/On_Time_On_Time_Performance_2015_2.csv
-rw-r--r--   3 alex eng  227016932 2016-02-16 03:28 flights/On_Time_On_Time_Performance_2015_3.csv
-rw-r--r--   3 alex eng  218600030 2016-02-16 03:28 flights/On_Time_On_Time_Performance_2015_4.csv
-rw-r--r--   3 alex eng  224003544 2016-02-16 03:29 flights/On_Time_On_Time_Performance_2015_5.csv
-rw-r--r--   3 alex eng  227418780 2016-02-16 03:29 flights/On_Time_On_Time_Performance_2015_6.csv
-rw-r--r--   3 alex eng  235037955 2016-02-15 21:56 flights/On_Time_On_Time_Performance_2015_7.csv

这将下载并将准点率数据放在航班的目录中(记住,正如我们在第一章中讨论的,探索性数据分析,我们希望将目录视为大数据数据集)。现在我们可以对 2015 年(可用数据)的整个时期进行相同的分析:

> sparkDf <- read.df(sqlContext, "./flights", "com.databricks.spark.csv", header="true")
> sfoFlights <- select(filter(sparkDf, sparkDf$Dest == "SFO"), "DayOfWeek", "Origin", "UniqueCarrier", "ArrDelayMinutes")
> aggs <- cache(agg(group_by(sfoFlights, "DayOfWeek", "Origin", "UniqueCarrier"), count(sparkDf$ArrDelayMinutes), avg(sparkDf$ArrDelayMinutes)))
> head(arrange(aggs, c('avg(ArrDelayMinutes)'), decreasing = TRUE), 10)
 DayOfWeek Origin UniqueCarrier count(ArrDelayMinutes) avg(ArrDelayMinutes) 
1          6    MSP            UA                      1            122.00000
2          3    RNO            UA                      8             79.50000
3          1    MSP            UA                     13             68.53846
4          7    SAT            UA                      1             65.00000
5          7    STL            UA                      9             64.55556
6          1    ORD            F9                     13             55.92308
7          1    MSO            OO                      4             50.00000
8          2    MSO            OO                      4             48.50000
9          5    CEC            OO                     28             45.86957
10         3    STL            UA                     13             43.46154

注意,我们使用了一个 cache() 调用来将数据集固定在内存中,因为我们稍后会再次使用它。这次是周六的明尼阿波利斯/联合!然而,你可能已经知道原因:对于这种 DayOfWeekOriginUniqueCarrier 组合,只有一个记录;这很可能是异常值。之前异常值的平均大约 30 个航班现在减少到 30 分钟:

> head(arrange(filter(filter(aggs, aggs$Origin == "SLC"), aggs$UniqueCarrier == "AS"), c('avg(ArrDelayMinutes)'), decreasing = TRUE), 100)
 DayOfWeek Origin UniqueCarrier count(ArrDelayMinutes) avg(ArrDelayMinutes)
1         7    SLC            AS                     30            32.600000
2         2    SLC            AS                     30            10.200000
3         4    SLC            AS                     31             9.774194
4         1    SLC            AS                     30             9.433333
5         3    SLC            AS                     30             5.866667
6         5    SLC            AS                     31             5.516129
7         6    SLC            AS                     30             2.133333

周日仍然是一个延迟问题。我们现在能分析的数据量限制仅是笔记本电脑上的核心数和集群中的节点数。现在让我们看看更复杂的机器学习模型。

线性模型

线性方法在统计建模中扮演着重要角色。正如其名所示,线性模型假设因变量是自变量的加权组合。在 R 中,lm 函数执行线性回归并报告系数,如下所示:

R> attach(iris)
R> lm(Sepal.Length ~ Sepal.Width)

Call:
lm(formula = Sepal.Length ~ Sepal.Width)

Coefficients:
(Intercept)  Sepal.Width
 6.5262      -0.2234

summary 函数提供了更多信息:

R> model <- lm(Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width)
R> summary(model)

Call:
lm(formula = Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width)

Residuals:
 Min       1Q   Median       3Q      Max 
-0.82816 -0.21989  0.01875  0.19709  0.84570 

Coefficients:
 Estimate Std. Error t value Pr(>|t|) 
(Intercept)   1.85600    0.25078   7.401 9.85e-12 ***
Sepal.Width   0.65084    0.06665   9.765  < 2e-16 ***
Petal.Length  0.70913    0.05672  12.502  < 2e-16 ***
Petal.Width  -0.55648    0.12755  -4.363 2.41e-05 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 0.3145 on 146 degrees of freedom
Multiple R-squared:  0.8586,  Adjusted R-squared:  0.8557 
F-statistic: 295.5 on 3 and 146 DF,  p-value: < 2.2e-16

虽然我们在第三章中考虑了广义线性模型,使用 Spark 和 MLlib,我们也将很快考虑 R 和 SparkR 中的 glm 实现,但线性模型通常提供更多信息,并且是处理噪声数据和选择相关属性进行进一步分析的出色工具。

注意

数据分析生命周期

虽然大多数统计书籍都关注分析和最佳使用现有数据,但统计分析的结果通常也应该影响对新信息来源的搜索。在完整的数据生命周期中,如第三章所述,使用 Spark 和 MLlib,数据科学家应始终将最新的变量重要性结果转化为如何收集数据的理论。例如,如果家用打印机的墨水使用分析表明照片墨水使用量增加,那么可以收集更多关于图片格式、数字图像来源以及用户偏好的纸张类型的信息。这种方法在实际的商业环境中证明非常有效,尽管它并没有完全自动化。

具体来说,以下是线性模型提供的输出结果的简要描述:

  • 残差:这些是实际值与预测值之间差异的统计数据。存在许多技术可以检测残差分布模式中的模型问题,但这超出了本书的范围。可以通过resid(model)函数获得详细的残差表。

  • 系数:这些是实际的线性组合系数;t 值表示系数值与标准误差估计值的比率:数值越高意味着这个系数对因变量的非平凡影响的可能性越高。这些系数也可以通过coef(model)函数获得。

  • 残差标准误差:这报告了标准均方误差,这是简单线性回归中优化的目标。

  • 多重 R 平方:这是由模型解释的因变量方差的比例。调整后的值考虑了模型中的参数数量,被认为是避免过度拟合的更好指标,如果观察数量不足以证明模型的复杂性,这种情况甚至在大数据问题中也会发生。

  • F 统计量:模型质量的衡量标准。简单来说,它衡量模型中所有参数解释因变量的程度。p 值提供了模型仅因随机机会解释因变量的概率。一般来说,小于 0.05(或 5%)的值被认为是令人满意的。虽然一般来说,高值可能意味着模型可能没有统计上的有效性,“其他因素都不重要”,但低 F 统计量并不总是意味着模型在实际中会表现良好,因此不能直接将其作为模型接受标准。

一旦应用了线性模型,通常还会应用更复杂的glm或递归模型,如决策树和rpart函数,以寻找有趣的变量交互。线性模型对于在可以改进的其他模型上建立基线是很好的。

最后,方差分析(ANOVA)是研究独立变量离散时的标准技术:

R> aov <- aov(Sepal.Length ~ Species)
R> summary(aov)
 Df Sum Sq Mean Sq F value Pr(>F) 
Species       2  63.21  31.606   119.3 <2e-16 ***
Residuals   147  38.96   0.265 
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

模型质量的衡量标准是 F 统计量。虽然可以使用Rscript通过管道机制运行 RDD 上的 R 算法,但我将在稍后部分部分介绍与Java 规范请求JSR)223 Python 集成相关的此功能。在本节中,我想特别探讨一个在 R 和 SparkR 中本地实现的广义线性回归glm函数。

广义线性模型

再次强调,你可以运行 R glm或 SparkR glm。以下表格提供了 R 实现可能的链接和优化函数列表:

以下列表显示了 R glm实现的可能选项:

Family Variance Link
gaussian gaussian identity
binomial binomial logit, probit 或 cloglog
poisson poisson log, identity 或 sqrt
Gamma Gamma inverse, identity 或 log
inverse.gaussian inverse.gaussian 1/mu²
quasi 用户定义 用户定义

我将使用二元目标ArrDel15,它表示飞机是否晚于 15 分钟到达。自变量将是DepDel15DayOfWeekMonthUniqueCarrierOriginDest

R> flights <- read.table(unz("On_Time_On_Time_Performance_2015_7.zip", "On_Time_On_Time_Performance_2015_7.csv"), nrows=1000000, header=T, quote="\"", sep=",")
R> flights$DoW_ <- factor(flights$DayOfWeek,levels=c(1,2,3,4,5,6,7), labels=c("Mon","Tue","Wed","Thu","Fri","Sat","Sun"))
R> attach(flights)
R> system.time(model <- glm(ArrDel15 ~ UniqueCarrier + DoW_ + Origin + Dest, flights, family="binomial"))

当你等待结果时,打开另一个 shell 并在 SparkR 模式下对完整七个月的数据运行glm

sparkR> cache(sparkDf <- read.df(sqlContext, "./flights", "com.databricks.spark.csv", header="true", inferSchema="true"))
DataFrame[Year:int, Quarter:int, Month:int, DayofMonth:int, DayOfWeek:int, FlightDate:string, UniqueCarrier:string, AirlineID:int, Carrier:string, TailNum:string, FlightNum:int, OriginAirportID:int, OriginAirportSeqID:int, OriginCityMarketID:int, Origin:string, OriginCityName:string, OriginState:string, OriginStateFips:int, OriginStateName:string, OriginWac:int, DestAirportID:int, DestAirportSeqID:int, DestCityMarketID:int, Dest:string, DestCityName:string, DestState:string, DestStateFips:int, DestStateName:string, DestWac:int, CRSDepTime:int, DepTime:int, DepDelay:double, DepDelayMinutes:double, DepDel15:double, DepartureDelayGroups:int, DepTimeBlk:string, TaxiOut:double, WheelsOff:int, WheelsOn:int, TaxiIn:double, CRSArrTime:int, ArrTime:int, ArrDelay:double, ArrDelayMinutes:double, ArrDel15:double, ArrivalDelayGroups:int, ArrTimeBlk:string, Cancelled:double, CancellationCode:string, Diverted:double, CRSElapsedTime:double, ActualElapsedTime:double, AirTime:double, Flights:double, Distance:double, DistanceGroup:int, CarrierDelay:double, WeatherDelay:double, NASDelay:double, SecurityDelay:double, LateAircraftDelay:double, FirstDepTime:int, TotalAddGTime:double, LongestAddGTime:double, DivAirportLandings:int, DivReachedDest:double, DivActualElapsedTime:double, DivArrDelay:double, DivDistance:double, Div1Airport:string, Div1AirportID:int, Div1AirportSeqID:int, Div1WheelsOn:int, Div1TotalGTime:double, Div1LongestGTime:double, Div1WheelsOff:int, Div1TailNum:string, Div2Airport:string, Div2AirportID:int, Div2AirportSeqID:int, Div2WheelsOn:int, Div2TotalGTime:double, Div2LongestGTime:double, Div2WheelsOff:string, Div2TailNum:string, Div3Airport:string, Div3AirportID:string, Div3AirportSeqID:string, Div3WheelsOn:string, Div3TotalGTime:string, Div3LongestGTime:string, Div3WheelsOff:string, Div3TailNum:string, Div4Airport:string, Div4AirportID:string, Div4AirportSeqID:string, Div4WheelsOn:string, Div4TotalGTime:string, Div4LongestGTime:string, Div4WheelsOff:string, Div4TailNum:string, Div5Airport:string, Div5AirportID:string, Div5AirportSeqID:string, Div5WheelsOn:string, Div5TotalGTime:string, Div5LongestGTime:string, Div5WheelsOff:string, Div5TailNum:string, :string]
sparkR> noNulls <- cache(dropna(selectExpr(filter(sparkDf, sparkDf$Cancelled == 0), "ArrDel15", "UniqueCarrier", "format_string('%d', DayOfWeek) as DayOfWeek", "Origin", "Dest"), "any"))
sparkR> sparkModel = glm(ArrDel15 ~ UniqueCarrier + DayOfWeek + Origin + Dest, noNulls, family="binomial")

在这里,我们试图构建一个模型,解释延误作为承运人、星期几和出发机场对目的地机场的影响,这由公式构造ArrDel15 ~ UniqueCarrier + DayOfWeek + Origin + Dest来捕捉。

注意

nulls,大数据和 Scala

注意,在 SparkR 的glm情况下,我必须明确过滤掉未取消的航班并移除 NA——或者用 C/Java 术语来说,是 nulls。虽然 R 默认会为你做这件事,但大数据中的 NA 非常常见,因为数据集通常是稀疏的,不应轻视。我们必须在 MLlib 中明确处理 nulls 的事实提醒我们数据集中有额外的信息,这绝对是一个受欢迎的功能。NA 的存在可以携带有关数据收集方式的信息。理想情况下,每个 NA 都应该有一个小的get_na_info方法来解释为什么这个特定的值不可用或未收集,这使我们想到了 Scala 中的Either类型。

尽管空值是从 Java 继承的,也是 Scala 的一部分,但 OptionEither 类型是新的且更健壮的机制,用于处理传统上使用空值的情况。具体来说,Either 可以提供一个值或异常信息,说明为什么没有计算;而 Option 可以提供一个值或为 None,这可以很容易地被 Scala 的模式匹配框架捕获。

你会注意到 SparkR 会运行多个线程,甚至在单个节点上,它也会消耗多个核心的 CPU 时间,并且即使数据量较大,返回速度也很快。在我的 32 核机器上的实验中,它能在不到一分钟内完成(相比之下,R glm 需要 35 分钟)。要获取结果,就像 R 模型的情况一样,我们需要运行 summary() 方法:

> summary(sparkModel)
$coefficients
 Estimate
(Intercept)      -1.518542340
UniqueCarrier_WN  0.382722232
UniqueCarrier_DL -0.047997652
UniqueCarrier_OO  0.367031995
UniqueCarrier_AA  0.046737727
UniqueCarrier_EV  0.344539788
UniqueCarrier_UA  0.299290120
UniqueCarrier_US  0.069837542
UniqueCarrier_MQ  0.467597761
UniqueCarrier_B6  0.326240578
UniqueCarrier_AS -0.210762769
UniqueCarrier_NK  0.841185903
UniqueCarrier_F9  0.788720078
UniqueCarrier_HA -0.094638586
DayOfWeek_5       0.232234937
DayOfWeek_4       0.274016179
DayOfWeek_3       0.147645473
DayOfWeek_1       0.347349366
DayOfWeek_2       0.190157420
DayOfWeek_7       0.199774806
Origin_ATL       -0.180512251
...

表现最差的航空公司是 NK(精神航空公司)。SparkR 内部使用的是有限内存的 BFGS,这是一种类似于 R glm 在 7 月数据上获得结果的有限内存拟牛顿优化方法:

R> summary(model)

Call:
glm(formula = ArrDel15 ~ UniqueCarrier + DoW + Origin + Dest, 
 family = "binomial", data = dow)

Deviance Residuals: 
 Min       1Q   Median       3Q      Max 
-1.4205  -0.7274  -0.6132  -0.4510   2.9414 

Coefficients:
 Estimate Std. Error z value Pr(>|z|) 
(Intercept)     -1.817e+00  2.402e-01  -7.563 3.95e-14 ***
UniqueCarrierAS -3.296e-01  3.413e-02  -9.658  < 2e-16 ***
UniqueCarrierB6  3.932e-01  2.358e-02  16.676  < 2e-16 ***
UniqueCarrierDL -6.602e-02  1.850e-02  -3.568 0.000359 ***
UniqueCarrierEV  3.174e-01  2.155e-02  14.728  < 2e-16 ***
UniqueCarrierF9  6.754e-01  2.979e-02  22.668  < 2e-16 ***
UniqueCarrierHA  7.883e-02  7.058e-02   1.117 0.264066 
UniqueCarrierMQ  2.175e-01  2.393e-02   9.090  < 2e-16 ***
UniqueCarrierNK  7.928e-01  2.702e-02  29.343  < 2e-16 ***
UniqueCarrierOO  4.001e-01  2.019e-02  19.817  < 2e-16 ***
UniqueCarrierUA  3.982e-01  1.827e-02  21.795  < 2e-16 ***
UniqueCarrierVX  9.723e-02  3.690e-02   2.635 0.008423 ** 
UniqueCarrierWN  6.358e-01  1.700e-02  37.406  < 2e-16 ***
dowTue           1.365e-01  1.313e-02  10.395  < 2e-16 ***
dowWed           1.724e-01  1.242e-02  13.877  < 2e-16 ***
dowThu           4.593e-02  1.256e-02   3.656 0.000256 ***
dowFri          -2.338e-01  1.311e-02 -17.837  < 2e-16 ***
dowSat          -2.413e-01  1.458e-02 -16.556  < 2e-16 ***
dowSun          -3.028e-01  1.408e-02 -21.511  < 2e-16 ***
OriginABI       -3.355e-01  2.554e-01  -1.314 0.188965 
...

SparkR glm 实现的其他参数如下表所示:

以下表格显示了 SparkR glm 实现的参数列表:

参数 可能的值 备注
formula R 中的符号描述 目前仅支持公式运算符的子集:"~", ".", ":", "+", 和 "-"
family gaussian or binomial 需要加引号:gaussian -> 线性回归,binomial -> 逻辑回归
data DataFrame 需要是 SparkR DataFrame,而不是 data.frame
lambda positive 正则化系数
alpha positive 弹性网络混合参数(详细信息请参阅 glmnet 的文档)
standardize TRUE or FALSE 用户定义
solver l-bfgs, normal or auto auto 会自动选择算法,l-bfgs 表示有限内存 BFGS,normal 表示使用正规方程作为线性回归问题的解析解

在 SparkR 中读取 JSON 文件

读取文本文件中的 JSON 记录的 Schema on Read 是大数据的一个便利特性。DataFrame 类能够确定每行包含一个 JSON 记录的文本文件的架构:

[akozlov@Alexanders-MacBook-Pro spark-1.6.1-bin-hadoop2.6]$ cat examples/src/main/resources/people.json 
{"name":"Michael"}
{"name":"Andy", "age":30}
{"name":"Justin", "age":19}

[akozlov@Alexanders-MacBook-Pro spark-1.6.1-bin-hadoop2.6]$ bin/sparkR
...

> people = read.json(sqlContext, "examples/src/main/resources/people.json")
> dtypes(people)
[[1]]
[1] "age"    "bigint"

[[2]]
[1] "name"   "string"

> schema(people)
StructType
|-name = "age", type = "LongType", nullable = TRUE
|-name = "name", type = "StringType", nullable = TRUE
> showDF(people)
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+

在 SparkR 中写入 Parquet 文件

正如我们在上一章中提到的,Parquet 格式是一种高效的存储格式,尤其是对于低基数列。Parquet 文件可以直接从 R 中读取/写入:

> write.parquet(sparkDf, "parquet")

你可以看到,新的 Parquet 文件比从 DoT 下载的原始 zip 文件小 66 倍:

[akozlov@Alexanders-MacBook-Pro spark-1.6.1-bin-hadoop2.6]$ ls –l On_Time_On_Time_Performance_2015_7.zip parquet/ flights/
-rw-r--r--  1 akozlov  staff  26204213 Sep  9 12:21 /Users/akozlov/spark/On_Time_On_Time_Performance_2015_7.zip

flights/:
total 459088
-rw-r--r--  1 akozlov  staff  235037955 Sep  9 12:20 On_Time_On_Time_Performance_2015_7.csv
-rw-r--r--  1 akozlov  staff      12054 Sep  9 12:20 readme.html

parquet/:
total 848
-rw-r--r--  1 akozlov  staff       0 Jan 24 22:50 _SUCCESS
-rw-r--r--  1 akozlov  staff   10000 Jan 24 22:50 _common_metadata
-rw-r--r--  1 akozlov  staff   23498 Jan 24 22:50 _metadata
-rw-r--r--  1 akozlov  staff  394418 Jan 24 22:50 part-r-00000-9e2d0004-c71f-4bf5-aafe-90822f9d7223.gz.parquet

从 R 调用 Scala

假设我们有一个在 Scala 中实现的数值方法的异常实现,我们想从 R 中调用它。一种方法是通过 R 的 system() 函数调用 Unix-like 系统上的 /bin/sh。然而,rscala 包是一个更有效的方法,它启动 Scala 解释器并维护 TCP/IP 网络连接的通信。

在这里,Scala 解释器在调用之间维护状态(记忆化)。同样,可以定义函数,如下所示:

R> scala <- scalaInterpreter()
R> scala %~% 'def pri(i: Stream[Int]): Stream[Int] = i.head #:: pri(i.tail filter  { x => { println("Evaluating " + x + "%" + i.head); x % i.head != 0 } } )'
ScalaInterpreterReference... engine: javax.script.ScriptEngine
R> scala %~% 'val primes = pri(Stream.from(2))'
ScalaInterpreterReference... primes: Stream[Int]
R> scala %~% 'primes take 5 foreach println'
2
Evaluating 3%2
3
Evaluating 4%2
Evaluating 5%2
Evaluating 5%3
5
Evaluating 6%2
Evaluating 7%2
Evaluating 7%3
Evaluating 7%5
7
Evaluating 8%2
Evaluating 9%2
Evaluating 9%3
Evaluating 10%2
Evaluating 11%2
Evaluating 11%3
Evaluating 11%5
Evaluating 11%7
11
R> scala %~% 'primes take 5 foreach println'
2
3
5
7
11
R> scala %~% 'primes take 7 foreach println'
2
3
5
7
11
Evaluating 12%2
Evaluating 13%2
Evaluating 13%3
Evaluating 13%5
Evaluating 13%7
Evaluating 13%11
13
Evaluating 14%2
Evaluating 15%2
Evaluating 15%3
Evaluating 16%2
Evaluating 17%2
Evaluating 17%3
Evaluating 17%5
Evaluating 17%7
Evaluating 17%11
Evaluating 17%13
17
R> 

Scala 可以使用!!!操作符和Rscript命令调用 R:

[akozlov@Alexanders-MacBook-Pro ~]$ cat << EOF > rdate.R
> #!/usr/local/bin/Rscript
> 
> write(date(), stdout())
> EOF
[akozlov@Alexanders-MacBook-Pro ~]$ chmod a+x rdate.R
[akozlov@Alexanders-MacBook-Pro ~]$ scala
Welcome to Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40).
Type in expressions to have them evaluated.
Type :help for more information.

scala> import sys.process._
import sys.process._

scala> val date = Process(Seq("./rdate.R")).!!
date: String =
"Wed Feb 24 02:20:09 2016
"

使用 Rserve

一种更有效的方法是使用类似的 TCP/IP 二进制传输协议,通过Rsclient/Rserve与 R 通信(www.rforge.net/Rserve)。要在已安装 R 的节点上启动Rserve,执行以下操作:

[akozlov@Alexanders-MacBook-Pro ~]$ wget http://www.rforge.net/Rserve/snapshot/Rserve_1.8-5.tar.gz

[akozlov@Alexanders-MacBook-Pro ~]$ R CMD INSTALL Rserve_1.8-5.tar
.gz
...
[akozlov@Alexanders-MacBook-Pro ~]$ R CMD INSTALL Rserve_1.8-5.tar.gz

[akozlov@Alexanders-MacBook-Pro ~]$ $ R -q CMD Rserve

R version 3.2.3 (2015-12-10) -- "Wooden Christmas-Tree"
Copyright (C) 2015 The R Foundation for Statistical Computing
Platform: x86_64-apple-darwin13.4.0 (64-bit)

R is free software and comes with ABSOLUTELY NO WARRANTY.
You are welcome to redistribute it under certain conditions.
Type 'license()' or 'licence()' for distribution details.

 Natural language support but running in an English locale

R is a collaborative project with many contributors.
Type 'contributors()' for more information and
'citation()' on how to cite R or R packages in publications.

Type 'demo()' for some demos, 'help()' for on-line help, or
'help.start()' for an HTML browser interface to help.
Type 'q()' to quit R.

Rserv started in daemon mode.

默认情况下,Rservlocalhost:6311上打开一个连接。二进制网络协议的优点是它具有平台无关性,多个客户端可以与服务器通信。客户端可以连接到Rserve

注意,虽然将结果作为二进制对象传递有其优点,但你必须小心 R 和 Scala 之间的类型映射。Rserve支持其他客户端,包括 Python,但我还会在本章末尾介绍 JSR 223 兼容的脚本。

与 Python 集成

Python 逐渐成为数据科学事实上的工具。它有一个命令行界面,以及通过 matplotlib 和 ggplot(基于 R 的 ggplot2)实现的不错的可视化。最近,Pandas 时间序列数据分析包的创造者 Wes McKinney 加入了 Cloudera,为 Python 在大数据领域铺平道路。

设置 Python

Python 通常是默认安装的一部分。Spark 需要版本 2.7.0+。

如果你没有在 Mac OS 上安装 Python,我建议安装 Homebrew 包管理器,请访问brew.sh

[akozlov@Alexanders-MacBook-Pro spark(master)]$ ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
==> This script will install:
/usr/local/bin/brew
/usr/local/Library/...
/usr/local/share/man/man1/brew.1
…
[akozlov@Alexanders-MacBook-Pro spark(master)]$ brew install python
…

否则,在类 Unix 系统上,可以从源分发版编译 Python:

$ export PYTHON_VERSION=2.7.11
$ wget -O - https://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz | tar xzvf -
$ cd $HOME/Python-$PYTHON_VERSION
$ ./configure--prefix=/usr/local --enable-unicode=ucs4--enable-shared LDFLAGS="-Wl,-rpath /usr/local/lib"
$ make; sudo make altinstall
$ sudo ln -sf /usr/local/bin/python2.7 /usr/local/bin/python

将其放置在不同于默认 Python 安装的目录中是一种良好的实践。在单个系统上拥有多个 Python 版本是正常的,通常不会导致问题,因为 Python 会分离安装目录。为了本章的目的,就像许多机器学习应用一样,我还需要一些包。这些包和具体版本可能因安装而异:

$ wget https://bootstrap.pypa.io/ez_setup.py
$ sudo /usr/local/bin/python ez_setup.py
$ sudo /usr/local/bin/easy_install-2.7 pip
$ sudo /usr/local/bin/pip install --upgrade avro nose numpy scipy pandas statsmodels scikit-learn iso8601 python-dateutil python-snappy

如果一切编译成功——SciPy 使用 Fortran 编译器和线性代数库——我们就准备好使用 Python 2.7.11 了!

注意

注意,如果想在分布式环境中使用pipe命令与 Python 一起使用,Python 需要安装在网络中的每个节点上。

PySpark

由于bin/sparkR使用预加载的 Spark 上下文启动 R,bin/pyspark使用预加载的 Spark 上下文和 Spark 驱动程序启动 Python shell。可以使用PYSPARK_PYTHON环境变量指向特定的 Python 版本:

[akozlov@Alexanders-MacBook-Pro spark-1.6.1-bin-hadoop2.6]$ export PYSPARK_PYTHON=/usr/local/bin/python
[akozlov@Alexanders-MacBook-Pro spark-1.6.1-bin-hadoop2.6]$ bin/pyspark 
Python 2.7.11 (default, Jan 23 2016, 20:14:24) 
[GCC 4.2.1 Compatible Apple LLVM 7.0.2 (clang-700.1.81)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  '_/
 /__ / .__/\_,_/_/ /_/\_\   version 1.6.1
 /_/

Using Python version 2.7.11 (default, Jan 23 2016 20:14:24)
SparkContext available as sc, HiveContext available as sqlContext.
>>>

PySpark 直接支持 Spark RDDs 上的大多数 MLlib 功能(spark.apache.org/docs/latest/api/python),但它已知落后 Scala API 几个版本(spark.apache.org/docs/latest/api/python)。截至 1.6.0+ 版本,它还支持 DataFrames(spark.apache.org/docs/latest/sql-programming-guide.html):

>>> sfoFlights = sqlContext.sql("SELECT Dest, UniqueCarrier, ArrDelayMinutes FROM parquet.parquet")
>>> sfoFlights.groupBy(["Dest", "UniqueCarrier"]).agg(func.avg("ArrDelayMinutes"), func.count("ArrDelayMinutes")).sort("avg(ArrDelayMinutes)", ascending=False).head(5)
[Row(Dest=u'HNL', UniqueCarrier=u'HA', avg(ArrDelayMinutes)=53.70967741935484, count(ArrDelayMinutes)=31), Row(Dest=u'IAH', UniqueCarrier=u'F9', avg(ArrDelayMinutes)=43.064516129032256, count(ArrDelayMinutes)=31), Row(Dest=u'LAX', UniqueCarrier=u'DL', avg(ArrDelayMinutes)=39.68691588785047, count(ArrDelayMinutes)=214), Row(Dest=u'LAX', UniqueCarrier=u'WN', avg(ArrDelayMinutes)=29.704453441295545, count(ArrDelayMinutes)=247), Row(Dest=u'MSO', UniqueCarrier=u'OO', avg(ArrDelayMinutes)=29.551724137931036, count(ArrDelayMinutes)=29)]

从 Java/Scala 调用 Python

由于这本书实际上是关于 Scala 的,我们也应该提到,可以直接从 Scala(或 Java)调用 Python 代码及其解释器。本章将讨论一些可用的选项。

使用 sys.process._

Scala,以及 Java,可以通过启动一个单独的线程来调用操作系统进程,这我们在第一章探索性数据分析中已经使用过,.! 方法将启动进程并返回退出代码,而 .!! 将返回包含输出的字符串:

scala> import sys.process._
import sys.process._

scala> val retCode = Process(Seq("/usr/local/bin/python", "-c", "import socket; print(socket.gethostname())")).!
Alexanders-MacBook-Pro.local
retCode: Int = 0

scala> val lines = Process(Seq("/usr/local/bin/python", "-c", """from datetime import datetime, timedelta; print("Yesterday was {}".format(datetime.now()-timedelta(days=1)))""")).!!
lines: String =
"Yesterday was 2016-02-12 16:24:53.161853
"

让我们尝试一个更复杂的 SVD 计算(类似于我们在 SVD++ 推荐引擎中使用的,但这次,它在后端调用 BLAS C 库)。我创建了一个 Python 可执行文件,它接受一个表示矩阵的字符串和所需的秩作为输入,并输出具有提供秩的 SVD 近似:

#!/usr/bin/env python

import sys
import os
import re

import numpy as np
from scipy import linalg
from scipy.linalg import svd

np.set_printoptions(linewidth=10000)

def process_line(input):
    inp = input.rstrip("\r\n")
    if len(inp) > 1:
        try:
            (mat, rank) = inp.split("|")
            a = np.matrix(mat)
            r = int(rank)
        except:
            a = np.matrix(inp)
            r = 1
        U, s, Vh = linalg.svd(a, full_matrices=False)
        for i in xrange(r, s.size):
            s[i] = 0
        S = linalg.diagsvd(s, s.size, s.size)
        print(str(np.dot(U, np.dot(S, Vh))).replace(os.linesep, ";"))

if __name__ == '__main__':
    map(process_line, sys.stdin)

让我们将其命名为 svd.py 并将其放在当前目录中。给定一个矩阵和秩作为输入,它会产生给定秩的近似:

$ echo -e "1,2,3;2,1,2;3,2,1;7,8,9|3" | ./svd.py
[[ 1\.  2\.  3.]; [ 2\.  1\.  2.]; [ 3\.  2\.  1.]; [ 7\.  8\.  9.]]

要从 Scala 中调用它,让我们在我们的 DSL 中定义以下 #<<< 方法:

scala> implicit class RunCommand(command: String) {
 |   def #<<< (input: String)(implicit buffer: StringBuilder) =  {
 |     val process = Process(command)
 |     val io = new ProcessIO (
 |       in  => { in.write(input getBytes "UTF-8"); in.close},
 |       out => { buffer append scala.io.Source.fromInputStream(out).getLines.mkString("\n"); buffer.append("\n"); out.close() },
 |       err => { scala.io.Source.fromInputStream(err).getLines().foreach(System.err.println) })
 |     (process run io).exitValue
 |   }
 | }
defined class RunCommand

现在,我们可以使用 #<<< 操作符来调用 Python 的 SVD 方法:

scala> implicit val buffer = new StringBuilder()
buffer: StringBuilder =

scala> if ("./svd.py" #<<< "1,2,3;2,1,2;3,2,1;7,8,9|1" == 0)  Some(buffer.toString) else None
res77: Option[String] = Some([[ 1.84716691  2.02576751  2.29557674]; [ 1.48971176  1.63375041  1.85134741]; [ 1.71759947  1.88367234  2.13455611]; [ 7.19431647  7.88992728  8.94077601]])

注意,由于我们要求结果的矩阵秩为 1,所有行和列都是线性相关的。我们甚至可以一次传递多行输入,如下所示:

scala> if ("./svd.py" #<<< """
 | 1,2,3;2,1,2;3,2,1;7,8,9|0
 | 1,2,3;2,1,2;3,2,1;7,8,9|1
 | 1,2,3;2,1,2;3,2,1;7,8,9|2
 | 1,2,3;2,1,2;3,2,1;7,8,9|3""" == 0) Some(buffer.toString) else None
res80: Option[String] =
Some([[ 0\.  0\.  0.]; [ 0\.  0\.  0.]; [ 0\.  0\.  0.]; [ 0\.  0\.  0.]]
[[ 1.84716691  2.02576751  2.29557674]; [ 1.48971176  1.63375041  1.85134741]; [ 1.71759947  1.88367234  2.13455611]; [ 7.19431647  7.88992728  8.94077601]]
[[ 0.9905897   2.02161614  2.98849663]; [ 1.72361156  1.63488399  1.66213642]; [ 3.04783513  1.89011928  1.05847477]; [ 7.04822694  7.88921926  9.05895373]]
[[ 1\.  2\.  3.]; [ 2\.  1\.  2.]; [ 3\.  2\.  1.]; [ 7\.  8\.  9.]])

Spark 管道

SVD 分解通常是一个相当耗时的操作,因此在这种情况下调用 Python 的相对开销很小。如果我们保持进程运行并一次提供多行,就像我们在上一个例子中所做的那样,我们可以避免这种开销。Hadoop MR 和 Spark 都实现了这种方法。例如,在 Spark 中,整个计算只需一行,如下所示:

scala> sc.parallelize(List("1,2,3;2,1,2;3,2,1;7,8,9|0", "1,2,3;2,1,2;3,2,1;7,8,9|1", "1,2,3;2,1,2;3,2,1;7,8,9|2", "1,2,3;2,1,2;3,2,1;7,8,9|3"),4).pipe("./svd.py").collect.foreach(println)
[[ 0\.  0\.  0.]; [ 0\.  0\.  0.]; [ 0\.  0\.  0.]; [ 0\.  0\.  0.]]
[[ 1.84716691  2.02576751  2.29557674]; [ 1.48971176  1.63375041  1.85134741]; [ 1.71759947  1.88367234  2.13455611]; [ 7.19431647  7.88992728  8.94077601]]
[[ 0.9905897   2.02161614  2.98849663]; [ 1.72361156  1.63488399  1.66213642]; [ 3.04783513  1.89011928  1.05847477]; [ 7.04822694  7.88921926  9.05895373]]
[[ 1\.  2\.  3.]; [ 2\.  1\.  2.]; [ 3\.  2\.  1.]; [ 7\.  8\.  9.]]

整个管道已准备好在多核工作站集群中分发!我想你已经爱上 Scala/Spark 了。

注意,调试管道化执行可能很棘手,因为数据是通过操作系统管道从一个进程传递到另一个进程的。

Jython 和 JSR 223

为了完整性,我们需要提及 Jython,这是 Python 的 Java 实现(与更熟悉的 C 实现不同,也称为 CPython)。Jython 通过允许用户将 Python 源代码编译成 Java 字节码,并在任何 Java 虚拟机上运行这些字节码,避免了通过操作系统管道传递输入/输出的问题。由于 Scala 也在 Java 虚拟机上运行,它可以直接使用 Jython 类,尽管通常情况并非如此;Scala 类有时与 Java/Jython 不兼容。

注意

JSR 223

在这个特定案例中,请求的是“JavaTM 平台的脚本编程”功能,最初于 2004 年 11 月 15 日提出 (www.jcp.org/en/jsr/detail?id=223)。最初,它针对的是 Java servlet 与多种脚本语言协同工作的能力。规范要求脚本语言维护者提供包含相应实现的 Java JAR。可移植性问题阻碍了实际的应用,尤其是在需要与操作系统进行复杂交互的平台,如 C 或 Fortran 中的动态链接。目前,只有少数语言受到支持,R 和 Python 被支持,但形式并不完整。

自 Java 6 以来,JSR 223:Java 的脚本编程添加了 javax.script 包,它允许通过相同的 API 调用多种脚本语言,只要该语言提供脚本引擎。要添加 Jython 脚本语言,请从 Jython 网站下载最新的 Jython JAR 文件,网址为 www.jython.org/downloads.html

$ wget -O jython-standalone-2.7.0.jar http://search.maven.org/remotecontent?filepath=org/python/jython-standalone/2.7.0/jython-standalone-2.7.0.jar

[akozlov@Alexanders-MacBook-Pro Scala]$ scala -cp jython-standalone-2.7.0.jar 
Welcome to Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40).
Type in expressions to have them evaluated.
Type :help for more information.

scala> import javax.script.ScriptEngine;
...
scala> import javax.script.ScriptEngineManager;
...
scala> import javax.script.ScriptException;
...
scala> val manager = new ScriptEngineManager();
manager: javax.script.ScriptEngineManager = javax.script.ScriptEngineManager@3a03464

scala> val engines = manager.getEngineFactories();
engines: java.util.List[javax.script.ScriptEngineFactory] = [org.python.jsr223.PyScriptEngineFactory@4909b8da, jdk.nashorn.api.scripting.NashornScriptEngineFactory@68837a77, scala.tools.nsc.interpreter.IMain$Factory@1324409e]

现在,我可以使用 Jython/Python 脚本引擎:

scala> val engine = new ScriptEngineManager().getEngineByName("jython");
engine: javax.script.ScriptEngine = org.python.jsr223.PyScriptEngine@6094de13

scala> engine.eval("from datetime import datetime, timedelta; yesterday = str(datetime.now()-timedelta(days=1))")
res15: Object = null

scala> engine.get("yesterday")
res16: Object = 2016-02-12 23:26:38.012000

在这里值得声明的是,并非所有 Python 模块都在 Jython 中可用。需要 C/Fortran 动态链接库的模块,而这些库在 Java 中不存在,在 Jython 中可能无法工作。具体来说,NumPy 和 SciPy 在 Jython 中不受支持,因为它们依赖于 C/Fortran。如果你发现其他缺失的模块,可以尝试将 Python 发行版中的 .py 文件复制到 Jython 的 sys.path 目录中——如果这样做有效,那么你很幸运。

Jython 的优点是无需在每个调用时启动 Python 运行时即可访问丰富的 Python 模块,这可能会带来显著的性能提升:

scala> val startTime = System.nanoTime
startTime: Long = 54384084381087

scala> for (i <- 1 to 100) {
 |   engine.eval("from datetime import datetime, timedelta; yesterday = str(datetime.now()-timedelta(days=1))")
 |   val yesterday = engine.get("yesterday")
 | }

scala> val elapsed = 1e-9 * (System.nanoTime - startTime)
elapsed: Double = 0.270837934

scala> val startTime = System.nanoTime
startTime: Long = 54391560460133

scala> for (i <- 1 to 100) {
 |   val yesterday = Process(Seq("/usr/local/bin/python", "-c", """from datetime import datetime, timedelta; print(datetime.now()-timedelta(days=1))""")).!!
 | }

scala> val elapsed = 1e-9 * (System.nanoTime - startTime)
elapsed: Double = 2.221937263

Jython JSR 223 调用速度快 10 倍!

摘要

R 和 Python 对于数据科学家来说就像面包和黄油一样。现代框架往往具有互操作性,并互相借鉴彼此的优势。在本章中,我介绍了 R 和 Python 之间互操作性的底层结构。它们都有流行的包(R)和模块(Python),这些扩展了当前的 Scala/Spark 功能。许多人认为 R 和 Python 的现有库对于它们的实现至关重要。

本章演示了几种集成这些包的方法,并提供了使用这些集成所带来的权衡,以便我们能够进入下一章,探讨 NLP,在 NLP 中,函数式编程从一开始就被传统地使用。

第九章:Scala 中的 NLP

本章描述了几种常见的自然语言处理NLP)技术,特别是那些可以从 Scala 中受益的技术。开源领域中有一些 NLP 包。其中最著名的是 NLTK (www.nltk.org),它是用 Python 编写的,并且可能还有更多强调 NLP 不同方面的专有软件解决方案。值得提及的是 Wolf (github.com/wolfe-pack)、FACTORIE (factorie.cs.umass.edu)、ScalaNLP (www.scalanlp.org)和 skymind (www.skymind.io),其中 skymind 部分是专有的。然而,由于一个或多个原因,这个领域的许多开源项目在一段时间内都保持活跃。大多数项目正被 Spark 和 MLlib 的能力所取代,尤其是在可扩展性方面。

我不会详细描述每个 NLP 项目,这些项目可能包括语音转文本、文本转语音和语言翻译,而是将在本章提供一些基本技术,专注于利用 Spark MLlib。这一章作为本书的最后一个分析章节,显得非常自然。Scala 是一种看起来非常自然语言的计算机语言,本章将利用我之前开发的技术。

NLP 可以说是 AI 的核心。最初,AI 被创造出来是为了模仿人类,而自然语言解析和理解是其不可或缺的一部分。大数据技术已经开始渗透 NLP,尽管传统上 NLP 非常计算密集,被视为小数据问题。NLP 通常需要广泛的深度学习技术,而所有书面文本的数据量似乎与今天所有机器生成的日志量以及大数据机器分析的数据量相比并不大。

尽管国会图书馆拥有数百万份文件,但其中大部分可以以 PB(实际数字数据量)为单位进行数字化,这是一个任何社交网站都能在几秒钟内收集、存储和分析的量级。大多数多产作者的完整作品可以存储在几 MB 的文件中(参考 表 09-1)。然而,社交网络和 ADTECH 公司每天都会从数百万用户和数百个上下文中解析文本。

完整作品 生活时期 大小
柏拉图 428/427 (或 424/423) - 348/347 BC 2.1 MB
威廉·莎士比亚 1564 年 4 月 26 日(洗礼)- 1616 年 4 月 23 日 3.8 MB
费奥多尔·陀思妥耶夫斯基 1821 年 11 月 11 日 - 1881 年 2 月 9 日 5.9 MB
列夫·托尔斯泰 1828 年 9 月 9 日 - 1910 年 11 月 20 日 6.9 MB
马克·吐温 1835 年 11 月 30 日 - 1910 年 4 月 21 日 13 MB

表 09-1. 一些著名作家的全集(大多数现在可以在 Amazon.com 上以几美元的价格购买,后来的作者,尽管已经数字化,但价格更高)

自然语言是一个动态的概念,随着时间的推移、技术和几代人的变化而变化。我们看到了表情符号、三字母缩写等的出现。外语往往相互借鉴;描述这个动态生态系统本身就是一项挑战。

如前几章所述,我将专注于如何使用 Scala 作为工具来编排语言分析,而不是在 Scala 中重写工具。由于这个主题非常广泛,我无法声称在这里涵盖 NLP 的所有方面。

在本章中,我们将涵盖以下主题:

  • 以文本处理流程和阶段为例,讨论自然语言处理(NLP)

  • 从词袋的角度学习简单的文本分析方法

  • 了解词频逆文档频率TF-IDF)技术,它超越了简单的词袋分析,并且在信息检索IR)中实际上是标准技术

  • 潜在狄利克雷分配LDA)方法为例,了解文档聚类

  • 使用基于 word2vec n-gram 算法进行语义分析

文本分析流程

在我们继续详细算法之前,让我们看看图 9-1中描述的通用文本处理流程。在文本分析中,输入通常以字符流的形式呈现(具体取决于特定的语言)。

词汇分析涉及将这个流分解成一系列单词(或语言学分析中的词素)。通常它也被称为分词(而单词被称为标记)。ANother Tool for Language RecognitionANTLR)(www.antlr.org/)和 Flex (flex.sourceforge.net)可能是开源社区中最著名的。词汇歧义的一个经典例子是词汇歧义。例如,在短语I saw a bat.中,bat可以指动物或棒球棒。我们通常需要上下文来弄清楚这一点,我们将在下一节讨论:

文本分析流程

图 9-1. NLP 过程的典型阶段。

句法分析,或称为解析,传统上处理的是将文本结构与语法规则相匹配。这对于不允许任何歧义的计算机语言来说相对更重要。在自然语言中,这个过程通常被称为分块和标记。在许多情况下,人类语言中单词的意义可能受语境、语调,甚至肢体语言或面部表情的影响。与大数据方法相比,大数据方法中数据的量胜过复杂性,这种分析的价值仍然是一个有争议的话题——后者之一是 word2vec 方法,稍后将进行描述。

语义分析是从句法结构中提取语言无关意义的过程。在尽可能的范围内,它还涉及去除特定于特定文化和语言背景的特征,到这种项目可能实现的程度。这一阶段的歧义来源包括:短语附着、连词、名词组结构、语义歧义、指代非字面言语等。再次强调,word2vec 部分解决了这些问题。

揭示整合部分解决了上下文的问题:一个句子或成语的意义可能取决于之前的句子或段落。句法分析和文化背景在这里起着重要作用。

最后,实用分析是试图根据意图重新解释所说内容的另一层复杂性。这如何改变世界的状态?它是否可行?

简单文本分析

文档的直接表示是一个词袋。Scala 和 Spark 提供了一个出色的范例来对词分布进行分析。首先,我们读取整个文本集合,然后计算独特的单词数量:

$ bin/spark-shell 
Welcome to
 ____              __
 / __/__  ___ _____/ /__
 _\ \/ _ \/ _ `/ __/  ''_/
 /___/ .__/\_,_/_/ /_/\_\   version 1.6.1
 /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_40)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> val leotolstoy = sc.textFile("leotolstoy").cache
leotolstoy: org.apache.spark.rdd.RDD[String] = leotolstoy MapPartitionsRDD[1] at textFile at <console>:27

scala> leotolstoy.flatMap(_.split("\\W+")).count
res1: Long = 1318234 

scala> val shakespeare = sc.textFile("shakespeare").cache
shakespeare: org.apache.spark.rdd.RDD[String] = shakespeare MapPartitionsRDD[7] at textFile at <console>:27

scala> shakespeare.flatMap(_.split("\\W+")).count
res2: Long = 1051958

这仅仅为我们提供了一个关于不同作者词汇库中不同单词数量的估计。找到两个语料库之间交集的最简单方法就是找到共同的词汇(由于列夫·托尔斯泰用俄语和法语写作,而莎士比亚是一位英语作家,所以这将非常不同):

scala> :silent

scala> val shakespeareBag = shakespeare.flatMap(_.split("\\W+")).map(_.toLowerCase).distinct

scala> val leotolstoyBag = leotolstoy.flatMap(_.split("\\W+")).map(_.toLowerCase).distinct
leotolstoyBag: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[27] at map at <console>:29

scala> println("The bags intersection is " + leotolstoyBag.intersection(shakespeareBag).count)
The bags intersection is 11552

几千个单词索引在当前实现中是可管理的。对于任何新的故事,我们可以确定它更有可能是由列夫·托尔斯泰还是威廉·莎士比亚所写。让我们看看《圣经的詹姆斯国王版》,它也可以从古腾堡计划下载(www.gutenberg.org/files/10/10-h/10-h.htm):

$ (mkdir bible; cd bible; wget http://www.gutenberg.org/cache/epub/10/pg10.txt)

scala> val bible = sc.textFile("bible").cache

scala> val bibleBag = bible.flatMap(_.split("\\W+")).map(_.toLowerCase).distinct

scala>:silent

scala> bibleBag.intersection(shakespeareBag).count
res5: Long = 7250

scala> bibleBag.intersection(leotolstoyBag).count
res24: Long = 6611

这似乎是有道理的,因为在莎士比亚时代,宗教语言很流行。另一方面,安东·契诃夫的戏剧与列夫·托尔斯泰的词汇有更大的交集:

$ (mkdir chekhov; cd chekhov;
 wget http://www.gutenberg.org/cache/epub/7986/pg7986.txt
 wget http://www.gutenberg.org/cache/epub/1756/pg1756.txt
 wget http://www.gutenberg.org/cache/epub/1754/1754.txt
 wget http://www.gutenberg.org/cache/epub/13415/pg13415.txt)

scala> val chekhov = sc.textFile("chekhov").cache
chekhov: org.apache.spark.rdd.RDD[String] = chekhov MapPartitionsRDD[61] at textFile at <console>:27

scala> val chekhovBag = chekhov.flatMap(_.split("\\W+")).map(_.toLowerCase).distinct
chekhovBag: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[66] at distinct at <console>:29

scala> chekhovBag.intersection(leotolstoyBag).count
res8: Long = 8263

scala> chekhovBag.intersection(shakespeareBag).count
res9: Long = 6457 

这是一个非常简单但有效的方法,但我们还可以进行一些常见的改进。首先,一个常见的技巧是对单词进行词干提取。在许多语言中,单词有一个共同的组成部分,通常称为词根,以及一个可变的前缀或后缀,这可能会根据上下文、性别、时间等因素而变化。词干提取是通过将这种灵活的词形近似到词根、词干或一般形式来提高独特计数和交集的过程。词干形式不需要与单词的形态学词根完全相同,通常只要相关单词映射到相同的词干就足够了,即使这个词干本身不是一个有效的语法词根。其次,我们可能应该考虑单词的频率——虽然我们将在下一节中描述更复杂的方法,但为了这个练习的目的,我们将排除计数非常高的单词,这些单词通常在任何文档中都存在,如文章和所有格代词,这些通常被称为停用词,以及计数非常低的单词。具体来说,我将使用我详细描述在章节末尾的优化版Porter 词干提取器实现。

注意

tartarus.org/martin/PorterStemmer/网站包含了一些 Porter 词干提取器的 Scala 和其他语言的实现,包括高度优化的 ANSI C,这可能更有效率,但在这里我将提供另一个优化的 Scala 版本,它可以立即与 Spark 一起使用。

词干提取示例将对单词进行词干提取,并计算它们之间的相对交集,同时移除停用词:

def main(args: Array[String]) {

    val stemmer = new Stemmer

    val conf = new SparkConf().
      setAppName("Stemmer").
      setMaster(args(0))

    val sc = new SparkContext(conf)

    val stopwords = scala.collection.immutable.TreeSet(
      "", "i", "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "from", "had", "has", "he", "her", "him", "his", "in", "is", "it", "its", "my", "not", "of", "on", "she", "that", "the", "to", "was", "were", "will", "with", "you"
    ) map { stemmer.stem(_) }

    val bags = for (name <- args.slice(1, args.length)) yield {
      val rdd = sc.textFile(name).map(_.toLowerCase)
      if (name == "nytimes" || name == "nips" || name == "enron")
        rdd.filter(!_.startsWith("zzz_")).flatMap(_.split("_")).map(stemmer.stem(_)).distinct.filter(!stopwords.contains(_)).cache
      else {
        val withCounts = rdd.flatMap(_.split("\\W+")).map(stemmer.stem(_)).filter(!stopwords.contains(_)).map((_, 1)).reduceByKey(_+_)
        val minCount = scala.math.max(1L, 0.0001 * withCounts.count.toLong)
        withCounts.filter(_._2 > minCount).map(_._1).cache
      }
    }

    val cntRoots = (0 until { args.length - 1 }).map(i => Math.sqrt(bags(i).count.toDouble))

    for(l <- 0 until { args.length - 1 }; r <- l until { args.length - 1 }) {
      val cnt = bags(l).intersection(bags(r)).count
      println("The intersect " + args(l+1) + " x " + args(r+1) + " is: " + cnt + " (" + (cnt.toDouble / cntRoots(l) / cntRoots(r)) + ")")
    }

    sc.stop
    }
}

当从命令行运行主类示例时,它将输出指定为参数的数据集的词干包大小和交集(这些是主文件系统中的目录,包含文档):

$ sbt "run-main org.akozlov.examples.Stemmer local[2] shakespeare leotolstoy chekhov nytimes nips enron bible"
[info] Loading project definition from /Users/akozlov/Src/Book/ml-in-scala/chapter09/project
[info] Set current project to NLP in Scala (in build file:/Users/akozlov/Src/Book/ml-in-scala/chapter09/)
[info] Running org.akozlov.examples.Stemmer local[2] shakespeare leotolstoy chekhov nytimes nips enron bible
The intersect shakespeare x shakespeare is: 10533 (1.0)
The intersect shakespeare x leotolstoy is: 5834 (0.5293670391596142)
The intersect shakespeare x chekhov is: 3295 (0.4715281914492153)
The intersect shakespeare x nytimes is: 7207 (0.4163369701270161)
The intersect shakespeare x nips is: 2726 (0.27457329089479504)
The intersect shakespeare x enron is: 5217 (0.34431535832271265)
The intersect shakespeare x bible is: 3826 (0.45171392986714726)
The intersect leotolstoy x leotolstoy is: 11531 (0.9999999999999999)
The intersect leotolstoy x chekhov is: 4099 (0.5606253333241973)
The intersect leotolstoy x nytimes is: 8657 (0.47796976891152176)
The intersect leotolstoy x nips is: 3231 (0.3110369262979765)
The intersect leotolstoy x enron is: 6076 (0.38326210407266764)
The intersect leotolstoy x bible is: 3455 (0.3898604013063757)
The intersect chekhov x chekhov is: 4636 (1.0)
The intersect chekhov x nytimes is: 3843 (0.33463022711780555)
The intersect chekhov x nips is: 1889 (0.28679311682962116)
The intersect chekhov x enron is: 3213 (0.31963226496874225)
The intersect chekhov x bible is: 2282 (0.40610513998395287)
The intersect nytimes x nytimes is: 28449 (1.0)
The intersect nytimes x nips is: 4954 (0.30362042173997206)
The intersect nytimes x enron is: 11273 (0.45270741164576034)
The intersect nytimes x bible is: 3655 (0.2625720159205085)
The intersect nips x nips is: 9358 (1.0000000000000002)
The intersect nips x enron is: 4888 (0.3422561629856124)
The intersect nips x bible is: 1615 (0.20229053645165143)
The intersect enron x enron is: 21796 (1.0)
The intersect enron x bible is: 2895 (0.23760453654690084)
The intersect bible x bible is: 6811 (1.0)
[success] Total time: 12 s, completed May 17, 2016 11:00:38 PM

在这个例子中,这仅仅证实了圣经的词汇比列奥·托尔斯泰和其他来源更接近威廉·莎士比亚的假设。有趣的是,现代的《纽约时报》文章和上一章中提到的安然电子邮件的词汇与列奥·托尔斯泰的词汇非常接近,这可能是翻译质量的更好指示。

另一点要注意的是,这个相当复杂的分析大约需要 40 行 Scala 代码(不包括库,特别是 Porter 词干提取器,大约有 100 行)和大约 12 秒。Scala 的强大之处在于它可以非常有效地利用其他库来编写简洁的代码。

注意

序列化

我们已经在第六章处理非结构化数据中讨论了序列化。由于 Spark 的任务在不同的线程和潜在的 JVM 中执行,Spark 在传递对象时进行了大量的序列化和反序列化。潜在地,我可以用map { val stemmer = new Stemmer; stemmer.stem(_) }代替map { stemmer.stem(_) },但后者在多次迭代中重用对象,在语言上似乎更吸引人。一种建议的性能优化是使用Kryo 序列化器,它比 Java 序列化器灵活性更低,但性能更好。然而,为了集成目的,只需使管道中的每个对象可序列化并使用默认的 Java 序列化就更容易了。

作为另一个例子,让我们计算单词频率的分布,如下所示:

scala> val bags = for (name <- List("shakespeare", "leotolstoy", "chekhov", "nytimes", "enron", "bible")) yield {
 |     sc textFile(name) flatMap { _.split("\\W+") } map { _.toLowerCase } map { stemmer.stem(_) } filter { ! stopwords.contains(_) } cache()
 | }
bags: List[org.apache.spark.rdd.RDD[String]] = List(MapPartitionsRDD[93] at filter at <console>:36, MapPartitionsRDD[98] at filter at <console>:36, MapPartitionsRDD[103] at filter at <console>:36, MapPartitionsRDD[108] at filter at <console>:36, MapPartitionsRDD[113] at filter at <console>:36, MapPartitionsRDD[118] at filter at <console>:36)

scala> bags reduceLeft { (a, b) => a.union(b) } map { (_, 1) } reduceByKey { _+_ } collect() sortBy(- _._2) map { x => scala.math.log(x._2) }
res18: Array[Double] = Array(10.27759958298627, 10.1152465449837, 10.058652004037477, 10.046635061754612, 9.999615579630348, 9.855399641729074, 9.834405391348684, 9.801233318497372, 9.792667717430884, 9.76347807952779, 9.742496866444002, 9.655474810542554, 9.630365631415676, 9.623244409181346, 9.593355351246755, 9.517604459155686, 9.515837804297965, 9.47231994707559, 9.45930760329985, 9.441531454869693, 9.435561763085358, 9.426257878198653, 9.378985497953893, 9.355997944398545, 9.34862295977619, 9.300820725104558, 9.25569607369698, 9.25320827220336, 9.229162126216771, 9.20391980417326, 9.19917830726999, 9.167224080902555, 9.153875834995056, 9.137877200242468, 9.129889247578555, 9.090430075303626, 9.090091799380007, 9.083075020930307, 9.077722847361343, 9.070273383079064, 9.0542711863262...
...

在以下图表中展示了对数-对数尺度上的相对频率分布。除了前几个标记之外,频率对排名的依赖性几乎是线性的:

简单的文本分析

图 9-2. 对数-对数尺度上单词相对频率的典型分布(Zipf 定律)

Spark 中的 MLlib 算法

让我们在 MLlib 上停下来,MLlib 是 Scala 编写的其他 NLP 库的补充。MLlib 之所以重要,主要是因为其可扩展性,因此支持一些数据准备和文本处理算法,尤其是在特征构造领域(spark.apache.org/docs/latest/ml-features.html)。

TF-IDF

尽管前面的分析已经可以提供强大的洞察力,但分析中缺失的信息是术语频率信息。在信息检索中,术语频率相对更重要,因为需要根据一些术语对文档集合进行搜索和排序。通常将顶级文档返回给用户。

TF-IDF 是一种标准技术,其中术语频率被语料库中术语的频率所抵消。Spark 实现了 TF-IDF。Spark 使用哈希函数来识别术语。这种方法避免了计算全局术语到索引映射的需要,但可能会受到潜在的哈希冲突的影响,其概率由哈希表的桶数决定。默认特征维度是2²⁰=1,048,576

在 Spark 实现中,数据集中的每一行都是一个文档。我们可以将其转换为可迭代的 RDD,并使用以下代码进行哈希计算:

scala> import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.feature.HashingTF

scala> import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.Vector

scala> val hashingTF = new HashingTF
hashingTF: org.apache.spark.mllib.feature.HashingTF = org.apache.spark.mllib.feature.HashingTF@61b975f7

scala> val documents: RDD[Seq[String]] = sc.textFile("shakepeare").map(_.split("\\W+").toSeq)
documents: org.apache.spark.rdd.RDD[Seq[String]] = MapPartitionsRDD[263] at map at <console>:34

scala> val tf = hashingTF transform documents
tf: org.apache.spark.rdd.RDD[org.apache.spark.mllib.linalg.Vector] = MapPartitionsRDD[264] at map at HashingTF.scala:76

在计算hashingTF时,我们只需要对数据进行单次遍历,应用 IDF 需要两次遍历:首先计算 IDF 向量,然后使用 IDF 缩放术语频率:

scala> tf.cache
res26: tf.type = MapPartitionsRDD[268] at map at HashingTF.scala:76

scala> import org.apache.spark.mllib.feature.IDF
import org.apache.spark.mllib.feature.IDF

scala> val idf = new IDF(minDocFreq = 2) fit tf
idf: org.apache.spark.mllib.feature.IDFModel = org.apache.spark.mllib.feature.IDFModel@514bda2d

scala> val tfidf = idf transform tf
tfidf: org.apache.spark.rdd.RDD[org.apache.spark.mllib.linalg.Vector] = MapPartitionsRDD[272] at mapPartitions at IDF.scala:178

scala> tfidf take(10) foreach println
(1048576,[3159,3543,84049,582393,787662,838279,928610,961626,1021219,1021273],[3.9626355004005083,4.556357737874695,8.380602528651274,8.157736974683708,11.513471982269106,9.316247404932888,10.666174121881904,11.513471982269106,8.07948477778396,11.002646358503116])
(1048576,[267794,1021219],[8.783442874448122,8.07948477778396])
(1048576,[0],[0.5688129477150906])
(1048576,[3123,3370,3521,3543,96727,101577,114801,116103,497275,504006,508606,843002,962509,980206],[4.207164322003765,2.9674322162952897,4.125144122691999,2.2781788689373474,2.132236195047438,3.2951341639027754,1.9204575904855747,6.318664992090735,11.002646358503116,3.1043838099579815,5.451238364272918,11.002646358503116,8.43769700104158,10.30949917794317])
(1048576,[0,3371,3521,3555,27409,89087,104545,107877,552624,735790,910062,943655,962421],[0.5688129477150906,3.442878442319589,4.125144122691999,4.462482535201062,5.023254392629403,5.160262034409286,5.646060083831103,4.712188947797486,11.002646358503116,7.006282204641219,6.216822672821767,11.513471982269106,8.898512204232908])
(1048576,[3371,3543,82108,114801,149895,279256,582393,597025,838279,915181],[3.442878442319589,2.2781788689373474,6.017670811187438,3.8409151809711495,7.893585399642122,6.625632265652778,8.157736974683708,10.414859693600997,9.316247404932888,11.513471982269106])
(1048576,[3123,3555,413342,504006,690950,702035,980206],[4.207164322003765,4.462482535201062,3.4399651117812313,3.1043838099579815,11.513471982269106,11.002646358503116,10.30949917794317])
(1048576,[0],[0.5688129477150906])
(1048576,[97,1344,3370,100898,105489,508606,582393,736902,838279,1026302],[2.533299776544098,23.026943964538212,2.9674322162952897,0.0,11.225789909817326,5.451238364272918,8.157736974683708,10.30949917794317,9.316247404932888,11.513471982269106])
(1048576,[0,1344,3365,114801,327690,357319,413342,692611,867249,965170],[4.550503581720725,23.026943964538212,2.7455719545259836,1.9204575904855747,8.268278849083533,9.521041817578901,3.4399651117812313,0.0,6.661441718349489,0.0])

在这里,我们看到每个文档由一组术语及其分数表示。

LDA

Spark MLlib 中的 LDA 是一种聚类机制,其中特征向量表示文档中单词的计数。该模型最大化观察到的单词计数的概率,假设每个文档是主题的混合,文档中的单词是基于Dirichlet 分布(多项式情况下的 beta 分布的推广)独立地为每个主题生成的。目标是推导出(潜在)主题分布和单词生成统计模型的参数。

MLlib 的实现基于 2009 年的 LDA 论文(www.jmlr.org/papers/volume10/newman09a/newman09a.pdf),并使用 GraphX 实现一个分布式期望最大化(EM)算法,用于将主题分配给文档。

让我们以第七章中讨论的安然电子邮件语料库为例,使用图算法,在那里我们试图分析通信图。对于电子邮件聚类,我们需要提取电子邮件正文并将其作为单行放置在训练文件中:

$ mkdir enron
$ cat /dev/null > enron/all.txt
$ for f in $(find maildir -name \*\. -print); do cat $f | sed '1,/^$/d;/^$/d' | tr "\n\r" "  " >> enron/all.txt; echo "" >> enron/all.txt; done
$

现在,让我们使用 Scala/Spark 构建一个包含文档 ID 的语料库数据集,后面跟着一个密集的单词计数数组:

$ spark-shell --driver-memory 8g --executor-memory 8g --packages com.github.fommil.netlib:all:1.1.2
Ivy Default Cache set to: /home/alex/.ivy2/cache
The jars for the packages stored in: /home/alex/.ivy2/jars
:: loading settings :: url = jar:file:/opt/cloudera/parcels/CDH-5.5.2-1.cdh5.5.2.p0.4/jars/spark-assembly-1.5.0-cdh5.5.2-hadoop2.6.0-cdh5.5.2.jar!/org/apache/ivy/core/settings/ivysettings.xml
com.github.fommil.netlib#all added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent;1.0
 confs: [default]
 found com.github.fommil.netlib#all;1.1.2 in central
 found net.sourceforge.f2j#arpack_combined_all;0.1 in central
 found com.github.fommil.netlib#core;1.1.2 in central
 found com.github.fommil.netlib#netlib-native_ref-osx-x86_64;1.1 in central
 found com.github.fommil.netlib#native_ref-java;1.1 in central
 found com.github.fommil#jniloader;1.1 in central
 found com.github.fommil.netlib#netlib-native_ref-linux-x86_64;1.1 in central
 found com.github.fommil.netlib#netlib-native_ref-linux-i686;1.1 in central
 found com.github.fommil.netlib#netlib-native_ref-win-x86_64;1.1 in central
 found com.github.fommil.netlib#netlib-native_ref-win-i686;1.1 in central
 found com.github.fommil.netlib#netlib-native_ref-linux-armhf;1.1 in central
 found com.github.fommil.netlib#netlib-native_system-osx-x86_64;1.1 in central
 found com.github.fommil.netlib#native_system-java;1.1 in central
 found com.github.fommil.netlib#netlib-native_system-linux-x86_64;1.1 in central
 found com.github.fommil.netlib#netlib-native_system-linux-i686;1.1 in central
 found com.github.fommil.netlib#netlib-native_system-linux-armhf;1.1 in central
 found com.github.fommil.netlib#netlib-native_system-win-x86_64;1.1 in central
 found com.github.fommil.netlib#netlib-native_system-win-i686;1.1 in central
downloading https://repo1.maven.org/maven2/net/sourceforge/f2j/arpack_combined_all/0.1/arpack_combined_all-0.1-javadoc.jar ...
 [SUCCESSFUL ] net.sourceforge.f2j#arpack_combined_all;0.1!arpack_combined_all.jar (513ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/core/1.1.2/core-1.1.2.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#core;1.1.2!core.jar (18ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_ref-osx-x86_64/1.1/netlib-native_ref-osx-x86_64-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_ref-osx-x86_64;1.1!netlib-native_ref-osx-x86_64.jar (167ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_ref-linux-x86_64/1.1/netlib-native_ref-linux-x86_64-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_ref-linux-x86_64;1.1!netlib-native_ref-linux-x86_64.jar (159ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_ref-linux-i686/1.1/netlib-native_ref-linux-i686-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_ref-linux-i686;1.1!netlib-native_ref-linux-i686.jar (131ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_ref-win-x86_64/1.1/netlib-native_ref-win-x86_64-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_ref-win-x86_64;1.1!netlib-native_ref-win-x86_64.jar (210ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_ref-win-i686/1.1/netlib-native_ref-win-i686-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_ref-win-i686;1.1!netlib-native_ref-win-i686.jar (167ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_ref-linux-armhf/1.1/netlib-native_ref-linux-armhf-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_ref-linux-armhf;1.1!netlib-native_ref-linux-armhf.jar (110ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_system-osx-x86_64/1.1/netlib-native_system-osx-x86_64-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_system-osx-x86_64;1.1!netlib-native_system-osx-x86_64.jar (54ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_system-linux-x86_64/1.1/netlib-native_system-linux-x86_64-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_system-linux-x86_64;1.1!netlib-native_system-linux-x86_64.jar (47ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_system-linux-i686/1.1/netlib-native_system-linux-i686-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_system-linux-i686;1.1!netlib-native_system-linux-i686.jar (44ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_system-linux-armhf/1.1/netlib-native_system-linux-armhf-1.1-natives.jar ...
[SUCCESSFUL ] com.github.fommil.netlib#netlib-native_system-linux-armhf;1.1!netlib-native_system-linux-armhf.jar (35ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_system-win-x86_64/1.1/netlib-native_system-win-x86_64-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_system-win-x86_64;1.1!netlib-native_system-win-x86_64.jar (62ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/netlib-native_system-win-i686/1.1/netlib-native_system-win-i686-1.1-natives.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#netlib-native_system-win-i686;1.1!netlib-native_system-win-i686.jar (55ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/native_ref-java/1.1/native_ref-java-1.1.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#native_ref-java;1.1!native_ref-java.jar (24ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/jniloader/1.1/jniloader-1.1.jar ...
 [SUCCESSFUL ] com.github.fommil#jniloader;1.1!jniloader.jar (3ms)
downloading https://repo1.maven.org/maven2/com/github/fommil/netlib/native_system-java/1.1/native_system-java-1.1.jar ...
 [SUCCESSFUL ] com.github.fommil.netlib#native_system-java;1.1!native_system-java.jar (7ms)
:: resolution report :: resolve 3366ms :: artifacts dl 1821ms
 :: modules in use:
 com.github.fommil#jniloader;1.1 from central in [default]
 com.github.fommil.netlib#all;1.1.2 from central in [default]
 com.github.fommil.netlib#core;1.1.2 from central in [default]
 com.github.fommil.netlib#native_ref-java;1.1 from central in [default]
 com.github.fommil.netlib#native_system-java;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_ref-linux-armhf;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_ref-linux-i686;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_ref-linux-x86_64;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_ref-osx-x86_64;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_ref-win-i686;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_ref-win-x86_64;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_system-linux-armhf;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_system-linux-i686;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_system-linux-x86_64;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_system-osx-x86_64;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_system-win-i686;1.1 from central in [default]
 com.github.fommil.netlib#netlib-native_system-win-x86_64;1.1 from central in [default]
 net.sourceforge.f2j#arpack_combined_all;0.1 from central in [default]
 :: evicted modules:
 com.github.fommil.netlib#core;1.1 by [com.github.fommil.netlib#core;1.1.2] in [default]
 --------------------------------------------------------------------
 |                  |            modules            ||   artifacts   |
 |       conf       | number| search|dwnlded|evicted|| number|dwnlded|
 ---------------------------------------------------------------------
 |      default     |   19  |   18  |   18  |   1   ||   17  |   17  |
 ---------------------------------------------------------------------
...
scala> val enron = sc textFile("enron")
enron: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[1] at textFile at <console>:21

scala> enron.flatMap(_.split("\\W+")).map(_.toLowerCase).distinct.count
res0: Long = 529199 

scala> val stopwords = scala.collection.immutable.TreeSet ("", "i", "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "from", "had", "has", "he", "her", "him", "his", "in", "is", "it", "its", "not", "of", "on", "she", "that", "the", "to", "was", "were", "will", "with", "you")
stopwords: scala.collection.immutable.TreeSet[String] = TreeSet(, a, an, and, are, as, at, be, but, by, for, from, had, has, he, her, him, his, i, in, is, it, its, not, of, on, she, that, the, to, was, were, will, with, you)
scala> 

scala> val terms = enron.flatMap(x => if (x.length < 8192) x.toLowerCase.split("\\W+") else Nil).filterNot(stopwords).map(_,1).reduceByKey(_+_).collect.sortBy(- _._2).slice(0, 1000).map(_._1)
terms: Array[String] = Array(enron, ect, com, this, hou, we, s, have, subject, or, 2001, if, your, pm, am, please, cc, 2000, e, any, me, 00, message, 1, corp, would, can, 10, our, all, sent, 2, mail, 11, re, thanks, original, know, 12, 713, http, may, t, do, 3, time, 01, ees, m, new, my, they, no, up, information, energy, us, gas, so, get, 5, about, there, need, what, call, out, 4, let, power, should, na, which, one, 02, also, been, www, other, 30, email, more, john, like, these, 03, mark, 04, attached, d, enron_development, their, see, 05, j, forwarded, market, some, agreement, 09, day, questions, meeting, 08, when, houston, doc, contact, company, 6, just, jeff, only, who, 8, fax, how, deal, could, 20, business, use, them, date, price, 06, week, here, net, 15, 9, 07, group, california,...
scala> def getBagCounts(bag: Seq[String]) = { for(term <- terms) yield { bag.count(_==term) } }
getBagCounts: (bag: Seq[String])Array[Int]

scala> import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vectors

scala> val corpus = enron.map(x => { if (x.length < 8192) Some(x.toLowerCase.split("\\W+").toSeq) else None } ).map(x => { Vectors.dense(getBagCounts(x.getOrElse(Nil)).map(_.toDouble).toArray) }).zipWithIndex.map(_.swap).cache
corpus: org.apache.spark.rdd.RDD[(Long, org.apache.spark.mllib.linalg.Vector)] = MapPartitionsRDD[14] at map at <console>:30

scala> import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel}
import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel}

scala> import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vectors

scala> val ldaModel = new LDA().setK(10).run(corpus)
...
scala> ldaModel.topicsMatrix.transpose
res2: org.apache.spark.mllib.linalg.Matrix = 
207683.78495933366  79745.88417942637   92118.63972404732   ... (1000 total)
35853.48027575886   4725.178508682296   111214.8860582083   ...
135755.75666585402  54736.471356209106  93289.65563593085   ...
39445.796099155996  6272.534431534215   34764.02707696523   ...
329786.21570967307  602782.9591026317   42212.22143362559   ...
62235.09960154089   12191.826543794878  59343.24100019015   ...
210049.59592560542  160538.9650732507   40034.69756641789   ...
53818.14660186875   6351.853448001488   125354.26708575874  ...
44133.150537842856  4342.697652158682   154382.95646078113  ...
90072.97362336674   21132.629704311104  93683.40795807641   ...

我们还可以按主题的相对重要性降序列出单词及其相关重要性:

scala> ldaModel.describeTopics foreach { x : (Array[Int], Array[Double]) => { print(x._1.slice(0,10).map(terms(_)).mkString(":")); print("-> "); print(x._2.slice(0,10).map(_.toFloat).mkString(":")); println } }
com:this:ect:or:if:s:hou:2001:00:we->0.054606363:0.024220783:0.02096761:0.013669214:0.0132700335:0.012969772:0.012623918:0.011363528:0.010114557:0.009587474
s:this:hou:your:2001:or:please:am:com:new->0.029883621:0.027119286:0.013396418:0.012856948:0.01218803:0.01124849:0.010425644:0.009812181:0.008742722:0.0070441025
com:this:s:ect:hou:or:2001:if:your:am->0.035424445:0.024343235:0.015182628:0.014283071:0.013619815:0.012251413:0.012221165:0.011411696:0.010284024:0.009559739
would:pm:cc:3:thanks:e:my:all:there:11->0.047611523:0.034175437:0.022914853:0.019933242:0.017208714:0.015393614:0.015366959:0.01393391:0.012577525:0.011743208
ect:com:we:can:they:03:if:also:00:this->0.13815293:0.0755843:0.065043546:0.015290086:0.0121941045:0.011561104:0.011326733:0.010967959:0.010653805:0.009674695
com:this:s:hou:or:2001:pm:your:if:cc->0.016605735:0.015834121:0.01289918:0.012708308:0.0125788655:0.011726159:0.011477625:0.010578845:0.010555539:0.009609056
com:ect:we:if:they:hou:s:00:2001:or->0.05537054:0.04231919:0.023271963:0.012856676:0.012689817:0.012186356:0.011350313:0.010887237:0.010778923:0.010662295
this:s:hou:com:your:2001:or:please:am:if->0.030830953:0.016557815:0.014236835:0.013236604:0.013107091:0.0126846135:0.012257128:0.010862533:0.01027849:0.008893094
this:s:or:pm:com:your:please:new:hou:2001->0.03981197:0.013273305:0.012872894:0.011672661:0.011380969:0.010689667:0.009650983:0.009605533:0.009535899:0.009165275
this:com:hou:s:or:2001:if:your:am:please->0.024562683:0.02361607:0.013770585:0.013601272:0.01269994:0.012360005:0.011348433:0.010228578:0.009619628:0.009347991

要找出每个主题的前置文档或每个文档的前置主题,我们需要将此模型转换为DistributedLDALocalLDAModel,它们扩展了LDAModel

scala> ldaModel.save(sc, "ldamodel")

scala> val sameModel = DistributedLDAModel.load(sc, "ldamode2l")

scala> sameModel.topDocumentsPerTopic(10) foreach { x : (Array[Long], Array[Double]) => { print(x._1.mkString(":")); print("-> "); print(x._2.map(_.toFloat).mkString(":")); println } }
59784:50745:52479:60441:58399:49202:64836:52490:67936:67938-> 0.97146696:0.9713364:0.9661418:0.9661132:0.95249915:0.9519995:0.94945914:0.94944507:0.8977366:0.8791358
233009:233844:233007:235307:233842:235306:235302:235293:233020:233857-> 0.9962034:0.9962034:0.9962034:0.9962034:0.9962034:0.99620336:0.9954057:0.9954057:0.9954057:0.9954057
14909:115602:14776:39025:115522:288507:4499:38955:15754:200876-> 0.83963907:0.83415157:0.8319566:0.8303818:0.8291597:0.8281472:0.82739806:0.8272517:0.82579833:0.8243338
237004:71818:124587:278308:278764:278950:233672:234490:126637:123664-> 0.99929106:0.9968135:0.9964454:0.99644524:0.996445:0.99644494:0.99644476:0.9964447:0.99644464:0.99644417
156466:82237:82252:82242:341376:82501:341367:340197:82212:82243-> 0.99716955:0.94635135:0.9431836:0.94241136:0.9421047:0.9410431:0.94075173:0.9406304:0.9402021:0.94014835
335708:336413:334075:419613:417327:418484:334157:335795:337573:334160-> 0.987011:0.98687994:0.9865438:0.96953565:0.96953565:0.96953565:0.9588571:0.95852506:0.95832515:0.9581657
243971:244119:228538:226696:224833:207609:144009:209548:143066:195299-> 0.7546907:0.7546907:0.59146744:0.59095955:0.59090924:0.45532238:0.45064417:0.44945204:0.4487876:0.44833568
242260:214359:126325:234126:123362:233304:235006:124195:107996:334829-> 0.89615464:0.8961442:0.8106028:0.8106027:0.8106023:0.8106023:0.8106021:0.8106019:0.76834095:0.7570231
209751:195546:201477:191758:211002:202325:197542:193691:199705:329052-> 0.913124:0.9130985:0.9130918:0.9130672:0.5525752:0.5524637:0.5524494:0.552405:0.55240136:0.5026157
153326:407544:407682:408098:157881:351230:343651:127848:98884:129351-> 0.97206575:0.97206575:0.97206575:0.97206575:0.97206575:0.9689198:0.968068:0.9659192:0.9657442:0.96553063

分割、标注和分块

当文本以数字形式呈现时,找到单词相对容易,因为我们可以在非单词字符上分割流。在口语语言分析中,这变得更加复杂。在这种情况下,分词器试图优化一个指标,例如,最小化词典中不同单词的数量以及短语的长短或复杂性(《Python 自然语言处理》,作者Steven Bird 等人O'Reilly Media Inc,2009 年)。

标注通常指的是词性标注。在英语中,这些是名词、代词、动词、形容词、副词、冠词、介词、连词和感叹词。例如,在短语我们看到了黄色的狗中,我们是代词,看到了是动词,the是冠词,yellow是形容词,dog是名词。

在某些语言中,分块和标注取决于上下文。例如,在中文中,爱国者字面意思是爱国家的人,可以指爱国家的人爱国家的人。在俄语中,??????? ?????? ??????????字面意思是执行不赦免,可以指执行,不赦免不执行,赦免。虽然在书面语言中,这可以通过逗号来消除歧义,但在口语中,这通常很难识别差异,尽管有时语调可以帮助正确地划分短语。

对于基于词袋中词频的技术,一些极其常见的单词,它们在帮助选择文档时价值不大,被明确地从词汇表中排除。这些单词被称为停用词。没有好的通用策略来确定停用词表,但在许多情况下,这是排除几乎出现在每份文档中的非常频繁的单词,这些单词对于分类或信息检索目的没有帮助来区分它们。

词性标注

词性标注以概率方式将每个词标注为它的语法功能——名词、动词、形容词等等。通常,词性标注作为句法和语义分析输入。让我们在 FACTORIE 工具包示例中演示词性标注,这是一个用 Scala 编写的软件库(factorie.cs.umass.edu)。首先,你需要从github.com/factorie/factorie.git下载二进制镜像或源文件并构建它:

$ git clone https://github.com/factorie/factorie.git
...
$ cd factorie
$ git checkout factorie_2.11-1.2
...
$ mvn package -Pnlp-jar-with-dependencies

在构建过程中,这还包括模型训练,以下命令将在port 3228上启动一个网络服务器:

$ $ bin/fac nlp --wsj-forward-pos --conll-chain-ner
java -Xmx6g -ea -Djava.awt.headless=true -Dfile.encoding=UTF-8 -server -classpath ./src/main/resources:./target/classes:./target/factorie_2.11-1.2-nlp-jar-with-dependencies.jar
found model
18232
Listening on port 3228
...

现在,所有流向port 3228的流量都将被解释(作为文本),输出将被分词和标注:

$ telnet localhost 3228
Trying ::1...
Connected to localhost.
Escape character is '^]'.
But I warn you, if you don't tell me that this means war, if you still try to defend the infamies and horrors perpetrated by that Antichrist--I really believe he is Antichrist--I will have nothing more to do with you and you are no longer my friend, no longer my 'faithful slave,' as you call yourself! But how do you do? I see I have frightened you--sit down and tell me all the news.

1  1  But  CC  O
2  2  I    PRP  O
3  3  warn    VBP  O
4  4  you    PRP  O
5  5  ,      O
6  6  if    IN  O
7  7  you    PRP  O
8  8  do    VBP  O
9  9  n't    RB  O
10  10  tell    VB  O
11  11  me    PRP  O
12  12  that    IN  O
13  13  this    DT  O
14  14  means    VBZ  O
15  15  war    NN  O
16  16  ,    ,  O
17  17  if    IN  O
18  18  you  PRP  O
19  19  still    RB  O
20  20  try    VBP  O
21  21  to    TO  O
22  22  defend    VB  O
23  23  the    DT  O
24  24  infamies    NNS  O
25  25  and    CC  O
26  26  horrors    NNS  O
27  27  perpetrated    VBN  O
28  28  by    IN  O
29  29  that    DT  O
30  30  Antichrist    NNP  O
31  31  --    :  O
32  1  I  PRP  O
33  2  really    RB  O
34  3  believe    VBP  O
35  4  he    PRP  O
36  5  is    VBZ  O
37  6  Antichrist    NNP  U-MISC
38  7  --    :  O
39  1  I    PRP  O
40  2  will    MD  O
41  3  have    VB  O
42  4  nothing    NN  O
43  5  more    JJR  O
44  6  to    TO  O
45  7  do    VB  O
46  8  with    IN  O
47  9  you    PRP  O
48  10  and    CC  O
49  11  you    PRP  O
50  12  are    VBP  O
51  13  no    RB  O
52  14  longer    RBR  O
53  15  my    PRP$  O
54  16  friend    NN  O
55  17  ,    ,  O
56  18  no    RB  O
57  19  longer    RB  O
58  20  my  PRP$  O
59  21  '    POS  O
60  22  faithful    NN  O
61  23  slave    NN  O
62  24  ,    ,  O
63  25  '    ''  O
64  26  as    IN  O
65  27  you    PRP  O
66  28  call    VBP  O
67  29  yourself    PRP  O
68  30  !    .  O
69  1  But    CC  O
70  2  how    WRB  O
71  3  do    VBP  O
72  4  you    PRP  O
73  5  do    VB  O
74  6  ?    .  O
75  1  I    PRP  O
76  2  see    VBP  O
77  3  I    PRP  O
78  4  have    VBP  O
79  5  frightened    VBN  O
80  6  you    PRP  O
81  7  --    :  O
82  8  sit    VB  O
83  9  down    RB  O
84  10  and    CC  O
85  11  tell    VB  O
86  12  me    PRP  O
87  13  all    DT  O
88  14  the    DT  O
89  15  news    NN  O
90  16  .    .  O

这种词性标注是一个单一路径的左右标注器,可以像流一样处理文本。内部,算法使用概率技术来找到最可能的分配。让我们也看看其他不使用语法分析但已被证明对语言理解和解释非常有用的技术。

使用 word2vec 查找词关系

Word2vec 是由谷歌的托马斯·米科尔洛夫(Tomas Mikolov)在 2012 年左右开发的。word2vec 背后的原始想法是通过以效率换取模型的复杂性来提高效率。word2vec 不是将文档表示为词袋,而是通过尝试分析 n-gram 或 skip-gram(一个包含潜在问题标记的周围标记集)来考虑每个词的上下文。单词及其上下文本身由浮点数/双精度浮点数数组表示使用 word2vec 查找词关系。目标函数是最大化对数似然:

使用 word2vec 查找词关系

其中:

使用 word2vec 查找词关系

通过选择最优的 使用 word2vec 找到单词关系 并获得全面的单词表示(也称为 映射优化)。基于余弦相似度度量(点积)找到相似单词 使用 word2vec 找到单词关系。Spark 实现使用层次 softmax,将计算条件概率的复杂性降低到 使用 word2vec 找到单词关系,或词汇大小 V 的对数,而不是 使用 word2vec 找到单词关系,或与 V 成正比。训练仍然是数据集大小的线性,但适用于大数据并行化技术。

Word2vec 通常是用来根据上下文预测最可能的单词,或者找到具有相似意义的相似单词(同义词)。以下代码在 列夫·托尔斯泰的《战争与和平》 上训练 word2vec 模型,并找到单词 circle 的同义词。我不得不通过运行 cat 2600.txt | tr "\n\r" " " > warandpeace.txt 命令将古腾堡的 《战争与和平》 表现转换为单行格式:

scala> val word2vec = new Word2Vec
word2vec: org.apache.spark.mllib.feature.Word2Vec = org.apache.spark.mllib.feature.Word2Vec@58bb4dd

scala> val model = word2vec.fit(sc.textFile("warandpeace").map(_.split("\\W+").toSeq)
model: org.apache.spark.mllib.feature.Word2VecModel = org.apache.spark.mllib.feature.Word2VecModel@6f61b9d7

scala> val synonyms = model.findSynonyms("life", 10)
synonyms: Array[(String, Double)] = Array((freedom,1.704344822168997), (universal,1.682276637692245), (conception,1.6776193389148586), (relation,1.6760497906519414), (humanity,1.67601036253831), (consists,1.6637604144872544), (recognition,1.6526169382380496), (subjection,1.6496559771230317), (activity,1.646671198014248), (astronomy,1.6444424059160712))

scala> synonyms foreach println
(freedom,1.704344822168997)
(universal,1.682276637692245)
(conception,1.6776193389148586)
(relation,1.6760497906519414)
(humanity,1.67601036253831)
(consists,1.6637604144872544)
(recognition,1.6526169382380496)
(subjection,1.6496559771230317)
(activity,1.646671198014248)
(astronomy,1.6444424059160712)

虽然在一般情况下,找到一个客观函数是困难的,并且 freedom 在英语同义词词典中没有被列为 life 的同义词,但结果确实是有意义的。

在 word2vec 模型中,每个单词都表示为一个双精度浮点数数组。另一个有趣的应用是找到关联 a 到 b 与 c 到 ? 相同 通过执行减法 vector(a) - vector(b) + vector(c)

scala> val a = model.getVectors.filter(_._1 == "monarchs").map(_._2).head
a: Array[Float] = Array(-0.0044642715, -0.0013227836, -0.011506443, 0.03691717, 0.020431392, 0.013427449, -0.0036369907, -0.013460356, -3.8938568E-4, 0.02432113, 0.014533845, 0.004130258, 0.00671316, -0.009344602, 0.006229065, -0.005442078, -0.0045390734, -0.0038824948, -6.5973646E-4, 0.021729799, -0.011289608, -0.0030690092, -0.011423801, 0.009100784, 0.011765533, 0.0069619063, 0.017540144, 0.011198071, 0.026103685, -0.017285397, 0.0045515243, -0.0044477824, -0.0074411617, -0.023975836, 0.011371289, -0.022625357, -2.6478301E-5, -0.010510282, 0.010622139, -0.009597833, 0.014937023, -0.01298345, 0.0016747514, 0.01172987, -0.001512275, 0.022340108, -0.009758578, -0.014942565, 0.0040697413, 0.0015349758, 0.010246878, 0.0021413323, 0.008739062, 0.007845526, 0.006857361, 0.01160148, 0.008595...
scala> val b = model.getVectors.filter(_._1 == "princess").map(_._2).head
b: Array[Float] = Array(0.13265875, -0.04882792, -0.08409957, -0.04067986, 0.009084379, 0.121674284, -0.11963971, 0.06699862, -0.20277102, 0.26296946, -0.058114383, 0.076021515, 0.06751665, -0.17419271, -0.089830205, 0.2463593, 0.062816426, -0.10538805, 0.062085453, -0.2483566, 0.03468293, 0.20642486, 0.3129267, -0.12418643, -0.12557726, 0.06725172, -0.03703333, -0.10810595, 0.06692443, -0.046484336, 0.2433963, -0.12762263, -0.18473054, -0.084376186, 0.0037174677, -0.0040220995, -0.3419341, -0.25928706, -0.054454487, 0.09521076, -0.041567303, -0.13727514, -0.04826158, 0.13326299, 0.16228828, 0.08495835, -0.18073058, -0.018380836, -0.15691829, 0.056539804, 0.13673553, -0.027935665, 0.081865616, 0.07029694, -0.041142456, 0.041359138, -0.2304657, -0.17088272, -0.14424285, -0.0030700471, -0...
scala> val c = model.getVectors.filter(_._1 == "individual").map(_._2).head
c: Array[Float] = Array(-0.0013353615, -0.01820516, 0.007949033, 0.05430816, -0.029520465, -0.030641818, -6.607431E-4, 0.026548808, 0.04784935, -0.006470232, 0.041406438, 0.06599842, 0.0074243015, 0.041538745, 0.0030222891, -0.003932073, -0.03154199, -0.028486902, 0.022139633, 0.05738223, -0.03890591, -0.06761177, 0.0055152955, -0.02480924, -0.053222697, -0.028698998, -0.005315235, 0.0582403, -0.0024816995, 0.031634405, -0.027884213, 6.0290704E-4, 1.9750209E-4, -0.05563172, 0.023785716, -0.037577976, 0.04134448, 0.0026664822, -0.019832063, -0.0011898747, 0.03160933, 0.031184288, 0.0025268437, -0.02718441, -0.07729341, -0.009460656, 0.005344515, -0.05110715, 0.018468754, 0.008984449, -0.0053139487, 0.0053904117, -0.01322933, -0.015247412, 0.009819351, 0.038043085, 0.044905875, 0.00402788...
scala> model.findSynonyms(new DenseVector((for(i <- 0 until 100) yield (a(i) - b(i) + c(i)).toDouble).toArray), 10) foreach println
(achievement,0.9432423663884002)
(uncertainty,0.9187759184842362)
(leader,0.9163721499105207)
(individual,0.9048367510621271)
(instead,0.8992079672038455)
(cannon,0.8947818781378154)
(arguments,0.8883634101905679)
(aims,0.8725107984356915)
(ants,0.8593842583047755)
(War,0.8530727227924755)

这可以用来在语言中找到关系。

Porter Stemmer 的代码实现

Porter Stemmer 首次在 20 世纪 80 年代开发,有许多实现方式。详细的步骤和原始参考可以在 tartarus.org/martin/PorterStemmer/def.txt 找到。它大致包括 6-9 步的词尾/结尾替换,其中一些取决于前缀或词根。我将提供一个与书籍代码仓库优化的 Scala 版本。例如,步骤 1 覆盖了大多数词干化情况,并包括 12 个替换:最后 8 个取决于音节数和词根中的元音存在:

  def step1(s: String) = {
    b = s
    // step 1a
    processSubList(List(("sses", "ss"), ("ies","i"),("ss","ss"), ("s", "")), _>=0)
    // step 1b
    if (!(replacer("eed", "ee", _>0)))
    {
      if ((vowelInStem("ed") && replacer("ed", "", _>=0)) || (vowelInStem("ing") && replacer("ing", "", _>=0)))
      {
        if (!processSubList(List(("at", "ate"), ("bl","ble"), ("iz","ize")), _>=0 ) )
        {
          // if this isn't done, then it gets more confusing.
          if (doublec() && b.last != 'l' && b.last != 's' && b.last != 'z') { b = b.substring(0, b.length - 1) }
          else
            if (calcM(b.length) == 1 && cvc("")) { b = b + "e" }
        }
      }
    }
    // step 1c
    (vowelInStem("y") && replacer("y", "i", _>=0))
    this
  }

完整的代码可以在 github.com/alexvk/ml-in-scala/blob/master/chapter09/src/main/scala/Stemmer.scala 找到。

摘要

在本章中,我描述了基本的 NLP 概念,并演示了一些基本技术。我希望展示,相当复杂的 NLP 概念可以用几行 Scala 代码表达和测试。这无疑是冰山一角,因为现在正在开发许多 NLP 技术,包括基于 GPU 的 CPU 内并行化的技术。(例如,参考github.com/dlwh/puck中的Puck)。我还介绍了主要的 Spark MLlib NLP 实现。

在下一章,也就是本书的最后一章,我将介绍系统和模型监控。

第十章:高级模型监控

尽管这是本书的最后一章,但在实际情况下,监控通常被视为一个事后考虑的问题,这实在是不幸。监控是任何长时间执行周期组件的重要部署组件,因此它是最终产品的一部分。监控可以显著提升产品体验,并定义未来的成功,因为它改善了问题诊断,并对于确定改进路径至关重要。

成功软件工程的一个基本原则是,在可能的情况下,将系统设计成针对个人使用,这一点完全适用于监控、诊断和调试——对于修复软件产品中现有问题的名称来说,这实在是一个相当不幸的名称。诊断和调试复杂系统,尤其是分布式系统,是困难的,因为事件往往可以任意交织,程序执行可能受到竞态条件的影响。尽管在分布式系统 devops 和可维护性领域有很多研究正在进行,但本章将探讨这一领域,并提供设计可维护的复杂分布式系统的指导原则。

首先,一个纯函数式方法,Scala 声称遵循这种方法,花费了大量时间避免副作用。虽然这个想法在许多方面都是有用的,但很难想象一个对现实世界没有影响的程序是有用的,数据驱动应用程序的整个理念就是要对业务运营产生积极影响,这是一个定义良好的副作用。

监控明显属于副作用类别。执行过程需要留下痕迹,以便用户日后可以解析,从而了解设计或实现出现偏差的地方。执行痕迹可以通过在控制台或文件(通常称为日志文件)中写入内容,或者返回一个包含程序执行痕迹和中间结果的对象来实现。后者实际上更符合函数式编程和单子哲学,对于分布式编程来说更为合适,但往往被忽视。这本来可以是一个有趣的研究课题,但遗憾的是空间有限,我不得不讨论当代系统中监控的实用方面,这些监控几乎总是通过日志来完成的。在每个调用中携带带有执行痕迹的对象的单子方法,无疑会增加进程间或机器间通信的开销,但可以节省大量时间来拼接不同的信息。

让我们列出每个人在需要找到代码中的错误时尝试的简单调试方法:

  • 分析程序输出,尤其是由简单的打印语句或内置的 logback、java.util.logging、log4j 或 slf4j 门面产生的日志

  • 连接(远程)调试器

  • 监控 CPU、磁盘 I/O、内存(以解决高级资源利用率问题)

大体上,所有这些方法在存在多线程或分布式系统的情况下都会失败——Scala 本身就是多线程的,而 Spark 本身就是分布式的。在多个节点上收集日志不可扩展(尽管存在一些成功的商业系统这样做)。由于安全和网络限制,远程调试并不总是可能的。远程调试也可能引起大量的开销,并干扰程序执行,尤其是对于使用同步的程序。将调试级别设置为 DEBUGTRACE 级别有时有帮助,但将你置于开发者手中,开发者可能或可能没有考虑到你当前正在处理的特定角落情况。本书采取的方法是打开一个包含足够信息的 servlet,以便实时了解程序执行和应用方法,尽可能多地在 Scala 和 Scalatra 的当前状态下做到这一点。

关于程序执行调试的总体问题已经说得够多了。监控有所不同,因为它只关注高级问题识别。与问题调查或解决相交,但通常发生在监控之外。在本章中,我们将涵盖以下主题:

  • 理解监控的主要领域和监控目标

  • 学习用于 Scala/Java 监控的操作系统工具,以支持问题识别和调试

  • 了解 MBeans 和 MXBeans

  • 理解模型性能漂移

  • 理解 A/B 测试

系统监控

尽管存在其他类型的监控,专门针对机器学习任务,例如监控模型的性能,但让我从基本的系统监控开始。传统上,系统监控是操作系统维护的一个主题,但它正成为任何复杂应用的一个关键组成部分,尤其是在多个分布式工作站上运行的应用。操作系统的核心组件包括 CPU、磁盘、内存、网络以及电池供电机器上的能源。以下表格中提供了传统的类似操作系统的监控工具。我们将它们限制为 Linux 工具,因为这是大多数 Scala 应用的平台,尽管其他操作系统供应商提供了诸如活动监视器之类的操作系统监控工具。由于 Scala 运行在 Java JVM 上,我还添加了针对 JVM 的特定 Java 监控工具:

区域 项目 备注
CPU htop, top, sar-u top一直是使用最频繁的性能诊断工具,因为 CPU 和内存一直是受限制的资源。随着分布式编程的出现,网络和磁盘往往成为受限制的资源。
磁盘 iostat, sar -d, lsof lsof提供的打开文件数量通常是一个限制性资源,因为许多大数据应用和守护进程倾向于保持多个文件打开。
内存 top, free, vmstat, sar -r 内存以多种方式被操作系统使用,例如维护磁盘 I/O 缓冲区,因此额外的缓冲和缓存内存有助于性能。
网络 ifconfig, netstat, tcpdump, nettop, iftop, nmap 网络是分布式系统之间通信的方式,是重要的操作系统组件。从应用的角度来看,关注错误、冲突和丢失的数据包,作为问题的指示器。
能源 powerstat 虽然传统上功耗不是操作系统监控的一部分,但它毕竟是一个共享资源,最近已成为维护工作系统的主要成本之一。
Java jconsole, jinfo, jcmd, jmc 所有这些工具都允许您检查应用程序的配置和运行时属性。Java Mission ControlJMC)从 JDK 7u40 版本开始随 JDK 一起提供。

表 10.1. 常见的 Linux 操作系统监控工具

在许多情况下,工具是多余的。例如,可以使用topsarjmc命令获取 CPU 和内存信息。

有一些工具可以跨分布式节点收集这些信息。Ganglia 是一个 BSD 许可的可伸缩分布式监控系统(ganglia.info)。它基于分层设计,非常注重数据结构和算法设计。它已知可以扩展到 10,000 多个节点。它由一个 gmetad 守护进程组成,该守护进程从多个主机收集信息并在 Web 界面中展示,以及在每个单独的主机上运行的 gmond 守护进程。默认情况下,通信发生在 8649 端口,这代表着 Unix。默认情况下,gmond 发送有关 CPU、内存和网络的信息,但存在多个插件用于其他指标(或可以创建)。Gmetad 可以聚合信息并将其传递到层次链中的另一个 gmetad 守护进程。最后,数据在 Ganglia Web 界面中展示。

Graphite 是另一个监控工具,它存储数值时间序列数据,并在需要时渲染这些数据的图表。该 Web 应用程序提供了一个/render 端点来生成图表并通过 RESTful API 检索原始数据。Graphite 有一个可插拔的后端(尽管它有自己的默认实现)。大多数现代指标实现,包括本章中使用的 scala-metrics,都支持将数据发送到 Graphite。

进程监控

在上一节中描述的工具不是特定于应用的。对于长时间运行的过程,通常需要向监控或图形解决方案提供有关内部状态的信息,例如 Ganglia 或 Graphite,或者只是在一个 servlet 中显示它。大多数这些解决方案是只读的,但在某些情况下,命令会赋予用户修改状态的控制权,例如日志级别,或者触发垃圾回收。

监控通常应该执行以下操作:

  • 提供关于程序执行和特定于应用的指标的高级信息

  • 可能执行对关键组件的健康检查

  • 可能会包含对一些关键指标的警报和阈值设置

我还看到监控包括更新操作,以更新日志参数或测试组件,例如使用预定义参数触发模型评分。后者可以被视为参数化健康检查的一部分。

让我们以一个简单的Hello World Web 应用程序为例来了解它是如何工作的,该应用程序接受类似 REST 的请求并为不同的用户分配唯一的 ID,该应用程序是用 Scala 框架 Scalatra 编写的(scalatra.org),这是一个 Scala 中的轻量级 Web 应用程序开发框架。该应用程序应该响应 CRUD HTTP 请求为用户创建一个唯一的数字 ID。要在 Scalatra 中实现此服务,我们只需要提供一个Scalate模板。完整的文档可以在scalatra.org/2.4/guides/views/scalate.html找到,源代码与本书一起提供,可以在chapter10子目录中找到:

class SimpleServlet extends Servlet {
  val logger = LoggerFactory.getLogger(getClass)
  var hwCounter: Long = 0L
  val hwLookup: scala.collection.mutable.Map[String,Long] = scala.collection.mutable.Map() 
  val defaultName = "Stranger"
  def response(name: String, id: Long) = { "Hello %s! Your id should be %d.".format(if (name.length > 0) name else defaultName, id) }
  get("/hw/:name") {
    val name = params("name")
    val startTime = System.nanoTime
    val retVal = response(name, synchronized { hwLookup.get(name) match { case Some(id) => id; case _ => hwLookup += name -> { hwCounter += 1; hwCounter } ; hwCounter } } )
    logger.info("It took [" + name + "] " + (System.nanoTime - startTime) + " " + TimeUnit.NANOSECONDS)
    retVal
  }
}

首先,代码从请求中获取name参数(也支持类似 REST 的参数解析)。然后,它检查内部 HashMap 中是否存在条目,如果不存在条目,它将使用对hwCounter的同步调用创建一个新的索引(在实际应用中,此类信息应持久存储在数据库中,如 HBase,但在此部分中为了简化,我将跳过这一层)。要运行应用程序,需要下载代码,启动sbt,并输入~;jetty:stop;jetty:start以启用连续运行/编译,如第七章中所述,使用图算法。对文件的修改将立即被构建工具捕获,jetty 服务器将重新启动:

[akozlov@Alexanders-MacBook-Pro chapter10]$ sbt
[info] Loading project definition from /Users/akozlov/Src/Book/ml-in-scala/chapter10/project
[info] Compiling 1 Scala source to /Users/akozlov/Src/Book/ml-in-scala/chapter10/project/target/scala-2.10/sbt-0.13/classes...
[info] Set current project to Advanced Model Monitoring (in build file:/Users/akozlov/Src/Book/ml-in-scala/chapter10/)
> ~;jetty:stop;jetty:start
[success] Total time: 0 s, completed May 15, 2016 12:08:31 PM
[info] Compiling Templates in Template Directory: /Users/akozlov/Src/Book/ml-in-scala/chapter10/src/main/webapp/WEB-INF/templates
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
[info] starting server ...
[success] Total time: 1 s, completed May 15, 2016 12:08:32 PM
1\. Waiting for source changes... (press enter to interrupt)
2016-05-15 12:08:32.578:INFO::main: Logging initialized @119ms
2016-05-15 12:08:32.586:INFO:oejr.Runner:main: Runner
2016-05-15 12:08:32.666:INFO:oejs.Server:main: jetty-9.2.1.v20140609
2016-05-15 12:08:34.650:WARN:oeja.AnnotationConfiguration:main: ServletContainerInitializers: detected. Class hierarchy: empty
2016-15-05 12:08:34.921: [main] INFO  o.scalatra.servlet.ScalatraListener - The cycle class name from the config: ScalatraBootstrap
2016-15-05 12:08:34.973: [main] INFO  o.scalatra.servlet.ScalatraListener - Initializing life cycle class: ScalatraBootstrap
2016-15-05 12:08:35.213: [main] INFO  o.f.s.servlet.ServletTemplateEngine - Scalate template engine using working directory: /var/folders/p1/y7ygx_4507q34vhd60q115p80000gn/T/scalate-6339535024071976693-workdir
2016-05-15 12:08:35.216:INFO:oejsh.ContextHandler:main: Started o.e.j.w.WebAppContext@1ef7fe8e{/,file:/Users/akozlov/Src/Book/ml-in-scala/chapter10/target/webapp/,AVAILABLE}{file:/Users/akozlov/Src/Book/ml-in-scala/chapter10/target/webapp/}
2016-05-15 12:08:35.216:WARN:oejsh.RequestLogHandler:main: !RequestLog
2016-05-15 12:08:35.237:INFO:oejs.ServerConnector:main: Started ServerConnector@68df9280{HTTP/1.1}{0.0.0.0:8080}
2016-05-15 12:08:35.237:INFO:oejs.Server:main: Started @2795ms2016-15-05 12:03:52.385: [main] INFO  o.f.s.servlet.ServletTemplateEngine - Scalate template engine using working directory: /var/folders/p1/y7ygx_4507q34vhd60q115p80000gn/T/scalate-3504767079718792844-workdir
2016-05-15 12:03:52.387:INFO:oejsh.ContextHandler:main: Started o.e.j.w.WebAppContext@1ef7fe8e{/,file:/Users/akozlov/Src/Book/ml-in-scala/chapter10/target/webapp/,AVAILABLE}{file:/Users/akozlov/Src/Book/ml-in-scala/chapter10/target/webapp/}
2016-05-15 12:03:52.388:WARN:oejsh.RequestLogHandler:main: !RequestLog
2016-05-15 12:03:52.408:INFO:oejs.ServerConnector:main: Started ServerConnector@68df9280{HTTP/1.1}{0.0.0.0:8080}
2016-05-15 12:03:52.408:INFO:oejs.Server:main: Started @2796mss

当 servlet 在 8080 端口启动时,发出浏览器请求:

提示

我为这本书预先创建了项目,但如果你想要从头开始创建 Scalatra 项目,chapter10/bin/create_project.sh中有一个gitter命令。Gitter 将创建一个project/build.scala文件,其中包含一个 Scala 对象,扩展了构建,这将设置项目参数并启用 SBT 的 Jetty 插件。

http://localhost:8080/hw/Joe.

输出应类似于以下截图:

进程监控

图 10-1:servlet 网页。

如果你使用不同的名称调用 servlet,它将分配一个唯一的 ID,该 ID 将在应用程序的生命周期内保持持久。

由于我们也启用了控制台日志记录,你将在控制台上看到类似以下命令的内容:

2016-15-05 13:10:06.240: [qtp1747585824-26] INFO  o.a.examples.ServletWithMetrics - It took [Joe] 133225 NANOSECONDS

在检索和分析日志时,可以将日志重定向到文件,并且有多个系统可以收集、搜索和分析来自一组分布式服务器的日志,但通常还需要一种简单的方法来检查运行中的代码。实现这一目标的一种方法是为指标创建一个单独的模板,然而,Scalatra 提供了指标和健康支持,以实现计数、直方图、速率等基本实现。

我将使用 Scalatra 指标支持。ScalatraBootstrap类必须实现MetricsBootstrap特质。org.scalatra.metrics.MetricsSupportorg.scalatra.metrics.HealthChecksSupport特质提供了类似于 Scalate 模板的模板,如下面的代码所示。

以下为ScalatraTemplate.scala文件的內容:

import org.akozlov.examples._

import javax.servlet.ServletContext
import org.scalatra.LifeCycle
import org.scalatra.metrics.MetricsSupportExtensions._
import org.scalatra.metrics._

class ScalatraBootstrap extends LifeCycle with MetricsBootstrap {
  override def init(context: ServletContext) = {
    context.mount(new ServletWithMetrics, "/")
    context.mountMetricsAdminServlet("/admin")
    context.mountHealthCheckServlet("/health")
    context.installInstrumentedFilter("/*")
  }
}

以下为ServletWithMetrics.scala文件的內容:

package org.akozlov.examples

import org.scalatra._
import scalate.ScalateSupport
import org.scalatra.ScalatraServlet
import org.scalatra.metrics.{MetricsSupport, HealthChecksSupport}
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.TimeUnit
import org.slf4j.{Logger, LoggerFactory}

class ServletWithMetrics extends Servlet with MetricsSupport with HealthChecksSupport {
  val logger = LoggerFactory.getLogger(getClass)
  val defaultName = "Stranger"
  var hwCounter: Long = 0L
  val hwLookup: scala.collection.mutable.Map[String,Long] = scala.collection.mutable.Map()  val hist = histogram("histogram")
  val cnt =  counter("counter")
  val m = meter("meter")
  healthCheck("response", unhealthyMessage = "Ouch!") { response("Alex", 2) contains "Alex" }
  def response(name: String, id: Long) = { "Hello %s! Your id should be %d.".format(if (name.length > 0) name else defaultName, id) }

  get("/hw/:name") {
    cnt += 1
    val name = params("name")
    hist += name.length
    val startTime = System.nanoTime
    val retVal = response(name, synchronized { hwLookup.get(name) match { case Some(id) => id; case _ => hwLookup += name -> { hwCounter += 1; hwCounter } ; hwCounter } } )s
    val elapsedTime = System.nanoTime - startTime
    logger.info("It took [" + name + "] " + elapsedTime + " " + TimeUnit.NANOSECONDS)
    m.mark(1)
    retVal
  }

如果你再次运行服务器,http://localhost:8080/admin页面将显示一组操作信息的链接,如下面的截图所示:

进程监控

图 10-2:管理 servlet 网页

指标链接将引导到图 10-3中所示的指标 servlet。org.akozlov.exampes.ServletWithMetrics.counter将具有请求的全局计数,而org.akozlov.exampes.ServletWithMetrics.histogram将显示累积值的分布,在这种情况下,是名称长度。更重要的是,它将计算507595989999.9分位数。计量计数器将显示过去1515分钟内的速率:

流程监控

图 10-3:指标 servlet 网页

最后,可以编写健康检查。在这种情况下,我将只检查响应函数的结果是否包含作为参数传递的字符串。请参考以下图 10.4

流程监控

图 10-4:健康检查 servlet 网页。

可以配置指标以向 Ganglia 或 Graphite 数据收集服务器报告,或者定期将信息记录到日志文件中。

端点不必是只读的。预配置的组件之一是计时器,它测量完成任务所需的时间——可以用来测量评分性能。让我们将代码放入ServletWithMetrics类中:

  get("/time") {
    val sleepTime = scala.util.Random.nextInt(1000)
    val startTime = System.nanoTime
    timer("timer") {
      Thread.sleep(sleepTime)
      Thread.sleep(sleepTime)
      Thread.sleep(sleepTime)
    }
    logger.info("It took [" + sleepTime + "] " + (System.nanoTime - startTime) + " " + TimeUnit.NANOSECONDS)
    m.mark(1)
  }

访问http://localhost:8080/time将触发代码执行,这将通过指标中的计时器进行计时。

类似地,可以使用put()模板创建的 put 操作,可以用来调整运行时参数或就地执行代码——这取决于代码,可能需要在生产环境中进行安全加固。

注意

JSR 110

JSR 110 是另一个Java 规范请求JSR),通常称为Java 管理扩展JMX)。JSR 110 指定了一系列 API 和协议,以便能够远程监控 JVM 执行。访问 JMX 服务的一种常见方式是通过默认连接到本地进程之一的jconsole命令。要连接到远程主机,您需要在 Java 命令行上提供-Dcom.sun.management.jmxremote.port=portNum属性。还建议启用安全性(SSL 或基于密码的认证)。在实践中,其他监控工具使用 JMX 进行监控,以及管理 JVM,因为 JMX 允许回调来管理系统状态。

您可以通过 JMX 公开自己的指标。虽然 Scala 在 JVM 中运行,但 JMX(通过 MBeans)的实现非常特定于 Java,不清楚该机制与 Scala 的兼容性如何。尽管如此,JMX Beans 可以在 Scala 中作为 servlet 公开。

JMX MBeans 通常可以在 JConsole 中检查,但我们也可以将其公开为/jmx servlet,书中代码库中提供的代码(github.com/alexvk/ml-in-scala)。

模型监控

我们已经涵盖了基本的系统和应用指标。最近,一个新的方向是利用监控组件来监控统计模型性能。统计模型性能包括以下内容:

  • 模型性能随时间的变化

  • 何时退役模型

  • 模型健康检查

随时间推移的性能

机器学习模型会随着时间的推移而退化,或者说“老化”:尽管这个过程还没有被充分理解,但模型性能往往会随时间变化,即使是因为概念漂移,即属性的定义发生变化,或底层依赖关系的变化。不幸的是,模型性能很少提高,至少在我的实践中是这样。因此,跟踪模型至关重要。一种方法是监控模型旨在优化的指标,因为在许多情况下,我们没有现成的标签数据集。

在许多情况下,模型性能的下降并不是直接与统计建模的质量相关,尽管像线性回归和逻辑回归这样的简单模型通常比决策树等更复杂的模型更稳定。模式演变或未注意到的属性重命名可能导致模型表现不佳。

模型监控的一部分应该是运行健康检查,其中模型定期对一些记录或已知评分的数据集进行评分。

模型退役标准

在实际部署中一个非常常见的案例是,数据科学家每隔几周就会带来更好的模型集。然而,如果这种情况没有发生,就需要制定一套标准来退役模型。由于现实世界的流量很少带有评分数据,例如,已经评分的数据,衡量模型性能的通常方式是通过代理,即模型应该改进的指标。

A/B 测试

A/B 测试是电子商务环境中控制实验的一个特定案例。A/B 测试通常应用于网页的不同版本,我们将完全独立的用户子集引导到每个版本。要测试的因变量通常是响应率。除非有关于用户的特定信息,并且在许多情况下,除非在计算机上放置了 cookie,否则这种分割通常是随机的。通常分割是基于唯一的 userID,但已知这并不适用于多个设备。A/B 测试受到与控制实验相同的假设:测试应该是完全独立的,因变量的分布应该是i.i.d.。尽管很难想象所有人都是真正的i.i.d.,但 A/B 测试已被证明适用于实际问题。

在建模中,我们将要评分的流量分成两个或多个通道,由两个或多个模型进行评分。进一步地,我们需要测量每个通道的累积性能指标以及估计的方差。通常,其中一个模型被视为基线,并与零假设相关联,而对于其他模型,我们运行 t 检验,比较差异与标准差的比例。

摘要

本章描述了系统、应用和模型监控目标,以及 Scala 和 Scalatra 现有的监控解决方案。许多指标与标准操作系统或 Java 监控重叠,但我们还讨论了如何创建特定于应用的指标和健康检查。我们讨论了机器学习应用中新兴的模型监控领域,其中统计模型受到退化、健康和性能监控的影响。我还简要提到了监控分布式系统,这是一个真正值得更多篇幅讨论的话题,但遗憾的是,我没有足够的空间来展开。

这本书的结尾,但绝对不是旅程的终点。我相信,当我们说话的时候,新的框架和应用正在被编写。在我的实践中,Scala 已经是一个非常出色且简洁的开发工具,我能够用几个小时而不是几天就实现结果,这是更传统工具的情况,但它尚未赢得广泛的认可,我对此非常确信。我们只需要强调它在现代交互式分析、复杂数据和分布式处理世界中的优势。

附录 A. 参考文献

这条学习路径是根据您的旅程精心打包的。它包括以下 Packt 产品的内容:

  • 《Scala 数据科学》,作者:Pascal Bugnion

  • 《Scala 机器学习》,作者:Patrick R. Nicolas

  • 《精通 Scala 机器学习》,作者:Alex Kozlov

posted @ 2025-09-03 10:24  绝不原创的飞龙  阅读(6)  评论(0)    收藏  举报