Skip to content

Commit

Permalink
implement tests
Browse files Browse the repository at this point in the history
  • Loading branch information
josebalius authored Aug 20, 2021
1 parent a76c9b7 commit 9c9f07a
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 11 deletions.
30 changes: 19 additions & 11 deletions netconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,36 @@ import (
"github.com/gorilla/websocket"
)

var _ net.Conn = (*NetConn)(nil)

var (
// ErrUnexpectedMsgType is returned when a reader is returned by the websocket
// connection that does not match the message type NetConn was created with
ErrUnexpectedMsgType = errors.New("unexpected message type")
)

type netConn struct {
type NetConn struct {
wsConn *websocket.Conn
msgType int

reader io.Reader
}

func NewNetConn(wsConn *websocket.Conn, msgType int) *netConn {
return &netConn{wsConn, msgType, nil}
// NewNetConn returns a NetConn pointer. It takes a gorilla websocket
// connection and a message type. See: https://github.com/gorilla/websocket/blob/v1.4.2/conn.go#L62
// for possible values
func NewNetConn(wsConn *websocket.Conn, msgType int) *NetConn {
return &NetConn{wsConn, msgType, nil}
}

func (c *netConn) Read(b []byte) (int, error) {
func (c *NetConn) Read(b []byte) (int, error) {
if c.reader == nil {
msgType, reader, err := c.wsConn.NextReader()
if err != nil {
return 0, fmt.Errorf("next reader: %w", err)
}

// err if we receive an unsupported message type
if msgType != c.msgType {
return 0, ErrUnexpectedMsgType
}
Expand All @@ -51,7 +59,7 @@ func (c *netConn) Read(b []byte) (int, error) {
return bytesRead, err
}

func (c *netConn) Write(b []byte) (int, error) {
func (c *NetConn) Write(b []byte) (int, error) {
nextWriter, err := c.wsConn.NextWriter(c.msgType)
if err != nil {
return 0, fmt.Errorf("next writer: %w", err)
Expand All @@ -65,30 +73,30 @@ func (c *netConn) Write(b []byte) (int, error) {
return bytesWritten, nextWriter.Close()
}

func (c *netConn) Close() error {
func (c *NetConn) Close() error {
return c.wsConn.Close()
}

func (c *netConn) LocalAddr() net.Addr {
func (c *NetConn) LocalAddr() net.Addr {
return c.wsConn.LocalAddr()
}

func (c *netConn) RemoteAddr() net.Addr {
func (c *NetConn) RemoteAddr() net.Addr {
return c.wsConn.RemoteAddr()
}

func (c *netConn) SetDeadline(t time.Time) error {
func (c *NetConn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}

return c.SetWriteDeadline(t)
}

func (c *netConn) SetReadDeadline(t time.Time) error {
func (c *NetConn) SetReadDeadline(t time.Time) error {
return c.wsConn.SetReadDeadline(t)
}

func (c *netConn) SetWriteDeadline(t time.Time) error {
func (c *NetConn) SetWriteDeadline(t time.Time) error {
return c.wsConn.SetWriteDeadline(t)
}
99 changes: 99 additions & 0 deletions netconn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package gwebsocket

import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/gorilla/websocket"
)

func TestNewConn(t *testing.T) {
wsConn := websocket.Conn{}
nc := NewNetConn(&wsConn, websocket.TextMessage)
if nc == nil {
t.Error("netConn is nil")
}
}

func TestConn(t *testing.T) {
upgrader := websocket.Upgrader{}
ping := []byte("ping")
pong := []byte("pong")

done := make(chan error)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
c, err := upgrader.Upgrade(w, req, nil)
if err != nil {
done <- fmt.Errorf("upgrade req: %w", err)
return
}
defer c.Close()

nc := NewNetConn(c, websocket.TextMessage)
b := make([]byte, len(ping))
n, err := nc.Read(b)
if err != nil {
done <- fmt.Errorf("server read message: %w", err)
return
}
if n != len(ping) {
t.Errorf("reading ping message len unexpected, got '%v'", n)
}

msg := string(b)
if msg != "ping" {
done <- fmt.Errorf("incoming message is not expected value, got: '%v'", msg)
return
}

n, err = nc.Write(pong)
if err != nil {
done <- fmt.Errorf("write pong: %w", err)
return
}
if n != len(pong) {
done <- fmt.Errorf("pong message len unexpected, got: '%v'", n)
return
}

done <- nil
}))
defer server.Close()

c, _, err := websocket.DefaultDialer.Dial(strings.Replace(server.URL, "http", "ws", -1), nil)
if err != nil {
t.Errorf("dial: %w", err)
}
defer c.Close()

nc := NewNetConn(c, websocket.TextMessage)

n, err := nc.Write(ping)
if err != nil {
t.Errorf("write message: %w", err)
}
if n != len(ping) {
t.Errorf("ping message len unexpected, got: '%v'", n)
}

b := make([]byte, len(pong))
n, err = nc.Read(b)
if err != nil {
t.Errorf("read message: %w", err)
}
if n != len(pong) {
t.Errorf("reading pong message len unexpected, got '%v'", n)
}

msg := string(b)
if msg != "pong" {
t.Errorf("client incoming message unexpectd, got: '%v'", msg)
}

if err := <-done; err != nil {
t.Error(err)
}
}

0 comments on commit 9c9f07a

Please sign in to comment.