package wsapi import ( "context" "encoding/json" "errors" "strings" "dynatron.me/x/blasphem/pkg/auth" "dynatron.me/x/blasphem/pkg/blas/core" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/mitchellh/mapstructure" "github.com/rs/zerolog/log" ) var ( NoSuchHandlerErr = errors.New("bad websocket command") NoMessageIDErr = errors.New("no message ID") AuthInvalidErr = errors.New("invalid auth") ) type Type string type MsgBase struct { ID *int `json:"id,omitempty"` Type Type `json:"type"` } type ( wsSession struct { conn *websocket.Conn b core.Blas ec echo.Context write chan<- interface{} user *auth.User refreshToken *auth.RefreshToken } wsEntry struct { dataNew core.NewData hnd core.Handler } wsRegistry map[string]wsEntry wsManager struct { r wsRegistry } ) func (wsm *wsManager) RegisterWSCommand(cmd string, hnd core.Handler, dataNew core.NewData) { wsm.r[cmd] = wsEntry{ dataNew: dataNew, hnd: hnd, } } func (wsm *wsManager) WSCommandHandler(cmd string, cmdSplit []string) (core.NewData, core.Handler, error) { if wse, ok := wsm.r[cmd]; ok { return wse.dataNew, wse.hnd, nil } if wse, ok := wsm.r[cmdSplit[0]]; ok { return wse.dataNew, wse.hnd, nil } return nil, nil, NoSuchHandlerErr } func NewManager() core.WebSocketManager { return &wsManager{ r: make(wsRegistry), } } func NewSession(s core.Blas, c echo.Context, conn *websocket.Conn) core.WebSocketSession { ws := &wsSession{ conn: conn, b: s, ec: c, } return ws } func (ws *wsSession) Conn() *websocket.Conn { return ws.conn } func (ws *wsSession) Blas() core.Blas { return ws.b } func (ws *wsSession) Go(ctx context.Context) error { authP := &authPhase{ws} err := ws.sendAuthRequired() if err != nil { return err } _, rdr, err := ws.conn.NextReader() if err != nil { return err } err = authP.handleMsg(ctx, rdr) if err != nil { return err } // command phase msgChan := make(chan map[string]interface{}) write := make(chan interface{}) ws.write = write defer close(write) cCtx, cancel := context.WithCancel(ctx) defer cancel() go func(ctx context.Context, ch chan<- map[string]interface{}) { defer close(ch) for { if err := ctx.Err(); err != nil { return } _, rdr, err := ws.conn.NextReader() switch { case err == nil: case websocket.IsCloseError(err, websocket.CloseGoingAway): return case err != nil: log.Error().Err(err).Str("remote", ws.ec.Request().RemoteAddr).Msg("websocket read fail") return } var msgMap map[string]interface{} err = json.NewDecoder(rdr).Decode(&msgMap) if err != nil { ws.writeError(-1, Error{Code: "invalid_format", Message: err.Error()}) } ch <- msgMap } }(cCtx, msgChan) for { select { case msg, ok := <-msgChan: if !ok { return nil } err = ws.handleMsg(ctx, msg) if err != nil { log.Error().Err(err).Msg("handleMsg") } case <-ctx.Done(): close(msgChan) return nil case m := <-write: err := ws.conn.WriteJSON(m) if err != nil { log.Error().Err(err).Msg("writeMsg") } } } } func (ws *wsSession) Write(msg interface{}) { ws.write <- msg } type MsgType string const ( ResultMsgType MsgType = "result" ) type Error struct { Code string `json:"code"` Message string `json:"message"` } type WSError struct { ID *int `json:"id,omitempty"` Type MsgType `json:"type"` Success bool `json:"success"` Error Error `json:"error"` } func (ws *wsSession) writeError(id int, err Error) error { return ws.conn.WriteJSON(WSError{ ID: &id, Type: ResultMsgType, Success: false, Error: err, }) } func (ws *wsSession) handleMsg(ctx context.Context, msgMap map[string]interface{}) error { if err := ctx.Err(); err != nil { return nil } msgType, ok := msgMap["type"].(string) if !ok { return NoSuchHandlerErr } idFl, ok := msgMap["id"].(float64) if !ok { ws.Write(WSError{ Type: ResultMsgType, Success: false, Error: Error{ Code: "invalid_id", Message: "command has no ID", }, }) return nil } id := int(idFl) cmd := strings.Split(msgType, "/") newData, hand, err := ws.b.WSCommandHandler(cmd[0], cmd) switch err { case nil: case NoSuchHandlerErr: ws.writeError(id, Error{ Code: "invalid_type", Message: "no such command", }) return nil default: log.Error().Err(err).Msg("dispatch") return nil } nd := newData(cmd) if _, ok := nd.(map[string]interface{}); !ok { err := mapstructure.Decode(&msgMap, &nd) if err != nil { ws.writeError(id, Error{ Code: "invalid_format", Message: err.Error(), }) return nil } } err = hand(ctx, ws, id, cmd, nd) if err != nil { log.Error().Err(err).Msg("dispatch") } return nil }