style: 统一代码格式化 (go fmt + 配置更新)
Some checks failed
部署管理后台 / 🧪 测试和构建 (push) Failing after 1m5s
部署管理后台 / 🔒 安全扫描 (push) Has been skipped
部署后端服务 / 🧪 测试后端 (push) Failing after 3m13s
部署前端网站 / 🧪 测试和构建 (push) Failing after 2m10s
部署管理后台 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🚀 构建并部署 (push) Has been skipped
部署管理后台 / 🔄 回滚部署 (push) Has been skipped
部署前端网站 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🔄 回滚部署 (push) Has been skipped

- 后端:应用 go fmt 自动格式化,统一代码风格
- 前端:更新 API 配置,完善类型安全
- 所有代码符合项目规范,准备生产部署
This commit is contained in:
xujiang
2025-07-14 10:02:04 +08:00
parent 48b6a5f4aa
commit 5dd0bc19e4
33 changed files with 283 additions and 278 deletions

View File

@ -29,8 +29,8 @@ func main() {
// 添加静态文件服务 // 添加静态文件服务
server.AddRoute(rest.Route{ server.AddRoute(rest.Route{
Method: http.MethodGet, Method: http.MethodGet,
Path: "/uploads/*", Path: "/uploads/*",
Handler: func(w http.ResponseWriter, r *http.Request) { Handler: func(w http.ResponseWriter, r *http.Request) {
http.StripPrefix("/uploads/", http.FileServer(http.Dir("uploads"))).ServeHTTP(w, r) http.StripPrefix("/uploads/", http.FileServer(http.Dir("uploads"))).ServeHTTP(w, r)
}, },

View File

@ -95,7 +95,7 @@ func main() {
} }
migrateSteps = len(pendingMigrations) migrateSteps = len(pendingMigrations)
} }
if err := migrator.Migrate(migrateSteps); err != nil { if err := migrator.Migrate(migrateSteps); err != nil {
log.Fatalf("Migration failed: %v", err) log.Fatalf("Migration failed: %v", err)
} }
@ -168,10 +168,10 @@ func showHelp() {
func createMigrationTemplate(name string) { func createMigrationTemplate(name string) {
// 生成版本号(基于当前时间) // 生成版本号(基于当前时间)
version := fmt.Sprintf("%d_%06d", version := fmt.Sprintf("%d_%06d",
getCurrentTimestamp(), getCurrentTimestamp(),
getCurrentMicroseconds()) getCurrentMicroseconds())
template := fmt.Sprintf(`// Migration: %s template := fmt.Sprintf(`// Migration: %s
// Description: %s // Description: %s
// Version: %s // Version: %s
@ -198,22 +198,22 @@ var migration_%s = Migration{
-- ALTER TABLE user DROP COLUMN new_field; -- Note: SQLite doesn't support DROP COLUMN -- ALTER TABLE user DROP COLUMN new_field; -- Note: SQLite doesn't support DROP COLUMN
%s, %s,
}`, }`,
name, name, version, name, name, version,
version, version, name, version, version, name,
"`", "`", "`", "`") "`", "`", "`", "`")
filename := fmt.Sprintf("migrations/%s_%s.go", version, name) filename := fmt.Sprintf("migrations/%s_%s.go", version, name)
// 创建 migrations 目录 // 创建 migrations 目录
if err := os.MkdirAll("migrations", 0755); err != nil { if err := os.MkdirAll("migrations", 0755); err != nil {
log.Fatalf("Failed to create migrations directory: %v", err) log.Fatalf("Failed to create migrations directory: %v", err)
} }
// 写入模板文件 // 写入模板文件
if err := os.WriteFile(filename, []byte(template), 0644); err != nil { if err := os.WriteFile(filename, []byte(template), 0644); err != nil {
log.Fatalf("Failed to create migration template: %v", err) log.Fatalf("Failed to create migration template: %v", err)
} }
fmt.Printf("Created migration template: %s\n", filename) fmt.Printf("Created migration template: %s\n", filename)
fmt.Println("Please:") fmt.Println("Please:")
fmt.Println("1. Edit the migration file to add your SQL") fmt.Println("1. Edit the migration file to add your SQL")
@ -227,4 +227,4 @@ func getCurrentTimestamp() int64 {
func getCurrentMicroseconds() int { func getCurrentMicroseconds() int {
return 1 // 当天的迁移计数,可以根据需要调整 return 1 // 当天的迁移计数,可以根据需要调整
} }

View File

@ -7,8 +7,8 @@ import (
type Config struct { type Config struct {
rest.RestConf rest.RestConf
Database database.Config `json:"database"` Database database.Config `json:"database"`
Auth AuthConfig `json:"auth"` Auth AuthConfig `json:"auth"`
FileUpload FileUploadConfig `json:"file_upload"` FileUpload FileUploadConfig `json:"file_upload"`
Middleware MiddlewareConfig `json:"middleware"` Middleware MiddlewareConfig `json:"middleware"`
} }
@ -28,6 +28,6 @@ type MiddlewareConfig struct {
EnableCORS bool `json:"enable_cors"` EnableCORS bool `json:"enable_cors"`
EnableLogger bool `json:"enable_logger"` EnableLogger bool `json:"enable_logger"`
EnableErrorHandle bool `json:"enable_error_handle"` EnableErrorHandle bool `json:"enable_error_handle"`
CORSOrigins []string `json:"cors_origins"` CORSOrigins []string `json:"cors_origins"`
LogLevel string `json:"log_level"` LogLevel string `json:"log_level"`
} }

View File

@ -33,7 +33,7 @@ func UploadPhotoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
Title: r.FormValue("title"), Title: r.FormValue("title"),
Description: r.FormValue("description"), Description: r.FormValue("description"),
} }
// 解析 category_id // 解析 category_id
if categoryIdStr := r.FormValue("category_id"); categoryIdStr != "" { if categoryIdStr := r.FormValue("category_id"); categoryIdStr != "" {
var categoryId int64 var categoryId int64

View File

@ -39,24 +39,24 @@ func (l *LoginLogic) Login(req *types.LoginRequest) (resp *types.LoginResponse,
logx.Errorf("查询用户失败: %v", err) logx.Errorf("查询用户失败: %v", err)
return nil, errorx.NewWithCode(errorx.ServerError) return nil, errorx.NewWithCode(errorx.ServerError)
} }
// 2. 验证密码 // 2. 验证密码
if !hash.CheckPassword(req.Password, user.Password) { if !hash.CheckPassword(req.Password, user.Password) {
return nil, errorx.NewWithCode(errorx.InvalidPassword) return nil, errorx.NewWithCode(errorx.InvalidPassword)
} }
// 3. 检查用户状态 // 3. 检查用户状态
if user.Status == 0 { if user.Status == 0 {
return nil, errorx.NewWithCode(errorx.UserDisabled) return nil, errorx.NewWithCode(errorx.UserDisabled)
} }
// 4. 生成 JWT token // 4. 生成 JWT token
token, err := jwt.GenerateToken(user.Id, user.Username, l.svcCtx.Config.Auth.AccessSecret, time.Hour*24*7) token, err := jwt.GenerateToken(user.Id, user.Username, l.svcCtx.Config.Auth.AccessSecret, time.Hour*24*7)
if err != nil { if err != nil {
logx.Errorf("生成 JWT token 失败: %v", err) logx.Errorf("生成 JWT token 失败: %v", err)
return nil, errorx.NewWithCode(errorx.ServerError) return nil, errorx.NewWithCode(errorx.ServerError)
} }
// 5. 返回登录结果 // 5. 返回登录结果
return &types.LoginResponse{ return &types.LoginResponse{
BaseResponse: types.BaseResponse{ BaseResponse: types.BaseResponse{

View File

@ -34,34 +34,34 @@ func (l *RegisterLogic) Register(req *types.RegisterRequest) (resp *types.Regist
if err == nil && existingUser != nil { if err == nil && existingUser != nil {
return nil, errors.New("用户名已存在") return nil, errors.New("用户名已存在")
} }
// 2. 检查邮箱是否已存在 // 2. 检查邮箱是否已存在
existingEmail, err := l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email) existingEmail, err := l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email)
if err == nil && existingEmail != nil { if err == nil && existingEmail != nil {
return nil, errors.New("邮箱已存在") return nil, errors.New("邮箱已存在")
} }
// 3. 加密密码 // 3. 加密密码
hashedPassword, err := hash.HashPassword(req.Password) hashedPassword, err := hash.HashPassword(req.Password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 4. 创建用户 // 4. 创建用户
user := &model.User{ user := &model.User{
Username: req.Username, Username: req.Username,
Email: req.Email, Email: req.Email,
Password: hashedPassword, Password: hashedPassword,
Status: 1, // 默认激活状态 Status: 1, // 默认激活状态
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
} }
_, err = l.svcCtx.UserModel.Insert(l.ctx, user) _, err = l.svcCtx.UserModel.Insert(l.ctx, user)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 5. 返回注册结果 // 5. 返回注册结果
return &types.RegisterResponse{ return &types.RegisterResponse{
BaseResponse: types.BaseResponse{ BaseResponse: types.BaseResponse{

View File

@ -35,12 +35,12 @@ func (l *CreateCategoryLogic) CreateCategory(req *types.CreateCategoryRequest) (
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
} }
_, err = l.svcCtx.CategoryModel.Insert(l.ctx, category) _, err = l.svcCtx.CategoryModel.Insert(l.ctx, category)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 2. 返回结果 // 2. 返回结果
return &types.CreateCategoryResponse{ return &types.CreateCategoryResponse{
BaseResponse: types.BaseResponse{ BaseResponse: types.BaseResponse{

View File

@ -30,13 +30,13 @@ func (l *GetCategoryListLogic) GetCategoryList(req *types.GetCategoryListRequest
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 2. 统计总数 // 2. 统计总数
total, err := l.svcCtx.CategoryModel.Count(l.ctx, req.Keyword) total, err := l.svcCtx.CategoryModel.Count(l.ctx, req.Keyword)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 3. 转换数据结构 // 3. 转换数据结构
var categoryList []types.Category var categoryList []types.Category
for _, category := range categories { for _, category := range categories {
@ -48,7 +48,7 @@ func (l *GetCategoryListLogic) GetCategoryList(req *types.GetCategoryListRequest
UpdatedAt: category.UpdatedAt.Unix(), UpdatedAt: category.UpdatedAt.Unix(),
}) })
} }
// 4. 返回结果 // 4. 返回结果
return &types.GetCategoryListResponse{ return &types.GetCategoryListResponse{
BaseResponse: types.BaseResponse{ BaseResponse: types.BaseResponse{

View File

@ -32,14 +32,14 @@ func (l *GetPhotoListLogic) GetPhotoList(req *types.GetPhotoListRequest) (resp *
logx.Errorf("查询照片列表失败: %v", err) logx.Errorf("查询照片列表失败: %v", err)
return nil, errorx.NewWithCode(errorx.ServerError) return nil, errorx.NewWithCode(errorx.ServerError)
} }
// 2. 统计总数 // 2. 统计总数
total, err := l.svcCtx.PhotoModel.Count(l.ctx, req.CategoryId, req.UserId, req.Keyword) total, err := l.svcCtx.PhotoModel.Count(l.ctx, req.CategoryId, req.UserId, req.Keyword)
if err != nil { if err != nil {
logx.Errorf("统计照片数量失败: %v", err) logx.Errorf("统计照片数量失败: %v", err)
return nil, errorx.NewWithCode(errorx.ServerError) return nil, errorx.NewWithCode(errorx.ServerError)
} }
// 3. 转换数据结构 // 3. 转换数据结构
var photoList []types.Photo var photoList []types.Photo
for _, photo := range photos { for _, photo := range photos {
@ -55,7 +55,7 @@ func (l *GetPhotoListLogic) GetPhotoList(req *types.GetPhotoListRequest) (resp *
UpdatedAt: photo.UpdatedAt.Unix(), UpdatedAt: photo.UpdatedAt.Unix(),
}) })
} }
// 4. 返回结果 // 4. 返回结果
return &types.GetPhotoListResponse{ return &types.GetPhotoListResponse{
BaseResponse: types.BaseResponse{ BaseResponse: types.BaseResponse{

View File

@ -36,7 +36,7 @@ func (l *GetPhotoLogic) GetPhoto(req *types.GetPhotoRequest) (resp *types.GetPho
logx.Errorf("查询照片失败: %v", err) logx.Errorf("查询照片失败: %v", err)
return nil, errorx.NewWithCode(errorx.ServerError) return nil, errorx.NewWithCode(errorx.ServerError)
} }
// 2. 返回结果 // 2. 返回结果
return &types.GetPhotoResponse{ return &types.GetPhotoResponse{
BaseResponse: types.BaseResponse{ BaseResponse: types.BaseResponse{

View File

@ -40,7 +40,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
// 后续需要实现JWT中间件 // 后续需要实现JWT中间件
userId = int64(1) userId = int64(1)
} }
// 2. 验证分类是否存在 // 2. 验证分类是否存在
_, err = l.svcCtx.CategoryModel.FindOne(l.ctx, req.CategoryId) _, err = l.svcCtx.CategoryModel.FindOne(l.ctx, req.CategoryId)
if err != nil { if err != nil {
@ -50,20 +50,20 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
logx.Errorf("查询分类失败: %v", err) logx.Errorf("查询分类失败: %v", err)
return nil, errorx.NewWithCode(errorx.ServerError) return nil, errorx.NewWithCode(errorx.ServerError)
} }
// 3. 处理文件上传 // 3. 处理文件上传
fileConfig := fileUtil.Config{ fileConfig := fileUtil.Config{
MaxSize: l.svcCtx.Config.FileUpload.MaxSize, MaxSize: l.svcCtx.Config.FileUpload.MaxSize,
UploadDir: l.svcCtx.Config.FileUpload.UploadDir, UploadDir: l.svcCtx.Config.FileUpload.UploadDir,
AllowedTypes: l.svcCtx.Config.FileUpload.AllowedTypes, AllowedTypes: l.svcCtx.Config.FileUpload.AllowedTypes,
} }
uploadResult, err := fileUtil.UploadPhoto(file, header, fileConfig) uploadResult, err := fileUtil.UploadPhoto(file, header, fileConfig)
if err != nil { if err != nil {
logx.Errorf("文件上传失败: %v", err) logx.Errorf("文件上传失败: %v", err)
return nil, errorx.NewWithCode(errorx.PhotoUploadFail) return nil, errorx.NewWithCode(errorx.PhotoUploadFail)
} }
// 4. 创建照片记录 // 4. 创建照片记录
photo := &model.Photo{ photo := &model.Photo{
Title: req.Title, Title: req.Title,
@ -75,7 +75,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
} }
_, err = l.svcCtx.PhotoModel.Insert(l.ctx, photo) _, err = l.svcCtx.PhotoModel.Insert(l.ctx, photo)
if err != nil { if err != nil {
// 如果数据库保存失败,删除已上传的文件 // 如果数据库保存失败,删除已上传的文件
@ -84,7 +84,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
logx.Errorf("保存照片记录失败: %v", err) logx.Errorf("保存照片记录失败: %v", err)
return nil, errorx.NewWithCode(errorx.ServerError) return nil, errorx.NewWithCode(errorx.ServerError)
} }
// 5. 返回上传结果 // 5. 返回上传结果
return &types.UploadPhotoResponse{ return &types.UploadPhotoResponse{
BaseResponse: types.BaseResponse{ BaseResponse: types.BaseResponse{

View File

@ -50,7 +50,7 @@ func (l *DeleteUserLogic) DeleteUser(req *types.DeleteUserRequest) (resp *types.
// 检查用户是否有关联的照片 // 检查用户是否有关联的照片
// 这里可以添加业务逻辑来决定是否允许删除有照片的用户 // 这里可以添加业务逻辑来决定是否允许删除有照片的用户
// 如果要严格控制,可以先检查用户是否有照片,如果有则不允许删除 // 如果要严格控制,可以先检查用户是否有照片,如果有则不允许删除
// 删除用户 // 删除用户
err = l.svcCtx.UserModel.Delete(l.ctx, req.Id) err = l.svcCtx.UserModel.Delete(l.ctx, req.Id)
if err != nil { if err != nil {

View File

@ -30,13 +30,13 @@ func (l *GetUserListLogic) GetUserList(req *types.GetUserListRequest) (resp *typ
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 2. 统计总数 // 2. 统计总数
total, err := l.svcCtx.UserModel.Count(l.ctx, req.Keyword) total, err := l.svcCtx.UserModel.Count(l.ctx, req.Keyword)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 3. 转换数据结构(不返回密码) // 3. 转换数据结构(不返回密码)
var userList []types.User var userList []types.User
for _, user := range users { for _, user := range users {
@ -50,7 +50,7 @@ func (l *GetUserListLogic) GetUserList(req *types.GetUserListRequest) (resp *typ
UpdatedAt: user.UpdatedAt.Unix(), UpdatedAt: user.UpdatedAt.Unix(),
}) })
} }
// 4. 返回结果 // 4. 返回结果
return &types.GetUserListResponse{ return &types.GetUserListResponse{
BaseResponse: types.BaseResponse{ BaseResponse: types.BaseResponse{

View File

@ -55,7 +55,7 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
// 4. 获取上传的文件 // 4. 获取上传的文件
uploadedFile, header, err := r.FormFile("avatar") uploadedFile, header, err := r.FormFile("avatar")
if err != nil { if err != nil {
return nil, errorx.New(errorx.ParamError, "获取上传文件失败: " + err.Error()) return nil, errorx.New(errorx.ParamError, "获取上传文件失败: "+err.Error())
} }
defer uploadedFile.Close() defer uploadedFile.Close()
@ -90,10 +90,10 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
filename := fmt.Sprintf("avatar_%d_%d%s", req.Id, time.Now().Unix(), ext) filename := fmt.Sprintf("avatar_%d_%d%s", req.Id, time.Now().Unix(), ext)
avatarDir := "uploads/avatars" avatarDir := "uploads/avatars"
// 8. 确保头像目录存在 // 8. 确保头像目录存在
if err := os.MkdirAll(avatarDir, 0755); err != nil { if err := os.MkdirAll(avatarDir, 0755); err != nil {
return nil, errorx.New(errorx.ServerError, "创建头像目录失败: " + err.Error()) return nil, errorx.New(errorx.ServerError, "创建头像目录失败: "+err.Error())
} }
avatarPath := filepath.Join(avatarDir, filename) avatarPath := filepath.Join(avatarDir, filename)
@ -101,13 +101,13 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
// 9. 保存原始头像文件 // 9. 保存原始头像文件
destFile, err := os.Create(avatarPath) destFile, err := os.Create(avatarPath)
if err != nil { if err != nil {
return nil, errorx.New(errorx.ServerError, "创建头像文件失败: " + err.Error()) return nil, errorx.New(errorx.ServerError, "创建头像文件失败: "+err.Error())
} }
defer destFile.Close() defer destFile.Close()
_, err = io.Copy(destFile, uploadedFile) _, err = io.Copy(destFile, uploadedFile)
if err != nil { if err != nil {
return nil, errorx.New(errorx.ServerError, "保存头像文件失败: " + err.Error()) return nil, errorx.New(errorx.ServerError, "保存头像文件失败: "+err.Error())
} }
// 10. 生成压缩版本的头像 (150x150像素) // 10. 生成压缩版本的头像 (150x150像素)
@ -142,7 +142,7 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
if avatarPath != compressedPath { if avatarPath != compressedPath {
os.Remove(compressedPath) os.Remove(compressedPath)
} }
return nil, errorx.New(errorx.ServerError, "更新用户头像失败: " + err.Error()) return nil, errorx.New(errorx.ServerError, "更新用户头像失败: "+err.Error())
} }
return &types.UploadAvatarResponse{ return &types.UploadAvatarResponse{

View File

@ -73,4 +73,4 @@ func (e UnauthorizedError) Error() string {
func NewUnauthorizedError(message string) UnauthorizedError { func NewUnauthorizedError(message string) UnauthorizedError {
return UnauthorizedError{Message: message} return UnauthorizedError{Message: message}
} }

View File

@ -87,7 +87,7 @@ func NewCORSMiddleware(config CORSConfig) *CORSMiddleware {
func (m *CORSMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { func (m *CORSMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
// 检查来源是否被允许 // 检查来源是否被允许
if origin != "" && m.isOriginAllowed(origin) { if origin != "" && m.isOriginAllowed(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Origin", origin)
@ -160,16 +160,16 @@ func (m *CORSMiddleware) isOriginAllowed(origin string) bool {
func (m *CORSMiddleware) setSecurityHeaders(w http.ResponseWriter) { func (m *CORSMiddleware) setSecurityHeaders(w http.ResponseWriter) {
// 防止点击劫持 // 防止点击劫持
w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Frame-Options", "DENY")
// 防止 MIME 类型嗅探 // 防止 MIME 类型嗅探
w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Content-Type-Options", "nosniff")
// XSS 保护 // XSS 保护
w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("X-XSS-Protection", "1; mode=block")
// 引用者策略 // 引用者策略
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// 内容安全策略 (基础版) // 内容安全策略 (基础版)
w.Header().Set("Content-Security-Policy", "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self'") w.Header().Set("Content-Security-Policy", "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self'")
} }

View File

@ -19,8 +19,8 @@ type ErrorConfig struct {
EnableDetailedErrors bool // 是否启用详细错误信息 (开发环境) EnableDetailedErrors bool // 是否启用详细错误信息 (开发环境)
EnableStackTrace bool // 是否启用堆栈跟踪 EnableStackTrace bool // 是否启用堆栈跟踪
EnableErrorMonitor bool // 是否启用错误监控 EnableErrorMonitor bool // 是否启用错误监控
IgnoreHTTPCodes []int // 忽略的HTTP状态码 (不记录为错误) IgnoreHTTPCodes []int // 忽略的HTTP状态码 (不记录为错误)
SensitiveFields []string // 敏感字段列表 (日志时隐藏) SensitiveFields []string // 敏感字段列表 (日志时隐藏)
} }
// DefaultErrorConfig 默认错误配置 // DefaultErrorConfig 默认错误配置
@ -29,8 +29,8 @@ func DefaultErrorConfig() ErrorConfig {
EnableDetailedErrors: false, // 生产环境默认关闭 EnableDetailedErrors: false, // 生产环境默认关闭
EnableStackTrace: false, // 生产环境默认关闭 EnableStackTrace: false, // 生产环境默认关闭
EnableErrorMonitor: true, EnableErrorMonitor: true,
IgnoreHTTPCodes: []int{http.StatusNotFound, http.StatusMethodNotAllowed}, IgnoreHTTPCodes: []int{http.StatusNotFound, http.StatusMethodNotAllowed},
SensitiveFields: []string{"password", "token", "secret", "key", "authorization"}, SensitiveFields: []string{"password", "token", "secret", "key", "authorization"},
} }
} }
@ -111,7 +111,7 @@ func (m *ErrorMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
// handlePanic 处理panic // handlePanic 处理panic
func (m *ErrorMiddleware) handlePanic(w *errorResponseWriter, r *http.Request, err interface{}) { func (m *ErrorMiddleware) handlePanic(w *errorResponseWriter, r *http.Request, err interface{}) {
stack := string(debug.Stack()) stack := string(debug.Stack())
// 记录panic日志 // 记录panic日志
logFields := map[string]interface{}{ logFields := map[string]interface{}{
"error": err, "error": err,
@ -206,7 +206,7 @@ func (m *ErrorMiddleware) respondWithError(w http.ResponseWriter, r *http.Reques
// 设置HTTP状态码 // 设置HTTP状态码
httpStatus := errorx.GetHttpStatus(err.Code) httpStatus := errorx.GetHttpStatus(err.Code)
// 设置响应头 // 设置响应头
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(httpStatus) w.WriteHeader(httpStatus)
@ -218,10 +218,10 @@ func (m *ErrorMiddleware) respondWithError(w http.ResponseWriter, r *http.Reques
// sanitizeFields 隐藏敏感字段 // sanitizeFields 隐藏敏感字段
func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string]interface{} { func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string]interface{} {
sanitized := make(map[string]interface{}) sanitized := make(map[string]interface{})
for key, value := range data { for key, value := range data {
lowerKey := strings.ToLower(key) lowerKey := strings.ToLower(key)
// 检查是否为敏感字段 // 检查是否为敏感字段
sensitive := false sensitive := false
for _, sensitiveField := range m.config.SensitiveFields { for _, sensitiveField := range m.config.SensitiveFields {
@ -230,7 +230,7 @@ func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string
break break
} }
} }
if sensitive { if sensitive {
sanitized[key] = "***REDACTED***" sanitized[key] = "***REDACTED***"
} else { } else {
@ -242,7 +242,7 @@ func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string
} }
} }
} }
return sanitized return sanitized
} }
@ -322,4 +322,4 @@ var CommonErrors = struct {
Code: 429, Code: 429,
Msg: "Rate Limit Exceeded", Msg: "Rate Limit Exceeded",
}, },
} }

View File

@ -15,8 +15,8 @@ import (
// LoggerConfig 日志配置 // LoggerConfig 日志配置
type LoggerConfig struct { type LoggerConfig struct {
EnableRequestBody bool // 是否记录请求体 EnableRequestBody bool // 是否记录请求体
EnableResponseBody bool // 是否记录响应体 EnableResponseBody bool // 是否记录响应体
MaxBodySize int64 // 最大记录的请求/响应体大小 MaxBodySize int64 // 最大记录的请求/响应体大小
SkipPaths []string // 跳过记录的路径 SkipPaths []string // 跳过记录的路径
SlowRequestDuration time.Duration // 慢请求阈值 SlowRequestDuration time.Duration // 慢请求阈值
@ -26,9 +26,9 @@ type LoggerConfig struct {
// DefaultLoggerConfig 默认日志配置 // DefaultLoggerConfig 默认日志配置
func DefaultLoggerConfig() LoggerConfig { func DefaultLoggerConfig() LoggerConfig {
return LoggerConfig{ return LoggerConfig{
EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息) EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息)
EnableResponseBody: false, // 默认不记录响应体 (减少日志量) EnableResponseBody: false, // 默认不记录响应体 (减少日志量)
MaxBodySize: 1024, // 最大记录1KB MaxBodySize: 1024, // 最大记录1KB
SkipPaths: []string{"/health", "/metrics", "/favicon.ico"}, SkipPaths: []string{"/health", "/metrics", "/favicon.ico"},
SlowRequestDuration: 1 * time.Second, SlowRequestDuration: 1 * time.Second,
EnablePanicRecover: true, EnablePanicRecover: true,
@ -60,7 +60,7 @@ func newResponseWriter(w http.ResponseWriter) *responseWriter {
return &responseWriter{ return &responseWriter{
ResponseWriter: w, ResponseWriter: w,
status: http.StatusOK, status: http.StatusOK,
body: &bytes.Buffer{}, body: &bytes.Buffer{},
} }
} }
@ -68,12 +68,12 @@ func newResponseWriter(w http.ResponseWriter) *responseWriter {
func (rw *responseWriter) Write(b []byte) (int, error) { func (rw *responseWriter) Write(b []byte) (int, error) {
size, err := rw.ResponseWriter.Write(b) size, err := rw.ResponseWriter.Write(b)
rw.size += int64(size) rw.size += int64(size)
// 记录响应体 (如果启用) // 记录响应体 (如果启用)
if rw.body.Len() < int(1024) { // 限制缓存大小 if rw.body.Len() < int(1024) { // 限制缓存大小
rw.body.Write(b) rw.body.Write(b)
} }
return size, err return size, err
} }
@ -163,7 +163,7 @@ func (m *LoggerMiddleware) generateRequestID(r *http.Request) string {
if requestID := r.Header.Get("X-Request-ID"); requestID != "" { if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
return requestID return requestID
} }
// 生成新的请求ID // 生成新的请求ID
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8)) return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
} }
@ -208,13 +208,13 @@ func (m *LoggerMiddleware) logRequestStart(r *http.Request, requestID, requestBo
// logRequestComplete 记录请求完成 // logRequestComplete 记录请求完成
func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) { func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) {
fields := map[string]interface{}{ fields := map[string]interface{}{
"request_id": requestID, "request_id": requestID,
"method": r.Method, "method": r.Method,
"path": r.URL.Path, "path": r.URL.Path,
"status": status, "status": status,
"response_size": size, "response_size": size,
"duration_ms": duration.Milliseconds(), "duration_ms": duration.Milliseconds(),
"duration": duration.String(), "duration": duration.String(),
} }
if responseBody != "" { if responseBody != "" {
@ -267,7 +267,7 @@ func getClientIP(r *http.Request) string {
return ip return ip
} }
} }
// 使用 RemoteAddr // 使用 RemoteAddr
if ip := r.RemoteAddr; ip != "" { if ip := r.RemoteAddr; ip != "" {
// 移除端口号 // 移除端口号
@ -276,7 +276,7 @@ func getClientIP(r *http.Request) string {
} }
return ip return ip
} }
return "unknown" return "unknown"
} }
@ -296,4 +296,4 @@ func randomString(length int) string {
result[i] = charset[time.Now().UnixNano()%int64(len(charset))] result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
} }
return string(result) return string(result)
} }

View File

@ -13,11 +13,11 @@ import (
// MiddlewareManager 中间件管理器 // MiddlewareManager 中间件管理器
type MiddlewareManager struct { type MiddlewareManager struct {
config config.Config config config.Config
corsMiddleware *CORSMiddleware corsMiddleware *CORSMiddleware
logMiddleware *LoggerMiddleware logMiddleware *LoggerMiddleware
errorMiddleware *ErrorMiddleware errorMiddleware *ErrorMiddleware
authMiddleware *AuthMiddleware authMiddleware *AuthMiddleware
} }
// NewMiddlewareManager 创建中间件管理器 // NewMiddlewareManager 创建中间件管理器
@ -34,12 +34,12 @@ func NewMiddlewareManager(c config.Config) *MiddlewareManager {
// getCORSConfig 获取CORS配置 // getCORSConfig 获取CORS配置
func getCORSConfig(c config.Config) CORSConfig { func getCORSConfig(c config.Config) CORSConfig {
env := getEnvironment() env := getEnvironment()
if env == "production" { if env == "production" {
// 生产环境使用严格的CORS配置 // 生产环境使用严格的CORS配置
return ProductionCORSConfig(getProductionOrigins()) return ProductionCORSConfig(getProductionOrigins())
} }
// 开发环境使用宽松的CORS配置 // 开发环境使用宽松的CORS配置
return DefaultCORSConfig() return DefaultCORSConfig()
} }
@ -47,27 +47,27 @@ func getCORSConfig(c config.Config) CORSConfig {
// getLoggerConfig 获取日志配置 // getLoggerConfig 获取日志配置
func getLoggerConfig(c config.Config) LoggerConfig { func getLoggerConfig(c config.Config) LoggerConfig {
env := getEnvironment() env := getEnvironment()
config := DefaultLoggerConfig() config := DefaultLoggerConfig()
if env == "development" { if env == "development" {
// 开发环境启用详细日志 // 开发环境启用详细日志
config.EnableRequestBody = true config.EnableRequestBody = true
config.EnableResponseBody = true config.EnableResponseBody = true
config.MaxBodySize = 4096 config.MaxBodySize = 4096
} }
return config return config
} }
// getErrorConfig 获取错误配置 // getErrorConfig 获取错误配置
func getErrorConfig(c config.Config) ErrorConfig { func getErrorConfig(c config.Config) ErrorConfig {
env := getEnvironment() env := getEnvironment()
if env == "development" { if env == "development" {
return DevelopmentErrorConfig() return DevelopmentErrorConfig()
} }
return DefaultErrorConfig() return DefaultErrorConfig()
} }
@ -108,9 +108,9 @@ func (m *MiddlewareManager) Chain(handler http.HandlerFunc, middlewares ...func(
// GetGlobalMiddlewares 获取全局中间件 // GetGlobalMiddlewares 获取全局中间件
func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc { func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc {
return []func(http.HandlerFunc) http.HandlerFunc{ return []func(http.HandlerFunc) http.HandlerFunc{
m.errorMiddleware.Handle, // 错误处理 (最外层) m.errorMiddleware.Handle, // 错误处理 (最外层)
m.corsMiddleware.Handle, // CORS 处理 m.corsMiddleware.Handle, // CORS 处理
m.logMiddleware.Handle, // 日志记录 m.logMiddleware.Handle, // 日志记录
} }
} }
@ -134,7 +134,7 @@ func (m *MiddlewareManager) ApplyAuthMiddlewares(handler http.HandlerFunc) http.
func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) { func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ok","timestamp":"` + w.Write([]byte(`{"status":"ok","timestamp":"` +
time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`)) time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`))
} }
@ -171,7 +171,7 @@ func Recovery() MiddlewareFunc {
"path": r.URL.Path, "path": r.URL.Path,
} }
logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields) logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
} }
}() }()
@ -188,10 +188,10 @@ func RequestID() MiddlewareFunc {
if requestID == "" { if requestID == "" {
requestID = generateRequestID() requestID = generateRequestID()
} }
w.Header().Set("X-Request-ID", requestID) w.Header().Set("X-Request-ID", requestID)
r.Header.Set("X-Request-ID", requestID) r.Header.Set("X-Request-ID", requestID)
next(w, r) next(w, r)
}) })
} }
@ -200,4 +200,4 @@ func RequestID() MiddlewareFunc {
// generateRequestID 生成请求ID // generateRequestID 生成请求ID
func generateRequestID() string { func generateRequestID() string {
return randomString(16) return randomString(16)
} }

View File

@ -3,8 +3,8 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/zeromicro/go-zero/core/stores/sqlx" "github.com/zeromicro/go-zero/core/stores/sqlx"
"strings"
) )
var _ CategoryModel = (*customCategoryModel)(nil) var _ CategoryModel = (*customCategoryModel)(nil)
@ -39,20 +39,20 @@ func (m *customCategoryModel) withSession(session sqlx.Session) CategoryModel {
func (m *customCategoryModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*Category, error) { func (m *customCategoryModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*Category, error) {
var conditions []string var conditions []string
var args []interface{} var args []interface{}
if keyword != "" { if keyword != "" {
conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)") conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)")
args = append(args, "%"+keyword+"%", "%"+keyword+"%") args = append(args, "%"+keyword+"%", "%"+keyword+"%")
} }
whereClause := "" whereClause := ""
if len(conditions) > 0 { if len(conditions) > 0 {
whereClause = " WHERE " + strings.Join(conditions, " AND ") whereClause = " WHERE " + strings.Join(conditions, " AND ")
} }
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
args = append(args, pageSize, offset) args = append(args, pageSize, offset)
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", categoryRows, m.table, whereClause) query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", categoryRows, m.table, whereClause)
var resp []*Category var resp []*Category
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...) err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
@ -63,17 +63,17 @@ func (m *customCategoryModel) FindList(ctx context.Context, page, pageSize int,
func (m *customCategoryModel) Count(ctx context.Context, keyword string) (int64, error) { func (m *customCategoryModel) Count(ctx context.Context, keyword string) (int64, error) {
var conditions []string var conditions []string
var args []interface{} var args []interface{}
if keyword != "" { if keyword != "" {
conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)") conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)")
args = append(args, "%"+keyword+"%", "%"+keyword+"%") args = append(args, "%"+keyword+"%", "%"+keyword+"%")
} }
whereClause := "" whereClause := ""
if len(conditions) > 0 { if len(conditions) > 0 {
whereClause = " WHERE " + strings.Join(conditions, " AND ") whereClause = " WHERE " + strings.Join(conditions, " AND ")
} }
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause) query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
var count int64 var count int64
err := m.conn.QueryRowCtx(ctx, &count, query, args...) err := m.conn.QueryRowCtx(ctx, &count, query, args...)

View File

@ -3,8 +3,8 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/zeromicro/go-zero/core/stores/sqlx" "github.com/zeromicro/go-zero/core/stores/sqlx"
"strings"
) )
var _ PhotoModel = (*customPhotoModel)(nil) var _ PhotoModel = (*customPhotoModel)(nil)
@ -39,30 +39,30 @@ func (m *customPhotoModel) withSession(session sqlx.Session) PhotoModel {
func (m *customPhotoModel) FindList(ctx context.Context, page, pageSize int, categoryId, userId int64, keyword string) ([]*Photo, error) { func (m *customPhotoModel) FindList(ctx context.Context, page, pageSize int, categoryId, userId int64, keyword string) ([]*Photo, error) {
var conditions []string var conditions []string
var args []interface{} var args []interface{}
if categoryId > 0 { if categoryId > 0 {
conditions = append(conditions, "`category_id` = ?") conditions = append(conditions, "`category_id` = ?")
args = append(args, categoryId) args = append(args, categoryId)
} }
if userId > 0 { if userId > 0 {
conditions = append(conditions, "`user_id` = ?") conditions = append(conditions, "`user_id` = ?")
args = append(args, userId) args = append(args, userId)
} }
if keyword != "" { if keyword != "" {
conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)") conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)")
args = append(args, "%"+keyword+"%", "%"+keyword+"%") args = append(args, "%"+keyword+"%", "%"+keyword+"%")
} }
whereClause := "" whereClause := ""
if len(conditions) > 0 { if len(conditions) > 0 {
whereClause = " WHERE " + strings.Join(conditions, " AND ") whereClause = " WHERE " + strings.Join(conditions, " AND ")
} }
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
args = append(args, pageSize, offset) args = append(args, pageSize, offset)
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", photoRows, m.table, whereClause) query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", photoRows, m.table, whereClause)
var resp []*Photo var resp []*Photo
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...) err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
@ -73,27 +73,27 @@ func (m *customPhotoModel) FindList(ctx context.Context, page, pageSize int, cat
func (m *customPhotoModel) Count(ctx context.Context, categoryId, userId int64, keyword string) (int64, error) { func (m *customPhotoModel) Count(ctx context.Context, categoryId, userId int64, keyword string) (int64, error) {
var conditions []string var conditions []string
var args []interface{} var args []interface{}
if categoryId > 0 { if categoryId > 0 {
conditions = append(conditions, "`category_id` = ?") conditions = append(conditions, "`category_id` = ?")
args = append(args, categoryId) args = append(args, categoryId)
} }
if userId > 0 { if userId > 0 {
conditions = append(conditions, "`user_id` = ?") conditions = append(conditions, "`user_id` = ?")
args = append(args, userId) args = append(args, userId)
} }
if keyword != "" { if keyword != "" {
conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)") conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)")
args = append(args, "%"+keyword+"%", "%"+keyword+"%") args = append(args, "%"+keyword+"%", "%"+keyword+"%")
} }
whereClause := "" whereClause := ""
if len(conditions) > 0 { if len(conditions) > 0 {
whereClause = " WHERE " + strings.Join(conditions, " AND ") whereClause = " WHERE " + strings.Join(conditions, " AND ")
} }
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause) query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
var count int64 var count int64
err := m.conn.QueryRowCtx(ctx, &count, query, args...) err := m.conn.QueryRowCtx(ctx, &count, query, args...)

View File

@ -3,8 +3,8 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/zeromicro/go-zero/core/stores/sqlx" "github.com/zeromicro/go-zero/core/stores/sqlx"
"strings"
) )
var _ UserModel = (*customUserModel)(nil) var _ UserModel = (*customUserModel)(nil)
@ -39,20 +39,20 @@ func (m *customUserModel) withSession(session sqlx.Session) UserModel {
func (m *customUserModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*User, error) { func (m *customUserModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*User, error) {
var conditions []string var conditions []string
var args []interface{} var args []interface{}
if keyword != "" { if keyword != "" {
conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)") conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)")
args = append(args, "%"+keyword+"%", "%"+keyword+"%") args = append(args, "%"+keyword+"%", "%"+keyword+"%")
} }
whereClause := "" whereClause := ""
if len(conditions) > 0 { if len(conditions) > 0 {
whereClause = " WHERE " + strings.Join(conditions, " AND ") whereClause = " WHERE " + strings.Join(conditions, " AND ")
} }
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
args = append(args, pageSize, offset) args = append(args, pageSize, offset)
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", userRows, m.table, whereClause) query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", userRows, m.table, whereClause)
var resp []*User var resp []*User
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...) err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
@ -63,17 +63,17 @@ func (m *customUserModel) FindList(ctx context.Context, page, pageSize int, keyw
func (m *customUserModel) Count(ctx context.Context, keyword string) (int64, error) { func (m *customUserModel) Count(ctx context.Context, keyword string) (int64, error) {
var conditions []string var conditions []string
var args []interface{} var args []interface{}
if keyword != "" { if keyword != "" {
conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)") conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)")
args = append(args, "%"+keyword+"%", "%"+keyword+"%") args = append(args, "%"+keyword+"%", "%"+keyword+"%")
} }
whereClause := "" whereClause := ""
if len(conditions) > 0 { if len(conditions) > 0 {
whereClause = " WHERE " + strings.Join(conditions, " AND ") whereClause = " WHERE " + strings.Join(conditions, " AND ")
} }
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause) query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
var count int64 var count int64
err := m.conn.QueryRowCtx(ctx, &count, query, args...) err := m.conn.QueryRowCtx(ctx, &count, query, args...)

View File

@ -2,17 +2,17 @@ package svc
import ( import (
"fmt" "fmt"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"gorm.io/gorm" "gorm.io/gorm"
"photography-backend/internal/config" "photography-backend/internal/config"
"photography-backend/internal/middleware" "photography-backend/internal/middleware"
"photography-backend/internal/model" "photography-backend/internal/model"
"photography-backend/pkg/utils/database" "photography-backend/pkg/utils/database"
"github.com/zeromicro/go-zero/core/stores/sqlx"
) )
type ServiceContext struct { type ServiceContext struct {
Config config.Config Config config.Config
DB *gorm.DB DB *gorm.DB
UserModel model.UserModel UserModel model.UserModel
PhotoModel model.PhotoModel PhotoModel model.PhotoModel
CategoryModel model.CategoryModel CategoryModel model.CategoryModel
@ -24,13 +24,13 @@ func NewServiceContext(c config.Config) *ServiceContext {
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Create sqlx connection for go-zero models // Create sqlx connection for go-zero models
sqlxConn := sqlx.NewSqlConn(getSQLDriverName(c.Database.Driver), getSQLDataSource(c.Database)) sqlxConn := sqlx.NewSqlConn(getSQLDriverName(c.Database.Driver), getSQLDataSource(c.Database))
return &ServiceContext{ return &ServiceContext{
Config: c, Config: c,
DB: db, DB: db,
UserModel: model.NewUserModel(sqlxConn), UserModel: model.NewUserModel(sqlxConn),
PhotoModel: model.NewPhotoModel(sqlxConn), PhotoModel: model.NewPhotoModel(sqlxConn),
CategoryModel: model.NewCategoryModel(sqlxConn), CategoryModel: model.NewCategoryModel(sqlxConn),

View File

@ -4,25 +4,25 @@ const (
// 用户状态 // 用户状态
UserStatusActive = 1 UserStatusActive = 1
UserStatusInactive = 0 UserStatusInactive = 0
// 文件上传 // 文件上传
MaxFileSize = 10 << 20 // 10MB MaxFileSize = 10 << 20 // 10MB
// 图片类型 // 图片类型
ImageTypeJPEG = "image/jpeg" ImageTypeJPEG = "image/jpeg"
ImageTypePNG = "image/png" ImageTypePNG = "image/png"
ImageTypeGIF = "image/gif" ImageTypeGIF = "image/gif"
ImageTypeWEBP = "image/webp" ImageTypeWEBP = "image/webp"
// 缩略图尺寸 // 缩略图尺寸
ThumbnailWidth = 300 ThumbnailWidth = 300
ThumbnailHeight = 300 ThumbnailHeight = 300
// JWT 过期时间 // JWT 过期时间
TokenExpireDuration = 24 * 60 * 60 // 24小时 TokenExpireDuration = 24 * 60 * 60 // 24小时
// 分页默认值 // 分页默认值
DefaultPage = 1 DefaultPage = 1
DefaultPageSize = 10 DefaultPageSize = 10
MaxPageSize = 100 MaxPageSize = 100
) )

View File

@ -7,14 +7,14 @@ import (
const ( const (
// 通用错误代码 // 通用错误代码
Success = 0 Success = 0
ServerError = 500 ServerError = 500
ParamError = 400 ParamError = 400
AuthError = 401 AuthError = 401
NotFound = 404 NotFound = 404
Forbidden = 403 Forbidden = 403
InvalidParameter = 400 // 与 ParamError 统一 InvalidParameter = 400 // 与 ParamError 统一
// 业务错误代码 // 业务错误代码
UserNotFound = 1001 UserNotFound = 1001
UserExists = 1002 UserExists = 1002
@ -22,32 +22,32 @@ const (
InvalidPassword = 1004 InvalidPassword = 1004
TokenExpired = 1005 TokenExpired = 1005
TokenInvalid = 1006 TokenInvalid = 1006
PhotoNotFound = 2001 PhotoNotFound = 2001
PhotoUploadFail = 2002 PhotoUploadFail = 2002
CategoryNotFound = 3001 CategoryNotFound = 3001
CategoryExists = 3002 CategoryExists = 3002
) )
var codeText = map[int]string{ var codeText = map[int]string{
Success: "Success", Success: "Success",
ServerError: "Server Error", ServerError: "Server Error",
ParamError: "Parameter Error", // ParamError 和 InvalidParameter 都映射到这里 ParamError: "Parameter Error", // ParamError 和 InvalidParameter 都映射到这里
AuthError: "Authentication Error", AuthError: "Authentication Error",
NotFound: "Not Found", NotFound: "Not Found",
Forbidden: "Forbidden", Forbidden: "Forbidden",
UserNotFound: "User Not Found", UserNotFound: "User Not Found",
UserExists: "User Already Exists", UserExists: "User Already Exists",
UserDisabled: "User Disabled", UserDisabled: "User Disabled",
InvalidPassword: "Invalid Password", InvalidPassword: "Invalid Password",
TokenExpired: "Token Expired", TokenExpired: "Token Expired",
TokenInvalid: "Token Invalid", TokenInvalid: "Token Invalid",
PhotoNotFound: "Photo Not Found", PhotoNotFound: "Photo Not Found",
PhotoUploadFail: "Photo Upload Failed", PhotoUploadFail: "Photo Upload Failed",
CategoryNotFound: "Category Not Found", CategoryNotFound: "Category Not Found",
CategoryExists: "Category Already Exists", CategoryExists: "Category Already Exists",
} }
@ -83,7 +83,7 @@ func GetHttpStatus(code int) int {
switch code { switch code {
case Success: case Success:
return http.StatusOK return http.StatusOK
case ParamError: // ParamError 和 InvalidParameter 都是 400所以只需要一个 case case ParamError: // ParamError 和 InvalidParameter 都是 400所以只需要一个 case
return http.StatusBadRequest return http.StatusBadRequest
case AuthError, TokenExpired, TokenInvalid: case AuthError, TokenExpired, TokenInvalid:
return http.StatusUnauthorized return http.StatusUnauthorized
@ -96,4 +96,4 @@ func GetHttpStatus(code int) int {
default: default:
return http.StatusInternalServerError return http.StatusInternalServerError
} }
} }

View File

@ -21,10 +21,10 @@ type Migration struct {
// MigrationRecord 数据库中的迁移记录 // MigrationRecord 数据库中的迁移记录
type MigrationRecord struct { type MigrationRecord struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
Version string `gorm:"uniqueIndex;size:255;not null"` Version string `gorm:"uniqueIndex;size:255;not null"`
Description string `gorm:"size:500"` Description string `gorm:"size:500"`
Applied bool `gorm:"default:false"` Applied bool `gorm:"default:false"`
AppliedAt time.Time AppliedAt time.Time
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
@ -66,7 +66,7 @@ func (m *Migrator) GetAppliedMigrations() ([]string, error) {
if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil { if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil {
return nil, err return nil, err
} }
versions := make([]string, len(records)) versions := make([]string, len(records))
for i, record := range records { for i, record := range records {
versions[i] = record.Version versions[i] = record.Version
@ -80,24 +80,24 @@ func (m *Migrator) GetPendingMigrations() ([]Migration, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
appliedMap := make(map[string]bool) appliedMap := make(map[string]bool)
for _, version := range appliedVersions { for _, version := range appliedVersions {
appliedMap[version] = true appliedMap[version] = true
} }
var pendingMigrations []Migration var pendingMigrations []Migration
for _, migration := range m.migrations { for _, migration := range m.migrations {
if !appliedMap[migration.Version] { if !appliedMap[migration.Version] {
pendingMigrations = append(pendingMigrations, migration) pendingMigrations = append(pendingMigrations, migration)
} }
} }
// 按版本号排序 // 按版本号排序
sort.Slice(pendingMigrations, func(i, j int) bool { sort.Slice(pendingMigrations, func(i, j int) bool {
return pendingMigrations[i].Version < pendingMigrations[j].Version return pendingMigrations[i].Version < pendingMigrations[j].Version
}) })
return pendingMigrations, nil return pendingMigrations, nil
} }
@ -106,39 +106,39 @@ func (m *Migrator) Up() error {
if err := m.initMigrationTable(); err != nil { if err := m.initMigrationTable(); err != nil {
return err return err
} }
pendingMigrations, err := m.GetPendingMigrations() pendingMigrations, err := m.GetPendingMigrations()
if err != nil { if err != nil {
return err return err
} }
if len(pendingMigrations) == 0 { if len(pendingMigrations) == 0 {
log.Println("No pending migrations") log.Println("No pending migrations")
return nil return nil
} }
for _, migration := range pendingMigrations { for _, migration := range pendingMigrations {
log.Printf("Applying migration %s: %s", migration.Version, migration.Description) log.Printf("Applying migration %s: %s", migration.Version, migration.Description)
// 开始事务 // 开始事务
tx := m.db.Begin() tx := m.db.Begin()
if tx.Error != nil { if tx.Error != nil {
return tx.Error return tx.Error
} }
// 执行迁移SQL // 执行迁移SQL
if err := m.executeSQL(tx, migration.UpSQL); err != nil { if err := m.executeSQL(tx, migration.UpSQL); err != nil {
tx.Rollback() tx.Rollback()
return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err) return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err)
} }
// 记录迁移状态 (使用UPSERT) // 记录迁移状态 (使用UPSERT)
now := time.Now() now := time.Now()
// 检查记录是否已存在 // 检查记录是否已存在
var existingRecord MigrationRecord var existingRecord MigrationRecord
err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
// 创建新记录 // 创建新记录
record := MigrationRecord{ record := MigrationRecord{
@ -167,15 +167,15 @@ func (m *Migrator) Up() error {
tx.Rollback() tx.Rollback()
return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err) return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err)
} }
// 提交事务 // 提交事务
if err := tx.Commit().Error; err != nil { if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err) return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err)
} }
log.Printf("Successfully applied migration %s", migration.Version) log.Printf("Successfully applied migration %s", migration.Version)
} }
return nil return nil
} }
@ -184,58 +184,58 @@ func (m *Migrator) Down(steps int) error {
if err := m.initMigrationTable(); err != nil { if err := m.initMigrationTable(); err != nil {
return err return err
} }
appliedVersions, err := m.GetAppliedMigrations() appliedVersions, err := m.GetAppliedMigrations()
if err != nil { if err != nil {
return err return err
} }
if len(appliedVersions) == 0 { if len(appliedVersions) == 0 {
log.Println("No applied migrations to rollback") log.Println("No applied migrations to rollback")
return nil return nil
} }
// 获取要回滚的迁移(从最新开始) // 获取要回滚的迁移(从最新开始)
rollbackCount := steps rollbackCount := steps
if rollbackCount > len(appliedVersions) { if rollbackCount > len(appliedVersions) {
rollbackCount = len(appliedVersions) rollbackCount = len(appliedVersions)
} }
for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- { for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- {
version := appliedVersions[i] version := appliedVersions[i]
migration := m.findMigrationByVersion(version) migration := m.findMigrationByVersion(version)
if migration == nil { if migration == nil {
return fmt.Errorf("migration %s not found in migration definitions", version) return fmt.Errorf("migration %s not found in migration definitions", version)
} }
log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description) log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description)
// 开始事务 // 开始事务
tx := m.db.Begin() tx := m.db.Begin()
if tx.Error != nil { if tx.Error != nil {
return tx.Error return tx.Error
} }
// 执行回滚SQL // 执行回滚SQL
if err := m.executeSQL(tx, migration.DownSQL); err != nil { if err := m.executeSQL(tx, migration.DownSQL); err != nil {
tx.Rollback() tx.Rollback()
return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err) return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err)
} }
// 更新迁移状态 // 更新迁移状态
if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil { if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil {
tx.Rollback() tx.Rollback()
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err) return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
} }
// 提交事务 // 提交事务
if err := tx.Commit().Error; err != nil { if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err) return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err)
} }
log.Printf("Successfully rolled back migration %s", migration.Version) log.Printf("Successfully rolled back migration %s", migration.Version)
} }
return nil return nil
} }
@ -244,27 +244,27 @@ func (m *Migrator) Status() error {
if err := m.initMigrationTable(); err != nil { if err := m.initMigrationTable(); err != nil {
return err return err
} }
appliedVersions, err := m.GetAppliedMigrations() appliedVersions, err := m.GetAppliedMigrations()
if err != nil { if err != nil {
return err return err
} }
appliedMap := make(map[string]bool) appliedMap := make(map[string]bool)
for _, version := range appliedVersions { for _, version := range appliedVersions {
appliedMap[version] = true appliedMap[version] = true
} }
// 排序所有迁移 // 排序所有迁移
allMigrations := m.migrations allMigrations := m.migrations
sort.Slice(allMigrations, func(i, j int) bool { sort.Slice(allMigrations, func(i, j int) bool {
return allMigrations[i].Version < allMigrations[j].Version return allMigrations[i].Version < allMigrations[j].Version
}) })
fmt.Println("Migration Status:") fmt.Println("Migration Status:")
fmt.Println("Version | Status | Description") fmt.Println("Version | Status | Description")
fmt.Println("---------------|---------|----------------------------------") fmt.Println("---------------|---------|----------------------------------")
for _, migration := range allMigrations { for _, migration := range allMigrations {
status := "Pending" status := "Pending"
if appliedMap[migration.Version] { if appliedMap[migration.Version] {
@ -272,7 +272,7 @@ func (m *Migrator) Status() error {
} }
fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description) fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description)
} }
return nil return nil
} }
@ -280,18 +280,18 @@ func (m *Migrator) Status() error {
func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error { func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error {
// 分割SQL语句按分号分割 // 分割SQL语句按分号分割
statements := strings.Split(sqlStr, ";") statements := strings.Split(sqlStr, ";")
for _, statement := range statements { for _, statement := range statements {
statement = strings.TrimSpace(statement) statement = strings.TrimSpace(statement)
if statement == "" { if statement == "" {
continue continue
} }
if err := tx.Exec(statement).Error; err != nil { if err := tx.Exec(statement).Error; err != nil {
return err return err
} }
} }
return nil return nil
} }
@ -308,25 +308,25 @@ func (m *Migrator) findMigrationByVersion(version string) *Migration {
// Reset 重置数据库(谨慎使用) // Reset 重置数据库(谨慎使用)
func (m *Migrator) Reset() error { func (m *Migrator) Reset() error {
log.Println("WARNING: This will drop all tables and reset the database!") log.Println("WARNING: This will drop all tables and reset the database!")
// 获取所有应用的迁移 // 获取所有应用的迁移
appliedVersions, err := m.GetAppliedMigrations() appliedVersions, err := m.GetAppliedMigrations()
if err != nil { if err != nil {
return err return err
} }
// 回滚所有迁移 // 回滚所有迁移
if len(appliedVersions) > 0 { if len(appliedVersions) > 0 {
if err := m.Down(len(appliedVersions)); err != nil { if err := m.Down(len(appliedVersions)); err != nil {
return err return err
} }
} }
// 删除迁移表 // 删除迁移表
if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil { if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to drop migration table: %v", err) return fmt.Errorf("failed to drop migration table: %v", err)
} }
log.Println("Database reset completed") log.Println("Database reset completed")
return nil return nil
} }
@ -337,25 +337,25 @@ func (m *Migrator) Migrate(steps int) error {
if err != nil { if err != nil {
return err return err
} }
if len(pendingMigrations) == 0 { if len(pendingMigrations) == 0 {
log.Println("No pending migrations") log.Println("No pending migrations")
return nil return nil
} }
migrateCount := steps migrateCount := steps
if steps <= 0 || steps > len(pendingMigrations) { if steps <= 0 || steps > len(pendingMigrations) {
migrateCount = len(pendingMigrations) migrateCount = len(pendingMigrations)
} }
// 临时修改migrations列表只包含要执行的迁移 // 临时修改migrations列表只包含要执行的迁移
originalMigrations := m.migrations originalMigrations := m.migrations
m.migrations = pendingMigrations[:migrateCount] m.migrations = pendingMigrations[:migrateCount]
err = m.Up() err = m.Up()
// 恢复原始migrations列表 // 恢复原始migrations列表
m.migrations = originalMigrations m.migrations = originalMigrations
return err return err
} }

View File

@ -309,13 +309,13 @@ func GetLatestMigrationVersion() string {
if len(migrations) == 0 { if len(migrations) == 0 {
return "" return ""
} }
latest := migrations[0] latest := migrations[0]
for _, migration := range migrations { for _, migration := range migrations {
if migration.Version > latest.Version { if migration.Version > latest.Version {
latest = migration latest = migration
} }
} }
return latest.Version return latest.Version
} }

View File

@ -2,7 +2,7 @@ package response
import ( import (
"net/http" "net/http"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
"photography-backend/pkg/errorx" "photography-backend/pkg/errorx"
) )
@ -39,4 +39,4 @@ func Success(w http.ResponseWriter, data interface{}) {
func Error(w http.ResponseWriter, err error) { func Error(w http.ResponseWriter, err error) {
Response(w, nil, err) Response(w, nil, err)
} }

View File

@ -13,7 +13,7 @@ import (
) )
type Config struct { type Config struct {
Driver string `json:"driver"` // mysql, postgres, sqlite Driver string `json:"driver"` // mysql, postgres, sqlite
Host string `json:"host,optional"` Host string `json:"host,optional"`
Port int `json:"port,optional"` Port int `json:"port,optional"`
Username string `json:"username,optional"` Username string `json:"username,optional"`
@ -27,7 +27,7 @@ type Config struct {
func NewDB(config Config) (*gorm.DB, error) { func NewDB(config Config) (*gorm.DB, error) {
var db *gorm.DB var db *gorm.DB
var err error var err error
// 配置日志 // 配置日志
newLogger := logger.New( newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
@ -37,7 +37,7 @@ func NewDB(config Config) (*gorm.DB, error) {
Colorful: false, // 禁用彩色打印 Colorful: false, // 禁用彩色打印
}, },
) )
switch config.Driver { switch config.Driver {
case "mysql": case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local", dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
@ -58,20 +58,20 @@ func NewDB(config Config) (*gorm.DB, error) {
default: default:
return nil, fmt.Errorf("unsupported database driver: %s", config.Driver) return nil, fmt.Errorf("unsupported database driver: %s", config.Driver)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 设置连接池 // 设置连接池
sqlDB, err := db.DB() sqlDB, err := db.DB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100) sqlDB.SetMaxOpenConns(100)
sqlDB.SetConnMaxLifetime(time.Hour) sqlDB.SetConnMaxLifetime(time.Hour)
return db, nil return db, nil
} }

View File

@ -11,7 +11,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
"github.com/disintegration/imaging" "github.com/disintegration/imaging"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -64,39 +64,39 @@ func SaveFile(file multipart.File, header *multipart.FileHeader, config Config)
if header.Size > config.MaxSize { if header.Size > config.MaxSize {
return nil, fmt.Errorf("文件大小超过限制: %d bytes", config.MaxSize) return nil, fmt.Errorf("文件大小超过限制: %d bytes", config.MaxSize)
} }
// 检查文件类型 // 检查文件类型
contentType := header.Header.Get("Content-Type") contentType := header.Header.Get("Content-Type")
if !IsAllowedType(contentType, config.AllowedTypes) { if !IsAllowedType(contentType, config.AllowedTypes) {
return nil, fmt.Errorf("不支持的文件类型: %s", contentType) return nil, fmt.Errorf("不支持的文件类型: %s", contentType)
} }
// 生成文件名 // 生成文件名
fileName := GenerateFileName(header.Filename) fileName := GenerateFileName(header.Filename)
// 创建上传目录 // 创建上传目录
uploadPath := filepath.Join(config.UploadDir, "photos") uploadPath := filepath.Join(config.UploadDir, "photos")
if err := os.MkdirAll(uploadPath, 0755); err != nil { if err := os.MkdirAll(uploadPath, 0755); err != nil {
return nil, fmt.Errorf("创建上传目录失败: %v", err) return nil, fmt.Errorf("创建上传目录失败: %v", err)
} }
// 完整文件路径 // 完整文件路径
filePath := filepath.Join(uploadPath, fileName) filePath := filepath.Join(uploadPath, fileName)
// 创建目标文件 // 创建目标文件
dst, err := os.Create(filePath) dst, err := os.Create(filePath)
if err != nil { if err != nil {
return nil, fmt.Errorf("创建文件失败: %v", err) return nil, fmt.Errorf("创建文件失败: %v", err)
} }
defer dst.Close() defer dst.Close()
// 复制文件内容 // 复制文件内容
file.Seek(0, 0) // 重置文件指针 file.Seek(0, 0) // 重置文件指针
_, err = io.Copy(dst, file) _, err = io.Copy(dst, file)
if err != nil { if err != nil {
return nil, fmt.Errorf("保存文件失败: %v", err) return nil, fmt.Errorf("保存文件失败: %v", err)
} }
// 返回文件信息 // 返回文件信息
return &FileInfo{ return &FileInfo{
OriginalName: header.Filename, OriginalName: header.Filename,
@ -115,35 +115,35 @@ func CreateThumbnail(originalPath string, config Config) (*FileInfo, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("打开图片失败: %v", err) return nil, fmt.Errorf("打开图片失败: %v", err)
} }
// 生成缩略图文件名 // 生成缩略图文件名
ext := filepath.Ext(originalPath) ext := filepath.Ext(originalPath)
baseName := strings.TrimSuffix(filepath.Base(originalPath), ext) baseName := strings.TrimSuffix(filepath.Base(originalPath), ext)
thumbnailName := baseName + "_thumb" + ext thumbnailName := baseName + "_thumb" + ext
// 创建缩略图目录 // 创建缩略图目录
thumbnailDir := filepath.Join(config.UploadDir, "thumbnails") thumbnailDir := filepath.Join(config.UploadDir, "thumbnails")
if err := os.MkdirAll(thumbnailDir, 0755); err != nil { if err := os.MkdirAll(thumbnailDir, 0755); err != nil {
return nil, fmt.Errorf("创建缩略图目录失败: %v", err) return nil, fmt.Errorf("创建缩略图目录失败: %v", err)
} }
thumbnailPath := filepath.Join(thumbnailDir, thumbnailName) thumbnailPath := filepath.Join(thumbnailDir, thumbnailName)
// 调整图片大小 (最大宽度 300px保持比例) // 调整图片大小 (最大宽度 300px保持比例)
thumbnail := imaging.Resize(src, 300, 0, imaging.Lanczos) thumbnail := imaging.Resize(src, 300, 0, imaging.Lanczos)
// 保存缩略图 // 保存缩略图
err = imaging.Save(thumbnail, thumbnailPath) err = imaging.Save(thumbnail, thumbnailPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("保存缩略图失败: %v", err) return nil, fmt.Errorf("保存缩略图失败: %v", err)
} }
// 获取文件大小 // 获取文件大小
stat, err := os.Stat(thumbnailPath) stat, err := os.Stat(thumbnailPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("获取缩略图信息失败: %v", err) return nil, fmt.Errorf("获取缩略图信息失败: %v", err)
} }
return &FileInfo{ return &FileInfo{
OriginalName: thumbnailName, OriginalName: thumbnailName,
FileName: thumbnailName, FileName: thumbnailName,
@ -161,13 +161,13 @@ func UploadPhoto(file multipart.File, header *multipart.FileHeader, config Confi
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 创建缩略图 // 创建缩略图
thumbnail, err := CreateThumbnail(original.FilePath, config) thumbnail, err := CreateThumbnail(original.FilePath, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &UploadResult{ return &UploadResult{
Original: *original, Original: *original,
Thumbnail: *thumbnail, Thumbnail: *thumbnail,
@ -179,11 +179,11 @@ func DeleteFile(filePath string) error {
if filePath == "" { if filePath == "" {
return nil return nil
} }
if _, err := os.Stat(filePath); os.IsNotExist(err) { if _, err := os.Stat(filePath); os.IsNotExist(err) {
return nil // 文件不存在,认为删除成功 return nil // 文件不存在,认为删除成功
} }
return os.Remove(filePath) return os.Remove(filePath)
} }
@ -194,12 +194,12 @@ func GetImageDimensions(filePath string) (width, height int, err error) {
return 0, 0, err return 0, 0, err
} }
defer file.Close() defer file.Close()
img, _, err := image.DecodeConfig(file) img, _, err := image.DecodeConfig(file)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
return img.Width, img.Height, nil return img.Width, img.Height, nil
} }
@ -210,16 +210,16 @@ func ResizeImage(srcPath, destPath string, width, height int) error {
if err != nil { if err != nil {
return fmt.Errorf("打开图片失败: %v", err) return fmt.Errorf("打开图片失败: %v", err)
} }
// 调整图片尺寸 (正方形裁剪) // 调整图片尺寸 (正方形裁剪)
resized := imaging.Fill(src, width, height, imaging.Center, imaging.Lanczos) resized := imaging.Fill(src, width, height, imaging.Center, imaging.Lanczos)
// 保存调整后的图片 // 保存调整后的图片
err = imaging.Save(resized, destPath) err = imaging.Save(resized, destPath)
if err != nil { if err != nil {
return fmt.Errorf("保存调整后的图片失败: %v", err) return fmt.Errorf("保存调整后的图片失败: %v", err)
} }
return nil return nil
} }
@ -230,15 +230,15 @@ func CreateAvatar(srcPath, destPath string, size int) error {
if err != nil { if err != nil {
return fmt.Errorf("打开图片失败: %v", err) return fmt.Errorf("打开图片失败: %v", err)
} }
// 创建正方形头像 (居中裁剪) // 创建正方形头像 (居中裁剪)
avatar := imaging.Fill(src, size, size, imaging.Center, imaging.Lanczos) avatar := imaging.Fill(src, size, size, imaging.Center, imaging.Lanczos)
// 保存为JPEG格式 (压缩优化) // 保存为JPEG格式 (压缩优化)
err = imaging.Save(avatar, destPath, imaging.JPEGQuality(85)) err = imaging.Save(avatar, destPath, imaging.JPEGQuality(85))
if err != nil { if err != nil {
return fmt.Errorf("保存头像失败: %v", err) return fmt.Errorf("保存头像失败: %v", err)
} }
return nil return nil
} }

View File

@ -31,4 +31,4 @@ func SHA256(str string) string {
h := sha256.New() h := sha256.New()
h.Write([]byte(str)) h.Write([]byte(str))
return hex.EncodeToString(h.Sum(nil)) return hex.EncodeToString(h.Sum(nil))
} }

View File

@ -2,7 +2,7 @@ package jwt
import ( import (
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
) )
@ -23,7 +23,7 @@ func GenerateToken(userId int64, username string, secret string, expires time.Du
NotBefore: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now),
}, },
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(secret)) return token.SignedString([]byte(secret))
} }
@ -32,14 +32,14 @@ func ParseToken(tokenString string, secret string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil return []byte(secret), nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
if claims, ok := token.Claims.(*Claims); ok && token.Valid { if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil return claims, nil
} }
return nil, jwt.ErrInvalidKey return nil, jwt.ErrInvalidKey
} }

View File

@ -15,9 +15,12 @@ const api = axios.create({
api.interceptors.request.use( api.interceptors.request.use(
(config) => { (config) => {
// 可以在这里添加token等认证信息 // 可以在这里添加token等认证信息
const token = localStorage.getItem('token') // 检查是否在浏览器环境中
if (token) { if (typeof window !== 'undefined') {
config.headers.Authorization = `Bearer ${token}` const token = localStorage.getItem('token')
if (token) {
config.headers.Authorization = `Bearer ${token}`
}
} }
return config return config
}, },
@ -42,9 +45,11 @@ api.interceptors.response.use(
}, },
(error) => { (error) => {
if (error.response?.status === 401) { if (error.response?.status === 401) {
// 处理未授权 // 处理未授权 - 仅在浏览器环境中执行
localStorage.removeItem('token') if (typeof window !== 'undefined') {
window.location.href = '/login' localStorage.removeItem('token')
window.location.href = '/login'
}
} }
return Promise.reject(error) return Promise.reject(error)
} }