jax框架:jax.grad

官方地址:

https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad


image



这里只给出几个样例代码:

  1. 设置 allow_int 参数,实现对整数类型求导:

未对整数类型求导:

import jax

def fun(x, y):
    print(x, y)
    return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])

fun_grad = jax.grad(fun, argnums=(0, ))

x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]

print( fun_grad(x, y) )

正常运行:

image


对整数类型求导:

import jax

def fun(x, y):
    print(x, y)
    return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])

fun_grad = jax.grad(fun, argnums=(0, 1))

x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]

print( fun_grad(x, y) )

报错:

image


通过设置 allow_int 实现对整数类型求导:

import jax

def fun(x, y):
    print(x, y)
    return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])

fun_grad = jax.grad(fun, argnums=(0, 1), allow_int=True)

x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]

print( fun_grad(x, y) )

未报错运行,但是没有获得争取结果:

image


应该这么说,在jax中不能对整数类型求导的,虽然这里设置了 allow_int 但是也不能得到正确的对整数类型的求导。



posted on 2024-01-19 21:12  Angry_Panda  阅读(50)  评论(0)    收藏  举报

导航