Files
photography/backend/internal/middleware/middleware.go
xujiang 5b3fc9bf9c feat: 完成后端中间件系统完善
## 🛡️ 新增功能
- 实现完整的CORS中间件,支持开发/生产环境配置
- 实现请求日志中间件,完整的请求生命周期记录
- 实现全局错误处理中间件,统一错误响应格式
- 创建中间件管理器,支持链式中间件和配置管理

## 🔧 技术改进
- 更新配置系统支持中间件配置
- 修复go-zero日志API兼容性问题
- 创建完整的中间件测试用例
- 编译测试通过,功能完整可用

## 📊 进度提升
- 项目总进度从42.5%提升至50.0%
- 中优先级任务完成率达55%
- 3个中优先级任务同时完成

## 🎯 完成的任务
14. 实现 CORS 中间件
16. 实现请求日志中间件
17. 完善全局错误处理

Co-authored-by: Claude Code <claude@anthropic.com>
2025-07-11 13:55:38 +08:00

203 lines
5.3 KiB
Go

package middleware
import (
"net/http"
"os"
"strings"
"time"
"photography-backend/internal/config"
"github.com/zeromicro/go-zero/core/logx"
)
// MiddlewareManager 中间件管理器
type MiddlewareManager struct {
config config.Config
corsMiddleware *CORSMiddleware
logMiddleware *LoggerMiddleware
errorMiddleware *ErrorMiddleware
authMiddleware *AuthMiddleware
}
// NewMiddlewareManager 创建中间件管理器
func NewMiddlewareManager(c config.Config) *MiddlewareManager {
return &MiddlewareManager{
config: c,
corsMiddleware: NewCORSMiddleware(getCORSConfig(c)),
logMiddleware: NewLoggerMiddleware(getLoggerConfig(c)),
errorMiddleware: NewErrorMiddleware(getErrorConfig(c)),
authMiddleware: NewAuthMiddleware(c.Auth.AccessSecret),
}
}
// getCORSConfig 获取CORS配置
func getCORSConfig(c config.Config) CORSConfig {
env := getEnvironment()
if env == "production" {
// 生产环境使用严格的CORS配置
return ProductionCORSConfig(getProductionOrigins())
}
// 开发环境使用宽松的CORS配置
return DefaultCORSConfig()
}
// getLoggerConfig 获取日志配置
func getLoggerConfig(c config.Config) LoggerConfig {
env := getEnvironment()
config := DefaultLoggerConfig()
if env == "development" {
// 开发环境启用详细日志
config.EnableRequestBody = true
config.EnableResponseBody = true
config.MaxBodySize = 4096
}
return config
}
// getErrorConfig 获取错误配置
func getErrorConfig(c config.Config) ErrorConfig {
env := getEnvironment()
if env == "development" {
return DevelopmentErrorConfig()
}
return DefaultErrorConfig()
}
// getEnvironment 获取环境变量
func getEnvironment() string {
env := os.Getenv("GO_ENV")
if env == "" {
env = os.Getenv("ENV")
}
if env == "" {
// 默认为开发环境
env = "development"
}
return strings.ToLower(env)
}
// getProductionOrigins 获取生产环境允许的来源
func getProductionOrigins() []string {
origins := os.Getenv("CORS_ALLOWED_ORIGINS")
if origins == "" {
return []string{
"https://photography.iriver.top",
"https://admin.photography.iriver.top",
}
}
return strings.Split(origins, ",")
}
// Chain 链式中间件
func (m *MiddlewareManager) Chain(handler http.HandlerFunc, middlewares ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
// 从右到左包装中间件
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}
// GetGlobalMiddlewares 获取全局中间件
func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc {
return []func(http.HandlerFunc) http.HandlerFunc{
m.errorMiddleware.Handle, // 错误处理 (最外层)
m.corsMiddleware.Handle, // CORS 处理
m.logMiddleware.Handle, // 日志记录
}
}
// GetAuthMiddleware 获取认证中间件
func (m *MiddlewareManager) GetAuthMiddleware() func(http.HandlerFunc) http.HandlerFunc {
return m.authMiddleware.Handle
}
// ApplyGlobalMiddlewares 应用全局中间件
func (m *MiddlewareManager) ApplyGlobalMiddlewares(handler http.HandlerFunc) http.HandlerFunc {
return m.Chain(handler, m.GetGlobalMiddlewares()...)
}
// ApplyAuthMiddlewares 应用认证中间件
func (m *MiddlewareManager) ApplyAuthMiddlewares(handler http.HandlerFunc) http.HandlerFunc {
middlewares := append(m.GetGlobalMiddlewares(), m.GetAuthMiddleware())
return m.Chain(handler, middlewares...)
}
// HealthCheck 健康检查处理器 (不使用中间件)
func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ok","timestamp":"` +
time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`))
}
// Middleware 中间件接口
type Middleware interface {
Handle(next http.HandlerFunc) http.HandlerFunc
}
// MiddlewareFunc 中间件函数类型
type MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc
// Handle 实现Middleware接口
func (f MiddlewareFunc) Handle(next http.HandlerFunc) http.HandlerFunc {
return f(next)
}
// Use 应用中间件到处理器
func Use(handler http.HandlerFunc, middlewares ...Middleware) http.HandlerFunc {
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i].Handle(handler)
}
return handler
}
// Recovery 通用恢复中间件
func Recovery() MiddlewareFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
fields := map[string]interface{}{
"error": err,
"method": r.Method,
"path": r.URL.Path,
}
logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next(w, r)
})
}
}
// RequestID 请求ID中间件
func RequestID() MiddlewareFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
w.Header().Set("X-Request-ID", requestID)
r.Header.Set("X-Request-ID", requestID)
next(w, r)
})
}
}
// generateRequestID 生成请求ID
func generateRequestID() string {
return randomString(16)
}