python3 tcp_server

基于python3的tcp_server
运行环境: python3.10.13

import socket
import traceback
import logging
import threading
import time
import json
from typing import Callable, Optional
from concurrent.futures import ThreadPoolExecutor
import queue
from queue import Queue
import os

# 配置日志记录
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('tcp_server.log')
    ]
)
logger = logging.getLogger(__name__)


class BoundedThreadPoolExecutor(ThreadPoolExecutor):
    def __init__(self, max_size=100, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._work_queue = Queue(maxsize=max_size)
        
    def submit(self, fn, /, *args, **kwargs):
        try:
            return super().submit(fn, *args, **kwargs)
        except queue.Full:
            logger.warning(f"Task queue full ({self._work_queue.maxsize})")
            raise

class TCPServer:
    """TCP服务器类,支持多线程处理客户端连接"""
    
    def __init__(self, host: str, port: int, backlog: int = 200, 
                 buffer_size: int = 4096, timeout: float = 30.0,
                 handle_client_func: Optional[Callable] = None):
        """
        初始化TCP服务器
        
        Args:
            host: 服务器监听的主机地址
            port: 服务器监听的端口
            backlog: 最大连接数
            buffer_size: 接收缓冲区大小
            timeout: 连接超时时间(秒)
            handle_client_func: 处理客户端连接的自定义函数
        """
        self.host = host
        self.port = port
        self.backlog = backlog
        self.buffer_size = buffer_size
        self.timeout = timeout
        self.handle_client_func = handle_client_func or self._default_handle_client
        self.server_socket = None
        self.is_running = False
        self.connections = {}  # 跟踪所有活动连接
        
        # 线程池优化
        # self.max_workers = 100  # 新增最大线程数限制
        # self.thread_pool = ThreadPoolExecutor(max_workers=self.max_workers)
        self.max_workers = min(200, (os.cpu_count() or 1) * 10)  # 基于CPU核心数动态计算
        self.thread_pool = BoundedThreadPoolExecutor(
            # 设置队列上限防止内存溢出(Python 3.9+)
            max_size=500,
            max_workers=self.max_workers,
            thread_name_prefix='TCP_Worker',  # 添加线程名前缀
        )
        self.connection_lock = threading.Lock()  # 新增连接字典的线程锁
    
    def start(self) -> None:
        """启动TCP服务器"""
        try:
            # 创建并配置服务器套接字
            self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self.server_socket.settimeout(1.0)  # 用于优雅关闭
            self.server_socket.bind((self.host, self.port))
            self.server_socket.listen(self.backlog)
            
            self.is_running = True
            logger.info(f"Server started on {self.host}:{self.port}")
            
            # 启动连接监控线程
            monitor_thread = threading.Thread(target=self._monitor_connections, daemon=True)
            monitor_thread.start()
            
            # 主循环:接受客户端连接
            while self.is_running:
                try:
                    client_socket, client_address = self.server_socket.accept()
                    client_socket.settimeout(self.timeout)
                    client_key = f"{client_address[0]}:{client_address[1]}"
                    
                    # 记录连接
                    self.connections[client_key] = {
                        'socket': client_socket,
                        'address': client_address,
                        'timestamp': time.time(),
                        'active': True,
                        'buffer': bytearray(),  # 每个连接有自己的接收缓冲区
                    }
                    
                    logger.info(f"Client connected: {client_key}")
                    
                    # 为每个客户端创建新线程
                    # client_thread = threading.Thread(
                    #     target=self._handle_client_connection,
                    #     args=(client_socket, client_address),
                    #     daemon=True
                    # )
                    # client_thread.start()
                    
                    # 使用线程池替代直接创建线程
                    self.thread_pool.submit(
                        self._handle_client_connection,
                        client_socket,
                        client_address
                    )
                    
                except socket.timeout:
                    # 超时用于检查is_running标志
                    continue
                except OSError as e:
                    if self.is_running:
                        logger.error(f"Socket error: {e}")
                        time.sleep(0.1)  # 避免CPU占用过高
                except Exception as e:
                    logger.error(f"Unexpected error in accept loop: {e}")
                    traceback.print_exc()
                    time.sleep(0.1)
        
        except Exception as e:
            logger.error(f"Failed to start server: {e}")
            traceback.print_exc()
        finally:
            self.stop()
    
    def stop(self) -> None:
        """停止TCP服务器并关闭所有连接"""
        self.is_running = False
        
        # 关闭所有客户端连接
        for client_key, client_info in list(self.connections.items()):
            if client_info['active']:
                try:
                    client_info['socket'].close()
                    logger.info(f"Closed client connection: {client_key}")
                except Exception as e:
                    logger.error(f"Error closing client connection {client_key}: {e}")
        
        # 清空连接字典
        self.connections.clear()
        
        # 关闭服务器套接字
        if self.server_socket:
            try:
                self.server_socket.close()
                logger.info("Server socket closed")
            except Exception as e:
                logger.error(f"Error closing server socket: {e}")
        
        # 关闭线程池(新增)
        self.thread_pool.shutdown(wait=True, cancel_futures=True)
        logger.info("Thread pool shutdown completed")
    
    def _handle_client_connection(self, client_socket: socket.socket, client_address: tuple) -> None:
        """处理客户端连接的主函数"""
        client_key = f"{client_address[0]}:{client_address[1]}"
        
        try:
            # 调用客户端处理函数
            self.handle_client_func(client_socket, client_address)
        except socket.timeout:
            logger.warning(f"Client timeout: {client_key}")
        except socket.error as e:
            logger.error(f"Socket error with client {client_key}: {e}")
        except Exception as e:
            logger.error(f"Unexpected error handling client {client_key}: {e}")
            traceback.print_exc()
        finally:
            # 确保连接从字典中移除
            with self.connection_lock:  # 加锁保证线程安全
                if client_key in self.connections:
                    del self.connections[client_key]
            
            # 关闭客户端连接
            try:
                client_socket.close()
                logger.info(f"Client disconnected: {client_key}")
            except Exception as e:
                logger.error(f"Error closing client socket {client_key}: {e}")
            
            # 从连接字典中移除
            # if client_key in self.connections:
            #     self.connections[client_key]['active'] = False
            #     del self.connections[client_key]
    
    def _default_handle_client(self, client_socket: socket.socket, client_address: tuple) -> None:
        """默认的客户端处理函数,使用消息长度前缀协议"""
        client_key = f"{client_address[0]}:{client_address[1]}"
        
        # 每个连接有自己的接收缓冲区
        connection_buffer = bytearray()
        
        try:
            while True:
                # 接收数据并添加到缓冲区
                data = client_socket.recv(self.buffer_size)
                
                if not data:
                    break  # 客户端关闭连接
                
                # 更新连接时间戳
                if client_key in self.connections:
                    self.connections[client_key]['timestamp'] = time.time()
                
                # 将新数据添加到缓冲区
                connection_buffer.extend(data)
                logger.debug(f"Received {len(data)} bytes from {client_key}, buffer size: {len(connection_buffer)}")
                
                # 处理缓冲区中的所有完整消息
                while self._process_message(client_socket, client_key, connection_buffer):
                    pass  # 继续处理下一条消息,直到没有完整消息为止
                
        except Exception as e:
            logger.error(f"Error handling data from {client_key}: {e}")
    
    def _process_message(self, client_socket: socket.socket, client_key: str, buffer: bytearray) -> bool:
        """
        处理缓冲区中的消息(新增自定义协议头解析),这里根据自己需求处理
        """
        try:
            if len(buffer) > 0:
                success = self._handle_message(client_socket, client_key, buffer)
                buffer.clear()
                return success
            return False
        except Exception as e:
            logger.error(f"Message processing error: {e}")
            return False

    def _handle_message(self, client_socket: socket.socket, client_key: str, data_recv: bytes) -> bool:
        """统一消息处理入口(更新协议处理)"""
        try:
            # 解码JSON内容
            data_str = data_recv.decode('utf-8')
            logger.debug(f"Valid JSON received: {data_str}")
            
            client_socket.sendall(data_recv)
            return True
            
        except UnicodeDecodeError:
            logger.error(f"Invalid UTF-8 data from {client_key}")
        except json.JSONDecodeError:
            logger.error(f"Invalid JSON from {client_key}")
        
        return False
    
    def _monitor_connections(self) -> None:
        """监控活跃连接的线程"""
        while self.is_running:
            try:
                current_time = time.time()
                
                # 检查超时连接
                for client_key, client_info in list(self.connections.items()):
                    if not client_info['active']:
                        continue
                        
                    idle_time = current_time - client_info['timestamp']
                    if idle_time > self.timeout * 1.5:  # 1.5倍超时时间
                        try:
                            client_info['socket'].close()
                            logger.warning(f"Closed idle connection: {client_key} (idle for {idle_time:.1f}s)")
                        except Exception as e:
                            logger.error(f"Error closing idle connection {client_key}: {e}")
                        
                        if client_key in self.connections:
                            self.connections[client_key]['active'] = False
                            del self.connections[client_key]
                
                # 定期报告连接状态
                if len(self.connections) > 0 and current_time % 60 < 1:  # 大约每分钟一次
                    logger.info(f"Active connections: {len(self.connections)}")
                    
                # 新增线程池监控
                logger.info(f"ThreadPool status: {self.thread_pool._work_queue.qsize()} queued, "
                   f"{len(self.thread_pool._threads)} active threads"
                   f"{self.thread_pool._threads} active"
                   )
                
                time.sleep(5)  # 每5秒检查一次
            except Exception as e:
                logger.error(f"Error in connection monitor: {e}")
                time.sleep(1)

# 使用示例
if __name__ == "__main__":
    HOST = '0.0.0.0'  # 监听所有可用接口
    PORT = 12123       # 服务器端口
    
    # 创建并启动服务器
    server = TCPServer(HOST, PORT)
    try:
        server.start()
    except KeyboardInterrupt:
        logger.info("Server shutting down due to keyboard interrupt")
        server.stop()


posted @ 2025-06-12 18:00  BrianSun  阅读(10)  评论(0)    收藏  举报