diff --git a/kernel/grid_engine.go b/kernel/grid_engine.go index 243432a2..f376e259 100644 --- a/kernel/grid_engine.go +++ b/kernel/grid_engine.go @@ -84,6 +84,9 @@ type GridContext struct { // Box indicators (Donchian Channels) BoxData *market.BoxData `json:"box_data,omitempty"` + + // Grid direction (neutral, long, short, long_bias, short_bias) + CurrentDirection string `json:"current_direction,omitempty"` } // ============================================================================ @@ -279,6 +282,20 @@ func buildGridUserPromptZh(ctx *GridContext) string { sb.WriteString(fmt.Sprintf("- 活跃订单数: %d\n", ctx.ActiveOrderCount)) sb.WriteString(fmt.Sprintf("- 已成交层数: %d\n", ctx.FilledLevelCount)) sb.WriteString(fmt.Sprintf("- 网格已暂停: %v\n", ctx.IsPaused)) + if ctx.CurrentDirection != "" { + directionDescZh := map[string]string{ + "neutral": "中性 (50%买+50%卖)", + "long": "做多 (100%买)", + "short": "做空 (100%卖)", + "long_bias": "偏多 (70%买+30%卖)", + "short_bias": "偏空 (30%买+70%卖)", + } + desc := directionDescZh[ctx.CurrentDirection] + if desc == "" { + desc = ctx.CurrentDirection + } + sb.WriteString(fmt.Sprintf("- 网格方向: %s\n", desc)) + } sb.WriteString("\n") // Grid levels detail @@ -376,6 +393,20 @@ func buildGridUserPromptEn(ctx *GridContext) string { sb.WriteString(fmt.Sprintf("- Active Orders: %d\n", ctx.ActiveOrderCount)) sb.WriteString(fmt.Sprintf("- Filled Levels: %d\n", ctx.FilledLevelCount)) sb.WriteString(fmt.Sprintf("- Grid Paused: %v\n", ctx.IsPaused)) + if ctx.CurrentDirection != "" { + directionDescEn := map[string]string{ + "neutral": "Neutral (50% buy + 50% sell)", + "long": "Long (100% buy)", + "short": "Short (100% sell)", + "long_bias": "Long Bias (70% buy + 30% sell)", + "short_bias": "Short Bias (30% buy + 70% sell)", + } + desc := directionDescEn[ctx.CurrentDirection] + if desc == "" { + desc = ctx.CurrentDirection + } + sb.WriteString(fmt.Sprintf("- Grid Direction: %s\n", desc)) + } sb.WriteString("\n") // Grid levels detail diff --git a/market/types.go b/market/types.go index 7569c9f3..95c335da 100644 --- a/market/types.go +++ b/market/types.go @@ -226,3 +226,37 @@ const ( BreakoutMid BreakoutLevel = "mid" BreakoutLong BreakoutLevel = "long" ) + +// GridDirection represents the current grid trading direction bias +type GridDirection string + +const ( + GridDirectionNeutral GridDirection = "neutral" // 50% buy + 50% sell + GridDirectionLong GridDirection = "long" // 100% buy + GridDirectionShort GridDirection = "short" // 100% sell + GridDirectionLongBias GridDirection = "long_bias" // 70% buy + 30% sell (default) + GridDirectionShortBias GridDirection = "short_bias" // 30% buy + 70% sell (default) +) + +// GetBuySellRatio returns the buy and sell ratio for this direction +// biasRatio is the ratio for biased directions (default 0.7 means 70%/30%) +func (d GridDirection) GetBuySellRatio(biasRatio float64) (buyRatio, sellRatio float64) { + if biasRatio <= 0 || biasRatio > 1 { + biasRatio = 0.7 // Default 70%/30% + } + + switch d { + case GridDirectionNeutral: + return 0.5, 0.5 + case GridDirectionLong: + return 1.0, 0.0 + case GridDirectionShort: + return 0.0, 1.0 + case GridDirectionLongBias: + return biasRatio, 1.0 - biasRatio + case GridDirectionShortBias: + return 1.0 - biasRatio, biasRatio + default: + return 0.5, 0.5 + } +} diff --git a/store/grid.go b/store/grid.go index 49ce5708..10073897 100644 --- a/store/grid.go +++ b/store/grid.go @@ -63,6 +63,10 @@ type GridConfigModel struct { AIProvider string `json:"ai_provider" gorm:"default:deepseek"` AIModel string `json:"ai_model" gorm:"default:deepseek-chat"` IsActive bool `json:"is_active" gorm:"default:false"` + + // Direction adjustment settings + EnableDirectionAdjust bool `json:"enable_direction_adjust" gorm:"default:false"` + DirectionBiasRatio float64 `json:"direction_bias_ratio" gorm:"default:0.7"` } func (GridConfigModel) TableName() string { @@ -108,6 +112,11 @@ type GridInstanceModel struct { // Position adjustment due to breakout PositionReductionPct float64 `json:"position_reduction_pct" gorm:"default:0"` // 0 = normal, 50 = reduced + // Grid direction adjustment state + CurrentDirection string `json:"current_direction" gorm:"default:neutral"` + DirectionChangedAt time.Time `json:"direction_changed_at"` + DirectionChangeCount int `json:"direction_change_count" gorm:"default:0"` + TotalProfit float64 `json:"total_profit" gorm:"default:0"` TotalFees float64 `json:"total_fees" gorm:"default:0"` TotalTrades int `json:"total_trades" gorm:"default:0"` diff --git a/store/strategy.go b/store/strategy.go index 7fbdef99..d27a9948 100644 --- a/store/strategy.go +++ b/store/strategy.go @@ -81,6 +81,10 @@ type GridStrategyConfig struct { DailyLossLimitPct float64 `json:"daily_loss_limit_pct"` // Use maker-only orders for lower fees UseMakerOnly bool `json:"use_maker_only"` + // Enable automatic grid direction adjustment based on box breakouts + EnableDirectionAdjust bool `json:"enable_direction_adjust"` + // Direction bias ratio for long_bias/short_bias modes (default 0.7 = 70%/30%) + DirectionBiasRatio float64 `json:"direction_bias_ratio"` } // PromptSectionsConfig editable sections of System Prompt diff --git a/trader/auto_trader_grid.go b/trader/auto_trader_grid.go index 5b445b7f..37b62d2d 100644 --- a/trader/auto_trader_grid.go +++ b/trader/auto_trader_grid.go @@ -65,14 +65,20 @@ type GridState struct { // Current regime level CurrentRegimeLevel string + + // Grid direction adjustment + CurrentDirection market.GridDirection + DirectionChangedAt time.Time + DirectionChangeCount int } // NewGridState creates a new grid state func NewGridState(config *store.GridStrategyConfig) *GridState { return &GridState{ - Config: config, - Levels: make([]kernel.GridLevelInfo, 0), - OrderBook: make(map[string]int), + Config: config, + Levels: make([]kernel.GridLevelInfo, 0), + OrderBook: make(map[string]int), + CurrentDirection: market.GridDirectionNeutral, } } @@ -325,7 +331,17 @@ func (at *AutoTrader) checkBoxBreakout() error { } // Take action based on breakout level - action := getBreakoutAction(breakoutLevel) + // Use direction-aware action if enabled + enableDirectionAdjust := gridConfig.EnableDirectionAdjust + action := getBreakoutActionWithDirection(breakoutLevel, enableDirectionAdjust) + + // If direction adjustment action, determine the new direction + if action == BreakoutActionAdjustDirection { + box, _ := market.GetBoxData(gridConfig.Symbol) + newDirection := determineGridDirection(box, at.gridState.CurrentDirection, breakoutLevel, direction) + return at.executeDirectionAdjustment(newDirection) + } + return at.executeBreakoutAction(action) } @@ -358,11 +374,38 @@ func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error { logger.Infof("Failed to cancel orders: %v", err) } return at.closeAllPositions() + + case BreakoutActionAdjustDirection: + // Direction adjustment is handled separately via executeDirectionAdjustment + // This case should not be reached, but handle gracefully + logger.Infof("Direction adjustment action received via executeBreakoutAction") + return nil } return nil } +// executeDirectionAdjustment handles grid direction changes based on box breakout +func (at *AutoTrader) executeDirectionAdjustment(newDirection market.GridDirection) error { + at.gridState.mu.RLock() + oldDirection := at.gridState.CurrentDirection + at.gridState.mu.RUnlock() + + if oldDirection == newDirection { + return nil // No change needed + } + + logger.Infof("[Grid] Direction adjustment: %s → %s", oldDirection, newDirection) + + // Cancel existing orders before adjusting + if err := at.cancelAllGridOrders(); err != nil { + logger.Warnf("[Grid] Failed to cancel orders during direction adjustment: %v", err) + } + + // Apply the new direction + return at.adjustGridDirection(newDirection) +} + // closeAllPositions closes all open positions for the grid symbol func (at *AutoTrader) closeAllPositions() error { gridConfig := at.config.StrategyConfig.GridConfig @@ -410,10 +453,16 @@ func (at *AutoTrader) checkFalseBreakoutRecovery() error { breakoutLevel := at.gridState.BreakoutLevel isPaused := at.gridState.IsPaused positionReduction := at.gridState.PositionReductionPct + currentDirection := at.gridState.CurrentDirection at.gridState.mu.RUnlock() - // Only check if we had a breakout - if breakoutLevel == string(market.BreakoutNone) && positionReduction == 0 && !isPaused { + // Only check if we had a breakout or non-neutral direction + needsRecoveryCheck := breakoutLevel != string(market.BreakoutNone) || + positionReduction != 0 || + isPaused || + (gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral) + + if !needsRecoveryCheck { return nil } @@ -436,6 +485,18 @@ func (at *AutoTrader) checkFalseBreakoutRecovery() error { at.gridState.mu.Unlock() } + // Check for direction recovery toward neutral (if direction adjustment is enabled) + if gridConfig.EnableDirectionAdjust && currentDirection != market.GridDirectionNeutral { + if shouldRecoverDirection(box, currentDirection) { + newDirection := determineRecoveryDirection(box.CurrentPrice, box, currentDirection) + if newDirection != currentDirection { + logger.Infof("[Grid] Direction recovery: %s → %s (price back in short box)", + currentDirection, newDirection) + at.adjustGridDirection(newDirection) + } + } + } + return nil } @@ -570,6 +631,128 @@ func (at *AutoTrader) initializeGridLevels(currentPrice float64, config *store.G } at.gridState.Levels = levels + + // Apply direction-based side assignment if enabled + if config.EnableDirectionAdjust { + at.applyGridDirection(currentPrice) + } +} + +// applyGridDirection adjusts grid level sides based on the current direction +// This redistributes buy/sell levels according to the direction bias ratio +func (at *AutoTrader) applyGridDirection(currentPrice float64) { + config := at.gridState.Config + direction := at.gridState.CurrentDirection + + // Get bias ratio from config, default to 0.7 (70%/30%) + biasRatio := config.DirectionBiasRatio + if biasRatio <= 0 || biasRatio > 1 { + biasRatio = 0.7 + } + + buyRatio, _ := direction.GetBuySellRatio(biasRatio) + + // Calculate how many levels should be buy vs sell based on direction + totalLevels := len(at.gridState.Levels) + targetBuyLevels := int(float64(totalLevels) * buyRatio) + + // For neutral: use price-based assignment (buy below, sell above) + if direction == market.GridDirectionNeutral { + for i := range at.gridState.Levels { + if at.gridState.Levels[i].Price <= currentPrice { + at.gridState.Levels[i].Side = "buy" + } else { + at.gridState.Levels[i].Side = "sell" + } + } + return + } + + // For long/long_bias: more buy levels + // For short/short_bias: more sell levels + switch direction { + case market.GridDirectionLong: + // 100% buy - all levels are buy + for i := range at.gridState.Levels { + at.gridState.Levels[i].Side = "buy" + } + + case market.GridDirectionShort: + // 100% sell - all levels are sell + for i := range at.gridState.Levels { + at.gridState.Levels[i].Side = "sell" + } + + case market.GridDirectionLongBias, market.GridDirectionShortBias: + // Assign sides based on position relative to current price + // For long_bias: keep all below as buy, convert some above to buy + // For short_bias: keep all above as sell, convert some below to sell + buyCount := 0 + sellCount := 0 + + for i := range at.gridState.Levels { + needMoreBuys := buyCount < targetBuyLevels + needMoreSells := sellCount < (totalLevels - targetBuyLevels) + + if at.gridState.Levels[i].Price <= currentPrice { + // Level below or at current price + if needMoreBuys { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } else { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } + } else { + // Level above current price + if needMoreSells && direction == market.GridDirectionShortBias { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } else if needMoreBuys && direction == market.GridDirectionLongBias { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } else if needMoreSells { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } else { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } + } + } + } + + logger.Infof("[Grid] Applied direction %s: buy_ratio=%.0f%%, levels reconfigured", + direction, buyRatio*100) +} + +// adjustGridDirection handles runtime direction adjustment when breakout is detected +func (at *AutoTrader) adjustGridDirection(newDirection market.GridDirection) error { + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + oldDirection := at.gridState.CurrentDirection + if oldDirection == newDirection { + return nil // No change needed + } + + at.gridState.CurrentDirection = newDirection + at.gridState.DirectionChangedAt = time.Now() + at.gridState.DirectionChangeCount++ + + logger.Infof("[Grid] Direction changed: %s → %s (change count: %d)", + oldDirection, newDirection, at.gridState.DirectionChangeCount) + + // Get current price for recalculation + currentPrice, err := at.trader.GetMarketPrice(at.gridState.Config.Symbol) + if err != nil { + return fmt.Errorf("failed to get market price: %w", err) + } + + // Reapply direction to grid levels + at.applyGridDirection(currentPrice) + + return nil } // RunGridCycle executes one grid trading cycle @@ -1370,6 +1553,85 @@ func (at *AutoTrader) initializeGridLevelsLocked(currentPrice float64, config *s } at.gridState.Levels = levels + + // Apply direction-based side assignment if enabled (note: caller holds lock) + if config.EnableDirectionAdjust { + at.applyGridDirectionLocked(currentPrice) + } +} + +// applyGridDirectionLocked adjusts grid level sides based on the current direction (caller must hold lock) +func (at *AutoTrader) applyGridDirectionLocked(currentPrice float64) { + config := at.gridState.Config + direction := at.gridState.CurrentDirection + + // Get bias ratio from config, default to 0.7 (70%/30%) + biasRatio := config.DirectionBiasRatio + if biasRatio <= 0 || biasRatio > 1 { + biasRatio = 0.7 + } + + buyRatio, _ := direction.GetBuySellRatio(biasRatio) + + // For neutral: use price-based assignment (buy below, sell above) + if direction == market.GridDirectionNeutral { + for i := range at.gridState.Levels { + if at.gridState.Levels[i].Price <= currentPrice { + at.gridState.Levels[i].Side = "buy" + } else { + at.gridState.Levels[i].Side = "sell" + } + } + return + } + + totalLevels := len(at.gridState.Levels) + targetBuyLevels := int(float64(totalLevels) * buyRatio) + + switch direction { + case market.GridDirectionLong: + for i := range at.gridState.Levels { + at.gridState.Levels[i].Side = "buy" + } + + case market.GridDirectionShort: + for i := range at.gridState.Levels { + at.gridState.Levels[i].Side = "sell" + } + + case market.GridDirectionLongBias, market.GridDirectionShortBias: + buyCount := 0 + sellCount := 0 + + for i := range at.gridState.Levels { + needMoreBuys := buyCount < targetBuyLevels + needMoreSells := sellCount < (totalLevels - targetBuyLevels) + + if at.gridState.Levels[i].Price <= currentPrice { + if needMoreBuys { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } else { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } + } else { + if needMoreSells && direction == market.GridDirectionShortBias { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } else if needMoreBuys && direction == market.GridDirectionLongBias { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } else if needMoreSells { + at.gridState.Levels[i].Side = "sell" + sellCount++ + } else { + at.gridState.Levels[i].Side = "buy" + buyCount++ + } + } + } + } } // GridRiskInfo contains risk information for frontend display @@ -1397,6 +1659,11 @@ type GridRiskInfo struct { BreakoutLevel string `json:"breakout_level"` BreakoutDirection string `json:"breakout_direction"` + + // Grid direction + CurrentGridDirection string `json:"current_grid_direction"` + DirectionChangeCount int `json:"direction_change_count"` + EnableDirectionAdjust bool `json:"enable_direction_adjust"` } // GetGridRiskInfo returns current risk information for frontend display @@ -1513,6 +1780,10 @@ func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo { BreakoutLevel: at.gridState.BreakoutLevel, BreakoutDirection: at.gridState.BreakoutDirection, + + CurrentGridDirection: string(at.gridState.CurrentDirection), + DirectionChangeCount: at.gridState.DirectionChangeCount, + EnableDirectionAdjust: gridConfig.EnableDirectionAdjust, } } diff --git a/trader/grid_regime.go b/trader/grid_regime.go index e574cc1b..b45a9f92 100644 --- a/trader/grid_regime.go +++ b/trader/grid_regime.go @@ -194,3 +194,119 @@ func getBreakoutAction(level market.BreakoutLevel) BreakoutAction { return BreakoutActionNone } } + +// ============================================================================ +// Task 10: Grid Direction Adjustment +// ============================================================================ + +const ( + // BreakoutActionAdjustDirection adjusts grid direction based on breakout + BreakoutActionAdjustDirection BreakoutAction = 4 +) + +// determineGridDirection determines the new grid direction based on box breakout +// currentDirection: the current grid direction +// breakoutLevel: which box level has been broken (short/mid/long) +// direction: breakout direction ("up" or "down") +// Returns: the new grid direction +func determineGridDirection(box *market.BoxData, currentDirection market.GridDirection, breakoutLevel market.BreakoutLevel, direction string) market.GridDirection { + if box == nil { + return currentDirection + } + + price := box.CurrentPrice + + switch breakoutLevel { + case market.BreakoutShort: + // Short box breakout: bias direction + // Still within mid box, so not a full trend yet + if direction == "up" { + return market.GridDirectionLongBias + } + return market.GridDirectionShortBias + + case market.BreakoutMid: + // Mid box breakout: full direction + // More significant move, commit fully + if direction == "up" { + return market.GridDirectionLong + } + return market.GridDirectionShort + + case market.BreakoutLong: + // Long box breakout: handled by existing emergency logic + // Return current direction, let existing handlers take over + return currentDirection + + case market.BreakoutNone: + // No breakout - check if we should recover toward neutral + return determineRecoveryDirection(price, box, currentDirection) + + default: + return currentDirection + } +} + +// determineRecoveryDirection determines if grid direction should recover toward neutral +// This implements the gradual recovery logic: long → long_bias → neutral ← short_bias ← short +func determineRecoveryDirection(price float64, box *market.BoxData, currentDirection market.GridDirection) market.GridDirection { + // Check if price is back inside the short box + insideShortBox := price >= box.ShortLower && price <= box.ShortUpper + + if !insideShortBox { + // Still outside short box, maintain current direction + return currentDirection + } + + // Price is inside short box, start recovery toward neutral + switch currentDirection { + case market.GridDirectionLong: + // Full long → bias long + return market.GridDirectionLongBias + case market.GridDirectionLongBias: + // Bias long → neutral + return market.GridDirectionNeutral + case market.GridDirectionShort: + // Full short → bias short + return market.GridDirectionShortBias + case market.GridDirectionShortBias: + // Bias short → neutral + return market.GridDirectionNeutral + default: + return currentDirection + } +} + +// getBreakoutActionWithDirection returns the appropriate action for a breakout level +// when direction adjustment is enabled +func getBreakoutActionWithDirection(level market.BreakoutLevel, enableDirectionAdjust bool) BreakoutAction { + if !enableDirectionAdjust { + // Fall back to original behavior + return getBreakoutAction(level) + } + + switch level { + case market.BreakoutShort: + // Short box breakout with direction adjustment: adjust direction instead of reducing position + return BreakoutActionAdjustDirection + case market.BreakoutMid: + // Mid box breakout with direction adjustment: adjust to full direction + return BreakoutActionAdjustDirection + case market.BreakoutLong: + // Long box breakout: always trigger emergency handling + return BreakoutActionCloseAll + default: + return BreakoutActionNone + } +} + +// shouldRecoverDirection checks if the current grid direction should start recovering toward neutral +func shouldRecoverDirection(box *market.BoxData, currentDirection market.GridDirection) bool { + if box == nil || currentDirection == market.GridDirectionNeutral { + return false + } + + price := box.CurrentPrice + // Check if price is back inside the short box + return price >= box.ShortLower && price <= box.ShortUpper +} diff --git a/trader/grid_regime_test.go b/trader/grid_regime_test.go index 25d0753a..32428c9c 100644 --- a/trader/grid_regime_test.go +++ b/trader/grid_regime_test.go @@ -120,3 +120,223 @@ func TestGetBreakoutAction(t *testing.T) { }) } } + +// ============================================================================ +// Grid Direction Tests +// ============================================================================ + +func TestGetBuySellRatio(t *testing.T) { + tests := []struct { + name string + direction market.GridDirection + biasRatio float64 + wantBuy float64 + wantSell float64 + }{ + {"neutral", market.GridDirectionNeutral, 0.7, 0.5, 0.5}, + {"long", market.GridDirectionLong, 0.7, 1.0, 0.0}, + {"short", market.GridDirectionShort, 0.7, 0.0, 1.0}, + {"long_bias_default", market.GridDirectionLongBias, 0.7, 0.7, 0.3}, + {"short_bias_default", market.GridDirectionShortBias, 0.7, 0.3, 0.7}, + {"long_bias_custom", market.GridDirectionLongBias, 0.8, 0.8, 0.2}, + {"short_bias_custom", market.GridDirectionShortBias, 0.8, 0.2, 0.8}, + {"invalid_bias_uses_default", market.GridDirectionLongBias, 0, 0.7, 0.3}, + {"negative_bias_uses_default", market.GridDirectionLongBias, -1, 0.7, 0.3}, + } + + const tolerance = 0.0001 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buy, sell := tt.direction.GetBuySellRatio(tt.biasRatio) + buyDiff := buy - tt.wantBuy + sellDiff := sell - tt.wantSell + if buyDiff < -tolerance || buyDiff > tolerance || sellDiff < -tolerance || sellDiff > tolerance { + t.Errorf("GetBuySellRatio(%v, %v) = (%v, %v), want (%v, %v)", + tt.direction, tt.biasRatio, buy, sell, tt.wantBuy, tt.wantSell) + } + }) + } +} + +func TestDetermineGridDirection(t *testing.T) { + box := &market.BoxData{ + ShortUpper: 100, + ShortLower: 90, + MidUpper: 105, + MidLower: 85, + LongUpper: 110, + LongLower: 80, + CurrentPrice: 95, + } + + tests := []struct { + name string + currentDirection market.GridDirection + breakoutLevel market.BreakoutLevel + direction string + expected market.GridDirection + }{ + // Short box breakouts + { + name: "short_breakout_up_neutral", + currentDirection: market.GridDirectionNeutral, + breakoutLevel: market.BreakoutShort, + direction: "up", + expected: market.GridDirectionLongBias, + }, + { + name: "short_breakout_down_neutral", + currentDirection: market.GridDirectionNeutral, + breakoutLevel: market.BreakoutShort, + direction: "down", + expected: market.GridDirectionShortBias, + }, + // Mid box breakouts + { + name: "mid_breakout_up", + currentDirection: market.GridDirectionLongBias, + breakoutLevel: market.BreakoutMid, + direction: "up", + expected: market.GridDirectionLong, + }, + { + name: "mid_breakout_down", + currentDirection: market.GridDirectionShortBias, + breakoutLevel: market.BreakoutMid, + direction: "down", + expected: market.GridDirectionShort, + }, + // Long box breakout - maintains current (emergency handling) + { + name: "long_breakout_maintains", + currentDirection: market.GridDirectionLong, + breakoutLevel: market.BreakoutLong, + direction: "up", + expected: market.GridDirectionLong, + }, + // No breakout - tests recovery logic + { + name: "no_breakout_neutral_stays", + currentDirection: market.GridDirectionNeutral, + breakoutLevel: market.BreakoutNone, + direction: "", + expected: market.GridDirectionNeutral, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := determineGridDirection(box, tt.currentDirection, tt.breakoutLevel, tt.direction) + if result != tt.expected { + t.Errorf("determineGridDirection() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestDetermineRecoveryDirection(t *testing.T) { + box := &market.BoxData{ + ShortUpper: 100, + ShortLower: 90, + MidUpper: 105, + MidLower: 85, + LongUpper: 110, + LongLower: 80, + CurrentPrice: 95, // Inside short box + } + + tests := []struct { + name string + price float64 + currentDirection market.GridDirection + expected market.GridDirection + }{ + // Inside short box - should recover + {"long_to_long_bias", 95, market.GridDirectionLong, market.GridDirectionLongBias}, + {"long_bias_to_neutral", 95, market.GridDirectionLongBias, market.GridDirectionNeutral}, + {"short_to_short_bias", 95, market.GridDirectionShort, market.GridDirectionShortBias}, + {"short_bias_to_neutral", 95, market.GridDirectionShortBias, market.GridDirectionNeutral}, + {"neutral_stays_neutral", 95, market.GridDirectionNeutral, market.GridDirectionNeutral}, + + // Outside short box - should maintain + {"long_outside_stays", 101, market.GridDirectionLong, market.GridDirectionLong}, + {"short_outside_stays", 89, market.GridDirectionShort, market.GridDirectionShort}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := determineRecoveryDirection(tt.price, box, tt.currentDirection) + if result != tt.expected { + t.Errorf("determineRecoveryDirection(%v, %v) = %v, want %v", + tt.price, tt.currentDirection, result, tt.expected) + } + }) + } +} + +func TestGetBreakoutActionWithDirection(t *testing.T) { + tests := []struct { + name string + level market.BreakoutLevel + enableDirectionAdjust bool + expected BreakoutAction + }{ + // Direction adjustment disabled - original behavior + {"short_disabled", market.BreakoutShort, false, BreakoutActionReducePosition}, + {"mid_disabled", market.BreakoutMid, false, BreakoutActionPauseGrid}, + {"long_disabled", market.BreakoutLong, false, BreakoutActionCloseAll}, + + // Direction adjustment enabled + {"short_enabled", market.BreakoutShort, true, BreakoutActionAdjustDirection}, + {"mid_enabled", market.BreakoutMid, true, BreakoutActionAdjustDirection}, + {"long_enabled", market.BreakoutLong, true, BreakoutActionCloseAll}, // Long always triggers emergency + {"none_enabled", market.BreakoutNone, true, BreakoutActionNone}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + action := getBreakoutActionWithDirection(tt.level, tt.enableDirectionAdjust) + if action != tt.expected { + t.Errorf("getBreakoutActionWithDirection(%v, %v) = %v, want %v", + tt.level, tt.enableDirectionAdjust, action, tt.expected) + } + }) + } +} + +func TestShouldRecoverDirection(t *testing.T) { + box := &market.BoxData{ + ShortUpper: 100, + ShortLower: 90, + MidUpper: 105, + MidLower: 85, + LongUpper: 110, + LongLower: 80, + CurrentPrice: 95, + } + + tests := []struct { + name string + price float64 + direction market.GridDirection + expected bool + }{ + {"neutral_inside_no_recovery", 95, market.GridDirectionNeutral, false}, + {"long_inside_should_recover", 95, market.GridDirectionLong, true}, + {"long_outside_no_recovery", 101, market.GridDirectionLong, false}, + {"short_inside_should_recover", 95, market.GridDirectionShort, true}, + {"short_outside_no_recovery", 89, market.GridDirectionShort, false}, + {"long_bias_inside_should_recover", 95, market.GridDirectionLongBias, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + box.CurrentPrice = tt.price + result := shouldRecoverDirection(box, tt.direction) + if result != tt.expected { + t.Errorf("shouldRecoverDirection(price=%v, %v) = %v, want %v", + tt.price, tt.direction, result, tt.expected) + } + }) + } +}