Skip to content

Commit

Permalink
optimize writev method
Browse files Browse the repository at this point in the history
  • Loading branch information
lxzan committed Jan 24, 2024
1 parent fe3d21b commit 8efe3c8
Show file tree
Hide file tree
Showing 13 changed files with 336 additions and 131 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ PASS
- [x] Broadcast
- [x] Dial via Proxy
- [x] Context-Takeover
- [x] Zero Allocs Read / Write
- [x] Passed Autobahn Test Cases [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/)
- [x] Concurrent & Asynchronous Non-Blocking Write

Expand Down
1 change: 0 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ PASS
- [x] 广播
- [x] 代理拨号
- [x] 上下文接管
- [x] 读写过程零动态内存分配
- [x] 支持并发和异步非阻塞写入
- [x] 通过所有 Autobahn 测试用例 [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/)

Expand Down
4 changes: 2 additions & 2 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) {
conn: &benchConn{},
config: upgrader.option.getConfig(),
}
var buf, _ = conn1.genFrame(OpcodeText, githubData, false)
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false)

var reader = bytes.NewBuffer(buf.Bytes())
var conn2 = &Conn{
Expand Down Expand Up @@ -98,7 +98,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) {
deflater: new(deflater),
}
conn1.deflater.initialize(false, conn1.pd)
var buf, _ = conn1.genFrame(OpcodeText, githubData, false)
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false)

var reader = bytes.NewBuffer(buf.Bytes())
var conn2 = &Conn{
Expand Down
19 changes: 12 additions & 7 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@ func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, er
}

// Compress 压缩
func (c *deflater) Compress(src []byte, dst *bytes.Buffer, dict []byte) error {
func (c *deflater) Compress(src internal.Payload, dst *bytes.Buffer, dict []byte) error {
c.cpsLocker.Lock()
defer c.cpsLocker.Unlock()

c.cpsWriter.ResetDict(dst, dict)
if err := internal.CheckErrors(internal.WriteN(c.cpsWriter, src), c.cpsWriter.Flush()); err != nil {
if _, err := src.WriteTo(c.cpsWriter); err != nil {
return err
}
if err := c.cpsWriter.Flush(); err != nil {
return err
}
if n := dst.Len(); n >= 4 {
Expand Down Expand Up @@ -116,16 +119,17 @@ func (c *slideWindow) initialize(pool *internal.Pool[[]byte], windowBits int) *s
return c
}

func (c *slideWindow) Write(p []byte) {
func (c *slideWindow) Write(p []byte) (int, error) {
if !c.enabled {
return
return 0, nil
}

var n = len(p)
var total = len(p)
var n = total
var length = len(c.dict)
if n+length <= c.size {
c.dict = append(c.dict, p...)
return
return total, nil
}

if m := c.size - length; m > 0 {
Expand All @@ -136,11 +140,12 @@ func (c *slideWindow) Write(p []byte) {

if n >= c.size {
copy(c.dict, p[n-c.size:])
return
return total, nil
}

copy(c.dict, c.dict[n:])
copy(c.dict[c.size-n:], p)
return total, nil
}

func (c *PermessageDeflate) genRequestHeader() string {
Expand Down
77 changes: 77 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package gws

import (
"errors"
"io"
"testing"
"time"

Expand Down Expand Up @@ -181,4 +183,79 @@ func TestPermessageNegotiation(t *testing.T) {
assert.NoError(t, err)
client.WriteMessage(OpcodeText, internal.AlphabetNumeric.Generate(1024))
})

t.Run("ok 5", func(t *testing.T) {
var addr = ":" + nextPort()
var serverHandler = &webSocketMocker{}
serverHandler.onMessage = func(socket *Conn, message *Message) {
println(message.Data.String())
}
var server = NewServer(serverHandler, &ServerOption{PermessageDeflate: PermessageDeflate{
Enabled: true,
ServerContextTakeover: true,
ClientContextTakeover: true,
ServerMaxWindowBits: 10,
ClientMaxWindowBits: 10,
}})
go server.Run(addr)

time.Sleep(100 * time.Millisecond)
client, _, err := NewClient(new(BuiltinEventHandler), &ClientOption{
Addr: "ws://localhost" + addr,
PermessageDeflate: PermessageDeflate{
Enabled: true,
ServerContextTakeover: true,
ClientContextTakeover: true,
Threshold: 1,
},
})
assert.NoError(t, err)
_ = client.WriteString("he")
assert.Equal(t, string(client.cpsWindow.dict), "he")
_ = client.WriteString("llo")
assert.Equal(t, string(client.cpsWindow.dict), "hello")
_ = client.WriteV(OpcodeText, []byte(", "), []byte("world!"))
assert.Equal(t, string(client.cpsWindow.dict), "hello, world!")
})

t.Run("fail", func(t *testing.T) {
var addr = ":" + nextPort()
var serverHandler = &webSocketMocker{}
var server = NewServer(serverHandler, &ServerOption{PermessageDeflate: PermessageDeflate{
Enabled: true,
ServerContextTakeover: true,
ClientContextTakeover: true,
ServerMaxWindowBits: 10,
ClientMaxWindowBits: 10,
}})
go server.Run(addr)

time.Sleep(100 * time.Millisecond)
client, _, err := NewClient(new(BuiltinEventHandler), &ClientOption{
Addr: "ws://localhost" + addr,
PermessageDeflate: PermessageDeflate{
Enabled: true,
ServerContextTakeover: true,
ClientContextTakeover: true,
Threshold: 1,
},
})
assert.NoError(t, err)
err = client.doWrite(OpcodeText, new(writerTo))
assert.Equal(t, err.Error(), "1")
})
}

type writerTo struct{}

func (c *writerTo) CheckEncoding(enabled bool, opcode uint8) bool {
return true
}

func (c *writerTo) Len() int {
return 10
}

func (c *writerTo) WriteTo(w io.Writer) (n int64, err error) {
return 0, errors.New("1")
}
17 changes: 5 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ import (
"bytes"
"crypto/tls"
"encoding/binary"
"github.com/lxzan/gws/internal"
"net"
"sync"
"sync/atomic"
"time"
"unicode/utf8"

"github.com/lxzan/gws/internal"
)

type Conn struct {
Expand Down Expand Up @@ -91,22 +89,17 @@ func (c *Conn) getDpsDict() []byte {
}

func (c *Conn) isTextValid(opcode Opcode, payload []byte) bool {
if !c.config.CheckUtf8Enabled {
return true
}
switch opcode {
case OpcodeText, OpcodeCloseConnection:
return utf8.Valid(payload)
default:
return true
if c.config.CheckUtf8Enabled {
return internal.CheckEncoding(uint8(opcode), payload)
}
return true
}

func (c *Conn) isClosed() bool { return atomic.LoadUint32(&c.closed) == 1 }

func (c *Conn) close(reason []byte, err error) {
c.err.Store(err)
_ = c.doWrite(OpcodeCloseConnection, reason)
_ = c.doWrite(OpcodeCloseConnection, internal.Bytes(reason))
_ = c.conn.Close()
}

Expand Down
85 changes: 85 additions & 0 deletions internal/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package internal

import (
"io"
"unicode/utf8"
)

// ReadN 精准地读取len(data)个字节, 否则返回错误
func ReadN(reader io.Reader, data []byte) error {
_, err := io.ReadFull(reader, data)
return err
}

func WriteN(writer io.Writer, content []byte) error {
_, err := writer.Write(content)
return err
}

func CheckEncoding(opcode uint8, payload []byte) bool {
switch opcode {
case 1, 8:
return utf8.Valid(payload)
default:
return true
}
}

type Payload interface {
io.WriterTo
Len() int
CheckEncoding(enabled bool, opcode uint8) bool
}

type Buffers [][]byte

func (b Buffers) CheckEncoding(enabled bool, opcode uint8) bool {
if enabled {
for i, _ := range b {
if !CheckEncoding(opcode, b[i]) {
return false
}
}
}
return true
}

func (b Buffers) Len() int {
var sum = 0
for i, _ := range b {
sum += len(b[i])
}
return sum
}

// WriteTo 可重复写
func (b Buffers) WriteTo(w io.Writer) (int64, error) {
var n = 0
for i, _ := range b {
x, err := w.Write(b[i])
n += x
if err != nil {
return int64(n), err
}
}
return int64(n), nil
}

type Bytes []byte

func (b Bytes) CheckEncoding(enabled bool, opcode uint8) bool {
if enabled {
return CheckEncoding(opcode, b)
}
return true
}

func (b Bytes) Len() int {
return len(b)
}

// WriteTo 可重复写
func (b Bytes) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(b)
return int64(n), err
}
90 changes: 90 additions & 0 deletions internal/io_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package internal

import (
"bytes"
"net"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestIOUtil(t *testing.T) {
var as = assert.New(t)

t.Run("", func(t *testing.T) {
var reader = strings.NewReader("hello")
var p = make([]byte, 5)
var err = ReadN(reader, p)
as.Nil(err)
})

t.Run("", func(t *testing.T) {
var writer = bytes.NewBufferString("")
var err = WriteN(writer, nil)
as.NoError(err)
})

t.Run("", func(t *testing.T) {
var writer = bytes.NewBufferString("")
var p = []byte("hello")
var err = WriteN(writer, p)
as.NoError(err)
})
}

func TestBuffers_WriteTo(t *testing.T) {
t.Run("", func(t *testing.T) {
var b = Buffers{
[]byte("he"),
[]byte("llo"),
}
var w = bytes.NewBufferString("")
b.WriteTo(w)
n, _ := b.WriteTo(w)
assert.Equal(t, w.String(), "hellohello")
assert.Equal(t, n, int64(5))
assert.Equal(t, b.Len(), 5)
assert.True(t, b.CheckEncoding(true, 1))
})

t.Run("", func(t *testing.T) {
var conn, _ = net.Pipe()
_ = conn.Close()
var b = Buffers{
[]byte("he"),
[]byte("llo"),
}
_, err := b.WriteTo(conn)
assert.Error(t, err)
})

t.Run("", func(t *testing.T) {
var str = "你好"
var b = Buffers{
[]byte("he"),
[]byte(str[2:]),
}
assert.False(t, b.CheckEncoding(true, 1))
})
}

func TestBytes_WriteTo(t *testing.T) {
t.Run("", func(t *testing.T) {
var b = Bytes("hello")
var w = bytes.NewBufferString("")
b.WriteTo(w)
n, _ := b.WriteTo(w)
assert.Equal(t, w.String(), "hellohello")
assert.Equal(t, n, int64(5))
assert.Equal(t, b.Len(), 5)
})

t.Run("", func(t *testing.T) {
var str = "你好"
var b = Bytes(str[2:])
assert.False(t, b.CheckEncoding(true, 1))
assert.True(t, b.CheckEncoding(false, 1))
assert.True(t, b.CheckEncoding(true, 2))
})
}
Loading

0 comments on commit 8efe3c8

Please sign in to comment.