diff --git a/cmd/gordio/main.go b/cmd/gordio/main.go index ebe1e11..207c6b1 100644 --- a/cmd/gordio/main.go +++ b/cmd/gordio/main.go @@ -15,8 +15,6 @@ import ( "github.com/spf13/cobra" ) - - func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) diff --git a/config.sample.yaml b/config.sample.yaml index dc4b2b1..c5d8392 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -21,3 +21,7 @@ log: - level: debug - level: error file: error.log +rateLimit: + enable: true + requests: 200 + over: 2m diff --git a/pkg/gordio/config/config.go b/pkg/gordio/config/config.go index e7c758e..f5a9dd6 100644 --- a/pkg/gordio/config/config.go +++ b/pkg/gordio/config/config.go @@ -2,6 +2,8 @@ package config import ( "os" + "sync" + "time" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -9,12 +11,13 @@ import ( ) type Config struct { - DB DB `yaml:"db"` - CORS CORS `yaml:"cors"` - Auth Auth `yaml:"auth"` - Log []Logger `yaml:"log"` - Listen string `yaml:"listen"` - Public bool `yaml:"public"` + DB DB `yaml:"db"` + CORS CORS `yaml:"cors"` + Auth Auth `yaml:"auth"` + Log []Logger `yaml:"log"` + Listen string `yaml:"listen"` + Public bool `yaml:"public"` + RateLimit RateLimit `yaml:"rateLimit"` configPath string } @@ -39,6 +42,28 @@ type Logger struct { Level *string `yaml:"level"` } +type RateLimit struct { + Enable bool `yaml:"enable"` + Requests int `yaml:"requests"` + Over time.Duration `yaml:"over"` + + verifyError sync.Once +} + +func (rl *RateLimit) Verify() bool { + if rl.Enable { + if rl.Requests > 0 && rl.Over > 0 { + return true + } + + rl.verifyError.Do(func() { + log.Error().Int("requests", rl.Requests).Str("over", rl.Over.String()).Msg("rate limit config makes no sense, disabled") + }) + } + + return false +} + func (c *Config) PreRunE() func(*cobra.Command, []string) error { return func(cmd *cobra.Command, args []string) error { return c.ReadConfig() diff --git a/pkg/gordio/nexus/client.go b/pkg/gordio/nexus/client.go index 89d92cc..4f16a8f 100644 --- a/pkg/gordio/nexus/client.go +++ b/pkg/gordio/nexus/client.go @@ -6,8 +6,8 @@ import ( "runtime" "sync" - "dynatron.me/x/stillbox/pkg/gordio/version" "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/gordio/version" "dynatron.me/x/stillbox/pkg/pb" "github.com/rs/zerolog/log" @@ -65,9 +65,9 @@ func (c *client) HandleMessage(ctx context.Context, mesgBytes []byte) { func pbVersion() *pb.Version { return &pb.Version{ ServerName: version.Name, - Version: version.Version, - Built: version.Built, - Platform: runtime.GOOS + "-" + runtime.GOARCH, + Version: version.Version, + Built: version.Built, + Platform: runtime.GOOS + "-" + runtime.GOARCH, } } diff --git a/pkg/gordio/server/routes.go b/pkg/gordio/server/routes.go index 3b9d932..1481dba 100644 --- a/pkg/gordio/server/routes.go +++ b/pkg/gordio/server/routes.go @@ -4,9 +4,9 @@ import ( "io/fs" "net/http" "strings" - "time" "dynatron.me/x/stillbox/client" + "dynatron.me/x/stillbox/pkg/gordio/config" "dynatron.me/x/stillbox/pkg/gordio/database" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -31,7 +31,7 @@ func (s *Server) setupRoutes() { }) r.Group(func(r chi.Router) { - r.Use(rateLimiter()) + s.rateLimit(r) r.Use(render.SetContentType(render.ContentTypeJSON)) // public routes s.auth.PublicRoutes(r) @@ -39,7 +39,8 @@ func (s *Server) setupRoutes() { }) r.Group(func(r chi.Router) { - r.Use(rateLimiter(), s.auth.VerifyMiddleware()) + s.rateLimit(r) + r.Use(s.auth.VerifyMiddleware()) // optional auth routes @@ -47,8 +48,13 @@ func (s *Server) setupRoutes() { }) } -func rateLimiter() func(http.Handler) http.Handler { - return httprate.LimitByRealIP(100, 1*time.Minute) +func (s *Server) rateLimit(r chi.Router) { + if s.conf.RateLimit.Verify() { + r.Use(rateLimiter(&s.conf.RateLimit)) + } +} +func rateLimiter(cfg *config.RateLimit) func(http.Handler) http.Handler { + return httprate.LimitByRealIP(cfg.Requests, cfg.Over) } func (s *Server) clientRoute(r chi.Router, clientRoot fs.FS) { @@ -57,11 +63,5 @@ func (s *Server) clientRoute(r chi.Router, clientRoot fs.FS) { pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") fs := http.StripPrefix(pathPrefix, http.FileServer(http.FS(clientRoot))) fs.ServeHTTP(w, r) - /* - if cl, authenticated := s.auth.Authenticated(r); authenticated { - w.Write([]byte("Hello " + cl["user"].(string) + "\n")) - } - w.Write([]byte("Welcome to gordio\n")) - */ }) } diff --git a/pkg/gordio/version/version.go b/pkg/gordio/version/version.go index d5468d1..cbd6533 100644 --- a/pkg/gordio/version/version.go +++ b/pkg/gordio/version/version.go @@ -6,7 +6,7 @@ import ( ) var ( - Name = "gordio" + Name = "gordio" Version = "unset" Built = "unset" )