528. 按权重随机选择
labuladong 题解
难度中等
给你一个 下标从 0 开始 的正整数数组 w ,其中 w[i] 代表第 i 个下标的权重。
请你实现一个函数 pickIndex ,它可以 随机地 从范围 [0, w.length - 1] 内(含 0 和 w.length - 1)选出并返回一个下标。选取下标 i 的 概率 为 w[i] / sum(w) 。
- 例如,对于
w = [1, 3],挑选下标0的概率为1 / (1 + 3) = 0.25(即,25%),而选取下标1的概率为3 / (1 + 3) = 0.75(即,75%)。
示例 1:
输入: ["Solution","pickIndex"] [[[1]],[]] 输出: [null,0] 解释: Solution solution = new Solution([1]); solution.pickIndex(); // 返回 0,因为数组中只有一个元素,所以唯一的选择是返回下标 0。
示例 2:
输入: ["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"] [[[1,3]],[],[],[],[],[]] 输出: [null,1,1,1,1,0] 解释: Solution solution = new Solution([1, 3]); solution.pickIndex(); // 返回 1,返回下标 1,返回该下标概率为 3/4 。 solution.pickIndex(); // 返回 1 solution.pickIndex(); // 返回 1 solution.pickIndex(); // 返回 1 solution.pickIndex(); // 返回 0,返回下标 0,返回该下标概率为 1/4 。 由于这是一个随机问题,允许多个答案,因此下列输出都可以被认为是正确的: [null,1,1,1,1,0] [null,1,1,1,1,1] [null,1,1,1,0,0] [null,1,1,1,0,1] [null,1,0,1,0,0] ...... 诸若此类。
class Solution: def __init__(self, w: List[int]): self.presum = [0] for n in w: self.presum.append(self.presum[-1]+n) def find(self,nums,target): lo = 0 hi = len(nums) - 1 while lo <= hi: mid = lo + (hi-lo)//2 if nums[mid] < target: lo = mid + 1 elif nums[mid] > target: hi = mid - 1 else: lo = mid +1 return lo def pickIndex(self) -> int: target = random.randint(0,self.presum[-1]-1) return self.find(self.presum,target) -1 # Your Solution object will be instantiated and called as such: # obj = Solution(w) # param_1 = obj.pickIndex()
class Solution { public: vector<int> presum; Solution(vector<int>& w) { int n = w.size(); presum = vector<int>(n+1,0); for(int i = 0; i < n;i++) { presum[i+1] = presum[i] + w[i]; } } int find_left_bound(int target) { if (presum.size()==0) return -1; int low = 0,high = presum.size(); while(low < high) { int mid = low + (high-low)/2; if (presum[mid]<target) { low = mid +1; } else { high = mid; } } return low; } int pickIndex() { // rand() % (b-a+1)+ a ; 就表示 a~b 之间的一个随机整数 int target = (rand()%(presum.back() - 1 + 1)) + 1;//生成闭区间[1,total]范围内的一个随机数 cout << target << endl; return find_left_bound(target)-1; } }; /** * Your Solution object will be instantiated and called as such: * Solution* obj = new Solution(w); * int param_1 = obj->pickIndex(); */
import math import random import bisect class WeightedRandomSampler: def __init__(self, w: list[int]): self.n = len(w) self.block_size = max(1, int(math.isqrt(self.n))) # 块大小 = √n self.num_blocks = (self.n + self.block_size - 1) // self.block_size self.blocks = [] # 存储块信息 self.total_sum = 0 # 总权重和 # 初始化分块 for i in range(self.num_blocks): start = i * self.block_size end = min((i + 1) * self.block_size, self.n) block_weights = w[start:end] # 块内权重 # 计算块内前缀和及总和 prefix = [] block_total = 0 for weight in block_weights: block_total += weight prefix.append(block_total) self.blocks.append({ 'start_index': start, # 块起始下标 'weights': block_weights, # 块内权重数组 'prefix': prefix, # 块内前缀和 'total': block_total # 块权重总和 }) self.total_sum += block_total def update(self, index: int, new_value: int) -> None: # 定位块和块内偏移 block_idx = index // self.block_size offset = index % self.block_size block = self.blocks[block_idx] # 更新总权重 old_value = block['weights'][offset] delta = new_value - old_value self.total_sum += delta block['total'] += delta block['weights'][offset] = new_value # 更新块内前缀和(O(√n)时间) block['prefix'] = [] current = 0 for weight in block['weights']: current += weight block['prefix'].append(current) def pickIndex(self) -> int: if self.total_sum <= 0: return -1 # 无有效权重 # 生成随机目标值 target = random.randint(1, self.total_sum) current_sum = 0 # 当前累计和 # 遍历块定位目标块(O(√n)时间) for block in self.blocks: if target <= current_sum + block['total']: # 在块内二分查找(O(log √n)时间) t = target - current_sum idx_in_block = bisect.bisect_left(block['prefix'], t) return block['start_index'] + idx_in_block current_sum += block['total'] return self.n - 1 # 保底返回

浙公网安备 33010602011771号