TowardsDataScience-2023-博客中文翻译-三十三-

TowardsDataScience 2023 博客中文翻译(三十三)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

嵌套字典 Python——Python 嵌套字典的完整指南

原文:towardsdatascience.com/nested-dictionary-python-a-complete-guide-to-python-nested-dictionaries-756a7822cb4f

如何在 Python 中使用嵌套字典?本文将教你关于 Python 嵌套字典的一切知识。

达里奥·拉德奇数据科学的前沿 达里奥·拉德奇

·发表于数据科学的前沿 ·12 分钟阅读·2023 年 4 月 18 日

--

图片由凯利·西克马提供,来源于Unsplash

什么是 Python 中的嵌套字典?

Python 中一种常见的数据结构是嵌套字典,或者说字典的值可以是其他字典。初学者不喜欢嵌套字典,因为它们需要更多的时间来处理和正确解析,但只要稍加练习,你就能掌握它。

刚接触 Python?首先学习基础字典。

今天你将学习什么是嵌套字典,为什么在 Python 中使用嵌套字典,如何在 Python 中遍历嵌套字典,等等。关于库的导入,将其放在你的脚本或笔记本的顶部:

import pprint
pp = pprint.PrettyPrinter(depth=4)

它将在打印嵌套字典时处理格式,使其更易于阅读。

如何在 Python 中创建嵌套字典

有许多方法可以创建嵌套字典,但如果你从头开始在 Python 中创建它们,你主要会使用两种方法。

使用常规 Python 符号

创建嵌套 Python 字典的第一种方法是利用常规 Python 符号。这意味着你不需要使用任何特定的函数或库来创建字典。只需将其分配给一个变量名,并将整个内容格式化为 JSON。

这是一个示例——以下代码片段创建了一个员工的嵌套字典,其中员工的电子邮件被用作字典的键,附加信息作为字典值。正如你所看到的,字典值本身也是一个字典:

employees = {
    "bdoe@email.com": {
        "first_name": "Bob", 
        "last_name": "Doe",
        "address": {
            "city": "New York",
            "street": "1st Street",
            "house_number": 1
        }
    },
    "mmarkson@email.com": {
        "first_name": "Mark", 
        "last_name": "Markson",
        "address": {
            "city": "San Diego",
            "street": "2nd Street",
            "house_number": 2
        }
    }
}

pp.pprint(employees)

这就是这个嵌套字典的样子:

图像 1 — 员工的嵌套字典(作者提供的图像)

总体而言,我们有一个包含两个键(电子邮件)的字典。每个键都有一个字典作为值,甚至还有一个第三个字典分配给address键。

字典是无序的,所以你看到的数据没有反映代码中指定的排序。对此不必担心。

使用 zip() 函数

在 Python 中创建嵌套字典的另一种方法是使用zip()函数。它用于同时迭代两个或多个迭代器。

为了演示,我们将声明两个列表:

  • employee_emails - 一个电子邮件列表,将作为字典的键

  • employee_details - 每个员工的详细信息列表,如名字、姓氏和地址。

如果你以这种方式声明数据,你可以将它们传递给zip()并将所有内容包装在dict()调用中。这将分配适当的键值对。代码如下:

employee_emails = ["bdoe@email.com", "mmarkson@email.com"]

employee_details = [
    {
        "first_name": "Bob", 
        "last_name": "Doe", 
        "address": {
            "city": "New York", 
            "street": "1st Street", 
            "house_number": 1
        }
    },
    {
        "first_name": "Mark", 
        "last_name": "Markson", 
        "address": {
            "city": "San Diego", 
            "street": "2nd Street", 
            "house_number": 2
        }
    }
]

employees = dict(zip(employee_emails, employee_details))
pp.pprint(employees)

结果数据看起来与之前的一样:

图像 2 — 员工的嵌套字典(2)(作者提供的图像)

实际上,没有理由使用这种方法来声明嵌套字典。这很混乱,编写起来也更费时间。只需坚持使用第一个方法,你就可以了。

接下来,让我们看看如何在 Python 中访问嵌套字典的元素。

如何访问嵌套字典的元素

你可以像访问普通字典一样访问嵌套字典的元素,唯一的例外是你现在需要添加额外的一组括号。

下面是你可以做的一些示例:

  • 访问单个元素

  • 访问一个也作为字典的单个元素

  • 连接多个嵌套字典值

或者在代码中:

# Access element that contains a string
print(employees["bdoe@email.com"]["first_name"])

# Access element that contains a string
print(employees["bdoe@email.com"]["last_name"])

# Access element that contains a dictionary
print(employees["bdoe@email.com"]["address"])

# Combine multiple elements
print(f"{employees['bdoe@email.com']['first_name']} {employees['bdoe@email.com']['last_name']} from {employees['bdoe@email.com']['address']['city']}")

这是你应该看到的输出:

图像 3 — 访问嵌套 Python 字典的元素(作者提供的图像)

总的来说,如果你想到达嵌套字典的最底层,你需要与嵌套字典的层级数相同的括号。例如,要获取bdoe@email.com的城市,你需要写employees["bdoe@email.com"]["address"]["email"]。很简单!

如何更改嵌套字典中的值

你现在知道如何访问嵌套字典中的元素,但如何更改这些值呢?这非常简单,你可以逐个更改值,也可以一次处理多个值。

更改嵌套字典中的单个值

你可以通过访问嵌套字典并分配一个新值来更改单个值。

以下示例展示了如何一次更改一个员工的完整address

# Change values one by one
employees["bdoe@email.com"]["address"]["city"] = "San Francisco"
employees["bdoe@email.com"]["address"]["street"] = "5th Street"
employees["bdoe@email.com"]["address"]["house_number"] = 5

pp.pprint(employees["bdoe@email.com"])

更新后的员工数据现在如下所示:

图像 4 — 更改嵌套字典中的单个值(作者提供的图像)

很好,但你能在一行中更改 address 吗?当然可以,接下来我们将探讨如何实现。

在嵌套字典中更改多个值

address 属性本身就是一个字典,这意味着你可以在一行 Python 代码中完全更改它:

# Change multiple values at once
employees["mmarkson@email.com"]["address"] = {"city": "Los Angeles", "street": "15th Street", "house_number": 15}

pp.pprint(employees["mmarkson@email.com"])

更新后的员工记录如下:

图 5 — 更改嵌套字典中的多个值(图片作者提供)

你现在知道如何访问和更改嵌套字典元素了,接下来我们将讨论如何在 Python 中向嵌套字典中添加新元素。

如何向嵌套字典中添加元素

在 Python 中向嵌套字典中添加新元素就是赋值一个新的键值对。就是这么简单!

以下代码片段声明了两个变量来存储新的嵌套字典项的键和值,然后使用 dictionary[key] = value 赋值运算符添加这个键值对:

new_employee_email = "jswift@email.com"
new_employee_details = {
    "first_name": "Jane", 
    "last_name": "Swift", 
    "address": {
        "city": "Boston", 
        "street": "10th Street", 
        "house_number": 10
    }
}

# dictionary[key] = value
employees[new_employee_email] = new_employee_details

pp.pprint(employees)

更新后的嵌套字典现在有 3 个记录:

图 6 — 在 Python 中向嵌套字典中添加元素(图片作者提供)

那是添加操作,所以接下来我们来讨论删除操作。

如何从嵌套字典中删除元素

你可以使用 Python 的 del 关键字,后跟字典的名称和你想删除的键。例如,运行 del d["name"] 来删除字典 d 中键为 name 的键值对。

在我们的示例中,让我们使用 del 删除刚刚添加的员工:

del employees["jswift@email.com"]

pp.pprint(employees)

现在我们只剩下两个记录了:

图 7 — 从嵌套字典中删除元素(图片作者提供)

接下来,让我们讨论如何合并两个或多个字典。

如何合并两个嵌套字典

字典合并就是将两个或多个字典合并成一个。字典的结构相同会有所帮助,但不是必须的,因为这是 Python。

为了演示,我们来声明一个新的嵌套字典的员工记录:

new_employees = {
    "jswift@email.com": {
        "first_name": "Jane", 
        "last_name": "Swift",
        "address": {
            "city": "Boston",
            "street": "25th Street",
            "house_number": 25
        }
    },
    "pjohnson@email.com": {
        "first_name": "Patrick", 
        "last_name": "Johnson",
        "address": {
            "city": "Miami",
            "street": "50th Street",
            "house_number": 50
        }
    }
}

pp.pprint(new_employees)

这就是它的样子:

图 8 — 两名新员工(图片作者提供)

这个想法是将这个字典添加到我们已有的字典中,有两种方法可以实现。

使用 update() 函数合并两个字典

update() 函数 更新 一个字典的内容到另一个字典。更新是就地进行的,这意味着你不需要重新赋值变量。

这是一个例子:

employees.update(new_employees)

pp.pprint(employees)

更新后的嵌套字典现在有四个记录:

图 9 — 合并嵌套字典(图片作者提供)

这个函数很容易使用,但缺点是 你一次只能添加一个字典。下一种方法则更灵活一些。

使用 kwargs 合并两个字典

**kwargs方法对新手来说可能看起来很奇怪,但它本质上只是解包字典。通过这样做,你可以将尽可能多的字典解包到一个新的字典中。

这里是一个例子:

emps_merged = {**employees, **new_employees}

pp.pprint(emps_merged)

合并后的嵌套字典与我们之前的字典完全相同:

图 10 — 合并嵌套字典(2)(作者提供的图片)

最终,决定最佳合并方法还是取决于你。我们推荐使用**kwargs,因为你可以在一行 Python 代码中合并数十个字典。

如何遍历嵌套字典

在 Python 中,处理嵌套字典时没有一种适用于所有情况的解决方案。字典项的结构会有所不同,这意味着你每次都需要定制代码。

为了演示,我们将介绍两个例子,一个较简单,另一个代码稍复杂一些。

第一个例子遍历所有字典项并打印键,然后也遍历相应的值并打印它。以下是代码:

# Keys and values
for emp_email, emp_info in employees.items():
    print(f"EMAIL: {emp_email}")

    # For each key that belongs to a dictionary at the given email
    for key in emp_info:
        # Print the corresponding key and value
        print(f"{key} = {emp_info[key]}")

    print()

你应该看到如下结果:

图 11 — 遍历嵌套字典(作者提供的图片)

如果你没有额外的嵌套层级,类似于我们所拥有的address,这种方法可能有效。处理这个问题会更加具有挑战性,但接下来我们可以尝试一下。

更高级的迭代示例

现在的想法是进入address字典,并打印它包含的所有元素。

为了让代码更健壮一点,我们将检查项目是否为字典,如果是,则遍历其项。如果项目不是字典,我们将简单地打印它:

# Keys and values
for emp_email, emp_info in employees.items():
    print(f"EMAIL: {emp_email}")

    # For every key in the inner dictionary
    for key in emp_info:

        # Check if a type is a dictionary
        if type(emp_info[key]) is dict:
            print(f"{key}:")

            # Print nested items
            for item in emp_info[key]:
                print(f"\t{item} = {emp_info[key][item]}")
        # Not a dictionary, print the value
        else:
            print(f"{key} = {emp_info[key]}")

    print()

你应该在屏幕上看到如下内容:

图 12 — 遍历嵌套字典(2)(作者提供的图片)

总体而言,我们已经成功解析了我们的嵌套字典结构,但再次强调,这并不适用于所有嵌套字典。你必须定制代码片段以适应你的用例,这可能会很棘手且耗时。

如何扁平化嵌套字典

扁平化嵌套字典意味着你希望获得一个不包含任何字典或列表的字典。

这是在将嵌套 JSON 文档解析为 Pandas DataFrame 时非常常见的数据预处理技术。如果你正在处理这样的数据,你会知道输入数据的结构会有很大不同。

大多数情况下,你会有一个字典列表。我们将在下方声明这样一个列表,并在每个员工对象内部添加email属性,而不是将其用作字典键:

employees = [
    {
        "first_name": "Bob", 
        "last_name": "Doe",
        "email": "bdoe@email.com",
        "address": {
            "city": "New York",
            "street": "1st Street",
            "house_number": 1
        }
    },
    {
        "first_name": "Mark", 
        "last_name": "Markson",
        "email": "mmarkson@email.com",
        "address": {
            "city": "San Diego",
            "street": "2nd Street",
            "house_number": 2
        }
    }
]

一旦你拥有这样的数据格式,就该进行扁平化处理了。以下递归函数将一个记录或列表中的一个字典元素进行扁平化。对于任何嵌套字典,它会将其扁平化,使键重命名为完整的路径

flatten_dict()函数必须应用于字典列表中的每一条记录,这意味着你可以使用 Python 循环或列表推导式。

这里有一个示例:

def flatten_dict(d: dict) -> dict:
    out = {}

    def flatten(x, name: str = ''):
        if type(x) is dict:
            for a in x:
                flatten(x[a], name + a + '_')
        elif type(x) is list:
            i = 0
            for a in x:
                flatten(a, name + str(i) + '_')
                i += 1
        else:
            out[name[:-1]] = x
    flatten(d)

    return out

# Apply the function to each row
employees_flat = [flatten_dict(emp) for emp in employees]
pp.pprint(employees_flat)

我们现在有了一个完全扁平的结构:

图 13 — 展开嵌套字典 (作者提供的图片)

请注意如何在其内部键值对之前添加了address,这样我们仍然能够理解它最初所属的位置。

现在,当你有了一个扁平化的字典列表时,你可以将其转换为 Pandas DataFrame。

嵌套字典 Python 转 Pandas DataFrame

如果你想将嵌套字典转换为 Pandas DataFrame,你必须先将其扁平化。否则,你会得到奇怪的索引,并且可能会得到单个单元格的字典作为值。

让我们首先展示一个不良的做法,以便你能理解为什么要扁平化数据。下面是与我们在文章中使用的员工字典相同的字典。然后我们在调用pd.DataFrame()时使用它:

import pandas as pd

employees = {
    "bdoe@email.com": {
        "first_name": "Bob", 
        "last_name": "Doe",
        "address": {
            "city": "New York",
            "street": "1st Street",
            "house_number": 1
        }
    },
    "mmarkson@email.com": {
        "first_name": "Mark", 
        "last_name": "Markson",
        "address": {
            "city": "San Diego",
            "street": "2nd Street",
            "house_number": 2
        }
    }
}

pd.DataFrame(employees)

这就是结果 DataFrame 的样子:

图 14 — 嵌套字典转 Pandas DataFrame (作者提供的图片)

糟糕且不可用。 你需要将一个扁平化的字典列表传递给 **pd.DataFrame()** 以将数据恢复为适当的格式。

你已经知道如何展开嵌套字典了,所以这应该感觉像是轻松的散步:

employees = [
    {
        "first_name": "Bob", 
        "last_name": "Doe",
        "email": "bdoe@email.com",
        "address": {
            "city": "New York",
            "street": "1st Street",
            "house_number": 1
        }
    },
    {
        "first_name": "Mark", 
        "last_name": "Markson",
        "email": "mmarkson@email.com",
        "address": {
            "city": "San Diego",
            "street": "2nd Street",
            "house_number": 2
        }
    }
]

# Flatten the records first
employees_flat = [flatten_dict(emp) for emp in employees]
pd.DataFrame(employees_flat)

现在,DataFrame 更容易理解和分析:

图 15 — 嵌套字典转 Pandas DataFrame (2) (作者提供的图片)

最后,让我们讲讲嵌套字典到 JSON 转换。

嵌套字典 Python 转 JSON 文件

JSON 和 Python 字典是相辅相成的,这意味着你可以轻松地在 JSON 文件和 Python 字典之间转换。

我们将展示如何将嵌套的 Python 字典转换为 JSON 文件。你需要导入json模块,并将字典传递给json.dumps()。可选的indent参数控制字典内部嵌套结构的缩进。

下面是代码:

import json

employees = {
    "bdoe@email.com": {
        "first_name": "Bob", 
        "last_name": "Doe",
        "address": {
            "city": "New York",
            "street": "1st Street",
            "house_number": 1
        }
    },
    "mmarkson@email.com": {
        "first_name": "Mark", 
        "last_name": "Markson",
        "address": {
            "city": "San Diego",
            "street": "2nd Street",
            "house_number": 2
        }
    }
}

json_object = json.dumps(employees, indent=4) 
print(json_object)

这就是你的 JSON 对象应该呈现的样子:

图 16 — 嵌套字典转 JSON 对象 (作者提供的图片)

你现在可以使用 Python 的上下文管理器语法将 JSON 对象写入文件。以下代码片段将其写入名为employees.json的文件:

with open("employees.json", "w") as f:
    json.dump(employees, f)

你可以在任何文本编辑器或 JupyterLab 中打开 JSON 文件。你会看到类似下面的内容:

图 17 — 嵌套字典转 JSON 文件 (作者提供的图片)

这就是你可以在 Python 中处理嵌套字典的方法。接下来,让我们简短回顾一下。

总结嵌套字典 Python

在 Python 中处理嵌套字典涉及很多内容。你可以访问单个值、修改它们、添加新行、删除旧行、合并多个字典、遍历它们,甚至将整个内容转换为 Pandas DataFrame 或 JSON 文件。

不幸的是,处理嵌套字典并没有一刀切的解决方案。每个项目的结构会有所不同,这意味着你需要定制代码以适应你的场景。特别是在遍历嵌套字典时,这一点尤其重要。

本文应该为你提供一个良好的起点,并涵盖 95%的场景,你可以随时自行深入探讨。

你对嵌套字典最喜欢的是什么?它们在日常数据科学任务中是否让你头疼? 请在评论区告诉我。

喜欢这篇文章?成为 Medium 会员 ,继续无限制地学习。如果你使用以下链接,我将获得你会员费用的一部分,而不会增加你的额外费用。

[## 使用我的推荐链接加入 Medium - Dario Radečić

阅读 Dario Radečić的每一个故事(以及 Medium 上成千上万其他作家的作品)。你的会员费用将直接支持…

medium.com](https://medium.com/@radecicdario/membership?source=post_page-----756a7822cb4f--------------------------------)

最初发表于 https://betterdatascience.com 于 2023 年 4 月 18 日。

解释性神经基础模型

原文:towardsdatascience.com/neural-basis-models-for-interpretability-fd04ac958ff2

解读 Meta AI 提出的新解释性模型

Nakul UpadhyaTowards Data Science Nakul Upadhya

·发表于 Towards Data Science ·6 分钟阅读·2023 年 10 月 11 日

--

机器学习和人工智能在各个领域的广泛应用带来了更高的风险和伦理评估挑战。如在 ProPublica 报道的刑事再犯模型 中所见,机器学习算法可能存在严重的偏见,因此需要强有力的解释机制,以确保这些模型在高风险领域的信任和安全。

那么,我们如何在解释性、准确性和模型表现力之间取得平衡呢?Meta AI 的研究人员提出了一种新的方法,称为神经基础模型(NBMs),这是一种广义加性模型的子家族,在基准数据集上实现了最先进的性能,同时保持了透明的解释性

在这篇文章中,我旨在解释 NBM 及其为何是一个有益的模型。像往常一样,我鼓励大家阅读原始论文。

如果你对解释性机器学习和其他伦理 AI 方面感兴趣,考虑查看我的其他文章并关注我!

Nakul Upadhya

Nakul Upadhya

解释性与伦理 AI

查看列表5 篇故事

背景:GAMs

NBM 被认为是广义加性模型(GAM)。GAM 本质上是可解释的模型,为每个特征学习一个形状函数,预测是通过“查询”形状函数来完成的。由于这些形状函数是独立的,通过可视化这些形状函数可以理解特征对预测的影响,使得模型高度可解释。变量之间的交互通过将多个变量传递到同一函数中并基于此构建形状函数来建模(通常将变量数量限制为 2 以提高互操作性),这种配置称为 GA2M。

GAMs 和 GA2Ms 的方程(图源自 Radenovic 等人 [1])

各种 GAM 和 GA2M 模型使用不同的机制来开发这些形状函数。可解释增强机(EBM)[2] 使用一组对每个特征进行训练的提升树,神经加性模型(NAMs)[3] 对每个特征使用深度神经网络,而 NODE-GAM [4] 使用 无意识神经树[6] 的集合。我推荐阅读以下文章以获得对这些模型的更详细解释:EBM 和 NODE-GAM/NAM

NBM 方法

神经基础模型(NBM)是一类新的广义加性模型(GAMs)子家族,利用形状函数的基础分解。

NBM 架构(图源自 Radenovic 等人 [1])

与其他 GAM 模型(如 NAM[3])不同,后者有效地为每个特征训练独立模型以构建形状函数,而 NBM 架构则依赖于少量的基础函数,这些函数在所有特征中共享,并为特定任务共同学习。这些函数是什么?它们是函数逼近的瑞士军刀:深度神经网络。

实际上,一个通用的 MLP 主干网络接受 1 个输入并输出 B 个值,这些值被训练并应用于每个输入特征。这些输出然后被线性组合以形成给定特征的最终预测,而线性组合的权重对每个特征不同。另一种思考这种架构的方法是通过编码器-解码器网络的视角。所有特征共享相同的 编码器(通用 MLP 主干),但每个特征都有其自己的 解码器(编码的线性变换)。每个特征的解码值然后被加总以生成最终预测。

这可以很容易地扩展到包括特征交互。如果我们想要建模配对交互,我们可以包括一个接受两个输入的 MLP,而不是一个。

NBM 和 NB2M 方程(图源自 Radenovic 等人 [1])

使用共享 MLP 主干而不是为每个特征使用不同的 MLP 的一个好处是模型的显著较小的尺寸。这使得 NBM 非常适合处理极高维度数据的任务。

性能与优势

为了测试他们的架构,Radenovic 等人(2022 年)将 NBM 与各种其他模型进行了比较,如线性回归、EBM [2]、NAM [3]、XGBoost[5] 和 MLP。他们的首次评估是在混合的表格和图像数据集上进行的。

基准性能比较(图来自 Radenovic 等人 [1])

总的来说,NBM 牢牢站稳脚跟,超越了其他可解释模型,甚至在某些数据集上超过了 MLP。

Radenovic 等人(2022 年)还在纯表格数据集上进行了另一项评估,重点是对比 SOTA GAM 模型。

与其他 GAMs 的性能比较:(图来自 Radenovic 等人 [1])

这个比较清楚地展示了 NBM 的强大,几乎在每个数据集上都击败了竞争对手。如前所述,NBM 的可扩展性也非常出色。如下所示,在高维数据任务中,NBM 的参数数量几乎是 NAM 的 70 分之一。

NAM/NBM 参数比较。X 轴是数据的维度(图来自 Radenovic 等人 [1])

结论

总体而言,NBM 是极其强大且轻量级的模型,由于其为 GAM,本质上是可解释的。然而,这并不意味着它是解决高风险机器学习问题的灵丹妙药。在使用这些模型时,仍然需要考虑许多因素。例如,一个本质上可解释的模型几乎没有意义,如果输入到模型中的特征不可解释

此外,虽然 NBM 的规模相较于 NAM 扩展得很好,但可解释性却没有。没有人能查看数千个特征归因图,特别是当这些归因图还包含成对交互时。这意味着在大参数空间下,仍然需要预处理方法,如特征选择,甚至作者也承认了这一点。然而,这并不贬低作者,因为这是一个仍然非常有用且相对容易实现和调整的模型。

该模型是 GAM 的事实对于机器学习应用(如移动设备和其他性能较低的设备)也非常有利,因为用户可以训练模型并部署生成的特征归因函数,而不是完整模型,从而实现极其快速且内存轻量的推理,而不会损失准确性。

资源与参考文献

  1. NBM 代码:github.com/facebookresearch/nbm-spam

  2. NBM 开放评论:openreview.net/forum?id=fpfDusqKZF

  3. 如果你对可解释的机器学习或时间序列预测感兴趣,可以考虑关注我:medium.com/@upadhyan

  4. 查看我关于可解释机器学习的其他文章:medium.com/@upadhyan/list/interpretable-and-ethical-ai-f6ee1f0b476d

参考文献

[1] Radenovic, F.、Dubey, A. 和 Mahajan, D.(2022)。用于可解释性的神经基础模型。神经信息处理系统进展35,8414–8426。

[2] Yin L.、Rich C.、Johannes G. 和 Giles H.(2013)准确可理解的模型与配对交互。在第 19 届 ACM SIGKDD 国际会议上的知识发现与数据挖掘论文集,623–631. 2013

[3] Agarwal, R.、Melnick, L.、Frosst, N.、Zhang, X.、Lengerich, B.、Caruana, R. 和 Hinton, G. E.(2021)。神经加性模型:使用神经网络的可解释机器学习。神经信息处理系统进展34,4699–4711。

[4] Chang, C.H.、Caruana, R. 和 Goldenberg, A.(2022)。NODE-GAM:用于可解释深度学习的神经广义加性模型。在国际学习表征会议

[5] Chen, T. 和 Guestrin, C.(2016 年 8 月)。Xgboost:一个可扩展的树提升系统。在第 22 届 ACM SIGKDD 国际会议上的知识发现与数据挖掘论文集(第 785–794 页)。

[6] Popov, S.、Morozov, S. 和 Babenko, A.(2019)。神经网络忽略决策集成用于表格数据的深度学习。在第八届国际学习表征会议

神经图数据库

原文:towardsdatascience.com/neural-graph-databases-cc35c9e1d04f?source=collection_archive---------0-----------------------#2023-03-28

图神经网络数据库的最新进展

图数据管理的新里程碑

Michael GalkinTowards Data Science Michael Galkin

·

关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 3 月 28 日

--

我们引入了神经图数据库的概念,作为图数据库发展的下一步。神经图数据库专为大规模不完整图设计,并利用图表示学习进行缺失边的即时推理。神经推理保持了较高的表达能力,支持类似于标准图查询语言的复杂逻辑查询。

图片由作者提供,辅助工具为 Stable Diffusion。

本文由 Hongyu Ren, Michael Cochez* 和* Zhaocheng Zhu 共同撰写,基于我们最新的论文 Neural Graph Reasoning: Complex Logical Query Answering Meets Graph Databases。你也可以关注 , Hongyu, Michael* 和* Zhaocheng 在 Twitter 上的动态。查看我们的 项目网站 获取更多资料。

概述

  1. 神经图数据库:什么和为什么?

  2. NGDBs 的蓝图

  3. 神经图存储

  4. 神经查询引擎

  5. 查询引擎的神经图推理

  6. NGDBs 的开放挑战

  7. 了解更多

神经图数据库:什么和为什么?

🍨香草图数据库几乎随处可见,这要归功于不断增长的生产图、灵活的图数据模型和富有表现力的查询语言。经典的符号图数据库在一个重要假设下运行得又快又酷:

完整性。查询引擎假设经典图数据库中的图是完整的。

在完整性假设下,我们可以构建索引,以多种读写优化格式存储图,并期望数据库返回有什么

但这一假设在实际中往往不成立(我们会说,几乎总是不成立)。例如在一些突出的知识图谱(KGs)中:在 Freebase 中,93.8%的人没有出生地,78.5%没有国籍,约 68%的人没有任何职业,而在 Wikidata 中,约 50%的艺术家没有出生日期,只有0.4%的已知建筑有高度信息。这仅仅是由数百名爱好者公开编辑的最大 KG,100M 节点和 1B 语句并不是行业中最大的图,所以你可以想象其不完整性的程度。

显然,为了考虑不完整性,除了“有什么?”我们还必须问“缺少什么?”(或“可以有什么?”)。让我们来看一个例子:

(a) - 输入查询;(b) — 带有预测边(虚线)的不完整图;(c) — 通过图遍历返回一个答案(UofT)的 SPARQL 查询;(d) — 神经执行恢复缺失的边,并返回两个新答案(UdeM, NYU)。图片来源:作者。

在这里,给定一个不完整的图(缺失边 (Turing Award, win, Bengio)(Deep Learning, field, LeCun))以及一个查询 “在深度学习领域的图灵奖得主在哪些大学工作?”(以逻辑形式或类似 SPARQL 的语言表达),符号图数据库只会返回一个通过图遍历得到的答案 UofT。我们将这种答案称为 简单 答案,或现有答案。考虑到缺失的边,我们可以恢复两个更多的答案 UdeMNYU困难 答案,或推断答案)。

如何推断缺失的边?

  • 在经典数据库中,我们选择不多。基于 RDF 的数据库具有一些形式语义,可以由庞大的 OWL 本体支持,但根据图的大小和推理的复杂性,在 SPARQL 推理规则 中完成推理可能需要无限的时间。标记属性图(LPG)数据库完全没有内置的推断缺失边的手段。

  • 得益于图机器学习的进展,我们通常可以在潜在(嵌入)空间中以线性时间执行链接预测!然后,我们可以将这种机制扩展到在嵌入空间中执行复杂的、类似数据库的查询。

神经图数据库结合了传统图数据库和现代图机器学习的优势。

即,数据库原则如(1)图作为一等公民,(2)高效存储,以及(3)统一查询接口,现在由图 ML 技术支持,如(1)几何表示,(2)对噪声输入的鲁棒性,(3)大规模预训练和微调,以弥合不完整性差距并实现神经图推理和推断。

一般来说,NGDB 的设计原则包括:

  • 数据不完整性假设 — 潜在数据可能在节点、链接和图级别上缺少信息,我们希望推断并在查询回答中加以利用;

  • 归纳性和可更新性 — 类似于传统数据库,允许更新和即时查询,构建图潜变量的表示学习算法必须具有归纳性,并以零样本(或少样本)方式对未见数据(新实体和关系)进行泛化,以防止昂贵的再训练(例如,浅层节点嵌入);

  • 表达能力 — 潜在表示在数据中编码逻辑和语义关系的能力,类似于 FOL(或其片段),并在查询回答中加以利用。实际上,神经推理支持的逻辑操作符集应接近或等同于标准图数据库语言,如 SPARQL 或 Cypher;

  • 超越知识图谱的多模态性——任何可以作为节点或记录存储在经典数据库中的图结构数据(例如图像、文本、分子图或带时间戳的序列),并且可以赋予向量表示的,都是神经图存储和神经查询引擎的有效来源。

解决 NGDB 原则的关键方法是:

  • 向量表示作为原子元素——虽然传统的图数据库在许多索引中对邻接矩阵(或边列表)进行哈希处理,但不完全性假设意味着给定的边图潜在(向量表示)都成为真理的来源,在神经图存储中。

  • 在潜在空间中的神经查询执行——由于不完全性假设,基本操作如边遍历不能仅通过符号操作来执行。相反,神经查询引擎在邻接和图潜在空间上操作,以将可能缺失的数据纳入查询回答中;

实际上,通过在潜在空间中回答查询(且不牺牲遍历性能),我们可以完全抛弃符号数据库索引。

符号图数据库和神经图数据库之间的主要区别:传统的数据库通过边遍历回答“有什么?”的问题,而神经图数据库还会回答“缺少什么?”的问题。图像来源:作者。

NGDBs 的蓝图

在深入了解 NGDBs 之前,我们先来看一下神经数据库的一般情况——事实证明它们已经存在了一段时间,你可能已经注意到了。许多机器学习系统在数据被编码为模型参数时,已经在这一范式下运行,而查询相当于前向传播,可以为下游任务输出新的表示或预测。

神经数据库概述

神经数据库的现状如何?它的不同种类之间有什么区别,NGDBs(神经图数据库)有什么特别之处?

向量数据库、自然语言数据库和神经图数据库之间的区别。图像来源:作者

  1. 向量数据库属于存储导向的系统,这些系统通常基于近似最近邻库(ANN),如FaissScaNN(或定制解决方案)来回答基于距离的查询,使用最大内积搜索(MIPS)、L1、L2 或其他距离。由于向量数据库与编码器无关(即,任何生成向量表示的编码器,如 ResNet 或 BERT,都可以作为来源),它们速度很快,但缺乏复杂的查询回答能力。

  2. 最近,随着大规模预训练模型的崛起——或称为基础模型——我们见证了它们在自然语言处理和计算机视觉任务中的巨大成功。我们认为,这些基础模型也是神经数据库的一个重要例子。在这些模型中,存储模块可能直接以模型参数的形式呈现,或者外包给一个外部索引,这在检索增强模型中常常使用,因为将所有世界知识编码到即便是数十亿个模型参数中也是困难的。查询模块通过填充编码器模型(BERT 或 T5 风格)中的空白或通过解码器模型(GPT 风格)中的提示,进行上下文学习,这些提示可以跨越多种模式,例如视觉应用的可学习标记或甚至调用外部工具

  3. Thorne et al介绍的自然语言数据库 (NLDB)将原子元素建模为通过预训练语言模型(LM)编码为向量的文本事实。对 NLDB 的查询以自然语言表达的形式发送,这些查询被编码为向量,查询处理采用检索器-阅读器方法。

神经图数据库并不是一个新名词——许多图机器学习方法尝试将图嵌入与数据库索引结合起来,或许RDF2VecLPG2Vec是一些最显著的例子,展示了如何将嵌入插件到现有图数据库中,并在符号索引之上运行。

相比之下,我们认为 NGDB 可以在潜在空间中无需符号索引直接工作。如下面所示,存在能够模拟嵌入空间中精确边遍历行为的机器学习算法,以检索“那里有什么”,并进行神经推理以回答“缺少什么”。

神经图数据库:架构

神经图数据库的概念图。输入查询由神经查询引擎处理,其中规划器导出查询的计算图,执行器在潜在空间中执行查询。神经图存储使用图存储和特征存储在嵌入存储中获取潜在表示。执行器与嵌入存储通信,以检索和返回结果。图像来源于作者

在更高层次上,NGDB 包含两个主要组件:神经图存储神经查询引擎。查询回答流程从某些应用程序或下游任务发送的已结构化格式的查询开始(例如,通过语义解析将初始自然语言查询转换为结构化格式)。

查询首先到达神经查询引擎,特别是到查询规划器模块。查询规划器的任务是根据查询复杂性、预测任务和底层数据存储(如可能的图划分)生成一个高效的原子操作计算图。

生成的计划随后被送往查询执行器,该执行器将查询编码到潜在空间中,执行对底层图及其潜在表示的原子操作,并将原子操作的结果聚合成最终答案集。执行是通过与神经图存储通信的检索模块完成的。

存储层包括

1️⃣ 图存储 用于以空间和时间高效的方式保存多关系邻接矩阵(例如,以各种稀疏格式如 COO 和 CSR)。

2️⃣ 特征存储 用于保存与底层图相关的节点和边级多模态特征。

3️⃣ 嵌入存储 利用编码器模块生成基于底层邻接和相关特征的潜在空间中的图表示。

检索模块查询编码后的图表示,以构建潜在答案的分布。

神经图存储

在传统的图数据库(右侧),查询被优化为一个计划(通常是一个连接操作符的树),并执行于数据库索引的存储中。在神经图数据库(左侧)中,我们将查询(或其步骤)编码到一个潜在空间中,并在底层图的潜在空间中执行。图像由作者提供。

在传统的图数据库中,存储设计通常取决于图建模范式。

两种最流行的范式是资源描述框架(RDF)图和标记属性图(LPG)。然而,我们认为新的 RDF-star(及其伴随的 SPARQL-star)将统一这两种范式,将 RDF 图的逻辑表达性与 LPG 的属性特性融合起来。许多现有的知识图谱已经遵循了类似 RDF-star 的范式,如 超关系知识图谱 和 Wikidata Statement Model

如果我们展望未来几年的骨干图建模范式,我们会选择 RDF-star。

在神经图存储中,输入图及其向量表示都是事实来源。为了在潜在空间中回答查询,我们需要:

  • 查询编码器

  • 图编码器

  • 检索机制用于将查询表示与图表示进行匹配

图编码(嵌入)过程可以视为一个压缩步骤,但保留了实体/关系的语义和结构相似性。嵌入空间中实体/关系之间的距离应该与语义/结构相似性正相关。编码器的架构有很多选择——我们建议坚持使用归纳型的,以遵循 NGDB 设计原则。在我们最近的NeurIPS 2022 工作中,我们展示了两个这样的归纳模型。

查询编码通常与自然图编码相匹配,使得它们处于同一空间。一旦我们有了潜在表示,检索模块就会启动以提取相关答案。

检索过程可以被视为在嵌入空间中对输入向量的最近邻搜索,并且具有 3 个直接好处:

  1. 每个检索项的置信度评分——多亏了嵌入空间中预定义的距离函数。

  2. 潜在空间和距离函数的不同定义——针对不同的图,例如,树状图在双曲空间中更易于处理。

  3. 效率和可扩展性——检索可以扩展到包含数十亿节点和边的极大图。

神经查询引擎

NGDBs(左)和传统图数据库(右)的查询规划。NGDB 规划(假设图不完整)可以逐步自回归执行(1)或完全生成一个步骤(2)。传统数据库规划是基于成本的,并且依赖于元数据(假设图完整并从中提取),例如中间答案的数量来构建连接操作符的树。图片由作者提供

在传统数据库中,典型的查询引擎执行三个主要操作。(1) 查询解析以验证语法正确性(通常会进行更深层次的语义分析);(2) 查询规划和优化以得出一个有效的查询计划(通常是关系操作符的树),以最小化计算成本;(3) 查询执行根据查询计划扫描存储并处理中间结果。

将这些操作扩展到 NGDBs 是相当简单的。

1️⃣ 查询解析可以通过语义解析转化为结构化查询格式。我们故意将 NGDBs 的查询语言讨论留待未来的工作和热烈的公众讨论😉

2️⃣ 查询规划器得出原子操作(投影和逻辑操作符)的有效查询计划,最大化完整性(必须返回所有现有边上的答案)和推断(即时预测缺失边)同时考虑查询复杂性和底层图。

3️⃣ 一旦查询计划完成,查询执行器将查询(或其部分)编码到潜在空间,与图存储及其检索模块进行通信,并将中间结果聚合到最终答案集中。查询执行存在两种常见机制:

  • 原子,类似于传统数据库,当查询计划按顺序执行,通过编码原子模式、检索其答案和执行逻辑操作作为中间步骤;

  • 全局,当整个查询图在一个步骤中被编码并在潜在空间中执行。

神经查询执行的主要挑战是将查询表达能力与 SPARQL 或 Cypher 等符号语言匹配——迄今为止,神经方法可以执行接近一阶逻辑表达能力的查询,但在符号语言方面还差一半。

神经图推理的分类学用于查询引擎

自 2018 年以来,关于复杂逻辑查询回答的神经方法(即查询嵌入)的文献不断增加,特别是 Hamilton 等人图查询嵌入(GQE)方面的开创性 NeurIPS 工作。GQE 能够回答带有交集的联接查询,并实时预测缺失的链接。

GQE 可以被视为对 NGDBs 的神经查询引擎的第一次尝试。

GQE 开创了图机器学习的整个子领域,随后出现了一些著名的例子,如 Query2Box (ICLR 2020)Continuous Query Decomposition (ICLR 2021)。我们进行了一项重大工作,将所有这些(约 50 项)工作按 3 个主要方向进行了分类:

⚛️ ——我们回答查询的基础结构是什么;

🛠️ 建模——我们如何回答查询以及采用了哪些归纳偏差;

🗣️ 查询——我们回答什么,查询结构是什么,以及预期的答案是什么。

复杂逻辑查询回答的神经方法分类。有关更多详细信息,请参见。图像由作者提供

⚛️ 说到,我们进一步将其细分为模态(经典的三元组图、超关系图、超图等)、推理领域(离散实体或包括连续输出)和语义(神经编码器如何捕捉更高阶关系,如 OWL 本体)。

🛠️ 在建模中,我们遵循编码器-处理器-解码器范式,对现有模型的归纳偏差进行分类,例如,具有神经或神经符号处理器的传递性或归纳编码器。

🗣️ 在 查询 中,我们的目标是将神经方法能够回答的查询集与符号图查询语言的查询集进行映射。我们讨论查询操作符(超越标准的与/或/非),查询模式(从链状查询到 DAG 和循环模式),以及投影变量(你喜欢的关系代数)。

NGDB 的开放挑战

分析分类法时,我们发现目前没有银弹,例如,大多数处理器只能在离散模式下处理基于树的查询。但这也意味着未来有很大的工作空间——可能包括你的贡献!

更具体地说,以下是未来几年 NGDB 的主要挑战。

沿着 分支:

  • 模态:支持更多图的模态:从经典的仅三元组图到超关系图、超图以及结合图、文本、图像等的多模态源。

  • 推理领域:支持对时间和连续(文本和数值)数据进行逻辑推理和神经查询回答——字面量构成了图的大部分以及对字面量的相关查询。

  • 背景语义:支持复杂公理和形式语义,这些语义编码了(潜在的)实体类及其层次结构之间的高阶关系,例如,支持对描述逻辑和 OWL 片段进行神经推理。

建模 分支:

  • 编码器:支持在推理时处理未见过的关系——这是(1)可更新性的关键,无需重新训练即可更新神经数据库;(2)启用预训练-微调策略,将查询回答推广到具有自定义关系模式的自定义图。

  • 处理器:表达性处理器网络能够有效且高效地执行类似于 SPARQL 和 Cypher 操作符的复杂查询操作符。提高神经处理器的样本效率对于训练时间与质量权衡至关重要——在保持高预测质量的同时减少训练时间。

  • 解码器:迄今为止,所有神经查询回答解码器仅在离散节点上操作。扩展答案范围到连续输出对于回答现实世界的查询至关重要。

  • 复杂性:由于处理器网络的主要计算瓶颈是嵌入空间的维度(对于纯神经模型)和/或节点数(对于神经-符号模型),新型高效的神经逻辑操作符和检索方法是将 NGDB 扩展到数十亿节点和万亿边的关键。

查询 中:

  • 操作符:使更复杂的查询操作符具备与声明式图查询语言相匹配的表达能力,例如,支持克林星号和加号、属性路径、过滤器。

  • 模式:回答比树状查询更复杂的模式,包括 DAG 和循环图。

  • 投影变量:允许投影超出最终叶节点实体,即允许返回中间变量、关系以及组织在元组(绑定)中的多个变量。

  • 表达能力:回答超出简单 EPFO 和 EFO 片段的查询,并追求数据库语言的表达能力。

最后,在数据集评估方面:

  • 需要更大且多样化的基准,涵盖更多图模式、更具表现力的查询语义、更多查询操作符和查询模式。

  • 由于现有的评估协议似乎有限(仅关注推断答案),需要一个更有原则的评估框架和指标,涵盖查询回答工作流的各个方面。

关于神经图存储和 NGDB 的一般情况,我们识别出以下挑战:

  • 需要一个可扩展检索机制来将神经推理扩展到数十亿节点的图。检索与查询处理器及其建模先验紧密相关。现有的可扩展 ANN 库只能处理基本的 L1、L2 和余弦距离,这限制了神经查询引擎中可能的处理器空间。

  • 目前,所有复杂查询数据集提供了一个硬编码的查询执行计划,可能不是最优的。需要一个神经查询规划器,能够将输入查询转换为最优执行序列,考虑预测任务、查询复杂性、神经处理器类型和存储层配置。

由于编码器的归纳性和可更新性而无需重新训练,运行推理时在比训练图更大的图上存在需要缓解持续学习灾难性遗忘规模泛化的问题。

了解更多

NGDB 仍然是一个新兴概念,面临许多未来研究的挑战。如果你想了解更多关于 NGDB 的内容,可以查看

我们还将组织研讨会,请关注最新动态!

神经网络 — 初学者指南 (1.1)

原文:towardsdatascience.com/neural-networks-a-beginners-guide-7b374b66441a?source=collection_archive---------11-----------------------#2023-03-20

建立关于神经网络的直觉

ShwetaTowards Data Science Shweta

·

关注 发表在 Towards Data Science ·10 min read·2023 年 3 月 20 日

--

照片由 La-Rel Easter 拍摄,来源于 Unsplash

深度学习在过去十年中经历了巨大的增长。它在图像分类、语音识别、文本转语音、自驾车等方面都有应用,深度学习解决的问题列表非常重要。因此,理解神经网络的基本结构和工作原理对于欣赏这些进展是必要的。

让我们深入探讨学习。

1. 神经网络的构建模块

神经网络是一个计算学习系统,通过使用底层的非线性映射函数来将输入变量映射到输出变量。

它包含五个基本组件:

a. 节点和层

b. 激活函数

c. 损失函数

d. 优化器

我们将详细了解这些组件。

层:

简而言之,神经网络是一系列相互连接的层。神经网络中有三种层类型:输入层 — 接受输入数据,隐藏层 — 转换输入数据,输出层 — 在应用转换后为给定的输入生成预测。接近输入层的层称为下层,接近输出层的层称为上层

每一层由多个神经元组成,也称为节点。给定层中的每个节点与下一层中的每个节点相连。节点接收来自上一层的加权输入总和,应用非线性激活函数,并生成一个输出,该输出随后成为下一层节点的输入。

考虑一个常见的分类问题,例如预测贷款申请者是否会违约。输入变量包括申请者年龄、就业类型、赡养人数、居住地、贷款价值比等。这些变量将组成输入层。

输入层中的节点数量对应于数据中的独立变量数量。隐藏层的数量以及这些层中的节点数是超参数,通常是问题复杂性和可用数据的函数。

在复杂问题中,层的数量和每层中的节点数量将更多,每个隐藏层将学习在上一层未学到的表示。这些神经网络被称为‘深度神经网络’。

对于回归问题,输出层中的节点数量为 1;对于多分类问题,输出层中的节点数量等于标签/类别数量;对于二分类问题,输出层中的节点数量为 1。

神经网络的工作原理可以归结为给定层中的单个节点。

神经网络中单个节点的工作原理(图片由作者提供)

如上所示,单个节点接受以下输入 — 偏置 b 和输入变量 x1 及 x2。它还接受另一个参数作为输入 — 每个独立变量的权重。权重表示输入变量的重要性。

节点将处理加权输入总和,如下所示:

z = w1x1 + w2x2 + bias(公式 1)

然后在给定层中的每个节点上应用激活函数以生成输出。应用激活函数后由节点生成的输出是 a。

f(z) = a(公式 2)

这是神经网络中单层单节点的工作原理。具有多个层和节点的网络也按照相同的原则运行。

2 层神经网络(作者提供的图片)

除了加权输入外,我们还可以看到在上述公式 1 中有一个叫做偏置 ‘b’ 的项。偏置在神经网络中有什么作用?

偏置 是一个帮助激活节点的变量。偏置是激活节点所需的阈值的负值。在给定层中的所有节点中使用一个单独的偏置值。

批次数据通过输入层传递,输入层将其发送到第一个隐藏层。第一个隐藏层中的神经元将基于激活函数的输出进行激活,激活函数接收输入的加权和与偏置并计算特定范围内的一个数字。

这引出了下一个问题 — 什么是激活函数,我们为什么需要它?

激活函数

简单的术语来说:

激活函数用于将节点的输入转换为传递到下一个隐藏层节点的输出值。

技术术语来说:

激活函数,也称为传递函数,定义了如何将输入的加权和与偏置转换为给定层中的节点的输出。它将输出值映射到特定范围,即 0 到 1 或 -1 到 +1,具体取决于所用的函数类型。

神经网络中使用的激活函数有两种类型 — 线性和非线性。

  1. 线性激活函数:

公式为 f(x) = b + Sigma( wi * xi),对所有输入变量 (i) 进行索引。

该函数的范围是:— 无穷大到 + 无穷大。

线性激活函数用于神经网络的外层,以解决回归问题。在输入层或隐藏层中使用它不是一个好主意,因为网络将无法捕捉底层数据中的复杂关系。

2. 非线性激活函数:

非线性激活函数默认是深度学习中最常用的激活函数。这些包括 Sigmoid 或 Logistic 函数、修正线性激活函数(ReLU)和双曲正切函数(Tanh)。

让我们更详细地了解每个。

  1. Sigmoid 激活函数

也称为 Logistic 函数,它接受任何实值作为输入,并在 0 和 1 之间给出输出。

公式为 y = 1/(1+ e^-z),具有 S 形曲线。这里的 z = b + sigma(xi * wi),对 i 输入变量进行索引。

对于非常大的正数 z,e^-z 将为 0,函数的输出将为 1。对于非常大的负数 z,e^-z 将是一个大数,因此函数的输出将为 0。

2. 修正线性激活函数 (ReLU):

它是今天使用最广泛的激活函数。ReLU 具有对所有大于 0 的输入值是线性的,而对其他值则是非线性的属性。

它表示为 f(x) = max(0,x)

3. 双曲正切激活函数:

类似于逻辑函数,它接受任何实数作为输入,并输出范围在 -1 和 +1 之间的值。

它表示为:f(x) = (e^z — e^-z) / (ez+e-z)。其中 z = b + sigma(xi * wi),索引为 i 个输入变量。

Tanh 函数的形状也是 S 形的,但范围不同。

通常在所有层中使用一个激活函数,唯一的例外是输出层。输出层使用的激活函数取决于问题陈述是否要求我们预测一个连续值,即回归,或一个分类值,即二分类或多标签分类。

因此,神经元可以定义为一个包含两个部分的操作——线性组件和激活组件,即神经元 = 线性 + 激活。

上述所有函数及其变体都有一些限制,我将在下一篇文章中介绍。

那么神经网络是如何学习的?

所有参数的权重都以一些随机值进行初始化。加权和被传递到网络的第一个隐藏层。

第一个隐藏层将计算所有神经元的输出,并将其传递给下一个隐藏层中的神经元。请注意,每层的输入值都通过激活函数进行转换,然后发送到下一层。

这种流动会持续到达最后一层,然后计算最终的预测。这种从输入层到输出层的单向流动称为‘前向传播’或‘前向传递’。

我们的网络现在已经生成了最终输出。接下来发生什么?

损失函数

将预测值与实际值进行比较并计算误差。误差的大小由损失函数给出。

损失函数将估计预测值的分布与训练数据中实际目标变量的分布的接近程度。

最大似然估计(MLE)框架用于计算整个训练数据上的误差。它通过估计预测的分布与训练数据中目标变量的分布的匹配程度来完成这一点。

在 MLE 框架下,分类问题的损失函数是 交叉熵,回归问题的损失函数是 均方误差

交叉熵 量度了两个概率分布之间的差异。在神经网络的背景下,它表示预测概率分布与训练数据集中目标变量分布之间的差异 对于给定的一组权重或参数。

对于二分类问题,使用的损失函数是二分类交叉熵;对于多分类问题,使用的损失函数是类别交叉熵。

例如,考虑一个与客户贷款违约相关的二分类问题。假设训练数据包含 5 个客户。

神经网络在第一次前向传播中将计算客户违约的概率。网络为所有 5 个客户生成的输出分别是[0.65, 0.25, 0.9, 0.33, 0.45]。

训练数据中观测值的实际值为[1, 1, 1, 1, 1]。

交叉熵损失定义如下:

图片由作者提供

使用这个方程,上述问题的交叉熵损失(CEL)计算如下:

图片由作者提供

在这里,二分类交叉熵计算了一个分数,这个分数总结了实际概率分布和预测概率分布之间的平均差异,以预测类别 1。给定目标变量的实际值和预测值的损失为0.404。我们如何解释这个值?它有一个相对的解释。最终模型的损失值将远低于 0.404。第五个也是最后一个构建块将帮助我们达到那个最优值。它通过寻找最优的权重和偏置值来最小化损失函数,从而实现这一点。

在多分类问题中,其中目标变量编码为 1 到 n-1 类,类别交叉熵将计算一个分数,这个分数总结了所有类别的实际概率分布和预测概率分布之间的平均差异。

类似地,对于回归问题,均方误差(MSE)是最常用的回归损失函数。MSE 计算为目标变量的预测值与实际值之间的平方差的平均值。由于它是误差的平方,输出总是为正。

MSE 有一些变体,如均方对数误差损失(MSLE)和均值绝对误差(MAE)。选择取决于多个因素,如异常值的存在、目标变量的分布等。

网络在第一次前向传播中生成的输出是由初始化为某些随机值的权重决定的。损失函数比较实际值和预测值并计算误差。下一步是通过改变权重来最小化误差。网络如何实现这一点?

这将我们引入神经网络的最后一个构建块,即优化器。

5. 优化器

如前面部分所讨论的,在神经网络中,学习发生在权重中。训练神经网络涉及到学习所有层中所有神经元的正确权重。这通过使用随机梯度下降算法结合反向传播算法来实现。

鉴于这是一个比上述内容更复杂的概念,我们将在下一篇文章中详细探讨这一点。这里涉及的所有构建块也值得在后续文章中做更详细的解释。

本文的关键要点是 最终的神经网络模型是整体架构的函数,即节点数量、层数等,以及参数(也称为权重)的最佳值。一旦我们解决了这两个组件,就可以自信地预测目标变量。

这里是我找到的一些在理解这个概念方面非常有帮助的链接。

  1. youtu.be/PySo_6S4ZAg — 这是斯坦福大学 CS230 神经网络课程,由 Andrew Ng 主讲。

  2. amzn.eu/d/6U4c3GR — 《用 Python 进行深度学习(第 2 版)》。一本很棒的书。概念用非常简单的语言解释。

  3. machinelearningmastery.com/ — 这是一个涵盖深度学习和机器学习所有基础和中级问题的资源。

希望到现在你对神经网络有了一些理解,并且了解了各种构建块如何结合在一起解决深度学习问题。请告诉我你的想法。

神经网络作为决策树

原文:towardsdatascience.com/neural-networks-as-decision-trees-89cd9fdcdf6a?source=collection_archive---------5-----------------------#2023-04-03

图片由 Jens Lelie 提供,来源于 Unsplash

将神经网络的强大功能与决策树的可解释结构结合起来

Nakul Upadhya数据科学之路 Nakul Upadhya

·

关注 发布于 数据科学之路 · 9 分钟阅读 · 2023 年 4 月 3 日

--

人工智能的近期繁荣清楚地展示了深度神经网络在各种任务中的强大能力,尤其是在数据维度高且与目标变量之间存在复杂非线性关系的分类问题领域。然而,解释任何神经分类器的决策是一个非常困难的问题。虽然许多后置方法如 DeepLift [2] 和 Layer-Wise Relevance Propagation [3] 可以帮助解释单个决策,但解释全局决策机制(即模型通常寻找的内容)则更加困难。

因此,许多高风险领域的从业者更倾向于选择更具可解释性的模型,如基本的决策树,因为决策层级可以被利益相关者清晰地可视化和理解。然而,基本的决策树往往不能提供足够的准确性,通常会使用集成方法如 Bagging 或 Boosting 来提高模型的性能。不过,这又牺牲了一些可解释性,因为要理解一个单一的决策,从业者需要查看数百棵树。然而,这些方法仍然比深度网络更受欢迎,因为至少特征重要性(无论是局部还是全局)可以被轻松提取和展示。

因此,目前的问题是我们想要神经网络的区分能力,但又希望具备决策树的可解释性。那么,为什么不把网络结构化成一棵树呢?这就是 Fross 和 Hinton(2017)在他们的论文“将神经网络提炼成软决策树”[1]中采用的主要方法。在这篇文章中,我将深入探讨神经决策树背后的关键机制,并解释这种方法的一些优点以及在实际应用中可能需要考虑的一些因素。虽然我们主要讨论分类树,但详细的方法也可以应用于回归树,只需进行一些相对较小的调整。

方法论

软决策树与硬决策树

在深入了解如何将神经网络构建成软决策树之前,让我们首先定义什么是软决策树。

当人们想到决策树(例如 sklearn 中实现的决策树)时,他们想到的是每个决策都是确定性的硬决策树。

硬决策树的示例(图片由作者提供)

如果满足某个条件,我们将走向左分支,否则我们走向右分支。每个叶子节点都有一个类别,通过简单地遍历树并选择我们最终到达的类别来进行预测。我们允许树生长得越大,可以采取的路径就越多,从而实现最终决策。

软决策树有许多相似之处,但工作方式略有不同

软决策树的示例(图片由作者提供)

在硬决策树中,每个分支是确定性的,而软决策树定义了在满足条件的情况下进入某个分支的概率。因此,虽然硬决策树输出一个单一值,软决策树则输出所有可能类别的概率分布,其中类别的概率是我们通过到达叶子的概率的乘积。 例如,上面树的批准概率等于 P(b1|X)(1-P(b2|X)) + (1-P(b2|X))(P(b3|X))。 分类决策就是选择具有最高概率的类别。

这种结构有许多优点。首先,非确定性决策使用户了解给定分类中的不确定性。此外,从技术上讲,硬树只是软树的特殊变体,其中所有分支概率都等于 1。

这些树的一个缺点是解释性略有下降。从利益相关者的角度来看,“我们批准了一个贷款,因为个人年收入为 $100k,债务少于 $400k”比起:

如果收入是 $110k,我们有 0.7 的概率向右分支;如果债务低于 400k,我们有 0.8 的概率批准,这样结果就是 0.56 的概率加上左分支中发生的情况。

这并不意味着这些树不可解释(因为仍然可以确切看到模型关注的内容),只是需要模型开发者提供更多的帮助。

倾斜决策树

在了解神经决策树之前,第二个需要掌握的概念是“倾斜”决策树的概念。

传统的决策树被认为是“正交”树,因为它们的决策是相对于给定的轴进行的。简单来说,每次决策中只使用一个变量。另一方面,倾斜树在决策过程中使用多个变量,通常是线性组合的形式。

倾斜决策边界的示例(图源:Zhang et. al 2017 [4])

决策节点中的一些示例值可能是“收入 — 债务 > 0”。这可以导致更强的决策边界。一个缺点是,如果没有适当的正则化,这些边界可能会变得越来越复杂。

将它们结合起来

现在我们理解了软决策树和倾斜决策树,我们可以将它们结合起来理解神经公式。

第一个组成部分是决策节点。对于每个节点,我们需要基于输入值的一些概率。为了实现这一点,我们可以使用神经网络的基本工具:权重和激活。在每个决策节点中,我们首先对输入变量进行线性组合,然后对总和应用一个 sigmoid 函数,得到分支概率。

为了防止极软的决策(使决策树更像硬决策树),可以使用温和的 sigmoid(或在应用 sigmoid 之前对线性组合进行乘法运算)。

每个叶节点包含一个 N 维张量,其中 N 是类别的数量。这个张量表示样本属于某一类别的概率分布。

神经网络作为决策树(图像复制自 Frosst & Hinton 2017 [1])

与软决策树一样,这棵神经树的输出是类别的概率分布。输出分布等于分布的总和乘以到达该分布的路径概率。

训练树

神经树的一个好处是可以通过像梯度下降这样的连续优化算法进行训练,而不是像普通决策树需要构建的贪婪算法。我们需要做的就是定义损失函数:

神经树的损失函数(图像来自 Frosst & Hinton 2017 [1])

这棵树的损失函数类似于交叉熵损失。在这个方程中,P^l(x) 是在给定数据点 x 的情况下到达叶节点 l 的概率,T_k 是目标类别 k 的概率(1 或 0),而 Q_k^l 是叶节点 l 中与类别 k 对应的张量(概率分布)元素。

关于这一结构的一个重要说明是树形结构是固定的。与使用贪婪算法逐个拆分节点并生长树的普通决策树不同,使用这种软决策树时,我们首先设置树的大小,然后使用梯度下降同时更新所有参数。这种方法的一个好处是更容易在不损失太多判别能力的情况下约束树的大小。

在训练过程中可能遇到的一个潜在陷阱是模型可能过度偏向单个分支,而未能利用树的全部力量。为了避免陷入不良解决方案,建议在损失函数中引入惩罚,鼓励树同时利用左右子树。

惩罚是期望平均分布(左右树各 50/50)与实际平均分布(定义为 alpha)之间的交叉熵。

节点 i 的 alpha 定义(图像来自 Frosst & Hinton 2017 [1])

在这个方程中,P^i(x) 是从根节点到节点 i 的路径概率。我们然后对所有内部节点的惩罚进行求和。

惩罚的定义(图像来自 Frosst & Hinton 2017 [1])

在这个方程中,lambda 是一个超参数,决定了惩罚的强度。然而,这可能会导致一些问题,因为随着树的下降,数据分裂成 50/50 的机会减少,因此建议使用根据树的深度变化的自适应 lambda。这将修改惩罚为:

修改后的惩罚函数(图片由作者提供,摘自 Frosst & Hinton 2017 [1])

当我们深入树中时,建议根据 2^-d 比例衰减 lambda。

结果可视化

虽然将神经网络重新表述为树形结构很有趣,但追求这种方法的主要原因是提供更多的模型可解释性。

首先看看经典问题的解释——MNIST 中的数字分类:

MNIST 示例(图片来自 Frosst & Hinton 2017 [1])

在上图中,内部节点的图像是学习到的过滤器,叶节点的图像是学习到的类别概率分布的可视化。对于每个叶节点和节点,最可能的分类用蓝色标注。

从这棵树来看,我们可以看到一些有趣的特征。例如,如果我们查看最右边的内部节点,潜在的分类是 3 和 8。实际上,我们可以在决策节点可视化中看到 3 的轮廓。白色区域似乎表明模型寻找能够闭合 3 的内部循环的线条,从而将其转换为 8。我们还可以看到模型在左侧倒数第三个节点中寻找 0 的形状。

另一个有趣的例子是预测 Connect4 游戏中的胜利

可视化神经决策树前 2 层预测 Connect4 游戏胜者(图片来自 Frosst & Hinton 2017)

这个例子中的学习到的过滤器表明,游戏可以分为两种不同类型:一种是玩家主要集中在棋盘边缘的游戏,另一种是玩家在棋盘中心放置棋子的游戏。

结论

将神经网络构建为软决策树使我们能够利用神经网络的强大能力,同时仍保留一些可解释性。正如 MNIST 数据集上的结果所示,学习到的过滤器可以提供局部和全局的解释能力,这对于高风险任务也是一种受欢迎且有帮助的特性。此外,训练方法(一次优化和更新整个树)使我们能够在保持树的大小固定的情况下获得更多的区分能力,这是我们在正常决策树中无法实现的。

尽管如此,神经树仍然不完美。树的软性特征意味着使用这些树的数据科学家需要在向非技术利益相关者展示之前“预处理”树,而普通决策树可以直接展示(因为它们相对自解释)。此外,虽然树的斜向特性有助于准确性,但在给定节点中变量过多会使解释变得更加困难。这意味着正则化不仅是推荐的,而且在一定程度上是必要的。此外,无论模型多么可解释,仍然存在对利益相关者可理解的解释性特征的需求。

然而,这些缺点并未削弱这些模型在推动解释性与性能前沿方面的潜力。我强烈建议大家在下一个数据科学任务中尝试这些模型。我也推荐大家阅读原始论文

资源与参考文献

  1. 在 PyTorch 中实现软决策树

  2. 想了解更多关于 XAI 和时间序列预测的信息,请关注

参考文献

[1] N. Frosst, G. Hinton. 将神经网络蒸馏为软决策树 (2017). 2017 人工智能行动会议

[2] A. Shrikumar, P. Greenside, A. Jundjae. 通过传播激活差异学习重要特征 (2017). 国际机器学习会议 PMLR 2017。

[3] S.Bach, A. Binder, G. Montavon, F. Klauschen, K-R. Muller, W. Samek. 基于层级相关传播的非线性分类器像素级解释 (2015). PloS one, 10(7), e0130140

[4] L. Zhang, J. Varadarajan, P. N. Suganthan, N. Ahuja, P. Moulin. 使用斜向随机森林的鲁棒视觉跟踪 (2017). 2017 计算机视觉与模式识别会议。

具有多个数据源的神经网络

原文:towardsdatascience.com/neural-networks-with-multiple-data-sources-ef91d7b4ad5a?source=collection_archive---------4-----------------------#2023-01-06

如何使用 Tensorflow 设计一个具有多个数据源输入的神经网络

Morgan Lynch数据科学前沿 Morgan Lynch

·

关注 发表在 数据科学前沿 ·5 分钟阅读·2023 年 1 月 6 日

--

具有多个数据源的卷积神经网络。图片来源:作者。

在许多使用场景中,神经网络需要并行训练多个数据源。这些包括医学应用场景,其中可能会有一张或多张图像与结构化的患者数据一起使用,或者多图像应用场景,其中不同对象的图像贡献到单一输出。例如,使用个人房屋和汽车的独立照片来预测他们的收入。

集体数据不能一体处理,因为每个数据源都有其独特的属性和形状。为了成功设计一个网络,每个输入流需要单独处理和训练。

使用具有多个独立输入的 CNN 已被证明比单一图像输入提高了准确性。在一项研究中[1],处理了三个不同的图像输入分支,并将它们合并,结果比单独处理图像提高了 8% 的准确性。

此外,还显示出在 CNN 设计中,晚期合并网络分支也能产生更好的准确性[2]。这种晚期合并意味着在实际操作中,输入分支应该在合并到最终模型并生成预测之前,几乎完全作为独立网络进行处理。

我们将详细讨论如何设计这种类型的卷积神经网络(CNN),通过一个理论上的患者数据示例,其中包含一个数据 CSV 文件和一张图像。我们将只考虑一个图像输入,但这种方法也可以用于每个患者的多个图像。

首先,必须加载源文件并将其处理为 Pandas 数据框。下方示例中,加载了一个简单的数据集,包括患者 ID、患者年龄和一个标志,表示是否已诊断出癌症。

需要注意数据框的形状,因为这将影响后续网络的设计。

接下来,我们必须为每个患者加载一张图像。这是通过对患者数据框进行迭代来完成的,以保持记录的顺序。

图像数据也被转换为 numpy 数组,以保持与从文件中加载的患者数据的一致性。

加载数据的形状。图片由作者提供

我们现在需要考虑已加载数据的形状。对于图像,如果每个图像的尺寸为 512x512 像素,并且我们有n张图像,那么数据的形状为(n, 512, 512)。对于具有多个通道的图像,可能会添加进一步的维度,但我们将保持这个示例简单。

对于结构化的患者数据,我们的文件中有三列和n条记录。这将导致数据形状为(n, 3)。患者 ID 列在训练中是不需要的,因此这列可能会被删除,从而得到最终训练数据的形状为(n, 2)。

数据的进一步预处理,如缩放,不在本讨论范围之内。对于本示例,我们将直接使用原始数据。

然而,在设计神经网络之前,还需要一步。这一步是将数据分割成训练集和测试集。这需要在一个步骤中完成,以保持数据集的顺序和分割。下面的示例演示了如何使用 scikit-learn 来完成这一步骤:

一旦分割完成,我们就可以从两个数据集中提取目标特征作为我们的‘y’数据集。检查两个结果训练数据集的形状应该产生类似于以下的输出:

(1200, 512, 512)

(1200, 3)

记录数为 1,200。两个数据集需要具有相同数量的记录,以便它们可以在神经网络的输出中合并。

现在我们可以使用 Keras 函数式 API 开始设计神经网络。首先,我们将从结构化的患者数据开始:

网络的设计可以有所不同,但最好包括一个归一化层。归一化层仅适应于训练数据。

重要的是,输入层的形状设置为数据中的列数(在此示例中为 3)。

输出层的形状也很关键,因为这是将与图像处理分支合并的形状。这由最后的 Dense 层决定。在此示例中,输出层的形状将是:

(无, 64)

其中‘None’是 Keras 对记录数量的解释,未指定。

数据分支现在已经完成,我们可以查看图像处理分支。虽然可以设计自己的网络,但在实践中使用预设计的模型更为方便。在此示例中,我们将使用 Keras Applications 中的 Resnet-50。

如上所示,输入形状是每个图像的大小,加上一个额外的维度用于图像通道(在此案例中为 1)。

在 Resnet 模型的末尾添加一个全连接的 Dense 层,以使输出与数据分支的形状相同:

(无, 64)

因为我们在整个过程中都注意了数据的形状,所以现在能够合并两个分支的输出:

两个分支被连接在一起,最后添加一个全连接的 Dense 层以将模型减少到最终的预测。这里使用的激活函数可以有所不同。在此示例中,使用线性激活函数来输出实际的类别概率。

CNN 的最终设计总结如下:

最终 CNN 设计。图片由作者提供

如上所示,如果仔细考虑数据的形状,可以成功地合并多个分支。然后,可以使用这个合并后的模型从多个数据源生成单一的预测。

感谢阅读。

参考文献:

[1] Yu Sun, Lin Zhu, Guan Wang, Fang Zhao, “Multi-Input Convolutional Neural Network for Flower Grading”, Journal of Electrical and Computer Engineering, vol. 2017, Article ID 9240407, 8 pages, 2017. doi.org/10.1155/2017/9240407

[2] Seeland M, Mäder P (2021) Multi-view classification with convolutional neural networks. PLoS ONE 16(1): e0245230. doi.org/10.1371/journal.pone.0245230

神经原型树

原文:towardsdatascience.com/neural-prototype-trees-f7bac36437a9

通过模仿人类推理来实现可解释的图像分类。

Nakul UpadhyaTowards Data Science Nakul Upadhya

·发布于 Towards Data Science ·6 分钟阅读·2023 年 6 月 2 日

--

机器学习和人工智能现在被应用于大量领域,但随着使用的增加,模型面临着更多的风险和伦理测试。

让我们通过最近的新闻来激发思考,关于一辆特斯拉在自动驾驶模式下撞上树木的事件。根据当局的说法,司机表示车辆在她启用自动驾驶模式后向右偏移,驶离了道路,并撞上了一棵树。目前,这些说法正在调查中,但想象一下,识别汽车突然做出怪异决策的原因是多么困难。它是否发生了误分类?它看到了什么让它困惑的东西?对于传统的黑箱模型来说,调查模型内部非常困难且昂贵。

那么,替代方案是什么?是否有一种可解释的图像分类方法?是的,通过原型学习[2]和神经原型树[1]!借助这些架构,模型采用了一种非常直观的预测方法:识别看起来熟悉的部分。那只鸟有长长的喙吗?它有红色的喉咙吗?那一定是一只蜂鸟!

在本文中,我旨在提供有关这些模型如何工作的资讯,并讨论使用这些模型的一些优缺点。我将频繁引用的两篇主要论文如下,我强烈建议有兴趣的读者阅读这些论文:

[## 这看起来像那样:用于可解释图像识别的深度学习

当我们面临具有挑战性的图像分类任务时,我们通常通过解构图像来解释我们的推理……

arxiv.org openaccess.thecvf.com [## CVPR 2021 开放获取库

神经原型树用于可解释的细粒度图像识别 Meike Nauta、Ron van Bree、Christin Seifert 等

openaccess.thecvf.com

这是我撰写的关于神经决策树变体的第二篇文章。如果你还没有读过第一篇文章,我强烈建议你浏览一下,因为这里阐述的许多概念都是基于标准神经树构建的。

towardsdatascience.com ## 神经网络作为决策树

利用神经网络的强大能力和决策树的可解释结构

[towardsdatascience.com

什么是原型?

图像识别中的原型思想首次由 Chen & Li 等(2019)在他们的论文“这看起来像那样:用于可解释图像识别的深度学习” [2] 中提出。这是一种潜在表示,表示与给定类别相关的某些训练图像补丁。正如名字所示,该模型通过解剖输入图像并找到提供证据的原型部分来工作,以表明图像属于某一类别。网络简单地计算欧几里得距离,并将其反转以创建相似度分数。这些分数随后通过一个全连接层生成最终的类别概率:

图 1: ProtoPNet 架构(图源自 Chen & Li 等,2019 [2])

一旦模型训练完成,用户可以简单地将学到的原型与训练集中的补丁进行匹配,从而创建对任何预测的非常可解释的解释:

图 2: 鸟类预测的推理过程(图源自 Chen & Li 等,2019 [2])

神经原型树

使用原型包模型(例如来自原始原型论文 [2] 的 ProtoPNet)的一个问题是,原型匹配是同时进行的,但人类图像识别依赖于一系列步骤。如果某物没有爪子或耳朵但有尾巴,那它可能不是猫,因此网络不应该给它打上猫的标签。这就是神经原型树 [2] 发挥作用的地方。与原型包不同,Nauta 等人 [1] 选择将他们的模型设计为神经决策树。这种决策树提供了这种顺序决策,并提供了全局可解释性,而不仅仅是局部可解释性。

软决策树的回顾

神经原型树是软决策树,而不是硬决策树。虽然硬决策树强制执行确定性分支(你要么向左走,要么向右走),但软决策树使用的是概率分支(你有 p 的概率向左走和 1-p 的概率向右走)。此外,虽然硬决策树输出一个单一的值,但软决策树输出的是所有可能类别的概率分布,其中类别的概率是到达叶子节点所经过概率的乘积,分类决策则是概率最高的类别。

原型树中的决策制定

像标准神经树一样,每个叶子节点包含一个关于类别的概率分布。分支决策是通过计算图像补丁到节点中给定原型的距离来做出的。每张图像的得分是图像中补丁与原型之间找到的最小距离,然后转换为概率。简单来说:如果在图像中找到原型,就向右走;如果找不到该原型,就向左走!

图 3:ProtoTree 中的预测机制(图源自 Nauta 等人 2021 [1])

这个机制显然允许我们使用与可视化普通原型模型相同的机制来可视化出极其可解释的模型。

图 4:原型树的全局解释(图源自 Nauta 等人 2021 [1])

学习叶子分布

在普通决策树中,叶子的标签是通过查看最终到达该叶子的样本来学习的,但在软树中,叶子中的分布是全局学习问题的一部分。然而,作者注意到将叶子的学习与原型的学习结合在一起会导致分类结果不佳。为了解决这个问题,他们利用了一种无导数策略来获取叶子概率的更新方案:

图 5:更新方案。c_l^t 是第 t 轮中叶子 l 的叶子概率。y 是真实值,yhat 是预测值。pi 是到该叶子的路径概率。(图自 Nauta 等,2021 [1])

该更新方案与小批量梯度相结合,以学习原型和卷积参数,从而创建一个高效的学习过程。

剪枝

为了提高可解释性,作者还引入了剪枝机制。如果一个叶节点包含有效的均匀分布,它的区分能力不强,因此最好对其进行剪枝,因为较小的树更易于阅读和解释。从数学上讲,作者定义了一个阈值 t 并移除所有最高类别概率小于 tmax(c_l) ≤ t)的叶子。如果一个子树中的所有叶子都被移除,则可以移除该子树及其相关原型,从而使树变得更加紧凑。通常,t = 1/K + epsilon* 其中 K 是类别数,epsilon 是一个非常小的数值,表示容差。

图 5:剪枝可视化(图自 Nauta 等,2021 [1])

性能

图 6:平均准确率和标准差。ProtoTree ens. 是 3 棵或 5 棵原型树的集合。 (图自 Nauta 等,2021 [1])

作者使用 CARS 和 CUBS 数据集对他们的方法进行了基准测试,与其他可解释的图像识别方法(如基于注意力的可解释性方法)进行比较。他们发现,通过使用相对较小的树木集合(9 棵和 11 棵),他们能够接近 SOTA 准确率。

结论

可解释的深度学习图像分类器相对于黑箱模型提供了许多优势。它们可以帮助建立信任,改善调试,并解释预测。此外,它们还可以用于探索数据,了解不同特征之间的关系。

总的来说,神经原型树是一种有前途的新方法,用于以可信赖的方式进行图像识别。如果医生能够检查模型所观察到的图像的特征,他更可能相信癌症检测模型。这些原型树甚至可以通过添加注意力等措施进一步提高准确性!

资源和参考文献

  1. 神经原型树的 Github: github.com/M-Nauta/ProtoTree

  2. 如果你对可解释的机器学习和人工智能感兴趣,可以考虑关注我:medium.com/@upadhyan

参考文献

[1] M. Nauta, R.v. Bree, C. Seifert. 神经原型树用于可解释的细粒度图像识别 (2021). IEEE/CVF 计算机视觉与模式识别会议(CVPR),2021

[2] C. Chen, O. Li, C. Tao. A.J. Barnett, J. Su, C. Rudin. 这看起来像那样:可解释图像识别的深度学习 (2019). 第 33 届神经信息处理系统会议。

新的 ChatGPT 提示工程技术:程序模拟

原文:towardsdatascience.com/new-chatgpt-prompt-engineering-technique-program-simulation-56f49746aa7b?source=collection_archive---------0-----------------------#2023-09-03

Giuseppe ScalamognaTowards Data Science Giuseppe Scalamogna

·

关注 发表在 Towards Data Science · 9 分钟阅读 · 2023 年 9 月 3 日

--

来源:作者提供的图像,使用 MidJourney 生成

提示工程的世界在各个层面上都非常迷人,并且有很多巧妙的方法可以引导像 ChatGPT 这样的代理生成特定类型的响应。诸如链式思维(CoT)、基于指令、N-shot、Few-shot 甚至像奉承/角色分配这样的技巧都是灵感的来源,激发了许多满足各种需求的提示库。

在这篇文章中,我将深入探讨一种技术,根据我的研究,这种技术可能尚未被充分探索。虽然我会暂时将其标记为“新”,但我会避免称其为“创新”。鉴于提示工程中的创新速度之快以及新方法的易于开发,这种技术可能已经以某种形式存在。

该技术的本质在于使 ChatGPT 以模拟程序的方式运行。我们知道,程序由一系列指令组成,这些指令通常打包成函数以执行特定任务。从某种程度上说,这种技术是基于指令和基于角色的提示技术的结合。但与这些方法不同的是,它寻求利用一个可重复且静态的指令框架,使得一个函数的输出可以影响另一个函数,并且整个交互保持在程序的范围内。这种模式应该与像 ChatGPT 这样的代理中的提示-完成机制很好地契合。

来源:作者提供的图像

为了说明这项技术,让我们在 ChatGPT4 中指定一个迷你应用程序的参数,该应用程序旨在作为互动创新工作坊。我们的迷你应用程序将包含以下功能和特点:

  1. 处理新想法

  2. 扩展想法

  3. 总结想法

  4. 检索想法

  5. 继续处理先前的想法

  6. Token/“记忆”使用统计

需要明确的是,我们不会要求 ChatGPT 用任何特定编程语言编写迷你应用程序,我们将在我们的程序参数中反映这一点。

根据这个程序大纲,让我们开始编写启动提示,以在 ChatGPT 中实例化我们的互动创新工作坊迷你应用程序。

程序模拟启动提示

Innovator’s Interactive Workshop Program

I want you to simulate an Innovator’s Interactive Workshop application whose core features are defined as follows:

1\. Work on New Idea: Prompt user to work on new idea. At any point when a user is ready to work through a new idea the program will suggest that a date or some time reference be provided. Here is additional detail on the options:
  a. Start from Scratch: Asks the user for the idea they would like to work on.
  b. Get Inspired: The program assists user interactively to come up with an idea to work on. The program will ask if the user has a general sense of an area to focus on or whether the program should present options. At all times the user is given the option to go directly to working on an idea.
2\. Expand on Idea: Program interactively helps user expand  on an idea.
3\. Summarize Idea: Program proposes a summary of the idea regardless of whether or not it has been expanded upon and proposes a title. The user may choose to rewrite or edit the summary. Once the user is satisfied with the summary, the program will "save" the idea summary.
4\. Retrieve Ideas: Program retrieves the titles of the idea summaries that were generated during the session. User is given the option to show a summary of one of the ideas or Continue Working on a Previous Idea.
5\. Continue Working on Previous Idea: Program retrieves the titles of the idea summaries that were generated during the session. User is asked to choose an idea to continue working on.
6\. Token/Memory Usage: Program displays the current token count and its percentage relative to the token limit of 32,000 tokens.

Other program parameters and considerations:

1\. All output should be presented in the form of text and embedded windows with code or markdown should not be used.
2\. The user flow and user experience should emulate that of a real program but nevertheless be conversational just like ChatGPT is.
3\. The Program should use emojis in helping convey context around the output. But this should be employed sparingly and without getting too carried away. The menu should however always have emojis and they should remain consistent throughout the conversation.

Once this prompt is received, the program will start with Main Menu and a short inspirational welcome message the program devises. Functions are selected by typing the number corresponding to the function or text that approximates to the function in question.  "Help" or "Menu" can be typed  at any time to return to this menu. 

如果你想以更互动的方式跟随并自己测试,可以随意将提示加载到 ChatGPT4 中。

这是 ChatGPT 对提示的完成结果。

到目前为止,一切顺利。我们已经启动了我们的“迷你应用程序”,收到了振奋人心的欢迎消息,并且展示了一个与我们的程序参数一致的功能菜单。让我们通过提交“1”来测试我们的迷你应用程序的功能,以启动“处理新想法”功能。

对话继续很好地遵循我们设定的“程序”结构,适当地提供了符合参数的完成。让我们继续从零开始构建一个想法,并让程序与我们合作,开发一种用于生长建筑物而非建造它们的技术。

有趣的是,我们注意到“程序”在没有明确指示的情况下自动调用“Expand on Idea”功能。鉴于程序的目标,这种行为并不不当,可能受到我们最初上下文设置的影响,这些设置引导聊天代理像程序一样运行。让我们继续深入探讨一下增长建筑所需的技术。

现在让我们检查一下用于增长建筑的材料。

我继续沿着这些思路前进,现在,让我们看看是否可以返回到菜单。

菜单仍然完整。让我们尝试让程序执行 Summarize Idea 功能。

我对这个标题和摘要暂时满意,所以让我们“保存”它。

很快,我们将测试检索我们“保存”的想法,以检查我们在实现数据持久性方面的努力是否成功。另一方面,调整我们的“迷你应用”以省略保存后的重复摘要可能会有所帮助。

角色启动作为程序的结果是在输出中包含主菜单——这种行为在程序的背景下是合理的,即使它在我们的程序定义中没有被明确配置。

接下来,让我们测试我们的令牌计数功能。

为了核实准确性,我转向 OpenAI 的分词器工具。

令牌计数不准确,证据在于显著的差异——我们的程序报告大约有 1,200 个令牌,而分词器工具显示为 2,730。鉴于这种不匹配,明智的做法是从程序中移除此功能。我不会深入讨论为何这种任务通常对语言模型来说是个问题,以及功能损失相对较小。最终,我预计这样的功能将会原生集成到 ChatGPT 中,特别是考虑到令牌计数信息在后台不断传递。

接下来,让我们深入研究“Get Inspired”功能以生成新想法。为了简洁起见,我将进一步展示对话。正如你所见,我选择深入探讨一个我们的程序建议的废物转化为能源的无人机概念,概括了这个想法,并让我们的程序“保存”了它。

一切看起来都很好,系统甚至擅自给我们的想法命名为“SolarSky”。为了更有效地实现这一点,我们可能会在程序定义中为此任务加入一个独立的函数,或者在“工作在新想法”或“扩展新想法”函数中提供更具体的指示。同样,我们在完成中看到菜单,这从程序流的角度看是合乎逻辑的。

现在让我们看看是否可以“检索想法”。

这似乎符合我们的原始指示,仅提供了所请求的标题。它还提示我们继续工作一个想法,即使这并没有明确地编程到迷你应用程序中。接下来,让我们评估它是否保持了根菜单索引。为此,我将输入“5”,对应于“继续工作在之前的想法”功能,看看是否有效。

显然,索引在对话上下文中被维护,并且相应地调用了函数。这一观察值得注意,特别是在考虑到多个索引可能处于活动状态的情况下。这引发了有关“程序”在这种条件下如何表现的有趣问题。你可能错过了,但在我们互动的早期,程序在征求用户对想法扩展选择的输入时实际上采用了索引技术。

让我们继续工作在我们的建筑构想上。

再次看起来不错。“程序”的行为如预期,并且也跟踪了我们在想法扩展过程中暂停的确切点。

让我们在这里停止测试我们的提示,看看通过这种技术我们学到了什么。

结论与观察

坦白说,这次练习虽然在范围和功能上都有限,但超出了我的预期。我们本可以让 ChatGPT 用 Python 等语言编写这个迷你应用程序,然后利用代码解释器(现在称为高级数据分析)在持续的 Python 会话中运行它。然而,这种方法会引入一种刚性,使得启用我们迷你应用程序中固有的对话功能变得困难。更不用说,特别是在具有多个重叠功能的程序中,我们立即面临着代码无法正常工作的风险。

ChatGPT 的表现尤其令人印象深刻,因为它以高度逼真的方式模拟了程序行为。提示完成保持在程序定义的边界内,即使在函数行为没有明确规定的情况下,完成也在迷你应用程序目的的上下文中有逻辑性。

这种程序模拟技术可能与 ChatGPT 的“自定义指令”功能配合良好,尽管值得一提的是,这样做会将程序的行为应用于所有后续的互动中。

我的下一步包括对这种技术进行更深入的研究,以评估是否可以通过一个全面的测试框架来了解这种方法相对于其他提示工程技术的表现。这种练习也可能帮助确定这种技术最适合哪些特定任务(或任务类别)。敬请关注更多信息。

与此同时,希望你在互动中发现这种技术和提示有帮助。如果你想进一步讨论这种技术,请随时通过LinkedIn与我联系。

除非另有说明,本文中的所有图片均由作者提供。

新数据表明 2023 年是有史以来最热的夏天

原文:towardsdatascience.com/new-data-demonstrates-that-2023-was-the-hottest-summer-ever-d92d500a8f01

气候变化|数据可视化

我们在 Python 和 Plotly 中开发可视化,以分析 2023 年 6 月至 8 月期间记录的最高气温

艾伦·琼斯面向数据科学 艾伦·琼斯

·发布于面向数据科学 ·阅读时间 11 分钟·2023 年 9 月 28 日

--

图片来源:路易斯·格拉特罗Unsplash

今年夏天比 1880 年以来的任何时候都要热!

数据科学家如何帮助展示我们的气候正在迅速变化,并帮助传达情况的严重性?我们将探讨如何通过分析和可视化有效地呈现数据,并对数据的表示进行优化。

但首先,让我们简要探讨一下全球变暖的一些后果。然后,我们将考虑如何使用 Plotly 和 Python 有效地可视化数据,并展示这些数据如何与 CO₂排放相关。

这张地图展示了 2023 年气象夏季(6 月、7 月和 8 月)的全球温度异常情况。它显示了地球不同区域相对于 1951 年至 1980 年基准平均值的温暖或凉爽程度。来源:NASA 地球观测台/劳伦·多芬经授权使用

这张来自 NASA 的地图显示了与 1951 年至 1980 年平均值相比,今年夏天的全球温度异常情况,我并不感到惊讶,因为我所在的欧洲地区是全球温度变化最高的区域之一——西班牙的气温高达 40 摄氏度(约 104 华氏度)并不罕见。

影响

高温加剧了加拿大、夏威夷、欧洲部分地区以及其他地方的野火,并可能促成了全球范围内的强降雨事件。

“Drevenochoria” 火灾从阿提卡的伊利翁看到的图像,拍摄时间大约为 7 月 18 日凌晨 2 点。图片由Sthivaios提供 CC BY-SA 4.0

欧洲的野火在夏季并不罕见,但今年特别猛烈,尤其是在希腊,岛屿如罗德岛和科孚岛等地区进行了疏散,许多人死伤。而且,在夏威夷,一个小镇被完全摧毁了。

除了野火,欧洲还遭遇了大规模的暴雨(特别是来自丹尼尔风暴的暴雨),希腊再次受到严重影响——洪水造成了数百万欧元的损失。

2023 年 9 月 9 日,丹尼尔气旋(也称为丹尼尔风暴)在利比亚北部。来自 NOAA-20 卫星的 VIIRS 影像——worldview.earthdata.nasa.gov/,公共领域

由于丹尼尔风暴,利比亚的洪水造成了巨大的破坏,并造成了数千人遇难,当时两个大坝崩溃,摧毁了地中海沿岸的德尔纳大部分地区。电视新闻播报了幸存者从曾经的家园废墟中被救出的悲惨场景。

虽然不可能明确证明这些灾难与气候变化之间的联系,一项由 NASA 主导的研究 确认了随着气温升高,严重干旱和过度降水的发生频率增加。

Carbon Brief(一个总部位于英国的网站,涵盖气候科学、气候政策和能源政策的最新发展)建议,93% 的极端高温事件经科学家评估都因为气候变化而变得更可能发生或更严重。

这并不令人惊讶:较高的温度意味着森林变得更干燥,更易燃,而当下雨时,这些较高的温度意味着大气能容纳更多的水蒸气,从而有更多的 H₂O 可降雨。

统计数据展示

根据他们的新闻稿,今年 6 月、7 月和 8 月的综合温度比任何其他记录中的夏季高出 0.23 摄氏度,比 1951 年至 1980 年的平均夏季温度高出 1.2 摄氏度。

数据由 NASA 戈达德太空研究所(GISS)[1]的科学家在纽约记录,记录了 6 月、7 月和 8 月的温度异常——这些月份被认为是北半球的气象夏季。

数据覆盖了从 1880 年至当前年份,并记录了与 1951 年至 1980 年计算的平均值相比的夏季温度变化。

我们可以绘制一个简单的折线图,从中可以清晰地看到温度的逐渐上升。1980 年后的温度逐渐上升,而 1951 年之前的温度大多低于平均水平,而 1951 年至 1980 年之间的温度则趋于接近平均温度。

2023 年 6 月、7 月和 8 月的全球温度。数据来源于 NASA 的 GISS [1] — 作者图片

然而,由于数据取自每年三个月的时间段,其连续性不如代表一系列相邻月份的数据。因此,也许对于这种数据,柱状图可能是更好的选择。

下图更清楚地展示了自 1880 年以来全球温度在 6 月、7 月和 8 月的变化情况,柱状图更好地表现了数据。你可以更容易地看到,今年的温度明显高于任何近期年份,并且比 1951 年至 1980 年的平均夏季温度高出很多。

2023 年 6 月、7 月和 8 月的全球温度。数据来源于 NASA 的 GISS [1] — 作者图片

因此,柱状图可能比折线图更好地传达数据。但我们仍然可以通过使用颜色来使其更加清晰。

2023 年 6 月、7 月和 8 月的全球温度。数据来源于 NASA 的 GISS [1] — 作者图片

上面的图表比之前的图表更戏剧性地展示了温度变化。

图表中的颜色编码(尽管严格来说是多余的)强调了上升趋势,较高的温度由更暖的颜色表示,深色主题与浅色形成鲜明对比,并突出了近期更热的年份。

我使用 Plotly 创建了柱状图,采用了深色方案和 Plotly 的‘inferno’颜色范围——我们将在下面看到 Python 代码。

数据

数据[1]覆盖了 1880 年至 2023 年,包括每年每个月的温度异常以及例如我们上面看到的 6 月至 8 月的月均温度。有一个全球数据文件,还有两个分别针对北半球和南半球的文件。

我已将数据复制到一个 GitHub 仓库,并编写了一个 Jupyter Notebook,该 Notebook 读取这些数据并生成我在这里使用的所有图表(文章顶部的地图除外——那张图片是由 NASA/GISS 的好人们制作的)。你可以在那里下载任何或所有的材料,我将在本文末尾提供一个链接。

这是一个将全球数据读取到 Pandas 数据框中的代码副本。

title, df = readdata.read_GLB()

我创建了一些辅助函数来读取文件,这些函数将文件的标题和数据分开,你将在仓库中找到它们。

这是数据的截图:

2023 年 6 月、7 月和 8 月的全球气温作为数据框。数据来自 NASA 的 GISS [1] — 图片由作者提供

一些数据缺失——主要是尚未发生的月份——但我们不必担心这些,因为我们不会使用任何包含缺失数据的列。

创建条形图只需三行代码。

period = 'JJA'
scale = 'inferno'
px.bar(df, x='Year', y = period, color="JJA", title = f"{title} - {period}", 
       color_continuous_scale=scale, template='plotly_dark')

图表绘制了北半球夏季的数据,这些数据在 JJA 列中找到。由于我在尝试不同的颜色比例,所以我还使用了一个变量 scale 来定义这个比例。该时期和标题(在之前的代码中设置)用于为图表创建标题,除此之外,它只是一个简单的 Plotly Express 条形图。

线图的代码类似于:

px.line(df, x='Year', y = period,  title = f"{title} - {period}", 
            template='plotly_white')

我对默认的 Plotly 颜色方案不太感冒,所以在这里我使用了我更喜欢的 plotly_white

CO₂ 排放

我们能展示全球气温持续上升的原因吗?恐怕不能简单地做到这一点。但我们可以展示 CO₂ 排放量的上升与气温上升之间的相关性,并指出我们知道 CO₂ 排放量由于人为活动而增加,并且大气中的 CO₂ 增加会导致变暖。

我使用了一个 Our World in Data 的 GitHub 仓库 [3] 来创建我们接下来将使用的数据。同样,它被复制到我自己的一个仓库中,我将其处理成较小的文件。链接将在文章末尾出现。

我们将创建一个这样的全球 CO₂ 排放线图:

全球 CO2 排放。数据来自 OWID [3],图片由作者提供

首先,让我们获取数据:

f = "https://raw.githubusercontent.com/alanjones2/CO2/master/data/world_df.csv"
co2 = pd.read_csv(f)

它看起来像这样:

我们只对“年份”和“年度 CO₂ 排放量”列感兴趣,从中我们可以绘制图表。

该图表类似于温度变化图。它有相同的平坦开始,并且在图表的后半部分上升更陡。它们在这里一起展示。

全球 CO2 排放。数据来自 OWID [3],图片由作者提供

2023 年 6 月、7 月和 8 月的全球气温。数据来自 NASA 的 GISS [1] — 图片由作者提供

我们可以接受科学共识,即 CO₂排放增加了全球变暖,但同时我们也需要接受其他因素的影响。

温度变化不仅仅是由于人类将温室气体排放到大气中。正如美国环境保护署(EPA)所明确指出的,还有其他因素,例如太阳活动和由于例如森林砍伐造成的地球反射率变化。还有除了二氧化碳之外的其他温室气体,如甲烷和一氧化二氮。

美国环境保护署(EPA)还明确表示,除了人为排放的温室气体之外,没有其他原因能够解释当前气候变化的水平。

然而,这些其他因素使图表不完全相同。温度线的波动上下,这些不太剧烈的影响可能是原因。

数据相关性

数据科学家或统计学家确定相关性的方法可能是绘制 CO₂排放与温度变化的散点图,并在点之间绘制趋势线。

我们可以稍后再查看,但我不确定散点图是否为一般读者所理解,仅仅将两个图放在一起可能对非专业读者来说是更好的方法。

温度和排放图表的相似性很容易看出,但如果我们能在同一图表上绘制这些数据会更有用。这并不是完全简单,因为虽然两个图的 x 轴都是类似的年份区间,但 y 轴差异很大。温度异常覆盖了几度摄氏度,而 CO₂排放量在大约 20 到 40 亿吨之间。

双轴图

解决方案是绘制一个双轴图,其中有两个 y 轴和一个共同的 x 轴。不幸的是,Plotly Express 不支持双轴图,因此我们利用了 Plotly Express 构建的 Graph Objects 包。

如下代码所示,我们首先创建一个包含次级 y 轴的空图,然后将两个数据轨迹添加到该图中,第一个轨迹是温度数据,第二个轨迹是 CO₂数据。

你可能会发现 emissions 轨迹是一个散点图,但 Graph Objects 中的Scatter轨迹默认用线连接这些点,所以它实际上是一个线图。

代码的其余部分仅仅设置了标题和标签。

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])

# Add traces
fig.add_trace(
    go.Bar(x=df['Year'], y=df['JJA'], name="Temp anomaly"),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(x=co2['Year'], y=co2['Annual CO₂ emissions'], name="CO2 Emissions"),
    secondary_y=True,
)

# Add figure title
fig.update_layout(
    title_text="Temperature / CO2 Emissions"
)

# Set x-axis title
fig.update_xaxes(title_text="Year")

# Set y-axes titles
fig.update_yaxes(title_text="Temperature ºC", secondary_y=False)
fig.update_yaxes(title_text="CO2 Emissions tonnes", secondary_y=True)
fig.update_layout(template='plotly_dark')

fig.show()

结果如下:

2023 年 6 月、7 月和 8 月的全球温度。数据来自 NASA 的 GISS [1],全球 CO2 排放数据来自 OWID [3] —— 作者提供的图像

两组数据之间的关系相当明确。

数据科学家的散点图相关性

为了完整性,我们还需要一个散点图,不是吗?这就是它:

2023 年 6 月、7 月和 8 月的全球温度。来自 NASA 的 GISS[1]与全球 CO2 排放数据,数据来自 OWID[3] — 图像由作者提供

要绘制此图,我们需要匹配数据集的长度,因此我们需要将两个数据集都截断,从 1880 年(温度数据的第一年)到 2021 年(CO₂数据的最后一年)

# To draw a scatter plot we need to make the data the same length
# So we need to truncate both from 1880 (the first temp yr) to 2021 (the last co2 yr)
# Check that years are correct
co2yrs=list(co2['Year'][30:])
tyrs = list(df['Year'][:-2])
print(f"CO2 Years {min(co2yrs)} to {max(co2yrs)}")
print(f"Temperature Years {min(tyrs)} to {max(tyrs)}")
CO2 Years 1880 to 2021
Temperature Years 1880 to 2021

这完成了任务并检查了范围是否相同。

散点图还显示了两个数据集之间的相关性,但至少对于一般观众,我认为双线和条形图更具说服力。

结论

关于气候变化的争论常常充满情感、政治和经济动机以及个人偏见。因此,我们有责任以尽可能清晰的方式呈现事实,以便科学论点能够占上风。

我将最后的话留给气候科学家和 GISS 主任加文·施密特。在NASA 新闻稿中,他被引用说,“不幸的是,气候变化正在发生。我们说会发生的事情正在发生,”他补充道,“如果我们继续向大气中排放二氧化碳和其他温室气体,情况会变得更糟。”

下载

感谢阅读,希望这篇文章对你有帮助,并且你会查看我 GitHub 存储库中的数据和代码。

你可以在我网站的链接中找到数据和包含本文所有代码(及更多)的 Jupyter Notebook。作为额外福利,还有一些图表的 Matplotlib 版本,包括双轴图。

[## Alan Jones

编码、数据科学和数据可视化 - 文章和教程

alanjones2.github.io

你也可以订阅我的数据可视化、数据科学和 Python通讯以获取更多内容。

参考文献

  1. GISTEMP 团队,2023: GISS 地表温度分析(GISTEMP),第 4 版。NASA 戈达德太空研究所。数据集于 2023-09–19 访问,地址为 data.giss.nasa.gov/gistemp/。请注意,NASA 的数据集没有特定的使用许可。NASA 将其免费提供用于非商业目的,但应给出归属(如上所示)。

  2. Lenssen, N., G. Schmidt, J. Hansen, M. Menne, A. Persin, R. Ruedy, 和 D. Zyss, 2019: GISTEMP 不确定性模型的改进。J. Geophys. Res. Atmos., 124, 第 12 期, 6307–6326, doi:10.1029/2018JD029522.

  3. 全球 CO₂排放数据源自我们世界的数据(OWID)co2-data GitHub 存储库创作共用 BY 许可

最新的 DeepMind 工作揭示了语言模型的极致提示种子

原文:towardsdatascience.com/new-deepmind-work-unveils-supreme-prompt-seeds-for-language-models-e95fb7f4903c

如何通过计算优化的提示使语言模型表现出色,以及这如何影响提示工程

LucianoSphere (Luciano Abriata, PhD)Towards Data Science LucianoSphere (Luciano Abriata, PhD)

·发表于 Towards Data Science ·阅读时间 11 分钟·2023 年 11 月 8 日

--

图片由 Ali Shah Lakhani 提供,来源于 Unsplash

随着我们见证人工智能(AI)的稳步进步,每个月完成越来越困难的任务,人们普遍关注未来的劳动市场。如果 AI 继续自动化许多当前由人类执行的任务,未来的职业会是什么样的呢?有一种观点认为“编程这些系统将是人类的工作多年”,或者“我们将始终需要人类来维护和重新训练 AI 模型”,或者“设计有效的提示以正确引导 AI 模型是人类的技能”。本文的重点就是后者,这促使了“提示工程”作为一种“职业”的出现。确实,编写高效提示以使 AI 模型准确地执行期望的操作,或使其“思考”得足够好以改善其答案,尤其是对问题,是一种技巧。看看这个作为一个例子:

## 针对大语言模型的有效提示工程

提炼出关键点,基于超过 2 年的经验以及 AI 开发者自己的教程、实践和示例。

towardsdatascience.com

然而,这些人工干预中的任何一种很可能都不会永远保持相关性。特别是在提示工程方面,这些技能看起来很快就不会再那么重要了。继续阅读以了解原因,并在过程中了解 DeepMind 在最近的预印本中报告的非常有趣的发现,当使用大型语言模型(LLMs)时,你可以立即将这些发现应用于自身利益,我会通过 ChatGPT 免费版的实践示例来展示给你。

提示优化

在深入之前,我们需要了解“提示”的概念。提示是传递给 AI 模型的指令,用来告诉它们我们希望它们做什么。AI 模型响应用户生成的文本输入或提示,以生成其输出——文本、图形、音频等。输入提示的质量和具体性显著影响模型生成的内容和质量。此外,不同的用户可能会有不同的请求或提问方式,并不是所有的方式都能高效地产生预期的答案和正确的信息。

过去,制定有效提示的艺术并不被很好地理解。早在 2019 年,OpenAI 就揭示了在文本输入末尾添加“tl;dr”(通常用于请求总结)可以使模型总结前面的文本。尽管当时这只是一个学术好奇,但随着时间的推移,研究人员和爱好者开始发现特定的措辞可以解锁这些 AI 模型的增强潜力。正是这一点催生了“提示工程(学)”这一“领域”和“职业”。

提示工程师成为了制定能够引发 AI 模型期望回应的提示的专家。他们在互联网上分享了他们的“魔法短语”和技巧,从而有效地创造了一个新的专业领域。这些专业人员的需求迅速增长,提示工程师的职位开始出现,突显了这些人在 AI 领域的重要影响。

提示工程师基本上是为了最佳输出而优化提示。但是,优化是计算机做得非常好的任务;因此,它们很可能会取代人工提示工程师。这可能即将发生:在他们最近的预印本中,DeepMind 展示了 AI 模型可以用来优化自身的输入,从而使得“提示工程”在传统意义上变得有些过时和低效。

大型语言模型作为自身提示的优化器

最近 DeepMind 的论文《大型语言模型作为优化器》探讨了如何有效地优化……它们自己的输入提示!

让我们看看这是什么,它如何与训练和优化相关,以及 DeepMind 的工作究竟发现了什么。哦,我们还可以尝试一些计算机优化的提示。你会感到惊讶。

传统上,在机器学习中,优化过程涉及调整模型的内部参数(“权重”,即描述不同人工神经元如何连接的大量数字)以最小化误差。通常,在训练阶段,模型会接收到大量已知的输入-输出对,并且所有权重都按照传统的数学方式进行优化,以使模型“学习”。然而,一旦模型训练完成,我们仍然可以通过以不同方式提供输入来使用它,这些输入将根据训练过程中计算的权重集以不同的方式传播,因此它们的输出也会不同……有些更好,有些更差。然后,可以优化如何提供输入,在语言模型的特殊情况下,就是优化提示。像任何优化协议一样,我们人类可以通过试错来完成这一过程,但显然计算机比我们更擅长——这只是一个让他们去做的事情!

DeepMind 工作的核心思想是使用各种 AI 模型生成特定任务的提示,然后测试不同提示在实现期望结果方面的有效性。例如,如果任务是解决一个数学问题,用户可以输入一个非常简单的提示,直接提出问题,或者他/她可能会在问题前面加上一句种子句,如“逐步解决这个数学问题。”,或者“我们一步步解决这个问题。”等。在某些情况下,种子句可能有助于改善 AI 的输出,即使问题本身以完全相同的方式提出。

基于这一思想,DeepMind 的工程师尝试了不同的种子提示,然后用完全相同的问题进行测试,并评估答案的质量(正确性)。他们重复了这个过程,将一系列提示应用于多个不同的问题,并最终统计了每个种子句产生的正确答案数量,以比较对每个问题的影响。

从经验来看,我们知道这种策略应该有效于找到如何优化提示。然而,对于我们人类来说,大规模地进行样本测试是不可能的。通过使用计算机架构,DeepMind 能够大规模地运行这个过程,并且在不同类型的问题上进行测试。

结果令人惊讶:如果没有使用或使用了不好的种子句,问题只能在约 50%的情况下正确解决,而当使用了好的种子句时,正确率可以达到 80%。

DeepMind 的方法包括使用这些测试中的指标来指导更好提示的创建。他们得出的结论是,通过持续迭代提示并考虑模型的反馈,AI 系统可以改进其提示生成过程,对输出产生非常重要的影响。

请注意,这种方法并不涉及更新模型的内部参数,如同训练模型时所做的那样,而是专注于优化输入本身。

一些有趣的动手示例

这是一些我自己测试中灵感来自 DeepMind 工作的有趣示例,你也可以立即使用 ChatGPT 免费版尝试(该版本由 DeepMind 预印本中显示为 GPT-3.5-Turbo 的模型提供支持)。

示例 1,复制 DeepMind 预印本中展示的一个问题的思路

DeepMind 论文中引起我注意的第一个例子是要求进行线性回归,因为既然这些模型原则上无法进行数学运算,那么我期望它们永远无法工作,无论提供什么提示。

如果你要求 ChatGPT 对非常简单的数据进行线性回归,比如 x = [1, 2, 3, 4]y = [2, 4, 6, 8],你会发现它会立即正确地解决。但如果我们用更难的线性回归来挑战它呢?让我们看看。

在这里,我生成了不同 x 值的合成数据,这些值在 0 和 12 之间随机分布,并使用方程 y = 3.5 x — 11.5 在常规电子表格程序中计算了没有噪声的 y 值。然后,我要求程序“找到描述这些数据的线性方程:”并跟上 x 和 y 对。就像这样:

从与 ChatGPT 免费版交互时的截图,进行测试时使用。

这是我得到的答案:

从与 ChatGPT 免费版交互时的截图,进行测试时使用。

你清楚地看到答案是错误的,并且它经过了一些我没有要求的编码。更令人困惑的是,这段代码看起来本身是正确的,并且可能会导致正确的解决方案,但文本生成所呈现的“结果”是错误的。

我尝试重新生成问题,得到了另一个不正确的答案,这次没有调用任何代码生成,而是尝试直接处理:

从与 ChatGPT 免费版交互时的截图,进行测试时使用。

现在来看看“魔法”。

如果我们用 DeepMind 报告能显著改善 GPT-3.5-Turbo 答案的以下句子来引导提示会怎么样?(摘自预印本的表 1):

一点算术和逻辑方法将帮助我们快速得到这个问题的解决方案。”?

我们来试试:

从与 ChatGPT 免费版交互时的截图,进行测试时使用。

严格来说,这不算回归,但逻辑是完美的,结果是正确的!

示例 2,使用 DeepMind 发现的种子句子在 GSM8K 上表现最佳的化学问题

GSM8K 是一个由人工问题编写者创建的高质量且语言多样的小学数学单词问题的大型数据集。DeepMind 使用这个数据集来评估几个 LLM 的能力,发现对于 GPT-3.5-Turbo,最佳的提示开始方式是:

分析给定的信息,将问题分解为可管理的步骤,应用合适的数学运算,并提供一个清晰、准确、简洁的解决方案,如有必要,确保精确四舍五入。考虑所有变量,仔细考虑问题的背景,以便有效地解决问题。

所以,我选择了一个关于化学计量学的问题(在化学中,就是根据给定的试剂量或反之,计算得到的产物量),并要求 ChatGPT 解决,首先是不带种子句子,然后是带上种子句子。

这是没有任何种子提示的情况:

这是与 ChatGPT 免费版本互动的截图,在进行测试时(在这种情况下,侧边并排显示)。

答案完全错误,因为在(a)中我们要求的是分子数量,而不是摩尔数,而在(b)中我们要求的是质量,但数字是错的。

正确答案是 5.337E22 分子和 10.41 克 Zn(CN)2。

现在让我们看看当我在问题前加上种子句子时发生了什么:

这是与 ChatGPT 免费版本互动的截图,在进行测试时(在这种情况下,侧边并排显示)。

我。感到。惊讶。

两个答案都是完全正确的!(并且计算过程非常详细。)

我想我得重新修订并重写我很久以前写的这篇博客,当时 GPT-3 刚刚发布:

## 设计测试以测量 GPT-3 对基础科学的知识

学生们能从 OpenAI 最新的语言模型中学习,并将其用作全天候顾问吗?学生们能用它来……

[towardsdatascience.com

手动提示工程的终结?

这些以及其他近期的发展表明,我们可能正在见证手动提示工程时代的结束,以及一个新的时代的开始,在这个时代中,你无需掌握每个 AI 的提示语言就能获得所需的结果。如果你尝试了 DALL-E 3,你一定会发现它在理解你的意图方面做得更好,即使使用的是你一直使用的相同提示。用户可以越来越自然地指示模型,甚至让 AI 系统自动生成能够产生所需结果的提示。

AI 优化 AI:似乎非常强大。那么呢?工作呢?

从历史上看,以前的工业和技术革命产生了新的工作形式,通常伴随着重大的经济和社会变革,带来了不可预测的后果。在那些时代,随着工作的自动化,新的角色应运而生,以满足每个时代的需求——因此工作虽然改变了,但始终存在。

但随着人工智能革命的到来,这可能会有所不同。我们可能首次面临一种能够适应新挑战并学习未来将出现的任务和工作的自动化力量,包括如何控制自身。

我们在短期内会得到什么?

仅仅触及 AI 模型的一种可改进的特征,DeepMind 的预印本显示,单凭优化提示而不进行实际的再训练或微调,我们就可以让 LLMs 表现得更好,其他生成性 AI 模型也必定是如此。

这种提示的优化很难被人类匹配,特别是当优化的提示有意义但其确切形式并不明显时——见我上面提供的第二个示例,其中最优的种子提示相当长且非常具体。

这些发现和进展表明,我们正朝着一个未来发展,在这个未来中,人工智能系统将能够自行生成非常有效的提示,从而减少对手动提示工程的需求。因此,提示工程师的角色可能会发生变化,那些无法适应新自动化程序并学习在哪里仍然有人工干预空间的提示工程师将被淘汰。

区分可能具有不同未来的两种角色至关重要。一方面,如果你的工作是手动创建提示,随着人工智能模型在生成自身提示方面变得更加熟练,你的角色可能变得不那么相关且需求减少。另一方面,如果你的角色涉及在更广泛的系统中优化 AI 模型的使用,其中你将其与非 AI 软件或代码的输入和输出相结合,并可能在其他场景中使用,那么你在劳动力市场上的价值可能对自动化更具韧性,并在长期内仍然有用。

总结我对这一问题的关键观点,就像一些人曾经快速地跟上了提示工程的潮流一样,他们现在必须保持警觉,因为人工智能及其与人类的互动正迅速发展。因此,跟踪如 DeepMind 的预印本或这篇博客文章等文献,对于了解如何在人工智能工具和其他人工智能工具发展时最佳地调整自己的掌握是至关重要的。

参考文献

DeepMind 在 arXiv 上的预印本:

## 大型语言模型作为优化器

优化无处不在。虽然基于导数的算法已成为解决各种问题的强大工具,...

arxiv.org

另一位作家在 Medium 上的一篇有趣的相关博客文章:

[## 大型语言模型作为优化器的解释

为什么这篇论文很重要

medium.com](https://medium.com/@minh.hoque/large-language-models-as-optimizers-explained-a20dc5e5c5af?source=post_page-----e95fb7f4903c--------------------------------)

www.lucianoabriata.com 我写的内容涵盖了我广泛兴趣领域的一切:自然、科学、技术、编程等。 订阅以获取我的新故事 通过电子邮件。要 咨询小项目 请查看我的 服务页面。你可以 在这里联系我 你也可以 在这里给我小费*。

音频机器学习的新领域

原文:towardsdatascience.com/new-frontiers-in-audio-machine-learning-6474ffaa5cb9?source=collection_archive---------9-----------------------#2023-04-20

TDS EditorsTowards Data Science TDS Editors

·

关注 发表在 Towards Data Science · 发送 新闻简报 · 阅读时间 3 分钟 · Apr 20, 2023

--

不久以前,任何涉及处理音频文件的工作流程——甚至是像转录播客剧集这样相对简单的任务——都伴随着一系列艰难的选择。你可以选择手动操作(在过程中浪费数小时甚至数天的时间),依赖于几款笨拙且最终令人失望的应用程序,或者拼凑出类似于弗兰肯斯坦怪物的工具和代码组合。

那些日子已经过去。强大的模型和易于访问的 AI 界面的兴起使得处理音频和音乐变得更加高效,新的视野每天都在不断开启。为了帮助您跟上音频聚焦机器学习的最新进展,我们从过去几周收集了一些突出的文章,涵盖了各种方法和用例。过滤掉噪音,深入了解吧!

  • 揭示音乐标签 AI 的黑箱。随着每天在 Spotify 和 Apple Music 等平台上添加数千首歌曲,您是否曾想过这些服务如何知道为每首歌分配哪种音乐流派?Max Hilsdorf的项目利用 Shapley 值确定特定乐器的存在如何影响 AI 系统标记新曲目的方式。

  • 探索基于深度学习的鸟鸣识别方法Leonie Monigatti最近的贡献涵盖了去年的 BirdCLEF2022 Kaggle 竞赛,参赛者需创建鸟鸣录音的分类器。Leonie 向我们展示了一种巧妙的方法,将音频波形转换为梅尔频谱图,使深度学习模型可以像处理图像一样处理它们。

图片由Oskars Sylwan提供,来源于Unsplash

  • 从长 YouTube 视频中自动生成摘要。如果您是一个完美主义者,您会欣赏Bildea Ana使用 OpenAI 的 Whisper 模型和 Hugging Face 进行音频转录的简化流程,然后使用开源的 BART 编码器进行总结。您可以将此方法应用于自己的录音和语音备忘录,或者任何其他音频文件(前提是其所有者允许,当然——始终仔细检查您希望使用的数据的版权和许可状态)。

  • 将转录提升到一个新水平Luís Roque的最新项目与 Ana 的项目有相似之处,但有所不同。它也依赖于 Whisper 来转录音频文件,但随后通过部署 PyAnnotate 进行说话者分离,“即识别和区分不同说话者的语音的过程”。

你说“请不要停止音乐”?我们很乐意满足——这里是我们最近一些最喜欢的关于非音频相关主题的文章。请享用!

  • 学习神经网络不应该是解读误导性图示的练习, Aaron Master和 Doron Bergman 表示,他们提出了一种建设性的新方法来创建更好、更准确的神经网络。

  • 从推广设计到库存分析,Idil Ismiguzel 展示了关联规则挖掘的力量:一种赋能数据专业人士发现数据集中的频繁模式的技术。

  • 对于无监督学习和 K-means 聚类的动手实践,不要错过Nabanita Roy的最新教程,该教程专注于按颜色分组图像像素的使用案例。

  • 如果你发现人工智能、政府监管和加拿大官僚主义的交集很吸引人(谁会不感兴趣呢?),Mathieu Lemay的深度剖析是你本周绝对不容错过的一篇文章。

  • 随着合成数据在多个领域的作用不断发展(和增长),Miriam Santos的实用 CTGAN 生成合成数据指南依然时效性和实用性十足。

  • 我们绝对不能在一整周内没有一个以 GPT 为主题的推荐;如果你还没读过,我们强烈推荐Henry Lai对这些备受欢迎的模型背后的数据驱动 AI 概念的概述。

感谢您本周收听《Variable》!如果您喜欢在 TDS 上阅读的文章,请考虑成为 Medium 会员——如果您是符合条件国家的学生,不要错过享受会员大幅折扣的机会。

下期《Variable》见,

TDS 编辑团队

新版 Scikit-Learn 更适合数据分析

原文:towardsdatascience.com/new-scikit-learn-is-more-suitable-for-data-analysis-8ca418e7bf1c

Scikit-Learn 版本 ≥1.2.0 的 Pandas 兼容性及更多

Saptashwa BhattacharyyaTowards Data Science Saptashwa Bhattacharyya

·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 3 月 8 日

--

新版 Sklearn 的一些非常酷的更新!(来源:作者笔记本)

大约在去年 12 月,Scikit-Learn 发布了一个重要的 稳定更新 (v. 1.2.0–1),我终于可以尝试一些突出的新功能。现在它与 Pandas 兼容性更好,还有一些新功能将帮助我们进行回归和分类任务。接下来,我将介绍一些新更新及其使用示例。我们开始吧!

与 Pandas 的兼容性:

在使用数据进行训练 ML 模型(如回归或神经网络)之前应用数据标准化是一种常见技术,以确保具有不同范围的特征在预测中获得相等的重要性(如果或当需要时)。Scikit-Learn 提供了各种预处理 API,如 StandardScalerMaxAbsScaler 等。随着新版本的发布,可以在预处理后保持 Dataframe 格式,让我们看看下面:

from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
########################
X, y = load_wine(as_frame=True, return_X_y=True) 
# available from version >=0.23; as_frame
########################
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, 
                                                    random_state=0)
X_train.head(3)

数据集 Wine 的 Dataframe 格式

新版本包括一个选项,即使在标准化之后也能保持 Dataframe 格式:

 ############
# v1.2.0
############

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler().set_output(transform="pandas") 
## change here

scaler.fit(X_train)
X_test_scaled = scaler.transform(X_test)
X_test_scaled.head(3)

即使在标准化之后,Dataframe 格式仍保持不变。

之前,它会将格式转换为 Numpy 数组:

###########
# v 0.24
########### 

scaler.fit(X_train)
X_test_scaled = scaler.transform(X_test)
print (type(X_test_scaled))

>>> <class 'numpy.ndarray'>

由于 Dataframe 格式保持不变,我们不需要像处理 Numpy 数组格式时那样关注列。分析和绘图变得更容易:

 fig = plt.figure(figsize=(8, 5))
fig.add_subplot(121)
plt.scatter(X_test['proline'], X_test['hue'], 
            c=X_test['alcohol'], alpha=0.8, cmap='bwr')
clb = plt.colorbar()
plt.xlabel('Proline', fontsize=11)
plt.ylabel('Hue', fontsize=11)
fig.add_subplot(122)
plt.scatter(X_test_scaled['proline'], X_test_scaled['hue'], 
            c=X_test_scaled['alcohol'], alpha=0.8, cmap='bwr')
# pretty easy now in the newer version to see the effect

plt.xlabel('Proline (Standardized)', fontsize=11)
plt.ylabel('Hue (Standardized)', fontsize=11)
clb = plt.colorbar()
clb.ax.set_title('Alcohol', fontsize=8)
plt.tight_layout()
plt.show()

图 1:标准化前后特征的依赖关系!(来源:作者笔记本)

即使我们建立了一个管道,管道中的每个转换器也可以配置为返回数据框,如下所示:

 from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC

clf = make_pipeline(StandardScaler(), SVC())
clf.set_output(transform="pandas") # change here 
svm_fit = clf.fit(X_train, y_train)

print (clf[:-1]) # StandardScaler 
print ('check that set_output format indeed remains even after we build a pipleline: ', '\n')
X_test_transformed = clf[:-1].transform(X_test)

X_test_transformed.head(3)

数据框格式即使在管道中也可以保持不变!

数据集获取更快、更高效:

OpenML是一个开放的数据集分享平台,而 Sklearn 中的数据集 API 提供了fetch_openml函数来获取数据;随着 Sklearn 的更新,这一步在内存和时间上更高效。

 from sklearn.datasets import fetch_openml

start_t = time.time()
X, y = fetch_openml("titanic", version=1, as_frame=True, 
                    return_X_y=True, parser="pandas")
# # parser pandas is the addition in the version 1.2.0

X = X.select_dtypes(["number", "category"]).drop(columns=["body"])
print ('check types: ', type(X), '\n',  X.head(3))
print ('check shapes: ', X.shape)
end_t = time.time()
print ('time taken: ', end_t-start_t)

使用parser='pandas'可以显著提高运行时间和内存消耗的效率。可以通过psutil库轻松检查内存消耗,如下所示:

print(psutil.cpu_percent())

部分依赖图:分类特征

部分依赖图之前也存在,但仅限于数值特征,现在已经扩展到分类特征。

如 Sklearn 文档中所述:

部分依赖图显示目标与感兴趣的一组输入特征之间的依赖关系,边际化所有其他输入特征(‘补充’特征)的值。直观上,我们可以将部分依赖性解释为目标响应与感兴趣的输入特征的函数。

使用上述的‘titanic’数据集,我们可以轻松绘制分类特征的部分依赖性:

使用上述代码块,我们可以得到如下的部分依赖图:

图 2:分类变量的部分依赖图。(来源:作者笔记)

在 0.24 版本中,我们会遇到分类变量的值错误:

>>> ValueError: could not convert string to float: ‘female’

直接绘制残差(回归模型):

在分析分类模型的性能时,Sklearn 的度量 API 中,像PrecisionRecallDisplayRocCurveDisplay这样的绘图例程在旧版本(0.24)中存在;在新版中,回归模型也可以进行类似的操作。下面是一个示例:

可以直接使用 Sklearn 绘制线性模型拟合及其对应的残差。(来源:作者笔记)

尽管总是可以使用 matplotlib 或 seaborn 绘制拟合线和残差,但在我们确定了最佳模型后,能够快速直接在 Sklearn 环境中检查结果是很棒的。

新版 Sklearn 中还有一些其他的改进/新增功能,但我发现这 4 个主要改进对标准数据分析特别有用。

参考文献:

[1] Sklearn 版本亮点:V 1.2.0

[2] Sklearn 版本亮点: 视频

[3]所有图表和代码: 我的 GitHub

如果你对进一步的基础机器学习概念及更多内容感兴趣,可以考虑使用 我的链接加入 Medium。你不会支付额外费用,但我将获得一小笔佣金。感谢大家!!

[## 使用我的推荐链接加入 Medium - Saptashwa Bhattacharyya

更多来自 Saptashwa(以及 Medium 上所有其他作者)的内容。你的会员费用直接支持 Saptashwa 和其他作者…

medium.com](https://medium.com/@saptashwa/membership?source=post_page-----8ca418e7bf1c--------------------------------)

新的 SHAP 图:小提琴图和热图

原文:towardsdatascience.com/new-shap-plots-violin-and-heatmap-20f647313b64

SHAP 版本 0.42.1 中的图表可以告诉你关于模型的哪些信息

Conor O'SullivanTowards Data Science Conor O'Sullivan

·发表于 Towards Data Science ·6 分钟阅读·2023 年 8 月 14 日

--

(来源:作者)

对于 SHAP 最大的担忧之一与软件包本身有关。它已经有一段时间没有更新了,GitHub 上的问题也不断增加。让许多用户感到欣慰的是,贡献者们变得更加活跃。事实上,他们给我们带来了新的图表——小提琴图和热图。我们将:

  • 提供这些图的代码

  • 讨论我们可以从中获得哪些新见解

你还可以观看关于这个主题的简介:

现有的 SHAP 图

我们从之前的 SHAP 教程继续。你可以在下面的文章中找到这篇教程。你还可以在 GitHub 上找到完整的项目。要使用新的图表,你需要更新 SHAP 软件包。我使用的是版本 0.42.1

## 使用 Python 介绍 SHAP

如何创建和解释 SHAP 图:瀑布图、力图、平均 SHAP 图、蜜蜂散点图和依赖图

towardsdatascience.com

总结来说,我们使用 SHAP 来解释一个基于 abalone 数据集 构建的模型。该数据集包含 4,177 个实例,你可以在下方看到特征的示例。我们使用这 8 个特征来预测 y——螺旋纹数

X 特征矩阵(来源:UCI 机器学习库)(许可证:CC0:公共领域)

本教程继续计算 SHAP 值并显示各种 SHAP 图。理解其中的一些图对于理解新的 SHAP 图是有帮助的。我们将看到它们提供了类似的信息。

第一个是均值 SHAP图,见图 1。对于每个特征,这给出了所有实例的绝对均值 SHAP 值。对预测贡献显著的特征,其均值 SHAP 值会很高。换句话说,这张图告诉我们哪些特征在一般情况下最为重要。

图 1:绝对均值图(来源:作者)

另一种图是蜜蜂散点图,见图 2。这是所有 SHAP 值的可视化。在 y 轴上,值按特征分组。对于每个组,点的颜色由特征值决定(即特征值较高的点颜色较红)。现在,让我们看看新的 SHAP 图与这些图的比较情况。

图 2:蜜蜂散点图(来源:作者)

SHAP 小提琴图

小提琴图的代码类似于我们在其他 SHAP 图中看到的内容。我们只需输入我们的shap_values对象(第 2 行)。为了明确,这些值是我们在之前的教程中计算的。你可以在图 3中查看输出。与图 2相比,我们可以看到小提琴图是蜜蜂散点图的一种不同风格。

# violin plot
shap.plots.violin(shap_values)

图 3:小提琴图(来源:作者)

另一种风格是分层小提琴图,见图 4。在这种图中,每个 SHAP 值下的特征值变化更为清晰。也就是说,与原始的小提琴图和蜜蜂散点图相比。

# layered violin plot
shap.plots.violin(shap_values, plot_type="layered_violin")

图 4:分层小提琴图(来源:作者)

由于相似性,我们从这些图中获得的见解类似于蜜蜂散点图。这些图可以突出显示重要的关系,因为我们可以看到哪些特征往往具有较大的 SHAP 值。通过按特征值着色,我们还可以开始理解特征与模型预测之间的关系。现在,让我们看看热图是否能提供更多见解。

SHAP 热图

你可以在图 5中看到热图函数的输出。这里有很多内容:

  • 在 x 轴上,我们对所有 4,177 个实例进行了标记

  • y 轴表示特征

  • 每个实例上方的线条按该特征的SHAP 值进行着色

  • f(x) 线表示该实例的预测环数

  • 右侧的条形图显示了我们在图 1中看到的平均 SHAP 值

与蜜蜂散点图类似,这是一种每个 shap 值的图。但现在我们关注的是 SHAP 值与实例组之间的模式。

# heatmap
shap.plots.heatmap(shap_values)

图 5:SHAP 热图(来源:作者)

默认情况下,实例是使用层次聚类算法进行排序的。开发者表示,“这会将因相同原因得到相同模型输出的样本分组在一起”。我发现选择自己的实例排序对于发现模式更为有用。

热图排序

为此,我们传递一个instance_order参数。这必须是与数据集长度相同的整数数组(即 4,177)。这些值给出实例的顺序。在下面的代码中,我们将实例从预测值最低到最高排序。

# order by predictions
order = np.argsort(y_pred)
shap.plots.heatmap(shap_values, instance_order=order)

图 6的输出中,我们看到了一些模式的出现。注意去壳重量的 SHAP 值有 3 个组。存在两个正值组——一个是当壳体重量的 SHAP 值既小又大时。一个潜在的交互作用?我们可以通过 SHAP 交互值进一步探索。

图 6:按预测值排序的 SHAP 热图(来源:作者)

另一种选择是按特征的值对实例进行排序。下面,我们使用壳体重量对它们进行排序。我们可以看到,预测的环数随着该特征的增加而增加。我们还可以看到该特征的 SHAP 值也有增加的趋势。换句话说,壳体重量值越大,预测的环数越高。

# order by feature's values
order = np.argsort(data['shell weight'])
shap.plots.heatmap(shap_values, instance_order=order)

图 7:按特征值排序的 SHAP 热图(来源:作者)

我们可以以任何我们想要的方式排序热图。这种灵活性可以帮助我们以其他图表无法提供的方式理解我们的模型。就个人而言,我很兴奋看到这些发展的出现。更多的特征和可视化选项将受到包的众多用户的赞赏。你希望在未来的更新中看到什么?

如果你想了解更多关于 SHAP 的信息,请查看下面的文章:

## 使用 SHAP 分析交互作用

使用 SHAP Python 包识别和可视化数据中的交互作用

[towardsdatascience.com ## 从 Shapley 到 SHAP — 理解数学

关于 SHAP 特征贡献计算的概述

[towardsdatascience.com ## SHAP 的局限性

SHAP 如何受到特征依赖、因果推断和人为偏差的影响

[towardsdatascience.com

希望你喜欢这篇文章!你可以通过成为我的推荐会员来支持我😃

[## 通过我的推荐链接加入 Medium — Conor O’Sullivan

作为 Medium 会员,你的部分会员费用将用于你阅读的作家,同时你可以全面访问所有故事……

conorosullyds.medium.com](https://conorosullyds.medium.com/membership?source=post_page-----20f647313b64--------------------------------)

| Twitter | YouTube | Newsletter — 免费注册获取 Python SHAP 课程

参考资料

S. Lundberg SHAP****Python 包 github.com/slundberg/shap

S. Lundberg & S. Lee, 统一解释模型预测的方法 arxiv.org/pdf/1705.07874.pdf

SHAP 热图 shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/heatmap.html

SHAP 小提琴图总结 shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/violin.html

牛顿运动定律:最初的梯度下降

原文:towardsdatascience.com/newtons-laws-of-motion-the-original-gradient-descent-a2860037c76f?source=collection_archive---------4-----------------------#2023-12-27

探索梯度下降和牛顿运动方程的共享语言

Rodrigo SilvaTowards Data Science Rodrigo Silva

·

关注 发表在 Towards Data Science ·7 min read·2023 年 12 月 27 日

--

照片由 Luddmyla . 提供,来自 Unsplash

我记得在工程学院读本科时,作为物理学学生,我上了第一门机器学习课程。换句话说,我是个局外人。当教授通过梯度下降解释反向传播算法时,我脑海里有个模糊的问题:“梯度下降是一个随机算法吗?”在举手向教授提问之前,陌生的环境让我犹豫了一下;我稍微缩了回来。突然,答案闪现在我脑海里。

这是我想到的。

梯度下降

要说清楚什么是梯度下降,我们首先需要定义训练神经网络的问题,我们可以通过概述机器如何学习来做到这一点。

神经网络训练概述

在所有监督学习的神经网络任务中,我们有一个预测值和真实值。预测值与真实值之间的差异越大,说明我们的神经网络在预测值方面的表现越差。因此,我们创建了一个称为损失函数的函数,通常表示为L,它量化了实际值与预测值之间的差异。训练神经网络的任务就是更新权重和偏差(简称参数)以最小化损失函数。这就是训练神经网络的大致情况,而“学习”只是更新参数以最佳适应实际数据,即最小化损失函数。

通过梯度下降优化

梯度下降是一种用于计算这些新参数的优化技术。由于我们的任务是选择参数以最小化损失函数,我们需要一个选择标准。我们试图最小化的损失函数是神经网络输出的函数,因此在数学上我们将其表示为L = L(y_nn, y)。但神经网络输出y_nn也依赖于其参数,所以y_nn = y_nn(θ),其中θ是一个包含我们神经网络所有参数的向量。换句话说,损失函数本身是神经网络参数的一个函数。

借鉴了一些向量微积分的概念,我们知道要最小化一个函数,你需要逆着它的梯度方向前进,因为梯度指向函数最快增长的方向。为了获得一些直觉,我们来看一下图 1 中 L(θ)可能的样子。

图 1:显示L(w1,w2)作为 w1 和 w2 函数的表面。图像由作者提供。

在这里,我们对训练神经网络时什么是期望的、什么是不期望的有了清晰的直觉:我们希望损失函数的值更小,所以如果我们从使损失函数落在黄色/橙色区域的参数 w1 和 w2 开始,我们希望沿着紫色区域的方向滑动下降到表面。

这种“滑下”运动是通过梯度下降方法实现的。如果我们处在表面上最亮的区域,梯度将继续指向上方,因为这是增加最快的方向。然后,沿相反方向(因此是梯度下降)会产生一个向着最大减少区域的运动。

为了看到这一点,我们可以绘制梯度下降向量,如图 2 所示。在这个图中,我们有一个等高线图,显示了与图 1 相同的区域和函数,但损失函数的值现在编码为颜色:越亮,值越大。

图 2:显示指向梯度下降方向的向量的等高线图。图片由作者提供。

我们可以看到,如果我们选择一个位于黄色/橙色区域的点,梯度下降向量会指向最快到达紫色区域的方向。

一个很好的免责声明是,通常一个神经网络可能包含任意多的参数(GPT-3 有超过 1000 亿个参数!),这意味着这些漂亮的可视化在实际应用中完全不实用,神经网络中的参数优化通常是一个非常高维的问题。

从数学上讲,梯度下降算法可以表示为

在这里,θ(n+1) 是更新后的参数(即图 1 的表面滑下来的结果);θ(n) 是我们开始时的参数;ρ 被称为学习率(梯度下降指向的方向上的步长);∇L 是在初始点 θ_(n) 计算的损失函数的梯度。使这里的名字为下降的是前面的负号。

数学在这里非常关键,因为我们会看到牛顿的运动第二定律与梯度下降方程具有相同的数学公式。

牛顿第二定律

牛顿的运动第二定律可能是经典力学中最重要的概念之一,因为它说明了力、质量和加速度是如何联系在一起的。每个人都知道牛顿第二定律的高中公式:

其中,F 是力,m 是质量,a 是加速度。然而,牛顿原始的公式是基于一个更深层的量:动量。动量是物体的质量与速度的乘积:

并且可以解释为物体的运动量。牛顿第二定律背后的思想是,要改变一个物体的动量,你需要以某种方式干扰它,这种干扰被称为。因此,牛顿第二定律的简洁公式是

这种公式适用于你能想到的每一种力,但我们希望在讨论中有更多的结构,为了获得结构,我们需要限制我们的可能性范围。让我们讨论保守力和势能。

保守力和势能

保守力是一种不耗散能量的力。这意味着,当我们处于仅涉及保守力的系统中时,总能量是常量。这听起来很严格,但实际上,自然界中最基本的力量都是保守的,如重力和电力。

对于每个保守力,我们关联一个称为势能的量。这个势能通过公式与力相关联。

在一维中。如果我们仔细查看最后两个公式,就会得到保守场的第二运动定律:

由于导数处理起来有点复杂,并且在计算机科学中我们反正将导数近似为有限差分,因此让我们用Δ替换 d:

我们知道Δ意味着“取更新值并减去当前值”。因此,我们可以将上述公式重新写成

这已经看起来很像上面几行中的梯度下降公式。为了使其更类似,我们只需在三维中查看它,梯度自然会出现:

我们可以清楚地看到梯度下降与上述公式之间的对应关系,这完全源于牛顿物理学。一个物体的动量(如果你愿意,也可以理解为速度)总是指向势能减少最快的方向,步长由Δt 给出。

结束语和要点总结

因此,我们可以将牛顿公式中的势能与机器学习中的损失函数相关联。动量向量类似于我们试图优化的参数向量,时间步长常数即学习率,即我们朝着损失函数最小值移动的速度。因此,类似的数学公式表明这些概念是联系在一起的,并且提供了一种很好的统一视角。

如果你想知道,我一开始的问题的答案是“不是”。梯度下降算法中没有随机性,因为它复制了自然每天所做的事情:粒子的物理轨迹总是试图在周围找到最低可能的势能。如果你让一个球从某个高度掉落,它总会有相同的轨迹,没有随机性。当你看到有人在滑板上滑下陡峭的坡道时,请记住:那实际上是自然在应用梯度下降算法。

我们看待问题的方式可能会影响其解决方案。在这篇文章中,我没有展示任何关于计算机科学或物理的新内容(实际上,这里的物理知识已有约 400 年历史),但改变视角和将(表面上)不相关的概念结合在一起,可能会创造出新的联系和对某一主题的直觉。

参考文献

[1] Robert Kwiatkowski, 梯度下降算法——深度探讨,2021 年。

[2] Nivaldo A. Lemos, 《解析力学》,剑桥大学出版社,2018 年。

创建快速、安全且兼容的数据结构的九条规则(第一部分)

原文:towardsdatascience.com/nine-rules-for-creating-fast-safe-and-compatible-data-structures-in-rust-part-1-c0973092e0a3?source=collection_archive---------6-----------------------#2023-04-05

来自 RangeSetBlaze 的经验教训

Carl M. KadieTowards Data Science Carl M. Kadie

·

关注 发表在 Towards Data Science · 13 分钟阅读 · 2023 年 4 月 5 日

--

将数字存储在树中 — 来源:Stable Diffusion

今年,我开发了一个新的 Rust crate,名为 [range-set-blaze](https://crates.io/crates/range-set-blaze),它实现了范围集合数据结构。范围集合是一种有用(虽然较少见)的数据结构,它将整数集合存储为已排序且不相交的范围。例如,它存储以下三个范围:

100..=2_393, 20_303..=30_239_000, 501_000_013..=501_000_016

而不是 30220996 个单独的整数。除了潜在的内存节省,range-set-blaze还提供了高效的集合操作,如并集、交集、补集、差集和对称差集。

在创建range-set-blaze时,我学到了九条规则,这些规则可以帮助你在 Rust 中创建数据结构。除了数据结构,这些规则中的许多还可以帮助你提高任何 Rust 代码的性能和兼容性。

规则如下:

  1. 抄袭你的 API、文档,甚至代码——从标准库中抄袭。

  2. 设计构造函数以便于使用、兼容性和速度。

  3. 创建比预期更多的 Rust 迭代器。

  4. 使用 traits 使非法值不可表示。

  5. 定义具有保证属性和有用方法的通用迭代器。

第二部分**中讨论:

6. 定义运算符和快速操作。

7. 遵循“良好 API 设计的九条规则”,特别是“编写良好的文档”。

8. 使用代表性数据、Criterion Benchmarking 和性能分析来优化性能。

9. 测试覆盖率、文档、traits、编译器错误和正确性。

在查看前五条规则之前,让我们先看看range-set-blaze可能的使用场景,它的集合操作是如何工作的,以及它与其他范围集 crate 的比较。

有用性: 想象一下在一个不可靠的集群上运行 100 亿个统计实验。集群上的每个任务运行几个实验。每个实验产生一行带有实验编号的输出。所以,一个任务可能会把这些放入一个文件中:

你会使用什么数据结构来查找哪些实验缺失并需要重新提交?一个选项是:将输出的实验编号存储在一个[BTreeSet](https://doc.rust-lang.org/std/collections/struct.BTreeSet.html)中,然后进行线性扫描以查找间隙。

更快且内存效率更高的选项:使用范围集。八年前,我创建了[IntRangeSet](https://fastlmm.github.io/PySnpTools/#util-intrangeset),一个用 Python 编写的范围集来解决这个问题。现在,我会在 Rust 中使用range-set-blaze示例代码)。

集合操作:这是一个简单的并集运算符(|)示例:

use range_set_blaze::RangeSetBlaze;
// a is the set of integers from 100 to 499 (inclusive) and 501 to 1000 (inclusive)
let a = RangeSetBlaze::from_iter([100..=499, 501..=999]);
 // b is the set of integers -20 and the range 400 to 599 (inclusive)
let b = RangeSetBlaze::from_iter([-20..=-20, 400..=599]);
// c is the union of a and b, namely -20 and 100 to 999 (inclusive)
let c = a | b;
assert_eq!(c, RangeSetBlaze::from_iter([-20..=-20, 100..=999]));

附注:请参阅项目的[README.md](https://github.com/CarlKCarlK/range-set-blaze)以获取来自生物学的另一个集合运算符示例。该示例使用RangeSetBlaze结构体从转录区域和外显子区域中查找基因的内含子区域。

与其他范围相关的 crate 的比较

好处: 尽管 Rust 的 crates.io 已经包含了几个范围集合的 crate,我希望我的版本能提供完整的集合操作,同时保持性能。通过各种优化措施,我相信它达到了这些目标(请参见 基准报告)。例如,它可以比最流行的范围集合 crate 快 75 倍来处理单个整数(因为其他 crate 没有对单个处理做特殊优化——但它可以轻松添加这种优化)。在另一个基准测试中,range-set-blaze——使用混合算法——在合并两个集合时比其他 crate 快 30% 到 600%。

不足: 与其他范围相关的 crate 相比,range-set-blaze 有两个重要不足。首先,它仅适用于 Rust 整数类型。大多数其他 crate 处理任何可以排序的元素(日期、浮点数、IP 地址、字符串等)。其次,它仅提供集合功能。许多其他 crate 还处理映射。随着兴趣(以及可能的帮助),这些不足可能会在未来得到解决。

创建数据结构需要做出许多决策。根据我在 range-set-blaze 上的经验,以下是我推荐的决策。为了避免优柔寡断,我将这些建议表述为规则。当然,每个数据结构都不同,因此并非每条规则都适用于每个数据结构。

本文涵盖规则 1 到 5。第二部分 涵盖规则 6 到 9。

规则 1:抄袭 API、文档甚至代码——来自标准库

查找标准库中的类似数据结构,并逐行研究其文档。我选择了 BTreeSet 作为我的模型。它可以在缓存高效的平衡树中存储整数集合。

附带说明:稍后,在基准测试(规则 8)中,我们将看到 range_set_blaze::*RangeSetBlaze* 在某些“块状”整数集合上的速度可能比 *BTreeSet* 快 1000 倍。

BTreeSet 提供了 28 个方法,例如,clearis_subset。它还实现了 18 个特性,例如,FromIterator<T>。这是 BTreeSetclear 文档和 RangeSetBlazeclear 文档:

你可以看到我主要是直接复制的。我将“元素”改为“整数元素”,以提醒用户 RangeSetBlaze 支持什么。我删除了 where A: Clone,因为所有整数必然是可克隆的。注意,Rust 文档包括一个“源”链接,这使得复制变得容易。

复制提供了这些优点:

  • 它告诉你需要提供哪些方法。换句话说,它为你的 API(应用程序编程接口)提供了一个起点。这节省了设计时间。此外,用户会理解并期望这些方法。你甚至可以使你的数据结构成为标准数据结构的直接替代品。

  • 几乎可以免费获得文档文本和文档测试。

  • 你甚至可以复制代码。例如,这里是 BTreeSetRangeSetBlazeis_superset 代码:

#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
pub fn is_superset(&self, other: &BTreeSet<T, A>) -> bool
where
    T: Ord,
{
    other.is_subset(self)
}
#[must_use]
pub fn is_superset(&self, other: &RangeSetBlaze<T>) -> bool {
    other.is_subset(self)
}

BTreeSet 代码让我想起了超集可以通过子集来定义,以及 #[must_use] 是一个存在且在这里适用的特性。

你可能 决定 不支持标准数据结构中的所有功能。例如,我跳过了 new_in,这是一个实验性特性。同样,标准库支持映射(不仅仅是集合)、任何可排序的元素(不仅仅是整数)和 Serde 序列化。对我而言,这些是可能的未来特性。

你也可以 决定 以不同的方式支持某些内容。例如,BTreeSet::first 返回 Option<&T>RangeSetBlaze::first 返回 Option<T>。我知道 T 是一个便于克隆的整数,所以不需要是一个引用。

顺便提一下:Rust 没有一个通用的 *Set* 特性来告诉所有集合应该实现哪些方法,甚至提供一些默认实现(例如,*is_superset**is_subset* 作为基础)吗?没有,但这个问题正在被 讨论

你也可能 决定 支持比标准数据结构更多的方法。例如,RangeSetBlaze::lenBTreeSet::len 一样,返回集合中的元素数量。然而,RangeSetBlaze 还提供 ranges_len,它返回集合中排序的、不相交的范围的数量。

规则 2:设计构造函数以提高易用性、兼容性和速度

如果有一个空版本的数据结构是有意义的,你会想定义一个 new 方法和一个 [Default::default](https://doc.rust-lang.org/std/default/trait.Default.html) 方法。

类似地,如果从迭代器填充数据结构是有意义的,你会想定义 [FromIterator::from_iter](https://doc.rust-lang.org/std/iter/trait.FromIterator.html) 方法。这些方法也会自动定义 collect 方法。像 BTreeSet 一样,RangeSetBlaze 接受整数的迭代器和对整数的引用。(接受引用很重要,因为许多 Rust 迭代器提供引用。)以下是 from_itercollect 使用的示例:

let a0 = RangeSetBlaze::from_iter([3, 2, 1, 100, 1]);
let a1: RangeSetBlaze<i32> = [3, 2, 1, 100, 1].into_iter().collect();
assert!(a0 == a1 && a0.to_string() == "1..=3, 100..=100");

RangeSetBlaze 也接受包含范围的迭代器以及对这些范围的引用。它对输入范围没有限制。这些范围可以是无序的、重叠的、空的或重复的。

#[allow(clippy::reversed_empty_ranges)]
let a0 = RangeSetBlaze::from_iter([1..=2, 2..=2, -10..=-5, 1..=0]);
#[allow(clippy::reversed_empty_ranges)]
let a1: RangeSetBlaze<i32> = [1..=2, 2..=2, -10..=-5, 1..=0].into_iter().collect();
assert!(a0 == a1 && a0.to_string() == "-10..=-5, 1..=2");

最后,考虑定义额外的 From::from 方法。这些方法会自动定义 into 方法。例如,为了兼容 BTreeSetRangeSetBlaze 在数组上定义了一个 From::from 方法。

let a0 = RangeSetBlaze::from([3, 2, 1, 100, 1]);
let a1: RangeSetBlaze<i32> = [3, 2, 1, 100, 1].into();
assert!(a0 == a1 && a0.to_string() == "1..=3, 100..=100")

RangeSetBlaze还定义了from_sorted_disjoint/into_range_set_blaze,用于保证已排序且不相交的区间的迭代器。(我们将在规则 5 中看到,如何通过特殊特性和 Rust 编译器来强制执行这一保证。)

let a0 = RangeSetBlaze::from_sorted_disjoint(CheckSortedDisjoint::from([-10..=-5, 1..=2]));
let a1: RangeSetBlaze<i32> = CheckSortedDisjoint::from([-10..=-5, 1..=2]).into_range_set_blaze();
assert!(a0 == a1 && a0.to_string() == "-10..=-5, 1..=2");

附言:为什么使用*from_sorted_disjoint*/*into_range_set_blaze*而不是*from_iter /collect**from/into*?请参见这个讨论这个讨论

对于你的每一个构造函数,考虑可能的加速和优化。RangeSetBlazefrom_iter中实现了这种优化:

  • 将相邻(可能无序)整数/区间合并成不相交的区间,O(n₁)

  • 按起始位置对不相交的区间进行排序,O(n₂ log n₂)

  • 合并相邻的区间,O(n₂)

  • 从现在排序且不相交的区间创建一个BTreeMap,O(n₃ log n₃)

其中 n₁ 是输入整数/区间的数量,n₂ 是不相交且无序的区间数量,n₃ 是最终排序且不相交的区间数量。

“块状”整数的影响是什么?如果 n₂ ≈ sqrt(n₁),则构建时间为 O(n₁)。(实际上,只要 n₂n₁/ln(n₁),构建时间为 O(n₁)。)在基准测试中,这在块状整数迭代器上变成了比HashSetBTreeSet快 700 倍。

规则 3:创建比你预期的更多的 Rust 迭代器

你猜测标准BTreeSet定义了多少种不同的迭代器类型?

答案是八种:IterIntoIterDrainFilterRangeDifferenceSymmetricDifferenceIntersection,和Union。许多非 Rust 编程语言可以将任何方法变成迭代器/生成器,只需几个“yield”语句。然而,Rust 并不提供这种功能(但正在讨论中)。因此,几乎每个与迭代相关的方法都需要你定义一个新的迭代器结构类型。这些结构至少会实现一个next方法,该方法返回Some()None

RangeSetBlaze及其相关类型定义了 13 个迭代器结构。让我们看两个。

首先,用户可以调用ranges并将整数作为一系列排序的不相交区间进行迭代。(请记住,RangeSetBlaze接受无序、重叠的区间,但存储排序的不相交区间。)

use range_set_blaze::RangeSetBlaze;
let set = RangeSetBlaze::from_iter([30..=40, 15..=25, 10..=20]);
let mut ranges = set.ranges();
assert_eq!(ranges.next(), Some(10..=25));
assert_eq!(ranges.next(), Some(30..=40));
assert_eq!(ranges.next(), None);

在内部,RangeSetBlaze使用标准BTreeMap来存储区间信息。因此,RangeSetBlaze::ranges方法构造一个包含BTreeMap::IterRangesIter结构。然后我们让RangesIter::next方法调用BTreeMap::Iternext方法,并将结果转换成所需类型。这里是代码:

impl<T: Integer> RangeSetBlaze<T> {
    pub fn ranges(&self) -> RangesIter<'_, T> {
        RangesIter {
            iter: self.btree_map.iter(),
        }
    }
}

#[derive(Clone, Debug)]
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct RangesIter<'a, T: Integer> {
    pub(crate) iter: btree_map::Iter<'a, T, T>,
}

impl<'a, T: Integer> Iterator for RangesIter<'a, T> {
    type Item = RangeInclusive<T>;
    fn next(&mut self) -> Option<Self::Item> {
        self.iter.next().map(|(start, end)| *start..=*end)
    }
    fn size_hint(&self) -> (usize, Option<usize>) {
        self.iter.size_hint()
    }
}

其次,用户可能希望调用iter并逐个以排序顺序遍历整数。在这种情况下,我们将返回一个名为Iter的结构体,它包含一个RangeIter,然后逐个遍历范围内的整数。以下是Iter::next的原始代码,之后是关注点的讨论。

impl<T: Integer, I> Iterator for Iter<T, I>
where
    I: Iterator<Item = RangeInclusive<T>> + SortedDisjoint,
{
    type Item = T;
    fn next(&mut self) -> Option<T> {
        loop {
            if let Some(range) = self.option_range.clone() {
                let (start, end) = range.into_inner();
                debug_assert!(start <= end && end <= T::safe_max_value());
                if start < end {
                    self.option_range = Some(start + T::one()..=end);
                } else {
                    self.option_range = None;
                }
                return Some(start);
            } else if let Some(range) = self.iter.next() {
                self.option_range = Some(range);
                continue;
            } else {
                return None;
            }
        }
    }

SortedDisjoint特征涉及到保证内部迭代器提供排序的、不相交的范围。我们将在规则 5 中讨论它。

option_range 字段保存我们当前返回整数的范围(如果有的话)。我们使用loopcontinue来填充空的option_range。这个循环最多只循环两次,因此我本可以使用递归。然而,其他一些迭代器的递归次数足以导致栈溢出。因此,…

尾递归优化在 Rust 中没有保证。我的政策是:在next函数中从不使用递归。

附注:感谢 Michael Roth,当前版本的Iter::next现在更简短了。他的拉取请求在这里

BTreeSetRangeSetBlaze除了iter方法外,还定义了一个into_iter迭代器方法。同样,RangeSetBlaze除了其ranges方法外,还定义了一个into_ranges迭代器方法。这些into_whatever方法获取RangeSetBlaze的所有权,这在某些情况下很有用。

规则 4:通过特征使非法值不可表示

我说过RangeSetBlaze只适用于整数,但有什么阻止你将它应用于字符呢?

use range_set_blaze::RangeSetBlaze;

fn _some_fn() {
    let _char_set = RangeSetBlaze::from_iter(['a', 'b', 'c', 'd']);
}

答案?编译器会阻止你。它返回这个错误消息:

let _char_set = RangeSetBlaze::from_iter(['a', 'b', 'c', 'd']);
  |                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  |                     the trait `Integer` is not implemented for `char`
  |
  = help: the following other types implement trait `Integer`:
            i128
            i16
            i32
            i64
            i8
            isize
            u128
            u16
          and $N others

为了实现这一点,RangeSetBlaze定义了一个它称之为Integer的特征。以下是该定义(以及我找到的所有超级特征):

pub trait Integer:
    num_integer::Integer
    + FromStr
    + fmt::Display
    + fmt::Debug
    + std::iter::Sum
    + num_traits::NumAssignOps
    + FromStr
    + Copy
    + num_traits::Bounded
    + num_traits::NumCast
    + Send
    + Sync
    + OverflowingSub
    + SampleUniform
{
    // associated type SafeLen definition not shown ...
    fn safe_len(range: &RangeInclusive<Self>) -> <Self as Integer>::SafeLen;
    fn safe_max_value() -> Self {        Self::max_value()    }
    fn f64_to_safe_len(f: f64) -> Self::SafeLen;
    fn safe_len_to_f64(len: Self::SafeLen) -> f64;
    fn add_len_less_one(a: Self, b: Self::SafeLen) -> Self;
    fn sub_len_less_one(a: Self, b: Self::SafeLen) -> Self;
}

接下来,我在所有感兴趣的整数类型(u8u128包括usizei8i128包括isize)上实现了Integer特征。例如,

impl Integer for i32 {
    #[cfg(target_pointer_width = "64")]
    type SafeLen = usize;
    fn safe_len(r: &RangeInclusive<Self>) -> <Self as Integer>::SafeLen {
        r.end().overflowing_sub(*r.start()).0 as u32 as <Self as Integer>::SafeLen + 1
    }
    fn safe_len_to_f64(len: Self::SafeLen) -> f64 {len as f64}
    fn f64_to_safe_len(f: f64) -> Self::SafeLen {f as Self::SafeLen}
    fn add_len_less_one(a: Self, b: Self::SafeLen) -> Self {a + (b - 1) as Self}
    fn sub_len_less_one(a: Self, b: Self::SafeLen) -> Self {a - (b - 1) as Self}
}

有了这个,我可以使代码泛型化为<T: Integer>,如规则 3 中的代码示例所示。

附注:为什么 Rust 没有提供一个标准的“整数”特征来做所有事情?这里是讨论

规则 5:定义具有保证属性和有用方法的泛型迭代器

RangeSetBlazefrom_sorted_disjoint构造函数假设输入是排序好的不相交范围。这让RangeSetBlaze避免了工作。但是如果这个假设错误了呢?例如,如果我们给它未排序的范围,会发生什么?

use range_set_blaze::RangeSetBlaze;

fn _some_fn() {
    let not_guaranteed = [5..=6, 1..=3, 3..=4].into_iter();
    let _range_set_int = RangeSetBlaze::from_sorted_disjoint(not_guaranteed);
}

与规则 4 一样,编译器会捕捉错误并返回有用的消息:

7 |     let _range_set_int = RangeSetBlaze::from_sorted_disjoint(not_guaranteed);
  |                          ----------------------------------- ^^^^^^^^^^^^^^
                              the trait `SortedDisjoint<_>` is not implemented for `std::array::IntoIter<RangeInclusive<{integer}>, 3>`
  |                          |
  |                          required by a bound introduced by this call
  |
  = help: the following other types implement trait `SortedDisjoint<T>`:
            CheckSortedDisjoint<T, I> ...

为了实现这一点,RangeSetBlaze定义了特征SortedDisjoint。以下是相关定义:

pub trait SortedStarts<T: Integer>: Iterator<Item = RangeInclusive<T>> {}
pub trait SortedDisjoint<T: Integer>: SortedStarts<T> {
// methods not shown, yet
}

这说明 SortedDisjoint 是对整数的泛型,并且每个 SortedDisjoint 必须是 SortedStarts。此外,所有 SortedStarts 都是整数范围的迭代器。

附注:我的项目需要两个新的特征,因为我需要保证两个不同的属性。需要保证一个属性的项目只需要一个新的特征。

那么,重点是什么呢?为什么要引入新的特征,而不是直接使用 Iterator<Item = RangeInclusive<T>?正如我从 Rüdiger Klaehn 的精彩 sorted-iter crate 中学到的,我们可以使用这些新特征来强制执行保证。例如,这个构造函数使用 where 子句只接受保证的(排序且不重叠的)整数迭代器:

impl<T: Integer> RangeSetBlaze<T> {
    pub fn from_sorted_disjoint<I>(iter: I) -> Self
    where
        I: SortedDisjoint<T>,
    {
        // ... code omitted ...
    }
}

那么,保证的迭代器如何获得所需的 SortedDisjoint 特征?它们实现了这个特征!例如,我们知道 RangeSetBlaze::ranges 方法返回一个 RangesIter 迭代器,它由排序且不重叠的范围组成,因此我们让 RangesIter 实现 SortedDisjoint 特征,如下所示:

impl<T: Integer> SortedStarts for RangesIter<'_, T> {}
impl<T: Integer> SortedDisjoint for RangesIter<'_, T> {}

就这样。我们已经将 RangesIter 标记为 SortedDisjoint。编译器会完成剩下的工作。

不保证到保证:我还标记了一个名为 CheckSortedDisjoint 的迭代器为 SortedDisjoint。有趣的是,它遍历一个 不保证 的内部迭代器。这怎么可能呢?实际上,当它迭代时,它也会检查——如果发现任何未排序或重叠的范围则会引发恐慌。结果是一个保证的迭代器。

有时保证外部迭代器:那么有时是SortedDisjoint而有时不是的迭代器呢?例如,流行的 [Itertools::tee](https://docs.rs/itertools/latest/itertools/trait.Itertools.html#method.tee) 方法将任何迭代器转换为两个具有相同内容的迭代器。如果其输入迭代器是SortedDisjoint,那么其输出迭代器也将是:

impl<T: Integer, I: SortedDisjoint<T>> SortedDisjoint<T> for Tee<I> {}

定义方法:几乎可以说是额外的好处,我们可以在泛型 SortedDisjoint 迭代器上定义方法。例如,在这里我们定义了 complement 方法,该方法生成当前迭代器中 包含的每个排序且不重叠的整数范围。

pub trait SortedDisjoint<T: Integer>: SortedStarts<T> {
    fn complement(self) -> NotIter<T, Self>
    where
        Self: Sized,
    {
        NotIter::new(self)
    }
}

这是来自 complement 文档的一个使用示例:

use range_set_blaze::prelude::*;

let a = CheckSortedDisjoint::from([-10i16..=0, 1000..=2000]);
let complement = a.complement();
assert_eq!(complement.to_string(), "-32768..=-11, 1..=999, 2001..=32767");

complement 方法使用 NotIter,另一个迭代器(见规则 3)。NotIter 也实现了 SortedDisjoint。这个示例还使用了 to_string,这是另一个 SortedDisjoint 方法。

要使用 complementto_string,用户必须

use range_set_blaze::SortedDisjoint;use range_set_blaze::prelude::*;。前导模块起作用是因为项目的 lib.rs 指定了

pub mod prelude;prelude.rs 文件包含这个 pub use 语句,它包括 SortedDisjoint

pub use crate::{
    intersection_dyn, union_dyn, CheckSortedDisjoint, DynSortedDisjoint, MultiwayRangeSetBlaze,
    MultiwayRangeSetBlazeRef, MultiwaySortedDisjoint, RangeSetBlaze, SortedDisjoint,
};

这些是创建 Rust 数据结构的前五条规则。请参见 第二部分 了解规则 6 到 9。

附言:如果你对未来的文章感兴趣,请关注我的 Medium。我写关于 Rust 和 Python 的科学编程、机器学习和统计学的文章。我通常每个月写一篇文章。

在 Rust 中创建快速、安全和兼容的数据结构的九条规则(第二部分)

原文:towardsdatascience.com/nine-rules-for-creating-fast-safe-and-compatible-data-structures-in-rust-part-2-da5e6961a0b7?source=collection_archive---------8-----------------------#2023-04-12

来自 RangeSetBlaze 的经验教训

Carl M. KadieTowards Data Science Carl M. Kadie

·

关注 发表在 Towards Data Science ·14 min read·Apr 12, 2023

--

将数字存储在树中 — 来源: 必应图像创作者

这是关于在 Rust 中创建数据结构的文章第二部分。我们将探讨第 6 到第 9 条规则:

  • 6. 定义操作符和快速操作。

  • 7. 遵循“九条良好 API 设计规则”,尤其是“编写良好的文档”。

  • 8. 使用代表性数据、Criterion 基准测试和分析工具来优化性能。

  • 9. 测试覆盖率、文档、特性、编译器错误和正确性。

查看第一部分以获取规则 1 至 5。

  1. 抄袭你的 API、文档,甚至代码——来自标准库。

  2. 设计易于使用、兼容且高效的构造函数。

  3. 创建比预期更多的 Rust 迭代器。

  4. 使用特性使非法值无法表示。

  5. 定义具有保证属性和有用方法的泛型迭代器。

这些规则是根据我创建的新范围集合 crate [range-set-blaze](https://crates.io/crates/range-set-blaze)的经验得出的。范围集合是一种数据结构,用于表示整数等集合,形式为排序且不重叠的范围。例如,0..=5, 10..=10。与其他范围集合 crate 相比,range-set-blaze提供了完整的集合操作,并且性能优越。

让我们首先看看这个 crate 如何提供集合操作。

规则 6:定义操作符和快速操作

标准的BTreeSet提供了集合操作符,例如,如果ab都是BTreeSet,那么a | b将是它们并集的BTreeSet。决定操作符是否适合你的数据结构。操作符可以是逻辑操作符,例如|&!,和/或算术操作符,例如+-*/^这个表格列出了 Rust 的“可重载”操作符及其方法名称。

range-set-blaze crate 在RangeSetBlaze结构上提供了与集合相关的操作符:

use range_set_blaze::RangeSetBlaze;
let a = RangeSetBlaze::from_iter([1..=4]);
let b = RangeSetBlaze::from_iter([0..=0, 3..=5, 10..=10]);
let union = a | b;
assert_eq!(union, RangeSetBlaze::from_iter([0..=5, 10..=10]));

以及任何实现了SortedDisjoint特性的迭代器(见第一部分中的规则 5):

use range_set_blaze::prelude::*;

let a = CheckSortedDisjoint::new(vec![1..=2, 5..=100].into_iter());
let b = CheckSortedDisjoint::from([2..=6]);
let union = a | b;
assert_eq!(union.to_string(), "1..=100");

BitOr是 Rust 中|操作符的方法名称(再次见表格)。下面是RangeSetBlaze上并集操作符的定义:

impl<T: Integer> BitOr<RangeSetBlaze<T>> for RangeSetBlaze<T> {
    type Output = RangeSetBlaze<T>;
    fn bitor(mut self, other: Self) -> RangeSetBlaze<T> {
      // code omitted for now
    }
}

这意味着对于所有的RangeSetBlaze,我们定义了一个|操作符,它接受一个RangeSetBlaze作为第二个输入,并返回一个RangeSetBlaze

子规则 #1:支持借用输入:上述定义允许用户进行a | b操作,但不支持&a | &b&a | ba | &b。这些“借用组合”需要三个额外的定义,如下所示:

impl<T: Integer> BitOr<&RangeSetBlaze<T>> for &RangeSetBlaze<T> {...}
impl<T: Integer> BitOr<RangeSetBlaze<T>> for &RangeSetBlaze<T> {...}
impl<T: Integer> BitOr<&RangeSetBlaze<T>> for RangeSetBlaze<T> {...}

或者,你可以使用[gen_ops](https://crates.io/crates/gen_ops) crate。它允许你在一个声明中定义多个操作符的所有借用组合。

子规则 #2:在特性上提供操作符(类似): 与集合相关的方法,如complementunion,可以在我们的泛型SortedDisjoint迭代器上高效地定义。这些方法在一次遍历中运行,并且需要的内存非常少。(在第一部分的规则 5 中,我们看到了complement的定义。)然而,我们希望使用!a来代替a.complement()。可以吗?可以,也不可以。

不,你不能在SortedDisjoint上实现特征ops::Not!运算符)——参见互联网讨论。(你可以,但不应该使ops::Not成为SortedDisjoint的超类型,因为这将使在外部类型上实现SortedDisjoint变得不可能,例如[Itertools::Tee](https://docs.rs/itertools/latest/itertools/trait.Itertools.html#method.tee)。)

然而,你可以并且应该在尽可能多的自定义类型上定义这些运算符。例如,这里是CheckSortedDisjointops::Not的实现:

impl<T: Integer, I> ops::Not for CheckSortedDisjoint<T, I>
where
    I: Iterator<Item = RangeInclusive<T>>,
{
    type Output = NotIter<T, Self>;

    fn not(self) -> Self::Output {
        self.complement()
    }
}

这让我们可以在这个函数中使用!

fn print_first_complement_gap() {
    let a = CheckSortedDisjoint::from([-10i16..=0, 1000..=2000]);
    println!("{:?}", (!a).next().unwrap()); // prints -32768..=-11
}

额外说明:这个示例函数运行在常量时间和内存中。

子规则 #3: 利用所有权: 我最初不支持|=BitOrAssign)运算符。我认为如果用户想将两个RangeSetBlaze对象进行并集并将结果放入第一个对象中,他们可以直接写a = a | b

然而,后来我发现了一个潜在的加速机会。具体来说,当b中的范围数量相对于a非常少时,我们应该一次一个地将b的范围添加到a中,而不是将它们的两个ranges迭代器合并在一起。因此,我实现了BitOrAssign。(我没有针对交集的加速方法,所以没有实现BitAndAssign。)

但等等,还有更多!当用户调用a | &b时,运算符拥有第一个输入。这意味着代码可以在原地修改第一个输入(如果愿意的话)并将其作为结果返回。这是该运算符的定义:

impl<T: Integer> BitOr<&RangeSetBlaze<T>> for RangeSetBlaze<T> {
    type Output = RangeSetBlaze<T>;
    fn bitor(mut self, other: &Self) -> RangeSetBlaze<T> {
        self |= other;
        self
    }
}

我最终为a | &ba | b以及&a | b创建了特殊的优化实现。Rust 甚至允许我优化&mut a |= b,无论是b更小(如之前所述)还是a更小(我们将a放入拥有的b中,一次一个范围,然后将a赋值给b)。

子规则 #4: 提供快速的多路操作: 我们可以通过定义两个RangeSetBlaze对象的交集(运算符&)来实现对其ranges的交集。

(a.ranges() & b.ranges()).into_range_set_blaze()

并且,得益于布尔代数的奇迹,我们可以通过补集和并集来定义两个RangesIter的交集:

!(self.complement() | other.into_iter().complement())

这为所有SortedDisjoint迭代器提供了一个常量内存的一次性交集。(同样的技巧也适用于集合差异。对称差异需要两次遍历,但仍然只需要常量内存。)

额外说明:BTreeSet不提供complement。为什么?因为例如[0, 3, 4, 5, 10]的补集将需要 40 亿个条目。然而,基于范围的表示法只存储 4 个条目:-2147483648..=-1, 1..=2, 6..=9, 11..=2147483647

那么多个RangeSetBlaze对象的交集呢?我们可以一次做两个,但那样会产生不必要的中间对象。相反,我们可以提供多路交集功能,这样用户就可以这样做:

use range_set_blaze::prelude::*;

let a = RangeSetBlaze::from_iter([1..=6, 8..=9, 11..=15]);
let b = RangeSetBlaze::from_iter([5..=13, 18..=29]);
let c = RangeSetBlaze::from_iter([-100..=100]);

let intersection = [a, b, c].intersection();

assert_eq!(intersection, RangeSetBlaze::from_iter([5..=6, 8..=9, 11..=13]));

多路交集被定义为对RangeSetBlaze对象的迭代器的一种方法:

pub trait MultiwayRangeSetBlaze<'a, T: Integer + 'a>:
    IntoIterator<Item = &'a RangeSetBlaze<T>> + Sized
{
    fn intersection(self) -> RangeSetBlaze<T> {
        self.into_iter()
            .map(RangeSetBlaze::ranges)
            .intersection()
            .into_range_set_blaze()
    }
}

这调用了新的多路SortedDisjoint intersectionunion

pub trait MultiwaySortedDisjoint<T: Integer, I>: IntoIterator<Item = I> + Sized
where
    I: SortedDisjoint<T>,
{
    fn intersection(self) -> BitAndKMerge<T, I> {
        self.into_iter()
            .map(|seq| seq.into_iter().complement())
            .union()
            .complement()
    }

    fn union(self) -> BitOrKMerge<T, I> {
        UnionIter::new(KMerge::new(self))
    }
}

将这些结合起来,我们能够在一次通过中交集任意数量的RangeSetBlaze对象,并且使用恒定的内存(加上最终RangeSetBlaze结果的内存)。

子规则 #5: 如有需要,使用 **Box<dyn** **>** 类型在不同类型之间提供多路操作: SortedDisjoint上的多路运算符可能会遇到问题。你可以对相同类型的输入进行并集,这里是RangesIter<T>

let _i0 = [a.ranges(), b.ranges(), c.ranges()].intersection();

但不能处理混合类型的输入,这里是NotIter<T, RangesIter<T>>RangesIter<T>

// doesn't compile: let _i1 = [!a.ranges(), b.ranges(), c.ranges()].intersection();

解决方案是将所有输入封装在一个新的通用类型中(详细信息见GitHub):

pub struct DynSortedDisjoint<'a, T: Integer> {
    iter: Box<dyn SortedDisjoint<T> + 'a>,
}

你可以显式地使用新类型或通过宏来使用它。

let _i2 = [
        DynSortedDisjoint::new(!a.ranges()),
        DynSortedDisjoint::new(b.ranges()),
        DynSortedDisjoint::new(c.ranges()),
    ]
    .intersection();
// or
let _i3 = intersection_dyn!(!a.ranges(), b.ranges(), c.ranges());

规则 7: 遵循九条良好 API 设计规则,特别是“编写良好文档”

文章优雅的 Rust 库 API 的九条规则中的所有规则都适用。这里讨论了最有趣的六条,首先是:

编写良好的文档以保持设计的诚实。

创建不会让你感到尴尬的示例。

你应该在你的lib.rs中加入#![warn(missing_docs)]。这将提醒你为每个公共类型、特性和方法编写文档。在(几乎)每一部分文档中包含一个示例。

在某个情况下,我厌倦了向用户警告某些迭代器应仅包含排序且不相交的范围。这使我回到规则 5,并让编译器强制执行所需的保证。

在另一个情况下,我发现自己在多个地方解释如何调用并集。具体来说,我告诉用户在短的第二输入和| (union)运算符的情况下使用extend方法。对此我感到解释很尴尬。这导致我提供了一个|=运算符,它总是做最快的事情。

使用Clippy

这个规则很重要。我现在认为这是理所当然的。所以,使用Clippy

接受所有类型的类型。

我主要尝试遵循这个规则,但有一个重大例外。虽然 Rust 提供了许多范围类型,但我只接受start..=end形式的范围。这不仅简化了代码,而且我认为它也简化了文档和示例。(根据用户的请求,我会重新考虑这个问题。)

定义并返回友好的错误信息。

出乎意料的是,这个规则在许多数据结构中并不常见。例如,标准的BTreeSet不会返回可能是错误的结果。同样,RangeSetBlaze在几个地方可能会引发恐慌——例如,如果CheckSortedDisjoint发现问题——但它从不返回错误结果。

了解用户的需求,最好通过使用自己的产品来了解。

我感到很遗憾没有遵循这个规则。我创建RangeSetBlaze是为了好玩。我的当前项目中没有需要它的地方。我希望其他人能够在他们的项目中使用它,并给我反馈他们需要什么。

规则 8:使用代表性数据、Criterion 基准测试和分析工具来优化性能

附注: 基准测试报告的最新版本显示了最新结果,包括与流行的 Roaring 压缩位图的比较。

子规则 #1:使用代表性数据: 在某些数据上,RangeSetBlaze会比标准的HashSetBTreeSet表现更差。例如,从随机整数(均匀且有替换)构造的速度大约比HashSet慢 2.5 倍。以下是构造时间与随机整数数量的关系图:

但如果整数是“聚集的”呢?例如,12、13、14、998、999、1000、1001、-333、-332 等。为了查看,我们将构造 100 万个聚集的整数。我们将使平均聚集大小从 1(没有聚集)变化到 100K(十个大聚集)。(更多细节,请参见(某些)范围相关的 Rust Crate 基准测试)。有了这些数据,当聚集大小达到 100 时,RangeSetBlazeHashTable快 30 倍,比BTreeSet快 15 倍。如果我们可以将聚集作为范围(而不是单独的整数)提供,则当聚集大小为 1000 时,RangeSetBlaze比标准数据类型快 700 倍。

这里的建议是对你认为有趣的数据(无论是实际数据还是合成数据)进行测试。对于这个项目,我创建了“随机聚集整数迭代器”,并控制了以下内容:

  • 整数的范围(例如,所有整数都在0..=9_999_999中)

  • 整数的数量(例如,100 万)

  • 平均聚集大小(例如,100)

  • 迭代器的数量(例如,1)

  • coverage,这是期望由迭代器(如果有多个则取并集或交集)覆盖的范围的比例——(例如,10% coverage)。

子规则 #2:使用 Criterion crate 在 Rust 中基准测试程序: 以下是一些提示:

  • 要获得良好的 HTML 报告和出色的图表,在你的Cargo.toml[dev-dependencies]部分中加入criterion = { version = "当前版本", features = ["html_reports"] }。(此说明目前在Criterion 文档中缺失)。

  • 在你的Cargo.toml中加入以下内容。它声明了一个名为“bench”的基准测试,使用第三方基准测试工具 Criterion。

[[bench]]
name = "bench"
harness = false
  • 在项目的顶层创建文件夹benches和文件bench.rs。文件的最后一行应为criterion_main!(benches);

  • 根据 Criterion 用户指南 中的描述创建基准测试。

  • 不要和 Criterion 抵触。它可能会抱怨你的基准测试太慢。如果可以,听取它的意见并加快速度。这将帮助它生成更好的统计数据。

附注:你可以在调试器下和作为集成测试运行 Criterion 实验。有关示例,请参见 [fn debug_k_play](https://github.com/CarlKCarlK/range-set-blaze/blob/main/tests/integration_test.rs#L634)[tests/integration.rs](https://github.com/CarlKCarlK/range-set-blaze/blob/main/tests/integration_test.rs)。关键是将公共代码放到自己的项目中(例如,[tests_common](https://github.com/CarlKCarlK/range-set-blaze/tree/main/tests_common)),然后在主 [Cargo.toml](https://github.com/CarlKCarlK/range-set-blaze/blob/main/Cargo.toml#LL16C1-L18C32) 文件中创建一个工作区。

子规则 #3: 使用基准测试来驱动设计决策: 八年前,当我创建这个数据结构的 Python 版本时,我将排序后的范围存储在一个向量中。对于新版本,我想知道使用 Rust 的标准 BTreeMap 是否会更快。为了验证,让我们将两个基于 BTreeMap 的库 — [rangemap](https://crates.io/crates/rangemap)RangeSetBlaze — 与两个基于 SmallVec 的库 — [Range-collections](https://crates.io/crates/range-collections)[range-set](https://crates.io/crates/range-set) 进行比较:

(有关详细信息,请参见 基准测试报告)。总结一下,最快的基于向量的方法比最慢的基于树的方法慢 14 倍。它还比 RangeSetBlaze 慢 50 倍。这并不令人惊讶,因为基于向量的方法并不是为我们关注的数据的大量插入设计的。

这是另一个例子。当我们将多个 RangeSet 对象相交(或并集)时,逐个处理可以吗,还是多路处理更快?基准测试给出了答案:

在两个集合上,所有方法都相似,但随着集合数量的增加,逐个处理的速度就会落后。到 100 个集合时,逐个处理必须创建大约 100 个中间集合,速度比多路处理慢约 14 倍。动态多路处理未被 RangeSetBlaze 使用,但 SortedDisjoint 迭代器有时需要。它比静态多路处理慢 5% 到 10%。(详细信息)。

子规则 #4: 对代码进行分析以查找瓶颈: 最受欢迎的 Rust 分析工具是 [flamegraph](https://github.com/flamegraph-rs/flamegraph)。我发现它很容易使用,但我错过了能够进行互动缩放和操作结果的功能。然而,实际上你可以使用几乎任何分析工具来配合 Rust。

例如,我订阅了 Visual Studio 2022 的付费版。我以前在 C++ 和 C# 代码上使用了它的全功能分析器。要在 Rust 上使用它,我只需将其(临时)添加到我的 Cargo.toml 文件中,并在发布模式下重新编译。

[profile.release]
debug = true

其他功能全面的分析器,虽然我没有使用过,但看起来可以与 Rust 配合使用的包括 AMD μProf 和在 Windows 上的 Superluminal

规则 9:测试覆盖率、文档、特性、编译器错误和正确性

数据结构通常支持多种方法,并被其他项目使用。因此,它们需要广泛的测试。

规则 1:结合集成测试和覆盖测试: 集成测试只能看到数据结构的公共接口。我将我的测试放在 tests/integration_test.rs 中。每个测试函数都以 #[test] 开头。

覆盖测试确保你的代码中的每一行至少被测试运行一次。使用强大且易用的 [llvm-cov](https://github.com/taiki-e/cargo-llvm-cov#) 来测量覆盖率。安装说明见 这里。用命令运行它:cargo llvm-cov --open.

要将覆盖率提升到 100%,你需要添加测试。我建议尽可能将这些覆盖测试纳入你的集成测试中。这将使你能够体验到公共接口的可用性。

要在你支持的类型之间创建覆盖(以及其他)测试,请使用神奇的 syntactic-for crate:

#[test]
fn lib_coverage_6() {
    syntactic_for! { ty in [i8, u8, isize, usize,  i16, u16, i32, u32, i64, u64, isize, usize, i128, u128] {
        $(
            let mut a = RangeSetBlaze::<$ty>::from_iter([1..=3, 5..=7, 9..=120]);
            a.ranges_insert(2..=100);
            assert_eq!(a, RangeSetBlaze::from_iter([1..=120]));

        )*
    }};
}

规则 2:在文档中添加多个示例并将其作为测试运行:以下是 SortedDisjoint::equals 文档中的一个示例:

 /// Given two [`SortedDisjoint`] iterators, efficiently tells if they
    /// are equal. Unlike most equality testing in Rust,
    /// this method takes ownership of the iterators and consumes them.
    ///
    /// # Examples
    ///
    /// ```

    /// use range_set_blaze::prelude::*;

    ///

    /// let a = CheckSortedDisjoint::from([1..=2]);

    /// let b = RangeSetBlaze::from_iter([1..=2]).into_ranges();

    /// assert!(a.equal(b));

    /// ```py

这可以通过 cargo testcargo test --doc 运行。(不过,在测量覆盖率时,它不会运行。)

规则 3:测试所有结构体和枚举的缺失特性: 我希望 RangeSetBlaze::Iter 一般支持与标准 BTreeSet::Iter 相同的特性。此外,所有类型应该在实际情况允许的情况下自动实现 SizedSendSyncUnpin。(参见这个 Let’s Get Rusty 视频)。以下是测试特性的方法:

fn is_sssu<T: Sized + Send + Sync + Unpin>() {}
fn is_like_btreeset_iter<T: Clone + std::fmt::Debug + FusedIterator + Iterator>() {}
// removed DoubleEndedIterator +ExactSizeIterator for now
#[test]
fn iter_traits() {
    type ARangesIter<'a> = RangesIter<'a, i32>;
    type AIter<'a> = Iter<i32, ARangesIter<'a>>;
    is_sssu::<AIter>();
    is_like_btreeset_iter::<AIter>();
}

当我第一次做这个时,我发现我缺少 DebugFusedIteratorDoubleEndedIteratorExactSizeIterator。我添加了前两个,并决定暂时不添加后两个。

类似地,一个将 RangeSetBlaze 的特性与标准 BTreeSet 的特性进行比较的测试提醒我实现 ClonePartialOrdHashEqOrdIntoIterator

子规则 #4: 测试非法值确实无法表示: 你希望你的数据结构在用户应用于错误类型时(例如,尝试创建一个字符集而不是整数集)引发编译器错误。同样,如果用户给出的是普通迭代器而不是像 SortedDisjoint 这样的迭代器,你的一些方法也应该引发编译器错误。我们如何测试这一点?

使用 [trybuild](https://docs.rs/trybuild/latest/trybuild/) 创建“ui”测试。首先,创建一个像这样的集成测试:

#[test]
fn ui() {
    let t = trybuild::TestCases::new();
    t.compile_fail("tests/ui/*.rs");
}

然后在 tests/ui 中,创建如 untrusted_pairs.rs 这样的文件,以测试编译器错误信息。

use range_set_blaze::RangeSetBlaze;

fn main() {
    let guaranteed = RangeSetBlaze::from_iter([1..=2, 3..=4, 5..=6]).into_ranges();
    let _range_set_int = RangeSetBlaze::from_sorted_disjoint(guaranteed); // yep
    let not_guaranteed = [5..=6, 1..=3, 3..=4].into_iter();
    let _range_set_int = RangeSetBlaze::from_sorted_disjoint(not_guaranteed); // nope
}

设置环境变量 TRYBUILD=overwrite 以记录预期的编译器错误信息。详细信息见 [trybuild](https://docs.rs/trybuild/latest/trybuild/)。如果你在本地无法得到与 CI(例如 GitHub Action)相同的结果,请参见 这个线程

子规则 #5: 使用自动化 QuickCheck 测试验证正确性: 根据 Rüdiger Klaehn 的推荐,我实现了自动化的 [QuickCheck](https://github.com/BurntSushi/quickcheck) 测试。这里是一个测试,它检查 RangeSetBlaze::is_disjoint 是否在 QuickCheck 生成的值上与 BTreeSet::is_disjoint 得到相同的答案。

type Element = i64;
type Reference = std::collections::BTreeSet<Element>;

#[quickcheck]
fn disjoint(a: Reference, b: Reference) -> bool {
    let a_r = RangeSetBlaze::from_iter(&a);
    let b_r = RangeSetBlaze::from_iter(&b);
    a.is_disjoint(&b) == a_r.is_disjoint(&b_r)
}

所以,你了解了,在 Rust 中创建数据结构的九条规则。创建 Rust 数据结构比我预期的时间要长。这一方面是因为我创建了一个完整的SortedDisjoint子库,另一方面则是因为 Rust 的要求,比如需要我定义 13 个公共迭代器结构体。然而,从 Rust 的速度和安全性来看,这段时间是值得的。按照这九条规则来创建你自己的强大 Rust 数据结构吧。

顺便提一下,如果你对未来的文章感兴趣,请 在 Medium 上关注我。我会写关于 Rust 和 Python 的科学编程、机器学习和统计学的文章。我通常每月写一篇文章。

在网络和嵌入式系统上运行 Rust 的九条规则

原文:towardsdatascience.com/nine-rules-for-running-rust-on-the-web-and-on-embedded-94462ef249a2?source=collection_archive---------3-----------------------#2023-07-05

从将 range-set-blaze 移植到 no_std 和 WASM 的实际经验

Carl M. KadieTowards Data Science Carl M. Kadie

·

关注 发表在 Towards Data Science ·17 分钟阅读·2023 年 7 月 5 日

--

微控制器上的螃蟹 — 来源:openai.com/dall-e-2/

我推荐使用 Rust,当你需要 C++ 的速度和 Python 的内存安全时。此外,使用 Rust 你可以构建在超过 100,000 个 软件库 上。除此之外,Rust 还提供了将你的代码运行在不仅是传统计算机上,还有网页甚至机器人上的潜力。

然而,“几乎所有地方”运行会带来复杂性。本文是为那些希望减轻这些复杂性的 Rust 程序员准备的。(它也可能对那些想了解 Rust 的“几乎所有地方”运行故事的人感兴趣。)

第一个复杂性:网页和机器人的嵌入式处理器不支持通用文件 IO。如果你的项目主要涉及读写文件,它不适合在机器人、其他嵌入式处理器或网页上运行。

第二个复杂性:将代码移植到几乎所有地方需要多个步骤和选择。导航这些选择可能会耗时。遗漏一步可能导致失败。本文旨在通过提供这九条规则来减少第二个复杂性,稍后我们将详细探讨这些规则:

  1. 将你的 lib.rs 或 main.rs 标记为 no_std。

  2. 如果可能的话,使用内置的“crate alloc”。

  3. 切换到“no std”依赖项。

  4. 创建 std 和 alloc 功能,并将你的 std-only 函数设为可选。

  5. 为 WASM 构建你的项目。使用 cargo tree 来使其正常工作。

  6. 创建 WASM 测试和 WASM 演示。

  7. [可选] 为嵌入式设备构建你的项目。

  8. [可选] 创建一个嵌入式测试和嵌入式演示。

  9. 完成 CI 测试、Cargo.toml 元数据和更新的 README.md。

遵循这些规则将帮助你创建运行在从 PC 到智能手机网页(demo)到机器人等所有地方的非常快速且内存安全的代码。代码可以非常小,并且可以利用庞大的 Rust crate 库。

为了说明这些规则,我们将把[range-set-blaze](https://github.com/CarlKCarlK/range-set-blaze) crate 移植到网页——WASM——和微控制器——嵌入式。 (这个 crate 操作“块状”整数的集合。这个 crate 的用户请求了这个移植。)

移植到 WASM 和嵌入式要求你避免使用 Rust 的标准库“std”。转换到“no std”比我预期的既容易又困难。容易是因为你仍然可以使用VecString。困难主要是因为测试。基于我在range-set-blaze上的经验,以下是我推荐的决策,逐一描述。为了避免优柔寡断,我将这些建议表达为规则。

规则 1:将你的 lib.rs 或 main.rs 标记为 no_std。

附注 1:首先,使用 Git 为你的项目创建一个新分支。这样,如果出现问题,你可以轻松地撤销所有更改。

附注 2:实际上WASM 部分支持 std。例如,它支持vecStringHashSet。它不支持文件 IO。如果你只需要 WASM 支持,你可能可以跳过本文中的所有“no_std”工作。

附注 3:来自Reddit的一个提示,我还没有测试过:“你真的不需要为 Wasm 设置任何no_std。如果你使用诸如wasm-optwasm-gc,甚至只是一个完整的集成管道与trunk,你不会看到二进制文件大小的差异,因为任何你没有使用的东西都会被剥离。无需设置no_std并寻找no_std依赖项。”

lib.rs的顶部标记为:

#![cfg_attr(not(test), no_std)]

这告诉 Rust 编译器除非在测试时,否则不要包含标准库。

附注 1:我的项目是一个具有lib.rs的库项目。我认为具有main.rs的二进制项目的步骤大致相同,但我还没有测试过。

附注 2:我们将在后面的规则中详细讨论代码测试。

range-set-blazelib.rs中添加“no_std”行,会导致 40 个编译器问题,大多数问题形式如下:

通过将主代码中的“std::”更改为“core::”来修复其中的一些问题(不包括测试代码)。对于range-set-blaze,这将问题数量从 40 个减少到 12 个。这个修复很有效,因为许多项,如std::cmp::max,在core::cmp::max中也可以找到。

可悲的是,像VecBox这样的项不能在core中,因为它们需要分配内存。幸运的是,如果你愿意支持内存分配,你仍然可以使用它们。

规则 2:如果可以的话,使用内置的“crate alloc”。

你是否应该允许你的 crate 分配内存?对于 WASM,你应该允许。对于许多嵌入式应用程序,你也应该允许。然而,对于一些嵌入式应用程序,你不应该允许。如果你决定允许内存分配,那么在lib.rs的顶部添加:

extern crate alloc;

你现在可以添加这样的行,以获取许多内存分配的项:

extern crate alloc;

use alloc::boxed::Box;
use alloc::collections::btree_map;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use alloc::{format, string::String};
use alloc::vec;

使用range-set-blaze,这将问题数量从 12 个减少到两个。我们将在规则 3 中解决这些问题。

附注:如果你在编写一个不能使用内存分配的嵌入式环境,并且遇到了例如Vec的问题,你可以尝试重写。例如,你可以尝试用数组替代向量。如果这样不行,可以查看其他规则。如果都不行,你可能无法将你的 crate 移植到no_std

规则 3:切换到“no std”依赖项。

如果你的项目使用了将“std”函数引入你代码中的 crate,Rust 编译器会发出警告。有时,你可以搜索crates.io并找到替代的“no_std” crate。例如,流行的thiserror crate 会将“std”注入你的代码。然而,社区已经创建了不含“std”的替代品。

对于 range-set-blaze,剩下的两个问题与 crate [gen_ops](https://crates.io/crates/gen_ops) 相关——这是一个方便地定义操作符如 “+” 和 “&” 的绝妙 crate。gen_ops 的版本 0.3.0 未完全支持 “no std”。然而,版本 0.4.0 支持。我更新了 Cargo.toml 中的依赖项,并改进了我的“no std”兼容性。

我现在可以运行这些命令:

cargo check # check that compiles as no_std
cargo test # check that tests, using std, still pass

命令 cargo check 确认了我的 crate 并没有直接使用标准库。命令 cargo test 确认了我的测试(仍使用标准库)继续通过。如果你的 crate 仍然不能编译,请查看下一个规则。

规则 4:创建 std 和 alloc 特性,并使你的 std-only 函数可选。

嵌入式处理器通常不支持读取和写入文件。同样,WASM 也尚未完全支持文件。虽然你可以找到一些与文件相关的“no std” crates,但似乎没有一个是全面的。因此,如果文件 IO 是你的 crate 的核心,移植到 WASM 和嵌入式可能不切实际。

然而,如果文件 IO —— 或任何其他 std-only 函数 —— 只是你的 crate 的附带功能,你可以通过“std”特性使该函数可选。方法如下:

将以下部分添加到你的 Cargo.toml

[package]
#...
resolver = "2" # the default for Rust 2021+

[features]
default = ["std"]
std = []
alloc = []

这表示你的 crate 现在有两个特性,“std”和“alloc”。默认情况下,编译器应使用“std”。

在你的 lib.rs 顶部,替换:

#![cfg_attr(not(test), no_std)]

替换为:

#![cfg_attr(not(feature = "std"), no_std)]

这表示如果你不应用“std”特性,编译器应在没有标准库的情况下进行编译。

在任何 std-only 代码之前的行中,添加 #[cfg(feature = "std")]。例如,这里我们定义了一个基于文件内容创建 RangeSetBlaze 结构体的函数:

#[cfg(feature = "std")]
use std::{
    fs::File,
    io::{self, BufRead, BufReader},
    path::Path,
};

#[cfg(feature = "std")]
#[allow(missing_docs)]
pub fn demo_read_ranges_from_file<P, T>(path: P) -> io::Result<RangeSetBlaze<T>>
where
    P: AsRef<Path>,
    T: FromStr + Integer,
{
 //...code not shown
}

要检查 “std” 和 “alloc” 特性,请执行以下操作:

cargo check # std
cargo check --features alloc --no-default-features

我们可以用以下方式测试“std”:

cargo test

附注:令人惊讶的是,cargo test --features alloc --no-default-features 不会测试 "alloc"。这是因为测试 需要线程、分配和其他可能在 no_std 中不可用的东西,因此 cargo 总是将常规测试作为“std”运行。

在这个阶段,我们检查了“std”和“alloc”,所以我们可以假设我们的库将与 WASM 和嵌入式兼容吗?不!一般来说,没有经过测试的东西都无法正常工作。具体来说,我们可能依赖于内部使用“std”代码的 crates。为了发现这些问题,我们必须在 WASM 和嵌入式环境中进行测试。

规则 5:为 WASM 构建你的项目。使用 cargo tree 来使其正常工作。

安装 WASM 交叉编译器,并用以下命令检查你的项目:

rustup target add wasm32-unknown-unknown # only need to do this once
# may find issues
cargo check --target wasm32-unknown-unknown --features alloc --no-default-features

当我在 range-set-blaze 上执行此操作时,它抱怨 getrandom crate 与 WASM 不兼容。一方面,我不惊讶 WASM 不完全支持随机数。另一方面,我感到惊讶,因为我的项目并不直接依赖于 getrandom。为了找出间接依赖,我使用 cargo tree。我发现我的项目依赖于 crate rand,而 rand 依赖于 getrandom。以下是使用的 cargo tree 命令:

cargo tree --edges no-dev --format "{p} {f}" --features alloc --no-default-features

该命令输出所有的 crate 及其使用的特性:

range-set-blaze v0.1.6 (O:\Projects\Science\wasmetc\wasm3) alloc
├── gen_ops v0.4.0
├── itertools v0.10.5 default,use_alloc,use_std
│   └── either v1.8.1 use_std
├── num-integer v0.1.45 default,std
│   └── num-traits v0.2.15 default,std
│       [build-dependencies]
│       └── autocfg v1.1.0
│   [build-dependencies]
│   └── autocfg v1.1.0
├── num-traits v0.2.15 default,std (*)
├── rand v0.8.5 alloc,default,getrandom,libc,rand_chacha,std,std_rng
│   ├── rand_chacha v0.3.1 std
│   │   ├── ppv-lite86 v0.2.17 simd,std
│   │   └── rand_core v0.6.4 alloc,getrandom,std
│   │       └── getrandom v0.2.9 std
│   │           └── cfg-if v1.0.0
...

输出显示range-set-blaze依赖于rand。此外,它还显示rand依赖于带有“std”特性的getrandom

我阅读了 getrandom documentation 并了解到其“js”特性支持 WASM。那么,我们如何让 rand 使用 getrandom/js,但仅在我们编译时启用我们的“alloc”特性?我们这样更新我们的 Cargo.toml

[features]
default = ["std"]
std = ["getrandom/std"]
alloc = ["getrandom/js"]

[dependencies]
# ...
getrandom = "0.2.10"

这表示我们的“std”特性依赖于 getrandom 的“std”特性。然而,我们的“alloc”特性应使用 getrandomjs 特性。

现在可以正常工作:

cargo check --target wasm32-unknown-unknown --features alloc --no-default-features

所以,我们已经完成了 WASM 的编译,但测试 WASM 呢?

规则 6:创建 WASM 测试和 WASM 演示。

让我们先用测试然后用演示网页来运行 WASM 版本。

tests/wasm.rs 中创建 WASM 测试

你可以几乎像测试本地代码一样测试 WASM。我们通过让原始测试仅在本地运行,而几乎重复的测试集在 WASM 上运行来做到这一点。以下是基于 The [wasm-bindgen](https://rustwasm.github.io/wasm-bindgen/wasm-bindgen-test/index.html) Guide 的步骤:

  1. 执行 cargo install wasm-bindgen-cli

  2. 将当前的集成测试从,例如 tests/integration_tests.rs 复制到 tests/wasm.rs。 (回忆一下,在 Rust 中,集成测试是位于 src 目录外的测试,并且只能看到项目的公共方法。)

  3. tests/wasm.rs 顶部,删除 #![cfg(test)] 并添加

    #![cfg(target_arch = "wasm32")]

    use wasm_bindgen_test::*; wasm_bindgen_test_configure!(run_in_browser);

  4. wasm.rs 中,将所有的#[test]替换为#[wasm_bindgen_test]

  5. 在你所有的 #![cfg(test)] 处(通常在tests/integration_tests.rssrc/tests.rs),添加额外的行:#![cfg(not(target_arch = "wasm32"))]

  6. 在你的 Cargo.toml 中,将 [dev-dependencies](如果有的话)更改为 [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]

  7. 在你的 Cargo.toml 中,添加一个部分:

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3.37"

设置完成后,本地测试,即cargo test,应该仍然有效。如果没有安装 Chrome 浏览器,请安装它。现在尝试使用以下命令运行 WASM 测试:

wasm-pack test --chrome --headless --features alloc --no-default-features

可能会失败,因为你的 WASM 测试使用了尚未或无法放入 Cargo.toml 的依赖。逐个解决每个问题,或者:

  1. 将所需的依赖项添加到 Cargo.toml[target.'cfg(target_arch = "wasm32")'.dev-dependencies] 部分,或者

  2. tests/wasm.rs 中移除测试。

对于 range-set-blaze,我删除了所有与测试包的基准测试框架相关的 WASM 测试。这些测试仍将在本地运行。在 tests\wasm.rs 中的一些有用的测试需要 crate syntactic-for,所以我将其添加到 Cargo.toml 中的 [target.'cfg(target_arch = "wasm32")'.dev-dependencies] 下。修复后,所有 59 个 WASM 测试均已运行并通过。

旁注:如果你的项目包括一个 examples folder,你可能需要在你的示例中创建一个 native 模块和一个 wasm 模块。查看这个 range-set-blaze 文件中的“示例”示例来了解如何操作。

tests/wasm-demo 中创建一个 WASM 演示。

支持 WASM 的乐趣之一是你可以在网页中演示你的 Rust 代码。这是一个网页 演示 [range-set-blaze](http://carlkcarlk.github.io/range-set-blaze/wasm-demo/)

按照以下步骤创建你自己的网页演示:

在项目的主 Cargo.toml 文件中,定义一个工作区并将 tests/wasm-demo 添加到其中:

[workspace]
members = [".", "tests/wasm-demo"]

在你的测试文件夹中,创建一个 test/wasm-demo 子文件夹。

它应该包含一个新的 Cargo.toml 文件,类似于这样(将 range-set-blaze 更改为你项目的名称):

[package]
name = "wasm-demo"
version = "0.1.0"
edition = "2021"

[lib]
crate-type = ["cdylib"]

[dependencies]
wasm-bindgen = "0.2"
range-set-blaze = { path = "../..", features = ["alloc"], default-features = false}

另外,创建一个文件 tests/wasm-demo/src/lib.rs。这是我的示例:

#![no_std]
extern crate alloc;
use alloc::{string::ToString, vec::Vec};
use range_set_blaze::RangeSetBlaze;
use wasm_bindgen::prelude::*;

#[wasm_bindgen]
pub fn disjoint_intervals(input: Vec<i32>) -> JsValue {
    let set: RangeSetBlaze<_> = input.into_iter().collect();
    let s = set.to_string();
    JsValue::from_str(&s)
}

这个文件定义了一个名为 disjoint_intervals 的函数,它接受一个整数向量作为输入,例如,100,103,101,102,-3,-4。使用 range-set-blaze 包,函数返回一个字符串,表示这些整数按排序后、互不重叠的区间,例如,-4..=-3, 100..=103

作为最后一步,创建文件 tests/wasm-demo/index.html。我的使用了一些 JavaScript 代码来接收一个整数列表,然后调用 Rust WASM 函数 disjoint_intervals

<!DOCTYPE html>
<html>
<body>
    <h2>Rust WASM RangeSetBlaze Demo</h2>
    <p>Enter a list of comma-separated integers:</p>
    <input id="inputData" type="text" value="100,103,101,102,-3,-4" oninput="callWasmFunction()">
    <br><br>
    <p id="output"></p>
    <script type="module">
        import init, { disjoint_intervals } from './pkg/wasm_demo.js';

        function callWasmFunction() {
            let inputData = document.getElementById("inputData").value;
            let data = inputData.split(',').map(x => x.trim() === "" ? NaN : Number(x)).filter(n => !isNaN(n));
            const typedArray = Int32Array.from(data);
            let result = disjoint_intervals(typedArray);
            document.getElementById("output").innerHTML = result;
        }
        window.callWasmFunction = callWasmFunction;
        init().then(callWasmFunction);
    </script>
</body>
</html>

要在本地运行演示,首先将终端移动到 tests/wasm-demo。然后执行:

# from tests/wasm-demo
wasm-pack build --target web

接下来,启动本地网络服务器并查看页面。我使用了 Live Preview 扩展到 VS Code。许多人使用 python -m http.serverrange-set-blaze 演示看起来像这样(也可以在 GitHub 上实时查看):

我发现看到我的 Rust 项目在网页中运行非常令人满意。如果 WASM 兼容性是你所寻找的一切,你可以跳到规则 9。

规则 7:为嵌入式构建你的项目。

如果你想将你的项目推进到 WASM 之外,请遵循这个规则和接下来的规则。

确保将终端移动回项目的主目录。然后,安装 thumbv7m-none-eabi,这是一个流行的嵌入式处理器,并使用以下命令检查你的项目:

# from project's home directory
rustup target add thumbv7m-none-eabi # only need to do this once
# will likely find issues
cargo check --target thumbv7m-none-eabi --features alloc --no-default-features

当我在range-set-blaze上执行此操作时,我得到与四组依赖项相关的错误:

  • thiserror — 我的项目依赖于这个 crate,但实际上没有使用它。我删除了这个依赖。

  • randgetrandom — 我的项目只需要在本地测试中使用随机数,所以我将依赖项移动到了[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]。我还更新了我的主代码和测试代码。

  • itertoolsnum-traitsnum-integer — 这些 crate 为“std”和“alloc”提供了功能。我将Cargo.toml更新为如下:

...
[features]
default = ["std"]
std = ["itertools/use_std", "num-traits/std", "num-integer/std"]
alloc = ["itertools/use_alloc", "num-traits", "num-integer"]

[dependencies]
itertools = { version = "0.10.1", optional = true, default-features = false }
num-integer = { version = "0.1.44", optional = true, default-features = false }
num-traits = { version = "0.2.15", optional = true, default-features = false }
gen_ops = "0.4.0"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
#...
rand = "0.8.4"
#...

我怎么知道使用哪个依赖项的哪个功能?理解像itertools这样的 crate 的功能需要阅读其文档并(通常)访问其 GitHub 仓库并阅读其Cargo.toml。你还应该使用cargo tree来检查你是否从每个依赖项中获得了所需的功能。例如,这种使用cargo tree的方法显示了对于默认编译,我得到了range-set-blazenum-integernum-traits的“std”功能,以及itertoolseither的“use-std”功能:

cargo tree --edges no-dev --format "{p} {f}"
range-set-blaze v0.1.6 (O:\Projects\Science\wasmetc\wasm4) default,itertools,num-integer,num-traits,std
├── gen_ops v0.4.0
├── itertools v0.10.5 use_alloc,use_std
│   └── either v1.8.1 use_std
├── num-integer v0.1.45 std
│   └── num-traits v0.2.15 std
│       [build-dependencies]
│       └── autocfg v1.1.0
│   [build-dependencies]
│   └── autocfg v1.1.0
└── num-traits v0.2.15 std (*)

这表明,对于--features alloc --no-default-feature编译,我获得了itertools的“use_alloc”功能以及其他依赖项的“无默认”版本:

cargo tree --edges no-dev --format "{p} {f}" --features alloc --no-default-features
range-set-blaze v0.1.6 (O:\Projects\Science\wasmetc\wasm4) alloc,itertools,num-integer,num-traits
├── gen_ops v0.4.0
├── itertools v0.10.5 use_alloc
│   └── either v1.8.1
├── num-integer v0.1.45
│   └── num-traits v0.2.15
│       [build-dependencies]
│       └── autocfg v1.1.0
│   [build-dependencies]
│   └── autocfg v1.1.0
└── num-traits v0.2.15  (*)

当你认为一切正常时,使用这些命令来检查/测试本地、WASM 和嵌入式:

# test native
cargo test
cargo test --features alloc --no-default-features
# check and test WASM
cargo check --target wasm32-unknown-unknown --features alloc --no-default-features
wasm-pack test --chrome --headless --features alloc --no-default-features
# check embedded
cargo check --target thumbv7m-none-eabi --features alloc --no-default-features

这些检查嵌入,但测试嵌入呢?

规则 8:创建一个单一的嵌入式测试和一个嵌入式演示。

让我们通过创建一个综合测试和演示来发挥我们的嵌入式功能。我们将它运行在一个叫做 QEMU 的模拟器上。

测试本地 Rust 很简单。测试 WASM Rust 还可以。测试嵌入式 Rust 很困难。我们将只进行一次简单的测试。

附注 1:有关运行和模拟嵌入式 Rust 的更多信息,请参见:The Embedded Rust Book

附注 2:有关更完整的嵌入式 Rust 测试框架的想法,请参见defmt-test。遗憾的是,我无法搞清楚如何在模拟下运行它。cortex-m/testsuite项目使用了 defmt-test 的一个分支,并且可以在模拟下运行,但没有提供独立的测试 crate,并且需要三个额外的(子)项目。

附注 3:一个嵌入式测试要比没有测试好得多。我们将在本地和 WASM 级别进行其余的测试。

我们将在当前的tests文件夹内创建嵌入式测试和演示。文件将是:

tests/embedded
├── .cargo
│   └── config.toml
├── Cargo.toml
├── build.rs
├── memory.x
└── src
    └── main.rs

这里是创建文件和设置的步骤。

  1. 安装QEMU 模拟器。在 Windows 上,这涉及运行安装程序,然后手动将"C:\Program Files\qemu\"添加到你的路径中。

2. 创建一个依赖于本地项目的 tests/embedded/Cargo.toml,并包含“无默认功能”和“alloc”。这是我的:

[package]
edition = "2021"
name = "embedded"
version = "0.1.0"

[dependencies]
alloc-cortex-m = "0.4.4"
cortex-m = "0.6.0"
cortex-m-rt = "0.6.10"
cortex-m-semihosting = "0.3.3"
panic-halt = "0.2.0"# reference your local project here
range-set-blaze = { path = "../..", features = ["alloc"], default-features = false }

[[bin]]
name = "embedded"
test = false
bench = false

3. 创建一个文件 tests/embedded/src/main.rs。将你的测试代码放在“test goes here”注释之后。这是我的文件:

// based on https://github.com/rust-embedded/cortex-m-quickstart/blob/master/examples/allocator.rs
// and https://github.com/rust-lang/rust/issues/51540
#![feature(alloc_error_handler)]
#![no_main]
#![no_std]

extern crate alloc;
use alloc::string::ToString;
use alloc_cortex_m::CortexMHeap;
use core::{alloc::Layout, iter::FromIterator};
use cortex_m::asm;
use cortex_m_rt::entry;
use cortex_m_semihosting::{debug, hprintln};
use panic_halt as _;
use range_set_blaze::RangeSetBlaze;

#[global_allocator]
static ALLOCATOR: CortexMHeap = CortexMHeap::empty();
const HEAP_SIZE: usize = 1024; // in bytes

#[entry]
fn main() -> ! {
    unsafe { ALLOCATOR.init(cortex_m_rt::heap_start() as usize, HEAP_SIZE) }

    // test goes here
    let range_set_blaze = RangeSetBlaze::from_iter([100, 103, 101, 102, -3, -4]);
    assert!(range_set_blaze.to_string() == "-4..=-3, 100..=103");
    hprintln!("{:?}", range_set_blaze.to_string()).unwrap();

    // exit QEMU/ NOTE do not run this on hardware; it can corrupt OpenOCD state
    debug::exit(debug::EXIT_SUCCESS);
    loop {}
}

#[alloc_error_handler]
fn alloc_error(_layout: Layout) -> ! {
    asm::bkpt();
    loop {}
}

4. 从 [cortex-m-quickstart](https://github.com/rust-embedded/cortex-m-quickstart/tree/master) 的 GitHub 仓库复制 build.rsmemory.xtests/embedded/

5. 创建一个包含以下内容的 tests/embedded/.cargo/config.toml

[target.thumbv7m-none-eabi]
runner = "qemu-system-arm -cpu cortex-m3 -machine lm3s6965evb -nographic -semihosting-config enable=on,target=native -kernel"

[build]
target = "thumbv7m-none-eabi"

6. 通过将 tests/embedded 添加到你的工作区来更新项目的主 Cargo.toml

[workspace]
members = [".", "tests/wasm-demo", "tests/embedded"]

使用这个设置,你几乎准备好运行模拟的嵌入式了。接下来,准备好你的终端,并将编译器设置为 nightly:

# Be sure qemu is on path, e.g., set PATH="C:\Program Files\qemu\";%PATH%
cd tests/embedded
rustup override set nightly # to support #![feature(alloc_error_handler)]

现在你可以在演示应用上使用 cargo checkcargo buildcargo run。例如:

cargo run
 Finished dev [unoptimized + debuginfo] target(s) in 0.03s
     Running `qemu-system-arm -cpu cortex-m3 -machine lm3s6965evb -nographic -semihosting-config enable=on,target=native -kernel O:\Projects\Science\wasmetc\wasm4\target\thumbv7m-none-eabi\debug\embedded`
Timer with period zero, disabling
"-4..=-3, 100..=103"

当你完成这项工作时,你将成功在一个(模拟的)微控制器上运行你的项目!如果你在设置过程中遇到问题,请仔细检查这些说明。如果还是不行,可以查看 The Embedded Rust Book

完成后,请务必将编译器设置回稳定版本:

rustup override set stable

规则 9:完成 CI 测试、Cargo.toml 元数据和更新后的 README.md。

CI 测试

我们快完成了,但我们必须确保今天有效的内容明天也能有效。这就是 CI(持续集成)测试的工作。

我将我的 CI 测试设置为每次检查和每个月运行一次。如果在 GitHub 上,创建一个文件 .github/workflows/tests.yml

name: test

on:
  push:
  schedule: # run every month
    - cron: '0 0 1 * *'
  pull_request:
  workflow_dispatch:

jobs:
  test_rust:
    name: Test Rust
    runs-on: ubuntu-latest
    steps:
      - name: Checkout
        uses: actions/checkout@v3
      - name: Setup Rust
        uses: dtolnay/rust-toolchain@master
        with:
          toolchain: stable
      - name: Setup WASM
        uses: jetli/wasm-pack-action@v0.4.0
      - name: Test Native & WASM
        run: |
          cargo clippy --verbose --all-targets --all-features -- -D warnings
          cargo test --verbose
          cargo test --features alloc --no-default-features --verbose
          wasm-pack test --chrome --headless --features alloc --no-default-features --verbose
      - name: Setup and check Embedded
        run: |
          rustup target add thumbv7m-none-eabi
          cargo check --target thumbv7m-none-eabi --features alloc --no-default-features
          rustup override set nightly
          rustup target add thumbv7m-none-eabi
          cargo check --target thumbv7m-none-eabi --features alloc --no-default-features
          sudo apt-get update && sudo apt-get install qemu qemu-system-arm
      - name: Test Embedded (in nightly)
        timeout-minutes: 3
        run: |
          cd tests/embedded
          cargo run

如果你仅仅使用 WASM,你可以省略与嵌入式相关的最后两个步骤。

附注:为什么最后一个测试会显示 timeout-minutes: 3?因为一个失败的嵌入式测试不会返回失败。相反,它会进入一个无限循环。我们通过超时来捕捉这一点。

元数据

Rust 允许你标记你的代码适用于特定的架构和环境。惯例是使用关键字和类别元数据。具体来说,根据需要将这些关键字和 类别 添加到你的 Cargo.toml 中:

[package]
#...
keywords = [
#...
    "wasm",
    "no_std",
]
categories = [
#...
    "wasm",
    "no-std",
] 

README.md

你还应该更新你的 README.md,告诉大家你支持 WASM 和嵌入式。这是我添加的内容:

The crate supports no_std, WASM, and embedded projects:

```toml

[dependencies]

range-set-blaze = { features = ["alloc"], default-features = false, version=VERSION }

```py

 *Relace VERSION with the current version.

所以,你现在有了:针对 Rust 中 WASM 和 no_std 端口的九条规则。Rust 是一个出色的语言,适用于本地、WASM 和嵌入式编程。它提供了速度、安全性,并访问到成千上万的有用 crate。遵循这九条规则,可以在几乎所有地方运行你自己的 Rust 代码。

附注:如果你对未来的文章感兴趣,请 关注我在 Medium。我写关于 Rust 和 Python 的科学编程、机器学习和统计。我通常每月写一篇文章。

Rust 代码 SIMD 加速的九条规则(第一部分)

原文:towardsdatascience.com/nine-rules-for-simd-acceleration-of-your-rust-code-part-1-c16fe639ce21?source=collection_archive---------2-----------------------#2023-12-12

通过将数据摄入在range-set-blaze库中提升 7 倍的一般经验教训。

Carl M. KadieTowards Data Science Carl M. Kadie

·

关注 发布于Towards Data Science ·17 min read·Dec 12, 2023

--

蟹通过小蟹委派进行计算 — 来源:openai.com/dall-e-2/。所有其他数据来自作者。

感谢 Ben Lichtman(B3NNY)在西雅图 Rust Meetup 中为我指明了 SIMD 的正确方向。

SIMD(单指令、多数据)操作自 2000 年代初以来一直是 Intel/AMD 和 ARM CPU 的一个特性。这些操作使你可以,例如,只用一个 CPU 操作 在单核 上将八个 i32 的数组加到另一个八个 i32 的数组上。使用 SIMD 操作大大加快了某些任务的速度。如果你没有使用 SIMD,你可能没有充分利用你 CPU 的能力。

这篇文章是“另一个 Rust 和 SIMD”文章吗?是的,也不是。是的,我确实将 SIMD 应用于一个编程问题,然后觉得有必要写一篇文章。不是,我希望这篇文章也能深入到足以指导你完成你的项目。它解释了 Rust nightly 中新提供的 SIMD 功能和设置。它包括一个 Rust SIMD 速查表。它展示了如何在不离开安全 Rust 的情况下使你的 SIMD 代码通用。它让你开始使用如 Godbolt 和 Criterion 等工具。最后,它介绍了简化过程的新 cargo 命令。

[range-set-blaze](https://crates.io/crates/range-set-blaze) crate 使用其 RangeSetBlaze::from_iter 方法来处理可能很长的整数序列。当整数是“clumpy”时,它可以比 Rust 的标准 HashSet::from_iter 快 30 倍。如果我们使用 SIMD 操作,能做到更好吗?是的!

查看 此文档 了解“clumpy”的定义。此外,当整数不不规则时会发生什么?RangeSetBlazeHashSet 慢 2 到 3 倍

对于不规则整数,RangeSetBlaze::from_slice — 基于 SIMD 操作的新方法 — 比 RangeSetBlaze::from_iter 快 7 倍。这使它比 HashSet::from_iter 快超过 200 倍。(当整数不不规则时,它仍然比 HashSet 慢 2 到 3 倍。)

在实现这一加速的过程中,我学到了九条规则,这些规则可以帮助你使用 SIMD 操作加速你的项目。

这些规则是:

  1. 使用 nightly Rust 和 core::simd,Rust 的实验性标准 SIMD 模块。

  2. CCC: 检查、控制并选择你计算机的 SIMD 能力。

  3. 学习 core::simd,但要有选择地。

  4. 头脑风暴候选算法。

  5. 使用 Godbolt 和 AI 来理解你代码的汇编,即使你不懂汇编语言。

  6. 使用内联泛型(当这不起作用时)宏,(当宏不起作用时)特性,将其推广到所有类型和 LANES。

查看 第二部分 以获取这些规则:

7. 使用 Criterion 基准测试来选择算法,并发现 LANES 应该(几乎)始终为 32 或 64。

8. 将您的最佳 SIMD 算法集成到您的项目中,并使用 *as_simd* 特别的代码处理 *i128* / *u128* ,并额外进行上下文基准测试。

9. 从项目中提取出您的最佳 SIMD 算法(目前)并选择一个可选的 cargo 特性。

旁注:为了避免含糊其辞,我称这些为“规则”,但它们当然只是建议。

规则 1:使用 nightly Rust 和 core::simd,Rust 的实验性标准 SIMD 模块。

Rust 可以通过稳定的 [core::arch](https://doc.rust-lang.org/core/arch/index.html) 模块或 nightly 的 [core::simd](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html) 模块访问 SIMD 操作。让我们比较一下它们:

**core::arch**

**core::simd**

  • Nightly

  • 令人愉快的简单和可移植。

  • 限制了向下游用户只能使用 nightly 版。

我决定选择“简单”。如果您决定选择更难的路线,首先从更简单的路径开始可能仍然是值得的。

无论哪种情况,在我们尝试在一个更大的项目中使用 SIMD 操作之前,让我们确保我们能够完全使用它们。以下是步骤:

首先,创建一个名为 simd_hello 的项目:

cargo new simd_hello
cd simd_hello

编辑 src/main.rs 以包含 (Rust playground):

// Tell nightly Rust to enable 'portable_simd'
#![feature(portable_simd)]
use core::simd::prelude::*;

// constant Simd structs
const LANES: usize = 32;
const THIRTEENS: Simd<u8, LANES> = Simd::<u8, LANES>::from_array([13; LANES]);
const TWENTYSIXS: Simd<u8, LANES> = Simd::<u8, LANES>::from_array([26; LANES]);
const ZEES: Simd<u8, LANES> = Simd::<u8, LANES>::from_array([b'Z'; LANES]);

fn main() {
    // create a Simd struct from a slice of LANES bytes
    let mut data = Simd::<u8, LANES>::from_slice(b"URYYBJBEYQVQBUBCRVGFNYYTBVATJRYY");

    data += THIRTEENS; // add 13 to each byte

    // compare each byte to 'Z', where the byte is greater than 'Z', subtract 26
    let mask = data.simd_gt(ZEES); // compare each byte to 'Z'
    data = mask.select(data - TWENTYSIXS, data);

    let output = String::from_utf8_lossy(data.as_array());
    assert_eq!(output, "HELLOWORLDIDOHOPEITSALLGOINGWELL");
    println!("{}", output);
}

接下来 —— 全面的 SIMD 功能需要 Rust 的 nightly 版本。假设您已安装了 Rust,请安装 nightly 版 (rustup install nightly)。确保您有最新的 nightly 版本 (rustup update nightly)。最后,设置此项目使用 nightly 版 (rustup override set nightly)。

您现在可以使用 cargo run 运行程序。该程序对 32 个大写字母的 ROT13 解密。通过 SIMD,程序可以同时解密所有 32 个字节。

让我们看看程序的每个部分是如何工作的。它从以下开始:

#![feature(portable_simd)]
use core::simd::prelude::*;

Rust nightly 仅在请求时提供其额外的功能(或“特性”)。 #![feature(portable_simd)] 语句请求 Rust nightly 可用新的实验性 core::simd 模块。然后,use 语句导入了模块的最重要的类型和特征。

在代码的下一部分中,我们定义了一些有用的常量:

const LANES: usize = 32;
const THIRTEENS: Simd<u8, LANES> = Simd::<u8, LANES>::from_array([13; LANES]);
const TWENTYSIXS: Simd<u8, LANES> = Simd::<u8, LANES>::from_array([26; LANES]);
const ZEES: Simd<u8, LANES> = Simd::<u8, LANES>::from_array([b'Z'; LANES]);

Simd结构体是一种特殊类型的 Rust 数组。(例如,它始终是内存对齐的。)常量LANES告诉了Simd数组的长度。from_array构造函数复制一个常规的 Rust 数组来创建一个Simd。在这种情况下,因为我们需要const Simd,所以我们构造的数组也必须是const

接下来的两行将我们加密的文本复制到data,然后对每个字母添加 13。

let mut data = Simd::<u8, LANES>::from_slice(b"URYYBJBEYQVQBUBCRVGFNYYTBVATJRYY");
data += THIRTEENS;

如果您出错了,您的加密文本长度不正好为LANES(32)怎么办?遗憾的是,编译器不会告诉您。相反,在运行程序时,from_slice将会崩溃。如果加密文本包含非大写字母怎么办?在本示例程序中,我们将忽略这种可能性。

+=操作符在Simd dataSimd THIRTEENS之间进行逐元素加法。它将结果放入data中。请记住,常规 Rust 加法的调试构建会检查溢出。但 SIMD 不会这样做。Rust 定义了 SIMD 算术运算符总是进行包装。类型为u8的值在 255 之后会包装。

巧合的是,Rot13 解密也需要包装,但是在‘Z’之后而不是在 255 之后。这里有一种编码所需 Rot13 包装的方法。它从任何值中减去 26,超出了‘Z’。

let mask = data.simd_gt(ZEES);
data = mask.select(data - TWENTYSIXS, data);

这里要求找到逐个元素的超过‘Z’的位置。然后,从所有值中减去 26。在感兴趣的位置,使用减去的值。在其他位置,使用原始值。从所有值中减去然后只使用一些看起来是不是浪费了?使用 SIMD,这不需要额外的计算机时间并且避免了跳转。因此,这种策略是高效且常见的。

程序以此方式结束:

let output = String::from_utf8_lossy(data.as_array());
assert_eq!(output, "HELLOWORLDIDOHOPEITSALLGOINGWELL");
println!("{}", output);

注意.as_array()方法。它安全地将Simd结构体转换为常规的 Rust 数组而不复制。

令我惊讶的是,这个程序在没有 SIMD 扩展的计算机上运行良好。Rust nightly 将代码编译成常规(非 SIMD)指令。但我们不仅仅想要运行“良好”,我们想要运行更快。这需要我们打开计算机的 SIMD 性能。

规则 2:CCC:检查,控制和选择您计算机的 SIMD 能力。

要使 SIMD 程序在您的计算机上运行得更快,您必须首先发现您的计算机支持哪些 SIMD 扩展。如果您有 Intel/AMD 计算机,可以使用我的[simd-detect](https://github.com/CarlKCarlK/cargo-simd-detect) cargo 命令。

运行:

rustup override set nightly
cargo install cargo-simd-detect --force
cargo simd-detect

在我的计算机上,输出如下:

extension       width                   available       enabled
sse2            128-bit/16-bytes        true            true
avx2            256-bit/32-bytes        true            false
avx512f         512-bit/64-bytes        true            false

这说明我的计算机支持sse2avx2avx512f SIMD 扩展。在其中,默认情况下,Rust 启用了普遍存在已有二十年历史的sse2扩展。

SIMD 扩展形成一个层次结构,avx512favx2之上,在sse2之上。启用更高级别的扩展也会启用较低级别的扩展。

大多数 Intel/AMD 计算机也支持十年历史的avx2扩展。您可以通过设置环境变量来启用它:

# For Windows Command Prompt
set RUSTFLAGS=-C target-feature=+avx2

# For Unix-like shells (like Bash)
export RUSTFLAGS="-C target-feature=+avx2"

“强制安装”并再次运行simd-detect,您应该看到启用了avx2

# Force install every time to see changes to 'enabled'
cargo install cargo-simd-detect --force
cargo simd-detect
extension         width                   available       enabled
sse2            128-bit/16-bytes        true            true
avx2            256-bit/32-bytes        true            true
avx512f         512-bit/64-bytes        true            false

或者,你可以打开你的机器支持的每一个 SIMD 扩展:

# For Windows Command Prompt
set RUSTFLAGS=-C target-cpu=native

# For Unix-like shells (like Bash)
export RUSTFLAGS="-C target-cpu=native"

在我的机器上,这启用了 avx512f,这是一种新的 SIMD 扩展,由一些英特尔计算机和少数 AMD 计算机支持。

你可以将 SIMD 扩展设置回它们的默认值(在英特尔/AMD 上是 sse2):

# For Windows Command Prompt
set RUSTFLAGS=

# For Unix-like shells (like Bash)
unset RUSTFLAGS

你可能会想知道为什么 target-cpu=native 不是 Rust 的默认值。问题在于使用 avx2avx512f 创建的二进制文件不能在缺少这些 SIMD 扩展的计算机上运行。因此,如果只为自己使用编译,请使用 target-cpu=native。然而,如果为其他人编译,请慎重选择 SIMD 扩展,并告知人们你所假设的 SIMD 扩展级别。

令人高兴的是,无论你选择哪种 SIMD 扩展级别,Rust 的 SIMD 支持都非常灵活,你可以轻松更改你的决策。接下来让我们详细了解在 Rust 中使用 SIMD 编程的细节。

规则 3:学习 core::simd,但要有选择性。

要使用 Rust 的新 [core::simd](https://doc.rust-lang.org/nightly/core/simd/index.html) 模块,你应该学习选择的构建模块。这里有一个速查表,包含我发现最有用的结构体、方法等。每个项目都包含到其文档的链接。

结构体

  • [Simd](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html) - 一个特殊的、对齐的、固定长度的数组,由[SimdElement](https://doc.rust-lang.org/std/simd/trait.SimdElement.html)组成。我们将数组中的位置及其存储的元素称为“lane”。默认情况下,我们复制 Simd 结构体而不是引用它们。

  • [Mask](https://doc.rust-lang.org/nightly/core/simd/struct.Mask.html) - 一种特殊的布尔数组,显示每个 lane 的包含/排除情况。

SimdElements

  • 浮点类型:f32f64

  • 整数类型:i8u8i16u16i32u32i64u64isizeusize

  • 但不包括 [*i128*](https://github.com/rust-lang/portable-simd/issues/108), [*u128*](https://github.com/rust-lang/portable-simd/issues/108)

**Simd** 构造函数

  • [Simd::from_array](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.from_array) - 通过复制固定长度数组创建一个 Simd 结构体。

  • [Simd::from_slice](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.from_slice) - 通过复制切片的前 LANE 个元素创建一个 Simd<T,LANE> 结构体。

  • [Simd::splat](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.splat) - 将单个值复制到 Simd 结构的所有 lane 中。

  • [slice::as_simd](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.to_simd) - 安全地将常规切片转换为对齐的 Simd 切片(加上不对齐的剩余部分),而不进行复制。

**Simd** 转换

  • [Simd::as_array](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.as_array) - 在不复制的情况下,将Simd结构体安全地转换为普通数组引用。

Simd 方法和运算符

  • [simd[i]](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.index) - 从Simd的一个通道中提取一个值。

  • [simd + simd](https://doc.rust-lang.org/core/simd/struct.Simd.html#impl-Add%3C%26'rhs+Simd%3CT,+LANES%3E%3E-for-%26'lhs+Simd%3CT,+LANES%3E) - 执行两个Simd结构体的元素级加法。同时支持-*/%、余数、按位与、按位或、异或、按位非、位移。

  • [simd += simd](https://doc.rust-lang.org/core/simd/struct.Simd.html#impl-AddAssign%3CU%3E-for-Simd%3CT,+LANES%3E) - 将另一个Simd结构体加到当前结构体上,进行就地操作。其他运算符也受支持。

  • [Simd::simd_gt](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.simd_gt) - 比较两个Simd结构体,返回一个Mask,指示第一个结构体的哪些元素大于第二个结构体的元素。同时支持simd_ltsimd_lesimd_gesimd_ltsimd_eqsimd_ne

  • [Simd::rotate_elements_left](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.rotate_elements_left) - 将Simd结构体的元素向左旋转指定的数量。同时支持rotate_elements_right

  • [simd_swizzle!(simd, indexes)](https://doc.rust-lang.org/std/simd/prelude/macro.simd_swizzle.html) - 根据指定的常量索引重新排列Simd结构体的元素。

  • [simd == simd](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#impl-Eq-for-Simd%3CT,+N%3E) - 检查两个Simd结构体之间的相等性,返回一个普通的bool结果。

  • [Simd::reduce_and](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.reduce_and) - 执行Simd结构体所有通道的按位与归约。同时支持:reduce_orreduce_xorreduce_maxreduce_minreduce_sum(但不支持reduce_eq)。

掩码 方法和运算符

  • [Mask::select](https://doc.rust-lang.org/nightly/core/simd/struct.Mask.html#method.select) - 根据掩码从两个Simd结构体中选择元素。

  • [Mask::all](https://doc.rust-lang.org/nightly/core/simd/struct.Mask.html#method.all) - 指示掩码是否全为true

  • [Mask::any](https://doc.rust-lang.org/nightly/core/simd/struct.Mask.html#method.all) - 指示掩码是否包含任何true

关于通道的一切

  • [Simd::LANES](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#associatedconstant.LANES) - 一个常量,表示Simd结构体中的元素(通道)数量。

  • [SupportedLaneCount](https://doc.rust-lang.org/nightly/core/simd/trait.SupportedLaneCount.html) - 指示允许的LANES值。通过泛型使用。

  • [simd.lanes](https://doc.rust-lang.org/core/simd/struct.Simd.html#method.lanes) - 常量方法,告诉Simd结构体的通道数量。

低级对齐、偏移量等

尽可能使用 [*to_simd*](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.to_simd) 代替。

  • [mem::size_of](https://doc.rust-lang.org/std/mem/fn.size_of.html)[mem::align_of](https://doc.rust-lang.org/std/mem/fn.align_of.html)[mem::align_to](https://doc.rust-lang.org/std/mem/fn.align_to.html)[intrinsics::offset](https://doc.rust-lang.org/std/intrinsics/fn.offset.html)[pointer::read_unaligned](https://doc.rust-lang.org/std/primitive.pointer.html#method.read_unaligned)(不安全),[pointer::write_unaligned](https://doc.rust-lang.org/std/primitive.pointer.html#method.write_unaligned)(不安全),[mem::transmute](https://doc.rust-lang.org/std/mem/fn.transmute.html)(不安全,const)

更多,也许感兴趣的

  • [deinterleave](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.deinterleave)[gather_or](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.gather_or)[reverse](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.reverse)[scatter](https://doc.rust-lang.org/nightly/core/simd/struct.Simd.html#method.scatter)

有了这些构建模块,现在是时候创造一些东西了。

规则 4:头脑风暴候选算法。

你想加速什么?你事先不会知道哪种 SIMD 方法(如果有的话)最好。因此,你应该创建许多算法,然后分析(规则 5)和基准测试(规则 7)它们。

我希望加速 [range-set-blaze](https://crates.io/crates/range-set-blaze),一个用于操作“clumpy”整数集的 crate。我希望创建 is_consecutive,一个用于检测连续整数块的函数,会很有用。

背景: Crate *range-set-blaze* *用于处理“clumpy”整数。这里的“clumpy”意味着用于表示数据的范围数量与输入整数的数量相比较少。例如,这些 1002 个输入整数

100, 101, ..., 489, 499, 501, 502, ..., 998, 999, 999, 100, 0

最终变成三个 Rust 范围:

0..=0, 100..=499, 501..=999

(在内部,[*RangeSetBlaze*](https://docs.rs/range-set-blaze/latest/range_set_blaze/struct.RangeSetBlaze.html#) 结构将整数集表示为存储在高效缓存 BTreeMap 中的排序不相交范围列表。)

尽管允许输入整数是无序和冗余的,但我们期望它们通常是“好的”。RangeSetBlaze 的 from_iter 构造函数已经利用这一期望通过组合相邻整数来分组。例如,from_iter 首先将这 1002 个输入整数转换为四个范围

*100..=499, 501..=999, 100..=100, 0..=0.*

在最小的恒定内存使用下,独立于输入大小。然后,它对这些减少的范围进行排序和合并。

我想知道是否可以通过快速找到(一些)连续整数来加速从类似数组的输入构建的 from_slice 方法。例如,是否可以在最小的恒定内存下,将 1002 个输入整数 转换为五个 Rust 范围:

*100..=499, 501..=999, 999..=999, 100..=100, 0..=0.*

如果是这样, *from_iter* 可以快速完成处理。

让我们先用常规 Rust 编写 is_consecutive

pub const LANES: usize = 16;
pub fn is_consecutive_regular(chunk: &[u32; LANES]) -> bool {
    for i in 1..LANES {
        if chunk[i - 1].checked_add(1) != Some(chunk[i]) {
            return false;
        }
    }
    true
}

算法只是顺序遍历数组,检查每个值是否比前一个值多 1。它还避免了溢出。

遍历这些项似乎很简单,我不确定 SIMD 是否能做得更好。这是我的第一次尝试:

Splat0

use std::simd::prelude::*;

const COMPARISON_VALUE_SPLAT0: Simd<u32, LANES> =
    Simd::from_array([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);

pub fn is_consecutive_splat0(chunk: Simd<u32, LANES>) -> bool {
    if chunk[0].overflowing_add(LANES as u32 - 1) != (chunk[LANES - 1], false) {
        return false;
    }
    let added = chunk + COMPARISON_VALUE_SPLAT0;
    Simd::splat(added[0]) == added
}

这里是它的计算概要:

来源:这张图及所有后续图片均由作者提供。

它首先(不必要地)检查第一个和最后一个项目是否相隔 15。然后,它通过将 15 加到第 0 项,将 14 加到下一个项,以此类推来创建 added。最后,为了查看 added 中的所有项是否相同,它基于 added 的第 0 项创建一个新的 Simd,然后进行比较。请记住,splat 从一个值创建一个 Simd 结构。

Splat1 & Splat2

当我向 Ben Lichtman 提到 is_consecutive 问题时,他独立地提出了这个,即 Splat1:

const COMPARISON_VALUE_SPLAT1: Simd<u32, LANES> =
    Simd::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);

pub fn is_consecutive_splat1(chunk: Simd<u32, LANES>) -> bool {
    let subtracted = chunk - COMPARISON_VALUE_SPLAT1;
    Simd::splat(chunk[0]) == subtracted
}

Splat1 从 chunk 中减去比较值,并检查结果是否与 chunk 的第一个元素相同,经过 splat。

他还提出了一个变体,称为 Splat2,它 splat subtracted 的第一个元素,而不是 chunk。这似乎可以避免一次内存访问。

我相信你一定在想这些方法中哪一个最好,但在我们讨论这个问题之前,让我们再看两个候选者。

Swizzle

Swizzle 类似于 Splat2,但使用 simd_swizzle! 而不是 splat。宏 simd_swizzle! 通过根据索引数组重新排列旧 Simd 的通道来创建一个新的 Simd

pub fn is_consecutive_sizzle(chunk: Simd<u32, LANES>) -> bool {
    let subtracted = chunk - COMPARISON_VALUE_SPLAT1;
    simd_swizzle!(subtracted, [0; LANES]) == subtracted
}

Rotate

这个方法不同。我对它寄予厚望。

const COMPARISON_VALUE_ROTATE: Simd<u32, LANES> =
    Simd::from_array([4294967281, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);

pub fn is_consecutive_rotate(chunk: Simd<u32, LANES>) -> bool {
    let rotated = chunk.rotate_elements_right::<1>();
    chunk - rotated == COMPARISON_VALUE_ROTATE
}

这个想法是将所有元素向右旋转一个位置。然后,我们从 rotated 中减去原始的 chunk。如果输入是连续的,结果应该是“−15”后跟所有 1。 (使用包装减法,-15 是 4294967281u32。)

现在我们有了候选者,让我们开始评估它们。

规则 5:使用 Godbolt 和 AI 来理解你的代码的汇编语言,即使你不知道汇编语言。

我们将通过两种方式评估这些候选者。首先,在这个规则中,我们将查看从代码生成的汇编语言。其次,在规则 7 中,我们将基准测试代码的速度。

如果你不知道汇编语言,也不要担心,你仍然可以从中获得一些信息。

查看生成的汇编语言的最简单方法是使用 Compiler Explorer, AKA Godbolt。它在不使用外部 crate 的简短代码片段上效果最佳。它看起来像这样:

参考上图中的数字,按照以下步骤使用 Godbolt:

  1. 使用你的网页浏览器打开 godbolt.org

  2. 添加一个新的源编辑器。

  3. 选择 Rust 作为你的语言。

  4. 粘贴感兴趣的代码。将感兴趣的函数设为公共(pub fn)。不包括主函数或不需要的函数。该工具不支持外部包(external crates)。

  5. 添加新的编译器。

  6. 将编译器版本设置为 nightly。

  7. 设置选项(暂时)为-C opt-level=3 -C target-feature=+avx512f.

  8. 如果有错误,请查看输出。

  9. 如果您想分享或保存工具的状态,请点击“分享”。

从上面的图像可以看出,Splat2 和 Sizzle 完全相同,因此我们可以将 Sizzle 从考虑中删除。如果您打开我的 Godbolt 会话的副本,您还会看到大多数函数编译为大致相同数量的汇编操作。例外是 Regular ——它更长——和 Splat0——它包括早期检查。

在汇编中,512 位寄存器以 ZMM 开头。256 位寄存器以 YMM 开头。128 位寄存器以 XMM 开头。如果您想更好地理解生成的汇编,请使用 AI 工具生成注释。例如,我在这里向Bing Chat询问关于 Splat2 的问题:

尝试不同的编译器设置,包括-C target-feature=+avx2,然后完全不使用target-feature

较少的汇编操作不一定意味着更快的速度。然而,查看汇编代码确实让我们确认编译器至少尝试使用 SIMD 操作、内联常量引用等。同样,像 Splat1 和 Swizzle 一样,有时它可以让我们知道两个候选项何时相同。

您可能需要比 Godbolt 提供的反汇编功能更多的功能,例如处理使用外部包的代码能力。B3NNY 推荐给我 cargo 工具 [cargo-show-asm](https://github.com/pacak/cargo-show-asm)。我试过了,发现使用起来相当容易。

range-set-blaze包必须处理超出u32的整数类型。此外,我们必须选择一定数量的 LANES,但我们没有理由认为 16 LANES 总是最好的。为了满足这些需求,在下一条规则中我们将概括代码。

规则 6:广义应用于所有类型和 LANES,包括内联泛型(in-lined generics),(当它不起作用时)宏(macros),以及(当它不起作用时)特性(traits)。

让我们首先用泛型概括 Splat1。

#[inline]
pub fn is_consecutive_splat1_gen<T, const N: usize>(
    chunk: Simd<T, N>,
    comparison_value: Simd<T, N>,
) -> bool
where
    T: SimdElement + PartialEq,
    Simd<T, N>: Sub<Simd<T, N>, Output = Simd<T, N>>,
    LaneCount<N>: SupportedLaneCount,
{
    let subtracted = chunk - comparison_value;
    Simd::splat(chunk[0]) == subtracted
}

首先注意#[inline]属性。对效率很重要,我们将几乎在所有这些小函数上使用它。

上面定义的函数is_consecutive_splat1_gen看起来很棒,除了它需要第二个输入,称为comparison_value,我们尚未定义。

如果您不需要通用常量comparison_value,我羡慕您。如果您愿意,您可以跳过下一条规则。同样地,如果您正在未来阅读此内容,并且创建通用常量comparison_value就像您个人机器人做家务一样轻松,那我就双倍羡慕您。

我们可以尝试创建一个comparison_value_splat_gen,它是通用的和 const 的。不幸的是,From<usize>和替代的T::One都不是 const,所以这个方法行不通:

// DOESN'T WORK BECAUSE From<usize> is not const
pub const fn comparison_value_splat_gen<T, const N: usize>() -> Simd<T, N>
where
    T: SimdElement + Default + From<usize> + AddAssign,
    LaneCount<N>: SupportedLaneCount,
{
    let mut arr: [T; N] = [T::from(0usize); N];
    let mut i_usize = 0;
    while i_usize < N {
        arr[i_usize] = T::from(i_usize);
        i_usize += 1;
    }
    Simd::from_array(arr)
}

宏是无赖的最后避难所。因此,让我们使用宏:

#[macro_export]
macro_rules! define_is_consecutive_splat1 {
    ($function:ident, $type:ty) => {
        #[inline]
        pub fn $function<const N: usize>(chunk: Simd<$type, N>) -> bool
        where
            LaneCount<N>: SupportedLaneCount,
        {
            define_comparison_value_splat!(comparison_value_splat, $type);

            let subtracted = chunk - comparison_value_splat();
            Simd::splat(chunk[0]) == subtracted
        }
    };
}
#[macro_export]
macro_rules! define_comparison_value_splat {
    ($function:ident, $type:ty) => {
        pub const fn $function<const N: usize>() -> Simd<$type, N>
        where
            LaneCount<N>: SupportedLaneCount,
        {
            let mut arr: [$type; N] = [0; N];
            let mut i = 0;
            while i < N {
                arr[i] = i as $type;
                i += 1;
            }
            Simd::from_array(arr)
        }
    };
}

这使我们能够在任何特定元素类型和所有 LANES 上运行(Rust Playground):

define_is_consecutive_splat1!(is_consecutive_splat1_i32, i32);

let a: Simd<i32, 16> = black_box(Simd::from_array(array::from_fn(|i| 100 + i as i32)));
let ninety_nines: Simd<i32, 16> = black_box(Simd::from_array([99; 16]));
assert!(is_consecutive_splat1_i32(a));
assert!(!is_consecutive_splat1_i32(ninety_nines));

遗憾的是,对于range-set-blaze来说还不够。它需要在所有元素类型(而不仅仅是一种)和(理想情况下)所有 LANES(而不仅仅是一个 LANE)上运行。

幸运的是,有一个解决方法,再次依赖于宏。它还利用了我们只需要支持有限类型列表的事实,即:i8i16i32i64isizeu8u16u32u64usize。如果您需要同时(或者替代地)支持f32f64,那也没问题。

另一方面,如果您需要支持i128u128,那可能就没有办法了。core::simd模块不支持它们。在第 8 条规则中,我们将看到range-set-blaze如何通过牺牲性能来解决这个问题。

这个解决方法定义了一个新的 trait,这里称为IsConsecutive。然后,我们使用一个宏(调用一个宏,再调用一个宏)来在这 10 种感兴趣的类型上实现这个 trait。

pub trait IsConsecutive {
    fn is_consecutive<const N: usize>(chunk: Simd<Self, N>) -> bool
    where
        Self: SimdElement,
        Simd<Self, N>: Sub<Simd<Self, N>, Output = Simd<Self, N>>,
        LaneCount<N>: SupportedLaneCount;
}

macro_rules! impl_is_consecutive {
    ($type:ty) => {
        impl IsConsecutive for $type {
            #[inline] // very important
            fn is_consecutive<const N: usize>(chunk: Simd<Self, N>) -> bool
            where
                Self: SimdElement,
                Simd<Self, N>: Sub<Simd<Self, N>, Output = Simd<Self, N>>,
                LaneCount<N>: SupportedLaneCount,
            {
                define_is_consecutive_splat1!(is_consecutive_splat1, $type);
                is_consecutive_splat1(chunk)
            }
        }
    };
}

impl_is_consecutive!(i8);
impl_is_consecutive!(i16);
impl_is_consecutive!(i32);
impl_is_consecutive!(i64);
impl_is_consecutive!(isize);
impl_is_consecutive!(u8);
impl_is_consecutive!(u16);
impl_is_consecutive!(u32);
impl_is_consecutive!(u64);
impl_is_consecutive!(usize);

现在我们可以调用完全通用的代码(Rust Playground):

// Works on i32 and 16 lanes
let a: Simd<i32, 16> = black_box(Simd::from_array(array::from_fn(|i| 100 + i as i32)));
let ninety_nines: Simd<i32, 16> = black_box(Simd::from_array([99; 16]));

assert!(IsConsecutive::is_consecutive(a));
assert!(!IsConsecutive::is_consecutive(ninety_nines));

// Works on i8 and 64 lanes
let a: Simd<i8, 64> = black_box(Simd::from_array(array::from_fn(|i| 10 + i as i8)));
let ninety_nines: Simd<i8, 64> = black_box(Simd::from_array([99; 64]));

assert!(IsConsecutive::is_consecutive(a));
assert!(!IsConsecutive::is_consecutive(ninety_nines));

使用这种技术,我们可以创建多个完全通用于类型和 LANES 的候选算法。接下来,是时候进行基准测试,看看哪些算法最快。

这些是向 Rust 添加 SIMD 代码的前六条规则。在第二部分中,我们将看到第 7 到第 9 条规则。这些规则将涵盖如何选择算法和设置 LANES,以及如何将 SIMD 操作集成到现有代码中(重要的是),如何使其可选。第二部分结束时将讨论何时/如果应该使用 SIMD 以及改进 Rust 的 SIMD 体验的想法。我希望能在那里见到你。

关注 Carl 在 Medium 上的文章。我写关于 Rust 和 Python 中的科学编程,机器学习和统计学的文章。我倾向于每个月写一篇文章。

你的 Rust 代码的 SIMD 加速九大规则(第二部分)

原文:towardsdatascience.com/nine-rules-for-simd-acceleration-of-your-rust-code-part-2-6a104b3be6f3?source=collection_archive---------7-----------------------#2023-12-15

从将 range-set-blaze Crate 数据摄取速度提升 7 倍的经验中得到的一般性教训

Carl M. KadieTowards Data Science Carl M. Kadie

·

关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 12 月 15 日

--

一只螃蟹将计算任务委托给小螃蟹——来源:openai.com/dall-e-2/。所有其他图形来自作者。

感谢 Ben Lichtman (B3NNY) 在 Seattle Rust Meetup 上为我指引了正确的 SIMD 方向。

这是关于在 Rust 中创建 SIMD 代码的文章的第二部分。(参见 第一部分。)我们将查看第 7 到第 9 条规则:

  • 7. 使用 Criterion 基准测试来选择算法,并发现通道数(LANES)应该(几乎总是)为 32 或 64。

  • 8. 将你最佳的 SIMD 算法集成到你的项目中,使用 as_simd、专门的 i128/u128 代码和额外的上下文基准测试。

  • 9. 使用可选的 cargo 特性将你最佳的 SIMD 算法从项目中抽取出来(暂时)。

回顾规则 1 到 6:

  1. 使用 nightly Rust 和 core::simd,这是 Rust 的实验性标准 SIMD 模块。

  2. CCC:检查、控制和选择你计算机的 SIMD 能力。

  3. 学习 core::simd,但要有选择地学习。

  4. 头脑风暴候选算法。

  5. 使用 Godbolt 和 AI 来理解你代码的汇编,即使你不知道汇编语言。

  6. 通过内联泛型(如果这不起作用)宏(如果这不起作用)特质来通用化到所有类型和通道数。

这些规则基于我尝试加速 [range-set-blaze](https://crates.io/crates/range-set-blaze) 的经验,这是一个用于操作“密集”整数集合的 Rust crate。

请回忆规则 6,来自于 第一部分,展示了如何使 Rust SIMD 算法在类型和通道数上完全通用。接下来我们需要选择算法并设置通道数。

规则 7:使用 Criterion 基准测试来选择算法,并发现通道数(LANES)应该(几乎总是)为 32 或 64。

在这一规则中,我们将看到如何使用流行的 criterion crate 来基准测试和评估我们的算法及选项。在 range-set-blaze 的上下文中,我们将进行评估:

  • 5 种算法 — Regular、Splat0、Splat1、Splat2、Rotate

  • 3 种 SIMD 扩展级别 — sse2(128 位)、avx2(256 位)、avx512f(512 位)

  • 10 种元素类型 — i8u8i16u16i32u32i64u64isizeusize

  • 5 种通道数量 — 4、8、16、32、64

  • 4 种输入长度 — 1024;10,240;102,400;1,024,000

  • 2 个 CPU — AMD 7950X 配有 avx512f,Intel i5–8250U 配有 avx2

基准测试测量每种组合的平均运行时间。然后,我们计算每秒传输的兆字节(Mbytes/sec)。

请参阅这篇 新伴随文章,了解如何开始使用 Criterion。那篇文章还展示了如何推动(滥用?)Criterion 来测量编译器设置的效果,例如 SIMD 扩展级别。

运行基准测试会生成一个 5000 行的 *.csv 文件,内容如下:

Group,Id,Parameter,Mean(ns),StdErr(ns)
vector,regular,avx2,256,i16,16,16,1024,291.47,0.080141
vector,regular,avx2,256,i16,16,16,10240,2821.6,3.3949
vector,regular,avx2,256,i16,16,16,102400,28224,7.8341
vector,regular,avx2,256,i16,16,16,1024000,287220,67.067
vector,regular,avx2,256,i16,16,32,1024,285.89,0.59509
...

此文件适合通过 电子表格数据透视表 或数据框工具如 Polars 进行分析。

算法和通道

这里是一个 Excel 数据透视表,显示 — 对于每种算法 — 吞吐量(MBytes/sec)与 SIMD 通道数的关系。该表格对 SIMD 扩展级别、元素类型和输入长度进行了吞吐量的平均计算。

在我的 AMD 桌面机器上:

在一台 Intel 笔记本电脑上:

表格显示 Splat1 和 Splat2 表现最佳。它们还显示更多的 lanes 始终在 32 或 64 之间更好。

例如,*sse2*(128 位宽)如何处理 64 条*i64*(4096 位宽)的数据?Rust 的*core::simd*模块通过自动而高效地将 4096 位分成 32 个 128 位块,使这一切成为可能。将 32 个 128 位块一起处理(显然)能够进行超越独立处理 128 位块的优化。

SIMD 扩展级别

让我们将 LANES 设置为 64,并比较 AMD 机器上不同的 SIMD 扩展级别。表格展示了各元素类型和输入长度下的平均吞吐量。

在我的 AMD 机器上,当使用 64 条 lanes 时,sse2最慢。比较avx2avx512f时,结果各有不同。再一次,算法 Splat1 和 Splat2 表现最好。

元素类型

接下来,我们将 SIMD 扩展级别设置为avx512f并比较不同的元素类型。我们将LANES保持在 64,并对输入长度进行平均吞吐量测试。

我们看到按位、32 位和 64 位元素的处理速度最快。(然而,按元素来说,更小的类型更快。)Splat1 和 Splat2 是最快的算法,其中 Splat1 表现略好。

输入长度

最后,让我们将元素类型设置为i32,看看输入长度与吞吐量的关系。

我们看到所有的 SIMD 算法在 100 万个输入下表现差不多。Splat1 显然在短输入上表现优于其他算法。

看起来更短的处理速度比更长的更快。这可能是缓存的结果,或者可能是基准测试丢弃未对齐数据的结果。

基准测试结论

基于这些基准测试,我们将使用 Splat1 算法。现在,我们将 LANES 设置为 32 或 64,但请参见下一条规则以了解复杂情况。最后,我们建议用户将其 SIMD 扩展级别设置为至少avx2

规则 8:将您最好的 SIMD 算法集成到您的项目中,使用as_simd,为i128/u128编写特殊代码,并进行额外的上下文基准测试。

as_simd

在添加 SIMD 支持之前,RangeSetBlaze的主要构造函数是from_iter

let a = RangeSetBlaze::from_iter([1, 2, 3]);

然而,SIMD 操作在数组上效果最佳,而不是在迭代器上。此外,从数组构造RangeSetBlaze通常是一个自然的操作,因此我添加了一个新的from_slice构造函数:

#[inline]
    pub fn from_slice(slice: impl AsRef<[T]>) -> Self {
        T::from_slice(slice)
    }

新构造函数对每个整数的from_slice方法进行内联调用。对于所有整数类型,除了i128/u128,接下来的调用是:

let (prefix, middle, suffix) = slice.as_simd();

Rust 的夜间版as_simd方法安全快速地将切片转换为:

  1. 一个未对齐的prefix——我们像之前一样用from_iter处理它。

  2. middle,一个对齐的Simd结构块数组

  3. 一个未对齐的suffix——我们像之前一样用from_iter处理它。

middle视为将我们的输入整数分成大小为 16 的块(或者LANES设置的任何大小)。然后,我们通过is_consecutive函数迭代这些块,寻找true的连续段。每个连续段成为一个单一的范围。例如,从 1000 到 1159(含)的 160 个连续整数将被识别并替换为一个 Rust RangeInclusive 1000..=1159。然后,这个范围由from_iter处理,比from_iter处理 160 个单独整数要快得多。当is_consecutive返回false时,我们退回到使用from_iter处理块中的单独整数。

i128/u128

我们如何处理core::simd不处理的类型数组,即i128/u128?目前,我只是用较慢的from_iter来处理它们。

上下文基准测试

作为最后一步,在你的主要代码的上下文中对 SIMD 代码进行基准测试,理想情况下使用具有代表性的数据。

range-set-blaze库已经包括了基准测试。一个基准测试测量了在不同块大小下处理 1,000,000 个整数的性能。平均块大小范围从 1(无块)到 100,000 块。让我们在LANES设置为 4、8、16、32 和 64 时运行该基准测试。我们将使用算法 Splat1 和 SIMD 扩展级别avx512f

对于每种块大小,条形图显示了处理 1,000,000 个整数的相对速度。对于每种块大小,最快的LANES设置为 100%。

我们看到,对于大小为 10 和 100 的块,LANES=4 是最好的。然而,对于大小为 100,000 的块,LANES=4 的表现比最佳值差 4 倍。在另一个极端,LANES=64 在大小为 100,000 的块上表现良好,但在大小为 100 和 1000 的块上分别比最佳值差 1.8 倍和 1.5 倍。

我决定将LANES设置为 16。它对于大小为 1000 的块是最好的。此外,它的表现从未比最佳值差超过 1.25 倍。

使用此设置,我们可以运行其他基准测试。下图显示了各种范围集库(包括range-set-blaze)在相同任务上工作的情况——处理 1,000,000 个具有不同块大小的整数。y轴是毫秒,值越低越好。

对于大小为 1000 的块,现有的RangeSetBlaze::into_iter方法(红色)已经比 HashSet(橙色)快 30 倍。注意,刻度是对数的。使用avx512f,新的 SIMD 加速的RangeSetBlaze::into_slice算法(浅蓝色)比 HashSet 快 230 倍。使用sse2(深蓝色),快 220 倍。使用avx2(黄色),快 180 倍。在这个基准测试中,与RangeSetBlaze::into_iter相比,avx512fRangeSetBlaze::into_slice快 7 倍。

我们还应该考虑最坏的情况,即处理没有聚集的数据。我进行了基准测试。结果显示,现有的RangeSetBlaze::into_iter比 HashSet 慢约 2.2 倍。新的RangeSetBlaze::into_slice比 HashSet 慢 2.4 倍。

所以,总的来说,新 SIMD 代码为假定数据有聚集的情况提供了巨大的优势。如果假设是错误的,它会更慢,但不会灾难性地慢。

由于 SIMD 代码集成到我们的项目中,我们准备好发布了,对吗?可惜不是。因为我们的代码依赖于 Rust 夜间版本,我们应该使其可选。我们将在下一条规则中看到如何做到这一点。

规则 9:暂时将你最佳的 SIMD 算法从项目中分离(使用可选的 cargo 功能)。

我们美丽的新 SIMD 代码依赖于 Rust 夜间版本,它会发生变化。要求用户依赖 Rust 夜间版本是不合理的。(另外,当事情出错时收到抱怨也会很烦人。)解决方案是将 SIMD 代码隐藏在 cargo 功能之后。

功能、功能、功能——在处理 SIMD 和 Rust 的上下文中,“功能”一词有三种不同的含义。首先,“CPU/target features”——这些描述了 CPU 的能力,包括它支持的 SIMD 扩展。见[*target-feature*](https://doc.rust-lang.org/std/arch/index.html) [*is_x86_feature_detected!*](https://doc.rust-lang.org/std/arch/index.html)。第二,“nightly feature gates”——Rust 通过功能开关控制 Rust 夜间版本中新语言功能的可见性。例如,[*#![feature(portable_simd)]*](https://github.com/rust-lang/portable-simd)。第三,“cargo features”——这些允许任何 Rust crate 或库提供/限制其部分功能的访问。你在*Cargo.toml*中看到这些,例如,当你添加对*itertools/use_std*的依赖时。

以下是range-set-blaze crate 采取的步骤,以使夜间依赖的 SIMD 代码可选:

  • Cargo.toml中,定义一个与 SIMD 代码相关的 cargo 功能:
[features]
from_slice = []
  • 在顶部的lib.rs文件中,设置夜间portable_simd功能开关,依赖于from_slice的 cargo 功能:
#![cfg_attr(feature = "from_slice", feature(portable_simd))]
  • 使用条件编译属性,例如#[cfg(feature = “from_slice”)],来选择性地包含 SIMD 代码。这包括测试。
/// Creates a [`RangeSetBlaze`] from a collection of integers. It is typically many
/// times faster than [`from_iter`][1]/[`collect`][1].
/// On a representative benchmark, the speed up was 6×.
///
/// **Warning: Requires the nightly compiler. Also, you must enable the `from_slice`
/// feature in your `Cargo.toml`. For example, with the command:**
/// ```bash

/// cargo add range-set-blaze --features "from_slice"

/// ```py
///
/// **Caution**: Compiling with `-C target-cpu=native` optimizes the binary for your current CPU architecture,
/// which may lead to compatibility issues on other machines with different architectures.
/// This is particularly important for distributing the binary or running it in varied environments.
/// [1]: struct.RangeSetBlaze.html#impl-FromIterator<T>-for-RangeSetBlaze<T>
#[cfg(feature = "from_slice")]
#[inline]
pub fn from_slice(slice: impl AsRef<[T]>) -> Self {
    T::from_slice(slice)
}
  • 如上文档所示,向文档中添加警告和注意事项。

  • 使用--features from_slice来检查或测试你的 SIMD 代码。

cargo check --features from_slice
cargo test --features from_slice
  • 使用--all-features运行所有测试,生成所有文档,并发布所有 cargo 功能:
cargo test --all-features --doc
cargo doc --no-deps --all-features --open
cargo publish --all-features --dry-run

结论

所以,这就是了:向 Rust 代码中添加 SIMD 操作的九条规则。这个过程的简便性反映了core::simd库的卓越设计。你应该在适用的地方总是使用 SIMD 吗?最终,应该的,当库从 Rust nightly 版本转到稳定版时。目前,在 SIMD 的性能优势至关重要的地方使用它,或者让它的使用成为可选。

有关提升 Rust 中 SIMD 体验的想法?core::simd 的质量已经很高,主要需要的是将其稳定化。

感谢你和我一起探讨 SIMD 编程。我希望如果你有一个适合 SIMD 的问题,这些步骤能帮助你加速处理。

关注 Carl 的 Medium 文章。我在 Rust 和 Python 的科学编程、机器学习以及统计学方面写作。我倾向于每月撰写一篇文章。

用 Dafny 正式验证 Rust 算法的九个规则(第一部分)

原文:towardsdatascience.com/nine-rules-to-formally-validate-rust-algorithms-with-dafny-part-1-5cb8c8a0bb92?source=collection_archive---------3-----------------------#2023-10-04

验证 range-set-blaze 木板的经验教训

Carl M. KadieTowards Data Science Carl M. Kadie

·

Follow 发表于 Towards Data Science ·14 min 阅读·2023 年 10 月 4 日

--

由 Carl M. Kadie 和 Divyanshu Ranjan 撰写

蟹证明毕达哥拉斯定理 — 源自:openai.com/dall-e-2/ & CC BY-SA 3.0 File:Pythagorean.svg

我的 Rust crate [range-set-blaze](https://crates.io/crates/range-set-blaze) 依赖于一个名为internal_add的关键函数。该函数应当将一系列整数插入到 crate 的数据结构中。但它是否正确地完成了这个任务?当然,我会进行测试,但测试可能会漏掉错误。理想情况下,我希望获得数学上的正确性保障。

附注:作为 Rust 程序员,我们欣赏确定性。Rust 类型系统保证我们不会解引用空指针。Rust 借用检查器保证我们不会在内存被释放后继续使用它。像Kani Rust crate这样的工具在某些情况下保证算术不会溢出。但如果我们想要确定一个算法的正确性呢?

为了实现这种确定性,Divyanshu Ranjan 和我将internal_add的算法移植到 Dafny 语言中。然后我们验证了 Dafny 版本的算法。(我们选择 Dafny 是因为它的强大和易用性。稍后我们会多谈谈这个选择。)

在验证过程中,我们学习了九条规则,可以帮助你使用 Dafny 验证算法——无论是用 Rust 还是其他语言编写的。你也可能会发现这些规则作为使用现代工具验证的难易程度的参考非常有趣。

规则如下:

  1. 不要学习 Dafny。

  2. 学习 Dafny。

  3. 定义你算法的基本概念。

  4. 规范你的算法。

  5. 从 Dafny 社区获取帮助。

  6. 验证一个不同的、更简单的算法。

第二部分 了解这些规则:

7. 将你的实际算法移植到 Dafny。

8. 验证你算法的 Dafny 版本。

9. 重新工作你的验证以确保可靠性。

附注:为了避免模棱两可,我们称这些为“规则”,但它们当然只是建议。

internal_add函数试图将一个新的整数范围高效地插入到已排序且不重叠的整数范围列表中。例如,如果我们从[101..=102, 400..=402, 404..=405]开始,并添加402..=404,我们期望的结果是[101..=102, 400..=405]

来源:本文及所有后续图片均由作者提供。

理想情况下,我会使用 Rust 特定的工具[1,2]正式验证这个算法。然而,这些工具似乎难以使用。因此,我选择了Dafny。Dafny 是一种语言和验证系统。它在世界各地的大学本科课程中教授,也在工业界使用。我发现它具有令人上瘾的交互性和对程序员友好的特点。

附带说明:Dafny 的创始人 Rustan Leino 博士与 Rust 的联系不仅仅是名字的巧合。他帮助创建了 Spec#,这是第一个使用类型系统来避免空指针的语言。Rust 当然采纳了这个想法,并取得了巨大成功。

本文涵盖规则 1 到 6。 第二部分 涵盖规则 7 到 9。

规则 1:不要学习 Dafny。

在尝试证明算法的数学正确性之前,决定这种努力是否值得。

Dafny 不是 Rust。使用 Dafny 需要将感兴趣的算法从 Rust 移植到 Dafny。这种移植可能会遗漏细节并引入错误。鉴于这种风险, 是否应该使用 Dafny 来验证 Rust 算法?我大胆地声称“这要看情况”。

  • 你的算法的正确性有多重要?如果你正在打印报告且它看起来正确,那么它可能确实是正确的。internal_add 算法涉及一个数据结构,我希望其他人能够自信地使用它,这给了我额外的动机去验证它。

  • 也许所有形式验证在当前工具下都太难了。然而,我相信 Dafny 使形式验证变得尽可能简单。如果你已经熟悉类型(例如,来自 Rust)和递归/归纳(在 Rust 中不常用),你会发现形式验证代码更容易。你可以阅读本文并自行决定何时形式验证足够简单,值得对你有价值。

  • 也许模糊测试(例如 [cargo-fuzz](https://github.com/rust-fuzz/cargo-fuzz))和基于属性的测试(例如 [QuickCheck](https://github.com/BurntSushi/quickcheck))已经足够好。虽然这些方法不能提供数学上的确定性,但它们聪明、有用且易于使用。(range-set-blaze crate 已经使用了 QuickCheck。有关详细信息,请参见 之前文章中的规则 9.5。)

  • 也许形式验证注定是失败的,因为编写规范和编写代码一样困难。我不同意这个观点。想一想 重构。我通常通过编写简单的代码来开始编程。然后,我将这些简单的代码重构以提高效率。对于 internal_add,我发现规范比任何代码都要简单。(你可以在规则 4 中自行判断这一点。)

附带说明:验证变成了从简单规范到最终高效算法的计算机检查重构。

  • 也许形式化验证注定是失败的,因为停机问题正式告诉我们形式化通常是不可能的。停机问题并没有注定我们的失败。虽然我们不能总是理解任意代码,但我们不需要这样做。我们只需要理解我们自己的代码,而我们(希望)写的代码是易于理解的。从规则 2 开始,我们将看到 Dafny 如何轻松验证特定的循环和递归是否会停止。

  • 也许迁移到 Dafny 太困难了。这不是我的经历。像 Rust 一样,Dafny 混合了命令式和函数式编程。我发现将我的算法迁移到 Dafny 是简单的。

假设你仍然希望用 Dafny 验证你的算法,那么下一步是学习 Dafny。

步骤 2:学习 Dafny。

Dafny 是一种编程语言和交互式验证系统。我推荐你将其作为 VS Code 扩展安装

要学习 Dafny,从dafny.org/开始。特别值得关注的是在线教程参考手册。我还发现 YouTube 上的Verification Corner 视频很有帮助。(可能感兴趣的还有大学教材《程序证明》,Kindle 版售价 $49)。我发现 Dafny 的编程语言部分比 Rust 更容易学习,可能与 C# 的难度相当。

Dafny 和 Rust 一样是完全类型化的。Dafny 像 Python 一样进行垃圾回收。这里有一个“Hello World”示例:

// hello.dfy
method Main()
{
  var s := "Hello World";
  print s, "\n";
}

Dafny,像 Python 一样,提供任意大小的整数。这里有一个程序,它通过重复递增来可证明地加两个自然数。

一些关注点:

  • Dafny 编码规范遵循 C#,而不是 Rust。因此,我们将函数命名为 SlowAdd 而不是 slow_add(尽管两者都能运行)。

  • Dafny 支持子类型。例如,任何可以证明是非负的 int 也是一个 nat

  • 赋值用 :=,等式用 ==。(没有 =。)

  • 函数参数,例如上面的xy,是不可变的。

  • Dafny 使用 ensuresinvariant 语句在编译时验证代码。然后,它会移除这些语句以完成编译。

  • 绿色的对勾标记显示这段代码已通过验证。Dafny 的 VS Code 扩展默认会持续尝试验证每个方法。这为使用 Dafny 的工作增添了几乎像赌博一样的兴奋感。在上面的例子中,如果我将y改为int而不是nat,那么验证应该会失败。 (你能找出原因吗?)Dafny 会用红色 X 标记我的函数,并告诉我“这个后置条件可能不成立:r == x + y”。

  • Dafny 了解一些整数、数组、集合、映射、序列等的数学。这通常使它能够自行完成验证的最后细节。

现在你了解了 Dafny,你应该让它了解你的算法。

规则 3:定义算法的基本概念。

range-set-blaze crate 将整数集合表示为已排序的、不重叠的范围。例如,这个包含三个范围的列表:

100..=2_393, 20_303..=30_239_000, 501_000_013..=501_000_016

表示一个包含 30,220,996 个整数的集合。

在 Rust 中,RangeSetBlaze 结构体内部用标准的 [BTreeMap](https://doc.rust-lang.org/std/collections/struct.BTreeMap.html) 表示这个数据结构。请记住,BTreeMap 表示按键排序的键/值对列表。在这里,我们的键是范围的起始值(例如,10020_303501_000_013),值是范围的包含结束值(例如,2_39330_239_000501_000_016)。RangeSetBlaze 使用 BTreeMap 而非 vec 来存储列表,以使键查找更适合缓存。

RangeSetBlaze 依赖于 BTreeMap,那么我们必须在 Dafny 中实现 BTreeMap 吗?幸运的是,不需要。我们可以改用 Dafny 的类似 vecseq 数据类型。这个替代方案有效,因为 BTreeMapvecseq 都可以表示排序列表——只是效率不同。对于形式验证的目的,我们只关心正确性,可以忽略效率。

RangeSetBlaze 需要范围列表是已排序且互不重叠的。我们如何在 Dafny 中表示“已排序且互不重叠”?我们可以通过这个 幽灵谓词(及相关代码 来表示:

ghost predicate ValidSeq(sequence: seq<NeIntRange>) {
  (forall i:nat, j:nat | i < j < |sequence| :: sequence[i].0 < sequence[j].0)
  && (forall i:nat, j:nat | i < j < |sequence| :: !Touch(sequence[i], sequence[j]))
}

type IntRange = (int, int)
type NeIntRange = x: IntRange | !IsEmpty(x) witness (0,0)

function IsEmpty(r: IntRange): bool
{
  r.0 > r.1
}

谓词 是返回 bool 的方法的另一种说法。幽灵 方法(或谓词)是只能用于验证而不能用于运行最终代码的方法。

从高层次来看,ValidSeq 谓词以非空整数范围的序列作为输入。它随后测试起始值是否排序,并且范围是否不重叠。具体来说,

  • IntRange 是一个由两个 int 值组成的元组。

  • 当且仅当 IntRange 的起始值大于结束值时,它才是空的。(这遵循了 Rust 的约定)。

  • NeIntRange(非空整数范围)是一个非空的 IntRange,例如,(0,0)。 [我们所有的范围都是包含结束值的。]

  • 这个表达式测试起始值是否已排序:

forall i:nat, j:nat | i < j < |sequence| :: sequence[i].0 < sequence[j].0

它可以被解读为“对所有自然数 ij —— 使得 i 小于 jj 小于序列的长度 —— 测试索引 i 处的起始值是否小于索引 j 处的起始值”。

附注:注意 Rust 的 BTreeMap 不支持(随机访问)索引,但这里我们使用了这种索引。这是可以的,因为 *ValidSeq* 是一个幽灵谓词,因此仅用于验证。

  • 这个表达式测试范围是否互不重叠:
forall i:nat, j:nat | i < j < |sequence| :: !Touch(sequence[i], sequence[j])

它可以读作“对于所有自然数ij — 使得i小于jj小于序列的长度 — 测试索引i的范围是否不触及索引j的范围。但Touch是什么?

我们将Touch定义为两个层次。在数学层面,如果范围i中存在整数i0,范围j中存在整数j0,并且i0j0彼此距离为一,那么范围i被认为触及范围j。在高效编程层面,我们希望避免依赖“存在”的定义。这是一个 Dafny 谓词,它既符合数学定义又高效:

predicate Touch(i: NeIntRange, j: NeIntRange)
  ensures Touch(i, j) == exists i0, j0 ::
                    Contains(i, i0) && Contains(j, j0) && -1 <= i0 - j0 <= 1
{
  assert Contains(i, i.0) && Contains(i, i.1) && Contains(j, j.0) && Contains(j, j.1);
  if i.1 < j.0 then
    assert  (-1 <= i.1 - j.0 <= 1) == (i.1+1 == j.0);
    i.1+1 == j.0
  else if j.1 < i.0 then
    assert (-1 <= j.1 - i.0 <= 1) == (j.1+1 == i.0);
    j.1+1 == i.0
  else
    var k0 := Max(i.0, j.0);
    assert Contains(i, k0) && Contains(j, k0);
    true
}

function Contains(r: IntRange, i: int): bool
{
  r.0 <= i && i <= r.1
}

function Max(a: int, b: int): int
{
  if a < b then b else a
}

一些关注点:

  • Touch不是幽灵。换句话说,我们可以在常规代码和验证代码中使用它。

  • assert语句帮助 Dafny 证明常规代码符合数学ensures语句。

  • 为了提高效率,Dafny 证明器分别验证method的内部和外部。只有ensures(以及尚未出现的requires)语句跨越这个边界。与method不同,Dafnyfunction对验证器是透明的。(我认为它类似于在验证方面内联代码。)

在定义了ValidSeqTouch等概念后,我们接下来要指定我们的算法应该做什么。

规则 4:指定你的算法。

最终,我希望证明我的 Rust 算法在将新范围插入RangeSetBlaze中是正确的。然而,在此之前,我们先定义一下什么是“正确”的范围插入

method InternalAdd(xs: seq<NeIntRange>, a: IntRange) returns (rs: seq<NeIntRange>)
  requires ValidSeq(xs)
  ensures ValidSeq(rs)
  ensures SeqToSet(rs) == SeqToSet(xs) + RangeToSet(a)
{
  if IsEmpty(a)
  {
    rs := xs;
  }
  else
  {
    assume false; // cheat for now
  }
}

这表示InternalAdd是一个方法,它接受xs,一个非空整数范围序列,以及a,一个整数范围(可能为空)。该方法输出rs,一个新的非空整数范围序列。

我们需要说明xsrs必须是已排序且不重叠的。这可以通过ValidSeqrequires和第一个ensures中轻松完成。

我们还需要说明rs包含了正确的内容。这难吗?其实不难。我们只需说明rs中的整数集合必须等于xs中的整数集合并与a中的整数集合并集。

旁注:在 Dafny 中,“+”应用于集合时表示“并集”。

一个范围中的整数集合是:

ghost function RangeToSet(pair: IntRange): set<int>
{
  set i {:autotriggers false} | pair.0 <= i <= pair.1 :: i
}

非空范围序列中的整数集合可以递归定义(即递归):

ghost function SeqToSet(sequence: seq<NeIntRange>): set<int>
  decreases |sequence|
  requires ValidSeq(sequence)
{
  if |sequence| == 0 then {}
  else if |sequence| == 1 then RangeToSet(sequence[0])
  else RangeToSet(sequence[0]) + SeqToSet(sequence[1..])
}

一些关注点:

  • 这行代码:assume false; // cheat for now 使验证即使在实际上不应该工作时也能工作。我们将其用作临时占位符。

  • 我们将RangeToSetSeqToSet设为幽灵,以防止我们在常规代码中使用它们。我们将它们设为函数(而不是方法),以便在验证时内联它们。

  • 因为 Dafny 对创建和操作集合和序列了解颇多,我们经常通过在规格说明中使用集合和序列获益。

  • 即使我们常规代码使用循环而不是递归,我们的验证代码通常会使用递归类似的归纳法。

  • {:autotriggers false} 相关于避免警告信息。更多信息请参见 Prof. James Wilcox 的这个 Stack Overflow 回答

我们现在有了 InternalAdd 的正式规格说明。我发现这个规格简短且直观。但如果你需要帮助理解规格说明或其他 Dafny 代码呢?

规则 5:寻求 Dafny 社区的帮助。

Dafny 问题的主要论坛是 Stack Overflow。令我惊讶的是,我在这里实际获得了许多有用的帮助。

我建议在问题标题前加上“Dafny:”。同时,确保为你的问题添加 dafny 标签,可能还要加上 formal-verification 标签。

附带说明:在网站上,你可以看到我的 11 个问题Divyanshu Ranjan 的 48 个与 Dafny 相关的回答

作为一个在 GitHub 上的开源项目,Dafny 也托管了 GitHub Discussions 和 Issues。

Dafny 社区虽然小,但似乎热衷于帮助用户和改进项目。

在有帮助的情况下,我们接下来必须找到一个符合规格的算法。

规则 6:验证一个不同、更简单的算法。

作为正式验证的初学者,我决定推迟对我 Rust 代码中实际 internal_add 的工作。相反,我开始着手开发一个我希望更容易验证的 InternalAdd 算法。最后,我得到的是这个

method InternalAdd(xs: seq<NeIntRange>, a: IntRange) returns (rs: seq<NeIntRange>)
  requires ValidSeq(xs)
  ensures ValidSeq(rs)
  ensures SeqToSet(rs) == SeqToSet(xs) + RangeToSet(a)
{
  if IsEmpty(a)
  {
    rs := xs;
  }
  else
  {
    var notTouching, merged := PartitionAndMerge(xs, a);
    var indexAfter := NoTouchIndexAfter(notTouching, merged);
    rs := InsertAt(notTouching, [merged], indexAfter);
  }
}

这个想法是,如果范围 a 为空,我们返回未更改的输入序列。否则,我们将工作分成三个步骤,我们可以独立验证。第一步,PartitionAndMerge, 返回:

  • notTouching 是一个不触及范围 a 的范围序列,且

  • merged 是由 a 及其触及到的所有内容创建的单一范围。

这里是一个示例输入和输出:

InternalAdd 接着寻找插入 merged 的位置,并最终插入它。

这里是代码 [PartitionAndMerge](https://github.com/CarlKCarlK/range-set-blaze/tree/oct23/tests/formal):

method PartitionAndMerge(xs: seq<NeIntRange>, a: NeIntRange) returns (notTouching: seq<NeIntRange>, merged: NeIntRange)
  requires ValidSeq(xs)

  ensures ValidSeq(notTouching)
  ensures RangeToSet(merged) >= RangeToSet(a)
  ensures forall range | range in notTouching :: !Touch(range, merged)
  ensures SeqToSet(xs) + RangeToSet(a) == SeqToSet(notTouching) + RangeToSet(merged)
{
  // Split into touching and not touching seqs
  var touching: seq<NeIntRange>;
  touching, notTouching := Partition(a, xs);

  // Merge the touching seq into one range with our original range
  merged := UnionSeq(a, touching);
}

这说明 PartitionAndMerge 要求 xs 是一个有效的非空整数范围序列,并且 a 是一个非空整数范围。它确保 nonTouching 是另一个有效的非空整数范围序列。它确保 merged 范围中的整数是 a 范围中整数的超集。它确保 notTouching 中的任何范围都不接触 merged 范围。最后,它确保 xsa 中的整数与 notTouchingmerged 中的整数完全相同。

PartitionAndMerge 还将工作分为两个步骤(PartitionUnionSeq),这两个步骤可以独立验证。这些步骤继续将工作细分。它在哪里结束?让我们看一个例子。

方法 UnionSeq 调用 [UnionRange](https://github.com/CarlKCarlK/range-set-blaze/tree/oct23/tests/formal) 来合并两个范围:

function UnionRange(x: IntRange, y: IntRange): IntRange
  requires IsEmpty(x) || IsEmpty(y) || Touch(x, y)
  ensures RangeToSet(x) + RangeToSet(y) == RangeToSet(UnionRange(x,y))
{
  if IsEmpty(x) then y
  else if IsEmpty(y) then x
  else (Min(x.0, y.0), Max(x.1, y.1))
}

UnionRange 代码处理了空情况,然后返回最小的包围范围。(最小的包围范围是从两个开始中较小的那个到两个结束中较大的那个。)但这怎么可能正确呢?一般来说,两个范围的最小包围范围可能包含额外的整数。我们可能会得到比输入的并集更大的范围,如下所示:

代码是正确的,因为它要求两个输入范围相接触或为空。这确保了范围 x 中的整数与范围 y 中的整数的并集正好是输出范围中的整数。

在编译时,Dafny 证明了这个函数是正确的。除此之外,它证明了所有调用这个函数的地方都提供了空或相接触的输入。

我认为这可以看作是 Rust 借用检查器的一个概括。在编译时,Rust 检查我们是否安全,避免了许多内存错误。在编译时,验证系统,如 Dafny,可以证明几乎任意的属性。当然,正如我们所见,这种能力是以复杂性为代价的。

这个经过验证的算法的完整代码大约有 200 行,分成大约十几个方法和函数。

这个规则显示了我们可以验证一个 InternalAdd 算法,但这不是我在 Rust 中使用的算法。我们将接下来讨论那个算法。

这些是使用 Dafny 验证 Rust 算法的前六条规则。请参阅 第二部分 获取规则 7 到 9。

关注 Carl 的 Medium 账号。我撰写关于 Rust 和 Python 的科学编程、机器学习和统计学的文章。我通常每月写一篇文章。

阅读 Divyanshu Ranjan 的更多工作,访问 他的博客。除了形式化方法,博客还涉及几何学、统计学等内容。

使用 Dafny 正式验证 Rust 算法的九条规则(第二部分)

原文:towardsdatascience.com/nine-rules-to-formally-validate-rust-algorithms-with-dafny-part-2-f2a279686700?source=collection_archive---------5-----------------------#2023-10-21

验证 range-set-blaze Crate 的经验教训

Carl M. KadieTowards Data Science Carl M. Kadie

·

关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 10 月 21 日

--

作者:Carl M. KadieDivyanshu Ranjan

蟹证明毕达哥拉斯定理 — 来源:openai.com/dall-e-3/ & CC BY-SA 3.0 文件:Pythagorean.svg

这是关于使用 Dafny 正式验证 Rust 算法的文章第二部分。我们查看第 7 至第 9 条规则:

  • 7. 将你的真实算法迁移到 Dafny。

  • 8. 验证你算法的 Dafny 版本。

  • 9. 重新审视你的验证以确保可靠性。

参见 第一部分 的规则 1 到 6:

  1. 不要学习 Dafny。

  2. 学习 Dafny。

  3. 定义你的算法的基本概念。

  4. 指定你的算法。

  5. 向 Dafny 社区寻求帮助。

  6. 验证不同的、更简单的算法。

这些规则来自我们验证一个来自 [range-set-blaze](https://crates.io/crates/range-set-blaze) 的算法的经验,这个 Rust crate 用于处理“分散”整数的集合。

记住第 6 规则,来自 第一部分,显示我们可以验证 一个 算法的 InternalAdd,但它不是 Rust crate 中使用的 那个 算法。接下来我们转向那个算法。

规则 7:将你的真实算法移植到 Dafny。

这里是感兴趣的 Rust 函数,部分代码暂时省略:

// https://stackoverflow.com/questions/49599833/how-to-find-next-smaller-key-in-btreemap-btreeset
// https://stackoverflow.com/questions/35663342/how-to-modify-partially-remove-a-range-from-a-btreemap
fn internal_add(&mut self, range: RangeInclusive<T>) {
    let (start, end) = range.clone().into_inner();
    assert!(
        end <= T::safe_max_value(),
        "end must be <= T::safe_max_value()"
    );
    if end < start {
        return;
    }
    //... code continues ...
}

这里是 Dafny 移植版的开始部分:

method InternalAdd(xs: seq<NeIntRange>, range: IntRange) returns (r: seq<NeIntRange>)
  requires ValidSeq(xs)
  ensures ValidSeq(r)
  ensures SeqToSet(r) == SeqToSet(xs) + RangeToSet(range)
{
  var (start, end) := range;
  if end < start {
    r := xs;
    return;
  }
 //... code continues ...
}

一些可能感兴趣的点:

  • Rust 代码使用 self 和类似面向对象的封装。Dafny 支持这种编码风格,但为了简便,我在这里没有使用。具体来说,Rust 代码会改变 self。我选择了用更函数式的方式编写 Dafny 代码——它接受一个不可变的序列并返回一个新的不可变序列。

  • Rust 代码通过借用检查器来管理内存。这导致了诸如 range.clone() 的表达式。Dafny 通过垃圾回收器来管理内存。在这两种情况下,内存安全都会得到保证。因此,我们在此验证中忽略它。

  • Rust 代码对 T 是泛型的,我在其他地方定义它包括所有标准的 Rust 整数类型,例如 u8isizei128。Dafny 代码定义在 int 上,这是一个表示任意大小整数的单一类型。这意味着这个 Dafny 移植版不需要检查整数溢出。 [参见 上一篇文章 了解如何用 Kani Rust 验证器正式证明溢出安全。]

  • Rust 代码包含一个运行时的 assert!,这是 Rust 中用于禁止一种特殊情况的:将 u128::max_value 插入到 RangeSetBlaze<u128> 中。由于 Dafny 使用任意大小的 int,它忽略了这种特殊情况。

附注:Rust 的 包含范围 *0..=u128::max_value** 的长度是多少?答案是* *u128::max_value*+1,这个值太大,无法用任何标准 Rust 整数类型表示。 *range-set-blaze* crate 将范围限制为 *0..=u128::max_value*-1,以便长度可以用 *u128* 表示。

接下来我们考虑 internal_add 算法的其余部分。记住,我们有一些排序好的不相交的范围和一些非空的新范围,我们想要插入。例如

致谢:此处和以下的图表由作者提供。

算法要求我们找出哪个(如果有的话)现有范围在新范围的开始之前(或正好在开始)。称之为“之前”范围。然后我们考虑四种情况:

  • 案例 1:新范围不触及之前的范围,因此我们在检查是否触及任何其他范围时插入新范围。

  • 案例 2:新范围触及之前的范围并超出它,因此在检查是否触及任何其他范围时扩展之前范围的末尾。(当没有其他范围被触及时,这将非常快。)

  • 案例 3:新范围触及之前的范围但没有超出它,因此不做任何操作。(这将总是非常非常快。)

  • 案例 4:新范围在任何范围之前开始,因此在检查是否触及任何其他范围时添加它。

这是Rust 中的算法

// code continued ...
// FUTURE: would be nice of BTreeMap to have a partition_point function that returns two iterators
    let mut before = self.btree_map.range_mut(..=start).rev();
    if let Some((start_before, end_before)) = before.next() {
        // Must check this in two parts to avoid overflow
        if match (*end_before).checked_add(&T::one()) {
            Some(end_before_succ) => end_before_succ < start,
            None => false,
        } {
            self.internal_add2(&range);
        } else if *end_before < end {
            self.len += T::safe_len(&(*end_before..=end - T::one()));
            *end_before = end;
            let start_before = *start_before;
            self.delete_extra(&(start_before..=end));
        } else {
            // completely contained, so do nothing
        }
    } else {
        self.internal_add2(&range);
    }
}

这里是 Dafny 中的算法:

// code continued ...
  var beforeHi := IndexAtOrBeforePlusOne(xs, start);
  if beforeHi > 0 { // does not go at front
    var (startBefore, endBefore) := xs[beforeHi-1];
    if endBefore+1 < start {
      r := InternalAdd2(xs, range);
    } else if endBefore < end {
      r := xs[..beforeHi-1] + [(startBefore, end)] + xs[beforeHi..];
      assume exists i: nat :: i < |r| && r[i] == (startBefore,end) && ValidSeq(r[..i+1]) && ValidSeq(r[i+1..]);
      r := DeleteExtra(r, (startBefore,end));
    } else{
      r := xs;
    }
  }
  else // goes at front
  {
    r := InternalAdd2(xs, range);
  }
}

可能感兴趣的一些要点:

  • Rust 代码通过键和值操作BTreeMap。Dafny 代码则通过(随机访问)索引操作排序后的seq。我让 Dafny 操作镜像 Rust 操作,尽管这使得 Dafny 代码不那么自然。

  • Rust 代码还更新了self.len,即 RangeSetBlaze 中的整数数量。Dafny 代码忽略了这一点。(更新len是一个将来可能会添加到 Dafny 代码中的功能。)

  • 和之前一样,Rust 版本包含了 Dafny 忽略的防止溢出的代码。

我继续通过编写internal_add2delete_extra的 Dafny 版本来完成移植,这两个函数是internal_add调用的。我通过编写这两个方法调用的其他方法等完成了移植。完整的移植代码约 185 行。你可以在这里查看。

它没有验证。接下来我们将处理验证。

规则 8:验证 Dafny 版本的算法。

在这一步,你将向代码中添加验证提示,例如,形式为assert语句。Dafny 使用这些提示来尝试验证你的代码。作为 Dafny 初学者,我(Carl)发现添加提示比编写代码更困难。部分原因是因为我不知道 Dafny 何时(或是否)会满足条件,然后我可以停止。

然而,我确实学会了如何开始。例如,上述InternalAdd的代码产生了两个验证错误。首先,Dafny 验证器报告说一个ensures可能不成立:

This postcondition might not hold: SeqToSet(r) == SeqToSet(xs) + RangeToSet(range)

附注:“后置条件”对应于*ensures*。“前置条件”对应于*requires*

其次,Dafny 验证器抱怨DeleteExtra的前置条件(即requires之一)无法证明。

我们将首先关注第一个问题,通过在方法底部添加assert。我们写它是为了反映ensures

// ... adding this as the last line in InternalAdd
assert SeqToSet(r) == SeqToSet(xs) + RangeToSet(range);
}

我们将明确忽略DeleteExtra问题,暂时用assume来处理。

// ...
      assume exists i: nat :: i < |r| && r[i] == (startBefore,end) && ValidSeq(r[..i+1]) && ValidSeq(r[i+1..]);
      r := DeleteExtra(r, (startBefore,end));
//...

Dafny 验证器现在仅对我们的最终assert提出抱怨。它说“断言可能不成立。”

记住,InternalAdd 代码使用嵌套的 if 语句将其工作分成五种情况。接下来,我们将 assert 从方法的末尾移动到每个情况的末尾。请在结果中寻找以 // case 注释结尾的行:

method InternalAdd(xs: seq<NeIntRange>, range: IntRange) returns (r: seq<NeIntRange>)
  requires ValidSeq(xs)
  ensures ValidSeq(r)
  ensures SeqToSet(r) == SeqToSet(xs) + RangeToSet(range)
{
  var (start, end) := range;
  if end < start {
    r := xs;
    assert SeqToSet(r) == SeqToSet(xs) + RangeToSet(range); // case 0 - validates
    return;
  }

  var beforeHi := IndexAtOrBeforePlusOne(xs, start);
  if beforeHi > 0 { // does not go at front
    var (startBefore, endBefore) := xs[beforeHi-1];
    if endBefore+1 < start {
      r := InternalAdd2(xs, range);
      assert SeqToSet(r) == SeqToSet(xs) + RangeToSet(range); // case 1 - validates
    } else if endBefore < end {
      r := xs[..beforeHi-1] + [(startBefore, end)] + xs[beforeHi..];
      assume exists i: nat :: i < |r| && r[i] == (startBefore,end) && ValidSeq(r[..i+1]) && ValidSeq(r[i+1..]);
      r := DeleteExtra(r, (startBefore,end));
      assert SeqToSet(r) == SeqToSet(xs) + RangeToSet(range); // case 2 - fails
    } else{
      r := xs;
      assert SeqToSet(r) == SeqToSet(xs) + RangeToSet(range); // case 3 - fails
    }
  }
  else // goes at front
  {
    r := InternalAdd2(xs, range);
    assert SeqToSet(r) == SeqToSet(xs) + RangeToSet(range); // case 4 - validates
  }
}

Dafny 现在告诉我们情况 0、1 和 4 已验证。情况 2 失败(并包含我们最终需要删除的 assume)。不过,现在让我们处理情况 3。

回忆一下这篇文章的规则 7,情况 3 是当我们添加一个新的范围(红色),该范围完全被现有范围(蓝色的“之前”范围)覆盖时,因此代码无需做任何操作。

那么,从逻辑上讲,我们知道什么?我们知道之前范围覆盖的整数是新范围覆盖的整数的超集。我们还知道之前范围是我们原始的排序且不重叠的范围列表(蓝色范围)的一部分。我们将通过 assert 语句将这两个提示添加到我们的代码中:

Dafny 同意这两个提示是正确的(绿色勾号),但它仍然不接受关注的 assert(红色标记)。

我们似乎需要再一个提示。具体来说,我们需要说服 Dafny 认为之前范围覆盖的整数是所有排序且不重叠范围列表中整数的子集。直观上,这是真的,因为之前范围是列表中的一个范围。

我们将这个提示写成一个没有主体的引理。Dafny 接受了它。

附录:为什么 Dafny 会接受这个主体为空的引理?我不知道,也没有很好的直觉。这只是有效。如果无效,我会尝试在其主体中添加断言。

使用引理,情况 3 现在已验证:

这意味着我们已验证情况 0、1、3 和 4。接下来我们将处理情况 2。此外,一些提到的方法,例如 DeleteExtra,尚未验证,我们需要对其进行处理。 [您可以看到截至目前的代码,在这里。]

有关验证调试的一般建议,请参阅 Dafny 用户指南的这一部分。我还推荐 这个 Stack Overflow 答案 和 James Wilcox 教授的迷你教程。

总的来说,想法是将验证算法的任务分解成许多更小的验证任务。我发现这比编程更困难,但并不太难,仍然很有趣。

我最终在 185 行常规代码中添加了大约 200 行验证代码(完整代码在这里)。当我最后验证完最后一个方法时,我错误地认为自己已经完成了。

令我惊讶(也是失望)的是,第一次所有内容验证通过并不意味着工作结束。你还必须确保你的项目将再次验证并且能为其他人验证。接下来,我们将讨论这一规则。

规则 9:重新审视你的验证以确保可靠性。

我以为我完成了。然后,我将数学 Min 函数的六行定义从Dafny 标准库移动到我的代码中。这导致我的验证失败,没有逻辑原因(字面上!)。后来,在我以为已经修复了之后,我删除了一个未使用的方法。再次,验证因没有逻辑原因而开始失败。

发生了什么?Dafny 通过随机搜索进行启发式工作。表面上更改代码(或更改随机种子)可以改变搜索所需的时间。有时,时间的变化非常剧烈。如果新的时间超过了用户设置的时间限制,验证将失败。[我们将在下面的提示 #3 中进一步讨论时间限制。]

你应该通过尝试不同的随机种子来测试验证的可靠性。以下是我在 Windows 上使用的命令来验证一个具有 10 个随机种子的文件。

@rem Find the location of Dafny and add it to your path
set path=C:\Users\carlk\.vscode-insiders\extensions\dafny-lang.ide-vscode-3.1.2\out\resources\4.2.0\github\dafny;%path%
dafny verify seq_of_sets_example7.dfy --verification-time-limit:30 --cores:20 --log-format csv --boogie -randomSeedIterations:10

结果是一个 *.csv 文件,你可以将其打开为电子表格,然后寻找失败:

附注:有关测量 Dafny 验证可靠性的更多想法,请参见这个关于分析 *.csv 文件的 Stack Overflow 答案以及这个推荐 dafny-reportgenerator 工具的 GitHub 讨论。

在找到问题点后,我请来合著者 Divyanshu Ranjan 帮忙。Divyanshu Ranjan 利用他在 Dafny 上的经验来修复项目的验证问题。

这是他的提示,以及来自项目的示例:

提示 #1:尽可能移除涉及“forall”和“exists”的 require 语句。

回顾规则 4,幽灵函数 SeqToSet 返回由排序且不相交的非空范围列表覆盖的整数集合。我们用函数 ValidSeq 定义“排序且不相交”,该函数内部使用了两个 forall 表达式。我们可以像这样移除列表必须排序和不相交的要求:

ghost function SeqToSet(sequence: seq<NeIntRange>): set<int>
  decreases |sequence|
  // removed: requires ValidSeq(sequence)
{
  if |sequence| == 0 then {}
  else if |sequence| == 1 then RangeToSet(sequence[0])
  else RangeToSet(sequence[0]) + SeqToSet(sequence[1..])
}

从我们的角度来看,我们有相同的有用函数。从 Dafny 的角度来看,该函数避免了两个 forall 表达式,并且更容易应用。

提示 #2 使用 calc 避免 Dafny 的猜测工作。

使用 Dafny calc 语句,你列出了得出结论所需的确切步骤。例如,这是 DeleteExtra 方法中的一个 calc

calc {
    SeqToSet(xs[..indexAfter+1]) + SeqToSet(xs[indexAfter+1..]);
  ==
    { SeqToSetConcatLemma(xs[..indexAfter+1], xs[indexAfter+1..]); }
    SeqToSet(xs[..indexAfter+1] + xs[indexAfter+1..]);
  ==
    { assert xs == xs[..indexAfter+1] + xs[indexAfter+1..]; }
    SeqToSet(xs);
  ==
    { SetsEqualLemma(xs, r[indexAfter], r2, indexAfter, indexDel); }
    SeqToSet(r2);
  }

在代码的这一点,xs 是一个范围序列,但它可能不是排序的或不相交的。calc 断言:

  1. xs 的两部分覆盖的整数等于

  2. 两部分拼接覆盖的整数等于

  3. xs 覆盖的整数等于

  4. rs

对于每一步,我们可以包含引理或断言来帮助证明这一步骤。例如,这个断言有助于证明从第 3 步到第 4 步的过渡:

{ assert xs == xs[..indexAfter+1] + xs[indexAfter+1..]; }

为了提高效率和控制,这些引理和断言在它们的步骤之外对验证器不可见。这使得 Dafny 集中注意力。

提示 #3:使用 timeLimit 提供所需的计算。

Dafny 在用户设置的 timeLimit 下停止尝试验证方法。10、15 或 30 秒的限制很常见,因为作为用户,我们通常希望那些不可能发生的验证能快速失败。然而,如果我们知道验证最终会发生,我们可以设置一个特定于方法的时间限制。例如,Divyanshu Ranjan 发现 DeleteExtra 通常会验证,但比其他方法花费更多时间,因此他添加了一个特定于方法的时间限制:

method {:timeLimit 30} DeleteExtra(xs: seq<NeIntRange>, internalRange: IntRange) returns (r: seq<NeIntRange>)
// ...

附注:*timeLimit* 并未考虑计算机之间速度的差异,因此请设置得稍微宽松些。

提示 #4:使用 split_here 将验证问题分为两部分。

Dafny 常见问题解答所述,有时一起验证一组断言更快,有时逐一验证更快。

使用 assert {:split_here} true; 语句将一系列断言拆分为两部分以进行验证。例如,即使有 timeLimitDeleteExtra 也会超时,直到 Divyanshu Ranjan 添加了这个:

// ...
else
  {
    r := (s[..i] + [pair]) + s[i..];
    assert r[..(i+1)] == s[..i] + [pair];
    assert r[(i+1)..] == s[i..];
    assert {:split_here} true; // split validation into two parts
    calc {
      SeqToSet(r[..(i+1)]) + SeqToSet(r[(i+1)..]);
// ...

提示 #5:保持引理简小。如有需要,跨引理拆分 ensures

有时引理会试图一次做太多事情。考虑一下 SetsEqualLemma。它与删除冗余范围相关。例如,如果我们将 a 插入到 xs 中,标记为“X”的范围将变得冗余。

SetsEqualLemma 的原始版本包含 12 个 requires 和 3 个 ensures。Divyanshu Ranjan 将其拆分为两个引理:RDoesntTouchLemma(11 个 requires 和 2 个 ensures)和 SetsEqualLemma(3 个 requires 和 1 个 ensures)。通过这一变化,项目的验证更加可靠。

应用这些技巧将提高我们证明的可靠性。我们能否使验证 100% 可靠?遗憾的是,不能。总有可能由于不幸的种子,Dafny 无法验证。因此,当你什么时候停止尝试改进验证?

在这个项目中,Divyanshu Ranjan 和我改进了验证代码,直到任何单次运行中验证错误的概率降到 33% 以下。因此,在 10 次随机运行中,我们看到的失败不超过 2 或 3 次。我们甚至尝试了 100 次随机运行。在 100 次运行中,我们看到了 30 次失败。

结论

所以,您看到了:九条规则来证明 Rust 算法的正确性。您可能会对这个过程不够简单或自动感到沮丧。然而,我反而感到鼓舞,因为这个过程完全是可能的。

附注:自高中几何课以来,我发现数学证明既迷人又令人沮丧。 “迷人” 是因为一旦证明的数学定理被认为永远正确。(欧几里得的几何仍被认为是正确的,而亚里士多德的物理学则不是。) “令人沮丧” 是因为我的数学课总是对我可以假设哪些公理以及我的证明可以迈出多大步伐感到模糊。Dafny 和类似的系统通过自动证明检查消除了这种模糊性。从我的角度来看,更好的是,它们帮助我们创建关于我深切关心的领域:算法的证明。

什么时候值得对算法进行正式证明?考虑到所涉及的工作,只有当算法在某种程度上复杂、重要或容易证明时,我才会再次进行这种证明。

未来这个过程可能如何改进?我希望看到:

  • 系统间的互换 — 一旦证明的几何定理就不需要再证明。我希望检查算法证明的系统能够互相使用对方的证明。

  • 一个像 Dafny 一样易于使用的全 Rust 系统 — 有关这方面的工作,请参见 [1,2]。

附注:你知道有一个易于使用的 Rust 验证系统吗?请考虑将其应用于 internal_add 的验证。这将使我们能够比较 Rust 系统的易用性和功能与 Dafny 的。

  • Rust 的证明类 **Cargo.lock** 文件 — 在 Rust 中,我们使用 Cargo.lock 来锁定项目依赖的已知良好组合。我希望当 Dafny 找到一种方法来证明,例如,一个方法时,它能锁定找到的证明步骤。这可以使验证更可靠。

  • 更好的 AI 验证 — 我的直觉是,经过一些改进的 ChatGPT 可能擅长创建 90% 需要的验证代码。我发现当前的 ChatGPT 4 在 Dafny 上表现较差,我认为是因为缺乏 Dafny 的训练示例。

  • 更好的 AI 验证 — 当 AI 生成代码时,我们担心代码的正确性。形式化验证可以通过证明正确性来提供帮助。(有关此的小示例,请参见我的文章 Check AI-Generated Code Perfectly and Automatically。)

感谢你加入我们对程序正确性的探索。我们希望如果你有一个需要证明的算法,这些步骤将帮助你找到该证明。

关注 Carl 在 Medium 上。我在 Rust 和 Python 的科学编程、机器学习和统计学方面写作。我倾向于每月写一篇文章。

阅读 Divyanshu Ranjan 更多的工作,见 他的博客。除了形式化方法,博客还涉及几何、统计学等主题。

2022 年 NLP 初创公司融资情况

原文:towardsdatascience.com/nlp-startup-funding-in-2022-caad77cb0f0?source=collection_archive---------8-----------------------#2023-01-09

Robert DaleTowards Data Science Robert Dale

·

关注 发表在 Towards Data Science ·32 分钟阅读·2023 年 1 月 9 日

--

照片由 Jason Leung 提供,发布在 Unsplash

NLP 技术的商业应用近年来急剧增长已不是什么秘密。从聊天机器人和虚拟助手到机器翻译和情感分析,NLP 技术现在被广泛应用于各种行业。随着对能够处理人类语言的技术需求的增加,投资者也迫不及待地想要参与其中。本文将回顾过去一年 NLP 初创公司的融资情况,识别出获得投资的应用和领域。

这篇文章的一个版本将会在 《自然语言工程杂志》 于 2023 年初刊登。

1. 引言

在策划 This Week in NLP 的内容过程中,这是一份关于 NLP 工具和技术商业应用的通讯(免费订阅请 点击这里!),我跟踪了自然语言处理领域的公司融资和收购情况。在 2022 年,我发现了 340 多个相关的融资事件,从种子轮融资到晚期的 E 轮和 F 轮融资。本文重点关注早期公司:具体而言,是那些报告了种子轮融资、种子轮融资或 A 轮融资的公司。这些公司尚未在市场上建立他们的产品或服务,可以明确地被描述为初创公司;它们对投资者构成了最高风险,同时我们也期望它们能成为创新思想的优秀来源。在我所掌握的数据中,略超过 50%的融资事件发生在种子轮、种子轮或 A 轮阶段;在本文中,我试图对这 173 家公司提供的产品进行一定的组织和结构化,以突显过去十二个月中被认为值得投资的技术和应用领域。

2. 这里的内容

一份包含 170 多家公司的平面列表以及它们的业务内容将会非常难以阅读且缺乏洞察力。因此,本文围绕 NLP 初创公司的分类组织,旨在为对这一领域感兴趣的读者提供帮助。这样一个领域可以有多种结构方式,我并不声称这里使用的分类是唯一的方式;这只是一个对我而言有意义的领域地图,也可能使你更容易确定哪些部分的文章对你最相关。

在理想的世界里,我们可能会选择通过应用类型或应用领域来组织技术产品,将这两者视为两个正交的详尽解决方案。但正如在 我对 2021 年 NLP 初创公司融资的回顾 中所述,我认为采用一种略显不舒服的混合方法更具信息量和实用性,这种方法提取了一种技术类型的层级,但保留了一部分我们审查的公司在应用领域方面的更好组织。因此,本文结构如下:

  • 在第三部分中,我们关注那些主要处理文本而非语音的产品和服务:这包括搜索、信息提取、内容审核、文本生成和机器翻译等子类别。

  • 在第四部分中,我们关注那些以某种方式涉及对话的应用程序:这里识别的子类别包括开发平台、定制开发提供商、对话智能、沟通技能反馈、销售支持和会议生产力工具。

  • 在第五部分中,我们讨论了视听处理,子类别包括语音处理、语音合成和视频合成。

  • 在第六部分中,我们考察了领域特定的解决方案,涵盖了法律科技、健康科技、教育科技以及一些只有一两个参与者的松散领域。

  • 最后,在第七部分中,我们对该领域的发展方向进行了总结性结论。

如果上面列出的类别都是相互排斥的,那就好了,但实际上并非如此;在各个地方都有交叉和边界模糊,因此你可能不同意我选择将某些公司放在特定类别的位置。特别是,第 3、4 和 5 节中讨论的一些公司本可以被视为领域特定应用,但为了识别活动集群,我认为将它们归类于技术类型下更为有用。

还需要进一步说明一些方法学的评论和一些警告事项。

  • 这里呈现的信息来自对约 150 个相关新闻来源的手动和自动处理的综合搜寻。我不太可能捕捉到每一个相关的融资事件,但我相信这些结果是比较全面的;如果你认为我遗漏了一个在 2022 年获得种子轮或 A 轮融资的 NLP 初创公司,请给我发邮件,以便我可以调查遗漏的原因。

  • 我认为一家公司的产品或服务提供 NLP 功能,如果语言处理技术在该产品或服务中扮演了重要角色。不可避免地,有些情况处于边缘地带;例如,许多基于网络的产品现在都集成了简单的聊天机器人功能,但我仅在聊天机器人是一个重要功能或具有有趣和创新的功能时,才将其纳入范围。

  • 我对公司业务的描述基于写作时对每个公司网站的简要审查。但情况可能迅速变化,公司也可能会显著转型,因此根据你阅读本文的时间,任何给定公司的网站现在可能讲述了不同的故事。此外,我尽量在短时间内提供尽可能多的信息,通常不超过一句话;但也有几个例子,我花了过多的时间和精力在一个网站上,试图明确公司提供的内容,结果未能成功,导致描述不够清晰。

  • 对于每家公司,我都标明了公司的成立年份,以及该公司获得了哪些轮次的融资、何时获得融资以及融资金额。所有金额均以美元计,尽管应注意到许多公司位于美国以外,并且获得了其他货币的融资;这里显示的美元等值金额基于撰写时的汇率,因此可能与融资时报告的金额略有不同。

不管怎样,废话不多说。我们开始吧。

3. 文档处理

在行业中使用的术语“文档 AI”通常指的是需要处理文档中的物理格式问题的方法(例如,从表格中提取信息),而“文本处理”通常更关注文档的语言内容,与其物理呈现方式抽象开来。然而,我的印象是,越来越多的解决方案正在将这两种范式结合在一起,因此我在这里将它们视为一个综合类别。

3.1 搜索

许多初创公司提供面向开发者的搜索引擎,这些开发者希望将搜索功能添加到他们的项目中。ZincSearch(成立于 2022 年;种子轮,360 万美元,2022 年 3 月)和Meilisearch(成立于 2018 年;A 轮,1500 万美元,2022 年 10 月)提供可下载的搜索引擎,Meilisearch 还提供完全托管的云版本。SeMI Technologies(成立于 2019 年;A 轮,1650 万美元,2022 年 2 月),Hebbia(成立于 2020 年;A 轮,3000 万美元,2022 年 6 月)和Vectara(成立于 2020 年;种子轮,2000 万美元,2022 年 10 月)强调他们使用向量搜索,也称为神经搜索或语义搜索,这与基于术语索引的旧方法形成对比;Pinecone Systems(成立于 2019 年;A 轮,2800 万美元,2022 年 3 月)提供一种可以作为搜索基础设施的向量数据库产品,Nuclia(成立于 2019 年;种子轮,540 万美元,2022 年 4 月)是一个端到端的 API,允许团队使用他们自己的向量化和标准化算法,同时提供存储、索引和查询功能。Deepset(成立于 2018 年;A 轮,1400 万美元,2022 年 4 月)提供一个开源 NLP 框架 Haystack,帮助开发者为各种搜索用例构建管道,而Opster(成立于 2019 年;A 轮,500 万美元,2022 年 7 月)提供一个自动化和管理企业搜索引擎及数据库的平台。

更大胆的策略是将你的搜索技术定位为对现有技术的替代方案:You.com(成立于 2020 年;A 轮,2500 万美元,2022 年 7 月),旨在成为一个开放的搜索平台,允许其他人基于其搜索技术进行开发,包括 AI 驱动的功能,如 YouCode,可以根据搜索查询生成代码,类似于 GitHub 的 Copilot,以及 YouWrite,由 OpenAI 的 GPT-3 驱动,可以生成文章、博客帖子和模板信件。但更常见的是针对特定用例:Ocean.io(成立于 2017 年;风险轮,630 万美元,2022 年 1 月)和Grata(成立于 2016 年;A 轮,2500 万美元,2022 年 2 月)都旨在帮助企业找到合适的业务目标;Vetted(成立于 2019 年;A 轮,1400 万美元,2022 年 8 月)是一个产品搜索引擎,旨在帮助消费者发现最符合需求的品牌和产品;Outmind(成立于 2019 年;种子轮,210 万美元,2022 年 9 月)专注于在各种工作场所应用中跨相关数据的聚合搜索;Mem(成立于 2021 年;A 轮,2350 万美元,2022 年 11 月)是一个生产力应用程序,能够在用户的笔记中进行搜索;Hypertype(成立于 2021 年;前种子轮,130 万美元,2022 年 5 月)搜索电子邮件档案,以自动化新邮件的撰写。

在文本搜索领域之外,Twelve Labs(成立于 2021 年;种子轮,500 万美元,2022 年 3 月)提供了一个视频搜索和理解平台,利用语义搜索在大规模视频档案中定位相关场景。

3.2 信息提取

提取和汇总信息是许多初创公司的关键关注点:KnowledgeNet.ai(成立于 2021 年;A 轮,940 万美元,2022 年 2 月)旨在通过整合电子邮件、客户关系管理系统、文件存储、职业网络和行业新闻源中的分散对话和数据,来支持交易者和高管;Ask-AI(成立于 2021 年;种子轮,900 万美元,2022 年 10 月)汇总了大量文本公司知识来源和客户沟通,通过问答界面使数据变得易于访问。

现在,通用工作流自动化产品中包含某种程度的文档 AI 能力已相当普遍。NanoNets(成立于 2017 年;A 轮融资,1000 万美元,2022 年 2 月)允许开发者创建可以从文档中提取数据并自动填充数据库的机器学习模型;Krista(成立于 2016 年;A 轮融资,1500 万美元,2022 年 2 月)强调其低代码自动化平台的对话性质;而Alkymi(成立于 2017 年;A 轮融资,2100 万美元,2022 年 10 月)提供一个统一的平台,从各种不同的非结构化数据源中提取数据,并提供大量“蓝图”用于常见文档类型。

一些信息提取产品的关注点更为狭窄:Neuron7.ai(成立于 2020 年;A 轮融资,1000 万美元,2022 年 6 月)将自己定位为服务智能平台,利用技术从组织中的数据和人员中提取信息,并利用这种“集体智能”帮助人们诊断和解决客户问题;Sensible(成立于 2020 年;种子轮融资,650 万美元,2022 年 11 月)提供一个文档编排平台,提供预先构建的模板,用于从 150 多种保险文档类型中提取数据;Stimulus(成立于 2017 年;种子轮融资,250 万美元,2022 年 8 月)是一个关系智能平台,利用数据和分析通过专有的评分机制帮助公司做出更好的采购决策;Prophia(成立于 2018 年;A 轮融资,1020 万美元,2022 年 12 月)的平台搜寻商业房地产合同,并提取关键条款,如平方英尺和租赁日期;以及theGist(成立于 2022 年;前种子轮融资,700 万美元,2022 年 11 月)的首款产品 theGist for Slack,提供 Slack 讨论的结构化、个性化摘要,过滤噪音,以免员工错过重要信息。

3.3 情感分析

情感分析仍然吸引着新的初创公司,通常是在评估和衡量消费者或用户反馈的背景下。一个常见的焦点是通过多渠道或来源聚合和分类反馈,现在使用 AI 模型:Viable(成立于 2020 年;种子轮,500 万美元,2022 年 5 月)在后台使用 GPT-3,提供综合反馈的书面分析服务,以及对反馈的自然语言查询;Idiomatic(成立于 2016 年;种子轮,400 万美元,2022 年 5 月)使用为每个特定业务案例量身定制的模型来分类反馈;Spiral(成立于 2018 年;种子轮,130 万美元,2022 年 11 月)将其反馈技术出售给中大型公司,涉及银行、金融科技、连接设备和保险行业。Lang(成立于 2018 年;A 轮,1050 万美元,2022 年 5 月)、Unwrap(成立于 2021 年;种子轮,320 万美元,2022 年 7 月)和Sturdy AI(成立于 2019 年;种子轮,310 万美元,2022 年 6 月)也类似地对检测到的问题和关注点进行分类,并提供各种形式的分析。

一个相关的重点是品牌管理:My Telescope(成立于 2018 年;前种子轮,260 万美元,2022 年 3 月)是一个市场情报和搜索平台,为营销人员和品牌提供市场趋势、品牌强度和活动效果的长期影响预测,Knit(成立于 2015 年;种子轮,360 万美元,2022 年 6 月)通过年轻消费者网络提供基于视频反馈和定量调查的详细消费者洞察;该公司的视频分析 AI 声称能在几分钟内分析数小时的视频反馈。

3.4 内容审核

尽管以上讨论的技术类别在一定程度上是传统和长期存在的,但“内容审核”是一个近年来才出现的类别,并且随着对虚假信息和有害语言使用问题的关注增加,预计将会增长。

Fairwords(成立于 2014 年;A 轮融资,530 万美元,2022 年 2 月)使用类似拼写检查的界面,提醒用户在输入时有害语言,并提供该语言可能被解释的信息;该软件还可以检测贿赂和腐败、串通及歧视的迹象;mpathic(成立于 2021 年;种子轮融资,400 万美元,2022 年 6 月)类似地帮助员工识别沟通中的潜在误解或曲解,并实时调整;Checkstep(成立于 2020 年;种子轮融资,500 万美元,2022 年 5 月)专注于虚假信息、仇恨言论、儿童性虐待材料(CSAM)、欺凌和垃圾邮件,同时具有版权侵权管理功能;Areto Labs(成立于 2020 年;前种子轮融资,73 万美元,2022 年 6 月)帮助公司识别和监控在线滥用,并通过自动化反制措施,如静音、封锁和举报负责的账户,处理这些问题;Modulate(成立于 2017 年;A 轮融资,3000 万美元,2022 年 8 月)是 ToxMod 的开发者,这是一种用于实时视频游戏语音聊天中检测和处理暴力或其他冒犯性言论的工具。Diversio(成立于 2018 年;A 轮融资,600 万美元,2022 年 1 月)衡量和跟踪关于多样性、公平性和包容性的语言,识别员工编写文本中的“包容性痛点”。

VineSight(成立于 2018 年;种子轮融资,400 万美元,2022 年 9 月)跟踪、分析并减轻针对品牌、活动和事业的在线虚假信息和毒性;Alethea(成立于 2019 年;A 轮融资,1000 万美元,2022 年 11 月)检测和减轻虚假信息和社交媒体操控的实例;Logically(成立于 2017 年;A 轮融资,2400 万美元,2022 年 3 月)将人工智能与专家分析师结合,发现、筛选和应对信息威胁;Pendulum(成立于 2021 年;种子轮融资,590 万美元,2022 年 1 月)的平台利用“叙事跟踪”在多种媒体中揭示叙事形成初期的威胁和机会,并跟踪其在网上传播的情况。

相关领域之一是隐私管理:Redactable(成立于 2018 年;种子轮,120 万美元,2022 年 5 月)和Private AI(成立于 2019 年;A 轮,800 万美元,2022 年 11 月)自动检测文档中的个人可识别信息(PII)并进行删除;Lightbeam.ai(成立于 2020 年;种子轮,450 万美元,2022 年 4 月)尝试识别信息所属的具体客户或身份,以便安全团队可以更有效地自动化保护这些数据;Protopia AI(成立于 2020 年;种子轮,200 万美元,2022 年 12 月)专注于数据在机器学习推理过程中使用时的风险,通过模糊化个人信息来避免信息被识别或泄露给未经授权的第三方。

另一个相关领域是风险管理。Shield(成立于 2018 年;A 轮,1500 万美元,2022 年 1 月)为合规团队提供了一个工作场所智能平台:其技术利用 NLP 来检测员工沟通渠道中的行为违规,如市场操控;Concentric AI(成立于 2018 年;A 轮,1450 万美元,2022 年 5 月)识别并分类敏感信息,通过一种叫做‘风险距离’的度量来评估风险并解决安全问题;VISO Trust(成立于 2020 年;A 轮,1100 万美元,2022 年 3 月)是一个安全尽职调查平台,通过使用文档启发式方法、NLP 和 ML 自动化编制第三方网络风险数据的过程。这些是公司内部关注的解决方案;另一方面,KYP(成立于 2021 年;种子轮,96 万美元,2022 年 10 月)是一个第三方风险情报平台,旨在提供业务所依赖的合作伙伴的完整情况。

3.5 文本生成

如果你错过了最近对生成性 AI 的强烈关注,尤其是大语言模型在文本预测中的应用,那你真是与世隔绝了。今年在这一领域最引人注目的初创公司是内容平台Jasper(成立于 2021 年;A 轮,1.25 亿美元,2022 年 10 月);这是我所知 2022 年最大的一次单笔 NLP 初创公司融资事件。此外,还有Regie.ai(成立于 2020 年;种子轮,480 万美元,2022 年 6 月),其 GPT-3 驱动的文案写作平台专注于销售和营销团队。

仍在吸引资金的解决方案中,有些似乎基于较旧的文本生成方法:Linguix(成立于 2018 年;Pre-seed 轮融资,100 万美元,2022 年 2 月)是一个写作助手,提供拼写和语法检查以及文本重写和各种评分指标;Magical(成立于 2020 年;A 轮融资,3500 万美元,2022 年 6 月)是一个类似文本扩展器的生产力工具,软件可以检测网页上的元素,并允许创建自定义缩写以移动相应的文本。QorusDocs(成立于 2012 年;风险投资轮融资,1000 万美元,2022 年 10 月)是一个基于云的提案管理软件,简化 RFP 响应并自动生成提案;该软件利用 NLP 技术通过从公司文档档案中选择最重要和相关的内容来简化 RFP 响应过程。

相关的还有Mintlify(成立于 2020 年;种子轮融资,280 万美元,2022 年 5 月),其平台读取代码并创建文档以解释代码,并检测用户如何与文档互动以提高其可读性;以及Findable(成立于 2020 年;种子轮融资,210 万美元,2022 年 6 月),其技术通过分析标题、图片和图纸来自动化建筑文档的组织。

3.6 机器翻译

Language I/O(成立于 2011 年;A 轮融资,650 万美元,2022 年 1 月)提供一个翻译平台,允许客户用超过 100 种语言提供实时客户支持;Viva Translate(成立于 2020 年;种子轮融资,400 万美元,2022 年 2 月)是一个跨语言翻译工具,专注于自由职业者与客户沟通中的翻译;Weglot(成立于 2016 年;A 轮融资,4800 万美元,2022 年 3 月)是一个无代码网站本地化技术提供商,其平台支持通过后期编辑功能进行人工优化;XL8(成立于 2019 年;Pre-Series A 轮融资,300 万美元,2022 年 7 月)提供优化的媒体内容机器翻译技术,包括合成配音或语音覆盖;以及WritePath(成立于 2009 年;种子轮融资,34 万美元,2022 年 12 月),是一个基于云的 B2B 翻译平台,针对商业、ESG 和投资者关系披露。

3.7 其他杂项应用

还有一些公司以不同方式处理自然语言文本输入,但这些公司并不完全符合上述已经扩展的类别。在文本到图像领域,有视觉艺术初创公司 Stability AI(成立于 2019 年;种子轮,1.07 亿美元,2022 年 10 月),该公司是 Stable Diffusion 的背后团队。Spiritt(成立于 2020 年;前种子轮,550 万美元,2022 年 7 月)将文本描述转化为应用,通过与聊天机器人的对话获取所需的信息。Zenlytic(成立于 2018 年;种子轮,540 万美元,2022 年 11 月)是一款无代码商业智能工具,提供自然语言界面。还有 Unlikely AI(成立于 2018 年;种子轮,2000 万美元,2022 年 9 月),他们以追求大型神经网络的替代方案为噱头,推出了他们的第一个产品——一个解决和解释隐晦填字谜的应用。

4. 对话 AI

4.1 开发平台

市场上似乎仍然有空间容纳新的自助式对话 AI 开发平台。其中一些平台专注于基于文本的聊天机器人开发:Druid(成立于 2018 年;A 轮,1500 万美元,2022 年 5 月)和 OpenDialog AI(成立于 2019 年;种子轮,480 万美元,2022 年 5 月)提供无代码聊天机器人创建平台;Zowie(成立于 2019 年;种子轮,500 万美元,2022 年 1 月)针对在线销售的企业,将无代码自动化能力与一套工具结合,允许客服人员提供个性化服务和产品推荐。

其他公司增加了语音功能:NLX(成立于 2018 年;种子轮,500 万美元,2022 年 1 月)和 Parloa(成立于 2017 年;种子轮,425 万美元,2022 年 5 月)提供无代码/低代码平台,用于自动化包括电话和聊天在内的全渠道客户服务,Flip CX(最初为 RedRoute;成立于 2017 年;种子轮,650 万美元,2022 年 2 月)强调能够处理语音电话的重要性,提供易于使用的配置工具,利用已经设计好的呼叫流程模式。

4.2 定制开发

也有不少新兴公司会利用他们自己的平台和工具集为你构建对话应用程序。Futr(成立于 2017 年;种子轮融资,250 万美元,2022 年 4 月)强调其平台支持所有社交渠道的多语言实时聊天;Tenyx(成立于 2021 年;种子轮融资,1500 万美元,2022 年 5 月)利用所谓的‘神经科学启发’人工智能构建基于语音的虚拟客服代理;Curious Thing(成立于 2018 年;种子轮融资,470 万美元,2022 年 5 月),其技术之前专注于人力资源相关的互动,现在转向提供更广泛的语音驱动对话人工智能解决方案,包括入站和出站电话;而Tymely(成立于 2020 年;种子轮融资,700 万美元,2022 年 9 月)则使用 AI-人类混合技术来自动化客户服务能力,每个机器生成的响应都由人工代理进行验证。

Chatdesk(成立于 2016 年;A 轮融资,700 万美元,2022 年 1 月)有一个有趣的模式:完全摒弃聊天机器人,它寻找、招聘和培训品牌的‘超级粉丝’成为‘Chatdesk 专家’,并在后台使用机器学习分析之前的支持消息,创建一个符合品牌的知识库,使这些超级粉丝能够以品牌的声音和政策回应客户问题。

4.3 对话智能

继续几年来一直可见的趋势,一些公司提供技术,对对话互动进行某种形式的分析,无论这些互动涉及虚拟代理还是人工代理。

Wiz.ai(成立于 2019 年;A 轮,2000 万美元,2022 年 1 月),专注于东南亚语言的对话 AI,使用前端对话机器人鼓励客户参与对话,同时后端实时筛选数据并将对话中的洞察存储到公司的现有 CRM 系统中以供后续分析;Talkmap(成立于 2017 年;A 轮,800 万美元,2022 年 2 月)对与客户的互动进行标记、结构化和分析,旨在提供接近实时的对话洞察;Affogata(成立于 2018 年;种子轮,950 万美元,2022 年 3 月)提供一个语音分析平台,允许企业识别异常模式,以简化实时响应并采取预防措施;Winn.AI(成立于 2021 年;种子轮,1700 万美元,2022 年 9 月)监控销售通话,自动跟踪、捕获和更新 CRM 条目,减少销售人员自行记笔记的需求;Operative Intelligence(成立于 2021 年;种子轮,350 万美元,2022 年 12 月)提供旨在帮助呼叫中心操作员克服对客户联系原因的误解的技术,通过识别真实原因来减少等待时间并改善问题解决。Jiminny(成立于 2016 年;A 轮,1700 万美元,2022 年 8 月)是一个对话智能平台,分析视频中的情绪,自动评分通话互动并生成实时洞察。

4.4 沟通技巧反馈

对人类代理对话贡献的分析,以提供关于沟通技巧的反馈,可以被视为一种特定形式的对话智能。

Abstrakt(成立于 2020 年;前种子轮,12 万美元,2022 年 3 月)提供实时电话辅导,监听通话并提出有用的建议;Klaus(成立于 2017 年;A 轮,1200 万美元,2022 年 9 月)通过跟踪各种沟通 KPI 来辅导代理,识别辅导机会并衡量支持质量。

Call Simulator(成立于 2021 年;种子轮,57.5 万美元,2022 年 1 月)是一个对话模拟平台,旨在为呼叫中心代理准备现实场景;Second Nature(成立于 2018 年;A 轮,1250 万美元,2022 年 1 月)提供一个模拟器,通过与销售代表对话的虚拟角色来测量代表们对关键话题的覆盖深度。

更广泛地说,Yoodli(成立于 2021 年;种子轮,600 万美元,2022 年 8 月)分析语音以提供改进沟通技能的建议:该平台为用户提供文字记录,并分析填充词的使用、非包容性语言、节奏、肢体语言和其他可操作的见解。该公司最近与 Toastmasters International 达成协议,为其提供演讲辅导,这是一个知名的公众演讲和领导力培训组织。

4.5 销售支持

还有许多公司提供各种形式的我们在这里称之为销售支持的服务。Tactic(成立于 2020 年;种子轮,450 万美元,2022 年 3 月)通过允许销售和营销人员用普通语言询问客户和市场数据,并应用过滤器以优先排序和排名结果来自动化客户和市场研究;Connectly.ai(成立于 2020 年;种子轮,金额未公开,2022 年 7 月)是一个无需编码的工具,允许企业通过 AI 驱动的“小型机器人”创建和发送互动和个性化的营销活动;Demoleap(成立于 2020 年;种子轮,440 万美元,2022 年 8 月)是一个现场演示助手和销售发现平台,指导销售人员在现场演示过程中遵循销售流程;Heyday(成立于 2021 年;种子轮,650 万美元,2022 年 6 月)是一个用于零售商的对话 AI 平台,自动化 FAQ;AdTonos(成立于 2016 年;种子轮,210 万美元,2022 年 8 月)通过其 YoursTruly 平台通过智能音响和移动设备播放互动广告来货币化音频流。

4.6 会议生产力工具

由于 Covid 驱动的虚拟会议平台(如 Zoom 和 Teams)的使用增加,出现了一个相对较新的工具类别市场,这些工具旨在支持会议生产力;在许多方面,这些工具是对在对话智能背景下开发的工具和技术的重新利用。

Sembly AI(成立于 2019 年;种子轮,金额未公开,2022 年 3 月)、Headroom(成立于 2020 年;种子轮,900 万美元,2022 年 8 月)、Xembly(成立于 2020 年;A 轮,1500 万美元,2022 年 10 月)、Fathom(成立于 2020 年;种子轮,470 万美元,2022 年 11 月)和 tl;dv(成立于 2020 年;种子轮,460 万美元,2022 年 6 月)都提供一些功能组合,用于转录和分析会议,提取主题和行动项目,以及生成摘要和会议记录。

Airgram(成立于 2020 年;A 轮融资,1000 万美元,2022 年 8 月)是一款音视频录制工具,可以设置为自动加入预定的 Zoom、Google Meet 或 Microsoft Teams 会议,在用户不在场时进行录制;该工具提供灵活的播放选项,并配有转录、话题检测和行动项识别功能;Amy(成立于 2019 年;种子轮融资,600 万美元,2022 年 6 月)是一个销售智能平台,旨在通过利用公开的潜在客户信息来简化会议准备,将这些数据转化为会议简报,提供对潜在客户的有价值洞察。

5.2 音视频处理

我们引入这一类别以涵盖语音技术除支持对话式人工智能之外的应用,同时也包括与视频结合使用的情况。

5.1 语音处理

NeuralSpace(成立于 2019 年;种子轮融资,170 万美元,2022 年 2 月)专注于低资源语言的语音技术开发,提供覆盖 90 多种语言的自助工具包,并包括自动语言检测;Ava(成立于 2014 年;A 轮融资,1000 万美元,2022 年 3 月)是一款实时字幕平台,能够在会议或视频中听取音频,为听障人士提供字幕,并标记每条字幕的发言者;Sounder(成立于 2019 年;A 轮融资,770 万美元,2022 年 2 月)是一款端到端的播客管理平台,涵盖品牌安全和品牌适宜性分析、话题分析、内容总结和动态分段;AssemblyAI(成立于 2017 年;A 轮融资,2800 万美元,2022 年 3 月)提供一组基于 LLM 的“音频智能”API,用于转录和理解音频数据,应用包括内容审核、情感检测、总结和个人信息遮盖。

Sanas(成立于 2020 年;A 轮融资,3200 万美元,2022 年 6 月)提供实时口音翻译,帮助多语言用户通过口音矫正实现清晰沟通;Namecoach(成立于 2014 年;A 轮融资,800 万美元,2022 年 11 月)提供嵌入上下文感知音频姓名发音按钮的软件,使用户能够自信地发音。

5.3 语音合成

Murf AI(成立于 2020 年;A 轮融资,1000 万美元,2022 年 9 月)是一家合成语音技术初创公司,开发逼真的 AI 语音用于播客、幻灯片演示和专业演讲,拥有 120 多种语言的精选语音库。在具体应用层面,ping(成立于 2016 年;种子轮融资,500 万美元,2022 年 6 月)允许商业司机听到他们的智能手机消息和电子邮件以超过 105 种语言朗读。

语音合成的一大用途是在其他语言中进行音频配音。Dubverse(成立于 2021 年;种子轮,80 万美元,2022 年 6 月)是一个自动化配音平台,允许用户几乎实时地将视频配音成多种语言,目前支持 10 种印度语言和 20 种“全球”语言;Dubdub(成立于 2021 年;种子轮,100 万美元,2022 年 9 月)使用人工智能和机器学习为企业创建多语言视频内容,覆盖 40 种语言;Deepdub(成立于 2019 年;A 轮,2000 万美元,2022 年 2 月)提供娱乐内容的配音服务,使用合成的原演员声音版本,使配音版本听起来更像原版;Papercup(成立于 2017 年;A 轮,2000 万美元,2022 年 6 月)类似地通过生成听起来像原讲者的声音来翻译视频。这些应用通常提供一个人工环节功能,专业翻译人员可以执行质量检查,编辑和修订翻译及语音,以提高质量。

5.3 视频合成

我们在此包含了那些专注于视频输出创作的公司,因为这些公司通常也涉及语音合成。

Pictory(成立于 2019 年;种子轮,210 万美元,2022 年 1 月)将长形式内容如网络研讨会、博客和白皮书转换为短社交视频;ShortTok(成立于 2021 年;前种子轮,未披露金额,2022 年 10 月)开发自动化视觉讲故事技术,从客户的视频和多模态内容库中创建短视频;Peech(成立于 2020 年;种子轮,830 万美元,2022 年 8 月)提供一个视频编辑工具,专为内容营销团队设计,可以自动合成与内容匹配的品牌视觉,并去除填充词;Rephrase.ai(成立于 2019 年;A 轮,1060 万美元,2022 年 9 月)也为营销和内容团队构建生成式人工智能工具,用于合成视频制作。

一种特定形式的视频合成是虚拟角色的生成。Metaphysic(成立于 2021 年;种子轮,750 万美元,2022 年 1 月),这家公司以其汤姆·克鲁斯深度伪造而闻名,开发用于创建可以融入元宇宙的数字头像的工具;Inworld AI(成立于 2021 年;种子轮,1250 万美元,2022 年 3 月)是另一个创建人工智能驱动的虚拟角色、沉浸式现实和元宇宙空间的平台,使非技术用户能够通过自然语言描述来创建角色个性;Deep Voodoo(成立于 2020 年;种子轮,2000 万美元,2022 年 12 月)是由《南方公园》创作者特雷·帕克和马特·斯通创办的深度伪造初创公司。

Speech Graphics(成立于 2010 年;A 轮,700 万美元,2022 年 2 月)提供基于音频的面部动画技术,使游戏和其他应用中的动画角色在讲话时能够正确地移动嘴巴;Hour One(成立于 2019 年;A 轮,2000 万美元,2022 年 4 月)的技术将人类转化为虚拟人类角色,这些角色可以以逼真的表现力被激活。Carter(成立于 2022 年;种子轮,200 万美元,2022 年 12 月)正在研发对话式 AI,以帮助游戏开发者使计算机化的游戏角色更具生动性。NeuralGarage(成立于 2021 年;种子轮,150 万美元,2022 年 11 月)是一个视频配音平台:给定音频输入和人脸,它会将人的唇部和下巴动作转化为匹配的语言,无论语言是什么。

6. 领域特定解决方案

6.1 法律技术

法律与语言密切相关,因此法律技术长期以来一直是语言处理技术应用的重要领域。

一个受欢迎的领域是文档分析和审查,其中 AI 支持的分析可以减少传统手动处理所需的大量时间。TermScout(成立于 2018 年;种子轮,500 万美元,2022 年 5 月)从合同中提取关键信息,以便于审查、评级和与行业标准的比较;Terzo(成立于 2020 年;A 轮,1630 万美元,2022 年 11 月)从合同中提取关键数据,帮助组织优化其供应商和客户关系中的支出和收入;Nammu21(成立于 2017 年;A 轮,1580 万美元,2022 年 10 月)将贷款文件拆解成结构化数据;Summize(成立于 2018 年;A 轮,600 万美元,2022 年 10 月)是一种合同审查解决方案,旨在通过与 Teams、MS Word 和 Slack 的集成来改善内部法律部门与业务用户之间的协作;以及 Della(成立于 2018 年;种子轮,250 万美元,2022 年 3 月)专注于复杂的单一文档,而不是大型文档审查项目。Zero(成立于 2014 年;A 轮,1200 万美元,2022 年 3 月)是基于 iOS 移动设备的生产力工具,集成了电子邮件收件箱和文档管理系统,提取关键信息,如可计费的交互,并自动将电子邮件归档到文件夹中。

另一个热门领域是提供法律文件起草支持。Henchman(成立于 2020 年;种子轮,320 万美元,2022 年 2 月)是一个合同起草初创公司,提供一个 Microsoft Word 插件,在你工作时从公司的数据库中建议条款;LexCheck(成立于 2015 年;种子轮,500 万美元,2022 年 3 月)提供一个合同谈判解决方案,分析合同以建立问题清单和合同语言修订;Harvey(成立于 2022 年;种子轮,500 万美元,2022 年 11 月)使用 GPT-3 根据任务描述为律师起草文档;该应用还可以回答法律问题。

许多公司将这些审查和起草功能与其他活动结合起来,以提供更全面的法律自动化平台。Uhura Solutions(成立于 2018 年;种子轮,180 万美元,2022 年 4 月)是一个低代码合同智能平台,使用 NLP 来简化合同和协议的分析及起草过程;Goodlegal(成立于 2021 年;前种子轮,130 万美元,2022 年 11 月)提供一套自动化工具,包括一个用于构建法律文本的拖放编辑器,并能够检查每个法律文本是否符合合法标准;PocketLaw(成立于 2018 年;A 轮,1060 万美元,2022 年 5 月)是一个主要面向中小企业的合同自动化 SaaS 法律技术平台;Klarity(成立于 2017 年;A 轮,1800 万美元,2022 年 1 月)为财务和会计团队提供自动化文档处理和管理平台;Josef(成立于 2017 年;种子轮,520 万美元,2022 年 11 月)是一个无代码软件平台,允许法律专业人士自动化重复任务,包括文档起草、提供法律指导和建议,并构建客户访谈的机器人。Legal OS(成立于 2018 年;种子轮,700 万美元,2022 年 1 月)是一个无代码法律自动化平台,将专家知识转化为数字知识图谱,之后可以用来构建各种法律产品和流程。

还有一些法律科技解决方案不完全符合上述类别。Alchemy Machines(成立于 2021 年;种子轮,40 万美元,2022 年 3 月)使用 NLP 和语音识别转录、分析和总结法律特定的网络会议和电话;Neur.on(成立于 2022 年;种子轮,170 万美元,2022 年 8 月)为法律专业人士提供定制的机器翻译解决方案;Ex Parte(成立于 2017 年;A 轮,750 万美元,2022 年 2 月)使用机器学习预测诉讼结果,推荐客户可以采取的行动以优化胜诉机会;而Proof Technology(成立于 2017 年;A 轮,550 万美元,2022 年 3 月)是一个相当独特的端到端解决方案,分析法院文件以提取案件标题信息,确定离被告或证人地址最近的送达人员,远程打印相关材料,并捕捉有关送达尝试的照片和描述数据。

6.2 健康科技

另一个与 NLP 有长期关系的领域是健康科技。这里的两个关键领域是医疗环境中对话式 AI 的使用以及医疗记录的处理。

在对话式 AI 方面,HeyRenee(成立于 2021 年;种子轮,440 万美元,2022 年 1 月)是一个以患者为中心的个人健康助理,可以提醒用户需要服用的药物,监测健康指标,处理处方续药,并安排与医生的虚拟或面对面访谈;Apowiser(成立于 2021 年;种子轮,150 万美元,2022 年 6 月)制作了 PharmAssist,一个基于聊天机器人的系统,用于支持客户在线购买非处方药,识别需要升级到医疗服务提供者的问题,以及检查对 OTC 药物成分的过敏和潜在敏感性;BirchAI(成立于 2020 年;种子轮,310 万美元,2022 年 1 月)旨在通过总结和分析客户与代表之间电话交谈的内容,简化医疗公司的客户支持;WhizAI(成立于 2017 年;A 轮,800 万美元,2022 年 9 月)提供了一个面向生命科学和医疗行业的分析平台的对话接口;而Kahun(成立于 2018 年;种子轮,800 万美元,2022 年 9 月)开发了一个临床评估聊天机器人,基于公司超过 3000 万条证据基础的医学见解地图。

关于医疗记录处理,DigitalOwl(成立于 2017 年;A 轮融资,2000 万美元,2022 年 1 月)提供一个医疗记录分析平台,能够从大量文档中提取相关信息;Dyania Health(成立于 2019 年;种子轮融资,530 万美元,2022 年 9 月)提供一个 NLP 平台,执行疾病专用的临床文本提取;Wisedocs(成立于 2018 年;种子轮融资,300 万美元,2022 年 3 月)使用智能字符识别技术读取和分析各种医疗记录审查相关文档;XpertDox(成立于 2015 年;种子轮融资,150 万美元,2022 年 8 月)开发了 XpertCoding 工具,该工具利用 AI 自动编码医疗索赔。DeepScribe(成立于 2017 年;A 轮融资,3000 万美元,2022 年 1 月)是一种环境医疗抄写员,记录医生与患者的对话,汇总并将其整合到健康记录系统中;Abridge(成立于 2018 年;A 轮融资,1250 万美元,2022 年 8 月)是一家对话 AI 初创公司,结构化并总结医生和患者的医疗对话,帮助填充健康记录中的相关信息;Eleos Health(成立于 2019 年;A 轮融资,2000 万美元,2022 年 4 月)构建了在行为健康临床医师与患者对话背景中环境运行的临床应用,生成会后临床进展记录和保险编码。

最后,一些不符合上述类别的健康科技初创公司:Kintsugi(成立于 2019 年;A 轮融资,2000 万美元,2022 年 2 月)利用机器学习和声音生物标志物检测临床抑郁症和焦虑的迹象;Marigold Health(成立于 2016 年;种子轮融资,600 万美元,2022 年 2 月)通过聊天支持小组帮助从物质使用或心理健康状况中恢复的个人,NLP 辅助同行管理他们的在线社区;以及 WeWalk(成立于 2019 年;资助,200 万美元,2022 年 7 月)是一家为视障人士开发智能手杖的初创公司:其语音助手可以回答有关用户位置、附近公共交通、识别附近建筑物和地标、预约 Uber,并提供适合视力有限或无视力者的实时步行导航。

6.3 其他领域

第三个领域是教育技术。FoondaMate(成立于 2020 年;种子轮融资,200 万美元,2022 年 5 月)是一个聊天机器人,通过 Facebook 和 WhatsApp 让发展中国家的学生可以提问,从而使教育变得更加可及; Prof Jim(成立于 2020 年;种子轮融资,110 万美元,2022 年 1 月)与教科书出版商和教育提供商合作,将教科书和其他文本学习材料转化为在线课程,包括自动生成的评估和头像讲师; Language Confidence(成立于 2016 年;种子轮融资,150 万美元,2022 年 3 月)提供一个 API,监听学生并评估和纠正他们的英语发音,提供视觉反馈; Copyleaks(成立于 2015 年;A 轮融资,600 万美元,2022 年 5 月)是一个抄袭检测解决方案,能够识别和跟踪 100 多种语言的在线抄袭内容。

还有一系列其他针对不同领域的应用:

  • 金融服务: Aviva(成立于 2022 年;种子轮融资,220 万美元,2022 年 12 月)使用自然语言处理将客户的口语与实时信用申请中的字段匹配; Webio(成立于 2016 年;A 轮融资,400 万美元,2022 年 6 月)为信用、催收和支付业务提供无代码对话式人工智能平台,允许客户提问、改变付款日期或安排新的还款计划。

  • 房地产:DOSS(成立于 2015 年;种子轮融资,金额未公开,2022 年 10 月)提供一个对话助手,允许客户咨询房地产建议和提示,搜索房源,并获取邻里信息和近期销售数据。

  • DevOps: Kubiya(成立于 2022 年;种子轮融资,600 万美元,2022 年 10 月)为 DevOps 团队提供对话式人工智能解决方案,允许用户用自然语言表达意图,并让虚拟助手自动化简单和繁琐的任务。

  • 无桌面工作环境: Datch(成立于 2018 年;A 轮融资,1000 万美元,2022 年 7 月)在工业环境中作为智能语音接口运作。

  • 零售:Evabot(成立于 2016 年;A 轮融资,830 万美元,2022 年 7 月)为企业礼品提供创意,通过聊天机器人管理的问卷建议最适合用户客户的礼物;它还利用 GPT-3 为每个礼物撰写个性化的‘手写’笔记。

  • 餐馆:Valyant AI(成立于 2017 年;种子轮,400 万美元,2022 年 4 月)开发了一个用于餐馆、零售和服务行业的专有对话 AI 平台;ConverseNow(成立于 2018 年;A 轮融资,1000 万美元,2022 年 8 月)提供了一个对话平台,用于餐馆自动化从高容量语音渠道中接单的过程。

  • 儿童:Snorble(成立于 2019 年;种子轮,1000 万美元,2022 年 4 月)制造了一款儿童就寝机器人,配备了可以讲故事、带孩子进行呼吸练习,并播放伴有灯光秀的舒缓音乐的语音驱动助手。

7. 总结

所以你可以看到:2022 年有 173 家公司获得了用于利用 NLP 技术的工具和应用程序的启动资金,总投资额刚刚超过 18 亿美元。表 1 展示了根据这里用来结构化领域的类别进行的投资细分,但请记住第二部分中关于覆盖范围和分类的各种警告。最好的情况是,我们可以将这一区分视为大致指示活动所在的地方。

对我来说,有几个点特别引人注目。以下是我从这次活动中总结出的十大要点。

  1. 尽管谷歌在搜索领域几乎垄断——通常报道显示其占有超过 80%的搜索引擎使用率——投资者仍然认为搜索领域值得投入资金。这里的大部分活动并不直接与谷歌竞争,而是更多地面向企业搜索和其他特定用途,但正当这篇文章撰写时,据报道谷歌的管理层已宣布‘红色警戒’,以应对 OpenAI 的 ChatGPT 问答聊天机器人的出现,以及担心这种方法可能会重新发明或取代传统的互联网搜索引擎。值得注意的是,You.com 迅速将类似 ChatGPT 的模型和界面集成到其搜索引擎中Perplexity.aiNeeva也在尝试将传统搜索与 LLM 结合起来。

  2. 文档 AI 似乎因深度学习技术的应用而经历了近期的复兴;越来越多的应用表明,文档处理领域的实际解决方案不能仅限于无实体的文本,而必须采取“整体文档”立场,以产生价值。我怀疑我们会看到智能字符识别与基于大型语言模型的文本处理的进一步进展,尽管目前还处于初期阶段:我今年早些时候有机会尝试了多种文档 AI 产品,但对它们在从被认为相当简单的表格中可靠提取信息的困难感到失望。

  3. 广义上的情感分析关注的是语言使用的实用方面。我们已经从早期仅仅确定电影评论或产品推荐的极性中走了很远,情感分析技术的变体在两个关键领域引起了显著关注,这两个领域在十年前无人察觉:沟通效果分析(无论是在销售电话还是会议场合中),以及内容审核。特别是后者在面对社交媒体平台的分裂性和恶意性,以及像欧洲数字服务法这样的法规影响时,似乎充满了增长的潜力。

  4. 文本生成的商业应用正处于剧烈的范式转变的边缘,过去十年中以模板为基础的技术相比于最近在大型语言模型基础上的文本生成技术显得如此平凡,以至于它们现在难以被认为是 AI 的一部分。这些早期的方法在可预见的未来仍将发挥作用,但它们正越来越远离前沿。当然,我们仍然需要克服大型语言模型不是可靠真相来源的问题。

  5. 机器翻译感觉上像是一个基本解决的问题,或者至少是一个在投资前值得再三考虑的问题。总是有改进的空间,但这些改进可能来自于谷歌和微软等大型公司之间的竞争;从投资者的角度来看,未来在于与其他技术的有趣配置功能的打包,或者解决有趣且创新的用例。

  6. 对话 AI 通常是一个令我困惑的领域:这是一个非常密集的空间,往往很难看出任何给定公司的独特卖点。我不愿意在这个领域寻找开发工具或应用程序开发人员,我也很难理解这里的投资决策动机。也许像 ChatGPT 这样的技术会在这个领域带来变化,尽管再次提到的关于需要说实话的棘手警告仍然存在。就目前而言,我认为这里的创新发展空间在于对话分析,以服务于目标,比如识别问题、沟通技能培训以及我们尚未想到的其他用途。

  7. 关于 Covid 没有什么积极的说法,但它确实推动了在线会议技术应用程序的发展,如 Zoom、Microsoft Teams 和 Google Meet;这反过来又开辟了一个全新的技术领域,即会议生产力工具。如上所述,这些工具通常是为对话智能开发的多方变体,但鉴于其处理的内容性质——我们可能称之为“长篇对话”——它提供了利用几乎所有类型的 NLP 技术的机会。我认为这是一个值得密切关注的领域。

  8. 与几年前相比,语音合成现在达到了非常高的标准,我认为这是一个我们只能期待逐步改进的领域。现在语音配音已经成为了一个重点,尤其是在与视频中的合成唇动相结合时;例如 Netflix 的整个目录中逼真的视听翻译潜力巨大。

  9. 法律科技和健康科技一直是 NLP 的关键应用领域,并且可能会继续如此。值得注意的是——从我来看,这有点令人担忧——LLMs 正在进入法律科技领域,健康科技也肯定会跟随。我不想听起来像是老生常谈,但如果要出现问题,这就是问题显现的地方:疲惫或不专注的人类编辑可能会在 LLM 起草的合同中漏掉严重的不实信息,导致昂贵的修正,或者——更糟糕——在健康报告中,导致伤害甚至生命丧失。我们可能会看到支持者争辩,类似于自动驾驶汽车的主张,这些 LLMs 的使用在这些背景下总体上减少了错误和损害。

  10. 教育部门将如何应对大型语言模型?ChatGPT 引发了关于“大学论文终结”的媒体广泛报道,担心代写服务将变得对即使是最贫困的学生也极为容易获得;就在这篇文章撰写时,媒体纷纷报道了一位南卡罗来纳州教授发现一名学生使用该应用写哲学论文的事件。这将值得关注:技术要么朝着解决传统教育关切的方向发展,要么传统教育评估实践将不得不进行根本性的改变。

总体而言,2022 年引入了一些引人入胜的技术和解决方案。随着 GPT-4 的到来,2023 年有望成为更有趣的一年。

如果你想跟上商业 NLP 世界的最新动态,可以考虑订阅免费的This Week in NLP 新闻通讯,网址是 www.languagetechnology.com/twin

使用 Python 进行 NLP:知识图谱

原文:towardsdatascience.com/nlp-with-python-knowledge-graph-12b93146a458

作者提供的图片

SpaCy、句子分割、词性标注、Dependency parsing、命名实体识别等……

Mauro Di PietroTowards Data Science Mauro Di Pietro

·发表于Towards Data Science ·14 分钟阅读·2023 年 4 月 19 日

--

摘要

在这篇文章中,我将展示如何使用 Python 和自然语言处理构建知识图谱。

照片由Moritz Kindler提供,Unsplash

网络图是一种数学结构,用于展示点之间的关系,可以通过无向图/有向图结构进行可视化。这是一种映射连接节点的数据库形式。

知识库是来自不同来源的信息的统一存储库,如维基百科

知识图谱是一种使用图结构数据模型的知识库。简单来说,它是一种网络图,展示了现实世界实体、事实、概念和事件之间的定性关系。“知识图谱”一词首次由谷歌在 2012 年使用,以介绍他们的模型

作者提供的图片

目前,大多数公司正在构建数据湖,这是一个中央数据库,用于存储从不同来源获取的各种原始数据(即结构化和非结构化数据)。因此,人们需要工具来理解这些不同信息片段。知识图谱变得越来越流行,因为它们可以简化对大型数据集的探索和洞察发现。换句话说,知识图谱将数据和相关的元数据连接起来,因此可以用来构建组织信息资产的全面表示。例如,知识图谱可以替代你需要浏览的所有文档堆,以便找到某一特定信息。

知识图谱被认为是自然语言处理领域的一部分,因为为了构建“知识”,必须经过一个叫做“语义丰富化”的过程。由于没有人愿意手动执行这个过程,我们需要机器和 NLP 算法来为我们完成这项任务。

我将展示一些有用的 Python 代码,这些代码可以轻松地应用于其他类似的情况(只需复制、粘贴、运行),并逐行讲解代码的注释,以便你可以复制这个示例(完整代码的链接如下)。

## DataScience_ArtificialIntelligence_Utils/example_knowledge_graph.ipynb at master ·…

你现在无法执行该操作。你在另一个标签页或窗口中登录了。在另一个标签页中注销了…

github.com

我将解析维基百科并提取一个页面,该页面将用作本教程的数据集(链接如下)。

## 俄乌战争 - 维基百科

俄乌战争是俄罗斯及其支持的分裂主义者之间正在进行的国际冲突…

维基百科

具体来说,我将介绍:

  • 设置:通过Wikipedia-API读取数据和数据包。

  • 使用SpaCy进行 NLP:句子分割、词性标注、依存句法分析、命名实体识别。

  • 使用Textacy提取实体及其关系。

  • 带有NetworkX的网络图构建。

  • 带有DateParser的时间线图。

设置

首先,我需要导入以下库:

## for data
import pandas as pd  #1.1.5
import numpy as np  #1.21.0

## for plotting
import matplotlib.pyplot as plt  #3.3.2

## for text
import wikipediaapi  #0.5.8
import nltk  #3.8.1
import re   

## for nlp
import spacy  #3.5.0
from spacy import displacy
import textacy  #0.12.0

## for graph
import networkx as nx  #3.0 (also pygraphviz==1.10)

## for timeline
import dateparser #1.1.7

Wikipedia-api 是一个 Python 包装器,可以轻松解析 Wikipedia 页面。我将提取我需要的页面,排除页面底部的所有“注释”和“参考书目”:

来源于 Wikipedia

我们可以简单地写出页面的名称:

topic = "Russo-Ukrainian War"

wiki = wikipediaapi.Wikipedia('en')
page = wiki.page(topic)
txt = page.text[:page.text.find("See also")]
txt[0:500] + " ..."

在这个用例中,我将尝试通过识别和提取文本中的主题-动作-对象(因此动作即为关系)来映射历史事件。

NLP

为了构建知识图谱,我们首先需要识别实体及其关系。因此,我们需要使用 NLP 技术处理文本数据集。

当前,用于此类任务的最常用库是 SpaCy,这是一个用于高级 NLP 的开源软件,利用 Cython (C+Python)。SpaCy 使用预训练的语言模型将文本分词,并将其转换为一个通常称为 “document” 的对象,基本上是一个包含模型预测的所有注释的类。

#python -m spacy download en_core_web_sm

nlp = spacy.load("en_core_web_sm")
doc = nlp(txt)

NLP 模型的第一个输出是 句子分割:确定一个句子开始和结束的位置的问题。通常,通过基于标点符号拆分段落来完成。让我们看看 SpaCy 将文本拆分成了多少个句子:

# from text to a list of sentences
lst_docs = [sent for sent in doc.sents]
print("tot sentences:", len(lst_docs))

作者提供的图片

现在,对于每个句子,我们将提取实体及其关系。为此,我们首先需要理解 词性标注 (POS tagging): 将句子中的每个单词标记上适当的语法标签的过程。以下是可能的标签的完整列表(截至今天):

  • ADJ: 形容词,例如 big, old, green, incomprehensible, first

  • ADP: 介词(前置词/后置词)例如 in, to, during

  • ADV: 副词,例如 very, tomorrow, down, where, there

  • AUX: 助动词,例如 is, has (done), will (do), should (do)

  • CONJ: 连词,例如 and, or, but

  • CCONJ: 并列连词,例如 and, or, but

  • DET: 限定词,例如 a, an, the

  • INTJ: 感叹词,例如 psst, ouch, bravo, hello

  • NOUN: 名词,例如 girl, cat, tree, air, beauty

  • NUM: 数字,例如 1, 2017, one, seventy-seven, IV, MMXIV

  • PART: 语气词,例如 ‘s, not

  • PRON: 代词,例如 I, you, he, she, myself, themselves, somebody

  • PROPN: 专有名词,例如 Mary, John, London, NATO, HBO

  • PUNCT: 标点符号,例如 ., (, ), ?

  • SCONJ: 从属连词,例如 if, while, that

  • SYM: 符号,例如 $, %, §, ©, +, −, ×, ÷, =, 😃, 表情符号

  • VERB: 动词,例如 run, runs, running, eat, ate, eating

  • X: 其他,例如 sfpksdpsxmsa

  • SPACE: 空格,例如

仅仅进行词性标注是不够的,模型还试图理解词对之间的关系。这项任务称为 依存句法分析。以下是所有可能的标记列表(截至今天):

  • 从句修饰名词: clausal modifier of noun

  • 形容词补足语: adjectival complement

  • 副词从句修饰语: adverbial clause modifier

  • 副词修饰语: adverbial modifier

  • 施事者: agent

  • 形容词修饰语: adjectival modifier

  • 同位修饰语: appositional modifier

  • 属性: attribute

  • 辅助词: auxiliary

  • 辅助动词(被动): auxiliary (passive)

  • 格标记: case marker

  • 并列连词: coordinating conjunction

  • 从句补足语: clausal complement

  • 复合修饰语: compound modifier

  • 连接词: conjunct

  • 从句主语: clausal subject

  • 从句主语(被动): clausal subject (passive)

  • 与格: dative

  • 依存词: unclassified dependent

  • 限定词: determiner

  • 直接宾语: direct object

  • 虚词: expletive

  • 感叹词: interjection

  • 标记: marker

  • 元修饰语: meta modifier

  • 否定修饰语: negation modifier

  • 名词修饰语: modifier of nominal

  • 名词短语作为副词修饰语: noun phrase as adverbial modifier

  • 名词主语: nominal subject

  • 名词主语(被动): nominal subject (passive)

  • 数量修饰语: number modifier

  • 对象谓词: object predicate

  • 并列语法: parataxis

  • 介词补足语: complement of preposition

  • 介词宾语: object of preposition

  • 所有格修饰语: possession modifier

  • 前关联连词: pre-correlative conjunction

  • 前限定词: pre-determiner

  • 介词修饰语: prepositional modifier

  • 词素: particle

  • 标点: punctuation

  • 数量词修饰语: modifier of quantifier

  • 关系从句修饰语: relative clause modifier

  • 根: root

  • 开放从句补足语: open clausal complement

让我们举个例子来理解词性标注和依存句法分析:

# take a sentence
i = 3
lst_docs[i]

让我们检查一下 NLP 模型预测的词性和依存标记:

for token in lst_docs[i]:
    print(token.text, "-->", "pos: "+token.pos_, "|", "dep: "+token.dep_, "")

作者提供的图片

SpaCy 还提供了一个 图形工具 来可视化这些标注:

from spacy import displacy

displacy.render(lst_docs[i], style="dep", options={"distance":100})

作者提供的图片

最重要的词素是动词(POS=VERB),因为它是句子意义的根源(DEP=ROOT)。

作者提供的图片

辅助粒子,例如副词和介词(POS=ADV/ADP),通常作为修饰语与动词相连(DEP=mod),因为它们可以修饰动词的意义。例如,“travel to” 和 “travel from” 尽管根词相同(“travel*”),却有不同的含义。

作者提供的图片

在与动词相关联的词中,必须有一些名词(POS=PROPN/NOUN),它们充当句子的主语和宾语(DEP=nsubj/obj*)。

作者提供的图片

名词通常靠近一个形容词(POS=ADJ),这个形容词充当它们意义的修饰词(DEP=amod)。例如,在“好人”和“坏人”中,形容词给名词“人”带来了相反的含义。

作者提供的图片

SpaCy 执行的另一个酷炫任务是 命名实体识别(NER)。命名实体是一个“现实世界的对象”(即人、国家、产品、日期),模型可以识别文档中的各种类型。以下是可能标签的完整列表(截至今日):

  • 人物: 人,包括虚构的。

  • 国家/宗教/政治组织: 国籍、宗教或政治团体。

  • 设施: 建筑物、机场、公路、桥梁等。

  • 组织: 公司、机构、组织等。

  • GPE: 国家、城市、州。

  • 地点: 非 GPE 位置、山脉、水体。

  • 产品: 物品、车辆、食品等。(不是服务。)

  • 事件: 命名的飓风、战役、战争、体育赛事等。

  • 艺术作品: 书籍、歌曲等的标题。

  • 法律: 成为法律的命名文档。

  • 语言: 任何命名的语言。

  • 日期: 绝对或相对的日期或时期。

  • 时间: 小于一天的时间。

  • 百分比: 包括“%”的百分数。

  • 货币: 包括单位的货币值。

  • 数量: 以重量或距离为单位的测量值。

  • 序数: “第一”、“第二”等。

  • 基数: 不属于其他类型的数字。

让我们看看我们的例子:

for tag in lst_docs[i].ents:
    print(tag.text, f"({tag.label_})") 

作者提供的图片

或者使用 SpaCy 的图形工具:

displacy.render(lst_docs[i], style="ent")

作者提供的图片

这在我们想要向知识图谱中添加几个属性时非常有用。

继续,使用 NLP 模型预测的标签,我们可以提取实体及其关系。

实体与关系提取

这个想法非常简单,但实现起来可能会很棘手。对于每个句子,我们将提取主语和宾语及其修饰词、复合词和它们之间的标点符号。

这可以通过两种方式完成:

  1. 手动地,你可以从基线代码开始,这些代码可能需要稍微修改和适应你的特定数据集/用例。
def extract_entities(doc):
    a, b, prev_dep, prev_txt, prefix, modifier = "", "", "", "", "", ""
    for token in doc:
        if token.dep_ != "punct":
            ## prexif --> prev_compound + compound
            if token.dep_ == "compound":
                prefix = prev_txt +" "+ token.text if prev_dep == "compound" else token.text

            ## modifier --> prev_compound + %mod
            if token.dep_.endswith("mod") == True:
                modifier = prev_txt +" "+ token.text if prev_dep == "compound" else token.text

            ## subject --> modifier + prefix + %subj
            if token.dep_.find("subj") == True:
                a = modifier +" "+ prefix + " "+ token.text
                prefix, modifier, prev_dep, prev_txt = "", "", "", ""

            ## if object --> modifier + prefix + %obj
            if token.dep_.find("obj") == True:
                b = modifier +" "+ prefix +" "+ token.text

            prev_dep, prev_txt = token.dep_, token.text

    # clean
    a = " ".join([i for i in a.split()])
    b = " ".join([i for i in b.split()])
    return (a.strip(), b.strip())

# The relation extraction requires the rule-based matching tool, 
# an improved version of regular expressions on raw text.
def extract_relation(doc, nlp):
    matcher = spacy.matcher.Matcher(nlp.vocab)
    p1 = [{'DEP':'ROOT'}, 
          {'DEP':'prep', 'OP':"?"},
          {'DEP':'agent', 'OP':"?"},
          {'POS':'ADJ', 'OP':"?"}] 
    matcher.add(key="matching_1", patterns=[p1]) 
    matches = matcher(doc)
    k = len(matches) - 1
    span = doc[matches[k][1]:matches[k][2]] 
    return span.text

让我们在这个数据集上试试,并查看常见的例子:

## extract entities
lst_entities = [extract_entities(i) for i in lst_docs]

## example
lst_entities[i]

## extract relations
lst_relations = [extract_relation(i,nlp) for i in lst_docs]

## example
lst_relations[i]

## extract attributes (NER)
lst_attr = []
for x in lst_docs:
    attr = ""
    for tag in x.ents:
        attr = attr+tag.text if tag.label_=="DATE" else attr+""
    lst_attr.append(attr)

## example
lst_attr[i]

2. 你也可以使用 Textacy,这是一个建立在 SpaCy 之上的库,用于扩展其核心功能。这更友好且通常更准确。

## extract entities and relations
dic = {"id":[], "text":[], "entity":[], "relation":[], "object":[]}

for n,sentence in enumerate(lst_docs):
    lst_generators = list(textacy.extract.subject_verb_object_triples(sentence))  
    for sent in lst_generators:
        subj = "_".join(map(str, sent.subject))
        obj  = "_".join(map(str, sent.object))
        relation = "_".join(map(str, sent.verb))
        dic["id"].append(n)
        dic["text"].append(sentence.text)
        dic["entity"].append(subj)
        dic["object"].append(obj)
        dic["relation"].append(relation)

## create dataframe
dtf = pd.DataFrame(dic)

## example
dtf[dtf["id"]==i]

作者提供的图片

让我们也使用 NER 标签(即日期)提取属性:

## extract attributes
attribute = "DATE"
dic = {"id":[], "text":[], attribute:[]}

for n,sentence in enumerate(lst_docs):
    lst = list(textacy.extract.entities(sentence, include_types={attribute}))
    if len(lst) > 0:
        for attr in lst:
            dic["id"].append(n)
            dic["text"].append(sentence.text)
            dic[attribute].append(str(attr))
    else:
        dic["id"].append(n)
        dic["text"].append(sentence.text)
        dic[attribute].append(np.nan)

dtf_att = pd.DataFrame(dic)
dtf_att = dtf_att[~dtf_att[attribute].isna()]

## example
dtf_att[dtf_att["id"]==i]

作者提供的图片

现在我们已经提取了“知识”,可以构建图谱。

网络图

标准的 Python 库用于创建和操作图网络是NetworkX。我们可以从整个数据集开始创建图,但如果节点太多,视觉效果会变得混乱:

## create full graph
G = nx.from_pandas_edgelist(dtf, source="entity", target="object", 
                            edge_attr="relation", 
                            create_using=nx.DiGraph())

## plot
plt.figure(figsize=(15,10))

pos = nx.spring_layout(G, k=1)
node_color = "skyblue"
edge_color = "black"

nx.draw(G, pos=pos, with_labels=True, node_color=node_color, 
        edge_color=edge_color, cmap=plt.cm.Dark2, 
        node_size=2000, connectionstyle='arc3,rad=0.1')

nx.draw_networkx_edge_labels(G, pos=pos, label_pos=0.5, 
                         edge_labels=nx.get_edge_attributes(G,'relation'),
                         font_size=12, font_color='black', alpha=0.6)
plt.show()

图片由作者提供

知识图谱使得从大局上看到事物之间的关系成为可能,但像这样是相当无用的……所以最好是根据我们要寻找的信息应用一些过滤器。对于这个例子,我将只取涉及最频繁实体(基本上是最连接的节点)的图的一部分:

dtf["entity"].value_counts().head()

图片由作者提供

## filter
f = "Russia"
tmp = dtf[(dtf["entity"]==f) | (dtf["object"]==f)]

## create small graph
G = nx.from_pandas_edgelist(tmp, source="entity", target="object", 
                            edge_attr="relation", 
                            create_using=nx.DiGraph())

## plot
plt.figure(figsize=(15,10))

pos = nx.nx_agraph.graphviz_layout(G, prog="neato")
node_color = ["red" if node==f else "skyblue" for node in G.nodes]
edge_color = ["red" if edge[0]==f else "black" for edge in G.edges]

nx.draw(G, pos=pos, with_labels=True, node_color=node_color, 
        edge_color=edge_color, cmap=plt.cm.Dark2, 
        node_size=2000, node_shape="o", connectionstyle='arc3,rad=0.1')

nx.draw_networkx_edge_labels(G, pos=pos, label_pos=0.5, 
                        edge_labels=nx.get_edge_attributes(G,'relation'),
                        font_size=12, font_color='black', alpha=0.6)
plt.show()

图片由作者提供

这样更好。如果你想将其做成 3D,使用以下代码:

from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(15,10))
ax = fig.add_subplot(111, projection="3d")
pos = nx.spring_layout(G, k=2.5, dim=3)

nodes = np.array([pos[v] for v in sorted(G) if v!=f])
center_node = np.array([pos[v] for v in sorted(G) if v==f])

edges = np.array([(pos[u],pos[v]) for u,v in G.edges() if v!=f])
center_edges = np.array([(pos[u],pos[v]) for u,v in G.edges() if v==f])

ax.scatter(*nodes.T, s=200, ec="w", c="skyblue", alpha=0.5)
ax.scatter(*center_node.T, s=200, c="red", alpha=0.5)

for link in edges:
    ax.plot(*link.T, color="grey", lw=0.5)
for link in center_edges:
    ax.plot(*link.T, color="red", lw=0.5)

for v in sorted(G):
    ax.text(*pos[v].T, s=v)
for u,v in G.edges():
    attr = nx.get_edge_attributes(G, "relation")[(u,v)]
    ax.text(*((pos[u]+pos[v])/2).T, s=attr)

ax.set(xlabel=None, ylabel=None, zlabel=None, 
       xticklabels=[], yticklabels=[], zticklabels=[])
ax.grid(False)
for dim in (ax.xaxis, ax.yaxis, ax.zaxis):
    dim.set_ticks([])
plt.show()

图片由作者提供

请注意,图表可能很有用且可视化效果很好,但这不是本教程的主要焦点。知识图谱最重要的部分是“知识”(文本处理),然后结果可以以数据框、图形或其他图表的形式展示。例如,我可以使用 NER 识别的日期来构建一个时间轴图。

时间轴图

首先,我必须将被识别为“日期”的字符串转换为 datetime 格式。库DateParser可以解析几乎所有在网页上常见的字符串格式的日期。

def utils_parsetime(txt):
    x = re.match(r'.*([1-3][0-9]{3})', txt) #<--check if there is a year
    if x is not None:
        try:
            dt = dateparser.parse(txt)
        except:
            dt = np.nan
    else:
        dt = np.nan
    return dt

让我们将其应用于属性的数据框:

dtf_att["dt"] = dtf_att["date"].apply(lambda x: utils_parsetime(x))

## example
dtf_att[dtf_att["id"]==i]

图片由作者提供

现在,我将其与主要的实体-关系数据框合并:

tmp = dtf.copy()
tmp["y"] = tmp["entity"]+" "+tmp["relation"]+" "+tmp["object"]

dtf_att = dtf_att.merge(tmp[["id","y"]], how="left", on="id")
dtf_att = dtf_att[~dtf_att["y"].isna()].sort_values("dt", 
                 ascending=True).drop_duplicates("y", keep='first')
dtf_att.head()

图片由作者提供

最后,我可以绘制时间轴。正如我们已经知道的,完整的图表可能不会很有用:

dates = dtf_att["dt"].values
names = dtf_att["y"].values
l = [10,-10, 8,-8, 6,-6, 4,-4, 2,-2]
levels = np.tile(l, int(np.ceil(len(dates)/len(l))))[:len(dates)]

fig, ax = plt.subplots(figsize=(20,10))
ax.set(title=topic, yticks=[], yticklabels=[])

ax.vlines(dates, ymin=0, ymax=levels, color="tab:red")
ax.plot(dates, np.zeros_like(dates), "-o", color="k", markerfacecolor="w")

for d,l,r in zip(dates,levels,names):
    ax.annotate(r, xy=(d,l), xytext=(-3, np.sign(l)*3), 
                textcoords="offset points",
                horizontalalignment="center",
                verticalalignment="bottom" if l>0 else "top")

plt.xticks(rotation=90) 
plt.show()

图片由作者提供

所以最好是过滤出特定时间段:

yyyy = "2022"
dates = dtf_att[dtf_att["dt"]>yyyy]["dt"].values
names = dtf_att[dtf_att["dt"]>yyyy]["y"].values
l = [10,-10, 8,-8, 6,-6, 4,-4, 2,-2]
levels = np.tile(l, int(np.ceil(len(dates)/len(l))))[:len(dates)]

fig, ax = plt.subplots(figsize=(20,10))
ax.set(title=topic, yticks=[], yticklabels=[])

ax.vlines(dates, ymin=0, ymax=levels, color="tab:red")
ax.plot(dates, np.zeros_like(dates), "-o", color="k", markerfacecolor="w")

for d,l,r in zip(dates,levels,names):
    ax.annotate(r, xy=(d,l), xytext=(-3, np.sign(l)*3), 
                textcoords="offset points",
                horizontalalignment="center",
                verticalalignment="bottom" if l>0 else "top")

plt.xticks(rotation=90) 
plt.show()

图片由作者提供

正如你所见,一旦“知识”被提取出来,你可以用任何方式绘制它。

结论

本文是关于如何用 Python 构建知识图谱的教程。我使用了几种 NLP 技术处理从维基百科解析的数据,以提取“知识”(即实体和关系),并将其存储在一个网络图对象中。

现在你明白了为什么公司正在利用 NLP 和知识图谱来映射来自多个来源的相关数据,并找到对业务有用的洞察。试想一下,通过在与单一实体(例如 Apple Inc)相关的所有文档(例如财报、新闻、推文)上应用这种模型,可以提取出多少价值。你可以快速了解所有与该实体直接相关的事实、人物和公司。然后,通过扩展网络,甚至可以获取与起始实体(A — > B — > C)不直接相关的信息。

希望你喜欢这个!如有任何问题或反馈,或想分享你的有趣项目,欢迎随时联系我。

👉 让我们联系 👈

本文是使用 Python 进行 NLP系列的一部分,还可以查看:

## 使用 NLP 进行文本总结:TextRank 与 Seq2Seq 与 BART

使用 Python 进行自然语言处理、Gensim、Tensorflow、Transformers

[towardsdatascience.com ## NLP 的文本分类:Tf-Idf 与 Word2Vec 与 BERT

预处理、模型设计、评估、Bag-of-Words、词嵌入、语言模型的可解释性

[towardsdatascience.com ## NLP 的文本分析与特征工程

语言检测、文本清理、长度、情感、命名实体识别、N-gram 频率、词向量、主题……

[towardsdatascience.com ## BERT 用于无模型训练的文本分类

当你没有标记训练集时,使用 BERT、词嵌入和向量相似性

[towardsdatascience.com ## 使用 NLP 构建 AI 聊天机器人:语音识别+Transformers

使用 Python 构建一个会话聊天机器人,与 AI 进行对话

[towardsdatascience.com

无代码机器学习平台:福音还是祸根?

原文:towardsdatascience.com/no-code-ml-platforms-boon-or-bane-eee27290245d

深入了解无代码平台如何促进加速的机器学习应用

Jojo John MoolayilTowards Data Science Jojo John Moolayil

·发表在Towards Data Science ·9 分钟阅读·2023 年 1 月 13 日

--

Scott Graham的照片,来源于Unsplash

近年来,我们看到一些大型企业和蓬勃发展的初创公司推出了多种无代码机器学习和数据科学平台。如今,大多数领先的云服务提供商至少提供一种无代码/低代码机器学习平台。微软的Azure ML Studio、亚马逊的Sagemaker Canvas和谷歌的AutoML是其中的一些例子。如果你更深入地观察它们,会发现它们的共同使命是民主化 AI/ML/DS。很长一段时间,我坚信无代码/低代码不会有效地民主化机器学习。然而,最近我改变了看法,原因可能不是你所猜测的。让我解释一下。

简短回顾

回到 2015 年,当我探索 Azure ML 工作室时,我确实感到很惊讶。那个时候的平台已经成熟,提供了丰富的功能来解决机器学习问题。数据导入、探索性数据分析、模型构建、超参数调优和部署的整个过程都可以通过拖放工具完成。这是我在这个类别中使用的第一个工具之一,我感到了一种完整感。这个工具让我实现了当时测试的目标——将一个模型部署到生产环境中,而不需要写一行代码(尽管是一个用于测试的简单模型)。然后,到 2016 年底,我确信这一类别的服务有巨大的市场,而且无代码工具很快会在机器学习问题上得到广泛采用。

然而,随着时间的推移,我几乎没有注意到这些工具在我主要参与的社区中的采用情况。这些工具确实有一些很炫的演示,但在大多数情况下,它们对我而言意义不大。渐渐地,我开始倾向于认为这些工具对于使人工智能普及来说是多余的。我的理由很简单;对于商业上重要的并最终部署到生产环境中的严肃机器学习用例而言,从不适合使用将控制权锁定在基于 UI 的工具中的工具构建。此外,对于严肃的机器学习用例,数据工程和数据处理是一个庞大的工作量。工程的庞大体量和复杂性永远无法适合过于简化的无代码工具。对我而言,无代码/低代码平台突然变成了一个被美化的工具,仅仅用于做出优秀的营销。

回到今天

最近,我开始从不同的角度看待这些工具。我觉得我可能对这些工具的看法存在偏见。这个可能性很大,因为我大多与已经熟悉某种形式的编码的数据显示科学家或在该领域经验丰富的专业人士互动。此外,我大多数时候在一个与软件工程师紧密合作的环境中工作,这些工程师帮助将研究原型转化为生产管道。因此,我们必须建立一个研究工作流程实践,以确保将研究原型和生产工件之间的转化努力降到最低。因此,我们通常选择了大数据工具支持的 Python 生态系统,这些工具运行在成熟的云平台上。在这种情况下,排除无代码解决方案是非常自然的。

为了用更广泛的视角和不同的用户群体理解现状,我开始联系我现有网络之外的人,以了解他们技术栈的变化以及无代码工具的采用情况。总体来说,在接触了相当多样化的受众后,我有了一些学习体会,这些体会最终改变了我的看法。

从新的视角看待问题

首先,我重新审视了组织在科学实践中的结构。尽管机器学习领域已经成熟,但在很多组织中仍然很少有科学职能。大多数组织在起步时都很艰难,通常团队人数也很不足。尽管这些组织中的科学问题潜力可能很大,但从一开始就很难明确大方向。从机器学习问题中发现价值并实现其业务影响的过程是缓慢且迭代的,需要有面对重大失败的心理准备。没有一种完美的科学路径能够帮助人们从识别问题到生成业务价值,像过于简化的从 A 点到 B 点的过程一样。这个过程通常是艰难且迭代的。这让我思考——不同成熟度的组织采用了什么工具?

事实上,并非所有组织都能负担得起或愿意从一开始就大规模投资于昂贵的科学技能。这个过程通常是一个未定义的路径。下图展示了从问题发现到解决以产品为驱动的科学解决方案的简化路径。[当然,每一步都有其自己的迭代,但你可以看到更大的图景。]

[作者提供的图像] - 机器学习使用案例的产品化路径示意图。

灰色区域表示给定里程碑的迭代频率。自然地,我们会有大量的想法在实现基本原型之前被淘汰,而这些原型在正式投入到严肃的原型之前还会进一步修剪,最后收窄为用于最终产品的关键优化版本。

长时间以来,我一直从不同的视角看待这些产品,并且不合理地批评了无代码平台的价值。我关键的问题是——这个解决方案对严肃的业务有多大价值? 在某些地方,这看起来对重要的用例来说有些多余。但后来我意识到,我是在从一个不缺乏机器学习技能和工程资源的工作环境的角度来进行比较。然而,并非所有地方都是这样。大多数组织没有足够的资源和团队来大规模支持科学用例验证,也可能没有成熟的科学职能来支持这一点。

下图展示了无代码平台在业务问题生命周期中的有效性的思考过程。

[作者提供的图像] — 无代码工具在问题生命周期阶段的有效性示意图

我的偏见是由于对问题更成熟阶段的倾斜。然而,这只是一个具体而狭窄的视角。每个组织根据其科学成熟度的位置,将拥有不同的工具。如果我们对大多数组织的解决问题过程进行概括,我们需要理解并非所有的想法都会变成生产产品。想法、原型、MVP 和最终产品的比例看起来像是多米诺骨牌倒退的顺序。因此,需要用不同的工具以不同的方式支持问题的每个生命周期阶段。下表将更深入地探讨上述问题生命周期阶段。

[图片来自作者]

如上所示,如果我们将问题的生命周期拆解成更小的里程碑,我们可以看到不同阶段对技能和资源的不同需求。专门的科学团队绝不是节省资源的团队,他们的成本通常与工程团队相当或更高。因此,小型组织中通常没有太多这样的团队。那么,那些可能没有专门科学团队的人员如何在不进行重大妥协的情况下更快地完成这个过程呢?

这时我开始看到无代码平台的新价值。

重新考虑无代码平台

在解决方案旅程中使用一刀切的解决方案是否合理?当然不!随着问题的进展会发生什么变化?在理想的世界中,为了使数据科学和机器学习变得普及,确实需要一个生态系统来促进在迭代频率非常高且失败率高的领域中更快地移动。为了支持创意阶段,我们已经拥有了最好的工具,比如白板、PPT、文档、写作等。对于基础和严肃的原型,我们是否有能够更快推进的工具?有人认为 Python 已经被充分普及,可以促进这一过程。这可能只是部分正确;不是所有分析师都精通 Python 和 SQL。因此,存在可以填补这一空白的东西。

这就是为什么我强烈认为无代码解决方案可以蓬勃发展的原因。

无代码解决方案提供了什么?

本质上,无代码机器学习平台显著降低了普通人接受数据科学的门槛。通过用模块化构建块整洁地抽象出关键复杂科学组件,这些平台支持从构思到实验和验证的过程,并留有额外的自定义空间。这些工具提供了稳健的默认设置,确保大多数任务可以在用户几乎不需自定义输入的情况下顺利进行。因此,这些工具通过简化数据工程和模型构建任务的过程,加快了验证创意的进程。此外,这些工具还简化了结果(成果)的消费过程,并支持更广泛的 go/no-go 决策,适用于大规模实验。对于首次接触机器学习的小型组织或新团队,这些工具以实惠且有效的价格点提供了巨大的价值,可以自信地加快初步步伐。

无代码机器学习平台不提供什么?

无代码工具绝不是大型严肃解决方案的替代品。它不是一个永久的工具集,无法应对从原型到生产的整个过程。当业务问题经过充分验证并开始扩展时,无代码工具的价值将开始减小,提示需要更精细的控制。无代码工具缺乏使大规模生产问题运行的复杂性。

那么它适合什么场景呢?

机器学习和数据科学用例的迭代和实验性特征确实使其成为一个资源密集型的倡议。不断增长的科技企业和/或最近刚刚采用机器学习进行业务的企业需要时间来验证创意,然后再做进一步投入。我们目前拥有的工具集可能并不是新团队开始数据科学工作的最友好和易于上手的手段。虽然它肯定是一个稳健的工具,但对于初学者来说可能不太理想。这正是民主化 AI/ML 工具开始发挥关键作用的地方。一个组织能否以仅一名员工和没有前期成本的低数据科学投资开始新的旅程?创意能否在没有严重工程努力和有限科学成熟度的情况下得到验证?一个有前途的创意能否逐步扩展,直到团队对大规模投资有信心?所有这些问题的明确答案并不总是容易获得,尤其是在现有的 Python 机器学习宇宙中;需要有更多的工具来提供支持。对于那些需要快速验证并有效地迭代到成熟的规模问题,无代码机器学习解决方案恰到好处。

当我们民主化 AI 和 ML 工具时,我们开始为生态系统提供合适的工具,以像抚养新生儿直到上幼儿园一样培养创意。一旦进入幼儿园,或许是时候寻求更好的工具了。但在此之前,无代码平台是你最好的朋友。

总结思考

一般来说,高质量的生产材料不建议通过过于简化的工具进行交付。但科学用例的迭代和实验性特征使得从一开始就进行资源密集型工程并不合适。问题的不同阶段以及组织的科学成熟度需要不同的工具来导航科学之旅。无代码/低代码解决方案提供了一个很好的起点,并有效降低了组织探索该领域是否对其业务有价值的门槛。当组织认真起来时,才有可能需要迁移到提供更多细化控制的工具和服务。在那之前,无代码工具将是您团队探索的好伙伴。

你好,感谢阅读!如果你想获取我即将发布的博客更新,请在 Twitter 上关注我,以便第一时间收到新帖通知。再次感谢!

TensorFlow 中不再出现 OOM 异常

原文:towardsdatascience.com/no-more-oom-exceptions-during-hyperparameter-searches-in-tensorflow-26e6e3069bc9

使用包装函数来避免 OOM 异常

Pascal JanetzkyTowards Data Science Pascal Janetzky

·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 4 月 1 日

--

现在是 2023 年。机器学习不再是炒作,而是日常产品的核心。越来越快的硬件使得训练更大规模的机器学习模型成为可能——而且时间更短。每天在 arXiv 提交的关于机器学习或相关领域的论文约有 100 篇,至少三分之一的论文利用了硬件的能力进行超参数搜索以优化其使用的模型。这不是很简单吗?只需选择一个框架——Optuna, wandb,随便哪个——接入你的常规训练循环,然后…

OOM 错误。

至少,这种情况在 TensorFlow 中经常发生。

图片由 İsmail Enes Ayhan 提供,来源于 Unsplash

当前状态

缺乏适当释放 GPU 内存的功能引发了许多讨论和问题,特别是在 StackOverflow 或 GitHub 等问答论坛中 (1, 2, 3, 4, 5, 6)。对于每个问题,都提出了一组类似的解决方法:

  • 限制 GPU 内存增长

  • 使用 numba 库清除 GPU

  • 使用本地 TF 函数,它们应该可以做到这一点

  • 切换到 PyTorch

本博文提出了我对这一长期存在且烦人的 OOM 异常问题的解决方案。在过去几年中进行了一些超参数优化运行后,我最近遇到了编程中最令人生畏的问题之一:

不易重复的异常,通常发生在百次运行中的某一次。在我的情况下,它偶尔发生在优化运行中选择了特别具有挑战性的参数组合时。比如,较大的批量大小和较多的卷积滤波器,这两者都对 GPU 内存造成压力。

有趣但更令人恼火的是,当我从新系统本地初始化这样的模型时——即之前没有运行其他 TF 代码——我可以成功运行模型。在检查了其他影响因素,如 GPU 大小、CUDA 版本和其他要求后,我没有发现此部分的错误。因此,必须是同一程序内神经网络的重复初始化导致了 OOM 错误。

在继续之前,我想澄清:OOM 错误可能有其他原因,迄今为止这些原因尚未明确指出。特别是,显然不可能将一个内存占用过大的模型放到一个内存过小的 GPU 上。

当模型在物理上过大时,解决方案是修改模型——相关关键词包括 混合精度训练逐层训练蒸馏剪枝——以及在多个加速器上运行训练——关键词:分布式训练,模型并行——或者切换到具有更多可用内存的计算设备。但这超出了本文的范围。

问题

返回到在超参数优化期间遇到的可怕 OOM 异常,我认为首先从概念上展示导致这种错误的原因是至关重要的。因此,请考虑以下可视化,其中我绘制了一个 GPU 及其内存块:

GPU 及其内存块的草图。图片由作者提供。

虽然这是一个简化的描述,但每次初始化新模型时,都会消耗另一块内存。

每个新模型都会占用 GPU 内存。图片由作者提供。

最终,加速器上的空间将不再剩余,导致 OOM 错误。模型越大,这种情况发生得越快。理想情况下,我们应在超参数试验结束时调用清理函数,为下一个模型释放内存。即使是可能适配到干净 GPU 上的网络,在没有进行垃圾回收时也会失败,因为宝贵的内存已经被之前模型的碎片占用。

我想称之为TensorFlow 在 GPU 上的作用域的部分在之前的草图中没有显示。默认情况下,TensorFlow 保留整个资源——这很聪明,因为后续请求增加配额将导致执行瓶颈。在下图中,TensorFlow 进程被勾画为“悬停”在整个 GPU 上:

TF 进程占用 GPU 以实现快速数据访问。图片由作者提供。

在其生命周期内,悬停的 TF 进程就像是即将到来的 TF 操作的占位符,这些操作在 GPU 及其内存上运行。

然后,一旦进程终止,TF 释放内存,使其可供其他程序使用。问题是:通常,作为超参数研究的一部分,所有网络初始化都是在同一个进程中(例如,在一个 for 循环中)进行的,这个进程悬停在 GPU 上。

解决方案,我希望,很明显:为每个试验/模型配置使用不同的进程。

这种方法适用于所有 TensorFlow 版本,尤其是旧版本(这些版本自然不会收到功能更新)。新的版本可能会增加适当的内存清理功能,但旧版本将缺乏这一可能性。所以,事不宜迟,这里是解决方案。

解决方案

为了在每个超参数试验中运行自己的进程,我们需要使用本地的 python multiprocessing 库。更新代码以使用这个包所需的努力出乎意料地少。

从全局视角来看,负责运行代码的过程——即主要驱动函数——需要修改以接受一个额外的参数,即队列。我们无需进一步深入,但这个队列对象作为调用函数(即调用 main()、run()、train() 或类似函数的函数)与主函数之间的桥梁。在主函数中,我们基本上可以保持现状。正如在参数搜索中常见的做法,改进训练/评估代码的返回值是优化实践的目标。

我们之前通过return语句返回这个值,而现在我们将这个目标值放入队列对象中。然后,我们从调用函数中提取它,并将其传递给超参数框架。

从优化框架的角度来看,变化不大。最显著的变化是它不再直接与训练函数“通信”,而是通过一个中介函数进行。概念上,这种更新的设置如下所示。

旧的默认代码(顶部)和更新后的代码(底部)的比较。在新代码中,优化框架通过包装函数与机器学习代码进行通信。

但除了这个变化之外,我们可以像往常一样进行参数搜索。从概念上讲,使用 python/伪代码的混合,让我向你展示修改后的代码的样子。

首先,我们必须将选择当前试验的参数组合的逻辑从主函数中移除(如果它曾被放在那儿)。这部分应该在我们添加进程管理之前完成。然后,我们使用多进程工具为 TF 相关代码生成一个进程,包装常用的 main()/train()/run()/等 函数:

def wrapper_function(): ← new
  hyperparameters = get_hyperparameters()
  …
  queue = Queue()
  process = Process(…)
  …
  results = queue.get()
  return results

与 TF 的通信,特别是结果的收集,通过队列对象完成,这就是为什么我们也将其传递给函数(稍后详细说明)。然后我们开始模型训练,等待其完成,并从队列对象中获取结果,将其传递给调用函数(通常是超参数框架)。这个值 —— 或者在多目标优化的情况下,这些值 —— 是超参数框架最终看到的;它对进程的细节一无所知。

在训练代码中,我们需要包含一种将参数优化目标传递到队列对象中的方法。在这里,我假设通常会返回模型在验证集上的表现,因为该指标(在该子集上)通常被用作优化目标。

(适应你的用例;概念是相同的)。

为此:

  1. 查找你退出训练/评估函数并将结果返回给调用者的地方。

  2. 在这里,将优化框架需要知道的一切信息收集到一个列表或其他数据集合中。

  3. 在下一步,将此列表传递给队列。

如前面段落所述,这些值可以从队列中查询,并从那里传递回优化框架:

def train(queue): ← modified
  …
  # do usual TF stuff
  load_model()
  load_data()
  …
  #collect results
  return_data = … ← new
  queue.put(return_data) ← new

关键点是:当我们获得评估结果时,TF 进程已经完成 —— 清除 GPU 内存。因此,训练和评估过程可以在下一次调用中访问干净的(内存方面的)GPU,并且使用新的超参数集。特别是,创建待评估的模型不会争夺剩余的 GPU 内存,因为之前的模型及其痕迹已经被移除。这样,我们避免了 OOM 问题。

一个简单的例子

此时,你可能会想知道如何将这些内容转化为代码。我听到了,虽然每个人的需求不同,但我构建了以下简单的设置来给你一个如何工作的概念。我们将使用 Optuna 来优化卷积神经网络的超参数。通过这些代码,你可以加载你喜欢的数据集,并在其上优化 CNN。

这里是 Optuna 代码的部分:

在这段代码中,请注意 Optuna 调用的函数。它不是实际的训练代码,而是一个中间函数,即包装函数。如前一节所述,包装函数会使用将要评估的超参数集调用底层训练代码。

至于训练代码,这段代码基本遵循标准设置:加载数据子集、初始化模型、训练并评估模型。新颖之处在于函数的最后几行。在这里,验证子集上的结果被传递到队列中。

这就是 OOM-free 超参数优化的核心代码**。完整代码可以在 这里 找到。

*我的经验表明,应该在调用函数中选择超参数,即在进程修改生效之前。

**唯一的例外是当模型实在太大而无法适应 GPU 内存时。

总结

在这篇博客文章中,我介绍了我对 TensorFlow 中长期存在的 OOM 异常挑战的解决方案。虽然自然地,我们不能将物理上过大的模型放入 GPU 中,但这种方法适用于研究关键的超参数搜索。从概念上讲,这个解决方案很简单:为每次试验(超参数组合)启动一个新进程。使用本地 Python 库,只需添加几行新代码和对现有设置进行一些更改即可实现。最后,将优化框架指向一个中间包装函数,而不是直接指向训练代码。

数据科学中没有“科学”?

原文:towardsdatascience.com/no-science-in-data-science-361ee914dc27?source=collection_archive---------2-----------------------#2023-11-01

我们是在分析数据,还是在工程现实?

Corné de RuijtTowards Data Science Corné de Ruijt

·

关注 发表在 Towards Data Science ·16 min read·2023 年 11 月 1 日

--

Vernagt Ferner — 提洛尔阿尔卑斯山。图片由作者提供。

引言

不用担心,我选择这个标题不是为了抱怨数据科学不是“真正的科学”(无论那意味着什么)。相反,我希望提供一些不同的视角来理解作为数据科学家的意义,希望能帮助在分析问题时考虑不同的视角。最初,我打算将这篇文章命名为“我从柏拉图和书呆子那里学到的作为数据科学家的经验”[1]。柏拉图和书呆子是爱德华·阿什福德·李最近写的一本书。我发现这本书的标题并没有真正传达书的内容。从某种程度上说,这符合李的写作风格。他走了许多有趣的旁路——他称之为“书呆子风暴”——使得书的主要信息仍然不清晰。在这篇文章中,我将专注于他书中的一个方面:科学和工程学之间的区别以及这如何与数据科学中的常见思维碰撞。虽然李写这本书是为了大众,但我发现这本书对我反思自己在计算机科学/信息技术领域的研究特别有用。

我猜这篇文章可以被看作是李描述工程学和科学之间差异的一个推论。特别是,通过李的论点,普遍的看法是“[数据科学是一个] 统一统计学、数据分析、信息学及其相关方法的概念 [以便于] 理解和分析实际现象 [通过数据]。”(强调是我加的)[2],目前在维基百科的数据科学页面上是第一条引用,似乎是一个误解。在这篇文章的其余部分,我将尝试解释为什么在李的论点下,这是这样。我将首先更仔细地查看我们最常指的“实际现象”是什么意思。其次,我将反思李关于工程学和科学的区分,因为既然我们谈论的是数据科学,这似乎是相关的。最后,我将对这可能意味着什么进行哲学上的思考,尤其是对数据科学家的工作。

Vernagthutte — 提罗尔阿尔卑斯山。图片作者

将现实付诸实践(以及为什么这很重要)

最近我经常听 BBC 播客《你已死于我》[3],这是一个历史喜剧播客。我特别喜欢的一集是关于时间测量的历史。在这一集中,他们讨论了我们现在理所当然的许多时间方面。尽管我经常使用时间的概念,但这集让我意识到我对这些概念的起源知之甚少(甚至现在也不多)。我只是接受一个小时包含 60 分钟,工作时间和休闲时间之间有差异,或者时间存在于一个连续体中。所有这些关于时间的概念都是人们的视角,仍然在我们的文化中回响而我们却未曾意识到。难怪李将这些称为“未知的已知”。未知的已知也被称为另一种名称:心理模型。

科学的一项重要任务是意识到这些未知的已知,并将其与现实进行对比。科学史上充满了这样的例子。在哥白尼之前,普遍的心理模型是太阳绕地球运转。在我们的历史上,放血或使用汞曾被认为是医学上合理的。统计学本身也充满了未知的已知。在 1750 年之前,结合观察以提高准确性,例如使用均值而不是选择最`可靠’的单一观察点,这种做法很少见且常常受到批评。当欧拉试图预测木星和土星的运动时,他使用了他认为最可能的数据点,而不是对所有数据进行平均以预测[4, p.p. 25–31]。另一个例子是,19 世纪的经济学家在将多个产品的价格合并成一个指数时不得不为自己辩护[4, Ch. 1],尽管当时统计方法在天文学家中已经广泛使用。

对于数据科学家来说,将心理模型与现实对比似乎是其活动的核心。这不仅适用于像 A/B 测试这样的琐碎例子,也适用于预测和优化。人们有很强的倾向将信息片段连接起来,形成一个连贯的故事,即使这些信息片段实际上是断裂的(所谓的事后偏见)。一旦这样的故事在某人的脑海中形成,就很难改变。事实上,确认偏误会暗示这个人主要会寻找支持该故事的证据。即使你可以证明预测模型更准确,或者规划算法比人类做得更好,你仍然会发现有人争辩说模型在这个或那个例外情况中不起作用。或者说,由于模型可能是一个黑箱,因此不能被信任。我个人在最近关于 ChatGPT 的发展中经历了这种情况。很难接受这种软件仅凭几句指令可能写出比我更好的博客。这只让我想到 ChatGPT 一定是不准确、不安全或不可靠的理由。你创建的模型会以开放的心态被接受的想法是一个幻想。你总是要与已有的模型竞争,无论是心理上还是实际的。

我最喜欢的关于将心理模型与现实对比的名言来自 Croll 和 Yozkovitz 在他们的书《精益分析》中[5]:

“我们都在自欺欺人——有些人比其他人更严重。企业家是最自欺欺人的。[…] 说谎甚至可能是成功的企业家的前提——毕竟,你需要说服别人相信某些事情,即使没有确凿的证据。你需要信徒和你一起冒险。你需要对自己说谎,但不能到危及你业务的地步。”

“这就是数据的作用”

数据被用作镜子来验证思维模型与现实的一致性。因此,我们称之为数据科学。正如李所说的那样,一位科学家“选择(或发明)一个忠实于目标的模型”。Croll 和 Yozkovitz 的引言也强调了基于先前信念创建思维模型并非坏事:自我欺骗是一条细线。一方面,欺骗自己能够做某事可能是重要的。以冒名顶替综合症为例,人们觉得自己取得的成功只是运气,并且随时可能被揭穿为骗子。你需要“假装直到成功”,以克服这种认知偏差。另一方面,许多人对自己的决策过于自信(例如,达宁-克鲁格效应)。你的自我要么阻止你做该做的事,要么妨碍你意识到自己不应该做某事。

自我欺骗特别困难的一点在于,它在某种程度上并不是撒谎,因为这些认知偏差使你真的相信你的思维模型就是现实,而你不会接受任何与此不同的论点[22, p. 490–492]。这就是数据科学难的原因:有时候你可能拥有所有证据,但如果这些证据与决策者的思维模型相悖,你将很难说服那个人。正如谚语所说:数据科学家的角色是告诉人们他们的宝宝很丑,人们通常不喜欢听到这个。

所以,这就是我对成为数据科学家的思维模型,直到柏拉图和书呆子将其拆解。

Mt Chaberton — Claviere, Piedmont. 作者提供的图片

现实是一个思维模型的堆叠

如果你再看看前一段,你会发现我刚才称之为“现实”的东西充满了另一层思维模型。精益分析中的“精益”指的是 Eric Ries 的《精益创业》[6]中的思维模型,它提倡一种快速创业试错的方法以实现快速增长。希望这将使公司变大到可以上市。随后,企业需要快速“扩张”的想法,很可能受到硅谷指数组织[7]中成功商业模型的启发(尽管只有 0.4%的初创公司实际上扩张[8])。

进一步追溯历史,公司需要成长的观念显然是资本主义的,起源于 18 世纪亚当·斯密的思想。公众公司这一概念是 17 世纪荷兰东印度公司(VOC)的发明[9],来自一个在当时才刚刚被发明的国家,经过了四代勃艮第统治者的锤炼。或许最有趣的是货币本身的概念,李提到的塞尔(Searle)[10 p.78]认为货币是“任何人们使用并认为是货币的东西”。

就像任何曾经与六岁孩子讨论过的人可能见证的那样,我们可以通过不断地问“为什么”来不断剥离心理模型,只是为了得到新的模型。直到某一时刻,你到达了某些看似是自己思维公理的心理模型。但这里有一个问题:所有这些模型都是,嗯,模型。它们不是现实。即使是像量子力学这样的基本模型,在大尺度上考虑时也是不准确的。

Mt Chaberton — Claviere, Piedmont. 作者图片。

为什么数据科学家是工程师

你被心理模型包围着,这也是工程师进入故事的地方。李谈到科学家和工程师之间的区别:

“我们可以选择(或发明)一个忠于目标的模型,或者我们可以选择(或发明)一个忠于模型的目标。前者是科学家的本质。后者是工程师的本质。” P.197。

举个例子,李提到了晶体管。晶体管是一个开/关开关的模型。也就是说,晶体管实际上并不是开/关开关:它们是放大器。晶体管被创造得如此贴近开/关开关。目标(晶体管)忠于模型(开/关开关)。按照上述工程师和科学家的定义,晶体管的创造是一项工程工作。而且,李还表明,像法拉第定律这样的电磁学定律,虽然描述了晶体管的“真实行为”,实际上也是模型。现实只有在某些情况下遵循这些模型;在其他情况下,它们可能完全错误。事实上,科学似乎是一堆模型,每个模型都建立在之前模型的近似基础上。直到某个抽象层级,现实才会变得与模型不符。正如李在第 2–6 章所示,计算机科学中的模型堆栈,从晶体管到网站,似乎几乎是无尽的。

经济模型也不例外。下次你遇到某个人时,问问他是否可以借用他们的笔。令人惊讶的是,大多数人不会向你收取费用,甚至会坚持让你保留这支笔。在某些人际交往的情境中,一种“日常共产主义”的形式似乎是主流,而不是资本主义[11, pp. 94–102]。

但李继续争论,实际上,我们称之为科学的大部分内容是工程学。

“但更根本的是,标题[柏拉图与书呆子]将知识及技术看作是存在于人类之外的柏拉图式观念,并由人类发现的观念,与另一种观点对立,即人类创造而非发现知识和技术。[……] 标题中的书呆子是一个创造性的力量,主观且甚至古怪,而不是客观地复述现有的真理。”

读完这段话后,我有一种感觉,或许我选择了错误的职业。对我而言,数据科学的概念正是‘发现’的部分。应该作为独立观察者来分析数据,提出帮助组织增加‘价值’的‘见解’。按照这个概念,数据科学家就像阿加莎·克里斯蒂笔下的波洛探员,他寻找证据以解开谜团,但并不是谋杀计划的一部分。当然,一旦谜团解开,它会呈现给观众(对数据科学家而言,可能是稍微不那么壮观的 PowerPoint 演示),观众对解决方案的精彩表现感到惊叹。确实,这有点夸张,但解决方案应该复杂并由戴眼镜的白人发现的观念无疑是另一种思维模式的一部分[12]。根据李对科学家的定义,我们的数据科学 Poirot 确实描述了一个发明忠实于现实的模型的人,因此是一个科学家。

但现在考虑以下问题:你得到了一些历史销售数据(见下图),这些数据来自一个相对较新的产品,你和许多人几乎都确定该产品的销售在不久的将来会增长。你会做出什么预测?你会只是“让数据说话”:拟合一个线性模型,并将其外推到未来吗?还是会利用这些信息定义一个在接下来几年内具有指数增长的模型?假设你选择了后者,做出了预测,并与业务所有者分享。业务所有者随后决定启动营销机制以实现这一预测。到现在为止,你的预测已成为商业目标,销售增长甚至超出你的预测。根据李氏对工程师和科学家的讨论:这些商业结果是工程出来的?还是一些不可避免的科学?

这个例子来自 Makridakis 的《21 世纪的规划与预测》[13]。在他的书中,他描述了一个实验,要求商业人士根据一些历史销售数据做出预测。除了数据,他们还被告知产品是旧的、新的还是成熟的。正如图 1 所示,预测结果差异巨大。这些预测没有哪个本质上是错误的;它们遵循了产品生命周期的心理模型。这只是表明,即使使用相同的数据,预测也可能完全不同,这会影响到行动和结果。我们的数据科学家是解决了谋杀案,还是实施了谋杀?

现实情况是,作为数据科学家,你在‘价值’(模型)的定义中扮演着重要角色。你与业务部门讨论如何通过数据增加价值。你参与了 KPI 的定义。如果无法找到财务利益,数据分析通常还有其他方面的好处。一个成功的数据科学家,往往是一个懂得如何改变行为(例如客户的行为),以忠实于某个模型(例如优化收入)的人。当然,人们只能对现实进行有限的改变。改变人类行为可能是困难的。但是,如果没有意识到作为数据科学家,你可以,也很可能已经在改变‘现实’,那将是一个错误。数据科学家是工程师。

图 1,基于相同数据的 3 种截然不同的预测,基于 Makridakis [13,第 19 页]。作者提供的图像。

工程现实

“好吧,如果游戏规则迫使我们采取不良策略,也许我们不应该尝试改变策略。也许我们应该尝试改变游戏。”

这段引文来自《算法的生活》[14,第 240 页]。在游戏理论的一章中,作者讨论了 Vickrey 拍卖[15]。Vickrey 拍卖是一种所谓的第二价格拍卖,赢家支付的是第二高竞标的金额。在 Vickrey 拍卖中,每个参与者的主导策略是根据他/她认为物品的实际价值进行出价,而不是例如,通过虚张声势使另一方支付高价。从这个意义上讲,Vickrey 拍卖是一个如何改变行为以忠于某个偏好模型的例子。

实践中,还有许多其他的“首选模型”。一个模型的常见特征是,比如说,具有较小的随机性。无论是在制造和物流中被称为“六西格玛”的模型,还是在客户服务中的“轻松体验”[16]。如果模型能够应对随机性,那么基于它分析和做出决策会更容易。这时,数据科学家的概率背景就显得非常重要。可以通过使用统计模型来解释部分随机性,使用模拟技术来确定模型应对不确定性的能力,或指出商业所有者可能过于关注波动的日常表现,而没有看到更稳定的长期模式。正如李氏对晶体管及其后续抽象层的例子所示,减少随机性对于构建后续的抽象层至关重要。当我在用 Python 编程时,我不需要考虑程序在我的机器上是如何精确执行的,我也理所当然地认为程序可能不如在更接近机器代码的语言中编写时那样快。

但减少随机性也是有风险的。当较低层次的抽象中的某些假设被违反时,往往会发生危机。以物流供应链为例。在过去十年中,“六西格玛”理念一直是管理供应链部分的一个重要策略。但是,较长的供应链更容易引入更多的随机性。小的干扰可以通过使用缓冲区来平滑,但大规模的干扰则不然,自新冠危机以来,这种情况频繁发生[17]。因此,研究越来越多地寻求为干扰做准备的方法,而不是试图减少它们[18]。在这里,数据科学家也扮演着工程师的角色。他/她会将观察到的方差解释为需要减少的东西,还是接受它作为不可避免的东西?解释一个而非另一个,将如同在预测示例中一样,导致不同的策略:严格流程与鲁棒流程。

数据科学家作为工程师的概念也在数据和人工智能的伦理讨论中频繁出现。虽然引导客户购买/点击/消费更多在许多情况下可能是无害的,但有时情况却不那么简单。施尔[19]描述了一个令人震惊的例子。施尔研究了拉斯维加斯老丨虎丨机的成瘾性,包括老丨虎丨机的工程设计如何促成这种成瘾,许多原则也适用于其他数字设备,如手机。在一个离奇的例子中,赌场里有人心脏病发作。在急救人员尝试救她时,其他赌场访客继续玩他们的老丨虎丨机。老丨虎丨机经过优化,使得用户即使在邻居倒地的情况下也能保持在“流动”状态中。

当然,人们可以争论是否应责怪那个进行 A/B 测试的工程师,使得老丨虎丨机如此上瘾,但我希望到目前为止我已经说服读者,我认为这至少部分是事实。从积极的一面来看,尽管老丨虎丨机仍然像以往一样上瘾,但对算法失误的媒体关注已增加了对算法公平性等问题的关注[20]。重要的公共算法(如搜索引擎)更经常接受算法审计[21]。秉持“你无法管理你无法测量的事物”这一原则,如果在分析中包括了公平性度量,决策者就必须解释为什么可能忽视了这一度量。如果没有报告公平性度量,决策者可能不会意识到首先存在公平性问题。

查贝顿山 — 克拉维耶,皮埃蒙特。图像由作者提供。

结论

在数据科学和分析的描述中,对分析的强调随处可见,但是否应该更多关注发明呢?在 Gartner 的分析成熟度曲线上,只有在完成描述性和预测性分析(这些仅仅是分析现象)阶段后,才会开始‘规范性分析’。然而,在许多应用中,我们可能会发现行为(实际现象)已经根据模型(如收入、公平等)进行了调整,而不是相反。虽然一些系统和技术比其他系统更用户友好,但人们已然适应了技术,这一点不容忽视。是的,我认为谷歌是一个用户友好的搜索引擎,但我也已经训练自己寻找可能带我到正确网页的合适关键词。我学会了如何开车以及过马路时要小心,这后一点在提高交通效率上贡献不亚于汽车的发明本身。

当然,作为数据科学家,你也会遇到只是分析现象的情况。在这种情况下,你仅仅扮演了波洛的角色,仅仅是调查谋杀。然而,知道行为往往跟随技术的发展,而不是相反,对于数据科学家来说是有用的。与其分析实际现象,不如思考一个理想的现象可能是什么样的,然后据此设计模型。

参考文献

[1] 李·爱德华·阿什福德,《柏拉图与书呆子:人类与技术的创意合作》(2017),MIT 出版社

[2] 林千江,《什么是数据科学?基础概念与启发式示例》。数据科学、分类及相关方法(1998),Springer,东京,40–51. 通过维基百科:en.wikipedia.org/wiki/Data_science. 方括号中的部分来自维基百科页面,其余内容引用自原始论文。最后访问时间:2023 年 2 月 4 日。

[3] BBC,《时间测量的历史》(2022 年),你对我已经死了,BBC,检索自:www.bbc.co.uk/programmes/p07mdbhg,最后访问日期:2023 年 2 月 4 日。

[4] Stigler, Stephen M. 统计学的历史:1900 年以前的不确定性测量。哈佛大学出版社,1986 年。

[5] Croll, Alistair 和 Benjamin Yoskovitz. 精益分析:利用数据更快地打造更好的初创公司。O’Reilly Media, Inc.,2013 年。

[6] Reis, Eric. 精益创业。纽约:Crown Business,2011 年,第 27 期:2016–2020 年。

[7] Ismail, Salim. 指数型组织:为什么新的组织比你的组织好、快、便宜十倍(以及该怎么做)。Diversion Books,2014 年。

[8] ScaleUpNation. 扩展的艺术。2020 年。scaleupnation.com/wp-content/uploads/2021/02/The-Art-of-Scaling-3.1.pdf。最后访问日期:2023 年 2 月 4 日。

[9] VOC。检索自:en.wikipedia.org/wiki/Dutch_East_India_Company,最后访问日期:2023 年 2 月 4 日。

[10] Searle, J. 心智、品牌与科学。哈佛大学出版社,剑桥,MA。

[11] Graeber, David. 债务:前 5000 年。Penguin UK,2012 年。

[12] Chang, E. 男性主导的科技行业。Penguin,2019 年。

[13] Makridakis, Spyros. 21 世纪的预测、规划和战略。Free Press,1990 年。

[14] Christian, Brian 和 Tom Griffiths. 活用算法:人类决策的计算机科学。Macmillan,2016 年。

[15] Vickrey 拍卖。检索自:en.wikipedia.org/wiki/Vickrey_auction,最后访问日期:2023 年 2 月 4 日。

[16] Dixon, Matthew, Nick Toman Rick DeLisi 和 N. Toman. 轻松体验。Penguin Random House,2020 年。

[17] 国会研究服务部。供应链中断与美国经济。2022 年。检索自:crsreports.congress.gov/product/pdf/IN/IN11926,最后访问日期:2023 年 2 月 4 日。

[18] Spieske, Alexander 和 Hendrik Birkel. 通过工业 4.0 提高供应链韧性:在 COVID-19 大流行的影响下的系统文献综述。计算机与工业工程 158,2021 年。

[19] Schüll, Natasha Dow. 设计中的成瘾。普林斯顿大学出版社,2012 年。

[20] Pitoura, Evaggelia, Kostas Stefanidis 和 Georgia Koutrika. 排名和推荐中的公平性:概述。VLDB 期刊,第 1–28 页。Springer,2022 年。

[21] Bandy, Jack. 有问题的机器行为:算法审计的系统文献综述。ACM 人机交互会议录 5,CSCW1,第 1–34 页。ACM,2021 年。

[22] Pinker, Steven. 我们天性的更好天使:历史中暴力的减少及其原因。Penguin UK,2011 年。

NODE:专注于表格数据的神经树

原文:towardsdatascience.com/node-tabular-focused-neural-trees-ee08c752fcd2

探索 NODE:一种用于表格数据的神经决策树架构

Nakul UpadhyaTowards Data Science Nakul Upadhya

·发表于 Towards Data Science ·7 分钟阅读·2023 年 7 月 4 日

--

近年来,机器学习迅猛发展,神经深度学习模型在图像和文本处理等复杂任务中已经超越了像 XGBoost [4] 这样的浅层模型。然而,深度模型在处理表格数据时往往不如这些浅层模型有效,目前尚未有一种通用的深度学习方法能够 consistently 超越梯度提升树。

为了解决这一差距,俄罗斯互联网服务公司 Yandex 的研究人员提出了一种新的架构:神经遗忘决策集成(NODE) [1]。该网络利用轻量级且可解释的神经决策树,并将其整合到神经网络框架中。这使得模型能够在保持可解释性的同时,捕捉表格数据中的复杂交互和依赖关系。

在这篇文章中,我旨在解释 NODE 的工作原理以及使其成为一个强大而可解释的预测模型的各种属性。像往常一样,我鼓励大家阅读原始论文。如果你想使用 NODE,请查看模型的 GitHub。

本文是关于神经决策树系列中的一部分,这些高度可解释的架构提供了与传统深度网络相当的预测能力。

Nakul Upadhya

Nakul Upadhya

软/神经决策树

查看列表3 个故事

NODE 决策树结构

神经决策树

本文假设你对神经决策树有一定的了解。如果你没有,我强烈建议阅读我之前关于它们的文章以获得深入的解释。然而,总结来说:神经决策树是既柔软又倾斜的决策树。

倾斜树是指在每个节点中使用多个变量来做出决策(通常以线性组合的形式排列)。例如,为了预测汽车事故,正交树可能使用规则“car_speed — speed_limit <10”来产生分支决策。这不同于像 CART(基本决策树)这样的“正交”树,后者在任何给定的节点只使用一个变量,并且需要更多的节点来近似相同的决策边界。

柔软树是指所有分支决策都是概率性的,每个节点的计算定义了进入特定分支的概率。这与像 CART 这样的普通“硬”决策树不同,后者的每个分支决策是确定性的。

由于树不限制每个节点使用的变量数量,并且分支决策是连续的,因此整个树是可微分的。由于整个树是可微分的,它可以集成到任何神经网络框架中,如 Pytorch 或 Tensorflow,并使用传统神经优化器(例如,随机梯度下降和 Adam)进行训练。

NODE 树

NODE 使用的决策树与传统神经树略有不同。让我们逐一分析这些差异。

NODE 树。F(*)表示分支函数,b 表示分支阈值。Sigma 表示概率转换函数。R 是叶节点结果(图来自 Popov 等人,2019 年[1])

无意识特性

第一个重大变化是树的特性是“无意识”的。这意味着树在相同深度的所有内部节点上使用相同的分裂权重和阈值。因此,无意识决策树(ODTs)可以表示为一个具有 2^d个条目的决策表(d为深度)。一个好处是,ODTs 比传统决策树更具可解释性,因为决策更少,更容易可视化和理解决策路径。然而,与传统决策树相比,ODTs 的学习能力显著较弱(再次由于分裂函数的受限特性)。

那么如果我们的目标是性能,为什么要使用 ODTs 呢?正如 CATBoost 的开发者[2]所展示的,ODTs 在集成在一起时效果非常好,并且不易过拟合数据。此外,ODTs 的推理非常高效,因为所有分裂可以并行计算,迅速找到表中的合适条目。

用于特征选择和分支的 Entmax

NODE 相对于传统神经决策树的第二个改进是其架构中使用了 alpha-entmax [3]而不是 sigmoid。Alpha-entmax 是 softmax 的一个广义版本,能够产生稀疏分布,其中大部分结果为零。这种稀疏性由一个参数(因此得名 alpha)控制,alpha 值越高,分布越稀疏。

图源于 Peters 等人 2019 年[3]

这种变换在两个关键地方使用。第一个使用场景是稀疏特征选择。NODE 包括一个可训练的特征选择权重矩阵 F(大小为 d x n,其中 n 是特征数量,d 是树的深度),通过 entmax 变换。由于 entmax 变换的大多数条目都等于零,这自然会导致在每个决策节点中使用的特征数量很少。

分支函数(图源于 Popov 等人 2019 年[1])

除了特征选择,entmax 还用于分支概率。这是通过传递分支函数的结果,减去一个学习的阈值,然后适当地缩放来完成的。然后将这个值与 0 串联,并传入 entmax 函数,以创建一个 2 类概率分布,这正是我们进行分支所需的。

分支方程见[1]。b_i 是分支阈值,tau_i 是缩放数据的学习值(图由作者提供)

使用这个,我们可以通过计算所有分支分布 c 的外积来定义一个“选择”张量 C。然后可以将其与叶子中的值相乘,以创建网络的结果。

集成

正如名字所示,这些神经遗忘决策树会被集成在一起。一个 NODE 层被定义为m棵单独树的串联,每棵树都有自己的分支决策和叶值。如前所述,这种集成与单个树的遗忘性质协同作用,有助于提高准确性,同时减少过拟合的可能性。

多层 NODE

NODE 是一个灵活的架构,可以单独训练(结果是单一的决策树集成)或使用复杂的多层结构,其中每组集成从前一层获取输入。

多层 NODE 架构(图源于 Popov 等人 2019 年[1])

NODE 的多层架构紧密跟随了流行的 DenseNet 架构。每个 NODE 层包含若干棵树,其输出被串联起来作为后续层的输入。最终输出是通过对所有层中所有树的输出进行平均获得的。由于每一层依赖于所有之前预测的链条,网络能够捕捉到复杂的依赖关系。

实验性能

为了测试他们的架构,Popov 等人(2019 年)将 NODE 与 CatBoost [2]、XGBoost [4]、全连接神经网络、mGBDT [5] 和 DeepForest [6]进行了比较。他们的方法涉及在六个不同的数据集上测试这些模型。具体来说,他们进行了使用每个模型默认参数的比较,以及另一项使用调整后的超参数的比较。

NODE 与其他模型的比较结果(图来源于 Popov 等人,2019 年)

NODE 的实验结果极为令人鼓舞。例如,NODE 架构在所有其他模型的默认参数下表现优于它们。即使在调整参数的情况下,NODE 在 6 个选定数据集中的 4 个数据集上仍优于大多数其他模型。

结论

通过将决策树的优势融入神经网络架构,NODE 为深度学习在结构化表格数据普遍存在的领域(如金融、医疗保健和客户分析)开辟了新的可能性。

不过,这并不是说 NODE 是完美的。例如,架构的集成意味着使用神经决策树所获得的许多局部可解释性收益被舍弃,模型中只能获得全局特征重要性。然而,这一架构确实提供了改进神经可解释性的基础构件,并且已提出了一个后续模型 (NODE-GAM [7])以弥合可解释性差距。

此外,虽然 NODE 在许多浅层模型中表现优异,但我使用它的经验表明,即使使用 GPU,训练时间也较长(这一结论得到了论文作者提供的实验结果的支持)。

总体而言,这是一种极具前景的方法,我计划在未来开发的深度学习模型中积极使用它作为一个组件。

资源与参考文献

  1. NODE 论文:arxiv.org/abs/1909.06312

  2. NODE 代码:github.com/Qwicen/node

  3. NODE 也可以在 Pytorch Tabular 包中找到:github.com/manujosephv/pytorch_tabular

  4. 如果你对可解释机器学习或时间序列预测感兴趣,可以考虑关注我:medium.com/@upadhyan

  5. 查看我关于神经决策树的其他文章:medium.com/@upadhyan/list/3b4a9cb97b84

参考文献

[1] Popov, S., Morozov, S., & Babenko, A. (2019). Neural oblivious decision ensembles for deep learning on tabular data. 第八届国际学习表征会议.

[2] Prokhorenkova, L., Gusev, G., Vorobev, A., Dorogush, A. V., & Gulin, A. (2018). CatBoost: unbiased boosting with categorical features. 神经信息处理系统进展, 31.

[3] Peters, B., Niculae, V., & Martins, A. (2019). 稀疏序列到序列模型。发表于第 57 届计算语言学协会年会论文集(第 1504–1519 页)。计算语言学协会。

[4] Chen, T., & Guestrin, C. (2016 年 8 月). Xgboost: 一个可扩展的树提升系统。发表于第 22 届 ACM SIGKDD 国际知识发现与数据挖掘会议论文集(第 785–794 页)。

[5] Feng, J., Yu, Y., & Zhou, Z. H. (2018). 多层梯度提升决策树。神经信息处理系统进展, 31

[6] Zhou, Z. H., & Feng, J. (2019). 深度森林。国家科学评论, 6(1), 74–86。

[7] Chang, C.H., Caruana, R., & Goldenberg, A. (2022). NODE-GAM: 神经广义加性模型用于可解释的深度学习。发表于国际学习表征会议

非负矩阵分解(NMF)用于图像数据的降维

原文:towardsdatascience.com/non-negative-matrix-factorization-nmf-for-dimensionality-reduction-in-image-data-8450f4cae8fa

使用 Python 和 Scikit-learn 讨论理论和实现

Rukshan PramodithaTowards Data Science Rukshan Pramoditha

·发表于 Towards Data Science ·9 分钟阅读·2023 年 5 月 6 日

--

原图来自 an_photosPixabay(作者稍作编辑)

我已经详细讨论了不同类型的降维技术。

主成分分析(PCA)因子分析(FA)线性判别分析(LDA)自编码器(AEs)核主成分分析是最受欢迎的几种。

非负矩阵分解(NMF 或 NNMF)也是一种线性降维技术,可以用于减少特征矩阵的维度。

所有降维技术都属于无监督机器学习的范畴,通过这种方法,我们可以揭示数据中隐藏的模式和重要的关系,而不需要标签。

因此,降维算法处理的是无标签的数据。在训练这种算法时,fit() 方法只需要特征矩阵 X 作为输入,不需要标签列 y

正如其名称所示,非负矩阵分解(NMF)需要特征矩阵为非负值。

由于这种非负性约束,NMF 的使用范围被限制在非负值的数据上,例如图像数据(像素值总是介于 0 和 255 之间,因此图像数据中没有负值!)。

**What you will learn:
----------------------------------------------------**
1\. Maths behind NMF
2\. NMF equation
3\. Feature matrix, V
4\. Transformed data matrix, W
5\. Factorization matrix, H
6\. Scikit-learn NMF() class
7\. Arguments, methods and attributes of NMF() class
8\. Load the MNIST with Scikit-learn
9\. Perform dimensionality reduction in image data

**Other matrix decomposition methods:
----------------------------------------------------** 1\. Eigendecomposition
2\. Singular value decomposition

非负矩阵分解(NMF)的数学原理

非负矩阵分解来源于线性代数。简单来说,它是将一个矩阵分解为两个小矩阵的乘积的过程。

更准确地说,

非负矩阵分解(NMF)是将一个非负特征矩阵 V (nxp) 分解为两个非负矩阵 W (nxd) 和 H (dxp) 的乘积的过程。这三个矩阵都应包含非负元素。

非负矩阵分解方程(作者图片)

WH 矩阵的乘积仅能给出矩阵 V 的近似值。因此,在应用 NMF 时应预期会有一些信息损失。

  • V (n x p): 表示特征矩阵,其中 n 是观察(样本)的数量,p 是特征(变量)的数量。这是我们要分解的数据矩阵。

  • W (n x d): 表示应用 NMF 后的转化数据矩阵。我们可以用这个转化后的矩阵代替原始特征矩阵V。因此,W 是 NMF 的最重要输出。它通过调用 Scikit-learn NMF 的 fit_transform() 方法获得。n 是观察(样本)的数量,d 是潜在因素或组件的数量。换句话说,d 描述了我们希望保留的维度量。实际上,这是一个超参数,我们需要在 Scikit-learn NMF 的 n_components 参数中指定。这个整数值应该小于特征数量 p,且大于 0。选择合适的 d 值是执行 NMF 时的一个真正挑战。我们需要考虑信息量与我们希望保留的组件数量之间的平衡。

from sklearn.decomposition import NMF

# W = transformed data matrix, V = original feature matrix
W = NMF(n_components=d).fit_transform(V)
  • H (d x p): 表示分解矩阵dp 的定义如上所述。这个矩阵不是特别重要。然而,可以通过调用 Scikit-learn NMF 的 components_ 属性来获得。
from sklearn.decomposition import NMF

# H = factorization matrix
H = NMF(n_components=d).fit(V).components_

非负矩阵分解(NMF)的 Python 实现

在 Python 中,NMF 通过使用 Scikit-learn 的 NMF() 类来实现。如你所知,Scikit-learn 是 Python 的机器学习库。

你只需导入 NMF() 类,并通过指定所需的参数来创建其实例。

# Import
from sklearn.decomposition import NMF

# Create an instance
nmf_model = NMF(n_components, init, random_state)

NMF() 类的重要参数

  • n_components: 定义组件或潜在因素的数量或我们希望保留的维度量的整数值。最重要的超参数!该值小于原始特征数量,并且大于 0。

  • init: 初始化过程的一种方法。NMF 模型返回的结果会因所选择的 init 方法而显著不同。

  • random_state: 在初始化方法为 ‘nndsvdar’‘random’ 时使用。使用一个整数以确保不同执行之间结果的一致性。

注意: NMF() 类中有许多参数。如果我们没有指定它们,调用NMF()函数时将采用默认值。要了解更多关于这些参数的信息,请参阅 Scikit-learn 文档。

NMF() 类的重要方法

  • fit(V): 从特征矩阵V中学习 NMF 模型。这里不应用任何转换。

  • fit_transform(V): 从特征矩阵V中学习 NMF 模型,并返回转换后的数据矩阵W

W = nmf_model.fit_transform(V)
  • transform(V): 返回经过拟合模型后的转换数据矩阵W
nmf_model.fit(V) # Fitted model
W = nmf_model.transform(V)
  • inverse_transform(W): 将数据矩阵W转换(恢复)回原始空间。对可视化非常有用!
recovered_data = nmf_model.inverse_transform(W)

NMF() 类的重要属性

  • components_: 返回分解矩阵H。这个矩阵不是非常重要。
H = nmf_model.components_
  • reconstruction_err_: 返回一个浮点数表示的 beta 散度,该值衡量VWH的乘积之间的距离。求解器在训练过程中尝试最小化此误差。通过设置不同的n_components值来分析此误差是选择正确的组件数量的一个好方法,d

使用非负矩阵分解(NMF)减少图像数据的维度

我们将使用 MNIST 数字数据集来完成这个任务。我们将对 MNIST 数据进行 NMF 处理,通过选择不同数量的组件来降低维度,然后将每个输出与原始数据进行比较。

第 1 步:使用 Scikit-learn 加载 MNIST 数据集

MNIST 数字数据集可以使用 Scikit-learn 按如下方式加载。

from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784', version=1)
image_data = mnist['data']

print("Shape:", image_data.shape)
print("Type:", type(image_data))

(图片由作者提供)

数据集被加载为 Pandas 数据框。形状为(70000, 784)。数据集中有 70000 个观测值(图像)。每个观测值有 784 个特征(像素值)。图像的大小为 28 x 28。以这种方式加载 MNIST 数据集时,每张图像被表示为一个包含 784(28 x 28)元素的一维数组。这是我们完成此任务所需的格式,数据集无需进一步修改。

或者,你也可以使用 Keras 加载 MNIST 数据集。那样的话,你将获得每张图像的 28 x 28 二维数组,而不是一维数组。你可以在这里了解更多信息。

第 2 步:可视化原始图像的样本

现在,我们将可视化 MNIST 数据集中前五张图像的样本。这个样本可以用来与 NMF 模型的输出进行比较。

import matplotlib.pyplot as plt

n = 5
plt.figure(figsize=(6.75, 1.5))
for i in range(n):
  ax = plt.subplot(1, n, i+1)
  plt.imshow(image_data.iloc[i].values.reshape(28, 28), cmap="binary")
  ax.axis('off')

plt.show()

原始 MNIST 数字的样本(图片由作者提供)

第 3 步:应用具有 9 个组件的 NMF(d = 9)

from sklearn.decomposition import NMF

nmf_model = NMF(n_components=9, init='random', random_state=0)
image_data_nmf = nmf_model.fit_transform(image_data)

print("Shape:", image_data_nmf.shape)
print("Type:", type(image_data_nmf))

(图片由作者提供)

现在,新的维度是 9。原始维度为 784。因此,维度已经显著降低!

要获取VWH矩阵的形状,我们可以运行以下代码。

print("V_shape:", image_data.shape)
print("W_shape:", image_data_nmf.shape)
print("H_shape", nmf_model.components_.shape)

(图片由作者提供)

要获取重建误差或VWH乘积之间的β散度,我们可以运行以下代码。

nmf_model.reconstruction_err_

(图片由作者提供)

重建误差非常高。这是因为我们只选择了 784 中的 9 个组件。我们可以通过可视化输出来验证这一点。

image_data_nmf_recovered = nmf_model.inverse_transform(image_data_nmf)

n = 5
plt.figure(figsize=(6.75, 1.5))
for i in range(n):
  ax = plt.subplot(1, n, i+1)
  plt.imshow(image_data_nmf_recovered[i, :].reshape(28, 28), cmap="binary")
  ax.axis('off')

plt.show()

NMF 输出:9 个组件或 d = 9

NMF 输出:9 个组件或 d = 9(图片由作者提供)

数字不清晰。你可以将此输出与原始图像的样本进行比较。

我运行了 NMF 算法,选择了 100、225 和 784 个组件。以下是结果。

NMF 输出:100 个组件或 d = 100

NMF 输出:100 个组件或 d = 100(图片由作者提供)

重建误差为 174524.20。

NMF 输出:225 个组件或 d = 225

NMF 输出:225 个组件或 d = 225(图片由作者提供)

重建误差为 104024.62。

NMF 输出:784 个组件或 d = 784(所有组件)

NMF 输出:784 个组件或 d = 784(图片由作者提供)

重建误差为 23349.67。

结论

当运行非负矩阵分解(NMF)时,组件数量增加,图像变得更清晰,重建误差变得更低。

通过查看输出和重建误差,可以选择一个合适的d值。为此,你需要多次运行 NMF 算法,这可能会根据你计算机的资源而耗时。

使用 d = 784(所有组件),你仍然会得到 23349.67 的重建误差,而不是零。

显然,W 和 H 矩阵的乘积仅仅给出了特征矩阵 V 的非负矩阵近似。

我们能在负矩阵上运行 NMF 吗?

答案是。如果你尝试使用含有负值的特征矩阵进行 NMF,你将得到以下ValueError!

import numpy as np
from sklearn.decomposition import NMF

V = np.array([[1, 1, -2, 1], [2, 1, -3, 2], [3, 1.2, -3.3, 5]])

nmf_model = NMF(n_components=2, init='random', random_state=0)
W = nmf_model.fit_transform(V)

print("V_shape:", V.shape)
print("W_shape:", W.shape)
print("Reconstruction error:", nmf_model.reconstruction_err_)

ValueError!(图片由作者提供)

在运行非负矩阵分解(NMF)时,不能打破非负性约束。特征矩阵应始终包含非负元素。

这是今天文章的结束。

如果你有任何问题或反馈,请告诉我。

你可能感兴趣的其他矩阵分解方法

阅读下一篇(推荐)

来一门 AI 课程怎么样?

加入我的私人邮件列表

永远不要再错过我的精彩故事。通过 订阅我的邮件列表,你将直接收到我发布的故事。

非常感谢你们的持续支持!下篇文章见。祝大家学习愉快!

MNIST 数据集信息

  • 引用: Deng, L., 2012. The mnist database of handwritten digit images for machine learning research. IEEE Signal Processing Magazine, 29(6), pp. 141–142.

  • 来源: yann.lecun.com/exdb/mnist/

  • 许可证: Yann LeCun(纽约大学 Courant 研究所)和Corinna Cortes(谷歌实验室,纽约)持有 MNIST 数据集的版权,该数据集在Creative Commons Attribution-ShareAlike 4.0 International LicenseCC BY-SA)下提供。你可以在这里了解更多不同的数据集许可证类型。

设计与撰写:

Rukshan Pramoditha

2023–05–06

非参数检验入门(第一部分:秩和符号检验)

原文:towardsdatascience.com/non-parametric-tests-for-beginners-part-1-rank-and-sign-tests-629704f27f2f

附有示例和 R 代码

Jae KimTowards Data Science Jae Kim

·发表于Towards Data Science ·阅读时长 9 分钟·2023 年 6 月 1 日

--

图片由Joshua Earle提供,来源于Unsplash

非参数检验是推断统计的一个重要分支。然而,许多数据科学家和分析师对它的使用还不广泛,也没有完全理解它。它是传统 t 检验等参数检验的自然替代方法,具有一系列优点,并在现代应用如 A/B 测试中具有很大的潜力。

非参数检验基于数据点的秩或符号,或使用如自助法(bootstrap)等重采样方法构建。在这篇文章中,讨论了基于秩和符号的检验,并提供了示例和 R 代码。自助法将在系列的第二部分中讨论。我想感谢 Venkat Raman,他最近的LinkedIn 帖子激发了这篇文章的写作。

1. 参数检验与非参数检验

推断统计或假设检验的关键元素如下:

  1. 零假设和备择假设(H0 和 H1)

  2. 检验统计量

  3. 在 H0 下的检验统计量的采样分布

  4. 决策规则(p 值或临界值,在给定的显著性水平下)

参数检验

包括诸如 t 检验、F 检验和卡方检验等知名检验。参数检验的一个典型特征是

  • 这需要估计未知参数,如均值和方差;以及

  • 它的采样分布遵循正态分布或由正态分布衍生出的其他分布(例如,F 分布或卡方分布)。

为确保采样分布的正态性,总体应遵循正态分布。如果总体不正态,则当样本量足够大时,采样分布可以通过正态分布进行近似。这称为渐近近似,其有效性基于在一系列参数假设下的中心极限定理。

非参数检验

非参数检验以不同于其参数对等方法的方式计算检验统计量及其采样分布:

  • 这些分布是通过完全依赖数据的方法获得的,例如数据点的秩和符号,而不需要估计总体参数。

  • 非参数检验具有精确的采样分布,即它可以在不依赖任何近似的情况下获得。该分布要么完全通过解析获得,要么可以通过蒙特卡罗模拟计算获得。

非参数检验的优点包括以下几点:

  • 它不需要强参数假设,如总体的正态性;

  • 它不需要对其采样分布进行渐近近似;

  • 由于采样分布是精确的,因此显著性水平(第一类错误的概率)在重复采样中始终是正确的(无规模失真);

  • 它的 p 值和临界值也是精确的;

  • 它的检验功效(拒绝虚假零假设的概率)通常高于其参数替代方法,尤其是在样本量较小的情况下。

它的主要缺点是当样本量较大或非常大时,计算精确的采样分布(以及精确的 p 值和临界值)可能会很耗时。然而,这在现代计算能力日益增强的情况下是一个小问题。此外,许多非参数检验采用解析公式或高效算法,当计算负担较重时,可以准确地近似其精确 p 值或临界值。

2. 简单的非参数检验

中位数的符号检验

考虑一个完全随机从其总体中生成的变量 X。使用其样本实现 (X1, …, Xn),研究者希望进行检验

H0: 中位数 = 0;H1: 中位数 ≠ 0。

在 H0 下,每个 X 值应为正(或负),概率为 0.5。或者,在 H0 下,X 的正案例的期望数为 n/2。

设测试统计量 T(X,n) 为 X > 0 的总案例数。假设 H0 下 T(X,n) 的采样分布服从一个具有 n 次试验的二项分布,每次试验的成功概率 (p) 等于 0.5,记作 B(n, p = 0.5)。分布 B(n=20, p = 0.5) 如下图所示:

T(X, n=20) 的精确采样分布,图片由作者创建

上述是 n = 20 时在 H0 下检验统计量 T(X,n) 的精确采样分布。如果 T(X,n) 的观测值接近 10,则不能拒绝原假设。检验的精确 p 值可以使用 R 中的 binom.test 函数计算。

作为一个例子,考虑以下的 X 和 Y 值,其中 n = 20。

表 1(如果 X > 0 则 Positive = 1;否则 Positive = 0),图像由作者创建

从上表 1 可知,T(X) = 12 和 T(Y) = 18,X 的中位数为 0.36,Y 的中位数为 1.67。显然,X 与 H0 高度兼容,而 Y 不兼容。X 的检验精确 p 值为 0.5034,Y 的检验精确 p 值为 0.0004,这些可以使用下面的 R 函数获得:

x = c(-0.63, 0.18,-0.84,1.60,0.33, -0.82,0.49,0.74,0.58,-0.31,
      1.51,0.39,-0.62,-2.21,1.12,-0.04,-0.02,0.94,0.82,0.59)

y=c(1.14,0.54,0.01,-0.02,1.26,-0.29,0.43,0.82,1.90,1.51,
    1.83,2.01,1.37,2.54,3.55, 3.99,5.28,5.41,3.69,2.85)

# Test statistics
Tx=sum(0.5*(sign(x)+1)); Ty=sum(0.5*(sign(y)+1))

# Sign test
binom.test(x=Tx,n=20,p=0.5); binom.test(x=Ty,n=20,p=0.5)

当 n 较大时,采样分布仍然精确地遵循 B(n, p = 0.5)。然而,该分布接近于均值为 0.5n 和方差为 0.25n 的正态分布。因此,当 n 较大时,正态分布可以作为精确分布 B(n, p=0.5) 的近似。

2. 随机性秩检验

可以使用数据的秩对一组时间序列观测进行简单的随机性检验。秩是样本观测值(X1, …, Xn)按升序排列的排名值。即,值 1 被分配给 X 的最小值;值 2 被分配给 X 的下一个最小值;依此类推,直到值 n 被分配给最大值。

原假设是时间序列是完全随机的,而备择假设是时间序列不是完全随机的。Bartels (1982) 提出了以下形式的检验统计量:

方程 (1)

其中 Ri 是第 i 个值 (Xi) 在 n 次观测序列中的秩。在原假设下,(R1, …, Rn)以相等的概率遵循(1, …., n)的任何排列。这是因为如果时间序列观测是完全随机的,其秩也应是完全随机的。基于这一点,RV 的精确分布可以使用以下 R 代码进行模拟:

nit=50000   # number of Monte Carlo iterations
n=20        # Sample size

# Calculating RV statistic
RV=matrix(NA,ncol=1,nrow=nit)
for (i in 1:nit) {
ranking <- sample(1:n, n, replace = FALSE)
RV[i,] = sum(diff(ranking)²)/(n*(n²-1)/12)
}

# Histogram
hist(RV)

# Critical Values replicating the values in Table 2 of Bartels (1982)
quantile(RV,probs = c(0.01,0.05,0.10))

1%       5%      10% 
1.013534 1.285714 1.439098 

上述 R 代码生成了在 H0 下 n = 20 时的 RV 精确采样分布,绘制在下方:

RV 的精确采样分布,图像由作者创建

精确的 p 值或临界值是按照通常的方式从上述分布中获得的。请注意,临界值(用上面的 R 代码给出)与 Bartels(1982)列出的值几乎相同。

如果方程 (1) 中给出的计算 RV 统计量小于显著性水平下的临界值,则拒绝完全随机的原假设。这是因为完全随机的序列的秩值也应是完全随机的,这会导致 RV 统计量的值较大(见下面的例子)。

当样本量较大或巨大的时候,上述模拟仍然可以进行,以生成确切的采样分布,而不会带来很大的计算负担。Bartels(1982)还提供了这些确切临界值的近似公式。

例如,考虑表 1 中的 X 和 Y,如下图所示:

X 和 Y 的时间图(图片由作者创建)

变量 X 围绕 0 随机变化,而 Y 则显示出上升趋势,这是非纯随机时间序列的特征。以下 R 代码绘制了 X 和 Y,并计算了 RV 统计量及其 p 值:

# plots
plot.ts(x,col="red",lwd=2,main="X"); abline(h=0)
plot.ts(y,col="red",lwd=2,main="Y"); abline(h=0)

# RV statistics and p-values
library(trend)
bartels.test(x); bartels.test(y)

X 的 RV 统计量为 2.21,p 值为 0.6844;而 Y 的为 0.32,p 值为 0.0000。这意味着在常规显著性水平下,不能拒绝 X 是纯随机的零假设,但 Y 的零假设被拒绝。

计算过程也在下表中说明:

X 和 Y 的 RV 统计量的说明(图片由作者创建)

作为一个纯随机序列,X 的秩值完全随机且高度变化(导致 RV 的分子值很大)。相反,Y 不是纯随机的,其秩值变化不大。由于这一特性,X 的 RV 统计量比 Y 的要大得多。

3. 威尔科克森检验

威尔科克森检验(McDonald, 2014)是 Welch 两样本 t 检验的非参数替代方法。零假设是两个总体的中位数值相等,对立假设是它们不相等。检验有两个版本:

  • 威尔科克森秩和检验(也称为 Mann–Whitney–Wilcoxon 检验),当 X 和 Y 独立时;以及

  • 威尔科克森符号秩检验,当 X 和 Y 配对时。

设(X1, …, Xn)和(Y1, …, Ym)为来自各自总体的随机样本。独立样本的检验统计量(威尔科克森秩和检验)为

其中 S(X,Y) = 1,如果 X > Y;S(X,Y) = 0.5,如果 X = Y;S(X,Y) = 0,如果 X < Y。

依赖样本情况下的统计量(威尔科克森符号秩检验)计算为

其中 Zi = Xi — Yi;如果 Zi > 0,则 sgn(Zi) = 1;否则 sgn(Zi) = —1;Ri 为|Zi|(Zi 的绝对值)的秩。注意有不同版本的 T 统计量,但它们都是等效的。

对于 U 和 T 统计量,H0 下的确切采样分布可以通过蒙特卡罗模拟获得,也可以进行近似。

对于表 1 中给出的 X 和 Y,威尔科克森检验的 R 代码为

# Wilcoxon rank-sum test (U)
wilcox.test(x,y,mu=0,paired = FALSE,exact=TRUE)

# Wilcoxon signed rank test (T)
wilcox.test(x,y,mu=0,paired = TRUE,exact=TRUE)

其中 H0: μ = 0 且μ = median(X) — median(Y)。U 检验统计量为 67.5,p 值为 0.0004;T 统计量为 11,p 值为 0.0001。因此,在 5%的显著性水平下,两者均拒绝了中位数相等的零假设。

本文回顾了三种基于秩和符号的简单非参数检验。非参数检验和参数检验之间的主要区别在于检验统计量及其在 H0 下的采样分布的计算方式。也就是说,

  • 非参数检验的检验统计量和采样分布是使用完全依赖数据的方法获得的,如秩和符号,而不需要估计未知的总体参数。

  • 它们是在没有依赖任何参数假设或基于中心极限定理的渐近近似的情况下获得的。

  • 非参数检验的采样分布是精确的。因此,检验在没有任何规模失真的情况下进行,其 p 值和临界值都是精确的。

  • 非参数检验通常比其参数检验对照显示出更好的统计性质(例如,更高的统计功效和没有规模失真),尤其是当样本量较小或参数检验的假设被违反时。

强烈建议研究人员在他们的应用中(例如 A/B 测试)采用这些非参数检验作为参数检验的替代方案。在这篇文章中,介绍了几个简单的非参数检验,附有示例和 R 代码。

参考文献:

Bartels, R. (1982). 《冯·诺伊曼随机性比率检验的秩版本》。美国统计学会杂志, 77(377), 40–46。

McDonald, J. H. (2014). 生物统计学手册. 纽约。 www.biostathandbook.com/wilcoxonsignedrank.html

非线性维度降低、核 PCA(kPCA)和多维尺度分析— Python 简单教程

原文:towardsdatascience.com/nonlinear-dimension-reduction-kernel-pca-kpca-and-multidimensional-scaling-an-easy-tutorial-63429ee9d0ae

如何在不破坏瑞士卷的情况下将其展平!!

Biman ChakrabortyTowards Data Science Biman Chakraborty

·发表于 Towards Data Science ·阅读时长 11 分钟·2023 年 12 月 11 日

--

瑞士卷数据(作者提供的图片)

在我的文章 主成分分析(PCA)— Python 简单教程 中,我讨论了如何使用 PCA 来减少数据的维度,同时尽可能保留点对点之间的距离。我用 MNIST 手写数据集举了一些例子,说明 PCA 如何将数据的维度从 784 降到 35,并且仍然能够使用高准确度的监督学习技术。

在这篇文章中,我们以一个简单的瑞士卷数据的三维示例开始,其中数据的真实流形具有 2 维,我们将从 PCA 开始。

示例:瑞士卷数据集

图 1 显示了使用sklearn库模拟的瑞士卷数据,包含𝑛=2000 个点。散点图显示了不同颜色的点分布在螺旋的不同部分。

#Load the libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

#Generate the Swiss Roll Dataset
from sklearn.datasets import make_swiss_roll

np.random.seed(42)
n_samples = 2000
X, t = make_swiss_roll(n_samples, noise=0.0)

fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], c=t, s=10, cmap='hot_r')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
#ax.set_zlim(-1,1)
ax.view_init( elev=7, azim=-80)
plt.show()

图 1:瑞士卷数据的三维视图(作者提供的图片)

我们首先对这个数据集应用 PCA,并在图 2 中可视化前两个主成分。我们观察到它仍然保留了数据的螺旋形状。螺旋的不同部分的点无法使用线性边界分离,大多数分类方法在降维后的数据上将会失败。

from sklearn.decomposition import PCA
pca_X = PCA(n_components=2)
prcomps_X= pca_X.fit_transform(X)

fig = plt.figure(figsize=(6,4))
plt.scatter(prcomps_X[:,0], prcomps_X[:,1],c=t, s=10, cmap='hot_r')
plt.xlabel('PCA Dimension 1')
plt.ylabel('PCA Dimension 2')
plt.show()

图 2:瑞士卷数据的前两个主成分维度(作者提供的图片)

它没有展开潜在的二维空间。为什么会这样?为了理解这一点,我们来看一下图 3,其中两点 A 和 B 之间的欧几里得距离用蓝色虚线表示。尽管这两点位于螺旋的完全不同部分,它们在欧几里得距离上却很接近。

u = np.linspace(0,1,100)
t1 = 1.5*np.pi*(1+2*u)
x1 = t1*np.cos(t1)
z1 = t1*np.sin(t1)
y1 = 10*np.ones((len(t1),))

fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], s=2,c='gray')
ax.plot(x1[20:90],y1[20:90],z1[20:90], c='red',linewidth=2.0)
ax.plot(x1[[20,89]],y1[[20,89]],z1[[20,89]], 'b--',linewidth=2.0)
ax.scatter(x1[[20,89]],y1[[20,89]],z1[[20,89]], 'o',s=50, alpha=1)
ax.text(x1[20], y1[20], z1[20]+1, s='A',c='k',fontweight='bold',size=12,alpha=1 )
ax.text(x1[89], y1[89], z1[89]+1, s='B',c='k',fontweight='bold',size=12,alpha=1 )
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.view_init(7,-80)
plt.show()

图 3:瑞士卷数据的测地距离与欧几里得距离(图片由作者提供)

在 PCA 中,欧几里得距离被保留。然而,这两点 A 和 B 沿螺旋流形的距离由红线显示,表明这两点在流形上相距较远。关键区别在于流形不是线性的。当我们处理线性流形时,欧几里得距离或 PCA 的效果非常好。但数据往往不在直线流形上,如这个示例数据集所示。其他图像数据,如手写数字数据,是高维数据中非线性流形的好例子。

我们需要以不同的方式定义距离以捕捉这种差异。但在此之前,我们先讨论一下如何利用距离构建主成分。

主成分:数学公式

给定一个 𝑛×𝑝 数据矩阵 𝐗,主成分方向 被定义为在这些方向上,𝐗 的样本方差依次被最大化的 𝑝 维正交向量。对于中心化的 𝐗,即 𝐗 的每一列之和为 0,第 𝑘 个主成分方向是

𝑛 维向量 𝐗𝑣_𝑘 称为 𝐗 的第 𝑘 个 主成分得分,且 𝑢_𝑘=(𝐗𝑣_𝑘)/𝑑_k 是归一化的第 𝑘 个主成分得分,其公式为

数量 d²_k/n 是 𝑣_𝑘 解释的方差量。

𝐗 的 奇异值分解 𝐗 = 𝑈𝐷𝑉^⊤ 描述了所有主成分得分和方差,其中 𝑈 是一个 𝑛×𝑝 维矩阵,列为 𝑢_1,𝑢_2,…,𝑢_𝑝,𝑉 是一个 𝑝×p 维矩阵,列为 𝑣_1,𝑣_2,…,𝑣_𝑝,𝐷 是一个 𝑝×𝑝 的对角矩阵,对角元素为 𝑑_1,𝑑_2,…,𝑑_𝑝。

让我们考虑前 𝑘 个主成分得分 𝐗𝑣_1=𝑑_1𝑢_1, …, 𝐗𝑣_𝑘=𝑑_𝑘𝑢_𝑘 作为新的特征向量。然后我们可以将其写成 𝐙=𝐗𝑉_𝑘=(𝑈𝐷)_𝑘,也就是矩阵 (𝑈𝐷) 的前 𝑘 列,并将 Z 视为 𝐗 的新低维表示。

𝐙 的行 𝑧_1,…,𝑧_𝑛 是这个新低维表示中的数据点。我们之前讨论过

在低维表示中,𝑖 和 𝑗 点之间的欧几里得距离大致等于这两点之间的原始欧几里得距离。

内积矩阵

𝑛×𝑛 维矩阵 𝐗𝐗^⊤ 被称为 内积矩阵,其 (𝑖,𝑗) 元素由 𝑥_i^⊤x_j 给出,即矩阵 𝐗 的第 𝑖 行和第 𝑗 行之间的内积。从上面我们可以得出

因此我们可以写道,

这称为 𝐗𝐗^⊤ 的 特征分解,因为 𝑈 的列是 𝐗𝐗^⊤ 的特征向量。从这个表示中,我们可以简单地计算特征分解或 分解 内积矩阵 𝐗𝐗^⊤,然后主成分得分由 𝑈𝐷 的列给出,即 𝑑_𝑗𝑢_𝑗,𝑗=1,…,𝑝。这表明,如果仅给出内积矩阵而不是原始数据点,则可以计算主成分得分。

仅从距离中获得的低维表示

假设我们只有数据点之间的距离,而没有原始数据。也就是说,我们有欧几里得距离。

或者所有 𝑖 和 𝑗。我们还能从这些距离中恢复主成分方向吗?

首先定义一个 𝑛×𝑛 维的距离矩阵 Δ,其中 (𝑖,𝑗) 元素由 Δ_𝑖𝑗 给出。我们可以从距离矩阵 Δ 恢复内积矩阵 𝐵=𝐗𝐗^⊤。

  1. 创建 𝑛×𝑛 矩阵 𝐴,其 (𝑖,𝑗) 元素由

2. 对 𝐴 进行双重中心化,即同时中心化 𝐴 的列和行,通过使用变换 B = (IM)A(IM) 来恢复矩阵 𝐵,其中

核主成分分析

核主成分分析通过用核矩阵 𝐾=((𝐾_𝑖𝑗)) 替换内积矩阵 𝐵 来简单模拟此过程,其中 𝐾_𝑖𝑗=<𝜙(𝑥_𝑖),𝜙(𝑥_𝑗)>,即特征向量 𝜙(𝑥_𝑖) 和 𝜙(𝑥_𝑗) 之间的内积。这里 𝜙 是从 ℝ^𝑝 → 𝐹 的非线性映射,𝐹 是任意维度的特征空间。这个想法类似于用于分类问题的支持向量机(SVM)中的核函数。我们将观察值投影到一个更高维的空间中,然后在该空间中获得主成分。我们可以简单地定义 𝐾_𝑖𝑗=Φ(𝑥_𝑖,𝑥_𝑗),对于径向基核,

对于多项式核,

其中 𝛾、𝑐 和 𝑑 是相应核函数的参数。

算法可以描述如下:

  1. 将 𝑛×𝑛 核内积矩阵 𝐾 定义为 𝐾=((Φ(𝑥_𝑖,𝑥_𝑗))。

  2. 使用 𝐾 的特征分解来提取 𝐾 的特征值和特征向量。

  3. 𝐾 的特征向量将给出主成分得分。

这是一种非线性降维,我们可以通过在前面示例中讨论的 瑞士卷 数据来说明核主成分分析的使用。

from sklearn.decomposition import KernelPCA
kpca_X = KernelPCA(n_components=2, kernel='rbf', gamma=0.002)
prcomps_kX= kpca_X.fit_transform(X)

fig = plt.figure(figsize=(6,4))
plt.scatter(prcomps_kX[:,0], prcomps_kX[:,1],c=t, s=10, cmap='hot_r' )
plt.xlabel('Kernel PCA Dimension 1')
plt.ylabel('Kernel PCA Dimension 2')
plt.show()

图 4:瑞士卷数据的核主成分分析维度(作者提供的图片)

在上述内容中,我们使用了一个rbf核,𝛾 = 0.002。尽管结果相比主成分分析有所改进,但它仍然没有展开瑞士卷数据,而是很好地捕捉了流形。

我们在下面用不同的模拟数据集展示了核主成分分析。

图 5:左侧模拟数据的核 PCA 维度。(作者提供的图片)

在左侧,我们有一个包含半径分别为 1.0、2.8 和 5.0 的 3 个同心圆的模拟数据,分布均匀。在右侧,我们绘制了使用rbf核和𝛾=0.3 的核 PCA 组件。我们观察到数据的三个簇之间有很好的分离。

多维缩放

我们在关于 PCA 的文章中讨论过,它试图在低维表示中保留观察之间的距离。换句话说,如果 𝑧_1,𝑧_2,…,𝑧_𝑛 是 𝑥_1,𝑥_2,…,𝑥_𝑛 的低维表示,那么 PCA 最小化

我们现在通过定义一个压力函数来推广这个想法,如下所示:

其中 𝑑_𝑖𝑗 是 𝑥_𝑖 和 𝑥_𝑗 之间的距离。通常,我们选择欧几里得距离,但也可以使用其他距离。

多维缩放寻求值 𝑧_1,𝑧_2,…,𝑧_𝑛∈ℝ^𝑘,以最小化压力函数 𝑆_𝑀(𝑧_1,𝑧_2,…,𝑧_𝑛)。

这被称为最小二乘Kruskal–Shephard 缩放。这个想法是找到一个低维的数据表示,尽可能保留成对的距离。请注意,这种近似是基于距离而不是平方距离的。

让我们看看它在瑞士卷数据上的实现。

from sklearn.manifold import MDS

embedding = MDS(n_components=2, normalized_stress='auto')
X_MDS = embedding.fit_transform(X)

fig = plt.figure(figsize=(6,4))
plt.scatter(X_MDS[:,0], X_MDS[:,1],c=t, s=10, cmap='hot_r' )
plt.xlabel('MDS Dimension 1')
plt.ylabel('MDS Dimension 2')
plt.show()

图 6:经典多维缩放的前两个维度。(作者提供的图片)

我们观察到结果与核主成分分析(kernel PCA)非常相似。

到目前为止,我们还没有超越欧几里得距离。但我们之前提到,在瑞士卷数据中,欧几里得距离并不理想。

有一类方法构造一个更复杂的距离 𝑑_𝑖𝑗 来度量高维点 𝑥_1,…,𝑥_𝑛∈ℝ^𝑝 之间的距离,然后将这些 𝑑_𝑖𝑗 通过多维缩放处理,以获得低维表示 𝑧_1,…,𝑧_𝑛∈ℝ^𝑘。这样,我们不仅得到主成分得分,我们的低维表示可能最终成为数据的非线性函数

切向距离

切向距离是一个我们可以通过多维缩放(虽然也用于其他地方)运行的更复杂的度量。

一个激励示例是我们之前使用的 手写数字数据。这里,我们有 16 \times 16 的图像,将其视为点 𝑥_𝑖∈ℝ²⁵⁶(即,它们被展开成向量)。例如,如果我们取一个“3”并 旋转 它一个小角度,我们希望旋转后的图像被认为接近原始图像。这在欧几里得距离中不一定成立。

图 7:原始“3”和旋转后的“3”图像(图片由作者提供)

我们可以定义 Δ_𝑖𝑗^rotation 为旋转后的 𝑥_𝑖 和旋转后的 𝑥_𝑗 之间的最短欧几里得距离。然而,你可以立即发现旋转数字“6”和“9”存在问题。

我们需要一些更容易计算的东西,并且将注意力限制在 小旋转 上。可以将图像的旋转集视为定义了 ℝ^𝑝 中的曲线——一个图像 𝑥_𝑖 是 ℝ^𝑝 中的一个点,当我们在任意方向上旋转它时,我们得到一条曲线。

切线距离 Δ_𝑖𝑗^tangent 通过首先计算每条曲线在观察到的图像处的切线,然后使用切线之间的最短欧几里得距离来定义。

等距特征映射(Isomap)

等距特征映射(Isomap)在更一般的设置中学习结构以定义距离。基本思想是构造一个图 𝐺=(𝑉,𝐸),即在顶点 𝑉={1,…,𝑛} 之间构造边 𝐸,基于 𝑥_1,…,𝑥_𝑛∈ℝ^𝑝 之间的结构。然后我们定义 𝑥_𝑖 和 𝑥_𝑗 之间的图距离 Δ_𝑖𝑗^Isomap,并使用多维缩放进行低维表示。

构造图:对于每对 𝑖,𝑗,如果满足以下任一条件,我们将 𝑖,𝑗 用边连接:

  • 𝑥_𝑖 是 𝑥_𝑗 的 𝑚 个最近邻之一,或者

  • 𝑥_𝑗 是 𝑥_𝑖 的 𝑚 个最近邻之一。这条边 𝑒 = {𝑖,𝑗} 的权重为 𝑤_𝑒=‖𝑥_𝑖−𝑥_𝑗‖。

定义图距离:现在我们已经构建了图,即我们已经构建了边集 𝐸,我们定义图距离 Δ_𝑖𝑗^Isomap 为从 𝑖 到 𝑗 的最短路径:

(这可以通过例如 迪杰斯特拉算法弗洛伊德算法 计算)

让我们现在深入研究其在 瑞士卷 数据上的实现。

from sklearn.manifold import Isomap

embedding = Isomap(n_components=2, n_neighbors=7)
X_iso = embedding.fit_transform(X)

fig = plt.figure(figsize=(6,4))
plt.scatter(X_iso[:,0], X_iso[:,1],c=t, s=10, cmap='hot_r' )
plt.xlabel('Isomap Dimension 1')
plt.ylabel('Isomap Dimension 2')
plt.show()

图 8:瑞士卷数据的二维表示。(图片由作者提供)

在邻居数量 𝑚=7 的情况下,多维缩放与 isomap 距离现在展开了 瑞士卷 数据。

局部线性嵌入

另一种非线性维度约减方法是 局部线性嵌入(LLE),它在精神上类似但细节却大相径庭。它不使用多维缩放。

基本思想分为两个步骤:

  1. 学习一组局部近似来描述 𝑥_1,…,𝑥_𝑛∈ℝ^𝑝 之间的结构

  2. 学习一个低维表示 𝑧_1,…,𝑧_𝑛∈ℝ^𝑘,最好与这些局部近似匹配

什么是局部近似?我们只是尝试用附近点𝑥_𝑗的线性函数来预测每个𝑥_𝑖(因此得名局部线性嵌入)。

对于每个𝑥_𝑖,我们首先找到它的𝑚个最近邻,并将它们的索引收集为 N(𝑖)。然后我们构建一个权重向量𝑤_𝑖∈ℝ^𝑛,设置𝑤_𝑖𝑗=0(当𝑗∉N(𝑖)时),并通过最小化来设置𝑤_𝑖𝑗(当𝑗∈N(𝑖)时)。

最后,我们取这些权重𝑤_1,…,𝑤_𝑛∈ℝ𝑛,并通过最小化来拟合低维表示𝑧_1,…,𝑧_𝑛∈ℝ𝑘。

我们再次使用局部线性嵌入(Local Linear Embedding)来说明如何处理瑞士卷数据。

from sklearn.manifold import LocallyLinearEmbedding

embedding = LocallyLinearEmbedding(n_components=2, n_neighbors=25)
X_lle = embedding.fit_transform(X)

fig = plt.figure(figsize=(6,4))
plt.scatter(X_lle[:,0], X_lle[:,1],c=t, s=10, cmap='hot_r' )
plt.xlabel('LLE Dimension 1')
plt.ylabel('LLEp Dimension 2')
plt.show()

图 9:瑞士卷数据的局部线性嵌入维度。(作者提供的图像)

局部线性嵌入的降维效果优于核 PCA 或经典 MDS,尽管不如Isomap

在这篇文章中,我们通过对 PCA 概念的推广学习了一些非线性降维技术。然而,没有一种单一的方法可以适用于所有类型的数据降维。根据数据的性质,我们应选择合适的降维技术。

希望你喜欢这篇文章!!

有关数据科学问题的咨询,请联系 biman.pph@gmail.com

不必 A/B 测试一切都是好的

原文:towardsdatascience.com/not-a-b_testing-everything-is-fine-7f67378428be?source=collection_archive---------2-----------------------#2023-12-20

实验领域的主流观点建议你测试一切。然而,一些关于 A/B 测试的不便真相表明,最好还是不要测试所有内容。

Yevhen KralychTowards Data Science Yevhen Kralych

·

关注 发表在 Towards Data Science ·17 分钟阅读·2023 年 12 月 20 日

--

图片由 OpenAI 的 DALL-E 创建

在在线和产品营销领域工作的人可能听说过 A/B 测试和在线实验。近年来出现了无数的 A/B 测试平台,它们鼓励你注册并利用实验的力量来提升你的产品。许多行业领导者和小型影响者都详细讲述了 A/B 测试的成功实施及其如何改变了某些业务。我相信实验的力量吗?是的,我相信。但同时,在提高统计学水平并经历了大量的试错后,我发现,就像生活和商业中的任何事情一样,有些问题有时会被忽视,这些通常是实验中不方便的缺陷,削弱了它们作为神奇工具的地位。

为了更好地理解问题的根源,我需要从在线 A/B 测试的起源开始讲起。早期,在线 A/B 测试并不存在,但一些以创新著称的公司决定将实验转移到在线领域。当然,到那时 A/B 测试已经是科学中用来发现真相的成熟方法。这些公司包括 Google(2000 年)、Amazon(2002 年)、以及一些其他大公司如 Booking.com(2004 年),微软也很快加入。我们不难发现这些公司有一个共同点,那就是它们拥有对任何业务至关重要的两个要素:资金和资源。资源不仅仅包括基础设施,还有具备专业知识和经验的人。而且他们已经拥有了数百万的用户。顺便提一句,A/B 测试的正确实施需要上述所有条件。

直到今天,他们仍然是在线实验领域最受认可的行业声音之一,和后来出现的公司,如 Netflix、Spotify、Airbnb 等相比也不遑多让。他们的想法和方法被广泛认可和讨论,他们在在线实验中的创新也同样受到关注。他们所做的事情被认为是最佳实践,虽然不可能将所有这些内容都放入一篇小文章中,但有些内容被提及得更多,它们基本可以归纳为:

  • 测试一切

  • 在测试之前绝不要发布更改

  • 即使是最小的变化也可能产生巨大影响

这些规则确实很有用,但并不适用于每个公司。事实上,对于许多产品和在线营销经理来说,盲目跟随这些规则可能会导致混乱甚至灾难。这是为什么呢?首先,盲目跟随任何东西都是一个坏主意,但有时我们必须依赖专家意见,因为我们在某些领域缺乏自己的专业知识和理解。我们通常忘记的是,并非所有专家意见都能很好地转化到我们自己的业务领域。这些成功的 A/B 测试基本原则的根本缺陷在于它们来源于多亿万公司,而你,读者,可能并不与其中任何一家相关联。

这篇文章将重点讨论统计功效这一已知概念及其扩展——实验的敏感性。这个概念是我在实验生活中每日决策的基础。

资源

“知识的幻觉比缺乏知识更糟” (某位聪明人)

如果你对 A/B 测试一无所知,这个想法可能看起来很简单——只需拿两个版本的东西进行比较。显示更高转化率(每用户收入、点击、注册等)的那个版本被认为更好。

如果你稍微了解一些,统计功效以及运行 A/B 测试所需样本量的计算,那么你会对检测所需效应大小的功效有所了解。如果你理解早期停止和窥探的警告——你就走在了正确的道路上。

当你进行一系列 A/A 测试时,对 A/B 测试简单性的误解会迅速被打破。在这些测试中,我们将两个完全相同的版本进行比较,并将结果展示给需要了解 A/B 测试的人。如果你有足够多的这些测试(例如 20–40 个),他们会发现有些测试显示处理组(也称为替代变体)比对照组(原始版本)有所改进,而有些测试则显示处理组实际上更差。当不断监控正在进行的实验时,我们可能会在大约 20%的时间看到显著结果。但如果我们比较的是两个相同的版本,这怎么可能呢?实际上,作者让公司的利益相关者进行了这个实验,并展示了这些误导性的结果,其中一位利益相关者回复说,这无疑是一个“错误”,如果一切设置得当,我们不会看到这样的情况。

这只是冰山一角,如果你已有一些经验,你会知道:

  • 实验远非简单

  • 测试不同的事物和不同的指标需要远超普通传统 A/B 测试的方法。一旦超出了简单的转化率测试,事情会变得成倍困难。你会开始关心方差及其减少,估计新奇效应和首因效应,评估分布的正态性等等。实际上,即使你知道如何处理问题,你也无法正确测试某些事物(稍后会详细说明)。

  • 你可能需要一位合格的数据科学家/统计学家。实际上,你肯定需要不止一位他们,以确定在你的具体情况下应该使用什么方法,以及需要考虑哪些注意事项。这包括确定要测试什么以及如何进行测试。

  • 你还需要一个合适的数据基础设施来收集分析数据并执行 A/B 测试。你选择的 A/B 测试平台的 JavaScript 库,最简单的解决方案,并不是最佳选择,因为它与已知的闪烁问题和增加的页面加载时间相关联。

  • 如果不完全理解背景和在各个方面偷工减料,很容易得到误导性结果。

以下是一个简化的流程图,说明了设置和分析实验过程中涉及的决策过程。实际上,事情变得更加复杂,因为我们必须处理不同的假设,如同质性、观察独立性、正态性等。如果你已经在这个领域待了一段时间,这些词汇你是熟悉的,你知道考虑所有因素可能有多么困难。如果你对实验还不熟悉,这些词汇对你来说可能毫无意义,但希望它们能暗示你,也许事情并不像看起来那么简单。

图片由Scribbr提供,已获许可

中小型公司可能会在分配设置适当 A/B 测试环境所需资源时遇到困难,每次启动 A/B 测试可能是一个耗时的任务。但这只是问题的一部分。希望在本文结束时,你能理解为什么在所有这些情况下,当经理给我发消息说“我们需要测试这个”时,我经常会回复“我们可以吗?”。真的,为什么我们不能?

用户和敏感性

在像微软和 Airbnb 这样的公司,成功实验的多数提升幅度低于 3%

那些熟悉统计功效概念的人知道,每组中的随机化单位(为简化起见,我们称之为“用户”)越多,你能够检测变体之间差异的机会就越高(其他条件相同),这也是像 Google 这样的巨大公司和你们这些普通在线业务之间的另一个关键区别——你的业务可能没有足够多的用户和流量来检测高达 3%的小差异,即使是检测 5%的提升,拥有足够统计功效(行业标准为 0.80)也是一种挑战。

在 alpha 0.05、功效 0.80、基准均值 10 和标准差 40、方差相等的情况下,不同样本大小的可检测提升。(作者提供的图片)

在上述敏感性分析中,我们可以看到,检测大约 7%的提升相对容易,只需每个变体 50000 名用户,但如果我们想要检测 3%的提升,则需要大约 275000 名用户每个变体。

温馨提示:G*Power是一个非常方便的软件,用于进行功效分析和各种功效计算,包括测试两个独立均值之间的差异的敏感性。尽管它以Cohen’s d的形式显示效应大小,但转换为提升是直接的。

在 G*Power 中执行的测试敏感性计算的屏幕截图。(作者提供的图片)

有了这些知识,我们可以采取两条路线:

  • 我们可以提出一个可接受的实验持续时间,计算 MDE,启动实验,如果未检测到差异,我们放弃更改,并假设如果存在差异,它不会高于 0.99 功效和给定显著性水平(0.05)的 MDE。

  • 我们可以决定实验的持续时间,计算 MDE,如果 MDE 对于给定的持续时间太高,我们可以选择不启动实验或在不测试的情况下发布更改(第二种选项是我通常的做法)。

事实上,第一种方法在 LinkedIn 上由 Ronny Kohavi提到过

第一种方法的缺点,尤其是对于资源有限的初创企业或小型企业,是你不断将资源投入到几乎没有机会提供可操作数据的领域。

进行不够敏感的实验可能导致参与实验的团队成员疲劳和士气低落

因此,如果你决定追求那个“圣杯”并测试所有推向生产的内容,你最终会得到的是:

  • 设计师花费数天,有时数周,设计改进版本的着陆页或产品部分

  • 开发人员通过你的 A/B 测试基础设施实施更改,这也需要时间

  • 数据分析师和数据工程师设置额外的数据跟踪(实验所需的额外指标和分段)

  • QA 团队测试最终结果(如果你足够幸运,一切正常,无需重新修改)

  • 测试被推送到生产环境中,在那里保持活动状态一到两个月

  • 你和相关利益方未能检测到显著差异(除非你运行实验的时间过长,从而危及其有效性)。

在经历了一系列这样的测试之后,包括公司的顶级增长声音在内的每个人都会失去动力,并因花费如此多的时间和精力进行测试而感到沮丧,最终却得出“变体之间没有差异”的结论。但在这里,措辞扮演了至关重要的角色。请看这里:

  • 变体之间没有显著差异

  • 我们未能检测到变体之间的差异。如果差异为 30%或更高,我们很可能(0.99)会检测到,如果差异为 20%或更高,则概率稍低(0.80)。

第二种措辞稍微复杂一些,但信息量更多。0.99 和 0.80 是不同的统计功效水平。

  • 这更符合已知的实验声明:“证据的缺乏并不是缺乏证据”。

  • 这揭示了我们的实验最初有多敏感,并可能暴露公司经常遇到的问题——进行充分实验的流量有限。

加上 Ronny Kohavi 在其白皮书中提供的知识,他声称他工作过的公司中大多数实验的提升不到 3%,这让我们感到困惑。实际上,他在其出版物中建议将 MDE 保持在 5%。

我在微软、Airbnb 和亚马逊见过成千上万的实验,极少看到关键指标的提升超过 10%。[source]

我推荐的大多数电子商务网站的默认 MDE 是 5%。[source]

在必应,每月的改进

多次实验的收入通常在个位数范围内。[source, section 4]

我仍然认为,对于那些产品优化不足且仅从 A/B 测试开始的小公司来说,可能会有更高的提升,但我觉得大多数情况下不会接近 30%。

问题

在制定 A/B 测试策略时,你需要从更大的角度来看待:可用资源、流量数量以及你手头的时间。

所以,我们最终得到的结果,以及我所说的“我们”指的是那些刚开始实验之旅的相当多的企业,就是大量资源用于设计、开发测试变体,资源用于设置测试本身(包括设置指标、细分等)——所有这些加起来实际上很难在合理的时间内检测到任何东西。我可能应该再强调一下,不应该过于相信平均测试的真实效果会有 30%的提升。

我经历过这个过程,我们在 SendPulse 尝试启动实验时有过许多失败的尝试,这些尝试总是显得徒劳,直到不久前,我意识到应该跳出 A/B 测试,看到更大的图景,而更大的图景就是这样。

  • 你拥有有限的资源

  • 你拥有有限的流量和用户

  • 你不会总是拥有进行适当实验的条件,事实上,如果你是一个较小的企业,这些条件会更为稀少。

  • 你应该在自己公司的背景下计划实验,仔细分配资源,并且要合理,避免将资源浪费在徒劳的任务上

  • 不对下一个变更进行实验是可以的,虽然不是理想的做法——企业在在线实验成为一种手段之前早已取得成功。你的一些变更会产生负面影响,一些则会产生积极影响,但只要积极影响大于负面影响,这也是可以接受的。

  • 如果你不小心,并且对实验作为唯一真实的方法过于热衷,你可能会把大部分资源投入到一个徒劳的任务中,使公司处于不利的位置。

以下是一个被称为“证据层级”的图示。虽然个人观点位于金字塔的底部,但它仍然有一定的参考价值,但更好的做法是接受这样的事实:有时这是唯一合理的选项,尽管它有缺陷。随机实验当然在金字塔的层级中更高。

科学中的证据层级。(图片由 CFCF 提供,通过 Wikimedia Commons,按 CC BY-SA 4.0 许可证授权)

解决方案

在更传统的设置中,启动 A/B 测试的流程大致如下:

  • 有人提出了某项变更的想法

  • 你估算实施变更所需的资源

  • 涉及到的人使变革成为现实(设计师、开发人员、产品经理)

  • 你设置最小可检测效应(MDE)和其他参数(alpha、beta、测试类型——双尾、单尾)

  • 你计算所需的样本量,并根据参数确定测试需要运行多久

  • 你启动测试

如上所述,这种方法是“实验优先”设计的核心——实验优先,无论代价如何,所需资源将被分配。完成实验所需的时间也不是问题。但如果你发现实施改变需要两周和三个人,而实验需要运行 8 到 12 个月才能敏感足够,你会怎么想?记住,利益相关者并不总是理解 A/B 测试的敏感性概念,因此为其持续一年进行合理化可能是一个挑战,而且世界变化迅速,这可能无法接受。更不用说技术问题会妨碍测试有效性,过期的 cookies 就是其中之一。

在资源、用户和时间有限的条件下,我们可以反转流程,将其改为“资源优先”设计,这可能在你的情况下是一个合理的解决方案。

假设:

  • 基于伪用户 ID(基于 cookies,这些 cookies 有时会过期和被删除)的 A/B 测试在较短的运行时间内更稳定,所以我们将其设置为最长 45 天。

  • 基于稳定标识符如 user-id 的 A/B 测试可以承受更长的运行时间(例如,基于转化指标为 3 个月,基于收入指标为 5 个月)。

我们接下来要做的是:

  • 查看我们在 45 天内为每个变体能收集多少单位,假设每个变体 30,000 个访问者。

  • 计算在可用样本量、alpha、功效和基础转化率下你的 A/B 测试的敏感性。

  • 如果效果足够合理(1%到 10%的提升),你可以考虑分配所需资源来实施改变和设置测试。

  • 如果效果高于 10%,特别是高于 20%,分配资源可能不是明智的决定,因为你改变后的实际提升可能会更低,而且你也无法可靠地检测到它。

我应该注意,最大实验长度和效果阈值由你决定,但我发现这些对我们来说效果很好。

  • 网站上 A/B 测试的最长时间为 45 天。

  • 基于产品中转化指标和持久标识符(如 user_id)的 A/B 测试最长时间为 60 天。

  • 基于产品收入指标的 A/B 测试的最长时间为 120 天。

决策的敏感性阈值:

  • 高达 5%——完美,启动是完全合理的,我们可以在这方面分配更多资源。

  • 5%-10%——不错,我们可以启动它,但我们应该小心投入多少资源。

  • 10%-15%——可接受,如果我们不需要花费太多资源——有限的开发时间、有限的设计时间、设置额外的指标和测试的细分不多,可以启动它。

  • 15%-20%——勉强可接受,但如果你需要更少的资源,并且你对成功有强烈的信念,启动可能是合理的。然而,你可能需要告知团队测试的敏感性较差。

  • 20% — 不可接受。进行如此低敏感性的测试仅在少数情况下是合理的,考虑一下你可以改变实验设计的哪些方面以提高敏感性(例如,可能将更改实施在多个着陆页而不是一个上,等等)。

基于敏感性的实验分类(图像由作者提供)

注意,在我的业务环境中,我们允许基于收入的实验运行更长时间,因为:

  • 收入增加是最高优先级

  • 基于收入的指标具有更高的方差,因此相比于转化率指标,敏感性较低,其他条件相同

一段时间后,我们对哪些测试足够敏感有了了解:

  • 跨整个网站或一组页面的更改(而不是单个页面)

  • “折叠线”上的更改(即着陆页的第一个屏幕上的更改)

  • 服务中的入门流程更改(因为这是用户旅程的开始,此处用户数量达到最大值)

  • 我们主要在新用户身上进行实验,忽略旧用户(以免处理可能的首因效应和新奇效应)。

更改的来源

我还应该介绍“更改的来源”这一术语,以进一步扩展我的思想和方法论。在 SendPulse,与其他公司一样,产品持续推向生产环境,包括涉及用户界面、可用性和其他外观的更改。这些更改在我们引入实验之前就已经发布了,因为,商业不能停滞不前。同时,还有一些我们特别希望测试的更改,例如有人提出了一个有趣但有风险的想法,我们不会在没有测试的情况下发布。

  • 在第一种情况下,无论如何都要分配资源,并且坚信必须实施更改。这意味着我们花费在测试上的资源只是为了设置测试本身,而不是开发/设计更改,我们称之为“自然更改”。

  • 在第二种情况下,所有用于测试的资源包括设计、开发更改和设置实验,我们称之为“实验性更改”。

为什么要进行这种分类?记住,我描述的哲学是从敏感性和资源的角度测试那些值得测试的内容,而不对公司现有流程造成太大干扰。我们不想让所有事情都依赖于实验,直到业务准备好为止。考虑到我们迄今所涵盖的一切,将实验逐步融入团队和公司的生活是有意义的。

上述分类允许我们在处理“自然更改”时使用以下方法:

  • 如果我们考虑测试“自然变化”,我们只看设置测试需要多少资源,即使敏感度超过 20%,但所需资源最小,我们也会进行测试。

  • 如果我们在指标上没有看到下降,我们坚持新的变体并将其推广到所有用户(记住,我们计划在决定测试之前就要发布它)。

  • 因此,即使测试不够敏感以检测变化,我们也只是为自己设定了一个“护栏”——以防变化真的大幅度下降了指标。我们不会通过寻找确凿证据来阻止推广变化——这只是一个预防措施。

另一方面,在处理“实验性变化”时,协议可能有所不同:

  • 我们需要基于“敏感度”来做决策,这在这里起着关键作用,因为我们要考虑分配多少资源来实施变化和测试本身,只有在我们有很好的机会检测到效果时,才应承诺进行工作。

  • 如果我们在指标上看不到提升,我们倾向于放弃变化并保留原始版本,因此,资源可能会浪费在后续会被舍弃的东西上——这些资源应该得到仔细管理。

结果(希望是积极的)

这种策略如何帮助一个成长中的企业适应实验心态?我觉得读者此时应该已经明白,但回顾一下也没坏处。

  • 你通过逐步引入 A/B 测试,给团队时间适应实验。

  • 你不会将有限的资源用于那些没有足够敏感度的实验,而资源对于成长中的初创公司来说是一个问题——你可能需要将它们用于其他地方。

  • 因此,你不会通过不断催促你的团队进行从未达到统计显著性的实验来迫使拒绝 A/B 测试,即使在启动它们时花费了大量时间——当你的大部分测试显示出显著的东西时,你会意识到这些努力并非徒劳。

  • 通过测试“自然变化”,即团队认为应该推出的即使没有实验的东西,只有在它们显示出统计学上显著下降时才拒绝,这样你不会造成太大干扰,但如果测试确实显示出下降,你会播下怀疑的种子,表明我们的决策并非全都完美。

重要的是要记住——A/B 测试并非微不足道,它们需要巨大的努力和资源来做到正确。像世界上的任何事物一样,我们应该了解自己的极限和在特定时间的能力。仅仅因为我们想攀登珠穆朗玛峰,并不意味着我们应该在不了解自己极限的情况下去做——有很多创业公司的尸体在比喻的珠穆朗玛峰上,他们超出了自己的能力范围。

祝你实验顺利!

并非全是彩虹和阳光:ChatGPT 的阴暗面

原文:towardsdatascience.com/not-all-rainbows-and-sunshine-the-darker-side-of-chatgpt-75917472b9c?source=collection_archive---------0-----------------------#2023-01-27

第一部分:大型语言模型的风险和伦理问题

Mary Reagan PhDTowards Data Science Mary Reagan PhD

·

关注 发表在 Towards Data Science · 9 分钟阅读 · 2023 年 1 月 27 日

--

图片 由 Fiddler AI 提供,已获许可

如果你还没听说过 ChatGPT,你一定是躲在一块非常大的石头下。这款病毒式聊天机器人被用于自然语言处理任务,如文本生成,已经引起了广泛关注。其背后的公司 OpenAI 最近在谈判以获得 290 亿美元的估值¹,微软可能很快会再投资 100 亿美元²。

ChatGPT 是一个自回归语言模型,使用深度学习生成文本。它通过在各种领域提供详细的回答让用户感到惊讶。它的回答非常有说服力,以至于很难判断这些回答是否由人类撰写。ChatGPT 建立在 OpenAI 的 GPT-3 系列大型语言模型(LLMs)基础上,于 2022 年 11 月 30 日推出。它是最大的 LLM 之一,可以写出优美的文章和诗歌,生成可用的代码,以及根据文本描述生成图表和网站,所有这些都不需要或几乎不需要监督。ChatGPT 的回答如此出色,它显示出有可能成为无处不在的 Google 搜索引擎的潜在对手。

大型语言模型是……嗯……庞大的。它们在大量的文本数据上进行训练,这些数据可以达到 PB 级,并且具有数十亿个参数。最终的多层神经网络通常有几个 TB 大。围绕 ChatGPT 和其他 LLM 的炒作和媒体关注是可以理解的——它们确实是人类智慧的杰出成果,有时会以新兴行为让这些模型的开发者感到惊讶。例如,通过在提示的开头使用某些“魔法”短语,如“让我们一步一步思考”,可以改善 GPT-3 的回答。这些新兴行为表明了模型的巨大复杂性以及当前的可解释性缺失,甚至让开发者思考这些模型是否具有意识。

大型语言模型的幽灵

尽管有许多积极的宣传和炒作,但负责任的人工智能社区中的一些人发出了强烈的警告。值得注意的是,在 2021 年,负责人工智能领域的著名研究员 Timit Gebru 发表了一篇论文,警告了与 LLM 相关的许多伦理问题,这也导致她被谷歌解雇。这些警告涉及广泛的问题:缺乏可解释性、抄袭、隐私、偏见、模型鲁棒性以及其环境影响。让我们稍微探讨一下这些话题。

信任与缺乏可解释性:

深度学习模型,尤其是 LLM,已经变得如此庞大和不透明,以至于即使是模型的开发者也常常无法理解为什么他们的模型会做出某些预测。这种缺乏可解释性是一个重大问题,特别是在用户希望了解模型生成特定输出的原因和方式的情况下。

在轻松的风格下,我们的首席执行官 Krishna Gade 使用了 ChatGPT 创作了一首关于可解释人工智能的诗歌,风格模仿了约翰·济慈,坦率地说,我认为效果相当不错。

克里希纳正确指出,关于模型如何得出输出的透明度不足。对于 LLMs 生成的作品,缺乏关于输出所依赖数据来源的透明度意味着 ChatGPT 提供的答案无法被正确引用,因此用户无法验证或信任其输出⁹。这导致 ChatGPT 创建的答案在像 Stack Overflow¹⁰这样的论坛上被禁止。

当使用像 OpenAI 的嵌入模型¹¹这样的工具时,透明度和理解模型如何得出输出变得尤为重要,因为这些模型本质上包含了一层模糊性,或者在模型用于高风险决策的情况下。例如,如果有人使用 ChatGPT 获取急救指示,用户需要知道回应是可靠的、准确的,并且来源可信。虽然存在各种事后方法来解释模型的选择,但这些解释在模型部署时往往被忽视。

这种缺乏透明度和可信度的后果在假新闻和虚假信息泛滥的时代尤其令人担忧,因为 LLMs 可能会被微调以传播虚假信息并威胁政治稳定。虽然 Open AI 正在研究各种方法来识别其模型的输出,并计划嵌入加密标签以对输出进行水印¹²,这些负责任的 AI 解决方案的速度仍然不足,可能也不够充分。

这引发了关于……的问题。

剽窃:

精心制作的 ChatGPT 文章的来源难以追踪自然引发了关于剽窃的讨论。但这是否真的是个问题?作者认为不是。在 ChatGPT 出现之前,学生们已经可以利用写作服务¹³,且一直有少数学生决心作弊。但对 ChatGPT 使所有孩子变成无脑的剽窃者的担忧已成为许多教育者的关注重点,并导致一些学区禁止使用 ChatGPT¹⁴。

关于剽窃的讨论掩盖了与 LLMs 相关的更大、更重要的伦理问题。鉴于这一话题的广泛关注,我不得不提及它。

隐私:

如果大型语言模型用于处理敏感数据,则面临数据隐私泄露的风险。训练集来源于各种数据,有时包括个人身份信息¹⁵ — 姓名、电子邮件地址¹⁶、电话号码、地址、医疗信息 — 因此,可能会出现在模型的输出中。虽然这是任何使用敏感数据训练的模型都会面临的问题,但考虑到 LLMs 的训练集规模庞大,这个问题可能影响到很多人。

潜在偏见:

如前所述,这些模型在大量的数据语料库上进行训练。当数据训练集如此庞大时,它们变得非常难以审计,因此固有地存在风险⁵。这个数据包含社会和历史偏见¹⁷,因此任何在这些数据上训练的模型都可能会再现这些偏见,除非采取适当的保护措施。许多流行的语言模型被发现含有偏见,这可能导致偏见思想的传播增加,并对某些群体造成持续伤害。GPT-3 被发现表现出常见的性别刻板印象¹⁸,将女性与家庭和外貌联系起来,并将她们描述为比男性角色更没有权力。令人遗憾的是,它还将穆斯林与暴力联系在一起¹⁹,其中三分之二对包含“穆斯林”一词的提示的回应中都包含了暴力的参考。很可能存在更多的偏见关联尚未被发现。

值得注意的是,微软的聊天机器人在 2016 年迅速变成了最糟糕的网络恶搞者的模仿者²⁰,吐露种族主义、性别歧视和其他辱骂性语言。尽管 ChatGPT 设有过滤器以尝试避免最糟糕的此类语言,但它可能并非万无一失。OpenAI 为人工标注员支付费用,以标记最具攻击性和令人不安的数据,但其合作公司因每小时仅支付 2 美元而受到批评,工人报告称遭受了深刻的心理伤害²¹。

模型的鲁棒性和安全性:

由于大语言模型(LLMs)是预先训练的,并随后根据特定任务进行微调,这会导致一系列问题和安全风险。值得注意的是,LLMs 缺乏提供不确定性估计的能力²²。没有了解模型的置信度(或不确定性),我们很难判断何时可以信任模型的输出,何时需要保持怀疑态度²³。这影响了它们在微调到新任务时的表现以及避免过拟合的能力。可解释的不确定性估计有潜力提高模型预测的鲁棒性。

模型安全性是一个迫在眉睫的问题,原因在于 LLM 的父模型在微调步骤之前的普遍性。因此,模型可能成为单点故障和攻击的主要目标,这会影响从原始模型派生的任何应用。此外,由于缺乏监督训练,LLMs 可能会受到数据投毒的威胁²⁵,这可能导致恶意言论被注入以针对特定公司、群体或个人。

LLM 的训练语料库是通过爬取互联网上各种语言和主题来源创建的,然而它们仅仅反映了最有可能接触并频繁使用互联网的人群。因此,AI 生成的语言趋于同质化,并且通常反映了最富有的社区和国家的实践⁶。对于不在训练数据中的语言,LLMs 更容易失败,需要更多的研究来解决与分布外数据相关的问题。

环境影响与可持续性:

一篇由 Strubell 和合作者于 2019 年发表的论文概述了 LLM 训练生命周期的巨大碳足迹²⁴ ²⁶,其中,训练一个拥有 2.13 亿参数的神经架构搜索模型被估算产生的碳排放量是普通汽车生命周期排放量的五倍以上。考虑到 GPT-3 拥有 175 亿 个参数,而下一代 GPT-4 传闻拥有 100 万亿 个参数,这在面临气候变化带来的日益严重的恐怖和破坏的世界中是一个重要的方面。

现在怎么办?

任何新技术都会带来优势和劣势。我已经概述了许多与大语言模型(LLMs)相关的问题,但我想强调的是,我也对这些模型为我们每个人带来的新可能性和承诺感到兴奋。社会有责任采取适当的保障措施,并明智地使用这项新技术。任何在公众中使用或公开的模型都需要进行监控、解释,并定期审计模型的偏差。在第二部分中,我将概述对 AI/ML 从业者、企业和政府机构的建议,说明如何解决特定于 LLMs 的一些问题。

参考文献:

  1. ChatGPT 创始人与投资者谈论以 290 亿美元估值出售股份,华尔街日报,2023 年。

  2. Todd Bishop, 微软计划通过潜在的 100 亿美元投资和新的整合来巩固与 OpenAI 的关系,GeekWire,2023 年。

  3. Parmy Olson, ChatGPT 应该让 Google 和 Alphabet 感到担忧。为什么要搜索,当你可以问 AI?,彭博社,2022 年。

  4. Subbarao Kambhampati, 改变人工智能研究的性质,ACM 通信,2022 年。

  5. Roger Montti, 什么是 Google LaMDA,为什么有人相信它具有意识?,搜索引擎期刊,2022 年。

  6. 艾米莉·M·本德,蒂姆尼特·戈布鲁,安吉丽娜·麦克米伦-梅杰,施玛尔格·施密切,关于随机鹦鹉的危险:语言模型会不会过大? FAccT 2021。

  7. 蒂姆尼特·戈布鲁,有效利他主义推动一种危险的“AI 安全”品牌,Wired,2022 年。

  8. 克里希纳·加德,www.linkedin.com/feed/update/urn:li:activity:7005573991251804160/

  9. 奇拉格·沙,艾米莉·本德,情境化搜索,CHIIR 2022。

  10. 为什么发布 GPT 和 ChatGPT 生成的回答目前不可接受,Stack Overflow,2022 年。

  11. openai.com/blog/new-and-improved-embedding-model/

  12. 凯尔·维格斯,OpenAI 尝试给 AI 文本加水印遇到的限制,TechCrunch,2022 年。

  13. 7 个最佳大学论文写作服务:评论和排名,GlobeNewswire,2021 年。

  14. 卡尔汉·罗森布拉特,ChatGPT 被禁止在纽约市公立学校的设备和网络上使用,NBC 新闻,2023 年。

  15. 尼古拉斯·卡尔尼,弗洛里安·特拉默,埃里克·沃勒斯,马修·贾吉尔斯基,阿里尔·赫伯特-沃斯,凯瑟琳·李,亚当·罗伯茨,汤姆·布朗,道恩·宋,乌尔法尔·厄尔林森,阿丽娜·奥普雷亚,科林·拉费尔,从大型语言模型中提取训练数据,USENIX 安全研讨会,2021 年。

  16. 马丁·安德森,从预训练自然语言模型中检索现实世界的电子邮件地址,Unite.AI,2022 年。

  17. 玛丽·里根,理解 AI 系统中的偏见和公平性,Fiddler AI 博客,2021 年。

  18. 李·露西,戴维·班曼,GPT-3 生成故事中的性别和表现偏见,ACL 叙事理解研讨会,2021 年。

  19. 安德鲁·迈尔斯,剔除流行语言模型 GPT-3 中的反穆斯林偏见,斯坦福 HAI 新闻,2021 年。

  20. 詹姆斯·文森特,推特让微软的 AI 聊天机器人在不到一天的时间里变成了一个种族主义者,The Verge,2016 年。

  21. 比利·佩里戈,OpenAI 在肯尼亚以每小时不到 2 美元的工资雇佣工人来减少 ChatGPT 的毒性,《时代》杂志,2023 年。

  22. Karthik Abinav Sankararaman, Sinong Wang, Han Fang, BayesFormer:具有不确定性估计的 Transformer,arxiv,2022 年。

  23. Andrew Ng, ChatGPT 疯狂!加密混乱削减了 AI 安全资金,Alexa 讲睡前故事,《The Batch — Deeplearning.ai 通讯》,2022 年。

  24. Emma Strubell, Ananya Ganesh, Andrew McCallum, 深度学习在 NLP 中的能源和政策考虑,ACL 2019 年。

  25. Eric Wallace, Tony Z. Zhao, Shi Feng, Sameer Singh, 对 NLP 模型的隐蔽数据污染攻击,NAACL 2021 年。

  26. Karen Hao, 训练单一 AI 模型可能排放相当于五辆汽车使用寿命的碳,《麻省理工学院技术评论》,2019 年。

不那么庞大的语言模型:优质数据打败巨人

原文:towardsdatascience.com/not-so-large-language-models-good-data-overthrows-the-goliath-a8226bd1ae61?source=collection_archive---------6-----------------------#2023-08-23

(图像由 DALL·E 生成)

如何制造一个百万级别的语言模型来超越十亿级别的模型

Gennaro S. RodriguesTowards Data Science Gennaro S. Rodrigues

·

关注 发表在 Towards Data Science · 6 min read · 2023 年 8 月 23 日

--

在这篇文章中,我们将探讨语言模型(LM)如何通过关注更好的数据和训练策略,而不仅仅依赖庞大的规模,来实现类似 LLM 的结果(有时甚至更好),以及人们如何已经成功且民主地做到这一点。

大型语言模型(LLMs)已经显著发展。它们带来了从生成类似人类的文本到理解复杂上下文的显著特性。虽然最初的兴奋主要集中在具有大量参数的模型上,但最近的发展表明,大小并不是唯一重要的因素。最近,一个新的概念“小型语言模型”(SLM)应运而生,致力于更智能地开发语言模型。

大模型的兴起

随着 LLMs 的出现,叙事变得简单明了——更大更好。具有更多参数的模型被期望能够更好地理解上下文,减少错误,提供更好的答案。但随着模型的增长,它们对计算资源的需求也增加了。训练这些巨型模型变得非常昂贵,这不是每个人都愿意(也不一定能)支付的。

对质量和效率的强调

认识到仅仅增加参数的不可持续性和递减回报,研究人员开始重新思考策略。与其只是将钱投入云端(增加更多的参数),一些研究人员转而利用更好的数据和更高效的训练策略。这个想法很优雅:一个训练良好的小模型可能会超越一个训练不良的大模型。但这可能吗?

Chinchilla 和 LLMs 训练的最佳点

“Chinchilla 论文” [1] 是对该领域的重要贡献,提供了对 LLMs 训练的有趣见解。实验似乎表明,在训练 LLMs 时存在一个“最佳点”。超过这个点,投入更多的资源(如更多参数)不一定会导致性能的成比例提高。论文强调,定义模型性能的不仅仅是模型的大小,而是数据的质量和使用的数据量。作者发现,为了实现计算最优训练,模型大小和训练令牌的数量应当等比缩放:每增加一倍的模型大小,训练令牌的数量也应增加一倍。

他们通过训练 Chinchilla(一个 70 亿参数的模型,训练于 1.4 万亿令牌)来测试这一点。尽管 Chinchilla 小得多,但在几乎所有评估中,包括语言建模、问答、常识任务等,Chinchilla 的表现都优于 Gopher。

Chinchilla 的大小和训练令牌与 SOTA LLMs 的比较。(来源:[1])

即使在其减少的规模下,Chinchilla 在各种任务上的表现也优于其 SOTA 对手:

大规模多任务语言理解(MMLU)。报告了 57 项任务中的平均 5-shot 准确率,并与来自[2]的模型和人类准确率比较,以及来自[3]的 73 名竞争性人类预测者在 2022/2023 年 6 月的 SOTA 准确率的平均预测。(来源:[1])

阅读理解和自动推理是语言模型通常会测试的标准任务。它测试模型理解文本更广泛背景的能力。在我们的案例中,可以通过预测那些仅在模型能够理解单词与之前上下文关系的情况下才会预期到的单词来进行示例。通常使用基准测试和数据集,如 RACE-h、RACE-m [4] 和 LAMBADA [5] 进行评估。即使在这种难以定义和测试的任务中,Chinchilla 也超越了更大的模型。

在阅读理解方面,Chinchilla 相比于 Gopher 显著提升了性能。(来源:[1])

Chinchilla 是许多尽管没有注重扩展规模但仍展现出有希望结果的语言模型之一。

LLaMA

LLaMA[6] 甚至更进一步。作者引入了从 7B 到 65B 参数的较小基础语言模型。它们在超过 1 万亿个标记的数据上进行训练,使用的仅是公开数据,使其兼容开源。

LLaMA-13B 在大多数基准测试中超过了参数多达 175B 的 GPT-3,而其体积小于 GPT-3 的 10 倍。作者认为,考虑到目标性能水平,训练时间更长的小型模型在给定计算预算下比大型模型更具优势,因为推理效率更高。

LLaMA 在常识推理任务中的零-shot 表现。(来源:[6])

一些项目甚至成功在预算有限的安卓智能手机上运行 LLaMA(或其版本),进一步证明我们正走在通过低计算资源实现语言模型民主化的正确道路上(LLaMA.c [7])。

LLaMA-65B(我知道,现在不算那么小,但仍然……)在与使用专有数据集的现有最先进模型如 PaLM-540B 的竞争中表现良好。这清楚地表明,优质数据不仅能提升模型的性能,还能使其变得民主化。机器学习工程师无需巨额预算就能在优质数据集上获得良好的模型训练。

优质数据胜过巨无霸

进一步巩固了语言模型不需要庞大才能表现良好的论点,TinyStories [8] 提供了一个合成数据集,其中包含仅供小孩子(最多四岁)理解的单词。它可以用来训练参数少于 1000 万的小型语言模型(SLMs),这些模型能够生成语法、推理和连贯性良好的多段故事。这与先前的研究形成对比,125M+ 参数的模型——如 GPT-Neo(小型)和 GPT-2(小型)——在生成连贯文本方面存在困难。

训练了 TinyStories 的模型能产生与参数大两个数量级的模型相当的输出。(来源:[8])

TinyStories 的一个令人兴奋的方面是数据集本身是由 GPT-3.5 和 GPT-4 创建的。作者们还引入了一种新的 SLM 评估范式,使用 GPT-4 对生成的故事在语法、情节和创意等维度上进行“评分”。这克服了标准基准测试要求受限输出的局限性。

结论

语言模型的发展展示了 AI 中的一个关键教训:更大并不总是更好。随着社区的持续进化和创新,人们意识到效率、数据质量和优化的训练策略是机器学习未来的关键。

关键要点

  • Chinchilla 证明了在训练语言模型时,令牌数量和训练数据质量之间存在一个最佳点。这一点与(或更重要于)模型参数的数量定义同样重要;

  • LLaMa 显示了使用仅公开数据就能达到类似 Chinchilla 的结果,证明了这一策略具有普遍可用性;

  • 像 TinyStories 这样的数据集可以用于训练小型语言模型(少于 1 亿),在特定任务上超越了十亿规模的模型。

参考文献

[1] Hoffmann, Jordan 等. “训练计算最优的大型语言模型。” arXiv 预印本 arXiv:2203.15556(2022 年)。

[2] D. Hendrycks 等. “测量大规模多任务语言理解。” arXiv 预印本 arXiv:2009.03300(2020 年)。

[3] J. Steinhardt. 来自 AI 预测的更新和经验教训,2021 年。URL https://bounded-regret.ghost.io/ai-forecasting/。

[4] Lai, Guokun 等. “RACE: 大规模阅读理解数据集来自考试。” 2017 年自然语言处理会议论文集,页码 785–794,哥本哈根,丹麦。计算语言学协会。

[5] Paperno 等,2016 “LAMBADA 数据集:需要广泛语篇背景的单词预测。” arXiv:1606.06031(2016 年)。

[6] Touvron, Hugo 等. “LLaMA: 开放且高效的基础语言模型。” ArXiv abs/2302.13971(2023 年)

[7] github.com/karpathy/llama2.c

[8] Eldan, Ronen 和 Yuan-Fang Li. “TinyStories:语言模型可以小到什么程度仍然能够说出连贯的英语?” ArXiv abs/2305.07759(2023 年)

那么,为什么我们应该关心推荐系统呢?特邀:对汤普森采样的简要介绍

原文:towardsdatascience.com/now-why-should-we-care-about-recommendation-systems-ft-a-soft-introduction-to-thompson-sampling-b9483b43f262

正在进行的推荐系统系列

Irene ChangTowards Data Science Irene Chang

·发布于 Towards Data Science ·阅读时间 12 分钟·2023 年 11 月 7 日

--

图片由 Myke Simon 提供,来源于 Unsplash

今天我发现自己再次陷入了相同的情境,连续第 100...01 天,一边浏览 Netflix 寻找观看的节目,一边拿着晚餐盒子吃饭。我的推荐内容中充斥着过多的亚洲浪漫和美国成长题材的建议,可能是基于我一个月或两个月前观看过的这些类别的某几部剧。 “这里没什么好看的…”–我一边阅读完所有简介一边叹了口气,自信地觉得自己能预测剧情的发展。我掏出了另一个备用的娱乐选项 Tiktok,同时潜意识里想着我可能需要不感兴趣一些视频,而喜欢保存其他视频,以便…推荐算法今天给我推送一些新的内容流。

推荐系统(RecSys)可以被认为是一个已经非常成熟的算法,它已经深深植入我们的日常生活中,以至于在 1 到 Chat-GPT 的尺度上,它在学术界和非学术界都感觉像是 80 年代的趋势。然而,它绝不是一个近乎完美的算法。操作推荐应用程序所面临的伦理、社会、技术和法律挑战从未成为研究的前沿(就像大多数其他技术产品一样……)。例如,选择性群体的不公平和隐私侵犯是围绕 RecSys 的热门担忧,但这些问题仍未得到实施公司充分解决。此外,还存在许多更微妙的问题,通常没有得到足够的深思,其中之一是个体决策过程中的自主权丧失。一种“强大的”RecSys 无疑可以将用户推向某个方向[2],使他们购买、观看、思考、相信他们如果不受到这种操控本不会做的事情。

因此,我想在我的研究生学习旅程中写一系列文章,随着我开始学习并深入探讨 RecSys 的优缺点……一切从零开始!我觉得可以从思考电影和……汤普森采样开始!

汤普森采样

汤普森采样 (TS) 是推荐系统文献和强化学习中的基础算法之一。正如 Samuele Mazzanti 在这篇精彩的文章中清楚解释的那样,它可以被认为是在在线学习环境中更好的 A/B 测试。简单来说,在电影推荐的背景下,TS 试图识别出最适合推荐给我的电影,以最大化我点击观看的机会。它可以通过相对较少的数据有效地做到这一点,因为它允许在每次观察到我是否点击电影时更新参数。粗略地说,这种动态特性使得 TS 能够在考虑我的观看历史和收藏的系列之外,还能实时考虑像浏览或在我当前正在使用的应用程序中的搜索结果等因素,以给我最合适的建议。然而,在这个适合初学者的教程中,我们仅仅看一下下面的简化分析。

让我们进一步分析吧!

考虑这 3 部电影,尽管它们都很棒,但我却有自己个人的排名。假设这 3 部电影中,有一部是如果出现在我的推荐中我将 100%重新观看,有一部是我极不可能重新观看的(5%),还有一部是我每次看到时有 70%的机会会点击观看。显然,TS 在事先并不了解这些关于我的信息,它的目标是学习我的行为,以便,如常识所说,推荐我它知道我一定会点击观看的电影。

作者提供的图片

在 TS 算法中,主要的工作流程如下:

  1. 行动:TS 建议我观看特定的电影,在数百部电影中选择

  2. 结果:我决定电影对我来说足够有趣并点击观看,或者我觉得无聊,在阅读了简介后点击退出页面。

  3. 奖励:可以被看作是 TS 在我点击观看某部电影时获得的“积分”数量,或者在我不点击时 TS 失去的积分。在基本的电影或广告推荐设置中,我们可以将奖励视为结果的等价物,因此 1 次点击电影=1 积分!

  4. 更新知识:TS 记录我的选择并更新其对我最喜欢电影的信念。

  5. 重复第 1 步(可以在我当前的浏览会话中,或者第二天晚餐时间),但现在有了关于我偏好的额外知识。

探索/利用

这是该文献中使用最多的术语,也是区分 TS 和其他相关算法的关键。上述第 5 步是这个逻辑开始发挥作用的地方。在 TS 的世界中,一切都存在某种程度的不确定性。我每周喝三次拿铁和五次抹茶并不一定意味着我比拿铁更喜欢抹茶,如果只是那一周(而我每周平均实际上喝的拿铁比抹茶多)呢?因此,TS 中的一切都由某种类型的分布表示,而不仅仅是单个数字。

图 1 在某一周,我喝了 5 杯抹茶和 3 杯拿铁(左),但平均每周我喝的拿铁比抹茶多(右)——作者提供的图片

起初,TS 显然对我对电影的偏好有很多不确定性,因此它的优先任务是探索这一点,通过给我提供许多不同的电影建议来观察我的反应。在经过几次点击和跳过后,TS 可以大致了解我倾向于点击的电影和没有效益的电影,从而对下一次给我推荐的电影有了更多的信心。这时,TS 开始利用高回报的选项,它会给我推荐我经常点击的电影,但仍然留有一些探索的空间。随着更多观察的积累,信心不断建立,简单情况下,探索的工作将变得非常少,因为 TS 已经对能够带来大量奖励的推荐有了很大的信心。

探索与利用通常被称为权衡或困境,因为过多的探索(即使在获得足够证据后仍然没有排除低价值选项)会导致大量损失,而过多的利用(即过快地排除太多选项)可能会错误地排除真正的最佳行动。

分布:Beta-Bernoulli

如上面的抹茶拿铁图所示,TS 使用不同类型的分布来理解我们对不同选项的偏好。在最基本的电影(和广告)情况下,我们通常使用 Beta-Bernoulli 组合。

伯努利分布.) 是一种离散分布,其中只有两种可能的结果:1 和 0。伯努利分布只有一个参数,表示某个变量,比如 Y,取值为 1 的概率。因此,如果我们说 Y~ Bern(p),比如 p = 0.7,这意味着 Y 有 0.7 的机会取值为 1,而 1–p = 1–0.7 = 0.3 的机会取值为 0。因此,伯努利分布适合用于建模奖励(在我们的例子中也是结果),因为我们的奖励只有两种结果:点击未点击

另一方面,Beta 分布用于建模 TS 对我电影兴趣的信念。Beta 分布有两个参数,alpha 和 beta,通常被认为是成功和失败的次数,两者都必须 ≥ 1。因此,使用 Beta 分布来建模我点击观看和跳过电影的次数是合适的。我们来看一个例子。这里有 3 个不同的 Beta 分布,代表 3 部电影,在 10 次观察中,所以所有 3 部电影的点击和跳过次数总和相同(10),但点击和跳过率不同。对于电影 1,我点击观看 2 次(alpha = 2)和跳过 8 次(beta = 8);对于电影 2,我点击观看 5 次和跳过 5 次;对于电影 3,我点击观看 8 次和跳过 2 次。

图 2. 图片由作者提供

根据图表,我们可以看到,我再次观看电影 2 的概率大约在 50% 处达到峰值,而电影 1 的这个概率要低得多。例如。我们可以将这些曲线视为观看电影的概率的概率,因此 Beta 分布非常适合表示 TS 对我电影偏好的信念。

算法

在本节中,我将帮助你清楚地理解算法的实现和方法论。首先,这是 Thompson Sampling 算法的一个片段,分别是伪代码和 Python 实现。伪代码摘自一本关于 TS 的精彩书籍,A tutorial on Thompson Sampling [Russo, 2017]。

图 3 Thompson Sampling,Python 实现(左)和伪代码(右)— 图片由作者提供

让我们来详细分析一下!

样本模型

算法的第一步涉及“猜测”我对每部电影的喜好。如前一节所述,我对每部电影的偏好可以使用 Beta 曲线表示,如图 2 所示,而 TS 对此没有先验知识,并且试图弄清楚这些 Beta 曲线的样子。在t = 1(第一轮)时,TS 可以假设我对所有 3 部电影的喜好相同,即点击和跳过的初始次数相等(我的 3 条 Beta 曲线将看起来相同)。

图 4 TS 对我对 3 部电影的偏好的首次猜测相同

这里的三种分布就是图 3 中伪代码中的p。从每个分布中,TS 将采样一个值,用 theta 表示,以帮助下一步的动作选择。

图 5 示例的 alpha-beta 值对,表示我们对每部 3 部电影的初始猜测的分布(也称为动作/手臂)

选择并应用动作

在此步骤中,TS 根据采样的 theta 值中最大的值选择要执行的动作(即选择推荐的电影)。以图 2 为例。假设我们只有 2 部电影——电影 1 和电影 3。使用最大的 theta 选择动作的想法是,如果真实分布几乎没有重叠,而我在我们的例子中几乎肯定喜欢一部电影多于另一部,那么电影 1 的采样 theta 很可能不会大于电影 3 的 theta。以类似的方式,如果我们只考虑电影 2 和 3,我们可以看到现在这些分布之间有更多重叠。然而,如果我们继续在足够多的轮次中采样更多的 theta 值,那么我们可以观察到电影 3 的 thetas > 电影 2 的 thetas 的比例大于反之,TS 将有足够的信息得出电影 3 是更好的“动作”的结论。一般来说,这也是为什么未知真实分布越明显,TS 找出哪个动作或手臂是最优的实验轮次就越少的原因。

在应用选择的动作后,TS 将收到我的反馈,即我是否点击观看电影。正如上面提到的,这一结果也被视为我们对相应动作的奖励。TS 将记录这一观察结果,并在下一步中用它来更新对我电影偏好的信念。

更新分布

在上面的 Beta 分布描述中,我们确定 Beta 分布的特征是成功次数和失败次数。我点击观看某部电影的次数越多,该电影的 Beta 分布的模式就越趋近于 1,而相反地,我跳过推荐的次数越多,模式就越趋近于 0。因此,在电影被推荐并记录响应后,对电影的信念更新是通过将电影的 Beta 分布的 alpha 或 beta 参数加 1 来完成的,具体取决于电影是被点击还是被跳过。

这种简单且易于解释的参数更新方法就是为什么 Beta Bernoulli 是一种非常常见的 TS 模型。

结果与讨论

回到文章开始时的情境。我们正在猜测 3 部电影中哪一部最适合推荐给我,假设有一部我会 100% 点击观看,一部我有 70% 的点击概率,另一部只有 5% 的点击概率(再次强调,这些信息 TS 并不知道)。第一行展示了两种不同的模拟起始点,这将使我们观察是否可以通过不同的初始先验信念达到相同的最终结果。

图 6 TS 模拟的不同轮数 T = 5, 10, 20。 Beta 分布代表了实验结束时 TS 对我电影偏好的信念。 左列:初始分布为 Beta(1, 1) 时的结果。右列:初始分布为 Beta(2, 3) 时的结果

从图 6 中,我们可以看到我最终最喜欢的电影是电影 1 — 《寄生虫》(对不起,漫威粉丝)!!

如我们所见,两种情况的探索过程不同,其中 Beta(1, 1) 的初始猜测导致更快地找到被认为是我最喜欢的电影。只需要 T=10 轮就可以看出 TS 明显在开发电影 1,这意味着 TS 已经推荐了电影 1 并得到了我的点击,因此其 Beta 分布向右拉动,因此从更新的分布中采样的 theta 超过了它的竞争对手,导致了开发。这种开发在 T=5 轮时已经出现,但根据相应的图表,电影 1 和电影 3 的 Beta 之间仍然存在较多重叠,它们的模式并不完全不同,这意味着 TS 仍然不完全确定电影 1 是最优的行动。

另一方面,Beta(2, 3) 的初始信念使得 TS 需要更多的轮次才能到达电影 1(T=20)。即使在 T=10 时,电影 1 和 电影 3 之间仍存在很大的不确定性,并且观察到由于 theta 采样的随机性,电影 3 可能被错误地当作最佳选项。这项实验表明,每个行动的初始先验知识在检测最佳臂的速度上起着作用,关于这个主题我们可以在未来的文章中进一步深入探讨。

需要注意的是,如果电影的实际分布几乎相同(比如电影 1 和电影 3 的点击率分别为 100% 和 98%),TS 很可能无法识别最佳行动,因为来自一个分布的样本 thetas 超过另一个分布的样本 thetas 的比例会被拆分。因此,如果由于偶然性,电影 3 的“较大 thetas”更多,TS 将更多地利用这个选项,导致其被错误地识别为最佳行动。

实验的另一个发现是,TS 仅能告诉我们最佳行动是什么,但不能提供关于其他选项的信息。这是因为在探索过程中,TS 会迅速淘汰那些被认为不是最优的选项,因此 TS 停止接收这些行动的进一步信息,从而不能提供最佳选项以外的行动的正确排序。

结论

在这篇文章中,我们探讨了汤普森采样算法,并通过电影推荐模拟进行了演练。汤普森采样在提供预测时涉及大量的分布和先验知识,这是贝叶斯模型的核心概念,我计划在即将到来的文章中与大家进一步讨论。如果你读到了这里,谢谢你的时间,希望这篇教程能给你提供对这个算法的技术和直观理解!如果你有任何进一步的问题,随时通过我的 LinkedIn 联系我,很高兴与您联系并回答问题!

参考文献:

[1] 推荐系统在不同领域和背景中的潜在风险和挑战是什么?

[2] 推荐系统及其伦理挑战

[3] 何时应该选择“汤普森采样”而不是 A/B 测试

现在你看到我 (CME): 基于概念的模型提取

原文:towardsdatascience.com/now-you-see-me-cme-concept-based-model-extraction-97231105f8fa?source=collection_archive---------5-----------------------#2023-09-22

一种标签高效的基于概念的模型方法

Dmitry KazhdanTowards Data Science Dmitry Kazhdan

·

关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 9 月 22 日

--

来自 CIKM 会议上展示的 AIMLAI 研讨会论文:“Now You See Me (CME): 基于概念的模型提取” (GitHub)

视觉摘要。图片由作者提供。

总结

问题 — 深度神经网络模型是黑箱,无法直接解释。因此 — 很难建立对这些模型的信任。现有方法,如概念瓶颈模型,能够使这些模型更具可解释性,但需要高昂的标注成本来标注基础概念。

关键创新 — 一种以 弱监督方式 生成基于概念的模型的方法,从而显著减少注释需求

解决方案 — 我们的 基于概念的模型提取(CME)框架,能够以 半监督 方式从预训练的原始卷积神经网络(CNN)中提取基于概念的模型,同时保持最终任务性能。

原始 CNN 的端到端输入处理。作者提供的图像。

两阶段概念模型处理。作者提供的图像。

概念瓶颈模型(CBMs)

近年来,解释性人工智能(XAI)[1] 领域对概念瓶颈模型(CBM)方法 [2] 的兴趣激增。这些方法引入了一种创新的模型架构,其中输入图像分为两个不同的阶段处理:概念编码概念处理

在概念编码过程中,概念信息从高维输入数据中提取。随后,在概念处理阶段,提取的概念信息用于生成所需的输出任务标签。CBMs 的一个显著特点是它们依赖于具有语义意义的 概念表示,作为下游任务预测的中间、可解释的表示,如下所示:

概念瓶颈模型处理。作者提供的图像。

如上所示,CBM 模型通过结合 任务损失 确保准确的任务标签预测,以及 概念损失 确保准确的中间概念预测进行训练。重要的是,CBMs 增强了模型的透明度,因为底层概念表示提供了一种解释和更好理解模型行为的方法。

概念瓶颈模型提供了一种新型的设计可解释的 CNN,允许用户通过概念将现有领域知识编码到模型中。

总体而言,CBMs 是一项重要的创新,使我们更接近于更透明和可信的模型。

挑战:CBMs 具有高概念注释成本

不幸的是,CBMs 在训练期间需要大量的概念注释。

目前,CBM 方法要求对 所有 训练样本进行显式注释,同时 包括最终任务和概念注释。因此,对于一个包含 N 个样本和 C 个概念的数据集,注释成本从 N 个注释(每个样本一个任务标签),增加到 N(C+1)* 个注释(每个样本一个任务标签,且每个概念一个概念标签)。在实践中,这可能迅速变得难以管理,特别是对于具有大量概念和训练样本的数据集。

例如,对于一个包含 10,000 张图片和 50 个概念的数据集,注释成本将增加 50*10,000=500,000 个标签,即增加 半百万 个额外注释。

不幸的是,概念瓶颈模型需要大量的概念标注进行训练。

利用 CME 的半监督概念模型

CME 依赖于 [3] 中强调的类似观察,其中观察到原始 CNN 模型通常在其 隐藏空间 中保留大量有关概念的信息,这可以用于无额外标注成本的概念信息挖掘。重要的是,这项工作考虑了基础概念 未知 的场景,并且必须以无监督的方式从模型的隐藏空间中提取。

使用 CME,我们利用上述观察,考虑一个场景,在该场景中,我们 已经 了解基础概念,但每个概念只有少量样本标注。类似于 [3],CME 依赖于给定的预训练原始 CNN 和少量概念标注,以 半监督的方式 提取进一步的概念标注,如下所示:

CME 模型处理。图片来源于作者。

如上所示,CME 使用预训练模型的隐藏空间以 事后 方式提取概念表示。详细信息见下文。

概念编码器训练:与 CBMs 处理原始数据的概念编码器从零开始训练不同,我们以 半监督的方式 设置概念编码器模型训练,使用原始 CNN 的隐藏空间:

  • 我们首先预先指定一组层 L,从原始 CNN 中用于概念提取。这可以是 所有 层,也可以只是最后几层,具体取决于可用的计算能力。

  • 接下来,对于每个概念,我们在 L 中 每个 层的隐藏空间上训练一个独立的模型,以预测该概念的值。

  • 我们继续选择具有最佳模型准确度的模型和相应层作为“最佳”模型和层来预测该概念。

  • 因此,在为概念 i 做出预测时,我们首先检索该概念的最佳层的隐藏空间表示,然后将其通过相应的预测模型进行推断。

总体来说,概念编码器 功能可以总结如下(假设总共有 k 个概念):

CME 概念编码器方程。图片来源于作者。

  • 这里,LHS 上的 p-hat 代表概念编码器函数。

  • gᵢ 项代表在不同层隐藏空间上训练的隐藏空间到概念模型,i 代表概念索引,范围从 1 到 k。在实际应用中,这些模型可以非常简单,例如线性回归器或梯度提升分类器。

  • f(x) 项代表原始原始 CNN 的子模型,提取输入在特定层的隐藏表示。

  • 在以上两种情况下, 上标指定了这两种模型操作的“最佳”层

概念处理器训练:CME 中的概念处理器模型训练是通过使用任务标签作为输出、概念编码器 预测作为输入来设置的。重要的是,这些模型操作在更紧凑的输入表示上,因此可以通过可解释的模型(如决策树(DTs)或逻辑回归(LR)模型)直接表示。

CME 实验与结果

我们在合成数据集(dSpritesshapes3d)以及具有挑战性的真实数据集(CUB)上的实验表明,CME 模型:

  • 实现高概念预测准确度,在许多情况下可与 CBM 相媲美,即使在与最终任务无关的概念上:

CBM 和 CME 模型的概念准确度,绘制了三个不同预测任务中的所有概念。图像由作者提供。

  • 允许对概念进行人为干预 — 即允许人们通过修正少量选定概念来快速改善模型性能:

CME 和 CBM 模型性能在不同概念干预程度下的变化。图像由作者提供。

  • 从概念的角度解释模型决策, 允许实践者直接绘制概念处理器模型:

一个概念处理器模型直接可视化的示例,针对一个选定任务。图像由作者提供。

  • 通过分析模型层间的隐藏空间,帮助理解模型对概念的处理:

一个简单 CNN 的隐藏空间可视化示例。列代表不同的层,行代表不同的概念,每行的颜色对应于该概念的值。标有 * 的为“最佳” CME 层。图像由作者提供。

通过在弱监督领域定义基于概念的模型(CME),我们可以开发出显著更具标签效率的基于概念的模型

主要结论

通过利用预训练的普通深度神经网络,我们可以在极大降低注释成本的情况下获得概念注释和基于概念的模型,与标准 CBM 方法相比。

此外,这不仅严格适用于与最终任务高度相关的概念,在某些情况下,也适用于与最终任务独立的概念。

参考文献

[1] Chris Molnar. 解释性机器学习。 christophm.github.io/interpretable-ml-book/

[2] Pang Wei Koh, Thao Nguyen, Yew Siang Tang, Stephen Mussmann, Emma Pierson, Been Kim, 和 Percy Liang. 概念瓶颈模型。在国际机器学习会议,第 5338–5348 页。PMLR*(2020)。

[3] Amirata Ghorbani, James Wexler, James Zou, 和 Been Kim. 朝向自动化基于概念的解释。 神经信息处理系统的进展32

np.stack() — 如何在 Numpy 和 Python 中堆叠两个数组

原文:towardsdatascience.com/np-stack-how-to-stack-two-arrays-in-numpy-and-python-fc910dd2d57a

Numpy 中堆叠的初学者和高级示例——学习如何轻松地连接数组序列

Dario RadečićTowards Data Science Dario Radečić

·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 1 月 10 日

--

图片由Brigitte Tohm提供,来源于Unsplash

Numpy 是一个在数据科学和机器学习中非常出色的库,因此如果你想成为数据专业人士,就必须掌握它。掌握这个包的方方面面是必要的,因为重新发明轮子是没有意义的——几乎你能想到的任何东西都已经实现了。

今天你将了解所有关于 np stack 的信息——即 Numpy 的stack()函数。简单来说,它允许你按行(默认)或按列连接数组,具体取决于你指定的参数值。我们将讨论基础知识和函数签名,然后进入 Python 中的示例。

什么是 np stack?

Numpy 的 np stack 函数用于在新轴上堆叠/连接数组。它将返回一个单一数组,作为堆叠多个形状相同的序列的结果。你也可以堆叠多维数组,稍后你将很快学到这一点。

但首先,让我们解释一下水平堆叠和垂直堆叠的区别。

Numpy 中的水平堆叠与垂直堆叠

水平堆叠数组意味着你将具有相同维度的数组堆叠在彼此之上。每个输入数组将在结果数组中成为一行。

查看下面的图像以更好地理解:

图 1 — 解释水平堆叠(图像由作者提供)

垂直堆叠则完全相反。两个垂直堆叠数组的一行包含来自两个数组的对应元素。

例如,垂直堆叠数组 Z 的第一行将包含输入数组 X 和 Y 的第一个元素。

也许你会发现视觉上更容易理解:

图 2 — 垂直堆叠解释(图片由作者提供)

说到这里,让我们看看 np.stack 函数的签名。

函数参数解释

np.stack 函数最多可以接受三个参数,其中只有第一个是必需的:

  • arrays - 数组的序列,或你想要堆叠的数组数组

  • axis - 整数,沿着你想要堆叠数组的轴(0 = 按行堆叠,1 = 对于一维数组按列堆叠,或使用 -1 使用最后一个轴)

  • out - 可选的结果存放位置。如果提供,输出数组的形状必须与堆叠结果的形状匹配

理论讲解够了!现在我们来看看一些实际的示例。

Numpy 堆叠实战 — 函数示例

我们讨论了很多关于水平和垂直堆叠的内容,所以让我们看看它在实践中的表现。

Numpy 水平堆叠(按行)

要水平堆叠两个 numpy 数组,只需调用 np.stack 函数并传入这些数组。无需其他参数:

import numpy as np

arr1 = np.array([1, 2, 3, 4])
arr2 = np.array([5, 6, 7, 8])

# Horizontal (row-wise) stacking #1
arr_stacked = np.stack([arr1, arr2])
print('Numpy horizontal stacking method #1')
print('-----------------------------------')
print(arr_stacked)

这是得到的结果:

图 3 — Numpy 中的水平堆叠 (1)(图片由作者提供)

如你所见,输出看起来很像 Pandas DataFrame 的 Numpy 版本,这意味着一个数组几乎等于矩阵的一行。

更明确地说,你可以通过将 axis=0 作为第二个参数来实现相同的结果:

import numpy as np

arr1 = np.array([1, 2, 3, 4])
arr2 = np.array([5, 6, 7, 8])

# Horizontal (row-wise) stacking #2
arr_stacked = np.stack([arr1, arr2], axis=0)
print('Numpy horizontal stacking method #2')
print('-----------------------------------')
print(arr_stacked)

结果是相同的:

图 4 — Numpy 中的水平堆叠 (2)(图片由作者提供)

接下来,让我们探索垂直堆叠。

Numpy 垂直堆叠(按列)

要垂直堆叠两个 numpy 数组,只需将 axis 参数的值更改为 1:

import numpy as np

arr1 = np.array([1, 2, 3, 4])
arr2 = np.array([5, 6, 7, 8])

# Vertical (column-wise) stacking #1
arr_stacked = np.stack([arr1, arr2], axis=1)
print('Numpy vertical stacking method #1')
print('---------------------------------')
print(arr_stacked)

现在,数组按列堆叠,这意味着你将有与提供的数组数量相等的列:

图 5 — Numpy 中的垂直堆叠 (1)(图片由作者提供)

使用简单的一维数组,你还可以设置 axis=-1 来垂直堆叠数组:

import numpy as np

arr1 = np.array([1, 2, 3, 4])
arr2 = np.array([5, 6, 7, 8])

# Vertical (column-wise) stacking #2
arr_stacked = np.stack([arr1, arr2], axis=-1)
print('Numpy vertical stacking method #2')
print('---------------------------------')
print(arr_stacked)

结果是相同的:

图 6 — Numpy 中的垂直堆叠 (2)(图片由作者提供)

接下来,让我们讨论一些关于堆叠 N 维数组的内容。

使用 stack() 合并一维数组

你已经看到如何堆叠一维数组了,下面是回顾:

import numpy as np

arr1 = np.array([1, 2, 3, 4])
arr2 = np.array([5, 6, 7, 8])

# Stacking 1D arrays
arr_stacked = np.stack([arr1, arr2])
print('Numpy stacking 1D arrays')
print('------------------------')
print(arr_stacked)

输出结果:

图 7 — 堆叠一维数组(图片由作者提供)

记住,如果你想按列堆叠数组,可以更改 axis 参数的值。

使用 stack() 合并二维数组

对于使用 np.stack 堆叠二维数组,过程是一样的。这是一个示例:

import numpy as np

arr1 = np.array([
    [1, 2, 3, 4],
    [5, 6, 7, 8]
])
arr2 = np.array([
    [9, 10, 11, 12],
    [13, 14, 15, 16]
])

# Stacking 2D arrays #1
arr_stacked = np.stack([arr1, arr2])
print('Numpy stacking 2D arrays method #1')
print('----------------------------------')
print(arr_stacked)

我们现在得到一个三维数组,每个元素是两个水平堆叠数组的二维数组:

图 8 — 堆叠二维数组 (1)(图片由作者提供)

一如既往,你可以垂直堆叠二维数组:

import numpy as np

arr1 = np.array([
    [1, 2, 3, 4],
    [5, 6, 7, 8]
])
arr2 = np.array([
    [9, 10, 11, 12],
    [13, 14, 15, 16]
])

# Stacking 2D arrays #2
arr_stacked = np.stack([arr1, arr2], axis=1)
print('Numpy stacking 2D arrays method #2')
print('----------------------------------')
print(arr_stacked)

以下是输出结果:

图像 9 — 堆叠 2D 数组(2)(作者提供的图像)

这就是 numpy 堆叠的基本知识了。接下来,我们将介绍一些高级用法示例和常见问题。

高级:循环中的 np stack

常见的问题之一是如何在循环中使用 np stack。这里有一个示例 —— 它将两个二维数组首先合并为一个三维数组:

import numpy as np

arr1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
arr2 = np.array([[13, 14, 15], [16, 17, 18]])
matrix = [arr1, arr2]

print('Numpy stacking in a loop - intermediary matrix')
print('----------------------------------------------')
print(matrix)

这是中间输出:

图像 10 — 循环中的 Numpy 堆叠(1)(作者提供的图像)

现在,要生成一个水平堆叠元素的二维数组,你可以使用循环:

arr3 = np.empty(shape=[0, matrix[0].shape[1]])

for m in matrix:
    arr3 = np.append(arr3, m, axis=0)

print('Numpy stacking in a loop')
print('------------------------')
print(arr3)

结果如下:

图像 11 — 循环中的 Numpy 堆叠(2)(作者提供的图像)

现在我们将讨论一些关于 Python 中 np stack 函数的常见问题。

常见问题

stack 和 concatenate 有什么区别?

简而言之,当传入两个一维数组时,np stack 函数将返回一个二维数组。而 np concatenate 函数则将所有输入数组的元素合并为一个一维数组。

什么是 numpy dstack?

numpy dstack 函数允许你按索引合并数组,并将结果存储为堆栈。这是一个示例:

import numpy as np

arr1 = np.array([1, 2, 3, 4])
arr2 = np.array([5, 6, 7, 8])

# Numpy depth stacking - dstack
arr_stacked = np.dstack([arr1, arr2])
print('Numpy depth stacking')
print('--------------------')
print(arr_stacked)
print()
print(f'Shape = {arr_stacked.shape}')

输出结果:

图像 12 — Numpy dstack(作者提供的图像)

因此,我们有两个 1x4 的数组进来,dstack 将它们垂直地合并成一个三维数组格式。这对于某些用例来说非常方便。

喜欢这篇文章吗?成为 Medium 会员 以继续无限学习。如果你使用以下链接,我将获得你的会员费的一部分,但对你没有额外费用。

medium.com/@radecicdario/membership

最初发表于 https://betterdatascience.com ,日期为 2023 年 1 月 10 日。

NP-什么?优化问题的复杂性类型解释

原文:towardsdatascience.com/np-what-complexity-types-of-optimization-problems-explained-558d43276044

复杂的建筑。图片由作者使用 Midjourney 创建。

计算机科学中的一个核心问题介绍

Hennie de HarderTowards Data Science Hennie de Harder

·发布于 Towards Data Science ·阅读时间 11 分钟·2023 年 8 月 17 日

--

为什么 最短路径问题 容易解决,而 旅行推销员问题 却不容易?这些问题的数学原理是什么?如何确定如果问题规模增加,是否会需要不可管理的步骤?在这篇文章中你将了解这个主题的基础知识。如果你想深入了解,我在文章末尾还附上了与这个主题相关的千年大奖难题的简要说明。

在我们开始讨论 NP 难度之前,你应该了解时间复杂度的基础。如果你熟悉时间复杂度、大 O 符号和最坏情况分析,你可以跳过以下部分。

时间复杂度

当我们使用计算机编程时,我们经常会遇到可以用不同方式解决的问题。我们需要考虑的一个重要方面是这些解决方案的效率。时间复杂度帮助我们理解当问题规模变大时,算法运行的速度如何。

大 O 符号 可以比作用一个简单的标签来标记算法,这个标签告诉我们算法完成所需的时间,基于我们处理的事物数量。这是一种描述算法步骤数量相对于问题输入规模增长的方式。

注意:时间复杂度本质上与步骤数量有关,而不是实际时间,因此这个名字不太准确。否则你可以使用更快的计算机和相同的算法。

给箱子(算法)贴上标签:你有多快?作者提供的图片。

我们通常关注最坏情况,因为我们希望确保无论我们给算法什么输入,它都不会花费超过一定的时间。这有助于确保我们的解决方案在情况变得困难时仍然可靠。

如果你正在寻找一本书中的特定页面,而你的算法从书的开头查到结尾,最坏的情况就是那一页是最后一页。作者提供的图片。

就像我们驾驶时选择最快路线一样,我们也希望为我们的问题选择最有效的算法。我们根据算法的时间复杂度来比较算法。一个运行更快(时间复杂度更低)的算法就像是选择更快的路线到达目的地。如前所述,一个更快的算法在最坏情况下需要的步骤更少。

现在,让我们通过一些实际的例子来探讨这些概念,使其更清晰。

常数时间:O(1)

想象一下你在计划去附近的公园旅行。你就住在公园旁边,所以无论你邀请多少朋友,步行到公园所需的时间始终不变。无论是你一个人还是一群 10 人,前往公园所需的时间都是恒定的——它不会因人数的不同而改变。

去公园的时间大致相同,不论人数多少。作者提供的图片。

常数时间的编程例子是使用一个键在字典中找到对应的值。

线性时间:O(n)

现在,想象一下计划一次野餐并召集所有朋友。当你邀请每个朋友时,你需要单独打电话或发消息。所以,如果你邀请 10 个朋友,你就打 10 个电话;如果你邀请 50 个朋友,你就打 50 个电话。联系朋友所需的时间随着朋友数量的增加而线性增长。

邀请一个朋友或 10 个朋友是有区别的:邀请 10 个朋友需要的时间大约是邀请 1 个朋友的 10 倍。作者提供的图片。

如果你遍历列表中的所有项目一次,这需要线性时间。

对数时间:O(log n)

在公园里,你将玩一个需要寻找隐藏宝藏的游戏。游戏会给你线索,帮助你缩小搜索区域。每个线索帮助你排除掉公园的一半。随着你找到更多的线索,搜索区域变得更小。另一种看法是:如果公园的面积是 100 平方米,那么找到宝藏最多需要 7 步(2⁷ = 128)。即使公园面积增加了 1 平方米,也没有关系,它仍然需要 7 步(101 < 128)。直到我们达到 128 之前,最坏情况下我们永远不需要额外的步骤(这里的最坏情况是什么?)。

每一步将搜索空间减少 50%。额外的一平方米不会有太大区别。图片由作者提供。

这类似于二分搜索的工作原理,其中每个线索都将可能的位置数量减少一半,导致对数时间复杂度。

二次时间:O(n²)

你正在尝试规划一个公路旅行,访问城市中的各种旅游景点。为了找出每对位置之间的最短路线,你需要将每个位置与其他每个位置进行比较。因此,如果你有 5 个位置,你需要进行 5 * 4 = 20 次比较;如果你有 10 个位置,你需要进行 10 * 9 = 90 次比较。随着位置数量的增加,比较次数会二次增长。

mylist = [1, 2, 3, 4, 5]

for n in mylist:
    for m in mylist:
        print(n*m) 

嵌套的 for 循环是一个例子,它需要二次时间。你要对所有元素进行两次循环,因此需要 n * n 次迭代。上面的代码示例将打印 5 * 5 = 25 个数字。

阶乘时间:O(n!)

最后但同样重要的是:阶乘时间。想象一下,你正在和朋友们组织一次盛大的旅行冒险。你想规划一个访问所有愿望列表上的国家的最佳路线。然而,找到最佳路线涉及考虑所有可能的国家排列。随着你在列表中添加更多国家,可能路线的数量会按阶乘增长。例如,如果你有 3 个国家,有 3! = 6 种可能的路线(ABC,ACB,BAC,BCA,CAB,CBA)。如果你有 4 个国家,则有 4! = 24 条路线,依此类推。

在这种情况下,随着你在列表中增加国家的数量,考虑所有可能路线所需的时间会因阶乘增长而急剧增加。这反映了阶乘时间复杂度,其中所需的时间随着输入规模的增加而极快增长。正如规划行程随着你添加更多国家而变得繁重一样,阶乘时间复杂度由于其快速增长在处理大问题时变得不可行。

我们可以在图表中可视化时间复杂度。绿色的是快速的,而橙色和红色的是在 n 增加时难以处理的。如果可能的话,你应该尽量避免这些。

与输入大小相关的不同 Big O 时间复杂度的操作数量。图片由作者提供。

让我们深入了解这与 NP-难度的关系。

NP-难度

常数、线性和二次时间是多项式时间的例子。多项式时间的形式是 O(nˣ)。如果我们将多项式时间与指数和阶乘时间进行比较,你会发现对于大值 n 有很大的差异:

看看粉色数字,它们展示了随着数据大小的增加,指数和阶乘时间复杂度函数如何增加。点击放大。图片由作者提供。

如表所示,多项式时间与指数时间和阶乘时间之间存在巨大差异。对于大值 n,多项式时间是相对快速的,对于 n < 100 不超过 1 秒,而指数时间和阶乘时间则无法管理。(尽管如果 n 的值非常大,多项式时间仍可能需要相当长的时间。)

NP-难度的概念有助于根据计算复杂度对问题进行分类。问题分为四个类别,其中最简单的区分是 P 问题和 NP 问题之间的区别。

P 问题

P 问题,其中 P 代表多项式,是可以在多项式时间内解决的。换句话说,它们的解决方案可以相对快速地找到,解决它们所需的时间最多是问题规模的多项式函数。

一个数学优化问题的例子,它属于 P 问题,是最短路径问题。如何在最小化距离的情况下从点 A 到点 B?迪杰斯特拉算法 是一个可以在多项式时间内解决这个问题的算法(或者通过优先队列更少的时间)。

从 A 到 B 的最短路径是什么?图像由作者提供。

NP 问题

另一方面,NP 问题(或非确定性多项式问题)涵盖了更广泛的挑战。这些问题的特点是提出的解决方案可以在多项式时间内有效验证。然而,找到解决方案本身可能需要指数级或甚至阶乘时间,使得它们比 P 问题更难解决。换句话说:如果你处理的是一个大型 NP 问题,尝试暴力破解是愚蠢的。

NP 问题是汉密尔顿路径问题:给定一个图,是否存在一条路径访问每个顶点一次并返回到起始顶点?如果有人声称他们找到了一个汉密尔顿路径,你可以通过检查它是否确实访问了每个顶点一次来验证。

验证黄色路径是否为汉密尔顿路径很简单。图像由作者提供。

在 NP 问题的范围内,我们遇到两个子类别:NP 完全问题和 NP 难问题。

NP 完全问题

在 NP 问题中,NP 完全问题是最具挑战性的。NP 完全问题是指属于 NP 且具有一个特殊属性的问题:如果你能找到一个多项式时间算法来解决它,你就能在多项式时间内解决所有 NP 问题。实质上,NP 完全问题是 NP 中“最难”的,因为它们至少与 NP 中的其他任何问题一样困难。

最著名的 NP-完全问题之一是旅行商问题(TSP),你需要找到一条最短的路线,同时访问所有给定的位置一次。可能的路线数量可以用 n!计算,其中 n 是要访问的位置数量。在之前的文章中,我使用了混合整数规划(第二个示例)和模拟退火(第一个示例)对 TSP 进行了编码。与之密切相关的是中国邮差问题,你需要至少访问图中的每条边一次。

TSP:访问每个节点。中国邮差问题:至少访问每条边一次。这两个都是 NP-完全问题。图片由作者提供。

NP-困难问题

NP-困难问题虽然相关,但与 NP-完全问题不同。NP-困难问题是指至少与 NP 中最难的问题一样困难的问题,不管它是否在 NP 中。换句话说,NP-困难问题不一定具有像 NP 中那样的高效验证过程。相反,它们作为极其困难的计算问题的基准。

停机问题询问的是,给定一个程序和输入,该程序是否会在该输入上停止(停止执行)或无限运行。它是不可判定的,这意味着没有算法可以在所有情况下解决它。停机问题是 NP-困难的,但不在 NP 中,因为它的解决方案无法高效验证。在下一个代码片段中,你会看到停机问题的两个简单示例,对于其他程序,确定其是否为停机问题可能是有问题的。

# Example 1\. Program with the following code will keep running
while True:
  continue

# Example 2\. Program with only a print statement will halt after printing
print('Halt')

总结来说,计算复杂度的范围包括从易于解决到极具挑战性的问题。虽然 P 问题可以高效解决,但 NP 问题引入了一层复杂性,其中 NP-完全问题代表了 NP 类中计算难度的巅峰。此外,NP-困难问题提供了对计算可行性边界的见解,即使它们不直接属于 NP 类。

如果我们假设 P ≠ NP(更多内容在最后一部分),这就是 P、NP、NP-完全和 NP-困难问题集合的欧拉图。图片由作者提供。

不幸的是,NP 问题在现实生活中无处不在。例如,优化配送卡车的路线、高效安排任务、设计电子电路,甚至蛋白质折叠都是 NP 问题的实例。这些问题的难处理性使得寻找最佳解决方案成为一项巨大的挑战:它们在大输入下 notoriously 难以解决。这种困难通常导致了近似算法、启发式方法和专门技术的发展,以寻找可能不是最优但在某些范围内可接受的解决方案。

如何书写历史?继续阅读以了解更多!照片由 Natalia Y. 提供,来源于 Unsplash

千年奖问题:P 对 NP

NP 硬度与其中一个未解的千年奖问题有关。这个问题很容易理解。如果你能证明 P = NP 或者 P ≠ NP,你就解决了它!这意味着什么?正如你现在所知道的,P 问题在大 n 下需要的步骤远少于 NP 完全问题。但从未证明 P 问题与 NP 问题确实不同,这意味着尚不确定是否存在多项式时间算法来解决 NP 问题。如何解决这个问题?有两种可能的结果:

  1. 如果你能找到一个解决 NP 完全问题(例如旅行商问题)的多项式时间算法,你就证明了 P = NP。

  2. 如果你能证明不存在多项式时间算法来解决特定的 NP 问题,你就证明了 P ≠ NP。你可能需要提出一个新的 NP 完全问题来实现这一点。

证明第一个观点将震撼世界,因为互联网安全是建立在 NP 硬度的基础上的。如果能够找到破解代码的多项式时间算法,那将是灾难性的。许多科学家认为第二种结果是正确的,即不存在能够解决 NP 完全问题的多项式时间算法。但这从未被证明。如果你想书写历史,这就是你的机会!

结论

深入探讨计算复杂性的细节揭示了计算机科学中问题解决的挑战。时间复杂性让我们能够在问题规模增长时评估算法效率。通过这种视角,我们探讨了常数、线性、对数、平方和阶乘复杂度等场景。

转向 NP-困难性,我们探索了各种复杂度的问题。P 与 NP 问题尤为突出——P 问题可以在多项式时间内解决,而 NP 问题则提供快速验证,但通常需要指数级或阶乘时间才能找到解决方案。NP 完全问题和 NP-困难问题成为计算挑战的巅峰。NP 完全问题涵盖了 NP 中的“最难”问题,提供了高效解决所有 NP 问题的捷径。NP-困难问题不局限于 NP,代表了计算复杂性的顶峰。

最后,P 与 NP 问题的谜团——一个难以捉摸的千年难题——有潜力重塑我们对计算复杂性的理解。这个问题的深远影响使得证明 P = NP 或 P ≠ NP 的探索者有可能改变历史的进程。

相关

## 为什么每位数据科学家都应该学习数学优化

数据科学课程目前关注数据可视化、特征工程、数据处理、(有/无)监督学习……

towardsdatascience.com ## 五种将数学优化与机器学习结合的方法

结合两种力量的实际例子。

towardsdatascience.com ## 精确算法还是启发式算法?

逐步指南,帮助你为数学优化问题做出正确选择

towardsdatascience.com

NT-Xent(归一化温度调节交叉熵)损失函数的解释及在 PyTorch 中的实现

原文:towardsdatascience.com/nt-xent-normalized-temperature-scaled-cross-entropy-loss-explained-and-implemented-in-pytorch-cc081f69848?source=collection_archive---------1-----------------------#2023-06-13

一种直观解释 NT-Xent 损失函数的方法,详细解释其操作,并在 PyTorch 中进行了实现

Dhruv MataniTowards Data Science Dhruv Matani

·

关注 发表在Towards Data Science ·14 min read·Jun 13, 2023

--

Naresh Singh合作撰写。

NT-Xent 损失函数公式。来源:Papers with code (CC-BY-SA)

介绍

最近在 自监督学习 和 对比学习 方面的进展激发了机器学习(ML)领域的研究人员和从业者重新关注这一领域。

尤其是,SimCLR 论文提出了一个简单的对比学习视觉表示框架,在自监督和对比学习领域获得了大量关注。

论文的核心思想非常简单——允许模型学习一对图像是否来自相同或不同的初始图像。

图 1:SimCLR 的高层次思路。来源:SimCLR 论文

SimCLR 方法将每个输入图像 i 编码为特征向量 zi。需要考虑两种情况:

  1. 正对:相同图像使用不同的增强集合进行增强,结果特征向量 zizj 进行比较。这些特征向量通过损失函数被强制保持相似。

  2. 负对:不同的图像使用不同的增强集合进行增强,结果特征向量 zizk 进行比较。这些特征向量通过损失函数被强制保持不相似。

本文其余部分将集中于解释和理解该损失函数及其使用 PyTorch 的高效实现。

NT-Xent 损失

从高层次看,对比学习模型接收 2N 张图像,来源于 N 个基础图像。每个 N 个基础图像都使用随机的图像增强集合进行增强,生成 2 张增强图像。这就是我们在单个训练批次中获得 2N 张图像的方式。

图 2:对比学习中的单个训练批次中的 6 张图像。每张图像下方的数字是该图像在输入批次中的索引,输入到对比学习模型中。图像来源:牛津视觉几何组(CC-SA)。

在接下来的章节中,我们将深入探讨 NT-Xent 损失的以下方面。

  1. 温度对 SoftMax 和 Sigmoid 的影响

  2. NT-Xent 损失的简单直观解释

  3. PyTorch 中 NT-Xent 的逐步实现

  4. 激发多标签损失函数需求(NT-BXent)

  5. PyTorch 中 NT-BXent 的逐步实现

步骤 2-5 的所有代码可以在 这个笔记本 中找到。步骤 1 的代码可以在 这个笔记本 中找到。

温度对 SoftMax 和 Sigmoid 的影响

为了理解本文中要研究的对比损失函数的所有活动部分,我们需要首先了解温度对 SoftMax 和 Sigmoid 激活函数的影响。

通常,温度缩放应用于 SoftMax 或 Sigmoid 的输入,以平滑或突出这些激活函数的输出。在传递到激活函数之前,输入 logits 被温度除以。你可以在这个笔记本中找到所有相关代码。

SoftMax:对于 SoftMax,高温度会降低输出分布的方差,从而使标签变得更加柔和。低温度则会增加输出分布的方差,使最大值相对于其他值更加突出。请参见下面的图表,了解输入张量[0.1081, 0.4376, 0.7697, 0.1929, 0.3626, 2.8451]的温度对 SoftMax 的影响。

图 3:温度对 SoftMax 的影响。来源:作者

Sigmoid:对于 Sigmoid,高温度会导致输出分布向 0.0 拉伸,而低温度则将输入扩展到更高的值,使输出更接近 0.0 或 1.0,具体取决于输入的未签名幅度。

图 4:温度对 Sigmoid 的影响。来源:作者

现在我们理解了不同温度值对 SoftMax 和 Sigmoid 函数的影响,让我们看看这些知识如何应用于理解 NT-Xent 损失。

解读 NT-Xent 损失

NT-Xent 损失通过理解损失名称中的各个术语来进行理解。

  1. 标准化:余弦相似度产生范围在[-1.0 到+1.0]之间的标准化分数

  2. 温度缩放:所有对的余弦相似度在计算交叉熵损失之前被温度缩放

  3. 交叉熵损失:底层损失是一个多类别(单标签)交叉熵损失

如上所述,我们假设对于大小为 2N 的批次,以下索引处的特征向量代表正对(0, 1)、(2, 3)、(4, 5)、(6, 7)等,其余组合代表负对。在解释 NT-Xent 损失时,这一点是与 SimCLR 相关的重要因素。

现在我们了解了 NT-Xent 损失的术语在上下文中的含义,让我们来看看计算特征向量批次上 NT-Xent 损失所需的机械步骤。

  1. 所有对的余弦相似度分数是针对 SimCLR 模型生成的每个 2N 向量计算的。这导致了(2N)²的相似度分数,表示为一个 2N x 2N 矩阵

  2. 相同值 (i, i) 之间的比较结果会被丢弃(因为一个分布与自身完全相似,不能让模型学到任何有用的东西)

  3. 每个值(余弦相似度)都由温度参数 𝜏(这是一个超参数)进行缩放

  4. 交叉熵损失应用于上述结果矩阵的每一行。以下段落将详细解释

  5. 通常,这些损失的均值(每批次一个损失)用于反向传播

这里使用交叉熵损失的方式在语义上与标准分类任务中的使用方式略有不同。在分类任务中,训练一个最终的“分类头”来为每个输入产生一个独热概率向量,我们在这个独热概率向量上计算交叉熵损失,因为我们实际上是在计算两个分布之间的差异。这个视频 美丽地解释了交叉熵损失的概念。在 NT-Xent 损失中,训练层和输出分布之间没有一一对应的关系。相反,每个输入都计算一个特征向量,然后计算每对特征向量之间的余弦相似度。这里的诀窍是,由于每张图片与输入批次中的恰好 1 张其他图片相似(正样本对)(如果我们忽略特征向量与自身的相似度),我们可以将其视为一种类似分类的设置,其中图像之间相似度概率的概率分布表示了一个分类任务,其中一个值接近 1.0,其余值接近 0.0。

既然我们对 NT-Xent 损失有了充分的理解,我们应该能很好地将这些思想实现到 PyTorch 中。我们开始吧!

NT-Xent 损失在 PyTorch 中的实现

本节中的所有代码可以在这个笔记本中找到。

代码重用:许多NT-Xent 损失的实现从头开始实现所有操作。此外,其中一些实现损失函数的方式效率不高,更喜欢使用for 循环而非 GPU 并行。相反,我们将使用不同的方法。我们将通过 PyTorch 已经提供的标准交叉熵损失来实现这种损失。为此,我们需要将预测和真实标签转换为交叉熵可以接受的格式。下面我们来看一下如何实现。

预测张量:首先,我们需要创建一个 PyTorch 张量,它将表示我们对比学习模型的输出。假设我们的批量大小是 8(2N=8),并且我们的特征向量有 2 个维度(2 个值)。我们将输入变量称为 “x”

x = torch.randn(8, 2)

余弦相似度:接下来,我们将计算此批次中每个特征向量之间的所有对的余弦相似度,并将结果存储在名为 “xcs” 的变量中。如果下面的代码看起来令人困惑,请阅读这个页面上的详细信息。这是“标准化”步骤。

xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)

如上所述,我们需要忽略每个特征向量的自相似度分数,因为它不对模型的学习做出贡献,并且在我们想要计算交叉熵损失时会成为不必要的麻烦。为此,我们将定义一个变量 “eye”,这是一个矩阵,其中主对角线上的元素值为 1.0,其余元素值为 0.0。我们可以使用以下命令创建这样的矩阵。

eye = torch.eye(8)

现在让我们将其转换为布尔矩阵,以便可以使用这个掩码矩阵在 “xcs” 变量中进行索引。

eye = eye.bool()

让我们将张量 “xcs” 克隆到一个名为 “y” 的张量中,以便以后可以引用“xcs”张量。

y = xcs.clone()

现在,我们将所有对的余弦相似度矩阵的主对角线上的值设置为 -inf,这样当我们对每一行计算 softmax 时,这个值将不会产生任何贡献。

y[eye] = float("-inf")

张量 “y” 通过温度参数缩放后,将成为 PyTorch 中的交叉熵损失 API 的输入之一。接下来,我们需要计算要传递给交叉熵损失 API 的真实标签(目标)。

真实标签(目标张量):对于我们使用的示例(2N=8),这就是真实标签张量的样子。

tensor([1, 0, 3, 2, 5, 4, 7, 6])

这是因为张量 “y” 中的以下索引对包含正对。

(0, 1), (1, 0)

(2, 3), (3, 2)

(4, 5), (5, 4)

(6, 7), (7, 6)

要解释上述索引对,我们来看一个单一的例子。对 (4, 5) 来说,这意味着第 4 行第 5 列应该设置为 1.0(正对),这也是上述张量所表示的。太好了!

要创建上述张量,我们可以使用以下 PyTorch 代码,该代码将真实标签存储在变量 “target” 中。

target = torch.arange(8)
target[0::2] += 1
target[1::2] -= 1

交叉熵损失:我们已经具备了计算损失所需的所有成分!剩下的唯一任务就是调用 PyTorch 中的 cross_entropy API。

loss = F.cross_entropy(y / temperature, target, reduction="mean")

变量 “loss” 现在包含了计算出的 NT-Xent 损失。让我们把所有的代码封装到一个 Python 函数中。

def nt_xent_loss(x, temperature):
  assert len(x.size()) == 2

  # Cosine similarity
  xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
  xcs[torch.eye(x.size(0)).bool()] = float("-inf")

  # Ground truth labels
  target = torch.arange(8)
  target[0::2] += 1
  target[1::2] -= 1

  # Standard cross-entropy loss
  return F.cross_entropy(xcs / temperature, target, reduction="mean")

上述代码有效,只要每个特征向量在训练对比学习模型时批次中恰好有一个正对。让我们来看一下如何在对比学习任务中处理多个正对。

用于对比学习的多标签损失:NT-BXent

在 SimCLR 论文中,每个图像i在索引j处有恰好 1 个相似对。这使得交叉熵损失成为任务的完美选择,因为它类似于多类别问题。相反,如果我们将 M > 2 个相同图像的增强输入到对比学习模型的单个训练批次中,那么每个批次将包含图像i的 M-1 个相似对。这将使任务类似于多标签问题。

显而易见的选择是将交叉熵损失替换为二元交叉熵损失。因此命名为 NT-BXent 损失,代表归一化温度缩放的二元交叉熵损失。

下面的公式展示了元素i的损失Li。公式中的σ表示S 型函数

图 5:NT-BXent 损失的公式。图像来源:本文作者

为了避免类别不平衡问题,我们通过我们小批量中正负对的数量的倒数来加权正负对。在用于反向传播的小批量中的最终损失将是小批量中每个样本损失的平均值。

接下来,让我们将注意力集中在我们在 PyTorch 中对 NT-BXent 损失的实现上。

在 PyTorch 中实现 NT-BXent 损失

本节中的所有代码可以在这个笔记本中找到。

代码重用:类似于我们对 NT-Xent 损失的实现,我们将重用 PyTorch 提供的二元交叉熵(BCE)损失方法。我们的真实标签设置将类似于使用 BCE 损失的多标签分类问题。

预测张量:我们将使用与 NT-Xent 损失实现中相同的(8, 2)预测张量。

x = torch.randn(8, 2)

余弦相似度:由于输入张量x相同,所有对的余弦相似度张量xcs也将相同。有关下面这行代码的详细解释,请参见这一页

xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)

为了确保位置(i, i)处的损失为0,我们需要进行一些操作,使得在对xcs张量应用 Sigmoid 后,它在每个索引(i, i)处的值为1。由于我们将使用 BCE 损失,我们会将每个特征向量的自相似性分数标记为张量xcs中的值为无穷大。这是因为在xcs张量上应用 Sigmoid 函数将无穷大转换为值1,我们将设置我们的真实标签,使得真实标签中的每个位置(i, i)的值为1

创建一个掩码张量,该张量在主对角线上具有值Truexcs在主对角线上具有自相似性分数),而其他地方为False

eye = torch.eye(8).bool()

将张量“xcs”克隆到一个名为“y”的张量中,以便我们可以稍后引用“xcs”张量。

y = xcs.clone()

现在,我们将所有对的余弦相似度矩阵的主对角线上的值设置为无穷大,以便在对每一行计算 Sigmoid 时,这些位置的值为 1。

y[eye] = float("inf")

张量“y”由温度参数缩放后,将作为 PyTorch 中BCE 损失 API的输入(预测)之一。接下来,我们需要计算要提供给 BCE 损失 API 的真实标签(目标)。

真实标签(目标张量):我们期望用户传递给我们包含正例的所有(x, y)索引对。这与我们对 NT-Xent 损失所做的有所不同,因为正对是隐式的,而这里正对是显式的。

除了用户提供的位置外,我们还将所有对角线元素设置为正对,如上所述。我们将使用 PyTorch 张量索引 API 提取这些位置的所有元素并将其设置为 1,而其他元素初始化为 0。

target = torch.zeros(8, 8)
pos_indices = torch.tensor([
  (0, 0), (0, 2), (0, 4),
  (1, 4), (1, 6), (1, 1),
  (2, 3),
  (3, 7),
  (4, 3),
  (7, 6),
])
# Add indexes of the principal diagonal as positive indexes.
# This will be useful since we will use the BCELoss in PyTorch,
# which will expect a value for the elements on the principal
# diagonal as well.
pos_indices = torch.cat([pos_indices, torch.arange(8).reshape(8, 1).expand(-1, 2)], dim=0)
# Set the values in the target vector to 1.
target[pos_indices[:,0], pos_indices[:,1]] = 1

二元交叉熵(BCE)损失:与 NT-Xent 损失不同,我们不能简单调用torch.nn.functional.binary_cross_entropy_function,因为我们需要根据当前小批量中索引 i 处的正负对数目来加权正负损失。

不过第一步是计算逐元素的 BCE 损失。

temperature = 0.1
loss = F.binary_cross_entropy((y / temperature).sigmoid(), target, reduction="none")

我们将创建一个正负对的二进制掩码,然后创建两个张量,loss_pos 和 loss_neg,只包含计算损失中对应于正对和负对的元素。

target_pos = target.bool()
target_neg = ~target_pos
# loss_pos and loss_neg below contain non-zero values only for those elements
# that are positive pairs and negative pairs respectively.
loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])
loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])

接下来,我们将分别对每个小批量中的元素 i 的正负对损失进行求和。

# loss_pos and loss_neg now contain the sum of positive and negative pair losses
# as computed relative to the i'th input.
loss_pos = loss_pos.sum(dim=1)
loss_neg = loss_neg.sum(dim=1)

为了进行加权,我们需要跟踪每个小批量中每个元素 i 对应的正负对的数量。张量“num_pos”“num_neg”将存储这些值。

# num_pos and num_neg below contain the number of positive and negative pairs
# computed relative to the i'th input. In an actual setting, this number should
# be the same for every input element, but we let it vary here for maximum
# flexibility.
num_pos = target.sum(dim=1)
num_neg = target.size(0) - num_pos

我们已经具备了计算损失所需的所有要素!我们唯一需要做的就是按正负对的数量对正负损失进行加权,然后在小批量中计算损失的平均值。

def nt_bxent_loss(x, pos_indices, temperature):
    assert len(x.size()) == 2

    # Add indexes of the principal diagonal elements to pos_indices
    pos_indices = torch.cat([
        pos_indices,
        torch.arange(x.size(0)).reshape(x.size(0), 1).expand(-1, 2),
    ], dim=0)

    # Ground truth labels
    target = torch.zeros(x.size(0), x.size(0))
    target[pos_indices[:,0], pos_indices[:,1]] = 1.0

    # Cosine similarity
    xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
    # Set logit of diagonal element to "inf" signifying complete
    # correlation. sigmoid(inf) = 1.0 so this will work out nicely
    # when computing the Binary cross-entropy Loss.
    xcs[torch.eye(x.size(0)).bool()] = float("inf")

    # Standard binary cross-entropy loss. We use binary_cross_entropy() here and not
    # binary_cross_entropy_with_logits() because of
    # https://github.com/pytorch/pytorch/issues/102894
    # The method *_with_logits() uses the log-sum-exp-trick, which causes inf and -inf values
    # to result in a NaN result.
    loss = F.binary_cross_entropy((xcs / temperature).sigmoid(), target, reduction="none")

    target_pos = target.bool()
    target_neg = ~target_pos

    loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])
    loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])
    loss_pos = loss_pos.sum(dim=1)
    loss_neg = loss_neg.sum(dim=1)
    num_pos = target.sum(dim=1)
    num_neg = x.size(0) - num_pos

    return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()

pos_indices = torch.tensor([
    (0, 0), (0, 2), (0, 4),
    (1, 4), (1, 6), (1, 1),
    (2, 3),
    (3, 7),
    (4, 3),
    (7, 6),
])
for t in (0.01, 0.1, 1.0, 10.0, 20.0):
    print(f"Temperature: {t:5.2f}, Loss: {nt_bxent_loss(x, pos_indices, temperature=t)}")

打印。

温度:0.01,损失:62.898780822753906

温度:0.10,损失:4.851151943206787

温度:1.00,损失:1.0727109909057617

温度:10.00,损失:0.9827173948287964

温度:20.00,损失:0.982099175453186

结论

自监督学习是深度学习中的一个新兴领域,它允许在未标记的数据上训练模型。这项技术让我们绕过了大规模标记数据的需求。

在这篇文章中,我们了解了对比学习的损失函数。第一个,称为 NT-Xent 损失,用于对每个输入在小批量中学习单个正对。我们介绍了 NT-BXent 损失,该损失用于在小批量中对每个输入学习多个(> 1)正对。我们学会了直观地解释这些损失,基于我们对交叉熵损失和二元交叉熵损失的理解。最后,我们在 PyTorch 中高效地实现了这两种损失函数。

NumPy 广播

原文:towardsdatascience.com/numpy-broadcasting-4c4cb9dff1e7?source=collection_archive---------20-----------------------#2023-01-04

定义、规则和示例

Pan CretanTowards Data Science Pan Cretan

·

关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 1 月 4 日

--

摄影:Jean-Guy Nakars 摄于 Unsplash

介绍

NumPy 提供了通过矢量化进行快速计算的方法,这避免了使用较慢的 Python 循环。矢量化在使用二元 ufuncs(如加法或乘法)时也可用,附加的好处是数组不需要具有相同的形状。具有不同形状的数组之间的操作称为 广播,这可能会特别令人困惑,特别是对于多维数组,或当两个数组都需要扩展时。

有许多示例和教程,但我发现通过思考并实际记住广播规则来处理问题最有用。这样更容易考虑任何给定的使用案例,并编写代码,而无需依赖试错法。

广播规则

我强烈推荐两本数据分析和数据科学的书籍。它们都有关于广播的小节。

数据分析的 Python 由 Wes McKinney 编写,包含以下广播规则:

如果对于每个尾部维度(即从末尾开始)轴长度匹配,或者任一长度为 1,则两个数组可以进行广播。广播将对缺失或长度为 1 的维度进行。

Python 数据科学手册 由 Jake VanderPlas 编写,包含更详细的广播规则:

规则 1:如果两个数组在维度数量上不同,维度较少的那个数组的形状会在其前面(左侧)用 1 进行填充

规则 2:如果两个数组的形状在任何维度上不匹配,则形状为 1 的数组会被扩展以匹配另一个形状。

规则 3:如果在任何维度上大小不一致且都不等于 1,将会引发错误。

我发现第二组规则更容易遵循。下面的示例将使用这些规则。

示例

可能最简单的广播示例之一,以及一个典型模式,是从每一列中减去列均值。进行此操作后,列均值将变为数值上等于零。

a = np.arange(12).reshape(4, 3)
means_columns = a.mean(axis=0)
res = a - means_columns
print('original array', a, sep='\n')
print('.. column means', a.mean(axis=0), sep='\n')
print('demeaned array', res, sep='\n')
print('.. column means', res.mean(axis=0), sep='\n')

这会打印

original array
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]
.. column means
[4.5 5.5 6.5]
demeaned array
[[-4.5 -4.5 -4.5]
 [-1.5 -1.5 -1.5]
 [ 1.5  1.5  1.5]
 [ 4.5  4.5  4.5]]
.. column means
[0\. 0\. 0.]

让我们看看这里发生了什么。计算 a.mean(axis=0) 的列均值会生成一个形状为 (3,) 的一维数组。涉及减法的两个数组在形状上有所不同,因此 means_columns 会根据规则 1 在左侧填充 1 以匹配形状。因此,在幕后,means_columns 会被重塑为 (1, 3)。然后根据规则 2,means_columns 会沿着轴 0 扩展,使其形状变为 (3, 3) 以匹配 a 的形状。

除了使用规则预测维度较少的数组如何被扩展外,我们还可以使用 [np.broadcast_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html),它返回一个只读视图,该视图具有给定的形状。该视图可能是不连续的,且不同元素可能指向相同的内存地址。

means_columns_bc = np.broadcast_to(means_columns, a.shape)
print(means_columns_bc)
print('base', means_columns_bc.base, sep='\n')
print('strides', means_columns_bc.strides, sep='\n')

这会打印

[[4.5 5.5 6.5]
 [4.5 5.5 6.5]
 [4.5 5.5 6.5]
 [4.5 5.5 6.5]]
base
[4.5 5.5 6.5]
strides
(0, 8)

我们可以看到基础是原始的均值数组(因此它是一个视图),而沿第一个轴的步幅是 0,这意味着同一列的不同元素指向相同的内存位置(有关 NumPy 内部的介绍,请参见这里)。NumPy 确实在尽可能优化内存使用!

如果我们想对行进行去均值操作怎么办?可以通过a.mean(axis=1)快速计算行的均值,该操作将返回一个形状为(4,)的数组。将其形状左侧填充 1,意味着数组将变成(1,4)。根据规则 3,这两个数组的最后维度不一致且都不是 1。这意味着广播不会发生。我们也可以预见到这一点,因为np.broadcast_to(a.mean(axis=1), a.shape)引发异常,告知我们广播无法生成请求的形状(4, 3)。形状的不兼容性还可以通过执行np.broadcast_shapes(a.shape, a.mean(axis=1).shape)来观察,这也引发异常,解释了形状不匹配。通过将行均值重塑为(4,1)数组来进行行去均值操作,可以使用a.mean(axis=1).reshape(-1, 1)a.mean(axis=1)[:, np.newaxis]

means_rows = a.mean(axis=1)
res = a - means_rows.reshape(-1, 1) # or res = a - means_rows[:, np.newaxis]
print('original array', a, sep='\n')
print('.. row means', a.mean(axis=1), sep='\n')
print('demeaned array', res, sep='\n')
print('.. row means', res.mean(axis=1), sep='\n')

这将打印

original array
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]
.. row means
[ 1\.  4\.  7\. 10.]
demeaned array
[[-1\.  0\.  1.]
 [-1\.  0\.  1.]
 [-1\.  0\.  1.]
 [-1\.  0\.  1.]]
.. row means
[0\. 0\. 0\. 0.]

规则 2 清楚地解释了为什么这有效,因为形状为(4,1)的数组可以在列中扩展,使其形状变成(4,3)。

在第三个示例中,我们将演示广播如何在 ufunc 二元函数中扩展两个数组

a = np.arange(4)
b = np.arange(3)
res = a[:, np.newaxis] + b[np.newaxis, :]
print('result array', res, sep='\n')

这将得到

result array
[[0 1 2]
 [1 2 3]
 [2 3 4]
 [3 4 5]]

严格来说,重塑b并不是必要的,但它使事情更清晰。我们还可以使用[np.broadcast_arrays](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_arrays.html)或相关且更灵活的[np.broadcast](https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html#numpy.broadcast)来广播两个数组,而不应用 ufunc。为了完整起见,还有其他方法可以实现与广播相同的结果,一个例子是np.add.outer(a, b),它产生与

np.array_equal(np.add.outer(a, b), a[:, np.newaxis] + b)

返回 True。

对任何轴上的高维数组进行去均值操作可以被推广为

def demean_axis(arr, axis=1):
    means = arr.mean(axis)
    indexer = [slice(None)]*arr.ndim
    indexer[axis] = np.newaxis
    return arr - means[tuple(indexer)]

arr = np.linspace(1, 12, 24*3).reshape(6,4,3)
res = demean_axis(arr, axis=1)

我们可以确认np.abs(res.mean(axis=1)).max()在数值上等于零。上述函数取自 Wes McKinney 的书籍,但需要稍微修改以适配本文使用的 NumPy 版本(1.23.4)。

作为一个更实际的例子,我们可以使用广播将彩色图像转换为灰度图像。广播部分已用注释标出:

import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import io
import numpy as np
from PIL import Image
# read in the original image (png)
with open('landscape_water_lake_nature_trees.png', mode='rb') as f:
    image_orig = f.read()
f = io.BytesIO(image_orig)
im = Image.open(f)
image_orig = np.array(im)/255.
print('shape of original image', image_orig.shape, sep='\n')

# convert RGBA to RGB (pillow could be used for this)
background = (1., 1., 1.)
row = image_orig.shape[0]
col = image_orig.shape[1]
image_color = np.zeros( (row, col, 3), dtype='float32' )
r, g, b, a = image_orig[:,:,0], image_orig[:,:,1], image_orig[:,:,2], image_orig[:,:,3]
a = np.asarray( a, dtype='float32' )
R, G, B = background
image_color[:,:,0] = r * a + (1.0 - a) * R
image_color[:,:,1] = g * a + (1.0 - a) * G
image_color[:,:,2] = b * a + (1.0 - a) * B
print('shape of image after RGBA to RGB conversion', image_color.shape, sep='\n')

# convert to greyscale
conv = np.array([0.2126, 0.7152, 0.0722])
# --- broadcasting !!! ---
image_grey = (image_color[:,:,:3]*conv).sum(axis=2)
# --- broadcasting !!! ---
print('shape of image after conversion to greyscale', image_grey.shape, sep='\n')

# plot the image
fig = plt.figure(figsize=(8, 4))
axs = fig.subplots(1, 2)
axs[0].axis('off')
axs[0].set_title('RGB image')
axs[0].imshow(image_color)
axs[1].axis('off')
axs[1].set_title('greyscale image')
axs[1].imshow(image_grey, cmap='gray')
axs[0].annotate('',xy=(0.52,0.5),xytext=(0.50,0.5),arrowprops=dict(facecolor='black'),
             xycoords='figure fraction', textcoords='figure fraction')
fig.savefig('RGBA_to_greyscale.png')

上述代码生成

将透明 PNG 图像转换为灰度图像(照片由nextvoyagepixabay提供)

也许值得关注的是,原图是一个 RGBA 图像(红色、绿色、蓝色、透明度),即它还包含第四个 alpha 通道来指示每个像素的透明度。我们通过使用白色背景将 RGBA 图像转换为 RGB。这以及示例中的几乎所有内容,都可以使用Pillow来完成,但我们故意尽可能使用 NumPy(例如,Pillow 可以转换为灰度图像以保持透明度)。代码的最后部分是一些Matplotlib的花招,用于比较彩色图像和灰度图像。

转换为灰度图像是通过以下公式完成的。

来自维基百科的this page。执行乘法和求和的代码部分,将图像数组的形状从(1080,1920,3)更改为(1080,1920)的部分已被注释掉。你可能会想,为了一行广播代码展示如此长的示例是否值得。这正是要点!广播是简洁的,如果没有它,代码会更长且更慢。通常,算法背后的魔法是几行 NumPy 操作,通常包括广播。

设置值

大多数 NumPy 用户将广播与数组加法或乘法相关联。然而,当使用索引设置值时,广播同样适用。下面你可以找到一组示例,我们将不对其详细评论。

# set one row, same value to all columns
a = np.ones((4,3))
a[0] = -1
# array([[-1., -1., -1.],
#        [ 1.,  1.,  1.],
#        [ 1.,  1.,  1.],
#        [ 1.,  1.,  1.]])

# set one column, same value to all rows
a = np.ones((4,3))
a[:, 0] = -1
# array([[-1.,  1.,  1.],
#        [-1.,  1.,  1.],
#        [-1.,  1.,  1.],
#        [-1.,  1.,  1.]])

# set all rows, same value to all elements
a = np.ones((4,3))
a[:] = -1
# array([[-1., -1., -1.],
#        [-1., -1., -1.],
#        [-1., -1., -1.],
#        [-1., -1., -1.]])

# set all rows, different value to each column
a = np.ones((4,3))
b = np.array([-1, -2, -3])
a[:] = b
# array([[-1., -2., -3.],
#        [-1., -2., -3.],
#        [-1., -2., -3.],
#        [-1., -2., -3.]])

# set all rows, different value to each row
a = np.ones((4,3))
b = np.array([-1, -2, -3, -4])
b = b[:, np.newaxis]
a[:] = b
# array([[-1., -1., -1.],
#        [-2., -2., -2.],
#        [-3., -3., -3.],
#        [-4., -4., -4.]])

# set some rows, different value to each column
a = np.ones((4,3))
b = np.array([-1, -2, -3])
a[:3] = b
# array([[-1., -2., -3.],
#        [-1., -2., -3.],
#        [-1., -2., -3.],
#        [ 1.,  1.,  1.]])

# set some rows, different value to each row
a = np.ones((4,3))
b = np.array([-1, -2, -3])
a[:3] = b[:, np.newaxis]
# array([[-1., -1., -1.],
#        [-2., -2., -2.],
#        [-3., -3., -3.],
#        [ 1.,  1.,  1.]])

# set some columns, different value to each column
a = np.ones((4,3))
b = np.array([-1, -2])
a[:,:2] = b
# array([[-1., -2.,  1.],
#        [-1., -2.,  1.],
#        [-1., -2.,  1.],
#        [-1., -2.,  1.]])

# set some columns, different value to each row
a = np.ones((4,3))
b = np.array([-1, -2, -3, -4])
a[:,:2] = b[:, np.newaxis]
# array([[-1., -1.,  1.],
#        [-2., -2.,  1.],
#        [-3., -3.,  1.],
#        [-4., -4.,  1.]])

结论

广播可能看起来很复杂,但如果记住几个关键原则,它可以很容易掌握。最重要的原则是数组的形状从右开始对齐。缺失的维度用一填充,始终在左侧。通过扩展维度为一的形状,这两个形状变得相同。在两个(或更多!)数组可以进行广播之前,可能需要使用np.newaxis进行一些重新形状调整。广播不仅用于计算整个数组,还用于设置数组中的某些值。这就是要点。通过一些练习,NumPy 的广播可以导致令人惊讶的简洁和高效的代码。充分利用它的潜力!

探究字符级 RNN:基于 NumPy 的实现指南

原文:towardsdatascience.com/numpy-character-level-rnn-af1428bb10a8?source=collection_archive---------5-----------------------#2023-01-27

由于最近 LLMs 蓬勃发展,掌握语言建模的基础知识至关重要

Joe Sasson](https://sassonjoe66.medium.com/?source=post_page-----af1428bb10a8--------------------------------)Towards Data Science](https://towardsdatascience.com/?source=post_page-----af1428bb10a8--------------------------------) Joe Sasson

·

关注 发表在 Towards Data Science ·17 min read·Jan 27, 2023

--

图片来自 Markus Spiske on Unsplash

简介

循环神经网络(RNNs)是一种强大的神经网络类型,能够处理序列数据,如时间序列或自然语言。本文将通过使用 NumPy 从零开始构建一个 Vanilla RNN 的过程。我们将首先讨论 RNN 的理论和直觉,包括它们的架构和适合解决的问题类型。接下来,我们将深入代码,解释 RNN 的各个组件及其相互作用。最后,我们将通过将 RNN 应用于实际数据集来展示其有效性。

具体来说,我们将实现一个多对多的字符级 RNN,使用序列化的在线学习。这意味着网络一次处理一个字符的输入序列,并在每个字符后更新网络参数。这允许网络在实时学习,并随着遇到的新模式而适应数据。

字符级 RNN 意味着输入和输出是单个字符,而不是单词或句子。这使得网络能够学习文本中字符之间的潜在模式和依赖关系。多对多架构指的是网络接收一个字符序列作为输入,并生成一个字符序列作为输出。这与多对一架构不同,在多对一架构中,网络接收一个输入序列并生成一个输出,或者与一对多架构不同,在一对多架构中,网络接收一个输入并生成一个输出序列。

我使用了 Andrej Karpathy 的代码(见这里)作为我的实现基础,并做了几处修改以提高通用性和可靠性。我扩展了代码以支持多个层次,并重新构建了它以提高可读性和重用性。这个项目建立在我之前使用 NumPy 创建简单 ANN 的工作基础上。相关源代码可以在这里找到。

理论与直觉

RNNs 可以与传统的前馈神经网络(ANNs)对比,后者没有“记忆”机制,并且独立处理每个输入。ANNs 适用于输入和输出具有固定大小且输入不包含序列依赖关系的问题。相比之下,RNNs 能够处理变长的输入序列,并通过隐藏状态保持对过去输入的“记忆”。

隐藏状态使 RNN 能够捕捉时间依赖关系,并根据整个输入序列做出预测。总的来说,网络使用来自先前时间步的信息来指导其对当前输入的处理。此外,更复杂的 NLP 架构可以处理长期依赖关系(GPT-3 使用了 2048 的序列长度进行训练),其中输入序列开头的信息对于预测序列末尾的输出仍然相关。这种保持“记忆”的能力使 RNN 和变换器在处理序列数据时相较于 ANNs 具有显著优势。

最近,像GPT-3BERT这样的变换器架构在各种 NLP 任务中变得越来越流行。这些架构基于自注意力机制,使网络能够有选择地关注输入序列的不同部分。这使得网络能够捕捉长期依赖关系,而无需递归,从而比 RNN 更高效且更易于训练。变换器架构已在各种 NLP 任务中实现了最先进的结果,并被用于许多实际应用中。

尽管变换器架构比普通 RNN 更复杂且具有不同的特性,但普通 RNN 在深度学习领域仍然发挥着重要作用。它们易于理解,容易实现和调试,并可以作为其他更复杂架构的构建块。在本文中,我们将重点关注普通 RNN,并窥探其真正的工作原理。

普通 RNN 的三种主要类型是:

  • 一对多: 输入一张狗的图片并输出‘狗的图片’

  • 多对一: 输入一个句子并接收情感(情感分析)

  • 多对多: 输入一个句子并输出完整句子(见下文)

我们将实现如下所示的多对多架构。

来源: Kaivan Kamali,《深度学习(第二部分)——递归神经网络(RNN)》(银河培训材料)。training.galaxyproject.org/training-material/topics/statistics/tutorials/RNN/tutorial.html

从这一点开始,我们将使用h[t]表示时间步 t 的隐藏状态。在图中,这被表示为s[t]

如你所见,来自上一个时间步的隐藏状态h[t-1]与当前输入x[t]结合,这一过程在时间步数上重复。在 RNN 块内,我们正在更新当前时间步的隐藏状态。

为了澄清,时间步只是一个字符,如 ‘a’ 或 ‘d’。输入序列包含一个可变数量的字符或时间步,也称为序列长度,这是网络的一个超参数。

代码

目录

  1. 准备数据

  2. RNN 类

  3. 前向传播

  4. 反向传播

  5. 优化器

  6. 训练

准备数据

## start with data
data = open('path-to-data', 'r').read() # should be simple plain text file

chars = list(set(data))
data_size, vocab_size = len(data), len(chars)

print('data has {} characters, {} unique.'.format(data_size, vocab_size))

char_to_idx = { ch:i for i,ch in enumerate(chars) }
idx_to_char = { i:ch for i,ch in enumerate(chars) }

我们从一个纯文本文件中读取数据作为字符串,并对字符进行标记化。每个唯一字符(共 65 个)将映射到一个整数,反之亦然。

让我们为 RNN 采样一个输入和目标序列。

pointer, seq_length = 0, 8

x = [char_to_idx[ch] for ch in data[pointer:pointer+seq_length]]

y = [char_to_idx[ch] for ch in data[pointer+1:pointer+seq_length+1]]

print(x)
>> [2, 54, 53, 62, 13, 28, 20, 54] # our RNN input sequence

print(y)
>> [54, 53, 62, 13, 28, 20, 54, 13] # our RNN target sequence

for t in range(seq_length):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

>>  when input is [2] the target: 54
    when input is [2, 54] the target: 53
    when input is [2, 54, 53] the target: 62
    when input is [2, 54, 53, 62] the target: 13
    when input is [2, 54, 53, 62, 13] the target: 28
    when input is [2, 54, 53, 62, 13, 28] the target: 20
    when input is [2, 54, 53, 62, 13, 28, 20] the target: 54
    when input is [2, 54, 53, 62, 13, 28, 20, 54] the target: 13

输入是一个标记化的序列,目标是输入偏移一个单位后的值。

RNN 类

class RNN:
    def __init__(self, hidden_size, vocab_size, seq_length, num_layers):
        pass 

    def __call__(self, *args: Any, **kwds: Any):
        """RNN Forward Pass"""

        pass 

    def backward(self, targets, cache):
        """RNN Backward Pass"""

        pass

    def update(self, grads, lr):
        """Perform Parameter Update w/ Adagrad"""

        pass

    def predict(self, hprev, seed_ix, n):
        """
        Make predictions using the trained RNN model.

        Parameters:
        hprev (numpy array): The previous hidden state.
        seed_ix (int): The seed letter index to start the prediction with.
        n (int): The number of characters to generate for the prediction.

        Returns:
        ixes (list): The list of predicted character indices.
        hs (numpy array): The final hidden state after making the predictions.
        """

        pass 

让我们开始讨论 RNN 的组件,与基本 ANN 相比。

在传统的前馈神经网络中,控制层间交互的参数由一个单一的权重矩阵表示,记作 W。然而,在递归神经网络(RNN)中,层间交互由多个矩阵表示。在我的代码中,这些矩阵特别是:WxhWhhWhy,分别表示输入层与隐藏层、隐藏层与隐藏层、以及隐藏层与输出层之间的权重。

Wxh 矩阵将输入层连接到隐藏层,并用于将每个时间步的输入转换为隐藏层的激活值。Whh 矩阵将时间步 t-1 的隐藏层连接到时间步 t 的隐藏层,并用于将隐藏状态从一个时间步传播到下一个时间步。Why 矩阵将隐藏层连接到输出层,并用于将隐藏状态转换为网络的最终输出。

总之,ANN 和 RNN 之间的主要区别在于 ANN 有一个权重矩阵,而 RNN 有多个权重矩阵,这些矩阵用于转换输入、传播隐藏状态并生成最终输出。这些多个权重矩阵使 RNN 能够保持对过去输入的记忆,并在时间上移动信息。

构造函数

def __init__(self, hidden_size, vocab_size, seq_length, num_layers):
    self.name = 'RNN'
    self.hidden_size = hidden_size
    self.vocab_size = vocab_size
    self.num_layers = num_layers

    # model parameters
    self.Wxh = [np.random.randn(hidden_size, vocab_size)*0.01 for _ in range(num_layers)] # input to hidden
    self.Whh = [np.random.randn(hidden_size, hidden_size)*0.01 for _ in range(num_layers)] # hidden to hidden
    self.Why = np.random.randn(vocab_size, hidden_size)*0.01 # hidden to output
    self.bh = [np.zeros((hidden_size, 1)) for _ in range(num_layers)] # hidden bias
    self.by = np.zeros((vocab_size, 1)) # output bias

    # memory variables for training (ada grad from karpathy's github)
    self.iteration, self.pointer = 0, 0
    self.mWxh = [np.zeros_like(w) for w in self.Wxh]
    self.mWhh = [np.zeros_like(w) for w in self.Whh] 
    self.mWhy = np.zeros_like(self.Why)
    self.mbh, self.mby = [np.zeros_like(b) for b in self.bh], np.zeros_like(self.by)
    self.loss = -np.log(1.0/vocab_size)*seq_length # loss at iteration 0

    self.running_loss = []

在这里,我们定义了如上所述的 RNN 参数。值得注意的是——参数 Whyby 表示一个线性层,甚至可以进一步抽象为一个独立的类,如 PyTorch 的 ‘nn.Linear’ 模块。然而,在此实现中,我们将它们作为 RNN 类的一部分保留。

前向传播

def __call__(self, *args: Any, **kwds: Any) -> Any:
    """RNN Forward Pass"""

    x, y, hprev = kwds['inputs'], kwds['targets'], kwds['hprev']

    loss = 0
    xs, hs, ys, ps = {}, {}, {}, {} # inputs, hidden state, output, probabilities
    hs[-1] = np.copy(hprev)

    # forward pass
    for t in range(len(x)):
        xs[t] = np.zeros((self.vocab_size,1)) # encode in 1-of-k representation
        xs[t][x[t]] = 1
        hs[t] = np.copy(hprev)

        if kwds.get('dropout', False): # use dropout layer (mask)

            for l in range(self.num_layers):
                dropout_mask = (np.random.rand(*hs[t-1][l].shape) < (1-0.5)).astype(float)
                hs[t-1][l] *= dropout_mask
                hs[t][l] = np.tanh(np.dot(self.Wxh[l], xs[t]) + np.dot(self.Whh[l], hs[t-1][l]) + self.bh[l]) # hidden state
                hs[t][l] = hs[t][l] / (1 - 0.5)

        else: # no dropout layer (mask)

            for l in range(self.num_layers):
                hs[t][l] = np.tanh(np.dot(self.Wxh[l], xs[t]) + np.dot(self.Whh[l], hs[t-1][l]) + self.bh[l]) # hidden state

        ys[t] = np.dot(self.Why, hs[t][-1]) + self.by # unnormalized log probabilities for next chars
        ps[t] = np.exp(ys[t]) / np.sum(np.exp(ys[t])) # probabilities for next chars
        loss += -np.log(ps[t][y[t],0]) # softmax (cross-entropy loss)

    self.running_loss.append(loss)

    return loss, hs[len(x)-1], {'xs':xs, 'hs':hs, 'ps':ps}

从顶部开始,逐步解析。这儿发生了什么?

循环外,每个序列一次

  • hprev 是我们的初始隐藏状态

  • 我们正在初始化字典以保存我们的输入、隐藏状态、logits 和概率。我们将在反向传播过程中需要这些。

  • 我们将损失初始化为零。

  • 我们将初始隐藏状态 hprev 设置为 hs[-1](表示时间步 t-1)。

循环内,每个时间步

  • 对我们的输入序列进行独热编码

  • 在时间步‘t’和层‘l’更新隐藏状态。

我们隐藏状态的数学表示。图片由作者提供。

  • 此外,你可能注意到有一些功能用于执行dropout

Dropout 是一种正则化技术,旨在通过在训练过程中随机“丢弃”(置为零)某些神经元来防止过拟合。在上述代码中,dropout 层在更新时间步‘t’、层‘l’的隐藏状态之前应用,通过将时间步‘t-1’的隐藏状态与 dropout 掩码相乘来实现。dropout 掩码是通过创建一个二值掩码生成的,其中每个元素以概率 p 为 1,否则为 0。通过这样做,我们在隐藏状态中随机“丢弃”一定数量的神经元,这有助于防止网络对任何单一神经元过于依赖。这使得网络更为鲁棒,并且不容易在训练数据上过拟合。在应用 dropout 之后,通过将隐藏状态除以(1-p)来缩放隐藏状态,以确保隐藏状态的期望值得到保持。

  • ys[t]为当前时间步提供线性层的输出。

  • ps[t]为当前时间步提供最终的 softmax 输出(概率)。

由于只有一层线性层,而不是任意数量的 RNN 层,因此ys[t]ps[t]的计算位于第二个循环之外。

  • 最后,我们返回损失,并且hs[len(x)-1]被用作下一个序列的hprev。我们使用缓存来获取反向传播过程中的梯度。

选择使用索引[t][l]来存储在时间步‘t’处第‘l’层的隐藏状态。这是因为模型一次处理一个时间步的输入序列,并且在每个时间步,它更新每一层的隐藏状态。通过使用索引[t][l],我们能够跟踪每一层在每个时间步的隐藏状态,从而方便地执行前向传递所需的计算。

此外,这种索引允许轻松访问最后一个时间步的隐藏状态,该状态作为hs[len(x)-1]返回,因为它是每层序列中最后一个时间步的隐藏状态。这个返回的隐藏状态在训练过程中被用作下一个序列的初始隐藏状态。

让我们进行前向传递。请记住,没有批量维度。

# Initialize RNN
num_layers = 3
hidden_size = 100
seq_length = 8

rnn = RNN(hidden_size=hidden_size, vocab_size=vocab_size, seq_length=seq_length, num_layers=num_layers)

x = [char_to_idx[ch] for ch in data[rnn.pointer:rnn.pointer+seq_length]]

y = [char_to_idx[ch] for ch in data[rnn.pointer+1:rnn.pointer+seq_length+1]]

# initialize hidden state with zeros
hprev = [np.zeros((hidden_size, 1)) for _ in range(num_layers)] 

## Call RNN
loss, hprev, cache = rnn(inputs=x, targets=y, hprev=hprev)

print(loss)
>> 33.38852380987117

反向传播

首先,了解 RNN 反向传播的一些直观感受。

基本 ANN 和 RNN 在反向传播中的关键区别在于错误在网络中的传播方式。虽然 ANN 和 RNN 都将错误从输出层传播到输入层,但 RNN 还会通过时间反向传播错误,在每个时间步调整权重和偏置。这使得 RNN 能够处理序列数据,并以隐藏状态的形式维持“记忆”。

BPTT(时间反向传播)算法通过展开 RNN 并为每个时间步创建计算图来工作。这个网络的计算图可以在这里查看。

然后为每个时间步计算梯度,并在整个序列上累积。

def backward(self, targets, cache):
    """RNN Backward Pass"""

    # unpack cache
    xs, hs, ps = cache['xs'], cache['hs'], cache['ps']

    # initialize gradients to zero
    dWxh, dWhh, dWhy = [np.zeros_like(w) for w in self.Wxh], [np.zeros_like(w) for w in self.Whh], np.zeros_like(self.Why)
    dbh, dby = [np.zeros_like(b) for b in self.bh], np.zeros_like(self.by)
    dhnext = [np.zeros_like(h) for h in hs[0]]

    for t in reversed(range(len(xs))):

        dy = np.copy(ps[t])

        # backprop into y. see http://cs231n.github.io/neural-networks-case-study/#grad if confused here
        dy[targets[t]] -= 1 

        dWhy += np.dot(dy, hs[t][-1].T)
        dby += dy

        for l in reversed(range(self.num_layers)):
            dh = np.dot(self.Why.T, dy) + dhnext[l]
            dhraw = (1 - hs[t][l] * hs[t][l]) * dh # backprop through tanh nonlinearity
            dbh[l] += dhraw
            dWxh[l] += np.dot(dhraw, xs[t].T)
            dWhh[l] += np.dot(dhraw, hs[t-1][l].T)
            dhnext[l] = np.dot(self.Whh[l].T, dhraw)

    return {'dWxh':dWxh, 'dWhh':dWhh, 'dWhy':dWhy, 'dbh':dbh, 'dby':dby}

与前向传递相同,让我们分解它。

这个函数的第一步是将权重和偏置的梯度初始化为零,这类似于前馈 ANN 中的情况。这是让我困惑的一点,因此我会进一步详细解释。

通过在每个序列之前将梯度重置为零,确保当前序列计算的梯度不会与先前序列计算的梯度累积或叠加。

这可以防止梯度变得过大,从而导致优化过程发散并对模型性能产生负面影响。此外,它允许对每个序列独立执行权重更新,这可以导致更稳定和一致的优化。

然后,它会反向遍历输入序列,对每个时间步 t 执行以下计算:

注意注释,回传到 y,这个链接将完美解释发生了什么。我也在之前的文章中深入探讨了这一点,你可以在这里查看。

  • 计算隐藏状态hs[t][l]相对于损失的梯度,用dh表示。

  • 计算原始隐藏状态的梯度,用dhraw表示

dh 和 dhraw 的区别是什么? 好问题。

dhdhraw的区别在于,dh是相对于损失的隐藏状态hs[t][l]的梯度,通过反向传播输出层 softmax 激活的概率ps[t]的梯度计算得出。dhraw是相同的梯度,但它进一步通过非线性 tanh 激活函数反向传播,通过元素级相乘隐藏状态dh的梯度与 tanh 函数的导数(1 - hs[t][l] * hs[t][l])得到。

  • 计算隐藏偏置bh[l]的梯度

  • 计算相对于损失的输入-隐藏权重Wxh[l]的梯度,用dWxh[l]表示。

  • 计算相对于损失的隐藏-隐藏权重Whh[l]的梯度,用dWhh[l]表示。

  • 计算下一个隐藏状态dhnext[l]的梯度。

反向传递计算的数学符号。作者提供的图像。

让我们进行反向传递。

# Initialize RNN
num_layers = 3
hidden_size = 100
seq_length = 8

rnn = RNN(hidden_size=hidden_size, vocab_size=vocab_size, seq_length=seq_length, num_layers=num_layers)

x = [char_to_idx[ch] for ch in data[rnn.pointer:rnn.pointer+seq_length]]

y = [char_to_idx[ch] for ch in data[rnn.pointer+1:rnn.pointer+seq_length+1]]

# initialize hidden state with zeros
hprev = [np.zeros((hidden_size, 1)) for _ in range(num_layers)] 

## Call RNN
loss, hprev, cache = rnn(inputs=x, targets=y, hprev=hprev)
grads = rnn.backward(targets=y, cache=cache)

最后,我们返回梯度,以便更新参数,这也是我下一个话题——优化的引子。

优化器

如 RNN 类中的‘update’方法所述,我们将使用 Adagrad 进行这个实现。

Adagrad 是一种优化算法,它根据历史梯度信息为神经网络中的每个参数单独调整学习率。

它特别适用于处理稀疏数据,通常用于自然语言处理任务。Adagrad 在每次迭代时调整学习率,确保模型尽可能快速和高效地从数据中学习。

def update(self, grads, lr):
    """Perform Parameter Update w/ Adagrad"""

    # unpack grads
    dWxh, dWhh, dWhy = grads['dWxh'], grads['dWhh'], grads['dWhy']
    dbh, dby = grads['dbh'], grads['dby']

    # loop through each layer
    for i in range(self.num_layers):

        # clip gradients to mitigate exploding gradients
        np.clip(dWxh[i], -5, 5, out=dWxh[i])
        np.clip(dWhh[i], -5, 5, out=dWhh[i])
        np.clip(dbh[i], -5, 5, out=dbh[i])

        # perform parameter update with Adagrad
        self.mWxh[i] += dWxh[i] * dWxh[i]
        self.Wxh[i] -= lr * dWxh[i] / np.sqrt(self.mWxh[i] + 1e-8)
        self.mWhh[i] += dWhh[i] * dWhh[i]
        self.Whh[i] -= lr * dWhh[i] / np.sqrt(self.mWhh[i] + 1e-8)
        self.mbh[i] += dbh[i] * dbh[i]
        self.bh[i] -= lr * dbh[i] / np.sqrt(self.mbh[i] + 1e-8)

    # clip gradients for Why and by
    np.clip(dWhy, -5, 5, out=dWhy)
    np.clip(dby, -5, 5, out=dby)

    # perform parameter update with Adagrad
    self.mWhy += dWhy * dWhy
    self.Why -= lr * dWhy / np.sqrt(self.mWhy + 1e-8)
    self.mby += dby * dby
    self.by -= lr * dby / np.sqrt(self.mby + 1e-8)

这段代码使用 Adagrad 优化算法更新 RNN 的参数。它跟踪参数的梯度平方和(mWxhmWhhmbhmWhymby),并用这个和的平方根加上一个小常数 1e-8 来除以学习率,以确保数值稳定,从而有效地调整每个参数的学习率。此外,它剪切梯度以防止 梯度爆炸

Adagrad 为每个参数调整学习率,对不频繁更新的参数执行较大的更新,对频繁更新的参数执行较小的更新。这意味着,对于不频繁更新的参数,学习率会较大,以便模型能对这些参数做出更大的调整。另一方面,对于频繁更新的参数,学习率会较小,从而使模型对这些参数进行小幅调整,以防过拟合。这与使用固定学习率形成对比,后者可能会导致参数校正不足或过度校正。

让我们进行参数更新。

# Initialize RNN
num_layers = 3
hidden_size = 100
seq_length = 8

rnn = RNN(hidden_size=hidden_size, vocab_size=vocab_size, seq_length=seq_length, num_layers=num_layers)

x = [char_to_idx[ch] for ch in data[rnn.pointer:rnn.pointer+seq_length]]

y = [char_to_idx[ch] for ch in data[rnn.pointer+1:rnn.pointer+seq_length+1]]

# initialize hidden state with zeros
hprev = [np.zeros((hidden_size, 1)) for _ in range(num_layers)] 

## Call RNN
loss, hprev, cache = rnn(inputs=x, targets=y, hprev=hprev)
grads = rnn.backward(targets=y, cache=cache)
rnn.update(grads=grads, lr=1e-1)

训练

最后一步实际上是训练网络,将输入序列输入网络中,计算错误,优化器更新权重和偏差。

def train(rnn, epochs, data, lr=1e-1, use_drop=False):

    for _ in range(epochs):

        # prepare inputs (we're sweeping from left to right in steps seq_length long)
        if rnn.pointer+seq_length+1 >= len(data) or rnn.iteration == 0:

            hprev = [np.zeros((hidden_size, 1)) for _ in range(rnn.num_layers)]  # reset RNN memory

            rnn.pointer = 0 # go from start of data

        x = [char_to_idx[ch] for ch in data[rnn.pointer:rnn.pointer+seq_length]]
        y = [char_to_idx[ch] for ch in data[rnn.pointer+1:rnn.pointer+seq_length+1]]

        if use_drop:
            loss, hprev, cache = rnn(inputs=x, targets=y, hprev=hprev, dropout=True)
        else:
            loss, hprev, cache = rnn(inputs=x, targets=y, hprev=hprev)

        grads = rnn.backward(targets=y, cache=cache)
        rnn.update(grads=grads, lr=lr)

        # update loss
        rnn.loss = rnn.loss * 0.999 + loss * 0.001

        ## show progress now and then
        if rnn.iteration % 1000 == 0: 
            print('iter {}, loss: {}'.format(rnn.iteration, rnn.loss))

            sample_ix = rnn.predict(hprev, x[0], 200)
            txt = ''.join(idx_to_char[ix] for ix in sample_ix)
            print('Sample')
            print ('----\n {} \n----'.format(txt))

        rnn.pointer += seq_length # move data pointer
        rnn.iteration += 1 # iteration counter

## hyper-params
num_layers = 2
hidden_size = 128
seq_length = 13

# Initialize RNN
rnn = RNN(hidden_size=hidden_size, 
          vocab_size=vocab_size, 
          seq_length=seq_length, 
          num_layers=num_layers)

train(rnn=rnn, epochs=15000, data=data)

这段代码非常简单。我们执行前向和反向传播,并在每个纪元更新模型参数。

我想指出的是——

损失通过当前损失和前一个损失的加权平均来更新。

当前损失乘以 0.001 并加到前一个损失上,前一个损失乘以 0.999。这意味着当前损失对总损失的影响较小,而前面的损失影响较大。这样,总损失波动不会那么大,并且会随着时间更加稳定。

通过使用 EMA(指数移动平均),更容易监控网络的性能,并检测它是否过拟合或欠拟合。

第零次迭代的损失与文本预测。图片由作者提供。

第 14,000 次迭代的损失与文本预测。图片由作者提供。

50,000 个纪元后的损失。

我们的 RNN 训练过程已经成功,我们可以看到损失的减少和生成样本质量的提高。然而,需要注意的是,生成原创莎士比亚文本是一个复杂的任务,而这一实现是一个简单的普通 RNN。因此,仍有进一步改进和尝试不同架构和技术的空间。

结论

总之,本文展示了如何使用 Numpy 实现和训练一个字符级 RNN。多对多架构和在线学习方法使得网络能够适应数据中的新模式,从而改进样本生成。虽然这个网络在生成原创莎士比亚文本方面具有一定能力,但需要注意的是,这只是一个简化版本,还有许多其他架构和技术可以探索,以获得更好的性能。

完整代码和仓库 在这里

随时与我们联系并提出问题,或对代码进行改进。

感谢阅读!

使用 RetinaNet 和 KerasCV 的目标检测

原文:towardsdatascience.com/object-detection-using-retinanet-and-kerascv-b07940327b6c?source=collection_archive---------3-----------------------#2023-12-06

使用 KerasCV 库的力量和简便性进行目标检测。

Ed IzaguirreTowards Data Science Ed Izaguirre

·

关注 发表在 Towards Data Science · 21 分钟阅读 · 2023 年 12 月 6 日

--

一张植物叶子的图像。创建于 DALL·E 2

目录

  1. 等等,什么是 KerasCV?

  2. 检查数据

  3. 图像预处理

  4. RetinaNet 模型背景

  5. 训练 RetinaNet

  6. 做出预测

  7. 结论

  8. 参考文献

相关链接

  • Kaggle 实验笔记: 随意复制笔记本,试验代码,并使用免费的 GPU。

  • PlantDoc 数据集:这是本笔记本中使用的数据集,托管在 Roboflow 上。该数据集在 CC BY 4.0 DEED 许可证下发布,这意味着你可以在任何媒介或格式中复制和重新分发该材料,甚至用于商业目的。

等等,什么是 KerasCV?

在完成基于图像分割的小项目后(参见这里),我准备转入计算机视觉领域下另一个常见任务:物体检测。物体检测指的是对图像进行处理,产生围绕感兴趣对象的框,并分类这些框中的对象。作为一个简单的例子,看看下面的图片:

图片

物体检测的示例。请注意边界框和类标签。图片由作者提供。

蓝色的框被称为边界框类名放置在其正上方。因此,物体检测可以分解为两个小问题:

  1. 一个回归问题,模型必须预测盒子左上角和右下角的xy坐标。

  2. 一个分类问题,模型必须预测盒子正在观察的物体类别。

在这个例子中,边界框是由人类创建和标记的。我们希望自动化这个过程,而一个训练良好的物体检测模型正可以做到这一点。

我坐下来回顾我关于物体检测的学习资料,很快就感到失望。不幸的是,大多数介绍性的资料几乎没有提到物体检测。François Chollet 在Python 深度学习 [1] 中提到:

请注意,我们不会涵盖物体检测,因为它对于介绍性书籍来说过于专业和复杂。

Aurélion Géron [2] 提供了许多关于物体检测背后思想的文本内容,但只提供了几行代码来处理带有虚拟边界框的物体检测任务,远未达到我所期望的端到端流水线。Andrew Ng [3] 的著名深度学习专项课程在物体检测方面涵盖最深入,但甚至他在编码实验室中也只是加载了一个预训练的物体检测模型进行推理。

想要更深入地研究,我开始勾勒出一个物体检测流水线的大纲。仅仅为了为 RetinaNet 模型进行预处理,你需要执行以下步骤(注:其他物体检测模型如 YOLO 需要不同的步骤):

  • 将输入图片都调整为相同的大小,并进行填充以防止长宽比混乱。哦,不要忘记边界框;这些也需要适当地重新调整形状,否则你会破坏你的数据。

  • 根据训练集中的真实边界框生成不同尺度和纵横比的锚框。这些锚框在训练过程中作为模型的参考点。

  • 根据与真实框的重叠情况为锚框分配标签。重叠度高的锚框标记为正例,而重叠度低的锚框标记为负例。

  • 描述相同的边界框有多种方法。你需要实现函数来在这些不同格式之间进行转换。稍后会详细介绍。

  • 实现数据增强时,不仅要增强图像,还要增强框。理论上你可以省略这一步,但在实践中这是必要的,以帮助我们的模型更好地泛化。

看看这个例子 在 Keras 网站上。哎呀。我们模型预测的后处理将需要更多工作。借用 Keras 团队的话:这是一个技术上复杂的问题。

当我开始绝望时,我开始急切地浏览互联网,偶然发现了一个我从未听说过的库:KerasCV。当我阅读文档时,我开始意识到这是TensorFlow/Keras 计算机视觉的未来。根据他们的介绍:

KerasCV 可以被理解为 Keras API 的横向扩展:这些组件是新的第一方 Keras 对象,过于专业化而无法添加到核心 Keras 中。它们与核心 Keras API 享有相同级别的打磨和向后兼容保证,并由 Keras 团队维护。

“但为什么我的学习材料中没有提到这个?” 我想。答案很简单:这是一个相当新的库。GitHub 上的第一次提交是在 2022 年 4 月 13 日,太新了,甚至还未出现在我教科书的最新版本中。事实上,该库的 1.0 版本尚未发布(截至 2023 年 11 月 10 日,它是 0.6.4)。我预计 KerasCV 会在我教科书的下一版和在线课程中详细讨论(公平地说,Gèron 确实提到过“新的 Keras NLP 项目”和 Keras CV 项目,读者可能会感兴趣)。

KerasCV 刚刚推出,除了 Keras 团队自己发布的教程外,还没有很多教程(见这里)。在本教程中,我将演示一个端到端的目标检测流程,使用受官方 Keras 指南启发但又不同于这些指南的技术来识别健康和病变叶片。有了 KerasCV,即使是初学者也可以利用标记数据集来构建有效的目标检测管道。

在我们开始之前需要注意几点。KerasCV 是一个快速变化的库,其代码库和文档会定期更新。这里展示的实现将适用于 KerasCV 版本 0.6.4。Keras 团队已声明:“在 KerasCV 达到 v1.0.0 之前,没有向后兼容的承诺。” 这意味着无法保证本教程中使用的方法在 KerasCV 更新时仍然有效。我已在链接的 Kaggle notebook 中硬编码了 KerasCV 版本号,以防止这些问题。

KerasCV 有很多已知的错误,可以在 GitHub 的问题标签页 中查看。此外,文档在一些领域也有所欠缺(我看着你,MultiClassNonMaxSuppression)。在使用 KerasCV 时,尽量不要被这些问题气馁。事实上,这是一个成为 KerasCV 代码库贡献者的绝佳机会!

本教程将重点介绍 KerasCV 的实现细节。我将简要回顾一些目标检测的高级概念,但假设读者对如 RetinaNet 架构等概念有一定背景知识。这里展示的代码已进行编辑和调整以提高清晰度,完整代码请参见上面链接的 Kaggle notebook。

最后,关于安全的提示。这里创建的模型并非最先进的技术;请将其视为一个高层次的教程。在将此植物疾病检测模型投入生产之前,需要进一步的微调和数据清理。最好将模型做出的任何预测交由人工专家确认诊断。

检查数据

PlantDoc 数据集包含 2,569 张图像,涵盖 13 种植物和 30 个类别。数据集的目标在 Singh 等人撰写的论文 PlantDoc: A Dataset for Visual Plant Disease Detection 的摘要中进行了阐述 [4]。

印度由于植物疾病每年损失 35% 的作物产量。由于缺乏实验室基础设施和专业知识,植物疾病的早期检测仍然很困难。本文探讨了计算机视觉方法在可扩展和早期植物疾病检测中的可能性。

这是一个崇高的目标,也是计算机视觉可以为农民做出很多贡献的领域。

Roboflow 允许我们以多种不同格式下载数据集。由于我们使用 TensorFlow,建议将数据集下载为 TFRecord 格式。TFRecord 是 TensorFlow 中一种特定格式,旨在高效地存储大量数据。数据由一系列记录表示,每个记录是一个键值对。每个键称为 feature。下载的压缩文件包含四个文件,其中两个用于训练,两个用于验证:

  • leaves_label_map.pbtxt : 这是一个 Protocol Buffers 文本格式文件,用于描述数据的结构。打开文件时,我看到有三十个类别。既有健康叶子如 Apple leaf,也有不健康叶子如 Apple Scab Leaf

  • leaves.tfrecord : 这是包含我们所有数据的 TFRecord 文件。

我们的第一步是检查 leaves.tfrecord。我们的记录包含哪些特征?不幸的是,Roboflow 并未指定这一点。

train_tfrecord_file = '/kaggle/input/plants-dataset/leaves.tfrecord'
val_tfrecord_file = '/kaggle/input/plants-dataset/test_leaves.tfrecord'

# Create a TFRecordDataset
train_dataset = tf.data.TFRecordDataset([train_tfrecord_file])
val_dataset = tf.data.TFRecordDataset([val_tfrecord_file])

# Iterate over a few entries and print their content. Uncomment this to look at the raw data
for record in train_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(record.numpy())
  print(example)

我看到以下打印的特征:

  • image/encoded : 这是图像的编码二进制表示。在这个数据集中,图像是以 jpeg 格式编码的。

  • image/height : 这是每个图像的高度。

  • image/width : 这是每个图像的宽度。

  • image/object/bbox/xmin : 这是边界框左上角的 x 坐标。

  • image/object/bbox/xmax : 这是边界框右下角的 x 坐标。

  • image/object/bbox/ymin : 这是边界框左上角的 y 坐标。

  • image/object/bbox/ymax : 这是边界框右下角的 y 坐标。

  • image/object/class/label : 这些是与每个边界框关联的标签。

现在我们想把所有图像及其关联的边界框整合到一个 TensorFlow Dataset 对象中。Dataset 对象允许你存储大量数据而不会使系统内存超载。这是通过延迟加载批处理等功能实现的。延迟加载意味着数据不会被加载到内存中,直到它被显式请求(例如在执行转换或训练时)。批处理意味着一次只加载选择数量的图像(通常为 8、16、32 等)。简而言之,我建议你始终将数据转换为 Dataset 对象,特别是在处理大量数据时(在目标检测中很常见)。

要将 TFRecord 转换为 TensorFlow 中的 Dataset 对象,你可以使用 tf.data.TFRecordDataset 类从 TFRecord 文件创建数据集,然后使用 map 方法应用解析函数来提取和预处理特征。解析代码如下所示。

def parse_tfrecord_fn(example):
    feature_description = {
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
        'image/object/class/label': tf.io.VarLenFeature(tf.int64),
    }

    parsed_example = tf.io.parse_single_example(example, feature_description)

    # Decode the JPEG image and normalize the pixel values to the [0, 255] range.
    img = tf.image.decode_jpeg(parsed_example['image/encoded'], channels=3) # Returned as uint8

    # Get the bounding box coordinates and class labels.
    xmin = tf.sparse.to_dense(parsed_example['image/object/bbox/xmin'])
    xmax = tf.sparse.to_dense(parsed_example['image/object/bbox/xmax'])
    ymin = tf.sparse.to_dense(parsed_example['image/object/bbox/ymin'])
    ymax = tf.sparse.to_dense(parsed_example['image/object/bbox/ymax'])
    labels = tf.sparse.to_dense(parsed_example['image/object/class/label'])

    # Stack the bounding box coordinates to create a [num_boxes, 4] tensor.
    rel_boxes = tf.stack([xmin, ymin, xmax, ymax], axis=-1)
    boxes = keras_cv.bounding_box.convert_format(rel_boxes, source='rel_xyxy', target='xyxy', images=img)

    # Create the final dictionary.
    image_dataset = {
        'images': img,
        'bounding_boxes': {
            'classes': labels,
            'boxes': boxes
        }
    }
    return image_dataset

让我们详细拆解一下:

  • feature_description : 这是一个描述每个特征预期格式的字典。当特征在数据集中所有示例中的长度是固定时,我们使用 tf.io.FixedLenFeature,当长度存在某些变动时,我们使用 tf.io.VarLenFeature。由于边界框的数量在数据集中并不固定(有些图像有更多框,有些则较少),因此我们对所有与边界框相关的内容使用 tf.io.VarLenFeature

  • 我们使用 tf.image.decode_jpeg 解码图像文件,因为我们的图像是以 JPEG 格式编码的。

  • 请注意用于边界框坐标和标签的 tf.sparse.to_dense 的使用。当我们使用 tf.io.VarLenFeature 时,信息会以稀疏矩阵的形式返回。稀疏矩阵是大多数元素为零的矩阵,结果是一个只有效存储非零值及其索引的数据结构。不幸的是,TensorFlow 中的许多预处理函数要求使用稠密矩阵。这包括 tf.stack,我们用来水平堆叠来自多个边界框的信息。为了解决这个问题,我们使用 tf.sparse.to_dense 将稀疏矩阵转换为稠密矩阵。

  • 在堆叠框之后,我们使用 KerasCV 的 keras_cv.bounding_box.convert_format 函数。检查数据时,我注意到边界框坐标被归一化在 0 和 1 之间。这意味着这些数字表示图像总宽度/高度的百分比。例如,值为 0.5 表示 50% * image_width。这是一种 相对格式,Keras 称之为 REL_XYXY,而不是 绝对格式 XYXY。理论上,转换为绝对格式不是必要的,但当我使用相对坐标训练模型时遇到了错误。有关其他支持的边界框格式,请参见 KerasCV 文档

  • 最后,我们将图像和边界框转换为 KerasCV 所需的格式:字典。Python 字典是一种包含键值对的数据类型。具体来说,KerasCV 期望以下格式:

image_dataset = {
  "images": [width, height, channels],
  bounding_boxes = {
    "classes": [num_boxes],
    "boxes": [num_boxes, 4]
  }
}

这实际上是一个“字典中的字典”,因为 bounding_boxes 也是一个字典。

最后使用 .map 函数将解析函数应用于我们的 TFRecord。然后可以检查 Dataset 对象。一切正常。

train_dataset = train_dataset.map(parse_tfrecord_fn)
val_dataset = val_dataset.map(parse_tfrecord_fn)

# Inspecting the data
for data in train_dataset.take(1):
    print(data)

恭喜,最困难的部分现在已经完成了。 在我看来,创建 KerasCV 所需的“字典中的字典”是最具挑战性的任务。其余部分更为直接。

图像预处理

我们的数据已经分为训练集和验证集。所以我们将开始对数据集进行批处理。

# Batching
BATCH_SIZE = 32
# Adding autotune for pre-fetching
AUTOTUNE = tf.data.experimental.AUTOTUNE

train_dataset = train_dataset.ragged_batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
val_dataset = val_dataset.ragged_batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

NUM_ROWS = 4
NUM_COLS = 8
IMG_SIZE = 416
BBOX_FORMAT = "xyxy"

一些说明:

  • 我们使用 ragged_batch 是因为我们不知道每个图像将有多少个边界框。如果所有图像都有相同数量的边界框,那么我们可以直接使用 batch

  • 我们设置了 BBOX_FORMAT=“xyxy” 。回忆一下,之前在加载数据时,我们将边界框格式从相对的 XYXY 格式转换为绝对的 XYXY 格式。

现在我们可以实现 数据增强。数据增强是计算机视觉问题中的一种常见技术。它对训练图像进行轻微的修改,例如轻微旋转、水平翻转图像等。这有助于解决数据不足的问题,并且有助于正则化。在这里,我们将引入以下增强方法:

  • KerasCV 的JitteredResize函数。这个函数旨在用于目标检测管道,实现了一种图像增强技术,涉及随机缩放、调整大小、裁剪和填充图像及相应的边界框。这一过程引入了尺度和局部特征的变异,提高了训练数据的多样性,从而改善了模型的泛化能力。

  • 然后我们添加了水平和垂直的RandomFlips以及RandomRotation。这里的factor是一个表示 2π分数的浮点数。我们使用 0.25,这意味着我们的增强器会将图像旋转-25%到 25%π之间的某个角度。以度数表示,这意味着旋转范围在-45°到 45°之间。

  • 最后,我们添加了RandomSaturationRandomHue。饱和度为 0.0 会留下灰度图像,而 1.0 则完全饱和。0.5 的因子不会造成任何变化,因此选择 0.4–0.6 的范围会产生细微的变化。色调因子为 0.0 不会产生变化。设置factor=0.2表示范围为 0.0–0.2,这是另一种细微变化。

augmenter = keras.Sequential(
    [
        keras_cv.layers.JitteredResize(
            target_size=(IMG_SIZE, IMG_SIZE), scale_factor=(0.8, 1.25), bounding_box_format=BBOX_FORMAT
        ),
        keras_cv.layers.RandomFlip(mode="horizontal_and_vertical", bounding_box_format=BBOX_FORMAT),
        keras_cv.layers.RandomRotation(factor=0.25, bounding_box_format=BBOX_FORMAT),
        keras_cv.layers.RandomSaturation(factor=(0.4, 0.6)),
        keras_cv.layers.RandomHue(factor=0.2, value_range=[0,255])
    ]
)

train_dataset = train_dataset.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)

我们通常只对训练集进行增强,因为我们希望模型避免“记忆”模式,而是确保模型学习到在现实世界中会遇到的通用模式。这增加了模型在训练过程中看到的多样性。

我们还希望将验证图像调整为相同的大小(带有填充)。这些图像将在不失真的情况下调整大小。边界框也必须相应地重新调整。KerasCV 可以轻松处理这一困难任务:

# Resize and pad images
inference_resizing = keras_cv.layers.Resizing(
    IMG_SIZE, IMG_SIZE, pad_to_aspect_ratio=True, bounding_box_format=BBOX_FORMAT
)

val_dataset = val_dataset.map(inference_resizing, num_parallel_calls=tf.data.AUTOTUNE)

最后,我们可以可视化我们的图像和包含预处理的边界框:

class_mapping = {
    1: 'Apple Scab Leaf',
    2: 'Apple leaf',
    3: 'Apple rust leaf',
    4: 'Bell_pepper leaf',
    5: 'Bell_pepper leaf spot',
    6: 'Blueberry leaf',
    7: 'Cherry leaf',
    8: 'Corn Gray leaf spot',
    9: 'Corn leaf blight',
    10: 'Corn rust leaf',
    11: 'Peach leaf',
    12: 'Potato leaf',
    13: 'Potato leaf early blight',
    14: 'Potato leaf late blight',
    15: 'Raspberry leaf',
    16: 'Soyabean leaf',
    17: 'Soybean leaf',
    18: 'Squash Powdery mildew leaf',
    19: 'Strawberry leaf',
    20: 'Tomato Early blight leaf',
    21: 'Tomato Septoria leaf spot',
    22: 'Tomato leaf',
    23: 'Tomato leaf bacterial spot',
    24: 'Tomato leaf late blight',
    25: 'Tomato leaf mosaic virus',
    26: 'Tomato leaf yellow virus',
    27: 'Tomato mold leaf',
    28: 'Tomato two spotted spider mites leaf',
    29: 'grape leaf',
    30: 'grape leaf black rot'
}

def visualize_dataset(inputs, value_range, rows, cols, bounding_box_format):
    inputs = next(iter(inputs.take(1)))
    images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]
    visualization.plot_bounding_box_gallery(
        images,
        value_range=value_range,
        rows=rows,
        cols=cols,
        y_true=bounding_boxes,
        scale=5,
        font_scale=0.7,
        bounding_box_format=bounding_box_format,
        class_mapping=class_mapping,
    )

# Visualize training set
visualize_dataset(
    train_dataset, bounding_box_format=BBOX_FORMAT, value_range=(0, 255), rows=NUM_ROWS, cols=NUM_COLS
)

# Visualize validation set
visualize_dataset(
    val_dataset, bounding_box_format=BBOX_FORMAT, value_range=(0, 255), rows=NUM_ROWS, cols=NUM_COLS
)

这种类型的可视化函数在 KerasCV 中很常见。它绘制了一组图像和框,行和列由参数指定。我们看到我们的训练图像有些被轻微旋转,有些被水平或垂直翻转,可能还进行了放大或缩小,并且色调/饱和度的细微变化也可以看到。在 KerasCV 中,所有增强层也会在必要时增强边界框。 请注意,class_mapping是一个简单的字典。我从之前提到的leaves_label_map.pbtxt文本文件中获得了键和标签。

左侧是原始图像(验证集)的示例,右侧是增强图像(训练集)。图片由作者提供。

在查看 RetinaNet 模型之前最后要说的一件事是,之前我们需要创建“字典中的字典”以将数据转换为与 KerasCV 预处理兼容的格式,但现在我们需要将其转换为数字元组以供模型训练。这相当直接:

def dict_to_tuple(inputs):
    return inputs["images"], bounding_box.to_dense(
        inputs["bounding_boxes"], max_boxes=32
    )

train_dataset = train_dataset.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = val_dataset.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)

RetinaNet 模型背景

一个用于进行目标检测的流行模型叫做RetinaNet。该模型的详细描述超出了本文的范围。简而言之,RetinaNet 是一个单阶段检测器,意味着它在预测边界框和类别之前只查看一次图像。这类似于著名的 YOLO(You Only Look Once)模型,但有一些重要的不同之处。我在这里要强调的是使用的创新分类损失函数:focal loss。它解决了图像中的类别不平衡问题。

为了理解这点的重要性,可以考虑以下类比:假设你是一名教室里有 100 个学生的老师。95 个学生吵闹且喧哗,总是喊叫和举手。5 个学生安静,不怎么说话。作为老师,你需要平等关注每个人,但吵闹的学生正在挤走安静的学生。这里你遇到了类别不平衡的问题。为了解决这个问题,你开发了一种特殊的助听器,它增强了安静学生的声音并弱化了吵闹学生的声音。在这个类比中,吵闹的学生是我们图像中不包含叶子的背景像素的大多数,而安静的学生是那些包含叶子的少量区域。这个“助听器”就是 focal loss,它使我们可以将模型集中在包含叶子的像素上,而不会过多关注那些不包含叶子的像素。

RetinaNet 模型有三个重要组件:

  • 一个 骨干网络。这构成了模型的基础。我们也称之为特征提取器。顾名思义,它接收图像并扫描特征。低层提取低级特征(例如线条和曲线),而高层提取高级特征(例如嘴唇和眼睛)。在这个项目中,骨干网络将是一个在COCO 数据集上进行过预训练的 YOLOv8 模型。我们只将 YOLO 用作特征提取器,而不是作为目标检测器。

  • 特征金字塔网络(FPN)。这是一种模型架构,在不同的尺度上生成“金字塔”特征图,以检测各种大小的对象。它通过通过自上而下的路径和横向连接将低分辨率的语义强特征与高分辨率的语义弱特征结合起来。查看这个视频以获取详细解释,或查看这篇论文 [5],该论文介绍了 FPN。

  • 两个任务特定的子网络。 这些子网络处理金字塔的每一层,并检测每层中的对象。一个子网络用于识别类别(分类),另一个用于识别边界框(回归)。这些子网络尚未训练。

简化的 RetinaNet 架构。图片由作者提供。

之前我们将图像调整为 416x416 的大小。这是一个有点随意的选择,尽管你选择的目标检测模型通常会指定一个所需的最小大小。对于我们使用的 YOLOv8 主干,图像大小应该是 32 的倍数。这是因为主干的最大步幅是 32,而且它是一个完全卷积网络。对于你自己项目中使用的任何模型,请进行调研以找出这个因素。

训练 RetinaNet

让我们从设置一些基本参数开始,比如优化器和我们将使用的指标。这里我们将使用 Adam 作为优化器。请注意global_clip_norm参数。根据KerasCV 目标检测指南

在训练目标检测模型时,你总是希望包含global_clipnorm。这是为了修复在训练目标检测模型时经常出现的梯度爆炸问题。

base_lr = 0.0001
# including a global_clipnorm is extremely important in object detection tasks
optimizer_Adam = tf.keras.optimizers.Adam(
    learning_rate=base_lr,
    global_clipnorm=10.0
)

我们将遵循他们的建议。对于我们的指标,我们将使用BoxCOCOMetrics。这些是目标检测中流行的指标。它们基本上包括平均精度 (mAP)平均召回率 (mAR)。总的来说,mAP 通过测量正确对象检测的平均面积与模型预测覆盖的总面积的比率来量化模型定位和识别对象的有效性。mAR 是一个不同的分数,通过计算正确识别的对象区域与实际对象区域的平均比例来评估模型捕获对象全部范围的能力。有关指标的详细信息,请参见这篇文章这段视频 对精度和召回率的基本知识进行了很好的解释。

coco_metrics = keras_cv.metrics.BoxCOCOMetrics(
    bounding_box_format=BBOX_FORMAT, evaluate_freq=5
)

由于框的指标计算开销很大,我们传递evaluate_freq=5参数,以告知我们的模型在每五个批次后计算指标,而不是在训练期间每个批次后计算。我注意到,当数字设置得过高时,验证指标根本没有被打印出来。

让我们继续查看我们将使用的回调:

class VisualizeDetections(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        if (epoch+1)%5==0:
            visualize_detections(
                self.model, bounding_box_format=BBOX_FORMAT, dataset=val_dataset, rows=NUM_ROWS, cols=NUM_COLS
            )

checkpoint_path="best-custom-model"

callbacks_list = [
    # Conducting early stopping to stop after 6 epochs of non-improving validation loss
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=6,
    ),

    # Saving the best model
    keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=True
    ),

    # Custom metrics printing after each epoch
    tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: 
        print(f"\nEpoch #{epoch+1} \n" +
              f"Loss: {logs['loss']:.4f} \n" + 
              f"mAP: {logs['MaP']:.4f} \n" + 
              f"Validation Loss: {logs['val_loss']:.4f} \n" + 
              f"Validation mAP: {logs['val_MaP']:.4f} \n") 
    ),

    # Visualizing results after each five epochs
    VisualizeDetections()
]
  • 早停。如果验证损失在六个周期后没有改善,我们将停止训练。

  • 模型检查点。我们将在每个周期后检查val_loss,如果它优于早期的周期,将保存模型权重。

  • Lambda 回调。Lambda 回调是一个自定义回调,允许你在训练过程中于每个周期的不同点定义并执行任意 Python 函数。在这种情况下,我们用它来在每个周期后打印自定义指标。如果直接打印 COCOMetrics,会是一堆杂乱的数字。为了简化,我们只打印训练和验证的损失和 mAP。

  • 检测的可视化。 这将在每五个周期后打印出一个 4x8 的图像网格以及预测的边界框。这将使我们洞察我们的模型有多好(或多糟)。如果一切顺利,这些可视化效果应该随着训练的进行而变得更好。

最终我们创建了我们的模型。回顾一下,主干是一个 YOLOv8 模型。我们必须传递我们将使用的 num_classes,以及 bounding_box_format

# Building a RetinaNet model with a backbone trained on coco datset
def create_model():        
    model = keras_cv.models.RetinaNet.from_preset(
        "yolo_v8_m_backbone_coco",
        num_classes=len(class_mapping),
        bounding_box_format=BBOX_FORMAT
    )
    return model

model = create_model()

我们还必须自定义模型的 非极大值抑制 参数。非极大值抑制用于目标检测中,以过滤掉多个重叠的预测边界框,这些框对应于同一对象。它只保留置信度分数最高的框,并删除冗余的框,确保每个对象只被检测一次。它包含两个参数:iou_thresholdconfidence_threshold

  1. IoU,或 交并比,是一个介于 0 和 1 之间的数字,衡量一个预测框与另一个预测框之间的重叠程度。如果重叠高于 iou_threshold,则置信度较低的预测框会被丢弃。

  2. 置信度分数反映了模型对其预测的边界框的信心。如果预测框的置信度分数低于 confidence_threshold,则该框会被丢弃。

尽管这些参数不会影响训练,但它们仍需根据您的特定应用进行调整以用于预测。设置 iou_threshold=0.5confidence_threshold=0.5 是一个好的起点。

在开始训练之前有一点需要注意:我们讨论了为什么将 分类损失 设置为焦点损失是有帮助的,但我们还没有讨论定义预测边界框坐标误差的合适 回归损失。一种流行的回归损失(或 box_loss)是 平滑 L1 损失。我认为平滑 L1 是一种“兼顾两全”的损失。它结合了 L1 损失(绝对误差)和 L2 损失(均方误差)。当误差值较小时,损失是二次的,当误差值较大时,损失是线性的(查看此链接)。KerasCV 为我们的便利提供了内置的平滑 L1 损失。训练期间显示的损失将是 box_lossclassification_loss 的总和。

# Using focal classification loss and smoothl1 box loss with coco metrics
model.compile(
    classification_loss="focal",
    box_loss="smoothl1",
    optimizer=optimizer_Adam,
    metrics=[coco_metrics]
)

history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=40,
    callbacks=callbacks_list,
    verbose=0,
)

在 NVIDIA Tesla P100 GPU 上训练大约需要一个小时 12 分钟。

进行预测

# Create model with the weights of the best model
model = create_model()
model.load_weights(checkpoint_path)

# Customizing non-max supression of model prediction. I found these numbers to work fairly well
model.prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(
    bounding_box_format=BBOX_FORMAT,
    from_logits=True,
    iou_threshold=0.2,
    confidence_threshold=0.6,
)

# Visuaize on validation set
visualize_detections(model, dataset=val_dataset, bounding_box_format=BBOX_FORMAT, rows=NUM_ROWS, cols=NUM_COLS)

现在我们可以加载在训练过程中看到的最佳模型,并用它对验证集进行一些预测:

验证集预测的样本视觉效果。图片由作者提供。

我们最佳模型的指标是:

  • 损失: 0.4185

  • mAP: 0.2182

  • 验证损失: 0.4584

  • 验证集 mAP: 0.2916

值得尊敬,但还有改进的空间。更多内容将在结论中讨论。(注意:我发现MultiClassNonMaxSuppression似乎没有正常工作。上面显示的左下角图像明显有超过 20%重叠的框,但较低置信度的框没有被抑制。这是我需要进一步研究的问题。)

这里是我们每个训练周期和验证周期的损失图。可以看到有些过拟合现象。此外,增加一个学习率调度器以逐渐降低学习率可能是明智的。这可能有助于解决在训练结束时出现的大幅跳跃问题。

每个训练周期和验证周期的损失图。我们看到了一些过拟合的迹象。图片由作者提供。

结论

如果你已经做到这一步,给自己一个赞美吧!目标检测是计算机视觉中较为困难的任务之一。幸运的是,我们有新的 KerasCV 库来简化我们的工作。总结一下创建目标检测管道的工作流程:

  • 开始时可视化你的数据集。问自己一些问题:“我的边界框格式是什么?是xyxyRelxyxy?我处理多少个类别?”确保创建一个类似于visualize_dataset的函数来查看你的图像和边界框。

  • 将你拥有的任何格式的数据转换为 KerasCV 所需的“字典中的字典”格式。使用 TensorFlow Dataset 对象来存储数据特别有帮助。

  • 进行一些基本的预处理,例如图像缩放和数据增强。KerasCV 使这些操作相对简单。请注意查阅你选择的模型的文献,以确保图像尺寸适当。

  • 将字典转换回元组以用于训练。

  • 选择一个优化器Adam是一个简单的选择),两个损失函数focal用于类别损失,L1 smooth用于框损失是简单的选择),以及指标COCO metrics是一个简单的选择)。

  • 在训练期间可视化你的检测结果可以帮助了解你的模型遗漏了哪些对象。

数据集中问题标签的示例。图片由作者提供。

主要的下一步之一是清理数据集。例如,查看上面的图像。标注者正确地识别了马铃薯叶晚疫病,但其他所有健康的马铃薯叶子呢?为什么这些没有标注为马铃薯叶?查看 Roboflow 网站上的健康检查标签,你可以看到某些类别在数据集中严重不足:

显示类别不平衡的图表。来自 Roboflow 的网站

在调整任何超参数之前,先尝试修复这些问题。祝你在目标检测任务中好运!

参考文献

[1] F. Chollet, 用 Python 进行深度学习(2021), Manning Publications Co.

[2] A. Géron, 动手实践机器学习:使用 Scikit-Learn、Keras 和 TensorFlow (2022), O’Reily Media Inc.

[3] A. Ng, 深度学习专项课程, DeepLearning.AI

[4] D. Singh, N. Jain, P. Jain, P. Kayal, S. Kumawat, N. Batra, PlantDoc:用于视觉植物疾病检测的数据集 (2019), CoDS COMAD 2020

[5] T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, S. Belongie, 用于目标检测的特征金字塔网络(2017), CVPR 2017

[6] T. Lin, P. Goyal, R. Girshick, K. He, P. Dollar, 目标检测中的焦点损失(2020), IEEE 模式分析与机器智能学报

面向对象的数据科学:重构代码

原文:towardsdatascience.com/object-oriented-data-science-refactoring-code-5bcb4ae7ce72

提升机器学习模型和数据科学产品的效率代码和 Python 类。

Molly RubyTowards Data Science Molly Ruby

·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 8 月 24 日

--

图片由作者创建。

对于数据科学家来说,代码是分析和决策的核心。随着数据科学应用的复杂性增加,从嵌入在软件中的机器学习模型到协调大量信息的复杂数据管道,开发干净、组织良好且易于维护的代码变得至关重要。面向对象编程(OOP)解锁了灵活性和效率,使数据科学家能够敏捷地应对不断变化的需求。OOP 引入了类的概念,这些类作为创建对象的蓝图,这些对象封装了数据及其操作。这种范式转变使数据科学家能够超越传统的函数方法,促进模块化设计和代码重用。

在本文中,我们将探讨通过创建类和部署面向对象技术来重构数据科学代码的好处,以及这种方法如何增强模块化和可重用性。

数据科学中的类的力量

在传统的数据科学工作流中,函数是封装逻辑的方法。这通常足够,因为函数允许开发人员减少重复代码。然而,随着项目的发展,维护大量函数可能会导致代码难以导航、调试和扩展。

这时类发挥了作用。类是创建对象的蓝图,这些对象将数据和操作数据的函数(称为方法)捆绑在一起。通过将代码组织成类,开发人员可以实现以下几个优势:

  1. 模块化和封装:类通过将相关功能组合在一起来促进模块化。每个类封装了自己的属性(数据)和方法(函数),减少了全局变量污染的风险和命名冲突的可能性。这有助于保持关注点的清晰分离,使代码更容易理解和修改。

  2. 可重用性:类通过为项目的不同部分提供一致的接口来鼓励重用。一旦定义了一个类,它可以在需要时实例化,并且其方法可以用来实现一致的结果。

  3. 4. 继承和多态:继承允许开发人员创建子类,从父类继承属性和方法。这促进了代码重用,同时使得特定任务的定制成为可能。多态性,另一个面向对象编程的概念,让开发人员可以在不同类中使用相同的方法名称,根据具体实现调整行为。

  4. 5. 测试和调试:类促进了单元测试,因为测试用例可以针对类中的单独方法,这使得识别和修复问题变得更加容易,从而提高了代码库的整体健壮性。

将代码重构为类:一个理论例子

假设你正在进行一个涉及数据预处理、模型训练和评估的机器学习项目。最初,你可能会有一组用于每个步骤的函数:

# Example: Using functions for data preprocessing

def load_data(file_path):
    # Load and preprocess data
    ...

def preprocess_data(data):
    # Clean, transform, and encode data
    ...

def train_model(preprocessed_data):
    # Train a machine learning model
    ...

def evaluate_model(trained_model, test_data):
    # Evaluate model performance
    ...

尽管功能分解有效,但随着时间的推移,预处理、训练和评估中可能会有许多步骤。这可能会使管理这些函数变得具有挑战性。

将代码重构为类:

def load_data(self, file_path):
        # Load data
        ...

class DataPreprocessor:
    def __init__(self, data):
        self.raw_data = data
        self.cleaned_data = self.clean_data(data)

    def clean_data(self):
        # imputation, outlier treatment
        ...

    def transform_data(self):
        # transformations and encode data
        ...

class ModelTrainer:
    def __init__(self, preprocessed_data):
        self.model = self.train_model(preprocessed_data)

    def fit(self, preprocessed_data):
        # Train a machine learning model
        ...

    def predict(self, preprocessed_data):
        # Predict using the machine learning model
        ...

class ModelEvaluator:
    def __init__(self, predictions, actuals):
        self.performance_metrics = self.evaluate_model(predictions, actuals)

    def evaluate_model(self, predictions, actuals):
        # Evaluate model performance
        ...

    def calculate_rmse(self, predictions, actuals):
        # Evaluate root mean squared error

    def calculate_r_squared(self, predictions, actuals):
        # Evaluate r_squared of the model

通过将工作流程拆分为类,组织性得到了提升,结构也更易于阅读和维护。每个类处理过程的特定方面。它们可以被实例化为:

data_preprocessor = DataPreprocessor('data.csv')
model_trainer = ModelTrainer(data_preprocessor.preprocessed_data)
model_evaluator = ModelEvaluator(model_trainer.model, test_data)

在这种情况下,类的引入提供了额外的结构和灵活性,改善了代码的工作流程和可用性。通过利用类的强大功能,这个示例创建了一个更为健壮和可扩展的代码库。

将代码重构为类:一个实际的例子

作为一个实际的例子,我最近将这个代码库三年前开发的代码重构到一个新代码库中,以展示重构前后的代码差异。

在初始代码库中,许多函数涵盖了建模任务,因为训练和测试了多个不同的模型。在重构版本中,有一个名为 SalesForecasting 的模型类涵盖了所有建模任务。这更易于阅读,并且使得作为 SalesForecasting 部署包更为高效,并可以用不同的输入多次实例化。作为预览,这个类的样子如下:

class SalesForecasting:
    """
    SalesForecasting class to train and predict sales using a variety of models. 
    """

    def __init__(self, model_list):
        """
        Initialize the SalesForecasting class with a list of models to train and predict.

        Args:
            model_list (list): list of models to train and predict. Options include:
                - LinearRegression
                - RandomForest
                - XGBoost
                - LSTM
                - ARIMA

        Returns:
            None
        """

        ...

    def fit(self, X_train, y_train):
        """
        Fit the models in model_dict to the training data.

        Args:
            X_train (pd.DataFrame): training data exogonous features for the model
            y_train (pd.Series): training data target for the model

        Returns:
            None
        """

        ...

    def __fit_regression_model(self, model):
        """
        Fit a regression model to the training data.

        Args:
            model (sklearn model): sklearn model to fit to the training data

        Returns:
            model (sklearn model): fitted sklearn model
        """
        ...

    def __fit_lstm_model(self, model):
        """
        Fit an LSTM model to the training data.

        Args:
            model (keras model): keras model to fit to the training data

        Returns:
            model (keras model): fitted keras model
        """

        ...

    def __fit_arima_model(self, model_name):
        """
        Fit an ARIMA model to the training data.

        Args:
            model_name (str): name of the model to fit to the training data

        Returns:
            model (pmdarima model): fitted pmdarima model
        """
        ...

    def predict(self, x_values, y_values=None, scaler=None, print_scores=False):
        """
        Predict values using the models in model_dict.

        Args:
            x_values (pd.DataFrame): exogenous features to predict on
            y_values (pd.Series): target values to compare predictions against
            scaler (sklearn scaler): scaler used to scale the data
            print_scores (bool): whether to print the scores for each model

        Returns:
            self (SalesForecasting): self with updated predictions
        """

        ...

    def __predict_regression_model(self, model):
        """
        Predict values using a regression model.

        Args:
            model (sklearn model): sklearn model to predict with

        Returns:
            predictions (np.array): array of predictions
        """
        ...

    def __predict_lstm_model(self, model):
        """
        Predict values using an LSTM model.

        Args:
            model (keras model): keras model to predict with

        Returns:
            predictions (np.array): array of predictions
        """
        ...

    def __predict_arima_model(self, model):
        """
        Predict values using an ARIMA model.

        Args:
            model (pmdarima model): pmdarima model to predict with
        Returns: 
            predictions (np.array): array of predictions
        """
        ...

    def __undo_scaling(self, values, scaler):
        """
        Undo scaling on a set of values.

        Args:
            values (np.array): array of values to unscale
            scaler (sklearn scaler): scaler to use to unscale the values

        Returns:
            unscaled_values (np.array): array of unscaled values
        """
        ...

    def get_scores(self, y_pred, y_true, model_name=None, print_scores=False):
        """
        Get the scores for a model. Scores include RMSE, MAE, and R2.

        Args:
            y_pred (np.array): array of predicted values
            y_true (np.array): array of true values
            model_name (str): name of the model to get scores for
            print_scores (bool): whether to print the scores for the model

        Returns:
            rmse (float): root mean squared error
            mae (float): mean absolute error
            r2 (float): r squared
        """
        ...

    def plot_results(self, model_list=None, figsize=p.FIG_SIZE, xlabel="Date", ylabel="Sales", title="Sales Forecasting Predictions"):
        """
        Plot the results of the predictions against the actual values.
        Generates a timeseries for predictions from each model in model_dict.

        Args:
            model_list (list): list of models to plot. If None, plots all models in model_dict
            figsize (tuple): tuple of figure size
            xlabel (str): label for x axis
            ylabel (str): label for y axis
            title (str): title for the plot

        Returns:
            fig (matplotlib figure): figure with the plot
        """

        ...

    def plot_errs(self, figsize=(13,3)):
        """
        Plot the errors for each model in model_dict. Errors include RMSE, MAE, and R2.

        Args:
            figsize (tuple): tuple of figure size

        Returns:
            fig (matplotlib figure): figure with the plot
        """
        ...

“SalesForecasting”类作为一个全面的蓝图,帮助数据驱动的企业通过应用各种预测模型来预测未来的销售趋势。在这个类中,数据科学家可以利用不同的建模技术,包括线性回归、随机森林、XGBoost、LSTM(长短期记忆)和 ARIMA(自回归积分滑动平均)。通过将预测工作流封装在这个类中,模型拟合、预测和评估的过程变得更加简化和一致。通过“SalesForecasting”类,数据科学家可以高效地实验不同的算法,并轻松维护代码库。

面向对象编程是数据科学家用来构建反映其分析的真实世界系统复杂性的代码的工具,使他们能够在最大化灵活性的同时提取有价值的洞察。虽然 python 旨在将类用于实例化和继承,上面的例子展示了一个初步步骤,其中类被用于模块化代码。随着数据科学能力的扩展和团队的增长,维护高效的代码是至关重要的。

在这里查看完整的重构代码库。

无需 OCR 的文档数据提取与变换器 (1/2)

原文:towardsdatascience.com/ocr-free-document-data-extraction-with-transformers-1-2-b5a826bc2ac3?source=collection_archive---------3-----------------------#2023-04-28

Donut 与 Pix2Struct 在自定义数据上的对比

图标:Toon Beerten图标:Towards Data Science Toon Beerten

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 4 月 28 日

--

作者提供的图像 ()

DonutPix2Struct 是图像到文本模型,将纯像素输入的简单性与视觉语言理解任务相结合。简单来说:输入一张图像,提取的索引以 JSON 格式输出。

最近我发布了一个在发票上微调的 Donut 模型。我经常收到如何使用自定义数据集进行训练的问题。此外,还发布了一个类似的模型:Pix2Struct,它声称性能显著更好。但真的是这样吗?

该是卷起袖子的时候了。我将展示给你:

  • 如何为 Donut 和 Pix2Struct 微调准备数据

  • 两种模型的训练过程

  • 实际数据集上的比较结果

当然,我也会提供 colab 笔记本,以便于你的实验和/或复制。

数据集

要进行此比较,我需要一个公开的数据集。我想避免使用通常用于文档理解任务的数据集,例如CORD,浏览了一下,发现了Ghega 数据集。它相当小(约 250 个文档),由 2 种类型的文档组成:专利申请和数据表。通过不同类型,我们可以模拟一个分类问题。每种类型我们都有多个索引需要提取。这些索引对于每种类型都是唯一的。正是我所需要的。来自的 Trieste 大学机器学习实验室的Medvet教授慷慨批准了这些文章的使用。

数据集似乎比较旧,所以需要调查它是否仍然适合我们的目标。

初步探索

当你获得一组新的数据时,你首先需要熟悉其结构。幸运的是,网站的详细描述对我们很有帮助。这是数据集的文件结构:

ghega-dataset
    datasheets
        central-zener-1
        central-zener-2
        diodes-zener
            document-000-123542.blocks.csv
            document-000-123542.groundtruth.csv
            document-000-123542.in.000.png
            document-000-123542.out.000.png
            document-001-123663.blocks.csv
            document-001-123663.groundtruth.csv
            document-001-123663.in.000.png
            document-001-123663.out.000.png
            ...
        mcc-zener
        ...
    patents
        ...

我们可以看到两个主要的子文件夹对应两个文档类型:数据表专利。在更下一级,我们有一些子文件夹,这些子文件夹本身不重要,但它们包含以某个前缀开头的文件。我们可以看到一个唯一的标识符,例如document-000–123542。对于每个这些标识符,我们有 4 种数据:

  • blocks.csv 文件包含有关边界框的信息。由于 Donut 或 Pix2Struct 不使用这些信息,我们可以忽略这些文件。

  • out.000.png 文件是后处理(去倾斜)的图像文件。由于我更愿意测试未处理的文件,我也会忽略这些。

  • 原始的、未处理的文档图像有一个 in.000.png 后缀。这是我们感兴趣的。

  • 最后是相应的groundtruth.csv文件。这包含我们认为是实际标注的图像索引。

这里是一个示例 groundtruth csv 文件以及列描述:

Case,-1,0.0,0.0,0.0,0.0,,0,1.28,2.78,0.79,0.10,MELF CASE
StorageTemperature,0,0.35,3.40,2.03,0.11,Operating and Storage Temperature,0,4.13,3.41,0.63,0.09,-65 to +200
 1\. element type
 2\. page of the label block (-1 if absent)
 3\. x of the label block
 4\. y of the label block
 5\. w of the label block
 6\. h of the label block
 7\. text of the label block
 8\. page of the value block (never absent!)
 9\. x of the value block
10\. y of the value block
11\. w of the value block
12\. h of the value block
13\. text of the label block

这意味着我们只对第一列和最后一列感兴趣。第一列是,最后一列是。在这种情况下:

KEY                   VALUE
Case                  MELF CASE
StorageTemperature    -65 to +200

这意味着对于该文档,我们将微调模型以查找‘Case’的值为‘MELF CASE’,并且提取一个‘StorageTemperature’,其值为‘-65 to +200’。

索引

在 groundtruth 元数据中存在以下索引:

  • 数据表:型号、类型、外壳、功耗、储存温度、电压、重量、热阻

  • 专利:标题、申请人、发明人、代表、申请日期、出版日期、申请编号、出版编号、优先权、分类、摘要第一行

观察到地面真实值的质量和可行性,我选择保留以下索引:

elements_to_extract = ['FilingDate', 'RepresentiveFL', 'Classification', 'PublicationDate','ApplicationNumber','Model','Voltage','StorageTemperature']

质量

对于图像转换为文本,使用了 ocropus 版本 0.2。这意味着它大约在 2014 年底发布。在数据科学领域这已经很古老了,那么地面真实度的质量是否符合我们的任务要求呢?

为此,我查看了一些随机图像,并将地面真实值与实际在文档上写的内容进行了比较。以下是两个 OCR 不正确的示例:

来自 Ghega 数据集的 document-001–109381.in.000.png

Classification 被设置为 BGSD 81/00 作为地面真实值。它应该是 B65D 81/100

来自 Ghega 数据集的 document-003–112107.in.000.png

StorageTemperature 显示 I -65 {O + 150 作为地面真实值,而我们可以看到它应该是 -65 to + 150

数据集中有许多此类错误。一种方法是纠正这些错误。另一种是忽略这些错误。由于我将使用相同的数据来比较两个模型,我选择了后者。如果数据用于生产,你可能需要选择前一种方法以获得最佳结果。

(还要注意,这些特殊字符可能会搞乱 JSON 格式,稍后我会回到这个话题)

Donut 数据集结构

我们需要的数据格式是什么样的?

对于微调 Donut 模型,我们需要将数据组织在一个文件夹中,所有文档作为单独的图像文件和一个元数据文件,结构为 JSON lines 文件。

donut-dataset
    document-000-123542.in.000.png
    document-001-123663.in.000.png
    ...
    metadata.jsonl

JSONL 文件包含每个图像文件一行,格式如下:

{"file_name": "document-010-100333.in.000.png", "ground_truth": "{\"gt_parse\": { \"DocType\": \"patent\", \"FilingDate\": \"06.12.1999\", \"RepresentiveFL\": \"Manresa Val, Manuel\", \"Classification\": \"A47l. 5/28\", \"PublicationDate\": \"1139845\", \"ApplicationNumber\": \"99959528 .3\" } }"}

让我们分解这行 JSON。在上层我们有一个包含两个元素的字典:file_nameground_truth。在 ground_truth 键下,我们有一个包含 gt_parse 键的字典。其值本身是一个字典,包含我们在文档中知道的键值对。或者更好:assign。记住,文档中不一定会出现文档类型。术语 datasheet 并没有作为文本出现在这些文档中。

幸运的是,pix2struct 使用相同的格式进行微调,因此我们可以一举两得。一旦我们将其转换为这种结构,我们还可以用来微调 Pix2Struct。

转换

对于转换本身,我在 colab 上创建了一个 Jupyter notebook。我决定在这个阶段将数据拆分为训练集和验证集,而不是在微调之前。这种方式,两个模型将使用相同的验证图像,结果会更具可比性。五个文档中会有一个用于验证。

利用上述 Ghega 数据集的结构知识,我们可以将转换过程概括如下:

对于每个以 in.000.png 结尾的文件名,我们取对应的 groundtruth 文件并创建一个临时的数据框对象。

注意,groundtruth 可能为空或完全不存在。(例如,对于 datasheets/taiwan-switching

接下来,我们从子文件夹中扣除类:patentdatasheet 。现在我们需要构建 JSON 行。对于每个我们想提取的元素/索引,我们检查它是否在数据框中并进行收集。然后复制图像本身。

对所有图像执行此操作,最后我们就有一个 JSONL 文件可以写出。

在 Python 中,它看起来是这样的:

json_lines_train = ''
json_lines_val = ''

for dirpath, dirnames, filenames in os.walk('/content/ghega-dataset/'):
    for filename in filenames:
        if filename.endswith('in.000.png'):
          gt_filename = filename.replace('in.000.png','groundtruth.csv')
          gt_filename_path = os.path.join(dirpath, gt_filename)
          if not os.path.exists(gt_filename_path):    #ignore files in /ghega-dataset/datasheets/taiwan-switching/ because no groundtruth exists
            continue
          if os.path.getsize(gt_filename_path) == 0:  #ignore empty groundtruth files
            print(f'skipped {gt_filename_path} because no info in metadata')
            continue
          doc_df = pd.read_csv(gt_filename_path, header=None)
          #find the doctype, based on path
          if 'patent' in dirpath:
            type = 'patent'
          else:
            type = 'datasheet'
          #create json line
          #eg:
          #{"file_name": "document-034-127420.in.000.png", "ground_truth": "{\"gt_parse\": { \"DocType\": \"datasheet\", \"Model\": \"ZMM5221 B - ZMM5267B\", \"Voltage\": \"1.5\", \"StorageTemperature\": \"-65 to 175\" } }"}
          p2 = ''
          #add always first element: DocType
          p2 += '\\"' + 'DocType' + '\\": '
          p2 += '\\"' + type + '\\"'
          new_row = {'ImagePath': os.path.join(dirpath, filename), 'DocType' :type}
          ghega_df = pd.concat([ghega_df, pd.DataFrame([new_row])], ignore_index=True)
          #fill other elements if available
          for element in elements_to_extract:
            value = doc_df[doc_df[0] == element][12].tolist()
            if len(value) > 0:
              p2 += ', '
              p2 += '\\"' + element + '\\": '
              value = re.sub(r'[^A-Za-z0-9 ,.()/-]+', '', value[0])   #get rid of \ of ” and " in json
              p2 += '\\"' + value + '\\"'
              new_row = {'ImagePath': os.path.join(dirpath, filename), element :value}
              ghega_df = pd.concat([ghega_df, pd.DataFrame([new_row])], ignore_index=True)

          p3 = ' } }"}'

          json_line = p1 + p2 + p3
          print(json_line)

          #take ~20% to validation
          #copy image file and append json line
          if random.randint(1, 100) < 20:
            output_path = '/content/dataset/validation/'
            json_lines_val += json_line + '\r\n'
            shutil.copy(os.path.join(dirpath, filename), '/content/dataset/validation/')  
          else:
            output_path = '/content/dataset/train/'
            json_lines_train += json_line + '\r\n'
            shutil.copy(os.path.join(dirpath, filename), '/content/dataset/train/')  

#write jsonl files
text_file = open('/content/dataset/train/metadata.jsonl', "w")
text_file.write(json_lines_train)
text_file.close()
text_file = open('/content/dataset/validation/metadata.jsonl', "w")
text_file.write(json_lines_val)
text_file.close()

ghega_df 是一个数据框,用于进行一些合理性检查或统计分析(如有需要)。我用它来检查随机样本,验证我的转换数据是否正确。

问题

转换完成后,一切看起来都很顺利。但我想摆脱那种通常第一次尝试就能成功的想法。总是会有一些小的意外问题发生。谈论我遇到的错误并展示解决方案对任何模拟整个过程并使用自己数据集的人都是有用的。

例如,在转换数据集后,我想训练 Donut 模型。在此之前,我需要创建一个训练数据集,如下所示:

train_dataset = DonutDataset("/content/dataset", max_length=max_length,
                             split="train", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
                             sort_json_key=False, # dataset is preprocessed, so no need for this
                             )

并且出现了这个错误:

---------------------------------------------------------------------------
ArrowInvalid                              Traceback (most recent call last)
<ipython-input-13-7726ec2b0341> in <cell line: 7>()
      5 processor.feature_extractor.do_align_long_axis = False
      6 
----> 7 train_dataset = DonutDataset("/content/dataset", max_length=max_length,
      8                              split="train", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
      9                              sort_json_key=False, # cord dataset is preprocessed, so no need for this

ArrowInvalid: JSON parse error: Missing a comma or '}' after an object member. in row 7

看起来第 7 行的 JSON 格式有问题。我复制了那一行并将其粘贴到一个 在线 JSON 验证器 中:

作者提供的图像

作者提供的图像

作者提供的图像

然而,它表示这是一个有效的 JSON 行。让我们更深入地看看:

{
   "file_name":"document-012-108498.in.000.png",
   "ground_truth":"{\"gt_parse\": {\"DocType\": \"patent\"\"FilingDate\": \"15\. Januar 2004 (15.01.2004)\",\"Classification\": \"BOZC 18/08,\",\"PublicationDate\": \"5\. August 2004 (05.08.2004)\",\"ApplicationNumber\": \"PCT/AT2004/000006\"} }"
}

你发现错误了吗?经过一段时间,我注意到 DocTypeFilingDate 之间缺少逗号。然而,这在所有行中都是缺失的,所以我不清楚为什么第 7 行会出现问题。当我修复了这个问题后,我再次尝试,现在它声称第 17 行有问题:

ArrowInvalid: JSON parse error: Missing a comma or '}' after an object member. in row 17

这是第 17 行,你发现了问题吗?

{"file_name": "document-007-103668.in.000.png", "ground_truth": "{\"gt_parse\": {\"DocType\": \"patent\",\"FilingDate\": \"18.12.2008\",\"RepresentiveFL\": \"Schubert, Siegmar\",\"Classification\": \"A47J 31/42 (2""6·"')\",\"PublicationDate\": \"12.08.2009\",\"ApplicationNumber\": \"08021980.1\"} }"}

这是Classification 元素的未转义引号。为了解决这个问题,我决定所有值只能包含字母数字字符和一些特殊字符,并使用了这个正则表达式:

[^A-Za-z0-9 ,.()/-]+

这可能会严重影响真实性能,但从我所见,其他字符都是由于 OCR 错误引起的。我认为,对于模型之间的相对比较,忽略这些字符影响不大。

数据准备:完成

数据准备的重要性常被忽视且被低估。通过上述步骤,我展示了如何调整自己的数据,以便 Donut 和 Pix2Struct 用于文档的关键索引提取。常见的陷阱也得到了修正。包含所有步骤的 Jupyter 笔记本可以在这里找到。我们已经完成了一半。下一步是用这个数据集训练这两个模型。我非常好奇它们的表现如何,但比较和训练将留到下一篇文章中。

你可能还喜欢:

[## 实战:使用🍩变换器进行文档数据提取

我使用 Donut 变换器模型提取发票索引的经验。

toon-beerten.medium.com](https://toon-beerten.medium.com/hands-on-document-data-extraction-with-transformer-7130df3b6132?source=post_page-----b5a826bc2ac3--------------------------------)

参考文献:

[## 无 OCR 文档理解变换器

理解文档图像(例如,发票)是一项核心但具有挑战性的任务,因为它需要复杂的功能…

arxiv.org](https://arxiv.org/abs/2111.15664?source=post_page-----b5a826bc2ac3--------------------------------) [## Pix2Struct:作为视觉语言理解预训练的截图解析

视觉位置语言无处不在——来源包括带有图表的教科书到包含图像的网页…

arxiv.org](https://arxiv.org/abs/2210.03347?source=post_page-----b5a826bc2ac3--------------------------------) [## 机器学习实验室 - Ghega 数据集

Ghega 数据集:用于文档理解和分类的数据集,我们提供了一个标注数据集,可以…

machinelearning.inginf.units.it](https://machinelearning.inginf.units.it/data-and-tools/ghega-dataset?source=post_page-----b5a826bc2ac3--------------------------------) [## to-be/donut-base-finetuned-invoices · Hugging Face

编辑模型卡 基于 Donut 基础模型(在论文《无 OCR 文档理解变换器》中介绍)…

huggingface.co](https://huggingface.co/to-be/donut-base-finetuned-invoices?source=post_page-----b5a826bc2ac3--------------------------------)

无 OCR 文档数据提取与变换器(2/2)

原文:towardsdatascience.com/ocr-free-document-data-extraction-with-transformers-2-2-38ce26f41951?source=collection_archive---------1-----------------------#2023-08-10

Donut 与 Pix2Struct 在自定义数据上的对比

Toon BeertenTowards Data Science Toon Beerten

·

关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 8 月 10 日

--

图片由作者提供(使用

这两种变换器模型对文档的理解如何?在第二部分中,我将展示如何训练它们并比较它们在关键索引提取任务中的结果。

调整 Donut 模型

所以让我们从 第一部分 开始,在那里我解释了如何准备自定义数据。我将数据集的两个文件夹打包并上传到一个新的 huggingface 数据集 这里。我使用的 Colab 笔记本可以在 这里 找到。它将下载数据集,设置环境,加载 Donut 模型并进行训练。

在微调了 75 分钟后,我在验证指标(即编辑距离)达到 0.116 时停止了:

作者提供的图像

在字段级别,我得到这些验证集结果:

作者提供的图像

当我们查看Doctype时,我们发现 Donut 总是正确地将文档识别为专利数据表。因此,我们可以说分类达到了 100% 的准确率。同样需要注意的是,即使我们有一个类别数据表,它也不需要文档上出现这个确切的词来进行分类。对于 Donut 来说,这并不重要,因为它经过微调以这样识别。

其他领域的得分也相当不错,但仅凭这张图表很难了解内部情况。我想看看模型在特定情况下的正确与错误之处。因此,我在我的笔记本中创建了一个例行程序来生成 HTML 格式的报告表。对于我的验证集中的每个文档,我都有这样的行条目:

作者提供的图像

左侧是识别(推断)数据及其真实值。右侧是图像。我还使用了颜色代码以便快速概览:

作者提供的图像

理想情况下,一切都应该用绿色突出显示。如果你想查看验证集的完整报告,可以在 这里 查看,或者本地下载这个 zip 文件

有了这些信息,我们可以发现常见的 OCR 错误,如Dczcmbci(应为December)或GL420(应为GL420,0 和 O 难以区分),这些错误会导致假阳性。

现在让我们关注表现最差的字段:电压。以下是推断数据、真实值和实际相关文档片段的一些样本。

作者提供的图像

问题在于真实值大多是错误的。是否包括单位(Volt 或 V)没有标准。有时会包含无关文本,有时只是一个(错误的!)数字。我现在明白为什么 Donut 会对此感到困难。

作者提供的图像

上面是一些 Donut 实际上给出最佳答案的样本,而实际情况是不完整或错误的。

作者提供的图像

上面是另一个糟糕训练数据混淆 Donut 的好例子。地面真实值中的‘I’字母是 OCR 读取信息前的垂直线的伪影。有时它存在,有时不存在。如果你对数据进行预处理,使其在这方面一致,Donut 将会学习并遵循这种结构。

微调 Pix2Struct

Donut 的结果保持稳定,Pix2Struct 的呢?我用来训练的 Colab 笔记本可以在这里找到。

经过 75 分钟的训练,我得到的编辑距离分数为 0.197,而 Donut 的为 0.116。这显然收敛速度较慢。

另一个观察结果是,到目前为止,每个返回的值都以一个空格开头。这可能是 ImageCaptioningDataset 类中的错误,但我没有进一步调查根本原因。不过,我在生成验证结果时会去掉这个空格。

Prediction: <s_DocType> datasheet</s_DocType></s_DocType> TSZU52C2 – TSZUZUZC39<s_DocType>
    Answer: <s_DocType>datasheet</s_DocType><s_Model>Tszuszcz</s_Model><s_Voltage>O9</s_Voltage>

在 2 小时后我停止了微调过程,因为验证指标再次上升:

作者提供的图像

那么这对验证集的字段级别意味着什么呢?

作者提供的图像

这看起来比 Donut 的结果差得多!如果你想查看完整的 HTML 报告,可以在这里查看,或者在本地下载这个 zip 文件

只有在数据表专利之间的分类似乎还不错(但不如 Donut)。其他字段则完全不佳。我们能推断发生了什么吗?

对于专利文档,我看到很多橙色线条,这意味着 Pix2Struct 根本没有返回这些字段。

作者提供的图像

即使在专利中返回字段,它们也完全是虚构的。而 Donut 的错误源于从文档的其他区域提取或有轻微的 OCR 错误,Pix2Struct 在这里则是出现了幻觉。

对 Pix2Struct 的表现感到失望,我尝试了几次新的训练以期获得更好的结果:

作者提供的图像

我尝试将 accumulate_grad_batches 逐渐从 8 降到 1。但这样学习率过高,会导致超调。将其降低到 1e-5 会使模型无法收敛。其他组合则导致模型崩溃。即使在一些特定的超参数下,验证指标看起来相当不错,但它给出了很多不正确或无法解析的行,例如:

<s_DocType> datasheet</s_DocType><s_Model> CMPZSM</s_Model><s_StorageTemperature> -0.9</s_Voltage><s_StorageTemperature> -051c 150</s_StorageTemperature>

这些尝试都没有给我带来实质性的更好结果,所以我就此停止了。

直到我看到 huggingface 实现中的交叉注意力 bug被修复。因此,我决定再试一次。训练了两个半小时,停在 0.1416 的验证指标上。

作者提供的图片

作者提供的图片

这显然比所有之前的结果都要好。查看 HTML 报告,现在似乎幻觉更少。总体来说,它的表现仍不如 Donut。

至于原因,我有两个理论。首先,Pix2Struct 主要在 HTML 网页图像上训练(预测掩码图像部分后面的内容),并且在切换到另一个领域,即原始文本时,遇到了困难。其次,使用的数据集非常具有挑战性。它包含了许多 OCR 错误和不一致(如包含单位、长度、负号)。在我的其他实验中,我真的发现数据集的质量和一致性比数量更重要。在这个数据集中,数据质量真的很差。也许这就是我无法复制论文中声称 Pix2Struct 超越 Donut 表现的原因。

推断速度

这两种模型在速度方面如何比较?所有训练都在相同的 T4 架构上进行,因此时间可以直接比较。我们已经看到 Pix2Struct 收敛所需的时间要长得多。那么推断时间呢?我们可以比较推断验证集所需的时间:

作者提供的图片

Donut 每个文档提取的平均时间为 1.3 秒,而 Pix2Struct 则超过两倍。

要点

  • 对我来说,最终的赢家是 Donut。在易用性、性能、训练稳定性和速度方面。

  • Pix2Struct 训练具有挑战性,因为它对训练超参数非常敏感。它收敛较慢,并且在这个数据集上没有达到 Donut 的结果。可能值得重新考虑使用更高质量的数据集来尝试 Pix2Struct。

  • 由于 Ghega 数据集包含太多不一致性,我将避免在进一步实验中使用它。

是否有其他替代模型?

  • Dessurt,似乎与 Donut 有相似的架构,应该表现相当。

  • DocParser,论文称其表现甚至更好。不幸的是,目前没有计划将该模型发布到未来。

  • mPLUG-DocOwl将很快发布,这是另一个有前景的无 OCR LLM 文档理解工具。

你可能还会喜欢:

[## 实战:使用🍩变压器进行文档数据提取

我使用甜甜圈转换器模型来提取发票索引的经验。

toon-beerten.medium.com

参考文献:

[## Pix2Struct: 截图解析作为视觉语言理解的预训练

视觉定位语言无处不在——来源从带有图表的教科书到带有图像的网页等。

arxiv.org [## 无 OCR 文档理解转换器

理解文档图像(例如发票)是一项核心但具有挑战性的任务,因为它需要复杂的功能,比如……

arxiv.org [## GitHub - Toon-nooT/notebooks

通过在 GitHub 上创建帐户来为 Toon-nooT/notebooks 的开发做贡献。

github.com

哦,你是说“管理变革”?

原文:towardsdatascience.com/oh-you-meant-manage-change-bc9639affab5?source=collection_archive---------7-----------------------#2023-10-20

数据组织中对变革的不同视角

Marc DelbaereTowards Data Science Marc Delbaere

·

关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 10 月 20 日

--

作者在布鲁塞尔的 Menssa 餐厅拍摄

变革的不同酿制方式

[场景:现代办公室的休息室。咖啡机的嗡嗡声是唯一的声音,空气中弥漫着新鲜咖啡的香气。CDO Alex 站在咖啡机旁,倒了一杯咖啡。数据工程师 Jamie 走了进来,看起来有些疲惫。]

  • Jamie:“又一天,又一个挑战。你知道,Alex,管理变革开始让我感到疲惫不堪。”

  • Alex(点头):“绝对是这样,Jamie。变革管理现在是我最优先的任务。我们必须确保自己在适应并保持领先。”

  • Jamie(扬起眉毛):“保持领先?我只是尽力让每次变化时事情不会崩溃。”

  • 亚历克斯:“确切地说,这就是预测这些变化并保持领先的关键。我们必须保持团队的动力和一致性。”

  • 杰米(困惑,但尝试认同):“是的,对齐并且……不落后。明白了。”

[谈话流转到其他话题,但他们对“变革管理”的观点差异依然未被说出或承认。]

那么,让我们来分析一下亚历克斯和杰米之间刚刚发生了什么。他们都提到了变革管理这个词,但他们就像在说不同的语言。

我们的首席数据官亚历克斯有远大的目标。她在监测市场变化、新兴技术,并设想公司在未来几年的发展方向。然而,制定战略是简单的,复杂的是让每个人达成共识。

引入一个新工具?她得准备好接受翻白眼和“又一个要学习的软件”的抱怨。一个新流程?准备好接受“但我们一直这么做”的合唱。对亚历克斯来说,变革管理就像走钢丝——平衡公司需要走的方向,同时确保每个人都支持,并且不对他们的工作安全感到恐慌。

然后是杰米。他的变革管理并不是关于未来几年的,而是现在。那个刚刚坏掉的管道?这是他的问题。销售报告中的差异?他的责任。最难的往往不是技术细节,而是人际因素。比如有人忘记告诉他一个微小的“无关紧要”的数据变化,导致一切陷入混乱。或者当任务出现问题时,指责游戏开始。对杰米来说,变革管理就是让事情今天顺利进行,并处理任何突发的问题。

战略视角:首席数据官的愿景

我经常与首席数据官互动,这种对话的多样性是我工作中真正让我欣赏的一方面。每一次对话都是不同的,带来独特的视角。然而,不知为何——也许是因为这些话题我非常关注,或者也许这里才是真正的行动所在——某些共同的主题不可避免地浮现出来。

首先,强调的是推动业务价值。这不仅仅是收集数据或实施最新的技术;关键在于将数据转化为可操作的洞察。确保每一个数据驱动的举措都与公司的目标相一致,无论是增加销售、提升客户满意度还是优化运营。

接下来是对效率的追求。首席数据官(CDO)经常面临改善运营、消除重复工作以及确保数据及时到达所需地点的任务。这不是轻松的工作;它涉及拆除旧有障碍、鼓励团队合作以及跟上新的技术解决方案。

许多首席数据官(CDO)正在倾向于去中心化数据网格的概念。这是从传统的中心化数据团队转变为一个模型的重大变化,在这个模型中,领域团队拥有、生成并提供他们的数据作为产品。这里的思维过程既简单又具有革命性:那些对数据最了解的人应该负责打包和维护数据。这不仅能确保更好的数据质量,还能培养自我消费的文化,赋予组织不同部分更多的自主权。

达成这些目标非常困难。每一个战略目标都带来了变更管理的问题,像亚历克斯这样的 CDO 必须直接面对这些问题。

比如以业务价值为首的议程。对于那些已经在技术任务中工作了多年,甚至几十年的数据专业人士来说,转移关注到业务成果上可能会让人感到不适。他们已经被训练成以数据准确性、系统集成和代码优化为思考的方式。要求他们“以业务价值为思考方式”常常会遇到困惑的目光!

还有去中心化的趋势,这在纸面上无疑是一个好主意:赋予团队权力,让他们承担责任,组织变得更加敏捷和高效。实际上,这意味着大量的变化需要被管理。去中心化带来了明确角色和责任的挑战。当每个人都是利益相关者时,任务很容易被忽视。谁负责数据质量?谁确保数据对需要的人是可访问的?没有明确的界定,问题会被遗漏,责备游戏就会开始。

从本质上讲,对于每一个战略转变,都存在着一个隐含的变更管理复杂网络。这不仅仅是绘制路线图,更要确保每个人都理解自己的角色,具备执行任务的能力,并且致力于前进的旅程。

实际情况:日常挑战

虽然亚历克斯作为 CDO 的角色主要关注大局,驾驭广泛而不可预测的情境,但变更管理还有另一面。这体现在像杰米这样的数据工程师面临的日常详细挑战中。在他们的领域中,变更管理不是关于长期战略或总体业务目标。相反,它关注于确保数据在不断变化的背景下保持一致和可访问的持续、每时每刻的障碍。

首先,组织中的大部分数据是作为副产品产生的。随着各种业务活动的展开,数据自然地积累,就像机器的废气一样。然而,虽然这些数据对于生成它的人来说可能只是副产品,但对于数据团队及其内部和外部的下游客户来说,这些数据成为了他们日常运营的核心。讽刺的是,在源头,这种‘废气’往往被忽视,尽管它对于链条上的这些利益相关者来说是不可或缺的。

可以将其视为在不稳定的地面上为高楼大厦奠定基础。地球下方总在移动,但你的任务是确保上面的庞大结构保持稳定。这是许多数据工程师和商业智能(BI)分析师的世界。他们站在前线,每天处理数据的异常行为。

他们面临的一个重大问题是交织在数据世界中的复杂依赖网络。数据从一个平台移动到另一个平台,经历转化,与其他数据集合并,最后到达预期的位置。这个过程的每个阶段都可能出现故障。在一个平台上进行的小调整可能会产生连锁反应,导致其他地方的干扰。而最具挑战性的一点是?通常,做出这些更改的人对他们可能引发的连锁反应毫无察觉。

面对不断变化的数据链带来的日常问题,思想领袖们提出了新的概念。他们首先引入了数据产品,这些数据产品打包数据以便于消费,类似于商店老板将商品展示给顾客。但随着更多人开始使用这些数据产品,出现了对正式承诺的需求——一种确保数据产品所有者可靠服务其用户的方式。这一认识促使了数据合同的制定,以确保这些义务得到履行。

数据合同作为打包数据以供使用者使用的人员与他们服务的消费者之间的桥梁,记录和执行明确的承诺:数据模式的不可变性、质量和可用性的标准等。它们优雅地解决了在数据依赖链中管理变化的问题。

弥合分歧:统一的变革视角

战略转型的挑战和管理不稳定依赖的挑战乍一看似乎大相径庭。像 Alex 这样的首席数据官(CDO)负责指导整体战略,并使整个组织朝着共同愿景对齐。同时,Jamie 是处理日常数据挑战的“消防员”。然而,在这两种观点的核心,有一些统一的原则可以弥合分歧。

透明度至关重要。无论是 Alex 沟通战略举措的广泛目标,还是 Jamie 标记架构变更可能带来的下游影响,清晰和开放的沟通可以预防许多问题。

协作确保一致性。数据组织中的每个人都需要保持同步。在日常层面,这意味着有效沟通以防止意外的麻烦。在战略层面,则是确保每个人对广泛目标有清晰认识,确保日常任务和整体计划朝着相同的方向推进。

标准化提供了稳定性。引入像数据合同这样的实践不仅解决了 Jamie 面临的细节挑战,还巩固了 Alex 战略愿景建立的基础。通过建立明确的标准,我们消除了模糊性,使得既有大局观的思考者也有注重细节的执行者能够协同朝同一方向前进。

最终的讽刺在于,为了解决 Jamie 的日常问题(即不透明依赖链中的意外后果),你需要将这个话题提升为战略优先事项。

如果你想取得成功,鉴于过程、人力和技术的变革,你需要应用所有由 Alex 倡导的变革管理良好原则。当然,Jamie 在这里扮演着至关重要的角色,他是最接近问题及其后果的人,因此他可以成为变革推动者,让他的同事和管理层参与其中。

所以,最初的互惠互利实际上可能是战略转型的开始:明确的行动理由、路线图、合适的领导者和工具。

即使是管理变革,你也需要变革管理!

好的,你已经训练了最好的机器学习模型。接下来做什么?

原文:towardsdatascience.com/okay-youve-trained-the-best-machine-learning-model-what-s-next-e7b8f167006e

数据科学

一个超越 Jupyter Notebook 建模的 MLOps 项目

Albers UzilaTowards Data Science Albers Uzila

·发表于 Towards Data Science ·阅读时间 18 分钟·2023 年 6 月 4 日

--

图片来源:Elena MozhviloUnsplash

***Table of Contents***
**·** **Initialize a Repository**
**·** **Migrate Your Codebase**
  ∘ config/config.py
  ∘ config/args.json
  ∘ tagolym/utils.py
  ∘ tagolym/data.py
  ∘ tagolym/train.py
  ∘ tagolym/predict.py
  ∘ tagolym/evaluate.py
  ∘ tagolym/main.py
**·** **Package Your Codebase** **·** **Setup Data Source Credential** **·** **Run Your Pipeline** **·** **Miscellaneous** **·** **Push Your Project to GitHub** **·** **Wrapping Up**

假设你正在构建一个数据科学项目,可能是为了工作、大学、作品集、爱好或其他任何目的。你已经花费了很多时间来解决问题陈述,并在 Jupyter notebooks 中进行实验。现在,你在想,“我怎么将我的工作部署成一个有用的产品?”。

具体来说,假设你有一个托管论坛的网站。用户可以给论坛中的线程添加标签,以方便在不同主题的论坛之间导航。你希望通过建议预定义的标签来改善用户体验,从而为讨论提供背景。

论坛可以是任何形式的,因此让我们更具体一点;它通常以一个 帖子 开始,解释一个数学问题,接着是围绕这个问题的想法、问题、提示或答案。以下是一个线程的样子及其三个标签,即 inductioncombinatorics unsolvedcombinatorics

论坛中的一个帖子示例 | 图片由 author 提供

此时,你已经在你的 notebooks 中完成了所有工作,从理解问题陈述、定义指标、查询数据、清理数据、预处理、EDA、构建模型到评估和优化模型。

你会注意到有很多帖子有着大量的标签。为了简化,你只筛选了 10 个标签。你开发的模型是简单的线性分类器(SVM、逻辑回归等),前面经过 TF-IDF 向量化,并用随机梯度下降(SGD)进行训练。

前 30 个频繁标签计数 | 图片由作者提供

最终的标签分布。注意几何标签是最常见的 | 图片由作者提供

虽然笔记本非常好,并且可以帮助你非常快速地进行实验,但它们并不适合生产环境,并且有时很难维护。因此,你需要将代码迁移到独立的 Python 文件中,然后逐步添加其他工具,同时与团队成员合作。

这个故事将引导你通过简明的步骤完成这项工作。在此之前,你可能想要刷新一下关于线性模型、TF-IDF 和 SGD 的知识:

## 线性回归、逻辑回归和 SVM 在 10 分钟内

线性回归与逻辑回归和支持向量机有什么关系?

towardsdatascience.com ## 你需要了解的词袋模型和 Word2Vec — 文本特征提取

为什么 Word2Vec 更好,但为什么它还不够好

towardsdatascience.com ## 从头开始的完整梯度下降算法步骤

以及其对常数学习率和线性搜索的实现

towardsdatascience.com

初始化一个仓库

首先,让我们在GitHub上创建一个名为tagolym-ml的新仓库,并配有README.md.gitignoreLICENSE

创建新的 GitHub 仓库 | 图片来源 author

要使用这个代码库,请执行以下步骤:

  1. 克隆代码库,将创建一个名为tagolym-ml的文件夹。

  2. 将工作目录更改为此文件夹。

  3. 创建一个名为venv的虚拟环境。

  4. 激活环境。

  5. 升级pip

  6. 可选地,你可以使用pip list检查当前环境中已安装的包,其中会有pipsetuptools

  7. 创建一个名为code_migration的新 git 分支并切换到它。

  8. 创建一个setup.py文件。

  9. 创建一些名为configtagolymcredentials的新文件夹。

  10. config文件夹内创建config.pyargs.json文件。

  11. tagolym文件夹内创建main.pyutils.pydata.pytrain.pyevaluate.pypredict.py文件。

如果你不知道如何做这些,不用担心。这里是你可以在喜欢的终端上运行的所有命令:

$ git clone https://github.com/dwiuzila/tagolym-ml.git
$ cd tagolym-ml
$ python3 -m venv venv
$ source venv/bin/activate
$ python3 -m pip install --upgrade pip
$ pip list
Package    Version
---------- -------
pip        23.1.2
setuptools 58.0.4
$ git checkout -b code_migration
$ touch setup.py
$ mkdir config tagolym credentials
$ touch config/config.py config/args.json
$ cd tagolym
$ touch main.py utils.py data.py train.py evaluate.py predict.py
$ cd ..

你现在有一个本地 git 仓库,已连接到 GitHub 上的远程仓库。当地仓库的目录将如下所示。

config/
├── args.json        - preprocessing/training parameters
└── config.py        - configuration setup
credentials/         - keys and passwords
tagolym/
├── data.py          - data processing components
├── evaluate.py      - evaluation components
├── main.py          - training/optimization pipelines
├── predict.py       - inference components
├── train.py         - training components
└── utils.py         - supplementary utilities
venv/                - virtual environment
.gitignore           - files/folders that git will ignore
LICENSE              - project license
README.md            - longform description of the project
setup.py             - code packaging

目前几乎所有这些文件都是空的。你将一个一个地填写它们,从config文件夹开始。

迁移你的代码库

你的项目中有两个主要文件夹,即configtagolym。你需要将笔记本中的必要代码复制到这些文件夹中的文件中。我们来做吧。

config/config.py

在这里,你定义了与种子、目录、实验跟踪、预处理和标签名称相关的全局变量。

当这个文件在你的代码中被导入时,如果尚未创建,它将创建两个新文件夹:

  1. data,用于存储项目的标记数据,

  2. stores/model,用于存储模型注册,

然后将stores/model连接到用于实验跟踪的 MLflow 跟踪 URI。

你还在这里定义了停用词和额外的命令词。停用词将默认为nltk包中的词汇,而命令词为["prove", "let", "find", "show", "given"],这些词经常出现在帖子中,但不提供任何有用的信号。

正则表达式用于预处理。它们看起来很吓人,但你不需要理解它们。它们的基本功能是捕捉任何数学表达式渐近线语法的 LaTeX,这些在数学问题的帖子中是基础和核心。

最后,记住你只选择了 10 个入围标签进行处理?你在这个文件中列出了所有这些标签。一些标签与您的标签有类似的含义(例如“inequalities” → “inequality”),因此你也有 10 个部分标签来捕获这些标签并用适当的标签替换它们。请参见下面的tagolym/data.py,了解如何操作。

config/args.json

这是你存储整个过程的初始参数的地方。它们来自管道的不同部分。

它们是什么意思?

  1. nocommandstem —— 处理帖子时的布尔值,是否排除命令词和实现词干提取

  2. ngram_max_range —— 在TF-IDF 向量化过程中提取不同n-gram 的n值范围的上限。

  3. lossl1_ratioalphalearning_rateeta0power_t —— 用于SGD 分类器的模型的超参数。

tagolym/utils.py

流水线有些复杂,因此你需要一些实用函数和 Python 类来简化代码。这个文件包含了这些:

  1. load_dictsave_dict —— 从 JSON 文件中加载字典,或将字典转储到 JSON 文件中。

  2. NumpyEncoder —— 将包含 Numpy 实例的对象编码为 Python 内置实例,用于save_dict

  3. IterativeStratification —— 当你处理像这个项目这样的多标签分类时,普通的训练-测试划分方法对于数据并不理想。相反,你需要我们所说的迭代分层,它旨在提供在给定阶数下标签关系证据的良好平衡分布。在这个项目中,阶数设置为 2。

tagolym/data.py

与数据相关的所有函数都写在这个文件中,包括数据分割、预处理和转换。

  1. preprocess —— 从包含部分标签的标签创建映射到config/config.py中定义的 10 个标签之一,然后对所有帖子和标签进行文本处理。这个函数还会在文本处理后删除所有空帖子样本。

  2. binarize — 根据模型要求,如果你正在处理多标签分类问题,你可能需要对标签进行二值化。此函数将标签转换为一个大小为(#样本 × #标签)的二进制矩阵,指示标签中标签的存在。例如,包含两个标签["algebra", "inequality"]的标签将被转换为[1, 0, 0, 0, 0, 1, 0, 0, 0, 0]。除了返回转换后的标签,它还返回稍后使用的[MultiLabelBinarizer](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html)对象,特别是在将矩阵转换回标签时。

  3. split_data — 使用tagolym/utils.py中的IterativeStratification,此函数将帖子和标签拆分为 3 部分,比例为 70/15/15,分别用于模型训练、验证和测试。

tagolym/train.py

最佳实践是将模型训练、验证和测试放在不同的文件中。正如文件名所示,你在这里进行所有的训练。由于你希望用户能自信地使用模型的标签推荐,你需要降低假阳性率。

另一方面,目前假阴性并不是你的首要任务。为了说明这一点,我们来看看一个极端的例子:模型将所有 10 个标签预测为负值,因此没有推荐标签,你会有大量的假阴性。但用户可以毫不犹豫地创建自己的标签。这没什么大不了的。

所以,你的目标是拥有一个高精度的模型。

现在,让我们讨论一下这个文件的内容:

  1. train — 预处理数据,将标签二值化,并使用tagolym/data.py中的函数拆分数据。然后,初始化一个模型,训练它,使用训练好的模型对所有三个数据拆分进行标签预测,并评估预测结果。这个函数接受args,其中包含config/args.json中的所有参数,返回时可能会添加一个额外的参数threshold。基本上,threshold是通过tune_threshold计算出的每个标签的最佳阈值列表。

  2. objective — f1 分数是超参数调整中选择的优化指标。使用试验中选择的args,此函数训练模型并返回验证集的 f1 分数。它还为试验设置了额外的属性,包括所有三个数据拆分的精确度、召回率和 f1 分数。

  3. tune_threshold — 二分类问题的默认决策边界是 0.5,但这可能并不是最优的,具体取决于问题。因此,除了调整args,你还需要在优化 f1 分数时调整每个标签的阈值。它的作用是尝试从 0 到 1 的网格中所有可能的阈值,并选择具有最大 f1 分数的阈值。

tagolym/predict.py

模型训练之后该做什么?预测!这个文件中有两个函数:

  1. custom_predict — 如果模型具有 predict_proba 属性,则此函数将预测每个标签作为标签的概率。否则,使用 0.5 阈值直接预测标签。在前一种情况下,如果提供了真实标签,函数将使用 tagolym/train.py 中的 tune_threshold 来调整阈值。

  2. predict — 加载 args、标签二值化器和训练好的模型。然后,预处理给定的文本,并使用 custom_predict 对其进行预测。之后,将预测矩阵转换回标签。

tagolym/evaluate.py

给定预测和真实标签矩阵,本文件的目的是计算精度、召回率、F1 分数和样本数量。性能是根据总体样本、每类样本和每个切片样本计算的。你考虑了 8 个切片:

  1. 短帖,即经过预处理后少于 5 个单词的帖子,

  2. 六个切片,其中帖子被标记为子主题但未标记为覆盖子主题的更大主题,以及

  3. 不包含频繁出现的四字或更多字的帖子。

tagolym/main.py

这是运行所有任务的主要文件。这里有 5 个函数和你需要在其中编写的指令:

  1. elt_data — 查询标记数据并以 JSON 格式保存到data文件夹中。

  2. train_model — 从data文件夹加载标记数据并训练模型。不要忘记使用 MLflow 记录指标、工件和参数。还要将 MLflow run_id和指标保存到config文件夹中。

  3. optimize — 从data文件夹加载标记数据并优化给定的参数。为了提高搜索效率,优化分为两个步骤:a) 预处理、向量化和建模中的超参数;b) 学习算法中的超参数。还要根据目标将最佳参数保存到config文件夹中,命名为args_opt.json

  4. load_artifacts — 将特定 run_id 的工件加载到内存中,包括参数、指标、模型和标签二值化器。

  5. predict_tag — 给定特定的 run_id,使用预加载的工件预测接收到的每个文本的标签。

唷!你刚刚完成了所有迁移工作。现在,如何使用这些代码?

Jason Strull 提供的照片,来源于 Unsplash

打包你的代码库

当你使用笔记本时,你有一个预加载的包集合用于实验。为了在本地重现并部署到生产环境,你希望明确地定义你的环境。

你在这个项目中导入了许多开源包,但你的环境中只有pipsetuptools。因此,在运行管道之前,你需要安装这些包。

下面是一个方便的命令来实现这一点。注意,我在最后添加了 [pip-chill](https://pypi.org/project/pip-chill/) 以便于后续清理生成的包要求文件。

$ pip install mlflow nltk regex scikit-learn snorkel joblib optuna pandas google-cloud-bigquery google-auth numpy scipy pip-chill

pip-chill 的一个很酷的特点是,它可以生成一个不包含文件中依赖于其他包的包的要求文件,使得要求文件干净且准确。让我们运行一下。

$ pip-chill --no-chill > requirements.txt

这将创建一个 requirements.txt 文件,包含你实际需要的所有包。请注意,因为这些包是已经列在文件中的包的依赖项,所以文件中没有 pandasscikit-learnregex 以及其他几个包。

现在你将使用 setup.py 打包你的代码库,将所有依赖项打包在一起。在这个文件中,加载你在 requirements.txt 中的所有库,并使用 setuptools 中的 setup 函数定义你的包。

你的包名将是 tagolym。你可以在下面的代码中看到其他细节,如版本和描述。你从 requirements.txt 中加载的库将用于 install_requires 参数,并成为 tagolym 的依赖项。

然后你可以使用下面的命令安装 tagolym。这将创建一个名为 tagolym.egg-info 的新文件夹,包含项目的元数据。

$ python3 -m pip install -e .

请注意,-e--editable 标志会从本地项目路径以可编辑模式安装包。换句话说,如果你在当前工作目录中使用一些函数,例如使用 from tagolym import main,然后对 tagolym/main.py 进行一些更改,你将能够使用这个更新版本,而无需使用 pip install 重新安装你的包。

设置数据源凭证

有一个小问题。这些项目中使用的数据是我自己的数据,存储在我的 BigQuery 中。在创建并下载一个 服务账户密钥 后,我将其重命名为 bigquery-key.json,并将其放置在 credentials 文件夹中。

要访问数据,你需要我的凭证,但不幸的是,这些凭证不能共享。不过不用担心,我会提供样本供你使用。

创建服务账户密钥 | 图片由 作者 提供

你需要做的很简单:下载样本 labeled_data.json 在这里 并将文件保存在工作目录中名为 data 的文件夹里。

运行你的管道

现在你准备好了!在终端中输入 python3 命令,你就可以运行 Python 中的所有内容。你只需使用 tagolym/main.py 文件。

首先,我使用我的凭证和 elt_data 函数查询数据。当我看到 ✅ Saved data! 时,我知道过程顺利完成。如上所述,你可以跳过这一步,手动将我提供的样本放入 data 文件夹中。

然后,您可以使用 optimize 函数来优化模型,通过读取初始参数 config/args.json。我将试验次数设置为 10,但您可以尝试其他设置。由于您有一个两步优化过程,所以将创建一个新的 MLflow 研究,总共 20 次试验。找到的最佳验证 f1 分数是 0.7730。

使用一组优化后的参数 config/args_opt.json,您可以再次使用 train_model 函数训练模型,并使用 predict_tag 函数对文本列表进行推断。您可以看到下面的预测非常准确!

$ python3
>>> from pathlib import Path
>>> from config import config
>>> from tagolym import main
>>>
>>> # query data
>>> key_path = "credentials/bigquery-key.json"
>>> main.elt_data(key_path)
✅ Saved data!
>>>
>>> # optimize model
>>> args_fp = Path(config.CONFIG_DIR, "args.json")
>>> main.optimize(args_fp, study_name="optimization", num_trials=10)
2023/06/03 17:42:12 INFO mlflow.tracking.fluent: Experiment with name 'optimization' does not exist. Creating a new experiment.
[I 2023-06-03 17:41:45,657] A new study created in memory with name: optimization
[I 2023-06-03 17:42:12,343] Trial 0 finished with value: 0.7519199358796977 and parameters: {'nocommand': False, 'stem': True, 'ngram_max': 2, 'loss': 'modified_huber', 'l1_ratio': 0.6011150117432088, 'alpha': 0.001331121608073689}. Best is trial 0 with value: 0.7519199358796977.
[I 2023-06-03 17:42:38,441] Trial 1 finished with value: 0.7629559140596291 and parameters: {'nocommand': False, 'stem': True, 'ngram_max': 2, 'loss': 'modified_huber', 'l1_ratio': 0.43194501864211576, 'alpha': 7.476312062252303e-05}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:42:57,713] Trial 2 finished with value: 0.7511576441724478 and parameters: {'nocommand': True, 'stem': False, 'ngram_max': 3, 'loss': 'hinge', 'l1_ratio': 0.5924145688620425, 'alpha': 1.3783237455007187e-05}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:43:19,108] Trial 3 finished with value: 0.7106573336158825 and parameters: {'nocommand': True, 'stem': False, 'ngram_max': 4, 'loss': 'hinge', 'l1_ratio': 0.6842330265121569, 'alpha': 0.00020914981329035596}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:43:37,349] Trial 4 finished with value: 0.741392879377292 and parameters: {'nocommand': False, 'stem': False, 'ngram_max': 2, 'loss': 'hinge', 'l1_ratio': 0.5467102793432796, 'alpha': 3.585612610345396e-05}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:44:04,235] Trial 5 finished with value: 0.7426444422157734 and parameters: {'nocommand': True, 'stem': True, 'ngram_max': 3, 'loss': 'hinge', 'l1_ratio': 0.045227288910538066, 'alpha': 9.46217535646148e-05}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:44:30,104] Trial 6 finished with value: 0.7337258988967691 and parameters: {'nocommand': True, 'stem': True, 'ngram_max': 2, 'loss': 'modified_huber', 'l1_ratio': 0.07455064367977082, 'alpha': 0.009133995846860976}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:44:51,778] Trial 7 finished with value: 0.7700323704566581 and parameters: {'nocommand': True, 'stem': False, 'ngram_max': 4, 'loss': 'log_loss', 'l1_ratio': 0.3584657285442726, 'alpha': 2.2264204303769678e-05}. Best is trial 7 with value: 0.7700323704566581.
[I 2023-06-03 17:45:18,125] Trial 8 finished with value: 0.7559495178348377 and parameters: {'nocommand': True, 'stem': True, 'ngram_max': 2, 'loss': 'log_loss', 'l1_ratio': 0.8872127425763265, 'alpha': 0.00026100256506134784}. Best is trial 7 with value: 0.7700323704566581.
[I 2023-06-03 17:45:47,029] Trial 9 finished with value: 0.7730089901544794 and parameters: {'nocommand': False, 'stem': True, 'ngram_max': 4, 'loss': 'log_loss', 'l1_ratio': 0.02541912674409519, 'alpha': 2.1070472806578224e-05}. Best is trial 9 with value: 0.7730089901544794.
[I 2023-06-03 17:45:47,056] A new study created in memory with name: optimization
[I 2023-06-03 17:46:16,061] Trial 0 finished with value: 0.7730089901544794 and parameters: {'learning_rate': 'optimal'}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:46:48,008] Trial 1 finished with value: 0.7701884982320516 and parameters: {'learning_rate': 'adaptive', 'eta0': 0.15930522616241014}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:47:18,651] Trial 2 finished with value: 0.7331091235928242 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.0265875439832727, 'power_t': 0.17272998688284025}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:47:49,429] Trial 3 finished with value: 0.7196639813595901 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.038234752246751866, 'power_t': 0.34474115788895177}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:48:21,601] Trial 4 finished with value: 0.7727673901952036 and parameters: {'learning_rate': 'adaptive', 'eta0': 0.3718364180573207}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:48:51,330] Trial 5 finished with value: 0.7576010292654753 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.16409286730647918, 'power_t': 0.16820964947491662}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:49:21,906] Trial 6 finished with value: 0.7428637006524251 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.040665633135147955, 'power_t': 0.13906884560255356}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:49:52,034] Trial 7 finished with value: 0.746701310091385 and parameters: {'learning_rate': 'constant', 'eta0': 0.011715937392307063}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:50:21,383] Trial 8 finished with value: 0.7683160697730758 and parameters: {'learning_rate': 'constant', 'eta0': 0.10968217207529521}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:50:51,373] Trial 9 finished with value: 0.7338062675694838 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.7568292060167615, 'power_t': 0.4579309401710595}. Best is trial 0 with value: 0.7730089901544794.
Best value (f1): 0.7730089901544794
Best hyperparameters: {
  "nocommand": false,
  "stem": true,
  "ngram_max": 4,
  "loss": "log_loss",
  "l1_ratio": 0.02541912674409519,
  "alpha": 2.1070472806578224e-05,
  "learning_rate": "invscaling",
  "eta0": 0.7568292060167615,
  "power_t": 0.4579309401710595,
  "threshold": [
    0.59,
    0.79,
    0.55,
    0.7000000000000001,
    0.5,
    0.72,
    0.76,
    0.63,
    0.7000000000000001,
    0.77
  ]
}
>>>
>>> # train model
>>> args_fp = Path(config.CONFIG_DIR, "args_opt.json")
>>> main.train_model(args_fp, experiment_name="baselines", run_name="sgd")
2023/06/03 17:52:01 INFO mlflow.tracking.fluent: Experiment with name 'baselines' does not exist. Creating a new experiment.
Run ID: fbdba0c7cab640bc853611ba6cd75cee
>>> text = [
...     "Let $c,d \geq 2$ be naturals. Let $\{a_n\}$ be the sequence satisfying $a_1 = c, a_{n+1} = a_n^d + c$ for $n = 1,2,\cdots$.Prove that for any $n \geq 2$, there exists a prime number $p$ such that $p|a_n$ and $p \not | a_i$ for $i = 1,2,\cdots n-1$.",
...     "Let $ABC$ be a triangle with circumcircle $\Gamma$ and incenter $I$ and let $M$ be the midpoint of $\overline{BC}$. The points $D$, $E$, $F$ are selected on sides $\overline{BC}$, $\overline{CA}$, $\overline{AB}$ such that $\overline{ID} \perp \overline{BC}$, $\overline{IE}\perp \overline{AI}$, and $\overline{IF}\perp \overline{AI}$. Suppose that the circumcircle of $\triangle AEF$ intersects $\Gamma$ at a point $X$ other than $A$. Prove that lines $XD$ and $AM$ meet on $\Gamma$.",
...     "Find all functions $f:(0,\infty)\rightarrow (0,\infty)$ such that for any $x,y\in (0,\infty)$, $$xf(x²)f(f(y)) + f(yf(x)) = f(xy) \left(f(f(x²)) + f(f(y²))\right).$$",
...     "Let $n$ be an even positive integer. We say that two different cells of a $n \times n$ board are [b]neighboring[/b] if they have a common side. Find the minimal number of cells on the $n \times n$ board that must be marked so that any cell (marked or not marked) has a marked neighboring cell."
... ]
>>> main.predict_tag(text=text)
[
  {
    "input_text": "Let $c,d \\geq 2$ be naturals. Let $\\{a_n\\}$ be the sequence satisfying $a_1 = c, a_{n+1} = a_n^d + c$ for $n = 1,2,\\cdots$.Prove that for any $n \\geq 2$, there exists a prime number $p$ such that $p|a_n$ and $p \not | a_i$ for $i = 1,2,\\cdots n-1$.",
    "predicted_tags": [
      "number theory"
    ]
  },
  {
    "input_text": "Let $ABC$ be a triangle with circumcircle $\\Gamma$ and incenter $I$ and let $M$ be the midpoint of $\\overline{BC}$. The points $D$, $E$, $F$ are selected on sides $\\overline{BC}$, $\\overline{CA}$, $\\overline{AB}$ such that $\\overline{ID} \\perp \\overline{BC}$, $\\overline{IE}\\perp \\overline{AI}$, and $\\overline{IF}\\perp \\overline{AI}$. Suppose that the circumcircle of $\triangle AEF$ intersects $\\Gamma$ at a point $X$ other than $A$. Prove that lines $XD$ and $AM$ meet on $\\Gamma$.",
    "predicted_tags": [
      "geometry"
    ]
  },
  {
    "input_text": "Find all functions $f:(0,\\infty)\rightarrow (0,\\infty)$ such that for any $x,y\\in (0,\\infty)$, $$xf(x²)f(f(y)) + f(yf(x)) = f(xy) \\left(f(f(x²)) + f(f(y²))\right).$$",
    "predicted_tags": [
      "algebra",
      "function"
    ]
  },
  {
    "input_text": "Let $n$ be an even positive integer. We say that two different cells of a $n \times n$ board are [b]neighboring[/b] if they have a common side. Find the minimal number of cells on the $n \times n$ board that must be marked so that any cell (marked or not marked) has a marked neighboring cell.",
    "predicted_tags": [
      "combinatorics"
    ]
  }
]
>>> exit()

您可以在美观的 MLflow UI 中查看您的实验:

$ mlflow ui --backend-store-uri stores/model

MLflow 用户界面 | 图片由 作者 提供

这些过程中的一些在后台创建了新的文件,大多数是模型训练的输出。您可以在下一节中解释的 README.md 文件中查看当前的项目目录。

杂项

您项目的安全性至关重要。因此,凭据应仅存在于本地仓库中;您不希望将其推送到 GitHub。

作为预防措施,在 .gitignore 文件末尾添加 credentials/。这将忽略您在开发项目时对 credentials 文件夹所做的任何更改。

其他您可能想要添加到 .gitignore 的内容包括 data/stores/,因为它们可能包含敏感信息或占用大量磁盘空间。如果您使用的是 macOS,还需添加 .DS_Store。这是一个存储其所在文件夹自定义属性的文件,对您的项目没有用处。

完成所有这些之后,您可以选择更新 README.md 中的项目描述。只需输入您在这个故事中完成的高层次过程,以便每个人都可以轻松复制您的工作。这可能看起来是这样的。

将您的项目推送到 GitHub

您的项目很酷,但它对其他人有用吗?要回答这个问题,您可以开源您的项目,以便每个人都可以从中受益,提供反馈,甚至贡献。这样做非常简单。

您需要的是下面的三个命令:

  1. 将您所做的每一项更改添加到 Git 索引中。

  2. 将索引中的更改提交到本地仓库,并

  3. 将本地仓库推送到远程,这将创建一个新的分支 code_migration 在远程仓库中。

$ git add .
$ git commit -m "Initial commit"
$ git push origin code_migration

您可以在 这里 查看结果。

了解更多关于 Git 的信息:

## 作为数据科学家使用 Git 命令的真实案例研究

完成分支说明

towardsdatascience.com

总结

恭喜你!🍻 你已经阅读完了这个故事。你学会了如何将你的数据科学实验从 Jupyter Notebook 转化为一个干净且可维护的项目。除此之外,你还知道了如何打包你的项目,运行整个数据管道,并使用 GitHub 和 BigQuery。

不过,这只是你 MLOps 旅程的开始。还有很长的路要走。敬请关注!📻

图片由 Matese FieldsUnsplash 提供

🔥 你好!如果你喜欢这个故事并想支持我作为一个作家,可以考虑 成为会员。每月只需 $5,你就可以无限制访问 Medium 上的所有故事。如果你通过我的链接注册,我将获得一小笔佣金。

🔖 想了解更多关于经典机器学习模型如何运作以及如何优化其参数的信息?或者 MLOps 大型项目的示例?还有精选的顶尖文章?继续阅读:

Albers Uzila

Albers Uzila

MLOps 大型项目 - 第二部分

查看列表3 篇故事!Albers Uzila

Albers Uzila

从零开始的机器学习

查看列表8 篇故事!Albers Uzila

Albers Uzila

高级优化方法

查看列表7 篇故事!Albers Uzila

Albers Uzila

我的最佳故事

查看列表10 个故事Albers Uzila

Albers Uzila

R 中的数据科学

查看列表7 个故事

关于 A/B 测试和携带效应

原文:towardsdatascience.com/on-ab-tests-and-carryover-effect-43668dbd52e2?source=collection_archive---------11-----------------------#2023-05-23

图片由 Ron Hansen 提供,来源于 Unsplash

Denis VorotyntsevTowards Data Science Denis Vorotyntsev

·

关注 发布于 Towards Data Science · 7 分钟阅读 · 2023 年 5 月 23 日

--

在复杂的数据驱动决策世界中,A/B 测试脱颖而出,成为一个强大的工具,帮助企业优化策略和改善用户体验。但当一个测试的效果渗透到下一个测试中时,会发生什么情况呢?这会使结果变得模糊不清,扭曲结果。

这种现象被称为“滞后效应”,可能对理解测试中变更的真实影响构成重大挑战。在本文中,我们将深入探讨 A/B 测试和滞后效应的细微差别,讨论有效管理这种现象的策略。我们将探索用户分组的机制、分桶技术以及如何识别和解决滞后效应,以确保你的 A/B 测试提供可靠的、可操作的结果。

用户与桶

AB 测试是比较两个版本功能的基础方法,通常用于确定哪一个表现更好。为了执行这些测试,我们通常将用户分成两个组——对照组和处理组,基于用户 ID。为了简化,我们可以将所有“偶数”用户分配给对照组,将所有“奇数”用户分配给处理组。

直观的 AB 测试设置:所有用户被分成对照组和处理组

初步步骤涉及估算样本大小——根据我们选择的指标(如点击率或每用户平均收入)来确定需要收集的用户或事件数量。这些估算考虑了方差(基于这些指标的历史观察)和预期效果(基于建议模型的离线结果)。在收集了足够的数据后,我们深入进行统计分析,如使用 t 检验比较对照组和处理组的平均收入,以确定表现更好的模型。

然而,这种方法面临几个障碍:

  1. 同时多重测试:同时启动多个测试成为挑战。例如,如果一个新模型需要测试,没有剩余流量来容纳这个测试。解决方案是暂停当前的 AB 测试,并将用户分成三组。但如果我们不知道未来将运行多少个变体呢?

  2. 可扩展性问题:处理广泛的用户基础时,扩展分析成为一项艰巨的任务。即使在用户级别缓存结果,对最近几天的 AB 测试进行统计计算也可能非常费力,尤其是在处理大量用户时。

为了绕过这些问题,我们采用了一种称为“分桶”的技术。

分桶方法

分桶将几个用户组合成一个称为“桶”的单元。你可以把这个桶看作是一个“元用户”。为了确定桶的 ID,我们使用以下公式:

bucket_id = hash(user_id + salt) % number_of_buckets

在这里,salt是一个固定的随机字符串,而number_of_buckets是系统的预定义参数。根据系统设计,桶的 ID 可以实时计算(当用户访问网站时)或在用户访问网站时计算一次并存储在键值存储中。

分桶理念:用户根据上述公式被分配到不同的桶中。处理是在用户级别应用的,但分析是在桶级别进行的。

在启动 AB 测试之前,我们估算要分配给对照组和处理组的桶数量。例如,如果总桶数为 1000,我们将桶 0–99 的用户分配给对照模型,将桶 100–199 的用户分配给处理模型。这为两个模型提供了 10%的流量。

分桶允许我们在桶级别分析结果,而不是用户级别,这使我们可以在桶级别缓存指标,从而消除了繁重的重新计算需求。

了解延续效应

想象一个场景,你正在为一个大型电子商务网站设计一个新的推荐模型,使用高级神经网络。这个模型包括一个创新的功能,即 OpenAI GPT3 API 调用,用于生成商品标题的嵌入。该模型的离线结果显示出显著的性能提升,因此决定进行在线测试。

AB 测试的结构是跨越一周,将网站流量的 10%分配给对照组,10%分配给处理组。目标是比较两组之间的点击数量,以确定哪个模型表现更好。

然而,在 AB 测试上线几小时后,所有指标出现了令人担忧的下降。深入分析数据发现,由于复杂的模型,页面加载时间显著增加。这是一个意料之外的问题,在离线测试中没有遇到,也未在在线测试中考虑到。

针对加载时间缓慢的问题,分配到处理组的用户变得沮丧,导致一些用户减少使用或完全流失。这一不幸事件扰乱了用户在桶中的分布平衡,这是新 AB 测试中未考虑的一个方面。

这种持续的不平衡,即初始测试的直接后果,被称为“延续效应”。它发生在同一组用户在多个测试中不断经历变化时。本质上,由于分桶分配中使用了一致的盐或种子,桶“记住”了之前的 AB 测试,从而影响了后续测试的结果。

当用户行为因先前的处理而改变时,延续效应变得特别明显。例如,如果正在测试的新功能需要用户学习曲线,处理组的成员可能由于早期接触而更快适应,从而在其他组用户中获得优势。

在大规模和成熟的系统中,即使是 1%的微小变化也可能意味着数百万的收入,因此这一效应变得极为重要。数据科学家和机器学习工程师通常力求在指标上获得 0.1%的提升。然而,即使是轻微的延续效应也可能使多个 AB 测试失效,从而导致一个次优的模型被采纳或一个优质的模型因 AB 分组偏差而被弃用。

识别问题

为识别这一问题,数据科学家应定期进行 AA 测试。设计良好的 AB 测试系统应在 AA 测试中产生均匀的 p 值分布。AA 测试中 p 值直方图的不均匀性表明桶存在不平衡,可能由各种因素造成,包括带来的影响。

应对问题的策略

重排所有用户

避免桶内存储问题的最快且最简单的解决方案是定期更换盐值。这能确保用户在每次重排后在桶内随机分布,从而打破与先前分割的关联。然而,这种方法在进行 AB 测试期间并不实际,因为它扰乱了对照组和处理组之间的用户分布,破坏了 AB 测试的独立同分布(i.i.d.)前提,从而使结果无效。

简单重排:更改盐值将导致用户在桶之间重新分配

当同时进行多个 AB 测试时,这种方法也会带来挑战,协调所有测试的终止可能很困难,任何延误都可能代价高昂。此外,持续时间长的负面测试无法停止,否则会失去 AB 测试的进展。

重排非 AB 测试用户

另一种替代方案是对不涉及任何 AB 测试的用户进行重排。采用这种方法,在每次完成 AB 测试后,那些不参与任何测试的用户将被重新分配到可用的桶中。

重排未参与 AB 测试的用户:在测试 2 结束后,来自桶 2、3、4 和 5 的用户被重排。测试 1 中的用户保持不变。在测试 1 结束后,所有桶中的用户都被重排。

尽管这种方法不需要停止所有 AB 测试,但其实施更为复杂。我们需要跟踪实验中的用户并存储用户-ID 与桶的映射,频繁更新——这在大型系统中可能比较棘手。

结论

处理 AB 测试的复杂性,从用户设置和桶管理到处理带来的影响,需要精心规划和策略处理。理解这些复杂性可以帮助确保测试提供有价值的、可操作的见解,并对你正在进行的开发工作产生积极的贡献。通过采用有效的解决方案来克服潜在的障碍,你可以优化测试过程,提高用户体验,并最终优化产品的成功。

进一步阅读

为了深入理解带来的影响和 AB 测试的其他细微差别,这里有一些有价值的资源供进一步阅读:

网页上的对照实验:调查与实用指南 对桶化设计进行了详细的探讨,这在执行对照网页实验中是一个关键要素。

避免 A/B 测试中的三个陷阱的分离策略 对分组设计进行了详细的讲解,分组设计是进行受控网络实验中的一个关键要素。

如何解读 p 值直方图 对 p 值直方图的解读进行了深入探讨,这在 AA 测试中至关重要。它有助于检测延续效应,使读者对 AB 测试的统计方面有更深入的理解。

关于人工智能与推理的类型

原文:towardsdatascience.com/on-ai-and-types-of-reasoning-fc6980295158?source=collection_archive---------3-----------------------#2023-01-20

人工智能如何做决策?

Jazmia HenryTowards Data Science Jazmia Henry

·

关注 发表在 Towards Data Science · 5 分钟阅读 · 2023 年 1 月 20 日

--

嗨,各位数据领域的朋友们!

我工作在算法领域的时间越长,我就越确信,算法只是人类让机器模仿我们思维方式的一种方法。

在任何给定的时刻,我们会接收 1100 万比特的信息,但仅处理其中的 40 到 50 个。我们已经进化到只关注对生存最有价值的信息。

在构建算法时,我们使用数据来进行预测或协助决策,其中一些特征对我们的分析更有价值或更有用。

处理我们数据的算法与处理我们周围世界的思维之间的区别在于理解上下文的能力,并在归纳推理(当我感到热时,我会出汗。因此,当未来温度高时,我将会出汗)、演绎推理(如果 A = B,B = C,那么 A = C)和演绎推理(我把食物放在有狗的房间的台面上。我回来后发现我的食物不见了,而我的狗看起来很内疚。我的狗一定吃了我的食物)之间轻松切换的能力。算法可以被可靠地训练来执行所有这些类型的推理,但无法像人类一样同时可靠地进行所有这些推理。

归纳推理

归纳推理

归纳推理遵循特定的路径。它从特定的观察开始(观察到的树上的叶子是绿色的),注意到一个模式(我面前的这群树都有绿色的叶子),然后得出一个一般性的结论(所有树木的叶子都是绿色的)。分类算法如逻辑回归在归纳推理方面表现良好。它们有一个目标变量,并利用特定的特征来得出更大的结论。

这里有一个这种现象的例子。假设你正在执行一个逻辑回归算法,该算法能够识别苹果和橙子的区别。你的目标变量是一个二进制变量——1 代表苹果或 0 代表橙子。你的特征是颜色和皮肤质地的分类变量,以及是否有果梗的布尔变量。当运行模型时,算法得出结论:如果水果的颜色不是橙色,皮肤光滑,并且有果梗,那么该水果是苹果。特征是具体的,算法能够检测到一个模式,最终得到的结果是一般性的。

演绎推理

演绎推理

归纳推理从具体开始,得出一般性结论,而演绎推理则从一般性结论开始,得出具体结论。这就像是开车经过一片树木繁茂的森林,注意到所有树上的叶子都是绿色的,然后提出假设:森林中的任何一棵树也会有绿色的叶子。

基本的聚类算法在演绎推理方面表现良好。它们将模型中的特征用于识别围绕一个质心最近的数据点,并根据接近度对其进行分组,从而利用一般信息(所有数据点都在这个平面上)来得出具体结论(在欧几里得空间中最接近的数据点在某种有价值的方式上是最相似的)。

演绎推理

演绎推理

演绎推理发生在算法在注意到模式后用不完全的数据得出结论时。比如,你想仅通过观察人们穿的衣物来推断外面的温度。当人们感到寒冷时,他们通常会穿上外套。通过观察外面没有人穿外套,你得出结论外面一定很温暖。

强化学习算法在归纳推理方面表现良好。代理使用模拟环境在面对不完整观察时通过计算轨迹和优化奖励来得出结论。

让我们来看一个例子。假设你正在构建一个 Q-Learning 算法作为自驾车的基本模型,以将包裹送到社区中的人们那里。你希望确保你的自驾车能在一天结束前安全高效地送达所有包裹。为了训练你的自驾车,你创建了一个数字代理,每次车辆安全送达包裹时,你可以奖励它。

你的代理能够观察人类专家驾驶交付路径并沿途做出决策的行动。在每次观察后,代理会尝试行驶该路线,直到做出最佳决策以获得最佳奖励。代理可以做出的决策包括最佳驾驶速度、转动方向盘的方式、何时刹车以及何时加速。然而,当代理驾驶其典型路径时,它遇到了导致交通堵塞的施工。这可能看起来像这样:

哎呀,有交通了!

研究人员可以促使代理在遇到新的未曾遇到的情况时做出最佳选择。经过更多的训练,代理能够预测最佳行动,同时继续执行其主要目标:按时安全地送达所有包裹。在计算过程中,代理可能发现还有另一个包裹需要送达——一个可能传统上在其路线后期送达的包裹。通过绕道将包裹送到另一个家庭,然后再返回原来的路线,代理可能会获得最大奖励。这项决策可能看起来像这样:

这种推理方式是归纳推理的一个例子。尽管代理没有关于导致交通的施工的完整信息,但它能够决定最佳的行动路线。它能够意识到交通会增加交付时间,并且在未被告知的情况下决定尝试另一条路线。

结论

讨论的每种推理方式都有其优点和缺点,具体取决于应用任务。通过理解三种主要的 AI 推理方式,可以推动 AI 的可能性,使我们更接近更有用和强大的通用 AI。

然而,是否这确实是我们未来 AI 的最佳目标尚待观察,但如果是这样,具有这种复杂推理能力的 AI 将会改变游戏规则。

你觉得怎么样?在下面告诉我。

关注我,了解更多关于数据和人工智能的文章!

** 所有图片均由作者创作。

数据驱动的方程发现

原文:towardsdatascience.com/on-data-driven-equation-discovery-5069795d239d?source=collection_archive---------3-----------------------#2023-12-01

乔治·米洛舍维奇Towards Data Science 乔治·米洛舍维奇

·

关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 12 月 1 日

--

图片由 ThisisEngineering RAEng 提供,来源于 Unsplash

借助实验验证的分析表达式来描述自然,一直以来是科学特别是物理学从万有引力定律到量子力学及更广泛领域成功的标志。随着气候变化、聚变和计算生物学等挑战使我们将关注点转向更多的计算,迫切需要在降低成本的同时保持物理一致性的简明而强大的减少模型。科学机器学习是一个新兴领域,它承诺提供这样的解决方案。本文是对近期数据驱动方程发现方法的简要回顾,面向对机器学习或统计学非常基础的科学家和工程师。

动机与历史视角

单纯地将数据拟合得很好已被证明是一种短视的努力,这一点通过托勒密的地心说模型得到了证明,该模型是直到开普勒的日心说之前最符合观测的模型。因此,将观察与基本物理原理相结合在科学中发挥了重要作用。然而,在物理学中,我们常常忽略了我们的世界模型已经是数据驱动的程度。以粒子标准模型为例,它有 19 个参数,其数值是通过实验确定的。用于气象和气候的地球系统模型虽然在基于流体动力学的物理一致核心上运行,但也需要对其许多敏感参数进行仔细的观测校准。最后,减少阶次建模在聚变和空间天气社区中正在获得关注,并且很可能在未来保持相关性。在生物学和社会科学等领域,第一性原理方法效果较差,统计系统识别已经发挥了重要作用。

机器学习中有多种方法可以直接从数据中预测系统的演变。最近,深度神经网络在天气预报领域取得了显著进展,这一点由Google’s DeepMind团队等证明。这在一定程度上归因于他们拥有的巨大资源,以及气象数据和物理数值天气预报模型的一般可用性,这些模型通过数据同化将这些数据插值到全球。然而,如果数据生成的条件发生变化(例如气候变化),这些完全基于数据驱动的模型可能会表现不佳。这意味着将这些黑箱方法应用于气候建模及其他数据不足的情况可能会存在疑问。因此,在本文中,我将强调从数据中提取方程的方法,因为方程更具可解释性,且不易过拟合。在机器学习术语中,我们可以将这些范式称为高偏差——低方差

首先值得一提的方法是Schmidt 和 Lipson的开创性工作,该工作使用了遗传编程(GP)进行符号回归,从简单动力系统(如双摆等)的轨迹数据中提取方程。该过程包括生成候选符号函数,推导这些表达式中涉及的偏导数,并将其与从数据中数值估算的导数进行比较。这个过程会重复进行,直到达到足够的准确性。重要的是,由于潜在候选表达式的数量非常庞大且相对准确,因此选择符合“简约原则”的表达式。简约原则通过表达式中的项数的倒数来衡量,而预测准确性则通过仅用于验证的保留实验数据上的误差来衡量。这个简约建模的原则构成了方程发现的基础。

遗传编程(GP)的思想是通过尝试一系列潜在的项来探索可能的分析表达式空间。这个表达式被编码在上面的树中,其结构可以表示为一种“基因”。通过突变这些基因的序列、选择和交叉最优候选项,可以获得新的树。例如,要获取右侧框中的方程,只需跟随右侧树的层级中的箭头即可。

这种方法的优点在于探索各种可能的解析表达式组合。它已在各种系统中尝试过,特别是,我将重点介绍 AI — 费曼,借助 GP 和神经网络,能够从数据中识别出费曼物理讲座中的 100 个方程。GP 的另一个有趣应用是发现 气候中的海洋参数化,其中实质上运行了一个高保真模型来提供训练数据,同时从训练数据中发现了较便宜的低保真模型的修正。然而,GP 并非没有缺陷,人工干预是不可或缺的,以确保参数化效果良好。此外,由于它遵循进化的过程:试错,因此可能非常低效。还有其他可能性吗?这将引导我们到近年来主导方程发现领域的方法。

稀疏系统识别

非线性动力学的稀疏识别(SINDy) 属于概念上简单但强大的方法家族。由 Steven L. Brunton 的团队介绍,以及 其他团队 ,并配有文档完善、支持良好的 代码库YouTube 教程。要获得一些实际操作经验,只需试用他们的 Jupyter notebooks。

我将根据 原始 SINDy 论文 描述该方法。通常,您拥有轨迹数据,其中包含 x(t)、y(t)、z(t) 等坐标。目标是从数据中重建一阶常微分方程(ODEs):

通常,x(t)(有时称为响应函数)是从观察或建模数据中获得的。目标是估计 f = f(x)(ODE 的右侧)的最佳选择。通常,会尝试一个单项式库,算法会继续寻找稀疏系数向量。系数向量的每个元素控制着这个单项式对整个表达式的重要性。

在这里,函数 f = f(x) 被表示为单项式库与稀疏向量的乘积。有关更多说明,请参见下面的图形。

有限差分法(例如)通常用于计算常微分方程左侧的导数。由于导数估计容易出错,这会在数据中产生噪声,这通常是不希望的。在某些情况下,过滤可能有助于处理这些问题。接下来,选择一个单项式(基函数)库来拟合常微分方程右侧,如图所示:

如[1]所示的非线性动力学的稀疏识别(SINDy)。其思想是提取一小部分基函数(例如单项式),即全基库的一个子集,当数据代入时,这些基函数能满足方程。在左侧写出时间导数(每列对应不同变量,每行对应数据样本,样本可能是时间),而右侧则是基库矩阵(其行跨度每个基函数)与稀疏向量相乘,稀疏向量是算法学习的对象。促进稀疏性意味着我们希望最终得到的大多数向量值为零,这符合节俭原则。

问题在于,除非我们拥有天文数字级的数据,否则这个任务将毫无希望,因为许多不同的多项式都会很好地工作,这将导致显著的过拟合。幸运的是,这正是稀疏回归的救援之处:重点是对右侧有太多活跃基函数进行惩罚。这可以通过多种方式实现。原始 SINDy 所依赖的一种方法叫做序列阈值最小二乘法(STLS),可以总结如下:

来自 SINDy 论文补充材料的 Matlab 稀疏表示代码。

换句话说,使用标准最小二乘法求解系数,然后在每次应用最小二乘法时逐步消除小系数。该过程依赖于一个超参数,该超参数控制系数的小值容忍度。这个参数看似任意,但可以进行所谓的帕累托分析:通过保留一些数据并测试学习模型在测试集上的表现来确定这个稀疏化超参数。这个系数的合理值对应于学习模型的准确性与复杂性曲线(复杂性 = 包含的项数)中的“肘部”,即所谓的帕累托前沿。或者,某些其他文献推荐使用信息准则来推广稀疏性,而不是执行上述的帕累托分析。

作为 SINDy 的最简单应用,考虑如何使用 STLS 成功识别Lorenz 63 模型

将 SINDy 应用于 Lorenz 63 模型识别的示例。系数(颜色)大致对应于用于生成训练数据的系数。这些数据是通过解决带有这些参数的相关 ODE 生成的。

STLS 在应用于自由度较大的系统(如偏微分方程(PDEs))时存在局限性,在这种情况下,可以考虑通过 主成分分析(PCA)或 非线性自编码器 等进行降维。后来,SINDy 算法通过 PDE-FIND 论文 得到了进一步改进,该论文引入了顺序阈值岭回归 (STRidge)。在后者中,岭回归 指的是带有 L2 惩罚的回归,而在 STRidge 中则交替进行小系数的淘汰,如同 STLS。这使得从仿真数据中发现各种标准 PDE 成为可能,例如 布尔戈斯方程科尔特韦格-德弗里斯方程(KdV)、纳维-斯托克斯方程、反应-扩散方程,甚至是科学机器学习中常遇到的一个相当特殊的方程,Kuramoto-Sivashinsky 方程,由于需要直接从数据中估计其四阶导数项,这通常较为棘手。

Kuramoto-Sivashinsky 方程描述了层流火焰流中的扩散-热不稳定性。

该方程的识别直接基于以下输入数据(这些数据是通过数值求解相同方程获得的):

Kuramoto-Sivashinsky 方程的解。右侧面板显示了场,而右侧面板则显示了其时间导数。

这并不是说该方法容易出错。事实上,将 SINDy 应用于现实观察数据的一个大挑战在于这些数据往往本身稀疏且噪声较大,通常在这种情况下识别效果较差。同样的问题也影响了基于符号回归的方法,如遗传编程(GP)。

弱 SINDy 是一种较新的发展,它显著提高了算法在噪声方面的鲁棒性。这种方法已由多位作者独立实施,尤其是 丹尼尔·梅森丹尼尔·R·古列维奇帕特里克·赖恩博德。其主要思想是,与发现 PDE 的微分形式相比,发现其 [弱] 积分形式,通过在一组域上对 PDE 进行积分,并乘以一些测试函数。这允许通过分部积分,从 PDE 的响应函数(未知解)中去除棘手的导数,而将这些导数应用于已知的测试函数。这种方法进一步应用于 Alves 和 Fiuza 进行的等离子体物理方程发现,其中恢复了 Vlasov 方程和等离子体流体模型。

SINDy 方法的另一个显而易见的局限性是,识别始终受到构成基的项库(例如单项式)的限制。虽然可以使用其他类型的基函数,如三角函数,但这仍然不够通用。假设 PDE 具有一个有理函数的形式,其中分子和分母都可以是多项式:

这种情况使得像 PDE-FIND 这样的算法应用变得复杂

这种情况当然可以通过遗传编程(GP)轻松处理。然而,SINDy 也扩展到了这样的情况,引入了 SINDy-PI(并行隐式),该方法成功用于识别描述 贝洛乌索夫-扎博廷斯基反应 的 PDE。

最后,其他稀疏促进方法,如稀疏贝叶斯回归,也称为相关向量机(RVM),也被用于使用完全相同的拟合术语库的方法从数据中识别方程,但受益于边际化和统计学家高度尊重的“奥卡姆剃刀”原则。我在这里不覆盖这些方法,但可以说,像张和林这样的作者声称对 ODEs 的系统识别更为稳健,并且这种方法甚至尝试用于学习简单条带气候模型的闭合,其中作者认为 RVM 比 STRidge 更稳健。此外,这些方法为识别方程的估计系数提供了自然的不确定性量化(UQ)。话虽如此,集成 SINDy的最新发展更加稳健,提供 UQ,但则依赖于统计方法自助聚合(bagging),这一方法也广泛应用于统计学和机器学习。

物理信息深度学习识别

解决和识别偏微分方程(PDE)系数的另一种方法是物理信息神经网络(PINNs),该方法在文献中引起了极大关注。主要思想是使用神经网络对 PDE 的解进行参数化,并将运动方程或其他类型的基于物理的归纳偏置引入损失函数。损失函数在预定义的一组所谓的“协同点”上进行评估。在执行梯度下降时,神经网络的权重会被调整,从而“学习”解决方案。所需提供的唯一数据包括初始条件和边界条件,这些条件也在一个单独的损失项中受到惩罚。该方法实际上借鉴了旧的非神经网络的协同方法。虽然神经网络提供了自然的自动微分方式,使这种方法非常有吸引力,但事实证明,PINNs 与标准数值方法如有限体积/有限元等通常不具竞争力。因此,作为解决前向问题(数值求解 PDE)的工具,PINNs并不那么有趣。

它们成为解决逆问题的有趣工具:通过数据估计模型,而不是通过已知模型生成数据。在原始 PINNs 论文中,两个 Navier-Stokes 方程的未知系数是通过数据进行估计的。

输入到 PINN 损失函数中的 Navier-Stokes 方程的假定形式。通过识别,获得了两个未知参数(位于黄色框内)。有关 PINNs 的 tensorflow 实现,请参阅

回顾起来,与 PDE-FIND 等算法相比,这似乎有些天真,因为方程的一般形式已经被假定。然而,这项工作的一个有趣方面是算法并没有输入压力数据,而是假设了不可压缩流动,并通过 PINN 直接恢复压力的解。

PINNs 已经在各种情况下应用,我想特别强调一个应用是空间天气,在这个应用中,展示了通过解决 Fokker-Planck 方程的逆问题来估计辐射带中的电子密度。这里,重新训练神经网络的集成方法在估计不确定性方面非常有用。最终,为了实现可解释性,进行学习的扩散系数的多项式扩展。将这种方法与直接使用类似 SINDy 的方法进行比较会非常有趣,后者也提供了多项式扩展。

“物理信息”这个术语已经被其他团队采纳,他们有时发明了自己将物理先验融入神经网络的版本,并称之为类似“基于物理”或“受物理启发”等引人注目的名称。这些方法有时可以被归类为软约束(惩罚不满足某些方程或对称性的损失)或硬约束(将约束实施到神经网络的架构中)。这种方法的例子可以在气候科学等其他学科中找到。

由于反向传播的神经网络提供了一种估计时间和空间导数的替代方法,因此稀疏回归(SR)或遗传编程(GP)与这些神经网络配合方法的结合似乎是不可避免的。虽然这样的研究有很多,但我将重点介绍一个相对文档齐全且支持良好的DeePyMoD以及代码库。了解这种方法的工作原理足以理解同时期或之后出现的所有其他研究,并在各种方面改进

DeePyMoD 框架:PDE 的解通过前馈神经网络(NN)进行参数化。在最新的论文中,损失函数由两个项组成:数据与 NN 预测之间的均方误差(MSE)项;正则化损失,它惩罚包括活跃库项在内的 PDE 函数形式。类似于 SINDy 的 STLS,当网络收敛到解时,稀疏性向量中的小项被消除,从而仅推广库中最大的系数。然后,NN 的训练会重复进行,直到满足收敛标准。

损失函数包括均方误差(MSE):

以及促进 PDE 函数形式的正则化

与较弱的 SINDy 相比,DeePyMoD 在噪声下显著更稳健,仅需要在时空域上很少的观测点,这对于从观测数据中发现方程是个好消息。例如,许多 PDE-FIND 能正确识别的标准 PDE 也可以由 DeePyMoD 识别,但只需在包含噪声数据的空间中采样几千个点。然而,使用神经网络进行这项任务的代价是更长的收敛时间。另一个问题是一些 PDE 对原始配合方法存在问题,例如由于高阶导数的 Kuramoto-Sivashinsky (KS) 方程。没有弱形式方法,从数据中识别 KS 通常很困难,尤其是在噪声存在的情况下。更多的近期发展涉及将弱 SINDy 方法与神经网络配合方法结合。另一个有趣且实际未探讨的问题是这些方法通常如何受到非高斯噪声的影响。

结论

总结来说,方程发现是基于物理的机器学习的自然候选者,正在全球多个团队积极开发。它已在流体动力学、等离子体物理、气候等多个领域找到应用。有关其他方法的更广泛概述,请参见综述文章。希望读者对该领域存在的不同方法有所了解,但我只是略微触及了表面,避免过于技术化。值得一提的是许多新的基于物理的机器学习方法,如神经常微分方程(ODEs)。

参考文献

  1. Camps-Valls, G. et al. 从数据中发现因果关系和方程。Physics Reports 1044,1–68 (2023)。

  2. Lam, R. et al. 学习高技能的中期全球天气预测。Science 0,eadi2336 (2023)。

  3. Mehta, P. et al. 物理学家机器学习的高偏差、低方差介绍。Physics Reports 810,1–124 (2019)。

  4. Schmidt, M. & Lipson, H. 从实验数据中提炼自由形式自然法则。Science 324,81–85 (2009)。

  5. Udrescu, S.-M. & Tegmark, M. AI Feynman: 一种受物理启发的符号回归方法。Sci Adv 6,eaay2631 (2020)。

  6. Ross, A., Li, Z., Perezhogin, P., Fernandez-Granda, C. & Zanna, L. 在理想化模型中对机器学习海洋子网格参数化的基准测试。Journal of Advances in Modeling Earth Systems 15,e2022MS003258 (2023)。

  7. Brunton, S. L., Proctor, J. L. & Kutz, J. N. 通过稀疏识别非线性动态系统从数据中发现主方程。Proceedings of the National Academy of Sciences 113,3932–3937 (2016)。

  8. Mangan, N. M., Kutz, J. N., Brunton, S. L. & Proctor, J. L. 通过稀疏回归和信息准则选择动态系统模型。Proceedings of the Royal Society A: Mathematical, Physical and Engineering Sciences 473,20170009 (2017)。

  9. Rudy, S. H., Brunton, S. L., Proctor, J. L. & Kutz, J. N. 数据驱动的偏微分方程发现。Science Advances 3,e1602614 (2017)。

  10. Messenger, D. A. & Bortz, D. M. 用于偏微分方程的弱 SINDy。Journal of Computational Physics 443,110525 (2021)。

  11. Gurevich, D. R., Reinbold, P. A. K. & Grigoriev, R. O. 对非线性 PDE 模型的鲁棒和最优稀疏回归。Chaos: An Interdisciplinary Journal of Nonlinear Science 29,103113 (2019)。

  12. Reinbold, P. A. K., Kageorge, L. M., Schatz, M. F. & Grigoriev, R. O. 通过物理约束的符号回归从噪声、不完整、高维实验数据中进行鲁棒学习。Nat Commun 12,3219 (2021)。

  13. Alves, E. P. & Fiuza, F. 从全动能模拟中数据驱动地发现简化的等离子体物理模型。Phys. Rev. Res. 4,033192 (2022)。

  14. Zhang, S. & Lin, G. 具有误差条的数据驱动的物理定律发现。皇家学会 A 卷:数学、物理和工程科学学报 474, 20180305 (2018)。

  15. Zanna, L. & Bolton, T. 数据驱动的海洋中尺度闭合方程发现。地球物理研究快报 47, e2020GL088376 (2020)。

  16. Fasel, U., Kutz, J. N., Brunton, B. W. & Brunton, S. L. Ensemble-SINDy:在低数据、高噪声极限下,通过主动学习和控制实现稳健的稀疏模型发现。皇家学会 A 卷:数学、物理和工程科学学报 478, 20210904 (2022)。

  17. Raissi, M., Perdikaris, P. & Karniadakis, G. E. 物理信息神经网络:解决涉及非线性偏微分方程的正向和逆向问题的深度学习框架。计算物理学杂志 378, 686–707 (2019)。

  18. Markidis, S. 旧与新:物理信息深度学习能否取代传统线性求解器?大数据前沿 4, (2021)。

  19. Camporeale, E., Wilkie, G. J., Drozdov, A. Y. & Bortnik, J. 数据驱动的 Fokker-Planck 方程发现:使用物理信息神经网络研究地球辐射带电子。地球物理研究杂志:空间物理学 127, e2022JA030377 (2022)。

  20. Beucler, T. 在模拟物理系统的神经网络中强制实施解析约束。物理评论快报 126, 098302 (2021)。

  21. Both, G.-J., Choudhury, S., Sens, P. & Kusters, R. DeepMoD:在噪声数据中进行模型发现的深度学习。计算物理学杂志 428, 109985 (2021)。

  22. Stephany, R. & Earls, C. PDE-READ:使用深度学习发现可读的偏微分方程。神经网络 154, 360–382 (2022)。

  23. Both, G.-J., Vermarien, G. & Kusters, R. 稀疏约束神经网络用于 PDE 模型发现。预印本于 doi.org/10.48550/arXiv.2011.04336 (2021)。

  24. Stephany, R. & Earls, C. Weak-PDE-LEARN:一种基于弱形式的 PDE 发现方法,适用于噪声大、数据有限的情况。预印本于 doi.org/10.48550/arXiv.2309.04699 (2023)。

在代表性不足的群体面前的学习

原文:towardsdatascience.com/on-learning-in-the-presence-of-underrepresented-groups-8937434d3c85?source=collection_archive---------11-----------------------#2023-07-11

改变是困难的:对亚群体偏移的更深入了解 (ICML 2023)

Yuzhe YangTowards Data Science Yuzhe Yang

·

关注 发表在 Towards Data Science ·8 min read·2023 年 7 月 11 日

--

让我向您介绍我们最新的工作,这项工作已被 ICML 2023 会议接受:改变是困难的:对亚群体偏移的更深入了解。机器学习模型在许多应用中表现出巨大的潜力,但它们在训练数据中代表性不足亚群体上往往表现较差。理解导致这种亚群体偏移的机制变异,以及算法在大规模不同偏移下的泛化能力仍然是一个挑战。在这项工作中,我们旨在通过提供对亚群体偏移及其对机器学习算法影响的细致分析来填补这一空白。

我们首先提出了一个统一的框架,剖析并解释了子群体中常见的变化。此外,我们引入了一个综合基准,包含 20 种最先进的算法,我们在 12 个现实世界的数据集上对其进行了评估,这些数据集涵盖了视觉语言医疗保健领域。通过我们的分析和基准测试,我们提供了关于子群体变化及机器学习算法在这些现实世界变化下如何泛化的有趣观察和理解。代码、数据和模型已经在 GitHub 上开源:github.com/YyzHarry/SubpopBench

背景与动机

机器学习模型在面对分布变化时通常表现出性能下降。这种变化发生在基础数据分布发生变化时(例如,训练分布与测试分布不同),导致模型部署时性能下降。构建对这些变化具有鲁棒性的机器学习模型对于在现实世界中安全部署这些模型至关重要。一种普遍存在的分布变化类型是子群体变化,其特征是在训练和部署之间某些子群体的比例发生变化。在这种情况下,模型可能在总体上表现良好,但在稀有子群体中表现较差。

:在牛与骆驼分类任务中,牛通常出现在绿色背景中,而骆驼则通常出现在黄色背景中。因此,模型在这些背景下表现良好,但无法泛化到背景颜色不同的图像中。 :在医学诊断任务中,机器学习模型在代表性不足的年龄或种族群体上表现往往较差。(图片由作者提供)

例如,在牛和骆驼分类任务中,牛通常出现在绿色草地区域,而骆驼则通常出现在黄色沙地背景区域。然而,这种关联是虚假的,因为牛或骆驼的存在与背景颜色无关。因此,训练好的模型在上述图像上表现良好,但无法泛化到训练数据中稀少的不同背景颜色的动物,例如沙地上的牛或草地上的骆驼。

此外,研究发现,在医学诊断方面,机器学习模型在代表性不足的年龄或种族群体上表现往往较差,这引发了重要的公平性问题。

所有这些变化通常被称为子群体变化,但对于导致子群体变化的机制变异及算法如何在大规模的不同变化下泛化的了解甚少。那么,如何建模子群体变化

子群体变化的统一框架

我们首先提供了一个统一的子群体转移建模框架。在经典分类设置中,我们有来自多个类别的训练数据(其中我们使用不同的颜色密度来表示每个类别中的样本数量)。然而,当涉及子群体转移时,除了类别之外还存在属性——例如在牛骆驼问题中的背景颜色。在这种情况下,我们可以根据属性标签定义离散的子群体,而且在同一类别中,不同属性的样本数量也可能有所不同(见下图)。自然地,为了测试模型,类似于我们在所有类别中评估性能的分类设置,在子群体转移中我们测试模型在所有子群体上的表现,以确保所有子群体中的最差性能也足够好,或确保所有组的性能都同样优秀

在子群体转移中,我们需要考虑属性,而不仅仅是类别标签。(图片由作者提供)

具体而言,为了提供一个通用的数学公式,我们首先使用贝叶斯定理重写分类模型。我们进一步将每个输入x视为由一组潜在核心特征(X_core)和一个属性列表(a)完全描述或生成。在这里,X_core表示与标签特定的、支持稳健分类的潜在不变成分,而属性a可能具有不一致的分布,并且不是标签特定的。因此,我们可以将这种建模整合回方程,并进一步分解为三项,如下所示:

一个通用的子群体转移建模框架。(图片由作者提供)

具体而言,第一项表示X_corey之间的点对点互信息(PMI),这是与潜在类别标签相关的稳健指标。第二项和第三项分别对应于属性分布和标签分布中可能出现的偏差。这种建模解释了属性和类别如何在子群体转移下影响结果。因此,给定训练和测试分布之间不变的X_core,我们可以忽略第一项的变化,关注属性类别在子群体转移下如何影响结果。

基于此框架,我们正式定义并描述了四种基本的子群体转移类型:虚假相关属性不平衡类别不平衡属性泛化。每种类型构成了子群体转移中可能出现的基本转移成分。

四种基本的子群体转移类型。(图片由作者提供)

首先,当某些属性在训练数据中与标签y存在虚假相关性,但在测试数据中没有时,这意味着虚假相关性。此外,当某些属性的采样概率远小于其他属性时,会引发属性不平衡。类似地,类别标签可能会表现出不平衡的分布,导致对少数标签的偏好较低,这将导致类别不平衡。最后,某些属性在训练中可能完全缺失,但在测试中对于某些类别却存在,这促使了属性泛化的需求。每种转移的属性/类别偏差来源及其对分类模型的影响总结在下面的表格中:

(图片由作者提供)

这四种情况构成了基本的转移组件,并且是解释真实数据中复杂子群体转移的重要元素。在实际应用中,数据集通常同时包含多种类型的转移,而不仅仅是一种。

SubpopBench:子群体转移基准测试

在建立了公式后,我们提出了SubpopBench,这是一个包括在 12 个真实世界数据集上评估的最先进算法的综合基准测试。特别是,这些数据集来自各种模态和任务,包括视觉、语言和医疗保健应用,数据模态范围从自然图像、文本、临床文本到胸部 X 光。这些数据集还展现了不同的转移组件。

SubpopBench 基准测试。(图片由作者提供)

SubpopBench 基准测试。(图片由作者提供)

关于此基准测试的详细信息,请参阅我们的论文。通过建立的基准测试和使用 20 种最先进算法训练的超过 10K 模型,我们揭示了未来研究中的一些有趣观察。

对子群体转移的细粒度分析

SOTA 算法仅改善某些类型的转移

首先,我们观察到 SOTA 算法仅在某些类型的转移上改善子群体鲁棒性,而在其他类型的转移上则没有。

(图片由作者提供)

我们在这里绘制了各种 SOTA 算法相对于 ERM 的最差组准确性改进。对于虚假相关性类别不平衡,现有算法可以提供相对于 ERM 的一致最差组增益,表明在解决这两种特定转移上已有进展。

然而,有趣的是,当涉及到属性不平衡时, across 数据集几乎没有观察到改进。此外,对于属性泛化,性能甚至变得更差。

这些发现强调了当前的进展仅针对特定的转移,而对于更具挑战性的转移,如 AG,没有进展

表示和分类器的作用

此外,我们受到启发去探讨表示分类器在子群体变化中的作用。具体来说,我们将整个网络分为两个部分:特征提取器f和分类器g,其中f从输入中提取潜在特征,而g输出最终预测。我们提出的问题是,表示和分类器如何影响子群体性能

(作者提供的图片)

首先,给定一个基础的 ERM 模型,当仅优化分类器学习并固定表示时,可以显著提高虚假相关类别不平衡的性能,这表明 ERM 学到的表示已经足够好。然而有趣的是,改进表示学习而非分类器可以带来显著的提升,特别是在属性不平衡方面,这表明我们可能需要更强大的特征来应对某些变化。最后,无分层学习的方式在属性泛化下没有性能提升。这突显了在面对现实中不同类型的变化时,需要考虑模型管道设计

关于模型选择与属性可用性

此外,我们观察到模型选择属性可用性对子群体变化评估有显著影响。

(作者提供的图片)

具体而言,当逐渐去除训练和/或验证数据中的属性注释时,所有算法的性能都出现了显著下降,特别是当训练和验证数据中没有属性可用时。

这表明获取属性仍在子群体变化中发挥重要作用,未来的算法应该考虑更现实的场景以进行模型选择和属性可用性。

超越最差组准确率的指标

最后,我们揭示了基本的 权衡在评估指标之间。最差组准确率,或WGA,被认为是子群体评估的金标准。然而,改善 WGA 是否总是提升其他有意义的指标

(作者提供的图片)

我们首先展示了改善 WGA 可能导致某些指标性能提升,例如这里显示的调整准确率。然而,如果我们进一步考虑最差情况精度,它却与 WGA 显示出非常强的负线性相关性。这揭示了使用 WGA 作为唯一指标来评估模型在子群体变化中的表现的基本限制:表现良好的模型具有高 WGA,但其最差类别精度可能很低,这在医疗诊断等关键应用中尤其令人担忧。

我们的观察强调了在子群体转移中需要更多现实且广泛的评估指标。我们还展示了许多在本文中与 WGA 呈负相关的其他指标。

结语

总结本文,我们系统地研究了子群体转移问题,形式化了一个统一的框架来定义和量化不同类型的子群体转移,并进一步建立了一个全面的基准,以便在真实世界数据中进行评估。我们的基准包括 20 种 SOTA 方法和 12 个来自不同领域的真实数据集。基于超过 10K 训练模型,我们揭示了子群体转移中的有趣特性,这些特性对未来的研究具有重要意义。我们希望我们的基准和发现能够促进现实和严格的评估,并激发子群体转移领域的新进展。最后,我附上了几篇相关论文的链接;感谢阅读!

代码: github.com/YyzHarry/SubpopBench

项目页面: subpopbench.csail.mit.edu/

演讲: www.youtube.com/watch?v=WiSrCWAAUNI

关于压缩大数据的重要性

原文:towardsdatascience.com/on-the-importance-of-compressing-big-data-edd4cc7441d2?source=collection_archive---------18-----------------------#2023-01-24

为什么以及如何最小化你的数据存储占用

Chaim RandTowards Data Science Chaim Rand

·

关注 发表在Towards Data Science ·15 分钟阅读·2023 年 1 月 24 日

--

图片来源:Joshua SortinoUnsplash

“数据是新的石油”,这是Clive Humby提出的一句话,用来描述现代许多公司在其发展和成功中对数据日益增长的依赖。公司们正收集大量数据,以至于像拍字节(petabyte)、艾字节(exabyte)和泽字节(zettabyte)这样的计量单位,已经在日常对话中取代了兆字节(megabyte)、千兆字节(gigabyte)和太字节(terabyte)。然而,无目的的数据收集是无用且浪费的。这个事实可以通过对 Humby 名言的以下扩展来最准确地总结:

数据就像原油。它有价值,但如果未经精炼,它实际上不能被有效使用。

迈克尔·帕尔默

为了使数据具有价值,收集数据的方式、使用目标及如何实现这些目标需要精心设计。这种设计的一个重要元素是如何以及在哪里存储收集的数据。鉴于今天收集的数据规模巨大,存储需要专门的解决方案。为了满足不断增长的需求,数据中心存储设施的规模不断扩大,基于云的对象存储服务如Amazon S3Google Cloud StorageAzure Blob Storage也越来越受欢迎。

数据存储的成本

数据存储涉及许多成本,有些成本比其他成本更为明显。如果管理不当,存储成本很容易成为你每月研发费用中的主导因素。以下是一些成本考虑因素。

直接存储成本

尽管单位存储空间的成本多年来一直稳步下降(由于技术进步),但这种下降已被需求的增加所盖过。不论你使用的是本地数据中心还是云存储服务,直接存储成本可能迅速上升

环境成本

在过去的几年里,我们见证了对数据中心碳足迹意识的提高,以及对计算科学中增加可持续性的呼吁(例如,见这里)。虽然数据中心服务器的计算密集型工作负荷占据了碳足迹的最大部分,但存储也需要在电力、冷却和生命周期更换上进行大量投资。预计大约 10%的数据中心碳足迹——相当于一个中型西方国家的碳足迹——来自数据存储(例如,见这里这里)。

数据流成本

虽然不是直接存储成本,但必须关注与数据推送和拉取相关的成本。这些可能是与云存储服务相关的显性数据传输成本,或是与设计支持数据应用所需通信带宽的基础设施相关的成本。

对数据应用的影响: 还需要注意数据存储的方式,特别是其大小,对数据应用的影响。许多数据应用依赖于从存储中持续流动的数据。在理想情况下,数据应以足够快的速度流经系统,以使应用主机的所有计算资源得到充分利用。然而,如果您的存储解决方案设计不当,您的应用可能会在等待输入数据时处于空闲状态。这将增加运行数据应用所需的整体时间,并且增加计算成本。我们在本文的附录中更详细地描述了这一情况。

降低数据存储成本

我们的数据收集和存储策略解决方案必须考虑到数据存储的潜在高成本。一些良好的做法包括以下几点:

  1. 将数据的收集和存储限制在实际需要的范围内。

  2. 从存储中删除不再需要的数据。

  3. 许多云服务提供商提供多种存储类别—以应对不同类型的数据访问模式(例如,请参见 Amazon S3 Intelligent-TieringGoogle Storage ClassesAzure Storage Access Tiers)。虽然标准存储选项(例如,Amazon S3Google Cloud StorageAzure Blob Storage)推荐用于频繁访问或需要低延迟的数据,将其他数据分配给更具归档性质的存储类别,可以降低成本。

  4. 压缩格式存储数据。

在本文中,我们将重点讨论数据压缩作为降低存储成本的一种手段。我们将讨论几种不同的压缩技术,并通过示例展示应用这些技术的潜在节省。

免责声明

在深入讨论之前,需要说明几点免责声明:

  • 在我们的讨论中,我们将提到几种压缩技术和工具。这些仅作为示例。我们的意图是要推广这些技术或工具,而是与许多其他替代技术或工具进行比较。适合您的最佳解决方案将极大地依赖于您的具体需求。

  • 在决定使用某种压缩算法(或任何已发布的算法)之前,请确保您已阅读并理解相关的使用条款。

  • 虽然我们的示例将重点放在图像数据上,但这些基本原则同样适用于其他领域。

数据压缩

在本节中,我们将回顾几种压缩技术并通过示例展示它们。我们将得出的主要见解是,通过将压缩算法适应于数据的特定属性,可以获得最佳结果。一个典型的例子就是有损无损压缩方案之间的区别。

有损与无损压缩

在无损压缩方案中不会丢失数据。解压缩数据时,所有信息都会被恢复:X=uncompress(compress(X))。在有损压缩方案中,由于压缩的原因会丢失一些信息:X≠uncompress(compress(X))。虽然这种数据丢失一开始可能让人担忧,但在许多情况下,这种影响会很小。一个经典的例子就是图像压缩。许多图像压缩方法,如流行的 JPEG 压缩,涉及到一定程度的数据丢失,但(前提是使用合理的压缩配置)得到的图像与原始图像几乎无法区分。毫无疑问,尽管视觉上相似,但某些算法可能对这种数据丢失特别敏感。然而,更多情况下,这种影响不会显著。而且潜在的存储空间节省可以非常可观。我们将在下面进一步扩展图像压缩的话题。

如何衡量压缩质量

评估压缩方案质量的方法有很多种,包括:

  1. 压缩比:这衡量了压缩过程中数据大小的减少。

  2. 信息丢失:仅在有损压缩的情况下相关,这衡量了压缩导致的信息丢失对数据质量的影响程度。根据数据类型、领域、数据的用途等,有许多不同的方法来衡量这种质量损失。

  3. 压缩开销:使用压缩方案意味着需要在管道的不同阶段进行压缩解压缩。这两项活动都需要一定的计算资源,并可能意味着一定程度的延迟。所需的计算量和延迟可以根据选择的压缩策略有所不同。

  4. 基础设施依赖:不同的压缩方案会因其基础设施依赖而有所不同。这些依赖可能是硬件依赖和/或软件依赖。

在接下来的示例中,我们将仅测量压缩比质量损失。在实际操作中,应应用其他指标,以便对压缩策略做出全面的决策。

示例

为了方便讨论,我们将考虑一个玩具示例,其中每个数据样本包括:一个 800x534 RGB 图像、两个具有 16 个类别的像素级分类图,以及一个像素级深度图,包含从图像平面到场景中每个像素捕获的 3D 位置的距离(以米为单位)。如果您想跟随示例,可以将下面的代码片段应用于Unsplash上这篇博客文章顶部的图像。

from PIL import Image
import numpy as np
np.random.seed(0)

im = Image.open('image.jpeg', mode='r')
image = np.array(im)
H,W,C = image.shape

# create artificial labels from image color channels
label1 = image[:,:,0].astype(np.int32)//16
label2 = image[:,:,1].astype(np.int32)//16
depth = (image[:,:,2]+np.random.normal(size=(H,W))).astype(np.float32)

# write all data sample elements to file
with open('image.bin','wb') as f:
    f.write(image.tobytes())
with open('label2.bin','wb') as f:
    f.write(label2.tobytes())
with open('label1.bin','wb') as f:
    f.write(label1.tobytes())
with open('depth.bin','wb') as f:
    f.write(depth.tobytes())

测量原始数据文件的存储足迹(例如,通过在 Linux 中运行ls -l),我们发现图像需要 1.3 MB 的存储空间,而 3 个标签图需要每个 1.7 MB。

选择文件格式

设计数据存储策略的一个重要步骤是选择存储数据的格式。这个决定可能会对数据应用程序的访问简便性和速度产生重大影响。特别是,您的设计应考虑不同应用程序的不同访问模式。在上一篇文章中,我们讨论了选择文件格式对机器学习训练的一些潜在影响。为了简化起见,我们将数据样本存储在标准 tarball 中。

import tarfile
with tarfile.open("base.tar", "w") as tar:
    for name in ["image.bin", "label1.bin", "label2.bin", "depth.bin"]:
        tar.add(name)

结果文件为 6.2 MB。

使用 ZIP 变体进行压缩

所谓“ZIP 变体”,是指各种流行的通用文件格式及其相关压缩方案,包括ZIPgzip7-zipbzip2Brotli等。请注意,虽然我们将这些格式归为一类,但底层算法可能存在显著差异。许多文件格式包含用于自动压缩数据样本的 ZIP 变体标志。例如,使用bzip2压缩,我们可以将 tarball 的大小减少到 2.7 MB,减少了 2.3 倍。

import tarfile
with tarfile.open("base.bz2", "w:bz2") as tar:
    for name in ["image.bin", "label1.bin", "label2.bin", "depth.bin"]:
        tar.add(name)

使用 ZIP 变体进行压缩特别吸引人,因为它的通用性。它可以普遍应用,而无需了解底层数据的具体类型或领域。在接下来的章节中,我们将查看是否可以使用考虑到原始数据类型详细信息的压缩方案来改进这一结果。

使用低精度数据类型

通过将数据转换为使用低位精度数据类型,可以节省大量存储空间。在我们的例子中有两个优化机会。

  1. 将 int32 替换为 uint8:通过使用最小的整数类型来满足需求,可以节省大量的存储空间。在我们的例子中,32 位整数表示显然对于表示我们的 16 类标签图来说是过多的。这些可以很容易地适应 uint8 矩阵而不会丢失任何信息。

  2. 将 float32 替换为 float16:与整数精度减少相反,此操作导致信息丢失(即,它是有损的)。此更改应仅在评估其对数据算法的潜在影响后进行。在下面的代码块中,我们演示了两种衡量数据质量变化的指标。这些可以用来预测对数据算法的影响。理想情况下,我们会发现这些指标与算法性能之间的某种关联,但这并不总是那么简单。

label1 = label1.astype(np.uint8)
label2 = label2.astype(np.uint8)
depth_new = depth.astype(np.float16)

# measure loss of quality
from numpy import linalg as LA
l_max = LA.norm((depth-depth_new.astype(np.float32)).flatten(),np.inf) # 0.12
l_2 = LA.norm((depth-depth_new.astype(np.float32)).flatten(),2) # 10.47

with open('label1.bin','wb') as f:
    f.write(label1.tobytes())
with open('label2.bin','wb') as f:
    f.write(label2.tobytes())
with open('depth.bin','wb') as f:
    f.write(depth_new.tobytes())

仅这些优化就产生了 2.9 MB 的 tarball,减少了 2.14 倍。

使用按位操作合并元素

目前,每个分类图都存储在 8 位uint8缓冲区中。然而,由于这些图像只包含 16 个类别,它们实际上每个只使用了四位。我们可以通过将两个图像合并成一个数据图来进一步压缩数据。

# compress
combined_label = (label2 * 16 + label1).astype(np.uint8)

# restore
label1 = combined_label % 16
label2 = combined_label // 16

结果的 tarball 大小为 2.5 MB,总体减少了 2.48 倍。

注意,我们也可以考虑将每个单独标签图中的相邻元素对合并,从而将其分辨率降至 400x534。在实践中,我们发现将单独的标签图合并有利于在后续处理阶段更好地压缩(如下一节讨论)。借用信息论中的术语,结果具有更低的熵。

图像压缩

各种图像压缩算法(有时称为编解码器)利用图像数据的独特统计特性来提高压缩率。虽然对图像压缩的全面概述超出了本文的范围,但我们将触及一些与当前问题相关的要点。

对图像编解码器进行初步搜索将返回各种选项,包括 PNGJPEGWebPJPEG XL 等。这些编解码器在几个属性上有所不同,包括以下几点:

  1. 无损与有损压缩支持:一些编解码器支持无损压缩,一些支持有损压缩,还有一些支持两者。在大多数情况下,有损压缩的压缩率比无损压缩更好。

  2. 支持的输入格式:不同的算法支持不同类型的输入。典型的限制包括支持的颜色通道数量、支持的每像素位数等。

  3. 压缩质量控制:编解码器在允许控制结果压缩质量的程度和性质上有所不同。例如,通过调整质量控制,我们可以管理压缩率与编码/解码速度和/或信息丢失之间的权衡。

  4. 底层压缩算法:编解码器底层的算法行为各异,优化的功能不同,表现出的伪影也不同。例如,一些算法可能比其他算法更容易去除你所依赖的特定图像频率。

结构相似性指数测量(SSIM 是一种常用的度量标准,用于评估图像压缩方案带来的图像质量退化。SSIM 值范围从 0 到 1,其中 1 表示与原始图像完全匹配。其他测量图像退化的指标包括 MSEPSNR。如前所述,你的目标应是选择能够预测信息丧失对数据算法性能影响的指标。

关于使用有损图像压缩方案,有几点需要注意:

  • 尽管大多数编解码器优化了视觉感知,你的算法可能对图像中那些不明显的元素较为敏感。特别是,依赖于高度视觉相似性来评估压缩方案应替代深入评估。

  • 当心那些给人以某个编解码器总是优于其他编解码器印象的网站。实际上,不同编解码器的相对性能在图像领域(例如,深空图像与医学图像)之间差异很大,甚至同一领域内的图像样本之间也会有所不同。强烈建议你对自己的图像数据集进行分析。

  • 对于视频序列数据,你可能会觉得有必要采用视频压缩格式。视频压缩利用相邻帧之间的相似性来进一步提高压缩率。然而,这通常会导致与图像压缩相比显著降低的质量(例如,通过 SSIM 测量)。在某些情况下,你可能会发现独立压缩每一帧效果更好。

我们的示例包括两个应用图像压缩的机会。首先,我们使用经典的有损 JPEG 编解码器压缩图像图,并将压缩质量设置为 95(有关设置质量值的详细信息,请参见 这里)。接下来,我们压缩标签图。由于我们期望机器学习算法高度依赖数据标签的准确性,因此我们选择无损 PNG 压缩格式,以避免丢失任何标签信息。请注意,尽管 JPEG 和 PNG 都是极受欢迎的格式,但它们并不以提供最佳压缩率而闻名。虽然足够用于我们的演示目的,但你可能会发现使用更现代的图像压缩算法能获得更好的结果。

下面的代码块演示了使用Pillow 包(版本 9.2.0)进行图像压缩。我们使用scikit-image 包(版本 0.19.3)应用 SSIM 评分。

from PIL import Image
from skimage.metrics import structural_similarity as SSIM

Image.fromarray(combined_label).save('label.png')
Image.fromarray(image).save('image.jpg',quality=95)

decoded = np.array(Image.open('image.jpg'))

ssim = SSIM(image, decoded, channel_axis=2)} # 0.996

我们示例中的 SSIM 分数为 0.996,表明 JPEG 编码导致的信息质量损失相对较低。自然,这一水平的图像降级是否可接受将取决于消耗数据的 ML 算法的敏感性。请注意,我们选择的压缩质量相对较高。较低的质量率可能会导致更好的压缩,但以更大的图像细节损失为代价。

下图展示了原始输入、解码输出以及它们之间的绝对差异(为了增强效果,差异被缩放了 20 倍)。

JPEG 对Joshua SortinoUnsplash上的照片的影响

结果

在这个阶段,图像、标签和深度图分别占用 341 KB、205 KB 和 835 KB 的存储空间。tarball 的大小为 1.2 MB。通过对 tarball 应用通用的 ZIP 算法,这一大小降至 1.1 MB。这比我们最初开始时的简单压缩结果小了一半。最终的压缩序列总结在下面的代码块中:

import tarfile
from PIL import Image

combined_label = (label2 * 16 + label1).astype(np.uint8)
Image.fromarray(combined_label).save('label.png')
Image.fromarray(image).save('image.jpg',quality=95)

depth_new = depth.astype(np.float16)
with open('depth.bin','wb') as f:
    f.write(depth_new.tobytes())

with tarfile.open("final.tar.bz2", "w:bz2") as tar:
    for name in ["image.jpg", "label.png", "depth.bin"]:
        tar.add(name)

通过这个相对简单的序列,我们成功地将数据的大小减少了 5.64 倍。这意味着存储空间节省了超过80%。将其应用于你的完整数据集可能对存储成本产生深远影响

我们可能还会发现额外的压缩机会。然而,每次额外操作的压缩率可能会降低。此外,使用不同的压缩序列可能会获得更好的压缩率。根据你现有的存储成本和潜在的节省,可能值得继续探索。

摘要

在这篇文章中,我们讨论了数据压缩对数据科学尤其是机器学习的重要性。我们演示了几个简单的压缩技术,并测量了它们对数据存储大小的影响。如上所述,我们选择的方法不一定适合你。找到一个好的解决方案的关键包括:对原始数据类型的深入了解、对数据消耗方式的深刻理解,以及对相关压缩方案的良好掌握。

随着 2023 年的开始,我们发现自己深陷于通常被称为大数据革命的过程中。正确的数据管理,包括使用数据压缩方案,仅仅是这一革命中许多重要组成部分之一。新年快乐。

附录:压缩对数据应用的影响

许多数据应用可以描述为在不同设备和设备组件之间连续流动的大量数据。例如,在深度学习训练负载中,原始训练数据从存储中流向 CPU 工作节点进行预处理和批处理,训练批次从 CPU 送入训练加速器,然后在前向计算图的不同阶段流动,梯度通过反向传播计算,而在分布式训练的情况下,数据在参与的加速器之间进行通信。

数据在典型深度学习训练步骤中的流动(作者提供)

在理想情况下,数据会在系统中足够快地流动,以保持所有计算资源的充分利用。然而,有时你可能会发现数据流受到通信通道带宽限制的约束。这可能导致应用中存在性能瓶颈,并导致计算资源的未充分利用。在这种不理想的情况下,昂贵的资源处于闲置状态,等待数据输入。这类问题可以通过多种方式解决,包括:增加通信带宽(例如,使用不同规格的实例类型)、改变应用架构和/或减少数据的大小

如果数据的大小很大,并且存储与应用主机之间的通信带宽有限,你可能特别容易遇到数据流瓶颈。将数据存储在压缩格式中可以减少存储位置与应用之间接口的瓶颈潜力

压缩的一种潜在权衡是压缩和/或解压数据所需的额外计算资源。如果你的数据应用已经计算密集,你可能会发现将数据存储在压缩格式中虽然在应用的某一部分释放了数据流瓶颈,却在其他地方引入了计算瓶颈。因此,数据应用管道中的数据压缩可能成为在不同资源的利用之间取得微妙平衡的艺术

在 DAX 度量中使用中间结果

原文:towardsdatascience.com/on-using-intermediary-results-in-dax-measures-9971efa72ae

我们在 DAX 中一直使用表变量。但当我们需要计算中间结果并在 DAX 度量中稍后重用它们时,该怎么办?这个挑战听起来简单,但实际上并不容易。

Salvatore CagliariTowards Data Science Salvatore Cagliari

·发表于 Towards Data Science ·8 分钟阅读·2023 年 3 月 13 日

--

Mika Baumeister 提供的照片,来源于 Unsplash

介绍

我们在 DAX 中一直使用中间表变量。

例如,查看以下度量:

[SalesYTD] =
  VAR YTDDates = DATESYTD('Date'[Date])

RETURN
  CALCULATE(
      [Sum Online Sales]
      ,YTDDates
      )

在这种情况下,我们生成一个基于实际筛选上下文的年初至今表,借助于DATESYTD()函数,其中包含了从实际筛选上下文中的年初(即 1 月 1 日)到当前日期的所有日期(当然是基于当前的筛选上下文)。

但有时,我们需要做更多的事情。

例如,我们需要根据当前的筛选上下文查询一个表,并对结果进行进一步计算。

在这种情况下,我们必须生成一个中间表,并将结果分配给度量中的一个变量,以进行所需的计算。

由于我们需要筛选上下文,我们不能预先创建一个包含这些中间结果的表,因为这个表会非常庞大,以容纳所有可能的筛选组合。

所以,让我们来看看我们可以如何做到这一点。

基础查询

基础查询如下:

EVALUATE
  ADDCOLUMNS(
        VALUES(Customer[CustomerKey])
          ,"AverageSales", CALCULATE( AVERAGEX( 'Online Sales'
          ,('Online Sales'[UnitPrice] * 'Online Sales'[SalesQuantity])-'Online Sales'[DiscountAmount]
          )
        )
      )
ORDER BY [AverageSales] DESC

这个查询的(截断)结果如下:

图 1 — 基础查询结果(图示作者提供)

但是如果这仅仅是一个起点,我们还需要在一个度量中基于这个结果进行进一步的计算,会发生什么呢?

例如,我们想要显示每个月所有行的总和。

在这种情况下,我们需要计算每个月的平均销售额,并对结果进行求和。

让我们看看我将如何直观地创建一个解决方案,它不起作用,然后,之后这个解决方案将如何有效。

不应该这样做——或者说它不工作的原因

解决上述需求的第一种直观方法是使用 ADDCOLUMNS() 函数生成一个变量,预计算表格中的平均销售额。然后对行进行聚合以获得所需的结果:

DEFINE
  MEASURE 'All Measures'[AverageSalePerCustomer] =
    VAR AverageSalesPerCust =
            ADDCOLUMNS (
                VALUES ( Customer[CustomerKey] ),
                "AverageSales",
                -- Calculate the Average Sales per Customer using the Filter Context provided by ADDCOLUMNS()
                CALCULATE (
                  AVERAGEX (
                    'Online Sales',
                    ( 'Online Sales'[UnitPrice] * 'Online Sales'[SalesQuantity] ) - 'Online Sales'[DiscountAmount]
                    )
                  )
                )
RETURN
  -- Calculate a sum of all the rows calculated in the previous step
  SUM ( AverageSalesPerCust[AverageSales] )

EVALUATE
  -- Get the list of all Months and call the Measures defined above for each month
  ADDCOLUMNS (
      SUMMARIZECOLUMNS (
            'Date'[Year Month Short Name] 
            ),
            "AverageSalePerCustomer", [AverageSalePerCustomer]
      )
ORDER BY 'Date'[Year Month Short Name]

DAX Studio中执行此查询后,我得到以下错误信息:

图 2 — 第一次尝试的错误信息(作者提供的图)

问题在于 SUM() 无法与表变量一起使用。

让我们用稍微不同的方式尝试一下。

让我们在 RETURN 语句后用SUMX() 替换 SUM():

SUMX(AverageSalesPerCust
    ,[AverageSales])

错误信息如下:

图 3 — 使用第二种方法的错误信息(作者提供的图)

所以,SUMX() 也无法访问表变量中的列。

我尝试了其他方法来获得所需的结果:

  • 使用CALCULATETABLE() 生成可以被 SUM() 或 SUMX() 使用的表

    但这个函数也无法访问表变量中的列。

  • 使用FILTER() 生成表并在 SUM() 中使用

    尽管 FILTER() 没有问题,但 SUM() 无法访问表变量。

但等等…… FILTER() 没有生成任何错误。

我可以将 FILTER() 与 SUMX() 一起使用吗?

让我们看看这种方法是否有效。

有效解决方案 — FILTER() 和 SUMX()

第一个有效的方法是以下查询:

DEFINE
  MEASURE 'All Measures'[AverageSalePerCustomer] =
    VAR AverageSalesPerCust =
      ADDCOLUMNS (
        VALUES ( Customer[CustomerKey] ),
        "AverageSales",
        CALCULATE (
            AVERAGEX (
              'Online Sales',
              ( 'Online Sales'[UnitPrice] * 'Online Sales'[SalesQuantity] ) - 'Online Sales'[DiscountAmount]
              )
            )
          )

    VAR AvgSalesOver0 =
      -- Wrap the intermediary table variable in FILTER() to make it usable by aggregation function
      FILTER (
          AverageSalesPerCust
          ,[AverageSales] > 0
          )

RETURN
  SUMX (
      AvgSalesOver0,
      [AverageSales]
      )

EVALUATE
  ADDCOLUMNS (
      SUMMARIZECOLUMNS (
            'Date'[Year Month Short Name]
            ),
            "AverageSalePerCustomer", [AverageSalePerCustomer]
        )
  -- Order by the calculated column
  ORDER BY [Year Month Short Name] DESC

第一个表变量 AverageSalesPerCust 仍然是一样的。

第二步是使用 FILTER() 函数来定义新的表变量 AvgSalesOver0。

更新 2024 年 1 月:正如AlexisOlson的评论中所述,FILTER() 并不是必需的。在另一个案例中,我使用了相同的技巧而不使用 FILTER(),效果良好。目前,我不知道为什么在这里需要 FILTER()。

在这种情况下,我使用 FILTER() 函数将表变量 AverageSalesPerCust 转换为可以被聚合函数使用的表变量。

由于 FILTER() 需要至少两个参数,我必须添加“Filter” [AverageSales] > 0 以确保 FILTER 函数能正常工作。

出乎意料的是,SUMX 在访问使用 FILTER() 函数构建的表变量时没有问题,我可以使用这个表变量来计算聚合。

查询的(截断)结果如下:

图 4 — 工作解决方案的结果(作者提供的图)

但执行需要一些时间。

当我查看性能指标时,我注意到一些问题:

图 5 — 第一个解决方案的性能指标(作者提供的图)

总执行时间超过三秒,其中超过两秒的时间花在了公式引擎(FE)上,因为它需要处理三百万行数据。

这需要更高效,我们必须尝试找到更高效的解决方案。

使用CALCULATETABLE()的优化尝试

我将CALCULATETABLE()FILTER()结合以获得以下解决方案:

DEFINE
  MEASURE 'All Measures'[AverageSalePerCustomer] =
        VAR AverageSalesPerCust =
          CALCULATETABLE (
            ADDCOLUMNS (
                VALUES ( Customer[CustomerKey] ),
                "AverageSales",
                CALCULATE (
                  AVERAGEX (
                    'Online Sales',
                      ( 'Online Sales'[UnitPrice] * 'Online Sales'[SalesQuantity] ) - 'Online Sales'[DiscountAmount]
                    )
                  )
                )
                -- Add a Dummy filter
                ,Customer[CustomerKey] <> 0
              )

VAR AvgSalesOver0 =
    FILTER (
        AverageSalesPerCust,
        [AverageSales] > 0
        )

RETURN
  SUMX (
      AvgSalesOver0
      ,[AverageSales]
      )

EVALUATE
  ADDCOLUMNS (
    SUMMARIZECOLUMNS (
            'Date'[Year Month Short Name]
            ),
            "AverageSalePerCustomer", [AverageSalePerCustomer]
          )
  -- Order by the calculated column is possible
  ORDER BY [Year Month Short Name] DESC

区别在于我用CALCULATETABLE()函数封装了ADDCOLUMNS()

我添加了一个虚拟筛选器,希望执行计划会有所改变。

但这个版本保持了一切不变。

让我们尝试其他方法。

使用虚拟筛选器的优化尝试

下一个想法是用虚拟筛选器替换FILTER函数中的谓词:

VAR AvgSalesOver0 =
    FILTER (
          AverageSalesPerCust
          -- Dummy-Filter 
          ,1 = 1
          )

假设没有列引用的谓词(1=1)可能会导致更高效的执行。

不幸的是,这并没有改变结果。

是否没有优化的可能或必要?

我尝试了其他表函数来构建第一个表变量,即每个客户的平均销售额,但一切保持不变。

经过一番思考,我意识到这个问题无法在 DAX 中解决。

存储引擎(SE)无法按照我们需要的方式工作,或者我未能找到正确的解决方案。

因此,DAX 引擎将依赖于公式引擎的能力,从而将三百万行数据加载到内存中并在其中汇总数据。

有时我们找到的解决方案已经是最好的了。

三秒是用户愿意等待结果的最大时间限制。

我们可以认为这个解决方案已经足够好。

但等一下。我使用包含十多年数据的整个数据集执行度量。这是一个现实的场景吗?

更现实的情况是用户只会分析一两年的数据。

使用一个筛选器将查询限制为仅一年,执行时间骤降至不到两分之一秒。

图 6 — 仅选择一年的性能指标(图由作者提供)

由于这将是报告中的常见场景,因此解决方案已准备好在报告中使用。

经验教训: 记得在测试复杂度量的性能时使用实际场景和用例。

Lucas Santos 拍摄,照片来源于 Unsplash

解决方案模板

根据这些结果,解决方案的模板如下:

DEFINE
  MEASURE 'All Measures'[AverageSalePerCustomer] =
    VAR <TableVariable> =
          <Definition of the Table>

    VAR <FilterOfTable> =
        FILTER(
            <TableVariable>
            ,1=1 - Dummy Filter
            )
RETURN
  <AggregationWithXFunction over <FilterOfTable> >

另外,我们可以使用这种形式:

DEFINE
MEASURE 'All Measures'[AverageSalePerCustomer] =
    VAR <TableVariable> =
      FILTER(
          <Definition of the Table>
          , 1=1 - Dummy Filter
          )
RETURN
  <AggregationWithXFunction over <TableVariable> >

第二种形式更短,并且一切都包含在一个变量中。

我鼓励你添加注释以解释为什么要在包含FILTER 1=1的度量中添加FILTER()

对于不了解这种技术的人来说,这毫无意义。

参考文献

如果你不知道如何在 DAX Studio 中收集和解释性能指标,或者对那里显示的指标解释不确定,请阅读这篇文章,我将在其中详细探讨这个功能:

如何使用 DAX Studio 从 Power BI 获取性能数据

有时我们会遇到报告运行缓慢的情况,我们需要找出原因。我们将看到如何收集性能数据以及…

towardsdatascience.com

我使用了 Contoso 示例数据集,就像我之前的文章中一样。你可以从 Microsoft 这里 免费下载 ContosoRetailDW 数据集。

Contoso 数据可以在 MIT 许可下自由使用,如 这里 所述。

我扩大了数据集,以使 DAX 引擎的工作负荷增加。

在线销售表包含 7100 万行(而不是 1260 万行),零售销售表包含 1850 万行(而不是 340 万行)。

加入 Medium 并使用我的推荐链接 - Salvatore Cagliari

阅读 Salvatore Cagliari(以及 Medium 上其他成千上万的作者)的每一个故事。您的会员费用直接…

medium.com

关于机器为何能够思考

原文:towardsdatascience.com/on-why-machines-can-think-40edafce293d?source=collection_archive---------2-----------------------#2023-12-06

我们如何以最简单的方式思考思维呢?

Niya StoimenovaTowards Data Science Niya Stoimenova

·

关注 发表在 Towards Data Science ·15 分钟阅读·2023 年 12 月 6 日

--

打开潘多拉的盒子(图片来源:作者)

在 17 世纪,勒内·笛卡尔提出了一个相对较新的思想——“我思故我在”。这一简单的表述成为了西方哲学的基础,并定义了几个世纪以来我们对人类本质的理解。

从那时起,我们对作为人类的意义的理解发生了变化。然而,实际上,许多人仍然认为思考的能力是人性的最重要标志之一。

因此,ChatGPT(及类似模型)发布的瞬间,我们开始被大量讨论“它是否能够思考”的文章轰炸。

例如,《纽约客》思考了“ChatGPT 有怎样的思维?”;《华盛顿邮报》宣称“ChatGPT 可以通过逻辑测试,但别指望它具有创造力”;《大西洋月刊》则得出结论称“ChatGPT 比你想象的更笨”。我个人最喜欢的是这个喜剧演员的视频,他试图向一位从事人力资源工作的人解释 ChatGPT 是什么。

与任何其他容易引发猜测的复杂话题一样,人们对于 AI 模型的思维能力既过分夸大又不足代表。因此,让我们深入探讨一下。

思考就是推理

思维是一个复杂的构造,已经代表了许多不同的事物。因此,为了简单起见,我们可以假设思维或多或少与推理同义。

推理是一个定义得更清晰的概念,巧合的是,它正被越来越多地用作AI 的未来。它也是笛卡尔(在很大程度上)谈论思维时的意思。

所以,不如问“AI 能思考吗?”,不如问“AI 能推理吗?”。

简短的回答是是的。长答案是——它可以推理,但仅限于某些方式。

推理不是一个单一的概念。根据她试图完成的任务类型,有多种推理方式。因此,在这篇文章中,我们将首先简要介绍三种关键的推理类型,并检查机器的表现。然后,我们将探讨为什么机器无法进行常识推理以及在它们能做到之前需要回答什么问题。

推理入门

通常,当我们“思考”时,会使用三种主要的推理类型:演绎归纳溯因

演绎

简而言之,演绎是从给定的规则和被假定为真的案例中得出结论的能力。

想象一下:你在锅里加水,打开炉子,然后放入一个温度计。由于你在学校学到的东西,你知道水(通常)在 100°C 时沸腾。因此,当有人告诉你温度已达到 100°C 时,你可以安全地推断出水正在沸腾(你不必亲眼看到它发生也能“相当确定”它确实发生了)。

这里有一个有用的结构需要记住。

1. 规则: 水在达到 100°C 时沸腾

2. 案例: 水的温度是 100°C

3. 结果: 锅里的水在煮沸

因此,你从规则案例推理到结果

推理:从规则和案例推理到结果(图片由作者提供)

推理对我们进行科学研究至关重要。它也是机器最容易重现的推理类型。

按设计,几乎每台机器都执行某种形式的推理。你的简单的、毫不起眼的计算器每次你问 3+5 是多少时都会推导出答案。而它其中没有任何人工智能。

如果我们把它放在和上述水的例子相同的结构中,我们得到:

规则: 计算器已经“提供”了规则1+1 = 2

案例: 你问了问题3+5 = ?

结果: 根据规则,它可以计算/推导出3+5 = 8

简单。

归纳

归纳是从给定的观察集合中概括规则的能力。它对我们进行科学研究至关重要,因为它使我们能够定量地识别新的模式/规则。

让我们坚持水的沸腾例子。假设你从未被告知水在 100°C 沸腾。因此,每次你将一锅水加热到沸腾时,你都放入一个温度计并测量温度——100 次,1,000 次,10,000 次。然后,你的朋友们也做同样的事——无论你做多少次,温度总是 100°C。因此,你可以归纳出规则:“水在 100°C 沸腾”。

1. 结果: 水在沸腾

2. 案例: 每当你放入温度计时,它总是显示 100°C。

3. 规则: 水在 100°C 沸腾。

归纳:从结果和案例推理到规则(图片由作者提供)

然后,你就根据你观察到的模式定量地识别出了一个新的规则。为此,你从结果案例推理到规则

这种推理类型当然并不总是正确的。著名的是,欧洲人曾认为所有天鹅都是白色的,直到他们航行到了澳大利亚。我们也知道水的沸点并不总是 100°C(大气压力也起作用)。

仅仅因为某件事发生了 10,000 次并不意味着它总是正确的。但 10,000 次通常是一个安全的选择。

归纳对机器来说要困难得多。你的计算器当然无法执行归纳。然而,机器学习模型可以。实际上,这正是它们的主要目标:从一组给定的结果中进行概括。

让我们举一个简单的例子。假设我们有一个监督分类模型,将用于垃圾邮件检测。首先,我们有标记的训练数据集——垃圾邮件非垃圾邮件(即结果)。在这个数据集中,我们为每个结果编制了多个案例。基于这些案例,模型会归纳出自己的规则,这些规则可以在以后应用于一个它从未见过的案例。

  1. 结果: 垃圾邮件或非垃圾邮件

2. 案例: 大样本,包括垃圾邮件和非垃圾邮件示例

  1. 规则: 包含“这些模式和词语”的邮件很可能是垃圾邮件(在一定的概率范围内)

同样,当处理无监督模型如推荐系统时,过程也类似。我们首先向模型提供一个数据集,关于人们在超市购物时的倾向(结果)。一旦开始模型训练,我们会期望它首先聚类重复的模式(案例),然后引导出自己的规则,这些规则可以在类似的背景中应用。

  1. 结果: 关于人们购买的未标记数据

  2. 案例: 模型在数据集中发现的类似购买(例如,每个人买鸡蛋的人也会买培根)。

  3. 规则: 买鸡蛋的人也会买培根(在一定的概率范围内)

在这两种情况下,这些规则不一定对人类是可理解的。也就是说,我们知道计算机视觉模型“关注”图像的某个部分,但我们很少知道为什么。实际上,模型越复杂,我们了解其使用规则的机会就越小。

所以,这里我们可以看到——机器可以同时执行归纳和演绎推理。

演绎推理和归纳推理——科学的基石

广泛认为,归纳演绎的结合是我们推理能力的驱动力。正如我们的例子所示,当代的机器学习模型,即使是简单的模型,也能执行这两种操作。

它们首先利用归纳推理从给定的数据集中生成规则。然后,它们将这些规则应用于新的案例。例如,一旦我们向模型提供一张以前未见过的照片,它会利用其规则来推断出具体的结果(例如,它可以告诉我们提供的照片是倒置的)。

尽管如此,大多数数据科学家会同意,即使是最先进的机器学习模型也无法进行推理。为什么?

水煮沸的例子可以简单地说明为什么仅依靠演绎推理和归纳推理并不足够。确实,我们需要它们来生成规则“水在 100°C 时沸腾”),然后在各种案例中进行验证。然而,这种结合不足以解释我们是如何猜测到煮沸结果温度有关的。

除此之外,归纳和演绎推理的额外局限性也变得显而易见——它们在特定的上下文中有所限制,缺乏在不同领域之间转移知识的能力。这正是演绎推理发挥作用的地方,它提供了一个更全面的视角,展示了使我们能够进行直觉跃迁并将洞察力连接到不同领域的认知过程。

演绎推理

演绎推理是从单一惊讶的观察(即结果)中生成新假设的能力。我们每次依赖经验来解释某些事物时,都会这样做。

我们出去看到一条湿街。我们用之前可能下过雨的猜测来解释。我们不需要看到 1 万条湿街就知道下雨时街道会变湿。从技术上讲,我们甚至不需要以前遇到过湿街——我们只需知道当水接触物体时,会使物体变湿。

这意味着,如果我们回到水煮沸的例子,我们将有不同的推理方式:

1. 结果:水在煮沸

2. 规则:水在 100°C 时煮沸

3. 案例:水的温度必须是 100°C

溯因推理:从规则和结果推断到一个案例(作者插图)

我们从结果开始(就像我们进行归纳推理时一样),但我们将其与我们已知的规则结合(基于我们的世界知识和经验)。这两者的结合使我们能够得出一个案例(即,水因温度变化而煮沸)。

溯因推理是所有推理类型中最不可靠的。通过溯因推理得出的假设很可能是不正确的。例如,“湿街”的结果可能与雨无关——也许某个地方的管道在夜间破裂,或者有人认真地用水喷洒了街道。然而,雨似乎是一个合理的解释。

因此,溯因推理允许我们在日常情况中顺利前行,而不会陷入困境。也就是说,我们不需要尝试 1 万次来做出简单的决策。

据我了解,目前没有任何 AI 模型/算法能够进行溯因推理。不是以我刚刚描述的方式。

那些对 1960 年代和 1970 年代基于规则的系统熟悉的人,当然可以提到MYCINXCONSHRDLU,并声称它们能够进行溯因推理。其他人可能会提到斯坦福 AI 指数在20222023中引用的溯因推理例子,认为这是未来研究中最有前景的领域之一(即,溯因自然语言推理)。

所以,如果机器在 1970 年代能够进行“溯因推理”,为什么它们仍然不能做我所称的溯因推理(即常识推理)?

为什么溯因推理仍然难以捉摸

即使是最先进的模型也无法进行溯因推理的原因有两个:混淆架构

混淆:溯因推理与最佳解释推理(IBE)不同

历史上,在计算机科学领域,许多人将 IBE 和推断这两个术语互换使用。即使是 ChatGPT 也会告诉你这两者是相同的,或者推断是 IBE 的一个子集(取决于你如何提问)。斯坦福哲学百科全书也反映了这种观点。实际上,你在计算机科学的相关领域阅读的几乎每一篇关于推断的论文,都告诉你它与 IBE 相同。

然而,这两者是非常 不同 的构建。

一般来说,推断涵盖了生成新案例的行为(将学习转移到不同的背景中)。另一方面,IBE 是一种非常特殊且更具背景特定性的归纳形式,它不一定要求你定量地识别模式(即,你不需要观察一个模式 10,000 次来制定规则)。这些之间的具体区别是一个相当复杂的哲学讨论。如果你想深入了解这一点,我推荐这篇论文

然而,就本文而言,帮助我们的是将其放在规则案例结果结构中进行思考,并使用像 MYCIN 和斯坦福 AI 指数引用的推断自然语言模型这样的具体例子。

MYCIN 是 20 世纪 70 年代在斯坦福开发的早期专家系统,旨在帮助医生诊断传染病。它依赖于一个知识库,其中每个规则都以条件(IF——即案例)和结论(THEN——即结果)的形式表达。然后它利用了逆向推理机制,使其能够从一组症状和病人数据(分别是结果案例)中,向后推理以识别和分配从 0 到 1 的启发式确定性评分给那些可能最好地解释情况的规则。即,它从结果和案例推理到规则(即,归纳推理遵循的模式)。

斯坦福 AI 指数引用的作为推断自然语言推理例子(无论是生成假设还是选择最合理的假设)有点棘手。但这仍然不是推断。事实上,我会说,它类似于 IBE,但它遵循与我们迄今讨论的其他机器学习模型相同的模式——归纳,接着是演绎。

背景:在 2020 年,Bhagavatula 及其同事*,对一个他们称之为 ART 的数据集(包含约 20K 由观察对(O1, O2)定义的叙事背景和 200K 解释性假设)训练了一个变换器模型。训练后,他们给模型提供了一组观察数据,并要求它生成一个合理的假设以匹配(见图 4)。

图 4:推测自然语言推理(该图取自 arXiv:1908.05739

如图所示,当一个变压器模型(GPT-2+COMeT 嵌入)面对 O1(例如,“Junior 是一只 20+ 岁的老海龟”)和 O2(例如,“Junior 仍然很强壮”)时,它可以生成一个合理的假设(例如,“Junior 一直在池塘里和她的朋友们一起游泳”),这可能解释了为什么我们认为 Junior 仍然很强壮。

为什么这是 IBE 而不是推测?

让我们暂时从基础的 ML 模型中抽象出来,考虑一下人类如何执行这样的推理任务。首先,我们得到一个结果Junior 仍然很强壮,并且我们被告知案例是什么(即,Junior 是一只相对年长的海龟)。然后,从这些中,我们会尝试找出一个潜在的(上下文相关的)规则,以解释这个案例和结果。例如,我们可以推导出一只年老却仍然强壮的海龟

  1. 趋向于和朋友们玩耍 或

  2. 有良好的食欲 或

  3. 具有良好的生命体征

依此类推。

然后我们可以选择最符合我们判断的规则,并将其应用于“一个老海龟”的情况。这将允许我们假设 Junior 可能一直在和她的朋友们一起游泳

如前所述,从有限的观察中识别潜在规则表明了 IBE,而从这些规则中得出的结论则倾向于是一种较弱的推演形式。

我们作为人类理解到,当一个生物变老(无论是海龟还是人类)时,它们的活力往往会下降(可以说是)。这使我们能够生成相对“充满意义”的规则。然而,变压器模型无法做到这一点。它能做的,是通过归纳和推演来改善对最可能的单词组合的预测。模型没有基本理解,当 Junior 玩得开心时,她仍然很强壮。

实际上,有人甚至可以说推测自然语言推理的工作类似于 链式思维 提示。尽管如此,指令是以不同的方式呈现给变压器的。

希望所有这些例子突显出计算机科学所称的推测其实并不是推测。相反,它似乎是一种特定上下文的归纳变体。

架构:当代 ML 模型受限于归纳

状态最先进模型无法进行推测的第二个原因在于它们的架构。根据定义,机器学习模型是一个生成归纳的机器。这种倾向被它们所谓的归纳偏差进一步加强。

归纳偏差是机器学习中的一个重要概念,指的是模型对其应学习的函数类型所持有的固有假设或偏好。偏差通过限制可能的假设集来指导学习过程,从而提高学习的效率和准确性。

例如,决策树关注于层级结构和简单的决策边界。支持向量机旨在找到类别之间的宽边距。卷积神经网络强调图像中的平移不变性和层级特征学习。递归神经网络偏向于序列模式,贝叶斯网络建模概率关系,正则化线性模型通过惩罚大系数来偏好简单模型,而通用的变换器如 GPT-4 则以捕捉数据中的序列依赖性和关系为特征。这些偏差塑造了模型的行为及其对不同任务的适用性。它们还使得将学习成果从一个情境转移到另一个情境变得困难。

我们仍需的

好的,到目前为止,我们讨论了推理的基础知识,我们看到机器确实可以进行推理。它们能够执行演绎推理和归纳推理。然而,我们直观上所称为“思考”的过程是由归纳推理促进的,由于混淆和架构问题,它依然难以捉摸。

那么,我们还需要什么?

我们如何构建能够执行归纳推理的系统?

首先,我们需要能够准确地定义什么是归纳推理,并描述它是如何工作的。遗憾的是,这方面的研究不多。尤其是当涉及到归纳推理如何与演绎推理和归纳推理相关联时,或者它如何被机器操作化时。学者们唯一一致的观点是,归纳推理发生在演绎推理和归纳推理之前。

那么,什么是归纳推理?

归纳推理并不是一个单一的构造。我个人遇到过大约 10 种不同类型,具体取决于它们所涉及的科学领域。即使是引入归纳推理概念的哲学家查尔斯·皮尔斯,也并没有以一致的方式提及它。

然而,有三种主要类型可以描述归纳推理所服务的基本功能。确切的功能及其形成过程过于复杂,无法在此文中详细讨论。所以,以下是简要说明。

首先,我们有最直接的归纳推理类型——解释性。就是我们迄今为止讨论的那种。使用它时,我们从一个观察(结果)和一个易于识别的规则开始。然后,这两者的组合使我们能够对情况做出猜测。这在水煮沸的例子中得到了很好的说明。

然后,我们有了创新性推断——一种允许我们从一个(期望的)结果推理到一对案例规则的推断。也就是说,我们只知道我们想要创造什么结果,然后我们需要逐步定义一个案例-规则配对,这样才能实现所述结果。这种类型的推断通常用于生成新颖的想法。

最后,我认为我们有了最有趣的一种推断——操控性推断。我们在唯一知道结果(期望的或其他)部分的情况下使用它。此外,这个结果“存在”的背景由多个隐藏的相互依赖关系定义。因此,不能立刻开始寻找/生成合适的案例-规则配对。相反,我们需要更好地理解结果以及它与环境的关系,以便减少不确定性。

这就是所谓的思维装置/认识中介发挥作用的地方。这可以采取例如基本草图、原型或 3D 模型的形式,用于增强我们对问题的理解。通过在目标环境中操控这个中介,我们能更深入地理解背景。因此,我们能够更好地探索规则案例的潜在组合。此外,它还使我们能够建立关联,帮助将知识从一个领域转移到另一个领域。这种思维的简化版本通常在立体几何中应用。

正如我所说,仍需要做大量工作来解释这些推断类型之间的关系及其与其他推理方法的相关性。然而,这项工作变得越来越关键,因为它有可能为不同领域之间的洞察力转移提供宝贵的见解。特别是在我们看到该领域对推理的新兴趣的背景下——无论是通过 IBE、“通过模拟和例子进行推理”,还是系统 1 和系统 2 思维。

在所有这些情况中,理解如何区分机器可以进行的不同类型的推理似乎尤为重要。因为,确实,机器是可以进行推理的。它们只是无法进行全方位的推理。

关于 IBE 的其他有趣工作可以在这篇论文中找到(他们确实将推断等同于 IBE)。

在生成式 AI 时代发展数据职业

原文:towardsdatascience.com/one-blogpost-comment-growing-the-data-career-in-the-generative-ai-era-9ef1242d3019

提高对学习三个基本数据概念的认识

Marina TosicTowards Data Science Marina Tosic

·发表于Towards Data Science ·阅读时间 5 分钟·2023 年 7 月 6 日

--

图片由Brett Jordan拍摄,来源于Unsplash

作为一名数据专业人士,我对生成式 AI 领域的所有最新发展感到惊讶。

虽然一些人称其为炒作,并愿意迅速将其视为另一个技术趋势,但其他人则坚信它是一个游戏规则的改变者。

无论你支持哪个观点,都很难忽视生成式 AI 可能为未来教育和工作场所带来的变革性可能性。

为了支持这一说法,只需提到哈佛大学将在今年秋季(2023 年秋季)将 AI 聊天机器人引入课堂,以接近一对一的师生比例。学生们将使用哈佛开发的聊天机器人来引导他们找到解决方案,而不是直接提供简单的答案。

对我来说,这明显表明哈佛大学正在引发一波改变新一代学习和工作方式的变革。

这意味着,生成式 AI 不仅仅是一个过时的趋势,我们需要开始寻找适应这一技术的新方法。

尽管我对生成式 AI 持积极看法,但我从未有过如此强烈的FOMO

换句话说,尽管我在过去的 12 年中经历了各种数据角色,并获得了机器学习概念的知识,但我无法跟上生成式 AI 领域的新发展。

新的术语、提示工程的概念、新的大型语言模型的开发、在其上构建的大量应用和解决方案、新的电子学习课程以及关于这个话题的大量帖子——这一切 都是 令人难以承受的。

此外,我无法摆脱这样的不安感——感觉我的一些数据技能现在已经,嗯,过时了。

我的业务同事用几个击键就取代我辛苦获得的查询技能的想法是令人害怕的。

然而,经过深思,我不得不承认,虽然有些(但仅有一些)技能会被取代,我并不介意。每周执行几次临时查询以回答相同的重复业务问题是我从不喜欢做的事情。

我意识到,“我”在数据仓库存储的数据和业务洞察生成之间只是拖慢了决策过程。

我还意识到,这种过渡,即我的替代,并不会一夜之间发生。

首先,当前的开发环境需要调整,即需要变得更加“业务用户友好”,而不是“开发者友好”。

其次,业务用户需要获得对“中心背后”是什么的技术理解。从自然文本条目生成分析洞察的自由也带来了同样的问题。

像慢速洞察生成、洞察生成不准确、在没有新输入(新数据源)的情况下丰富洞察,以及洞察质量检查的技术过程等问题仍将存在。

而且仍然需要有人来处理和“修复”这些问题,为业务用户服务。

换句话说,生成式 AI 无法轻易取代基础性的数据知识。

那么,我所说的“基础性”数据知识是什么意思?

为了支持我对上述问题的回答,归结为三个核心概念:

  1. 构建数据架构

论点: 技术知识和对如何在特定行业中设计适当数据架构的理解至关重要。

让我以金融科技行业为例。

在金融科技行业,构建数据平台时需要考虑严格的规定,即PCI 数据安全标准。在这些标准之上,有时还会有市场基础的标准。

例如,在瑞士,还有FINMA规定,需要考虑这些规定以使你的数据平台以及数据架构合规。

当然,规定可能会发生变化,这意味着数据架构需要跟随这些变化。这对生成式 AI 提出了真正的挑战。

生成式 AI 可以在一定程度上支持架构设计和开发。

但是,它无法在法规变化的行业中设计可定制的架构解决方案。

如果没有在类似的历史实例上进行训练,它不具备应用特定架构适应性的能力。

2. 数据质量管理

论点:垃圾进 — 垃圾出” 的说法将始终有效,所有在数据领域工作的人都清楚低质量数据输入的成本。

使用生成式 AI 解决方案,低质量输出的成本更高。

例如,我需要参考一下我在《卫报》阅读的 最近一篇文章。那是一篇关于一位律师使用 ChatGPT 提供类似的过往法律案件例子的文章。他想要支持他的论点,即客户对航空公司提起的诉讼不应被驳回。

我想你可以想象故事的经过:当航空公司的律师检查了所引用的裁决和法律引文时,他们发现这些引用都不存在。简而言之,ChatGPT 正在 幻觉

从这篇文章中得出的结论是,低质量的数据输出可能会使你失去业务,导致整个项目停摆,或者丧失客户和声誉。

因此,数据专业人员将更加忙碌于管理数据输入和输出的质量。

3. 数据隐私与安全

论点: 作为数据专业人员,你应该了解 SQL 注入数据库安全 的概念。

随着生成式 AI 的发展以及简单提示的使用,数据仓库攻击和数据泄露场景比以往任何时候都更容易发生。

提示注入 的危险——例如,某人可能通过一次文本输入就能够删除整个数据库或检索机密记录——需要成为数据安全的核心问题。

这意味着数据和 IT 专业人员将继续在保护和安全数据方面发挥关键作用。

总结: 具备基础数据概念知识的数据专业人员 将会作为“常量”留在职场,高效管理数据,识别问题,并优化解决方案,以确保合规、安全和可靠。

这是生成式 AI 无法轻易替代的部分。

所以,如果你是一位年轻的专业人士,寻求在生成式 AI 时代如何发展数据职业的建议,那么首先要学习上述核心概念。

相信我:投资时间和资源来获得基础数据知识,将在你的数据职业生涯中带来长期的回报。

生成式 AI 将提升你在这些领域的学习曲线和工作表现,但它只能帮助你达到一定程度。那些“重要”的工作仍然取决于你和你的知识。

One Hot 编码

原文:towardsdatascience.com/one-hot-encoding-scikit-vs-pandas-2133775567b8?source=collection_archive---------2-----------------------#2023-03-13

Scikit Learn 还是 Pandas?

安德拉斯·盖费斯数据科学趋势 安德拉斯·盖费斯

·

关注 发布于 数据科学趋势 ·8 分钟阅读·2023 年 3 月 13 日

--

One hot 编码是一种流行的表示分类数据的方法(所有图片均由作者提供)

摘要

sklearn.preprocessing.OneHotEncoderpandas.get_dummies都是流行的选择(实际上是唯一的选择,除非你想自己实现)来执行独热编码。大多数科学家推荐使用scikit,因为它使用其 fit/transform 范式,提供了一种内置机制来学习训练集中的所有可能类别,并将其应用于验证或实际输入数据。因此,这种方法将防止在验证或实际输入数据中不包含所有类别或类别顺序不同引发的错误。

在本文中,我将争论,这场竞争没有明确的赢家。对于使用pandas DataFrame 的数据科学家来说,使用原生的pandas get_dummies函数有明显的好处,而且有一种非常简单的方法可以避免上述提到的问题。

介绍

什么是独热编码

如果你已经知道这些内容,可以安全地跳过这一部分。

独热编码(以下简称OHE)是一种将分类数据编码为数值数据的技术。它主要用于机器学习应用。例如,假设你正在构建一个预测动物体重的模型。你的一个输入将是动物的类型,即猫/狗/鹦鹉。这是一个字符串值,因此像线性回归这样的模型无法处理它。

第一个想到的方法是给动物赋予整数标签,并用相应的整数表示替换每个字符串。但是,如果这样做,你会引入一些人为的排序(例如,鹦鹉对“动物”权重的影响是猫的三倍)。相反,OHE 为每种动物创建一个新的输入变量(即列),并根据动物是否是选定的那个,将该变量设置为 1 或 0。示例:

独热编码(所有图片由作者提供)

在这种分离之后,你的线性模型可以独立地为这些新列分配权重。实际上,你并不需要 3 列来表示这 3 种动物。你可以选择其中任何一列来丢弃。换句话说,如果它既不是狗也不是猫,那它就只能是鹦鹉。

ScikitPandas

scikit-learnpandas都提供了执行此操作的方法,数据科学家之间关于使用哪个方法的争论已经有很长的历史。如果你搜索一下,会找到很多相关文章。我重新讨论这个话题的原因是这两个库都在不断发展,有一些新的功能在决策时值得考虑。

文章范围

在编码时,可以指定几个选项,比如是否使用稀疏或密集的数据表示,或者是否保留所有新列或删除其中之一。这两个库都支持许多这样的功能,但在这篇文章中我不会关注它们。本文的重点是类别的处理,如下所述:

如果你进行训练/测试拆分(无论是手动还是使用sklearn.model_selection.train_test_split自动化),可能会出现训练数据集中不包含任何鹦鹉的情况。从理论上讲,这不一定是一个问题,如果某些类别缺失,你仍然可以进行预测,只是预测可能不够准确。但如果你的代码没有为这种差异做好准备,那么由于拟合数据中的列与用于预测的数据的列不一致,你的代码会出错。

在这篇文章中,我将关注以下几点:

  • 如何告诉 OHE 所有类别的集合,并确保编码一致地应用于训练/测试/验证/实际数据?

  • 如何将编码应用于 pandas DataFrame?

  • 如何在 scikit 管道中集成编码器?

Scikit-learn

通常的做法是使用sklearn.preprocessing.OneHotEncoder,因为通过其 fit/transform 范式,你可以使用训练数据集来“教会”类别,并将其应用于你的实际输入数据。

主要步骤如下:

其中 X_train 是你的训练输入数据,real_input 是你希望应用模型的真实输入数据(真是个惊喜!)。

如果你“幸运”,那么所有可能的类别都会出现在 X_train 中,编码器对象学习这些类别及其对应的映射,并会为真实输入生成正确的列和正确的列顺序。我们需要注意的是,sklearn.preprocessing.OneHotEncoder产生的是一个 numpy 数组,所以列的顺序很重要。

但你不应该假设自己总是会幸运。例如,如果你使用交叉验证来随机重复地将数据拆分为训练和测试部分,你可能会发现实际的训练数据缺少一些类别。这会导致错误,因为你无法转换测试集中的数据。

sklearn 为这种情况提供的解决方案是明确地向 OneHotEncoder 对象提供可能的类别,如下所示:

你需要在类别参数中提供一个列表的列表,以便为每个输入列指定类别。

使用 scikit 的另一个常见步骤是进行原始 numpy 数组和 pandas DataFrame 之间的转换。你可以使用sklearn.compose.make_column_transformer来实现,或者手动实现,使用 OneHotEncoder 的.get_feature_names_out()方法来获取新特征的列名。让我们来看一下这两种方法的示例。我将添加另一列,Color,以使示例更加信息丰富。

指定输入和编码器

列转换器方法

我们可以看到列转换器完成了部分工作,但如果我们想使用 DataFrames,还需要做额外的工作。我也不太喜欢这些列名,但除了手动后处理,没有其他方法可以调整它们。注意,列是为所有可能的类别创建的,而不仅仅是那些出现在输入中的类别。

手动方法

我称之为手动方法,因为我们直接使用 OneHotEncoder 对象,并自己处理选择和追加列的操作。

我们不得不做一点额外的手动工作,但列名更友好。此外,在较新的 scikit 版本(1.3 及以上)中,我们可以微调这些名称。

管道

一个 scikit 管道 是一种方便的方式来顺序应用一系列转换。你可以使用它来组装几个步骤,这些步骤可以一起进行交叉验证,同时设置不同的参数。

手动/原始方法通常不适合包含在管道中,因为需要额外的步骤来选择和添加列。而列转换器方法则适用于管道。我们所做的额外步骤仅仅是将 numpy 数组转换为 DataFrame,这对管道来说不是必需的。

Pandas

pandas.get_dummies 函数不遵循 fit/transform 模型,也没有明确的输入参数来指定可用的类别。因此,可以得出结论,它不适合这个任务。然而,这个结论并不正确。

Pandas 本身支持通过 pandas.CategoricalDtype 处理分类数据。你需要做好功课,并正确设置列的类别。一旦一致完成这些操作,你就不再需要拟合步骤了。

使用分类类型有额外的好处,例如减少存储空间或检查拼写错误。让我们看看这是如何做到的:

现在我们需要做的就是调用 get_dummies 函数。

正如我们所看到的,在类别正确设置之后,不需要额外的工作就可以获得一个漂亮的 DataFrame。实际上,我在上面有点作弊:默认情况下,get_dummies 会转换所有具有对象、字符串或类别数据类型的列。如果这不是我们想要的,我们可以通过使用 get_dummies 的 columns 参数显式指定要转换的列列表:

我们在上面提到了 scikit 管道。为了使变换器适用于管道,它必须实现 fit 和 transform 方法,而 get_dummies 函数显然没有做到这一点。幸运的是,为此任务创建一个自定义变换器非常简单:

现在我们可以像使用其他 scikit 变换器一样使用我们新的类,我们甚至可以将其嵌入到管道中。

在编写这个变换器时,我们假设相关列已经具有分类数据类型。但是,只需添加几行代码到 GetDummiesTransformer 中即可在 init 函数中允许指定列。

结论

正如我们所见,明确指定 scikit OneHotEncoder 和 pandas get_dummies 方法的可用类别是可能的,也非常推荐。(记住:明确优于隐含!)这意味着这两种方法都非常适合这个任务,所以选择哪个方法是个人偏好。对于 scikit,明确的类别设置是通过将参数传递给 OneHotEncoder 类的构造函数实现的,而对于 pandas,我们必须设置分类数据类型。

  • 使用“原始”版本的 OneHotEncoder(即没有列变换器)需要最多的手动调整,我在实际中仅在非常少见的情况下会使用这种方法。

  • 如果你的过程依赖于 scikit 管道(这有许多优点),那么使用 scikit OneHotEncoder 和列变换器似乎是最自然的选择。

  • 如果你喜欢逐步处理数据,从一个 DataFrame 到另一个 DataFrame(这在探索阶段可能是一个不错的选择),那么我一定会选择pandas.get_dummies方法。

就这样,希望你从我的帖子中学到了东西。像往常一样:点赞、订阅、分享、评论!

一步使决策树产生更好的结果

原文:towardsdatascience.com/one-step-to-make-decision-trees-produce-better-results-b0ccd6738200?source=collection_archive---------10-----------------------#2023-11-23

背景、实施和模型改进

Gabe VerzinoTowards Data Science Gabe Verzino

·

关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 11 月 23 日

--

在树木中(作者提供的照片)

决策树(DT)被放弃得太快了。

就像这样发生的:

DT 已训练。自然过拟合出现。超参数调整(令人不满意)。最后,树被替换为随机森林。

尽管这可能是性能上的快速胜利,但这种替代更注重“黑匣子”算法。这并不理想。只有决策树能产生直观的结果,为业务领导提供比较权衡的能力,并在流程改进中起到关键作用。

如果你无法理解甚至解释某件事情,那么它就不会进入生产环节。在那些即使小的失败也会带来极端风险的行业中,这一点尤为真实,比如医疗保健领域。

(旁注:人们经常问“随机森林生成特征重要性,这难道不解释了哪些特征是重要的吗?”并不完全如此。特征重要性几乎立即被解释为 因果 驱动因素(例如,具有与目标的依赖性的特征),但它们只是 模型 驱动因素。虽然在这方面对技术人员有所帮助,但特征重要性通常是:(1)在弱模型中无用(2)在具有高基数的特征中膨胀,以及(3)偏向于相关特征。这是另一条完全不同的探索路径,但基本上就是这样。)

决策树如何做出决策

坚持使用决策树将保留您有效沟通结果的能力,但如何使它们性能卓越呢?超参数调整只能帮助到一定程度。无论如何,都应该进行深思熟虑的特征工程。

事实证明,特征数据的特定结构可能使其更好地适应底层的决策树算法,从而使决策树能够做出更好的决策。

在底层,决策树通过在您提供的所有数据中创建正交决策边界(垂直分割)来分离类别。它以一种贪婪的算法方式进行这一操作 —— 首先选择最佳分割的特征,然后转向其他特征中不那么优化的分割。

我们可以直观地检查我们的特征以寻找正交的决策边界。让我们查看以下公开可用的乳腺癌数据集中的特征。在下面的顶部图中,绘制“最差周长”和“平均周长”可以产生良好的正交决策边界,可以很好地分离恶性和良性类别。因此,这些特征将是 DT 模型中的很好的选择。

作者提供的图片

上图显示的底部显示了“平均面积”和“平均周长”,DT 生成了正交决策边界(因其固有性质),但这些是不必要复杂的。也许,对角线分隔在这里会更好,但这不是 DT 分类器的分割方式。此外,DT 对训练数据中甚至是小变化(如异常值)非常敏感,这些变化已知会产生完全不同的树结构。

为了适应决策树的这种独特和基础机制 —— 并最终改善性能和泛化能力 —— 可以应用主成分分析(PCA)。

PCA 在两个重要方面提升了 DT 的性能:

(1) 将关键特征定向在一起(解释最大方差的特征)

(2) 减少特征空间

实际上,PCA + DT 过程自然地展现了上述顶部图中您看到的“最差周长”和“平均周长”特征。这两个是最具预测性的变量,毫不奇怪地具有出色的正交决策边界。

实施过程

请记住,PCA 适用于连续数据。乳腺癌数据集完全由连续变量组成。(另一方面的注释:我看到 PCA 被用于分类变量,不建议这样做。名义级别没有隐含的距离,序数级别并不总是等距离的,强制在离散特征上进行距离表示通常将变量重构为毫无意义的东西。另一个时间的另一个切入点。)

让我们开始下载所需的软件包,并将我们的乳腺癌数据集转换为特征X和目标变量y

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns

# Load the Breast Cancer dataset
data = load_breast_cancer()
X = data.data
y = data.target

可以调用此数据集的数据框架头部进行检查。

cancer = load_breast_cancer()
df = pd.DataFrame(np.c_[cancer['data'], cancer['target']],
                  columns= np.append(cancer['feature_names'], ['target']))
df.head()

作者的图片

首先,在没有 PCA 的情况下训练 DecisionTreeClassifier,并收集这些预测(original_predictions)。

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Fit a Decision Tree Classifier on the non-PCA-transformed dataset
original_tree = DecisionTreeClassifier(random_state=42)
original_tree.fit(X_train, y_train)

# Predictions on the original dataset
original_predictions = original_tree.predict(X_test)

现在,应用 PCA 来选择能够解释训练集中大部分方差的最小维数。而不是任意选择这个维数,可以使用“拐点法”来确定能够解释 99%方差的维数(如下所示硬编码)。

# Finding the optimal number of PCA components using the elbow method
pca = PCA()
pca.fit(X_train)

explained_variance = pca.explained_variance_ratio_
cumulative_explained_variance = np.cumsum(explained_variance)

# Plot explained variance
plt.plot(range(1, len(cumulative_explained_variance) + 1), cumulative_explained_variance, marker='o')
plt.xlabel('Number of Components')
plt.ylabel('Cumulative Explained Variance')
plt.title('PCA Explained Variance')
plt.grid()
plt.show()

# Determine the optimal number of components (elbow point)
optimal_num_components = np.where(cumulative_explained_variance >= 0.99999)[0][0] + 1

print(f"Optimal number of PCA components: {optimal_num_components}")

基于图表形成“拐点”的视觉观察,发现 6 个 PCA 成分解释了训练集方差的 99%。

作者的图片

现在在训练集上应用 PCA 来捕获 6 个主成分。您可以使用奇异值分解(SVD)进行此操作,这是一种标准的矩阵分解技术(此处不涉及的过程)。与以前一样,在 PCA 转换的训练集上训练 DecisionTreeClassifier,并收集这些预测(pca_predictions)。

# Apply PCA with the optimal number of components
pca = PCA(n_components=optimal_num_components, svd_solver="full")
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# Fit a Decision Tree Classifier on the PCA-transformed dataset
pca_tree = DecisionTreeClassifier(random_state=42)
pca_tree.fit(X_train_pca, y_train)

# Predictions on the PCA-transformed dataset
pca_predictions = pca_tree.predict(X_test_pca)
# Confusion matrix
pca_cm = confusion_matrix(y_test, pca_predictions)

# Precision and Recall scores for the original dataset
original_precision = precision_score(y_test, original_predictions, average=’weighted’)
original_recall = recall_score(y_test, original_predictions, average='weighted')
original_accuracy = accuracy_score(y_test, original_predictions)

# Precision and Recall scores
pca_precision = precision_score(y_test, pca_predictions)
pca_recall = recall_score(y_test, pca_predictions)
pca_accuracy = accuracy_score(y_test, pca_predictions)

# Output precision and recall scores
print(f"Original Dataset - Precision: {original_precision:.4f}, Recall: {original_recall:.4f}, Accuracy: {original_accuracy:.4f}")
print(f"PCA-Transformed Dataset - Precision: {pca_precision:.4f}, Recall: {pca_recall:.4f}, Accuracy: {pca_accuracy:.4f}")

现在我们可以比较我们的原始预测(未经 PCA 转换)和 pca 预测(经 PCA 转换),观察我们评估指标(准确率、精确度和召回率)的任何相对改进。

与原始的决策树训练数据相比,当我们首先对数据集进行 PCA 转换,然后进行决策树训练时,我们在各方面都有所改进:

我们可以绘制混淆矩阵,显示两个决策树在恶性和良性肿瘤分类改进方面的相对改进。

# Plot the confusion matrices
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
sns.heatmap(original_cm, annot=True, fmt="d", cmap="Blues", xticklabels=data.target_names, yticklabels=data.target_names)
plt.title("Original Decision Tree Confusion Matrix\nPrecision: {:.2f}, Recall: {:.2f}".format(original_precision, original_recall))
plt.xlabel("Predicted")
plt.ylabel("True")

plt.subplot(1, 2, 2)
sns.heatmap(pca_cm, annot=True, fmt="d", cmap="Blues", xticklabels=data.target_names, yticklabels=data.target_names)
plt.title("Decision Tree with PCA Confusion Matrix\nPrecision: {:.2f}, Recall: {:.2f}".format(pca_precision, pca_recall))
plt.xlabel("Predicted")
plt.ylabel("True")

plt.tight_layout()
plt.show()

作者的图片

最后,识别生成 6 个主成分所使用的我们原始特征是很有价值的。从技术上讲,PCA 生成新的特征,这些特征是原始特征的线性组合。这些新特征彼此正交,并按解释的方差排序。但是,调用components_attribute可以识别用于创建这些组件的特征。

# Get the explained variance ratio of each principal component
explained_variance_ratio = pca.explained_variance_ratio_

# Get the PCA components
pca_components = pca.components_

# Create a DataFrame to display the contributions of original features to each principal component
df_components = pd.DataFrame(data=pca_components, columns=data.feature_names)

# Print the top features contributing to each principal component
for i in range(optimal_num_components):
    print(f"Top features for PC{i+1}:")
    sorted_features = df_components.iloc[i].abs().sort_values(ascending=False)
    print(sorted_features.head())
    print("\nExplained Variance Ratio:", explained_variance_ratio[i])
    print("=" * 50)

因此,对于我们选择的 6 个主成分,模型使用以下 5 个特征创建了这些主成分:

作者的图片

结论

决策树往往被过早放弃,转而使用更高性能的算法。虽然最高性能很重要,但这可能不是最好的选择——这个决定最终取决于你的利益相关者需求以及解释模型为什么会建议特定结果(参见“可解释人工智能”)。

与其寻求最先进的技术算法,不如通过深思熟虑的特征工程和主成分分析来优化数据准备,从而给决策树提供最佳机会,以展示其直观的决策能力。

感谢阅读。很高兴在LinkedIn上与任何人建立联系!如果你想分享你当前面临的有趣的数据科学挑战,请留下评论或发私信,我会很乐意尝试探讨/撰写相关内容。

我最近的文章:

调试逻辑回归错误——它们的含义及修复方法

使用贝叶斯网络预测医院辅助服务量

为什么平衡类别被过分炒作

只有在你知道如何独立完成任务时才使用 LLMs

原文:towardsdatascience.com/only-use-llms-if-you-know-how-to-do-the-task-on-your-own-0d56e0d07572

否则,你可能会遭遇无声的错误或严重的后果

Soner YıldırımTowards Data Science Soner Yıldırım

·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 10 月 25 日

--

(图片由作者使用 Midjourney 创建)

对于我们大多数人(或所有人)来说,LLMs 是神秘的盒子,能够令人惊讶地快速完成复杂的事情。只要它们给我们所需的,我们通常不关心“如何”部分。

ChatGPT 和其他大型语言模型无疑是生产力的提升者。它们可以轻松处理各种任务,否则这些任务将会枯燥且耗时。

然而,我们不能完全依赖它们。例如,在数据分析方面,我们如何确保 ChatGPT 对数据的见解是准确的?是的,它知道 Pandas 这个流行的数据分析库,但如果它犯了错误呢?或者,如果它部分完成了任务而未能完成剩余部分会怎样?

互补 ChatGPT 的最佳解决方案就是你自己。你需要知道如何独立完成任务,这样:

  1. 你可以确保 ChatGPT 的解决方案是正确的。

  2. 当 ChatGPT 无法执行或不知道如何执行时,你可以替代它。

在本文中,我将展示三个示例来支持我之前提到的两个主张。

示例 1:使用 Pandas 进行数据清理

我有一个包含二手车价格和其他一些属性的数据集。这是一个有点混乱的数据集,需要大量清理。我使用了 ChatGPT 高级数据分析(ADA)插件来完成这项任务。

大部分情况下,它表现良好。然而,有一个非常特定的操作 ChatGPT ADA 未能完成。因此,如果我自己不知道如何做,这项任务将无法完成。

让我首先解释 ChatGPT ADA 无法处理的部分。

可能由于数据输入错误,make(即汽车品牌或制造商)列中的一些行包含年份值。当我调查这些行时,我发现 make 值被记录在 model 列中。为了修复这一问题,需要从 model 列中提取 make 部分,并用来替换 make 列中的年份值。

这是如何修复它的:

(图片由作者提供)

我向 ChatGPT ADA 询问了几次,但它未能提供修复此问题的方法。它的最佳尝试是用“未知”替换这些年份值。

然后,我详细编写了以下提示,以便让 ChatGPT ADA 给我提供解决任务所需的 Python 代码:

对于实际是年份值的 make 列中的行,model 列包括正确的 make 值。对于这些行,从 model 列中提取第一个词,并用它替换 make 列中的年份值。

这是 ChatGPT ADA 为这个特定问题编写的代码:

# Identify rows where 'make' is a year value
is_year = data['make'].str.isnumeric() & data['make'].str.len() == 4

# For these rows, split the 'model' column and get the first word
correct_make = data.loc[is_year, 'model'].str.split().str.get(0)

# Update the 'make' column with the correct values
data.loc[is_year, 'make'] = correct_make

虽然这不是最优解,但它完成了任务。

示例 2:替换缺失值

我试图清理的数据集中有缺失值(用“其他”、“未知”或 None 表示)在一列中。我要求 ChatGPT 用该列中最频繁的值替换它们。

(图片由作者提供)

这是 ChatGPT 的回复:

most_common_cylinder = data['cylinders'].mode()[0]
data['cylinders'] = data['cylinders'].apply(lambda x: most_common_cylinder if "cylinders" not in x else x)

这个方法是正确的,因为它用最常见的值替换了不包含“cylinders”的值。然而,它包含了 apply 函数的使用,这在处理大数据集时是不建议的。apply 函数不是向量化操作,可能成为性能瓶颈。

更好的方法是使用以下向量化操作:

df.loc[~df["cylinders"].str.contains("cylinders"), "cylinders"] = df["cylinders"].mode()[0]

如果我不知道 Pandas,我可能无法意识到 apply 函数的使用可能会导致性能问题,并寻找替代解决方案。

示例 3:以更符合 Python 风格的方式编写单元测试

我想测试 ChatGPT 是否能改进单元测试或使其更符合 Python 风格。

我编写了以下单元测试,实际上非常简单:

def test_query(submission):
    query = submission.query
    assert query.lower().count("where") == 1

当我要求 ChatGPT 改进它时,我期望的更新如下:

def test_query(submission):
    assert submission.query.lower().count("where") == 1

第二个版本消除了创建不必要的中间变量 query

在第一次尝试中,ChatGPT 编写了如下单元测试:

# first solution
def test_query(submission):
    query = submission.query
    assert query.count("where", flags=re.IGNORECASE) == 1

这是错误的。count 方法没有 flags 参数。另外,这比我的第一次尝试更简单(或更符合 Python 风格)吗?

第二次尝试是正确的,但仍然没有更简单。

# second solution
import re

def test_query(submission):
    query = submission.query
    assert len(re.findall(r'where', query, flags=re.IGNORECASE)) == 1

然后,我告诉 ChatGPT 这不比我的解决方案更简单,并建议使用以下方法(这正是我所考虑的):

def test_query(submission):
    assert submission.query.lower().count("where") == 1

ChatGPT 批准了我的新建议,接受了它更简洁且符合 Python 风格。

最后的思考

我在这篇文章中展示的示例用例并没有降低 ChatGPT 或其他 LLM 的实用性。我已经用它完成了许多不同的任务,并得到了令人满意的结果。

我想强调的是,他们可能会犯错误。这些错误有些是明显的,有些则可能是隐性的。为了确保你获得准确的结果,留意 ChatGPT 的操作方式。我建议不要完全依赖你不了解的工具。你仍然可以用它来学习新工具,但在进行任何重要操作之前,一定要进行测试。

如果你喜欢这篇文章,请记得点赞和评论,以帮助我获得更多的支持。 关注我 以获取更多关于 Python、数据科学、机器学习和人工智能的内容。

感谢阅读。如果你有任何反馈,请告诉我。

ONNX:用于可互操作深度学习模型的标准

原文:towardsdatascience.com/onnx-the-standard-for-interoperable-deep-learning-models-a47dfbdf9a09

图片由Jonny CaspariUnsplash上提供

了解使用 ONNX 标准在框架和硬件平台之间部署模型的好处

Marcello PolitiTowards Data Science Marcello Politi

·发表在Towards Data Science ·5 分钟阅读·2023 年 1 月 24 日

--

我第一次听说 ONNX 是在 INRIA 实习期间。我当时在用 Julia 语言开发神经网络剪枝算法。那时还没有很多预训练模型可以使用,因此利用 ONNX 导入其他语言和框架开发的模型可能是一个解决方案。

在本文中,我想介绍 ONNX,并通过一个实际示例来解释其巨大的潜力。

ONNX 是什么?

ONNX,即开放神经网络交换,是一个用于表示深度学习模型的开源标准。它由 Facebook 和 Microsoft 开发,旨在使研究人员和工程师能够更轻松地在不同的深度学习框架和硬件平台之间迁移模型。

ONNX 的一个主要优势是它允许模型轻松地从一个框架(如 PyTorch)导出,并导入到另一个框架(如 TensorFlow)。这对于那些想尝试不同框架来训练和部署模型的研究人员,或者需要在不同硬件平台上部署模型的工程师尤其有用。

框架互操作性(图片由作者提供)

ONNX 还提供了一套工具,用于优化和量化模型,这有助于减少模型的内存和计算需求。这对于在边缘设备和其他资源受限环境中部署模型尤其有用。

另一个 ONNX 的重要特点是它得到了广泛的公司和组织的支持。这不仅包括 Facebook 和 Microsoft,还包括像 Amazon、NVIDIA 和 Intel 这样的公司。这种广泛的支持确保了 ONNX 将继续得到积极开发和维护,使其成为一个稳健和稳定的深度学习模型表示标准。

ONNX Runtime

ONNX Runtime 是一个开源推断引擎,用于执行 ONNX(开放神经网络交换)模型。它被设计为高性能且轻量级,使其非常适合在各种硬件平台上部署,包括边缘设备、服务器和云服务。

ONNX Runtime 提供了 C++ API、C# API 和 Python API 来执行 ONNX 模型。它还支持多种后端,包括 CUDA 和 OpenCL,这使得它可以在各种硬件平台上运行,如 NVIDIA GPUs 和 Intel CPUs。

ONNX Runtime 非常有用,因为你可以在任何硬件上使用模型进行推断,无论你使用的是 CPU、GPU、FPGA 还是其他设备,而无需实际重写代码!

ONNX Runtime(图片来源于作者)

ONNX Runtime 的主要优点之一是其性能。它使用多种技术,如即时编译(JIT)、内核融合和子图分区来优化模型性能。它还支持线程池和节点间通信进行分布式部署,使其成为大规模部署的合适选择。

我将在未来的文章中解释所有这些高级功能!

ONNX Runtime 还支持多种模型,包括传统的机器学习模型和深度学习模型。这使得它成为一个多功能的推断引擎,可以用于从计算机视觉和自然语言处理到语音识别和自动驾驶等各种应用。

让我们开始编码吧!

现在让我们来看一个示例,我们将使用经典的scikit-learn创建一个机器学习模型,然后这个模型转换为 ONNX 格式,以便我们可以与 ONNX Runtime 一起使用

首先,我们导入必要的库,将模型拉入 sklearn 并导出为经典的 pickle 格式。我们将使用鸢尾花数据集。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import joblib

#import data
iris = load_iris()
x,y = iris.data, iris.target
x_train, x_test, y_train, y_test = train_test_split(x, y)

#train and save model
clr = RandomForestClassifier()
clr.fit(x_train, y_train)
joblib.dump(clr, 'model.pkl', compress = 9)

现在我们已经训练并保存了模型,我们可以重新导入它并将其转换为 ONNX 模型。每个框架都有其自己的转换库。因此,如果你是在 PyTorch 或 TensorFlow 中开发的模型,你需要使用其他库。在这种情况下,库叫做skl2onnx

所以我们导入了必要的库。

%%capture
!pip install skl2onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import joblib

现在我们终于可以进行转换了。我们应该指定inital_type,然后可以创建一个名为model.onnx的文件,用于保存 onnx 模型。

clr = joblib.load('model.pkl')
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types = initial_type)
with open('model.onnx' , 'wb') as f:
  f.write(onx.SerializeToString()) 

现在我们已经有了 ONNX 格式的模型,我们可以导入它,并在一些数据上使用它进行推理。

然后我们安装 ONNX Runtime。

%%capture
!pip install onnxruntime
import onnxruntime as rt
import numpy as np

现在我们创建数据,并导入模型,从而创建一个会话。我们指定输入和输出名称(标签),然后在数据上运行会话!

data = np.array([[5.4, 6.3, 2.6, 7.4], [3.4, 6.2, 7.4, 2.3],[5.2, 6.4, 4.2,5.6]])

sess = rt.InferenceSession('model.onnx')
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: data.astype(np.float32)})[0]
print(pred_onx)

好吧,你通过利用 ONNX Runtime 得到了结果。这只需要几个简单的命令!

这只是对 ONNX 的一个介绍,你当然可以做更多,但我希望你发现这个例子有用。

最后的想法

ONNX 是一个开源标准,它使得在不同框架和硬件平台之间移动深度学习模型变得容易。它提供了一套优化和量化模型的工具,并且得到了众多公司和组织的支持。因此,ONNX 正在成为深度学习的重要标准,使得共享模型和跨平台部署变得简单。

结束

Marcello Politi

LinkedinTwitterCV

OpenAI API — ChatGPT 背后的模型介绍与实现

原文:towardsdatascience.com/openai-api-intro-11-practical-implementation-examples-of-the-models-behind-chatgpt-18601f68b51b

使用 ChatGPT 背后的模型的编程方法。

Farzad MahmoodinobarTowards Data Science Farzad Mahmoodinobar

·发表于 Towards Data Science ·19 分钟阅读·2023 年 11 月 7 日

--

图片由 Freddy Castro 提供,来源于 Unsplash

现如今 ChatGPT 不需要进一步介绍,在这篇文章中,我们将更深入地探讨如何通过官方的 OpenAI API(OpenAI 是 ChatGPT 背后的公司)以编程方式与 ChatGPT(如 GPT-4、GPT-3.5、DALL·E 等)背后的模型和引擎互动。机器学习科学家和工程师通常更喜欢使用 API 而不是图形用户界面,例如 ChatGPT,因为 API 提供了更高的灵活性和定制性,正如我们将在实现示例中看到的,这在商业环境中是必需的。

为了使用 OpenAI 的 API,我们将设置并激活一个 Python 虚拟环境(这是一项推荐但可选的步骤),安装 OpenAI Python 库,并开始实现 11 个实际示例。这些示例是我在众多探索过的示例中最喜欢的,将包括以下内容:

  1. 解释代码

  2. 图像生成

  3. 表情符号翻译(即我们提供文本描述,模型返回描述该文本的表情符号!)

  4. 语法错误纠正

  5. 机场代码提取器

  6. 命名实体提取器

  7. 机器翻译

  8. 情感分析

  9. 文本摘要

  10. 解析非结构化数据

  11. 编写 SQL 查询

我会在逐步讲解每项任务时提供更多细节,但现在既然我们知道了大纲,让我们开始吧!

1. 设置 Python

这一步只是为了创建一个虚拟环境,以便你可以将本文中创建和使用的内容与其他 Python 工作隔离开来。正如我在文章中提到的,使用虚拟环境是可选的,但通常是机器学习从业者和程序员推荐的最佳实践之一。有多种方法可以创建虚拟环境,下面是我使用的一种方法。我们将创建虚拟环境,然后激活它,再安装 OpenAI 的 Python 库(即使你决定跳过虚拟环境步骤,安装 OpenAI 的 Python 库仍然是必需的步骤)。

Mac 用户打开你的终端,Windows 用户打开命令提示符(说明见下方,以防你对这一步不熟悉),并跟随操作!

提示: 如何打开“终端”(在 Mac 上)或“命令提示符”(在 Windows 上)如下:

- Mac 用户: 前往“应用程序”文件夹或使用“Spotlight”搜索“终端”(Command + Space 打开“Spotlight”)

- Windows 用户: 在开始菜单中搜索“cmd”以打开“命令提示符”

1.1. 虚拟环境

打开终端或命令提示符,我们可以使用以下命令创建名为openai-env的虚拟环境:

python -m venv openai-env

一旦虚拟环境创建完成,我们可以使用下面的命令激活它:

source openai-env/bin/activate

现在我们在新创建并激活的虚拟环境中。接下来,我们将安装 OpenAI 的 Python 库。

1.2. 安装 OpenAI Python 库

请注意,虽然虚拟环境的使用是可选的,但安装 OpenAI Python 库是实现的必需步骤。以下命令安装最新的 OpenAI Python 库:

pip install — upgrade openai

2. 设置 API 密钥

使用 OpenAI API 需要设置一个 OpenAI 账户并获取 OpenAI API 密钥——我将带你完成这两个步骤。

设置 OpenAI 账户可以通过 OpenAI 注册网站 完成。创建 OpenAI 账户后,你可以访问 API 密钥页面 并点击“创建新的密钥”。你需要将其保存在安全的地方,并且通常不想与他人分享你的 API 密钥。

一旦设置好 API 密钥,我们将按如下方式导入它,将 YOUR_API_KEY 替换为你最近创建的 OpenAI API 密钥:

 # Import libraries
import os
import openai

# Pass API Key
os.environ['OPENAI_API_KEY'] = 'YOUR_API_KEY'
openai.api_key = os.getenv("OPENAI_API_KEY")

在准备工作完成后,我们终于可以专注于创建一个函数来调用 OpenAI 的 API,并开始实现示例的有趣部分!

3. 调用函数

在本节中,我们将创建一个函数来调用 OpenAI 的 API,我将其命名为magicWand!调用 OpenAI 的 API 需要传递一组变量(如下所述)。创建此函数将简化过程,以便我们不需要为每个示例重复相同的步骤。

对于所有示例,除了图像生成外,我们将使用 OpenAI 的聊天完成,并在请求中使用以下变量。目前无需了解这些变量的详细信息。我们将在逐步示例中学习它们的工作原理,但我已经提供了一个概述以供参考。

  • engine:识别将要使用的模型,例如gpt-4gpt-3.5-turbo

  • system_prompt:用于提供任务高层次指导的系统级提示

  • user_prompt:用于提供任务更详细说明的用户级提示

  • temperature:这是介于 0 和 2 之间的采样温度。较高的数字(例如 0.8)将确保输出更随机生成,而较低的数字(例如 0.2)将使其更具确定性

  • max_tokens:模型将生成的最大令牌数量(这有助于限制响应长度)

我们将在示例中将这些变量作为配置字典的值,格式如下:

config = {
 ‘engine’: ‘ENGINE_NAME’,
 ‘system_prompt’: ‘SYSTEM_PROMPT’,
 ‘user_prompt’: ‘USER_PROMPT’,
 ‘temperature’: TEMPERATURE,
 ‘max_tokens’: MAX_TOKENS
}

让我们创建我们的函数!此时,不必完全理解函数的内容。我们将请求 GPT-4 解释代码,这是第一个示例!

# Import libraries
import openai

# Define the call function
def magicWand(config):
    # Extract variables from the config dictionary
    engine = config.get('engine', 'gpt-3.5-turbo')  # Default to 'gpt-3.5-turbo'
    system_prompt = config.get('system_prompt', '')
    user_prompt = config.get('user_prompt', '') 
    temperature = config.get('temperature', 0.8)
    max_tokens = config.get('max_tokens', 100)

    # Create an array of message objects
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]

    # Make the API call
    response = openai.ChatCompletion.create(
        model=engine,
        messages=messages,
        max_tokens=max_tokens,
        temperature=temperature
    )

    # Extract and return the generated text
    return response.choices[0].message['content'].strip()

4. 任务实施

现在我们有了函数,让我们尝试第一个任务,即代码解释。

4.1. 解释代码

正如之前承诺的那样,让我们请 GPT-4 解释我们的代码,以便更好地理解该函数!

为了做到这一点,我们将使用刚刚创建的magicWand函数,并按如下方式传递值。请注意,主要指令是system_prompt,我们向 GPT-4 解释任务为:您将获得一段代码,您的任务是以简洁的方式解释它

让我们实施这个任务并查看结果。

# Create the config dictionary
config = {
    'system_prompt': 'You will be provided with a piece of code, and your task is to explain it in a concise way.',
    'engine': 'gpt-4',
    'temperature': 0,
    'max_tokens': 2000,
    'user_prompt': '''
import openai

def magicWand(config):
    engine = config.get('engine', 'gpt-3.5-turbo')
    system_prompt = config.get('system_prompt', '')
    user_prompt = config.get('user_prompt', '') 
    temperature = config.get('temperature', 0.8)
    max_tokens = config.get('max_tokens', 100)

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]

    response = openai.ChatCompletion.create(
        model=engine,
        messages=messages,
        max_tokens=max_tokens,
        temperature=temperature
    )

    return response.choices[0].message['content'].strip()
    '''
}

# Use the config to get the result
result = magicWand(config)

# Print the result
print(result)

结果:

GPT-4 对magicWand函数的解释

我发现这些结果既令人印象深刻又迷人。它逐步解释了我们的magciWand函数变量是什么以及它们的目的是什么。阅读完结果后,我们接下来就生成一个图像吧!

4.2. 图像生成

如名字所示,这次我们将生成一张图像。这是唯一一个我们不使用 OpenAI 的聊天完成,而是使用 OpenAI 的 DALL·E 模型的示例。使用案例非常直接——我们只需在prompt中提供图像描述,n为图像数量,size为图像大小。让我们请求一只黑色的苏格兰折耳猫,带有浅金色眼睛,躺在白色床单上,并查看生成的图像!

# Import libraries
from IPython.display import display, Image
import requests

# Generate the response
response = openai.Image.create(
  prompt="A black Scottish fold cat with light golden eyes laying down on white sheets",
  n=1,
  size="512x512"
)

# Save the image URL
image_url = response['data'][0]['url']

# Fetch the image
response = requests.get(image_url)

# Display the image
img = Image(data=response.content)
display(img)

结果:

通过 OpenAI 的 API 生成的猫咪图像

这是一个相当不错的图像,并且与我们的提示一致。接下来,我们将处理一个有趣的请求——我们将要求 GPT-4 将自然语言输入(即文本)翻译成表情符号!

4.3. 表情符号翻译

这可能是我最喜欢的例子!我们将要求 GPT-4 使用我们自己的magicWand函数将文本翻译成表情符号。我们将提供给 GPT4 的总体指示作为system_prompt,即你将会收到文本,你的任务是将其翻译成表情符号。不要使用任何常规文本。仅使用表情符号尽力而为,然后提供需要从文本翻译成表情符号的user_prompt数据科学文章很有趣。让我们看看结果!

# Create the config dictionary
config = {
    'engine': 'gpt-4',
    'system_prompt': 'You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.',
    'user_prompt': 'Data science articles are fun',
    'temperature': 0.8,
    'max_tokens': 128
}

# Use the config to get the result
result = magicWand(config)

# Print the results
print(result)

结果:

GPT-4 的文本到表情符号的翻译

这非常有趣!我能看到前半部分的表情符号与数据和科学相关,后半部分的表情符号则与乐趣相关。

接下来,我们将要求 GPT 模型纠正给定句子中的语法错误。

4.4. 语法错误纠正

机器学习模型的一个应用场景是纠正句子中的语法错误。这在商业环境中可以带来许多好处。例如,处理客户沟通的企业需要人类代表阅读、审查并回复这些客户沟通。人类代表的成本相当高,如果收到的消息难以理解,这样的沟通将需要额外的工作来让人类理解和回复。作为替代方案,企业可以依赖语法错误纠正模型来首先清理传入的沟通,然后将修正后的沟通版本发送给人类代表进行审查和处理。我以前写过关于另一种语法错误纠正模型的独立帖子(见下文链接),所以我决定使用相同的句子来看看 GPT 模型的表现!

## 机器学习中的语法错误纠正 — 概述与实现

使用语法错误纠正:标记,而不是重写(GECTor)

towardsdatascience.com

我们将指示gpt-4gpt-3.5-turbo纠正句子中的语法错误她昨天看天空时梳头,然后比较它们的表现。请注意,句子中有故意的拼写和语法错误供模型纠正。为此,我们将使用system_prompt你将会收到陈述,你的任务是将其转换为标准英语来给这两个模型,然后将句子提供为user_prompt

首先,让我们使用 GPT-3.5 实现并查看结果:

config = {
    'engine': 'gpt-3.5-turbo',
    'system_prompt':'You will be provided with statements, and your task is to convert them to standard English.',
    'user_prompt':'she looks at sky yesterday whil brushed her hair',
    'max_tokens':256
}

result = magicWand(config)

print(result)

结果:

GPT-3.5 修正后的句子

现在让我们用 GPT-4 来实现:

config = {
    'engine': 'gpt-4',
    'system_prompt':'You will be provided with statements, and your task is to convert them to standard English.',
    'user_prompt':'she looks at sky yesterday whil brushed her hair',
    'max_tokens':256
}

result = magicWand(config)

print(result)

结果:

GPT-4 修正后的句子

首先的观察是,两种模型在修正语法错误和提高原句可读性方面表现都非常好。第二个观察是,gpt-3.5-turbo的表现几乎与gpt-4一样,考虑到gpt-4的更高成本,也许我们可以在未来的语法错误修正中仅使用gpt-3.5-turbo

在下一个例子中,我们将要求 GPT 识别句子中的机场代码!

4.5. 机场代码提取器

我必须承认这个任务有点奇怪,是我个人无法立即完成的,但我们将要求 GPT 从文本中返回机场代码,system_prompt你将获得一段文本,你的任务是从中提取机场代码user_prompt我在八月从西雅图飞往波士顿。让我们使用gpt-4gpt-3.5-turbo并比较结果,从 GPT-3.5 开始。

config = {
    'engine': 'gpt-3.5-turbo',
    'system_prompt':'You will be provided with a text, and your task is to extract the airport codes from it.',
    'user_prompt':'I flew from Seattle to Boston in August.',
    'max_tokens':256
}

result = magicWand(config)

print(result)

结果:

GPT-3.5 提取的机场名称

然后让我们用 GPT-4 来实现:

config = {
    'engine': 'gpt-4',
    'system_prompt':'You will be provided with a text, and your task is to extract the airport codes from it.',
    'user_prompt':'I flew from Seattle to Boston in August.',
    'max_tokens':256
}

result = magicWand(config)

print(result)

结果:

GPT-4 尝试提取的机场名称

正如你所看到的,结果在“技术上”都是正确的,但也非常不同。gpt-3.5-turbo准确地返回了句子中西雅图和波士顿这两个城市名称的实际机场代码,尽管机场代码在提示中并未明确包含,正如gpt-4所说。我的假设是gpt-3.5-turbo经过微调,能够返回机场代码,而gpt-4则更字面地对待任务,没有检索出机场代码——这两者都很有趣。

接下来,让我们继续识别提供文本中的命名实体。

4.6. 命名实体提取器

命名实体识别是一个常见的自然语言处理(NLP)任务,其中识别命名实体,如名称、地点、地址、组织等。通过一个例子,这将变得更容易理解。我们将给gpt-4gpt-3.5-turbo提供system_prompt你将获得一段文本,你的任务是从中提取命名实体user_prompt我在八月从西雅图飞往波士顿。我记得我穿着崭新的耐克鞋,因为我对它们非常兴奋,以至于把我的 iPhone 忘在了黄色的凯美瑞出租车里。我们期望模型能识别出如西雅图、波士顿、八月、耐克、iPhone 和凯美瑞等命名实体,但让我们先看看模型的表现,从 GPT-3.5 开始。

config = {
    'engine': 'gpt-3.5-turbo',
    'system_prompt':'You will be provided with a text, and your task is to extract the named-entities from it.',
    'user_prompt':'I flew from Seattle to Boston in August. I remember I was wearing my brand new Nike shoes because I was so excited about them that I ended up leaving my iPhone in the yellow Camry cab.',
    'max_tokens':256
}

result = magicWand(config)

print(result)

结果:

GPT-3.5 提取的命名实体

接下来让我们实现并查看 GPT-4 的结果:

config = {
    'engine': 'gpt-4',
    'system_prompt':'You will be provided with a text, and your task is to extract the named-entities from it.',
    'user_prompt':'I flew from Seattle to Boston in August. I remember I was wearing my brand new Nike shoes because I was so excited about them that I ended up leaving my iPhone in the yellow Camry cab.',
    'max_tokens':256
}

result = magicWand(config)

print(result)

结果:

GPT-4 提取的命名实体

结果非常好!两个模型都能够识别所有命名实体,而与gpt-4不同,gpt-3.5-turbo还能够返回命名实体的类型(例如,Seattle 和 Boston 是地点等)。因此,如果我有命名实体识别的任务,我更可能使用gpt-3.5-turbo,因为它能够返回识别出的命名实体的类型,并且比gpt-4便宜。

让我们进入下一个任务,请模型为我们进行翻译!

4.7. 机器翻译

这个任务不言而喻。我们将包含一个system_promptYou will be provided with a text, and your task is to translate it to French,然后提供一个user_promptCan you help me with this task?,需要翻译成法语。我测试了两个 GPT 模型,结果相同,所以下面仅包含其中一个作为参考。

config = {
    'engine': 'gpt-4',
    'system_prompt':'You will be provided with a text, and your task is to translate it to French.',
    'user_prompt':'Can you help me with this task?',
    'max_tokens':128
}

result = magicWand(config)

print(result)

结果:

GPT-4 的英法翻译结果

这是一个预期中的好翻译!随意调整system_prompt,让模型将user_prompt翻译成其他语言吧!

接下来,我们将查看情感分析。

4.8. 情感分析

情感分析是另一项常见的 NLP 任务。在最基本的形式中,它告诉我们一个句子是带有积极、消极还是中立的情感。这对企业理解客户反馈非常有用。例如,一个模型可以处理所有的餐馆、产品或服务的客户评论,并返回正面、中性或负面的评论百分比,并用作该餐馆、产品或服务的总体评分!

让我们给gpt-4一个system_promptYou will be provided with a text, and your task is to provide a nuanced sentiment analysisuser_promptThat was such an exciting movie。请注意,这句话听起来非常积极,因此我们想看看情感分析结果是否与预期一致。

让我们看看结果吧!

config = {
    'engine': 'gpt-4',
    'system_prompt':'You will be provided with a text, and your task is to provide a nuanced sentiment analysis.',
    'user_prompt':'That was such an exciting movie!',
    'max_tokens':128
}

result = magicWand(config)

print(result)

结果:

GPT-4 的情感分析结果

正如预期的那样,gpt-4也认为这个句子是积极的,并添加了一个听起来像是模型决策背后理由的第二句。需要注意的是,“推理”可能真实也可能不真实,因为像 GPT-4 这样的 LLM 并不是完全确定性的——它仅仅表明模型将“激动人心”与享受、参与和积极情感关联起来——我们对于这些模型生成内容的方式知之甚少,但这是另一个话题。

接下来的任务,让我们请模型总结一个文本!

4.9. 文本摘要

文本摘要是另一项不言自明的任务。想象一下,当我们拨打客户支持热线时,代表希望阅读之前关于这个话题的客户支持与客户之间的通信。与其阅读整个通信记录,不如使用模型总结过去通信中的最重要部分,然后客户支持代表只需阅读模型提供的摘要。这为客户和服务提供商节省了宝贵的时间。

为了实现这一点,让我们向 GPT-4 提供 system_prompt你将会收到一段文本,你的任务是对其进行总结,不使用原文中的词语,并提供一段较长的文本作为 user_prompt,其内容为 文本摘要是自动总结文本输入的任务,同时仍传达主要观点和要点。需要这种总结模型的商业直觉之一是人们阅读收到的文本通信(例如客户电子邮件),并且使用总结模型可以节省人工时间

让我们看看结果吧!

config = {
    'engine': 'gpt-4',
    'system_prompt':'You will be provided with a text, and your task is to provide a summary of it, without using the original words.',
    'user_prompt':'Text summarization is the task of automatically summarizing textual input, while still conveying the main points and gist of the incoming text. One example of the business intuition behind the need for such summarization models is the situations where humans read incoming text communications (e.g. customer emails) and using a summarization model can save human time.',
    'max_tokens':256
}

result = magicWand(config)

print(result)

结果:

GPT-4 的文本摘要结果

这似乎是对提供文本的一个很好的总结!我认为我自己也未必能做得更好。

接下来的任务是处理数据分析的那些人,即解析非结构化数据。

4.10. 解析非结构化数据

对于那些需要处理大量非结构化数据的用户,这个结果非常有用。我们可以让模型处理文本,然后将数据组织成不同的组。让我们看一个例子来更好地理解这一点。

我们将向 GPT-4 提供 system_prompt你将会收到非结构化数据,你的任务是将其解析成一个 Pandas 数据框,然后提供非结构化数据作为 user_prompt,内容为 几天前我在火星上散步时,遇到了一群亚马逊员工。第一个人,杰克,来自波士顿,穿着黑色裤子、白色衬衫和黑色跑鞋。另一个人,吉尔,穿着长款靛蓝色连衣裙,脚踩浅蓝色凉鞋,来自西雅图。第三个人名叫约翰。我不记得他穿了什么,但我特别记得约翰来自新泽西州的纽瓦克。最后一个人是珍娜,她来自旧金山。珍娜穿着白色 T 恤和蓝色裤子,但我不记得她穿了什么鞋子

你可以看到 user_prompt 包括了个人的名字和他们穿的衣物。让我们看看模型如何组织这些信息。

config = {
    'engine': 'gpt-3.5-turbo',
    'system_prompt':'You will be provided with unstructured data, and your task is to parse it into a Pandas dataframe.',
    'user_prompt':'As I was walking around Mars a few days ago, I came across a group of Amazon employees. The first one, Jack, was originally from Boston and wore black pants with a white shirt and black running shoes. Another one, Jill, had a long indigo-colored dress on with light blue sandals and was originally from Seattle. The third one was named John. I cannot remember what he was wearing but I particularly recall that John was from Newark, New Jersey. The last individual was Jenna and she was from San Francisco. Jenna had a white t-shirt and blue pants on but I cannot recall her shoes.',
    'max_tokens':1024
}

result = magicWand(config)

print(result)

结果:

GPT-3.5 对非结构化输入的结构化输出

这非常令人印象深刻,做得很好!正如你所看到的,GPT-3.5 能够处理提供的文本,并将其组织(即解析)成每个人的相关列。这种解析在过去大多是手动完成的,当人们希望以表格格式分析数据时,这会非常有帮助。

作为最后的任务,我们将要求模型为我们编写一个 SQL 查询!

4.11. 编写 SQL 查询

我个人对 GPT 模型的 SQL 编写能力非常好奇,因为我在工作中频繁使用 SQL,并且在 Medium 上发布过以下 SQL 教程。

## SQL 教程 + 备忘单

介绍

medium.com

为了评估 GPT 模型的能力,我选择了我在 SQL 帖子中使用的一个示例,并要求 gpt-3.5-turbo 编写我自己准备的查询。为此,我们希望将要查询的表定义为 system_prompt 的一部分,然后像往常一样在 user_prompt 中定义任务。让我们看看这个示例的实现,然后是结果和与我自己编写的查询的比较。

config = {
    'engine': 'gpt-3.5-turbo',
    'user_prompt':'Write a SQL query which returns a rank of the salaries overall and also by gender from highest to the lowest salary.',
    'max_tokens':1024,
    'system_prompt':'''
Given the following SQL tables, your job is to write queries given a user’ s request.

DROP TABLE IF EXISTS salary;

CREATE TEMPORARY TABLE salary(city VARCHAR(30), average_salary int);

INSERT INTO
salary
VALUES
    ('san_francisco', '54500'),
    ('seattle', '54100'),
    ('new_york', '34400'),
    ('phoenix', '31800');

DROP TABLE IF EXISTS people;

CREATE TEMPORARY TABLE people(
    person_id int,
    name VARCHAR(30),
    gender VARCHAR(30),
    location VARCHAR(30),
    birth_year int,
    birth_month VARCHAR(30),
    birth_day int,
    job_title VARCHAR(30),
    salary int
);

INSERT INTO
people
VALUES
    (
        '1',
        'james',
        'male',
        'seattle',
        '1984',
        '9',
        '15',
        'software_developer',
        '115000'
    ),
    (
        '2',
        'mary',
        'female',
        'new_york',
        '1992',
        '1',
        '13',
        'financial_analyst',
        '183000'
    ),
    (
        '3',
        'john',
        'male',
        'san_francisco',
        '1971',
        '4',
        '22',
        'data_scientist',
        '165000'
    ),
    (
        '4',
        'patricia',
        'female',
        'phoenix',
        '1971',
        '8',
        '15',
        'physician',
        '215000'
    ),
    (
        '5',
        'michael',
        'male',
        'new_york',
        '1966',
        '1',
        '13',
        'retired',
        '25000'
    ),
    (
        '6',
        'jennifer',
        'female',
        'phoenix',
        '1994',
        '12',
        '12',
        'data_scientist',
        '165000'
    );
'''
}

result = magicWand(config)

print(result)

结果:

GPT-3.5 的查询

这非常令人印象深刻!这个查询涉及使用窗口函数,这些函数是 SQL 中较具挑战性的概念之一,但模型处理得相当好。以下是我在上面帖子中提供的解决方案作为参考,你可以看到模型的响应整体结构与我编写的查询非常相似!

我自己的查询

5. 结论

在这篇文章中,我们介绍了 OpenAI 的 API,它提供了对后台模型的访问,使 ChatGPT 能够执行各种任务。然后,我们使用 OpenAI 的 API 实现了 11 个聊天完成和图像生成的示例,并比较了 gpt-4gpt-3.5-turbo 在这些任务中的表现。总体而言,我发现这两个 GPT 模型都非常强大,并认为它们是我个人使用的有用工具,而 DALL·E 是一个令人印象深刻的图像生成器。

感谢阅读!

如果你觉得这篇文章对你有帮助,请 关注我在 Medium 上订阅 以接收我的最新帖子!

(所有图片,除非另有说明,均由作者提供。)

OpenAI 的网络爬虫和 FTC 失误

原文:towardsdatascience.com/openais-web-crawler-and-ftc-missteps-a14047f4ff69?source=collection_archive---------7-----------------------#2023-08-22

OpenAI 推出默认的自动同意爬虫以抓取互联网,而 FTC 进行了一项模糊的消费者欺骗调查

Viggy BalagopalakrishnanTowards Data Science Viggy Balagopalakrishnan

·

关注 发表在 Towards Data Science · 11 分钟阅读 · 2023 年 8 月 22 日

--

图片由 Giammarco Boscaro 提供,来源于 Unsplash

随着 AI 采用的急剧上升,数据专业人士越来越需要考虑数据来源。虽然最初一波高性能的 LLM 是使用一种普遍但有争议的数据抓取策略进行训练的,但这种有问题的做法最近受到关注,引发了诉讼和数据所有权的问题。本文提供了对这些法律概念的深入理解以及监管机构如何应对这一问题(剧透:效果并不显著)。

来自 Towards Data Science 编辑的说明: 虽然我们允许独立作者根据我们的 规则和指南 发布文章,但我们并不赞同每位作者的贡献。你不应在未寻求专业建议的情况下依赖某位作者的作品。详情请参见我们的 读者条款

上周,Open AI(ChatGPT 的制造商)正式宣布了他们的网络爬虫——这是一个从互联网上所有网站抓取内容的软件,这些内容随后用于 AI 模型训练。爬虫的存在并不令人惊讶,今天存在着几种合法的网络爬虫,包括索引整个互联网的 Google 爬虫。然而,这还是 OpenAI 首次明确宣布其存在,并提供了一个机制让网站选择退出被抓取。

请注意,爬虫程序默认是需要主动选择的,也就是说,你需要明确更改网站上的一段代码来要求爬虫程序不要抓取你的数据。主动选择/退出的默认设置是固定的,通常决定了大多数人的行为,因为大多数人不会费心去更改默认设置。这也是苹果 iOS14 隐私变更对数字广告行业产生重大影响的原因。

OpenAI 网络爬虫(来源:OpenAI

那么,为什么还要提供退出选项呢?这可能是 OpenAI 对最近的诉讼采取的预防性措施,诉讼指控内容拥有者的版权受到侵犯(如果你想进一步了解,更多关于数据抓取的深度文章可以阅读)。ChatGPT 的竞争对手 Google Bard 面临类似挑战,但 Google 尚未宣布相应的解决方案——他们确实提出了如何升级robots.txt来解决这个问题的征求意见(以一些巧妙的公关写作呈现)。

在本文中,我们将深入探讨:

  • OpenAI 爬虫对内容拥有者的影响

  • FTC 目前对 OpenAI 的调查

  • 我们目前所处的法律环境

  • 为什么 FTC 追究 OpenAI 的做法是(又一个)错误的步骤

OpenAI 的爬虫对内容拥有者的影响

尽管公告为广告商提供了一个选项,可以阻止 OpenAI 的爬虫抓取他们的数据,但仍有几个问题:

  1. 默认情况下是选择加入的,这意味着 OpenAI 可以继续抓取,直到网站明确告诉他们不要抓取

  2. 关于在未经同意的情况下抓取数据用于模型训练时内容拥有者的权利,尚未有明确的法律裁决(这基本上适用于所有被迫默认选择加入的情况)

目前,有两个法律构架决定了语言模型是否可以在未获同意的情况下获取所有这些数据——版权和公平使用。

版权(在美国版权法第 102 条中)为特定类型的内容提供保护,但也有例外:

版权保护根据本标题存在于任何有形的表达媒介中固定的原创作品中,无论现在已知还是以后开发,从中可以被感知、复制或以其他方式传达,无论是直接还是借助机器或设备。作品的类别包括以下几类:(1)文学作品;(2)音乐作品,包括任何附带的文字;(3)戏剧作品,包括任何附带的音乐;(4)哑剧和舞蹈作品;(5)图画、图形和雕塑作品;(6)电影和其他视听作品;(7)声音录音;(8)建筑作品。

(b)在任何情况下,原创作品的版权保护不会扩展到任何想法、程序、过程、系统、操作方法、概念、原则或发现,无论以何种形式描述、解释、插图或体现于该作品中

例如,版权保护大多数原创作品(例如,如果你写了一篇原创的博客文章或书籍),但不保护广泛的思想(例如,你不能声称你是第一个写关于 AI 如何影响数据权利的人,因此这个想法属于你)。

版权保护的另一个例外是公平使用(美国版权法第 107 条):

对于受版权保护的作品的公平使用,包括通过复制或其他该节指定的方式使用该作品用于批评、评论、新闻报道、教学(包括用于课堂使用的多份副本)、学术研究或研究,并不构成版权侵权。

在确定任何特定情况下对作品的使用是否属于合理使用时,需要考虑的因素包括(1)使用的目的和性质,包括这种使用是否具有商业性质或是否用于非营利性教育目的;(2)受版权保护作品的性质;(3)使用部分相对于受版权保护作品整体的数量和实质性;以及(4)使用对潜在市场或受版权保护作品价值的影响。

例如,如果你从一篇研究论文中提取了内容并写了评论,这是可以的,并且你不会侵犯内容所有者的版权。当我从此页面链接到另一篇文章并添加该文章的引用文本时,情况也是一样的。

这两个概念的创建旨在保护内容所有者的权利,同时也允许信息的自由流动,特别是在教育、研究和评论的背景下。

我不是法律专家,但根据我对上述语言的研究/理解,在 AI 模型抓取训练内容时,情况变得模糊

  • AI 公司通常会从内容所有者的网站上抓取全文(这些是受版权保护的),训练模型以学习“思想”/“概念”/“原理”(这些是不受版权保护的),然后模型最终会生成不同的文本。在这种情况下,内容所有者是否会获得版权保护?

  • 由于训练后的语言模型现在最终用于商业目的(例如,ChatGPT Plus 是一款付费产品),这是否违反了内容所有者的版权(因为公平使用例外不再适用)?

目前尚未有法院对此作出裁决,因此很难预测结果。我这个非律师的观点是,第二种情况可能更容易解决:OpenAI 抓取了数据并用它创建了商业产品,因此他们不适用公平使用的例外。我想象第一种情况(模型是训练在“思想”上还是仅仅是原始文本上)则无人能知晓。请注意,这两个条件都需要对内容所有者有利,内容所有者只有在上述两个例外(“思想”例外或公平使用例外)都不适用于 OpenAI 时才能获胜。

我提到这个细微差别是因为在人工智能风险的范围内(并不详尽)——从内容所有者的权利、放大欺诈、自动化工作到 AGI / 人类毁灭——最紧迫的短期问题是内容所有者的权利,这一点可以从大量的诉讼和对内容平台的影响(例如 StackOverflow 的故事)中看出。

虽然像 FTC 这样的监管机构可以考虑真正的长期问题,并提出假设性/创造性的方法来应对这些风险,但它们的真正短期潜力在于能够应对那些将在 5 到 10 年内影响我们的风险。例如版权侵权。这引出了 FTC 在做什么。

FTC 对 OpenAI 的当前调查

在七月中旬,FTC 宣布正在调查 OpenAI。令其有趣(和令人沮丧)的地方在于FTC 调查的原因。ChatGPT 的制造商正在被调查,以评估该公司是否违反了消费者保护法律将个人声誉和数据置于风险之中。这不合理?你并不孤单。让我们进一步了解一下这一情况的背景。

FTC 对 AI 监管的最明确立场在四月提出:“法律书中没有 AI 豁免条款,FTC 将积极执行法律,以打击不公平或欺骗性行为或不公平的竞争方式。” 随后出现了一些与诽谤相关的问题:电台主持人马克·沃尔特斯起诉 OpenAI因为 ChatGPT 指控他欺诈非营利组织,一位法学教授被ChatGPT 错误指控性骚扰

这两种情况对相关人员来说都很糟糕,我对此表示同情。然而,语言模型(如 GPT)和基于它们的产品(如 ChatGPT)会“产生幻觉”,并且经常不准确。FTC 对调查的第一个前提是—— ChatGPT 产生幻觉,从而造成声誉损害。

在一次激烈的国会听证会上,一位代表(有理)询问 FTC为什么要追究诽谤和中伤,这些通常由州法律处理。FTC 主席丽娜·汗给出了一个复杂的论点

韩表示,诽谤和中伤不是 FTC 执行的重点,但在 AI 训练中滥用个人私人信息可能是一种欺诈或欺骗行为,违反 FTC 法案。韩说:“我们关注的是,‘是否对人们造成了实质性的伤害?’伤害可以表现为各种形式。”

将整个论点总结起来——FTC 认为ChatGPT 的幻觉生成了不正确的信息(包括诽谤),这可能构成消费者欺骗。此外,敏感的用户私人信息可能被使用/泄露(基于一个漏洞 OpenAI 已迅速修复)。

作为调查的一部分,FTC 要求 OpenAI 提供一长串信息——包括他们的模型是如何训练的,使用了哪些数据来源,如何向客户展示他们的产品,以及因为识别到风险而暂停发布模型的情况。

问题是——在当前法律环境下,FTC 最好的做法是监管可能成为最大 AI 公司之一的 OpenAI 吗?

我们今天所处的法律环境

要批评 FTC 对 OpenAI 的战略,了解我们今天所处的法律环境是很有帮助的。我们不会深入细节,但可以简单地用反垄断历史作为例子:

  • 在 1900 年代,巨大的企业集团(“信托”)出现,公共和私人权力的平衡转移到了这些公司手中。

  • 为了应对这种情况,1890 年的《谢尔曼法》被通过,以对私人权力施加限制并维护竞争;这一法律被用于诉讼并打破从事反竞争行为(掠夺性定价、卡特尔交易、分销垄断)的“信托”。

  • 大约在 1960 年代,法官们因根据法律精神而非法律字面进行裁决而遭遇大量反对。例如,解释《谢尔曼法》以确定一组公司是否“过度限制贸易”涉及主观性,法官们被指责参与司法激进主义。

  • 为了引入客观性,芝加哥学派开创了消费者福利标准——“法院应 exclusively 以消费者福利为指导”(例如,垄断者明显提高价格是错误的,但对于其他行为,证明消费者伤害的责任在于监管者)。

  • 这种标准今天仍然适用,并且是 FTC 和 DOJ 难以对付大型科技公司的原因之一——例如,FTC 无法主张 Google 提高价格,因为他们的大多数产品是免费的,即使 Google 从事了其他反竞争行为。

从中可以得出的结论是——我们今天继续在一个案件 heavily 基于“法律字面”而非“法律精神”的环境中运作。这一点,加上今天美国最高法院的组成,导致了对法律的相当保守的解释。

对于 FTC 来说,这意味着要接受这一现状并找出赢得案件的办法。FTC 和 DOJ 的运营模式(是合理的)是追求少数几个大案件并实施严格的执法,以便让其他公司在违法之前三思而后行。要实现这一点,FTC 需要在一些关键问题上取得重大胜利,并且需要在当前法律环境的限制下制定一个胜利策略

为什么 FTC 对 OpenAI 的追击是(又一次)失误

FTC 在对抗大型科技公司方面屡次失败,我认为这些失败都可以归因于一种失败的“我们讨厌一切大型科技公司”的策略,类似于用锤子而非手术刀的方式来对付这些公司。

例如,FTC 采取了强硬手段阻止了$69B 的微软-动视收购案,但失败了(可以说是非常惨败)。FTC 辩称微软收购动视会杀死游戏市场的竞争。法官写了一份相当直白的判决,驳回了 FTC 的所有论点,这里有法官的评论之一

没有内部文件、电子邮件或聊天记录与微软声明的意图相矛盾,即不将《使命召唤》独占到 Xbox 主机上。尽管 FTC 行政程序中进行了大量的发现,包括生产了近 100 万份文件和 30 份证词,FTC 没有发现任何一份与微软公开承诺将《使命召唤》提供给 PlayStation(和 Nintendo Switch)相矛盾的文件。

另一个强硬手段的案例是 FTC 试图阻止 Meta 收购 VR 公司 Within,他们失败了。他们为什么要这样做?他们想测试一下是否有意阻止在特定市场变得庞大之前的收购,而鉴于当前的法律环境,这个尝试不出所料地被驳回了。

FTC 对 OpenAI 的调查问题类似:

  1. 他们在追究(在我看来)一个相当琐碎的问题以及语言模型的已知局限——幻觉;他们应该将精力集中在 5 至 10 年内真正重要的人工智能问题上,例如版权。

  2. 尽管当前法律环境中多种“创造性”的法律途径被否决,他们仍在尝试另一种创造性的论点:幻觉 → 诽谤 → 消费者欺诈

对他们行动的宽泛解释是,他们想为他们的“人工智能不免于现有法律”立下先例,而这场徒劳的追逐战使他们从 OpenAI 那里获得了大量自我报告的数据(FTC 发布了20 页的要求)。

然而,鉴于他们一再采用暴力手段/任何大科技公司的非竞争性方法,并结合那些在法院被一再驳回的创意论点,我认为 FTC 在这个案件中并未赢得应有的信任。

结论

我绝对认为 OpenAI 应该受到监管。这不是因为他们的 LLM 会出现幻觉(当然会),而是因为他们公然未经许可使用创作者的内容。这不是因为它会改变过去,而是因为它将有助于为创作者建立一个健康的未来,在那里他们的内容所有权得到保护(法院是否会将现状视为版权侵权还有待观察)。

如果 FTC 继续重复其错误,采用“铁锤而非手术刀”的方法,这种情况不会改变。针对大科技公司的手术刀方法有明确的成功先例,其中最著名的是英国竞争与市场管理局。他们赢得的两个大案件集中在特定的反竞争机制上:阻止谷歌对其广告技术堆栈中的自身产品给予优待,以及允许其他支付提供商进行应用内支付。

如果 FTC 继续走当前的道路,他们的败绩将激励科技公司继续为所欲为,因为他们知道他们可以赢得官司。FTC 是时候反思自己的失败,向其他监管机构学习成功经验,并进行调整。

🚀 如果你喜欢这篇文章,可以考虑订阅我的每周通讯 每周,我都会发布一篇深度分析 关于当前科技话题/产品策略的文章,阅读时间大约为 10 分钟。祝好,Viggy。

[## Unpacked | Viggy Balagopalakrishnan | Substack

针对当前科技和商业话题的深度分析,帮助你保持领先。每周送到你的邮箱…

thisisunpacked.substack.com](https://thisisunpacked.substack.com/?source=post_page-----a14047f4ff69--------------------------------)

openCypher* 针对任何关系数据库

原文:towardsdatascience.com/opencypher-against-any-relational-database-a3b2388579df

关系数据库作为图形数据库 = Mindful (openCypher-2-SQL)

Victor Morgante数据科学前沿 Victor Morgante

·发表于 数据科学前沿 ·8 分钟阅读·2023 年 7 月 25 日

--

图片由作者提供。阴阳月。由 Syed Ahmad 修改的免版税照片 Unsplash

在任何关系数据库上运行的 openCypher 图查询的有限子集是 Mindful 计划。此阶段的查询是只读的,且不包含元图查询。Mindful 是 微软的 openCypher 到 SQL 转换器 的封闭源代码修改版,遵循 MIT 许可证,其中 Mindful 生成 SQL 以操作任何关系/SQL 数据库。

考虑到这一点……让我们开始了解范围……

在 Mindful 的背景下,“任何关系数据库”意味着 openCypher 查询被转换为针对任何实际关系数据库的 SQL,而不是那些需要为图类型查询特别修改表格或将数据作为 JSON 注入字段并在该 JSON 数据上执行图形查询的关系数据库。

openCypher 查询被转换为 SQL,以便在任何标准关系数据库上运行。

解释视频:

对您业务的适用性 — 数据科学

您可能已经拥有一个现有的数据仓库、语义层或本质上是关系型的数据库,并且使用 SQL 作为主要查询语言……而您希望使用图形查询来查询您的数据资产。

相反,你可能急需从现有的图数据库迁移到关系/SQL 基于数据库,并需要数据迁移测试和实施工具。Mindful,一个 openCypher 到 SQL 的转译器,旨在成为实现你目标的工具。

现有的在关系数据库上进行图查询的实现需要特殊的表来有效表示节点类型和边类型(例如,具有单列主键的表)。Mindful 实现允许你在具有多列主键的表上运行 filter openCypher 查询。

在这篇文章中,我们展示了如何在不影响现有关系数据库栈的情况下实现这一点,通过采用一种数据科学策略,在 DDL 的评论部分使用 JSON 存储有关关系数据的同态图结构的元信息。例如,ORACLE、SQL Server、PostgreSQL 和大多数主流关系数据库都支持对元数据/DDL 进行注释。

实际上,通过使用现有数据库/数据仓库或语义层中提供的标准工具对数据库进行文档化,你将你的关系数据栈转换为图兼容的数据栈。

让我们深入探讨使这一切成为可能的数据科学和数据库理论……

关注细节

现代数据库管理系统显然具有以下特点,我们将对其进行探讨:

任何关系数据库都可以视为图数据库或关系数据库,取决于你的观点(如下所示);以及

一旦数据库投入使用,数据库架构随处可见。例如,甚至连以前提倡无架构的图数据库业务如Neo4j 也开始谈论“架构”

考虑到这一定位,我们可以想象,未来所有有价值的数据库都将提供图查询语言以及普通结构化查询语言(SQL)查询。

任何关系数据库都可以视作图数据库

图-关系范式、关系知识图谱 和多模型数据库并不新鲜,但为了设定背景,请注意实体-关系图与其对应的属性图模式及其反向映射:

同态映射——属性图模式到实体-关系图。图像作者提供。

我们的架构用于在电影院预订座位以观看电影。一个人可以为在电影院观看电影的一个或多个座位进行预订;其中座位位于该电影院的某个区域和行中。

Shutterstock 图片,2187621947. 授权给 Victor Morgante / FactEngine。 免版税标准许可证 2187621947在 Shutterstock 上。

数据库理论…

使 Mindful 成功的两个方面是:

  1. 外键关系多对多表在其他关系数据库中与图数据库的边类型是同义的;

  2. 图查询语言如 openCypher 支持可以用类似于 SQL 查询的方式编写的查询。

让我们单独查看这些内容…

外键和多对多表

外键 <-> 边类型

关系数据库中的外键关系与属性图模式中的对应边类型有 1:1 的映射关系。

在我们的模式中,每个名为 Row 的表中的 Row 记录都有一个指向‘Cinema’表中 Cinema 记录的外键引用,表示该行‘IS IN’的影院。

实体关系图中的外键引用。图片由作者提供。

…在我们的属性图模式中有一个对应的边类型。

属性图模式中的边类型。图片由作者提供。

…在我们的 DDL(数据库定义语言)中,这可以通过一个表示属性图模式的“IS_IN”标签的 JSON(JavaScript 对象表示法)轻松表示,嵌入为 CREATE TABLE DDL 语句中‘Row’表的对应外键参考定义的注释/文档。

CREATE TABLE "Row" (Cinema_Id INTEGER REFERENCES [Cinema] NOT NULL
,RowNr NUMBER NOT NULL
, CONSTRAINT [Row_PK] PRIMARY KEY ([Cinema_Id],[RowNr])
, FOREIGN KEY ([Cinema_Id]) REFERENCES [Cinema] ([Cinema_Id])
     ON DELETE CASCADE ON UPDATE CASCADE /* { Label:"IS_IN"} */
)

…我们可以将关系视图和图视图之间的同态视为一种变形动画:

外键关系与边类型之间的同态。图片由作者提供。

多对多表 <-> 边类型

类似地,关系数据库中的多对多连接表在传统图数据库的属性图模式中与其对应的边类型具有同态。

在我们的模式中,一个预订可能有多个座位,而一个座位(在其生命周期内)可能有多个预订。

注意,在我们的模式中,‘Booking’和‘Seat’表的主键是多列的,而我们在‘BookingHasSeat’多对多连接表上的主键有七列。

在 ER 图中的多对多连接表。图片由作者提供。

…在属性图模式视图中,Booking 和 Seat 节点类型之间有一个标记为‘HAS’的对应边类型。

属性图模式中的边类型。图片由作者提供。

…我们可以很容易地将‘HAS’标签作为 JSON 捕获在 DDL 中‘BookingHasSeat’表主键的注释中。

CREATE TABLE "BookingHasSeat" (CinemaId INTEGER
,Letter TEXT(1)
,DateTime DATETIME
,RowNr NUMBER
,Person_Id INTEGER
,Cinema_Id INTEGER
,Film_Id INTEGER
,CONSTRAINT [BookingHasSeat_PK] PRIMARY KEY 
     ([CinemaId],[Letter],[DateTime],[RowNr],[Person_Id],[Cinema_Id],[Film_Id])
     /* { Label:"HAS"} */
,FOREIGN KEY ([Person_Id],[Film_Id],[DateTime],[CinemaId]) REFERENCES 
     [Booking] ([Person_Id],[Film_Id],[DateTime],[Cinema_Id])
     ON DELETE CASCADE ON UPDATE CASCADE /* { Label:"HAS"} */
,FOREIGN KEY ([RowNr],[Cinema_Id],[Letter]) REFERENCES 
     [Seat] ([RowNr],[Cinema_Id],[Letter])
     ON DELETE CASCADE ON UPDATE CASCADE /* { Label:"HAS"} */
)

…我们也可以将关系视图和图视图之间的同态视为一种变形动画:

架构

Mindful 的实现变得相当简单易懂。通过将适当的同态从关系模式映射到图模式,并且我们可以将有关同态的元信息存储在关系数据库的 DDL 中,我们有一种非常简单的机制来编写 openCypher 查询……因为实际上,关系数据库的元模型与图数据库的元模型之间接近同构。即,从适当的视角来看,它们在概念上几乎是相同的。

同态 — openCypher 到 SQL

…两种查询类型的故事

图查询语言有两种查询类型:

  1. 过滤查询;以及

  2. 元图/图遍历查询。

元图查询是一种在属性图模式的“图”(或模型)中查找路径、关系或结构的查询类型。

另一方面,过滤查询从数据库的所有数据的超集中过滤数据。例如,用自然语言表达:“Peter 预订了哪些影院?”

在这一阶段,我们最感兴趣的是过滤查询,因为这是 Mindful 代码所处的阶段,但一个元图查询可能会提出这样的问题,其他形式的自然语言表述为:“从一个人 Peter 到他预订的影院的最短路径是什么?”。

从图像上看,我们可以看到,对于我们的模式,最短路径是从人到预订,再到会话,再到影院,如下所示(而不是人-预订-座位-行-影院):

最短路径 — 人到影院 — 元图/图遍历。图像作者提供。

对于我们的目的……我们需要过滤查询,因为 openCypher 过滤查询与 SQL 查询有同态,而 标准 SQL 不支持元图查询

让我们使用 openCypher 查询

假设我们想知道登录名为“Peter”的人在哪个日期和时间在什么影院观看了哪个电影,以及座位字母和行号……我们可以为我们的模式编写以下 openCypher 查询:

**MATCH (p:Person)<-[:IS_FOR]-(b:Booking)-[:HAS]->(seat:Seat)-[:IS_IN]->(:Row)-[:IS_IN]->(c:Cinema), (b)-[:IS_FOR]->(s:Session)-[:IS_FOR]->(f:Film)

WHERE p.LoginName = ‘Peter’

RETURN p.LoginName, f.Name, s.DateTime, seat.Letter, seat.RowNr, c.CinemaName

这是一种过滤类型的查询,因为我们正在将结果筛选到那些登录名为“Peter”的人。

…现在让我们在关系数据库上运行这个查询,使用连接到 Mindful 的软件:

openCypher 查询与关系数据库 — 结果。图像作者提供。

在这个阶段,微软的 openCypher 到 SQL 转译器(以及 Mindful)生成的 SQL 相当繁琐,我在这里不会展示,但足以说明,相比于 openCypher 查询,相应的 SQL 查询相当长。以下是我们模式的等效 SQL 查询:

SELECT Person.[LoginName],Film.[Name],Session.[DateTime],
       Seat.Letter, Seat.RowNr,Cinema.[CinemaName]
FROM Person,
Booking,
Session,
Film,
Seat,
Row,
Cinema
,BookingHasSeat
WHERE Booking.Person_Id = Person.Person_Id
AND Booking.Film_Id = Session.Film_Id
AND Booking.DateTime = Session.DateTime
AND Booking.Cinema_Id = Session.Cinema_Id
AND Session.Film_Id = Film.Film_Id
AND BookingHasSeat.Person_Id = Booking.Person_Id
AND BookingHasSeat.Film_Id = Booking.Film_Id
AND BookingHasSeat.DateTime = Booking.DateTime
AND BookingHasSeat.CinemaId = Booking.Cinema_Id
AND BookingHasSeat.RowNr = Seat.RowNr
AND BookingHasSeat.Cinema_Id = Seat.Cinema_Id
AND BookingHasSeat.Letter = Seat.Letter
AND Seat.Cinema_Id = Row.Cinema_Id
AND Seat.RowNr = Row.RowNr
AND Row.Cinema_Id = Cinema.Cinema_Id
AND Person.LoginName = 'Peter'

总结

所以,这就是如此简单。我们展示了关系模式和图模式之间的同态,使得可以编写软件,以便在任何适当配置的关系数据库上执行 openCypher 查询(过滤类型),并且 openCypher 被转换为 SQL。

感谢阅读。时间允许的话,我将写更多关于图数据库、关系数据库、图查询和同态的内容。

— — — — — — —

Mindful openCypher 查询目前为只读,并且在此阶段没有元图查询。

归属:本文中的模式源于最初在 ActiveFacts examples github 页面 下以 MIT 许可证 开源共享的内容。

— — — — — — — — — — — — — 结束 — — — — — — — — — — —

通过物理启发的 DeepONet 进行算子学习:从头开始实现

原文:towardsdatascience.com/operator-learning-via-physics-informed-deeponet-lets-implement-it-from-scratch-6659f3179887

深入探讨 DeepONets、物理启发的神经网络以及物理启发的 DeepONets

Shuai GuoTowards Data Science Shuai Guo

·发表于 Towards Data Science ·阅读时间 23 分钟·2023 年 7 月 7 日

--

图 1. ODE/PDEs 广泛用于描述系统过程。在许多情况下,这些 ODE/PDEs 接受一个函数(例如,强迫函数 u(t))作为输入,并输出另一个函数(例如,s(t))。传统上,数值求解器用于连接输入和输出。最近,神经算子的开发大大提高了处理效率。(图像由作者提供)

常微分方程(ODEs)和偏微分方程(PDEs)是许多科学和工程学科的基础,从物理学和生物学到经济学和气候科学。它们是描述物理系统和过程的基本工具,捕捉了数量随时间和空间的连续变化。

然而,这些方程的一个独特特点是它们不仅接受单一数值作为输入,还接受函数。例如,考虑预测建筑物因地震而产生的振动。地面的震动随时间变化,可以表示为一个函数,该函数作为描述建筑物运动的微分方程的输入。同样,在音乐厅中声波传播的情况下,乐器产生的声波可以是一个随时间变化的音量和音调的输入函数。这些变化的输入函数从根本上影响了结果输出函数——建筑物的振动和音乐厅的声学场。

传统上,这些 ODEs/PDEs 通常使用有限差分或有限元方法等数值求解器来解决。然而,这些方法存在一个瓶颈:每当有新的输入函数时,求解器必须重新运行一次。这个过程可能计算密集且缓慢,尤其是在复杂系统或高维输入情况下。

为了应对这一挑战,Deep Operator Network(简称DeepONet)的创新框架由Lu et al.于 2019 年提出。DeepONets 旨在学习将输入函数映射到输出函数的算子,本质上是学习预测这些 ODEs/PDEs 在任意给定输入函数下的输出,而不需要每次都重新运行数值求解器。

尽管 DeepONets 很强大,但它们继承了数据驱动方法面临的共同问题:我们如何确保网络的预测与包含在控制方程中的已知物理定律一致?

进入物理信息化学习领域。

物理信息化学习是一个迅速发展的机器学习分支,它将物理原理与数据科学结合起来,以增强对复杂物理系统的建模和理解。它涉及利用特定领域的知识和物理定律来指导学习过程,提高机器学习模型的准确性、泛化能力和可解释性。

在这个框架下,2021 年,Wang et al.提出了 DeepONets 的新变种:物理信息化 DeepONet。这种创新方法在 DeepONets 的基础上,通过将我们对物理定律的理解融入学习过程中,进行改进。我们不再只是让模型从数据中学习,而是用源于几个世纪科学探究的原理来指导它。

这看起来是一个非常有前景的方法!但是我们应该如何在实践中实现它?这正是我们今天要探讨的内容🤗

在这篇博客中,我们将讨论物理信息化 DeepONet 背后的理论,并逐步讲解如何从零开始实现它。我们还将把我们开发的模型付诸实践,通过实际案例展示其强大能力。

如果你也有兴趣使用物理信息化 DeepONet 解决逆问题,可以查看我的新博客:利用物理信息化 DeepONet 解决逆问题:带代码实现的实用指南

让我们开始吧!

内容表

· 1. 案例研究

· 2. 物理信息化 DeepONet

∘ 2.1 DeepONet:概述

∘ 2.2 物理信息化神经网络(PINNs)

∘ 2.3 物理信息化 DeepONet

· 3. 物理信息化 DeepONet 的实现

∘ 3.1 定义架构

∘ 3.2 定义 ODE 损失

∘ 3.3 定义梯度下降步骤

· 4. 数据生成与组织

∘ 4.1 u(·) 轮廓生成

∘ 4.2 数据集生成

∘ 4.3 数据集组织

· 5. 训练物理信息深度运算网络

· 6. 结果讨论

· 7. 重点总结

· 参考文献

1. 案例研究

让我们在一个具体的例子中扎根讨论。在这篇博客中,我们将重现 Wang et al. 原论文中考虑的第一个案例研究,即由以下常微分方程(ODE)描述的初值问题:

具有初始条件 s(0) = 0。

在这个方程中,u(t) 是随时间变化的输入函数,而 s(t) 是我们感兴趣的在时间 t 的系统状态。在物理场景中,u(t) 可能代表施加在系统上的力,而 s(t) 可能代表系统的响应,比如位移或速度,具体取决于上下文。我们这里的最终目标是学习强迫项 u(t) 与 ODE 解 s(t) 之间的映射关系。

传统的数值方法,如欧拉方法或龙格-库塔方法,可以有效地求解此方程。然而,请注意,强迫项 u(t) 可以采取各种轮廓,如下图所示:

图 2. u(t) 的示例轮廓。 (作者提供的图片)

因此,每当 u(t) 变化时,我们需要重新运行整个模拟以获取相应的 s(t)(如图 3 所示),这可能会计算密集且效率低下。

图 3. s(t) 的相应轮廓。它们是通过使用 RK45 算法求解 ODE 计算得出的。 (作者提供的图片)

那么,我们如何更高效地解决这种问题呢?

2. 物理信息深度运算网络

如介绍中所述,物理信息 DeepONet 是解决我们目标问题的有前途的解决方案。在这一部分,我们将详细解析其基本概念,使其更易于理解。

我们将首先讨论原始 DeepONet 的基础原则。接着,我们将探索物理信息神经网络的概念及其如何为问题解决提供额外的维度。最后,我们将展示如何将这两个想法无缝集成以构建物理信息 DeepONet。

2.1 DeepONet:概述

DeepONet,简而言之就是深度运算网络,代表了深度学习的新前沿。与传统的机器学习方法将一组输入值映射到输出值不同,DeepONet 旨在将整个函数映射到其他函数。这使得 DeepONet 在处理自然涉及函数输入和输出的问题时特别强大。那么它是如何实现这一目标的呢?

为了符号化我们想要实现的目标:

图 4. 我们的目标是训练一个神经网络,以近似将强迫项 u(·)映射到目标输出 s(·)的算子,这两者都是时间的函数。(图片由作者提供)

左边是将输入函数 u(·)映射到输出函数 s(·)的算子G。右边,我们希望使用神经网络来近似算子 G。一旦实现了这一点,我们可以利用训练好的神经网络在给定任何 u(·)的情况下快速计算 s(·)。

对于当前的案例研究,输入函数 u(·)和输出函数 s(·)都将时间坐标t作为唯一参数。因此,我们希望构建的神经网络的“输入”和“输出”应如下所示:

图 5. 我们希望训练的神经网络模型的输入和输出。(图片由作者提供)

实质上,我们的神经网络应接受 u(t)的整个轮廓作为第一个输入,以及一个特定时间点t作为第二个输入。随后,它应输出在时间点t评估的目标输出函数 s(·),即 s(t)。

为了更好地理解这一设置,我们认识到 s(t)的值首先依赖于 s(·)的轮廓,而 s(·)的轮廓又依赖于 u(·),其次依赖于 s(·)被评估的时间点。这也是时间坐标t需要作为输入之一的原因。

目前我们需要弄清楚两个问题:首先,我们应该如何将 u(·)的连续轮廓输入网络?其次,我们应该如何拼接这两个输入,即t和 u(·)。

1️⃣ 我们应该如何输入 u(·)的连续轮廓?

实际上,我们并不这样做。一种直接的解决方案是离散表示函数 u(·)。更具体地说,我们只是评估 u(·)在足够但有限的多个位置的值,然后将这些离散的 u(·)值输入到神经网络中:

图 6. 在被输入到神经网络之前,u(·)轮廓被离散化。(图片由作者提供)

这些位置在原始 DeepONet 论文中被称为传感器

2️⃣ 我们应该如何将输入t和 u(·)拼接在一起?

初看之下,我们可能会想直接在输入层将它们拼接起来。然而,事实证明,这种简单的方法不仅会限制我们可以使用的神经网络类型,而且在实践中会导致次优的预测准确度。不过,还有更好的方法。现在是介绍DeepONet的时候了。

总之,DeepONet 提出了一种用于算子学习的新网络架构:它由两个主要组件组成:分支网络主干网络。分支网络将离散函数值作为输入,并将其转换为特征向量。同时,主干网络将坐标(在我们当前的案例研究中,坐标仅为t。对于 PDE,将包括时间和空间坐标)作为输入,并将其也转换为相同维度的特征向量。这两个特征向量然后通过点积进行合并,最终结果用作在输入坐标处评估 s(·)的预测值。

图 7. DeepONet 包括一个分支网络来处理输入函数 u(·)和一个主干网络来处理时间/空间坐标。两个网络的输出具有相同的维度,并通过点积进行合并。可选地,可以在点积后添加一个偏置项以进一步提高模型的表达能力。(图片由作者提供)

在原始 DeepONet 论文中,作者指出,这种在“分支”和“主干”网络中体现的“分而治之”策略受到算子通用逼近定理的启发,旨在为算子学习引入强的归纳偏置。这也是作者声称使 DeepONet 成为有效解决方案的关键点。

如果你想了解更多关于 DeepONet 理论基础的内容,请参考原始论文的附录 A。

DeepONet 的主要优势之一是其效率。一旦训练完成,DeepONet 可以实时推断新的输入函数的输出函数,无需进一步训练,只要新的输入函数在其训练过的输入函数范围内。这使 DeepONet 成为需要实时推断的应用中的强大工具。

DeepONet 的另一个显著优势在于其灵活性和多功能性。虽然主干网络和分支网络最常见的选择是全连接层,但 DeepONet 框架允许高度的架构自定义。根据输入函数 u(·)和坐标的特征,可以采用各种神经网络架构,如 CNN、RNN 等。这种适应性使 DeepONet 成为一个高度多功能的工具。

然而,尽管存在这些优势,DeepONet 的局限性也很明显:作为一种纯数据驱动的方法,DeepONet 不能保证其预测结果会遵循描述所考虑物理系统的先验知识或控制方程。因此,DeepONet 可能无法很好地泛化,尤其是当面对位于训练数据分布之外的输入函数,即分布外(OOD)输入时。对此的一个常见解决方案是准备大量训练数据,但在实际中这可能并不总是可行,特别是在数据收集可能昂贵或耗时的科学和工程领域。

那么我们应该如何解决这些局限性呢?是时候讨论物理信息学习,更具体地说,是物理信息神经网络(PINNs)了。

2.2 物理信息神经网络(PINNs)

在传统的机器学习模型中,我们主要依赖数据来学习潜在的模式。然而,在许多科学和工程领域,捕捉我们对动态系统的先验知识的控制方程(ODE/PDE)是可用的,它们提供了除了观察数据之外的另一种信息来源。如果正确地将这一额外的知识源纳入模型中,它可能会改善模型的性能和泛化能力,特别是在处理有限或噪声数据时。这就是物理信息学习的作用所在。

当我们将物理信息学习与神经网络的概念结合时,我们将得到物理信息神经网络(PINNs)。

PINNs 是一种神经网络,其中网络不仅仅是拟合数据,还要尊重由微分方程描述的已知物理定律。这是通过引入ODE/PDE 损失来实现的,它测量了控制微分方程的违反程度。通过这种方式,我们将物理定律注入网络训练过程,使其物理信息化

图 8. 物理信息神经网络的损失函数包括 PDE 损失的贡献项,这有效地测量了预测解是否满足控制微分方程。注意,由于自动微分的存在,相对于输入的输出的导数可以很容易地计算出来。(图片来源:作者)

尽管 PINNs 在许多应用中已被证明有效,但它们也不是没有局限性。PINNs 通常是针对特定的输入参数(例如边界和初始条件、外部强迫等)进行训练的。因此,每当输入参数发生变化时,我们就需要重新训练 PINN。因此,它们在不同操作条件下的实时推断效果不是特别好。

还记得哪个方法是专门用于处理变化的输入参数的吗?没错,就是 DeepONet!现在是将物理信息学习的理念与 DeepONet 结合的时候了。

2.3 物理信息 DeepONet

物理信息 DeepONet的主要思想是结合 DeepONets 和 PINNs 的优点。就像 DeepONet 一样,物理信息 DeepONet 能够将一个函数作为输入,并产生一个函数作为输出。这使得它在实时推断新输入函数时非常高效,无需重新训练。

另一方面,像 PINN 一样,物理信息 DeepONet 在学习过程中融入了已知的物理定律。这些定律作为额外的约束引入到训练过程中的损失函数中。这种方法使得模型即使在处理有限或嘈杂数据时,也能做出物理一致的预测。

我们如何实现这种整合呢?类似于 PINNs,我们增加一个额外的损失项,以衡量模型的预测如何符合已知的微分方程。通过优化这个损失函数,模型学会进行数据一致(如果在训练过程中提供了测量数据)和物理一致的预测。

图 10. 物理信息 DeepONet 使用 DeepONet 作为骨干架构,同时利用物理信息学习的概念来训练模型。这样,训练后的物理信息 DeepONet 既数据一致又物理一致。(图像由作者提供)

总结来说,物理信息 DeepONet 是一个强大的工具,结合了两者的优势:DeepONet 的高效性和物理信息学习的准确性。它代表了一种有前景的方法,用于解决那些实时推断和物理一致性都至关重要的复杂问题。

在下一部分,我们将开始进行案例研究,并将理论转化为实际代码。

3. 物理信息 DeepONet 的实现

在这一部分,我们将详细讲解如何定义一个物理信息 DeepONet 模型,以解决我们的目标案例研究。我们将使用 TensorFlow 来实现它。让我们先导入必要的库:

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
tf.random.set_seed(42)

3.1 定义架构

如前所述,物理信息 DeepONet 与原始 DeepONet 具有相同的架构。以下函数定义了 DeepONet 的架构:

def create_model(mean, var, verbose=False):
    """Definition of a DeepONet with fully connected branch and trunk layers.

    Args:
    ----
    mean: dictionary, mean values of the inputs
    var: dictionary, variance values of the inputs
    verbose: boolean, indicate whether to show the model summary

    Outputs:
    --------
    model: the DeepONet model
    """

    # Branch net
    branch_input = tf.keras.Input(shape=(len(mean['forcing'])), name="forcing")
    branch = tf.keras.layers.Normalization(mean=mean['forcing'], variance=var['forcing'])(branch_input)
    for i in range(3):
        branch = tf.keras.layers.Dense(50, activation="tanh")(branch)

    # Trunk net
    trunk_input = tf.keras.Input(shape=(len(mean['time'])), name="time")
    trunk = tf.keras.layers.Normalization(mean=mean['time'], variance=var['time'])(trunk_input)   
    for i in range(3):
        trunk = tf.keras.layers.Dense(50, activation="tanh")(trunk)

    # Compute the dot product between branch and trunk net
    dot_product = tf.reduce_sum(tf.multiply(branch, trunk), axis=1, keepdims=True)

    # Add the bias
    output = BiasLayer()(dot_product)

    # Create the model
    model = tf.keras.models.Model(inputs=[branch_input, trunk_input], outputs=output)

    if verbose:
        model.summary()

    return model 

在上面的代码中:

  1. 我们假设主干网络和分支网络都是完全连接的网络,每个网络有 3 个隐藏层,每层包含 50 个神经元,并且使用 tanh 激活函数。这个架构是基于初步测试选择的,并且应该作为这个问题的一个良好的起点。然而,可以很容易地用其他架构(例如 CNN、RNN 等)和其他层超参数进行替换。

  2. 主干网络和分支网络的输出通过点积合并。正如原始 DeepONet 论文中建议的,我们添加了一个偏置项以提高预测准确性。BiasLayer()是一个自定义定义的类,用于实现这个目标:

class BiasLayer(tf.keras.layers.Layer):
    def build(self, input_shape):
        self.bias = self.add_weight(shape=(1,),
                                    initializer=tf.keras.initializers.Zeros(),
                                    trainable=True)
    def call(self, inputs):
        return inputs + self.bias

3.2 定义 ODE 损失

接下来,我们定义一个函数来计算 ODE 损失。回顾一下我们的目标 ODE 是:

因此,我们可以按如下方式定义该函数:

@tf.function
def ODE_residual_calculator(t, u, u_t, model):
    """ODE residual calculation.

    Args:
    ----
    t: temporal coordinate
    u: input function evaluated at discrete temporal coordinates
    u_t: input function evaluated at t
    model: DeepONet model

    Outputs:
    --------
    ODE_residual: residual of the governing ODE
    """

    with tf.GradientTape() as tape:
        tape.watch(t)
        s = model({"forcing": u, "time": t})

    # Calculate gradients
    ds_dt = tape.gradient(s, t)

    # ODE residual
    ODE_residual = ds_dt - u_t

    return ODE_residual

在上面的代码中:

  1. 我们使用tf.GradientTape()来计算 s(·)相对于t的梯度。请注意,在 TensorFlow 中,tf.GradientTape()作为上下文管理器使用,任何在 tape 上下文中执行的操作都会被 tape 记录。在这里,我们显式地观察变量t。因此,TensorFlow 会自动跟踪涉及t的所有操作,在这种情况下,它是 DeepONet 模型的前向传播。之后,我们使用 tape 的gradient()方法来计算 s(·)相对于t的梯度。

  2. 我们包括了一个额外的输入参数u_t,它表示在t时刻评估的输入函数 u(·)的值。这构成了我们目标 ODE 的右侧项,并且它是计算 ODE 残差损失所需的。

  3. 我们使用@tf.function装饰器将我们刚刚定义的常规 Python 函数转换为 TensorFlow 图。这是有用的,因为梯度计算可能非常昂贵,并且在图模式下执行可以显著加速计算。

3.3 定义梯度下降步骤

接下来,我们定义一个函数来编译总损失函数并计算总损失相对于网络模型参数的梯度:

@tf.function
def train_step(X, X_init, IC_weight, ODE_weight, model):
    """Calculate gradients of the total loss with respect to network model parameters.

    Args:
    ----
    X: training dataset for evaluating ODE residuals
    X_init: training dataset for evaluating initial conditions
    IC_weight: weight for initial condition loss
    ODE_weight: weight for ODE loss
    model: DeepONet model

    Outputs:
    --------
    ODE_loss: calculated ODE loss
    IC_loss: calculated initial condition loss
    total_loss: weighted sum of ODE loss and initial condition loss
    gradients: gradients of the total loss with respect to network model parameters.
    """
    with tf.GradientTape() as tape:
        tape.watch(model.trainable_weights)

        # Initial condition prediction
        y_pred_IC = model({"forcing": X_init[:, 1:-1], "time": X_init[:, :1]})

        # Equation residual
        ODE_residual = ODE_residual_calculator(t=X[:, :1], u=X[:, 1:-1], u_t=X[:, -1:], model=model)

        # Calculate loss
        IC_loss = tf.reduce_mean(keras.losses.mean_squared_error(0, y_pred_IC))
        ODE_loss = tf.reduce_mean(tf.square(ODE_residual))

        # Total loss
        total_loss = IC_loss*IC_weight + ODE_loss*ODE_weight

    gradients = tape.gradient(total_loss, model.trainable_variables)

    return ODE_loss, IC_loss, total_loss, gradients

在上面的代码中:

  1. 我们只考虑两个损失项:与初始条件相关的损失IC_loss和 ODE 残差损失ODE_lossIC_loss通过将模型预测的 s(t=0)与已知的初始值 0 进行比较来计算,ODE_loss通过调用我们之前定义的ODE_residual_calculator函数来计算。如果有可用的测量 s(t)值(在上面的代码中未实现),数据损失也可以计算并添加到总损失中。

  2. 通常,总损失是IC_lossODE_loss的加权和,其中权重控制在训练过程中对这些单独损失项的重视程度或优先级。在我们的案例研究中,将IC_weightODE_weight都设置为 1 就足够了。

  3. 类似于我们计算ODE_loss的方式,我们也采用了tf.GradientTape()作为上下文管理器来计算梯度。然而,这里我们计算的是总损失相对于网络模型参数的梯度,这对于执行梯度下降是必要的。

在继续之前,让我们快速总结一下我们到目前为止所开发的内容:

1️⃣ 我们可以使用create_model()函数初始化一个 DeepONet 模型。

2️⃣ 我们可以计算 ODE 残差,以评估模型预测与所治理 ODE 的契合程度。这是通过ODE_residual_calculator函数实现的。

3️⃣ 我们可以使用train_step计算总损失及其相对于网络模型参数的梯度。

现在准备工作完成了一半🚀 在下一节中,我们将讨论数据生成和数据组织的问题(上述代码中的奇怪X[:, :1]会在那时变得清晰)。之后,我们终于可以训练模型并查看其表现。

4. 数据生成和组织

在本节中,我们讨论合成数据的生成及其在训练 Physics-informed DeepONet 模型中的组织方式。

4.1 生成 u(·)特征

用于训练、验证和测试的数据将是合成生成的。这样做的理由有两个:不仅方便,而且可以完全控制数据的特征。

在我们的案例研究中,我们将使用零均值的高斯过程生成输入函数 u(·),并使用径向基函数(RBF)核。

高斯过程是一种强大的数学框架,常用于机器学习中建模函数。RBF 核是捕捉输入点之间相似性的热门选择。通过在高斯过程中使用 RBF 核,我们确保生成的合成数据表现出平滑和连续的模式,这在各种应用中通常是有利的。如需了解更多关于高斯过程的内容,请随时查看我之前的博客

在 scikit-learn 中,这可以通过几行代码实现:

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF

def create_samples(length_scale, sample_num):
    """Create synthetic data for u(·)

    Args:
    ----
    length_scale: float, length scale for RNF kernel
    sample_num: number of u(·) profiles to generate

    Outputs:
    --------
    u_sample: generated u(·) profiles
    """

    # Define kernel with given length scale
    kernel = RBF(length_scale)

    # Create Gaussian process regressor
    gp = GaussianProcessRegressor(kernel=kernel)

    # Collocation point locations
    X_sample = np.linspace(0, 1, 100).reshape(-1, 1) 

    # Create samples
    u_sample = np.zeros((sample_num, 100))
    for i in range(sample_num):
        # sampling from the prior directly
        n = np.random.randint(0, 10000)
        u_sample[i, :] = gp.sample_y(X_sample, random_state=n).flatten()  

    return u_sample

在上面的代码中:

  1. 我们使用length_scale来控制生成函数的形状。对于 RBF 核,图 11 展示了不同核长度尺度下的 u(·)特征。

  2. 请记住,我们需要在将 u(·)输入 DeepONet 之前对其进行离散化。这是通过指定X_sample变量来完成的,该变量在我们感兴趣的时间域内分配 100 个均匀分布的点。

  3. 在 scikit-learn 中,GaussianProcessRegressor对象提供了一个sample_y方法,用于从具有长度尺度指定核的高斯过程抽取随机样本。注意,我们在使用GaussianProcessRegressor对象之前并没有调用.fit(),这与我们通常对其他 scikit-learn 回归器的做法不同。这是故意的,因为我们希望GaussianProcessRegressor使用我们提供的精确length_scale。如果你调用.fit()length_scale将被优化为另一个值以更好地拟合给定的数据。

  4. 输出u_sample是一个维度为 sample_num * 100 的矩阵。u_sample的每一行表示一个 u(·)的特征,其中包含 100 个离散值。

图 11. 在不同核长度尺度下的合成 u(·) 轮廓。(图片来源:作者)

4.2 数据集生成

现在我们已经生成了 u(·) 轮廓,让我们关注如何组织数据集,以便它可以输入到 DeepONet 模型中。

请记住,我们在上一节中开发的 DeepONet 模型需要 3 个输入:

  1. 时间坐标 t,这是介于 0 和 1 之间的标量(暂时不考虑批量大小);

  2. u(·) 的轮廓,这是一个由在预定义的、固定的时间坐标(介于 0 和 1 之间)下评估的 u(·) 值组成的向量;

  3. u(t) 的值,这也是一个标量。这个 u(t) 值用于在时间坐标 t 下计算 ODE 损失。

因此,我们可以这样构建一个单一的样本:

(图片来源:作者)

当然,对于每个 u(·) 轮廓(在上图中标记为绿色),我们应考虑多个 t(及其对应的 u(t))来评估 ODE 损失,以更好地施加物理约束。理论上,t 可以取考虑的时间域内的任何值(即在我们案例研究中为 0 和 1 之间)。然而,为了简化,我们只考虑在 u(·) 轮廓离散化的相同时间位置的 t。因此,我们更新后的数据集将如下所示:

(图片来源:作者)

请注意,上述讨论仅考虑了单一的 u(·) 轮廓。如果我们考虑所有的 u(·) 轮廓,我们的最终数据集将如下所示:

(图片来源:作者)

其中 N 代表 u(·) 轮廓的数量。现在有了这个前提,让我们看看一些代码:

from tqdm import tqdm
from scipy.integrate import solve_ivp

def generate_dataset(N, length_scale, ODE_solve=False):
    """Generate dataset for Physics-informed DeepONet training.

    Args:
    ----
    N: int, number of u(·) profiles
    length_scale: float, length scale for RNF kernel
    ODE_solve: boolean, indicate whether to compute the corresponding s(·)

    Outputs:
    --------
    X: the dataset for t, u(·) profiles, and u(t)
    y: the dataset for the corresponding ODE solution s(·)
    """

    # Create random fields
    random_field = create_samples(length_scale, N)

    # Compile dataset
    X = np.zeros((N*100, 100+2))
    y = np.zeros((N*100, 1))

    for i in tqdm(range(N)):
        u = np.tile(random_field[i, :], (100, 1))
        t = np.linspace(0, 1, 100).reshape(-1, 1)

        # u(·) evaluated at t
        u_t = np.diag(u).reshape(-1, 1)

        # Update overall matrix
        X[i*100:(i+1)*100, :] = np.concatenate((t, u, u_t), axis=1)

        # Solve ODE
        if ODE_solve:
            sol = solve_ivp(lambda var_t, var_s: np.interp(var_t, t.flatten(), random_field[i, :]), 
                            t_span=[0, 1], y0=[0], t_eval=t.flatten(), method='RK45')
            y[i*100:(i+1)*100, :] = sol.y[0].reshape(-1, 1)

    return X, y

在上述代码中,我们添加了一个选项,用于计算给定 u(·) 轮廓的相应 s(·)。虽然我们在训练中不会使用 s(·) 值,但我们仍然需要它们来测试模型性能。 s(·) 的计算是通过使用 scipy.integrate.solve_ivp 实现的,这是一个来自 SciPy 的 ODE 求解器,专门设计用于解决初值问题。

现在我们可以生成训练、验证和测试数据集。请注意,对于本案例研究,我们将使用 0.4 的长度尺度生成 u(·) 轮廓,并训练物理信息 DeepONet。

# Create training dataset
N_train = 2000
length_scale_train = 0.4
X_train, y_train = generate_dataset(N_train, length_scale_train)

# Create validation dataset
N_val = 100
length_scale_test = 0.4
X_val, y_val = generate_dataset(N_val, length_scale_test)

# Create testing dataset
N_test = 100
length_scale_test = 0.4
X_test, y_test = generate_dataset(N_test, length_scale_test, ODE_solve=True)

4.3 数据集组织

最后,我们将 NumPy 数组转换为 TensorFlow 数据集对象,以便于数据输入。

# Determine batch size
ini_batch_size = int(2000/100)
col_batch_size = 2000

# Create dataset object (initial conditions)
X_train_ini = tf.convert_to_tensor(X_train[X_train[:, 0]==0], dtype=tf.float32)
ini_ds = tf.data.Dataset.from_tensor_slices((X_train_ini))
ini_ds = ini_ds.shuffle(5000).batch(ini_batch_size)

# Create dataset object (collocation points)
X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
train_ds = tf.data.Dataset.from_tensor_slices((X_train))
train_ds = train_ds.shuffle(100000).batch(col_batch_size)

# Scaling 
mean = {
    'forcing': np.mean(X_train[:, 1:-1], axis=0),
    'time': np.mean(X_train[:, :1], axis=0)
}

var = {
    'forcing': np.var(X_train[:, 1:-1], axis=0),
    'time': np.var(X_train[:, :1], axis=0)
}

在上面的代码中,我们创建了两个不同的数据集:一个用于评估 ODE 损失(train_ds),另一个用于评估初始条件损失(ini_ds)。我们还预先计算了 t 和 u(·) 的均值和方差。这些值将用于标准化输入。

数据生成和组织的部分就到这里。接下来,我们将启动模型训练并查看其表现。

5. 训练物理信息 DeepONet

作为第一步,让我们创建一个自定义类来跟踪损失演变:

from collections import defaultdict

class LossTracking:

    def __init__(self):
        self.mean_total_loss = keras.metrics.Mean()
        self.mean_IC_loss = keras.metrics.Mean()
        self.mean_ODE_loss = keras.metrics.Mean()
        self.loss_history = defaultdict(list)

    def update(self, total_loss, IC_loss, ODE_loss):
        self.mean_total_loss(total_loss)
        self.mean_IC_loss(IC_loss)
        self.mean_ODE_loss(ODE_loss)

    def reset(self):
        self.mean_total_loss.reset_states()
        self.mean_IC_loss.reset_states()
        self.mean_ODE_loss.reset_states()

    def print(self):
        print(f"IC={self.mean_IC_loss.result().numpy():.4e}, \
              ODE={self.mean_ODE_loss.result().numpy():.4e}, \
              total_loss={self.mean_total_loss.result().numpy():.4e}")

    def history(self):
        self.loss_history['total_loss'].append(self.mean_total_loss.result().numpy())
        self.loss_history['IC_loss'].append(self.mean_IC_loss.result().numpy())
        self.loss_history['ODE_loss'].append(self.mean_ODE_loss.result().numpy())

然后,我们定义了主要的训练/验证逻辑:

# Set up training configurations
n_epochs = 300
IC_weight= tf.constant(1.0, dtype=tf.float32)   
ODE_weight= tf.constant(1.0, dtype=tf.float32)
loss_tracker = LossTracking()
val_loss_hist = []

# Set up optimizer
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

# Instantiate the PINN model
PI_DeepONet= create_model(mean, var)
PI_DeepONet.compile(optimizer=optimizer)

# Configure callbacks
_callbacks = [keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=30),
             tf.keras.callbacks.ModelCheckpoint('NN_model.h5', monitor='val_loss', save_best_only=True)]
callbacks = tf.keras.callbacks.CallbackList(
                _callbacks, add_history=False, model=PI_DeepONet)

# Start training process
for epoch in range(1, n_epochs + 1):  
    print(f"Epoch {epoch}:")

    for X_init, X in zip(ini_ds, train_ds):

        # Calculate gradients
        ODE_loss, IC_loss, total_loss, gradients = train_step(X, X_init, 
                                                            IC_weight, ODE_weight,
                                                            PI_DeepONet)
        # Gradient descent
        PI_DeepONet.optimizer.apply_gradients(zip(gradients, PI_DeepONet.trainable_variables))

        # Loss tracking
        loss_tracker.update(total_loss, IC_loss, ODE_loss)

    # Loss summary
    loss_tracker.history()
    loss_tracker.print()
    loss_tracker.reset()

    ####### Validation
    val_res = ODE_residual_calculator(X_val[:, :1], X_val[:, 1:-1], X_val[:, -1:], PI_DeepONet)
    val_ODE = tf.cast(tf.reduce_mean(tf.square(val_res)), tf.float32)

    X_val_ini = X_val[X_val[:, 0]==0]
    pred_ini_valid = PI_DeepONet.predict({"forcing": X_val_ini[:, 1:-1], "time": X_val_ini[:, :1]}, batch_size=12800)
    val_IC = tf.reduce_mean(keras.losses.mean_squared_error(0, pred_ini_valid))
    print(f"val_IC: {val_IC.numpy():.4e}, val_ODE: {val_ODE.numpy():.4e}, lr: {PI_DeepONet.optimizer.lr.numpy():.2e}")

    # Callback at the end of epoch
    callbacks.on_epoch_end(epoch, logs={'val_loss': val_IC+val_ODE})
    val_loss_hist.append(val_IC+val_ODE)

    # Re-shuffle dataset
    ini_ds = tf.data.Dataset.from_tensor_slices((X_train_ini))
    ini_ds = ini_ds.shuffle(5000).batch(ini_batch_size)

    train_ds = tf.data.Dataset.from_tensor_slices((X_train))
    train_ds = train_ds.shuffle(100000).batch(col_batch_size)

这是一段相当长的代码,但它应该是自解释的,因为我们已经覆盖了所有重要部分。

为了可视化训练性能,我们可以绘制损失收敛曲线:

# History
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].plot(range(n_epochs), loss_tracker.loss_history['IC_loss'])
ax[1].plot(range(n_epochs), loss_tracker.loss_history['ODE_loss'])
ax[2].plot(range(n_epochs), val_loss_hist)
ax[0].set_title('IC Loss')
ax[1].set_title('ODE Loss')
ax[2].set_title('Val Loss')
for axs in ax:
    axs.set_yscale('log')

训练结果如下所示:

图 12. 损失收敛图。(图像由作者提供)

此外,我们还可以看到在训练过程中某一特定目标 s(·)的预测准确性如何变化:

在训练开始时,我们可以看到模型预测与真实值之间存在明显的差异。然而,到训练结束时,预测的 s(·)已经收敛到真实值。这表明我们的物理信息深度网络学习得很充分。

6. 结果讨论

一旦训练完成,我们可以重新加载保存的权重并评估性能。

在这里,我们随机挑选了三个 u(·)轮廓从测试数据集中,并将其对应的 s(·)与我们的物理信息深度网络预测的结果以及由数值 ODE 求解器计算的真实值进行了比较。我们可以看到,预测结果与真实值几乎无法区分。

图 13. 从测试数据集中随机选择了三个 u(·)轮廓,这些轮廓显示在上排。下排显示了对应的 s(·)轮廓。我们可以看到,物理信息深度网络预测的结果与由数值 ODE 求解器计算的真实值几乎无法区分。(图像由作者提供)

这些结果相当令人惊讶,考虑到我们甚至没有使用任何 s(·)的观测数据(除了初始条件)来训练 DeepONet。这表明控制 ODE 本身为模型提供了足够的“监督”信号,以做出准确的预测。

另一个有趣的评估点是所谓的“分布外”预测能力。由于我们在训练 DeepONet 时强制执行了控制方程,我们可以预期训练得到的物理信息深度网络在 u(·)轮廓超出训练 u(·)分布时仍能做出不错的预测。

为了测试这一点,我们可以使用不同的长度尺度生成 u(·)轮廓。以下结果显示了使用 0.6 长度尺度生成的三个 u(·)轮廓,以及预测的 s(·)。这些结果相当不错,考虑到物理信息深度网络是用 0.4 的长度尺度训练的。

图 14. 训练得到的物理信息深度网络显示了一定程度的分布外预测能力。(图像由作者提供)

然而,如果我们继续将长度尺度减小到 0.2,我们会注意到明显的差异开始出现。这表明训练得到的物理信息深度网络(DeepONet)在泛化能力上存在限制。

图 15. 物理信息深度 ONet 的泛化能力是有限的。(作者提供的图像)

较小的长度尺度通常会导致更复杂的 u(·)轮廓,这与用于训练的 u(·)轮廓可能有很大不同。这可以解释为何训练后的模型在较小长度尺度区域预测准确性遇到挑战。

图 16. 我们训练的模型在泛化到较小长度尺度区域时面临挑战,因为 u(·)轮廓更复杂,与训练数据有较大区别。(作者提供的图像)

总的来说,我们可以说开发的物理信息深度 ONet 能够在仅有 ODE 约束的情况下正确学习系统动态和从输入函数到输出函数的映射。此外,物理信息深度 ONet 在处理“超分布”预测方面显示出一定的能力,这表明训练模型与控制 ODE 对齐可以提高模型的泛化能力。

7. 收获

我们在探索物理信息深度 ONet 的过程中走了很长一段路。从理解深度 ONet 和物理信息学习的基本概念,到通过代码实现看到它们的实际应用,我们已经详细讲解了这一强大方法在求解微分方程中的应用。

这里有几点关键的收获:

1️⃣ 深度 ONet是一个强大的操作符学习框架,得益于其创新的分支和主干网络架构。

2️⃣ 物理信息学习明确地将动态系统的控制微分方程纳入学习过程,从而具有提高模型解释性和泛化能力的潜力。

3️⃣ 物理信息深度 ONet结合了深度 ONet 和物理信息学习的优势,呈现出作为学习功能映射的有前景工具,同时遵循相关的控制方程。

希望你喜欢这次对物理信息深度 ONet的深入探讨。接下来,我们将转向使用物理信息深度 ONet 解决逆问题。敬请关注!

如果你觉得我的内容有用,可以在这里请我喝杯咖啡🤗 非常感谢你的支持!

你可以在这里找到包含完整代码的辅助笔记本 💻

如果你也对使用物理信息深度 ONet 解决逆问题感兴趣,请随时查看我的新博客:使用物理信息深度 ONet 解决逆问题:带代码实现的实用指南

如果你想了解最新的物理知识学习最佳实践,请查看我目前正在进行的设计模式系列:揭示物理知识驱动神经网络的设计模式

你也可以订阅我的通讯或者在Medium上关注我。

参考文献

[1] Lu 等人,DeepONet:基于算子通用近似定理的非线性算子学习,用于识别微分方程。arXiv, 2019。

[2] Wang 等人,学习参数偏微分方程的解算符,基于物理知识的 DeepOnets。arXiv, 2021。

Optical Flow with RAFT: 第一部分

原文:towardsdatascience.com/optical-flow-with-raft-part-1-f984b4a33993?source=collection_archive---------1-----------------------#2023-10-03

深入探讨光流的深度学习

Isaac BerriosTowards Data Science Isaac Berrios

·

关注 发布于 Towards Data Science ·14 分钟阅读·2023 年 10 月 3 日

--

照片由 Zdeněk Macháček 提供,发布于 Unsplash

在这篇文章中,我们将学习一种旗舰级深度学习方法,这种方法在 2020 年获得了ECCV最佳论文奖,并被引用超过 1000 次。它也是许多顶级模型在KITTI 基准测试中的基础。这个模型叫做RAFT: Recurrent All-Pairs Field Transforms for Optical Flow,可以在PyTorchGitHub上轻松获取。实现使其非常易于获取,但模型复杂,理解起来可能会令人困惑。在这篇文章中,我们将把 RAFT 分解为其基本组成部分,并详细了解它们。然后,我们将学习如何在 Python 中使用它来估计光流。在第二部分中,我们将探索隐秘的细节并可视化不同的模块,以便深入理解它们的工作原理。

  • 介绍

  • RAFT 的基础

  • 视觉相似性

  • 迭代更新

  • 如何使用 RAFT

  • 结论

介绍

光流

光流是图像序列中像素的表观运动。为了估计光流,场景中物体的运动必须有相应的亮度位移。这意味着图像中的一个移动的红球在下一张图像中应具有相同的亮度和颜色,这使我们能够确定它在像素上的移动量。图 1 展示了一个光流的例子,其中一个逆时针旋转的天花板风扇被一系列图像捕捉到。

图 1. 图像序列的光流估计。帧 1,帧 2,帧 1 和帧 2 之间计算的光流。来源:作者。

最右边的彩色图像包含了从帧 1 到帧 2 的每个像素的表观运动,它的颜色编码方式使得不同的颜色表示像素运动的不同水平和垂直方向。这是一个密集光流估计的例子。

稠密光流的估计为每个像素分配一个二维流向量,描述其在时间间隔内的水平和垂直位移。在稀疏光流中,这个向量仅分配给对应于强特征(如角点和边缘)的像素。为了使流向量存在,该像素在时间 t 的亮度必须与时间 t+1 时相同,这被称为 亮度一致性假设位置(x,y)* 在时间t 的图像强度或亮度由 I(x,y,t) 给出。下面的图 2 展示了已知像素位移的示例,其中 dxdy 是水平和垂直图像位移,dt 是帧之间的时间差。

图 2。像素从时间 t 到 t+dt 的位移。亮度一致性假设意味着该像素在两个帧中具有相同的颜色和强度。来源:作者。

亮度一致性假设意味着在 (x,y,t) 处的像素在 (x+dx, y+dy, t+dy) 处将具有相同的强度。因此:I(x, y, t) = I(x+dx, y+dy, t+dt)

从亮度一致性假设出发,我们可以通过在(x, y, t) 周围展开右侧的 1ˢᵗ阶泰勒近似来推导光流方程[1]。

光流方程的推导。来源:作者。

水平和垂直梯度 IₓIᵧ 可以通过 Sobel 算子 进行近似,时间梯度 Iₜ 是已知的,因为我们有时间 tt+1 的图像。流方程有两个未知数 uv,分别是时间 dt 内的水平和垂直位移。一个方程中的两个未知数使这个问题成为一个 欠定 问题,许多尝试都旨在解决 uv。RAFT 是一种深度学习方法,用于估计 uv,但实际上,它比仅仅基于两帧预测光流要复杂得多。它经过精心设计,以准确估计光流场,下一节我们将深入探讨它的复杂细节。

RAFT 的基础

RAFT 是一个深度神经网络,能够估计给定一对连续图像的密集光流I₁I₂。它估计一个流位移场(, ),将每个像素(u, v)I₁中映射到I₂中对应的像素(u’, v’),其中(u’, v’) = (u + (u), v + (v))。它通过提取特征、寻找其相关性,然后以模拟优化算法的方式迭代更新流。初始流要么初始化为全 0,要么可以使用向前投影的先前流估计来初始化,这被称为温启动。整体架构如下所示。

图 3. RAFT 的架构。修改自

请注意,它包含三个主要模块:特征编码器模块、视觉相似性模块和迭代更新模块。RAFT 架构有两个版本,一个大版本有 480 万参数,一个小版本有 100 万参数,在这篇文章中我们将重点关注大版本,但理解小版本在理解大版本后意义不大。

特征提取

RAFT 通过一个包含六个残差块的卷积神经网络(CNN)对两个输入图像进行特征提取,并将每个图像下采样到 1/8 分辨率,具有 D 个特征图。

图 4. RAFT 的编码块。修改自

特征编码器网络g在两个图像上使用共享权重进行操作,而上下文编码器网络f仅在I₁上操作,并提取作为流估计的主要参考的特征。除了细微的差异外,两个网络的整体架构几乎相同。上下文网络使用批归一化,而特征网络使用实例归一化,上下文网络提取C = c + h特征图,其中c是上下文特征图的数量,h是将初始化迭代更新模块隐藏状态的隐藏特征图数量。

特征网络 f 和上下文网络 g 的函数映射。来源:作者。

注意:原始论文中经常使用特征图大小 H/8xW/8 的简写符号:HxW。这可能会令人困惑,因此我们将遵循 H’ = H/8 的约定,使特征图大小为 H’xW’。我们也将提及从 I₁ as g¹中提取的特征图张量,I₂亦然。

视觉相似性

相关体积

视觉相似性是一个 4D H’xW’xH’xW’的全对关联体积C,通过计算特征图的点积得到。

4D 相关体积的计算。修改自

在相关体积中,来自特征图的每个像素与特征图中的每个像素都有一个计算得到的相关性,我们称这些相关性中的每一个为2D 响应映射 (见图 5)。想象在 4D 空间中可能有些挑战,所以可以将体积的前两个维度展平:(H’xW’)xH’xW’,现在我们得到一个 3D 体积,其中的每个像素都有自己的 2D 响应映射,显示其与的每个像素位置的相关性。由于特征是从图像中提取的,响应映射实际上指示了I₁的给定像素与I₂的每个像素的相关程度。

视觉相似性是一种全对 Correlation Volume,通过计算每个像素位置处的每个特征图的相关性,将I₁的像素与I₂的每个单一像素联系起来

相关金字塔

相关体积有效地提供了关于小像素位移的信息,但可能难以捕捉较大的位移。为了捕捉大和小的像素位移,需要多个级别的相关性。为此,我们构建了一个包含多个相关体积级别的相关金字塔,其中不同级别的相关体积通过对相关体积的最后两个维度进行平均池化来生成。平均池化操作在体积的最后两个维度产生了I₂的粗略相关特征,这使得I₁的精细特征能够与I₂的逐渐粗略的特征相关联。每个金字塔级别包含越来越小的 2D 响应映射。

图 5。左:I₁中单个像素与I₂所有像素的关系。右:相关金字塔中各种相关体积的 2D 响应映射。来源

图 5 显示了不同平均池化级别的不同 2D 响应映射。相应的相关体积的尺寸被堆叠到一个 5D 相关金字塔中,其中包含四个级别的核大小:1、2、4 和 8。该金字塔提供了关于大和小位移的强大信息,同时保持对I₁的高分辨率。

相关查找

相关查找运算符 L꜀ 通过在每个级别的相关金字塔中索引特征来生成新的特征图。给定当前的光流估计(, )I₁的每个像素:x = (u, v)映射到其在I₂中估计的对应关系x’ = (u + f¹(u) + v + f²(v))。我们定义了x’周围的局部邻域:

以像素x’ = (u’, v’)为中心的半径r的邻域。来源:作者。

对应关系是基于其流估计的I₂像素的新位置

所有金字塔层级上的常量半径意味着更大的上下文将被纳入到较低层级中。 即,半径为 4 对应于原始分辨率下的 256 像素。

实际上,这个邻域是围绕每个细分分辨率像素中心的正方形网格,r = 4 时,我们在每个像素周围得到一个 9x9 的网格,其中每个维度的长度为 (2r + 1)。我们通过 双线性重采样 在网格定义的位置周围对每个像素的相关性特征进行重采样(边缘位置使用零填充)。由于流偏移和平均池化,邻域网格值可能是浮点数,双线性重采样通过对附近像素的 2x2 子邻域进行加权平均来处理这一点。换句话说,重采样将提供 亚像素 精度。我们在金字塔的每一层的所有像素位置进行重采样,这可以通过 PyTorch 的 F.grid_sample() 高效完成。这些重采样后的特征被称为 相关性特征,并输入到更新块中。

高效的相关性查找(可选)

相关性查找的复杂度为 O(N²),其中 N 是像素数量,这可能会成为大图像的瓶颈,但存在一种等效操作,其复杂度为 O(NM),其中 M 是金字塔层数。该操作将相关性金字塔与查找相结合,利用了内积和平均池化的线性特性。下图显示了在 2ᵐx2ᵐ 网格上的平均相关性响应 Cᵐ(金字塔层级 m)。

等效的相关性实现。来源

对于给定的金字塔层级 m,我们不需要对特征图 进行求和,这意味着可以通过将特征图 与平均池化后的特征图 进行内积来计算相关性,这具有 O(N) 的复杂度。由于这仅适用于单个金字塔层级 m,我们必须为每一层计算这个内积,使其复杂度为 O(M),总复杂度为 O(NM)。我们不是预计算金字塔的相关性,而是只预计算池化特征图,并在查找发生时按需计算相关性值。

迭代更新

更新操作符估计一系列光流:{f₀, f,…, f}** 从初始起点 f₀ 开始,该起点可以是全 0 或向前投影的先前光流估计(热启动)。在每次迭代 k 中,它产生一个光流更新方向 Δf,该方向被加到当前估计中:fₖ₊₁ = fₖ + Δfₖ。更新操作符模仿优化算法,并经过训练以提供更新,使得估计的光流序列收敛到一个固定点:fₖ → f*。

更新块

更新块的输入包括:相关特征、当前光流估计、上下文特征和隐藏特征。其结构及突出显示的子块如下所示。

图 6. 大型架构的 RAFT 更新块,不同子块突出显示。蓝色-特征提取块,红色 — 递归更新块,绿色 — 光流头。改编自 来源

更新块中的子块包括:

  • 特征提取块 — 从相关性、光流和 I₁(上下文网络)中提取运动特征。

  • 递归更新块 — 递归计算光流更新

  • 光流头 — 最终卷积层,将光流估计重新调整为 H/8 x W/8 x 2

如图 6 所示,递归更新块的输入是光流、相关性和上下文特征的连接。潜在的隐藏状态由上下文网络中的隐藏特征初始化。(上下文网络提取了一堆 2D 特征图,然后将其分离为上下文特征图和隐藏特征图)。递归更新块由 2 个可分离的 ConvGRU 组成,这些 ConvGRU 可以在不显著增加网络规模的情况下增加感受野。在每次更新时,递归更新块中的隐藏状态被传递到光流头,以获得尺寸为 H/8 x W/8 x 2 的光流估计。该估计随后使用凸上采样进行上采样。

凸上采样

RAFT 的作者实验了 双线性 和凸上采样,并发现凸上采样提供了显著的性能提升。

图 7. 双线性 VS 凸上采样的比较。 来源

凸上采样 估计每个细像素为其相邻 3x3 粗像素的 凸组合

让我们分解一下凸上采样的工作原理,下面的图 8 提供了一个很好的视觉示例。

图 8. 单个全分辨率像素(紫色)的凸上采样示例。 来源

首先,我们假设一个细分分辨率的像素是其最近的粗分辨率邻居的凸组合。这一假设意味着粗分辨率像素的加权和必须等于真实的细分分辨率像素,且权重之和为一且非负。由于我们是按八倍因子上采样的,每个粗分辨率像素必须分解成 64 个(8x8)的细分像素(图 8 中的视觉效果不按比例)。我们还注意到 3x3 网格中心的每个 64 个像素都需要自己的权重集,总共需要的权重数为:(H/8 x W/8 x (8x8x9))。

实际上,权重由神经网络参数化,凸上采样块使用两个卷积层来预测一个(H/8 x W/8 x (8x8x9))的掩码,然后对九个邻居的权重进行 softmax,得到形状为(H/8 x W/8 x (8x8))的掩码。然后我们使用这个掩码来获得邻域的加权组合,并重新调整以得到 HxWx2 的流场。

训练

RAFT 的目标函数能够捕捉所有迭代的流预测。形式上,它是流预测和真实值之间加权的l1距离的总和,权重以指数形式增长。

RAFT 的损失,γ = 0.8。 来源

如何使用 RAFT

我们可以使用 RAFT 来估计我们自己图像上的密集光流。首先,我们需要克隆GitHub 仓库并下载模型。此教程的代码在GitHub上。

!git clone https://github.com/princeton-vl/RAFT.git

%cd RAFT
!./download_models.sh
%cd ..

预训练的 RAFT 模型有几种不同的版本,根据作者,它们是:

  • raft-chairs — 在 FlyingChairs 上训练

  • raft-things — 在 FlyingChairs + FlyingThings 上训练

  • raft-sintel — 在 FlyingChairs + FlyingThings + Sintel + KITTI 上训练(用于提交的模型)

  • raft-kitti — raft-sintel 在仅 KITTI 上微调

  • raft-small — 在 FlyingChairs + FlyingThings 上训练

接下来,我们将 RAFT 的核心添加到路径中

sys.path.append('RAFT/core')

现在,我们需要一些辅助函数来与 RAFT 类接口。注意:这些辅助函数仅为 CUDA 编写,但你可以通过Colab轻松访问 GPU。

import torch
from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder

def process_img(img, device='cuda'):
    return torch.from_numpy(img).permute(2, 0, 1).float()[None].to(device)

def load_model(weights_path, args):
    """ Loads model to CUDA only """
    model = RAFT(args)
    pretrained_weights = torch.load(weights_path, map_location=torch.device("cpu"))
    model = torch.nn.DataParallel(model)
    model.load_state_dict(pretrained_weights)
    model.to("cuda")
    return model

def inference(model, frame1, frame2, device='cuda', pad_mode='sintel',
              iters=12, flow_init=None, upsample=True, test_mode=True):

    model.eval()
    with torch.no_grad():
        # preprocess
        frame1 = process_img(frame1, device)
        frame2 = process_img(frame2, device)

        padder = InputPadder(frame1.shape, mode=pad_mode)
        frame1, frame2 = padder.pad(frame1, frame2)

        # predict flow
        if test_mode:
          flow_low, flow_up = model(frame1,
                                    frame2,
                                    iters=iters,
                                    flow_init=flow_init,
                                    upsample=upsample,
                                    test_mode=test_mode)
          return flow_low, flow_up

        else:
            flow_iters = model(frame1,
                               frame2,
                               iters=iters,
                               flow_init=flow_init,
                               upsample=upsample,
                               test_mode=test_mode)

            return flow_iters

def get_viz(flo):
    flo = flo[0].permute(1,2,0).cpu().numpy()
    return flow_viz.flow_to_image(flo)

注意到inference()中的输入填充,我们需要确保所有图像的尺寸都能被 8 整除。raft.py 代码可以从命令行轻松访问,但如果我们想要与之接口,我们需要重写部分代码,或者可以创建一个特殊的类来传递参数。

# class to interface with RAFT
class Args():
  def __init__(self, model='', path='', small=False, 
               mixed_precision=True, alternate_corr=False):
    self.model = model
    self.path = path
    self.small = small
    self.mixed_precision = mixed_precision
    self.alternate_corr = alternate_corr

  """ Sketchy hack to pretend to iterate through the class objects """
  def __iter__(self):
    return self

  def __next__(self):
    raise StopIteration

Args 类的默认初始化将直接与任何大型 RAFT 模型进行接口。为了演示 RAFT,我们将使用一个慢速旋转的天花板风扇视频的帧。现在我们可以加载一个模型并估算光流。

model = load_model("RAFT/models/raft-sintel.pth", args=Args())
flow_low, flow_up = inference(model, frame1, frame2, device='cuda', test_mode=True)

测试模式将返回 1/8 分辨率光流以及凸性上采样光流。

图 9. 上:RAFT 的输入图像序列。下:1/8 分辨率和上采样的光流估计。图片来源于作者。来源:作者。

再次,我们可以看到凸性上采样的显著优势,现在让我们来看看额外迭代的优势。图 10 展示了一个风扇图像上的 20 次迭代的 GIF 动画。

 flow_iters = inference(model, frame1, frame2, device='cuda', pad_mode=None, iters=20, test_mode=False)

图 10. 光流估计的渐进迭代。来源:作者。

我们可以看到前几次迭代的明显好处,在这种情况下,模型能够在约 10 次迭代中收敛。现在我们将尝试使用温初始化,将前一个 1/8 分辨率的流估计传递给推理函数。

# get previous estimate at 1/8 res
flow_lo, flow_up = inference(model, frame1, frame2, device='cuda', pad_mode=None, iters=20, test_mode=True)

# 0 initialization
flow_lo_cold, flow_up_cold = inference(model, frame2, frame3, device='cuda', pad_mode=None, flow_init=None, iters=20, test_mode=True)

# warm initialization
flow_lo_warm, flow_up_warm = inference(model, frame2, frame3, device='cuda', pad_mode=None, flow_init=flow_lo, iters=20, test_mode=True)

图 11. 在第 2 和第 3 帧之间进行 0 VS 温初始化的光流估计。来源:作者。

在这种情况下,我们并未看到任何改善,右侧的温初始化实际上看起来比初始化为 0 的流还要糟糕。对于这个视频序列来说,温启动的好处似乎微乎其微,但在不同的环境中可能会有用。

结论

在本文中,我们了解了 RAFT,一个能够估算准确流场的先进模型。RAFT 能够通过计算从提取的特征图中的所有像素的全对相关体积来捕捉所有像素之间的关系。建立相关金字塔以捕捉大和小的像素位移。查找运算符基于当前流估计从相关金字塔中提取新的相关特征。更新块使用相关特征和当前流估计提供迭代更新,收敛到最终的流估计,然后使用凸性上采样进行上采样。在第二部分,我们将详细探讨网络并了解一些关键块的工作方式。

参考资料

[1] Horn, B. K. P., & Schunck, B. G. (1981). 确定光流。人工智能, 17(1–3), 185–203. doi.org/10.1016/0004-3702(81)90024-2

[2] Teed, Z., & Deng, J. (2020). Raft: 循环全对场变换用于光流。计算机视觉 — ECCV 2020, 402–419. doi.org/10.1007/978-3-030-58536-5_24

RAFT 中的光流:第二部分

原文:towardsdatascience.com/optical-flow-with-raft-part-2-f0376a972c25?source=collection_archive---------8-----------------------#2023-10-03

解密深度光流模型

Isaac BerriosTowards Data Science Isaac Berrios

·

关注 发表在Towards Data Science ·10 分钟阅读·2023 年 10 月 3 日

--

图片由Kevin Hansen提供,来源于Unsplash

在这篇文章中,我们将以另一种方式来看待RAFT。第一部分中的直接方法能够详细解析网络的细节,但在这里我们将可视化这些细节并建立有价值的直觉。在第一部分中,我们的目标是理解 RAFT,以便我们可以直接使用它;而在第二部分,我们将旨在以一种方式理解 RAFT,使我们能够利用其架构的不同部分来构建自己的模型。

这里是本文的概览:

  • 动机

  • RAFT 架构

  • 查找操作符

  • 迭代更新

  • 结论

动机

RAFT 的概念被许多后续工作利用,理解 RAFT 是理解这些新方法的关键。你如何知道 RAFT 的哪些部分可以或应该被利用?为什么许多后续工作使用相关体?这些问题的答案来自于掌握 RAFT 的内部工作原理,仅仅了解论文表面内容通常是不够的,有时我们需要深入探讨,RAFT 也不例外。

RAFT 架构

首先,快速回顾一下 RAFT,其架构可以分解为三个基本模块,如下所示。

图 1. RAFT 的架构。修改自 来源

特征提取

特征编码器是一个卷积神经网络(CNN),利用共享权重从图像 I₁I₂ 中提取特征。上下文编码器提取上下文和隐藏特征,这些特征都被输入到迭代更新模块中。

特征网络 f 和上下文网络 g 的函数映射。来源:作者。

4D 相关体的计算。修改自 来源

视觉相似性

两个特征图的点积形成一个 4D 全对相关体,其中 的每个像素映射到 的所有像素,这些映射被称为 2D 响应图其中 g¹ 和 g² 分别是从 I₁ 和 I₂ 中提取的特征图张量。 在相关体的最后两个维度上执行平均池化,卷积核大小为:1、2、4、8。

图 2. 左:I₁ 中单个像素与 I₂ 中所有像素的关系。右:相关金字塔中各种相关体的 2D 响应图。来源

我们将这些相关体积堆叠成一个 5D 相关金字塔,每一层将 的精细像素与 的逐渐粗糙的像素特征相关联。这使我们能够捕获关于大规模和小规模像素位移的信息。查找操作符 从相关体积中提取相关特征。它获取 的一个精细特征像素及其对应的光流场,并计算其新的明显位置,这称为其 对应关系。然后,它在预定义半径 r 周围形成一个 2D 网格,并随后沿网格执行 亚像素 双线性重采样 以获得新的网格值。这个重采样的特征网格包含流(下降)方向信息,查找操作符对金字塔中每层的每个像素执行此操作。这些逐像素特征网格称为 相关特征,然后被重新塑形并输入到 迭代更新块

迭代更新

迭代更新块接受四个输入:上下文特征、相关特征、当前流估计和隐藏特征。流和相关特征被一起编码为运动特征,因为它们都描述了特征像素的相对运动。上下文特征保持不变,作为更新块的稳定参考。该块本身由一个 ConvGRU 组成,该网络计算一定数量的递归更新,随后由一个包含卷积层的流头将隐藏状态转换为原始输入分辨率的 1/8 的流估计。

图 3. RAFT 更新块用于大型架构,突出显示了不同的子块。蓝色—特征提取块,红色—递归更新块,绿色—流头。修改自 Source

更新操作符的功能类似于优化算法,这意味着它从初始流 f₀ 开始,并迭代计算新的流值 Δfₖ,这些值被添加到先前的流估计中:fₖ₊₁ = fₖ + Δfₖ,直到收敛到一个固定值:fₖ → f*。在执行迭代更新时,流估计和相关特征(不是相关金字塔)会持续被精炼。一旦迭代耗尽,流估计将从 1/8 上采样到原始分辨率。

凸上采样

凸上采样 通过将每个精细像素估计为其邻近的 3x3 粗像素网格的 凸组合。权重由一个神经网络参数化,该网络能够为每个精细像素学习最佳权重。下方显示了一个示例。

图 4. 上:双线性 VS 凸采样的比较。下:单个全分辨率像素(紫色)的凸采样示例。 来源

学习光流偏移量

重要的是要记住,RAFT 不一定估计光流,它估计的是从起始点的 光流偏移量,输出的是这些光流偏移量的累积。第一次估计是对 I₁ 像素位置上 0 的先前光流的更新。I₁ 的信息来自初始隐藏状态和上下文特征,这为更新块提供了持续的反馈,指导学习,使 RAFT 能够估计光流偏移量,而不是重新估计光流。

RAFT 不直接估计光流,它的更新块从起始点估计光流偏移量,模型输出的是光流偏移量的累积。

在训练过程中,网络的递归更新模拟了优化算法的步骤,其中每个新的光流估计 fₖ₊₁ = fₖ + Δfₖ 被目标函数越来越严格地审视,迫使网络随着迭代次数的增加学习更保守的 Δfₖ 估计。目标函数捕捉所有光流更新,是光流预测与真实值之间的加权 l1 距离的总和,权重呈指数增长。

RAFT 的损失,γ = 0.8。 来源

这种优化算法的模仿以及查找操作符的半径协同作用,限制了每次更新的搜索空间,从而减少了过拟合的风险,加快了训练速度,并且提高了泛化能力。

解读 RAFT

现在我们可以开始解读 RAFT,并深入了解它如何进行预测。这个教程的代码位于 GitHub,在这里 RAFT 代码已被修改,以在每个主要块处输出其中间特征图。测试图像是逆时针旋转的吊扇,这会产生几乎所有方向的光流。为了获得较大的光流位移,我们将跳过序列中的几个图像,这将使光流特征更加明显,便于研究。

图 5. 预测的光流。来源:作者。

现在让我们检查来自特征编码器块的图。

图 6. 来自特征编码器块的特征图(0–127)。来源:作者。

特征编码器生成了fmap 1fmap 2,它们看起来像噪声,但实际上包含了对流动估计至关重要的信息。隐藏的特征图类似于输入图像I₁,它们直接初始化了 ConvGRU 更新块的隐藏状态,提供了关于I₁中像素的宝贵信息。上下文特征图与输入图像有些相似,突出显示了诸如角落和边缘等强特征。原始论文建议,上下文特征提高了网络准确识别空间运动边界的能力。

相关性

相关金字塔对于计算准确的光流至关重要,因为它能够捕捉从I₁I₂的像素对应关系。正如我们将看到的,相关体积是 RAFT 的核心,我们将可视化其出色的像素位移捕捉能力。我们将通过检查几个测试像素来接近这一点,并查看它们的二维响应图如何捕捉相对位移。我们只能看到特征,但像素位移的信息仍然会显现。

相关金字塔捕捉了图像序列中像素在多个分辨率级别上的对应关系。

下图展示了一个测试图像的估计流动,并标注了一些测试像素。

图 7. 带注释测试像素的流动图像。来源:作者。

每个像素的水平和垂直流场分量的估计值是:

  • 像素 0: (-49.4, -4.3)

  • 像素 1: (-5.8, -26.4)

  • 像素 2: (23.5, -9.3)

访问相关特征

要访问给定测试像素的相关图,我们将以下函数添加到 corr.py 脚本中,以获得任何金字塔级别下给定测试像素的整数索引。

def get_corr_idx(loc, lvl, w=71, h=40):
    """ Obtains index of test pixel location in correlation volume.
        loc - test pixel location 
        lvl - Pyramid level
        w - 1/8 of padded horizontal image width
        h - 1/8 of padded vertical image height
        """
    u = np.clip(np.round(loc[2]/(lvl*8)), 0, (w-1)) 
    v = np.clip(np.round(loc[1]/(lvl*8)), 0, (h-1))
    return int(u + w*v)

一旦我们获得了测试像素索引,我们可以获取其二维响应图、重新采样的相关特征和每个金字塔级别的对应关系。

test_pixel_idx = get_corr_idx(test_pixel, lvl=(2**i))

# get the 2D response map
corr_response = corr[test_pixel_idx, 0, :, :].detach().cpu().numpy()

# get the resampled correlation feature grids
resampled_corr = bilinear_sampler(corr, coords_lvl)
resampled_corr_response = resampled_corr[test_pixel_idx, 0, :, :].detach().cpu().numpy()

# get the correspondence
pixel_loc = centroid_lvl[test_pixel_idx, :, :, :].cpu().numpy().squeeze()

可视化相关特征

对于每个像素,图 8-10 中展示了前 15 次更新的第一个金字塔级别的二维响应图 GIF。对应关系由红色方框标记。大的相关值(亮点)指示了I₂中的相对像素位置。注意,随着网络学习流动偏移,对应关系如何围绕高相关值收敛。尽管这些是大位移,但在第一个金字塔级别上的所有对相关性仍能捕捉到它们。

图 8. 金字塔级别 1,像素 0 的二维相关响应图。来源:作者。

图 9. 金字塔级别 1,像素 1 的二维相关响应图。来源:作者。

图 10. 金字塔级别 1,像素 2 的二维相关响应图。来源:作者。

相关金字塔能够捕捉所有层级的相关性,但这并不总是显而易见的。随着金字塔层级的提升,事物开始变得更加抽象,确定 RAFT 实际在做什么变得越来越困难,此外,我们看到的是提取特征的相关性,使得事情更加模糊。

检索下降方向

匹配点通过查找操作符来提出下降方向。相关查找操作符在新的匹配点位置周围放置一个半径为r的网格:x’ = (u + f¹(u) + v + f²(v)),其中(, )是当前的流场估计。围绕 x’的网格用于从相关体中进行双线性重采样。这些重采样网格是输入到更新操作符中的相关特征,以预测下一个流估计。下图显示了第一个金字塔层级 2D 相关响应在像素 1 的情况以及其对应的双线性重采样网格;上排是第一次迭代,流初始化为零;下排是第二次迭代。

图 11. 像素 1 放大后的相关响应和双线性重采样网格。上排:迭代 0,下排:迭代 1。匹配点在(行,列)。来源:作者。

请注意上排中,左侧的相关响应与右侧的重采样网格相同,这是由于零流初始化造成的。我们还注意到关于右上角的双线性重采样网格的一个非常重要的点,最大值直接位于中心的左侧。如果我们向右移动三像素并向上移动一像素,即(3, -1),那么我们就会落在这个大值上。这是从相关体中检索出的建议下降方向。在迭代更新块中,网络利用这些信息来制定实际的下降方向Δfₖ

在下排中,我们可以看到匹配点大致从(39, 34)移动到(41.93, 33.3),这是一个(2.93, -0.7)的位移,显示网络确实利用了建议的下降方向。在右下角的重采样网格中,我们看到最大值位于中心并与匹配点对齐,这表明网络已经有一个接近收敛的流预测。

运动特征

运动特征是相关特征和当前流估计的卷积编码。它们提供了需要由更新块细化的像素流信息。以下展示了每次迭代的一些运动特征图。

图 12. 不同迭代下的运动特征图。是的,我挑选了有趣的特征。来源:作者。

看起来,运动特征与具有大幅度移动的像素对应,不同的特征图似乎对应于不同的像素流,这在图 12 右侧的特征 126 和 127 中表现得尤为明显。它们都以类似的方式收敛到实际流动预测。

结论

在这篇文章中,我们学习了 RAFT 及其内部工作原理。我们看到提取的隐藏特征提供了关于 I₁ 的有用信息,而提取的上下文特征则提供了关于 I₁ 强特征的参考信息。我们已经可视化了相关体如何捕捉到小范围和大范围像素位移的信息。事实证明,RAFT 中的相关概念被许多后续工作所采用,从可视化中建立的直觉加强了这种模式。如果你已经读到这里,恭喜你!你现在对 RAFT 有了比表面层次更深入的理解。

参考文献

[1] Teed, Z., & Deng, J. (2020). Raft: Recurrent all-pairs field transforms for optical flow. 计算机视觉 — ECCV 2020, 402–419. doi.org/10.1007/978-3-030-58536-5_24

优化需求满足:行业方法

原文:towardsdatascience.com/optimal-demand-fulfillment-an-industryapproach-58746615d91e?source=collection_archive---------15-----------------------#2023-02-06

Saif Ali KherajTowards Data Science Saif Ali Kheraj

·

关注 发表在 Towards Data Science · 8 分钟阅读 · 2023 年 2 月 6 日

--

在电信领域,数据科学有许多应用,涵盖从站点规划到预算和管理会计决策、客户生命周期管理以及市场营销。在这篇文章中,我们将讨论一个这样的用例,即网络站点规划。我们将从两个场景开始,以便您可以理解它们。

场景 1: 最近,电信决策者发现由于未规划的站点规划导致了巨大的开销。网络站点当然有各种成本,包括固定建设成本、可变成本和其他可变开销。公司的管理者发现了导致低效的各种差异。现在,管理者想要重新规划一切,关闭不必要的站点,同时确保剩余的站点完全覆盖所有需求/区域。这就是一个用例。目标是减少基站的总数量,同时仍然覆盖所有区域。

场景 2: 另一个用例是当管理者想要限制用于覆盖最大区域数量的站点数量,这可能是由于财务不足或某个季度或年份的预算限制。因此,在这种情况下,我们希望在选择不超过 p 个基站(上限)的情况下最大化我们的覆盖区域。在这种情况下,可能无法覆盖所有区域。

概述

让我们看看如何使用优化技术来轻松解决这两个用例。需求并不总是指用户或区域的数量,但可以从整体上考虑多个因素,这超出了本文的范围。

场景 1

让我们从第一个用例开始,即在满足所有需求的情况下,尽可能少地选择基站。我们将使用一个简单的例子来演示这些概念;然而,在现实世界中,你可能会有不同类型的数据集,包含成千上万的基站,并且每个基站的邻域或覆盖范围可以通过一定的距离阈值来检查。

让我们从一个例子开始。我们有四个基站,区域被分为四个区域。大面积通常会在地图上分成不同的网格。我们不会深入讨论这一点。例如,基站 #1 可以覆盖区域 1 和 2,而基站 #2 可以覆盖区域 3。基站 #2 可以覆盖区域 1、2 和 3。基站 #3 可以覆盖区域 3 和 4。基站 #4 可以覆盖区域 2 和 4。下图展示了这一点。

图 1:区域与网络塔的连接性(作者图)

在这种情况下,我们想要找到满足所有要求的最少数量的基站。例如,我们可能不需要基站 #4,因为需求已经由基站 #1、#2 和 #3 满足。因此,我们可能要排除基站 #4。这正是我们想要完成的:找到最佳的基站组合,以满足所有需求,同时消除不必要的基站。

方程

让我们用 J 表示位置/基站,用 I 表示需求节点/区域。我们还定义 aij,如果需求节点 i 被位置节点 j 覆盖,则 aij 等于 1。aij 可以使用某个距离阈值来定义。

表 1:场景 1 的样本问题表述和变量符号(作者绘制)

每个基站位置由一个二进制变量 xj 表示,该变量指示该位置是否被选中。如果基站被选中,则 xj=1,否则 xj=0。

目标函数是通过最小化选定基站的总和来进行优化。

方程 1:目标函数——场景 1 中最小化选定基站的总和(作者绘制)

约束条件是确保每个区域至少被一个选定的基站覆盖。因此,我们希望每个需求节点 i(在我们案例中是区域)至少由一个位置 j(在我们案例中是基站)服务。

方程 2:约束条件——场景 1 中每个区域必须由至少一个基站/位置覆盖(作者绘制)

因此,对于每个区域 i,我们查看所有基站,看看这些基站是否覆盖区域 i。每个区域应该至少由一个基站覆盖。请注意,在这个特定模型中,区域可以由多个基站服务。

通过解决这个优化问题,电信公司可以在最小化所需基站数量的同时确定最佳的基站位置,以覆盖整个国家/城市/区域。

Code

from pyomo.environ import *

model = ConcreteModel()

# Define parameters regions and cell sites
model.I = Set(initialize=[1,2,3,4]) # regions/demands
model.J = Set(initialize=[1,2,3,4]) # cell sites
model.a = Param(model.I, model.J, initialize={(1,1):1, (1,2):1, (1,3):0, (1,4):0, (2,1):1,(2,2):1,(2,3):1, (2,4):0, (3,1):0,(3,2):0,(3,3):1,(3,4):1,(4,1):0,(4,2):1,(4,3):0, (4,4):1})

# Define variables
model.x = Var(model.J, within=Binary)

# Define objective function
def objective(model):
    return sum(model.x[j] for j in model.J)

model.obj = Objective(rule=objective, sense=minimize)

# Define constraints (sumproduct), for each Region i, we look at all cell sites, and see whether this cell site covers region i or not. sum must be greator than equal to 1
def demand_rule(model, i):
    return sum(model.a[i,j]*model.x[j] for j in model.J) >= 1

model.demand_constraint = Constraint(model.I, rule=demand_rule)

# Solve the model
solver = SolverFactory('glpk')
results = solver.solve(model)

# Print the solution
print("Optimal solution:")
for j in model.J:
    if model.x[j].value == 1:
        print(f"Location {j} is open")
print(f"Number of open locations: {model.obj()}")

图 1:场景 1 的结果(作者绘制)

正如我们所见,我们可以用两个基站覆盖所有需求点/区域,你可以使用上面的图表自行验证结果。

场景 2

现在我们进入第二个用例,讨论如何对可用的站点位置设置上限,以满足最大需求。考虑到可用的基站数量有限,我们希望最大化覆盖范围。例如,公司可能决定设置一个最大为两个基站的上限,并关闭其他基站以节省费用。现在的问题是我们能用两个基站覆盖多少个区域。当然,不同的组合是可能的;例如,如果我们只使用基站 1 和 2,并关闭基站 3 和 4,那么我们实际上会在区域 3 失去覆盖。也许有某些两个基站的组合可以覆盖整个区域,这将通过我们的优化方程来解决。可能存在不止一个可行的解决方案。

方程

表 2:场景 2 的样本问题表述和变量符号(作者绘制)

在这种情况下,我们希望基本上最大化我们的覆盖范围,我们用 y 变量定义覆盖的概念。在这种情况下, yi 表示我们的区域是否会被选中。如果被选中,则 yi 将为 1,否则为 0。我们基本上希望最大化这个总和,也就是说,我们希望每个区域都被覆盖。这可以通过以下方程表示。

方程 3:目标函数——最大化场景 1 的覆盖/区域(图由作者提供)

然而,我们有一些限制条件,因此可能无法覆盖所有区域。让我们从第一个限制条件开始,即覆盖区域的基站数量必须大于等于 yi。例如,如果某个区域被选中,则 yi=1,那么我们基本上希望覆盖该区域的基站数量大于等于 yi。

方程 4:限制条件——每个选定的区域(yi)需要至少由 1 个基站/位置覆盖,用于场景 2(图由作者提供)

方程与第一部分相同,只是将 1 替换为 yi。因此,对于每个区域 I,我们查看它是否被基站覆盖,并将其加起来以确保它大于等于 yi。

除此之外,最重要的限制条件是公司决定将我们限制在最多两个基站。我们基本上希望基站总数小于等于 p。(在我们的情况下是 2)。

方程 5:限制条件——场景 2 的 p 个基站上限(图由作者提供)

现在,让我们将所有方程组合在一起,再次写出完整的内容。

方程 6:结合上述所有内容的完整方程,用于场景 2(图由作者提供)

代码

from pyomo.environ import *

model = ConcreteModel()
# Define parameters regions and cell sites
model.I = Set(initialize=[1,2,3,4]) # regions/demands
model.J = Set(initialize=[1,2,3,4]) # cell sites
model.a = Param(model.I, model.J, initialize={(1,1):1, (1,2):1, (1,3):0, (1,4):0, (2,1):1,(2,2):1,(2,3):1, (2,4):0, (3,1):0,(3,2):0,(3,3):1,(3,4):1,(4,1):0,(4,2):1,(4,3):0, (4,4):1})

model.p = Param(initialize=1)
model.x = Var(model.J, within=Binary)
model.y = Var(model.I, within=Binary)

def objective(model):
    return sum(model.y[i] for i in model.I)

model.obj = Objective(rule=objective, sense=maximize)

##constraint
def coverage_rule(model, i):
    return sum(model.a[i,j]*model.x[j] for j in model.J) >= model.y[i]

model.coverage_constraint = Constraint(model.I, rule=coverage_rule)

def cellsitelocation_rule(model):
    return sum(model.x[j] for j in model.J) <= model.p

model.cellsite_constraint = Constraint(rule=cellsitelocation_rule)

opt = SolverFactory("glpk")
results = opt.solve(model)

print("Optimal solution:")
for j in model.J:
    if model.x[j].value == 1:
        print(f"Cellsite {j} is on")

print("Optimal solution:")
for i in model.I:
    if model.y[i].value == 1:
        print(f"Region {i} is covered")
print(f"Number of regions covered: {model.obj()}")

图 2:场景 2 的结果(图由作者提供)

我们仍然可以用有限数量的基站覆盖所有需求/区域。你现在可以尝试 p=1 来查看可能会遗漏哪些区域。重要的是要注意,使用上限时,你可能会遗漏一些区域,但仍能最大化覆盖范围。

结论

这些是我们在本文中探讨的非常简单的问题。然而,当涉及到这样的问题时,这可能是一个非常强大的工具。我们可以有一个更高级的方程,还带有加权部分,现在由你进一步思考。对于这些类型的问题还有许多其他变体,例如最小化从位置到需求点的运输成本。这些类型的问题有许多有趣的变体。

参考文献

[1] optimization.cbe.cornell.edu/index.php?title=Set_covering_problem

[2] www.im.ntu.edu.tw/~lckung/courses/OR15/slides/OR-Sp15_09_IPapplication.pdf

[3] www.pyomo.org/documentation

深度学习中的神经网络优化

原文:towardsdatascience.com/optimisation-algorithms-neural-networks-101-256e16a88412

如何超越“普通”梯度下降算法改进训练

Egor HowellTowards Data Science Egor Howell

·发表于 Towards Data Science ·8 分钟阅读·2023 年 11 月 24 日

--

www.flaticon.com/free-icons/neural-network.神经网络图标。神经网络图标由 andinur 创作 — Flaticon.

背景

在我上一篇文章中,我们讨论了如何通过超参数调整来提高神经网络的性能:

## 超参数调整:神经网络基础

如何通过调整超参数来改善神经网络的“学习”和“训练”

towardsdatascience.com

这是一个过程,通过对学习率和隐藏层数量等最佳超参数进行“调整”,以找到对我们的网络性能最优的参数。

不幸的是,对于大型深度神经网络(深度学习),这种调整过程极其缓慢。改进的一种方法是使用比传统“普通”梯度下降方法更快速的优化器。在这篇文章中,我们将深入探讨最流行的优化器和梯度下降的变体,这些可以提升训练速度以及收敛性,并在 PyTorch 中进行比较!

回顾:梯度下降

在深入之前,让我们快速复习一下梯度下降及其背后的理论。

梯度下降的目标是通过减去参数关于损失函数的梯度(偏导数)来更新模型的参数。学习率 α 用于调节此过程,以确保参数的更新在合理范围内进行,避免过度或不足地达到最优值。

梯度下降。方程由作者提供。

  • θ 是模型的参数。

  • J(θ) 是损失函数。

  • ∇J(θ) 是损失函数的梯度。 是梯度算子,也称为 nabla

  • α 是学习率。

我曾写过一篇关于梯度下降及其工作原理的文章,如果你想对它有更多了解,可以参考一下:

## 线性回归:梯度下降与解析解

解释了为什么梯度下降在数据科学中经常使用,并提供了 C 语言的实现

towardsdatascience.com

动量

梯度下降通常被可视化为一个球体在山坡上滚动。当它到达山谷底部时,它已经收敛到最小值,即最优值。一个持续向下滚动的球体会获得一定的 动量,然而,普通梯度下降是在每次迭代基础上进行的,并不了解之前的更新。

通过在梯度下降更新中包含动量,它为算法提供了关于先前计算的梯度的信息。

从数学上讲,我们得到的是:

动量梯度下降。方程由作者提供。

其中:

  • v_t 是当前的速度。

  • β 是动量系数,值在 0 和 1 之间。这有时也被解释为“摩擦力”。你需要找到最佳的 β 值,但通常 0.9 是一个不错的基准

  • t 是当前的时间步长或迭代次数。

  • v_{t−1}​* 是前一步的速度(上一次计算的值)。

其余术语与之前对普通梯度下降的定义相同!

注意,我们利用了先前梯度的信息来‘加速’当前梯度的方向。这提高了收敛速度,并减少了普通梯度下降中可能出现的任何振荡。

动量在 PyTorch 中也很容易实现。

optimizer = torch.optim.SGD([theta], lr=learning_rate, momentum=momentum)

Nesterov 加速梯度

Nesterov 加速梯度 (NAG),或称为 Nesterov 动量,是对动量算法的轻微修改,通常能导致更好的收敛效果。

NAG 测量相对于损失函数的梯度时略微超前于 θ. 这改善了收敛性,因为 动量 值通常会朝着最优点前进。因此,每次允许算法略微前进一步可以使其更快收敛。

Nesterov 加速梯度下降。方程由作者编写。

其中 ∇J(θ+βv_{t−1}​) 是在当前 θ 前稍微一点的损失函数的梯度。

上述方程中的所有术语与之前的优化器相同,因此我不会再次列出所有术语!

Nesterov 加速梯度也可以在 PyTorch 中轻松实现

optimizer = torch.optim.SGD([theta], lr=learning_rate, momentum=momentum, nesterov=True)

AdaGrad

自适应梯度算法(Adagrad) 是一种梯度下降算法,它使用自适应学习率,如果特征/参数更新得更频繁,学习率会变得更小。换句话说,它对更陡峭的梯度比对较浅的梯度衰减更多。

Adagrad。方程由作者编写。

这里:

  • G​ 是一个对角矩阵,积累了每个参数在时间步长内所有梯度的平方。

  • ϵ 是一个小的平滑项,用于避免当 G 非常小时的除零问题。

  • 表示逐元素乘法。这是 Hadamard 乘积

上述方程中的其余术语与之前的优化器相同,因此我不会再次列出所有术语!

元素级矩阵乘法的一个例子,假设 AB 都是 2x2

Hadamard 乘积的一个例子。方程由作者用 LaTeX 编写。

正如我们所看到的,G 的值越大,对参数的更新就越小。它基本上是平方梯度的移动平均。这确保了学习过程变慢,不会超过最优点。

Adagrad 的一个问题是,它有时会使学习率衰减得太多,导致神经网络过早停止学习并陷入停滞。因此,通常不推荐在训练神经网络时使用 Adagrad。

optimizer = torch.optim.Adagrad([theta], lr=learning_rate)

RMSProp

RMSProp(均方根传播) 通过只考虑最近的梯度来解决 Adagrad 过早结束训练的问题。它通过引入另一个超参数 β 来做到这一点,从而缩小对对角矩阵 G 内部值的影响:

RMSProp。方程由作者用 LaTeX 编写。

上述方程中的所有项都与之前优化器的相同,所以我不会再次列出它们!

像其他优化器一样,RMSProp 在 PyTorch 中实现起来很简单:

optimizer = torch.optim.RMSprop([theta], lr=learning_rate, alpha=alpha, eps=eps)

Adam

我们将要看的最终优化器是自适应矩估计,更为人知的是Adam。该算法结合了动量和 RMSProp,因此可以说是两者的最佳结合。不过,它有几个额外的步骤:

Adam 优化器。公式由作者以 LaTeX 呈现。

前两个步骤和最后一步几乎是我们之前展示的动量和 RMSProp 算法。第三步和第四步是修正v_tG_t的偏差,因为它们在开始时被初始化为 0。

Adam 是自适应学习率算法,类似于 RMSProp,因此使用此优化器时不一定需要调节学习率。

上述方程中的其他项与之前优化器相同,所以我不会再次列出它们!

以下是如何在 PyTorch 中应用 Adam:

optimizer = torch.optim.Adam([theta], lr=learning_rate)

其他优化器

这里有许多其他梯度下降优化器,我们考虑的只是一阶导数,这些被称为雅可比矩阵。还有一种二阶导数,称为赫西矩阵, 但其计算复杂度为,而一阶导数的计算复杂度仅为O

实际上,深度神经网络有数万到数百万行数据,因此赫西梯度下降方法很少使用。大多数情况下,金标准确实是 Adam 或 Nestorov。

还有批量、小批量和随机梯度下降,这些会影响网络的计算速度。我在之前的文章中写过这些算法。

其他一些常用优化器包括:

完整的综合列表可以在这里找到。

性能比较

下面的代码是对我们之前讨论的不同优化器的比较,针对J(θ) = θ²损失函数。最小值在θ = 0:

import torch
import plotly.graph_objects as go

# Function to perform optimisation and log theta
def optimize(optimizer_class, theta_init, lr, iterations, **kwargs):
    theta_values = []
    theta = torch.tensor([theta_init], requires_grad=True)
    optimizer = optimizer_class([theta], lr=lr, **kwargs)
    for _ in range(iterations):
        optimizer.zero_grad()
        loss = theta.pow(2)
        loss.backward()
        optimizer.step()
        theta_values.append(theta.item())
    return theta_values

# Initial values
theta_init = 3.0
learning_rate = 0.01
iterations = 1000

# Optimiser configurations
optim_configs = {
    "Momentum": {"optimizer_class": torch.optim.SGD, "lr": learning_rate, "momentum": 0.9},
    "Nesterov": {"optimizer_class": torch.optim.SGD, "lr": learning_rate, "momentum": 0.9, "nesterov": True},
    "Adagrad": {"optimizer_class": torch.optim.Adagrad, "lr": learning_rate},
    "Adam": {"optimizer_class": torch.optim.Adam, "lr": learning_rate},
    "RMSprop": {"optimizer_class": torch.optim.RMSprop, "lr": learning_rate}
}

# Run optimization for each optimizer and collect theta values
results = {}
for name, config in optim_configs.items():
    results[name] = optimize(**config, theta_init=theta_init, iterations=iterations)

# Plot the results
fig = go.Figure()

for optimiser, theta_values in results.items():
    fig.add_trace(go.Scatter(x=list(range(iterations)), y=theta_values, mode='lines', name=optimiser))

fig.update_layout(title="Optimiser Performance Comparison",
                  xaxis_title="Iteration Number",
                  yaxis_title="Value of Theta",
                  legend_title="Optimisers",
                  template="plotly_white",
                  width=900,
                  height=600,
                  font=dict(size=18),
                  xaxis=dict(tickfont=dict(size=16)),
                  yaxis=dict(tickfont=dict(size=16)),
                  title_font_size=24)

fig.show()

优化器比较。由作者使用 Python 绘制的图。

这个图很有趣,有几个关键点要指出:

  • 动量和 Nestorov 都超出了 θ. 的最优值。

  • Adagrad 非常慢。这与我们之前讨论的情况一致,即学习率迅速衰减,导致训练过早停止和学习停滞。

  • Adam 和 RMSProp 似乎是最好的,其中 RMSProp 更快地达到最优值。

当然,这只是一个简单的示例,在实际问题中,最佳的优化器可能会有所不同。因此,尝试各种不同的优化器并选择表现最佳的往往是非常值得的。

这段代码可以在我的 GitHub 上找到:

[## Medium-Articles/Neural Networks/optimisers.py at main · egorhowell/Medium-Articles

在我的中等博客/文章中使用的代码。通过创建帐户来贡献开发…

github.com](https://github.com/egorhowell/Medium-Articles/blob/main/Neural Networks/optimisers.py?source=post_page-----256e16a88412--------------------------------)

摘要与进一步的思考

在这篇文章中,我们看到了几种加速和提高普通梯度下降性能的方法。这两种方法类型是基于动量的,使用来自先前梯度的信息,以及基于自适应的,依据计算出的梯度调整学习率。在文献中,Adam 优化器通常是最推荐和最常用于研究的优化器。然而,尝试不同的优化器总是值得的,以确定哪种最适合你的模型。

另一个话题!

我有一个免费的新闻通讯,分析数据,我每周分享成为更好的数据科学家的技巧。没有“虚 fluff”或“点击诱饵”,只有来自实践数据科学家的纯粹可操作的见解。

[## 分析数据 | Egor Howell | Substack

如何成为更好的数据科学家。点击阅读《分析数据》,由 Egor Howell 编写,Substack 出版…

新闻通讯

与我联系!

参考文献与进一步阅读

这是我关于神经网络的其他一些可能感兴趣的博客:

## 激活函数与非线性:神经网络 101

解释神经网络为何能学习(几乎)任何事物和一切

[towardsdatascience.com ## 前向传播与反向传播:神经网络基础

通过手动和代码(使用 PyTorch)解释神经网络如何“训练”和“学习”数据中的模式

[towardsdatascience.com

优化:Python 中的容量限制设施选址问题

原文:towardsdatascience.com/optimization-capacitated-facility-location-problem-in-python-57c08f259fe0

查找最佳的仓库数量和位置以降低成本并满足需求

Nicolo Cosimo AlbaneseTowards Data Science Nicolo Cosimo Albanese

·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 2 月 28 日

--

图片由作者提供。

目录

  1. 简介

  2. 问题陈述

  3. 实现

    3.1. 数据集

    3.2. 客户、仓库和需求

    3.3. 供应和固定成本

    3.4. 运输成本

    3.5. 优化

  4. 探索结果

  5. 结论

1. 引言

设施选址问题(FLPs) 是经典的优化任务。它们的目标是确定仓库或工厂的最佳潜在位置。

仓库可能有或没有容量限制。这将 有容量(CFLP)无容量(UFLP) 问题变体区分开来。

业务目标是找到一组能够最小化成本的仓库位置。原始问题定义由 Balinski (1965) 提出了两个(年度)成本因素之和的最小化:

  • 运输成本

  • 仓库固定成本

运输成本指的是从仓库位置到客户的费用。仓库固定成本是特定于位置的。它可能包括如租金、税费、电费和维护等费用。

设施选址 是一个众所周知的主题,具有相当丰富的文献。因此,存在许多问题变体以及方法。这篇文章介绍了经典的 CFLP 公式,并分享了一个使用 PuLP 的实际 Python 示例。

2. 问题陈述

CFLP 的目标是确定能够满足客户需求的仓库数量和位置,同时降低固定和运输成本。因此,我们可以将问题表述为以下目标函数的最小化:

在前一个表达式中:

  • N是客户地点的集合。

  • M是候选仓库地点的集合。

  • fⱼ表示仓库j的年固定成本。

  • tᵢⱼ表示从仓库j到客户i的运输成本。

  • xᵢⱼ是从仓库j到客户i的单位数。

  • yⱼ是一个二进制变量yⱼ ∈ {0,1},表示是否在位置j建立仓库(yⱼ = 1)或不建立(yⱼ = 0)。

现在让我们考虑将约束添加到目标函数中。

由于我们正在建模一个有容量限制的问题,每个设施j可以供应的年最大容量为Cⱼ。因此,交付给客户的单位数xᵢⱼ不能超过此值:

从仓库j到客户i的年交付单位数必须在零到dᵢ之间,其中dᵢ是客户i的年需求:

最后但同样重要的是,我们必须满足客户的需求。在这个示例中,我们规定每个服务客户地点的仓库必须完全满足其需求:

总之,我们可以如下定义问题:

图片由作者提供。

3. 实施

让我们导入所需的库:

  • NumPyPandas用于数据处理。

  • math用于特定的数学函数。

  • GeoPandas用于地理空间表示。

  • Matplotlib用于数据可视化。

  • PuLP用于优化。

import numpy as np
import pandas as pd
import geopandas
from math import sin, cos, asin, acos, radians

from pulp import LpProblem, LpMinimize, LpVariable, LpBinary, lpSum, LpStatus, value

import matplotlib.pyplot as plt
plt.style.use('ggplot')

3.1. 数据集

我们在意大利设定优化问题。

起始数据集可以在 simplemaps.com 上获得。我们可以从这里下载输入 csv 文件,并在MIT 许可证个人和商业用途均可自由使用

# Load dataframe
df = pd.read_csv(
    './it.csv', 
    usecols = ['city', 'lat', 'lng', 'population', 'capital', 'admin_name'])

我们感兴趣的是以下列:

  • city: 城镇名称;

  • lat: 纬度;

  • lng: 经度;

  • population: 居民数量;

  • capital: 表示城市是否是主要城市或行政中心;

  • admin_name: 最高级别行政区域的名称。

3.2. 客户、仓库和需求

在创建客户、设施和需求时,我们假设:

  • 客户是输入城市的一部分(30%)。

  • 设施只能在行政中心建立。作为起始条件,我们假设我们可以在意大利 80%的主要城市建立仓库。

  • 需求是恒定且全年已知的。它等于客户城镇人口的一个部分(2%)加上一个误差项。

RANDOM_STATE = 2          # For reproducibility
FRACTION_CUSTOMERS = 0.3  # Fraction of cities we want to keep as customers
FRACTION_WAREHOUSES = 0.8 # Fraction of cities we want to keep as warehouse locations
FRACTION_DEMAND = 0.02    # Fraction of citizens of a city that may order a product  

# List of the 20 regions of Italy
REGION_LIST = [
    'Lombardy', 'Veneto', 'Emilia-Romagna', 'Sicilia', 'Campania', 'Piedmont', 'Puglia', 
    'Lazio', 'Calabria', 'Tuscany', 'Sardegna', 'Marche', 'Friuli-Venezia Giulia', 'Abruzzo',
    'Umbria', 'Trentino-Alto Adige', 'Liguria', 'Basilicata', 'Molise', 'Valle d’Aosta']

# Demand is composed of: 
#   1\. A fraction of the population
#   2\. An error term of uniform distribution
# Note: demand is approximated to the closest int 
# as its physical meaning denies decimals
df['demand'] = np.floor(
    FRACTION_DEMAND * df.population + np.random.uniform(-10, 10, size=(df.shape[0],)))

# Create the warehouses dataframe:
#   1\. Filter the 20 regions of Italy
#   2\. Filter capitals as candidate warehouse locations
#   3\. Sample a fraction of the original cities
facility_df = df.\
                loc[df.admin_name.isin(REGION_LIST)].\
                loc[df.capital.isin(['admin', 'minor'])].\
                sample(frac=FRACTION_WAREHOUSES, random_state=RANDOM_STATE, ignore_index=True)

# Create the customers dataframe:
#   1\. Filter the 20 regions of Italy
#   2\. Sample a fraction of the original cities
customer_df = df.\
                loc[df.admin_name.isin(REGION_LIST)].\
                sample(frac=FRACTION_CUSTOMERS, random_state=RANDOM_STATE, ignore_index=True)

# Customers IDs list
customer_df['customer_id'] = range(1, 1 + customer_df.shape[0])

注意:在在线数据集中,区域名称Valle d'Aosta包含的是排版(弯曲)撇号(U+2019),而不是打字机(直线)撇号(U+0027)。如果复制此代码,请考虑到这一点。

尽管这对于优化任务不是必需的,但我们可能希望在地图上观察我们的地点。geopandas简化了这一任务。可以使用points_from_xy方法轻松创建一个充满地理空间信息的GeoDataFrame

def add_geocoordinates(df, lat='lat', lng='lng'):
    '''
    Add column "geometry" with <shapely.geometry.point.Point> objects 
        built from latitude and longitude values in the input dataframe

    Args:
        - df: input dataframe
        - lat: name of the column containing the latitude (default: lat)
        - lng: name of the column containing the longitude (default: lng)
    Out:
        - df: same dataframe enriched with a geo-coordinate column
    '''
    assert pd.Series([lat, lng]).isin(df.columns).all(),\
        f'Cannot find columns "{lat}" and/or "{lng}" in the input dataframe.'
    return geopandas.GeoDataFrame(
        df, geometry=geopandas.points_from_xy(df.lng, df.lat))

customer_df = add_geocoordinates(customer_df)
facility_df = add_geocoordinates(facility_df)

我们可以通过geopandas访问意大利的地图,并绘制客户和潜在的仓库位置:

# Prepare country plot
world = geopandas.read_file(geopandas.datasets.get_path('naturalearth_lowres'))

# Extract and plot the shape of Italy
italy = world[world.name == 'Italy']
ax = italy.plot(color='white', edgecolor='black', figsize=(10, 10))

# Plot customers as points
customer_df.\
    plot(ax=ax, marker='X', color='red', markersize=30, alpha=0.5, label='Customer')

# Plot potential facility locations as points
facility_df.\
    plot(ax=ax, marker='D', color='blue', markersize=30, alpha=0.5, label='Potential warehouse')

# Add legend
plt.legend(facecolor='white', title='Location')

# Add title
plt.title('Customer and potential warehouses')

# Remove ticks from axis
plt.xticks([])
plt.yticks([])

# Show plot
plt.show()

图片由作者提供。

同样,我们可以观察到 20 个意大利地区的平均需求:

# Prepare region dataframe:
#   1\. Filter the 20 regions of Italy
#   2\. Group by region
#   3\. Calculate:
#      - Mean regional latitude
#      - Mean regional longitude
#      - Sum of regional demand
region_df = df.\
             loc[df.admin_name.isin(REGION_LIST)].\
             groupby(['admin_name']).\
             agg({'lat': 'mean', 'lng': 'mean', 'demand': 'sum'}).\
             reset_index()

# Add geo-coordinates
region_df = add_geocoordinates(region_df)

# Plot the shape of Italy
ax = italy.plot(color='white', edgecolor='black', figsize=(10, 10))

# Plot region area colored based on demand
region_df.\
    plot(ax=ax, column='demand', marker='o', c='demand', cmap='plasma', markersize=2500, alpha=0.6)

# Add region 'center' as red dots
region_df.\
    plot(ax=ax, marker='o', c='red', markersize=25, alpha=0.8, label='Customer location')

# Add region name above the center
for i, row in region_df.iterrows():
    plt.annotate(
        row.admin_name, xy=(row.lng, row.lat+0.2), horizontalalignment='center')

# Add color bar with demand scale
plt.colorbar(ax.get_children()[1], ax=ax, label='Annual Demand', fraction=0.04, pad=0.04) 

# Add title
plt.title('Annual demand by region')

# Remove ticks from axis
plt.xticks([])
plt.yticks([])

# Show plot
plt.show()

图片由作者提供。

为了方便以后使用PuLP,让我们将需求数据存储在customer-demand对的字典中:

# Dictionary of cutomer id (id) and demand (value)
demand_dict = { customer : customer_df['demand'][i] for i, customer in enumerate(customer_df['customer_id']) }

3.3. 供应和固定成本

为了建模供应和固定成本,我们假设:

  • 每个仓库可以满足的最大年度供应量等于平均区域需求的 3 倍。

  • 每个仓库都有一个固定的年成本为 100.000,00 €,与其位置无关。

与需求相同,我们将供应和固定成本存储在字典中:

# Assumptions: 
#    1\. Each warehouse has an annual cost of 100.000,00 euros: rent, electricity, ...
#    2\. Each warehouse can meet 3 times the regional average annual demand
COST_PER_WAREHOUSE = 100_000
SUPPLY_FACTOR_PER_WAREHOUSE = 3
SUPPLY_PER_WAREHOUSE = region_df.demand.mean() * SUPPLY_FACTOR_PER_WAREHOUSE

# Warehouses list
facility_df['warehouse_id'] = ['Warehouse ' + str(i) for i in range(1, 1 + facility_df.shape[0])]

# Dictionary of warehouse id (id) and max supply (value)
annual_supply_dict = { warehouse : SUPPLY_PER_WAREHOUSE for warehouse in facility_df['warehouse_id'] }

# Dictionary of warehouse id (id) and fixed costs (value)
annual_cost_dict = { warehouse : COST_PER_WAREHOUSE for warehouse in facility_df['warehouse_id'] }

3.4. 运输成本

运输成本的估算需要:

  • 不同位置之间的距离,以及

  • 每单位距离的成本函数。

我们可以使用哈弗森公式来近似两个位置之间的距离:

def haversine_distance(lat1, lon1, lat2, lon2):
    '''
    Calculate distance between two locations given latitude and longitude.

    Args:
       - lat1: latitude of the first location
       - lon1: longitude of the first location
       - lat2: latitude of the second location
       - lon2: longitude of the second location
    Out:
       - Distance in Km

    Ref: 
       - https://en.wikipedia.org/wiki/Haversine_formula
    '''
    return 6371.01 *\
            acos(sin(radians(lat1))*sin(radians(lat2)) +\
            cos(radians(lat1))*cos(radians(lat2))*cos(radians(lon1)-radians(lon2)))

让我们在两个城市上尝试一下:

  • 米兰(纬度:45.4654219,经度:9.18854)

  • 贝尔加莫(纬度:45.695000,经度:9.670000)

haversine_distance(45.4654219, 9.1859243, 45.695000, 9.670000)
45.508144765533906

我们得到的距离为 45.5 公里。不幸的是,这个度量与我们在汽车导航系统上看到的距离不一致,因为我们没有考虑路线:

米兰和贝尔加莫之间的哈弗森距离和路线距离。图片由作者提供。

尽管如此,我们可以使用我们的估算作为任务的合理近似。

最后,我们需要将距离转换为成本度量。目前在写作时,意大利的平均汽油价格为 1.87 €/L(来源)。一辆 EURO VI 卡车的平均耗油量约为 0.38 L/Km(来源)。通过一个简单但合理的近似,我们可以估算在意大利土地上每公里的平均成本为 0.71 €:

def traveling_cost(distance_in_km):
    '''
    Return traveling cost in euros given a distance in Km.

    Args:
      - distance_in_km: travel distance in Km
    Out:
      - cost of the trip in euros
    '''
    return 0.71 * distance_in_km

现在我们可以计算每个仓库-客户对的旅行成本,并将其存储在字典中:

# Dict to store the distances between all warehouses and customers
transport_costs_dict = {}

# For each warehouse location
for i in range(0, facility_df.shape[0]):

    # Dict to store the distances between the i-th warehouse and all customers
    warehouse_transport_costs_dict = {}

    # For each customer location
    for j in range(0, customer_df.shape[0]):

        # Distance in Km between warehouse i and customer j
        d = 0 if facility_df.city[i]==customer_df.city[j] else haversine_distance(
            facility_df.lat[i], facility_df.lng[i], customer_df.lat[j], customer_df.lng[j])

        # Update costs for warehouse i
        warehouse_transport_costs_dict.update({customer_df.customer_id[j]: traveling_cost(d)})

    # Final dictionary with all costs for all warehouses
    transport_costs_dict.update({facility_df.warehouse_id[i]: warehouse_transport_costs_dict})

3.5. 优化

让我们回顾一下优化问题:

图片由作者提供。

我们可以定义两个决策变量xᵢⱼyⱼ,目标函数和约束条件如下:

# Define linear problem
lp_problem = LpProblem('CFLP', LpMinimize)

# Variable: y_j (constraint: it is binary)
created_facility = LpVariable.dicts(
    'Create_facility', facility_df['warehouse_id'], 0, 1, LpBinary)

# Variable: x_ij
served_customer = LpVariable.dicts(
    'Link', [(i,j) for i in customer_df['customer_id'] for j in facility_df['warehouse_id']], 0)

# Objective function 
objective = lpSum(annual_cost_dict[j]*created_facility[j] for j in facility_df['warehouse_id']) +\
            lpSum(transport_costs_dict[j][i]*served_customer[(i,j)] \
                  for j in facility_df['warehouse_id'] for i in customer_df['customer_id'])

lp_problem += objective

# Costraint: the demand must be met
for i in customer_df['customer_id']:
    lp_problem += lpSum(served_customer[(i,j)] for j in facility_df['warehouse_id']) == demand_dict[i]

# Constraint: a warehouse cannot deliver more than its capacity limit
for j in facility_df['warehouse_id']:
    lp_problem += lpSum(served_customer[(i,j)] for i in customer_df['customer_id']) <= annual_supply_dict[j] * created_facility[j]

# Constraint: a warehouse cannot give a customer more than its demand
for i in customer_df['customer_id']:
    for j in facility_df['warehouse_id']:
        lp_problem += served_customer[(i,j)] <= demand_dict[i] * created_facility[j]

我们可以解决优化问题:

lp_problem.solve()

我们可以如下检查结果:

print('Solution: ', LpStatus[lp_problem.status])
Solution: Optimal

我们现在对探索决策变量感兴趣:我们需要多少个仓库?它们的位置在哪里?

4. 探索结果

首先,让我们考虑商业目标:最小化成本。我们可以检查目标函数的值:

value(lp_problem.objective)
8964323.323646087

这是在给定约束条件下我们可以实现的最低成本。任何其他仓库数量或位置的选择都会导致目标函数值更高。

我们可以通过varValue属性访问决策变量。例如,我们可以查看yⱼj = 仓库 1时的值:

created_facility['Warehouse 1'].varValue
1.0

由于yⱼ = 1,我们应在该位置建立一个仓库。我们可以轻松地操作变量并计算所需设施的数量:

# List of the values assumed by the binary variable created_facility
facility_values = [i.varValue for i in created_facility.values()]

# Count of each distinct value of the list
[[i, facility_values.count(i)] for i in set(facility_values)]
[[0.0, 59], [1.0, 32]]

只需建造最初预算的 91 个地点中的 32 个即可。35.1%(32 / 91)的潜在仓库足以满足在给定约束条件下的需求。

我们可以将决策变量保存到初始数据框中,并观察所选择的位置:

# Create dataframe column to store whether to build the warehouse or not 
facility_df['build_warehouse'] = ''

# Assign Yes/No to the dataframe column based on the optimization binary variable
for i in facility_df['warehouse_id']:
    if created_facility[i].varValue == 1:
        print('Build site at: ', i)
        facility_df.loc[facility_df['warehouse_id'] == i, 'build_warehouse'] = 'Yes'
    else:
        facility_df.loc[facility_df['warehouse_id'] == i, 'build_warehouse'] = 'No'
Build site at:  Warehouse 1
Build site at:  Warehouse 2
Build site at:  Warehouse 3
Build site at:  Warehouse 4
Build site at:  Warehouse 7
Build site at:  Warehouse 8
Build site at:  Warehouse 16
Build site at:  Warehouse 18
Build site at:  Warehouse 20
Build site at:  Warehouse 21
Build site at:  Warehouse 22
Build site at:  Warehouse 23
Build site at:  Warehouse 25
Build site at:  Warehouse 26
Build site at:  Warehouse 27
Build site at:  Warehouse 29
Build site at:  Warehouse 33
Build site at:  Warehouse 35
Build site at:  Warehouse 38
Build site at:  Warehouse 48
Build site at:  Warehouse 49
Build site at:  Warehouse 55
Build site at:  Warehouse 56
Build site at:  Warehouse 57
Build site at:  Warehouse 58
Build site at:  Warehouse 63
Build site at:  Warehouse 66
Build site at:  Warehouse 70
Build site at:  Warehouse 74
Build site at:  Warehouse 82
Build site at:  Warehouse 83
Build site at:  Warehouse 84
colors = ['#990000', '#0059b3']

facility_df.build_warehouse.value_counts().plot.barh(
  title='Warehouse sites to be established', xlabel='Number of sites', color=colors, ylabel='Establish', figsize=(7,6)) 

for i, v in enumerate(facility_df.build_warehouse.value_counts()):
    plt.text(v, i, ' '+str(round(v,3)), color=colors[i], va='center', fontweight='bold')

图片来源于作者。

# Plot the shape of Italy
ax = italy.plot(color='white', edgecolor='black', figsize=(10, 10))

# Plot sites to establish
facility_df.\
    loc[facility_df.build_warehouse =='Yes'].\
    plot(ax=ax, marker='o', c='#0059b3', markersize=50, label='Build')

# Plot sites to discard
facility_df.\
    loc[facility_df.build_warehouse =='No'].\
    plot(ax=ax, marker='X', c='#990000', markersize=40, label='Discard')

# Add title
plt.title('Optimized Warehouse Sites')

# Add legend
plt.legend(title='Warehouse Site', facecolor='white')

# Remove ticks from axis
plt.xticks([])
plt.yticks([])

# Show plot
plt.show()

图片来源于作者。

同样,我们可以遍历决策变量xᵢⱼ,找到优化解中每个仓库服务的客户:

def get_linked_customers(input_warehouse):
    '''
    Find customer ids that are served by the input warehouse.

    Args:
        - input_warehouse: string (example: <Warehouse 21>)
    Out:
        - List of customers ids connected to the warehouse
    '''
    # Initialize empty list
    linked_customers = []

    # Iterate through the xij decision variable
    for (k, v) in served_customer.items():

            # Filter the input warehouse and positive variable values
            if k[1]==input_warehouse and v.varValue>0:

                # Customer is served by the input warehouse
                linked_customers.append(k[0])

    return linked_customers

# Warehouses to establish
establish = facility_df.loc[facility_df.build_warehouse =='Yes']

# Plot the shape of Italy
ax = italy.plot(color='white', edgecolor='black', figsize=(30, 30))

# Plot sites to establish
establish.\
    plot(ax=ax, marker='o', c='#0059b3', markersize=100, label='Warehouse')

# Plot customers
customer_df.\
    plot(ax=ax, marker='X', color='#990000', markersize=80, alpha=0.8, label='Customer')

# For each warehouse to build
for w in establish.warehouse_id:

    # Extract list of customers served by the warehouse
    linked_customers = get_linked_customers(w)

    # For each served customer
    for c in linked_customers:

        # Plot connection between warehouse and the served customer
        ax.plot(
         [establish.loc[establish.warehouse_id==w].lng, customer_df.loc[customer_df.customer_id==c].lng],
         [establish.loc[establish.warehouse_id==w].lat, customer_df.loc[customer_df.customer_id==c].lat],
         linewidth=0.8, linestyle='--', color='#0059b3')

# Add title
plt.title('Optimized Customers Supply', fontsize = 35)

# Add legend
plt.legend(facecolor='white', fontsize=30)

# Remove ticks from axis
plt.xticks([])
plt.yticks([])

# Show plot
plt.show()

图片来源于作者。

5. 结论

在这篇文章中,我们介绍了一个经典的优化挑战:容量受限设施选址问题(CFLP)。我们描述了它的推导过程,并分享了一个实用的 Python 示例。特别地,由于我们从一个原始的地理位置数据集开始,我们覆盖了所有框架问题和求解问题所需的必要步骤和假设。

优化:牛顿-拉夫森方法的几何解释

原文:towardsdatascience.com/optimization-geometrical-interpretation-of-the-newton-raphson-method-d9f7acf1b5ae

探索一种基本的数值优化技术,重点关注其几何解释

Saupin GuillaumeTowards Data Science Saupin Guillaume

·发布于Towards Data Science ·阅读时间 8 分钟·2023 年 10 月 13 日

--

图片由Ansgar Scheffold拍摄,发布在Unsplash

梯度下降被广泛认为是基本的数值优化技术之一,而牛顿-拉夫森方法则在这一领域中具有显著地位。这种方法以其简单性、优雅性和计算能力而著称,值得深入探讨。

在本文中,我们的目标是阐明牛顿-拉夫森方法的几何原理。通过这种阐述,我们旨在为读者提供对其机制的直观理解,并消除与其数学基础相关的潜在复杂性。

随后,为了建立我们讨论的坚实数学框架,我们将深入探讨该方法的数学复杂性,并提供在 Python 编程语言中的实际实现。

在此演示之后,我们将区分牛顿-拉夫森方法的两个主要应用:根查找和优化。这一区分将明确方法在不同上下文中的实际应用。

最后,我们将对牛顿-拉夫森方法和梯度下降方法进行比较分析,提供对它们各自优缺点的见解。

如果你对数学概念感兴趣,并希望通过 Python 快速学习,请查看我的书籍:

[## 用 Python 揭示 70 个数学概念:通过 Python 探索数学的实用指南

购买《用 Python 揭示 70 个数学概念:通过 Python 探索数学的实用指南》请访问…

amzn.to

图形概述

查找根的迭代过程。图片由作者提供。

基本上,牛顿-拉夫森方法是一种迭代过程,旨在数值确定数学方程的根。其主要目标是确定一个值,记作 x_root,使方程 f(x_root) = 0 得以满足。需要注意的是,由于其数值性质,通过该方法获得的 x_root 值是一个近似值。

该方法的核心原理如下:从初始点 x_0 开始,目标是生成一系列值 x_i,这些值逐渐逼近所寻求的 x_root。为了实现这一点,在 (x_i, f(x_i)) 点附近通过线性化过程对函数进行局部近似。这种近似通过计算函数在点 (x_i, f(x_i)) 处的切线来完成。

对于这种局部近似,根的确定是直接的:它对应于该切线与 x 轴的交点。这个交点的横坐标 x_i+1 作为根 x_root 的更新近似值。

这个迭代过程会不断重复,直到 f(x) 的值足够接近零,表明已接近根。上面的图形作为该过程的示意图,展示了每次迭代时的切线及其与 x 轴的交点。

在这个具体的例子中,使用的基础函数是多项式 (x-6)(x-6)(x-6),它在 x=6 处具有三重根。显然,经过几次迭代,大约 10 次,该方法会收敛到接近理论根 x_root=6 的点。

绘制该图形的代码如下:

显示牛顿-拉夫森方法的连续迭代。代码由作者提供。

按照我通常的做法,我在上述代码中依赖自动微分,特别是使用 JAX 库来有效计算 f 的导数。

方法的数学推导

在让你感受到牛顿-拉夫森方法的工作方式后,是时候看看这个方法如何从数学上推导出来了。

给定一个函数f(x)和一个起始点x_i,目的是在该点(x_i, f(x_i))对函数进行线性化。这正是我们在前一节中所做的,通过在该点绘制切线。

从形式上讲,这可以表示为fx_i处的一阶泰勒展开:

一阶泰勒展开。作者提供的公式。

该公式给出了fx_i附近的线性化公式,其中delta是变量。

给定此公式,按照上述几何方法计算delta,更新x_i,使得该线性化与 x 轴相交,即:

计算x_i+1。作者提供的公式。

牛顿-拉夫森方法用于根的识别

正如引言中提到的,牛顿-拉夫森方法可以用来解决两个主要问题:根的查找和函数优化。在本节中,我们将特别关注根的查找。

首先,让我们重新审视根的查找概念。在几何背景下,这本质上是识别一个函数f与 x 轴交点的横坐标x。这与本文初始图中的基本概念相符。从代数角度来看,这等同于确定x使得f(x) = 0

因此,牛顿-拉夫森方法被设计用来解决这个具体问题:确定函数f的曲线与 x 轴交点的x值。

牛顿-拉夫森方法用于优化

如前所述,该方法主要解决的问题是根的查找。然而,在许多情况下,牛顿-拉夫森方法并不是用来寻找根,而是用于优化。换句话说,它用于识别x的值,使得函数f达到最大值或最小值点。

这怎么可能呢?最初,这可能看起来根的查找和优化是完全不同的挑战。然而,它们却是密切相关的。这种联系来源于一个事实,即函数在其斜率符号改变的点达到极值,这意味着其导数为零。因此,寻找函数的极值等同于识别其导数的根。

将牛顿-拉夫森方法应用于这种情况是相对简单的。该过程包括用其导数f’替换f,然后用其二阶导数f’’替换f’

以下代码片段展示了如何调整该方法以最大化高度非线性函数-(-(x-6)**2 / 3.0 + log(6+x) + 2*sin(x))

该代码通过应用牛顿-拉夫森方法来查找函数达到最大值的x,并绘制连续探索的点:

优化一个明显非线性的函数。作者提供的图像。

牛顿-拉夫森方法的优化几何解释

牛顿-拉夫森方法在优化中的适应可以从两种基本方式进行解释。第一种方法涉及将用于根查找的几何解释应用于函数 f 的导数 f',可以在本文的早期部分找到详细信息。

第二种更具吸引力的观点认为,在优化的背景下,函数 f 不再由直线(即一阶多项式)局部近似,而是由二阶多项式近似。从几何角度来看,这意味着 f 的曲线由抛物线近似。在这种情况下,目标不是确定局部近似与 x 轴的交点(如同在根查找中),而是找到该函数的最小值。

正如俗话所说,“一图胜千言”,下图的视觉表示有助于有效地说明这个概念。

通过拟合抛物线进行优化。图由作者提供。

基本上,算法的工作原理如下:

  • 选择一个点

  • 在这个点拟合一个抛物线,即一个多项式 p(x)=ax**2+bx+c,使其通过同一个点,并且与f具有相同的一阶和二阶导数。

  • 使用抛物线极值的横坐标 x 作为新的起始点

  • 迭代

从 Python 的角度来看,这给出:

牛顿-拉夫森方法的几何再解释。代码由作者提供。

请注意,这种牛顿-拉夫森方法的几何再解释给出的结果与之前的实现完全相同(在此情况下为7.4074383),并且收敛速度相同。只是稍微不够高效,因为我们需要计算多项式 p 的三个权重。

应用条件

牛顿-拉夫森方法是一个非常优雅且强大的方法。然而,在应用于优化时,它需要一些条件才能发挥作用。

首先,优化的函数必须是二次可微的。其次,需要从一个良好的近似开始,以避免局部最小值。

牛顿-拉夫森方法不同于梯度下降

另一个非常流行的优化方法是梯度下降。然而,它们在许多方面有所不同。

首先,牛顿-拉夫森方法在应用于优化时要求函数 f 是二次可微的。相对而言,梯度下降仅要求一阶导数。

其次,牛顿-拉夫森方法需要的调整较少,因为没有学习率需要配置。新点由抛物线的最小值自动给出。

第三,这种直接估计下一个点的方法,无需设置学习率,确保了更快的收敛。更准确地说,牛顿-拉夫森方法的收敛是二次的,而梯度下降方法的收敛仅是线性的。

www.buymeacoffee.com/guillaumes0

总结来说,本文探讨了牛顿-拉夫森方法这一基本的数值优化技术,重点是其几何解释。方法的优雅和计算能力得到了突出,文章旨在以直观的方式阐明其机制。

牛顿-拉夫森方法被描述为一种求根的迭代过程,其目标是近似一个数学方程的根。核心原则进行了讨论,强调了线性化和切线交点作为关键概念。

文章随后区分了牛顿-拉夫森方法的两个主要应用:求根和优化。求根被解释为寻找函数与 x 轴的交点,而优化被描述为寻找函数导数为零的极值点。方法在优化中的适应性,使用二次多项式来近似函数,被概述了。

应用牛顿-拉夫森优化方法的条件已被提到,包括函数需要是二次可微的要求,以及从一个好的初始近似值开始的重要性。

最后,文章比较了牛顿-拉夫森方法和梯度下降方法在其微分要求、调节和收敛速度方面的区别。文章总结了牛顿-拉夫森方法在优化中的独特优势。

在即将发布的后续文章中,我们将更详细地探讨牛顿-拉夫森方法在实际应用中的应用。

如果你对数学概念感兴趣,并且希望通过 Python 快速学习它们,可以看看我的书:

[## 揭示 70 个数学概念与 Python:通过 Python 探索数学的实用指南…

在…

amzn.to](https://amzn.to/3QaRkXZ?source=post_page-----d9f7acf1b5ae--------------------------------)

优化、牛顿法与利润最大化:第一部分 — 基本优化理论

原文:towardsdatascience.com/optimization-newtons-method-profit-maximization-part-1-basic-optimization-theory-ff7c5f966565?source=collection_archive---------10-----------------------#2023-01-10

所有图片由作者提供

学习如何解决和利用牛顿法解决多维优化问题

Jacob PieniazekTowards Data Science Jacob Pieniazek

·

关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 1 月 10 日

--

本文是 3 部分系列中的第一部分。在第一部分中,我们将学习基本的优化理论。然后,在第二部分,我们将扩展这些理论到约束优化问题。最后,在第三部分中,我们将应用所涵盖的优化理论,以及计量经济学和经济理论,来解决一个利润最大化问题。

数学优化是一个极其强大的数学领域,支撑了我们数据科学家在日常工作中隐性或显性使用的许多工具——实际上,几乎所有的机器学习算法都利用优化理论来实现模型收敛。例如,在分类问题中,我们试图通过选择模型的最优参数或权重来最小化对数损失。一般来说,数学优化可以被视为机器学习的主要理论机制。对数学优化的深刻理解是数据科学家工具箱中非常有用的技能——它使数据科学家能够更深入地理解许多当前使用的算法,并且进一步解决各种独特的优化问题

许多读者可能对梯度下降或相关的优化算法,如随机梯度下降,已经有所了解。然而,本文将更深入地讨论经典的牛顿优化方法,有时称为牛顿-拉夫森方法。我们将从基础数学知识开始,逐步讲解优化理论,到梯度下降,然后深入探讨牛顿方法及其在 Python 中的实现。这将为我们进入第二部分的约束优化和本系列的第三部分中的计量经济学利润最大化问题提供必要的前期准备。

优化基础——一个简单的二次函数

数学优化可以被定义为“确定数学定义问题的最佳解决方案的科学。”[1] 在一些实际例子中,这可以被概念化为:选择参数以最小化机器学习算法的损失函数,选择价格和广告以最大化利润,选择股票以最大化风险调整后的财务回报等等。形式上,任何数学优化问题都可以抽象地表述如下:

(1)

这可以理解为:选择向量x的实值,以最小化目标函数 f(x)(或最大化-f(x)),并满足不等式约束 g(x)和等式约束 h(x)。我们将在本系列的第二部分中讨论如何求解约束优化问题——因为它们使优化问题变得特别复杂。现在,让我们来看一个无约束的单变量示例——考虑以下优化问题:

(2)

在这种情况下,我们想选择使上面的二次函数最小化的 x 值。我们可以采用多种方法——首先,一种简单的方法是对 x 值的大范围进行网格搜索,并选择使f(x)具有最低函数值的 x。然而,随着搜索空间的增加、函数变得更加复杂或维度增加,这种方法可能很快失去计算上的可行性。

如果存在封闭形式的解,我们可以直接使用微积分来求解。也就是说,我们可以通过微积分解析地求解 x 的值。通过取导数(或在高维中,如后面所述的梯度)并将其设置为 0——相对最小值的必要一阶条件——我们可以求解函数的相对极值。然后我们可以取第二导数(或在高维中,如后面所述的 Hessian 矩阵)来确定这个极值是最大值还是最小值。第二导数大于 0(或正定 Hessian 矩阵)——相对最小值的必要二阶条件——意味着是最小值,反之亦然。观察:

(3)

我们可以通过图形验证上述(2):

请注意,当一个函数存在多个极值点(即多个最小值或最大值)时,必须小心确定哪个是全局极值——我们将在本文中简要讨论这个问题。

上述分析方法可以扩展到更高维度,利用梯度和 Hessian 矩阵——然而,我们不会在高维中求解封闭形式的解——但直觉依然相同。我们仍然会利用迭代方案来解决更高维的问题。我说的迭代方案是什么意思?通常,可能不存在封闭形式(或解析)的解,并且为了存在最大值或最小值,封闭形式的解确实不一定存在。因此,我们需要一种数值方法来解决优化问题。这引导我们到更广泛的迭代方案,包括梯度下降法和牛顿方法。

迭代优化方案

一般来说,迭代优化方案主要分为三类,即 零阶一阶二阶,它们分别利用函数的零阶、一阶或二阶导数的局部信息。[1] 要使用每种迭代方案,函数 f(x) 必须是相应阶数上连续可微的函数。

零阶迭代方案

零阶迭代方案 与上述的网格搜索紧密相关——简单来说,你在一定范围内搜索可能的 x 值以获得最小的函数值。正如你可能猜到的,这些方法往往比使用高阶方法的计算开销大得多。不用说,它们可以是可靠的并且容易编程。市场上有一些方法可以改进简单的网格搜索,参见[1]了解更多信息;然而,我们将更多关注高阶方案。

一阶迭代方案

一阶迭代方案 是利用目标函数一阶导数的局部信息的迭代方案。最显著的例子是梯度下降方法。对于如上所述的单变量函数,梯度就是一阶导数。将此推广到 n 维度,对于一个函数 f(x),梯度是一阶偏导数的向量:

(4) 连续可微函数 f(x)

梯度下降从选择一个随机的起点开始,并在 f(x) 的负梯度方向上迭代进行——函数的最陡方向。每次迭代步骤可以表示如下:

(5) 梯度下降迭代方案

其中 γ 是相应的学习率,控制梯度下降算法在每次迭代中“学习”的快慢。学习率过大会导致我们的迭代不受控制地发散。学习率过小则迭代可能需要很长时间才能收敛。此方案会迭代进行,直到达到一个或多个收敛准则,如:

(6) 迭代优化方案的收敛准则

对于某个小的 epsilon 阈值。回到我们的二次例子,将初始猜测设置为 x = 3 和学习率设置为 0.1,步骤如下:

(7)

视觉化如下:

梯度下降和一阶迭代方案在性能上显著可靠。实际上,梯度下降算法主要用于神经网络和机器学习模型中的损失函数优化,许多发展已提高了这些算法的效能。然而,它们仍然仅使用关于函数的有限局部信息(仅一阶导数)。因此,在高维情况下,根据目标函数的性质和学习率,这些方案 1) 可能具有较慢的收敛速度,因为它们保持线性收敛率,并且 2) 可能完全无法收敛。由于这个原因,数据科学家扩大优化工具库是有益的!

二阶迭代方案

如你现在可能已经明白,二阶迭代方案 是利用目标函数的一阶导数和二阶导数的局部信息的迭代方案。最显著的是,我们有牛顿法 (NM),它使用目标函数的海森矩阵。对于单变量函数,海森矩阵仅仅是二阶导数。类似于梯度,将其推广到 n 维度,海森矩阵是一个 n x n 对称 矩阵,包含一个两次连续可微函数 f(x)的二阶偏导数。

(8) 二次连续可微函数 f(x)的海森矩阵

现在转到导出 NM,首先回忆最小值的一阶必要条件:

(9) x* 处相对最小值的一级必要条件

我们可以使用泰勒级数展开来近似 x*:

(10)

每次迭代的增量 Δ 是对 x* 的一个更好的预期近似。因此,使用 NM 的每次迭代步骤可以表示如下:

(11) 牛顿法迭代方案

回到我们的二次函数示例,将初始猜测设置为 x = 3,步骤如下:

(12)

我们优雅地在第一次迭代时就收敛到最优解。注意,无论方案如何,收敛标准都是相同的。

请注意,所有优化方案都可能陷入局部极值,而不是全局极值(即,考虑具有多个极值(最小值和/或最大值)的高阶多项式——我们可能会陷入一个局部极值,而实际上,另一个极值可能在全球范围内对我们的实际问题更为优化)。已有的方法,并且仍在不断开发,用于处理全局优化,我们将不会深入探讨。你可以利用函数形式的先验知识来设置对结果的预期(即,如果一个严格凸函数有一个临界点,则它必须是全局最小值)。然而,作为一般经验法则,通常明智的做法是对不同的初始值 x 迭代优化方案,然后研究结果的稳定性,通常选择具有最优函数值的结果。

多维示例——罗森布罗克的抛物线谷

现在让我们考虑以下两个变量的优化问题:

(13) 罗森布罗克的抛物线谷

我们将首先通过手动求解上述优化问题,然后在 Python 中进行求解,均使用牛顿法。

通过手动求解

要手动求解,我们需要求解梯度,求解 Hessian,选择我们的初始猜测 Γ = [x,y],然后迭代将这些信息输入到 NM 算法中,直到收敛为止。首先,求解梯度,我们得到:

(14)

求解 Hessian,我们得到:

(15)

将我们的初始猜测设置为 Γ = [-1.2,1],我们得到:

(16)

因此,我们成功地求解了目标函数的最优最小值为 Γ* = [1,1]。

通过 Python 求解

我们现在将转向用 Python 求解这个问题,并将其推广到任何函数,使用 SymPy —— 一个用于符号数学的 Python 库。首先,让我们定义罗森布罗克的抛物线谷,并计算该函数的梯度和 Hessian:

import sympy as sm
import numpy as np

# Define symbols & objective function
x, y = sm.symbols('x y')
Gamma = [x,y]
objective = 100*(y-x**2)**2 + (1-x)**2

def get_gradient(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
) -> np.ndarray:
    """
    Calculate the gradient of a function.

    Args:
        function (sm.core.expr.Expr): The function to calculate the gradient of.
        symbols (list[sm.core.symbol.Symbol]): The symbols representing the variables in the function.

    Returns:
        numpy.ndarray: The gradient of the function.
    """
    d1 = {}
    gradient = np.array([])

    for i in symbols:
        d1[i] = sm.diff(function, i, 1)
        gradient = np.append(gradient, d1[i])

    return gradient

def get_hessian(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float],
) -> np.ndarray:
    """
    Calculate the Hessian matrix of a function.

    Args:
    function (sm.core.expr.Expr): The function for which the Hessian matrix is calculated.
    symbols (list[sm.core.symbol.Symbol]): The list of symbols used in the function.

    Returns:
    numpy.ndarray: The Hessian matrix of the function.
    """
    d2 = {}
    hessian = np.array([])

    for i in symbols:
        for j in symbols:
            d2[f"{i}{j}"] = sm.diff(function, i, j)
            hessian = np.append(hessian, d2[f"{i}{j}"])

    hessian = np.array(np.array_split(hessian, len(symbols)))

    return hessian

SymPy 允许我们调查方程的符号表示。例如,如果我们调用 objective ,我们将看到相应的输出:

SymPy 的函数符号表示

此外,SymPy 允许我们利用 sm.diff() 命令对相关函数进行求导。如果我们运行定义的函数以获得梯度 get_gradient(objective,Gamma) ,我们得到一个表示梯度的 numpy 数组:

SymPy 求解的梯度

访问特定元素时,我们可以看到符号表示 get_gradient(objective, Gamma)[0]

SymPy 解出的 df(Γ)/dx

类似地,对于 Hessian 矩阵,我们可以调用 get_hessian(objective, Gamma)

SymPy 解出的 Hessian 矩阵

访问特定元素 get_hessian(objective,Gamma)[0][1]

SymPy 解出的 df(Γ)/dxdy

可以轻松验证梯度和 Hessian 矩阵与我们手动计算得到的结果是相同的。SymPy 允许对给定符号的指定值评估任何函数。例如,我们可以通过如下调整函数来评估初始猜测处的梯度:

import sympy as sm
import numpy as np

def get_gradient(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float], # Add x0 as argument
) -> np.ndarray:
    """
    Calculate the gradient of a function at a given point.

    Args:
        function (sm.core.expr.Expr): The function to calculate the gradient of.
        symbols (list[sm.core.symbol.Symbol]): The symbols representing the variables in the function.
        x0 (dict[sm.core.symbol.Symbol, float]): The point at which to calculate the gradient.

    Returns:
        numpy.ndarray: The gradient of the function at the given point.
    """
    d1 = {}
    gradient = np.array([])

    for i in symbols:
        d1[i] = sm.diff(function, i, 1).evalf(subs=x0) # add evalf method
        gradient = np.append(gradient, d1[i])

    return gradient.astype(np.float64) # Change data type to float

def get_hessian(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float],
) -> np.ndarray:
    """
    Calculate the Hessian matrix of a function at a given point.

    Args:
    function (sm.core.expr.Expr): The function for which the Hessian matrix is calculated.
    symbols (list[sm.core.symbol.Symbol]): The list of symbols used in the function.
    x0 (dict[sm.core.symbol.Symbol, float]): The point at which the Hessian matrix is evaluated.

    Returns:
    numpy.ndarray: The Hessian matrix of the function at the given point.
    """
    d2 = {}
    hessian = np.array([])

    for i in symbols:
        for j in symbols:
            d2[f"{i}{j}"] = sm.diff(function, i, j).evalf(subs=x0)
            hessian = np.append(hessian, d2[f"{i}{j}"])

    hessian = np.array(np.array_split(hessian, len(symbols)))

    return hessian.astype(np.float64)

我们现在可以通过调用 get_gradient(objective, Gamma, {x:-1.2,y:1.0}) 来计算给定起始点的梯度:

初始点处的梯度

类似地,对于 Hessian 矩阵 get_hessian(objective, Gamma, {x:-1.2,y:1.0})

初始点处的 Hessian 矩阵

同样,我们可以通过以上手动计算的工作来验证这些值是否正确。现在我们拥有了编写牛顿法所需的所有要素(梯度下降的代码在本文末尾也提供)

import sympy as sm
import numpy as np

def newton_method(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float],
    iterations: int = 100,
) -> dict[sm.core.symbol.Symbol, float] or None:
    """
    Perform Newton's method to find the solution to the optimization problem.

    Args:
        function (sm.core.expr.Expr): The objective function to be optimized.
        symbols (list[sm.core.symbol.Symbol]): The symbols used in the objective function.
        x0 (dict[sm.core.symbol.Symbol, float]): The initial values for the symbols.
        iterations (int, optional): The maximum number of iterations. Defaults to 100.

    Returns:
        dict[sm.core.symbol.Symbol, float] or None: The solution to the optimization problem, or None if no solution is found.
    """

    x_star = {}
    x_star[0] = np.array(list(x0.values()))

    print(f"Starting Values: {x_star[0]}")

    for i in range(iterations):

        gradient = get_gradient(function, symbols, dict(zip(x0.keys(), x_star[i])))
        hessian = get_hessian(function, symbols, dict(zip(x0.keys(), x_star[i])))

        x_star[i + 1] = x_star[i].T - np.linalg.inv(hessian) @ gradient.T

        if np.linalg.norm(x_star[i + 1] - x_star[i]) < 10e-5:
            solution = dict(zip(x0.keys(), x_star[i + 1]))
            print(f"\nConvergence Achieved ({i+1} iterations): Solution = {solution}")
            break
        else:
            solution = None

        print(f"Step {i+1}: {x_star[i+1]}")

    return solution

我们现在可以通过 newton_method(objective,Gamma,{x:-1.2,y:1}) 来运行代码:

结论

就这样!如果你已经阅读到这一步,你现在对如何思考和抽象地制定无约束数学优化问题有了扎实的理解,包括基本的分析方法和更复杂的迭代方法。显然,我们在迭代方案中可以融入的信息越多(即更高阶的导数),收敛速度就越高效。请注意,我们只是触及了数学优化复杂世界的表面。 尽管如此,我们今天讨论的工具在实践中绝对可以使用,并可以扩展到更高维的优化问题。

请关注 第二部分 的系列文章,我们将在其中扩展我们在这里学到的内容,解决有约束的优化问题——这是无约束优化的一个极其实用的扩展。实际上,大多数实际优化问题都会对选择变量有某种形式的约束。然后我们将转向本系列的第三部分,在其中我们将应用学到的优化理论和额外的计量经济学与经济理论来解决一个简单的利润最大化问题。希望你和我一样喜欢阅读这篇文章!

奖励——牛顿法的陷阱

尽管牛顿法具有吸引力,但它也有自身的陷阱。尤其是,存在两个主要陷阱——1)即使在选择接近解的起始点时,NM 并不总是收敛;2)NM 在每一步都需要计算 Hessian 矩阵,这在高维情况下计算开销非常大。对于陷阱 #1),一种相应的解决方案是改进牛顿法(MNM),可以粗略地认为它是梯度下降法,其中搜索方向由牛顿步长Δ给出。对于陷阱 #2),已经提出了准牛顿方法,如 DFP 或 BFGS,这些方法通过近似每一步使用的逆 Hessian 矩阵来减轻计算负担。有关更多信息,请参见[1]。

梯度下降函数

import sympy as sm
import numpy as np

def gradient_descent(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float],
    learning_rate: float = 0.1,
    iterations: int = 100,
) -> dict[sm.core.symbol.Symbol, float] or None:
    """
    Performs gradient descent optimization to find the minimum of a given function.

    Args:
        function (sm.core.expr.Expr): The function to be optimized.
        symbols (list[sm.core.symbol.Symbol]): The symbols used in the function.
        x0 (dict[sm.core.symbol.Symbol, float]): The initial values for the symbols.
        learning_rate (float, optional): The learning rate for the optimization. Defaults to 0.1.
        iterations (int, optional): The maximum number of iterations. Defaults to 100.

    Returns:
        dict[sm.core.symbol.Symbol, float] or None: The solution found by the optimization, or None if no solution is found.
    """
    x_star = {}
    x_star[0] = np.array(list(x0.values()))

    print(f"Starting Values: {x_star[0]}")

    for i in range(iterations):

        gradient = get_gradient(function, symbols, dict(zip(x0.keys(), x_star[i])))

        x_star[i + 1] = x_star[i].T - learning_rate * gradient.T

        if np.linalg.norm(x_star[i + 1] - x_star[i]) < 10e-5:
            solution = dict(zip(x0.keys(), x_star[i + 1]))
            print(f"\nConvergence Achieved ({i+1} iterations): Solution = {solution}")
            break
        else:
            solution = None

        print(f"Step {i+1}: {x_star[i+1]}")

    return solution

资源

[1] Snyman, J. A., & Wilke, D. N. (2019). 实用数学优化:基本优化理论与基于梯度的算法(第 2 版)。Springer。

[2] en.wikipedia.org/wiki/Gradient_descent

[3] en.wikipedia.org/wiki/Newton%27s_method#:~:text=In%20numerical%20analysis%2C%20Newton%27s%20method%2C%20also%20known%20as,roots%20%28or%20zeroes%29%20of%20a%20real%20-valued%20function.

通过此 GitHub Repo 访问所有代码: github.com/jakepenzak/Blog-Posts

感谢你阅读我的帖子!我在 Medium 上的帖子旨在探讨利用 计量经济学 统计/机器学习 技术的现实世界和理论应用。此外,我还希望通过理论和模拟提供某些方法论的理论基础。最重要的是,我写作是为了学习!我希望让复杂的话题对所有人稍微更易于理解。如果你喜欢这篇文章,请考虑 在 Medium 上关注我

优化、牛顿法与利润最大化:第二部分——约束优化理论

原文:towardsdatascience.com/optimization-newtons-method-profit-maximization-part-2-constrained-optimization-theory-dc18613c5770?source=collection_archive---------9-----------------------#2023-02-02

作者提供的所有图片

了解如何扩展牛顿法并解决约束优化问题

Jacob PieniazekTowards Data Science Jacob Pieniazek

·

关注 发表在 Towards Data Science ·13 分钟阅读·2023 年 2 月 2 日

--

这篇文章是一个三部分系列中的第二部分。在第一部分中,我们研究了基本的优化理论。现在,在第二部分中,我们将把这一理论扩展到受限优化问题。最后,在第三部分中,我们将应用所涵盖的优化理论,以及计量经济学和经济理论,来解决一个利润最大化问题。

考虑以下问题:你想确定在特定金融工具上投资多少资金以最大化你的投资回报。然而,单纯最大化投资回报的问题过于宽泛和简单。由于其简单性,解决方案就是将所有资金投入到回报潜力最高的金融工具中。显然,这不是一个好的投资策略;那么,我们如何改进呢?通过对投资决策施加约束,我们的选择变量。 例如,我们可以指定约束条件,比如 1) 限制我们愿意承担的金融风险量(见现代投资组合理论),或者 2) 指定我们投资组合中每类金融工具(如股票、债券、衍生品等)的分配比例——可能性无穷无尽。注意,当我们添加约束时,这个问题变得显著更具可处理性。尽管这是一个简单的例子,但它有助于捕捉受限优化的一个基本动机:

受限优化的本质在于为无约束优化问题提供一种适用性和解决复杂现实世界问题的能力。

受限优化被定义为“在变量受到约束的情况下,对目标函数进行优化的过程。”[1] 添加对变量的约束将一个无约束的、或许是难以处理的优化问题转化为一个有助于建模和解决现实世界问题的问题。然而,添加约束可能将一个简单的优化问题转变为一个不再是微不足道的问题。在这篇文章中,我们将深入探讨一些可以添加到我们工具箱中的技术,以扩展在第一部分中学习的无约束优化理论,从而解决受限优化问题。

第一部分中,我们介绍了基本的优化理论——包括 1) 解析设置和解决一个简单的单变量优化问题,2) 迭代优化方案——即梯度下降法和牛顿法,以及 3) 手动和使用 Python 实现牛顿法用于多维优化问题。本文旨在使那些已经熟悉 第一部分 中内容的读者能够轻松理解。

约束优化基础(& 第一部分 回顾)

一个数学优化问题可以抽象地表示如下:

(1)

在这里,我们选择 x 的实际值,以最小化 目标函数 f(x)(或最大化 -f(x)),同时满足 不等式约束 g(x) 和 等式约束 h(x)。在 第一部分 中,我们讨论了如何在没有 g(x) 和 h(x) 的情况下解决这些问题,现在我们将这些约束重新引入到我们的优化问题中。首先,让我们简明扼要地回顾如何实现牛顿法用于无约束问题。

回顾一下,我们可以使用泰勒级数展开来近似最小值的一阶必要条件:

(2)

其中 H(x) 和 f(x) 分别表示 f(x) 的 Hessian 矩阵和梯度。每次迭代增加的 delta, Δ, 是对最优值 x* 的预期更好近似。因此,每次使用牛顿法的迭代步骤可以表示如下:

(3) 牛顿法迭代方案

我们执行这个方案直到在以下一个或多个标准上达到收敛:

(4) 迭代优化方案的收敛标准

将其转化为 Python 代码,我们使用 SymPy —— 一个用于符号数学的 Python 库 —— 并创建可泛化的函数来计算梯度、计算 Hessian 矩阵,并实现牛顿法用于 n 维函数:

import sympy as sm
import numpy as np

def get_gradient(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float],
) -> np.ndarray:
    """
    Calculate the gradient of a function at a given point.

    Args:
        function (sm.core.expr.Expr): The function to calculate the gradient of.
        symbols (list[sm.core.symbol.Symbol]): The symbols representing the variables in the function.
        x0 (dict[sm.core.symbol.Symbol, float]): The point at which to calculate the gradient.

    Returns:
        numpy.ndarray: The gradient of the function at the given point.
    """
    d1 = {}
    gradient = np.array([])

    for i in symbols:
        d1[i] = sm.diff(function, i, 1).evalf(subs=x0)
        gradient = np.append(gradient, d1[i])

    return gradient.astype(np.float64)

def get_hessian(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float],
) -> np.ndarray:
    """
    Calculate the Hessian matrix of a function at a given point.

    Args:
    function (sm.core.expr.Expr): The function for which the Hessian matrix is calculated.
    symbols (list[sm.core.symbol.Symbol]): The list of symbols used in the function.
    x0 (dict[sm.core.symbol.Symbol, float]): The point at which the Hessian matrix is evaluated.

    Returns:
    numpy.ndarray: The Hessian matrix of the function at the given point.
    """
    d2 = {}
    hessian = np.array([])

    for i in symbols:
        for j in symbols:
            d2[f"{i}{j}"] = sm.diff(function, i, j).evalf(subs=x0)
            hessian = np.append(hessian, d2[f"{i}{j}"])

    hessian = np.array(np.array_split(hessian, len(symbols)))

    return hessian.astype(np.float64)

def newton_method(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float],
    iterations: int = 100,
) -> dict[sm.core.symbol.Symbol, float] or None:
    """
    Perform Newton's method to find the solution to the optimization problem.

    Args:
        function (sm.core.expr.Expr): The objective function to be optimized.
        symbols (list[sm.core.symbol.Symbol]): The symbols used in the objective function.
        x0 (dict[sm.core.symbol.Symbol, float]): The initial values for the symbols.
        iterations (int, optional): The maximum number of iterations. Defaults to 100.

    Returns:
        dict[sm.core.symbol.Symbol, float] or None: The solution to the optimization problem, or None if no solution is found.
    """

    x_star = {}
    x_star[0] = np.array(list(x0.values()))

    # x = [] ## Return x for visual!

    print(f"Starting Values: {x_star[0]}")

    for i in range(iterations):
        # x.append(dict(zip(x0.keys(),x_star[i]))) ## Return x for visual!

        gradient = get_gradient(function, symbols, dict(zip(x0.keys(), x_star[i])))
        hessian = get_hessian(function, symbols, dict(zip(x0.keys(), x_star[i])))

        x_star[i + 1] = x_star[i].T - np.linalg.inv(hessian) @ gradient.T

        if np.linalg.norm(x_star[i + 1] - x_star[i]) < 10e-5:
            solution = dict(zip(x0.keys(), x_star[i + 1]))
            print(f"\nConvergence Achieved ({i+1} iterations): Solution = {solution}")
            break
        else:
            solution = None

        print(f"Step {i+1}: {x_star[i+1]}")

    return solution

为了解决无约束优化问题,我们可以运行以下代码:

import sympy as sm

# Define Symbols
x, y = sm.symbols('x y') 
Gamma = [x,y] 

# Define Objective Function (Rosenbrock's Parabolic Valley)
objective = 100*(y-x**2)**2 + (1-x)**2

# Specify starting values
Gamma0 = {x:-1.2,y:1}

# Call function
newton_method(objective, Gamma, Gamma0)

及其对应的输出:

如果上述所有材料感觉非常陌生,那么我建议查看part 1,它将更深入地探讨上述内容并帮助你跟上进度!话不多说,让我们深入实施优化问题中的约束。

注意:所有以下约束优化技术都可以且应该在适用时与梯度下降算法结合使用!

求解带约束的优化问题

如我们上面讨论的,目标函数可能有两种约束——等式约束和不等式约束。注意,对于每种类型的约束,都有不同的方法,具有不同的优缺点。有关不同方法的进一步讨论,请参见[2]。不过,我们将重点关注两种方法,一种用于等式约束,另一种用于不等式约束,这些方法在性能上可靠,对新手易于理解,并且可以很容易地集成到一个有机的问题中。

等式约束 — 拉格朗日

首先,我们将处理带有等式约束的优化问题。即,具有以下形式的优化问题:

(5)

假设我们正在处理 Rosenbrock 的抛物谷,如part 1中所述,但现在添加了等式约束 x² - y = 2:

(6) 带有等式约束的 Rosenbrock 的抛物谷 (问题 1)

注意,为了简化和一致性,等式约束应写成等于零的形式。现在我们的优化问题看起来像:

Rosenbrock 的抛物谷(紫黄色色图)和等式约束曲线(黑色)

在这里,可行区域的最优值位于等式约束曲线与我们上面的目标函数的交点之一。

约瑟夫-路易·拉格朗日开发了一种方法,将等式约束直接纳入目标函数中——创建拉格朗日函数——以便可以继续应用传统方法使用一阶和二阶导数。[2][3] 形式上,拉格朗日函数具有以下形式:

(7) 拉格朗日函数的正式定义

其中 f(x) 和 h(x) 分别是目标函数和等式约束。Λ 是与每个等式约束 j 对应的 拉格朗日乘子。拉格朗日乘子被视为拉格朗日函数中的新选择变量。正好,x* 作为等式约束问题的最小值的必要条件x* 对应于拉格朗日函数的驻点 (x, Λ). 即,

(8) 拉格朗日一阶条件

对于上述示例——等式约束的 Rosenbrock 抛物线谷(公式 1)——我们可以将拉格朗日函数写为:

(9)

然后我们可以使用牛顿法求解这个拉格朗日函数,但现在需要将拉格朗日乘子作为附加选择变量。

import sympy as sm

x, y, λ  = sm.symbols('x y λ')

Langrangian_objective = 100*(y-x**2)**2 + (1-x)**2 + λ*(x**2-y-2)
Gamma = [x,y,λ]
Gamma0 = {x:-1.2,y:1,λ:1}

newton_method(Langrangian_objective,Gamma,Gamma0)

对应的输出为:

可以很容易验证解满足我们的等式约束。就这样!这还不算太难,对吧?这种方法可以扩展以添加任何数量的等式约束——只需添加另一个拉格朗日乘子。现在我们继续讨论如何纳入不等式约束。

不等式约束—对数障碍函数

现在我们将处理带有不等式约束的优化问题。即,具有如下形式的优化问题:

(10)

再次假设,我们在处理 Rosenbrock 的抛物线谷,但现在有不等式约束 x ≤ 0 和 y ≥ 3:

(11) 带有不等式约束的 Rosenbrock 抛物线谷 (问题 2)

现在我们的优化问题变成了:

Rosenbrock 的抛物线谷(紫黄色色彩图)和不等式约束平面(黑色)

最优值的可行区域位于由红星标记的约束所界定的象限中。

由于这些约束没有严格的等式,我们无法将它们直接纳入目标函数。然而,我们可以动脑筋——我们可以做的是增强目标函数,在目标函数中加入一个“障碍”,对接近不等式约束边界的解值进行惩罚。这些方法被称为“内点法”或“障碍法”[4][5]。像拉格朗日函数一样,我们可以通过引入障碍函数(在我们这个案例中是对数障碍函数)将原来的约束优化问题转化为无约束优化问题,从而创建障碍函数。形式上,对数障碍函数的特点是:

(12) 对数障碍函数的正式定义

其中ρ是一个小的正标量——称为障碍参数。随着ρ → 0,障碍函数B(x,ρ)的解应收敛到我们原始约束优化函数的解。注意,c(x)表示,根据我们如何制定不等式约束(大于或小于零),将决定我们使用该约束的负值或正值。我们知道 y=log(x)在 x ≤ 0 时未定义,因此我们需要将约束制定为始终≥0。

你可能会问,对数障碍法究竟是如何工作的?首先,在使用障碍法时,我们必须选择位于可行区域的起始值。随着最优值接近由约束定义的“障碍”,该方法依赖于对数函数在值接近零时趋向负无穷的特性,从而惩罚目标函数值。随着ρ → 0,惩罚减小(见下图),我们逐渐收敛到解。然而,需要从足够大的ρ开始,以确保惩罚足够大,防止“跳出”障碍。因此,该算法比牛顿法多了一个额外的循环——即,我们选择一个起始值ρ,使用牛顿法优化障碍函数,然后通过缓慢减小ρρ → 0)来更新ρ,直到收敛。

不同值的ρ的对数障碍

回到我们之前的例子——不等式约束的罗森布罗克抛物谷(公式 2)——我们可以将障碍函数写为:

(13)

记住 log(a) + log(b) = log(ab),以及我们的一个约束 x ≤ 0 → -x ≥ 0。我们必须更新我们的代码以适应障碍法算法:

import sympy as sm
import numpy as np

def constrained_newton_method(
    function: sm.core.expr.Expr,
    symbols: list[sm.core.symbol.Symbol],
    x0: dict[sm.core.symbol.Symbol, float],
    iterations: int = 100,
) -> dict[sm.core.symbol.Symbol, float] or None:
    """
    Performs constrained Newton's method to find the optimal solution of a function subject to constraints.

    Parameters:
        function (sm.core.expr.Expr): The function to optimize.
        symbols (list[sm.core.symbol.Symbol]): The symbols used in the function.
        x0 (dict[sm.core.symbol.Symbol, float]): The initial values for the symbols.
        iterations (int, optional): The maximum number of iterations. Defaults to 100.

    Returns:
        dict[sm.core.symbol.Symbol, float] or None: The optimal solution if convergence is achieved, otherwise None.
    """
    x_star = {}
    x_star[0] = np.array(list(x0.values())[:-1])

    optimal_solutions = []
    optimal_solutions.append(dict(zip(list(x0.keys())[:-1], x_star[0])))

    for step in range(iterations):
        # Evaluate function at rho value
        if step == 0:  # starting rho
            rho_sub = list(x0.values())[-1]

        rho_sub_values = {list(x0.keys())[-1]: rho_sub}
        function_eval = function.evalf(subs=rho_sub_values)

        print(f"Step {step} w/ {rho_sub_values}")  # Barrier method step
        print(f"Starting Values: {x_star[0]}")

        # Newton's Method
        for i in range(iterations):
            gradient = get_gradient(
                function_eval, symbols[:-1], dict(zip(list(x0.keys())[:-1], x_star[i]))
            )
            hessian = get_hessian(
                function_eval, symbols[:-1], dict(zip(list(x0.keys())[:-1], x_star[i]))
            )

            x_star[i + 1] = x_star[i].T - np.linalg.inv(hessian) @ gradient.T

            if np.linalg.norm(x_star[i + 1] - x_star[i]) < 10e-5:
                solution = dict(zip(list(x0.keys())[:-1], x_star[i + 1]))
                print(
                    f"Convergence Achieved ({i+1} iterations): Solution = {solution}\n"
                )
                break

        # Record optimal solution & previous optimal solution for each barrier method iteration
        optimal_solution = x_star[i + 1]
        previous_optimal_solution = list(optimal_solutions[step - 1].values())
        optimal_solutions.append(dict(zip(list(x0.keys())[:-1], optimal_solution)))

        # Check for overall convergence
        if np.linalg.norm(optimal_solution - previous_optimal_solution) < 10e-5:
            print(
                f"\n Overall Convergence Achieved ({step} steps): Solution = {optimal_solutions[step]}\n"
            )
            overall_solution = optimal_solutions[step]
            break
        else:
            overall_solution = None

        # Set new starting point
        x_star = {}
        x_star[0] = optimal_solution

        # Update rho
        rho_sub = 0.9 * rho_sub

    return overall_solution

我们现在可以用上述代码求解障碍函数(注意:确保起始值在不等式约束的可行范围内,如果跳出不等式约束,可能需要增加ρ的起始值):

import sympy as sm

x, y, ρ = sm.symbols('x y ρ')

Barrier_objective = 100*(y-x**2)**2 + (1-x)**2 - ρ*sm.log((-x)*(y-3))
Gamma = [x,y,ρ] # Function requires last symbol to be ρ!
Gamma0 = {x:-15,y:15,ρ:10}

constrained_newton_method(Barrier_objective,Gamma,Gamma0)

其对应的输出为:

很明显,解满足指定的不等式约束。就是这样。我们现在已经处理了优化问题中的不等式约束。最后,让我们将一切整合起来,继续处理具有混合约束的约束优化问题——这只是我们上述工作组合的结果。

汇总

现在,让我们通过结合上述的等式和不等式约束来解决我们的优化问题。也就是说,我们想要解决一个形式的优化问题:

(14)

我们只需要将拉格朗日函数和障碍函数结合成一个函数。因此,我们可以创建一个通用函数,称之为 O,用于处理具有等式和不等式约束的优化问题:

(15) 可推广的约束优化问题的函数

其中,如前所述,Λ 是拉格朗日乘子向量,ρ 是障碍参数。因此,结合我们上面的受限(公式 6)和无约束问题(公式 11),我们可以将我们的混合受限优化问题表述如下:

(16)

在 Python 中,

import sympy as sm

x, y, λ, ρ = sm.symbols('x y λ ρ')

combined_objective = 100*(y-x**2)**2 + (1-x)**2 + λ*(x**2-y-2) - ρ*sm.log((-x)*(y-3))
Gamma = [x,y,λ,ρ] # Function requires last symbol to be ρ!
Gamma0 = {x:-15,y:15,λ:0,ρ:10}

constrained_newton_method(combined_objective,Gamma,Gamma0)

以及相应的输出:

我们可以验证这个解确实符合约束条件。具体来说,x ≤ 0,y ≥ 3,& x² - y = 2。令人满意,不是吗?

结论

哎呀。深呼吸一下——你值得拥有。希望到目前为止,你对将约束条件融入优化问题的技术有了更好的理解。我们仍然只是触及了数学优化中各种工具和技术的表面。

敬请关注系列的第三部分,即最后一部分,我们将应用到目前为止学到的优化材料,并结合计量经济学和经济理论来解决利润最大化问题。我的目标是第三部分将总结我们所涵盖的内容,并展示一个实际应用案例。像往常一样,我希望你能像我写作时一样享受阅读这篇文章!

资源

[1] en.wikipedia.org/wiki/Constrained_optimization

[2] Snyman, J. A., & Wilke, D. N. (2019). 实用数学优化:基本优化理论与基于梯度的算法(第 2 版)。Springer。

[3] en.wikipedia.org/wiki/Lagrange_multiplier

[4] en.wikipedia.org/wiki/Interior-point_method

[5] en.wikipedia.org/wiki/Barrier_function

通过这个 GitHub 仓库访问所有代码: github.com/jakepenzak/Blog-Posts

感谢你阅读我的文章!我在 Medium 上的文章旨在探索利用 计量经济学 统计/机器学习 技术的现实世界和理论应用。此外,我还希望通过理论和模拟提供各种方法论的理论基础。最重要的是,我写作是为了学习!我希望能让复杂的话题对所有人更易于理解。如果你喜欢这篇文章,请考虑 在 Medium 上关注我

posted @ 2024-10-12 19:55  绝不原创的飞龙  阅读(409)  评论(0)    收藏  举报