package http import ( "context" "crypto/rand" "encoding/base64" "log" "net/http" "strings" "time" "github.com/example/storage-appliance/internal/auth" ) // ContextKey used to store values in context type ContextKey string const ( ContextKeyRequestID ContextKey = "request-id" ContextKeyUser ContextKey = "user" ContextKeyUserID ContextKey = "user.id" ContextKeySession ContextKey = "session" ) // RequestID middleware sets a request ID in headers and request context func RequestID(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) }) } // Logging middleware prints basic request logs func Logging(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) log.Printf("%s %s in %v", r.Method, r.URL.Path, time.Since(start)) }) } // AuthMiddleware creates an auth middleware that uses the provided App func AuthMiddleware(app *App) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip auth for login and public routes if strings.HasPrefix(r.URL.Path, "/login") || strings.HasPrefix(r.URL.Path, "/static") || r.URL.Path == "/healthz" || r.URL.Path == "/metrics" { next.ServeHTTP(w, r) return } // Get session token from cookie cookie, err := r.Cookie(auth.SessionCookieName) if err != nil { // No session, redirect to login if r.Header.Get("HX-Request") == "true" { w.Header().Set("HX-Redirect", "/login") w.WriteHeader(http.StatusUnauthorized) } else { http.Redirect(w, r, "/login", http.StatusFound) } return } // Validate session sessionStore := auth.NewSessionStore(app.DB) session, err := sessionStore.GetSession(r.Context(), cookie.Value) if err != nil { // Invalid session, redirect to login if r.Header.Get("HX-Request") == "true" { w.Header().Set("HX-Redirect", "/login") w.WriteHeader(http.StatusUnauthorized) } else { http.Redirect(w, r, "/login", http.StatusFound) } return } // Get user userStore := auth.NewUserStore(app.DB) user, err := userStore.GetUserByID(r.Context(), session.UserID) if err != nil { http.Error(w, "user not found", http.StatusUnauthorized) return } // Store user info in context ctx := context.WithValue(r.Context(), ContextKeyUser, user.Username) ctx = context.WithValue(ctx, ContextKeyUserID, user.ID) ctx = context.WithValue(ctx, ContextKeySession, session) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Auth is a legacy wrapper for backward compatibility func Auth(next http.Handler) http.Handler { // This will be replaced by AuthMiddleware in router return next } // RequireAuth middleware ensures user is authenticated (alternative to Auth that doesn't redirect) func RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID := r.Context().Value(ContextKeyUserID) if userID == nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } // CSRFMiddleware creates a CSRF middleware that uses the provided App func CSRFMiddleware(app *App) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // For safe methods, ensure CSRF token cookie exists if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" { // Set CSRF token cookie if it doesn't exist if cookie, err := r.Cookie("csrf_token"); err != nil || cookie.Value == "" { token := generateCSRFToken() http.SetCookie(w, &http.Cookie{ Name: "csrf_token", Value: token, Path: "/", HttpOnly: false, // Needed for HTMX to read it Secure: false, SameSite: http.SameSiteStrictMode, MaxAge: 86400, // 24 hours }) } next.ServeHTTP(w, r) return } // Get CSRF token from header (HTMX compatible) or form token := r.Header.Get("X-CSRF-Token") if token == "" { token = r.FormValue("csrf_token") } // Get expected token from cookie expectedToken := getCSRFToken(r) if token == "" || token != expectedToken { http.Error(w, "invalid CSRF token", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } } // getCSRFToken retrieves or generates a CSRF token for the session func getCSRFToken(r *http.Request) string { // Try to get from cookie first cookie, err := r.Cookie("csrf_token") if err == nil && cookie.Value != "" { return cookie.Value } // Generate new token (will be set in cookie by handler) return generateCSRFToken() } func generateCSRFToken() string { b := make([]byte, 32) rand.Read(b) return base64.URLEncoding.EncodeToString(b) } // RequirePermission creates a permission check middleware func RequirePermission(app *App, permission string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID := r.Context().Value(ContextKeyUserID) if userID == nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return } rbacStore := auth.NewRBACStore(app.DB) hasPermission, err := rbacStore.UserHasPermission(r.Context(), userID.(string), permission) if err != nil { log.Printf("permission check error: %v", err) http.Error(w, "internal error", http.StatusInternalServerError) return } if !hasPermission { http.Error(w, "forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } } // RBAC middleware (kept for backward compatibility) func RBAC(permission string) func(http.Handler) http.Handler { // This will be replaced by RequirePermission in router return func(next http.Handler) http.Handler { return next } }