feat: 实现后端和管理后台基础架构

## 后端架构 (Go + Gin + GORM)
-  完整的分层架构 (API/Service/Repository)
-  PostgreSQL数据库设计和迁移脚本
-  JWT认证系统和权限控制
-  用户、照片、分类、标签等核心模型
-  中间件系统 (认证、CORS、日志)
-  配置管理和环境变量支持
-  结构化日志和错误处理
-  Makefile构建和部署脚本

## 管理后台架构 (React + TypeScript)
-  Vite + React 18 + TypeScript现代化架构
-  路由系统和状态管理 (Zustand + TanStack Query)
-  基于Radix UI的组件库基础
-  认证流程和权限控制
-  响应式设计和主题系统

## 数据库设计
-  用户表 (角色权限、认证信息)
-  照片表 (元数据、EXIF、状态管理)
-  分类表 (层级结构、封面图片)
-  标签表 (使用统计、标签云)
-  关联表 (照片-标签多对多)

## 技术特点
- 🚀 高性能: Gin框架 + GORM ORM
- 🔐 安全: JWT认证 + 密码加密 + 权限控制
- 📊 监控: 结构化日志 + 健康检查
- 🎨 现代化: React 18 + TypeScript + Vite
- 📱 响应式: Tailwind CSS + Radix UI

参考文档: docs/development/saved-docs/
This commit is contained in:
xujiang
2025-07-09 14:56:22 +08:00
parent 180fbd2ae9
commit c57ec3aa82
34 changed files with 3432 additions and 0 deletions

View File

@ -0,0 +1,211 @@
package postgres
import (
"fmt"
"photography-backend/internal/models"
"gorm.io/gorm"
)
// CategoryRepository 分类仓库接口
type CategoryRepository interface {
Create(category *models.Category) error
GetByID(id uint) (*models.Category, error)
Update(category *models.Category) error
Delete(id uint) error
List(params *models.CategoryListParams) ([]*models.Category, error)
GetTree() ([]*models.Category, error)
GetChildren(parentID uint) ([]*models.Category, error)
GetStats() (*models.CategoryStats, error)
UpdateSort(id uint, sort int) error
GetPhotoCount(id uint) (int64, error)
}
// categoryRepository 分类仓库实现
type categoryRepository struct {
db *gorm.DB
}
// NewCategoryRepository 创建分类仓库
func NewCategoryRepository(db *gorm.DB) CategoryRepository {
return &categoryRepository{db: db}
}
// Create 创建分类
func (r *categoryRepository) Create(category *models.Category) error {
if err := r.db.Create(category).Error; err != nil {
return fmt.Errorf("failed to create category: %w", err)
}
return nil
}
// GetByID 根据ID获取分类
func (r *categoryRepository) GetByID(id uint) (*models.Category, error) {
var category models.Category
if err := r.db.Preload("Parent").Preload("Children").
First(&category, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get category by id: %w", err)
}
// 计算照片数量
var photoCount int64
if err := r.db.Model(&models.Photo{}).Where("category_id = ?", id).
Count(&photoCount).Error; err != nil {
return nil, fmt.Errorf("failed to count photos: %w", err)
}
category.PhotoCount = int(photoCount)
return &category, nil
}
// Update 更新分类
func (r *categoryRepository) Update(category *models.Category) error {
if err := r.db.Save(category).Error; err != nil {
return fmt.Errorf("failed to update category: %w", err)
}
return nil
}
// Delete 删除分类
func (r *categoryRepository) Delete(id uint) error {
// 开启事务
tx := r.db.Begin()
// 将子分类的父分类设置为NULL
if err := tx.Model(&models.Category{}).Where("parent_id = ?", id).
Update("parent_id", nil).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update child categories: %w", err)
}
// 删除分类
if err := tx.Delete(&models.Category{}, id).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to delete category: %w", err)
}
return tx.Commit().Error
}
// List 获取分类列表
func (r *categoryRepository) List(params *models.CategoryListParams) ([]*models.Category, error) {
var categories []*models.Category
query := r.db.Model(&models.Category{})
// 添加过滤条件
if params.ParentID > 0 {
query = query.Where("parent_id = ?", params.ParentID)
}
if params.IsActive {
query = query.Where("is_active = ?", true)
}
if err := query.Order("sort ASC, created_at DESC").
Find(&categories).Error; err != nil {
return nil, fmt.Errorf("failed to list categories: %w", err)
}
// 如果需要包含统计信息
if params.IncludeStats {
for _, category := range categories {
var photoCount int64
if err := r.db.Model(&models.Photo{}).Where("category_id = ?", category.ID).
Count(&photoCount).Error; err != nil {
return nil, fmt.Errorf("failed to count photos for category %d: %w", category.ID, err)
}
category.PhotoCount = int(photoCount)
}
}
return categories, nil
}
// GetTree 获取分类树
func (r *categoryRepository) GetTree() ([]*models.Category, error) {
var categories []*models.Category
// 获取所有分类
if err := r.db.Where("is_active = ?", true).
Order("sort ASC, created_at DESC").
Find(&categories).Error; err != nil {
return nil, fmt.Errorf("failed to get categories: %w", err)
}
// 构建分类树
categoryMap := make(map[uint]*models.Category)
var rootCategories []*models.Category
// 第一次遍历:建立映射
for _, category := range categories {
categoryMap[category.ID] = category
category.Children = []*models.Category{}
}
// 第二次遍历:构建树形结构
for _, category := range categories {
if category.ParentID == nil {
rootCategories = append(rootCategories, category)
} else {
if parent, exists := categoryMap[*category.ParentID]; exists {
parent.Children = append(parent.Children, category)
}
}
}
return rootCategories, nil
}
// GetChildren 获取子分类
func (r *categoryRepository) GetChildren(parentID uint) ([]*models.Category, error) {
var categories []*models.Category
if err := r.db.Where("parent_id = ? AND is_active = ?", parentID, true).
Order("sort ASC, created_at DESC").
Find(&categories).Error; err != nil {
return nil, fmt.Errorf("failed to get child categories: %w", err)
}
return categories, nil
}
// GetStats 获取分类统计
func (r *categoryRepository) GetStats() (*models.CategoryStats, error) {
var stats models.CategoryStats
// 总分类数
if err := r.db.Model(&models.Category{}).Count(&stats.TotalCategories).Error; err != nil {
return nil, fmt.Errorf("failed to count total categories: %w", err)
}
// 计算最大层级
// 这里简化处理,实际应用中可能需要递归查询
stats.MaxLevel = 3
// 特色分类数量这里假设有一个is_featured字段实际可能需要调整
stats.FeaturedCount = 0
return &stats, nil
}
// UpdateSort 更新排序
func (r *categoryRepository) UpdateSort(id uint, sort int) error {
if err := r.db.Model(&models.Category{}).Where("id = ?", id).
Update("sort", sort).Error; err != nil {
return fmt.Errorf("failed to update sort: %w", err)
}
return nil
}
// GetPhotoCount 获取分类的照片数量
func (r *categoryRepository) GetPhotoCount(id uint) (int64, error) {
var count int64
if err := r.db.Model(&models.Photo{}).Where("category_id = ?", id).
Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to count photos: %w", err)
}
return count, nil
}

View File

@ -0,0 +1,78 @@
package postgres
import (
"fmt"
"time"
"gorm.io/gorm"
"gorm.io/driver/postgres"
"photography-backend/internal/config"
"photography-backend/internal/models"
)
// Database 数据库连接
type Database struct {
DB *gorm.DB
}
// NewDatabase 创建数据库连接
func NewDatabase(cfg *config.DatabaseConfig) (*Database, error) {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host,
cfg.Port,
cfg.Username,
cfg.Password,
cfg.Database,
cfg.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// 获取底层sql.DB实例配置连接池
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB instance: %w", err)
}
// 设置连接池参数
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second)
// 测试连接
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &Database{DB: db}, nil
}
// AutoMigrate 自动迁移数据库表结构
func (d *Database) AutoMigrate() error {
return d.DB.AutoMigrate(
&models.User{},
&models.Category{},
&models.Tag{},
&models.Photo{},
)
}
// Close 关闭数据库连接
func (d *Database) Close() error {
sqlDB, err := d.DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// Health 检查数据库健康状态
func (d *Database) Health() error {
sqlDB, err := d.DB.DB()
if err != nil {
return err
}
return sqlDB.Ping()
}

View File

@ -0,0 +1,303 @@
package postgres
import (
"fmt"
"photography-backend/internal/models"
"gorm.io/gorm"
)
// PhotoRepository 照片仓库接口
type PhotoRepository interface {
Create(photo *models.Photo) error
GetByID(id uint) (*models.Photo, error)
Update(photo *models.Photo) error
Delete(id uint) error
List(params *models.PhotoListParams) ([]*models.Photo, int64, error)
GetByCategory(categoryID uint, page, limit int) ([]*models.Photo, int64, error)
GetByTag(tagID uint, page, limit int) ([]*models.Photo, int64, error)
GetByUser(userID uint, page, limit int) ([]*models.Photo, int64, error)
Search(query string, page, limit int) ([]*models.Photo, int64, error)
IncrementViewCount(id uint) error
IncrementLikeCount(id uint) error
UpdateStatus(id uint, status string) error
GetStats() (*PhotoStats, error)
}
// PhotoStats 照片统计
type PhotoStats struct {
Total int64 `json:"total"`
Published int64 `json:"published"`
Draft int64 `json:"draft"`
Archived int64 `json:"archived"`
}
// photoRepository 照片仓库实现
type photoRepository struct {
db *gorm.DB
}
// NewPhotoRepository 创建照片仓库
func NewPhotoRepository(db *gorm.DB) PhotoRepository {
return &photoRepository{db: db}
}
// Create 创建照片
func (r *photoRepository) Create(photo *models.Photo) error {
if err := r.db.Create(photo).Error; err != nil {
return fmt.Errorf("failed to create photo: %w", err)
}
return nil
}
// GetByID 根据ID获取照片
func (r *photoRepository) GetByID(id uint) (*models.Photo, error) {
var photo models.Photo
if err := r.db.Preload("Category").Preload("Tags").Preload("User").
First(&photo, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get photo by id: %w", err)
}
return &photo, nil
}
// Update 更新照片
func (r *photoRepository) Update(photo *models.Photo) error {
if err := r.db.Save(photo).Error; err != nil {
return fmt.Errorf("failed to update photo: %w", err)
}
return nil
}
// Delete 删除照片
func (r *photoRepository) Delete(id uint) error {
if err := r.db.Delete(&models.Photo{}, id).Error; err != nil {
return fmt.Errorf("failed to delete photo: %w", err)
}
return nil
}
// List 获取照片列表
func (r *photoRepository) List(params *models.PhotoListParams) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
query := r.db.Model(&models.Photo{}).
Preload("Category").
Preload("Tags").
Preload("User")
// 添加过滤条件
if params.CategoryID > 0 {
query = query.Where("category_id = ?", params.CategoryID)
}
if params.TagID > 0 {
query = query.Joins("JOIN photo_tags ON photos.id = photo_tags.photo_id").
Where("photo_tags.tag_id = ?", params.TagID)
}
if params.UserID > 0 {
query = query.Where("user_id = ?", params.UserID)
}
if params.Status != "" {
query = query.Where("status = ?", params.Status)
}
if params.Search != "" {
query = query.Where("title ILIKE ? OR description ILIKE ?",
"%"+params.Search+"%", "%"+params.Search+"%")
}
if params.Year > 0 {
query = query.Where("EXTRACT(YEAR FROM taken_at) = ?", params.Year)
}
if params.Month > 0 {
query = query.Where("EXTRACT(MONTH FROM taken_at) = ?", params.Month)
}
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count photos: %w", err)
}
// 排序
orderClause := fmt.Sprintf("%s %s", params.SortBy, params.SortOrder)
// 分页查询
offset := (params.Page - 1) * params.Limit
if err := query.Offset(offset).Limit(params.Limit).
Order(orderClause).
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to list photos: %w", err)
}
return photos, total, nil
}
// GetByCategory 根据分类获取照片
func (r *photoRepository) GetByCategory(categoryID uint, page, limit int) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
query := r.db.Model(&models.Photo{}).
Where("category_id = ? AND is_public = ?", categoryID, true).
Preload("Category").
Preload("Tags")
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count photos by category: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := query.Offset(offset).Limit(limit).
Order("created_at DESC").
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to get photos by category: %w", err)
}
return photos, total, nil
}
// GetByTag 根据标签获取照片
func (r *photoRepository) GetByTag(tagID uint, page, limit int) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
query := r.db.Model(&models.Photo{}).
Joins("JOIN photo_tags ON photos.id = photo_tags.photo_id").
Where("photo_tags.tag_id = ? AND photos.is_public = ?", tagID, true).
Preload("Category").
Preload("Tags")
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count photos by tag: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := query.Offset(offset).Limit(limit).
Order("photos.created_at DESC").
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to get photos by tag: %w", err)
}
return photos, total, nil
}
// GetByUser 根据用户获取照片
func (r *photoRepository) GetByUser(userID uint, page, limit int) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
query := r.db.Model(&models.Photo{}).
Where("user_id = ?", userID).
Preload("Category").
Preload("Tags")
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count photos by user: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := query.Offset(offset).Limit(limit).
Order("created_at DESC").
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to get photos by user: %w", err)
}
return photos, total, nil
}
// Search 搜索照片
func (r *photoRepository) Search(query string, page, limit int) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
searchQuery := r.db.Model(&models.Photo{}).
Where("title ILIKE ? OR description ILIKE ? OR location ILIKE ?",
"%"+query+"%", "%"+query+"%", "%"+query+"%").
Where("is_public = ?", true).
Preload("Category").
Preload("Tags")
// 计算总数
if err := searchQuery.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count search results: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := searchQuery.Offset(offset).Limit(limit).
Order("created_at DESC").
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to search photos: %w", err)
}
return photos, total, nil
}
// IncrementViewCount 增加浏览次数
func (r *photoRepository) IncrementViewCount(id uint) error {
if err := r.db.Model(&models.Photo{}).Where("id = ?", id).
Update("view_count", gorm.Expr("view_count + 1")).Error; err != nil {
return fmt.Errorf("failed to increment view count: %w", err)
}
return nil
}
// IncrementLikeCount 增加点赞次数
func (r *photoRepository) IncrementLikeCount(id uint) error {
if err := r.db.Model(&models.Photo{}).Where("id = ?", id).
Update("like_count", gorm.Expr("like_count + 1")).Error; err != nil {
return fmt.Errorf("failed to increment like count: %w", err)
}
return nil
}
// UpdateStatus 更新状态
func (r *photoRepository) UpdateStatus(id uint, status string) error {
if err := r.db.Model(&models.Photo{}).Where("id = ?", id).
Update("status", status).Error; err != nil {
return fmt.Errorf("failed to update status: %w", err)
}
return nil
}
// GetStats 获取照片统计
func (r *photoRepository) GetStats() (*PhotoStats, error) {
var stats PhotoStats
// 总数
if err := r.db.Model(&models.Photo{}).Count(&stats.Total).Error; err != nil {
return nil, fmt.Errorf("failed to count total photos: %w", err)
}
// 已发布
if err := r.db.Model(&models.Photo{}).Where("status = ?", models.StatusPublished).
Count(&stats.Published).Error; err != nil {
return nil, fmt.Errorf("failed to count published photos: %w", err)
}
// 草稿
if err := r.db.Model(&models.Photo{}).Where("status = ?", models.StatusDraft).
Count(&stats.Draft).Error; err != nil {
return nil, fmt.Errorf("failed to count draft photos: %w", err)
}
// 已归档
if err := r.db.Model(&models.Photo{}).Where("status = ?", models.StatusArchived).
Count(&stats.Archived).Error; err != nil {
return nil, fmt.Errorf("failed to count archived photos: %w", err)
}
return &stats, nil
}

View File

@ -0,0 +1,217 @@
package postgres
import (
"fmt"
"photography-backend/internal/models"
"gorm.io/gorm"
)
// TagRepository 标签仓库接口
type TagRepository interface {
Create(tag *models.Tag) error
GetByID(id uint) (*models.Tag, error)
GetByName(name string) (*models.Tag, error)
Update(tag *models.Tag) error
Delete(id uint) error
List(params *models.TagListParams) ([]*models.Tag, int64, error)
Search(query string, limit int) ([]*models.Tag, error)
GetPopular(limit int) ([]*models.Tag, error)
GetOrCreate(name string) (*models.Tag, error)
IncrementUseCount(id uint) error
DecrementUseCount(id uint) error
GetCloud(minUsage int, maxTags int) ([]*models.Tag, error)
}
// tagRepository 标签仓库实现
type tagRepository struct {
db *gorm.DB
}
// NewTagRepository 创建标签仓库
func NewTagRepository(db *gorm.DB) TagRepository {
return &tagRepository{db: db}
}
// Create 创建标签
func (r *tagRepository) Create(tag *models.Tag) error {
if err := r.db.Create(tag).Error; err != nil {
return fmt.Errorf("failed to create tag: %w", err)
}
return nil
}
// GetByID 根据ID获取标签
func (r *tagRepository) GetByID(id uint) (*models.Tag, error) {
var tag models.Tag
if err := r.db.First(&tag, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get tag by id: %w", err)
}
return &tag, nil
}
// GetByName 根据名称获取标签
func (r *tagRepository) GetByName(name string) (*models.Tag, error) {
var tag models.Tag
if err := r.db.Where("name = ?", name).First(&tag).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get tag by name: %w", err)
}
return &tag, nil
}
// Update 更新标签
func (r *tagRepository) Update(tag *models.Tag) error {
if err := r.db.Save(tag).Error; err != nil {
return fmt.Errorf("failed to update tag: %w", err)
}
return nil
}
// Delete 删除标签
func (r *tagRepository) Delete(id uint) error {
// 开启事务
tx := r.db.Begin()
// 删除照片标签关联
if err := tx.Exec("DELETE FROM photo_tags WHERE tag_id = ?", id).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to delete photo tag relations: %w", err)
}
// 删除标签
if err := tx.Delete(&models.Tag{}, id).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to delete tag: %w", err)
}
return tx.Commit().Error
}
// List 获取标签列表
func (r *tagRepository) List(params *models.TagListParams) ([]*models.Tag, int64, error) {
var tags []*models.Tag
var total int64
query := r.db.Model(&models.Tag{})
// 添加过滤条件
if params.Search != "" {
query = query.Where("name ILIKE ?", "%"+params.Search+"%")
}
if params.IsActive {
query = query.Where("is_active = ?", true)
}
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count tags: %w", err)
}
// 排序
orderClause := fmt.Sprintf("%s %s", params.SortBy, params.SortOrder)
// 分页查询
offset := (params.Page - 1) * params.Limit
if err := query.Offset(offset).Limit(params.Limit).
Order(orderClause).
Find(&tags).Error; err != nil {
return nil, 0, fmt.Errorf("failed to list tags: %w", err)
}
return tags, total, nil
}
// Search 搜索标签
func (r *tagRepository) Search(query string, limit int) ([]*models.Tag, error) {
var tags []*models.Tag
if err := r.db.Where("name ILIKE ? AND is_active = ?", "%"+query+"%", true).
Order("use_count DESC").
Limit(limit).
Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to search tags: %w", err)
}
return tags, nil
}
// GetPopular 获取热门标签
func (r *tagRepository) GetPopular(limit int) ([]*models.Tag, error) {
var tags []*models.Tag
if err := r.db.Where("is_active = ?", true).
Order("use_count DESC").
Limit(limit).
Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to get popular tags: %w", err)
}
return tags, nil
}
// GetOrCreate 获取或创建标签
func (r *tagRepository) GetOrCreate(name string) (*models.Tag, error) {
var tag models.Tag
// 先尝试获取
if err := r.db.Where("name = ?", name).First(&tag).Error; err != nil {
if err == gorm.ErrRecordNotFound {
// 不存在则创建
tag = models.Tag{
Name: name,
UseCount: 0,
IsActive: true,
}
if err := r.db.Create(&tag).Error; err != nil {
return nil, fmt.Errorf("failed to create tag: %w", err)
}
} else {
return nil, fmt.Errorf("failed to get tag: %w", err)
}
}
return &tag, nil
}
// IncrementUseCount 增加使用次数
func (r *tagRepository) IncrementUseCount(id uint) error {
if err := r.db.Model(&models.Tag{}).Where("id = ?", id).
Update("use_count", gorm.Expr("use_count + 1")).Error; err != nil {
return fmt.Errorf("failed to increment use count: %w", err)
}
return nil
}
// DecrementUseCount 减少使用次数
func (r *tagRepository) DecrementUseCount(id uint) error {
if err := r.db.Model(&models.Tag{}).Where("id = ?", id).
Update("use_count", gorm.Expr("GREATEST(use_count - 1, 0)")).Error; err != nil {
return fmt.Errorf("failed to decrement use count: %w", err)
}
return nil
}
// GetCloud 获取标签云数据
func (r *tagRepository) GetCloud(minUsage int, maxTags int) ([]*models.Tag, error) {
var tags []*models.Tag
query := r.db.Where("is_active = ?", true)
if minUsage > 0 {
query = query.Where("use_count >= ?", minUsage)
}
if err := query.Order("use_count DESC").
Limit(maxTags).
Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to get tag cloud: %w", err)
}
return tags, nil
}

View File

@ -0,0 +1,129 @@
package postgres
import (
"fmt"
"photography-backend/internal/models"
"gorm.io/gorm"
)
// UserRepository 用户仓库接口
type UserRepository interface {
Create(user *models.User) error
GetByID(id uint) (*models.User, error)
GetByUsername(username string) (*models.User, error)
GetByEmail(email string) (*models.User, error)
Update(user *models.User) error
Delete(id uint) error
List(page, limit int, role string, isActive *bool) ([]*models.User, int64, error)
UpdateLastLogin(id uint) error
}
// userRepository 用户仓库实现
type userRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户仓库
func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepository{db: db}
}
// Create 创建用户
func (r *userRepository) Create(user *models.User) error {
if err := r.db.Create(user).Error; err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// GetByID 根据ID获取用户
func (r *userRepository) GetByID(id uint) (*models.User, error) {
var user models.User
if err := r.db.First(&user, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get user by id: %w", err)
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *userRepository) GetByUsername(username string) (*models.User, error) {
var user models.User
if err := r.db.Where("username = ?", username).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get user by username: %w", err)
}
return &user, nil
}
// GetByEmail 根据邮箱获取用户
func (r *userRepository) GetByEmail(email string) (*models.User, error) {
var user models.User
if err := r.db.Where("email = ?", email).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get user by email: %w", err)
}
return &user, nil
}
// Update 更新用户
func (r *userRepository) Update(user *models.User) error {
if err := r.db.Save(user).Error; err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// Delete 删除用户
func (r *userRepository) Delete(id uint) error {
if err := r.db.Delete(&models.User{}, id).Error; err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}
// List 获取用户列表
func (r *userRepository) List(page, limit int, role string, isActive *bool) ([]*models.User, int64, error) {
var users []*models.User
var total int64
query := r.db.Model(&models.User{})
// 添加过滤条件
if role != "" {
query = query.Where("role = ?", role)
}
if isActive != nil {
query = query.Where("is_active = ?", *isActive)
}
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count users: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := query.Offset(offset).Limit(limit).
Order("created_at DESC").
Find(&users).Error; err != nil {
return nil, 0, fmt.Errorf("failed to list users: %w", err)
}
return users, total, nil
}
// UpdateLastLogin 更新最后登录时间
func (r *userRepository) UpdateLastLogin(id uint) error {
if err := r.db.Model(&models.User{}).Where("id = ?", id).
Update("last_login", gorm.Expr("NOW()")).Error; err != nil {
return fmt.Errorf("failed to update last login: %w", err)
}
return nil
}