Live2D

Note -「Spiking Neural Network」SNN 光速入门

\[\mathscr{Lorain~wy~Lora~blea.} \newcommand{\DS}[0]{\displaystyle} % operators alias \newcommand{\opn}[1]{\operatorname{#1}} \newcommand{\card}[0]{\opn{card}} \newcommand{\lcm}[0]{\opn{lcm}} \newcommand{\char}[0]{\opn{char}} \newcommand{\Char}[0]{\opn{Char}} \newcommand{\Min}[0]{\opn{Min}} \newcommand{\rank}[0]{\opn{rank}} \newcommand{\Hom}[0]{\opn{Hom}} \newcommand{\End}[0]{\opn{End}} \newcommand{\im}[0]{\opn{im}} \newcommand{\tr}[0]{\opn{tr}} \newcommand{\diag}[0]{\opn{diag}} \newcommand{\coker}[0]{\opn{coker}} \newcommand{\id}[0]{\opn{id}} \newcommand{\sgn}[0]{\opn{sgn}} \newcommand{\Res}[0]{\opn{Res}} \newcommand{\Ad}[0]{\opn{Ad}} \newcommand{\ord}[0]{\opn{ord}} \newcommand{\Stab}[0]{\opn{Stab}} \newcommand{\conjeq}[0]{\sim_{\u{conj}}} \newcommand{\cent}[0]{\u{\degree C}} \newcommand{\Sym}[0]{\opn{Sym}} \newcommand{\Var}[0]{\opn{Var}} \newcommand{\wg}[0]{\wedge} \newcommand{\Wg}[0]{\bigwedge} \newcommand{\sq}[0]{\opn{\square}} % symbols alias \newcommand{\E}[0]{\exist} \newcommand{\A}[0]{\forall} \newcommand{\l}[0]{\left} \newcommand{\r}[0]{\right} \newcommand{\ox}[0]{\otimes} \newcommand{\lra}[0]{\leftrightarrow} \newcommand{\llra}[0]{\longleftrightarrow} \newcommand{\iso}[1]{\overset{\sim}{#1}} \newcommand{\eps}[0]{\varepsilon} \newcommand{\Ra}[0]{\Rightarrow} \newcommand{\Eq}[0]{\Leftrightarrow} \newcommand{\d}[0]{\mathrm{d}} \newcommand{\e}[0]{\mathrm{e}} \newcommand{\i}[0]{\mathrm{i}} \newcommand{\j}[0]{\mathrm{j}} \newcommand{\k}[0]{\mathrm{k}} \newcommand{\Ex}[0]{\mathbb{E}} \newcommand{\D}[0]{\mathbb{D}} \newcommand{\oo}[0]{\infty} \newcommand{\tto}[0]{\rightrightarrows} \newcommand{\mmap}[0]{\hookrightarrow} \newcommand{\emap}[0]{\twoheadrightarrow} \newcommand{\actl}[0]{\curvearrowright} \newcommand{\actr}[0]{\curvearrowleft} \newcommand{\nsubg}[0]{\triangleleft} \newcommand{\nsupg}[0]{\triangleright} \newcommand{\lin}[0]{\lim_{n\to\oo}} \newcommand{\linf}[0]{\liminf_{n\to\oo}} \newcommand{\lsup}[0]{\limsup_{n\to\oo}} \newcommand{\ser}[0]{\sum_{n=1}^\oo} \newcommand{\serz}[0]{\sum_{n=0}^\oo} \newcommand{\isoto}[0]{\overset\sim\to} \newcommand{\F}[0]{\mathbb F} \newcommand{\x}[0]{\times} \newcommand{\M}[0]{\mathbf{M}} \newcommand{\T}[0]{\intercal} \newcommand{\Co}[0]{\complement} \newcommand{\alp}[0]{\alpha} \newcommand{\lmd}[0]{\lambda} \newcommand{\mmid}[0]{\parallel} % symbols with parameters \newcommand{\der}[1]{\frac{\d}{\d #1}} \newcommand{\ul}[1]{\underline{#1}} \newcommand{\ol}[1]{\overline{#1}} \newcommand{\wt}[1]{\widetilde{#1}} \newcommand{\br}[1]{\l(#1\r)} \newcommand{\bk}[1]{\l[#1\r]} \newcommand{\ev}[1]{\l.#1\r|} \newcommand{\wh}[1]{\widehat{#1}} \newcommand{\eval}[1]{\l[\!\l[#1\r]\!\r]} \newcommand{\abs}[1]{\l|#1\r|} \newcommand{\bs}[1]{\boldsymbol{#1}} \newcommand{\dat}[1]{\bs{\mathrm{#1}}} \newcommand{\env}[2]{\begin{#1}#2\end{#1}} \newcommand{\ALI}[1]{\env{aligned}{#1}} \newcommand{\CAS}[1]{\env{cases}{#1}} \newcommand{\pmat}[1]{\env{pmatrix}{#1}} \newcommand{\dary}[2]{\l|\begin{array}{#1}#2\end{array}\r|} \newcommand{\pary}[2]{\l(\begin{array}{#1}#2\end{array}\r)} \newcommand{\pblk}[4]{\l(\begin{array}{c|c}{#1}&{#2}\\\hline{#3}&{#4}\end{array}\r)} \newcommand{\u}[1]{\mathrm{#1}} \newcommand{\t}[1]{\texttt{#1}} \newcommand{\lix}[1]{\lim_{x\to #1}} \newcommand{\ops}[1]{#1\cdots #1} \newcommand{\seq}[3]{{#1}_{#2}\ops,{#1}_{#3}} \newcommand{\dedu}[2]{\u{(#1)}\Ra\u{(#2)}} \newcommand{\prv}[3]{\DS{{\DS #1} \over {\DS #2}}~(#3)} \]

  本文直接参考自知乎专栏《白话脉冲神经网络》, 它是对 snntorch 教程文档 的机翻 (当然翻译得还行). 本文适合让你快速理解 SNN 是什么, 以及上手使用一些基础的轮子.

模拟神经元

  通过将磷脂双分子层视作电容 \(C\), 将离子通道视作电阻 \(R\), 我们可以把神经元建模为一个带有外部 (膜外) 电流输入 \(I_{\u{in}}\) 的 RC 振荡电路 (麻烦你脑补一下). 设 \(t\) 时刻外部电流为 \(I_{\u{in}}=I_{\u{in}}(t)\), 膜内外电势差为 \(U\) (它在大多数资料中被记为 \(U_{\u{mem}}\), 即 membrane 的电势差), 电容电量为 \(Q\), 那么使用一些高中物理知识易得

\[I_{\u{in}}=I_R+I_C,\\ U=I_R\cdot R,\\ Q=CU,\\ I_C=\frac{\d Q}{\d t}=C\frac{\d U}{\d t}. \]

以此刻画出电势差与输入电流的关系

\[I_{\u{in}}=\frac{U}{R}+C\frac{\d U}{\d t}\\ \Eq U+RC\cdot \dot U-RI_{\u{in}}=0. \]

我们知道在这个微分方程存在指数形式的解析解. 特别地, 如果保持 \(I_{\u{in}}=0\), 设电路的时间常数 \(\tau=RC\), 这时就有

\[U=U_0\e^{-\frac{t}{\tau}}. \]

如果 \(U_0=0\)\(I_{\u{in}}\)\(t_0\) 时刻从 \(0\) 阶跃到 \(I_0\), 这时就有

\[U=\CAS{ 0,&t\le t_0;\\ RI_0\br{1-\e^{-\frac{t-t_0}{\tau}}},&t>t_0. } \]

联系生活实际 (指高中生物做的一车题), 这很好地描述了膜电位先快后慢的回复过程.

  Anyway, 为了模拟神经元的行为, 我们更青睐对上述微分方程的离散近似. 设时间步为 \(\Delta t\), 使用前向 Euler 法模拟, 就有

\[\tau\frac{U(t+\Delta t)-U(t)}{\Delta t}=RI_{\u{in}}(t+\Delta t)-U(t)\\ \Ra U(t+\Delta t)=\frac{\Delta t}{\tau}RI_{\u{in}}(t+\Delta t)+\br{1-\frac{\Delta t}{\tau}}U(t). \]

(作为参考, 对于生物神经元, 一个可能的数据为 \(R=50\u{M\Omega}\), \(C=100\u{pF}\), 则 \(\tau=5\u{ms}\).)

  接下来, 我们需要建模神经元最关键的行为: 兴奋. 我们探测膜内外电势差 \(U\), 如果它超过了取定的阈值 \(\vartheta\), 神经元在当前时刻输出脉冲信号 \(S=1\), 否则输出 \(S=0\). 紧接着, 我们需要模拟神经元超极化的过程: 当神经元产生兴奋, 它会顷刻开启大量粒子通道, 使得膜电位迅速回落并最终恢复静息电位. 最暴力的办法自然是

\[S(t)=[U(t)>\vartheta],\\ U(t+\Delta t)=\frac{\Delta t}{\tau}RI_{\u{in}}(t+\Delta t)+\br{1-\frac{\Delta t}{\tau}}U(t)-\vartheta S(t). \]

这样, 当输入恒定的且足以使神经元兴奋的电流 \(I_0\) 时, 神经元会反复地累积电势差, 兴奋并释放脉冲信号, 回复, 如此循环.

  我们差不多就 sketch 出了 snntorch.Lapicque 这个轮子.

  作为补充, 我们还可以模拟二阶神经元, 它将 \(I_{\u{in}}\) 细致描述为另一个 leak & integration 过程. 即

\[I_{\u{syn}}(t+\Delta t)=\frac{\Delta t}{\tau'}I_{\u{in}}(t+\Delta t)+\br{1-\frac{\Delta t}{\tau'}}I_{\u{syn}}(t). \]

之后用 \(I_{\u{syn}}\) 替换 \(U\) 中的 \(I_{\u{in}}\) 即可. 经验上, 二阶神经元能够更好地编码时序或长时间尺度的信息.

神经元到网络

  轮子看上去很 "物理", 但用轮子的时候就可以开始砍掉冗余参数了. 设 \(R=1\), \(\Delta t=1\), 这样

\[\beta:=\ev{\frac{U(t+\Delta t)}{U(t)}}_{I_{\u{in}}=0}=1-\frac{\Delta t}{\tau}=1-\frac{1}{C}. \]

它被称为逆时间常数. 我们把 \(\beta\) 作为新的超参数改写 Euler 法的式子 (因为 \(\Delta t=1\) 已经充分离散, 我们此后不加说明地将 \(U\) 等物理量视为关于 \(t\in\Z_{\ge 0}\) 的列表):

\[U_{t+1}=\beta U_t+(1-\beta){I_{\u{in}~t+1}}-\vartheta S_{t}. \]

对于输入电流 \(I_{\u{in}}\) 的缩放系数 \(1-\beta\), 我们将它作为突触间的可学习权值. 设向量 \(\dat x_{t+1}\) 表示此神经元的前神经元输出, 向量 \(\dat w\) 描述该神经元接收前神经元的电流缩放参数, 我们得到

\[S_{t}=[U_t>\vartheta],\\ U_{t+1}=\beta U_t+\dat w^\T\dat x_{t+1}-\vartheta S_{t}. \]

其中 \(\beta\) 是超参数, \(\dat w\) 是可学习参数 (可以认为它吸收了 \(1-\beta\) 这个常数), \(\vartheta\) 一般也被设置为 \(1\).

  我们又得到了一个叫做 snntorch.Leaky 的轮子, 它确实只需要设置 \(\beta\), 并贴心地让 \(\vartheta\) 默认为 \(1\).

  如果把 \(\dat w^\T\dat x_{t+1}\) 丢到 FC 里, 这样的神经元就可以在神经网络中被抽象为一个与时间步 \(t\) 相关的激活函数 \(f_t(\cdot)\). 这样, 我们就可以搭一个前馈脉冲神经网络啦!

# 初始化
fc1 = nn.Linear(num_inputs, num_hidden)
lif1 = snn.Leaky(beta=beta)
fc2 = nn.Linear(num_hidden, num_outputs)
lif2 = snn.Leaky(beta=beta)

mem1 = lif1.init_leaky()
mem2 = lif2.init_leaky()

# 模拟
for step in range(num_steps):
    cur1 = fc1(spk_in[step])
    spk1, mem1 = lif1(cur1, mem1)
    cur2 = fc2(spk1)
    spk2, mem2 = lif2(cur2, mem2)

代理梯度

  聪明的小朋友就要问了, 你这 \(S_{t+1}=[U_t>\vartheta]\) 不是直接把梯度打烂了吗? 具体来说, 我们希望计算

\[\frac{\part\mathcal L}{\part U_t}=\frac{\part L}{\part S_{t}}\cdot\frac{\part S_{t}}{\part U_t}, \]

但后面这一项令人两眼一黑. 这该怎么办呢?

  我们肯定需要采取 代理梯度 的方案, 即, 正向传播时执行 \(S_{t}=[U_t>\vartheta]\), 但在方向传播时将后一项梯度代替为 \(\frac{\part \wt S_{t}}{\part U_t}\), 只要构造一个合理的可以传递梯度信息的 \(\wt S_{t}\), 我们就能进行反向传播了.

  如何构造呢? 我们直接令 \(\frac{\part \wt S_{t}}{\part U_t}=S_{t}\) (就像是认为 \(\wt S_{t}=\opn{ReLU}(U_t-\vartheta)\)), 这种简单的近似被称为脉冲运算符 (spike-operator) 方法, 它在物理意义上做了如下近似假设:

  • 如果神经元静默, \(S_{t}=0\), 自然就是 \(S_{t}\gets U_t\x 0\);
  • 如果神经元兴奋, \(S_{t}=1\), 假设 \(U_t\approx\vartheta\) (你看, 这听上去真的很合理), 而我们已经控制 \(\vartheta=1\), 这样 \(S_{t}\gets U_t\x 1\).

综合以上二者就得到 \(\frac{\part\wt S_{t}}{\part U_t}\), 合情合理. snntorch 的神经元都采用了代理梯度.

时间轴上反向传播 (BPTT)

  如果你会 BP 你就会 BPTT, 因为 BPTT 就是 BP T T.

  我们在 Leaky 神经元上操练一下 BPTT. 上游即时梯度 \(\frac{\part\mathcal L_t}{\part S_t}\) 对所有 \(t\) 已知, 我们希望最小化总损失 \(\mathcal L:=\sum_t\mathcal L_t\), 那么对于 Leaky 的可学习参数 \(\dat w\) 有:

\[\ALI{ \frac{\part\mathcal L}{\part \dat w} &= \sum_{t}\sum_{s\le t}\frac{\part\mathcal L_t}{\part\dat w_s}\underbrace{\frac{\part\dat w_s}{\part\dat w_t}}_{=1}\\ &= \sum_{t}\sum_{s\le t}\frac{\part\mathcal L_t}{\part\dat w_s}. } \]

\(\frac{\part\mathcal L_t}{\part \dat w_{t-1}}\) 为例, 它又能以链式法则展开为

\[\ALI{ \frac{\part\mathcal L_t}{\part\dat w_{t-1}} &\approx \frac{\part\mathcal L_t}{\part S_t}\x\frac{\part\wt S_t}{\part U_{t}}\x\frac{\part U_{t}}{\part U_{t-1}}\x\frac{\part U_{t-1}}{\part\dat w_{t-1}}\\ &= \frac{\part \mathcal L_t}{\part S_t}\x S_t\x\beta\x \dat x_{t-1}\\ &= \beta S_t\cdot\frac{\part\mathcal L_t}{\part S_t}\x\dat x_{t-1}. } \]

  Anyway, 伟大的自动梯度会解决这些计算.

MNIST

  对于输出编码和损失计算, 这里采用最简易的模式: 将放电频率最高的输出神经元作为预测类别, 并对输出神经元的膜电位应用 Softmax 和 Cross Entropy Loss.

import torch
from torchvision import datasets, transforms
import snntorch as snn

batch_size = 128
transform = transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])
mnist_train = datasets.MNIST(
    root='./data', train=True, transform=transform, download=True
)
mnist_test = datasets.MNIST(
    root='./data', train=False, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
    dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=mnist_test, batch_size=batch_size, shuffle=False, drop_last=True
)

num_inputs = 28 * 28
num_hidden = 1000
num_outputs = 10
num_steps = 25 # number of time steps
beta = 0.95 # inverse time constant
num_epochs = 5

class SNNModel(torch.nn.Module):
    def __init__(self):
        super(SNNModel, self).__init__()
        self.fc1 = torch.nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = torch.nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        spk2_rec, mem2_rec = [], []
        
        for step in range(num_steps):
            x_t = x.view(batch_size, -1)
            cur1 = self.fc1(x_t)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)
        
        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SNNModel().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999))

for epoch in range(num_epochs):
    epoch_correct = 0
    epoch_total = 0
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)
        model.train()
        spk_rec, mem_rec = model(data.view(batch_size, -1))

        with torch.no_grad():
            epoch_total += targets.size(0)
            frequencies = spk_rec.sum(dim=0)
            _, predicted = torch.max(frequencies.data, 1)
            epoch_correct += (predicted == targets).sum().item()

        total_loss = torch.zeros(1).to(device)
        for step in range(num_steps):
            loss = criterion(mem_rec[step], targets)
            total_loss += loss
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Loss: {total_loss.item()/num_steps:.4f}, '
          f'Accuracy: {100 * epoch_correct / epoch_total:.2f}%')

with torch.no_grad():
    model.eval()
    correct = 0
    total = 0
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        spk_rec, mem_rec = model(data.view(batch_size, -1))
        outputs = mem_rec[-1]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

  拿下 \(97.57\%\) 的测试集准确率. 当然, 更好的诸如种群编码的计数也值得一提, 还有 snntorch.backprop.BPTTsnntorch.Functional.* 等方便的轮子值得尝试, 不过它们是比较直观的, 留作课后读物叭.

posted @ 2026-02-04 16:19  Rainybunny  阅读(0)  评论(0)    收藏  举报