前缀共享数据集生成

def sample(
        self,
        tokenizer: PreTrainedTokenizerBase,
        num_requests: int,
        range_ratio: float = DEFAULT_RANGE_RATIO,
        input_len: int = DEFAULT_INPUT_LEN,
        output_len: int = DEFAULT_OUTPUT_LEN,
        prefix_sharing_strength: float = 0.5,  # 0-1,控制共享强度,越高共享越多
        min_prefix_tokens: int = 2,  # 前缀最小token数
        max_prefix_tokens: int = 8,  # 前缀最大token数
        **kwargs,
    ) -> list[SampleRequest]:
        # 验证参数有效性
        assert range_ratio < 1.0, (
            "random_range_ratio must be < 1.0 to ensure a valid sampling range"
        )
        assert 0.0 <= prefix_sharing_strength <= 1.0, (
            "prefix_sharing_strength must be between 0.0 and 1.0"
        )
        assert min_prefix_tokens <= max_prefix_tokens, "min_prefix_tokens must be <= max_prefix_tokens"
        assert min_prefix_tokens > 0, "min_prefix_tokens must be positive"

        vocab_size = tokenizer.vocab_size
        num_special_tokens = tokenizer.num_special_tokens_to_add()
        real_input_len = input_len - num_special_tokens

        # 输入输出长度采样逻辑
        input_low = int(real_input_len * (1 - range_ratio))
        input_high = int(real_input_len * (1 + range_ratio))
        output_low = int(output_len * (1 - range_ratio))
        output_high = int(output_len * (1 + range_ratio))

        # 调试日志
        logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
        logger.info("Sampling output_len from [%s, %s]", output_low, output_high)
        logger.info("Prefix sharing strength: %s", prefix_sharing_strength)

        input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
        output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
        offsets = np.random.randint(0, vocab_size, size=num_requests)

        # 生成基础共享前缀集合(数量取决于共享强度)
        # 共享强度越高,基础前缀数量越少,共享程度越高
        num_base_prefixes = max(1, int(5 * (1 - prefix_sharing_strength) + 1))  # 最多5个基础前缀
        base_prefixes = []
        
        for _ in range(num_base_prefixes):
            # 随机生成前缀长度
            prefix_len = random.randint(min_prefix_tokens, max_prefix_tokens)
            # 随机生成前缀token
            prefix_tokens = np.random.randint(0, vocab_size, size=prefix_len).tolist()
            base_prefixes.append(prefix_tokens)

        requests = []
        for i in range(num_requests):
            # 根据共享强度选择前缀生成方式
            if random.random() < prefix_sharing_strength:
                # 高概率使用基础共享前缀(促进共享)
                selected_prefix = random.choice(base_prefixes)
            else:
                # 低概率生成全新的随机前缀(增加多样性)
                prefix_len = random.randint(min_prefix_tokens, max_prefix_tokens)
                selected_prefix = np.random.randint(0, vocab_size, size=prefix_len).tolist()

            # 生成内部序列
            inner_seq = (
                (offsets[i] + i + np.arange(input_lens[i])) % vocab_size
            ).tolist()
            
            # 组合完整的token序列
            token_sequence = selected_prefix + inner_seq
            prompt = tokenizer.decode(token_sequence)
            
            # 重新编码解码以确保长度一致性
            total_input_len = len(selected_prefix) + int(input_lens[i])
            re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
                :total_input_len
            ]
            prompt = tokenizer.decode(re_encoded_sequence)
            total_input_len = len(re_encoded_sequence)
            
            requests.append(
                SampleRequest(
                    prompt=prompt,
                    prompt_len=total_input_len,
                    expected_output_len=int(output_lens[i]),
                )
            )
        return requests

 

posted @ 2025-09-04 16:45  Echo宝贝儿  阅读(14)  评论(0)    收藏  举报