Skip to content

Commit

Permalink
Merge pull request #86 from lxzan/dev
Browse files Browse the repository at this point in the history
Fix: ReadMaxPayloadSize Limit
  • Loading branch information
lxzan authored Apr 23, 2024
2 parents be5b1fd + e743e93 commit bf09c37
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 9 deletions.
2 changes: 1 addition & 1 deletion benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) {
config: config,
deflater: new(deflater),
}
conn1.deflater.initialize(false, conn1.pd)
conn1.deflater.initialize(false, conn1.pd, config.ReadMaxPayloadSize)
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false)

var reader = bytes.NewBuffer(buf.Bytes())
Expand Down
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (c *connector) handshake() (*Conn, *http.Response, error) {
readQueue: make(channel, c.option.ParallelGolimit),
}
if pd.Enabled {
socket.deflater.initialize(false, pd)
socket.deflater.initialize(false, pd, c.option.ReadMaxPayloadSize)
if pd.ServerContextTakeover {
socket.dpsWindow.initialize(nil, pd.ServerMaxWindowBits)
}
Expand Down
30 changes: 26 additions & 4 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ type deflaterPool struct {
pool []*deflater
}

func (c *deflaterPool) initialize(options PermessageDeflate) *deflaterPool {
func (c *deflaterPool) initialize(options PermessageDeflate, limit int) *deflaterPool {
c.num = uint64(options.PoolSize)
for i := uint64(0); i < c.num; i++ {
c.pool = append(c.pool, new(deflater).initialize(true, options))
c.pool = append(c.pool, new(deflater).initialize(true, options, limit))
}
return c
}
Expand All @@ -39,15 +39,19 @@ func (c *deflaterPool) Select() *deflater {

type deflater struct {
dpsLocker sync.Mutex
buf []byte
limit int
dpsBuffer *bytes.Buffer
dpsReader io.ReadCloser
cpsLocker sync.Mutex
cpsWriter *flate.Writer
}

func (c *deflater) initialize(isServer bool, options PermessageDeflate) *deflater {
func (c *deflater) initialize(isServer bool, options PermessageDeflate, limit int) *deflater {
c.dpsReader = flate.NewReader(nil)
c.dpsBuffer = bytes.NewBuffer(nil)
c.buf = make([]byte, 32*1024)
c.limit = limit
windowBits := internal.SelectValue(isServer, options.ServerMaxWindowBits, options.ClientMaxWindowBits)
if windowBits == 15 {
c.cpsWriter, _ = flate.NewWriter(nil, options.Level)
Expand All @@ -73,7 +77,8 @@ func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, er

_, _ = src.Write(flateTail)
c.resetFR(src, dict)
if _, err := c.dpsReader.(io.WriterTo).WriteTo(c.dpsBuffer); err != nil {
reader := limitReader(c.dpsReader, c.limit)
if _, err := io.CopyBuffer(c.dpsBuffer, reader, c.buf); err != nil {
return nil, err
}
var dst = binaryPool.Get(c.dpsBuffer.Len())
Expand Down Expand Up @@ -223,3 +228,20 @@ func permessageNegotiation(str string) PermessageDeflate {
options.ServerMaxWindowBits = internal.SelectValue(options.ServerMaxWindowBits < 8, 8, options.ServerMaxWindowBits)
return options
}

func limitReader(r io.Reader, limit int) io.Reader { return &limitedReader{R: r, M: limit} }

type limitedReader struct {
R io.Reader
N int
M int
}

func (c *limitedReader) Read(p []byte) (n int, err error) {
n, err = c.R.Read(p)
c.N += n
if c.N > c.M {
return n, internal.CloseMessageTooLarge
}
return
}
4 changes: 2 additions & 2 deletions task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ func serveWebSocket(
}
if compressEnabled {
if isServer {
socket.deflater = new(deflaterPool).initialize(pd).Select()
socket.deflater = new(deflaterPool).initialize(pd, config.ReadMaxPayloadSize).Select()
if pd.ServerContextTakeover {
socket.cpsWindow.initialize(config.cswPool, pd.ServerMaxWindowBits)
}
if pd.ClientContextTakeover {
socket.dpsWindow.initialize(config.dswPool, pd.ClientMaxWindowBits)
}
} else {
socket.deflater = new(deflater).initialize(false, pd)
socket.deflater = new(deflater).initialize(false, pd, config.ReadMaxPayloadSize)
}
}
return socket
Expand Down
2 changes: 1 addition & 1 deletion upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func NewUpgrader(eventHandler Event, option *ServerOption) *Upgrader {
deflaterPool: new(deflaterPool),
}
if u.option.PermessageDeflate.Enabled {
u.deflaterPool.initialize(u.option.PermessageDeflate)
u.deflaterPool.initialize(u.option.PermessageDeflate, option.ReadMaxPayloadSize)
}
return u
}
Expand Down
31 changes: 31 additions & 0 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gws
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -79,6 +80,36 @@ func TestWriteBigMessage(t *testing.T) {
var err = server.WriteMessage(OpcodeText, internal.AlphabetNumeric.Generate(128))
assert.Error(t, err)
})

t.Run("", func(t *testing.T) {
var wg = &sync.WaitGroup{}
wg.Add(1)
var serverHandler = new(webSocketMocker)
var clientHandler = new(webSocketMocker)
serverHandler.onClose = func(socket *Conn, err error) {
assert.True(t, errors.Is(err, internal.CloseMessageTooLarge))
wg.Done()
}
var serverOption = &ServerOption{
ReadMaxPayloadSize: 128,
PermessageDeflate: PermessageDeflate{Enabled: true, Threshold: 1},
}
var clientOption = &ClientOption{
ReadMaxPayloadSize: 128 * 1024,
PermessageDeflate: PermessageDeflate{Enabled: true, Threshold: 1},
}
server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
go server.ReadLoop()
go client.ReadLoop()

var buf = bytes.NewBufferString("")
for i := 0; i < 64*1024; i++ {
buf.WriteString("a")
}
var err = client.WriteMessage(OpcodeText, buf.Bytes())
assert.NoError(t, err)
wg.Wait()
})
}

func TestWriteClose(t *testing.T) {
Expand Down

0 comments on commit bf09c37

Please sign in to comment.