摘要: ```python # %% import jax import jax.numpy as jnp import numpy as np def loss(params, r): lambda_a, lambda_s = params return jnp.maximum(r - lambda_a 阅读全文
posted @ 2023-05-13 13:26 bregman 阅读(23) 评论(0) 推荐(0) 编辑