package postgres import ( "fmt" "photography-backend/internal/models" "gorm.io/gorm" ) // UserRepository 用户仓库接口 type UserRepository interface { Create(user *models.User) error GetByID(id uint) (*models.User, error) GetByUsername(username string) (*models.User, error) GetByEmail(email string) (*models.User, error) Update(user *models.User) error Delete(id uint) error List(page, limit int, role string, isActive *bool) ([]*models.User, int64, error) UpdateLastLogin(id uint) error } // userRepository 用户仓库实现 type userRepository struct { db *gorm.DB } // NewUserRepository 创建用户仓库 func NewUserRepository(db *gorm.DB) UserRepository { return &userRepository{db: db} } // Create 创建用户 func (r *userRepository) Create(user *models.User) error { if err := r.db.Create(user).Error; err != nil { return fmt.Errorf("failed to create user: %w", err) } return nil } // GetByID 根据ID获取用户 func (r *userRepository) GetByID(id uint) (*models.User, error) { var user models.User if err := r.db.First(&user, id).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil, nil } return nil, fmt.Errorf("failed to get user by id: %w", err) } return &user, nil } // GetByUsername 根据用户名获取用户 func (r *userRepository) GetByUsername(username string) (*models.User, error) { var user models.User if err := r.db.Where("username = ?", username).First(&user).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil, nil } return nil, fmt.Errorf("failed to get user by username: %w", err) } return &user, nil } // GetByEmail 根据邮箱获取用户 func (r *userRepository) GetByEmail(email string) (*models.User, error) { var user models.User if err := r.db.Where("email = ?", email).First(&user).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil, nil } return nil, fmt.Errorf("failed to get user by email: %w", err) } return &user, nil } // Update 更新用户 func (r *userRepository) Update(user *models.User) error { if err := r.db.Save(user).Error; err != nil { return fmt.Errorf("failed to update user: %w", err) } return nil } // Delete 删除用户 func (r *userRepository) Delete(id uint) error { if err := r.db.Delete(&models.User{}, id).Error; err != nil { return fmt.Errorf("failed to delete user: %w", err) } return nil } // List 获取用户列表 func (r *userRepository) List(page, limit int, role string, isActive *bool) ([]*models.User, int64, error) { var users []*models.User var total int64 query := r.db.Model(&models.User{}) // 添加过滤条件 if role != "" { query = query.Where("role = ?", role) } if isActive != nil { query = query.Where("is_active = ?", *isActive) } // 计算总数 if err := query.Count(&total).Error; err != nil { return nil, 0, fmt.Errorf("failed to count users: %w", err) } // 分页查询 offset := (page - 1) * limit if err := query.Offset(offset).Limit(limit). Order("created_at DESC"). Find(&users).Error; err != nil { return nil, 0, fmt.Errorf("failed to list users: %w", err) } return users, total, nil } // UpdateLastLogin 更新最后登录时间 func (r *userRepository) UpdateLastLogin(id uint) error { if err := r.db.Model(&models.User{}).Where("id = ?", id). Update("last_login", gorm.Expr("NOW()")).Error; err != nil { return fmt.Errorf("failed to update last login: %w", err) } return nil }