Files
photography/backend/internal/middleware/logger.go
xujiang 5b3fc9bf9c feat: 完成后端中间件系统完善
## 🛡️ 新增功能
- 实现完整的CORS中间件,支持开发/生产环境配置
- 实现请求日志中间件,完整的请求生命周期记录
- 实现全局错误处理中间件,统一错误响应格式
- 创建中间件管理器,支持链式中间件和配置管理

## 🔧 技术改进
- 更新配置系统支持中间件配置
- 修复go-zero日志API兼容性问题
- 创建完整的中间件测试用例
- 编译测试通过,功能完整可用

## 📊 进度提升
- 项目总进度从42.5%提升至50.0%
- 中优先级任务完成率达55%
- 3个中优先级任务同时完成

## 🎯 完成的任务
14. 实现 CORS 中间件
16. 实现请求日志中间件
17. 完善全局错误处理

Co-authored-by: Claude Code <claude@anthropic.com>
2025-07-11 13:55:38 +08:00

299 lines
7.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/zeromicro/go-zero/core/logx"
)
// LoggerConfig 日志配置
type LoggerConfig struct {
EnableRequestBody bool // 是否记录请求体
EnableResponseBody bool // 是否记录响应体
MaxBodySize int64 // 最大记录的请求/响应体大小
SkipPaths []string // 跳过记录的路径
SlowRequestDuration time.Duration // 慢请求阈值
EnablePanicRecover bool // 是否启用panic恢复
}
// DefaultLoggerConfig 默认日志配置
func DefaultLoggerConfig() LoggerConfig {
return LoggerConfig{
EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息)
EnableResponseBody: false, // 默认不记录响应体 (减少日志量)
MaxBodySize: 1024, // 最大记录1KB
SkipPaths: []string{"/health", "/metrics", "/favicon.ico"},
SlowRequestDuration: 1 * time.Second,
EnablePanicRecover: true,
}
}
// LoggerMiddleware 日志中间件
type LoggerMiddleware struct {
config LoggerConfig
}
// NewLoggerMiddleware 创建日志中间件
func NewLoggerMiddleware(config LoggerConfig) *LoggerMiddleware {
return &LoggerMiddleware{
config: config,
}
}
// responseWriter 包装 http.ResponseWriter 用于记录响应
type responseWriter struct {
http.ResponseWriter
status int
size int64
body *bytes.Buffer
}
// newResponseWriter 创建响应写入器
func newResponseWriter(w http.ResponseWriter) *responseWriter {
return &responseWriter{
ResponseWriter: w,
status: http.StatusOK,
body: &bytes.Buffer{},
}
}
// Write 写入响应数据
func (rw *responseWriter) Write(b []byte) (int, error) {
size, err := rw.ResponseWriter.Write(b)
rw.size += int64(size)
// 记录响应体 (如果启用)
if rw.body.Len() < int(1024) { // 限制缓存大小
rw.body.Write(b)
}
return size, err
}
// WriteHeader 写入响应头
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.status = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}
// Handle 处理请求日志
func (m *LoggerMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 检查是否跳过此路径
if m.shouldSkipPath(r.URL.Path) {
next(w, r)
return
}
// 记录请求开始时间
start := time.Now()
// 生成请求ID
requestID := m.generateRequestID(r)
ctx := context.WithValue(r.Context(), "request_id", requestID)
r = r.WithContext(ctx)
// 添加请求ID到响应头
w.Header().Set("X-Request-ID", requestID)
// 包装响应写入器
rw := newResponseWriter(w)
// 记录请求体 (如果启用)
var requestBody string
if m.config.EnableRequestBody && r.Body != nil {
requestBody = m.readRequestBody(r)
}
// 设置panic恢复
if m.config.EnablePanicRecover {
defer func() {
if err := recover(); err != nil {
m.logPanic(r, err, string(debug.Stack()))
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
}
}()
}
// 记录请求开始
m.logRequestStart(r, requestID, requestBody)
// 执行下一个处理器
next(rw, r)
// 计算处理时间
duration := time.Since(start)
// 记录响应体 (如果启用)
var responseBody string
if m.config.EnableResponseBody && rw.body.Len() > 0 {
responseBody = m.truncateString(rw.body.String(), int(m.config.MaxBodySize))
}
// 记录请求完成
m.logRequestComplete(r, requestID, rw.status, rw.size, duration, responseBody)
// 记录慢请求
if duration > m.config.SlowRequestDuration {
m.logSlowRequest(r, requestID, duration)
}
})
}
// shouldSkipPath 检查是否应该跳过此路径
func (m *LoggerMiddleware) shouldSkipPath(path string) bool {
for _, skipPath := range m.config.SkipPaths {
if strings.HasPrefix(path, skipPath) {
return true
}
}
return false
}
// generateRequestID 生成请求ID
func (m *LoggerMiddleware) generateRequestID(r *http.Request) string {
// 优先使用现有的请求ID
if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
return requestID
}
// 生成新的请求ID
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
}
// readRequestBody 读取请求体
func (m *LoggerMiddleware) readRequestBody(r *http.Request) string {
if r.Body == nil {
return ""
}
// 读取请求体
body, err := io.ReadAll(io.LimitReader(r.Body, m.config.MaxBodySize))
if err != nil {
return fmt.Sprintf("error reading body: %v", err)
}
// 恢复请求体 (以便后续处理器可以读取)
r.Body = io.NopCloser(bytes.NewBuffer(body))
return string(body)
}
// logRequestStart 记录请求开始
func (m *LoggerMiddleware) logRequestStart(r *http.Request, requestID, requestBody string) {
fields := map[string]interface{}{
"request_id": requestID,
"method": r.Method,
"path": r.URL.Path,
"query": r.URL.RawQuery,
"user_agent": r.UserAgent(),
"remote_addr": getClientIP(r),
"content_length": r.ContentLength,
}
if requestBody != "" {
fields["request_body"] = requestBody
}
logx.WithContext(r.Context()).Infof("Request started: %+v", fields)
}
// logRequestComplete 记录请求完成
func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) {
fields := map[string]interface{}{
"request_id": requestID,
"method": r.Method,
"path": r.URL.Path,
"status": status,
"response_size": size,
"duration_ms": duration.Milliseconds(),
"duration": duration.String(),
}
if responseBody != "" {
fields["response_body"] = responseBody
}
// 根据状态码选择日志级别
switch {
case status >= 500:
logx.WithContext(r.Context()).Errorf("Request completed with server error: %+v", fields)
case status >= 400:
logx.WithContext(r.Context()).Infof("Request completed with client error: %+v", fields)
default:
logx.WithContext(r.Context()).Infof("Request completed: %+v", fields)
}
}
// logSlowRequest 记录慢请求
func (m *LoggerMiddleware) logSlowRequest(r *http.Request, requestID string, duration time.Duration) {
fields := map[string]interface{}{
"request_id": requestID,
"method": r.Method,
"path": r.URL.Path,
"duration": duration.String(),
"duration_ms": duration.Milliseconds(),
}
logx.WithContext(r.Context()).Infof("Slow request detected: %+v", fields)
}
// logPanic 记录panic
func (m *LoggerMiddleware) logPanic(r *http.Request, err interface{}, stack string) {
fields := map[string]interface{}{
"method": r.Method,
"path": r.URL.Path,
"error": err,
"stack_trace": stack,
}
logx.WithContext(r.Context()).Errorf("Panic recovered: %+v", fields)
}
// getClientIP 获取客户端IP
func getClientIP(r *http.Request) string {
// 检查各种可能的IP头部
for _, header := range []string{"X-Forwarded-For", "X-Real-IP", "X-Client-IP"} {
if ip := r.Header.Get(header); ip != "" {
// X-Forwarded-For 可能包含多个IP取第一个
if strings.Contains(ip, ",") {
return strings.TrimSpace(strings.Split(ip, ",")[0])
}
return ip
}
}
// 使用 RemoteAddr
if ip := r.RemoteAddr; ip != "" {
// 移除端口号
if strings.Contains(ip, ":") {
return strings.Split(ip, ":")[0]
}
return ip
}
return "unknown"
}
// truncateString 截断字符串
func (m *LoggerMiddleware) truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// randomString 生成随机字符串
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
}
return string(result)
}