mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2026-06-06 05:51:19 +08:00
account system、custom prompt
This commit is contained in:
474
api/server.go
474
api/server.go
@@ -4,11 +4,14 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"nofx/auth"
|
||||
"nofx/config"
|
||||
"nofx/manager"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Server HTTP API服务器
|
||||
@@ -66,30 +69,51 @@ func (s *Server) setupRoutes() {
|
||||
// API路由组
|
||||
api := s.router.Group("/api")
|
||||
{
|
||||
// AI交易员管理
|
||||
api.GET("/traders", s.handleTraderList)
|
||||
api.POST("/traders", s.handleCreateTrader)
|
||||
api.DELETE("/traders/:id", s.handleDeleteTrader)
|
||||
api.POST("/traders/:id/start", s.handleStartTrader)
|
||||
api.POST("/traders/:id/stop", s.handleStopTrader)
|
||||
// 认证相关路由(无需认证)
|
||||
api.POST("/register", s.handleRegister)
|
||||
api.POST("/login", s.handleLogin)
|
||||
api.POST("/verify-otp", s.handleVerifyOTP)
|
||||
api.POST("/complete-registration", s.handleCompleteRegistration)
|
||||
|
||||
// 系统支持的模型和交易所(无需认证)
|
||||
api.GET("/supported-models", s.handleGetSupportedModels)
|
||||
api.GET("/supported-exchanges", s.handleGetSupportedExchanges)
|
||||
|
||||
// 系统配置(无需认证)
|
||||
api.GET("/config", s.handleGetSystemConfig)
|
||||
|
||||
// AI模型配置
|
||||
api.GET("/models", s.handleGetModelConfigs)
|
||||
api.PUT("/models", s.handleUpdateModelConfigs)
|
||||
// 需要认证的路由
|
||||
protected := api.Group("/", s.authMiddleware())
|
||||
{
|
||||
// AI交易员管理
|
||||
protected.GET("/traders", s.handleTraderList)
|
||||
protected.POST("/traders", s.handleCreateTrader)
|
||||
protected.DELETE("/traders/:id", s.handleDeleteTrader)
|
||||
protected.POST("/traders/:id/start", s.handleStartTrader)
|
||||
protected.POST("/traders/:id/stop", s.handleStopTrader)
|
||||
protected.PUT("/traders/:id/prompt", s.handleUpdateTraderPrompt)
|
||||
|
||||
// 交易所配置
|
||||
api.GET("/exchanges", s.handleGetExchangeConfigs)
|
||||
api.PUT("/exchanges", s.handleUpdateExchangeConfigs)
|
||||
// AI模型配置
|
||||
protected.GET("/models", s.handleGetModelConfigs)
|
||||
protected.PUT("/models", s.handleUpdateModelConfigs)
|
||||
|
||||
// 指定trader的数据(使用query参数 ?trader_id=xxx)
|
||||
api.GET("/status", s.handleStatus)
|
||||
api.GET("/account", s.handleAccount)
|
||||
api.GET("/positions", s.handlePositions)
|
||||
api.GET("/decisions", s.handleDecisions)
|
||||
api.GET("/decisions/latest", s.handleLatestDecisions)
|
||||
api.GET("/statistics", s.handleStatistics)
|
||||
api.GET("/equity-history", s.handleEquityHistory)
|
||||
api.GET("/performance", s.handlePerformance)
|
||||
// 交易所配置
|
||||
protected.GET("/exchanges", s.handleGetExchangeConfigs)
|
||||
protected.PUT("/exchanges", s.handleUpdateExchangeConfigs)
|
||||
|
||||
// 竞赛总览
|
||||
protected.GET("/competition", s.handleCompetition)
|
||||
|
||||
// 指定trader的数据(使用query参数 ?trader_id=xxx)
|
||||
protected.GET("/status", s.handleStatus)
|
||||
protected.GET("/account", s.handleAccount)
|
||||
protected.GET("/positions", s.handlePositions)
|
||||
protected.GET("/decisions", s.handleDecisions)
|
||||
protected.GET("/decisions/latest", s.handleLatestDecisions)
|
||||
protected.GET("/statistics", s.handleStatistics)
|
||||
protected.GET("/equity-history", s.handleEquityHistory)
|
||||
protected.GET("/performance", s.handlePerformance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,17 +125,40 @@ func (s *Server) handleHealth(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetSystemConfig 获取系统配置(客户端需要知道的配置)
|
||||
func (s *Server) handleGetSystemConfig(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"admin_mode": auth.IsAdminMode(),
|
||||
})
|
||||
}
|
||||
|
||||
// getTraderFromQuery 从query参数获取trader
|
||||
func (s *Server) getTraderFromQuery(c *gin.Context) (*manager.TraderManager, string, error) {
|
||||
userID := c.GetString("user_id")
|
||||
traderID := c.Query("trader_id")
|
||||
|
||||
// 确保用户的交易员已加载到内存中
|
||||
err := s.traderManager.LoadUserTraders(s.database, userID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 加载用户 %s 的交易员失败: %v", userID, err)
|
||||
}
|
||||
|
||||
if traderID == "" {
|
||||
// 如果没有指定trader_id,返回第一个trader
|
||||
// 如果没有指定trader_id,返回该用户的第一个trader
|
||||
ids := s.traderManager.GetTraderIDs()
|
||||
if len(ids) == 0 {
|
||||
return nil, "", fmt.Errorf("没有可用的trader")
|
||||
}
|
||||
traderID = ids[0]
|
||||
|
||||
// 获取用户的交易员列表,优先返回用户自己的交易员
|
||||
userTraders, err := s.database.GetTraders(userID)
|
||||
if err == nil && len(userTraders) > 0 {
|
||||
traderID = userTraders[0].ID
|
||||
} else {
|
||||
traderID = ids[0]
|
||||
}
|
||||
}
|
||||
|
||||
return s.traderManager, traderID, nil
|
||||
}
|
||||
|
||||
@@ -121,6 +168,8 @@ type CreateTraderRequest struct {
|
||||
AIModelID string `json:"ai_model_id" binding:"required"`
|
||||
ExchangeID string `json:"exchange_id" binding:"required"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
CustomPrompt string `json:"custom_prompt"`
|
||||
OverrideBasePrompt bool `json:"override_base_prompt"`
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
@@ -150,15 +199,20 @@ type UpdateModelConfigRequest struct {
|
||||
|
||||
type UpdateExchangeConfigRequest struct {
|
||||
Exchanges map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Testnet bool `json:"testnet"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
|
||||
AsterUser string `json:"aster_user"`
|
||||
AsterSigner string `json:"aster_signer"`
|
||||
AsterPrivateKey string `json:"aster_private_key"`
|
||||
} `json:"exchanges"`
|
||||
}
|
||||
|
||||
// handleCreateTrader 创建新的AI交易员
|
||||
func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
var req CreateTraderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -171,10 +225,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
// 创建交易员配置
|
||||
trader := &config.TraderConfig{
|
||||
ID: traderID,
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
AIModelID: req.AIModelID,
|
||||
ExchangeID: req.ExchangeID,
|
||||
InitialBalance: req.InitialBalance,
|
||||
CustomPrompt: req.CustomPrompt,
|
||||
OverrideBasePrompt: req.OverrideBasePrompt,
|
||||
ScanIntervalMinutes: 3, // 默认3分钟
|
||||
IsRunning: false,
|
||||
}
|
||||
@@ -186,6 +243,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 立即将新交易员加载到TraderManager中
|
||||
err = s.traderManager.LoadUserTraders(s.database, userID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 加载用户交易员到内存失败: %v", err)
|
||||
// 这里不返回错误,因为交易员已经成功创建到数据库
|
||||
}
|
||||
|
||||
log.Printf("✓ 创建交易员成功: %s (模型: %s, 交易所: %s)", req.Name, req.AIModelID, req.ExchangeID)
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
@@ -198,10 +262,11 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
||||
|
||||
// handleDeleteTrader 删除交易员
|
||||
func (s *Server) handleDeleteTrader(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
traderID := c.Param("id")
|
||||
|
||||
// 从数据库删除
|
||||
err := s.database.DeleteTrader(traderID)
|
||||
err := s.database.DeleteTrader(userID, traderID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("删除交易员失败: %v", err)})
|
||||
return
|
||||
@@ -246,7 +311,8 @@ func (s *Server) handleStartTrader(c *gin.Context) {
|
||||
}()
|
||||
|
||||
// 更新数据库中的运行状态
|
||||
err = s.database.UpdateTraderStatus(traderID, true)
|
||||
userID := c.GetString("user_id")
|
||||
err = s.database.UpdateTraderStatus(userID, traderID, true)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 更新交易员状态失败: %v", err)
|
||||
}
|
||||
@@ -276,7 +342,8 @@ func (s *Server) handleStopTrader(c *gin.Context) {
|
||||
trader.Stop()
|
||||
|
||||
// 更新数据库中的运行状态
|
||||
err = s.database.UpdateTraderStatus(traderID, false)
|
||||
userID := c.GetString("user_id")
|
||||
err = s.database.UpdateTraderStatus(userID, traderID, false)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 更新交易员状态失败: %v", err)
|
||||
}
|
||||
@@ -285,19 +352,57 @@ func (s *Server) handleStopTrader(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "交易员已停止"})
|
||||
}
|
||||
|
||||
// handleUpdateTraderPrompt 更新交易员自定义Prompt
|
||||
func (s *Server) handleUpdateTraderPrompt(c *gin.Context) {
|
||||
traderID := c.Param("id")
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
var req struct {
|
||||
CustomPrompt string `json:"custom_prompt"`
|
||||
OverrideBasePrompt bool `json:"override_base_prompt"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
err := s.database.UpdateTraderCustomPrompt(userID, traderID, req.CustomPrompt, req.OverrideBasePrompt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新自定义prompt失败: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果trader在内存中,更新其custom prompt和override设置
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err == nil {
|
||||
trader.SetCustomPrompt(req.CustomPrompt)
|
||||
trader.SetOverrideBasePrompt(req.OverrideBasePrompt)
|
||||
log.Printf("✓ 已更新交易员 %s 的自定义prompt (覆盖基础=%v)", trader.GetName(), req.OverrideBasePrompt)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "自定义prompt已更新"})
|
||||
}
|
||||
|
||||
// handleGetModelConfigs 获取AI模型配置
|
||||
func (s *Server) handleGetModelConfigs(c *gin.Context) {
|
||||
models, err := s.database.GetAIModels()
|
||||
userID := c.GetString("user_id")
|
||||
log.Printf("🔍 查询用户 %s 的AI模型配置", userID)
|
||||
models, err := s.database.GetAIModels(userID)
|
||||
if err != nil {
|
||||
log.Printf("❌ 获取AI模型配置失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取AI模型配置失败: %v", err)})
|
||||
return
|
||||
}
|
||||
log.Printf("✅ 找到 %d 个AI模型配置", len(models))
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
}
|
||||
|
||||
// handleUpdateModelConfigs 更新AI模型配置
|
||||
func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
var req UpdateModelConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -306,7 +411,7 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
|
||||
|
||||
// 更新每个模型的配置
|
||||
for modelID, modelData := range req.Models {
|
||||
err := s.database.UpdateAIModel(modelID, modelData.Enabled, modelData.APIKey)
|
||||
err := s.database.UpdateAIModel(userID, modelID, modelData.Enabled, modelData.APIKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新模型 %s 失败: %v", modelID, err)})
|
||||
return
|
||||
@@ -319,17 +424,22 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
|
||||
|
||||
// handleGetExchangeConfigs 获取交易所配置
|
||||
func (s *Server) handleGetExchangeConfigs(c *gin.Context) {
|
||||
exchanges, err := s.database.GetExchanges()
|
||||
userID := c.GetString("user_id")
|
||||
log.Printf("🔍 查询用户 %s 的交易所配置", userID)
|
||||
exchanges, err := s.database.GetExchanges(userID)
|
||||
if err != nil {
|
||||
log.Printf("❌ 获取交易所配置失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取交易所配置失败: %v", err)})
|
||||
return
|
||||
}
|
||||
log.Printf("✅ 找到 %d 个交易所配置", len(exchanges))
|
||||
|
||||
c.JSON(http.StatusOK, exchanges)
|
||||
}
|
||||
|
||||
// handleUpdateExchangeConfigs 更新交易所配置
|
||||
func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
var req UpdateExchangeConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -338,7 +448,7 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
|
||||
|
||||
// 更新每个交易所的配置
|
||||
for exchangeID, exchangeData := range req.Exchanges {
|
||||
err := s.database.UpdateExchange(exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet)
|
||||
err := s.database.UpdateExchange(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新交易所 %s 失败: %v", exchangeID, err)})
|
||||
return
|
||||
@@ -351,7 +461,8 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
|
||||
|
||||
// handleTraderList trader列表
|
||||
func (s *Server) handleTraderList(c *gin.Context) {
|
||||
traders, err := s.database.GetTraders()
|
||||
userID := c.GetString("user_id")
|
||||
traders, err := s.database.GetTraders(userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取交易员列表失败: %v", err)})
|
||||
return
|
||||
@@ -539,6 +650,27 @@ func (s *Server) handleStatistics(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// handleCompetition 竞赛总览(对比所有trader)
|
||||
func (s *Server) handleCompetition(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// 确保用户的交易员已加载到内存中
|
||||
err := s.traderManager.LoadUserTraders(s.database, userID)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 加载用户 %s 的交易员失败: %v", userID, err)
|
||||
}
|
||||
|
||||
competition, err := s.traderManager.GetCompetitionData(userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("获取竞赛数据失败: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, competition)
|
||||
}
|
||||
|
||||
// handleEquityHistory 收益率历史数据
|
||||
func (s *Server) handleEquityHistory(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
@@ -652,6 +784,278 @@ func (s *Server) handlePerformance(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, performance)
|
||||
}
|
||||
|
||||
// authMiddleware JWT认证中间件
|
||||
func (s *Server) authMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 如果是管理员模式,直接使用admin用户
|
||||
if auth.IsAdminMode() {
|
||||
c.Set("user_id", "admin")
|
||||
c.Set("email", "admin@localhost")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少Authorization头"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer token格式
|
||||
tokenParts := strings.Split(authHeader, " ")
|
||||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Authorization格式"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 验证JWT token
|
||||
claims, err := auth.ValidateJWT(tokenParts[1])
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的token: " + err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存储到上下文中
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("email", claims.Email)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// handleRegister 处理用户注册请求
|
||||
func (s *Server) handleRegister(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
_, err := s.database.GetUserByEmail(req.Email)
|
||||
if err == nil {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "邮箱已被注册"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成密码哈希
|
||||
passwordHash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "密码处理失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成OTP密钥
|
||||
otpSecret, err := auth.GenerateOTPSecret()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "OTP密钥生成失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建用户(未验证OTP状态)
|
||||
userID := uuid.New().String()
|
||||
user := &config.User{
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: passwordHash,
|
||||
OTPSecret: otpSecret,
|
||||
OTPVerified: false,
|
||||
}
|
||||
|
||||
err = s.database.CreateUser(user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建用户失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回OTP设置信息
|
||||
qrCodeURL := auth.GetOTPQRCodeURL(otpSecret, req.Email)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": userID,
|
||||
"email": req.Email,
|
||||
"otp_secret": otpSecret,
|
||||
"qr_code_url": qrCodeURL,
|
||||
"message": "请使用Google Authenticator扫描二维码并验证OTP",
|
||||
})
|
||||
}
|
||||
|
||||
// handleCompleteRegistration 完成注册(验证OTP)
|
||||
func (s *Server) handleCompleteRegistration(c *gin.Context) {
|
||||
var req struct {
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
OTPCode string `json:"otp_code" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.database.GetUserByID(req.UserID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "用户不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证OTP
|
||||
if !auth.VerifyOTP(user.OTPSecret, req.OTPCode) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "OTP验证码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新用户OTP验证状态
|
||||
err = s.database.UpdateUserOTPVerified(req.UserID, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新用户状态失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
token, err := auth.GenerateJWT(user.ID, user.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成token失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化用户的默认模型和交易所配置
|
||||
err = s.initUserDefaultConfigs(user.ID)
|
||||
if err != nil {
|
||||
log.Printf("初始化用户默认配置失败: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"user_id": user.ID,
|
||||
"email": user.Email,
|
||||
"message": "注册完成",
|
||||
})
|
||||
}
|
||||
|
||||
// handleLogin 处理用户登录请求
|
||||
func (s *Server) handleLogin(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.database.GetUserByEmail(req.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "邮箱或密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !auth.CheckPassword(req.Password, user.PasswordHash) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "邮箱或密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查OTP是否已验证
|
||||
if !user.OTPVerified {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "账户未完成OTP设置",
|
||||
"user_id": user.ID,
|
||||
"requires_otp_setup": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回需要OTP验证的状态
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": user.ID,
|
||||
"email": user.Email,
|
||||
"message": "请输入Google Authenticator验证码",
|
||||
"requires_otp": true,
|
||||
})
|
||||
}
|
||||
|
||||
// handleVerifyOTP 验证OTP并完成登录
|
||||
func (s *Server) handleVerifyOTP(c *gin.Context) {
|
||||
var req struct {
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
OTPCode string `json:"otp_code" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.database.GetUserByID(req.UserID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "用户不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证OTP
|
||||
if !auth.VerifyOTP(user.OTPSecret, req.OTPCode) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "验证码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
token, err := auth.GenerateJWT(user.ID, user.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成token失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"user_id": user.ID,
|
||||
"email": user.Email,
|
||||
"message": "登录成功",
|
||||
})
|
||||
}
|
||||
|
||||
// initUserDefaultConfigs 为新用户初始化默认的模型和交易所配置
|
||||
func (s *Server) initUserDefaultConfigs(userID string) error {
|
||||
// 注释掉自动创建默认配置,让用户手动添加
|
||||
// 这样新用户注册后不会自动有配置项
|
||||
log.Printf("用户 %s 注册完成,等待手动配置AI模型和交易所", userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleGetSupportedModels 获取系统支持的AI模型列表
|
||||
func (s *Server) handleGetSupportedModels(c *gin.Context) {
|
||||
// 返回系统支持的AI模型(从default用户获取)
|
||||
models, err := s.database.GetAIModels("default")
|
||||
if err != nil {
|
||||
log.Printf("❌ 获取支持的AI模型失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取支持的AI模型失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
}
|
||||
|
||||
// handleGetSupportedExchanges 获取系统支持的交易所列表
|
||||
func (s *Server) handleGetSupportedExchanges(c *gin.Context) {
|
||||
// 返回系统支持的交易所(从default用户获取)
|
||||
exchanges, err := s.database.GetExchanges("default")
|
||||
if err != nil {
|
||||
log.Printf("❌ 获取支持的交易所失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取支持的交易所失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, exchanges)
|
||||
}
|
||||
|
||||
// Start 启动服务器
|
||||
func (s *Server) Start() error {
|
||||
addr := fmt.Sprintf(":%d", s.port)
|
||||
|
||||
Reference in New Issue
Block a user