package httpapp import ( "fmt" "net/http" "strings" ) // auditMiddleware logs all mutating operations func (a *App) auditMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Only log mutating operations (POST, PUT, DELETE, PATCH) if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions { next.ServeHTTP(w, r) return } // Skip audit for public endpoints if a.isPublicEndpoint(r.URL.Path) { next.ServeHTTP(w, r) return } // Get user from context (if authenticated) actor := "system" user, ok := getUserFromContext(r) if ok { actor = user.ID } // Extract action from method and path action := extractAction(r.Method, r.URL.Path) resource := extractResource(r.URL.Path) // Get client info ip := getClientIP(r) userAgent := r.UserAgent() // Create response writer wrapper to capture status code rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} // Execute the handler next.ServeHTTP(rw, r) // Log the operation result := "success" message := "" if rw.statusCode >= 400 { result = "failure" message = http.StatusText(rw.statusCode) } a.auditStore.Log(actor, action, resource, result, message, ip, userAgent) }) } // responseWriter wraps http.ResponseWriter to capture status code type responseWriter struct { http.ResponseWriter statusCode int } func (rw *responseWriter) WriteHeader(code int) { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } // extractAction extracts action name from HTTP method and path func extractAction(method, path string) string { // Remove /api/v1 prefix path = strings.TrimPrefix(path, "/api/v1") path = strings.Trim(path, "/") parts := strings.Split(path, "/") resource := parts[0] // Map HTTP methods to actions actionMap := map[string]string{ http.MethodPost: "create", http.MethodPut: "update", http.MethodPatch: "update", http.MethodDelete: "delete", } action := actionMap[method] if action == "" { action = strings.ToLower(method) } return fmt.Sprintf("%s.%s", resource, action) } // extractResource extracts resource identifier from path func extractResource(path string) string { // Remove /api/v1 prefix path = strings.TrimPrefix(path, "/api/v1") path = strings.Trim(path, "/") parts := strings.Split(path, "/") if len(parts) == 0 { return "unknown" } resource := parts[0] if len(parts) > 1 { // Include resource ID if present resource = fmt.Sprintf("%s/%s", resource, parts[1]) } return resource } // getClientIP extracts client IP from request func getClientIP(r *http.Request) string { // Check X-Forwarded-For header (for proxies) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { ips := strings.Split(xff, ",") return strings.TrimSpace(ips[0]) } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fallback to RemoteAddr ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx] } return ip }