mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2026-06-06 05:51:19 +08:00
Merge branch 'dev' into beta
# Conflicts: # .github/workflows/docker-build.yml # .gitignore # api/server.go # config/config.go # config/database.go # decision/engine.go # docker-compose.yml # go.mod # go.sum # logger/telegram_sender.go # main.go # mcp/client.go # prompts/adaptive.txt # prompts/default.txt # prompts/nof1.txt # start.sh # trader/aster_trader.go # trader/auto_trader.go # trader/binance_futures.go # trader/hyperliquid_trader.go # web/package-lock.json # web/package.json # web/src/App.tsx # web/src/components/AILearning.tsx # web/src/components/AITradersPage.tsx # web/src/components/CompetitionPage.tsx # web/src/components/EquityChart.tsx # web/src/components/Header.tsx # web/src/components/LoginPage.tsx # web/src/components/RegisterPage.tsx # web/src/components/TraderConfigModal.tsx # web/src/components/TraderConfigViewModal.tsx # web/src/components/landing/FooterSection.tsx # web/src/components/landing/HeaderBar.tsx # web/src/contexts/AuthContext.tsx # web/src/i18n/translations.ts # web/src/lib/api.ts # web/src/lib/config.ts # web/src/types.ts
This commit is contained in:
72
api/crypto_handler.go
Normal file
72
api/crypto_handler.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"nofx/crypto"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CryptoHandler 加密 API 處理器
|
||||
type CryptoHandler struct {
|
||||
cryptoService *crypto.CryptoService
|
||||
}
|
||||
|
||||
// NewCryptoHandler 創建加密處理器
|
||||
func NewCryptoHandler(cryptoService *crypto.CryptoService) *CryptoHandler {
|
||||
return &CryptoHandler{
|
||||
cryptoService: cryptoService,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 公鑰端點 ====================
|
||||
|
||||
// HandleGetPublicKey 獲取伺服器公鑰
|
||||
func (h *CryptoHandler) HandleGetPublicKey(c *gin.Context) {
|
||||
publicKey := h.cryptoService.GetPublicKeyPEM()
|
||||
|
||||
c.JSON(http.StatusOK, map[string]string{
|
||||
"public_key": publicKey,
|
||||
"algorithm": "RSA-OAEP-2048",
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 加密數據解密端點 ====================
|
||||
|
||||
// HandleDecryptSensitiveData 解密客戶端傳送的加密数据
|
||||
func (h *CryptoHandler) HandleDecryptSensitiveData(c *gin.Context) {
|
||||
var payload crypto.EncryptedPayload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解密
|
||||
decrypted, err := h.cryptoService.DecryptSensitiveData(&payload)
|
||||
if err != nil {
|
||||
log.Printf("❌ 解密失敗: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Decryption failed"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, map[string]string{
|
||||
"plaintext": decrypted,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 審計日誌查詢端點 ====================
|
||||
|
||||
// 删除审计日志相关功能,在当前简化的实现中不需要
|
||||
|
||||
// ==================== 工具函數 ====================
|
||||
|
||||
// isValidPrivateKey 驗證私鑰格式
|
||||
func isValidPrivateKey(key string) bool {
|
||||
// EVM 私鑰: 64 位十六進制 (可選 0x 前綴)
|
||||
if len(key) == 64 || (len(key) == 66 && key[:2] == "0x") {
|
||||
return true
|
||||
}
|
||||
// TODO: 添加其他鏈的驗證
|
||||
return false
|
||||
}
|
||||
252
api/register_otp_test.go
Normal file
252
api/register_otp_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// MockUser 模擬用戶結構
|
||||
type MockUser struct {
|
||||
ID int
|
||||
Email string
|
||||
OTPSecret string
|
||||
OTPVerified bool
|
||||
}
|
||||
|
||||
// TestOTPRefetchLogic 測試 OTP 重新獲取邏輯
|
||||
func TestOTPRefetchLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
existingUser *MockUser
|
||||
userExists bool
|
||||
expectedAction string // "allow_refetch", "reject_duplicate", "create_new"
|
||||
expectedMessage string
|
||||
}{
|
||||
{
|
||||
name: "新用戶註冊_郵箱不存在",
|
||||
existingUser: nil,
|
||||
userExists: false,
|
||||
expectedAction: "create_new",
|
||||
expectedMessage: "創建新用戶",
|
||||
},
|
||||
{
|
||||
name: "未完成OTP驗證_允許重新獲取",
|
||||
existingUser: &MockUser{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
OTPSecret: "SECRET123",
|
||||
OTPVerified: false,
|
||||
},
|
||||
userExists: true,
|
||||
expectedAction: "allow_refetch",
|
||||
expectedMessage: "检测到未完成的注册,请继续完成OTP设置",
|
||||
},
|
||||
{
|
||||
name: "已完成OTP驗證_拒絕重複註冊",
|
||||
existingUser: &MockUser{
|
||||
ID: 2,
|
||||
Email: "verified@example.com",
|
||||
OTPSecret: "SECRET456",
|
||||
OTPVerified: true,
|
||||
},
|
||||
userExists: true,
|
||||
expectedAction: "reject_duplicate",
|
||||
expectedMessage: "邮箱已被注册",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模擬邏輯處理流程
|
||||
var actualAction string
|
||||
var actualMessage string
|
||||
|
||||
if !tt.userExists {
|
||||
// 用戶不存在,創建新用戶
|
||||
actualAction = "create_new"
|
||||
actualMessage = "創建新用戶"
|
||||
} else {
|
||||
// 用戶已存在,檢查 OTP 驗證狀態
|
||||
if !tt.existingUser.OTPVerified {
|
||||
// 未完成 OTP 驗證,允許重新獲取
|
||||
actualAction = "allow_refetch"
|
||||
actualMessage = "检测到未完成的注册,请继续完成OTP设置"
|
||||
} else {
|
||||
// 已完成驗證,拒絕重複註冊
|
||||
actualAction = "reject_duplicate"
|
||||
actualMessage = "邮箱已被注册"
|
||||
}
|
||||
}
|
||||
|
||||
// 驗證結果
|
||||
if actualAction != tt.expectedAction {
|
||||
t.Errorf("Action 不符: got %s, want %s", actualAction, tt.expectedAction)
|
||||
}
|
||||
if actualMessage != tt.expectedMessage {
|
||||
t.Errorf("Message 不符: got %s, want %s", actualMessage, tt.expectedMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOTPVerificationStates 測試 OTP 驗證狀態判斷
|
||||
func TestOTPVerificationStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
otpVerified bool
|
||||
shouldAllowRefetch bool
|
||||
}{
|
||||
{
|
||||
name: "OTP已驗證_不允許重新獲取",
|
||||
otpVerified: true,
|
||||
shouldAllowRefetch: false,
|
||||
},
|
||||
{
|
||||
name: "OTP未驗證_允許重新獲取",
|
||||
otpVerified: false,
|
||||
shouldAllowRefetch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模擬驗證邏輯
|
||||
allowRefetch := !tt.otpVerified
|
||||
|
||||
if allowRefetch != tt.shouldAllowRefetch {
|
||||
t.Errorf("Refetch logic error: OTPVerified=%v, allowRefetch=%v, expected=%v",
|
||||
tt.otpVerified, allowRefetch, tt.shouldAllowRefetch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistrationFlow 測試完整註冊流程的邏輯分支
|
||||
func TestRegistrationFlow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario string
|
||||
userExists bool
|
||||
otpVerified bool
|
||||
expectHTTPCode int // 模擬的 HTTP 狀態碼
|
||||
expectResponse string
|
||||
}{
|
||||
{
|
||||
name: "場景1_新用戶首次註冊",
|
||||
scenario: "新用戶首次訪問註冊接口",
|
||||
userExists: false,
|
||||
otpVerified: false,
|
||||
expectHTTPCode: 200,
|
||||
expectResponse: "創建用戶並返回 OTP 設置信息",
|
||||
},
|
||||
{
|
||||
name: "場景2_用戶中斷註冊後重新訪問",
|
||||
scenario: "用戶之前註冊但未完成 OTP 設置,現在重新訪問",
|
||||
userExists: true,
|
||||
otpVerified: false,
|
||||
expectHTTPCode: 200,
|
||||
expectResponse: "返回現有用戶的 OTP 信息,允許繼續完成",
|
||||
},
|
||||
{
|
||||
name: "場景3_已註冊用戶嘗試重複註冊",
|
||||
scenario: "用戶已完成註冊,嘗試用同一郵箱再次註冊",
|
||||
userExists: true,
|
||||
otpVerified: true,
|
||||
expectHTTPCode: 409, // Conflict
|
||||
expectResponse: "邮箱已被注册",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模擬註冊流程邏輯
|
||||
var actualHTTPCode int
|
||||
var actualResponse string
|
||||
|
||||
if !tt.userExists {
|
||||
// 新用戶,創建並返回 OTP 信息
|
||||
actualHTTPCode = 200
|
||||
actualResponse = "創建用戶並返回 OTP 設置信息"
|
||||
} else {
|
||||
// 用戶已存在
|
||||
if !tt.otpVerified {
|
||||
// 未完成 OTP 驗證,允許重新獲取
|
||||
actualHTTPCode = 200
|
||||
actualResponse = "返回現有用戶的 OTP 信息,允許繼續完成"
|
||||
} else {
|
||||
// 已完成驗證,拒絕重複註冊
|
||||
actualHTTPCode = 409
|
||||
actualResponse = "邮箱已被注册"
|
||||
}
|
||||
}
|
||||
|
||||
// 驗證
|
||||
if actualHTTPCode != tt.expectHTTPCode {
|
||||
t.Errorf("HTTP code 不符: got %d, want %d (scenario: %s)",
|
||||
actualHTTPCode, tt.expectHTTPCode, tt.scenario)
|
||||
}
|
||||
if actualResponse != tt.expectResponse {
|
||||
t.Errorf("Response 不符: got %s, want %s (scenario: %s)",
|
||||
actualResponse, tt.expectResponse, tt.scenario)
|
||||
}
|
||||
|
||||
t.Logf("✓ %s: HTTP %d, %s", tt.scenario, actualHTTPCode, actualResponse)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEdgeCases 測試邊界情況
|
||||
func TestEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *MockUser
|
||||
expectAllow bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "用戶ID為0_視為新用戶",
|
||||
user: &MockUser{
|
||||
ID: 0,
|
||||
Email: "new@example.com",
|
||||
OTPVerified: false,
|
||||
},
|
||||
expectAllow: true,
|
||||
description: "ID為0通常表示用戶還未創建",
|
||||
},
|
||||
{
|
||||
name: "OTPSecret為空_仍可重新獲取",
|
||||
user: &MockUser{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
OTPSecret: "",
|
||||
OTPVerified: false,
|
||||
},
|
||||
expectAllow: true,
|
||||
description: "即使 OTPSecret 為空,只要未驗證就允許重新獲取",
|
||||
},
|
||||
{
|
||||
name: "OTPSecret存在但已驗證_不允許",
|
||||
user: &MockUser{
|
||||
ID: 2,
|
||||
Email: "verified@example.com",
|
||||
OTPSecret: "SECRET789",
|
||||
OTPVerified: true,
|
||||
},
|
||||
expectAllow: false,
|
||||
description: "OTP 已驗證的用戶不能重新獲取",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 核心邏輯:只要 OTPVerified 為 false,就允許重新獲取
|
||||
allowRefetch := !tt.user.OTPVerified
|
||||
|
||||
if allowRefetch != tt.expectAllow {
|
||||
t.Errorf("Edge case failed: %s\nUser: ID=%d, OTPVerified=%v\nExpected allow=%v, got=%v",
|
||||
tt.description, tt.user.ID, tt.user.OTPVerified, tt.expectAllow, allowRefetch)
|
||||
}
|
||||
|
||||
t.Logf("✓ %s", tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
454
api/server.go
454
api/server.go
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -8,13 +9,14 @@ import (
|
||||
"net/http"
|
||||
"nofx/auth"
|
||||
"nofx/config"
|
||||
"nofx/crypto"
|
||||
"nofx/decision"
|
||||
"nofx/hook"
|
||||
"nofx/manager"
|
||||
"nofx/trader"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -22,13 +24,15 @@ import (
|
||||
// Server HTTP API服务器
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
httpServer *http.Server
|
||||
traderManager *manager.TraderManager
|
||||
database *config.Database
|
||||
cryptoHandler *CryptoHandler
|
||||
port int
|
||||
}
|
||||
|
||||
// NewServer 创建API服务器
|
||||
func NewServer(traderManager *manager.TraderManager, database *config.Database, port int) *Server {
|
||||
func NewServer(traderManager *manager.TraderManager, database *config.Database, cryptoService *crypto.CryptoService, port int) *Server {
|
||||
// 设置为Release模式(减少日志输出)
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
@@ -37,10 +41,14 @@ func NewServer(traderManager *manager.TraderManager, database *config.Database,
|
||||
// 启用CORS
|
||||
router.Use(corsMiddleware())
|
||||
|
||||
// 创建加密处理器
|
||||
cryptoHandler := NewCryptoHandler(cryptoService)
|
||||
|
||||
s := &Server{
|
||||
router: router,
|
||||
traderManager: traderManager,
|
||||
database: database,
|
||||
cryptoHandler: cryptoHandler,
|
||||
port: port,
|
||||
}
|
||||
|
||||
@@ -74,19 +82,19 @@ func (s *Server) setupRoutes() {
|
||||
// 健康检查
|
||||
api.Any("/health", s.handleHealth)
|
||||
|
||||
// 认证相关路由(无需认证)
|
||||
api.POST("/register", s.handleRegister)
|
||||
api.POST("/login", s.handleLogin)
|
||||
api.POST("/verify-otp", s.handleVerifyOTP)
|
||||
api.POST("/complete-registration", s.handleCompleteRegistration)
|
||||
// 管理员登录(管理员模式下使用,公共)
|
||||
|
||||
// 系统支持的模型和交易所(无需认证)
|
||||
api.GET("/supported-models", s.handleGetSupportedModels)
|
||||
api.GET("/supported-exchanges", s.handleGetSupportedExchanges)
|
||||
|
||||
// 系统配置(无需认证)
|
||||
// 系统配置(无需认证,用于前端判断是否管理员模式/注册是否开启)
|
||||
api.GET("/config", s.handleGetSystemConfig)
|
||||
|
||||
// 加密相关接口(无需认证)
|
||||
api.GET("/crypto/public-key", s.cryptoHandler.HandleGetPublicKey)
|
||||
api.POST("/crypto/decrypt", s.cryptoHandler.HandleDecryptSensitiveData)
|
||||
|
||||
// 系统提示词模板管理(无需认证)
|
||||
api.GET("/prompt-templates", s.handleGetPromptTemplates)
|
||||
api.GET("/prompt-templates/:name", s.handleGetPromptTemplate)
|
||||
@@ -99,9 +107,18 @@ func (s *Server) setupRoutes() {
|
||||
api.POST("/equity-history-batch", s.handleEquityHistoryBatch)
|
||||
api.GET("/traders/:id/public-config", s.handleGetPublicTraderConfig)
|
||||
|
||||
// 认证相关路由(无需认证)
|
||||
api.POST("/register", s.handleRegister)
|
||||
api.POST("/login", s.handleLogin)
|
||||
api.POST("/verify-otp", s.handleVerifyOTP)
|
||||
api.POST("/complete-registration", s.handleCompleteRegistration)
|
||||
|
||||
// 需要认证的路由
|
||||
protected := api.Group("/", s.authMiddleware())
|
||||
{
|
||||
// 注销(加入黑名单)
|
||||
protected.POST("/logout", s.handleLogout)
|
||||
|
||||
// 服务器IP查询(需要认证,用于白名单配置)
|
||||
protected.GET("/server-ip", s.handleGetServerIP)
|
||||
|
||||
@@ -180,7 +197,6 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) {
|
||||
betaMode := betaModeStr == "true"
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"admin_mode": auth.IsAdminMode(),
|
||||
"beta_mode": betaMode,
|
||||
"default_coins": defaultCoins,
|
||||
"btc_eth_leverage": btcEthLeverage,
|
||||
@@ -190,6 +206,17 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) {
|
||||
|
||||
// handleGetServerIP 获取服务器IP地址(用于白名单配置)
|
||||
func (s *Server) handleGetServerIP(c *gin.Context) {
|
||||
|
||||
// 首先尝试从Hook获取用户专用IP
|
||||
userIP := hook.HookExec[hook.IpResult](hook.GETIP, c.GetString("user_id"))
|
||||
if userIP != nil && userIP.Error() == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"public_ip": userIP.GetResult(),
|
||||
"message": "请将此IP地址添加到白名单中",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试通过第三方API获取公网IP
|
||||
publicIP := getPublicIPFromAPI()
|
||||
|
||||
@@ -372,6 +399,16 @@ type ModelConfig struct {
|
||||
CustomAPIURL string `json:"customApiUrl,omitempty"`
|
||||
}
|
||||
|
||||
// SafeModelConfig 安全的模型配置结构(不包含敏感信息)
|
||||
type SafeModelConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CustomAPIURL string `json:"customApiUrl"` // 自定义API URL(通常不敏感)
|
||||
CustomModelName string `json:"customModelName"` // 自定义模型名(不敏感)
|
||||
}
|
||||
|
||||
type ExchangeConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -382,6 +419,18 @@ type ExchangeConfig struct {
|
||||
Testnet bool `json:"testnet,omitempty"`
|
||||
}
|
||||
|
||||
// SafeExchangeConfig 安全的交易所配置结构(不包含敏感信息)
|
||||
type SafeExchangeConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "cex" or "dex"
|
||||
Enabled bool `json:"enabled"`
|
||||
Testnet bool `json:"testnet,omitempty"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Hyperliquid钱包地址(不敏感)
|
||||
AsterUser string `json:"asterUser"` // Aster用户名(不敏感)
|
||||
AsterSigner string `json:"asterSigner"` // Aster签名者(不敏感)
|
||||
}
|
||||
|
||||
type UpdateModelConfigRequest struct {
|
||||
Models map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
@@ -507,7 +556,7 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
|
||||
switch req.ExchangeID {
|
||||
case "binance":
|
||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey)
|
||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
|
||||
case "hyperliquid":
|
||||
tempTrader, createErr = trader.NewHyperliquidTrader(
|
||||
exchangeCfg.APIKey, // private key
|
||||
@@ -576,9 +625,9 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 立即将新交易员加载到TraderManager中
|
||||
err = s.traderManager.LoadUserTraders(s.database, userID)
|
||||
err = s.traderManager.LoadTraderByID(s.database, userID, traderID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 加载用户交易员到内存失败: %v", err)
|
||||
log.Printf("⚠️ 加载交易员到内存失败: %v", err)
|
||||
// 这里不返回错误,因为交易员已经成功创建到数据库
|
||||
}
|
||||
|
||||
@@ -594,17 +643,18 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
|
||||
// UpdateTraderRequest 更新交易员请求
|
||||
type UpdateTraderRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
AIModelID string `json:"ai_model_id" binding:"required"`
|
||||
ExchangeID string `json:"exchange_id" binding:"required"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
||||
BTCETHLeverage int `json:"btc_eth_leverage"`
|
||||
AltcoinLeverage int `json:"altcoin_leverage"`
|
||||
TradingSymbols string `json:"trading_symbols"`
|
||||
CustomPrompt string `json:"custom_prompt"`
|
||||
OverrideBasePrompt bool `json:"override_base_prompt"`
|
||||
IsCrossMargin *bool `json:"is_cross_margin"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
AIModelID string `json:"ai_model_id" binding:"required"`
|
||||
ExchangeID string `json:"exchange_id" binding:"required"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
||||
BTCETHLeverage int `json:"btc_eth_leverage"`
|
||||
AltcoinLeverage int `json:"altcoin_leverage"`
|
||||
TradingSymbols string `json:"trading_symbols"`
|
||||
CustomPrompt string `json:"custom_prompt"`
|
||||
OverrideBasePrompt bool `json:"override_base_prompt"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
IsCrossMargin *bool `json:"is_cross_margin"`
|
||||
}
|
||||
|
||||
// handleUpdateTrader 更新交易员配置
|
||||
@@ -662,6 +712,12 @@ func (s *Server) handleUpdateTrader(c *gin.Context) {
|
||||
scanIntervalMinutes = 3
|
||||
}
|
||||
|
||||
// 设置提示词模板,允许更新
|
||||
systemPromptTemplate := req.SystemPromptTemplate
|
||||
if systemPromptTemplate == "" {
|
||||
systemPromptTemplate = existingTrader.SystemPromptTemplate // 如果请求中没有提供,保持原值
|
||||
}
|
||||
|
||||
// 更新交易员配置
|
||||
trader := &config.TraderRecord{
|
||||
ID: traderID,
|
||||
@@ -675,7 +731,7 @@ func (s *Server) handleUpdateTrader(c *gin.Context) {
|
||||
TradingSymbols: req.TradingSymbols,
|
||||
CustomPrompt: req.CustomPrompt,
|
||||
OverrideBasePrompt: req.OverrideBasePrompt,
|
||||
SystemPromptTemplate: existingTrader.SystemPromptTemplate, // 保持原值
|
||||
SystemPromptTemplate: systemPromptTemplate,
|
||||
IsCrossMargin: isCrossMargin,
|
||||
ScanIntervalMinutes: scanIntervalMinutes,
|
||||
IsRunning: existingTrader.IsRunning, // 保持原值
|
||||
@@ -689,9 +745,9 @@ func (s *Server) handleUpdateTrader(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 重新加载交易员到内存
|
||||
err = s.traderManager.LoadUserTraders(s.database, userID)
|
||||
err = s.traderManager.LoadTraderByID(s.database, userID, traderID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
|
||||
log.Printf("⚠️ 重新加载交易员到内存失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 更新交易员成功: %s (模型: %s, 交易所: %s)", req.Name, req.AIModelID, req.ExchangeID)
|
||||
@@ -735,12 +791,15 @@ func (s *Server) handleStartTrader(c *gin.Context) {
|
||||
traderID := c.Param("id")
|
||||
|
||||
// 校验交易员是否属于当前用户
|
||||
_, _, _, err := s.database.GetTraderConfig(userID, traderID)
|
||||
traderRecord, _, _, err := s.database.GetTraderConfig(userID, traderID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在或无访问权限"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取模板名称
|
||||
templateName := traderRecord.SystemPromptTemplate
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在"})
|
||||
@@ -754,6 +813,9 @@ func (s *Server) handleStartTrader(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 重新加载系统提示词模板(确保使用最新的硬盘文件)
|
||||
s.reloadPromptTemplatesWithLog(templateName)
|
||||
|
||||
// 启动交易员
|
||||
go func() {
|
||||
log.Printf("▶️ 启动交易员 %s (%s)", traderID, trader.GetName())
|
||||
@@ -868,7 +930,7 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
|
||||
|
||||
switch traderConfig.ExchangeID {
|
||||
case "binance":
|
||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey)
|
||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
|
||||
case "hyperliquid":
|
||||
tempTrader, createErr = trader.NewHyperliquidTrader(
|
||||
exchangeCfg.APIKey,
|
||||
@@ -934,9 +996,9 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 重新加载交易员到内存
|
||||
err = s.traderManager.LoadUserTraders(s.database, userID)
|
||||
err = s.traderManager.LoadTraderByID(s.database, userID, traderID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
|
||||
log.Printf("⚠️ 重新加载交易员到内存失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("✅ 已同步余额: %.2f → %.2f USDT (%s %.2f%%)", oldBalance, actualBalance, changeType, changePercent)
|
||||
@@ -962,18 +1024,69 @@ func (s *Server) handleGetModelConfigs(c *gin.Context) {
|
||||
}
|
||||
log.Printf("✅ 找到 %d 个AI模型配置", len(models))
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
// 转换为安全的响应结构,移除敏感信息
|
||||
safeModels := make([]SafeModelConfig, len(models))
|
||||
for i, model := range models {
|
||||
safeModels[i] = SafeModelConfig{
|
||||
ID: model.ID,
|
||||
Name: model.Name,
|
||||
Provider: model.Provider,
|
||||
Enabled: model.Enabled,
|
||||
CustomAPIURL: model.CustomAPIURL,
|
||||
CustomModelName: model.CustomModelName,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, safeModels)
|
||||
}
|
||||
|
||||
// handleUpdateModelConfigs 更新AI模型配置
|
||||
// handleUpdateModelConfigs 更新AI模型配置(仅支持加密数据)
|
||||
func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
var req UpdateModelConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
// 读取原始请求体
|
||||
bodyBytes, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析加密的 payload
|
||||
var encryptedPayload crypto.EncryptedPayload
|
||||
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
|
||||
log.Printf("❌ 解析加密载荷失败: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误,必须使用加密传输"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证是否为加密数据
|
||||
if encryptedPayload.WrappedKey == "" {
|
||||
log.Printf("❌ 检测到非加密请求 (UserID: %s)", userID)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "此接口仅支持加密传输,请使用加密客户端",
|
||||
"code": "ENCRYPTION_REQUIRED",
|
||||
"message": "Encrypted transmission is required for security reasons",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 解密数据
|
||||
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
|
||||
if err != nil {
|
||||
log.Printf("❌ 解密模型配置失败 (UserID: %s): %v", userID, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "解密数据失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析解密后的数据
|
||||
var req UpdateModelConfigRequest
|
||||
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
|
||||
log.Printf("❌ 解析解密数据失败: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "解析解密数据失败"})
|
||||
return
|
||||
}
|
||||
log.Printf("🔓 已解密模型配置数据 (UserID: %s)", userID)
|
||||
|
||||
// 更新每个模型的配置
|
||||
for modelID, modelData := range req.Models {
|
||||
err := s.database.UpdateAIModel(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName)
|
||||
@@ -984,13 +1097,13 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 重新加载该用户的所有交易员,使新配置立即生效
|
||||
err := s.traderManager.LoadUserTraders(s.database, userID)
|
||||
err = s.traderManager.LoadUserTraders(s.database, userID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
|
||||
// 这里不返回错误,因为模型配置已经成功更新到数据库
|
||||
}
|
||||
|
||||
log.Printf("✓ AI模型配置已更新: %+v", req.Models)
|
||||
log.Printf("✓ AI模型配置已更新: %+v", SanitizeModelConfigForLog(req.Models))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "模型配置已更新"})
|
||||
}
|
||||
|
||||
@@ -1006,18 +1119,71 @@ func (s *Server) handleGetExchangeConfigs(c *gin.Context) {
|
||||
}
|
||||
log.Printf("✅ 找到 %d 个交易所配置", len(exchanges))
|
||||
|
||||
c.JSON(http.StatusOK, exchanges)
|
||||
// 转换为安全的响应结构,移除敏感信息
|
||||
safeExchanges := make([]SafeExchangeConfig, len(exchanges))
|
||||
for i, exchange := range exchanges {
|
||||
safeExchanges[i] = SafeExchangeConfig{
|
||||
ID: exchange.ID,
|
||||
Name: exchange.Name,
|
||||
Type: exchange.Type,
|
||||
Enabled: exchange.Enabled,
|
||||
Testnet: exchange.Testnet,
|
||||
HyperliquidWalletAddr: exchange.HyperliquidWalletAddr,
|
||||
AsterUser: exchange.AsterUser,
|
||||
AsterSigner: exchange.AsterSigner,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, safeExchanges)
|
||||
}
|
||||
|
||||
// handleUpdateExchangeConfigs 更新交易所配置
|
||||
// handleUpdateExchangeConfigs 更新交易所配置(仅支持加密数据)
|
||||
func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
var req UpdateExchangeConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
// 读取原始请求体
|
||||
bodyBytes, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析加密的 payload
|
||||
var encryptedPayload crypto.EncryptedPayload
|
||||
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
|
||||
log.Printf("❌ 解析加密载荷失败: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误,必须使用加密传输"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证是否为加密数据
|
||||
if encryptedPayload.WrappedKey == "" {
|
||||
log.Printf("❌ 检测到非加密请求 (UserID: %s)", userID)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "此接口仅支持加密传输,请使用加密客户端",
|
||||
"code": "ENCRYPTION_REQUIRED",
|
||||
"message": "Encrypted transmission is required for security reasons",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 解密数据
|
||||
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
|
||||
if err != nil {
|
||||
log.Printf("❌ 解密交易所配置失败 (UserID: %s): %v", userID, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "解密数据失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析解密后的数据
|
||||
var req UpdateExchangeConfigRequest
|
||||
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
|
||||
log.Printf("❌ 解析解密数据失败: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "解析解密数据失败"})
|
||||
return
|
||||
}
|
||||
log.Printf("🔓 已解密交易所配置数据 (UserID: %s)", userID)
|
||||
|
||||
// 更新每个交易所的配置
|
||||
for exchangeID, exchangeData := range req.Exchanges {
|
||||
err := s.database.UpdateExchange(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey)
|
||||
@@ -1028,13 +1194,13 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 重新加载该用户的所有交易员,使新配置立即生效
|
||||
err := s.traderManager.LoadUserTraders(s.database, userID)
|
||||
err = s.traderManager.LoadUserTraders(s.database, userID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err)
|
||||
// 这里不返回错误,因为交易所配置已经成功更新到数据库
|
||||
}
|
||||
|
||||
log.Printf("✓ 交易所配置已更新: %+v", req.Exchanges)
|
||||
log.Printf("✓ 交易所配置已更新: %+v", SanitizeExchangeConfigForLog(req.Exchanges))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "交易所配置已更新"})
|
||||
}
|
||||
|
||||
@@ -1144,21 +1310,22 @@ func (s *Server) handleGetTraderConfig(c *gin.Context) {
|
||||
aiModelID := traderConfig.AIModelID
|
||||
|
||||
result := map[string]interface{}{
|
||||
"trader_id": traderConfig.ID,
|
||||
"trader_name": traderConfig.Name,
|
||||
"ai_model": aiModelID,
|
||||
"exchange_id": traderConfig.ExchangeID,
|
||||
"initial_balance": traderConfig.InitialBalance,
|
||||
"scan_interval_minutes": traderConfig.ScanIntervalMinutes,
|
||||
"btc_eth_leverage": traderConfig.BTCETHLeverage,
|
||||
"altcoin_leverage": traderConfig.AltcoinLeverage,
|
||||
"trading_symbols": traderConfig.TradingSymbols,
|
||||
"custom_prompt": traderConfig.CustomPrompt,
|
||||
"override_base_prompt": traderConfig.OverrideBasePrompt,
|
||||
"is_cross_margin": traderConfig.IsCrossMargin,
|
||||
"use_coin_pool": traderConfig.UseCoinPool,
|
||||
"use_oi_top": traderConfig.UseOITop,
|
||||
"is_running": isRunning,
|
||||
"trader_id": traderConfig.ID,
|
||||
"trader_name": traderConfig.Name,
|
||||
"ai_model": aiModelID,
|
||||
"exchange_id": traderConfig.ExchangeID,
|
||||
"initial_balance": traderConfig.InitialBalance,
|
||||
"scan_interval_minutes": traderConfig.ScanIntervalMinutes,
|
||||
"btc_eth_leverage": traderConfig.BTCETHLeverage,
|
||||
"altcoin_leverage": traderConfig.AltcoinLeverage,
|
||||
"trading_symbols": traderConfig.TradingSymbols,
|
||||
"custom_prompt": traderConfig.CustomPrompt,
|
||||
"override_base_prompt": traderConfig.OverrideBasePrompt,
|
||||
"system_prompt_template": traderConfig.SystemPromptTemplate,
|
||||
"is_cross_margin": traderConfig.IsCrossMargin,
|
||||
"use_coin_pool": traderConfig.UseCoinPool,
|
||||
"use_oi_top": traderConfig.UseOITop,
|
||||
"is_running": isRunning,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
@@ -1280,7 +1447,15 @@ func (s *Server) handleLatestDecisions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
records, err := trader.GetDecisionLogger().GetLatestRecords(5)
|
||||
// 从 query 参数读取 limit,默认 5,最大 50
|
||||
limit := 5
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 50 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
records, err := trader.GetDecisionLogger().GetLatestRecords(limit)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("获取决策日志失败: %v", err),
|
||||
@@ -1459,14 +1634,6 @@ func (s *Server) handlePerformance(c *gin.Context) {
|
||||
// authMiddleware JWT认证中间件
|
||||
func (s *Server) authMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 如果是管理员模式,直接使用admin用户
|
||||
if auth.IsAdminMode() {
|
||||
c.Set("user_id", "admin")
|
||||
c.Set("email", "admin@localhost")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少Authorization头"})
|
||||
@@ -1482,8 +1649,17 @@ func (s *Server) authMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := tokenParts[1]
|
||||
|
||||
// 黑名单检查
|
||||
if auth.IsTokenBlacklisted(tokenString) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "token已失效,请重新登录"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 验证JWT token
|
||||
claims, err := auth.ValidateJWT(tokenParts[1])
|
||||
claims, err := auth.ValidateJWT(tokenString)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的token: " + err.Error()})
|
||||
c.Abort()
|
||||
@@ -1497,8 +1673,37 @@ func (s *Server) authMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogout 将当前token加入黑名单
|
||||
func (s *Server) handleLogout(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少Authorization头"})
|
||||
return
|
||||
}
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Authorization格式"})
|
||||
return
|
||||
}
|
||||
tokenString := parts[1]
|
||||
claims, err := auth.ValidateJWT(tokenString)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的token"})
|
||||
return
|
||||
}
|
||||
var exp time.Time
|
||||
if claims.ExpiresAt != nil {
|
||||
exp = claims.ExpiresAt.Time
|
||||
} else {
|
||||
exp = time.Now().Add(24 * time.Hour)
|
||||
}
|
||||
auth.BlacklistToken(tokenString, exp)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已登出"})
|
||||
}
|
||||
|
||||
// handleRegister 处理用户注册请求
|
||||
func (s *Server) handleRegister(c *gin.Context) {
|
||||
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
@@ -1532,8 +1737,21 @@ func (s *Server) handleRegister(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
_, err := s.database.GetUserByEmail(req.Email)
|
||||
existingUser, err := s.database.GetUserByEmail(req.Email)
|
||||
if err == nil {
|
||||
// 如果用户未完成OTP验证,允许重新获取OTP(支持中断后恢复注册)
|
||||
if !existingUser.OTPVerified {
|
||||
qrCodeURL := auth.GetOTPQRCodeURL(existingUser.OTPSecret, req.Email)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": existingUser.ID,
|
||||
"email": req.Email,
|
||||
"otp_secret": existingUser.OTPSecret,
|
||||
"qr_code_url": qrCodeURL,
|
||||
"message": "检测到未完成的注册,请继续完成OTP设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 用户已完成验证,拒绝重复注册
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "邮箱已被注册"})
|
||||
return
|
||||
}
|
||||
@@ -1728,6 +1946,50 @@ func (s *Server) handleVerifyOTP(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// handleResetPassword 重置密码(通过邮箱 + OTP 验证)
|
||||
func (s *Server) handleResetPassword(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
OTPCode string `json:"otp_code" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 查询用户
|
||||
user, err := s.database.GetUserByEmail(req.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "邮箱不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证 OTP
|
||||
if !auth.VerifyOTP(user.OTPSecret, req.OTPCode) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Google Authenticator 验证码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新密码哈希
|
||||
newPasswordHash, err := auth.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "密码处理失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
err = s.database.UpdateUserPassword(user.ID, newPasswordHash)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "密码更新失败"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("✓ 用户 %s 密码已重置", user.Email)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "密码重置成功,请使用新密码登录"})
|
||||
}
|
||||
|
||||
// initUserDefaultConfigs 为新用户初始化默认的模型和交易所配置
|
||||
func (s *Server) initUserDefaultConfigs(userID string) error {
|
||||
// 注释掉自动创建默认配置,让用户手动添加
|
||||
@@ -1759,7 +2021,22 @@ func (s *Server) handleGetSupportedExchanges(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, exchanges)
|
||||
// 转换为安全的响应结构,移除敏感信息
|
||||
safeExchanges := make([]SafeExchangeConfig, len(exchanges))
|
||||
for i, exchange := range exchanges {
|
||||
safeExchanges[i] = SafeExchangeConfig{
|
||||
ID: exchange.ID,
|
||||
Name: exchange.Name,
|
||||
Type: exchange.Type,
|
||||
Enabled: exchange.Enabled,
|
||||
Testnet: exchange.Testnet,
|
||||
HyperliquidWalletAddr: "", // 默认配置不包含钱包地址
|
||||
AsterUser: "", // 默认配置不包含用户信息
|
||||
AsterSigner: "",
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, safeExchanges)
|
||||
}
|
||||
|
||||
// Start 启动服务器
|
||||
@@ -1791,7 +2068,26 @@ func (s *Server) Start() error {
|
||||
log.Printf(" • GET /api/performance?trader_id=xxx - 指定trader的AI学习表现分析")
|
||||
log.Println()
|
||||
|
||||
return s.router.Run(addr)
|
||||
// 创建 http.Server 以支持 graceful shutdown
|
||||
s.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.router,
|
||||
}
|
||||
|
||||
return s.httpServer.ListenAndServe()
|
||||
}
|
||||
|
||||
// Shutdown 优雅关闭 API 服务器
|
||||
func (s *Server) Shutdown() error {
|
||||
if s.httpServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 设置 5 秒超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return s.httpServer.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// handleGetPromptTemplates 获取所有系统提示词模板列表
|
||||
@@ -2035,3 +2331,17 @@ func (s *Server) handleGetPublicTraderConfig(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// reloadPromptTemplatesWithLog 重新加载提示词模板并记录日志
|
||||
func (s *Server) reloadPromptTemplatesWithLog(templateName string) {
|
||||
if err := decision.ReloadPromptTemplates(); err != nil {
|
||||
log.Printf("⚠️ 重新加载提示词模板失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if templateName == "" {
|
||||
log.Printf("✓ 已重新加载系统提示词模板 [当前使用: default (未指定,使用默认)]")
|
||||
} else {
|
||||
log.Printf("✓ 已重新加载系统提示词模板 [当前使用: %s]", templateName)
|
||||
}
|
||||
}
|
||||
|
||||
227
api/server_test.go
Normal file
227
api/server_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nofx/config"
|
||||
)
|
||||
|
||||
// TestUpdateTraderRequest_SystemPromptTemplate 测试更新交易员时 SystemPromptTemplate 字段是否存在
|
||||
func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requestJSON string
|
||||
expectedPromptTemplate string
|
||||
}{
|
||||
{
|
||||
name: "更新时应该能接收 system_prompt_template=nof1",
|
||||
requestJSON: `{
|
||||
"name": "Test Trader",
|
||||
"ai_model_id": "gpt-4",
|
||||
"exchange_id": "binance",
|
||||
"initial_balance": 1000,
|
||||
"scan_interval_minutes": 5,
|
||||
"btc_eth_leverage": 5,
|
||||
"altcoin_leverage": 3,
|
||||
"trading_symbols": "BTC,ETH",
|
||||
"custom_prompt": "test",
|
||||
"override_base_prompt": false,
|
||||
"is_cross_margin": true,
|
||||
"system_prompt_template": "nof1"
|
||||
}`,
|
||||
expectedPromptTemplate: "nof1",
|
||||
},
|
||||
{
|
||||
name: "更新时应该能接收 system_prompt_template=default",
|
||||
requestJSON: `{
|
||||
"name": "Test Trader",
|
||||
"ai_model_id": "gpt-4",
|
||||
"exchange_id": "binance",
|
||||
"initial_balance": 1000,
|
||||
"scan_interval_minutes": 5,
|
||||
"btc_eth_leverage": 5,
|
||||
"altcoin_leverage": 3,
|
||||
"trading_symbols": "BTC,ETH",
|
||||
"custom_prompt": "test",
|
||||
"override_base_prompt": false,
|
||||
"is_cross_margin": true,
|
||||
"system_prompt_template": "default"
|
||||
}`,
|
||||
expectedPromptTemplate: "default",
|
||||
},
|
||||
{
|
||||
name: "更新时应该能接收 system_prompt_template=custom",
|
||||
requestJSON: `{
|
||||
"name": "Test Trader",
|
||||
"ai_model_id": "gpt-4",
|
||||
"exchange_id": "binance",
|
||||
"initial_balance": 1000,
|
||||
"scan_interval_minutes": 5,
|
||||
"btc_eth_leverage": 5,
|
||||
"altcoin_leverage": 3,
|
||||
"trading_symbols": "BTC,ETH",
|
||||
"custom_prompt": "test",
|
||||
"override_base_prompt": false,
|
||||
"is_cross_margin": true,
|
||||
"system_prompt_template": "custom"
|
||||
}`,
|
||||
expectedPromptTemplate: "custom",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试 UpdateTraderRequest 结构体是否能正确解析 system_prompt_template 字段
|
||||
var req UpdateTraderRequest
|
||||
err := json.Unmarshal([]byte(tt.requestJSON), &req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// ✅ 验证 SystemPromptTemplate 字段是否被正确读取
|
||||
if req.SystemPromptTemplate != tt.expectedPromptTemplate {
|
||||
t.Errorf("Expected SystemPromptTemplate=%q, got %q",
|
||||
tt.expectedPromptTemplate, req.SystemPromptTemplate)
|
||||
}
|
||||
|
||||
// 验证其他字段也被正确解析
|
||||
if req.Name != "Test Trader" {
|
||||
t.Errorf("Name not parsed correctly")
|
||||
}
|
||||
if req.AIModelID != "gpt-4" {
|
||||
t.Errorf("AIModelID not parsed correctly")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetTraderConfigResponse_SystemPromptTemplate 测试获取交易员配置时返回值是否包含 system_prompt_template
|
||||
func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
traderConfig *config.TraderRecord
|
||||
expectedTemplate string
|
||||
}{
|
||||
{
|
||||
name: "获取配置应该返回 system_prompt_template=nof1",
|
||||
traderConfig: &config.TraderRecord{
|
||||
ID: "trader-123",
|
||||
UserID: "user-1",
|
||||
Name: "Test Trader",
|
||||
AIModelID: "gpt-4",
|
||||
ExchangeID: "binance",
|
||||
InitialBalance: 1000,
|
||||
ScanIntervalMinutes: 5,
|
||||
BTCETHLeverage: 5,
|
||||
AltcoinLeverage: 3,
|
||||
TradingSymbols: "BTC,ETH",
|
||||
CustomPrompt: "test",
|
||||
OverrideBasePrompt: false,
|
||||
SystemPromptTemplate: "nof1",
|
||||
IsCrossMargin: true,
|
||||
IsRunning: false,
|
||||
},
|
||||
expectedTemplate: "nof1",
|
||||
},
|
||||
{
|
||||
name: "获取配置应该返回 system_prompt_template=default",
|
||||
traderConfig: &config.TraderRecord{
|
||||
ID: "trader-456",
|
||||
UserID: "user-1",
|
||||
Name: "Test Trader 2",
|
||||
AIModelID: "gpt-4",
|
||||
ExchangeID: "binance",
|
||||
InitialBalance: 2000,
|
||||
ScanIntervalMinutes: 10,
|
||||
BTCETHLeverage: 10,
|
||||
AltcoinLeverage: 5,
|
||||
TradingSymbols: "BTC",
|
||||
CustomPrompt: "",
|
||||
OverrideBasePrompt: false,
|
||||
SystemPromptTemplate: "default",
|
||||
IsCrossMargin: false,
|
||||
IsRunning: false,
|
||||
},
|
||||
expectedTemplate: "default",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 handleGetTraderConfig 的返回值构造逻辑(修复后的实现)
|
||||
result := map[string]interface{}{
|
||||
"trader_id": tt.traderConfig.ID,
|
||||
"trader_name": tt.traderConfig.Name,
|
||||
"ai_model": tt.traderConfig.AIModelID,
|
||||
"exchange_id": tt.traderConfig.ExchangeID,
|
||||
"initial_balance": tt.traderConfig.InitialBalance,
|
||||
"scan_interval_minutes": tt.traderConfig.ScanIntervalMinutes,
|
||||
"btc_eth_leverage": tt.traderConfig.BTCETHLeverage,
|
||||
"altcoin_leverage": tt.traderConfig.AltcoinLeverage,
|
||||
"trading_symbols": tt.traderConfig.TradingSymbols,
|
||||
"custom_prompt": tt.traderConfig.CustomPrompt,
|
||||
"override_base_prompt": tt.traderConfig.OverrideBasePrompt,
|
||||
"system_prompt_template": tt.traderConfig.SystemPromptTemplate,
|
||||
"is_cross_margin": tt.traderConfig.IsCrossMargin,
|
||||
"is_running": tt.traderConfig.IsRunning,
|
||||
}
|
||||
|
||||
// ✅ 检查响应中是否包含 system_prompt_template
|
||||
if _, exists := result["system_prompt_template"]; !exists {
|
||||
t.Errorf("Response is missing 'system_prompt_template' field")
|
||||
} else {
|
||||
actualTemplate := result["system_prompt_template"].(string)
|
||||
if actualTemplate != tt.expectedTemplate {
|
||||
t.Errorf("Expected system_prompt_template=%q, got %q",
|
||||
tt.expectedTemplate, actualTemplate)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证其他字段是否正确
|
||||
if result["trader_id"] != tt.traderConfig.ID {
|
||||
t.Errorf("trader_id mismatch")
|
||||
}
|
||||
if result["trader_name"] != tt.traderConfig.Name {
|
||||
t.Errorf("trader_name mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTraderRequest_CompleteFields 验证 UpdateTraderRequest 结构体定义完整性
|
||||
func TestUpdateTraderRequest_CompleteFields(t *testing.T) {
|
||||
jsonData := `{
|
||||
"name": "Test Trader",
|
||||
"ai_model_id": "gpt-4",
|
||||
"exchange_id": "binance",
|
||||
"initial_balance": 1000,
|
||||
"scan_interval_minutes": 5,
|
||||
"btc_eth_leverage": 5,
|
||||
"altcoin_leverage": 3,
|
||||
"trading_symbols": "BTC,ETH",
|
||||
"custom_prompt": "test",
|
||||
"override_base_prompt": false,
|
||||
"is_cross_margin": true,
|
||||
"system_prompt_template": "nof1"
|
||||
}`
|
||||
|
||||
var req UpdateTraderRequest
|
||||
err := json.Unmarshal([]byte(jsonData), &req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// 验证基本字段是否正确解析
|
||||
if req.Name != "Test Trader" {
|
||||
t.Errorf("Name mismatch: got %q", req.Name)
|
||||
}
|
||||
if req.AIModelID != "gpt-4" {
|
||||
t.Errorf("AIModelID mismatch: got %q", req.AIModelID)
|
||||
}
|
||||
|
||||
// ✅ 验证 SystemPromptTemplate 字段已正确添加到结构体
|
||||
if req.SystemPromptTemplate != "nof1" {
|
||||
t.Errorf("SystemPromptTemplate mismatch: expected %q, got %q", "nof1", req.SystemPromptTemplate)
|
||||
}
|
||||
}
|
||||
97
api/utils.go
Normal file
97
api/utils.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package api
|
||||
|
||||
import "strings"
|
||||
|
||||
// MaskSensitiveString 脱敏敏感字符串,只显示前4位和后4位
|
||||
// 用于脱敏 API Key、Secret Key、Private Key 等敏感信息
|
||||
func MaskSensitiveString(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
length := len(s)
|
||||
if length <= 8 {
|
||||
return "****" // 字符串太短,全部隐藏
|
||||
}
|
||||
return s[:4] + "****" + s[length-4:]
|
||||
}
|
||||
|
||||
// SanitizeModelConfigForLog 脱敏模型配置用于日志输出
|
||||
func SanitizeModelConfigForLog(models map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
CustomAPIURL string `json:"custom_api_url"`
|
||||
CustomModelName string `json:"custom_model_name"`
|
||||
}) map[string]interface{} {
|
||||
safe := make(map[string]interface{})
|
||||
for modelID, cfg := range models {
|
||||
safe[modelID] = map[string]interface{}{
|
||||
"enabled": cfg.Enabled,
|
||||
"api_key": MaskSensitiveString(cfg.APIKey),
|
||||
"custom_api_url": cfg.CustomAPIURL,
|
||||
"custom_model_name": cfg.CustomModelName,
|
||||
}
|
||||
}
|
||||
return safe
|
||||
}
|
||||
|
||||
// SanitizeExchangeConfigForLog 脱敏交易所配置用于日志输出
|
||||
func SanitizeExchangeConfigForLog(exchanges map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
|
||||
AsterUser string `json:"aster_user"`
|
||||
AsterSigner string `json:"aster_signer"`
|
||||
AsterPrivateKey string `json:"aster_private_key"`
|
||||
}) map[string]interface{} {
|
||||
safe := make(map[string]interface{})
|
||||
for exchangeID, cfg := range exchanges {
|
||||
safeExchange := map[string]interface{}{
|
||||
"enabled": cfg.Enabled,
|
||||
"testnet": cfg.Testnet,
|
||||
}
|
||||
|
||||
// 只在有值时才添加脱敏后的敏感字段
|
||||
if cfg.APIKey != "" {
|
||||
safeExchange["api_key"] = MaskSensitiveString(cfg.APIKey)
|
||||
}
|
||||
if cfg.SecretKey != "" {
|
||||
safeExchange["secret_key"] = MaskSensitiveString(cfg.SecretKey)
|
||||
}
|
||||
if cfg.AsterPrivateKey != "" {
|
||||
safeExchange["aster_private_key"] = MaskSensitiveString(cfg.AsterPrivateKey)
|
||||
}
|
||||
|
||||
// 非敏感字段直接添加
|
||||
if cfg.HyperliquidWalletAddr != "" {
|
||||
safeExchange["hyperliquid_wallet_addr"] = cfg.HyperliquidWalletAddr
|
||||
}
|
||||
if cfg.AsterUser != "" {
|
||||
safeExchange["aster_user"] = cfg.AsterUser
|
||||
}
|
||||
if cfg.AsterSigner != "" {
|
||||
safeExchange["aster_signer"] = cfg.AsterSigner
|
||||
}
|
||||
|
||||
safe[exchangeID] = safeExchange
|
||||
}
|
||||
return safe
|
||||
}
|
||||
|
||||
// MaskEmail 脱敏邮箱地址,保留前2位和@后部分
|
||||
func MaskEmail(email string) string {
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return "****" // 格式不正确
|
||||
}
|
||||
username := parts[0]
|
||||
domain := parts[1]
|
||||
if len(username) <= 2 {
|
||||
return "**@" + domain
|
||||
}
|
||||
return username[:2] + "****@" + domain
|
||||
}
|
||||
193
api/utils_test.go
Normal file
193
api/utils_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMaskSensitiveString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "短字符串(小于等于8位)",
|
||||
input: "short",
|
||||
expected: "****",
|
||||
},
|
||||
{
|
||||
name: "正常API key",
|
||||
input: "sk-1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
expected: "sk-1****wxyz",
|
||||
},
|
||||
{
|
||||
name: "正常私钥",
|
||||
input: "0x1234567890abcdef1234567890abcdef12345678",
|
||||
expected: "0x12****5678",
|
||||
},
|
||||
{
|
||||
name: "刚好9位",
|
||||
input: "123456789",
|
||||
expected: "1234****6789",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := MaskSensitiveString(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("MaskSensitiveString(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeModelConfigForLog(t *testing.T) {
|
||||
models := map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
CustomAPIURL string `json:"custom_api_url"`
|
||||
CustomModelName string `json:"custom_model_name"`
|
||||
}{
|
||||
"deepseek": {
|
||||
Enabled: true,
|
||||
APIKey: "sk-1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
CustomAPIURL: "https://api.deepseek.com",
|
||||
CustomModelName: "deepseek-chat",
|
||||
},
|
||||
}
|
||||
|
||||
result := SanitizeModelConfigForLog(models)
|
||||
|
||||
deepseekConfig, ok := result["deepseek"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("deepseek config not found or wrong type")
|
||||
}
|
||||
|
||||
if deepseekConfig["enabled"] != true {
|
||||
t.Errorf("expected enabled=true, got %v", deepseekConfig["enabled"])
|
||||
}
|
||||
|
||||
maskedKey, ok := deepseekConfig["api_key"].(string)
|
||||
if !ok {
|
||||
t.Fatal("api_key not found or wrong type")
|
||||
}
|
||||
|
||||
if maskedKey != "sk-1****wxyz" {
|
||||
t.Errorf("expected masked api_key='sk-1****wxyz', got %q", maskedKey)
|
||||
}
|
||||
|
||||
if deepseekConfig["custom_api_url"] != "https://api.deepseek.com" {
|
||||
t.Errorf("custom_api_url should not be masked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeExchangeConfigForLog(t *testing.T) {
|
||||
exchanges := map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
|
||||
AsterUser string `json:"aster_user"`
|
||||
AsterSigner string `json:"aster_signer"`
|
||||
AsterPrivateKey string `json:"aster_private_key"`
|
||||
}{
|
||||
"binance": {
|
||||
Enabled: true,
|
||||
APIKey: "binance_api_key_1234567890abcdef",
|
||||
SecretKey: "binance_secret_key_1234567890abcdef",
|
||||
Testnet: false,
|
||||
},
|
||||
"hyperliquid": {
|
||||
Enabled: true,
|
||||
HyperliquidWalletAddr: "0x1234567890abcdef1234567890abcdef12345678",
|
||||
Testnet: false,
|
||||
},
|
||||
}
|
||||
|
||||
result := SanitizeExchangeConfigForLog(exchanges)
|
||||
|
||||
// 检查币安配置
|
||||
binanceConfig, ok := result["binance"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("binance config not found or wrong type")
|
||||
}
|
||||
|
||||
maskedAPIKey, ok := binanceConfig["api_key"].(string)
|
||||
if !ok {
|
||||
t.Fatal("binance api_key not found or wrong type")
|
||||
}
|
||||
|
||||
if maskedAPIKey != "bina****cdef" {
|
||||
t.Errorf("expected masked api_key='bina****cdef', got %q", maskedAPIKey)
|
||||
}
|
||||
|
||||
maskedSecretKey, ok := binanceConfig["secret_key"].(string)
|
||||
if !ok {
|
||||
t.Fatal("binance secret_key not found or wrong type")
|
||||
}
|
||||
|
||||
if maskedSecretKey != "bina****cdef" {
|
||||
t.Errorf("expected masked secret_key='bina****cdef', got %q", maskedSecretKey)
|
||||
}
|
||||
|
||||
// 检查 Hyperliquid 配置
|
||||
hlConfig, ok := result["hyperliquid"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("hyperliquid config not found or wrong type")
|
||||
}
|
||||
|
||||
walletAddr, ok := hlConfig["hyperliquid_wallet_addr"].(string)
|
||||
if !ok {
|
||||
t.Fatal("hyperliquid_wallet_addr not found or wrong type")
|
||||
}
|
||||
|
||||
// 钱包地址不应该被脱敏
|
||||
if walletAddr != "0x1234567890abcdef1234567890abcdef12345678" {
|
||||
t.Errorf("wallet address should not be masked, got %q", walletAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaskEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "空邮箱",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "格式错误",
|
||||
input: "notanemail",
|
||||
expected: "****",
|
||||
},
|
||||
{
|
||||
name: "正常邮箱",
|
||||
input: "user@example.com",
|
||||
expected: "us****@example.com",
|
||||
},
|
||||
{
|
||||
name: "短用户名",
|
||||
input: "a@example.com",
|
||||
expected: "**@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := MaskEmail(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("MaskEmail(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user