166 lines
3.6 KiB
Go
166 lines
3.6 KiB
Go
package httpapp
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"gitea.avt.data-center.id/othman.suseno/atlas/internal/errors"
|
|
)
|
|
|
|
// RateLimiter implements token bucket rate limiting
|
|
type RateLimiter struct {
|
|
mu sync.RWMutex
|
|
clients map[string]*clientLimiter
|
|
rate int // requests per window
|
|
window time.Duration // time window
|
|
cleanupTick *time.Ticker
|
|
stopCleanup chan struct{}
|
|
}
|
|
|
|
type clientLimiter struct {
|
|
tokens int
|
|
lastUpdate time.Time
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter
|
|
func NewRateLimiter(rate int, window time.Duration) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
clients: make(map[string]*clientLimiter),
|
|
rate: rate,
|
|
window: window,
|
|
cleanupTick: time.NewTicker(5 * time.Minute),
|
|
stopCleanup: make(chan struct{}),
|
|
}
|
|
|
|
// Start cleanup goroutine
|
|
go rl.cleanup()
|
|
|
|
return rl
|
|
}
|
|
|
|
// cleanup periodically removes old client limiters
|
|
func (rl *RateLimiter) cleanup() {
|
|
for {
|
|
select {
|
|
case <-rl.cleanupTick.C:
|
|
rl.mu.Lock()
|
|
now := time.Now()
|
|
for key, limiter := range rl.clients {
|
|
limiter.mu.Lock()
|
|
// Remove if last update was more than 2 windows ago
|
|
if now.Sub(limiter.lastUpdate) > rl.window*2 {
|
|
delete(rl.clients, key)
|
|
}
|
|
limiter.mu.Unlock()
|
|
}
|
|
rl.mu.Unlock()
|
|
case <-rl.stopCleanup:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stop stops the cleanup goroutine
|
|
func (rl *RateLimiter) Stop() {
|
|
rl.cleanupTick.Stop()
|
|
close(rl.stopCleanup)
|
|
}
|
|
|
|
// Allow checks if a request from the given key should be allowed
|
|
func (rl *RateLimiter) Allow(key string) bool {
|
|
rl.mu.Lock()
|
|
limiter, exists := rl.clients[key]
|
|
if !exists {
|
|
limiter = &clientLimiter{
|
|
tokens: rl.rate,
|
|
lastUpdate: time.Now(),
|
|
}
|
|
rl.clients[key] = limiter
|
|
}
|
|
rl.mu.Unlock()
|
|
|
|
limiter.mu.Lock()
|
|
defer limiter.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
elapsed := now.Sub(limiter.lastUpdate)
|
|
|
|
// Refill tokens based on elapsed time
|
|
if elapsed >= rl.window {
|
|
// Full refill
|
|
limiter.tokens = rl.rate
|
|
} else {
|
|
// Partial refill based on elapsed time
|
|
tokensToAdd := int(float64(rl.rate) * elapsed.Seconds() / rl.window.Seconds())
|
|
if tokensToAdd > 0 {
|
|
limiter.tokens = min(limiter.tokens+tokensToAdd, rl.rate)
|
|
}
|
|
}
|
|
|
|
limiter.lastUpdate = now
|
|
|
|
// Check if we have tokens
|
|
if limiter.tokens > 0 {
|
|
limiter.tokens--
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// getClientKey extracts a key for rate limiting from the request
|
|
func getClientKey(r *http.Request) string {
|
|
// Try to get IP address
|
|
ip := getClientIP(r)
|
|
|
|
// If authenticated, use user ID for more granular limiting
|
|
if user, ok := getUserFromContext(r); ok {
|
|
return "user:" + user.ID
|
|
}
|
|
|
|
return "ip:" + ip
|
|
}
|
|
|
|
// rateLimitMiddleware implements rate limiting
|
|
func (a *App) rateLimitMiddleware(next http.Handler) http.Handler {
|
|
// Default: 100 requests per minute per client
|
|
rateLimiter := NewRateLimiter(100, time.Minute)
|
|
|
|
// Store limiter for cleanup on shutdown
|
|
// Note: Cleanup will be handled by the limiter's own cleanup goroutine
|
|
_ = rateLimiter
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Skip rate limiting for public endpoints
|
|
if a.isPublicEndpoint(r.URL.Path) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
key := getClientKey(r)
|
|
if !rateLimiter.Allow(key) {
|
|
writeError(w, errors.NewAPIError(
|
|
errors.ErrCodeServiceUnavailable,
|
|
"rate limit exceeded",
|
|
http.StatusTooManyRequests,
|
|
).WithDetails("too many requests, please try again later"))
|
|
return
|
|
}
|
|
|
|
// Add rate limit headers
|
|
w.Header().Set("X-RateLimit-Limit", "100")
|
|
w.Header().Set("X-RateLimit-Window", "60")
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|