Improve routing
This commit is contained in:
parent
7c860a85b1
commit
963f4020b7
3 changed files with 62 additions and 35 deletions
|
@ -69,6 +69,10 @@ func (fs FlowStore) register(f *Flow) {
|
|||
fs[f.ID] = f
|
||||
}
|
||||
|
||||
func (fs FlowStore) Remove(f *Flow) {
|
||||
delete(fs, f.ID)
|
||||
}
|
||||
|
||||
const cullAge = time.Minute * 30
|
||||
|
||||
func (fs FlowStore) cull() {
|
||||
|
@ -142,6 +146,8 @@ func (f *Flow) progress(a *Authenticator, c echo.Context) error {
|
|||
err = a.Check(f, rm)
|
||||
switch err {
|
||||
case nil:
|
||||
// TODO: setup the session. delete the flow.
|
||||
a.Flows.Remove(f)
|
||||
return c.String(http.StatusOK, "login success!")
|
||||
case ErrInvalidHandler:
|
||||
return c.String(http.StatusNotFound, err.Error())
|
||||
|
@ -170,16 +176,16 @@ func (a *Authenticator) LoginFlowDeleteHandler(c echo.Context) error {
|
|||
return c.String(http.StatusOK, "deleted")
|
||||
}
|
||||
|
||||
func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
|
||||
func setJSON(c echo.Context) {
|
||||
if c.Request().Method == http.MethodPost && strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "text/plain") {
|
||||
// hack around the content-type, Context.JSON refuses to work otherwise
|
||||
c.Request().Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
|
||||
}
|
||||
}
|
||||
|
||||
flowID := c.Param("flow_id")
|
||||
func (a *Authenticator) BeginLoginFlowHandler(c echo.Context) error {
|
||||
setJSON(c)
|
||||
|
||||
switch flowID {
|
||||
case "": // new
|
||||
var flowReq FlowRequest
|
||||
err := c.Bind(&flowReq)
|
||||
if err != nil {
|
||||
|
@ -193,7 +199,13 @@ func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
|
|||
}
|
||||
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
default:
|
||||
}
|
||||
|
||||
func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
|
||||
setJSON(c)
|
||||
|
||||
flowID := c.Param("flow_id")
|
||||
|
||||
flow := a.Flows.Get(FlowID(flowID))
|
||||
if flow == nil {
|
||||
return c.String(http.StatusNotFound, "no such flow")
|
||||
|
@ -201,4 +213,16 @@ func (a *Authenticator) LoginFlowHandler(c echo.Context) error {
|
|||
|
||||
return flow.progress(a, c)
|
||||
}
|
||||
|
||||
func (a *Authenticator) InstallRoutes(e *echo.Echo) {
|
||||
authG := e.Group("/auth")
|
||||
authG.GET("/authorize", a.AuthorizeHandler)
|
||||
authG.GET("/providers", a.ProvidersHandler)
|
||||
|
||||
authG.POST("/login_flow", a.BeginLoginFlowHandler)
|
||||
|
||||
loginFlow := authG.Group("/login_flow") // TODO: add IP address affinity middleware
|
||||
loginFlow.POST("/:flow_id", a.LoginFlowHandler)
|
||||
loginFlow.DELETE("/:flow_id", a.LoginFlowDeleteHandler)
|
||||
|
||||
}
|
||||
|
|
|
@ -36,5 +36,6 @@ func (ips *IPSource) UnmarshalYAML(val *yaml.Node) error {
|
|||
type Config struct {
|
||||
Bind string `yaml:"bind"`
|
||||
IPSource IPSource `yaml:"ip_source"`
|
||||
LogRequestErrors bool `yaml:"log_req_errors"`
|
||||
TrustedProxies []string `yaml:"trusted_proxies"`
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
@ -24,19 +23,14 @@ type Server struct {
|
|||
*blas.Blas
|
||||
*echo.Echo
|
||||
auth.Authenticator
|
||||
rootFS fs.FS
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (s *Server) installRoutes() {
|
||||
s.GET("/", echo.WrapHandler(frontend.FSHandler))
|
||||
s.GET("/api/websocket", s.wsHandler)
|
||||
s.GET("/auth/authorize", s.AuthorizeHandler)
|
||||
s.GET("/auth/providers", s.ProvidersHandler)
|
||||
|
||||
s.POST("/auth/login_flow", s.LoginFlowHandler)
|
||||
s.POST("/auth/login_flow/:flow_id", s.LoginFlowHandler)
|
||||
s.DELETE("/auth/login_flow/:flow_id", s.LoginFlowDeleteHandler)
|
||||
s.Authenticator.InstallRoutes(s.Echo)
|
||||
}
|
||||
|
||||
func New(cfg *config.Config) (s *Server, err error) {
|
||||
|
@ -56,7 +50,15 @@ func New(cfg *config.Config) (s *Server, err error) {
|
|||
|
||||
s.Echo.Debug = true
|
||||
s.Echo.HideBanner = true
|
||||
s.Echo.Logger = lecho.From(log.Logger)
|
||||
logger := lecho.From(log.Logger)
|
||||
s.Echo.Logger = logger
|
||||
|
||||
if cfg.Server.LogRequestErrors {
|
||||
s.Echo.Use(lecho.Middleware(lecho.Config{
|
||||
Logger: logger,
|
||||
}))
|
||||
}
|
||||
|
||||
s.Echo.HidePort = true
|
||||
|
||||
ipext := map[conf.IPSource]echo.IPExtractor{
|
||||
|
|
Loading…
Reference in a new issue