package httpapp import ( "net/http" "sync" "time" "gitea.avt.data-center.id/othman.suseno/atlas/internal/errors" ) // RateLimiter implements token bucket rate limiting type RateLimiter struct { mu sync.RWMutex clients map[string]*clientLimiter rate int // requests per window window time.Duration // time window cleanupTick *time.Ticker stopCleanup chan struct{} } type clientLimiter struct { tokens int lastUpdate time.Time mu sync.Mutex } // NewRateLimiter creates a new rate limiter func NewRateLimiter(rate int, window time.Duration) *RateLimiter { rl := &RateLimiter{ clients: make(map[string]*clientLimiter), rate: rate, window: window, cleanupTick: time.NewTicker(5 * time.Minute), stopCleanup: make(chan struct{}), } // Start cleanup goroutine go rl.cleanup() return rl } // cleanup periodically removes old client limiters func (rl *RateLimiter) cleanup() { for { select { case <-rl.cleanupTick.C: rl.mu.Lock() now := time.Now() for key, limiter := range rl.clients { limiter.mu.Lock() // Remove if last update was more than 2 windows ago if now.Sub(limiter.lastUpdate) > rl.window*2 { delete(rl.clients, key) } limiter.mu.Unlock() } rl.mu.Unlock() case <-rl.stopCleanup: return } } } // Stop stops the cleanup goroutine func (rl *RateLimiter) Stop() { rl.cleanupTick.Stop() close(rl.stopCleanup) } // Allow checks if a request from the given key should be allowed func (rl *RateLimiter) Allow(key string) bool { rl.mu.Lock() limiter, exists := rl.clients[key] if !exists { limiter = &clientLimiter{ tokens: rl.rate, lastUpdate: time.Now(), } rl.clients[key] = limiter } rl.mu.Unlock() limiter.mu.Lock() defer limiter.mu.Unlock() now := time.Now() elapsed := now.Sub(limiter.lastUpdate) // Refill tokens based on elapsed time if elapsed >= rl.window { // Full refill limiter.tokens = rl.rate } else { // Partial refill based on elapsed time tokensToAdd := int(float64(rl.rate) * elapsed.Seconds() / rl.window.Seconds()) if tokensToAdd > 0 { limiter.tokens = min(limiter.tokens+tokensToAdd, rl.rate) } } limiter.lastUpdate = now // Check if we have tokens if limiter.tokens > 0 { limiter.tokens-- return true } return false } // getClientKey extracts a key for rate limiting from the request func getClientKey(r *http.Request) string { // Try to get IP address ip := getClientIP(r) // If authenticated, use user ID for more granular limiting if user, ok := getUserFromContext(r); ok { return "user:" + user.ID } return "ip:" + ip } // rateLimitMiddleware implements rate limiting func (a *App) rateLimitMiddleware(next http.Handler) http.Handler { // Default: 100 requests per minute per client rateLimiter := NewRateLimiter(100, time.Minute) // Store limiter for cleanup on shutdown // Note: Cleanup will be handled by the limiter's own cleanup goroutine _ = rateLimiter return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip rate limiting for public endpoints if a.isPublicEndpoint(r.URL.Path, r.Method) { next.ServeHTTP(w, r) return } key := getClientKey(r) if !rateLimiter.Allow(key) { writeError(w, errors.NewAPIError( errors.ErrCodeServiceUnavailable, "rate limit exceeded", http.StatusTooManyRequests, ).WithDetails("too many requests, please try again later")) return } // Add rate limit headers w.Header().Set("X-RateLimit-Limit", "100") w.Header().Set("X-RateLimit-Window", "60") next.ServeHTTP(w, r) }) } func min(a, b int) int { if a < b { return a } return b }