计算梯度jax

# %%
import jax        
import jax.numpy as jnp

import numpy as np

def loss(params, r):
    lambda_a, lambda_s = params
    return jnp.maximum(r - lambda_a + lambda_s, 0).max()
 

loss_grad = jax.grad(loss)

grad_a, grad_s = loss_grad(params, r)
print(grad_a, grad_s)

##%%

def jac(params, r):
    lambda_a, lambda_s = params
    v = r - lambda_a + lambda_s
    g_b = np.logical_and(v>0, v==v.max()).astype(float)
    g_b /= g_b.sum()
    return -g_b, g_b


#%% 测试
print("data1")
      
params = (np.array([-.50, .51, .51]), np.random.randn(3) )
r =  1

grad_a, grad_s = loss_grad(params, r)
print(params)    
print(grad_a, grad_s)
ga, gs = jac(params, r)
print(grad_a==ga, grad_s==gs)
posted @ 2023-05-13 13:26  bregman  阅读(23)  评论(0编辑  收藏  举报