提供一个 BPE(Byte Pair Encoding)算法
-
BPE 类实现了完整的训练、编码和解码功能:
train方法:通过迭代合并最频繁的字符对来构建词汇表encode方法:将文本转换为 BPE 子词序列decode方法:将 BPE 子词序列恢复为原始文本
-
核心原理:
- 训练时,算法不断寻找并合并最频繁的字符对,形成新的子词
- 编码时,根据训练好的合并规则将文本分解为子词序列
- 解码时,通过连接子词并替换特殊结束标记
</w>来恢复原始文本
-
恢复操作的关键:
- 编码过程中会为每个单词添加
</w>标记 - 解码时通过替换
</w>标记为空格来区分不同单词 - 最后通过简单的字符串处理去除多余空格,得到流畅的文本
- 编码过程中会为每个单词添加
import re
from collections import defaultdict
class BPE:
def __init__(self, vocab_size=100):
self.vocab_size = vocab_size
self.vocab = {}
self.merges = {} # 记录合并操作,用于解码
def get_pairs(self, words):
"""获取所有相邻字符对"""
pairs = defaultdict(int)
for word in words:
symbols = word.split()
for i in range(len(symbols) - 1):
pair = (symbols[i], symbols[i+1])
pairs[pair] += 1
return pairs
def merge_vocab(self, pair, v_in):
"""合并最频繁的字符对"""
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
for word in v_in:
w_out = p.sub(''.join(pair), word)
v_out[w_out] = v_in[word]
return v_out
def train(self, text):
"""训练BPE模型"""
# 预处理文本:分割单词并在末尾添加特殊标记
words = re.findall(r'\w+', text.lower())
words = [' '.join(list(word)) + ' </w>' for word in words]
# 初始化词汇表
vocab = defaultdict(int)
for word in words:
vocab[word] += 1
# 迭代合并最频繁的字符对
num_merges = self.vocab_size - len(set(''.join(words).replace(' ', '')))
for i in range(num_merges):
pairs = self.get_pairs(vocab.keys())
if not pairs:
break
# 选择最频繁的字符对
best_pair = max(pairs, key=pairs.get)
self.merges[best_pair] = i # 记录合并操作及顺序
# 合并词汇表
vocab = self.merge_vocab(best_pair, vocab)
# 构建最终词汇表
self.vocab = set()
for word in vocab:
self.vocab.update(word.split())
def encode(self, text):
"""将文本编码为BPE子词序列"""
words = re.findall(r'\w+', text.lower())
encoded = []
for word in words:
# 初始化字符序列,添加结束标记
tokens = list(word) + ['</w>']
# 应用合并规则
while len(tokens) > 1:
# 找到所有可能的相邻对
pairs = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]
# 找到在合并规则中出现最早的对
valid_pairs = {p: self.merges[p] for p in pairs if p in self.merges}
if not valid_pairs:
break
# 选择最早合并的对进行合并
best_pair = min(valid_pairs, key=valid_pairs.get)
# 合并这对字符
new_tokens = []
i = 0
while i < len(tokens):
if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == best_pair:
new_tokens.append(''.join(best_pair))
i += 2
else:
new_tokens.append(tokens[i])
i += 1
tokens = new_tokens
encoded.extend(tokens)
return encoded
def decode(self, encoded_tokens):
"""将BPE编码结果恢复为原始文本"""
# 替换结束标记,连接子词
text = ''.join(encoded_tokens).replace('</w>', ' ')
# 去除多余空格
text = re.sub(r'\s+', ' ', text).strip()
return text
# 示例使用
if __name__ == "__main__":
# 训练文本
training_text = """
Byte Pair Encoding (BPE) is a compression algorithm that iteratively replaces the most frequent pair of consecutive bytes with a single byte that does not appear in the original data.
It was first introduced for use in data compression by Philip Gage in a 1994 paper.
Later, BPE was adapted for use in natural language processing tasks, specifically for tokenization.
"""
# 创建并训练BPE模型
bpe = BPE(vocab_size=150)
bpe.train(training_text)
# 测试文本
test_text = "BPE is useful for natural language processing tasks."
print("原始文本:", test_text)
# 编码
encoded = bpe.encode(test_text)
print("BPE编码结果:", encoded)
# 解码(恢复)
decoded = bpe.decode(encoded)
print("恢复后的文本:", decoded)

浙公网安备 33010602011771号