123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- package ws
- import (
- "errors"
- "github.com/gorilla/websocket"
- "github.com/rs/zerolog/log"
- "io"
- "net"
- "net/http"
- "time"
- )
- type ConnectionOptions struct {
- WriteDeadline time.Duration
- ReadDeadline time.Duration
- PingPeriod time.Duration
- PingMessage []byte
- Headers http.Header
- Proxy *net.Dialer
- MessageListener chan []byte
- }
- type Connection struct {
- session *websocket.Conn
- opt ConnectionOptions
- buf chan []byte
- }
- func NewConnection(url string, opts ConnectionOptions) (*Connection, error) {
- var (
- dialer *websocket.Dialer
- err error
- )
- if opts.WriteDeadline == 0 {
- opts.WriteDeadline = time.Second * 10
- }
- if opts.PingPeriod == 0 {
- opts.PingPeriod = time.Second * 30
- }
- c := &Connection{
- buf: make(chan []byte, 256),
- opt: opts,
- }
- if opts.Proxy != nil {
- dialer = &websocket.Dialer{
- NetDial: opts.Proxy.Dial,
- }
- } else {
- dialer = websocket.DefaultDialer
- }
- if c.session, _, err = dialer.Dial(url, opts.Headers); err != nil {
- return nil, err
- }
- if opts.MessageListener != nil {
- go c.startReader()
- }
- go c.startWriter()
- return c, nil
- }
- func (c *Connection) startReader() {
- //goland:noinspection ALL
- defer c.session.Close()
- var (
- buf []byte
- err error
- )
- if c.opt.ReadDeadline > 0 {
- _ = c.session.SetReadDeadline(time.Now().Add(c.opt.ReadDeadline))
- }
- for {
- if _, buf, err = c.session.ReadMessage(); err != nil {
- if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) ||
- errors.Is(err, io.EOF) {
- log.Debug().Str("net", "websocket").Err(err).Msg("closed connection")
- break
- }
- }
- if buf == nil {
- continue
- }
- c.opt.MessageListener <- buf
- }
- }
- func (c *Connection) startWriter() {
- var (
- wr io.WriteCloser
- err error
- )
- ticker := time.NewTicker(c.opt.PingPeriod)
- defer func() {
- ticker.Stop()
- if c.session != nil {
- _ = c.session.Close()
- }
- }()
- for {
- select {
- case message, ok := <-c.buf:
- _ = c.session.SetWriteDeadline(time.Now().Add(c.opt.WriteDeadline))
- if !ok {
- if err = c.session.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
- log.Debug().Str("net", "websocket").Err(err).Msg("could not correctly close the channel")
- }
- return
- }
- if wr, err = c.session.NextWriter(websocket.BinaryMessage); err != nil {
- log.Debug().Str("net", "websocket").Err(err).Msg("could not open writer io")
- return
- }
- if _, err = wr.Write(message); err != nil {
- log.Debug().Str("net", "websocket").Err(err).Msg("could not write message")
- return
- }
- if err = wr.Close(); err != nil {
- log.Debug().Str("net", "websocket").Err(err).Msg("could not close writer io")
- return
- }
- case <-ticker.C:
- _ = c.session.SetWriteDeadline(time.Now().Add(c.opt.WriteDeadline))
- if c.opt.PingMessage == nil {
- if err = c.session.WriteMessage(websocket.PingMessage, nil); err != nil {
- return
- }
- } else {
- if err = c.session.WriteMessage(websocket.BinaryMessage, c.opt.PingMessage); err != nil {
- return
- }
- }
- }
- }
- }
- func (c *Connection) Send(data []byte) error {
- select {
- case c.buf <- data:
- default:
- close(c.buf)
- return errors.New("websocket: closed write channel")
- }
- return nil
- }
|