Skip to content

Commit

Permalink
Curing anonymous functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lxzan committed Nov 10, 2023
1 parent 37aaa6c commit e5db458
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 50 deletions.
14 changes: 7 additions & 7 deletions internal/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ func (p *BufferPool) Put(b *bytes.Buffer) {
if b == nil || b.Cap() == 0 {
return
}
if i := p.getIndex(uint32(b.Cap())); i > 0 {
p.pools[i].Put(b)
if index := p.getIndex(uint32(b.Cap())); index > 0 {
p.pools[index].Put(b)
}
}

Expand All @@ -53,12 +53,12 @@ func (p *BufferPool) Get(n int) *bytes.Buffer {
return bytes.NewBuffer(make([]byte, 0, n))
}

buf := p.pools[index].Get().(*bytes.Buffer)
if buf.Cap() < n {
buf.Grow(p.limits[index])
b := p.pools[index].Get().(*bytes.Buffer)
if b.Cap() < n {
b.Grow(p.limits[index])
}
buf.Reset()
return buf
b.Reset()
return b
}

func (p *BufferPool) getIndex(v uint32) int {
Expand Down
105 changes: 87 additions & 18 deletions task.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
package gws

import (
"bytes"
"sync"
)

type (
workerQueue struct {
mu sync.Mutex // 锁
q []asyncJob // 任务队列
q heap // 任务队列
maxConcurrency int32 // 最大并发
curConcurrency int32 // 当前并发
}

asyncJob func()
asyncJob struct {
serial int
socket *Conn
frame *bytes.Buffer
execute func(conn *Conn, buffer *bytes.Buffer)
}
)

// newWorkerQueue 创建一个任务队列
Expand All @@ -25,28 +31,19 @@ func newWorkerQueue(maxConcurrency int32) *workerQueue {
return c
}

func (c *workerQueue) pop() asyncJob {
if len(c.q) == 0 {
return nil
}
var job = c.q[0]
c.q = c.q[1:]
return job
}

// 获取一个任务
func (c *workerQueue) getJob(newJob asyncJob, delta int32) asyncJob {
func (c *workerQueue) getJob(newJob *asyncJob, delta int32) *asyncJob {
c.mu.Lock()
defer c.mu.Unlock()

if newJob != nil {
c.q = append(c.q, newJob)
c.q.Push(newJob)
}
c.curConcurrency += delta
if c.curConcurrency >= c.maxConcurrency {
return nil
}
var job = c.pop()
var job = c.q.Pop()
if job == nil {
return nil
}
Expand All @@ -55,15 +52,15 @@ func (c *workerQueue) getJob(newJob asyncJob, delta int32) asyncJob {
}

// 循环执行任务
func (c *workerQueue) do(job asyncJob) {
func (c *workerQueue) do(job *asyncJob) {
for job != nil {
job()
job.execute(job.socket, job.frame)
job = c.getJob(nil, -1)
}
}

// Push 追加任务, 有资源空闲的话会立即执行
func (c *workerQueue) Push(job asyncJob) {
func (c *workerQueue) Push(job *asyncJob) {
if nextJob := c.getJob(job, 0); nextJob != nil {
go c.do(nextJob)
}
Expand All @@ -78,8 +75,80 @@ func (c channel) done() { <-c }
func (c channel) Go(m *Message, f func(*Message) error) error {
c.add()
go func() {
f(m)
_ = f(m)
c.done()
}()
return nil
}

type heap struct {
data []*asyncJob
serial int
}

func (c *heap) next() int {
c.serial++
return c.serial
}

func (c *heap) less(i, j int) bool {
return c.data[i].serial < c.data[j].serial
}

func (c *heap) Len() int {
return len(c.data)
}

func (c *heap) swap(i, j int) {
c.data[i], c.data[j] = c.data[j], c.data[i]
}

func (c *heap) Push(v *asyncJob) {
if v.serial == 0 {
v.serial = c.next()
}
c.data = append(c.data, v)
c.up(c.Len() - 1)
}

func (c *heap) up(i int) {
var j = (i - 1) / 2
if i >= 1 && c.less(i, j) {
c.swap(i, j)
c.up(j)
}
}

func (c *heap) Pop() *asyncJob {
n := c.Len()
switch n {
case 0:
return nil
case 1:
v := c.data[0]
c.data = c.data[:0]
return v
default:
v := c.data[0]
c.data[0] = c.data[n-1]
c.data = c.data[:n-1]
c.down(0, n-1)
return v
}
}

func (c *heap) down(i, n int) {
var j = 2*i + 1
var k = 2*i + 2
var x = -1
if j < n {
x = j
}
if k < n && c.less(k, j) {
x = k
}
if x != -1 && c.less(x, i) {
c.swap(i, x)
c.down(x, n)
}
}
44 changes: 36 additions & 8 deletions task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package gws

import (
"bufio"
"bytes"
"fmt"
"net"
"sort"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -233,14 +235,14 @@ func TestTaskQueue(t *testing.T) {
listA = append(listA, i)

v := i
q.Push(func() {
q.Push(&asyncJob{execute: func(conn *Conn, buffer *bytes.Buffer) {
defer wg.Done()
var latency = time.Duration(internal.AlphabetNumeric.Intn(100)) * time.Microsecond
time.Sleep(latency)
mu.Lock()
listB = append(listB, v)
mu.Unlock()
})
}})
}
wg.Wait()
as.ElementsMatch(listA, listB)
Expand All @@ -253,11 +255,11 @@ func TestTaskQueue(t *testing.T) {
wg.Add(1000)
for i := int64(1); i <= 1000; i++ {
var tmp = i
w.Push(func() {
w.Push(&asyncJob{execute: func(conn *Conn, buffer *bytes.Buffer) {
time.Sleep(time.Millisecond)
atomic.AddInt64(&sum, tmp)
wg.Done()
})
}})
}
wg.Wait()
as.Equal(sum, int64(500500))
Expand All @@ -270,11 +272,11 @@ func TestTaskQueue(t *testing.T) {
wg.Add(1000)
for i := int64(1); i <= 1000; i++ {
var tmp = i
w.Push(func() {
w.Push(&asyncJob{execute: func(conn *Conn, buffer *bytes.Buffer) {
time.Sleep(time.Millisecond)
atomic.AddInt64(&sum, tmp)
wg.Done()
})
}})
}
wg.Wait()
as.Equal(sum, int64(500500))
Expand Down Expand Up @@ -348,16 +350,42 @@ func TestRQueue(t *testing.T) {
var serial = int64(0)
var done = make(chan struct{})
for i := 0; i < total; i++ {
q.Push(func() {
q.Push(&asyncJob{execute: func(conn *Conn, buffer *bytes.Buffer) {
x := atomic.AddInt64(&concurrency, 1)
assert.LessOrEqual(t, x, int64(limit))
time.Sleep(10 * time.Millisecond)
atomic.AddInt64(&concurrency, -1)
if atomic.AddInt64(&serial, 1) == total {
done <- struct{}{}
}
})
}})
}
<-done
})
}

func TestHeap_Sort(t *testing.T) {
var count = 1000
var list0 []int
var list1 []int
var h heap
for i := 0; i < count; i++ {
var v = internal.Numeric.Intn(count) + 1
list0 = append(list0, v)
h.Push(&asyncJob{serial: v})
}

sort.Ints(list0)
for h.Len() > 0 {
list1 = append(list1, h.Pop().serial)
}
for i := 0; i < count; i++ {
assert.Equal(t, list0[i], list1[i])
}
assert.Zero(t, h.Len())
}

func TestHeap_Pop(t *testing.T) {
var h = heap{}
assert.Nil(t, h.Pop())
}
39 changes: 22 additions & 17 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ func (c *Conn) WriteString(s string) error {
return c.WriteMessage(OpcodeText, internal.StringToBytes(s))
}

func writeAsync(socket *Conn, buffer *bytes.Buffer) {
if socket.isClosed() {
return
}
err := internal.WriteN(socket.conn, buffer.Bytes())
binaryPool.Put(buffer)
socket.emitError(err)
}

// WriteAsync 异步非阻塞地写入消息
// Write messages asynchronously and non-blocking
func (c *Conn) WriteAsync(opcode Opcode, payload []byte) error {
Expand All @@ -48,15 +57,8 @@ func (c *Conn) WriteAsync(opcode Opcode, payload []byte) error {
c.emitError(err)
return err
}

c.writeQueue.Push(func() {
if c.isClosed() {
return
}
err = internal.WriteN(c.conn, frame.Bytes())
binaryPool.Put(frame)
c.emitError(err)
})
job := &asyncJob{socket: c, frame: frame, execute: writeAsync}
c.writeQueue.Push(job)
return nil
}

Expand Down Expand Up @@ -162,6 +164,15 @@ func NewBroadcaster(opcode Opcode, payload []byte) *Broadcaster {
return c
}

func (c *Broadcaster) writeAsync(socket *Conn, buffer *bytes.Buffer) {
if !socket.isClosed() {
socket.emitError(internal.WriteN(socket.conn, buffer.Bytes()))
}
if atomic.AddInt64(&c.state, -1) == 0 {
c.doClose()
}
}

// Broadcast 广播
// 向客户端发送广播消息
// Send a broadcast message to a client.
Expand All @@ -174,14 +185,8 @@ func (c *Broadcaster) Broadcast(socket *Conn) error {
}

atomic.AddInt64(&c.state, 1)
socket.writeQueue.Push(func() {
if !socket.isClosed() {
socket.emitError(internal.WriteN(socket.conn, msg.frame.Bytes()))
}
if atomic.AddInt64(&c.state, -1) == 0 {
c.doClose()
}
})
var job = &asyncJob{socket: socket, frame: msg.frame, execute: c.writeAsync}
socket.writeQueue.Push(job)
return nil
}

Expand Down

0 comments on commit e5db458

Please sign in to comment.