sortedcontainers - SortedList

sortedcontainers是python的第三方有序容器库,有SortedList、SortedKeyList、SortedSet、SortedDict四种容器。
SortedKeyList是可以指定比较函数的有序列表。接收一个可以传递给list.sort()中的key参数的函数。

SortedList

  1. 创建
from sortedcontainers import SortedList 
# SortedList(iterable=None)
sl1 = SortedList()
sl2 = SortedList([3,9,6])
# 可以传递一个可迭代对象或不传递任何参数来创建一个有序列表

2.添加元素

# SortedList.add(val)  添加单个元素
# SortedList.update(iterable)  把一个可迭代对象中的元素加入
sl1.add(7)
sl2.update([4,3,9,6])
  1. 删除元素
# SortedList.clear()  清空列表
# SortedList.discard(val)  删除val,如果val在列表中
# SortedList.remove(val)   删除val,如果val在列表中,如果不在则引发ValueError
# SortedList.pop(index=-1)  删除并返回index处的数据

s1l.discard(5)  # 5不在列表中不做任何操作
s2l.remove(5)   # 5不在列表中会引发ValueError
s2l.pop()      # 返回9
  1. 查找
# SortedList.bisect_left(val)  # 返回如果插入val,保持有序时val的下标,如果val已经存在于列表中,返回val下标
# SortedList.bisect_right(val)  # 返回如果插入val,保持有序时val的下标,如果val已经存在于列表中,返回val后一个位置下标
# SortedList.count(val)
# SortedList.index(value, start=None, stop=None)  
# sl2 = [3,3,4,6,6,9,9]
sl2.bisect_left(4)  # 2
sl2.bisect_right(5)  # 3
sl2.bisect_right(4)  # 3
  1. 其他
# SortedList.irange(minimum=None, maximum=None, inclusive=(True, True), reverse=False)  # 返回一个列表元素位于从minimum到maximum的迭代器
# SortedList.islice(start=None, stop=None, reverse=False)  # 切片,返回一个start到stop-1的迭代器,start和stop-1是下标
# sl2 = [3,3,4,6,6,9,9]
sl2.irange(3,7)  # iterator([3,3,4,6,6])
sl2.islice(1,4)  # iterator([3,4,6])

实现原理

SortedList的内部结构实际上是一个类B+树结构,只有两层。

    DEFAULT_LOAD_FACTOR = 1000  # 每个子节点数据存储因子,当子节点数据长度达到其两倍时,拆分子节点为两个子节点
    def __init__(self, iterable=None, key=None):
        """Initialize sorted list instance.
        Optional `iterable` argument provides an initial iterable of values to
        initialize the sorted list.
        Runtime complexity: `O(n*log(n))`
        """
        assert key is None
        self._len = 0   # 列表长度
        self._load = self.DEFAULT_LOAD_FACTOR
        self._lists = []  # 二维列表,存储B+树的子节点
        self._maxes = []  # 一维列表,存储的是每个子节点中的最大值
        self._index = []  # 索引,是由子节点长度为叶子节点构成的满二叉树,更确切的说是一个堆,父节点是子节点值的和。 以列表的形式保存
        self._offset = 0  # 索引中第一个子节点的位置

        if iterable is not None:
            self._update(iterable)

    def add(self, value):
        _lists = self._lists
        _maxes = self._maxes  # 保存的是每个子节点最大值

        if _maxes:
            pos = bisect_right(_maxes, value)  # 先通过根节点self._maxes寻找value应该插入哪个子节点
            # pos即self._lists中的子数组下标
            if pos == len(_maxes):  # 如果pos等于maxes长度,则新插入值value大于列表中的最大值
                pos -= 1
                _lists[pos].append(value)  # 只需把value添加在最后
                _maxes[pos] = value  # 把maxes的组后一个值改为value
            else:
                insort(_lists[pos], value)  # bisect.insort_right 有序插入

            self._expand(pos)  # 判断是否需要拆分子节点
        else:
            _lists.append([value])
            _maxes.append(value)

        self._len += 1

    def _expand(self, pos):
        # 判断是否需要拆分子节点
        _load = self._load
        _lists = self._lists
        _index = self._index

        if len(_lists[pos]) > (_load << 1):
            # 如果子节点长度超过存储因子两倍则拆分子节点
            _maxes = self._maxes

            _lists_pos = _lists[pos]
            half = _lists_pos[_load:]  # 提取后一半
            del _lists_pos[_load:]  # 在原数组中删除后一半
            _maxes[pos] = _lists_pos[-1]  # 更改在maxes中对应的数组最大值

            _lists.insert(pos + 1, half)  # 把后一半作为新数组插入
            _maxes.insert(pos + 1, half[-1])  # 把新数组的最大值插入

            del _index[:]  # 删除索引
        else:
            if _index:  # 如果没有超过两倍且有索引则更新索引
                child = self._offset + pos  # 通过offset得到所插入的子节点对应的下标
                while child:
                    _index[child] += 1  # 长度加一
                    child = (child - 1) >> 1  # 依次更新受影响的父节点
                _index[0] += 1  # 总长度加一


    def update(self, iterable):
        # 把一个迭代对象的全部元素加入 
        _lists = self._lists
        _maxes = self._maxes
        values = sorted(iterable)  # 把可迭代对象的元素排序

        if _maxes:
            if len(values) * 4 >= self._len:  # 如果要插入的数据长度达到已有数据的四倍
                _lists.append(values)   
                values = reduce(iadd, _lists, []) # 则把所有数据保存到一个列表中
                values.sort()  # 将其排序
                self._clear()  # 清空全部数据,重修构建列表
            else:
                _add = self.add
                for val in values:  # 数据量达不到四倍的话依次调用add方法加入数据
                    _add(val)
                return

        _load = self._load
        # 把所有数据分成self._load大小,最后一个子数组不保证其长度
        _lists.extend(values[pos:(pos + _load)]
                      for pos in range(0, len(values), _load))
        _maxes.extend(sublist[-1] for sublist in _lists)  # 更新maxes为每个子数组的最大值
        self._len = len(values)  # 更新长度
        del self._index[:]  # 删除索引

    def _build_index(self):
        """创建索引
        Build a positional index for indexing the sorted list.

        Indexes are represented as binary trees in a dense array notation
        similar to a binary heap.

        For example, given a lists representation storing integers::

            0: [1, 2, 3]
            1: [4, 5]
            2: [6, 7, 8, 9]
            3: [10, 11, 12, 13, 14]

        The first transformation maps the sub-lists by their length. The
        first row of the index is the length of the sub-lists::

            0: [3, 2, 4, 5]

        Each row after that is the sum of consecutive pairs of the previous
        row::

            1: [5, 9]
            2: [14]

        Finally, the index is built by concatenating these lists together::

            _index = [14, 5, 9, 3, 2, 4, 5]

        An offset storing the start of the first row is also stored::

            _offset = 3

        When built, the index can be used for efficient indexing into the list.
        See the comment and notes on ``SortedList._pos`` for details.

        """
        row0 = list(map(len, self._lists))  # 统计每个子数组长度作为子节点

        if len(row0) == 1:  # 如果只有一个子数组
            self._index[:] = row0  # 索引只有子节点
            self._offset = 0  # 偏移为0
            return

        # 源码中充分利用python特性的代码,十分巧妙
        head = iter(row0)  # 返回子节点的迭代器
        tail = iter(head)  # 返回子节点的迭代器的迭代器
        # zip(head, tail) 会得到一个元素是原列表(偶数下标元素,奇数下标元素)的迭代器
        # 比如 原列表[1,2,3,4,5] 
        # 会得到 [(1,2), (3,4)] 
        row1 = list(starmap(add, zip(head, tail)))
        # 使用functools.starmap方法对每个元素求和
        # 就会得到子节点的父节点,每个父节点的值是其两个子节点之和

        if len(row0) & 1:  # 如果子节点的个数是奇数, 则少计算了最后一个叶子结点
            row1.append(row0[-1])  # 把最后一个子节点加入row1

        if len(row1) == 1:  # 如果row1只有一个结点,则索引创建完成
            self._index[:] = row1 + row0  # 索引即row1和row0构成
            self._offset = 1
            return

        size = 2 ** (int(log(len(row1) - 1, 2)) + 1)  # 非叶子结点层的节点数必须是2的幂
        row1.extend(repeat(0, size - len(row1)))  # 在row1后面补0,让row1可以生成下一层,让row1的长度达到2的幂
        tree = [row0, row1]  
        # 不断的用tree[-1]构建下一层,直到tree[-1]只有一个结点
        while len(tree[-1]) > 1:
            head = iter(tree[-1])
            tail = iter(head)
            row = list(starmap(add, zip(head, tail)))
            tree.append(row)

        reduce(iadd, reversed(tree), self._index)  # 把tree的所有子数组串起来然后逆序即为索引
        self._offset = size * 2 - 1  # row1层有size个结点,则非叶子结点有size*2-1个,所以偏移即为size*2-1

    def _loc(self, pos, idx):
        """把二维索引(子数组索引,子数组内索引)转换为一维索引
        Convert an index pair (lists index, sublist index) into a single
        index number that corresponds to the position of the value in the
        sorted list.

        Many queries require the index be built. Details of the index are
        described in ``SortedList._build_index``.

        Indexing requires traversing the tree from a leaf node to the root. The
        parent of each node is easily computable at ``(pos - 1) // 2``.

        Left-child nodes are always at odd indices and right-child nodes are
        always at even indices.

        When traversing up from a right-child node, increment the total by the
        left-child node.

        The final index is the sum from traversal and the index in the sublist.

        For example, using the index from ``SortedList._build_index``::

            _index = 14 5 9 3 2 4 5
            _offset = 3

        Tree::

                 14
              5      9
            3   2  4   5

        Converting an index pair (2, 3) into a single index involves iterating
        like so:

        1. Starting at the leaf node: offset + alpha = 3 + 2 = 5. We identify
           the node as a left-child node. At such nodes, we simply traverse to
           the parent.

        2. At node 9, position 2, we recognize the node as a right-child node
           and accumulate the left-child in our total. Total is now 5 and we
           traverse to the parent at position 0.

        3. Iteration ends at the root.

        The index is then the sum of the total and sublist index: 5 + 3 = 8.

        :param int pos: lists index
        :param int idx: sublist index
        :return: index in sorted list
       
        """
        
        if not pos:
            return idx

        _index = self._index

        if not _index:
            self._build_index()  # 没有索引先创建索引

        total = 0

        # Increment pos to point in the index to len(self._lists[pos]).

        pos += self._offset  # offset+pos得到对应子节点下标,然后计算前pos个子数组长度和

        # Iterate until reaching the root of the index tree at pos = 0.

        while pos:

            # Right-child nodes are at odd indices. At such indices
            # account the total below the left child node.

            if not pos & 1:  # 奇数下标一定是左子树,偶数下标一定是右子树
                total += _index[pos - 1]  # 如果定位到右子树则累积其兄弟左子树的值

            # Advance pos to the parent node.

            pos = (pos - 1) >> 1  # 跳转到父节点

        return total + idx  # 返回total+idx,即前pos个子节点(从1开始计数)的数据长度+数组内idx


    def _pos(self, idx):
        """ 把一个一维坐标转换为二维坐标
        Convert an index into an index pair (lists index, sublist index)
        that can be used to access the corresponding lists position.

        Many queries require the index be built. Details of the index are
        described in ``SortedList._build_index``.

        Indexing requires traversing the tree to a leaf node. Each node has two
        children which are easily computable. Given an index, pos, the
        left-child is at ``pos * 2 + 1`` and the right-child is at ``pos * 2 +
        2``.

        When the index is less than the left-child, traversal moves to the
        left sub-tree. Otherwise, the index is decremented by the left-child
        and traversal moves to the right sub-tree.

        At a child node, the indexing pair is computed from the relative
        position of the child node as compared with the offset and the remaining
        index.

        For example, using the index from ``SortedList._build_index``::

            _index = 14 5 9 3 2 4 5
            _offset = 3

        Tree::

                 14
              5      9
            3   2  4   5

        Indexing position 8 involves iterating like so:

        1. Starting at the root, position 0, 8 is compared with the left-child
           node (5) which it is greater than. When greater the index is
           decremented and the position is updated to the right child node.

        2. At node 9 with index 3, we again compare the index to the left-child
           node with value 4. Because the index is the less than the left-child
           node, we simply traverse to the left.

        3. At node 4 with index 3, we recognize that we are at a leaf node and
           stop iterating.

        4. To compute the sublist index, we subtract the offset from the index
           of the leaf node: 5 - 3 = 2. To compute the index in the sublist, we
           simply use the index remaining from iteration. In this case, 3.

        The final index pair from our example is (2, 3) which corresponds to
        index 8 in the sorted list.

        :param int idx: index in sorted list
        :return: (lists index, sublist index) pair

        """
        # 从根节点出发,比较idx值和左子结点值,大于则进入右子树,idx减去左子节点值,然后递归下去直到叶子结点
        # 叶子结点-offset即 pos值,最后idx的值即为组内idx
        if idx < 0:
            last_len = len(self._lists[-1])

            if (-idx) <= last_len:
                return len(self._lists) - 1, last_len + idx

            idx += self._len

            if idx < 0:
                raise IndexError('list index out of range')
        elif idx >= self._len:
            raise IndexError('list index out of range')

        if idx < len(self._lists[0]):
            return 0, idx

        _index = self._index

        if not _index:
            self._build_index()

        pos = 0
        child = 1
        len_index = len(_index)

        while child < len_index:
            index_child = _index[child]

            if idx < index_child:
                pos = child
            else:
                idx -= index_child
                pos = child + 1
            child = (pos << 1) + 1

        return (pos - self._offset, idx)

    def bisect_left(self, value):
        """Return an index to insert `value` in the sorted list.

        If the `value` is already present, the insertion point will be before
        (to the left of) any existing values.

        Similar to the `bisect` module in the standard library.

        Runtime complexity: `O(log(n))` -- approximate.

        >>> sl = SortedList([10, 11, 12, 13, 14])
        >>> sl.bisect_left(12)
        2

        :param value: insertion index of value in sorted list
        :return: index

        """
        _maxes = self._maxes

        if not _maxes:
            return 0
        # 首先在maxes中查找子数组下标
        pos = bisect_left(_maxes, value)

        if pos == len(_maxes):
            return self._len
        # 然后在子数组中寻找
        idx = bisect_left(self._lists[pos], value)
        return self._loc(pos, idx)  # 通过loc方法把二维坐标(pos,idx)转换为一维


    def bisect_right(self, value):
        """Return an index to insert `value` in the sorted list.

        Similar to `bisect_left`, but if `value` is already present, the
        insertion point will be after (to the right of) any existing values.

        Similar to the `bisect` module in the standard library.

        Runtime complexity: `O(log(n))` -- approximate.

        >>> sl = SortedList([10, 11, 12, 13, 14])
        >>> sl.bisect_right(12)
        3

        :param value: insertion index of value in sorted list
        :return: index

        """
        _maxes = self._maxes

        if not _maxes:
            return 0

        pos = bisect_right(_maxes, value)

        if pos == len(_maxes):
            return self._len

        idx = bisect_right(self._lists[pos], value)
        return self._loc(pos, idx)

    def __getitem__(self, index):
        """
        Lookup value at `index` in sorted list.

        ``sl.__getitem__(index)`` <==> ``sl[index]``

        Supports slicing.

        Runtime complexity: `O(log(n))` -- approximate.

        >>> sl = SortedList('abcde')
        >>> sl[1]
        'b'
        >>> sl[-1]
        'e'
        >>> sl[2:5]
        ['c', 'd', 'e']

        :param index: integer or slice for indexing
        :return: value or list of values
        :raises IndexError: if index out of range

        """
        _lists = self._lists

        if isinstance(index, slice):
            start, stop, step = index.indices(self._len)

            if step == 1 and start < stop:
                # Whole slice optimization: start to stop slices the whole
                # sorted list.

                if start == 0 and stop == self._len:
                    return reduce(iadd, self._lists, [])

                start_pos, start_idx = self._pos(start)  # 通过pos方法把一维坐标转换为二维
                start_list = _lists[start_pos]
                stop_idx = start_idx + stop - start

                # Small slice optimization: start index and stop index are
                # within the start list.

                if len(start_list) >= stop_idx:
                    return start_list[start_idx:stop_idx]

                if stop == self._len:
                    stop_pos = len(_lists) - 1
                    stop_idx = len(_lists[stop_pos])
                else:
                    stop_pos, stop_idx = self._pos(stop)

                prefix = _lists[start_pos][start_idx:]
                middle = _lists[(start_pos + 1):stop_pos]
                result = reduce(iadd, middle, prefix)
                result += _lists[stop_pos][:stop_idx]

                return result

            if step == -1 and start > stop:
                result = self._getitem(slice(stop + 1, start + 1))
                result.reverse()
                return result

            # Return a list because a negative step could
            # reverse the order of the items and this could
            # be the desired behavior.

            indices = range(start, stop, step)
            return list(self._getitem(index) for index in indices)
        else:
            if self._len:
                if index == 0:
                    return _lists[0][0]
                elif index == -1:
                    return _lists[-1][-1]
            else:
                raise IndexError('list index out of range')

            if 0 <= index < len(_lists[0]):
                return _lists[0][index]

            len_last = len(_lists[-1])

            if -len_last < index < 0:
                return _lists[-1][len_last + index]

            pos, idx = self._pos(index)  # 通过pos方法把一维坐标转换为二维
            return _lists[pos][idx]
posted @ 2022-08-01 21:07  店里最会撒谎白玉汤  阅读(483)  评论(0编辑  收藏  举报