def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
prefix_sharing_strength: float = 0.5, # 0-1,控制共享强度,越高共享越多
min_prefix_tokens: int = 2, # 前缀最小token数
max_prefix_tokens: int = 8, # 前缀最大token数
**kwargs,
) -> list[SampleRequest]:
# 验证参数有效性
assert range_ratio < 1.0, (
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
)
assert 0.0 <= prefix_sharing_strength <= 1.0, (
"prefix_sharing_strength must be between 0.0 and 1.0"
)
assert min_prefix_tokens <= max_prefix_tokens, "min_prefix_tokens must be <= max_prefix_tokens"
assert min_prefix_tokens > 0, "min_prefix_tokens must be positive"
vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
# 输入输出长度采样逻辑
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))
# 调试日志
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
logger.info("Sampling output_len from [%s, %s]", output_low, output_high)
logger.info("Prefix sharing strength: %s", prefix_sharing_strength)
input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests)
# 生成基础共享前缀集合(数量取决于共享强度)
# 共享强度越高,基础前缀数量越少,共享程度越高
num_base_prefixes = max(1, int(5 * (1 - prefix_sharing_strength) + 1)) # 最多5个基础前缀
base_prefixes = []
for _ in range(num_base_prefixes):
# 随机生成前缀长度
prefix_len = random.randint(min_prefix_tokens, max_prefix_tokens)
# 随机生成前缀token
prefix_tokens = np.random.randint(0, vocab_size, size=prefix_len).tolist()
base_prefixes.append(prefix_tokens)
requests = []
for i in range(num_requests):
# 根据共享强度选择前缀生成方式
if random.random() < prefix_sharing_strength:
# 高概率使用基础共享前缀(促进共享)
selected_prefix = random.choice(base_prefixes)
else:
# 低概率生成全新的随机前缀(增加多样性)
prefix_len = random.randint(min_prefix_tokens, max_prefix_tokens)
selected_prefix = np.random.randint(0, vocab_size, size=prefix_len).tolist()
# 生成内部序列
inner_seq = (
(offsets[i] + i + np.arange(input_lens[i])) % vocab_size
).tolist()
# 组合完整的token序列
token_sequence = selected_prefix + inner_seq
prompt = tokenizer.decode(token_sequence)
# 重新编码解码以确保长度一致性
total_input_len = len(selected_prefix) + int(input_lens[i])
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
:total_input_len
]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = len(re_encoded_sequence)
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
)
)
return requests