account system、custom prompt

This commit is contained in:
icy
2025-10-31 03:42:01 +08:00
parent bbe1e1f929
commit ceb2f7b435
32 changed files with 3873 additions and 465 deletions

View File

@@ -4,11 +4,14 @@ import (
"fmt"
"log"
"net/http"
"nofx/auth"
"nofx/config"
"nofx/manager"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// Server HTTP API服务器
@@ -66,30 +69,51 @@ func (s *Server) setupRoutes() {
// API路由组
api := s.router.Group("/api")
{
// AI交易员管理
api.GET("/traders", s.handleTraderList)
api.POST("/traders", s.handleCreateTrader)
api.DELETE("/traders/:id", s.handleDeleteTrader)
api.POST("/traders/:id/start", s.handleStartTrader)
api.POST("/traders/:id/stop", s.handleStopTrader)
// 认证相关路由(无需认证)
api.POST("/register", s.handleRegister)
api.POST("/login", s.handleLogin)
api.POST("/verify-otp", s.handleVerifyOTP)
api.POST("/complete-registration", s.handleCompleteRegistration)
// 系统支持的模型和交易所(无需认证)
api.GET("/supported-models", s.handleGetSupportedModels)
api.GET("/supported-exchanges", s.handleGetSupportedExchanges)
// 系统配置(无需认证)
api.GET("/config", s.handleGetSystemConfig)
// AI模型配置
api.GET("/models", s.handleGetModelConfigs)
api.PUT("/models", s.handleUpdateModelConfigs)
// 需要认证的路由
protected := api.Group("/", s.authMiddleware())
{
// AI交易员管理
protected.GET("/traders", s.handleTraderList)
protected.POST("/traders", s.handleCreateTrader)
protected.DELETE("/traders/:id", s.handleDeleteTrader)
protected.POST("/traders/:id/start", s.handleStartTrader)
protected.POST("/traders/:id/stop", s.handleStopTrader)
protected.PUT("/traders/:id/prompt", s.handleUpdateTraderPrompt)
// 交易所配置
api.GET("/exchanges", s.handleGetExchangeConfigs)
api.PUT("/exchanges", s.handleUpdateExchangeConfigs)
// AI模型配置
protected.GET("/models", s.handleGetModelConfigs)
protected.PUT("/models", s.handleUpdateModelConfigs)
// 指定trader的数据使用query参数 ?trader_id=xxx
api.GET("/status", s.handleStatus)
api.GET("/account", s.handleAccount)
api.GET("/positions", s.handlePositions)
api.GET("/decisions", s.handleDecisions)
api.GET("/decisions/latest", s.handleLatestDecisions)
api.GET("/statistics", s.handleStatistics)
api.GET("/equity-history", s.handleEquityHistory)
api.GET("/performance", s.handlePerformance)
// 交易所配置
protected.GET("/exchanges", s.handleGetExchangeConfigs)
protected.PUT("/exchanges", s.handleUpdateExchangeConfigs)
// 竞赛总览
protected.GET("/competition", s.handleCompetition)
// 指定trader的数据使用query参数 ?trader_id=xxx
protected.GET("/status", s.handleStatus)
protected.GET("/account", s.handleAccount)
protected.GET("/positions", s.handlePositions)
protected.GET("/decisions", s.handleDecisions)
protected.GET("/decisions/latest", s.handleLatestDecisions)
protected.GET("/statistics", s.handleStatistics)
protected.GET("/equity-history", s.handleEquityHistory)
protected.GET("/performance", s.handlePerformance)
}
}
}
@@ -101,17 +125,40 @@ func (s *Server) handleHealth(c *gin.Context) {
})
}
// handleGetSystemConfig 获取系统配置(客户端需要知道的配置)
func (s *Server) handleGetSystemConfig(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"admin_mode": auth.IsAdminMode(),
})
}
// getTraderFromQuery 从query参数获取trader
func (s *Server) getTraderFromQuery(c *gin.Context) (*manager.TraderManager, string, error) {
userID := c.GetString("user_id")
traderID := c.Query("trader_id")
// 确保用户的交易员已加载到内存中
err := s.traderManager.LoadUserTraders(s.database, userID)
if err != nil {
log.Printf("⚠️ 加载用户 %s 的交易员失败: %v", userID, err)
}
if traderID == "" {
// 如果没有指定trader_id返回第一个trader
// 如果没有指定trader_id返回该用户的第一个trader
ids := s.traderManager.GetTraderIDs()
if len(ids) == 0 {
return nil, "", fmt.Errorf("没有可用的trader")
}
traderID = ids[0]
// 获取用户的交易员列表,优先返回用户自己的交易员
userTraders, err := s.database.GetTraders(userID)
if err == nil && len(userTraders) > 0 {
traderID = userTraders[0].ID
} else {
traderID = ids[0]
}
}
return s.traderManager, traderID, nil
}
@@ -121,6 +168,8 @@ type CreateTraderRequest struct {
AIModelID string `json:"ai_model_id" binding:"required"`
ExchangeID string `json:"exchange_id" binding:"required"`
InitialBalance float64 `json:"initial_balance"`
CustomPrompt string `json:"custom_prompt"`
OverrideBasePrompt bool `json:"override_base_prompt"`
}
type ModelConfig struct {
@@ -150,15 +199,20 @@ type UpdateModelConfigRequest struct {
type UpdateExchangeConfigRequest struct {
Exchanges map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
SecretKey string `json:"secret_key"`
Testnet bool `json:"testnet"`
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
SecretKey string `json:"secret_key"`
Testnet bool `json:"testnet"`
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
AsterUser string `json:"aster_user"`
AsterSigner string `json:"aster_signer"`
AsterPrivateKey string `json:"aster_private_key"`
} `json:"exchanges"`
}
// handleCreateTrader 创建新的AI交易员
func (s *Server) handleCreateTrader(c *gin.Context) {
userID := c.GetString("user_id")
var req CreateTraderRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -171,10 +225,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
// 创建交易员配置
trader := &config.TraderConfig{
ID: traderID,
UserID: userID,
Name: req.Name,
AIModelID: req.AIModelID,
ExchangeID: req.ExchangeID,
InitialBalance: req.InitialBalance,
CustomPrompt: req.CustomPrompt,
OverrideBasePrompt: req.OverrideBasePrompt,
ScanIntervalMinutes: 3, // 默认3分钟
IsRunning: false,
}
@@ -186,6 +243,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
return
}
// 立即将新交易员加载到TraderManager中
err = s.traderManager.LoadUserTraders(s.database, userID)
if err != nil {
log.Printf("⚠️ 加载用户交易员到内存失败: %v", err)
// 这里不返回错误,因为交易员已经成功创建到数据库
}
log.Printf("✓ 创建交易员成功: %s (模型: %s, 交易所: %s)", req.Name, req.AIModelID, req.ExchangeID)
c.JSON(http.StatusCreated, gin.H{
@@ -198,10 +262,11 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
// handleDeleteTrader 删除交易员
func (s *Server) handleDeleteTrader(c *gin.Context) {
userID := c.GetString("user_id")
traderID := c.Param("id")
// 从数据库删除
err := s.database.DeleteTrader(traderID)
err := s.database.DeleteTrader(userID, traderID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("删除交易员失败: %v", err)})
return
@@ -246,7 +311,8 @@ func (s *Server) handleStartTrader(c *gin.Context) {
}()
// 更新数据库中的运行状态
err = s.database.UpdateTraderStatus(traderID, true)
userID := c.GetString("user_id")
err = s.database.UpdateTraderStatus(userID, traderID, true)
if err != nil {
log.Printf("⚠️ 更新交易员状态失败: %v", err)
}
@@ -276,7 +342,8 @@ func (s *Server) handleStopTrader(c *gin.Context) {
trader.Stop()
// 更新数据库中的运行状态
err = s.database.UpdateTraderStatus(traderID, false)
userID := c.GetString("user_id")
err = s.database.UpdateTraderStatus(userID, traderID, false)
if err != nil {
log.Printf("⚠️ 更新交易员状态失败: %v", err)
}
@@ -285,19 +352,57 @@ func (s *Server) handleStopTrader(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "交易员已停止"})
}
// handleUpdateTraderPrompt 更新交易员自定义Prompt
func (s *Server) handleUpdateTraderPrompt(c *gin.Context) {
traderID := c.Param("id")
userID := c.GetString("user_id")
var req struct {
CustomPrompt string `json:"custom_prompt"`
OverrideBasePrompt bool `json:"override_base_prompt"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 更新数据库
err := s.database.UpdateTraderCustomPrompt(userID, traderID, req.CustomPrompt, req.OverrideBasePrompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新自定义prompt失败: %v", err)})
return
}
// 如果trader在内存中更新其custom prompt和override设置
trader, err := s.traderManager.GetTrader(traderID)
if err == nil {
trader.SetCustomPrompt(req.CustomPrompt)
trader.SetOverrideBasePrompt(req.OverrideBasePrompt)
log.Printf("✓ 已更新交易员 %s 的自定义prompt (覆盖基础=%v)", trader.GetName(), req.OverrideBasePrompt)
}
c.JSON(http.StatusOK, gin.H{"message": "自定义prompt已更新"})
}
// handleGetModelConfigs 获取AI模型配置
func (s *Server) handleGetModelConfigs(c *gin.Context) {
models, err := s.database.GetAIModels()
userID := c.GetString("user_id")
log.Printf("🔍 查询用户 %s 的AI模型配置", userID)
models, err := s.database.GetAIModels(userID)
if err != nil {
log.Printf("❌ 获取AI模型配置失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取AI模型配置失败: %v", err)})
return
}
log.Printf("✅ 找到 %d 个AI模型配置", len(models))
c.JSON(http.StatusOK, models)
}
// handleUpdateModelConfigs 更新AI模型配置
func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
userID := c.GetString("user_id")
var req UpdateModelConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -306,7 +411,7 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
// 更新每个模型的配置
for modelID, modelData := range req.Models {
err := s.database.UpdateAIModel(modelID, modelData.Enabled, modelData.APIKey)
err := s.database.UpdateAIModel(userID, modelID, modelData.Enabled, modelData.APIKey)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新模型 %s 失败: %v", modelID, err)})
return
@@ -319,17 +424,22 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
// handleGetExchangeConfigs 获取交易所配置
func (s *Server) handleGetExchangeConfigs(c *gin.Context) {
exchanges, err := s.database.GetExchanges()
userID := c.GetString("user_id")
log.Printf("🔍 查询用户 %s 的交易所配置", userID)
exchanges, err := s.database.GetExchanges(userID)
if err != nil {
log.Printf("❌ 获取交易所配置失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取交易所配置失败: %v", err)})
return
}
log.Printf("✅ 找到 %d 个交易所配置", len(exchanges))
c.JSON(http.StatusOK, exchanges)
}
// handleUpdateExchangeConfigs 更新交易所配置
func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
userID := c.GetString("user_id")
var req UpdateExchangeConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -338,7 +448,7 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
// 更新每个交易所的配置
for exchangeID, exchangeData := range req.Exchanges {
err := s.database.UpdateExchange(exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet)
err := s.database.UpdateExchange(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新交易所 %s 失败: %v", exchangeID, err)})
return
@@ -351,7 +461,8 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
// handleTraderList trader列表
func (s *Server) handleTraderList(c *gin.Context) {
traders, err := s.database.GetTraders()
userID := c.GetString("user_id")
traders, err := s.database.GetTraders(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取交易员列表失败: %v", err)})
return
@@ -539,6 +650,27 @@ func (s *Server) handleStatistics(c *gin.Context) {
c.JSON(http.StatusOK, stats)
}
// handleCompetition 竞赛总览对比所有trader
func (s *Server) handleCompetition(c *gin.Context) {
userID := c.GetString("user_id")
// 确保用户的交易员已加载到内存中
err := s.traderManager.LoadUserTraders(s.database, userID)
if err != nil {
log.Printf("⚠️ 加载用户 %s 的交易员失败: %v", userID, err)
}
competition, err := s.traderManager.GetCompetitionData(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("获取竞赛数据失败: %v", err),
})
return
}
c.JSON(http.StatusOK, competition)
}
// handleEquityHistory 收益率历史数据
func (s *Server) handleEquityHistory(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
@@ -652,6 +784,278 @@ func (s *Server) handlePerformance(c *gin.Context) {
c.JSON(http.StatusOK, performance)
}
// authMiddleware JWT认证中间件
func (s *Server) authMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 如果是管理员模式直接使用admin用户
if auth.IsAdminMode() {
c.Set("user_id", "admin")
c.Set("email", "admin@localhost")
c.Next()
return
}
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少Authorization头"})
c.Abort()
return
}
// 检查Bearer token格式
tokenParts := strings.Split(authHeader, " ")
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Authorization格式"})
c.Abort()
return
}
// 验证JWT token
claims, err := auth.ValidateJWT(tokenParts[1])
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的token: " + err.Error()})
c.Abort()
return
}
// 将用户信息存储到上下文中
c.Set("user_id", claims.UserID)
c.Set("email", claims.Email)
c.Next()
}
}
// handleRegister 处理用户注册请求
func (s *Server) handleRegister(c *gin.Context) {
var req struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 检查邮箱是否已存在
_, err := s.database.GetUserByEmail(req.Email)
if err == nil {
c.JSON(http.StatusConflict, gin.H{"error": "邮箱已被注册"})
return
}
// 生成密码哈希
passwordHash, err := auth.HashPassword(req.Password)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "密码处理失败"})
return
}
// 生成OTP密钥
otpSecret, err := auth.GenerateOTPSecret()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "OTP密钥生成失败"})
return
}
// 创建用户未验证OTP状态
userID := uuid.New().String()
user := &config.User{
ID: userID,
Email: req.Email,
PasswordHash: passwordHash,
OTPSecret: otpSecret,
OTPVerified: false,
}
err = s.database.CreateUser(user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建用户失败: " + err.Error()})
return
}
// 返回OTP设置信息
qrCodeURL := auth.GetOTPQRCodeURL(otpSecret, req.Email)
c.JSON(http.StatusOK, gin.H{
"user_id": userID,
"email": req.Email,
"otp_secret": otpSecret,
"qr_code_url": qrCodeURL,
"message": "请使用Google Authenticator扫描二维码并验证OTP",
})
}
// handleCompleteRegistration 完成注册验证OTP
func (s *Server) handleCompleteRegistration(c *gin.Context) {
var req struct {
UserID string `json:"user_id" binding:"required"`
OTPCode string `json:"otp_code" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 获取用户信息
user, err := s.database.GetUserByID(req.UserID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "用户不存在"})
return
}
// 验证OTP
if !auth.VerifyOTP(user.OTPSecret, req.OTPCode) {
c.JSON(http.StatusBadRequest, gin.H{"error": "OTP验证码错误"})
return
}
// 更新用户OTP验证状态
err = s.database.UpdateUserOTPVerified(req.UserID, true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新用户状态失败"})
return
}
// 生成JWT token
token, err := auth.GenerateJWT(user.ID, user.Email)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成token失败"})
return
}
// 初始化用户的默认模型和交易所配置
err = s.initUserDefaultConfigs(user.ID)
if err != nil {
log.Printf("初始化用户默认配置失败: %v", err)
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"user_id": user.ID,
"email": user.Email,
"message": "注册完成",
})
}
// handleLogin 处理用户登录请求
func (s *Server) handleLogin(c *gin.Context) {
var req struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 获取用户信息
user, err := s.database.GetUserByEmail(req.Email)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "邮箱或密码错误"})
return
}
// 验证密码
if !auth.CheckPassword(req.Password, user.PasswordHash) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "邮箱或密码错误"})
return
}
// 检查OTP是否已验证
if !user.OTPVerified {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "账户未完成OTP设置",
"user_id": user.ID,
"requires_otp_setup": true,
})
return
}
// 返回需要OTP验证的状态
c.JSON(http.StatusOK, gin.H{
"user_id": user.ID,
"email": user.Email,
"message": "请输入Google Authenticator验证码",
"requires_otp": true,
})
}
// handleVerifyOTP 验证OTP并完成登录
func (s *Server) handleVerifyOTP(c *gin.Context) {
var req struct {
UserID string `json:"user_id" binding:"required"`
OTPCode string `json:"otp_code" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 获取用户信息
user, err := s.database.GetUserByID(req.UserID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "用户不存在"})
return
}
// 验证OTP
if !auth.VerifyOTP(user.OTPSecret, req.OTPCode) {
c.JSON(http.StatusBadRequest, gin.H{"error": "验证码错误"})
return
}
// 生成JWT token
token, err := auth.GenerateJWT(user.ID, user.Email)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成token失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"user_id": user.ID,
"email": user.Email,
"message": "登录成功",
})
}
// initUserDefaultConfigs 为新用户初始化默认的模型和交易所配置
func (s *Server) initUserDefaultConfigs(userID string) error {
// 注释掉自动创建默认配置,让用户手动添加
// 这样新用户注册后不会自动有配置项
log.Printf("用户 %s 注册完成等待手动配置AI模型和交易所", userID)
return nil
}
// handleGetSupportedModels 获取系统支持的AI模型列表
func (s *Server) handleGetSupportedModels(c *gin.Context) {
// 返回系统支持的AI模型从default用户获取
models, err := s.database.GetAIModels("default")
if err != nil {
log.Printf("❌ 获取支持的AI模型失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取支持的AI模型失败"})
return
}
c.JSON(http.StatusOK, models)
}
// handleGetSupportedExchanges 获取系统支持的交易所列表
func (s *Server) handleGetSupportedExchanges(c *gin.Context) {
// 返回系统支持的交易所从default用户获取
exchanges, err := s.database.GetExchanges("default")
if err != nil {
log.Printf("❌ 获取支持的交易所失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取支持的交易所失败"})
return
}
c.JSON(http.StatusOK, exchanges)
}
// Start 启动服务器
func (s *Server) Start() error {
addr := fmt.Sprintf(":%d", s.port)