package middleware import ( "net/http" "os" "strings" "time" "photography-backend/internal/config" "github.com/zeromicro/go-zero/core/logx" ) // MiddlewareManager 中间件管理器 type MiddlewareManager struct { config config.Config corsMiddleware *CORSMiddleware logMiddleware *LoggerMiddleware errorMiddleware *ErrorMiddleware authMiddleware *AuthMiddleware } // NewMiddlewareManager 创建中间件管理器 func NewMiddlewareManager(c config.Config) *MiddlewareManager { return &MiddlewareManager{ config: c, corsMiddleware: NewCORSMiddleware(getCORSConfig(c)), logMiddleware: NewLoggerMiddleware(getLoggerConfig(c)), errorMiddleware: NewErrorMiddleware(getErrorConfig(c)), authMiddleware: NewAuthMiddleware(c.Auth.AccessSecret), } } // getCORSConfig 获取CORS配置 func getCORSConfig(c config.Config) CORSConfig { env := getEnvironment() if env == "production" { // 生产环境使用严格的CORS配置 return ProductionCORSConfig(getProductionOrigins()) } // 开发环境使用宽松的CORS配置 return DefaultCORSConfig() } // getLoggerConfig 获取日志配置 func getLoggerConfig(c config.Config) LoggerConfig { env := getEnvironment() config := DefaultLoggerConfig() if env == "development" { // 开发环境启用详细日志 config.EnableRequestBody = true config.EnableResponseBody = true config.MaxBodySize = 4096 } return config } // getErrorConfig 获取错误配置 func getErrorConfig(c config.Config) ErrorConfig { env := getEnvironment() if env == "development" { return DevelopmentErrorConfig() } return DefaultErrorConfig() } // getEnvironment 获取环境变量 func getEnvironment() string { env := os.Getenv("GO_ENV") if env == "" { env = os.Getenv("ENV") } if env == "" { // 默认为开发环境 env = "development" } return strings.ToLower(env) } // getProductionOrigins 获取生产环境允许的来源 func getProductionOrigins() []string { origins := os.Getenv("CORS_ALLOWED_ORIGINS") if origins == "" { return []string{ "https://photography.iriver.top", "https://admin.photography.iriver.top", } } return strings.Split(origins, ",") } // Chain 链式中间件 func (m *MiddlewareManager) Chain(handler http.HandlerFunc, middlewares ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc { // 从右到左包装中间件 for i := len(middlewares) - 1; i >= 0; i-- { handler = middlewares[i](handler) } return handler } // GetGlobalMiddlewares 获取全局中间件 func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc { return []func(http.HandlerFunc) http.HandlerFunc{ m.errorMiddleware.Handle, // 错误处理 (最外层) m.corsMiddleware.Handle, // CORS 处理 m.logMiddleware.Handle, // 日志记录 } } // GetAuthMiddleware 获取认证中间件 func (m *MiddlewareManager) GetAuthMiddleware() func(http.HandlerFunc) http.HandlerFunc { return m.authMiddleware.Handle } // ApplyGlobalMiddlewares 应用全局中间件 func (m *MiddlewareManager) ApplyGlobalMiddlewares(handler http.HandlerFunc) http.HandlerFunc { return m.Chain(handler, m.GetGlobalMiddlewares()...) } // ApplyAuthMiddlewares 应用认证中间件 func (m *MiddlewareManager) ApplyAuthMiddlewares(handler http.HandlerFunc) http.HandlerFunc { middlewares := append(m.GetGlobalMiddlewares(), m.GetAuthMiddleware()) return m.Chain(handler, middlewares...) } // HealthCheck 健康检查处理器 (不使用中间件) func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"status":"ok","timestamp":"` + time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`)) } // Middleware 中间件接口 type Middleware interface { Handle(next http.HandlerFunc) http.HandlerFunc } // MiddlewareFunc 中间件函数类型 type MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc // Handle 实现Middleware接口 func (f MiddlewareFunc) Handle(next http.HandlerFunc) http.HandlerFunc { return f(next) } // Use 应用中间件到处理器 func Use(handler http.HandlerFunc, middlewares ...Middleware) http.HandlerFunc { for i := len(middlewares) - 1; i >= 0; i-- { handler = middlewares[i].Handle(handler) } return handler } // Recovery 通用恢复中间件 func Recovery() MiddlewareFunc { return func(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { fields := map[string]interface{}{ "error": err, "method": r.Method, "path": r.URL.Path, } logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields) http.Error(w, "Internal Server Error", http.StatusInternalServerError) } }() next(w, r) }) } } // RequestID 请求ID中间件 func RequestID() MiddlewareFunc { return func(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestID := r.Header.Get("X-Request-ID") if requestID == "" { requestID = generateRequestID() } w.Header().Set("X-Request-ID", requestID) r.Header.Set("X-Request-ID", requestID) next(w, r) }) } } // generateRequestID 生成请求ID func generateRequestID() string { return randomString(16) }