Files
calypso/backend/internal/common/database/migrations.go
2025-12-29 02:44:52 +07:00

168 lines
3.9 KiB
Go

package database
import (
"context"
"embed"
"fmt"
"io/fs"
"sort"
"strconv"
"strings"
"github.com/atlasos/calypso/internal/common/logger"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
// RunMigrations executes all pending database migrations
func RunMigrations(ctx context.Context, db *DB) error {
log := logger.NewLogger("migrations")
// Create migrations table if it doesn't exist
if err := createMigrationsTable(ctx, db); err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Get all migration files
migrations, err := getMigrationFiles()
if err != nil {
return fmt.Errorf("failed to read migration files: %w", err)
}
// Get applied migrations
applied, err := getAppliedMigrations(ctx, db)
if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err)
}
// Apply pending migrations
for _, migration := range migrations {
if applied[migration.Version] {
log.Debug("Migration already applied", "version", migration.Version)
continue
}
log.Info("Applying migration", "version", migration.Version, "name", migration.Name)
// Read migration SQL
sql, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", migration.Filename))
if err != nil {
return fmt.Errorf("failed to read migration file %s: %w", migration.Filename, err)
}
// Execute migration in a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
if _, err := tx.ExecContext(ctx, string(sql)); err != nil {
tx.Rollback()
return fmt.Errorf("failed to execute migration %d: %w", migration.Version, err)
}
// Record migration
if _, err := tx.ExecContext(ctx,
"INSERT INTO schema_migrations (version, applied_at) VALUES ($1, NOW())",
migration.Version,
); err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
}
log.Info("Migration applied successfully", "version", migration.Version)
}
return nil
}
// Migration represents a database migration
type Migration struct {
Version int
Name string
Filename string
}
// getMigrationFiles returns all migration files sorted by version
func getMigrationFiles() ([]Migration, error) {
entries, err := fs.ReadDir(migrationsFS, "migrations")
if err != nil {
return nil, err
}
var migrations []Migration
for _, entry := range entries {
if entry.IsDir() {
continue
}
filename := entry.Name()
if !strings.HasSuffix(filename, ".sql") {
continue
}
// Parse version from filename: 001_initial_schema.sql
parts := strings.SplitN(filename, "_", 2)
if len(parts) < 2 {
continue
}
version, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
name := strings.TrimSuffix(parts[1], ".sql")
migrations = append(migrations, Migration{
Version: version,
Name: name,
Filename: filename,
})
}
// Sort by version
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
return migrations, nil
}
// createMigrationsTable creates the schema_migrations table
func createMigrationsTable(ctx context.Context, db *DB) error {
query := `
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT NOW()
)
`
_, err := db.ExecContext(ctx, query)
return err
}
// getAppliedMigrations returns a map of applied migration versions
func getAppliedMigrations(ctx context.Context, db *DB) (map[int]bool, error) {
rows, err := db.QueryContext(ctx, "SELECT version FROM schema_migrations ORDER BY version")
if err != nil {
return nil, err
}
defer rows.Close()
applied := make(map[int]bool)
for rows.Next() {
var version int
if err := rows.Scan(&version); err != nil {
return nil, err
}
applied[version] = true
}
return applied, rows.Err()
}