[AI生成] 基于smux实现在单条TCP连接上收发http

工作流程

Client -->|1. 建立 TCP 连接| Server
Server -->|2. 连接建立成功| Client
    
Client -->|3. 创建 SMUX 会话| Server

# 这里只调用了GET /
Client -->|4. 请求 1: GET /| Server
Client -->|5. 请求 2: GET /status| Server
Client -->|6. 请求 3: GET /echo| Server
    
Server -->|7. 响应 1: ok| Client
Server -->|8. 响应 2: {"status":"ok"}| Client
Server -->|9. 响应 3: 回显内容| Client
    
Client -->|10. 继续发送请求| Server

核心特性

单条 TCP 长连接: 所有 HTTP 请求共享同一条底层 TCP 连接,减少连接开销
SMUX 多路复用: 通过 SMUX 协议在单连接上多路复用多个逻辑流 (Stream)
独立流: 每个 HTTP 请求/响应使用独立的流
TCP Keepalive: 保持连接稳定 (30秒探测间隔)
正确的 HTTP 响应: Content-Length 正确设置,避免 EOF 错误
连接复用: Client 自动检测并复用现有连接

server.go

package main

import (
	"bufio"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"time"

	"github.com/xtaci/smux"
)

// Server SMUX HTTP 服务器,通过单条 TCP 长连接 + SMUX 多路复用来接收所有 HTTP 请求
type Server struct {
	addr     string         // 监听地址
	mux      *http.ServeMux // HTTP 请求多路复用器
	listener net.Listener   // TCP 监听器
}

// NewServer 创建一个新的 SMUX HTTP 服务器
func NewServer(addr string) *Server {
	return &Server{
		addr: addr,
		mux:  http.NewServeMux(),
	}
}

// Handle 注册 HTTP 处理函数到指定路径
func (s *Server) Handle(path string, handler http.Handler) {
	s.mux.Handle(path, handler)
}

// HandleFunc 注册 HTTP 处理函数到指定路径
func (s *Server) HandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) {
	s.mux.HandleFunc(path, handler)
}

// setTCPKeepalive 为 TCP 连接设置 keepalive 选项,保持长连接稳定
func setTCPKeepalive(conn net.Conn) error {
	tcpConn, ok := conn.(*net.TCPConn)
	if !ok {
		return nil
	}

	// 启用 TCP keepalive
	if err := tcpConn.SetKeepAlive(true); err != nil {
		return err
	}

	// 设置 keepalive 探测间隔为 30 秒
	if err := tcpConn.SetKeepAlivePeriod(30 * time.Second); err != nil {
		return err
	}

	return nil
}

// ListenAndServe 启动 SMUX HTTP 服务器
func (s *Server) ListenAndServe() error {
	// 监听 TCP 连接
	listener, err := net.Listen("tcp", s.addr)
	if err != nil {
		return err
	}
	defer listener.Close()

	log.Printf("SMUX HTTP server listening on %s with TCP keepalive enabled", s.addr)

	// 持续接受新的连接
	for {
		conn, err := listener.Accept()
		if err != nil {
			log.Printf("Accept error: %v", err)
			continue
		}

		// 设置 TCP keepalive
		if err := setTCPKeepalive(conn); err != nil {
			log.Printf("Failed to set TCP keepalive for %s: %v", conn.RemoteAddr(), err)
		}

		// 在新的 goroutine 中处理连接
		go s.handleConnection(conn)
	}
}

// handleConnection 处理单个 TCP 连接,为其创建 SMUX 会话
func (s *Server) handleConnection(conn net.Conn) {
	defer conn.Close()

	log.Printf("New connection from %s", conn.RemoteAddr())

	// 创建 SMUX 服务器端会话
	session, err := smux.Server(conn, nil)
	if err != nil {
		log.Printf("Create smux session error: %v", err)
		return
	}
	defer session.Close()

	// 持续接受新的 SMUX 流
	for {
		stream, err := session.AcceptStream()
		if err != nil {
			log.Printf("Accept stream error: %v", err)
			return
		}

		// 在新的 goroutine 中处理每个流
		go s.handleStream(stream)
	}
}

// handleStream 处理单个 SMUX 流,读取 HTTP 请求并返回响应
func (s *Server) handleStream(stream net.Conn) {
	defer stream.Close()

	// 从流中读取 HTTP 请求
	req, err := http.ReadRequest(bufio.NewReader(stream))
	if err != nil {
		if err != io.EOF {
			log.Printf("Read request error: %v", err)
			// 如果读取请求出错,返回错误响应
			w := &responseWriter{
				stream: stream,
				header: make(http.Header),
			}
			w.WriteHeader(http.StatusBadRequest)
			w.Write([]byte("err"))
			w.flush()
		}
		return
	}

	// 创建自定义的 responseWriter
	w := &responseWriter{
		stream: stream,
		header: make(http.Header),
	}

	// 调用 HTTP 处理函数
	s.mux.ServeHTTP(w, req)

	// 如果处理函数没有写入头,默认写入 200 OK
	if !w.wroteHeader {
		w.WriteHeader(http.StatusOK)
	}

	// 刷新响应到流中
	if err := w.flush(); err != nil {
		log.Printf("Flush response error: %v", err)
	}
}

// responseWriter 自定义 HTTP 响应写入器
// 缓存响应体,确保可以正确设置 Content-Length 头
type responseWriter struct {
	stream      net.Conn    // SMUX 流
	header      http.Header // HTTP 响应头
	wroteHeader bool        // 是否已写入响应头
	statusCode  int         // HTTP 状态码
	body        []byte      // 缓存的响应体
}

// Header 返回响应头
func (w *responseWriter) Header() http.Header {
	return w.header
}

// Write 写入响应体(先缓存,不直接发送)
func (w *responseWriter) Write(b []byte) (int, error) {
	if !w.wroteHeader {
		w.WriteHeader(http.StatusOK)
	}
	w.body = append(w.body, b...)
	return len(b), nil
}

// WriteHeader 写入响应头(只记录,不直接发送)
func (w *responseWriter) WriteHeader(statusCode int) {
	if w.wroteHeader {
		return
	}
	w.wroteHeader = true
	w.statusCode = statusCode
}

// flush 将完整的 HTTP 响应发送到流中
// 包括:状态行、响应头、空行、响应体
func (w *responseWriter) flush() error {
	if !w.wroteHeader {
		w.WriteHeader(http.StatusOK)
	}

	// 获取状态文本
	statusText := http.StatusText(w.statusCode)
	if statusText == "" {
		statusText = "Unknown"
	}

	// 写入状态行
	headerLine := fmt.Sprintf("HTTP/1.1 %d %s\r\n", w.statusCode, statusText)
	if _, err := w.stream.Write([]byte(headerLine)); err != nil {
		return err
	}

	// 设置 Content-Length 头
	w.header.Set("Content-Length", fmt.Sprintf("%d", len(w.body)))

	// 写入所有响应头
	for key, values := range w.header {
		for _, value := range values {
			if _, err := w.stream.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value))); err != nil {
				return err
			}
		}
	}

	// 写入空行分隔头部和正文
	if _, err := w.stream.Write([]byte("\r\n")); err != nil {
		return err
	}

	// 写入响应体
	if len(w.body) > 0 {
		if _, err := w.stream.Write(w.body); err != nil {
			return err
		}
	}

	return nil
}

// main 测试服务器
func main() {
	// 创建服务器,监听 8080 端口
	server := NewServer(":8080")

	// 注册测试路由
	// 根路径:返回 ok
	server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
		w.Write([]byte("ok"))
	})

	// /echo:返回请求体内容
	server.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
		body, err := io.ReadAll(r.Body)
		if err != nil {
			w.WriteHeader(http.StatusInternalServerError)
			w.Write([]byte("err"))
			return
		}
		w.Write(body)
	})

	// /status:返回 JSON 状态
	server.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		w.Write([]byte(`{"status": "ok"}`))
	})

	// /error:测试错误响应
	server.HandleFunc("/error", func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusInternalServerError)
		w.Write([]byte("err"))
	})

	// 启动服务器
	if err := server.ListenAndServe(); err != nil {
		log.Fatalf("Server error: %v", err)
	}
}

client.go

package main

import (
	"bufio"
	"io"
	"log"
	"net"
	"net/http"
	"sync"
	"time"

	"github.com/xtaci/smux"
)

// Client SMUX HTTP 客户端,通过单条 TCP 长连接 + SMUX 多路复用来传输所有 HTTP 请求
type Client struct {
	addr    string        // 服务器地址
	conn    net.Conn      // 底层 TCP 连接
	session *smux.Session // SMUX 会话
	mu      sync.Mutex    // 互斥锁,保护连接和会话的并发访问
}

// NewClient 创建一个新的 SMUX 客户端
func NewClient(addr string) *Client {
	return &Client{
		addr: addr,
	}
}

// setTCPKeepalive 为 TCP 连接设置 keepalive 选项,保持长连接稳定
func setTCPKeepalive(conn net.Conn) error {
	tcpConn, ok := conn.(*net.TCPConn)
	if !ok {
		return nil
	}

	// 启用 TCP keepalive
	if err := tcpConn.SetKeepAlive(true); err != nil {
		return err
	}

	// 设置 keepalive 探测间隔为 30 秒
	if err := tcpConn.SetKeepAlivePeriod(30 * time.Second); err != nil {
		return err
	}

	return nil
}

// Connect 连接到 SMUX 服务器
// 如果已有有效的连接,则直接复用,避免重复建立连接
func (c *Client) Connect() error {
	c.mu.Lock()
	defer c.mu.Unlock()

	// 检查是否已有有效的会话
	if c.session != nil && !c.session.IsClosed() {
		return nil
	}

	// 建立新的 TCP 连接
	conn, err := net.Dial("tcp", c.addr)
	if err != nil {
		return err
	}

	// 设置 TCP keepalive
	if err := setTCPKeepalive(conn); err != nil {
		log.Printf("Failed to set TCP keepalive: %v", err)
	}

	// 创建 SMUX 客户端会话
	session, err := smux.Client(conn, nil)
	if err != nil {
		conn.Close()
		return err
	}

	// 保存连接和会话
	c.conn = conn
	c.session = session
	log.Printf("Connected to SMUX server at %s with TCP keepalive enabled", c.addr)

	return nil
}

// streamReadCloser 包装 SMUX 流和 HTTP 响应体
// 确保在关闭响应体时同时关闭底层的 SMUX 流
type streamReadCloser struct {
	stream net.Conn      // SMUX 流
	rc     io.ReadCloser // 原始的 HTTP 响应体
}

// Read 从响应体读取数据
func (s *streamReadCloser) Read(p []byte) (int, error) {
	return s.rc.Read(p)
}

// Close 关闭响应体和 SMUX 流
func (s *streamReadCloser) Close() error {
	err1 := s.rc.Close()
	err2 := s.stream.Close()
	if err1 != nil {
		return err1
	}
	return err2
}

// Do 发送 HTTP 请求并返回响应
// 通过 SMUX 在单条 TCP 连接上多路复用多个请求
func (c *Client) Do(req *http.Request) (*http.Response, error) {
	// 确保已连接到服务器
	if err := c.Connect(); err != nil {
		return nil, err
	}

	// 打开一个新的 SMUX 流
	stream, err := c.session.OpenStream()
	if err != nil {
		return nil, err
	}

	// 将 HTTP 请求写入流
	if err := req.Write(stream); err != nil {
		stream.Close()
		return nil, err
	}

	// 从流中读取 HTTP 响应
	resp, err := http.ReadResponse(bufio.NewReader(stream), req)
	if err != nil {
		stream.Close()
		return nil, err
	}

	// 包装响应体,确保流正确关闭
	resp.Body = &streamReadCloser{
		stream: stream,
		rc:     resp.Body,
	}

	return resp, nil
}

// Get 发送 GET 请求
func (c *Client) Get(url string) (*http.Response, error) {
	req, err := http.NewRequest("GET", url, nil)
	if err != nil {
		return nil, err
	}
	return c.Do(req)
}

// Close 关闭客户端连接和会话
func (c *Client) Close() error {
	c.mu.Lock()
	defer c.mu.Unlock()

	if c.session != nil {
		c.session.Close()
	}
	if c.conn != nil {
		c.conn.Close()
	}

	c.session = nil
	c.conn = nil
	return nil
}

// main 测试函数:持续不断地向服务器发送 GET 请求
func main() {
	// 创建客户端
	client := NewClient("localhost:8080")
	defer client.Close()

	log.Println("Starting continuous interaction with SMUX server...")
	log.Println("Press Ctrl+C to stop")

	counter := 0
	for {
		counter++

		// 发送 GET 请求
		resp, err := client.Get("http://localhost/")
		if err != nil {
			log.Printf("[%d] GET / error: %v", counter, err)
			time.Sleep(1 * time.Second)
			continue
		}

		// 读取响应体
		body, err := io.ReadAll(resp.Body)
		resp.Body.Close()
		if err != nil {
			log.Printf("[%d] Read response error: %v", counter, err)
			time.Sleep(1 * time.Second)
			continue
		}

		// 打印响应
		log.Printf("[%d] GET / response: %s", counter, body)

		// 等待 500 毫秒后再次请求
		time.Sleep(500 * time.Millisecond)
	}
}

运行测试

$ go run server.go
2026/04/06 10:22:02 SMUX HTTP server listening on :8080 with TCP keepalive enabled
2026/04/06 10:22:10 New connection from [::1]:51219
2026/04/06 10:22:14 Accept stream error: read tcp [::1]:8080->[::1]:51219: wsarecv: An existing connection was forcibly closed by the remote host.

$ go run client.go
2026/04/06 10:22:10 Starting continuous interaction with SMUX server...
2026/04/06 10:22:10 Press Ctrl+C to stop
2026/04/06 10:22:10 Connected to SMUX server at localhost:8080 with TCP keepalive enabled
2026/04/06 10:22:10 [1] GET / response: ok
2026/04/06 10:22:11 [2] GET / response: ok
2026/04/06 10:22:11 [3] GET / response: ok
2026/04/06 10:22:12 [4] GET / response: ok
2026/04/06 10:22:12 [5] GET / response: ok
2026/04/06 10:22:13 [6] GET / response: ok
2026/04/06 10:22:13 [7] GET / response: ok
2026/04/06 10:22:14 [8] GET / response: ok
2026/04/06 10:22:14 [9] GET / response: ok
exit status 0xc000013a

 

posted on 2026-04-06 10:26  王景迁  阅读(2)  评论(0)    收藏  举报

导航