mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2026-07-06 04:20:59 +08:00
feat: enhance backtest with real-time positions, P&L fixes, and strategy integration
- Add real-time position display with unrealized P&L during backtest - Fix P&L calculation by tracking accumulated opening fees - Add strategy coin source resolution (AI500, OI Top, mixed) - Infer AI provider from model name for better compatibility - Cap position size to available margin to prevent insufficient cash errors - Fix trade markers on K-line chart (long/short instead of buy/sell) - Add QuantData and OI ranking to backtest decision context
This commit is contained in:
301
api/backtest.go
301
api/backtest.go
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -12,6 +13,9 @@ import (
|
||||
"time"
|
||||
|
||||
"nofx/backtest"
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"nofx/provider"
|
||||
"nofx/store"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -32,6 +36,7 @@ func (s *Server) registerBacktestRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/trace", s.handleBacktestTrace)
|
||||
router.GET("/decisions", s.handleBacktestDecisions)
|
||||
router.GET("/export", s.handleBacktestExport)
|
||||
router.GET("/klines", s.handleBacktestKlines)
|
||||
}
|
||||
|
||||
type backtestStartRequest struct {
|
||||
@@ -65,11 +70,54 @@ func (s *Server) handleBacktestStart(c *gin.Context) {
|
||||
}
|
||||
cfg.CustomPrompt = strings.TrimSpace(cfg.CustomPrompt)
|
||||
cfg.UserID = normalizeUserID(c.GetString("user_id"))
|
||||
|
||||
logger.Infof("📊 Backtest request - symbols from request: %v (count=%d), strategyID: %s",
|
||||
cfg.Symbols, len(cfg.Symbols), cfg.StrategyID)
|
||||
|
||||
// Load strategy config if strategy_id is provided
|
||||
if cfg.StrategyID != "" {
|
||||
strategy, err := s.store.Strategy().Get(cfg.UserID, cfg.StrategyID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to load strategy: %v", err)})
|
||||
return
|
||||
}
|
||||
if strategy == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("strategy not found: %s", cfg.StrategyID)})
|
||||
return
|
||||
}
|
||||
var strategyConfig store.StrategyConfig
|
||||
if err := json.Unmarshal([]byte(strategy.Config), &strategyConfig); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to parse strategy config: %v", err)})
|
||||
return
|
||||
}
|
||||
cfg.SetLoadedStrategy(&strategyConfig)
|
||||
logger.Infof("📊 Backtest using saved strategy: %s (%s)", strategy.Name, strategy.ID)
|
||||
logger.Infof("📊 Strategy coin source: type=%s, use_coin_pool=%v, use_oi_top=%v, static_coins=%v",
|
||||
strategyConfig.CoinSource.SourceType,
|
||||
strategyConfig.CoinSource.UseCoinPool,
|
||||
strategyConfig.CoinSource.UseOITop,
|
||||
strategyConfig.CoinSource.StaticCoins)
|
||||
|
||||
// If no symbols provided, fetch from strategy's coin source
|
||||
if len(cfg.Symbols) == 0 {
|
||||
symbols, err := s.resolveStrategyCoins(&strategyConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to resolve coins from strategy: %v", err)})
|
||||
return
|
||||
}
|
||||
cfg.Symbols = symbols
|
||||
logger.Infof("📊 Resolved %d coins from strategy: %v", len(symbols), symbols)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hydrateBacktestAIConfig(&cfg); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("📊 Starting backtest with final config: runID=%s, symbols=%v (count=%d), strategyID=%s",
|
||||
cfg.RunID, cfg.Symbols, len(cfg.Symbols), cfg.StrategyID)
|
||||
|
||||
runner, err := s.backtestManager.Start(context.Background(), cfg)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -443,6 +491,89 @@ func (s *Server) handleBacktestExport(c *gin.Context) {
|
||||
c.FileAttachment(path, filename)
|
||||
}
|
||||
|
||||
func (s *Server) handleBacktestKlines(c *gin.Context) {
|
||||
if s.backtestManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "backtest manager unavailable"})
|
||||
return
|
||||
}
|
||||
userID := normalizeUserID(c.GetString("user_id"))
|
||||
runID := c.Query("run_id")
|
||||
symbol := c.Query("symbol")
|
||||
timeframe := c.Query("timeframe")
|
||||
|
||||
if runID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "run_id is required"})
|
||||
return
|
||||
}
|
||||
if symbol == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "symbol is required"})
|
||||
return
|
||||
}
|
||||
|
||||
meta, err := s.ensureBacktestRunOwnership(runID, userID)
|
||||
if writeBacktestAccessError(c, err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Load config to get time range
|
||||
cfg, err := backtest.LoadConfig(runID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "failed to load backtest config"})
|
||||
return
|
||||
}
|
||||
|
||||
// Use decision timeframe if not specified
|
||||
if timeframe == "" {
|
||||
timeframe = cfg.DecisionTimeframe
|
||||
if timeframe == "" {
|
||||
timeframe = "15m"
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch klines for the backtest time range
|
||||
startTime := time.Unix(cfg.StartTS, 0)
|
||||
endTime := time.Unix(cfg.EndTS, 0)
|
||||
|
||||
klines, err := market.GetKlinesRange(symbol, timeframe, startTime, endTime)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to fetch klines for %s: %v", symbol, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to fetch klines: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to response format
|
||||
type KlineResponse struct {
|
||||
Time int64 `json:"time"`
|
||||
Open float64 `json:"open"`
|
||||
High float64 `json:"high"`
|
||||
Low float64 `json:"low"`
|
||||
Close float64 `json:"close"`
|
||||
Volume float64 `json:"volume"`
|
||||
}
|
||||
|
||||
result := make([]KlineResponse, len(klines))
|
||||
for i, k := range klines {
|
||||
result[i] = KlineResponse{
|
||||
Time: k.OpenTime / 1000, // Convert to seconds for lightweight-charts
|
||||
Open: k.Open,
|
||||
High: k.High,
|
||||
Low: k.Low,
|
||||
Close: k.Close,
|
||||
Volume: k.Volume,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
"start_ts": cfg.StartTS,
|
||||
"end_ts": cfg.EndTS,
|
||||
"count": len(result),
|
||||
"klines": result,
|
||||
"run_id": meta.RunID,
|
||||
})
|
||||
}
|
||||
|
||||
func queryInt(c *gin.Context, name string, fallback int) int {
|
||||
if value := c.Query(name); value != "" {
|
||||
if v, err := strconv.Atoi(value); err == nil {
|
||||
@@ -498,6 +629,155 @@ func writeBacktestAccessError(c *gin.Context, err error) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// resolveStrategyCoins fetches coins based on strategy's coin source configuration
|
||||
func (s *Server) resolveStrategyCoins(strategyConfig *store.StrategyConfig) ([]string, error) {
|
||||
if strategyConfig == nil {
|
||||
return nil, fmt.Errorf("strategy config is nil")
|
||||
}
|
||||
|
||||
coinSource := strategyConfig.CoinSource
|
||||
var symbols []string
|
||||
symbolSet := make(map[string]bool)
|
||||
|
||||
// Set custom API URLs if provided
|
||||
if coinSource.CoinPoolAPIURL != "" {
|
||||
provider.SetCoinPoolAPI(coinSource.CoinPoolAPIURL)
|
||||
}
|
||||
if coinSource.OITopAPIURL != "" {
|
||||
provider.SetOITopAPI(coinSource.OITopAPIURL)
|
||||
}
|
||||
|
||||
// Handle empty source_type - check flags for backward compatibility
|
||||
sourceType := coinSource.SourceType
|
||||
if sourceType == "" {
|
||||
if coinSource.UseCoinPool && coinSource.UseOITop {
|
||||
sourceType = "mixed"
|
||||
} else if coinSource.UseCoinPool {
|
||||
sourceType = "coinpool"
|
||||
} else if coinSource.UseOITop {
|
||||
sourceType = "oi_top"
|
||||
} else if len(coinSource.StaticCoins) > 0 {
|
||||
sourceType = "static"
|
||||
} else {
|
||||
return nil, fmt.Errorf("strategy has no coin source configured")
|
||||
}
|
||||
logger.Infof("📊 Inferred source_type=%s from flags", sourceType)
|
||||
}
|
||||
|
||||
switch sourceType {
|
||||
case "static":
|
||||
for _, sym := range coinSource.StaticCoins {
|
||||
sym = market.Normalize(sym)
|
||||
if !symbolSet[sym] {
|
||||
symbols = append(symbols, sym)
|
||||
symbolSet[sym] = true
|
||||
}
|
||||
}
|
||||
|
||||
case "coinpool":
|
||||
limit := coinSource.CoinPoolLimit
|
||||
if limit <= 0 {
|
||||
limit = 30
|
||||
}
|
||||
logger.Infof("📊 Fetching AI500 coins with limit=%d", limit)
|
||||
coins, err := provider.GetTopRatedCoins(limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get AI500 coins: %w", err)
|
||||
}
|
||||
logger.Infof("📊 Got %d coins from AI500: %v", len(coins), coins)
|
||||
for _, sym := range coins {
|
||||
sym = market.Normalize(sym)
|
||||
if !symbolSet[sym] {
|
||||
symbols = append(symbols, sym)
|
||||
symbolSet[sym] = true
|
||||
}
|
||||
}
|
||||
|
||||
case "oi_top":
|
||||
coins, err := provider.GetOITopSymbols()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OI Top coins: %w", err)
|
||||
}
|
||||
limit := coinSource.OITopLimit
|
||||
if limit <= 0 || limit > len(coins) {
|
||||
limit = len(coins)
|
||||
}
|
||||
for i, sym := range coins {
|
||||
if i >= limit {
|
||||
break
|
||||
}
|
||||
sym = market.Normalize(sym)
|
||||
if !symbolSet[sym] {
|
||||
symbols = append(symbols, sym)
|
||||
symbolSet[sym] = true
|
||||
}
|
||||
}
|
||||
|
||||
case "mixed":
|
||||
// Get from coin pool
|
||||
if coinSource.UseCoinPool {
|
||||
limit := coinSource.CoinPoolLimit
|
||||
if limit <= 0 {
|
||||
limit = 30
|
||||
}
|
||||
coins, err := provider.GetTopRatedCoins(limit)
|
||||
if err != nil {
|
||||
logger.Warnf("Failed to get AI500 coins: %v", err)
|
||||
} else {
|
||||
for _, sym := range coins {
|
||||
sym = market.Normalize(sym)
|
||||
if !symbolSet[sym] {
|
||||
symbols = append(symbols, sym)
|
||||
symbolSet[sym] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get from OI Top
|
||||
if coinSource.UseOITop {
|
||||
coins, err := provider.GetOITopSymbols()
|
||||
if err != nil {
|
||||
logger.Warnf("Failed to get OI Top coins: %v", err)
|
||||
} else {
|
||||
limit := coinSource.OITopLimit
|
||||
if limit <= 0 || limit > len(coins) {
|
||||
limit = len(coins)
|
||||
}
|
||||
for i, sym := range coins {
|
||||
if i >= limit {
|
||||
break
|
||||
}
|
||||
sym = market.Normalize(sym)
|
||||
if !symbolSet[sym] {
|
||||
symbols = append(symbols, sym)
|
||||
symbolSet[sym] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add static coins
|
||||
for _, sym := range coinSource.StaticCoins {
|
||||
sym = market.Normalize(sym)
|
||||
if !symbolSet[sym] {
|
||||
symbols = append(symbols, sym)
|
||||
symbolSet[sym] = true
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown coin source type: %s", sourceType)
|
||||
}
|
||||
|
||||
if len(symbols) == 0 {
|
||||
return nil, fmt.Errorf("no coins resolved from strategy")
|
||||
}
|
||||
|
||||
logger.Infof("📊 Final resolved symbols: %d coins - %v", len(symbols), symbols)
|
||||
return symbols, nil
|
||||
}
|
||||
|
||||
func (s *Server) resolveBacktestAIConfig(cfg *backtest.BacktestConfig, userID string) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config is nil")
|
||||
@@ -549,7 +829,26 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
|
||||
return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name)
|
||||
}
|
||||
|
||||
cfg.AICfg.Provider = strings.ToLower(model.Provider)
|
||||
provider := strings.ToLower(strings.TrimSpace(model.Provider))
|
||||
// Ensure provider is never empty or "inherit" - infer from model name if needed
|
||||
if provider == "" || provider == "inherit" {
|
||||
modelNameLower := strings.ToLower(model.Name)
|
||||
if strings.Contains(modelNameLower, "claude") || strings.Contains(modelNameLower, "anthropic") {
|
||||
provider = "anthropic"
|
||||
} else if strings.Contains(modelNameLower, "gpt") || strings.Contains(modelNameLower, "openai") {
|
||||
provider = "openai"
|
||||
} else if strings.Contains(modelNameLower, "gemini") || strings.Contains(modelNameLower, "google") {
|
||||
provider = "google"
|
||||
} else if strings.Contains(modelNameLower, "deepseek") {
|
||||
provider = "deepseek"
|
||||
} else if model.CustomAPIURL != "" {
|
||||
provider = "custom"
|
||||
} else {
|
||||
provider = "openai" // default fallback
|
||||
}
|
||||
logger.Infof("📊 Inferred AI provider '%s' from model name '%s'", provider, model.Name)
|
||||
}
|
||||
cfg.AICfg.Provider = provider
|
||||
cfg.AICfg.APIKey = apiKey
|
||||
cfg.AICfg.BaseURL = strings.TrimSpace(model.CustomAPIURL)
|
||||
modelName := strings.TrimSpace(model.CustomModelName)
|
||||
|
||||
Reference in New Issue
Block a user