[AI生成] github smux源码分析

https://github.com/xtaci/smux

特点

smux = 在一条 TCP 连接里,同时跑多条虚拟 “数据流”
普通 TCP:一条连接 = 一条路,只能发一股数据
smux 多路复用:一条连接 = 一条高速公路,里面划分多条车道,同时跑多组数据互不干扰
作用:省连接、省资源、更快、更稳定。

把 smux 想象成 社区快递中转站:
TCP 连接 = 一条大马路
smux 流(Stream)= 一条快递车道,流 ID(sid)= 车道编号
Frame 帧(最小数据单位) = 快递包裹

工作流程

image

其中,UPD是窗口更新。

image

你要发数据 → 拆成一个个小包裹(Frame)
打上编号(流 ID)→ 塞进同一条 TCP 马路
对方收到 → 按编号分拣 → 还原成完整数据

发送端
调用 OpenStream() → 发 SYN 指令
写数据 → 拆成帧 → 打上流 ID
交给 TCP 发送

接收端
收到 TCP 数据 → 解帧
读取流 ID → 分到对应缓冲区
应用层读取 → 回复 UPD 窗口

关闭
调用 Close() → 发 FIN 指令
双方确认 → 释放流

代码解析

一条 TCP 创建一个 Session
OpenStream 创建无数虚拟 Stream(不占新连接)
Write 发数据 → 窗口耗尽 → 阻塞等待 UPD
Read 读数据 → 消费缓冲区 → 自动发 UPD
收到 UPD → 恢复窗口 → 解除 Write 阻塞
循环实现高速、稳定、多路复用传输

1. 结构体定义

package smux

import (
	"encoding/binary"
	"io"
	"sync"
)

// 1. 帧指令类型(smux的所有命令)
const (
	cmdSYN = 0x01 // 新建流(建立虚拟连接)
	cmdFIN = 0x02 // 关闭流
	cmdPSH = 0x03 // 传输真实数据
	cmdNOP = 0x04 // 空指令(心跳)
	cmdUPD = 0x05 // 🔥 UPD:窗口更新(流量控制核心)
)

// 2. 帧结构:smux传输的最小数据单元
type Frame struct {
	ver  uint8   // 协议版本
	cmd  uint8   // 指令类型(SYN/FIN/PSH/UPD)
	sid  uint32  // 流ID(区分不同虚拟流)
	data []byte  // 数据内容(PSH存业务数据,UPD存窗口信息)
}

// 3. UPD指令固定头结构(8字节)
type updHeader struct {
	Cmd      uint8  // 固定为cmdUPD
	_        uint8  // 保留位
	Consumed uint32 // 接收端【已消费】数据量
	Window   uint32 // 接收端【剩余可用窗口】
}

// 4. Session:会话(管理一条TCP,承载所有虚拟流)
type Session struct {
	mux     sync.Mutex
	conn    io.ReadWriteCloser // 底层TCP连接
	streams map[uint32]*Stream // 所有虚拟流集合
	nextID  uint32             // 下一个流的ID
}

// 5. Stream:虚拟流(用户真正读写的对象)
type Stream struct {
	id         uint32        // 流唯一ID
	session    *Session      // 所属会话
	buffer     []byte        // 接收缓冲区
	window     uint32        // 接收窗口(我能收多少)
	sendWindow uint32        // 发送窗口(对方能收多少)
	consumed   uint32        // 已消费数据(触发UPD)
	writeCond  *sync.Cond    // 发送阻塞唤醒器
}

2. 核心流程

流程 1:创建会话(基于 TCP 建立 smux)

// NewSession 基于一条TCP连接创建smux会话
// 作用:把普通TCP包装成多路复用连接
func NewSession(conn io.ReadWriteCloser) *Session {
	s := &Session{
		conn:    conn,
		streams: make(map[uint32]*Stream),
		nextID:  1,
	}
	// 启动后台协程:循环读取TCP数据,解析帧
	go s.recvLoop()
	return s
}

// recvLoop 接收循环(后台一直运行)
// 作用:从TCP读数据 → 解析成smux帧 → 分发给对应流
func (s *Session) recvLoop() {
	for {
		// 1. 读取一个帧
		frame, err := s.readFrame()
		if err != nil {
			return
		}

		// 2. 根据帧指令类型处理
		switch frame.cmd {
			case cmdSYN: // 新建流
				s.handleSYN(frame)
			case cmdPSH: // 收到数据
				s.handlePSH(frame)
			case cmdUPD: // 🔥 收到窗口更新(核心)
				s.handleUPD(frame)
			case cmdFIN: // 关闭流
				s.handleFIN(frame)
		}
	}
}

流程 2:打开虚拟流(不新建 TCP!)

// OpenStream 创建一条虚拟流
// 重点:不会新建TCP,只在当前会话里创建逻辑流
func (s *Session) OpenStream() *Stream {
	s.mux.Lock()
	defer s.mux.Unlock()

	// 1. 生成新流ID
	id := s.nextID
	s.nextID++

	// 2. 创建流对象
	stream := &Stream{
		id:         id,
		session:    s,
		window:     256 * 1024, // 默认接收窗口 256KB
		sendWindow: 256 * 1024, // 默认发送窗口
		writeCond:  sync.NewCond(&sync.Mutex{}),
	}

	// 3. 保存到会话
	s.streams[id] = stream

	// 4. 发送SYN指令,通知对端创建对应流
	s.writeFrame(Frame{cmd: cmdSYN, sid: id})
	return stream
}

流程 3:发送数据(Write)+ 窗口阻塞机制

// Write 流发送数据(用户调用)
// 机制:发送窗口为0时阻塞,直到收到UPD才继续
func (s *Stream) Write(b []byte) (int, error) {
	total := 0
	n := len(b)

	for total < n {
		// ===== 🔥 关键:发送窗口为0,阻塞等待 UPD =====
		s.writeCond.L.Lock()
		for s.sendWindow == 0 {
			// 等待被唤醒(收到UPD后唤醒)
			s.writeCond.Wait()
		}
		s.writeCond.L.Unlock()

		// 可发送长度 = 取窗口大小和剩余数据较小值
		writeLen := min(int(s.sendWindow), n-total)
		if writeLen <= 0 {
			break
		}

		// 封装PSH数据帧发送
		s.session.writeFrame(Frame{
			cmd:  cmdPSH,
			sid:  s.id,
			data: b[total : total+writeLen],
		})

		// 减少发送窗口(每发一点,窗口变小)
		s.sendWindow -= uint32(writeLen)
		total += writeLen
	}

	return total, nil
}

流程 4:接收数据(Read)+ 自动发送 UPD

// Read 流读取数据(用户调用)
// 核心:读取后会自动发送 UPD 通知发送端
func (s *Stream) Read(buf []byte) (int, error) {
	// 1. 从缓冲区读取数据
	n := copy(buf, s.buffer)
	if n == 0 {
		return 0, io.EOF
	}

	// 2. 清空已读数据
	s.buffer = s.buffer[n:]

	// 3. 累加已消费字节数
	s.consumed += uint32(n)

	// ===== 🔥 关键:消费达到阈值,发送 UPD =====
	if s.consumed > s.window/2 {
		// 发送窗口更新
		s.sendUPD()
	}

	return n, nil
}

// sendUPD 发送 UPD 指令(流量控制核心)
func (s *Stream) sendUPD() {
	// 构造UPD帧(固定8字节)
	data := make([]byte, 8)
	// 写入:已消费数据量
	binary.LittleEndian.PutUint32(data[0:4], s.consumed)
	// 写入:当前可用接收窗口
	binary.LittleEndian.PutUint32(data[4:8], s.window)

	// 发送UPD帧
	s.session.writeFrame(Frame{
		cmd:  cmdUPD,
		sid:  s.id,
		data: data,
	})

	// 重置已消费计数
	s.consumed = 0
}

流程 5:收到 UPD → 解除发送阻塞

// handleUPD 处理收到的UPD指令
func (s *Session) handleUPD(f Frame) {
	// 1. 找到对应流
	stream, ok := s.streams[f.sid]
	if !ok {
		return
	}

	// 2. 解析UPD数据:已消费 + 窗口大小
	consumed := binary.LittleEndian.Uint32(f.data[0:4])
	window := binary.LittleEndian.Uint32(f.data[4:8])

	// 3. 🔥 恢复发送窗口(允许继续发送)
	stream.sendWindow = window

	// 4. 唤醒阻塞的Write()
	stream.writeCond.Signal()
}

流程 6:关闭流

// Close 关闭虚拟流
func (s *Stream) Close() {
	// 发送FIN指令
	s.session.writeFrame(Frame{cmd: cmdFIN, sid: s.id})
	// 从会话删除
	s.session.mux.Lock()
	delete(s.session.streams, s.id)
	s.session.mux.Unlock()
}

 

posted on 2026-04-06 14:46  王景迁  阅读(4)  评论(0)    收藏  举报

导航