前缀共享数据集生成
def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, prefix_sharing_strength: float = 0.5, # 0-1,控制共享强度,越高共享越多 min_prefix_tokens: int = 2, # 前缀最小token数 max_prefix_tokens: int = 8, # 前缀最大token数 **kwargs, ) -> list[SampleRequest]: # 验证参数有效性 assert range_ratio < 1.0, ( "random_range_ratio must be < 1.0 to ensure a valid sampling range" ) assert 0.0 <= prefix_sharing_strength <= 1.0, ( "prefix_sharing_strength must be between 0.0 and 1.0" ) assert min_prefix_tokens <= max_prefix_tokens, "min_prefix_tokens must be <= max_prefix_tokens" assert min_prefix_tokens > 0, "min_prefix_tokens must be positive" vocab_size = tokenizer.vocab_size num_special_tokens = tokenizer.num_special_tokens_to_add() real_input_len = input_len - num_special_tokens # 输入输出长度采样逻辑 input_low = int(real_input_len * (1 - range_ratio)) input_high = int(real_input_len * (1 + range_ratio)) output_low = int(output_len * (1 - range_ratio)) output_high = int(output_len * (1 + range_ratio)) # 调试日志 logger.info("Sampling input_len from [%s, %s]", input_low, input_high) logger.info("Sampling output_len from [%s, %s]", output_low, output_high) logger.info("Prefix sharing strength: %s", prefix_sharing_strength) input_lens = np.random.randint(input_low, input_high + 1, size=num_requests) output_lens = np.random.randint(output_low, output_high + 1, size=num_requests) offsets = np.random.randint(0, vocab_size, size=num_requests) # 生成基础共享前缀集合(数量取决于共享强度) # 共享强度越高,基础前缀数量越少,共享程度越高 num_base_prefixes = max(1, int(5 * (1 - prefix_sharing_strength) + 1)) # 最多5个基础前缀 base_prefixes = [] for _ in range(num_base_prefixes): # 随机生成前缀长度 prefix_len = random.randint(min_prefix_tokens, max_prefix_tokens) # 随机生成前缀token prefix_tokens = np.random.randint(0, vocab_size, size=prefix_len).tolist() base_prefixes.append(prefix_tokens) requests = [] for i in range(num_requests): # 根据共享强度选择前缀生成方式 if random.random() < prefix_sharing_strength: # 高概率使用基础共享前缀(促进共享) selected_prefix = random.choice(base_prefixes) else: # 低概率生成全新的随机前缀(增加多样性) prefix_len = random.randint(min_prefix_tokens, max_prefix_tokens) selected_prefix = np.random.randint(0, vocab_size, size=prefix_len).tolist() # 生成内部序列 inner_seq = ( (offsets[i] + i + np.arange(input_lens[i])) % vocab_size ).tolist() # 组合完整的token序列 token_sequence = selected_prefix + inner_seq prompt = tokenizer.decode(token_sequence) # 重新编码解码以确保长度一致性 total_input_len = len(selected_prefix) + int(input_lens[i]) re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ :total_input_len ] prompt = tokenizer.decode(re_encoded_sequence) total_input_len = len(re_encoded_sequence) requests.append( SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), ) ) return requests