package service import ( "context" "errors" "fmt" "strings" "photography-backend/internal/models" "photography-backend/internal/utils" "go.uber.org/zap" "gorm.io/gorm" ) type CategoryService struct { db *gorm.DB logger *zap.Logger } func NewCategoryService(db *gorm.DB, logger *zap.Logger) *CategoryService { return &CategoryService{ db: db, logger: logger, } } // GetCategories 获取分类列表 func (s *CategoryService) GetCategories(ctx context.Context, parentID *uint) ([]models.Category, error) { var categories []models.Category query := s.db.WithContext(ctx).Order("sort_order ASC, created_at ASC") if parentID != nil { query = query.Where("parent_id = ?", *parentID) } else { query = query.Where("parent_id IS NULL") } if err := query.Find(&categories).Error; err != nil { s.logger.Error("Failed to get categories", zap.Error(err)) return nil, err } return categories, nil } // GetCategoryTree 获取分类树 func (s *CategoryService) GetCategoryTree(ctx context.Context) ([]models.CategoryTree, error) { var categories []models.Category if err := s.db.WithContext(ctx). Order("sort_order ASC, created_at ASC"). Find(&categories).Error; err != nil { s.logger.Error("Failed to get all categories", zap.Error(err)) return nil, err } // 构建树形结构 tree := s.buildCategoryTree(categories, nil) return tree, nil } // GetCategoryByID 根据ID获取分类 func (s *CategoryService) GetCategoryByID(ctx context.Context, id uint) (*models.Category, error) { var category models.Category if err := s.db.WithContext(ctx).First(&category, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("category not found") } s.logger.Error("Failed to get category by ID", zap.Error(err), zap.Uint("id", id)) return nil, err } return &category, nil } // GetCategoryBySlug 根据slug获取分类 func (s *CategoryService) GetCategoryBySlug(ctx context.Context, slug string) (*models.Category, error) { var category models.Category if err := s.db.WithContext(ctx).Where("slug = ?", slug).First(&category).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("category not found") } s.logger.Error("Failed to get category by slug", zap.Error(err), zap.String("slug", slug)) return nil, err } return &category, nil } // CreateCategory 创建分类 func (s *CategoryService) CreateCategory(ctx context.Context, req *models.CreateCategoryRequest) (*models.Category, error) { // 验证slug唯一性 if err := s.validateSlugUnique(ctx, req.Slug, 0); err != nil { return nil, err } // 验证父分类存在性 if req.ParentID != nil { var parentCategory models.Category if err := s.db.WithContext(ctx).First(&parentCategory, *req.ParentID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("parent category not found") } return nil, err } } // 获取排序顺序 sortOrder := s.getNextSortOrder(ctx, req.ParentID) category := &models.Category{ Name: req.Name, Slug: req.Slug, Description: req.Description, ParentID: req.ParentID, SortOrder: sortOrder, IsActive: true, } if err := s.db.WithContext(ctx).Create(category).Error; err != nil { s.logger.Error("Failed to create category", zap.Error(err)) return nil, err } s.logger.Info("Category created successfully", zap.Uint("id", category.ID)) return category, nil } // UpdateCategory 更新分类 func (s *CategoryService) UpdateCategory(ctx context.Context, id uint, req *models.UpdateCategoryRequest) (*models.Category, error) { // 检查分类是否存在 var category models.Category if err := s.db.WithContext(ctx).First(&category, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("category not found") } return nil, err } // 验证slug唯一性 if req.Slug != nil && *req.Slug != category.Slug { if err := s.validateSlugUnique(ctx, *req.Slug, id); err != nil { return nil, err } } // 验证父分类(防止循环引用) if req.ParentID != nil && *req.ParentID != category.ParentID { if err := s.validateParentCategory(ctx, id, *req.ParentID); err != nil { return nil, err } } // 构建更新数据 updates := map[string]interface{}{} if req.Name != nil { updates["name"] = *req.Name } if req.Slug != nil { updates["slug"] = *req.Slug } if req.Description != nil { updates["description"] = *req.Description } if req.ParentID != nil { if *req.ParentID == 0 { updates["parent_id"] = nil } else { updates["parent_id"] = *req.ParentID } } if req.SortOrder != nil { updates["sort_order"] = *req.SortOrder } if req.IsActive != nil { updates["is_active"] = *req.IsActive } if len(updates) > 0 { if err := s.db.WithContext(ctx).Model(&category).Updates(updates).Error; err != nil { s.logger.Error("Failed to update category", zap.Error(err)) return nil, err } } s.logger.Info("Category updated successfully", zap.Uint("id", id)) return &category, nil } // DeleteCategory 删除分类 func (s *CategoryService) DeleteCategory(ctx context.Context, id uint) error { // 检查分类是否存在 var category models.Category if err := s.db.WithContext(ctx).First(&category, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("category not found") } return err } // 检查是否有子分类 var childCount int64 if err := s.db.WithContext(ctx).Model(&models.Category{}). Where("parent_id = ?", id).Count(&childCount).Error; err != nil { return err } if childCount > 0 { return errors.New("cannot delete category with subcategories") } // 检查是否有关联的照片 var photoCount int64 if err := s.db.WithContext(ctx).Table("photo_categories"). Where("category_id = ?", id).Count(&photoCount).Error; err != nil { return err } if photoCount > 0 { return errors.New("cannot delete category with associated photos") } // 删除分类 if err := s.db.WithContext(ctx).Delete(&category).Error; err != nil { s.logger.Error("Failed to delete category", zap.Error(err)) return err } s.logger.Info("Category deleted successfully", zap.Uint("id", id)) return nil } // ReorderCategories 重新排序分类 func (s *CategoryService) ReorderCategories(ctx context.Context, parentID *uint, categoryIDs []uint) error { // 验证所有分类都属于同一父分类 var categories []models.Category query := s.db.WithContext(ctx).Where("id IN ?", categoryIDs) if parentID != nil { query = query.Where("parent_id = ?", *parentID) } else { query = query.Where("parent_id IS NULL") } if err := query.Find(&categories).Error; err != nil { return err } if len(categories) != len(categoryIDs) { return errors.New("invalid category IDs") } // 开始事务 tx := s.db.WithContext(ctx).Begin() if tx.Error != nil { return tx.Error } defer tx.Rollback() // 更新排序 for i, categoryID := range categoryIDs { if err := tx.Model(&models.Category{}). Where("id = ?", categoryID). Update("sort_order", i+1).Error; err != nil { return err } } // 提交事务 if err := tx.Commit().Error; err != nil { return err } s.logger.Info("Categories reordered successfully", zap.Int("count", len(categoryIDs))) return nil } // GetCategoryStats 获取分类统计信息 func (s *CategoryService) GetCategoryStats(ctx context.Context) (*models.CategoryStats, error) { var stats models.CategoryStats // 总分类数 if err := s.db.WithContext(ctx).Model(&models.Category{}).Count(&stats.Total).Error; err != nil { return nil, err } // 活跃分类数 if err := s.db.WithContext(ctx).Model(&models.Category{}). Where("is_active = ?", true).Count(&stats.Active).Error; err != nil { return nil, err } // 顶级分类数 if err := s.db.WithContext(ctx).Model(&models.Category{}). Where("parent_id IS NULL").Count(&stats.TopLevel).Error; err != nil { return nil, err } // 各分类照片数量 var categoryPhotoStats []struct { CategoryID uint `json:"category_id"` Name string `json:"name"` PhotoCount int64 `json:"photo_count"` } if err := s.db.WithContext(ctx). Table("categories"). Select("categories.id as category_id, categories.name, COUNT(photo_categories.photo_id) as photo_count"). Joins("LEFT JOIN photo_categories ON categories.id = photo_categories.category_id"). Group("categories.id, categories.name"). Order("photo_count DESC"). Limit(10). Find(&categoryPhotoStats).Error; err != nil { return nil, err } stats.PhotoCounts = make(map[string]int64) for _, stat := range categoryPhotoStats { stats.PhotoCounts[stat.Name] = stat.PhotoCount } return &stats, nil } // validateSlugUnique 验证slug唯一性 func (s *CategoryService) validateSlugUnique(ctx context.Context, slug string, excludeID uint) error { var count int64 query := s.db.WithContext(ctx).Model(&models.Category{}).Where("slug = ?", slug) if excludeID > 0 { query = query.Where("id != ?", excludeID) } if err := query.Count(&count).Error; err != nil { return err } if count > 0 { return errors.New("slug already exists") } return nil } // validateParentCategory 验证父分类(防止循环引用) func (s *CategoryService) validateParentCategory(ctx context.Context, categoryID, parentID uint) error { if categoryID == parentID { return errors.New("category cannot be its own parent") } // 检查是否会形成循环引用 current := parentID for current != 0 { var parent models.Category if err := s.db.WithContext(ctx).First(&parent, current).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("parent category not found") } return err } if parent.ParentID == nil { break } if *parent.ParentID == categoryID { return errors.New("circular reference detected") } current = *parent.ParentID } return nil } // getNextSortOrder 获取下一个排序顺序 func (s *CategoryService) getNextSortOrder(ctx context.Context, parentID *uint) int { var maxOrder int query := s.db.WithContext(ctx).Model(&models.Category{}).Select("COALESCE(MAX(sort_order), 0)") if parentID != nil { query = query.Where("parent_id = ?", *parentID) } else { query = query.Where("parent_id IS NULL") } query.Row().Scan(&maxOrder) return maxOrder + 1 } // buildCategoryTree 构建分类树 func (s *CategoryService) buildCategoryTree(categories []models.Category, parentID *uint) []models.CategoryTree { var tree []models.CategoryTree for _, category := range categories { // 检查是否匹配父分类 if (parentID == nil && category.ParentID == nil) || (parentID != nil && category.ParentID != nil && *category.ParentID == *parentID) { node := models.CategoryTree{ ID: category.ID, Name: category.Name, Slug: category.Slug, Description: category.Description, ParentID: category.ParentID, SortOrder: category.SortOrder, IsActive: category.IsActive, PhotoCount: category.PhotoCount, CreatedAt: category.CreatedAt, UpdatedAt: category.UpdatedAt, } // 递归构建子分类 node.Children = s.buildCategoryTree(categories, &category.ID) tree = append(tree, node) } } return tree } // GenerateSlug 生成slug func (s *CategoryService) GenerateSlug(ctx context.Context, name string) (string, error) { baseSlug := utils.GenerateSlug(name) slug := baseSlug counter := 1 for { var count int64 if err := s.db.WithContext(ctx).Model(&models.Category{}). Where("slug = ?", slug).Count(&count).Error; err != nil { return "", err } if count == 0 { break } slug = fmt.Sprintf("%s-%d", baseSlug, counter) counter++ } return slug, nil }