Rate limiting is configurable

This commit is contained in:
Daniel 2024-10-18 15:21:42 -04:00
parent fdabdba892
commit 943e3ae2ac
6 changed files with 51 additions and 24 deletions

View file

@ -15,8 +15,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func main() { func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})

View file

@ -21,3 +21,7 @@ log:
- level: debug - level: debug
- level: error - level: error
file: error.log file: error.log
rateLimit:
enable: true
requests: 200
over: 2m

View file

@ -2,6 +2,8 @@ package config
import ( import (
"os" "os"
"sync"
"time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -9,12 +11,13 @@ import (
) )
type Config struct { type Config struct {
DB DB `yaml:"db"` DB DB `yaml:"db"`
CORS CORS `yaml:"cors"` CORS CORS `yaml:"cors"`
Auth Auth `yaml:"auth"` Auth Auth `yaml:"auth"`
Log []Logger `yaml:"log"` Log []Logger `yaml:"log"`
Listen string `yaml:"listen"` Listen string `yaml:"listen"`
Public bool `yaml:"public"` Public bool `yaml:"public"`
RateLimit RateLimit `yaml:"rateLimit"`
configPath string configPath string
} }
@ -39,6 +42,28 @@ type Logger struct {
Level *string `yaml:"level"` Level *string `yaml:"level"`
} }
type RateLimit struct {
Enable bool `yaml:"enable"`
Requests int `yaml:"requests"`
Over time.Duration `yaml:"over"`
verifyError sync.Once
}
func (rl *RateLimit) Verify() bool {
if rl.Enable {
if rl.Requests > 0 && rl.Over > 0 {
return true
}
rl.verifyError.Do(func() {
log.Error().Int("requests", rl.Requests).Str("over", rl.Over.String()).Msg("rate limit config makes no sense, disabled")
})
}
return false
}
func (c *Config) PreRunE() func(*cobra.Command, []string) error { func (c *Config) PreRunE() func(*cobra.Command, []string) error {
return func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error {
return c.ReadConfig() return c.ReadConfig()

View file

@ -6,8 +6,8 @@ import (
"runtime" "runtime"
"sync" "sync"
"dynatron.me/x/stillbox/pkg/gordio/version"
"dynatron.me/x/stillbox/pkg/calls" "dynatron.me/x/stillbox/pkg/calls"
"dynatron.me/x/stillbox/pkg/gordio/version"
"dynatron.me/x/stillbox/pkg/pb" "dynatron.me/x/stillbox/pkg/pb"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -65,9 +65,9 @@ func (c *client) HandleMessage(ctx context.Context, mesgBytes []byte) {
func pbVersion() *pb.Version { func pbVersion() *pb.Version {
return &pb.Version{ return &pb.Version{
ServerName: version.Name, ServerName: version.Name,
Version: version.Version, Version: version.Version,
Built: version.Built, Built: version.Built,
Platform: runtime.GOOS + "-" + runtime.GOARCH, Platform: runtime.GOOS + "-" + runtime.GOARCH,
} }
} }

View file

@ -4,9 +4,9 @@ import (
"io/fs" "io/fs"
"net/http" "net/http"
"strings" "strings"
"time"
"dynatron.me/x/stillbox/client" "dynatron.me/x/stillbox/client"
"dynatron.me/x/stillbox/pkg/gordio/config"
"dynatron.me/x/stillbox/pkg/gordio/database" "dynatron.me/x/stillbox/pkg/gordio/database"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
@ -31,7 +31,7 @@ func (s *Server) setupRoutes() {
}) })
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(rateLimiter()) s.rateLimit(r)
r.Use(render.SetContentType(render.ContentTypeJSON)) r.Use(render.SetContentType(render.ContentTypeJSON))
// public routes // public routes
s.auth.PublicRoutes(r) s.auth.PublicRoutes(r)
@ -39,7 +39,8 @@ func (s *Server) setupRoutes() {
}) })
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(rateLimiter(), s.auth.VerifyMiddleware()) s.rateLimit(r)
r.Use(s.auth.VerifyMiddleware())
// optional auth routes // optional auth routes
@ -47,8 +48,13 @@ func (s *Server) setupRoutes() {
}) })
} }
func rateLimiter() func(http.Handler) http.Handler { func (s *Server) rateLimit(r chi.Router) {
return httprate.LimitByRealIP(100, 1*time.Minute) if s.conf.RateLimit.Verify() {
r.Use(rateLimiter(&s.conf.RateLimit))
}
}
func rateLimiter(cfg *config.RateLimit) func(http.Handler) http.Handler {
return httprate.LimitByRealIP(cfg.Requests, cfg.Over)
} }
func (s *Server) clientRoute(r chi.Router, clientRoot fs.FS) { func (s *Server) clientRoute(r chi.Router, clientRoot fs.FS) {
@ -57,11 +63,5 @@ func (s *Server) clientRoute(r chi.Router, clientRoot fs.FS) {
pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*")
fs := http.StripPrefix(pathPrefix, http.FileServer(http.FS(clientRoot))) fs := http.StripPrefix(pathPrefix, http.FileServer(http.FS(clientRoot)))
fs.ServeHTTP(w, r) fs.ServeHTTP(w, r)
/*
if cl, authenticated := s.auth.Authenticated(r); authenticated {
w.Write([]byte("Hello " + cl["user"].(string) + "\n"))
}
w.Write([]byte("Welcome to gordio\n"))
*/
}) })
} }

View file

@ -6,7 +6,7 @@ import (
) )
var ( var (
Name = "gordio" Name = "gordio"
Version = "unset" Version = "unset"
Built = "unset" Built = "unset"
) )