fix
Some checks failed
部署后端服务 / 🧪 测试后端 (push) Failing after 5m8s
部署后端服务 / 🚀 构建并部署 (push) Has been skipped
部署后端服务 / 🔄 回滚部署 (push) Has been skipped

This commit is contained in:
xujiang
2025-07-10 18:09:11 +08:00
parent 35004f224e
commit 010fe2a8c7
96 changed files with 23709 additions and 19 deletions

View File

@ -0,0 +1,873 @@
# Repository Layer - CLAUDE.md
本文件为 Claude Code 在数据访问层中工作时提供指导。
## 🎯 模块概览
Repository 层负责数据访问逻辑,提供数据库操作的抽象接口,隔离业务逻辑与数据存储细节。
### 主要职责
- 🔧 提供数据库操作接口
- 📊 实现 CRUD 操作
- 🔍 提供复杂查询支持
- 💾 管理数据库连接和事务
- 🚀 优化查询性能
- 🔄 支持多种数据源
## 📁 模块结构
```
internal/repository/
├── CLAUDE.md # 📋 当前文件 - 数据访问开发指导
├── interfaces/ # 🔗 仓储接口定义
│ ├── user_repository.go # 用户仓储接口
│ ├── photo_repository.go # 照片仓储接口
│ ├── category_repository.go # 分类仓储接口
│ ├── tag_repository.go # 标签仓储接口
│ └── album_repository.go # 相册仓储接口
├── postgres/ # 🐘 PostgreSQL 实现
│ ├── user_repository.go # 用户仓储实现
│ ├── photo_repository.go # 照片仓储实现
│ ├── category_repository.go # 分类仓储实现
│ ├── tag_repository.go # 标签仓储实现
│ ├── album_repository.go # 相册仓储实现
│ └── base_repository.go # 基础仓储实现
├── redis/ # 🟥 Redis 缓存实现
│ ├── user_cache.go # 用户缓存
│ ├── photo_cache.go # 照片缓存
│ └── cache_manager.go # 缓存管理器
├── sqlite/ # 📦 SQLite 实现(开发用)
│ ├── user_repository.go # 用户仓储实现
│ ├── photo_repository.go # 照片仓储实现
│ └── base_repository.go # 基础仓储实现
├── mocks/ # 🧪 模拟对象(测试用)
│ ├── user_repository_mock.go # 用户仓储模拟
│ ├── photo_repository_mock.go # 照片仓储模拟
│ └── generate.go # 模拟对象生成
└── errors.go # 📝 仓储层错误定义
```
## 🔗 接口设计
### 基础仓储接口
```go
// interfaces/base_repository.go - 基础仓储接口
package interfaces
import (
"context"
"gorm.io/gorm"
)
// BaseRepositoryr 基础仓储接口
type BaseRepositoryr[T any] interface {
// 基础 CRUD 操作
Create(ctx context.Context, entity *T) (*T, error)
GetByID(ctx context.Context, id uint) (*T, error)
Update(ctx context.Context, entity *T) (*T, error)
Delete(ctx context.Context, id uint) error
// 批量操作
CreateBatch(ctx context.Context, entities []*T) error
UpdateBatch(ctx context.Context, entities []*T) error
DeleteBatch(ctx context.Context, ids []uint) error
// 查询操作
List(ctx context.Context, opts *ListOptions) ([]*T, int64, error)
Count(ctx context.Context, conditions map[string]interface{}) (int64, error)
Exists(ctx context.Context, conditions map[string]interface{}) (bool, error)
// 事务操作
WithTx(tx *gorm.DB) BaseRepositoryr[T]
Transaction(ctx context.Context, fn func(repo BaseRepositoryr[T]) error) error
}
// ListOptions 查询选项
type ListOptions struct {
Page int `json:"page"`
Limit int `json:"limit"`
Sort string `json:"sort"`
Order string `json:"order"`
Conditions map[string]interface{} `json:"conditions"`
Preloads []string `json:"preloads"`
}
```
### 用户仓储接口
```go
// interfaces/user_repository.go - 用户仓储接口
package interfaces
import (
"context"
"photography-backend/internal/model/entity"
)
// UserRepositoryr 用户仓储接口
type UserRepositoryr interface {
BaseRepositoryr[entity.User]
// 用户特定查询
GetByEmail(ctx context.Context, email string) (*entity.User, error)
GetByUsername(ctx context.Context, username string) (*entity.User, error)
GetByEmailOrUsername(ctx context.Context, emailOrUsername string) (*entity.User, error)
// 用户列表查询
ListByRole(ctx context.Context, role entity.UserRole, opts *ListOptions) ([]*entity.User, int64, error)
ListByStatus(ctx context.Context, status entity.UserStatus, opts *ListOptions) ([]*entity.User, int64, error)
SearchUsers(ctx context.Context, keyword string, opts *ListOptions) ([]*entity.User, int64, error)
// 用户统计
CountByRole(ctx context.Context, role entity.UserRole) (int64, error)
CountByStatus(ctx context.Context, status entity.UserStatus) (int64, error)
CountActiveUsers(ctx context.Context) (int64, error)
// 用户状态更新
UpdateStatus(ctx context.Context, id uint, status entity.UserStatus) error
UpdateLastLogin(ctx context.Context, id uint) error
// 密码相关
UpdatePassword(ctx context.Context, id uint, hashedPassword string) error
// 软删除恢复
Restore(ctx context.Context, id uint) error
}
```
### 照片仓储接口
```go
// interfaces/photo_repository.go - 照片仓储接口
package interfaces
import (
"context"
"time"
"photography-backend/internal/model/entity"
)
// PhotoRepositoryr 照片仓储接口
type PhotoRepositoryr interface {
BaseRepositoryr[entity.Photo]
// 照片查询
GetByFilename(ctx context.Context, filename string) (*entity.Photo, error)
GetByUserID(ctx context.Context, userID uint) ([]*entity.Photo, error)
ListByUserID(ctx context.Context, userID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
ListByStatus(ctx context.Context, status entity.PhotoStatus, opts *ListOptions) ([]*entity.Photo, int64, error)
// 分类和标签查询
ListByCategory(ctx context.Context, categoryID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
ListByTag(ctx context.Context, tagID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
ListByAlbum(ctx context.Context, albumID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
// 搜索功能
SearchPhotos(ctx context.Context, keyword string, opts *SearchOptions) ([]*entity.Photo, int64, error)
SearchByMetadata(ctx context.Context, metadata map[string]interface{}, opts *ListOptions) ([]*entity.Photo, int64, error)
// 时间范围查询
ListByDateRange(ctx context.Context, startDate, endDate time.Time, opts *ListOptions) ([]*entity.Photo, int64, error)
ListByCreatedDateRange(ctx context.Context, startDate, endDate time.Time, opts *ListOptions) ([]*entity.Photo, int64, error)
// 统计查询
CountByUser(ctx context.Context, userID uint) (int64, error)
CountByStatus(ctx context.Context, status entity.PhotoStatus) (int64, error)
CountByCategory(ctx context.Context, categoryID uint) (int64, error)
CountByTag(ctx context.Context, tagID uint) (int64, error)
// 关联操作
AddCategories(ctx context.Context, photoID uint, categoryIDs []uint) error
RemoveCategories(ctx context.Context, photoID uint, categoryIDs []uint) error
AddTags(ctx context.Context, photoID uint, tagIDs []uint) error
RemoveTags(ctx context.Context, photoID uint, tagIDs []uint) error
// 统计更新
IncrementViewCount(ctx context.Context, id uint) error
IncrementDownloadCount(ctx context.Context, id uint) error
// 批量操作
BatchUpdateStatus(ctx context.Context, ids []uint, status entity.PhotoStatus) error
BatchDelete(ctx context.Context, ids []uint) error
}
// SearchOptions 搜索选项
type SearchOptions struct {
ListOptions
Fields []string `json:"fields"`
DateFrom *time.Time `json:"date_from"`
DateTo *time.Time `json:"date_to"`
MinWidth int `json:"min_width"`
MaxWidth int `json:"max_width"`
MinHeight int `json:"min_height"`
MaxHeight int `json:"max_height"`
}
```
## 🔧 仓储实现
### 基础仓储实现
```go
// postgres/base_repository.go - 基础仓储实现
package postgres
import (
"context"
"errors"
"fmt"
"reflect"
"gorm.io/gorm"
"go.uber.org/zap"
"photography-backend/internal/repository/interfaces"
"photography-backend/pkg/logger"
)
// BaseRepository 基础仓储实现
type BaseRepository[T any] struct {
db *gorm.DB
logger logger.Logger
}
// NewBaseRepository 创建基础仓储
func NewBaseRepository[T any](db *gorm.DB, logger logger.Logger) *BaseRepository[T] {
return &BaseRepository[T]{
db: db,
logger: logger,
}
}
// Create 创建记录
func (r *BaseRepository[T]) Create(ctx context.Context, entity *T) (*T, error) {
if err := r.db.WithContext(ctx).Create(entity).Error; err != nil {
r.logger.Error("failed to create entity", zap.Error(err))
return nil, err
}
return entity, nil
}
// GetByID 根据ID获取记录
func (r *BaseRepository[T]) GetByID(ctx context.Context, id uint) (*T, error) {
var entity T
if err := r.db.WithContext(ctx).First(&entity, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
r.logger.Error("failed to get entity by id", zap.Error(err), zap.Uint("id", id))
return nil, err
}
return &entity, nil
}
// Update 更新记录
func (r *BaseRepository[T]) Update(ctx context.Context, entity *T) (*T, error) {
if err := r.db.WithContext(ctx).Save(entity).Error; err != nil {
r.logger.Error("failed to update entity", zap.Error(err))
return nil, err
}
return entity, nil
}
// Delete 删除记录
func (r *BaseRepository[T]) Delete(ctx context.Context, id uint) error {
var entity T
if err := r.db.WithContext(ctx).Delete(&entity, id).Error; err != nil {
r.logger.Error("failed to delete entity", zap.Error(err), zap.Uint("id", id))
return err
}
return nil
}
// List 列表查询
func (r *BaseRepository[T]) List(ctx context.Context, opts *interfaces.ListOptions) ([]*T, int64, error) {
var entities []*T
var total int64
db := r.db.WithContext(ctx)
// 应用条件
if opts.Conditions != nil {
for key, value := range opts.Conditions {
db = db.Where(key, value)
}
}
// 获取总数
if err := db.Model(new(T)).Count(&total).Error; err != nil {
r.logger.Error("failed to count entities", zap.Error(err))
return nil, 0, err
}
// 应用排序
if opts.Sort != "" {
order := "ASC"
if opts.Order == "desc" {
order = "DESC"
}
db = db.Order(fmt.Sprintf("%s %s", opts.Sort, order))
}
// 应用分页
if opts.Page > 0 && opts.Limit > 0 {
offset := (opts.Page - 1) * opts.Limit
db = db.Offset(offset).Limit(opts.Limit)
}
// 应用预加载
if opts.Preloads != nil {
for _, preload := range opts.Preloads {
db = db.Preload(preload)
}
}
// 查询数据
if err := db.Find(&entities).Error; err != nil {
r.logger.Error("failed to list entities", zap.Error(err))
return nil, 0, err
}
return entities, total, nil
}
// Transaction 事务执行
func (r *BaseRepository[T]) Transaction(ctx context.Context, fn func(repo interfaces.BaseRepositoryr[T]) error) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
repo := &BaseRepository[T]{
db: tx,
logger: r.logger,
}
return fn(repo)
})
}
// WithTx 使用事务
func (r *BaseRepository[T]) WithTx(tx *gorm.DB) interfaces.BaseRepositoryr[T] {
return &BaseRepository[T]{
db: tx,
logger: r.logger,
}
}
```
### 用户仓储实现
```go
// postgres/user_repository.go - 用户仓储实现
package postgres
import (
"context"
"errors"
"strings"
"time"
"gorm.io/gorm"
"go.uber.org/zap"
"photography-backend/internal/model/entity"
"photography-backend/internal/repository/interfaces"
"photography-backend/pkg/logger"
)
// UserRepository 用户仓储实现
type UserRepository struct {
*BaseRepository[entity.User]
db *gorm.DB
logger logger.Logger
}
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB, logger logger.Logger) interfaces.UserRepositoryr {
return &UserRepository{
BaseRepository: NewBaseRepository[entity.User](db, logger),
db: db,
logger: logger,
}
}
// GetByEmail 根据邮箱获取用户
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*entity.User, error) {
var user entity.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
r.logger.Error("failed to get user by email", zap.Error(err), zap.String("email", email))
return nil, err
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
var user entity.User
err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
r.logger.Error("failed to get user by username", zap.Error(err), zap.String("username", username))
return nil, err
}
return &user, nil
}
// GetByEmailOrUsername 根据邮箱或用户名获取用户
func (r *UserRepository) GetByEmailOrUsername(ctx context.Context, emailOrUsername string) (*entity.User, error) {
var user entity.User
err := r.db.WithContext(ctx).Where("email = ? OR username = ?", emailOrUsername, emailOrUsername).First(&user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
r.logger.Error("failed to get user by email or username", zap.Error(err), zap.String("emailOrUsername", emailOrUsername))
return nil, err
}
return &user, nil
}
// SearchUsers 搜索用户
func (r *UserRepository) SearchUsers(ctx context.Context, keyword string, opts *interfaces.ListOptions) ([]*entity.User, int64, error) {
var users []*entity.User
var total int64
db := r.db.WithContext(ctx)
// 搜索条件
searchCondition := fmt.Sprintf("username ILIKE %s OR email ILIKE %s OR first_name ILIKE %s OR last_name ILIKE %s",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
db = db.Where(searchCondition)
// 应用其他条件
if opts.Conditions != nil {
for key, value := range opts.Conditions {
db = db.Where(key, value)
}
}
// 获取总数
if err := db.Model(&entity.User{}).Count(&total).Error; err != nil {
r.logger.Error("failed to count users", zap.Error(err))
return nil, 0, err
}
// 应用排序和分页
if opts.Sort != "" {
order := "ASC"
if opts.Order == "desc" {
order = "DESC"
}
db = db.Order(fmt.Sprintf("%s %s", opts.Sort, order))
}
if opts.Page > 0 && opts.Limit > 0 {
offset := (opts.Page - 1) * opts.Limit
db = db.Offset(offset).Limit(opts.Limit)
}
// 查询数据
if err := db.Find(&users).Error; err != nil {
r.logger.Error("failed to search users", zap.Error(err))
return nil, 0, err
}
return users, total, nil
}
// UpdateStatus 更新用户状态
func (r *UserRepository) UpdateStatus(ctx context.Context, id uint, status entity.UserStatus) error {
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
r.logger.Error("failed to update user status", zap.Error(err), zap.Uint("id", id), zap.String("status", string(status)))
return err
}
return nil
}
// UpdateLastLogin 更新最后登录时间
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id uint) error {
now := time.Now()
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("last_login_at", now).Error
if err != nil {
r.logger.Error("failed to update last login time", zap.Error(err), zap.Uint("id", id))
return err
}
return nil
}
// UpdatePassword 更新密码
func (r *UserRepository) UpdatePassword(ctx context.Context, id uint, hashedPassword string) error {
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
if err != nil {
r.logger.Error("failed to update password", zap.Error(err), zap.Uint("id", id))
return err
}
return nil
}
// CountByRole 按角色统计用户数
func (r *UserRepository) CountByRole(ctx context.Context, role entity.UserRole) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("role = ?", role).Count(&count).Error
if err != nil {
r.logger.Error("failed to count users by role", zap.Error(err), zap.String("role", string(role)))
return 0, err
}
return count, nil
}
// CountActiveUsers 统计活跃用户数
func (r *UserRepository) CountActiveUsers(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("status = ?", entity.UserStatusActive).Count(&count).Error
if err != nil {
r.logger.Error("failed to count active users", zap.Error(err))
return 0, err
}
return count, nil
}
```
## 💾 Redis 缓存实现
### 缓存管理器
```go
// redis/cache_manager.go - 缓存管理器
package redis
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/go-redis/redis/v8"
"go.uber.org/zap"
"photography-backend/pkg/logger"
)
// CacheManager 缓存管理器
type CacheManager struct {
client *redis.Client
logger logger.Logger
}
// NewCacheManager 创建缓存管理器
func NewCacheManager(client *redis.Client, logger logger.Logger) *CacheManager {
return &CacheManager{
client: client,
logger: logger,
}
}
// Set 设置缓存
func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
cm.logger.Error("failed to marshal cache value", zap.Error(err), zap.String("key", key))
return err
}
err = cm.client.Set(ctx, key, data, ttl).Err()
if err != nil {
cm.logger.Error("failed to set cache", zap.Error(err), zap.String("key", key))
return err
}
return nil
}
// Get 获取缓存
func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) error {
data, err := cm.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return ErrCacheNotFound
}
cm.logger.Error("failed to get cache", zap.Error(err), zap.String("key", key))
return err
}
err = json.Unmarshal([]byte(data), dest)
if err != nil {
cm.logger.Error("failed to unmarshal cache value", zap.Error(err), zap.String("key", key))
return err
}
return nil
}
// Delete 删除缓存
func (cm *CacheManager) Delete(ctx context.Context, key string) error {
err := cm.client.Del(ctx, key).Err()
if err != nil {
cm.logger.Error("failed to delete cache", zap.Error(err), zap.String("key", key))
return err
}
return nil
}
// DeletePattern 批量删除缓存
func (cm *CacheManager) DeletePattern(ctx context.Context, pattern string) error {
keys, err := cm.client.Keys(ctx, pattern).Result()
if err != nil {
cm.logger.Error("failed to get keys by pattern", zap.Error(err), zap.String("pattern", pattern))
return err
}
if len(keys) > 0 {
err = cm.client.Del(ctx, keys...).Err()
if err != nil {
cm.logger.Error("failed to delete keys by pattern", zap.Error(err), zap.String("pattern", pattern))
return err
}
}
return nil
}
// Exists 检查缓存是否存在
func (cm *CacheManager) Exists(ctx context.Context, key string) (bool, error) {
count, err := cm.client.Exists(ctx, key).Result()
if err != nil {
cm.logger.Error("failed to check cache existence", zap.Error(err), zap.String("key", key))
return false, err
}
return count > 0, nil
}
// SetTTL 设置过期时间
func (cm *CacheManager) SetTTL(ctx context.Context, key string, ttl time.Duration) error {
err := cm.client.Expire(ctx, key, ttl).Err()
if err != nil {
cm.logger.Error("failed to set cache ttl", zap.Error(err), zap.String("key", key))
return err
}
return nil
}
```
### 用户缓存
```go
// redis/user_cache.go - 用户缓存
package redis
import (
"context"
"fmt"
"time"
"photography-backend/internal/model/entity"
"photography-backend/pkg/logger"
)
// UserCache 用户缓存
type UserCache struct {
*CacheManager
ttl time.Duration
}
// NewUserCache 创建用户缓存
func NewUserCache(cm *CacheManager, ttl time.Duration) *UserCache {
return &UserCache{
CacheManager: cm,
ttl: ttl,
}
}
// SetUser 缓存用户
func (uc *UserCache) SetUser(ctx context.Context, user *entity.User) error {
key := fmt.Sprintf("user:id:%d", user.ID)
return uc.Set(ctx, key, user, uc.ttl)
}
// GetUser 获取用户缓存
func (uc *UserCache) GetUser(ctx context.Context, id uint) (*entity.User, error) {
key := fmt.Sprintf("user:id:%d", id)
var user entity.User
err := uc.Get(ctx, key, &user)
if err != nil {
return nil, err
}
return &user, nil
}
// SetUserByEmail 按邮箱缓存用户
func (uc *UserCache) SetUserByEmail(ctx context.Context, user *entity.User) error {
key := fmt.Sprintf("user:email:%s", user.Email)
return uc.Set(ctx, key, user, uc.ttl)
}
// GetUserByEmail 按邮箱获取用户缓存
func (uc *UserCache) GetUserByEmail(ctx context.Context, email string) (*entity.User, error) {
key := fmt.Sprintf("user:email:%s", email)
var user entity.User
err := uc.Get(ctx, key, &user)
if err != nil {
return nil, err
}
return &user, nil
}
// DeleteUser 删除用户缓存
func (uc *UserCache) DeleteUser(ctx context.Context, id uint) error {
key := fmt.Sprintf("user:id:%d", id)
return uc.Delete(ctx, key)
}
// DeleteUserByEmail 按邮箱删除用户缓存
func (uc *UserCache) DeleteUserByEmail(ctx context.Context, email string) error {
key := fmt.Sprintf("user:email:%s", email)
return uc.Delete(ctx, key)
}
// InvalidateUserCache 失效用户相关缓存
func (uc *UserCache) InvalidateUserCache(ctx context.Context, userID uint) error {
pattern := fmt.Sprintf("user:*:%d", userID)
return uc.DeletePattern(ctx, pattern)
}
```
## 🔍 错误处理
### 仓储层错误定义
```go
// errors.go - 仓储层错误定义
package repository
import "errors"
var (
// 通用错误
ErrNotFound = errors.New("record not found")
ErrDuplicateKey = errors.New("duplicate key")
ErrInvalidParameter = errors.New("invalid parameter")
ErrDatabaseError = errors.New("database error")
ErrTransactionError = errors.New("transaction error")
// 缓存错误
ErrCacheNotFound = errors.New("cache not found")
ErrCacheError = errors.New("cache error")
ErrCacheExpired = errors.New("cache expired")
// 连接错误
ErrConnectionFailed = errors.New("connection failed")
ErrConnectionTimeout = errors.New("connection timeout")
)
```
## 🧪 测试
### 仓储测试
```go
// postgres/user_repository_test.go - 用户仓储测试
package postgres
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"photography-backend/internal/model/entity"
"photography-backend/pkg/logger"
)
type UserRepositoryTestSuite struct {
suite.Suite
db *gorm.DB
repo *UserRepository
}
func (suite *UserRepositoryTestSuite) SetupTest() {
// 使用内存数据库进行测试
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
suite.Require().NoError(err)
// 自动迁移
err = db.AutoMigrate(&entity.User{})
suite.Require().NoError(err)
suite.db = db
suite.repo = NewUserRepository(db, logger.NewNoop()).(*UserRepository)
}
func (suite *UserRepositoryTestSuite) TearDownTest() {
sqlDB, _ := suite.db.DB()
sqlDB.Close()
}
func (suite *UserRepositoryTestSuite) TestCreateUser() {
ctx := context.Background()
user := &entity.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
Role: entity.UserRoleUser,
Status: entity.UserStatusActive,
}
createdUser, err := suite.repo.Create(ctx, user)
assert.NoError(suite.T(), err)
assert.NotZero(suite.T(), createdUser.ID)
assert.Equal(suite.T(), user.Username, createdUser.Username)
assert.Equal(suite.T(), user.Email, createdUser.Email)
}
func (suite *UserRepositoryTestSuite) TestGetUserByEmail() {
ctx := context.Background()
// 创建测试用户
user := &entity.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
Role: entity.UserRoleUser,
Status: entity.UserStatusActive,
}
createdUser, err := suite.repo.Create(ctx, user)
suite.Require().NoError(err)
// 根据邮箱获取用户
foundUser, err := suite.repo.GetByEmail(ctx, user.Email)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), createdUser.ID, foundUser.ID)
assert.Equal(suite.T(), createdUser.Email, foundUser.Email)
}
func TestUserRepositoryTestSuite(t *testing.T) {
suite.Run(t, new(UserRepositoryTestSuite))
}
```
## 💡 最佳实践
### 设计原则
1. **接口隔离**: 定义清晰的仓储接口
2. **依赖倒置**: 依赖接口而非具体实现
3. **单一职责**: 每个仓储只负责一个实体
4. **错误处理**: 统一错误处理和日志记录
5. **事务支持**: 提供事务操作支持
### 性能优化
1. **查询优化**: 使用适当的索引和查询条件
2. **批量操作**: 支持批量插入和更新
3. **缓存策略**: 合理使用缓存减少数据库访问
4. **连接池**: 使用连接池管理数据库连接
5. **预加载**: 避免 N+1 查询问题
### 测试策略
1. **单元测试**: 为每个仓储编写单元测试
2. **集成测试**: 测试数据库交互
3. **模拟对象**: 使用 Mock 对象进行测试
4. **测试数据**: 准备充分的测试数据
5. **性能测试**: 测试查询性能和并发性能
本模块是数据访问的核心,确保数据操作的正确性和性能是关键。

View File

@ -0,0 +1,39 @@
package interfaces
import (
"context"
"photography-backend/internal/model/entity"
)
// CategoryRepository 分类仓储接口
type CategoryRepository interface {
// 基本CRUD操作
Create(ctx context.Context, category *entity.Category) error
GetByID(ctx context.Context, id uint) (*entity.Category, error)
GetBySlug(ctx context.Context, slug string) (*entity.Category, error)
Update(ctx context.Context, category *entity.Category) error
Delete(ctx context.Context, id uint) error
// 查询操作
List(ctx context.Context, parentID *uint) ([]*entity.Category, error)
GetTree(ctx context.Context) ([]*entity.CategoryTree, error)
GetChildren(ctx context.Context, parentID uint) ([]*entity.Category, error)
GetParent(ctx context.Context, categoryID uint) (*entity.Category, error)
// 排序操作
Reorder(ctx context.Context, parentID *uint, categoryIDs []uint) error
GetNextSortOrder(ctx context.Context, parentID *uint) (int, error)
// 验证操作
ValidateSlugUnique(ctx context.Context, slug string, excludeID uint) error
ValidateParentCategory(ctx context.Context, categoryID, parentID uint) error
// 统计操作
Count(ctx context.Context) (int64, error)
CountActive(ctx context.Context) (int64, error)
CountTopLevel(ctx context.Context) (int64, error)
GetStats(ctx context.Context) (*entity.CategoryStats, error)
// 工具方法
GenerateUniqueSlug(ctx context.Context, baseName string) (string, error)
}

View File

@ -0,0 +1,33 @@
package interfaces
import (
"context"
"photography-backend/internal/model/entity"
)
// PhotoRepository 照片仓储接口
type PhotoRepository interface {
// 基本CRUD操作
Create(ctx context.Context, photo *entity.Photo) error
GetByID(ctx context.Context, id uint) (*entity.Photo, error)
Update(ctx context.Context, photo *entity.Photo) error
Delete(ctx context.Context, id uint) error
// 查询操作
List(ctx context.Context, params *entity.PhotoListParams) ([]*entity.Photo, int64, error)
ListByUserID(ctx context.Context, userID uint, params *entity.PhotoListParams) ([]*entity.Photo, int64, error)
ListByCategory(ctx context.Context, categoryID uint, params *entity.PhotoListParams) ([]*entity.Photo, int64, error)
Search(ctx context.Context, query string, params *entity.PhotoListParams) ([]*entity.Photo, int64, error)
// 批量操作
BatchUpdate(ctx context.Context, ids []uint, updates map[string]interface{}) error
BatchDelete(ctx context.Context, ids []uint) error
BatchUpdateCategories(ctx context.Context, photoIDs []uint, categoryIDs []uint) error
BatchUpdateTags(ctx context.Context, photoIDs []uint, tagIDs []uint) error
// 统计操作
Count(ctx context.Context) (int64, error)
CountByStatus(ctx context.Context, status string) (int64, error)
CountByUser(ctx context.Context, userID uint) (int64, error)
GetStats(ctx context.Context) (*entity.PhotoStats, error)
}

View File

@ -0,0 +1,42 @@
package interfaces
import (
"context"
"photography-backend/internal/model/entity"
)
// TagRepository 标签仓储接口
type TagRepository interface {
// 基本CRUD操作
Create(ctx context.Context, tag *entity.Tag) error
GetByID(ctx context.Context, id uint) (*entity.Tag, error)
GetBySlug(ctx context.Context, slug string) (*entity.Tag, error)
GetByName(ctx context.Context, name string) (*entity.Tag, error)
Update(ctx context.Context, tag *entity.Tag) error
Delete(ctx context.Context, id uint) error
// 查询操作
List(ctx context.Context, params *entity.TagListParams) ([]*entity.Tag, int64, error)
Search(ctx context.Context, query string) ([]*entity.Tag, error)
GetPopular(ctx context.Context, limit int) ([]*entity.Tag, error)
GetByPhotos(ctx context.Context, photoIDs []uint) ([]*entity.Tag, error)
// 批量操作
CreateMultiple(ctx context.Context, tags []*entity.Tag) error
GetOrCreateByNames(ctx context.Context, names []string) ([]*entity.Tag, error)
BatchDelete(ctx context.Context, ids []uint) error
// 关联操作
AttachToPhoto(ctx context.Context, tagID, photoID uint) error
DetachFromPhoto(ctx context.Context, tagID, photoID uint) error
GetPhotoTags(ctx context.Context, photoID uint) ([]*entity.Tag, error)
// 统计操作
Count(ctx context.Context) (int64, error)
CountByPhotos(ctx context.Context) (map[uint]int64, error)
GetStats(ctx context.Context) (*entity.TagStats, error)
// 工具方法
GenerateUniqueSlug(ctx context.Context, baseName string) (string, error)
ValidateSlugUnique(ctx context.Context, slug string, excludeID uint) error
}

View File

@ -0,0 +1,40 @@
package interfaces
import (
"context"
"photography-backend/internal/model/entity"
)
// UserRepository 用户仓储接口
type UserRepository interface {
// 基本CRUD操作
Create(ctx context.Context, user *entity.User) error
GetByID(ctx context.Context, id uint) (*entity.User, error)
GetByEmail(ctx context.Context, email string) (*entity.User, error)
GetByUsername(ctx context.Context, username string) (*entity.User, error)
Update(ctx context.Context, user *entity.User) error
Delete(ctx context.Context, id uint) error
// 查询操作
List(ctx context.Context, params *entity.UserListParams) ([]*entity.User, int64, error)
Search(ctx context.Context, query string, params *entity.UserListParams) ([]*entity.User, int64, error)
// 认证相关
UpdatePassword(ctx context.Context, userID uint, hashedPassword string) error
UpdateLastLogin(ctx context.Context, userID uint) error
IncrementLoginCount(ctx context.Context, userID uint) error
// 状态管理
SetActive(ctx context.Context, userID uint, isActive bool) error
VerifyEmail(ctx context.Context, userID uint) error
// 统计操作
Count(ctx context.Context) (int64, error)
CountByRole(ctx context.Context, role entity.UserRole) (int64, error)
CountActive(ctx context.Context) (int64, error)
GetStats(ctx context.Context) (*entity.UserStats, error)
// 验证操作
ExistsByEmail(ctx context.Context, email string) (bool, error)
ExistsByUsername(ctx context.Context, username string) (bool, error)
}

View File

@ -0,0 +1,345 @@
package postgres
import (
"context"
"errors"
"fmt"
"photography-backend/internal/model/entity"
"photography-backend/internal/repository/interfaces"
"photography-backend/internal/utils"
"go.uber.org/zap"
"gorm.io/gorm"
)
// categoryRepositoryImpl 分类仓储实现
type categoryRepositoryImpl struct {
db *gorm.DB
logger *zap.Logger
}
// NewCategoryRepository 创建分类仓储实现
func NewCategoryRepository(db *gorm.DB, logger *zap.Logger) interfaces.CategoryRepository {
return &categoryRepositoryImpl{
db: db,
logger: logger,
}
}
// Create 创建分类
func (r *categoryRepositoryImpl) Create(ctx context.Context, category *entity.Category) error {
return r.db.WithContext(ctx).Create(category).Error
}
// GetByID 根据ID获取分类
func (r *categoryRepositoryImpl) GetByID(ctx context.Context, id uint) (*entity.Category, error) {
var category entity.Category
err := r.db.WithContext(ctx).First(&category, id).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("category not found")
}
return nil, err
}
return &category, nil
}
// GetBySlug 根据slug获取分类
func (r *categoryRepositoryImpl) GetBySlug(ctx context.Context, slug string) (*entity.Category, error) {
var category entity.Category
err := r.db.WithContext(ctx).Where("slug = ?", slug).First(&category).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("category not found")
}
return nil, err
}
return &category, nil
}
// Update 更新分类
func (r *categoryRepositoryImpl) Update(ctx context.Context, category *entity.Category) error {
return r.db.WithContext(ctx).Save(category).Error
}
// Delete 删除分类
func (r *categoryRepositoryImpl) Delete(ctx context.Context, id uint) error {
return r.db.WithContext(ctx).Delete(&entity.Category{}, id).Error
}
// List 获取分类列表
func (r *categoryRepositoryImpl) List(ctx context.Context, parentID *uint) ([]*entity.Category, error) {
var categories []*entity.Category
query := r.db.WithContext(ctx).Order("sort_order ASC, created_at ASC")
if parentID != nil {
query = query.Where("parent_id = ?", *parentID)
} else {
query = query.Where("parent_id IS NULL")
}
err := query.Find(&categories).Error
return categories, err
}
// GetTree 获取分类树
func (r *categoryRepositoryImpl) GetTree(ctx context.Context) ([]*entity.CategoryTree, error) {
var categories []*entity.Category
if err := r.db.WithContext(ctx).
Order("sort_order ASC, created_at ASC").
Find(&categories).Error; err != nil {
return nil, err
}
// 构建树形结构
tree := r.buildCategoryTree(categories, nil)
return tree, nil
}
// GetChildren 获取子分类
func (r *categoryRepositoryImpl) GetChildren(ctx context.Context, parentID uint) ([]*entity.Category, error) {
var children []*entity.Category
err := r.db.WithContext(ctx).
Where("parent_id = ?", parentID).
Order("sort_order ASC").
Find(&children).Error
return children, err
}
// GetParent 获取父分类
func (r *categoryRepositoryImpl) GetParent(ctx context.Context, categoryID uint) (*entity.Category, error) {
var category entity.Category
err := r.db.WithContext(ctx).
Preload("Parent").
First(&category, categoryID).Error
if err != nil {
return nil, err
}
return category.Parent, nil
}
// Reorder 重新排序分类
func (r *categoryRepositoryImpl) Reorder(ctx context.Context, parentID *uint, categoryIDs []uint) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i, categoryID := range categoryIDs {
if err := tx.Model(&entity.Category{}).
Where("id = ?", categoryID).
Update("sort_order", i+1).Error; err != nil {
return err
}
}
return nil
})
}
// GetNextSortOrder 获取下一个排序顺序
func (r *categoryRepositoryImpl) GetNextSortOrder(ctx context.Context, parentID *uint) (int, error) {
var maxOrder int
query := r.db.WithContext(ctx).Model(&entity.Category{}).Select("COALESCE(MAX(sort_order), 0)")
if parentID != nil {
query = query.Where("parent_id = ?", *parentID)
} else {
query = query.Where("parent_id IS NULL")
}
err := query.Row().Scan(&maxOrder)
return maxOrder + 1, err
}
// ValidateSlugUnique 验证slug唯一性
func (r *categoryRepositoryImpl) ValidateSlugUnique(ctx context.Context, slug string, excludeID uint) error {
var count int64
query := r.db.WithContext(ctx).Model(&entity.Category{}).Where("slug = ?", slug)
if excludeID > 0 {
query = query.Where("id != ?", excludeID)
}
if err := query.Count(&count).Error; err != nil {
return err
}
if count > 0 {
return errors.New("slug already exists")
}
return nil
}
// ValidateParentCategory 验证父分类(防止循环引用)
func (r *categoryRepositoryImpl) ValidateParentCategory(ctx context.Context, categoryID, parentID uint) error {
if categoryID == parentID {
return errors.New("category cannot be its own parent")
}
// 检查父分类是否存在
var parent entity.Category
if err := r.db.WithContext(ctx).First(&parent, parentID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("parent category not found")
}
return err
}
// 检查是否会形成循环引用
current := parentID
for current != 0 {
var category entity.Category
if err := r.db.WithContext(ctx).First(&category, current).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("parent category not found")
}
return err
}
if category.ParentID == nil {
break
}
if *category.ParentID == categoryID {
return errors.New("circular reference detected")
}
current = *category.ParentID
}
return nil
}
// Count 统计分类总数
func (r *categoryRepositoryImpl) Count(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.Category{}).Count(&count).Error
return count, err
}
// CountActive 统计活跃分类数
func (r *categoryRepositoryImpl) CountActive(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.Category{}).
Where("is_active = ?", true).Count(&count).Error
return count, err
}
// CountTopLevel 统计顶级分类数
func (r *categoryRepositoryImpl) CountTopLevel(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.Category{}).
Where("parent_id IS NULL").Count(&count).Error
return count, err
}
// GetStats 获取分类统计信息
func (r *categoryRepositoryImpl) GetStats(ctx context.Context) (*entity.CategoryStats, error) {
var stats entity.CategoryStats
// 总分类数
if total, err := r.Count(ctx); err != nil {
return nil, err
} else {
stats.Total = total
}
// 活跃分类数
if active, err := r.CountActive(ctx); err != nil {
return nil, err
} else {
stats.Active = active
}
// 顶级分类数
if topLevel, err := r.CountTopLevel(ctx); err != nil {
return nil, err
} else {
stats.TopLevel = topLevel
}
// 各分类照片数量
var categoryPhotoStats []struct {
CategoryID uint `json:"category_id"`
Name string `json:"name"`
PhotoCount int64 `json:"photo_count"`
}
if err := r.db.WithContext(ctx).
Table("categories").
Select("categories.id as category_id, categories.name, COUNT(photo_categories.photo_id) as photo_count").
Joins("LEFT JOIN photo_categories ON categories.id = photo_categories.category_id").
Group("categories.id, categories.name").
Order("photo_count DESC").
Limit(10).
Find(&categoryPhotoStats).Error; err != nil {
return nil, err
}
stats.PhotoCounts = make(map[string]int64)
for _, stat := range categoryPhotoStats {
stats.PhotoCounts[stat.Name] = stat.PhotoCount
}
return &stats, nil
}
// GenerateUniqueSlug 生成唯一slug
func (r *categoryRepositoryImpl) GenerateUniqueSlug(ctx context.Context, baseName string) (string, error) {
baseSlug := utils.GenerateSlug(baseName)
slug := baseSlug
counter := 1
for {
var count int64
if err := r.db.WithContext(ctx).Model(&entity.Category{}).
Where("slug = ?", slug).Count(&count).Error; err != nil {
return "", err
}
if count == 0 {
break
}
slug = fmt.Sprintf("%s-%d", baseSlug, counter)
counter++
}
return slug, nil
}
// buildCategoryTree 构建分类树
func (r *categoryRepositoryImpl) buildCategoryTree(categories []*entity.Category, parentID *uint) []*entity.CategoryTree {
var tree []*entity.CategoryTree
for _, category := range categories {
// 检查是否匹配父分类
if (parentID == nil && category.ParentID == nil) ||
(parentID != nil && category.ParentID != nil && *category.ParentID == *parentID) {
node := &entity.CategoryTree{
ID: category.ID,
Name: category.Name,
Slug: category.Slug,
Description: category.Description,
ParentID: category.ParentID,
SortOrder: category.SortOrder,
IsActive: category.IsActive,
PhotoCount: category.PhotoCount,
CreatedAt: category.CreatedAt,
UpdatedAt: category.UpdatedAt,
}
// 递归构建子分类
children := r.buildCategoryTree(categories, &category.ID)
node.Children = make([]entity.CategoryTree, len(children))
for i, child := range children {
node.Children[i] = *child
}
tree = append(tree, node)
}
}
return tree
}

View File

@ -0,0 +1,78 @@
package postgres
import (
"fmt"
"time"
"gorm.io/gorm"
"gorm.io/driver/postgres"
"photography-backend/internal/config"
"photography-backend/internal/model/entity"
)
// Database 数据库连接
type Database struct {
DB *gorm.DB
}
// NewDatabase 创建数据库连接
func NewDatabase(cfg *config.DatabaseConfig) (*Database, error) {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host,
cfg.Port,
cfg.Username,
cfg.Password,
cfg.Database,
cfg.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// 获取底层sql.DB实例配置连接池
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB instance: %w", err)
}
// 设置连接池参数
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second)
// 测试连接
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &Database{DB: db}, nil
}
// AutoMigrate 自动迁移数据库表结构
func (d *Database) AutoMigrate() error {
return d.DB.AutoMigrate(
&entity.User{},
&entity.Category{},
&entity.Tag{},
&entity.Photo{},
)
}
// Close 关闭数据库连接
func (d *Database) Close() error {
sqlDB, err := d.DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// Health 检查数据库健康状态
func (d *Database) Health() error {
sqlDB, err := d.DB.DB()
if err != nil {
return err
}
return sqlDB.Ping()
}

View File

@ -0,0 +1,375 @@
package postgres
import (
"context"
"errors"
"fmt"
"time"
"photography-backend/internal/model/entity"
"photography-backend/internal/repository/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// photoRepositoryImpl 照片仓储实现
type photoRepositoryImpl struct {
db *gorm.DB
logger *zap.Logger
}
// NewPhotoRepository 创建照片仓储实现
func NewPhotoRepository(db *gorm.DB, logger *zap.Logger) interfaces.PhotoRepository {
return &photoRepositoryImpl{
db: db,
logger: logger,
}
}
// Create 创建照片
func (r *photoRepositoryImpl) Create(ctx context.Context, photo *entity.Photo) error {
return r.db.WithContext(ctx).Create(photo).Error
}
// GetByID 根据ID获取照片
func (r *photoRepositoryImpl) GetByID(ctx context.Context, id uint) (*entity.Photo, error) {
var photo entity.Photo
err := r.db.WithContext(ctx).
Preload("User").
Preload("Categories").
Preload("Tags").
First(&photo, id).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("photo not found")
}
return nil, err
}
return &photo, nil
}
// GetByFilename 根据文件名获取照片
func (r *photoRepositoryImpl) GetByFilename(ctx context.Context, filename string) (*entity.Photo, error) {
var photo entity.Photo
err := r.db.WithContext(ctx).Where("filename = ?", filename).First(&photo).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("photo not found")
}
return nil, err
}
return &photo, nil
}
// Update 更新照片
func (r *photoRepositoryImpl) Update(ctx context.Context, photo *entity.Photo) error {
return r.db.WithContext(ctx).Save(photo).Error
}
// Delete 删除照片
func (r *photoRepositoryImpl) Delete(ctx context.Context, id uint) error {
return r.db.WithContext(ctx).Delete(&entity.Photo{}, id).Error
}
// List 获取照片列表
func (r *photoRepositoryImpl) List(ctx context.Context, params *entity.PhotoListParams) ([]*entity.Photo, int64, error) {
var photos []*entity.Photo
var total int64
query := r.db.WithContext(ctx).Model(&entity.Photo{})
// 应用过滤条件
if params.UserID != nil {
query = query.Where("user_id = ?", *params.UserID)
}
if params.Status != nil {
query = query.Where("status = ?", *params.Status)
}
if params.CategoryID != nil {
query = query.Joins("JOIN photo_categories ON photos.id = photo_categories.photo_id").
Where("photo_categories.category_id = ?", *params.CategoryID)
}
if params.TagID != nil {
query = query.Joins("JOIN photo_tags ON photos.id = photo_tags.photo_id").
Where("photo_tags.tag_id = ?", *params.TagID)
}
if params.DateFrom != nil {
query = query.Where("taken_at >= ?", *params.DateFrom)
}
if params.DateTo != nil {
query = query.Where("taken_at <= ?", *params.DateTo)
}
if params.Search != "" {
query = query.Where("title ILIKE ? OR description ILIKE ?",
"%"+params.Search+"%", "%"+params.Search+"%")
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序
orderBy := "created_at DESC"
if params.Sort != "" {
order := "ASC"
if params.Order == "desc" {
order = "DESC"
}
orderBy = fmt.Sprintf("%s %s", params.Sort, order)
}
query = query.Order(orderBy)
// 应用分页
if params.Page > 0 && params.Limit > 0 {
offset := (params.Page - 1) * params.Limit
query = query.Offset(offset).Limit(params.Limit)
}
// 预加载关联数据
query = query.Preload("User").Preload("Categories").Preload("Tags")
// 查询数据
if err := query.Find(&photos).Error; err != nil {
return nil, 0, err
}
return photos, total, nil
}
// ListByUserID 根据用户ID获取照片列表
func (r *photoRepositoryImpl) ListByUserID(ctx context.Context, userID uint, params *entity.PhotoListParams) ([]*entity.Photo, int64, error) {
if params == nil {
params = &entity.PhotoListParams{}
}
params.UserID = &userID
return r.List(ctx, params)
}
// ListByStatus 根据状态获取照片列表
func (r *photoRepositoryImpl) ListByStatus(ctx context.Context, status entity.PhotoStatus, params *entity.PhotoListParams) ([]*entity.Photo, int64, error) {
if params == nil {
params = &entity.PhotoListParams{}
}
params.Status = &status
return r.List(ctx, params)
}
// ListByCategory 根据分类获取照片列表
func (r *photoRepositoryImpl) ListByCategory(ctx context.Context, categoryID uint, params *entity.PhotoListParams) ([]*entity.Photo, int64, error) {
if params == nil {
params = &entity.PhotoListParams{}
}
params.CategoryID = &categoryID
return r.List(ctx, params)
}
// Search 搜索照片
func (r *photoRepositoryImpl) Search(ctx context.Context, query string, params *entity.PhotoListParams) ([]*entity.Photo, int64, error) {
if params == nil {
params = &entity.PhotoListParams{}
}
params.Search = query
return r.List(ctx, params)
}
// Count 统计照片总数
func (r *photoRepositoryImpl) Count(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.Photo{}).Count(&count).Error
return count, err
}
// CountByUser 统计用户照片数
func (r *photoRepositoryImpl) CountByUser(ctx context.Context, userID uint) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("user_id = ?", userID).Count(&count).Error
return count, err
}
// CountByStatus 统计指定状态照片数
func (r *photoRepositoryImpl) CountByStatus(ctx context.Context, status entity.PhotoStatus) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("status = ?", status).Count(&count).Error
return count, err
}
// CountByCategory 统计分类照片数
func (r *photoRepositoryImpl) CountByCategory(ctx context.Context, categoryID uint) (int64, error) {
var count int64
err := r.db.WithContext(ctx).
Table("photo_categories").
Where("category_id = ?", categoryID).
Count(&count).Error
return count, err
}
// CountByStatus 统计指定状态照片数
func (r *photoRepositoryImpl) CountByStatus(ctx context.Context, status string) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("status = ?", status).Count(&count).Error
return count, err
}
// BatchUpdate 批量更新
func (r *photoRepositoryImpl) BatchUpdate(ctx context.Context, ids []uint, updates map[string]interface{}) error {
if len(ids) == 0 || len(updates) == 0 {
return nil
}
return r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("id IN ?", ids).
Updates(updates).Error
}
// BatchUpdateCategories 批量更新分类
func (r *photoRepositoryImpl) BatchUpdateCategories(ctx context.Context, photoIDs []uint, categoryIDs []uint) error {
if len(photoIDs) == 0 {
return nil
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 删除现有关联
if err := tx.Exec("DELETE FROM photo_categories WHERE photo_id IN ?", photoIDs).Error; err != nil {
return err
}
// 添加新关联
for _, photoID := range photoIDs {
for _, categoryID := range categoryIDs {
if err := tx.Exec("INSERT INTO photo_categories (photo_id, category_id) VALUES (?, ?)",
photoID, categoryID).Error; err != nil {
return err
}
}
}
return nil
})
}
// BatchUpdateTags 批量更新标签
func (r *photoRepositoryImpl) BatchUpdateTags(ctx context.Context, photoIDs []uint, tagIDs []uint) error {
if len(photoIDs) == 0 {
return nil
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 删除现有关联
if err := tx.Exec("DELETE FROM photo_tags WHERE photo_id IN ?", photoIDs).Error; err != nil {
return err
}
// 添加新关联
for _, photoID := range photoIDs {
for _, tagID := range tagIDs {
if err := tx.Exec("INSERT INTO photo_tags (photo_id, tag_id) VALUES (?, ?)",
photoID, tagID).Error; err != nil {
return err
}
}
}
return nil
})
}
// BatchDelete 批量删除
func (r *photoRepositoryImpl) BatchDelete(ctx context.Context, ids []uint) error {
if len(ids) == 0 {
return nil
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 删除关联关系
if err := tx.Exec("DELETE FROM photo_categories WHERE photo_id IN ?", ids).Error; err != nil {
return err
}
if err := tx.Exec("DELETE FROM photo_tags WHERE photo_id IN ?", ids).Error; err != nil {
return err
}
// 删除照片记录
return tx.Delete(&entity.Photo{}, ids).Error
})
}
// GetStats 获取照片统计信息
func (r *photoRepositoryImpl) GetStats(ctx context.Context) (*entity.PhotoStats, error) {
var stats entity.PhotoStats
// 总照片数
if total, err := r.Count(ctx); err != nil {
return nil, err
} else {
stats.Total = total
}
// 按状态统计
for _, status := range []entity.PhotoStatus{
entity.PhotoStatusActive,
entity.PhotoStatusDraft,
entity.PhotoStatusArchived,
} {
if count, err := r.CountByStatus(ctx, status); err != nil {
return nil, err
} else {
switch status {
case entity.PhotoStatusActive:
stats.Published = count
case entity.PhotoStatusDraft:
stats.Draft = count
case entity.PhotoStatusArchived:
stats.Archived = count
}
}
}
// 本月新增照片数
now := time.Now()
startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
endOfMonth := startOfMonth.AddDate(0, 1, 0).Add(-time.Nanosecond)
var monthlyCount int64
if err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("created_at >= ? AND created_at <= ?", startOfMonth, endOfMonth).
Count(&monthlyCount).Error; err != nil {
return nil, err
}
stats.ThisMonth = monthlyCount
// 用户照片分布Top 10
var userPhotoStats []struct {
UserID uint `json:"user_id"`
Username string `json:"username"`
PhotoCount int64 `json:"photo_count"`
}
if err := r.db.WithContext(ctx).
Table("photos").
Select("photos.user_id, users.username, COUNT(photos.id) as photo_count").
Joins("LEFT JOIN users ON photos.user_id = users.id").
Group("photos.user_id, users.username").
Order("photo_count DESC").
Limit(10).
Find(&userPhotoStats).Error; err != nil {
return nil, err
}
stats.UserPhotoCounts = make(map[string]int64)
for _, stat := range userPhotoStats {
stats.UserPhotoCounts[stat.Username] = stat.PhotoCount
}
return &stats, nil
}

View File

@ -0,0 +1,468 @@
package postgres
import (
"context"
"errors"
"fmt"
"strings"
"photography-backend/internal/model/entity"
"photography-backend/internal/repository/interfaces"
"photography-backend/internal/utils"
"go.uber.org/zap"
"gorm.io/gorm"
)
// tagRepositoryImpl 标签仓储实现
type tagRepositoryImpl struct {
db *gorm.DB
logger *zap.Logger
}
// NewTagRepository 创建标签仓储实现
func NewTagRepository(db *gorm.DB, logger *zap.Logger) interfaces.TagRepository {
return &tagRepositoryImpl{
db: db,
logger: logger,
}
}
// Create 创建标签
func (r *tagRepositoryImpl) Create(ctx context.Context, tag *entity.Tag) error {
return r.db.WithContext(ctx).Create(tag).Error
}
// GetByID 根据ID获取标签
func (r *tagRepositoryImpl) GetByID(ctx context.Context, id uint) (*entity.Tag, error) {
var tag entity.Tag
err := r.db.WithContext(ctx).First(&tag, id).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("tag not found")
}
return nil, err
}
return &tag, nil
}
// GetBySlug 根据slug获取标签
func (r *tagRepositoryImpl) GetBySlug(ctx context.Context, slug string) (*entity.Tag, error) {
var tag entity.Tag
err := r.db.WithContext(ctx).Where("slug = ?", slug).First(&tag).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("tag not found")
}
return nil, err
}
return &tag, nil
}
// GetByName 根据名称获取标签
func (r *tagRepositoryImpl) GetByName(ctx context.Context, name string) (*entity.Tag, error) {
var tag entity.Tag
err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("tag not found")
}
return nil, err
}
return &tag, nil
}
// Update 更新标签
func (r *tagRepositoryImpl) Update(ctx context.Context, tag *entity.Tag) error {
return r.db.WithContext(ctx).Save(tag).Error
}
// Delete 删除标签
func (r *tagRepositoryImpl) Delete(ctx context.Context, id uint) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 删除照片标签关联
if err := tx.Exec("DELETE FROM photo_tags WHERE tag_id = ?", id).Error; err != nil {
return err
}
// 删除标签
return tx.Delete(&entity.Tag{}, id).Error
})
}
// List 获取标签列表
func (r *tagRepositoryImpl) List(ctx context.Context, params *entity.TagListParams) ([]*entity.Tag, int64, error) {
var tags []*entity.Tag
var total int64
query := r.db.WithContext(ctx).Model(&entity.Tag{})
// 应用过滤条件
if params.Search != "" {
query = query.Where("name ILIKE ? OR description ILIKE ?",
"%"+params.Search+"%", "%"+params.Search+"%")
}
if params.Color != "" {
query = query.Where("color = ?", params.Color)
}
if params.CreatedFrom != nil {
query = query.Where("created_at >= ?", *params.CreatedFrom)
}
if params.CreatedTo != nil {
query = query.Where("created_at <= ?", *params.CreatedTo)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序
orderBy := "created_at DESC"
if params.Sort != "" {
order := "ASC"
if params.Order == "desc" {
order = "DESC"
}
orderBy = fmt.Sprintf("%s %s", params.Sort, order)
}
query = query.Order(orderBy)
// 应用分页
if params.Page > 0 && params.Limit > 0 {
offset := (params.Page - 1) * params.Limit
query = query.Offset(offset).Limit(params.Limit)
}
// 如果需要包含照片计数
if params.IncludePhotoCount {
query = query.Select("tags.*, COUNT(photo_tags.photo_id) as photo_count").
Joins("LEFT JOIN photo_tags ON tags.id = photo_tags.tag_id").
Group("tags.id")
}
// 查询数据
if err := query.Find(&tags).Error; err != nil {
return nil, 0, err
}
return tags, total, nil
}
// Search 搜索标签
func (r *tagRepositoryImpl) Search(ctx context.Context, query string) ([]*entity.Tag, error) {
var tags []*entity.Tag
err := r.db.WithContext(ctx).
Where("name ILIKE ? OR description ILIKE ?", "%"+query+"%", "%"+query+"%").
Order("name ASC").
Limit(50).
Find(&tags).Error
return tags, err
}
// GetPopular 获取热门标签
func (r *tagRepositoryImpl) GetPopular(ctx context.Context, limit int) ([]*entity.Tag, error) {
var tags []*entity.Tag
err := r.db.WithContext(ctx).
Select("tags.*, COUNT(photo_tags.photo_id) as photo_count").
Joins("LEFT JOIN photo_tags ON tags.id = photo_tags.tag_id").
Group("tags.id").
Order("photo_count DESC").
Limit(limit).
Find(&tags).Error
return tags, err
}
// GetByPhotos 根据照片IDs获取标签
func (r *tagRepositoryImpl) GetByPhotos(ctx context.Context, photoIDs []uint) ([]*entity.Tag, error) {
if len(photoIDs) == 0 {
return []*entity.Tag{}, nil
}
var tags []*entity.Tag
err := r.db.WithContext(ctx).
Joins("JOIN photo_tags ON tags.id = photo_tags.tag_id").
Where("photo_tags.photo_id IN ?", photoIDs).
Distinct().
Find(&tags).Error
return tags, err
}
// CreateMultiple 批量创建标签
func (r *tagRepositoryImpl) CreateMultiple(ctx context.Context, tags []*entity.Tag) error {
if len(tags) == 0 {
return nil
}
return r.db.WithContext(ctx).CreateInBatches(tags, 100).Error
}
// GetOrCreateByNames 根据名称获取或创建标签
func (r *tagRepositoryImpl) GetOrCreateByNames(ctx context.Context, names []string) ([]*entity.Tag, error) {
if len(names) == 0 {
return []*entity.Tag{}, nil
}
var tags []*entity.Tag
return tags, r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" {
continue
}
var tag entity.Tag
// 尝试获取现有标签
err := tx.Where("name = ?", name).First(&tag).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 创建新标签
slug, err := r.generateUniqueSlug(ctx, name)
if err != nil {
return err
}
tag = entity.Tag{
Name: name,
Slug: slug,
Color: r.generateRandomColor(),
PhotoCount: 0,
}
if err := tx.Create(&tag).Error; err != nil {
return err
}
} else {
return err
}
}
tags = append(tags, &tag)
}
return nil
})
}
// BatchDelete 批量删除标签
func (r *tagRepositoryImpl) BatchDelete(ctx context.Context, ids []uint) error {
if len(ids) == 0 {
return nil
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 删除照片标签关联
if err := tx.Exec("DELETE FROM photo_tags WHERE tag_id IN ?", ids).Error; err != nil {
return err
}
// 删除标签
return tx.Delete(&entity.Tag{}, ids).Error
})
}
// AttachToPhoto 为照片添加标签
func (r *tagRepositoryImpl) AttachToPhoto(ctx context.Context, tagID, photoID uint) error {
// 检查关联是否已存在
var count int64
if err := r.db.WithContext(ctx).Table("photo_tags").
Where("tag_id = ? AND photo_id = ?", tagID, photoID).
Count(&count).Error; err != nil {
return err
}
if count > 0 {
return nil // 关联已存在
}
// 创建关联
return r.db.WithContext(ctx).Exec(
"INSERT INTO photo_tags (tag_id, photo_id) VALUES (?, ?)",
tagID, photoID,
).Error
}
// DetachFromPhoto 从照片移除标签
func (r *tagRepositoryImpl) DetachFromPhoto(ctx context.Context, tagID, photoID uint) error {
return r.db.WithContext(ctx).Exec(
"DELETE FROM photo_tags WHERE tag_id = ? AND photo_id = ?",
tagID, photoID,
).Error
}
// GetPhotoTags 获取照片的标签
func (r *tagRepositoryImpl) GetPhotoTags(ctx context.Context, photoID uint) ([]*entity.Tag, error) {
var tags []*entity.Tag
err := r.db.WithContext(ctx).
Joins("JOIN photo_tags ON tags.id = photo_tags.tag_id").
Where("photo_tags.photo_id = ?", photoID).
Order("tags.name ASC").
Find(&tags).Error
return tags, err
}
// Count 统计标签总数
func (r *tagRepositoryImpl) Count(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.Tag{}).Count(&count).Error
return count, err
}
// CountByPhotos 统计各标签的照片数量
func (r *tagRepositoryImpl) CountByPhotos(ctx context.Context) (map[uint]int64, error) {
var results []struct {
TagID uint `json:"tag_id"`
PhotoCount int64 `json:"photo_count"`
}
err := r.db.WithContext(ctx).
Table("photo_tags").
Select("tag_id, COUNT(photo_id) as photo_count").
Group("tag_id").
Find(&results).Error
if err != nil {
return nil, err
}
counts := make(map[uint]int64)
for _, result := range results {
counts[result.TagID] = result.PhotoCount
}
return counts, nil
}
// GetStats 获取标签统计信息
func (r *tagRepositoryImpl) GetStats(ctx context.Context) (*entity.TagStats, error) {
var stats entity.TagStats
// 总标签数
if total, err := r.Count(ctx); err != nil {
return nil, err
} else {
stats.Total = total
}
// 已使用标签数(有照片关联的标签)
var usedCount int64
if err := r.db.WithContext(ctx).
Table("tags").
Joins("JOIN photo_tags ON tags.id = photo_tags.tag_id").
Distinct("tags.id").
Count(&usedCount).Error; err != nil {
return nil, err
}
stats.Used = usedCount
// 未使用标签数
stats.Unused = stats.Total - stats.Used
// 平均每个标签的照片数
if stats.Used > 0 {
var totalPhotos int64
if err := r.db.WithContext(ctx).
Table("photo_tags").
Count(&totalPhotos).Error; err != nil {
return nil, err
}
stats.AvgPhotosPerTag = float64(totalPhotos) / float64(stats.Used)
}
// 最受欢迎的标签前10
var popularTags []struct {
TagID uint `json:"tag_id"`
Name string `json:"name"`
PhotoCount int64 `json:"photo_count"`
}
if err := r.db.WithContext(ctx).
Table("tags").
Select("tags.id as tag_id, tags.name, COUNT(photo_tags.photo_id) as photo_count").
Joins("LEFT JOIN photo_tags ON tags.id = photo_tags.tag_id").
Group("tags.id, tags.name").
Order("photo_count DESC").
Limit(10).
Find(&popularTags).Error; err != nil {
return nil, err
}
stats.PopularTags = make(map[string]int64)
for _, tag := range popularTags {
stats.PopularTags[tag.Name] = tag.PhotoCount
}
return &stats, nil
}
// GenerateUniqueSlug 生成唯一slug
func (r *tagRepositoryImpl) GenerateUniqueSlug(ctx context.Context, baseName string) (string, error) {
return r.generateUniqueSlug(ctx, baseName)
}
// ValidateSlugUnique 验证slug唯一性
func (r *tagRepositoryImpl) ValidateSlugUnique(ctx context.Context, slug string, excludeID uint) error {
var count int64
query := r.db.WithContext(ctx).Model(&entity.Tag{}).Where("slug = ?", slug)
if excludeID > 0 {
query = query.Where("id != ?", excludeID)
}
if err := query.Count(&count).Error; err != nil {
return err
}
if count > 0 {
return errors.New("slug already exists")
}
return nil
}
// generateUniqueSlug 生成唯一slug
func (r *tagRepositoryImpl) generateUniqueSlug(ctx context.Context, baseName string) (string, error) {
baseSlug := utils.GenerateSlug(baseName)
slug := baseSlug
counter := 1
for {
var count int64
if err := r.db.WithContext(ctx).Model(&entity.Tag{}).
Where("slug = ?", slug).Count(&count).Error; err != nil {
return "", err
}
if count == 0 {
break
}
slug = fmt.Sprintf("%s-%d", baseSlug, counter)
counter++
}
return slug, nil
}
// generateRandomColor 生成随机颜色
func (r *tagRepositoryImpl) generateRandomColor() string {
colors := []string{
"#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7",
"#DDA0DD", "#98D8C8", "#F7DC6F", "#BB8FCE", "#85C1E9",
"#F8C471", "#82E0AA", "#F1948A", "#85C1E9", "#D7BDE2",
}
return colors[len(colors)%15] // 简单的颜色选择
}

View File

@ -0,0 +1,516 @@
package postgres
import (
"context"
"errors"
"fmt"
"time"
"photography-backend/internal/model/entity"
"photography-backend/internal/repository/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// userRepositoryImpl 用户仓储实现
type userRepositoryImpl struct {
db *gorm.DB
logger *zap.Logger
}
// NewUserRepository 创建用户仓储实现
func NewUserRepository(db *gorm.DB, logger *zap.Logger) interfaces.UserRepository {
return &userRepositoryImpl{
db: db,
logger: logger,
}
}
// Create 创建用户
func (r *userRepositoryImpl) Create(ctx context.Context, user *entity.User) error {
return r.db.WithContext(ctx).Create(user).Error
}
// GetByID 根据ID获取用户
func (r *userRepositoryImpl) GetByID(ctx context.Context, id uint) (*entity.User, error) {
var user entity.User
err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("user not found")
}
return nil, err
}
return &user, nil
}
// GetByEmail 根据邮箱获取用户
func (r *userRepositoryImpl) GetByEmail(ctx context.Context, email string) (*entity.User, error) {
var user entity.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("user not found")
}
return nil, err
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *userRepositoryImpl) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
var user entity.User
err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("user not found")
}
return nil, err
}
return &user, nil
}
// GetByEmailOrUsername 根据邮箱或用户名获取用户
func (r *userRepositoryImpl) GetByEmailOrUsername(ctx context.Context, emailOrUsername string) (*entity.User, error) {
var user entity.User
err := r.db.WithContext(ctx).
Where("email = ? OR username = ?", emailOrUsername, emailOrUsername).
First(&user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("user not found")
}
return nil, err
}
return &user, nil
}
// Update 更新用户
func (r *userRepositoryImpl) Update(ctx context.Context, user *entity.User) error {
return r.db.WithContext(ctx).Save(user).Error
}
// Delete 删除用户
func (r *userRepositoryImpl) Delete(ctx context.Context, id uint) error {
return r.db.WithContext(ctx).Delete(&entity.User{}, id).Error
}
// List 获取用户列表
func (r *userRepositoryImpl) List(ctx context.Context, params *entity.UserListParams) ([]*entity.User, int64, error) {
var users []*entity.User
var total int64
query := r.db.WithContext(ctx).Model(&entity.User{})
// 应用过滤条件
if params.Role != nil {
query = query.Where("role = ?", *params.Role)
}
if params.Status != nil {
query = query.Where("status = ?", *params.Status)
}
if params.Search != "" {
query = query.Where("username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
"%"+params.Search+"%", "%"+params.Search+"%", "%"+params.Search+"%", "%"+params.Search+"%")
}
if params.CreatedFrom != nil {
query = query.Where("created_at >= ?", *params.CreatedFrom)
}
if params.CreatedTo != nil {
query = query.Where("created_at <= ?", *params.CreatedTo)
}
if params.LastLoginFrom != nil {
query = query.Where("last_login_at >= ?", *params.LastLoginFrom)
}
if params.LastLoginTo != nil {
query = query.Where("last_login_at <= ?", *params.LastLoginTo)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序
orderBy := "created_at DESC"
if params.Sort != "" {
order := "ASC"
if params.Order == "desc" {
order = "DESC"
}
orderBy = fmt.Sprintf("%s %s", params.Sort, order)
}
query = query.Order(orderBy)
// 应用分页
if params.Page > 0 && params.Limit > 0 {
offset := (params.Page - 1) * params.Limit
query = query.Offset(offset).Limit(params.Limit)
}
// 查询数据
if err := query.Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// ListByRole 根据角色获取用户列表
func (r *userRepositoryImpl) ListByRole(ctx context.Context, role entity.UserRole, params *entity.UserListParams) ([]*entity.User, int64, error) {
if params == nil {
params = &entity.UserListParams{}
}
params.Role = &role
return r.List(ctx, params)
}
// ListByStatus 根据状态获取用户列表
func (r *userRepositoryImpl) ListByStatus(ctx context.Context, status entity.UserStatus, params *entity.UserListParams) ([]*entity.User, int64, error) {
if params == nil {
params = &entity.UserListParams{}
}
params.Status = &status
return r.List(ctx, params)
}
// SearchUsers 搜索用户
func (r *userRepositoryImpl) SearchUsers(ctx context.Context, keyword string, params *entity.UserListParams) ([]*entity.User, int64, error) {
if params == nil {
params = &entity.UserListParams{}
}
params.Search = keyword
return r.List(ctx, params)
}
// Count 统计用户总数
func (r *userRepositoryImpl) Count(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.User{}).Count(&count).Error
return count, err
}
// CountByRole 根据角色统计用户数
func (r *userRepositoryImpl) CountByRole(ctx context.Context, role entity.UserRole) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.User{}).
Where("role = ?", role).Count(&count).Error
return count, err
}
// CountByStatus 根据状态统计用户数
func (r *userRepositoryImpl) CountByStatus(ctx context.Context, status entity.UserStatus) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entity.User{}).
Where("status = ?", status).Count(&count).Error
return count, err
}
// CountActiveUsers 统计活跃用户数
func (r *userRepositoryImpl) CountActiveUsers(ctx context.Context) (int64, error) {
return r.CountByStatus(ctx, entity.UserStatusActive)
}
// UpdateStatus 更新用户状态
func (r *userRepositoryImpl) UpdateStatus(ctx context.Context, id uint, status entity.UserStatus) error {
return r.db.WithContext(ctx).Model(&entity.User{}).
Where("id = ?", id).
Update("status", status).Error
}
// UpdateLastLogin 更新最后登录时间
func (r *userRepositoryImpl) UpdateLastLogin(ctx context.Context, id uint) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&entity.User{}).
Where("id = ?", id).
Update("last_login_at", now).Error
}
// UpdatePassword 更新密码
func (r *userRepositoryImpl) UpdatePassword(ctx context.Context, id uint, hashedPassword string) error {
return r.db.WithContext(ctx).Model(&entity.User{}).
Where("id = ?", id).
Update("password", hashedPassword).Error
}
// SoftDelete 软删除用户
func (r *userRepositoryImpl) SoftDelete(ctx context.Context, id uint) error {
return r.db.WithContext(ctx).Delete(&entity.User{}, id).Error
}
// Restore 恢复软删除用户
func (r *userRepositoryImpl) Restore(ctx context.Context, id uint) error {
return r.db.WithContext(ctx).Unscoped().Model(&entity.User{}).
Where("id = ?", id).
Update("deleted_at", nil).Error
}
// BatchUpdateStatus 批量更新用户状态
func (r *userRepositoryImpl) BatchUpdateStatus(ctx context.Context, ids []uint, status entity.UserStatus) error {
if len(ids) == 0 {
return nil
}
return r.db.WithContext(ctx).Model(&entity.User{}).
Where("id IN ?", ids).
Update("status", status).Error
}
// BatchDelete 批量删除用户
func (r *userRepositoryImpl) BatchDelete(ctx context.Context, ids []uint) error {
if len(ids) == 0 {
return nil
}
return r.db.WithContext(ctx).Delete(&entity.User{}, ids).Error
}
// ValidateEmailUnique 验证邮箱唯一性
func (r *userRepositoryImpl) ValidateEmailUnique(ctx context.Context, email string, excludeID uint) error {
var count int64
query := r.db.WithContext(ctx).Model(&entity.User{}).Where("email = ?", email)
if excludeID > 0 {
query = query.Where("id != ?", excludeID)
}
if err := query.Count(&count).Error; err != nil {
return err
}
if count > 0 {
return errors.New("email already exists")
}
return nil
}
// ValidateUsernameUnique 验证用户名唯一性
func (r *userRepositoryImpl) ValidateUsernameUnique(ctx context.Context, username string, excludeID uint) error {
var count int64
query := r.db.WithContext(ctx).Model(&entity.User{}).Where("username = ?", username)
if excludeID > 0 {
query = query.Where("id != ?", excludeID)
}
if err := query.Count(&count).Error; err != nil {
return err
}
if count > 0 {
return errors.New("username already exists")
}
return nil
}
// GetUserPhotos 获取用户照片
func (r *userRepositoryImpl) GetUserPhotos(ctx context.Context, userID uint, params *entity.PhotoListParams) ([]*entity.Photo, int64, error) {
var photos []*entity.Photo
var total int64
query := r.db.WithContext(ctx).Model(&entity.Photo{}).Where("user_id = ?", userID)
// 应用过滤条件
if params != nil {
if params.Status != nil {
query = query.Where("status = ?", *params.Status)
}
if params.CategoryID != nil {
query = query.Joins("JOIN photo_categories ON photos.id = photo_categories.photo_id").
Where("photo_categories.category_id = ?", *params.CategoryID)
}
if params.TagID != nil {
query = query.Joins("JOIN photo_tags ON photos.id = photo_tags.photo_id").
Where("photo_tags.tag_id = ?", *params.TagID)
}
if params.DateFrom != nil {
query = query.Where("taken_at >= ?", *params.DateFrom)
}
if params.DateTo != nil {
query = query.Where("taken_at <= ?", *params.DateTo)
}
if params.Search != "" {
query = query.Where("title ILIKE ? OR description ILIKE ?",
"%"+params.Search+"%", "%"+params.Search+"%")
}
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序和分页
if params != nil {
orderBy := "created_at DESC"
if params.Sort != "" {
order := "ASC"
if params.Order == "desc" {
order = "DESC"
}
orderBy = fmt.Sprintf("%s %s", params.Sort, order)
}
query = query.Order(orderBy)
if params.Page > 0 && params.Limit > 0 {
offset := (params.Page - 1) * params.Limit
query = query.Offset(offset).Limit(params.Limit)
}
}
// 预加载关联数据
query = query.Preload("Categories").Preload("Tags")
// 查询数据
if err := query.Find(&photos).Error; err != nil {
return nil, 0, err
}
return photos, total, nil
}
// GetUserStats 获取用户统计信息
func (r *userRepositoryImpl) GetUserStats(ctx context.Context, userID uint) (*entity.UserStats, error) {
var stats entity.UserStats
// 照片统计
var photoCount int64
if err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("user_id = ?", userID).Count(&photoCount).Error; err != nil {
return nil, err
}
stats.PhotoCount = photoCount
// 按状态统计照片
for _, status := range []entity.PhotoStatus{
entity.PhotoStatusActive,
entity.PhotoStatusDraft,
entity.PhotoStatusArchived,
} {
var count int64
if err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("user_id = ? AND status = ?", userID, status).
Count(&count).Error; err != nil {
return nil, err
}
switch status {
case entity.PhotoStatusActive:
stats.PublishedPhotos = count
case entity.PhotoStatusDraft:
stats.DraftPhotos = count
case entity.PhotoStatusArchived:
stats.ArchivedPhotos = count
}
}
// 总浏览数
var totalViews int64
if err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("user_id = ?", userID).
Select("COALESCE(SUM(view_count), 0)").Row().Scan(&totalViews); err != nil {
return nil, err
}
stats.TotalViews = totalViews
// 总下载数
var totalDownloads int64
if err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("user_id = ?", userID).
Select("COALESCE(SUM(download_count), 0)").Row().Scan(&totalDownloads); err != nil {
return nil, err
}
stats.TotalDownloads = totalDownloads
// 存储空间使用
var storageUsed int64
if err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("user_id = ?", userID).
Select("COALESCE(SUM(file_size), 0)").Row().Scan(&storageUsed); err != nil {
return nil, err
}
stats.StorageUsed = storageUsed
// 本月新增照片
now := time.Now()
startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
endOfMonth := startOfMonth.AddDate(0, 1, 0).Add(-time.Nanosecond)
var monthlyPhotos int64
if err := r.db.WithContext(ctx).Model(&entity.Photo{}).
Where("user_id = ? AND created_at >= ? AND created_at <= ?", userID, startOfMonth, endOfMonth).
Count(&monthlyPhotos).Error; err != nil {
return nil, err
}
stats.MonthlyPhotos = monthlyPhotos
return &stats, nil
}
// GetAllStats 获取全部用户统计信息
func (r *userRepositoryImpl) GetAllStats(ctx context.Context) (*entity.UserGlobalStats, error) {
var stats entity.UserGlobalStats
// 总用户数
if total, err := r.Count(ctx); err != nil {
return nil, err
} else {
stats.Total = total
}
// 活跃用户数
if active, err := r.CountActiveUsers(ctx); err != nil {
return nil, err
} else {
stats.Active = active
}
// 按角色统计
for _, role := range []entity.UserRole{
entity.UserRoleAdmin,
entity.UserRoleEditor,
entity.UserRoleUser,
} {
if count, err := r.CountByRole(ctx, role); err != nil {
return nil, err
} else {
switch role {
case entity.UserRoleAdmin:
stats.Admins = count
case entity.UserRoleEditor:
stats.Editors = count
case entity.UserRoleUser:
stats.Users = count
}
}
}
// 本月新注册用户
now := time.Now()
startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
endOfMonth := startOfMonth.AddDate(0, 1, 0).Add(-time.Nanosecond)
var monthlyUsers int64
if err := r.db.WithContext(ctx).Model(&entity.User{}).
Where("created_at >= ? AND created_at <= ?", startOfMonth, endOfMonth).
Count(&monthlyUsers).Error; err != nil {
return nil, err
}
stats.MonthlyRegistrations = monthlyUsers
return &stats, nil
}