Files
atlas/internal/httpapp/auth_middleware.go

170 lines
4.9 KiB
Go

package httpapp
import (
"context"
"net/http"
"strings"
"gitea.avt.data-center.id/othman.suseno/atlas/internal/auth"
"gitea.avt.data-center.id/othman.suseno/atlas/internal/models"
)
const (
userCtxKey ctxKey = "user"
roleCtxKey ctxKey = "role"
)
// authMiddleware validates JWT tokens and extracts user info
func (a *App) authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for public endpoints (includes web UI pages and read-only GET endpoints)
if a.isPublicEndpoint(r.URL.Path) {
next.ServeHTTP(w, r)
return
}
// Extract token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "missing authorization header"})
return
}
// Parse "Bearer <token>"
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "invalid authorization header format"})
return
}
token := parts[1]
claims, err := a.authService.ValidateToken(token)
if err != nil {
if err == auth.ErrExpiredToken {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "token expired"})
} else {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "invalid token"})
}
return
}
// Get user from store
user, err := a.userStore.GetByID(claims.UserID)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "user not found"})
return
}
if !user.Active {
writeJSON(w, http.StatusForbidden, map[string]string{"error": "user account is disabled"})
return
}
// Add user info to context
ctx := context.WithValue(r.Context(), userCtxKey, user)
ctx = context.WithValue(ctx, roleCtxKey, user.Role)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// requireRole middleware checks if user has required role
func (a *App) requireRole(allowedRoles ...models.Role) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
role, ok := r.Context().Value(roleCtxKey).(models.Role)
if !ok {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "unauthorized"})
return
}
// Check if user role is in allowed roles
allowed := false
for _, allowedRole := range allowedRoles {
if role == allowedRole {
allowed = true
break
}
}
if !allowed {
writeJSON(w, http.StatusForbidden, map[string]string{"error": "insufficient permissions"})
return
}
next.ServeHTTP(w, r)
})
}
}
// isPublicEndpoint checks if an endpoint is public (no auth required)
func (a *App) isPublicEndpoint(path string) bool {
publicPaths := []string{
"/healthz",
"/health",
"/metrics",
"/api/v1/auth/login",
"/api/v1/auth/logout",
"/", // Dashboard
"/login", // Login page
"/storage", // Storage management page
"/shares", // Shares page
"/iscsi", // iSCSI page
"/protection", // Data Protection page
"/management", // System Management page
"/api/docs", // API documentation
"/api/openapi.yaml", // OpenAPI spec
}
for _, publicPath := range publicPaths {
if path == publicPath {
return true
}
// Also allow paths that start with public paths (for sub-pages)
if strings.HasPrefix(path, publicPath+"/") {
return true
}
}
// Static files are public
if strings.HasPrefix(path, "/static/") {
return true
}
// Make read-only GET endpoints public for web UI (but require auth for mutations)
// This allows the UI to display data without login, but operations require auth
publicReadOnlyPaths := []string{
"/api/v1/dashboard", // Dashboard data
"/api/v1/disks", // List disks
"/api/v1/pools", // List pools (GET only)
"/api/v1/pools/available", // List available pools
"/api/v1/datasets", // List datasets (GET only)
"/api/v1/zvols", // List ZVOLs (GET only)
"/api/v1/shares/smb", // List SMB shares (GET only)
"/api/v1/exports/nfs", // List NFS exports (GET only)
"/api/v1/iscsi/targets", // List iSCSI targets (GET only)
"/api/v1/snapshots", // List snapshots (GET only)
"/api/v1/snapshot-policies", // List snapshot policies (GET only)
}
for _, publicPath := range publicReadOnlyPaths {
if path == publicPath {
return true
}
}
return false
}
// getUserFromContext extracts user from request context
func getUserFromContext(r *http.Request) (*models.User, bool) {
user, ok := r.Context().Value(userCtxKey).(*models.User)
return user, ok
}
// getRoleFromContext extracts role from request context
func getRoleFromContext(r *http.Request) (models.Role, bool) {
role, ok := r.Context().Value(roleCtxKey).(models.Role)
return role, ok
}