package middleware import ( "bytes" "context" "fmt" "io" "net/http" "runtime/debug" "strings" "time" "github.com/zeromicro/go-zero/core/logx" ) // LoggerConfig 日志配置 type LoggerConfig struct { EnableRequestBody bool // 是否记录请求体 EnableResponseBody bool // 是否记录响应体 MaxBodySize int64 // 最大记录的请求/响应体大小 SkipPaths []string // 跳过记录的路径 SlowRequestDuration time.Duration // 慢请求阈值 EnablePanicRecover bool // 是否启用panic恢复 } // DefaultLoggerConfig 默认日志配置 func DefaultLoggerConfig() LoggerConfig { return LoggerConfig{ EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息) EnableResponseBody: false, // 默认不记录响应体 (减少日志量) MaxBodySize: 1024, // 最大记录1KB SkipPaths: []string{"/health", "/metrics", "/favicon.ico"}, SlowRequestDuration: 1 * time.Second, EnablePanicRecover: true, } } // LoggerMiddleware 日志中间件 type LoggerMiddleware struct { config LoggerConfig } // NewLoggerMiddleware 创建日志中间件 func NewLoggerMiddleware(config LoggerConfig) *LoggerMiddleware { return &LoggerMiddleware{ config: config, } } // responseWriter 包装 http.ResponseWriter 用于记录响应 type responseWriter struct { http.ResponseWriter status int size int64 body *bytes.Buffer } // newResponseWriter 创建响应写入器 func newResponseWriter(w http.ResponseWriter) *responseWriter { return &responseWriter{ ResponseWriter: w, status: http.StatusOK, body: &bytes.Buffer{}, } } // Write 写入响应数据 func (rw *responseWriter) Write(b []byte) (int, error) { size, err := rw.ResponseWriter.Write(b) rw.size += int64(size) // 记录响应体 (如果启用) if rw.body.Len() < int(1024) { // 限制缓存大小 rw.body.Write(b) } return size, err } // WriteHeader 写入响应头 func (rw *responseWriter) WriteHeader(statusCode int) { rw.status = statusCode rw.ResponseWriter.WriteHeader(statusCode) } // Handle 处理请求日志 func (m *LoggerMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 检查是否跳过此路径 if m.shouldSkipPath(r.URL.Path) { next(w, r) return } // 记录请求开始时间 start := time.Now() // 生成请求ID requestID := m.generateRequestID(r) ctx := context.WithValue(r.Context(), "request_id", requestID) r = r.WithContext(ctx) // 添加请求ID到响应头 w.Header().Set("X-Request-ID", requestID) // 包装响应写入器 rw := newResponseWriter(w) // 记录请求体 (如果启用) var requestBody string if m.config.EnableRequestBody && r.Body != nil { requestBody = m.readRequestBody(r) } // 设置panic恢复 if m.config.EnablePanicRecover { defer func() { if err := recover(); err != nil { m.logPanic(r, err, string(debug.Stack())) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) } }() } // 记录请求开始 m.logRequestStart(r, requestID, requestBody) // 执行下一个处理器 next(rw, r) // 计算处理时间 duration := time.Since(start) // 记录响应体 (如果启用) var responseBody string if m.config.EnableResponseBody && rw.body.Len() > 0 { responseBody = m.truncateString(rw.body.String(), int(m.config.MaxBodySize)) } // 记录请求完成 m.logRequestComplete(r, requestID, rw.status, rw.size, duration, responseBody) // 记录慢请求 if duration > m.config.SlowRequestDuration { m.logSlowRequest(r, requestID, duration) } }) } // shouldSkipPath 检查是否应该跳过此路径 func (m *LoggerMiddleware) shouldSkipPath(path string) bool { for _, skipPath := range m.config.SkipPaths { if strings.HasPrefix(path, skipPath) { return true } } return false } // generateRequestID 生成请求ID func (m *LoggerMiddleware) generateRequestID(r *http.Request) string { // 优先使用现有的请求ID if requestID := r.Header.Get("X-Request-ID"); requestID != "" { return requestID } // 生成新的请求ID return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8)) } // readRequestBody 读取请求体 func (m *LoggerMiddleware) readRequestBody(r *http.Request) string { if r.Body == nil { return "" } // 读取请求体 body, err := io.ReadAll(io.LimitReader(r.Body, m.config.MaxBodySize)) if err != nil { return fmt.Sprintf("error reading body: %v", err) } // 恢复请求体 (以便后续处理器可以读取) r.Body = io.NopCloser(bytes.NewBuffer(body)) return string(body) } // logRequestStart 记录请求开始 func (m *LoggerMiddleware) logRequestStart(r *http.Request, requestID, requestBody string) { fields := map[string]interface{}{ "request_id": requestID, "method": r.Method, "path": r.URL.Path, "query": r.URL.RawQuery, "user_agent": r.UserAgent(), "remote_addr": getClientIP(r), "content_length": r.ContentLength, } if requestBody != "" { fields["request_body"] = requestBody } logx.WithContext(r.Context()).Infof("Request started: %+v", fields) } // logRequestComplete 记录请求完成 func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) { fields := map[string]interface{}{ "request_id": requestID, "method": r.Method, "path": r.URL.Path, "status": status, "response_size": size, "duration_ms": duration.Milliseconds(), "duration": duration.String(), } if responseBody != "" { fields["response_body"] = responseBody } // 根据状态码选择日志级别 switch { case status >= 500: logx.WithContext(r.Context()).Errorf("Request completed with server error: %+v", fields) case status >= 400: logx.WithContext(r.Context()).Infof("Request completed with client error: %+v", fields) default: logx.WithContext(r.Context()).Infof("Request completed: %+v", fields) } } // logSlowRequest 记录慢请求 func (m *LoggerMiddleware) logSlowRequest(r *http.Request, requestID string, duration time.Duration) { fields := map[string]interface{}{ "request_id": requestID, "method": r.Method, "path": r.URL.Path, "duration": duration.String(), "duration_ms": duration.Milliseconds(), } logx.WithContext(r.Context()).Infof("Slow request detected: %+v", fields) } // logPanic 记录panic func (m *LoggerMiddleware) logPanic(r *http.Request, err interface{}, stack string) { fields := map[string]interface{}{ "method": r.Method, "path": r.URL.Path, "error": err, "stack_trace": stack, } logx.WithContext(r.Context()).Errorf("Panic recovered: %+v", fields) } // getClientIP 获取客户端IP func getClientIP(r *http.Request) string { // 检查各种可能的IP头部 for _, header := range []string{"X-Forwarded-For", "X-Real-IP", "X-Client-IP"} { if ip := r.Header.Get(header); ip != "" { // X-Forwarded-For 可能包含多个IP,取第一个 if strings.Contains(ip, ",") { return strings.TrimSpace(strings.Split(ip, ",")[0]) } return ip } } // 使用 RemoteAddr if ip := r.RemoteAddr; ip != "" { // 移除端口号 if strings.Contains(ip, ":") { return strings.Split(ip, ":")[0] } return ip } return "unknown" } // truncateString 截断字符串 func (m *LoggerMiddleware) truncateString(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } // randomString 生成随机字符串 func randomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" result := make([]byte, length) for i := range result { result[i] = charset[time.Now().UnixNano()%int64(len(charset))] } return string(result) }