Files
photography/backend-old/internal/repository/CLAUDE.md
xujiang 010fe2a8c7
Some checks failed
部署后端服务 / 🧪 测试后端 (push) Failing after 5m8s
部署后端服务 / 🚀 构建并部署 (push) Has been skipped
部署后端服务 / 🔄 回滚部署 (push) Has been skipped
fix
2025-07-10 18:09:11 +08:00

27 KiB

Repository Layer - CLAUDE.md

本文件为 Claude Code 在数据访问层中工作时提供指导。

🎯 模块概览

Repository 层负责数据访问逻辑,提供数据库操作的抽象接口,隔离业务逻辑与数据存储细节。

主要职责

  • 🔧 提供数据库操作接口
  • 📊 实现 CRUD 操作
  • 🔍 提供复杂查询支持
  • 💾 管理数据库连接和事务
  • 🚀 优化查询性能
  • 🔄 支持多种数据源

📁 模块结构

internal/repository/
├── CLAUDE.md                    # 📋 当前文件 - 数据访问开发指导
├── interfaces/                  # 🔗 仓储接口定义
│   ├── user_repository.go       # 用户仓储接口
│   ├── photo_repository.go      # 照片仓储接口
│   ├── category_repository.go   # 分类仓储接口
│   ├── tag_repository.go        # 标签仓储接口
│   └── album_repository.go      # 相册仓储接口
├── postgres/                    # 🐘 PostgreSQL 实现
│   ├── user_repository.go       # 用户仓储实现
│   ├── photo_repository.go      # 照片仓储实现
│   ├── category_repository.go   # 分类仓储实现
│   ├── tag_repository.go        # 标签仓储实现
│   ├── album_repository.go      # 相册仓储实现
│   └── base_repository.go       # 基础仓储实现
├── redis/                       # 🟥 Redis 缓存实现
│   ├── user_cache.go            # 用户缓存
│   ├── photo_cache.go           # 照片缓存
│   └── cache_manager.go         # 缓存管理器
├── sqlite/                      # 📦 SQLite 实现(开发用)
│   ├── user_repository.go       # 用户仓储实现
│   ├── photo_repository.go      # 照片仓储实现
│   └── base_repository.go       # 基础仓储实现
├── mocks/                       # 🧪 模拟对象(测试用)
│   ├── user_repository_mock.go  # 用户仓储模拟
│   ├── photo_repository_mock.go # 照片仓储模拟
│   └── generate.go              # 模拟对象生成
└── errors.go                    # 📝 仓储层错误定义

🔗 接口设计

基础仓储接口

// interfaces/base_repository.go - 基础仓储接口
package interfaces

import (
    "context"
    "gorm.io/gorm"
)

// BaseRepositoryr 基础仓储接口
type BaseRepositoryr[T any] interface {
    // 基础 CRUD 操作
    Create(ctx context.Context, entity *T) (*T, error)
    GetByID(ctx context.Context, id uint) (*T, error)
    Update(ctx context.Context, entity *T) (*T, error)
    Delete(ctx context.Context, id uint) error
    
    // 批量操作
    CreateBatch(ctx context.Context, entities []*T) error
    UpdateBatch(ctx context.Context, entities []*T) error
    DeleteBatch(ctx context.Context, ids []uint) error
    
    // 查询操作
    List(ctx context.Context, opts *ListOptions) ([]*T, int64, error)
    Count(ctx context.Context, conditions map[string]interface{}) (int64, error)
    Exists(ctx context.Context, conditions map[string]interface{}) (bool, error)
    
    // 事务操作
    WithTx(tx *gorm.DB) BaseRepositoryr[T]
    Transaction(ctx context.Context, fn func(repo BaseRepositoryr[T]) error) error
}

// ListOptions 查询选项
type ListOptions struct {
    Page       int                    `json:"page"`
    Limit      int                    `json:"limit"`
    Sort       string                 `json:"sort"`
    Order      string                 `json:"order"`
    Conditions map[string]interface{} `json:"conditions"`
    Preloads   []string               `json:"preloads"`
}

用户仓储接口

// interfaces/user_repository.go - 用户仓储接口
package interfaces

import (
    "context"
    "photography-backend/internal/model/entity"
)

// UserRepositoryr 用户仓储接口
type UserRepositoryr interface {
    BaseRepositoryr[entity.User]
    
    // 用户特定查询
    GetByEmail(ctx context.Context, email string) (*entity.User, error)
    GetByUsername(ctx context.Context, username string) (*entity.User, error)
    GetByEmailOrUsername(ctx context.Context, emailOrUsername string) (*entity.User, error)
    
    // 用户列表查询
    ListByRole(ctx context.Context, role entity.UserRole, opts *ListOptions) ([]*entity.User, int64, error)
    ListByStatus(ctx context.Context, status entity.UserStatus, opts *ListOptions) ([]*entity.User, int64, error)
    SearchUsers(ctx context.Context, keyword string, opts *ListOptions) ([]*entity.User, int64, error)
    
    // 用户统计
    CountByRole(ctx context.Context, role entity.UserRole) (int64, error)
    CountByStatus(ctx context.Context, status entity.UserStatus) (int64, error)
    CountActiveUsers(ctx context.Context) (int64, error)
    
    // 用户状态更新
    UpdateStatus(ctx context.Context, id uint, status entity.UserStatus) error
    UpdateLastLogin(ctx context.Context, id uint) error
    
    // 密码相关
    UpdatePassword(ctx context.Context, id uint, hashedPassword string) error
    
    // 软删除恢复
    Restore(ctx context.Context, id uint) error
}

照片仓储接口

// interfaces/photo_repository.go - 照片仓储接口
package interfaces

import (
    "context"
    "time"
    "photography-backend/internal/model/entity"
)

// PhotoRepositoryr 照片仓储接口
type PhotoRepositoryr interface {
    BaseRepositoryr[entity.Photo]
    
    // 照片查询
    GetByFilename(ctx context.Context, filename string) (*entity.Photo, error)
    GetByUserID(ctx context.Context, userID uint) ([]*entity.Photo, error)
    ListByUserID(ctx context.Context, userID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
    ListByStatus(ctx context.Context, status entity.PhotoStatus, opts *ListOptions) ([]*entity.Photo, int64, error)
    
    // 分类和标签查询
    ListByCategory(ctx context.Context, categoryID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
    ListByTag(ctx context.Context, tagID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
    ListByAlbum(ctx context.Context, albumID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
    
    // 搜索功能
    SearchPhotos(ctx context.Context, keyword string, opts *SearchOptions) ([]*entity.Photo, int64, error)
    SearchByMetadata(ctx context.Context, metadata map[string]interface{}, opts *ListOptions) ([]*entity.Photo, int64, error)
    
    // 时间范围查询
    ListByDateRange(ctx context.Context, startDate, endDate time.Time, opts *ListOptions) ([]*entity.Photo, int64, error)
    ListByCreatedDateRange(ctx context.Context, startDate, endDate time.Time, opts *ListOptions) ([]*entity.Photo, int64, error)
    
    // 统计查询
    CountByUser(ctx context.Context, userID uint) (int64, error)
    CountByStatus(ctx context.Context, status entity.PhotoStatus) (int64, error)
    CountByCategory(ctx context.Context, categoryID uint) (int64, error)
    CountByTag(ctx context.Context, tagID uint) (int64, error)
    
    // 关联操作
    AddCategories(ctx context.Context, photoID uint, categoryIDs []uint) error
    RemoveCategories(ctx context.Context, photoID uint, categoryIDs []uint) error
    AddTags(ctx context.Context, photoID uint, tagIDs []uint) error
    RemoveTags(ctx context.Context, photoID uint, tagIDs []uint) error
    
    // 统计更新
    IncrementViewCount(ctx context.Context, id uint) error
    IncrementDownloadCount(ctx context.Context, id uint) error
    
    // 批量操作
    BatchUpdateStatus(ctx context.Context, ids []uint, status entity.PhotoStatus) error
    BatchDelete(ctx context.Context, ids []uint) error
}

// SearchOptions 搜索选项
type SearchOptions struct {
    ListOptions
    Fields    []string   `json:"fields"`
    DateFrom  *time.Time `json:"date_from"`
    DateTo    *time.Time `json:"date_to"`
    MinWidth  int        `json:"min_width"`
    MaxWidth  int        `json:"max_width"`
    MinHeight int        `json:"min_height"`
    MaxHeight int        `json:"max_height"`
}

🔧 仓储实现

基础仓储实现

// postgres/base_repository.go - 基础仓储实现
package postgres

import (
    "context"
    "errors"
    "fmt"
    "reflect"
    
    "gorm.io/gorm"
    "go.uber.org/zap"
    
    "photography-backend/internal/repository/interfaces"
    "photography-backend/pkg/logger"
)

// BaseRepository 基础仓储实现
type BaseRepository[T any] struct {
    db     *gorm.DB
    logger logger.Logger
}

// NewBaseRepository 创建基础仓储
func NewBaseRepository[T any](db *gorm.DB, logger logger.Logger) *BaseRepository[T] {
    return &BaseRepository[T]{
        db:     db,
        logger: logger,
    }
}

// Create 创建记录
func (r *BaseRepository[T]) Create(ctx context.Context, entity *T) (*T, error) {
    if err := r.db.WithContext(ctx).Create(entity).Error; err != nil {
        r.logger.Error("failed to create entity", zap.Error(err))
        return nil, err
    }
    return entity, nil
}

// GetByID 根据ID获取记录
func (r *BaseRepository[T]) GetByID(ctx context.Context, id uint) (*T, error) {
    var entity T
    if err := r.db.WithContext(ctx).First(&entity, id).Error; err != nil {
        if errors.Is(err, gorm.ErrRecordNotFound) {
            return nil, ErrNotFound
        }
        r.logger.Error("failed to get entity by id", zap.Error(err), zap.Uint("id", id))
        return nil, err
    }
    return &entity, nil
}

// Update 更新记录
func (r *BaseRepository[T]) Update(ctx context.Context, entity *T) (*T, error) {
    if err := r.db.WithContext(ctx).Save(entity).Error; err != nil {
        r.logger.Error("failed to update entity", zap.Error(err))
        return nil, err
    }
    return entity, nil
}

// Delete 删除记录
func (r *BaseRepository[T]) Delete(ctx context.Context, id uint) error {
    var entity T
    if err := r.db.WithContext(ctx).Delete(&entity, id).Error; err != nil {
        r.logger.Error("failed to delete entity", zap.Error(err), zap.Uint("id", id))
        return err
    }
    return nil
}

// List 列表查询
func (r *BaseRepository[T]) List(ctx context.Context, opts *interfaces.ListOptions) ([]*T, int64, error) {
    var entities []*T
    var total int64
    
    db := r.db.WithContext(ctx)
    
    // 应用条件
    if opts.Conditions != nil {
        for key, value := range opts.Conditions {
            db = db.Where(key, value)
        }
    }
    
    // 获取总数
    if err := db.Model(new(T)).Count(&total).Error; err != nil {
        r.logger.Error("failed to count entities", zap.Error(err))
        return nil, 0, err
    }
    
    // 应用排序
    if opts.Sort != "" {
        order := "ASC"
        if opts.Order == "desc" {
            order = "DESC"
        }
        db = db.Order(fmt.Sprintf("%s %s", opts.Sort, order))
    }
    
    // 应用分页
    if opts.Page > 0 && opts.Limit > 0 {
        offset := (opts.Page - 1) * opts.Limit
        db = db.Offset(offset).Limit(opts.Limit)
    }
    
    // 应用预加载
    if opts.Preloads != nil {
        for _, preload := range opts.Preloads {
            db = db.Preload(preload)
        }
    }
    
    // 查询数据
    if err := db.Find(&entities).Error; err != nil {
        r.logger.Error("failed to list entities", zap.Error(err))
        return nil, 0, err
    }
    
    return entities, total, nil
}

// Transaction 事务执行
func (r *BaseRepository[T]) Transaction(ctx context.Context, fn func(repo interfaces.BaseRepositoryr[T]) error) error {
    return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
        repo := &BaseRepository[T]{
            db:     tx,
            logger: r.logger,
        }
        return fn(repo)
    })
}

// WithTx 使用事务
func (r *BaseRepository[T]) WithTx(tx *gorm.DB) interfaces.BaseRepositoryr[T] {
    return &BaseRepository[T]{
        db:     tx,
        logger: r.logger,
    }
}

用户仓储实现

// postgres/user_repository.go - 用户仓储实现
package postgres

import (
    "context"
    "errors"
    "strings"
    "time"
    
    "gorm.io/gorm"
    "go.uber.org/zap"
    
    "photography-backend/internal/model/entity"
    "photography-backend/internal/repository/interfaces"
    "photography-backend/pkg/logger"
)

// UserRepository 用户仓储实现
type UserRepository struct {
    *BaseRepository[entity.User]
    db     *gorm.DB
    logger logger.Logger
}

// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB, logger logger.Logger) interfaces.UserRepositoryr {
    return &UserRepository{
        BaseRepository: NewBaseRepository[entity.User](db, logger),
        db:            db,
        logger:        logger,
    }
}

// GetByEmail 根据邮箱获取用户
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*entity.User, error) {
    var user entity.User
    err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
    if err != nil {
        if errors.Is(err, gorm.ErrRecordNotFound) {
            return nil, ErrNotFound
        }
        r.logger.Error("failed to get user by email", zap.Error(err), zap.String("email", email))
        return nil, err
    }
    return &user, nil
}

// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
    var user entity.User
    err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
    if err != nil {
        if errors.Is(err, gorm.ErrRecordNotFound) {
            return nil, ErrNotFound
        }
        r.logger.Error("failed to get user by username", zap.Error(err), zap.String("username", username))
        return nil, err
    }
    return &user, nil
}

// GetByEmailOrUsername 根据邮箱或用户名获取用户
func (r *UserRepository) GetByEmailOrUsername(ctx context.Context, emailOrUsername string) (*entity.User, error) {
    var user entity.User
    err := r.db.WithContext(ctx).Where("email = ? OR username = ?", emailOrUsername, emailOrUsername).First(&user).Error
    if err != nil {
        if errors.Is(err, gorm.ErrRecordNotFound) {
            return nil, ErrNotFound
        }
        r.logger.Error("failed to get user by email or username", zap.Error(err), zap.String("emailOrUsername", emailOrUsername))
        return nil, err
    }
    return &user, nil
}

// SearchUsers 搜索用户
func (r *UserRepository) SearchUsers(ctx context.Context, keyword string, opts *interfaces.ListOptions) ([]*entity.User, int64, error) {
    var users []*entity.User
    var total int64
    
    db := r.db.WithContext(ctx)
    
    // 搜索条件
    searchCondition := fmt.Sprintf("username ILIKE %s OR email ILIKE %s OR first_name ILIKE %s OR last_name ILIKE %s", 
        "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
    db = db.Where(searchCondition)
    
    // 应用其他条件
    if opts.Conditions != nil {
        for key, value := range opts.Conditions {
            db = db.Where(key, value)
        }
    }
    
    // 获取总数
    if err := db.Model(&entity.User{}).Count(&total).Error; err != nil {
        r.logger.Error("failed to count users", zap.Error(err))
        return nil, 0, err
    }
    
    // 应用排序和分页
    if opts.Sort != "" {
        order := "ASC"
        if opts.Order == "desc" {
            order = "DESC"
        }
        db = db.Order(fmt.Sprintf("%s %s", opts.Sort, order))
    }
    
    if opts.Page > 0 && opts.Limit > 0 {
        offset := (opts.Page - 1) * opts.Limit
        db = db.Offset(offset).Limit(opts.Limit)
    }
    
    // 查询数据
    if err := db.Find(&users).Error; err != nil {
        r.logger.Error("failed to search users", zap.Error(err))
        return nil, 0, err
    }
    
    return users, total, nil
}

// UpdateStatus 更新用户状态
func (r *UserRepository) UpdateStatus(ctx context.Context, id uint, status entity.UserStatus) error {
    err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("status", status).Error
    if err != nil {
        r.logger.Error("failed to update user status", zap.Error(err), zap.Uint("id", id), zap.String("status", string(status)))
        return err
    }
    return nil
}

// UpdateLastLogin 更新最后登录时间
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id uint) error {
    now := time.Now()
    err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("last_login_at", now).Error
    if err != nil {
        r.logger.Error("failed to update last login time", zap.Error(err), zap.Uint("id", id))
        return err
    }
    return nil
}

// UpdatePassword 更新密码
func (r *UserRepository) UpdatePassword(ctx context.Context, id uint, hashedPassword string) error {
    err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
    if err != nil {
        r.logger.Error("failed to update password", zap.Error(err), zap.Uint("id", id))
        return err
    }
    return nil
}

// CountByRole 按角色统计用户数
func (r *UserRepository) CountByRole(ctx context.Context, role entity.UserRole) (int64, error) {
    var count int64
    err := r.db.WithContext(ctx).Model(&entity.User{}).Where("role = ?", role).Count(&count).Error
    if err != nil {
        r.logger.Error("failed to count users by role", zap.Error(err), zap.String("role", string(role)))
        return 0, err
    }
    return count, nil
}

// CountActiveUsers 统计活跃用户数
func (r *UserRepository) CountActiveUsers(ctx context.Context) (int64, error) {
    var count int64
    err := r.db.WithContext(ctx).Model(&entity.User{}).Where("status = ?", entity.UserStatusActive).Count(&count).Error
    if err != nil {
        r.logger.Error("failed to count active users", zap.Error(err))
        return 0, err
    }
    return count, nil
}

💾 Redis 缓存实现

缓存管理器

// redis/cache_manager.go - 缓存管理器
package redis

import (
    "context"
    "encoding/json"
    "fmt"
    "time"
    
    "github.com/go-redis/redis/v8"
    "go.uber.org/zap"
    
    "photography-backend/pkg/logger"
)

// CacheManager 缓存管理器
type CacheManager struct {
    client *redis.Client
    logger logger.Logger
}

// NewCacheManager 创建缓存管理器
func NewCacheManager(client *redis.Client, logger logger.Logger) *CacheManager {
    return &CacheManager{
        client: client,
        logger: logger,
    }
}

// Set 设置缓存
func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
    data, err := json.Marshal(value)
    if err != nil {
        cm.logger.Error("failed to marshal cache value", zap.Error(err), zap.String("key", key))
        return err
    }
    
    err = cm.client.Set(ctx, key, data, ttl).Err()
    if err != nil {
        cm.logger.Error("failed to set cache", zap.Error(err), zap.String("key", key))
        return err
    }
    
    return nil
}

// Get 获取缓存
func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) error {
    data, err := cm.client.Get(ctx, key).Result()
    if err != nil {
        if err == redis.Nil {
            return ErrCacheNotFound
        }
        cm.logger.Error("failed to get cache", zap.Error(err), zap.String("key", key))
        return err
    }
    
    err = json.Unmarshal([]byte(data), dest)
    if err != nil {
        cm.logger.Error("failed to unmarshal cache value", zap.Error(err), zap.String("key", key))
        return err
    }
    
    return nil
}

// Delete 删除缓存
func (cm *CacheManager) Delete(ctx context.Context, key string) error {
    err := cm.client.Del(ctx, key).Err()
    if err != nil {
        cm.logger.Error("failed to delete cache", zap.Error(err), zap.String("key", key))
        return err
    }
    return nil
}

// DeletePattern 批量删除缓存
func (cm *CacheManager) DeletePattern(ctx context.Context, pattern string) error {
    keys, err := cm.client.Keys(ctx, pattern).Result()
    if err != nil {
        cm.logger.Error("failed to get keys by pattern", zap.Error(err), zap.String("pattern", pattern))
        return err
    }
    
    if len(keys) > 0 {
        err = cm.client.Del(ctx, keys...).Err()
        if err != nil {
            cm.logger.Error("failed to delete keys by pattern", zap.Error(err), zap.String("pattern", pattern))
            return err
        }
    }
    
    return nil
}

// Exists 检查缓存是否存在
func (cm *CacheManager) Exists(ctx context.Context, key string) (bool, error) {
    count, err := cm.client.Exists(ctx, key).Result()
    if err != nil {
        cm.logger.Error("failed to check cache existence", zap.Error(err), zap.String("key", key))
        return false, err
    }
    return count > 0, nil
}

// SetTTL 设置过期时间
func (cm *CacheManager) SetTTL(ctx context.Context, key string, ttl time.Duration) error {
    err := cm.client.Expire(ctx, key, ttl).Err()
    if err != nil {
        cm.logger.Error("failed to set cache ttl", zap.Error(err), zap.String("key", key))
        return err
    }
    return nil
}

用户缓存

// redis/user_cache.go - 用户缓存
package redis

import (
    "context"
    "fmt"
    "time"
    
    "photography-backend/internal/model/entity"
    "photography-backend/pkg/logger"
)

// UserCache 用户缓存
type UserCache struct {
    *CacheManager
    ttl time.Duration
}

// NewUserCache 创建用户缓存
func NewUserCache(cm *CacheManager, ttl time.Duration) *UserCache {
    return &UserCache{
        CacheManager: cm,
        ttl:         ttl,
    }
}

// SetUser 缓存用户
func (uc *UserCache) SetUser(ctx context.Context, user *entity.User) error {
    key := fmt.Sprintf("user:id:%d", user.ID)
    return uc.Set(ctx, key, user, uc.ttl)
}

// GetUser 获取用户缓存
func (uc *UserCache) GetUser(ctx context.Context, id uint) (*entity.User, error) {
    key := fmt.Sprintf("user:id:%d", id)
    var user entity.User
    err := uc.Get(ctx, key, &user)
    if err != nil {
        return nil, err
    }
    return &user, nil
}

// SetUserByEmail 按邮箱缓存用户
func (uc *UserCache) SetUserByEmail(ctx context.Context, user *entity.User) error {
    key := fmt.Sprintf("user:email:%s", user.Email)
    return uc.Set(ctx, key, user, uc.ttl)
}

// GetUserByEmail 按邮箱获取用户缓存
func (uc *UserCache) GetUserByEmail(ctx context.Context, email string) (*entity.User, error) {
    key := fmt.Sprintf("user:email:%s", email)
    var user entity.User
    err := uc.Get(ctx, key, &user)
    if err != nil {
        return nil, err
    }
    return &user, nil
}

// DeleteUser 删除用户缓存
func (uc *UserCache) DeleteUser(ctx context.Context, id uint) error {
    key := fmt.Sprintf("user:id:%d", id)
    return uc.Delete(ctx, key)
}

// DeleteUserByEmail 按邮箱删除用户缓存
func (uc *UserCache) DeleteUserByEmail(ctx context.Context, email string) error {
    key := fmt.Sprintf("user:email:%s", email)
    return uc.Delete(ctx, key)
}

// InvalidateUserCache 失效用户相关缓存
func (uc *UserCache) InvalidateUserCache(ctx context.Context, userID uint) error {
    pattern := fmt.Sprintf("user:*:%d", userID)
    return uc.DeletePattern(ctx, pattern)
}

🔍 错误处理

仓储层错误定义

// errors.go - 仓储层错误定义
package repository

import "errors"

var (
    // 通用错误
    ErrNotFound         = errors.New("record not found")
    ErrDuplicateKey     = errors.New("duplicate key")
    ErrInvalidParameter = errors.New("invalid parameter")
    ErrDatabaseError    = errors.New("database error")
    ErrTransactionError = errors.New("transaction error")
    
    // 缓存错误
    ErrCacheNotFound    = errors.New("cache not found")
    ErrCacheError       = errors.New("cache error")
    ErrCacheExpired     = errors.New("cache expired")
    
    // 连接错误
    ErrConnectionFailed = errors.New("connection failed")
    ErrConnectionTimeout = errors.New("connection timeout")
)

🧪 测试

仓储测试

// postgres/user_repository_test.go - 用户仓储测试
package postgres

import (
    "context"
    "testing"
    
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/suite"
    "gorm.io/driver/sqlite"
    "gorm.io/gorm"
    
    "photography-backend/internal/model/entity"
    "photography-backend/pkg/logger"
)

type UserRepositoryTestSuite struct {
    suite.Suite
    db   *gorm.DB
    repo *UserRepository
}

func (suite *UserRepositoryTestSuite) SetupTest() {
    // 使用内存数据库进行测试
    db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
    suite.Require().NoError(err)
    
    // 自动迁移
    err = db.AutoMigrate(&entity.User{})
    suite.Require().NoError(err)
    
    suite.db = db
    suite.repo = NewUserRepository(db, logger.NewNoop()).(*UserRepository)
}

func (suite *UserRepositoryTestSuite) TearDownTest() {
    sqlDB, _ := suite.db.DB()
    sqlDB.Close()
}

func (suite *UserRepositoryTestSuite) TestCreateUser() {
    ctx := context.Background()
    
    user := &entity.User{
        Username: "testuser",
        Email:    "test@example.com",
        Password: "hashedpassword",
        Role:     entity.UserRoleUser,
        Status:   entity.UserStatusActive,
    }
    
    createdUser, err := suite.repo.Create(ctx, user)
    
    assert.NoError(suite.T(), err)
    assert.NotZero(suite.T(), createdUser.ID)
    assert.Equal(suite.T(), user.Username, createdUser.Username)
    assert.Equal(suite.T(), user.Email, createdUser.Email)
}

func (suite *UserRepositoryTestSuite) TestGetUserByEmail() {
    ctx := context.Background()
    
    // 创建测试用户
    user := &entity.User{
        Username: "testuser",
        Email:    "test@example.com",
        Password: "hashedpassword",
        Role:     entity.UserRoleUser,
        Status:   entity.UserStatusActive,
    }
    
    createdUser, err := suite.repo.Create(ctx, user)
    suite.Require().NoError(err)
    
    // 根据邮箱获取用户
    foundUser, err := suite.repo.GetByEmail(ctx, user.Email)
    
    assert.NoError(suite.T(), err)
    assert.Equal(suite.T(), createdUser.ID, foundUser.ID)
    assert.Equal(suite.T(), createdUser.Email, foundUser.Email)
}

func TestUserRepositoryTestSuite(t *testing.T) {
    suite.Run(t, new(UserRepositoryTestSuite))
}

💡 最佳实践

设计原则

  1. 接口隔离: 定义清晰的仓储接口
  2. 依赖倒置: 依赖接口而非具体实现
  3. 单一职责: 每个仓储只负责一个实体
  4. 错误处理: 统一错误处理和日志记录
  5. 事务支持: 提供事务操作支持

性能优化

  1. 查询优化: 使用适当的索引和查询条件
  2. 批量操作: 支持批量插入和更新
  3. 缓存策略: 合理使用缓存减少数据库访问
  4. 连接池: 使用连接池管理数据库连接
  5. 预加载: 避免 N+1 查询问题

测试策略

  1. 单元测试: 为每个仓储编写单元测试
  2. 集成测试: 测试数据库交互
  3. 模拟对象: 使用 Mock 对象进行测试
  4. 测试数据: 准备充分的测试数据
  5. 性能测试: 测试查询性能和并发性能

本模块是数据访问的核心,确保数据操作的正确性和性能是关键。