【代码】TF代码片段
- tf.switch_case, 1.15 以后版本生效
- fast_net 中不能用 lambda 函数, 否则结果不一致。
def fast_net(branch_index, branch_names=branch_names):
import functools
branch_fns = [(tf.equal(branch_index, k), functools.partial(branch_fn, i))
for k, i in enumerate(branch_names)]
return tf.case(branch_fns)
def train_net(branch_names=branch_names):
output_list = [branch_fn(i) for i in branch_names]
idx_one_hot = tf.one_hot(idx, len(branch_names), axis=1) # [B, 4, 1]
output = idx_one_hot * tf.transpose(output_list, perm=[1, 0, 2])
logits = tf.reduce_sum(output, axis=1) # [B, 1]
return logits
J, _ = tf.unique(idx[:, 0])
logits = tf.cond(tf.size(J) > 1, train_net, lambda: fast_net(J[0]))
--- 她说, 她是仙,她不是神