168 lines
3.9 KiB
Go
168 lines
3.9 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"fmt"
|
|
"io/fs"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/atlasos/calypso/internal/common/logger"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var migrationsFS embed.FS
|
|
|
|
// RunMigrations executes all pending database migrations
|
|
func RunMigrations(ctx context.Context, db *DB) error {
|
|
log := logger.NewLogger("migrations")
|
|
|
|
// Create migrations table if it doesn't exist
|
|
if err := createMigrationsTable(ctx, db); err != nil {
|
|
return fmt.Errorf("failed to create migrations table: %w", err)
|
|
}
|
|
|
|
// Get all migration files
|
|
migrations, err := getMigrationFiles()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read migration files: %w", err)
|
|
}
|
|
|
|
// Get applied migrations
|
|
applied, err := getAppliedMigrations(ctx, db)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
|
}
|
|
|
|
// Apply pending migrations
|
|
for _, migration := range migrations {
|
|
if applied[migration.Version] {
|
|
log.Debug("Migration already applied", "version", migration.Version)
|
|
continue
|
|
}
|
|
|
|
log.Info("Applying migration", "version", migration.Version, "name", migration.Name)
|
|
|
|
// Read migration SQL
|
|
sql, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", migration.Filename))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read migration file %s: %w", migration.Filename, err)
|
|
}
|
|
|
|
// Execute migration in a transaction
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, string(sql)); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to execute migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
// Record migration
|
|
if _, err := tx.ExecContext(ctx,
|
|
"INSERT INTO schema_migrations (version, applied_at) VALUES ($1, NOW())",
|
|
migration.Version,
|
|
); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
log.Info("Migration applied successfully", "version", migration.Version)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Migration represents a database migration
|
|
type Migration struct {
|
|
Version int
|
|
Name string
|
|
Filename string
|
|
}
|
|
|
|
// getMigrationFiles returns all migration files sorted by version
|
|
func getMigrationFiles() ([]Migration, error) {
|
|
entries, err := fs.ReadDir(migrationsFS, "migrations")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var migrations []Migration
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
|
|
filename := entry.Name()
|
|
if !strings.HasSuffix(filename, ".sql") {
|
|
continue
|
|
}
|
|
|
|
// Parse version from filename: 001_initial_schema.sql
|
|
parts := strings.SplitN(filename, "_", 2)
|
|
if len(parts) < 2 {
|
|
continue
|
|
}
|
|
|
|
version, err := strconv.Atoi(parts[0])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
name := strings.TrimSuffix(parts[1], ".sql")
|
|
migrations = append(migrations, Migration{
|
|
Version: version,
|
|
Name: name,
|
|
Filename: filename,
|
|
})
|
|
}
|
|
|
|
// Sort by version
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Version < migrations[j].Version
|
|
})
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
// createMigrationsTable creates the schema_migrations table
|
|
func createMigrationsTable(ctx context.Context, db *DB) error {
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version INTEGER PRIMARY KEY,
|
|
applied_at TIMESTAMP NOT NULL DEFAULT NOW()
|
|
)
|
|
`
|
|
_, err := db.ExecContext(ctx, query)
|
|
return err
|
|
}
|
|
|
|
// getAppliedMigrations returns a map of applied migration versions
|
|
func getAppliedMigrations(ctx context.Context, db *DB) (map[int]bool, error) {
|
|
rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations ORDER BY version")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
applied := make(map[int]bool)
|
|
for rows.Next() {
|
|
var version int
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, err
|
|
}
|
|
applied[version] = true
|
|
}
|
|
|
|
return applied, rows.Err()
|
|
}
|
|
|