package migration import ( "fmt" "log" "sort" "strings" "time" "gorm.io/gorm" ) // Migration 表示一个数据库迁移 type Migration struct { Version string Description string UpSQL string DownSQL string Timestamp time.Time } // MigrationRecord 数据库中的迁移记录 type MigrationRecord struct { ID uint `gorm:"primaryKey"` Version string `gorm:"uniqueIndex;size:255;not null"` Description string `gorm:"size:500"` Applied bool `gorm:"default:false"` AppliedAt time.Time CreatedAt time.Time UpdatedAt time.Time } // Migrator 迁移管理器 type Migrator struct { db *gorm.DB migrations []Migration tableName string } // NewMigrator 创建新的迁移管理器 func NewMigrator(db *gorm.DB) *Migrator { return &Migrator{ db: db, migrations: []Migration{}, tableName: "schema_migrations", } } // AddMigration 添加迁移 func (m *Migrator) AddMigration(migration Migration) { m.migrations = append(m.migrations, migration) } // initMigrationTable 初始化迁移表 func (m *Migrator) initMigrationTable() error { // 确保迁移表存在 if err := m.db.AutoMigrate(&MigrationRecord{}); err != nil { return fmt.Errorf("failed to create migration table: %v", err) } return nil } // GetAppliedMigrations 获取已应用的迁移 func (m *Migrator) GetAppliedMigrations() ([]string, error) { var records []MigrationRecord if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil { return nil, err } versions := make([]string, len(records)) for i, record := range records { versions[i] = record.Version } return versions, nil } // GetPendingMigrations 获取待应用的迁移 func (m *Migrator) GetPendingMigrations() ([]Migration, error) { appliedVersions, err := m.GetAppliedMigrations() if err != nil { return nil, err } appliedMap := make(map[string]bool) for _, version := range appliedVersions { appliedMap[version] = true } var pendingMigrations []Migration for _, migration := range m.migrations { if !appliedMap[migration.Version] { pendingMigrations = append(pendingMigrations, migration) } } // 按版本号排序 sort.Slice(pendingMigrations, func(i, j int) bool { return pendingMigrations[i].Version < pendingMigrations[j].Version }) return pendingMigrations, nil } // Up 执行迁移(向上) func (m *Migrator) Up() error { if err := m.initMigrationTable(); err != nil { return err } pendingMigrations, err := m.GetPendingMigrations() if err != nil { return err } if len(pendingMigrations) == 0 { log.Println("No pending migrations") return nil } for _, migration := range pendingMigrations { log.Printf("Applying migration %s: %s", migration.Version, migration.Description) // 开始事务 tx := m.db.Begin() if tx.Error != nil { return tx.Error } // 执行迁移SQL if err := m.executeSQL(tx, migration.UpSQL); err != nil { tx.Rollback() return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err) } // 记录迁移状态 (使用UPSERT) now := time.Now() // 检查记录是否已存在 var existingRecord MigrationRecord err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error if err == gorm.ErrRecordNotFound { // 创建新记录 record := MigrationRecord{ Version: migration.Version, Description: migration.Description, Applied: true, AppliedAt: now, UpdatedAt: now, } if err := tx.Create(&record).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to create migration record %s: %v", migration.Version, err) } } else if err == nil { // 更新现有记录 updates := map[string]interface{}{ "applied": true, "applied_at": now, "updated_at": now, } if err := tx.Model(&existingRecord).Updates(updates).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err) } } else { tx.Rollback() return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err) } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err) } log.Printf("Successfully applied migration %s", migration.Version) } return nil } // Down 回滚迁移(向下) func (m *Migrator) Down(steps int) error { if err := m.initMigrationTable(); err != nil { return err } appliedVersions, err := m.GetAppliedMigrations() if err != nil { return err } if len(appliedVersions) == 0 { log.Println("No applied migrations to rollback") return nil } // 获取要回滚的迁移(从最新开始) rollbackCount := steps if rollbackCount > len(appliedVersions) { rollbackCount = len(appliedVersions) } for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- { version := appliedVersions[i] migration := m.findMigrationByVersion(version) if migration == nil { return fmt.Errorf("migration %s not found in migration definitions", version) } log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description) // 开始事务 tx := m.db.Begin() if tx.Error != nil { return tx.Error } // 执行回滚SQL if err := m.executeSQL(tx, migration.DownSQL); err != nil { tx.Rollback() return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err) } // 更新迁移状态 if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err) } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err) } log.Printf("Successfully rolled back migration %s", migration.Version) } return nil } // Status 显示迁移状态 func (m *Migrator) Status() error { if err := m.initMigrationTable(); err != nil { return err } appliedVersions, err := m.GetAppliedMigrations() if err != nil { return err } appliedMap := make(map[string]bool) for _, version := range appliedVersions { appliedMap[version] = true } // 排序所有迁移 allMigrations := m.migrations sort.Slice(allMigrations, func(i, j int) bool { return allMigrations[i].Version < allMigrations[j].Version }) fmt.Println("Migration Status:") fmt.Println("Version | Status | Description") fmt.Println("---------------|---------|----------------------------------") for _, migration := range allMigrations { status := "Pending" if appliedMap[migration.Version] { status = "Applied" } fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description) } return nil } // executeSQL 执行SQL语句 func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error { // 分割SQL语句(按分号分割) statements := strings.Split(sqlStr, ";") for _, statement := range statements { statement = strings.TrimSpace(statement) if statement == "" { continue } if err := tx.Exec(statement).Error; err != nil { return err } } return nil } // findMigrationByVersion 根据版本号查找迁移 func (m *Migrator) findMigrationByVersion(version string) *Migration { for _, migration := range m.migrations { if migration.Version == version { return &migration } } return nil } // Reset 重置数据库(谨慎使用) func (m *Migrator) Reset() error { log.Println("WARNING: This will drop all tables and reset the database!") // 获取所有应用的迁移 appliedVersions, err := m.GetAppliedMigrations() if err != nil { return err } // 回滚所有迁移 if len(appliedVersions) > 0 { if err := m.Down(len(appliedVersions)); err != nil { return err } } // 删除迁移表 if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil { return fmt.Errorf("failed to drop migration table: %v", err) } log.Println("Database reset completed") return nil } // Migrate 执行指定数量的待处理迁移 func (m *Migrator) Migrate(steps int) error { pendingMigrations, err := m.GetPendingMigrations() if err != nil { return err } if len(pendingMigrations) == 0 { log.Println("No pending migrations") return nil } migrateCount := steps if steps <= 0 || steps > len(pendingMigrations) { migrateCount = len(pendingMigrations) } // 临时修改migrations列表,只包含要执行的迁移 originalMigrations := m.migrations m.migrations = pendingMigrations[:migrateCount] err = m.Up() // 恢复原始migrations列表 m.migrations = originalMigrations return err }