diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index 2eae750..a798a63 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -34,8 +34,8 @@ type jwtAuth interface { // InstallVerifyMiddleware installs the JWT verifier middleware to the provided chi Router. VerifyMiddleware() func(http.Handler) http.Handler - // InstallAuthMiddleware installs the JWT authenticator middleware to the provided chi Router. - AuthMiddleware() func(http.Handler) http.Handler + // SubjectMiddleware sets the request context subject from JWT or public. + SubjectMiddleware(requireAuth bool) func(http.Handler) http.Handler // PublicRoutes installs the auth route to the provided chi Router. PublicRoutes(chi.Router) @@ -84,12 +84,20 @@ func TokenFromCookie(r *http.Request) string { return cookie.Value } -func (a *Auth) AuthMiddleware() func(http.Handler) http.Handler { +func (a *Auth) PublicSubjectMiddleware() func(http.Handler) http.Handler { + return a.SubjectMiddleware(false) +} + +func (a *Auth) AuthorizedSubjectMiddleware() func(http.Handler) http.Handler { + return a.SubjectMiddleware(true) +} + +func (a *Auth) SubjectMiddleware(requireToken bool) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { hfn := func(w http.ResponseWriter, r *http.Request) { token, _, err := jwtauth.FromContext(r.Context()) - if err != nil { + if err != nil && requireToken { http.Error(w, err.Error(), http.StatusUnauthorized) return } diff --git a/pkg/rbac/entities/entities.go b/pkg/rbac/entities/entities.go index 89553c0..2885930 100644 --- a/pkg/rbac/entities/entities.go +++ b/pkg/rbac/entities/entities.go @@ -31,11 +31,11 @@ const ( func SubjectFrom(ctx context.Context) Subject { sub, ok := ctx.Value(SubjectCtxKey).(Subject) - if ok { - return sub + if !ok { + panic("no subject in context") } - return new(PublicSubject) + return sub } type Subject interface { diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 9410baf..e0fbfff 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -29,10 +29,11 @@ func (s *Server) setupRoutes() { r.Use(s.WithCtxStores()) s.installPprof() + r.Use(s.auth.VerifyMiddleware()) r.Group(func(r chi.Router) { + r.Use(s.auth.SubjectMiddleware(true)) // authenticated routes - r.Use(s.auth.VerifyMiddleware(), s.auth.AuthMiddleware()) s.nex.PrivateRoutes(r) s.auth.PrivateRoutes(r) s.alerter.PrivateRoutes(r) @@ -41,6 +42,7 @@ func (s *Server) setupRoutes() { r.Group(func(r chi.Router) { s.rateLimit(r) + r.Use(s.auth.SubjectMiddleware(false)) r.Use(render.SetContentType(render.ContentTypeJSON)) // public routes s.sources.PublicRoutes(r) @@ -49,6 +51,7 @@ func (s *Server) setupRoutes() { r.Group(func(r chi.Router) { // auth/share routes get rate-limited heavily, but not using middleware s.rateLimit(r) + r.Use(s.auth.SubjectMiddleware(false)) r.Use(render.SetContentType(render.ContentTypeJSON)) s.auth.PublicRoutes(r) r.Mount("/share", s.rest.ShareRouter()) @@ -56,9 +59,8 @@ func (s *Server) setupRoutes() { r.Group(func(r chi.Router) { s.rateLimit(r) - r.Use(s.auth.VerifyMiddleware()) - // optional auth routes + r.Use(s.auth.SubjectMiddleware(false)) s.clientRoute(r, clientRoot) })