Merge branch 'dev' into beta

# Conflicts:
#	.github/workflows/docker-build.yml
#	.gitignore
#	api/server.go
#	config/config.go
#	config/database.go
#	decision/engine.go
#	docker-compose.yml
#	go.mod
#	go.sum
#	logger/telegram_sender.go
#	main.go
#	mcp/client.go
#	prompts/adaptive.txt
#	prompts/default.txt
#	prompts/nof1.txt
#	start.sh
#	trader/aster_trader.go
#	trader/auto_trader.go
#	trader/binance_futures.go
#	trader/hyperliquid_trader.go
#	web/package-lock.json
#	web/package.json
#	web/src/App.tsx
#	web/src/components/AILearning.tsx
#	web/src/components/AITradersPage.tsx
#	web/src/components/CompetitionPage.tsx
#	web/src/components/EquityChart.tsx
#	web/src/components/Header.tsx
#	web/src/components/LoginPage.tsx
#	web/src/components/RegisterPage.tsx
#	web/src/components/TraderConfigModal.tsx
#	web/src/components/TraderConfigViewModal.tsx
#	web/src/components/landing/FooterSection.tsx
#	web/src/components/landing/HeaderBar.tsx
#	web/src/contexts/AuthContext.tsx
#	web/src/i18n/translations.ts
#	web/src/lib/api.ts
#	web/src/lib/config.ts
#	web/src/types.ts
This commit is contained in:
Icy
2025-11-12 23:20:25 +08:00
143 changed files with 32902 additions and 3582 deletions

72
api/crypto_handler.go Normal file
View File

@@ -0,0 +1,72 @@
package api
import (
"log"
"net/http"
"nofx/crypto"
"github.com/gin-gonic/gin"
)
// CryptoHandler 加密 API 處理器
type CryptoHandler struct {
cryptoService *crypto.CryptoService
}
// NewCryptoHandler 創建加密處理器
func NewCryptoHandler(cryptoService *crypto.CryptoService) *CryptoHandler {
return &CryptoHandler{
cryptoService: cryptoService,
}
}
// ==================== 公鑰端點 ====================
// HandleGetPublicKey 獲取伺服器公鑰
func (h *CryptoHandler) HandleGetPublicKey(c *gin.Context) {
publicKey := h.cryptoService.GetPublicKeyPEM()
c.JSON(http.StatusOK, map[string]string{
"public_key": publicKey,
"algorithm": "RSA-OAEP-2048",
})
}
// ==================== 加密數據解密端點 ====================
// HandleDecryptSensitiveData 解密客戶端傳送的加密数据
func (h *CryptoHandler) HandleDecryptSensitiveData(c *gin.Context) {
var payload crypto.EncryptedPayload
if err := c.ShouldBindJSON(&payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
// 解密
decrypted, err := h.cryptoService.DecryptSensitiveData(&payload)
if err != nil {
log.Printf("❌ 解密失敗: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Decryption failed"})
return
}
c.JSON(http.StatusOK, map[string]string{
"plaintext": decrypted,
})
}
// ==================== 審計日誌查詢端點 ====================
// 删除审计日志相关功能,在当前简化的实现中不需要
// ==================== 工具函數 ====================
// isValidPrivateKey 驗證私鑰格式
func isValidPrivateKey(key string) bool {
// EVM 私鑰: 64 位十六進制 (可選 0x 前綴)
if len(key) == 64 || (len(key) == 66 && key[:2] == "0x") {
return true
}
// TODO: 添加其他鏈的驗證
return false
}

252
api/register_otp_test.go Normal file
View File

@@ -0,0 +1,252 @@
package api
import (
"testing"
)
// MockUser 模擬用戶結構
type MockUser struct {
ID int
Email string
OTPSecret string
OTPVerified bool
}
// TestOTPRefetchLogic 測試 OTP 重新獲取邏輯
func TestOTPRefetchLogic(t *testing.T) {
tests := []struct {
name string
existingUser *MockUser
userExists bool
expectedAction string // "allow_refetch", "reject_duplicate", "create_new"
expectedMessage string
}{
{
name: "新用戶註冊_郵箱不存在",
existingUser: nil,
userExists: false,
expectedAction: "create_new",
expectedMessage: "創建新用戶",
},
{
name: "未完成OTP驗證_允許重新獲取",
existingUser: &MockUser{
ID: 1,
Email: "test@example.com",
OTPSecret: "SECRET123",
OTPVerified: false,
},
userExists: true,
expectedAction: "allow_refetch",
expectedMessage: "检测到未完成的注册请继续完成OTP设置",
},
{
name: "已完成OTP驗證_拒絕重複註冊",
existingUser: &MockUser{
ID: 2,
Email: "verified@example.com",
OTPSecret: "SECRET456",
OTPVerified: true,
},
userExists: true,
expectedAction: "reject_duplicate",
expectedMessage: "邮箱已被注册",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模擬邏輯處理流程
var actualAction string
var actualMessage string
if !tt.userExists {
// 用戶不存在,創建新用戶
actualAction = "create_new"
actualMessage = "創建新用戶"
} else {
// 用戶已存在,檢查 OTP 驗證狀態
if !tt.existingUser.OTPVerified {
// 未完成 OTP 驗證,允許重新獲取
actualAction = "allow_refetch"
actualMessage = "检测到未完成的注册请继续完成OTP设置"
} else {
// 已完成驗證,拒絕重複註冊
actualAction = "reject_duplicate"
actualMessage = "邮箱已被注册"
}
}
// 驗證結果
if actualAction != tt.expectedAction {
t.Errorf("Action 不符: got %s, want %s", actualAction, tt.expectedAction)
}
if actualMessage != tt.expectedMessage {
t.Errorf("Message 不符: got %s, want %s", actualMessage, tt.expectedMessage)
}
})
}
}
// TestOTPVerificationStates 測試 OTP 驗證狀態判斷
func TestOTPVerificationStates(t *testing.T) {
tests := []struct {
name string
otpVerified bool
shouldAllowRefetch bool
}{
{
name: "OTP已驗證_不允許重新獲取",
otpVerified: true,
shouldAllowRefetch: false,
},
{
name: "OTP未驗證_允許重新獲取",
otpVerified: false,
shouldAllowRefetch: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模擬驗證邏輯
allowRefetch := !tt.otpVerified
if allowRefetch != tt.shouldAllowRefetch {
t.Errorf("Refetch logic error: OTPVerified=%v, allowRefetch=%v, expected=%v",
tt.otpVerified, allowRefetch, tt.shouldAllowRefetch)
}
})
}
}
// TestRegistrationFlow 測試完整註冊流程的邏輯分支
func TestRegistrationFlow(t *testing.T) {
tests := []struct {
name string
scenario string
userExists bool
otpVerified bool
expectHTTPCode int // 模擬的 HTTP 狀態碼
expectResponse string
}{
{
name: "場景1_新用戶首次註冊",
scenario: "新用戶首次訪問註冊接口",
userExists: false,
otpVerified: false,
expectHTTPCode: 200,
expectResponse: "創建用戶並返回 OTP 設置信息",
},
{
name: "場景2_用戶中斷註冊後重新訪問",
scenario: "用戶之前註冊但未完成 OTP 設置,現在重新訪問",
userExists: true,
otpVerified: false,
expectHTTPCode: 200,
expectResponse: "返回現有用戶的 OTP 信息,允許繼續完成",
},
{
name: "場景3_已註冊用戶嘗試重複註冊",
scenario: "用戶已完成註冊,嘗試用同一郵箱再次註冊",
userExists: true,
otpVerified: true,
expectHTTPCode: 409, // Conflict
expectResponse: "邮箱已被注册",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模擬註冊流程邏輯
var actualHTTPCode int
var actualResponse string
if !tt.userExists {
// 新用戶,創建並返回 OTP 信息
actualHTTPCode = 200
actualResponse = "創建用戶並返回 OTP 設置信息"
} else {
// 用戶已存在
if !tt.otpVerified {
// 未完成 OTP 驗證,允許重新獲取
actualHTTPCode = 200
actualResponse = "返回現有用戶的 OTP 信息,允許繼續完成"
} else {
// 已完成驗證,拒絕重複註冊
actualHTTPCode = 409
actualResponse = "邮箱已被注册"
}
}
// 驗證
if actualHTTPCode != tt.expectHTTPCode {
t.Errorf("HTTP code 不符: got %d, want %d (scenario: %s)",
actualHTTPCode, tt.expectHTTPCode, tt.scenario)
}
if actualResponse != tt.expectResponse {
t.Errorf("Response 不符: got %s, want %s (scenario: %s)",
actualResponse, tt.expectResponse, tt.scenario)
}
t.Logf("✓ %s: HTTP %d, %s", tt.scenario, actualHTTPCode, actualResponse)
})
}
}
// TestEdgeCases 測試邊界情況
func TestEdgeCases(t *testing.T) {
tests := []struct {
name string
user *MockUser
expectAllow bool
description string
}{
{
name: "用戶ID為0_視為新用戶",
user: &MockUser{
ID: 0,
Email: "new@example.com",
OTPVerified: false,
},
expectAllow: true,
description: "ID為0通常表示用戶還未創建",
},
{
name: "OTPSecret為空_仍可重新獲取",
user: &MockUser{
ID: 1,
Email: "test@example.com",
OTPSecret: "",
OTPVerified: false,
},
expectAllow: true,
description: "即使 OTPSecret 為空,只要未驗證就允許重新獲取",
},
{
name: "OTPSecret存在但已驗證_不允許",
user: &MockUser{
ID: 2,
Email: "verified@example.com",
OTPSecret: "SECRET789",
OTPVerified: true,
},
expectAllow: false,
description: "OTP 已驗證的用戶不能重新獲取",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 核心邏輯:只要 OTPVerified 為 false就允許重新獲取
allowRefetch := !tt.user.OTPVerified
if allowRefetch != tt.expectAllow {
t.Errorf("Edge case failed: %s\nUser: ID=%d, OTPVerified=%v\nExpected allow=%v, got=%v",
tt.description, tt.user.ID, tt.user.OTPVerified, tt.expectAllow, allowRefetch)
}
t.Logf("✓ %s", tt.description)
})
}
}

View File

@@ -1,6 +1,7 @@
package api
import (
"context"
"encoding/json"
"fmt"
"log"
@@ -8,13 +9,14 @@ import (
"net/http"
"nofx/auth"
"nofx/config"
"nofx/crypto"
"nofx/decision"
"nofx/hook"
"nofx/manager"
"nofx/trader"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -22,13 +24,15 @@ import (
// Server HTTP API服务器
type Server struct {
router *gin.Engine
httpServer *http.Server
traderManager *manager.TraderManager
database *config.Database
cryptoHandler *CryptoHandler
port int
}
// NewServer 创建API服务器
func NewServer(traderManager *manager.TraderManager, database *config.Database, port int) *Server {
func NewServer(traderManager *manager.TraderManager, database *config.Database, cryptoService *crypto.CryptoService, port int) *Server {
// 设置为Release模式减少日志输出
gin.SetMode(gin.ReleaseMode)
@@ -37,10 +41,14 @@ func NewServer(traderManager *manager.TraderManager, database *config.Database,
// 启用CORS
router.Use(corsMiddleware())
// 创建加密处理器
cryptoHandler := NewCryptoHandler(cryptoService)
s := &Server{
router: router,
traderManager: traderManager,
database: database,
cryptoHandler: cryptoHandler,
port: port,
}
@@ -74,19 +82,19 @@ func (s *Server) setupRoutes() {
// 健康检查
api.Any("/health", s.handleHealth)
// 认证相关路由(无需认证
api.POST("/register", s.handleRegister)
api.POST("/login", s.handleLogin)
api.POST("/verify-otp", s.handleVerifyOTP)
api.POST("/complete-registration", s.handleCompleteRegistration)
// 管理员登录(管理员模式下使用,公共
// 系统支持的模型和交易所(无需认证)
api.GET("/supported-models", s.handleGetSupportedModels)
api.GET("/supported-exchanges", s.handleGetSupportedExchanges)
// 系统配置(无需认证)
// 系统配置(无需认证,用于前端判断是否管理员模式/注册是否开启
api.GET("/config", s.handleGetSystemConfig)
// 加密相关接口(无需认证)
api.GET("/crypto/public-key", s.cryptoHandler.HandleGetPublicKey)
api.POST("/crypto/decrypt", s.cryptoHandler.HandleDecryptSensitiveData)
// 系统提示词模板管理(无需认证)
api.GET("/prompt-templates", s.handleGetPromptTemplates)
api.GET("/prompt-templates/:name", s.handleGetPromptTemplate)
@@ -99,9 +107,18 @@ func (s *Server) setupRoutes() {
api.POST("/equity-history-batch", s.handleEquityHistoryBatch)
api.GET("/traders/:id/public-config", s.handleGetPublicTraderConfig)
// 认证相关路由(无需认证)
api.POST("/register", s.handleRegister)
api.POST("/login", s.handleLogin)
api.POST("/verify-otp", s.handleVerifyOTP)
api.POST("/complete-registration", s.handleCompleteRegistration)
// 需要认证的路由
protected := api.Group("/", s.authMiddleware())
{
// 注销(加入黑名单)
protected.POST("/logout", s.handleLogout)
// 服务器IP查询需要认证用于白名单配置
protected.GET("/server-ip", s.handleGetServerIP)
@@ -180,7 +197,6 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) {
betaMode := betaModeStr == "true"
c.JSON(http.StatusOK, gin.H{
"admin_mode": auth.IsAdminMode(),
"beta_mode": betaMode,
"default_coins": defaultCoins,
"btc_eth_leverage": btcEthLeverage,
@@ -190,6 +206,17 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) {
// handleGetServerIP 获取服务器IP地址用于白名单配置
func (s *Server) handleGetServerIP(c *gin.Context) {
// 首先尝试从Hook获取用户专用IP
userIP := hook.HookExec[hook.IpResult](hook.GETIP, c.GetString("user_id"))
if userIP != nil && userIP.Error() == nil {
c.JSON(http.StatusOK, gin.H{
"public_ip": userIP.GetResult(),
"message": "请将此IP地址添加到白名单中",
})
return
}
// 尝试通过第三方API获取公网IP
publicIP := getPublicIPFromAPI()
@@ -372,6 +399,16 @@ type ModelConfig struct {
CustomAPIURL string `json:"customApiUrl,omitempty"`
}
// SafeModelConfig 安全的模型配置结构(不包含敏感信息)
type SafeModelConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Enabled bool `json:"enabled"`
CustomAPIURL string `json:"customApiUrl"` // 自定义API URL通常不敏感
CustomModelName string `json:"customModelName"` // 自定义模型名(不敏感)
}
type ExchangeConfig struct {
ID string `json:"id"`
Name string `json:"name"`
@@ -382,6 +419,18 @@ type ExchangeConfig struct {
Testnet bool `json:"testnet,omitempty"`
}
// SafeExchangeConfig 安全的交易所配置结构(不包含敏感信息)
type SafeExchangeConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"` // "cex" or "dex"
Enabled bool `json:"enabled"`
Testnet bool `json:"testnet,omitempty"`
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Hyperliquid钱包地址不敏感
AsterUser string `json:"asterUser"` // Aster用户名不敏感
AsterSigner string `json:"asterSigner"` // Aster签名者不敏感
}
type UpdateModelConfigRequest struct {
Models map[string]struct {
Enabled bool `json:"enabled"`
@@ -507,7 +556,7 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
switch req.ExchangeID {
case "binance":
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey)
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
case "hyperliquid":
tempTrader, createErr = trader.NewHyperliquidTrader(
exchangeCfg.APIKey, // private key
@@ -576,9 +625,9 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
}
// 立即将新交易员加载到TraderManager中
err = s.traderManager.LoadUserTraders(s.database, userID)
err = s.traderManager.LoadTraderByID(s.database, userID, traderID)
if err != nil {
log.Printf("⚠️ 加载用户交易员到内存失败: %v", err)
log.Printf("⚠️ 加载交易员到内存失败: %v", err)
// 这里不返回错误,因为交易员已经成功创建到数据库
}
@@ -594,17 +643,18 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
// UpdateTraderRequest 更新交易员请求
type UpdateTraderRequest struct {
Name string `json:"name" binding:"required"`
AIModelID string `json:"ai_model_id" binding:"required"`
ExchangeID string `json:"exchange_id" binding:"required"`
InitialBalance float64 `json:"initial_balance"`
ScanIntervalMinutes int `json:"scan_interval_minutes"`
BTCETHLeverage int `json:"btc_eth_leverage"`
AltcoinLeverage int `json:"altcoin_leverage"`
TradingSymbols string `json:"trading_symbols"`
CustomPrompt string `json:"custom_prompt"`
OverrideBasePrompt bool `json:"override_base_prompt"`
IsCrossMargin *bool `json:"is_cross_margin"`
Name string `json:"name" binding:"required"`
AIModelID string `json:"ai_model_id" binding:"required"`
ExchangeID string `json:"exchange_id" binding:"required"`
InitialBalance float64 `json:"initial_balance"`
ScanIntervalMinutes int `json:"scan_interval_minutes"`
BTCETHLeverage int `json:"btc_eth_leverage"`
AltcoinLeverage int `json:"altcoin_leverage"`
TradingSymbols string `json:"trading_symbols"`
CustomPrompt string `json:"custom_prompt"`
OverrideBasePrompt bool `json:"override_base_prompt"`
SystemPromptTemplate string `json:"system_prompt_template"`
IsCrossMargin *bool `json:"is_cross_margin"`
}
// handleUpdateTrader 更新交易员配置
@@ -662,6 +712,12 @@ func (s *Server) handleUpdateTrader(c *gin.Context) {
scanIntervalMinutes = 3
}
// 设置提示词模板,允许更新
systemPromptTemplate := req.SystemPromptTemplate
if systemPromptTemplate == "" {
systemPromptTemplate = existingTrader.SystemPromptTemplate // 如果请求中没有提供,保持原值
}
// 更新交易员配置
trader := &config.TraderRecord{
ID: traderID,
@@ -675,7 +731,7 @@ func (s *Server) handleUpdateTrader(c *gin.Context) {
TradingSymbols: req.TradingSymbols,
CustomPrompt: req.CustomPrompt,
OverrideBasePrompt: req.OverrideBasePrompt,
SystemPromptTemplate: existingTrader.SystemPromptTemplate, // 保持原值
SystemPromptTemplate: systemPromptTemplate,
IsCrossMargin: isCrossMargin,
ScanIntervalMinutes: scanIntervalMinutes,
IsRunning: existingTrader.IsRunning, // 保持原值
@@ -689,9 +745,9 @@ func (s *Server) handleUpdateTrader(c *gin.Context) {
}
// 重新加载交易员到内存
err = s.traderManager.LoadUserTraders(s.database, userID)
err = s.traderManager.LoadTraderByID(s.database, userID, traderID)
if err != nil {
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
log.Printf("⚠️ 重新加载交易员到内存失败: %v", err)
}
log.Printf("✓ 更新交易员成功: %s (模型: %s, 交易所: %s)", req.Name, req.AIModelID, req.ExchangeID)
@@ -735,12 +791,15 @@ func (s *Server) handleStartTrader(c *gin.Context) {
traderID := c.Param("id")
// 校验交易员是否属于当前用户
_, _, _, err := s.database.GetTraderConfig(userID, traderID)
traderRecord, _, _, err := s.database.GetTraderConfig(userID, traderID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在或无访问权限"})
return
}
// 获取模板名称
templateName := traderRecord.SystemPromptTemplate
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在"})
@@ -754,6 +813,9 @@ func (s *Server) handleStartTrader(c *gin.Context) {
return
}
// 重新加载系统提示词模板(确保使用最新的硬盘文件)
s.reloadPromptTemplatesWithLog(templateName)
// 启动交易员
go func() {
log.Printf("▶️ 启动交易员 %s (%s)", traderID, trader.GetName())
@@ -868,7 +930,7 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
switch traderConfig.ExchangeID {
case "binance":
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey)
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
case "hyperliquid":
tempTrader, createErr = trader.NewHyperliquidTrader(
exchangeCfg.APIKey,
@@ -934,9 +996,9 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
}
// 重新加载交易员到内存
err = s.traderManager.LoadUserTraders(s.database, userID)
err = s.traderManager.LoadTraderByID(s.database, userID, traderID)
if err != nil {
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
log.Printf("⚠️ 重新加载交易员到内存失败: %v", err)
}
log.Printf("✅ 已同步余额: %.2f → %.2f USDT (%s %.2f%%)", oldBalance, actualBalance, changeType, changePercent)
@@ -962,18 +1024,69 @@ func (s *Server) handleGetModelConfigs(c *gin.Context) {
}
log.Printf("✅ 找到 %d 个AI模型配置", len(models))
c.JSON(http.StatusOK, models)
// 转换为安全的响应结构,移除敏感信息
safeModels := make([]SafeModelConfig, len(models))
for i, model := range models {
safeModels[i] = SafeModelConfig{
ID: model.ID,
Name: model.Name,
Provider: model.Provider,
Enabled: model.Enabled,
CustomAPIURL: model.CustomAPIURL,
CustomModelName: model.CustomModelName,
}
}
c.JSON(http.StatusOK, safeModels)
}
// handleUpdateModelConfigs 更新AI模型配置
// handleUpdateModelConfigs 更新AI模型配置(仅支持加密数据)
func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
userID := c.GetString("user_id")
var req UpdateModelConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
// 读取原始请求体
bodyBytes, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
return
}
// 解析加密的 payload
var encryptedPayload crypto.EncryptedPayload
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
log.Printf("❌ 解析加密载荷失败: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误,必须使用加密传输"})
return
}
// 验证是否为加密数据
if encryptedPayload.WrappedKey == "" {
log.Printf("❌ 检测到非加密请求 (UserID: %s)", userID)
c.JSON(http.StatusBadRequest, gin.H{
"error": "此接口仅支持加密传输,请使用加密客户端",
"code": "ENCRYPTION_REQUIRED",
"message": "Encrypted transmission is required for security reasons",
})
return
}
// 解密数据
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
if err != nil {
log.Printf("❌ 解密模型配置失败 (UserID: %s): %v", userID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": "解密数据失败"})
return
}
// 解析解密后的数据
var req UpdateModelConfigRequest
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
log.Printf("❌ 解析解密数据失败: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "解析解密数据失败"})
return
}
log.Printf("🔓 已解密模型配置数据 (UserID: %s)", userID)
// 更新每个模型的配置
for modelID, modelData := range req.Models {
err := s.database.UpdateAIModel(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName)
@@ -984,13 +1097,13 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
}
// 重新加载该用户的所有交易员,使新配置立即生效
err := s.traderManager.LoadUserTraders(s.database, userID)
err = s.traderManager.LoadUserTraders(s.database, userID)
if err != nil {
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
// 这里不返回错误,因为模型配置已经成功更新到数据库
}
log.Printf("✓ AI模型配置已更新: %+v", req.Models)
log.Printf("✓ AI模型配置已更新: %+v", SanitizeModelConfigForLog(req.Models))
c.JSON(http.StatusOK, gin.H{"message": "模型配置已更新"})
}
@@ -1006,18 +1119,71 @@ func (s *Server) handleGetExchangeConfigs(c *gin.Context) {
}
log.Printf("✅ 找到 %d 个交易所配置", len(exchanges))
c.JSON(http.StatusOK, exchanges)
// 转换为安全的响应结构,移除敏感信息
safeExchanges := make([]SafeExchangeConfig, len(exchanges))
for i, exchange := range exchanges {
safeExchanges[i] = SafeExchangeConfig{
ID: exchange.ID,
Name: exchange.Name,
Type: exchange.Type,
Enabled: exchange.Enabled,
Testnet: exchange.Testnet,
HyperliquidWalletAddr: exchange.HyperliquidWalletAddr,
AsterUser: exchange.AsterUser,
AsterSigner: exchange.AsterSigner,
}
}
c.JSON(http.StatusOK, safeExchanges)
}
// handleUpdateExchangeConfigs 更新交易所配置
// handleUpdateExchangeConfigs 更新交易所配置(仅支持加密数据)
func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
userID := c.GetString("user_id")
var req UpdateExchangeConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
// 读取原始请求体
bodyBytes, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
return
}
// 解析加密的 payload
var encryptedPayload crypto.EncryptedPayload
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
log.Printf("❌ 解析加密载荷失败: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误,必须使用加密传输"})
return
}
// 验证是否为加密数据
if encryptedPayload.WrappedKey == "" {
log.Printf("❌ 检测到非加密请求 (UserID: %s)", userID)
c.JSON(http.StatusBadRequest, gin.H{
"error": "此接口仅支持加密传输,请使用加密客户端",
"code": "ENCRYPTION_REQUIRED",
"message": "Encrypted transmission is required for security reasons",
})
return
}
// 解密数据
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
if err != nil {
log.Printf("❌ 解密交易所配置失败 (UserID: %s): %v", userID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": "解密数据失败"})
return
}
// 解析解密后的数据
var req UpdateExchangeConfigRequest
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
log.Printf("❌ 解析解密数据失败: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "解析解密数据失败"})
return
}
log.Printf("🔓 已解密交易所配置数据 (UserID: %s)", userID)
// 更新每个交易所的配置
for exchangeID, exchangeData := range req.Exchanges {
err := s.database.UpdateExchange(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey)
@@ -1028,13 +1194,13 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
}
// 重新加载该用户的所有交易员,使新配置立即生效
err := s.traderManager.LoadUserTraders(s.database, userID)
err = s.traderManager.LoadUserTraders(s.database, userID)
if err != nil {
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
// 这里不返回错误,因为交易所配置已经成功更新到数据库
}
log.Printf("✓ 交易所配置已更新: %+v", req.Exchanges)
log.Printf("✓ 交易所配置已更新: %+v", SanitizeExchangeConfigForLog(req.Exchanges))
c.JSON(http.StatusOK, gin.H{"message": "交易所配置已更新"})
}
@@ -1144,21 +1310,22 @@ func (s *Server) handleGetTraderConfig(c *gin.Context) {
aiModelID := traderConfig.AIModelID
result := map[string]interface{}{
"trader_id": traderConfig.ID,
"trader_name": traderConfig.Name,
"ai_model": aiModelID,
"exchange_id": traderConfig.ExchangeID,
"initial_balance": traderConfig.InitialBalance,
"scan_interval_minutes": traderConfig.ScanIntervalMinutes,
"btc_eth_leverage": traderConfig.BTCETHLeverage,
"altcoin_leverage": traderConfig.AltcoinLeverage,
"trading_symbols": traderConfig.TradingSymbols,
"custom_prompt": traderConfig.CustomPrompt,
"override_base_prompt": traderConfig.OverrideBasePrompt,
"is_cross_margin": traderConfig.IsCrossMargin,
"use_coin_pool": traderConfig.UseCoinPool,
"use_oi_top": traderConfig.UseOITop,
"is_running": isRunning,
"trader_id": traderConfig.ID,
"trader_name": traderConfig.Name,
"ai_model": aiModelID,
"exchange_id": traderConfig.ExchangeID,
"initial_balance": traderConfig.InitialBalance,
"scan_interval_minutes": traderConfig.ScanIntervalMinutes,
"btc_eth_leverage": traderConfig.BTCETHLeverage,
"altcoin_leverage": traderConfig.AltcoinLeverage,
"trading_symbols": traderConfig.TradingSymbols,
"custom_prompt": traderConfig.CustomPrompt,
"override_base_prompt": traderConfig.OverrideBasePrompt,
"system_prompt_template": traderConfig.SystemPromptTemplate,
"is_cross_margin": traderConfig.IsCrossMargin,
"use_coin_pool": traderConfig.UseCoinPool,
"use_oi_top": traderConfig.UseOITop,
"is_running": isRunning,
}
c.JSON(http.StatusOK, result)
@@ -1280,7 +1447,15 @@ func (s *Server) handleLatestDecisions(c *gin.Context) {
return
}
records, err := trader.GetDecisionLogger().GetLatestRecords(5)
// 从 query 参数读取 limit默认 5最大 50
limit := 5
if limitStr := c.Query("limit"); limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 50 {
limit = l
}
}
records, err := trader.GetDecisionLogger().GetLatestRecords(limit)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("获取决策日志失败: %v", err),
@@ -1459,14 +1634,6 @@ func (s *Server) handlePerformance(c *gin.Context) {
// authMiddleware JWT认证中间件
func (s *Server) authMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 如果是管理员模式直接使用admin用户
if auth.IsAdminMode() {
c.Set("user_id", "admin")
c.Set("email", "admin@localhost")
c.Next()
return
}
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少Authorization头"})
@@ -1482,8 +1649,17 @@ func (s *Server) authMiddleware() gin.HandlerFunc {
return
}
tokenString := tokenParts[1]
// 黑名单检查
if auth.IsTokenBlacklisted(tokenString) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "token已失效请重新登录"})
c.Abort()
return
}
// 验证JWT token
claims, err := auth.ValidateJWT(tokenParts[1])
claims, err := auth.ValidateJWT(tokenString)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的token: " + err.Error()})
c.Abort()
@@ -1497,8 +1673,37 @@ func (s *Server) authMiddleware() gin.HandlerFunc {
}
}
// handleLogout 将当前token加入黑名单
func (s *Server) handleLogout(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少Authorization头"})
return
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Authorization格式"})
return
}
tokenString := parts[1]
claims, err := auth.ValidateJWT(tokenString)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的token"})
return
}
var exp time.Time
if claims.ExpiresAt != nil {
exp = claims.ExpiresAt.Time
} else {
exp = time.Now().Add(24 * time.Hour)
}
auth.BlacklistToken(tokenString, exp)
c.JSON(http.StatusOK, gin.H{"message": "已登出"})
}
// handleRegister 处理用户注册请求
func (s *Server) handleRegister(c *gin.Context) {
var req struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
@@ -1532,8 +1737,21 @@ func (s *Server) handleRegister(c *gin.Context) {
}
// 检查邮箱是否已存在
_, err := s.database.GetUserByEmail(req.Email)
existingUser, err := s.database.GetUserByEmail(req.Email)
if err == nil {
// 如果用户未完成OTP验证允许重新获取OTP支持中断后恢复注册
if !existingUser.OTPVerified {
qrCodeURL := auth.GetOTPQRCodeURL(existingUser.OTPSecret, req.Email)
c.JSON(http.StatusOK, gin.H{
"user_id": existingUser.ID,
"email": req.Email,
"otp_secret": existingUser.OTPSecret,
"qr_code_url": qrCodeURL,
"message": "检测到未完成的注册请继续完成OTP设置",
})
return
}
// 用户已完成验证,拒绝重复注册
c.JSON(http.StatusConflict, gin.H{"error": "邮箱已被注册"})
return
}
@@ -1728,6 +1946,50 @@ func (s *Server) handleVerifyOTP(c *gin.Context) {
})
}
// handleResetPassword 重置密码(通过邮箱 + OTP 验证)
func (s *Server) handleResetPassword(c *gin.Context) {
var req struct {
Email string `json:"email" binding:"required,email"`
NewPassword string `json:"new_password" binding:"required,min=6"`
OTPCode string `json:"otp_code" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 查询用户
user, err := s.database.GetUserByEmail(req.Email)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "邮箱不存在"})
return
}
// 验证 OTP
if !auth.VerifyOTP(user.OTPSecret, req.OTPCode) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Google Authenticator 验证码错误"})
return
}
// 生成新密码哈希
newPasswordHash, err := auth.HashPassword(req.NewPassword)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "密码处理失败"})
return
}
// 更新密码
err = s.database.UpdateUserPassword(user.ID, newPasswordHash)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "密码更新失败"})
return
}
log.Printf("✓ 用户 %s 密码已重置", user.Email)
c.JSON(http.StatusOK, gin.H{"message": "密码重置成功,请使用新密码登录"})
}
// initUserDefaultConfigs 为新用户初始化默认的模型和交易所配置
func (s *Server) initUserDefaultConfigs(userID string) error {
// 注释掉自动创建默认配置,让用户手动添加
@@ -1759,7 +2021,22 @@ func (s *Server) handleGetSupportedExchanges(c *gin.Context) {
return
}
c.JSON(http.StatusOK, exchanges)
// 转换为安全的响应结构,移除敏感信息
safeExchanges := make([]SafeExchangeConfig, len(exchanges))
for i, exchange := range exchanges {
safeExchanges[i] = SafeExchangeConfig{
ID: exchange.ID,
Name: exchange.Name,
Type: exchange.Type,
Enabled: exchange.Enabled,
Testnet: exchange.Testnet,
HyperliquidWalletAddr: "", // 默认配置不包含钱包地址
AsterUser: "", // 默认配置不包含用户信息
AsterSigner: "",
}
}
c.JSON(http.StatusOK, safeExchanges)
}
// Start 启动服务器
@@ -1791,7 +2068,26 @@ func (s *Server) Start() error {
log.Printf(" • GET /api/performance?trader_id=xxx - 指定trader的AI学习表现分析")
log.Println()
return s.router.Run(addr)
// 创建 http.Server 以支持 graceful shutdown
s.httpServer = &http.Server{
Addr: addr,
Handler: s.router,
}
return s.httpServer.ListenAndServe()
}
// Shutdown 优雅关闭 API 服务器
func (s *Server) Shutdown() error {
if s.httpServer == nil {
return nil
}
// 设置 5 秒超时
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return s.httpServer.Shutdown(ctx)
}
// handleGetPromptTemplates 获取所有系统提示词模板列表
@@ -2035,3 +2331,17 @@ func (s *Server) handleGetPublicTraderConfig(c *gin.Context) {
c.JSON(http.StatusOK, result)
}
// reloadPromptTemplatesWithLog 重新加载提示词模板并记录日志
func (s *Server) reloadPromptTemplatesWithLog(templateName string) {
if err := decision.ReloadPromptTemplates(); err != nil {
log.Printf("⚠️ 重新加载提示词模板失败: %v", err)
return
}
if templateName == "" {
log.Printf("✓ 已重新加载系统提示词模板 [当前使用: default (未指定,使用默认)]")
} else {
log.Printf("✓ 已重新加载系统提示词模板 [当前使用: %s]", templateName)
}
}

227
api/server_test.go Normal file
View File

@@ -0,0 +1,227 @@
package api
import (
"encoding/json"
"testing"
"nofx/config"
)
// TestUpdateTraderRequest_SystemPromptTemplate 测试更新交易员时 SystemPromptTemplate 字段是否存在
func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
tests := []struct {
name string
requestJSON string
expectedPromptTemplate string
}{
{
name: "更新时应该能接收 system_prompt_template=nof1",
requestJSON: `{
"name": "Test Trader",
"ai_model_id": "gpt-4",
"exchange_id": "binance",
"initial_balance": 1000,
"scan_interval_minutes": 5,
"btc_eth_leverage": 5,
"altcoin_leverage": 3,
"trading_symbols": "BTC,ETH",
"custom_prompt": "test",
"override_base_prompt": false,
"is_cross_margin": true,
"system_prompt_template": "nof1"
}`,
expectedPromptTemplate: "nof1",
},
{
name: "更新时应该能接收 system_prompt_template=default",
requestJSON: `{
"name": "Test Trader",
"ai_model_id": "gpt-4",
"exchange_id": "binance",
"initial_balance": 1000,
"scan_interval_minutes": 5,
"btc_eth_leverage": 5,
"altcoin_leverage": 3,
"trading_symbols": "BTC,ETH",
"custom_prompt": "test",
"override_base_prompt": false,
"is_cross_margin": true,
"system_prompt_template": "default"
}`,
expectedPromptTemplate: "default",
},
{
name: "更新时应该能接收 system_prompt_template=custom",
requestJSON: `{
"name": "Test Trader",
"ai_model_id": "gpt-4",
"exchange_id": "binance",
"initial_balance": 1000,
"scan_interval_minutes": 5,
"btc_eth_leverage": 5,
"altcoin_leverage": 3,
"trading_symbols": "BTC,ETH",
"custom_prompt": "test",
"override_base_prompt": false,
"is_cross_margin": true,
"system_prompt_template": "custom"
}`,
expectedPromptTemplate: "custom",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 测试 UpdateTraderRequest 结构体是否能正确解析 system_prompt_template 字段
var req UpdateTraderRequest
err := json.Unmarshal([]byte(tt.requestJSON), &req)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
// ✅ 验证 SystemPromptTemplate 字段是否被正确读取
if req.SystemPromptTemplate != tt.expectedPromptTemplate {
t.Errorf("Expected SystemPromptTemplate=%q, got %q",
tt.expectedPromptTemplate, req.SystemPromptTemplate)
}
// 验证其他字段也被正确解析
if req.Name != "Test Trader" {
t.Errorf("Name not parsed correctly")
}
if req.AIModelID != "gpt-4" {
t.Errorf("AIModelID not parsed correctly")
}
})
}
}
// TestGetTraderConfigResponse_SystemPromptTemplate 测试获取交易员配置时返回值是否包含 system_prompt_template
func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
tests := []struct {
name string
traderConfig *config.TraderRecord
expectedTemplate string
}{
{
name: "获取配置应该返回 system_prompt_template=nof1",
traderConfig: &config.TraderRecord{
ID: "trader-123",
UserID: "user-1",
Name: "Test Trader",
AIModelID: "gpt-4",
ExchangeID: "binance",
InitialBalance: 1000,
ScanIntervalMinutes: 5,
BTCETHLeverage: 5,
AltcoinLeverage: 3,
TradingSymbols: "BTC,ETH",
CustomPrompt: "test",
OverrideBasePrompt: false,
SystemPromptTemplate: "nof1",
IsCrossMargin: true,
IsRunning: false,
},
expectedTemplate: "nof1",
},
{
name: "获取配置应该返回 system_prompt_template=default",
traderConfig: &config.TraderRecord{
ID: "trader-456",
UserID: "user-1",
Name: "Test Trader 2",
AIModelID: "gpt-4",
ExchangeID: "binance",
InitialBalance: 2000,
ScanIntervalMinutes: 10,
BTCETHLeverage: 10,
AltcoinLeverage: 5,
TradingSymbols: "BTC",
CustomPrompt: "",
OverrideBasePrompt: false,
SystemPromptTemplate: "default",
IsCrossMargin: false,
IsRunning: false,
},
expectedTemplate: "default",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 handleGetTraderConfig 的返回值构造逻辑(修复后的实现)
result := map[string]interface{}{
"trader_id": tt.traderConfig.ID,
"trader_name": tt.traderConfig.Name,
"ai_model": tt.traderConfig.AIModelID,
"exchange_id": tt.traderConfig.ExchangeID,
"initial_balance": tt.traderConfig.InitialBalance,
"scan_interval_minutes": tt.traderConfig.ScanIntervalMinutes,
"btc_eth_leverage": tt.traderConfig.BTCETHLeverage,
"altcoin_leverage": tt.traderConfig.AltcoinLeverage,
"trading_symbols": tt.traderConfig.TradingSymbols,
"custom_prompt": tt.traderConfig.CustomPrompt,
"override_base_prompt": tt.traderConfig.OverrideBasePrompt,
"system_prompt_template": tt.traderConfig.SystemPromptTemplate,
"is_cross_margin": tt.traderConfig.IsCrossMargin,
"is_running": tt.traderConfig.IsRunning,
}
// ✅ 检查响应中是否包含 system_prompt_template
if _, exists := result["system_prompt_template"]; !exists {
t.Errorf("Response is missing 'system_prompt_template' field")
} else {
actualTemplate := result["system_prompt_template"].(string)
if actualTemplate != tt.expectedTemplate {
t.Errorf("Expected system_prompt_template=%q, got %q",
tt.expectedTemplate, actualTemplate)
}
}
// 验证其他字段是否正确
if result["trader_id"] != tt.traderConfig.ID {
t.Errorf("trader_id mismatch")
}
if result["trader_name"] != tt.traderConfig.Name {
t.Errorf("trader_name mismatch")
}
})
}
}
// TestUpdateTraderRequest_CompleteFields 验证 UpdateTraderRequest 结构体定义完整性
func TestUpdateTraderRequest_CompleteFields(t *testing.T) {
jsonData := `{
"name": "Test Trader",
"ai_model_id": "gpt-4",
"exchange_id": "binance",
"initial_balance": 1000,
"scan_interval_minutes": 5,
"btc_eth_leverage": 5,
"altcoin_leverage": 3,
"trading_symbols": "BTC,ETH",
"custom_prompt": "test",
"override_base_prompt": false,
"is_cross_margin": true,
"system_prompt_template": "nof1"
}`
var req UpdateTraderRequest
err := json.Unmarshal([]byte(jsonData), &req)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
// 验证基本字段是否正确解析
if req.Name != "Test Trader" {
t.Errorf("Name mismatch: got %q", req.Name)
}
if req.AIModelID != "gpt-4" {
t.Errorf("AIModelID mismatch: got %q", req.AIModelID)
}
// ✅ 验证 SystemPromptTemplate 字段已正确添加到结构体
if req.SystemPromptTemplate != "nof1" {
t.Errorf("SystemPromptTemplate mismatch: expected %q, got %q", "nof1", req.SystemPromptTemplate)
}
}

97
api/utils.go Normal file
View File

@@ -0,0 +1,97 @@
package api
import "strings"
// MaskSensitiveString 脱敏敏感字符串只显示前4位和后4位
// 用于脱敏 API Key、Secret Key、Private Key 等敏感信息
func MaskSensitiveString(s string) string {
if s == "" {
return ""
}
length := len(s)
if length <= 8 {
return "****" // 字符串太短,全部隐藏
}
return s[:4] + "****" + s[length-4:]
}
// SanitizeModelConfigForLog 脱敏模型配置用于日志输出
func SanitizeModelConfigForLog(models map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
CustomAPIURL string `json:"custom_api_url"`
CustomModelName string `json:"custom_model_name"`
}) map[string]interface{} {
safe := make(map[string]interface{})
for modelID, cfg := range models {
safe[modelID] = map[string]interface{}{
"enabled": cfg.Enabled,
"api_key": MaskSensitiveString(cfg.APIKey),
"custom_api_url": cfg.CustomAPIURL,
"custom_model_name": cfg.CustomModelName,
}
}
return safe
}
// SanitizeExchangeConfigForLog 脱敏交易所配置用于日志输出
func SanitizeExchangeConfigForLog(exchanges map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
SecretKey string `json:"secret_key"`
Testnet bool `json:"testnet"`
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
AsterUser string `json:"aster_user"`
AsterSigner string `json:"aster_signer"`
AsterPrivateKey string `json:"aster_private_key"`
}) map[string]interface{} {
safe := make(map[string]interface{})
for exchangeID, cfg := range exchanges {
safeExchange := map[string]interface{}{
"enabled": cfg.Enabled,
"testnet": cfg.Testnet,
}
// 只在有值时才添加脱敏后的敏感字段
if cfg.APIKey != "" {
safeExchange["api_key"] = MaskSensitiveString(cfg.APIKey)
}
if cfg.SecretKey != "" {
safeExchange["secret_key"] = MaskSensitiveString(cfg.SecretKey)
}
if cfg.AsterPrivateKey != "" {
safeExchange["aster_private_key"] = MaskSensitiveString(cfg.AsterPrivateKey)
}
// 非敏感字段直接添加
if cfg.HyperliquidWalletAddr != "" {
safeExchange["hyperliquid_wallet_addr"] = cfg.HyperliquidWalletAddr
}
if cfg.AsterUser != "" {
safeExchange["aster_user"] = cfg.AsterUser
}
if cfg.AsterSigner != "" {
safeExchange["aster_signer"] = cfg.AsterSigner
}
safe[exchangeID] = safeExchange
}
return safe
}
// MaskEmail 脱敏邮箱地址保留前2位和@后部分
func MaskEmail(email string) string {
if email == "" {
return ""
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
return "****" // 格式不正确
}
username := parts[0]
domain := parts[1]
if len(username) <= 2 {
return "**@" + domain
}
return username[:2] + "****@" + domain
}

193
api/utils_test.go Normal file
View File

@@ -0,0 +1,193 @@
package api
import (
"testing"
)
func TestMaskSensitiveString(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "空字符串",
input: "",
expected: "",
},
{
name: "短字符串小于等于8位",
input: "short",
expected: "****",
},
{
name: "正常API key",
input: "sk-1234567890abcdefghijklmnopqrstuvwxyz",
expected: "sk-1****wxyz",
},
{
name: "正常私钥",
input: "0x1234567890abcdef1234567890abcdef12345678",
expected: "0x12****5678",
},
{
name: "刚好9位",
input: "123456789",
expected: "1234****6789",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := MaskSensitiveString(tt.input)
if result != tt.expected {
t.Errorf("MaskSensitiveString(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestSanitizeModelConfigForLog(t *testing.T) {
models := map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
CustomAPIURL string `json:"custom_api_url"`
CustomModelName string `json:"custom_model_name"`
}{
"deepseek": {
Enabled: true,
APIKey: "sk-1234567890abcdefghijklmnopqrstuvwxyz",
CustomAPIURL: "https://api.deepseek.com",
CustomModelName: "deepseek-chat",
},
}
result := SanitizeModelConfigForLog(models)
deepseekConfig, ok := result["deepseek"].(map[string]interface{})
if !ok {
t.Fatal("deepseek config not found or wrong type")
}
if deepseekConfig["enabled"] != true {
t.Errorf("expected enabled=true, got %v", deepseekConfig["enabled"])
}
maskedKey, ok := deepseekConfig["api_key"].(string)
if !ok {
t.Fatal("api_key not found or wrong type")
}
if maskedKey != "sk-1****wxyz" {
t.Errorf("expected masked api_key='sk-1****wxyz', got %q", maskedKey)
}
if deepseekConfig["custom_api_url"] != "https://api.deepseek.com" {
t.Errorf("custom_api_url should not be masked")
}
}
func TestSanitizeExchangeConfigForLog(t *testing.T) {
exchanges := map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
SecretKey string `json:"secret_key"`
Testnet bool `json:"testnet"`
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
AsterUser string `json:"aster_user"`
AsterSigner string `json:"aster_signer"`
AsterPrivateKey string `json:"aster_private_key"`
}{
"binance": {
Enabled: true,
APIKey: "binance_api_key_1234567890abcdef",
SecretKey: "binance_secret_key_1234567890abcdef",
Testnet: false,
},
"hyperliquid": {
Enabled: true,
HyperliquidWalletAddr: "0x1234567890abcdef1234567890abcdef12345678",
Testnet: false,
},
}
result := SanitizeExchangeConfigForLog(exchanges)
// 检查币安配置
binanceConfig, ok := result["binance"].(map[string]interface{})
if !ok {
t.Fatal("binance config not found or wrong type")
}
maskedAPIKey, ok := binanceConfig["api_key"].(string)
if !ok {
t.Fatal("binance api_key not found or wrong type")
}
if maskedAPIKey != "bina****cdef" {
t.Errorf("expected masked api_key='bina****cdef', got %q", maskedAPIKey)
}
maskedSecretKey, ok := binanceConfig["secret_key"].(string)
if !ok {
t.Fatal("binance secret_key not found or wrong type")
}
if maskedSecretKey != "bina****cdef" {
t.Errorf("expected masked secret_key='bina****cdef', got %q", maskedSecretKey)
}
// 检查 Hyperliquid 配置
hlConfig, ok := result["hyperliquid"].(map[string]interface{})
if !ok {
t.Fatal("hyperliquid config not found or wrong type")
}
walletAddr, ok := hlConfig["hyperliquid_wallet_addr"].(string)
if !ok {
t.Fatal("hyperliquid_wallet_addr not found or wrong type")
}
// 钱包地址不应该被脱敏
if walletAddr != "0x1234567890abcdef1234567890abcdef12345678" {
t.Errorf("wallet address should not be masked, got %q", walletAddr)
}
}
func TestMaskEmail(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "空邮箱",
input: "",
expected: "",
},
{
name: "格式错误",
input: "notanemail",
expected: "****",
},
{
name: "正常邮箱",
input: "user@example.com",
expected: "us****@example.com",
},
{
name: "短用户名",
input: "a@example.com",
expected: "**@example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := MaskEmail(tt.input)
if result != tt.expected {
t.Errorf("MaskEmail(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}