mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2026-07-03 11:00:58 +08:00
Merge branch 'origin/beta' into nofxos/test
# Conflicts: # config/database_pg.go
This commit is contained in:
@@ -13,6 +13,9 @@ REDIS_HOST=redis
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=redis123456
|
||||
|
||||
# 数据加密密钥
|
||||
DATA_ENCRYPTION_KEY=my_secret_encryption_key
|
||||
|
||||
# Ports Configuration
|
||||
# Backend API server port (internal: 8080, external: configurable)
|
||||
NOFX_BACKEND_PORT=8080
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -35,6 +35,11 @@ config.db
|
||||
certs/
|
||||
beta_codes.txt
|
||||
|
||||
# 密钥文件
|
||||
keys/
|
||||
*.key
|
||||
*.pem
|
||||
|
||||
# 决策日志
|
||||
decision_logs/
|
||||
coin_pool_cache/
|
||||
|
||||
167
api/server.go
167
api/server.go
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"nofx/auth"
|
||||
"nofx/config"
|
||||
"nofx/crypto"
|
||||
"nofx/decision"
|
||||
"nofx/manager"
|
||||
"nofx/trader"
|
||||
@@ -24,11 +25,12 @@ type Server struct {
|
||||
router *gin.Engine
|
||||
traderManager *manager.TraderManager
|
||||
database config.DatabaseInterface
|
||||
cryptoService *crypto.CryptoService
|
||||
port int
|
||||
}
|
||||
|
||||
// NewServer 创建API服务器
|
||||
func NewServer(traderManager *manager.TraderManager, database config.DatabaseInterface, port int) *Server {
|
||||
func NewServer(traderManager *manager.TraderManager, database config.DatabaseInterface, cryptoService *crypto.CryptoService, port int) *Server {
|
||||
// 设置为Release模式(减少日志输出)
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
@@ -37,10 +39,17 @@ func NewServer(traderManager *manager.TraderManager, database config.DatabaseInt
|
||||
// 启用CORS
|
||||
router.Use(corsMiddleware())
|
||||
|
||||
if cryptoService == nil {
|
||||
log.Printf("⚠️ 加密服务未初始化,敏感数据加解密功能不可用")
|
||||
} else {
|
||||
database.SetCryptoService(cryptoService)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
router: router,
|
||||
traderManager: traderManager,
|
||||
database: database,
|
||||
cryptoService: cryptoService,
|
||||
port: port,
|
||||
}
|
||||
|
||||
@@ -123,6 +132,7 @@ func (s *Server) setupRoutes() {
|
||||
// 交易所配置
|
||||
protected.GET("/exchanges", s.handleGetExchangeConfigs)
|
||||
protected.PUT("/exchanges", s.handleUpdateExchangeConfigs)
|
||||
protected.PUT("/exchanges/encrypted", s.handleUpdateExchangeConfigsEncrypted)
|
||||
|
||||
// 用户信号源配置
|
||||
protected.GET("/user/signal-sources", s.handleGetUserSignalSource)
|
||||
@@ -179,11 +189,19 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) {
|
||||
betaModeStr, _ := s.database.GetSystemConfig("beta_mode")
|
||||
betaMode := betaModeStr == "true"
|
||||
|
||||
// 获取RSA公钥
|
||||
var rsaPublicKey string
|
||||
if s.cryptoService != nil {
|
||||
rsaPublicKey = s.cryptoService.GetPublicKeyPEM()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"beta_mode": betaMode,
|
||||
"default_coins": defaultCoins,
|
||||
"btc_eth_leverage": btcEthLeverage,
|
||||
"altcoin_leverage": altcoinLeverage,
|
||||
"rsa_public_key": rsaPublicKey,
|
||||
"rsa_key_id": "rsa-key-2025-11-05",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -381,6 +399,21 @@ type ExchangeConfig struct {
|
||||
Testnet bool `json:"testnet,omitempty"`
|
||||
}
|
||||
|
||||
// SafeExchangeConfig 安全的交易所配置响应结构(不包含敏感信息)
|
||||
type SafeExchangeConfig struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // 钱包地址,非敏感信息
|
||||
AsterUser string `json:"asterUser"` // Aster用户名,非敏感信息
|
||||
Deleted bool `json:"deleted"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UpdateModelConfigRequest struct {
|
||||
Models map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
@@ -1005,7 +1038,25 @@ func (s *Server) handleGetExchangeConfigs(c *gin.Context) {
|
||||
}
|
||||
log.Printf("✅ 找到 %d 个交易所配置", len(exchanges))
|
||||
|
||||
c.JSON(http.StatusOK, exchanges)
|
||||
// 转换为安全的响应结构,过滤敏感信息
|
||||
safeExchanges := make([]SafeExchangeConfig, len(exchanges))
|
||||
for i, exchange := range exchanges {
|
||||
safeExchanges[i] = SafeExchangeConfig{
|
||||
ID: exchange.ID,
|
||||
UserID: exchange.UserID,
|
||||
Name: exchange.Name,
|
||||
Type: exchange.Type,
|
||||
Enabled: exchange.Enabled,
|
||||
Testnet: exchange.Testnet,
|
||||
HyperliquidWalletAddr: exchange.HyperliquidWalletAddr, // 钱包地址,非敏感信息
|
||||
AsterUser: exchange.AsterUser, // Aster用户名,非敏感信息
|
||||
Deleted: exchange.Deleted,
|
||||
CreatedAt: exchange.CreatedAt,
|
||||
UpdatedAt: exchange.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, safeExchanges)
|
||||
}
|
||||
|
||||
// handleUpdateExchangeConfigs 更新交易所配置
|
||||
@@ -1638,8 +1689,10 @@ func (s *Server) handleCompleteRegistration(c *gin.Context) {
|
||||
// handleLogin 处理用户登录请求
|
||||
func (s *Server) handleLogin(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
Email string `json:"email"`
|
||||
EmailEncrypted *crypto.EncryptedPayload `json:"email_encrypted"`
|
||||
Password string `json:"password"`
|
||||
PasswordEncrypted *crypto.EncryptedPayload `json:"password_encrypted"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -1647,6 +1700,51 @@ func (s *Server) handleLogin(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.EmailEncrypted != nil {
|
||||
if s.cryptoService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"})
|
||||
return
|
||||
}
|
||||
|
||||
decryptedEmail, err := s.cryptoService.DecryptSensitiveData(req.EmailEncrypted)
|
||||
if err != nil {
|
||||
log.Printf("❌ 登录邮箱解密失败: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱解密失败"})
|
||||
return
|
||||
}
|
||||
req.Email = decryptedEmail
|
||||
}
|
||||
|
||||
if req.PasswordEncrypted != nil {
|
||||
if s.cryptoService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"})
|
||||
return
|
||||
}
|
||||
|
||||
decryptedPassword, err := s.cryptoService.DecryptSensitiveData(req.PasswordEncrypted)
|
||||
if err != nil {
|
||||
log.Printf("❌ 登录密码解密失败: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "密码解密失败"})
|
||||
return
|
||||
}
|
||||
req.Password = decryptedPassword
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(req.Email)
|
||||
if req.Email == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱不能为空"})
|
||||
return
|
||||
}
|
||||
if !strings.Contains(req.Email, "@") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱格式错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Password) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.database.GetUserByEmail(req.Email)
|
||||
if err != nil {
|
||||
@@ -2026,3 +2124,64 @@ func (s *Server) handleGetPublicTraderConfig(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// handleUpdateExchangeConfigsEncrypted 更新交易所配置(加密传输)
|
||||
func (s *Server) handleUpdateExchangeConfigsEncrypted(c *gin.Context) {
|
||||
if s.cryptoService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// 接收加密载荷
|
||||
var payload crypto.EncryptedPayload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 解密数据
|
||||
decryptedData, err := s.cryptoService.DecryptSensitiveData(&payload)
|
||||
if err != nil {
|
||||
log.Printf("❌ 解密失败: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "解密失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析解密后的数据
|
||||
var req UpdateExchangeConfigRequest
|
||||
if err := json.Unmarshal([]byte(decryptedData), &req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "数据格式错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新每个交易所的配置
|
||||
for exchangeID, exchangeData := range req.Exchanges {
|
||||
err := s.database.UpdateExchange(
|
||||
userID,
|
||||
exchangeID,
|
||||
exchangeData.Enabled,
|
||||
exchangeData.APIKey,
|
||||
exchangeData.SecretKey,
|
||||
exchangeData.Testnet,
|
||||
exchangeData.HyperliquidWalletAddr,
|
||||
exchangeData.AsterUser,
|
||||
exchangeData.AsterSigner,
|
||||
exchangeData.AsterPrivateKey,
|
||||
)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新交易所 %s 失败: %v", exchangeID, err)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 重新加载该用户的所有交易员,使新配置立即生效
|
||||
err = s.traderManager.LoadUserTraders(s.database, userID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 交易所配置已通过加密方式更新")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "交易所配置已更新"})
|
||||
}
|
||||
|
||||
@@ -22,5 +22,20 @@
|
||||
"jwt_secret": "Qk0kAa+d0iIEzXVHXbNbm+UaN3RNabmWtH8rDWZ5OPf+4GX8pBflAHodfpbipVMyrw1fsDanHsNBjhgbDeK9Jg==",
|
||||
"log": {
|
||||
"level": "info"
|
||||
},
|
||||
"proxy": {
|
||||
"enabled": false,
|
||||
"mode": "single",
|
||||
"timeout": 30,
|
||||
"proxy_url": "http://127.0.0.1:7890",
|
||||
"proxy_list": [],
|
||||
"brightdata_endpoint": "",
|
||||
"brightdata_token": "",
|
||||
"brightdata_zone": "",
|
||||
"proxy_host": "",
|
||||
"proxy_user": "",
|
||||
"proxy_password": "",
|
||||
"refresh_interval": 0,
|
||||
"blacklist_ttl": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,8 +75,25 @@ type Config struct {
|
||||
StopTradingMinutes int `json:"stop_trading_minutes"`
|
||||
Leverage LeverageConfig `json:"leverage"` // 杠杆配置
|
||||
Log *LogConfig `json:"log"` // 日志配置(可选)
|
||||
Proxy *ProxyConfig `json:"proxy"` // HTTP 代理配置(可选)
|
||||
}
|
||||
|
||||
// ProxyConfig HTTP 代理配置
|
||||
type ProxyConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用代理
|
||||
Mode string `json:"mode"` // 模式: "single", "pool", "brightdata"
|
||||
Timeout int `json:"timeout"` // 超时时间(秒)
|
||||
ProxyURL string `json:"proxy_url"` // 单个代理地址
|
||||
ProxyList []string `json:"proxy_list"` // 代理列表
|
||||
BrightDataEndpoint string `json:"brightdata_endpoint"` // Bright Data接口地址
|
||||
BrightDataToken string `json:"brightdata_token"` // Bright Data访问令牌
|
||||
BrightDataZone string `json:"brightdata_zone"` // Bright Data区域
|
||||
ProxyHost string `json:"proxy_host"` // 代理主机
|
||||
ProxyUser string `json:"proxy_user"` // 代理用户名模板
|
||||
ProxyPassword string `json:"proxy_password"` // 代理密码
|
||||
RefreshInterval int `json:"refresh_interval"` // 刷新间隔(秒)
|
||||
BlacklistTTL int `json:"blacklist_ttl"` // 黑名单TTL
|
||||
}
|
||||
// LoadConfig 从文件加载配置
|
||||
func LoadConfig(filename string) (*Config, error) {
|
||||
data, err := os.ReadFile(filename)
|
||||
|
||||
@@ -1,126 +1,130 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"nofx/crypto"
|
||||
)
|
||||
|
||||
// DatabaseInterface 定义了数据库实现需要提供的方法集合
|
||||
type DatabaseInterface interface {
|
||||
CreateUser(user *User) error
|
||||
GetUserByEmail(email string) (*User, error)
|
||||
GetUserByID(userID string) (*User, error)
|
||||
GetAllUsers() ([]string, error)
|
||||
UpdateUserOTPVerified(userID string, verified bool) error
|
||||
GetAIModels(userID string) ([]*AIModelConfig, error)
|
||||
UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error
|
||||
GetExchanges(userID string) ([]*ExchangeConfig, error)
|
||||
UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error
|
||||
CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error
|
||||
CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error
|
||||
CreateTrader(trader *TraderRecord) error
|
||||
GetTraders(userID string) ([]*TraderRecord, error)
|
||||
UpdateTraderStatus(userID, id string, isRunning bool) error
|
||||
UpdateTrader(trader *TraderRecord) error
|
||||
UpdateTraderInitialBalance(userID, id string, newBalance float64) error
|
||||
UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error
|
||||
DeleteTrader(userID, id string) error
|
||||
GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error)
|
||||
GetSystemConfig(key string) (string, error)
|
||||
SetSystemConfig(key, value string) error
|
||||
CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error
|
||||
GetUserSignalSource(userID string) (*UserSignalSource, error)
|
||||
UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error
|
||||
GetCustomCoins() []string
|
||||
LoadBetaCodesFromFile(filePath string) error
|
||||
ValidateBetaCode(code string) (bool, error)
|
||||
UseBetaCode(code, userEmail string) error
|
||||
GetBetaCodeStats() (total, used int, err error)
|
||||
Close() error
|
||||
SetCryptoService(cs *crypto.CryptoService)
|
||||
CreateUser(user *User) error
|
||||
GetUserByEmail(email string) (*User, error)
|
||||
GetUserByID(userID string) (*User, error)
|
||||
GetAllUsers() ([]string, error)
|
||||
UpdateUserOTPVerified(userID string, verified bool) error
|
||||
GetAIModels(userID string) ([]*AIModelConfig, error)
|
||||
UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error
|
||||
GetExchanges(userID string) ([]*ExchangeConfig, error)
|
||||
UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error
|
||||
CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error
|
||||
CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error
|
||||
CreateTrader(trader *TraderRecord) error
|
||||
GetTraders(userID string) ([]*TraderRecord, error)
|
||||
UpdateTraderStatus(userID, id string, isRunning bool) error
|
||||
UpdateTrader(trader *TraderRecord) error
|
||||
UpdateTraderInitialBalance(userID, id string, newBalance float64) error
|
||||
UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error
|
||||
DeleteTrader(userID, id string) error
|
||||
GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error)
|
||||
GetSystemConfig(key string) (string, error)
|
||||
SetSystemConfig(key, value string) error
|
||||
CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error
|
||||
GetUserSignalSource(userID string) (*UserSignalSource, error)
|
||||
UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error
|
||||
GetCustomCoins() []string
|
||||
LoadBetaCodesFromFile(filePath string) error
|
||||
ValidateBetaCode(code string) (bool, error)
|
||||
UseBetaCode(code, userEmail string) error
|
||||
GetBetaCodeStats() (total, used int, err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// User 用户配置
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash string `json:"-"` // 不返回到前端
|
||||
OTPSecret string `json:"-"` // 不返回到前端
|
||||
OTPVerified bool `json:"otp_verified"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash string `json:"-"`
|
||||
OTPSecret string `json:"-"`
|
||||
OTPVerified bool `json:"otp_verified"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AIModelConfig AI模型配置
|
||||
type AIModelConfig struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
CustomAPIURL string `json:"customApiUrl"`
|
||||
CustomModelName string `json:"customModelName"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
CustomAPIURL string `json:"customApiUrl"`
|
||||
CustomModelName string `json:"customModelName"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ExchangeConfig 交易所配置
|
||||
type ExchangeConfig struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
SecretKey string `json:"secretKey"`
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"`
|
||||
AsterUser string `json:"asterUser"`
|
||||
AsterSigner string `json:"asterSigner"`
|
||||
AsterPrivateKey string `json:"asterPrivateKey"`
|
||||
Deleted bool `json:"deleted"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
SecretKey string `json:"secretKey"`
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"`
|
||||
AsterUser string `json:"asterUser"`
|
||||
AsterSigner string `json:"asterSigner"`
|
||||
AsterPrivateKey string `json:"asterPrivateKey"`
|
||||
DEXWalletPrivateKey string `json:"dexWalletPrivateKey"` // 统一的DEX私钥字段
|
||||
Deleted bool `json:"deleted"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TraderRecord 交易员配置(数据库实体)
|
||||
// TraderRecord 交易员配置
|
||||
type TraderRecord struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
ExchangeID string `json:"exchange_id"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
BTCETHLeverage int `json:"btc_eth_leverage"`
|
||||
AltcoinLeverage int `json:"altcoin_leverage"`
|
||||
TradingSymbols string `json:"trading_symbols"`
|
||||
UseCoinPool bool `json:"use_coin_pool"`
|
||||
UseOITop bool `json:"use_oi_top"`
|
||||
CustomPrompt string `json:"custom_prompt"`
|
||||
OverrideBasePrompt bool `json:"override_base_prompt"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
IsCrossMargin bool `json:"is_cross_margin"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
ExchangeID string `json:"exchange_id"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
BTCETHLeverage int `json:"btc_eth_leverage"`
|
||||
AltcoinLeverage int `json:"altcoin_leverage"`
|
||||
TradingSymbols string `json:"trading_symbols"`
|
||||
UseCoinPool bool `json:"use_coin_pool"`
|
||||
UseOITop bool `json:"use_oi_top"`
|
||||
CustomPrompt string `json:"custom_prompt"`
|
||||
OverrideBasePrompt bool `json:"override_base_prompt"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
IsCrossMargin bool `json:"is_cross_margin"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// UserSignalSource 用户信号源配置
|
||||
type UserSignalSource struct {
|
||||
ID int `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
CoinPoolURL string `json:"coin_pool_url"`
|
||||
OITopURL string `json:"oi_top_url"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID int `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
CoinPoolURL string `json:"coin_pool_url"`
|
||||
OITopURL string `json:"oi_top_url"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// NewDatabase 创建数据库连接(仅支持 PostgreSQL)
|
||||
func NewDatabase() (DatabaseInterface, error) {
|
||||
pgDB, err := NewPostgreSQLDatabase()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建PostgreSQL数据库失败: %w", err)
|
||||
}
|
||||
return pgDB, nil
|
||||
pgDB, err := NewPostgreSQLDatabase()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建PostgreSQL数据库失败: %w", err)
|
||||
}
|
||||
return pgDB, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/crypto"
|
||||
"nofx/market"
|
||||
"os"
|
||||
"slices"
|
||||
@@ -16,7 +17,8 @@ import (
|
||||
|
||||
// PostgreSQLDatabase PostgreSQL数据库配置
|
||||
type PostgreSQLDatabase struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
cryptoService *crypto.CryptoService
|
||||
}
|
||||
|
||||
// NewPostgreSQLDatabase 创建PostgreSQL数据库连接
|
||||
@@ -60,6 +62,42 @@ func NewPostgreSQLDatabase() (*PostgreSQLDatabase, error) {
|
||||
return database, nil
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDatabase) SetCryptoService(cs *crypto.CryptoService) {
|
||||
d.cryptoService = cs
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDatabase) encryptValue(value string, aadParts ...string) (string, error) {
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
if d.cryptoService == nil {
|
||||
return "", fmt.Errorf("crypto service not initialized")
|
||||
}
|
||||
if !d.cryptoService.HasDataKey() {
|
||||
return "", fmt.Errorf("data encryption key not configured")
|
||||
}
|
||||
if d.cryptoService.IsEncryptedStorageValue(value) {
|
||||
return value, nil
|
||||
}
|
||||
return d.cryptoService.EncryptForStorage(value, aadParts...)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDatabase) decryptValue(value string, aadParts ...string) (string, error) {
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
if d.cryptoService == nil {
|
||||
return "", fmt.Errorf("crypto service not initialized")
|
||||
}
|
||||
if !d.cryptoService.HasDataKey() {
|
||||
return "", fmt.Errorf("data encryption key not configured")
|
||||
}
|
||||
if !d.cryptoService.IsEncryptedStorageValue(value) {
|
||||
return "", fmt.Errorf("value is not encrypted")
|
||||
}
|
||||
return d.cryptoService.DecryptFromStorage(value, aadParts...)
|
||||
}
|
||||
|
||||
// getEnv 获取环境变量,如果不存在返回默认值
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
@@ -162,6 +200,15 @@ func (d *PostgreSQLDatabase) GetAIModels(userID string) ([]*AIModelConfig, error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if model.APIKey != "" {
|
||||
decrypted, err := d.decryptValue(model.APIKey, model.UserID, model.ID, "api_key")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.APIKey = decrypted
|
||||
}
|
||||
|
||||
models = append(models, &model)
|
||||
}
|
||||
|
||||
@@ -216,7 +263,7 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK
|
||||
log.Printf("🗑️ UpdateAIModel: 已标记删除用户 %s 的模型配置 %s (通过provider匹配)", userID, existingID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// 没有找到配置,返回成功(幂等性)
|
||||
log.Printf("ℹ️ UpdateAIModel: 模型配置不存在,跳过删除: %s", id)
|
||||
return nil
|
||||
@@ -229,11 +276,18 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK
|
||||
`, userID, id).Scan(&existingID)
|
||||
|
||||
if err == nil {
|
||||
apiKeyEnc, err := d.encryptValue(apiKey, userID, existingID, "api_key")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 找到了现有配置(精确匹配 ID),更新它
|
||||
_, err = d.db.Exec(`
|
||||
UPDATE ai_models SET enabled = $1, api_key = $2, custom_api_url = $3, custom_model_name = $4, deleted = FALSE, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $5 AND user_id = $6
|
||||
`, enabled, apiKey, customAPIURL, customModelName, existingID, userID)
|
||||
`, enabled, apiKeyEnc, customAPIURL, customModelName, existingID, userID)
|
||||
return err
|
||||
}
|
||||
if err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -244,12 +298,19 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK
|
||||
`, userID, provider).Scan(&existingID)
|
||||
|
||||
if err == nil {
|
||||
apiKeyEnc, err := d.encryptValue(apiKey, userID, existingID, "api_key")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 找到了现有配置(通过 provider 匹配,兼容旧版),更新它
|
||||
log.Printf("⚠️ 使用旧版 provider 匹配更新模型: %s -> %s", provider, existingID)
|
||||
_, err = d.db.Exec(`
|
||||
UPDATE ai_models SET enabled = $1, api_key = $2, custom_api_url = $3, custom_model_name = $4, deleted = FALSE, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $5 AND user_id = $6
|
||||
`, enabled, apiKey, customAPIURL, customModelName, existingID, userID)
|
||||
`, enabled, apiKeyEnc, customAPIURL, customModelName, existingID, userID)
|
||||
return err
|
||||
}
|
||||
if err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -292,11 +353,16 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK
|
||||
newModelID = fmt.Sprintf("%s_%s", userID, provider)
|
||||
}
|
||||
|
||||
apiKeyEnc, err := d.encryptValue(apiKey, userID, newModelID, "api_key")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("✓ 创建新的 AI 模型配置: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
|
||||
_, err = d.db.Exec(`
|
||||
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
`, newModelID, userID, name, provider, enabled, apiKey, customAPIURL, customModelName)
|
||||
`, newModelID, userID, name, provider, enabled, apiKeyEnc, customAPIURL, customModelName)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -309,6 +375,7 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err
|
||||
COALESCE(aster_user, '') AS aster_user,
|
||||
COALESCE(aster_signer, '') AS aster_signer,
|
||||
COALESCE(aster_private_key, '') AS aster_private_key,
|
||||
COALESCE(dex_wallet_private_key, '') AS dex_wallet_private_key,
|
||||
COALESCE(deleted, FALSE) AS deleted,
|
||||
created_at, updated_at
|
||||
FROM exchanges
|
||||
@@ -329,12 +396,50 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err
|
||||
&exchange.Enabled, &exchange.APIKey, &exchange.SecretKey, &exchange.Testnet,
|
||||
&exchange.HyperliquidWalletAddr, &exchange.AsterUser,
|
||||
&exchange.AsterSigner, &exchange.AsterPrivateKey,
|
||||
&exchange.DEXWalletPrivateKey,
|
||||
&exchange.Deleted,
|
||||
&exchange.CreatedAt, &exchange.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if decrypted, err := d.decryptValue(exchange.APIKey, exchange.UserID, exchange.ID, "api_key"); err == nil {
|
||||
exchange.APIKey = decrypted
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
if decrypted, err := d.decryptValue(exchange.SecretKey, exchange.UserID, exchange.ID, "secret_key"); err == nil {
|
||||
exchange.SecretKey = decrypted
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
if decrypted, err := d.decryptValue(exchange.HyperliquidWalletAddr, exchange.UserID, exchange.ID, "hyperliquid_wallet_addr"); err == nil {
|
||||
exchange.HyperliquidWalletAddr = decrypted
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
if decrypted, err := d.decryptValue(exchange.AsterUser, exchange.UserID, exchange.ID, "aster_user"); err == nil {
|
||||
exchange.AsterUser = decrypted
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
if decrypted, err := d.decryptValue(exchange.AsterSigner, exchange.UserID, exchange.ID, "aster_signer"); err == nil {
|
||||
exchange.AsterSigner = decrypted
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
if decrypted, err := d.decryptValue(exchange.AsterPrivateKey, exchange.UserID, exchange.ID, "aster_private_key"); err == nil {
|
||||
exchange.AsterPrivateKey = decrypted
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
if decrypted, err := d.decryptValue(exchange.DEXWalletPrivateKey, exchange.UserID, exchange.ID, "dex_wallet_private_key"); err == nil {
|
||||
exchange.DEXWalletPrivateKey = decrypted
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exchanges = append(exchanges, &exchange)
|
||||
}
|
||||
|
||||
@@ -345,7 +450,7 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err
|
||||
func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error {
|
||||
log.Printf("🔧 UpdateExchange: userID=%s, id=%s, enabled=%v", userID, id, enabled)
|
||||
|
||||
// 如果请求禁用该交易所,标记为已删除
|
||||
// 如果请求禁用该交易所,执行软删除
|
||||
if !enabled {
|
||||
_, err := d.db.Exec(`
|
||||
UPDATE exchanges
|
||||
@@ -369,13 +474,38 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt api_key failed: %w", err)
|
||||
}
|
||||
secretKeyEnc, err := d.encryptValue(secretKey, userID, id, "secret_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt secret_key failed: %w", err)
|
||||
}
|
||||
hyperAddrEnc, err := d.encryptValue(hyperliquidWalletAddr, userID, id, "hyperliquid_wallet_addr")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt hyperliquid_wallet_addr failed: %w", err)
|
||||
}
|
||||
asterUserEnc, err := d.encryptValue(asterUser, userID, id, "aster_user")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt aster_user failed: %w", err)
|
||||
}
|
||||
asterSignerEnc, err := d.encryptValue(asterSigner, userID, id, "aster_signer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt aster_signer failed: %w", err)
|
||||
}
|
||||
asterPrivateKeyEnc, err := d.encryptValue(asterPrivateKey, userID, id, "aster_private_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt aster_private_key failed: %w", err)
|
||||
}
|
||||
|
||||
// 首先尝试更新现有的用户配置
|
||||
result, err := d.db.Exec(`
|
||||
UPDATE exchanges SET enabled = $1, api_key = $2, secret_key = $3, testnet = $4,
|
||||
hyperliquid_wallet_addr = $5, aster_user = $6, aster_signer = $7, aster_private_key = $8,
|
||||
deleted = FALSE, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $9 AND user_id = $10
|
||||
`, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, id, userID)
|
||||
`, enabled, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc, id, userID)
|
||||
if err != nil {
|
||||
log.Printf("❌ UpdateExchange: 更新失败: %v", err)
|
||||
return err
|
||||
@@ -418,7 +548,7 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api
|
||||
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
|
||||
deleted, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, TRUE, $5, $6, $7, $8, $9, $10, $11, FALSE, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
`, id, userID, name, typ, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey)
|
||||
`, id, userID, name, typ, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("❌ UpdateExchange: 创建记录失败: %v", err)
|
||||
@@ -434,21 +564,51 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api
|
||||
|
||||
// CreateAIModel 创建AI模型配置
|
||||
func (d *PostgreSQLDatabase) CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
|
||||
_, err := d.db.Exec(`
|
||||
apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = d.db.Exec(`
|
||||
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
`, id, userID, name, provider, enabled, apiKey, customAPIURL)
|
||||
`, id, userID, name, provider, enabled, apiKeyEnc, customAPIURL)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateExchange 创建交易所配置
|
||||
func (d *PostgreSQLDatabase) CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error {
|
||||
_, err := d.db.Exec(`
|
||||
apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt api_key failed: %w", err)
|
||||
}
|
||||
secretKeyEnc, err := d.encryptValue(secretKey, userID, id, "secret_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt secret_key failed: %w", err)
|
||||
}
|
||||
hyperAddrEnc, err := d.encryptValue(hyperliquidWalletAddr, userID, id, "hyperliquid_wallet_addr")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt hyperliquid_wallet_addr failed: %w", err)
|
||||
}
|
||||
asterUserEnc, err := d.encryptValue(asterUser, userID, id, "aster_user")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt aster_user failed: %w", err)
|
||||
}
|
||||
asterSignerEnc, err := d.encryptValue(asterSigner, userID, id, "aster_signer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt aster_signer failed: %w", err)
|
||||
}
|
||||
asterPrivateKeyEnc, err := d.encryptValue(asterPrivateKey, userID, id, "aster_private_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt aster_private_key failed: %w", err)
|
||||
}
|
||||
|
||||
_, err = d.db.Exec(`
|
||||
INSERT INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet, hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT (id, user_id) DO NOTHING
|
||||
`, id, userID, name, typ, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey)
|
||||
`, id, userID, name, typ, enabled, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -575,6 +735,57 @@ func (d *PostgreSQLDatabase) GetTraderConfig(userID, traderID string) (*TraderRe
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if aiModel.APIKey != "" {
|
||||
decrypted, err := d.decryptValue(aiModel.APIKey, aiModel.UserID, aiModel.ID, "api_key")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
aiModel.APIKey = decrypted
|
||||
}
|
||||
|
||||
if exchange.APIKey != "" {
|
||||
decrypted, err := d.decryptValue(exchange.APIKey, exchange.UserID, exchange.ID, "api_key")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
exchange.APIKey = decrypted
|
||||
}
|
||||
if exchange.SecretKey != "" {
|
||||
decrypted, err := d.decryptValue(exchange.SecretKey, exchange.UserID, exchange.ID, "secret_key")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
exchange.SecretKey = decrypted
|
||||
}
|
||||
if exchange.HyperliquidWalletAddr != "" {
|
||||
decrypted, err := d.decryptValue(exchange.HyperliquidWalletAddr, exchange.UserID, exchange.ID, "hyperliquid_wallet_addr")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
exchange.HyperliquidWalletAddr = decrypted
|
||||
}
|
||||
if exchange.AsterUser != "" {
|
||||
decrypted, err := d.decryptValue(exchange.AsterUser, exchange.UserID, exchange.ID, "aster_user")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
exchange.AsterUser = decrypted
|
||||
}
|
||||
if exchange.AsterSigner != "" {
|
||||
decrypted, err := d.decryptValue(exchange.AsterSigner, exchange.UserID, exchange.ID, "aster_signer")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
exchange.AsterSigner = decrypted
|
||||
}
|
||||
if exchange.AsterPrivateKey != "" {
|
||||
decrypted, err := d.decryptValue(exchange.AsterPrivateKey, exchange.UserID, exchange.ID, "aster_private_key")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
exchange.AsterPrivateKey = decrypted
|
||||
}
|
||||
|
||||
return &trader, &aiModel, &exchange, nil
|
||||
}
|
||||
|
||||
|
||||
394
crypto/crypto.go
Normal file
394
crypto/crypto.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
storagePrefix = "ENC:v1:"
|
||||
storageDelimiter = ":"
|
||||
dataKeyEnvName = "DATA_ENCRYPTION_KEY"
|
||||
)
|
||||
|
||||
type EncryptedPayload struct {
|
||||
WrappedKey string `json:"wrappedKey"`
|
||||
IV string `json:"iv"`
|
||||
Ciphertext string `json:"ciphertext"`
|
||||
AAD string `json:"aad,omitempty"`
|
||||
KID string `json:"kid,omitempty"`
|
||||
TS int64 `json:"ts,omitempty"`
|
||||
}
|
||||
|
||||
type AADData struct {
|
||||
UserID string `json:"userId"`
|
||||
SessionID string `json:"sessionId"`
|
||||
TS int64 `json:"ts"`
|
||||
Purpose string `json:"purpose"`
|
||||
}
|
||||
|
||||
type CryptoService struct {
|
||||
privateKey *rsa.PrivateKey
|
||||
publicKey *rsa.PublicKey
|
||||
dataKey []byte
|
||||
}
|
||||
|
||||
func NewCryptoService(privateKeyPath string) (*CryptoService, error) {
|
||||
// 读取私钥文件
|
||||
privateKeyPEM, err := ioutil.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
// 如果私钥文件不存在,生成新的密钥对
|
||||
if err := GenerateRSAKeyPair(privateKeyPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate RSA key pair: %w", err)
|
||||
}
|
||||
privateKeyPEM, err = ioutil.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read generated private key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析私钥
|
||||
privateKey, err := ParseRSAPrivateKeyFromPEM(privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
dataKey, err := loadDataKeyFromEnv()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load data encryption key: %w", err)
|
||||
}
|
||||
|
||||
return &CryptoService{
|
||||
privateKey: privateKey,
|
||||
publicKey: &privateKey.PublicKey,
|
||||
dataKey: dataKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GenerateRSAKeyPair(privateKeyPath string) error {
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(privateKeyPath)
|
||||
if dir != "." {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 生成 RSA 密钥对
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 编码私钥
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
// 保存私钥
|
||||
if err := ioutil.WriteFile(privateKeyPath, privateKeyPEM, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 编码公钥
|
||||
publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyDER,
|
||||
})
|
||||
|
||||
// 保存公钥
|
||||
publicKeyPath := privateKeyPath + ".pub"
|
||||
if err := ioutil.WriteFile(publicKeyPath, publicKeyPEM, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return nil, errors.New("no PEM block found")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.New("not an RSA key")
|
||||
}
|
||||
return rsaKey, nil
|
||||
default:
|
||||
return nil, errors.New("unsupported key type: " + block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func loadDataKeyFromEnv() ([]byte, error) {
|
||||
keyStr := strings.TrimSpace(os.Getenv(dataKeyEnvName))
|
||||
if keyStr == "" {
|
||||
return nil, fmt.Errorf("%s not set", dataKeyEnvName)
|
||||
}
|
||||
|
||||
if key, ok := decodePossibleKey(keyStr); ok {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(keyStr))
|
||||
key := make([]byte, len(sum))
|
||||
copy(key, sum[:])
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func decodePossibleKey(value string) ([]byte, bool) {
|
||||
decoders := []func(string) ([]byte, error){
|
||||
base64.StdEncoding.DecodeString,
|
||||
base64.RawStdEncoding.DecodeString,
|
||||
func(s string) ([]byte, error) { return hex.DecodeString(s) },
|
||||
}
|
||||
|
||||
for _, decoder := range decoders {
|
||||
if decoded, err := decoder(value); err == nil {
|
||||
if key, ok := normalizeAESKey(decoded); ok {
|
||||
return key, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func normalizeAESKey(raw []byte) ([]byte, bool) {
|
||||
switch len(raw) {
|
||||
case 16, 24, 32:
|
||||
return raw, true
|
||||
case 0:
|
||||
return nil, false
|
||||
default:
|
||||
sum := sha256.Sum256(raw)
|
||||
key := make([]byte, len(sum))
|
||||
copy(key, sum[:])
|
||||
return key, true
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CryptoService) HasDataKey() bool {
|
||||
return len(cs.dataKey) > 0
|
||||
}
|
||||
|
||||
func (cs *CryptoService) GetPublicKeyPEM() string {
|
||||
publicKeyDER, err := x509.MarshalPKIXPublicKey(cs.publicKey)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyDER,
|
||||
})
|
||||
|
||||
return string(publicKeyPEM)
|
||||
}
|
||||
|
||||
func (cs *CryptoService) EncryptForStorage(plaintext string, aadParts ...string) (string, error) {
|
||||
if plaintext == "" {
|
||||
return "", nil
|
||||
}
|
||||
if !cs.HasDataKey() {
|
||||
return "", errors.New("data encryption key not configured")
|
||||
}
|
||||
if isEncryptedStorageValue(plaintext) {
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(cs.dataKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
aad := composeAAD(aadParts)
|
||||
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), aad)
|
||||
|
||||
return storagePrefix +
|
||||
base64.StdEncoding.EncodeToString(nonce) + storageDelimiter +
|
||||
base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (string, error) {
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
if !cs.HasDataKey() {
|
||||
return "", errors.New("data encryption key not configured")
|
||||
}
|
||||
if !isEncryptedStorageValue(value) {
|
||||
return "", errors.New("value is not encrypted")
|
||||
}
|
||||
|
||||
payload := strings.TrimPrefix(value, storagePrefix)
|
||||
parts := strings.SplitN(payload, storageDelimiter, 2)
|
||||
if len(parts) != 2 {
|
||||
return "", errors.New("invalid encrypted payload format")
|
||||
}
|
||||
|
||||
nonce, err := base64.StdEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode nonce failed: %w", err)
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode ciphertext failed: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(cs.dataKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return "", fmt.Errorf("invalid nonce size: expected %d, got %d", gcm.NonceSize(), len(nonce))
|
||||
}
|
||||
|
||||
aad := composeAAD(aadParts)
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, aad)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
func (cs *CryptoService) IsEncryptedStorageValue(value string) bool {
|
||||
return isEncryptedStorageValue(value)
|
||||
}
|
||||
|
||||
func composeAAD(parts []string) []byte {
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
return []byte(strings.Join(parts, "|"))
|
||||
}
|
||||
|
||||
func isEncryptedStorageValue(value string) bool {
|
||||
return strings.HasPrefix(value, storagePrefix)
|
||||
}
|
||||
|
||||
func (cs *CryptoService) DecryptPayload(payload *EncryptedPayload) ([]byte, error) {
|
||||
// 1. 验证时间戳(防止重放攻击)
|
||||
if payload.TS != 0 {
|
||||
elapsed := time.Since(time.Unix(payload.TS, 0))
|
||||
if elapsed > 5*time.Minute || elapsed < -1*time.Minute {
|
||||
return nil, errors.New("timestamp invalid or expired")
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 解码 base64url
|
||||
wrappedKey, err := base64.RawURLEncoding.DecodeString(payload.WrappedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode wrapped key: %w", err)
|
||||
}
|
||||
|
||||
iv, err := base64.RawURLEncoding.DecodeString(payload.IV)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode IV: %w", err)
|
||||
}
|
||||
|
||||
ciphertext, err := base64.RawURLEncoding.DecodeString(payload.Ciphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode ciphertext: %w", err)
|
||||
}
|
||||
|
||||
var aad []byte
|
||||
if payload.AAD != "" {
|
||||
aad, err = base64.RawURLEncoding.DecodeString(payload.AAD)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AAD: %w", err)
|
||||
}
|
||||
|
||||
// 验证 AAD
|
||||
var aadData AADData
|
||||
if err := json.Unmarshal(aad, &aadData); err == nil {
|
||||
// 可以在这里添加额外的验证逻辑
|
||||
// 例如:验证 sessionID、userID 等
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 使用 RSA-OAEP 解密 AES 密钥
|
||||
aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, cs.privateKey, wrappedKey, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unwrap AES key: %w", err)
|
||||
}
|
||||
|
||||
// 4. 使用 AES-GCM 解密数据
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
if len(iv) != gcm.NonceSize() {
|
||||
return nil, fmt.Errorf("invalid IV size: expected %d, got %d", gcm.NonceSize(), len(iv))
|
||||
}
|
||||
|
||||
// 解密并验证认证标签
|
||||
plaintext, err := gcm.Open(nil, iv, ciphertext, aad)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authentication/decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func (cs *CryptoService) DecryptSensitiveData(payload *EncryptedPayload) (string, error) {
|
||||
plaintext, err := cs.DecryptPayload(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
@@ -57,6 +57,7 @@ services:
|
||||
environment:
|
||||
- TZ=${NOFX_TIMEZONE:-Asia/Shanghai} # Set timezone
|
||||
- AI_MAX_TOKENS=4000 # AI响应的最大token数(默认2000,建议4000-8000)
|
||||
- DATA_ENCRYPTION_KEY=${DATA_ENCRYPTION_KEY} # 数据加密密钥
|
||||
- POSTGRES_HOST=postgres
|
||||
- POSTGRES_PORT=5432
|
||||
- POSTGRES_DB=${POSTGRES_DB:-nofx}
|
||||
|
||||
10
main.go
10
main.go
@@ -7,6 +7,7 @@ import (
|
||||
"nofx/api"
|
||||
"nofx/auth"
|
||||
"nofx/config"
|
||||
"nofx/crypto"
|
||||
"nofx/manager"
|
||||
"nofx/market"
|
||||
"nofx/pool"
|
||||
@@ -171,6 +172,13 @@ func main() {
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
// 初始化加密服务(用于敏感数据加密存储与传输)
|
||||
cryptoService, err := crypto.NewCryptoService("keys/rsa_private.key")
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 初始化加密服务失败: %v", err)
|
||||
}
|
||||
database.SetCryptoService(cryptoService)
|
||||
|
||||
// 同步config.json到数据库
|
||||
if err := syncConfigToDatabase(database, configFile); err != nil {
|
||||
log.Printf("⚠️ 同步config.json到数据库失败: %v", err)
|
||||
@@ -289,7 +297,7 @@ func main() {
|
||||
}
|
||||
|
||||
// 创建并启动API服务器
|
||||
apiServer := api.NewServer(traderManager, database, apiPort)
|
||||
apiServer := api.NewServer(traderManager, database, cryptoService, apiPort)
|
||||
go func() {
|
||||
if err := apiServer.Start(); err != nil {
|
||||
log.Printf("❌ API服务器错误: %v", err)
|
||||
|
||||
685
proxy/README.md
Normal file
685
proxy/README.md
Normal file
@@ -0,0 +1,685 @@
|
||||
# HTTP 代理模块
|
||||
|
||||
## 概述
|
||||
|
||||
这是一个高度解耦的HTTP代理管理模块,专为解决高频API请求被限流/封禁问题而设计。支持单代理、代理池和动态IP获取三种模式,提供线程安全的IP轮换和智能黑名单管理机制。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- ✅ **三种工作模式**:单代理、固定代理池、Bright Data API动态获取
|
||||
- ✅ **线程安全**:所有操作使用读写锁保护,支持并发访问
|
||||
- ✅ **智能黑名单**:失败的代理IP手动加入黑名单,TTL机制自动恢复
|
||||
- ✅ **自动刷新**:支持定时刷新代理IP列表(默认30分钟)
|
||||
- ✅ **随机轮换**:从可用IP池中随机选择,避免单点压力
|
||||
- ✅ **防越界保护**:多层数组边界检查,确保运行时安全
|
||||
- ✅ **可选启用**:未配置或禁用时自动使用直连,不影响独立客户
|
||||
|
||||
## 架构设计
|
||||
|
||||
```
|
||||
proxy/
|
||||
├── README.md # 本文档
|
||||
├── types.go # 核心数据结构定义
|
||||
├── provider.go # IP提供者接口定义
|
||||
├── single_provider.go # 单代理实现
|
||||
├── fixed_provider.go # 固定代理池实现
|
||||
├── brightdata_provider.go # Bright Data API实现
|
||||
└── proxy_manager.go # 代理管理器(核心逻辑)
|
||||
```
|
||||
|
||||
### 设计原则
|
||||
|
||||
1. **接口抽象**:通过 `IPProvider` 接口实现不同代理源的统一管理
|
||||
2. **策略模式**:三种Provider实现可灵活切换
|
||||
3. **单例模式**:全局ProxyManager确保资源统一管理
|
||||
4. **防御性编程**:多层边界检查,优雅处理异常情况
|
||||
|
||||
## 配置说明
|
||||
|
||||
在 `config.json` 中添加 `proxy` 配置段:
|
||||
|
||||
```json
|
||||
{
|
||||
"proxy": {
|
||||
"enabled": true,
|
||||
"mode": "single",
|
||||
"timeout": 30,
|
||||
"proxy_url": "http://127.0.0.1:7890",
|
||||
"proxy_list": [],
|
||||
"brightdata_endpoint": "",
|
||||
"brightdata_token": "",
|
||||
"brightdata_zone": "",
|
||||
"proxy_host": "",
|
||||
"proxy_user": "",
|
||||
"proxy_password": "",
|
||||
"refresh_interval": 1800,
|
||||
"blacklist_ttl": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 配置字段详解
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `enabled` | bool | 是 | 是否启用代理(false时使用直连) |
|
||||
| `mode` | string | 是 | 代理模式:`single`/`pool`/`brightdata` |
|
||||
| `timeout` | int | 否 | HTTP请求超时时间(秒),默认30 |
|
||||
| `proxy_url` | string | single模式必填 | 单个代理地址,如 `http://127.0.0.1:7890` |
|
||||
| `proxy_list` | []string | pool模式必填 | 代理列表,支持 `http://`、`https://`、`socks5://` |
|
||||
| `brightdata_endpoint` | string | brightdata模式必填 | Bright Data API端点 |
|
||||
| `brightdata_token` | string | brightdata模式可选 | Bright Data访问令牌 |
|
||||
| `brightdata_zone` | string | brightdata模式可选 | Bright Data区域参数 |
|
||||
| `proxy_host` | string | 否 | 代理主机(用于认证代理) |
|
||||
| `proxy_user` | string | 否 | 代理用户名模板,支持 `%s` 占位符替换IP |
|
||||
| `proxy_password` | string | 否 | 代理密码 |
|
||||
| `refresh_interval` | int | 否 | IP列表刷新间隔(秒),brightdata模式默认1800(30分钟) |
|
||||
| `blacklist_ttl` | int | 否 | 黑名单IP的TTL(刷新次数),默认5 |
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 初始化代理管理器
|
||||
|
||||
在 `main.go` 或初始化代码中:
|
||||
|
||||
```go
|
||||
import (
|
||||
"nofx/proxy"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 方式1:使用配置结构体初始化
|
||||
proxyConfig := &proxy.Config{
|
||||
Enabled: true,
|
||||
Mode: "single",
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: "http://127.0.0.1:7890",
|
||||
BlacklistTTL: 5,
|
||||
}
|
||||
|
||||
err := proxy.InitGlobalProxyManager(proxyConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("初始化代理管理器失败: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 获取代理HTTP客户端
|
||||
|
||||
在需要发送HTTP请求的地方:
|
||||
|
||||
```go
|
||||
// 获取代理客户端(包含ProxyID用于黑名单管理)
|
||||
proxyClient, err := proxy.GetProxyHTTPClient()
|
||||
if err != nil {
|
||||
log.Printf("获取代理客户端失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用代理客户端发送请求
|
||||
resp, err := proxyClient.Client.Get("https://api.example.com/data")
|
||||
if err != nil {
|
||||
// 请求失败,将此代理加入黑名单
|
||||
proxy.AddBlacklist(proxyClient.ProxyID)
|
||||
log.Printf("请求失败,代理IP %s 已加入黑名单", proxyClient.IP)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理响应...
|
||||
```
|
||||
|
||||
### 3. 黑名单管理
|
||||
|
||||
```go
|
||||
// 添加失败的代理到黑名单
|
||||
proxy.AddBlacklist(proxyClient.ProxyID)
|
||||
|
||||
// 获取黑名单状态
|
||||
total, blacklisted, available := proxy.GetGlobalProxyManager().GetBlacklistStatus()
|
||||
log.Printf("代理状态: 总计%d个,黑名单%d个,可用%d个", total, blacklisted, available)
|
||||
```
|
||||
|
||||
### 4. 手动刷新IP列表
|
||||
|
||||
```go
|
||||
err := proxy.RefreshIPList()
|
||||
if err != nil {
|
||||
log.Printf("刷新IP列表失败: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### 5. 检查代理是否启用
|
||||
|
||||
```go
|
||||
if proxy.IsEnabled() {
|
||||
log.Println("代理已启用")
|
||||
} else {
|
||||
log.Println("代理未启用,使用直连")
|
||||
}
|
||||
```
|
||||
|
||||
## 三种模式详解
|
||||
|
||||
### Mode 1: Single(单代理模式)
|
||||
|
||||
适用场景:本地代理工具(如Clash、V2Ray)或单个固定代理服务器
|
||||
|
||||
```json
|
||||
{
|
||||
"proxy": {
|
||||
"enabled": true,
|
||||
"mode": "single",
|
||||
"proxy_url": "http://127.0.0.1:7890"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
特点:
|
||||
- 简单直接,适合本地开发和测试
|
||||
- 所有请求通过同一个代理
|
||||
- 不需要刷新和轮换
|
||||
|
||||
### Mode 2: Pool(代理池模式)
|
||||
|
||||
适用场景:拥有多个固定代理服务器,需要轮换使用
|
||||
|
||||
```json
|
||||
{
|
||||
"proxy": {
|
||||
"enabled": true,
|
||||
"mode": "pool",
|
||||
"proxy_list": [
|
||||
"http://proxy1.example.com:8080",
|
||||
"http://user:pass@proxy2.example.com:8080",
|
||||
"socks5://proxy3.example.com:1080"
|
||||
],
|
||||
"blacklist_ttl": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
特点:
|
||||
- 支持多协议:HTTP、HTTPS、SOCKS5
|
||||
- 随机选择代理,分散请求压力
|
||||
- 失败的代理自动加入黑名单
|
||||
- 黑名单IP经过TTL次刷新后自动恢复
|
||||
|
||||
### Mode 3: BrightData(动态IP模式)
|
||||
|
||||
适用场景:使用Bright Data等提供API的动态代理服务
|
||||
|
||||
```json
|
||||
{
|
||||
"proxy": {
|
||||
"enabled": true,
|
||||
"mode": "brightdata",
|
||||
"brightdata_endpoint": "https://api.brightdata.com/zones/get_ips",
|
||||
"brightdata_token": "your_api_token",
|
||||
"brightdata_zone": "residential",
|
||||
"proxy_host": "brd.superproxy.io:22225",
|
||||
"proxy_user": "brd-customer-xxx-zone-residential-ip-%s",
|
||||
"proxy_password": "your_password",
|
||||
"refresh_interval": 1800,
|
||||
"blacklist_ttl": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
特点:
|
||||
- 从API动态获取可用IP列表
|
||||
- 自动定时刷新(默认30分钟)
|
||||
- 支持用户名模板(`%s` 替换为IP地址)
|
||||
- 黑名单TTL机制避免频繁切换
|
||||
|
||||
**用户名模板说明**:
|
||||
```
|
||||
proxy_user: "brd-customer-xxx-zone-residential-ip-%s"
|
||||
↑
|
||||
自动替换为IP地址
|
||||
```
|
||||
|
||||
## 核心API
|
||||
|
||||
### 全局函数
|
||||
|
||||
```go
|
||||
// 初始化全局代理管理器(只执行一次)
|
||||
func InitGlobalProxyManager(config *Config) error
|
||||
|
||||
// 获取全局代理管理器实例
|
||||
func GetGlobalProxyManager() *ProxyManager
|
||||
|
||||
// 获取代理HTTP客户端(包含ProxyID和IP信息)
|
||||
func GetProxyHTTPClient() (*ProxyClient, error)
|
||||
|
||||
// 将代理IP添加到黑名单
|
||||
func AddBlacklist(proxyID int)
|
||||
|
||||
// 刷新IP列表
|
||||
func RefreshIPList() error
|
||||
|
||||
// 检查代理是否启用
|
||||
func IsEnabled() bool
|
||||
```
|
||||
|
||||
### ProxyManager 方法
|
||||
|
||||
```go
|
||||
// 获取代理客户端
|
||||
func (m *ProxyManager) GetProxyClient() (*ProxyClient, error)
|
||||
|
||||
// 刷新IP列表
|
||||
func (m *ProxyManager) RefreshIPList() error
|
||||
|
||||
// 添加到黑名单
|
||||
func (m *ProxyManager) AddBlacklist(proxyID int)
|
||||
|
||||
// 获取黑名单状态
|
||||
func (m *ProxyManager) GetBlacklistStatus() (total, blacklisted, available int)
|
||||
|
||||
// 启动自动刷新
|
||||
func (m *ProxyManager) StartAutoRefresh()
|
||||
|
||||
// 停止自动刷新
|
||||
func (m *ProxyManager) StopAutoRefresh()
|
||||
```
|
||||
|
||||
## 黑名单机制
|
||||
|
||||
### 工作原理
|
||||
|
||||
1. **添加黑名单**:当代理请求失败时,调用 `AddBlacklist(proxyID)` 将该IP加入黑名单
|
||||
2. **TTL倒计时**:每次刷新IP列表时,黑名单中的IP的TTL减1
|
||||
3. **自动恢复**:当TTL归零时,IP自动从黑名单移除,重新可用
|
||||
|
||||
### 线程安全保证
|
||||
|
||||
```go
|
||||
// 添加黑名单使用写锁
|
||||
func (m *ProxyManager) AddBlacklist(proxyID int) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 防越界检查
|
||||
if proxyID < 0 || proxyID >= len(m.ipList) {
|
||||
log.Printf("⚠️ 无效的 ProxyID: %d", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ip := m.ipList[proxyID].IP
|
||||
m.blacklist[proxyID] = ip
|
||||
m.ipBlacklist[ip] = m.config.BlacklistTTL
|
||||
}
|
||||
|
||||
// 获取代理使用读锁(支持并发)
|
||||
func (m *ProxyManager) getRandomProxy() (int, *ProxyIP, error) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
// ... 读取操作
|
||||
}
|
||||
```
|
||||
|
||||
### 示例流程
|
||||
|
||||
```
|
||||
初始状态:5个代理IP,TTL=3
|
||||
IP列表: [IP1, IP2, IP3, IP4, IP5]
|
||||
黑名单: {}
|
||||
|
||||
第1次失败:IP2请求失败
|
||||
IP列表: [IP1, IP2, IP3, IP4, IP5]
|
||||
黑名单: {IP2: TTL=3}
|
||||
|
||||
第1次刷新:TTL-1
|
||||
黑名单: {IP2: TTL=2}
|
||||
|
||||
第2次刷新:TTL-1
|
||||
黑名单: {IP2: TTL=1}
|
||||
|
||||
第3次刷新:TTL-1
|
||||
黑名单: {IP2: TTL=0} → 从黑名单移除
|
||||
|
||||
第3次刷新后:
|
||||
IP列表: [IP1, IP2, IP3, IP4, IP5]
|
||||
黑名单: {} ← IP2已恢复可用
|
||||
```
|
||||
|
||||
## 完整使用示例
|
||||
|
||||
### 示例1:币安API请求(单代理模式)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"nofx/proxy"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 初始化代理
|
||||
err := proxy.InitGlobalProxyManager(&proxy.Config{
|
||||
Enabled: true,
|
||||
Mode: "single",
|
||||
ProxyURL: "http://127.0.0.1:7890",
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("初始化代理失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取币安数据
|
||||
proxyClient, err := proxy.GetProxyHTTPClient()
|
||||
if err != nil {
|
||||
log.Fatalf("获取代理客户端失败: %v", err)
|
||||
}
|
||||
|
||||
resp, err := proxyClient.Client.Get("https://fapi.binance.com/fapi/v1/ticker/24hr")
|
||||
if err != nil {
|
||||
log.Printf("请求失败: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
log.Printf("请求成功,使用代理: %s", proxyClient.IP)
|
||||
}
|
||||
```
|
||||
|
||||
### 示例2:OI数据获取(代理池模式 + 黑名单)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"nofx/proxy"
|
||||
"time"
|
||||
)
|
||||
|
||||
func fetchOIData(symbol string) error {
|
||||
proxyClient, err := proxy.GetProxyHTTPClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取代理失败: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://fapi.binance.com/futures/data/openInterestHist?symbol=%s&period=5m&limit=1", symbol)
|
||||
resp, err := proxyClient.Client.Get(url)
|
||||
if err != nil {
|
||||
// 请求失败,加入黑名单
|
||||
proxy.AddBlacklist(proxyClient.ProxyID)
|
||||
return fmt.Errorf("请求失败 (代理: %s): %w", proxyClient.IP, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
// 状态码异常,加入黑名单
|
||||
proxy.AddBlacklist(proxyClient.ProxyID)
|
||||
return fmt.Errorf("状态码异常: %d (代理: %s)", resp.StatusCode, proxyClient.IP)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Printf("✓ 获取 %s OI数据成功 (代理: %s): %s", symbol, proxyClient.IP, string(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 初始化代理池
|
||||
err := proxy.InitGlobalProxyManager(&proxy.Config{
|
||||
Enabled: true,
|
||||
Mode: "pool",
|
||||
ProxyList: []string{
|
||||
"http://proxy1.example.com:8080",
|
||||
"http://proxy2.example.com:8080",
|
||||
"http://proxy3.example.com:8080",
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
BlacklistTTL: 5,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("初始化代理失败: %v", err)
|
||||
}
|
||||
|
||||
// 循环获取数据
|
||||
symbols := []string{"BTCUSDT", "ETHUSDT", "SOLUSDT"}
|
||||
for {
|
||||
for _, symbol := range symbols {
|
||||
if err := fetchOIData(symbol); err != nil {
|
||||
log.Printf("⚠️ %v", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 示例3:Bright Data动态IP
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"nofx/proxy"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 初始化Bright Data代理
|
||||
err := proxy.InitGlobalProxyManager(&proxy.Config{
|
||||
Enabled: true,
|
||||
Mode: "brightdata",
|
||||
BrightDataEndpoint: "https://api.brightdata.com/zones/get_ips",
|
||||
BrightDataToken: "your_token",
|
||||
BrightDataZone: "residential",
|
||||
ProxyHost: "brd.superproxy.io:22225",
|
||||
ProxyUser: "brd-customer-xxx-zone-residential-ip-%s",
|
||||
ProxyPassword: "your_password",
|
||||
RefreshInterval: 30 * time.Minute,
|
||||
Timeout: 30 * time.Second,
|
||||
BlacklistTTL: 5,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("初始化代理失败: %v", err)
|
||||
}
|
||||
|
||||
// 代理会自动每30分钟刷新IP列表
|
||||
log.Println("✓ Bright Data代理已启动,自动刷新已开启")
|
||||
|
||||
// 获取并使用代理
|
||||
for i := 0; i < 10; i++ {
|
||||
proxyClient, err := proxy.GetProxyHTTPClient()
|
||||
if err != nil {
|
||||
log.Printf("获取代理失败: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := proxyClient.Client.Get("https://api.ipify.org?format=json")
|
||||
if err != nil {
|
||||
proxy.AddBlacklist(proxyClient.ProxyID)
|
||||
log.Printf("请求失败,代理已加入黑名单: %s", proxyClient.IP)
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
log.Printf("✓ 请求成功 (代理ID: %d, IP: %s)", proxyClient.ProxyID, proxyClient.IP)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
### 1. 模块解耦性
|
||||
|
||||
- ✅ 代理模块完全独立,不依赖其他业务模块
|
||||
- ✅ 禁用代理时自动使用直连,对业务代码透明
|
||||
- ✅ 适合多租户/多客户环境,可按需启用
|
||||
|
||||
### 2. 线程安全
|
||||
|
||||
- ✅ 所有公开方法都是线程安全的
|
||||
- ✅ 支持高并发场景下的代理获取和黑名单操作
|
||||
- ✅ 读写锁优化性能:读操作可并发,写操作独占
|
||||
|
||||
### 3. 错误处理
|
||||
|
||||
```go
|
||||
proxyClient, err := proxy.GetProxyHTTPClient()
|
||||
if err != nil {
|
||||
// 可能的错误:
|
||||
// - 代理IP列表为空
|
||||
// - 所有代理都在黑名单中
|
||||
// - 代理URL解析失败
|
||||
log.Printf("获取代理失败: %v", err)
|
||||
|
||||
// 建议:降级为直连或重试
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
### 4. 性能优化建议
|
||||
|
||||
- 对于高频请求,复用 `http.Client` 而不是每次创建新的
|
||||
- 合理设置 `refresh_interval` 避免频繁刷新
|
||||
- `blacklist_ttl` 建议设置为 3-10,平衡恢复速度和稳定性
|
||||
|
||||
### 5. 安全建议
|
||||
|
||||
- 生产环境中代理密钥应使用环境变量或密钥管理服务
|
||||
- 避免在日志中打印完整的代理URL(包含密码)
|
||||
- TLS验证默认开启,如需跳过请谨慎评估风险
|
||||
|
||||
### 6. 调试技巧
|
||||
|
||||
```go
|
||||
// 获取当前代理状态
|
||||
total, blacklisted, available := proxy.GetGlobalProxyManager().GetBlacklistStatus()
|
||||
log.Printf("代理池状态: 总计=%d, 黑名单=%d, 可用=%d", total, blacklisted, available)
|
||||
|
||||
// 检查是否启用
|
||||
if !proxy.IsEnabled() {
|
||||
log.Println("代理未启用,请检查配置")
|
||||
}
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 问题1:获取代理失败 - "代理IP列表为空"
|
||||
|
||||
**原因**:
|
||||
- `single` 模式:未配置 `proxy_url`
|
||||
- `pool` 模式:`proxy_list` 为空
|
||||
- `brightdata` 模式:API返回空列表或请求失败
|
||||
|
||||
**解决方案**:
|
||||
```bash
|
||||
# 检查配置文件
|
||||
cat config.json | grep -A 15 "proxy"
|
||||
|
||||
# 检查日志,查看初始化信息
|
||||
# 应该看到类似:🌐 HTTP 代理已启用 (xxx模式)
|
||||
```
|
||||
|
||||
### 问题2:所有代理都在黑名单中
|
||||
|
||||
**原因**:请求持续失败,所有IP被加入黑名单
|
||||
|
||||
**解决方案**:
|
||||
```go
|
||||
// 方案1:手动刷新IP列表(会触发TTL倒计时)
|
||||
proxy.RefreshIPList()
|
||||
|
||||
// 方案2:降低blacklist_ttl,加快恢复速度
|
||||
// config.json: "blacklist_ttl": 2 (默认5)
|
||||
|
||||
// 方案3:检查代理本身是否可用
|
||||
// 使用curl测试代理:
|
||||
// curl -x http://proxy_url https://api.binance.com/api/v3/ping
|
||||
```
|
||||
|
||||
### 问题3:Bright Data模式无法获取IP
|
||||
|
||||
**原因**:
|
||||
- API端点配置错误
|
||||
- Token无效或过期
|
||||
- Zone参数不正确
|
||||
|
||||
**解决方案**:
|
||||
```bash
|
||||
# 手动测试API
|
||||
curl -H "Authorization: Bearer YOUR_TOKEN" \
|
||||
"https://api.brightdata.com/zones/get_ips?zone=residential"
|
||||
|
||||
# 检查返回格式是否符合:
|
||||
# {"ips": [{"ip": "1.2.3.4", ...}, ...]}
|
||||
```
|
||||
|
||||
### 问题4:代理连接超时
|
||||
|
||||
**原因**:代理服务器响应慢或网络不稳定
|
||||
|
||||
**解决方案**:
|
||||
```json
|
||||
{
|
||||
"proxy": {
|
||||
"timeout": 60 // 增加超时时间(秒)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 扩展开发
|
||||
|
||||
### 添加新的Provider
|
||||
|
||||
实现 `IPProvider` 接口即可:
|
||||
|
||||
```go
|
||||
// custom_provider.go
|
||||
package proxy
|
||||
|
||||
type CustomProvider struct {
|
||||
// 自定义字段
|
||||
}
|
||||
|
||||
func NewCustomProvider(config string) *CustomProvider {
|
||||
return &CustomProvider{}
|
||||
}
|
||||
|
||||
func (p *CustomProvider) GetIPList() ([]ProxyIP, error) {
|
||||
// 实现获取IP列表的逻辑
|
||||
return []ProxyIP{}, nil
|
||||
}
|
||||
|
||||
func (p *CustomProvider) RefreshIPList() ([]ProxyIP, error) {
|
||||
// 实现刷新IP列表的逻辑
|
||||
return p.GetIPList()
|
||||
}
|
||||
```
|
||||
|
||||
然后在 `proxy_manager.go` 的 `NewProxyManager` 中添加新模式:
|
||||
|
||||
```go
|
||||
case "custom":
|
||||
m.provider = NewCustomProvider(config.CustomEndpoint)
|
||||
log.Printf("🌐 HTTP 代理已启用 (自定义模式)")
|
||||
```
|
||||
|
||||
## 更新日志
|
||||
|
||||
### v1.0.0 (当前版本)
|
||||
- ✅ 支持三种代理模式:single、pool、brightdata
|
||||
- ✅ 线程安全的IP轮换和黑名单管理
|
||||
- ✅ 自动刷新机制(30分钟默认)
|
||||
- ✅ TTL黑名单自动恢复
|
||||
- ✅ 防越界保护
|
||||
- ✅ ProxyID追踪机制
|
||||
|
||||
|
||||
## 技术支持
|
||||
|
||||
如有问题或建议,请联系项目维护者 @hzb1115
|
||||
。
|
||||
105
proxy/brightdata_provider.go
Normal file
105
proxy/brightdata_provider.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BrightDataProvider Bright Data动态获取IP提供者
|
||||
type BrightDataProvider struct {
|
||||
endpoint string
|
||||
token string
|
||||
zone string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewBrightDataProvider 创建Bright Data IP提供者
|
||||
func NewBrightDataProvider(endpoint, token, zone string) *BrightDataProvider {
|
||||
return &BrightDataProvider{
|
||||
endpoint: endpoint,
|
||||
token: token,
|
||||
zone: zone,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// BrightDataIPList Bright Data API返回的IP列表结构
|
||||
type BrightDataIPList struct {
|
||||
IPs []struct {
|
||||
IP string `json:"ip"`
|
||||
Maxmind string `json:"maxmind"`
|
||||
Ext map[string]interface{} `json:"ext"`
|
||||
} `json:"ips"`
|
||||
}
|
||||
|
||||
func (p *BrightDataProvider) GetIPList() ([]ProxyIP, error) {
|
||||
return p.fetchIPList()
|
||||
}
|
||||
|
||||
func (p *BrightDataProvider) RefreshIPList() ([]ProxyIP, error) {
|
||||
return p.fetchIPList()
|
||||
}
|
||||
|
||||
func (p *BrightDataProvider) fetchIPList() ([]ProxyIP, error) {
|
||||
// 构建请求URL
|
||||
url := p.endpoint
|
||||
if p.zone != "" {
|
||||
url = fmt.Sprintf("%s?zone=%s", p.endpoint, p.zone)
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置授权头
|
||||
if p.token != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.token))
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送HTTP请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取HTTP响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析JSON数据(支持Bright Data格式)
|
||||
var ipList BrightDataIPList
|
||||
if err := json.Unmarshal(body, &ipList); err != nil {
|
||||
return nil, fmt.Errorf("解析JSON数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为ProxyIP列表
|
||||
result := make([]ProxyIP, 0, len(ipList.IPs))
|
||||
for _, ip := range ipList.IPs {
|
||||
result = append(result, ProxyIP{
|
||||
IP: ip.IP,
|
||||
Protocol: "http",
|
||||
Ext: ip.Ext,
|
||||
})
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, fmt.Errorf("API返回的IP列表为空")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
42
proxy/fixed_provider.go
Normal file
42
proxy/fixed_provider.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package proxy
|
||||
|
||||
import "strings"
|
||||
|
||||
// FixedIPProvider 固定IP列表提供者
|
||||
type FixedIPProvider struct {
|
||||
ips []ProxyIP
|
||||
}
|
||||
|
||||
// NewFixedIPProvider 创建固定IP列表提供者
|
||||
func NewFixedIPProvider(proxyURLs []string) *FixedIPProvider {
|
||||
ips := make([]ProxyIP, 0, len(proxyURLs))
|
||||
for _, proxyURL := range proxyURLs {
|
||||
// 简单解析代理URL
|
||||
// 格式: http://ip:port 或 socks5://user:pass@ip:port
|
||||
protocol := "http"
|
||||
if strings.HasPrefix(proxyURL, "socks5://") {
|
||||
protocol = "socks5"
|
||||
proxyURL = strings.TrimPrefix(proxyURL, "socks5://")
|
||||
} else if strings.HasPrefix(proxyURL, "http://") {
|
||||
proxyURL = strings.TrimPrefix(proxyURL, "http://")
|
||||
} else if strings.HasPrefix(proxyURL, "https://") {
|
||||
protocol = "https"
|
||||
proxyURL = strings.TrimPrefix(proxyURL, "https://")
|
||||
}
|
||||
|
||||
ips = append(ips, ProxyIP{
|
||||
IP: proxyURL,
|
||||
Protocol: protocol,
|
||||
})
|
||||
}
|
||||
|
||||
return &FixedIPProvider{ips: ips}
|
||||
}
|
||||
|
||||
func (p *FixedIPProvider) GetIPList() ([]ProxyIP, error) {
|
||||
return p.ips, nil
|
||||
}
|
||||
|
||||
func (p *FixedIPProvider) RefreshIPList() ([]ProxyIP, error) {
|
||||
return p.ips, nil
|
||||
}
|
||||
10
proxy/provider.go
Normal file
10
proxy/provider.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package proxy
|
||||
|
||||
// IPProvider IP提供者接口
|
||||
type IPProvider interface {
|
||||
// GetIPList 获取IP列表
|
||||
GetIPList() ([]ProxyIP, error)
|
||||
|
||||
// RefreshIPList 刷新IP列表(可选实现)
|
||||
RefreshIPList() ([]ProxyIP, error)
|
||||
}
|
||||
47
proxy/proxy_client.go
Normal file
47
proxy/proxy_client.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- 便捷函数(直接使用全局管理器) ---
|
||||
|
||||
// GetProxyHTTPClient 获取代理 HTTP 客户端(返回 ProxyClient,包含 ProxyID)
|
||||
func GetProxyHTTPClient() (*ProxyClient, error) {
|
||||
return GetGlobalProxyManager().GetProxyClient()
|
||||
}
|
||||
|
||||
// NewHTTPClient 创建一个新的HTTP客户端(使用全局代理配置)
|
||||
// 注意:不返回 ProxyID,如需 ProxyID 请使用 GetProxyHTTPClient()
|
||||
func NewHTTPClient() *http.Client {
|
||||
client, err := GetGlobalProxyManager().GetProxyClient()
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 获取代理客户端失败,使用直连: %v", err)
|
||||
return &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return client.Client
|
||||
}
|
||||
|
||||
// NewHTTPClientWithTimeout 创建一个新的HTTP客户端并指定超时时间
|
||||
// 注意:不返回 ProxyID,如需 ProxyID 请使用 GetProxyHTTPClient()
|
||||
func NewHTTPClientWithTimeout(timeout time.Duration) *http.Client {
|
||||
client, err := GetGlobalProxyManager().GetProxyClient()
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 获取代理客户端失败,使用直连: %v", err)
|
||||
return &http.Client{Timeout: timeout}
|
||||
}
|
||||
client.Client.Timeout = timeout
|
||||
return client.Client
|
||||
}
|
||||
|
||||
// GetTransport 获取HTTP Transport
|
||||
func GetTransport() *http.Transport {
|
||||
client, err := GetGlobalProxyManager().GetProxyClient()
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 获取代理客户端失败,使用直连: %v", err)
|
||||
return &http.Transport{}
|
||||
}
|
||||
return client.Client.Transport.(*http.Transport)
|
||||
}
|
||||
346
proxy/proxy_manager.go
Normal file
346
proxy/proxy_manager.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProxyManager 代理管理器
|
||||
type ProxyManager struct {
|
||||
config *Config
|
||||
provider IPProvider
|
||||
|
||||
// IP池管理
|
||||
ipList []ProxyIP
|
||||
blacklist map[int]string // ProxyID -> IP
|
||||
ipBlacklist map[string]int // IP -> 剩余TTL
|
||||
mutex sync.RWMutex // 读写锁,保证线程安全
|
||||
|
||||
// 刷新控制
|
||||
stopRefresh chan struct{}
|
||||
}
|
||||
|
||||
var (
|
||||
globalProxyManager *ProxyManager
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// InitGlobalProxyManager 初始化全局代理管理器
|
||||
func InitGlobalProxyManager(config *Config) error {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
globalProxyManager, err = NewProxyManager(config)
|
||||
if err == nil && config.Enabled && config.RefreshInterval > 0 {
|
||||
globalProxyManager.StartAutoRefresh()
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// GetGlobalProxyManager 获取全局代理管理器
|
||||
func GetGlobalProxyManager() *ProxyManager {
|
||||
if globalProxyManager == nil {
|
||||
// 如果未初始化,使用默认配置(禁用代理)
|
||||
_ = InitGlobalProxyManager(&Config{Enabled: false})
|
||||
}
|
||||
return globalProxyManager
|
||||
}
|
||||
|
||||
// NewProxyManager 创建代理管理器
|
||||
func NewProxyManager(config *Config) (*ProxyManager, error) {
|
||||
if config == nil {
|
||||
config = &Config{Enabled: false}
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 30 * time.Second
|
||||
}
|
||||
if config.BlacklistTTL == 0 {
|
||||
config.BlacklistTTL = 5 // 默认 TTL 为 5 次刷新
|
||||
}
|
||||
if config.RefreshInterval == 0 && config.Mode == "brightdata" {
|
||||
config.RefreshInterval = 30 * time.Minute // 默认 30 分钟刷新一次
|
||||
}
|
||||
|
||||
m := &ProxyManager{
|
||||
config: config,
|
||||
blacklist: make(map[int]string),
|
||||
ipBlacklist: make(map[string]int),
|
||||
stopRefresh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 如果未启用代理,直接返回
|
||||
if !config.Enabled {
|
||||
log.Printf("🌐 HTTP 代理未启用,使用直连")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// 根据模式选择IP提供者
|
||||
switch config.Mode {
|
||||
case "single":
|
||||
// 单个代理模式
|
||||
if config.ProxyURL == "" {
|
||||
return nil, fmt.Errorf("single模式下必须配置proxy_url")
|
||||
}
|
||||
m.provider = NewSingleProxyProvider(config.ProxyURL)
|
||||
log.Printf("🌐 HTTP 代理已启用 (单代理模式): %s", config.ProxyURL)
|
||||
|
||||
case "pool":
|
||||
// 代理池模式(固定列表)
|
||||
if len(config.ProxyList) == 0 {
|
||||
return nil, fmt.Errorf("pool模式下必须配置proxy_list")
|
||||
}
|
||||
m.provider = NewFixedIPProvider(config.ProxyList)
|
||||
log.Printf("🌐 HTTP 代理已启用 (代理池模式): %d个代理", len(config.ProxyList))
|
||||
|
||||
case "brightdata":
|
||||
// Bright Data动态获取模式
|
||||
if config.BrightDataEndpoint == "" {
|
||||
return nil, fmt.Errorf("brightdata模式下必须配置brightdata_endpoint")
|
||||
}
|
||||
m.provider = NewBrightDataProvider(config.BrightDataEndpoint, config.BrightDataToken, config.BrightDataZone)
|
||||
log.Printf("🌐 HTTP 代理已启用 (Bright Data模式): %s", config.BrightDataEndpoint)
|
||||
|
||||
default:
|
||||
// 默认使用single模式
|
||||
if config.ProxyURL == "" {
|
||||
return nil, fmt.Errorf("未知的proxy模式: %s", config.Mode)
|
||||
}
|
||||
m.provider = NewSingleProxyProvider(config.ProxyURL)
|
||||
log.Printf("🌐 HTTP 代理已启用 (默认模式): %s", config.ProxyURL)
|
||||
}
|
||||
|
||||
// 初始化IP列表
|
||||
if err := m.RefreshIPList(); err != nil {
|
||||
return nil, fmt.Errorf("初始化IP列表失败: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// RefreshIPList 刷新IP列表(线程安全)
|
||||
func (m *ProxyManager) RefreshIPList() error {
|
||||
if m.provider == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ips, err := m.provider.RefreshIPList()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 清理黑名单,TTL倒计时
|
||||
validIPs := make([]ProxyIP, 0, len(ips))
|
||||
newBlacklist := make(map[int]string)
|
||||
|
||||
for _, ip := range ips {
|
||||
if ttl, inBlacklist := m.ipBlacklist[ip.IP]; inBlacklist {
|
||||
// TTL 倒计时
|
||||
m.ipBlacklist[ip.IP] = ttl - 1
|
||||
if ttl > 0 {
|
||||
// 仍在黑名单中,跳过
|
||||
continue
|
||||
}
|
||||
// TTL 归零,从黑名单移除
|
||||
delete(m.ipBlacklist, ip.IP)
|
||||
log.Printf("✓ 代理IP已从黑名单恢复: %s", ip.IP)
|
||||
}
|
||||
validIPs = append(validIPs, ip)
|
||||
}
|
||||
|
||||
m.ipList = validIPs
|
||||
m.blacklist = newBlacklist
|
||||
|
||||
log.Printf("✓ 刷新代理IP列表: 总计%d个,黑名单%d个,可用%d个",
|
||||
len(ips), len(m.ipBlacklist), len(validIPs))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartAutoRefresh 启动自动刷新
|
||||
func (m *ProxyManager) StartAutoRefresh() {
|
||||
if m.config.RefreshInterval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(m.config.RefreshInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := m.RefreshIPList(); err != nil {
|
||||
log.Printf("⚠️ 自动刷新IP列表失败: %v", err)
|
||||
}
|
||||
case <-m.stopRefresh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("✓ 已启动代理IP自动刷新 (间隔: %v)", m.config.RefreshInterval)
|
||||
}
|
||||
|
||||
// StopAutoRefresh 停止自动刷新
|
||||
func (m *ProxyManager) StopAutoRefresh() {
|
||||
close(m.stopRefresh)
|
||||
}
|
||||
|
||||
// getRandomProxy 随机获取一个可用代理(线程安全 - 读锁,确保不越界)
|
||||
func (m *ProxyManager) getRandomProxy() (int, *ProxyIP, error) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if len(m.ipList) == 0 {
|
||||
return -1, nil, fmt.Errorf("代理IP列表为空")
|
||||
}
|
||||
|
||||
// 找到所有未被黑名单的索引
|
||||
availableIndices := make([]int, 0, len(m.ipList))
|
||||
for i := range m.ipList {
|
||||
if _, inBlacklist := m.blacklist[i]; !inBlacklist {
|
||||
availableIndices = append(availableIndices, i)
|
||||
}
|
||||
}
|
||||
|
||||
if len(availableIndices) == 0 {
|
||||
return -1, nil, fmt.Errorf("所有代理IP都在黑名单中")
|
||||
}
|
||||
|
||||
// 随机选择一个(确保不越界)
|
||||
randomIdx := availableIndices[rand.Intn(len(availableIndices))]
|
||||
|
||||
// 二次检查,确保索引有效(防御性编程)
|
||||
if randomIdx < 0 || randomIdx >= len(m.ipList) {
|
||||
return -1, nil, fmt.Errorf("代理索引越界: %d (总数: %d)", randomIdx, len(m.ipList))
|
||||
}
|
||||
|
||||
return randomIdx, &m.ipList[randomIdx], nil
|
||||
}
|
||||
|
||||
// buildProxyURL 构建代理URL
|
||||
func (m *ProxyManager) buildProxyURL(ip *ProxyIP) string {
|
||||
if m.config.ProxyHost != "" && m.config.ProxyUser != "" {
|
||||
// 使用配置的代理主机和认证信息
|
||||
user := m.config.ProxyUser
|
||||
if m.config.ProxyUser != "" && ip.IP != "" {
|
||||
// 支持%s占位符替换IP
|
||||
user = fmt.Sprintf(m.config.ProxyUser, ip.IP)
|
||||
}
|
||||
|
||||
protocol := ip.Protocol
|
||||
if protocol == "" {
|
||||
protocol = "http"
|
||||
}
|
||||
|
||||
if m.config.ProxyPassword != "" {
|
||||
return fmt.Sprintf("%s://%s:%s@%s", protocol, user, m.config.ProxyPassword, m.config.ProxyHost)
|
||||
}
|
||||
return fmt.Sprintf("%s://%s@%s", protocol, user, m.config.ProxyHost)
|
||||
}
|
||||
|
||||
// 直接使用IP信息
|
||||
return ip.IP
|
||||
}
|
||||
|
||||
// GetProxyClient 获取代理客户端(线程安全)
|
||||
func (m *ProxyManager) GetProxyClient() (*ProxyClient, error) {
|
||||
if !m.config.Enabled {
|
||||
// 未启用代理,返回普通HTTP客户端
|
||||
return &ProxyClient{
|
||||
ProxyID: -1, // -1 表示未使用代理
|
||||
IP: "direct",
|
||||
Client: &http.Client{
|
||||
Timeout: m.config.Timeout,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 获取随机代理(使用读锁,确保不越界)
|
||||
proxyID, proxyIP, err := m.getRandomProxy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建代理URL
|
||||
proxyURLStr := m.buildProxyURL(proxyIP)
|
||||
proxyURL, err := url.Parse(proxyURLStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析代理URL失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建Transport
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
return &ProxyClient{
|
||||
ProxyID: proxyID,
|
||||
IP: proxyIP.IP,
|
||||
Client: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: m.config.Timeout,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddBlacklist 将代理IP添加到黑名单(线程安全 - 写锁)
|
||||
func (m *ProxyManager) AddBlacklist(proxyID int) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 检查 proxyID 有效性,防止越界
|
||||
if proxyID < 0 || proxyID >= len(m.ipList) {
|
||||
log.Printf("⚠️ 无效的 ProxyID: %d (有效范围: 0-%d)", proxyID, len(m.ipList)-1)
|
||||
return
|
||||
}
|
||||
|
||||
ip := m.ipList[proxyID].IP
|
||||
m.blacklist[proxyID] = ip
|
||||
m.ipBlacklist[ip] = m.config.BlacklistTTL
|
||||
|
||||
log.Printf("⚠️ 代理IP已加入黑名单: %s (ProxyID: %d, TTL: %d)", ip, proxyID, m.config.BlacklistTTL)
|
||||
}
|
||||
|
||||
// GetBlacklistStatus 获取黑名单状态(线程安全 - 读锁)
|
||||
func (m *ProxyManager) GetBlacklistStatus() (total int, blacklisted int, available int) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
total = len(m.ipList)
|
||||
blacklisted = len(m.ipBlacklist)
|
||||
available = total - len(m.blacklist)
|
||||
return
|
||||
}
|
||||
|
||||
// IsEnabled 检查代理是否启用
|
||||
func IsEnabled() bool {
|
||||
return GetGlobalProxyManager().config.Enabled
|
||||
}
|
||||
|
||||
// RefreshIPList 刷新全局代理IP列表
|
||||
func RefreshIPList() error {
|
||||
return GetGlobalProxyManager().RefreshIPList()
|
||||
}
|
||||
|
||||
// AddBlacklist 将代理IP添加到全局黑名单
|
||||
func AddBlacklist(proxyID int) {
|
||||
GetGlobalProxyManager().AddBlacklist(proxyID)
|
||||
}
|
||||
19
proxy/single_provider.go
Normal file
19
proxy/single_provider.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package proxy
|
||||
|
||||
// SingleProxyProvider 单个代理提供者(不使用IP池)
|
||||
type SingleProxyProvider struct {
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
// NewSingleProxyProvider 创建单个代理提供者
|
||||
func NewSingleProxyProvider(proxyURL string) *SingleProxyProvider {
|
||||
return &SingleProxyProvider{proxyURL: proxyURL}
|
||||
}
|
||||
|
||||
func (p *SingleProxyProvider) GetIPList() ([]ProxyIP, error) {
|
||||
return []ProxyIP{{IP: p.proxyURL}}, nil
|
||||
}
|
||||
|
||||
func (p *SingleProxyProvider) RefreshIPList() ([]ProxyIP, error) {
|
||||
return p.GetIPList()
|
||||
}
|
||||
40
proxy/types.go
Normal file
40
proxy/types.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProxyIP 代理IP信息
|
||||
type ProxyIP struct {
|
||||
IP string `json:"ip"` // IP地址
|
||||
Port string `json:"port"` // 端口(可选)
|
||||
Username string `json:"username"` // 用户名(可选)
|
||||
Password string `json:"password"` // 密码(可选)
|
||||
Protocol string `json:"protocol"` // 协议: http, https, socks5
|
||||
Ext map[string]interface{} `json:"ext"` // 扩展信息
|
||||
}
|
||||
|
||||
// ProxyClient 代理客户端
|
||||
type ProxyClient struct {
|
||||
ProxyID int // IP池中的代理ID(索引)
|
||||
IP string // 使用的IP地址
|
||||
*http.Client // HTTP客户端
|
||||
}
|
||||
|
||||
// Config 代理配置
|
||||
type Config struct {
|
||||
Enabled bool // 是否启用代理
|
||||
Mode string // 模式: "single", "pool", "brightdata"
|
||||
Timeout time.Duration // 超时时间
|
||||
ProxyURL string // 单个代理地址 (single模式)
|
||||
ProxyList []string // 代理列表 (pool模式)
|
||||
BrightDataEndpoint string // Bright Data接口地址 (brightdata模式)
|
||||
BrightDataToken string // Bright Data访问令牌 (brightdata模式)
|
||||
BrightDataZone string // Bright Data区域 (brightdata模式)
|
||||
ProxyHost string // 代理主机
|
||||
ProxyUser string // 代理用户名模板(支持%s占位符)
|
||||
ProxyPassword string // 代理密码
|
||||
RefreshInterval time.Duration // IP列表刷新间隔
|
||||
BlacklistTTL int // 黑名单IP的TTL(刷新次数)
|
||||
}
|
||||
76
scripts/generate_rsa_keys/main.go
Normal file
76
scripts/generate_rsa_keys/main.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func main() {
|
||||
keysDir := "keys"
|
||||
if err := os.MkdirAll(keysDir, 0700); err != nil {
|
||||
fmt.Printf("创建keys目录失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
privateKeyPath := filepath.Join(keysDir, "rsa_private.key")
|
||||
publicKeyPath := filepath.Join(keysDir, "rsa_private.key.pub")
|
||||
|
||||
if _, err := os.Stat(privateKeyPath); err == nil {
|
||||
fmt.Println("RSA密钥对已存在:")
|
||||
fmt.Printf(" 私钥: %s\n", privateKeyPath)
|
||||
fmt.Printf(" 公钥: %s\n", publicKeyPath)
|
||||
|
||||
publicKeyPEM, err := ioutil.ReadFile(publicKeyPath)
|
||||
if err == nil {
|
||||
fmt.Println("\n公钥内容:")
|
||||
fmt.Println(string(publicKeyPEM))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("生成新的RSA密钥对...")
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
fmt.Printf("生成RSA密钥失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
if err := ioutil.WriteFile(privateKeyPath, privateKeyPEM, 0600); err != nil {
|
||||
fmt.Printf("保存私钥失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
fmt.Printf("编码公钥失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyDER,
|
||||
})
|
||||
|
||||
if err := ioutil.WriteFile(publicKeyPath, publicKeyPEM, 0644); err != nil {
|
||||
fmt.Printf("保存公钥失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("✓ RSA密钥对生成成功!")
|
||||
fmt.Printf(" 私钥: %s\n", privateKeyPath)
|
||||
fmt.Printf(" 公钥: %s\n", publicKeyPath)
|
||||
fmt.Println("\n公钥内容(可用于前端配置):")
|
||||
fmt.Println(string(publicKeyPEM))
|
||||
fmt.Println("\n注意: 请妥善保管私钥文件,不要提交到版本控制系统中!")
|
||||
}
|
||||
367
scripts/migrate_sensitive_data/main.go
Normal file
367
scripts/migrate_sensitive_data/main.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/crypto"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func main() {
|
||||
privateKeyPath := flag.String("key", "keys/rsa_private.key", "RSA 私钥路径")
|
||||
dryRun := flag.Bool("dry-run", false, "仅检查需要迁移的数据,不写入数据库")
|
||||
flag.Parse()
|
||||
|
||||
// 尝试加载 .env 文件(从项目根目录运行时)
|
||||
envPaths := []string{
|
||||
".env", // 项目根目录
|
||||
}
|
||||
envLoaded := false
|
||||
for _, envPath := range envPaths {
|
||||
if err := loadEnvFile(envPath); err == nil {
|
||||
log.Printf("成功加载 .env 文件: %s", envPath)
|
||||
envLoaded = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !envLoaded {
|
||||
log.Printf("警告: 未找到 .env 文件,请确保在项目根目录存在 .env 文件")
|
||||
log.Printf("尝试的路径: %v", envPaths)
|
||||
}
|
||||
|
||||
// 确保环境变量已设置
|
||||
if os.Getenv("DATA_ENCRYPTION_KEY") == "" {
|
||||
log.Fatalf("迁移失败: DATA_ENCRYPTION_KEY 环境变量未设置")
|
||||
}
|
||||
|
||||
if err := run(*privateKeyPath, *dryRun); err != nil {
|
||||
log.Fatalf("迁移失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func run(privateKeyPath string, dryRun bool) error {
|
||||
log.SetFlags(0)
|
||||
|
||||
// 尝试多个可能的私钥路径(从项目根目录运行时)
|
||||
keyPaths := []string{
|
||||
privateKeyPath, // 用户指定的路径
|
||||
"keys/rsa_private.key", // 项目根目录的 keys 文件夹
|
||||
}
|
||||
|
||||
var finalKeyPath string
|
||||
for _, path := range keyPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
finalKeyPath = path
|
||||
log.Printf("找到私钥文件: %s", path)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if finalKeyPath == "" {
|
||||
finalKeyPath = privateKeyPath // 使用默认路径,让 crypto 服务生成新密钥
|
||||
log.Printf("警告: 私钥文件不存在,将使用路径: %s, 系统将尝试生成新密钥", finalKeyPath)
|
||||
}
|
||||
|
||||
cryptoService, err := crypto.NewCryptoService(finalKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化加密服务失败: %w", err)
|
||||
}
|
||||
|
||||
db, err := openPostgres()
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
log.Printf("开始迁移 AI 模型密钥 (dry-run=%v)", dryRun)
|
||||
if err := migrateAIModels(db, cryptoService, dryRun); err != nil {
|
||||
return fmt.Errorf("迁移 AI 模型失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("开始迁移交易所密钥 (dry-run=%v)", dryRun)
|
||||
if err := migrateExchanges(db, cryptoService, dryRun); err != nil {
|
||||
return fmt.Errorf("迁移交易所失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 敏感数据迁移完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
func openPostgres() (*sql.DB, error) {
|
||||
host := getEnv("POSTGRES_HOST", "localhost")
|
||||
// 如果是 Docker 服务名,替换为 localhost
|
||||
if host == "postgres" {
|
||||
host = "localhost"
|
||||
}
|
||||
port := getEnv("POSTGRES_PORT", "5432")
|
||||
dbname := getEnv("POSTGRES_DB", "nofx")
|
||||
user := getEnv("POSTGRES_USER", "nofx")
|
||||
password := getEnv("POSTGRES_PASSWORD", "nofx123456")
|
||||
|
||||
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
|
||||
host, port, user, password, dbname)
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(5)
|
||||
db.SetMaxIdleConns(2)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func migrateAIModels(db *sql.DB, cryptoService *crypto.CryptoService, dryRun bool) error {
|
||||
type record struct {
|
||||
ID string
|
||||
UserID string
|
||||
APIKey string
|
||||
}
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT id, user_id, COALESCE(api_key, '')
|
||||
FROM ai_models
|
||||
WHERE COALESCE(deleted, FALSE) = FALSE
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []record
|
||||
for rows.Next() {
|
||||
var r record
|
||||
if err := rows.Scan(&r.ID, &r.UserID, &r.APIKey); err != nil {
|
||||
return err
|
||||
}
|
||||
records = append(records, r)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var updated int
|
||||
for _, r := range records {
|
||||
if r.APIKey == "" || cryptoService.IsEncryptedStorageValue(r.APIKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
encrypted, err := cryptoService.EncryptForStorage(r.APIKey, r.UserID, r.ID, "api_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 AI 模型 %s (%s) 失败: %w", r.ID, r.UserID, err)
|
||||
}
|
||||
|
||||
updated++
|
||||
if dryRun {
|
||||
log.Printf("[DRY-RUN] AI 模型 %s (%s) 将被加密", r.ID, r.UserID)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`
|
||||
UPDATE ai_models
|
||||
SET api_key = $1, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $2 AND user_id = $3
|
||||
`, encrypted, r.ID, r.UserID); err != nil {
|
||||
return fmt.Errorf("更新 AI 模型 %s (%s) 失败: %w", r.ID, r.UserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("AI 模型处理完成,需更新 %d 条记录", updated)
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateExchanges(db *sql.DB, cryptoService *crypto.CryptoService, dryRun bool) error {
|
||||
type record struct {
|
||||
ID string
|
||||
UserID string
|
||||
APIKey string
|
||||
SecretKey string
|
||||
HyperliquidWallet string
|
||||
AsterUser string
|
||||
AsterSigner string
|
||||
AsterPrivateKey string
|
||||
}
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT id, user_id,
|
||||
COALESCE(api_key, '') AS api_key,
|
||||
COALESCE(secret_key, '') AS secret_key,
|
||||
COALESCE(hyperliquid_wallet_addr, '') AS hyperliquid_wallet_addr,
|
||||
COALESCE(aster_user, '') AS aster_user,
|
||||
COALESCE(aster_signer, '') AS aster_signer,
|
||||
COALESCE(aster_private_key, '') AS aster_private_key
|
||||
FROM exchanges
|
||||
WHERE COALESCE(deleted, FALSE) = FALSE
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []record
|
||||
for rows.Next() {
|
||||
var r record
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.UserID,
|
||||
&r.APIKey, &r.SecretKey,
|
||||
&r.HyperliquidWallet,
|
||||
&r.AsterUser, &r.AsterSigner, &r.AsterPrivateKey,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
records = append(records, r)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var updated int
|
||||
for _, r := range records {
|
||||
newAPIKey := r.APIKey
|
||||
newSecretKey := r.SecretKey
|
||||
newHyper := r.HyperliquidWallet
|
||||
newAsterUser := r.AsterUser
|
||||
newAsterSigner := r.AsterSigner
|
||||
newAsterPrivate := r.AsterPrivateKey
|
||||
|
||||
changed := false
|
||||
|
||||
if r.APIKey != "" && !cryptoService.IsEncryptedStorageValue(r.APIKey) {
|
||||
enc, err := cryptoService.EncryptForStorage(r.APIKey, r.UserID, r.ID, "api_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密交易所 API Key 失败: %s (%s): %w", r.ID, r.UserID, err)
|
||||
}
|
||||
newAPIKey = enc
|
||||
changed = true
|
||||
}
|
||||
if r.SecretKey != "" && !cryptoService.IsEncryptedStorageValue(r.SecretKey) {
|
||||
enc, err := cryptoService.EncryptForStorage(r.SecretKey, r.UserID, r.ID, "secret_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密交易所 Secret Key 失败: %s (%s): %w", r.ID, r.UserID, err)
|
||||
}
|
||||
newSecretKey = enc
|
||||
changed = true
|
||||
}
|
||||
if r.HyperliquidWallet != "" && !cryptoService.IsEncryptedStorageValue(r.HyperliquidWallet) {
|
||||
enc, err := cryptoService.EncryptForStorage(r.HyperliquidWallet, r.UserID, r.ID, "hyperliquid_wallet_addr")
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Hyperliquid 地址失败: %s (%s): %w", r.ID, r.UserID, err)
|
||||
}
|
||||
newHyper = enc
|
||||
changed = true
|
||||
}
|
||||
if r.AsterUser != "" && !cryptoService.IsEncryptedStorageValue(r.AsterUser) {
|
||||
enc, err := cryptoService.EncryptForStorage(r.AsterUser, r.UserID, r.ID, "aster_user")
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Aster 用户失败: %s (%s): %w", r.ID, r.UserID, err)
|
||||
}
|
||||
newAsterUser = enc
|
||||
changed = true
|
||||
}
|
||||
if r.AsterSigner != "" && !cryptoService.IsEncryptedStorageValue(r.AsterSigner) {
|
||||
enc, err := cryptoService.EncryptForStorage(r.AsterSigner, r.UserID, r.ID, "aster_signer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Aster Signer 失败: %s (%s): %w", r.ID, r.UserID, err)
|
||||
}
|
||||
newAsterSigner = enc
|
||||
changed = true
|
||||
}
|
||||
if r.AsterPrivateKey != "" && !cryptoService.IsEncryptedStorageValue(r.AsterPrivateKey) {
|
||||
enc, err := cryptoService.EncryptForStorage(r.AsterPrivateKey, r.UserID, r.ID, "aster_private_key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Aster 私钥失败: %s (%s): %w", r.ID, r.UserID, err)
|
||||
}
|
||||
newAsterPrivate = enc
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
|
||||
updated++
|
||||
if dryRun {
|
||||
log.Printf("[DRY-RUN] 交易所 %s (%s) 将被加密", r.ID, r.UserID)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`
|
||||
UPDATE exchanges
|
||||
SET api_key = $1,
|
||||
secret_key = $2,
|
||||
hyperliquid_wallet_addr = $3,
|
||||
aster_user = $4,
|
||||
aster_signer = $5,
|
||||
aster_private_key = $6,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $7 AND user_id = $8
|
||||
`, newAPIKey, newSecretKey, newHyper, newAsterUser, newAsterSigner, newAsterPrivate, r.ID, r.UserID); err != nil {
|
||||
return fmt.Errorf("更新交易所 %s (%s) 失败: %w", r.ID, r.UserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("交易所处理完成,需更新 %d 条记录", updated)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
return val
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func loadEnvFile(filename string) error {
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(filename); os.IsNotExist(err) {
|
||||
return fmt.Errorf("文件不存在: %s", filename)
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无法打开文件: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 逐行读取
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// 跳过空行和注释行
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析 KEY=VALUE 格式
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
// 只有当环境变量不存在时才设置
|
||||
if os.Getenv(key) == "" {
|
||||
os.Setenv(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import { useAuth } from '../contexts/AuthContext'
|
||||
import { getExchangeIcon } from './ExchangeIcons'
|
||||
import { getModelIcon } from './ModelIcons'
|
||||
import { TraderConfigModal } from './TraderConfigModal'
|
||||
import { TwoStageKeyModal } from './TwoStageKeyModal'
|
||||
import {
|
||||
Bot,
|
||||
Brain,
|
||||
@@ -46,6 +47,12 @@ function getShortName(fullName: string): string {
|
||||
return parts.length > 1 ? parts[parts.length - 1] : fullName
|
||||
}
|
||||
|
||||
function maskSecret(value: string): string {
|
||||
if (!value) return ''
|
||||
const length = Math.min(value.length, 16)
|
||||
return '•'.repeat(length)
|
||||
}
|
||||
|
||||
interface AITradersPageProps {
|
||||
onTraderSelect?: (traderId: string) => void
|
||||
}
|
||||
@@ -143,30 +150,9 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) {
|
||||
allExchanges?.filter((e) => {
|
||||
if (!e.enabled) return false
|
||||
|
||||
// Aster 交易所需要特殊字段
|
||||
if (e.id === 'aster') {
|
||||
return (
|
||||
e.asterUser &&
|
||||
e.asterUser.trim() !== '' &&
|
||||
e.asterSigner &&
|
||||
e.asterSigner.trim() !== '' &&
|
||||
e.asterPrivateKey &&
|
||||
e.asterPrivateKey.trim() !== ''
|
||||
)
|
||||
}
|
||||
|
||||
// Hyperliquid 只需要私钥(作为apiKey),钱包地址会自动从私钥生成
|
||||
if (e.id === 'hyperliquid') {
|
||||
return e.apiKey && e.apiKey.trim() !== ''
|
||||
}
|
||||
|
||||
// Binance 等其他交易所需要 apiKey 和 secretKey
|
||||
return (
|
||||
e.apiKey &&
|
||||
e.apiKey.trim() !== '' &&
|
||||
e.secretKey &&
|
||||
e.secretKey.trim() !== ''
|
||||
)
|
||||
// 由于API不再返回敏感字段信息,只能基于enabled状态判断
|
||||
// 实际的配置验证将在后端进行
|
||||
return true
|
||||
}) || []
|
||||
|
||||
// 检查模型是否正在被运行中的交易员使用
|
||||
@@ -445,7 +431,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) {
|
||||
},
|
||||
}
|
||||
|
||||
await api.updateExchangeConfigs(request)
|
||||
await api.updateExchangeConfigsEncrypted(request)
|
||||
|
||||
const refreshed = await api.getExchangeConfigs()
|
||||
setAllExchanges(refreshed)
|
||||
@@ -494,7 +480,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) {
|
||||
},
|
||||
}
|
||||
|
||||
await api.updateExchangeConfigs(request)
|
||||
await api.updateExchangeConfigsEncrypted(request)
|
||||
|
||||
const refreshedExchanges = await api.getExchangeConfigs()
|
||||
setAllExchanges(refreshedExchanges)
|
||||
@@ -811,7 +797,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) {
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={`w-2.5 h-2.5 md:w-3 md:h-3 rounded-full flex-shrink-0 ${exchange.enabled && exchange.apiKey ? 'bg-green-400' : 'bg-gray-500'}`}
|
||||
className={`w-2.5 h-2.5 md:w-3 md:h-3 rounded-full flex-shrink-0 ${exchange.enabled ? 'bg-green-400' : 'bg-gray-500'}`}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
@@ -1666,6 +1652,9 @@ function ExchangeConfigModal({
|
||||
const [asterUser, setAsterUser] = useState('')
|
||||
const [asterSigner, setAsterSigner] = useState('')
|
||||
const [asterPrivateKey, setAsterPrivateKey] = useState('')
|
||||
const [secureInputTarget, setSecureInputTarget] = useState<
|
||||
null | 'hyperliquid' | 'aster'
|
||||
>(null)
|
||||
|
||||
// 获取当前选择的交易所信息
|
||||
// 编辑模式:从 configuredExchanges 查找(包含用户配置的 apiKey、secretKey 等)
|
||||
@@ -1674,24 +1663,50 @@ function ExchangeConfigModal({
|
||||
? configuredExchanges?.find(e => e.id === selectedExchangeId)
|
||||
: supportedExchanges?.find(e => e.id === selectedExchangeId);
|
||||
|
||||
// 如果是编辑现有交易所,初始化表单数据
|
||||
const secureInputContextLabel =
|
||||
secureInputTarget === 'aster'
|
||||
? t('asterExchangeName', language)
|
||||
: secureInputTarget === 'hyperliquid'
|
||||
? t('hyperliquidExchangeName', language)
|
||||
: undefined
|
||||
|
||||
// 如果是编辑现有交易所,清空所有敏感字段以保证安全
|
||||
useEffect(() => {
|
||||
if (editingExchangeId && selectedExchange) {
|
||||
setApiKey(selectedExchange.apiKey || '')
|
||||
setSecretKey(selectedExchange.secretKey || '')
|
||||
setPassphrase('') // Don't load existing passphrase for security
|
||||
// 编辑模式下清空所有敏感字段,用户需要重新输入
|
||||
setApiKey('')
|
||||
setSecretKey('')
|
||||
setPassphrase('')
|
||||
setTestnet(selectedExchange.testnet || false)
|
||||
|
||||
// Hyperliquid 字段
|
||||
setHyperliquidWalletAddr(selectedExchange.hyperliquidWalletAddr || '')
|
||||
|
||||
// Aster 字段
|
||||
setAsterUser(selectedExchange.asterUser || '')
|
||||
setAsterSigner(selectedExchange.asterSigner || '')
|
||||
setAsterPrivateKey('') // Don't load existing private key for security
|
||||
setAsterSigner('')
|
||||
setAsterPrivateKey('')
|
||||
}
|
||||
}, [editingExchangeId, selectedExchange])
|
||||
|
||||
const handleSecureInputComplete = ({
|
||||
value,
|
||||
obfuscationLog,
|
||||
}: {
|
||||
value: string
|
||||
obfuscationLog: string[]
|
||||
}) => {
|
||||
const trimmed = value.trim()
|
||||
if (secureInputTarget === 'hyperliquid') {
|
||||
setApiKey(trimmed)
|
||||
}
|
||||
if (secureInputTarget === 'aster') {
|
||||
setAsterPrivateKey(trimmed)
|
||||
}
|
||||
console.log('Secure input obfuscation log:', obfuscationLog)
|
||||
setSecureInputTarget(null)
|
||||
}
|
||||
|
||||
const handleSecureInputCancel = () => {
|
||||
setSecureInputTarget(null)
|
||||
}
|
||||
|
||||
// 加载服务器IP(当选择binance时)
|
||||
useEffect(() => {
|
||||
if (selectedExchangeId === 'binance' && !serverIP) {
|
||||
@@ -1755,11 +1770,12 @@ function ExchangeConfigModal({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50 p-4">
|
||||
<div
|
||||
className="bg-gray-800 rounded-lg p-6 w-full max-w-lg relative"
|
||||
style={{ background: '#1E2329' }}
|
||||
>
|
||||
<>
|
||||
<div className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50 p-4">
|
||||
<div
|
||||
className="bg-gray-800 rounded-lg p-6 w-full max-w-lg relative"
|
||||
style={{ background: '#1E2329' }}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h3 className="text-xl font-bold" style={{ color: '#EAECEF' }}>
|
||||
{editingExchangeId
|
||||
@@ -2094,19 +2110,55 @@ function ExchangeConfigModal({
|
||||
>
|
||||
{t('privateKey', language)}
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
value={apiKey}
|
||||
onChange={(e) => setApiKey(e.target.value)}
|
||||
placeholder={t('enterPrivateKey', language)}
|
||||
className="w-full px-3 py-2 rounded"
|
||||
style={{
|
||||
background: '#0B0E11',
|
||||
border: '1px solid #2B3139',
|
||||
color: '#EAECEF',
|
||||
}}
|
||||
required
|
||||
/>
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="text"
|
||||
value={maskSecret(apiKey)}
|
||||
readOnly
|
||||
placeholder={t('enterPrivateKey', language)}
|
||||
className="w-full px-3 py-2 rounded"
|
||||
style={{
|
||||
background: '#0B0E11',
|
||||
border: '1px solid #2B3139',
|
||||
color: '#EAECEF',
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setSecureInputTarget('hyperliquid')}
|
||||
className="px-3 py-2 rounded text-xs font-semibold transition-all hover:scale-105"
|
||||
style={{
|
||||
background: '#F0B90B',
|
||||
color: '#000',
|
||||
whiteSpace: 'nowrap',
|
||||
}}
|
||||
>
|
||||
{apiKey
|
||||
? t('secureInputReenter', language)
|
||||
: t('secureInputButton', language)}
|
||||
</button>
|
||||
{apiKey && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setApiKey('')}
|
||||
className="px-3 py-2 rounded text-xs font-semibold transition-all hover:scale-105"
|
||||
style={{
|
||||
background: '#1B1F2B',
|
||||
color: '#848E9C',
|
||||
whiteSpace: 'nowrap',
|
||||
}}
|
||||
>
|
||||
{t('secureInputClear', language)}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{apiKey && (
|
||||
<div className="text-xs" style={{ color: '#848E9C' }}>
|
||||
{t('secureInputHint', language)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-xs mt-1" style={{ color: '#848E9C' }}>
|
||||
{t('hyperliquidPrivateKeyDesc', language)}
|
||||
</div>
|
||||
@@ -2209,19 +2261,55 @@ function ExchangeConfigModal({
|
||||
/>
|
||||
</Tooltip>
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
value={asterPrivateKey}
|
||||
onChange={(e) => setAsterPrivateKey(e.target.value)}
|
||||
placeholder={t('enterPrivateKey', language)}
|
||||
className="w-full px-3 py-2 rounded"
|
||||
style={{
|
||||
background: '#0B0E11',
|
||||
border: '1px solid #2B3139',
|
||||
color: '#EAECEF',
|
||||
}}
|
||||
required
|
||||
/>
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="text"
|
||||
value={maskSecret(asterPrivateKey)}
|
||||
readOnly
|
||||
placeholder={t('enterPrivateKey', language)}
|
||||
className="w-full px-3 py-2 rounded"
|
||||
style={{
|
||||
background: '#0B0E11',
|
||||
border: '1px solid #2B3139',
|
||||
color: '#EAECEF',
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setSecureInputTarget('aster')}
|
||||
className="px-3 py-2 rounded text-xs font-semibold transition-all hover:scale-105"
|
||||
style={{
|
||||
background: '#F0B90B',
|
||||
color: '#000',
|
||||
whiteSpace: 'nowrap',
|
||||
}}
|
||||
>
|
||||
{asterPrivateKey
|
||||
? t('secureInputReenter', language)
|
||||
: t('secureInputButton', language)}
|
||||
</button>
|
||||
{asterPrivateKey && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setAsterPrivateKey('')}
|
||||
className="px-3 py-2 rounded text-xs font-semibold transition-all hover:scale-105"
|
||||
style={{
|
||||
background: '#1B1F2B',
|
||||
color: '#848E9C',
|
||||
whiteSpace: 'nowrap',
|
||||
}}
|
||||
>
|
||||
{t('secureInputClear', language)}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{asterPrivateKey && (
|
||||
<div className="text-xs" style={{ color: '#848E9C' }}>
|
||||
{t('secureInputHint', language)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
@@ -2349,6 +2437,16 @@ function ExchangeConfigModal({
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<TwoStageKeyModal
|
||||
isOpen={secureInputTarget !== null}
|
||||
language={language}
|
||||
contextLabel={secureInputContextLabel}
|
||||
expectedLength={64}
|
||||
onCancel={handleSecureInputCancel}
|
||||
onComplete={handleSecureInputComplete}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
320
web/src/components/TwoStageKeyModal.tsx
Normal file
320
web/src/components/TwoStageKeyModal.tsx
Normal file
@@ -0,0 +1,320 @@
|
||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { createPortal } from 'react-dom'
|
||||
import { t, type Language } from '../i18n/translations'
|
||||
|
||||
const DEFAULT_LENGTH = 64
|
||||
|
||||
function generateObfuscation(): string {
|
||||
const bytes = new Uint8Array(32)
|
||||
crypto.getRandomValues(bytes)
|
||||
return Array.from(bytes, (byte) => byte.toString(16).padStart(2, '0')).join('')
|
||||
}
|
||||
|
||||
function validatePrivateKeyFormat(value: string, expectedLength: number): boolean {
|
||||
const normalized = value.startsWith('0x') ? value.slice(2) : value
|
||||
if (normalized.length !== expectedLength) {
|
||||
return false
|
||||
}
|
||||
return /^[0-9a-fA-F]+$/.test(normalized)
|
||||
}
|
||||
|
||||
export interface TwoStageKeyModalResult {
|
||||
value: string
|
||||
obfuscationLog: string[]
|
||||
}
|
||||
|
||||
interface TwoStageKeyModalProps {
|
||||
isOpen: boolean
|
||||
language: Language
|
||||
onCancel: () => void
|
||||
onComplete: (result: TwoStageKeyModalResult) => void
|
||||
expectedLength?: number
|
||||
contextLabel?: string
|
||||
}
|
||||
|
||||
export function TwoStageKeyModal({
|
||||
isOpen,
|
||||
language,
|
||||
onCancel,
|
||||
onComplete,
|
||||
expectedLength = DEFAULT_LENGTH,
|
||||
contextLabel,
|
||||
}: TwoStageKeyModalProps) {
|
||||
const [stage, setStage] = useState<1 | 2>(1)
|
||||
const [part1, setPart1] = useState('')
|
||||
const [part2, setPart2] = useState('')
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [clipboardStatus, setClipboardStatus] = useState<'idle' | 'copied' | 'failed'>('idle')
|
||||
const [obfuscationLog, setObfuscationLog] = useState<string[]>([])
|
||||
const [processing, setProcessing] = useState(false)
|
||||
const [manualObfuscationValue, setManualObfuscationValue] = useState<string | null>(null)
|
||||
const stage1InputRef = useRef<HTMLInputElement | null>(null)
|
||||
const stage2InputRef = useRef<HTMLInputElement | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen) return
|
||||
const handler = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') {
|
||||
event.preventDefault()
|
||||
onCancel()
|
||||
}
|
||||
}
|
||||
document.addEventListener('keydown', handler)
|
||||
return () => document.removeEventListener('keydown', handler)
|
||||
}, [isOpen, onCancel])
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen) {
|
||||
setStage(1)
|
||||
setPart1('')
|
||||
setPart2('')
|
||||
setError(null)
|
||||
setClipboardStatus('idle')
|
||||
setObfuscationLog([])
|
||||
setProcessing(false)
|
||||
setManualObfuscationValue(null)
|
||||
return
|
||||
}
|
||||
|
||||
const focusTimer = setTimeout(() => {
|
||||
if (stage === 1) {
|
||||
stage1InputRef.current?.focus()
|
||||
} else {
|
||||
stage2InputRef.current?.focus()
|
||||
}
|
||||
}, 10)
|
||||
|
||||
return () => clearTimeout(focusTimer)
|
||||
}, [isOpen, stage])
|
||||
|
||||
const heading = useMemo(() => {
|
||||
if (!contextLabel) {
|
||||
return t('twoStageModalTitle', language)
|
||||
}
|
||||
return `${t('twoStageModalTitle', language)} · ${contextLabel}`
|
||||
}, [contextLabel, language])
|
||||
|
||||
if (!isOpen) {
|
||||
return null
|
||||
}
|
||||
|
||||
const handleOverlayClick = () => {
|
||||
if (!processing) {
|
||||
onCancel()
|
||||
}
|
||||
}
|
||||
|
||||
const handleStage1Next = async () => {
|
||||
if (!part1.trim()) {
|
||||
setError(t('twoStageStage1Error', language))
|
||||
return
|
||||
}
|
||||
setProcessing(true)
|
||||
const obfuscation = generateObfuscation()
|
||||
let copied = false
|
||||
try {
|
||||
await navigator.clipboard.writeText(obfuscation)
|
||||
copied = true
|
||||
setClipboardStatus('copied')
|
||||
setManualObfuscationValue(null)
|
||||
} catch (err) {
|
||||
console.warn('Clipboard write failed', err)
|
||||
setClipboardStatus('failed')
|
||||
setManualObfuscationValue(obfuscation)
|
||||
}
|
||||
setObfuscationLog((prev) => [...prev, `stage1:${new Date().toISOString()}`])
|
||||
setProcessing(false)
|
||||
setError(null)
|
||||
setStage(2)
|
||||
if (copied) {
|
||||
setManualObfuscationValue(null)
|
||||
}
|
||||
}
|
||||
|
||||
const handleSubmit = () => {
|
||||
const cleanedPart1 = part1.trim()
|
||||
const cleanedPart2 = part2.trim()
|
||||
const combined = (cleanedPart1 + cleanedPart2).replace(/\s+/g, '')
|
||||
|
||||
if (!validatePrivateKeyFormat(combined, expectedLength)) {
|
||||
setError(t('twoStageInvalidFormat', language, { length: expectedLength }))
|
||||
return
|
||||
}
|
||||
|
||||
setObfuscationLog((prev) => [...prev, `stage2:${new Date().toISOString()}`])
|
||||
const result: TwoStageKeyModalResult = {
|
||||
value: combined,
|
||||
obfuscationLog: [...obfuscationLog, `stage2:${new Date().toISOString()}`],
|
||||
}
|
||||
onComplete(result)
|
||||
}
|
||||
|
||||
const modalContent = (
|
||||
<div
|
||||
className="fixed inset-0 z-50 flex items-center justify-center bg-black/70 px-4"
|
||||
onClick={handleOverlayClick}
|
||||
>
|
||||
<div
|
||||
className="w-full max-w-md rounded-xl border border-[#2B3139] bg-[#0B0E11] p-6 shadow-2xl"
|
||||
onClick={(event) => event.stopPropagation()}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<h2 className="text-lg font-semibold" style={{ color: '#EAECEF' }}>
|
||||
{heading}
|
||||
</h2>
|
||||
<p className="text-xs mt-1" style={{ color: '#848E9C' }}>
|
||||
{t('twoStageModalDescription', language, { length: expectedLength })}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{stage === 1 ? (
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label
|
||||
className="block text-sm font-semibold mb-2"
|
||||
style={{ color: '#EAECEF' }}
|
||||
>
|
||||
{t('twoStageStage1Title', language)}
|
||||
</label>
|
||||
<input
|
||||
ref={stage1InputRef}
|
||||
type="password"
|
||||
value={part1}
|
||||
onChange={(event) => setPart1(event.target.value)}
|
||||
placeholder={t('twoStageStage1Placeholder', language)}
|
||||
className="w-full rounded border border-[#2B3139] bg-[#0F111C] px-3 py-2 text-sm text-[#EAECEF] outline-none focus:ring-2 focus:ring-[#F0B90B]/40"
|
||||
disabled={processing}
|
||||
/>
|
||||
<p className="mt-2 text-xs" style={{ color: '#848E9C' }}>
|
||||
{t('twoStageStage1Hint', language)}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{clipboardStatus === 'failed' && (
|
||||
<div
|
||||
className="rounded border border-red-500/40 bg-red-500/10 px-3 py-2 text-xs"
|
||||
style={{ color: '#F6465D' }}
|
||||
>
|
||||
<div>{t('twoStageClipboardManual', language)}</div>
|
||||
{manualObfuscationValue && (
|
||||
<code className="mt-2 block select-all rounded bg-black/40 px-2 py-1 text-[11px] text-[#F0B90B]">
|
||||
{manualObfuscationValue}
|
||||
</code>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div
|
||||
className="rounded border border-red-500/40 bg-red-500/10 px-3 py-2 text-xs"
|
||||
style={{ color: '#F6465D' }}
|
||||
>
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onCancel}
|
||||
className="flex-1 rounded px-3 py-2 text-sm font-semibold transition-all hover:scale-[1.01]"
|
||||
style={{ background: '#1B1F2B', color: '#848E9C' }}
|
||||
disabled={processing}
|
||||
>
|
||||
{t('twoStageCancel', language)}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleStage1Next}
|
||||
className="flex-1 rounded px-3 py-2 text-sm font-semibold transition-all hover:scale-[1.01]"
|
||||
style={{
|
||||
background: processing ? '#3d2e0d' : '#F0B90B',
|
||||
color: processing ? '#a18a43' : '#000',
|
||||
opacity: part1.trim().length === 0 ? 0.7 : 1,
|
||||
}}
|
||||
disabled={processing || part1.trim().length === 0}
|
||||
>
|
||||
{processing ? t('twoStageProcessing', language) : t('twoStageNext', language)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label
|
||||
className="block text-sm font-semibold mb-2"
|
||||
style={{ color: '#EAECEF' }}
|
||||
>
|
||||
{t('twoStageStage2Title', language)}
|
||||
</label>
|
||||
<input
|
||||
ref={stage2InputRef}
|
||||
type="password"
|
||||
value={part2}
|
||||
onChange={(event) => setPart2(event.target.value)}
|
||||
placeholder={t('twoStageStage2Placeholder', language)}
|
||||
className="w-full rounded border border-[#2B3139] bg-[#0F111C] px-3 py-2 text-sm text-[#EAECEF] outline-none focus:ring-2 focus:ring-[#F0B90B]/40"
|
||||
/>
|
||||
<p className="mt-2 text-xs" style={{ color: '#848E9C' }}>
|
||||
{t('twoStageStage2Hint', language)}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{clipboardStatus === 'copied' && (
|
||||
<div
|
||||
className="rounded border border-[#F0B90B]/40 bg-[#F0B90B]/10 px-3 py-2 text-xs"
|
||||
style={{ color: '#F0B90B' }}
|
||||
>
|
||||
{t('twoStageClipboardSuccess', language)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{clipboardStatus === 'failed' && manualObfuscationValue && (
|
||||
<div
|
||||
className="rounded border border-[#2B3139] bg-[#141821] px-3 py-2 text-xs"
|
||||
style={{ color: '#EAECEF' }}
|
||||
>
|
||||
{t('twoStageClipboardReminder', language)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div
|
||||
className="rounded border border-red-500/40 bg-red-500/10 px-3 py-2 text-xs"
|
||||
style={{ color: '#F6465D' }}
|
||||
>
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setStage(1)
|
||||
setPart2('')
|
||||
setError(null)
|
||||
setClipboardStatus('idle')
|
||||
}}
|
||||
className="rounded px-3 py-2 text-sm font-semibold transition-all hover:scale-[1.01]"
|
||||
style={{ background: '#1B1F2B', color: '#848E9C' }}
|
||||
>
|
||||
{t('twoStageBack', language)}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleSubmit}
|
||||
className="flex-1 rounded px-3 py-2 text-sm font-semibold transition-all hover:scale-[1.01]"
|
||||
style={{ background: '#F0B90B', color: '#000' }}
|
||||
>
|
||||
{t('twoStageSubmit', language)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
return createPortal(modalContent, document.body)
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
import React, { createContext, useContext, useState, useEffect } from 'react'
|
||||
import React, { createContext, useContext, useState, useEffect } from 'react';
|
||||
import { getSystemConfig } from '../lib/config';
|
||||
import { CryptoService } from '../lib/crypto';
|
||||
|
||||
interface User {
|
||||
id: string
|
||||
@@ -61,12 +63,33 @@ export function AuthProvider({ children }: { children: React.ReactNode }) {
|
||||
|
||||
const login = async (email: string, password: string) => {
|
||||
try {
|
||||
const systemConfig = await getSystemConfig()
|
||||
if (!systemConfig.rsa_public_key) {
|
||||
throw new Error('系统未配置登录所需的RSA公钥')
|
||||
}
|
||||
|
||||
await CryptoService.initialize(systemConfig.rsa_public_key)
|
||||
const sessionId = sessionStorage.getItem('session_id') || ''
|
||||
|
||||
const requestBody = {
|
||||
email_encrypted: await CryptoService.encryptSensitiveData(
|
||||
email,
|
||||
email,
|
||||
sessionId
|
||||
),
|
||||
password_encrypted: await CryptoService.encryptSensitiveData(
|
||||
password,
|
||||
email,
|
||||
sessionId
|
||||
),
|
||||
}
|
||||
|
||||
const response = await fetch('/api/login', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ email, password }),
|
||||
body: JSON.stringify(requestBody),
|
||||
})
|
||||
|
||||
const data = await response.json()
|
||||
@@ -84,6 +107,7 @@ export function AuthProvider({ children }: { children: React.ReactNode }) {
|
||||
return { success: false, message: data.error }
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Login request failed:', error)
|
||||
return { success: false, message: '登录失败,请重试' }
|
||||
}
|
||||
|
||||
|
||||
@@ -204,6 +204,42 @@ export const translations = {
|
||||
'API wallet private key - Get from https://www.asterdex.com/en/api-wallet (only used locally for signing, never transmitted)',
|
||||
asterUsdtWarning:
|
||||
'Important: Aster only tracks USDT balance. Please ensure you use USDT as margin currency to avoid P&L calculation errors caused by price fluctuations of other assets (BNB, ETH, etc.)',
|
||||
hyperliquidExchangeName: 'Hyperliquid',
|
||||
asterExchangeName: 'Aster DEX',
|
||||
secureInputButton: 'Secure Input',
|
||||
secureInputReenter: 'Re-enter Securely',
|
||||
secureInputClear: 'Clear',
|
||||
secureInputHint:
|
||||
'Captured via secure two-step input. Use “Re-enter Securely” to update this value.',
|
||||
twoStageModalTitle: 'Secure Key Input',
|
||||
twoStageModalDescription:
|
||||
'Use a two-step flow to enter your {length}-character private key safely.',
|
||||
twoStageStage1Title: 'Step 1 · Enter the first half',
|
||||
twoStageStage1Placeholder: 'First 32 characters (include 0x if present)',
|
||||
twoStageStage1Hint:
|
||||
'Continuing copies an obfuscation string to your clipboard as a diversion.',
|
||||
twoStageStage1Error: 'Please enter the first part before continuing.',
|
||||
twoStageNext: 'Next',
|
||||
twoStageProcessing: 'Processing…',
|
||||
twoStageCancel: 'Cancel',
|
||||
twoStageStage2Title: 'Step 2 · Enter the rest',
|
||||
twoStageStage2Placeholder: 'Remaining characters of your private key',
|
||||
twoStageStage2Hint:
|
||||
'Paste the obfuscation string somewhere neutral, then finish entering your key.',
|
||||
twoStageClipboardSuccess:
|
||||
'Obfuscation string copied. Paste it into any text field once before completing.',
|
||||
twoStageClipboardReminder:
|
||||
'Remember to paste the obfuscation string before submitting to avoid clipboard leaks.',
|
||||
twoStageClipboardManual:
|
||||
'Automatic copy failed. Copy the obfuscation string below manually.',
|
||||
twoStageClipboardFailed:
|
||||
'Automatic clipboard copy failed. Please copy the obfuscation string manually.',
|
||||
twoStageClipboardInstruction:
|
||||
'Obfuscation string copied. Paste it once before finishing the input.',
|
||||
twoStageBack: 'Back',
|
||||
twoStageSubmit: 'Confirm',
|
||||
twoStageInvalidFormat:
|
||||
'Invalid private key format. Expected {length} hexadecimal characters (optional 0x prefix).',
|
||||
testnetDescription:
|
||||
'Enable to connect to exchange test environment for simulated trading',
|
||||
securityWarning: 'Security Warning',
|
||||
@@ -700,6 +736,34 @@ export const translations = {
|
||||
'API 钱包私钥 - 从 https://www.asterdex.com/zh-CN/api-wallet 获取(仅在本地用于签名,不会被传输)',
|
||||
asterUsdtWarning:
|
||||
'重要提示:Aster 仅统计 USDT 余额。请确保您使用 USDT 作为保证金币种,避免其他资产(BNB、ETH等)的价格波动导致盈亏统计错误',
|
||||
hyperliquidExchangeName: 'Hyperliquid',
|
||||
asterExchangeName: 'Aster DEX',
|
||||
secureInputButton: '安全输入',
|
||||
secureInputReenter: '重新安全输入',
|
||||
secureInputClear: '清除',
|
||||
secureInputHint: '已通过安全双阶段输入设置。若需修改,请点击“重新安全输入”。',
|
||||
twoStageModalTitle: '安全私钥输入',
|
||||
twoStageModalDescription: '使用双阶段流程安全输入长度为 {length} 的私钥。',
|
||||
twoStageStage1Title: '步骤一 · 输入前半段',
|
||||
twoStageStage1Placeholder: '前 32 位字符(若有 0x 前缀请保留)',
|
||||
twoStageStage1Hint: '继续后会将扰动字符串复制到剪贴板,用于迷惑剪贴板监控。',
|
||||
twoStageStage1Error: '请先输入第一段私钥。',
|
||||
twoStageNext: '下一步',
|
||||
twoStageProcessing: '处理中…',
|
||||
twoStageCancel: '取消',
|
||||
twoStageStage2Title: '步骤二 · 输入剩余部分',
|
||||
twoStageStage2Placeholder: '剩余的私钥字符',
|
||||
twoStageStage2Hint: '将扰动字符串粘贴到任意位置后,再完成私钥输入。',
|
||||
twoStageClipboardSuccess:
|
||||
'扰动字符串已复制。请在完成前在任意文本处粘贴一次以迷惑剪贴板记录。',
|
||||
twoStageClipboardReminder:
|
||||
'记得在提交前粘贴一次扰动字符串,降低剪贴板泄漏风险。',
|
||||
twoStageClipboardManual: '自动复制失败,请手动复制下面的扰动字符串。',
|
||||
twoStageClipboardFailed: '自动写入剪贴板失败,请手动复制扰动字符串。',
|
||||
twoStageClipboardInstruction: '扰动字符串已复制,请在完成输入前粘贴一次。',
|
||||
twoStageBack: '返回',
|
||||
twoStageSubmit: '确认',
|
||||
twoStageInvalidFormat: '私钥格式不正确,应为 {length} 位十六进制字符(可选 0x 前缀)。',
|
||||
testnetDescription: '启用后将连接到交易所测试环境,用于模拟交易',
|
||||
securityWarning: '安全提示',
|
||||
saveConfiguration: '保存配置',
|
||||
|
||||
@@ -11,7 +11,8 @@ import type {
|
||||
UpdateModelConfigRequest,
|
||||
UpdateExchangeConfigRequest,
|
||||
CompetitionData,
|
||||
} from '../types'
|
||||
} from '../types';
|
||||
import { CryptoService } from './crypto';
|
||||
|
||||
const API_BASE = '/api'
|
||||
|
||||
@@ -165,6 +166,40 @@ export const api = {
|
||||
if (!res.ok) throw new Error('更新交易所配置失败')
|
||||
},
|
||||
|
||||
// 使用加密传输更新交易所配置
|
||||
async updateExchangeConfigsEncrypted(request: UpdateExchangeConfigRequest): Promise<void> {
|
||||
// 从系统配置获取公钥
|
||||
const configRes = await fetch(`${API_BASE}/config`);
|
||||
if (!configRes.ok) throw new Error('获取系统配置失败');
|
||||
const config = await configRes.json();
|
||||
|
||||
if (!config.rsa_public_key) {
|
||||
throw new Error('系统未配置RSA公钥,无法使用加密传输');
|
||||
}
|
||||
|
||||
// 初始化加密服务
|
||||
await CryptoService.initialize(config.rsa_public_key);
|
||||
|
||||
// 获取用户信息(从localStorage或其他地方)
|
||||
const userId = localStorage.getItem('user_id') || '';
|
||||
const sessionId = sessionStorage.getItem('session_id') || '';
|
||||
|
||||
// 加密敏感数据
|
||||
const encryptedPayload = await CryptoService.encryptSensitiveData(
|
||||
JSON.stringify(request),
|
||||
userId,
|
||||
sessionId
|
||||
);
|
||||
|
||||
// 发送加密数据
|
||||
const res = await fetch(`${API_BASE}/exchanges/encrypted`, {
|
||||
method: 'PUT',
|
||||
headers: getAuthHeaders(),
|
||||
body: JSON.stringify(encryptedPayload),
|
||||
});
|
||||
if (!res.ok) throw new Error('更新交易所配置失败');
|
||||
},
|
||||
|
||||
// 获取系统状态(支持trader_id)
|
||||
async getStatus(traderId?: string): Promise<SystemStatus> {
|
||||
const url = traderId
|
||||
|
||||
@@ -3,6 +3,8 @@ export interface SystemConfig {
|
||||
default_coins?: string[]
|
||||
btc_eth_leverage?: number
|
||||
altcoin_leverage?: number
|
||||
rsa_public_key?: string
|
||||
rsa_key_id?: string
|
||||
}
|
||||
|
||||
let configPromise: Promise<SystemConfig> | null = null
|
||||
|
||||
147
web/src/lib/crypto.ts
Normal file
147
web/src/lib/crypto.ts
Normal file
@@ -0,0 +1,147 @@
|
||||
export interface EncryptedPayload {
|
||||
wrappedKey: string; // RSA-OAEP(K)
|
||||
iv: string; // 12 bytes
|
||||
ciphertext: string; // AES-GCM 输出(含 tag)
|
||||
aad?: string; // 可选:额外认证数据
|
||||
kid?: string; // 可选:服务端公钥标识
|
||||
ts?: number; // 可选:unix 秒,用于重放保护
|
||||
}
|
||||
|
||||
export class CryptoService {
|
||||
private static publicKey: CryptoKey | null = null;
|
||||
private static publicKeyPEM: string | null = null;
|
||||
|
||||
static async initialize(publicKeyPEM: string) {
|
||||
// 检查 Web Crypto API 是否可用
|
||||
if (!window.crypto || !window.crypto.subtle) {
|
||||
throw new Error('Web Crypto API is not available. Please use HTTPS or localhost to access the application.');
|
||||
}
|
||||
|
||||
if (this.publicKey && this.publicKeyPEM === publicKeyPEM) {
|
||||
return;
|
||||
}
|
||||
this.publicKeyPEM = publicKeyPEM;
|
||||
this.publicKey = await this.importPublicKey(publicKeyPEM);
|
||||
}
|
||||
|
||||
private static async importPublicKey(pem: string): Promise<CryptoKey> {
|
||||
const pemHeader = '-----BEGIN PUBLIC KEY-----';
|
||||
const pemFooter = '-----END PUBLIC KEY-----';
|
||||
const headerIndex = pem.indexOf(pemHeader);
|
||||
const footerIndex = pem.indexOf(pemFooter);
|
||||
|
||||
if (headerIndex === -1 || footerIndex === -1 || headerIndex >= footerIndex) {
|
||||
throw new Error('Invalid PEM formatted public key');
|
||||
}
|
||||
|
||||
const pemContents = pem
|
||||
.substring(headerIndex + pemHeader.length, footerIndex)
|
||||
.replace(/\s+/g, ''); // 移除所有空白字符(包括换行符、空格等)
|
||||
|
||||
const binaryDerString = atob(pemContents);
|
||||
const binaryDer = new Uint8Array(binaryDerString.length);
|
||||
for (let i = 0; i < binaryDerString.length; i++) {
|
||||
binaryDer[i] = binaryDerString.charCodeAt(i);
|
||||
}
|
||||
|
||||
return crypto.subtle.importKey(
|
||||
'spki',
|
||||
binaryDer,
|
||||
{
|
||||
name: 'RSA-OAEP',
|
||||
hash: 'SHA-256',
|
||||
},
|
||||
false,
|
||||
['encrypt']
|
||||
);
|
||||
}
|
||||
|
||||
static async encryptSensitiveData(
|
||||
plaintext: string,
|
||||
userId?: string,
|
||||
sessionId?: string
|
||||
): Promise<EncryptedPayload> {
|
||||
if (!this.publicKey) {
|
||||
throw new Error('Crypto service not initialized. Call initialize() first.');
|
||||
}
|
||||
|
||||
// 1. 生成 256-bit AES 密钥
|
||||
const aesKey = await crypto.subtle.generateKey(
|
||||
{
|
||||
name: 'AES-GCM',
|
||||
length: 256,
|
||||
},
|
||||
true,
|
||||
['encrypt']
|
||||
);
|
||||
|
||||
// 2. 生成 12 字节随机 IV
|
||||
const iv = crypto.getRandomValues(new Uint8Array(12));
|
||||
|
||||
// 3. 准备 AAD (额外认证数据)
|
||||
const ts = Math.floor(Date.now() / 1000);
|
||||
const aadObject = {
|
||||
userId: userId || '',
|
||||
sessionId: sessionId || '',
|
||||
ts: ts,
|
||||
purpose: 'sensitive_data_encryption'
|
||||
};
|
||||
const aadString = JSON.stringify(aadObject);
|
||||
const aadBytes = new TextEncoder().encode(aadString);
|
||||
|
||||
// 4. 使用 AES-GCM 加密数据
|
||||
const plaintextBytes = new TextEncoder().encode(plaintext);
|
||||
const ciphertext = await crypto.subtle.encrypt(
|
||||
{
|
||||
name: 'AES-GCM',
|
||||
iv: iv,
|
||||
additionalData: aadBytes,
|
||||
tagLength: 128, // 16 bytes tag
|
||||
},
|
||||
aesKey,
|
||||
plaintextBytes
|
||||
);
|
||||
|
||||
// 5. 导出 AES 密钥
|
||||
const aesKeyRaw = await crypto.subtle.exportKey('raw', aesKey);
|
||||
|
||||
// 6. 使用 RSA-OAEP 加密 AES 密钥
|
||||
const wrappedKey = await crypto.subtle.encrypt(
|
||||
{
|
||||
name: 'RSA-OAEP',
|
||||
},
|
||||
this.publicKey,
|
||||
aesKeyRaw
|
||||
);
|
||||
|
||||
// 7. 转换为 base64url 格式
|
||||
return {
|
||||
wrappedKey: this.arrayBufferToBase64Url(wrappedKey),
|
||||
iv: this.arrayBufferToBase64Url(iv),
|
||||
ciphertext: this.arrayBufferToBase64Url(ciphertext),
|
||||
aad: this.arrayBufferToBase64Url(aadBytes),
|
||||
kid: 'rsa-key-2025-11-05',
|
||||
ts: ts,
|
||||
};
|
||||
}
|
||||
|
||||
private static arrayBufferToBase64Url(buffer: ArrayBuffer | Uint8Array): string {
|
||||
const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer);
|
||||
let binary = '';
|
||||
for (let i = 0; i < bytes.byteLength; i++) {
|
||||
binary += String.fromCharCode(bytes[i]);
|
||||
}
|
||||
return btoa(binary)
|
||||
.replace(/\+/g, '-')
|
||||
.replace(/\//g, '_')
|
||||
.replace(/=/g, '');
|
||||
}
|
||||
|
||||
static async encryptWalletPrivateKey(privateKey: string, userId?: string, sessionId?: string): Promise<EncryptedPayload> {
|
||||
return this.encryptSensitiveData(privateKey, userId, sessionId);
|
||||
}
|
||||
|
||||
static async encryptExchangeSecret(secretKey: string, userId?: string, sessionId?: string): Promise<EncryptedPayload> {
|
||||
return this.encryptSensitiveData(secretKey, userId, sessionId);
|
||||
}
|
||||
}
|
||||
@@ -108,19 +108,16 @@ export interface AIModel {
|
||||
|
||||
export interface Exchange {
|
||||
id: string
|
||||
user_id: string
|
||||
name: string
|
||||
type: 'cex' | 'dex'
|
||||
enabled: boolean
|
||||
apiKey?: string
|
||||
secretKey?: string
|
||||
testnet?: boolean
|
||||
// Hyperliquid 特定字段
|
||||
hyperliquidWalletAddr?: string
|
||||
// Aster 特定字段
|
||||
asterUser?: string
|
||||
asterSigner?: string
|
||||
asterPrivateKey?: string
|
||||
deleted?: boolean
|
||||
hyperliquidWalletAddr?: string // 钱包地址,非敏感信息
|
||||
asterUser?: string // Aster用户名,非敏感信息
|
||||
deleted: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface CreateTraderRequest {
|
||||
|
||||
Reference in New Issue
Block a user