原文地址
has_one_axis函数
-
has_one_axis(),原文是将其定义在plot内部,但是为了分析放便,将其抽离出来
-
def has_one_axis(X):
"""
isinstance(X, list) and not hasattr(X[0], "__len__")
这是一个条件表达式,用于检查变量X是否为列表类型并且其第一个元素是否为非序列类型(即单个值类型)。
如果X是单轴的,那么说明X对应于多条曲线的(函数)或者多条曲线将共用相同的自变量输入
具体来说,这个条件表达式包含两个部分,用and运算符连接起来:
isinstance(X, list):检查变量X是否为列表类型。如果是,则返回True;否则返回False。
not hasattr(X[0], "len"):检查变量X的第一个元素是否为序列类型(即是否具有__len__属性)。如果不是,则返回True;否则返回False。由于not运算符的作用,如果第一个元素不是序列类型,则整个表达式的结果为True;否则结果为False。
因此,这个条件表达式的作用是检查变量X是否为列表类型,并且其中的第一个元素是否为非序列类型。如果满足这个条件,则说明X是一个包含单个值的列表;否则说明X是一个多维列表或数组。
假设有以下两个变量:
X = [1, 2, 3, 4, 5]
Y = [[1, 2], [3, 4], [5, 6]]
对于变量X,它是一个包含单个值的列表,因此符合条件表达式中的两个条件,即X是列表类型,并且其第一个元素(即1)是非序列类型。因此,条件表达式的结果为True。
对于变量Y,它是一个二维列表,因此不符合条件表达式中的第二个条件,即其中的第一个元素(即[1, 2])是序列类型。因此,条件表达式的结果为False。
因此,可以使用这个条件表达式来检查变量X是否为包含单个值的列表,以便在绘图函数中统一处理数据类型。
不过,如果判断出X是列表类型,也可以将其封装为ndarray,再判断ndim属性np.array(X).ndim==1
"""
ndarray_1dim=hasattr(X, "ndim") and X.ndim == 1
list_1dim= isinstance(X, list) and X==[] or not hasattr(X[0], "__len__")
return ndarray_1dim or list_1dim
d2l.plot函数
def plot(X, Y=None, xlabel=None, ylabel=None, legend=[], xlim=None,
ylim=None, xscale='linear', yscale='linear',
fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):
"""绘制数据点"""
"""这是一个用Python编写的绘制数据点的函数。该函数可以接受多组数据作为输入,将它们绘制成图形,并可以设置各种参数,如坐标轴标签、图例、坐标轴范围等等。
具体来说,该函数的参数包括:
X:一个列表或数组,作为横坐标
Y:一个列表或数组,以X为横坐标(自变量数组)根据若干函数,计算出对应若干组函数值向量。如果不提供Y,则默认使用X作为y轴坐标值。
xlabel:x轴的标签。
ylabel:y轴的标签。
legend:一个列表,包含图例标签的字符串。默认值为[]。
xlim:一个元组,包含x轴范围的最小值和最大值。
ylim:一个元组,包含y轴范围的最小值和最大值。
xscale:x轴的缩放类型。默认值为'linear'。
yscale:y轴的缩放类型。默认值为'linear'。
fmts:一个元组,包含线条的样式。默认值为('-', 'm--', 'g-.', 'r:')。这将允许4条曲线有互不相同的样式
figsize:一个元组,包含图形的宽度和高度。默认值为(3.5, 2.5)。
axes:一个matplotlib.axes.Axes对象,表示绘图的坐标系。如果没有提供,则默认使用当前坐标系。
该函数的实现过程主要包括以下几个步骤:
设置图形的大小。
提取或创建要使用的坐标系对象。
检查输入数据的格式,并将它们统一为列表的形式。
清空坐标系,并绘制每组数据点。
设置坐标轴的标签、范围、缩放类型和图例。
这个函数可以方便地绘制多组数据点,并且可以通过修改参数来调整图形的样式和布局。"""
set_figsize(figsize)
axes = axes if axes else d2l.plt.gca()
if has_one_axis(X):
X = [X]
if Y is None:
X, Y = [[]] * len(X), X
elif has_one_axis(Y):
Y = [Y]
if len(X) != len(Y):
X = X * len(Y)
fmts=fmts*len(Y)
axes.cla()
for x, y, fmt in zip(X, Y, fmts):
if len(x):
axes.plot(x, y, fmt)
else:
axes.plot(y, fmt)
set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
demos
-
import numpy as np
x = np.linspace(0, 1, 10)
y1 = np.random.randn(10)
y2 = np.random.randn(10)
y3 = np.random.randn(10)
y4= np.random.randn(10)
y5= np.random.randn(10)
legend=["y"+str(i) for i in range(5)]
-
plot(x, [y1, y2, y3,y4,y5], xlabel='x', ylabel='y',legend=legend)
-
plot(X=x, Y=Y, xlabel='x', ylabel='f(x)', legend=['f(x)', 'Tangent line (x=1)'])
-
plot(X=X, Y=None, xlabel='x', ylabel='f(x)', legend=['f(x)', 'Tangent line (x=1)'])
-
plot(X=X, Y=Y, xlabel='x', ylabel='f(x)', legend=['f(x)', 'Tangent line (x=1)'])