873 lines
27 KiB
Markdown
873 lines
27 KiB
Markdown
# Repository Layer - CLAUDE.md
|
|
|
|
本文件为 Claude Code 在数据访问层中工作时提供指导。
|
|
|
|
## 🎯 模块概览
|
|
|
|
Repository 层负责数据访问逻辑,提供数据库操作的抽象接口,隔离业务逻辑与数据存储细节。
|
|
|
|
### 主要职责
|
|
- 🔧 提供数据库操作接口
|
|
- 📊 实现 CRUD 操作
|
|
- 🔍 提供复杂查询支持
|
|
- 💾 管理数据库连接和事务
|
|
- 🚀 优化查询性能
|
|
- 🔄 支持多种数据源
|
|
|
|
## 📁 模块结构
|
|
|
|
```
|
|
internal/repository/
|
|
├── CLAUDE.md # 📋 当前文件 - 数据访问开发指导
|
|
├── interfaces/ # 🔗 仓储接口定义
|
|
│ ├── user_repository.go # 用户仓储接口
|
|
│ ├── photo_repository.go # 照片仓储接口
|
|
│ ├── category_repository.go # 分类仓储接口
|
|
│ ├── tag_repository.go # 标签仓储接口
|
|
│ └── album_repository.go # 相册仓储接口
|
|
├── postgres/ # 🐘 PostgreSQL 实现
|
|
│ ├── user_repository.go # 用户仓储实现
|
|
│ ├── photo_repository.go # 照片仓储实现
|
|
│ ├── category_repository.go # 分类仓储实现
|
|
│ ├── tag_repository.go # 标签仓储实现
|
|
│ ├── album_repository.go # 相册仓储实现
|
|
│ └── base_repository.go # 基础仓储实现
|
|
├── redis/ # 🟥 Redis 缓存实现
|
|
│ ├── user_cache.go # 用户缓存
|
|
│ ├── photo_cache.go # 照片缓存
|
|
│ └── cache_manager.go # 缓存管理器
|
|
├── sqlite/ # 📦 SQLite 实现(开发用)
|
|
│ ├── user_repository.go # 用户仓储实现
|
|
│ ├── photo_repository.go # 照片仓储实现
|
|
│ └── base_repository.go # 基础仓储实现
|
|
├── mocks/ # 🧪 模拟对象(测试用)
|
|
│ ├── user_repository_mock.go # 用户仓储模拟
|
|
│ ├── photo_repository_mock.go # 照片仓储模拟
|
|
│ └── generate.go # 模拟对象生成
|
|
└── errors.go # 📝 仓储层错误定义
|
|
```
|
|
|
|
## 🔗 接口设计
|
|
|
|
### 基础仓储接口
|
|
```go
|
|
// interfaces/base_repository.go - 基础仓储接口
|
|
package interfaces
|
|
|
|
import (
|
|
"context"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// BaseRepositoryr 基础仓储接口
|
|
type BaseRepositoryr[T any] interface {
|
|
// 基础 CRUD 操作
|
|
Create(ctx context.Context, entity *T) (*T, error)
|
|
GetByID(ctx context.Context, id uint) (*T, error)
|
|
Update(ctx context.Context, entity *T) (*T, error)
|
|
Delete(ctx context.Context, id uint) error
|
|
|
|
// 批量操作
|
|
CreateBatch(ctx context.Context, entities []*T) error
|
|
UpdateBatch(ctx context.Context, entities []*T) error
|
|
DeleteBatch(ctx context.Context, ids []uint) error
|
|
|
|
// 查询操作
|
|
List(ctx context.Context, opts *ListOptions) ([]*T, int64, error)
|
|
Count(ctx context.Context, conditions map[string]interface{}) (int64, error)
|
|
Exists(ctx context.Context, conditions map[string]interface{}) (bool, error)
|
|
|
|
// 事务操作
|
|
WithTx(tx *gorm.DB) BaseRepositoryr[T]
|
|
Transaction(ctx context.Context, fn func(repo BaseRepositoryr[T]) error) error
|
|
}
|
|
|
|
// ListOptions 查询选项
|
|
type ListOptions struct {
|
|
Page int `json:"page"`
|
|
Limit int `json:"limit"`
|
|
Sort string `json:"sort"`
|
|
Order string `json:"order"`
|
|
Conditions map[string]interface{} `json:"conditions"`
|
|
Preloads []string `json:"preloads"`
|
|
}
|
|
```
|
|
|
|
### 用户仓储接口
|
|
```go
|
|
// interfaces/user_repository.go - 用户仓储接口
|
|
package interfaces
|
|
|
|
import (
|
|
"context"
|
|
"photography-backend/internal/model/entity"
|
|
)
|
|
|
|
// UserRepositoryr 用户仓储接口
|
|
type UserRepositoryr interface {
|
|
BaseRepositoryr[entity.User]
|
|
|
|
// 用户特定查询
|
|
GetByEmail(ctx context.Context, email string) (*entity.User, error)
|
|
GetByUsername(ctx context.Context, username string) (*entity.User, error)
|
|
GetByEmailOrUsername(ctx context.Context, emailOrUsername string) (*entity.User, error)
|
|
|
|
// 用户列表查询
|
|
ListByRole(ctx context.Context, role entity.UserRole, opts *ListOptions) ([]*entity.User, int64, error)
|
|
ListByStatus(ctx context.Context, status entity.UserStatus, opts *ListOptions) ([]*entity.User, int64, error)
|
|
SearchUsers(ctx context.Context, keyword string, opts *ListOptions) ([]*entity.User, int64, error)
|
|
|
|
// 用户统计
|
|
CountByRole(ctx context.Context, role entity.UserRole) (int64, error)
|
|
CountByStatus(ctx context.Context, status entity.UserStatus) (int64, error)
|
|
CountActiveUsers(ctx context.Context) (int64, error)
|
|
|
|
// 用户状态更新
|
|
UpdateStatus(ctx context.Context, id uint, status entity.UserStatus) error
|
|
UpdateLastLogin(ctx context.Context, id uint) error
|
|
|
|
// 密码相关
|
|
UpdatePassword(ctx context.Context, id uint, hashedPassword string) error
|
|
|
|
// 软删除恢复
|
|
Restore(ctx context.Context, id uint) error
|
|
}
|
|
```
|
|
|
|
### 照片仓储接口
|
|
```go
|
|
// interfaces/photo_repository.go - 照片仓储接口
|
|
package interfaces
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
"photography-backend/internal/model/entity"
|
|
)
|
|
|
|
// PhotoRepositoryr 照片仓储接口
|
|
type PhotoRepositoryr interface {
|
|
BaseRepositoryr[entity.Photo]
|
|
|
|
// 照片查询
|
|
GetByFilename(ctx context.Context, filename string) (*entity.Photo, error)
|
|
GetByUserID(ctx context.Context, userID uint) ([]*entity.Photo, error)
|
|
ListByUserID(ctx context.Context, userID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
|
|
ListByStatus(ctx context.Context, status entity.PhotoStatus, opts *ListOptions) ([]*entity.Photo, int64, error)
|
|
|
|
// 分类和标签查询
|
|
ListByCategory(ctx context.Context, categoryID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
|
|
ListByTag(ctx context.Context, tagID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
|
|
ListByAlbum(ctx context.Context, albumID uint, opts *ListOptions) ([]*entity.Photo, int64, error)
|
|
|
|
// 搜索功能
|
|
SearchPhotos(ctx context.Context, keyword string, opts *SearchOptions) ([]*entity.Photo, int64, error)
|
|
SearchByMetadata(ctx context.Context, metadata map[string]interface{}, opts *ListOptions) ([]*entity.Photo, int64, error)
|
|
|
|
// 时间范围查询
|
|
ListByDateRange(ctx context.Context, startDate, endDate time.Time, opts *ListOptions) ([]*entity.Photo, int64, error)
|
|
ListByCreatedDateRange(ctx context.Context, startDate, endDate time.Time, opts *ListOptions) ([]*entity.Photo, int64, error)
|
|
|
|
// 统计查询
|
|
CountByUser(ctx context.Context, userID uint) (int64, error)
|
|
CountByStatus(ctx context.Context, status entity.PhotoStatus) (int64, error)
|
|
CountByCategory(ctx context.Context, categoryID uint) (int64, error)
|
|
CountByTag(ctx context.Context, tagID uint) (int64, error)
|
|
|
|
// 关联操作
|
|
AddCategories(ctx context.Context, photoID uint, categoryIDs []uint) error
|
|
RemoveCategories(ctx context.Context, photoID uint, categoryIDs []uint) error
|
|
AddTags(ctx context.Context, photoID uint, tagIDs []uint) error
|
|
RemoveTags(ctx context.Context, photoID uint, tagIDs []uint) error
|
|
|
|
// 统计更新
|
|
IncrementViewCount(ctx context.Context, id uint) error
|
|
IncrementDownloadCount(ctx context.Context, id uint) error
|
|
|
|
// 批量操作
|
|
BatchUpdateStatus(ctx context.Context, ids []uint, status entity.PhotoStatus) error
|
|
BatchDelete(ctx context.Context, ids []uint) error
|
|
}
|
|
|
|
// SearchOptions 搜索选项
|
|
type SearchOptions struct {
|
|
ListOptions
|
|
Fields []string `json:"fields"`
|
|
DateFrom *time.Time `json:"date_from"`
|
|
DateTo *time.Time `json:"date_to"`
|
|
MinWidth int `json:"min_width"`
|
|
MaxWidth int `json:"max_width"`
|
|
MinHeight int `json:"min_height"`
|
|
MaxHeight int `json:"max_height"`
|
|
}
|
|
```
|
|
|
|
## 🔧 仓储实现
|
|
|
|
### 基础仓储实现
|
|
```go
|
|
// postgres/base_repository.go - 基础仓储实现
|
|
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"gorm.io/gorm"
|
|
"go.uber.org/zap"
|
|
|
|
"photography-backend/internal/repository/interfaces"
|
|
"photography-backend/pkg/logger"
|
|
)
|
|
|
|
// BaseRepository 基础仓储实现
|
|
type BaseRepository[T any] struct {
|
|
db *gorm.DB
|
|
logger logger.Logger
|
|
}
|
|
|
|
// NewBaseRepository 创建基础仓储
|
|
func NewBaseRepository[T any](db *gorm.DB, logger logger.Logger) *BaseRepository[T] {
|
|
return &BaseRepository[T]{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Create 创建记录
|
|
func (r *BaseRepository[T]) Create(ctx context.Context, entity *T) (*T, error) {
|
|
if err := r.db.WithContext(ctx).Create(entity).Error; err != nil {
|
|
r.logger.Error("failed to create entity", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
return entity, nil
|
|
}
|
|
|
|
// GetByID 根据ID获取记录
|
|
func (r *BaseRepository[T]) GetByID(ctx context.Context, id uint) (*T, error) {
|
|
var entity T
|
|
if err := r.db.WithContext(ctx).First(&entity, id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
r.logger.Error("failed to get entity by id", zap.Error(err), zap.Uint("id", id))
|
|
return nil, err
|
|
}
|
|
return &entity, nil
|
|
}
|
|
|
|
// Update 更新记录
|
|
func (r *BaseRepository[T]) Update(ctx context.Context, entity *T) (*T, error) {
|
|
if err := r.db.WithContext(ctx).Save(entity).Error; err != nil {
|
|
r.logger.Error("failed to update entity", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
return entity, nil
|
|
}
|
|
|
|
// Delete 删除记录
|
|
func (r *BaseRepository[T]) Delete(ctx context.Context, id uint) error {
|
|
var entity T
|
|
if err := r.db.WithContext(ctx).Delete(&entity, id).Error; err != nil {
|
|
r.logger.Error("failed to delete entity", zap.Error(err), zap.Uint("id", id))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List 列表查询
|
|
func (r *BaseRepository[T]) List(ctx context.Context, opts *interfaces.ListOptions) ([]*T, int64, error) {
|
|
var entities []*T
|
|
var total int64
|
|
|
|
db := r.db.WithContext(ctx)
|
|
|
|
// 应用条件
|
|
if opts.Conditions != nil {
|
|
for key, value := range opts.Conditions {
|
|
db = db.Where(key, value)
|
|
}
|
|
}
|
|
|
|
// 获取总数
|
|
if err := db.Model(new(T)).Count(&total).Error; err != nil {
|
|
r.logger.Error("failed to count entities", zap.Error(err))
|
|
return nil, 0, err
|
|
}
|
|
|
|
// 应用排序
|
|
if opts.Sort != "" {
|
|
order := "ASC"
|
|
if opts.Order == "desc" {
|
|
order = "DESC"
|
|
}
|
|
db = db.Order(fmt.Sprintf("%s %s", opts.Sort, order))
|
|
}
|
|
|
|
// 应用分页
|
|
if opts.Page > 0 && opts.Limit > 0 {
|
|
offset := (opts.Page - 1) * opts.Limit
|
|
db = db.Offset(offset).Limit(opts.Limit)
|
|
}
|
|
|
|
// 应用预加载
|
|
if opts.Preloads != nil {
|
|
for _, preload := range opts.Preloads {
|
|
db = db.Preload(preload)
|
|
}
|
|
}
|
|
|
|
// 查询数据
|
|
if err := db.Find(&entities).Error; err != nil {
|
|
r.logger.Error("failed to list entities", zap.Error(err))
|
|
return nil, 0, err
|
|
}
|
|
|
|
return entities, total, nil
|
|
}
|
|
|
|
// Transaction 事务执行
|
|
func (r *BaseRepository[T]) Transaction(ctx context.Context, fn func(repo interfaces.BaseRepositoryr[T]) error) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
repo := &BaseRepository[T]{
|
|
db: tx,
|
|
logger: r.logger,
|
|
}
|
|
return fn(repo)
|
|
})
|
|
}
|
|
|
|
// WithTx 使用事务
|
|
func (r *BaseRepository[T]) WithTx(tx *gorm.DB) interfaces.BaseRepositoryr[T] {
|
|
return &BaseRepository[T]{
|
|
db: tx,
|
|
logger: r.logger,
|
|
}
|
|
}
|
|
```
|
|
|
|
### 用户仓储实现
|
|
```go
|
|
// postgres/user_repository.go - 用户仓储实现
|
|
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
"go.uber.org/zap"
|
|
|
|
"photography-backend/internal/model/entity"
|
|
"photography-backend/internal/repository/interfaces"
|
|
"photography-backend/pkg/logger"
|
|
)
|
|
|
|
// UserRepository 用户仓储实现
|
|
type UserRepository struct {
|
|
*BaseRepository[entity.User]
|
|
db *gorm.DB
|
|
logger logger.Logger
|
|
}
|
|
|
|
// NewUserRepository 创建用户仓储
|
|
func NewUserRepository(db *gorm.DB, logger logger.Logger) interfaces.UserRepositoryr {
|
|
return &UserRepository{
|
|
BaseRepository: NewBaseRepository[entity.User](db, logger),
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// GetByEmail 根据邮箱获取用户
|
|
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*entity.User, error) {
|
|
var user entity.User
|
|
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
r.logger.Error("failed to get user by email", zap.Error(err), zap.String("email", email))
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// GetByUsername 根据用户名获取用户
|
|
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
|
|
var user entity.User
|
|
err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
r.logger.Error("failed to get user by username", zap.Error(err), zap.String("username", username))
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// GetByEmailOrUsername 根据邮箱或用户名获取用户
|
|
func (r *UserRepository) GetByEmailOrUsername(ctx context.Context, emailOrUsername string) (*entity.User, error) {
|
|
var user entity.User
|
|
err := r.db.WithContext(ctx).Where("email = ? OR username = ?", emailOrUsername, emailOrUsername).First(&user).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
r.logger.Error("failed to get user by email or username", zap.Error(err), zap.String("emailOrUsername", emailOrUsername))
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// SearchUsers 搜索用户
|
|
func (r *UserRepository) SearchUsers(ctx context.Context, keyword string, opts *interfaces.ListOptions) ([]*entity.User, int64, error) {
|
|
var users []*entity.User
|
|
var total int64
|
|
|
|
db := r.db.WithContext(ctx)
|
|
|
|
// 搜索条件
|
|
searchCondition := fmt.Sprintf("username ILIKE %s OR email ILIKE %s OR first_name ILIKE %s OR last_name ILIKE %s",
|
|
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
|
db = db.Where(searchCondition)
|
|
|
|
// 应用其他条件
|
|
if opts.Conditions != nil {
|
|
for key, value := range opts.Conditions {
|
|
db = db.Where(key, value)
|
|
}
|
|
}
|
|
|
|
// 获取总数
|
|
if err := db.Model(&entity.User{}).Count(&total).Error; err != nil {
|
|
r.logger.Error("failed to count users", zap.Error(err))
|
|
return nil, 0, err
|
|
}
|
|
|
|
// 应用排序和分页
|
|
if opts.Sort != "" {
|
|
order := "ASC"
|
|
if opts.Order == "desc" {
|
|
order = "DESC"
|
|
}
|
|
db = db.Order(fmt.Sprintf("%s %s", opts.Sort, order))
|
|
}
|
|
|
|
if opts.Page > 0 && opts.Limit > 0 {
|
|
offset := (opts.Page - 1) * opts.Limit
|
|
db = db.Offset(offset).Limit(opts.Limit)
|
|
}
|
|
|
|
// 查询数据
|
|
if err := db.Find(&users).Error; err != nil {
|
|
r.logger.Error("failed to search users", zap.Error(err))
|
|
return nil, 0, err
|
|
}
|
|
|
|
return users, total, nil
|
|
}
|
|
|
|
// UpdateStatus 更新用户状态
|
|
func (r *UserRepository) UpdateStatus(ctx context.Context, id uint, status entity.UserStatus) error {
|
|
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("status", status).Error
|
|
if err != nil {
|
|
r.logger.Error("failed to update user status", zap.Error(err), zap.Uint("id", id), zap.String("status", string(status)))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateLastLogin 更新最后登录时间
|
|
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id uint) error {
|
|
now := time.Now()
|
|
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("last_login_at", now).Error
|
|
if err != nil {
|
|
r.logger.Error("failed to update last login time", zap.Error(err), zap.Uint("id", id))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdatePassword 更新密码
|
|
func (r *UserRepository) UpdatePassword(ctx context.Context, id uint, hashedPassword string) error {
|
|
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
|
if err != nil {
|
|
r.logger.Error("failed to update password", zap.Error(err), zap.Uint("id", id))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CountByRole 按角色统计用户数
|
|
func (r *UserRepository) CountByRole(ctx context.Context, role entity.UserRole) (int64, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("role = ?", role).Count(&count).Error
|
|
if err != nil {
|
|
r.logger.Error("failed to count users by role", zap.Error(err), zap.String("role", string(role)))
|
|
return 0, err
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// CountActiveUsers 统计活跃用户数
|
|
func (r *UserRepository) CountActiveUsers(ctx context.Context) (int64, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("status = ?", entity.UserStatusActive).Count(&count).Error
|
|
if err != nil {
|
|
r.logger.Error("failed to count active users", zap.Error(err))
|
|
return 0, err
|
|
}
|
|
return count, nil
|
|
}
|
|
```
|
|
|
|
## 💾 Redis 缓存实现
|
|
|
|
### 缓存管理器
|
|
```go
|
|
// redis/cache_manager.go - 缓存管理器
|
|
package redis
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
"go.uber.org/zap"
|
|
|
|
"photography-backend/pkg/logger"
|
|
)
|
|
|
|
// CacheManager 缓存管理器
|
|
type CacheManager struct {
|
|
client *redis.Client
|
|
logger logger.Logger
|
|
}
|
|
|
|
// NewCacheManager 创建缓存管理器
|
|
func NewCacheManager(client *redis.Client, logger logger.Logger) *CacheManager {
|
|
return &CacheManager{
|
|
client: client,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Set 设置缓存
|
|
func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
|
data, err := json.Marshal(value)
|
|
if err != nil {
|
|
cm.logger.Error("failed to marshal cache value", zap.Error(err), zap.String("key", key))
|
|
return err
|
|
}
|
|
|
|
err = cm.client.Set(ctx, key, data, ttl).Err()
|
|
if err != nil {
|
|
cm.logger.Error("failed to set cache", zap.Error(err), zap.String("key", key))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Get 获取缓存
|
|
func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) error {
|
|
data, err := cm.client.Get(ctx, key).Result()
|
|
if err != nil {
|
|
if err == redis.Nil {
|
|
return ErrCacheNotFound
|
|
}
|
|
cm.logger.Error("failed to get cache", zap.Error(err), zap.String("key", key))
|
|
return err
|
|
}
|
|
|
|
err = json.Unmarshal([]byte(data), dest)
|
|
if err != nil {
|
|
cm.logger.Error("failed to unmarshal cache value", zap.Error(err), zap.String("key", key))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Delete 删除缓存
|
|
func (cm *CacheManager) Delete(ctx context.Context, key string) error {
|
|
err := cm.client.Del(ctx, key).Err()
|
|
if err != nil {
|
|
cm.logger.Error("failed to delete cache", zap.Error(err), zap.String("key", key))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeletePattern 批量删除缓存
|
|
func (cm *CacheManager) DeletePattern(ctx context.Context, pattern string) error {
|
|
keys, err := cm.client.Keys(ctx, pattern).Result()
|
|
if err != nil {
|
|
cm.logger.Error("failed to get keys by pattern", zap.Error(err), zap.String("pattern", pattern))
|
|
return err
|
|
}
|
|
|
|
if len(keys) > 0 {
|
|
err = cm.client.Del(ctx, keys...).Err()
|
|
if err != nil {
|
|
cm.logger.Error("failed to delete keys by pattern", zap.Error(err), zap.String("pattern", pattern))
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Exists 检查缓存是否存在
|
|
func (cm *CacheManager) Exists(ctx context.Context, key string) (bool, error) {
|
|
count, err := cm.client.Exists(ctx, key).Result()
|
|
if err != nil {
|
|
cm.logger.Error("failed to check cache existence", zap.Error(err), zap.String("key", key))
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
// SetTTL 设置过期时间
|
|
func (cm *CacheManager) SetTTL(ctx context.Context, key string, ttl time.Duration) error {
|
|
err := cm.client.Expire(ctx, key, ttl).Err()
|
|
if err != nil {
|
|
cm.logger.Error("failed to set cache ttl", zap.Error(err), zap.String("key", key))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
```
|
|
|
|
### 用户缓存
|
|
```go
|
|
// redis/user_cache.go - 用户缓存
|
|
package redis
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"photography-backend/internal/model/entity"
|
|
"photography-backend/pkg/logger"
|
|
)
|
|
|
|
// UserCache 用户缓存
|
|
type UserCache struct {
|
|
*CacheManager
|
|
ttl time.Duration
|
|
}
|
|
|
|
// NewUserCache 创建用户缓存
|
|
func NewUserCache(cm *CacheManager, ttl time.Duration) *UserCache {
|
|
return &UserCache{
|
|
CacheManager: cm,
|
|
ttl: ttl,
|
|
}
|
|
}
|
|
|
|
// SetUser 缓存用户
|
|
func (uc *UserCache) SetUser(ctx context.Context, user *entity.User) error {
|
|
key := fmt.Sprintf("user:id:%d", user.ID)
|
|
return uc.Set(ctx, key, user, uc.ttl)
|
|
}
|
|
|
|
// GetUser 获取用户缓存
|
|
func (uc *UserCache) GetUser(ctx context.Context, id uint) (*entity.User, error) {
|
|
key := fmt.Sprintf("user:id:%d", id)
|
|
var user entity.User
|
|
err := uc.Get(ctx, key, &user)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// SetUserByEmail 按邮箱缓存用户
|
|
func (uc *UserCache) SetUserByEmail(ctx context.Context, user *entity.User) error {
|
|
key := fmt.Sprintf("user:email:%s", user.Email)
|
|
return uc.Set(ctx, key, user, uc.ttl)
|
|
}
|
|
|
|
// GetUserByEmail 按邮箱获取用户缓存
|
|
func (uc *UserCache) GetUserByEmail(ctx context.Context, email string) (*entity.User, error) {
|
|
key := fmt.Sprintf("user:email:%s", email)
|
|
var user entity.User
|
|
err := uc.Get(ctx, key, &user)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// DeleteUser 删除用户缓存
|
|
func (uc *UserCache) DeleteUser(ctx context.Context, id uint) error {
|
|
key := fmt.Sprintf("user:id:%d", id)
|
|
return uc.Delete(ctx, key)
|
|
}
|
|
|
|
// DeleteUserByEmail 按邮箱删除用户缓存
|
|
func (uc *UserCache) DeleteUserByEmail(ctx context.Context, email string) error {
|
|
key := fmt.Sprintf("user:email:%s", email)
|
|
return uc.Delete(ctx, key)
|
|
}
|
|
|
|
// InvalidateUserCache 失效用户相关缓存
|
|
func (uc *UserCache) InvalidateUserCache(ctx context.Context, userID uint) error {
|
|
pattern := fmt.Sprintf("user:*:%d", userID)
|
|
return uc.DeletePattern(ctx, pattern)
|
|
}
|
|
```
|
|
|
|
## 🔍 错误处理
|
|
|
|
### 仓储层错误定义
|
|
```go
|
|
// errors.go - 仓储层错误定义
|
|
package repository
|
|
|
|
import "errors"
|
|
|
|
var (
|
|
// 通用错误
|
|
ErrNotFound = errors.New("record not found")
|
|
ErrDuplicateKey = errors.New("duplicate key")
|
|
ErrInvalidParameter = errors.New("invalid parameter")
|
|
ErrDatabaseError = errors.New("database error")
|
|
ErrTransactionError = errors.New("transaction error")
|
|
|
|
// 缓存错误
|
|
ErrCacheNotFound = errors.New("cache not found")
|
|
ErrCacheError = errors.New("cache error")
|
|
ErrCacheExpired = errors.New("cache expired")
|
|
|
|
// 连接错误
|
|
ErrConnectionFailed = errors.New("connection failed")
|
|
ErrConnectionTimeout = errors.New("connection timeout")
|
|
)
|
|
```
|
|
|
|
## 🧪 测试
|
|
|
|
### 仓储测试
|
|
```go
|
|
// postgres/user_repository_test.go - 用户仓储测试
|
|
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/suite"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
|
|
"photography-backend/internal/model/entity"
|
|
"photography-backend/pkg/logger"
|
|
)
|
|
|
|
type UserRepositoryTestSuite struct {
|
|
suite.Suite
|
|
db *gorm.DB
|
|
repo *UserRepository
|
|
}
|
|
|
|
func (suite *UserRepositoryTestSuite) SetupTest() {
|
|
// 使用内存数据库进行测试
|
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
|
suite.Require().NoError(err)
|
|
|
|
// 自动迁移
|
|
err = db.AutoMigrate(&entity.User{})
|
|
suite.Require().NoError(err)
|
|
|
|
suite.db = db
|
|
suite.repo = NewUserRepository(db, logger.NewNoop()).(*UserRepository)
|
|
}
|
|
|
|
func (suite *UserRepositoryTestSuite) TearDownTest() {
|
|
sqlDB, _ := suite.db.DB()
|
|
sqlDB.Close()
|
|
}
|
|
|
|
func (suite *UserRepositoryTestSuite) TestCreateUser() {
|
|
ctx := context.Background()
|
|
|
|
user := &entity.User{
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
Password: "hashedpassword",
|
|
Role: entity.UserRoleUser,
|
|
Status: entity.UserStatusActive,
|
|
}
|
|
|
|
createdUser, err := suite.repo.Create(ctx, user)
|
|
|
|
assert.NoError(suite.T(), err)
|
|
assert.NotZero(suite.T(), createdUser.ID)
|
|
assert.Equal(suite.T(), user.Username, createdUser.Username)
|
|
assert.Equal(suite.T(), user.Email, createdUser.Email)
|
|
}
|
|
|
|
func (suite *UserRepositoryTestSuite) TestGetUserByEmail() {
|
|
ctx := context.Background()
|
|
|
|
// 创建测试用户
|
|
user := &entity.User{
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
Password: "hashedpassword",
|
|
Role: entity.UserRoleUser,
|
|
Status: entity.UserStatusActive,
|
|
}
|
|
|
|
createdUser, err := suite.repo.Create(ctx, user)
|
|
suite.Require().NoError(err)
|
|
|
|
// 根据邮箱获取用户
|
|
foundUser, err := suite.repo.GetByEmail(ctx, user.Email)
|
|
|
|
assert.NoError(suite.T(), err)
|
|
assert.Equal(suite.T(), createdUser.ID, foundUser.ID)
|
|
assert.Equal(suite.T(), createdUser.Email, foundUser.Email)
|
|
}
|
|
|
|
func TestUserRepositoryTestSuite(t *testing.T) {
|
|
suite.Run(t, new(UserRepositoryTestSuite))
|
|
}
|
|
```
|
|
|
|
## 💡 最佳实践
|
|
|
|
### 设计原则
|
|
1. **接口隔离**: 定义清晰的仓储接口
|
|
2. **依赖倒置**: 依赖接口而非具体实现
|
|
3. **单一职责**: 每个仓储只负责一个实体
|
|
4. **错误处理**: 统一错误处理和日志记录
|
|
5. **事务支持**: 提供事务操作支持
|
|
|
|
### 性能优化
|
|
1. **查询优化**: 使用适当的索引和查询条件
|
|
2. **批量操作**: 支持批量插入和更新
|
|
3. **缓存策略**: 合理使用缓存减少数据库访问
|
|
4. **连接池**: 使用连接池管理数据库连接
|
|
5. **预加载**: 避免 N+1 查询问题
|
|
|
|
### 测试策略
|
|
1. **单元测试**: 为每个仓储编写单元测试
|
|
2. **集成测试**: 测试数据库交互
|
|
3. **模拟对象**: 使用 Mock 对象进行测试
|
|
4. **测试数据**: 准备充分的测试数据
|
|
5. **性能测试**: 测试查询性能和并发性能
|
|
|
|
本模块是数据访问的核心,确保数据操作的正确性和性能是关键。 |