faster-fifo:C++实现的python多进程通信队列 —— 强化学习ppo算法库sample-factory的C++实现的python多进程通信队列 —— python3.12版本下成功通过测试

项目地址:

https://github.com/alex-petrenko/faster-fifo




需要注意,该项目给出了两种安装方法,一种是pip从pypi官网安装,一种是从GitHub上的源码安装;经过测试发现这个项目维护程度较差,因此pypi官网上的项目比较落后,因此不建议使用pypi上的安装,而是进行源码编译安装。


给出源码编译安装方法:(经过测试,该项目可以在python3.12版本上成功编译,并通过unittest测试)

pip install Cython
python setup.py build_ext --inplace
pip install -e .


测试命令:

python -m unittest

测试结果:

image



项目页的Demo代码有错误,下面给出改正的代码:

from faster_fifo import Queue
from queue import Full, Empty

import logging
log = logging.getLogger('rl')

q = Queue(1000 * 1000)  # specify the size of the circular buffer in the ctor

# any pickle-able Python object can be added to the queue
py_obj = dict(a=42, b=33, c=(1, 2, 3), d=[1, 2, 3], e='123', f=b'kkk')
q.put(py_obj)
assert q.qsize() == 1

retrieved = q.get()
assert q.empty()
assert py_obj == retrieved

for i in range(100):
    try:
        q.put(py_obj, timeout=0.1)
    except Full:
        log.debug('Queue is full!')

num_received = 0
while num_received < 100:
    # get multiple messages at once, returns a list of messages for better performance in many-to-few scenarios
    # get_many does not guarantee that all max_messages_to_get will be received on the first call, in fact
    # no such guarantee can be made in multiprocessing systems.
    # get_many() will retrieve as many messages as there are available AND can fit in the pre-allocated memory
    # buffer. The size of the buffer is increased gradually to match demand.
    messages = q.get_many(max_messages_to_get=100)
    num_received += len(messages)

try:
    q.get(timeout=0.1)
    assert True, 'This won\'t be called'
except Empty:
    log.debug('Queue is empty')


关于这个项目并没有给出使用文档和帮助文档,因此对于这个项目的使用方法和API可以参考项目中的测试test代码:

import logging
import multiprocessing
from queue import Full, Empty
from time import time
from unittest import TestCase
import ctypes

from faster_fifo import Queue

ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

fmt = logging.Formatter('[%(asctime)s][%(process)05d] %(message)s')
ch.setFormatter(fmt)

log = logging.getLogger('rl')
log.setLevel(logging.DEBUG)
log.handlers = []  # No duplicated handlers
log.propagate = False  # workaround for duplicated logs in ipython
log.addHandler(ch)

MSG_SIZE = 5


def make_msg(msg_idx):
    return (msg_idx,) * MSG_SIZE


def produce_msgs(q, p_idx, num_messages):
    i = 0
    while i < num_messages:
        try:
            q.put(make_msg(i), timeout=0.01)
            if i % 50000 == 0:
                log.info('Produce: %d %d', i, p_idx)
            i += 1
        except Full:
            pass
        except Exception as exc:
            log.exception(exc)


def consume_msgs(q, p_idx, all_msgs_sent, consume_many=1):
    num_received = 0
    while True:
        try:
            if consume_many == 1:
                msg = q.get(timeout=0.01)
                msgs = [msg]
            else:
                msgs = q.get_many(timeout=0.01, max_messages_to_get=consume_many)

            for msg in msgs:
                if msg[0] % 50000 == 0:
                    log.info('Consume: %r %d num_msgs: %d total received: %d', msg, p_idx, len(msgs), num_received)
                num_received += 1

        except Empty:
            if all_msgs_sent.value:
                break
        except Exception as exc:
            log.exception(exc)


def run_test(queue_cls, num_producers, num_consumers, msgs_per_prod, consume_many):
    start_time = time()
    q = queue_cls(100000)

    producers = []
    consumers = []
    all_msgs_sent = multiprocessing.RawValue(ctypes.c_bool, False)
    for j in range(num_producers):
        p = multiprocessing.Process(target=produce_msgs, args=(q, j, msgs_per_prod))
        producers.append(p)
    for j in range(num_consumers):
        p = multiprocessing.Process(target=consume_msgs, args=(q, j, all_msgs_sent, consume_many))
        consumers.append(p)
    for p in producers:
        p.start()
    for c in consumers:
        c.start()
    for p in producers:
        p.join()
    all_msgs_sent.value = True
    for c in consumers:
        c.join()
    q.close()
    log.info('Exiting queue type %s', queue_cls.__module__ + '.' + queue_cls.__name__)
    end_time = time()
    time_taken = end_time - start_time
    log.info('Time taken by queue type %s is %.5f', queue_cls.__module__ + '.' + queue_cls.__name__, time_taken)
    return time_taken


class ComparisonTestCase(TestCase):
    @staticmethod
    def comparison(n_prod, n_con, n_msgs):
        n_msgs += 1  # +1 here to make sure the last log line will be printed
        time_ff = run_test(Queue, num_producers=n_prod, num_consumers=n_con, msgs_per_prod=n_msgs, consume_many=1)
        time_ff_many = run_test(Queue, num_producers=n_prod, num_consumers=n_con, msgs_per_prod=n_msgs,
                                consume_many=100)
        time_mp = run_test(multiprocessing.Queue, num_producers=n_prod, num_consumers=n_con, msgs_per_prod=n_msgs,
                           consume_many=1)

        if time_ff > time_mp:
            log.warning(f'faster-fifo took longer than mp.Queue ({time_ff=} vs {time_mp=}) on configuration ({n_prod}, {n_con}, {n_msgs})')
        if time_ff_many > time_mp:
            log.warning(f'faster-fifo get_many() took longer than mp.Queue ({time_ff_many=} vs {time_mp=}) on configuration ({n_prod}, {n_con}, {n_msgs})')

        return time_ff, time_ff_many, time_mp

    def test_all_configurations(self):
        configurations = (
            (1, 1, 200000),
            (1, 10, 200000),
            (10, 1, 100000),
            (3, 20, 100000),
            (20, 3, 50000),
            (20, 20, 50000),
        )

        results = []
        for c in configurations:
            results.append(self.comparison(*c))

        log.info('\nResults:\n')
        for c, r in zip(configurations, results):
            log.info('Configuration %r, timing [ff: %.2fs, ff_many: %.2fs, mp.queue: %.2fs]', c, *r)


# i9-7900X (10-core CPU)
# [2020-05-16 03:24:26,548][30412] Configuration (1, 1, 200000), timing [ff: 0.92s, ff_many: 0.93s, mp.queue: 2.83s]
# [2020-05-16 03:24:26,548][30412] Configuration (1, 10, 200000), timing [ff: 1.43s, ff_many: 1.40s, mp.queue: 7.60s]
# [2020-05-16 03:24:26,548][30412] Configuration (10, 1, 100000), timing [ff: 4.95s, ff_many: 1.40s, mp.queue: 12.24s]
# [2020-05-16 03:24:26,548][30412] Configuration (3, 20, 100000), timing [ff: 2.29s, ff_many: 2.25s, mp.queue: 13.25s]
# [2020-05-16 03:24:26,548][30412] Configuration (20, 3, 50000), timing [ff: 3.19s, ff_many: 1.12s, mp.queue: 29.07s]
# [2020-05-16 03:24:26,548][30412] Configuration (20, 20, 50000), timing [ff: 1.65s, ff_many: 4.14s, mp.queue: 46.71s]

# With Erik's changes to prevent stale (version 1.1.2)
# [2021-05-14 00:51:46,237][25370] Configuration (1, 1, 200000), timing [ff: 1.05s, ff_many: 1.10s, mp.queue: 2.32s]
# [2021-05-14 00:51:46,237][25370] Configuration (1, 10, 200000), timing [ff: 1.49s, ff_many: 1.51s, mp.queue: 3.31s]
# [2021-05-14 00:51:46,237][25370] Configuration (10, 1, 100000), timing [ff: 6.07s, ff_many: 0.97s, mp.queue: 12.92s]
# [2021-05-14 00:51:46,237][25370] Configuration (3, 20, 100000), timing [ff: 2.27s, ff_many: 2.19s, mp.queue: 6.55s]
# [2021-05-14 00:51:46,237][25370] Configuration (20, 3, 50000), timing [ff: 7.65s, ff_many: 0.70s, mp.queue: 15.40s]
# [2021-05-14 00:51:46,237][25370] Configuration (20, 20, 50000), timing [ff: 1.82s, ff_many: 4.14s, mp.queue: 31.65s]
# Ran 1 test in 103.115s

# Ubuntu 20, Python 3.9, version 1.3.1
# For some reason (10, 1) configuration became 2x slower. This does not seem to be any kind of regression in the code
# because it also reproduces in older versions on my new system. Seems to be caused by Python environment/Linux version/compiler?
# [2022-03-31 02:28:18,986][549096] Configuration (1, 1, 200000), timing [ff: 1.06s, ff_many: 1.06s, mp.queue: 2.10s]
# [2022-03-31 02:28:18,986][549096] Configuration (1, 10, 200000), timing [ff: 1.51s, ff_many: 1.52s, mp.queue: 2.96s]
# [2022-03-31 02:28:18,986][549096] Configuration (10, 1, 100000), timing [ff: 13.10s, ff_many: 0.99s, mp.queue: 11.54s]
# [2022-03-31 02:28:18,987][549096] Configuration (3, 20, 100000), timing [ff: 3.04s, ff_many: 2.22s, mp.queue: 7.19s]
# [2022-03-31 02:28:18,987][549096] Configuration (20, 3, 50000), timing [ff: 14.80s, ff_many: 0.67s, mp.queue: 15.29s]
# [2022-03-31 02:28:18,987][549096] Configuration (20, 20, 50000), timing [ff: 1.33s, ff_many: 3.91s, mp.queue: 21.46s]
# Ran 1 test in 105.765s

# Ubuntu 20, Python 3.8, version 1.4.1
# [2022-07-21 01:32:09,831][2374340] Configuration (1, 1, 200000), timing [ff: 1.13s, ff_many: 1.10s, mp.queue: 2.20s]
# [2022-07-21 01:32:09,831][2374340] Configuration (1, 10, 200000), timing [ff: 1.72s, ff_many: 1.69s, mp.queue: 3.51s]
# [2022-07-21 01:32:09,831][2374340] Configuration (10, 1, 100000), timing [ff: 13.71s, ff_many: 1.28s, mp.queue: 12.16s]
# [2022-07-21 01:32:09,831][2374340] Configuration (3, 20, 100000), timing [ff: 3.18s, ff_many: 2.29s, mp.queue: 7.81s]
# [2022-07-21 01:32:09,831][2374340] Configuration (20, 3, 50000), timing [ff: 15.47s, ff_many: 0.75s, mp.queue: 17.52s]
# [2022-07-21 01:32:09,831][2374340] Configuration (20, 20, 50000), timing [ff: 1.26s, ff_many: 3.74s, mp.queue: 27.82s]
# Ran 1 test in 118.350s


# i5-4200U (dual-core CPU)
# [2020-05-22 18:03:55,061][09146] Configuration (1, 1, 200000), timing [ff: 2.09s, ff_many: 2.20s, mp.queue: 7.86s]
# [2020-05-22 18:03:55,061][09146] Configuration (1, 10, 200000), timing [ff: 4.01s, ff_many: 3.88s, mp.queue: 11.68s]
# [2020-05-22 18:03:55,061][09146] Configuration (10, 1, 100000), timing [ff: 16.68s, ff_many: 5.98s, mp.queue: 44.48s]
# [2020-05-22 18:03:55,061][09146] Configuration (3, 20, 100000), timing [ff: 7.83s, ff_many: 7.49s, mp.queue: 22.59s]
# [2020-05-22 18:03:55,061][09146] Configuration (20, 3, 50000), timing [ff: 22.30s, ff_many: 6.35s, mp.queue: 66.30s]
# [2020-05-22 18:03:55,061][09146] Configuration (20, 20, 50000), timing [ff: 14.39s, ff_many: 15.78s, mp.queue: 78.75s]


# Update (2023.08.11)
# Ubuntu 18, Intel(R) Xeon(R) CPU E5-2650, Python 3.9.16
# [2023-08-11 17:37:58,796][40648] Configuration (1, 1, 200000), timing [ff: 2.28s, ff_many: 2.46s, mp.queue: 3.48s]
# [2023-08-11 17:37:58,796][40648] Configuration (1, 10, 200000), timing [ff: 2.71s, ff_many: 2.82s, mp.queue: 11.65s]
# [2023-08-11 17:37:58,796][40648] Configuration (10, 1, 100000), timing [ff: 13.69s, ff_many: 1.95s, mp.queue: 18.39s]
# [2023-08-11 17:37:58,796][40648] Configuration (3, 20, 100000), timing [ff: 2.97s, ff_many: 2.30s, mp.queue: 21.10s]
# [2023-08-11 17:37:58,796][40648] Configuration (20, 3, 50000), timing [ff: 19.75s, ff_many: 1.08s, mp.queue: 23.24s]
# [2023-08-11 17:37:58,796][40648] Configuration (20, 20, 50000), timing [ff: 2.81s, ff_many: 3.73s, mp.queue: 70.49s]
# Ran 1 test in 206.923s

# Ubuntu 18, Intel(R) Xeon(R) CPU E5-2650, Python 3.11.4
# [2023-08-11 17:46:36,056][42634] Configuration (1, 1, 200000), timing [ff: 1.67s, ff_many: 1.77s, mp.queue: 2.45s]
# [2023-08-11 17:46:36,056][42634] Configuration (1, 10, 200000), timing [ff: 2.27s, ff_many: 2.31s, mp.queue: 5.61s]
# [2023-08-11 17:46:36,056][42634] Configuration (10, 1, 100000), timing [ff: 13.15s, ff_many: 1.86s, mp.queue: 14.05s]
# [2023-08-11 17:46:36,056][42634] Configuration (3, 20, 100000), timing [ff: 2.97s, ff_many: 2.23s, mp.queue: 12.71s]
# [2023-08-11 17:46:36,056][42634] Configuration (20, 3, 50000), timing [ff: 19.22s, ff_many: 0.99s, mp.queue: 17.28s]
# [2023-08-11 17:46:36,056][42634] Configuration (20, 20, 50000), timing [ff: 2.60s, ff_many: 3.52s, mp.queue: 43.30s]
# Ran 1 test in 149.972s


import logging
import multiprocessing
import threading
from queue import Full, Empty
from typing import Callable
from unittest import TestCase

import numpy as np

from faster_fifo import Queue


ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

log = logging.getLogger("rl")
log.setLevel(logging.DEBUG)
log.handlers = []  # No duplicated handlers
log.propagate = False  # workaround for duplicated logs in ipython
log.addHandler(ch)

MSG_SIZE = 5
BIG_MSG_MAX_SIZE = int(1e7)


def make_msg(msg_idx):
    return (msg_idx,) * MSG_SIZE


def make_big_msg(msg_idx):
    rand_size = np.random.randint(1, min((msg_idx + 1) * 1000, BIG_MSG_MAX_SIZE))
    return np.empty(rand_size, dtype=np.uint8)


def produce(q, p_idx, num_messages, make_msg_fn: Callable = make_msg):
    i = 0
    while i < num_messages:
        try:
            q.put(make_msg_fn(i), timeout=0.01)
            if i % 50000 == 0:
                log.info("Produce: %d %d", i, p_idx)
            i += 1
        except Full:
            # time.sleep(0.001)
            pass
    log.info("Done! %d", p_idx)


def consume(q, p_idx, consume_many, total_num_messages=int(1e9)):
    messages_received = 0
    while True:
        try:
            msgs = q.get_many(timeout=0.01, max_messages_to_get=consume_many)
            for msg in msgs:
                messages_received += 1
                if msg[0] % 50000 == 0:
                    log.info("Consume: %r %d num_msgs: %d", msg, p_idx, len(msgs))
            if messages_received >= total_num_messages:
                break
        except Empty:
            if q.is_closed():
                break
    log.info("Done! %d", p_idx)


class TestFastQueue(TestCase):
    def test_singleproc(self):
        q = Queue()
        produce(q, 0, num_messages=20)
        consume(q, 0, consume_many=2, total_num_messages=20)
        q.close()
        self.assertIsNone(q.last_error)

    def run_producer_consumer(
        self,
        n_producers: int,
        n_consumers: int,
        n_msg: int,
        execution_medium: type,
        make_msg_fn: Callable,
    ):
        q = Queue()
        consume_many = 1000
        producers = []
        consumers = []
        for j in range(n_producers):
            p = execution_medium(target=produce, args=(q, j, n_msg, make_msg_fn))
            producers.append(p)
        for j in range(n_consumers):
            p = execution_medium(target=consume, args=(q, j, consume_many))
            consumers.append(p)
        for c in consumers:
            c.start()
        for p in producers:
            p.start()
        for p in producers:
            p.join()
        q.close()
        for c in consumers:
            c.join()

        self.assertIsNone(q.last_error)
        log.info("Exit...")

    def test_multiprocessing(self):
        self.run_producer_consumer(
            20,
            3,
            100001,
            execution_medium=multiprocessing.Process,
            make_msg_fn=make_msg,
        )

    def test_multithreading(self):
        self.run_producer_consumer(
            20,
            3,
            100001,
            execution_medium=threading.Thread,
            make_msg_fn=make_msg,
        )

    def test_multiprocessing_big_msg(self):
        self.run_producer_consumer(
            20,
            3,
            1001,
            execution_medium=multiprocessing.Process,
            make_msg_fn=make_big_msg,
        )

    def test_multithreading_big_msg(self):
        self.run_producer_consumer(
            20,
            20,
            101,
            execution_medium=threading.Thread,
            make_msg_fn=make_big_msg,
        )

    def test_msg(self):
        q = Queue(max_size_bytes=1000)

        py_obj = dict(a=42, b=33, c=(1, 2, 3), d=[1, 2, 3], e="123", f=b"kkk")
        q.put_nowait(py_obj)
        res = q.get_nowait()
        log.debug("got object %r", res)
        self.assertEqual(py_obj, res)

    def test_msg_many(self):
        q = Queue(max_size_bytes=100000)

        py_objs = [
            dict(a=42, b=33, c=(1, 2, 3), d=[1, 2, 3], e="123", f=b"kkk")
            for _ in range(5)
        ]
        q.put_many_nowait(py_objs)
        res = q.get_many_nowait()

        while not q.empty():
            res.extend(q.get_many_nowait())

        log.debug("Got object %r", res)
        self.assertEqual(py_objs, res)

        q.put_nowait(py_objs)
        res = q.get_nowait()
        self.assertEqual(py_objs, res)

    def test_queue_size(self):
        q = Queue(max_size_bytes=1000)
        py_obj_1 = dict(a=10, b=20)
        py_obj_2 = dict(a=30, b=40)
        q.put_nowait(py_obj_1)
        q.put_nowait(py_obj_2)
        q_size_bef = q.qsize()
        log.debug("Queue size after put -  %d", q_size_bef)
        num_messages = 0
        want_to_read = 2
        while num_messages < want_to_read:
            msgs = q.get_many()
            print(msgs)
            num_messages += len(msgs)
        self.assertEqual(type(q_size_bef), int)
        q_size_af = q.qsize()
        log.debug("Queue size after get -  %d", q_size_af)
        self.assertEqual(q_size_af, 0)

    def test_queue_empty(self):
        q = Queue(max_size_bytes=1000)
        self.assertTrue(q.empty())
        py_obj = dict(a=42, b=33, c=(1, 2, 3), d=[1, 2, 3], e="123", f=b"kkk")
        q.put_nowait(py_obj)
        q_empty = q.empty()
        self.assertFalse(q_empty)

    def test_queue_data_size(self):
        q = Queue(max_size_bytes=1000)
        py_obj = dict(a=10, b=20)
        q.put_nowait(py_obj)
        py_obj_size = q.data_size()
        log.debug("Queue data size after put -  %d", py_obj_size)
        q.put_nowait(py_obj)
        self.assertTrue(q.data_size(), 2 * py_obj_size)

    def test_queue_full(self):
        q = Queue(max_size_bytes=60)
        self.assertFalse(q.full())
        py_obj = (1, 2)
        while True:
            try:
                q.put_nowait(py_obj)
            except Full:
                self.assertTrue(q.full())
                break

    def test_queue_usage(self):
        q = Queue(1000 * 1000)  # specify the size of the circular buffer in the ctor

        # any pickle-able Python object can be added to the queue
        py_obj = dict(a=42, b=33, c=(1, 2, 3), d=[1, 2, 3], e="123", f=b"kkk")
        q.put(py_obj)
        assert q.qsize() == 1

        retrieved = q.get()
        assert q.empty()
        assert py_obj == retrieved

        for i in range(100):
            try:
                q.put(py_obj, timeout=0.1)
            except Full:
                log.debug("Queue is full!")

        num_received = 0
        while num_received < 100:
            # get multiple messages at once, returns a list of messages for better performance in many-to-few scenarios
            # get_many does not guarantee that all max_messages_to_get will be received on the first call, in fact
            # no such guarantee can be made in multiprocessing systems.
            # get_many() will retrieve as many messages as there are available AND can fit in the pre-allocated memory
            # buffer. The size of the buffer is increased gradually to match demand.
            messages = q.get_many(max_messages_to_get=100)
            num_received += len(messages)

        try:
            q.get(timeout=0.1)
            assert True, "This won't be called"
        except Empty:
            log.debug("Queue is empty")

    def test_max_size(self):
        q = Queue(
            max_size_bytes=1000, maxsize=5
        )  # Create a queue with a maximum of 5 messages

        for i in range(5):
            q.put_nowait(make_msg(i))  # Add 5 messages to the queue

        q_size_bef = q.qsize()
        log.debug("Queue size after put -  %d", q_size_bef)

        self.assertTrue(q.full())  # Check that the queue is full

        with self.assertRaises(
            Full
        ):  # Check that adding another message raises the Full exception
            q.put_nowait(make_msg(5))
        q.get_many()
        self.assertFalse(q.full())  # Check that the queue is not full

    def test_queue_full_msgs(self):
        q = Queue(maxsize=5)
        self.assertFalse(q.full())
        py_obj = (1, 2)
        for _ in range(5):
            q.put_nowait(py_obj)
        self.assertTrue(q.full())
        with self.assertRaises(Full):
            q.put_nowait(py_obj)

    def test_producer_consumer_msgs(self):
        self.run_producer_consumer_msgs(
            1,
            1,
            10,
            threading.Thread,
            make_msg,
        )

    def run_producer_consumer_msgs(
        self,
        n_producers: int,
        n_consumers: int,
        n_msg: int,
        execution_medium: type,
        make_msg_fn: Callable,
    ):
        q = Queue(maxsize=n_msg * n_producers)
        consume_many = 5
        producers = []
        consumers = []
        for j in range(n_producers):
            p = execution_medium(target=produce, args=(q, j, n_msg, make_msg_fn))
            producers.append(p)
        for j in range(n_consumers):
            p = execution_medium(target=consume, args=(q, j, consume_many))
            consumers.append(p)
        for c in consumers:
            c.start()
        for p in producers:
            p.start()
        for p in producers:
            p.join()
        q.close()
        for c in consumers:
            c.join()

        self.assertIsNone(q.last_error)
        log.info("Exit...")


def spawn_producer(data_q_):
    for i in range(10):
        data = [1, 2, 3, i]
        data_q_.put(data)


def spawn_consumer(data_q_):
    i = 0
    while True:
        try:
            data = data_q_.get(timeout=0.5)
            print(data)
            i += 1
        except Empty:
            print("Read", i, "messages")
            break


class TestSpawn(TestCase):
    def test_spawn_ctx(self):
        ctx = multiprocessing.get_context("spawn")
        data_q = Queue(1000 * 1000)
        procs = [ctx.Process(target=spawn_producer, args=(data_q,)) for _ in range(2)]
        procs.append(ctx.Process(target=spawn_consumer, args=(data_q,)))

        # add data to the queue and read some of it back to make sure all buffers are initialized before
        # the new process is spawned (such that we need to pickle everything)
        for i in range(10):
            data_q.put(self.test_spawn_ctx.__name__)
        msgs = data_q.get_many(max_messages_to_get=2)
        print(msgs)

        for p in procs:
            p.start()
        for p in procs:
            p.join()


# this can actually be used instead of Pickle if we know that we need to support only specific data types
# should be significantly faster
def custom_int_deserializer(msg_bytes):
    return int.from_bytes(msg_bytes, "big")


def custom_int_serializer(x):
    return x.to_bytes(4, "big")


class TestCustomSerializer(TestCase):
    def test_custom_loads_dumps(self):
        q = Queue(
            max_size_bytes=100000,
            loads=custom_int_deserializer,
            dumps=custom_int_serializer,
        )
        for i in range(32767):
            q.put(i)
            deserialized_i = q.get()
            assert i == deserialized_i


class SubQueue(Queue):
    pass


def worker_test_subclass(_x: Queue, _y: Queue):
    pass


class TestSubclass(TestCase):
    def test_subclass(self):
        ctx = multiprocessing.get_context("spawn")

        q = Queue()
        q.put(1)
        q.get()
        qs = SubQueue()  # Works with Queue()
        qs.put(2)
        qs.get()

        from multiprocessing.pool import Pool as MpPool

        pool = MpPool(
            2, initializer=worker_test_subclass, initargs=(q, qs), context=ctx
        )
        pool.close()
        pool.join()


posted on 2024-03-01 13:54  Angry_Panda  阅读(200)  评论(0)    收藏  举报

导航