Compare commits
2 Commits
8a0792500e
...
5dd0bc19e4
| Author | SHA1 | Date | |
|---|---|---|---|
| 5dd0bc19e4 | |||
| 48b6a5f4aa |
31
admin/.eslintrc.json
Normal file
31
admin/.eslintrc.json
Normal file
@ -0,0 +1,31 @@
|
||||
{
|
||||
"env": {
|
||||
"browser": true,
|
||||
"es2020": true
|
||||
},
|
||||
"extends": [
|
||||
"eslint:recommended"
|
||||
],
|
||||
"ignorePatterns": [
|
||||
"dist",
|
||||
"node_modules",
|
||||
"*.config.*"
|
||||
],
|
||||
"parser": "@typescript-eslint/parser",
|
||||
"parserOptions": {
|
||||
"ecmaVersion": "latest",
|
||||
"sourceType": "module",
|
||||
"ecmaFeatures": {
|
||||
"jsx": true
|
||||
}
|
||||
},
|
||||
"plugins": ["@typescript-eslint"],
|
||||
"rules": {
|
||||
"no-unused-vars": "off",
|
||||
"@typescript-eslint/no-unused-vars": [
|
||||
"error",
|
||||
{ "argsIgnorePattern": "^_", "varsIgnorePattern": "^_" }
|
||||
],
|
||||
"no-console": "warn"
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
import React from "react"
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
function Skeleton({
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { useState, useCallback, useRef, useEffect } from 'react'
|
||||
import React, { useState, useCallback, useRef, useEffect } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
||||
|
||||
@ -29,8 +29,8 @@ func main() {
|
||||
|
||||
// 添加静态文件服务
|
||||
server.AddRoute(rest.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/uploads/*",
|
||||
Method: http.MethodGet,
|
||||
Path: "/uploads/*",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
http.StripPrefix("/uploads/", http.FileServer(http.Dir("uploads"))).ServeHTTP(w, r)
|
||||
},
|
||||
|
||||
@ -95,7 +95,7 @@ func main() {
|
||||
}
|
||||
migrateSteps = len(pendingMigrations)
|
||||
}
|
||||
|
||||
|
||||
if err := migrator.Migrate(migrateSteps); err != nil {
|
||||
log.Fatalf("Migration failed: %v", err)
|
||||
}
|
||||
@ -168,10 +168,10 @@ func showHelp() {
|
||||
|
||||
func createMigrationTemplate(name string) {
|
||||
// 生成版本号(基于当前时间)
|
||||
version := fmt.Sprintf("%d_%06d",
|
||||
getCurrentTimestamp(),
|
||||
version := fmt.Sprintf("%d_%06d",
|
||||
getCurrentTimestamp(),
|
||||
getCurrentMicroseconds())
|
||||
|
||||
|
||||
template := fmt.Sprintf(`// Migration: %s
|
||||
// Description: %s
|
||||
// Version: %s
|
||||
@ -198,22 +198,22 @@ var migration_%s = Migration{
|
||||
-- ALTER TABLE user DROP COLUMN new_field; -- Note: SQLite doesn't support DROP COLUMN
|
||||
%s,
|
||||
}`,
|
||||
name, name, version,
|
||||
name, name, version,
|
||||
version, version, name,
|
||||
"`", "`", "`", "`")
|
||||
|
||||
|
||||
filename := fmt.Sprintf("migrations/%s_%s.go", version, name)
|
||||
|
||||
|
||||
// 创建 migrations 目录
|
||||
if err := os.MkdirAll("migrations", 0755); err != nil {
|
||||
log.Fatalf("Failed to create migrations directory: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 写入模板文件
|
||||
if err := os.WriteFile(filename, []byte(template), 0644); err != nil {
|
||||
log.Fatalf("Failed to create migration template: %v", err)
|
||||
}
|
||||
|
||||
|
||||
fmt.Printf("Created migration template: %s\n", filename)
|
||||
fmt.Println("Please:")
|
||||
fmt.Println("1. Edit the migration file to add your SQL")
|
||||
@ -227,4 +227,4 @@ func getCurrentTimestamp() int64 {
|
||||
|
||||
func getCurrentMicroseconds() int {
|
||||
return 1 // 当天的迁移计数,可以根据需要调整
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,8 +7,8 @@ import (
|
||||
|
||||
type Config struct {
|
||||
rest.RestConf
|
||||
Database database.Config `json:"database"`
|
||||
Auth AuthConfig `json:"auth"`
|
||||
Database database.Config `json:"database"`
|
||||
Auth AuthConfig `json:"auth"`
|
||||
FileUpload FileUploadConfig `json:"file_upload"`
|
||||
Middleware MiddlewareConfig `json:"middleware"`
|
||||
}
|
||||
@ -28,6 +28,6 @@ type MiddlewareConfig struct {
|
||||
EnableCORS bool `json:"enable_cors"`
|
||||
EnableLogger bool `json:"enable_logger"`
|
||||
EnableErrorHandle bool `json:"enable_error_handle"`
|
||||
CORSOrigins []string `json:"cors_origins"`
|
||||
LogLevel string `json:"log_level"`
|
||||
CORSOrigins []string `json:"cors_origins"`
|
||||
LogLevel string `json:"log_level"`
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ func UploadPhotoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||||
Title: r.FormValue("title"),
|
||||
Description: r.FormValue("description"),
|
||||
}
|
||||
|
||||
|
||||
// 解析 category_id
|
||||
if categoryIdStr := r.FormValue("category_id"); categoryIdStr != "" {
|
||||
var categoryId int64
|
||||
|
||||
@ -39,24 +39,24 @@ func (l *LoginLogic) Login(req *types.LoginRequest) (resp *types.LoginResponse,
|
||||
logx.Errorf("查询用户失败: %v", err)
|
||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||
}
|
||||
|
||||
|
||||
// 2. 验证密码
|
||||
if !hash.CheckPassword(req.Password, user.Password) {
|
||||
return nil, errorx.NewWithCode(errorx.InvalidPassword)
|
||||
}
|
||||
|
||||
|
||||
// 3. 检查用户状态
|
||||
if user.Status == 0 {
|
||||
return nil, errorx.NewWithCode(errorx.UserDisabled)
|
||||
}
|
||||
|
||||
|
||||
// 4. 生成 JWT token
|
||||
token, err := jwt.GenerateToken(user.Id, user.Username, l.svcCtx.Config.Auth.AccessSecret, time.Hour*24*7)
|
||||
if err != nil {
|
||||
logx.Errorf("生成 JWT token 失败: %v", err)
|
||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||
}
|
||||
|
||||
|
||||
// 5. 返回登录结果
|
||||
return &types.LoginResponse{
|
||||
BaseResponse: types.BaseResponse{
|
||||
|
||||
@ -34,34 +34,34 @@ func (l *RegisterLogic) Register(req *types.RegisterRequest) (resp *types.Regist
|
||||
if err == nil && existingUser != nil {
|
||||
return nil, errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
|
||||
// 2. 检查邮箱是否已存在
|
||||
existingEmail, err := l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email)
|
||||
if err == nil && existingEmail != nil {
|
||||
return nil, errors.New("邮箱已存在")
|
||||
}
|
||||
|
||||
|
||||
// 3. 加密密码
|
||||
hashedPassword, err := hash.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 4. 创建用户
|
||||
user := &model.User{
|
||||
Username: req.Username,
|
||||
Email: req.Email,
|
||||
Password: hashedPassword,
|
||||
Status: 1, // 默认激活状态
|
||||
Username: req.Username,
|
||||
Email: req.Email,
|
||||
Password: hashedPassword,
|
||||
Status: 1, // 默认激活状态
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
_, err = l.svcCtx.UserModel.Insert(l.ctx, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 5. 返回注册结果
|
||||
return &types.RegisterResponse{
|
||||
BaseResponse: types.BaseResponse{
|
||||
|
||||
@ -35,12 +35,12 @@ func (l *CreateCategoryLogic) CreateCategory(req *types.CreateCategoryRequest) (
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
_, err = l.svcCtx.CategoryModel.Insert(l.ctx, category)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 2. 返回结果
|
||||
return &types.CreateCategoryResponse{
|
||||
BaseResponse: types.BaseResponse{
|
||||
|
||||
@ -30,13 +30,13 @@ func (l *GetCategoryListLogic) GetCategoryList(req *types.GetCategoryListRequest
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 2. 统计总数
|
||||
total, err := l.svcCtx.CategoryModel.Count(l.ctx, req.Keyword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 3. 转换数据结构
|
||||
var categoryList []types.Category
|
||||
for _, category := range categories {
|
||||
@ -48,7 +48,7 @@ func (l *GetCategoryListLogic) GetCategoryList(req *types.GetCategoryListRequest
|
||||
UpdatedAt: category.UpdatedAt.Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 4. 返回结果
|
||||
return &types.GetCategoryListResponse{
|
||||
BaseResponse: types.BaseResponse{
|
||||
|
||||
@ -32,14 +32,14 @@ func (l *GetPhotoListLogic) GetPhotoList(req *types.GetPhotoListRequest) (resp *
|
||||
logx.Errorf("查询照片列表失败: %v", err)
|
||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||
}
|
||||
|
||||
|
||||
// 2. 统计总数
|
||||
total, err := l.svcCtx.PhotoModel.Count(l.ctx, req.CategoryId, req.UserId, req.Keyword)
|
||||
if err != nil {
|
||||
logx.Errorf("统计照片数量失败: %v", err)
|
||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||
}
|
||||
|
||||
|
||||
// 3. 转换数据结构
|
||||
var photoList []types.Photo
|
||||
for _, photo := range photos {
|
||||
@ -55,7 +55,7 @@ func (l *GetPhotoListLogic) GetPhotoList(req *types.GetPhotoListRequest) (resp *
|
||||
UpdatedAt: photo.UpdatedAt.Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 4. 返回结果
|
||||
return &types.GetPhotoListResponse{
|
||||
BaseResponse: types.BaseResponse{
|
||||
|
||||
@ -36,7 +36,7 @@ func (l *GetPhotoLogic) GetPhoto(req *types.GetPhotoRequest) (resp *types.GetPho
|
||||
logx.Errorf("查询照片失败: %v", err)
|
||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||
}
|
||||
|
||||
|
||||
// 2. 返回结果
|
||||
return &types.GetPhotoResponse{
|
||||
BaseResponse: types.BaseResponse{
|
||||
|
||||
@ -40,7 +40,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
|
||||
// 后续需要实现JWT中间件
|
||||
userId = int64(1)
|
||||
}
|
||||
|
||||
|
||||
// 2. 验证分类是否存在
|
||||
_, err = l.svcCtx.CategoryModel.FindOne(l.ctx, req.CategoryId)
|
||||
if err != nil {
|
||||
@ -50,20 +50,20 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
|
||||
logx.Errorf("查询分类失败: %v", err)
|
||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||
}
|
||||
|
||||
|
||||
// 3. 处理文件上传
|
||||
fileConfig := fileUtil.Config{
|
||||
MaxSize: l.svcCtx.Config.FileUpload.MaxSize,
|
||||
UploadDir: l.svcCtx.Config.FileUpload.UploadDir,
|
||||
AllowedTypes: l.svcCtx.Config.FileUpload.AllowedTypes,
|
||||
}
|
||||
|
||||
|
||||
uploadResult, err := fileUtil.UploadPhoto(file, header, fileConfig)
|
||||
if err != nil {
|
||||
logx.Errorf("文件上传失败: %v", err)
|
||||
return nil, errorx.NewWithCode(errorx.PhotoUploadFail)
|
||||
}
|
||||
|
||||
|
||||
// 4. 创建照片记录
|
||||
photo := &model.Photo{
|
||||
Title: req.Title,
|
||||
@ -75,7 +75,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
_, err = l.svcCtx.PhotoModel.Insert(l.ctx, photo)
|
||||
if err != nil {
|
||||
// 如果数据库保存失败,删除已上传的文件
|
||||
@ -84,7 +84,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
|
||||
logx.Errorf("保存照片记录失败: %v", err)
|
||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||
}
|
||||
|
||||
|
||||
// 5. 返回上传结果
|
||||
return &types.UploadPhotoResponse{
|
||||
BaseResponse: types.BaseResponse{
|
||||
|
||||
@ -50,7 +50,7 @@ func (l *DeleteUserLogic) DeleteUser(req *types.DeleteUserRequest) (resp *types.
|
||||
// 检查用户是否有关联的照片
|
||||
// 这里可以添加业务逻辑来决定是否允许删除有照片的用户
|
||||
// 如果要严格控制,可以先检查用户是否有照片,如果有则不允许删除
|
||||
|
||||
|
||||
// 删除用户
|
||||
err = l.svcCtx.UserModel.Delete(l.ctx, req.Id)
|
||||
if err != nil {
|
||||
|
||||
@ -30,13 +30,13 @@ func (l *GetUserListLogic) GetUserList(req *types.GetUserListRequest) (resp *typ
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 2. 统计总数
|
||||
total, err := l.svcCtx.UserModel.Count(l.ctx, req.Keyword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 3. 转换数据结构(不返回密码)
|
||||
var userList []types.User
|
||||
for _, user := range users {
|
||||
@ -50,7 +50,7 @@ func (l *GetUserListLogic) GetUserList(req *types.GetUserListRequest) (resp *typ
|
||||
UpdatedAt: user.UpdatedAt.Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 4. 返回结果
|
||||
return &types.GetUserListResponse{
|
||||
BaseResponse: types.BaseResponse{
|
||||
|
||||
@ -55,7 +55,7 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
|
||||
// 4. 获取上传的文件
|
||||
uploadedFile, header, err := r.FormFile("avatar")
|
||||
if err != nil {
|
||||
return nil, errorx.New(errorx.ParamError, "获取上传文件失败: " + err.Error())
|
||||
return nil, errorx.New(errorx.ParamError, "获取上传文件失败: "+err.Error())
|
||||
}
|
||||
defer uploadedFile.Close()
|
||||
|
||||
@ -90,10 +90,10 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
|
||||
|
||||
filename := fmt.Sprintf("avatar_%d_%d%s", req.Id, time.Now().Unix(), ext)
|
||||
avatarDir := "uploads/avatars"
|
||||
|
||||
|
||||
// 8. 确保头像目录存在
|
||||
if err := os.MkdirAll(avatarDir, 0755); err != nil {
|
||||
return nil, errorx.New(errorx.ServerError, "创建头像目录失败: " + err.Error())
|
||||
return nil, errorx.New(errorx.ServerError, "创建头像目录失败: "+err.Error())
|
||||
}
|
||||
|
||||
avatarPath := filepath.Join(avatarDir, filename)
|
||||
@ -101,13 +101,13 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
|
||||
// 9. 保存原始头像文件
|
||||
destFile, err := os.Create(avatarPath)
|
||||
if err != nil {
|
||||
return nil, errorx.New(errorx.ServerError, "创建头像文件失败: " + err.Error())
|
||||
return nil, errorx.New(errorx.ServerError, "创建头像文件失败: "+err.Error())
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, uploadedFile)
|
||||
if err != nil {
|
||||
return nil, errorx.New(errorx.ServerError, "保存头像文件失败: " + err.Error())
|
||||
return nil, errorx.New(errorx.ServerError, "保存头像文件失败: "+err.Error())
|
||||
}
|
||||
|
||||
// 10. 生成压缩版本的头像 (150x150像素)
|
||||
@ -142,7 +142,7 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
|
||||
if avatarPath != compressedPath {
|
||||
os.Remove(compressedPath)
|
||||
}
|
||||
return nil, errorx.New(errorx.ServerError, "更新用户头像失败: " + err.Error())
|
||||
return nil, errorx.New(errorx.ServerError, "更新用户头像失败: "+err.Error())
|
||||
}
|
||||
|
||||
return &types.UploadAvatarResponse{
|
||||
|
||||
@ -73,4 +73,4 @@ func (e UnauthorizedError) Error() string {
|
||||
|
||||
func NewUnauthorizedError(message string) UnauthorizedError {
|
||||
return UnauthorizedError{Message: message}
|
||||
}
|
||||
}
|
||||
|
||||
@ -87,7 +87,7 @@ func NewCORSMiddleware(config CORSConfig) *CORSMiddleware {
|
||||
func (m *CORSMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
|
||||
// 检查来源是否被允许
|
||||
if origin != "" && m.isOriginAllowed(origin) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
@ -160,16 +160,16 @@ func (m *CORSMiddleware) isOriginAllowed(origin string) bool {
|
||||
func (m *CORSMiddleware) setSecurityHeaders(w http.ResponseWriter) {
|
||||
// 防止点击劫持
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
|
||||
|
||||
// 防止 MIME 类型嗅探
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
|
||||
// XSS 保护
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
|
||||
// 引用者策略
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
|
||||
// 内容安全策略 (基础版)
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self'")
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,8 +19,8 @@ type ErrorConfig struct {
|
||||
EnableDetailedErrors bool // 是否启用详细错误信息 (开发环境)
|
||||
EnableStackTrace bool // 是否启用堆栈跟踪
|
||||
EnableErrorMonitor bool // 是否启用错误监控
|
||||
IgnoreHTTPCodes []int // 忽略的HTTP状态码 (不记录为错误)
|
||||
SensitiveFields []string // 敏感字段列表 (日志时隐藏)
|
||||
IgnoreHTTPCodes []int // 忽略的HTTP状态码 (不记录为错误)
|
||||
SensitiveFields []string // 敏感字段列表 (日志时隐藏)
|
||||
}
|
||||
|
||||
// DefaultErrorConfig 默认错误配置
|
||||
@ -29,8 +29,8 @@ func DefaultErrorConfig() ErrorConfig {
|
||||
EnableDetailedErrors: false, // 生产环境默认关闭
|
||||
EnableStackTrace: false, // 生产环境默认关闭
|
||||
EnableErrorMonitor: true,
|
||||
IgnoreHTTPCodes: []int{http.StatusNotFound, http.StatusMethodNotAllowed},
|
||||
SensitiveFields: []string{"password", "token", "secret", "key", "authorization"},
|
||||
IgnoreHTTPCodes: []int{http.StatusNotFound, http.StatusMethodNotAllowed},
|
||||
SensitiveFields: []string{"password", "token", "secret", "key", "authorization"},
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,7 +111,7 @@ func (m *ErrorMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
// handlePanic 处理panic
|
||||
func (m *ErrorMiddleware) handlePanic(w *errorResponseWriter, r *http.Request, err interface{}) {
|
||||
stack := string(debug.Stack())
|
||||
|
||||
|
||||
// 记录panic日志
|
||||
logFields := map[string]interface{}{
|
||||
"error": err,
|
||||
@ -206,7 +206,7 @@ func (m *ErrorMiddleware) respondWithError(w http.ResponseWriter, r *http.Reques
|
||||
|
||||
// 设置HTTP状态码
|
||||
httpStatus := errorx.GetHttpStatus(err.Code)
|
||||
|
||||
|
||||
// 设置响应头
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(httpStatus)
|
||||
@ -218,10 +218,10 @@ func (m *ErrorMiddleware) respondWithError(w http.ResponseWriter, r *http.Reques
|
||||
// sanitizeFields 隐藏敏感字段
|
||||
func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string]interface{} {
|
||||
sanitized := make(map[string]interface{})
|
||||
|
||||
|
||||
for key, value := range data {
|
||||
lowerKey := strings.ToLower(key)
|
||||
|
||||
|
||||
// 检查是否为敏感字段
|
||||
sensitive := false
|
||||
for _, sensitiveField := range m.config.SensitiveFields {
|
||||
@ -230,7 +230,7 @@ func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if sensitive {
|
||||
sanitized[key] = "***REDACTED***"
|
||||
} else {
|
||||
@ -242,7 +242,7 @@ func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
@ -322,4 +322,4 @@ var CommonErrors = struct {
|
||||
Code: 429,
|
||||
Msg: "Rate Limit Exceeded",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,8 +15,8 @@ import (
|
||||
|
||||
// LoggerConfig 日志配置
|
||||
type LoggerConfig struct {
|
||||
EnableRequestBody bool // 是否记录请求体
|
||||
EnableResponseBody bool // 是否记录响应体
|
||||
EnableRequestBody bool // 是否记录请求体
|
||||
EnableResponseBody bool // 是否记录响应体
|
||||
MaxBodySize int64 // 最大记录的请求/响应体大小
|
||||
SkipPaths []string // 跳过记录的路径
|
||||
SlowRequestDuration time.Duration // 慢请求阈值
|
||||
@ -26,9 +26,9 @@ type LoggerConfig struct {
|
||||
// DefaultLoggerConfig 默认日志配置
|
||||
func DefaultLoggerConfig() LoggerConfig {
|
||||
return LoggerConfig{
|
||||
EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息)
|
||||
EnableResponseBody: false, // 默认不记录响应体 (减少日志量)
|
||||
MaxBodySize: 1024, // 最大记录1KB
|
||||
EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息)
|
||||
EnableResponseBody: false, // 默认不记录响应体 (减少日志量)
|
||||
MaxBodySize: 1024, // 最大记录1KB
|
||||
SkipPaths: []string{"/health", "/metrics", "/favicon.ico"},
|
||||
SlowRequestDuration: 1 * time.Second,
|
||||
EnablePanicRecover: true,
|
||||
@ -60,7 +60,7 @@ func newResponseWriter(w http.ResponseWriter) *responseWriter {
|
||||
return &responseWriter{
|
||||
ResponseWriter: w,
|
||||
status: http.StatusOK,
|
||||
body: &bytes.Buffer{},
|
||||
body: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,12 +68,12 @@ func newResponseWriter(w http.ResponseWriter) *responseWriter {
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
size, err := rw.ResponseWriter.Write(b)
|
||||
rw.size += int64(size)
|
||||
|
||||
|
||||
// 记录响应体 (如果启用)
|
||||
if rw.body.Len() < int(1024) { // 限制缓存大小
|
||||
rw.body.Write(b)
|
||||
}
|
||||
|
||||
|
||||
return size, err
|
||||
}
|
||||
|
||||
@ -163,7 +163,7 @@ func (m *LoggerMiddleware) generateRequestID(r *http.Request) string {
|
||||
if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
|
||||
|
||||
// 生成新的请求ID
|
||||
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
|
||||
}
|
||||
@ -208,13 +208,13 @@ func (m *LoggerMiddleware) logRequestStart(r *http.Request, requestID, requestBo
|
||||
// logRequestComplete 记录请求完成
|
||||
func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) {
|
||||
fields := map[string]interface{}{
|
||||
"request_id": requestID,
|
||||
"method": r.Method,
|
||||
"path": r.URL.Path,
|
||||
"status": status,
|
||||
"response_size": size,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"duration": duration.String(),
|
||||
"request_id": requestID,
|
||||
"method": r.Method,
|
||||
"path": r.URL.Path,
|
||||
"status": status,
|
||||
"response_size": size,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"duration": duration.String(),
|
||||
}
|
||||
|
||||
if responseBody != "" {
|
||||
@ -267,7 +267,7 @@ func getClientIP(r *http.Request) string {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 使用 RemoteAddr
|
||||
if ip := r.RemoteAddr; ip != "" {
|
||||
// 移除端口号
|
||||
@ -276,7 +276,7 @@ func getClientIP(r *http.Request) string {
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
@ -296,4 +296,4 @@ func randomString(length int) string {
|
||||
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,11 +13,11 @@ import (
|
||||
|
||||
// MiddlewareManager 中间件管理器
|
||||
type MiddlewareManager struct {
|
||||
config config.Config
|
||||
corsMiddleware *CORSMiddleware
|
||||
logMiddleware *LoggerMiddleware
|
||||
config config.Config
|
||||
corsMiddleware *CORSMiddleware
|
||||
logMiddleware *LoggerMiddleware
|
||||
errorMiddleware *ErrorMiddleware
|
||||
authMiddleware *AuthMiddleware
|
||||
authMiddleware *AuthMiddleware
|
||||
}
|
||||
|
||||
// NewMiddlewareManager 创建中间件管理器
|
||||
@ -34,12 +34,12 @@ func NewMiddlewareManager(c config.Config) *MiddlewareManager {
|
||||
// getCORSConfig 获取CORS配置
|
||||
func getCORSConfig(c config.Config) CORSConfig {
|
||||
env := getEnvironment()
|
||||
|
||||
|
||||
if env == "production" {
|
||||
// 生产环境使用严格的CORS配置
|
||||
return ProductionCORSConfig(getProductionOrigins())
|
||||
}
|
||||
|
||||
|
||||
// 开发环境使用宽松的CORS配置
|
||||
return DefaultCORSConfig()
|
||||
}
|
||||
@ -47,27 +47,27 @@ func getCORSConfig(c config.Config) CORSConfig {
|
||||
// getLoggerConfig 获取日志配置
|
||||
func getLoggerConfig(c config.Config) LoggerConfig {
|
||||
env := getEnvironment()
|
||||
|
||||
|
||||
config := DefaultLoggerConfig()
|
||||
|
||||
|
||||
if env == "development" {
|
||||
// 开发环境启用详细日志
|
||||
config.EnableRequestBody = true
|
||||
config.EnableResponseBody = true
|
||||
config.MaxBodySize = 4096
|
||||
}
|
||||
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// getErrorConfig 获取错误配置
|
||||
func getErrorConfig(c config.Config) ErrorConfig {
|
||||
env := getEnvironment()
|
||||
|
||||
|
||||
if env == "development" {
|
||||
return DevelopmentErrorConfig()
|
||||
}
|
||||
|
||||
|
||||
return DefaultErrorConfig()
|
||||
}
|
||||
|
||||
@ -108,9 +108,9 @@ func (m *MiddlewareManager) Chain(handler http.HandlerFunc, middlewares ...func(
|
||||
// GetGlobalMiddlewares 获取全局中间件
|
||||
func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc {
|
||||
return []func(http.HandlerFunc) http.HandlerFunc{
|
||||
m.errorMiddleware.Handle, // 错误处理 (最外层)
|
||||
m.corsMiddleware.Handle, // CORS 处理
|
||||
m.logMiddleware.Handle, // 日志记录
|
||||
m.errorMiddleware.Handle, // 错误处理 (最外层)
|
||||
m.corsMiddleware.Handle, // CORS 处理
|
||||
m.logMiddleware.Handle, // 日志记录
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,7 +134,7 @@ func (m *MiddlewareManager) ApplyAuthMiddlewares(handler http.HandlerFunc) http.
|
||||
func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"status":"ok","timestamp":"` +
|
||||
w.Write([]byte(`{"status":"ok","timestamp":"` +
|
||||
time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`))
|
||||
}
|
||||
|
||||
@ -171,7 +171,7 @@ func Recovery() MiddlewareFunc {
|
||||
"path": r.URL.Path,
|
||||
}
|
||||
logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields)
|
||||
|
||||
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
@ -188,10 +188,10 @@ func RequestID() MiddlewareFunc {
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
|
||||
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
r.Header.Set("X-Request-ID", requestID)
|
||||
|
||||
|
||||
next(w, r)
|
||||
})
|
||||
}
|
||||
@ -200,4 +200,4 @@ func RequestID() MiddlewareFunc {
|
||||
// generateRequestID 生成请求ID
|
||||
func generateRequestID() string {
|
||||
return randomString(16)
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,8 +3,8 @@ package model
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var _ CategoryModel = (*customCategoryModel)(nil)
|
||||
@ -39,20 +39,20 @@ func (m *customCategoryModel) withSession(session sqlx.Session) CategoryModel {
|
||||
func (m *customCategoryModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*Category, error) {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
|
||||
if keyword != "" {
|
||||
conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)")
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
|
||||
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", categoryRows, m.table, whereClause)
|
||||
var resp []*Category
|
||||
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
||||
@ -63,17 +63,17 @@ func (m *customCategoryModel) FindList(ctx context.Context, page, pageSize int,
|
||||
func (m *customCategoryModel) Count(ctx context.Context, keyword string) (int64, error) {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
|
||||
if keyword != "" {
|
||||
conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)")
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
|
||||
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
||||
var count int64
|
||||
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
||||
|
||||
@ -3,8 +3,8 @@ package model
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var _ PhotoModel = (*customPhotoModel)(nil)
|
||||
@ -39,30 +39,30 @@ func (m *customPhotoModel) withSession(session sqlx.Session) PhotoModel {
|
||||
func (m *customPhotoModel) FindList(ctx context.Context, page, pageSize int, categoryId, userId int64, keyword string) ([]*Photo, error) {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
|
||||
if categoryId > 0 {
|
||||
conditions = append(conditions, "`category_id` = ?")
|
||||
args = append(args, categoryId)
|
||||
}
|
||||
|
||||
|
||||
if userId > 0 {
|
||||
conditions = append(conditions, "`user_id` = ?")
|
||||
args = append(args, userId)
|
||||
}
|
||||
|
||||
|
||||
if keyword != "" {
|
||||
conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)")
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
|
||||
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", photoRows, m.table, whereClause)
|
||||
var resp []*Photo
|
||||
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
||||
@ -73,27 +73,27 @@ func (m *customPhotoModel) FindList(ctx context.Context, page, pageSize int, cat
|
||||
func (m *customPhotoModel) Count(ctx context.Context, categoryId, userId int64, keyword string) (int64, error) {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
|
||||
if categoryId > 0 {
|
||||
conditions = append(conditions, "`category_id` = ?")
|
||||
args = append(args, categoryId)
|
||||
}
|
||||
|
||||
|
||||
if userId > 0 {
|
||||
conditions = append(conditions, "`user_id` = ?")
|
||||
args = append(args, userId)
|
||||
}
|
||||
|
||||
|
||||
if keyword != "" {
|
||||
conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)")
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
|
||||
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
||||
var count int64
|
||||
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
||||
|
||||
@ -3,8 +3,8 @@ package model
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var _ UserModel = (*customUserModel)(nil)
|
||||
@ -39,20 +39,20 @@ func (m *customUserModel) withSession(session sqlx.Session) UserModel {
|
||||
func (m *customUserModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*User, error) {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
|
||||
if keyword != "" {
|
||||
conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)")
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
|
||||
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", userRows, m.table, whereClause)
|
||||
var resp []*User
|
||||
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
||||
@ -63,17 +63,17 @@ func (m *customUserModel) FindList(ctx context.Context, page, pageSize int, keyw
|
||||
func (m *customUserModel) Count(ctx context.Context, keyword string) (int64, error) {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
|
||||
if keyword != "" {
|
||||
conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)")
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
|
||||
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
||||
var count int64
|
||||
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
||||
|
||||
@ -2,17 +2,17 @@ package svc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
"gorm.io/gorm"
|
||||
"photography-backend/internal/config"
|
||||
"photography-backend/internal/middleware"
|
||||
"photography-backend/internal/model"
|
||||
"photography-backend/pkg/utils/database"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
)
|
||||
|
||||
type ServiceContext struct {
|
||||
Config config.Config
|
||||
DB *gorm.DB
|
||||
Config config.Config
|
||||
DB *gorm.DB
|
||||
UserModel model.UserModel
|
||||
PhotoModel model.PhotoModel
|
||||
CategoryModel model.CategoryModel
|
||||
@ -24,13 +24,13 @@ func NewServiceContext(c config.Config) *ServiceContext {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
// Create sqlx connection for go-zero models
|
||||
sqlxConn := sqlx.NewSqlConn(getSQLDriverName(c.Database.Driver), getSQLDataSource(c.Database))
|
||||
|
||||
|
||||
return &ServiceContext{
|
||||
Config: c,
|
||||
DB: db,
|
||||
Config: c,
|
||||
DB: db,
|
||||
UserModel: model.NewUserModel(sqlxConn),
|
||||
PhotoModel: model.NewPhotoModel(sqlxConn),
|
||||
CategoryModel: model.NewCategoryModel(sqlxConn),
|
||||
|
||||
@ -4,25 +4,25 @@ const (
|
||||
// 用户状态
|
||||
UserStatusActive = 1
|
||||
UserStatusInactive = 0
|
||||
|
||||
|
||||
// 文件上传
|
||||
MaxFileSize = 10 << 20 // 10MB
|
||||
|
||||
|
||||
// 图片类型
|
||||
ImageTypeJPEG = "image/jpeg"
|
||||
ImageTypePNG = "image/png"
|
||||
ImageTypeGIF = "image/gif"
|
||||
ImageTypeWEBP = "image/webp"
|
||||
|
||||
|
||||
// 缩略图尺寸
|
||||
ThumbnailWidth = 300
|
||||
ThumbnailHeight = 300
|
||||
|
||||
|
||||
// JWT 过期时间
|
||||
TokenExpireDuration = 24 * 60 * 60 // 24小时
|
||||
|
||||
|
||||
// 分页默认值
|
||||
DefaultPage = 1
|
||||
DefaultPageSize = 10
|
||||
MaxPageSize = 100
|
||||
)
|
||||
)
|
||||
|
||||
@ -7,14 +7,14 @@ import (
|
||||
|
||||
const (
|
||||
// 通用错误代码
|
||||
Success = 0
|
||||
ServerError = 500
|
||||
ParamError = 400
|
||||
AuthError = 401
|
||||
NotFound = 404
|
||||
Forbidden = 403
|
||||
InvalidParameter = 400 // 与 ParamError 统一
|
||||
|
||||
Success = 0
|
||||
ServerError = 500
|
||||
ParamError = 400
|
||||
AuthError = 401
|
||||
NotFound = 404
|
||||
Forbidden = 403
|
||||
InvalidParameter = 400 // 与 ParamError 统一
|
||||
|
||||
// 业务错误代码
|
||||
UserNotFound = 1001
|
||||
UserExists = 1002
|
||||
@ -22,32 +22,32 @@ const (
|
||||
InvalidPassword = 1004
|
||||
TokenExpired = 1005
|
||||
TokenInvalid = 1006
|
||||
|
||||
|
||||
PhotoNotFound = 2001
|
||||
PhotoUploadFail = 2002
|
||||
|
||||
|
||||
CategoryNotFound = 3001
|
||||
CategoryExists = 3002
|
||||
)
|
||||
|
||||
var codeText = map[int]string{
|
||||
Success: "Success",
|
||||
ServerError: "Server Error",
|
||||
ParamError: "Parameter Error", // ParamError 和 InvalidParameter 都映射到这里
|
||||
AuthError: "Authentication Error",
|
||||
NotFound: "Not Found",
|
||||
Forbidden: "Forbidden",
|
||||
|
||||
Success: "Success",
|
||||
ServerError: "Server Error",
|
||||
ParamError: "Parameter Error", // ParamError 和 InvalidParameter 都映射到这里
|
||||
AuthError: "Authentication Error",
|
||||
NotFound: "Not Found",
|
||||
Forbidden: "Forbidden",
|
||||
|
||||
UserNotFound: "User Not Found",
|
||||
UserExists: "User Already Exists",
|
||||
UserDisabled: "User Disabled",
|
||||
InvalidPassword: "Invalid Password",
|
||||
TokenExpired: "Token Expired",
|
||||
TokenInvalid: "Token Invalid",
|
||||
|
||||
|
||||
PhotoNotFound: "Photo Not Found",
|
||||
PhotoUploadFail: "Photo Upload Failed",
|
||||
|
||||
|
||||
CategoryNotFound: "Category Not Found",
|
||||
CategoryExists: "Category Already Exists",
|
||||
}
|
||||
@ -83,7 +83,7 @@ func GetHttpStatus(code int) int {
|
||||
switch code {
|
||||
case Success:
|
||||
return http.StatusOK
|
||||
case ParamError: // ParamError 和 InvalidParameter 都是 400,所以只需要一个 case
|
||||
case ParamError: // ParamError 和 InvalidParameter 都是 400,所以只需要一个 case
|
||||
return http.StatusBadRequest
|
||||
case AuthError, TokenExpired, TokenInvalid:
|
||||
return http.StatusUnauthorized
|
||||
@ -96,4 +96,4 @@ func GetHttpStatus(code int) int {
|
||||
default:
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,10 +21,10 @@ type Migration struct {
|
||||
|
||||
// MigrationRecord 数据库中的迁移记录
|
||||
type MigrationRecord struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Version string `gorm:"uniqueIndex;size:255;not null"`
|
||||
Description string `gorm:"size:500"`
|
||||
Applied bool `gorm:"default:false"`
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Version string `gorm:"uniqueIndex;size:255;not null"`
|
||||
Description string `gorm:"size:500"`
|
||||
Applied bool `gorm:"default:false"`
|
||||
AppliedAt time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
@ -66,7 +66,7 @@ func (m *Migrator) GetAppliedMigrations() ([]string, error) {
|
||||
if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
versions := make([]string, len(records))
|
||||
for i, record := range records {
|
||||
versions[i] = record.Version
|
||||
@ -80,24 +80,24 @@ func (m *Migrator) GetPendingMigrations() ([]Migration, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
appliedMap := make(map[string]bool)
|
||||
for _, version := range appliedVersions {
|
||||
appliedMap[version] = true
|
||||
}
|
||||
|
||||
|
||||
var pendingMigrations []Migration
|
||||
for _, migration := range m.migrations {
|
||||
if !appliedMap[migration.Version] {
|
||||
pendingMigrations = append(pendingMigrations, migration)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 按版本号排序
|
||||
sort.Slice(pendingMigrations, func(i, j int) bool {
|
||||
return pendingMigrations[i].Version < pendingMigrations[j].Version
|
||||
})
|
||||
|
||||
|
||||
return pendingMigrations, nil
|
||||
}
|
||||
|
||||
@ -106,39 +106,39 @@ func (m *Migrator) Up() error {
|
||||
if err := m.initMigrationTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
pendingMigrations, err := m.GetPendingMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
if len(pendingMigrations) == 0 {
|
||||
log.Println("No pending migrations")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
for _, migration := range pendingMigrations {
|
||||
log.Printf("Applying migration %s: %s", migration.Version, migration.Description)
|
||||
|
||||
|
||||
// 开始事务
|
||||
tx := m.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
|
||||
// 执行迁移SQL
|
||||
if err := m.executeSQL(tx, migration.UpSQL); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
// 记录迁移状态 (使用UPSERT)
|
||||
now := time.Now()
|
||||
|
||||
|
||||
// 检查记录是否已存在
|
||||
var existingRecord MigrationRecord
|
||||
err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error
|
||||
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// 创建新记录
|
||||
record := MigrationRecord{
|
||||
@ -167,15 +167,15 @@ func (m *Migrator) Up() error {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
log.Printf("Successfully applied migration %s", migration.Version)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -184,58 +184,58 @@ func (m *Migrator) Down(steps int) error {
|
||||
if err := m.initMigrationTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
appliedVersions, err := m.GetAppliedMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
if len(appliedVersions) == 0 {
|
||||
log.Println("No applied migrations to rollback")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// 获取要回滚的迁移(从最新开始)
|
||||
rollbackCount := steps
|
||||
if rollbackCount > len(appliedVersions) {
|
||||
rollbackCount = len(appliedVersions)
|
||||
}
|
||||
|
||||
|
||||
for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- {
|
||||
version := appliedVersions[i]
|
||||
migration := m.findMigrationByVersion(version)
|
||||
if migration == nil {
|
||||
return fmt.Errorf("migration %s not found in migration definitions", version)
|
||||
}
|
||||
|
||||
|
||||
log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description)
|
||||
|
||||
|
||||
// 开始事务
|
||||
tx := m.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
|
||||
// 执行回滚SQL
|
||||
if err := m.executeSQL(tx, migration.DownSQL); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
// 更新迁移状态
|
||||
if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err)
|
||||
}
|
||||
|
||||
|
||||
log.Printf("Successfully rolled back migration %s", migration.Version)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -244,27 +244,27 @@ func (m *Migrator) Status() error {
|
||||
if err := m.initMigrationTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
appliedVersions, err := m.GetAppliedMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
appliedMap := make(map[string]bool)
|
||||
for _, version := range appliedVersions {
|
||||
appliedMap[version] = true
|
||||
}
|
||||
|
||||
|
||||
// 排序所有迁移
|
||||
allMigrations := m.migrations
|
||||
sort.Slice(allMigrations, func(i, j int) bool {
|
||||
return allMigrations[i].Version < allMigrations[j].Version
|
||||
})
|
||||
|
||||
|
||||
fmt.Println("Migration Status:")
|
||||
fmt.Println("Version | Status | Description")
|
||||
fmt.Println("---------------|---------|----------------------------------")
|
||||
|
||||
|
||||
for _, migration := range allMigrations {
|
||||
status := "Pending"
|
||||
if appliedMap[migration.Version] {
|
||||
@ -272,7 +272,7 @@ func (m *Migrator) Status() error {
|
||||
}
|
||||
fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -280,18 +280,18 @@ func (m *Migrator) Status() error {
|
||||
func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error {
|
||||
// 分割SQL语句(按分号分割)
|
||||
statements := strings.Split(sqlStr, ";")
|
||||
|
||||
|
||||
for _, statement := range statements {
|
||||
statement = strings.TrimSpace(statement)
|
||||
if statement == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if err := tx.Exec(statement).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -308,25 +308,25 @@ func (m *Migrator) findMigrationByVersion(version string) *Migration {
|
||||
// Reset 重置数据库(谨慎使用)
|
||||
func (m *Migrator) Reset() error {
|
||||
log.Println("WARNING: This will drop all tables and reset the database!")
|
||||
|
||||
|
||||
// 获取所有应用的迁移
|
||||
appliedVersions, err := m.GetAppliedMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
// 回滚所有迁移
|
||||
if len(appliedVersions) > 0 {
|
||||
if err := m.Down(len(appliedVersions)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 删除迁移表
|
||||
if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil {
|
||||
return fmt.Errorf("failed to drop migration table: %v", err)
|
||||
}
|
||||
|
||||
|
||||
log.Println("Database reset completed")
|
||||
return nil
|
||||
}
|
||||
@ -337,25 +337,25 @@ func (m *Migrator) Migrate(steps int) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
if len(pendingMigrations) == 0 {
|
||||
log.Println("No pending migrations")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
migrateCount := steps
|
||||
if steps <= 0 || steps > len(pendingMigrations) {
|
||||
migrateCount = len(pendingMigrations)
|
||||
}
|
||||
|
||||
|
||||
// 临时修改migrations列表,只包含要执行的迁移
|
||||
originalMigrations := m.migrations
|
||||
m.migrations = pendingMigrations[:migrateCount]
|
||||
|
||||
|
||||
err = m.Up()
|
||||
|
||||
|
||||
// 恢复原始migrations列表
|
||||
m.migrations = originalMigrations
|
||||
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -309,13 +309,13 @@ func GetLatestMigrationVersion() string {
|
||||
if len(migrations) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
latest := migrations[0]
|
||||
for _, migration := range migrations {
|
||||
if migration.Version > latest.Version {
|
||||
latest = migration
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return latest.Version
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ package response
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"photography-backend/pkg/errorx"
|
||||
)
|
||||
@ -39,4 +39,4 @@ func Success(w http.ResponseWriter, data interface{}) {
|
||||
|
||||
func Error(w http.ResponseWriter, err error) {
|
||||
Response(w, nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Driver string `json:"driver"` // mysql, postgres, sqlite
|
||||
Driver string `json:"driver"` // mysql, postgres, sqlite
|
||||
Host string `json:"host,optional"`
|
||||
Port int `json:"port,optional"`
|
||||
Username string `json:"username,optional"`
|
||||
@ -27,7 +27,7 @@ type Config struct {
|
||||
func NewDB(config Config) (*gorm.DB, error) {
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
|
||||
|
||||
// 配置日志
|
||||
newLogger := logger.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
|
||||
@ -37,7 +37,7 @@ func NewDB(config Config) (*gorm.DB, error) {
|
||||
Colorful: false, // 禁用彩色打印
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
switch config.Driver {
|
||||
case "mysql":
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
|
||||
@ -58,20 +58,20 @@ func NewDB(config Config) (*gorm.DB, error) {
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database driver: %s", config.Driver)
|
||||
}
|
||||
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 设置连接池
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
sqlDB.SetMaxIdleConns(10)
|
||||
sqlDB.SetMaxOpenConns(100)
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
|
||||
return db, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,7 +11,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -64,39 +64,39 @@ func SaveFile(file multipart.File, header *multipart.FileHeader, config Config)
|
||||
if header.Size > config.MaxSize {
|
||||
return nil, fmt.Errorf("文件大小超过限制: %d bytes", config.MaxSize)
|
||||
}
|
||||
|
||||
|
||||
// 检查文件类型
|
||||
contentType := header.Header.Get("Content-Type")
|
||||
if !IsAllowedType(contentType, config.AllowedTypes) {
|
||||
return nil, fmt.Errorf("不支持的文件类型: %s", contentType)
|
||||
}
|
||||
|
||||
|
||||
// 生成文件名
|
||||
fileName := GenerateFileName(header.Filename)
|
||||
|
||||
|
||||
// 创建上传目录
|
||||
uploadPath := filepath.Join(config.UploadDir, "photos")
|
||||
if err := os.MkdirAll(uploadPath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建上传目录失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 完整文件路径
|
||||
filePath := filepath.Join(uploadPath, fileName)
|
||||
|
||||
|
||||
// 创建目标文件
|
||||
dst, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建文件失败: %v", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
|
||||
// 复制文件内容
|
||||
file.Seek(0, 0) // 重置文件指针
|
||||
_, err = io.Copy(dst, file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("保存文件失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 返回文件信息
|
||||
return &FileInfo{
|
||||
OriginalName: header.Filename,
|
||||
@ -115,35 +115,35 @@ func CreateThumbnail(originalPath string, config Config) (*FileInfo, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开图片失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 生成缩略图文件名
|
||||
ext := filepath.Ext(originalPath)
|
||||
baseName := strings.TrimSuffix(filepath.Base(originalPath), ext)
|
||||
thumbnailName := baseName + "_thumb" + ext
|
||||
|
||||
|
||||
// 创建缩略图目录
|
||||
thumbnailDir := filepath.Join(config.UploadDir, "thumbnails")
|
||||
if err := os.MkdirAll(thumbnailDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建缩略图目录失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
thumbnailPath := filepath.Join(thumbnailDir, thumbnailName)
|
||||
|
||||
|
||||
// 调整图片大小 (最大宽度 300px,保持比例)
|
||||
thumbnail := imaging.Resize(src, 300, 0, imaging.Lanczos)
|
||||
|
||||
|
||||
// 保存缩略图
|
||||
err = imaging.Save(thumbnail, thumbnailPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("保存缩略图失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 获取文件大小
|
||||
stat, err := os.Stat(thumbnailPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取缩略图信息失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
return &FileInfo{
|
||||
OriginalName: thumbnailName,
|
||||
FileName: thumbnailName,
|
||||
@ -161,13 +161,13 @@ func UploadPhoto(file multipart.File, header *multipart.FileHeader, config Confi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 创建缩略图
|
||||
thumbnail, err := CreateThumbnail(original.FilePath, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
return &UploadResult{
|
||||
Original: *original,
|
||||
Thumbnail: *thumbnail,
|
||||
@ -179,11 +179,11 @@ func DeleteFile(filePath string) error {
|
||||
if filePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return nil // 文件不存在,认为删除成功
|
||||
}
|
||||
|
||||
|
||||
return os.Remove(filePath)
|
||||
}
|
||||
|
||||
@ -194,12 +194,12 @@ func GetImageDimensions(filePath string) (width, height int, err error) {
|
||||
return 0, 0, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
|
||||
img, _, err := image.DecodeConfig(file)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
|
||||
return img.Width, img.Height, nil
|
||||
}
|
||||
|
||||
@ -210,16 +210,16 @@ func ResizeImage(srcPath, destPath string, width, height int) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开图片失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 调整图片尺寸 (正方形裁剪)
|
||||
resized := imaging.Fill(src, width, height, imaging.Center, imaging.Lanczos)
|
||||
|
||||
|
||||
// 保存调整后的图片
|
||||
err = imaging.Save(resized, destPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存调整后的图片失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -230,15 +230,15 @@ func CreateAvatar(srcPath, destPath string, size int) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开图片失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 创建正方形头像 (居中裁剪)
|
||||
avatar := imaging.Fill(src, size, size, imaging.Center, imaging.Lanczos)
|
||||
|
||||
|
||||
// 保存为JPEG格式 (压缩优化)
|
||||
err = imaging.Save(avatar, destPath, imaging.JPEGQuality(85))
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存头像失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,4 +31,4 @@ func SHA256(str string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(str))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ package jwt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
@ -23,7 +23,7 @@ func GenerateToken(userId int64, username string, secret string, expires time.Du
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
@ -32,14 +32,14 @@ func ParseToken(tokenString string, secret string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
|
||||
return nil, jwt.ErrInvalidKey
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,12 +23,12 @@ import (
|
||||
// IntegrationTestSuite 集成测试套件
|
||||
type IntegrationTestSuite struct {
|
||||
suite.Suite
|
||||
svcCtx *svc.ServiceContext
|
||||
cfg config.Config
|
||||
db *gorm.DB
|
||||
authToken string
|
||||
userID int64
|
||||
photoID int64
|
||||
svcCtx *svc.ServiceContext
|
||||
cfg config.Config
|
||||
db *gorm.DB
|
||||
authToken string
|
||||
userID int64
|
||||
photoID int64
|
||||
categoryID int64
|
||||
}
|
||||
|
||||
@ -37,24 +37,24 @@ func (suite *IntegrationTestSuite) SetupSuite() {
|
||||
// 加载配置
|
||||
var cfg config.Config
|
||||
conf.MustLoad("../etc/photography-api.yaml", &cfg)
|
||||
|
||||
|
||||
// 使用内存数据库
|
||||
cfg.Database.Driver = "sqlite"
|
||||
cfg.Database.FilePath = ":memory:"
|
||||
|
||||
|
||||
// 创建数据库连接
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
suite.db = db
|
||||
suite.cfg = cfg
|
||||
|
||||
|
||||
// 创建服务上下文
|
||||
suite.svcCtx = svc.NewServiceContext(cfg)
|
||||
|
||||
|
||||
// 初始化数据库表
|
||||
suite.initDatabase()
|
||||
|
||||
|
||||
// 创建测试数据
|
||||
suite.seedTestData()
|
||||
}
|
||||
@ -71,7 +71,7 @@ func (suite *IntegrationTestSuite) TearDownSuite() {
|
||||
func (suite *IntegrationTestSuite) initDatabase() {
|
||||
// 这里应该运行迁移或创建表
|
||||
// 简化示例,实际应该使用迁移系统
|
||||
|
||||
|
||||
// 创建用户表
|
||||
err := suite.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
@ -86,7 +86,7 @@ func (suite *IntegrationTestSuite) initDatabase() {
|
||||
)
|
||||
`).Error
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
// 创建分类表
|
||||
err = suite.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS categories (
|
||||
@ -101,7 +101,7 @@ func (suite *IntegrationTestSuite) initDatabase() {
|
||||
)
|
||||
`).Error
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
// 创建照片表
|
||||
err = suite.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS photos (
|
||||
@ -130,7 +130,7 @@ func (suite *IntegrationTestSuite) seedTestData() {
|
||||
VALUES ('testuser', 'test@example.com', '$2a$10$92IXUNpkjO0rOQ5byMi.Ye4oKoEa3Ro9llC/.og/at2.uheWG/igi', 1)
|
||||
`).Error
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
// 获取用户ID
|
||||
var user struct {
|
||||
ID int64 `gorm:"column:id"`
|
||||
@ -138,14 +138,14 @@ func (suite *IntegrationTestSuite) seedTestData() {
|
||||
err = suite.db.Table("users").Where("username = ?", "testuser").First(&user).Error
|
||||
suite.Require().NoError(err)
|
||||
suite.userID = user.ID
|
||||
|
||||
|
||||
// 创建测试分类
|
||||
err = suite.db.Exec(`
|
||||
INSERT INTO categories (name, description, is_active)
|
||||
VALUES ('测试分类', '这是一个测试分类', 1)
|
||||
`).Error
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
// 获取分类ID
|
||||
var category struct {
|
||||
ID int64 `gorm:"column:id"`
|
||||
@ -159,19 +159,19 @@ func (suite *IntegrationTestSuite) seedTestData() {
|
||||
func (suite *IntegrationTestSuite) TestCompleteWorkflow() {
|
||||
// 1. 用户注册
|
||||
suite.testUserRegistration()
|
||||
|
||||
|
||||
// 2. 用户登录
|
||||
suite.testUserLogin()
|
||||
|
||||
|
||||
// 3. 分类管理
|
||||
suite.testCategoryManagement()
|
||||
|
||||
|
||||
// 4. 照片管理
|
||||
suite.testPhotoManagement()
|
||||
|
||||
|
||||
// 5. 权限验证
|
||||
suite.testPermissionValidation()
|
||||
|
||||
|
||||
// 6. 数据关联性测试
|
||||
suite.testDataRelationships()
|
||||
}
|
||||
@ -184,21 +184,21 @@ func (suite *IntegrationTestSuite) testUserRegistration() {
|
||||
"email": "newuser@example.com",
|
||||
"password": "newpassword123",
|
||||
}
|
||||
|
||||
|
||||
resp := suite.makeRequest("POST", "/api/v1/auth/register", registerData, "")
|
||||
suite.Equal(200, resp.Code)
|
||||
|
||||
|
||||
// 测试重复用户名
|
||||
resp = suite.makeRequest("POST", "/api/v1/auth/register", registerData, "")
|
||||
suite.NotEqual(200, resp.Code)
|
||||
|
||||
|
||||
// 测试无效邮箱
|
||||
invalidData := map[string]interface{}{
|
||||
"username": "testuser2",
|
||||
"email": "invalid-email",
|
||||
"password": "password123",
|
||||
}
|
||||
|
||||
|
||||
resp = suite.makeRequest("POST", "/api/v1/auth/register", invalidData, "")
|
||||
suite.NotEqual(200, resp.Code)
|
||||
}
|
||||
@ -210,23 +210,23 @@ func (suite *IntegrationTestSuite) testUserLogin() {
|
||||
"username": "testuser",
|
||||
"password": "password",
|
||||
}
|
||||
|
||||
|
||||
resp := suite.makeRequest("POST", "/api/v1/auth/login", loginData, "")
|
||||
suite.Equal(200, resp.Code)
|
||||
|
||||
|
||||
var loginResp types.LoginResponse
|
||||
err := json.Unmarshal(resp.Body, &loginResp)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
suite.authToken = loginResp.Data.Token
|
||||
suite.NotEmpty(suite.authToken)
|
||||
|
||||
|
||||
// 测试无效凭证
|
||||
invalidLogin := map[string]interface{}{
|
||||
"username": "testuser",
|
||||
"password": "wrongpassword",
|
||||
}
|
||||
|
||||
|
||||
resp = suite.makeRequest("POST", "/api/v1/auth/login", invalidLogin, "")
|
||||
suite.NotEqual(200, resp.Code)
|
||||
}
|
||||
@ -239,35 +239,35 @@ func (suite *IntegrationTestSuite) testCategoryManagement() {
|
||||
"description": "这是一个新的分类",
|
||||
"parent_id": suite.categoryID,
|
||||
}
|
||||
|
||||
|
||||
resp := suite.makeRequest("POST", "/api/v1/categories", categoryData, suite.authToken)
|
||||
suite.Equal(200, resp.Code)
|
||||
|
||||
|
||||
var createResp types.CreateCategoryResponse
|
||||
err := json.Unmarshal(resp.Body, &createResp)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
newCategoryID := createResp.Data.ID
|
||||
|
||||
|
||||
newCategoryID := createResp.Data.Id
|
||||
|
||||
// 测试获取分类列表
|
||||
resp = suite.makeRequest("GET", "/api/v1/categories", nil, suite.authToken)
|
||||
suite.Equal(200, resp.Code)
|
||||
|
||||
|
||||
var listResp types.GetCategoryListResponse
|
||||
err = json.Unmarshal(resp.Body, &listResp)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
suite.GreaterOrEqual(len(listResp.Data), 2)
|
||||
|
||||
|
||||
suite.GreaterOrEqual(len(listResp.Data.Categories), 2)
|
||||
|
||||
// 测试更新分类
|
||||
updateData := map[string]interface{}{
|
||||
"name": "更新的分类",
|
||||
"description": "这是一个更新的分类",
|
||||
}
|
||||
|
||||
|
||||
resp = suite.makeRequest("PUT", fmt.Sprintf("/api/v1/categories/%d", newCategoryID), updateData, suite.authToken)
|
||||
suite.Equal(200, resp.Code)
|
||||
|
||||
|
||||
// 测试删除分类
|
||||
resp = suite.makeRequest("DELETE", fmt.Sprintf("/api/v1/categories/%d", newCategoryID), nil, suite.authToken)
|
||||
suite.Equal(200, resp.Code)
|
||||
@ -276,20 +276,20 @@ func (suite *IntegrationTestSuite) testCategoryManagement() {
|
||||
// testPhotoManagement 照片管理测试
|
||||
func (suite *IntegrationTestSuite) testPhotoManagement() {
|
||||
// 测试创建照片记录(简化版,不包含实际文件上传)
|
||||
photoData := map[string]interface{}{
|
||||
_ = map[string]interface{}{
|
||||
"title": "测试照片",
|
||||
"description": "这是一个测试照片",
|
||||
"file_path": "/uploads/test.jpg",
|
||||
"category_id": suite.categoryID,
|
||||
}
|
||||
|
||||
|
||||
// 这里应该测试实际的文件上传,简化为直接插入数据库
|
||||
err := suite.db.Exec(`
|
||||
INSERT INTO photos (title, description, file_path, category_id, user_id)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`, "测试照片", "这是一个测试照片", "/uploads/test.jpg", suite.categoryID, suite.userID).Error
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
// 获取照片ID
|
||||
var photo struct {
|
||||
ID int64 `gorm:"column:id"`
|
||||
@ -297,27 +297,27 @@ func (suite *IntegrationTestSuite) testPhotoManagement() {
|
||||
err = suite.db.Table("photos").Where("title = ?", "测试照片").First(&photo).Error
|
||||
suite.Require().NoError(err)
|
||||
suite.photoID = photo.ID
|
||||
|
||||
|
||||
// 测试获取照片列表
|
||||
resp := suite.makeRequest("GET", "/api/v1/photos", nil, suite.authToken)
|
||||
suite.Equal(200, resp.Code)
|
||||
|
||||
|
||||
var listResp types.GetPhotoListResponse
|
||||
err = json.Unmarshal(resp.Body, &listResp)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
suite.GreaterOrEqual(len(listResp.Data), 1)
|
||||
|
||||
|
||||
suite.GreaterOrEqual(len(listResp.Data.Photos), 1)
|
||||
|
||||
// 测试获取照片详情
|
||||
resp = suite.makeRequest("GET", fmt.Sprintf("/api/v1/photos/%d", suite.photoID), nil, suite.authToken)
|
||||
suite.Equal(200, resp.Code)
|
||||
|
||||
|
||||
// 测试更新照片
|
||||
updateData := map[string]interface{}{
|
||||
"title": "更新的照片",
|
||||
"description": "这是一个更新的照片",
|
||||
}
|
||||
|
||||
|
||||
resp = suite.makeRequest("PUT", fmt.Sprintf("/api/v1/photos/%d", suite.photoID), updateData, suite.authToken)
|
||||
suite.Equal(200, resp.Code)
|
||||
}
|
||||
@ -327,11 +327,11 @@ func (suite *IntegrationTestSuite) testPermissionValidation() {
|
||||
// 测试未认证访问
|
||||
resp := suite.makeRequest("GET", "/api/v1/photos", nil, "")
|
||||
suite.Equal(401, resp.Code)
|
||||
|
||||
|
||||
// 测试无效token
|
||||
resp = suite.makeRequest("GET", "/api/v1/photos", nil, "invalid_token")
|
||||
suite.Equal(401, resp.Code)
|
||||
|
||||
|
||||
// 测试权限不足(尝试访问其他用户的照片)
|
||||
// 这里需要创建另一个用户的照片进行测试
|
||||
}
|
||||
@ -343,12 +343,12 @@ func (suite *IntegrationTestSuite) testDataRelationships() {
|
||||
err := suite.db.Table("photos").Where("category_id = ?", suite.categoryID).Count(&count).Error
|
||||
suite.Require().NoError(err)
|
||||
suite.GreaterOrEqual(count, int64(1))
|
||||
|
||||
|
||||
// 测试用户与照片的关联
|
||||
err = suite.db.Table("photos").Where("user_id = ?", suite.userID).Count(&count).Error
|
||||
suite.Require().NoError(err)
|
||||
suite.GreaterOrEqual(count, int64(1))
|
||||
|
||||
|
||||
// 测试级联删除(如果删除分类,照片的category_id应该被处理)
|
||||
// 这里可以测试数据库约束和业务逻辑
|
||||
}
|
||||
@ -357,20 +357,20 @@ func (suite *IntegrationTestSuite) testDataRelationships() {
|
||||
func (suite *IntegrationTestSuite) makeRequest(method, path string, data interface{}, token string) *TestResponse {
|
||||
var body []byte
|
||||
var err error
|
||||
|
||||
|
||||
if data != nil {
|
||||
body, err = json.Marshal(data)
|
||||
suite.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
||||
req, err := http.NewRequest(method, path, bytes.NewReader(body))
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
|
||||
// 这里应该实际调用API服务
|
||||
// 简化示例,返回模拟响应
|
||||
return &TestResponse{
|
||||
@ -388,33 +388,33 @@ type TestResponse struct {
|
||||
// TestDataConsistency 数据一致性测试
|
||||
func (suite *IntegrationTestSuite) TestDataConsistency() {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
// 测试事务操作
|
||||
tx := suite.db.Begin()
|
||||
|
||||
|
||||
// 创建用户
|
||||
err := tx.Exec(`
|
||||
INSERT INTO users (username, email, password_hash)
|
||||
VALUES (?, ?, ?)
|
||||
`, "txuser", "txuser@example.com", "hashedpassword").Error
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
// 创建分类
|
||||
err = tx.Exec(`
|
||||
INSERT INTO categories (name, description)
|
||||
VALUES (?, ?)
|
||||
`, "事务分类", "事务测试分类").Error
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
// 回滚事务
|
||||
tx.Rollback()
|
||||
|
||||
|
||||
// 验证数据未被插入
|
||||
var userCount int64
|
||||
err = suite.db.WithContext(ctx).Table("users").Where("username = ?", "txuser").Count(&userCount).Error
|
||||
suite.Require().NoError(err)
|
||||
suite.Equal(int64(0), userCount)
|
||||
|
||||
|
||||
var categoryCount int64
|
||||
err = suite.db.WithContext(ctx).Table("categories").Where("name = ?", "事务分类").Count(&categoryCount).Error
|
||||
suite.Require().NoError(err)
|
||||
@ -425,26 +425,26 @@ func (suite *IntegrationTestSuite) TestDataConsistency() {
|
||||
func (suite *IntegrationTestSuite) TestConcurrentOperations() {
|
||||
concurrency := 10
|
||||
done := make(chan bool, concurrency)
|
||||
|
||||
|
||||
// 并发创建分类
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(index int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
|
||||
err := suite.db.Exec(`
|
||||
INSERT INTO categories (name, description)
|
||||
VALUES (?, ?)
|
||||
`, fmt.Sprintf("并发分类_%d", index), fmt.Sprintf("并发测试分类_%d", index)).Error
|
||||
|
||||
|
||||
suite.Require().NoError(err)
|
||||
}(i)
|
||||
}
|
||||
|
||||
|
||||
// 等待所有操作完成
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
|
||||
// 验证数据一致性
|
||||
var count int64
|
||||
err := suite.db.Table("categories").Where("name LIKE ?", "并发分类_%").Count(&count).Error
|
||||
@ -456,13 +456,13 @@ func (suite *IntegrationTestSuite) TestConcurrentOperations() {
|
||||
func (suite *IntegrationTestSuite) TestCacheOperations() {
|
||||
// 如果项目使用Redis缓存,这里测试缓存操作
|
||||
// 简化示例,测试内存缓存
|
||||
|
||||
|
||||
cache := make(map[string]interface{})
|
||||
|
||||
|
||||
// 测试缓存设置
|
||||
cache["test_key"] = "test_value"
|
||||
suite.Equal("test_value", cache["test_key"])
|
||||
|
||||
|
||||
// 测试缓存过期(简化)
|
||||
delete(cache, "test_key")
|
||||
_, exists := cache["test_key"]
|
||||
@ -473,9 +473,9 @@ func (suite *IntegrationTestSuite) TestCacheOperations() {
|
||||
func (suite *IntegrationTestSuite) TestPerformanceWithLoad() {
|
||||
// 创建大量测试数据
|
||||
batchSize := 1000
|
||||
|
||||
|
||||
start := time.Now()
|
||||
|
||||
|
||||
for i := 0; i < batchSize; i++ {
|
||||
err := suite.db.Exec(`
|
||||
INSERT INTO categories (name, description)
|
||||
@ -483,28 +483,28 @@ func (suite *IntegrationTestSuite) TestPerformanceWithLoad() {
|
||||
`, fmt.Sprintf("性能测试分类_%d", i), fmt.Sprintf("性能测试描述_%d", i)).Error
|
||||
suite.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
||||
insertDuration := time.Since(start)
|
||||
|
||||
|
||||
// 测试查询性能
|
||||
start = time.Now()
|
||||
|
||||
|
||||
var categories []struct {
|
||||
ID int64 `gorm:"column:id"`
|
||||
Name string `gorm:"column:name"`
|
||||
}
|
||||
|
||||
|
||||
err := suite.db.Table("categories").Where("name LIKE ?", "性能测试分类_%").Find(&categories).Error
|
||||
suite.Require().NoError(err)
|
||||
|
||||
|
||||
queryDuration := time.Since(start)
|
||||
|
||||
|
||||
suite.Equal(batchSize, len(categories))
|
||||
|
||||
|
||||
// 记录性能指标
|
||||
suite.T().Logf("Insert %d records took: %v", batchSize, insertDuration)
|
||||
suite.T().Logf("Query %d records took: %v", batchSize, queryDuration)
|
||||
|
||||
|
||||
// 性能断言
|
||||
suite.Less(insertDuration, 5*time.Second)
|
||||
suite.Less(queryDuration, 1*time.Second)
|
||||
@ -514,23 +514,23 @@ func (suite *IntegrationTestSuite) TestPerformanceWithLoad() {
|
||||
func (suite *IntegrationTestSuite) TestErrorRecovery() {
|
||||
// 测试数据库连接错误恢复
|
||||
// 这里应该测试数据库连接中断后的恢复机制
|
||||
|
||||
|
||||
// 测试事务失败恢复
|
||||
tx := suite.db.Begin()
|
||||
|
||||
|
||||
// 故意创建一个会失败的操作
|
||||
err := tx.Exec(`
|
||||
INSERT INTO users (username, email, password_hash)
|
||||
VALUES (?, ?, ?)
|
||||
`, "testuser", "test@example.com", "password").Error // 重复的用户名
|
||||
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
suite.T().Log("Transaction properly rolled back after error")
|
||||
} else {
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
|
||||
// 验证数据库状态仍然正常
|
||||
var count int64
|
||||
err = suite.db.Table("users").Count(&count).Error
|
||||
@ -541,4 +541,4 @@ func (suite *IntegrationTestSuite) TestErrorRecovery() {
|
||||
// TestIntegrationTestSuite 运行集成测试套件
|
||||
func TestIntegrationTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(IntegrationTestSuite))
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,23 +38,23 @@ func SetupTestEnvironment(t *testing.T) *TestContext {
|
||||
// 加载测试配置
|
||||
var cfg config.Config
|
||||
conf.MustLoad("../etc/photography-api.yaml", &cfg)
|
||||
|
||||
|
||||
// 使用内存数据库进行测试
|
||||
cfg.Database.Driver = "sqlite"
|
||||
cfg.Database.FilePath = ":memory:"
|
||||
|
||||
|
||||
// 创建服务上下文
|
||||
svcCtx := svc.NewServiceContext(cfg)
|
||||
|
||||
|
||||
// 创建 REST 服务器
|
||||
server := rest.MustNewServer(rest.RestConf{
|
||||
ServiceConf: cfg.ServiceConf,
|
||||
Port: 0, // 使用随机端口
|
||||
})
|
||||
|
||||
|
||||
// 注册路由
|
||||
handler.RegisterHandlers(server, svcCtx)
|
||||
|
||||
|
||||
return &TestContext{
|
||||
server: server,
|
||||
svcCtx: svcCtx,
|
||||
@ -79,15 +79,15 @@ func (tc *TestContext) Login(t *testing.T) {
|
||||
Username: "admin",
|
||||
Password: "admin123",
|
||||
}
|
||||
|
||||
|
||||
respBody, err := tc.PostJSON("/api/v1/auth/login", loginReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var resp types.LoginResponse
|
||||
err = json.Unmarshal(respBody, &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
tc.authToken = resp.Token
|
||||
|
||||
tc.authToken = resp.Data.Token
|
||||
}
|
||||
|
||||
// PostJSON 发送 POST JSON 请求
|
||||
@ -96,16 +96,16 @@ func (tc *TestContext) PostJSON(path string, data interface{}) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tc.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+tc.authToken)
|
||||
}
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tc.server.ServeHTTP(w, req)
|
||||
|
||||
|
||||
return w.Body.Bytes(), nil
|
||||
}
|
||||
|
||||
@ -115,10 +115,10 @@ func (tc *TestContext) GetJSON(path string) ([]byte, error) {
|
||||
if tc.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+tc.authToken)
|
||||
}
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tc.server.ServeHTTP(w, req)
|
||||
|
||||
|
||||
return w.Body.Bytes(), nil
|
||||
}
|
||||
|
||||
@ -128,16 +128,16 @@ func (tc *TestContext) PutJSON(path string, data interface{}) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
req := httptest.NewRequest(http.MethodPut, path, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tc.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+tc.authToken)
|
||||
}
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tc.server.ServeHTTP(w, req)
|
||||
|
||||
|
||||
return w.Body.Bytes(), nil
|
||||
}
|
||||
|
||||
@ -147,10 +147,10 @@ func (tc *TestContext) DeleteJSON(path string) ([]byte, error) {
|
||||
if tc.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+tc.authToken)
|
||||
}
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tc.server.ServeHTTP(w, req)
|
||||
|
||||
|
||||
return w.Body.Bytes(), nil
|
||||
}
|
||||
|
||||
@ -159,14 +159,14 @@ func TestHealthCheck(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
respBody, err := tc.GetJSON("/api/v1/health")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var resp types.BaseResponse
|
||||
err = json.Unmarshal(respBody, &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, resp.Code)
|
||||
assert.Equal(t, "success", resp.Message)
|
||||
}
|
||||
@ -176,53 +176,53 @@ func TestAuthFlow(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 测试注册
|
||||
registerReq := types.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
|
||||
|
||||
respBody, err := tc.PostJSON("/api/v1/auth/register", registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var registerResp types.RegisterResponse
|
||||
err = json.Unmarshal(respBody, ®isterResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, registerResp.Code)
|
||||
assert.NotEmpty(t, registerResp.Data.Token)
|
||||
|
||||
|
||||
// 测试登录
|
||||
loginReq := types.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
}
|
||||
|
||||
|
||||
respBody, err = tc.PostJSON("/api/v1/auth/login", loginReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var loginResp types.LoginResponse
|
||||
err = json.Unmarshal(respBody, &loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, loginResp.Code)
|
||||
assert.NotEmpty(t, loginResp.Data.Token)
|
||||
|
||||
|
||||
// 测试无效凭证
|
||||
invalidLoginReq := types.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "wrongpassword",
|
||||
}
|
||||
|
||||
|
||||
respBody, err = tc.PostJSON("/api/v1/auth/login", invalidLoginReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var invalidResp types.LoginResponse
|
||||
err = json.Unmarshal(respBody, &invalidResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.NotEqual(t, 200, invalidResp.Code)
|
||||
}
|
||||
|
||||
@ -231,61 +231,61 @@ func TestUserCRUD(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 先登录获取token
|
||||
tc.Login(t)
|
||||
|
||||
|
||||
// 测试创建用户
|
||||
createReq := types.CreateUserRequest{
|
||||
Username: "newuser",
|
||||
Password: "newpass123",
|
||||
Email: "newuser@example.com",
|
||||
}
|
||||
|
||||
|
||||
respBody, err := tc.PostJSON("/api/v1/users", createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var createResp types.CreateUserResponse
|
||||
err = json.Unmarshal(respBody, &createResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, createResp.Code)
|
||||
userID := createResp.Data.ID
|
||||
|
||||
|
||||
// 测试获取用户
|
||||
respBody, err = tc.GetJSON(fmt.Sprintf("/api/v1/users/%d", userID))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var getResp types.GetUserResponse
|
||||
err = json.Unmarshal(respBody, &getResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, getResp.Code)
|
||||
assert.Equal(t, "newuser", getResp.Data.Username)
|
||||
|
||||
|
||||
// 测试更新用户
|
||||
updateReq := types.UpdateUserRequest{
|
||||
Username: "updateduser",
|
||||
Email: "updated@example.com",
|
||||
}
|
||||
|
||||
|
||||
respBody, err = tc.PutJSON(fmt.Sprintf("/api/v1/users/%d", userID), updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var updateResp types.UpdateUserResponse
|
||||
err = json.Unmarshal(respBody, &updateResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, updateResp.Code)
|
||||
|
||||
|
||||
// 测试删除用户
|
||||
respBody, err = tc.DeleteJSON(fmt.Sprintf("/api/v1/users/%d", userID))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var deleteResp types.DeleteUserResponse
|
||||
err = json.Unmarshal(respBody, &deleteResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, deleteResp.Code)
|
||||
}
|
||||
|
||||
@ -294,50 +294,50 @@ func TestCategoryCRUD(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 先登录获取token
|
||||
tc.Login(t)
|
||||
|
||||
|
||||
// 测试创建分类
|
||||
createReq := types.CreateCategoryRequest{
|
||||
Name: "测试分类",
|
||||
Description: "这是一个测试分类",
|
||||
}
|
||||
|
||||
|
||||
respBody, err := tc.PostJSON("/api/v1/categories", createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var createResp types.CreateCategoryResponse
|
||||
err = json.Unmarshal(respBody, &createResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, createResp.Code)
|
||||
categoryID := createResp.Data.ID
|
||||
|
||||
|
||||
// 测试获取分类列表
|
||||
respBody, err = tc.GetJSON("/api/v1/categories")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var listResp types.GetCategoryListResponse
|
||||
err = json.Unmarshal(respBody, &listResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, listResp.Code)
|
||||
assert.GreaterOrEqual(t, len(listResp.Data), 1)
|
||||
|
||||
|
||||
// 测试更新分类
|
||||
updateReq := types.UpdateCategoryRequest{
|
||||
Name: "更新的分类",
|
||||
Description: "这是一个更新的分类",
|
||||
}
|
||||
|
||||
|
||||
respBody, err = tc.PutJSON(fmt.Sprintf("/api/v1/categories/%d", categoryID), updateReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var updateResp types.UpdateCategoryResponse
|
||||
err = json.Unmarshal(respBody, &updateResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, updateResp.Code)
|
||||
}
|
||||
|
||||
@ -346,46 +346,46 @@ func TestPhotoUpload(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 先登录获取token
|
||||
tc.Login(t)
|
||||
|
||||
|
||||
// 创建测试图片文件
|
||||
testImageContent := []byte("fake image content")
|
||||
|
||||
|
||||
// 创建 multipart form
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
|
||||
// 添加文件字段
|
||||
part, err := writer.CreateFormFile("file", "test.jpg")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
_, err = part.Write(testImageContent)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 添加其他字段
|
||||
_ = writer.WriteField("title", "测试照片")
|
||||
_ = writer.WriteField("description", "这是一个测试照片")
|
||||
_ = writer.WriteField("category_id", "1")
|
||||
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 发送请求
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/photos", &buf)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+tc.authToken)
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tc.server.ServeHTTP(w, req)
|
||||
|
||||
|
||||
respBody := w.Body.Bytes()
|
||||
|
||||
|
||||
var resp types.UploadPhotoResponse
|
||||
err = json.Unmarshal(respBody, &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, resp.Code)
|
||||
assert.NotEmpty(t, resp.Data.ID)
|
||||
}
|
||||
@ -395,18 +395,18 @@ func TestPhotoList(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 先登录获取token
|
||||
tc.Login(t)
|
||||
|
||||
|
||||
// 测试获取照片列表
|
||||
respBody, err := tc.GetJSON("/api/v1/photos?page=1&limit=10")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var resp types.GetPhotoListResponse
|
||||
err = json.Unmarshal(respBody, &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, 200, resp.Code)
|
||||
assert.IsType(t, []types.Photo{}, resp.Data)
|
||||
}
|
||||
@ -416,24 +416,24 @@ func TestMiddleware(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 测试 CORS 中间件
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/v1/health", nil)
|
||||
req.Header.Set("Origin", "http://localhost:3000")
|
||||
req.Header.Set("Access-Control-Request-Method", "GET")
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tc.server.ServeHTTP(w, req)
|
||||
|
||||
|
||||
assert.Equal(t, "http://localhost:3000", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "GET")
|
||||
|
||||
|
||||
// 测试认证中间件
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
||||
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
tc.server.ServeHTTP(w, req)
|
||||
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
@ -442,24 +442,24 @@ func TestErrorHandling(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 测试不存在的接口
|
||||
respBody, err := tc.GetJSON("/api/v1/nonexistent")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
var resp types.BaseResponse
|
||||
err = json.Unmarshal(respBody, &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.NotEqual(t, 200, resp.Code)
|
||||
|
||||
|
||||
// 测试无效的 JSON
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader("invalid json"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tc.server.ServeHTTP(w, req)
|
||||
|
||||
|
||||
assert.NotEqual(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
@ -468,21 +468,21 @@ func TestPerformance(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 测试健康检查接口的性能
|
||||
start := time.Now()
|
||||
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
_, err := tc.GetJSON("/api/v1/health")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
||||
duration := time.Since(start)
|
||||
avgDuration := duration / 100
|
||||
|
||||
|
||||
// 平均响应时间应该小于 10ms
|
||||
assert.Less(t, avgDuration, 10*time.Millisecond)
|
||||
|
||||
|
||||
t.Logf("Average response time: %v", avgDuration)
|
||||
}
|
||||
|
||||
@ -491,11 +491,11 @@ func TestConcurrency(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 并发测试健康检查接口
|
||||
concurrency := 50
|
||||
done := make(chan bool, concurrency)
|
||||
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
_, err := tc.GetJSON("/api/v1/health")
|
||||
@ -503,12 +503,12 @@ func TestConcurrency(t *testing.T) {
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
// 等待所有请求完成
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
|
||||
t.Log("Concurrency test completed successfully")
|
||||
}
|
||||
|
||||
@ -517,31 +517,31 @@ func TestFileOperations(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 测试文件上传目录创建
|
||||
uploadDir := "test_uploads"
|
||||
err := os.MkdirAll(uploadDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
defer os.RemoveAll(uploadDir)
|
||||
|
||||
|
||||
// 测试文件写入
|
||||
testFile := filepath.Join(uploadDir, "test.txt")
|
||||
content := []byte("test content")
|
||||
|
||||
|
||||
err = os.WriteFile(testFile, content, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 测试文件读取
|
||||
readContent, err := os.ReadFile(testFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, content, readContent)
|
||||
|
||||
|
||||
// 测试文件删除
|
||||
err = os.Remove(testFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
_, err = os.Stat(testFile)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
@ -551,35 +551,35 @@ func TestDatabaseOperations(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
defer tc.StopServer()
|
||||
tc.StartServer()
|
||||
|
||||
|
||||
// 测试数据库连接
|
||||
assert.NotNil(t, tc.svcCtx.DB)
|
||||
|
||||
|
||||
// 测试数据库查询
|
||||
var count int64
|
||||
err := tc.svcCtx.DB.Table("users").Count(&count).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.GreaterOrEqual(t, count, int64(0))
|
||||
}
|
||||
|
||||
// TestConfigValidation 配置验证测试
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
tc := SetupTestEnvironment(t)
|
||||
|
||||
|
||||
// 测试配置加载
|
||||
assert.NotEmpty(t, tc.cfg.Name)
|
||||
assert.NotEmpty(t, tc.cfg.Host)
|
||||
assert.Greater(t, tc.cfg.Port, 0)
|
||||
|
||||
|
||||
// 测试数据库配置
|
||||
assert.NotEmpty(t, tc.cfg.Database.Driver)
|
||||
|
||||
|
||||
// 测试认证配置
|
||||
assert.NotEmpty(t, tc.cfg.Auth.AccessSecret)
|
||||
assert.Greater(t, tc.cfg.Auth.AccessExpire, int64(0))
|
||||
|
||||
|
||||
// 测试文件上传配置
|
||||
assert.Greater(t, tc.cfg.FileUpload.MaxSize, int64(0))
|
||||
assert.NotEmpty(t, tc.cfg.FileUpload.UploadDir)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"use client"
|
||||
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import { Badge } from './ui/badge'
|
||||
import { Button } from './ui/button'
|
||||
import { Alert, AlertDescription } from './ui/alert'
|
||||
@ -21,7 +21,7 @@ export function ApiStatus() {
|
||||
}
|
||||
}, [useRealApi])
|
||||
|
||||
const checkApiStatus = async () => {
|
||||
const checkApiStatus = useCallback(async () => {
|
||||
setIsLoading(true)
|
||||
try {
|
||||
if (useRealApi) {
|
||||
@ -38,14 +38,14 @@ export function ApiStatus() {
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
}, [useRealApi])
|
||||
|
||||
useEffect(() => {
|
||||
checkApiStatus()
|
||||
// 每30秒检查一次API状态
|
||||
const interval = setInterval(checkApiStatus, 30000)
|
||||
return () => clearInterval(interval)
|
||||
}, [useRealApi])
|
||||
}, [useRealApi, checkApiStatus])
|
||||
|
||||
const toggleApiMode = () => {
|
||||
const newMode = !useRealApi
|
||||
|
||||
@ -33,7 +33,7 @@ interface CategoryPageProps {
|
||||
}
|
||||
|
||||
export function CategoryPage({ photos, onCategorySelect, onPhotosView }: CategoryPageProps) {
|
||||
const { data: dynamicCategories = [] } = useCategories()
|
||||
const { data: _dynamicCategories = [] } = useCategories()
|
||||
const [searchQuery, setSearchQuery] = useState("")
|
||||
const [viewMode, setViewMode] = useState<'grid' | 'list'>('grid')
|
||||
|
||||
@ -273,7 +273,7 @@ export function CategoryPage({ photos, onCategorySelect, onPhotosView }: Categor
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex -space-x-2">
|
||||
{getCategoryPreviewImages(category.photos).slice(0, 3).map((photo, idx) => (
|
||||
{getCategoryPreviewImages(category.photos).slice(0, 3).map((photo, _idx) => (
|
||||
<div
|
||||
key={photo.id}
|
||||
className="w-12 h-12 rounded-lg border-2 border-white overflow-hidden"
|
||||
|
||||
@ -35,7 +35,7 @@ export function FilterBar({
|
||||
const [showAdvanced, setShowAdvanced] = useState(false)
|
||||
|
||||
// 静态分类作为备选
|
||||
const staticCategories = [
|
||||
const _staticCategories = [
|
||||
{ id: "all", name: "全部作品" },
|
||||
{ id: "urban", name: "城市风光" },
|
||||
{ id: "nature", name: "自然风景" },
|
||||
@ -191,7 +191,7 @@ export function FilterBar({
|
||||
|
||||
{searchText.trim() && (
|
||||
<Badge variant="outline" className="gap-1">
|
||||
搜索: "{searchText.trim()}"
|
||||
搜索: “{searchText.trim()}”
|
||||
<X
|
||||
className="h-3 w-3 cursor-pointer"
|
||||
onClick={handleClearSearch}
|
||||
|
||||
@ -18,7 +18,6 @@ import {
|
||||
Tag,
|
||||
TrendingUp,
|
||||
Hash,
|
||||
Filter,
|
||||
ArrowRight,
|
||||
Camera,
|
||||
Sparkles
|
||||
@ -95,7 +94,7 @@ export function TagCloud({ photos, onTagSelect, onPhotosView }: TagCloudProps) {
|
||||
|
||||
// 过滤和排序标签
|
||||
const filteredTags = useMemo(() => {
|
||||
let filtered = tagStats.filter(tag =>
|
||||
const filtered = tagStats.filter(tag =>
|
||||
tag.count >= minCount &&
|
||||
(searchQuery.trim() === '' || tag.name.toLowerCase().includes(searchQuery.toLowerCase()))
|
||||
)
|
||||
@ -237,7 +236,7 @@ export function TagCloud({ photos, onTagSelect, onPhotosView }: TagCloudProps) {
|
||||
<div className="flex items-center gap-2">
|
||||
<select
|
||||
value={sortBy}
|
||||
onChange={(e) => setSortBy(e.target.value as any)}
|
||||
onChange={(e) => setSortBy(e.target.value as 'popularity' | 'alphabetical' | 'recent')}
|
||||
className="px-3 py-2 border border-gray-300 rounded-md text-sm"
|
||||
>
|
||||
<option value="popularity">按热度</option>
|
||||
@ -324,7 +323,7 @@ export function TagCloud({ photos, onTagSelect, onPhotosView }: TagCloudProps) {
|
||||
|
||||
{/* 最近照片预览 */}
|
||||
<div className="flex -space-x-2 mt-3">
|
||||
{tag.recentPhotos.slice(0, 3).map((photo, idx) => (
|
||||
{tag.recentPhotos.slice(0, 3).map((photo, _idx) => (
|
||||
<div
|
||||
key={photo.id}
|
||||
className="w-8 h-8 rounded-full border-2 border-white overflow-hidden"
|
||||
|
||||
@ -15,9 +15,12 @@ const api = axios.create({
|
||||
api.interceptors.request.use(
|
||||
(config) => {
|
||||
// 可以在这里添加token等认证信息
|
||||
const token = localStorage.getItem('token')
|
||||
if (token) {
|
||||
config.headers.Authorization = `Bearer ${token}`
|
||||
// 检查是否在浏览器环境中
|
||||
if (typeof window !== 'undefined') {
|
||||
const token = localStorage.getItem('token')
|
||||
if (token) {
|
||||
config.headers.Authorization = `Bearer ${token}`
|
||||
}
|
||||
}
|
||||
return config
|
||||
},
|
||||
@ -42,9 +45,11 @@ api.interceptors.response.use(
|
||||
},
|
||||
(error) => {
|
||||
if (error.response?.status === 401) {
|
||||
// 处理未授权
|
||||
localStorage.removeItem('token')
|
||||
window.location.href = '/login'
|
||||
// 处理未授权 - 仅在浏览器环境中执行
|
||||
if (typeof window !== 'undefined') {
|
||||
localStorage.removeItem('token')
|
||||
window.location.href = '/login'
|
||||
}
|
||||
}
|
||||
return Promise.reject(error)
|
||||
}
|
||||
|
||||
@ -8,7 +8,7 @@ class CategoryService {
|
||||
// 获取所有分类
|
||||
async getAllCategories(): Promise<Category[]> {
|
||||
if (process.env.NEXT_PUBLIC_USE_REAL_API === 'true') {
|
||||
const response: any = await api.get('/categories?page=1&page_size=100')
|
||||
const response: { categories: Category[] } = await api.get('/categories?page=1&page_size=100')
|
||||
return response?.categories || []
|
||||
} else {
|
||||
// Mock API 返回字符串数组,需要转换
|
||||
|
||||
@ -58,27 +58,54 @@ export const queryKeys = {
|
||||
categories: ['categories'] as const,
|
||||
}
|
||||
|
||||
// 后端照片数据结构
|
||||
interface BackendPhoto {
|
||||
id: number
|
||||
title?: string
|
||||
description?: string
|
||||
src?: string
|
||||
url?: string
|
||||
image_path?: string
|
||||
file_path?: string
|
||||
thumbnail_path?: string
|
||||
category?: string
|
||||
category_id?: number
|
||||
user_id?: number
|
||||
tags?: string[]
|
||||
date?: string
|
||||
created_at?: number
|
||||
updated_at?: number
|
||||
exif?: {
|
||||
camera?: string
|
||||
lens?: string
|
||||
settings?: string
|
||||
location?: string
|
||||
}
|
||||
}
|
||||
|
||||
// 数据转换工具
|
||||
const transformPhoto = async (backendPhoto: any): Promise<Photo> => {
|
||||
const transformPhoto = async (backendPhoto: BackendPhoto): Promise<Photo> => {
|
||||
// 如果使用Mock API,直接返回
|
||||
if (process.env.NEXT_PUBLIC_USE_REAL_API !== 'true') {
|
||||
return {
|
||||
...backendPhoto,
|
||||
id: backendPhoto.id,
|
||||
title: backendPhoto.title || '无标题',
|
||||
description: backendPhoto.description || '',
|
||||
src: backendPhoto.src || '/placeholder.jpg',
|
||||
category: backendPhoto.category || 'general',
|
||||
tags: backendPhoto.tags || [],
|
||||
date: backendPhoto.date || new Date().toISOString().split('T')[0],
|
||||
exif: backendPhoto.exif || {
|
||||
camera: '未知',
|
||||
lens: '未知',
|
||||
settings: '未知',
|
||||
location: '未知'
|
||||
exif: {
|
||||
camera: backendPhoto.exif?.camera || '未知',
|
||||
lens: backendPhoto.exif?.lens || '未知',
|
||||
settings: backendPhoto.exif?.settings || '未知',
|
||||
location: backendPhoto.exif?.location || '未知'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取分类名称
|
||||
const categoryName = await categoryService.getCategoryName(backendPhoto.category_id)
|
||||
const categoryName = await categoryService.getCategoryName(backendPhoto.category_id || 1)
|
||||
|
||||
// 转换后端API数据格式
|
||||
return {
|
||||
@ -88,7 +115,7 @@ const transformPhoto = async (backendPhoto: any): Promise<Photo> => {
|
||||
src: backendPhoto.file_path ? `http://localhost:8080${backendPhoto.file_path}` : '/placeholder.jpg',
|
||||
category: categoryName,
|
||||
tags: [], // 后端暂无标签系统,使用空数组
|
||||
date: new Date(backendPhoto.created_at * 1000).toISOString().split('T')[0],
|
||||
date: new Date((backendPhoto.created_at || Date.now() / 1000) * 1000).toISOString().split('T')[0],
|
||||
exif: {
|
||||
camera: '未知',
|
||||
lens: '未知',
|
||||
@ -105,11 +132,11 @@ const transformPhoto = async (backendPhoto: any): Promise<Photo> => {
|
||||
}
|
||||
}
|
||||
|
||||
const transformCategory = (backendCategory: any): string => {
|
||||
const _transformCategory = (backendCategory: Category | string): string => {
|
||||
if (process.env.NEXT_PUBLIC_USE_REAL_API !== 'true') {
|
||||
return backendCategory
|
||||
return typeof backendCategory === 'string' ? backendCategory : backendCategory.name
|
||||
}
|
||||
return backendCategory.name
|
||||
return typeof backendCategory === 'string' ? backendCategory : backendCategory.name
|
||||
}
|
||||
|
||||
// 获取所有照片
|
||||
@ -119,13 +146,13 @@ export const usePhotos = () => {
|
||||
queryFn: async (): Promise<Photo[]> => {
|
||||
if (process.env.NEXT_PUBLIC_USE_REAL_API === 'true') {
|
||||
// 使用真实API,带分页参数
|
||||
const response: any = await api.get('/photos?page=1&page_size=100')
|
||||
const response: { photos: BackendPhoto[] } = await api.get('/photos?page=1&page_size=100')
|
||||
const photos = response?.photos || []
|
||||
// 并发处理所有照片的转换
|
||||
return Promise.all(photos.map(transformPhoto))
|
||||
} else {
|
||||
// 使用Mock API
|
||||
const photos: any[] = await api.get('/photos')
|
||||
const photos: BackendPhoto[] = await api.get('/photos')
|
||||
return Promise.all(photos.map(transformPhoto))
|
||||
}
|
||||
},
|
||||
@ -139,7 +166,7 @@ export const usePhotosPaginated = (page: number = 1, pageSize: number = 12) => {
|
||||
queryKey: [...queryKeys.photos, 'paginated', page, pageSize],
|
||||
queryFn: async (): Promise<{ photos: Photo[], total: number, hasMore: boolean }> => {
|
||||
if (process.env.NEXT_PUBLIC_USE_REAL_API === 'true') {
|
||||
const response: any = await api.get(`/photos?page=${page}&page_size=${pageSize}`)
|
||||
const response: { photos: BackendPhoto[], total: number } = await api.get(`/photos?page=${page}&page_size=${pageSize}`)
|
||||
const photos = response?.photos || []
|
||||
const total = response?.total || 0
|
||||
const transformedPhotos = await Promise.all(photos.map(transformPhoto))
|
||||
@ -150,7 +177,7 @@ export const usePhotosPaginated = (page: number = 1, pageSize: number = 12) => {
|
||||
}
|
||||
} else {
|
||||
// 使用Mock API - 模拟分页
|
||||
const allPhotos: any[] = await api.get('/photos')
|
||||
const allPhotos: BackendPhoto[] = await api.get('/photos')
|
||||
const startIndex = (page - 1) * pageSize
|
||||
const endIndex = startIndex + pageSize
|
||||
const paginatedPhotos = allPhotos.slice(startIndex, endIndex)
|
||||
@ -167,17 +194,17 @@ export const usePhotosPaginated = (page: number = 1, pageSize: number = 12) => {
|
||||
}
|
||||
|
||||
// 无限滚动照片查询
|
||||
export const useInfinitePhotos = (pageSize: number = 12) => {
|
||||
export const useInfinitePhotos = (_pageSize: number = 12) => {
|
||||
return useQuery({
|
||||
queryKey: [...queryKeys.photos, 'infinite'],
|
||||
queryFn: async (): Promise<Photo[]> => {
|
||||
if (process.env.NEXT_PUBLIC_USE_REAL_API === 'true') {
|
||||
// 获取所有照片用于前端分页
|
||||
const response: any = await api.get('/photos?page=1&page_size=200')
|
||||
const response: { photos: BackendPhoto[] } = await api.get('/photos?page=1&page_size=200')
|
||||
const photos = response?.photos || []
|
||||
return Promise.all(photos.map(transformPhoto))
|
||||
} else {
|
||||
const photos: any[] = await api.get('/photos')
|
||||
const photos: BackendPhoto[] = await api.get('/photos')
|
||||
return Promise.all(photos.map(transformPhoto))
|
||||
}
|
||||
},
|
||||
@ -191,7 +218,7 @@ export const usePhoto = (id: number) => {
|
||||
queryKey: queryKeys.photo(id),
|
||||
queryFn: async (): Promise<Photo> => {
|
||||
if (process.env.NEXT_PUBLIC_USE_REAL_API === 'true') {
|
||||
const response = await api.get(`/photos/${id}`)
|
||||
const response: BackendPhoto = await api.get(`/photos/${id}`)
|
||||
return await transformPhoto(response)
|
||||
} else {
|
||||
return api.get(`/photos/${id}`)
|
||||
@ -207,7 +234,7 @@ export const useCategories = () => {
|
||||
queryKey: queryKeys.categories,
|
||||
queryFn: async (): Promise<string[]> => {
|
||||
if (process.env.NEXT_PUBLIC_USE_REAL_API === 'true') {
|
||||
const response: any = await api.get('/categories?page=1&page_size=100')
|
||||
const response: { categories: Category[] } = await api.get('/categories?page=1&page_size=100')
|
||||
const categories = response?.categories || []
|
||||
return categories.map((cat: Category) => cat.name)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user