package middleware import ( "net/http" "strings" "time" "github.com/zeromicro/go-zero/core/logx" ) // CORSConfig CORS 配置 type CORSConfig struct { AllowOrigins []string // 允许的来源 AllowMethods []string // 允许的方法 AllowHeaders []string // 允许的头部 ExposeHeaders []string // 暴露的头部 AllowCredentials bool // 是否允许携带凭证 MaxAge time.Duration // 预检请求缓存时间 } // DefaultCORSConfig 默认 CORS 配置 func DefaultCORSConfig() CORSConfig { return CORSConfig{ AllowOrigins: []string{ "http://localhost:3000", "http://localhost:3001", "http://localhost:5173", "http://localhost:8080", }, AllowMethods: []string{ "GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", }, AllowHeaders: []string{ "Origin", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "accept", "origin", "Cache-Control", "X-Requested-With", }, ExposeHeaders: []string{ "Content-Length", "Content-Type", }, AllowCredentials: true, MaxAge: 12 * time.Hour, } } // ProductionCORSConfig 生产环境 CORS 配置 func ProductionCORSConfig(allowedOrigins []string) CORSConfig { config := DefaultCORSConfig() if len(allowedOrigins) > 0 { config.AllowOrigins = allowedOrigins } else { // 生产环境默认只允许 HTTPS config.AllowOrigins = []string{ "https://photography.iriver.top", "https://admin.photography.iriver.top", } } return config } // CORSMiddleware CORS 中间件 type CORSMiddleware struct { config CORSConfig } // NewCORSMiddleware 创建 CORS 中间件 func NewCORSMiddleware(config CORSConfig) *CORSMiddleware { return &CORSMiddleware{ config: config, } } // Handle 处理 CORS func (m *CORSMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // 检查来源是否被允许 if origin != "" && m.isOriginAllowed(origin) { w.Header().Set("Access-Control-Allow-Origin", origin) } // 设置允许的方法 if len(m.config.AllowMethods) > 0 { w.Header().Set("Access-Control-Allow-Methods", strings.Join(m.config.AllowMethods, ", ")) } // 设置允许的头部 if len(m.config.AllowHeaders) > 0 { w.Header().Set("Access-Control-Allow-Headers", strings.Join(m.config.AllowHeaders, ", ")) } // 设置暴露的头部 if len(m.config.ExposeHeaders) > 0 { w.Header().Set("Access-Control-Expose-Headers", strings.Join(m.config.ExposeHeaders, ", ")) } // 设置是否允许携带凭证 if m.config.AllowCredentials { w.Header().Set("Access-Control-Allow-Credentials", "true") } // 设置预检请求缓存时间 if m.config.MaxAge > 0 { w.Header().Set("Access-Control-Max-Age", m.config.MaxAge.String()) } // 添加安全头部 m.setSecurityHeaders(w) // 处理预检请求 if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } // 记录跨域请求 if origin != "" { logx.Infof("CORS request from origin: %s, method: %s, path: %s", origin, r.Method, r.URL.Path) } next(w, r) }) } // isOriginAllowed 检查来源是否被允许 func (m *CORSMiddleware) isOriginAllowed(origin string) bool { for _, allowedOrigin := range m.config.AllowOrigins { if allowedOrigin == "*" { return true } if allowedOrigin == origin { return true } // 支持通配符匹配 if strings.HasSuffix(allowedOrigin, "*") { prefix := strings.TrimSuffix(allowedOrigin, "*") if strings.HasPrefix(origin, prefix) { return true } } } return false } // setSecurityHeaders 设置安全头部 func (m *CORSMiddleware) setSecurityHeaders(w http.ResponseWriter) { // 防止点击劫持 w.Header().Set("X-Frame-Options", "DENY") // 防止 MIME 类型嗅探 w.Header().Set("X-Content-Type-Options", "nosniff") // XSS 保护 w.Header().Set("X-XSS-Protection", "1; mode=block") // 引用者策略 w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // 内容安全策略 (基础版) w.Header().Set("Content-Security-Policy", "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self'") }