package migration import ( "database/sql" "fmt" "io/ioutil" "path/filepath" "sort" "strings" "github.com/jmoiron/sqlx" ) // Migration represents a database migration type Migration struct { Version string Name string SQL string } // Migrator handles database migrations type Migrator struct { db *sqlx.DB migrationsDir string } // NewMigrator creates a new migrator instance func NewMigrator(db *sqlx.DB, migrationsDir string) *Migrator { return &Migrator{ db: db, migrationsDir: migrationsDir, } } // Run executes all pending migrations func (m *Migrator) Run() error { // Create migrations table if it doesn't exist if err := m.createMigrationsTable(); err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } // Get all migration files migrations, err := m.loadMigrations() if err != nil { return fmt.Errorf("failed to load migrations: %w", err) } // Get applied migrations appliedMigrations, err := m.getAppliedMigrations() if err != nil { return fmt.Errorf("failed to get applied migrations: %w", err) } // Apply pending migrations for _, migration := range migrations { if !contains(appliedMigrations, migration.Version) { if err := m.applyMigration(migration); err != nil { return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err) } fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Name) } } return nil } // createMigrationsTable creates the migrations tracking table func (m *Migrator) createMigrationsTable() error { query := ` CREATE TABLE IF NOT EXISTS schema_migrations ( version VARCHAR(255) PRIMARY KEY, applied_at TIMESTAMP NOT NULL DEFAULT NOW() ) ` _, err := m.db.Exec(query) return err } // loadMigrations loads all migration files from the migrations directory func (m *Migrator) loadMigrations() ([]Migration, error) { files, err := filepath.Glob(filepath.Join(m.migrationsDir, "*.sql")) if err != nil { return nil, err } var migrations []Migration for _, file := range files { content, err := ioutil.ReadFile(file) if err != nil { return nil, err } filename := filepath.Base(file) parts := strings.SplitN(filename, "_", 2) if len(parts) != 2 { continue // Skip files that don't match the pattern } version := parts[0] name := strings.TrimSuffix(parts[1], ".sql") migrations = append(migrations, Migration{ Version: version, Name: name, SQL: string(content), }) } // Sort migrations by version sort.Slice(migrations, func(i, j int) bool { return migrations[i].Version < migrations[j].Version }) return migrations, nil } // getAppliedMigrations returns a list of applied migration versions func (m *Migrator) getAppliedMigrations() ([]string, error) { var versions []string err := m.db.Select(&versions, "SELECT version FROM schema_migrations ORDER BY version") return versions, err } // applyMigration applies a single migration func (m *Migrator) applyMigration(migration Migration) error { tx, err := m.db.Beginx() if err != nil { return err } defer tx.Rollback() // Execute the migration SQL _, err = tx.Exec(migration.SQL) if err != nil { return err } // Record the migration as applied _, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES ($1)", migration.Version) if err != nil { return err } return tx.Commit() } // contains checks if a slice contains a string func contains(slice []string, item string) bool { for _, s := range slice { if s == item { return true } } return false }