Files
photography/backend/pkg/migration/migration.go
xujiang 5dd0bc19e4
Some checks failed
部署管理后台 / 🧪 测试和构建 (push) Failing after 1m5s
部署管理后台 / 🔒 安全扫描 (push) Has been skipped
部署后端服务 / 🧪 测试后端 (push) Failing after 3m13s
部署前端网站 / 🧪 测试和构建 (push) Failing after 2m10s
部署管理后台 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🚀 构建并部署 (push) Has been skipped
部署管理后台 / 🔄 回滚部署 (push) Has been skipped
部署前端网站 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🔄 回滚部署 (push) Has been skipped
style: 统一代码格式化 (go fmt + 配置更新)
- 后端:应用 go fmt 自动格式化,统一代码风格
- 前端:更新 API 配置,完善类型安全
- 所有代码符合项目规范,准备生产部署
2025-07-14 10:02:04 +08:00

362 lines
8.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package migration
import (
"fmt"
"log"
"sort"
"strings"
"time"
"gorm.io/gorm"
)
// Migration 表示一个数据库迁移
type Migration struct {
Version string
Description string
UpSQL string
DownSQL string
Timestamp time.Time
}
// MigrationRecord 数据库中的迁移记录
type MigrationRecord struct {
ID uint `gorm:"primaryKey"`
Version string `gorm:"uniqueIndex;size:255;not null"`
Description string `gorm:"size:500"`
Applied bool `gorm:"default:false"`
AppliedAt time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
// Migrator 迁移管理器
type Migrator struct {
db *gorm.DB
migrations []Migration
tableName string
}
// NewMigrator 创建新的迁移管理器
func NewMigrator(db *gorm.DB) *Migrator {
return &Migrator{
db: db,
migrations: []Migration{},
tableName: "schema_migrations",
}
}
// AddMigration 添加迁移
func (m *Migrator) AddMigration(migration Migration) {
m.migrations = append(m.migrations, migration)
}
// initMigrationTable 初始化迁移表
func (m *Migrator) initMigrationTable() error {
// 确保迁移表存在
if err := m.db.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration table: %v", err)
}
return nil
}
// GetAppliedMigrations 获取已应用的迁移
func (m *Migrator) GetAppliedMigrations() ([]string, error) {
var records []MigrationRecord
if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil {
return nil, err
}
versions := make([]string, len(records))
for i, record := range records {
versions[i] = record.Version
}
return versions, nil
}
// GetPendingMigrations 获取待应用的迁移
func (m *Migrator) GetPendingMigrations() ([]Migration, error) {
appliedVersions, err := m.GetAppliedMigrations()
if err != nil {
return nil, err
}
appliedMap := make(map[string]bool)
for _, version := range appliedVersions {
appliedMap[version] = true
}
var pendingMigrations []Migration
for _, migration := range m.migrations {
if !appliedMap[migration.Version] {
pendingMigrations = append(pendingMigrations, migration)
}
}
// 按版本号排序
sort.Slice(pendingMigrations, func(i, j int) bool {
return pendingMigrations[i].Version < pendingMigrations[j].Version
})
return pendingMigrations, nil
}
// Up 执行迁移(向上)
func (m *Migrator) Up() error {
if err := m.initMigrationTable(); err != nil {
return err
}
pendingMigrations, err := m.GetPendingMigrations()
if err != nil {
return err
}
if len(pendingMigrations) == 0 {
log.Println("No pending migrations")
return nil
}
for _, migration := range pendingMigrations {
log.Printf("Applying migration %s: %s", migration.Version, migration.Description)
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
return tx.Error
}
// 执行迁移SQL
if err := m.executeSQL(tx, migration.UpSQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err)
}
// 记录迁移状态 (使用UPSERT)
now := time.Now()
// 检查记录是否已存在
var existingRecord MigrationRecord
err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error
if err == gorm.ErrRecordNotFound {
// 创建新记录
record := MigrationRecord{
Version: migration.Version,
Description: migration.Description,
Applied: true,
AppliedAt: now,
UpdatedAt: now,
}
if err := tx.Create(&record).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to create migration record %s: %v", migration.Version, err)
}
} else if err == nil {
// 更新现有记录
updates := map[string]interface{}{
"applied": true,
"applied_at": now,
"updated_at": now,
}
if err := tx.Model(&existingRecord).Updates(updates).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
}
} else {
tx.Rollback()
return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err)
}
log.Printf("Successfully applied migration %s", migration.Version)
}
return nil
}
// Down 回滚迁移(向下)
func (m *Migrator) Down(steps int) error {
if err := m.initMigrationTable(); err != nil {
return err
}
appliedVersions, err := m.GetAppliedMigrations()
if err != nil {
return err
}
if len(appliedVersions) == 0 {
log.Println("No applied migrations to rollback")
return nil
}
// 获取要回滚的迁移(从最新开始)
rollbackCount := steps
if rollbackCount > len(appliedVersions) {
rollbackCount = len(appliedVersions)
}
for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- {
version := appliedVersions[i]
migration := m.findMigrationByVersion(version)
if migration == nil {
return fmt.Errorf("migration %s not found in migration definitions", version)
}
log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description)
// 开始事务
tx := m.db.Begin()
if tx.Error != nil {
return tx.Error
}
// 执行回滚SQL
if err := m.executeSQL(tx, migration.DownSQL); err != nil {
tx.Rollback()
return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err)
}
// 更新迁移状态
if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err)
}
log.Printf("Successfully rolled back migration %s", migration.Version)
}
return nil
}
// Status 显示迁移状态
func (m *Migrator) Status() error {
if err := m.initMigrationTable(); err != nil {
return err
}
appliedVersions, err := m.GetAppliedMigrations()
if err != nil {
return err
}
appliedMap := make(map[string]bool)
for _, version := range appliedVersions {
appliedMap[version] = true
}
// 排序所有迁移
allMigrations := m.migrations
sort.Slice(allMigrations, func(i, j int) bool {
return allMigrations[i].Version < allMigrations[j].Version
})
fmt.Println("Migration Status:")
fmt.Println("Version | Status | Description")
fmt.Println("---------------|---------|----------------------------------")
for _, migration := range allMigrations {
status := "Pending"
if appliedMap[migration.Version] {
status = "Applied"
}
fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description)
}
return nil
}
// executeSQL 执行SQL语句
func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error {
// 分割SQL语句按分号分割
statements := strings.Split(sqlStr, ";")
for _, statement := range statements {
statement = strings.TrimSpace(statement)
if statement == "" {
continue
}
if err := tx.Exec(statement).Error; err != nil {
return err
}
}
return nil
}
// findMigrationByVersion 根据版本号查找迁移
func (m *Migrator) findMigrationByVersion(version string) *Migration {
for _, migration := range m.migrations {
if migration.Version == version {
return &migration
}
}
return nil
}
// Reset 重置数据库(谨慎使用)
func (m *Migrator) Reset() error {
log.Println("WARNING: This will drop all tables and reset the database!")
// 获取所有应用的迁移
appliedVersions, err := m.GetAppliedMigrations()
if err != nil {
return err
}
// 回滚所有迁移
if len(appliedVersions) > 0 {
if err := m.Down(len(appliedVersions)); err != nil {
return err
}
}
// 删除迁移表
if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to drop migration table: %v", err)
}
log.Println("Database reset completed")
return nil
}
// Migrate 执行指定数量的待处理迁移
func (m *Migrator) Migrate(steps int) error {
pendingMigrations, err := m.GetPendingMigrations()
if err != nil {
return err
}
if len(pendingMigrations) == 0 {
log.Println("No pending migrations")
return nil
}
migrateCount := steps
if steps <= 0 || steps > len(pendingMigrations) {
migrateCount = len(pendingMigrations)
}
// 临时修改migrations列表只包含要执行的迁移
originalMigrations := m.migrations
m.migrations = pendingMigrations[:migrateCount]
err = m.Up()
// 恢复原始migrations列表
m.migrations = originalMigrations
return err
}