cross_entropy_error函数的理解,和代码解释。

首先原理如下:

 

然后代码如下:

def cross_entropy_error(y, t):
    if y.ndim == 1:
        t = t.reshape(1, t.size)
        y = y.reshape(1, y.size)
        
    if t.size == y.size:
        t = t.argmax(axis=1)
             
    batch_size = y.shape[0]

    return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size

下面举例说明:

通过一个具体的例子,详细展开 cross_entropy_error 函数的计算过程,逐行解释代码如何处理输入,并展示每一步的数值计算。

这样可以更直观地理解代码的逻辑和交叉熵损失的计算。

 

posted @ 2025-05-19 15:54  AlphaGeek  阅读(40)  评论(0)    收藏  举报