package middleware import ( "encoding/json" "net/http" "runtime/debug" "strings" "time" "photography-backend/pkg/errorx" "photography-backend/pkg/response" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/rest/httpx" ) // ErrorConfig 错误处理配置 type ErrorConfig struct { EnableDetailedErrors bool // 是否启用详细错误信息 (开发环境) EnableStackTrace bool // 是否启用堆栈跟踪 EnableErrorMonitor bool // 是否启用错误监控 IgnoreHTTPCodes []int // 忽略的HTTP状态码 (不记录为错误) SensitiveFields []string // 敏感字段列表 (日志时隐藏) } // DefaultErrorConfig 默认错误配置 func DefaultErrorConfig() ErrorConfig { return ErrorConfig{ EnableDetailedErrors: false, // 生产环境默认关闭 EnableStackTrace: false, // 生产环境默认关闭 EnableErrorMonitor: true, IgnoreHTTPCodes: []int{http.StatusNotFound, http.StatusMethodNotAllowed}, SensitiveFields: []string{"password", "token", "secret", "key", "authorization"}, } } // DevelopmentErrorConfig 开发环境错误配置 func DevelopmentErrorConfig() ErrorConfig { config := DefaultErrorConfig() config.EnableDetailedErrors = true config.EnableStackTrace = true return config } // ErrorMiddleware 全局错误处理中间件 type ErrorMiddleware struct { config ErrorConfig } // NewErrorMiddleware 创建错误处理中间件 func NewErrorMiddleware(config ErrorConfig) *ErrorMiddleware { return &ErrorMiddleware{ config: config, } } // errorResponseWriter 包装 ResponseWriter 用于捕获错误 type errorResponseWriter struct { http.ResponseWriter statusCode int body []byte written bool } // newErrorResponseWriter 创建错误响应写入器 func newErrorResponseWriter(w http.ResponseWriter) *errorResponseWriter { return &errorResponseWriter{ ResponseWriter: w, statusCode: http.StatusOK, } } // Write 写入响应数据 func (erw *errorResponseWriter) Write(b []byte) (int, error) { if !erw.written { erw.body = append(erw.body, b...) } return erw.ResponseWriter.Write(b) } // WriteHeader 写入响应头 func (erw *errorResponseWriter) WriteHeader(statusCode int) { erw.statusCode = statusCode erw.written = true erw.ResponseWriter.WriteHeader(statusCode) } // Handle 处理全局错误 func (m *ErrorMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 包装响应写入器 erw := newErrorResponseWriter(w) // 设置panic恢复 defer func() { if err := recover(); err != nil { m.handlePanic(erw, r, err) } }() // 执行下一个处理器 next(erw, r) // 检查是否需要处理错误 if m.shouldHandleError(erw.statusCode) { m.handleHTTPError(erw, r) } }) } // handlePanic 处理panic func (m *ErrorMiddleware) handlePanic(w *errorResponseWriter, r *http.Request, err interface{}) { stack := string(debug.Stack()) // 记录panic日志 logFields := map[string]interface{}{ "error": err, "method": r.Method, "path": r.URL.Path, "user_agent": r.UserAgent(), "remote_ip": getClientIP(r), "timestamp": time.Now().Format(time.RFC3339), } if m.config.EnableStackTrace { logFields["stack_trace"] = stack } logx.WithContext(r.Context()).Errorf("Panic recovered: %+v", logFields) // 响应错误 m.respondWithError(w.ResponseWriter, r, &errorx.CodeError{ Code: errorx.ServerError, Msg: "Internal Server Error", }, map[string]interface{}{ "error": err, "stack": stack, }) // 错误监控 if m.config.EnableErrorMonitor { m.reportError("panic", err, logFields) } } // handleHTTPError 处理HTTP错误 func (m *ErrorMiddleware) handleHTTPError(w *errorResponseWriter, r *http.Request) { // 记录HTTP错误 logFields := map[string]interface{}{ "status_code": w.statusCode, "method": r.Method, "path": r.URL.Path, "user_agent": r.UserAgent(), "remote_ip": getClientIP(r), "timestamp": time.Now().Format(time.RFC3339), } // 尝试解析响应体中的错误信息 if len(w.body) > 0 { var errorBody map[string]interface{} if err := json.Unmarshal(w.body, &errorBody); err == nil { // 隐藏敏感字段 errorBody = m.sanitizeFields(errorBody) logFields["response_body"] = errorBody } } // 根据状态码选择日志级别 switch { case w.statusCode >= 500: logx.WithContext(r.Context()).Errorf("Server error occurred: %+v", logFields) case w.statusCode >= 400: logx.WithContext(r.Context()).Infof("Client error occurred: %+v", logFields) } // 错误监控 if m.config.EnableErrorMonitor { m.reportError("http_error", w.statusCode, logFields) } } // shouldHandleError 检查是否应该处理此错误 func (m *ErrorMiddleware) shouldHandleError(statusCode int) bool { for _, ignoreCode := range m.config.IgnoreHTTPCodes { if statusCode == ignoreCode { return false } } return statusCode >= 400 } // respondWithError 响应错误 func (m *ErrorMiddleware) respondWithError(w http.ResponseWriter, r *http.Request, err *errorx.CodeError, extra map[string]interface{}) { // 构建响应体 body := response.Body{ Code: err.Code, Message: err.Msg, } // 添加详细错误信息 (仅开发环境) if m.config.EnableDetailedErrors && extra != nil { // 隐藏敏感信息 extra = m.sanitizeFields(extra) body.Data = extra } // 设置HTTP状态码 httpStatus := errorx.GetHttpStatus(err.Code) // 设置响应头 w.Header().Set("Content-Type", "application/json") w.WriteHeader(httpStatus) // 写入响应 httpx.WriteJson(w, httpStatus, body) } // sanitizeFields 隐藏敏感字段 func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string]interface{} { sanitized := make(map[string]interface{}) for key, value := range data { lowerKey := strings.ToLower(key) // 检查是否为敏感字段 sensitive := false for _, sensitiveField := range m.config.SensitiveFields { if strings.Contains(lowerKey, strings.ToLower(sensitiveField)) { sensitive = true break } } if sensitive { sanitized[key] = "***REDACTED***" } else { // 递归处理嵌套对象 if nestedMap, ok := value.(map[string]interface{}); ok { sanitized[key] = m.sanitizeFields(nestedMap) } else { sanitized[key] = value } } } return sanitized } // reportError 报告错误到监控系统 func (m *ErrorMiddleware) reportError(errorType string, error interface{}, context map[string]interface{}) { // 这里可以集成第三方监控服务 (如 Sentry, DataDog 等) // 目前只记录到日志 fields := map[string]interface{}{ "error_type": errorType, "error": error, "context": context, "timestamp": time.Now().Format(time.RFC3339), } logx.Infof("Error reported to monitoring system: %+v", fields) } // ErrorResponse 标准化错误响应 type ErrorResponse struct { Code int `json:"code"` Message string `json:"message"` Details map[string]interface{} `json:"details,omitempty"` Timestamp string `json:"timestamp"` RequestID string `json:"request_id,omitempty"` } // NewErrorResponse 创建标准化错误响应 func NewErrorResponse(code int, message string, details map[string]interface{}, requestID string) *ErrorResponse { return &ErrorResponse{ Code: code, Message: message, Details: details, Timestamp: time.Now().Format(time.RFC3339), RequestID: requestID, } } // CommonErrors 常用错误响应 var CommonErrors = struct { BadRequest *errorx.CodeError Unauthorized *errorx.CodeError Forbidden *errorx.CodeError NotFound *errorx.CodeError MethodNotAllowed *errorx.CodeError InternalServerError *errorx.CodeError ValidationFailed *errorx.CodeError RateLimitExceeded *errorx.CodeError }{ BadRequest: &errorx.CodeError{ Code: errorx.ParamError, Msg: "Bad Request", }, Unauthorized: &errorx.CodeError{ Code: errorx.AuthError, Msg: "Unauthorized", }, Forbidden: &errorx.CodeError{ Code: errorx.Forbidden, Msg: "Forbidden", }, NotFound: &errorx.CodeError{ Code: errorx.NotFound, Msg: "Not Found", }, MethodNotAllowed: &errorx.CodeError{ Code: 405, Msg: "Method Not Allowed", }, InternalServerError: &errorx.CodeError{ Code: errorx.ServerError, Msg: "Internal Server Error", }, ValidationFailed: &errorx.CodeError{ Code: errorx.ParamError, Msg: "Validation Failed", }, RateLimitExceeded: &errorx.CodeError{ Code: 429, Msg: "Rate Limit Exceeded", }, }