feat: 完成后端中间件系统完善
## 🛡️ 新增功能 - 实现完整的CORS中间件,支持开发/生产环境配置 - 实现请求日志中间件,完整的请求生命周期记录 - 实现全局错误处理中间件,统一错误响应格式 - 创建中间件管理器,支持链式中间件和配置管理 ## 🔧 技术改进 - 更新配置系统支持中间件配置 - 修复go-zero日志API兼容性问题 - 创建完整的中间件测试用例 - 编译测试通过,功能完整可用 ## 📊 进度提升 - 项目总进度从42.5%提升至50.0% - 中优先级任务完成率达55% - 3个中优先级任务同时完成 ## 🎯 完成的任务 14. 实现 CORS 中间件 16. 实现请求日志中间件 17. 完善全局错误处理 Co-authored-by: Claude Code <claude@anthropic.com>
This commit is contained in:
299
backend/internal/middleware/logger.go
Normal file
299
backend/internal/middleware/logger.go
Normal file
@ -0,0 +1,299 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// LoggerConfig 日志配置
|
||||
type LoggerConfig struct {
|
||||
EnableRequestBody bool // 是否记录请求体
|
||||
EnableResponseBody bool // 是否记录响应体
|
||||
MaxBodySize int64 // 最大记录的请求/响应体大小
|
||||
SkipPaths []string // 跳过记录的路径
|
||||
SlowRequestDuration time.Duration // 慢请求阈值
|
||||
EnablePanicRecover bool // 是否启用panic恢复
|
||||
}
|
||||
|
||||
// DefaultLoggerConfig 默认日志配置
|
||||
func DefaultLoggerConfig() LoggerConfig {
|
||||
return LoggerConfig{
|
||||
EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息)
|
||||
EnableResponseBody: false, // 默认不记录响应体 (减少日志量)
|
||||
MaxBodySize: 1024, // 最大记录1KB
|
||||
SkipPaths: []string{"/health", "/metrics", "/favicon.ico"},
|
||||
SlowRequestDuration: 1 * time.Second,
|
||||
EnablePanicRecover: true,
|
||||
}
|
||||
}
|
||||
|
||||
// LoggerMiddleware 日志中间件
|
||||
type LoggerMiddleware struct {
|
||||
config LoggerConfig
|
||||
}
|
||||
|
||||
// NewLoggerMiddleware 创建日志中间件
|
||||
func NewLoggerMiddleware(config LoggerConfig) *LoggerMiddleware {
|
||||
return &LoggerMiddleware{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// responseWriter 包装 http.ResponseWriter 用于记录响应
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
size int64
|
||||
body *bytes.Buffer
|
||||
}
|
||||
|
||||
// newResponseWriter 创建响应写入器
|
||||
func newResponseWriter(w http.ResponseWriter) *responseWriter {
|
||||
return &responseWriter{
|
||||
ResponseWriter: w,
|
||||
status: http.StatusOK,
|
||||
body: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
// Write 写入响应数据
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
size, err := rw.ResponseWriter.Write(b)
|
||||
rw.size += int64(size)
|
||||
|
||||
// 记录响应体 (如果启用)
|
||||
if rw.body.Len() < int(1024) { // 限制缓存大小
|
||||
rw.body.Write(b)
|
||||
}
|
||||
|
||||
return size, err
|
||||
}
|
||||
|
||||
// WriteHeader 写入响应头
|
||||
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||||
rw.status = statusCode
|
||||
rw.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Handle 处理请求日志
|
||||
func (m *LoggerMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 检查是否跳过此路径
|
||||
if m.shouldSkipPath(r.URL.Path) {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录请求开始时间
|
||||
start := time.Now()
|
||||
|
||||
// 生成请求ID
|
||||
requestID := m.generateRequestID(r)
|
||||
ctx := context.WithValue(r.Context(), "request_id", requestID)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// 添加请求ID到响应头
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
// 包装响应写入器
|
||||
rw := newResponseWriter(w)
|
||||
|
||||
// 记录请求体 (如果启用)
|
||||
var requestBody string
|
||||
if m.config.EnableRequestBody && r.Body != nil {
|
||||
requestBody = m.readRequestBody(r)
|
||||
}
|
||||
|
||||
// 设置panic恢复
|
||||
if m.config.EnablePanicRecover {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
m.logPanic(r, err, string(debug.Stack()))
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 记录请求开始
|
||||
m.logRequestStart(r, requestID, requestBody)
|
||||
|
||||
// 执行下一个处理器
|
||||
next(rw, r)
|
||||
|
||||
// 计算处理时间
|
||||
duration := time.Since(start)
|
||||
|
||||
// 记录响应体 (如果启用)
|
||||
var responseBody string
|
||||
if m.config.EnableResponseBody && rw.body.Len() > 0 {
|
||||
responseBody = m.truncateString(rw.body.String(), int(m.config.MaxBodySize))
|
||||
}
|
||||
|
||||
// 记录请求完成
|
||||
m.logRequestComplete(r, requestID, rw.status, rw.size, duration, responseBody)
|
||||
|
||||
// 记录慢请求
|
||||
if duration > m.config.SlowRequestDuration {
|
||||
m.logSlowRequest(r, requestID, duration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// shouldSkipPath 检查是否应该跳过此路径
|
||||
func (m *LoggerMiddleware) shouldSkipPath(path string) bool {
|
||||
for _, skipPath := range m.config.SkipPaths {
|
||||
if strings.HasPrefix(path, skipPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (m *LoggerMiddleware) generateRequestID(r *http.Request) string {
|
||||
// 优先使用现有的请求ID
|
||||
if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
|
||||
// 生成新的请求ID
|
||||
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
|
||||
}
|
||||
|
||||
// readRequestBody 读取请求体
|
||||
func (m *LoggerMiddleware) readRequestBody(r *http.Request) string {
|
||||
if r.Body == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, m.config.MaxBodySize))
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error reading body: %v", err)
|
||||
}
|
||||
|
||||
// 恢复请求体 (以便后续处理器可以读取)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
return string(body)
|
||||
}
|
||||
|
||||
// logRequestStart 记录请求开始
|
||||
func (m *LoggerMiddleware) logRequestStart(r *http.Request, requestID, requestBody string) {
|
||||
fields := map[string]interface{}{
|
||||
"request_id": requestID,
|
||||
"method": r.Method,
|
||||
"path": r.URL.Path,
|
||||
"query": r.URL.RawQuery,
|
||||
"user_agent": r.UserAgent(),
|
||||
"remote_addr": getClientIP(r),
|
||||
"content_length": r.ContentLength,
|
||||
}
|
||||
|
||||
if requestBody != "" {
|
||||
fields["request_body"] = requestBody
|
||||
}
|
||||
|
||||
logx.WithContext(r.Context()).Infof("Request started: %+v", fields)
|
||||
}
|
||||
|
||||
// logRequestComplete 记录请求完成
|
||||
func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) {
|
||||
fields := map[string]interface{}{
|
||||
"request_id": requestID,
|
||||
"method": r.Method,
|
||||
"path": r.URL.Path,
|
||||
"status": status,
|
||||
"response_size": size,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"duration": duration.String(),
|
||||
}
|
||||
|
||||
if responseBody != "" {
|
||||
fields["response_body"] = responseBody
|
||||
}
|
||||
|
||||
// 根据状态码选择日志级别
|
||||
switch {
|
||||
case status >= 500:
|
||||
logx.WithContext(r.Context()).Errorf("Request completed with server error: %+v", fields)
|
||||
case status >= 400:
|
||||
logx.WithContext(r.Context()).Infof("Request completed with client error: %+v", fields)
|
||||
default:
|
||||
logx.WithContext(r.Context()).Infof("Request completed: %+v", fields)
|
||||
}
|
||||
}
|
||||
|
||||
// logSlowRequest 记录慢请求
|
||||
func (m *LoggerMiddleware) logSlowRequest(r *http.Request, requestID string, duration time.Duration) {
|
||||
fields := map[string]interface{}{
|
||||
"request_id": requestID,
|
||||
"method": r.Method,
|
||||
"path": r.URL.Path,
|
||||
"duration": duration.String(),
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
}
|
||||
logx.WithContext(r.Context()).Infof("Slow request detected: %+v", fields)
|
||||
}
|
||||
|
||||
// logPanic 记录panic
|
||||
func (m *LoggerMiddleware) logPanic(r *http.Request, err interface{}, stack string) {
|
||||
fields := map[string]interface{}{
|
||||
"method": r.Method,
|
||||
"path": r.URL.Path,
|
||||
"error": err,
|
||||
"stack_trace": stack,
|
||||
}
|
||||
logx.WithContext(r.Context()).Errorf("Panic recovered: %+v", fields)
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// 检查各种可能的IP头部
|
||||
for _, header := range []string{"X-Forwarded-For", "X-Real-IP", "X-Client-IP"} {
|
||||
if ip := r.Header.Get(header); ip != "" {
|
||||
// X-Forwarded-For 可能包含多个IP,取第一个
|
||||
if strings.Contains(ip, ",") {
|
||||
return strings.TrimSpace(strings.Split(ip, ",")[0])
|
||||
}
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 RemoteAddr
|
||||
if ip := r.RemoteAddr; ip != "" {
|
||||
// 移除端口号
|
||||
if strings.Contains(ip, ":") {
|
||||
return strings.Split(ip, ":")[0]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// truncateString 截断字符串
|
||||
func (m *LoggerMiddleware) truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// randomString 生成随机字符串
|
||||
func randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
for i := range result {
|
||||
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
Reference in New Issue
Block a user