【Linux C/C_++制作】第18章:网络编程基础 - Socket编程实战

第18章:网络编程基础 - Socket编程实战

学习目标

本周将深入理解网络编程的核心概念,掌握Socket编程的底层原理,实现高性能的网络通信程序。通过实际案例分析网络协议行为,掌握现代C++网络编程的最佳实践。

第一部分:网络协议栈核心概念

1.1 TCP/IP协议栈分层模型

网络协议栈就像邮政系统,每一层都有特定的职责和服务,通过分工协作完成数据的可靠传输:

#include <cstdint>
  #include <memory>
    #include <vector>
      #include <functional>
        #include <map>
          #include <chrono>
            #include <cmath>
              #include <random>
                // 网络协议栈的核心概念
                namespace NetworkConcepts {
                // 协议层抽象基类
                template<typename Packet>
                  class ProtocolLayer {
                  public:
                  virtual ~ProtocolLayer() = default;
                  // 数据封装:添加协议头部信息,就像给信件加上信封
                  virtual Packet encapsulate(const Packet& data) = 0;
                  // 数据解封:移除协议头部,提取有效数据,就像拆开信封
                  virtual Packet decapsulate(const Packet& data) = 0;
                  // 错误检测:验证数据完整性,就像检查信件是否损坏
                  virtual bool detectError(const Packet& data) = 0;
                  // 吞吐量计算:单位时间内传输的数据量,就像邮局的日处理量
                  virtual double throughput(const std::chrono::steady_clock::time_point& t) = 0;
                  // 延迟计算:数据从发送到接收的时间,就像信件的邮寄时间
                  virtual double latency(const Packet& data) = 0;
                  };
                  // TCP拥塞控制核心概念
                  class CongestionControlConcept {
                  private:
                  double cwnd_;           // 拥塞窗口大小
                  double ssthresh_;       // 慢启动阈值
                  double rtt_;            // 往返时间
                  double bandwidth_;      // 带宽估计
                  // 丢包概率模型 P_loss = f(network_conditions)
                  std::function<double()> loss_probability_;
                    public:
                    CongestionControlModel(double initial_cwnd = 1.0, double initial_ssthresh = 64.0)
                    : cwnd_(initial_cwnd), ssthresh_(initial_ssthresh), rtt_(0.1), bandwidth_(1e6) {}
                    // 慢启动算法:cwnd(t+1) = cwnd(t) + MSS (每ACK)
                    void slowStart() {
                    cwnd_ += 1.0;  // 简化模型:每ACK增加1个MSS
                    }
                    // 拥塞避免算法:cwnd(t+1) = cwnd(t) + MSS/cwnd(t) (每ACK)
                    void congestionAvoidance() {
                    cwnd_ += 1.0 / cwnd_;
                    }
                    // 快速重传算法
                    void fastRetransmit() {
                    ssthresh_ = std::max(cwnd_ / 2.0, 2.0);
                    cwnd_ = ssthresh_ + 3.0;
                    }
                    // 快速恢复算法
                    void fastRecovery() {
                    cwnd_ = ssthresh_;
                    }
                    // 超时重传算法
                    void timeout() {
                    ssthresh_ = std::max(cwnd_ / 2.0, 2.0);
                    cwnd_ = 1.0;
                    }
                    // 吞吐量预测:估算当前网络传输能力
                    double predictThroughput() const {
                    // 带宽延迟积 = 带宽 × 往返时间
                    double bdp = bandwidth_ * rtt_;
                    // 实际吞吐量 = min(拥塞窗口, 带宽延迟积) / 往返时间
                    return std::min(cwnd_, bdp) / rtt_;
                    }
                    // 最优窗口大小:根据网络条件计算最佳窗口
                    double optimalWindow() const {
                    // 基于带宽、延迟和丢包率计算最优窗口大小
                    return bandwidth_ * rtt_ * std::sqrt(1.5 * (1.0 - loss_probability_()));
                    }
                    double getCwnd() const { return cwnd_; }
                    double getSsthresh() const { return ssthresh_; }
                    };
                    // 网络队列模型 - 数据包排队处理
                    class NetworkQueueModel {
                    private:
                    double arrival_rate_;   // 数据包到达率(包/秒)
                    double service_rate_;   // 服务处理率(包/秒)
                    public:
                    NetworkQueueModel(double arrival_rate, double service_rate)
                    : arrival_rate_(arrival_rate), service_rate_(service_rate) {}
                    // 系统利用率 = 到达率 / 服务率
                    double utilization() const {
                    return arrival_rate_ / service_rate_;
                    }
                    // 平均队列长度:根据利特尔法则计算
                    double avgQueueLength() const {
                    double rho = utilization();
                    if (rho >= 1.0) return std::numeric_limits<double>::infinity();
                      return (arrival_rate_ * arrival_rate_) / (service_rate_ * (service_rate_ - arrival_rate_));
                      }
                      // 平均等待时间:队列长度除以到达率
                      double avgWaitTime() const {
                      double rho = utilization();
                      if (rho >= 1.0) return std::numeric_limits<double>::infinity();
                        return arrival_rate_ / (service_rate_ * (service_rate_ - arrival_rate_));
                        }
                        // 丢包概率 (队列溢出)
                        double packetLossProbability(int max_queue_size) const {
                        double rho = utilization();
                        if (rho >= 1.0) return 1.0;
                        // 对于有限队列的M/M/1/K模型
                        return (1.0 - rho) * std::pow(rho, max_queue_size) / (1.0 - std::pow(rho, max_queue_size + 1));
                        }
                        };
                        }

1.2 套接字抽象核心概念

套接字是网络通信的端点,就像电话系统中的电话号码,通过协议、IP地址和端口号唯一标识一个通信连接:

// 套接字核心概念
namespace SocketConcepts {
// 套接字五元组: (协议, 源IP, 源端口, 目标IP, 目标端口)
struct SocketTuple {
uint8_t protocol;     // 协议类型
uint32_t src_ip;       // 源IP地址
uint16_t src_port;     // 源端口
uint32_t dst_ip;       // 目标IP地址
uint16_t dst_port;     // 目标端口
// 哈希函数用于快速查找
size_t hash() const {
return std::hash<uint64_t>{}(
  (static_cast<uint64_t>(protocol) << 56) |
    (static_cast<uint64_t>(src_ip) << 24) |
      (static_cast<uint64_t>(src_port) << 8) |
        (static_cast<uint64_t>(dst_ip >> 8))
          );
          }
          bool operator==(const SocketTuple& other) const {
          return protocol == other.protocol &&
          src_ip == other.src_ip &&
          src_port == other.src_port &&
          dst_ip == other.dst_ip &&
          dst_port == other.dst_port;
          }
          };
          // 套接字状态机
          enum class SocketState {
          CLOSED,         // 关闭状态
          LISTEN,         // 监听状态
          SYN_SENT,       // SYN已发送
          SYN_RECEIVED,   // SYN已接收
          ESTABLISHED,    // 连接已建立
          FIN_WAIT_1,     // FIN等待1
          FIN_WAIT_2,     // FIN等待2
          CLOSE_WAIT,     // 关闭等待
          CLOSING,        // 正在关闭
          LAST_ACK,       // 最后确认
          TIME_WAIT       // 时间等待
          };
          // TCP状态转换矩阵
          class TCPStateMachine {
          private:
          SocketState current_state_;
          std::map<std::pair<SocketState, std::string>, SocketState> transitions_;
            public:
            TCPStateMachine() : current_state_(SocketState::CLOSED) {
            initializeTransitions();
            }
            void initializeTransitions() {
            // TCP状态转换图
            transitions_[{SocketState::CLOSED, "PASSIVE_OPEN"}] = SocketState::LISTEN;
            transitions_[{SocketState::CLOSED, "ACTIVE_OPEN"}] = SocketState::SYN_SENT;
            transitions_[{SocketState::LISTEN, "SEND_SYN"}] = SocketState::SYN_SENT;
            transitions_[{SocketState::LISTEN, "RECV_SYN"}] = SocketState::SYN_RECEIVED;
            transitions_[{SocketState::SYN_SENT, "RECV_SYN_ACK"}] = SocketState::ESTABLISHED;
            transitions_[{SocketState::SYN_RECEIVED, "RECV_ACK"}] = SocketState::ESTABLISHED;
            transitions_[{SocketState::ESTABLISHED, "SEND_FIN"}] = SocketState::FIN_WAIT_1;
            transitions_[{SocketState::ESTABLISHED, "RECV_FIN"}] = SocketState::CLOSE_WAIT;
            transitions_[{SocketState::FIN_WAIT_1, "RECV_FIN"}] = SocketState::CLOSING;
            transitions_[{SocketState::FIN_WAIT_1, "RECV_FIN_ACK"}] = SocketState::FIN_WAIT_2;
            transitions_[{SocketState::CLOSE_WAIT, "SEND_FIN"}] = SocketState::LAST_ACK;
            transitions_[{SocketState::CLOSING, "RECV_ACK"}] = SocketState::TIME_WAIT;
            transitions_[{SocketState::LAST_ACK, "RECV_ACK"}] = SocketState::CLOSED;
            }
            bool transition(const std::string& event) {
            auto key = std::make_pair(current_state_, event);
            auto it = transitions_.find(key);
            if (it != transitions_.end()) {
            current_state_ = it->second;
            return true;
            }
            return false;
            }
            SocketState getCurrentState() const { return current_state_; }
            // 状态转换检查:判断给定事件是否能触发状态转换
            double transitionProbability(const std::string& event) const {
            auto key = std::make_pair(current_state_, event);
            return transitions_.count(key) > 0 ? 1.0 : 0.0;
            }
            };
            // 套接字缓冲区模型
            template<typename T>
              class SocketBuffer {
              private:
              std::vector<T> buffer_;
                size_t capacity_;
                size_t read_pos_;
                size_t write_pos_;
                // 缓冲区利用率:已使用空间占总容量的比例
                double utilization() const {
                return static_cast<double>(size()) / capacity_;
                  }
                  public:
                  explicit SocketBuffer(size_t capacity)
                  : capacity_(capacity), read_pos_(0), write_pos_(0) {
                  buffer_.resize(capacity);
                  }
                  // 写入数据
                  bool write(const T* data, size_t len) {
                  if (available() < len) return false;
                  for (size_t i = 0; i < len; ++i) {
                  buffer_[write_pos_] = data[i];
                  write_pos_ = (write_pos_ + 1) % capacity_;
                  }
                  return true;
                  }
                  // 读取数据
                  bool read(T* data, size_t len) {
                  if (size() < len) return false;
                  for (size_t i = 0; i < len; ++i) {
                  data[i] = buffer_[read_pos_];
                  read_pos_ = (read_pos_ + 1) % capacity_;
                  }
                  return true;
                  }
                  size_t size() const {
                  if (write_pos_ >= read_pos_) {
                  return write_pos_ - read_pos_;
                  } else {
                  return capacity_ - read_pos_ + write_pos_;
                  }
                  }
                  size_t available() const {
                  return capacity_ - size();
                  }
                  // 缓冲区性能指标:吞吐量与利用率的关系
                  double throughput() const {
                  // 经验表明:中等利用率时吞吐量最佳
                  double util = utilization();
                  return util * (1.0 - util);  // 倒U型关系
                  }
                  };
                  }

第二部分:现代C++ Socket编程实现

2.1 RAII Socket封装

使用现代C++特性实现类型安全、异常安全的Socket封装:

#include <sys/socket.h>
  #include <netinet/in.h>
    #include <netinet/tcp.h>
      #include <arpa/inet.h>
        #include <unistd.h>
          #include <fcntl.h>
            #include <poll.h>
              #include <errno.h>
                #include <cstring>
                  #include <string>
                    #include <stdexcept>
                      #include <system_error>
                        #include <memory>
                          #include <span>
                            namespace AdvancedSocket {
                            // 自定义异常类型
                            class SocketException : public std::runtime_error {
                            private:
                            int error_code_;
                            public:
                            SocketException(const std::string& message, int error_code)
                            : std::runtime_error(message + ": " + std::strerror(error_code))
                            , error_code_(error_code) {}
                            int errorCode() const { return error_code_; }
                            };
                            // RAII Socket封装
                            class Socket {
                            private:
                            int fd_;
                            bool non_blocking_;
                            // 禁止复制
                            Socket(const Socket&) = delete;
                            Socket& operator=(const Socket&) = delete;
                            public:
                            // 构造函数
                            explicit Socket(int domain = AF_INET, int type = SOCK_STREAM, int protocol = 0)
                            : fd_(-1), non_blocking_(false) {
                            fd_ = ::socket(domain, type, protocol);
                            if (fd_ == -1) {
                            throw SocketException("Failed to create socket", errno);
                            }
                            // 设置SO_REUSEADDR
                            int reuse = 1;
                            if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) == -1) {
                            close();
                            throw SocketException("Failed to set SO_REUSEADDR", errno);
                            }
                            }
                            // 移动构造函数
                            Socket(Socket&& other) noexcept
                            : fd_(other.fd_), non_blocking_(other.non_blocking_) {
                            other.fd_ = -1;
                            }
                            // 移动赋值运算符
                            Socket& operator=(Socket&& other) noexcept {
                            if (this != &other) {
                            close();
                            fd_ = other.fd_;
                            non_blocking_ = other.non_blocking_;
                            other.fd_ = -1;
                            }
                            return *this;
                            }
                            // 析构函数
                            ~Socket() {
                            close();
                            }
                            // 获取文件描述符
                            int getFd() const { return fd_; }
                            // 绑定地址
                            void bind(const struct sockaddr* addr, socklen_t addrlen) {
                            if (::bind(fd_, addr, addrlen) == -1) {
                            throw SocketException("Failed to bind socket", errno);
                            }
                            }
                            // 监听连接
                            void listen(int backlog = SOMAXCONN) {
                            if (::listen(fd_, backlog) == -1) {
                            throw SocketException("Failed to listen on socket", errno);
                            }
                            }
                            // 接受连接
                            Socket accept(struct sockaddr* addr = nullptr, socklen_t* addrlen = nullptr) {
                            int client_fd = ::accept(fd_, addr, addrlen);
                            if (client_fd == -1) {
                            throw SocketException("Failed to accept connection", errno);
                            }
                            Socket client_socket;
                            client_socket.close();
                            client_socket.fd_ = client_fd;
                            client_socket.non_blocking_ = non_blocking_;
                            return client_socket;
                            }
                            // 连接到服务器
                            void connect(const struct sockaddr* addr, socklen_t addrlen) {
                            int result = ::connect(fd_, addr, addrlen);
                            if (result == -1 && errno != EINPROGRESS) {
                            throw SocketException("Failed to connect", errno);
                            }
                            }
                            // 发送数据
                            ssize_t send(const void* data, size_t len, int flags = 0) {
                            ssize_t result = ::send(fd_, data, len, flags);
                            if (result == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
                            throw SocketException("Failed to send data", errno);
                            }
                            return result;
                            }
                            // 接收数据
                            ssize_t recv(void* buffer, size_t len, int flags = 0) {
                            ssize_t result = ::recv(fd_, buffer, len, flags);
                            if (result == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
                            throw SocketException("Failed to receive data", errno);
                            }
                            return result;
                            }
                            // 设置非阻塞模式
                            void setNonBlocking(bool non_blocking = true) {
                            int flags = fcntl(fd_, F_GETFL, 0);
                            if (flags == -1) {
                            throw SocketException("Failed to get socket flags", errno);
                            }
                            if (non_blocking) {
                            flags |= O_NONBLOCK;
                            } else {
                            flags &= ~O_NONBLOCK;
                            }
                            if (fcntl(fd_, F_SETFL, flags) == -1) {
                            throw SocketException("Failed to set non-blocking mode", errno);
                            }
                            non_blocking_ = non_blocking;
                            }
                            // 设置TCP_NODELAY
                            void setTcpNoDelay(bool enable = true) {
                            int value = enable ? 1 : 0;
                            if (::setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) == -1) {
                            throw SocketException("Failed to set TCP_NODELAY", errno);
                            }
                            }
                            // 设置SO_KEEPALIVE
                            void setKeepAlive(bool enable = true) {
                            int value = enable ? 1 : 0;
                            if (::setsockopt(fd_, SOL_SOCKET, SO_KEEPALIVE, &value, sizeof(value)) == -1) {
                            throw SocketException("Failed to set SO_KEEPALIVE", errno);
                            }
                            }
                            // 获取错误码
                            int getError() const {
                            int error = 0;
                            socklen_t len = sizeof(error);
                            if (::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len) == -1) {
                            return errno;
                            }
                            return error;
                            }
                            // 关闭套接字
                            void close() {
                            if (fd_ != -1) {
                            ::close(fd_);
                            fd_ = -1;
                            }
                            }
                            // 检查是否有效
                            bool isValid() const {
                            return fd_ != -1;
                            }
                            };
                            // 地址类封装
                            class InetAddress {
                            private:
                            struct sockaddr_in addr_;
                            public:
                            // 构造函数
                            InetAddress() {
                            std::memset(&addr_, 0, sizeof(addr_));
                            addr_.sin_family = AF_INET;
                            addr_.sin_addr.s_addr = htonl(INADDR_ANY);
                            addr_.sin_port = htons(0);
                            }
                            InetAddress(const std::string& ip, uint16_t port) {
                            std::memset(&addr_, 0, sizeof(addr_));
                            addr_.sin_family = AF_INET;
                            addr_.sin_port = htons(port);
                            if (ip.empty() || ip == "0.0.0.0") {
                            addr_.sin_addr.s_addr = htonl(INADDR_ANY);
                            } else {
                            if (inet_pton(AF_INET, ip.c_str(), &addr_.sin_addr) != 1) {
                            throw std::invalid_argument("Invalid IP address: " + ip);
                            }
                            }
                            }
                            InetAddress(uint16_t port) : InetAddress("0.0.0.0", port) {}
                            // 获取sockaddr
                            const struct sockaddr* getSockAddr() const {
                            return reinterpret_cast<const struct sockaddr*>(&addr_);
                              }
                              struct sockaddr* getSockAddr() {
                              return reinterpret_cast<struct sockaddr*>(&addr_);
                                }
                                // 获取地址长度
                                socklen_t getSockLen() const {
                                return sizeof(addr_);
                                }
                                // 获取IP地址
                                std::string getIp() const {
                                char buf[INET_ADDRSTRLEN];
                                const char* result = inet_ntop(AF_INET, &addr_.sin_addr, buf, sizeof(buf));
                                if (result == nullptr) {
                                throw std::runtime_error("Failed to convert IP address");
                                }
                                return std::string(result);
                                }
                                // 获取端口
                                uint16_t getPort() const {
                                return ntohs(addr_.sin_port);
                                }
                                // 获取IP:端口字符串
                                std::string toString() const {
                                return getIp() + ":" + std::to_string(getPort());
                                }
                                };
                                }

2.2 高性能TCP服务器实现

实现基于epoll的高性能TCP服务器,支持大量并发连接:

#include <sys/epoll.h>
  #include <vector>
    #include <unordered_map>
      #include <memory>
        #include <functional>
          #include <queue>
            #include <thread>
              #include <mutex>
                #include <condition_variable>
                  #include <atomic>
                    namespace AdvancedSocket {
                    // 连接状态
                    enum class ConnectionState {
                    CONNECTING,
                    CONNECTED,
                    DISCONNECTING,
                    DISCONNECTED
                    };
                    // TCP连接类
                    class TcpConnection : public std::enable_shared_from_this<TcpConnection> {
                      private:
                      Socket socket_;
                      InetAddress local_addr_;
                      InetAddress peer_addr_;
                      ConnectionState state_;
                      std::atomic<bool> reading_;
                        std::atomic<bool> writing_;
                          std::queue<std::vector<uint8_t>> write_queue_;
                            std::mutex write_mutex_;
                            // 统计信息
                            std::atomic<uint64_t> bytes_sent_;
                              std::atomic<uint64_t> bytes_received_;
                                std::atomic<uint64_t> messages_sent_;
                                  std::atomic<uint64_t> messages_received_;
                                    public:
                                    using MessageCallback = std::function<void(const std::shared_ptr<TcpConnection>&,
                                      const std::vector<uint8_t>&)>;
                                        using CloseCallback = std::function<void(const std::shared_ptr<TcpConnection>&)>;
                                          using ErrorCallback = std::function<void(const std::shared_ptr<TcpConnection>&, int)>;
                                            private:
                                            MessageCallback message_callback_;
                                            CloseCallback close_callback_;
                                            ErrorCallback error_callback_;
                                            public:
                                            TcpConnection(Socket&& socket, const InetAddress& local_addr, const InetAddress& peer_addr)
                                            : socket_(std::move(socket))
                                            , local_addr_(local_addr)
                                            , peer_addr_(peer_addr)
                                            , state_(ConnectionState::CONNECTED)
                                            , reading_(false)
                                            , writing_(false)
                                            , bytes_sent_(0)
                                            , bytes_received_(0)
                                            , messages_sent_(0)
                                            , messages_received_(0) {
                                            socket_.setNonBlocking(true);
                                            socket_.setTcpNoDelay(true);
                                            }
                                            ~TcpConnection() {
                                            shutdown();
                                            }
                                            // 获取文件描述符
                                            int getFd() const {
                                            return socket_.getFd();
                                            }
                                            // 获取本地地址
                                            const InetAddress& getLocalAddress() const {
                                            return local_addr_;
                                            }
                                            // 获取对端地址
                                            const InetAddress& getPeerAddress() const {
                                            return peer_addr_;
                                            }
                                            // 获取状态
                                            ConnectionState getState() const {
                                            return state_;
                                            }
                                            // 设置回调函数
                                            void setMessageCallback(MessageCallback callback) {
                                            message_callback_ = std::move(callback);
                                            }
                                            void setCloseCallback(CloseCallback callback) {
                                            close_callback_ = std::move(callback);
                                            }
                                            void setErrorCallback(ErrorCallback callback) {
                                            error_callback_ = std::move(callback);
                                            }
                                            // 发送数据
                                            void send(const void* data, size_t len) {
                                            if (state_ != ConnectionState::CONNECTED) {
                                            return;
                                            }
                                            std::vector<uint8_t> message(static_cast<const uint8_t*>(data),
                                              static_cast<const uint8_t*>(data) + len);
                                                bool was_writing = writing_.exchange(true);
                                                {
                                                std::lock_guard<std::mutex> lock(write_mutex_);
                                                  write_queue_.push(std::move(message));
                                                  }
                                                  if (!was_writing) {
                                                  startWriting();
                                                  }
                                                  }
                                                  // 发送字符串
                                                  void send(const std::string& message) {
                                                  send(message.data(), message.size());
                                                  }
                                                  // 开始读取
                                                  void startReading() {
                                                  if (reading_.exchange(true)) {
                                                  return;
                                                  }
                                                  readData();
                                                  }
                                                  // 停止读取
                                                  void stopReading() {
                                                  reading_ = false;
                                                  }
                                                  // 关闭连接
                                                  void shutdown() {
                                                  if (state_ == ConnectionState::DISCONNECTED) {
                                                  return;
                                                  }
                                                  state_ = ConnectionState::DISCONNECTING;
                                                  socket_.close();
                                                  state_ = ConnectionState::DISCONNECTED;
                                                  if (close_callback_) {
                                                  close_callback_(shared_from_this());
                                                  }
                                                  }
                                                  // 获取统计信息
                                                  uint64_t getBytesSent() const { return bytes_sent_; }
                                                  uint64_t getBytesReceived() const { return bytes_received_; }
                                                  uint64_t getMessagesSent() const { return messages_sent_; }
                                                  uint64_t getMessagesReceived() const { return messages_received_; }
                                                  private:
                                                  void readData() {
                                                  if (!reading_) {
                                                  return;
                                                  }
                                                  std::vector<uint8_t> buffer(65536);
                                                    while (reading_) {
                                                    ssize_t n = socket_.recv(buffer.data(), buffer.size());
                                                    if (n > 0) {
                                                    bytes_received_ += n;
                                                    messages_received_++;
                                                    if (message_callback_) {
                                                    std::vector<uint8_t> message(buffer.begin(), buffer.begin() + n);
                                                      message_callback_(shared_from_this(), message);
                                                      }
                                                      } else if (n == 0) {
                                                      // 连接关闭
                                                      shutdown();
                                                      break;
                                                      } else {
                                                      // 没有更多数据
                                                      break;
                                                      }
                                                      }
                                                      }
                                                      void startWriting() {
                                                      if (!writing_) {
                                                      return;
                                                      }
                                                      std::vector<uint8_t> message;
                                                        {
                                                        std::lock_guard<std::mutex> lock(write_mutex_);
                                                          if (write_queue_.empty()) {
                                                          writing_ = false;
                                                          return;
                                                          }
                                                          message = std::move(write_queue_.front());
                                                          write_queue_.pop();
                                                          }
                                                          writeData(message);
                                                          }
                                                          void writeData(const std::vector<uint8_t>& message) {
                                                            ssize_t n = socket_.send(message.data(), message.size());
                                                            if (n > 0) {
                                                            bytes_sent_ += n;
                                                            messages_sent_++;
                                                            if (static_cast<size_t>(n) < message.size()) {
                                                              // 部分写入,需要继续
                                                              std::vector<uint8_t> remaining(message.begin() + n, message.end());
                                                                std::lock_guard<std::mutex> lock(write_mutex_);
                                                                  write_queue_.push(remaining);
                                                                  }
                                                                  } else if (n == -1 && errno == EAGAIN) {
                                                                  // 暂时无法写入,重新加入队列
                                                                  std::lock_guard<std::mutex> lock(write_mutex_);
                                                                    write_queue_.push(message);
                                                                    }
                                                                    // 继续处理队列
                                                                    startWriting();
                                                                    }
                                                                    };
                                                                    // Epoll事件循环
                                                                    class EventLoop {
                                                                    private:
                                                                    int epoll_fd_;
                                                                    std::atomic<bool> running_;
                                                                      static constexpr int MAX_EVENTS = 1024;
                                                                      public:
                                                                      EventLoop() : running_(false) {
                                                                      epoll_fd_ = epoll_create1(EPOLL_CLOEXEC);
                                                                      if (epoll_fd_ == -1) {
                                                                      throw SocketException("Failed to create epoll", errno);
                                                                      }
                                                                      }
                                                                      ~EventLoop() {
                                                                      stop();
                                                                      if (epoll_fd_ != -1) {
                                                                      close(epoll_fd_);
                                                                      }
                                                                      }
                                                                      // 添加文件描述符
                                                                      void addChannel(int fd, uint32_t events, void* data) {
                                                                      struct epoll_event ev;
                                                                      std::memset(&ev, 0, sizeof(ev));
                                                                      ev.events = events;
                                                                      ev.data.ptr = data;
                                                                      if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &ev) == -1) {
                                                                      throw SocketException("Failed to add channel to epoll", errno);
                                                                      }
                                                                      }
                                                                      // 修改文件描述符事件
                                                                      void modifyChannel(int fd, uint32_t events, void* data) {
                                                                      struct epoll_event ev;
                                                                      std::memset(&ev, 0, sizeof(ev));
                                                                      ev.events = events;
                                                                      ev.data.ptr = data;
                                                                      if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, fd, &ev) == -1) {
                                                                      throw SocketException("Failed to modify channel in epoll", errno);
                                                                      }
                                                                      }
                                                                      // 删除文件描述符
                                                                      void removeChannel(int fd) {
                                                                      if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == -1) {
                                                                      throw SocketException("Failed to remove channel from epoll", errno);
                                                                      }
                                                                      }
                                                                      // 启动事件循环
                                                                      void start(std::function<void()> callback = nullptr) {
                                                                        running_ = true;
                                                                        struct epoll_event events[MAX_EVENTS];
                                                                        while (running_) {
                                                                        int num_events = epoll_wait(epoll_fd_, events, MAX_EVENTS, 1000);
                                                                        if (num_events == -1) {
                                                                        if (errno == EINTR) {
                                                                        continue;
                                                                        }
                                                                        throw SocketException("epoll_wait failed", errno);
                                                                        }
                                                                        for (int i = 0; i < num_events; ++i) {
                                                                        if (callback) {
                                                                        callback();
                                                                        }
                                                                        }
                                                                        }
                                                                        }
                                                                        // 停止事件循环
                                                                        void stop() {
                                                                        running_ = false;
                                                                        }
                                                                        };
                                                                        // TCP服务器
                                                                        class TcpServer {
                                                                        private:
                                                                        EventLoop loop_;
                                                                        Socket accept_socket_;
                                                                        InetAddress listen_addr_;
                                                                        std::unordered_map<int, std::shared_ptr<TcpConnection>> connections_;
                                                                          std::mutex connections_mutex_;
                                                                          TcpConnection::MessageCallback message_callback_;
                                                                          TcpConnection::CloseCallback close_callback_;
                                                                          TcpConnection::ErrorCallback error_callback_;
                                                                          std::atomic<bool> started_;
                                                                            std::thread server_thread_;
                                                                            public:
                                                                            TcpServer(const InetAddress& listen_addr)
                                                                            : listen_addr_(listen_addr)
                                                                            , started_(false) {
                                                                            accept_socket_.bind(listen_addr_.getSockAddr(), listen_addr_.getSockLen());
                                                                            accept_socket_.listen();
                                                                            accept_socket_.setNonBlocking(true);
                                                                            }
                                                                            ~TcpServer() {
                                                                            stop();
                                                                            }
                                                                            // 设置消息回调
                                                                            void setMessageCallback(TcpConnection::MessageCallback callback) {
                                                                            message_callback_ = std::move(callback);
                                                                            }
                                                                            // 设置关闭回调
                                                                            void setCloseCallback(TcpConnection::CloseCallback callback) {
                                                                            close_callback_ = std::move(callback);
                                                                            }
                                                                            // 设置错误回调
                                                                            void setErrorCallback(TcpConnection::ErrorCallback callback) {
                                                                            error_callback_ = std::move(callback);
                                                                            }
                                                                            // 启动服务器
                                                                            void start() {
                                                                            if (started_.exchange(true)) {
                                                                            return;
                                                                            }
                                                                            server_thread_ = std::thread([this]() {
                                                                            this->run();
                                                                            });
                                                                            }
                                                                            // 停止服务器
                                                                            void stop() {
                                                                            if (!started_.exchange(false)) {
                                                                            return;
                                                                            }
                                                                            loop_.stop();
                                                                            if (server_thread_.joinable()) {
                                                                            server_thread_.join();
                                                                            }
                                                                            // 关闭所有连接
                                                                            std::lock_guard<std::mutex> lock(connections_mutex_);
                                                                              for (auto& [fd, conn] : connections_) {
                                                                              conn->shutdown();
                                                                              }
                                                                              connections_.clear();
                                                                              }
                                                                              // 获取连接数
                                                                              size_t getConnectionCount() const {
                                                                              std::lock_guard<std::mutex> lock(connections_mutex_);
                                                                                return connections_.size();
                                                                                }
                                                                                private:
                                                                                void run() {
                                                                                // 添加监听socket到epoll
                                                                                loop_.addChannel(accept_socket_.getFd(), EPOLLIN | EPOLLET, this);
                                                                                // 启动事件循环
                                                                                loop_.start([this]() {
                                                                                handleAccept();
                                                                                });
                                                                                }
                                                                                void handleAccept() {
                                                                                while (true) {
                                                                                try {
                                                                                InetAddress peer_addr;
                                                                                struct sockaddr_in addr;
                                                                                socklen_t addrlen = sizeof(addr);
                                                                                Socket client_socket = accept_socket_.accept(
                                                                                reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
                                                                                  InetAddress local_addr;
                                                                                  struct sockaddr_in local_addr_in;
                                                                                  socklen_t local_addrlen = sizeof(local_addr_in);
                                                                                  if (getsockname(client_socket.getFd(),
                                                                                  reinterpret_cast<struct sockaddr*>(&local_addr_in),
                                                                                    &local_addrlen) == 0) {
                                                                                    local_addr = InetAddress(inet_ntoa(local_addr_in.sin_addr),
                                                                                    ntohs(local_addr_in.sin_port));
                                                                                    }
                                                                                    peer_addr = InetAddress(inet_ntoa(addr.sin_addr), ntohs(addr.sin_port));
                                                                                    auto conn = std::make_shared<TcpConnection>(
                                                                                      std::move(client_socket), local_addr, peer_addr);
                                                                                      conn->setMessageCallback(message_callback_);
                                                                                      conn->setCloseCallback([this](const std::shared_ptr<TcpConnection>& connection) {
                                                                                        this->removeConnection(connection);
                                                                                        if (close_callback_) {
                                                                                        close_callback_(connection);
                                                                                        }
                                                                                        });
                                                                                        conn->setErrorCallback(error_callback_);
                                                                                        {
                                                                                        std::lock_guard<std::mutex> lock(connections_mutex_);
                                                                                          connections_[conn->getFd()] = conn;
                                                                                          }
                                                                                          // 添加到epoll
                                                                                          loop_.addChannel(conn->getFd(), EPOLLIN | EPOLLOUT | EPOLLET, conn.get());
                                                                                          // 开始读取
                                                                                          conn->startReading();
                                                                                          } catch (const SocketException& e) {
                                                                                          if (e.errorCode() == EAGAIN || e.errorCode() == EWOULDBLOCK) {
                                                                                          // 没有更多连接
                                                                                          break;
                                                                                          }
                                                                                          // 其他错误
                                                                                          if (error_callback_) {
                                                                                          error_callback_(nullptr, e.errorCode());
                                                                                          }
                                                                                          }
                                                                                          }
                                                                                          }
                                                                                          void removeConnection(const std::shared_ptr<TcpConnection>& conn) {
                                                                                            std::lock_guard<std::mutex> lock(connections_mutex_);
                                                                                              connections_.erase(conn->getFd());
                                                                                              loop_.removeChannel(conn->getFd());
                                                                                              }
                                                                                              };
                                                                                              }

第三部分:高级网络编程技术

3.1 零拷贝技术实现

实现高效的零拷贝数据传输:

#include <sys/sendfile.h>
  #include <sys/uio.h>
    #include <fcntl.h>
      namespace AdvancedSocket {
      // 零拷贝发送器
      class ZeroCopySender {
      private:
      int socket_fd_;
      public:
      explicit ZeroCopySender(int socket_fd) : socket_fd_(socket_fd) {}
      // 使用sendfile进行零拷贝发送
      ssize_t sendFile(int file_fd, off_t offset, size_t count) {
      off_t current_offset = offset;
      ssize_t sent = ::sendfile(socket_fd_, file_fd, &current_offset, count);
      if (sent == -1) {
      if (errno == EAGAIN || errno == EWOULDBLOCK) {
      return 0;  // 需要重试
      }
      throw SocketException("sendfile failed", errno);
      }
      return sent;
      }
      // 使用splice进行零拷贝数据传输
      ssize_t spliceData(int pipe_fd, size_t len, int flags = 0) {
      ssize_t result = ::splice(pipe_fd, nullptr, socket_fd_, nullptr, len, flags);
      if (result == -1) {
      if (errno == EAGAIN || errno == EWOULDBLOCK) {
      return 0;
      }
      throw SocketException("splice failed", errno);
      }
      return result;
      }
      // 使用writev进行聚集写
      ssize_t writeVector(const struct iovec* iov, int iovcnt) {
      ssize_t result = ::writev(socket_fd_, iov, iovcnt);
      if (result == -1) {
      if (errno == EAGAIN || errno == EWOULDBLOCK) {
      return 0;
      }
      throw SocketException("writev failed", errno);
      }
      return result;
      }
      };
      // 内存映射文件发送器
      class MemoryMappedSender {
      private:
      int socket_fd_;
      public:
      explicit MemoryMappedSender(int socket_fd) : socket_fd_(socket_fd) {}
      // 发送内存映射文件
      void sendMappedFile(const std::string& filename) {
      int fd = ::open(filename.c_str(), O_RDONLY);
      if (fd == -1) {
      throw SocketException("Failed to open file", errno);
      }
      // 获取文件大小
      struct stat st;
      if (fstat(fd, &st) == -1) {
      close(fd);
      throw SocketException("Failed to stat file", errno);
      }
      // 内存映射
      void* mapped = ::mmap(nullptr, st.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
      if (mapped == MAP_FAILED) {
      close(fd);
      throw SocketException("Failed to mmap file", errno);
      }
      // 发送数据
      size_t total_sent = 0;
      const uint8_t* data = static_cast<const uint8_t*>(mapped);
        while (total_sent < static_cast<size_t>(st.st_size)) {
          ssize_t sent = ::send(socket_fd_, data + total_sent,
          st.st_size - total_sent, MSG_NOSIGNAL);
          if (sent == -1) {
          if (errno == EAGAIN || errno == EWOULDBLOCK) {
          continue;  // 重试
          }
          munmap(mapped, st.st_size);
          close(fd);
          throw SocketException("Failed to send mapped data", errno);
          }
          total_sent += sent;
          }
          // 清理
          munmap(mapped, st.st_size);
          close(fd);
          }
          };
          // 管道缓冲区管理器
          class PipeBuffer {
          private:
          int pipe_fds_[2];
          static constexpr size_t PIPE_SIZE = 65536;
          public:
          PipeBuffer() {
          if (::pipe2(pipe_fds_, O_NONBLOCK | O_CLOEXEC) == -1) {
          throw SocketException("Failed to create pipe", errno);
          }
          }
          ~PipeBuffer() {
          close(pipe_fds_[0]);
          close(pipe_fds_[1]);
          }
          // 获取读端文件描述符
          int getReadFd() const {
          return pipe_fds_[0];
          }
          // 获取写端文件描述符
          int getWriteFd() const {
          return pipe_fds_[1];
          }
          // 写入数据到管道
          ssize_t writeData(const void* data, size_t len) {
          return ::write(pipe_fds_[1], data, len);
          }
          // 从管道读取数据
          ssize_t readData(void* buffer, size_t len) {
          return ::read(pipe_fds_[0], buffer, len);
          }
          };
          }

3.2 高级I/O多路复用

实现基于epoll的边缘触发模式的高性能I/O多路复用:

namespace AdvancedSocket {
// 通道抽象
class Channel {
private:
int fd_;
uint32_t events_;
uint32_t revents_;
std::function<void()> read_callback_;
  std::function<void()> write_callback_;
    std::function<void()> error_callback_;
      std::function<void()> close_callback_;
        public:
        Channel(int fd) : fd_(fd), events_(0), revents_(0) {}
        int getFd() const { return fd_; }
        uint32_t getEvents() const { return events_; }
        void setEvents(uint32_t events) { events_ = events; }
        uint32_t getRevents() const { return revents_; }
        void setRevents(uint32_t revents) { revents_ = revents; }
        void setReadCallback(std::function<void()> callback) {
          read_callback_ = std::move(callback);
          }
          void setWriteCallback(std::function<void()> callback) {
            write_callback_ = std::move(callback);
            }
            void setErrorCallback(std::function<void()> callback) {
              error_callback_ = std::move(callback);
              }
              void setCloseCallback(std::function<void()> callback) {
                close_callback_ = std::move(callback);
                }
                void enableReading() { events_ |= EPOLLIN; }
                void disableReading() { events_ &= ~EPOLLIN; }
                void enableWriting() { events_ |= EPOLLOUT; }
                void disableWriting() { events_ &= ~EPOLLOUT; }
                void disableAll() { events_ = 0; }
                bool isReading() const { return events_ & EPOLLIN; }
                bool isWriting() const { return events_ & EPOLLOUT; }
                void handleEvent() {
                if ((revents_ & EPOLLHUP) && !(revents_ & EPOLLIN)) {
                if (close_callback_) close_callback_();
                }
                if (revents_ & (EPOLLERR | EPOLLHUP)) {
                if (error_callback_) error_callback_();
                }
                if (revents_ & (EPOLLIN | EPOLLPRI | EPOLLRDHUP)) {
                if (read_callback_) read_callback_();
                }
                if (revents_ & EPOLLOUT) {
                if (write_callback_) write_callback_();
                }
                }
                };
                // 高性能EpollPoller
                class EpollPoller {
                private:
                int epoll_fd_;
                std::vector<struct epoll_event> events_;
                  std::unordered_map<int, std::shared_ptr<Channel>> channels_;
                    static constexpr int kInitEventListSize = 16;
                    public:
                    EpollPoller() : epoll_fd_(epoll_create1(EPOLL_CLOEXEC)),
                    events_(kInitEventListSize) {
                    if (epoll_fd_ < 0) {
                    throw SocketException("Failed to create epoll", errno);
                    }
                    }
                    ~EpollPoller() {
                    if (epoll_fd_ >= 0) {
                    close(epoll_fd_);
                    }
                    }
                    // 添加通道
                    void addChannel(std::shared_ptr<Channel> channel) {
                      int fd = channel->getFd();
                      struct epoll_event ev;
                      std::memset(&ev, 0, sizeof(ev));
                      ev.events = channel->getEvents();
                      ev.data.fd = fd;
                      if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &ev) < 0) {
                      throw SocketException("Failed to add channel to epoll", errno);
                      }
                      channels_[fd] = channel;
                      }
                      // 修改通道
                      void modifyChannel(std::shared_ptr<Channel> channel) {
                        int fd = channel->getFd();
                        struct epoll_event ev;
                        std::memset(&ev, 0, sizeof(ev));
                        ev.events = channel->getEvents();
                        ev.data.fd = fd;
                        if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, fd, &ev) < 0) {
                        throw SocketException("Failed to modify channel in epoll", errno);
                        }
                        }
                        // 删除通道
                        void removeChannel(std::shared_ptr<Channel> channel) {
                          int fd = channel->getFd();
                          if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) < 0) {
                          throw SocketException("Failed to remove channel from epoll", errno);
                          }
                          channels_.erase(fd);
                          }
                          // 轮询事件
                          int poll(int timeout_ms = -1) {
                          int num_events = epoll_wait(epoll_fd_, events_.data(),
                          static_cast<int>(events_.size()), timeout_ms);
                            if (num_events < 0) {
                            if (errno == EINTR) {
                            return 0;
                            }
                            throw SocketException("epoll_wait failed", errno);
                            }
                            // 处理事件
                            for (int i = 0; i < num_events; ++i) {
                            int fd = events_[i].data.fd;
                            auto it = channels_.find(fd);
                            if (it != channels_.end()) {
                            auto& channel = it->second;
                            channel->setRevents(events_[i].events);
                            channel->handleEvent();
                            }
                            }
                            // 动态扩展事件数组
                            if (static_cast<size_t>(num_events) == events_.size()) {
                              events_.resize(events_.size() * 2);
                              }
                              return num_events;
                              }
                              };
                              // 定时器管理器
                              class TimerManager {
                              private:
                              struct Timer {
                              std::chrono::steady_clock::time_point expiration;
                              std::function<void()> callback;
                                bool repeat;
                                std::chrono::milliseconds interval;
                                bool operator<(const Timer& other) const {
                                return expiration > other.expiration;  // 最小堆
                                }
                                };
                                std::priority_queue<Timer> timers_;
                                  std::mutex timers_mutex_;
                                  public:
                                  // 添加定时器
                                  void addTimer(std::chrono::milliseconds delay,
                                  std::function<void()> callback,
                                    bool repeat = false,
                                    std::chrono::milliseconds interval = std::chrono::milliseconds(0)) {
                                    Timer timer;
                                    timer.expiration = std::chrono::steady_clock::now() + delay;
                                    timer.callback = std::move(callback);
                                    timer.repeat = repeat;
                                    timer.interval = interval;
                                    std::lock_guard<std::mutex> lock(timers_mutex_);
                                      timers_.push(std::move(timer));
                                      }
                                      // 处理到期的定时器
                                      void processTimers() {
                                      auto now = std::chrono::steady_clock::now();
                                      std::lock_guard<std::mutex> lock(timers_mutex_);
                                        while (!timers_.empty() && timers_.top().expiration <= now) {
                                        Timer timer = timers_.top();
                                        timers_.pop();
                                        // 执行回调
                                        if (timer.callback) {
                                        timer.callback();
                                        }
                                        // 如果是重复定时器,重新加入队列
                                        if (timer.repeat) {
                                        timer.expiration = now + timer.interval;
                                        timers_.push(std::move(timer));
                                        }
                                        }
                                        }
                                        // 获取最近的定时器到期时间
                                        int getNextTimeout() const {
                                        std::lock_guard<std::mutex> lock(timers_mutex_);
                                          if (timers_.empty()) {
                                          return -1;  // 无限等待
                                          }
                                          auto now = std::chrono::steady_clock::now();
                                          auto next_expiration = timers_.top().expiration;
                                          if (next_expiration <= now) {
                                          return 0;  // 立即处理
                                          }
                                          return std::chrono::duration_cast<std::chrono::milliseconds>(
                                            next_expiration - now).count();
                                            }
                                            };
                                            // 完整的事件驱动服务器
                                            class EventDrivenServer {
                                            private:
                                            EpollPoller poller_;
                                            TimerManager timer_manager_;
                                            std::atomic<bool> running_;
                                              std::thread server_thread_;
                                              public:
                                              EventDrivenServer() : running_(false) {}
                                              ~EventDrivenServer() {
                                              stop();
                                              }
                                              // 添加通道
                                              void addChannel(std::shared_ptr<Channel> channel) {
                                                poller_.addChannel(channel);
                                                }
                                                // 修改通道
                                                void modifyChannel(std::shared_ptr<Channel> channel) {
                                                  poller_.modifyChannel(channel);
                                                  }
                                                  // 删除通道
                                                  void removeChannel(std::shared_ptr<Channel> channel) {
                                                    poller_.removeChannel(channel);
                                                    }
                                                    // 添加定时器
                                                    void addTimer(std::chrono::milliseconds delay,
                                                    std::function<void()> callback,
                                                      bool repeat = false,
                                                      std::chrono::milliseconds interval = std::chrono::milliseconds(0)) {
                                                      timer_manager_.addTimer(delay, callback, repeat, interval);
                                                      }
                                                      // 启动服务器
                                                      void start() {
                                                      if (running_.exchange(true)) {
                                                      return;
                                                      }
                                                      server_thread_ = std::thread([this]() {
                                                      this->eventLoop();
                                                      });
                                                      }
                                                      // 停止服务器
                                                      void stop() {
                                                      if (!running_.exchange(false)) {
                                                      return;
                                                      }
                                                      if (server_thread_.joinable()) {
                                                      server_thread_.join();
                                                      }
                                                      }
                                                      private:
                                                      void eventLoop() {
                                                      while (running_) {
                                                      // 处理定时器
                                                      timer_manager_.processTimers();
                                                      // 获取下一个超时时间
                                                      int timeout = timer_manager_.getNextTimeout();
                                                      // 轮询事件
                                                      poller_.poll(timeout);
                                                      }
                                                      }
                                                      };
                                                      }

第四部分:实战项目

4.1 高性能HTTP服务器实现

实现一个完整的HTTP服务器,支持静态文件服务和动态内容生成:

#include <sstream>
  #include <fstream>
    #include <regex>
      #include <ctime>
        #include <iomanip>
          namespace AdvancedSocket {
          // HTTP请求解析器
          class HttpRequest {
          private:
          std::string method_;
          std::string path_;
          std::string version_;
          std::map<std::string, std::string> headers_;
            std::string body_;
            bool parsed_;
            public:
            HttpRequest() : parsed_(false) {}
            // 解析HTTP请求
            bool parse(const std::string& data) {
            std::istringstream stream(data);
            std::string line;
            // 解析请求行
            if (!std::getline(stream, line)) {
            return false;
            }
            std::istringstream request_line(line);
            if (!(request_line >> method_ >> path_ >> version_)) {
            return false;
            }
            // 解析头部
            while (std::getline(stream, line) && line != "\r") {
            size_t colon_pos = line.find(':');
            if (colon_pos != std::string::npos) {
            std::string key = line.substr(0, colon_pos);
            std::string value = line.substr(colon_pos + 1);
            // 去除前后空格
            value.erase(0, value.find_first_not_of(" \t\r\n"));
            value.erase(value.find_last_not_of(" \t\r\n") + 1);
            headers_[key] = value;
            }
            }
            // 解析消息体
            std::string remaining;
            while (std::getline(stream, line)) {
            body_ += line + "\n";
            }
            parsed_ = true;
            return true;
            }
            // 获取方法
            const std::string& getMethod() const { return method_; }
            // 获取路径
            const std::string& getPath() const { return path_; }
            // 获取版本
            const std::string& getVersion() const { return version_; }
            // 获取头部
            std::string getHeader(const std::string& key) const {
            auto it = headers_.find(key);
            return it != headers_.end() ? it->second : "";
            }
            // 获取消息体
            const std::string& getBody() const { return body_; }
            // 是否已解析
            bool isParsed() const { return parsed_; }
            };
            // HTTP响应构建器
            class HttpResponse {
            private:
            int status_code_;
            std::string status_message_;
            std::map<std::string, std::string> headers_;
              std::string body_;
              static const std::map<int, std::string> status_messages_;
                public:
                HttpResponse() : status_code_(200) {
                // 设置默认头部
                headers_["Server"] = "AdvancedSocket/1.0";
                headers_["Connection"] = "close";
                headers_["Content-Type"] = "text/html; charset=utf-8";
                }
                // 设置状态码
                void setStatusCode(int code) {
                status_code_ = code;
                }
                // 设置头部
                void setHeader(const std::string& key, const std::string& value) {
                headers_[key] = value;
                }
                // 设置消息体
                void setBody(const std::string& body) {
                body_ = body;
                headers_["Content-Length"] = std::to_string(body.size());
                }
                // 构建HTTP响应字符串
                std::string build() const {
                std::ostringstream response;
                // 状态行
                response << "HTTP/1.1 " << status_code_ << " "
                << getStatusMessage(status_code_) << "\r\n";
                // 头部
                for (const auto& [key, value] : headers_) {
                response << key << ": " << value << "\r\n";
                }
                response << "\r\n";
                // 消息体
                response << body_;
                return response.str();
                }
                private:
                std::string getStatusMessage(int code) const {
                static const std::map<int, std::string> messages = {
                  {200, "OK"},
                  {201, "Created"},
                  {204, "No Content"},
                  {301, "Moved Permanently"},
                  {302, "Found"},
                  {304, "Not Modified"},
                  {400, "Bad Request"},
                  {401, "Unauthorized"},
                  {403, "Forbidden"},
                  {404, "Not Found"},
                  {405, "Method Not Allowed"},
                  {500, "Internal Server Error"},
                  {501, "Not Implemented"},
                  {502, "Bad Gateway"},
                  {503, "Service Unavailable"}
                  };
                  auto it = messages.find(code);
                  return it != messages.end() ? it->second : "Unknown";
                  }
                  };
                  // HTTP处理器接口
                  class HttpHandler {
                  public:
                  virtual ~HttpHandler() = default;
                  virtual void handleRequest(const HttpRequest& request, HttpResponse& response) = 0;
                  };
                  // 静态文件处理器
                  class StaticFileHandler : public HttpHandler {
                  private:
                  std::string document_root_;
                  public:
                  explicit StaticFileHandler(const std::string& document_root)
                  : document_root_(document_root) {}
                  void handleRequest(const HttpRequest& request, HttpResponse& response) override {
                  std::string path = request.getPath();
                  // 安全检查
                  if (path.find("..") != std::string::npos) {
                  response.setStatusCode(403);
                  response.setBody("403 Forbidden: Invalid path");
                  return;
                  }
                  // 默认索引文件
                  if (path == "/") {
                  path = "/index.html";
                  }
                  std::string full_path = document_root_ + path;
                  // 读取文件
                  std::ifstream file(full_path, std::ios::binary);
                  if (!file) {
                  response.setStatusCode(404);
                  response.setBody("404 Not Found: " + path);
                  return;
                  }
                  // 读取文件内容
                  std::string content((std::istreambuf_iterator<char>(file)),
                    std::istreambuf_iterator<char>());
                      // 设置内容类型
                      std::string content_type = getContentType(path);
                      response.setHeader("Content-Type", content_type);
                      response.setBody(content);
                      }
                      private:
                      std::string getContentType(const std::string& path) const {
                      size_t dot_pos = path.find_last_of('.');
                      if (dot_pos == std::string::npos) {
                      return "application/octet-stream";
                      }
                      std::string ext = path.substr(dot_pos + 1);
                      static const std::map<std::string, std::string> mime_types = {
                        {"html", "text/html"},
                        {"css", "text/css"},
                        {"js", "application/javascript"},
                        {"json", "application/json"},
                        {"png", "image/png"},
                        {"jpg", "image/jpeg"},
                        {"jpeg", "image/jpeg"},
                        {"gif", "image/gif"},
                        {"svg", "image/svg+xml"},
                        {"txt", "text/plain"}
                        };
                        auto it = mime_types.find(ext);
                        return it != mime_types.end() ? it->second : "application/octet-stream";
                        }
                        };
                        // HTTP服务器
                        class HttpServer {
                        private:
                        TcpServer tcp_server_;
                        std::map<std::string, std::shared_ptr<HttpHandler>> handlers_;
                          std::shared_ptr<HttpHandler> default_handler_;
                            public:
                            HttpServer(const InetAddress& listen_addr)
                            : tcp_server_(listen_addr) {
                            // 设置TCP服务器回调
                            tcp_server_.setMessageCallback(
                            [this](const std::shared_ptr<TcpConnection>& conn,
                              const std::vector<uint8_t>& data) {
                                this->handleHttpRequest(conn, data);
                                });
                                tcp_server_.setCloseCallback(
                                [](const std::shared_ptr<TcpConnection>& conn) {
                                  // 连接关闭处理
                                  });
                                  tcp_server_.setErrorCallback(
                                  [](const std::shared_ptr<TcpConnection>& conn, int error) {
                                    // 错误处理
                                    });
                                    }
                                    // 注册处理器
                                    void registerHandler(const std::string& path, std::shared_ptr<HttpHandler> handler) {
                                      handlers_[path] = handler;
                                      }
                                      // 设置默认处理器
                                      void setDefaultHandler(std::shared_ptr<HttpHandler> handler) {
                                        default_handler_ = handler;
                                        }
                                        // 启动服务器
                                        void start() {
                                        tcp_server_.start();
                                        }
                                        // 停止服务器
                                        void stop() {
                                        tcp_server_.stop();
                                        }
                                        private:
                                        void handleHttpRequest(const std::shared_ptr<TcpConnection>& conn,
                                          const std::vector<uint8_t>& data) {
                                            std::string request_data(data.begin(), data.end());
                                            HttpRequest request;
                                            if (!request.parse(request_data)) {
                                            HttpResponse response;
                                            response.setStatusCode(400);
                                            response.setBody("400 Bad Request");
                                            conn->send(response.build());
                                            return;
                                            }
                                            HttpResponse response;
                                            // 查找对应的处理器
                                            auto it = handlers_.find(request.getPath());
                                            if (it != handlers_.end()) {
                                            it->second->handleRequest(request, response);
                                            } else if (default_handler_) {
                                            default_handler_->handleRequest(request, response);
                                            } else {
                                            response.setStatusCode(404);
                                            response.setBody("404 Not Found");
                                            }
                                            conn->send(response.build());
                                            }
                                            };
                                            }

第五部分:性能优化与监控

5.1 网络性能监控

实现网络性能监控和统计系统:

#include <atomic>
  #include <chrono>
    #include <algorithm>
      #include <numeric>
        namespace AdvancedSocket {
        // 网络性能监控器
        class NetworkMonitor {
        private:
        struct ConnectionStats {
        std::atomic<uint64_t> bytes_sent{0};
          std::atomic<uint64_t> bytes_received{0};
            std::atomic<uint64_t> packets_sent{0};
              std::atomic<uint64_t> packets_received{0};
                std::atomic<uint64_t> errors{0};
                  std::atomic<uint64_t> retransmissions{0};
                    std::chrono::steady_clock::time_point start_time;
                    std::chrono::steady_clock::time_point last_activity;
                    ConnectionStats() {
                    auto now = std::chrono::steady_clock::now();
                    start_time = now;
                    last_activity = now;
                    }
                    };
                    struct LatencySample {
                    std::chrono::microseconds latency;
                    std::chrono::steady_clock::time_point timestamp;
                    };
                    std::unordered_map<int, ConnectionStats> connection_stats_;
                      std::mutex stats_mutex_;
                      // 延迟样本
                      std::vector<LatencySample> latency_samples_;
                        std::mutex latency_mutex_;
                        static constexpr size_t MAX_LATENCY_SAMPLES = 10000;
                        // 全局统计
                        std::atomic<uint64_t> total_connections_{0};
                          std::atomic<uint64_t> active_connections_{0};
                            std::atomic<uint64_t> total_errors_{0};
                              public:
                              // 记录连接建立
                              void onConnectionEstablished(int fd) {
                              std::lock_guard<std::mutex> lock(stats_mutex_);
                                connection_stats_.emplace(fd, ConnectionStats{});
                                total_connections_++;
                                active_connections_++;
                                }
                                // 记录连接关闭
                                void onConnectionClosed(int fd) {
                                std::lock_guard<std::mutex> lock(stats_mutex_);
                                  connection_stats_.erase(fd);
                                  active_connections_--;
                                  }
                                  // 记录数据发送
                                  void onDataSent(int fd, size_t bytes) {
                                  std::lock_guard<std::mutex> lock(stats_mutex_);
                                    auto it = connection_stats_.find(fd);
                                    if (it != connection_stats_.end()) {
                                    it->second.bytes_sent += bytes;
                                    it->second.packets_sent++;
                                    it->second.last_activity = std::chrono::steady_clock::now();
                                    }
                                    }
                                    // 记录数据接收
                                    void onDataReceived(int fd, size_t bytes) {
                                    std::lock_guard<std::mutex> lock(stats_mutex_);
                                      auto it = connection_stats_.find(fd);
                                      if (it != connection_stats_.end()) {
                                      it->second.bytes_received += bytes;
                                      it->second.packets_received++;
                                      it->second.last_activity = std::chrono::steady_clock::now();
                                      }
                                      }
                                      // 记录延迟样本
                                      void recordLatency(std::chrono::microseconds latency) {
                                      std::lock_guard<std::mutex> lock(latency_mutex_);
                                        LatencySample sample;
                                        sample.latency = latency;
                                        sample.timestamp = std::chrono::steady_clock::now();
                                        latency_samples_.push_back(sample);
                                        // 限制样本数量
                                        if (latency_samples_.size() > MAX_LATENCY_SAMPLES) {
                                        latency_samples_.erase(latency_samples_.begin());
                                        }
                                        }
                                        // 获取统计信息
                                        struct NetworkStats {
                                        uint64_t total_connections;
                                        uint64_t active_connections;
                                        uint64_t total_errors;
                                        double avg_latency_us;
                                        double p50_latency_us;
                                        double p95_latency_us;
                                        double p99_latency_us;
                                        double throughput_mbps;
                                        };
                                        NetworkStats getStats() const {
                                        NetworkStats stats{};
                                        stats.total_connections = total_connections_.load();
                                        stats.active_connections = active_connections_.load();
                                        stats.total_errors = total_errors_.load();
                                        // 计算延迟统计
                                        std::vector<std::chrono::microseconds> latencies;
                                          {
                                          std::lock_guard<std::mutex> lock(latency_mutex_);
                                            for (const auto& sample : latency_samples_) {
                                            latencies.push_back(sample.latency);
                                            }
                                            }
                                            if (!latencies.empty()) {
                                            std::sort(latencies.begin(), latencies.end());
                                            size_t n = latencies.size();
                                            stats.avg_latency_us = std::accumulate(latencies.begin(), latencies.end(),
                                            std::chrono::microseconds(0)).count() / n;
                                            stats.p50_latency_us = latencies[n * 50 / 100].count();
                                            stats.p95_latency_us = latencies[n * 95 / 100].count();
                                            stats.p99_latency_us = latencies[n * 99 / 100].count();
                                            }
                                            // 计算吞吐量
                                            uint64_t total_bytes = 0;
                                            auto now = std::chrono::steady_clock::now();
                                            std::lock_guard<std::mutex> lock(stats_mutex_);
                                              for (const auto& [fd, conn_stats] : connection_stats_) {
                                              auto duration = std::chrono::duration_cast<std::chrono::seconds>(
                                                now - conn_stats.start_time).count();
                                                if (duration > 0) {
                                                total_bytes += conn_stats.bytes_sent + conn_stats.bytes_received;
                                                }
                                                }
                                                stats.throughput_mbps = (total_bytes * 8.0) / (1024.0 * 1024.0);  // Mbps
                                                return stats;
                                                }
                                                // 记录错误
                                                void onError(int fd, int error_code) {
                                                total_errors_++;
                                                std::lock_guard<std::mutex> lock(stats_mutex_);
                                                  auto it = connection_stats_.find(fd);
                                                  if (it != connection_stats_.end()) {
                                                  it->second.errors++;
                                                  }
                                                  }
                                                  };
                                                  // 自适应缓冲区管理器
                                                  class AdaptiveBufferManager {
                                                  private:
                                                  struct BufferStats {
                                                  size_t current_size;
                                                  size_t peak_usage;
                                                  double utilization;
                                                  std::chrono::steady_clock::time_point last_adjustment;
                                                  };
                                                  std::unordered_map<int, BufferStats> buffer_stats_;
                                                    std::mutex stats_mutex_;
                                                    static constexpr size_t MIN_BUFFER_SIZE = 4096;
                                                    static constexpr size_t MAX_BUFFER_SIZE = 1024 * 1024;  // 1MB
                                                    static constexpr double TARGET_UTILIZATION = 0.8;
                                                    public:
                                                    // 获取推荐缓冲区大小
                                                    size_t getRecommendedBufferSize(int connection_id, size_t current_usage) {
                                                    std::lock_guard<std::mutex> lock(stats_mutex_);
                                                      auto& stats = buffer_stats_[connection_id];
                                                      auto now = std::chrono::steady_clock::now();
                                                      // 更新统计
                                                      stats.peak_usage = std::max(stats.peak_usage, current_usage);
                                                      stats.utilization = static_cast<double>(current_usage) / stats.current_size;
                                                        // 计算新的缓冲区大小
                                                        size_t new_size = stats.current_size;
                                                        // 如果利用率过高,增加缓冲区
                                                        if (stats.utilization > TARGET_UTILIZATION &&
                                                        stats.current_size < MAX_BUFFER_SIZE) {
                                                        new_size = std::min(stats.current_size * 2, MAX_BUFFER_SIZE);
                                                        }
                                                        // 如果利用率过低,减少缓冲区
                                                        else if (stats.utilization < TARGET_UTILIZATION * 0.5 &&
                                                        stats.current_size > MIN_BUFFER_SIZE) {
                                                        new_size = std::max(stats.current_size / 2, MIN_BUFFER_SIZE);
                                                        }
                                                        // 限制调整频率
                                                        auto time_since_adjustment = std::chrono::duration_cast<std::chrono::seconds>(
                                                          now - stats.last_adjustment).count();
                                                          if (time_since_adjustment > 5) {  // 5秒调整一次
                                                          stats.current_size = new_size;
                                                          stats.last_adjustment = now;
                                                          }
                                                          return stats.current_size;
                                                          }
                                                          // 初始化缓冲区统计
                                                          void initializeBuffer(int connection_id, size_t initial_size) {
                                                          std::lock_guard<std::mutex> lock(stats_mutex_);
                                                            BufferStats stats{};
                                                            stats.current_size = std::max(initial_size, MIN_BUFFER_SIZE);
                                                            stats.peak_usage = 0;
                                                            stats.utilization = 0.0;
                                                            stats.last_adjustment = std::chrono::steady_clock::now();
                                                            buffer_stats_[connection_id] = stats;
                                                            }
                                                            };
                                                            }

第六部分:实践练习

练习1:TCP客户端实现

实现一个完整的TCP客户端,支持重连和心跳机制:

#include <iostream>
  #include <thread>
    #include <chrono>
      namespace AdvancedSocket {
      class TcpClient {
      private:
      std::string server_ip_;
      uint16_t server_port_;
      std::atomic<bool> connected_;
        std::shared_ptr<TcpConnection> connection_;
          std::thread reconnect_thread_;
          std::atomic<bool> should_reconnect_;
            NetworkMonitor monitor_;
            public:
            TcpClient(const std::string& server_ip, uint16_t server_port)
            : server_ip_(server_ip)
            , server_port_(server_port)
            , connected_(false)
            , should_reconnect_(true) {
            }
            ~TcpClient() {
            disconnect();
            should_reconnect_ = false;
            if (reconnect_thread_.joinable()) {
            reconnect_thread_.join();
            }
            }
            // 连接到服务器
            bool connect() {
            try {
            Socket socket;
            InetAddress server_addr(server_ip_, server_port_);
            socket.connect(server_addr.getSockAddr(), server_addr.getSockLen());
            socket.setNonBlocking(true);
            socket.setTcpNoDelay(true);
            InetAddress local_addr;
            struct sockaddr_in local_addr_in;
            socklen_t local_addrlen = sizeof(local_addr_in);
            if (getsockname(socket.getFd(),
            reinterpret_cast<struct sockaddr*>(&local_addr_in),
              &local_addrlen) == 0) {
              local_addr = InetAddress(inet_ntoa(local_addr_in.sin_addr),
              ntohs(local_addr_in.sin_port));
              }
              connection_ = std::make_shared<TcpConnection>(
                std::move(socket), local_addr, server_addr);
                setupConnectionCallbacks();
                connection_->startReading();
                connected_ = true;
                monitor_.onConnectionEstablished(connection_->getFd());
                std::cout << "Connected to server: " << server_addr.toString() << std::endl;
                return true;
                } catch (const std::exception& e) {
                std::cerr << "Connection failed: " << e.what() << std::endl;
                return false;
                }
                }
                // 断开连接
                void disconnect() {
                connected_ = false;
                if (connection_) {
                connection_->shutdown();
                monitor_.onConnectionClosed(connection_->getFd());
                }
                }
                // 发送消息
                void send(const std::string& message) {
                if (connected_ && connection_) {
                connection_->send(message);
                monitor_.onDataSent(connection_->getFd(), message.size());
                }
                }
                // 启动自动重连
                void startAutoReconnect() {
                reconnect_thread_ = std::thread([this]() {
                while (should_reconnect_) {
                if (!connected_) {
                std::cout << "Attempting to reconnect..." << std::endl;
                if (connect()) {
                std::cout << "Reconnected successfully" << std::endl;
                }
                }
                std::this_thread::sleep_for(std::chrono::seconds(5));
                }
                });
                }
                // 获取网络统计
                NetworkMonitor::NetworkStats getStats() const {
                return monitor_.getStats();
                }
                private:
                void setupConnectionCallbacks() {
                connection_->setMessageCallback(
                [this](const std::shared_ptr<TcpConnection>& conn,
                  const std::vector<uint8_t>& data) {
                    std::string message(data.begin(), data.end());
                    std::cout << "Received: " << message << std::endl;
                    monitor_.onDataReceived(conn->getFd(), data.size());
                    });
                    connection_->setCloseCallback(
                    [this](const std::shared_ptr<TcpConnection>& conn) {
                      std::cout << "Connection closed" << std::endl;
                      connected_ = false;
                      monitor_.onConnectionClosed(conn->getFd());
                      });
                      connection_->setErrorCallback(
                      [this](const std::shared_ptr<TcpConnection>& conn, int error) {
                        std::cerr << "Connection error: " << error << std::endl;
                        monitor_.onError(conn->getFd(), error);
                        });
                        }
                        };
                        }
                        // 使用示例
                        int main() {
                        try {
                        AdvancedSocket::TcpClient client("127.0.0.1", 8080);
                        // 启动自动重连
                        client.startAutoReconnect();
                        // 尝试连接
                        if (client.connect()) {
                        std::cout << "Connected to server" << std::endl;
                        // 发送消息
                        client.send("Hello, Server!");
                        // 等待响应
                        std::this_thread::sleep_for(std::chrono::seconds(2));
                        // 获取统计信息
                        auto stats = client.getStats();
                        std::cout << "Network Stats:" << std::endl;
                        std::cout << "  Active connections: " << stats.active_connections << std::endl;
                        std::cout << "  Total connections: " << stats.total_connections << std::endl;
                        std::cout << "  Throughput: " << stats.throughput_mbps << " Mbps" << std::endl;
                        }
                        // 保持运行一段时间
                        std::this_thread::sleep_for(std::chrono::seconds(30));
                        } catch (const std::exception& e) {
                        std::cerr << "Error: " << e.what() << std::endl;
                        return 1;
                        }
                        return 0;
                        }

练习2:UDP可靠传输实现

在UDP基础上实现可靠传输机制:

#include <queue>
  #include <unordered_map>
    #include <random>
      namespace AdvancedSocket {
      // UDP可靠传输协议
      class ReliableUDP {
      private:
      struct Packet {
      uint32_t sequence_number;
      uint32_t ack_number;
      uint16_t flags;
      uint16_t window_size;
      std::vector<uint8_t> data;
        std::chrono::steady_clock::time_point timestamp;
        int retry_count;
        };
        static constexpr uint16_t FLAG_ACK = 0x01;
        static constexpr uint16_t FLAG_SYN = 0x02;
        static constexpr uint16_t FLAG_FIN = 0x04;
        static constexpr size_t MAX_PACKET_SIZE = 1400;
        static constexpr size_t MAX_RETRIES = 3;
        static constexpr auto RETRY_TIMEOUT = std::chrono::milliseconds(100);
        Socket socket_;
        InetAddress peer_addr_;
        uint32_t send_sequence_;
        uint32_t recv_sequence_;
        uint32_t acked_sequence_;
        std::queue<Packet> send_queue_;
          std::unordered_map<uint32_t, Packet> unacknowledged_packets_;
            std::mutex send_mutex_;
            std::atomic<bool> connected_;
              std::thread send_thread_;
              std::thread recv_thread_;
              NetworkMonitor monitor_;
              public:
              ReliableUDP() : send_sequence_(0), recv_sequence_(0), acked_sequence_(0), connected_(false) {
              socket_ = Socket(AF_INET, SOCK_DGRAM, 0);
              socket_.setNonBlocking(true);
              }
              ~ReliableUDP() {
              disconnect();
              }
              // 连接到对端
              bool connect(const std::string& ip, uint16_t port) {
              peer_addr_ = InetAddress(ip, port);
              // 发送SYN包
              Packet syn_packet;
              syn_packet.sequence_number = send_sequence_++;
              syn_packet.ack_number = 0;
              syn_packet.flags = FLAG_SYN;
              syn_packet.window_size = 1024;
              syn_packet.timestamp = std::chrono::steady_clock::now();
              syn_packet.retry_count = 0;
              if (!sendPacket(syn_packet)) {
              return false;
              }
              // 等待SYN-ACK响应
              auto start_time = std::chrono::steady_clock::now();
              while (std::chrono::steady_clock::now() - start_time < std::chrono::seconds(5)) {
              Packet response;
              if (receivePacket(response) && (response.flags & FLAG_SYN) && (response.flags & FLAG_ACK)) {
              recv_sequence_ = response.sequence_number;
              acked_sequence_ = response.ack_number;
              connected_ = true;
              // 发送ACK确认
              Packet ack_packet;
              ack_packet.sequence_number = send_sequence_++;
              ack_packet.ack_number = response.sequence_number + 1;
              ack_packet.flags = FLAG_ACK;
              ack_packet.window_size = 1024;
              sendPacket(ack_packet);
              // 启动发送和接收线程
              startThreads();
              monitor_.onConnectionEstablished(socket_.getFd());
              return true;
              }
              std::this_thread::sleep_for(std::chrono::milliseconds(10));
              }
              return false;
              }
              // 断开连接
              void disconnect() {
              connected_ = false;
              if (send_thread_.joinable()) {
              send_thread_.join();
              }
              if (recv_thread_.joinable()) {
              recv_thread_.join();
              }
              monitor_.onConnectionClosed(socket_.getFd());
              }
              // 发送可靠数据
              bool send(const void* data, size_t len) {
              if (!connected_) {
              return false;
              }
              // 分片发送
              size_t offset = 0;
              while (offset < len) {
              size_t chunk_size = std::min(len - offset, MAX_PACKET_SIZE - sizeof(PacketHeader));
              Packet packet;
              packet.sequence_number = send_sequence_++;
              packet.ack_number = recv_sequence_;
              packet.flags = 0;
              packet.window_size = 1024;
              packet.data.assign(static_cast<const uint8_t*>(data) + offset,
                static_cast<const uint8_t*>(data) + offset + chunk_size);
                  packet.timestamp = std::chrono::steady_clock::now();
                  packet.retry_count = 0;
                  {
                  std::lock_guard<std::mutex> lock(send_mutex_);
                    send_queue_.push(packet);
                    }
                    offset += chunk_size;
                    }
                    return true;
                    }
                    // 接收可靠数据
                    bool receive(void* buffer, size_t& len) {
                    if (!connected_) {
                    return false;
                    }
                    std::vector<uint8_t> received_data;
                      auto start_time = std::chrono::steady_clock::now();
                      while (std::chrono::steady_clock::now() - start_time < std::chrono::seconds(5)) {
                      Packet packet;
                      if (receivePacket(packet) && !(packet.flags & FLAG_SYN) && !(packet.flags & FLAG_FIN)) {
                      // 按序接收
                      if (packet.sequence_number == recv_sequence_ + 1) {
                      received_data.insert(received_data.end(), packet.data.begin(), packet.data.end());
                      recv_sequence_ = packet.sequence_number;
                      // 发送ACK
                      sendAck(packet.sequence_number);
                      if (received_data.size() >= len) {
                      std::memcpy(buffer, received_data.data(), len);
                      return true;
                      }
                      }
                      // 重复包,重发ACK
                      else if (packet.sequence_number <= recv_sequence_) {
                      sendAck(packet.sequence_number);
                      }
                      }
                      std::this_thread::sleep_for(std::chrono::milliseconds(1));
                      }
                      return false;
                      }
                      // 获取网络统计
                      NetworkMonitor::NetworkStats getStats() const {
                      return monitor_.getStats();
                      }
                      private:
                      struct PacketHeader {
                      uint32_t sequence_number;
                      uint32_t ack_number;
                      uint16_t flags;
                      uint16_t window_size;
                      uint16_t data_length;
                      uint16_t checksum;
                      };
                      void startThreads() {
                      send_thread_ = std::thread([this]() { sendThread(); });
                      recv_thread_ = std::thread([this]() { recvThread(); });
                      }
                      void sendThread() {
                      while (connected_) {
                      processSendQueue();
                      processRetransmissions();
                      std::this_thread::sleep_for(std::chrono::milliseconds(10));
                      }
                      }
                      void recvThread() {
                      while (connected_) {
                      processIncomingPackets();
                      std::this_thread::sleep_for(std::chrono::milliseconds(1));
                      }
                      }
                      bool sendPacket(const Packet& packet) {
                      PacketHeader header;
                      header.sequence_number = htonl(packet.sequence_number);
                      header.ack_number = htonl(packet.ack_number);
                      header.flags = htons(packet.flags);
                      header.window_size = htons(packet.window_size);
                      header.data_length = htons(packet.data.size());
                      header.checksum = calculateChecksum(packet);
                      std::vector<uint8_t> buffer(sizeof(PacketHeader) + packet.data.size());
                        std::memcpy(buffer.data(), &header, sizeof(header));
                        if (!packet.data.empty()) {
                        std::memcpy(buffer.data() + sizeof(header), packet.data.data(), packet.data.size());
                        }
                        ssize_t sent = socket_.send(buffer.data(), buffer.size(), 0);
                        if (sent > 0) {
                        monitor_.onDataSent(socket_.getFd(), sent);
                        return true;
                        }
                        return false;
                        }
                        bool receivePacket(Packet& packet) {
                        std::vector<uint8_t> buffer(MAX_PACKET_SIZE);
                          ssize_t received = socket_.recv(buffer.data(), buffer.size(), 0);
                          if (received < static_cast<ssize_t>(sizeof(PacketHeader))) {
                            return false;
                            }
                            PacketHeader header;
                            std::memcpy(&header, buffer.data(), sizeof(header));
                            packet.sequence_number = ntohl(header.sequence_number);
                            packet.ack_number = ntohl(header.ack_number);
                            packet.flags = ntohs(header.flags);
                            packet.window_size = ntohs(header.window_size);
                            uint16_t data_length = ntohs(header.data_length);
                            if (data_length > 0 && received >= static_cast<ssize_t>(sizeof(header) + data_length)) {
                              packet.data.assign(buffer.begin() + sizeof(header),
                              buffer.begin() + sizeof(header) + data_length);
                              }
                              monitor_.onDataReceived(socket_.getFd(), received);
                              return true;
                              }
                              uint16_t calculateChecksum(const Packet& packet) {
                              // 简化的校验和计算
                              uint32_t sum = packet.sequence_number + packet.ack_number +
                              packet.flags + packet.window_size;
                              for (uint8_t byte : packet.data) {
                              sum += byte;
                              }
                              return static_cast<uint16_t>((sum & 0xFFFF) + (sum >> 16));
                                }
                                void sendAck(uint32_t sequence_number) {
                                Packet ack_packet;
                                ack_packet.sequence_number = send_sequence_++;
                                ack_packet.ack_number = sequence_number;
                                ack_packet.flags = FLAG_ACK;
                                ack_packet.window_size = 1024;
                                sendPacket(ack_packet);
                                }
                                void processSendQueue() {
                                while (connected_) {
                                Packet packet;
                                {
                                std::lock_guard<std::mutex> lock(send_mutex_);
                                  if (send_queue_.empty()) {
                                  break;
                                  }
                                  packet = send_queue_.front();
                                  send_queue_.pop();
                                  }
                                  if (sendPacket(packet)) {
                                  // 保存未确认的包
                                  unacknowledged_packets_[packet.sequence_number] = packet;
                                  } else {
                                  // 发送失败,重新加入队列
                                  std::lock_guard<std::mutex> lock(send_mutex_);
                                    send_queue_.push(packet);
                                    break;
                                    }
                                    }
                                    }
                                    void processRetransmissions() {
                                    auto now = std::chrono::steady_clock::now();
                                    for (auto& [seq_num, packet] : unacknowledged_packets_) {
                                    if (now - packet.timestamp > RETRY_TIMEOUT) {
                                    if (packet.retry_count < MAX_RETRIES) {
                                    packet.timestamp = now;
                                    packet.retry_count++;
                                    sendPacket(packet);
                                    } else {
                                    // 超过重试次数,认为连接失败
                                    connected_ = false;
                                    break;
                                    }
                                    }
                                    }
                                    }
                                    void processIncomingPackets() {
                                    Packet packet;
                                    while (receivePacket(packet)) {
                                    // 处理ACK
                                    if (packet.flags & FLAG_ACK) {
                                    unacknowledged_packets_.erase(packet.ack_number);
                                    }
                                    // 处理数据包
                                    if (!(packet.flags & (FLAG_SYN | FLAG_ACK | FLAG_FIN))) {
                                    // 这里处理接收到的数据
                                    }
                                    }
                                    }
                                    };
                                    }
                                    // 使用示例
                                    int main() {
                                    try {
                                    AdvancedSocket::ReliableUDP client;
                                    // 连接到服务器
                                    if (client.connect("127.0.0.1", 8888)) {
                                    std::cout << "Connected to server" << std::endl;
                                    // 发送数据
                                    std::string message = "Hello, Reliable UDP Server!";
                                    if (client.send(message.data(), message.size())) {
                                    std::cout << "Message sent successfully" << std::endl;
                                    }
                                    // 接收响应
                                    char buffer[1024];
                                    size_t buffer_size = sizeof(buffer);
                                    if (client.receive(buffer, buffer_size)) {
                                    std::cout << "Received response: " << std::string(buffer, buffer_size) << std::endl;
                                    }
                                    // 获取统计信息
                                    auto stats = client.getStats();
                                    std::cout << "Network Stats:" << std::endl;
                                    std::cout << "  Total connections: " << stats.total_connections << std::endl;
                                    std::cout << "  Active connections: " << stats.active_connections << std::endl;
                                    std::cout << "  Total errors: " << stats.total_errors << std::endl;
                                    }
                                    } catch (const std::exception& e) {
                                    std::cerr << "Error: " << e.what() << std::endl;
                                    return 1;
                                    }
                                    return 0;
                                    }

练习3:网络性能测试工具

实现一个网络性能测试工具,测量带宽、延迟和丢包率:

#include <iostream>
  #include <fstream>
    #include <vector>
      #include <random>
        #include <cmath>
          namespace AdvancedSocket {
          // 网络性能测试器
          class NetworkPerformanceTester {
          private:
          struct TestResult {
          double throughput_mbps;
          double avg_latency_ms;
          double min_latency_ms;
          double max_latency_ms;
          double jitter_ms;
          double packet_loss_rate;
          size_t packets_sent;
          size_t packets_received;
          };
          std::string server_ip_;
          uint16_t server_port_;
          NetworkMonitor monitor_;
          public:
          NetworkPerformanceTester(const std::string& server_ip, uint16_t server_port)
          : server_ip_(server_ip), server_port_(server_port) {}
          // 带宽测试
          TestResult bandwidthTest(size_t test_duration_seconds = 10,
          size_t packet_size = 1400,
          size_t packets_per_second = 1000) {
          TcpClient client(server_ip_, server_port_);
          TestResult result{};
          if (!client.connect()) {
          std::cerr << "Failed to connect to server" << std::endl;
          return result;
          }
          std::cout << "Starting bandwidth test..." << std::endl;
          std::vector<uint8_t> data(packet_size);
            std::generate(data.begin(), data.end(), []() { return rand() % 256; });
            auto start_time = std::chrono::steady_clock::now();
            auto next_packet_time = start_time;
            while (std::chrono::steady_clock::now() - start_time <
            std::chrono::seconds(test_duration_seconds)) {
            auto now = std::chrono::steady_clock::now();
            if (now >= next_packet_time) {
            client.send(data.data(), data.size());
            result.packets_sent++;
            next_packet_time += std::chrono::microseconds(1000000 / packets_per_second);
            }
            std::this_thread::sleep_for(std::chrono::microseconds(100));
            }
            // 获取统计信息
            auto stats = client.getStats();
            result.packets_received = stats.messages_received;
            result.packet_loss_rate = 1.0 - (static_cast<double>(result.packets_received) /
              result.packets_sent);
              // 计算吞吐量
              size_t total_bytes = result.packets_sent * packet_size;
              double test_duration = std::chrono::duration_cast<std::chrono::seconds>(
                std::chrono::steady_clock::now() - start_time).count();
                result.throughput_mbps = (total_bytes * 8.0) / (test_duration * 1000000.0);
                return result;
                }
                // 延迟测试
                TestResult latencyTest(size_t num_packets = 1000,
                size_t packet_size = 64,
                size_t interval_ms = 10) {
                ReliableUDP client;
                TestResult result{};
                std::vector<double> latencies;
                  if (!client.connect(server_ip_, server_port_ + 1)) {  // 使用不同端口
                  std::cerr << "Failed to connect to server" << std::endl;
                  return result;
                  }
                  std::cout << "Starting latency test..." << std::endl;
                  std::vector<uint8_t> data(packet_size);
                    std::generate(data.begin(), data.end(), []() { return rand() % 256; });
                    for (size_t i = 0; i < num_packets; ++i) {
                    auto send_time = std::chrono::steady_clock::now();
                    if (client.send(data.data(), data.size())) {
                    result.packets_sent++;
                    std::vector<uint8_t> response(packet_size);
                      size_t response_size = packet_size;
                      if (client.receive(response.data(), response_size)) {
                      auto recv_time = std::chrono::steady_clock::now();
                      double latency = std::chrono::duration_cast<std::chrono::microseconds>(
                        recv_time - send_time).count() / 1000.0;
                        latencies.push_back(latency);
                        result.packets_received++;
                        }
                        }
                        std::this_thread::sleep_for(std::chrono::milliseconds(interval_ms));
                        }
                        // 计算延迟统计
                        if (!latencies.empty()) {
                        std::sort(latencies.begin(), latencies.end());
                        result.min_latency_ms = latencies.front();
                        result.max_latency_ms = latencies.back();
                        result.avg_latency_ms = std::accumulate(latencies.begin(), latencies.end(), 0.0) /
                        latencies.size();
                        // 计算抖动 (标准差)
                        double variance = 0.0;
                        for (double latency : latencies) {
                        variance += std::pow(latency - result.avg_latency_ms, 2);
                        }
                        result.jitter_ms = std::sqrt(variance / latencies.size());
                        result.packet_loss_rate = 1.0 - (static_cast<double>(result.packets_received) /
                          result.packets_sent);
                          }
                          return result;
                          }
                          // 综合测试报告
                          void runFullTest(const std::string& output_file = "network_test_report.txt") {
                          std::cout << "Running full network performance test..." << std::endl;
                          // 延迟测试
                          std::cout << "\n=== Latency Test ===" << std::endl;
                          auto latency_result = latencyTest();
                          printTestResult("Latency Test", latency_result);
                          // 带宽测试
                          std::cout << "\n=== Bandwidth Test ===" << std::endl;
                          auto bandwidth_result = bandwidthTest();
                          printTestResult("Bandwidth Test", bandwidth_result);
                          // 生成测试报告
                          generateTestReport("Full Network Test",
                          {latency_result, bandwidth_result},
                          output_file);
                          }
                          private:
                          void printTestResult(const std::string& test_name, const TestResult& result) {
                          std::cout << test_name << " Results:" << std::endl;
                          std::cout << "  Throughput: " << result.throughput_mbps << " Mbps" << std::endl;
                          std::cout << "  Average Latency: " << result.avg_latency_ms << " ms" << std::endl;
                          std::cout << "  Min/Max Latency: " << result.min_latency_ms << "/" << result.max_latency_ms << " ms" << std::endl;
                          std::cout << "  Jitter: " << result.jitter_ms << " ms" << std::endl;
                          std::cout << "  Packet Loss: " << (result.packet_loss_rate * 100) << "%" << std::endl;
                          std::cout << "  Packets: " << result.packets_received << "/" << result.packets_sent << std::endl;
                          }
                          void generateTestReport(const std::string& test_name,
                          const std::vector<TestResult>& results,
                            const std::string& filename) {
                            std::ofstream file(filename);
                            if (!file) {
                            std::cerr << "Failed to create report file" << std::endl;
                            return;
                            }
                            file << "Network Performance Test Report" << std::endl;
                            file << "=================================" << std::endl;
                            file << "Test Name: " << test_name << std::endl;
                            file << "Server: " << server_ip_ << ":" << server_port_ << std::endl;
                            file << "Test Time: " << std::chrono::duration_cast<std::chrono::seconds>(
                              std::chrono::system_clock::now().time_since_epoch()).count() << std::endl;
                              file << std::endl;
                              for (size_t i = 0; i < results.size(); ++i) {
                              file << "Test " << (i + 1) << ":" << std::endl;
                              file << "  Throughput: " << results[i].throughput_mbps << " Mbps" << std::endl;
                              file << "  Average Latency: " << results[i].avg_latency_ms << " ms" << std::endl;
                              file << "  Jitter: " << results[i].jitter_ms << " ms" << std::endl;
                              file << "  Packet Loss: " << (results[i].packet_loss_rate * 100) << "%" << std::endl;
                              file << std::endl;
                              }
                              std::cout << "Test report saved to: " << filename << std::endl;
                              }
                              };
                              }
                              // 使用示例
                              int main() {
                              try {
                              AdvancedSocket::NetworkPerformanceTester tester("127.0.0.1", 8080);
                              // 运行完整测试
                              tester.runFullTest("network_performance_report.txt");
                              } catch (const std::exception& e) {
                              std::cerr << "Error: " << e.what() << std::endl;
                              return 1;
                              }
                              return 0;
                              }

总结

本周我们深入学习了网络编程的基础知识,包括:

核心概念

  1. 网络协议栈核心概念:理解TCP/IP协议栈的工作原理,包括拥塞控制算法、队列管理等
  2. 现代C++ Socket编程:RAII封装、异常安全、移动语义等现代C++特性的应用
  3. 高性能网络编程:epoll边缘触发、零拷贝技术、事件驱动架构
  4. 网络性能优化:自适应缓冲区管理、性能监控、统计指标收集

实战技能

  1. TCP/IP协议实现:深入理解TCP状态机、拥塞控制、流量控制
  2. 并发网络编程:多线程、事件驱动、异步I/O等并发模型
  3. 网络性能调优:延迟优化、吞吐量最大化、资源利用率提升
  4. 错误处理和容错:连接管理、超时重试、故障恢复

理论基础

  1. 排队论基础:理解数据包排队处理的基本原理
  2. 概率分析:丢包率、重传概率的实际意义
  3. 性能指标:延迟、抖动、吞吐量的实际测量
  4. 优化算法:自适应调优和反馈控制机制

最佳实践

  1. 现代C++特性:智能指针、RAII、移动语义、模板元编程
  2. 设计模式:Reactor模式、Proactor模式、观察者模式
  3. 性能优化:零拷贝、内存池、锁-free数据结构
  4. 可扩展架构:模块化设计、插件机制、配置驱动

这些知识和技能为后续的高级网络编程和系统级开发奠定了坚实的基础。通过理论概念的理解和实际代码的实现,我们能够深入理解网络编程的本质,并构建高性能、高可靠性的网络应用程序。

posted @ 2025-12-18 08:46  yangykaifa  阅读(8)  评论(0)    收藏  举报