Files
photography/backend/pkg/migration/migration.go
xujiang 84e778e033 feat: 完成数据库迁移系统开发
- 创建完整的迁移框架 (pkg/migration/)
- 版本管理系统,时间戳版本号 (YYYYMMDD_HHMMSS)
- 事务安全的上下迁移机制 (Up/Down)
- 迁移状态跟踪和记录 (migration_records 表)
- 命令行迁移工具 (cmd/migrate/main.go)
- 生产环境迁移脚本 (scripts/production-migrate.sh)
- 生产环境初始化脚本 (scripts/init-production-db.sh)
- 迁移测试脚本 (scripts/test-migration.sh)
- Makefile 集成 (migrate-up, migrate-down, migrate-status)
- 5个预定义迁移 (基础表、默认数据、元数据、收藏、用户资料)
- 自动备份机制、预览模式、详细日志
- 完整文档 (docs/DATABASE_MIGRATION.md)

任务13完成,项目完成率达到42.5%
2025-07-11 13:41:52 +08:00

361 lines
8.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package migration
import (
"fmt"
"log"
"sort"
"strings"
"time"
"gorm.io/gorm"
)
// Migration 表示一个数据库迁移
type Migration struct {
Version string
Description string
UpSQL string
DownSQL string
Timestamp time.Time
}
// MigrationRecord 数据库中的迁移记录
type MigrationRecord struct {
ID uint `gorm:"primaryKey"`
Version string `gorm:"uniqueIndex;size:255;not null"`
Description string `gorm:"size:500"`
Applied bool `gorm:"default:false"`
AppliedAt time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
// Migrator 迁移管理器
type Migrator struct {
db *gorm.DB
migrations []Migration
tableName string
}
// NewMigrator 创建新的迁移管理器
func NewMigrator(db *gorm.DB) *Migrator {
return &Migrator{
db: db,
migrations: []Migration{},
tableName: "schema_migrations",
}
}
// AddMigration 添加迁移
func (m *Migrator) AddMigration(migration Migration) {
m.migrations = append(m.migrations, migration)
}
// initMigrationTable 初始化迁移表
func (m *Migrator) initMigrationTable() error {
// 确保迁移表存在
if err := m.db.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration table: %v", err)
}
return nil
}
// GetAppliedMigrations 获取已应用的迁移
func (m *Migrator) GetAppliedMigrations() ([]string, error) {
var records []MigrationRecord
if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil {
return nil, err
}
versions := make([]string, len(records))
for i, record := range records {
versions[i] = record.Version
}
return versions, nil
}
// GetPendingMigrations 获取待应用的迁移
func (m *Migrator) GetPendingMigrations() ([]Migration, error) {
appliedVersions, err := m.GetAppliedMigrations()
if err != nil {
return nil, err
}
appliedMap := make(map[string]bool)
for _, version := range appliedVersions {
appliedMap[version] = true
}
var pendingMigrations []Migration
for _, migration := range m.migrations {
if !appliedMap[migration.Version] {
pendingMigrations = append(pendingMigrations, migration)
}
}
// 按版本号排序
sort.Slice(pendingMigrations, func(i, j int) bool {
return pendingMigrations[i].Version < pendingMigrations[j].Version
})
return pendingMigrations, nil
}
// Up 执行迁移(向上)
func (m *Migrator) Up() error {
if err := m.initMigrationTable(); err != nil {
return err
}
pendingMigrations, err := m.GetPendingMigrations()
if err != nil {
return err
}
if len(pendingMigrations) == 0 {
log.Println("No pending migrations")
return nil
}
for _, migration := range pendingMigrations {
log.Printf("Applying migration %s: %s", migration.Version, migration.Description)
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
return tx.Error
}
// 执行迁移SQL
if err := m.executeSQL(tx, migration.UpSQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err)
}
// 记录迁移状态 (使用UPSERT)
now := time.Now()
// 检查记录是否已存在
var existingRecord MigrationRecord
err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error
if err == gorm.ErrRecordNotFound {
// 创建新记录
record := MigrationRecord{
Version: migration.Version,
Description: migration.Description,
Applied: true,
AppliedAt: now,
UpdatedAt: now,
}
if err := tx.Create(&record).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to create migration record %s: %v", migration.Version, err)
}
} else if err == nil {
// 更新现有记录
updates := map[string]interface{}{
"applied": true,
"applied_at": now,
"updated_at": now,
}
if err := tx.Model(&existingRecord).Updates(updates).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
}
} else {
tx.Rollback()
return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err)
}
log.Printf("Successfully applied migration %s", migration.Version)
}
return nil
}
// Down 回滚迁移(向下)
func (m *Migrator) Down(steps int) error {
if err := m.initMigrationTable(); err != nil {
return err
}
appliedVersions, err := m.GetAppliedMigrations()
if err != nil {
return err
}
if len(appliedVersions) == 0 {
log.Println("No applied migrations to rollback")
return nil
}
// 获取要回滚的迁移(从最新开始)
rollbackCount := steps
if rollbackCount > len(appliedVersions) {
rollbackCount = len(appliedVersions)
}
for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- {
version := appliedVersions[i]
migration := m.findMigrationByVersion(version)
if migration == nil {
return fmt.Errorf("migration %s not found in migration definitions", version)
}
log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description)
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
return tx.Error
}
// 执行回滚SQL
if err := m.executeSQL(tx, migration.DownSQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err)
}
// 更新迁移状态
if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err)
}
log.Printf("Successfully rolled back migration %s", migration.Version)
}
return nil
}
// Status 显示迁移状态
func (m *Migrator) Status() error {
if err := m.initMigrationTable(); err != nil {
return err
}
appliedVersions, err := m.GetAppliedMigrations()
if err != nil {
return err
}
appliedMap := make(map[string]bool)
for _, version := range appliedVersions {
appliedMap[version] = true
}
// 排序所有迁移
allMigrations := m.migrations
sort.Slice(allMigrations, func(i, j int) bool {
return allMigrations[i].Version < allMigrations[j].Version
})
fmt.Println("Migration Status:")
fmt.Println("Version | Status | Description")
fmt.Println("---------------|---------|----------------------------------")
for _, migration := range allMigrations {
status := "Pending"
if appliedMap[migration.Version] {
status = "Applied"
}
fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description)
}
return nil
}
// executeSQL 执行SQL语句
func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error {
// 分割SQL语句按分号分割
statements := strings.Split(sqlStr, ";")
for _, statement := range statements {
statement = strings.TrimSpace(statement)
if statement == "" {
continue
}
if err := tx.Exec(statement).Error; err != nil {
return err
}
}
return nil
}
// findMigrationByVersion 根据版本号查找迁移
func (m *Migrator) findMigrationByVersion(version string) *Migration {
for _, migration := range m.migrations {
if migration.Version == version {
return &migration
}
}
return nil
}
// Reset 重置数据库(谨慎使用)
func (m *Migrator) Reset() error {
log.Println("WARNING: This will drop all tables and reset the database!")
// 获取所有应用的迁移
appliedVersions, err := m.GetAppliedMigrations()
if err != nil {
return err
}
// 回滚所有迁移
if len(appliedVersions) > 0 {
if err := m.Down(len(appliedVersions)); err != nil {
return err
}
}
// 删除迁移表
if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to drop migration table: %v", err)
}
log.Println("Database reset completed")
return nil
}
// Migrate 执行指定数量的待处理迁移
func (m *Migrator) Migrate(steps int) error {
pendingMigrations, err := m.GetPendingMigrations()
if err != nil {
return err
}
if len(pendingMigrations) == 0 {
log.Println("No pending migrations")
return nil
}
migrateCount := steps
if steps <= 0 || steps > len(pendingMigrations) {
migrateCount = len(pendingMigrations)
}
// 临时修改migrations列表只包含要执行的迁移
originalMigrations := m.migrations
m.migrations = pendingMigrations[:migrateCount]
err = m.Up()
// 恢复原始migrations列表
m.migrations = originalMigrations
return err
}