diff --git a/backend/etc/photography-api.yaml b/backend/etc/photography-api.yaml index 6b55679..c939fc3 100644 --- a/backend/etc/photography-api.yaml +++ b/backend/etc/photography-api.yaml @@ -1,3 +1,43 @@ Name: photography-api Host: 0.0.0.0 -Port: 8888 +Port: 8080 + +# 数据库配置 +Database: + Driver: sqlite + FilePath: data/photography.db + Host: localhost + Port: 5432 + Database: photography + Username: postgres + Password: "" + Charset: utf8mb4 + SSLMode: disable + MaxOpenConns: 100 + MaxIdleConns: 10 + +# 认证配置 +Auth: + AccessSecret: photography-secret-key-2024 + AccessExpire: 86400 + +# 文件上传配置 +FileUpload: + MaxSize: 10485760 # 10MB + UploadDir: uploads + AllowedTypes: + - image/jpeg + - image/png + - image/gif + - image/webp + +# 中间件配置 +Middleware: + EnableCORS: true + EnableLogger: true + EnableErrorHandle: true + CORSOrigins: + - http://localhost:3000 + - http://localhost:3001 + - http://localhost:5173 + LogLevel: info diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index fba989c..deee050 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -10,6 +10,7 @@ type Config struct { Database database.Config `json:"database"` Auth AuthConfig `json:"auth"` FileUpload FileUploadConfig `json:"file_upload"` + Middleware MiddlewareConfig `json:"middleware"` } type AuthConfig struct { @@ -22,3 +23,11 @@ type FileUploadConfig struct { UploadDir string `json:"upload_dir"` AllowedTypes []string `json:"allowed_types"` } + +type MiddlewareConfig struct { + EnableCORS bool `json:"enable_cors"` + EnableLogger bool `json:"enable_logger"` + EnableErrorHandle bool `json:"enable_error_handle"` + CORSOrigins []string `json:"cors_origins"` + LogLevel string `json:"log_level"` +} diff --git a/backend/internal/middleware/cors.go b/backend/internal/middleware/cors.go new file mode 100644 index 0000000..20d6d90 --- /dev/null +++ b/backend/internal/middleware/cors.go @@ -0,0 +1,175 @@ +package middleware + +import ( + "net/http" + "strings" + "time" + + "github.com/zeromicro/go-zero/core/logx" +) + +// CORSConfig CORS 配置 +type CORSConfig struct { + AllowOrigins []string // 允许的来源 + AllowMethods []string // 允许的方法 + AllowHeaders []string // 允许的头部 + ExposeHeaders []string // 暴露的头部 + AllowCredentials bool // 是否允许携带凭证 + MaxAge time.Duration // 预检请求缓存时间 +} + +// DefaultCORSConfig 默认 CORS 配置 +func DefaultCORSConfig() CORSConfig { + return CORSConfig{ + AllowOrigins: []string{ + "http://localhost:3000", + "http://localhost:3001", + "http://localhost:5173", + "http://localhost:8080", + }, + AllowMethods: []string{ + "GET", + "POST", + "PUT", + "DELETE", + "OPTIONS", + "HEAD", + }, + AllowHeaders: []string{ + "Origin", + "Content-Type", + "Content-Length", + "Accept-Encoding", + "X-CSRF-Token", + "Authorization", + "accept", + "origin", + "Cache-Control", + "X-Requested-With", + }, + ExposeHeaders: []string{ + "Content-Length", + "Content-Type", + }, + AllowCredentials: true, + MaxAge: 12 * time.Hour, + } +} + +// ProductionCORSConfig 生产环境 CORS 配置 +func ProductionCORSConfig(allowedOrigins []string) CORSConfig { + config := DefaultCORSConfig() + if len(allowedOrigins) > 0 { + config.AllowOrigins = allowedOrigins + } else { + // 生产环境默认只允许 HTTPS + config.AllowOrigins = []string{ + "https://photography.iriver.top", + "https://admin.photography.iriver.top", + } + } + return config +} + +// CORSMiddleware CORS 中间件 +type CORSMiddleware struct { + config CORSConfig +} + +// NewCORSMiddleware 创建 CORS 中间件 +func NewCORSMiddleware(config CORSConfig) *CORSMiddleware { + return &CORSMiddleware{ + config: config, + } +} + +// Handle 处理 CORS +func (m *CORSMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + // 检查来源是否被允许 + if origin != "" && m.isOriginAllowed(origin) { + w.Header().Set("Access-Control-Allow-Origin", origin) + } + + // 设置允许的方法 + if len(m.config.AllowMethods) > 0 { + w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.config.AllowMethods, ", ")) + } + + // 设置允许的头部 + if len(m.config.AllowHeaders) > 0 { + w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.config.AllowHeaders, ", ")) + } + + // 设置暴露的头部 + if len(m.config.ExposeHeaders) > 0 { + w.Header().Set("Access-Control-Expose-Headers", strings.Join(m.config.ExposeHeaders, ", ")) + } + + // 设置是否允许携带凭证 + if m.config.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + + // 设置预检请求缓存时间 + if m.config.MaxAge > 0 { + w.Header().Set("Access-Control-Max-Age", m.config.MaxAge.String()) + } + + // 添加安全头部 + m.setSecurityHeaders(w) + + // 处理预检请求 + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + // 记录跨域请求 + if origin != "" { + logx.Infof("CORS request from origin: %s, method: %s, path: %s", origin, r.Method, r.URL.Path) + } + + next(w, r) + }) +} + +// isOriginAllowed 检查来源是否被允许 +func (m *CORSMiddleware) isOriginAllowed(origin string) bool { + for _, allowedOrigin := range m.config.AllowOrigins { + if allowedOrigin == "*" { + return true + } + if allowedOrigin == origin { + return true + } + // 支持通配符匹配 + if strings.HasSuffix(allowedOrigin, "*") { + prefix := strings.TrimSuffix(allowedOrigin, "*") + if strings.HasPrefix(origin, prefix) { + return true + } + } + } + return false +} + +// setSecurityHeaders 设置安全头部 +func (m *CORSMiddleware) setSecurityHeaders(w http.ResponseWriter) { + // 防止点击劫持 + w.Header().Set("X-Frame-Options", "DENY") + + // 防止 MIME 类型嗅探 + w.Header().Set("X-Content-Type-Options", "nosniff") + + // XSS 保护 + w.Header().Set("X-XSS-Protection", "1; mode=block") + + // 引用者策略 + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + + // 内容安全策略 (基础版) + w.Header().Set("Content-Security-Policy", "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self'") +} \ No newline at end of file diff --git a/backend/internal/middleware/error.go b/backend/internal/middleware/error.go new file mode 100644 index 0000000..a1a328b --- /dev/null +++ b/backend/internal/middleware/error.go @@ -0,0 +1,325 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "runtime/debug" + "strings" + "time" + + "photography-backend/pkg/errorx" + "photography-backend/pkg/response" + + "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/rest/httpx" +) + +// ErrorConfig 错误处理配置 +type ErrorConfig struct { + EnableDetailedErrors bool // 是否启用详细错误信息 (开发环境) + EnableStackTrace bool // 是否启用堆栈跟踪 + EnableErrorMonitor bool // 是否启用错误监控 + IgnoreHTTPCodes []int // 忽略的HTTP状态码 (不记录为错误) + SensitiveFields []string // 敏感字段列表 (日志时隐藏) +} + +// DefaultErrorConfig 默认错误配置 +func DefaultErrorConfig() ErrorConfig { + return ErrorConfig{ + EnableDetailedErrors: false, // 生产环境默认关闭 + EnableStackTrace: false, // 生产环境默认关闭 + EnableErrorMonitor: true, + IgnoreHTTPCodes: []int{http.StatusNotFound, http.StatusMethodNotAllowed}, + SensitiveFields: []string{"password", "token", "secret", "key", "authorization"}, + } +} + +// DevelopmentErrorConfig 开发环境错误配置 +func DevelopmentErrorConfig() ErrorConfig { + config := DefaultErrorConfig() + config.EnableDetailedErrors = true + config.EnableStackTrace = true + return config +} + +// ErrorMiddleware 全局错误处理中间件 +type ErrorMiddleware struct { + config ErrorConfig +} + +// NewErrorMiddleware 创建错误处理中间件 +func NewErrorMiddleware(config ErrorConfig) *ErrorMiddleware { + return &ErrorMiddleware{ + config: config, + } +} + +// errorResponseWriter 包装 ResponseWriter 用于捕获错误 +type errorResponseWriter struct { + http.ResponseWriter + statusCode int + body []byte + written bool +} + +// newErrorResponseWriter 创建错误响应写入器 +func newErrorResponseWriter(w http.ResponseWriter) *errorResponseWriter { + return &errorResponseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + } +} + +// Write 写入响应数据 +func (erw *errorResponseWriter) Write(b []byte) (int, error) { + if !erw.written { + erw.body = append(erw.body, b...) + } + return erw.ResponseWriter.Write(b) +} + +// WriteHeader 写入响应头 +func (erw *errorResponseWriter) WriteHeader(statusCode int) { + erw.statusCode = statusCode + erw.written = true + erw.ResponseWriter.WriteHeader(statusCode) +} + +// Handle 处理全局错误 +func (m *ErrorMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 包装响应写入器 + erw := newErrorResponseWriter(w) + + // 设置panic恢复 + defer func() { + if err := recover(); err != nil { + m.handlePanic(erw, r, err) + } + }() + + // 执行下一个处理器 + next(erw, r) + + // 检查是否需要处理错误 + if m.shouldHandleError(erw.statusCode) { + m.handleHTTPError(erw, r) + } + }) +} + +// handlePanic 处理panic +func (m *ErrorMiddleware) handlePanic(w *errorResponseWriter, r *http.Request, err interface{}) { + stack := string(debug.Stack()) + + // 记录panic日志 + logFields := map[string]interface{}{ + "error": err, + "method": r.Method, + "path": r.URL.Path, + "user_agent": r.UserAgent(), + "remote_ip": getClientIP(r), + "timestamp": time.Now().Format(time.RFC3339), + } + + if m.config.EnableStackTrace { + logFields["stack_trace"] = stack + } + + logx.WithContext(r.Context()).Errorf("Panic recovered: %+v", logFields) + + // 响应错误 + m.respondWithError(w.ResponseWriter, r, &errorx.CodeError{ + Code: errorx.ServerError, + Msg: "Internal Server Error", + }, map[string]interface{}{ + "error": err, + "stack": stack, + }) + + // 错误监控 + if m.config.EnableErrorMonitor { + m.reportError("panic", err, logFields) + } +} + +// handleHTTPError 处理HTTP错误 +func (m *ErrorMiddleware) handleHTTPError(w *errorResponseWriter, r *http.Request) { + // 记录HTTP错误 + logFields := map[string]interface{}{ + "status_code": w.statusCode, + "method": r.Method, + "path": r.URL.Path, + "user_agent": r.UserAgent(), + "remote_ip": getClientIP(r), + "timestamp": time.Now().Format(time.RFC3339), + } + + // 尝试解析响应体中的错误信息 + if len(w.body) > 0 { + var errorBody map[string]interface{} + if err := json.Unmarshal(w.body, &errorBody); err == nil { + // 隐藏敏感字段 + errorBody = m.sanitizeFields(errorBody) + logFields["response_body"] = errorBody + } + } + + // 根据状态码选择日志级别 + switch { + case w.statusCode >= 500: + logx.WithContext(r.Context()).Errorf("Server error occurred: %+v", logFields) + case w.statusCode >= 400: + logx.WithContext(r.Context()).Infof("Client error occurred: %+v", logFields) + } + + // 错误监控 + if m.config.EnableErrorMonitor { + m.reportError("http_error", w.statusCode, logFields) + } +} + +// shouldHandleError 检查是否应该处理此错误 +func (m *ErrorMiddleware) shouldHandleError(statusCode int) bool { + for _, ignoreCode := range m.config.IgnoreHTTPCodes { + if statusCode == ignoreCode { + return false + } + } + return statusCode >= 400 +} + +// respondWithError 响应错误 +func (m *ErrorMiddleware) respondWithError(w http.ResponseWriter, r *http.Request, err *errorx.CodeError, extra map[string]interface{}) { + // 构建响应体 + body := response.Body{ + Code: err.Code, + Message: err.Msg, + } + + // 添加详细错误信息 (仅开发环境) + if m.config.EnableDetailedErrors && extra != nil { + // 隐藏敏感信息 + extra = m.sanitizeFields(extra) + body.Data = extra + } + + // 设置HTTP状态码 + httpStatus := errorx.GetHttpStatus(err.Code) + + // 设置响应头 + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(httpStatus) + + // 写入响应 + httpx.WriteJson(w, httpStatus, body) +} + +// sanitizeFields 隐藏敏感字段 +func (m *ErrorMiddleware) sanitizeFields(data map[string]interface{}) map[string]interface{} { + sanitized := make(map[string]interface{}) + + for key, value := range data { + lowerKey := strings.ToLower(key) + + // 检查是否为敏感字段 + sensitive := false + for _, sensitiveField := range m.config.SensitiveFields { + if strings.Contains(lowerKey, strings.ToLower(sensitiveField)) { + sensitive = true + break + } + } + + if sensitive { + sanitized[key] = "***REDACTED***" + } else { + // 递归处理嵌套对象 + if nestedMap, ok := value.(map[string]interface{}); ok { + sanitized[key] = m.sanitizeFields(nestedMap) + } else { + sanitized[key] = value + } + } + } + + return sanitized +} + +// reportError 报告错误到监控系统 +func (m *ErrorMiddleware) reportError(errorType string, error interface{}, context map[string]interface{}) { + // 这里可以集成第三方监控服务 (如 Sentry, DataDog 等) + // 目前只记录到日志 + fields := map[string]interface{}{ + "error_type": errorType, + "error": error, + "context": context, + "timestamp": time.Now().Format(time.RFC3339), + } + logx.Infof("Error reported to monitoring system: %+v", fields) +} + +// ErrorResponse 标准化错误响应 +type ErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Details map[string]interface{} `json:"details,omitempty"` + Timestamp string `json:"timestamp"` + RequestID string `json:"request_id,omitempty"` +} + +// NewErrorResponse 创建标准化错误响应 +func NewErrorResponse(code int, message string, details map[string]interface{}, requestID string) *ErrorResponse { + return &ErrorResponse{ + Code: code, + Message: message, + Details: details, + Timestamp: time.Now().Format(time.RFC3339), + RequestID: requestID, + } +} + +// CommonErrors 常用错误响应 +var CommonErrors = struct { + BadRequest *errorx.CodeError + Unauthorized *errorx.CodeError + Forbidden *errorx.CodeError + NotFound *errorx.CodeError + MethodNotAllowed *errorx.CodeError + InternalServerError *errorx.CodeError + ValidationFailed *errorx.CodeError + RateLimitExceeded *errorx.CodeError +}{ + BadRequest: &errorx.CodeError{ + Code: errorx.ParamError, + Msg: "Bad Request", + }, + Unauthorized: &errorx.CodeError{ + Code: errorx.AuthError, + Msg: "Unauthorized", + }, + Forbidden: &errorx.CodeError{ + Code: errorx.Forbidden, + Msg: "Forbidden", + }, + NotFound: &errorx.CodeError{ + Code: errorx.NotFound, + Msg: "Not Found", + }, + MethodNotAllowed: &errorx.CodeError{ + Code: 405, + Msg: "Method Not Allowed", + }, + InternalServerError: &errorx.CodeError{ + Code: errorx.ServerError, + Msg: "Internal Server Error", + }, + ValidationFailed: &errorx.CodeError{ + Code: errorx.ParamError, + Msg: "Validation Failed", + }, + RateLimitExceeded: &errorx.CodeError{ + Code: 429, + Msg: "Rate Limit Exceeded", + }, +} \ No newline at end of file diff --git a/backend/internal/middleware/logger.go b/backend/internal/middleware/logger.go new file mode 100644 index 0000000..288d1b0 --- /dev/null +++ b/backend/internal/middleware/logger.go @@ -0,0 +1,299 @@ +package middleware + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "runtime/debug" + "strings" + "time" + + "github.com/zeromicro/go-zero/core/logx" +) + +// LoggerConfig 日志配置 +type LoggerConfig struct { + EnableRequestBody bool // 是否记录请求体 + EnableResponseBody bool // 是否记录响应体 + MaxBodySize int64 // 最大记录的请求/响应体大小 + SkipPaths []string // 跳过记录的路径 + SlowRequestDuration time.Duration // 慢请求阈值 + EnablePanicRecover bool // 是否启用panic恢复 +} + +// DefaultLoggerConfig 默认日志配置 +func DefaultLoggerConfig() LoggerConfig { + return LoggerConfig{ + EnableRequestBody: false, // 默认不记录请求体 (可能包含敏感信息) + EnableResponseBody: false, // 默认不记录响应体 (减少日志量) + MaxBodySize: 1024, // 最大记录1KB + SkipPaths: []string{"/health", "/metrics", "/favicon.ico"}, + SlowRequestDuration: 1 * time.Second, + EnablePanicRecover: true, + } +} + +// LoggerMiddleware 日志中间件 +type LoggerMiddleware struct { + config LoggerConfig +} + +// NewLoggerMiddleware 创建日志中间件 +func NewLoggerMiddleware(config LoggerConfig) *LoggerMiddleware { + return &LoggerMiddleware{ + config: config, + } +} + +// responseWriter 包装 http.ResponseWriter 用于记录响应 +type responseWriter struct { + http.ResponseWriter + status int + size int64 + body *bytes.Buffer +} + +// newResponseWriter 创建响应写入器 +func newResponseWriter(w http.ResponseWriter) *responseWriter { + return &responseWriter{ + ResponseWriter: w, + status: http.StatusOK, + body: &bytes.Buffer{}, + } +} + +// Write 写入响应数据 +func (rw *responseWriter) Write(b []byte) (int, error) { + size, err := rw.ResponseWriter.Write(b) + rw.size += int64(size) + + // 记录响应体 (如果启用) + if rw.body.Len() < int(1024) { // 限制缓存大小 + rw.body.Write(b) + } + + return size, err +} + +// WriteHeader 写入响应头 +func (rw *responseWriter) WriteHeader(statusCode int) { + rw.status = statusCode + rw.ResponseWriter.WriteHeader(statusCode) +} + +// Handle 处理请求日志 +func (m *LoggerMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 检查是否跳过此路径 + if m.shouldSkipPath(r.URL.Path) { + next(w, r) + return + } + + // 记录请求开始时间 + start := time.Now() + + // 生成请求ID + requestID := m.generateRequestID(r) + ctx := context.WithValue(r.Context(), "request_id", requestID) + r = r.WithContext(ctx) + + // 添加请求ID到响应头 + w.Header().Set("X-Request-ID", requestID) + + // 包装响应写入器 + rw := newResponseWriter(w) + + // 记录请求体 (如果启用) + var requestBody string + if m.config.EnableRequestBody && r.Body != nil { + requestBody = m.readRequestBody(r) + } + + // 设置panic恢复 + if m.config.EnablePanicRecover { + defer func() { + if err := recover(); err != nil { + m.logPanic(r, err, string(debug.Stack())) + http.Error(rw, "Internal Server Error", http.StatusInternalServerError) + } + }() + } + + // 记录请求开始 + m.logRequestStart(r, requestID, requestBody) + + // 执行下一个处理器 + next(rw, r) + + // 计算处理时间 + duration := time.Since(start) + + // 记录响应体 (如果启用) + var responseBody string + if m.config.EnableResponseBody && rw.body.Len() > 0 { + responseBody = m.truncateString(rw.body.String(), int(m.config.MaxBodySize)) + } + + // 记录请求完成 + m.logRequestComplete(r, requestID, rw.status, rw.size, duration, responseBody) + + // 记录慢请求 + if duration > m.config.SlowRequestDuration { + m.logSlowRequest(r, requestID, duration) + } + }) +} + +// shouldSkipPath 检查是否应该跳过此路径 +func (m *LoggerMiddleware) shouldSkipPath(path string) bool { + for _, skipPath := range m.config.SkipPaths { + if strings.HasPrefix(path, skipPath) { + return true + } + } + return false +} + +// generateRequestID 生成请求ID +func (m *LoggerMiddleware) generateRequestID(r *http.Request) string { + // 优先使用现有的请求ID + if requestID := r.Header.Get("X-Request-ID"); requestID != "" { + return requestID + } + + // 生成新的请求ID + return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8)) +} + +// readRequestBody 读取请求体 +func (m *LoggerMiddleware) readRequestBody(r *http.Request) string { + if r.Body == nil { + return "" + } + + // 读取请求体 + body, err := io.ReadAll(io.LimitReader(r.Body, m.config.MaxBodySize)) + if err != nil { + return fmt.Sprintf("error reading body: %v", err) + } + + // 恢复请求体 (以便后续处理器可以读取) + r.Body = io.NopCloser(bytes.NewBuffer(body)) + + return string(body) +} + +// logRequestStart 记录请求开始 +func (m *LoggerMiddleware) logRequestStart(r *http.Request, requestID, requestBody string) { + fields := map[string]interface{}{ + "request_id": requestID, + "method": r.Method, + "path": r.URL.Path, + "query": r.URL.RawQuery, + "user_agent": r.UserAgent(), + "remote_addr": getClientIP(r), + "content_length": r.ContentLength, + } + + if requestBody != "" { + fields["request_body"] = requestBody + } + + logx.WithContext(r.Context()).Infof("Request started: %+v", fields) +} + +// logRequestComplete 记录请求完成 +func (m *LoggerMiddleware) logRequestComplete(r *http.Request, requestID string, status int, size int64, duration time.Duration, responseBody string) { + fields := map[string]interface{}{ + "request_id": requestID, + "method": r.Method, + "path": r.URL.Path, + "status": status, + "response_size": size, + "duration_ms": duration.Milliseconds(), + "duration": duration.String(), + } + + if responseBody != "" { + fields["response_body"] = responseBody + } + + // 根据状态码选择日志级别 + switch { + case status >= 500: + logx.WithContext(r.Context()).Errorf("Request completed with server error: %+v", fields) + case status >= 400: + logx.WithContext(r.Context()).Infof("Request completed with client error: %+v", fields) + default: + logx.WithContext(r.Context()).Infof("Request completed: %+v", fields) + } +} + +// logSlowRequest 记录慢请求 +func (m *LoggerMiddleware) logSlowRequest(r *http.Request, requestID string, duration time.Duration) { + fields := map[string]interface{}{ + "request_id": requestID, + "method": r.Method, + "path": r.URL.Path, + "duration": duration.String(), + "duration_ms": duration.Milliseconds(), + } + logx.WithContext(r.Context()).Infof("Slow request detected: %+v", fields) +} + +// logPanic 记录panic +func (m *LoggerMiddleware) logPanic(r *http.Request, err interface{}, stack string) { + fields := map[string]interface{}{ + "method": r.Method, + "path": r.URL.Path, + "error": err, + "stack_trace": stack, + } + logx.WithContext(r.Context()).Errorf("Panic recovered: %+v", fields) +} + +// getClientIP 获取客户端IP +func getClientIP(r *http.Request) string { + // 检查各种可能的IP头部 + for _, header := range []string{"X-Forwarded-For", "X-Real-IP", "X-Client-IP"} { + if ip := r.Header.Get(header); ip != "" { + // X-Forwarded-For 可能包含多个IP,取第一个 + if strings.Contains(ip, ",") { + return strings.TrimSpace(strings.Split(ip, ",")[0]) + } + return ip + } + } + + // 使用 RemoteAddr + if ip := r.RemoteAddr; ip != "" { + // 移除端口号 + if strings.Contains(ip, ":") { + return strings.Split(ip, ":")[0] + } + return ip + } + + return "unknown" +} + +// truncateString 截断字符串 +func (m *LoggerMiddleware) truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// randomString 生成随机字符串 +func randomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + result := make([]byte, length) + for i := range result { + result[i] = charset[time.Now().UnixNano()%int64(len(charset))] + } + return string(result) +} \ No newline at end of file diff --git a/backend/internal/middleware/middleware.go b/backend/internal/middleware/middleware.go new file mode 100644 index 0000000..ecc2fba --- /dev/null +++ b/backend/internal/middleware/middleware.go @@ -0,0 +1,203 @@ +package middleware + +import ( + "net/http" + "os" + "strings" + "time" + + "photography-backend/internal/config" + + "github.com/zeromicro/go-zero/core/logx" +) + +// MiddlewareManager 中间件管理器 +type MiddlewareManager struct { + config config.Config + corsMiddleware *CORSMiddleware + logMiddleware *LoggerMiddleware + errorMiddleware *ErrorMiddleware + authMiddleware *AuthMiddleware +} + +// NewMiddlewareManager 创建中间件管理器 +func NewMiddlewareManager(c config.Config) *MiddlewareManager { + return &MiddlewareManager{ + config: c, + corsMiddleware: NewCORSMiddleware(getCORSConfig(c)), + logMiddleware: NewLoggerMiddleware(getLoggerConfig(c)), + errorMiddleware: NewErrorMiddleware(getErrorConfig(c)), + authMiddleware: NewAuthMiddleware(c.Auth.AccessSecret), + } +} + +// getCORSConfig 获取CORS配置 +func getCORSConfig(c config.Config) CORSConfig { + env := getEnvironment() + + if env == "production" { + // 生产环境使用严格的CORS配置 + return ProductionCORSConfig(getProductionOrigins()) + } + + // 开发环境使用宽松的CORS配置 + return DefaultCORSConfig() +} + +// getLoggerConfig 获取日志配置 +func getLoggerConfig(c config.Config) LoggerConfig { + env := getEnvironment() + + config := DefaultLoggerConfig() + + if env == "development" { + // 开发环境启用详细日志 + config.EnableRequestBody = true + config.EnableResponseBody = true + config.MaxBodySize = 4096 + } + + return config +} + +// getErrorConfig 获取错误配置 +func getErrorConfig(c config.Config) ErrorConfig { + env := getEnvironment() + + if env == "development" { + return DevelopmentErrorConfig() + } + + return DefaultErrorConfig() +} + +// getEnvironment 获取环境变量 +func getEnvironment() string { + env := os.Getenv("GO_ENV") + if env == "" { + env = os.Getenv("ENV") + } + if env == "" { + // 默认为开发环境 + env = "development" + } + return strings.ToLower(env) +} + +// getProductionOrigins 获取生产环境允许的来源 +func getProductionOrigins() []string { + origins := os.Getenv("CORS_ALLOWED_ORIGINS") + if origins == "" { + return []string{ + "https://photography.iriver.top", + "https://admin.photography.iriver.top", + } + } + return strings.Split(origins, ",") +} + +// Chain 链式中间件 +func (m *MiddlewareManager) Chain(handler http.HandlerFunc, middlewares ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc { + // 从右到左包装中间件 + for i := len(middlewares) - 1; i >= 0; i-- { + handler = middlewares[i](handler) + } + return handler +} + +// GetGlobalMiddlewares 获取全局中间件 +func (m *MiddlewareManager) GetGlobalMiddlewares() []func(http.HandlerFunc) http.HandlerFunc { + return []func(http.HandlerFunc) http.HandlerFunc{ + m.errorMiddleware.Handle, // 错误处理 (最外层) + m.corsMiddleware.Handle, // CORS 处理 + m.logMiddleware.Handle, // 日志记录 + } +} + +// GetAuthMiddleware 获取认证中间件 +func (m *MiddlewareManager) GetAuthMiddleware() func(http.HandlerFunc) http.HandlerFunc { + return m.authMiddleware.Handle +} + +// ApplyGlobalMiddlewares 应用全局中间件 +func (m *MiddlewareManager) ApplyGlobalMiddlewares(handler http.HandlerFunc) http.HandlerFunc { + return m.Chain(handler, m.GetGlobalMiddlewares()...) +} + +// ApplyAuthMiddlewares 应用认证中间件 +func (m *MiddlewareManager) ApplyAuthMiddlewares(handler http.HandlerFunc) http.HandlerFunc { + middlewares := append(m.GetGlobalMiddlewares(), m.GetAuthMiddleware()) + return m.Chain(handler, middlewares...) +} + +// HealthCheck 健康检查处理器 (不使用中间件) +func (m *MiddlewareManager) HealthCheck(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok","timestamp":"` + + time.Now().Format("2006-01-02T15:04:05Z07:00") + `"}`)) +} + +// Middleware 中间件接口 +type Middleware interface { + Handle(next http.HandlerFunc) http.HandlerFunc +} + +// MiddlewareFunc 中间件函数类型 +type MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc + +// Handle 实现Middleware接口 +func (f MiddlewareFunc) Handle(next http.HandlerFunc) http.HandlerFunc { + return f(next) +} + +// Use 应用中间件到处理器 +func Use(handler http.HandlerFunc, middlewares ...Middleware) http.HandlerFunc { + for i := len(middlewares) - 1; i >= 0; i-- { + handler = middlewares[i].Handle(handler) + } + return handler +} + +// Recovery 通用恢复中间件 +func Recovery() MiddlewareFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + fields := map[string]interface{}{ + "error": err, + "method": r.Method, + "path": r.URL.Path, + } + logx.WithContext(r.Context()).Errorf("Panic recovered in Recovery middleware: %+v", fields) + + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }() + next(w, r) + }) + } +} + +// RequestID 请求ID中间件 +func RequestID() MiddlewareFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + } + + w.Header().Set("X-Request-ID", requestID) + r.Header.Set("X-Request-ID", requestID) + + next(w, r) + }) + } +} + +// generateRequestID 生成请求ID +func generateRequestID() string { + return randomString(16) +} \ No newline at end of file diff --git a/backend/internal/svc/servicecontext.go b/backend/internal/svc/servicecontext.go index f7851a1..ae886f7 100644 --- a/backend/internal/svc/servicecontext.go +++ b/backend/internal/svc/servicecontext.go @@ -4,6 +4,7 @@ import ( "fmt" "gorm.io/gorm" "photography-backend/internal/config" + "photography-backend/internal/middleware" "photography-backend/internal/model" "photography-backend/pkg/utils/database" "github.com/zeromicro/go-zero/core/stores/sqlx" @@ -15,6 +16,7 @@ type ServiceContext struct { UserModel model.UserModel PhotoModel model.PhotoModel CategoryModel model.CategoryModel + Middleware *middleware.MiddlewareManager } func NewServiceContext(c config.Config) *ServiceContext { @@ -32,6 +34,7 @@ func NewServiceContext(c config.Config) *ServiceContext { UserModel: model.NewUserModel(sqlxConn), PhotoModel: model.NewPhotoModel(sqlxConn), CategoryModel: model.NewCategoryModel(sqlxConn), + Middleware: middleware.NewMiddlewareManager(c), } } diff --git a/backend/test_middleware.http b/backend/test_middleware.http new file mode 100644 index 0000000..e154b59 --- /dev/null +++ b/backend/test_middleware.http @@ -0,0 +1,63 @@ +### 1. 测试健康检查 - 不包含中间件 +GET http://localhost:8080/health +Content-Type: application/json + +### 2. 测试CORS预检请求 +OPTIONS http://localhost:8080/api/v1/photos +Origin: http://localhost:3000 +Access-Control-Request-Method: GET +Access-Control-Request-Headers: Authorization, Content-Type + +### 3. 测试CORS跨域请求 +GET http://localhost:8080/api/v1/photos +Origin: http://localhost:3000 +Authorization: Bearer invalid-token + +### 4. 测试日志记录 - GET请求 +GET http://localhost:8080/api/v1/photos +Content-Type: application/json + +### 5. 测试日志记录 - POST请求带请求体 +POST http://localhost:8080/api/v1/auth/login +Content-Type: application/json + +{ + "username": "test", + "password": "password" +} + +### 6. 测试错误处理 - 404错误 +GET http://localhost:8080/api/v1/nonexistent +Content-Type: application/json + +### 7. 测试错误处理 - 认证错误 +GET http://localhost:8080/api/v1/photos +Authorization: Bearer invalid-token +Content-Type: application/json + +### 8. 测试错误处理 - 参数错误 +POST http://localhost:8080/api/v1/auth/login +Content-Type: application/json + +{ + "invalid": "data" +} + +### 9. 测试慢请求记录 (如果有延迟接口) +GET http://localhost:8080/api/v1/photos?delay=2000 +Content-Type: application/json + +### 10. 测试请求ID传递 +GET http://localhost:8080/api/v1/photos +X-Request-ID: test-request-id-12345 +Content-Type: application/json + +### 11. 测试安全头部 +GET http://localhost:8080/api/v1/photos +Origin: http://localhost:3000 +Content-Type: application/json + +### 12. 测试不同来源的CORS +GET http://localhost:8080/api/v1/photos +Origin: http://malicious-site.com +Content-Type: application/json \ No newline at end of file