feat: 实现后端和管理后台基础架构

## 后端架构 (Go + Gin + GORM)
-  完整的分层架构 (API/Service/Repository)
-  PostgreSQL数据库设计和迁移脚本
-  JWT认证系统和权限控制
-  用户、照片、分类、标签等核心模型
-  中间件系统 (认证、CORS、日志)
-  配置管理和环境变量支持
-  结构化日志和错误处理
-  Makefile构建和部署脚本

## 管理后台架构 (React + TypeScript)
-  Vite + React 18 + TypeScript现代化架构
-  路由系统和状态管理 (Zustand + TanStack Query)
-  基于Radix UI的组件库基础
-  认证流程和权限控制
-  响应式设计和主题系统

## 数据库设计
-  用户表 (角色权限、认证信息)
-  照片表 (元数据、EXIF、状态管理)
-  分类表 (层级结构、封面图片)
-  标签表 (使用统计、标签云)
-  关联表 (照片-标签多对多)

## 技术特点
- 🚀 高性能: Gin框架 + GORM ORM
- 🔐 安全: JWT认证 + 密码加密 + 权限控制
- 📊 监控: 结构化日志 + 健康检查
- 🎨 现代化: React 18 + TypeScript + Vite
- 📱 响应式: Tailwind CSS + Radix UI

参考文档: docs/development/saved-docs/
This commit is contained in:
xujiang
2025-07-09 14:56:22 +08:00
parent 180fbd2ae9
commit c57ec3aa82
34 changed files with 3432 additions and 0 deletions

View File

@ -0,0 +1,118 @@
package handlers
import (
"net/http"
"github.com/gin-gonic/gin"
"photography-backend/internal/models"
"photography-backend/internal/service/auth"
"photography-backend/internal/api/middleware"
"photography-backend/pkg/response"
)
// AuthHandler 认证处理器
type AuthHandler struct {
authService *auth.AuthService
}
// NewAuthHandler 创建认证处理器
func NewAuthHandler(authService *auth.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// Login 用户登录
func (h *AuthHandler) Login(c *gin.Context) {
var req models.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, response.Error(http.StatusBadRequest, err.Error()))
return
}
loginResp, err := h.authService.Login(&req)
if err != nil {
c.JSON(http.StatusUnauthorized, response.Error(http.StatusUnauthorized, err.Error()))
return
}
c.JSON(http.StatusOK, response.Success(loginResp))
}
// Register 用户注册
func (h *AuthHandler) Register(c *gin.Context) {
var req models.CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, response.Error(http.StatusBadRequest, err.Error()))
return
}
user, err := h.authService.Register(&req)
if err != nil {
c.JSON(http.StatusBadRequest, response.Error(http.StatusBadRequest, err.Error()))
return
}
c.JSON(http.StatusCreated, response.Success(user))
}
// RefreshToken 刷新令牌
func (h *AuthHandler) RefreshToken(c *gin.Context) {
var req models.RefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, response.Error(http.StatusBadRequest, err.Error()))
return
}
loginResp, err := h.authService.RefreshToken(&req)
if err != nil {
c.JSON(http.StatusUnauthorized, response.Error(http.StatusUnauthorized, err.Error()))
return
}
c.JSON(http.StatusOK, response.Success(loginResp))
}
// GetProfile 获取用户资料
func (h *AuthHandler) GetProfile(c *gin.Context) {
userID, exists := middleware.GetCurrentUser(c)
if !exists {
c.JSON(http.StatusUnauthorized, response.Error(http.StatusUnauthorized, "User not authenticated"))
return
}
user, err := h.authService.GetUserByID(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, response.Error(http.StatusInternalServerError, err.Error()))
return
}
c.JSON(http.StatusOK, response.Success(user))
}
// UpdatePassword 更新密码
func (h *AuthHandler) UpdatePassword(c *gin.Context) {
userID, exists := middleware.GetCurrentUser(c)
if !exists {
c.JSON(http.StatusUnauthorized, response.Error(http.StatusUnauthorized, "User not authenticated"))
return
}
var req models.UpdatePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, response.Error(http.StatusBadRequest, err.Error()))
return
}
if err := h.authService.UpdatePassword(userID, &req); err != nil {
c.JSON(http.StatusBadRequest, response.Error(http.StatusBadRequest, err.Error()))
return
}
c.JSON(http.StatusOK, response.Success(gin.H{"message": "Password updated successfully"}))
}
// Logout 用户登出
func (h *AuthHandler) Logout(c *gin.Context) {
// 简单实现实际应用中可能需要将token加入黑名单
c.JSON(http.StatusOK, response.Success(gin.H{"message": "Logged out successfully"}))
}

View File

@ -0,0 +1,217 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"photography-backend/internal/service/auth"
"photography-backend/internal/models"
)
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
jwtService *auth.JWTService
}
// NewAuthMiddleware 创建认证中间件
func NewAuthMiddleware(jwtService *auth.JWTService) *AuthMiddleware {
return &AuthMiddleware{
jwtService: jwtService,
}
}
// RequireAuth 需要认证的中间件
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 从Header中获取Authorization
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authorization header is required",
})
c.Abort()
return
}
// 检查Bearer前缀
if !strings.HasPrefix(authHeader, "Bearer ") {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid authorization header format",
})
c.Abort()
return
}
// 提取token
token := strings.TrimPrefix(authHeader, "Bearer ")
// 验证token
claims, err := m.jwtService.ValidateToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid or expired token",
})
c.Abort()
return
}
// 将用户信息存入上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("user_role", claims.Role)
c.Next()
}
}
// RequireRole 需要特定角色的中间件
func (m *AuthMiddleware) RequireRole(requiredRole string) gin.HandlerFunc {
return func(c *gin.Context) {
userRole, exists := c.Get("user_role")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "User role not found in context",
})
c.Abort()
return
}
roleStr, ok := userRole.(string)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid user role",
})
c.Abort()
return
}
// 检查角色权限
if !m.hasPermission(roleStr, requiredRole) {
c.JSON(http.StatusForbidden, gin.H{
"error": "Insufficient permissions",
})
c.Abort()
return
}
c.Next()
}
}
// RequireAdmin 需要管理员权限的中间件
func (m *AuthMiddleware) RequireAdmin() gin.HandlerFunc {
return m.RequireRole(models.RoleAdmin)
}
// RequireEditor 需要编辑者权限的中间件
func (m *AuthMiddleware) RequireEditor() gin.HandlerFunc {
return m.RequireRole(models.RoleEditor)
}
// OptionalAuth 可选认证中间件
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next()
return
}
if !strings.HasPrefix(authHeader, "Bearer ") {
c.Next()
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
claims, err := m.jwtService.ValidateToken(token)
if err != nil {
c.Next()
return
}
// 将用户信息存入上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("user_role", claims.Role)
c.Next()
}
}
// GetCurrentUser 获取当前用户ID
func GetCurrentUser(c *gin.Context) (uint, bool) {
userID, exists := c.Get("user_id")
if !exists {
return 0, false
}
id, ok := userID.(uint)
return id, ok
}
// GetCurrentUserRole 获取当前用户角色
func GetCurrentUserRole(c *gin.Context) (string, bool) {
userRole, exists := c.Get("user_role")
if !exists {
return "", false
}
role, ok := userRole.(string)
return role, ok
}
// GetCurrentUsername 获取当前用户名
func GetCurrentUsername(c *gin.Context) (string, bool) {
username, exists := c.Get("username")
if !exists {
return "", false
}
name, ok := username.(string)
return name, ok
}
// IsAuthenticated 检查是否已认证
func IsAuthenticated(c *gin.Context) bool {
_, exists := c.Get("user_id")
return exists
}
// IsAdmin 检查是否为管理员
func IsAdmin(c *gin.Context) bool {
role, exists := GetCurrentUserRole(c)
if !exists {
return false
}
return role == models.RoleAdmin
}
// IsEditor 检查是否为编辑者或以上
func IsEditor(c *gin.Context) bool {
role, exists := GetCurrentUserRole(c)
if !exists {
return false
}
return role == models.RoleEditor || role == models.RoleAdmin
}
// hasPermission 检查权限
func (m *AuthMiddleware) hasPermission(userRole, requiredRole string) bool {
roleLevel := map[string]int{
models.RoleUser: 1,
models.RoleEditor: 2,
models.RoleAdmin: 3,
}
userLevel, exists := roleLevel[userRole]
if !exists {
return false
}
requiredLevel, exists := roleLevel[requiredRole]
if !exists {
return false
}
return userLevel >= requiredLevel
}

View File

@ -0,0 +1,58 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
"photography-backend/internal/config"
)
// CORSMiddleware CORS中间件
func CORSMiddleware(cfg *config.CORSConfig) gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.GetHeader("Origin")
// 检查是否允许的来源
allowed := false
for _, allowedOrigin := range cfg.AllowedOrigins {
if allowedOrigin == "*" || allowedOrigin == origin {
allowed = true
break
}
}
if allowed {
c.Header("Access-Control-Allow-Origin", origin)
}
// 设置其他CORS头
c.Header("Access-Control-Allow-Methods", joinStrings(cfg.AllowedMethods, ", "))
c.Header("Access-Control-Allow-Headers", joinStrings(cfg.AllowedHeaders, ", "))
c.Header("Access-Control-Max-Age", "86400")
if cfg.AllowCredentials {
c.Header("Access-Control-Allow-Credentials", "true")
}
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// joinStrings 连接字符串数组
func joinStrings(strs []string, sep string) string {
if len(strs) == 0 {
return ""
}
result := strs[0]
for i := 1; i < len(strs); i++ {
result += sep + strs[i]
}
return result
}

View File

@ -0,0 +1,74 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// LoggerMiddleware 日志中间件
func LoggerMiddleware(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
// 处理请求
c.Next()
// 计算延迟
latency := time.Since(start)
// 获取请求信息
clientIP := c.ClientIP()
method := c.Request.Method
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
if raw != "" {
path = path + "?" + raw
}
// 记录日志
logger.Info("HTTP Request",
zap.String("method", method),
zap.String("path", path),
zap.String("client_ip", clientIP),
zap.Int("status_code", statusCode),
zap.Int("body_size", bodySize),
zap.Duration("latency", latency),
zap.String("user_agent", c.Request.UserAgent()),
)
}
}
// RequestIDMiddleware 请求ID中间件
func RequestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
c.Set("request_id", requestID)
c.Header("X-Request-ID", requestID)
c.Next()
}
}
// generateRequestID 生成请求ID
func generateRequestID() string {
// 简单实现实际应用中可能需要更复杂的ID生成逻辑
return time.Now().Format("20060102150405") + "-" + randomString(8)
}
// randomString 生成随机字符串
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[time.Now().UnixNano()%int64(len(charset))]
}
return string(b)
}

View File

@ -0,0 +1,44 @@
package routes
import (
"github.com/gin-gonic/gin"
"photography-backend/internal/api/handlers"
"photography-backend/internal/api/middleware"
)
// Handlers 处理器集合
type Handlers struct {
Auth *handlers.AuthHandler
}
// SetupRoutes 设置路由
func SetupRoutes(r *gin.Engine, handlers *Handlers, authMiddleware *middleware.AuthMiddleware) {
// API v1路由组
v1 := r.Group("/api/v1")
// 公开路由
public := v1.Group("/auth")
{
public.POST("/login", handlers.Auth.Login)
public.POST("/register", handlers.Auth.Register)
public.POST("/refresh", handlers.Auth.RefreshToken)
}
// 需要认证的路由
protected := v1.Group("/")
protected.Use(authMiddleware.RequireAuth())
{
// 用户资料
protected.GET("/auth/profile", handlers.Auth.GetProfile)
protected.PUT("/auth/password", handlers.Auth.UpdatePassword)
protected.POST("/auth/logout", handlers.Auth.Logout)
}
// 管理员路由
admin := v1.Group("/admin")
admin.Use(authMiddleware.RequireAuth())
admin.Use(authMiddleware.RequireAdmin())
{
// 将在后续添加管理员相关路由
}
}

View File

@ -0,0 +1,231 @@
package config
import (
"fmt"
"time"
"github.com/spf13/viper"
)
// Config 应用配置
type Config struct {
App AppConfig `mapstructure:"app"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Storage StorageConfig `mapstructure:"storage"`
Upload UploadConfig `mapstructure:"upload"`
Logger LoggerConfig `mapstructure:"logger"`
CORS CORSConfig `mapstructure:"cors"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
}
// AppConfig 应用配置
type AppConfig struct {
Name string `mapstructure:"name"`
Version string `mapstructure:"version"`
Environment string `mapstructure:"environment"`
Port int `mapstructure:"port"`
Debug bool `mapstructure:"debug"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
Database string `mapstructure:"database"`
SSLMode string `mapstructure:"ssl_mode"`
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
ConnMaxLifetime int `mapstructure:"conn_max_lifetime"`
}
// RedisConfig Redis配置
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
Database int `mapstructure:"database"`
PoolSize int `mapstructure:"pool_size"`
MinIdleConns int `mapstructure:"min_idle_conns"`
}
// JWTConfig JWT配置
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpiresIn string `mapstructure:"expires_in"`
RefreshExpiresIn string `mapstructure:"refresh_expires_in"`
}
// StorageConfig 存储配置
type StorageConfig struct {
Type string `mapstructure:"type"`
Local LocalConfig `mapstructure:"local"`
S3 S3Config `mapstructure:"s3"`
}
// LocalConfig 本地存储配置
type LocalConfig struct {
BasePath string `mapstructure:"base_path"`
BaseURL string `mapstructure:"base_url"`
}
// S3Config S3存储配置
type S3Config struct {
Region string `mapstructure:"region"`
Bucket string `mapstructure:"bucket"`
AccessKey string `mapstructure:"access_key"`
SecretKey string `mapstructure:"secret_key"`
Endpoint string `mapstructure:"endpoint"`
}
// UploadConfig 上传配置
type UploadConfig struct {
MaxFileSize int64 `mapstructure:"max_file_size"`
AllowedTypes []string `mapstructure:"allowed_types"`
ThumbnailSizes []ThumbnailSize `mapstructure:"thumbnail_sizes"`
}
// ThumbnailSize 缩略图尺寸
type ThumbnailSize struct {
Name string `mapstructure:"name"`
Width int `mapstructure:"width"`
Height int `mapstructure:"height"`
}
// LoggerConfig 日志配置
type LoggerConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
Output string `mapstructure:"output"`
Filename string `mapstructure:"filename"`
MaxSize int `mapstructure:"max_size"`
MaxAge int `mapstructure:"max_age"`
Compress bool `mapstructure:"compress"`
}
// CORSConfig CORS配置
type CORSConfig struct {
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowedMethods []string `mapstructure:"allowed_methods"`
AllowedHeaders []string `mapstructure:"allowed_headers"`
AllowCredentials bool `mapstructure:"allow_credentials"`
}
// RateLimitConfig 限流配置
type RateLimitConfig struct {
Enabled bool `mapstructure:"enabled"`
RequestsPerMinute int `mapstructure:"requests_per_minute"`
Burst int `mapstructure:"burst"`
}
var AppConfig *Config
// LoadConfig 加载配置
func LoadConfig(configPath string) (*Config, error) {
viper.SetConfigFile(configPath)
viper.SetConfigType("yaml")
// 设置环境变量前缀
viper.SetEnvPrefix("PHOTOGRAPHY")
viper.AutomaticEnv()
// 环境变量替换配置
viper.BindEnv("database.host", "DB_HOST")
viper.BindEnv("database.port", "DB_PORT")
viper.BindEnv("database.username", "DB_USER")
viper.BindEnv("database.password", "DB_PASSWORD")
viper.BindEnv("database.database", "DB_NAME")
viper.BindEnv("redis.host", "REDIS_HOST")
viper.BindEnv("redis.port", "REDIS_PORT")
viper.BindEnv("redis.password", "REDIS_PASSWORD")
viper.BindEnv("jwt.secret", "JWT_SECRET")
viper.BindEnv("storage.s3.access_key", "AWS_ACCESS_KEY_ID")
viper.BindEnv("storage.s3.secret_key", "AWS_SECRET_ACCESS_KEY")
viper.BindEnv("app.port", "PORT")
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// 验证配置
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("config validation failed: %w", err)
}
AppConfig = &config
return &config, nil
}
// validateConfig 验证配置
func validateConfig(config *Config) error {
if config.App.Name == "" {
return fmt.Errorf("app name is required")
}
if config.Database.Host == "" {
return fmt.Errorf("database host is required")
}
if config.JWT.Secret == "" {
return fmt.Errorf("jwt secret is required")
}
return nil
}
// GetJWTExpiration 获取JWT过期时间
func (c *Config) GetJWTExpiration() time.Duration {
duration, err := time.ParseDuration(c.JWT.ExpiresIn)
if err != nil {
return 24 * time.Hour // 默认24小时
}
return duration
}
// GetJWTRefreshExpiration 获取JWT刷新过期时间
func (c *Config) GetJWTRefreshExpiration() time.Duration {
duration, err := time.ParseDuration(c.JWT.RefreshExpiresIn)
if err != nil {
return 7 * 24 * time.Hour // 默认7天
}
return duration
}
// GetDatabaseDSN 获取数据库DSN
func (c *Config) GetDatabaseDSN() string {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
c.Database.Host,
c.Database.Port,
c.Database.Username,
c.Database.Password,
c.Database.Database,
c.Database.SSLMode,
)
}
// GetRedisAddr 获取Redis地址
func (c *Config) GetRedisAddr() string {
return fmt.Sprintf("%s:%d", c.Redis.Host, c.Redis.Port)
}
// GetServerAddr 获取服务器地址
func (c *Config) GetServerAddr() string {
return fmt.Sprintf(":%d", c.App.Port)
}
// IsDevelopment 是否为开发环境
func (c *Config) IsDevelopment() bool {
return c.App.Environment == "development"
}
// IsProduction 是否为生产环境
func (c *Config) IsProduction() bool {
return c.App.Environment == "production"
}

View File

@ -0,0 +1,85 @@
package models
import (
"time"
"gorm.io/gorm"
)
// Category 分类模型
type Category struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
ParentID *uint `json:"parent_id"`
Parent *Category `gorm:"foreignKey:ParentID" json:"parent,omitempty"`
Children []Category `gorm:"foreignKey:ParentID" json:"children,omitempty"`
Color string `gorm:"size:7;default:#3b82f6" json:"color"`
CoverImage string `gorm:"size:500" json:"cover_image"`
Sort int `gorm:"default:0" json:"sort"`
IsActive bool `gorm:"default:true" json:"is_active"`
PhotoCount int `gorm:"-" json:"photo_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// TableName 返回分类表名
func (Category) TableName() string {
return "categories"
}
// CreateCategoryRequest 创建分类请求
type CreateCategoryRequest struct {
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
ParentID *uint `json:"parent_id"`
Color string `json:"color" binding:"omitempty,hexcolor"`
CoverImage string `json:"cover_image" binding:"omitempty,max=500"`
Sort int `json:"sort"`
}
// UpdateCategoryRequest 更新分类请求
type UpdateCategoryRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
ParentID *uint `json:"parent_id"`
Color *string `json:"color" binding:"omitempty,hexcolor"`
CoverImage *string `json:"cover_image" binding:"omitempty,max=500"`
Sort *int `json:"sort"`
IsActive *bool `json:"is_active"`
}
// CategoryListParams 分类列表查询参数
type CategoryListParams struct {
IncludeStats bool `form:"include_stats"`
IncludeTree bool `form:"include_tree"`
ParentID uint `form:"parent_id"`
IsActive bool `form:"is_active"`
}
// CategoryResponse 分类响应
type CategoryResponse struct {
*Category
}
// CategoryTreeNode 分类树节点
type CategoryTreeNode struct {
ID uint `json:"id"`
Name string `json:"name"`
PhotoCount int `json:"photo_count"`
Children []CategoryTreeNode `json:"children"`
}
// CategoryListResponse 分类列表响应
type CategoryListResponse struct {
Categories []CategoryResponse `json:"categories"`
Tree []CategoryTreeNode `json:"tree,omitempty"`
Stats *CategoryStats `json:"stats,omitempty"`
}
// CategoryStats 分类统计
type CategoryStats struct {
TotalCategories int `json:"total_categories"`
MaxLevel int `json:"max_level"`
FeaturedCount int `json:"featured_count"`
}

View File

@ -0,0 +1,99 @@
package models
import (
"time"
"gorm.io/gorm"
)
// Photo 照片模型
type Photo struct {
ID uint `gorm:"primaryKey" json:"id"`
Title string `gorm:"size:255;not null" json:"title"`
Description string `gorm:"type:text" json:"description"`
Filename string `gorm:"size:255;not null" json:"filename"`
FilePath string `gorm:"size:500;not null" json:"file_path"`
FileSize int64 `json:"file_size"`
MimeType string `gorm:"size:100" json:"mime_type"`
Width int `json:"width"`
Height int `json:"height"`
CategoryID uint `json:"category_id"`
Category *Category `gorm:"foreignKey:CategoryID" json:"category,omitempty"`
Tags []Tag `gorm:"many2many:photo_tags;" json:"tags,omitempty"`
EXIF string `gorm:"type:jsonb" json:"exif"`
TakenAt *time.Time `json:"taken_at"`
Location string `gorm:"size:255" json:"location"`
IsPublic bool `gorm:"default:true" json:"is_public"`
Status string `gorm:"size:20;default:draft" json:"status"`
ViewCount int `gorm:"default:0" json:"view_count"`
LikeCount int `gorm:"default:0" json:"like_count"`
UserID uint `gorm:"not null" json:"user_id"`
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// TableName 返回照片表名
func (Photo) TableName() string {
return "photos"
}
// PhotoStatus 照片状态常量
const (
StatusDraft = "draft"
StatusPublished = "published"
StatusArchived = "archived"
)
// CreatePhotoRequest 创建照片请求
type CreatePhotoRequest struct {
Title string `json:"title" binding:"required,max=255"`
Description string `json:"description"`
CategoryID uint `json:"category_id" binding:"required"`
TagIDs []uint `json:"tag_ids"`
TakenAt *time.Time `json:"taken_at"`
Location string `json:"location" binding:"max=255"`
IsPublic *bool `json:"is_public"`
Status string `json:"status" binding:"omitempty,oneof=draft published archived"`
}
// UpdatePhotoRequest 更新照片请求
type UpdatePhotoRequest struct {
Title *string `json:"title" binding:"omitempty,max=255"`
Description *string `json:"description"`
CategoryID *uint `json:"category_id"`
TagIDs []uint `json:"tag_ids"`
TakenAt *time.Time `json:"taken_at"`
Location *string `json:"location" binding:"omitempty,max=255"`
IsPublic *bool `json:"is_public"`
Status *string `json:"status" binding:"omitempty,oneof=draft published archived"`
}
// PhotoListParams 照片列表查询参数
type PhotoListParams struct {
Page int `form:"page,default=1" binding:"min=1"`
Limit int `form:"limit,default=20" binding:"min=1,max=100"`
CategoryID uint `form:"category_id"`
TagID uint `form:"tag_id"`
UserID uint `form:"user_id"`
Status string `form:"status" binding:"omitempty,oneof=draft published archived"`
Search string `form:"search"`
SortBy string `form:"sort_by,default=created_at" binding:"omitempty,oneof=created_at taken_at title view_count like_count"`
SortOrder string `form:"sort_order,default=desc" binding:"omitempty,oneof=asc desc"`
Year int `form:"year"`
Month int `form:"month" binding:"min=1,max=12"`
}
// PhotoResponse 照片响应
type PhotoResponse struct {
*Photo
ThumbnailURLs map[string]string `json:"thumbnail_urls,omitempty"`
}
// PhotoListResponse 照片列表响应
type PhotoListResponse struct {
Photos []PhotoResponse `json:"photos"`
Total int64 `json:"total"`
Page int `json:"page"`
Limit int `json:"limit"`
}

View File

@ -0,0 +1,95 @@
package models
import (
"time"
"gorm.io/gorm"
)
// Tag 标签模型
type Tag struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:50;not null;unique" json:"name"`
Color string `gorm:"size:7;default:#6b7280" json:"color"`
UseCount int `gorm:"default:0" json:"use_count"`
IsActive bool `gorm:"default:true" json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// TableName 返回标签表名
func (Tag) TableName() string {
return "tags"
}
// CreateTagRequest 创建标签请求
type CreateTagRequest struct {
Name string `json:"name" binding:"required,max=50"`
Color string `json:"color" binding:"omitempty,hexcolor"`
}
// UpdateTagRequest 更新标签请求
type UpdateTagRequest struct {
Name *string `json:"name" binding:"omitempty,max=50"`
Color *string `json:"color" binding:"omitempty,hexcolor"`
IsActive *bool `json:"is_active"`
}
// TagListParams 标签列表查询参数
type TagListParams struct {
Page int `form:"page,default=1" binding:"min=1"`
Limit int `form:"limit,default=50" binding:"min=1,max=100"`
Search string `form:"search"`
SortBy string `form:"sort_by,default=use_count" binding:"omitempty,oneof=use_count name created_at"`
SortOrder string `form:"sort_order,default=desc" binding:"omitempty,oneof=asc desc"`
IsActive bool `form:"is_active"`
}
// TagSuggestionsParams 标签建议查询参数
type TagSuggestionsParams struct {
Query string `form:"q" binding:"required"`
Limit int `form:"limit,default=10" binding:"min=1,max=20"`
}
// TagResponse 标签响应
type TagResponse struct {
*Tag
MatchScore float64 `json:"match_score,omitempty"`
}
// TagListResponse 标签列表响应
type TagListResponse struct {
Tags []TagResponse `json:"tags"`
Total int64 `json:"total"`
Page int `json:"page"`
Limit int `json:"limit"`
Groups *TagGroups `json:"groups,omitempty"`
}
// TagGroups 标签分组
type TagGroups struct {
Style TagGroup `json:"style"`
Subject TagGroup `json:"subject"`
Technique TagGroup `json:"technique"`
Location TagGroup `json:"location"`
}
// TagGroup 标签组
type TagGroup struct {
Name string `json:"name"`
Count int `json:"count"`
}
// TagCloudItem 标签云项目
type TagCloudItem struct {
ID uint `json:"id"`
Name string `json:"name"`
UseCount int `json:"use_count"`
RelativeSize int `json:"relative_size"`
Color string `json:"color"`
}
// TagCloudResponse 标签云响应
type TagCloudResponse struct {
Tags []TagCloudItem `json:"tags"`
}

View File

@ -0,0 +1,76 @@
package models
import (
"time"
"gorm.io/gorm"
)
// User 用户模型
type User struct {
ID uint `gorm:"primaryKey" json:"id"`
Username string `gorm:"size:50;not null;unique" json:"username"`
Email string `gorm:"size:100;not null;unique" json:"email"`
Password string `gorm:"size:255;not null" json:"-"`
Name string `gorm:"size:100" json:"name"`
Avatar string `gorm:"size:500" json:"avatar"`
Role string `gorm:"size:20;default:user" json:"role"`
IsActive bool `gorm:"default:true" json:"is_active"`
LastLogin *time.Time `json:"last_login"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// TableName 返回用户表名
func (User) TableName() string {
return "users"
}
// UserRole 用户角色常量
const (
RoleUser = "user"
RoleEditor = "editor"
RoleAdmin = "admin"
)
// CreateUserRequest 创建用户请求
type CreateUserRequest struct {
Username string `json:"username" binding:"required,min=3,max=50"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Name string `json:"name" binding:"max=100"`
Role string `json:"role" binding:"omitempty,oneof=user editor admin"`
}
// UpdateUserRequest 更新用户请求
type UpdateUserRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
Avatar *string `json:"avatar" binding:"omitempty,max=500"`
IsActive *bool `json:"is_active"`
}
// UpdatePasswordRequest 更新密码请求
type UpdatePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// LoginRequest 登录请求
type LoginRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
// LoginResponse 登录响应
type LoginResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
User *User `json:"user"`
}
// RefreshTokenRequest 刷新令牌请求
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}

View File

@ -0,0 +1,211 @@
package postgres
import (
"fmt"
"photography-backend/internal/models"
"gorm.io/gorm"
)
// CategoryRepository 分类仓库接口
type CategoryRepository interface {
Create(category *models.Category) error
GetByID(id uint) (*models.Category, error)
Update(category *models.Category) error
Delete(id uint) error
List(params *models.CategoryListParams) ([]*models.Category, error)
GetTree() ([]*models.Category, error)
GetChildren(parentID uint) ([]*models.Category, error)
GetStats() (*models.CategoryStats, error)
UpdateSort(id uint, sort int) error
GetPhotoCount(id uint) (int64, error)
}
// categoryRepository 分类仓库实现
type categoryRepository struct {
db *gorm.DB
}
// NewCategoryRepository 创建分类仓库
func NewCategoryRepository(db *gorm.DB) CategoryRepository {
return &categoryRepository{db: db}
}
// Create 创建分类
func (r *categoryRepository) Create(category *models.Category) error {
if err := r.db.Create(category).Error; err != nil {
return fmt.Errorf("failed to create category: %w", err)
}
return nil
}
// GetByID 根据ID获取分类
func (r *categoryRepository) GetByID(id uint) (*models.Category, error) {
var category models.Category
if err := r.db.Preload("Parent").Preload("Children").
First(&category, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get category by id: %w", err)
}
// 计算照片数量
var photoCount int64
if err := r.db.Model(&models.Photo{}).Where("category_id = ?", id).
Count(&photoCount).Error; err != nil {
return nil, fmt.Errorf("failed to count photos: %w", err)
}
category.PhotoCount = int(photoCount)
return &category, nil
}
// Update 更新分类
func (r *categoryRepository) Update(category *models.Category) error {
if err := r.db.Save(category).Error; err != nil {
return fmt.Errorf("failed to update category: %w", err)
}
return nil
}
// Delete 删除分类
func (r *categoryRepository) Delete(id uint) error {
// 开启事务
tx := r.db.Begin()
// 将子分类的父分类设置为NULL
if err := tx.Model(&models.Category{}).Where("parent_id = ?", id).
Update("parent_id", nil).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to update child categories: %w", err)
}
// 删除分类
if err := tx.Delete(&models.Category{}, id).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to delete category: %w", err)
}
return tx.Commit().Error
}
// List 获取分类列表
func (r *categoryRepository) List(params *models.CategoryListParams) ([]*models.Category, error) {
var categories []*models.Category
query := r.db.Model(&models.Category{})
// 添加过滤条件
if params.ParentID > 0 {
query = query.Where("parent_id = ?", params.ParentID)
}
if params.IsActive {
query = query.Where("is_active = ?", true)
}
if err := query.Order("sort ASC, created_at DESC").
Find(&categories).Error; err != nil {
return nil, fmt.Errorf("failed to list categories: %w", err)
}
// 如果需要包含统计信息
if params.IncludeStats {
for _, category := range categories {
var photoCount int64
if err := r.db.Model(&models.Photo{}).Where("category_id = ?", category.ID).
Count(&photoCount).Error; err != nil {
return nil, fmt.Errorf("failed to count photos for category %d: %w", category.ID, err)
}
category.PhotoCount = int(photoCount)
}
}
return categories, nil
}
// GetTree 获取分类树
func (r *categoryRepository) GetTree() ([]*models.Category, error) {
var categories []*models.Category
// 获取所有分类
if err := r.db.Where("is_active = ?", true).
Order("sort ASC, created_at DESC").
Find(&categories).Error; err != nil {
return nil, fmt.Errorf("failed to get categories: %w", err)
}
// 构建分类树
categoryMap := make(map[uint]*models.Category)
var rootCategories []*models.Category
// 第一次遍历:建立映射
for _, category := range categories {
categoryMap[category.ID] = category
category.Children = []*models.Category{}
}
// 第二次遍历:构建树形结构
for _, category := range categories {
if category.ParentID == nil {
rootCategories = append(rootCategories, category)
} else {
if parent, exists := categoryMap[*category.ParentID]; exists {
parent.Children = append(parent.Children, category)
}
}
}
return rootCategories, nil
}
// GetChildren 获取子分类
func (r *categoryRepository) GetChildren(parentID uint) ([]*models.Category, error) {
var categories []*models.Category
if err := r.db.Where("parent_id = ? AND is_active = ?", parentID, true).
Order("sort ASC, created_at DESC").
Find(&categories).Error; err != nil {
return nil, fmt.Errorf("failed to get child categories: %w", err)
}
return categories, nil
}
// GetStats 获取分类统计
func (r *categoryRepository) GetStats() (*models.CategoryStats, error) {
var stats models.CategoryStats
// 总分类数
if err := r.db.Model(&models.Category{}).Count(&stats.TotalCategories).Error; err != nil {
return nil, fmt.Errorf("failed to count total categories: %w", err)
}
// 计算最大层级
// 这里简化处理,实际应用中可能需要递归查询
stats.MaxLevel = 3
// 特色分类数量这里假设有一个is_featured字段实际可能需要调整
stats.FeaturedCount = 0
return &stats, nil
}
// UpdateSort 更新排序
func (r *categoryRepository) UpdateSort(id uint, sort int) error {
if err := r.db.Model(&models.Category{}).Where("id = ?", id).
Update("sort", sort).Error; err != nil {
return fmt.Errorf("failed to update sort: %w", err)
}
return nil
}
// GetPhotoCount 获取分类的照片数量
func (r *categoryRepository) GetPhotoCount(id uint) (int64, error) {
var count int64
if err := r.db.Model(&models.Photo{}).Where("category_id = ?", id).
Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to count photos: %w", err)
}
return count, nil
}

View File

@ -0,0 +1,78 @@
package postgres
import (
"fmt"
"time"
"gorm.io/gorm"
"gorm.io/driver/postgres"
"photography-backend/internal/config"
"photography-backend/internal/models"
)
// Database 数据库连接
type Database struct {
DB *gorm.DB
}
// NewDatabase 创建数据库连接
func NewDatabase(cfg *config.DatabaseConfig) (*Database, error) {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host,
cfg.Port,
cfg.Username,
cfg.Password,
cfg.Database,
cfg.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// 获取底层sql.DB实例配置连接池
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB instance: %w", err)
}
// 设置连接池参数
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second)
// 测试连接
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &Database{DB: db}, nil
}
// AutoMigrate 自动迁移数据库表结构
func (d *Database) AutoMigrate() error {
return d.DB.AutoMigrate(
&models.User{},
&models.Category{},
&models.Tag{},
&models.Photo{},
)
}
// Close 关闭数据库连接
func (d *Database) Close() error {
sqlDB, err := d.DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// Health 检查数据库健康状态
func (d *Database) Health() error {
sqlDB, err := d.DB.DB()
if err != nil {
return err
}
return sqlDB.Ping()
}

View File

@ -0,0 +1,303 @@
package postgres
import (
"fmt"
"photography-backend/internal/models"
"gorm.io/gorm"
)
// PhotoRepository 照片仓库接口
type PhotoRepository interface {
Create(photo *models.Photo) error
GetByID(id uint) (*models.Photo, error)
Update(photo *models.Photo) error
Delete(id uint) error
List(params *models.PhotoListParams) ([]*models.Photo, int64, error)
GetByCategory(categoryID uint, page, limit int) ([]*models.Photo, int64, error)
GetByTag(tagID uint, page, limit int) ([]*models.Photo, int64, error)
GetByUser(userID uint, page, limit int) ([]*models.Photo, int64, error)
Search(query string, page, limit int) ([]*models.Photo, int64, error)
IncrementViewCount(id uint) error
IncrementLikeCount(id uint) error
UpdateStatus(id uint, status string) error
GetStats() (*PhotoStats, error)
}
// PhotoStats 照片统计
type PhotoStats struct {
Total int64 `json:"total"`
Published int64 `json:"published"`
Draft int64 `json:"draft"`
Archived int64 `json:"archived"`
}
// photoRepository 照片仓库实现
type photoRepository struct {
db *gorm.DB
}
// NewPhotoRepository 创建照片仓库
func NewPhotoRepository(db *gorm.DB) PhotoRepository {
return &photoRepository{db: db}
}
// Create 创建照片
func (r *photoRepository) Create(photo *models.Photo) error {
if err := r.db.Create(photo).Error; err != nil {
return fmt.Errorf("failed to create photo: %w", err)
}
return nil
}
// GetByID 根据ID获取照片
func (r *photoRepository) GetByID(id uint) (*models.Photo, error) {
var photo models.Photo
if err := r.db.Preload("Category").Preload("Tags").Preload("User").
First(&photo, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get photo by id: %w", err)
}
return &photo, nil
}
// Update 更新照片
func (r *photoRepository) Update(photo *models.Photo) error {
if err := r.db.Save(photo).Error; err != nil {
return fmt.Errorf("failed to update photo: %w", err)
}
return nil
}
// Delete 删除照片
func (r *photoRepository) Delete(id uint) error {
if err := r.db.Delete(&models.Photo{}, id).Error; err != nil {
return fmt.Errorf("failed to delete photo: %w", err)
}
return nil
}
// List 获取照片列表
func (r *photoRepository) List(params *models.PhotoListParams) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
query := r.db.Model(&models.Photo{}).
Preload("Category").
Preload("Tags").
Preload("User")
// 添加过滤条件
if params.CategoryID > 0 {
query = query.Where("category_id = ?", params.CategoryID)
}
if params.TagID > 0 {
query = query.Joins("JOIN photo_tags ON photos.id = photo_tags.photo_id").
Where("photo_tags.tag_id = ?", params.TagID)
}
if params.UserID > 0 {
query = query.Where("user_id = ?", params.UserID)
}
if params.Status != "" {
query = query.Where("status = ?", params.Status)
}
if params.Search != "" {
query = query.Where("title ILIKE ? OR description ILIKE ?",
"%"+params.Search+"%", "%"+params.Search+"%")
}
if params.Year > 0 {
query = query.Where("EXTRACT(YEAR FROM taken_at) = ?", params.Year)
}
if params.Month > 0 {
query = query.Where("EXTRACT(MONTH FROM taken_at) = ?", params.Month)
}
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count photos: %w", err)
}
// 排序
orderClause := fmt.Sprintf("%s %s", params.SortBy, params.SortOrder)
// 分页查询
offset := (params.Page - 1) * params.Limit
if err := query.Offset(offset).Limit(params.Limit).
Order(orderClause).
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to list photos: %w", err)
}
return photos, total, nil
}
// GetByCategory 根据分类获取照片
func (r *photoRepository) GetByCategory(categoryID uint, page, limit int) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
query := r.db.Model(&models.Photo{}).
Where("category_id = ? AND is_public = ?", categoryID, true).
Preload("Category").
Preload("Tags")
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count photos by category: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := query.Offset(offset).Limit(limit).
Order("created_at DESC").
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to get photos by category: %w", err)
}
return photos, total, nil
}
// GetByTag 根据标签获取照片
func (r *photoRepository) GetByTag(tagID uint, page, limit int) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
query := r.db.Model(&models.Photo{}).
Joins("JOIN photo_tags ON photos.id = photo_tags.photo_id").
Where("photo_tags.tag_id = ? AND photos.is_public = ?", tagID, true).
Preload("Category").
Preload("Tags")
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count photos by tag: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := query.Offset(offset).Limit(limit).
Order("photos.created_at DESC").
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to get photos by tag: %w", err)
}
return photos, total, nil
}
// GetByUser 根据用户获取照片
func (r *photoRepository) GetByUser(userID uint, page, limit int) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
query := r.db.Model(&models.Photo{}).
Where("user_id = ?", userID).
Preload("Category").
Preload("Tags")
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count photos by user: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := query.Offset(offset).Limit(limit).
Order("created_at DESC").
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to get photos by user: %w", err)
}
return photos, total, nil
}
// Search 搜索照片
func (r *photoRepository) Search(query string, page, limit int) ([]*models.Photo, int64, error) {
var photos []*models.Photo
var total int64
searchQuery := r.db.Model(&models.Photo{}).
Where("title ILIKE ? OR description ILIKE ? OR location ILIKE ?",
"%"+query+"%", "%"+query+"%", "%"+query+"%").
Where("is_public = ?", true).
Preload("Category").
Preload("Tags")
// 计算总数
if err := searchQuery.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count search results: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := searchQuery.Offset(offset).Limit(limit).
Order("created_at DESC").
Find(&photos).Error; err != nil {
return nil, 0, fmt.Errorf("failed to search photos: %w", err)
}
return photos, total, nil
}
// IncrementViewCount 增加浏览次数
func (r *photoRepository) IncrementViewCount(id uint) error {
if err := r.db.Model(&models.Photo{}).Where("id = ?", id).
Update("view_count", gorm.Expr("view_count + 1")).Error; err != nil {
return fmt.Errorf("failed to increment view count: %w", err)
}
return nil
}
// IncrementLikeCount 增加点赞次数
func (r *photoRepository) IncrementLikeCount(id uint) error {
if err := r.db.Model(&models.Photo{}).Where("id = ?", id).
Update("like_count", gorm.Expr("like_count + 1")).Error; err != nil {
return fmt.Errorf("failed to increment like count: %w", err)
}
return nil
}
// UpdateStatus 更新状态
func (r *photoRepository) UpdateStatus(id uint, status string) error {
if err := r.db.Model(&models.Photo{}).Where("id = ?", id).
Update("status", status).Error; err != nil {
return fmt.Errorf("failed to update status: %w", err)
}
return nil
}
// GetStats 获取照片统计
func (r *photoRepository) GetStats() (*PhotoStats, error) {
var stats PhotoStats
// 总数
if err := r.db.Model(&models.Photo{}).Count(&stats.Total).Error; err != nil {
return nil, fmt.Errorf("failed to count total photos: %w", err)
}
// 已发布
if err := r.db.Model(&models.Photo{}).Where("status = ?", models.StatusPublished).
Count(&stats.Published).Error; err != nil {
return nil, fmt.Errorf("failed to count published photos: %w", err)
}
// 草稿
if err := r.db.Model(&models.Photo{}).Where("status = ?", models.StatusDraft).
Count(&stats.Draft).Error; err != nil {
return nil, fmt.Errorf("failed to count draft photos: %w", err)
}
// 已归档
if err := r.db.Model(&models.Photo{}).Where("status = ?", models.StatusArchived).
Count(&stats.Archived).Error; err != nil {
return nil, fmt.Errorf("failed to count archived photos: %w", err)
}
return &stats, nil
}

View File

@ -0,0 +1,217 @@
package postgres
import (
"fmt"
"photography-backend/internal/models"
"gorm.io/gorm"
)
// TagRepository 标签仓库接口
type TagRepository interface {
Create(tag *models.Tag) error
GetByID(id uint) (*models.Tag, error)
GetByName(name string) (*models.Tag, error)
Update(tag *models.Tag) error
Delete(id uint) error
List(params *models.TagListParams) ([]*models.Tag, int64, error)
Search(query string, limit int) ([]*models.Tag, error)
GetPopular(limit int) ([]*models.Tag, error)
GetOrCreate(name string) (*models.Tag, error)
IncrementUseCount(id uint) error
DecrementUseCount(id uint) error
GetCloud(minUsage int, maxTags int) ([]*models.Tag, error)
}
// tagRepository 标签仓库实现
type tagRepository struct {
db *gorm.DB
}
// NewTagRepository 创建标签仓库
func NewTagRepository(db *gorm.DB) TagRepository {
return &tagRepository{db: db}
}
// Create 创建标签
func (r *tagRepository) Create(tag *models.Tag) error {
if err := r.db.Create(tag).Error; err != nil {
return fmt.Errorf("failed to create tag: %w", err)
}
return nil
}
// GetByID 根据ID获取标签
func (r *tagRepository) GetByID(id uint) (*models.Tag, error) {
var tag models.Tag
if err := r.db.First(&tag, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get tag by id: %w", err)
}
return &tag, nil
}
// GetByName 根据名称获取标签
func (r *tagRepository) GetByName(name string) (*models.Tag, error) {
var tag models.Tag
if err := r.db.Where("name = ?", name).First(&tag).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get tag by name: %w", err)
}
return &tag, nil
}
// Update 更新标签
func (r *tagRepository) Update(tag *models.Tag) error {
if err := r.db.Save(tag).Error; err != nil {
return fmt.Errorf("failed to update tag: %w", err)
}
return nil
}
// Delete 删除标签
func (r *tagRepository) Delete(id uint) error {
// 开启事务
tx := r.db.Begin()
// 删除照片标签关联
if err := tx.Exec("DELETE FROM photo_tags WHERE tag_id = ?", id).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to delete photo tag relations: %w", err)
}
// 删除标签
if err := tx.Delete(&models.Tag{}, id).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to delete tag: %w", err)
}
return tx.Commit().Error
}
// List 获取标签列表
func (r *tagRepository) List(params *models.TagListParams) ([]*models.Tag, int64, error) {
var tags []*models.Tag
var total int64
query := r.db.Model(&models.Tag{})
// 添加过滤条件
if params.Search != "" {
query = query.Where("name ILIKE ?", "%"+params.Search+"%")
}
if params.IsActive {
query = query.Where("is_active = ?", true)
}
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count tags: %w", err)
}
// 排序
orderClause := fmt.Sprintf("%s %s", params.SortBy, params.SortOrder)
// 分页查询
offset := (params.Page - 1) * params.Limit
if err := query.Offset(offset).Limit(params.Limit).
Order(orderClause).
Find(&tags).Error; err != nil {
return nil, 0, fmt.Errorf("failed to list tags: %w", err)
}
return tags, total, nil
}
// Search 搜索标签
func (r *tagRepository) Search(query string, limit int) ([]*models.Tag, error) {
var tags []*models.Tag
if err := r.db.Where("name ILIKE ? AND is_active = ?", "%"+query+"%", true).
Order("use_count DESC").
Limit(limit).
Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to search tags: %w", err)
}
return tags, nil
}
// GetPopular 获取热门标签
func (r *tagRepository) GetPopular(limit int) ([]*models.Tag, error) {
var tags []*models.Tag
if err := r.db.Where("is_active = ?", true).
Order("use_count DESC").
Limit(limit).
Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to get popular tags: %w", err)
}
return tags, nil
}
// GetOrCreate 获取或创建标签
func (r *tagRepository) GetOrCreate(name string) (*models.Tag, error) {
var tag models.Tag
// 先尝试获取
if err := r.db.Where("name = ?", name).First(&tag).Error; err != nil {
if err == gorm.ErrRecordNotFound {
// 不存在则创建
tag = models.Tag{
Name: name,
UseCount: 0,
IsActive: true,
}
if err := r.db.Create(&tag).Error; err != nil {
return nil, fmt.Errorf("failed to create tag: %w", err)
}
} else {
return nil, fmt.Errorf("failed to get tag: %w", err)
}
}
return &tag, nil
}
// IncrementUseCount 增加使用次数
func (r *tagRepository) IncrementUseCount(id uint) error {
if err := r.db.Model(&models.Tag{}).Where("id = ?", id).
Update("use_count", gorm.Expr("use_count + 1")).Error; err != nil {
return fmt.Errorf("failed to increment use count: %w", err)
}
return nil
}
// DecrementUseCount 减少使用次数
func (r *tagRepository) DecrementUseCount(id uint) error {
if err := r.db.Model(&models.Tag{}).Where("id = ?", id).
Update("use_count", gorm.Expr("GREATEST(use_count - 1, 0)")).Error; err != nil {
return fmt.Errorf("failed to decrement use count: %w", err)
}
return nil
}
// GetCloud 获取标签云数据
func (r *tagRepository) GetCloud(minUsage int, maxTags int) ([]*models.Tag, error) {
var tags []*models.Tag
query := r.db.Where("is_active = ?", true)
if minUsage > 0 {
query = query.Where("use_count >= ?", minUsage)
}
if err := query.Order("use_count DESC").
Limit(maxTags).
Find(&tags).Error; err != nil {
return nil, fmt.Errorf("failed to get tag cloud: %w", err)
}
return tags, nil
}

View File

@ -0,0 +1,129 @@
package postgres
import (
"fmt"
"photography-backend/internal/models"
"gorm.io/gorm"
)
// UserRepository 用户仓库接口
type UserRepository interface {
Create(user *models.User) error
GetByID(id uint) (*models.User, error)
GetByUsername(username string) (*models.User, error)
GetByEmail(email string) (*models.User, error)
Update(user *models.User) error
Delete(id uint) error
List(page, limit int, role string, isActive *bool) ([]*models.User, int64, error)
UpdateLastLogin(id uint) error
}
// userRepository 用户仓库实现
type userRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户仓库
func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepository{db: db}
}
// Create 创建用户
func (r *userRepository) Create(user *models.User) error {
if err := r.db.Create(user).Error; err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// GetByID 根据ID获取用户
func (r *userRepository) GetByID(id uint) (*models.User, error) {
var user models.User
if err := r.db.First(&user, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get user by id: %w", err)
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *userRepository) GetByUsername(username string) (*models.User, error) {
var user models.User
if err := r.db.Where("username = ?", username).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get user by username: %w", err)
}
return &user, nil
}
// GetByEmail 根据邮箱获取用户
func (r *userRepository) GetByEmail(email string) (*models.User, error) {
var user models.User
if err := r.db.Where("email = ?", email).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get user by email: %w", err)
}
return &user, nil
}
// Update 更新用户
func (r *userRepository) Update(user *models.User) error {
if err := r.db.Save(user).Error; err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// Delete 删除用户
func (r *userRepository) Delete(id uint) error {
if err := r.db.Delete(&models.User{}, id).Error; err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}
// List 获取用户列表
func (r *userRepository) List(page, limit int, role string, isActive *bool) ([]*models.User, int64, error) {
var users []*models.User
var total int64
query := r.db.Model(&models.User{})
// 添加过滤条件
if role != "" {
query = query.Where("role = ?", role)
}
if isActive != nil {
query = query.Where("is_active = ?", *isActive)
}
// 计算总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count users: %w", err)
}
// 分页查询
offset := (page - 1) * limit
if err := query.Offset(offset).Limit(limit).
Order("created_at DESC").
Find(&users).Error; err != nil {
return nil, 0, fmt.Errorf("failed to list users: %w", err)
}
return users, total, nil
}
// UpdateLastLogin 更新最后登录时间
func (r *userRepository) UpdateLastLogin(id uint) error {
if err := r.db.Model(&models.User{}).Where("id = ?", id).
Update("last_login", gorm.Expr("NOW()")).Error; err != nil {
return fmt.Errorf("failed to update last login: %w", err)
}
return nil
}

View File

@ -0,0 +1,253 @@
package auth
import (
"fmt"
"time"
"golang.org/x/crypto/bcrypt"
"photography-backend/internal/models"
"photography-backend/internal/repository/postgres"
)
// AuthService 认证服务
type AuthService struct {
userRepo postgres.UserRepository
jwtService *JWTService
}
// NewAuthService 创建认证服务
func NewAuthService(userRepo postgres.UserRepository, jwtService *JWTService) *AuthService {
return &AuthService{
userRepo: userRepo,
jwtService: jwtService,
}
}
// Login 用户登录
func (s *AuthService) Login(req *models.LoginRequest) (*models.LoginResponse, error) {
// 根据用户名或邮箱查找用户
var user *models.User
var err error
// 尝试按用户名查找
user, err = s.userRepo.GetByUsername(req.Username)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
// 如果用户名未找到,尝试按邮箱查找
if user == nil {
user, err = s.userRepo.GetByEmail(req.Username)
if err != nil {
return nil, fmt.Errorf("failed to get user by email: %w", err)
}
}
if user == nil {
return nil, fmt.Errorf("invalid credentials")
}
// 检查用户是否激活
if !user.IsActive {
return nil, fmt.Errorf("user account is deactivated")
}
// 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// 生成JWT令牌
tokenPair, err := s.jwtService.GenerateTokenPair(user.ID, user.Username, user.Role)
if err != nil {
return nil, fmt.Errorf("failed to generate tokens: %w", err)
}
// 更新最后登录时间
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
// 记录错误但不中断登录流程
fmt.Printf("failed to update last login: %v\n", err)
}
// 清除密码字段
user.Password = ""
return &models.LoginResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: user,
}, nil
}
// Register 用户注册
func (s *AuthService) Register(req *models.CreateUserRequest) (*models.User, error) {
// 检查用户名是否已存在
existingUser, err := s.userRepo.GetByUsername(req.Username)
if err != nil {
return nil, fmt.Errorf("failed to check username: %w", err)
}
if existingUser != nil {
return nil, fmt.Errorf("username already exists")
}
// 检查邮箱是否已存在
existingUser, err = s.userRepo.GetByEmail(req.Email)
if err != nil {
return nil, fmt.Errorf("failed to check email: %w", err)
}
if existingUser != nil {
return nil, fmt.Errorf("email already exists")
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// 创建用户
user := &models.User{
Username: req.Username,
Email: req.Email,
Password: string(hashedPassword),
Name: req.Name,
Role: req.Role,
IsActive: true,
}
// 如果没有指定角色,默认为普通用户
if user.Role == "" {
user.Role = models.RoleUser
}
if err := s.userRepo.Create(user); err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
// 清除密码字段
user.Password = ""
return user, nil
}
// RefreshToken 刷新令牌
func (s *AuthService) RefreshToken(req *models.RefreshTokenRequest) (*models.LoginResponse, error) {
// 验证刷新令牌
claims, err := s.jwtService.ValidateToken(req.RefreshToken)
if err != nil {
return nil, fmt.Errorf("invalid refresh token: %w", err)
}
// 获取用户信息
user, err := s.userRepo.GetByID(claims.UserID)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return nil, fmt.Errorf("user not found")
}
// 检查用户是否激活
if !user.IsActive {
return nil, fmt.Errorf("user account is deactivated")
}
// 生成新的令牌对
tokenPair, err := s.jwtService.GenerateTokenPair(user.ID, user.Username, user.Role)
if err != nil {
return nil, fmt.Errorf("failed to generate tokens: %w", err)
}
// 清除密码字段
user.Password = ""
return &models.LoginResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
TokenType: tokenPair.TokenType,
ExpiresIn: tokenPair.ExpiresIn,
User: user,
}, nil
}
// GetUserByID 根据ID获取用户
func (s *AuthService) GetUserByID(id uint) (*models.User, error) {
user, err := s.userRepo.GetByID(id)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return nil, fmt.Errorf("user not found")
}
// 清除密码字段
user.Password = ""
return user, nil
}
// UpdatePassword 更新密码
func (s *AuthService) UpdatePassword(userID uint, req *models.UpdatePasswordRequest) error {
// 获取用户信息
user, err := s.userRepo.GetByID(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return fmt.Errorf("user not found")
}
// 验证旧密码
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.OldPassword)); err != nil {
return fmt.Errorf("invalid old password")
}
// 加密新密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
// 更新密码
user.Password = string(hashedPassword)
if err := s.userRepo.Update(user); err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
// CheckPermission 检查权限
func (s *AuthService) CheckPermission(userRole string, requiredRole string) bool {
roleLevel := map[string]int{
models.RoleUser: 1,
models.RoleEditor: 2,
models.RoleAdmin: 3,
}
userLevel, exists := roleLevel[userRole]
if !exists {
return false
}
requiredLevel, exists := roleLevel[requiredRole]
if !exists {
return false
}
return userLevel >= requiredLevel
}
// IsAdmin 检查是否为管理员
func (s *AuthService) IsAdmin(userRole string) bool {
return userRole == models.RoleAdmin
}
// IsEditor 检查是否为编辑者或以上
func (s *AuthService) IsEditor(userRole string) bool {
return userRole == models.RoleEditor || userRole == models.RoleAdmin
}

View File

@ -0,0 +1,129 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"photography-backend/internal/config"
)
// JWTService JWT服务
type JWTService struct {
secretKey []byte
accessTokenDuration time.Duration
refreshTokenDuration time.Duration
}
// JWTClaims JWT声明
type JWTClaims struct {
UserID uint `json:"user_id"`
Username string `json:"username"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// TokenPair 令牌对
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
// NewJWTService 创建JWT服务
func NewJWTService(cfg *config.JWTConfig) *JWTService {
return &JWTService{
secretKey: []byte(cfg.Secret),
accessTokenDuration: config.AppConfig.GetJWTExpiration(),
refreshTokenDuration: config.AppConfig.GetJWTRefreshExpiration(),
}
}
// GenerateTokenPair 生成令牌对
func (s *JWTService) GenerateTokenPair(userID uint, username, role string) (*TokenPair, error) {
// 生成访问令牌
accessToken, err := s.generateToken(userID, username, role, s.accessTokenDuration)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
// 生成刷新令牌
refreshToken, err := s.generateToken(userID, username, role, s.refreshTokenDuration)
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
return &TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(s.accessTokenDuration.Seconds()),
}, nil
}
// generateToken 生成令牌
func (s *JWTService) generateToken(userID uint, username, role string, duration time.Duration) (string, error) {
now := time.Now()
claims := &JWTClaims{
UserID: userID,
Username: username,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(duration)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
Issuer: "photography-backend",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.secretKey)
}
// ValidateToken 验证令牌
func (s *JWTService) ValidateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return s.secretKey, nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}
// RefreshToken 刷新令牌
func (s *JWTService) RefreshToken(refreshToken string) (*TokenPair, error) {
claims, err := s.ValidateToken(refreshToken)
if err != nil {
return nil, fmt.Errorf("invalid refresh token: %w", err)
}
// 生成新的令牌对
return s.GenerateTokenPair(claims.UserID, claims.Username, claims.Role)
}
// GetClaimsFromToken 从令牌中获取声明
func (s *JWTService) GetClaimsFromToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.secretKey, nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*JWTClaims); ok {
return claims, nil
}
return nil, fmt.Errorf("invalid claims")
}