-
Notifications
You must be signed in to change notification settings - Fork 94
/
types.go
287 lines (229 loc) · 6.62 KB
/
types.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
package gws
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
"runtime"
"unsafe"
"github.com/lxzan/gws/internal"
)
const frameHeaderSize = 14
type Opcode uint8
const (
OpcodeContinuation Opcode = 0x0
OpcodeText Opcode = 0x1
OpcodeBinary Opcode = 0x2
OpcodeCloseConnection Opcode = 0x8
OpcodePing Opcode = 0x9
OpcodePong Opcode = 0xA
)
func (c Opcode) isDataFrame() bool {
return c <= OpcodeBinary
}
type CloseError struct {
Code uint16
Reason []byte
}
func (c *CloseError) Error() string {
return fmt.Sprintf("gws: connection closed, code=%d, reason=%s", c.Code, string(c.Reason))
}
var (
errEmpty = errors.New("")
// ErrUnauthorized 未通过鉴权认证
// Failure to pass forensic authentication
ErrUnauthorized = errors.New("unauthorized")
// ErrHandshake 握手错误, 请求头未通过校验
// Handshake error, request header does not pass checksum.
ErrHandshake = errors.New("handshake error")
// ErrCompressionNegotiation 压缩拓展协商失败, 请尝试关闭压缩
// Compression extension negotiation failed, please try to disable compression.
ErrCompressionNegotiation = errors.New("invalid compression negotiation")
// ErrSubprotocolNegotiation 子协议协商失败
// Sub-protocol negotiation failed
ErrSubprotocolNegotiation = errors.New("sub-protocol negotiation failed")
// ErrTextEncoding 文本消息编码错误(必须是utf8编码)
// Text message encoding error (must be utf8)
ErrTextEncoding = errors.New("invalid text encoding")
// ErrConnClosed 连接已关闭
// Connection closed
ErrConnClosed = net.ErrClosed
// ErrUnsupportedProtocol 不支持的网络协议
// Unsupported network protocols
ErrUnsupportedProtocol = errors.New("unsupported protocol")
)
type Event interface {
// OnOpen 建立连接事件
// WebSocket connection was successfully established
OnOpen(socket *Conn)
// OnClose 关闭事件
// 接收到了网络连接另一端发送的关闭帧, 或者IO过程中出现错误主动断开连接
// 如果是前者, err可以断言为*CloseError
// Received a close frame from the other end of the network connection, or disconnected voluntarily due to an error in the IO process
// In the former case, err can be asserted as *CloseError
OnClose(socket *Conn, err error)
// OnPing 心跳探测事件
// Received a ping frame
OnPing(socket *Conn, payload []byte)
// OnPong 心跳响应事件
// Received a pong frame
OnPong(socket *Conn, payload []byte)
// OnMessage 消息事件
// 如果开启了ParallelEnabled, 会并行地调用OnMessage; 没有做recover处理.
// If ParallelEnabled is enabled, OnMessage is called in parallel. No recover is done.
OnMessage(socket *Conn, message *Message)
}
type BuiltinEventHandler struct{}
func (b BuiltinEventHandler) OnOpen(socket *Conn) {}
func (b BuiltinEventHandler) OnClose(socket *Conn, err error) {}
func (b BuiltinEventHandler) OnPing(socket *Conn, payload []byte) { _ = socket.WritePong(nil) }
func (b BuiltinEventHandler) OnPong(socket *Conn, payload []byte) {}
func (b BuiltinEventHandler) OnMessage(socket *Conn, message *Message) {}
type frameHeader [frameHeaderSize]byte
func (c *frameHeader) GetFIN() bool {
return ((*c)[0] >> 7) == 1
}
func (c *frameHeader) GetRSV1() bool {
return ((*c)[0] << 1 >> 7) == 1
}
func (c *frameHeader) GetRSV2() bool {
return ((*c)[0] << 2 >> 7) == 1
}
func (c *frameHeader) GetRSV3() bool {
return ((*c)[0] << 3 >> 7) == 1
}
func (c *frameHeader) GetOpcode() Opcode {
return Opcode((*c)[0] << 4 >> 4)
}
func (c *frameHeader) GetMask() bool {
return ((*c)[1] >> 7) == 1
}
func (c *frameHeader) GetLengthCode() uint8 {
return (*c)[1] << 1 >> 1
}
func (c *frameHeader) SetMask() {
(*c)[1] |= uint8(128)
}
func (c *frameHeader) SetLength(n uint64) (offset int) {
if n <= internal.ThresholdV1 {
(*c)[1] += uint8(n)
return 0
} else if n <= internal.ThresholdV2 {
(*c)[1] += 126
binary.BigEndian.PutUint16((*c)[2:4], uint16(n))
return 2
} else {
(*c)[1] += 127
binary.BigEndian.PutUint64((*c)[2:10], n)
return 8
}
}
func (c *frameHeader) SetMaskKey(offset int, key [4]byte) {
copy((*c)[offset:offset+4], key[0:])
}
// GenerateHeader generate frame header for writing
// 可以考虑每个客户端连接带一个随机数发生器
func (c *frameHeader) GenerateHeader(isServer bool, fin bool, compress bool, opcode Opcode, length int) (headerLength int, maskBytes []byte) {
headerLength = 2
var b0 = uint8(opcode)
if fin {
b0 += 128
}
if compress {
b0 += 64
}
(*c)[0] = b0
headerLength += c.SetLength(uint64(length))
if !isServer {
(*c)[1] |= 128
maskNum := internal.AlphabetNumeric.Uint32()
binary.LittleEndian.PutUint32((*c)[headerLength:headerLength+4], maskNum)
maskBytes = (*c)[headerLength : headerLength+4]
headerLength += 4
}
return
}
// Parse 解析完整协议头, 最多14byte, 返回payload长度
func (c *frameHeader) Parse(reader io.Reader) (int, error) {
if err := internal.ReadN(reader, (*c)[0:2]); err != nil {
return 0, err
}
var payloadLength = 0
var lengthCode = c.GetLengthCode()
switch lengthCode {
case 126:
if err := internal.ReadN(reader, (*c)[2:4]); err != nil {
return 0, err
}
payloadLength = int(binary.BigEndian.Uint16((*c)[2:4]))
case 127:
if err := internal.ReadN(reader, (*c)[2:10]); err != nil {
return 0, err
}
payloadLength = int(binary.BigEndian.Uint64((*c)[2:10]))
default:
payloadLength = int(lengthCode)
}
var maskOn = c.GetMask()
if maskOn {
if err := internal.ReadN(reader, (*c)[10:14]); err != nil {
return 0, err
}
}
return payloadLength, nil
}
// GetMaskKey parser把maskKey放到了末尾
func (c *frameHeader) GetMaskKey() []byte {
return (*c)[10:14]
}
type Message struct {
// 是否压缩
compressed bool
// 操作码
Opcode Opcode
// 消息内容
Data *bytes.Buffer
}
func (c *Message) Read(p []byte) (n int, err error) {
return c.Data.Read(p)
}
func (c *Message) Bytes() []byte {
return c.Data.Bytes()
}
// Close recycle buffer
func (c *Message) Close() error {
binaryPool.Put(c.Data)
c.Data = nil
return nil
}
type continuationFrame struct {
initialized bool
compressed bool
opcode Opcode
buffer *bytes.Buffer
}
func (c *continuationFrame) reset() {
c.initialized = false
c.compressed = false
c.opcode = 0
c.buffer = nil
}
type Logger interface {
Error(v ...any)
}
type stdLogger struct{}
func (c *stdLogger) Error(v ...any) {
log.Println(v...)
}
func Recovery(logger Logger) {
if e := recover(); e != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
msg := *(*string)(unsafe.Pointer(&buf))
logger.Error("fatal error:", e, msg)
}
}