package database import ( "fmt" "log" "time" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" "photography-backend/internal/config" "photography-backend/internal/model/entity" ) // Database 数据库连接管理器 type Database struct { db *gorm.DB config *config.DatabaseConfig } // New 创建新的数据库连接 func New(cfg *config.Config) (*Database, error) { var db *gorm.DB var err error // 配置 GORM 日志 gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), } if cfg.App.Environment == "production" { gormConfig.Logger = logger.Default.LogMode(logger.Error) } // 根据环境选择数据库 if cfg.App.Environment == "test" || cfg.Database.Host == "" { // 使用 SQLite 进行测试或开发 db, err = gorm.Open(sqlite.Open("photography_dev.db"), gormConfig) } else { // 使用 PostgreSQL 进行生产 dsn := cfg.GetDatabaseDSN() db, err = gorm.Open(postgres.Open(dsn), gormConfig) } if err != nil { return nil, fmt.Errorf("failed to connect to database: %w", err) } // 配置连接池 sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err) } // 设置连接池参数 sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns) sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns) sqlDB.SetConnMaxLifetime(time.Duration(cfg.Database.ConnMaxLifetime) * time.Minute) // 测试连接 if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) } return &Database{ db: db, config: &cfg.Database, }, nil } // GetDB 获取数据库连接 func (d *Database) GetDB() *gorm.DB { return d.db } // Close 关闭数据库连接 func (d *Database) Close() error { sqlDB, err := d.db.DB() if err != nil { return err } return sqlDB.Close() } // AutoMigrate 自动迁移数据库表 func (d *Database) AutoMigrate() error { // 按依赖关系顺序迁移表 entities := []interface{}{ &entity.User{}, &entity.Category{}, &entity.Tag{}, &entity.Album{}, &entity.Photo{}, &entity.PhotoTag{}, } for _, entity := range entities { if err := d.db.AutoMigrate(entity); err != nil { return fmt.Errorf("failed to migrate %T: %w", entity, err) } } log.Println("Database migration completed successfully") return nil } // Seed 填充种子数据 func (d *Database) Seed() error { // 检查是否已有数据 var userCount int64 if err := d.db.Model(&entity.User{}).Count(&userCount).Error; err != nil { return fmt.Errorf("failed to count users: %w", err) } if userCount > 0 { log.Println("Database already has data, skipping seed") return nil } // 创建事务 tx := d.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } defer tx.Rollback() // 创建默认用户 if err := d.seedUsers(tx); err != nil { return fmt.Errorf("failed to seed users: %w", err) } // 创建默认分类 if err := d.seedCategories(tx); err != nil { return fmt.Errorf("failed to seed categories: %w", err) } // 创建默认标签 if err := d.seedTags(tx); err != nil { return fmt.Errorf("failed to seed tags: %w", err) } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } log.Println("Database seeding completed successfully") return nil } // seedUsers 创建默认用户 func (d *Database) seedUsers(tx *gorm.DB) error { users := []entity.User{ { Username: "admin", Email: "admin@photography.com", Password: "$2a$10$D4Zz6m3j1YJzp8Y7zW4l2OXcQ5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0", // admin123 Name: "管理员", Role: entity.UserRoleAdmin, IsActive: true, IsPublic: true, EmailVerified: true, }, { Username: "photographer", Email: "photographer@photography.com", Password: "$2a$10$D4Zz6m3j1YJzp8Y7zW4l2OXcQ5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0", // admin123 Name: "摄影师", Role: entity.UserRolePhotographer, IsActive: true, IsPublic: true, EmailVerified: true, }, { Username: "demo", Email: "demo@photography.com", Password: "$2a$10$D4Zz6m3j1YJzp8Y7zW4l2OXcQ5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0", // admin123 Name: "演示用户", Role: entity.UserRoleUser, IsActive: true, IsPublic: true, EmailVerified: true, }, } for _, user := range users { if err := tx.Create(&user).Error; err != nil { return fmt.Errorf("failed to create user %s: %w", user.Username, err) } } return nil } // seedCategories 创建默认分类 func (d *Database) seedCategories(tx *gorm.DB) error { categories := []entity.Category{ { Name: "风景摄影", Description: "自然风景摄影作品", Color: "#10b981", Sort: 1, IsActive: true, }, { Name: "人像摄影", Description: "人物肖像摄影作品", Color: "#f59e0b", Sort: 2, IsActive: true, }, { Name: "街头摄影", Description: "街头纪实摄影作品", Color: "#ef4444", Sort: 3, IsActive: true, }, { Name: "建筑摄影", Description: "建筑和城市摄影作品", Color: "#3b82f6", Sort: 4, IsActive: true, }, { Name: "抽象摄影", Description: "抽象艺术摄影作品", Color: "#8b5cf6", Sort: 5, IsActive: true, }, } for _, category := range categories { if err := tx.Create(&category).Error; err != nil { return fmt.Errorf("failed to create category %s: %w", category.Name, err) } } return nil } // seedTags 创建默认标签 func (d *Database) seedTags(tx *gorm.DB) error { tags := []entity.Tag{ {Name: "自然", Color: "#10b981"}, {Name: "人物", Color: "#f59e0b"}, {Name: "城市", Color: "#3b82f6"}, {Name: "夜景", Color: "#1f2937"}, {Name: "黑白", Color: "#6b7280"}, {Name: "色彩", Color: "#ec4899"}, {Name: "构图", Color: "#8b5cf6"}, {Name: "光影", Color: "#f97316"}, {Name: "街头", Color: "#ef4444"}, {Name: "建筑", Color: "#0891b2"}, {Name: "风景", Color: "#10b981"}, {Name: "抽象", Color: "#8b5cf6"}, {Name: "微距", Color: "#84cc16"}, {Name: "运动", Color: "#f97316"}, {Name: "动物", Color: "#8b5cf6"}, } for _, tag := range tags { if err := tx.Create(&tag).Error; err != nil { return fmt.Errorf("failed to create tag %s: %w", tag.Name, err) } } return nil } // HealthCheck 健康检查 func (d *Database) HealthCheck() error { sqlDB, err := d.db.DB() if err != nil { return fmt.Errorf("failed to get underlying sql.DB: %w", err) } if err := sqlDB.Ping(); err != nil { return fmt.Errorf("database ping failed: %w", err) } return nil } // GetStats 获取数据库统计信息 func (d *Database) GetStats() (map[string]interface{}, error) { sqlDB, err := d.db.DB() if err != nil { return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err) } stats := sqlDB.Stats() // 获取表记录数 var userCount, photoCount, albumCount, categoryCount, tagCount int64 d.db.Model(&entity.User{}).Count(&userCount) d.db.Model(&entity.Photo{}).Count(&photoCount) d.db.Model(&entity.Album{}).Count(&albumCount) d.db.Model(&entity.Category{}).Count(&categoryCount) d.db.Model(&entity.Tag{}).Count(&tagCount) return map[string]interface{}{ "connection_stats": map[string]interface{}{ "max_open_connections": stats.MaxOpenConnections, "open_connections": stats.OpenConnections, "in_use": stats.InUse, "idle": stats.Idle, }, "table_counts": map[string]interface{}{ "users": userCount, "photos": photoCount, "albums": albumCount, "categories": categoryCount, "tags": tagCount, }, }, nil } // Transaction 执行事务 func (d *Database) Transaction(fn func(*gorm.DB) error) error { tx := d.db.Begin() if tx.Error != nil { return fmt.Errorf("failed to begin transaction: %w", tx.Error) } defer tx.Rollback() if err := fn(tx); err != nil { return err } if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil }