Python:变长和定长序列拆分

1 导引

Python中的任何序列(可迭代的对象)都可以通过赋值操作进行拆分,包括但不限于元组、列表、字符串、文件、迭代器、生成器等。

2 元组拆分

元组拆分是最为常见的一种拆分,示例如下:

p = (4, 5)
x, y = p 
print(x, y) # 4 5

如果写成

x, y, z = p

那么就会抛出ValueError异常:"not enough values to unpack (expected 3, got 2)"
如果写成

p = (4, 5, 6)
x, y = p

那么就会抛出ValueError异常:"too many values to unpack (expected 2)"

元组拆包无处不在,比如我们知道Python的zip()函数相当于返回一个元组列表的迭代器,我们可以对该元组列表进行迭代拆分:

list_a = [1, 3, 5]
list_b = [2, 4, 6]
print(list(zip(list_a, list_b)))
#[(1, 2), (3, 4), (5, 6)]
for a, b in zip(list_a, list_b):
    print("%d - %d" % (a, b))
# 1 - 2
# 3 - 4
# 5 - 6

上面的迭代语句其实隐式等价于for (a, b) in zip(list_a, list_b)
接下来容易出错的地方来了,很多时候我们会将附加索引下标的上述迭代错误地这样写:

for idx, a, b in enumerate(zip(list_a, list_b)):
    print("idx%d: %d - %d" % (idx, a, b))

此时就会抛出ValueError异常:"not enough values to unpack (expected 3, got 2)"。我们经过上面的讨论知道,这是不正确地元组拆包所致。

原来,迭代enumerate(zip(list_a, list_b))实际等价于迭代[(0, (1, 2)), (1, (3, 4)), (2, (5, 6))]:

print(list(enumerate(zip(list_a, list_b))))
# [(0, (1, 2)), (1, (3, 4)), (2, (5, 6))]

对其迭代需要进行两次复合的元组拆包,即:

for idx, (a, b) in enumerate(zip(list_a, list_b)):
    print("idx%d: %d - %d" % (idx, a, b))
# idx0: 1 - 2
# idx1: 3 - 4
# idx2: 5 - 6

还是同样地,两次拆包有一次隐式省略,上述迭代语句隐式等价于for (idx, (a, b)) in enumerate(zip(list_a, list_b)):

这里值得一提的是,上面说了用zip()函数+list()函数可以让我们获得一个元组列表,该操作的在机器学习项目的场景下非常实用,因为我们已知一堆点的\(x\)坐标列表和\(y\)坐标列表,我们可以通过zip()函数+list()函数的形式获得\((x,y)\)坐标列表。然而,如果我们已知\((x,y)\)坐标列表,如何快速恢复出\(x\)坐标列表和\(y\)坐标列表呢?我们可以这样写:

points = [(1, 2), (3, 4), (5, 6)]
x, y = zip(*points)
print(x) # (1, 3, 5)
print(y) # (2, 4, 6)

这里有个*运算符读者可能感到陌生,这表示将points列表中的所有元素以位置参数的形式传入zip()函数(读者可以参见我的博客《Python:位置参数、关键字参数和接受任意数量的参数》),而zip(*points)实际上等价于

zip((1, 2), (3, 4), (5, 6))

而我们前面说过,迭代上述zip()函数返回的迭代器实质上等于迭代元组列表[(1, 3, 5), (2, 4, 6)]。因为该元组列表只有两个元素,故我们可以直接对该列表进行拆包,于是得到了拆包结果(1, 3, 5)(2, 4, 6)。从这个视角看,zip(*)操作可以理解将二维数据沿纵向拆分成列向量

这种写法有个巨大的应用场景就是处理机器学习的训练数据。比如,假设我们在做一个机器学习项目,有下列训练数据X和训练数据Y

import numpy as np
X = np.random.rand(5, 3)
Y = np.random.randint(0, 2, size=(5, 1))
print(X)
# [[0.20447277 0.85066912 0.3331559 ]
#  [0.78313617 0.78667579 0.17555529]
#  [0.67388656 0.75179676 0.58292836]
#  [0.12512522 0.5669724  0.45970325]
#  [0.61955282 0.64029496 0.93385069]]
print(Y)
# [[0]
#  [1]
#  [1]
#  [1]
#  [0]]

我们接下来想不借助scikit-learn库中的sklearn.utils.shuffle函数,仅仅使用numpy包和Python内置函数来优雅地完成对数据集的shuffle操作,那么该怎么做呢?首先,直接写

np.random.shuffle(X)
np.random.shuffle(Y)

是不行的,因为这样会丢失样本数据xy的一一对应关系。事实上,我们可以先试用zip函数将原始的XY数据转换成(x, y)二元组组成的坐标列表:

x_y_pair = list(zip(X, Y)) 
print(x_y_pair)
# [(array([0.36742827, 0.02156507, 0.07500242]), array([1])), 
# (array([0.6562936 , 0.7262091 , 0.50394983]), array([0])), 
# (array([0.02043896, 0.08081809, 0.5199801 ]), array([0])), 
# (array([0.87178023, 0.06728234, 0.54260044]), array([1])), 
# (array([0.81271828, 0.50946797, 0.02489041]), array([1]))]

然后在此基础上进行shuffle:

np.random.shuffle(x_y_pair)
print(x_y_pair)
# [(array([0.6562936 , 0.7262091 , 0.50394983]), array([0])), 
# (array([0.81271828, 0.50946797, 0.02489041]), array([1])), 
# (array([0.02043896, 0.08081809, 0.5199801 ]), array([0])), 
# (array([0.87178023, 0.06728234, 0.54260044]), array([1])), 
# (array([0.36742827, 0.02156507, 0.07500242]), array([1]))]

然后,我们再借用zip(*)np.stack()组合操作得到拼接完成的数据集:

X = np.stack(list(zip(*x_y_pair))[0])
Y = np.stack(list(zip(*x_y_pair))[1])
print(X)
# [[0.6562936  0.7262091  0.50394983]
#  [0.81271828 0.50946797 0.02489041]
#  [0.02043896 0.08081809 0.5199801 ]
#  [0.87178023 0.06728234 0.54260044]
#  [0.36742827 0.02156507 0.07500242]]
print(Y)
# [[0]
#  [1]
#  [0]
#  [1]
#  [1]]

正如我们前面所说的,这里list(zip(*x_y_pair))沿纵向将x_y_pair拆分成xy这两部分,得到了一个由x向量组成的元组(array([0.6562936,...]), array([0.81271828,...]), ..., array([0.36742827, ...))和一个由y构成的元组(array([0]), array([1]), ..., array([1])),然后我们再将由x向量构成的元组和y构成的元组进行stack操作,就还原了我们的XY数据

PS:这里np.stack()是对数据的堆叠(会增加一个额外维度),比如对一维的数据(shape为(n, ))就是堆叠得到一个新的二维数据;而np.concatenate()则是需要指定一个维度进行拼接(不会增加额外维度),对一维数据就是拼接得到一个新的一维数据:

import numpy as np

a = np.array([1, 2, 3])
b = np.array([3, 4, 5])

res1 = np.stack([a, b])
res2 = np.concatenate([a, b])
print(res1)
# [[1 2 3]
#  [3 4 5]]
print(res2)
# [1 2 3 3 4 5]

下面是在多维数据(shape为(n, 1)(n, m)下的情况:

c = np.array([[1, 2, 3]])
d = np.array([[4, 5, 6]])

res3 = np.stack([c, d])
res4 = np.concatenate([c, d])
print(res3)
# [[[1 2 3]]

#  [[4 5 6]]]
print(res4)
# [[1 2 3]
#  [4 5 6]]

e = np.array([[1, 2], [3, 4]])
f = np.array([[5, 6], [7, 8]])

res5 = np.stack([e, f])
res6 = np.concatenate([e, f])
print(res5)
# [[[1 2]
#   [3 4]]

#  [[5 6]
#   [7 8]]]
print(res6)
# [[1 2]
#  [3 4]
#  [5 6]
#  [7 8]]

可见和一维的情况一样,np.stack会增加额外的维度,np.concatenate()则不会。此二者都是默认axis=0,即沿着维度0的方向堆叠/拼接。

好了,现在言归正传,回到我们关于Python元组拆分的讨论。其实,Python中所谓函数能返回多个值,其实是返回的元组,如下面这种所示:

def func():
    return 1, 2, 3

实际上等同于返回(1, 2, 3)元组。我们可以选择直接接收该元组对象:

my_tuple = func()
print(my_tuple) # (1, 2, 3)

注意,上面这个代码中my_tuple为一个引用,引用在函数体内部创建的元组对象(如对此有疑问,可参见我的博客《Python对象模型与序列迭代陷阱 》)。

当然,也可以将元组拆包接收:

a, b, c = func()
print(a, b, c) # 1 2 3

但是注意,如果要拆包必须要保证拆包正确,像下面这种写法:

a, b = func()

无疑就会抛出ValueError异常:"too many values to unpack (expected 2)"了。

3 字符串拆分

字符串的拆分示意如下:

s = 'Hello'
a, b, c, d, e = s
print(a) # H

4 拆分时丢弃值

如果在拆分时想丢弃某些特定的值,可以用一个用不到的变量名来作为丢弃值的名称(常选_做为变量名),如下所示:

s = 'Hello'
a, b, _, d, _ = s
print(a) # H

5 嵌套序列拆分

Python也提供简洁的对嵌套序列进行拆分的语法。如下所示我们对一个比较复杂的异质列表进行拆分:

data = ['zhy', 50, 123.0, (2000, 12, 21)]
name, shares, price, (year, month, day) = data
print(year) # 2000

如果你想完整地得到(2000, 12, 21)这个表示时间戳的元组,那么你就得这样写:

data = ['zhy', 50, 123.0, (2000, 12, 21)]
name, shares, price, date = data
print(date) # (2000, 12, 21)

6 从任意长度的可迭代对象中拆分

之前我们说过,如果我们想从可迭代对象中分解出\(N\)个元素,但如果这个可迭代对象长度超过\(N\),则会抛出异常"too many values to unpack"。针对这个问题的解决方案是采用*表达式。
比如我们给定学生的分数,想去掉一个最高分和一个最低分,然后对剩下的学生求平均分,我们可以这样写:

def avg(data: list):
    return sum(data)/len(data)
# 去掉最高分,最低分然后做均分统计
def drop_first_last(grades):
    first, *middle, last = grades
    return avg(middle)
print(drop_first_last([1,2,3,4])) # 2.5

还有一种情况是有一些用户记录,记录由姓名+电子邮件+任意数量的电话号码组成,则我们可以这样分解用户记录:

record = ['zhy', 'zhy1056692290@qq.com', '773-556234', '774-223333']
name, email, *phone_numbers = record
print(phone_numbers) # ['773-556234', '774-223333']

事实上,如果电话号码为空也是合法的,此时phone_numbers为空列表。

record = ['zhy', 'zhy1056692290@qq.com']
name, email, *phone_numbers = record
print(phone_numbers) # []

还有一种使用情况则更为巧妙。如果我们需要遍历变长元组组成的列表,这些元组长度不一。那么此时*表达式可大大简化我们的代码。

records = [('foo', 1, 2), ('bar', 'hello'), ('foo', 3, 4)]
for tag, *args in records:
    if tag == 'bar':
        print(args)
# ['hello']

在对一些复杂的字符串进行拆分时,*表达式也显得特别有用。

line = "nobody:*:-2:-2:-2:Unprivileged User:/var/empty:/usr/bin/false"
uname, *fields, home_dir, sh = line.split(':')
print(home_dir) # /var/empty

*表达式也可以和我们前面说的嵌套拆分和变量丢弃一起结合使用。

record = ['ACME', 50, 123.45, (128, 18, 2012)]
name, *_, (*_, year) = record
print(year) # 2012

最后再介绍*表达式用于递归函数的一种黑魔法,比如与递归求和结合可以这样写:

items = [1, 10, 7, 4, 5, 9]
def sum(items):
    head, *tail = items
    return head + sum(tail) if tail else head
print(sum(items)) # 36

不过,Python由于自身递归栈的限制,并不擅长递归。我们最后一个递归的例子可以做为一种学术上的尝试,但不建议在实践中使用它。

参考

posted @ 2021-10-09 22:12  orion-orion  阅读(388)  评论(0编辑  收藏  举报