Files
atlas/internal/httpapp/cache_middleware.go

248 lines
5.4 KiB
Go

package httpapp
import (
"crypto/sha256"
"encoding/hex"
"net/http"
"sync"
"time"
)
// CacheEntry represents a cached response
type CacheEntry struct {
Body []byte
Headers map[string]string
StatusCode int
ExpiresAt time.Time
ETag string
}
// ResponseCache provides HTTP response caching
type ResponseCache struct {
mu sync.RWMutex
cache map[string]*CacheEntry
ttl time.Duration
}
// NewResponseCache creates a new response cache
func NewResponseCache(ttl time.Duration) *ResponseCache {
c := &ResponseCache{
cache: make(map[string]*CacheEntry),
ttl: ttl,
}
// Start cleanup goroutine
go c.cleanup()
return c
}
// cleanup periodically removes expired entries
func (c *ResponseCache) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now()
for key, entry := range c.cache {
if now.After(entry.ExpiresAt) {
delete(c.cache, key)
}
}
c.mu.Unlock()
}
}
// Get retrieves a cached entry
func (c *ResponseCache) Get(key string) (*CacheEntry, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.cache[key]
if !exists {
return nil, false
}
if time.Now().After(entry.ExpiresAt) {
return nil, false
}
return entry, true
}
// Set stores a cached entry
func (c *ResponseCache) Set(key string, entry *CacheEntry) {
c.mu.Lock()
defer c.mu.Unlock()
entry.ExpiresAt = time.Now().Add(c.ttl)
c.cache[key] = entry
}
// Invalidate removes a cached entry
func (c *ResponseCache) Invalidate(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
// InvalidatePattern removes entries matching a pattern
func (c *ResponseCache) InvalidatePattern(pattern string) {
c.mu.Lock()
defer c.mu.Unlock()
for key := range c.cache {
if containsPattern(key, pattern) {
delete(c.cache, key)
}
}
}
// containsPattern checks if a string contains a pattern (simple prefix/suffix matching)
func containsPattern(s, pattern string) bool {
// Simple pattern matching - can be enhanced
return len(s) >= len(pattern) && (s[:len(pattern)] == pattern || s[len(s)-len(pattern):] == pattern)
}
// generateCacheKey creates a cache key from request
func generateCacheKey(r *http.Request) string {
// Include method, path, and query string
key := r.Method + ":" + r.URL.Path
if r.URL.RawQuery != "" {
key += "?" + r.URL.RawQuery
}
// Hash the key for consistent length
hash := sha256.Sum256([]byte(key))
return hex.EncodeToString(hash[:])
}
// generateETag generates an ETag from content
func generateETag(content []byte) string {
hash := sha256.Sum256(content)
return `"` + hex.EncodeToString(hash[:16]) + `"`
}
// cacheMiddleware provides response caching for GET requests
func (a *App) cacheMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only cache GET requests
if r.Method != http.MethodGet {
next.ServeHTTP(w, r)
return
}
// Skip caching for authenticated endpoints that may have user-specific data
if !a.isPublicEndpoint(r.URL.Path, r.Method) {
// Check if user is authenticated - if so, skip caching
// In production, you might want per-user caching by including user ID in cache key
if _, ok := getUserFromContext(r); ok {
next.ServeHTTP(w, r)
return
}
}
// Skip caching for certain endpoints
if a.shouldSkipCache(r.URL.Path) {
next.ServeHTTP(w, r)
return
}
// Check cache
cacheKey := generateCacheKey(r)
if entry, found := a.cache.Get(cacheKey); found {
// Check If-None-Match header for ETag validation
ifNoneMatch := r.Header.Get("If-None-Match")
if ifNoneMatch == entry.ETag {
w.WriteHeader(http.StatusNotModified)
return
}
// Serve from cache
for k, v := range entry.Headers {
w.Header().Set(k, v)
}
w.Header().Set("ETag", entry.ETag)
w.Header().Set("X-Cache", "HIT")
w.WriteHeader(entry.StatusCode)
w.Write(entry.Body)
return
}
// Create response writer to capture response
rw := &cacheResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
body: make([]byte, 0),
}
next.ServeHTTP(rw, r)
// Only cache successful responses
if rw.statusCode >= 200 && rw.statusCode < 300 {
// Generate ETag
etag := generateETag(rw.body)
// Store in cache
headers := make(map[string]string)
for k, v := range rw.Header() {
if len(v) > 0 {
headers[k] = v[0]
}
}
entry := &CacheEntry{
Body: rw.body,
Headers: headers,
StatusCode: rw.statusCode,
ETag: etag,
}
a.cache.Set(cacheKey, entry)
// Add cache headers
w.Header().Set("ETag", etag)
w.Header().Set("X-Cache", "MISS")
}
})
}
// cacheResponseWriter captures response for caching
type cacheResponseWriter struct {
http.ResponseWriter
statusCode int
body []byte
}
func (rw *cacheResponseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *cacheResponseWriter) Write(b []byte) (int, error) {
rw.body = append(rw.body, b...)
return rw.ResponseWriter.Write(b)
}
// shouldSkipCache determines if a path should skip caching
func (a *App) shouldSkipCache(path string) bool {
// Skip caching for dynamic endpoints
skipPaths := []string{
"/metrics",
"/healthz",
"/health",
"/api/v1/system/info",
"/api/v1/system/logs",
"/api/v1/dashboard",
}
for _, skipPath := range skipPaths {
if path == skipPath {
return true
}
}
return false
}