feat: enhance backtest with real-time positions, P&L fixes, and strategy integration

- Add real-time position display with unrealized P&L during backtest
- Fix P&L calculation by tracking accumulated opening fees
- Add strategy coin source resolution (AI500, OI Top, mixed)
- Infer AI provider from model name for better compatibility
- Cap position size to available margin to prevent insufficient cash errors
- Fix trade markers on K-line chart (long/short instead of buy/sell)
- Add QuantData and OI ranking to backtest decision context
This commit is contained in:
tinkle-community
2025-12-20 01:10:11 +08:00
parent 5534861fe5
commit e2d702c662
9 changed files with 1144 additions and 62 deletions

View File

@@ -3,6 +3,7 @@ package api
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/http"
@@ -12,6 +13,9 @@ import (
"time"
"nofx/backtest"
"nofx/logger"
"nofx/market"
"nofx/provider"
"nofx/store"
"github.com/gin-gonic/gin"
@@ -32,6 +36,7 @@ func (s *Server) registerBacktestRoutes(router *gin.RouterGroup) {
router.GET("/trace", s.handleBacktestTrace)
router.GET("/decisions", s.handleBacktestDecisions)
router.GET("/export", s.handleBacktestExport)
router.GET("/klines", s.handleBacktestKlines)
}
type backtestStartRequest struct {
@@ -65,11 +70,54 @@ func (s *Server) handleBacktestStart(c *gin.Context) {
}
cfg.CustomPrompt = strings.TrimSpace(cfg.CustomPrompt)
cfg.UserID = normalizeUserID(c.GetString("user_id"))
logger.Infof("📊 Backtest request - symbols from request: %v (count=%d), strategyID: %s",
cfg.Symbols, len(cfg.Symbols), cfg.StrategyID)
// Load strategy config if strategy_id is provided
if cfg.StrategyID != "" {
strategy, err := s.store.Strategy().Get(cfg.UserID, cfg.StrategyID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to load strategy: %v", err)})
return
}
if strategy == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("strategy not found: %s", cfg.StrategyID)})
return
}
var strategyConfig store.StrategyConfig
if err := json.Unmarshal([]byte(strategy.Config), &strategyConfig); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to parse strategy config: %v", err)})
return
}
cfg.SetLoadedStrategy(&strategyConfig)
logger.Infof("📊 Backtest using saved strategy: %s (%s)", strategy.Name, strategy.ID)
logger.Infof("📊 Strategy coin source: type=%s, use_coin_pool=%v, use_oi_top=%v, static_coins=%v",
strategyConfig.CoinSource.SourceType,
strategyConfig.CoinSource.UseCoinPool,
strategyConfig.CoinSource.UseOITop,
strategyConfig.CoinSource.StaticCoins)
// If no symbols provided, fetch from strategy's coin source
if len(cfg.Symbols) == 0 {
symbols, err := s.resolveStrategyCoins(&strategyConfig)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to resolve coins from strategy: %v", err)})
return
}
cfg.Symbols = symbols
logger.Infof("📊 Resolved %d coins from strategy: %v", len(symbols), symbols)
}
}
if err := s.hydrateBacktestAIConfig(&cfg); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
logger.Infof("📊 Starting backtest with final config: runID=%s, symbols=%v (count=%d), strategyID=%s",
cfg.RunID, cfg.Symbols, len(cfg.Symbols), cfg.StrategyID)
runner, err := s.backtestManager.Start(context.Background(), cfg)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -443,6 +491,89 @@ func (s *Server) handleBacktestExport(c *gin.Context) {
c.FileAttachment(path, filename)
}
func (s *Server) handleBacktestKlines(c *gin.Context) {
if s.backtestManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "backtest manager unavailable"})
return
}
userID := normalizeUserID(c.GetString("user_id"))
runID := c.Query("run_id")
symbol := c.Query("symbol")
timeframe := c.Query("timeframe")
if runID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "run_id is required"})
return
}
if symbol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "symbol is required"})
return
}
meta, err := s.ensureBacktestRunOwnership(runID, userID)
if writeBacktestAccessError(c, err) {
return
}
// Load config to get time range
cfg, err := backtest.LoadConfig(runID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "failed to load backtest config"})
return
}
// Use decision timeframe if not specified
if timeframe == "" {
timeframe = cfg.DecisionTimeframe
if timeframe == "" {
timeframe = "15m"
}
}
// Fetch klines for the backtest time range
startTime := time.Unix(cfg.StartTS, 0)
endTime := time.Unix(cfg.EndTS, 0)
klines, err := market.GetKlinesRange(symbol, timeframe, startTime, endTime)
if err != nil {
logger.Errorf("Failed to fetch klines for %s: %v", symbol, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to fetch klines: %v", err)})
return
}
// Convert to response format
type KlineResponse struct {
Time int64 `json:"time"`
Open float64 `json:"open"`
High float64 `json:"high"`
Low float64 `json:"low"`
Close float64 `json:"close"`
Volume float64 `json:"volume"`
}
result := make([]KlineResponse, len(klines))
for i, k := range klines {
result[i] = KlineResponse{
Time: k.OpenTime / 1000, // Convert to seconds for lightweight-charts
Open: k.Open,
High: k.High,
Low: k.Low,
Close: k.Close,
Volume: k.Volume,
}
}
c.JSON(http.StatusOK, gin.H{
"symbol": symbol,
"timeframe": timeframe,
"start_ts": cfg.StartTS,
"end_ts": cfg.EndTS,
"count": len(result),
"klines": result,
"run_id": meta.RunID,
})
}
func queryInt(c *gin.Context, name string, fallback int) int {
if value := c.Query(name); value != "" {
if v, err := strconv.Atoi(value); err == nil {
@@ -498,6 +629,155 @@ func writeBacktestAccessError(c *gin.Context, err error) bool {
return true
}
// resolveStrategyCoins fetches coins based on strategy's coin source configuration
func (s *Server) resolveStrategyCoins(strategyConfig *store.StrategyConfig) ([]string, error) {
if strategyConfig == nil {
return nil, fmt.Errorf("strategy config is nil")
}
coinSource := strategyConfig.CoinSource
var symbols []string
symbolSet := make(map[string]bool)
// Set custom API URLs if provided
if coinSource.CoinPoolAPIURL != "" {
provider.SetCoinPoolAPI(coinSource.CoinPoolAPIURL)
}
if coinSource.OITopAPIURL != "" {
provider.SetOITopAPI(coinSource.OITopAPIURL)
}
// Handle empty source_type - check flags for backward compatibility
sourceType := coinSource.SourceType
if sourceType == "" {
if coinSource.UseCoinPool && coinSource.UseOITop {
sourceType = "mixed"
} else if coinSource.UseCoinPool {
sourceType = "coinpool"
} else if coinSource.UseOITop {
sourceType = "oi_top"
} else if len(coinSource.StaticCoins) > 0 {
sourceType = "static"
} else {
return nil, fmt.Errorf("strategy has no coin source configured")
}
logger.Infof("📊 Inferred source_type=%s from flags", sourceType)
}
switch sourceType {
case "static":
for _, sym := range coinSource.StaticCoins {
sym = market.Normalize(sym)
if !symbolSet[sym] {
symbols = append(symbols, sym)
symbolSet[sym] = true
}
}
case "coinpool":
limit := coinSource.CoinPoolLimit
if limit <= 0 {
limit = 30
}
logger.Infof("📊 Fetching AI500 coins with limit=%d", limit)
coins, err := provider.GetTopRatedCoins(limit)
if err != nil {
return nil, fmt.Errorf("failed to get AI500 coins: %w", err)
}
logger.Infof("📊 Got %d coins from AI500: %v", len(coins), coins)
for _, sym := range coins {
sym = market.Normalize(sym)
if !symbolSet[sym] {
symbols = append(symbols, sym)
symbolSet[sym] = true
}
}
case "oi_top":
coins, err := provider.GetOITopSymbols()
if err != nil {
return nil, fmt.Errorf("failed to get OI Top coins: %w", err)
}
limit := coinSource.OITopLimit
if limit <= 0 || limit > len(coins) {
limit = len(coins)
}
for i, sym := range coins {
if i >= limit {
break
}
sym = market.Normalize(sym)
if !symbolSet[sym] {
symbols = append(symbols, sym)
symbolSet[sym] = true
}
}
case "mixed":
// Get from coin pool
if coinSource.UseCoinPool {
limit := coinSource.CoinPoolLimit
if limit <= 0 {
limit = 30
}
coins, err := provider.GetTopRatedCoins(limit)
if err != nil {
logger.Warnf("Failed to get AI500 coins: %v", err)
} else {
for _, sym := range coins {
sym = market.Normalize(sym)
if !symbolSet[sym] {
symbols = append(symbols, sym)
symbolSet[sym] = true
}
}
}
}
// Get from OI Top
if coinSource.UseOITop {
coins, err := provider.GetOITopSymbols()
if err != nil {
logger.Warnf("Failed to get OI Top coins: %v", err)
} else {
limit := coinSource.OITopLimit
if limit <= 0 || limit > len(coins) {
limit = len(coins)
}
for i, sym := range coins {
if i >= limit {
break
}
sym = market.Normalize(sym)
if !symbolSet[sym] {
symbols = append(symbols, sym)
symbolSet[sym] = true
}
}
}
}
// Add static coins
for _, sym := range coinSource.StaticCoins {
sym = market.Normalize(sym)
if !symbolSet[sym] {
symbols = append(symbols, sym)
symbolSet[sym] = true
}
}
default:
return nil, fmt.Errorf("unknown coin source type: %s", sourceType)
}
if len(symbols) == 0 {
return nil, fmt.Errorf("no coins resolved from strategy")
}
logger.Infof("📊 Final resolved symbols: %d coins - %v", len(symbols), symbols)
return symbols, nil
}
func (s *Server) resolveBacktestAIConfig(cfg *backtest.BacktestConfig, userID string) error {
if cfg == nil {
return fmt.Errorf("config is nil")
@@ -549,7 +829,26 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name)
}
cfg.AICfg.Provider = strings.ToLower(model.Provider)
provider := strings.ToLower(strings.TrimSpace(model.Provider))
// Ensure provider is never empty or "inherit" - infer from model name if needed
if provider == "" || provider == "inherit" {
modelNameLower := strings.ToLower(model.Name)
if strings.Contains(modelNameLower, "claude") || strings.Contains(modelNameLower, "anthropic") {
provider = "anthropic"
} else if strings.Contains(modelNameLower, "gpt") || strings.Contains(modelNameLower, "openai") {
provider = "openai"
} else if strings.Contains(modelNameLower, "gemini") || strings.Contains(modelNameLower, "google") {
provider = "google"
} else if strings.Contains(modelNameLower, "deepseek") {
provider = "deepseek"
} else if model.CustomAPIURL != "" {
provider = "custom"
} else {
provider = "openai" // default fallback
}
logger.Infof("📊 Inferred AI provider '%s' from model name '%s'", provider, model.Name)
}
cfg.AICfg.Provider = provider
cfg.AICfg.APIKey = apiKey
cfg.AICfg.BaseURL = strings.TrimSpace(model.CustomAPIURL)
modelName := strings.TrimSpace(model.CustomModelName)