Files
nofx/trader/auto_trader_grid.go
tinkle-community 5fb26c17dc feat: add AI grid trading and market regime classification
- Add GridTrader interface with PlaceLimitOrder, CancelOrder, GetOrderBook
- Implement GridTrader for all exchanges (Binance, Bybit, OKX, Bitget, Hyperliquid, Aster, Lighter)
- Add grid engine with ATR-based boundary calculation and fund distribution
- Add market regime classification documents (Chinese/English)
- Add GridConfigEditor component for frontend configuration
2026-01-13 10:33:02 +08:00

580 lines
16 KiB
Go

package trader
import (
"encoding/json"
"fmt"
"math"
"nofx/kernel"
"nofx/logger"
"nofx/market"
"nofx/store"
"sync"
"time"
)
// ============================================================================
// Grid Trading State Management
// ============================================================================
// GridState holds the runtime state for grid trading
type GridState struct {
mu sync.RWMutex
// Configuration
Config *store.GridStrategyConfig
// Grid levels
Levels []kernel.GridLevelInfo
// Calculated bounds
UpperPrice float64
LowerPrice float64
GridSpacing float64
// State flags
IsPaused bool
IsInitialized bool
// Performance tracking
TotalProfit float64
TotalTrades int
WinningTrades int
MaxDrawdown float64
PeakEquity float64
DailyPnL float64
LastDailyReset time.Time
// Order tracking
OrderBook map[string]int // OrderID -> LevelIndex
}
// NewGridState creates a new grid state
func NewGridState(config *store.GridStrategyConfig) *GridState {
return &GridState{
Config: config,
Levels: make([]kernel.GridLevelInfo, 0),
OrderBook: make(map[string]int),
}
}
// ============================================================================
// AutoTrader Grid Methods
// ============================================================================
// InitializeGrid initializes the grid state and calculates levels
func (at *AutoTrader) InitializeGrid() error {
if at.config.StrategyConfig == nil || at.config.StrategyConfig.GridConfig == nil {
return fmt.Errorf("grid configuration not found")
}
gridConfig := at.config.StrategyConfig.GridConfig
at.gridState = NewGridState(gridConfig)
// Get current market price
price, err := at.trader.GetMarketPrice(gridConfig.Symbol)
if err != nil {
return fmt.Errorf("failed to get market price: %w", err)
}
// Calculate grid bounds
if gridConfig.UseATRBounds {
// Get ATR for bound calculation
mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"4h"}, "4h", 20)
if err != nil {
logger.Warnf("Failed to get market data for ATR: %v, using default bounds", err)
at.calculateDefaultBounds(price, gridConfig)
} else {
at.calculateATRBounds(price, mktData, gridConfig)
}
} else {
// Use manual bounds
at.gridState.UpperPrice = gridConfig.UpperPrice
at.gridState.LowerPrice = gridConfig.LowerPrice
}
// Calculate grid spacing
at.gridState.GridSpacing = (at.gridState.UpperPrice - at.gridState.LowerPrice) / float64(gridConfig.GridCount-1)
// Initialize grid levels
at.initializeGridLevels(price, gridConfig)
at.gridState.IsInitialized = true
logger.Infof("📊 [Grid] Initialized: %d levels, $%.2f - $%.2f, spacing $%.2f",
gridConfig.GridCount, at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing)
return nil
}
// calculateDefaultBounds calculates default bounds based on price
func (at *AutoTrader) calculateDefaultBounds(price float64, config *store.GridStrategyConfig) {
// Default: ±3% from current price
multiplier := 0.03 * float64(config.GridCount) / 10
at.gridState.UpperPrice = price * (1 + multiplier)
at.gridState.LowerPrice = price * (1 - multiplier)
}
// calculateATRBounds calculates bounds using ATR
func (at *AutoTrader) calculateATRBounds(price float64, mktData *market.Data, config *store.GridStrategyConfig) {
atr := 0.0
if mktData.LongerTermContext != nil {
atr = mktData.LongerTermContext.ATR14
}
if atr <= 0 {
at.calculateDefaultBounds(price, config)
return
}
multiplier := config.ATRMultiplier
if multiplier <= 0 {
multiplier = 2.0
}
halfRange := atr * multiplier
at.gridState.UpperPrice = price + halfRange
at.gridState.LowerPrice = price - halfRange
}
// initializeGridLevels creates the grid level structure
func (at *AutoTrader) initializeGridLevels(currentPrice float64, config *store.GridStrategyConfig) {
levels := make([]kernel.GridLevelInfo, config.GridCount)
totalWeight := 0.0
weights := make([]float64, config.GridCount)
// Calculate weights based on distribution
for i := 0; i < config.GridCount; i++ {
switch config.Distribution {
case "gaussian":
// Gaussian distribution - more weight in the middle
center := float64(config.GridCount-1) / 2
sigma := float64(config.GridCount) / 4
weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma))
case "pyramid":
// Pyramid - more weight at bottom
weights[i] = float64(config.GridCount - i)
default: // uniform
weights[i] = 1.0
}
totalWeight += weights[i]
}
// Create levels
for i := 0; i < config.GridCount; i++ {
price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing
allocatedUSD := config.TotalInvestment * weights[i] / totalWeight
// Determine initial side (below current price = buy, above = sell)
side := "buy"
if price > currentPrice {
side = "sell"
}
levels[i] = kernel.GridLevelInfo{
Index: i,
Price: price,
State: "empty",
Side: side,
AllocatedUSD: allocatedUSD,
}
}
at.gridState.Levels = levels
}
// RunGridCycle executes one grid trading cycle
func (at *AutoTrader) RunGridCycle() error {
if at.gridState == nil || !at.gridState.IsInitialized {
if err := at.InitializeGrid(); err != nil {
return fmt.Errorf("failed to initialize grid: %w", err)
}
}
gridConfig := at.config.StrategyConfig.GridConfig
lang := at.config.StrategyConfig.Language
if lang == "" {
lang = "en"
}
// Build grid context
gridCtx, err := at.buildGridContext()
if err != nil {
return fmt.Errorf("failed to build grid context: %w", err)
}
// Get AI decisions
decision, err := kernel.GetGridDecisions(gridCtx, at.mcpClient, gridConfig, lang)
if err != nil {
return fmt.Errorf("failed to get grid decisions: %w", err)
}
// Execute decisions
for _, d := range decision.Decisions {
if err := at.executeGridDecision(&d); err != nil {
logger.Warnf("[Grid] Failed to execute decision %s: %v", d.Action, err)
}
}
// Sync state with exchange
at.syncGridState()
// Save decision record
at.saveGridDecisionRecord(decision)
return nil
}
// buildGridContext builds the context for AI grid decisions
func (at *AutoTrader) buildGridContext() (*kernel.GridContext, error) {
gridConfig := at.config.StrategyConfig.GridConfig
// Get market data
mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"5m", "4h"}, "5m", 50)
if err != nil {
return nil, fmt.Errorf("failed to get market data: %w", err)
}
// Build base context from market data
ctx := kernel.BuildGridContextFromMarketData(mktData, gridConfig)
// Add grid state
at.gridState.mu.RLock()
ctx.Levels = at.gridState.Levels
ctx.UpperPrice = at.gridState.UpperPrice
ctx.LowerPrice = at.gridState.LowerPrice
ctx.GridSpacing = at.gridState.GridSpacing
ctx.IsPaused = at.gridState.IsPaused
ctx.TotalProfit = at.gridState.TotalProfit
ctx.TotalTrades = at.gridState.TotalTrades
ctx.WinningTrades = at.gridState.WinningTrades
ctx.MaxDrawdown = at.gridState.MaxDrawdown
ctx.DailyPnL = at.gridState.DailyPnL
// Count active orders and filled levels
for _, level := range at.gridState.Levels {
if level.State == "pending" {
ctx.ActiveOrderCount++
} else if level.State == "filled" {
ctx.FilledLevelCount++
}
}
at.gridState.mu.RUnlock()
// Get account info
balance, err := at.trader.GetBalance()
if err == nil {
if equity, ok := balance["total_equity"].(float64); ok {
ctx.TotalEquity = equity
}
if available, ok := balance["availableBalance"].(float64); ok {
ctx.AvailableBalance = available
}
if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok {
ctx.UnrealizedPnL = unrealized
}
}
// Get current position
positions, err := at.trader.GetPositions()
if err == nil {
for _, pos := range positions {
if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol {
if size, ok := pos["positionAmt"].(float64); ok {
ctx.CurrentPosition = size
}
}
}
}
return ctx, nil
}
// executeGridDecision executes a single grid decision
func (at *AutoTrader) executeGridDecision(d *kernel.Decision) error {
switch d.Action {
case "place_buy_limit":
return at.placeGridLimitOrder(d, "BUY")
case "place_sell_limit":
return at.placeGridLimitOrder(d, "SELL")
case "cancel_order":
return at.cancelGridOrder(d)
case "cancel_all_orders":
return at.cancelAllGridOrders()
case "pause_grid":
return at.pauseGrid(d.Reasoning)
case "resume_grid":
return at.resumeGrid()
case "adjust_grid":
return at.adjustGrid(d)
case "hold":
logger.Infof("[Grid] Holding current state: %s", d.Reasoning)
return nil
// Support standard actions for closing positions
case "close_long":
_, err := at.trader.CloseLong(d.Symbol, d.Quantity)
return err
case "close_short":
_, err := at.trader.CloseShort(d.Symbol, d.Quantity)
return err
default:
logger.Warnf("[Grid] Unknown action: %s", d.Action)
return nil
}
}
// placeGridLimitOrder places a limit order for grid trading
func (at *AutoTrader) placeGridLimitOrder(d *kernel.Decision, side string) error {
// Check if trader supports GridTrader interface
gridTrader, ok := at.trader.(GridTrader)
if !ok {
// Fallback to adapter
gridTrader = NewGridTraderAdapter(at.trader)
}
gridConfig := at.config.StrategyConfig.GridConfig
req := &LimitOrderRequest{
Symbol: d.Symbol,
Side: side,
Price: d.Price,
Quantity: d.Quantity,
Leverage: gridConfig.Leverage,
PostOnly: gridConfig.UseMakerOnly,
ReduceOnly: false,
ClientID: fmt.Sprintf("grid-%d-%d", d.LevelIndex, time.Now().UnixNano()%1000000),
}
result, err := gridTrader.PlaceLimitOrder(req)
if err != nil {
return fmt.Errorf("failed to place limit order: %w", err)
}
// Update grid level state
at.gridState.mu.Lock()
if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) {
at.gridState.Levels[d.LevelIndex].State = "pending"
at.gridState.Levels[d.LevelIndex].OrderID = result.OrderID
at.gridState.Levels[d.LevelIndex].OrderQuantity = d.Quantity
at.gridState.OrderBook[result.OrderID] = d.LevelIndex
}
at.gridState.mu.Unlock()
logger.Infof("[Grid] Placed %s limit order at $%.2f, qty=%.4f, level=%d, orderID=%s",
side, d.Price, d.Quantity, d.LevelIndex, result.OrderID)
return nil
}
// cancelGridOrder cancels a specific grid order
func (at *AutoTrader) cancelGridOrder(d *kernel.Decision) error {
gridTrader, ok := at.trader.(GridTrader)
if !ok {
gridTrader = NewGridTraderAdapter(at.trader)
}
if err := gridTrader.CancelOrder(d.Symbol, d.OrderID); err != nil {
return fmt.Errorf("failed to cancel order: %w", err)
}
// Update state
at.gridState.mu.Lock()
if levelIdx, ok := at.gridState.OrderBook[d.OrderID]; ok {
if levelIdx >= 0 && levelIdx < len(at.gridState.Levels) {
at.gridState.Levels[levelIdx].State = "empty"
at.gridState.Levels[levelIdx].OrderID = ""
at.gridState.Levels[levelIdx].OrderQuantity = 0
}
delete(at.gridState.OrderBook, d.OrderID)
}
at.gridState.mu.Unlock()
logger.Infof("[Grid] Cancelled order: %s", d.OrderID)
return nil
}
// cancelAllGridOrders cancels all grid orders
func (at *AutoTrader) cancelAllGridOrders() error {
gridConfig := at.config.StrategyConfig.GridConfig
if err := at.trader.CancelAllOrders(gridConfig.Symbol); err != nil {
return fmt.Errorf("failed to cancel all orders: %w", err)
}
// Reset all pending levels
at.gridState.mu.Lock()
for i := range at.gridState.Levels {
if at.gridState.Levels[i].State == "pending" {
at.gridState.Levels[i].State = "empty"
at.gridState.Levels[i].OrderID = ""
at.gridState.Levels[i].OrderQuantity = 0
}
}
at.gridState.OrderBook = make(map[string]int)
at.gridState.mu.Unlock()
logger.Infof("[Grid] Cancelled all orders")
return nil
}
// pauseGrid pauses grid trading
func (at *AutoTrader) pauseGrid(reason string) error {
at.cancelAllGridOrders()
at.gridState.mu.Lock()
at.gridState.IsPaused = true
at.gridState.mu.Unlock()
logger.Infof("[Grid] Paused: %s", reason)
return nil
}
// resumeGrid resumes grid trading
func (at *AutoTrader) resumeGrid() error {
at.gridState.mu.Lock()
at.gridState.IsPaused = false
at.gridState.mu.Unlock()
logger.Infof("[Grid] Resumed")
return nil
}
// adjustGrid adjusts grid parameters
func (at *AutoTrader) adjustGrid(d *kernel.Decision) error {
// Cancel existing orders first
at.cancelAllGridOrders()
gridConfig := at.config.StrategyConfig.GridConfig
// Get current price
price, err := at.trader.GetMarketPrice(gridConfig.Symbol)
if err != nil {
return fmt.Errorf("failed to get market price: %w", err)
}
// Reinitialize grid levels
at.initializeGridLevels(price, gridConfig)
logger.Infof("[Grid] Adjusted grid bounds around price $%.2f", price)
return nil
}
// syncGridState syncs grid state with exchange
func (at *AutoTrader) syncGridState() {
gridConfig := at.config.StrategyConfig.GridConfig
// Get open orders from exchange
openOrders, err := at.trader.GetOpenOrders(gridConfig.Symbol)
if err != nil {
logger.Warnf("[Grid] Failed to get open orders: %v", err)
return
}
// Build set of active order IDs
activeOrderIDs := make(map[string]bool)
for _, order := range openOrders {
activeOrderIDs[order.OrderID] = true
}
// Update levels based on order status
at.gridState.mu.Lock()
for i := range at.gridState.Levels {
level := &at.gridState.Levels[i]
if level.State == "pending" && level.OrderID != "" {
if !activeOrderIDs[level.OrderID] {
// Order no longer exists - might be filled or cancelled
// Mark as filled (we'll need to verify with position data)
level.State = "filled"
level.PositionEntry = level.Price
at.gridState.TotalTrades++
logger.Infof("[Grid] Level %d order filled at $%.2f", i, level.Price)
}
}
}
at.gridState.mu.Unlock()
// Update position info
positions, err := at.trader.GetPositions()
if err != nil {
return
}
var totalPosition float64
for _, pos := range positions {
if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol {
if size, ok := pos["positionAmt"].(float64); ok {
totalPosition = size
}
if pnl, ok := pos["unRealizedProfit"].(float64); ok {
// Update unrealized PnL for filled levels
at.gridState.mu.Lock()
for i := range at.gridState.Levels {
if at.gridState.Levels[i].State == "filled" {
// Distribute PnL (simplified - in production, track per-level)
at.gridState.Levels[i].UnrealizedPnL = pnl / float64(at.gridState.TotalTrades)
}
}
at.gridState.mu.Unlock()
}
}
}
logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", totalPosition, len(openOrders))
}
// saveGridDecisionRecord saves the grid decision to database
func (at *AutoTrader) saveGridDecisionRecord(decision *kernel.FullDecision) {
if at.store == nil {
return
}
at.cycleNumber++
record := &store.DecisionRecord{
TraderID: at.id,
CycleNumber: at.cycleNumber,
Timestamp: time.Now().UTC(),
SystemPrompt: decision.SystemPrompt,
InputPrompt: decision.UserPrompt,
CoTTrace: decision.CoTTrace,
RawResponse: decision.RawResponse,
AIRequestDurationMs: decision.AIRequestDurationMs,
Success: true,
}
if len(decision.Decisions) > 0 {
decisionJSON, _ := json.MarshalIndent(decision.Decisions, "", " ")
record.DecisionJSON = string(decisionJSON)
// Convert kernel.Decision to store.DecisionAction for frontend display
for _, d := range decision.Decisions {
actionRecord := store.DecisionAction{
Action: d.Action,
Symbol: d.Symbol,
Quantity: d.Quantity,
Leverage: d.Leverage,
Price: d.Price,
StopLoss: d.StopLoss,
TakeProfit: d.TakeProfit,
Confidence: d.Confidence,
Reasoning: d.Reasoning,
Timestamp: time.Now().UTC(),
Success: true, // Grid decisions are executed inline
}
record.Decisions = append(record.Decisions, actionRecord)
}
}
record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("Grid cycle completed with %d decisions", len(decision.Decisions)))
if err := at.store.Decision().LogDecision(record); err != nil {
logger.Warnf("[Grid] Failed to save decision record: %v", err)
}
}
// IsGridStrategy returns true if current strategy is grid trading
func (at *AutoTrader) IsGridStrategy() bool {
if at.config.StrategyConfig == nil {
return false
}
return at.config.StrategyConfig.StrategyType == "grid_trading" && at.config.StrategyConfig.GridConfig != nil
}