Files
storage-appliance/internal/http/middleware.go

208 lines
5.9 KiB
Go

package http
import (
"context"
"crypto/rand"
"encoding/base64"
"log"
"net/http"
"strings"
"time"
"github.com/example/storage-appliance/internal/auth"
)
// ContextKey used to store values in context
type ContextKey string
const (
ContextKeyRequestID ContextKey = "request-id"
ContextKeyUser ContextKey = "user"
ContextKeyUserID ContextKey = "user.id"
ContextKeySession ContextKey = "session"
)
// RequestID middleware sets a request ID in headers and request context
func RequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}
// Logging middleware prints basic request logs
func Logging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
log.Printf("%s %s in %v", r.Method, r.URL.Path, time.Since(start))
})
}
// AuthMiddleware creates an auth middleware that uses the provided App
func AuthMiddleware(app *App) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for login and public routes
if strings.HasPrefix(r.URL.Path, "/login") || strings.HasPrefix(r.URL.Path, "/static") || r.URL.Path == "/healthz" || r.URL.Path == "/metrics" {
next.ServeHTTP(w, r)
return
}
// Get session token from cookie
cookie, err := r.Cookie(auth.SessionCookieName)
if err != nil {
// No session, redirect to login
if r.Header.Get("HX-Request") == "true" {
w.Header().Set("HX-Redirect", "/login")
w.WriteHeader(http.StatusUnauthorized)
} else {
http.Redirect(w, r, "/login", http.StatusFound)
}
return
}
// Validate session
sessionStore := auth.NewSessionStore(app.DB)
session, err := sessionStore.GetSession(r.Context(), cookie.Value)
if err != nil {
// Invalid session, redirect to login
if r.Header.Get("HX-Request") == "true" {
w.Header().Set("HX-Redirect", "/login")
w.WriteHeader(http.StatusUnauthorized)
} else {
http.Redirect(w, r, "/login", http.StatusFound)
}
return
}
// Get user
userStore := auth.NewUserStore(app.DB)
user, err := userStore.GetUserByID(r.Context(), session.UserID)
if err != nil {
http.Error(w, "user not found", http.StatusUnauthorized)
return
}
// Store user info in context
ctx := context.WithValue(r.Context(), ContextKeyUser, user.Username)
ctx = context.WithValue(ctx, ContextKeyUserID, user.ID)
ctx = context.WithValue(ctx, ContextKeySession, session)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Auth is a legacy wrapper for backward compatibility
func Auth(next http.Handler) http.Handler {
// This will be replaced by AuthMiddleware in router
return next
}
// RequireAuth middleware ensures user is authenticated (alternative to Auth that doesn't redirect)
func RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := r.Context().Value(ContextKeyUserID)
if userID == nil {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
// CSRFMiddleware creates a CSRF middleware that uses the provided App
func CSRFMiddleware(app *App) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// For safe methods, ensure CSRF token cookie exists
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
// Set CSRF token cookie if it doesn't exist
if cookie, err := r.Cookie("csrf_token"); err != nil || cookie.Value == "" {
token := generateCSRFToken()
http.SetCookie(w, &http.Cookie{
Name: "csrf_token",
Value: token,
Path: "/",
HttpOnly: false, // Needed for HTMX to read it
Secure: false,
SameSite: http.SameSiteStrictMode,
MaxAge: 86400, // 24 hours
})
}
next.ServeHTTP(w, r)
return
}
// Get CSRF token from header (HTMX compatible) or form
token := r.Header.Get("X-CSRF-Token")
if token == "" {
token = r.FormValue("csrf_token")
}
// Get expected token from cookie
expectedToken := getCSRFToken(r)
if token == "" || token != expectedToken {
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
// getCSRFToken retrieves or generates a CSRF token for the session
func getCSRFToken(r *http.Request) string {
// Try to get from cookie first
cookie, err := r.Cookie("csrf_token")
if err == nil && cookie.Value != "" {
return cookie.Value
}
// Generate new token (will be set in cookie by handler)
return generateCSRFToken()
}
func generateCSRFToken() string {
b := make([]byte, 32)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}
// RequirePermission creates a permission check middleware
func RequirePermission(app *App, permission string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := r.Context().Value(ContextKeyUserID)
if userID == nil {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
rbacStore := auth.NewRBACStore(app.DB)
hasPermission, err := rbacStore.UserHasPermission(r.Context(), userID.(string), permission)
if err != nil {
log.Printf("permission check error: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if !hasPermission {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
// RBAC middleware (kept for backward compatibility)
func RBAC(permission string) func(http.Handler) http.Handler {
// This will be replaced by RequirePermission in router
return func(next http.Handler) http.Handler {
return next
}
}