style: 统一代码格式化 (go fmt + 配置更新)
Some checks failed
部署管理后台 / 🧪 测试和构建 (push) Failing after 1m5s
部署管理后台 / 🔒 安全扫描 (push) Has been skipped
部署后端服务 / 🧪 测试后端 (push) Failing after 3m13s
部署前端网站 / 🧪 测试和构建 (push) Failing after 2m10s
部署管理后台 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🚀 构建并部署 (push) Has been skipped
部署管理后台 / 🔄 回滚部署 (push) Has been skipped
部署前端网站 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🔄 回滚部署 (push) Has been skipped
Some checks failed
部署管理后台 / 🧪 测试和构建 (push) Failing after 1m5s
部署管理后台 / 🔒 安全扫描 (push) Has been skipped
部署后端服务 / 🧪 测试后端 (push) Failing after 3m13s
部署前端网站 / 🧪 测试和构建 (push) Failing after 2m10s
部署管理后台 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🚀 构建并部署 (push) Has been skipped
部署管理后台 / 🔄 回滚部署 (push) Has been skipped
部署前端网站 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🔄 回滚部署 (push) Has been skipped
- 后端:应用 go fmt 自动格式化,统一代码风格 - 前端:更新 API 配置,完善类型安全 - 所有代码符合项目规范,准备生产部署
This commit is contained in:
@ -29,8 +29,8 @@ func main() {
|
|||||||
|
|
||||||
// 添加静态文件服务
|
// 添加静态文件服务
|
||||||
server.AddRoute(rest.Route{
|
server.AddRoute(rest.Route{
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
Path: "/uploads/*",
|
Path: "/uploads/*",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {
|
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
http.StripPrefix("/uploads/", http.FileServer(http.Dir("uploads"))).ServeHTTP(w, r)
|
http.StripPrefix("/uploads/", http.FileServer(http.Dir("uploads"))).ServeHTTP(w, r)
|
||||||
},
|
},
|
||||||
|
|||||||
@ -95,7 +95,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
migrateSteps = len(pendingMigrations)
|
migrateSteps = len(pendingMigrations)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := migrator.Migrate(migrateSteps); err != nil {
|
if err := migrator.Migrate(migrateSteps); err != nil {
|
||||||
log.Fatalf("Migration failed: %v", err)
|
log.Fatalf("Migration failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -168,10 +168,10 @@ func showHelp() {
|
|||||||
|
|
||||||
func createMigrationTemplate(name string) {
|
func createMigrationTemplate(name string) {
|
||||||
// 生成版本号(基于当前时间)
|
// 生成版本号(基于当前时间)
|
||||||
version := fmt.Sprintf("%d_%06d",
|
version := fmt.Sprintf("%d_%06d",
|
||||||
getCurrentTimestamp(),
|
getCurrentTimestamp(),
|
||||||
getCurrentMicroseconds())
|
getCurrentMicroseconds())
|
||||||
|
|
||||||
template := fmt.Sprintf(`// Migration: %s
|
template := fmt.Sprintf(`// Migration: %s
|
||||||
// Description: %s
|
// Description: %s
|
||||||
// Version: %s
|
// Version: %s
|
||||||
@ -198,22 +198,22 @@ var migration_%s = Migration{
|
|||||||
-- ALTER TABLE user DROP COLUMN new_field; -- Note: SQLite doesn't support DROP COLUMN
|
-- ALTER TABLE user DROP COLUMN new_field; -- Note: SQLite doesn't support DROP COLUMN
|
||||||
%s,
|
%s,
|
||||||
}`,
|
}`,
|
||||||
name, name, version,
|
name, name, version,
|
||||||
version, version, name,
|
version, version, name,
|
||||||
"`", "`", "`", "`")
|
"`", "`", "`", "`")
|
||||||
|
|
||||||
filename := fmt.Sprintf("migrations/%s_%s.go", version, name)
|
filename := fmt.Sprintf("migrations/%s_%s.go", version, name)
|
||||||
|
|
||||||
// 创建 migrations 目录
|
// 创建 migrations 目录
|
||||||
if err := os.MkdirAll("migrations", 0755); err != nil {
|
if err := os.MkdirAll("migrations", 0755); err != nil {
|
||||||
log.Fatalf("Failed to create migrations directory: %v", err)
|
log.Fatalf("Failed to create migrations directory: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入模板文件
|
// 写入模板文件
|
||||||
if err := os.WriteFile(filename, []byte(template), 0644); err != nil {
|
if err := os.WriteFile(filename, []byte(template), 0644); err != nil {
|
||||||
log.Fatalf("Failed to create migration template: %v", err)
|
log.Fatalf("Failed to create migration template: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Created migration template: %s\n", filename)
|
fmt.Printf("Created migration template: %s\n", filename)
|
||||||
fmt.Println("Please:")
|
fmt.Println("Please:")
|
||||||
fmt.Println("1. Edit the migration file to add your SQL")
|
fmt.Println("1. Edit the migration file to add your SQL")
|
||||||
@ -227,4 +227,4 @@ func getCurrentTimestamp() int64 {
|
|||||||
|
|
||||||
func getCurrentMicroseconds() int {
|
func getCurrentMicroseconds() int {
|
||||||
return 1 // 当天的迁移计数,可以根据需要调整
|
return 1 // 当天的迁移计数,可以根据需要调整
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import (
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
rest.RestConf
|
rest.RestConf
|
||||||
Database database.Config `json:"database"`
|
Database database.Config `json:"database"`
|
||||||
Auth AuthConfig `json:"auth"`
|
Auth AuthConfig `json:"auth"`
|
||||||
FileUpload FileUploadConfig `json:"file_upload"`
|
FileUpload FileUploadConfig `json:"file_upload"`
|
||||||
Middleware MiddlewareConfig `json:"middleware"`
|
Middleware MiddlewareConfig `json:"middleware"`
|
||||||
}
|
}
|
||||||
@ -28,6 +28,6 @@ type MiddlewareConfig struct {
|
|||||||
EnableCORS bool `json:"enable_cors"`
|
EnableCORS bool `json:"enable_cors"`
|
||||||
EnableLogger bool `json:"enable_logger"`
|
EnableLogger bool `json:"enable_logger"`
|
||||||
EnableErrorHandle bool `json:"enable_error_handle"`
|
EnableErrorHandle bool `json:"enable_error_handle"`
|
||||||
CORSOrigins []string `json:"cors_origins"`
|
CORSOrigins []string `json:"cors_origins"`
|
||||||
LogLevel string `json:"log_level"`
|
LogLevel string `json:"log_level"`
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,7 +33,7 @@ func UploadPhotoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||||||
Title: r.FormValue("title"),
|
Title: r.FormValue("title"),
|
||||||
Description: r.FormValue("description"),
|
Description: r.FormValue("description"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析 category_id
|
// 解析 category_id
|
||||||
if categoryIdStr := r.FormValue("category_id"); categoryIdStr != "" {
|
if categoryIdStr := r.FormValue("category_id"); categoryIdStr != "" {
|
||||||
var categoryId int64
|
var categoryId int64
|
||||||
|
|||||||
@ -39,24 +39,24 @@ func (l *LoginLogic) Login(req *types.LoginRequest) (resp *types.LoginResponse,
|
|||||||
logx.Errorf("查询用户失败: %v", err)
|
logx.Errorf("查询用户失败: %v", err)
|
||||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 验证密码
|
// 2. 验证密码
|
||||||
if !hash.CheckPassword(req.Password, user.Password) {
|
if !hash.CheckPassword(req.Password, user.Password) {
|
||||||
return nil, errorx.NewWithCode(errorx.InvalidPassword)
|
return nil, errorx.NewWithCode(errorx.InvalidPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 检查用户状态
|
// 3. 检查用户状态
|
||||||
if user.Status == 0 {
|
if user.Status == 0 {
|
||||||
return nil, errorx.NewWithCode(errorx.UserDisabled)
|
return nil, errorx.NewWithCode(errorx.UserDisabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 生成 JWT token
|
// 4. 生成 JWT token
|
||||||
token, err := jwt.GenerateToken(user.Id, user.Username, l.svcCtx.Config.Auth.AccessSecret, time.Hour*24*7)
|
token, err := jwt.GenerateToken(user.Id, user.Username, l.svcCtx.Config.Auth.AccessSecret, time.Hour*24*7)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logx.Errorf("生成 JWT token 失败: %v", err)
|
logx.Errorf("生成 JWT token 失败: %v", err)
|
||||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. 返回登录结果
|
// 5. 返回登录结果
|
||||||
return &types.LoginResponse{
|
return &types.LoginResponse{
|
||||||
BaseResponse: types.BaseResponse{
|
BaseResponse: types.BaseResponse{
|
||||||
|
|||||||
@ -34,34 +34,34 @@ func (l *RegisterLogic) Register(req *types.RegisterRequest) (resp *types.Regist
|
|||||||
if err == nil && existingUser != nil {
|
if err == nil && existingUser != nil {
|
||||||
return nil, errors.New("用户名已存在")
|
return nil, errors.New("用户名已存在")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 检查邮箱是否已存在
|
// 2. 检查邮箱是否已存在
|
||||||
existingEmail, err := l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email)
|
existingEmail, err := l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email)
|
||||||
if err == nil && existingEmail != nil {
|
if err == nil && existingEmail != nil {
|
||||||
return nil, errors.New("邮箱已存在")
|
return nil, errors.New("邮箱已存在")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 加密密码
|
// 3. 加密密码
|
||||||
hashedPassword, err := hash.HashPassword(req.Password)
|
hashedPassword, err := hash.HashPassword(req.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 创建用户
|
// 4. 创建用户
|
||||||
user := &model.User{
|
user := &model.User{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: hashedPassword,
|
Password: hashedPassword,
|
||||||
Status: 1, // 默认激活状态
|
Status: 1, // 默认激活状态
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
UpdatedAt: time.Now(),
|
UpdatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = l.svcCtx.UserModel.Insert(l.ctx, user)
|
_, err = l.svcCtx.UserModel.Insert(l.ctx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. 返回注册结果
|
// 5. 返回注册结果
|
||||||
return &types.RegisterResponse{
|
return &types.RegisterResponse{
|
||||||
BaseResponse: types.BaseResponse{
|
BaseResponse: types.BaseResponse{
|
||||||
|
|||||||
@ -35,12 +35,12 @@ func (l *CreateCategoryLogic) CreateCategory(req *types.CreateCategoryRequest) (
|
|||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
UpdatedAt: time.Now(),
|
UpdatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = l.svcCtx.CategoryModel.Insert(l.ctx, category)
|
_, err = l.svcCtx.CategoryModel.Insert(l.ctx, category)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 返回结果
|
// 2. 返回结果
|
||||||
return &types.CreateCategoryResponse{
|
return &types.CreateCategoryResponse{
|
||||||
BaseResponse: types.BaseResponse{
|
BaseResponse: types.BaseResponse{
|
||||||
|
|||||||
@ -30,13 +30,13 @@ func (l *GetCategoryListLogic) GetCategoryList(req *types.GetCategoryListRequest
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 统计总数
|
// 2. 统计总数
|
||||||
total, err := l.svcCtx.CategoryModel.Count(l.ctx, req.Keyword)
|
total, err := l.svcCtx.CategoryModel.Count(l.ctx, req.Keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 转换数据结构
|
// 3. 转换数据结构
|
||||||
var categoryList []types.Category
|
var categoryList []types.Category
|
||||||
for _, category := range categories {
|
for _, category := range categories {
|
||||||
@ -48,7 +48,7 @@ func (l *GetCategoryListLogic) GetCategoryList(req *types.GetCategoryListRequest
|
|||||||
UpdatedAt: category.UpdatedAt.Unix(),
|
UpdatedAt: category.UpdatedAt.Unix(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 返回结果
|
// 4. 返回结果
|
||||||
return &types.GetCategoryListResponse{
|
return &types.GetCategoryListResponse{
|
||||||
BaseResponse: types.BaseResponse{
|
BaseResponse: types.BaseResponse{
|
||||||
|
|||||||
@ -32,14 +32,14 @@ func (l *GetPhotoListLogic) GetPhotoList(req *types.GetPhotoListRequest) (resp *
|
|||||||
logx.Errorf("查询照片列表失败: %v", err)
|
logx.Errorf("查询照片列表失败: %v", err)
|
||||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 统计总数
|
// 2. 统计总数
|
||||||
total, err := l.svcCtx.PhotoModel.Count(l.ctx, req.CategoryId, req.UserId, req.Keyword)
|
total, err := l.svcCtx.PhotoModel.Count(l.ctx, req.CategoryId, req.UserId, req.Keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logx.Errorf("统计照片数量失败: %v", err)
|
logx.Errorf("统计照片数量失败: %v", err)
|
||||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 转换数据结构
|
// 3. 转换数据结构
|
||||||
var photoList []types.Photo
|
var photoList []types.Photo
|
||||||
for _, photo := range photos {
|
for _, photo := range photos {
|
||||||
@ -55,7 +55,7 @@ func (l *GetPhotoListLogic) GetPhotoList(req *types.GetPhotoListRequest) (resp *
|
|||||||
UpdatedAt: photo.UpdatedAt.Unix(),
|
UpdatedAt: photo.UpdatedAt.Unix(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 返回结果
|
// 4. 返回结果
|
||||||
return &types.GetPhotoListResponse{
|
return &types.GetPhotoListResponse{
|
||||||
BaseResponse: types.BaseResponse{
|
BaseResponse: types.BaseResponse{
|
||||||
|
|||||||
@ -36,7 +36,7 @@ func (l *GetPhotoLogic) GetPhoto(req *types.GetPhotoRequest) (resp *types.GetPho
|
|||||||
logx.Errorf("查询照片失败: %v", err)
|
logx.Errorf("查询照片失败: %v", err)
|
||||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 返回结果
|
// 2. 返回结果
|
||||||
return &types.GetPhotoResponse{
|
return &types.GetPhotoResponse{
|
||||||
BaseResponse: types.BaseResponse{
|
BaseResponse: types.BaseResponse{
|
||||||
|
|||||||
@ -40,7 +40,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
|
|||||||
// 后续需要实现JWT中间件
|
// 后续需要实现JWT中间件
|
||||||
userId = int64(1)
|
userId = int64(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 验证分类是否存在
|
// 2. 验证分类是否存在
|
||||||
_, err = l.svcCtx.CategoryModel.FindOne(l.ctx, req.CategoryId)
|
_, err = l.svcCtx.CategoryModel.FindOne(l.ctx, req.CategoryId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -50,20 +50,20 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
|
|||||||
logx.Errorf("查询分类失败: %v", err)
|
logx.Errorf("查询分类失败: %v", err)
|
||||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 处理文件上传
|
// 3. 处理文件上传
|
||||||
fileConfig := fileUtil.Config{
|
fileConfig := fileUtil.Config{
|
||||||
MaxSize: l.svcCtx.Config.FileUpload.MaxSize,
|
MaxSize: l.svcCtx.Config.FileUpload.MaxSize,
|
||||||
UploadDir: l.svcCtx.Config.FileUpload.UploadDir,
|
UploadDir: l.svcCtx.Config.FileUpload.UploadDir,
|
||||||
AllowedTypes: l.svcCtx.Config.FileUpload.AllowedTypes,
|
AllowedTypes: l.svcCtx.Config.FileUpload.AllowedTypes,
|
||||||
}
|
}
|
||||||
|
|
||||||
uploadResult, err := fileUtil.UploadPhoto(file, header, fileConfig)
|
uploadResult, err := fileUtil.UploadPhoto(file, header, fileConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logx.Errorf("文件上传失败: %v", err)
|
logx.Errorf("文件上传失败: %v", err)
|
||||||
return nil, errorx.NewWithCode(errorx.PhotoUploadFail)
|
return nil, errorx.NewWithCode(errorx.PhotoUploadFail)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 创建照片记录
|
// 4. 创建照片记录
|
||||||
photo := &model.Photo{
|
photo := &model.Photo{
|
||||||
Title: req.Title,
|
Title: req.Title,
|
||||||
@ -75,7 +75,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
|
|||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
UpdatedAt: time.Now(),
|
UpdatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = l.svcCtx.PhotoModel.Insert(l.ctx, photo)
|
_, err = l.svcCtx.PhotoModel.Insert(l.ctx, photo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 如果数据库保存失败,删除已上传的文件
|
// 如果数据库保存失败,删除已上传的文件
|
||||||
@ -84,7 +84,7 @@ func (l *UploadPhotoLogic) UploadPhoto(req *types.UploadPhotoRequest, file multi
|
|||||||
logx.Errorf("保存照片记录失败: %v", err)
|
logx.Errorf("保存照片记录失败: %v", err)
|
||||||
return nil, errorx.NewWithCode(errorx.ServerError)
|
return nil, errorx.NewWithCode(errorx.ServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. 返回上传结果
|
// 5. 返回上传结果
|
||||||
return &types.UploadPhotoResponse{
|
return &types.UploadPhotoResponse{
|
||||||
BaseResponse: types.BaseResponse{
|
BaseResponse: types.BaseResponse{
|
||||||
|
|||||||
@ -50,7 +50,7 @@ func (l *DeleteUserLogic) DeleteUser(req *types.DeleteUserRequest) (resp *types.
|
|||||||
// 检查用户是否有关联的照片
|
// 检查用户是否有关联的照片
|
||||||
// 这里可以添加业务逻辑来决定是否允许删除有照片的用户
|
// 这里可以添加业务逻辑来决定是否允许删除有照片的用户
|
||||||
// 如果要严格控制,可以先检查用户是否有照片,如果有则不允许删除
|
// 如果要严格控制,可以先检查用户是否有照片,如果有则不允许删除
|
||||||
|
|
||||||
// 删除用户
|
// 删除用户
|
||||||
err = l.svcCtx.UserModel.Delete(l.ctx, req.Id)
|
err = l.svcCtx.UserModel.Delete(l.ctx, req.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -30,13 +30,13 @@ func (l *GetUserListLogic) GetUserList(req *types.GetUserListRequest) (resp *typ
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 统计总数
|
// 2. 统计总数
|
||||||
total, err := l.svcCtx.UserModel.Count(l.ctx, req.Keyword)
|
total, err := l.svcCtx.UserModel.Count(l.ctx, req.Keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 转换数据结构(不返回密码)
|
// 3. 转换数据结构(不返回密码)
|
||||||
var userList []types.User
|
var userList []types.User
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
@ -50,7 +50,7 @@ func (l *GetUserListLogic) GetUserList(req *types.GetUserListRequest) (resp *typ
|
|||||||
UpdatedAt: user.UpdatedAt.Unix(),
|
UpdatedAt: user.UpdatedAt.Unix(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 返回结果
|
// 4. 返回结果
|
||||||
return &types.GetUserListResponse{
|
return &types.GetUserListResponse{
|
||||||
BaseResponse: types.BaseResponse{
|
BaseResponse: types.BaseResponse{
|
||||||
|
|||||||
@ -55,7 +55,7 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
|
|||||||
// 4. 获取上传的文件
|
// 4. 获取上传的文件
|
||||||
uploadedFile, header, err := r.FormFile("avatar")
|
uploadedFile, header, err := r.FormFile("avatar")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.New(errorx.ParamError, "获取上传文件失败: " + err.Error())
|
return nil, errorx.New(errorx.ParamError, "获取上传文件失败: "+err.Error())
|
||||||
}
|
}
|
||||||
defer uploadedFile.Close()
|
defer uploadedFile.Close()
|
||||||
|
|
||||||
@ -90,10 +90,10 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
|
|||||||
|
|
||||||
filename := fmt.Sprintf("avatar_%d_%d%s", req.Id, time.Now().Unix(), ext)
|
filename := fmt.Sprintf("avatar_%d_%d%s", req.Id, time.Now().Unix(), ext)
|
||||||
avatarDir := "uploads/avatars"
|
avatarDir := "uploads/avatars"
|
||||||
|
|
||||||
// 8. 确保头像目录存在
|
// 8. 确保头像目录存在
|
||||||
if err := os.MkdirAll(avatarDir, 0755); err != nil {
|
if err := os.MkdirAll(avatarDir, 0755); err != nil {
|
||||||
return nil, errorx.New(errorx.ServerError, "创建头像目录失败: " + err.Error())
|
return nil, errorx.New(errorx.ServerError, "创建头像目录失败: "+err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
avatarPath := filepath.Join(avatarDir, filename)
|
avatarPath := filepath.Join(avatarDir, filename)
|
||||||
@ -101,13 +101,13 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
|
|||||||
// 9. 保存原始头像文件
|
// 9. 保存原始头像文件
|
||||||
destFile, err := os.Create(avatarPath)
|
destFile, err := os.Create(avatarPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.New(errorx.ServerError, "创建头像文件失败: " + err.Error())
|
return nil, errorx.New(errorx.ServerError, "创建头像文件失败: "+err.Error())
|
||||||
}
|
}
|
||||||
defer destFile.Close()
|
defer destFile.Close()
|
||||||
|
|
||||||
_, err = io.Copy(destFile, uploadedFile)
|
_, err = io.Copy(destFile, uploadedFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.New(errorx.ServerError, "保存头像文件失败: " + err.Error())
|
return nil, errorx.New(errorx.ServerError, "保存头像文件失败: "+err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// 10. 生成压缩版本的头像 (150x150像素)
|
// 10. 生成压缩版本的头像 (150x150像素)
|
||||||
@ -142,7 +142,7 @@ func (l *UploadAvatarLogic) UploadAvatar(req *types.UploadAvatarRequest, r *http
|
|||||||
if avatarPath != compressedPath {
|
if avatarPath != compressedPath {
|
||||||
os.Remove(compressedPath)
|
os.Remove(compressedPath)
|
||||||
}
|
}
|
||||||
return nil, errorx.New(errorx.ServerError, "更新用户头像失败: " + err.Error())
|
return nil, errorx.New(errorx.ServerError, "更新用户头像失败: "+err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return &types.UploadAvatarResponse{
|
return &types.UploadAvatarResponse{
|
||||||
|
|||||||
@ -73,4 +73,4 @@ func (e UnauthorizedError) Error() string {
|
|||||||
|
|
||||||
func NewUnauthorizedError(message string) UnauthorizedError {
|
func NewUnauthorizedError(message string) UnauthorizedError {
|
||||||
return UnauthorizedError{Message: message}
|
return UnauthorizedError{Message: message}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -87,7 +87,7 @@ func NewCORSMiddleware(config CORSConfig) *CORSMiddleware {
|
|||||||
func (m *CORSMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
func (m *CORSMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
|
|
||||||
// 检查来源是否被允许
|
// 检查来源是否被允许
|
||||||
if origin != "" && m.isOriginAllowed(origin) {
|
if origin != "" && m.isOriginAllowed(origin) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
@ -160,16 +160,16 @@ func (m *CORSMiddleware) isOriginAllowed(origin string) bool {
|
|||||||
func (m *CORSMiddleware) setSecurityHeaders(w http.ResponseWriter) {
|
func (m *CORSMiddleware) setSecurityHeaders(w http.ResponseWriter) {
|
||||||
// 防止点击劫持
|
// 防止点击劫持
|
||||||
w.Header().Set("X-Frame-Options", "DENY")
|
w.Header().Set("X-Frame-Options", "DENY")
|
||||||
|
|
||||||
// 防止 MIME 类型嗅探
|
// 防止 MIME 类型嗅探
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
// XSS 保护
|
// XSS 保护
|
||||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||||
|
|
||||||
// 引用者策略
|
// 引用者策略
|
||||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||||
|
|
||||||
// 内容安全策略 (基础版)
|
// 内容安全策略 (基础版)
|
||||||
w.Header().Set("Content-Security-Policy", "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self'")
|
w.Header().Set("Content-Security-Policy", "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self'")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,8 +19,8 @@ type ErrorConfig struct {
|
|||||||
EnableDetailedErrors bool // 是否启用详细错误信息 (开发环境)
|
EnableDetailedErrors bool // 是否启用详细错误信息 (开发环境)
|
||||||
EnableStackTrace bool // 是否启用堆栈跟踪
|
EnableStackTrace bool // 是否启用堆栈跟踪
|
||||||
EnableErrorMonitor bool // 是否启用错误监控
|
EnableErrorMonitor bool // 是否启用错误监控
|
||||||
IgnoreHTTPCodes []int // 忽略的HTTP状态码 (不记录为错误)
|
IgnoreHTTPCodes []int // 忽略的HTTP状态码 (不记录为错误)
|
||||||
SensitiveFields []string // 敏感字段列表 (日志时隐藏)
|
SensitiveFields []string // 敏感字段列表 (日志时隐藏)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultErrorConfig 默认错误配置
|
// DefaultErrorConfig 默认错误配置
|
||||||
@ -29,8 +29,8 @@ func DefaultErrorConfig() ErrorConfig {
|
|||||||
EnableDetailedErrors: false, // 生产环境默认关闭
|
EnableDetailedErrors: false, // 生产环境默认关闭
|
||||||
EnableStackTrace: false, // 生产环境默认关闭
|
EnableStackTrace: false, // 生产环境默认关闭
|
||||||
EnableErrorMonitor: true,
|
EnableErrorMonitor: true,
|
||||||
IgnoreHTTPCodes: []int{http.StatusNotFound, http.StatusMethodNotAllowed},
|
IgnoreHTTPCodes: []int{http.StatusNotFound, http.StatusMethodNotAllowed},
|
||||||
SensitiveFields: []string{"password", "token", "secret", "key", "authorization"},
|
SensitiveFields: []string{"password", "token", "secret", "key", "authorization"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ func (m *ErrorMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
|||||||
// handlePanic 处理panic
|
// handlePanic 处理panic
|
||||||
func (m *ErrorMiddleware) handlePanic(w *errorResponseWriter, r *http.Request, err interface{}) {
|
func (m *ErrorMiddleware) handlePanic(w *errorResponseWriter, r *http.Request, err interface{}) {
|
||||||
stack := string(debug.Stack())
|
stack := string(debug.Stack())
|
||||||
|
|
||||||
// 记录panic日志
|
// 记录panic日志
|
||||||
logFields := map[string]interface{}{
|
logFields := map[string]interface{}{
|
||||||
"error": err,
|
"error": err,
|
||||||
@ -206,7 +206,7 @@ func (m *ErrorMiddleware) respondWithError(w http.ResponseWriter, r *http.Reques
|
|||||||
|
|
||||||
// 设置HTTP状态码
|
// 设置HTTP状态码
|
||||||
httpStatus := errorx.GetHttpStatus(err.Code)
|
httpStatus := errorx.GetHttpStatus(err.Code)
|
||||||
|
|
||||||
// 设置响应头
|
// 设置响应头
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(httpStatus)
|
w.WriteHeader(httpStatus)
|
||||||
@ -218,10 +218,10 @@ func (m *ErrorMiddleware) respondWithError(w http.ResponseWriter, r *http.Reques
|
|||||||
// sanitizeFields 隐藏敏感字段
|
// sanitizeFields 隐藏敏感字段
|
||||||
func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string]interface{} {
|
func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string]interface{} {
|
||||||
sanitized := make(map[string]interface{})
|
sanitized := make(map[string]interface{})
|
||||||
|
|
||||||
for key, value := range data {
|
for key, value := range data {
|
||||||
lowerKey := strings.ToLower(key)
|
lowerKey := strings.ToLower(key)
|
||||||
|
|
||||||
// 检查是否为敏感字段
|
// 检查是否为敏感字段
|
||||||
sensitive := false
|
sensitive := false
|
||||||
for _, sensitiveField := range m.config.SensitiveFields {
|
for _, sensitiveField := range m.config.SensitiveFields {
|
||||||
@ -230,7 +230,7 @@ func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sensitive {
|
if sensitive {
|
||||||
sanitized[key] = "***REDACTED***"
|
sanitized[key] = "***REDACTED***"
|
||||||
} else {
|
} else {
|
||||||
@ -242,7 +242,7 @@ func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,4 +322,4 @@ var CommonErrors = struct {
|
|||||||
Code: 429,
|
Code: 429,
|
||||||
Msg: "Rate Limit Exceeded",
|
Msg: "Rate Limit Exceeded",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,8 +15,8 @@ import (
|
|||||||
|
|
||||||
// LoggerConfig 日志配置
|
// LoggerConfig 日志配置
|
||||||
type LoggerConfig struct {
|
type LoggerConfig struct {
|
||||||
EnableRequestBody bool // 是否记录请求体
|
EnableRequestBody bool // 是否记录请求体
|
||||||
EnableResponseBody bool // 是否记录响应体
|
EnableResponseBody bool // 是否记录响应体
|
||||||
MaxBodySize int64 // 最大记录的请求/响应体大小
|
MaxBodySize int64 // 最大记录的请求/响应体大小
|
||||||
SkipPaths []string // 跳过记录的路径
|
SkipPaths []string // 跳过记录的路径
|
||||||
SlowRequestDuration time.Duration // 慢请求阈值
|
SlowRequestDuration time.Duration // 慢请求阈值
|
||||||
@ -26,9 +26,9 @@ type LoggerConfig struct {
|
|||||||
// DefaultLoggerConfig 默认日志配置
|
// DefaultLoggerConfig 默认日志配置
|
||||||
func DefaultLoggerConfig() LoggerConfig {
|
func DefaultLoggerConfig() LoggerConfig {
|
||||||
return LoggerConfig{
|
return LoggerConfig{
|
||||||
EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息)
|
EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息)
|
||||||
EnableResponseBody: false, // 默认不记录响应体 (减少日志量)
|
EnableResponseBody: false, // 默认不记录响应体 (减少日志量)
|
||||||
MaxBodySize: 1024, // 最大记录1KB
|
MaxBodySize: 1024, // 最大记录1KB
|
||||||
SkipPaths: []string{"/health", "/metrics", "/favicon.ico"},
|
SkipPaths: []string{"/health", "/metrics", "/favicon.ico"},
|
||||||
SlowRequestDuration: 1 * time.Second,
|
SlowRequestDuration: 1 * time.Second,
|
||||||
EnablePanicRecover: true,
|
EnablePanicRecover: true,
|
||||||
@ -60,7 +60,7 @@ func newResponseWriter(w http.ResponseWriter) *responseWriter {
|
|||||||
return &responseWriter{
|
return &responseWriter{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
body: &bytes.Buffer{},
|
body: &bytes.Buffer{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,12 +68,12 @@ func newResponseWriter(w http.ResponseWriter) *responseWriter {
|
|||||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||||
size, err := rw.ResponseWriter.Write(b)
|
size, err := rw.ResponseWriter.Write(b)
|
||||||
rw.size += int64(size)
|
rw.size += int64(size)
|
||||||
|
|
||||||
// 记录响应体 (如果启用)
|
// 记录响应体 (如果启用)
|
||||||
if rw.body.Len() < int(1024) { // 限制缓存大小
|
if rw.body.Len() < int(1024) { // 限制缓存大小
|
||||||
rw.body.Write(b)
|
rw.body.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
return size, err
|
return size, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ func (m *LoggerMiddleware) generateRequestID(r *http.Request) string {
|
|||||||
if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
|
if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
|
||||||
return requestID
|
return requestID
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成新的请求ID
|
// 生成新的请求ID
|
||||||
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
|
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
|
||||||
}
|
}
|
||||||
@ -208,13 +208,13 @@ func (m *LoggerMiddleware) logRequestStart(r *http.Request, requestID, requestBo
|
|||||||
// logRequestComplete 记录请求完成
|
// logRequestComplete 记录请求完成
|
||||||
func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) {
|
func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) {
|
||||||
fields := map[string]interface{}{
|
fields := map[string]interface{}{
|
||||||
"request_id": requestID,
|
"request_id": requestID,
|
||||||
"method": r.Method,
|
"method": r.Method,
|
||||||
"path": r.URL.Path,
|
"path": r.URL.Path,
|
||||||
"status": status,
|
"status": status,
|
||||||
"response_size": size,
|
"response_size": size,
|
||||||
"duration_ms": duration.Milliseconds(),
|
"duration_ms": duration.Milliseconds(),
|
||||||
"duration": duration.String(),
|
"duration": duration.String(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseBody != "" {
|
if responseBody != "" {
|
||||||
@ -267,7 +267,7 @@ func getClientIP(r *http.Request) string {
|
|||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 RemoteAddr
|
// 使用 RemoteAddr
|
||||||
if ip := r.RemoteAddr; ip != "" {
|
if ip := r.RemoteAddr; ip != "" {
|
||||||
// 移除端口号
|
// 移除端口号
|
||||||
@ -276,7 +276,7 @@ func getClientIP(r *http.Request) string {
|
|||||||
}
|
}
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,4 +296,4 @@ func randomString(length int) string {
|
|||||||
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
|
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
|
||||||
}
|
}
|
||||||
return string(result)
|
return string(result)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,11 +13,11 @@ import (
|
|||||||
|
|
||||||
// MiddlewareManager 中间件管理器
|
// MiddlewareManager 中间件管理器
|
||||||
type MiddlewareManager struct {
|
type MiddlewareManager struct {
|
||||||
config config.Config
|
config config.Config
|
||||||
corsMiddleware *CORSMiddleware
|
corsMiddleware *CORSMiddleware
|
||||||
logMiddleware *LoggerMiddleware
|
logMiddleware *LoggerMiddleware
|
||||||
errorMiddleware *ErrorMiddleware
|
errorMiddleware *ErrorMiddleware
|
||||||
authMiddleware *AuthMiddleware
|
authMiddleware *AuthMiddleware
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMiddlewareManager 创建中间件管理器
|
// NewMiddlewareManager 创建中间件管理器
|
||||||
@ -34,12 +34,12 @@ func NewMiddlewareManager(c config.Config) *MiddlewareManager {
|
|||||||
// getCORSConfig 获取CORS配置
|
// getCORSConfig 获取CORS配置
|
||||||
func getCORSConfig(c config.Config) CORSConfig {
|
func getCORSConfig(c config.Config) CORSConfig {
|
||||||
env := getEnvironment()
|
env := getEnvironment()
|
||||||
|
|
||||||
if env == "production" {
|
if env == "production" {
|
||||||
// 生产环境使用严格的CORS配置
|
// 生产环境使用严格的CORS配置
|
||||||
return ProductionCORSConfig(getProductionOrigins())
|
return ProductionCORSConfig(getProductionOrigins())
|
||||||
}
|
}
|
||||||
|
|
||||||
// 开发环境使用宽松的CORS配置
|
// 开发环境使用宽松的CORS配置
|
||||||
return DefaultCORSConfig()
|
return DefaultCORSConfig()
|
||||||
}
|
}
|
||||||
@ -47,27 +47,27 @@ func getCORSConfig(c config.Config) CORSConfig {
|
|||||||
// getLoggerConfig 获取日志配置
|
// getLoggerConfig 获取日志配置
|
||||||
func getLoggerConfig(c config.Config) LoggerConfig {
|
func getLoggerConfig(c config.Config) LoggerConfig {
|
||||||
env := getEnvironment()
|
env := getEnvironment()
|
||||||
|
|
||||||
config := DefaultLoggerConfig()
|
config := DefaultLoggerConfig()
|
||||||
|
|
||||||
if env == "development" {
|
if env == "development" {
|
||||||
// 开发环境启用详细日志
|
// 开发环境启用详细日志
|
||||||
config.EnableRequestBody = true
|
config.EnableRequestBody = true
|
||||||
config.EnableResponseBody = true
|
config.EnableResponseBody = true
|
||||||
config.MaxBodySize = 4096
|
config.MaxBodySize = 4096
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
// getErrorConfig 获取错误配置
|
// getErrorConfig 获取错误配置
|
||||||
func getErrorConfig(c config.Config) ErrorConfig {
|
func getErrorConfig(c config.Config) ErrorConfig {
|
||||||
env := getEnvironment()
|
env := getEnvironment()
|
||||||
|
|
||||||
if env == "development" {
|
if env == "development" {
|
||||||
return DevelopmentErrorConfig()
|
return DevelopmentErrorConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
return DefaultErrorConfig()
|
return DefaultErrorConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,9 +108,9 @@ func (m *MiddlewareManager) Chain(handler http.HandlerFunc, middlewares ...func(
|
|||||||
// GetGlobalMiddlewares 获取全局中间件
|
// GetGlobalMiddlewares 获取全局中间件
|
||||||
func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc {
|
func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc {
|
||||||
return []func(http.HandlerFunc) http.HandlerFunc{
|
return []func(http.HandlerFunc) http.HandlerFunc{
|
||||||
m.errorMiddleware.Handle, // 错误处理 (最外层)
|
m.errorMiddleware.Handle, // 错误处理 (最外层)
|
||||||
m.corsMiddleware.Handle, // CORS 处理
|
m.corsMiddleware.Handle, // CORS 处理
|
||||||
m.logMiddleware.Handle, // 日志记录
|
m.logMiddleware.Handle, // 日志记录
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,7 +134,7 @@ func (m *MiddlewareManager) ApplyAuthMiddlewares(handler http.HandlerFunc) http.
|
|||||||
func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) {
|
func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(`{"status":"ok","timestamp":"` +
|
w.Write([]byte(`{"status":"ok","timestamp":"` +
|
||||||
time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`))
|
time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ func Recovery() MiddlewareFunc {
|
|||||||
"path": r.URL.Path,
|
"path": r.URL.Path,
|
||||||
}
|
}
|
||||||
logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields)
|
logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields)
|
||||||
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -188,10 +188,10 @@ func RequestID() MiddlewareFunc {
|
|||||||
if requestID == "" {
|
if requestID == "" {
|
||||||
requestID = generateRequestID()
|
requestID = generateRequestID()
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
w.Header().Set("X-Request-ID", requestID)
|
||||||
r.Header.Set("X-Request-ID", requestID)
|
r.Header.Set("X-Request-ID", requestID)
|
||||||
|
|
||||||
next(w, r)
|
next(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -200,4 +200,4 @@ func RequestID() MiddlewareFunc {
|
|||||||
// generateRequestID 生成请求ID
|
// generateRequestID 生成请求ID
|
||||||
func generateRequestID() string {
|
func generateRequestID() string {
|
||||||
return randomString(16)
|
return randomString(16)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,8 +3,8 @@ package model
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ CategoryModel = (*customCategoryModel)(nil)
|
var _ CategoryModel = (*customCategoryModel)(nil)
|
||||||
@ -39,20 +39,20 @@ func (m *customCategoryModel) withSession(session sqlx.Session) CategoryModel {
|
|||||||
func (m *customCategoryModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*Category, error) {
|
func (m *customCategoryModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*Category, error) {
|
||||||
var conditions []string
|
var conditions []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)")
|
conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)")
|
||||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
whereClause := ""
|
whereClause := ""
|
||||||
if len(conditions) > 0 {
|
if len(conditions) > 0 {
|
||||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||||
}
|
}
|
||||||
|
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
args = append(args, pageSize, offset)
|
args = append(args, pageSize, offset)
|
||||||
|
|
||||||
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", categoryRows, m.table, whereClause)
|
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", categoryRows, m.table, whereClause)
|
||||||
var resp []*Category
|
var resp []*Category
|
||||||
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
||||||
@ -63,17 +63,17 @@ func (m *customCategoryModel) FindList(ctx context.Context, page, pageSize int,
|
|||||||
func (m *customCategoryModel) Count(ctx context.Context, keyword string) (int64, error) {
|
func (m *customCategoryModel) Count(ctx context.Context, keyword string) (int64, error) {
|
||||||
var conditions []string
|
var conditions []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)")
|
conditions = append(conditions, "(`name` LIKE ? OR `description` LIKE ?)")
|
||||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
whereClause := ""
|
whereClause := ""
|
||||||
if len(conditions) > 0 {
|
if len(conditions) > 0 {
|
||||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
||||||
var count int64
|
var count int64
|
||||||
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
||||||
|
|||||||
@ -3,8 +3,8 @@ package model
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ PhotoModel = (*customPhotoModel)(nil)
|
var _ PhotoModel = (*customPhotoModel)(nil)
|
||||||
@ -39,30 +39,30 @@ func (m *customPhotoModel) withSession(session sqlx.Session) PhotoModel {
|
|||||||
func (m *customPhotoModel) FindList(ctx context.Context, page, pageSize int, categoryId, userId int64, keyword string) ([]*Photo, error) {
|
func (m *customPhotoModel) FindList(ctx context.Context, page, pageSize int, categoryId, userId int64, keyword string) ([]*Photo, error) {
|
||||||
var conditions []string
|
var conditions []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
if categoryId > 0 {
|
if categoryId > 0 {
|
||||||
conditions = append(conditions, "`category_id` = ?")
|
conditions = append(conditions, "`category_id` = ?")
|
||||||
args = append(args, categoryId)
|
args = append(args, categoryId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if userId > 0 {
|
if userId > 0 {
|
||||||
conditions = append(conditions, "`user_id` = ?")
|
conditions = append(conditions, "`user_id` = ?")
|
||||||
args = append(args, userId)
|
args = append(args, userId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)")
|
conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)")
|
||||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
whereClause := ""
|
whereClause := ""
|
||||||
if len(conditions) > 0 {
|
if len(conditions) > 0 {
|
||||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||||
}
|
}
|
||||||
|
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
args = append(args, pageSize, offset)
|
args = append(args, pageSize, offset)
|
||||||
|
|
||||||
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", photoRows, m.table, whereClause)
|
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", photoRows, m.table, whereClause)
|
||||||
var resp []*Photo
|
var resp []*Photo
|
||||||
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
||||||
@ -73,27 +73,27 @@ func (m *customPhotoModel) FindList(ctx context.Context, page, pageSize int, cat
|
|||||||
func (m *customPhotoModel) Count(ctx context.Context, categoryId, userId int64, keyword string) (int64, error) {
|
func (m *customPhotoModel) Count(ctx context.Context, categoryId, userId int64, keyword string) (int64, error) {
|
||||||
var conditions []string
|
var conditions []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
if categoryId > 0 {
|
if categoryId > 0 {
|
||||||
conditions = append(conditions, "`category_id` = ?")
|
conditions = append(conditions, "`category_id` = ?")
|
||||||
args = append(args, categoryId)
|
args = append(args, categoryId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if userId > 0 {
|
if userId > 0 {
|
||||||
conditions = append(conditions, "`user_id` = ?")
|
conditions = append(conditions, "`user_id` = ?")
|
||||||
args = append(args, userId)
|
args = append(args, userId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)")
|
conditions = append(conditions, "(`title` LIKE ? OR `description` LIKE ?)")
|
||||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
whereClause := ""
|
whereClause := ""
|
||||||
if len(conditions) > 0 {
|
if len(conditions) > 0 {
|
||||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
||||||
var count int64
|
var count int64
|
||||||
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
||||||
|
|||||||
@ -3,8 +3,8 @@ package model
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ UserModel = (*customUserModel)(nil)
|
var _ UserModel = (*customUserModel)(nil)
|
||||||
@ -39,20 +39,20 @@ func (m *customUserModel) withSession(session sqlx.Session) UserModel {
|
|||||||
func (m *customUserModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*User, error) {
|
func (m *customUserModel) FindList(ctx context.Context, page, pageSize int, keyword string) ([]*User, error) {
|
||||||
var conditions []string
|
var conditions []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)")
|
conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)")
|
||||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
whereClause := ""
|
whereClause := ""
|
||||||
if len(conditions) > 0 {
|
if len(conditions) > 0 {
|
||||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||||
}
|
}
|
||||||
|
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
args = append(args, pageSize, offset)
|
args = append(args, pageSize, offset)
|
||||||
|
|
||||||
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", userRows, m.table, whereClause)
|
query := fmt.Sprintf("select %s from %s%s ORDER BY `created_at` DESC LIMIT ? OFFSET ?", userRows, m.table, whereClause)
|
||||||
var resp []*User
|
var resp []*User
|
||||||
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
err := m.conn.QueryRowsCtx(ctx, &resp, query, args...)
|
||||||
@ -63,17 +63,17 @@ func (m *customUserModel) FindList(ctx context.Context, page, pageSize int, keyw
|
|||||||
func (m *customUserModel) Count(ctx context.Context, keyword string) (int64, error) {
|
func (m *customUserModel) Count(ctx context.Context, keyword string) (int64, error) {
|
||||||
var conditions []string
|
var conditions []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)")
|
conditions = append(conditions, "(`username` LIKE ? OR `email` LIKE ?)")
|
||||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
whereClause := ""
|
whereClause := ""
|
||||||
if len(conditions) > 0 {
|
if len(conditions) > 0 {
|
||||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
query := fmt.Sprintf("select count(*) from %s%s", m.table, whereClause)
|
||||||
var count int64
|
var count int64
|
||||||
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
err := m.conn.QueryRowCtx(ctx, &count, query, args...)
|
||||||
|
|||||||
@ -2,17 +2,17 @@ package svc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"photography-backend/internal/config"
|
"photography-backend/internal/config"
|
||||||
"photography-backend/internal/middleware"
|
"photography-backend/internal/middleware"
|
||||||
"photography-backend/internal/model"
|
"photography-backend/internal/model"
|
||||||
"photography-backend/pkg/utils/database"
|
"photography-backend/pkg/utils/database"
|
||||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServiceContext struct {
|
type ServiceContext struct {
|
||||||
Config config.Config
|
Config config.Config
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
UserModel model.UserModel
|
UserModel model.UserModel
|
||||||
PhotoModel model.PhotoModel
|
PhotoModel model.PhotoModel
|
||||||
CategoryModel model.CategoryModel
|
CategoryModel model.CategoryModel
|
||||||
@ -24,13 +24,13 @@ func NewServiceContext(c config.Config) *ServiceContext {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create sqlx connection for go-zero models
|
// Create sqlx connection for go-zero models
|
||||||
sqlxConn := sqlx.NewSqlConn(getSQLDriverName(c.Database.Driver), getSQLDataSource(c.Database))
|
sqlxConn := sqlx.NewSqlConn(getSQLDriverName(c.Database.Driver), getSQLDataSource(c.Database))
|
||||||
|
|
||||||
return &ServiceContext{
|
return &ServiceContext{
|
||||||
Config: c,
|
Config: c,
|
||||||
DB: db,
|
DB: db,
|
||||||
UserModel: model.NewUserModel(sqlxConn),
|
UserModel: model.NewUserModel(sqlxConn),
|
||||||
PhotoModel: model.NewPhotoModel(sqlxConn),
|
PhotoModel: model.NewPhotoModel(sqlxConn),
|
||||||
CategoryModel: model.NewCategoryModel(sqlxConn),
|
CategoryModel: model.NewCategoryModel(sqlxConn),
|
||||||
|
|||||||
@ -4,25 +4,25 @@ const (
|
|||||||
// 用户状态
|
// 用户状态
|
||||||
UserStatusActive = 1
|
UserStatusActive = 1
|
||||||
UserStatusInactive = 0
|
UserStatusInactive = 0
|
||||||
|
|
||||||
// 文件上传
|
// 文件上传
|
||||||
MaxFileSize = 10 << 20 // 10MB
|
MaxFileSize = 10 << 20 // 10MB
|
||||||
|
|
||||||
// 图片类型
|
// 图片类型
|
||||||
ImageTypeJPEG = "image/jpeg"
|
ImageTypeJPEG = "image/jpeg"
|
||||||
ImageTypePNG = "image/png"
|
ImageTypePNG = "image/png"
|
||||||
ImageTypeGIF = "image/gif"
|
ImageTypeGIF = "image/gif"
|
||||||
ImageTypeWEBP = "image/webp"
|
ImageTypeWEBP = "image/webp"
|
||||||
|
|
||||||
// 缩略图尺寸
|
// 缩略图尺寸
|
||||||
ThumbnailWidth = 300
|
ThumbnailWidth = 300
|
||||||
ThumbnailHeight = 300
|
ThumbnailHeight = 300
|
||||||
|
|
||||||
// JWT 过期时间
|
// JWT 过期时间
|
||||||
TokenExpireDuration = 24 * 60 * 60 // 24小时
|
TokenExpireDuration = 24 * 60 * 60 // 24小时
|
||||||
|
|
||||||
// 分页默认值
|
// 分页默认值
|
||||||
DefaultPage = 1
|
DefaultPage = 1
|
||||||
DefaultPageSize = 10
|
DefaultPageSize = 10
|
||||||
MaxPageSize = 100
|
MaxPageSize = 100
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,14 +7,14 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// 通用错误代码
|
// 通用错误代码
|
||||||
Success = 0
|
Success = 0
|
||||||
ServerError = 500
|
ServerError = 500
|
||||||
ParamError = 400
|
ParamError = 400
|
||||||
AuthError = 401
|
AuthError = 401
|
||||||
NotFound = 404
|
NotFound = 404
|
||||||
Forbidden = 403
|
Forbidden = 403
|
||||||
InvalidParameter = 400 // 与 ParamError 统一
|
InvalidParameter = 400 // 与 ParamError 统一
|
||||||
|
|
||||||
// 业务错误代码
|
// 业务错误代码
|
||||||
UserNotFound = 1001
|
UserNotFound = 1001
|
||||||
UserExists = 1002
|
UserExists = 1002
|
||||||
@ -22,32 +22,32 @@ const (
|
|||||||
InvalidPassword = 1004
|
InvalidPassword = 1004
|
||||||
TokenExpired = 1005
|
TokenExpired = 1005
|
||||||
TokenInvalid = 1006
|
TokenInvalid = 1006
|
||||||
|
|
||||||
PhotoNotFound = 2001
|
PhotoNotFound = 2001
|
||||||
PhotoUploadFail = 2002
|
PhotoUploadFail = 2002
|
||||||
|
|
||||||
CategoryNotFound = 3001
|
CategoryNotFound = 3001
|
||||||
CategoryExists = 3002
|
CategoryExists = 3002
|
||||||
)
|
)
|
||||||
|
|
||||||
var codeText = map[int]string{
|
var codeText = map[int]string{
|
||||||
Success: "Success",
|
Success: "Success",
|
||||||
ServerError: "Server Error",
|
ServerError: "Server Error",
|
||||||
ParamError: "Parameter Error", // ParamError 和 InvalidParameter 都映射到这里
|
ParamError: "Parameter Error", // ParamError 和 InvalidParameter 都映射到这里
|
||||||
AuthError: "Authentication Error",
|
AuthError: "Authentication Error",
|
||||||
NotFound: "Not Found",
|
NotFound: "Not Found",
|
||||||
Forbidden: "Forbidden",
|
Forbidden: "Forbidden",
|
||||||
|
|
||||||
UserNotFound: "User Not Found",
|
UserNotFound: "User Not Found",
|
||||||
UserExists: "User Already Exists",
|
UserExists: "User Already Exists",
|
||||||
UserDisabled: "User Disabled",
|
UserDisabled: "User Disabled",
|
||||||
InvalidPassword: "Invalid Password",
|
InvalidPassword: "Invalid Password",
|
||||||
TokenExpired: "Token Expired",
|
TokenExpired: "Token Expired",
|
||||||
TokenInvalid: "Token Invalid",
|
TokenInvalid: "Token Invalid",
|
||||||
|
|
||||||
PhotoNotFound: "Photo Not Found",
|
PhotoNotFound: "Photo Not Found",
|
||||||
PhotoUploadFail: "Photo Upload Failed",
|
PhotoUploadFail: "Photo Upload Failed",
|
||||||
|
|
||||||
CategoryNotFound: "Category Not Found",
|
CategoryNotFound: "Category Not Found",
|
||||||
CategoryExists: "Category Already Exists",
|
CategoryExists: "Category Already Exists",
|
||||||
}
|
}
|
||||||
@ -83,7 +83,7 @@ func GetHttpStatus(code int) int {
|
|||||||
switch code {
|
switch code {
|
||||||
case Success:
|
case Success:
|
||||||
return http.StatusOK
|
return http.StatusOK
|
||||||
case ParamError: // ParamError 和 InvalidParameter 都是 400,所以只需要一个 case
|
case ParamError: // ParamError 和 InvalidParameter 都是 400,所以只需要一个 case
|
||||||
return http.StatusBadRequest
|
return http.StatusBadRequest
|
||||||
case AuthError, TokenExpired, TokenInvalid:
|
case AuthError, TokenExpired, TokenInvalid:
|
||||||
return http.StatusUnauthorized
|
return http.StatusUnauthorized
|
||||||
@ -96,4 +96,4 @@ func GetHttpStatus(code int) int {
|
|||||||
default:
|
default:
|
||||||
return http.StatusInternalServerError
|
return http.StatusInternalServerError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,10 +21,10 @@ type Migration struct {
|
|||||||
|
|
||||||
// MigrationRecord 数据库中的迁移记录
|
// MigrationRecord 数据库中的迁移记录
|
||||||
type MigrationRecord struct {
|
type MigrationRecord struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
Version string `gorm:"uniqueIndex;size:255;not null"`
|
Version string `gorm:"uniqueIndex;size:255;not null"`
|
||||||
Description string `gorm:"size:500"`
|
Description string `gorm:"size:500"`
|
||||||
Applied bool `gorm:"default:false"`
|
Applied bool `gorm:"default:false"`
|
||||||
AppliedAt time.Time
|
AppliedAt time.Time
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
@ -66,7 +66,7 @@ func (m *Migrator) GetAppliedMigrations() ([]string, error) {
|
|||||||
if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil {
|
if err := m.db.Where("applied = ?", true).Order("version ASC").Find(&records).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
versions := make([]string, len(records))
|
versions := make([]string, len(records))
|
||||||
for i, record := range records {
|
for i, record := range records {
|
||||||
versions[i] = record.Version
|
versions[i] = record.Version
|
||||||
@ -80,24 +80,24 @@ func (m *Migrator) GetPendingMigrations() ([]Migration, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
appliedMap := make(map[string]bool)
|
appliedMap := make(map[string]bool)
|
||||||
for _, version := range appliedVersions {
|
for _, version := range appliedVersions {
|
||||||
appliedMap[version] = true
|
appliedMap[version] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var pendingMigrations []Migration
|
var pendingMigrations []Migration
|
||||||
for _, migration := range m.migrations {
|
for _, migration := range m.migrations {
|
||||||
if !appliedMap[migration.Version] {
|
if !appliedMap[migration.Version] {
|
||||||
pendingMigrations = append(pendingMigrations, migration)
|
pendingMigrations = append(pendingMigrations, migration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 按版本号排序
|
// 按版本号排序
|
||||||
sort.Slice(pendingMigrations, func(i, j int) bool {
|
sort.Slice(pendingMigrations, func(i, j int) bool {
|
||||||
return pendingMigrations[i].Version < pendingMigrations[j].Version
|
return pendingMigrations[i].Version < pendingMigrations[j].Version
|
||||||
})
|
})
|
||||||
|
|
||||||
return pendingMigrations, nil
|
return pendingMigrations, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,39 +106,39 @@ func (m *Migrator) Up() error {
|
|||||||
if err := m.initMigrationTable(); err != nil {
|
if err := m.initMigrationTable(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pendingMigrations, err := m.GetPendingMigrations()
|
pendingMigrations, err := m.GetPendingMigrations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(pendingMigrations) == 0 {
|
if len(pendingMigrations) == 0 {
|
||||||
log.Println("No pending migrations")
|
log.Println("No pending migrations")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, migration := range pendingMigrations {
|
for _, migration := range pendingMigrations {
|
||||||
log.Printf("Applying migration %s: %s", migration.Version, migration.Description)
|
log.Printf("Applying migration %s: %s", migration.Version, migration.Description)
|
||||||
|
|
||||||
// 开始事务
|
// 开始事务
|
||||||
tx := m.db.Begin()
|
tx := m.db.Begin()
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
return tx.Error
|
return tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行迁移SQL
|
// 执行迁移SQL
|
||||||
if err := m.executeSQL(tx, migration.UpSQL); err != nil {
|
if err := m.executeSQL(tx, migration.UpSQL); err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err)
|
return fmt.Errorf("failed to apply migration %s: %v", migration.Version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 记录迁移状态 (使用UPSERT)
|
// 记录迁移状态 (使用UPSERT)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// 检查记录是否已存在
|
// 检查记录是否已存在
|
||||||
var existingRecord MigrationRecord
|
var existingRecord MigrationRecord
|
||||||
err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error
|
err := tx.Where("version = ?", migration.Version).First(&existingRecord).Error
|
||||||
|
|
||||||
if err == gorm.ErrRecordNotFound {
|
if err == gorm.ErrRecordNotFound {
|
||||||
// 创建新记录
|
// 创建新记录
|
||||||
record := MigrationRecord{
|
record := MigrationRecord{
|
||||||
@ -167,15 +167,15 @@ func (m *Migrator) Up() error {
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err)
|
return fmt.Errorf("failed to check migration record %s: %v", migration.Version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提交事务
|
// 提交事务
|
||||||
if err := tx.Commit().Error; err != nil {
|
if err := tx.Commit().Error; err != nil {
|
||||||
return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err)
|
return fmt.Errorf("failed to commit migration %s: %v", migration.Version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Successfully applied migration %s", migration.Version)
|
log.Printf("Successfully applied migration %s", migration.Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,58 +184,58 @@ func (m *Migrator) Down(steps int) error {
|
|||||||
if err := m.initMigrationTable(); err != nil {
|
if err := m.initMigrationTable(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
appliedVersions, err := m.GetAppliedMigrations()
|
appliedVersions, err := m.GetAppliedMigrations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(appliedVersions) == 0 {
|
if len(appliedVersions) == 0 {
|
||||||
log.Println("No applied migrations to rollback")
|
log.Println("No applied migrations to rollback")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取要回滚的迁移(从最新开始)
|
// 获取要回滚的迁移(从最新开始)
|
||||||
rollbackCount := steps
|
rollbackCount := steps
|
||||||
if rollbackCount > len(appliedVersions) {
|
if rollbackCount > len(appliedVersions) {
|
||||||
rollbackCount = len(appliedVersions)
|
rollbackCount = len(appliedVersions)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- {
|
for i := len(appliedVersions) - 1; i >= len(appliedVersions)-rollbackCount; i-- {
|
||||||
version := appliedVersions[i]
|
version := appliedVersions[i]
|
||||||
migration := m.findMigrationByVersion(version)
|
migration := m.findMigrationByVersion(version)
|
||||||
if migration == nil {
|
if migration == nil {
|
||||||
return fmt.Errorf("migration %s not found in migration definitions", version)
|
return fmt.Errorf("migration %s not found in migration definitions", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description)
|
log.Printf("Rolling back migration %s: %s", migration.Version, migration.Description)
|
||||||
|
|
||||||
// 开始事务
|
// 开始事务
|
||||||
tx := m.db.Begin()
|
tx := m.db.Begin()
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
return tx.Error
|
return tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行回滚SQL
|
// 执行回滚SQL
|
||||||
if err := m.executeSQL(tx, migration.DownSQL); err != nil {
|
if err := m.executeSQL(tx, migration.DownSQL); err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err)
|
return fmt.Errorf("failed to rollback migration %s: %v", migration.Version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新迁移状态
|
// 更新迁移状态
|
||||||
if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil {
|
if err := tx.Model(&MigrationRecord{}).Where("version = ?", version).Update("applied", false).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
|
return fmt.Errorf("failed to update migration record %s: %v", migration.Version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提交事务
|
// 提交事务
|
||||||
if err := tx.Commit().Error; err != nil {
|
if err := tx.Commit().Error; err != nil {
|
||||||
return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err)
|
return fmt.Errorf("failed to commit rollback %s: %v", migration.Version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Successfully rolled back migration %s", migration.Version)
|
log.Printf("Successfully rolled back migration %s", migration.Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,27 +244,27 @@ func (m *Migrator) Status() error {
|
|||||||
if err := m.initMigrationTable(); err != nil {
|
if err := m.initMigrationTable(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
appliedVersions, err := m.GetAppliedMigrations()
|
appliedVersions, err := m.GetAppliedMigrations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
appliedMap := make(map[string]bool)
|
appliedMap := make(map[string]bool)
|
||||||
for _, version := range appliedVersions {
|
for _, version := range appliedVersions {
|
||||||
appliedMap[version] = true
|
appliedMap[version] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 排序所有迁移
|
// 排序所有迁移
|
||||||
allMigrations := m.migrations
|
allMigrations := m.migrations
|
||||||
sort.Slice(allMigrations, func(i, j int) bool {
|
sort.Slice(allMigrations, func(i, j int) bool {
|
||||||
return allMigrations[i].Version < allMigrations[j].Version
|
return allMigrations[i].Version < allMigrations[j].Version
|
||||||
})
|
})
|
||||||
|
|
||||||
fmt.Println("Migration Status:")
|
fmt.Println("Migration Status:")
|
||||||
fmt.Println("Version | Status | Description")
|
fmt.Println("Version | Status | Description")
|
||||||
fmt.Println("---------------|---------|----------------------------------")
|
fmt.Println("---------------|---------|----------------------------------")
|
||||||
|
|
||||||
for _, migration := range allMigrations {
|
for _, migration := range allMigrations {
|
||||||
status := "Pending"
|
status := "Pending"
|
||||||
if appliedMap[migration.Version] {
|
if appliedMap[migration.Version] {
|
||||||
@ -272,7 +272,7 @@ func (m *Migrator) Status() error {
|
|||||||
}
|
}
|
||||||
fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description)
|
fmt.Printf("%-14s | %-7s | %s\n", migration.Version, status, migration.Description)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,18 +280,18 @@ func (m *Migrator) Status() error {
|
|||||||
func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error {
|
func (m *Migrator) executeSQL(tx *gorm.DB, sqlStr string) error {
|
||||||
// 分割SQL语句(按分号分割)
|
// 分割SQL语句(按分号分割)
|
||||||
statements := strings.Split(sqlStr, ";")
|
statements := strings.Split(sqlStr, ";")
|
||||||
|
|
||||||
for _, statement := range statements {
|
for _, statement := range statements {
|
||||||
statement = strings.TrimSpace(statement)
|
statement = strings.TrimSpace(statement)
|
||||||
if statement == "" {
|
if statement == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Exec(statement).Error; err != nil {
|
if err := tx.Exec(statement).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,25 +308,25 @@ func (m *Migrator) findMigrationByVersion(version string) *Migration {
|
|||||||
// Reset 重置数据库(谨慎使用)
|
// Reset 重置数据库(谨慎使用)
|
||||||
func (m *Migrator) Reset() error {
|
func (m *Migrator) Reset() error {
|
||||||
log.Println("WARNING: This will drop all tables and reset the database!")
|
log.Println("WARNING: This will drop all tables and reset the database!")
|
||||||
|
|
||||||
// 获取所有应用的迁移
|
// 获取所有应用的迁移
|
||||||
appliedVersions, err := m.GetAppliedMigrations()
|
appliedVersions, err := m.GetAppliedMigrations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 回滚所有迁移
|
// 回滚所有迁移
|
||||||
if len(appliedVersions) > 0 {
|
if len(appliedVersions) > 0 {
|
||||||
if err := m.Down(len(appliedVersions)); err != nil {
|
if err := m.Down(len(appliedVersions)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 删除迁移表
|
// 删除迁移表
|
||||||
if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil {
|
if err := m.db.Migrator().DropTable(&MigrationRecord{}); err != nil {
|
||||||
return fmt.Errorf("failed to drop migration table: %v", err)
|
return fmt.Errorf("failed to drop migration table: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Database reset completed")
|
log.Println("Database reset completed")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -337,25 +337,25 @@ func (m *Migrator) Migrate(steps int) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(pendingMigrations) == 0 {
|
if len(pendingMigrations) == 0 {
|
||||||
log.Println("No pending migrations")
|
log.Println("No pending migrations")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
migrateCount := steps
|
migrateCount := steps
|
||||||
if steps <= 0 || steps > len(pendingMigrations) {
|
if steps <= 0 || steps > len(pendingMigrations) {
|
||||||
migrateCount = len(pendingMigrations)
|
migrateCount = len(pendingMigrations)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 临时修改migrations列表,只包含要执行的迁移
|
// 临时修改migrations列表,只包含要执行的迁移
|
||||||
originalMigrations := m.migrations
|
originalMigrations := m.migrations
|
||||||
m.migrations = pendingMigrations[:migrateCount]
|
m.migrations = pendingMigrations[:migrateCount]
|
||||||
|
|
||||||
err = m.Up()
|
err = m.Up()
|
||||||
|
|
||||||
// 恢复原始migrations列表
|
// 恢复原始migrations列表
|
||||||
m.migrations = originalMigrations
|
m.migrations = originalMigrations
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -309,13 +309,13 @@ func GetLatestMigrationVersion() string {
|
|||||||
if len(migrations) == 0 {
|
if len(migrations) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
latest := migrations[0]
|
latest := migrations[0]
|
||||||
for _, migration := range migrations {
|
for _, migration := range migrations {
|
||||||
if migration.Version > latest.Version {
|
if migration.Version > latest.Version {
|
||||||
latest = migration
|
latest = migration
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return latest.Version
|
return latest.Version
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@ package response
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"photography-backend/pkg/errorx"
|
"photography-backend/pkg/errorx"
|
||||||
)
|
)
|
||||||
@ -39,4 +39,4 @@ func Success(w http.ResponseWriter, data interface{}) {
|
|||||||
|
|
||||||
func Error(w http.ResponseWriter, err error) {
|
func Error(w http.ResponseWriter, err error) {
|
||||||
Response(w, nil, err)
|
Response(w, nil, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Driver string `json:"driver"` // mysql, postgres, sqlite
|
Driver string `json:"driver"` // mysql, postgres, sqlite
|
||||||
Host string `json:"host,optional"`
|
Host string `json:"host,optional"`
|
||||||
Port int `json:"port,optional"`
|
Port int `json:"port,optional"`
|
||||||
Username string `json:"username,optional"`
|
Username string `json:"username,optional"`
|
||||||
@ -27,7 +27,7 @@ type Config struct {
|
|||||||
func NewDB(config Config) (*gorm.DB, error) {
|
func NewDB(config Config) (*gorm.DB, error) {
|
||||||
var db *gorm.DB
|
var db *gorm.DB
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// 配置日志
|
// 配置日志
|
||||||
newLogger := logger.New(
|
newLogger := logger.New(
|
||||||
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
|
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
|
||||||
@ -37,7 +37,7 @@ func NewDB(config Config) (*gorm.DB, error) {
|
|||||||
Colorful: false, // 禁用彩色打印
|
Colorful: false, // 禁用彩色打印
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
switch config.Driver {
|
switch config.Driver {
|
||||||
case "mysql":
|
case "mysql":
|
||||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
|
||||||
@ -58,20 +58,20 @@ func NewDB(config Config) (*gorm.DB, error) {
|
|||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported database driver: %s", config.Driver)
|
return nil, fmt.Errorf("unsupported database driver: %s", config.Driver)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置连接池
|
// 设置连接池
|
||||||
sqlDB, err := db.DB()
|
sqlDB, err := db.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB.SetMaxIdleConns(10)
|
sqlDB.SetMaxIdleConns(10)
|
||||||
sqlDB.SetMaxOpenConns(100)
|
sqlDB.SetMaxOpenConns(100)
|
||||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||||
|
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/disintegration/imaging"
|
"github.com/disintegration/imaging"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -64,39 +64,39 @@ func SaveFile(file multipart.File, header *multipart.FileHeader, config Config)
|
|||||||
if header.Size > config.MaxSize {
|
if header.Size > config.MaxSize {
|
||||||
return nil, fmt.Errorf("文件大小超过限制: %d bytes", config.MaxSize)
|
return nil, fmt.Errorf("文件大小超过限制: %d bytes", config.MaxSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查文件类型
|
// 检查文件类型
|
||||||
contentType := header.Header.Get("Content-Type")
|
contentType := header.Header.Get("Content-Type")
|
||||||
if !IsAllowedType(contentType, config.AllowedTypes) {
|
if !IsAllowedType(contentType, config.AllowedTypes) {
|
||||||
return nil, fmt.Errorf("不支持的文件类型: %s", contentType)
|
return nil, fmt.Errorf("不支持的文件类型: %s", contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成文件名
|
// 生成文件名
|
||||||
fileName := GenerateFileName(header.Filename)
|
fileName := GenerateFileName(header.Filename)
|
||||||
|
|
||||||
// 创建上传目录
|
// 创建上传目录
|
||||||
uploadPath := filepath.Join(config.UploadDir, "photos")
|
uploadPath := filepath.Join(config.UploadDir, "photos")
|
||||||
if err := os.MkdirAll(uploadPath, 0755); err != nil {
|
if err := os.MkdirAll(uploadPath, 0755); err != nil {
|
||||||
return nil, fmt.Errorf("创建上传目录失败: %v", err)
|
return nil, fmt.Errorf("创建上传目录失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 完整文件路径
|
// 完整文件路径
|
||||||
filePath := filepath.Join(uploadPath, fileName)
|
filePath := filepath.Join(uploadPath, fileName)
|
||||||
|
|
||||||
// 创建目标文件
|
// 创建目标文件
|
||||||
dst, err := os.Create(filePath)
|
dst, err := os.Create(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("创建文件失败: %v", err)
|
return nil, fmt.Errorf("创建文件失败: %v", err)
|
||||||
}
|
}
|
||||||
defer dst.Close()
|
defer dst.Close()
|
||||||
|
|
||||||
// 复制文件内容
|
// 复制文件内容
|
||||||
file.Seek(0, 0) // 重置文件指针
|
file.Seek(0, 0) // 重置文件指针
|
||||||
_, err = io.Copy(dst, file)
|
_, err = io.Copy(dst, file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("保存文件失败: %v", err)
|
return nil, fmt.Errorf("保存文件失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回文件信息
|
// 返回文件信息
|
||||||
return &FileInfo{
|
return &FileInfo{
|
||||||
OriginalName: header.Filename,
|
OriginalName: header.Filename,
|
||||||
@ -115,35 +115,35 @@ func CreateThumbnail(originalPath string, config Config) (*FileInfo, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("打开图片失败: %v", err)
|
return nil, fmt.Errorf("打开图片失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成缩略图文件名
|
// 生成缩略图文件名
|
||||||
ext := filepath.Ext(originalPath)
|
ext := filepath.Ext(originalPath)
|
||||||
baseName := strings.TrimSuffix(filepath.Base(originalPath), ext)
|
baseName := strings.TrimSuffix(filepath.Base(originalPath), ext)
|
||||||
thumbnailName := baseName + "_thumb" + ext
|
thumbnailName := baseName + "_thumb" + ext
|
||||||
|
|
||||||
// 创建缩略图目录
|
// 创建缩略图目录
|
||||||
thumbnailDir := filepath.Join(config.UploadDir, "thumbnails")
|
thumbnailDir := filepath.Join(config.UploadDir, "thumbnails")
|
||||||
if err := os.MkdirAll(thumbnailDir, 0755); err != nil {
|
if err := os.MkdirAll(thumbnailDir, 0755); err != nil {
|
||||||
return nil, fmt.Errorf("创建缩略图目录失败: %v", err)
|
return nil, fmt.Errorf("创建缩略图目录失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
thumbnailPath := filepath.Join(thumbnailDir, thumbnailName)
|
thumbnailPath := filepath.Join(thumbnailDir, thumbnailName)
|
||||||
|
|
||||||
// 调整图片大小 (最大宽度 300px,保持比例)
|
// 调整图片大小 (最大宽度 300px,保持比例)
|
||||||
thumbnail := imaging.Resize(src, 300, 0, imaging.Lanczos)
|
thumbnail := imaging.Resize(src, 300, 0, imaging.Lanczos)
|
||||||
|
|
||||||
// 保存缩略图
|
// 保存缩略图
|
||||||
err = imaging.Save(thumbnail, thumbnailPath)
|
err = imaging.Save(thumbnail, thumbnailPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("保存缩略图失败: %v", err)
|
return nil, fmt.Errorf("保存缩略图失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取文件大小
|
// 获取文件大小
|
||||||
stat, err := os.Stat(thumbnailPath)
|
stat, err := os.Stat(thumbnailPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("获取缩略图信息失败: %v", err)
|
return nil, fmt.Errorf("获取缩略图信息失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &FileInfo{
|
return &FileInfo{
|
||||||
OriginalName: thumbnailName,
|
OriginalName: thumbnailName,
|
||||||
FileName: thumbnailName,
|
FileName: thumbnailName,
|
||||||
@ -161,13 +161,13 @@ func UploadPhoto(file multipart.File, header *multipart.FileHeader, config Confi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建缩略图
|
// 创建缩略图
|
||||||
thumbnail, err := CreateThumbnail(original.FilePath, config)
|
thumbnail, err := CreateThumbnail(original.FilePath, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &UploadResult{
|
return &UploadResult{
|
||||||
Original: *original,
|
Original: *original,
|
||||||
Thumbnail: *thumbnail,
|
Thumbnail: *thumbnail,
|
||||||
@ -179,11 +179,11 @@ func DeleteFile(filePath string) error {
|
|||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||||
return nil // 文件不存在,认为删除成功
|
return nil // 文件不存在,认为删除成功
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.Remove(filePath)
|
return os.Remove(filePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,12 +194,12 @@ func GetImageDimensions(filePath string) (width, height int, err error) {
|
|||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
img, _, err := image.DecodeConfig(file)
|
img, _, err := image.DecodeConfig(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return img.Width, img.Height, nil
|
return img.Width, img.Height, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,16 +210,16 @@ func ResizeImage(srcPath, destPath string, width, height int) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("打开图片失败: %v", err)
|
return fmt.Errorf("打开图片失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调整图片尺寸 (正方形裁剪)
|
// 调整图片尺寸 (正方形裁剪)
|
||||||
resized := imaging.Fill(src, width, height, imaging.Center, imaging.Lanczos)
|
resized := imaging.Fill(src, width, height, imaging.Center, imaging.Lanczos)
|
||||||
|
|
||||||
// 保存调整后的图片
|
// 保存调整后的图片
|
||||||
err = imaging.Save(resized, destPath)
|
err = imaging.Save(resized, destPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("保存调整后的图片失败: %v", err)
|
return fmt.Errorf("保存调整后的图片失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,15 +230,15 @@ func CreateAvatar(srcPath, destPath string, size int) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("打开图片失败: %v", err)
|
return fmt.Errorf("打开图片失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建正方形头像 (居中裁剪)
|
// 创建正方形头像 (居中裁剪)
|
||||||
avatar := imaging.Fill(src, size, size, imaging.Center, imaging.Lanczos)
|
avatar := imaging.Fill(src, size, size, imaging.Center, imaging.Lanczos)
|
||||||
|
|
||||||
// 保存为JPEG格式 (压缩优化)
|
// 保存为JPEG格式 (压缩优化)
|
||||||
err = imaging.Save(avatar, destPath, imaging.JPEGQuality(85))
|
err = imaging.Save(avatar, destPath, imaging.JPEGQuality(85))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("保存头像失败: %v", err)
|
return fmt.Errorf("保存头像失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,4 +31,4 @@ func SHA256(str string) string {
|
|||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
h.Write([]byte(str))
|
h.Write([]byte(str))
|
||||||
return hex.EncodeToString(h.Sum(nil))
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@ package jwt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ func GenerateToken(userId int64, username string, secret string, expires time.Du
|
|||||||
NotBefore: jwt.NewNumericDate(now),
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
return token.SignedString([]byte(secret))
|
return token.SignedString([]byte(secret))
|
||||||
}
|
}
|
||||||
@ -32,14 +32,14 @@ func ParseToken(tokenString string, secret string) (*Claims, error) {
|
|||||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
return []byte(secret), nil
|
return []byte(secret), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, jwt.ErrInvalidKey
|
return nil, jwt.ErrInvalidKey
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,9 +15,12 @@ const api = axios.create({
|
|||||||
api.interceptors.request.use(
|
api.interceptors.request.use(
|
||||||
(config) => {
|
(config) => {
|
||||||
// 可以在这里添加token等认证信息
|
// 可以在这里添加token等认证信息
|
||||||
const token = localStorage.getItem('token')
|
// 检查是否在浏览器环境中
|
||||||
if (token) {
|
if (typeof window !== 'undefined') {
|
||||||
config.headers.Authorization = `Bearer ${token}`
|
const token = localStorage.getItem('token')
|
||||||
|
if (token) {
|
||||||
|
config.headers.Authorization = `Bearer ${token}`
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
},
|
},
|
||||||
@ -42,9 +45,11 @@ api.interceptors.response.use(
|
|||||||
},
|
},
|
||||||
(error) => {
|
(error) => {
|
||||||
if (error.response?.status === 401) {
|
if (error.response?.status === 401) {
|
||||||
// 处理未授权
|
// 处理未授权 - 仅在浏览器环境中执行
|
||||||
localStorage.removeItem('token')
|
if (typeof window !== 'undefined') {
|
||||||
window.location.href = '/login'
|
localStorage.removeItem('token')
|
||||||
|
window.location.href = '/login'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Promise.reject(error)
|
return Promise.reject(error)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user