diff --git a/accept.go b/accept.go index f45fdd0b..6227a6a3 100644 --- a/accept.go +++ b/accept.go @@ -1,6 +1,3 @@ -//go:build !js -// +build !js - package websocket import ( @@ -167,7 +164,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - return newConn(connConfig{ + return &Conn{newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, @@ -178,7 +175,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con br: brw.Reader, bw: brw.Writer, - }), nil + })}, nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { diff --git a/close.go b/close_std.go similarity index 94% rename from close.go rename to close_std.go index f94951dc..4219156c 100644 --- a/close.go +++ b/close_std.go @@ -1,6 +1,3 @@ -//go:build !js -// +build !js - package websocket import ( @@ -97,7 +94,7 @@ func CloseStatus(err error) StatusCode { // // Close will unblock all goroutines interacting with the connection once // complete. -func (c *Conn) Close(code StatusCode, reason string) (err error) { +func (c *StdConn) Close(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") if c.casClosing() { @@ -130,7 +127,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) { // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. -func (c *Conn) CloseNow() (err error) { +func (c *StdConn) CloseNow() (err error) { defer errd.Wrap(&err, "failed to immediately close WebSocket") if c.casClosing() { @@ -155,7 +152,7 @@ func (c *Conn) CloseNow() (err error) { return err } -func (c *Conn) closeHandshake(code StatusCode, reason string) error { +func (c *StdConn) closeHandshake(code StatusCode, reason string) error { err := c.writeClose(code, reason) if err != nil { return err @@ -168,7 +165,7 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error { return nil } -func (c *Conn) writeClose(code StatusCode, reason string) error { +func (c *StdConn) writeClose(code StatusCode, reason string) error { ce := CloseError{ Code: code, Reason: reason, @@ -196,7 +193,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { return nil } -func (c *Conn) waitCloseHandshake() error { +func (c *StdConn) waitCloseHandshake() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -228,7 +225,7 @@ func (c *Conn) waitCloseHandshake() error { } } -func (c *Conn) waitGoroutines() error { +func (c *StdConn) waitGoroutines() error { t := time.NewTimer(time.Second * 15) defer t.Stop() @@ -328,11 +325,11 @@ func (ce CloseError) bytesErr() ([]byte, error) { return buf, nil } -func (c *Conn) casClosing() bool { +func (c *StdConn) casClosing() bool { return c.closing.Swap(true) } -func (c *Conn) isClosed() bool { +func (c *StdConn) isClosed() bool { select { case <-c.closed: return true diff --git a/compress.go b/compress.go index 1f3adcfb..fa4ff1f2 100644 --- a/compress.go +++ b/compress.go @@ -1,6 +1,3 @@ -//go:build !js -// +build !js - package websocket import ( diff --git a/conn.go b/conn.go index 42fe89fe..4220b572 100644 --- a/conn.go +++ b/conn.go @@ -1,307 +1,31 @@ -//go:build !js -// +build !js - package websocket import ( - "bufio" "context" - "fmt" "io" - "net" - "runtime" - "strconv" - "sync" - "sync/atomic" -) - -// MessageType represents the type of a WebSocket message. -// See https://tools.ietf.org/html/rfc6455#section-5.6 -type MessageType int - -// MessageType constants. -const ( - // MessageText is for UTF-8 encoded text messages like JSON. - MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like protobufs. - MessageBinary ) -// Conn represents a WebSocket connection. -// All methods may be called concurrently except for Reader and Read. -// -// You must always read from the connection. Otherwise control -// frames will not be handled. See Reader and CloseRead. -// -// Be sure to call Close on the connection when you -// are finished with it to release associated resources. -// -// On any error from any method, the connection is closed -// with an appropriate reason. -// -// This applies to context expirations as well unfortunately. -// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 type Conn struct { - noCopy noCopy - - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - br *bufio.Reader - bw *bufio.Writer - - readTimeout chan context.Context - writeTimeout chan context.Context - timeoutLoopDone chan struct{} - - // Read state. - readMu *mu - readHeaderBuf [8]byte - readControlBuf [maxControlPayload]byte - msgReader *msgReader - - // Write state. - msgWriter *msgWriter - writeFrameMu *mu - writeBuf []byte - writeHeaderBuf [8]byte - writeHeader header - - // Close handshake state. - closeStateMu sync.RWMutex - closeReceivedErr error - closeSentErr error - - // CloseRead state. - closeReadMu sync.Mutex - closeReadCtx context.Context - closeReadDone chan struct{} - - closing atomic.Bool - closeMu sync.Mutex // Protects following. - closed chan struct{} - - pingCounter atomic.Int64 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} - onPingReceived func(context.Context, []byte) bool - onPongReceived func(context.Context, []byte) -} - -type connConfig struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - onPingReceived func(context.Context, []byte) bool - onPongReceived func(context.Context, []byte) - - br *bufio.Reader - bw *bufio.Writer -} - -func newConn(cfg connConfig) *Conn { - c := &Conn{ - subprotocol: cfg.subprotocol, - rwc: cfg.rwc, - client: cfg.client, - copts: cfg.copts, - flateThreshold: cfg.flateThreshold, - - br: cfg.br, - bw: cfg.bw, - - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - timeoutLoopDone: make(chan struct{}), - - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), - onPingReceived: cfg.onPingReceived, - onPongReceived: cfg.onPongReceived, - } - - c.readMu = newMu(c) - c.writeFrameMu = newMu(c) - - c.msgReader = newMsgReader(c) - - c.msgWriter = newMsgWriter(c) - if c.client { - c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) - } - - if c.flate() && c.flateThreshold == 0 { - c.flateThreshold = 128 - if !c.msgWriter.flateContextTakeover() { - c.flateThreshold = 512 - } - } - - runtime.SetFinalizer(c, func(c *Conn) { - c.close() - }) - - go c.timeoutLoop() - - return c -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - -func (c *Conn) close() error { - c.closeMu.Lock() - defer c.closeMu.Unlock() - - if c.isClosed() { - return net.ErrClosed - } - runtime.SetFinalizer(c, nil) - close(c.closed) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - err := c.rwc.Close() - // With the close of rwc, these become safe to close. - c.msgWriter.close() - c.msgReader.close() - return err -} - -func (c *Conn) timeoutLoop() { - defer close(c.timeoutLoopDone) - - readCtx := context.Background() - writeCtx := context.Background() - - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.close() - return - case <-writeCtx.Done(): - c.close() - return - } - } -} - -func (c *Conn) flate() bool { - return c.copts != nil -} - -// Ping sends a ping to the peer and waits for a pong. -// Use this to measure latency or ensure the peer is responsive. -// Ping must be called concurrently with Reader as it does -// not read from the connection but instead waits for a Reader call -// to read the pong. -// -// TCP Keepalives should suffice for most use cases. -func (c *Conn) Ping(ctx context.Context) error { - p := c.pingCounter.Add(1) - - err := c.ping(ctx, strconv.FormatInt(p, 10)) - if err != nil { - return fmt.Errorf("failed to ping: %w", err) - } - return nil -} - -func (c *Conn) ping(ctx context.Context, p string) error { - pong := make(chan struct{}, 1) - - c.activePingsMu.Lock() - c.activePings[p] = pong - c.activePingsMu.Unlock() - - defer func() { - c.activePingsMu.Lock() - delete(c.activePings, p) - c.activePingsMu.Unlock() - }() - - err := c.writeControl(ctx, opPing, []byte(p)) - if err != nil { - return err - } - - select { - case <-c.closed: - return net.ErrClosed - case <-ctx.Done(): - return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - case <-pong: - return nil - } -} - -type mu struct { - c *Conn - ch chan struct{} + Stream } -func newMu(c *Conn) *mu { - return &mu{ - c: c, - ch: make(chan struct{}, 1), - } +type Stream interface { + Close(code StatusCode, reason string) (err error) + CloseNow() (err error) + CloseRead(ctx context.Context) context.Context + Ping(ctx context.Context) error + Read(ctx context.Context) (MessageType, []byte, error) + Reader(ctx context.Context) (MessageType, io.Reader, error) + SetReadLimit(n int64) + Subprotocol() string + Write(ctx context.Context, typ MessageType, p []byte) error + Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) + conn() any + newMu() muLocker } -func (m *mu) forceLock() { - m.ch <- struct{}{} +type muLocker interface { + forceLock() + tryLock() bool + unlock() } - -func (m *mu) tryLock() bool { - select { - case m.ch <- struct{}{}: - return true - default: - return false - } -} - -func (m *mu) lock(ctx context.Context) error { - select { - case <-m.c.closed: - return net.ErrClosed - case <-ctx.Done(): - return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - case m.ch <- struct{}{}: - // To make sure the connection is certainly alive. - // As it's possible the send on m.ch was selected - // over the receive on closed. - select { - case <-m.c.closed: - // Make sure to release. - m.unlock() - return net.ErrClosed - default: - } - return nil - } -} - -func (m *mu) unlock() { - select { - case <-m.ch: - default: - } -} - -type noCopy struct{} - -func (*noCopy) Lock() {} diff --git a/conn_std.go b/conn_std.go new file mode 100644 index 00000000..80c62ec3 --- /dev/null +++ b/conn_std.go @@ -0,0 +1,314 @@ +package websocket + +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "runtime" + "strconv" + "sync" + "sync/atomic" +) + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary +) + +// StdConn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +// +// This applies to context expirations as well unfortunately. +// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 +type StdConn struct { + noCopy noCopy + + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + timeoutLoopDone chan struct{} + + // Read state. + readMu *stdMu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader + + // Write state. + msgWriter *msgWriter + writeFrameMu *stdMu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header + + // Close handshake state. + closeStateMu sync.RWMutex + closeReceivedErr error + closeSentErr error + + // CloseRead state. + closeReadMu sync.Mutex + closeReadCtx context.Context + closeReadDone chan struct{} + + closing atomic.Bool + closeMu sync.Mutex // Protects following. + closed chan struct{} + + pingCounter atomic.Int64 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) +} + +var _ Stream = (*StdConn)(nil) + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *StdConn { + c := &StdConn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + timeoutLoopDone: make(chan struct{}), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + onPingReceived: cfg.onPingReceived, + onPongReceived: cfg.onPongReceived, + } + + c.readMu = newStdMu(c) + c.writeFrameMu = newStdMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriter = newMsgWriter(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 128 + if !c.msgWriter.flateContextTakeover() { + c.flateThreshold = 512 + } + } + + runtime.SetFinalizer(c, func(c *StdConn) { + c.close() + }) + + go c.timeoutLoop() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *StdConn) Subprotocol() string { + return c.subprotocol +} + +func (c *StdConn) conn() any { + return c.rwc +} + +func (c *StdConn) close() error { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return net.ErrClosed + } + runtime.SetFinalizer(c, nil) + close(c.closed) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + err := c.rwc.Close() + // With the close of rwc, these become safe to close. + c.msgWriter.close() + c.msgReader.close() + return err +} + +func (c *StdConn) timeoutLoop() { + defer close(c.timeoutLoopDone) + + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.close() + return + case <-writeCtx.Done(): + c.close() + return + } + } +} + +func (c *StdConn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *StdConn) Ping(ctx context.Context) error { + p := c.pingCounter.Add(1) + + err := c.ping(ctx, strconv.FormatInt(p, 10)) + if err != nil { + return fmt.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *StdConn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}, 1) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return net.ErrClosed + case <-ctx.Done(): + return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) + case <-pong: + return nil + } +} + +type stdMu struct { + c *StdConn + ch chan struct{} +} + +func newStdMu(c *StdConn) *stdMu { + return &stdMu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (c *StdConn) newMu() muLocker { + return newStdMu(c) +} + +func (m *stdMu) forceLock() { + m.ch <- struct{}{} +} + +func (m *stdMu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *stdMu) lock(ctx context.Context) error { + select { + case <-m.c.closed: + return net.ErrClosed + case <-ctx.Done(): + return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + case m.ch <- struct{}{}: + // To make sure the connection is certainly alive. + // As it's possible the send on m.ch was selected + // over the receive on closed. + select { + case <-m.c.closed: + // Make sure to release. + m.unlock() + return net.ErrClosed + default: + } + return nil + } +} + +func (m *stdMu) unlock() { + select { + case <-m.ch: + default: + } +} + +type noCopy struct{} + +func (*noCopy) Lock() {} diff --git a/conn_test.go b/conn_test.go index 45bb75be..561ca6fb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -567,8 +567,8 @@ func BenchmarkConn(b *testing.B) { bb.goEchoLoop(c2) - bytesWritten := c1.RecordBytesWritten() - bytesRead := c1.RecordBytesRead() + bytesWritten := c1.Stream.(*websocket.StdConn).RecordBytesWritten() + bytesRead := c1.Stream.(*websocket.StdConn).RecordBytesRead() msg := []byte(strings.Repeat("1234", 128)) readBuf := make([]byte, len(msg)) diff --git a/dial.go b/dial.go index 0b11ecbb..33ed4d9b 100644 --- a/dial.go +++ b/dial.go @@ -1,6 +1,3 @@ -//go:build !js -// +build !js - package websocket import ( @@ -118,11 +115,8 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 // // URLs with http/https schemes will work and are interpreted as ws/wss. -func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { - return dial(ctx, u, opts, nil) -} -func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { +func dialStd(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") var cancel context.CancelFunc @@ -173,7 +167,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } - return newConn(connConfig{ + return &Conn{newConn(connConfig{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), rwc: rwc, client: true, @@ -183,7 +177,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( onPongReceived: opts.OnPongReceived, br: getBufioReader(rwc), bw: getBufioWriter(rwc), - }), resp, nil + })}, resp, nil } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { diff --git a/dial_std.go b/dial_std.go new file mode 100644 index 00000000..d6785c3d --- /dev/null +++ b/dial_std.go @@ -0,0 +1,13 @@ +//go:build !js +// +build !js + +package websocket + +import ( + "context" + "net/http" +) + +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { + return dialStd(ctx, u, opts, nil) +} diff --git a/export_test.go b/export_test.go index d3443991..5cc26988 100644 --- a/export_test.go +++ b/export_test.go @@ -9,7 +9,7 @@ import ( "github.com/coder/websocket/internal/util" ) -func (c *Conn) RecordBytesWritten() *int { +func (c *StdConn) RecordBytesWritten() *int { var bytesWritten int c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) { bytesWritten += len(p) @@ -18,7 +18,7 @@ func (c *Conn) RecordBytesWritten() *int { return &bytesWritten } -func (c *Conn) RecordBytesRead() *int { +func (c *StdConn) RecordBytesRead() *int { var bytesRead int c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) { n, err := c.rwc.Read(p) @@ -30,7 +30,7 @@ func (c *Conn) RecordBytesRead() *int { var ErrClosed = net.ErrClosed -var ExportedDial = dial +var ExportedDial = dialStd var SecWebSocketAccept = secWebSocketAccept var SecWebSocketKey = secWebSocketKey var VerifyServerResponse = verifyServerResponse diff --git a/frame.go b/frame.go index e7ab76be..ab07ccff 100644 --- a/frame.go +++ b/frame.go @@ -1,5 +1,3 @@ -//go:build !js - package websocket import ( diff --git a/hijack.go b/hijack.go index 9cce45ca..79792d2a 100644 --- a/hijack.go +++ b/hijack.go @@ -1,5 +1,3 @@ -//go:build !js - package websocket import ( diff --git a/mask_go.go b/mask_go.go index b29435e9..b0b0a054 100644 --- a/mask_go.go +++ b/mask_go.go @@ -1,4 +1,4 @@ -//go:build !amd64 && !arm64 && !js +//go:build !amd64 && !arm64 package websocket diff --git a/netconn.go b/netconn.go index b118e4d3..d21844d2 100644 --- a/netconn.go +++ b/netconn.go @@ -51,8 +51,8 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { nc := &netConn{ c: c, msgType: msgType, - readMu: newMu(c), - writeMu: newMu(c), + readMu: c.newMu(), + writeMu: c.newMu(), } nc.writeCtx, nc.writeCancel = context.WithCancel(ctx) @@ -98,13 +98,13 @@ type netConn struct { msgType MessageType writeTimer *time.Timer - writeMu *mu + writeMu muLocker writeExpired atomic.Int64 writeCtx context.Context writeCancel context.CancelFunc readTimer *time.Timer - readMu *mu + readMu muLocker readExpired atomic.Int64 readCtx context.Context readCancel context.CancelFunc diff --git a/netconn_js.go b/netconn_js.go deleted file mode 100644 index ccc8c89f..00000000 --- a/netconn_js.go +++ /dev/null @@ -1,11 +0,0 @@ -package websocket - -import "net" - -func (nc *netConn) RemoteAddr() net.Addr { - return websocketAddr{} -} - -func (nc *netConn) LocalAddr() net.Addr { - return websocketAddr{} -} diff --git a/netconn_notjs.go b/netconn_std.go similarity index 67% rename from netconn_notjs.go rename to netconn_std.go index f3eb0d66..0f1bb94a 100644 --- a/netconn_notjs.go +++ b/netconn_std.go @@ -1,19 +1,16 @@ -//go:build !js -// +build !js - package websocket import "net" func (nc *netConn) RemoteAddr() net.Addr { - if unc, ok := nc.c.rwc.(net.Conn); ok { + if unc, ok := nc.c.conn().(net.Conn); ok { return unc.RemoteAddr() } return websocketAddr{} } func (nc *netConn) LocalAddr() net.Addr { - if unc, ok := nc.c.rwc.(net.Conn); ok { + if unc, ok := nc.c.conn().(net.Conn); ok { return unc.LocalAddr() } return websocketAddr{} diff --git a/read.go b/read_std.go similarity index 92% rename from read.go rename to read_std.go index 2db22435..4491fe16 100644 --- a/read.go +++ b/read_std.go @@ -1,6 +1,3 @@ -//go:build !js -// +build !js - package websocket import ( @@ -33,13 +30,13 @@ import ( // use time.AfterFunc to cancel the context passed in. // See https://github.com/nhooyr/websocket/issues/87#issue-451703332 // Most users should not need this. -func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { +func (c *StdConn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return c.reader(ctx) } // Read is a convenience method around Reader to read a single message // from the connection. -func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { +func (c *StdConn) Read(ctx context.Context) (MessageType, []byte, error) { typ, r, err := c.Reader(ctx) if err != nil { return 0, nil, err @@ -62,7 +59,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // frames are responded to. This means c.Ping and c.Close will still work as expected. // // This function is idempotent. -func (c *Conn) CloseRead(ctx context.Context) context.Context { +func (c *StdConn) CloseRead(ctx context.Context) context.Context { c.closeReadMu.Lock() ctx2 := c.closeReadCtx if ctx2 != nil { @@ -94,7 +91,7 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // When the limit is hit, the connection will be closed with StatusMessageTooBig. // // Set to -1 to disable. -func (c *Conn) SetReadLimit(n int64) { +func (c *StdConn) SetReadLimit(n int64) { if n >= 0 { // We read one more byte than the limit in case // there is a fin frame that needs to be read. @@ -106,7 +103,7 @@ func (c *Conn) SetReadLimit(n int64) { const defaultReadLimit = 32768 -func newMsgReader(c *Conn) *msgReader { +func newMsgReader(c *StdConn) *msgReader { mr := &msgReader{ c: c, fin: true, @@ -168,7 +165,7 @@ func (mr *msgReader) flateContextTakeover() bool { return !mr.c.copts.clientNoContextTakeover } -func (c *Conn) readRSV1Illegal(h header) bool { +func (c *StdConn) readRSV1Illegal(h header) bool { // If compression is disabled, rsv1 is illegal. if !c.flate() { return true @@ -180,7 +177,7 @@ func (c *Conn) readRSV1Illegal(h header) bool { return false } -func (c *Conn) readLoop(ctx context.Context) (header, error) { +func (c *StdConn) readLoop(ctx context.Context) (header, error) { for { h, err := c.readFrameHeader(ctx) if err != nil { @@ -223,7 +220,7 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { // an error depending on if the connection closed or the context timed // out during use. Typically the referenced error is a named return // variable of the function calling this method. -func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { +func (c *StdConn) prepareRead(ctx context.Context, err *error) (func(), error) { select { case <-c.closed: return nil, net.ErrClosed @@ -254,7 +251,7 @@ func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { return done, nil } -func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { +func (c *StdConn) readFrameHeader(ctx context.Context) (_ header, err error) { readDone, err := c.prepareRead(ctx, &err) if err != nil { return header{}, err @@ -269,7 +266,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { return h, nil } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { +func (c *StdConn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { readDone, err := c.prepareRead(ctx, &err) if err != nil { return 0, err @@ -284,7 +281,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error return n, err } -func (c *Conn) handleControl(ctx context.Context, h header) (err error) { +func (c *StdConn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) @@ -363,7 +360,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { return err } -func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { +func (c *StdConn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { defer errd.Wrap(&err, "failed to get reader") err = c.readMu.lock(ctx) @@ -393,7 +390,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro } type msgReader struct { - c *Conn + c *StdConn ctx context.Context flate bool @@ -495,13 +492,13 @@ func (mr *msgReader) read(p []byte) (int, error) { } type limitReader struct { - c *Conn + c *StdConn r io.Reader limit atomic.Int64 n int64 } -func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { +func newLimitReader(c *StdConn, r io.Reader, limit int64) *limitReader { lr := &limitReader{ c: c, } diff --git a/write.go b/write_std.go similarity index 89% rename from write.go rename to write_std.go index 7324de74..6d251917 100644 --- a/write.go +++ b/write_std.go @@ -1,6 +1,3 @@ -//go:build !js -// +build !js - package websocket import ( @@ -26,7 +23,7 @@ import ( // // Only one writer can be open at a time, multiple calls will block until the previous writer // is closed. -func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { +func (c *StdConn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { return nil, fmt.Errorf("failed to get writer: %w", err) @@ -40,7 +37,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // // If compression is disabled or the compression threshold is not met, then it // will write the message in a single frame. -func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { +func (c *StdConn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { return fmt.Errorf("failed to write msg: %w", err) @@ -49,10 +46,10 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { } type msgWriter struct { - c *Conn + c *StdConn - mu *mu - writeMu *mu + mu *stdMu + writeMu *stdMu closed bool ctx context.Context @@ -63,11 +60,11 @@ type msgWriter struct { flateWriter *flate.Writer } -func newMsgWriter(c *Conn) *msgWriter { +func newMsgWriter(c *StdConn) *msgWriter { mw := &msgWriter{ c: c, - mu: newMu(c), - writeMu: newMu(c), + mu: newStdMu(c), + writeMu: newStdMu(c), } return mw } @@ -92,7 +89,7 @@ func (mw *msgWriter) flateContextTakeover() bool { return !mw.c.copts.serverNoContextTakeover } -func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { +func (c *StdConn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { err := c.msgWriter.reset(ctx, typ) if err != nil { return nil, err @@ -100,7 +97,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err return c.msgWriter, nil } -func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { +func (c *StdConn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { mw, err := c.writer(ctx, typ) if err != nil { return 0, err @@ -229,7 +226,7 @@ func (mw *msgWriter) close() { mw.putFlateWriter() } -func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { +func (c *StdConn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() @@ -241,7 +238,7 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } // writeFrame handles all writes to the connection. -func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { +func (c *StdConn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err @@ -331,7 +328,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return n, nil } -func (c *Conn) writeFramePayload(p []byte) (n int, err error) { +func (c *StdConn) writeFramePayload(p []byte) (n int, err error) { defer errd.Wrap(&err, "failed to write frame payload") if !c.writeHeader.masked { @@ -387,6 +384,6 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { return writeBuf } -func (c *Conn) writeError(code StatusCode, err error) { +func (c *StdConn) writeError(code StatusCode, err error) { c.writeClose(code, err.Error()) } diff --git a/ws_js.go b/ws_js.go index 5e324c47..abbaa839 100644 --- a/ws_js.go +++ b/ws_js.go @@ -19,28 +19,8 @@ import ( "github.com/coder/websocket/internal/wsjs" ) -// opcode represents a WebSocket opcode. -type opcode int - -// https://tools.ietf.org/html/rfc6455#section-11.8. -const ( - opContinuation opcode = iota - opText - opBinary - // 3 - 7 are reserved for further non-control frames. - _ - _ - _ - _ - _ - opClose - opPing - opPong - // 11-16 are reserved for further control frames. -) - -// Conn provides a wrapper around the browser WebSocket API. -type Conn struct { +// BrowserConn provides a wrapper around the browser WebSocket API. +type BrowserConn struct { noCopy noCopy ws wsjs.WebSocket @@ -66,7 +46,9 @@ type Conn struct { readBuf []wsjs.MessageEvent } -func (c *Conn) close(err error, wasClean bool) { +var _ Stream = (*BrowserConn)(nil) + +func (c *BrowserConn) close(err error, wasClean bool) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) @@ -79,7 +61,7 @@ func (c *Conn) close(err error, wasClean bool) { }) } -func (c *Conn) init() { +func (c *BrowserConn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) @@ -118,19 +100,19 @@ func (c *Conn) init() { } }) - runtime.SetFinalizer(c, func(c *Conn) { + runtime.SetFinalizer(c, func(c *BrowserConn) { c.setCloseErr(errors.New("connection garbage collected")) c.closeWithInternal() }) } -func (c *Conn) closeWithInternal() { +func (c *BrowserConn) closeWithInternal() { c.Close(StatusInternalError, "something went wrong") } // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. -func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { +func (c *BrowserConn) Read(ctx context.Context) (MessageType, []byte, error) { c.closeReadMu.Lock() closedRead := c.closeReadCtx != nil c.closeReadMu.Unlock() @@ -151,7 +133,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { return typ, p, nil } -func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { +func (c *BrowserConn) read(ctx context.Context) (MessageType, []byte, error) { select { case <-ctx.Done(): c.Close(StatusPolicyViolation, "read timed out") @@ -189,13 +171,13 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { } // Ping is mocked out for Wasm. -func (c *Conn) Ping(ctx context.Context) error { +func (c *BrowserConn) Ping(ctx context.Context) error { return nil } // Write writes a message of the given type to the connection. // Always non blocking. -func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { +func (c *BrowserConn) Write(ctx context.Context, typ MessageType, p []byte) error { err := c.write(ctx, typ, p) if err != nil { // Have to ensure the WebSocket is closed after a write error @@ -210,7 +192,7 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { return nil } -func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { +func (c *BrowserConn) write(ctx context.Context, typ MessageType, p []byte) error { if c.isClosed() { return net.ErrClosed } @@ -228,7 +210,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { // It will wait until the peer responds with a close frame // or the connection is closed. // It thus performs the full WebSocket close handshake. -func (c *Conn) Close(code StatusCode, reason string) error { +func (c *BrowserConn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) @@ -241,11 +223,11 @@ func (c *Conn) Close(code StatusCode, reason string) error { // // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close // a WebSocket without the close handshake. -func (c *Conn) CloseNow() error { +func (c *BrowserConn) CloseNow() error { return c.Close(StatusGoingAway, "") } -func (c *Conn) exportedClose(code StatusCode, reason string) error { +func (c *BrowserConn) exportedClose(code StatusCode, reason string) error { c.closingMu.Lock() defer c.closingMu.Unlock() @@ -273,14 +255,12 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. -func (c *Conn) Subprotocol() string { +func (c *BrowserConn) Subprotocol() string { return c.ws.Subprotocol() } -// DialOptions represents the options available to pass to Dial. -type DialOptions struct { - // Subprotocols lists the subprotocols to negotiate with the server. - Subprotocols []string +func (c *BrowserConn) conn() any { + return c.ws } // Dial creates a new WebSocket connection to the given url with the given options. @@ -288,6 +268,9 @@ type DialOptions struct { // The returned *http.Response is always nil or a mock. It's only in the signature // to match the core API. func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + if opts != nil && opts.HTTPClient != nil { + return dialStd(ctx, url, opts, nil) + } c, resp, err := dial(ctx, url, opts) if err != nil { return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) @@ -308,7 +291,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp return nil, nil, err } - c := &Conn{ + c := &BrowserConn{ ws: ws, } c.init() @@ -324,7 +307,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp c.Close(StatusPolicyViolation, "dial timed out") return nil, nil, ctx.Err() case <-opench: - return c, &http.Response{ + return &Conn{c}, &http.Response{ StatusCode: http.StatusSwitchingProtocols, }, nil case <-c.closed: @@ -334,7 +317,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp // Reader attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. -func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { +func (c *BrowserConn) Reader(ctx context.Context) (MessageType, io.Reader, error) { typ, p, err := c.Read(ctx) if err != nil { return 0, nil, err @@ -345,7 +328,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { // Writer returns a writer to write a WebSocket data message to the connection. // It buffers the entire message in memory and then sends it when the writer // is closed. -func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { +func (c *BrowserConn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { return &writer{ c: c, ctx: ctx, @@ -357,7 +340,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err type writer struct { closed bool - c *Conn + c *BrowserConn ctx context.Context typ MessageType @@ -390,7 +373,7 @@ func (w *writer) Close() error { } // CloseRead implements *Conn.CloseRead for wasm. -func (c *Conn) CloseRead(ctx context.Context) context.Context { +func (c *BrowserConn) CloseRead(ctx context.Context) context.Context { c.closeReadMu.Lock() ctx2 := c.closeReadCtx if ctx2 != nil { @@ -413,17 +396,17 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { } // SetReadLimit implements *Conn.SetReadLimit for wasm. -func (c *Conn) SetReadLimit(n int64) { +func (c *BrowserConn) SetReadLimit(n int64) { c.msgReadLimit.Store(n) } -func (c *Conn) setCloseErr(err error) { +func (c *BrowserConn) setCloseErr(err error) { c.closeErrOnce.Do(func() { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) }) } -func (c *Conn) isClosed() bool { +func (c *BrowserConn) isClosed() bool { select { case <-c.closed: return true @@ -432,152 +415,27 @@ func (c *Conn) isClosed() bool { } } -// AcceptOptions represents Accept's options. -type AcceptOptions struct { - Subprotocols []string - InsecureSkipVerify bool - OriginPatterns []string - CompressionMode CompressionMode - CompressionThreshold int -} - -// Accept is stubbed out for Wasm. -func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - return nil, errors.New("unimplemented") -} - -// StatusCode represents a WebSocket status code. -// https://tools.ietf.org/html/rfc6455#section-7.4 -type StatusCode int - -// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// -// These are only the status codes defined by the protocol. -// -// You can define custom codes in the 3000-4999 range. -// The 3000-3999 range is reserved for use by libraries, frameworks and applications. -// The 4000-4999 range is reserved for private use. -const ( - StatusNormalClosure StatusCode = 1000 - StatusGoingAway StatusCode = 1001 - StatusProtocolError StatusCode = 1002 - StatusUnsupportedData StatusCode = 1003 - - // 1004 is reserved and so unexported. - statusReserved StatusCode = 1004 - - // StatusNoStatusRcvd cannot be sent in a close message. - // It is reserved for when a close message is received without - // a status code. - StatusNoStatusRcvd StatusCode = 1005 - - // StatusAbnormalClosure is exported for use only with Wasm. - // In non Wasm Go, the returned error will indicate whether the - // connection was closed abnormally. - StatusAbnormalClosure StatusCode = 1006 - - StatusInvalidFramePayloadData StatusCode = 1007 - StatusPolicyViolation StatusCode = 1008 - StatusMessageTooBig StatusCode = 1009 - StatusMandatoryExtension StatusCode = 1010 - StatusInternalError StatusCode = 1011 - StatusServiceRestart StatusCode = 1012 - StatusTryAgainLater StatusCode = 1013 - StatusBadGateway StatusCode = 1014 - - // StatusTLSHandshake is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether there was - // a TLS handshake failure. - StatusTLSHandshake StatusCode = 1015 -) - -// CloseError is returned when the connection is closed with a status and reason. -// -// Use Go 1.13's errors.As to check for this error. -// Also see the CloseStatus helper. -type CloseError struct { - Code StatusCode - Reason string -} - -func (ce CloseError) Error() string { - return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) -} - -// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab -// the status code from a CloseError. -// -// -1 will be returned if the passed error is nil or not a CloseError. -func CloseStatus(err error) StatusCode { - var ce CloseError - if errors.As(err, &ce) { - return ce.Code - } - return -1 -} - -// CompressionMode represents the modes available to the deflate extension. -// See https://tools.ietf.org/html/rfc7692 -// Works in all browsers except Safari which does not implement the deflate extension. -type CompressionMode int - -const ( - // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed - // for every message. This applies to both server and client side. - // - // This means less efficient compression as the sliding window from previous messages - // will not be used but the memory overhead will be lower if the connections - // are long lived and seldom used. - // - // The message will only be compressed if greater than 512 bytes. - CompressionNoContextTakeover CompressionMode = iota - - // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. - // This enables reusing the sliding window from previous messages. - // As most WebSocket protocols are repetitive, this can be very efficient. - // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. - // - // If the peer negotiates NoContextTakeover on the client or server side, it will be - // used instead as this is required by the RFC. - CompressionContextTakeover - - // CompressionDisabled disables the deflate extension. - // - // Use this if you are using a predominantly binary protocol with very - // little duplication in between messages or CPU and memory are more - // important than bandwidth. - CompressionDisabled -) - -// MessageType represents the type of a WebSocket message. -// See https://tools.ietf.org/html/rfc6455#section-5.6 -type MessageType int - -// MessageType constants. -const ( - // MessageText is for UTF-8 encoded text messages like JSON. - MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like protobufs. - MessageBinary -) - -type mu struct { - c *Conn +type jsMu struct { + c *BrowserConn ch chan struct{} } -func newMu(c *Conn) *mu { - return &mu{ +func newMu(c *BrowserConn) *jsMu { + return &jsMu{ c: c, ch: make(chan struct{}, 1), } } -func (m *mu) forceLock() { +func (c *BrowserConn) newMu() muLocker { + return newMu(c) +} + +func (m *jsMu) forceLock() { m.ch <- struct{}{} } -func (m *mu) tryLock() bool { +func (m *jsMu) tryLock() bool { select { case m.ch <- struct{}{}: return true @@ -586,13 +444,9 @@ func (m *mu) tryLock() bool { } } -func (m *mu) unlock() { +func (m *jsMu) unlock() { select { case <-m.ch: default: } } - -type noCopy struct{} - -func (*noCopy) Lock() {}