package config import ( "fmt" "time" "github.com/spf13/viper" ) // Config 应用配置 type Config struct { App AppConfig `mapstructure:"app"` Database DatabaseConfig `mapstructure:"database"` Redis RedisConfig `mapstructure:"redis"` JWT JWTConfig `mapstructure:"jwt"` Storage StorageConfig `mapstructure:"storage"` Upload UploadConfig `mapstructure:"upload"` Logger LoggerConfig `mapstructure:"logger"` CORS CORSConfig `mapstructure:"cors"` RateLimit RateLimitConfig `mapstructure:"rate_limit"` } // AppConfig 应用配置 type AppConfig struct { Name string `mapstructure:"name"` Version string `mapstructure:"version"` Environment string `mapstructure:"environment"` Port int `mapstructure:"port"` Debug bool `mapstructure:"debug"` } // DatabaseConfig 数据库配置 type DatabaseConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Username string `mapstructure:"username"` Password string `mapstructure:"password"` Database string `mapstructure:"database"` SSLMode string `mapstructure:"ssl_mode"` MaxOpenConns int `mapstructure:"max_open_conns"` MaxIdleConns int `mapstructure:"max_idle_conns"` ConnMaxLifetime int `mapstructure:"conn_max_lifetime"` } // RedisConfig Redis配置 type RedisConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Password string `mapstructure:"password"` Database int `mapstructure:"database"` PoolSize int `mapstructure:"pool_size"` MinIdleConns int `mapstructure:"min_idle_conns"` } // JWTConfig JWT配置 type JWTConfig struct { Secret string `mapstructure:"secret"` ExpiresIn string `mapstructure:"expires_in"` RefreshExpiresIn string `mapstructure:"refresh_expires_in"` } // StorageConfig 存储配置 type StorageConfig struct { Type string `mapstructure:"type"` Local LocalConfig `mapstructure:"local"` S3 S3Config `mapstructure:"s3"` } // LocalConfig 本地存储配置 type LocalConfig struct { BasePath string `mapstructure:"base_path"` BaseURL string `mapstructure:"base_url"` } // S3Config S3存储配置 type S3Config struct { Region string `mapstructure:"region"` Bucket string `mapstructure:"bucket"` AccessKey string `mapstructure:"access_key"` SecretKey string `mapstructure:"secret_key"` Endpoint string `mapstructure:"endpoint"` } // UploadConfig 上传配置 type UploadConfig struct { MaxFileSize int64 `mapstructure:"max_file_size"` AllowedTypes []string `mapstructure:"allowed_types"` ThumbnailSizes []ThumbnailSize `mapstructure:"thumbnail_sizes"` } // ThumbnailSize 缩略图尺寸 type ThumbnailSize struct { Name string `mapstructure:"name"` Width int `mapstructure:"width"` Height int `mapstructure:"height"` } // LoggerConfig 日志配置 type LoggerConfig struct { Level string `mapstructure:"level"` Format string `mapstructure:"format"` Output string `mapstructure:"output"` Filename string `mapstructure:"filename"` MaxSize int `mapstructure:"max_size"` MaxAge int `mapstructure:"max_age"` Compress bool `mapstructure:"compress"` } // CORSConfig CORS配置 type CORSConfig struct { AllowedOrigins []string `mapstructure:"allowed_origins"` AllowedMethods []string `mapstructure:"allowed_methods"` AllowedHeaders []string `mapstructure:"allowed_headers"` AllowCredentials bool `mapstructure:"allow_credentials"` } // RateLimitConfig 限流配置 type RateLimitConfig struct { Enabled bool `mapstructure:"enabled"` RequestsPerMinute int `mapstructure:"requests_per_minute"` Burst int `mapstructure:"burst"` } var AppConfig *Config // LoadConfig 加载配置 func LoadConfig(configPath string) (*Config, error) { viper.SetConfigFile(configPath) viper.SetConfigType("yaml") // 设置环境变量前缀 viper.SetEnvPrefix("PHOTOGRAPHY") viper.AutomaticEnv() // 环境变量替换配置 viper.BindEnv("database.host", "DB_HOST") viper.BindEnv("database.port", "DB_PORT") viper.BindEnv("database.username", "DB_USER") viper.BindEnv("database.password", "DB_PASSWORD") viper.BindEnv("database.database", "DB_NAME") viper.BindEnv("redis.host", "REDIS_HOST") viper.BindEnv("redis.port", "REDIS_PORT") viper.BindEnv("redis.password", "REDIS_PASSWORD") viper.BindEnv("jwt.secret", "JWT_SECRET") viper.BindEnv("storage.s3.access_key", "AWS_ACCESS_KEY_ID") viper.BindEnv("storage.s3.secret_key", "AWS_SECRET_ACCESS_KEY") viper.BindEnv("app.port", "PORT") if err := viper.ReadInConfig(); err != nil { return nil, fmt.Errorf("failed to read config file: %w", err) } var config Config if err := viper.Unmarshal(&config); err != nil { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } // 验证配置 if err := validateConfig(&config); err != nil { return nil, fmt.Errorf("config validation failed: %w", err) } AppConfig = &config return &config, nil } // validateConfig 验证配置 func validateConfig(config *Config) error { if config.App.Name == "" { return fmt.Errorf("app name is required") } if config.Database.Host == "" { return fmt.Errorf("database host is required") } if config.JWT.Secret == "" { return fmt.Errorf("jwt secret is required") } return nil } // GetJWTExpiration 获取JWT过期时间 func (c *Config) GetJWTExpiration() time.Duration { duration, err := time.ParseDuration(c.JWT.ExpiresIn) if err != nil { return 24 * time.Hour // 默认24小时 } return duration } // GetJWTRefreshExpiration 获取JWT刷新过期时间 func (c *Config) GetJWTRefreshExpiration() time.Duration { duration, err := time.ParseDuration(c.JWT.RefreshExpiresIn) if err != nil { return 7 * 24 * time.Hour // 默认7天 } return duration } // GetDatabaseDSN 获取数据库DSN func (c *Config) GetDatabaseDSN() string { return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", c.Database.Host, c.Database.Port, c.Database.Username, c.Database.Password, c.Database.Database, c.Database.SSLMode, ) } // GetRedisAddr 获取Redis地址 func (c *Config) GetRedisAddr() string { return fmt.Sprintf("%s:%d", c.Redis.Host, c.Redis.Port) } // GetServerAddr 获取服务器地址 func (c *Config) GetServerAddr() string { return fmt.Sprintf(":%d", c.App.Port) } // IsDevelopment 是否为开发环境 func (c *Config) IsDevelopment() bool { return c.App.Environment == "development" } // IsProduction 是否为生产环境 func (c *Config) IsProduction() bool { return c.App.Environment == "production" }