|
@@ -0,0 +1,167 @@
|
|
|
+package ws
|
|
|
+
|
|
|
+import (
|
|
|
+ "errors"
|
|
|
+ "github.com/gorilla/websocket"
|
|
|
+ "github.com/rs/zerolog/log"
|
|
|
+ "io"
|
|
|
+ "net"
|
|
|
+ "net/http"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+type ConnectionOptions struct {
|
|
|
+ WriteDeadline time.Duration
|
|
|
+ ReadDeadline time.Duration
|
|
|
+ PingPeriod time.Duration
|
|
|
+ PingMessage []byte
|
|
|
+
|
|
|
+ Headers http.Header
|
|
|
+ Proxy *net.Dialer
|
|
|
+
|
|
|
+ MessageListener chan []byte
|
|
|
+}
|
|
|
+
|
|
|
+type Connection struct {
|
|
|
+ session *websocket.Conn
|
|
|
+ opt ConnectionOptions
|
|
|
+ buf chan []byte
|
|
|
+}
|
|
|
+
|
|
|
+func NewConnection(url string, opts ConnectionOptions) (*Connection, error) {
|
|
|
+ var (
|
|
|
+ dialer *websocket.Dialer
|
|
|
+ err error
|
|
|
+ )
|
|
|
+
|
|
|
+ if opts.WriteDeadline == 0 {
|
|
|
+ opts.WriteDeadline = time.Second * 10
|
|
|
+ }
|
|
|
+
|
|
|
+ if opts.PingPeriod == 0 {
|
|
|
+ opts.PingPeriod = time.Second * 30
|
|
|
+ }
|
|
|
+
|
|
|
+ c := &Connection{
|
|
|
+ buf: make(chan []byte, 256),
|
|
|
+ opt: opts,
|
|
|
+ }
|
|
|
+
|
|
|
+ if opts.Proxy != nil {
|
|
|
+ dialer = &websocket.Dialer{
|
|
|
+ NetDial: opts.Proxy.Dial,
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ dialer = websocket.DefaultDialer
|
|
|
+ }
|
|
|
+
|
|
|
+ if c.session, _, err = dialer.Dial(url, opts.Headers); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if opts.MessageListener != nil {
|
|
|
+ go c.startReader()
|
|
|
+ }
|
|
|
+
|
|
|
+ go c.startWriter()
|
|
|
+
|
|
|
+ return c, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Connection) startReader() {
|
|
|
+ //goland:noinspection ALL
|
|
|
+ defer c.session.Close()
|
|
|
+
|
|
|
+ var (
|
|
|
+ buf []byte
|
|
|
+ err error
|
|
|
+ )
|
|
|
+
|
|
|
+ if c.opt.ReadDeadline > 0 {
|
|
|
+ _ = c.session.SetReadDeadline(time.Now().Add(c.opt.ReadDeadline))
|
|
|
+ }
|
|
|
+
|
|
|
+ for {
|
|
|
+ if _, buf, err = c.session.ReadMessage(); err != nil {
|
|
|
+ if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) ||
|
|
|
+ errors.Is(err, io.EOF) {
|
|
|
+ log.Debug().Str("net", "websocket").Err(err).Msg("closed connection")
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if buf == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ c.opt.MessageListener <- buf
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Connection) startWriter() {
|
|
|
+ var (
|
|
|
+ wr io.WriteCloser
|
|
|
+ err error
|
|
|
+ )
|
|
|
+
|
|
|
+ ticker := time.NewTicker(c.opt.PingPeriod)
|
|
|
+ defer func() {
|
|
|
+ ticker.Stop()
|
|
|
+
|
|
|
+ if c.session != nil {
|
|
|
+ _ = c.session.Close()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case message, ok := <-c.buf:
|
|
|
+ _ = c.session.SetWriteDeadline(time.Now().Add(c.opt.WriteDeadline))
|
|
|
+
|
|
|
+ if !ok {
|
|
|
+ if err = c.session.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
|
|
|
+ log.Debug().Str("net", "websocket").Err(err).Msg("could not correctly close the channel")
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if wr, err = c.session.NextWriter(websocket.BinaryMessage); err != nil {
|
|
|
+ log.Debug().Str("net", "websocket").Err(err).Msg("could not open writer io")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if _, err = wr.Write(message); err != nil {
|
|
|
+ log.Debug().Str("net", "websocket").Err(err).Msg("could not write message")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if err = wr.Close(); err != nil {
|
|
|
+ log.Debug().Str("net", "websocket").Err(err).Msg("could not close writer io")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ case <-ticker.C:
|
|
|
+ _ = c.session.SetWriteDeadline(time.Now().Add(c.opt.WriteDeadline))
|
|
|
+
|
|
|
+ if c.opt.PingMessage == nil {
|
|
|
+ if err = c.session.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if err = c.session.WriteMessage(websocket.BinaryMessage, c.opt.PingMessage); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Connection) Send(data []byte) error {
|
|
|
+ select {
|
|
|
+ case c.buf <- data:
|
|
|
+ default:
|
|
|
+ close(c.buf)
|
|
|
+ return errors.New("websocket: closed write channel")
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|