package main import ( "flag" "fmt" "log" "os" "photography-backend/internal/config" "photography-backend/pkg/migration" "photography-backend/pkg/utils/database" "github.com/zeromicro/go-zero/core/conf" ) func main() { // 命令行参数 var ( configFile = flag.String("f", "/etc/photography-api.yaml", "配置文件路径") command = flag.String("c", "status", "迁移命令: up, down, status, reset, migrate") steps = flag.Int("s", 0, "步数(用于 down 和 migrate 命令)") version = flag.String("v", "", "迁移版本(用于特定版本操作)") help = flag.Bool("h", false, "显示帮助信息") ) flag.Parse() // 显示帮助信息 if *help { showHelp() return } // 加载配置 var c config.Config conf.MustLoad(*configFile, &c) // 创建数据库连接 db, err := database.NewDB(c.Database) if err != nil { log.Fatalf("Failed to connect to database: %v", err) } // 创建迁移器 migrator := migration.NewMigrator(db) // 加载所有迁移 migrations := migration.GetAllMigrations() for _, m := range migrations { migrator.AddMigration(m) } // 执行命令 switch *command { case "up": if err := migrator.Up(); err != nil { log.Fatalf("Migration up failed: %v", err) } fmt.Println("All migrations applied successfully!") case "down": if *steps <= 0 { fmt.Println("Please specify the number of steps to rollback with -s flag") os.Exit(1) } if err := migrator.Down(*steps); err != nil { log.Fatalf("Migration down failed: %v", err) } fmt.Printf("Successfully rolled back %d migrations\n", *steps) case "status": if err := migrator.Status(); err != nil { log.Fatalf("Failed to get migration status: %v", err) } case "reset": fmt.Print("Are you sure you want to reset the database? This will DROP ALL TABLES! (y/N): ") var response string fmt.Scanln(&response) if response == "y" || response == "Y" { if err := migrator.Reset(); err != nil { log.Fatalf("Database reset failed: %v", err) } fmt.Println("Database reset successfully!") } else { fmt.Println("Database reset cancelled.") } case "migrate": migrateSteps := *steps if migrateSteps <= 0 { // 如果没有指定步数,默认执行所有待处理迁移 pendingMigrations, err := migrator.GetPendingMigrations() if err != nil { log.Fatalf("Failed to get pending migrations: %v", err) } migrateSteps = len(pendingMigrations) } if err := migrator.Migrate(migrateSteps); err != nil { log.Fatalf("Migration failed: %v", err) } fmt.Printf("Successfully applied %d migrations\n", migrateSteps) case "version": if *version != "" { migration := migration.GetMigrationByVersion(*version) if migration == nil { fmt.Printf("Migration version %s not found\n", *version) os.Exit(1) } fmt.Printf("Version: %s\n", migration.Version) fmt.Printf("Description: %s\n", migration.Description) fmt.Printf("Timestamp: %s\n", migration.Timestamp.Format("2006-01-02 15:04:05")) } else { latest := migration.GetLatestMigrationVersion() if latest == "" { fmt.Println("No migrations found") } else { fmt.Printf("Latest migration version: %s\n", latest) } } case "create": if len(flag.Args()) < 1 { fmt.Println("Please provide a migration name: go run cmd/migrate/main.go -c create \"migration_name\"") os.Exit(1) } migrationName := flag.Args()[0] createMigrationTemplate(migrationName) default: fmt.Printf("Unknown command: %s\n", *command) showHelp() os.Exit(1) } } func showHelp() { fmt.Println("数据库迁移工具使用说明:") fmt.Println() fmt.Println("命令格式:") fmt.Println(" go run cmd/migrate/main.go [选项] [参数]") fmt.Println() fmt.Println("选项:") fmt.Println(" -f string 配置文件路径 (默认: /etc/photography-api.yaml)") fmt.Println(" -c string 迁移命令 (默认: status)") fmt.Println(" -s int 步数,用于 down 和 migrate 命令") fmt.Println(" -v string 迁移版本,用于特定版本操作") fmt.Println(" -h 显示帮助信息") fmt.Println() fmt.Println("命令:") fmt.Println(" status 显示所有迁移的状态") fmt.Println(" up 应用所有待处理的迁移") fmt.Println(" down 回滚指定数量的迁移 (需要 -s 参数)") fmt.Println(" migrate 应用指定数量的迁移 (可选 -s 参数)") fmt.Println(" reset 重置数据库 (删除所有表)") fmt.Println(" version 显示最新迁移版本 (可选 -v 查看特定版本)") fmt.Println(" create 创建新的迁移模板") fmt.Println() fmt.Println("示例:") fmt.Println(" go run cmd/migrate/main.go -c status") fmt.Println(" go run cmd/migrate/main.go -c up") fmt.Println(" go run cmd/migrate/main.go -c down -s 1") fmt.Println(" go run cmd/migrate/main.go -c migrate -s 2") fmt.Println(" go run cmd/migrate/main.go -c version") fmt.Println(" go run cmd/migrate/main.go -c create \"add_user_avatar_field\"") } func createMigrationTemplate(name string) { // 生成版本号(基于当前时间) version := fmt.Sprintf("%d_%06d", getCurrentTimestamp(), getCurrentMicroseconds()) template := fmt.Sprintf(`// Migration: %s // Description: %s // Version: %s package migration import "time" // Add this migration to GetAllMigrations() in migrations.go var migration_%s = Migration{ Version: "%s", Description: "%s", Timestamp: time.Now(), UpSQL: %s -- Add your UP migration SQL here -- Example: -- ALTER TABLE user ADD COLUMN new_field VARCHAR(255) DEFAULT ''; -- CREATE INDEX IF NOT EXISTS idx_user_new_field ON user(new_field); %s, DownSQL: %s -- Add your DOWN migration SQL here (rollback changes) -- Example: -- DROP INDEX IF EXISTS idx_user_new_field; -- ALTER TABLE user DROP COLUMN new_field; -- Note: SQLite doesn't support DROP COLUMN %s, }`, name, name, version, version, version, name, "`", "`", "`", "`") filename := fmt.Sprintf("migrations/%s_%s.go", version, name) // 创建 migrations 目录 if err := os.MkdirAll("migrations", 0755); err != nil { log.Fatalf("Failed to create migrations directory: %v", err) } // 写入模板文件 if err := os.WriteFile(filename, []byte(template), 0644); err != nil { log.Fatalf("Failed to create migration template: %v", err) } fmt.Printf("Created migration template: %s\n", filename) fmt.Println("Please:") fmt.Println("1. Edit the migration file to add your SQL") fmt.Println("2. Add the migration to GetAllMigrations() in pkg/migration/migrations.go") fmt.Println("3. Run 'go run cmd/migrate/main.go -c status' to verify") } func getCurrentTimestamp() int64 { return 20250711000000 // 格式: YYYYMMDDHHMMSS,可以根据需要调整 } func getCurrentMicroseconds() int { return 1 // 当天的迁移计数,可以根据需要调整 }