python3 tcp_server
基于python3的tcp_server
运行环境: python3.10.13
import socket
import traceback
import logging
import threading
import time
import json
from typing import Callable, Optional
from concurrent.futures import ThreadPoolExecutor
import queue
from queue import Queue
import os
# 配置日志记录
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('tcp_server.log')
]
)
logger = logging.getLogger(__name__)
class BoundedThreadPoolExecutor(ThreadPoolExecutor):
def __init__(self, max_size=100, *args, **kwargs):
super().__init__(*args, **kwargs)
self._work_queue = Queue(maxsize=max_size)
def submit(self, fn, /, *args, **kwargs):
try:
return super().submit(fn, *args, **kwargs)
except queue.Full:
logger.warning(f"Task queue full ({self._work_queue.maxsize})")
raise
class TCPServer:
"""TCP服务器类,支持多线程处理客户端连接"""
def __init__(self, host: str, port: int, backlog: int = 200,
buffer_size: int = 4096, timeout: float = 30.0,
handle_client_func: Optional[Callable] = None):
"""
初始化TCP服务器
Args:
host: 服务器监听的主机地址
port: 服务器监听的端口
backlog: 最大连接数
buffer_size: 接收缓冲区大小
timeout: 连接超时时间(秒)
handle_client_func: 处理客户端连接的自定义函数
"""
self.host = host
self.port = port
self.backlog = backlog
self.buffer_size = buffer_size
self.timeout = timeout
self.handle_client_func = handle_client_func or self._default_handle_client
self.server_socket = None
self.is_running = False
self.connections = {} # 跟踪所有活动连接
# 线程池优化
# self.max_workers = 100 # 新增最大线程数限制
# self.thread_pool = ThreadPoolExecutor(max_workers=self.max_workers)
self.max_workers = min(200, (os.cpu_count() or 1) * 10) # 基于CPU核心数动态计算
self.thread_pool = BoundedThreadPoolExecutor(
# 设置队列上限防止内存溢出(Python 3.9+)
max_size=500,
max_workers=self.max_workers,
thread_name_prefix='TCP_Worker', # 添加线程名前缀
)
self.connection_lock = threading.Lock() # 新增连接字典的线程锁
def start(self) -> None:
"""启动TCP服务器"""
try:
# 创建并配置服务器套接字
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server_socket.settimeout(1.0) # 用于优雅关闭
self.server_socket.bind((self.host, self.port))
self.server_socket.listen(self.backlog)
self.is_running = True
logger.info(f"Server started on {self.host}:{self.port}")
# 启动连接监控线程
monitor_thread = threading.Thread(target=self._monitor_connections, daemon=True)
monitor_thread.start()
# 主循环:接受客户端连接
while self.is_running:
try:
client_socket, client_address = self.server_socket.accept()
client_socket.settimeout(self.timeout)
client_key = f"{client_address[0]}:{client_address[1]}"
# 记录连接
self.connections[client_key] = {
'socket': client_socket,
'address': client_address,
'timestamp': time.time(),
'active': True,
'buffer': bytearray(), # 每个连接有自己的接收缓冲区
}
logger.info(f"Client connected: {client_key}")
# 为每个客户端创建新线程
# client_thread = threading.Thread(
# target=self._handle_client_connection,
# args=(client_socket, client_address),
# daemon=True
# )
# client_thread.start()
# 使用线程池替代直接创建线程
self.thread_pool.submit(
self._handle_client_connection,
client_socket,
client_address
)
except socket.timeout:
# 超时用于检查is_running标志
continue
except OSError as e:
if self.is_running:
logger.error(f"Socket error: {e}")
time.sleep(0.1) # 避免CPU占用过高
except Exception as e:
logger.error(f"Unexpected error in accept loop: {e}")
traceback.print_exc()
time.sleep(0.1)
except Exception as e:
logger.error(f"Failed to start server: {e}")
traceback.print_exc()
finally:
self.stop()
def stop(self) -> None:
"""停止TCP服务器并关闭所有连接"""
self.is_running = False
# 关闭所有客户端连接
for client_key, client_info in list(self.connections.items()):
if client_info['active']:
try:
client_info['socket'].close()
logger.info(f"Closed client connection: {client_key}")
except Exception as e:
logger.error(f"Error closing client connection {client_key}: {e}")
# 清空连接字典
self.connections.clear()
# 关闭服务器套接字
if self.server_socket:
try:
self.server_socket.close()
logger.info("Server socket closed")
except Exception as e:
logger.error(f"Error closing server socket: {e}")
# 关闭线程池(新增)
self.thread_pool.shutdown(wait=True, cancel_futures=True)
logger.info("Thread pool shutdown completed")
def _handle_client_connection(self, client_socket: socket.socket, client_address: tuple) -> None:
"""处理客户端连接的主函数"""
client_key = f"{client_address[0]}:{client_address[1]}"
try:
# 调用客户端处理函数
self.handle_client_func(client_socket, client_address)
except socket.timeout:
logger.warning(f"Client timeout: {client_key}")
except socket.error as e:
logger.error(f"Socket error with client {client_key}: {e}")
except Exception as e:
logger.error(f"Unexpected error handling client {client_key}: {e}")
traceback.print_exc()
finally:
# 确保连接从字典中移除
with self.connection_lock: # 加锁保证线程安全
if client_key in self.connections:
del self.connections[client_key]
# 关闭客户端连接
try:
client_socket.close()
logger.info(f"Client disconnected: {client_key}")
except Exception as e:
logger.error(f"Error closing client socket {client_key}: {e}")
# 从连接字典中移除
# if client_key in self.connections:
# self.connections[client_key]['active'] = False
# del self.connections[client_key]
def _default_handle_client(self, client_socket: socket.socket, client_address: tuple) -> None:
"""默认的客户端处理函数,使用消息长度前缀协议"""
client_key = f"{client_address[0]}:{client_address[1]}"
# 每个连接有自己的接收缓冲区
connection_buffer = bytearray()
try:
while True:
# 接收数据并添加到缓冲区
data = client_socket.recv(self.buffer_size)
if not data:
break # 客户端关闭连接
# 更新连接时间戳
if client_key in self.connections:
self.connections[client_key]['timestamp'] = time.time()
# 将新数据添加到缓冲区
connection_buffer.extend(data)
logger.debug(f"Received {len(data)} bytes from {client_key}, buffer size: {len(connection_buffer)}")
# 处理缓冲区中的所有完整消息
while self._process_message(client_socket, client_key, connection_buffer):
pass # 继续处理下一条消息,直到没有完整消息为止
except Exception as e:
logger.error(f"Error handling data from {client_key}: {e}")
def _process_message(self, client_socket: socket.socket, client_key: str, buffer: bytearray) -> bool:
"""
处理缓冲区中的消息(新增自定义协议头解析),这里根据自己需求处理
"""
try:
if len(buffer) > 0:
success = self._handle_message(client_socket, client_key, buffer)
buffer.clear()
return success
return False
except Exception as e:
logger.error(f"Message processing error: {e}")
return False
def _handle_message(self, client_socket: socket.socket, client_key: str, data_recv: bytes) -> bool:
"""统一消息处理入口(更新协议处理)"""
try:
# 解码JSON内容
data_str = data_recv.decode('utf-8')
logger.debug(f"Valid JSON received: {data_str}")
client_socket.sendall(data_recv)
return True
except UnicodeDecodeError:
logger.error(f"Invalid UTF-8 data from {client_key}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON from {client_key}")
return False
def _monitor_connections(self) -> None:
"""监控活跃连接的线程"""
while self.is_running:
try:
current_time = time.time()
# 检查超时连接
for client_key, client_info in list(self.connections.items()):
if not client_info['active']:
continue
idle_time = current_time - client_info['timestamp']
if idle_time > self.timeout * 1.5: # 1.5倍超时时间
try:
client_info['socket'].close()
logger.warning(f"Closed idle connection: {client_key} (idle for {idle_time:.1f}s)")
except Exception as e:
logger.error(f"Error closing idle connection {client_key}: {e}")
if client_key in self.connections:
self.connections[client_key]['active'] = False
del self.connections[client_key]
# 定期报告连接状态
if len(self.connections) > 0 and current_time % 60 < 1: # 大约每分钟一次
logger.info(f"Active connections: {len(self.connections)}")
# 新增线程池监控
logger.info(f"ThreadPool status: {self.thread_pool._work_queue.qsize()} queued, "
f"{len(self.thread_pool._threads)} active threads"
f"{self.thread_pool._threads} active"
)
time.sleep(5) # 每5秒检查一次
except Exception as e:
logger.error(f"Error in connection monitor: {e}")
time.sleep(1)
# 使用示例
if __name__ == "__main__":
HOST = '0.0.0.0' # 监听所有可用接口
PORT = 12123 # 服务器端口
# 创建并启动服务器
server = TCPServer(HOST, PORT)
try:
server.start()
except KeyboardInterrupt:
logger.info("Server shutting down due to keyboard interrupt")
server.stop()