package database import ( "context" "embed" "fmt" "io/fs" "sort" "strconv" "strings" "github.com/atlasos/calypso/internal/common/logger" ) //go:embed migrations/*.sql var migrationsFS embed.FS // RunMigrations executes all pending database migrations func RunMigrations(ctx context.Context, db *DB) error { log := logger.NewLogger("migrations") // Create migrations table if it doesn't exist if err := createMigrationsTable(ctx, db); err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } // Get all migration files migrations, err := getMigrationFiles() if err != nil { return fmt.Errorf("failed to read migration files: %w", err) } // Get applied migrations applied, err := getAppliedMigrations(ctx, db) if err != nil { return fmt.Errorf("failed to get applied migrations: %w", err) } // Apply pending migrations for _, migration := range migrations { if applied[migration.Version] { log.Debug("Migration already applied", "version", migration.Version) continue } log.Info("Applying migration", "version", migration.Version, "name", migration.Name) // Read migration SQL sql, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", migration.Filename)) if err != nil { return fmt.Errorf("failed to read migration file %s: %w", migration.Filename, err) } // Execute migration in a transaction tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } if _, err := tx.ExecContext(ctx, string(sql)); err != nil { tx.Rollback() return fmt.Errorf("failed to execute migration %d: %w", migration.Version, err) } // Record migration if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (version, applied_at) VALUES ($1, NOW())", migration.Version, ); err != nil { tx.Rollback() return fmt.Errorf("failed to record migration %d: %w", migration.Version, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err) } log.Info("Migration applied successfully", "version", migration.Version) } return nil } // Migration represents a database migration type Migration struct { Version int Name string Filename string } // getMigrationFiles returns all migration files sorted by version func getMigrationFiles() ([]Migration, error) { entries, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { return nil, err } var migrations []Migration for _, entry := range entries { if entry.IsDir() { continue } filename := entry.Name() if !strings.HasSuffix(filename, ".sql") { continue } // Parse version from filename: 001_initial_schema.sql parts := strings.SplitN(filename, "_", 2) if len(parts) < 2 { continue } version, err := strconv.Atoi(parts[0]) if err != nil { continue } name := strings.TrimSuffix(parts[1], ".sql") migrations = append(migrations, Migration{ Version: version, Name: name, Filename: filename, }) } // Sort by version sort.Slice(migrations, func(i, j int) bool { return migrations[i].Version < migrations[j].Version }) return migrations, nil } // createMigrationsTable creates the schema_migrations table func createMigrationsTable(ctx context.Context, db *DB) error { query := ` CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at TIMESTAMP NOT NULL DEFAULT NOW() ) ` _, err := db.ExecContext(ctx, query) return err } // getAppliedMigrations returns a map of applied migration versions func getAppliedMigrations(ctx context.Context, db *DB) (map[int]bool, error) { rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations ORDER BY version") if err != nil { return nil, err } defer rows.Close() applied := make(map[int]bool) for rows.Next() { var version int if err := rows.Scan(&version); err != nil { return nil, err } applied[version] = true } return applied, rows.Err() }