This commit is contained in:
165
internal/httpapp/rate_limit.go
Normal file
165
internal/httpapp/rate_limit.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package httpapp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.avt.data-center.id/othman.suseno/atlas/internal/errors"
|
||||
)
|
||||
|
||||
// RateLimiter implements token bucket rate limiting
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*clientLimiter
|
||||
rate int // requests per window
|
||||
window time.Duration // time window
|
||||
cleanupTick *time.Ticker
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
type clientLimiter struct {
|
||||
tokens int
|
||||
lastUpdate time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter(rate int, window time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
clients: make(map[string]*clientLimiter),
|
||||
rate: rate,
|
||||
window: window,
|
||||
cleanupTick: time.NewTicker(5 * time.Minute),
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go rl.cleanup()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanup periodically removes old client limiters
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
for {
|
||||
select {
|
||||
case <-rl.cleanupTick.C:
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, limiter := range rl.clients {
|
||||
limiter.mu.Lock()
|
||||
// Remove if last update was more than 2 windows ago
|
||||
if now.Sub(limiter.lastUpdate) > rl.window*2 {
|
||||
delete(rl.clients, key)
|
||||
}
|
||||
limiter.mu.Unlock()
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
case <-rl.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (rl *RateLimiter) Stop() {
|
||||
rl.cleanupTick.Stop()
|
||||
close(rl.stopCleanup)
|
||||
}
|
||||
|
||||
// Allow checks if a request from the given key should be allowed
|
||||
func (rl *RateLimiter) Allow(key string) bool {
|
||||
rl.mu.Lock()
|
||||
limiter, exists := rl.clients[key]
|
||||
if !exists {
|
||||
limiter = &clientLimiter{
|
||||
tokens: rl.rate,
|
||||
lastUpdate: time.Now(),
|
||||
}
|
||||
rl.clients[key] = limiter
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
limiter.mu.Lock()
|
||||
defer limiter.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(limiter.lastUpdate)
|
||||
|
||||
// Refill tokens based on elapsed time
|
||||
if elapsed >= rl.window {
|
||||
// Full refill
|
||||
limiter.tokens = rl.rate
|
||||
} else {
|
||||
// Partial refill based on elapsed time
|
||||
tokensToAdd := int(float64(rl.rate) * elapsed.Seconds() / rl.window.Seconds())
|
||||
if tokensToAdd > 0 {
|
||||
limiter.tokens = min(limiter.tokens+tokensToAdd, rl.rate)
|
||||
}
|
||||
}
|
||||
|
||||
limiter.lastUpdate = now
|
||||
|
||||
// Check if we have tokens
|
||||
if limiter.tokens > 0 {
|
||||
limiter.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getClientKey extracts a key for rate limiting from the request
|
||||
func getClientKey(r *http.Request) string {
|
||||
// Try to get IP address
|
||||
ip := getClientIP(r)
|
||||
|
||||
// If authenticated, use user ID for more granular limiting
|
||||
if user, ok := getUserFromContext(r); ok {
|
||||
return "user:" + user.ID
|
||||
}
|
||||
|
||||
return "ip:" + ip
|
||||
}
|
||||
|
||||
// rateLimitMiddleware implements rate limiting
|
||||
func (a *App) rateLimitMiddleware(next http.Handler) http.Handler {
|
||||
// Default: 100 requests per minute per client
|
||||
rateLimiter := NewRateLimiter(100, time.Minute)
|
||||
|
||||
// Store limiter for cleanup on shutdown
|
||||
// Note: Cleanup will be handled by the limiter's own cleanup goroutine
|
||||
_ = rateLimiter
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip rate limiting for public endpoints
|
||||
if a.isPublicEndpoint(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
key := getClientKey(r)
|
||||
if !rateLimiter.Allow(key) {
|
||||
writeError(w, errors.NewAPIError(
|
||||
errors.ErrCodeServiceUnavailable,
|
||||
"rate limit exceeded",
|
||||
http.StatusTooManyRequests,
|
||||
).WithDetails("too many requests, please try again later"))
|
||||
return
|
||||
}
|
||||
|
||||
// Add rate limit headers
|
||||
w.Header().Set("X-RateLimit-Limit", "100")
|
||||
w.Header().Set("X-RateLimit-Window", "60")
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user