155 lines
3.5 KiB
Go
155 lines
3.5 KiB
Go
package migration
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
// Migration represents a database migration
|
|
type Migration struct {
|
|
Version string
|
|
Name string
|
|
SQL string
|
|
}
|
|
|
|
// Migrator handles database migrations
|
|
type Migrator struct {
|
|
db *sqlx.DB
|
|
migrationsDir string
|
|
}
|
|
|
|
// NewMigrator creates a new migrator instance
|
|
func NewMigrator(db *sqlx.DB, migrationsDir string) *Migrator {
|
|
return &Migrator{
|
|
db: db,
|
|
migrationsDir: migrationsDir,
|
|
}
|
|
}
|
|
|
|
// Run executes all pending migrations
|
|
func (m *Migrator) Run() error {
|
|
// Create migrations table if it doesn't exist
|
|
if err := m.createMigrationsTable(); err != nil {
|
|
return fmt.Errorf("failed to create migrations table: %w", err)
|
|
}
|
|
|
|
// Get all migration files
|
|
migrations, err := m.loadMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load migrations: %w", err)
|
|
}
|
|
|
|
// Get applied migrations
|
|
appliedMigrations, err := m.getAppliedMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
|
}
|
|
|
|
// Apply pending migrations
|
|
for _, migration := range migrations {
|
|
if !contains(appliedMigrations, migration.Version) {
|
|
if err := m.applyMigration(migration); err != nil {
|
|
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
|
|
}
|
|
fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Name)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createMigrationsTable creates the migrations tracking table
|
|
func (m *Migrator) createMigrationsTable() error {
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version VARCHAR(255) PRIMARY KEY,
|
|
applied_at TIMESTAMP NOT NULL DEFAULT NOW()
|
|
)
|
|
`
|
|
_, err := m.db.Exec(query)
|
|
return err
|
|
}
|
|
|
|
// loadMigrations loads all migration files from the migrations directory
|
|
func (m *Migrator) loadMigrations() ([]Migration, error) {
|
|
files, err := filepath.Glob(filepath.Join(m.migrationsDir, "*.sql"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var migrations []Migration
|
|
for _, file := range files {
|
|
content, err := ioutil.ReadFile(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
filename := filepath.Base(file)
|
|
parts := strings.SplitN(filename, "_", 2)
|
|
if len(parts) != 2 {
|
|
continue // Skip files that don't match the pattern
|
|
}
|
|
|
|
version := parts[0]
|
|
name := strings.TrimSuffix(parts[1], ".sql")
|
|
|
|
migrations = append(migrations, Migration{
|
|
Version: version,
|
|
Name: name,
|
|
SQL: string(content),
|
|
})
|
|
}
|
|
|
|
// Sort migrations by version
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Version < migrations[j].Version
|
|
})
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
// getAppliedMigrations returns a list of applied migration versions
|
|
func (m *Migrator) getAppliedMigrations() ([]string, error) {
|
|
var versions []string
|
|
err := m.db.Select(&versions, "SELECT version FROM schema_migrations ORDER BY version")
|
|
return versions, err
|
|
}
|
|
|
|
// applyMigration applies a single migration
|
|
func (m *Migrator) applyMigration(migration Migration) error {
|
|
tx, err := m.db.Beginx()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Execute the migration SQL
|
|
_, err = tx.Exec(migration.SQL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Record the migration as applied
|
|
_, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES ($1)", migration.Version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// contains checks if a slice contains a string
|
|
func contains(slice []string, item string) bool {
|
|
for _, s := range slice {
|
|
if s == item {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
} |