Files
geek-life/migration/migrator.go

155 lines
3.5 KiB
Go

package migration
import (
"database/sql"
"fmt"
"io/ioutil"
"path/filepath"
"sort"
"strings"
"github.com/jmoiron/sqlx"
)
// Migration represents a database migration
type Migration struct {
Version string
Name string
SQL string
}
// Migrator handles database migrations
type Migrator struct {
db *sqlx.DB
migrationsDir string
}
// NewMigrator creates a new migrator instance
func NewMigrator(db *sqlx.DB, migrationsDir string) *Migrator {
return &Migrator{
db: db,
migrationsDir: migrationsDir,
}
}
// Run executes all pending migrations
func (m *Migrator) Run() error {
// Create migrations table if it doesn't exist
if err := m.createMigrationsTable(); err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Get all migration files
migrations, err := m.loadMigrations()
if err != nil {
return fmt.Errorf("failed to load migrations: %w", err)
}
// Get applied migrations
appliedMigrations, err := m.getAppliedMigrations()
if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err)
}
// Apply pending migrations
for _, migration := range migrations {
if !contains(appliedMigrations, migration.Version) {
if err := m.applyMigration(migration); err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration.Version, err)
}
fmt.Printf("Applied migration: %s - %s\n", migration.Version, migration.Name)
}
}
return nil
}
// createMigrationsTable creates the migrations tracking table
func (m *Migrator) createMigrationsTable() error {
query := `
CREATE TABLE IF NOT EXISTS schema_migrations (
version VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT NOW()
)
`
_, err := m.db.Exec(query)
return err
}
// loadMigrations loads all migration files from the migrations directory
func (m *Migrator) loadMigrations() ([]Migration, error) {
files, err := filepath.Glob(filepath.Join(m.migrationsDir, "*.sql"))
if err != nil {
return nil, err
}
var migrations []Migration
for _, file := range files {
content, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
filename := filepath.Base(file)
parts := strings.SplitN(filename, "_", 2)
if len(parts) != 2 {
continue // Skip files that don't match the pattern
}
version := parts[0]
name := strings.TrimSuffix(parts[1], ".sql")
migrations = append(migrations, Migration{
Version: version,
Name: name,
SQL: string(content),
})
}
// Sort migrations by version
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
return migrations, nil
}
// getAppliedMigrations returns a list of applied migration versions
func (m *Migrator) getAppliedMigrations() ([]string, error) {
var versions []string
err := m.db.Select(&versions, "SELECT version FROM schema_migrations ORDER BY version")
return versions, err
}
// applyMigration applies a single migration
func (m *Migrator) applyMigration(migration Migration) error {
tx, err := m.db.Beginx()
if err != nil {
return err
}
defer tx.Rollback()
// Execute the migration SQL
_, err = tx.Exec(migration.SQL)
if err != nil {
return err
}
// Record the migration as applied
_, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES ($1)", migration.Version)
if err != nil {
return err
}
return tx.Commit()
}
// contains checks if a slice contains a string
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}