- 创建完整的迁移框架 (pkg/migration/) - 版本管理系统,时间戳版本号 (YYYYMMDD_HHMMSS) - 事务安全的上下迁移机制 (Up/Down) - 迁移状态跟踪和记录 (migration_records 表) - 命令行迁移工具 (cmd/migrate/main.go) - 生产环境迁移脚本 (scripts/production-migrate.sh) - 生产环境初始化脚本 (scripts/init-production-db.sh) - 迁移测试脚本 (scripts/test-migration.sh) - Makefile 集成 (migrate-up, migrate-down, migrate-status) - 5个预定义迁移 (基础表、默认数据、元数据、收藏、用户资料) - 自动备份机制、预览模式、详细日志 - 完整文档 (docs/DATABASE_MIGRATION.md) 任务13完成,项目完成率达到42.5%
361 lines
8.7 KiB
Go
361 lines
8.7 KiB
Go
package migration
|
||
|
||
import (
|
||
"fmt"
|
||
"log"
|
||
"sort"
|
||
"strings"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// Migration 表示一个数据库迁移
|
||
type Migration struct {
|
||
Version string
|
||
Description string
|
||
UpSQL string
|
||
DownSQL string
|
||
Timestamp time.Time
|
||
}
|
||
|
||
// MigrationRecord 数据库中的迁移记录
|
||
type MigrationRecord struct {
|
||
ID uint `gorm:"primaryKey"`
|
||
Version string `gorm:"uniqueIndex;size:255;not null"`
|
||
Description string `gorm:"size:500"`
|
||
Applied bool `gorm:"default:false"`
|
||
AppliedAt time.Time
|
||
CreatedAt time.Time
|
||
UpdatedAt time.Time
|
||
}
|
||
|
||
// Migrator 迁移管理器
|
||
type Migrator struct {
|
||
db *gorm.DB
|
||
migrations []Migration
|
||
tableName string
|
||
}
|
||
|
||
// NewMigrator 创建新的迁移管理器
|
||
func NewMigrator(db *gorm.DB) *Migrator {
|
||
return &Migrator{
|
||
db: db,
|
||
migrations: []Migration{},
|
||
tableName: "schema_migrations",
|
||
}
|
||
}
|
||
|
||
// AddMigration 添加迁移
|
||
func (m *Migrator) AddMigration(migration Migration) {
|
||
m.migrations = append(m.migrations, migration)
|
||
}
|
||
|
||
// initMigrationTable 初始化迁移表
|
||
func (m *Migrator) initMigrationTable() error {
|
||
// 确保迁移表存在
|
||
if err := m.db.AutoMigrate(&MigrationRecord{}); err != nil {
|
||
return fmt.Errorf("failed to create migration table: %v", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetAppliedMigrations 获取已应用的迁移
|
||
func (m *Migrator) GetAppliedMigrations() ([]string, error) {
|
||
var records []MigrationRecord
|
||
if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
versions := make([]string, len(records))
|
||
for i, record := range records {
|
||
versions[i] = record.Version
|
||
}
|
||
return versions, nil
|
||
}
|
||
|
||
// GetPendingMigrations 获取待应用的迁移
|
||
func (m *Migrator) GetPendingMigrations() ([]Migration, error) {
|
||
appliedVersions, err := m.GetAppliedMigrations()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
appliedMap := make(map[string]bool)
|
||
for _, version := range appliedVersions {
|
||
appliedMap[version] = true
|
||
}
|
||
|
||
var pendingMigrations []Migration
|
||
for _, migration := range m.migrations {
|
||
if !appliedMap[migration.Version] {
|
||
pendingMigrations = append(pendingMigrations, migration)
|
||
}
|
||
}
|
||
|
||
// 按版本号排序
|
||
sort.Slice(pendingMigrations, func(i, j int) bool {
|
||
return pendingMigrations[i].Version < pendingMigrations[j].Version
|
||
})
|
||
|
||
return pendingMigrations, nil
|
||
}
|
||
|
||
// Up 执行迁移(向上)
|
||
func (m *Migrator) Up() error {
|
||
if err := m.initMigrationTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
pendingMigrations, err := m.GetPendingMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(pendingMigrations) == 0 {
|
||
log.Println("No pending migrations")
|
||
return nil
|
||
}
|
||
|
||
for _, migration := range pendingMigrations {
|
||
log.Printf("Applying migration %s: %s", migration.Version, migration.Description)
|
||
|
||
// 开始事务
|
||
tx := m.db.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
|
||
// 执行迁移SQL
|
||
if err := m.executeSQL(tx, migration.UpSQL); err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err)
|
||
}
|
||
|
||
// 记录迁移状态 (使用UPSERT)
|
||
now := time.Now()
|
||
|
||
// 检查记录是否已存在
|
||
var existingRecord MigrationRecord
|
||
err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error
|
||
|
||
if err == gorm.ErrRecordNotFound {
|
||
// 创建新记录
|
||
record := MigrationRecord{
|
||
Version: migration.Version,
|
||
Description: migration.Description,
|
||
Applied: true,
|
||
AppliedAt: now,
|
||
UpdatedAt: now,
|
||
}
|
||
if err := tx.Create(&record).Error; err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to create migration record %s: %v", migration.Version, err)
|
||
}
|
||
} else if err == nil {
|
||
// 更新现有记录
|
||
updates := map[string]interface{}{
|
||
"applied": true,
|
||
"applied_at": now,
|
||
"updated_at": now,
|
||
}
|
||
if err := tx.Model(&existingRecord).Updates(updates).Error; err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
|
||
}
|
||
} else {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err)
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err)
|
||
}
|
||
|
||
log.Printf("Successfully applied migration %s", migration.Version)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Down 回滚迁移(向下)
|
||
func (m *Migrator) Down(steps int) error {
|
||
if err := m.initMigrationTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
appliedVersions, err := m.GetAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(appliedVersions) == 0 {
|
||
log.Println("No applied migrations to rollback")
|
||
return nil
|
||
}
|
||
|
||
// 获取要回滚的迁移(从最新开始)
|
||
rollbackCount := steps
|
||
if rollbackCount > len(appliedVersions) {
|
||
rollbackCount = len(appliedVersions)
|
||
}
|
||
|
||
for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- {
|
||
version := appliedVersions[i]
|
||
migration := m.findMigrationByVersion(version)
|
||
if migration == nil {
|
||
return fmt.Errorf("migration %s not found in migration definitions", version)
|
||
}
|
||
|
||
log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description)
|
||
|
||
// 开始事务
|
||
tx := m.db.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
|
||
// 执行回滚SQL
|
||
if err := m.executeSQL(tx, migration.DownSQL); err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err)
|
||
}
|
||
|
||
// 更新迁移状态
|
||
if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err)
|
||
}
|
||
|
||
log.Printf("Successfully rolled back migration %s", migration.Version)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Status 显示迁移状态
|
||
func (m *Migrator) Status() error {
|
||
if err := m.initMigrationTable(); err != nil {
|
||
return err
|
||
}
|
||
|
||
appliedVersions, err := m.GetAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
appliedMap := make(map[string]bool)
|
||
for _, version := range appliedVersions {
|
||
appliedMap[version] = true
|
||
}
|
||
|
||
// 排序所有迁移
|
||
allMigrations := m.migrations
|
||
sort.Slice(allMigrations, func(i, j int) bool {
|
||
return allMigrations[i].Version < allMigrations[j].Version
|
||
})
|
||
|
||
fmt.Println("Migration Status:")
|
||
fmt.Println("Version | Status | Description")
|
||
fmt.Println("---------------|---------|----------------------------------")
|
||
|
||
for _, migration := range allMigrations {
|
||
status := "Pending"
|
||
if appliedMap[migration.Version] {
|
||
status = "Applied"
|
||
}
|
||
fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// executeSQL 执行SQL语句
|
||
func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error {
|
||
// 分割SQL语句(按分号分割)
|
||
statements := strings.Split(sqlStr, ";")
|
||
|
||
for _, statement := range statements {
|
||
statement = strings.TrimSpace(statement)
|
||
if statement == "" {
|
||
continue
|
||
}
|
||
|
||
if err := tx.Exec(statement).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// findMigrationByVersion 根据版本号查找迁移
|
||
func (m *Migrator) findMigrationByVersion(version string) *Migration {
|
||
for _, migration := range m.migrations {
|
||
if migration.Version == version {
|
||
return &migration
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Reset 重置数据库(谨慎使用)
|
||
func (m *Migrator) Reset() error {
|
||
log.Println("WARNING: This will drop all tables and reset the database!")
|
||
|
||
// 获取所有应用的迁移
|
||
appliedVersions, err := m.GetAppliedMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 回滚所有迁移
|
||
if len(appliedVersions) > 0 {
|
||
if err := m.Down(len(appliedVersions)); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 删除迁移表
|
||
if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil {
|
||
return fmt.Errorf("failed to drop migration table: %v", err)
|
||
}
|
||
|
||
log.Println("Database reset completed")
|
||
return nil
|
||
}
|
||
|
||
// Migrate 执行指定数量的待处理迁移
|
||
func (m *Migrator) Migrate(steps int) error {
|
||
pendingMigrations, err := m.GetPendingMigrations()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(pendingMigrations) == 0 {
|
||
log.Println("No pending migrations")
|
||
return nil
|
||
}
|
||
|
||
migrateCount := steps
|
||
if steps <= 0 || steps > len(pendingMigrations) {
|
||
migrateCount = len(pendingMigrations)
|
||
}
|
||
|
||
// 临时修改migrations列表,只包含要执行的迁移
|
||
originalMigrations := m.migrations
|
||
m.migrations = pendingMigrations[:migrateCount]
|
||
|
||
err = m.Up()
|
||
|
||
// 恢复原始migrations列表
|
||
m.migrations = originalMigrations
|
||
|
||
return err
|
||
} |