## 🛡️ 新增功能 - 实现完整的CORS中间件,支持开发/生产环境配置 - 实现请求日志中间件,完整的请求生命周期记录 - 实现全局错误处理中间件,统一错误响应格式 - 创建中间件管理器,支持链式中间件和配置管理 ## 🔧 技术改进 - 更新配置系统支持中间件配置 - 修复go-zero日志API兼容性问题 - 创建完整的中间件测试用例 - 编译测试通过,功能完整可用 ## 📊 进度提升 - 项目总进度从42.5%提升至50.0% - 中优先级任务完成率达55% - 3个中优先级任务同时完成 ## 🎯 完成的任务 14. 实现 CORS 中间件 16. 实现请求日志中间件 17. 完善全局错误处理 Co-authored-by: Claude Code <claude@anthropic.com>
203 lines
5.3 KiB
Go
203 lines
5.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"photography-backend/internal/config"
|
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
)
|
|
|
|
// MiddlewareManager 中间件管理器
|
|
type MiddlewareManager struct {
|
|
config config.Config
|
|
corsMiddleware *CORSMiddleware
|
|
logMiddleware *LoggerMiddleware
|
|
errorMiddleware *ErrorMiddleware
|
|
authMiddleware *AuthMiddleware
|
|
}
|
|
|
|
// NewMiddlewareManager 创建中间件管理器
|
|
func NewMiddlewareManager(c config.Config) *MiddlewareManager {
|
|
return &MiddlewareManager{
|
|
config: c,
|
|
corsMiddleware: NewCORSMiddleware(getCORSConfig(c)),
|
|
logMiddleware: NewLoggerMiddleware(getLoggerConfig(c)),
|
|
errorMiddleware: NewErrorMiddleware(getErrorConfig(c)),
|
|
authMiddleware: NewAuthMiddleware(c.Auth.AccessSecret),
|
|
}
|
|
}
|
|
|
|
// getCORSConfig 获取CORS配置
|
|
func getCORSConfig(c config.Config) CORSConfig {
|
|
env := getEnvironment()
|
|
|
|
if env == "production" {
|
|
// 生产环境使用严格的CORS配置
|
|
return ProductionCORSConfig(getProductionOrigins())
|
|
}
|
|
|
|
// 开发环境使用宽松的CORS配置
|
|
return DefaultCORSConfig()
|
|
}
|
|
|
|
// getLoggerConfig 获取日志配置
|
|
func getLoggerConfig(c config.Config) LoggerConfig {
|
|
env := getEnvironment()
|
|
|
|
config := DefaultLoggerConfig()
|
|
|
|
if env == "development" {
|
|
// 开发环境启用详细日志
|
|
config.EnableRequestBody = true
|
|
config.EnableResponseBody = true
|
|
config.MaxBodySize = 4096
|
|
}
|
|
|
|
return config
|
|
}
|
|
|
|
// getErrorConfig 获取错误配置
|
|
func getErrorConfig(c config.Config) ErrorConfig {
|
|
env := getEnvironment()
|
|
|
|
if env == "development" {
|
|
return DevelopmentErrorConfig()
|
|
}
|
|
|
|
return DefaultErrorConfig()
|
|
}
|
|
|
|
// getEnvironment 获取环境变量
|
|
func getEnvironment() string {
|
|
env := os.Getenv("GO_ENV")
|
|
if env == "" {
|
|
env = os.Getenv("ENV")
|
|
}
|
|
if env == "" {
|
|
// 默认为开发环境
|
|
env = "development"
|
|
}
|
|
return strings.ToLower(env)
|
|
}
|
|
|
|
// getProductionOrigins 获取生产环境允许的来源
|
|
func getProductionOrigins() []string {
|
|
origins := os.Getenv("CORS_ALLOWED_ORIGINS")
|
|
if origins == "" {
|
|
return []string{
|
|
"https://photography.iriver.top",
|
|
"https://admin.photography.iriver.top",
|
|
}
|
|
}
|
|
return strings.Split(origins, ",")
|
|
}
|
|
|
|
// Chain 链式中间件
|
|
func (m *MiddlewareManager) Chain(handler http.HandlerFunc, middlewares ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
|
|
// 从右到左包装中间件
|
|
for i := len(middlewares) - 1; i >= 0; i-- {
|
|
handler = middlewares[i](handler)
|
|
}
|
|
return handler
|
|
}
|
|
|
|
// GetGlobalMiddlewares 获取全局中间件
|
|
func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc {
|
|
return []func(http.HandlerFunc) http.HandlerFunc{
|
|
m.errorMiddleware.Handle, // 错误处理 (最外层)
|
|
m.corsMiddleware.Handle, // CORS 处理
|
|
m.logMiddleware.Handle, // 日志记录
|
|
}
|
|
}
|
|
|
|
// GetAuthMiddleware 获取认证中间件
|
|
func (m *MiddlewareManager) GetAuthMiddleware() func(http.HandlerFunc) http.HandlerFunc {
|
|
return m.authMiddleware.Handle
|
|
}
|
|
|
|
// ApplyGlobalMiddlewares 应用全局中间件
|
|
func (m *MiddlewareManager) ApplyGlobalMiddlewares(handler http.HandlerFunc) http.HandlerFunc {
|
|
return m.Chain(handler, m.GetGlobalMiddlewares()...)
|
|
}
|
|
|
|
// ApplyAuthMiddlewares 应用认证中间件
|
|
func (m *MiddlewareManager) ApplyAuthMiddlewares(handler http.HandlerFunc) http.HandlerFunc {
|
|
middlewares := append(m.GetGlobalMiddlewares(), m.GetAuthMiddleware())
|
|
return m.Chain(handler, middlewares...)
|
|
}
|
|
|
|
// HealthCheck 健康检查处理器 (不使用中间件)
|
|
func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"status":"ok","timestamp":"` +
|
|
time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`))
|
|
}
|
|
|
|
// Middleware 中间件接口
|
|
type Middleware interface {
|
|
Handle(next http.HandlerFunc) http.HandlerFunc
|
|
}
|
|
|
|
// MiddlewareFunc 中间件函数类型
|
|
type MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc
|
|
|
|
// Handle 实现Middleware接口
|
|
func (f MiddlewareFunc) Handle(next http.HandlerFunc) http.HandlerFunc {
|
|
return f(next)
|
|
}
|
|
|
|
// Use 应用中间件到处理器
|
|
func Use(handler http.HandlerFunc, middlewares ...Middleware) http.HandlerFunc {
|
|
for i := len(middlewares) - 1; i >= 0; i-- {
|
|
handler = middlewares[i].Handle(handler)
|
|
}
|
|
return handler
|
|
}
|
|
|
|
// Recovery 通用恢复中间件
|
|
func Recovery() MiddlewareFunc {
|
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
fields := map[string]interface{}{
|
|
"error": err,
|
|
"method": r.Method,
|
|
"path": r.URL.Path,
|
|
}
|
|
logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields)
|
|
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
}
|
|
}()
|
|
next(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequestID 请求ID中间件
|
|
func RequestID() MiddlewareFunc {
|
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestID := r.Header.Get("X-Request-ID")
|
|
if requestID == "" {
|
|
requestID = generateRequestID()
|
|
}
|
|
|
|
w.Header().Set("X-Request-ID", requestID)
|
|
r.Header.Set("X-Request-ID", requestID)
|
|
|
|
next(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// generateRequestID 生成请求ID
|
|
func generateRequestID() string {
|
|
return randomString(16)
|
|
} |