1 import redis
2 import time
3 from typing import Dict, List, Tuple, Any, Optional
4
5 from config.model import settings
6 from pydantic import BaseModel
7
8
9 class StreamMessage(BaseModel):
10 message_id: str
11 message_data: Dict[str, Any]
12
13
14 class StreamMessages(BaseModel):
15 stream_name: str
16 messages: List[StreamMessage]
17
18
19 class RedisStreamManager:
20 def __init__(self, stream_name: str, redis_url: str = "", timeout: int = 2, max_length: int = 100000):
21 self.redis_url = redis_url or settings.OMS_REDIS_URL
22 self.redis = self._connect_to_redis()
23 self.stream_name = stream_name
24 self.timeout = timeout
25 self.max_length = max_length
26
27 # 确保流存在并设置最大长度
28 self.redis.xtrim(self.stream_name, maxlen=self.max_length, approximate=False)
29
30 def _connect_to_redis(self):
31 try:
32 client = redis.StrictRedis.from_url(self.redis_url, decode_responses=True)
33 client.ping() # 测试连接是否成功
34 return client
35 except redis.ConnectionError as e:
36 raise e
37
38 def add(self, message_data: Dict[str, str]) -> str:
39 """添加消息到流中"""
40 return self.redis.xadd(self.stream_name, message_data, maxlen=self.max_length)
41
42 def ensure_group(self, group_name: str):
43 """确保消费者组存在,如果不存在则创建"""
44 try:
45 self.redis.xgroup_create(self.stream_name, group_name, id="0", mkstream=True)
46 except redis.exceptions.ResponseError as e:
47 # 消费者组可能已经存在,忽略异常
48 if "BUSYGROUP Consumer Group name already exists" not in str(e):
49 raise e
50
51 def consume(self, group_name: str, consumer_name: str, count: int = 10) -> StreamMessages:
52 """从流中消费消息"""
53 self.ensure_group(group_name) # 确保消费者组存在
54 raw_messages = self.redis.xreadgroup(
55 group_name, consumer_name, {self.stream_name: ">"}, count=count, block=5000
56 )
57
58 # 解析 Redis 流消息到 Pydantic 模型
59 stream_messages = []
60 for _, message_list in raw_messages:
61 for message_id, message_data in message_list:
62 stream_message = StreamMessage(message_id=message_id, message_data=dict(message_data))
63 stream_messages.append(stream_message)
64
65 return StreamMessages(stream_name=self.stream_name, messages=stream_messages)
66
67 def ack(self, group_name: str, message_id: str):
68 """确认消息已处理"""
69 self.redis.xack(self.stream_name, group_name, message_id)
70
71 def reassign(self, group_name: str, consumer_name: str, message_id: str):
72 """根据消息ID使消息重新可分配"""
73 self.redis.xclaim(self.stream_name, group_name, consumer_name, min_idle_time=self.timeout * 1000, id=message_id)
74
75 def query_unconfirmed(self, group_name: str) -> StreamMessages:
76 """查询未确认的消息"""
77 # 查询未确认的消息
78 pending = self.redis.xpending(self.stream_name, group_name)
79
80 # 处理返回的未确认消息
81 stream_messages = []
82 min_id = pending.get("min")
83 max_id = pending.get("max")
84
85 # 通过 XRANGE 查询未确认消息的具体内容
86 if min_id and max_id:
87 message_list = self.redis.xrange(self.stream_name, min_id, max_id)
88
89 for message_id, message_data in message_list:
90 stream_message = StreamMessage(message_id=message_id, message_data=dict(message_data))
91 stream_messages.append(stream_message)
92
93 return StreamMessages(stream_name=self.stream_name, messages=stream_messages)
94
95 def check_timeout(self, message_id: str) -> bool:
96 """检查消息是否超时"""
97 message_time = int(message_id.split("-")[0])
98 return (time.time() - message_time / 1000) > self.timeout
99
100 def get_all_timeout_messages(self, group_name: str) -> List[str]:
101 """获取所有超时的未确认消息ID"""
102 pending_messages = self.query_unconfirmed(group_name)
103 timeout_messages = []
104 for message in pending_messages.get("messages", []):
105 message_id = message[0]
106 if self.check_timeout(message_id):
107 timeout_messages.append(message_id)
108 return timeout_messages
109
110 def handle_timeout(self, group_name: str, consumer_name: str):
111 """处理超时的未确认消息"""
112 timeout_messages = self.get_all_timeout_messages(group_name)
113 for message_id in timeout_messages:
114 # 重新分配超时的未确认消息
115 self.reassign(group_name, consumer_name, message_id)
116
117 def reassign_all_unconfirmed(self, group_name: str, consumer_name: str):
118 """将所有未确认的消息恢复到可分配状态"""
119 unconfirmed_messages = self.query_unconfirmed(group_name)
120 for message in unconfirmed_messages.get("messages", []):
121 message_id = message[0] # 直接使用,不需要解码
122 self.reassign(group_name, consumer_name, message_id)
123
124 def query_all(self, start: str = "-", end: str = "+"):
125 """查询流中的所有消息"""
126 try:
127 messages = self.redis.xrange(self.stream_name, min=start, max=end)
128 return messages
129 except redis.exceptions.ResponseError as e:
130 print(f"Error querying all messages: {e}")
131 return None
132
133 def weixiaofei(self, group_name):
134 return self.redis.xpending(self.stream_name, group_name)
135
136
137 def test4():
138 stream_name = "my_stream"
139 group_name = "my_group"
140 consumer_name = "my_consumer"
141 manager = RedisStreamManager(stream_name=stream_name)
142
143 # 确保消费者组存在
144 manager.ensure_group(group_name)
145
146 # 示例:添加消息
147 print("Adding messages...")
148 for i in range(5):
149 message_id = manager.add({"key": f"value_{i}"})
150 print(f"Added message with ID: {message_id}")
151
152 data = manager.query_all()
153 print(f"\ndata:{data}")
154
155 data = manager.consume(group_name, consumer_name, count=1)
156 print(f"\n未消费的:{data.messages[0].message_id}")
157 manager.ack(group_name, data.messages[0].message_id)
158
159 data = manager.consume(group_name, consumer_name, count=1)
160 print(f"\n消费未确认:{data}")
161
162 data = manager.query_unconfirmed(group_name)
163 print(f"\n未消费列表:{data}")
164
165
166 def test3():
167 stream_name = "my_stream"
168 group_name = "my_group"
169 consumer_name = "my_consumer"
170 manager = RedisStreamManager(stream_name=stream_name)
171
172 # 确保消费者组存在
173 manager.ensure_group(group_name)
174
175 data = manager.query_unconfirmed(group_name)
176 print(f"\n未消费的:{data}")
177
178
179 def test1():
180 stream_name = "my_stream"
181 group_name = "my_group"
182 consumer_name = "my_consumer"
183 manager = RedisStreamManager(stream_name=stream_name)
184
185 # 确保消费者组存在
186 manager.ensure_group(group_name)
187
188 # 示例:添加消息
189 print("Adding messages...")
190 for i in range(5):
191 message_id = manager.add({"key": f"value_{i}"})
192 print(f"Added message with ID: {message_id}")
193
194 data = manager.query_all()
195 print(f"\ndata:{data}")
196
197 # 示例:消费消息
198 test_cus = False
199 test_timeout = False
200 print("\nConsuming messages...")
201 messages = manager.consume(group_name, consumer_name)
202 for stream, message_list in messages:
203 for message_id, message_data in message_list:
204 message_id = message_id # 直接使用,不需要解码
205 message_data = dict(message_data)
206 print(f"Received message {message_id}: {message_data}")
207
208 if not test_cus:
209 test_cus = True
210 continue
211 elif not test_timeout:
212 test_timeout = True
213 # 使消息超时以模拟未确认
214 time.sleep(2 * manager.timeout)
215
216 # 确认消息
217 manager.ack(group_name, message_id)
218 print(f"Acknowledged message {message_id}")
219
220 data = manager.query_all()
221 print(f"\n所有数据:{data}")
222
223 data = manager.consume(group_name, consumer_name)
224 print(f"\n未消费的:{data}")
225
226 data = manager.handle_timeout(group_name, consumer_name)
227 print(f"\n超时的:{data}")
228
229 data = manager.query_unconfirmed(group_name)
230 print(f"\n待确认的:{data}")
231
232
233 def test2():
234 stream_name = "my_stream"
235 group_name = "my_group"
236 consumer_name = "my_consumer"
237 manager = RedisStreamManager(stream_name=stream_name)
238
239 # 确保消费者组存在
240 manager.ensure_group(group_name)
241
242 # 示例:消费消息
243 print("\nConsuming messages...")
244 messages = manager.consume(group_name, consumer_name)
245 for stream, message_list in messages:
246 for message_id, message_data in message_list:
247 message_id = message_id # 直接使用,不需要解码
248 message_data = dict(message_data)
249 print(f"Received message {message_id}: {message_data}")
250
251 # 使消息超时以模拟未确认
252 time.sleep(2 * manager.timeout) # 等待超时
253
254 # 确认消息
255 manager.ack(group_name, message_id)
256 print(f"Acknowledged message {message_id}")
257
258 # 处理超时消息
259 print("\nHandling timeout messages...")
260 manager.handle_timeout(group_name, consumer_name)
261 print("Handled timeout messages")
262
263 # 获取所有超时消息
264 print("\nGetting all timeout messages...")
265 timeout_messages = manager.get_all_timeout_messages(group_name)
266 print(f"All timeout messages: {timeout_messages}")
267
268 # 将所有未确认的消息恢复到可分配状态
269 print("\nReassigning all unconfirmed messages...")
270 manager.reassign_all_unconfirmed(group_name, consumer_name)
271 print("Reassigned all unconfirmed messages")