golang webSocket

  1 package info_websocket
  2 
  3 import (
  4     "crypto/sha1"
  5     "encoding/base64"
  6     "errors"
  7     "io"
  8     "log"
  9     "net"
 10     "strings"
 11 )
 12 
 13 func main() {
 14     ln, err := net.Listen("tcp", ":8000")//监听端口
 15     if err != nil {
 16         log.Panic(err)
 17     }
 18     for {
 19         log.Println("wss")
 20         conn, err := ln.Accept()//等待客户的连接
 21         if err != nil {
 22             log.Println("Accept err:", err)
 23         }
 24         for {
 25             handleConnection(conn)
 26         }
 27     }
 28 }
 29 
 30 func handleConnection(conn net.Conn) {
 31     content := make([]byte, 1024)
 32     _, err := conn.Read(content)
 33     log.Println(string(content))
 34     if err != nil {
 35         log.Println(err)
 36     }
 37     isHttp := false
 38     // 先暂时这么判断
 39     if string(content[0:3]) == "GET" {
 40         isHttp = true
 41     }
 42     log.Println("isHttp:", isHttp)
 43     if isHttp {
 44         headers := parseHandshake(string(content))
 45         log.Println("headers", headers)
 46         secWebsocketKey := headers["Sec-WebSocket-Key"]
 47         // NOTE:这里省略其他的验证
 48         guid := "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
 49         // 计算Sec-WebSocket-Accept
 50         h := sha1.New()
 51         log.Println("accept raw:", secWebsocketKey+guid)
 52         io.WriteString(h, secWebsocketKey+guid)
 53         accept := make([]byte, 28)
 54         base64.StdEncoding.Encode(accept, h.Sum(nil))
 55         log.Println(string(accept))
 56         response := "HTTP/1.1 101 Switching Protocols\r\n"
 57         response = response + "Sec-WebSocket-Accept: " + string(accept) + "\r\n"
 58         response = response + "Connection: Upgrade\r\n"
 59         response = response + "Upgrade: websocket\r\n\r\n"
 60         log.Println("response:", response)
 61         if lenth, err := conn.Write([]byte(response)); err != nil {
 62             log.Println(err)
 63         } else {
 64             log.Println("send len:", lenth)
 65         }
 66         wssocket := NewWsSocket(conn)
 67         for {
 68             data, err := wssocket.ReadIframe()
 69             if err != nil {
 70                 log.Println("readIframe err:", err)
 71             }
 72             log.Println("read data:", string(data))
 73             err = wssocket.SendIframe([]byte("good"))
 74             if err != nil {
 75                 log.Println("sendIframe err:", err)
 76             }
 77             log.Println("send data")
 78         }
 79     } else {
 80         log.Println(string(content))
 81         // 直接读取
 82     }
 83 }
 84 
 85 type WsSocket struct {
 86     MaskingKey []byte
 87     Conn       net.Conn
 88 }
 89 
 90 func NewWsSocket(conn net.Conn) *WsSocket {
 91     return &WsSocket{Conn: conn}
 92 }
 93 
 94 func (this *WsSocket) SendIframe(data []byte) error {
 95     // 这里只处理data长度<125的
 96     if len(data) >= 125 {
 97         return errors.New("send iframe data error")
 98     }
 99     lenth := len(data)
100     maskedData := make([]byte, lenth)
101     for i := 0; i < lenth; i++ {
102         if this.MaskingKey != nil {
103             maskedData[i] = data[i] ^ this.MaskingKey[i%4]
104         } else {
105             maskedData[i] = data[i]
106         }
107     }
108     this.Conn.Write([]byte{0x81})
109     var payLenByte byte
110     if this.MaskingKey != nil && len(this.MaskingKey) != 4 {
111         payLenByte = byte(0x80) | byte(lenth)
112         this.Conn.Write([]byte{payLenByte})
113         this.Conn.Write(this.MaskingKey)
114     } else {
115         payLenByte = byte(0x00) | byte(lenth)
116         this.Conn.Write([]byte{payLenByte})
117     }
118     this.Conn.Write(data)
119     return nil
120 }
121 
122 func (this *WsSocket) ReadIframe() (data []byte, err error) {
123     err = nil
124     //第一个字节:FIN + RSV1-3 + OPCODE
125     opcodeByte := make([]byte, 1)
126     this.Conn.Read(opcodeByte)
127     FIN := opcodeByte[0] >> 7
128     RSV1 := opcodeByte[0] >> 6 & 1
129     RSV2 := opcodeByte[0] >> 5 & 1
130     RSV3 := opcodeByte[0] >> 4 & 1
131     OPCODE := opcodeByte[0] & 15
132     log.Println(RSV1, RSV2, RSV3, OPCODE)
133 
134     payloadLenByte := make([]byte, 1)
135     this.Conn.Read(payloadLenByte)
136     payloadLen := int(payloadLenByte[0] & 0x7F)
137     mask := payloadLenByte[0] >> 7
138     if payloadLen == 127 {
139         extendedByte := make([]byte, 8)
140         this.Conn.Read(extendedByte)
141     }
142     maskingByte := make([]byte, 4)
143     if mask == 1 {
144         this.Conn.Read(maskingByte)
145         this.MaskingKey = maskingByte
146     }
147 
148     payloadDataByte := make([]byte, payloadLen)
149     this.Conn.Read(payloadDataByte)
150     log.Println("data:", payloadDataByte)
151     dataByte := make([]byte, payloadLen)
152     for i := 0; i < payloadLen; i++ {
153         if mask == 1 {
154             dataByte[i] = payloadDataByte[i] ^ maskingByte[i%4]
155         } else {
156             dataByte[i] = payloadDataByte[i]
157         }
158     }
159     if FIN == 1 {
160         data = dataByte
161         return
162     }
163     nextData, err := this.ReadIframe()
164     if err != nil {
165         return
166     }
167     data = append(data, nextData...)
168     return
169 }
170 
171 func parseHandshake(content string) map[string]string {
172     headers := make(map[string]string, 10)
173     lines := strings.Split(content, "\r\n")
174     for _, line := range lines {
175         if len(line) >= 0 {
176             words := strings.Split(line, ":")
177             if len(words) == 2 {
178                 headers[strings.Trim(words[0], " ")] = strings.Trim(words[1], " ")
179             }
180         }
181     }
182     return headers
183 }

 

posted @ 2020-02-16 15:35  鸡儿er  阅读(162)  评论(0编辑  收藏  举报