Skip to content

Commit

Permalink
async write
Browse files Browse the repository at this point in the history
async write

async write

async write

async write
  • Loading branch information
lixizan committed Feb 23, 2023
1 parent 8c5adca commit a2bb4a9
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ bench:
go test -benchmem -bench ^Benchmark github.com/lxzan/gws

build:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/gws-server-linux-amd64 github.com/lxzan/gws/examples/benchmark
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/gws-server-linux-amd64 github.com/lxzan/gws/examples/testsuite

run-testsuite-server:
go run github.com/lxzan/gws/examples/testsuite
Expand Down
116 changes: 116 additions & 0 deletions aio.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package gws

import (
"sync"
)

type (
workerQueue struct {
mu *sync.Mutex // 锁
q []writeJob // 任务队列
maxConcurrency int64 // 最大并发
curConcurrency int64 // 当前并发
}

writeJob struct {
Args *Conn
Do func(args *Conn) error
}

messageWrapper struct {
opcode Opcode
payload []byte
}
)

// newWorkerQueue 创建一个工作队列
func newWorkerQueue(maxConcurrency int64) *workerQueue {
c := &workerQueue{
mu: &sync.Mutex{},
q: make([]writeJob, 0),
maxConcurrency: maxConcurrency,
curConcurrency: 0,
}
return c
}

func (c *workerQueue) getJob() interface{} {
c.mu.Lock()
defer c.mu.Unlock()

if c.curConcurrency >= c.maxConcurrency {
return nil
}
if n := len(c.q); n == 0 {
return nil
}
var result = c.q[0]
c.q = c.q[1:]
c.curConcurrency++
return result
}

func (c *workerQueue) decrease() {
c.mu.Lock()
c.curConcurrency--
c.mu.Unlock()
}

func (c *workerQueue) do(job writeJob) {
job.Args.emitError(job.Do(job.Args))
c.decrease()
if nextJob := c.getJob(); nextJob != nil {
c.do(nextJob.(writeJob))
}
}

// AddJob 追加任务, 有资源空闲的话会立即执行
func (c *workerQueue) AddJob(job writeJob) {
c.mu.Lock()
c.q = append(c.q, job)
c.mu.Unlock()
if item := c.getJob(); item != nil {
go c.do(item.(writeJob))
}
}

func newMessageQueue() messageQueue {
return messageQueue{
mu: &sync.RWMutex{},
data: []messageWrapper{},
}
}

type messageQueue struct {
mu *sync.RWMutex
data []messageWrapper
}

func (c *messageQueue) Len() int {
c.mu.RLock()
n := len(c.data)
c.mu.RUnlock()
return n
}

func (c *messageQueue) Push(conn *Conn, m messageWrapper) {
c.mu.Lock()
c.data = append(c.data, m)
if n := len(c.data); n == 1 {
_writeQueue.AddJob(writeJob{Args: conn, Do: doWriteAsync})
}
c.mu.Unlock()
}

func (c *messageQueue) Range(f func(msg messageWrapper) error) error {
c.mu.Lock()
defer c.mu.Unlock()

for i, _ := range c.data {
if err := f(c.data[i]); err != nil {
return err
}
}
c.data = c.data[:0]
return nil
}
61 changes: 61 additions & 0 deletions aio_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package gws

import (
"bufio"
"github.com/stretchr/testify/assert"
"net"
"sync"
"testing"
)

func newPeer(config *Upgrader) (server, client *Conn) {
size := 4096
s, c := net.Pipe()
{
brw := bufio.NewReadWriter(bufio.NewReaderSize(s, size), bufio.NewWriterSize(s, size))
server = serveWebSocket(config, &Request{}, s, brw, config.EventHandler, config.CompressEnabled)
}
{
brw := bufio.NewReadWriter(bufio.NewReaderSize(c, size), bufio.NewWriterSize(c, size))
client = serveWebSocket(config, &Request{}, c, brw, config.EventHandler, config.CompressEnabled)
}
return
}

func TestConn_WriteAsync(t *testing.T) {
var as = assert.New(t)
SetMaxConcurrencyForWriteQueue(8)
var handler = new(webSocketMocker)
var upgrader = NewUpgrader(func(c *Upgrader) {
c.EventHandler = handler
})
server, client := newPeer(upgrader)

var message = []byte("hello")
var count = 1000

go func() {
for i := 0; i < count; i++ {
server.WriteMessageAsync(OpcodeText, message)
}
}()

var wg sync.WaitGroup
wg.Add(count)
go func() {
for {
var header = frameHeader{}
_, err := client.conn.Read(header[:2])
if err != nil {
return
}
var payload = make([]byte, header.GetLengthCode())
if _, err := client.conn.Read(payload); err != nil {
return
}
as.Equal(string(message), string(payload))
wg.Done()
}
}()
wg.Wait()
}
4 changes: 2 additions & 2 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func (c *decompressor) Decompress(payload *bytes.Buffer) (*bytes.Buffer, error)
return nil, err
}

var buf = bpool.Get(3 * payload.Len())
var buf = _bpool.Get(3 * payload.Len())
_, err := io.Copy(buf, c.fr)
bpool.Put(payload)
_bpool.Put(payload)
return buf, err
}
3 changes: 3 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type Conn struct {
closed uint32
// write lock
wmu *sync.Mutex
// messages wait for sending
wmq messageQueue
}

func serveWebSocket(config *Upgrader, r *Request, netConn net.Conn, brw *bufio.ReadWriter, handler Event, compressEnabled bool) *Conn {
Expand All @@ -59,6 +61,7 @@ func serveWebSocket(config *Upgrader, r *Request, netConn net.Conn, brw *bufio.R
rbuf: brw.Reader,
fh: frameHeader{},
handler: handler,
wmq: newMessageQueue(),
}
if c.compressEnabled {
c.compressor = newCompressor(config.CompressLevel)
Expand Down
2 changes: 1 addition & 1 deletion frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (c *Message) Read(p []byte) (n int, err error) {

// Close recycle buffer
func (c *Message) Close() {
bpool.Put(c.Data)
_bpool.Put(c.Data)
c.Data = nil
}

Expand Down
19 changes: 19 additions & 0 deletions init.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package gws

import "github.com/lxzan/gws/internal"

const defaultAsyncWriteConcurrency = 128

var (
// task queue for async write
_writeQueue = newWorkerQueue(defaultAsyncWriteConcurrency)

// buffer pool
_bpool = internal.NewBufferPool()
)

func SetMaxConcurrencyForWriteQueue(num int64) {
if num > 0 {
_writeQueue.maxConcurrency = num
}
}
6 changes: 3 additions & 3 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ func (c *Conn) readMessage() error {
return err
}
contentLength = int(binary.BigEndian.Uint16(c.fh[2:4]))
buf = bpool.Get(contentLength)
buf = _bpool.Get(contentLength)
case 127:
err := internal.ReadN(c.rbuf, c.fh[2:10], 8)
if err != nil {
return err
}
contentLength = int(binary.BigEndian.Uint64(c.fh[2:10]))
buf = bpool.Get(contentLength)
buf = _bpool.Get(contentLength)
default:
buf = bpool.Get(int(lengthCode))
buf = _bpool.Get(int(lengthCode))
}

if contentLength > c.config.MaxContentLength {
Expand Down
4 changes: 0 additions & 4 deletions updrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ import (
"strings"
)

var (
bpool = internal.NewBufferPool()
)

const (
defaultCompressLevel = flate.BestSpeed // Best Speed
defaultMaxContentLength = 16 * 1024 * 1024 // 16MiB
Expand Down
43 changes: 43 additions & 0 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,46 @@ func (c *Conn) writeFrame(opcode Opcode, payload []byte, enableCompress bool) er
}
return c.wbuf.Flush()
}

// WriteMessageAsync write message async
func (c *Conn) WriteMessageAsync(opcode Opcode, payload []byte) {
c.wmq.Push(c, messageWrapper{
opcode: opcode,
payload: payload,
})
}

// write and clear messages
func doWriteAsync(conn *Conn) error {
if conn.wmq.Len() == 0 {
return nil
}

conn.wmu.Lock()
defer conn.wmu.Unlock()
return conn.wmq.Range(func(msg messageWrapper) error {
if atomic.LoadUint32(&conn.closed) == 1 {
return internal.ErrConnClosed
}

var enableCompress = conn.compressEnabled && msg.opcode.IsDataFrame() && len(msg.payload) >= conn.config.CompressionThreshold
if enableCompress {
compressedContent, err := conn.compressor.Compress(bytes.NewBuffer(msg.payload))
if err != nil {
return internal.NewError(internal.CloseInternalServerErr, err)
}
msg.payload = compressedContent.Bytes()
}

var header = frameHeader{}
var n = len(msg.payload)
var headerLength = header.GenerateServerHeader(true, enableCompress, msg.opcode, n)
if err := internal.WriteN(conn.wbuf, header[:headerLength], headerLength); err != nil {
return err
}
if err := internal.WriteN(conn.wbuf, msg.payload, n); err != nil {
return err
}
return conn.wbuf.Flush()
})
}

0 comments on commit a2bb4a9

Please sign in to comment.