学习日记2025.6.5
大模型
Flash Attention
主要针对HBM和SRAM进行的优化,目的是为了减少对HBM的读写,从而提高计算速度,核心技术是softmax分块计算
HBM(容量大但是读写慢),SRAM(容量小但是读写快)
softmax计算容易爆精度,(e的指数增长),所以引入safe softmax
如何分块计算?放弃了
Paged Attention
PagedAttention灵感来自于操作系统中虚拟内存和分页的经典思想,它可以允许在非连续空间立存储连续的KV张量。具体来说,PagedAttention把每个序列的KV缓存进行了分块,每个块包含固定长度的token,而在计算attention时可以高效地找到并获取那些块。
每个固定长度的块可以看成虚拟内存中的页,token可以看成字节,序列可以看成进程。那么通过一个块表就可以将连续的逻辑块映射到非连续的物理块,而物理块可以根据新生成的token按需分配。
所以序列在分块之后,只有最后一个块可能会浪费内存(实际中浪费的内存低于4%)。高效利用内存的好处很明显:系统可以在一个batch中同时输入更多的序列,提升GPU的利用率,显著地提升吞吐量。
PagedAttention的另外一个好处是高效内存共享。例如,在并行采样的时候,一个prompt需要生成多个输出序列。这种情况下,对于这个prompt的计算和内存可以在输出序列之间共享。
通过块表可以自然地实现内存共享。类似进程之间共享物理页,在PagedAttention中的不同序列通过将逻辑块映射到一样的物理块上可以实现共享块。为了确保安全共享,PagedAttention跟踪物理块的引用计数,并实现了Copy-on-Write机制。 内存共享减少了55%内存使用量,大大降低了采样算法的内存开销,同时提升了高达2.2倍的吞吐量。
transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ========== 1. Causal Self Attention ==========
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        assert d_model % n_head == 0
        self.d_head = d_model // n_head
        self.n_head = n_head
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.d_head).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B, n_head, T, d_head]
        scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_head)
        mask = torch.tril(torch.ones(T, T, device=x.device))
        scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = attn @ v  # [B, n_head, T, d_head]
        out = out.transpose(1, 2).contiguous().reshape(B, T, C)
        return self.proj(out)
# ========== 2. FeedForward ==========
class FeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, d_model)
        )
    def forward(self, x):
        return self.net(x)
# ========== 3. Transformer Block ==========
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, ff_hidden):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_head)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, ff_hidden)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))  # pre-LN
        x = x + self.ff(self.ln2(x))
        return x
# ========== 4. GPT Model ==========
class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_layer=4, d_model=128, n_head=4, ff_hidden=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, d_model))
        self.blocks = nn.Sequential(*[
            TransformerBlock(d_model, n_head, ff_hidden) for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
        self.block_size = block_size
        self.apply(self._init_weights)
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
    def forward(self, idx):
        B, T = idx.shape
        assert T <= self.block_size, f"Cannot handle sequence length {T} > block size {self.block_size}"
        tok_emb = self.token_emb(idx)                  # [B, T, C]
        x = tok_emb + self.pos_emb[:, :T, :]           # [B, T, C]
        x = self.blocks(x)                             # Transformer Blocks
        x = self.ln_f(x)                               # Final LayerNorm
        logits = self.head(x)                          # [B, T, vocab_size]
        return logits
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_crop = idx[:, -self.block_size:]
            logits = self(idx_crop)  # [B, T, vocab_size]
            next_token_logits = logits[:, -1, :]  # 只取最后一个 token 的预测
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)  # greedy decode
            idx = torch.cat((idx, next_token), dim=1)  # 拼接
        return idx
# 假设 vocab: {'I': 0, 'love': 1, 'AI': 2, '<pad>': 3, '<unk>': 4}
vocab = {"I": 0, "love": 1, "AI": 2, "<pad>": 3, "<unk>": 4}
ivocab = {v: k for k, v in vocab.items()}
# 初始化模型
model = GPT(vocab_size=len(vocab), block_size=8)
model.eval()
# 输入序列: "I love"
input_ids = torch.tensor([[vocab["I"], vocab["love"]]])  # shape: [1, 2]
# 生成3个新 token
output_ids = model.generate(input_ids, max_new_tokens=3)
print("Generated token ids:", output_ids)
# 解码为词
generated_text = [ivocab.get(i.item(), "<unk>") for i in output_ids[0]]
print("Generated text:", " ".join(generated_text))
                    
                
                
            
        
浙公网安备 33010602011771号