GPT

2018 年,OpenAI 提出了基于 Transformer Decoder 架构的语言模型 GPT,并将在 5 年后轰动全世界……

image

模型定义

import torch as th
import torch.nn as nn

# 字符表和映射
vocab = list("abcdefghijklmnopqrstuvwxyz ,.!?")  # 简单字符表
stoi = {ch: i for i, ch in enumerate(vocab)}
itos = {i: ch for i, ch in enumerate(vocab)}

# 超参数
vocab_size = len(vocab)
embed_dim = 32
num_heads = 2
num_layers = 2
block_size = 16  # 最大输入长度

class SimpleGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(block_size, embed_dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True),
            num_layers=num_layers
        )
        self.ln = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, idx):
        N = idx.shape[1]
        token_emb = self.token_emb(idx)  # (B, N, D)
        pos = th.arange(N, device=idx.device)
        pos_emb = self.pos_emb(pos)      # (N, D)
        x = token_emb + pos_emb          # (B, N, D)
        mask = nn.Transformer.generate_square_subsequent_mask(N)
        x = self.transformer(x, mask=mask)
        x = self.ln(x)
        logits = self.head(x)            # (B, N, vocab_size)
        return logits

构造数据

# 超参数
batch_size = 8
num_epochs = 500
lr = 1e-2

# 构造训练数据
text = "hello world. hello world. hello world."
data = [stoi[c] for c in text]

def get_batch():
    # 随机采样 batch
    ix = th.randint(0, len(data) - block_size - 1, (batch_size,))
    x = th.stack([th.tensor(data[i:i+block_size]) for i in ix])
    y = th.stack([th.tensor(data[i+1:i+block_size+1]) for i in ix])
    return x, y

训练

import torch.nn.functional as F

# 实例化模型和优化器
device = th.device("cuda" if th.cuda.is_available() else "cpu")
model = SimpleGPT().to(device)
optimizer = th.optim.Adam(model.parameters(), lr=lr)

# 训练循环
for epoch in range(num_epochs):
    model.train()
    x, y = get_batch()
    x, y = x.to(device), y.to(device)
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch+1) % 50 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

生成

# 推理函数:给定前缀,生成后续字符
def generate(model, prefix, max_new_tokens=20):
    model.eval()
    idx = th.tensor([[stoi[c] for c in prefix]], dtype=th.long, device=device)
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]  # 只保留最后 block_size 个 token
        logits = model(idx_cond)
        next_id = th.argmax(logits[0, -1], dim=-1, keepdim=True)
        idx = th.cat([idx, next_id.view(1, 1)], dim=1)
    return ''.join([itos[i] for i in idx[0].tolist()])

output = generate(model, "hello", max_new_tokens=20)
print(output)

运行:

$ python main.py
hello world. hello world.

可以看到 SimpleGPT 成功学会了说 "hello world"!

参见:

posted @ 2025-06-18 20:04  Undefined443  阅读(14)  评论(0)    收藏  举报