diff --git a/.env.example b/.env.example index 50ad92dd..dc269f1b 100644 --- a/.env.example +++ b/.env.example @@ -13,6 +13,9 @@ REDIS_HOST=redis REDIS_PORT=6379 REDIS_PASSWORD=redis123456 +# 数据加密密钥 +DATA_ENCRYPTION_KEY=my_secret_encryption_key + # Ports Configuration # Backend API server port (internal: 8080, external: configurable) NOFX_BACKEND_PORT=8080 diff --git a/.gitignore b/.gitignore index 9f3bdd5d..a5c1c3c3 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,11 @@ config.db certs/ beta_codes.txt +# 密钥文件 +keys/ +*.key +*.pem + # 决策日志 decision_logs/ coin_pool_cache/ diff --git a/api/server.go b/api/server.go index a10a39f6..66426706 100644 --- a/api/server.go +++ b/api/server.go @@ -8,6 +8,7 @@ import ( "net/http" "nofx/auth" "nofx/config" + "nofx/crypto" "nofx/decision" "nofx/manager" "nofx/trader" @@ -24,11 +25,12 @@ type Server struct { router *gin.Engine traderManager *manager.TraderManager database config.DatabaseInterface + cryptoService *crypto.CryptoService port int } // NewServer 创建API服务器 -func NewServer(traderManager *manager.TraderManager, database config.DatabaseInterface, port int) *Server { +func NewServer(traderManager *manager.TraderManager, database config.DatabaseInterface, cryptoService *crypto.CryptoService, port int) *Server { // 设置为Release模式(减少日志输出) gin.SetMode(gin.ReleaseMode) @@ -37,10 +39,17 @@ func NewServer(traderManager *manager.TraderManager, database config.DatabaseInt // 启用CORS router.Use(corsMiddleware()) + if cryptoService == nil { + log.Printf("⚠️ 加密服务未初始化,敏感数据加解密功能不可用") + } else { + database.SetCryptoService(cryptoService) + } + s := &Server{ router: router, traderManager: traderManager, database: database, + cryptoService: cryptoService, port: port, } @@ -123,6 +132,7 @@ func (s *Server) setupRoutes() { // 交易所配置 protected.GET("/exchanges", s.handleGetExchangeConfigs) protected.PUT("/exchanges", s.handleUpdateExchangeConfigs) + protected.PUT("/exchanges/encrypted", s.handleUpdateExchangeConfigsEncrypted) // 用户信号源配置 protected.GET("/user/signal-sources", s.handleGetUserSignalSource) @@ -179,11 +189,19 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) { betaModeStr, _ := s.database.GetSystemConfig("beta_mode") betaMode := betaModeStr == "true" + // 获取RSA公钥 + var rsaPublicKey string + if s.cryptoService != nil { + rsaPublicKey = s.cryptoService.GetPublicKeyPEM() + } + c.JSON(http.StatusOK, gin.H{ "beta_mode": betaMode, "default_coins": defaultCoins, "btc_eth_leverage": btcEthLeverage, "altcoin_leverage": altcoinLeverage, + "rsa_public_key": rsaPublicKey, + "rsa_key_id": "rsa-key-2025-11-05", }) } @@ -381,6 +399,21 @@ type ExchangeConfig struct { Testnet bool `json:"testnet,omitempty"` } +// SafeExchangeConfig 安全的交易所配置响应结构(不包含敏感信息) +type SafeExchangeConfig struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Type string `json:"type"` + Enabled bool `json:"enabled"` + Testnet bool `json:"testnet"` + HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // 钱包地址,非敏感信息 + AsterUser string `json:"asterUser"` // Aster用户名,非敏感信息 + Deleted bool `json:"deleted"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type UpdateModelConfigRequest struct { Models map[string]struct { Enabled bool `json:"enabled"` @@ -1005,7 +1038,25 @@ func (s *Server) handleGetExchangeConfigs(c *gin.Context) { } log.Printf("✅ 找到 %d 个交易所配置", len(exchanges)) - c.JSON(http.StatusOK, exchanges) + // 转换为安全的响应结构,过滤敏感信息 + safeExchanges := make([]SafeExchangeConfig, len(exchanges)) + for i, exchange := range exchanges { + safeExchanges[i] = SafeExchangeConfig{ + ID: exchange.ID, + UserID: exchange.UserID, + Name: exchange.Name, + Type: exchange.Type, + Enabled: exchange.Enabled, + Testnet: exchange.Testnet, + HyperliquidWalletAddr: exchange.HyperliquidWalletAddr, // 钱包地址,非敏感信息 + AsterUser: exchange.AsterUser, // Aster用户名,非敏感信息 + Deleted: exchange.Deleted, + CreatedAt: exchange.CreatedAt, + UpdatedAt: exchange.UpdatedAt, + } + } + + c.JSON(http.StatusOK, safeExchanges) } // handleUpdateExchangeConfigs 更新交易所配置 @@ -1638,8 +1689,10 @@ func (s *Server) handleCompleteRegistration(c *gin.Context) { // handleLogin 处理用户登录请求 func (s *Server) handleLogin(c *gin.Context) { var req struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required"` + Email string `json:"email"` + EmailEncrypted *crypto.EncryptedPayload `json:"email_encrypted"` + Password string `json:"password"` + PasswordEncrypted *crypto.EncryptedPayload `json:"password_encrypted"` } if err := c.ShouldBindJSON(&req); err != nil { @@ -1647,6 +1700,51 @@ func (s *Server) handleLogin(c *gin.Context) { return } + if req.EmailEncrypted != nil { + if s.cryptoService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"}) + return + } + + decryptedEmail, err := s.cryptoService.DecryptSensitiveData(req.EmailEncrypted) + if err != nil { + log.Printf("❌ 登录邮箱解密失败: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱解密失败"}) + return + } + req.Email = decryptedEmail + } + + if req.PasswordEncrypted != nil { + if s.cryptoService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"}) + return + } + + decryptedPassword, err := s.cryptoService.DecryptSensitiveData(req.PasswordEncrypted) + if err != nil { + log.Printf("❌ 登录密码解密失败: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "密码解密失败"}) + return + } + req.Password = decryptedPassword + } + + req.Email = strings.TrimSpace(req.Email) + if req.Email == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱不能为空"}) + return + } + if !strings.Contains(req.Email, "@") { + c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱格式错误"}) + return + } + + if strings.TrimSpace(req.Password) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"}) + return + } + // 获取用户信息 user, err := s.database.GetUserByEmail(req.Email) if err != nil { @@ -2026,3 +2124,64 @@ func (s *Server) handleGetPublicTraderConfig(c *gin.Context) { c.JSON(http.StatusOK, result) } + +// handleUpdateExchangeConfigsEncrypted 更新交易所配置(加密传输) +func (s *Server) handleUpdateExchangeConfigsEncrypted(c *gin.Context) { + if s.cryptoService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"}) + return + } + + userID := c.GetString("user_id") + + // 接收加密载荷 + var payload crypto.EncryptedPayload + if err := c.ShouldBindJSON(&payload); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 解密数据 + decryptedData, err := s.cryptoService.DecryptSensitiveData(&payload) + if err != nil { + log.Printf("❌ 解密失败: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "解密失败"}) + return + } + + // 解析解密后的数据 + var req UpdateExchangeConfigRequest + if err := json.Unmarshal([]byte(decryptedData), &req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "数据格式错误"}) + return + } + + // 更新每个交易所的配置 + for exchangeID, exchangeData := range req.Exchanges { + err := s.database.UpdateExchange( + userID, + exchangeID, + exchangeData.Enabled, + exchangeData.APIKey, + exchangeData.SecretKey, + exchangeData.Testnet, + exchangeData.HyperliquidWalletAddr, + exchangeData.AsterUser, + exchangeData.AsterSigner, + exchangeData.AsterPrivateKey, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新交易所 %s 失败: %v", exchangeID, err)}) + return + } + } + + // 重新加载该用户的所有交易员,使新配置立即生效 + err = s.traderManager.LoadUserTraders(s.database, userID) + if err != nil { + log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err) + } + + log.Printf("✓ 交易所配置已通过加密方式更新") + c.JSON(http.StatusOK, gin.H{"message": "交易所配置已更新"}) +} diff --git a/config.json.example b/config.json.example index 6a169ff4..ccdccca1 100644 --- a/config.json.example +++ b/config.json.example @@ -22,5 +22,20 @@ "jwt_secret": "Qk0kAa+d0iIEzXVHXbNbm+UaN3RNabmWtH8rDWZ5OPf+4GX8pBflAHodfpbipVMyrw1fsDanHsNBjhgbDeK9Jg==", "log": { "level": "info" + }, + "proxy": { + "enabled": false, + "mode": "single", + "timeout": 30, + "proxy_url": "http://127.0.0.1:7890", + "proxy_list": [], + "brightdata_endpoint": "", + "brightdata_token": "", + "brightdata_zone": "", + "proxy_host": "", + "proxy_user": "", + "proxy_password": "", + "refresh_interval": 0, + "blacklist_ttl": 5 } } diff --git a/config/config.go b/config/config.go index b913212f..6d1a433d 100644 --- a/config/config.go +++ b/config/config.go @@ -75,8 +75,25 @@ type Config struct { StopTradingMinutes int `json:"stop_trading_minutes"` Leverage LeverageConfig `json:"leverage"` // 杠杆配置 Log *LogConfig `json:"log"` // 日志配置(可选) + Proxy *ProxyConfig `json:"proxy"` // HTTP 代理配置(可选) } +// ProxyConfig HTTP 代理配置 +type ProxyConfig struct { + Enabled bool `json:"enabled"` // 是否启用代理 + Mode string `json:"mode"` // 模式: "single", "pool", "brightdata" + Timeout int `json:"timeout"` // 超时时间(秒) + ProxyURL string `json:"proxy_url"` // 单个代理地址 + ProxyList []string `json:"proxy_list"` // 代理列表 + BrightDataEndpoint string `json:"brightdata_endpoint"` // Bright Data接口地址 + BrightDataToken string `json:"brightdata_token"` // Bright Data访问令牌 + BrightDataZone string `json:"brightdata_zone"` // Bright Data区域 + ProxyHost string `json:"proxy_host"` // 代理主机 + ProxyUser string `json:"proxy_user"` // 代理用户名模板 + ProxyPassword string `json:"proxy_password"` // 代理密码 + RefreshInterval int `json:"refresh_interval"` // 刷新间隔(秒) + BlacklistTTL int `json:"blacklist_ttl"` // 黑名单TTL +} // LoadConfig 从文件加载配置 func LoadConfig(filename string) (*Config, error) { data, err := os.ReadFile(filename) diff --git a/config/database.go b/config/database.go index 51876587..1e6e1504 100644 --- a/config/database.go +++ b/config/database.go @@ -1,126 +1,130 @@ package config import ( - "fmt" - "time" + "fmt" + "time" + + "nofx/crypto" ) // DatabaseInterface 定义了数据库实现需要提供的方法集合 type DatabaseInterface interface { - CreateUser(user *User) error - GetUserByEmail(email string) (*User, error) - GetUserByID(userID string) (*User, error) - GetAllUsers() ([]string, error) - UpdateUserOTPVerified(userID string, verified bool) error - GetAIModels(userID string) ([]*AIModelConfig, error) - UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error - GetExchanges(userID string) ([]*ExchangeConfig, error) - UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error - CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error - CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error - CreateTrader(trader *TraderRecord) error - GetTraders(userID string) ([]*TraderRecord, error) - UpdateTraderStatus(userID, id string, isRunning bool) error - UpdateTrader(trader *TraderRecord) error - UpdateTraderInitialBalance(userID, id string, newBalance float64) error - UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error - DeleteTrader(userID, id string) error - GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error) - GetSystemConfig(key string) (string, error) - SetSystemConfig(key, value string) error - CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error - GetUserSignalSource(userID string) (*UserSignalSource, error) - UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error - GetCustomCoins() []string - LoadBetaCodesFromFile(filePath string) error - ValidateBetaCode(code string) (bool, error) - UseBetaCode(code, userEmail string) error - GetBetaCodeStats() (total, used int, err error) - Close() error + SetCryptoService(cs *crypto.CryptoService) + CreateUser(user *User) error + GetUserByEmail(email string) (*User, error) + GetUserByID(userID string) (*User, error) + GetAllUsers() ([]string, error) + UpdateUserOTPVerified(userID string, verified bool) error + GetAIModels(userID string) ([]*AIModelConfig, error) + UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error + GetExchanges(userID string) ([]*ExchangeConfig, error) + UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error + CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error + CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error + CreateTrader(trader *TraderRecord) error + GetTraders(userID string) ([]*TraderRecord, error) + UpdateTraderStatus(userID, id string, isRunning bool) error + UpdateTrader(trader *TraderRecord) error + UpdateTraderInitialBalance(userID, id string, newBalance float64) error + UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error + DeleteTrader(userID, id string) error + GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error) + GetSystemConfig(key string) (string, error) + SetSystemConfig(key, value string) error + CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error + GetUserSignalSource(userID string) (*UserSignalSource, error) + UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error + GetCustomCoins() []string + LoadBetaCodesFromFile(filePath string) error + ValidateBetaCode(code string) (bool, error) + UseBetaCode(code, userEmail string) error + GetBetaCodeStats() (total, used int, err error) + Close() error } // User 用户配置 type User struct { - ID string `json:"id"` - Email string `json:"email"` - PasswordHash string `json:"-"` // 不返回到前端 - OTPSecret string `json:"-"` // 不返回到前端 - OTPVerified bool `json:"otp_verified"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + Email string `json:"email"` + PasswordHash string `json:"-"` + OTPSecret string `json:"-"` + OTPVerified bool `json:"otp_verified"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // AIModelConfig AI模型配置 type AIModelConfig struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Provider string `json:"provider"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` - CustomAPIURL string `json:"customApiUrl"` - CustomModelName string `json:"customModelName"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Provider string `json:"provider"` + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey"` + CustomAPIURL string `json:"customApiUrl"` + CustomModelName string `json:"customModelName"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // ExchangeConfig 交易所配置 type ExchangeConfig struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Type string `json:"type"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` - SecretKey string `json:"secretKey"` - Testnet bool `json:"testnet"` - HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` - AsterUser string `json:"asterUser"` - AsterSigner string `json:"asterSigner"` - AsterPrivateKey string `json:"asterPrivateKey"` - Deleted bool `json:"deleted"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Type string `json:"type"` + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey"` + SecretKey string `json:"secretKey"` + Testnet bool `json:"testnet"` + HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` + AsterUser string `json:"asterUser"` + AsterSigner string `json:"asterSigner"` + AsterPrivateKey string `json:"asterPrivateKey"` + DEXWalletPrivateKey string `json:"dexWalletPrivateKey"` // 统一的DEX私钥字段 + Deleted bool `json:"deleted"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } -// TraderRecord 交易员配置(数据库实体) +// TraderRecord 交易员配置 type TraderRecord struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - AIModelID string `json:"ai_model_id"` - ExchangeID string `json:"exchange_id"` - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - IsRunning bool `json:"is_running"` - BTCETHLeverage int `json:"btc_eth_leverage"` - AltcoinLeverage int `json:"altcoin_leverage"` - TradingSymbols string `json:"trading_symbols"` - UseCoinPool bool `json:"use_coin_pool"` - UseOITop bool `json:"use_oi_top"` - CustomPrompt string `json:"custom_prompt"` - OverrideBasePrompt bool `json:"override_base_prompt"` - SystemPromptTemplate string `json:"system_prompt_template"` - IsCrossMargin bool `json:"is_cross_margin"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + AIModelID string `json:"ai_model_id"` + ExchangeID string `json:"exchange_id"` + InitialBalance float64 `json:"initial_balance"` + ScanIntervalMinutes int `json:"scan_interval_minutes"` + IsRunning bool `json:"is_running"` + BTCETHLeverage int `json:"btc_eth_leverage"` + AltcoinLeverage int `json:"altcoin_leverage"` + TradingSymbols string `json:"trading_symbols"` + UseCoinPool bool `json:"use_coin_pool"` + UseOITop bool `json:"use_oi_top"` + CustomPrompt string `json:"custom_prompt"` + OverrideBasePrompt bool `json:"override_base_prompt"` + SystemPromptTemplate string `json:"system_prompt_template"` + IsCrossMargin bool `json:"is_cross_margin"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // UserSignalSource 用户信号源配置 type UserSignalSource struct { - ID int `json:"id"` - UserID string `json:"user_id"` - CoinPoolURL string `json:"coin_pool_url"` - OITopURL string `json:"oi_top_url"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int `json:"id"` + UserID string `json:"user_id"` + CoinPoolURL string `json:"coin_pool_url"` + OITopURL string `json:"oi_top_url"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // NewDatabase 创建数据库连接(仅支持 PostgreSQL) func NewDatabase() (DatabaseInterface, error) { - pgDB, err := NewPostgreSQLDatabase() - if err != nil { - return nil, fmt.Errorf("创建PostgreSQL数据库失败: %w", err) - } - return pgDB, nil + pgDB, err := NewPostgreSQLDatabase() + if err != nil { + return nil, fmt.Errorf("创建PostgreSQL数据库失败: %w", err) + } + return pgDB, nil } diff --git a/config/database_pg.go b/config/database_pg.go index a7da471e..1acee98f 100644 --- a/config/database_pg.go +++ b/config/database_pg.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log" + "nofx/crypto" "nofx/market" "os" "slices" @@ -16,7 +17,8 @@ import ( // PostgreSQLDatabase PostgreSQL数据库配置 type PostgreSQLDatabase struct { - db *sql.DB + db *sql.DB + cryptoService *crypto.CryptoService } // NewPostgreSQLDatabase 创建PostgreSQL数据库连接 @@ -60,6 +62,42 @@ func NewPostgreSQLDatabase() (*PostgreSQLDatabase, error) { return database, nil } +func (d *PostgreSQLDatabase) SetCryptoService(cs *crypto.CryptoService) { + d.cryptoService = cs +} + +func (d *PostgreSQLDatabase) encryptValue(value string, aadParts ...string) (string, error) { + if value == "" { + return "", nil + } + if d.cryptoService == nil { + return "", fmt.Errorf("crypto service not initialized") + } + if !d.cryptoService.HasDataKey() { + return "", fmt.Errorf("data encryption key not configured") + } + if d.cryptoService.IsEncryptedStorageValue(value) { + return value, nil + } + return d.cryptoService.EncryptForStorage(value, aadParts...) +} + +func (d *PostgreSQLDatabase) decryptValue(value string, aadParts ...string) (string, error) { + if value == "" { + return "", nil + } + if d.cryptoService == nil { + return "", fmt.Errorf("crypto service not initialized") + } + if !d.cryptoService.HasDataKey() { + return "", fmt.Errorf("data encryption key not configured") + } + if !d.cryptoService.IsEncryptedStorageValue(value) { + return "", fmt.Errorf("value is not encrypted") + } + return d.cryptoService.DecryptFromStorage(value, aadParts...) +} + // getEnv 获取环境变量,如果不存在返回默认值 func getEnv(key, defaultValue string) string { if value := os.Getenv(key); value != "" { @@ -162,6 +200,15 @@ func (d *PostgreSQLDatabase) GetAIModels(userID string) ([]*AIModelConfig, error if err != nil { return nil, err } + + if model.APIKey != "" { + decrypted, err := d.decryptValue(model.APIKey, model.UserID, model.ID, "api_key") + if err != nil { + return nil, err + } + model.APIKey = decrypted + } + models = append(models, &model) } @@ -216,7 +263,7 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK log.Printf("🗑️ UpdateAIModel: 已标记删除用户 %s 的模型配置 %s (通过provider匹配)", userID, existingID) return nil } - + // 没有找到配置,返回成功(幂等性) log.Printf("ℹ️ UpdateAIModel: 模型配置不存在,跳过删除: %s", id) return nil @@ -229,11 +276,18 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK `, userID, id).Scan(&existingID) if err == nil { + apiKeyEnc, err := d.encryptValue(apiKey, userID, existingID, "api_key") + if err != nil { + return err + } // 找到了现有配置(精确匹配 ID),更新它 _, err = d.db.Exec(` UPDATE ai_models SET enabled = $1, api_key = $2, custom_api_url = $3, custom_model_name = $4, deleted = FALSE, updated_at = CURRENT_TIMESTAMP WHERE id = $5 AND user_id = $6 - `, enabled, apiKey, customAPIURL, customModelName, existingID, userID) + `, enabled, apiKeyEnc, customAPIURL, customModelName, existingID, userID) + return err + } + if err != sql.ErrNoRows { return err } @@ -244,12 +298,19 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK `, userID, provider).Scan(&existingID) if err == nil { + apiKeyEnc, err := d.encryptValue(apiKey, userID, existingID, "api_key") + if err != nil { + return err + } // 找到了现有配置(通过 provider 匹配,兼容旧版),更新它 log.Printf("⚠️ 使用旧版 provider 匹配更新模型: %s -> %s", provider, existingID) _, err = d.db.Exec(` UPDATE ai_models SET enabled = $1, api_key = $2, custom_api_url = $3, custom_model_name = $4, deleted = FALSE, updated_at = CURRENT_TIMESTAMP WHERE id = $5 AND user_id = $6 - `, enabled, apiKey, customAPIURL, customModelName, existingID, userID) + `, enabled, apiKeyEnc, customAPIURL, customModelName, existingID, userID) + return err + } + if err != sql.ErrNoRows { return err } @@ -292,11 +353,16 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK newModelID = fmt.Sprintf("%s_%s", userID, provider) } + apiKeyEnc, err := d.encryptValue(apiKey, userID, newModelID, "api_key") + if err != nil { + return err + } + log.Printf("✓ 创建新的 AI 模型配置: ID=%s, Provider=%s, Name=%s", newModelID, provider, name) _, err = d.db.Exec(` INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - `, newModelID, userID, name, provider, enabled, apiKey, customAPIURL, customModelName) + `, newModelID, userID, name, provider, enabled, apiKeyEnc, customAPIURL, customModelName) return err } @@ -309,6 +375,7 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err COALESCE(aster_user, '') AS aster_user, COALESCE(aster_signer, '') AS aster_signer, COALESCE(aster_private_key, '') AS aster_private_key, + COALESCE(dex_wallet_private_key, '') AS dex_wallet_private_key, COALESCE(deleted, FALSE) AS deleted, created_at, updated_at FROM exchanges @@ -329,12 +396,50 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err &exchange.Enabled, &exchange.APIKey, &exchange.SecretKey, &exchange.Testnet, &exchange.HyperliquidWalletAddr, &exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey, + &exchange.DEXWalletPrivateKey, &exchange.Deleted, &exchange.CreatedAt, &exchange.UpdatedAt, ) if err != nil { return nil, err } + + if decrypted, err := d.decryptValue(exchange.APIKey, exchange.UserID, exchange.ID, "api_key"); err == nil { + exchange.APIKey = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.SecretKey, exchange.UserID, exchange.ID, "secret_key"); err == nil { + exchange.SecretKey = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.HyperliquidWalletAddr, exchange.UserID, exchange.ID, "hyperliquid_wallet_addr"); err == nil { + exchange.HyperliquidWalletAddr = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.AsterUser, exchange.UserID, exchange.ID, "aster_user"); err == nil { + exchange.AsterUser = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.AsterSigner, exchange.UserID, exchange.ID, "aster_signer"); err == nil { + exchange.AsterSigner = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.AsterPrivateKey, exchange.UserID, exchange.ID, "aster_private_key"); err == nil { + exchange.AsterPrivateKey = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.DEXWalletPrivateKey, exchange.UserID, exchange.ID, "dex_wallet_private_key"); err == nil { + exchange.DEXWalletPrivateKey = decrypted + } else { + return nil, err + } + exchanges = append(exchanges, &exchange) } @@ -345,7 +450,7 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error { log.Printf("🔧 UpdateExchange: userID=%s, id=%s, enabled=%v", userID, id, enabled) - // 如果请求禁用该交易所,标记为已删除 + // 如果请求禁用该交易所,执行软删除 if !enabled { _, err := d.db.Exec(` UPDATE exchanges @@ -369,13 +474,38 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api return nil } + apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key") + if err != nil { + return fmt.Errorf("encrypt api_key failed: %w", err) + } + secretKeyEnc, err := d.encryptValue(secretKey, userID, id, "secret_key") + if err != nil { + return fmt.Errorf("encrypt secret_key failed: %w", err) + } + hyperAddrEnc, err := d.encryptValue(hyperliquidWalletAddr, userID, id, "hyperliquid_wallet_addr") + if err != nil { + return fmt.Errorf("encrypt hyperliquid_wallet_addr failed: %w", err) + } + asterUserEnc, err := d.encryptValue(asterUser, userID, id, "aster_user") + if err != nil { + return fmt.Errorf("encrypt aster_user failed: %w", err) + } + asterSignerEnc, err := d.encryptValue(asterSigner, userID, id, "aster_signer") + if err != nil { + return fmt.Errorf("encrypt aster_signer failed: %w", err) + } + asterPrivateKeyEnc, err := d.encryptValue(asterPrivateKey, userID, id, "aster_private_key") + if err != nil { + return fmt.Errorf("encrypt aster_private_key failed: %w", err) + } + // 首先尝试更新现有的用户配置 result, err := d.db.Exec(` UPDATE exchanges SET enabled = $1, api_key = $2, secret_key = $3, testnet = $4, hyperliquid_wallet_addr = $5, aster_user = $6, aster_signer = $7, aster_private_key = $8, deleted = FALSE, updated_at = CURRENT_TIMESTAMP WHERE id = $9 AND user_id = $10 - `, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, id, userID) + `, enabled, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc, id, userID) if err != nil { log.Printf("❌ UpdateExchange: 更新失败: %v", err) return err @@ -418,7 +548,7 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, deleted, created_at, updated_at) VALUES ($1, $2, $3, $4, TRUE, $5, $6, $7, $8, $9, $10, $11, FALSE, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - `, id, userID, name, typ, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey) + `, id, userID, name, typ, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc) if err != nil { log.Printf("❌ UpdateExchange: 创建记录失败: %v", err) @@ -434,21 +564,51 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api // CreateAIModel 创建AI模型配置 func (d *PostgreSQLDatabase) CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error { - _, err := d.db.Exec(` + apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key") + if err != nil { + return err + } + + _, err = d.db.Exec(` INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (id) DO NOTHING - `, id, userID, name, provider, enabled, apiKey, customAPIURL) + `, id, userID, name, provider, enabled, apiKeyEnc, customAPIURL) return err } // CreateExchange 创建交易所配置 func (d *PostgreSQLDatabase) CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error { - _, err := d.db.Exec(` + apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key") + if err != nil { + return fmt.Errorf("encrypt api_key failed: %w", err) + } + secretKeyEnc, err := d.encryptValue(secretKey, userID, id, "secret_key") + if err != nil { + return fmt.Errorf("encrypt secret_key failed: %w", err) + } + hyperAddrEnc, err := d.encryptValue(hyperliquidWalletAddr, userID, id, "hyperliquid_wallet_addr") + if err != nil { + return fmt.Errorf("encrypt hyperliquid_wallet_addr failed: %w", err) + } + asterUserEnc, err := d.encryptValue(asterUser, userID, id, "aster_user") + if err != nil { + return fmt.Errorf("encrypt aster_user failed: %w", err) + } + asterSignerEnc, err := d.encryptValue(asterSigner, userID, id, "aster_signer") + if err != nil { + return fmt.Errorf("encrypt aster_signer failed: %w", err) + } + asterPrivateKeyEnc, err := d.encryptValue(asterPrivateKey, userID, id, "aster_private_key") + if err != nil { + return fmt.Errorf("encrypt aster_private_key failed: %w", err) + } + + _, err = d.db.Exec(` INSERT INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet, hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT (id, user_id) DO NOTHING - `, id, userID, name, typ, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey) + `, id, userID, name, typ, enabled, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc) return err } @@ -575,6 +735,57 @@ func (d *PostgreSQLDatabase) GetTraderConfig(userID, traderID string) (*TraderRe return nil, nil, nil, err } + if aiModel.APIKey != "" { + decrypted, err := d.decryptValue(aiModel.APIKey, aiModel.UserID, aiModel.ID, "api_key") + if err != nil { + return nil, nil, nil, err + } + aiModel.APIKey = decrypted + } + + if exchange.APIKey != "" { + decrypted, err := d.decryptValue(exchange.APIKey, exchange.UserID, exchange.ID, "api_key") + if err != nil { + return nil, nil, nil, err + } + exchange.APIKey = decrypted + } + if exchange.SecretKey != "" { + decrypted, err := d.decryptValue(exchange.SecretKey, exchange.UserID, exchange.ID, "secret_key") + if err != nil { + return nil, nil, nil, err + } + exchange.SecretKey = decrypted + } + if exchange.HyperliquidWalletAddr != "" { + decrypted, err := d.decryptValue(exchange.HyperliquidWalletAddr, exchange.UserID, exchange.ID, "hyperliquid_wallet_addr") + if err != nil { + return nil, nil, nil, err + } + exchange.HyperliquidWalletAddr = decrypted + } + if exchange.AsterUser != "" { + decrypted, err := d.decryptValue(exchange.AsterUser, exchange.UserID, exchange.ID, "aster_user") + if err != nil { + return nil, nil, nil, err + } + exchange.AsterUser = decrypted + } + if exchange.AsterSigner != "" { + decrypted, err := d.decryptValue(exchange.AsterSigner, exchange.UserID, exchange.ID, "aster_signer") + if err != nil { + return nil, nil, nil, err + } + exchange.AsterSigner = decrypted + } + if exchange.AsterPrivateKey != "" { + decrypted, err := d.decryptValue(exchange.AsterPrivateKey, exchange.UserID, exchange.ID, "aster_private_key") + if err != nil { + return nil, nil, nil, err + } + exchange.AsterPrivateKey = decrypted + } + return &trader, &aiModel, &exchange, nil } diff --git a/crypto/crypto.go b/crypto/crypto.go new file mode 100644 index 00000000..9a29480f --- /dev/null +++ b/crypto/crypto.go @@ -0,0 +1,394 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" +) + +const ( + storagePrefix = "ENC:v1:" + storageDelimiter = ":" + dataKeyEnvName = "DATA_ENCRYPTION_KEY" +) + +type EncryptedPayload struct { + WrappedKey string `json:"wrappedKey"` + IV string `json:"iv"` + Ciphertext string `json:"ciphertext"` + AAD string `json:"aad,omitempty"` + KID string `json:"kid,omitempty"` + TS int64 `json:"ts,omitempty"` +} + +type AADData struct { + UserID string `json:"userId"` + SessionID string `json:"sessionId"` + TS int64 `json:"ts"` + Purpose string `json:"purpose"` +} + +type CryptoService struct { + privateKey *rsa.PrivateKey + publicKey *rsa.PublicKey + dataKey []byte +} + +func NewCryptoService(privateKeyPath string) (*CryptoService, error) { + // 读取私钥文件 + privateKeyPEM, err := ioutil.ReadFile(privateKeyPath) + if err != nil { + // 如果私钥文件不存在,生成新的密钥对 + if err := GenerateRSAKeyPair(privateKeyPath); err != nil { + return nil, fmt.Errorf("failed to generate RSA key pair: %w", err) + } + privateKeyPEM, err = ioutil.ReadFile(privateKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to read generated private key: %w", err) + } + } + + // 解析私钥 + privateKey, err := ParseRSAPrivateKeyFromPEM(privateKeyPEM) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + dataKey, err := loadDataKeyFromEnv() + if err != nil { + return nil, fmt.Errorf("failed to load data encryption key: %w", err) + } + + return &CryptoService{ + privateKey: privateKey, + publicKey: &privateKey.PublicKey, + dataKey: dataKey, + }, nil +} + +func GenerateRSAKeyPair(privateKeyPath string) error { + // 确保目录存在 + dir := filepath.Dir(privateKeyPath) + if dir != "." { + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + } + + // 生成 RSA 密钥对 + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + + // 编码私钥 + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + // 保存私钥 + if err := ioutil.WriteFile(privateKeyPath, privateKeyPEM, 0600); err != nil { + return err + } + + // 编码公钥 + publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + if err != nil { + return err + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyDER, + }) + + // 保存公钥 + publicKeyPath := privateKeyPath + ".pub" + if err := ioutil.WriteFile(publicKeyPath, publicKeyPEM, 0644); err != nil { + return err + } + + return nil +} + +func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("no PEM block found") + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "PRIVATE KEY": + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("not an RSA key") + } + return rsaKey, nil + default: + return nil, errors.New("unsupported key type: " + block.Type) + } +} + +func loadDataKeyFromEnv() ([]byte, error) { + keyStr := strings.TrimSpace(os.Getenv(dataKeyEnvName)) + if keyStr == "" { + return nil, fmt.Errorf("%s not set", dataKeyEnvName) + } + + if key, ok := decodePossibleKey(keyStr); ok { + return key, nil + } + + sum := sha256.Sum256([]byte(keyStr)) + key := make([]byte, len(sum)) + copy(key, sum[:]) + return key, nil +} + +func decodePossibleKey(value string) ([]byte, bool) { + decoders := []func(string) ([]byte, error){ + base64.StdEncoding.DecodeString, + base64.RawStdEncoding.DecodeString, + func(s string) ([]byte, error) { return hex.DecodeString(s) }, + } + + for _, decoder := range decoders { + if decoded, err := decoder(value); err == nil { + if key, ok := normalizeAESKey(decoded); ok { + return key, true + } + } + } + + return nil, false +} + +func normalizeAESKey(raw []byte) ([]byte, bool) { + switch len(raw) { + case 16, 24, 32: + return raw, true + case 0: + return nil, false + default: + sum := sha256.Sum256(raw) + key := make([]byte, len(sum)) + copy(key, sum[:]) + return key, true + } +} + +func (cs *CryptoService) HasDataKey() bool { + return len(cs.dataKey) > 0 +} + +func (cs *CryptoService) GetPublicKeyPEM() string { + publicKeyDER, err := x509.MarshalPKIXPublicKey(cs.publicKey) + if err != nil { + return "" + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyDER, + }) + + return string(publicKeyPEM) +} + +func (cs *CryptoService) EncryptForStorage(plaintext string, aadParts ...string) (string, error) { + if plaintext == "" { + return "", nil + } + if !cs.HasDataKey() { + return "", errors.New("data encryption key not configured") + } + if isEncryptedStorageValue(plaintext) { + return plaintext, nil + } + + block, err := aes.NewCipher(cs.dataKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return "", err + } + + aad := composeAAD(aadParts) + ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), aad) + + return storagePrefix + + base64.StdEncoding.EncodeToString(nonce) + storageDelimiter + + base64.StdEncoding.EncodeToString(ciphertext), nil +} + +func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (string, error) { + if value == "" { + return "", nil + } + if !cs.HasDataKey() { + return "", errors.New("data encryption key not configured") + } + if !isEncryptedStorageValue(value) { + return "", errors.New("value is not encrypted") + } + + payload := strings.TrimPrefix(value, storagePrefix) + parts := strings.SplitN(payload, storageDelimiter, 2) + if len(parts) != 2 { + return "", errors.New("invalid encrypted payload format") + } + + nonce, err := base64.StdEncoding.DecodeString(parts[0]) + if err != nil { + return "", fmt.Errorf("decode nonce failed: %w", err) + } + + ciphertext, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("decode ciphertext failed: %w", err) + } + + block, err := aes.NewCipher(cs.dataKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + if len(nonce) != gcm.NonceSize() { + return "", fmt.Errorf("invalid nonce size: expected %d, got %d", gcm.NonceSize(), len(nonce)) + } + + aad := composeAAD(aadParts) + plaintext, err := gcm.Open(nil, nonce, ciphertext, aad) + if err != nil { + return "", fmt.Errorf("decryption failed: %w", err) + } + + return string(plaintext), nil +} + +func (cs *CryptoService) IsEncryptedStorageValue(value string) bool { + return isEncryptedStorageValue(value) +} + +func composeAAD(parts []string) []byte { + if len(parts) == 0 { + return nil + } + return []byte(strings.Join(parts, "|")) +} + +func isEncryptedStorageValue(value string) bool { + return strings.HasPrefix(value, storagePrefix) +} + +func (cs *CryptoService) DecryptPayload(payload *EncryptedPayload) ([]byte, error) { + // 1. 验证时间戳(防止重放攻击) + if payload.TS != 0 { + elapsed := time.Since(time.Unix(payload.TS, 0)) + if elapsed > 5*time.Minute || elapsed < -1*time.Minute { + return nil, errors.New("timestamp invalid or expired") + } + } + + // 2. 解码 base64url + wrappedKey, err := base64.RawURLEncoding.DecodeString(payload.WrappedKey) + if err != nil { + return nil, fmt.Errorf("failed to decode wrapped key: %w", err) + } + + iv, err := base64.RawURLEncoding.DecodeString(payload.IV) + if err != nil { + return nil, fmt.Errorf("failed to decode IV: %w", err) + } + + ciphertext, err := base64.RawURLEncoding.DecodeString(payload.Ciphertext) + if err != nil { + return nil, fmt.Errorf("failed to decode ciphertext: %w", err) + } + + var aad []byte + if payload.AAD != "" { + aad, err = base64.RawURLEncoding.DecodeString(payload.AAD) + if err != nil { + return nil, fmt.Errorf("failed to decode AAD: %w", err) + } + + // 验证 AAD + var aadData AADData + if err := json.Unmarshal(aad, &aadData); err == nil { + // 可以在这里添加额外的验证逻辑 + // 例如:验证 sessionID、userID 等 + } + } + + // 3. 使用 RSA-OAEP 解密 AES 密钥 + aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, cs.privateKey, wrappedKey, nil) + if err != nil { + return nil, fmt.Errorf("failed to unwrap AES key: %w", err) + } + + // 4. 使用 AES-GCM 解密数据 + block, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("failed to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + if len(iv) != gcm.NonceSize() { + return nil, fmt.Errorf("invalid IV size: expected %d, got %d", gcm.NonceSize(), len(iv)) + } + + // 解密并验证认证标签 + plaintext, err := gcm.Open(nil, iv, ciphertext, aad) + if err != nil { + return nil, fmt.Errorf("authentication/decryption failed: %w", err) + } + + return plaintext, nil +} + +func (cs *CryptoService) DecryptSensitiveData(payload *EncryptedPayload) (string, error) { + plaintext, err := cs.DecryptPayload(payload) + if err != nil { + return "", err + } + return string(plaintext), nil +} diff --git a/docker-compose.yml b/docker-compose.yml index acdf459a..a15a01de 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -57,6 +57,7 @@ services: environment: - TZ=${NOFX_TIMEZONE:-Asia/Shanghai} # Set timezone - AI_MAX_TOKENS=4000 # AI响应的最大token数(默认2000,建议4000-8000) + - DATA_ENCRYPTION_KEY=${DATA_ENCRYPTION_KEY} # 数据加密密钥 - POSTGRES_HOST=postgres - POSTGRES_PORT=5432 - POSTGRES_DB=${POSTGRES_DB:-nofx} diff --git a/main.go b/main.go index 73dbab1b..dee1082e 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "nofx/api" "nofx/auth" "nofx/config" + "nofx/crypto" "nofx/manager" "nofx/market" "nofx/pool" @@ -171,6 +172,13 @@ func main() { } defer database.Close() + // 初始化加密服务(用于敏感数据加密存储与传输) + cryptoService, err := crypto.NewCryptoService("keys/rsa_private.key") + if err != nil { + log.Fatalf("❌ 初始化加密服务失败: %v", err) + } + database.SetCryptoService(cryptoService) + // 同步config.json到数据库 if err := syncConfigToDatabase(database, configFile); err != nil { log.Printf("⚠️ 同步config.json到数据库失败: %v", err) @@ -289,7 +297,7 @@ func main() { } // 创建并启动API服务器 - apiServer := api.NewServer(traderManager, database, apiPort) + apiServer := api.NewServer(traderManager, database, cryptoService, apiPort) go func() { if err := apiServer.Start(); err != nil { log.Printf("❌ API服务器错误: %v", err) diff --git a/proxy/README.md b/proxy/README.md new file mode 100644 index 00000000..f48a35d4 --- /dev/null +++ b/proxy/README.md @@ -0,0 +1,685 @@ +# HTTP 代理模块 + +## 概述 + +这是一个高度解耦的HTTP代理管理模块,专为解决高频API请求被限流/封禁问题而设计。支持单代理、代理池和动态IP获取三种模式,提供线程安全的IP轮换和智能黑名单管理机制。 + +## 功能特性 + +- ✅ **三种工作模式**:单代理、固定代理池、Bright Data API动态获取 +- ✅ **线程安全**:所有操作使用读写锁保护,支持并发访问 +- ✅ **智能黑名单**:失败的代理IP手动加入黑名单,TTL机制自动恢复 +- ✅ **自动刷新**:支持定时刷新代理IP列表(默认30分钟) +- ✅ **随机轮换**:从可用IP池中随机选择,避免单点压力 +- ✅ **防越界保护**:多层数组边界检查,确保运行时安全 +- ✅ **可选启用**:未配置或禁用时自动使用直连,不影响独立客户 + +## 架构设计 + +``` +proxy/ +├── README.md # 本文档 +├── types.go # 核心数据结构定义 +├── provider.go # IP提供者接口定义 +├── single_provider.go # 单代理实现 +├── fixed_provider.go # 固定代理池实现 +├── brightdata_provider.go # Bright Data API实现 +└── proxy_manager.go # 代理管理器(核心逻辑) +``` + +### 设计原则 + +1. **接口抽象**:通过 `IPProvider` 接口实现不同代理源的统一管理 +2. **策略模式**:三种Provider实现可灵活切换 +3. **单例模式**:全局ProxyManager确保资源统一管理 +4. **防御性编程**:多层边界检查,优雅处理异常情况 + +## 配置说明 + +在 `config.json` 中添加 `proxy` 配置段: + +```json +{ + "proxy": { + "enabled": true, + "mode": "single", + "timeout": 30, + "proxy_url": "http://127.0.0.1:7890", + "proxy_list": [], + "brightdata_endpoint": "", + "brightdata_token": "", + "brightdata_zone": "", + "proxy_host": "", + "proxy_user": "", + "proxy_password": "", + "refresh_interval": 1800, + "blacklist_ttl": 5 + } +} +``` + +### 配置字段详解 + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `enabled` | bool | 是 | 是否启用代理(false时使用直连) | +| `mode` | string | 是 | 代理模式:`single`/`pool`/`brightdata` | +| `timeout` | int | 否 | HTTP请求超时时间(秒),默认30 | +| `proxy_url` | string | single模式必填 | 单个代理地址,如 `http://127.0.0.1:7890` | +| `proxy_list` | []string | pool模式必填 | 代理列表,支持 `http://`、`https://`、`socks5://` | +| `brightdata_endpoint` | string | brightdata模式必填 | Bright Data API端点 | +| `brightdata_token` | string | brightdata模式可选 | Bright Data访问令牌 | +| `brightdata_zone` | string | brightdata模式可选 | Bright Data区域参数 | +| `proxy_host` | string | 否 | 代理主机(用于认证代理) | +| `proxy_user` | string | 否 | 代理用户名模板,支持 `%s` 占位符替换IP | +| `proxy_password` | string | 否 | 代理密码 | +| `refresh_interval` | int | 否 | IP列表刷新间隔(秒),brightdata模式默认1800(30分钟) | +| `blacklist_ttl` | int | 否 | 黑名单IP的TTL(刷新次数),默认5 | + +## 使用方法 + +### 1. 初始化代理管理器 + +在 `main.go` 或初始化代码中: + +```go +import ( + "nofx/proxy" + "time" +) + +// 方式1:使用配置结构体初始化 +proxyConfig := &proxy.Config{ + Enabled: true, + Mode: "single", + Timeout: 30 * time.Second, + ProxyURL: "http://127.0.0.1:7890", + BlacklistTTL: 5, +} + +err := proxy.InitGlobalProxyManager(proxyConfig) +if err != nil { + log.Fatalf("初始化代理管理器失败: %v", err) +} +``` + +### 2. 获取代理HTTP客户端 + +在需要发送HTTP请求的地方: + +```go +// 获取代理客户端(包含ProxyID用于黑名单管理) +proxyClient, err := proxy.GetProxyHTTPClient() +if err != nil { + log.Printf("获取代理客户端失败: %v", err) + return +} + +// 使用代理客户端发送请求 +resp, err := proxyClient.Client.Get("https://api.example.com/data") +if err != nil { + // 请求失败,将此代理加入黑名单 + proxy.AddBlacklist(proxyClient.ProxyID) + log.Printf("请求失败,代理IP %s 已加入黑名单", proxyClient.IP) + return +} +defer resp.Body.Close() + +// 处理响应... +``` + +### 3. 黑名单管理 + +```go +// 添加失败的代理到黑名单 +proxy.AddBlacklist(proxyClient.ProxyID) + +// 获取黑名单状态 +total, blacklisted, available := proxy.GetGlobalProxyManager().GetBlacklistStatus() +log.Printf("代理状态: 总计%d个,黑名单%d个,可用%d个", total, blacklisted, available) +``` + +### 4. 手动刷新IP列表 + +```go +err := proxy.RefreshIPList() +if err != nil { + log.Printf("刷新IP列表失败: %v", err) +} +``` + +### 5. 检查代理是否启用 + +```go +if proxy.IsEnabled() { + log.Println("代理已启用") +} else { + log.Println("代理未启用,使用直连") +} +``` + +## 三种模式详解 + +### Mode 1: Single(单代理模式) + +适用场景:本地代理工具(如Clash、V2Ray)或单个固定代理服务器 + +```json +{ + "proxy": { + "enabled": true, + "mode": "single", + "proxy_url": "http://127.0.0.1:7890" + } +} +``` + +特点: +- 简单直接,适合本地开发和测试 +- 所有请求通过同一个代理 +- 不需要刷新和轮换 + +### Mode 2: Pool(代理池模式) + +适用场景:拥有多个固定代理服务器,需要轮换使用 + +```json +{ + "proxy": { + "enabled": true, + "mode": "pool", + "proxy_list": [ + "http://proxy1.example.com:8080", + "http://user:pass@proxy2.example.com:8080", + "socks5://proxy3.example.com:1080" + ], + "blacklist_ttl": 5 + } +} +``` + +特点: +- 支持多协议:HTTP、HTTPS、SOCKS5 +- 随机选择代理,分散请求压力 +- 失败的代理自动加入黑名单 +- 黑名单IP经过TTL次刷新后自动恢复 + +### Mode 3: BrightData(动态IP模式) + +适用场景:使用Bright Data等提供API的动态代理服务 + +```json +{ + "proxy": { + "enabled": true, + "mode": "brightdata", + "brightdata_endpoint": "https://api.brightdata.com/zones/get_ips", + "brightdata_token": "your_api_token", + "brightdata_zone": "residential", + "proxy_host": "brd.superproxy.io:22225", + "proxy_user": "brd-customer-xxx-zone-residential-ip-%s", + "proxy_password": "your_password", + "refresh_interval": 1800, + "blacklist_ttl": 5 + } +} +``` + +特点: +- 从API动态获取可用IP列表 +- 自动定时刷新(默认30分钟) +- 支持用户名模板(`%s` 替换为IP地址) +- 黑名单TTL机制避免频繁切换 + +**用户名模板说明**: +``` +proxy_user: "brd-customer-xxx-zone-residential-ip-%s" + ↑ + 自动替换为IP地址 +``` + +## 核心API + +### 全局函数 + +```go +// 初始化全局代理管理器(只执行一次) +func InitGlobalProxyManager(config *Config) error + +// 获取全局代理管理器实例 +func GetGlobalProxyManager() *ProxyManager + +// 获取代理HTTP客户端(包含ProxyID和IP信息) +func GetProxyHTTPClient() (*ProxyClient, error) + +// 将代理IP添加到黑名单 +func AddBlacklist(proxyID int) + +// 刷新IP列表 +func RefreshIPList() error + +// 检查代理是否启用 +func IsEnabled() bool +``` + +### ProxyManager 方法 + +```go +// 获取代理客户端 +func (m *ProxyManager) GetProxyClient() (*ProxyClient, error) + +// 刷新IP列表 +func (m *ProxyManager) RefreshIPList() error + +// 添加到黑名单 +func (m *ProxyManager) AddBlacklist(proxyID int) + +// 获取黑名单状态 +func (m *ProxyManager) GetBlacklistStatus() (total, blacklisted, available int) + +// 启动自动刷新 +func (m *ProxyManager) StartAutoRefresh() + +// 停止自动刷新 +func (m *ProxyManager) StopAutoRefresh() +``` + +## 黑名单机制 + +### 工作原理 + +1. **添加黑名单**:当代理请求失败时,调用 `AddBlacklist(proxyID)` 将该IP加入黑名单 +2. **TTL倒计时**:每次刷新IP列表时,黑名单中的IP的TTL减1 +3. **自动恢复**:当TTL归零时,IP自动从黑名单移除,重新可用 + +### 线程安全保证 + +```go +// 添加黑名单使用写锁 +func (m *ProxyManager) AddBlacklist(proxyID int) { + m.mutex.Lock() + defer m.mutex.Unlock() + + // 防越界检查 + if proxyID < 0 || proxyID >= len(m.ipList) { + log.Printf("⚠️ 无效的 ProxyID: %d", proxyID) + return + } + + ip := m.ipList[proxyID].IP + m.blacklist[proxyID] = ip + m.ipBlacklist[ip] = m.config.BlacklistTTL +} + +// 获取代理使用读锁(支持并发) +func (m *ProxyManager) getRandomProxy() (int, *ProxyIP, error) { + m.mutex.RLock() + defer m.mutex.RUnlock() + // ... 读取操作 +} +``` + +### 示例流程 + +``` +初始状态:5个代理IP,TTL=3 +IP列表: [IP1, IP2, IP3, IP4, IP5] +黑名单: {} + +第1次失败:IP2请求失败 +IP列表: [IP1, IP2, IP3, IP4, IP5] +黑名单: {IP2: TTL=3} + +第1次刷新:TTL-1 +黑名单: {IP2: TTL=2} + +第2次刷新:TTL-1 +黑名单: {IP2: TTL=1} + +第3次刷新:TTL-1 +黑名单: {IP2: TTL=0} → 从黑名单移除 + +第3次刷新后: +IP列表: [IP1, IP2, IP3, IP4, IP5] +黑名单: {} ← IP2已恢复可用 +``` + +## 完整使用示例 + +### 示例1:币安API请求(单代理模式) + +```go +package main + +import ( + "log" + "nofx/proxy" + "time" +) + +func main() { + // 初始化代理 + err := proxy.InitGlobalProxyManager(&proxy.Config{ + Enabled: true, + Mode: "single", + ProxyURL: "http://127.0.0.1:7890", + Timeout: 30 * time.Second, + }) + if err != nil { + log.Fatalf("初始化代理失败: %v", err) + } + + // 获取币安数据 + proxyClient, err := proxy.GetProxyHTTPClient() + if err != nil { + log.Fatalf("获取代理客户端失败: %v", err) + } + + resp, err := proxyClient.Client.Get("https://fapi.binance.com/fapi/v1/ticker/24hr") + if err != nil { + log.Printf("请求失败: %v", err) + return + } + defer resp.Body.Close() + + log.Printf("请求成功,使用代理: %s", proxyClient.IP) +} +``` + +### 示例2:OI数据获取(代理池模式 + 黑名单) + +```go +package main + +import ( + "fmt" + "io" + "log" + "nofx/proxy" + "time" +) + +func fetchOIData(symbol string) error { + proxyClient, err := proxy.GetProxyHTTPClient() + if err != nil { + return fmt.Errorf("获取代理失败: %w", err) + } + + url := fmt.Sprintf("https://fapi.binance.com/futures/data/openInterestHist?symbol=%s&period=5m&limit=1", symbol) + resp, err := proxyClient.Client.Get(url) + if err != nil { + // 请求失败,加入黑名单 + proxy.AddBlacklist(proxyClient.ProxyID) + return fmt.Errorf("请求失败 (代理: %s): %w", proxyClient.IP, err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + // 状态码异常,加入黑名单 + proxy.AddBlacklist(proxyClient.ProxyID) + return fmt.Errorf("状态码异常: %d (代理: %s)", resp.StatusCode, proxyClient.IP) + } + + body, _ := io.ReadAll(resp.Body) + log.Printf("✓ 获取 %s OI数据成功 (代理: %s): %s", symbol, proxyClient.IP, string(body)) + return nil +} + +func main() { + // 初始化代理池 + err := proxy.InitGlobalProxyManager(&proxy.Config{ + Enabled: true, + Mode: "pool", + ProxyList: []string{ + "http://proxy1.example.com:8080", + "http://proxy2.example.com:8080", + "http://proxy3.example.com:8080", + }, + Timeout: 30 * time.Second, + BlacklistTTL: 5, + }) + if err != nil { + log.Fatalf("初始化代理失败: %v", err) + } + + // 循环获取数据 + symbols := []string{"BTCUSDT", "ETHUSDT", "SOLUSDT"} + for { + for _, symbol := range symbols { + if err := fetchOIData(symbol); err != nil { + log.Printf("⚠️ %v", err) + } + time.Sleep(1 * time.Second) + } + time.Sleep(10 * time.Second) + } +} +``` + +### 示例3:Bright Data动态IP + +```go +package main + +import ( + "log" + "nofx/proxy" + "time" +) + +func main() { + // 初始化Bright Data代理 + err := proxy.InitGlobalProxyManager(&proxy.Config{ + Enabled: true, + Mode: "brightdata", + BrightDataEndpoint: "https://api.brightdata.com/zones/get_ips", + BrightDataToken: "your_token", + BrightDataZone: "residential", + ProxyHost: "brd.superproxy.io:22225", + ProxyUser: "brd-customer-xxx-zone-residential-ip-%s", + ProxyPassword: "your_password", + RefreshInterval: 30 * time.Minute, + Timeout: 30 * time.Second, + BlacklistTTL: 5, + }) + if err != nil { + log.Fatalf("初始化代理失败: %v", err) + } + + // 代理会自动每30分钟刷新IP列表 + log.Println("✓ Bright Data代理已启动,自动刷新已开启") + + // 获取并使用代理 + for i := 0; i < 10; i++ { + proxyClient, err := proxy.GetProxyHTTPClient() + if err != nil { + log.Printf("获取代理失败: %v", err) + continue + } + + resp, err := proxyClient.Client.Get("https://api.ipify.org?format=json") + if err != nil { + proxy.AddBlacklist(proxyClient.ProxyID) + log.Printf("请求失败,代理已加入黑名单: %s", proxyClient.IP) + continue + } + resp.Body.Close() + + log.Printf("✓ 请求成功 (代理ID: %d, IP: %s)", proxyClient.ProxyID, proxyClient.IP) + time.Sleep(2 * time.Second) + } +} +``` + +## 注意事项 + +### 1. 模块解耦性 + +- ✅ 代理模块完全独立,不依赖其他业务模块 +- ✅ 禁用代理时自动使用直连,对业务代码透明 +- ✅ 适合多租户/多客户环境,可按需启用 + +### 2. 线程安全 + +- ✅ 所有公开方法都是线程安全的 +- ✅ 支持高并发场景下的代理获取和黑名单操作 +- ✅ 读写锁优化性能:读操作可并发,写操作独占 + +### 3. 错误处理 + +```go +proxyClient, err := proxy.GetProxyHTTPClient() +if err != nil { + // 可能的错误: + // - 代理IP列表为空 + // - 所有代理都在黑名单中 + // - 代理URL解析失败 + log.Printf("获取代理失败: %v", err) + + // 建议:降级为直连或重试 + return +} +``` + +### 4. 性能优化建议 + +- 对于高频请求,复用 `http.Client` 而不是每次创建新的 +- 合理设置 `refresh_interval` 避免频繁刷新 +- `blacklist_ttl` 建议设置为 3-10,平衡恢复速度和稳定性 + +### 5. 安全建议 + +- 生产环境中代理密钥应使用环境变量或密钥管理服务 +- 避免在日志中打印完整的代理URL(包含密码) +- TLS验证默认开启,如需跳过请谨慎评估风险 + +### 6. 调试技巧 + +```go +// 获取当前代理状态 +total, blacklisted, available := proxy.GetGlobalProxyManager().GetBlacklistStatus() +log.Printf("代理池状态: 总计=%d, 黑名单=%d, 可用=%d", total, blacklisted, available) + +// 检查是否启用 +if !proxy.IsEnabled() { + log.Println("代理未启用,请检查配置") +} +``` + +## 故障排查 + +### 问题1:获取代理失败 - "代理IP列表为空" + +**原因**: +- `single` 模式:未配置 `proxy_url` +- `pool` 模式:`proxy_list` 为空 +- `brightdata` 模式:API返回空列表或请求失败 + +**解决方案**: +```bash +# 检查配置文件 +cat config.json | grep -A 15 "proxy" + +# 检查日志,查看初始化信息 +# 应该看到类似:🌐 HTTP 代理已启用 (xxx模式) +``` + +### 问题2:所有代理都在黑名单中 + +**原因**:请求持续失败,所有IP被加入黑名单 + +**解决方案**: +```go +// 方案1:手动刷新IP列表(会触发TTL倒计时) +proxy.RefreshIPList() + +// 方案2:降低blacklist_ttl,加快恢复速度 +// config.json: "blacklist_ttl": 2 (默认5) + +// 方案3:检查代理本身是否可用 +// 使用curl测试代理: +// curl -x http://proxy_url https://api.binance.com/api/v3/ping +``` + +### 问题3:Bright Data模式无法获取IP + +**原因**: +- API端点配置错误 +- Token无效或过期 +- Zone参数不正确 + +**解决方案**: +```bash +# 手动测试API +curl -H "Authorization: Bearer YOUR_TOKEN" \ + "https://api.brightdata.com/zones/get_ips?zone=residential" + +# 检查返回格式是否符合: +# {"ips": [{"ip": "1.2.3.4", ...}, ...]} +``` + +### 问题4:代理连接超时 + +**原因**:代理服务器响应慢或网络不稳定 + +**解决方案**: +```json +{ + "proxy": { + "timeout": 60 // 增加超时时间(秒) + } +} +``` + +## 扩展开发 + +### 添加新的Provider + +实现 `IPProvider` 接口即可: + +```go +// custom_provider.go +package proxy + +type CustomProvider struct { + // 自定义字段 +} + +func NewCustomProvider(config string) *CustomProvider { + return &CustomProvider{} +} + +func (p *CustomProvider) GetIPList() ([]ProxyIP, error) { + // 实现获取IP列表的逻辑 + return []ProxyIP{}, nil +} + +func (p *CustomProvider) RefreshIPList() ([]ProxyIP, error) { + // 实现刷新IP列表的逻辑 + return p.GetIPList() +} +``` + +然后在 `proxy_manager.go` 的 `NewProxyManager` 中添加新模式: + +```go +case "custom": + m.provider = NewCustomProvider(config.CustomEndpoint) + log.Printf("🌐 HTTP 代理已启用 (自定义模式)") +``` + +## 更新日志 + +### v1.0.0 (当前版本) +- ✅ 支持三种代理模式:single、pool、brightdata +- ✅ 线程安全的IP轮换和黑名单管理 +- ✅ 自动刷新机制(30分钟默认) +- ✅ TTL黑名单自动恢复 +- ✅ 防越界保护 +- ✅ ProxyID追踪机制 + + +## 技术支持 + +如有问题或建议,请联系项目维护者 @hzb1115 +。 diff --git a/proxy/brightdata_provider.go b/proxy/brightdata_provider.go new file mode 100644 index 00000000..e8febd55 --- /dev/null +++ b/proxy/brightdata_provider.go @@ -0,0 +1,105 @@ +package proxy + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "time" +) + +// BrightDataProvider Bright Data动态获取IP提供者 +type BrightDataProvider struct { + endpoint string + token string + zone string + client *http.Client +} + +// NewBrightDataProvider 创建Bright Data IP提供者 +func NewBrightDataProvider(endpoint, token, zone string) *BrightDataProvider { + return &BrightDataProvider{ + endpoint: endpoint, + token: token, + zone: zone, + client: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// BrightDataIPList Bright Data API返回的IP列表结构 +type BrightDataIPList struct { + IPs []struct { + IP string `json:"ip"` + Maxmind string `json:"maxmind"` + Ext map[string]interface{} `json:"ext"` + } `json:"ips"` +} + +func (p *BrightDataProvider) GetIPList() ([]ProxyIP, error) { + return p.fetchIPList() +} + +func (p *BrightDataProvider) RefreshIPList() ([]ProxyIP, error) { + return p.fetchIPList() +} + +func (p *BrightDataProvider) fetchIPList() ([]ProxyIP, error) { + // 构建请求URL + url := p.endpoint + if p.zone != "" { + url = fmt.Sprintf("%s?zone=%s", p.endpoint, p.zone) + } + + // 创建HTTP请求 + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("创建HTTP请求失败: %w", err) + } + + // 设置授权头 + if p.token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.token)) + } + + // 发送请求 + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("发送HTTP请求失败: %w", err) + } + defer resp.Body.Close() + + // 读取响应体 + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取HTTP响应失败: %w", err) + } + + // 检查状态码 + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body)) + } + + // 解析JSON数据(支持Bright Data格式) + var ipList BrightDataIPList + if err := json.Unmarshal(body, &ipList); err != nil { + return nil, fmt.Errorf("解析JSON数据失败: %w", err) + } + + // 转换为ProxyIP列表 + result := make([]ProxyIP, 0, len(ipList.IPs)) + for _, ip := range ipList.IPs { + result = append(result, ProxyIP{ + IP: ip.IP, + Protocol: "http", + Ext: ip.Ext, + }) + } + + if len(result) == 0 { + return nil, fmt.Errorf("API返回的IP列表为空") + } + + return result, nil +} diff --git a/proxy/fixed_provider.go b/proxy/fixed_provider.go new file mode 100644 index 00000000..267b047e --- /dev/null +++ b/proxy/fixed_provider.go @@ -0,0 +1,42 @@ +package proxy + +import "strings" + +// FixedIPProvider 固定IP列表提供者 +type FixedIPProvider struct { + ips []ProxyIP +} + +// NewFixedIPProvider 创建固定IP列表提供者 +func NewFixedIPProvider(proxyURLs []string) *FixedIPProvider { + ips := make([]ProxyIP, 0, len(proxyURLs)) + for _, proxyURL := range proxyURLs { + // 简单解析代理URL + // 格式: http://ip:port 或 socks5://user:pass@ip:port + protocol := "http" + if strings.HasPrefix(proxyURL, "socks5://") { + protocol = "socks5" + proxyURL = strings.TrimPrefix(proxyURL, "socks5://") + } else if strings.HasPrefix(proxyURL, "http://") { + proxyURL = strings.TrimPrefix(proxyURL, "http://") + } else if strings.HasPrefix(proxyURL, "https://") { + protocol = "https" + proxyURL = strings.TrimPrefix(proxyURL, "https://") + } + + ips = append(ips, ProxyIP{ + IP: proxyURL, + Protocol: protocol, + }) + } + + return &FixedIPProvider{ips: ips} +} + +func (p *FixedIPProvider) GetIPList() ([]ProxyIP, error) { + return p.ips, nil +} + +func (p *FixedIPProvider) RefreshIPList() ([]ProxyIP, error) { + return p.ips, nil +} diff --git a/proxy/provider.go b/proxy/provider.go new file mode 100644 index 00000000..b4d6e06d --- /dev/null +++ b/proxy/provider.go @@ -0,0 +1,10 @@ +package proxy + +// IPProvider IP提供者接口 +type IPProvider interface { + // GetIPList 获取IP列表 + GetIPList() ([]ProxyIP, error) + + // RefreshIPList 刷新IP列表(可选实现) + RefreshIPList() ([]ProxyIP, error) +} diff --git a/proxy/proxy_client.go b/proxy/proxy_client.go new file mode 100644 index 00000000..cda50b00 --- /dev/null +++ b/proxy/proxy_client.go @@ -0,0 +1,47 @@ +package proxy + +import ( + "log" + "net/http" + "time" +) + +// --- 便捷函数(直接使用全局管理器) --- + +// GetProxyHTTPClient 获取代理 HTTP 客户端(返回 ProxyClient,包含 ProxyID) +func GetProxyHTTPClient() (*ProxyClient, error) { + return GetGlobalProxyManager().GetProxyClient() +} + +// NewHTTPClient 创建一个新的HTTP客户端(使用全局代理配置) +// 注意:不返回 ProxyID,如需 ProxyID 请使用 GetProxyHTTPClient() +func NewHTTPClient() *http.Client { + client, err := GetGlobalProxyManager().GetProxyClient() + if err != nil { + log.Printf("⚠️ 获取代理客户端失败,使用直连: %v", err) + return &http.Client{Timeout: 30 * time.Second} + } + return client.Client +} + +// NewHTTPClientWithTimeout 创建一个新的HTTP客户端并指定超时时间 +// 注意:不返回 ProxyID,如需 ProxyID 请使用 GetProxyHTTPClient() +func NewHTTPClientWithTimeout(timeout time.Duration) *http.Client { + client, err := GetGlobalProxyManager().GetProxyClient() + if err != nil { + log.Printf("⚠️ 获取代理客户端失败,使用直连: %v", err) + return &http.Client{Timeout: timeout} + } + client.Client.Timeout = timeout + return client.Client +} + +// GetTransport 获取HTTP Transport +func GetTransport() *http.Transport { + client, err := GetGlobalProxyManager().GetProxyClient() + if err != nil { + log.Printf("⚠️ 获取代理客户端失败,使用直连: %v", err) + return &http.Transport{} + } + return client.Client.Transport.(*http.Transport) +} \ No newline at end of file diff --git a/proxy/proxy_manager.go b/proxy/proxy_manager.go new file mode 100644 index 00000000..aaca00e4 --- /dev/null +++ b/proxy/proxy_manager.go @@ -0,0 +1,346 @@ +package proxy + +import ( + "crypto/tls" + "fmt" + "log" + "math/rand" + "net/http" + "net/url" + "sync" + "time" +) + +// ProxyManager 代理管理器 +type ProxyManager struct { + config *Config + provider IPProvider + + // IP池管理 + ipList []ProxyIP + blacklist map[int]string // ProxyID -> IP + ipBlacklist map[string]int // IP -> 剩余TTL + mutex sync.RWMutex // 读写锁,保证线程安全 + + // 刷新控制 + stopRefresh chan struct{} +} + +var ( + globalProxyManager *ProxyManager + once sync.Once +) + +// InitGlobalProxyManager 初始化全局代理管理器 +func InitGlobalProxyManager(config *Config) error { + var err error + once.Do(func() { + globalProxyManager, err = NewProxyManager(config) + if err == nil && config.Enabled && config.RefreshInterval > 0 { + globalProxyManager.StartAutoRefresh() + } + }) + return err +} + +// GetGlobalProxyManager 获取全局代理管理器 +func GetGlobalProxyManager() *ProxyManager { + if globalProxyManager == nil { + // 如果未初始化,使用默认配置(禁用代理) + _ = InitGlobalProxyManager(&Config{Enabled: false}) + } + return globalProxyManager +} + +// NewProxyManager 创建代理管理器 +func NewProxyManager(config *Config) (*ProxyManager, error) { + if config == nil { + config = &Config{Enabled: false} + } + + // 设置默认值 + if config.Timeout == 0 { + config.Timeout = 30 * time.Second + } + if config.BlacklistTTL == 0 { + config.BlacklistTTL = 5 // 默认 TTL 为 5 次刷新 + } + if config.RefreshInterval == 0 && config.Mode == "brightdata" { + config.RefreshInterval = 30 * time.Minute // 默认 30 分钟刷新一次 + } + + m := &ProxyManager{ + config: config, + blacklist: make(map[int]string), + ipBlacklist: make(map[string]int), + stopRefresh: make(chan struct{}), + } + + // 如果未启用代理,直接返回 + if !config.Enabled { + log.Printf("🌐 HTTP 代理未启用,使用直连") + return m, nil + } + + // 根据模式选择IP提供者 + switch config.Mode { + case "single": + // 单个代理模式 + if config.ProxyURL == "" { + return nil, fmt.Errorf("single模式下必须配置proxy_url") + } + m.provider = NewSingleProxyProvider(config.ProxyURL) + log.Printf("🌐 HTTP 代理已启用 (单代理模式): %s", config.ProxyURL) + + case "pool": + // 代理池模式(固定列表) + if len(config.ProxyList) == 0 { + return nil, fmt.Errorf("pool模式下必须配置proxy_list") + } + m.provider = NewFixedIPProvider(config.ProxyList) + log.Printf("🌐 HTTP 代理已启用 (代理池模式): %d个代理", len(config.ProxyList)) + + case "brightdata": + // Bright Data动态获取模式 + if config.BrightDataEndpoint == "" { + return nil, fmt.Errorf("brightdata模式下必须配置brightdata_endpoint") + } + m.provider = NewBrightDataProvider(config.BrightDataEndpoint, config.BrightDataToken, config.BrightDataZone) + log.Printf("🌐 HTTP 代理已启用 (Bright Data模式): %s", config.BrightDataEndpoint) + + default: + // 默认使用single模式 + if config.ProxyURL == "" { + return nil, fmt.Errorf("未知的proxy模式: %s", config.Mode) + } + m.provider = NewSingleProxyProvider(config.ProxyURL) + log.Printf("🌐 HTTP 代理已启用 (默认模式): %s", config.ProxyURL) + } + + // 初始化IP列表 + if err := m.RefreshIPList(); err != nil { + return nil, fmt.Errorf("初始化IP列表失败: %w", err) + } + + return m, nil +} + +// RefreshIPList 刷新IP列表(线程安全) +func (m *ProxyManager) RefreshIPList() error { + if m.provider == nil { + return nil + } + + ips, err := m.provider.RefreshIPList() + if err != nil { + return err + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + // 清理黑名单,TTL倒计时 + validIPs := make([]ProxyIP, 0, len(ips)) + newBlacklist := make(map[int]string) + + for _, ip := range ips { + if ttl, inBlacklist := m.ipBlacklist[ip.IP]; inBlacklist { + // TTL 倒计时 + m.ipBlacklist[ip.IP] = ttl - 1 + if ttl > 0 { + // 仍在黑名单中,跳过 + continue + } + // TTL 归零,从黑名单移除 + delete(m.ipBlacklist, ip.IP) + log.Printf("✓ 代理IP已从黑名单恢复: %s", ip.IP) + } + validIPs = append(validIPs, ip) + } + + m.ipList = validIPs + m.blacklist = newBlacklist + + log.Printf("✓ 刷新代理IP列表: 总计%d个,黑名单%d个,可用%d个", + len(ips), len(m.ipBlacklist), len(validIPs)) + + return nil +} + +// StartAutoRefresh 启动自动刷新 +func (m *ProxyManager) StartAutoRefresh() { + if m.config.RefreshInterval <= 0 { + return + } + + go func() { + ticker := time.NewTicker(m.config.RefreshInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := m.RefreshIPList(); err != nil { + log.Printf("⚠️ 自动刷新IP列表失败: %v", err) + } + case <-m.stopRefresh: + return + } + } + }() + + log.Printf("✓ 已启动代理IP自动刷新 (间隔: %v)", m.config.RefreshInterval) +} + +// StopAutoRefresh 停止自动刷新 +func (m *ProxyManager) StopAutoRefresh() { + close(m.stopRefresh) +} + +// getRandomProxy 随机获取一个可用代理(线程安全 - 读锁,确保不越界) +func (m *ProxyManager) getRandomProxy() (int, *ProxyIP, error) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + if len(m.ipList) == 0 { + return -1, nil, fmt.Errorf("代理IP列表为空") + } + + // 找到所有未被黑名单的索引 + availableIndices := make([]int, 0, len(m.ipList)) + for i := range m.ipList { + if _, inBlacklist := m.blacklist[i]; !inBlacklist { + availableIndices = append(availableIndices, i) + } + } + + if len(availableIndices) == 0 { + return -1, nil, fmt.Errorf("所有代理IP都在黑名单中") + } + + // 随机选择一个(确保不越界) + randomIdx := availableIndices[rand.Intn(len(availableIndices))] + + // 二次检查,确保索引有效(防御性编程) + if randomIdx < 0 || randomIdx >= len(m.ipList) { + return -1, nil, fmt.Errorf("代理索引越界: %d (总数: %d)", randomIdx, len(m.ipList)) + } + + return randomIdx, &m.ipList[randomIdx], nil +} + +// buildProxyURL 构建代理URL +func (m *ProxyManager) buildProxyURL(ip *ProxyIP) string { + if m.config.ProxyHost != "" && m.config.ProxyUser != "" { + // 使用配置的代理主机和认证信息 + user := m.config.ProxyUser + if m.config.ProxyUser != "" && ip.IP != "" { + // 支持%s占位符替换IP + user = fmt.Sprintf(m.config.ProxyUser, ip.IP) + } + + protocol := ip.Protocol + if protocol == "" { + protocol = "http" + } + + if m.config.ProxyPassword != "" { + return fmt.Sprintf("%s://%s:%s@%s", protocol, user, m.config.ProxyPassword, m.config.ProxyHost) + } + return fmt.Sprintf("%s://%s@%s", protocol, user, m.config.ProxyHost) + } + + // 直接使用IP信息 + return ip.IP +} + +// GetProxyClient 获取代理客户端(线程安全) +func (m *ProxyManager) GetProxyClient() (*ProxyClient, error) { + if !m.config.Enabled { + // 未启用代理,返回普通HTTP客户端 + return &ProxyClient{ + ProxyID: -1, // -1 表示未使用代理 + IP: "direct", + Client: &http.Client{ + Timeout: m.config.Timeout, + }, + }, nil + } + + // 获取随机代理(使用读锁,确保不越界) + proxyID, proxyIP, err := m.getRandomProxy() + if err != nil { + return nil, err + } + + // 构建代理URL + proxyURLStr := m.buildProxyURL(proxyIP) + proxyURL, err := url.Parse(proxyURLStr) + if err != nil { + return nil, fmt.Errorf("解析代理URL失败: %w", err) + } + + // 创建Transport + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: false, + }, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + } + + return &ProxyClient{ + ProxyID: proxyID, + IP: proxyIP.IP, + Client: &http.Client{ + Transport: transport, + Timeout: m.config.Timeout, + }, + }, nil +} + +// AddBlacklist 将代理IP添加到黑名单(线程安全 - 写锁) +func (m *ProxyManager) AddBlacklist(proxyID int) { + m.mutex.Lock() + defer m.mutex.Unlock() + + // 检查 proxyID 有效性,防止越界 + if proxyID < 0 || proxyID >= len(m.ipList) { + log.Printf("⚠️ 无效的 ProxyID: %d (有效范围: 0-%d)", proxyID, len(m.ipList)-1) + return + } + + ip := m.ipList[proxyID].IP + m.blacklist[proxyID] = ip + m.ipBlacklist[ip] = m.config.BlacklistTTL + + log.Printf("⚠️ 代理IP已加入黑名单: %s (ProxyID: %d, TTL: %d)", ip, proxyID, m.config.BlacklistTTL) +} + +// GetBlacklistStatus 获取黑名单状态(线程安全 - 读锁) +func (m *ProxyManager) GetBlacklistStatus() (total int, blacklisted int, available int) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + total = len(m.ipList) + blacklisted = len(m.ipBlacklist) + available = total - len(m.blacklist) + return +} + +// IsEnabled 检查代理是否启用 +func IsEnabled() bool { + return GetGlobalProxyManager().config.Enabled +} + +// RefreshIPList 刷新全局代理IP列表 +func RefreshIPList() error { + return GetGlobalProxyManager().RefreshIPList() +} + +// AddBlacklist 将代理IP添加到全局黑名单 +func AddBlacklist(proxyID int) { + GetGlobalProxyManager().AddBlacklist(proxyID) +} diff --git a/proxy/single_provider.go b/proxy/single_provider.go new file mode 100644 index 00000000..bbea9fce --- /dev/null +++ b/proxy/single_provider.go @@ -0,0 +1,19 @@ +package proxy + +// SingleProxyProvider 单个代理提供者(不使用IP池) +type SingleProxyProvider struct { + proxyURL string +} + +// NewSingleProxyProvider 创建单个代理提供者 +func NewSingleProxyProvider(proxyURL string) *SingleProxyProvider { + return &SingleProxyProvider{proxyURL: proxyURL} +} + +func (p *SingleProxyProvider) GetIPList() ([]ProxyIP, error) { + return []ProxyIP{{IP: p.proxyURL}}, nil +} + +func (p *SingleProxyProvider) RefreshIPList() ([]ProxyIP, error) { + return p.GetIPList() +} diff --git a/proxy/types.go b/proxy/types.go new file mode 100644 index 00000000..89678c86 --- /dev/null +++ b/proxy/types.go @@ -0,0 +1,40 @@ +package proxy + +import ( + "net/http" + "time" +) + +// ProxyIP 代理IP信息 +type ProxyIP struct { + IP string `json:"ip"` // IP地址 + Port string `json:"port"` // 端口(可选) + Username string `json:"username"` // 用户名(可选) + Password string `json:"password"` // 密码(可选) + Protocol string `json:"protocol"` // 协议: http, https, socks5 + Ext map[string]interface{} `json:"ext"` // 扩展信息 +} + +// ProxyClient 代理客户端 +type ProxyClient struct { + ProxyID int // IP池中的代理ID(索引) + IP string // 使用的IP地址 + *http.Client // HTTP客户端 +} + +// Config 代理配置 +type Config struct { + Enabled bool // 是否启用代理 + Mode string // 模式: "single", "pool", "brightdata" + Timeout time.Duration // 超时时间 + ProxyURL string // 单个代理地址 (single模式) + ProxyList []string // 代理列表 (pool模式) + BrightDataEndpoint string // Bright Data接口地址 (brightdata模式) + BrightDataToken string // Bright Data访问令牌 (brightdata模式) + BrightDataZone string // Bright Data区域 (brightdata模式) + ProxyHost string // 代理主机 + ProxyUser string // 代理用户名模板(支持%s占位符) + ProxyPassword string // 代理密码 + RefreshInterval time.Duration // IP列表刷新间隔 + BlacklistTTL int // 黑名单IP的TTL(刷新次数) +} diff --git a/scripts/generate_rsa_keys/main.go b/scripts/generate_rsa_keys/main.go new file mode 100644 index 00000000..c3f642e2 --- /dev/null +++ b/scripts/generate_rsa_keys/main.go @@ -0,0 +1,76 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "os" + "path/filepath" +) + +func main() { + keysDir := "keys" + if err := os.MkdirAll(keysDir, 0700); err != nil { + fmt.Printf("创建keys目录失败: %v\n", err) + return + } + + privateKeyPath := filepath.Join(keysDir, "rsa_private.key") + publicKeyPath := filepath.Join(keysDir, "rsa_private.key.pub") + + if _, err := os.Stat(privateKeyPath); err == nil { + fmt.Println("RSA密钥对已存在:") + fmt.Printf(" 私钥: %s\n", privateKeyPath) + fmt.Printf(" 公钥: %s\n", publicKeyPath) + + publicKeyPEM, err := ioutil.ReadFile(publicKeyPath) + if err == nil { + fmt.Println("\n公钥内容:") + fmt.Println(string(publicKeyPEM)) + } + return + } + + fmt.Println("生成新的RSA密钥对...") + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + fmt.Printf("生成RSA密钥失败: %v\n", err) + return + } + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + if err := ioutil.WriteFile(privateKeyPath, privateKeyPEM, 0600); err != nil { + fmt.Printf("保存私钥失败: %v\n", err) + return + } + + publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + if err != nil { + fmt.Printf("编码公钥失败: %v\n", err) + return + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyDER, + }) + + if err := ioutil.WriteFile(publicKeyPath, publicKeyPEM, 0644); err != nil { + fmt.Printf("保存公钥失败: %v\n", err) + return + } + + fmt.Println("✓ RSA密钥对生成成功!") + fmt.Printf(" 私钥: %s\n", privateKeyPath) + fmt.Printf(" 公钥: %s\n", publicKeyPath) + fmt.Println("\n公钥内容(可用于前端配置):") + fmt.Println(string(publicKeyPEM)) + fmt.Println("\n注意: 请妥善保管私钥文件,不要提交到版本控制系统中!") +} diff --git a/scripts/migrate_sensitive_data/main.go b/scripts/migrate_sensitive_data/main.go new file mode 100644 index 00000000..f5db9a3b --- /dev/null +++ b/scripts/migrate_sensitive_data/main.go @@ -0,0 +1,367 @@ +package main + +import ( + "bufio" + "database/sql" + "flag" + "fmt" + "log" + "nofx/crypto" + "os" + "strings" + "time" + + _ "github.com/lib/pq" +) + +func main() { + privateKeyPath := flag.String("key", "keys/rsa_private.key", "RSA 私钥路径") + dryRun := flag.Bool("dry-run", false, "仅检查需要迁移的数据,不写入数据库") + flag.Parse() + + // 尝试加载 .env 文件(从项目根目录运行时) + envPaths := []string{ + ".env", // 项目根目录 + } + envLoaded := false + for _, envPath := range envPaths { + if err := loadEnvFile(envPath); err == nil { + log.Printf("成功加载 .env 文件: %s", envPath) + envLoaded = true + break + } + } + if !envLoaded { + log.Printf("警告: 未找到 .env 文件,请确保在项目根目录存在 .env 文件") + log.Printf("尝试的路径: %v", envPaths) + } + + // 确保环境变量已设置 + if os.Getenv("DATA_ENCRYPTION_KEY") == "" { + log.Fatalf("迁移失败: DATA_ENCRYPTION_KEY 环境变量未设置") + } + + if err := run(*privateKeyPath, *dryRun); err != nil { + log.Fatalf("迁移失败: %v", err) + } +} + +func run(privateKeyPath string, dryRun bool) error { + log.SetFlags(0) + + // 尝试多个可能的私钥路径(从项目根目录运行时) + keyPaths := []string{ + privateKeyPath, // 用户指定的路径 + "keys/rsa_private.key", // 项目根目录的 keys 文件夹 + } + + var finalKeyPath string + for _, path := range keyPaths { + if _, err := os.Stat(path); err == nil { + finalKeyPath = path + log.Printf("找到私钥文件: %s", path) + break + } + } + + if finalKeyPath == "" { + finalKeyPath = privateKeyPath // 使用默认路径,让 crypto 服务生成新密钥 + log.Printf("警告: 私钥文件不存在,将使用路径: %s, 系统将尝试生成新密钥", finalKeyPath) + } + + cryptoService, err := crypto.NewCryptoService(finalKeyPath) + if err != nil { + return fmt.Errorf("初始化加密服务失败: %w", err) + } + + db, err := openPostgres() + if err != nil { + return fmt.Errorf("连接数据库失败: %w", err) + } + defer db.Close() + + log.Printf("开始迁移 AI 模型密钥 (dry-run=%v)", dryRun) + if err := migrateAIModels(db, cryptoService, dryRun); err != nil { + return fmt.Errorf("迁移 AI 模型失败: %w", err) + } + + log.Printf("开始迁移交易所密钥 (dry-run=%v)", dryRun) + if err := migrateExchanges(db, cryptoService, dryRun); err != nil { + return fmt.Errorf("迁移交易所失败: %w", err) + } + + log.Printf("✓ 敏感数据迁移完成") + return nil +} + +func openPostgres() (*sql.DB, error) { + host := getEnv("POSTGRES_HOST", "localhost") + // 如果是 Docker 服务名,替换为 localhost + if host == "postgres" { + host = "localhost" + } + port := getEnv("POSTGRES_PORT", "5432") + dbname := getEnv("POSTGRES_DB", "nofx") + user := getEnv("POSTGRES_USER", "nofx") + password := getEnv("POSTGRES_PASSWORD", "nofx123456") + + dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", + host, port, user, password, dbname) + + db, err := sql.Open("postgres", dsn) + if err != nil { + return nil, err + } + + db.SetMaxOpenConns(5) + db.SetMaxIdleConns(2) + db.SetConnMaxLifetime(5 * time.Minute) + + if err := db.Ping(); err != nil { + db.Close() + return nil, err + } + + return db, nil +} + +func migrateAIModels(db *sql.DB, cryptoService *crypto.CryptoService, dryRun bool) error { + type record struct { + ID string + UserID string + APIKey string + } + + rows, err := db.Query(` + SELECT id, user_id, COALESCE(api_key, '') + FROM ai_models + WHERE COALESCE(deleted, FALSE) = FALSE + `) + if err != nil { + return err + } + defer rows.Close() + + var records []record + for rows.Next() { + var r record + if err := rows.Scan(&r.ID, &r.UserID, &r.APIKey); err != nil { + return err + } + records = append(records, r) + } + if err := rows.Err(); err != nil { + return err + } + + var updated int + for _, r := range records { + if r.APIKey == "" || cryptoService.IsEncryptedStorageValue(r.APIKey) { + continue + } + + encrypted, err := cryptoService.EncryptForStorage(r.APIKey, r.UserID, r.ID, "api_key") + if err != nil { + return fmt.Errorf("加密 AI 模型 %s (%s) 失败: %w", r.ID, r.UserID, err) + } + + updated++ + if dryRun { + log.Printf("[DRY-RUN] AI 模型 %s (%s) 将被加密", r.ID, r.UserID) + continue + } + + if _, err := db.Exec(` + UPDATE ai_models + SET api_key = $1, updated_at = CURRENT_TIMESTAMP + WHERE id = $2 AND user_id = $3 + `, encrypted, r.ID, r.UserID); err != nil { + return fmt.Errorf("更新 AI 模型 %s (%s) 失败: %w", r.ID, r.UserID, err) + } + } + + log.Printf("AI 模型处理完成,需更新 %d 条记录", updated) + return nil +} + +func migrateExchanges(db *sql.DB, cryptoService *crypto.CryptoService, dryRun bool) error { + type record struct { + ID string + UserID string + APIKey string + SecretKey string + HyperliquidWallet string + AsterUser string + AsterSigner string + AsterPrivateKey string + } + + rows, err := db.Query(` + SELECT id, user_id, + COALESCE(api_key, '') AS api_key, + COALESCE(secret_key, '') AS secret_key, + COALESCE(hyperliquid_wallet_addr, '') AS hyperliquid_wallet_addr, + COALESCE(aster_user, '') AS aster_user, + COALESCE(aster_signer, '') AS aster_signer, + COALESCE(aster_private_key, '') AS aster_private_key + FROM exchanges + WHERE COALESCE(deleted, FALSE) = FALSE + `) + if err != nil { + return err + } + defer rows.Close() + + var records []record + for rows.Next() { + var r record + if err := rows.Scan( + &r.ID, &r.UserID, + &r.APIKey, &r.SecretKey, + &r.HyperliquidWallet, + &r.AsterUser, &r.AsterSigner, &r.AsterPrivateKey, + ); err != nil { + return err + } + records = append(records, r) + } + if err := rows.Err(); err != nil { + return err + } + + var updated int + for _, r := range records { + newAPIKey := r.APIKey + newSecretKey := r.SecretKey + newHyper := r.HyperliquidWallet + newAsterUser := r.AsterUser + newAsterSigner := r.AsterSigner + newAsterPrivate := r.AsterPrivateKey + + changed := false + + if r.APIKey != "" && !cryptoService.IsEncryptedStorageValue(r.APIKey) { + enc, err := cryptoService.EncryptForStorage(r.APIKey, r.UserID, r.ID, "api_key") + if err != nil { + return fmt.Errorf("加密交易所 API Key 失败: %s (%s): %w", r.ID, r.UserID, err) + } + newAPIKey = enc + changed = true + } + if r.SecretKey != "" && !cryptoService.IsEncryptedStorageValue(r.SecretKey) { + enc, err := cryptoService.EncryptForStorage(r.SecretKey, r.UserID, r.ID, "secret_key") + if err != nil { + return fmt.Errorf("加密交易所 Secret Key 失败: %s (%s): %w", r.ID, r.UserID, err) + } + newSecretKey = enc + changed = true + } + if r.HyperliquidWallet != "" && !cryptoService.IsEncryptedStorageValue(r.HyperliquidWallet) { + enc, err := cryptoService.EncryptForStorage(r.HyperliquidWallet, r.UserID, r.ID, "hyperliquid_wallet_addr") + if err != nil { + return fmt.Errorf("加密 Hyperliquid 地址失败: %s (%s): %w", r.ID, r.UserID, err) + } + newHyper = enc + changed = true + } + if r.AsterUser != "" && !cryptoService.IsEncryptedStorageValue(r.AsterUser) { + enc, err := cryptoService.EncryptForStorage(r.AsterUser, r.UserID, r.ID, "aster_user") + if err != nil { + return fmt.Errorf("加密 Aster 用户失败: %s (%s): %w", r.ID, r.UserID, err) + } + newAsterUser = enc + changed = true + } + if r.AsterSigner != "" && !cryptoService.IsEncryptedStorageValue(r.AsterSigner) { + enc, err := cryptoService.EncryptForStorage(r.AsterSigner, r.UserID, r.ID, "aster_signer") + if err != nil { + return fmt.Errorf("加密 Aster Signer 失败: %s (%s): %w", r.ID, r.UserID, err) + } + newAsterSigner = enc + changed = true + } + if r.AsterPrivateKey != "" && !cryptoService.IsEncryptedStorageValue(r.AsterPrivateKey) { + enc, err := cryptoService.EncryptForStorage(r.AsterPrivateKey, r.UserID, r.ID, "aster_private_key") + if err != nil { + return fmt.Errorf("加密 Aster 私钥失败: %s (%s): %w", r.ID, r.UserID, err) + } + newAsterPrivate = enc + changed = true + } + + if !changed { + continue + } + + updated++ + if dryRun { + log.Printf("[DRY-RUN] 交易所 %s (%s) 将被加密", r.ID, r.UserID) + continue + } + + if _, err := db.Exec(` + UPDATE exchanges + SET api_key = $1, + secret_key = $2, + hyperliquid_wallet_addr = $3, + aster_user = $4, + aster_signer = $5, + aster_private_key = $6, + updated_at = CURRENT_TIMESTAMP + WHERE id = $7 AND user_id = $8 + `, newAPIKey, newSecretKey, newHyper, newAsterUser, newAsterSigner, newAsterPrivate, r.ID, r.UserID); err != nil { + return fmt.Errorf("更新交易所 %s (%s) 失败: %w", r.ID, r.UserID, err) + } + } + + log.Printf("交易所处理完成,需更新 %d 条记录", updated) + return nil +} + +func getEnv(key, fallback string) string { + if val := os.Getenv(key); val != "" { + return val + } + return fallback +} + +func loadEnvFile(filename string) error { + // 检查文件是否存在 + if _, err := os.Stat(filename); os.IsNotExist(err) { + return fmt.Errorf("文件不存在: %s", filename) + } + + // 打开文件 + file, err := os.Open(filename) + if err != nil { + return fmt.Errorf("无法打开文件: %w", err) + } + defer file.Close() + + // 逐行读取 + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + // 跳过空行和注释行 + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // 解析 KEY=VALUE 格式 + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // 只有当环境变量不存在时才设置 + if os.Getenv(key) == "" { + os.Setenv(key, value) + } + } + + return scanner.Err() +} diff --git a/web/src/components/AITradersPage.tsx b/web/src/components/AITradersPage.tsx index 198821bd..325c7483 100644 --- a/web/src/components/AITradersPage.tsx +++ b/web/src/components/AITradersPage.tsx @@ -13,6 +13,7 @@ import { useAuth } from '../contexts/AuthContext' import { getExchangeIcon } from './ExchangeIcons' import { getModelIcon } from './ModelIcons' import { TraderConfigModal } from './TraderConfigModal' +import { TwoStageKeyModal } from './TwoStageKeyModal' import { Bot, Brain, @@ -46,6 +47,12 @@ function getShortName(fullName: string): string { return parts.length > 1 ? parts[parts.length - 1] : fullName } +function maskSecret(value: string): string { + if (!value) return '' + const length = Math.min(value.length, 16) + return '•'.repeat(length) +} + interface AITradersPageProps { onTraderSelect?: (traderId: string) => void } @@ -143,30 +150,9 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { allExchanges?.filter((e) => { if (!e.enabled) return false - // Aster 交易所需要特殊字段 - if (e.id === 'aster') { - return ( - e.asterUser && - e.asterUser.trim() !== '' && - e.asterSigner && - e.asterSigner.trim() !== '' && - e.asterPrivateKey && - e.asterPrivateKey.trim() !== '' - ) - } - - // Hyperliquid 只需要私钥(作为apiKey),钱包地址会自动从私钥生成 - if (e.id === 'hyperliquid') { - return e.apiKey && e.apiKey.trim() !== '' - } - - // Binance 等其他交易所需要 apiKey 和 secretKey - return ( - e.apiKey && - e.apiKey.trim() !== '' && - e.secretKey && - e.secretKey.trim() !== '' - ) + // 由于API不再返回敏感字段信息,只能基于enabled状态判断 + // 实际的配置验证将在后端进行 + return true }) || [] // 检查模型是否正在被运行中的交易员使用 @@ -445,7 +431,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { }, } - await api.updateExchangeConfigs(request) + await api.updateExchangeConfigsEncrypted(request) const refreshed = await api.getExchangeConfigs() setAllExchanges(refreshed) @@ -494,7 +480,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { }, } - await api.updateExchangeConfigs(request) + await api.updateExchangeConfigsEncrypted(request) const refreshedExchanges = await api.getExchangeConfigs() setAllExchanges(refreshedExchanges) @@ -811,7 +797,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) {
) @@ -1666,6 +1652,9 @@ function ExchangeConfigModal({ const [asterUser, setAsterUser] = useState('') const [asterSigner, setAsterSigner] = useState('') const [asterPrivateKey, setAsterPrivateKey] = useState('') + const [secureInputTarget, setSecureInputTarget] = useState< + null | 'hyperliquid' | 'aster' + >(null) // 获取当前选择的交易所信息 // 编辑模式:从 configuredExchanges 查找(包含用户配置的 apiKey、secretKey 等) @@ -1674,24 +1663,50 @@ function ExchangeConfigModal({ ? configuredExchanges?.find(e => e.id === selectedExchangeId) : supportedExchanges?.find(e => e.id === selectedExchangeId); - // 如果是编辑现有交易所,初始化表单数据 + const secureInputContextLabel = + secureInputTarget === 'aster' + ? t('asterExchangeName', language) + : secureInputTarget === 'hyperliquid' + ? t('hyperliquidExchangeName', language) + : undefined + + // 如果是编辑现有交易所,清空所有敏感字段以保证安全 useEffect(() => { if (editingExchangeId && selectedExchange) { - setApiKey(selectedExchange.apiKey || '') - setSecretKey(selectedExchange.secretKey || '') - setPassphrase('') // Don't load existing passphrase for security + // 编辑模式下清空所有敏感字段,用户需要重新输入 + setApiKey('') + setSecretKey('') + setPassphrase('') setTestnet(selectedExchange.testnet || false) - - // Hyperliquid 字段 setHyperliquidWalletAddr(selectedExchange.hyperliquidWalletAddr || '') - - // Aster 字段 setAsterUser(selectedExchange.asterUser || '') - setAsterSigner(selectedExchange.asterSigner || '') - setAsterPrivateKey('') // Don't load existing private key for security + setAsterSigner('') + setAsterPrivateKey('') } }, [editingExchangeId, selectedExchange]) + const handleSecureInputComplete = ({ + value, + obfuscationLog, + }: { + value: string + obfuscationLog: string[] + }) => { + const trimmed = value.trim() + if (secureInputTarget === 'hyperliquid') { + setApiKey(trimmed) + } + if (secureInputTarget === 'aster') { + setAsterPrivateKey(trimmed) + } + console.log('Secure input obfuscation log:', obfuscationLog) + setSecureInputTarget(null) + } + + const handleSecureInputCancel = () => { + setSecureInputTarget(null) + } + // 加载服务器IP(当选择binance时) useEffect(() => { if (selectedExchangeId === 'binance' && !serverIP) { @@ -1755,11 +1770,12 @@ function ExchangeConfigModal({ } return ( -+ {t('twoStageModalDescription', language, { length: expectedLength })} +
++ {t('twoStageStage1Hint', language)} +
+
+ {manualObfuscationValue}
+
+ )}
+ + {t('twoStageStage2Hint', language)} +
+