JAX-中文文档-十四-
JAX 中文文档(十四)
jax.scipy 模块
jax.scipy.cluster
| vq(obs, code_book[, check_finite]) | 将观测值分配给代码簿中的代码。 | ## jax.scipy.fft
dct(x[, type, n, axis, norm]) |
计算输入的离散余弦变换 |
|---|---|
dctn(x[, type, s, axes, norm]) |
计算输入的多维离散余弦变换 |
idct(x[, type, n, axis, norm]) |
计算输入的离散余弦变换的逆变换 |
| idctn(x[, type, s, axes, norm]) | 计算输入的多维离散余弦变换的逆变换 | ## jax.scipy.integrate
| trapezoid(y[, x, dx, axis]) | 使用复合梯形法则沿指定轴积分。 | ## jax.scipy.interpolate
| RegularGridInterpolator(points, values[, ...]) | 对正规矩形网格上的点进行插值。 | ## jax.scipy.linalg
block_diag(*arrs) |
从输入数组创建块对角矩阵。 |
|---|---|
cho_factor(a[, lower, overwrite_a, check_finite]) |
基于 Cholesky 的线性求解因式分解 |
cho_solve(c_and_lower, b[, overwrite_b, ...]) |
使用 Cholesky 分解解线性系统 |
cholesky(a[, lower, overwrite_a, check_finite]) |
计算矩阵的 Cholesky 分解。 |
det(a[, overwrite_a, check_finite]) |
计算矩阵的行列式 |
eigh() |
计算 Hermitian 矩阵的特征值和特征向量 |
eigh_tridiagonal(d, e, *[, eigvals_only, ...]) |
解对称实三对角矩阵的特征值问题 |
expm(A, *[, upper_triangular, max_squarings]) |
计算矩阵指数 |
expm_frechet() |
计算矩阵指数的 Frechet 导数 |
funm(A, func[, disp]) |
评估矩阵值函数 |
hessenberg() |
计算矩阵的 Hessenberg 形式 |
hilbert(n) |
创建阶数为 n 的 Hilbert 矩阵。 |
inv(a[, overwrite_a, check_finite]) |
返回方阵的逆矩阵 |
lu() |
计算 LU 分解 |
lu_factor(a[, overwrite_a, check_finite]) |
基于 LU 的线性求解因式分解 |
lu_solve(lu_and_piv, b[, trans, ...]) |
使用 LU 分解解线性系统 |
polar(a[, side, method, eps, max_iterations]) |
计算极分解 |
qr() |
计算数组的 QR 分解 |
rsf2csf(T, Z[, check_finite]) |
将实数舒尔形式转换为复数舒尔形式。 |
schur(a[, output]) |
计算舒尔分解 |
solve(a, b[, lower, overwrite_a, ...]) |
解线性方程组 |
solve_triangular(a, b[, trans, lower, ...]) |
解上(或下)三角线性方程组 |
sqrtm(A[, blocksize]) |
计算矩阵的平方根 |
svd() |
计算奇异值分解 |
| toeplitz(c[, r]) | 构造 Toeplitz 矩阵 | ## jax.scipy.ndimage
| map_coordinates(input, coordinates, order[, ...]) | 使用插值将输入数组映射到新坐标。 | ## jax.scipy.optimize
minimize(fun, x0[, args, tol, options]) |
最小化一个或多个变量的标量函数。 |
|---|
| OptimizeResults(x, success, status, fun, ...) | 优化结果对象。 | ## jax.scipy.signal
fftconvolve(in1, in2[, mode, axes]) |
使用快速傅里叶变换(FFT)卷积两个 N 维数组。 |
|---|---|
convolve(in1, in2[, mode, method, precision]) |
两个 N 维数组的卷积。 |
convolve2d(in1, in2[, mode, boundary, ...]) |
两个二维数组的卷积。 |
correlate(in1, in2[, mode, method, precision]) |
两个 N 维数组的互相关。 |
correlate2d(in1, in2[, mode, boundary, ...]) |
两个二维数组的互相关。 |
csd(x, y[, fs, window, nperseg, noverlap, ...]) |
使用 Welch 方法估计交叉功率谱密度(CSD)。 |
detrend(data[, axis, type, bp, overwrite_data]) |
从数据中移除线性或分段线性趋势。 |
istft(Zxx[, fs, window, nperseg, noverlap, ...]) |
执行逆短时傅里叶变换(ISTFT)。 |
stft(x[, fs, window, nperseg, noverlap, ...]) |
计算短时傅里叶变换(STFT)。 |
| welch(x[, fs, window, nperseg, noverlap, ...]) | 使用 Welch 方法估计功率谱密度(PSD)。 | ## jax.scipy.spatial.transform
Rotation(quat) |
三维旋转。 |
|---|
| Slerp(times, timedelta, rotations, rotvecs) | 球面线性插值旋转。 | ## jax.scipy.sparse.linalg
bicgstab(A, b[, x0, tol, atol, maxiter, M]) |
使用双共轭梯度稳定迭代解决 Ax = b。 |
|---|---|
cg(A, b[, x0, tol, atol, maxiter, M]) |
使用共轭梯度法解决 Ax = b。 |
| gmres(A, b[, x0, tol, atol, restart, ...]) | GMRES 解决线性系统 A x = b,给定 A 和 b。 | ## jax.scipy.special
bernoulli(n) |
生成前 N 个伯努利数。 |
|---|---|
beta() |
贝塔函数 |
betainc(a, b, x) |
正则化的不完全贝塔函数。 |
betaln(a, b) |
贝塔函数绝对值的自然对数 |
digamma(x) |
Digamma 函数 |
entr(x) |
熵函数 |
erf(x) |
误差函数 |
erfc(x) |
误差函数的补函数 |
erfinv(x) |
误差函数的反函数 |
exp1(x) |
指数积分函数。 |
expi |
指数积分函数。 |
expit(x) |
逻辑 sigmoid(expit)函数 |
expn |
广义指数积分函数。 |
factorial(n[, exact]) |
阶乘函数 |
gamma(x) |
伽马函数。 |
gammainc(a, x) |
正则化的下不完全伽马函数。 |
gammaincc(a, x) |
正则化的上不完全伽马函数。 |
gammaln(x) |
伽马函数绝对值的自然对数。 |
gammasgn(x) |
伽马函数的符号。 |
hyp1f1 |
1F1 超几何函数。 |
i0(x) |
修改贝塞尔函数零阶。 |
i0e(x) |
指数缩放的修改贝塞尔函数零阶。 |
i1(x) |
修改贝塞尔函数一阶。 |
i1e(x) |
指数缩放的修改贝塞尔函数一阶。 |
log_ndtr |
对数正态分布函数。 |
logit |
对数几率函数。 |
logsumexp() |
对数-总和-指数归约。 |
lpmn(m, n, z) |
第一类相关勒让德函数(ALFs)。 |
lpmn_values(m, n, z, is_normalized) |
第一类相关勒让德函数(ALFs)。 |
multigammaln(a, d) |
多变量伽马函数的自然对数。 |
ndtr(x) |
正态分布函数。 |
ndtri(p) |
正态分布函数的反函数。 |
poch |
Pochhammer 符号。 |
polygamma(n, x) |
多次伽马函数。 |
spence(x) |
斯宾斯函数,也称实数域下的二元对数函数。 |
sph_harm(m, n, theta, phi[, n_max]) |
计算球谐函数。 |
xlog1py |
计算 x*log(1 + y),当 x=0 时返回 0。 |
xlogy |
计算 x*log(y),当 x=0 时返回 0。 |
zeta |
赫维茨 ζ 函数。 |
kl_div(p, q) |
库尔巴克-莱布勒散度。 |
| rel_entr(p, q) | 相对熵函数。 | ## jax.scipy.stats
mode(a[, axis, nan_policy, keepdims]) |
计算数组沿轴的众数(最常见的值)。 |
|---|---|
rankdata(a[, method, axis, nan_policy]) |
计算数组沿轴的排名。 |
sem(a[, axis, ddof, nan_policy, keepdims]) |
计算均值的标准误差。 |
jax.scipy.stats.bernoulli
logpmf(k, p[, loc]) |
伯努利对数概率质量函数。 |
|---|---|
pmf(k, p[, loc]) |
伯努利概率质量函数。 |
cdf(k, p) |
伯努利累积分布函数。 |
| ppf(q, p) | 伯努利百分位点函数。 | ### jax.scipy.stats.beta
logpdf(x, a, b[, loc, scale]) |
Beta 对数概率分布函数。 |
|---|---|
pdf(x, a, b[, loc, scale]) |
Beta 概率分布函数。 |
cdf(x, a, b[, loc, scale]) |
Beta 累积分布函数。 |
logcdf(x, a, b[, loc, scale]) |
Beta 对数累积分布函数。 |
sf(x, a, b[, loc, scale]) |
Beta 分布生存函数。 |
| logsf(x, a, b[, loc, scale]) | Beta 分布对数生存函数。 | ### jax.scipy.stats.betabinom
logpmf(k, n, a, b[, loc]) |
Beta-二项式对数概率质量函数。 |
|---|
| pmf(k, n, a, b[, loc]) | Beta-二项式概率质量函数。 | ### jax.scipy.stats.binom
logpmf(k, n, p[, loc]) |
二项式对数概率质量函数。 |
|---|
| pmf(k, n, p[, loc]) | 二项式概率质量函数。 | ### jax.scipy.stats.cauchy
logpdf(x[, loc, scale]) |
柯西对数概率分布函数。 |
|---|---|
pdf(x[, loc, scale]) |
柯西概率分布函数。 |
cdf(x[, loc, scale]) |
柯西累积分布函数。 |
logcdf(x[, loc, scale]) |
柯西对数累积分布函数。 |
sf(x[, loc, scale]) |
柯西分布对数生存函数。 |
logsf(x[, loc, scale]) |
柯西对数生存函数。 |
isf(q[, loc, scale]) |
柯西分布逆生存函数。 |
| ppf(q[, loc, scale]) | 柯西分布分位点函数。 | ### jax.scipy.stats.chi2
logpdf(x, df[, loc, scale]) |
卡方分布对数概率分布函数。 |
|---|---|
pdf(x, df[, loc, scale]) |
卡方概率分布函数。 |
cdf(x, df[, loc, scale]) |
卡方累积分布函数。 |
logcdf(x, df[, loc, scale]) |
卡方对数累积分布函数。 |
sf(x, df[, loc, scale]) |
卡方生存函数。 |
| logsf(x, df[, loc, scale]) | 卡方对数生存函数。 | ### jax.scipy.stats.dirichlet
logpdf(x, alpha) |
狄利克雷对数概率分布函数。 |
|---|
| pdf(x, alpha) | 狄利克雷概率分布函数。 | ### jax.scipy.stats.expon
logpdf(x[, loc, scale]) |
指数对数概率分布函数。 |
|---|
| pdf(x[, loc, scale]) | 指数概率分布函数。 | ### jax.scipy.stats.gamma
logpdf(x, a[, loc, scale]) |
伽玛对数概率分布函数。 |
|---|---|
pdf(x, a[, loc, scale]) |
伽玛概率分布函数。 |
cdf(x, a[, loc, scale]) |
伽玛累积分布函数。 |
logcdf(x, a[, loc, scale]) |
伽玛对数累积分布函数。 |
sf(x, a[, loc, scale]) |
伽玛生存函数。 |
| logsf(x, a[, loc, scale]) | 伽玛对数生存函数。 | ### jax.scipy.stats.gennorm
cdf(x, beta) |
广义正态累积分布函数。 |
|---|---|
logpdf(x, beta) |
广义正态对数概率分布函数。 |
| pdf(x, beta) | 广义正态概率分布函数。 | ### jax.scipy.stats.geom
logpmf(k, p[, loc]) |
几何对数概率质量函数。 |
|---|
| pmf(k, p[, loc]) | 几何概率质量函数。 | ### jax.scipy.stats.laplace
cdf(x[, loc, scale]) |
拉普拉斯累积分布函数。 |
|---|---|
logpdf(x[, loc, scale]) |
拉普拉斯对数概率分布函数。 |
| pdf(x[, loc, scale]) | 拉普拉斯概率分布函数。 | ### jax.scipy.stats.logistic
cdf(x[, loc, scale]) |
Logistic 累积分布函数。 |
|---|---|
isf(x[, loc, scale]) |
Logistic 分布逆生存函数。 |
logpdf(x[, loc, scale]) |
Logistic 对数概率分布函数。 |
pdf(x[, loc, scale]) |
Logistic 概率分布函数。 |
ppf(x[, loc, scale]) |
Logistic 分位点函数。 |
| sf(x[, loc, scale]) | Logistic 分布生存函数。 | ### jax.scipy.stats.multinomial
logpmf(x, n, p) |
多项式对数概率质量函数。 |
|---|
| pmf(x, n, p) | 多项分布概率质量函数。 | ### jax.scipy.stats.multivariate_normal
logpdf(x, mean, cov[, allow_singular]) |
多元正态分布对数概率分布函数。 |
|---|
| pdf(x, mean, cov) | 多元正态分布概率分布函数。 | ### jax.scipy.stats.nbinom
logpmf(k, n, p[, loc]) |
负二项分布对数概率质量函数。 |
|---|
| pmf(k, n, p[, loc]) | 负二项分布概率质量函数。 | ### jax.scipy.stats.norm
logpdf(x[, loc, scale]) |
正态分布对数概率分布函数。 |
|---|---|
pdf(x[, loc, scale]) |
正态分布概率分布函数。 |
cdf(x[, loc, scale]) |
正态分布累积分布函数。 |
logcdf(x[, loc, scale]) |
正态分布对数累积分布函数。 |
ppf(q[, loc, scale]) |
正态分布百分点函数。 |
sf(x[, loc, scale]) |
正态分布生存函数。 |
logsf(x[, loc, scale]) |
正态分布对数生存函数。 |
| isf(q[, loc, scale]) | 正态分布逆生存函数。 | ### jax.scipy.stats.pareto
logpdf(x, b[, loc, scale]) |
帕累托对数概率分布函数。 |
|---|
| pdf(x, b[, loc, scale]) | 帕累托分布概率分布函数。 | ### jax.scipy.stats.poisson
logpmf(k, mu[, loc]) |
泊松分布对数概率质量函数。 |
|---|---|
pmf(k, mu[, loc]) |
泊松分布概率质量函数。 |
| cdf(k, mu[, loc]) | 泊松分布累积分布函数。 | ### jax.scipy.stats.t
logpdf(x, df[, loc, scale]) |
学生 t 分布对数概率分布函数。 |
|---|
| pdf(x, df[, loc, scale]) | 学生 t 分布概率分布函数。 | ### jax.scipy.stats.truncnorm
cdf(x, a, b[, loc, scale]) |
截断正态分布累积分布函数。 |
|---|---|
logcdf(x, a, b[, loc, scale]) |
截断正态分布对数累积分布函数。 |
logpdf(x, a, b[, loc, scale]) |
截断正态分布对数概率分布函数。 |
logsf(x, a, b[, loc, scale]) |
截断正态分布对数生存函数。 |
pdf(x, a, b[, loc, scale]) |
截断正态分布概率分布函数。 |
| sf(x, a, b[, loc, scale]) | 截断正态分布对数生存函数。 | ### jax.scipy.stats.uniform
logpdf(x[, loc, scale]) |
均匀分布对数概率分布函数。 |
|---|---|
pdf(x[, loc, scale]) |
均匀分布概率分布函数。 |
cdf(x[, loc, scale]) |
均匀分布累积分布函数。 |
ppf(q[, loc, scale]) |
均匀分布百分点函数。 |
jax.scipy.stats.gaussian_kde
gaussian_kde(dataset[, bw_method, weights]) |
高斯核密度估计器 |
|---|---|
gaussian_kde.evaluate(points) |
对给定点评估高斯核密度估计器。 |
gaussian_kde.integrate_gaussian(mean, cov) |
加权高斯积分分布。 |
gaussian_kde.integrate_box_1d(low, high) |
在给定限制下积分分布。 |
gaussian_kde.integrate_kde(other) |
集成两个高斯核密度估计分布的乘积。 |
gaussian_kde.resample(key[, shape]) |
从估计的概率密度函数中随机采样数据集 |
gaussian_kde.pdf(x) |
概率密度函数 |
gaussian_kde.logpdf(x) |
对数概率密度函数 |
jax.scipy.stats.vonmises
logpdf(x, kappa) |
von Mises 对数概率分布函数。 |
|---|
| pdf(x, kappa) | von Mises 概率分布函数。 | ### jax.scipy.stats.wrapcauchy
logpdf(x, c) |
Wrapped Cauchy 对数概率分布函数。 |
|---|---|
pdf(x, c) |
Wrapped Cauchy 概率分布函数。 |
jax.scipy.stats.bernoulli.logpmf
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.logpmf.html
jax.scipy.stats.bernoulli.logpmf(k, p, loc=0)
伯努利对数概率质量函数。
scipy.stats.bernoulli 的 JAX 实现 logpmf
伯努利概率质量函数定义如下
[\begin{split}f(k) = \begin{cases} 1 - p, & k = 0 \ p, & k = 1 \ 0, & \mathrm{otherwise} \end{cases}\end{split}]
参数:
-
k (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,要评估 PMF 的值
-
p (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布形状参数
-
loc (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布偏移量
返回值:
logpmf 值的数组
返回类型:
Array
另请参阅
-
jax.scipy.stats.bernoulli.cdf() -
jax.scipy.stats.bernoulli.pmf() -
jax.scipy.stats.bernoulli.ppf()
jax.scipy.stats.bernoulli.pmf
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.pmf.html
jax.scipy.stats.bernoulli.pmf(k, p, loc=0)
伯努利概率质量函数。
scipy.stats.bernoulli pmf 的 JAX 实现
伯努利概率质量函数定义为
[\begin{split}f(k) = \begin{cases} 1 - p, & k = 0 \ p, & k = 1 \ 0, & \mathrm{otherwise} \end{cases}\end{split}]
参数:
-
k (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,要评估 PMF 的值
-
p (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,分布形状参数
-
loc (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,分布偏移
返回:
pmf 值数组
返回类型:
数组
参见
-
jax.scipy.stats.bernoulli.cdf() -
jax.scipy.stats.bernoulli.logpmf() -
jax.scipy.stats.bernoulli.ppf()
jax.scipy.stats.bernoulli.cdf
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.cdf.html
jax.scipy.stats.bernoulli.cdf(k, p)
伯努利累积分布函数。
scipy.stats.bernoulli 的 JAX 实现 cdf
伯努利累积分布函数被定义为:
[f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)]
其中 (f_{pmf}(k, p)) 是伯努利概率质量函数 jax.scipy.stats.bernoulli.pmf()。
参数:
-
k (Array | ndarray | bool | number | bool | int | float | complex) – 数组,用于评估 CDF 的值
-
p (Array | ndarray | bool | number | bool | int | float | complex) – 数组,分布形状参数
-
loc – 数组,分布偏移
返回:
cdf 值的数组
返回类型:
Array
另请参见
-
jax.scipy.stats.bernoulli.logpmf() -
jax.scipy.stats.bernoulli.pmf() -
jax.scipy.stats.bernoulli.ppf()
jax.scipy.stats.bernoulli.ppf
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.ppf.html
jax.scipy.stats.bernoulli.ppf(q, p)
伯努利百分点函数。
JAX 实现的 scipy.stats.bernoulli ppf
百分点函数是累积分布函数的反函数,jax.scipy.stats.bernoulli.cdf()。
参数:
-
k – arraylike,评估 PPF 的值
-
p (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布形状参数
-
loc – arraylike,分布偏移
-
q (Array | ndarray | bool | number | bool | int | float | complex)
返回:
ppf 值数组
返回类型:
Array
另见
-
jax.scipy.stats.bernoulli.cdf() -
jax.scipy.stats.bernoulli.logpmf() -
jax.scipy.stats.bernoulli.pmf()
jax.lax 模块
jax.lax 是支持诸如 jax.numpy 等库的基本操作的库。通常会定义转换规则,例如 JVP 和批处理规则,作为对 jax.lax 基元的转换。
许多基元都是等价于 XLA 操作的薄包装,详细描述请参阅XLA 操作语义文档。
在可能的情况下,优先使用诸如 jax.numpy 等库,而不是直接使用 jax.lax。jax.numpy API 遵循 NumPy,因此比 jax.lax API 更稳定,更不易更改。
Operators
abs(x) |
按元素绝对值:(|x|)。 |
|---|---|
acos(x) |
按元素求反余弦:(\mathrm{acos}(x))。 |
acosh(x) |
按元素求反双曲余弦:(\mathrm{acosh}(x))。 |
add(x, y) |
按元素加法:(x + y)。 |
after_all(*operands) |
合并一个或多个 XLA 令牌值。 |
approx_max_k(operand, k[, ...]) |
以近似方式返回 operand 的最大 k 值及其索引。 |
approx_min_k(operand, k[, ...]) |
以近似方式返回 operand 的最小 k 值及其索引。 |
argmax(operand, axis, index_dtype) |
计算沿着 axis 的最大元素的索引。 |
argmin(operand, axis, index_dtype) |
计算沿着 axis 的最小元素的索引。 |
asin(x) |
按元素求反正弦:(\mathrm{asin}(x))。 |
asinh(x) |
按元素求反双曲正弦:(\mathrm{asinh}(x))。 |
atan(x) |
按元素求反正切:(\mathrm{atan}(x))。 |
atan2(x, y) |
两个变量的按元素反正切:(\mathrm{atan}({x \over y}))。 |
atanh(x) |
按元素求反双曲正切:(\mathrm{atanh}(x))。 |
batch_matmul(lhs, rhs[, precision]) |
批量矩阵乘法。 |
bessel_i0e(x) |
指数缩放修正贝塞尔函数 (0) 阶:(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)) |
bessel_i1e(x) |
指数缩放修正贝塞尔函数 (1) 阶:(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)) |
betainc(a, b, x) |
按元素的正则化不完全贝塔积分。 |
bitcast_convert_type(operand, new_dtype) |
按元素位转换。 |
bitwise_and(x, y) |
按位与运算:(x \wedge y)。 |
bitwise_not(x) |
按位取反:(\neg x)。 |
bitwise_or(x, y) |
按位或运算:(x \vee y)。 |
bitwise_xor(x, y) |
按位异或运算:(x \oplus y)。 |
population_count(x) |
按元素计算 popcount,即每个元素中设置的位数。 |
broadcast(operand, sizes) |
广播数组,添加新的前导维度。 |
broadcast_in_dim(operand, shape, ...) |
包装 XLA 的 BroadcastInDim 操作符。 |
broadcast_shapes() |
返回经过 NumPy 广播后的形状。 |
broadcast_to_rank(x, rank) |
添加 1 的前导维度,使 x 的等级为 rank。 |
broadcasted_iota(dtype, shape, dimension) |
iota的便捷封装器。 |
cbrt(x) |
元素级立方根:(\sqrt[3]{x})。 |
ceil(x) |
元素级向上取整:(\left\lceil x \right\rceil)。 |
clamp(min, x, max) |
元素级 clamp 函数。 |
clz(x) |
元素级计算前导零的个数。 |
collapse(operand, start_dimension[, ...]) |
将数组的维度折叠为单个维度。 |
complex(x, y) |
元素级构造复数:(x + jy)。 |
concatenate(operands, dimension) |
沿指定维度连接一系列数组。 |
conj(x) |
元素级复数的共轭函数:(\overline{x})。 |
conv(lhs, rhs, window_strides, padding[, ...]) |
conv_general_dilated的便捷封装器。 |
convert_element_type(operand, new_dtype) |
元素级类型转换。 |
conv_dimension_numbers(lhs_shape, rhs_shape, ...) |
将卷积维度编号转换为 ConvDimensionNumbers。 |
conv_general_dilated(lhs, rhs, ...[, ...]) |
带有可选扩展的通用 n 维卷积运算符。 |
conv_general_dilated_local(lhs, rhs, ...[, ...]) |
带有可选扩展的通用 n 维非共享卷积运算符。 |
conv_general_dilated_patches(lhs, ...[, ...]) |
提取符合 conv_general_dilated 接受域的补丁。 |
conv_transpose(lhs, rhs, strides, padding[, ...]) |
计算 N 维卷积的“转置”的便捷封装器。 |
conv_with_general_padding(lhs, rhs, ...[, ...]) |
conv_general_dilated的便捷封装器。 |
cos(x) |
元素级余弦函数:(\mathrm{cos}(x))。 |
cosh(x) |
元素级双曲余弦函数:(\mathrm{cosh}(x))。 |
cumlogsumexp(operand[, axis, reverse]) |
沿轴计算累积 logsumexp。 |
cummax(operand[, axis, reverse]) |
沿轴计算累积最大值。 |
cummin(operand[, axis, reverse]) |
沿轴计算累积最小值。 |
cumprod(operand[, axis, reverse]) |
沿轴计算累积乘积。 |
cumsum(operand[, axis, reverse]) |
沿轴计算累积和。 |
digamma(x) |
元素级 digamma 函数:(\psi(x))。 |
div(x, y) |
元素级除法:(x \over y)。 |
dot(lhs, rhs[, precision, ...]) |
向量/向量,矩阵/向量和矩阵/矩阵乘法。 |
dot_general(lhs, rhs, dimension_numbers[, ...]) |
通用的点积/收缩运算符。 |
dynamic_index_in_dim(operand, index[, axis, ...]) |
对 dynamic_slice 的便捷封装,用于执行整数索引。 |
dynamic_slice(operand, start_indices, ...) |
封装了 XLA 的 DynamicSlice 操作符。 |
dynamic_slice_in_dim(operand, start_index, ...) |
方便地封装了应用于单个维度的 lax.dynamic_slice()。 |
dynamic_update_index_in_dim(operand, update, ...) |
方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新大小为 1 的切片。 |
dynamic_update_slice(operand, update, ...) |
封装了 XLA 的 DynamicUpdateSlice 操作符。 |
dynamic_update_slice_in_dim(operand, update, ...) |
方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新一个切片。 |
eq(x, y) |
元素级相等:(x = y)。 |
erf(x) |
元素级误差函数:(\mathrm{erf}(x))。 |
erfc(x) |
元素级补充误差函数:(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x))。 |
erf_inv(x) |
元素级反误差函数:(\mathrm{erf}^{-1}(x))。 |
exp(x) |
元素级指数函数:(e^x)。 |
expand_dims(array, dimensions) |
将任意数量的大小为 1 的维度插入到数组中。 |
expm1(x) |
元素级运算 (e^{x} - 1)。 |
fft(x, fft_type, fft_lengths) |
|
floor(x) |
元素级向下取整:(\left\lfloor x \right\rfloor)。 |
full(shape, fill_value[, dtype, sharding]) |
返回填充值为 fill_value 的形状数组。 |
full_like(x, fill_value[, dtype, shape, ...]) |
基于示例数组 x 创建类似于 np.full 的完整数组。 |
gather(operand, start_indices, ...[, ...]) |
Gather 操作符。 |
ge(x, y) |
元素级大于或等于:(x \geq y)。 |
gt(x, y) |
元素级大于:(x > y)。 |
igamma(a, x) |
元素级正则化不完全 gamma 函数。 |
igammac(a, x) |
元素级补充正则化不完全 gamma 函数。 |
imag(x) |
提取复数的虚部:(\mathrm{Im}(x))。 |
index_in_dim(operand, index[, axis, keepdims]) |
方便地封装了 lax.slice(),用于执行整数索引。 |
index_take(src, idxs, axes) |
|
integer_pow(x, y) |
元素级幂运算:(x^y),其中 (y) 是固定整数。 |
iota(dtype, size) |
封装了 XLA 的 Iota 操作符。 |
is_finite(x) |
元素级 (\mathrm{isfinite})。 |
le(x, y) |
元素级小于或等于:(x \leq y)。 |
lgamma(x) |
元素级对数 gamma 函数:(\mathrm{log}(\Gamma(x)))。 |
log(x) |
元素级自然对数:(\mathrm{log}(x))。 |
log1p(x) |
元素级 (\mathrm{log}(1 + x))。 |
logistic(x) |
元素级 logistic(sigmoid)函数:(\frac{1}{1 + e^{-x}})。 |
lt(x, y) |
元素级小于:(x < y)。 |
max(x, y) |
元素级最大值:(\mathrm{max}(x, y)) |
min(x, y) |
元素级最小值:(\mathrm{min}(x, y)) |
mul(x, y) |
元素级乘法:(x \times y)。 |
ne(x, y) |
按位不等于:(x \neq y)。 |
neg(x) |
按位取负:(-x)。 |
nextafter(x1, x2) |
返回 x1 在 x2 方向上的下一个可表示的值。 |
pad(operand, padding_value, padding_config) |
对数组应用低、高和/或内部填充。 |
polygamma(m, x) |
按位多次 gamma 函数:(\psi^{(m)}(x))。 |
population_count(x) |
按位人口统计,统计每个元素中设置的位数。 |
pow(x, y) |
按位幂运算:(x^y)。 |
random_gamma_grad(a, x) |
Gamma 分布导数的按位计算。 |
real(x) |
按位提取实部:(\mathrm{Re}(x))。 |
reciprocal(x) |
按位倒数:(1 \over x)。 |
reduce(operands, init_values, computation, ...) |
封装了 XLA 的 Reduce 运算符。 |
reduce_precision(operand, exponent_bits, ...) |
封装了 XLA 的 ReducePrecision 运算符。 |
reduce_window(operand, init_value, ...[, ...]) |
|
rem(x, y) |
按位取余:(x \bmod y)。 |
reshape(operand, new_sizes[, dimensions]) |
封装了 XLA 的 Reshape 运算符。 |
rev(operand, dimensions) |
封装了 XLA 的 Rev 运算符。 |
rng_bit_generator(key, shape[, dtype, algorithm]) |
无状态的伪随机数位生成器。 |
rng_uniform(a, b, shape) |
有状态的伪随机数生成器。 |
round(x[, rounding_method]) |
按位四舍五入。 |
rsqrt(x) |
按位倒数平方根:(1 \over \sqrt{x})。 |
scatter(operand, scatter_indices, updates, ...) |
Scatter-update 运算符。 |
scatter_add(operand, scatter_indices, ...[, ...]) |
Scatter-add 运算符。 |
scatter_apply(operand, scatter_indices, ...) |
Scatter-apply 运算符。 |
scatter_max(operand, scatter_indices, ...[, ...]) |
Scatter-max 运算符。 |
scatter_min(operand, scatter_indices, ...[, ...]) |
Scatter-min 运算符。 |
scatter_mul(operand, scatter_indices, ...[, ...]) |
Scatter-multiply 运算符。 |
shift_left(x, y) |
按位左移:(x \ll y)。 |
shift_right_arithmetic(x, y) |
按位算术右移:(x \gg y)。 |
shift_right_logical(x, y) |
按位逻辑右移:(x \gg y)。 |
sign(x) |
按位符号函数。 |
sin(x) |
按位正弦函数:(\mathrm{sin}(x))。 |
sinh(x) |
按位双曲正弦函数:(\mathrm{sinh}(x))。 |
slice(operand, start_indices, limit_indices) |
封装了 XLA 的 Slice 运算符。 |
slice_in_dim(operand, start_index, limit_index) |
lax.slice() 的单维度应用封装。 |
sort() |
封装了 XLA 的 Sort 运算符。 |
sort_key_val(keys, values[, dimension, ...]) |
沿着dimension排序keys并对values应用相同的置换。 |
sqrt(x) |
逐元素平方根:(\sqrt{x})。 |
square(x) |
逐元素平方:(x²)。 |
squeeze(array, dimensions) |
从数组中挤出任意数量的大小为 1 的维度。 |
sub(x, y) |
逐元素减法:(x - y)。 |
tan(x) |
逐元素正切:(\mathrm{tan}(x))。 |
tanh(x) |
逐元素双曲正切:(\mathrm{tanh}(x))。 |
top_k(operand, k) |
返回operand最后一轴上的前k个值及其索引。 |
transpose(operand, permutation) |
包装 XLA 的Transpose运算符。 |
zeros_like_array(x) |
|
zeta(x, q) |
逐元素 Hurwitz zeta 函数:(\zeta(x, q)) |
控制流操作符
associative_scan(fn, elems[, reverse, axis]) |
使用关联二元操作并行执行扫描。 |
|---|---|
cond(pred, true_fun, false_fun, *operands[, ...]) |
根据条件应用true_fun或false_fun。 |
fori_loop(lower, upper, body_fun, init_val, *) |
通过归约到jax.lax.while_loop()从lower到upper循环。 |
map(f, xs) |
在主要数组轴上映射函数。 |
scan(f, init[, xs, length, reverse, unroll, ...]) |
在主要数组轴上扫描函数并携带状态。 |
select(pred, on_true, on_false) |
根据布尔谓词在两个分支之间选择。 |
select_n(which, *cases) |
从多个情况中选择数组值。 |
switch(index, branches, *operands[, operand]) |
根据index应用恰好一个branches。 |
while_loop(cond_fun, body_fun, init_val) |
在cond_fun为 True 时重复调用body_fun。 |
自定义梯度操作符
stop_gradient(x) |
停止梯度计算。 |
|---|---|
custom_linear_solve(matvec, b, solve[, ...]) |
使用隐式定义的梯度执行无矩阵线性求解。 |
custom_root(f, initial_guess, solve, ...[, ...]) |
可微分求解函数的根。 |
并行操作符
all_gather(x, axis_name, *[, ...]) |
在所有副本中收集x的值。 |
|---|---|
all_to_all(x, axis_name, split_axis, ...[, ...]) |
映射轴的实例化和映射不同轴。 |
pdot(x, y, axis_name[, pos_contract, ...]) |
|
psum(x, axis_name, *[, axis_index_groups]) |
在映射的轴axis_name上进行全归约求和。 |
psum_scatter(x, axis_name, *[, ...]) |
像psum(x, axis_name),但每个设备仅保留部分结果。 |
pmax(x, axis_name, *[, axis_index_groups]) |
在映射的轴axis_name上计算全归约最大值。 |
pmin(x, axis_name, *[, axis_index_groups]) |
在映射的轴axis_name上计算全归约最小值。 |
pmean(x, axis_name, *[, axis_index_groups]) |
在映射的轴axis_name上计算全归约均值。 |
ppermute(x, axis_name, perm) |
根据置换 perm 执行集体置换。 |
pshuffle(x, axis_name, perm) |
使用替代置换编码的 jax.lax.ppermute 的便捷包装器 |
pswapaxes(x, axis_name, axis, *[, ...]) |
将 pmapped 轴 axis_name 与非映射轴 axis 交换。 |
axis_index(axis_name) |
返回沿映射轴 axis_name 的索引。 |
与分片相关的操作符
with_sharding_constraint(x, shardings) |
在 jitted 计算中约束数组的分片机制 |
|---|
线性代数操作符 (jax.lax.linalg)
cholesky(x, *[, symmetrize_input]) |
Cholesky 分解。 |
|---|---|
eig(x, *[, compute_left_eigenvectors, ...]) |
一般矩阵的特征分解。 |
eigh(x, *[, lower, symmetrize_input, ...]) |
Hermite 矩阵的特征分解。 |
hessenberg(a) |
将方阵约化为上 Hessenberg 形式。 |
lu(x) |
带有部分主元列主元分解。 |
householder_product(a, taus) |
单元 Householder 反射的乘积。 |
qdwh(x, *[, is_hermitian, max_iterations, ...]) |
基于 QR 的动态加权 Halley 迭代进行极分解。 |
qr(x, *[, full_matrices]) |
QR 分解。 |
schur(x, *[, compute_schur_vectors, ...]) |
|
svd() |
奇异值分解。 |
triangular_solve(a, b, *[, left_side, ...]) |
三角解法。 |
tridiagonal(a, *[, lower]) |
将对称/Hermitian 矩阵约化为三对角形式。 |
tridiagonal_solve(dl, d, du, b) |
计算三对角线性系统的解。 |
参数类
class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
描述卷积的批量、空间和特征维度。
参数:
-
lhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
-
rhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(输出特征维度,输入特征维度,空间维度…)。
-
out_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
jax.lax.ConvGeneralDilatedDimensionNumbers
alias of tuple[str, str, str] | ConvDimensionNumbers | None
class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)
描述了传递给 XLA 的 Gather 运算符 的维度号参数。有关维度号含义的详细信息,请参阅 XLA 文档。
Parameters:
-
offset_dims (tuple[int, ...**]) – gather 输出中偏移到从操作数切片的数组中的维度的集合。必须是升序整数元组,每个代表输出的一个维度编号。
-
collapsed_slice_dims (tuple[int, ...**]) – operand 中具有 slice_sizes[i] == 1 的维度 i 的集合,这些维度不应在 gather 输出中具有对应维度。必须是一个升序整数元组。
-
start_index_map (tuple[int, ...**]) – 对于 start_indices 中的每个维度,给出应该被切片的操作数中对应的维度。必须是一个大小等于 start_indices.shape[-1] 的整数元组。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐含的;总是存在一个索引向量维度,且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。
class jax.lax.GatherScatterMode(value)
描述了如何处理 gather 或 scatter 中的越界索引。
可能的值包括:
CLIP:
索引将被夹在最近的范围内值上,即整个要收集的窗口都在范围内。
FILL_OR_DROP:
如果收集窗口的任何部分越界,则返回整个窗口,即使其他部分原本在界内的元素也将用常量填充。如果分散窗口的任何部分越界,则整个窗口将被丢弃。
PROMISE_IN_BOUNDS:
用户承诺索引在范围内。不会执行额外检查。实际上,根据当前的 XLA 实现,这意味着越界的 gather 将被夹在范围内,但越界的 scatter 将被丢弃。如果索引越界,则梯度将不正确。
class jax.lax.Precision(value)
lax 函数的精度枚举
JAX 函数的精度参数通常控制加速器后端(即 TPU 和 GPU)上的数组计算速度和精度之间的权衡。成员包括:
默认:
最快模式,但最不准确。在 bfloat16 中执行计算。别名:'default','fastest','bfloat16'。
高:
较慢但更准确。以 3 个 bfloat16 传递执行 float32 计算,或在可用时使用 tensorfloat32。别名:'high','bfloat16_3x','tensorfloat32'。
最高:
最慢但最准确。根据适用情况在 float32 或 float64 中执行计算。别名:'highest','float32'。
jax.lax.PrecisionLike
别名为 str | Precision | tuple[str, str] | tuple[Precision, Precision] | None
class jax.lax.RoundingMethod(value)
一个枚举。
class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)
描述了对 XLA 的 Scatter 操作符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
参数:
-
update_window_dims (Sequence[int]) – 更新中作为窗口维度的维度集合。必须是整数元组,按升序排列,每个表示一个维度编号。
-
inserted_window_dims (Sequence[int]) – 必须插入更新形状的大小为 1 的窗口维度集合。必须是整数元组,按升序排列,每个表示输出的维度编号的镜像图。这些是 gather 情况下 collapsed_slice_dims 的镜像图。
-
scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,大小等于 scatter_indices.shape[-1]。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是有一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,添加一个尺寸为 1 的尾随维度。
jax.random 模块
伪随机数生成的实用程序。
jax.random 包提供了多种例程,用于确定性生成伪随机数序列。
基本用法
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches))
PRNG keys
与 NumPy 和 SciPy 用户习惯的 有状态 伪随机数生成器(PRNGs)不同,JAX 随机函数都要求作为第一个参数传递一个显式的 PRNG 状态。随机状态由我们称之为 key 的特殊数组元素类型描述,通常由 jax.random.key() 函数生成:
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
然后,可以在 JAX 的任何随机数生成例程中使用该 key:
>>> random.uniform(key)
Array(0.41845703, dtype=float32)
请注意,使用 key 不会修改它,因此重复使用相同的 key 将导致相同的结果:
>>> random.uniform(key)
Array(0.41845703, dtype=float32)
如果需要新的随机数,可以使用 jax.random.split() 生成新的子 key:
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)
注意
类型化的 key 数组,例如上述 key<fry>,在 JAX v0.4.16 中引入。在此之前,key 通常以 uint32 数组表示,其最终维度表示 key 的位级表示。
两种形式的 key 数组仍然可以通过 jax.random 模块创建和使用。新式的类型化 key 数组使用 jax.random.key() 创建。传统的 uint32 key 数组使用 jax.random.PRNGKey() 创建。
要在两者之间进行转换,使用 jax.random.key_data() 和 jax.random.wrap_key_data()。当与 JAX 外部系统(例如将数组导出为可序列化格式)交互或将 key 传递给基于 JAX 的库时,可能需要传统的 key 格式。
否则,建议使用类型化的 key。传统 key 相对于类型化 key 的注意事项包括:
-
它们有一个额外的尾维度。
-
它们具有数字数据类型 (
uint32),允许进行通常不用于 key 的操作,例如整数算术。 -
它们不包含有关 RNG 实现的信息。当传统 key 传递给
jax.random函数时,全局配置设置确定 RNG 实现(参见下文的“高级 RNG 配置”)。
要了解更多关于此升级以及 key 类型设计的信息,请参阅 JEP 9263。
高级
设计和背景
TLDR:JAX PRNG = Threefry counter PRNG + 一个功能数组导向的 分裂模型
更多详细信息,请参阅 docs/jep/263-prng.md。
总结一下,JAX PRNG 还包括但不限于以下要求:
-
确保可重现性,
-
良好的并行化,无论是向量化(生成数组值)还是多副本、多核计算。特别是它不应在随机函数调用之间使用顺序约束。
高级 RNG 配置
JAX 提供了几种 PRNG 实现。可以通过可选的 impl 关键字参数选择特定的实现。如果在密钥构造函数中没有传递 impl 选项,则实现由全局 jax_default_prng_impl 配置标志确定。
-
默认,“threefry2x32”: 基于 Threefry 哈希函数构建的基于计数器的 PRNG。
-
实验性 一种仅包装了 XLA 随机位生成器(RBG)算法的 PRNG。请参阅 TF 文档。
-
“rbg” 使用 ThreeFry 进行分割,并使用 XLA RBG 进行数据生成。
-
“unsafe_rbg” 仅用于演示目的,使用 RBG 进行分割(使用未经测试的虚构算法)和生成。
这些实验性实现生成的随机流尚未经过任何经验随机性测试(例如 Big Crush)。生成的随机比特可能会在 JAX 的不同版本之间变化。
-
不使用默认 RNG 的可能原因是:
-
可能编译速度较慢(特别是对于 Google Cloud TPU)
-
在 TPU 上执行速度较慢
-
不支持高效的自动分片/分区
这里是一个简短的总结:
| 属性 | Threefry | Threefry* | rbg | unsafe_rbg | rbg** | unsafe_rbg** |
|---|---|---|---|---|---|---|
| 在 TPU 上最快 | ✅ | ✅ | ✅ | ✅ | ||
| 可以高效分片(使用 pjit) | ✅ | ✅ | ✅ | |||
| 在分片中相同 | ✅ | ✅ | ✅ | ✅ | ||
| 在 CPU/GPU/TPU 上相同 | ✅ | ✅ | ||||
| 在 JAX/XLA 版本间相同 | ✅ | ✅ |
(*): 设置了jax_threefry_partitionable=1
(**): 设置了XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
“rbg” 和 “unsafe_rbg” 之间的区别在于,“rbg” 用于生成随机值时使用了较不稳定/研究较少的哈希函数(但不用于 jax.random.split 或 jax.random.fold_in),而 “unsafe_rbg” 还额外在 jax.random.split 和 jax.random.fold_in 中使用了更不稳定的哈希函数。因此,在不同密钥生成的随机流质量方面不那么安全。
要了解有关 jax_threefry_partitionable 的更多信息,请参阅jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
API 参考
密钥创建与操作
PRNGKey(seed, *[, impl]) |
给定整数种子创建伪随机数生成器(PRNG)密钥。 |
|---|---|
key(seed, *[, impl]) |
给定整数种子创建伪随机数生成器(PRNG)密钥。 |
key_data(密钥) |
恢复 PRNG 密钥数组下的密钥数据位。 |
wrap_key_data(key_bits_array, *[, impl]) |
将密钥数据位数组包装成 PRNG 密钥数组。 |
fold_in(key, data) |
将数据折叠到 PRNG 密钥中,形成新的 PRNG 密钥。 |
split(key[, num]) |
将 PRNG 密钥按添加一个前导轴拆分为 num 个新密钥。 |
clone(key) |
克隆一个密钥以便重复使用。 |
随机抽样器
ball(key, d[, p, shape, dtype]) |
从单位 Lp 球中均匀采样。 |
|---|---|
bernoulli(key[, p, shape]) |
采样给定形状和均值的伯努利分布随机值。 |
beta(key, a, b[, shape, dtype]) |
采样给定形状和浮点数数据类型的贝塔分布随机值。 |
binomial(key, n, p[, shape, dtype]) |
采样给定形状和浮点数数据类型的二项分布随机值。 |
bits(key[, shape, dtype]) |
以无符号整数的形式采样均匀比特。 |
categorical(key, logits[, axis, shape]) |
从分类分布中采样随机值。 |
cauchy(key[, shape, dtype]) |
采样给定形状和浮点数数据类型的柯西分布随机值。 |
chisquare(key, df[, shape, dtype]) |
采样给定形状和浮点数数据类型的卡方分布随机值。 |
choice(key, a[, shape, replace, p, axis]) |
从给定数组中生成随机样本。 |
dirichlet(key, alpha[, shape, dtype]) |
采样给定形状和浮点数数据类型的狄利克雷分布随机值。 |
double_sided_maxwell(key, loc, scale[, ...]) |
从双边 Maxwell 分布中采样。 |
exponential(key[, shape, dtype]) |
采样给定形状和浮点数数据类型的指数分布随机值。 |
f(key, dfnum, dfden[, shape, dtype]) |
采样给定形状和浮点数数据类型的 F 分布随机值。 |
gamma(key, a[, shape, dtype]) |
采样给定形状和浮点数数据类型的伽马分布随机值。 |
generalized_normal(key, p[, shape, dtype]) |
从广义正态分布中采样。 |
geometric(key, p[, shape, dtype]) |
采样给定形状和浮点数数据类型的几何分布随机值。 |
gumbel(key[, shape, dtype]) |
采样给定形状和浮点数数据类型的 Gumbel 分布随机值。 |
laplace(key[, shape, dtype]) |
采样给定形状和浮点数数据类型的拉普拉斯分布随机值。 |
loggamma(key, a[, shape, dtype]) |
采样给定形状和浮点数数据类型的对数伽马分布随机值。 |
logistic(key[, shape, dtype]) |
采样给定形状和浮点数数据类型的 logistic 随机值。 |
lognormal(key[, sigma, shape, dtype]) |
采样给定形状和浮点数数据类型的对数正态分布随机值。 |
maxwell(key[, shape, dtype]) |
从单边 Maxwell 分布中采样。 |
multivariate_normal(key, mean, cov[, shape, ...]) |
采样给定均值和协方差的多变量正态分布随机值。 |
normal(key[, shape, dtype]) |
采样给定形状和浮点数数据类型的标准正态分布随机值。 |
orthogonal(key, n[, shape, dtype]) |
从正交群 O(n) 中均匀采样。 |
pareto(key, b[, shape, dtype]) |
采样给定形状和浮点数数据类型的帕累托分布随机值。 |
permutation(key, x[, axis, independent]) |
返回随机排列的数组或范围。 |
poisson(key, lam[, shape, dtype]) |
采样给定形状和整数数据类型的泊松分布随机值。 |
rademacher(key[, shape, dtype]) |
从 Rademacher 分布中采样。 |
randint(key, shape, minval, maxval[, dtype]) |
用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机整数值。 |
[rayleigh(key, scale[, shape, dtype]) |
用给定的形状和浮点数数据类型示例瑞利随机值。 |
t(key, df[, shape, dtype]) |
用给定的形状和浮点数数据类型示例学生 t 分布随机值。 |
triangular(key, left, mode, right[, shape, ...]) |
用给定的形状和浮点数数据类型示例三角形随机值。 |
truncated_normal(key, lower, upper[, shape, ...]) |
用给定的形状和数据类型示例截断标准正态随机值。 |
uniform(key[, shape, dtype, minval, maxval]) |
用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机值。 |
[wald(key, mean[, shape, dtype]) |
用给定的形状和浮点数数据类型示例瓦尔德随机值。 |
weibull_min(key, scale, concentration[, ...]) |
从威布尔分布中采样。 |
jax.sharding 模块
类
class jax.sharding.Sharding
描述了jax.Array如何跨设备布局。
property addressable_devices: set[Device]
Sharding中由当前进程可寻址的设备集合。
addressable_devices_indices_map(global_shape)
从可寻址设备到它们包含的数组数据切片的映射。
addressable_devices_indices_map 包含适用于可寻址设备的device_indices_map部分。
参数:
global_shape (tuple[int, ...**])
返回类型:
Mapping[Device, tuple[slice, …] | None]
property device_set: set[Device]
这个Sharding跨越的设备集合。
在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。
devices_indices_map(global_shape)
返回从设备到它们包含的数组切片的映射。
映射包括所有全局设备,即包括来自其他进程的不可寻址设备。
参数:
global_shape (tuple[int, ...**])
返回类型:
Mapping[Device, tuple[slice, …]]
is_equivalent_to(other, ndim)
如果两个分片等效,则返回True。
如果它们在相同设备上放置了相同的逻辑数组分片,则两个分片是等效的。
例如,如果NamedSharding和PositionalSharding都将数组的相同分片放置在相同的设备上,则它们可能是等效的。
参数:
-
self (Sharding)
-
other (Sharding)
-
ndim (int)
返回类型:
property is_fully_addressable: bool
此分片是否是完全可寻址的?
如果当前进程能够寻址Sharding中列出的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable 等效于 "is_local"。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则分片是完全复制的。
property memory_kind: str | None
返回分片的内存类型。
shard_shape(global_shape)
返回每个设备上数据的形状。
此函数返回的分片形状是从global_shape和分片属性计算得出的。
参数:
global_shape (tuple[int, ...**])
返回类型:
with_memory_kind(kind)
返回具有指定内存类型的新分片实例。
参数:
kind (str)
返回类型:
分片
class jax.sharding.SingleDeviceSharding
基类:分片
一个将其数据放置在单个设备上的分片。
参数:
device – 单个设备。
示例
>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
... jax.devices()[0])
property device_set: set[Device]
此分片跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的非可寻址设备。
devices_indices_map(global_shape)
返回从设备到每个包含的数组片段的映射。
映射包括所有全局设备,即包括来自其他进程的非可寻址设备。
参数:
global_shape (tuple[int, ...**])
返回类型:
property is_fully_addressable: bool
此分片是否完全可寻址?
如果当前进程可以寻址分片中命名的所有设备,则称分片完全可寻址。is_fully_addressable在多进程 JAX 中等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则分片完全复制。
property memory_kind: str | None
返回分片的内存类型。
with_memory_kind(kind)
返回具有指定内存类型的新分片实例。
参数:
kind (str)
返回类型:
单设备分片
class jax.sharding.NamedSharding
基类:分片
一个NamedSharding使用命名轴来表示分片。
一个NamedSharding是设备Mesh和描述如何跨该网格对数组进行分片的PartitionSpec的组合。
一个Mesh是 JAX 设备的多维 NumPy 数组,其中网格的每个轴都有一个名称,例如 'x' 或 'y'。
一个PartitionSpec是一个元组,其元素可以是None、一个网格轴或一组网格轴的元组。每个元素描述如何在零个或多个网格维度上对输入维度进行分区。例如,PartitionSpec('x', 'y')表示数据的第一维在网格的 x 轴上进行分片,第二维在网格的 y 轴上进行分片。
分布式数组和自动并行化(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)教程详细讲解了如何使用Mesh和PartitionSpec,包括更多细节和图示。
参数:
-
mesh – 一个
jax.sharding.Mesh对象。 -
spec – 一个
jax.sharding.PartitionSpec对象。
示例
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property addressable_devices: set[Device]
当前进程可以访问的Sharding中的设备集。
property device_set: set[Device]
该Sharding跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
property is_fully_addressable: bool
此分片是否完全可寻址?
一个分片如果当前进程可以访问Sharding中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则称分片为完全复制。
property memory_kind: str | None
返回分片的内存类型。
property mesh
(self) -> object
property spec
(self) -> object
with_memory_kind(kind)
返回具有指定内存类型的新Sharding实例。
参数:
kind (str)
返回类型:
NamedSharding
class jax.sharding.PositionalSharding(devices, *, memory_kind=None)
基类:Sharding
参数:
-
devices (Sequence[xc.Device**] | np.ndarray)
-
memory_kind (str | None)
property device_set: set[Device]
该Sharding跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
property is_fully_addressable: bool
此分片是否完全可寻址?
一个分片如果当前进程可以访问Sharding中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则称分片为完全复制。
property memory_kind: str | None
返回分片的内存类型。
with_memory_kind(kind)
返回具有指定内存类型的新Sharding实例。
参数:
kind (str)
返回类型:
PositionalSharding
class jax.sharding.PmapSharding
基类:Sharding
描述了jax.pmap()使用的分片。
classmethod default(shape, sharded_dim=0, devices=None)
创建一个PmapSharding,与jax.pmap()使用的默认放置方式匹配。
参数:
-
sharded_dim (int") – 输入数组进行分片的维度。默认为 0。
-
devices(Sequence[Device] | None) – 可选的设备序列。如果省略,隐含的
-
used(pmap 使用的设备顺序是) –
jax.local_devices()。 -
of(这是顺序) –
jax.local_devices()。
返回类型:
PmapSharding
property device_set: set[Device]
这个Sharding跨越的设备集合。
在多控制器 JAX 中,设备集合是全局的,即包括其他进程的非可寻址设备。
property devices
(self)-> ndarray
devices_indices_map(global_shape)
返回设备到每个包含的数组切片的映射。
映射包括所有全局设备,即包括其他进程的非可寻址设备。
参数:
返回类型:
is_equivalent_to(other, ndim)
如果两个分片等效,则返回True。
如果它们将相同的逻辑数组分片放置在相同的设备上,则两个分片是等效的。
例如,如果NamedSharding和PositionalSharding将数组的相同分片放置在相同的设备上,则它们可能是等效的。
参数:
-
self(PmapSharding)
-
other(PmapSharding)
-
ndim(int)
返回类型:
布尔("in Python v3.12")
property is_fully_addressable: bool
这个分片是否完全可寻址?
如果当前进程能够处理Sharding中命名的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable相当于“is_local”。
property is_fully_replicated: bool
这个分片是否完全复制?
如果每个设备都有完整数据的副本,则分片是完全复制的。
property memory_kind: str | None
返回分片的内存类型。
shard_shape(global_shape)
返回每个设备上数据的形状。
此函数返回的分片形状是从global_shape和分片属性计算而来的。
参数:
返回类型:
property sharding_spec
(self)-> jax::ShardingSpec
with_memory_kind(kind)
返回具有指定内存类型的新 Sharding 实例。
参数:
kind(str)
class jax.sharding.GSPMDSharding
基类:Sharding
property device_set: set[Device]
这个Sharding跨越的设备集合。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
property is_fully_addressable: bool
此分片是否完全可寻址?
如果当前进程可以访问Sharding中命名的所有设备,则分片是完全可寻址的。is_fully_addressable相当于多进程 JAX 中的“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
一个分片是完全复制的,如果每个设备都有整个数据的完整副本。
property memory_kind: str | None
返回分片的内存类型。
with_memory_kind(kind)
返回具有指定内存类型的新 Sharding 实例。
参数:
kind(str)
返回类型:
GSPMDSharding
class jax.sharding.PartitionSpec(*partitions)
元组描述如何在设备网格上对数组进行分区。
每个元素都可以是None、字符串或字符串元组。有关更多详细信息,请参阅jax.sharding.NamedSharding的文档。
此类存在,以便 JAX 的 pytree 实用程序可以区分分区规范和应视为 pytrees 的元组。
class jax.sharding.Mesh(devices, axis_names)
声明在此管理器范围内可用的硬件资源。
特别是,所有axis_names在管理块内都变成有效的资源名称,并且可以在jax.experimental.pjit.pjit()的in_axis_resources参数中使用,还请参阅 JAX 的多进程编程模型(jax.readthedocs.io/en/latest/multi_process.html)和分布式数组与自动并行化教程(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
如果您在多线程中编译,请确保with Mesh上下文管理器位于线程将执行的函数内部。
参数:
-
devices(ndarray) - 包含 JAX 设备对象(例如从
jax.devices()获得的对象)的 NumPy ndarray 对象。 -
axis_names(tuple[Any, ...**]) - 资源轴名称序列,用于分配给
devices参数的维度。其长度应与devices的秩匹配。
示例
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
jax.debug 模块
运行时值调试实用工具
jax.debug.print 和 jax.debug.breakpoint 描述了如何利用 JAX 的运行时值调试功能。
callback(callback, *args[, ordered]) |
调用可分阶段的 Python 回调函数。 |
|---|---|
print(fmt, *args[, ordered]) |
打印值,并在 JAX 函数中工作。 |
breakpoint(*[, backend, filter_frames, ...]) |
在程序中某一点设置断点。 |
调试分片实用工具
能够在分段函数内(和外部)检查和可视化数组分片的函数。
inspect_array_sharding(value, *, callback) |
在 JIT 编译函数内部启用检查数组分片。 |
|---|---|
visualize_array_sharding(arr, **kwargs) |
可视化数组的分片。 |
visualize_sharding(shape, sharding, *[, ...]) |
使用 rich 可视化 Sharding。 |
jax.dlpack 模块
from_dlpack(external_array[, device, copy]) |
返回一个 DLPack 张量的 Array 表示形式。 |
|---|---|
to_dlpack(x[, stream, src_device, ...]) |
返回一个封装了 Array x 的 DLPack 张量。 |
jax.distributed 模块
initialize([coordinator_address, ...]) |
初始化 JAX 分布式系统。 |
|---|---|
shutdown() |
关闭分布式系统。 |
jax.dtypes 模块
bfloat16 |
bfloat16 浮点数值 |
|---|---|
canonicalize_dtype(dtype[, allow_extended_dtype]) |
根据config.x64_enabled配置将 dtype 转换为规范的 dtype。 |
float0 |
对应于相同名称的标量类型和 dtype 的 DType 类。 |
issubdtype(a, b) |
如果第一个参数是类型代码在类型层次结构中较低/相等,则返回 True。 |
prng_key() |
PRNG Key dtypes 的标量类。 |
result_type(*args[, return_weak_type_flag]) |
方便函数,用于应用 JAX 参数 dtype 提升。 |
scalar_type_of(x) |
返回与 JAX 值关联的标量类型。 |
jax.flatten_util 模块
函数列表
| - | ravel_pytree(pytree) |
将一个数组的 pytree 展平(压缩)为一个 1D 数组。 |
|---|
jax.image 模块
图像操作函数。
更多的图像操作函数可以在建立在 JAX 之上的库中找到,例如 PIX。
图像操作函数
resize(image, shape, method[, antialias, ...]) |
图像调整大小。 |
|---|---|
scale_and_translate(image, shape, ...[, ...]) |
对图像应用缩放和平移。 |
参数类
class jax.image.ResizeMethod(value)
图像调整大小方法。
可能的取值包括:
NEAREST:
最近邻插值。
LINEAR:
线性插值。
LANCZOS3:
Lanczos 重采样,使用半径为 3 的核。
LANCZOS5:
Lanczos 重采样,使用半径为 5 的核。
CUBIC:
三次插值,使用 Keys 三次核。
jax.nn 模块
jax.nn.initializers模块
神经网络库常见函数。
激活函数
relu |
线性整流单元激活函数。 |
|---|---|
relu6 |
线性整流单元 6 激活函数。 |
sigmoid(x) |
Sigmoid 激活函数。 |
softplus(x) |
Softplus 激活函数。 |
sparse_plus(x) |
稀疏加法函数。 |
sparse_sigmoid(x) |
稀疏 Sigmoid 激活函数。 |
soft_sign(x) |
Soft-sign 激活函数。 |
silu(x) |
SiLU(又称 swish)激活函数。 |
swish(x) |
SiLU(又称 swish)激活函数。 |
log_sigmoid(x) |
对数 Sigmoid 激活函数。 |
leaky_relu(x[, negative_slope]) |
泄漏整流线性单元激活函数。 |
hard_sigmoid(x) |
硬 Sigmoid 激活函数。 |
hard_silu(x) |
硬 SiLU(swish)激活函数。 |
hard_swish(x) |
硬 SiLU(swish)激活函数。 |
hard_tanh(x) |
硬\tanh 激活函数。 |
elu(x[, alpha]) |
指数线性单元激活函数。 |
celu(x[, alpha]) |
连续可微的指数线性单元激活函数。 |
selu(x) |
缩放的指数线性单元激活函数。 |
gelu(x[, approximate]) |
高斯误差线性单元激活函数。 |
glu(x[, axis]) |
门控线性单元激活函数。 |
squareplus(x[, b]) |
Squareplus 激活函数。 |
mish(x) |
Mish 激活函数。 |
其他函数
softmax(x[, axis, where, initial]) |
Softmax 函数。 |
|---|---|
log_softmax(x[, axis, where, initial]) |
对数 Softmax 函数。 |
logsumexp() |
对数-总和-指数归约。 |
standardize(x[, axis, mean, variance, ...]) |
通过减去mean并除以(\sqrt{\mathrm{variance}})来标准化数组。 |
one_hot(x, num_classes, *[, dtype, axis]) |
对给定索引进行 One-hot 编码。 |
jax.nn.initializers 模块
与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。
初始化器
该模块提供了与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。
初始化器是一个函数,接受三个参数:(key, shape, dtype),并返回一个具有形状shape和数据类型dtype的数组。参数key是一个 PRNG 密钥(例如来自jax.random.key()),用于生成初始化数组的随机数。
constant(value[, dtype]) |
构建一个返回常数值数组的初始化器。 |
|---|---|
delta_orthogonal([scale, column_axis, dtype]) |
构建一个用于增量正交核的初始化器。 |
glorot_normal([in_axis, out_axis, ...]) |
构建一个 Glorot 正态初始化器(又称 Xavier 正态初始化器)。 |
glorot_uniform([in_axis, out_axis, ...]) |
构建一个 Glorot 均匀初始化器(又称 Xavier 均匀初始化器)。 |
he_normal([in_axis, out_axis, batch_axis, dtype]) |
构建一个 He 正态初始化器(又称 Kaiming 正态初始化器)。 |
he_uniform([in_axis, out_axis, batch_axis, ...]) |
构建一个 He 均匀初始化器(又称 Kaiming 均匀初始化器)。 |
lecun_normal([in_axis, out_axis, ...]) |
构建一个 Lecun 正态初始化器。 |
lecun_uniform([in_axis, out_axis, ...]) |
构建一个 Lecun 均匀初始化器。 |
normal([stddev, dtype]) |
构建一个返回实数正态分布随机数组的初始化器。 |
ones(key, shape[, dtype]) |
返回一个填充为一的常数数组的初始化器。 |
orthogonal([scale, column_axis, dtype]) |
构建一个返回均匀分布正交矩阵的初始化器。 |
truncated_normal([stddev, dtype, lower, upper]) |
构建一个返回截断正态分布随机数组的初始化器。 |
uniform([scale, dtype]) |
构建一个返回实数均匀分布随机数组的初始化器。 |
variance_scaling(scale, mode, distribution) |
初始化器,根据权重张量的形状调整其尺度。 |
zeros(key, shape[, dtype]) |
返回一个填充零的常数数组的初始化器。 |
jax.ops 模块
段落约简运算符
| segment_max(data, segment_ids[, ...]) | 计算数组段内的最大值。 |
函数 jax.ops.index_update、jax.ops.index_add 等已在 JAX 0.2.22 中弃用,并已移除。请改用 JAX 数组上的 jax.numpy.ndarray.at 属性。 |
|---|
segment_min(data, segment_ids[, ...]) |
segment_prod(data, segment_ids[, ...]) |
segment_sum(data, segment_ids[, ...]) |
jax.profiler 模块
跟踪和时间分析
描述了如何利用 JAX 的跟踪和时间分析功能进行程序性能分析。
start_server(port) |
在指定端口启动分析器服务器。 |
|---|---|
start_trace(log_dir[, create_perfetto_link, ...]) |
启动性能分析跟踪。 |
stop_trace() |
停止当前正在运行的性能分析跟踪。 |
trace(log_dir[, create_perfetto_link, ...]) |
上下文管理器,用于进行性能分析跟踪。 |
annotate_function(func[, name]) |
生成函数执行的跟踪事件的装饰器。 |
TraceAnnotation |
在分析器中生成跟踪事件的上下文管理器。 |
StepTraceAnnotation(name, **kwargs) |
在分析器中生成步骤跟踪事件的上下文管理器。 |
设备内存分析
请参阅设备内存分析,了解 JAX 的设备内存分析功能简介。
device_memory_profile([backend]) |
捕获 JAX 设备内存使用情况,格式为 pprof 协议缓冲区。 |
|---|---|
save_device_memory_profile(filename[, backend]) |
收集设备内存使用情况,并将其写入文件。 |
jax.stages 模块
接口到编译执行过程的各个阶段。
JAX 转换,例如jax.jit和jax.pmap,也支持一种通用的显式降阶和预编译执行 ahead of time 的方式。 该模块定义了代表这一过程各个阶段的类型。
有关更多信息,请参阅AOT walkthrough。
类
class jax.stages.Wrapped(*args, **kwargs)
一个准备好进行追踪、降阶和编译的函数。
此协议反映了诸如jax.jit之类的函数的输出。 调用它会导致 JIT(即时)降阶、编译和执行。 它也可以在编译之前明确降阶,并在执行之前编译结果。
__call__(*args, **kwargs)
执行包装的函数,根据需要进行降阶和编译。
lower(*args, **kwargs)
明确为给定的参数降阶此函数。
一个降阶函数被从 Python 阶段化,并翻译为编译器的输入语言,可能以依赖于后端的方式。 它已准备好进行编译,但尚未编译。
返回:
一个Lowered实例,表示降阶。
返回类型:
降阶
trace(*args, **kwargs)
明确为给定的参数追踪此函数。
一个追踪函数被从 Python 阶段化,并翻译为一个 jaxpr。 它已准备好进行降阶,但尚未降阶。
返回:
一个Traced实例,表示追踪。
返回类型:
追踪
class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)
降阶一个根据参数类型和值特化的函数。
降阶是一种准备好进行编译的计算。 此类将降阶与稍后编译和执行所需的剩余信息一起携带。 它还提供了一个通用的 API,用于查询 JAX 各种降阶路径(jit()、pmap()等)中降阶计算的属性。
参数:
as_text(dialect=None)
此降阶的人类可读文本表示。
旨在可视化和调试目的。 这不必是有效的也不一定可靠的序列化。 它直接传递给外部调用者。
参数:
方言(str | 无) – 可选字符串,指定一个降阶方言(例如,“stablehlo”)
返回类型:
compile(compiler_options=None)
编译,并返回相应的Compiled实例。
参数:
compiler_options (dict[str, str | bool] | None)
返回类型:
Compiled
compiler_ir(dialect=None)
这种降低的任意对象表示。
旨在调试目的。这不是有效的也不是可靠的序列化。输出在不同调用之间没有一致性的保证。
如果不可用,则返回None,例如基于后端、编译器或运行时。
参数:
dialect (str | None) – 可选字符串,指定一个降低方言(例如“stablehlo”)
返回类型:
Any | None
cost_analysis()
执行成本估算的摘要。
旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。
如果不可用,则返回None,例如基于后端、编译器或运行时。
返回类型:
Any | None
property in_tree: PyTreeDef
一对(位置参数、关键字参数)的树结构。
class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)
编译后的函数专门针对类型/值进行了优化表示。
编译计算与可执行文件相关联,并提供执行所需的剩余信息。它还为查询 JAX 的各种编译路径和后端中编译计算属性提供了一个共同的 API。
参数:
-
args_info (Any)
-
out_tree (PyTreeDef)
__call__(*args, **kwargs)
将自身作为函数调用。
as_text()
这是可执行文件的人类可读文本表示。
旨在可视化和调试。这不是有效的也不是可靠的序列化。
如果不可用,则返回None,例如基于后端、编译器或运行时。
返回类型:
str | None
cost_analysis()
执行成本估算的摘要。
旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。
如果不可用,则返回None,例如基于后端、编译器或运行时。
返回类型:
Any | None
property in_tree: PyTreeDef
(位置参数,关键字参数) 的树结构。
memory_analysis()
估计内存需求的摘要。
用于可视化和调试目的。由此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如嵌套的字典、列表和具有数字叶子的元组)。然而,其结构可以是任意的:在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能是不一致的。
返回 None 如果不可用,例如基于后端、编译器或运行时。
返回类型:
任意 | None
runtime_executable()
此可执行对象的任意对象表示。
用于调试目的。这不是有效也不是可靠的序列化。输出不能保证在不同调用之间的一致性。
返回 None 如果不可用,例如基于后端、编译器或运行时。
返回类型:
任意 | None


浙公网安备 33010602011771号