【代码】TF代码片段

  • tf.switch_case, 1.15 以后版本生效
  • fast_net 中不能用 lambda 函数, 否则结果不一致。
    def fast_net(branch_index, branch_names=branch_names):
    import functools
    branch_fns = [(tf.equal(branch_index, k), functools.partial(branch_fn, i))
                  for k, i in enumerate(branch_names)]
    return  tf.case(branch_fns)

    def train_net(branch_names=branch_names):
        output_list = [branch_fn(i) for i in branch_names]
        idx_one_hot = tf.one_hot(idx, len(branch_names), axis=1)  # [B, 4, 1]
        output = idx_one_hot * tf.transpose(output_list, perm=[1, 0, 2])
        logits = tf.reduce_sum(output, axis=1)  # [B, 1]
        return logits

    J, _ = tf.unique(idx[:, 0])
    logits = tf.cond(tf.size(J) > 1, train_net, lambda: fast_net(J[0]))
posted @ 2022-07-10 10:33  bregman  阅读(55)  评论(0编辑  收藏  举报