connection.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package ws
  2. import (
  3. "errors"
  4. "github.com/gorilla/websocket"
  5. "github.com/rs/zerolog/log"
  6. "io"
  7. "net"
  8. "net/http"
  9. "time"
  10. )
  11. type ConnectionOptions struct {
  12. WriteDeadline time.Duration
  13. ReadDeadline time.Duration
  14. PingPeriod time.Duration
  15. PingMessage []byte
  16. Headers http.Header
  17. Proxy *net.Dialer
  18. MessageListener chan []byte
  19. }
  20. type Connection struct {
  21. session *websocket.Conn
  22. opt ConnectionOptions
  23. buf chan []byte
  24. }
  25. func NewConnection(url string, opts ConnectionOptions) (*Connection, error) {
  26. var (
  27. dialer *websocket.Dialer
  28. err error
  29. )
  30. if opts.WriteDeadline == 0 {
  31. opts.WriteDeadline = time.Second * 10
  32. }
  33. if opts.PingPeriod == 0 {
  34. opts.PingPeriod = time.Second * 30
  35. }
  36. c := &Connection{
  37. buf: make(chan []byte, 256),
  38. opt: opts,
  39. }
  40. if opts.Proxy != nil {
  41. dialer = &websocket.Dialer{
  42. NetDial: opts.Proxy.Dial,
  43. }
  44. } else {
  45. dialer = websocket.DefaultDialer
  46. }
  47. if c.session, _, err = dialer.Dial(url, opts.Headers); err != nil {
  48. return nil, err
  49. }
  50. if opts.MessageListener != nil {
  51. go c.startReader()
  52. }
  53. go c.startWriter()
  54. return c, nil
  55. }
  56. func (c *Connection) startReader() {
  57. //goland:noinspection ALL
  58. defer c.session.Close()
  59. var (
  60. buf []byte
  61. err error
  62. )
  63. if c.opt.ReadDeadline > 0 {
  64. _ = c.session.SetReadDeadline(time.Now().Add(c.opt.ReadDeadline))
  65. }
  66. for {
  67. if _, buf, err = c.session.ReadMessage(); err != nil {
  68. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) ||
  69. errors.Is(err, io.EOF) {
  70. log.Debug().Str("net", "websocket").Err(err).Msg("closed connection")
  71. break
  72. }
  73. }
  74. if buf == nil {
  75. continue
  76. }
  77. c.opt.MessageListener <- buf
  78. }
  79. }
  80. func (c *Connection) startWriter() {
  81. var (
  82. wr io.WriteCloser
  83. err error
  84. )
  85. ticker := time.NewTicker(c.opt.PingPeriod)
  86. defer func() {
  87. ticker.Stop()
  88. if c.session != nil {
  89. _ = c.session.Close()
  90. }
  91. }()
  92. for {
  93. select {
  94. case message, ok := <-c.buf:
  95. _ = c.session.SetWriteDeadline(time.Now().Add(c.opt.WriteDeadline))
  96. if !ok {
  97. if err = c.session.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
  98. log.Debug().Str("net", "websocket").Err(err).Msg("could not correctly close the channel")
  99. }
  100. return
  101. }
  102. if wr, err = c.session.NextWriter(websocket.BinaryMessage); err != nil {
  103. log.Debug().Str("net", "websocket").Err(err).Msg("could not open writer io")
  104. return
  105. }
  106. if _, err = wr.Write(message); err != nil {
  107. log.Debug().Str("net", "websocket").Err(err).Msg("could not write message")
  108. return
  109. }
  110. if err = wr.Close(); err != nil {
  111. log.Debug().Str("net", "websocket").Err(err).Msg("could not close writer io")
  112. return
  113. }
  114. case <-ticker.C:
  115. _ = c.session.SetWriteDeadline(time.Now().Add(c.opt.WriteDeadline))
  116. if c.opt.PingMessage == nil {
  117. if err = c.session.WriteMessage(websocket.PingMessage, nil); err != nil {
  118. return
  119. }
  120. } else {
  121. if err = c.session.WriteMessage(websocket.BinaryMessage, c.opt.PingMessage); err != nil {
  122. return
  123. }
  124. }
  125. }
  126. }
  127. }
  128. func (c *Connection) Send(data []byte) error {
  129. select {
  130. case c.buf <- data:
  131. default:
  132. close(c.buf)
  133. return errors.New("websocket: closed write channel")
  134. }
  135. return nil
  136. }