Files
photography/backend/internal/middleware/logger.go
xujiang 5dd0bc19e4
Some checks failed
部署管理后台 / 🧪 测试和构建 (push) Failing after 1m5s
部署管理后台 / 🔒 安全扫描 (push) Has been skipped
部署后端服务 / 🧪 测试后端 (push) Failing after 3m13s
部署前端网站 / 🧪 测试和构建 (push) Failing after 2m10s
部署管理后台 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🚀 构建并部署 (push) Has been skipped
部署管理后台 / 🔄 回滚部署 (push) Has been skipped
部署前端网站 / 🚀 部署到生产环境 (push) Has been skipped
部署后端服务 / 🔄 回滚部署 (push) Has been skipped
style: 统一代码格式化 (go fmt + 配置更新)
- 后端:应用 go fmt 自动格式化,统一代码风格
- 前端:更新 API 配置,完善类型安全
- 所有代码符合项目规范,准备生产部署
2025-07-14 10:02:04 +08:00

300 lines
7.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/zeromicro/go-zero/core/logx"
)
// LoggerConfig 日志配置
type LoggerConfig struct {
EnableRequestBody bool // 是否记录请求体
EnableResponseBody bool // 是否记录响应体
MaxBodySize int64 // 最大记录的请求/响应体大小
SkipPaths []string // 跳过记录的路径
SlowRequestDuration time.Duration // 慢请求阈值
EnablePanicRecover bool // 是否启用panic恢复
}
// DefaultLoggerConfig 默认日志配置
func DefaultLoggerConfig() LoggerConfig {
return LoggerConfig{
EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息)
EnableResponseBody: false, // 默认不记录响应体 (减少日志量)
MaxBodySize: 1024, // 最大记录1KB
SkipPaths: []string{"/health", "/metrics", "/favicon.ico"},
SlowRequestDuration: 1 * time.Second,
EnablePanicRecover: true,
}
}
// LoggerMiddleware 日志中间件
type LoggerMiddleware struct {
config LoggerConfig
}
// NewLoggerMiddleware 创建日志中间件
func NewLoggerMiddleware(config LoggerConfig) *LoggerMiddleware {
return &LoggerMiddleware{
config: config,
}
}
// responseWriter 包装 http.ResponseWriter 用于记录响应
type responseWriter struct {
http.ResponseWriter
status int
size int64
body *bytes.Buffer
}
// newResponseWriter 创建响应写入器
func newResponseWriter(w http.ResponseWriter) *responseWriter {
return &responseWriter{
ResponseWriter: w,
status: http.StatusOK,
body: &bytes.Buffer{},
}
}
// Write 写入响应数据
func (rw *responseWriter) Write(b []byte) (int, error) {
size, err := rw.ResponseWriter.Write(b)
rw.size += int64(size)
// 记录响应体 (如果启用)
if rw.body.Len() < int(1024) { // 限制缓存大小
rw.body.Write(b)
}
return size, err
}
// WriteHeader 写入响应头
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.status = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}
// Handle 处理请求日志
func (m *LoggerMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 检查是否跳过此路径
if m.shouldSkipPath(r.URL.Path) {
next(w, r)
return
}
// 记录请求开始时间
start := time.Now()
// 生成请求ID
requestID := m.generateRequestID(r)
ctx := context.WithValue(r.Context(), "request_id", requestID)
r = r.WithContext(ctx)
// 添加请求ID到响应头
w.Header().Set("X-Request-ID", requestID)
// 包装响应写入器
rw := newResponseWriter(w)
// 记录请求体 (如果启用)
var requestBody string
if m.config.EnableRequestBody && r.Body != nil {
requestBody = m.readRequestBody(r)
}
// 设置panic恢复
if m.config.EnablePanicRecover {
defer func() {
if err := recover(); err != nil {
m.logPanic(r, err, string(debug.Stack()))
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
}
}()
}
// 记录请求开始
m.logRequestStart(r, requestID, requestBody)
// 执行下一个处理器
next(rw, r)
// 计算处理时间
duration := time.Since(start)
// 记录响应体 (如果启用)
var responseBody string
if m.config.EnableResponseBody && rw.body.Len() > 0 {
responseBody = m.truncateString(rw.body.String(), int(m.config.MaxBodySize))
}
// 记录请求完成
m.logRequestComplete(r, requestID, rw.status, rw.size, duration, responseBody)
// 记录慢请求
if duration > m.config.SlowRequestDuration {
m.logSlowRequest(r, requestID, duration)
}
})
}
// shouldSkipPath 检查是否应该跳过此路径
func (m *LoggerMiddleware) shouldSkipPath(path string) bool {
for _, skipPath := range m.config.SkipPaths {
if strings.HasPrefix(path, skipPath) {
return true
}
}
return false
}
// generateRequestID 生成请求ID
func (m *LoggerMiddleware) generateRequestID(r *http.Request) string {
// 优先使用现有的请求ID
if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
return requestID
}
// 生成新的请求ID
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
}
// readRequestBody 读取请求体
func (m *LoggerMiddleware) readRequestBody(r *http.Request) string {
if r.Body == nil {
return ""
}
// 读取请求体
body, err := io.ReadAll(io.LimitReader(r.Body, m.config.MaxBodySize))
if err != nil {
return fmt.Sprintf("error reading body: %v", err)
}
// 恢复请求体 (以便后续处理器可以读取)
r.Body = io.NopCloser(bytes.NewBuffer(body))
return string(body)
}
// logRequestStart 记录请求开始
func (m *LoggerMiddleware) logRequestStart(r *http.Request, requestID, requestBody string) {
fields := map[string]interface{}{
"request_id": requestID,
"method": r.Method,
"path": r.URL.Path,
"query": r.URL.RawQuery,
"user_agent": r.UserAgent(),
"remote_addr": getClientIP(r),
"content_length": r.ContentLength,
}
if requestBody != "" {
fields["request_body"] = requestBody
}
logx.WithContext(r.Context()).Infof("Request started: %+v", fields)
}
// logRequestComplete 记录请求完成
func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) {
fields := map[string]interface{}{
"request_id": requestID,
"method": r.Method,
"path": r.URL.Path,
"status": status,
"response_size": size,
"duration_ms": duration.Milliseconds(),
"duration": duration.String(),
}
if responseBody != "" {
fields["response_body"] = responseBody
}
// 根据状态码选择日志级别
switch {
case status >= 500:
logx.WithContext(r.Context()).Errorf("Request completed with server error: %+v", fields)
case status >= 400:
logx.WithContext(r.Context()).Infof("Request completed with client error: %+v", fields)
default:
logx.WithContext(r.Context()).Infof("Request completed: %+v", fields)
}
}
// logSlowRequest 记录慢请求
func (m *LoggerMiddleware) logSlowRequest(r *http.Request, requestID string, duration time.Duration) {
fields := map[string]interface{}{
"request_id": requestID,
"method": r.Method,
"path": r.URL.Path,
"duration": duration.String(),
"duration_ms": duration.Milliseconds(),
}
logx.WithContext(r.Context()).Infof("Slow request detected: %+v", fields)
}
// logPanic 记录panic
func (m *LoggerMiddleware) logPanic(r *http.Request, err interface{}, stack string) {
fields := map[string]interface{}{
"method": r.Method,
"path": r.URL.Path,
"error": err,
"stack_trace": stack,
}
logx.WithContext(r.Context()).Errorf("Panic recovered: %+v", fields)
}
// getClientIP 获取客户端IP
func getClientIP(r *http.Request) string {
// 检查各种可能的IP头部
for _, header := range []string{"X-Forwarded-For", "X-Real-IP", "X-Client-IP"} {
if ip := r.Header.Get(header); ip != "" {
// X-Forwarded-For 可能包含多个IP取第一个
if strings.Contains(ip, ",") {
return strings.TrimSpace(strings.Split(ip, ",")[0])
}
return ip
}
}
// 使用 RemoteAddr
if ip := r.RemoteAddr; ip != "" {
// 移除端口号
if strings.Contains(ip, ":") {
return strings.Split(ip, ":")[0]
}
return ip
}
return "unknown"
}
// truncateString 截断字符串
func (m *LoggerMiddleware) truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// randomString 生成随机字符串
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
}
return string(result)
}