1 package info_websocket
2
3 import (
4 "crypto/sha1"
5 "encoding/base64"
6 "errors"
7 "io"
8 "log"
9 "net"
10 "strings"
11 )
12
13 func main() {
14 ln, err := net.Listen("tcp", ":8000")//监听端口
15 if err != nil {
16 log.Panic(err)
17 }
18 for {
19 log.Println("wss")
20 conn, err := ln.Accept()//等待客户的连接
21 if err != nil {
22 log.Println("Accept err:", err)
23 }
24 for {
25 handleConnection(conn)
26 }
27 }
28 }
29
30 func handleConnection(conn net.Conn) {
31 content := make([]byte, 1024)
32 _, err := conn.Read(content)
33 log.Println(string(content))
34 if err != nil {
35 log.Println(err)
36 }
37 isHttp := false
38 // 先暂时这么判断
39 if string(content[0:3]) == "GET" {
40 isHttp = true
41 }
42 log.Println("isHttp:", isHttp)
43 if isHttp {
44 headers := parseHandshake(string(content))
45 log.Println("headers", headers)
46 secWebsocketKey := headers["Sec-WebSocket-Key"]
47 // NOTE:这里省略其他的验证
48 guid := "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
49 // 计算Sec-WebSocket-Accept
50 h := sha1.New()
51 log.Println("accept raw:", secWebsocketKey+guid)
52 io.WriteString(h, secWebsocketKey+guid)
53 accept := make([]byte, 28)
54 base64.StdEncoding.Encode(accept, h.Sum(nil))
55 log.Println(string(accept))
56 response := "HTTP/1.1 101 Switching Protocols\r\n"
57 response = response + "Sec-WebSocket-Accept: " + string(accept) + "\r\n"
58 response = response + "Connection: Upgrade\r\n"
59 response = response + "Upgrade: websocket\r\n\r\n"
60 log.Println("response:", response)
61 if lenth, err := conn.Write([]byte(response)); err != nil {
62 log.Println(err)
63 } else {
64 log.Println("send len:", lenth)
65 }
66 wssocket := NewWsSocket(conn)
67 for {
68 data, err := wssocket.ReadIframe()
69 if err != nil {
70 log.Println("readIframe err:", err)
71 }
72 log.Println("read data:", string(data))
73 err = wssocket.SendIframe([]byte("good"))
74 if err != nil {
75 log.Println("sendIframe err:", err)
76 }
77 log.Println("send data")
78 }
79 } else {
80 log.Println(string(content))
81 // 直接读取
82 }
83 }
84
85 type WsSocket struct {
86 MaskingKey []byte
87 Conn net.Conn
88 }
89
90 func NewWsSocket(conn net.Conn) *WsSocket {
91 return &WsSocket{Conn: conn}
92 }
93
94 func (this *WsSocket) SendIframe(data []byte) error {
95 // 这里只处理data长度<125的
96 if len(data) >= 125 {
97 return errors.New("send iframe data error")
98 }
99 lenth := len(data)
100 maskedData := make([]byte, lenth)
101 for i := 0; i < lenth; i++ {
102 if this.MaskingKey != nil {
103 maskedData[i] = data[i] ^ this.MaskingKey[i%4]
104 } else {
105 maskedData[i] = data[i]
106 }
107 }
108 this.Conn.Write([]byte{0x81})
109 var payLenByte byte
110 if this.MaskingKey != nil && len(this.MaskingKey) != 4 {
111 payLenByte = byte(0x80) | byte(lenth)
112 this.Conn.Write([]byte{payLenByte})
113 this.Conn.Write(this.MaskingKey)
114 } else {
115 payLenByte = byte(0x00) | byte(lenth)
116 this.Conn.Write([]byte{payLenByte})
117 }
118 this.Conn.Write(data)
119 return nil
120 }
121
122 func (this *WsSocket) ReadIframe() (data []byte, err error) {
123 err = nil
124 //第一个字节:FIN + RSV1-3 + OPCODE
125 opcodeByte := make([]byte, 1)
126 this.Conn.Read(opcodeByte)
127 FIN := opcodeByte[0] >> 7
128 RSV1 := opcodeByte[0] >> 6 & 1
129 RSV2 := opcodeByte[0] >> 5 & 1
130 RSV3 := opcodeByte[0] >> 4 & 1
131 OPCODE := opcodeByte[0] & 15
132 log.Println(RSV1, RSV2, RSV3, OPCODE)
133
134 payloadLenByte := make([]byte, 1)
135 this.Conn.Read(payloadLenByte)
136 payloadLen := int(payloadLenByte[0] & 0x7F)
137 mask := payloadLenByte[0] >> 7
138 if payloadLen == 127 {
139 extendedByte := make([]byte, 8)
140 this.Conn.Read(extendedByte)
141 }
142 maskingByte := make([]byte, 4)
143 if mask == 1 {
144 this.Conn.Read(maskingByte)
145 this.MaskingKey = maskingByte
146 }
147
148 payloadDataByte := make([]byte, payloadLen)
149 this.Conn.Read(payloadDataByte)
150 log.Println("data:", payloadDataByte)
151 dataByte := make([]byte, payloadLen)
152 for i := 0; i < payloadLen; i++ {
153 if mask == 1 {
154 dataByte[i] = payloadDataByte[i] ^ maskingByte[i%4]
155 } else {
156 dataByte[i] = payloadDataByte[i]
157 }
158 }
159 if FIN == 1 {
160 data = dataByte
161 return
162 }
163 nextData, err := this.ReadIframe()
164 if err != nil {
165 return
166 }
167 data = append(data, nextData...)
168 return
169 }
170
171 func parseHandshake(content string) map[string]string {
172 headers := make(map[string]string, 10)
173 lines := strings.Split(content, "\r\n")
174 for _, line := range lines {
175 if len(line) >= 0 {
176 words := strings.Split(line, ":")
177 if len(words) == 2 {
178 headers[strings.Trim(words[0], " ")] = strings.Trim(words[1], " ")
179 }
180 }
181 }
182 return headers
183 }