From 7e96c5d0f2c131fb9373e879b67e116f70b4c480 Mon Sep 17 00:00:00 2001 From: tinkle-community Date: Mon, 19 Jan 2026 12:07:14 +0800 Subject: [PATCH] Ai grid (#1344) * feat: add AI grid trading and market regime classification - Add GridTrader interface with PlaceLimitOrder, CancelOrder, GetOrderBook - Implement GridTrader for all exchanges (Binance, Bybit, OKX, Bitget, Hyperliquid, Aster, Lighter) - Add grid engine with ATR-based boundary calculation and fund distribution - Add market regime classification documents (Chinese/English) - Add GridConfigEditor component for frontend configuration * fix: implement GetOpenOrders for Lighter exchange * debug: add logging for Lighter GetActiveOrders API call * fix: correct Lighter API response parsing for GetOpenOrders - Changed response field from 'data' to 'orders' to match Lighter API - Updated OrderResponse struct to match Lighter's actual field names - Fixed field types: price/quantity as strings, is_ask for side * feat: implement GetOpenOrders for Aster, OKX, Bitget exchanges - Aster: uses /fapi/v3/openOrders endpoint - OKX: uses /api/v5/trade/orders-pending and orders-algo-pending - Bitget: uses /api/v2/mix/order/orders-pending and orders-plan-pending * fix: address code review issues for GetOpenOrders - Add error logging for OKX/Bitget API failures (was silently swallowed) - Fix Lighter position side logic to handle reduce-only orders - Change verbose debug logs from Infof to Debugf level * fix: provide FromAccountIndex and ApiKeyIndex for Lighter nonce auto-fetch Root cause: SDK requires these fields to fetch nonce from API, otherwise nonce gets cached/stuck * fix: use auth query parameter instead of Authorization header for Lighter API * test: add Lighter API authentication tests and diagnostic tools * fix(grid): add leverage setting before order placement CRITICAL BUG FIX: - Call SetLeverage() in GridTraderAdapter.PlaceLimitOrder() - Set leverage during grid initialization - Log leverage setting results * fix(grid): prevent CancelOrder from canceling all orders CRITICAL BUG FIX: - CancelOrder no longer calls CancelAllOrders - Try exchange-specific CancelOrder if available - Return error if individual cancellation not supported * fix(grid): add total position value limit check CRITICAL: Prevent excessive position accumulation - New checkTotalPositionLimit() function - Checks current + pending + new order value - Rejects orders that would exceed TotalInvestment x Leverage - Logs clear error messages when limit exceeded * feat(grid): implement stop loss execution CRITICAL: Add code-level stop loss protection - New checkAndExecuteStopLoss() function - Checks each filled level against StopLossPct - Automatically closes positions exceeding stop loss - Called during every grid state sync * feat(grid): add breakout detection and auto-pause CRITICAL: Detect price breakout from grid range - New checkBreakout() function to detect upper/lower breakouts - Auto-pause grid on significant breakout (>2%) - Cancel all orders when breakout detected - Prevent continued losses in trending market - Minor breakouts (1-2%) logged for AI consideration * feat(grid): enforce max drawdown limit with emergency exit CRITICAL: Add drawdown protection - New checkMaxDrawdown() function tracks peak equity - emergencyExit() closes all positions and cancels orders - Auto-pause grid when MaxDrawdownPct exceeded - Protect capital from excessive losses * feat(grid): enforce daily loss limit - Add checkDailyLossLimit() function to check if daily loss exceeds limit - Track daily PnL with auto-reset at midnight - Pause grid when DailyLossLimitPct exceeded - Add updateDailyPnL() helper for realized PnL tracking - Prevent excessive single-day losses * fix(grid): update daily PnL when stop loss is executed The updateDailyPnL() function was added but never called, leaving DailyPnL always at 0 and preventing daily loss limit checks from triggering. This fix updates DailyPnL and TotalProfit directly in checkAndExecuteStopLoss() when a stop loss is executed. We update directly rather than calling updateDailyPnL() because the mutex is already held in that function. * feat(grid): add automatic grid adjustment - New checkGridSkew() detects imbalanced grid - autoAdjustGrid() reinitializes around current price - Prevents grid from becoming ineffective after drift - Triggers when one side is 3x more filled than other * fix(grid): recalculate bounds in autoAdjustGrid before reinitializing levels Critical fix for grid auto-adjustment: - Recalculate grid bounds (UpperPrice, LowerPrice, GridSpacing) centered on current price before reinitializing grid levels - Preserve filled positions during adjustment by saving and restoring them to the closest new level after reinitialization - Hold mutex lock for the entire adjustment operation to ensure atomicity - Add locked variants of calculateDefaultBounds, calculateATRBounds, and initializeGridLevels to use during adjustment Without this fix, autoAdjustGrid was using old boundaries when creating new grid levels, defeating the purpose of auto-adjustment when price moved significantly. * fix(grid): improve order state sync logic - Don't assume missing orders are filled - Compare position size to determine fill vs cancel - Properly reset cancelled orders to empty state - More accurate grid state tracking * fix(grid): use actual PositionSize sum instead of count in syncGridState heuristic The position-based heuristic was using `float64(previousFilledCount) * level.OrderQuantity` which incorrectly assumed uniform order quantities. Since the grid uses weighted distribution (gaussian, pyramid, uniform) where orders have different quantities, this could lead to incorrect fill detection. Now sums the actual PositionSize from filled levels for accurate comparison. Also adds warning log when GetPositions() fails. * docs: add grid market regime detection design Design for enhanced market state recognition with: - Multi-dimensional indicators (ATR, Bollinger, EMA, MACD, RSI) - Multi-period box indicators (72/240/500 1h candles) - 4-level ranging classification - Breakout detection and handling - Frontend risk control panel * docs: add grid market regime implementation plan 20 tasks covering: - Donchian channel calculation - Box data types and API - Regime classification (4 levels) - Breakout detection and handling - False breakout recovery - Frontend risk panel - AI prompt updates * feat(market): add Donchian channel calculation Add calculateDonchian function to compute highest high and lowest low over a specified period. This is the foundation for box (range) detection in the multi-period box indicator system for grid trading. * fix(market): handle invalid period in calculateDonchian * feat(market): add BoxData and RegimeLevel types * feat(market): add GetBoxData for multi-period box calculation Adds calculateBoxData internal function and GetBoxData public API that fetches 1h klines and computes three Donchian box levels (short/mid/long). This will be used by the grid trading system to detect market regime. * feat(store): add box and regime fields to grid models * feat(trader): add regime classification and breakout detection Implements Tasks 6-9 for grid market regime awareness: - Task 6: classifyRegimeLevel with Bollinger/ATR thresholds - Task 7: detectBoxBreakout for multi-period box breakouts - Task 8: confirmBreakout with 3-candle confirmation logic - Task 9: getBreakoutAction mapping breakout levels to actions * feat(trader): integrate box breakout detection into grid cycle - Task 10: Add checkBoxBreakout with 3-candle confirmation - Task 11: Add checkFalseBreakoutRecovery for 50% position recovery - Task 12: Add box/breakout/regime fields to GridState * feat: add grid risk panel with API endpoint - Task 13: Add GridRiskInfo type to frontend - Task 14: Add /traders/:id/grid-risk API endpoint - Task 15: Add GetGridRiskInfo method to AutoTrader - Task 16: Create GridRiskPanel component with i18n * feat(kernel): add box indicators to AI prompt - Add BoxData field to GridContext - Add box indicator table to both zh/en prompts - Show breakout/warning alerts based on price position * feat(web): integrate GridRiskPanel into TraderDashboardPage * feat(lighter): improve API key validation and market caching - Add API key validation status tracking - Add market list caching to reduce API calls - Improve logging (debug vs info levels) - Add comprehensive integration tests - Update trader manager and store for lighter support * fix: remove hardcoded test wallet address * fix(grid): improve GridRiskPanel layout and fix liquidation data - Make panel collapsible with summary badges when collapsed - Use compact 2-column grid layout for detailed info - Fix auth token key (token -> auth_token) - Only calculate liquidation distance when position exists * fix(grid): add isRunning checks to prevent trades after Stop() is called --- api/server.go | 47 +- cmd/lighter_test/main.go | 233 +++ docs/market-regime-classification-en.md | 281 +++ docs/market-regime-classification-zh.md | 281 +++ docs/plans/2026-01-14-grid-trading-fixes.md | 1072 +++++++++++ .../2026-01-17-grid-market-regime-design.md | 151 ++ .../2026-01-17-grid-market-regime-impl.md | 1655 +++++++++++++++++ kernel/engine.go | 9 +- kernel/grid_engine.go | 587 ++++++ manager/trader_manager.go | 8 +- market/data.go | 88 + market/data_test.go | 83 + market/types.go | 39 + scripts/test_lighter_orders.go | 168 ++ store/grid.go | 585 ++++++ store/store.go | 14 + store/strategy.go | 36 + store/trader.go | 20 + trader/aster_trader.go | 189 +- trader/auto_trader.go | 49 +- trader/auto_trader_grid.go | 1579 ++++++++++++++++ trader/binance_futures.go | 155 ++ trader/binance_sync_e2e_test.go | 14 +- trader/bitget_trader.go | 238 ++- trader/bybit_trader.go | 156 ++ trader/exchange_sync_test.go | 10 +- trader/grid_regime.go | 196 ++ trader/grid_regime_test.go | 122 ++ trader/hyperliquid_sync_test.go | 20 +- trader/hyperliquid_trader.go | 115 ++ trader/interface.go | 118 +- trader/lighter_integration_test.go | 589 +++++- trader/lighter_trader_v2.go | 107 +- trader/lighter_trader_v2_account.go | 119 +- trader/lighter_trader_v2_orders.go | 135 +- trader/lighter_trader_v2_orders_test.go | 421 +++++ trader/lighter_trader_v2_trading.go | 325 +++- trader/lighter_types.go | 28 +- trader/okx_trader.go | 252 ++- .../components/strategy/GridConfigEditor.tsx | 424 +++++ web/src/components/strategy/GridRiskPanel.tsx | 372 ++++ web/src/pages/StrategyStudioPage.tsx | 94 +- web/src/pages/TraderDashboardPage.tsx | 19 + web/src/types.ts | 69 + 44 files changed, 11038 insertions(+), 234 deletions(-) create mode 100644 cmd/lighter_test/main.go create mode 100644 docs/market-regime-classification-en.md create mode 100644 docs/market-regime-classification-zh.md create mode 100644 docs/plans/2026-01-14-grid-trading-fixes.md create mode 100644 docs/plans/2026-01-17-grid-market-regime-design.md create mode 100644 docs/plans/2026-01-17-grid-market-regime-impl.md create mode 100644 kernel/grid_engine.go create mode 100644 scripts/test_lighter_orders.go create mode 100644 store/grid.go create mode 100644 trader/auto_trader_grid.go create mode 100644 trader/grid_regime.go create mode 100644 trader/grid_regime_test.go create mode 100644 trader/lighter_trader_v2_orders_test.go create mode 100644 web/src/components/strategy/GridConfigEditor.tsx create mode 100644 web/src/components/strategy/GridRiskPanel.tsx diff --git a/api/server.go b/api/server.go index 99eeb8c0..dfe4d7b3 100644 --- a/api/server.go +++ b/api/server.go @@ -157,6 +157,7 @@ func (s *Server) setupRoutes() { protected.POST("/traders/:id/sync-balance", s.handleSyncBalance) protected.POST("/traders/:id/close-position", s.handleClosePosition) protected.PUT("/traders/:id/competition", s.handleToggleCompetition) + protected.GET("/traders/:id/grid-risk", s.handleGetGridRiskInfo) // AI model configuration protected.GET("/models", s.handleGetModelConfigs) @@ -1096,6 +1097,20 @@ func (s *Server) handleToggleCompetition(c *gin.Context) { }) } +// handleGetGridRiskInfo returns current risk information for a grid trader +func (s *Server) handleGetGridRiskInfo(c *gin.Context) { + traderID := c.Param("id") + + autoTrader, err := s.traderManager.GetTrader(traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "trader not found"}) + return + } + + riskInfo := autoTrader.GetGridRiskInfo() + c.JSON(http.StatusOK, riskInfo) +} + // handleSyncBalance Sync exchange balance to initial_balance (Option B: Manual Sync + Option C: Smart Detection) func (s *Server) handleSyncBalance(c *gin.Context) { userID := c.GetString("user_id") @@ -1369,7 +1384,7 @@ func (s *Server) handleClosePosition(c *gin.Context) { if closeErr != nil { logger.Infof("❌ Close position failed: symbol=%s, side=%s, error=%v", req.Symbol, req.Side, closeErr) - SafeInternalError(c, "Failed to close position", closeErr) + SafeInternalError(c, "Close position", closeErr) return } @@ -1705,8 +1720,15 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { logger.Infof("🔓 Decrypted model config data (UserID: %s)", userID) } - // Update each model's configuration + // Update each model's configuration and track traders that need reload + tradersToReload := make(map[string]bool) for modelID, modelData := range req.Models { + // Find traders using this AI model BEFORE updating + traders, _ := s.store.Trader().ListByAIModelID(userID, modelID) + for _, t := range traders { + tradersToReload[t.ID] = true + } + err := s.store.AIModel().Update(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName) if err != nil { SafeInternalError(c, fmt.Sprintf("Update model %s", modelID), err) @@ -1714,6 +1736,12 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { } } + // Remove affected traders from memory BEFORE reloading to pick up new config + for traderID := range tradersToReload { + logger.Infof("🔄 Removing trader %s from memory to reload with new AI model config", traderID) + s.traderManager.RemoveTrader(traderID) + } + // Reload all traders for this user to make new config take effect immediately err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { @@ -1825,8 +1853,15 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { logger.Infof("🔓 Decrypted exchange config data (UserID: %s)", userID) } - // Update each exchange's configuration + // Update each exchange's configuration and track traders that need reload + tradersToReload := make(map[string]bool) for exchangeID, exchangeData := range req.Exchanges { + // Find traders using this exchange BEFORE updating + traders, _ := s.store.Trader().ListByExchangeID(userID, exchangeID) + for _, t := range traders { + tradersToReload[t.ID] = true + } + err := s.store.Exchange().Update(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Passphrase, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey, exchangeData.LighterAPIKeyPrivateKey, exchangeData.LighterAPIKeyIndex) if err != nil { SafeInternalError(c, fmt.Sprintf("Update exchange %s", exchangeID), err) @@ -1834,6 +1869,12 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { } } + // Remove affected traders from memory BEFORE reloading to pick up new config + for traderID := range tradersToReload { + logger.Infof("🔄 Removing trader %s from memory to reload with new exchange config", traderID) + s.traderManager.RemoveTrader(traderID) + } + // Reload all traders for this user to make new config take effect immediately err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { diff --git a/cmd/lighter_test/main.go b/cmd/lighter_test/main.go new file mode 100644 index 00000000..6f896a23 --- /dev/null +++ b/cmd/lighter_test/main.go @@ -0,0 +1,233 @@ +// Lighter API Authentication Test Tool +// Usage: go run cmd/lighter_test/main.go -wallet=0x... -apikey=... [-testnet] +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + lighterClient "github.com/elliottech/lighter-go/client" + lighterHTTP "github.com/elliottech/lighter-go/client/http" +) + +func main() { + // Parse command line flags + walletAddr := flag.String("wallet", "", "Ethereum wallet address") + apiKeyPrivateKey := flag.String("apikey", "", "API key private key (40 bytes hex)") + apiKeyIndex := flag.Int("apikeyindex", 0, "API key index (0-255)") + testnet := flag.Bool("testnet", false, "Use testnet instead of mainnet") + flag.Parse() + + if *walletAddr == "" || *apiKeyPrivateKey == "" { + fmt.Println("Usage: go run cmd/lighter_test/main.go -wallet=0x... -apikey=...") + fmt.Println("Options:") + fmt.Println(" -wallet Ethereum wallet address (required)") + fmt.Println(" -apikey API key private key, 40 bytes hex (required)") + fmt.Println(" -apikeyindex API key index, 0-255 (default: 0)") + fmt.Println(" -testnet Use testnet instead of mainnet") + os.Exit(1) + } + + fmt.Println("=== Lighter API Authentication Test ===") + fmt.Printf("Wallet: %s\n", *walletAddr) + fmt.Printf("API Key Index: %d\n", *apiKeyIndex) + fmt.Printf("Testnet: %v\n", *testnet) + fmt.Println() + + // Determine base URL + baseURL := "https://mainnet.zklighter.elliot.ai" + chainID := uint32(304) + if *testnet { + baseURL = "https://testnet.zklighter.elliot.ai" + chainID = uint32(300) + } + + // Create HTTP client + httpClient := lighterHTTP.NewClient(baseURL) + client := &http.Client{Timeout: 30 * time.Second} + + // Step 1: Get account info + fmt.Println("Step 1: Getting account info...") + accountInfo, err := getAccountByL1Address(client, baseURL, *walletAddr) + if err != nil { + fmt.Printf("ERROR: Failed to get account info: %v\n", err) + os.Exit(1) + } + fmt.Printf("SUCCESS: Account index = %d\n\n", accountInfo.AccountIndex) + + // Step 2: Create TxClient + fmt.Println("Step 2: Creating TxClient...") + txClient, err := lighterClient.NewTxClient( + httpClient, + *apiKeyPrivateKey, + accountInfo.AccountIndex, + uint8(*apiKeyIndex), + chainID, + ) + if err != nil { + fmt.Printf("ERROR: Failed to create TxClient: %v\n", err) + os.Exit(1) + } + fmt.Println("SUCCESS: TxClient created\n") + + // Step 3: Generate auth token + fmt.Println("Step 3: Generating auth token...") + deadline := time.Now().Add(1 * time.Hour) + authToken, err := txClient.GetAuthToken(deadline) + if err != nil { + fmt.Printf("ERROR: Failed to generate auth token: %v\n", err) + os.Exit(1) + } + fmt.Printf("SUCCESS: Auth token generated\n") + fmt.Printf("Token: %s...\n", authToken[:min(50, len(authToken))]) + fmt.Printf("Valid until: %s\n\n", deadline.Format(time.RFC3339)) + + // Step 4: Test GetActiveOrders API with auth query parameter + fmt.Println("Step 4: Testing GetActiveOrders API...") + encodedAuth := url.QueryEscape(authToken) + endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0&auth=%s", + baseURL, accountInfo.AccountIndex, encodedAuth) + + fmt.Printf("Endpoint: %s...\n", endpoint[:min(120, len(endpoint))]) + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + fmt.Printf("ERROR: Failed to create request: %v\n", err) + os.Exit(1) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + fmt.Printf("ERROR: Request failed: %v\n", err) + os.Exit(1) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + fmt.Printf("Status: %d\n", resp.StatusCode) + fmt.Printf("Response: %s\n\n", string(body)) + + // Parse response + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []struct { + OrderID string `json:"order_id"` + Side string `json:"side"` + Type string `json:"type"` + Price string `json:"price"` + } `json:"orders"` + } + if err := json.Unmarshal(body, &apiResp); err != nil { + fmt.Printf("ERROR: Failed to parse response: %v\n", err) + os.Exit(1) + } + + if apiResp.Code != 200 { + fmt.Printf("API ERROR: code=%d, message=%s\n", apiResp.Code, apiResp.Message) + fmt.Println("\n=== DIAGNOSTIC INFO ===") + fmt.Println("If you see 'invalid signature', possible causes:") + fmt.Println("1. API key is not registered on-chain") + fmt.Println("2. API key private key is incorrect") + fmt.Println("3. API key index is wrong") + fmt.Println("4. Account index mismatch") + fmt.Println("\nTo fix:") + fmt.Println("- Go to app.lighter.xyz and register/verify your API key") + fmt.Println("- Make sure you're using the correct API key private key") + os.Exit(1) + } + + fmt.Printf("SUCCESS: Retrieved %d orders\n", len(apiResp.Orders)) + for i, order := range apiResp.Orders { + if i >= 5 { + fmt.Printf("... and %d more orders\n", len(apiResp.Orders)-5) + break + } + fmt.Printf(" Order %s: %s %s @ %s\n", order.OrderID, order.Side, order.Type, order.Price) + } + + // Step 5: Test GetTrades API (also needs auth) + fmt.Println("\nStep 5: Testing GetTrades API...") + tradesEndpoint := fmt.Sprintf("%s/api/v1/trades?account_index=%d&sort_by=timestamp&sort_dir=desc&limit=5&auth=%s", + baseURL, accountInfo.AccountIndex, encodedAuth) + + tradesReq, _ := http.NewRequest("GET", tradesEndpoint, nil) + tradesResp, err := client.Do(tradesReq) + if err != nil { + fmt.Printf("ERROR: Trades request failed: %v\n", err) + } else { + defer tradesResp.Body.Close() + tradesBody, _ := io.ReadAll(tradesResp.Body) + fmt.Printf("Status: %d\n", tradesResp.StatusCode) + if tradesResp.StatusCode == 200 { + fmt.Println("SUCCESS: GetTrades API working") + } else { + fmt.Printf("Response: %s\n", string(tradesBody)) + } + } + + fmt.Println("\n=== ALL TESTS PASSED ===") +} + +// AccountInfo represents Lighter account information +type AccountInfo struct { + AccountIndex int64 `json:"account_index"` + L1Address string `json:"l1_address"` +} + +// getAccountByL1Address gets account info by L1 wallet address +func getAccountByL1Address(client *http.Client, baseURL, walletAddr string) (*AccountInfo, error) { + endpoint := fmt.Sprintf("%s/api/v1/account?by=l1_address&value=%s", baseURL, walletAddr) + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + req = req.WithContext(ctx) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Parse response - can be in "accounts" or "sub_accounts" field + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Accounts []AccountInfo `json:"accounts"` + SubAccounts []AccountInfo `json:"sub_accounts"` + } + + if err := json.Unmarshal(body, &apiResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w, body: %s", err, string(body)) + } + + // Check main accounts first + if len(apiResp.Accounts) > 0 { + return &apiResp.Accounts[0], nil + } + + // Check sub-accounts + if len(apiResp.SubAccounts) > 0 { + return &apiResp.SubAccounts[0], nil + } + + return nil, fmt.Errorf("no account found for address: %s", walletAddr) +} diff --git a/docs/market-regime-classification-en.md b/docs/market-regime-classification-en.md new file mode 100644 index 00000000..ddc83c89 --- /dev/null +++ b/docs/market-regime-classification-en.md @@ -0,0 +1,281 @@ +# Market Regime Classification Framework + +> A comprehensive market state identification system for quantitative trading strategy matching + +--- + +## 1. Classification Dimensions Overview + +Market state identification requires analysis across multiple dimensions: + +| Dimension | Sub-dimensions | Description | +|-----------|---------------|-------------| +| **Trend** | Direction, Strength | Determine market movement direction and momentum | +| **Volatility** | Amplitude, Frequency | Measure price fluctuation characteristics | +| **Structure** | Pattern, Phase | Identify market structure and cycle position | + +--- + +## 2. Primary Classification (5 Categories) + +### 2.1 Classification Overview + +| Code | Name | Key Characteristics | Suitable Strategies | +|------|------|---------------------|---------------------| +| `TREND_UP` | Uptrend | Higher highs & higher lows | Trend following, Breakout | +| `TREND_DOWN` | Downtrend | Lower highs & lower lows | Trend following, Short selling | +| `RANGE` | Range-bound | Price oscillates within bounds | Grid trading, Mean reversion | +| `TRANSITION` | Transition | Uncertain directional period | Wait & watch, Small positions | +| `BREAKOUT` | Breakout | Price breaks key levels | Breakout trading | + +### 2.2 Identification Indicators + +- **ADX (Average Directional Index)**: Measures trend strength + - ADX > 25: Clear trend exists + - ADX < 20: Range-bound market +- **EMA Alignment**: Determines trend direction + - EMA20 > EMA50 > EMA200: Bullish alignment + - EMA20 < EMA50 < EMA200: Bearish alignment + +--- + +## 3. Secondary Classification (18 Sub-categories) + +### 3.1 Uptrend Sub-categories (5 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `TU_STRONG_LOW_VOL` | Strong Uptrend · Low Vol | Steady rise, shallow pullbacks | ADX>40, ATR%<2%, Pullback<38.2% | +| `TU_STRONG_HIGH_VOL` | Strong Uptrend · High Vol | Rapid surge, high volatility | ADX>40, ATR%>4%, MACD histogram expanding | +| `TU_WEAK_CHOPPY` | Weak Uptrend · Choppy | Two steps forward, one back | ADX 20-30, RSI oscillating 50-70 | +| `TU_PARABOLIC` | Parabolic Acceleration | Exponential price increase | Price far from MA, RSI>80, Volume surge | +| `TU_EXHAUSTION` | Uptrend Exhaustion | New highs but weakening momentum | Price new high + MACD/RSI divergence | + +**Strategy Matching:** +- Strong Low Vol: Heavy trend following, pyramid adding +- Strong High Vol: Medium position, trailing stops +- Weak Choppy: Light swing trading +- Parabolic: Cautious, prepare to exit +- Exhaustion: Reduce positions, prepare for reversal + +### 3.2 Downtrend Sub-categories (5 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `TD_STRONG_LOW_VOL` | Strong Downtrend · Low Vol | Steady decline, weak bounces | ADX>40, ATR%<2%, Bounce<38.2% | +| `TD_STRONG_HIGH_VOL` | Strong Downtrend · High Vol | Panic selling, wild swings | ADX>40, ATR%>5%, VIX spike | +| `TD_WEAK_CHOPPY` | Weak Downtrend · Choppy | Grinding lower with bounces | ADX 20-30, RSI oscillating 30-50 | +| `TD_CAPITULATION` | Capitulation | High volume crash, extreme fear | RSI<20, Volume>3x average | +| `TD_EXHAUSTION` | Downtrend Exhaustion | New lows but selling pressure fading | Price new low + MACD/RSI divergence | + +**Strategy Matching:** +- Strong Low Vol: Short trend following +- Strong High Vol: Stay flat or light hedge +- Weak Choppy: Wait for stabilization +- Capitulation: Light bottom fishing possible +- Exhaustion: Gradually build long positions + +### 3.3 Range Sub-categories (4 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `RG_TIGHT_LOW_VOL` | Tight Range · Low Vol | Extreme contraction, coiling | BB Width<2%, ATR at new lows | +| `RG_TIGHT_HIGH_VOL` | Tight Range · High Vol | Violent swings within range | BB Width<3%, ATR%>3% | +| `RG_WIDE_LOW_VOL` | Wide Range · Low Vol | Large range, slow movement | BB Width>5%, ATR%<2% | +| `RG_WIDE_HIGH_VOL` | Wide Range · High Vol | Large range, fast movement | BB Width>5%, ATR%>3% | + +**Strategy Matching:** +- Tight Low Vol: Dense grid, wait for breakout +- Tight High Vol: Fast grid, small frequent profits +- Wide Low Vol: Sparse grid, patient holding +- Wide High Vol: Swing trading, high profit targets + +### 3.4 Transition (2 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `TR_BOTTOM_FORMING` | Bottom Forming | Decline slowing, testing support | Price stabilizing + Volume drying up + RSI divergence | +| `TR_TOP_FORMING` | Top Forming | Rally slowing, testing resistance | Price stalling + Volume drying up + RSI divergence | + +### 3.5 Breakout (2 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `BK_UPWARD` | Upward Breakout | Breaking resistance with volume | Price>Previous high, Volume>2x, BB breakout | +| `BK_DOWNWARD` | Downward Breakout | Breaking support with volume | Price2x, BB breakdown | + +--- + +## 4. Tertiary Classification (36 Ultra-fine Categories) + +### 4.1 Trend Phase Classification + +Uptrend lifecycle consists of 5 phases: + +| Phase Code | Name | Description | Quantitative Criteria | +|------------|------|-------------|----------------------| +| `TU_S1_INITIATION` | Uptrend Initiation | First break above MA or previous high | MACD bullish cross, Price>EMA20 | +| `TU_S2_ACCELERATION` | Uptrend Acceleration | Momentum increasing, slope steepening | MACD histogram expanding, ADX rising | +| `TU_S3_MAIN_WAVE` | Main Wave | Sustained rise, shallow pullbacks | RSI 60-80, Pullbacks hold EMA20 | +| `TU_S4_EXHAUSTION` | Uptrend Exhaustion | Slowing momentum, divergences appearing | RSI divergence, MACD divergence | +| `TU_S5_REVERSAL` | Trend Reversal | Breakdown, trend ending | Break below EMA50, MACD bearish cross | + +Downtrend phases follow same pattern: `TD_S1` through `TD_S5` + +### 4.2 Range Position Classification + +| Position Code | Name | Description | Strategy Suggestion | +|---------------|------|-------------|---------------------| +| `RG_UPPER` | Upper Range | Price near resistance | Bias toward short | +| `RG_MIDDLE` | Mid Range | Price near middle band | Neutral grid trading | +| `RG_LOWER` | Lower Range | Price near support | Bias toward long | +| `RG_SQUEEZE` | Squeeze Pattern | Highs and lows converging | Wait for direction | +| `RG_EXPAND` | Expanding Pattern | Highs and lows diverging | Boundary reversal | + +### 4.3 Volatility Grades + +| Code | Name | ATR% | BB Width | Strategy Suggestion | +|------|------|------|----------|---------------------| +| `VOL_EXTREME_LOW` | Extreme Low Vol | <1% | <1.5% | Option selling | +| `VOL_LOW` | Low Volatility | 1-2% | 1.5-2.5% | Grid / Mean reversion | +| `VOL_NORMAL` | Normal Volatility | 2-3% | 2.5-4% | Trend following | +| `VOL_HIGH` | High Volatility | 3-5% | 4-6% | Momentum / Breakout | +| `VOL_EXTREME_HIGH` | Extreme High Vol | >5% | >6% | Reduce exposure / Hedge | + +--- + +## 5. Complete State Encoding Rules + +### 5.1 Encoding Format + +``` +{Primary}_{Volatility}_{Phase}_{Position} +``` + +### 5.2 Encoding Examples + +| Full Code | Interpretation | +|-----------|----------------| +| `TU_LV_S3_M` | Uptrend_LowVol_MainWave_Middle | +| `TD_HV_S2_L` | Downtrend_HighVol_Acceleration_Lower | +| `RG_NV_SQ_U` | Range_NormalVol_Squeeze_Upper | +| `BK_HV_UP_M` | Breakout_HighVol_Upward_Middle | + +--- + +## 6. Core Identification Indicators + +### 6.1 Trend Indicators + +| Indicator | Calculation | Criteria | +|-----------|-------------|----------| +| ADX | 14-period Average Directional Index | >40 Strong, 25-40 Medium, <25 Weak/Range | +| Trend Score | Composite EMA/MACD/Price structure | -100 to +100, Positive=Bullish, Negative=Bearish | +| EMA Alignment | Relative position of EMA20/50/200 | Bullish/Bearish/Mixed alignment | + +### 6.2 Volatility Indicators + +| Indicator | Calculation | Purpose | +|-----------|-------------|---------| +| ATR Percent | ATR(14) / Current Price × 100% | Measure relative volatility | +| BB Width | (Upper - Lower) / Middle × 100% | Measure price range | +| Volatility Rank | Current vol percentile in history | Determine vol level | + +### 6.3 Momentum Indicators + +| Indicator | Calculation | Criteria | +|-----------|-------------|----------| +| RSI | 14-period Relative Strength Index | >70 Overbought, <30 Oversold, 50 Neutral | +| MACD Histogram | MACD - Signal | Positive=Bullish momentum, Negative=Bearish | +| Momentum Score | Composite RSI/MACD/Volume | Measure current momentum | + +### 6.4 Structure Indicators + +| Indicator | Description | Purpose | +|-----------|-------------|---------| +| Swing Structure | HH/HL/LH/LL sequence | Determine trend structure | +| Support/Resistance | Key price levels | Define trading range | +| Volume Profile | Volume-price relationship | Validate price action | + +--- + +## 7. Strategy Matching Matrix + +### 7.1 Regime-Strategy Mapping + +| Regime Type | Recommended Strategy | Position Size | Stop Loss | +|-------------|---------------------|---------------|-----------| +| Strong Uptrend · Low Vol | Trend following + Pyramid | 60-80% | ATR×2 | +| Strong Uptrend · High Vol | Momentum + Quick profit | 40-60% | ATR×1.5 | +| Uptrend Exhaustion | Reduce + Reversal short | 20-30% | Previous high | +| Panic Decline | Wait or light bottom fish | 10-20% | Wide stop | +| Low Vol Range | Grid trading | 50-70% | Range boundary | +| High Vol Range | Swing trading | 30-50% | ATR×2 | +| Squeeze Pattern | Wait for breakout | 10-20% | - | +| Upward Breakout | Chase + Add on pullback | 50-70% | Breakout level | +| Bottom Formation | Scale in gradually | 20-40% | New low | + +### 7.2 Grid Strategy Parameter Matching + +| Range Type | Grid Levels | Grid Spacing | Other Parameters | +|------------|-------------|--------------|------------------| +| Tight Low Vol | 30-50 levels | Small spacing | Enable Maker Only | +| Tight High Vol | 15-25 levels | Small spacing | Fast execution mode | +| Wide Low Vol | 10-20 levels | Large spacing | Patient execution | +| Wide High Vol | 15-25 levels | Large spacing | High profit targets | +| Squeeze Pattern | Pause grid | - | Wait for breakout signal | +| Upper Range | Short bias | Medium | Increase sell weight | +| Lower Range | Long bias | Medium | Increase buy weight | + +--- + +## 8. Real-time Monitoring Guidelines + +### 8.1 State Transition Triggers + +| Current State | Trigger Condition | Transitions To | +|---------------|-------------------|----------------| +| Range | Price breakout + Volume + ADX rising | Breakout | +| Uptrend | RSI divergence + Volume decline | Exhaustion | +| Downtrend | RSI divergence + Volume decline | Exhaustion | +| Breakout | Failed breakout, price returns | Range | +| Exhaustion | Confirmed reversal breakout | Opposite trend | + +### 8.2 Risk Control Rules + +| Regime State | Max Position | Risk Per Trade | Special Rules | +|--------------|--------------|----------------|---------------| +| Strong Trend | 80% | 2% | Adding allowed | +| Weak Trend | 50% | 1.5% | No adding | +| Range | 60% | 1% | Diversified holding | +| Transition | 30% | 1% | Reduce activity | +| High Volatility | 40% | 0.5% | Wide stops | + +--- + +## 9. Appendix + +### 9.1 Abbreviation Reference + +| Abbrev | Full Form | Description | +|--------|-----------|-------------| +| TU | Trend Up | Upward trend | +| TD | Trend Down | Downward trend | +| RG | Range | Range-bound market | +| TR | Transition | Trend transition | +| BK | Breakout | Breakout pattern | +| LV | Low Volatility | Low volatility regime | +| HV | High Volatility | High volatility regime | +| NV | Normal Volatility | Normal volatility regime | +| XLV | Extreme Low Vol | Extremely low volatility | +| XHV | Extreme High Vol | Extremely high volatility | + +### 9.2 Document Information + +- Version: v1.0 +- Created: January 2026 +- Applicable: Cryptocurrency, Forex, Stocks, and other financial markets + +--- + +*This document is designed for market state identification and strategy matching in quantitative trading systems* diff --git a/docs/market-regime-classification-zh.md b/docs/market-regime-classification-zh.md new file mode 100644 index 00000000..36e0092d --- /dev/null +++ b/docs/market-regime-classification-zh.md @@ -0,0 +1,281 @@ +# 市场行情精细分类体系 + +> 用于量化交易策略匹配的市场状态识别框架 + +--- + +## 一、分类维度概览 + +市场状态识别需要从多个维度进行分析: + +| 维度 | 子维度 | 说明 | +|------|--------|------| +| **趋势维度** | 方向、强度 | 判断市场运动方向和力度 | +| **波动维度** | 幅度、频率 | 衡量价格波动特征 | +| **结构维度** | 形态、阶段 | 识别市场结构和所处周期 | + +--- + +## 二、一级分类(5大类) + +### 2.1 分类总览 + +| 代码 | 名称 | 核心特征 | 适合策略 | +|------|------|----------|----------| +| `TREND_UP` | 上涨趋势 | 高点/低点持续抬升 | 趋势跟踪、突破追涨 | +| `TREND_DOWN` | 下跌趋势 | 高点/低点持续降低 | 趋势跟踪、做空策略 | +| `RANGE` | 震荡区间 | 价格在区间内波动 | 网格交易、均值回归 | +| `TRANSITION` | 趋势转换 | 方向不明确的过渡期 | 观望、小仓位试探 | +| `BREAKOUT` | 突破行情 | 价格突破关键位置 | 突破追踪策略 | + +### 2.2 识别指标 + +- **ADX(平均方向指数)**:衡量趋势强度 + - ADX > 25:存在明确趋势 + - ADX < 20:震荡市场 +- **EMA排列**:判断趋势方向 + - EMA20 > EMA50 > EMA200:多头排列 + - EMA20 < EMA50 < EMA200:空头排列 + +--- + +## 三、二级分类(18细分类) + +### 3.1 上涨趋势细分(5种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `TU_STRONG_LOW_VOL` | 强势上涨·低波动 | 稳步上涨,回调幅度小 | ADX>40, ATR%<2%, 回调<38.2% | +| `TU_STRONG_HIGH_VOL` | 强势上涨·高波动 | 快速拉升,波动剧烈 | ADX>40, ATR%>4%, MACD柱放大 | +| `TU_WEAK_CHOPPY` | 弱势上涨·震荡 | 涨三退二,反复磨蹭 | ADX 20-30, RSI在50-70震荡 | +| `TU_PARABOLIC` | 抛物线加速 | 指数级加速上涨 | 价格远离均线, RSI>80, 成交量放大 | +| `TU_EXHAUSTION` | 上涨衰竭 | 创新高但动能减弱 | 价格新高 + MACD/RSI顶背离 | + +**策略匹配:** +- 强势低波动:重仓趋势跟踪,金字塔加仓 +- 强势高波动:中等仓位,设置移动止盈 +- 弱势震荡:轻仓波段,高抛低吸 +- 抛物线加速:谨慎追涨,准备离场 +- 上涨衰竭:减仓观望,准备反转做空 + +### 3.2 下跌趋势细分(5种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `TD_STRONG_LOW_VOL` | 强势下跌·低波动 | 稳步下跌,反弹无力 | ADX>40, ATR%<2%, 反弹<38.2% | +| `TD_STRONG_HIGH_VOL` | 强势下跌·高波动 | 恐慌抛售,波动剧烈 | ADX>40, ATR%>5%, 恐慌指数飙升 | +| `TD_WEAK_CHOPPY` | 弱势下跌·震荡 | 跌跌涨涨,磨底过程 | ADX 20-30, RSI在30-50震荡 | +| `TD_CAPITULATION` | 恐慌投降 | 放量暴跌,情绪极端 | RSI<20, 成交量>3倍均量 | +| `TD_EXHAUSTION` | 下跌衰竭 | 创新低但卖压减弱 | 价格新低 + MACD/RSI底背离 | + +**策略匹配:** +- 强势低波动:空头趋势跟踪 +- 强势高波动:观望或轻仓对冲 +- 弱势震荡:等待企稳信号 +- 恐慌投降:极端情况可轻仓抄底 +- 下跌衰竭:逐步建立多头仓位 + +### 3.3 震荡区间细分(4种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `RG_TIGHT_LOW_VOL` | 窄幅震荡·低波动 | 极度收敛,蓄势待发 | 布林带宽度<2%, ATR创新低 | +| `RG_TIGHT_HIGH_VOL` | 窄幅震荡·高波动 | 区间内剧烈波动 | 布林带宽度<3%, ATR%>3% | +| `RG_WIDE_LOW_VOL` | 宽幅震荡·低波动 | 大区间慢速波动 | 布林带宽度>5%, ATR%<2% | +| `RG_WIDE_HIGH_VOL` | 宽幅震荡·高波动 | 大区间快速波动 | 布林带宽度>5%, ATR%>3% | + +**策略匹配:** +- 窄幅低波动:密集网格,等待突破 +- 窄幅高波动:快速网格,小利润多次 +- 宽幅低波动:稀疏网格,耐心持有 +- 宽幅高波动:波段交易,高利润目标 + +### 3.4 转换过渡(2种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `TR_BOTTOM_FORMING` | 底部形成中 | 下跌放缓,试探支撑 | 价格止跌 + 成交量萎缩 + RSI底背离 | +| `TR_TOP_FORMING` | 顶部形成中 | 上涨放缓,试探压力 | 价格滞涨 + 成交量萎缩 + RSI顶背离 | + +### 3.5 突破行情(2种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `BK_UPWARD` | 向上突破 | 突破阻力位并放量 | 价格>前高, 成交量>2倍, 布林带突破 | +| `BK_DOWNWARD` | 向下突破 | 跌破支撑位并放量 | 价格<前低, 成交量>2倍, 布林带跌破 | + +--- + +## 四、三级分类(36超细分类) + +### 4.1 趋势阶段细分 + +上涨趋势生命周期分为5个阶段: + +| 阶段代码 | 名称 | 特征描述 | 量化判断标准 | +|----------|------|----------|--------------| +| `TU_S1_INITIATION` | 上涨启动期 | 首次突破均线或前高 | MACD金叉, 价格突破EMA20 | +| `TU_S2_ACCELERATION` | 上涨加速期 | 动能增强,斜率加大 | MACD柱持续增大, ADX上升 | +| `TU_S3_MAIN_WAVE` | 主升浪阶段 | 持续上涨,回调幅度浅 | RSI维持60-80, 回调不破EMA20 | +| `TU_S4_EXHAUSTION` | 上涨衰竭期 | 涨速放缓,出现背离 | RSI顶背离, MACD顶背离 | +| `TU_S5_REVERSAL` | 趋势反转期 | 破位下跌,趋势结束 | 跌破EMA50, MACD死叉 | + +下跌趋势同理,代码为 `TD_S1` 至 `TD_S5` + +### 4.2 震荡位置细分 + +| 位置代码 | 名称 | 特征描述 | 策略建议 | +|----------|------|----------|----------| +| `RG_UPPER` | 区间上沿震荡 | 价格接近阻力位 | 偏空操作为主 | +| `RG_MIDDLE` | 区间中部震荡 | 价格在中轨附近 | 双向网格交易 | +| `RG_LOWER` | 区间下沿震荡 | 价格接近支撑位 | 偏多操作为主 | +| `RG_SQUEEZE` | 收敛三角震荡 | 高低点逐渐收窄 | 等待方向选择 | +| `RG_EXPAND` | 扩散三角震荡 | 高低点逐渐扩张 | 边界反转操作 | + +### 4.3 波动率等级 + +| 代码 | 名称 | ATR百分比 | 布林带宽度 | 策略建议 | +|------|------|-----------|------------|----------| +| `VOL_EXTREME_LOW` | 极低波动 | <1% | <1.5% | 期权卖方策略 | +| `VOL_LOW` | 低波动 | 1-2% | 1.5-2.5% | 网格/均值回归 | +| `VOL_NORMAL` | 正常波动 | 2-3% | 2.5-4% | 趋势跟踪 | +| `VOL_HIGH` | 高波动 | 3-5% | 4-6% | 动量/突破 | +| `VOL_EXTREME_HIGH` | 极高波动 | >5% | >6% | 减仓/对冲 | + +--- + +## 五、完整状态编码规则 + +### 5.1 编码格式 + +``` +{一级分类}_{波动等级}_{阶段}_{位置} +``` + +### 5.2 编码示例 + +| 完整代码 | 含义解释 | +|----------|----------| +| `TU_LV_S3_M` | 上涨趋势_低波动_主升浪_中部位置 | +| `TD_HV_S2_L` | 下跌趋势_高波动_加速期_下部位置 | +| `RG_NV_SQ_U` | 震荡区间_正常波动_收敛形态_上沿位置 | +| `BK_HV_UP_M` | 突破行情_高波动_向上突破_中部位置 | + +--- + +## 六、核心识别指标 + +### 6.1 趋势指标 + +| 指标 | 计算方法 | 判断标准 | +|------|----------|----------| +| ADX | 14周期平均方向指数 | >40强趋势, 25-40中等, <25弱/震荡 | +| 趋势评分 | 综合EMA/MACD/价格结构 | -100到+100, 正数多头,负数空头 | +| EMA排列 | EMA20/50/200相对位置 | 多头排列/空头排列/混乱 | + +### 6.2 波动指标 + +| 指标 | 计算方法 | 用途 | +|------|----------|------| +| ATR百分比 | ATR(14) / 当前价格 × 100% | 衡量相对波动幅度 | +| 布林带宽度 | (上轨-下轨) / 中轨 × 100% | 衡量价格波动区间 | +| 波动率排名 | 当前波动在历史中的分位 | 判断波动率高低 | + +### 6.3 动量指标 + +| 指标 | 计算方法 | 判断标准 | +|------|----------|----------| +| RSI | 14周期相对强弱指数 | >70超买, <30超卖, 50中性 | +| MACD柱 | MACD - Signal | 正数多头动能,负数空头动能 | +| 动量评分 | 综合RSI/MACD/成交量 | 衡量当前动能强弱 | + +### 6.4 结构指标 + +| 指标 | 说明 | 用途 | +|------|------|------| +| 高低点结构 | HH/HL/LH/LL序列 | 判断趋势结构 | +| 支撑阻力位 | 关键价格水平 | 确定交易区间 | +| 成交量形态 | 量价配合关系 | 验证价格走势 | + +--- + +## 七、策略匹配矩阵 + +### 7.1 行情类型与策略对应 + +| 行情类型 | 推荐策略 | 建议仓位 | 止损设置 | +|----------|----------|----------|----------| +| 强势上涨·低波动 | 趋势跟踪+金字塔加仓 | 60-80% | ATR×2 | +| 强势上涨·高波动 | 动量突破+快速止盈 | 40-60% | ATR×1.5 | +| 上涨衰竭期 | 减仓+反转信号做空 | 20-30% | 前高 | +| 恐慌下跌 | 观望或轻仓抄底 | 10-20% | 宽止损 | +| 低波动震荡 | 网格交易 | 50-70% | 区间边界 | +| 高波动震荡 | 波段高抛低吸 | 30-50% | ATR×2 | +| 收敛等待 | 蓄势等突破 | 10-20% | - | +| 向上突破 | 追涨+回踩加仓 | 50-70% | 突破位 | +| 底部形成 | 分批建仓 | 20-40% | 新低 | + +### 7.2 网格策略参数匹配 + +| 震荡类型 | 网格层数 | 网格间距 | 其他参数 | +|----------|----------|----------|----------| +| 窄幅低波动 | 30-50层 | 小间距 | 启用Maker Only | +| 窄幅高波动 | 15-25层 | 小间距 | 快速成交模式 | +| 宽幅低波动 | 10-20层 | 大间距 | 耐心等待成交 | +| 宽幅高波动 | 15-25层 | 大间距 | 高利润目标 | +| 收敛形态 | 暂停网格 | - | 等待突破信号 | +| 区间上沿 | 偏空配置 | 中等 | 卖单权重增加 | +| 区间下沿 | 偏多配置 | 中等 | 买单权重增加 | + +--- + +## 八、实时监控建议 + +### 8.1 状态转换触发条件 + +| 当前状态 | 触发条件 | 转换到 | +|----------|----------|--------| +| 震荡区间 | 价格突破+放量+ADX上升 | 突破行情 | +| 上涨趋势 | RSI顶背离+成交量萎缩 | 上涨衰竭 | +| 下跌趋势 | RSI底背离+成交量萎缩 | 下跌衰竭 | +| 突破行情 | 突破失败回落 | 震荡区间 | +| 趋势衰竭 | 反向突破确认 | 反向趋势 | + +### 8.2 风险控制规则 + +| 行情状态 | 最大仓位 | 单笔风险 | 特殊规则 | +|----------|----------|----------|----------| +| 强趋势 | 80% | 2% | 可加仓 | +| 弱趋势 | 50% | 1.5% | 不加仓 | +| 震荡 | 60% | 1% | 分散持仓 | +| 转换期 | 30% | 1% | 减少操作 | +| 高波动 | 40% | 0.5% | 宽止损 | + +--- + +## 九、附录 + +### 9.1 缩写对照表 + +| 缩写 | 英文全称 | 中文含义 | +|------|----------|----------| +| TU | Trend Up | 上涨趋势 | +| TD | Trend Down | 下跌趋势 | +| RG | Range | 震荡区间 | +| TR | Transition | 趋势转换 | +| BK | Breakout | 突破行情 | +| LV | Low Volatility | 低波动 | +| HV | High Volatility | 高波动 | +| NV | Normal Volatility | 正常波动 | +| XLV | Extreme Low Vol | 极低波动 | +| XHV | Extreme High Vol | 极高波动 | + +### 9.2 版本信息 + +- 文档版本:v1.0 +- 创建日期:2026年1月 +- 适用范围:加密货币、外汇、股票等金融市场 + +--- + +*本文档用于量化交易系统的市场状态识别和策略匹配* diff --git a/docs/plans/2026-01-14-grid-trading-fixes.md b/docs/plans/2026-01-14-grid-trading-fixes.md new file mode 100644 index 00000000..48d37435 --- /dev/null +++ b/docs/plans/2026-01-14-grid-trading-fixes.md @@ -0,0 +1,1072 @@ +# AI自适应网格交易系统修复计划 + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** 修复AI网格交易系统的所有致命和严重问题,添加代码级风控保护机制。 + +**Architecture:** +1. 在AI决策和订单执行之间添加风控验证层 +2. 实现代码级止损、仓位限制、突破检测 +3. 修复杠杆设置和订单取消的BUG +4. 添加自动网格调整机制 + +**Tech Stack:** Go, GORM, 交易所API接口 + +--- + +## 问题优先级 + +| 优先级 | 问题 | Task | +|--------|------|------| +| P0 致命 | 杠杆未生效 | Task 1 | +| P0 致命 | 取消订单逻辑错误 | Task 2 | +| P0 致命 | 无总仓位限制 | Task 3 | +| P1 严重 | 无止损执行 | Task 4 | +| P1 严重 | 无突破检测 | Task 5 | +| P1 严重 | MaxDrawdown未执行 | Task 6 | +| P1 严重 | DailyLossLimit未执行 | Task 7 | +| P2 中等 | 无动态调整 | Task 8 | +| P2 中等 | 订单状态同步错误 | Task 9 | + +--- + +## Task 1: 修复杠杆设置BUG + +**问题:** `PlaceLimitOrder` 完全忽略 `Leverage` 字段,从未调用 `SetLeverage()` + +**Files:** +- Modify: `trader/interface.go:171-194` +- Modify: `trader/auto_trader_grid.go:324-409` +- Create: `trader/grid_test.go` (新增测试) + +### Step 1.1: 在 GridTraderAdapter.PlaceLimitOrder 中添加杠杆设置 + +修改 `trader/interface.go`: + +```go +// PlaceLimitOrder implements limit order using available methods +// For exchanges without native limit order support, this uses conditional orders +func (a *GridTraderAdapter) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // CRITICAL FIX: Set leverage before placing order + if req.Leverage > 0 { + if err := a.Trader.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Grid] Failed to set leverage %dx: %v", req.Leverage, err) + // Continue anyway - some exchanges don't require explicit leverage setting + } + } + + // Use SetStopLoss/SetTakeProfit as conditional limit orders + // For buy orders below current price, use stop-loss mechanism + // For sell orders above current price, use take-profit mechanism + var err error + if req.Side == "BUY" { + err = a.Trader.SetStopLoss(req.Symbol, "SHORT", req.Quantity, req.Price) + } else { + err = a.Trader.SetTakeProfit(req.Symbol, "LONG", req.Quantity, req.Price) + } + if err != nil { + return nil, err + } + return &LimitOrderResult{ + OrderID: req.ClientID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} +``` + +### Step 1.2: 在 InitializeGrid 中设置杠杆 + +修改 `trader/auto_trader_grid.go`, 在 `InitializeGrid()` 函数末尾添加: + +```go +// InitializeGrid initializes the grid state and calculates levels +func (at *AutoTrader) InitializeGrid() error { + // ... 现有代码 ... + + at.gridState.IsInitialized = true + + // CRITICAL: Set leverage on exchange before trading + if err := at.trader.SetLeverage(gridConfig.Symbol, gridConfig.Leverage); err != nil { + logger.Warnf("[Grid] Failed to set leverage %dx on exchange: %v", gridConfig.Leverage, err) + // Not fatal - continue with default leverage + } else { + logger.Infof("[Grid] Leverage set to %dx for %s", gridConfig.Leverage, gridConfig.Symbol) + } + + logger.Infof("[Grid] Initialized: %d levels, $%.2f - $%.2f, spacing $%.2f", + gridConfig.GridCount, at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) + + return nil +} +``` + +### Step 1.3: 运行测试验证 + +```bash +go build ./trader/ +go test -v -run "TestLighter.*Leverage" ./trader/ -timeout 60s +``` + +### Step 1.4: 提交 + +```bash +git add trader/interface.go trader/auto_trader_grid.go +git commit -m "fix(grid): add leverage setting before order placement + +CRITICAL BUG FIX: +- Call SetLeverage() in GridTraderAdapter.PlaceLimitOrder() +- Set leverage during grid initialization +- Log leverage setting results" +``` + +--- + +## Task 2: 修复订单取消逻辑BUG + +**问题:** `GridTraderAdapter.CancelOrder()` 错误地调用 `CancelAllOrders()` + +**Files:** +- Modify: `trader/interface.go:196-200` + +### Step 2.1: 修复 CancelOrder 实现 + +修改 `trader/interface.go`: + +```go +// CancelOrder cancels a specific order +func (a *GridTraderAdapter) CancelOrder(symbol, orderID string) error { + // Try to use CancelOrder if trader supports it directly + if canceler, ok := a.Trader.(interface { + CancelOrder(symbol, orderID string) error + }); ok { + return canceler.CancelOrder(symbol, orderID) + } + + // For traders that only support CancelAllOrders, log a warning + // This is a limitation - we cannot cancel individual orders + logger.Warnf("[Grid] Trader does not support individual order cancellation, " + + "cannot cancel order %s. Consider using exchange-specific GridTrader implementation.", orderID) + + // Return error instead of canceling all orders + return fmt.Errorf("individual order cancellation not supported for this exchange") +} +``` + +### Step 2.2: 添加 fmt import (如果缺失) + +确保 `trader/interface.go` 顶部有: +```go +import ( + "fmt" + // ... 其他imports +) +``` + +### Step 2.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 2.4: 提交 + +```bash +git add trader/interface.go +git commit -m "fix(grid): prevent CancelOrder from canceling all orders + +CRITICAL BUG FIX: +- CancelOrder no longer calls CancelAllOrders +- Try exchange-specific CancelOrder if available +- Return error if individual cancellation not supported" +``` + +--- + +## Task 3: 添加总仓位限制 + +**问题:** 只检查单层仓位,不检查总仓位,导致可能开出巨额仓位 + +**Files:** +- Modify: `trader/auto_trader_grid.go:324-409` +- Modify: `trader/auto_trader_grid.go` (新增 `checkTotalPositionLimit` 函数) + +### Step 3.1: 添加总仓位检查函数 + +在 `trader/auto_trader_grid.go` 中 `placeGridLimitOrder` 函数之前添加: + +```go +// checkTotalPositionLimit checks if adding a new position would exceed total limits +// Returns: (allowed bool, currentPositionValue float64, maxAllowed float64) +func (at *AutoTrader) checkTotalPositionLimit(symbol string, additionalValue float64) (bool, float64, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + // Calculate max allowed total position value + // Total position should not exceed: TotalInvestment × Leverage + maxTotalPositionValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) + + // Get current position value from exchange + currentPositionValue := 0.0 + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == symbol { + if size, ok := pos["positionAmt"].(float64); ok { + if price, ok := pos["markPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * price + } else if entryPrice, ok := pos["entryPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * entryPrice + } + } + } + } + } + + // Also count pending orders as potential position + at.gridState.mu.RLock() + pendingValue := 0.0 + for _, level := range at.gridState.Levels { + if level.State == "pending" { + pendingValue += level.OrderQuantity * level.Price + } + } + at.gridState.mu.RUnlock() + + totalAfterOrder := currentPositionValue + pendingValue + additionalValue + allowed := totalAfterOrder <= maxTotalPositionValue + + return allowed, currentPositionValue + pendingValue, maxTotalPositionValue +} +``` + +### Step 3.2: 在 placeGridLimitOrder 中使用总仓位检查 + +修改 `trader/auto_trader_grid.go` 的 `placeGridLimitOrder` 函数,在现有检查之后添加: + +```go +func (at *AutoTrader) placeGridLimitOrder(d *kernel.Decision, side string) error { + // ... 现有代码到 line 377 ... + + // CRITICAL: Check total position limit before placing order + orderValue := quantity * d.Price + allowed, currentValue, maxValue := at.checkTotalPositionLimit(d.Symbol, orderValue) + if !allowed { + logger.Errorf("[Grid] TOTAL POSITION LIMIT EXCEEDED: current=$%.2f + order=$%.2f > max=$%.2f. Rejecting order.", + currentValue, orderValue, maxValue) + return fmt.Errorf("total position value $%.2f would exceed limit $%.2f", currentValue+orderValue, maxValue) + } + + req := &LimitOrderRequest{ + // ... 现有代码 ... + } + // ... 其余代码 ... +} +``` + +### Step 3.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 3.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "fix(grid): add total position value limit check + +CRITICAL: Prevent excessive position accumulation +- New checkTotalPositionLimit() function +- Checks current + pending + new order value +- Rejects orders that would exceed TotalInvestment × Leverage +- Logs clear error messages when limit exceeded" +``` + +--- + +## Task 4: 添加止损执行机制 + +**问题:** `StopLossPct` 存在于配置但从未使用 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkAndExecuteStopLoss` 函数) +- Modify: `trader/auto_trader_grid.go:504-565` (在 `syncGridState` 中调用) + +### Step 4.1: 添加止损检查和执行函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// checkAndExecuteStopLoss checks if any filled level has exceeded stop loss and closes it +func (at *AutoTrader) checkAndExecuteStopLoss() { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.StopLossPct <= 0 { + return // Stop loss not configured + } + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get market price for stop loss check: %v", err) + return + } + + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State != "filled" || level.PositionEntry <= 0 { + continue + } + + // Calculate loss percentage + var lossPct float64 + if level.Side == "buy" { + // Long position: loss when price drops + lossPct = (level.PositionEntry - currentPrice) / level.PositionEntry * 100 + } else { + // Short position: loss when price rises + lossPct = (currentPrice - level.PositionEntry) / level.PositionEntry * 100 + } + + // Check if stop loss triggered + if lossPct >= gridConfig.StopLossPct { + logger.Warnf("[Grid] STOP LOSS TRIGGERED: Level %d, entry=$%.2f, current=$%.2f, loss=%.2f%%", + i, level.PositionEntry, currentPrice, lossPct) + + // Close the position + var closeErr error + if level.Side == "buy" { + _, closeErr = at.trader.CloseLong(gridConfig.Symbol, level.PositionSize) + } else { + _, closeErr = at.trader.CloseShort(gridConfig.Symbol, level.PositionSize) + } + + if closeErr != nil { + logger.Errorf("[Grid] Failed to execute stop loss for level %d: %v", i, closeErr) + } else { + level.State = "stopped" + level.UnrealizedPnL = -lossPct * level.AllocatedUSD / 100 + at.gridState.TotalTrades++ + logger.Infof("[Grid] Stop loss executed: Level %d closed at $%.2f (loss %.2f%%)", + i, currentPrice, lossPct) + } + } + } +} +``` + +### Step 4.2: 在 syncGridState 中调用止损检查 + +修改 `trader/auto_trader_grid.go` 的 `syncGridState` 函数末尾: + +```go +func (at *AutoTrader) syncGridState() { + // ... 现有代码 ... + + logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", totalPosition, len(openOrders)) + + // CRITICAL: Check stop loss for filled levels + at.checkAndExecuteStopLoss() +} +``` + +### Step 4.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 4.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): implement stop loss execution + +CRITICAL: Add code-level stop loss protection +- New checkAndExecuteStopLoss() function +- Checks each filled level against StopLossPct +- Automatically closes positions exceeding stop loss +- Called during every grid state sync" +``` + +--- + +## Task 5: 添加突破检测机制 + +**问题:** 价格突破网格边界时无响应,继续执行导致单边亏损 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkBreakout` 函数) +- Modify: `trader/auto_trader_grid.go:184-224` (在 `RunGridCycle` 中调用) + +### Step 5.1: 添加突破检测函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// BreakoutType represents the type of price breakout +type BreakoutType string + +const ( + BreakoutNone BreakoutType = "none" + BreakoutUpper BreakoutType = "upper" + BreakoutLower BreakoutType = "lower" +) + +// checkBreakout detects if price has broken out of grid range +// Returns breakout type and percentage beyond boundary +func (at *AutoTrader) checkBreakout() (BreakoutType, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return BreakoutNone, 0 + } + + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + if upper <= 0 || lower <= 0 { + return BreakoutNone, 0 + } + + // Check upper breakout + if currentPrice > upper { + breakoutPct := (currentPrice - upper) / upper * 100 + return BreakoutUpper, breakoutPct + } + + // Check lower breakout + if currentPrice < lower { + breakoutPct := (lower - currentPrice) / lower * 100 + return BreakoutLower, breakoutPct + } + + return BreakoutNone, 0 +} + +// handleBreakout handles price breakout from grid range +func (at *AutoTrader) handleBreakout(breakoutType BreakoutType, breakoutPct float64) error { + gridConfig := at.config.StrategyConfig.GridConfig + + logger.Warnf("[Grid] BREAKOUT DETECTED: %s, %.2f%% beyond boundary", breakoutType, breakoutPct) + + // If breakout exceeds 2%, pause grid and cancel orders + if breakoutPct >= 2.0 { + logger.Warnf("[Grid] Significant breakout (%.2f%%), pausing grid and canceling orders", breakoutPct) + + // Cancel all pending orders to prevent further losses + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders on breakout: %v", err) + } + + // Pause grid trading + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + return fmt.Errorf("grid paused due to %s breakout (%.2f%%)", breakoutType, breakoutPct) + } + + // If breakout is minor (< 2%), consider adjusting grid + if breakoutPct >= 1.0 { + logger.Infof("[Grid] Minor breakout (%.2f%%), considering grid adjustment", breakoutPct) + // Let AI decide whether to adjust + } + + return nil +} +``` + +### Step 5.2: 在 RunGridCycle 中添加突破检测 + +修改 `trader/auto_trader_grid.go` 的 `RunGridCycle` 函数: + +```go +func (at *AutoTrader) RunGridCycle() error { + if at.gridState == nil || !at.gridState.IsInitialized { + if err := at.InitializeGrid(); err != nil { + return fmt.Errorf("failed to initialize grid: %w", err) + } + } + + // CRITICAL: Check for breakout before executing any trades + breakoutType, breakoutPct := at.checkBreakout() + if breakoutType != BreakoutNone { + if err := at.handleBreakout(breakoutType, breakoutPct); err != nil { + return err // Grid paused due to breakout + } + } + + // Check if grid is paused + at.gridState.mu.RLock() + isPaused := at.gridState.IsPaused + at.gridState.mu.RUnlock() + if isPaused { + logger.Infof("[Grid] Grid is paused, skipping cycle") + return nil + } + + gridConfig := at.config.StrategyConfig.GridConfig + // ... 其余现有代码 ... +} +``` + +### Step 5.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 5.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): add breakout detection and auto-pause + +CRITICAL: Detect price breakout from grid range +- New checkBreakout() function +- Auto-pause grid on significant breakout (>2%) +- Cancel all orders when breakout detected +- Prevent continued losses in trending market" +``` + +--- + +## Task 6: 添加 MaxDrawdown 强制执行 + +**问题:** `MaxDrawdownPct` 存在于配置但从未检查 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkMaxDrawdown` 函数) +- Modify: `trader/auto_trader_grid.go:184-224` (在 `RunGridCycle` 中调用) + +### Step 6.1: 添加最大回撤检查函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// checkMaxDrawdown checks if current drawdown exceeds maximum allowed +// Returns: (exceeded bool, currentDrawdown float64) +func (at *AutoTrader) checkMaxDrawdown() (bool, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.MaxDrawdownPct <= 0 { + return false, 0 + } + + // Get current equity + balance, err := at.trader.GetBalance() + if err != nil { + return false, 0 + } + + currentEquity := 0.0 + if equity, ok := balance["total_equity"].(float64); ok { + currentEquity = equity + } else if total, ok := balance["totalWalletBalance"].(float64); ok { + if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { + currentEquity = total + unrealized + } + } + + if currentEquity <= 0 { + return false, 0 + } + + // Update peak equity + at.gridState.mu.Lock() + if currentEquity > at.gridState.PeakEquity { + at.gridState.PeakEquity = currentEquity + } + peakEquity := at.gridState.PeakEquity + at.gridState.mu.Unlock() + + if peakEquity <= 0 { + return false, 0 + } + + // Calculate current drawdown + drawdown := (peakEquity - currentEquity) / peakEquity * 100 + + // Update max drawdown tracking + at.gridState.mu.Lock() + if drawdown > at.gridState.MaxDrawdown { + at.gridState.MaxDrawdown = drawdown + } + at.gridState.mu.Unlock() + + return drawdown >= gridConfig.MaxDrawdownPct, drawdown +} + +// emergencyExit closes all positions and cancels all orders +func (at *AutoTrader) emergencyExit(reason string) error { + gridConfig := at.config.StrategyConfig.GridConfig + + logger.Errorf("[Grid] EMERGENCY EXIT: %s", reason) + + // Cancel all orders + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders in emergency: %v", err) + } + + // Close all positions + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok && size != 0 { + if size > 0 { + at.trader.CloseLong(gridConfig.Symbol, size) + } else { + at.trader.CloseShort(gridConfig.Symbol, -size) + } + } + } + } + } + + // Pause grid + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + return nil +} +``` + +### Step 6.2: 在 RunGridCycle 中添加回撤检查 + +修改 `trader/auto_trader_grid.go` 的 `RunGridCycle` 函数,在突破检测后添加: + +```go +func (at *AutoTrader) RunGridCycle() error { + // ... 初始化检查 ... + + // CRITICAL: Check for breakout + // ... 突破检测代码 ... + + // CRITICAL: Check max drawdown + exceeded, drawdown := at.checkMaxDrawdown() + if exceeded { + return at.emergencyExit(fmt.Sprintf("max drawdown exceeded: %.2f%%", drawdown)) + } + + // ... 其余代码 ... +} +``` + +### Step 6.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 6.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): enforce max drawdown limit with emergency exit + +CRITICAL: Add drawdown protection +- New checkMaxDrawdown() function tracks peak equity +- emergencyExit() closes all positions and cancels orders +- Auto-pause grid when MaxDrawdownPct exceeded +- Protect capital from excessive losses" +``` + +--- + +## Task 7: 添加 DailyLossLimit 强制执行 + +**问题:** `DailyLossLimitPct` 存在于配置但从未检查 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkDailyLossLimit` 函数) +- Modify: `trader/auto_trader_grid.go:184-224` (在 `RunGridCycle` 中调用) + +### Step 7.1: 添加日损失限制检查函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// checkDailyLossLimit checks if daily loss exceeds limit +// Returns: (exceeded bool, dailyLossPct float64) +func (at *AutoTrader) checkDailyLossLimit() (bool, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.DailyLossLimitPct <= 0 { + return false, 0 + } + + at.gridState.mu.Lock() + // Reset daily PnL if new day + now := time.Now() + if now.YearDay() != at.gridState.LastDailyReset.YearDay() || + now.Year() != at.gridState.LastDailyReset.Year() { + at.gridState.DailyPnL = 0 + at.gridState.LastDailyReset = now + } + dailyPnL := at.gridState.DailyPnL + at.gridState.mu.Unlock() + + // Calculate daily loss as percentage of total investment + dailyLossPct := 0.0 + if gridConfig.TotalInvestment > 0 && dailyPnL < 0 { + dailyLossPct = (-dailyPnL) / gridConfig.TotalInvestment * 100 + } + + return dailyLossPct >= gridConfig.DailyLossLimitPct, dailyLossPct +} + +// updateDailyPnL updates the daily PnL tracking +func (at *AutoTrader) updateDailyPnL(realizedPnL float64) { + at.gridState.mu.Lock() + at.gridState.DailyPnL += realizedPnL + at.gridState.TotalProfit += realizedPnL + at.gridState.mu.Unlock() +} +``` + +### Step 7.2: 在 RunGridCycle 中添加日损失检查 + +修改 `trader/auto_trader_grid.go` 的 `RunGridCycle` 函数: + +```go +func (at *AutoTrader) RunGridCycle() error { + // ... 初始化和突破检测 ... + + // CRITICAL: Check max drawdown + // ... + + // CRITICAL: Check daily loss limit + exceeded, dailyLossPct := at.checkDailyLossLimit() + if exceeded { + logger.Errorf("[Grid] Daily loss limit exceeded: %.2f%%", dailyLossPct) + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return fmt.Errorf("daily loss limit exceeded: %.2f%%", dailyLossPct) + } + + // ... 其余代码 ... +} +``` + +### Step 7.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 7.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): enforce daily loss limit + +- New checkDailyLossLimit() function +- Track daily PnL with auto-reset at midnight +- Pause grid when DailyLossLimitPct exceeded +- Prevent excessive single-day losses" +``` + +--- + +## Task 8: 添加自动网格调整 + +**问题:** 网格无法自动适应价格偏移 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkGridSkew` 函数) +- Modify: `trader/auto_trader_grid.go:504-565` (在 `syncGridState` 中调用) + +### Step 8.1: 添加网格倾斜检测函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// checkGridSkew checks if grid is heavily skewed (too many fills on one side) +// Returns: (skewed bool, buyFilledCount int, sellFilledCount int) +func (at *AutoTrader) checkGridSkew() (bool, int, int) { + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + buyFilled := 0 + sellFilled := 0 + buyEmpty := 0 + sellEmpty := 0 + + for _, level := range at.gridState.Levels { + if level.Side == "buy" { + if level.State == "filled" { + buyFilled++ + } else if level.State == "empty" { + buyEmpty++ + } + } else { + if level.State == "filled" { + sellFilled++ + } else if level.State == "empty" { + sellEmpty++ + } + } + } + + // Grid is skewed if one side has 3x more fills than the other + // or if one side is completely empty + skewed := false + if buyFilled > 0 && sellFilled == 0 && sellEmpty > 5 { + skewed = true // All buys filled, no sells + } else if sellFilled > 0 && buyFilled == 0 && buyEmpty > 5 { + skewed = true // All sells filled, no buys + } else if buyFilled >= 3*sellFilled && buyFilled > 5 { + skewed = true + } else if sellFilled >= 3*buyFilled && sellFilled > 5 { + skewed = true + } + + return skewed, buyFilled, sellFilled +} + +// autoAdjustGrid automatically adjusts grid when heavily skewed +func (at *AutoTrader) autoAdjustGrid() { + skewed, buyFilled, sellFilled := at.checkGridSkew() + if !skewed { + return + } + + logger.Warnf("[Grid] Grid heavily skewed: buy_filled=%d, sell_filled=%d. Auto-adjusting...", + buyFilled, sellFilled) + + gridConfig := at.config.StrategyConfig.GridConfig + + // Get current price + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Errorf("[Grid] Failed to get price for auto-adjust: %v", err) + return + } + + // Check if price is near grid boundary + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + // Only adjust if price has moved significantly (>50% of grid range) + gridRange := upper - lower + midPrice := (upper + lower) / 2 + priceDeviation := math.Abs(currentPrice - midPrice) + + if priceDeviation < gridRange*0.3 { + return // Price still near center, don't adjust + } + + // Cancel existing orders and reinitialize + logger.Infof("[Grid] Adjusting grid around new price $%.2f", currentPrice) + at.cancelAllGridOrders() + at.initializeGridLevels(currentPrice, gridConfig) +} +``` + +### Step 8.2: 在 syncGridState 中调用自动调整 + +修改 `trader/auto_trader_grid.go` 的 `syncGridState` 函数: + +```go +func (at *AutoTrader) syncGridState() { + // ... 现有代码 ... + + // Check stop loss + at.checkAndExecuteStopLoss() + + // Check grid skew and auto-adjust if needed + at.autoAdjustGrid() +} +``` + +### Step 8.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 8.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): add automatic grid adjustment + +- New checkGridSkew() detects imbalanced grid +- autoAdjustGrid() reinitializes around current price +- Prevents grid from becoming ineffective after drift +- Triggers when one side is 3x more filled than other" +``` + +--- + +## Task 9: 修复订单状态同步逻辑 + +**问题:** 假设订单不存在就是成交,但可能是被取消 + +**Files:** +- Modify: `trader/auto_trader_grid.go:504-565` + +### Step 9.1: 改进订单状态同步逻辑 + +修改 `trader/auto_trader_grid.go` 的 `syncGridState` 函数: + +```go +// syncGridState syncs grid state with exchange +func (at *AutoTrader) syncGridState() { + gridConfig := at.config.StrategyConfig.GridConfig + + // Get open orders from exchange + openOrders, err := at.trader.GetOpenOrders(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get open orders: %v", err) + return + } + + // Build set of active order IDs + activeOrderIDs := make(map[string]bool) + for _, order := range openOrders { + activeOrderIDs[order.OrderID] = true + } + + // Get current positions to verify fills + positions, err := at.trader.GetPositions() + currentPositionSize := 0.0 + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok { + currentPositionSize = size + } + } + } + } + + // Update levels based on order status + at.gridState.mu.Lock() + previousFilledCount := 0 + for _, level := range at.gridState.Levels { + if level.State == "filled" { + previousFilledCount++ + } + } + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State == "pending" && level.OrderID != "" { + if !activeOrderIDs[level.OrderID] { + // Order no longer exists - check if position changed to determine fill vs cancel + // This is a heuristic - ideally we'd query order history + if math.Abs(currentPositionSize) > math.Abs(float64(previousFilledCount)*level.OrderQuantity) { + // Position increased, likely filled + level.State = "filled" + level.PositionEntry = level.Price + level.PositionSize = level.OrderQuantity + at.gridState.TotalTrades++ + logger.Infof("[Grid] Level %d order filled at $%.2f", i, level.Price) + } else { + // Position didn't increase as expected, likely cancelled + level.State = "empty" + level.OrderID = "" + level.OrderQuantity = 0 + logger.Infof("[Grid] Level %d order cancelled/expired", i) + } + delete(at.gridState.OrderBook, level.OrderID) + } + } + } + at.gridState.mu.Unlock() + + logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", currentPositionSize, len(openOrders)) + + // Check stop loss + at.checkAndExecuteStopLoss() + + // Check grid skew + at.autoAdjustGrid() +} +``` + +### Step 9.2: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 9.3: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "fix(grid): improve order state sync logic + +- Don't assume missing orders are filled +- Compare position size to determine fill vs cancel +- Properly reset cancelled orders to empty state +- More accurate grid state tracking" +``` + +--- + +## 完成后的验证步骤 + +### 全面测试 + +```bash +# 编译验证 +go build ./... + +# 运行所有trader测试 +go test -v ./trader/... -timeout 300s + +# 运行网格相关测试 +go test -v -run "Grid" ./trader/ -timeout 60s +``` + +### 代码审查清单 + +- [ ] 所有P0致命问题已修复 +- [ ] 所有P1严重问题已修复 +- [ ] 杠杆在初始化时设置 +- [ ] 订单取消逻辑正确 +- [ ] 总仓位有限制 +- [ ] 止损被执行 +- [ ] 突破时自动暂停 +- [ ] MaxDrawdown触发紧急退出 +- [ ] DailyLossLimit暂停交易 +- [ ] 网格自动调整 + +--- + +## 架构改进总结 + +``` +修复后的架构: + +┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ ┌─────────────┐ +│ 市场数据 │ ──▶ │ AI决策 │ ──▶ │ 代码级风控验证 │ ──▶ │ 执行交易 │ +└─────────────┘ └─────────────┘ └─────────────────────────┘ └─────────────┘ + │ + ▼ + ┌────────────────────────────────────────────────────┐ + │ 风控检查清单 (每个周期执行) │ + │ ✓ checkBreakout() - 突破检测 │ + │ ✓ checkMaxDrawdown() - 最大回撤 │ + │ ✓ checkDailyLossLimit() - 日损失限制 │ + │ ✓ checkTotalPositionLimit() - 总仓位限制 │ + │ ✓ checkAndExecuteStopLoss() - 止损执行 │ + │ ✓ checkGridSkew() - 网格平衡 │ + │ ✓ SetLeverage() - 杠杆设置 │ + └────────────────────────────────────────────────────┘ +``` diff --git a/docs/plans/2026-01-17-grid-market-regime-design.md b/docs/plans/2026-01-17-grid-market-regime-design.md new file mode 100644 index 00000000..bd9ba7c1 --- /dev/null +++ b/docs/plans/2026-01-17-grid-market-regime-design.md @@ -0,0 +1,151 @@ +# 网格策略市场状态识别与风控设计 + +## 概述 + +增强网格策略的市场状态识别能力,实现震荡/趋势的精准判断,并根据不同震荡级别自动调整网格参数和风控策略。 + +--- + +## 一、市场状态识别 + +### 1.1 识别维度(3个) + +| 维度 | 指标 | 作用 | +|------|------|------| +| 价格波动 | ATR14 + Bollinger带宽 | 判断震荡幅度 | +| 趋势强度 | EMA20/50距离 + MACD | 判断是否有趋势 | +| 动量 | RSI14 + 1h/4h涨跌幅 | 判断超买超卖 | + +### 1.2 箱体指标(新增) + +基于1小时K线的多周期Donchian通道: + +| 箱体级别 | 周期 | 覆盖时间 | 用途 | +|----------|------|----------|------| +| 短期箱体 | 72根1小时 | 3天 | 日内波动边界 | +| 中期箱体 | 240根1小时 | 10天 | 周级别震荡区间 | +| 长期箱体 | 500根1小时 | ~21天 | 大级别趋势边界 | + +### 1.3 判断方式 + +由AI综合分析以上指标 + 原始K线序列 + 箱体位置,输出市场状态判断。 + +--- + +## 二、震荡分级与网格策略 + +### 2.1 四级震荡分类 + +| 级别 | 特征 | 判断依据 | +|------|------|----------| +| 窄幅震荡 | 价格在短期箱体内小幅波动 | Bollinger带宽 < 2%,ATR低 | +| 标准震荡 | 价格在中期箱体内正常波动 | Bollinger带宽 2-3%,ATR正常 | +| 宽幅震荡 | 价格接近中期箱体边缘 | Bollinger带宽 3-4%,ATR较高 | +| 剧烈震荡 | 价格接近长期箱体边缘 | Bollinger带宽 > 4%,ATR高 | + +### 2.2 各级别对应的网格策略 + +| 级别 | 网格密度 | 网格范围 | 单格仓位 | 总仓位上限 | 有效杠杆上限 | +|------|----------|----------|----------|------------|--------------| +| 窄幅震荡 | 密集 | 窄 | 小 | 30-40% | 2x | +| 标准震荡 | 正常 | 中等 | 正常 | 60-70% | 3-4x | +| 宽幅震荡 | 稀疏 | 宽 | 正常 | 50-60% | 3x | +| 剧烈震荡 | 最稀疏 | 最宽 | 小 | 30-40% | 2x | + +**核心原则:** +- 窄幅震荡:单格仓位小 + 总仓位上限低(防击穿风险) +- 剧烈震荡:同样保守(随时可能变趋势) +- 标准震荡:才是放量的最佳时机 + +--- + +## 三、突破处理与恢复机制 + +### 3.1 突破判断与处理 + +**确认方式:** 收盘价突破箱体后,持续3根1小时K线不回箱体 + +| 箱体级别 | 突破处理 | +|----------|----------| +| 短期箱体突破 | 降低仓位到 50% | +| 中期箱体突破 | 暂停网格 + 取消挂单 | +| 长期箱体突破 | 暂停网格 + 取消挂单 + 平掉所有持仓 | + +### 3.2 假突破恢复 + +**价格回到箱体内 → 以50%仓位恢复网格** + +--- + +## 四、前端风控面板 + +### 4.1 需要展示的信息 + +| 类别 | 显示内容 | +|------|----------| +| 杠杆信息 | 当前杠杆、有效杠杆、系统推荐杠杆 | +| 仓位信息 | 当前仓位、最大仓位、仓位占比 | +| 爆仓信息 | 爆仓价格、爆仓距离(%) | +| 市场状态 | 当前震荡级别(窄幅/标准/宽幅/剧烈) | +| 箱体状态 | 短期/中期/长期箱体上下沿、当前价格位置 | + +--- + +## 五、实现要点 + +### 5.1 后端新增 + +1. **箱体指标计算** (`market/data.go`) + - 新增 `calculateDonchian(klines, period)` 函数 + - 返回 upper(最高价), lower(最低价) + - 支持72/240/500三个周期 + +2. **市场状态评估** (`kernel/grid_engine.go`) + - 更新AI prompt,加入箱体指标和K线序列 + - AI输出震荡级别判断 + +3. **网格参数动态调整** (`trader/auto_trader_grid.go`) + - 根据震荡级别自动调整:网格密度、范围、仓位、杠杆 + - 实现有效杠杆上限控制 + +4. **突破处理逻辑** (`trader/auto_trader_grid.go`) + - 实现三级箱体突破检测 + - 实现3根K线确认逻辑 + - 实现降级恢复机制 + +### 5.2 前端新增 + +1. **风控面板组件** + - 杠杆信息展示 + - 仓位信息展示 + - 爆仓信息展示 + - 市场状态展示 + - 箱体状态可视化 + +### 5.3 数据模型更新 + +1. **GridConfigModel** 新增字段: + - `EffectiveLeverageLimit` - 有效杠杆上限 + - `ShortBoxPeriod` - 短期箱体周期 (默认72) + - `MidBoxPeriod` - 中期箱体周期 (默认240) + - `LongBoxPeriod` - 长期箱体周期 (默认500) + +2. **GridInstanceModel** 新增字段: + - `CurrentRegimeLevel` - 当前震荡级别 (narrow/standard/wide/volatile) + - `ShortBoxUpper/Lower` - 短期箱体上下沿 + - `MidBoxUpper/Lower` - 中期箱体上下沿 + - `LongBoxUpper/Lower` - 长期箱体上下沿 + - `BreakoutStatus` - 突破状态 (none/short/mid/long) + - `BreakoutConfirmCount` - 突破确认K线计数 + +--- + +## 六、风险控制总结 + +| 控制点 | 机制 | +|--------|------| +| 仓位控制 | 根据震荡级别限制总仓位上限 (30-70%) | +| 杠杆控制 | 根据震荡级别限制有效杠杆 (2-4x) | +| 突破保护 | 三级箱体突破分级处理 | +| 假突破恢复 | 50%仓位降级恢复 | +| 爆仓预防 | 前端展示爆仓距离,系统自动限制杠杆 | diff --git a/docs/plans/2026-01-17-grid-market-regime-impl.md b/docs/plans/2026-01-17-grid-market-regime-impl.md new file mode 100644 index 00000000..a28d2b44 --- /dev/null +++ b/docs/plans/2026-01-17-grid-market-regime-impl.md @@ -0,0 +1,1655 @@ +# Grid Market Regime Detection Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Implement multi-period box indicators and 4-level ranging classification for grid trading with automatic parameter adjustment and breakout handling. + +**Architecture:** Add Donchian channel calculation to market package, extend grid models with box/regime fields, implement breakout detection in auto_trader_grid, add risk control panel to frontend. + +**Tech Stack:** Go (backend), React/TypeScript (frontend), GORM (database), 1-hour Kline data + +--- + +## Task 1: Add Donchian Channel Calculation + +**Files:** +- Modify: `market/data.go` +- Test: `market/data_test.go` + +**Step 1: Write the failing test** + +Add to `market/data_test.go`: + +```go +func TestCalculateDonchian(t *testing.T) { + // Create test klines with known high/low values + klines := []Kline{ + {High: 100, Low: 90}, + {High: 105, Low: 88}, + {High: 102, Low: 92}, + {High: 108, Low: 85}, + {High: 103, Low: 91}, + } + + upper, lower := calculateDonchian(klines, 5) + + if upper != 108 { + t.Errorf("Expected upper = 108, got %v", upper) + } + if lower != 85 { + t.Errorf("Expected lower = 85, got %v", lower) + } +} + +func TestCalculateDonchian_PartialPeriod(t *testing.T) { + klines := []Kline{ + {High: 100, Low: 90}, + {High: 105, Low: 88}, + } + + upper, lower := calculateDonchian(klines, 10) + + // Should use all available klines when period > len(klines) + if upper != 105 { + t.Errorf("Expected upper = 105, got %v", upper) + } + if lower != 88 { + t.Errorf("Expected lower = 88, got %v", lower) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./market/... -run TestCalculateDonchian` +Expected: FAIL with "undefined: calculateDonchian" + +**Step 3: Write minimal implementation** + +Add to `market/data.go`: + +```go +// calculateDonchian calculates Donchian channel (highest high, lowest low) for given period +func calculateDonchian(klines []Kline, period int) (upper, lower float64) { + if len(klines) == 0 { + return 0, 0 + } + + // Use all available klines if period > len(klines) + start := len(klines) - period + if start < 0 { + start = 0 + } + + upper = klines[start].High + lower = klines[start].Low + + for i := start + 1; i < len(klines); i++ { + if klines[i].High > upper { + upper = klines[i].High + } + if klines[i].Low < lower { + lower = klines[i].Low + } + } + + return upper, lower +} + +// ExportCalculateDonchian exports calculateDonchian for testing +func ExportCalculateDonchian(klines []Kline, period int) (float64, float64) { + return calculateDonchian(klines, period) +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./market/... -run TestCalculateDonchian` +Expected: PASS + +**Step 5: Commit** + +```bash +git add market/data.go market/data_test.go +git commit -m "feat(market): add Donchian channel calculation" +``` + +--- + +## Task 2: Add Box Data Types + +**Files:** +- Modify: `market/types.go` + +**Step 1: Add BoxData struct** + +Add to `market/types.go`: + +```go +// BoxData represents multi-period Donchian channel (box) data +type BoxData struct { + // Short-term box (72 1h candles = 3 days) + ShortUpper float64 `json:"short_upper"` + ShortLower float64 `json:"short_lower"` + + // Mid-term box (240 1h candles = 10 days) + MidUpper float64 `json:"mid_upper"` + MidLower float64 `json:"mid_lower"` + + // Long-term box (500 1h candles = ~21 days) + LongUpper float64 `json:"long_upper"` + LongLower float64 `json:"long_lower"` + + // Current price position relative to boxes + CurrentPrice float64 `json:"current_price"` +} + +// RegimeLevel represents the ranging classification level +type RegimeLevel string + +const ( + RegimeLevelNarrow RegimeLevel = "narrow" // 窄幅震荡 + RegimeLevelStandard RegimeLevel = "standard" // 标准震荡 + RegimeLevelWide RegimeLevel = "wide" // 宽幅震荡 + RegimeLevelVolatile RegimeLevel = "volatile" // 剧烈震荡 + RegimeLevelTrending RegimeLevel = "trending" // 趋势 +) + +// BreakoutLevel represents which box level has been broken +type BreakoutLevel string + +const ( + BreakoutNone BreakoutLevel = "none" + BreakoutShort BreakoutLevel = "short" + BreakoutMid BreakoutLevel = "mid" + BreakoutLong BreakoutLevel = "long" +) +``` + +**Step 2: Commit** + +```bash +git add market/types.go +git commit -m "feat(market): add BoxData and RegimeLevel types" +``` + +--- + +## Task 3: Add GetBoxData Function + +**Files:** +- Modify: `market/data.go` +- Test: `market/data_test.go` + +**Step 1: Write the failing test** + +Add to `market/data_test.go`: + +```go +func TestGetBoxData(t *testing.T) { + // This test requires mocking kline data source + // For now, test the internal calculation logic + klines := make([]Kline, 500) + for i := 0; i < 500; i++ { + // Create synthetic price data + basePrice := 100.0 + klines[i] = Kline{ + High: basePrice + float64(i%10), + Low: basePrice - float64(i%10), + } + } + + box := calculateBoxData(klines, 100.0) + + if box.ShortUpper == 0 || box.ShortLower == 0 { + t.Error("Short box should not be zero") + } + if box.MidUpper == 0 || box.MidLower == 0 { + t.Error("Mid box should not be zero") + } + if box.LongUpper == 0 || box.LongLower == 0 { + t.Error("Long box should not be zero") + } + if box.CurrentPrice != 100.0 { + t.Errorf("Expected CurrentPrice = 100.0, got %v", box.CurrentPrice) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./market/... -run TestGetBoxData` +Expected: FAIL with "undefined: calculateBoxData" + +**Step 3: Write minimal implementation** + +Add to `market/data.go`: + +```go +const ( + ShortBoxPeriod = 72 // 3 days of 1h candles + MidBoxPeriod = 240 // 10 days of 1h candles + LongBoxPeriod = 500 // ~21 days of 1h candles +) + +// calculateBoxData calculates multi-period box data from klines +func calculateBoxData(klines []Kline, currentPrice float64) *BoxData { + box := &BoxData{ + CurrentPrice: currentPrice, + } + + if len(klines) == 0 { + return box + } + + box.ShortUpper, box.ShortLower = calculateDonchian(klines, ShortBoxPeriod) + box.MidUpper, box.MidLower = calculateDonchian(klines, MidBoxPeriod) + box.LongUpper, box.LongLower = calculateDonchian(klines, LongBoxPeriod) + + return box +} + +// GetBoxData fetches 1h klines and calculates box data for a symbol +func GetBoxData(symbol string) (*BoxData, error) { + symbol = Normalize(symbol) + + // Fetch 500 1h klines + var klines []Kline + var err error + + if IsXyzDexAsset(symbol) { + klines, err = getKlinesFromHyperliquid(symbol, "1h", LongBoxPeriod) + } else { + klines, err = getKlinesFromCoinAnk(symbol, "1h", LongBoxPeriod) + } + + if err != nil { + return nil, fmt.Errorf("failed to get 1h klines: %w", err) + } + + if len(klines) == 0 { + return nil, fmt.Errorf("no kline data available") + } + + currentPrice := klines[len(klines)-1].Close + + return calculateBoxData(klines, currentPrice), nil +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./market/... -run TestGetBoxData` +Expected: PASS + +**Step 5: Commit** + +```bash +git add market/data.go market/data_test.go +git commit -m "feat(market): add GetBoxData for multi-period box calculation" +``` + +--- + +## Task 4: Update GridConfigModel with Box Parameters + +**Files:** +- Modify: `store/grid.go` + +**Step 1: Add new fields to GridConfigModel** + +Add fields after `TrendResumeThreshold` in `store/grid.go`: + +```go + // Box indicator periods (1h candles) + ShortBoxPeriod int `json:"short_box_period" gorm:"default:72"` // 3 days + MidBoxPeriod int `json:"mid_box_period" gorm:"default:240"` // 10 days + LongBoxPeriod int `json:"long_box_period" gorm:"default:500"` // 21 days + + // Effective leverage limits by regime level + NarrowRegimeLeverage int `json:"narrow_regime_leverage" gorm:"default:2"` + StandardRegimeLeverage int `json:"standard_regime_leverage" gorm:"default:4"` + WideRegimeLeverage int `json:"wide_regime_leverage" gorm:"default:3"` + VolatileRegimeLeverage int `json:"volatile_regime_leverage" gorm:"default:2"` + + // Position limits by regime level (percentage of total investment) + NarrowRegimePositionPct float64 `json:"narrow_regime_position_pct" gorm:"default:40"` + StandardRegimePositionPct float64 `json:"standard_regime_position_pct" gorm:"default:70"` + WideRegimePositionPct float64 `json:"wide_regime_position_pct" gorm:"default:60"` + VolatileRegimePositionPct float64 `json:"volatile_regime_position_pct" gorm:"default:40"` +``` + +**Step 2: Commit** + +```bash +git add store/grid.go +git commit -m "feat(store): add box period and regime leverage fields to GridConfigModel" +``` + +--- + +## Task 5: Update GridInstanceModel with Box State + +**Files:** +- Modify: `store/grid.go` + +**Step 1: Add new fields to GridInstanceModel** + +Add fields after `ConsecutiveTrending` in `store/grid.go`: + +```go + // Current regime level (narrow/standard/wide/volatile/trending) + CurrentRegimeLevel string `json:"current_regime_level" gorm:"default:standard"` + + // Box state + ShortBoxUpper float64 `json:"short_box_upper"` + ShortBoxLower float64 `json:"short_box_lower"` + MidBoxUpper float64 `json:"mid_box_upper"` + MidBoxLower float64 `json:"mid_box_lower"` + LongBoxUpper float64 `json:"long_box_upper"` + LongBoxLower float64 `json:"long_box_lower"` + + // Breakout state + BreakoutLevel string `json:"breakout_level" gorm:"default:none"` // none/short/mid/long + BreakoutDirection string `json:"breakout_direction"` // up/down + BreakoutConfirmCount int `json:"breakout_confirm_count" gorm:"default:0"` + BreakoutStartTime time.Time `json:"breakout_start_time"` + + // Position adjustment due to breakout + PositionReductionPct float64 `json:"position_reduction_pct" gorm:"default:0"` // 0 = normal, 50 = reduced +``` + +**Step 2: Commit** + +```bash +git add store/grid.go +git commit -m "feat(store): add box state and breakout fields to GridInstanceModel" +``` + +--- + +## Task 6: Add Regime Level Classification + +**Files:** +- Create: `trader/grid_regime.go` +- Test: `trader/grid_regime_test.go` + +**Step 1: Write the failing test** + +Create `trader/grid_regime_test.go`: + +```go +package trader + +import ( + "nofx/market" + "testing" +) + +func TestClassifyRegimeLevel(t *testing.T) { + tests := []struct { + name string + bollingerWidth float64 + atr14Pct float64 + expected market.RegimeLevel + }{ + {"narrow", 1.5, 0.8, market.RegimeLevelNarrow}, + {"standard", 2.5, 1.5, market.RegimeLevelStandard}, + {"wide", 3.5, 2.5, market.RegimeLevelWide}, + {"volatile", 5.0, 4.0, market.RegimeLevelVolatile}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := classifyRegimeLevel(tt.bollingerWidth, tt.atr14Pct) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestClassifyRegimeLevel` +Expected: FAIL with "undefined: classifyRegimeLevel" + +**Step 3: Write minimal implementation** + +Create `trader/grid_regime.go`: + +```go +package trader + +import "nofx/market" + +// classifyRegimeLevel determines the regime level based on market indicators +// bollingerWidth: Bollinger band width as percentage +// atr14Pct: ATR14 as percentage of current price +func classifyRegimeLevel(bollingerWidth, atr14Pct float64) market.RegimeLevel { + // Narrow: Bollinger < 2%, ATR < 1% + if bollingerWidth < 2.0 && atr14Pct < 1.0 { + return market.RegimeLevelNarrow + } + + // Standard: Bollinger 2-3%, ATR 1-2% + if bollingerWidth <= 3.0 && atr14Pct <= 2.0 { + return market.RegimeLevelStandard + } + + // Wide: Bollinger 3-4%, ATR 2-3% + if bollingerWidth <= 4.0 && atr14Pct <= 3.0 { + return market.RegimeLevelWide + } + + // Volatile: Bollinger > 4%, ATR > 3% + return market.RegimeLevelVolatile +} + +// getRegimeLeverageLimit returns the effective leverage limit for a regime level +func getRegimeLeverageLimit(level market.RegimeLevel, config *store.GridStrategyConfig) int { + switch level { + case market.RegimeLevelNarrow: + return config.NarrowRegimeLeverage + case market.RegimeLevelStandard: + return config.StandardRegimeLeverage + case market.RegimeLevelWide: + return config.WideRegimeLeverage + case market.RegimeLevelVolatile: + return config.VolatileRegimeLeverage + default: + return 2 // Conservative default + } +} + +// getRegimePositionLimit returns the position limit percentage for a regime level +func getRegimePositionLimit(level market.RegimeLevel, config *store.GridStrategyConfig) float64 { + switch level { + case market.RegimeLevelNarrow: + return config.NarrowRegimePositionPct + case market.RegimeLevelStandard: + return config.StandardRegimePositionPct + case market.RegimeLevelWide: + return config.WideRegimePositionPct + case market.RegimeLevelVolatile: + return config.VolatileRegimePositionPct + default: + return 40.0 // Conservative default + } +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestClassifyRegimeLevel` +Expected: PASS + +**Step 5: Commit** + +```bash +git add trader/grid_regime.go trader/grid_regime_test.go +git commit -m "feat(trader): add regime level classification" +``` + +--- + +## Task 7: Add Breakout Detection + +**Files:** +- Modify: `trader/grid_regime.go` +- Test: `trader/grid_regime_test.go` + +**Step 1: Write the failing test** + +Add to `trader/grid_regime_test.go`: + +```go +func TestDetectBoxBreakout(t *testing.T) { + box := &market.BoxData{ + ShortUpper: 100, + ShortLower: 90, + MidUpper: 105, + MidLower: 85, + LongUpper: 110, + LongLower: 80, + CurrentPrice: 95, + } + + // No breakout + level, direction := detectBoxBreakout(box) + if level != market.BreakoutNone { + t.Errorf("Expected no breakout, got %v", level) + } + + // Short breakout up + box.CurrentPrice = 101 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutShort || direction != "up" { + t.Errorf("Expected short breakout up, got %v %v", level, direction) + } + + // Mid breakout down + box.CurrentPrice = 84 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutMid || direction != "down" { + t.Errorf("Expected mid breakout down, got %v %v", level, direction) + } + + // Long breakout up + box.CurrentPrice = 112 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutLong || direction != "up" { + t.Errorf("Expected long breakout up, got %v %v", level, direction) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestDetectBoxBreakout` +Expected: FAIL with "undefined: detectBoxBreakout" + +**Step 3: Write minimal implementation** + +Add to `trader/grid_regime.go`: + +```go +// detectBoxBreakout checks if price has broken out of any box level +// Returns the highest breakout level and direction +func detectBoxBreakout(box *market.BoxData) (market.BreakoutLevel, string) { + price := box.CurrentPrice + + // Check long box first (highest priority) + if price > box.LongUpper { + return market.BreakoutLong, "up" + } + if price < box.LongLower { + return market.BreakoutLong, "down" + } + + // Check mid box + if price > box.MidUpper { + return market.BreakoutMid, "up" + } + if price < box.MidLower { + return market.BreakoutMid, "down" + } + + // Check short box + if price > box.ShortUpper { + return market.BreakoutShort, "up" + } + if price < box.ShortLower { + return market.BreakoutShort, "down" + } + + return market.BreakoutNone, "" +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestDetectBoxBreakout` +Expected: PASS + +**Step 5: Commit** + +```bash +git add trader/grid_regime.go trader/grid_regime_test.go +git commit -m "feat(trader): add box breakout detection" +``` + +--- + +## Task 8: Add Breakout Confirmation Logic + +**Files:** +- Modify: `trader/grid_regime.go` +- Test: `trader/grid_regime_test.go` + +**Step 1: Write the failing test** + +Add to `trader/grid_regime_test.go`: + +```go +func TestBreakoutConfirmation(t *testing.T) { + state := &BreakoutState{ + Level: market.BreakoutShort, + Direction: "up", + ConfirmCount: 0, + } + + // First confirmation + confirmed := confirmBreakout(state, market.BreakoutShort, "up") + if confirmed || state.ConfirmCount != 1 { + t.Errorf("Expected not confirmed, count=1, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Second confirmation + confirmed = confirmBreakout(state, market.BreakoutShort, "up") + if confirmed || state.ConfirmCount != 2 { + t.Errorf("Expected not confirmed, count=2, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Third confirmation - should confirm + confirmed = confirmBreakout(state, market.BreakoutShort, "up") + if !confirmed || state.ConfirmCount != 3 { + t.Errorf("Expected confirmed, count=3, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Reset on price return + state.ConfirmCount = 2 + confirmed = confirmBreakout(state, market.BreakoutNone, "") + if state.ConfirmCount != 0 { + t.Errorf("Expected count reset to 0, got %d", state.ConfirmCount) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestBreakoutConfirmation` +Expected: FAIL with "undefined: BreakoutState" + +**Step 3: Write minimal implementation** + +Add to `trader/grid_regime.go`: + +```go +const BreakoutConfirmRequired = 3 // 3 candles to confirm breakout + +// BreakoutState tracks the current breakout state +type BreakoutState struct { + Level market.BreakoutLevel + Direction string + ConfirmCount int + StartTime time.Time +} + +// confirmBreakout updates breakout state and returns true if breakout is confirmed +func confirmBreakout(state *BreakoutState, currentLevel market.BreakoutLevel, direction string) bool { + // If price returned to box, reset state + if currentLevel == market.BreakoutNone { + state.ConfirmCount = 0 + state.Level = market.BreakoutNone + state.Direction = "" + return false + } + + // If same breakout continues, increment count + if state.Level == currentLevel && state.Direction == direction { + state.ConfirmCount++ + } else { + // New breakout, reset count + state.Level = currentLevel + state.Direction = direction + state.ConfirmCount = 1 + state.StartTime = time.Now() + } + + return state.ConfirmCount >= BreakoutConfirmRequired +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestBreakoutConfirmation` +Expected: PASS + +**Step 5: Commit** + +```bash +git add trader/grid_regime.go trader/grid_regime_test.go +git commit -m "feat(trader): add breakout confirmation logic" +``` + +--- + +## Task 9: Add Breakout Handler + +**Files:** +- Modify: `trader/grid_regime.go` +- Test: `trader/grid_regime_test.go` + +**Step 1: Write the failing test** + +Add to `trader/grid_regime_test.go`: + +```go +func TestGetBreakoutAction(t *testing.T) { + tests := []struct { + level market.BreakoutLevel + expected BreakoutAction + }{ + {market.BreakoutNone, BreakoutActionNone}, + {market.BreakoutShort, BreakoutActionReducePosition}, + {market.BreakoutMid, BreakoutActionPauseGrid}, + {market.BreakoutLong, BreakoutActionCloseAll}, + } + + for _, tt := range tests { + t.Run(string(tt.level), func(t *testing.T) { + action := getBreakoutAction(tt.level) + if action != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, action) + } + }) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestGetBreakoutAction` +Expected: FAIL with "undefined: BreakoutAction" + +**Step 3: Write minimal implementation** + +Add to `trader/grid_regime.go`: + +```go +// BreakoutAction represents the action to take on breakout +type BreakoutAction int + +const ( + BreakoutActionNone BreakoutAction = iota + BreakoutActionReducePosition // Short box breakout: reduce to 50% + BreakoutActionPauseGrid // Mid box breakout: pause grid + cancel orders + BreakoutActionCloseAll // Long box breakout: pause + cancel + close all +) + +// getBreakoutAction returns the appropriate action for a breakout level +func getBreakoutAction(level market.BreakoutLevel) BreakoutAction { + switch level { + case market.BreakoutShort: + return BreakoutActionReducePosition + case market.BreakoutMid: + return BreakoutActionPauseGrid + case market.BreakoutLong: + return BreakoutActionCloseAll + default: + return BreakoutActionNone + } +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestGetBreakoutAction` +Expected: PASS + +**Step 5: Commit** + +```bash +git add trader/grid_regime.go trader/grid_regime_test.go +git commit -m "feat(trader): add breakout action handler" +``` + +--- + +## Task 10: Integrate Breakout Detection into Grid Cycle + +**Files:** +- Modify: `trader/auto_trader_grid.go` + +**Step 1: Add checkBoxBreakout method** + +Add to `trader/auto_trader_grid.go` after `checkBreakout` function: + +```go +// checkBoxBreakout checks for multi-period box breakouts and takes appropriate action +func (at *AutoTrader) checkBoxBreakout() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + // Get box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + logger.Infof("Failed to get box data: %v", err) + return nil // Non-fatal, continue with other checks + } + + // Update instance with box values + at.gridState.mu.Lock() + // Store box values in grid state for reference + at.gridState.mu.Unlock() + + // Detect breakout + breakoutLevel, direction := detectBoxBreakout(box) + + // Get current breakout state from instance + state := &BreakoutState{ + Level: market.BreakoutLevel(at.gridState.BreakoutLevel), + Direction: at.gridState.BreakoutDirection, + ConfirmCount: at.gridState.BreakoutConfirmCount, + } + + // Check if breakout is confirmed (3 candles) + confirmed := confirmBreakout(state, breakoutLevel, direction) + + // Update grid state + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(state.Level) + at.gridState.BreakoutDirection = state.Direction + at.gridState.BreakoutConfirmCount = state.ConfirmCount + at.gridState.mu.Unlock() + + if !confirmed { + return nil + } + + // Take action based on breakout level + action := getBreakoutAction(breakoutLevel) + return at.executeBreakoutAction(action) +} + +// executeBreakoutAction executes the appropriate action for a breakout +func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error { + gridConfig := at.config.StrategyConfig.GridConfig + + switch action { + case BreakoutActionReducePosition: + // Short box breakout: reduce position to 50% + logger.Infof("Short box breakout confirmed, reducing position to 50%%") + at.gridState.mu.Lock() + at.gridState.PositionReductionPct = 50 + at.gridState.mu.Unlock() + return nil + + case BreakoutActionPauseGrid: + // Mid box breakout: pause grid + cancel orders + logger.Infof("Mid box breakout confirmed, pausing grid and canceling orders") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return at.cancelAllGridOrders() + + case BreakoutActionCloseAll: + // Long box breakout: pause + cancel + close all + logger.Infof("Long box breakout confirmed, closing all positions") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + if err := at.cancelAllGridOrders(); err != nil { + logger.Infof("Failed to cancel orders: %v", err) + } + return at.closeAllPositions() + } + + return nil +} + +// closeAllPositions closes all open positions +func (at *AutoTrader) closeAllPositions() error { + gridConfig := at.config.StrategyConfig.GridConfig + + positions, err := at.trader.GetPositions() + if err != nil { + return fmt.Errorf("failed to get positions: %w", err) + } + + for _, pos := range positions { + symbol, _ := pos["symbol"].(string) + if symbol != gridConfig.Symbol { + continue + } + + size, _ := pos["positionAmt"].(float64) + if size == 0 { + continue + } + + if size > 0 { + _, err = at.trader.CloseLong(symbol, size) + } else { + _, err = at.trader.CloseShort(symbol, -size) + } + if err != nil { + logger.Infof("Failed to close position: %v", err) + } + } + + return nil +} +``` + +**Step 2: Add checkBoxBreakout call to RunGridCycle** + +In `RunGridCycle`, add after existing breakout check: + +```go + // Check multi-period box breakout + if err := at.checkBoxBreakout(); err != nil { + logger.Infof("Box breakout check error: %v", err) + } +``` + +**Step 3: Commit** + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(trader): integrate box breakout detection into grid cycle" +``` + +--- + +## Task 11: Add False Breakout Recovery + +**Files:** +- Modify: `trader/auto_trader_grid.go` + +**Step 1: Add recovery logic** + +Add to `trader/auto_trader_grid.go`: + +```go +// checkFalseBreakoutRecovery checks if price has returned to box after breakout +func (at *AutoTrader) checkFalseBreakoutRecovery() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + at.gridState.mu.RLock() + breakoutLevel := at.gridState.BreakoutLevel + isPaused := at.gridState.IsPaused + positionReduction := at.gridState.PositionReductionPct + at.gridState.mu.RUnlock() + + // Only check if we had a breakout + if breakoutLevel == string(market.BreakoutNone) && positionReduction == 0 && !isPaused { + return nil + } + + // Get current box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + return nil + } + + // Check if price is back inside the long box + if box.CurrentPrice >= box.LongLower && box.CurrentPrice <= box.LongUpper { + logger.Infof("Price returned to box, recovering with 50%% position") + + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(market.BreakoutNone) + at.gridState.BreakoutDirection = "" + at.gridState.BreakoutConfirmCount = 0 + at.gridState.PositionReductionPct = 50 // Recover at 50% + at.gridState.IsPaused = false + at.gridState.mu.Unlock() + } + + return nil +} +``` + +**Step 2: Add call in RunGridCycle** + +```go + // Check for false breakout recovery + if err := at.checkFalseBreakoutRecovery(); err != nil { + logger.Infof("False breakout recovery check error: %v", err) + } +``` + +**Step 3: Commit** + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(trader): add false breakout recovery logic" +``` + +--- + +## Task 12: Update GridState with Box Fields + +**Files:** +- Modify: `trader/auto_trader_grid.go` + +**Step 1: Add box fields to GridState struct** + +Add to `GridState` struct in `trader/auto_trader_grid.go`: + +```go + // Box state + ShortBoxUpper float64 + ShortBoxLower float64 + MidBoxUpper float64 + MidBoxLower float64 + LongBoxUpper float64 + LongBoxLower float64 + + // Breakout state + BreakoutLevel string + BreakoutDirection string + BreakoutConfirmCount int + + // Position reduction (0 = normal, 50 = reduced after false breakout) + PositionReductionPct float64 + + // Current regime level + CurrentRegimeLevel string +``` + +**Step 2: Commit** + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(trader): add box and regime fields to GridState" +``` + +--- + +## Task 13: Add Frontend Types + +**Files:** +- Modify: `web/src/types.ts` (or equivalent types file) + +**Step 1: Add grid risk info types** + +Add to types file: + +```typescript +export interface GridRiskInfo { + // Leverage info + currentLeverage: number + effectiveLeverage: number + recommendedLeverage: number + + // Position info + currentPosition: number + maxPosition: number + positionPercent: number + + // Liquidation info + liquidationPrice: number + liquidationDistance: number // percentage + + // Market state + regimeLevel: 'narrow' | 'standard' | 'wide' | 'volatile' | 'trending' + + // Box state + shortBoxUpper: number + shortBoxLower: number + midBoxUpper: number + midBoxLower: number + longBoxUpper: number + longBoxLower: number + currentPrice: number + + // Breakout state + breakoutLevel: 'none' | 'short' | 'mid' | 'long' + breakoutDirection: 'up' | 'down' | '' +} +``` + +**Step 2: Commit** + +```bash +git add web/src/types.ts +git commit -m "feat(web): add GridRiskInfo type" +``` + +--- + +## Task 14: Add API Endpoint for Risk Info + +**Files:** +- Modify: `api/server.go` + +**Step 1: Add handler function** + +Add to `api/server.go`: + +```go +// handleGetGridRiskInfo returns current risk information for a grid trader +func (s *Server) handleGetGridRiskInfo(c *gin.Context) { + traderID := c.Param("id") + + trader, err := s.manager.GetTrader(traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "trader not found"}) + return + } + + autoTrader, ok := trader.(*trader.AutoTrader) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "not an auto trader"}) + return + } + + riskInfo := autoTrader.GetGridRiskInfo() + c.JSON(http.StatusOK, riskInfo) +} +``` + +**Step 2: Add route** + +Add route in `setupRoutes`: + +```go + api.GET("/traders/:id/grid-risk", s.handleGetGridRiskInfo) +``` + +**Step 3: Commit** + +```bash +git add api/server.go +git commit -m "feat(api): add grid risk info endpoint" +``` + +--- + +## Task 15: Add GetGridRiskInfo Method to AutoTrader + +**Files:** +- Modify: `trader/auto_trader_grid.go` + +**Step 1: Add method** + +Add to `trader/auto_trader_grid.go`: + +```go +// GridRiskInfo contains risk information for frontend display +type GridRiskInfo struct { + CurrentLeverage int `json:"current_leverage"` + EffectiveLeverage float64 `json:"effective_leverage"` + RecommendedLeverage int `json:"recommended_leverage"` + + CurrentPosition float64 `json:"current_position"` + MaxPosition float64 `json:"max_position"` + PositionPercent float64 `json:"position_percent"` + + LiquidationPrice float64 `json:"liquidation_price"` + LiquidationDistance float64 `json:"liquidation_distance"` + + RegimeLevel string `json:"regime_level"` + + ShortBoxUpper float64 `json:"short_box_upper"` + ShortBoxLower float64 `json:"short_box_lower"` + MidBoxUpper float64 `json:"mid_box_upper"` + MidBoxLower float64 `json:"mid_box_lower"` + LongBoxUpper float64 `json:"long_box_upper"` + LongBoxLower float64 `json:"long_box_lower"` + CurrentPrice float64 `json:"current_price"` + + BreakoutLevel string `json:"breakout_level"` + BreakoutDirection string `json:"breakout_direction"` +} + +// GetGridRiskInfo returns current risk information +func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return &GridRiskInfo{} + } + + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + // Get current price + currentPrice, _ := at.trader.GetMarketPrice(gridConfig.Symbol) + + // Calculate effective leverage + totalInvestment := gridConfig.TotalInvestment + leverage := gridConfig.Leverage + + // Get current position value + positions, _ := at.trader.GetPositions() + var currentPositionValue float64 + for _, pos := range positions { + if sym, _ := pos["symbol"].(string); sym == gridConfig.Symbol { + size, _ := pos["positionAmt"].(float64) + entry, _ := pos["entryPrice"].(float64) + currentPositionValue = math.Abs(size * entry) + break + } + } + + effectiveLeverage := currentPositionValue / totalInvestment + + // Calculate max position based on regime + regimeLevel := market.RegimeLevel(at.gridState.CurrentRegimeLevel) + maxPositionPct := getRegimePositionLimit(regimeLevel, gridConfig) + maxPosition := totalInvestment * maxPositionPct / 100 * float64(leverage) + recommendedLeverage := getRegimeLeverageLimit(regimeLevel, gridConfig) + + // Calculate liquidation distance + liquidationDistance := 100.0 / float64(leverage) * 0.9 // ~90% of theoretical max + + var liquidationPrice float64 + if currentPositionValue > 0 { + liquidationPrice = currentPrice * (1 - liquidationDistance/100) + } + + return &GridRiskInfo{ + CurrentLeverage: leverage, + EffectiveLeverage: effectiveLeverage, + RecommendedLeverage: recommendedLeverage, + + CurrentPosition: currentPositionValue, + MaxPosition: maxPosition, + PositionPercent: currentPositionValue / maxPosition * 100, + + LiquidationPrice: liquidationPrice, + LiquidationDistance: liquidationDistance, + + RegimeLevel: at.gridState.CurrentRegimeLevel, + + ShortBoxUpper: at.gridState.ShortBoxUpper, + ShortBoxLower: at.gridState.ShortBoxLower, + MidBoxUpper: at.gridState.MidBoxUpper, + MidBoxLower: at.gridState.MidBoxLower, + LongBoxUpper: at.gridState.LongBoxUpper, + LongBoxLower: at.gridState.LongBoxLower, + CurrentPrice: currentPrice, + + BreakoutLevel: at.gridState.BreakoutLevel, + BreakoutDirection: at.gridState.BreakoutDirection, + } +} +``` + +**Step 2: Commit** + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(trader): add GetGridRiskInfo method" +``` + +--- + +## Task 16: Create GridRiskPanel Component + +**Files:** +- Create: `web/src/components/strategy/GridRiskPanel.tsx` + +**Step 1: Create component** + +Create `web/src/components/strategy/GridRiskPanel.tsx`: + +```tsx +import { useState, useEffect } from 'react' +import { AlertTriangle, TrendingUp, Shield, Box } from 'lucide-react' + +interface GridRiskInfo { + currentLeverage: number + effectiveLeverage: number + recommendedLeverage: number + currentPosition: number + maxPosition: number + positionPercent: number + liquidationPrice: number + liquidationDistance: number + regimeLevel: string + shortBoxUpper: number + shortBoxLower: number + midBoxUpper: number + midBoxLower: number + longBoxUpper: number + longBoxLower: number + currentPrice: number + breakoutLevel: string + breakoutDirection: string +} + +interface GridRiskPanelProps { + traderId: string + language: string +} + +export function GridRiskPanel({ traderId, language }: GridRiskPanelProps) { + const [riskInfo, setRiskInfo] = useState(null) + const [loading, setLoading] = useState(true) + + const t = (key: string) => { + const translations: Record> = { + leverageInfo: { zh: '杠杆信息', en: 'Leverage Info' }, + currentLeverage: { zh: '当前杠杆', en: 'Current Leverage' }, + effectiveLeverage: { zh: '有效杠杆', en: 'Effective Leverage' }, + recommendedLeverage: { zh: '推荐杠杆', en: 'Recommended Leverage' }, + positionInfo: { zh: '仓位信息', en: 'Position Info' }, + currentPosition: { zh: '当前仓位', en: 'Current Position' }, + maxPosition: { zh: '最大仓位', en: 'Max Position' }, + liquidationInfo: { zh: '爆仓信息', en: 'Liquidation Info' }, + liquidationPrice: { zh: '爆仓价格', en: 'Liquidation Price' }, + liquidationDistance: { zh: '爆仓距离', en: 'Distance' }, + marketState: { zh: '市场状态', en: 'Market State' }, + regimeLevel: { zh: '震荡级别', en: 'Regime Level' }, + boxState: { zh: '箱体状态', en: 'Box State' }, + shortBox: { zh: '短期箱体', en: 'Short Box' }, + midBox: { zh: '中期箱体', en: 'Mid Box' }, + longBox: { zh: '长期箱体', en: 'Long Box' }, + narrow: { zh: '窄幅震荡', en: 'Narrow' }, + standard: { zh: '标准震荡', en: 'Standard' }, + wide: { zh: '宽幅震荡', en: 'Wide' }, + volatile: { zh: '剧烈震荡', en: 'Volatile' }, + trending: { zh: '趋势', en: 'Trending' }, + breakout: { zh: '突破', en: 'Breakout' }, + none: { zh: '无', en: 'None' }, + } + return translations[key]?.[language] || key + } + + useEffect(() => { + const fetchRiskInfo = async () => { + try { + const res = await fetch(`/api/traders/${traderId}/grid-risk`) + if (res.ok) { + const data = await res.json() + setRiskInfo(data) + } + } catch (err) { + console.error('Failed to fetch risk info:', err) + } finally { + setLoading(false) + } + } + + fetchRiskInfo() + const interval = setInterval(fetchRiskInfo, 10000) // Update every 10s + return () => clearInterval(interval) + }, [traderId]) + + if (loading || !riskInfo) { + return
+ } + + const getRegimeColor = (level: string) => { + switch (level) { + case 'narrow': return 'text-green-400' + case 'standard': return 'text-blue-400' + case 'wide': return 'text-yellow-400' + case 'volatile': return 'text-orange-400' + case 'trending': return 'text-red-400' + default: return 'text-gray-400' + } + } + + return ( +
+ {/* Leverage Info */} +
+

+ + {t('leverageInfo')} +

+
+
+
{t('currentLeverage')}
+
{riskInfo.currentLeverage}x
+
+
+
{t('effectiveLeverage')}
+
{riskInfo.effectiveLeverage.toFixed(2)}x
+
+
+
{t('recommendedLeverage')}
+
{riskInfo.recommendedLeverage}x
+
+
+
+ + {/* Position Info */} +
+

+ + {t('positionInfo')} +

+
+
+
{t('currentPosition')}
+
${riskInfo.currentPosition.toFixed(2)}
+
+
+
{t('maxPosition')}
+
${riskInfo.maxPosition.toFixed(2)}
+
+
+
+
+
+
+ + {/* Liquidation Info */} +
+

+ + {t('liquidationInfo')} +

+
+
+
{t('liquidationPrice')}
+
${riskInfo.liquidationPrice.toFixed(2)}
+
+
+
{t('liquidationDistance')}
+
{riskInfo.liquidationDistance.toFixed(1)}%
+
+
+
+ + {/* Market State */} +
+

+ + {t('marketState')} +

+
+
+
{t('regimeLevel')}
+
+ {t(riskInfo.regimeLevel)} +
+
+ {riskInfo.breakoutLevel !== 'none' && ( +
+ {t('breakout')}: {riskInfo.breakoutLevel} ({riskInfo.breakoutDirection}) +
+ )} +
+
+ + {/* Box State */} +
+

{t('boxState')}

+
+
+ {t('shortBox')} + {riskInfo.shortBoxLower.toFixed(2)} - {riskInfo.shortBoxUpper.toFixed(2)} +
+
+ {t('midBox')} + {riskInfo.midBoxLower.toFixed(2)} - {riskInfo.midBoxUpper.toFixed(2)} +
+
+ {t('longBox')} + {riskInfo.longBoxLower.toFixed(2)} - {riskInfo.longBoxUpper.toFixed(2)} +
+
+ Current Price + ${riskInfo.currentPrice.toFixed(2)} +
+
+
+
+ ) +} +``` + +**Step 2: Commit** + +```bash +git add web/src/components/strategy/GridRiskPanel.tsx +git commit -m "feat(web): add GridRiskPanel component" +``` + +--- + +## Task 17: Update AI Prompt with Box Indicators + +**Files:** +- Modify: `kernel/grid_engine.go` + +**Step 1: Update BuildGridUserPrompt to include box data** + +Add box data section to the prompt in `kernel/grid_engine.go`: + +```go +// In BuildGridUserPrompt function, add after market data section: + + // Box Indicator Section + if gridCtx.BoxData != nil { + sb.WriteString("\n## Box Indicators (Donchian Channels)\n\n") + sb.WriteString("| Box Level | Upper | Lower | Width |\n") + sb.WriteString("|-----------|-------|-------|-------|\n") + + shortWidth := (gridCtx.BoxData.ShortUpper - gridCtx.BoxData.ShortLower) / gridCtx.BoxData.CurrentPrice * 100 + midWidth := (gridCtx.BoxData.MidUpper - gridCtx.BoxData.MidLower) / gridCtx.BoxData.CurrentPrice * 100 + longWidth := (gridCtx.BoxData.LongUpper - gridCtx.BoxData.LongLower) / gridCtx.BoxData.CurrentPrice * 100 + + sb.WriteString(fmt.Sprintf("| Short (3d) | %.2f | %.2f | %.2f%% |\n", + gridCtx.BoxData.ShortUpper, gridCtx.BoxData.ShortLower, shortWidth)) + sb.WriteString(fmt.Sprintf("| Mid (10d) | %.2f | %.2f | %.2f%% |\n", + gridCtx.BoxData.MidUpper, gridCtx.BoxData.MidLower, midWidth)) + sb.WriteString(fmt.Sprintf("| Long (21d) | %.2f | %.2f | %.2f%% |\n", + gridCtx.BoxData.LongUpper, gridCtx.BoxData.LongLower, longWidth)) + + // Price position + sb.WriteString(fmt.Sprintf("\nCurrent Price: %.2f\n", gridCtx.BoxData.CurrentPrice)) + + // Check position relative to boxes + price := gridCtx.BoxData.CurrentPrice + if price > gridCtx.BoxData.LongUpper || price < gridCtx.BoxData.LongLower { + sb.WriteString("⚠️ BREAKOUT: Price outside long-term box!\n") + } else if price > gridCtx.BoxData.MidUpper || price < gridCtx.BoxData.MidLower { + sb.WriteString("⚠️ WARNING: Price approaching long-term box boundary\n") + } + } +``` + +**Step 2: Update GridContext struct** + +Add BoxData field to GridContext: + +```go +type GridContext struct { + // ... existing fields ... + + // Box data + BoxData *market.BoxData +} +``` + +**Step 3: Commit** + +```bash +git add kernel/grid_engine.go +git commit -m "feat(kernel): add box indicators to AI prompt" +``` + +--- + +## Task 18: Database Migration + +**Files:** +- Modify: `store/grid.go` + +**Step 1: Update InitGridSchema to migrate new fields** + +The GORM AutoMigrate will handle adding new columns. Verify by running: + +```bash +cd /Users/yida/gopro/open-nofx && go run . migrate +``` + +**Step 2: Commit** + +```bash +git add store/grid.go +git commit -m "chore(store): ensure new grid fields are migrated" +``` + +--- + +## Task 19: Run All Tests + +**Step 1: Run backend tests** + +```bash +cd /Users/yida/gopro/open-nofx && go test -v ./... +``` + +**Step 2: Run frontend tests (if available)** + +```bash +cd /Users/yida/gopro/open-nofx/web && npm test +``` + +**Step 3: Fix any failing tests and commit** + +```bash +git add . +git commit -m "test: fix tests for grid regime implementation" +``` + +--- + +## Task 20: Final Integration Test + +**Step 1: Start the server** + +```bash +cd /Users/yida/gopro/open-nofx && go run . +``` + +**Step 2: Verify API endpoint** + +```bash +curl http://localhost:8080/api/traders//grid-risk +``` + +**Step 3: Verify frontend displays risk panel** + +Open browser and check grid trading page shows risk panel. + +**Step 4: Final commit** + +```bash +git add . +git commit -m "feat: complete grid market regime detection implementation" +``` + +--- + +## Summary + +| Task | Description | Files | +|------|-------------|-------| +| 1 | Donchian calculation | market/data.go | +| 2 | Box data types | market/types.go | +| 3 | GetBoxData function | market/data.go | +| 4 | GridConfigModel fields | store/grid.go | +| 5 | GridInstanceModel fields | store/grid.go | +| 6 | Regime classification | trader/grid_regime.go | +| 7 | Breakout detection | trader/grid_regime.go | +| 8 | Breakout confirmation | trader/grid_regime.go | +| 9 | Breakout handler | trader/grid_regime.go | +| 10 | Grid cycle integration | trader/auto_trader_grid.go | +| 11 | False breakout recovery | trader/auto_trader_grid.go | +| 12 | GridState fields | trader/auto_trader_grid.go | +| 13 | Frontend types | web/src/types.ts | +| 14 | API endpoint | api/server.go | +| 15 | GetGridRiskInfo method | trader/auto_trader_grid.go | +| 16 | GridRiskPanel component | web/src/components/ | +| 17 | AI prompt update | kernel/grid_engine.go | +| 18 | Database migration | store/grid.go | +| 19 | Run all tests | - | +| 20 | Integration test | - | diff --git a/kernel/engine.go b/kernel/engine.go index e452c730..6ce5aa52 100644 --- a/kernel/engine.go +++ b/kernel/engine.go @@ -130,7 +130,8 @@ type Context struct { // Decision AI trading decision type Decision struct { Symbol string `json:"symbol"` - Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short", "hold", "wait" + Action string `json:"action"` // Standard: "open_long", "open_short", "close_long", "close_short", "hold", "wait" + // Grid actions: "place_buy_limit", "place_sell_limit", "cancel_order", "cancel_all_orders", "pause_grid", "resume_grid", "adjust_grid" // Opening position parameters Leverage int `json:"leverage,omitempty"` @@ -138,6 +139,12 @@ type Decision struct { StopLoss float64 `json:"stop_loss,omitempty"` TakeProfit float64 `json:"take_profit,omitempty"` + // Grid trading parameters + Price float64 `json:"price,omitempty"` // Limit order price (for grid) + Quantity float64 `json:"quantity,omitempty"` // Order quantity (for grid) + LevelIndex int `json:"level_index,omitempty"` // Grid level index + OrderID string `json:"order_id,omitempty"` // Order ID (for cancel) + // Common parameters Confidence int `json:"confidence,omitempty"` // Confidence level (0-100) RiskUSD float64 `json:"risk_usd,omitempty"` // Maximum USD risk diff --git a/kernel/grid_engine.go b/kernel/grid_engine.go new file mode 100644 index 00000000..243432a2 --- /dev/null +++ b/kernel/grid_engine.go @@ -0,0 +1,587 @@ +package kernel + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "nofx/market" + "nofx/mcp" + "nofx/store" + "strings" + "time" +) + +// ============================================================================ +// Grid Trading Context and Types +// ============================================================================ + +// GridLevelInfo represents a single grid level's current state +type GridLevelInfo struct { + Index int `json:"index"` // Level index (0 = lowest) + Price float64 `json:"price"` // Target price for this level + State string `json:"state"` // "empty", "pending", "filled" + Side string `json:"side"` // "buy" or "sell" + OrderID string `json:"order_id"` // Current order ID (if pending) + OrderQuantity float64 `json:"order_quantity"` // Order quantity + PositionSize float64 `json:"position_size"` // Position size (if filled) + PositionEntry float64 `json:"position_entry"` // Entry price (if filled) + AllocatedUSD float64 `json:"allocated_usd"` // USD allocated to this level + UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized P&L (if filled) +} + +// GridContext contains all information needed for AI grid decision making +type GridContext struct { + // Basic info + Symbol string `json:"symbol"` + CurrentTime string `json:"current_time"` + CurrentPrice float64 `json:"current_price"` + + // Grid configuration + GridCount int `json:"grid_count"` + TotalInvestment float64 `json:"total_investment"` + Leverage int `json:"leverage"` + UpperPrice float64 `json:"upper_price"` + LowerPrice float64 `json:"lower_price"` + GridSpacing float64 `json:"grid_spacing"` + Distribution string `json:"distribution"` + + // Grid state + Levels []GridLevelInfo `json:"levels"` + ActiveOrderCount int `json:"active_order_count"` + FilledLevelCount int `json:"filled_level_count"` + IsPaused bool `json:"is_paused"` + + // Market data + ATR14 float64 `json:"atr14"` + BollingerUpper float64 `json:"bollinger_upper"` + BollingerMiddle float64 `json:"bollinger_middle"` + BollingerLower float64 `json:"bollinger_lower"` + BollingerWidth float64 `json:"bollinger_width"` // Percentage + EMA20 float64 `json:"ema20"` + EMA50 float64 `json:"ema50"` + EMADistance float64 `json:"ema_distance"` // Percentage + RSI14 float64 `json:"rsi14"` + MACD float64 `json:"macd"` + MACDSignal float64 `json:"macd_signal"` + MACDHistogram float64 `json:"macd_histogram"` + FundingRate float64 `json:"funding_rate"` + Volume24h float64 `json:"volume_24h"` + PriceChange1h float64 `json:"price_change_1h"` + PriceChange4h float64 `json:"price_change_4h"` + + // Account info + TotalEquity float64 `json:"total_equity"` + AvailableBalance float64 `json:"available_balance"` + CurrentPosition float64 `json:"current_position"` // Net position size + UnrealizedPnL float64 `json:"unrealized_pnl"` + + // Performance + TotalProfit float64 `json:"total_profit"` + TotalTrades int `json:"total_trades"` + WinningTrades int `json:"winning_trades"` + MaxDrawdown float64 `json:"max_drawdown"` + DailyPnL float64 `json:"daily_pnl"` + + // Box indicators (Donchian Channels) + BoxData *market.BoxData `json:"box_data,omitempty"` +} + +// ============================================================================ +// Grid Prompt Building +// ============================================================================ + +// BuildGridSystemPrompt builds the system prompt for grid trading AI +func BuildGridSystemPrompt(config *store.GridStrategyConfig, lang string) string { + if lang == "zh" { + return buildGridSystemPromptZh(config) + } + return buildGridSystemPromptEn(config) +} + +func buildGridSystemPromptZh(config *store.GridStrategyConfig) string { + return fmt.Sprintf(`# 你是一个专业的网格交易AI + +## 角色定义 +你是一个经验丰富的网格交易专家,负责管理 %s 的网格交易策略。你的任务是: +1. 判断当前市场状态(震荡/趋势/高波动) +2. 决定是否需要调整网格或暂停交易 +3. 管理每个网格层级的订单 + +## 网格配置 +- 交易对: %s +- 网格层数: %d +- 总投资: %.2f USDT +- 杠杆: %dx +- 价格分布: %s + +## 决策规则 + +### 市场状态判断 +- **震荡市场** (适合网格): 布林带宽度 < 3%%, EMA20/50 距离 < 1%%, 价格在布林带中轨附近 +- **趋势市场** (暂停网格): 布林带宽度 > 4%%, EMA20/50 距离 > 2%%, 价格持续突破布林带 +- **高波动市场** (谨慎): ATR异常放大, 价格剧烈波动 + +### 可执行的操作 +- place_buy_limit: 在指定价格下买入限价单 +- place_sell_limit: 在指定价格下卖出限价单 +- cancel_order: 取消指定订单 +- cancel_all_orders: 取消所有订单 +- pause_grid: 暂停网格交易(趋势市场时) +- resume_grid: 恢复网格交易(震荡市场时) +- adjust_grid: 调整网格边界 +- hold: 保持当前状态不操作 + +## 输出格式 +输出JSON数组,每个决策包含: +- symbol: 交易对 +- action: 操作类型 +- price: 价格(限价单用) +- quantity: 数量 +- level_index: 网格层级索引 +- order_id: 订单ID(取消订单用) +- confidence: 置信度 0-100 +- reasoning: 决策理由 + +示例: +[ + {"symbol": "BTCUSDT", "action": "place_buy_limit", "price": 94000, "quantity": 0.01, "level_index": 2, "confidence": 85, "reasoning": "第2层价格接近,下买单"}, + {"symbol": "BTCUSDT", "action": "hold", "confidence": 90, "reasoning": "市场震荡,保持当前网格"} +] +`, config.Symbol, config.Symbol, config.GridCount, config.TotalInvestment, config.Leverage, config.Distribution) +} + +func buildGridSystemPromptEn(config *store.GridStrategyConfig) string { + return fmt.Sprintf(`# You are a Professional Grid Trading AI + +## Role Definition +You are an experienced grid trading expert managing a grid strategy for %s. Your tasks are: +1. Assess current market regime (ranging/trending/volatile) +2. Decide whether to adjust grid or pause trading +3. Manage orders at each grid level + +## Grid Configuration +- Symbol: %s +- Grid Levels: %d +- Total Investment: %.2f USDT +- Leverage: %dx +- Distribution: %s + +## Decision Rules + +### Market Regime Assessment +- **Ranging Market** (ideal for grid): Bollinger width < 3%%, EMA20/50 distance < 1%%, price near middle band +- **Trending Market** (pause grid): Bollinger width > 4%%, EMA20/50 distance > 2%%, price breaking bands +- **High Volatility** (caution): ATR spike, erratic price movement + +### Available Actions +- place_buy_limit: Place buy limit order at specified price +- place_sell_limit: Place sell limit order at specified price +- cancel_order: Cancel specific order +- cancel_all_orders: Cancel all orders +- pause_grid: Pause grid trading (in trending market) +- resume_grid: Resume grid trading (in ranging market) +- adjust_grid: Adjust grid boundaries +- hold: Maintain current state + +## Output Format +Output JSON array, each decision contains: +- symbol: Trading pair +- action: Action type +- price: Price (for limit orders) +- quantity: Quantity +- level_index: Grid level index +- order_id: Order ID (for cancel) +- confidence: Confidence 0-100 +- reasoning: Decision reason + +Example: +[ + {"symbol": "BTCUSDT", "action": "place_buy_limit", "price": 94000, "quantity": 0.01, "level_index": 2, "confidence": 85, "reasoning": "Level 2 price approaching, place buy order"}, + {"symbol": "BTCUSDT", "action": "hold", "confidence": 90, "reasoning": "Market ranging, maintain current grid"} +] +`, config.Symbol, config.Symbol, config.GridCount, config.TotalInvestment, config.Leverage, config.Distribution) +} + +// BuildGridUserPrompt builds the user prompt with current grid context +func BuildGridUserPrompt(ctx *GridContext, lang string) string { + if lang == "zh" { + return buildGridUserPromptZh(ctx) + } + return buildGridUserPromptEn(ctx) +} + +func buildGridUserPromptZh(ctx *GridContext) string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("## 当前时间: %s\n\n", ctx.CurrentTime)) + + // Market data section + sb.WriteString("## 市场数据\n") + sb.WriteString(fmt.Sprintf("- 当前价格: $%.2f\n", ctx.CurrentPrice)) + sb.WriteString(fmt.Sprintf("- 1小时涨跌: %.2f%%\n", ctx.PriceChange1h)) + sb.WriteString(fmt.Sprintf("- 4小时涨跌: %.2f%%\n", ctx.PriceChange4h)) + sb.WriteString(fmt.Sprintf("- ATR14: $%.2f (%.2f%%)\n", ctx.ATR14, ctx.ATR14/ctx.CurrentPrice*100)) + sb.WriteString(fmt.Sprintf("- 布林带: 上轨 $%.2f, 中轨 $%.2f, 下轨 $%.2f\n", ctx.BollingerUpper, ctx.BollingerMiddle, ctx.BollingerLower)) + sb.WriteString(fmt.Sprintf("- 布林带宽度: %.2f%%\n", ctx.BollingerWidth)) + sb.WriteString(fmt.Sprintf("- EMA20: $%.2f, EMA50: $%.2f, 距离: %.2f%%\n", ctx.EMA20, ctx.EMA50, ctx.EMADistance)) + sb.WriteString(fmt.Sprintf("- RSI14: %.1f\n", ctx.RSI14)) + sb.WriteString(fmt.Sprintf("- MACD: %.4f, Signal: %.4f, Histogram: %.4f\n", ctx.MACD, ctx.MACDSignal, ctx.MACDHistogram)) + sb.WriteString(fmt.Sprintf("- 资金费率: %.4f%%\n", ctx.FundingRate*100)) + sb.WriteString("\n") + + // Box Indicator Section + if ctx.BoxData != nil { + sb.WriteString("## 箱体指标 (唐奇安通道)\n\n") + sb.WriteString("| 箱体级别 | 上轨 | 下轨 | 宽度 |\n") + sb.WriteString("|----------|------|------|------|\n") + + shortWidth := 0.0 + midWidth := 0.0 + longWidth := 0.0 + + if ctx.BoxData.CurrentPrice > 0 { + shortWidth = (ctx.BoxData.ShortUpper - ctx.BoxData.ShortLower) / ctx.BoxData.CurrentPrice * 100 + midWidth = (ctx.BoxData.MidUpper - ctx.BoxData.MidLower) / ctx.BoxData.CurrentPrice * 100 + longWidth = (ctx.BoxData.LongUpper - ctx.BoxData.LongLower) / ctx.BoxData.CurrentPrice * 100 + } + + sb.WriteString(fmt.Sprintf("| 短期 (3天) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.ShortUpper, ctx.BoxData.ShortLower, shortWidth)) + sb.WriteString(fmt.Sprintf("| 中期 (10天) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.MidUpper, ctx.BoxData.MidLower, midWidth)) + sb.WriteString(fmt.Sprintf("| 长期 (21天) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.LongUpper, ctx.BoxData.LongLower, longWidth)) + + sb.WriteString(fmt.Sprintf("\n当前价格: %.2f\n", ctx.BoxData.CurrentPrice)) + + // Check position relative to boxes + price := ctx.BoxData.CurrentPrice + if price > ctx.BoxData.LongUpper || price < ctx.BoxData.LongLower { + sb.WriteString("⚠️ 突破: 价格突破长期箱体!\n") + } else if price > ctx.BoxData.MidUpper || price < ctx.BoxData.MidLower { + sb.WriteString("⚠️ 警告: 价格接近长期箱体边界\n") + } + sb.WriteString("\n") + } + + // Account section + sb.WriteString("## 账户状态\n") + sb.WriteString(fmt.Sprintf("- 总权益: $%.2f\n", ctx.TotalEquity)) + sb.WriteString(fmt.Sprintf("- 可用余额: $%.2f\n", ctx.AvailableBalance)) + sb.WriteString(fmt.Sprintf("- 当前持仓: %.4f (净头寸)\n", ctx.CurrentPosition)) + sb.WriteString(fmt.Sprintf("- 未实现盈亏: $%.2f\n", ctx.UnrealizedPnL)) + sb.WriteString("\n") + + // Grid state section + sb.WriteString("## 网格状态\n") + sb.WriteString(fmt.Sprintf("- 网格范围: $%.2f - $%.2f\n", ctx.LowerPrice, ctx.UpperPrice)) + sb.WriteString(fmt.Sprintf("- 网格间距: $%.2f\n", ctx.GridSpacing)) + sb.WriteString(fmt.Sprintf("- 活跃订单数: %d\n", ctx.ActiveOrderCount)) + sb.WriteString(fmt.Sprintf("- 已成交层数: %d\n", ctx.FilledLevelCount)) + sb.WriteString(fmt.Sprintf("- 网格已暂停: %v\n", ctx.IsPaused)) + sb.WriteString("\n") + + // Grid levels detail + sb.WriteString("## 网格层级详情\n") + sb.WriteString("| 层级 | 价格 | 状态 | 方向 | 订单数量 | 持仓数量 | 未实现盈亏 |\n") + sb.WriteString("|------|------|------|------|----------|----------|------------|\n") + for _, level := range ctx.Levels { + sb.WriteString(fmt.Sprintf("| %d | $%.2f | %s | %s | %.4f | %.4f | $%.2f |\n", + level.Index, level.Price, level.State, level.Side, + level.OrderQuantity, level.PositionSize, level.UnrealizedPnL)) + } + sb.WriteString("\n") + + // Performance section + sb.WriteString("## 绩效统计\n") + sb.WriteString(fmt.Sprintf("- 总利润: $%.2f\n", ctx.TotalProfit)) + sb.WriteString(fmt.Sprintf("- 总交易次数: %d\n", ctx.TotalTrades)) + sb.WriteString(fmt.Sprintf("- 胜率: %.1f%%\n", float64(ctx.WinningTrades)/float64(max(ctx.TotalTrades, 1))*100)) + sb.WriteString(fmt.Sprintf("- 最大回撤: %.2f%%\n", ctx.MaxDrawdown)) + sb.WriteString(fmt.Sprintf("- 今日盈亏: $%.2f\n", ctx.DailyPnL)) + sb.WriteString("\n") + + sb.WriteString("## 请分析以上数据,做出网格交易决策\n") + sb.WriteString("输出JSON数组格式的决策列表。\n") + + return sb.String() +} + +func buildGridUserPromptEn(ctx *GridContext) string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("## Current Time: %s\n\n", ctx.CurrentTime)) + + // Market data section + sb.WriteString("## Market Data\n") + sb.WriteString(fmt.Sprintf("- Current Price: $%.2f\n", ctx.CurrentPrice)) + sb.WriteString(fmt.Sprintf("- 1h Change: %.2f%%\n", ctx.PriceChange1h)) + sb.WriteString(fmt.Sprintf("- 4h Change: %.2f%%\n", ctx.PriceChange4h)) + sb.WriteString(fmt.Sprintf("- ATR14: $%.2f (%.2f%%)\n", ctx.ATR14, ctx.ATR14/ctx.CurrentPrice*100)) + sb.WriteString(fmt.Sprintf("- Bollinger Bands: Upper $%.2f, Middle $%.2f, Lower $%.2f\n", ctx.BollingerUpper, ctx.BollingerMiddle, ctx.BollingerLower)) + sb.WriteString(fmt.Sprintf("- Bollinger Width: %.2f%%\n", ctx.BollingerWidth)) + sb.WriteString(fmt.Sprintf("- EMA20: $%.2f, EMA50: $%.2f, Distance: %.2f%%\n", ctx.EMA20, ctx.EMA50, ctx.EMADistance)) + sb.WriteString(fmt.Sprintf("- RSI14: %.1f\n", ctx.RSI14)) + sb.WriteString(fmt.Sprintf("- MACD: %.4f, Signal: %.4f, Histogram: %.4f\n", ctx.MACD, ctx.MACDSignal, ctx.MACDHistogram)) + sb.WriteString(fmt.Sprintf("- Funding Rate: %.4f%%\n", ctx.FundingRate*100)) + sb.WriteString("\n") + + // Box Indicator Section + if ctx.BoxData != nil { + sb.WriteString("## Box Indicators (Donchian Channels)\n\n") + sb.WriteString("| Box Level | Upper | Lower | Width |\n") + sb.WriteString("|-----------|-------|-------|-------|\n") + + shortWidth := 0.0 + midWidth := 0.0 + longWidth := 0.0 + + if ctx.BoxData.CurrentPrice > 0 { + shortWidth = (ctx.BoxData.ShortUpper - ctx.BoxData.ShortLower) / ctx.BoxData.CurrentPrice * 100 + midWidth = (ctx.BoxData.MidUpper - ctx.BoxData.MidLower) / ctx.BoxData.CurrentPrice * 100 + longWidth = (ctx.BoxData.LongUpper - ctx.BoxData.LongLower) / ctx.BoxData.CurrentPrice * 100 + } + + sb.WriteString(fmt.Sprintf("| Short (3d) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.ShortUpper, ctx.BoxData.ShortLower, shortWidth)) + sb.WriteString(fmt.Sprintf("| Mid (10d) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.MidUpper, ctx.BoxData.MidLower, midWidth)) + sb.WriteString(fmt.Sprintf("| Long (21d) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.LongUpper, ctx.BoxData.LongLower, longWidth)) + + sb.WriteString(fmt.Sprintf("\nCurrent Price: %.2f\n", ctx.BoxData.CurrentPrice)) + + // Check position relative to boxes + price := ctx.BoxData.CurrentPrice + if price > ctx.BoxData.LongUpper || price < ctx.BoxData.LongLower { + sb.WriteString("⚠️ BREAKOUT: Price outside long-term box!\n") + } else if price > ctx.BoxData.MidUpper || price < ctx.BoxData.MidLower { + sb.WriteString("⚠️ WARNING: Price approaching long-term box boundary\n") + } + sb.WriteString("\n") + } + + // Account section + sb.WriteString("## Account Status\n") + sb.WriteString(fmt.Sprintf("- Total Equity: $%.2f\n", ctx.TotalEquity)) + sb.WriteString(fmt.Sprintf("- Available Balance: $%.2f\n", ctx.AvailableBalance)) + sb.WriteString(fmt.Sprintf("- Current Position: %.4f (net)\n", ctx.CurrentPosition)) + sb.WriteString(fmt.Sprintf("- Unrealized PnL: $%.2f\n", ctx.UnrealizedPnL)) + sb.WriteString("\n") + + // Grid state section + sb.WriteString("## Grid Status\n") + sb.WriteString(fmt.Sprintf("- Grid Range: $%.2f - $%.2f\n", ctx.LowerPrice, ctx.UpperPrice)) + sb.WriteString(fmt.Sprintf("- Grid Spacing: $%.2f\n", ctx.GridSpacing)) + sb.WriteString(fmt.Sprintf("- Active Orders: %d\n", ctx.ActiveOrderCount)) + sb.WriteString(fmt.Sprintf("- Filled Levels: %d\n", ctx.FilledLevelCount)) + sb.WriteString(fmt.Sprintf("- Grid Paused: %v\n", ctx.IsPaused)) + sb.WriteString("\n") + + // Grid levels detail + sb.WriteString("## Grid Levels Detail\n") + sb.WriteString("| Level | Price | State | Side | Order Qty | Position | Unrealized PnL |\n") + sb.WriteString("|-------|-------|-------|------|-----------|----------|----------------|\n") + for _, level := range ctx.Levels { + sb.WriteString(fmt.Sprintf("| %d | $%.2f | %s | %s | %.4f | %.4f | $%.2f |\n", + level.Index, level.Price, level.State, level.Side, + level.OrderQuantity, level.PositionSize, level.UnrealizedPnL)) + } + sb.WriteString("\n") + + // Performance section + sb.WriteString("## Performance Stats\n") + sb.WriteString(fmt.Sprintf("- Total Profit: $%.2f\n", ctx.TotalProfit)) + sb.WriteString(fmt.Sprintf("- Total Trades: %d\n", ctx.TotalTrades)) + sb.WriteString(fmt.Sprintf("- Win Rate: %.1f%%\n", float64(ctx.WinningTrades)/float64(max(ctx.TotalTrades, 1))*100)) + sb.WriteString(fmt.Sprintf("- Max Drawdown: %.2f%%\n", ctx.MaxDrawdown)) + sb.WriteString(fmt.Sprintf("- Daily PnL: $%.2f\n", ctx.DailyPnL)) + sb.WriteString("\n") + + sb.WriteString("## Please analyze the data above and make grid trading decisions\n") + sb.WriteString("Output a JSON array of decisions.\n") + + return sb.String() +} + +// ============================================================================ +// Grid Decision Functions +// ============================================================================ + +// GetGridDecisions gets AI decisions for grid trading +func GetGridDecisions(ctx *GridContext, mcpClient mcp.AIClient, config *store.GridStrategyConfig, lang string) (*FullDecision, error) { + startTime := time.Now() + + // Build prompts + systemPrompt := BuildGridSystemPrompt(config, lang) + userPrompt := BuildGridUserPrompt(ctx, lang) + + logger.Infof("🤖 [Grid] Calling AI for grid decisions...") + + // Call AI + response, err := mcpClient.CallWithMessages(systemPrompt, userPrompt) + if err != nil { + return nil, fmt.Errorf("AI call failed: %w", err) + } + + // Parse decisions from response + decisions, err := parseGridDecisions(response, ctx.Symbol) + if err != nil { + logger.Warnf("Failed to parse grid decisions: %v", err) + // Return hold decision as fallback + decisions = []Decision{{ + Symbol: ctx.Symbol, + Action: "hold", + Confidence: 50, + Reasoning: "Failed to parse AI response, holding current state", + }} + } + + duration := time.Since(startTime).Milliseconds() + logger.Infof("⏱️ [Grid] AI call duration: %d ms, decisions: %d", duration, len(decisions)) + + // Extract chain of thought from response + cotTrace := extractCoTTrace(response) + + return &FullDecision{ + SystemPrompt: systemPrompt, + UserPrompt: userPrompt, + CoTTrace: cotTrace, + Decisions: decisions, + RawResponse: response, + AIRequestDurationMs: duration, + Timestamp: time.Now(), + }, nil +} + +// parseGridDecisions parses AI response into grid decisions +func parseGridDecisions(response string, symbol string) ([]Decision, error) { + // Try to find JSON array in response + jsonStr := extractJSONArray(response) + if jsonStr == "" { + return nil, fmt.Errorf("no JSON array found in response") + } + + var decisions []Decision + if err := json.Unmarshal([]byte(jsonStr), &decisions); err != nil { + return nil, fmt.Errorf("failed to parse JSON: %w", err) + } + + // Validate and set default symbol + for i := range decisions { + if decisions[i].Symbol == "" { + decisions[i].Symbol = symbol + } + // Validate action + if !isValidGridAction(decisions[i].Action) { + logger.Warnf("Invalid grid action: %s", decisions[i].Action) + } + } + + return decisions, nil +} + +// extractJSONArray extracts JSON array from AI response +func extractJSONArray(response string) string { + // Try to find ```json code block first + matches := reJSONFence.FindStringSubmatch(response) + if len(matches) > 1 { + return matches[1] + } + + // Try to find raw JSON array + matches = reJSONArray.FindStringSubmatch(response) + if len(matches) > 0 { + return matches[0] + } + + return "" +} + +// isValidGridAction checks if action is a valid grid action +func isValidGridAction(action string) bool { + validActions := map[string]bool{ + "place_buy_limit": true, + "place_sell_limit": true, + "cancel_order": true, + "cancel_all_orders": true, + "pause_grid": true, + "resume_grid": true, + "adjust_grid": true, + "hold": true, + // Also support standard actions for compatibility + "open_long": true, + "open_short": true, + "close_long": true, + "close_short": true, + } + return validActions[action] +} + +// ============================================================================ +// Grid Context Builder Helpers +// ============================================================================ + +// BuildGridContextFromMarketData builds grid context from market data +func BuildGridContextFromMarketData(mktData *market.Data, config *store.GridStrategyConfig) *GridContext { + ctx := &GridContext{ + Symbol: config.Symbol, + CurrentTime: time.Now().Format("2006-01-02 15:04:05"), + CurrentPrice: mktData.CurrentPrice, + + // Grid config + GridCount: config.GridCount, + TotalInvestment: config.TotalInvestment, + Leverage: config.Leverage, + Distribution: config.Distribution, + + // Market data + PriceChange1h: mktData.PriceChange1h, + PriceChange4h: mktData.PriceChange4h, + FundingRate: mktData.FundingRate, + } + + // Extract indicators from timeframe data + if mktData.TimeframeData != nil { + if tf5m, ok := mktData.TimeframeData["5m"]; ok { + if len(tf5m.BOLLUpper) > 0 { + ctx.BollingerUpper = tf5m.BOLLUpper[len(tf5m.BOLLUpper)-1] + ctx.BollingerMiddle = tf5m.BOLLMiddle[len(tf5m.BOLLMiddle)-1] + ctx.BollingerLower = tf5m.BOLLLower[len(tf5m.BOLLLower)-1] + if ctx.BollingerMiddle > 0 { + ctx.BollingerWidth = (ctx.BollingerUpper - ctx.BollingerLower) / ctx.BollingerMiddle * 100 + } + } + ctx.ATR14 = tf5m.ATR14 + if len(tf5m.RSI14Values) > 0 { + ctx.RSI14 = tf5m.RSI14Values[len(tf5m.RSI14Values)-1] + } + } + } + + // Extract longer term context + if mktData.LongerTermContext != nil { + if ctx.ATR14 == 0 { + ctx.ATR14 = mktData.LongerTermContext.ATR14 + } + ctx.EMA50 = mktData.LongerTermContext.EMA50 + } + + ctx.EMA20 = mktData.CurrentEMA20 + ctx.MACD = mktData.CurrentMACD + + // Calculate EMA distance + if ctx.EMA50 > 0 { + ctx.EMADistance = (ctx.EMA20 - ctx.EMA50) / ctx.EMA50 * 100 + } + + return ctx +} + +// Helper function for max +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/manager/trader_manager.go b/manager/trader_manager.go index 8e39670e..4060a794 100644 --- a/manager/trader_manager.go +++ b/manager/trader_manager.go @@ -292,8 +292,8 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [ // Concurrently fetch data for each trader for i, t := range traders { go func(index int, trader *trader.AutoTrader) { - // Set timeout to 3 seconds for single trader - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + // Set timeout to 10 seconds for single trader (increased from 3s for DEX reliability) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Use channel for timeout control @@ -330,7 +330,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [ } case err := <-errorChan: // Failed to get account info - logger.Infof("⚠️ Failed to get account info for trader %s: %v", trader.GetID(), err) + logger.Infof("⚠️ Failed to get account info for trader %s (%s/%s): %v", trader.GetName(), trader.GetID(), trader.GetExchange(), err) traderData = map[string]interface{}{ "trader_id": trader.GetID(), "trader_name": trader.GetName(), @@ -347,7 +347,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [ } case <-ctx.Done(): // Timeout - logger.Infof("⏰ Timeout getting account info for trader %s", trader.GetID()) + logger.Infof("⏰ Timeout (10s) getting account info for trader %s (%s/%s)", trader.GetName(), trader.GetID(), trader.GetExchange()) traderData = map[string]interface{}{ "trader_id": trader.GetID(), "trader_name": trader.GetName(), diff --git a/market/data.go b/market/data.go index 1fa990c0..a993736a 100644 --- a/market/data.go +++ b/market/data.go @@ -1210,3 +1210,91 @@ func ExportCalculateATR(klines []Kline, period int) float64 { func ExportCalculateBOLL(klines []Kline, period int, multiplier float64) (upper, middle, lower float64) { return calculateBOLL(klines, period, multiplier) } + +// calculateDonchian calculates Donchian channel (highest high, lowest low) for given period +func calculateDonchian(klines []Kline, period int) (upper, lower float64) { + if len(klines) == 0 || period <= 0 { + return 0, 0 + } + + // Use all available klines if period > len(klines) + start := len(klines) - period + if start < 0 { + start = 0 + } + + upper = klines[start].High + lower = klines[start].Low + + for i := start + 1; i < len(klines); i++ { + if klines[i].High > upper { + upper = klines[i].High + } + if klines[i].Low < lower { + lower = klines[i].Low + } + } + + return upper, lower +} + +// ExportCalculateDonchian exports calculateDonchian for testing +func ExportCalculateDonchian(klines []Kline, period int) (float64, float64) { + return calculateDonchian(klines, period) +} + +// Box period constants (in 1h candles) +const ( + ShortBoxPeriod = 72 // 3 days of 1h candles + MidBoxPeriod = 240 // 10 days of 1h candles + LongBoxPeriod = 500 // ~21 days of 1h candles +) + +// calculateBoxData calculates multi-period box data from klines +func calculateBoxData(klines []Kline, currentPrice float64) *BoxData { + box := &BoxData{ + CurrentPrice: currentPrice, + } + + if len(klines) == 0 { + return box + } + + box.ShortUpper, box.ShortLower = calculateDonchian(klines, ShortBoxPeriod) + box.MidUpper, box.MidLower = calculateDonchian(klines, MidBoxPeriod) + box.LongUpper, box.LongLower = calculateDonchian(klines, LongBoxPeriod) + + return box +} + +// ExportCalculateBoxData exports calculateBoxData for testing +func ExportCalculateBoxData(klines []Kline, currentPrice float64) *BoxData { + return calculateBoxData(klines, currentPrice) +} + +// GetBoxData fetches 1h klines and calculates box data for a symbol +func GetBoxData(symbol string) (*BoxData, error) { + symbol = Normalize(symbol) + + // Fetch 500 1h klines + var klines []Kline + var err error + + if IsXyzDexAsset(symbol) { + klines, err = getKlinesFromHyperliquid(symbol, "1h", LongBoxPeriod) + } else { + klines, err = getKlinesFromCoinAnk(symbol, "1h", LongBoxPeriod) + } + + if err != nil { + return nil, fmt.Errorf("failed to get 1h klines: %w", err) + } + + if len(klines) == 0 { + return nil, fmt.Errorf("no kline data available") + } + + currentPrice := klines[len(klines)-1].Close + + return calculateBoxData(klines, currentPrice), nil +} diff --git a/market/data_test.go b/market/data_test.go index b20a336c..231c4806 100644 --- a/market/data_test.go +++ b/market/data_test.go @@ -500,3 +500,86 @@ func TestIsStaleData_EmptyKlines(t *testing.T) { t.Error("Expected false for empty klines, got true") } } + +func TestCalculateDonchian(t *testing.T) { + // Create test klines with known high/low values + klines := []Kline{ + {High: 100, Low: 90}, + {High: 105, Low: 88}, + {High: 102, Low: 92}, + {High: 108, Low: 85}, + {High: 103, Low: 91}, + } + + upper, lower := ExportCalculateDonchian(klines, 5) + + if upper != 108 { + t.Errorf("Expected upper = 108, got %v", upper) + } + if lower != 85 { + t.Errorf("Expected lower = 85, got %v", lower) + } +} + +func TestCalculateDonchian_PartialPeriod(t *testing.T) { + klines := []Kline{ + {High: 100, Low: 90}, + {High: 105, Low: 88}, + } + + upper, lower := ExportCalculateDonchian(klines, 10) + + // Should use all available klines when period > len(klines) + if upper != 105 { + t.Errorf("Expected upper = 105, got %v", upper) + } + if lower != 88 { + t.Errorf("Expected lower = 88, got %v", lower) + } +} + +func TestCalculateDonchian_InvalidPeriod(t *testing.T) { + klines := []Kline{ + {High: 100, Low: 90}, + } + + // Zero period should return (0, 0) + upper, lower := ExportCalculateDonchian(klines, 0) + if upper != 0 || lower != 0 { + t.Errorf("Expected (0, 0) for zero period, got (%v, %v)", upper, lower) + } + + // Negative period should return (0, 0) + upper, lower = ExportCalculateDonchian(klines, -1) + if upper != 0 || lower != 0 { + t.Errorf("Expected (0, 0) for negative period, got (%v, %v)", upper, lower) + } +} + +func TestCalculateBoxData(t *testing.T) { + // Create synthetic kline data + klines := make([]Kline, 500) + for i := 0; i < 500; i++ { + basePrice := 100.0 + klines[i] = Kline{ + High: basePrice + float64(i%10), + Low: basePrice - float64(i%10), + Close: basePrice, + } + } + + box := ExportCalculateBoxData(klines, 100.0) + + if box.ShortUpper == 0 || box.ShortLower == 0 { + t.Error("Short box should not be zero") + } + if box.MidUpper == 0 || box.MidLower == 0 { + t.Error("Mid box should not be zero") + } + if box.LongUpper == 0 || box.LongLower == 0 { + t.Error("Long box should not be zero") + } + if box.CurrentPrice != 100.0 { + t.Errorf("Expected CurrentPrice = 100.0, got %v", box.CurrentPrice) + } +} diff --git a/market/types.go b/market/types.go index 1ca71a68..7569c9f3 100644 --- a/market/types.go +++ b/market/types.go @@ -187,3 +187,42 @@ var config = Config{ }, UpdateInterval: 60, // 1 minute } + +// BoxData represents multi-period Donchian channel (box) data +type BoxData struct { + // Short-term box (72 1h candles = 3 days) + ShortUpper float64 `json:"short_upper"` + ShortLower float64 `json:"short_lower"` + + // Mid-term box (240 1h candles = 10 days) + MidUpper float64 `json:"mid_upper"` + MidLower float64 `json:"mid_lower"` + + // Long-term box (500 1h candles = ~21 days) + LongUpper float64 `json:"long_upper"` + LongLower float64 `json:"long_lower"` + + // Current price position relative to boxes + CurrentPrice float64 `json:"current_price"` +} + +// RegimeLevel represents the ranging classification level +type RegimeLevel string + +const ( + RegimeLevelNarrow RegimeLevel = "narrow" // 窄幅震荡 + RegimeLevelStandard RegimeLevel = "standard" // 标准震荡 + RegimeLevelWide RegimeLevel = "wide" // 宽幅震荡 + RegimeLevelVolatile RegimeLevel = "volatile" // 剧烈震荡 + RegimeLevelTrending RegimeLevel = "trending" // 趋势 +) + +// BreakoutLevel represents which box level has been broken +type BreakoutLevel string + +const ( + BreakoutNone BreakoutLevel = "none" + BreakoutShort BreakoutLevel = "short" + BreakoutMid BreakoutLevel = "mid" + BreakoutLong BreakoutLevel = "long" +) diff --git a/scripts/test_lighter_orders.go b/scripts/test_lighter_orders.go new file mode 100644 index 00000000..e064cac2 --- /dev/null +++ b/scripts/test_lighter_orders.go @@ -0,0 +1,168 @@ +//go:build ignore + +// Test script to verify Lighter API authentication +// Run: go run scripts/test_lighter_orders.go +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + lighterClient "github.com/elliottech/lighter-go/client" + lighterHTTP "github.com/elliottech/lighter-go/client/http" +) + +func main() { + // Configuration - update these values + walletAddr := os.Getenv("LIGHTER_WALLET") + apiKeyPrivateKey := os.Getenv("LIGHTER_API_KEY") + + if walletAddr == "" || apiKeyPrivateKey == "" { + fmt.Println("Usage: LIGHTER_WALLET=0x... LIGHTER_API_KEY=... go run scripts/test_lighter_orders.go") + fmt.Println("Environment variables required:") + fmt.Println(" LIGHTER_WALLET - Ethereum wallet address") + fmt.Println(" LIGHTER_API_KEY - API key private key (40 bytes hex)") + os.Exit(1) + } + + fmt.Println("=== Lighter API Test ===") + fmt.Printf("Wallet: %s\n\n", walletAddr) + + baseURL := "https://mainnet.zklighter.elliot.ai" + chainID := uint32(304) + client := &http.Client{Timeout: 30 * time.Second} + + // Step 1: Get account info (no auth required) + fmt.Println("1. Getting account info...") + accountIndex, err := getAccountIndex(client, baseURL, walletAddr) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + os.Exit(1) + } + fmt.Printf(" OK: account_index = %d\n\n", accountIndex) + + // Step 2: Create TxClient and generate auth token + fmt.Println("2. Creating TxClient and generating auth token...") + httpClient := lighterHTTP.NewClient(baseURL) + txClient, err := lighterClient.NewTxClient(httpClient, apiKeyPrivateKey, accountIndex, 0, chainID) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + os.Exit(1) + } + + authToken, err := txClient.GetAuthToken(time.Now().Add(1 * time.Hour)) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + os.Exit(1) + } + fmt.Printf(" OK: auth token generated\n\n") + + // Step 3: Test GetActiveOrders with auth query parameter (NEW method) + fmt.Println("3. Testing GetActiveOrders with auth query parameter (FIXED)...") + encodedAuth := url.QueryEscape(authToken) + endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0&auth=%s", + baseURL, accountIndex, encodedAuth) + + resp, err := client.Get(endpoint) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + os.Exit(1) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var result map[string]interface{} + json.Unmarshal(body, &result) + + if code, ok := result["code"].(float64); ok && code == 200 { + orders := result["orders"].([]interface{}) + fmt.Printf(" OK: Retrieved %d orders\n", len(orders)) + if len(orders) > 0 { + fmt.Println(" Sample orders:") + for i, o := range orders { + if i >= 3 { + fmt.Printf(" ... and %d more\n", len(orders)-3) + break + } + order := o.(map[string]interface{}) + fmt.Printf(" - ID: %v, Price: %v, Side: %v\n", + order["order_id"], order["price"], order["is_ask"]) + } + } + } else { + fmt.Printf(" FAILED: %s\n", string(body)) + fmt.Println("\n Possible causes:") + fmt.Println(" - API key not registered on-chain") + fmt.Println(" - API key private key incorrect") + fmt.Println(" - Account index mismatch") + os.Exit(1) + } + + // Step 4: Test GetActiveOrders with Authorization header (OLD method - for comparison) + fmt.Println("\n4. Testing GetActiveOrders with Authorization header (OLD method)...") + endpoint2 := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0", + baseURL, accountIndex) + + req, _ := http.NewRequest("GET", endpoint2, nil) + req.Header.Set("Authorization", authToken) + req.Header.Set("Content-Type", "application/json") + + resp2, err := client.Do(req) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + } else { + defer resp2.Body.Close() + body2, _ := io.ReadAll(resp2.Body) + var result2 map[string]interface{} + json.Unmarshal(body2, &result2) + + if code, ok := result2["code"].(float64); ok && code == 200 { + orders := result2["orders"].([]interface{}) + fmt.Printf(" OK: Retrieved %d orders (both methods work!)\n", len(orders)) + } else { + fmt.Printf(" FAILED: %s\n", string(body2)) + fmt.Println(" ^ This is expected - Authorization header doesn't work consistently") + } + } + + fmt.Println("\n=== TEST COMPLETE ===") + fmt.Println("If test 3 passed, the fix is working correctly.") +} + +func getAccountIndex(client *http.Client, baseURL, walletAddr string) (int64, error) { + endpoint := fmt.Sprintf("%s/api/v1/account?by=l1_address&value=%s", baseURL, walletAddr) + resp, err := client.Get(endpoint) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var result struct { + Code int `json:"code"` + Accounts []struct { + AccountIndex int64 `json:"account_index"` + } `json:"accounts"` + SubAccounts []struct { + AccountIndex int64 `json:"account_index"` + } `json:"sub_accounts"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return 0, fmt.Errorf("failed to parse: %w", err) + } + + if len(result.Accounts) > 0 { + return result.Accounts[0].AccountIndex, nil + } + if len(result.SubAccounts) > 0 { + return result.SubAccounts[0].AccountIndex, nil + } + + return 0, fmt.Errorf("no account found") +} diff --git a/store/grid.go b/store/grid.go new file mode 100644 index 00000000..49ce5708 --- /dev/null +++ b/store/grid.go @@ -0,0 +1,585 @@ +package store + +import ( + "fmt" + "time" + + "gorm.io/gorm" +) + +// ==================== Grid Store Models ==================== +// These models mirror the grid package types but are defined here +// to avoid import cycles between store and grid packages. + +// GridConfigModel GORM model for grid_configs table +type GridConfigModel struct { + ID string `json:"id" gorm:"primaryKey"` + UserID string `json:"user_id" gorm:"index"` + TraderID string `json:"trader_id" gorm:"index"` + Symbol string `json:"symbol" gorm:"not null"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + + GridCount int `json:"grid_count" gorm:"default:10"` + TotalInvestment float64 `json:"total_investment" gorm:"not null"` + Leverage int `json:"leverage" gorm:"default:5"` + UpperPrice float64 `json:"upper_price"` + LowerPrice float64 `json:"lower_price"` + UseATRBounds bool `json:"use_atr_bounds" gorm:"default:true"` + ATRMultiplier float64 `json:"atr_multiplier" gorm:"default:2.0"` + Distribution string `json:"distribution" gorm:"default:gaussian"` + + MaxDrawdownPct float64 `json:"max_drawdown_pct" gorm:"default:15.0"` + StopLossPct float64 `json:"stop_loss_pct" gorm:"default:5.0"` + DailyLossLimitPct float64 `json:"daily_loss_limit_pct" gorm:"default:10"` + MaxPositionSizePct float64 `json:"max_position_size_pct" gorm:"default:30"` + + RegimeCheckInterval int `json:"regime_check_interval" gorm:"default:30"` + AutoPauseOnTrend bool `json:"auto_pause_on_trend" gorm:"default:true"` + MinRangingScore int `json:"min_ranging_score" gorm:"default:60"` + TrendResumeThreshold int `json:"trend_resume_threshold" gorm:"default:70"` + + // Box indicator periods (1h candles) + ShortBoxPeriod int `json:"short_box_period" gorm:"default:72"` // 3 days + MidBoxPeriod int `json:"mid_box_period" gorm:"default:240"` // 10 days + LongBoxPeriod int `json:"long_box_period" gorm:"default:500"` // 21 days + + // Effective leverage limits by regime level + NarrowRegimeLeverage int `json:"narrow_regime_leverage" gorm:"default:2"` + StandardRegimeLeverage int `json:"standard_regime_leverage" gorm:"default:4"` + WideRegimeLeverage int `json:"wide_regime_leverage" gorm:"default:3"` + VolatileRegimeLeverage int `json:"volatile_regime_leverage" gorm:"default:2"` + + // Position limits by regime level (percentage of total investment) + NarrowRegimePositionPct float64 `json:"narrow_regime_position_pct" gorm:"default:40"` + StandardRegimePositionPct float64 `json:"standard_regime_position_pct" gorm:"default:70"` + WideRegimePositionPct float64 `json:"wide_regime_position_pct" gorm:"default:60"` + VolatileRegimePositionPct float64 `json:"volatile_regime_position_pct" gorm:"default:40"` + + OrderRefreshSec int `json:"order_refresh_sec" gorm:"default:300"` + UseMakerOnly bool `json:"use_maker_only" gorm:"default:true"` + SlippageTolerPct float64 `json:"slippage_toler_pct" gorm:"default:0.1"` + + AIProvider string `json:"ai_provider" gorm:"default:deepseek"` + AIModel string `json:"ai_model" gorm:"default:deepseek-chat"` + IsActive bool `json:"is_active" gorm:"default:false"` +} + +func (GridConfigModel) TableName() string { + return "grid_configs" +} + +// GridInstanceModel GORM model for grid_instances table +type GridInstanceModel struct { + ID string `json:"id" gorm:"primaryKey"` + ConfigID string `json:"config_id" gorm:"index;not null"` + Symbol string `json:"symbol" gorm:"not null"` + State string `json:"state" gorm:"not null"` + StartedAt time.Time `json:"started_at"` + StoppedAt *time.Time `json:"stopped_at,omitempty"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + + CurrentUpperPrice float64 `json:"current_upper_price"` + CurrentLowerPrice float64 `json:"current_lower_price"` + CurrentGridSpacing float64 `json:"current_grid_spacing"` + ActiveLevelCount int `json:"active_level_count"` + CurrentRegime string `json:"current_regime"` + RegimeScore int `json:"regime_score"` + LastRegimeCheck time.Time `json:"last_regime_check"` + ConsecutiveTrending int `json:"consecutive_trending"` + + // Current regime level (narrow/standard/wide/volatile/trending) + CurrentRegimeLevel string `json:"current_regime_level" gorm:"default:standard"` + + // Box state + ShortBoxUpper float64 `json:"short_box_upper"` + ShortBoxLower float64 `json:"short_box_lower"` + MidBoxUpper float64 `json:"mid_box_upper"` + MidBoxLower float64 `json:"mid_box_lower"` + LongBoxUpper float64 `json:"long_box_upper"` + LongBoxLower float64 `json:"long_box_lower"` + + // Breakout state + BreakoutLevel string `json:"breakout_level" gorm:"default:none"` // none/short/mid/long + BreakoutDirection string `json:"breakout_direction"` // up/down + BreakoutConfirmCount int `json:"breakout_confirm_count" gorm:"default:0"` + BreakoutStartTime time.Time `json:"breakout_start_time"` + + // Position adjustment due to breakout + PositionReductionPct float64 `json:"position_reduction_pct" gorm:"default:0"` // 0 = normal, 50 = reduced + + TotalProfit float64 `json:"total_profit" gorm:"default:0"` + TotalFees float64 `json:"total_fees" gorm:"default:0"` + TotalTrades int `json:"total_trades" gorm:"default:0"` + WinningTrades int `json:"winning_trades" gorm:"default:0"` + MaxDrawdown float64 `json:"max_drawdown" gorm:"default:0"` + CurrentDrawdown float64 `json:"current_drawdown" gorm:"default:0"` + PeakEquity float64 `json:"peak_equity" gorm:"default:0"` + DailyProfit float64 `json:"daily_profit" gorm:"default:0"` + DailyLoss float64 `json:"daily_loss" gorm:"default:0"` + LastDailyReset time.Time `json:"last_daily_reset"` +} + +func (GridInstanceModel) TableName() string { + return "grid_instances" +} + +// GridLevelModel GORM model for grid_levels table +type GridLevelModel struct { + ID string `json:"id" gorm:"primaryKey"` + InstanceID string `json:"instance_id" gorm:"index;not null"` + LevelIndex int `json:"level_index" gorm:"not null"` + Price float64 `json:"price" gorm:"not null"` + State string `json:"state" gorm:"not null"` + Side string `json:"side"` + OrderID string `json:"order_id,omitempty"` + OrderPrice float64 `json:"order_price,omitempty"` + OrderQuantity float64 `json:"order_quantity,omitempty"` + OrderCreatedAt *time.Time `json:"order_created_at,omitempty"` + PositionSize float64 `json:"position_size,omitempty"` + PositionEntry float64 `json:"position_entry,omitempty"` + PositionOpenAt *time.Time `json:"position_open_at,omitempty"` + AllocationWeight float64 `json:"allocation_weight"` + AllocatedUSD float64 `json:"allocated_usd"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +func (GridLevelModel) TableName() string { + return "grid_levels" +} + +// GridEventModel GORM model for grid_events table +type GridEventModel struct { + ID string `json:"id" gorm:"primaryKey"` + InstanceID string `json:"instance_id" gorm:"index;not null"` + LevelID string `json:"level_id,omitempty" gorm:"index"` + EventType string `json:"event_type" gorm:"not null"` + EventTime time.Time `json:"event_time" gorm:"autoCreateTime"` + Price float64 `json:"price,omitempty"` + Quantity float64 `json:"quantity,omitempty"` + Side string `json:"side,omitempty"` + PnL float64 `json:"pnl,omitempty"` + Fee float64 `json:"fee,omitempty"` + Message string `json:"message,omitempty"` + OldRegime string `json:"old_regime,omitempty"` + NewRegime string `json:"new_regime,omitempty"` + TriggerType string `json:"trigger_type,omitempty"` + RawData string `json:"raw_data,omitempty" gorm:"type:text"` +} + +func (GridEventModel) TableName() string { + return "grid_events" +} + +// GridRegimeAssessmentModel GORM model for grid_regime_assessments table +type GridRegimeAssessmentModel struct { + ID string `json:"id" gorm:"primaryKey"` + InstanceID string `json:"instance_id" gorm:"index;not null"` + AssessedAt time.Time `json:"assessed_at" gorm:"autoCreateTime"` + Regime string `json:"regime" gorm:"not null"` + Score int `json:"score" gorm:"not null"` + Confidence float64 `json:"confidence"` + BollingerSignal int `json:"bollinger_signal"` + EMASignal int `json:"ema_signal"` + MACDSignal int `json:"macd_signal"` + VolumeSignal int `json:"volume_signal"` + OISignal int `json:"oi_signal"` + FundingSignal int `json:"funding_signal"` + CandleSignal int `json:"candle_signal"` + ATR14 float64 `json:"atr14"` + BollingerWidth float64 `json:"bollinger_width"` + EMADistance float64 `json:"ema_distance"` + CurrentPrice float64 `json:"current_price"` + AIReasoning string `json:"ai_reasoning" gorm:"type:text"` +} + +func (GridRegimeAssessmentModel) TableName() string { + return "grid_regime_assessments" +} + +// ==================== Grid Store ==================== + +// GridStore provides database operations for grid trading +type GridStore struct { + db *gorm.DB +} + +// NewGridStore creates a new grid store +func NewGridStore(db *gorm.DB) *GridStore { + return &GridStore{db: db} +} + +// InitTables initializes grid-related tables +func (s *GridStore) InitTables() error { + // For PostgreSQL with existing tables, skip AutoMigrate to avoid type conflicts + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'grid_configs'`).Scan(&tableExists) + + if tableExists > 0 { + // Tables exist, just ensure indexes + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_configs_user_id ON grid_configs(user_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_configs_trader_id ON grid_configs(trader_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_instances_config_id ON grid_instances(config_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_levels_instance_id ON grid_levels(instance_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_events_instance_id ON grid_events(instance_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_events_level_id ON grid_events(level_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_regime_assessments_instance_id ON grid_regime_assessments(instance_id)`) + return nil + } + } + + // AutoMigrate all grid tables + if err := s.db.AutoMigrate( + &GridConfigModel{}, + &GridInstanceModel{}, + &GridLevelModel{}, + &GridEventModel{}, + &GridRegimeAssessmentModel{}, + ); err != nil { + return fmt.Errorf("failed to migrate grid tables: %w", err) + } + + return nil +} + +// ==================== Config Operations ==================== + +// SaveGridConfig saves or updates a grid configuration +func (s *GridStore) SaveGridConfig(config *GridConfigModel) error { + config.UpdatedAt = time.Now() + if config.CreatedAt.IsZero() { + config.CreatedAt = time.Now() + } + return s.db.Save(config).Error +} + +// LoadGridConfig loads a grid configuration by ID +func (s *GridStore) LoadGridConfig(id string) (*GridConfigModel, error) { + var config GridConfigModel + err := s.db.Where("id = ?", id).First(&config).Error + if err != nil { + return nil, err + } + return &config, nil +} + +// LoadGridConfigByTrader loads a grid configuration by trader ID +func (s *GridStore) LoadGridConfigByTrader(traderID string) (*GridConfigModel, error) { + var config GridConfigModel + err := s.db.Where("trader_id = ? AND is_active = true", traderID).First(&config).Error + if err != nil { + return nil, err + } + return &config, nil +} + +// ListGridConfigs lists all grid configurations for a user +func (s *GridStore) ListGridConfigs(userID string) ([]GridConfigModel, error) { + var configs []GridConfigModel + err := s.db.Where("user_id = ?", userID).Order("created_at DESC").Find(&configs).Error + if err != nil { + return nil, err + } + return configs, nil +} + +// DeleteGridConfig deletes a grid configuration and all related data +func (s *GridStore) DeleteGridConfig(id string) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Get all instances for this config + var instances []GridInstanceModel + if err := tx.Where("config_id = ?", id).Find(&instances).Error; err != nil { + return err + } + + // Delete related data for each instance + for _, instance := range instances { + if err := tx.Where("instance_id = ?", instance.ID).Delete(&GridLevelModel{}).Error; err != nil { + return err + } + if err := tx.Where("instance_id = ?", instance.ID).Delete(&GridEventModel{}).Error; err != nil { + return err + } + if err := tx.Where("instance_id = ?", instance.ID).Delete(&GridRegimeAssessmentModel{}).Error; err != nil { + return err + } + } + + // Delete instances + if err := tx.Where("config_id = ?", id).Delete(&GridInstanceModel{}).Error; err != nil { + return err + } + + // Delete config + return tx.Where("id = ?", id).Delete(&GridConfigModel{}).Error + }) +} + +// ==================== Instance Operations ==================== + +// SaveGridInstance saves or updates a grid instance +func (s *GridStore) SaveGridInstance(instance *GridInstanceModel) error { + instance.UpdatedAt = time.Now() + return s.db.Save(instance).Error +} + +// LoadGridInstance loads a grid instance by config ID +func (s *GridStore) LoadGridInstance(configID string) (*GridInstanceModel, error) { + var instance GridInstanceModel + err := s.db.Where("config_id = ?", configID). + Order("started_at DESC"). + First(&instance).Error + if err != nil { + return nil, err + } + return &instance, nil +} + +// LoadGridInstanceByID loads a grid instance by ID +func (s *GridStore) LoadGridInstanceByID(id string) (*GridInstanceModel, error) { + var instance GridInstanceModel + err := s.db.Where("id = ?", id).First(&instance).Error + if err != nil { + return nil, err + } + return &instance, nil +} + +// ListGridInstances lists all instances for a config +func (s *GridStore) ListGridInstances(configID string) ([]GridInstanceModel, error) { + var instances []GridInstanceModel + err := s.db.Where("config_id = ?", configID). + Order("started_at DESC"). + Find(&instances).Error + if err != nil { + return nil, err + } + return instances, nil +} + +// ==================== Level Operations ==================== + +// SaveGridLevel saves or updates a grid level +func (s *GridStore) SaveGridLevel(level *GridLevelModel) error { + level.UpdatedAt = time.Now() + return s.db.Save(level).Error +} + +// SaveGridLevels saves multiple grid levels +func (s *GridStore) SaveGridLevels(levels []GridLevelModel) error { + if len(levels) == 0 { + return nil + } + now := time.Now() + for i := range levels { + levels[i].UpdatedAt = now + } + return s.db.Save(&levels).Error +} + +// LoadGridLevels loads all levels for an instance +func (s *GridStore) LoadGridLevels(instanceID string) ([]GridLevelModel, error) { + var levels []GridLevelModel + err := s.db.Where("instance_id = ?", instanceID). + Order("level_index ASC"). + Find(&levels).Error + if err != nil { + return nil, err + } + return levels, nil +} + +// DeleteGridLevels deletes all levels for an instance +func (s *GridStore) DeleteGridLevels(instanceID string) error { + return s.db.Where("instance_id = ?", instanceID).Delete(&GridLevelModel{}).Error +} + +// ==================== Event Operations ==================== + +// SaveGridEvent saves a grid event +func (s *GridStore) SaveGridEvent(event *GridEventModel) error { + if event.EventTime.IsZero() { + event.EventTime = time.Now() + } + return s.db.Create(event).Error +} + +// LoadRecentGridEvents loads recent events for an instance +func (s *GridStore) LoadRecentGridEvents(instanceID string, limit int) ([]GridEventModel, error) { + var events []GridEventModel + query := s.db.Where("instance_id = ?", instanceID). + Order("event_time DESC") + if limit > 0 { + query = query.Limit(limit) + } + err := query.Find(&events).Error + if err != nil { + return nil, err + } + return events, nil +} + +// LoadGridEventsByType loads events of a specific type +func (s *GridStore) LoadGridEventsByType(instanceID, eventType string, limit int) ([]GridEventModel, error) { + var events []GridEventModel + query := s.db.Where("instance_id = ? AND event_type = ?", instanceID, eventType). + Order("event_time DESC") + if limit > 0 { + query = query.Limit(limit) + } + err := query.Find(&events).Error + if err != nil { + return nil, err + } + return events, nil +} + +// CountGridEvents counts events for an instance +func (s *GridStore) CountGridEvents(instanceID string) (int64, error) { + var count int64 + err := s.db.Model(&GridEventModel{}). + Where("instance_id = ?", instanceID). + Count(&count).Error + return count, err +} + +// ==================== Regime Assessment Operations ==================== + +// SaveGridRegimeAssessment saves a regime assessment +func (s *GridStore) SaveGridRegimeAssessment(assessment *GridRegimeAssessmentModel) error { + if assessment.AssessedAt.IsZero() { + assessment.AssessedAt = time.Now() + } + return s.db.Create(assessment).Error +} + +// LoadLatestGridRegime loads the latest regime assessment +func (s *GridStore) LoadLatestGridRegime(instanceID string) (*GridRegimeAssessmentModel, error) { + var assessment GridRegimeAssessmentModel + err := s.db.Where("instance_id = ?", instanceID). + Order("assessed_at DESC"). + First(&assessment).Error + if err != nil { + return nil, err + } + return &assessment, nil +} + +// LoadGridRegimeHistory loads regime assessment history +func (s *GridStore) LoadGridRegimeHistory(instanceID string, limit int) ([]GridRegimeAssessmentModel, error) { + var assessments []GridRegimeAssessmentModel + query := s.db.Where("instance_id = ?", instanceID). + Order("assessed_at DESC") + if limit > 0 { + query = query.Limit(limit) + } + err := query.Find(&assessments).Error + if err != nil { + return nil, err + } + return assessments, nil +} + +// ==================== Statistics Operations ==================== + +// GetGridInstanceStatistics returns statistics for an instance +func (s *GridStore) GetGridInstanceStatistics(instanceID string) (map[string]interface{}, error) { + var instance GridInstanceModel + if err := s.db.Where("id = ?", instanceID).First(&instance).Error; err != nil { + return nil, err + } + + // Count events by type + var eventCounts []struct { + EventType string + Count int64 + } + s.db.Model(&GridEventModel{}). + Select("event_type, count(*) as count"). + Where("instance_id = ?", instanceID). + Group("event_type"). + Find(&eventCounts) + + eventCountMap := make(map[string]int64) + for _, ec := range eventCounts { + eventCountMap[ec.EventType] = ec.Count + } + + // Get latest regime + var latestRegime GridRegimeAssessmentModel + s.db.Where("instance_id = ?", instanceID). + Order("assessed_at DESC"). + First(&latestRegime) + + winRate := 0.0 + if instance.TotalTrades > 0 { + winRate = float64(instance.WinningTrades) / float64(instance.TotalTrades) * 100 + } + + return map[string]interface{}{ + "instance_id": instance.ID, + "state": instance.State, + "started_at": instance.StartedAt, + "stopped_at": instance.StoppedAt, + "total_profit": instance.TotalProfit, + "total_fees": instance.TotalFees, + "total_trades": instance.TotalTrades, + "winning_trades": instance.WinningTrades, + "win_rate": winRate, + "max_drawdown": instance.MaxDrawdown, + "current_drawdown": instance.CurrentDrawdown, + "peak_equity": instance.PeakEquity, + "active_level_count": instance.ActiveLevelCount, + "current_regime": instance.CurrentRegime, + "regime_score": instance.RegimeScore, + "event_counts": eventCountMap, + "latest_regime_score": latestRegime.Score, + }, nil +} + +// GetGridPerformanceMetrics returns performance metrics for a time period +func (s *GridStore) GetGridPerformanceMetrics(instanceID string, from, to time.Time) (map[string]interface{}, error) { + // Count trades in period + var tradeCounts struct { + TotalFills int64 + BuyFills int64 + SellFills int64 + } + s.db.Model(&GridEventModel{}). + Select("count(*) as total_fills, "+ + "sum(case when side = 'buy' then 1 else 0 end) as buy_fills, "+ + "sum(case when side = 'sell' then 1 else 0 end) as sell_fills"). + Where("instance_id = ? AND event_type = 'order_filled' AND event_time BETWEEN ? AND ?", + instanceID, from, to). + Scan(&tradeCounts) + + // Sum profit/loss + var pnlSum struct { + TotalPnL float64 + TotalFee float64 + } + s.db.Model(&GridEventModel{}). + Select("coalesce(sum(pnl), 0) as total_pnl, coalesce(sum(fee), 0) as total_fee"). + Where("instance_id = ? AND event_time BETWEEN ? AND ?", instanceID, from, to). + Scan(&pnlSum) + + // Count regime changes + var regimeChanges int64 + s.db.Model(&GridEventModel{}). + Where("instance_id = ? AND event_type = 'regime_change' AND event_time BETWEEN ? AND ?", + instanceID, from, to). + Count(®imeChanges) + + return map[string]interface{}{ + "period_start": from, + "period_end": to, + "total_fills": tradeCounts.TotalFills, + "buy_fills": tradeCounts.BuyFills, + "sell_fills": tradeCounts.SellFills, + "total_pnl": pnlSum.TotalPnL, + "total_fees": pnlSum.TotalFee, + "net_pnl": pnlSum.TotalPnL - pnlSum.TotalFee, + "regime_changes": regimeChanges, + }, nil +} diff --git a/store/store.go b/store/store.go index 21c15813..8119b935 100644 --- a/store/store.go +++ b/store/store.go @@ -28,6 +28,7 @@ type Store struct { strategy *StrategyStore equity *EquityStore order *OrderStore + grid *GridStore mu sync.RWMutex } @@ -156,6 +157,9 @@ func (s *Store) initTables() error { if err := s.Order().InitTables(); err != nil { return fmt.Errorf("failed to initialize order tables: %w", err) } + if err := s.Grid().InitTables(); err != nil { + return fmt.Errorf("failed to initialize grid tables: %w", err) + } return nil } @@ -279,6 +283,16 @@ func (s *Store) Order() *OrderStore { return s.order } +// Grid gets grid trading storage +func (s *Store) Grid() *GridStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.grid == nil { + s.grid = NewGridStore(s.gdb) + } + return s.grid +} + // Close closes database connection func (s *Store) Close() error { if s.driver != nil { diff --git a/store/strategy.go b/store/strategy.go index 1b3f5b11..be009851 100644 --- a/store/strategy.go +++ b/store/strategy.go @@ -32,6 +32,9 @@ func (Strategy) TableName() string { return "strategies" } // StrategyConfig strategy configuration details (JSON structure) type StrategyConfig struct { + // Strategy type: "ai_trading" (default) or "grid_trading" + StrategyType string `json:"strategy_type,omitempty"` + // language setting: "zh" for Chinese, "en" for English // This determines the language used for data formatting and prompt generation Language string `json:"language,omitempty"` @@ -45,6 +48,39 @@ type StrategyConfig struct { RiskControl RiskControlConfig `json:"risk_control"` // editable sections of System Prompt PromptSections PromptSectionsConfig `json:"prompt_sections,omitempty"` + + // Grid trading configuration (only used when StrategyType == "grid_trading") + GridConfig *GridStrategyConfig `json:"grid_config,omitempty"` +} + +// GridStrategyConfig grid trading specific configuration +type GridStrategyConfig struct { + // Trading pair (e.g., "BTCUSDT") + Symbol string `json:"symbol"` + // Number of grid levels (5-50) + GridCount int `json:"grid_count"` + // Total investment in USDT + TotalInvestment float64 `json:"total_investment"` + // Leverage (1-20) + Leverage int `json:"leverage"` + // Upper price boundary (0 = auto-calculate from ATR) + UpperPrice float64 `json:"upper_price"` + // Lower price boundary (0 = auto-calculate from ATR) + LowerPrice float64 `json:"lower_price"` + // Use ATR to auto-calculate bounds + UseATRBounds bool `json:"use_atr_bounds"` + // ATR multiplier for bound calculation (default 2.0) + ATRMultiplier float64 `json:"atr_multiplier"` + // Position distribution: "uniform" | "gaussian" | "pyramid" + Distribution string `json:"distribution"` + // Maximum drawdown percentage before emergency exit + MaxDrawdownPct float64 `json:"max_drawdown_pct"` + // Stop loss percentage per position + StopLossPct float64 `json:"stop_loss_pct"` + // Daily loss limit percentage + DailyLossLimitPct float64 `json:"daily_loss_limit_pct"` + // Use maker-only orders for lower fees + UseMakerOnly bool `json:"use_maker_only"` } // PromptSectionsConfig editable sections of System Prompt diff --git a/store/trader.go b/store/trader.go index ebe93e55..8b983baa 100644 --- a/store/trader.go +++ b/store/trader.go @@ -248,3 +248,23 @@ func (s *TraderStore) ListAll() ([]*Trader, error) { } return traders, nil } + +// ListByExchangeID gets traders that use a specific exchange +func (s *TraderStore) ListByExchangeID(userID, exchangeID string) ([]*Trader, error) { + var traders []*Trader + err := s.db.Where("user_id = ? AND exchange_id = ?", userID, exchangeID).Find(&traders).Error + if err != nil { + return nil, err + } + return traders, nil +} + +// ListByAIModelID gets traders that use a specific AI model +func (s *TraderStore) ListByAIModelID(userID, aiModelID string) ([]*Trader, error) { + var traders []*Trader + err := s.db.Where("user_id = ? AND ai_model_id = ?", userID, aiModelID).Find(&traders).Error + if err != nil { + return nil, err + } + return traders, nil +} diff --git a/trader/aster_trader.go b/trader/aster_trader.go index 2c1bbe7d..1d9a547b 100644 --- a/trader/aster_trader.go +++ b/trader/aster_trader.go @@ -1417,6 +1417,191 @@ func (t *AsterTrader) GetTrades(startTime time.Time, limit int) ([]TradeRecord, // GetOpenOrders gets all open/pending orders for a symbol func (t *AsterTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { - // TODO: Implement Aster open orders - return []OpenOrder{}, nil + params := map[string]interface{}{ + "symbol": symbol, + } + + body, err := t.request("GET", "/fapi/v3/openOrders", params) + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + var orders []struct { + OrderID int64 `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` + PositionSide string `json:"positionSide"` + Type string `json:"type"` + Price string `json:"price"` + StopPrice string `json:"stopPrice"` + OrigQty string `json:"origQty"` + Status string `json:"status"` + } + + if err := json.Unmarshal(body, &orders); err != nil { + return nil, fmt.Errorf("failed to parse open orders: %w", err) + } + + var result []OpenOrder + for _, order := range orders { + price, _ := strconv.ParseFloat(order.Price, 64) + stopPrice, _ := strconv.ParseFloat(order.StopPrice, 64) + quantity, _ := strconv.ParseFloat(order.OrigQty, 64) + + result = append(result, OpenOrder{ + OrderID: fmt.Sprintf("%d", order.OrderID), + Symbol: order.Symbol, + Side: order.Side, + PositionSide: order.PositionSide, + Type: order.Type, + Price: price, + StopPrice: stopPrice, + Quantity: quantity, + Status: order.Status, + }) + } + + logger.Infof("✓ ASTER GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +func (t *AsterTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(req.Symbol, req.Price) + if err != nil { + return nil, fmt.Errorf("failed to format price: %w", err) + } + formattedQty, err := t.formatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Get precision information + prec, err := t.getPrecision(req.Symbol) + if err != nil { + return nil, fmt.Errorf("failed to get precision: %w", err) + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + // Determine side + side := "BUY" + if req.Side == "SELL" || req.Side == "Sell" || req.Side == "sell" { + side = "SELL" + } + + params := map[string]interface{}{ + "symbol": req.Symbol, + "positionSide": "BOTH", + "type": "LIMIT", + "side": side, + "timeInForce": "GTC", + "quantity": qtyStr, + "price": priceStr, + } + + // Add reduceOnly if specified + if req.ReduceOnly { + params["reduceOnly"] = "true" + } + + body, err := t.request("POST", "/fapi/v3/order", params) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + // Extract order ID + orderID := "" + if id, ok := result["orderId"].(float64); ok { + orderID = fmt.Sprintf("%.0f", id) + } else if id, ok := result["orderId"].(string); ok { + orderID = id + } + + // Extract client order ID + clientOrderID := "" + if cid, ok := result["clientOrderId"].(string); ok { + clientOrderID = cid + } + + return &LimitOrderResult{ + OrderID: orderID, + ClientID: clientOrderID, + Symbol: req.Symbol, + Side: side, + Price: formattedPrice, + Quantity: formattedQty, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by order ID +func (t *AsterTrader) CancelOrder(symbol, orderID string) error { + params := map[string]interface{}{ + "symbol": symbol, + "orderId": orderID, + } + + _, err := t.request("DELETE", "/fapi/v3/order", params) + if err != nil { + return fmt.Errorf("failed to cancel order %s: %w", orderID, err) + } + + return nil +} + +// GetOrderBook gets the order book for a symbol +func (t *AsterTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + if depth <= 0 { + depth = 20 + } + + // Aster uses public endpoint (no signature required) + resp, err := t.client.Get(fmt.Sprintf("%s/fapi/v3/depth?symbol=%s&limit=%d", t.baseURL, symbol, depth)) + if err != nil { + return nil, nil, fmt.Errorf("failed to fetch order book: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Bids [][]string `json:"bids"` // [[price, qty], ...] + Asks [][]string `json:"asks"` // [[price, qty], ...] + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + // Convert string arrays to float64 arrays + bids = make([][]float64, len(result.Bids)) + for i, bid := range result.Bids { + if len(bid) >= 2 { + price, _ := strconv.ParseFloat(bid[0], 64) + qty, _ := strconv.ParseFloat(bid[1], 64) + bids[i] = []float64{price, qty} + } + } + + asks = make([][]float64, len(result.Asks)) + for i, ask := range result.Asks { + if len(ask) >= 2 { + price, _ := strconv.ParseFloat(ask[0], 64) + qty, _ := strconv.ParseFloat(ask[1], 64) + asks[i] = []float64{price, qty} + } + } + + return bids, asks, nil } diff --git a/trader/auto_trader.go b/trader/auto_trader.go index 2f0daee0..3ec70635 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -123,6 +123,7 @@ type AutoTrader struct { peakPnLCacheMutex sync.RWMutex // Cache read-write lock lastBalanceSyncTime time.Time // Last balance sync time userID string // User ID + gridState *GridState // Grid trading state (only used when StrategyType == "grid_trading") } // NewAutoTrader creates an automatic trader @@ -419,9 +420,25 @@ func (at *AutoTrader) Run() error { ticker := time.NewTicker(at.config.ScanInterval) defer ticker.Stop() + // Check if this is a grid trading strategy + isGridStrategy := at.IsGridStrategy() + if isGridStrategy { + logger.Infof("🔲 [%s] Grid trading strategy detected, initializing grid...", at.name) + if err := at.InitializeGrid(); err != nil { + logger.Errorf("❌ [%s] Failed to initialize grid: %v", at.name, err) + return fmt.Errorf("grid initialization failed: %w", err) + } + } + // Execute immediately on first run - if err := at.runCycle(); err != nil { - logger.Infof("❌ Execution failed: %v", err) + if isGridStrategy { + if err := at.RunGridCycle(); err != nil { + logger.Infof("❌ Grid execution failed: %v", err) + } + } else { + if err := at.runCycle(); err != nil { + logger.Infof("❌ Execution failed: %v", err) + } } for { @@ -435,8 +452,14 @@ func (at *AutoTrader) Run() error { select { case <-ticker.C: - if err := at.runCycle(); err != nil { - logger.Infof("❌ Execution failed: %v", err) + if isGridStrategy { + if err := at.RunGridCycle(); err != nil { + logger.Infof("❌ Grid execution failed: %v", err) + } + } else { + if err := at.runCycle(); err != nil { + logger.Infof("❌ Execution failed: %v", err) + } } case <-at.stopMonitorCh: logger.Infof("[%s] ⏹ Stop signal received, exiting automatic trading main loop", at.name) @@ -1365,6 +1388,12 @@ func (at *AutoTrader) GetID() string { return at.id } +// GetUnderlyingTrader returns the underlying Trader interface implementation +// This is used by grid trading and other components that need direct exchange access +func (at *AutoTrader) GetUnderlyingTrader() Trader { + return at.trader +} + // GetName gets trader name func (at *AutoTrader) GetName() string { return at.name @@ -1471,7 +1500,7 @@ func (at *AutoTrader) GetStatus() map[string]interface{} { isRunning := at.isRunning at.isRunningMutex.RUnlock() - return map[string]interface{}{ + result := map[string]interface{}{ "trader_id": at.id, "trader_name": at.name, "ai_model": at.aiModel, @@ -1486,6 +1515,16 @@ func (at *AutoTrader) GetStatus() map[string]interface{} { "last_reset_time": at.lastResetTime.Format(time.RFC3339), "ai_provider": aiProvider, } + + // Add strategy info + if at.config.StrategyConfig != nil { + result["strategy_type"] = at.config.StrategyConfig.StrategyType + if at.config.StrategyConfig.GridConfig != nil { + result["grid_symbol"] = at.config.StrategyConfig.GridConfig.Symbol + } + } + + return result } // GetAccountInfo gets account information (for API) diff --git a/trader/auto_trader_grid.go b/trader/auto_trader_grid.go new file mode 100644 index 00000000..5b445b7f --- /dev/null +++ b/trader/auto_trader_grid.go @@ -0,0 +1,1579 @@ +package trader + +import ( + "encoding/json" + "fmt" + "math" + "nofx/kernel" + "nofx/logger" + "nofx/market" + "nofx/store" + "sync" + "time" +) + +// ============================================================================ +// Grid Trading State Management +// ============================================================================ + +// GridState holds the runtime state for grid trading +type GridState struct { + mu sync.RWMutex + + // Configuration + Config *store.GridStrategyConfig + + // Grid levels + Levels []kernel.GridLevelInfo + + // Calculated bounds + UpperPrice float64 + LowerPrice float64 + GridSpacing float64 + + // State flags + IsPaused bool + IsInitialized bool + + // Performance tracking + TotalProfit float64 + TotalTrades int + WinningTrades int + MaxDrawdown float64 + PeakEquity float64 + DailyPnL float64 + LastDailyReset time.Time + + // Order tracking + OrderBook map[string]int // OrderID -> LevelIndex + + // Box state + ShortBoxUpper float64 + ShortBoxLower float64 + MidBoxUpper float64 + MidBoxLower float64 + LongBoxUpper float64 + LongBoxLower float64 + + // Breakout state + BreakoutLevel string + BreakoutDirection string + BreakoutConfirmCount int + + // Position reduction (0 = normal, 50 = reduced after false breakout) + PositionReductionPct float64 + + // Current regime level + CurrentRegimeLevel string +} + +// NewGridState creates a new grid state +func NewGridState(config *store.GridStrategyConfig) *GridState { + return &GridState{ + Config: config, + Levels: make([]kernel.GridLevelInfo, 0), + OrderBook: make(map[string]int), + } +} + +// ============================================================================ +// Breakout Detection +// ============================================================================ + +// BreakoutType represents the type of price breakout +type BreakoutType string + +const ( + BreakoutNone BreakoutType = "none" + BreakoutUpper BreakoutType = "upper" + BreakoutLower BreakoutType = "lower" +) + +// checkBreakout detects if price has broken out of grid range +// Returns breakout type and percentage beyond boundary +func (at *AutoTrader) checkBreakout() (BreakoutType, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return BreakoutNone, 0 + } + + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + if upper <= 0 || lower <= 0 { + return BreakoutNone, 0 + } + + // Check upper breakout + if currentPrice > upper { + breakoutPct := (currentPrice - upper) / upper * 100 + return BreakoutUpper, breakoutPct + } + + // Check lower breakout + if currentPrice < lower { + breakoutPct := (lower - currentPrice) / lower * 100 + return BreakoutLower, breakoutPct + } + + return BreakoutNone, 0 +} + +// checkMaxDrawdown checks if current drawdown exceeds maximum allowed +// Returns: (exceeded bool, currentDrawdown float64) +func (at *AutoTrader) checkMaxDrawdown() (bool, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.MaxDrawdownPct <= 0 { + return false, 0 + } + + // Get current equity + balance, err := at.trader.GetBalance() + if err != nil { + return false, 0 + } + + currentEquity := 0.0 + if equity, ok := balance["total_equity"].(float64); ok { + currentEquity = equity + } else if total, ok := balance["totalWalletBalance"].(float64); ok { + if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { + currentEquity = total + unrealized + } + } + + if currentEquity <= 0 { + return false, 0 + } + + // Update peak equity + at.gridState.mu.Lock() + if currentEquity > at.gridState.PeakEquity { + at.gridState.PeakEquity = currentEquity + } + peakEquity := at.gridState.PeakEquity + at.gridState.mu.Unlock() + + if peakEquity <= 0 { + return false, 0 + } + + // Calculate current drawdown + drawdown := (peakEquity - currentEquity) / peakEquity * 100 + + // Update max drawdown tracking + at.gridState.mu.Lock() + if drawdown > at.gridState.MaxDrawdown { + at.gridState.MaxDrawdown = drawdown + } + at.gridState.mu.Unlock() + + return drawdown >= gridConfig.MaxDrawdownPct, drawdown +} + +// checkDailyLossLimit checks if daily loss exceeds limit +// Returns: (exceeded bool, dailyLossPct float64) +func (at *AutoTrader) checkDailyLossLimit() (bool, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.DailyLossLimitPct <= 0 { + return false, 0 + } + + at.gridState.mu.Lock() + // Reset daily PnL if new day + now := time.Now() + if now.YearDay() != at.gridState.LastDailyReset.YearDay() || + now.Year() != at.gridState.LastDailyReset.Year() { + at.gridState.DailyPnL = 0 + at.gridState.LastDailyReset = now + } + dailyPnL := at.gridState.DailyPnL + at.gridState.mu.Unlock() + + // Calculate daily loss as percentage of total investment + dailyLossPct := 0.0 + if gridConfig.TotalInvestment > 0 && dailyPnL < 0 { + dailyLossPct = (-dailyPnL) / gridConfig.TotalInvestment * 100 + } + + return dailyLossPct >= gridConfig.DailyLossLimitPct, dailyLossPct +} + +// updateDailyPnL updates the daily PnL tracking +func (at *AutoTrader) updateDailyPnL(realizedPnL float64) { + at.gridState.mu.Lock() + at.gridState.DailyPnL += realizedPnL + at.gridState.TotalProfit += realizedPnL + at.gridState.mu.Unlock() +} + +// emergencyExit closes all positions and cancels all orders +func (at *AutoTrader) emergencyExit(reason string) error { + gridConfig := at.config.StrategyConfig.GridConfig + + logger.Errorf("[Grid] EMERGENCY EXIT: %s", reason) + + // Cancel all orders + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders in emergency: %v", err) + } + + // Close all positions + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok && size != 0 { + if size > 0 { + at.trader.CloseLong(gridConfig.Symbol, size) + } else { + at.trader.CloseShort(gridConfig.Symbol, -size) + } + } + } + } + } + + // Pause grid + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + return nil +} + +// handleBreakout handles price breakout from grid range +func (at *AutoTrader) handleBreakout(breakoutType BreakoutType, breakoutPct float64) error { + logger.Warnf("[Grid] BREAKOUT DETECTED: %s, %.2f%% beyond boundary", breakoutType, breakoutPct) + + // If breakout exceeds 2%, pause grid and cancel orders + if breakoutPct >= 2.0 { + logger.Warnf("[Grid] Significant breakout (%.2f%%), pausing grid and canceling orders", breakoutPct) + + // Cancel all pending orders to prevent further losses + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders on breakout: %v", err) + } + + // Pause grid trading + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + return fmt.Errorf("grid paused due to %s breakout (%.2f%%)", breakoutType, breakoutPct) + } + + // If breakout is minor (< 2%), consider adjusting grid + if breakoutPct >= 1.0 { + logger.Infof("[Grid] Minor breakout (%.2f%%), considering grid adjustment", breakoutPct) + // Let AI decide whether to adjust + } + + return nil +} + +// checkBoxBreakout checks for multi-period box breakouts and takes appropriate action +func (at *AutoTrader) checkBoxBreakout() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + // Get box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + logger.Infof("Failed to get box data: %v", err) + return nil // Non-fatal, continue with other checks + } + + // Update grid state with box values + at.gridState.mu.Lock() + at.gridState.ShortBoxUpper = box.ShortUpper + at.gridState.ShortBoxLower = box.ShortLower + at.gridState.MidBoxUpper = box.MidUpper + at.gridState.MidBoxLower = box.MidLower + at.gridState.LongBoxUpper = box.LongUpper + at.gridState.LongBoxLower = box.LongLower + at.gridState.mu.Unlock() + + // Detect breakout + breakoutLevel, direction := detectBoxBreakout(box) + + // Get current breakout state + state := &BreakoutState{ + Level: market.BreakoutLevel(at.gridState.BreakoutLevel), + Direction: at.gridState.BreakoutDirection, + ConfirmCount: at.gridState.BreakoutConfirmCount, + } + + // Check if breakout is confirmed (3 candles) + confirmed := confirmBreakout(state, breakoutLevel, direction) + + // Update grid state + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(state.Level) + at.gridState.BreakoutDirection = state.Direction + at.gridState.BreakoutConfirmCount = state.ConfirmCount + at.gridState.mu.Unlock() + + if !confirmed { + return nil + } + + // Take action based on breakout level + action := getBreakoutAction(breakoutLevel) + return at.executeBreakoutAction(action) +} + +// executeBreakoutAction executes the appropriate action for a breakout +func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error { + switch action { + case BreakoutActionReducePosition: + // Short box breakout: reduce position to 50% + logger.Infof("Short box breakout confirmed, reducing position to 50%%") + at.gridState.mu.Lock() + at.gridState.PositionReductionPct = 50 + at.gridState.mu.Unlock() + return nil + + case BreakoutActionPauseGrid: + // Mid box breakout: pause grid + cancel orders + logger.Infof("Mid box breakout confirmed, pausing grid and canceling orders") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return at.cancelAllGridOrders() + + case BreakoutActionCloseAll: + // Long box breakout: pause + cancel + close all + logger.Infof("Long box breakout confirmed, closing all positions") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + if err := at.cancelAllGridOrders(); err != nil { + logger.Infof("Failed to cancel orders: %v", err) + } + return at.closeAllPositions() + } + + return nil +} + +// closeAllPositions closes all open positions for the grid symbol +func (at *AutoTrader) closeAllPositions() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + positions, err := at.trader.GetPositions() + if err != nil { + return fmt.Errorf("failed to get positions: %w", err) + } + + for _, pos := range positions { + symbol, _ := pos["symbol"].(string) + if symbol != gridConfig.Symbol { + continue + } + + size, _ := pos["positionAmt"].(float64) + if size == 0 { + continue + } + + if size > 0 { + _, err = at.trader.CloseLong(symbol, size) + } else { + _, err = at.trader.CloseShort(symbol, -size) + } + if err != nil { + logger.Infof("Failed to close position: %v", err) + } + } + + return nil +} + +// checkFalseBreakoutRecovery checks if price has returned to box after breakout +func (at *AutoTrader) checkFalseBreakoutRecovery() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + at.gridState.mu.RLock() + breakoutLevel := at.gridState.BreakoutLevel + isPaused := at.gridState.IsPaused + positionReduction := at.gridState.PositionReductionPct + at.gridState.mu.RUnlock() + + // Only check if we had a breakout + if breakoutLevel == string(market.BreakoutNone) && positionReduction == 0 && !isPaused { + return nil + } + + // Get current box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + return nil + } + + // Check if price is back inside the long box + if box.CurrentPrice >= box.LongLower && box.CurrentPrice <= box.LongUpper { + logger.Infof("Price returned to box, recovering with 50%% position") + + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(market.BreakoutNone) + at.gridState.BreakoutDirection = "" + at.gridState.BreakoutConfirmCount = 0 + at.gridState.PositionReductionPct = 50 // Recover at 50% + at.gridState.IsPaused = false + at.gridState.mu.Unlock() + } + + return nil +} + +// ============================================================================ +// AutoTrader Grid Methods +// ============================================================================ + +// InitializeGrid initializes the grid state and calculates levels +func (at *AutoTrader) InitializeGrid() error { + if at.config.StrategyConfig == nil || at.config.StrategyConfig.GridConfig == nil { + return fmt.Errorf("grid configuration not found") + } + + gridConfig := at.config.StrategyConfig.GridConfig + at.gridState = NewGridState(gridConfig) + + // Get current market price + price, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return fmt.Errorf("failed to get market price: %w", err) + } + + // Calculate grid bounds + if gridConfig.UseATRBounds { + // Get ATR for bound calculation + mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"4h"}, "4h", 20) + if err != nil { + logger.Warnf("Failed to get market data for ATR: %v, using default bounds", err) + at.calculateDefaultBounds(price, gridConfig) + } else { + at.calculateATRBounds(price, mktData, gridConfig) + } + } else { + // Use manual bounds + at.gridState.UpperPrice = gridConfig.UpperPrice + at.gridState.LowerPrice = gridConfig.LowerPrice + } + + // Calculate grid spacing + at.gridState.GridSpacing = (at.gridState.UpperPrice - at.gridState.LowerPrice) / float64(gridConfig.GridCount-1) + + // Initialize grid levels + at.initializeGridLevels(price, gridConfig) + + at.gridState.IsInitialized = true + + // CRITICAL: Set leverage on exchange before trading + if err := at.trader.SetLeverage(gridConfig.Symbol, gridConfig.Leverage); err != nil { + logger.Warnf("[Grid] Failed to set leverage %dx on exchange: %v", gridConfig.Leverage, err) + // Not fatal - continue with default leverage + } else { + logger.Infof("[Grid] Leverage set to %dx for %s", gridConfig.Leverage, gridConfig.Symbol) + } + + logger.Infof("📊 [Grid] Initialized: %d levels, $%.2f - $%.2f, spacing $%.2f", + gridConfig.GridCount, at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) + + return nil +} + +// calculateDefaultBounds calculates default bounds based on price +func (at *AutoTrader) calculateDefaultBounds(price float64, config *store.GridStrategyConfig) { + // Default: ±3% from current price + multiplier := 0.03 * float64(config.GridCount) / 10 + at.gridState.UpperPrice = price * (1 + multiplier) + at.gridState.LowerPrice = price * (1 - multiplier) +} + +// calculateATRBounds calculates bounds using ATR +func (at *AutoTrader) calculateATRBounds(price float64, mktData *market.Data, config *store.GridStrategyConfig) { + atr := 0.0 + if mktData.LongerTermContext != nil { + atr = mktData.LongerTermContext.ATR14 + } + + if atr <= 0 { + at.calculateDefaultBounds(price, config) + return + } + + multiplier := config.ATRMultiplier + if multiplier <= 0 { + multiplier = 2.0 + } + + halfRange := atr * multiplier + at.gridState.UpperPrice = price + halfRange + at.gridState.LowerPrice = price - halfRange +} + +// initializeGridLevels creates the grid level structure +func (at *AutoTrader) initializeGridLevels(currentPrice float64, config *store.GridStrategyConfig) { + levels := make([]kernel.GridLevelInfo, config.GridCount) + totalWeight := 0.0 + weights := make([]float64, config.GridCount) + + // Calculate weights based on distribution + for i := 0; i < config.GridCount; i++ { + switch config.Distribution { + case "gaussian": + // Gaussian distribution - more weight in the middle + center := float64(config.GridCount-1) / 2 + sigma := float64(config.GridCount) / 4 + weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma)) + case "pyramid": + // Pyramid - more weight at bottom + weights[i] = float64(config.GridCount - i) + default: // uniform + weights[i] = 1.0 + } + totalWeight += weights[i] + } + + // Create levels + for i := 0; i < config.GridCount; i++ { + price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing + allocatedUSD := config.TotalInvestment * weights[i] / totalWeight + + // Determine initial side (below current price = buy, above = sell) + side := "buy" + if price > currentPrice { + side = "sell" + } + + levels[i] = kernel.GridLevelInfo{ + Index: i, + Price: price, + State: "empty", + Side: side, + AllocatedUSD: allocatedUSD, + } + } + + at.gridState.Levels = levels +} + +// RunGridCycle executes one grid trading cycle +func (at *AutoTrader) RunGridCycle() error { + // Check if trader is stopped (early exit to prevent trades after Stop() is called) + at.isRunningMutex.RLock() + running := at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("[Grid] Trader is stopped, aborting grid cycle") + return nil + } + + if at.gridState == nil || !at.gridState.IsInitialized { + if err := at.InitializeGrid(); err != nil { + return fmt.Errorf("failed to initialize grid: %w", err) + } + } + + // CRITICAL: Check for breakout before executing any trades + breakoutType, breakoutPct := at.checkBreakout() + if breakoutType != BreakoutNone { + if err := at.handleBreakout(breakoutType, breakoutPct); err != nil { + return err // Grid paused due to breakout + } + } + + // CRITICAL: Check max drawdown + exceeded, drawdown := at.checkMaxDrawdown() + if exceeded { + return at.emergencyExit(fmt.Sprintf("max drawdown exceeded: %.2f%%", drawdown)) + } + + // CRITICAL: Check daily loss limit + dailyExceeded, dailyLossPct := at.checkDailyLossLimit() + if dailyExceeded { + logger.Errorf("[Grid] Daily loss limit exceeded: %.2f%%", dailyLossPct) + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return fmt.Errorf("daily loss limit exceeded: %.2f%%", dailyLossPct) + } + + // Check multi-period box breakout + if err := at.checkBoxBreakout(); err != nil { + logger.Infof("Box breakout check error: %v", err) + } + + // Check for false breakout recovery + if err := at.checkFalseBreakoutRecovery(); err != nil { + logger.Infof("False breakout recovery check error: %v", err) + } + + // Check if grid is paused + at.gridState.mu.RLock() + isPaused := at.gridState.IsPaused + at.gridState.mu.RUnlock() + if isPaused { + logger.Infof("[Grid] Grid is paused, skipping cycle") + return nil + } + + gridConfig := at.config.StrategyConfig.GridConfig + lang := at.config.StrategyConfig.Language + if lang == "" { + lang = "en" + } + + // Build grid context + gridCtx, err := at.buildGridContext() + if err != nil { + return fmt.Errorf("failed to build grid context: %w", err) + } + + // Get AI decisions + decision, err := kernel.GetGridDecisions(gridCtx, at.mcpClient, gridConfig, lang) + if err != nil { + return fmt.Errorf("failed to get grid decisions: %w", err) + } + + // Check if trader is stopped before executing any decisions (prevent trades after Stop()) + at.isRunningMutex.RLock() + running = at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("[Grid] Trader stopped before decision execution, aborting grid cycle") + return nil + } + + // Execute decisions + for _, d := range decision.Decisions { + // Check if trader is still running before each decision + at.isRunningMutex.RLock() + running := at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("[Grid] Trader stopped, skipping remaining %d decisions", len(decision.Decisions)) + break + } + + if err := at.executeGridDecision(&d); err != nil { + logger.Warnf("[Grid] Failed to execute decision %s: %v", d.Action, err) + } + } + + // Sync state with exchange + at.syncGridState() + + // Save decision record + at.saveGridDecisionRecord(decision) + + return nil +} + +// buildGridContext builds the context for AI grid decisions +func (at *AutoTrader) buildGridContext() (*kernel.GridContext, error) { + gridConfig := at.config.StrategyConfig.GridConfig + + // Get market data + mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"5m", "4h"}, "5m", 50) + if err != nil { + return nil, fmt.Errorf("failed to get market data: %w", err) + } + + // Build base context from market data + ctx := kernel.BuildGridContextFromMarketData(mktData, gridConfig) + + // Add grid state + at.gridState.mu.RLock() + ctx.Levels = at.gridState.Levels + ctx.UpperPrice = at.gridState.UpperPrice + ctx.LowerPrice = at.gridState.LowerPrice + ctx.GridSpacing = at.gridState.GridSpacing + ctx.IsPaused = at.gridState.IsPaused + ctx.TotalProfit = at.gridState.TotalProfit + ctx.TotalTrades = at.gridState.TotalTrades + ctx.WinningTrades = at.gridState.WinningTrades + ctx.MaxDrawdown = at.gridState.MaxDrawdown + ctx.DailyPnL = at.gridState.DailyPnL + + // Count active orders and filled levels + for _, level := range at.gridState.Levels { + if level.State == "pending" { + ctx.ActiveOrderCount++ + } else if level.State == "filled" { + ctx.FilledLevelCount++ + } + } + at.gridState.mu.RUnlock() + + // Get account info + balance, err := at.trader.GetBalance() + if err == nil { + if equity, ok := balance["total_equity"].(float64); ok { + ctx.TotalEquity = equity + } + if available, ok := balance["availableBalance"].(float64); ok { + ctx.AvailableBalance = available + } + if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { + ctx.UnrealizedPnL = unrealized + } + } + + // Get current position + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok { + ctx.CurrentPosition = size + } + } + } + } + + return ctx, nil +} + +// executeGridDecision executes a single grid decision +func (at *AutoTrader) executeGridDecision(d *kernel.Decision) error { + switch d.Action { + case "place_buy_limit": + return at.placeGridLimitOrder(d, "BUY") + case "place_sell_limit": + return at.placeGridLimitOrder(d, "SELL") + case "cancel_order": + return at.cancelGridOrder(d) + case "cancel_all_orders": + return at.cancelAllGridOrders() + case "pause_grid": + return at.pauseGrid(d.Reasoning) + case "resume_grid": + return at.resumeGrid() + case "adjust_grid": + return at.adjustGrid(d) + case "hold": + logger.Infof("[Grid] Holding current state: %s", d.Reasoning) + return nil + // Support standard actions for closing positions + case "close_long": + _, err := at.trader.CloseLong(d.Symbol, d.Quantity) + return err + case "close_short": + _, err := at.trader.CloseShort(d.Symbol, d.Quantity) + return err + default: + logger.Warnf("[Grid] Unknown action: %s", d.Action) + return nil + } +} + +// checkTotalPositionLimit checks if adding a new position would exceed total limits +// Returns: (allowed bool, currentPositionValue float64, maxAllowed float64) +func (at *AutoTrader) checkTotalPositionLimit(symbol string, additionalValue float64) (bool, float64, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + // Calculate max allowed total position value + // Total position should not exceed: TotalInvestment × Leverage + maxTotalPositionValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) + + // Get current position value from exchange + currentPositionValue := 0.0 + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == symbol { + if size, ok := pos["positionAmt"].(float64); ok { + if price, ok := pos["markPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * price + } else if entryPrice, ok := pos["entryPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * entryPrice + } + } + } + } + } + + // Also count pending orders as potential position + at.gridState.mu.RLock() + pendingValue := 0.0 + for _, level := range at.gridState.Levels { + if level.State == "pending" { + pendingValue += level.OrderQuantity * level.Price + } + } + at.gridState.mu.RUnlock() + + totalAfterOrder := currentPositionValue + pendingValue + additionalValue + allowed := totalAfterOrder <= maxTotalPositionValue + + return allowed, currentPositionValue + pendingValue, maxTotalPositionValue +} + +// placeGridLimitOrder places a limit order for grid trading +func (at *AutoTrader) placeGridLimitOrder(d *kernel.Decision, side string) error { + // Check if trader supports GridTrader interface + gridTrader, ok := at.trader.(GridTrader) + if !ok { + // Fallback to adapter + gridTrader = NewGridTraderAdapter(at.trader) + } + + gridConfig := at.config.StrategyConfig.GridConfig + + // CRITICAL: Validate and cap quantity to prevent excessive position sizes + // This protects against AI miscalculations or leverage misconfigurations + quantity := d.Quantity + if d.Price > 0 && gridConfig.TotalInvestment > 0 { + // Calculate max allowed position value per grid level + // Each level gets proportional share of total investment + maxMarginPerLevel := gridConfig.TotalInvestment / float64(gridConfig.GridCount) + maxPositionValuePerLevel := maxMarginPerLevel * float64(gridConfig.Leverage) + maxQuantityPerLevel := maxPositionValuePerLevel / d.Price + + // Also get the level's allocated USD for additional validation + at.gridState.mu.RLock() + var levelAllocatedUSD float64 + if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) { + levelAllocatedUSD = at.gridState.Levels[d.LevelIndex].AllocatedUSD + } + at.gridState.mu.RUnlock() + + // Use level-specific allocation if available + if levelAllocatedUSD > 0 { + levelMaxPositionValue := levelAllocatedUSD * float64(gridConfig.Leverage) + levelMaxQuantity := levelMaxPositionValue / d.Price + if levelMaxQuantity < maxQuantityPerLevel { + maxQuantityPerLevel = levelMaxQuantity + } + } + + // Cap quantity if it exceeds the maximum allowed + if quantity > maxQuantityPerLevel { + logger.Warnf("[Grid] ⚠️ Quantity %.4f exceeds max allowed %.4f (position_value $%.2f > max $%.2f), capping", + quantity, maxQuantityPerLevel, quantity*d.Price, maxPositionValuePerLevel) + quantity = maxQuantityPerLevel + } + + // Safety check: ensure position value is reasonable (within 2x of intended max as absolute limit) + positionValue := quantity * d.Price + absoluteMaxValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) * 2 // 2x safety margin + if positionValue > absoluteMaxValue { + logger.Errorf("[Grid] CRITICAL: Position value $%.2f exceeds absolute max $%.2f! Rejecting order.", + positionValue, absoluteMaxValue) + return fmt.Errorf("position value $%.2f exceeds safety limit $%.2f", positionValue, absoluteMaxValue) + } + } + + // CRITICAL: Check total position limit before placing order + orderValue := quantity * d.Price + allowed, currentValue, maxValue := at.checkTotalPositionLimit(d.Symbol, orderValue) + if !allowed { + logger.Errorf("[Grid] TOTAL POSITION LIMIT EXCEEDED: current=$%.2f + order=$%.2f > max=$%.2f. Rejecting order.", + currentValue, orderValue, maxValue) + return fmt.Errorf("total position value $%.2f would exceed limit $%.2f", currentValue+orderValue, maxValue) + } + + req := &LimitOrderRequest{ + Symbol: d.Symbol, + Side: side, + Price: d.Price, + Quantity: quantity, // Use validated/capped quantity + Leverage: gridConfig.Leverage, + PostOnly: gridConfig.UseMakerOnly, + ReduceOnly: false, + ClientID: fmt.Sprintf("grid-%d-%d", d.LevelIndex, time.Now().UnixNano()%1000000), + } + + result, err := gridTrader.PlaceLimitOrder(req) + if err != nil { + return fmt.Errorf("failed to place limit order: %w", err) + } + + // Update grid level state + at.gridState.mu.Lock() + if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) { + at.gridState.Levels[d.LevelIndex].State = "pending" + at.gridState.Levels[d.LevelIndex].OrderID = result.OrderID + at.gridState.Levels[d.LevelIndex].OrderQuantity = d.Quantity + at.gridState.OrderBook[result.OrderID] = d.LevelIndex + } + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Placed %s limit order at $%.2f, qty=%.4f, level=%d, orderID=%s", + side, d.Price, d.Quantity, d.LevelIndex, result.OrderID) + + return nil +} + +// cancelGridOrder cancels a specific grid order +func (at *AutoTrader) cancelGridOrder(d *kernel.Decision) error { + gridTrader, ok := at.trader.(GridTrader) + if !ok { + gridTrader = NewGridTraderAdapter(at.trader) + } + + if err := gridTrader.CancelOrder(d.Symbol, d.OrderID); err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + // Update state + at.gridState.mu.Lock() + if levelIdx, ok := at.gridState.OrderBook[d.OrderID]; ok { + if levelIdx >= 0 && levelIdx < len(at.gridState.Levels) { + at.gridState.Levels[levelIdx].State = "empty" + at.gridState.Levels[levelIdx].OrderID = "" + at.gridState.Levels[levelIdx].OrderQuantity = 0 + } + delete(at.gridState.OrderBook, d.OrderID) + } + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Cancelled order: %s", d.OrderID) + return nil +} + +// cancelAllGridOrders cancels all grid orders +func (at *AutoTrader) cancelAllGridOrders() error { + gridConfig := at.config.StrategyConfig.GridConfig + + if err := at.trader.CancelAllOrders(gridConfig.Symbol); err != nil { + return fmt.Errorf("failed to cancel all orders: %w", err) + } + + // Reset all pending levels + at.gridState.mu.Lock() + for i := range at.gridState.Levels { + if at.gridState.Levels[i].State == "pending" { + at.gridState.Levels[i].State = "empty" + at.gridState.Levels[i].OrderID = "" + at.gridState.Levels[i].OrderQuantity = 0 + } + } + at.gridState.OrderBook = make(map[string]int) + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Cancelled all orders") + return nil +} + +// pauseGrid pauses grid trading +func (at *AutoTrader) pauseGrid(reason string) error { + at.cancelAllGridOrders() + + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Paused: %s", reason) + return nil +} + +// resumeGrid resumes grid trading +func (at *AutoTrader) resumeGrid() error { + at.gridState.mu.Lock() + at.gridState.IsPaused = false + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Resumed") + return nil +} + +// adjustGrid adjusts grid parameters +func (at *AutoTrader) adjustGrid(d *kernel.Decision) error { + // Cancel existing orders first + at.cancelAllGridOrders() + + gridConfig := at.config.StrategyConfig.GridConfig + + // Get current price + price, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return fmt.Errorf("failed to get market price: %w", err) + } + + // Reinitialize grid levels + at.initializeGridLevels(price, gridConfig) + + logger.Infof("[Grid] Adjusted grid bounds around price $%.2f", price) + return nil +} + +// syncGridState syncs grid state with exchange +func (at *AutoTrader) syncGridState() { + gridConfig := at.config.StrategyConfig.GridConfig + + // Get open orders from exchange + openOrders, err := at.trader.GetOpenOrders(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get open orders: %v", err) + return + } + + // Build set of active order IDs + activeOrderIDs := make(map[string]bool) + for _, order := range openOrders { + activeOrderIDs[order.OrderID] = true + } + + // Get current positions to verify fills + positions, err := at.trader.GetPositions() + currentPositionSize := 0.0 + if err != nil { + logger.Warnf("[Grid] Failed to get positions for state sync: %v", err) + } else { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok { + currentPositionSize = size + } + } + } + } + + // Update levels based on order status + at.gridState.mu.Lock() + expectedPositionSize := 0.0 + for _, level := range at.gridState.Levels { + if level.State == "filled" { + expectedPositionSize += level.PositionSize + } + } + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State == "pending" && level.OrderID != "" { + if !activeOrderIDs[level.OrderID] { + // Order no longer exists - check if position changed to determine fill vs cancel + // This is a heuristic - ideally we'd query order history + // If current position is larger than expected filled positions, this order was likely filled + if math.Abs(currentPositionSize) > math.Abs(expectedPositionSize) { + // Position increased, likely filled + level.State = "filled" + level.PositionEntry = level.Price + level.PositionSize = level.OrderQuantity + at.gridState.TotalTrades++ + logger.Infof("[Grid] Level %d order filled at $%.2f", i, level.Price) + } else { + // Position didn't increase as expected, likely cancelled + level.State = "empty" + level.OrderID = "" + level.OrderQuantity = 0 + logger.Infof("[Grid] Level %d order cancelled/expired", i) + } + delete(at.gridState.OrderBook, level.OrderID) + } + } + } + at.gridState.mu.Unlock() + + logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", currentPositionSize, len(openOrders)) + + // Check stop loss + at.checkAndExecuteStopLoss() + + // Check grid skew + at.autoAdjustGrid() +} + +// saveGridDecisionRecord saves the grid decision to database +func (at *AutoTrader) saveGridDecisionRecord(decision *kernel.FullDecision) { + if at.store == nil { + return + } + + at.cycleNumber++ + + record := &store.DecisionRecord{ + TraderID: at.id, + CycleNumber: at.cycleNumber, + Timestamp: time.Now().UTC(), + SystemPrompt: decision.SystemPrompt, + InputPrompt: decision.UserPrompt, + CoTTrace: decision.CoTTrace, + RawResponse: decision.RawResponse, + AIRequestDurationMs: decision.AIRequestDurationMs, + Success: true, + } + + if len(decision.Decisions) > 0 { + decisionJSON, _ := json.MarshalIndent(decision.Decisions, "", " ") + record.DecisionJSON = string(decisionJSON) + + // Convert kernel.Decision to store.DecisionAction for frontend display + for _, d := range decision.Decisions { + actionRecord := store.DecisionAction{ + Action: d.Action, + Symbol: d.Symbol, + Quantity: d.Quantity, + Leverage: d.Leverage, + Price: d.Price, + StopLoss: d.StopLoss, + TakeProfit: d.TakeProfit, + Confidence: d.Confidence, + Reasoning: d.Reasoning, + Timestamp: time.Now().UTC(), + Success: true, // Grid decisions are executed inline + } + record.Decisions = append(record.Decisions, actionRecord) + } + } + + record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("Grid cycle completed with %d decisions", len(decision.Decisions))) + + if err := at.store.Decision().LogDecision(record); err != nil { + logger.Warnf("[Grid] Failed to save decision record: %v", err) + } +} + +// IsGridStrategy returns true if current strategy is grid trading +func (at *AutoTrader) IsGridStrategy() bool { + if at.config.StrategyConfig == nil { + return false + } + return at.config.StrategyConfig.StrategyType == "grid_trading" && at.config.StrategyConfig.GridConfig != nil +} + +// checkGridSkew checks if grid is heavily skewed (too many fills on one side) +// Returns: (skewed bool, buyFilledCount int, sellFilledCount int) +func (at *AutoTrader) checkGridSkew() (bool, int, int) { + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + buyFilled := 0 + sellFilled := 0 + buyEmpty := 0 + sellEmpty := 0 + + for _, level := range at.gridState.Levels { + if level.Side == "buy" { + if level.State == "filled" { + buyFilled++ + } else if level.State == "empty" { + buyEmpty++ + } + } else { + if level.State == "filled" { + sellFilled++ + } else if level.State == "empty" { + sellEmpty++ + } + } + } + + // Grid is skewed if one side has 3x more fills than the other + // or if one side is completely empty + skewed := false + if buyFilled > 0 && sellFilled == 0 && sellEmpty > 5 { + skewed = true // All buys filled, no sells + } else if sellFilled > 0 && buyFilled == 0 && buyEmpty > 5 { + skewed = true // All sells filled, no buys + } else if buyFilled >= 3*sellFilled && buyFilled > 5 { + skewed = true + } else if sellFilled >= 3*buyFilled && sellFilled > 5 { + skewed = true + } + + return skewed, buyFilled, sellFilled +} + +// autoAdjustGrid automatically adjusts grid when heavily skewed +func (at *AutoTrader) autoAdjustGrid() { + skewed, buyFilled, sellFilled := at.checkGridSkew() + if !skewed { + return + } + + logger.Warnf("[Grid] Grid heavily skewed: buy_filled=%d, sell_filled=%d. Auto-adjusting...", + buyFilled, sellFilled) + + gridConfig := at.config.StrategyConfig.GridConfig + + // Get current price + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Errorf("[Grid] Failed to get price for auto-adjust: %v", err) + return + } + + // Check if price is near grid boundary + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + // Only adjust if price has moved significantly (>30% of grid range) + gridRange := upper - lower + midPrice := (upper + lower) / 2 + priceDeviation := math.Abs(currentPrice - midPrice) + + if priceDeviation < gridRange*0.3 { + return // Price still near center, don't adjust + } + + logger.Infof("[Grid] Adjusting grid around new price $%.2f", currentPrice) + + // Cancel existing orders first (before taking the lock for state modification) + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders during auto-adjust: %v", err) + // Continue with adjustment anyway + } + + // CRITICAL FIX: Hold lock for the entire adjustment operation to ensure atomicity + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + // Preserve filled positions before reinitializing + filledPositions := make(map[int]kernel.GridLevelInfo) + for i, level := range at.gridState.Levels { + if level.State == "filled" { + filledPositions[i] = level + } + } + + // CRITICAL FIX: Recalculate grid bounds centered on current price + // Use the same logic as InitializeGrid() - either ATR-based or default percentage + if gridConfig.UseATRBounds { + // Try to get ATR for bound calculation + mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"4h"}, "4h", 20) + if err != nil { + logger.Warnf("[Grid] Failed to get market data for ATR during adjust: %v, using default bounds", err) + at.calculateDefaultBoundsLocked(currentPrice, gridConfig) + } else { + at.calculateATRBoundsLocked(currentPrice, mktData, gridConfig) + } + } else { + // Use default bounds calculation (scaled by grid count) + at.calculateDefaultBoundsLocked(currentPrice, gridConfig) + } + + // Recalculate grid spacing based on new bounds + at.gridState.GridSpacing = (at.gridState.UpperPrice - at.gridState.LowerPrice) / float64(gridConfig.GridCount-1) + + logger.Infof("[Grid] New bounds: $%.2f - $%.2f, spacing: $%.2f", + at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) + + // Initialize new grid levels (without lock since we already hold it) + at.initializeGridLevelsLocked(currentPrice, gridConfig) + + // CRITICAL FIX: Restore filled positions - find closest new level for each filled position + for _, filledLevel := range filledPositions { + closestIdx := -1 + closestDist := math.MaxFloat64 + + for i, newLevel := range at.gridState.Levels { + dist := math.Abs(newLevel.Price - filledLevel.PositionEntry) + if dist < closestDist { + closestDist = dist + closestIdx = i + } + } + + if closestIdx >= 0 { + // Restore the filled state to the closest level + at.gridState.Levels[closestIdx].State = "filled" + at.gridState.Levels[closestIdx].PositionEntry = filledLevel.PositionEntry + at.gridState.Levels[closestIdx].PositionSize = filledLevel.PositionSize + at.gridState.Levels[closestIdx].UnrealizedPnL = filledLevel.UnrealizedPnL + at.gridState.Levels[closestIdx].OrderID = filledLevel.OrderID + at.gridState.Levels[closestIdx].OrderQuantity = filledLevel.OrderQuantity + logger.Infof("[Grid] Restored filled position at level %d (entry $%.2f)", closestIdx, filledLevel.PositionEntry) + } + } +} + +// calculateDefaultBoundsLocked calculates default bounds (caller must hold lock) +func (at *AutoTrader) calculateDefaultBoundsLocked(price float64, config *store.GridStrategyConfig) { + // Default: ±3% from current price, scaled by grid count + multiplier := 0.03 * float64(config.GridCount) / 10 + at.gridState.UpperPrice = price * (1 + multiplier) + at.gridState.LowerPrice = price * (1 - multiplier) +} + +// calculateATRBoundsLocked calculates bounds using ATR (caller must hold lock) +func (at *AutoTrader) calculateATRBoundsLocked(price float64, mktData *market.Data, config *store.GridStrategyConfig) { + atr := 0.0 + if mktData.LongerTermContext != nil { + atr = mktData.LongerTermContext.ATR14 + } + + if atr <= 0 { + at.calculateDefaultBoundsLocked(price, config) + return + } + + multiplier := config.ATRMultiplier + if multiplier <= 0 { + multiplier = 2.0 + } + + halfRange := atr * multiplier + at.gridState.UpperPrice = price + halfRange + at.gridState.LowerPrice = price - halfRange +} + +// initializeGridLevelsLocked creates the grid level structure (caller must hold lock) +func (at *AutoTrader) initializeGridLevelsLocked(currentPrice float64, config *store.GridStrategyConfig) { + levels := make([]kernel.GridLevelInfo, config.GridCount) + totalWeight := 0.0 + weights := make([]float64, config.GridCount) + + // Calculate weights based on distribution + for i := 0; i < config.GridCount; i++ { + switch config.Distribution { + case "gaussian": + // Gaussian distribution - more weight in the middle + center := float64(config.GridCount-1) / 2 + sigma := float64(config.GridCount) / 4 + weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma)) + case "pyramid": + // Pyramid - more weight at bottom + weights[i] = float64(config.GridCount - i) + default: // uniform + weights[i] = 1.0 + } + totalWeight += weights[i] + } + + // Create levels + for i := 0; i < config.GridCount; i++ { + price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing + allocatedUSD := config.TotalInvestment * weights[i] / totalWeight + + // Determine initial side (below current price = buy, above = sell) + side := "buy" + if price > currentPrice { + side = "sell" + } + + levels[i] = kernel.GridLevelInfo{ + Index: i, + Price: price, + State: "empty", + Side: side, + AllocatedUSD: allocatedUSD, + } + } + + at.gridState.Levels = levels +} + +// GridRiskInfo contains risk information for frontend display +type GridRiskInfo struct { + CurrentLeverage int `json:"current_leverage"` + EffectiveLeverage float64 `json:"effective_leverage"` + RecommendedLeverage int `json:"recommended_leverage"` + + CurrentPosition float64 `json:"current_position"` + MaxPosition float64 `json:"max_position"` + PositionPercent float64 `json:"position_percent"` + + LiquidationPrice float64 `json:"liquidation_price"` + LiquidationDistance float64 `json:"liquidation_distance"` + + RegimeLevel string `json:"regime_level"` + + ShortBoxUpper float64 `json:"short_box_upper"` + ShortBoxLower float64 `json:"short_box_lower"` + MidBoxUpper float64 `json:"mid_box_upper"` + MidBoxLower float64 `json:"mid_box_lower"` + LongBoxUpper float64 `json:"long_box_upper"` + LongBoxLower float64 `json:"long_box_lower"` + CurrentPrice float64 `json:"current_price"` + + BreakoutLevel string `json:"breakout_level"` + BreakoutDirection string `json:"breakout_direction"` +} + +// GetGridRiskInfo returns current risk information for frontend display +func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return &GridRiskInfo{} + } + + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + // Get current price + currentPrice, _ := at.trader.GetMarketPrice(gridConfig.Symbol) + + // Calculate effective leverage + totalInvestment := gridConfig.TotalInvestment + leverage := gridConfig.Leverage + + // Get current position value + positions, _ := at.trader.GetPositions() + var currentPositionValue float64 + var currentPositionSize float64 + for _, pos := range positions { + if sym, _ := pos["symbol"].(string); sym == gridConfig.Symbol { + size, _ := pos["positionAmt"].(float64) + entry, _ := pos["entryPrice"].(float64) + currentPositionValue = math.Abs(size * entry) + currentPositionSize = size + break + } + } + + effectiveLeverage := 0.0 + if totalInvestment > 0 { + effectiveLeverage = currentPositionValue / totalInvestment + } + + // Calculate max position based on regime + regimeLevel := market.RegimeLevel(at.gridState.CurrentRegimeLevel) + if regimeLevel == "" { + regimeLevel = market.RegimeLevelStandard + } + + // Use default position limit since GridStrategyConfig doesn't have regime-specific limits + // Default is 70% for standard regime + maxPositionPct := 70.0 + switch regimeLevel { + case market.RegimeLevelNarrow: + maxPositionPct = 40.0 + case market.RegimeLevelStandard: + maxPositionPct = 70.0 + case market.RegimeLevelWide: + maxPositionPct = 60.0 + case market.RegimeLevelVolatile: + maxPositionPct = 40.0 + } + + maxPosition := totalInvestment * maxPositionPct / 100 * float64(leverage) + + // Use default leverage limits since GridStrategyConfig doesn't have regime-specific limits + recommendedLeverage := leverage + switch regimeLevel { + case market.RegimeLevelNarrow: + recommendedLeverage = min(leverage, 2) + case market.RegimeLevelStandard: + recommendedLeverage = min(leverage, 4) + case market.RegimeLevelWide: + recommendedLeverage = min(leverage, 3) + case market.RegimeLevelVolatile: + recommendedLeverage = min(leverage, 2) + } + + // Calculate liquidation distance and price only when there's a position + var liquidationDistance float64 + var liquidationPrice float64 + if currentPositionSize != 0 && currentPrice > 0 { + liquidationDistance = 100.0 / float64(leverage) * 0.9 // ~90% of theoretical max + if currentPositionSize > 0 { + // Long position: liquidation below entry + liquidationPrice = currentPrice * (1 - liquidationDistance/100) + } else { + // Short position: liquidation above entry + liquidationPrice = currentPrice * (1 + liquidationDistance/100) + } + } + + positionPercent := 0.0 + if maxPosition > 0 { + positionPercent = currentPositionValue / maxPosition * 100 + } + + return &GridRiskInfo{ + CurrentLeverage: leverage, + EffectiveLeverage: effectiveLeverage, + RecommendedLeverage: recommendedLeverage, + + CurrentPosition: currentPositionValue, + MaxPosition: maxPosition, + PositionPercent: positionPercent, + + LiquidationPrice: liquidationPrice, + LiquidationDistance: liquidationDistance, + + RegimeLevel: string(regimeLevel), + + ShortBoxUpper: at.gridState.ShortBoxUpper, + ShortBoxLower: at.gridState.ShortBoxLower, + MidBoxUpper: at.gridState.MidBoxUpper, + MidBoxLower: at.gridState.MidBoxLower, + LongBoxUpper: at.gridState.LongBoxUpper, + LongBoxLower: at.gridState.LongBoxLower, + CurrentPrice: currentPrice, + + BreakoutLevel: at.gridState.BreakoutLevel, + BreakoutDirection: at.gridState.BreakoutDirection, + } +} + +// checkAndExecuteStopLoss checks if any filled level has exceeded stop loss and closes it +func (at *AutoTrader) checkAndExecuteStopLoss() { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.StopLossPct <= 0 { + return // Stop loss not configured + } + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get market price for stop loss check: %v", err) + return + } + + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State != "filled" || level.PositionEntry <= 0 { + continue + } + + // Calculate loss percentage + var lossPct float64 + if level.Side == "buy" { + // Long position: loss when price drops + lossPct = (level.PositionEntry - currentPrice) / level.PositionEntry * 100 + } else { + // Short position: loss when price rises + lossPct = (currentPrice - level.PositionEntry) / level.PositionEntry * 100 + } + + // Check if stop loss triggered + if lossPct >= gridConfig.StopLossPct { + logger.Warnf("[Grid] STOP LOSS TRIGGERED: Level %d, entry=$%.2f, current=$%.2f, loss=%.2f%%", + i, level.PositionEntry, currentPrice, lossPct) + + // Close the position + var closeErr error + if level.Side == "buy" { + _, closeErr = at.trader.CloseLong(gridConfig.Symbol, level.PositionSize) + } else { + _, closeErr = at.trader.CloseShort(gridConfig.Symbol, level.PositionSize) + } + + if closeErr != nil { + logger.Errorf("[Grid] Failed to execute stop loss for level %d: %v", i, closeErr) + } else { + level.State = "stopped" + realizedLoss := -lossPct * level.AllocatedUSD / 100 + level.UnrealizedPnL = realizedLoss + at.gridState.TotalTrades++ + // Update daily PnL tracking (lock already held, update directly) + at.gridState.DailyPnL += realizedLoss + at.gridState.TotalProfit += realizedLoss + logger.Infof("[Grid] Stop loss executed: Level %d closed at $%.2f (loss %.2f%%)", + i, currentPrice, lossPct) + } + } + } +} diff --git a/trader/binance_futures.go b/trader/binance_futures.go index 5a54db04..a7ef6dd0 100644 --- a/trader/binance_futures.go +++ b/trader/binance_futures.go @@ -716,6 +716,125 @@ func (t *FuturesTrader) CancelAllOrders(symbol string) error { return nil } +// PlaceLimitOrder places a limit order for grid trading +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // Format quantity to correct precision + quantityStr, err := t.FormatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Format price to correct precision + priceStr, err := t.FormatPrice(req.Symbol, req.Price) + if err != nil { + return nil, fmt.Errorf("failed to format price: %w", err) + } + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("Failed to set leverage: %v", err) + } + } + + // Determine side and position side + var side futures.SideType + var positionSide futures.PositionSideType + + if req.Side == "BUY" { + side = futures.SideTypeBuy + positionSide = futures.PositionSideTypeLong + } else { + side = futures.SideTypeSell + positionSide = futures.PositionSideTypeShort + } + + // Build order service with broker ID + orderService := t.client.NewCreateOrderService(). + Symbol(req.Symbol). + Side(side). + PositionSide(positionSide). + Type(futures.OrderTypeLimit). + TimeInForce(futures.TimeInForceTypeGTC). + Quantity(quantityStr). + Price(priceStr). + NewClientOrderID(getBrOrderID()) + + // Execute order + order, err := orderService.Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + logger.Infof("✓ [Grid] Placed limit order: %s %s %s @ %s, qty=%s, orderID=%d", + req.Symbol, req.Side, positionSide, priceStr, quantityStr, order.OrderID) + + return &LimitOrderResult{ + OrderID: fmt.Sprintf("%d", order.OrderID), + ClientID: order.ClientOrderID, + Symbol: order.Symbol, + Side: string(order.Side), + PositionSide: string(order.PositionSide), + Price: req.Price, + Quantity: req.Quantity, + Status: string(order.Status), + }, nil +} + +// CancelOrder cancels a specific order by ID +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) CancelOrder(symbol, orderID string) error { + // Parse order ID to int64 + orderIDInt, err := strconv.ParseInt(orderID, 10, 64) + if err != nil { + return fmt.Errorf("invalid order ID: %w", err) + } + + _, err = t.client.NewCancelOrderService(). + Symbol(symbol). + OrderID(orderIDInt). + Do(context.Background()) + + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Grid] Cancelled order: %s/%s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + book, err := t.client.NewDepthService(). + Symbol(symbol). + Limit(depth). + Do(context.Background()) + + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + // Convert bids + bids = make([][]float64, len(book.Bids)) + for i, bid := range book.Bids { + price, _ := strconv.ParseFloat(bid.Price, 64) + qty, _ := strconv.ParseFloat(bid.Quantity, 64) + bids[i] = []float64{price, qty} + } + + // Convert asks + asks = make([][]float64, len(book.Asks)) + for i, ask := range book.Asks { + price, _ := strconv.ParseFloat(ask.Price, 64) + qty, _ := strconv.ParseFloat(ask.Quantity, 64) + asks[i] = []float64{price, qty} + } + + return bids, asks, nil +} + // CancelStopOrders cancels take-profit/stop-loss orders for this symbol (used to adjust TP/SL positions) // Now uses both legacy API and new Algo Order API (Binance migrated stop orders to Algo system) func (t *FuturesTrader) CancelStopOrders(symbol string) error { @@ -1035,6 +1154,42 @@ func (t *FuturesTrader) FormatQuantity(symbol string, quantity float64) (string, return fmt.Sprintf(format, quantity), nil } +// GetSymbolPricePrecision gets the price precision for a trading pair +func (t *FuturesTrader) GetSymbolPricePrecision(symbol string) (int, error) { + exchangeInfo, err := t.client.NewExchangeInfoService().Do(context.Background()) + if err != nil { + return 0, fmt.Errorf("failed to get trading rules: %w", err) + } + + for _, s := range exchangeInfo.Symbols { + if s.Symbol == symbol { + // Get precision from PRICE_FILTER filter + for _, filter := range s.Filters { + if filter["filterType"] == "PRICE_FILTER" { + tickSize := filter["tickSize"].(string) + precision := calculatePrecision(tickSize) + return precision, nil + } + } + } + } + + // Default to 2 decimal places for price + return 2, nil +} + +// FormatPrice formats price to correct precision +func (t *FuturesTrader) FormatPrice(symbol string, price float64) (string, error) { + precision, err := t.GetSymbolPricePrecision(symbol) + if err != nil { + // If retrieval fails, use default format + return fmt.Sprintf("%.2f", price), nil + } + + format := fmt.Sprintf("%%.%df", precision) + return fmt.Sprintf(format, price), nil +} + // Helper functions func contains(s, substr string) bool { return len(s) >= len(substr) && stringContains(s, substr) diff --git a/trader/binance_sync_e2e_test.go b/trader/binance_sync_e2e_test.go index 9024b43b..91f42436 100644 --- a/trader/binance_sync_e2e_test.go +++ b/trader/binance_sync_e2e_test.go @@ -92,7 +92,7 @@ func TestBinanceSyncE2E(t *testing.T) { t.Logf(" [%d] %s %s %s qty=%.6f price=%.4f action=%s time=%s", i+1, order.ExchangeOrderID, order.Symbol, order.Side, order.Quantity, order.Price, order.OrderAction, - order.FilledAt.Format(time.RFC3339)) + time.UnixMilli(order.FilledAt).Format(time.RFC3339)) } } @@ -118,10 +118,11 @@ func TestBinanceSyncE2E(t *testing.T) { } // Test GetLastFillTimeByExchange - lastFillTime, err := orderStore.GetLastFillTimeByExchange(exchangeID) + lastFillTimeMs, err := orderStore.GetLastFillTimeByExchange(exchangeID) if err != nil { t.Logf(" ⚠️ GetLastFillTimeByExchange error: %v", err) } else { + lastFillTime := time.UnixMilli(lastFillTimeMs) t.Logf("\n📅 Last fill time from DB: %s", lastFillTime.Format(time.RFC3339)) // Check if it would be in the future (the bug we fixed) @@ -175,7 +176,7 @@ func TestBinanceSyncWithExistingData(t *testing.T) { Price: 50000, Quantity: 0.001, QuoteQuantity: 50, - CreatedAt: localTime, // This time is "in the future" if interpreted as UTC + CreatedAt: localTime.UnixMilli(), // This time is "in the future" if interpreted as UTC } if err := orderStore.CreateFill(fakeFill); err != nil { t.Fatalf("Failed to create fake fill: %v", err) @@ -186,10 +187,11 @@ func TestBinanceSyncWithExistingData(t *testing.T) { t.Logf(" Current UTC time: %s", time.Now().UTC().Format(time.RFC3339)) // Check GetLastFillTimeByExchange - lastFillTime, _ := orderStore.GetLastFillTimeByExchange(exchangeID) - t.Logf(" GetLastFillTimeByExchange returned: %s", lastFillTime.Format(time.RFC3339)) + lastFillTimeMs2, _ := orderStore.GetLastFillTimeByExchange(exchangeID) + lastFillTime2 := time.UnixMilli(lastFillTimeMs2) + t.Logf(" GetLastFillTimeByExchange returned: %s", lastFillTime2.Format(time.RFC3339)) - if lastFillTime.After(time.Now().UTC()) { + if lastFillTime2.After(time.Now().UTC()) { t.Logf(" ⚠️ Last fill time is in the future - this is the bug scenario!") } diff --git a/trader/bitget_trader.go b/trader/bitget_trader.go index 41f42f4a..8990df52 100644 --- a/trader/bitget_trader.go +++ b/trader/bitget_trader.go @@ -1099,6 +1099,240 @@ func genBitgetClientOid() string { // GetOpenOrders gets all open/pending orders for a symbol func (t *BitgetTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { - // TODO: Implement Bitget open orders - return []OpenOrder{}, nil + symbol = t.convertSymbol(symbol) + var result []OpenOrder + + // 1. Get pending limit orders + params := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + } + + data, err := t.doRequest("GET", bitgetPendingPath, params) + if err != nil { + logger.Warnf("[Bitget] Failed to get pending orders: %v", err) + } + if err == nil && data != nil { + var orders struct { + EntrustedList []struct { + OrderId string `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` // buy/sell + TradeSide string `json:"tradeSide"` // open/close + PosSide string `json:"posSide"` // long/short + OrderType string `json:"orderType"` // limit/market + Price string `json:"price"` + Size string `json:"size"` + State string `json:"state"` + } `json:"entrustedList"` + } + if err := json.Unmarshal(data, &orders); err == nil { + for _, order := range orders.EntrustedList { + price, _ := strconv.ParseFloat(order.Price, 64) + quantity, _ := strconv.ParseFloat(order.Size, 64) + + // Convert side to standard format + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + + result = append(result, OpenOrder{ + OrderID: order.OrderId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: strings.ToUpper(order.OrderType), + Price: price, + StopPrice: 0, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + // 2. Get pending plan orders (stop-loss/take-profit) + planParams := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + } + + planData, err := t.doRequest("GET", "/api/v2/mix/order/orders-plan-pending", planParams) + if err != nil { + logger.Warnf("[Bitget] Failed to get plan orders: %v", err) + } + if err == nil && planData != nil { + var planOrders struct { + EntrustedList []struct { + OrderId string `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` + PosSide string `json:"posSide"` + PlanType string `json:"planType"` // normal_plan/profit_plan/loss_plan + TriggerPrice string `json:"triggerPrice"` + Size string `json:"size"` + State string `json:"state"` + } `json:"entrustedList"` + } + if err := json.Unmarshal(planData, &planOrders); err == nil { + for _, order := range planOrders.EntrustedList { + triggerPrice, _ := strconv.ParseFloat(order.TriggerPrice, 64) + quantity, _ := strconv.ParseFloat(order.Size, 64) + + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + + // Map Bitget plan type to order type + orderType := "STOP_MARKET" + if order.PlanType == "profit_plan" { + orderType = "TAKE_PROFIT_MARKET" + } + + result = append(result, OpenOrder{ + OrderID: order.OrderId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: orderType, + Price: 0, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + logger.Infof("✓ BITGET GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *BitgetTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + symbol := t.convertSymbol(req.Symbol) + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(symbol, req.Leverage); err != nil { + logger.Warnf("[Bitget] Failed to set leverage: %v", err) + } + } + + // Format quantity + qtyStr, _ := t.FormatQuantity(symbol, req.Quantity) + + // Determine side + side := "buy" + if req.Side == "SELL" { + side = "sell" + } + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "side": side, + "orderType": "limit", + "size": qtyStr, + "price": fmt.Sprintf("%.8f", req.Price), + "force": "GTC", // Good Till Cancel + "clientOid": genBitgetClientOid(), + } + + // Add reduce only if specified + if req.ReduceOnly { + body["reduceOnly"] = "YES" + } + + logger.Infof("[Bitget] PlaceLimitOrder: %s %s @ %.4f, qty=%s", symbol, side, req.Price, qtyStr) + + data, err := t.doRequest("POST", bitgetOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var order struct { + OrderId string `json:"orderId"` + ClientOid string `json:"clientOid"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + logger.Infof("✓ [Bitget] Limit order placed: %s %s @ %.4f, orderID=%s", + symbol, side, req.Price, order.OrderId) + + return &LimitOrderResult{ + OrderID: order.OrderId, + ClientID: order.ClientOid, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *BitgetTrader) CancelOrder(symbol, orderID string) error { + symbol = t.convertSymbol(symbol) + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "orderId": orderID, + } + + _, err := t.doRequest("POST", "/api/v2/mix/order/cancel-order", body) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Bitget] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *BitgetTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + symbol = t.convertSymbol(symbol) + path := fmt.Sprintf("/api/v2/mix/market/depth?symbol=%s&productType=USDT-FUTURES&limit=%d", symbol, depth) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + var result struct { + Bids [][]string `json:"bids"` + Asks [][]string `json:"asks"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + // Parse bids + for _, b := range result.Bids { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result.Asks { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil } diff --git a/trader/bybit_trader.go b/trader/bybit_trader.go index d40de870..745a58e4 100644 --- a/trader/bybit_trader.go +++ b/trader/bybit_trader.go @@ -1105,3 +1105,159 @@ func (t *BybitTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { return result, nil } + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *BybitTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // Format quantity + qtyStr, err := t.FormatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Format price + priceStr := fmt.Sprintf("%.8f", req.Price) + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Bybit] Failed to set leverage: %v", err) + } + } + + // Determine side + side := "Buy" + if req.Side == "SELL" { + side = "Sell" + } + + params := map[string]interface{}{ + "category": "linear", + "symbol": req.Symbol, + "side": side, + "orderType": "Limit", + "qty": qtyStr, + "price": priceStr, + "timeInForce": "GTC", // Good Till Cancel + "positionIdx": 0, // One-way position mode + } + + // Add reduce only if specified + if req.ReduceOnly { + params["reduceOnly"] = true + } + + logger.Infof("[Bybit] PlaceLimitOrder: %s %s @ %s, qty=%s", req.Symbol, side, priceStr, qtyStr) + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + // Parse result + orderID := "" + if result.RetCode == 0 { + if resultData, ok := result.Result.(map[string]interface{}); ok { + if id, ok := resultData["orderId"].(string); ok { + orderID = id + } + } + } else { + return nil, fmt.Errorf("Bybit order failed: %s", result.RetMsg) + } + + logger.Infof("✓ [Bybit] Limit order placed: %s %s @ %s, qty=%s, orderID=%s", + req.Symbol, side, priceStr, qtyStr, orderID) + + return &LimitOrderResult{ + OrderID: orderID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *BybitTrader) CancelOrder(symbol, orderID string) error { + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "orderId": orderID, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).CancelOrder(context.Background()) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + if result.RetCode != 0 { + return fmt.Errorf("Bybit cancel order failed: %s", result.RetMsg) + } + + logger.Infof("✓ [Bybit] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *BybitTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + if depth <= 0 { + depth = 25 + } + + // Use HTTP request directly since the SDK doesn't expose GetOrderbook + url := fmt.Sprintf("https://api.bybit.com/v5/market/orderbook?category=linear&symbol=%s&limit=%d", symbol, depth) + resp, err := http.Get(url) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + RetCode int `json:"retCode"` + RetMsg string `json:"retMsg"` + Result struct { + S string `json:"s"` // symbol + B [][]string `json:"b"` // bids [[price, size], ...] + A [][]string `json:"a"` // asks [[price, size], ...] + } `json:"result"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + if result.RetCode != 0 { + return nil, nil, fmt.Errorf("Bybit get orderbook failed: %s", result.RetMsg) + } + + // Parse bids + for _, b := range result.Result.B { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result.Result.A { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil +} diff --git a/trader/exchange_sync_test.go b/trader/exchange_sync_test.go index 7811d4cd..fc0c4866 100644 --- a/trader/exchange_sync_test.go +++ b/trader/exchange_sync_test.go @@ -141,7 +141,7 @@ func runStandardTests(t *testing.T, exchangeName string) { traderID, exchangeID, exchangeType, trade.Symbol, trade.Side, trade.Action, trade.Quantity, trade.Price, trade.Fee, trade.RealizedPnL, - time.Now().Add(time.Duration(i)*time.Second), + time.Now().Add(time.Duration(i)*time.Second).UnixMilli(), "", ) if err != nil { @@ -227,7 +227,7 @@ func TestPositionAccumulationBug(t *testing.T) { traderID, exchangeID, exchangeType, "ETHUSDT", "LONG", "open_long", 0.1, 3500+float64(i*10), 0.5, 0, - time.Now().Add(time.Duration(i*2)*time.Second), + time.Now().Add(time.Duration(i*2)*time.Second).UnixMilli(), "", ) if err != nil { @@ -239,7 +239,7 @@ func TestPositionAccumulationBug(t *testing.T) { traderID, exchangeID, exchangeType, "ETHUSDT", "LONG", "close_long", 0.1, 3600+float64(i*10), 0.5, 10, - time.Now().Add(time.Duration(i*2+1)*time.Second), + time.Now().Add(time.Duration(i*2+1)*time.Second).UnixMilli(), "", ) if err != nil { @@ -309,7 +309,7 @@ func TestQuantityPrecision(t *testing.T) { traderID, exchangeID, exchangeType, "BTCUSDT", "LONG", "open_long", 0.01, 50000, 1.0, 0, - time.Now(), + time.Now().UnixMilli(), "", ) if err != nil { @@ -322,7 +322,7 @@ func TestQuantityPrecision(t *testing.T) { traderID, exchangeID, exchangeType, "BTCUSDT", "LONG", "close_long", 0.00999999, 51000, 1.0, 10, - time.Now().Add(time.Second), + time.Now().Add(time.Second).UnixMilli(), "", ) if err != nil { diff --git a/trader/grid_regime.go b/trader/grid_regime.go new file mode 100644 index 00000000..e574cc1b --- /dev/null +++ b/trader/grid_regime.go @@ -0,0 +1,196 @@ +package trader + +import ( + "nofx/market" + "nofx/store" + "time" +) + +// ============================================================================ +// Task 6: Regime Level Classification +// ============================================================================ + +// classifyRegimeLevel determines the regime level based on market indicators +// bollingerWidth: Bollinger band width as percentage +// atr14Pct: ATR14 as percentage of current price +func classifyRegimeLevel(bollingerWidth, atr14Pct float64) market.RegimeLevel { + // Narrow: Bollinger < 2%, ATR < 1% + if bollingerWidth < 2.0 && atr14Pct < 1.0 { + return market.RegimeLevelNarrow + } + + // Standard: Bollinger 2-3%, ATR 1-2% + if bollingerWidth <= 3.0 && atr14Pct <= 2.0 { + return market.RegimeLevelStandard + } + + // Wide: Bollinger 3-4%, ATR 2-3% + if bollingerWidth <= 4.0 && atr14Pct <= 3.0 { + return market.RegimeLevelWide + } + + // Volatile: Bollinger > 4%, ATR > 3% + return market.RegimeLevelVolatile +} + +// getRegimeLeverageLimit returns the effective leverage limit for a regime level +func getRegimeLeverageLimit(level market.RegimeLevel, config *store.GridConfigModel) int { + switch level { + case market.RegimeLevelNarrow: + if config.NarrowRegimeLeverage > 0 { + return config.NarrowRegimeLeverage + } + return 2 + case market.RegimeLevelStandard: + if config.StandardRegimeLeverage > 0 { + return config.StandardRegimeLeverage + } + return 4 + case market.RegimeLevelWide: + if config.WideRegimeLeverage > 0 { + return config.WideRegimeLeverage + } + return 3 + case market.RegimeLevelVolatile: + if config.VolatileRegimeLeverage > 0 { + return config.VolatileRegimeLeverage + } + return 2 + default: + return 2 // Conservative default + } +} + +// getRegimePositionLimit returns the position limit percentage for a regime level +func getRegimePositionLimit(level market.RegimeLevel, config *store.GridConfigModel) float64 { + switch level { + case market.RegimeLevelNarrow: + if config.NarrowRegimePositionPct > 0 { + return config.NarrowRegimePositionPct + } + return 40.0 + case market.RegimeLevelStandard: + if config.StandardRegimePositionPct > 0 { + return config.StandardRegimePositionPct + } + return 70.0 + case market.RegimeLevelWide: + if config.WideRegimePositionPct > 0 { + return config.WideRegimePositionPct + } + return 60.0 + case market.RegimeLevelVolatile: + if config.VolatileRegimePositionPct > 0 { + return config.VolatileRegimePositionPct + } + return 40.0 + default: + return 40.0 // Conservative default + } +} + +// ============================================================================ +// Task 7: Breakout Detection +// ============================================================================ + +// detectBoxBreakout checks if price has broken out of any box level +// Returns the highest breakout level and direction +func detectBoxBreakout(box *market.BoxData) (market.BreakoutLevel, string) { + if box == nil { + return market.BreakoutNone, "" + } + + price := box.CurrentPrice + + // Check long box first (highest priority) + if price > box.LongUpper { + return market.BreakoutLong, "up" + } + if price < box.LongLower { + return market.BreakoutLong, "down" + } + + // Check mid box + if price > box.MidUpper { + return market.BreakoutMid, "up" + } + if price < box.MidLower { + return market.BreakoutMid, "down" + } + + // Check short box + if price > box.ShortUpper { + return market.BreakoutShort, "up" + } + if price < box.ShortLower { + return market.BreakoutShort, "down" + } + + return market.BreakoutNone, "" +} + +// ============================================================================ +// Task 8: Breakout Confirmation Logic +// ============================================================================ + +const BreakoutConfirmRequired = 3 // 3 candles to confirm breakout + +// BreakoutState tracks the current breakout state +type BreakoutState struct { + Level market.BreakoutLevel + Direction string + ConfirmCount int + StartTime time.Time +} + +// confirmBreakout updates breakout state and returns true if breakout is confirmed +func confirmBreakout(state *BreakoutState, currentLevel market.BreakoutLevel, direction string) bool { + // If price returned to box, reset state + if currentLevel == market.BreakoutNone { + state.ConfirmCount = 0 + state.Level = market.BreakoutNone + state.Direction = "" + return false + } + + // If same breakout continues, increment count + if state.Level == currentLevel && state.Direction == direction { + state.ConfirmCount++ + } else { + // New breakout, reset count + state.Level = currentLevel + state.Direction = direction + state.ConfirmCount = 1 + state.StartTime = time.Now() + } + + return state.ConfirmCount >= BreakoutConfirmRequired +} + +// ============================================================================ +// Task 9: Breakout Handler +// ============================================================================ + +// BreakoutAction represents the action to take on breakout +type BreakoutAction int + +const ( + BreakoutActionNone BreakoutAction = iota + BreakoutActionReducePosition // Short box breakout: reduce to 50% + BreakoutActionPauseGrid // Mid box breakout: pause grid + cancel orders + BreakoutActionCloseAll // Long box breakout: pause + cancel + close all +) + +// getBreakoutAction returns the appropriate action for a breakout level +func getBreakoutAction(level market.BreakoutLevel) BreakoutAction { + switch level { + case market.BreakoutShort: + return BreakoutActionReducePosition + case market.BreakoutMid: + return BreakoutActionPauseGrid + case market.BreakoutLong: + return BreakoutActionCloseAll + default: + return BreakoutActionNone + } +} diff --git a/trader/grid_regime_test.go b/trader/grid_regime_test.go new file mode 100644 index 00000000..25d0753a --- /dev/null +++ b/trader/grid_regime_test.go @@ -0,0 +1,122 @@ +package trader + +import ( + "nofx/market" + "testing" +) + +func TestClassifyRegimeLevel(t *testing.T) { + tests := []struct { + name string + bollingerWidth float64 + atr14Pct float64 + expected market.RegimeLevel + }{ + {"narrow", 1.5, 0.8, market.RegimeLevelNarrow}, + {"standard", 2.5, 1.5, market.RegimeLevelStandard}, + {"wide", 3.5, 2.5, market.RegimeLevelWide}, + {"volatile", 5.0, 4.0, market.RegimeLevelVolatile}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := classifyRegimeLevel(tt.bollingerWidth, tt.atr14Pct) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestDetectBoxBreakout(t *testing.T) { + box := &market.BoxData{ + ShortUpper: 100, + ShortLower: 90, + MidUpper: 105, + MidLower: 85, + LongUpper: 110, + LongLower: 80, + CurrentPrice: 95, + } + + // No breakout + level, direction := detectBoxBreakout(box) + if level != market.BreakoutNone { + t.Errorf("Expected no breakout, got %v", level) + } + + // Short breakout up + box.CurrentPrice = 101 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutShort || direction != "up" { + t.Errorf("Expected short breakout up, got %v %v", level, direction) + } + + // Mid breakout down + box.CurrentPrice = 84 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutMid || direction != "down" { + t.Errorf("Expected mid breakout down, got %v %v", level, direction) + } + + // Long breakout up + box.CurrentPrice = 112 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutLong || direction != "up" { + t.Errorf("Expected long breakout up, got %v %v", level, direction) + } +} + +func TestBreakoutConfirmation(t *testing.T) { + state := &BreakoutState{ + Level: market.BreakoutNone, + Direction: "", + ConfirmCount: 0, + } + + // First detection + confirmed := confirmBreakout(state, market.BreakoutShort, "up") + if confirmed || state.ConfirmCount != 1 { + t.Errorf("Expected not confirmed, count=1, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Second confirmation + confirmed = confirmBreakout(state, market.BreakoutShort, "up") + if confirmed || state.ConfirmCount != 2 { + t.Errorf("Expected not confirmed, count=2, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Third confirmation - should confirm + confirmed = confirmBreakout(state, market.BreakoutShort, "up") + if !confirmed || state.ConfirmCount != 3 { + t.Errorf("Expected confirmed, count=3, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Reset on price return + state.ConfirmCount = 2 + confirmed = confirmBreakout(state, market.BreakoutNone, "") + if state.ConfirmCount != 0 { + t.Errorf("Expected count reset to 0, got %d", state.ConfirmCount) + } +} + +func TestGetBreakoutAction(t *testing.T) { + tests := []struct { + level market.BreakoutLevel + expected BreakoutAction + }{ + {market.BreakoutNone, BreakoutActionNone}, + {market.BreakoutShort, BreakoutActionReducePosition}, + {market.BreakoutMid, BreakoutActionPauseGrid}, + {market.BreakoutLong, BreakoutActionCloseAll}, + } + + for _, tt := range tests { + t.Run(string(tt.level), func(t *testing.T) { + action := getBreakoutAction(tt.level) + if action != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, action) + } + }) + } +} diff --git a/trader/hyperliquid_sync_test.go b/trader/hyperliquid_sync_test.go index fce8fdc8..4eb4bd81 100644 --- a/trader/hyperliquid_sync_test.go +++ b/trader/hyperliquid_sync_test.go @@ -103,7 +103,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "open_long", 0.1, 3500, 0.5, 0, - time.Now(), "order-1", + time.Now().UnixMilli(), "order-1", ) if err != nil { t.Fatalf("Failed to process open long: %v", err) @@ -126,7 +126,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "close_long", 0.1, 3600, 0.5, 10.0, // PnL = (3600-3500)*0.1 = 10 - time.Now(), "order-2", + time.Now().UnixMilli(), "order-2", ) if err != nil { t.Fatalf("Failed to process close long: %v", err) @@ -152,7 +152,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "SHORT", "open_short", 0.05, 3500, 0.25, 0, - time.Now(), "order-3", + time.Now().UnixMilli(), "order-3", ) if err != nil { t.Fatalf("Failed to process open short: %v", err) @@ -176,7 +176,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "SHORT", "close_short", 0.05, 3400, 0.25, 5.0, // PnL = (3500-3400)*0.05 = 5 - time.Now(), "order-4", + time.Now().UnixMilli(), "order-4", ) if err != nil { t.Fatalf("Failed to process close short: %v", err) @@ -205,7 +205,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "open_long", 0.1, 3500, 0.5, 0, - time.Now(), "order-5", + time.Now().UnixMilli(), "order-5", ) if err != nil { t.Fatalf("Failed to process first open: %v", err) @@ -216,7 +216,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "open_long", 0.1, 3600, 0.5, 0, - time.Now(), "order-6", + time.Now().UnixMilli(), "order-6", ) if err != nil { t.Fatalf("Failed to process add position: %v", err) @@ -243,7 +243,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "close_long", 0.2, 3700, 1.0, 30.0, - time.Now(), "order-7", + time.Now().UnixMilli(), "order-7", ) if err != nil { t.Fatalf("Failed to process close: %v", err) @@ -269,7 +269,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "open_long", 1.0, 3500, 2.0, 0, - time.Now(), "order-8", + time.Now().UnixMilli(), "order-8", ) if err != nil { t.Fatalf("Failed to process open: %v", err) @@ -280,7 +280,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "close_long", 0.3, 3600, 0.6, 30.0, - time.Now(), "order-9", + time.Now().UnixMilli(), "order-9", ) if err != nil { t.Fatalf("Failed to process partial close: %v", err) @@ -351,7 +351,7 @@ func TestHyperliquidBugScenario(t *testing.T) { traderID, exchangeID, exchangeType, trade.symbol, trade.side, trade.action, trade.qty, trade.price, trade.fee, trade.pnl, - time.Now().Add(time.Duration(i)*time.Second), + time.Now().Add(time.Duration(i)*time.Second).UnixMilli(), "", ) if err != nil { diff --git a/trader/hyperliquid_trader.go b/trader/hyperliquid_trader.go index 354acd99..aa89f39e 100644 --- a/trader/hyperliquid_trader.go +++ b/trader/hyperliquid_trader.go @@ -2114,3 +2114,118 @@ func (t *HyperliquidTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { return result, nil } + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *HyperliquidTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + coin := convertSymbolToHyperliquid(req.Symbol) + + // Set leverage if specified and not xyz dex + isXyz := strings.HasPrefix(coin, "xyz:") + if req.Leverage > 0 && !isXyz { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Hyperliquid] Failed to set leverage: %v", err) + } + } + + // Round quantity to allowed decimals + roundedQuantity := t.roundToSzDecimals(coin, req.Quantity) + + // Round price to 5 significant figures + roundedPrice := t.roundPriceToSigfigs(req.Price) + + // Determine if buy or sell + isBuy := req.Side == "BUY" + + logger.Infof("[Hyperliquid] PlaceLimitOrder: %s %s @ %.4f, qty=%.4f", coin, req.Side, roundedPrice, roundedQuantity) + + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: isBuy, + Size: roundedQuantity, + Price: roundedPrice, + OrderType: hyperliquid.OrderType{ + Limit: &hyperliquid.LimitOrderType{ + Tif: hyperliquid.TifGtc, // Good Till Cancel for grid orders + }, + }, + ReduceOnly: req.ReduceOnly, + } + + _, err := t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + // Note: Hyperliquid's Order response doesn't return the order ID directly + // We would need to query open orders to get it, but for grid trading + // we can track orders by price level instead + orderID := fmt.Sprintf("%d", time.Now().UnixNano()) + + logger.Infof("✓ [Hyperliquid] Limit order placed: %s %s @ %.4f", + coin, req.Side, roundedPrice) + + return &LimitOrderResult{ + OrderID: orderID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: roundedPrice, + Quantity: roundedQuantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *HyperliquidTrader) CancelOrder(symbol, orderID string) error { + coin := convertSymbolToHyperliquid(symbol) + + // Parse order ID + oid, err := strconv.ParseInt(orderID, 10, 64) + if err != nil { + return fmt.Errorf("invalid order ID: %w", err) + } + + _, err = t.exchange.Cancel(t.ctx, coin, oid) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Hyperliquid] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *HyperliquidTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + coin := convertSymbolToHyperliquid(symbol) + + l2Book, err := t.exchange.Info().L2Snapshot(t.ctx, coin) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + if l2Book == nil || len(l2Book.Levels) < 2 { + return nil, nil, fmt.Errorf("invalid order book data") + } + + // Parse bids (first level array) + for i, level := range l2Book.Levels[0] { + if i >= depth { + break + } + bids = append(bids, []float64{level.Px, level.Sz}) + } + + // Parse asks (second level array) + for i, level := range l2Book.Levels[1] { + if i >= depth { + break + } + asks = append(asks, []float64{level.Px, level.Sz}) + } + + return bids, asks, nil +} diff --git a/trader/interface.go b/trader/interface.go index 35618633..741e6e31 100644 --- a/trader/interface.go +++ b/trader/interface.go @@ -1,6 +1,10 @@ package trader -import "time" +import ( + "fmt" + "nofx/logger" + "time" +) // ClosedPnLRecord represents a single closed position record from exchange type ClosedPnLRecord struct { @@ -112,3 +116,115 @@ type OpenOrder struct { Quantity float64 `json:"quantity"` Status string `json:"status"` // NEW } + +// LimitOrderRequest represents a limit order request for grid trading +type LimitOrderRequest struct { + Symbol string `json:"symbol"` + Side string `json:"side"` // BUY/SELL + PositionSide string `json:"position_side"` // LONG/SHORT (for hedge mode) + Price float64 `json:"price"` // Limit price + Quantity float64 `json:"quantity"` + Leverage int `json:"leverage"` + PostOnly bool `json:"post_only"` // Maker only order + ReduceOnly bool `json:"reduce_only"` // Reduce position only + ClientID string `json:"client_id"` // Client order ID for tracking +} + +// LimitOrderResult represents the result of placing a limit order +type LimitOrderResult struct { + OrderID string `json:"order_id"` + ClientID string `json:"client_id"` + Symbol string `json:"symbol"` + Side string `json:"side"` + PositionSide string `json:"position_side"` + Price float64 `json:"price"` + Quantity float64 `json:"quantity"` + Status string `json:"status"` // NEW, PARTIALLY_FILLED, FILLED, CANCELED +} + +// GridTrader extends Trader interface with limit order support for grid trading +// Exchanges that support grid trading should implement this interface +type GridTrader interface { + Trader + + // PlaceLimitOrder places a limit order at specified price + // Returns order ID and status + PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) + + // CancelOrder cancels a specific order by ID + CancelOrder(symbol, orderID string) error + + // GetOrderBook gets current order book (for price validation) + // Returns best bid/ask prices + GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) +} + +// GridTraderAdapter wraps a basic Trader to provide GridTrader interface +// Uses stop orders as a fallback when limit orders aren't directly available +type GridTraderAdapter struct { + Trader +} + +// NewGridTraderAdapter creates an adapter for basic Trader +func NewGridTraderAdapter(t Trader) *GridTraderAdapter { + return &GridTraderAdapter{Trader: t} +} + +// PlaceLimitOrder implements limit order using available methods +// For exchanges without native limit order support, this uses conditional orders +func (a *GridTraderAdapter) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // CRITICAL FIX: Set leverage before placing order + if req.Leverage > 0 { + if err := a.Trader.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Grid] Failed to set leverage %dx: %v", req.Leverage, err) + // Continue anyway - some exchanges don't require explicit leverage setting + } + } + + // Use SetStopLoss/SetTakeProfit as conditional limit orders + // For buy orders below current price, use stop-loss mechanism + // For sell orders above current price, use take-profit mechanism + var err error + if req.Side == "BUY" { + err = a.Trader.SetStopLoss(req.Symbol, "SHORT", req.Quantity, req.Price) + } else { + err = a.Trader.SetTakeProfit(req.Symbol, "LONG", req.Quantity, req.Price) + } + if err != nil { + return nil, err + } + return &LimitOrderResult{ + OrderID: req.ClientID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order +func (a *GridTraderAdapter) CancelOrder(symbol, orderID string) error { + // Try to use CancelOrder if trader supports it directly + if canceler, ok := a.Trader.(interface { + CancelOrder(symbol, orderID string) error + }); ok { + return canceler.CancelOrder(symbol, orderID) + } + + // For traders that only support CancelAllOrders, log a warning + // This is a limitation - we cannot cancel individual orders + logger.Warnf("[Grid] Trader does not support individual order cancellation, "+ + "cannot cancel order %s. Consider using exchange-specific GridTrader implementation.", orderID) + + // Return error instead of canceling all orders + return fmt.Errorf("individual order cancellation not supported for this exchange") +} + +// GetOrderBook returns empty order book (not supported in basic Trader) +func (a *GridTraderAdapter) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + // Not supported, return empty + return nil, nil, nil +} diff --git a/trader/lighter_integration_test.go b/trader/lighter_integration_test.go index c8c414e7..11281201 100644 --- a/trader/lighter_integration_test.go +++ b/trader/lighter_integration_test.go @@ -1,25 +1,41 @@ package trader import ( + "fmt" "os" "strings" "testing" "time" ) -// Test configuration - uses real account -// Run with: LIGHTER_TEST=1 go test -v ./trader -run TestLighter -timeout 120s -const ( - testWalletAddr = "" - testAPIKeyPrivateKey = "" - testAPIKeyIndex = 0 - testAccountIndex = int64(681514) -) +// Test configuration - uses environment variables for security +// Run with: +// LIGHTER_TEST=1 LIGHTER_WALLET=0x... LIGHTER_API_KEY=... LIGHTER_API_KEY_INDEX=2 go test -v ./trader -run TestLighter -timeout 300s +// Run with trading: +// LIGHTER_TEST=1 LIGHTER_TRADE_TEST=1 LIGHTER_WALLET=0x... LIGHTER_API_KEY=... go test -v ./trader -run TestLighter -timeout 300s + +// getTestConfig returns test configuration from environment variables +func getTestConfig() (walletAddr, apiKey string, apiKeyIndex int) { + walletAddr = os.Getenv("LIGHTER_WALLET") + apiKey = os.Getenv("LIGHTER_API_KEY") + // All credentials must be provided via environment variables for security + apiKeyIndex = 2 // Default to index 2 (more stable than index 0) + if idx := os.Getenv("LIGHTER_API_KEY_INDEX"); idx != "" { + fmt.Sscanf(idx, "%d", &apiKeyIndex) + } + return +} func skipIfNoEnv(t *testing.T) { if os.Getenv("LIGHTER_TEST") != "1" { t.Skip("Skipping Lighter integration test. Set LIGHTER_TEST=1 to run") } + if os.Getenv("LIGHTER_WALLET") == "" { + t.Skip("Skipping: LIGHTER_WALLET environment variable not set") + } + if os.Getenv("LIGHTER_API_KEY") == "" { + t.Skip("Skipping: LIGHTER_API_KEY environment variable not set") + } } // skipIfJurisdictionRestricted checks if error is due to geographic restriction @@ -31,7 +47,8 @@ func skipIfJurisdictionRestricted(t *testing.T, err error) { } func createTestTrader(t *testing.T) *LighterTraderV2 { - trader, err := NewLighterTraderV2(testWalletAddr, testAPIKeyPrivateKey, testAPIKeyIndex, false) + walletAddr, apiKey, apiKeyIndex := getTestConfig() + trader, err := NewLighterTraderV2(walletAddr, apiKey, apiKeyIndex, false) if err != nil { t.Fatalf("Failed to create trader: %v", err) } @@ -46,9 +63,9 @@ func TestLighterAccountInit(t *testing.T) { trader := createTestTrader(t) defer trader.Cleanup() - // Verify account index - if trader.accountIndex != testAccountIndex { - t.Errorf("Expected account index %d, got %d", testAccountIndex, trader.accountIndex) + // Verify account index is valid (non-zero) + if trader.accountIndex <= 0 { + t.Errorf("Expected valid account index, got %d", trader.accountIndex) } t.Logf("✅ Account initialized: index=%d", trader.accountIndex) @@ -253,11 +270,11 @@ func TestLighterCreateAndCancelLimitOrder(t *testing.T) { t.Fatalf("CreateOrder failed: %v", err) } - orderID, _ := result["order_id"].(string) + orderID, _ := result["orderId"].(string) t.Logf("✅ Order created: %s", orderID) if orderID == "" { - t.Fatal("Expected order ID in response") + t.Fatal("Expected orderId in response") } // Wait a moment for order to be processed @@ -517,11 +534,12 @@ func TestLighterOrderSync(t *testing.T) { // ==================== Benchmark Tests ==================== func BenchmarkLighterGetBalance(b *testing.B) { - if os.Getenv("LIGHTER_TEST") != "1" { - b.Skip("Skipping benchmark. Set LIGHTER_TEST=1 to run") + if os.Getenv("LIGHTER_TEST") != "1" || os.Getenv("LIGHTER_API_KEY") == "" { + b.Skip("Skipping benchmark. Set LIGHTER_TEST=1 and LIGHTER_API_KEY to run") } - trader, err := NewLighterTraderV2(testWalletAddr, testAPIKeyPrivateKey, testAPIKeyIndex, false) + walletAddr, apiKey, apiKeyIndex := getTestConfig() + trader, err := NewLighterTraderV2(walletAddr, apiKey, apiKeyIndex, false) if err != nil { b.Fatalf("Failed to create trader: %v", err) } @@ -537,11 +555,12 @@ func BenchmarkLighterGetBalance(b *testing.B) { } func BenchmarkLighterGetMarketPrice(b *testing.B) { - if os.Getenv("LIGHTER_TEST") != "1" { - b.Skip("Skipping benchmark. Set LIGHTER_TEST=1 to run") + if os.Getenv("LIGHTER_TEST") != "1" || os.Getenv("LIGHTER_API_KEY") == "" { + b.Skip("Skipping benchmark. Set LIGHTER_TEST=1 and LIGHTER_API_KEY to run") } - trader, err := NewLighterTraderV2(testWalletAddr, testAPIKeyPrivateKey, testAPIKeyIndex, false) + walletAddr, apiKey, apiKeyIndex := getTestConfig() + trader, err := NewLighterTraderV2(walletAddr, apiKey, apiKeyIndex, false) if err != nil { b.Fatalf("Failed to create trader: %v", err) } @@ -555,3 +574,533 @@ func BenchmarkLighterGetMarketPrice(b *testing.B) { } } } + +// ==================== GetOpenOrders Tests ==================== + +func TestLighterGetOpenOrders(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test GetOpenOrders + orders, err := trader.GetOpenOrders("ETH") + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("GetOpenOrders failed: %v", err) + } + + t.Logf("✅ GetOpenOrders: found %d open orders", len(orders)) + for i, order := range orders { + if i >= 5 { + t.Logf(" ... and %d more", len(orders)-5) + break + } + t.Logf(" [%d] %s %s %s: qty=%.4f @ %.2f, status=%s", + i+1, order.Symbol, order.Side, order.Type, order.Quantity, order.Price, order.Status) + } +} + +func TestLighterGetActiveOrders(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test GetActiveOrders (internal API) + orders, err := trader.GetActiveOrders("ETH") + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("GetActiveOrders failed: %v", err) + } + + t.Logf("✅ GetActiveOrders: found %d active orders", len(orders)) + for i, order := range orders { + if i >= 5 { + t.Logf(" ... and %d more", len(orders)-5) + break + } + t.Logf(" [%d] OrderID=%s, Type=%s, Price=%s, RemainingAmount=%s", + i+1, order.OrderID, order.Type, order.Price, order.RemainingBaseAmount) + } +} + +// ==================== OrderBook Tests ==================== + +func TestLighterGetOrderBook(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test GetOrderBook + bids, asks, err := trader.GetOrderBook("ETH", 10) + if err != nil { + // OrderBook API may not be available in all regions or require special permissions + if strings.Contains(err.Error(), "403") || strings.Contains(err.Error(), "restricted") { + t.Skipf("Skipping: OrderBook API not available: %v", err) + } + t.Fatalf("GetOrderBook failed: %v", err) + } + + t.Logf("✅ GetOrderBook: %d bids, %d asks", len(bids), len(asks)) + + if len(bids) > 0 { + t.Logf(" Best Bid: %.2f @ %.4f", bids[0][0], bids[0][1]) + } + if len(asks) > 0 { + t.Logf(" Best Ask: %.2f @ %.4f", asks[0][0], asks[0][1]) + } + + // Verify spread makes sense + if len(bids) > 0 && len(asks) > 0 { + spread := asks[0][0] - bids[0][0] + spreadPct := spread / bids[0][0] * 100 + t.Logf(" Spread: %.2f (%.4f%%)", spread, spreadPct) + + if spread < 0 { + t.Error("Invalid spread: ask < bid") + } + } +} + +// ==================== PlaceLimitOrder (GridTrader) Tests ==================== + +func TestLighterPlaceLimitOrder(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Get current market price + marketPrice, err := trader.GetMarketPrice("ETH") + if err != nil { + t.Fatalf("Failed to get market price: %v", err) + } + t.Logf("Current ETH price: %.2f", marketPrice) + + // Create a limit order using PlaceLimitOrder (GridTrader interface) + // Buy order at 75% of market price (won't fill) + limitPrice := marketPrice * 0.75 + quantity := 0.01 + + req := &LimitOrderRequest{ + Symbol: "ETH", + Side: "BUY", + PositionSide: "LONG", + Price: limitPrice, + Quantity: quantity, + Leverage: 10, + ClientID: "test-order-001", + ReduceOnly: false, + } + + t.Logf("Placing limit order via PlaceLimitOrder: %s %.4f @ %.2f", req.Side, req.Quantity, req.Price) + + result, err := trader.PlaceLimitOrder(req) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("PlaceLimitOrder failed: %v", err) + } + + t.Logf("✅ PlaceLimitOrder result: OrderID=%s, Status=%s", result.OrderID, result.Status) + + if result.OrderID == "" { + t.Fatal("Expected OrderID in result") + } + + // Wait and cancel + time.Sleep(3 * time.Second) + + // Cancel the order + err = trader.CancelOrder("ETH", result.OrderID) + if err != nil { + t.Logf("⚠️ Failed to cancel order: %v", err) + } else { + t.Log("✅ Order cancelled successfully") + } +} + +// ==================== SetMarginMode Tests ==================== + +func TestLighterSetMarginMode(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test setting cross margin + t.Log("Setting margin mode to CROSS...") + err := trader.SetMarginMode("ETH", true) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Errorf("SetMarginMode(cross) failed: %v", err) + } else { + t.Log("✅ SetMarginMode(cross) succeeded") + } + + time.Sleep(2 * time.Second) + + // Note: Isolated margin may fail if there's an open position + // Just test cross margin for safety +} + +// ==================== Stop-Loss/Take-Profit Tests ==================== + +func TestLighterStopLossOrder(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping stop-loss test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Check if we have a position first + pos, err := trader.GetPosition("ETH") + if err != nil { + t.Fatalf("GetPosition failed: %v", err) + } + + if pos == nil || pos.Size == 0 { + t.Skip("No ETH position to set stop-loss for") + } + + // Calculate stop-loss price (5% below entry for long, 5% above for short) + var stopPrice float64 + if pos.Side == "long" { + stopPrice = pos.EntryPrice * 0.95 + } else { + stopPrice = pos.EntryPrice * 1.05 + } + + t.Logf("Position: %s %s, size=%.4f, entry=%.2f", pos.Symbol, pos.Side, pos.Size, pos.EntryPrice) + t.Logf("Setting stop-loss at %.2f", stopPrice) + + err = trader.SetStopLoss("ETH", strings.ToUpper(pos.Side), pos.Size, stopPrice) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Errorf("SetStopLoss failed: %v", err) + } else { + t.Log("✅ SetStopLoss succeeded") + } +} + +func TestLighterTakeProfitOrder(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping take-profit test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Check if we have a position first + pos, err := trader.GetPosition("ETH") + if err != nil { + t.Fatalf("GetPosition failed: %v", err) + } + + if pos == nil || pos.Size == 0 { + t.Skip("No ETH position to set take-profit for") + } + + // Calculate take-profit price (10% above entry for long, 10% below for short) + var takeProfitPrice float64 + if pos.Side == "long" { + takeProfitPrice = pos.EntryPrice * 1.10 + } else { + takeProfitPrice = pos.EntryPrice * 0.90 + } + + t.Logf("Position: %s %s, size=%.4f, entry=%.2f", pos.Symbol, pos.Side, pos.Size, pos.EntryPrice) + t.Logf("Setting take-profit at %.2f", takeProfitPrice) + + err = trader.SetTakeProfit("ETH", strings.ToUpper(pos.Side), pos.Size, takeProfitPrice) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Errorf("SetTakeProfit failed: %v", err) + } else { + t.Log("✅ SetTakeProfit succeeded") + } +} + +// ==================== Full Trading Flow Tests ==================== + +func TestLighterFullTradingFlow(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping full trading flow test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + symbol := "ETH" + quantity := 0.01 // Minimum quantity + leverage := 10 + + // Step 1: Get initial state + t.Log("=== Step 1: Get Initial State ===") + balance, _ := trader.GetBalance() + if equity, ok := balance["total_equity"].(float64); ok { + t.Logf(" Initial equity: %.2f", equity) + } + + marketPrice, err := trader.GetMarketPrice(symbol) + if err != nil { + t.Fatalf("Failed to get market price: %v", err) + } + t.Logf(" Market price: %.2f", marketPrice) + + // Step 2: Set leverage + t.Log("=== Step 2: Set Leverage ===") + err = trader.SetLeverage(symbol, leverage) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("SetLeverage failed: %v", err) + } + t.Logf(" Leverage set to %dx", leverage) + time.Sleep(2 * time.Second) + + // Step 3: Open Long Position + t.Log("=== Step 3: Open Long Position ===") + result, err := trader.OpenLong(symbol, quantity, leverage) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("OpenLong failed: %v", err) + } + t.Logf(" OpenLong result: %v", result) + time.Sleep(3 * time.Second) + + // Step 4: Verify position + t.Log("=== Step 4: Verify Position ===") + pos, err := trader.GetPosition(symbol) + if err != nil { + t.Errorf("GetPosition failed: %v", err) + } else if pos != nil { + t.Logf(" Position: %s %s, size=%.4f, entry=%.2f, pnl=%.2f", + pos.Symbol, pos.Side, pos.Size, pos.EntryPrice, pos.UnrealizedPnL) + } + + // Step 5: Place limit order (sell at higher price) + t.Log("=== Step 5: Place Limit Sell Order ===") + limitPrice := marketPrice * 1.05 // 5% above market + limitResult, err := trader.CreateOrder(symbol, true, quantity, limitPrice, "limit", true) + if err != nil { + t.Logf(" Failed to place limit order: %v", err) + } else { + t.Logf(" Limit order placed: %v", limitResult) + } + time.Sleep(2 * time.Second) + + // Step 6: Get open orders + t.Log("=== Step 6: Get Open Orders ===") + orders, err := trader.GetOpenOrders(symbol) + if err != nil { + t.Logf(" Failed to get open orders: %v", err) + } else { + t.Logf(" Open orders: %d", len(orders)) + for _, o := range orders { + t.Logf(" - %s %s: qty=%.4f @ %.2f", o.Side, o.Type, o.Quantity, o.Price) + } + } + + // Step 7: Cancel all orders + t.Log("=== Step 7: Cancel All Orders ===") + err = trader.CancelAllOrders(symbol) + if err != nil { + t.Logf(" Failed to cancel orders: %v", err) + } else { + t.Log(" All orders cancelled") + } + time.Sleep(2 * time.Second) + + // Step 8: Close position + t.Log("=== Step 8: Close Position ===") + closeResult, err := trader.CloseLong(symbol, 0) // 0 = close all + if err != nil { + t.Errorf("CloseLong failed: %v", err) + } else { + t.Logf(" CloseLong result: %v", closeResult) + } + time.Sleep(3 * time.Second) + + // Step 9: Verify position closed + t.Log("=== Step 9: Verify Position Closed ===") + pos, _ = trader.GetPosition(symbol) + if pos == nil || pos.Size == 0 { + t.Log(" ✅ Position closed successfully") + } else { + t.Logf(" ⚠️ Position still exists: size=%.4f", pos.Size) + } + + // Step 10: Get final balance + t.Log("=== Step 10: Get Final State ===") + balance, _ = trader.GetBalance() + if equity, ok := balance["total_equity"].(float64); ok { + t.Logf(" Final equity: %.2f", equity) + } + + t.Log("=== Full Trading Flow Completed ===") +} + +// ==================== API Key Validation Tests ==================== + +func TestLighterAPIKeyValid(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Check if API key is valid + if trader.apiKeyValid { + t.Log("✅ API key is VALID and matches server") + } else { + t.Error("❌ API key is INVALID - does not match server") + } + + // Verify by checking the actual API key + err := trader.checkClient() + if err != nil { + t.Errorf("API key verification error: %v", err) + } else { + t.Log("✅ API key verification passed") + } +} + +// ==================== Market Order Tests ==================== + +func TestLighterMarketOrderBuy(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping market order test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Create a small market buy order + quantity := 0.01 + t.Logf("Creating market buy order: %.4f ETH", quantity) + + result, err := trader.CreateOrder("ETH", false, quantity, 0, "market", false) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("Market buy failed: %v", err) + } + + t.Logf("✅ Market buy result: %v", result) + + // Wait and close + time.Sleep(3 * time.Second) + + // Close the position + _, err = trader.CloseLong("ETH", quantity) + if err != nil { + t.Logf("⚠️ Failed to close position: %v", err) + } else { + t.Log("✅ Position closed") + } +} + +func TestLighterMarketOrderSell(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping market order test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Create a small market sell order (short) + quantity := 0.01 + t.Logf("Creating market sell order (short): %.4f ETH", quantity) + + result, err := trader.CreateOrder("ETH", true, quantity, 0, "market", false) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("Market sell failed: %v", err) + } + + t.Logf("✅ Market sell result: %v", result) + + // Wait and close + time.Sleep(3 * time.Second) + + // Close the position + _, err = trader.CloseShort("ETH", quantity) + if err != nil { + t.Logf("⚠️ Failed to close position: %v", err) + } else { + t.Log("✅ Position closed") + } +} + +// ==================== GetPosition Tests ==================== + +func TestLighterGetPosition(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test GetPosition for ETH + pos, err := trader.GetPosition("ETH") + if err != nil { + t.Fatalf("GetPosition failed: %v", err) + } + + if pos == nil { + t.Log("✅ No ETH position (pos is nil)") + } else if pos.Size == 0 { + t.Log("✅ No ETH position (size is 0)") + } else { + t.Logf("✅ ETH position found:") + t.Logf(" Symbol: %s", pos.Symbol) + t.Logf(" Side: %s", pos.Side) + t.Logf(" Size: %.4f", pos.Size) + t.Logf(" Entry Price: %.2f", pos.EntryPrice) + t.Logf(" Mark Price: %.2f", pos.MarkPrice) + t.Logf(" Liquidation Price: %.2f", pos.LiquidationPrice) + t.Logf(" Unrealized PnL: %.2f", pos.UnrealizedPnL) + t.Logf(" Leverage: %.1fx", pos.Leverage) + } +} + +// ==================== Symbol Normalization Tests ==================== + +func TestLighterSymbolNormalization(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test different symbol formats + testCases := []struct { + input string + expected string + }{ + {"ETH", "ETH"}, + {"ETH-PERP", "ETH"}, + {"ETHUSDT", "ETH"}, + {"ETH/USDT", "ETH"}, + {"BTC", "BTC"}, + {"BTCUSDT", "BTC"}, + } + + for _, tc := range testCases { + // Try to get market price with different formats + price, err := trader.GetMarketPrice(tc.input) + if err != nil { + t.Logf("⚠️ GetMarketPrice(%s) failed: %v", tc.input, err) + } else { + t.Logf("✅ GetMarketPrice(%s) = %.2f", tc.input, price) + } + } +} diff --git a/trader/lighter_trader_v2.go b/trader/lighter_trader_v2.go index 6abdf405..60b11570 100644 --- a/trader/lighter_trader_v2.go +++ b/trader/lighter_trader_v2.go @@ -74,6 +74,7 @@ type LighterTraderV2 struct { apiKeyPrivateKey string // 40-byte API Key private key (for signing transactions) apiKeyIndex uint8 // API Key index (default 0) accountIndex int64 // Account index + apiKeyValid bool // Whether API key has been validated against server // Authentication token authToken string @@ -85,8 +86,10 @@ type LighterTraderV2 struct { precisionMutex sync.RWMutex // Market index cache - marketIndexMap map[string]uint16 // symbol -> market_id - marketMutex sync.RWMutex + marketIndexMap map[string]uint16 // symbol -> market_id + marketMutex sync.RWMutex + marketListCache []MarketInfo // Cached market list + marketListCacheTime time.Time // Time when cache was populated } // NewLighterTraderV2 Create new LIGHTER trader (using official SDK) @@ -127,9 +130,6 @@ func NewLighterTraderV2(walletAddr, apiKeyPrivateKeyHex string, apiKeyIndex int, walletAddr: walletAddr, client: &http.Client{ Timeout: 30 * time.Second, - Transport: &http.Transport{ - Proxy: nil, // Disable proxy for direct connection to Lighter API - }, }, baseURL: baseURL, testnet: testnet, @@ -162,14 +162,18 @@ func NewLighterTraderV2(walletAddr, apiKeyPrivateKeyHex string, apiKeyIndex int, // 7. Verify API Key is correct if err := trader.checkClient(); err != nil { - logger.Warnf("⚠️ API Key verification failed: %v", err) - logger.Warnf("⚠️ The API key may not be registered on-chain. Authenticated API calls (like GetTrades) will fail.") - logger.Warnf("⚠️ To fix: Register this API key using change_api_key transaction from app.lighter.xyz") - // Don't fail here, allow trader to continue (may work with some operations) + trader.apiKeyValid = false + logger.Warnf("⚠️ API Key verification FAILED: %v", err) + logger.Warnf("⚠️ ❌ The API key stored in NOFX does NOT match the API key registered on Lighter.") + logger.Warnf("⚠️ ❌ ALL trading operations (open/close positions, cancel orders) WILL FAIL with 'invalid signature' error.") + logger.Warnf("⚠️ 🔧 To fix: Update your Lighter API key in NOFX Exchange settings with the correct key from app.lighter.xyz") + // Don't fail here, allow trader to continue for read operations (balance, positions) + } else { + trader.apiKeyValid = true } - logger.Infof("✓ LIGHTER trader initialized successfully (account=%d, apiKey=%d, testnet=%v)", - trader.accountIndex, trader.apiKeyIndex, testnet) + logger.Infof("✓ LIGHTER trader initialized (account=%d, apiKey=%d, testnet=%v, apiKeyValid=%v)", + trader.accountIndex, trader.apiKeyIndex, testnet, trader.apiKeyValid) return trader, nil } @@ -212,7 +216,7 @@ func (t *LighterTraderV2) getAccountByL1Address() (*AccountInfo, error) { } // Log raw response for debugging - logger.Infof("LIGHTER account API response: %s", string(body)) + logger.Debugf("LIGHTER account API response: %s", string(body)) if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get account (status %d): %s", resp.StatusCode, string(body)) @@ -238,10 +242,10 @@ func (t *LighterTraderV2) getAccountByL1Address() (*AccountInfo, error) { return nil, fmt.Errorf("no account found for wallet address: %s (try depositing funds first at app.lighter.xyz)", t.walletAddr) } - // Log all found accounts - logger.Infof("Found %d accounts (main: %d, sub: %d)", len(allAccounts), len(accountResp.Accounts), len(accountResp.SubAccounts)) + // Log account summary + logger.Infof("Found %d account(s) (main: %d, sub: %d)", len(allAccounts), len(accountResp.Accounts), len(accountResp.SubAccounts)) for i, acc := range allAccounts { - logger.Infof(" Account[%d]: index=%d, collateral=%s", i, acc.AccountIndex, acc.Collateral) + logger.Debugf(" Account[%d]: index=%d, collateral=%s", i, acc.AccountIndex, acc.Collateral) } account := &allAccounts[0] @@ -253,26 +257,79 @@ func (t *LighterTraderV2) getAccountByL1Address() (*AccountInfo, error) { return account, nil } +// ApiKeyResponse API key query response +type ApiKeyResponse struct { + Code int `json:"code"` + ApiKeys []struct { + AccountIndex int64 `json:"account_index"` + ApiKeyIndex uint8 `json:"api_key_index"` + Nonce int64 `json:"nonce"` + PublicKey string `json:"public_key"` + } `json:"api_keys"` +} + +// getApiKeyFromServer Get API Key public key from Lighter server +// Uses our own HTTP client instead of SDK's global client to avoid connection issues +func (t *LighterTraderV2) getApiKeyFromServer() (string, error) { + endpoint := fmt.Sprintf("%s/api/v1/apikeys?account_index=%d&api_key_index=%d", + t.baseURL, t.accountIndex, t.apiKeyIndex) + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return "", err + } + + resp, err := t.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var result ApiKeyResponse + if err := json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if result.Code != 200 { + return "", fmt.Errorf("API error (code %d)", result.Code) + } + + if len(result.ApiKeys) == 0 { + return "", fmt.Errorf("no API keys found for account %d", t.accountIndex) + } + + return result.ApiKeys[0].PublicKey, nil +} + // checkClient Verify if API Key is correct func (t *LighterTraderV2) checkClient() error { if t.txClient == nil { return fmt.Errorf("TxClient not initialized") } - // Get API Key public key registered on server - publicKey, err := t.httpClient.GetApiKey(t.accountIndex, t.apiKeyIndex) + // Get API Key public key registered on server (using our own HTTP client) + serverPubKey, err := t.getApiKeyFromServer() if err != nil { return fmt.Errorf("failed to get API Key: %w", err) } - // Get local API Key public key + // Get local API Key public key from SDK pubKeyBytes := t.txClient.GetKeyManager().PubKeyBytes() localPubKey := hexutil.Encode(pubKeyBytes[:]) - localPubKey = strings.Replace(localPubKey, "0x", "", 1) + localPubKey = strings.TrimPrefix(localPubKey, "0x") // Compare public keys - if publicKey != localPubKey { - return fmt.Errorf("API Key mismatch: local=%s, server=%s", localPubKey, publicKey) + if serverPubKey != localPubKey { + return fmt.Errorf("API Key mismatch: local=%s, server=%s", localPubKey, serverPubKey) } logger.Infof("✓ API Key verification passed") @@ -436,12 +493,8 @@ func (t *LighterTraderV2) GetTrades(startTime time.Time, limit int) ([]TradeReco return []TradeRecord{}, nil } - // Debug: log raw response (first 500 chars) - logBody := string(body) - if len(logBody) > 500 { - logBody = logBody[:500] + "..." - } - logger.Infof("📋 Lighter trades API raw response: %s", logBody) + // Debug: log raw response + logger.Debugf("Lighter trades API response: %s", string(body)) var response LighterTradeResponse if err := json.Unmarshal(body, &response); err != nil { diff --git a/trader/lighter_trader_v2_account.go b/trader/lighter_trader_v2_account.go index b9c84c18..b2865c32 100644 --- a/trader/lighter_trader_v2_account.go +++ b/trader/lighter_trader_v2_account.go @@ -11,6 +11,7 @@ import ( ) // getFullAccountInfo Fetch full account info from Lighter API (includes balance and positions) +// Supports both main accounts and sub-accounts func (t *LighterTraderV2) getFullAccountInfo() (*AccountInfo, error) { endpoint := fmt.Sprintf("%s/api/v1/account?by=l1_address&value=%s", t.baseURL, t.walletAddr) @@ -34,20 +35,47 @@ func (t *LighterTraderV2) getFullAccountInfo() (*AccountInfo, error) { return nil, fmt.Errorf("failed to get account (status %d): %s", resp.StatusCode, string(body)) } - // Parse response - Lighter returns {"accounts": [...]} + // Parse response - Lighter may return accounts in "accounts" or "sub_accounts" field var accountResp AccountResponse if err := json.Unmarshal(body, &accountResp); err != nil { return nil, fmt.Errorf("failed to parse account response: %w", err) } - if len(accountResp.Accounts) == 0 { - return nil, fmt.Errorf("no account found for wallet address: %s", t.walletAddr) + // Check for API error code + if accountResp.Code != 0 && accountResp.Code != 200 { + return nil, fmt.Errorf("Lighter API error (code %d): %s", accountResp.Code, accountResp.Message) } - account := &accountResp.Accounts[0] - // Use index field if account_index is 0 - if account.AccountIndex == 0 && account.Index != 0 { - account.AccountIndex = account.Index + // Combine both accounts and sub_accounts - some users have sub-accounts + var allAccounts []AccountInfo + allAccounts = append(allAccounts, accountResp.Accounts...) + allAccounts = append(allAccounts, accountResp.SubAccounts...) + + if len(allAccounts) == 0 { + return nil, fmt.Errorf("no account found for wallet address: %s (try depositing funds first at app.lighter.xyz)", t.walletAddr) + } + + // Find the account that matches our stored accountIndex, or use the first one + var account *AccountInfo + for i := range allAccounts { + acc := &allAccounts[i] + // Use index field if account_index is 0 + if acc.AccountIndex == 0 && acc.Index != 0 { + acc.AccountIndex = acc.Index + } + // Match by stored accountIndex if we have one + if t.accountIndex != 0 && acc.AccountIndex == t.accountIndex { + account = acc + break + } + } + + // If no specific match, use the first account + if account == nil { + account = &allAccounts[0] + if account.AccountIndex == 0 && account.Index != 0 { + account.AccountIndex = account.Index + } } return account, nil @@ -328,12 +356,13 @@ func (t *LighterTraderV2) FormatQuantity(symbol string, quantity float64) (strin return fmt.Sprintf("%.4f", quantity), nil } -// GetOrderBook Get order book with best bid/ask prices -func (t *LighterTraderV2) GetOrderBook(symbol string) (bestBid, bestAsk float64, err error) { +// GetOrderBook Get order book (implements GridTrader interface) +// Returns bids and asks as [][]float64 where each element is [price, quantity] +func (t *LighterTraderV2) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { // Get market_id first marketID, err := t.getMarketIndex(symbol) if err != nil { - return 0, 0, fmt.Errorf("failed to get market ID: %w", err) + return nil, nil, fmt.Errorf("failed to get market ID: %w", err) } // Get order book from Lighter API @@ -341,22 +370,22 @@ func (t *LighterTraderV2) GetOrderBook(symbol string) (bestBid, bestAsk float64, req, err := http.NewRequest("GET", endpoint, nil) if err != nil { - return 0, 0, err + return nil, nil, err } resp, err := t.client.Do(req) if err != nil { - return 0, 0, err + return nil, nil, err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return 0, 0, err + return nil, nil, err } if resp.StatusCode != http.StatusOK { - return 0, 0, fmt.Errorf("failed to get order book (status %d): %s", resp.StatusCode, string(body)) + return nil, nil, fmt.Errorf("failed to get order book (status %d): %s", resp.StatusCode, string(body)) } // Parse response @@ -369,35 +398,61 @@ func (t *LighterTraderV2) GetOrderBook(symbol string) (bestBid, bestAsk float64, } if err := json.Unmarshal(body, &apiResp); err != nil { - return 0, 0, fmt.Errorf("failed to parse order book: %w", err) + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) } if apiResp.Code != 200 { - return 0, 0, fmt.Errorf("API error code: %d", apiResp.Code) + return nil, nil, fmt.Errorf("API error code: %d", apiResp.Code) } - // Get best bid (highest buy price) - if len(apiResp.Data.Bids) > 0 && len(apiResp.Data.Bids[0]) >= 1 { - if price, ok := apiResp.Data.Bids[0][0].(float64); ok { - bestBid = price - } else if priceStr, ok := apiResp.Data.Bids[0][0].(string); ok { - bestBid, _ = strconv.ParseFloat(priceStr, 64) + // Helper to parse price/quantity from interface{} + parseFloat := func(v interface{}) float64 { + if f, ok := v.(float64); ok { + return f + } + if s, ok := v.(string); ok { + f, _ := strconv.ParseFloat(s, 64) + return f + } + return 0 + } + + // Convert bids to [][]float64 + maxBids := len(apiResp.Data.Bids) + if depth > 0 && depth < maxBids { + maxBids = depth + } + bids = make([][]float64, 0, maxBids) + for i := 0; i < maxBids; i++ { + if len(apiResp.Data.Bids[i]) >= 2 { + price := parseFloat(apiResp.Data.Bids[i][0]) + qty := parseFloat(apiResp.Data.Bids[i][1]) + if price > 0 && qty > 0 { + bids = append(bids, []float64{price, qty}) + } } } - // Get best ask (lowest sell price) - if len(apiResp.Data.Asks) > 0 && len(apiResp.Data.Asks[0]) >= 1 { - if price, ok := apiResp.Data.Asks[0][0].(float64); ok { - bestAsk = price - } else if priceStr, ok := apiResp.Data.Asks[0][0].(string); ok { - bestAsk, _ = strconv.ParseFloat(priceStr, 64) + // Convert asks to [][]float64 + maxAsks := len(apiResp.Data.Asks) + if depth > 0 && depth < maxAsks { + maxAsks = depth + } + asks = make([][]float64, 0, maxAsks) + for i := 0; i < maxAsks; i++ { + if len(apiResp.Data.Asks[i]) >= 2 { + price := parseFloat(apiResp.Data.Asks[i][0]) + qty := parseFloat(apiResp.Data.Asks[i][1]) + if price > 0 && qty > 0 { + asks = append(asks, []float64{price, qty}) + } } } - if bestBid <= 0 || bestAsk <= 0 { - return 0, 0, fmt.Errorf("invalid order book prices: bid=%.2f, ask=%.2f", bestBid, bestAsk) + if len(bids) > 0 && len(asks) > 0 { + logger.Infof("✓ Lighter order book: %s best_bid=%.2f, best_ask=%.2f, depth=%d/%d", + symbol, bids[0][0], asks[0][0], len(bids), len(asks)) } - logger.Infof("✓ Lighter order book: %s bid=%.2f, ask=%.2f", symbol, bestBid, bestAsk) - return bestBid, bestAsk, nil + return bids, asks, nil } diff --git a/trader/lighter_trader_v2_orders.go b/trader/lighter_trader_v2_orders.go index c30783c2..024b512c 100644 --- a/trader/lighter_trader_v2_orders.go +++ b/trader/lighter_trader_v2_orders.go @@ -1,12 +1,11 @@ package trader import ( - "bytes" "encoding/json" "fmt" "io" - "mime/multipart" "net/http" + "net/url" "nofx/logger" "strconv" @@ -100,15 +99,18 @@ func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[str return nil, fmt.Errorf("invalid auth token: %w", err) } - // Build request URL - endpoint := fmt.Sprintf("%s/api/v1/order/%s", t.baseURL, orderID) + // URL encode auth token (contains colons that need encoding) + // Authentication: Use "auth" query parameter (not Authorization header) + encodedAuth := url.QueryEscape(t.authToken) + + // Build request URL with auth query parameter + endpoint := fmt.Sprintf("%s/api/v1/order/%s?auth=%s", t.baseURL, orderID, encodedAuth) req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return nil, err } - req.Header.Set("Authorization", t.authToken) req.Header.Set("Content-Type", "application/json") resp, err := t.client.Do(req) @@ -148,7 +150,7 @@ func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[str "orderId": order.OrderID, "status": unifiedStatus, "avgPrice": order.Price, - "executedQty": order.FilledQty, + "executedQty": order.FilledBaseAmount, "commission": 0.0, }, nil } @@ -210,9 +212,15 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("failed to get market index: %w", err) } - // Build request URL - endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=%d", - t.baseURL, t.accountIndex, marketIndex) + // URL encode auth token (contains colons that need encoding) + // Authentication: Use "auth" query parameter (not Authorization header) + encodedAuth := url.QueryEscape(t.authToken) + + // Build request URL with auth query parameter + endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=%d&auth=%s", + t.baseURL, t.accountIndex, marketIndex, encodedAuth) + + logger.Debugf("📋 LIGHTER GetActiveOrders: endpoint=%s", endpoint[:min(len(endpoint), 120)]+"...") // Send GET request req, err := http.NewRequest("GET", endpoint, nil) @@ -220,8 +228,6 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("failed to create request: %w", err) } - // Add authentication header - req.Header.Set("Authorization", t.authToken) req.Header.Set("Content-Type", "application/json") resp, err := t.client.Do(req) @@ -235,11 +241,13 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("failed to read response: %w", err) } - // Parse response + logger.Debugf("📋 LIGHTER GetActiveOrders raw response: %s", string(body)) + + // Parse response - Lighter API uses "orders" field, not "data" var apiResp struct { Code int `json:"code"` Message string `json:"message"` - Data []OrderResponse `json:"data"` + Orders []OrderResponse `json:"orders"` } if err := json.Unmarshal(body, &apiResp); err != nil { @@ -250,11 +258,15 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("failed to get active orders (code %d): %s", apiResp.Code, apiResp.Message) } - logger.Infof("✓ LIGHTER - Retrieved %d active orders", len(apiResp.Data)) - return apiResp.Data, nil + logger.Infof("✓ LIGHTER - Retrieved %d active orders", len(apiResp.Orders)) + for i, order := range apiResp.Orders { + logger.Debugf(" Order[%d]: order_id=%s, order_index=%d, market=%d", i, order.OrderID, order.OrderIndex, order.MarketIndex) + } + return apiResp.Orders, nil } // CancelOrder Cancel a single order +// orderID can be either a numeric order_index or a tx_hash string func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { if t.txClient == nil { return fmt.Errorf("TxClient not initialized") @@ -267,10 +279,15 @@ func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { } marketIndex := uint8(marketIndexU16) // SDK expects uint8 - // Convert orderID to int64 + // Try to parse orderID as numeric order_index first orderIndex, err := strconv.ParseInt(orderID, 10, 64) if err != nil { - return fmt.Errorf("invalid order ID: %w", err) + // orderID is a tx_hash, need to query order to get numeric order_index + logger.Debugf("📋 LIGHTER CancelOrder: orderID is tx_hash, querying order...") + orderIndex, err = t.getOrderIndexByTxHash(symbol, orderID) + if err != nil { + return fmt.Errorf("failed to get order index from tx_hash: %w", err) + } } // Build cancel order request @@ -280,22 +297,26 @@ func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { } // Sign transaction using SDK + // Must provide FromAccountIndex and ApiKeyIndex for nonce auto-fetch to work nonce := int64(-1) // -1 means auto-fetch + apiKeyIdx := t.apiKeyIndex tx, err := t.txClient.GetCancelOrderTransaction(txReq, &types.TransactOpts{ - Nonce: &nonce, + FromAccountIndex: &t.accountIndex, + ApiKeyIndex: &apiKeyIdx, + Nonce: &nonce, }) if err != nil { return fmt.Errorf("failed to sign cancel order: %w", err) } - // Serialize transaction - txBytes, err := json.Marshal(tx) + // Get tx_info from SDK (consistent with CreateOrder and other transactions) + txInfo, err := tx.GetTxInfo() if err != nil { - return fmt.Errorf("failed to serialize transaction: %w", err) + return fmt.Errorf("failed to get tx info: %w", err) } - // Submit cancel order to LIGHTER API - _, err = t.submitCancelOrder(txBytes) + // Submit cancel order to LIGHTER API using unified submitOrder function + _, err = t.submitOrder(int(tx.GetTxType()), txInfo) if err != nil { return fmt.Errorf("failed to submit cancel order: %w", err) } @@ -304,65 +325,21 @@ func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { return nil } -// submitCancelOrder Submit signed cancel order to LIGHTER API using multipart/form-data -func (t *LighterTraderV2) submitCancelOrder(signedTx []byte) (map[string]interface{}, error) { - const TX_TYPE_CANCEL_ORDER = 15 - - // Build multipart form data (Lighter API requires form-data, not JSON) - var body bytes.Buffer - writer := multipart.NewWriter(&body) - - // Add tx_type field - if err := writer.WriteField("tx_type", strconv.Itoa(TX_TYPE_CANCEL_ORDER)); err != nil { - return nil, fmt.Errorf("failed to write tx_type: %w", err) - } - - // Add tx_info field - if err := writer.WriteField("tx_info", string(signedTx)); err != nil { - return nil, fmt.Errorf("failed to write tx_info: %w", err) - } - - // Close multipart writer - if err := writer.Close(); err != nil { - return nil, fmt.Errorf("failed to close multipart writer: %w", err) - } - - // Send POST request to /api/v1/sendTx - endpoint := fmt.Sprintf("%s/api/v1/sendTx", t.baseURL) - httpReq, err := http.NewRequest("POST", endpoint, &body) +// getOrderIndexByTxHash finds the numeric order_index by searching active orders for the tx_hash +func (t *LighterTraderV2) getOrderIndexByTxHash(symbol, txHash string) (int64, error) { + // Get all active orders for this symbol + orders, err := t.GetActiveOrders(symbol) if err != nil { - return nil, err + return 0, fmt.Errorf("failed to get active orders: %w", err) } - httpReq.Header.Set("Content-Type", writer.FormDataContentType()) - - resp, err := t.client.Do(httpReq) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err + // Search for the order with matching tx_hash (order_id) + for _, order := range orders { + if order.OrderID == txHash { + logger.Debugf("📋 LIGHTER Found order_index %d for tx_hash %s", order.OrderIndex, txHash) + return order.OrderIndex, nil + } } - // Parse response - var sendResp SendTxResponse - if err := json.Unmarshal(respBody, &sendResp); err != nil { - return nil, fmt.Errorf("failed to parse response: %w, body: %s", err, string(respBody)) - } - - // Check response code - if sendResp.Code != 200 { - return nil, fmt.Errorf("failed to submit cancel order (code %d): %s", sendResp.Code, sendResp.Message) - } - - result := map[string]interface{}{ - "tx_hash": sendResp.Data["tx_hash"], - "status": "cancelled", - } - - logger.Infof("✓ Cancel order submitted to LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"]) - return result, nil + return 0, fmt.Errorf("order not found with tx_hash: %s (may already be filled or cancelled)", txHash) } diff --git a/trader/lighter_trader_v2_orders_test.go b/trader/lighter_trader_v2_orders_test.go new file mode 100644 index 00000000..7b84912f --- /dev/null +++ b/trader/lighter_trader_v2_orders_test.go @@ -0,0 +1,421 @@ +package trader + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGetActiveOrders_ParseResponse tests parsing of Lighter API response +func TestGetActiveOrders_ParseResponse(t *testing.T) { + // Mock response from Lighter API + mockResponse := `{ + "code": 200, + "message": "success", + "orders": [ + { + "order_id": "123456", + "order_index": 123456, + "market_index": 0, + "side": "ask", + "type": "limit", + "is_ask": true, + "price": "3150.50", + "initial_base_amount": "1.5", + "remaining_base_amount": "1.5", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "", + "reduce_only": false, + "timestamp": 1736745600000, + "created_at": 1736745600000 + }, + { + "order_id": "123457", + "order_index": 123457, + "market_index": 0, + "side": "bid", + "type": "limit", + "is_ask": false, + "price": "3100.00", + "initial_base_amount": "2.0", + "remaining_base_amount": "2.0", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "", + "reduce_only": false, + "timestamp": 1736745601000, + "created_at": 1736745601000 + }, + { + "order_id": "123458", + "order_index": 123458, + "market_index": 0, + "side": "ask", + "type": "stop_loss", + "is_ask": true, + "price": "0", + "initial_base_amount": "1.0", + "remaining_base_amount": "1.0", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "3000.00", + "reduce_only": true, + "timestamp": 1736745602000, + "created_at": 1736745602000 + } + ] + }` + + // Parse the response + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + + err := json.Unmarshal([]byte(mockResponse), &apiResp) + require.NoError(t, err, "Should parse response without error") + + // Verify parsed data + assert.Equal(t, 200, apiResp.Code) + assert.Equal(t, 3, len(apiResp.Orders)) + + // Test first order (sell limit) + order1 := apiResp.Orders[0] + assert.Equal(t, "123456", order1.OrderID) + assert.True(t, order1.IsAsk, "First order should be ask (sell)") + assert.Equal(t, "3150.50", order1.Price) + assert.Equal(t, "1.5", order1.RemainingBaseAmount) + assert.False(t, order1.ReduceOnly) + + // Test second order (buy limit) + order2 := apiResp.Orders[1] + assert.Equal(t, "123457", order2.OrderID) + assert.False(t, order2.IsAsk, "Second order should be bid (buy)") + assert.Equal(t, "3100.00", order2.Price) + + // Test third order (stop-loss) + order3 := apiResp.Orders[2] + assert.Equal(t, "123458", order3.OrderID) + assert.Equal(t, "stop_loss", order3.Type) + assert.Equal(t, "3000.00", order3.TriggerPrice) + assert.True(t, order3.ReduceOnly) +} + +// TestGetActiveOrders_EmptyResponse tests handling of empty orders +func TestGetActiveOrders_EmptyResponse(t *testing.T) { + mockResponse := `{ + "code": 200, + "message": "success", + "orders": [] + }` + + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + + err := json.Unmarshal([]byte(mockResponse), &apiResp) + require.NoError(t, err) + assert.Equal(t, 200, apiResp.Code) + assert.Equal(t, 0, len(apiResp.Orders)) +} + +// TestGetActiveOrders_ErrorResponse tests handling of API error +func TestGetActiveOrders_ErrorResponse(t *testing.T) { + mockResponse := `{ + "code": 29500, + "message": "internal server error: invalid signature" + }` + + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + + err := json.Unmarshal([]byte(mockResponse), &apiResp) + require.NoError(t, err) + assert.Equal(t, 29500, apiResp.Code) + assert.Contains(t, apiResp.Message, "invalid signature") +} + +// TestConvertOrderResponseToOpenOrder tests conversion logic +func TestConvertOrderResponseToOpenOrder(t *testing.T) { + testCases := []struct { + name string + order OrderResponse + expectedSide string + expectedType string + expectedPosSide string + }{ + { + name: "Sell limit order (opening short)", + order: OrderResponse{ + OrderID: "1", + IsAsk: true, + Type: "limit", + Price: "3150.00", + RemainingBaseAmount: "1.0", + ReduceOnly: false, + }, + expectedSide: "SELL", + expectedType: "LIMIT", + expectedPosSide: "SHORT", + }, + { + name: "Buy limit order (opening long)", + order: OrderResponse{ + OrderID: "2", + IsAsk: false, + Type: "limit", + Price: "3100.00", + RemainingBaseAmount: "1.0", + ReduceOnly: false, + }, + expectedSide: "BUY", + expectedType: "LIMIT", + expectedPosSide: "LONG", + }, + { + name: "Sell stop-loss (closing long)", + order: OrderResponse{ + OrderID: "3", + IsAsk: true, + Type: "stop_loss", + TriggerPrice: "3000.00", + RemainingBaseAmount: "1.0", + ReduceOnly: true, + }, + expectedSide: "SELL", + expectedType: "STOP_MARKET", + expectedPosSide: "LONG", + }, + { + name: "Buy stop-loss (closing short)", + order: OrderResponse{ + OrderID: "4", + IsAsk: false, + Type: "stop_loss", + TriggerPrice: "3200.00", + RemainingBaseAmount: "1.0", + ReduceOnly: true, + }, + expectedSide: "BUY", + expectedType: "STOP_MARKET", + expectedPosSide: "SHORT", + }, + { + name: "Take profit (closing long)", + order: OrderResponse{ + OrderID: "5", + IsAsk: true, + Type: "take_profit", + TriggerPrice: "3500.00", + RemainingBaseAmount: "1.0", + ReduceOnly: true, + }, + expectedSide: "SELL", + expectedType: "TAKE_PROFIT_MARKET", + expectedPosSide: "LONG", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Convert side + side := "BUY" + if tc.order.IsAsk { + side = "SELL" + } + assert.Equal(t, tc.expectedSide, side) + + // Convert order type + orderType := "LIMIT" + if tc.order.Type == "market" { + orderType = "MARKET" + } else if tc.order.Type == "stop_loss" || tc.order.Type == "stop" { + orderType = "STOP_MARKET" + } else if tc.order.Type == "take_profit" { + orderType = "TAKE_PROFIT_MARKET" + } + assert.Equal(t, tc.expectedType, orderType) + + // Convert position side + positionSide := "LONG" + if tc.order.ReduceOnly { + if side == "BUY" { + positionSide = "SHORT" + } else { + positionSide = "LONG" + } + } else { + if side == "SELL" { + positionSide = "SHORT" + } + } + assert.Equal(t, tc.expectedPosSide, positionSide) + }) + } +} + +// TestGetActiveOrders_MockServer tests the full HTTP flow with a mock server +func TestGetActiveOrders_MockServer(t *testing.T) { + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request path and auth parameter + assert.Contains(t, r.URL.Path, "/api/v1/accountActiveOrders") + + // Check that auth query parameter is present + authParam := r.URL.Query().Get("auth") + if authParam == "" { + // Return error if no auth parameter + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]interface{}{ + "code": 29500, + "message": "internal server error: invalid signature", + }) + return + } + + // Return success response + response := map[string]interface{}{ + "code": 200, + "message": "success", + "orders": []map[string]interface{}{ + { + "order_id": "123456", + "order_index": 123456, + "market_index": 0, + "side": "ask", + "type": "limit", + "is_ask": true, + "price": "3150.50", + "initial_base_amount": "1.5", + "remaining_base_amount": "1.5", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "", + "reduce_only": false, + }, + }, + } + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // Test request without auth - should fail + resp, err := http.Get(server.URL + "/api/v1/accountActiveOrders?account_index=123&market_id=0") + require.NoError(t, err) + defer resp.Body.Close() + + var errorResp struct { + Code int `json:"code"` + Message string `json:"message"` + } + json.NewDecoder(resp.Body).Decode(&errorResp) + assert.Equal(t, 29500, errorResp.Code) + + // Test request with auth - should succeed + resp2, err := http.Get(server.URL + "/api/v1/accountActiveOrders?account_index=123&market_id=0&auth=test_token") + require.NoError(t, err) + defer resp2.Body.Close() + + var successResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + json.NewDecoder(resp2.Body).Decode(&successResp) + assert.Equal(t, 200, successResp.Code) + assert.Equal(t, 1, len(successResp.Orders)) +} + +// TestAuthTokenFormat tests the auth token format +func TestAuthTokenFormat(t *testing.T) { + // Auth token format: timestamp:account_index:api_key_index:signature + // Example: 1768308847:687247:0:742e02... + + sampleToken := "1768308847:687247:0:742e02abc123" + + // The token should be URL encoded when used as query parameter + // Colons become %3A + expectedEncoded := "1768308847%3A687247%3A0%3A742e02abc123" + + // URL encode the token + encoded := url.QueryEscape(sampleToken) + + assert.Equal(t, expectedEncoded, encoded) +} + +// TestOrderResponseStruct tests that OrderResponse struct matches API response +func TestOrderResponseStruct(t *testing.T) { + // Real API response sample (from logs) + realResponse := `{ + "order_id": "4609885", + "order_index": 4609885, + "market_index": 0, + "side": "ask", + "type": "limit", + "is_ask": true, + "price": "3150.00", + "initial_base_amount": "0.0300", + "remaining_base_amount": "0.0300", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "", + "reduce_only": false, + "timestamp": 1736745600000, + "created_at": 1736745600000 + }` + + var order OrderResponse + err := json.Unmarshal([]byte(realResponse), &order) + require.NoError(t, err) + + assert.Equal(t, "4609885", order.OrderID) + assert.Equal(t, int64(4609885), order.OrderIndex) + assert.Equal(t, 0, order.MarketIndex) + assert.Equal(t, "ask", order.Side) + assert.Equal(t, "limit", order.Type) + assert.True(t, order.IsAsk) + assert.Equal(t, "3150.00", order.Price) + assert.Equal(t, "0.0300", order.InitialBaseAmount) + assert.Equal(t, "0.0300", order.RemainingBaseAmount) + assert.Equal(t, "0", order.FilledBaseAmount) + assert.Equal(t, "open", order.Status) + assert.Equal(t, "", order.TriggerPrice) + assert.False(t, order.ReduceOnly) + assert.Equal(t, int64(1736745600000), order.Timestamp) + assert.Equal(t, int64(1736745600000), order.CreatedAt) +} + +// BenchmarkParseOrderResponse benchmarks response parsing +func BenchmarkParseOrderResponse(b *testing.B) { + mockResponse := `{ + "code": 200, + "message": "success", + "orders": [ + {"order_id": "1", "is_ask": true, "price": "3150.50", "remaining_base_amount": "1.5"}, + {"order_id": "2", "is_ask": false, "price": "3100.00", "remaining_base_amount": "2.0"}, + {"order_id": "3", "is_ask": true, "price": "3200.00", "remaining_base_amount": "0.5"} + ] + }` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + json.Unmarshal([]byte(mockResponse), &apiResp) + } +} diff --git a/trader/lighter_trader_v2_trading.go b/trader/lighter_trader_v2_trading.go index ff5a7341..1f4ff417 100644 --- a/trader/lighter_trader_v2_trading.go +++ b/trader/lighter_trader_v2_trading.go @@ -273,9 +273,13 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6 } // Sign transaction using SDK (nonce will be auto-fetched) + // Must provide FromAccountIndex and ApiKeyIndex for nonce auto-fetch to work nonce := int64(-1) // -1 means auto-fetch + apiKeyIdx := t.apiKeyIndex tx, err := t.txClient.GetCreateOrderTransaction(txReq, &types.TransactOpts{ - Nonce: &nonce, + FromAccountIndex: &t.accountIndex, + ApiKeyIndex: &apiKeyIdx, + Nonce: &nonce, }) if err != nil { return nil, fmt.Errorf("failed to sign order: %w", err) @@ -288,7 +292,7 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6 } // Debug: Log the tx_info content - logger.Infof("DEBUG tx_type: %d, tx_info: %s", tx.GetTxType(), txInfo) + logger.Debugf("tx_type: %d, tx_info: %s", tx.GetTxType(), txInfo) // Submit order to LIGHTER API orderResp, err := t.submitOrder(int(tx.GetTxType()), txInfo) @@ -302,6 +306,16 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6 } logger.Infof("✓ LIGHTER order created: %s %s qty=%.4f", symbol, side, quantity) + // For limit orders, poll for the actual order_index after submission + // This is needed because CancelOrder requires the numeric order_index, not tx_hash + if orderType == "limit" { + txHash, _ := orderResp["tx_hash"].(string) + if orderIndex, err := t.pollForOrderIndex(symbol, txHash); err == nil && orderIndex > 0 { + orderResp["orderId"] = fmt.Sprintf("%d", orderIndex) + orderResp["order_index"] = orderIndex + } + } + return orderResp, nil } @@ -386,10 +400,19 @@ func (t *LighterTraderV2) submitOrder(txType int, txInfo string) (map[string]int } // Log full response for debugging - logger.Infof("DEBUG API response: %s", string(respBody)) + logger.Debugf("API response: %s", string(respBody)) // Check response code if sendResp.Code != 200 { + // Provide more specific error message for signature errors + // Code 21120: invalid signature (order submission) + // Code 29500: internal server error: invalid signature (authenticated GET APIs) + if (sendResp.Code == 21120 || sendResp.Code == 29500) && strings.Contains(sendResp.Message, "invalid signature") { + if !t.apiKeyValid { + return nil, fmt.Errorf("API Key MISMATCH (code %d): The API key stored in NOFX does not match the one registered on Lighter. Please update your Lighter API key in Exchange settings at app.lighter.xyz", sendResp.Code) + } + return nil, fmt.Errorf("API Key signature invalid (code %d): Please verify your Lighter API Key in Exchange settings matches the key registered at app.lighter.xyz", sendResp.Code) + } return nil, fmt.Errorf("failed to submit order (code %d): %s", sendResp.Code, sendResp.Message) } @@ -403,17 +426,45 @@ func (t *LighterTraderV2) submitOrder(txType int, txInfo string) (map[string]int } } + logger.Infof("✓ Order submitted to LIGHTER - tx_hash: %s", txHash) + result := map[string]interface{}{ "tx_hash": txHash, "status": "submitted", - "orderId": txHash, // Use tx_hash as orderId + "orderId": txHash, // Use tx_hash as orderId initially } - logger.Infof("✓ Order submitted to LIGHTER - tx_hash: %s", txHash) - return result, nil } +// pollForOrderIndex polls active orders to find the order_index for a newly created order +// Returns the highest order_index (newest order) for the given symbol +func (t *LighterTraderV2) pollForOrderIndex(symbol string, txHash string) (int64, error) { + // Wait a moment for the order to be processed + time.Sleep(500 * time.Millisecond) + + // Get active orders + orders, err := t.GetActiveOrders(symbol) + if err != nil { + return 0, fmt.Errorf("failed to get active orders: %w", err) + } + + if len(orders) == 0 { + return 0, fmt.Errorf("no active orders found (order may have been filled immediately)") + } + + // Find the highest order_index (newest order) + var highestIndex int64 + for _, order := range orders { + if order.OrderIndex > highestIndex { + highestIndex = order.OrderIndex + } + } + + logger.Infof("✓ Order created with order_index: %d (tx_hash: %s)", highestIndex, txHash) + return highestIndex, nil +} + // normalizeSymbol Convert NOFX symbol format to Lighter format // NOFX uses "BTC-PERP", "BTCUSDT", etc. Lighter uses "BTC", "ETH", etc. func normalizeSymbol(symbol string) string { @@ -431,7 +482,7 @@ func (t *LighterTraderV2) getMarketInfo(symbol string) (*MarketInfo, error) { // Normalize symbol to Lighter format normalizedSymbol := normalizeSymbol(symbol) - // 1. Fetch market list from API (TODO: cache this) + // Fetch market list from API (cached for 1 hour) markets, err := t.fetchMarketList() if err != nil { return nil, fmt.Errorf("failed to fetch market list: %w", err) @@ -467,8 +518,18 @@ type MarketInfo struct { PriceDecimals int `json:"price_decimals"` } -// fetchMarketList Fetch market list from API +// fetchMarketList Fetch market list from API with caching (TTL: 1 hour) func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) { + // Check cache (TTL: 1 hour) + t.marketMutex.RLock() + if len(t.marketListCache) > 0 && time.Since(t.marketListCacheTime) < time.Hour { + cached := t.marketListCache + t.marketMutex.RUnlock() + return cached, nil + } + t.marketMutex.RUnlock() + + // Fetch from API endpoint := fmt.Sprintf("%s/api/v1/orderBooks", t.baseURL) req, err := http.NewRequest("GET", endpoint, nil) @@ -514,14 +575,20 @@ func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) { for _, market := range apiResp.OrderBooks { if market.Status == "active" { markets = append(markets, MarketInfo{ - Symbol: market.Symbol, - MarketID: market.MarketID, - SizeDecimals: market.SupportedSizeDecimals, - PriceDecimals: market.SupportedPriceDecimals, + Symbol: market.Symbol, + MarketID: market.MarketID, + SizeDecimals: market.SupportedSizeDecimals, + PriceDecimals: market.SupportedPriceDecimals, }) } } + // Update cache + t.marketMutex.Lock() + t.marketListCache = markets + t.marketListCacheTime = time.Now() + t.marketMutex.Unlock() + logger.Infof("✓ Retrieved %d active markets from Lighter", len(markets)) return markets, nil } @@ -550,31 +617,132 @@ func (t *LighterTraderV2) getFallbackMarketIndex(symbol string) (uint16, error) } // SetLeverage Set leverage (implements Trader interface) +// Lighter uses InitialMarginFraction to represent leverage: +// - InitialMarginFraction = (100 / leverage) * 100 (stored as percentage * 100) +// - e.g., 5x leverage = 20% margin = 2000 in API +// - e.g., 20x leverage = 5% margin = 500 in API func (t *LighterTraderV2) SetLeverage(symbol string, leverage int) error { if t.txClient == nil { return fmt.Errorf("TxClient not initialized") } - // TODO: Sign and submit SetLeverage transaction using SDK - logger.Infof("⚙️ Setting leverage: %s = %dx", symbol, leverage) + // Validate leverage range (1x to 50x typical max) + if leverage < 1 || leverage > 50 { + return fmt.Errorf("leverage must be between 1 and 50, got %d", leverage) + } - return nil // Return success for now + // Get market info (includes market_id) + marketInfo, err := t.getMarketInfo(symbol) + if err != nil { + return fmt.Errorf("failed to get market info: %w", err) + } + marketIndex := uint8(marketInfo.MarketID) + + // Calculate InitialMarginFraction from leverage + // leverage = 100 / margin_fraction_percent + // margin_fraction_percent = 100 / leverage + // API value = margin_fraction_percent * 100 + marginFractionPercent := 100.0 / float64(leverage) + initialMarginFraction := uint16(marginFractionPercent * 100) // e.g., 5x => 20% => 2000 + + logger.Infof("⚙️ Setting leverage: %s = %dx (margin_fraction=%.2f%%, API value=%d)", + symbol, leverage, marginFractionPercent, initialMarginFraction) + + // Build UpdateLeverage request + txReq := &types.UpdateLeverageTxReq{ + MarketIndex: marketIndex, + InitialMarginFraction: initialMarginFraction, + MarginMode: 0, // 0 = cross margin (default) + } + + // Sign transaction using SDK + nonce := int64(-1) // Auto-fetch nonce + tx, err := t.txClient.GetUpdateLeverageTransaction(txReq, &types.TransactOpts{ + Nonce: &nonce, + }) + if err != nil { + return fmt.Errorf("failed to sign leverage transaction: %w", err) + } + + // Get tx_info from SDK + txInfo, err := tx.GetTxInfo() + if err != nil { + return fmt.Errorf("failed to get tx info: %w", err) + } + + // Submit to Lighter API (reuse submitOrder which handles any transaction type) + result, err := t.submitOrder(int(tx.GetTxType()), txInfo) + if err != nil { + return fmt.Errorf("failed to submit leverage transaction: %w", err) + } + + logger.Infof("✓ Leverage set successfully: %s = %dx (tx_hash: %v)", symbol, leverage, result["tx_hash"]) + return nil } // SetMarginMode Set margin mode (implements Trader interface) +// Lighter uses UpdateLeverage transaction which includes both leverage and margin mode +// MarginMode: 0 = cross, 1 = isolated func (t *LighterTraderV2) SetMarginMode(symbol string, isCrossMargin bool) error { if t.txClient == nil { return fmt.Errorf("TxClient not initialized") } - modeStr := "isolated" - if isCrossMargin { - modeStr = "cross" + // Get market info + marketInfo, err := t.getMarketInfo(symbol) + if err != nil { + return fmt.Errorf("failed to get market info: %w", err) + } + marketIndex := uint8(marketInfo.MarketID) + + // Determine margin mode value + var marginMode uint8 = 0 // cross + modeStr := "cross" + if !isCrossMargin { + marginMode = 1 // isolated + modeStr = "isolated" } - logger.Infof("⚙️ Setting margin mode: %s = %s", symbol, modeStr) + // Get current position to preserve leverage, or use default 10x if no position + var initialMarginFraction uint16 = 1000 // Default 10x leverage (10% margin = 1000) + pos, err := t.GetPosition(symbol) + if err == nil && pos != nil && pos.Leverage > 0 { + // Calculate InitialMarginFraction from current leverage + marginFractionPercent := 100.0 / pos.Leverage + initialMarginFraction = uint16(marginFractionPercent * 100) + } - // TODO: Sign and submit SetMarginMode transaction using SDK + logger.Infof("⚙️ Setting margin mode: %s = %s (margin_mode=%d, preserving leverage)", symbol, modeStr, marginMode) + + // Build UpdateLeverage request (also updates margin mode) + txReq := &types.UpdateLeverageTxReq{ + MarketIndex: marketIndex, + InitialMarginFraction: initialMarginFraction, + MarginMode: marginMode, + } + + // Sign transaction + nonce := int64(-1) + tx, err := t.txClient.GetUpdateLeverageTransaction(txReq, &types.TransactOpts{ + Nonce: &nonce, + }) + if err != nil { + return fmt.Errorf("failed to sign margin mode transaction: %w", err) + } + + // Get tx_info + txInfo, err := tx.GetTxInfo() + if err != nil { + return fmt.Errorf("failed to get tx info: %w", err) + } + + // Submit to Lighter API + result, err := t.submitOrder(int(tx.GetTxType()), txInfo) + if err != nil { + return fmt.Errorf("failed to submit margin mode transaction: %w", err) + } + + logger.Infof("✓ Margin mode set successfully: %s = %s (tx_hash: %v)", symbol, modeStr, result["tx_hash"]) return nil } @@ -653,7 +821,7 @@ func (t *LighterTraderV2) CreateStopOrder(symbol string, isAsk bool, quantity fl return nil, fmt.Errorf("failed to get tx info: %w", err) } - logger.Infof("DEBUG stop order - type: %d, trigger: %.2f, price: %.2f, isAsk: %v", orderTypeValue, triggerPrice, float64(priceValue)/100, isAsk) + logger.Debugf("stop order - type: %d, trigger: %.2f, price: %.2f, isAsk: %v", orderTypeValue, triggerPrice, float64(priceValue)/100, isAsk) // Submit order orderResp, err := t.submitOrder(int(tx.GetTxType()), txInfo) @@ -689,6 +857,117 @@ func pow10(n int) int64 { // GetOpenOrders gets all open/pending orders for a symbol func (t *LighterTraderV2) GetOpenOrders(symbol string) ([]OpenOrder, error) { - // TODO: Implement Lighter open orders - return []OpenOrder{}, nil + // Get active orders from Lighter API + activeOrders, err := t.GetActiveOrders(symbol) + if err != nil { + return nil, fmt.Errorf("failed to get active orders: %w", err) + } + + var result []OpenOrder + for _, order := range activeOrders { + // Convert side: Lighter uses is_ask (true=sell, false=buy) + side := "BUY" + if order.IsAsk { + side = "SELL" + } + + // Determine order type from Lighter's type field + orderType := "LIMIT" + if order.Type == "market" { + orderType = "MARKET" + } else if order.Type == "stop_loss" || order.Type == "stop" { + orderType = "STOP_MARKET" + } else if order.Type == "take_profit" { + orderType = "TAKE_PROFIT_MARKET" + } + + // Determine position side based on order direction and reduce-only flag + positionSide := "LONG" + if order.ReduceOnly { + // For reduce-only orders, position side is opposite to order side + if side == "BUY" { + positionSide = "SHORT" // Buying to close short + } else { + positionSide = "LONG" // Selling to close long + } + } else { + // For opening orders + if side == "SELL" { + positionSide = "SHORT" + } + } + + // Parse price and quantity from string fields + price, _ := strconv.ParseFloat(order.Price, 64) + quantity, _ := strconv.ParseFloat(order.RemainingBaseAmount, 64) + if quantity == 0 { + quantity, _ = strconv.ParseFloat(order.InitialBaseAmount, 64) + } + triggerPrice, _ := strconv.ParseFloat(order.TriggerPrice, 64) + + openOrder := OpenOrder{ + OrderID: order.OrderID, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: orderType, + Price: price, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + } + result = append(result, openOrder) + } + + logger.Infof("✓ LIGHTER GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder implements GridTrader interface for grid trading +// Places a limit order at the specified price +func (t *LighterTraderV2) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + if t.txClient == nil { + return nil, fmt.Errorf("TxClient not initialized") + } + + // Determine if this is a sell (ask) order + isAsk := req.Side == "SELL" + + logger.Infof("📝 LIGHTER placing limit order: %s %s @ %.4f, qty=%.4f, leverage=%dx", + req.Symbol, req.Side, req.Price, req.Quantity, req.Leverage) + + // Set leverage before placing order (important for grid trading) + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("⚠️ Failed to set leverage: %v (continuing with current leverage)", err) + } + } + + // Create limit order using existing CreateOrder function + orderResult, err := t.CreateOrder(req.Symbol, isAsk, req.Quantity, req.Price, "limit", req.ReduceOnly) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + // Extract order ID from result + orderID := "" + if id, ok := orderResult["orderId"]; ok { + orderID = fmt.Sprintf("%v", id) + } else if txHash, ok := orderResult["tx_hash"]; ok { + orderID = fmt.Sprintf("%v", txHash) + } + + logger.Infof("✓ LIGHTER limit order placed: %s %s @ %.4f, OrderID: %s", + req.Symbol, req.Side, req.Price, orderID) + + return &LimitOrderResult{ + OrderID: orderID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil } diff --git a/trader/lighter_types.go b/trader/lighter_types.go index 3a76b7c4..e4670cdd 100644 --- a/trader/lighter_types.go +++ b/trader/lighter_types.go @@ -41,18 +41,24 @@ type CreateOrderRequest struct { PostOnly bool `json:"post_only"` // Post-only (maker only) } -// OrderResponse Order response (Lighter) +// OrderResponse Order response (Lighter API) +// Field names must match Lighter API response exactly type OrderResponse struct { - OrderID string `json:"order_id"` - Symbol string `json:"symbol"` - Side string `json:"side"` - OrderType string `json:"order_type"` - Quantity float64 `json:"quantity"` - Price float64 `json:"price"` - Status string `json:"status"` // "open", "filled", "cancelled" - FilledQty float64 `json:"filled_qty"` - RemainingQty float64 `json:"remaining_qty"` - CreateTime int64 `json:"create_time"` + OrderID string `json:"order_id"` + OrderIndex int64 `json:"order_index"` + MarketIndex int `json:"market_index"` + Side string `json:"side"` // "bid" or "ask" + Type string `json:"type"` // "limit", "market", etc. + IsAsk bool `json:"is_ask"` // true = sell, false = buy + Price string `json:"price"` // Price as string + InitialBaseAmount string `json:"initial_base_amount"` // Original quantity + RemainingBaseAmount string `json:"remaining_base_amount"` // Remaining quantity + FilledBaseAmount string `json:"filled_base_amount"` // Filled quantity + Status string `json:"status"` // "open", "filled", "cancelled" + TriggerPrice string `json:"trigger_price"` // For stop orders + ReduceOnly bool `json:"reduce_only"` + Timestamp int64 `json:"timestamp"` + CreatedAt int64 `json:"created_at"` } // LighterTradeResponse represents the response from Lighter trades API diff --git a/trader/okx_trader.go b/trader/okx_trader.go index 2d4b8b89..fc3fae69 100644 --- a/trader/okx_trader.go +++ b/trader/okx_trader.go @@ -1390,6 +1390,254 @@ func (t *OKXTrader) GetClosedPnL(startTime time.Time, limit int) ([]ClosedPnLRec // GetOpenOrders gets all open/pending orders for a symbol func (t *OKXTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { - // TODO: Implement OKX open orders - return []OpenOrder{}, nil + instId := t.convertSymbol(symbol) + var result []OpenOrder + + // 1. Get pending limit orders + path := fmt.Sprintf("%s?instId=%s&instType=SWAP", okxPendingOrdersPath, instId) + data, err := t.doRequest("GET", path, nil) + if err != nil { + logger.Warnf("[OKX] Failed to get pending orders: %v", err) + } + if err == nil && data != nil { + var orders []struct { + OrdId string `json:"ordId"` + InstId string `json:"instId"` + Side string `json:"side"` // buy/sell + PosSide string `json:"posSide"` // long/short/net + OrdType string `json:"ordType"` // limit/market/post_only + Px string `json:"px"` // price + Sz string `json:"sz"` // size + State string `json:"state"` // live/partially_filled + } + if err := json.Unmarshal(data, &orders); err == nil { + for _, order := range orders { + price, _ := strconv.ParseFloat(order.Px, 64) + quantity, _ := strconv.ParseFloat(order.Sz, 64) + + // Convert OKX side to standard format + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + if positionSide == "NET" { + positionSide = "BOTH" + } + + result = append(result, OpenOrder{ + OrderID: order.OrdId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: strings.ToUpper(order.OrdType), + Price: price, + StopPrice: 0, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + // 2. Get pending algo orders (stop-loss/take-profit) + algoPath := fmt.Sprintf("%s?instId=%s&instType=SWAP", okxAlgoPendingPath, instId) + algoData, err := t.doRequest("GET", algoPath, nil) + if err != nil { + logger.Warnf("[OKX] Failed to get algo orders: %v", err) + } + if err == nil && algoData != nil { + var algoOrders []struct { + AlgoId string `json:"algoId"` + InstId string `json:"instId"` + Side string `json:"side"` + PosSide string `json:"posSide"` + OrdType string `json:"ordType"` // conditional/oco/trigger + TriggerPx string `json:"triggerPx"` + Sz string `json:"sz"` + State string `json:"state"` + } + if err := json.Unmarshal(algoData, &algoOrders); err == nil { + for _, order := range algoOrders { + triggerPrice, _ := strconv.ParseFloat(order.TriggerPx, 64) + quantity, _ := strconv.ParseFloat(order.Sz, 64) + + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + if positionSide == "NET" { + positionSide = "BOTH" + } + + // Map OKX algo order type + orderType := "STOP_MARKET" + if order.OrdType == "oco" { + orderType = "TAKE_PROFIT_MARKET" + } + + result = append(result, OpenOrder{ + OrderID: order.AlgoId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: orderType, + Price: 0, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + logger.Infof("✓ OKX GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *OKXTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + instId := t.convertSymbol(req.Symbol) + + // Get instrument info + inst, err := t.getInstrument(req.Symbol) + if err != nil { + return nil, fmt.Errorf("failed to get instrument info: %w", err) + } + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[OKX] Failed to set leverage: %v", err) + } + } + + // Convert quantity to contract size + sz := req.Quantity / inst.CtVal + szStr := t.formatSize(sz, inst) + + // Determine side and position side + side := "buy" + posSide := "long" + if req.Side == "SELL" { + side = "sell" + posSide = "short" + } + + body := map[string]interface{}{ + "instId": instId, + "tdMode": "cross", + "side": side, + "posSide": posSide, + "ordType": "limit", + "sz": szStr, + "px": fmt.Sprintf("%.8f", req.Price), + "clOrdId": genOkxClOrdID(), + "tag": okxTag, + } + + // Add reduce only if specified + if req.ReduceOnly { + body["reduceOnly"] = true + } + + logger.Infof("[OKX] PlaceLimitOrder: %s %s @ %.4f, sz=%s", instId, side, req.Price, szStr) + + data, err := t.doRequest("POST", okxOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var orders []struct { + OrdId string `json:"ordId"` + ClOrdId string `json:"clOrdId"` + SCode string `json:"sCode"` + SMsg string `json:"sMsg"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + if len(orders) == 0 { + return nil, fmt.Errorf("empty order response") + } + + if orders[0].SCode != "0" { + return nil, fmt.Errorf("OKX order failed: %s", orders[0].SMsg) + } + + logger.Infof("✓ [OKX] Limit order placed: %s %s @ %.4f, orderID=%s", + instId, side, req.Price, orders[0].OrdId) + + return &LimitOrderResult{ + OrderID: orders[0].OrdId, + ClientID: orders[0].ClOrdId, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *OKXTrader) CancelOrder(symbol, orderID string) error { + instId := t.convertSymbol(symbol) + + body := map[string]interface{}{ + "instId": instId, + "ordId": orderID, + } + + _, err := t.doRequest("POST", "/api/v5/trade/cancel-order", body) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [OKX] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *OKXTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + instId := t.convertSymbol(symbol) + path := fmt.Sprintf("/api/v5/market/books?instId=%s&sz=%d", instId, depth) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + var result []struct { + Bids [][]string `json:"bids"` + Asks [][]string `json:"asks"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + if len(result) == 0 { + return nil, nil, nil + } + + // Parse bids + for _, b := range result[0].Bids { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result[0].Asks { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil } diff --git a/web/src/components/strategy/GridConfigEditor.tsx b/web/src/components/strategy/GridConfigEditor.tsx new file mode 100644 index 00000000..7756f2ba --- /dev/null +++ b/web/src/components/strategy/GridConfigEditor.tsx @@ -0,0 +1,424 @@ +import { Grid, DollarSign, TrendingUp, Shield } from 'lucide-react' +import type { GridStrategyConfig } from '../../types' + +interface GridConfigEditorProps { + config: GridStrategyConfig + onChange: (config: GridStrategyConfig) => void + disabled?: boolean + language: string +} + +// Default grid config +export const defaultGridConfig: GridStrategyConfig = { + symbol: 'BTCUSDT', + grid_count: 10, + total_investment: 1000, + leverage: 5, + upper_price: 0, + lower_price: 0, + use_atr_bounds: true, + atr_multiplier: 2.0, + distribution: 'gaussian', + max_drawdown_pct: 15, + stop_loss_pct: 5, + daily_loss_limit_pct: 10, + use_maker_only: true, +} + +export function GridConfigEditor({ + config, + onChange, + disabled, + language, +}: GridConfigEditorProps) { + const t = (key: string) => { + const translations: Record> = { + // Section titles + tradingPair: { zh: '交易设置', en: 'Trading Setup' }, + gridParameters: { zh: '网格参数', en: 'Grid Parameters' }, + priceBounds: { zh: '价格边界', en: 'Price Bounds' }, + riskControl: { zh: '风险控制', en: 'Risk Control' }, + + // Trading pair + symbol: { zh: '交易对', en: 'Trading Pair' }, + symbolDesc: { zh: '选择要进行网格交易的交易对', en: 'Select trading pair for grid trading' }, + + // Investment + totalInvestment: { zh: '投资金额 (USDT)', en: 'Investment (USDT)' }, + totalInvestmentDesc: { zh: '网格策略的总投资金额', en: 'Total investment for grid strategy' }, + leverage: { zh: '杠杆倍数', en: 'Leverage' }, + leverageDesc: { zh: '交易使用的杠杆倍数 (1-20)', en: 'Leverage for trading (1-20)' }, + + // Grid parameters + gridCount: { zh: '网格数量', en: 'Grid Count' }, + gridCountDesc: { zh: '网格层级数量 (5-50)', en: 'Number of grid levels (5-50)' }, + distribution: { zh: '资金分配方式', en: 'Distribution' }, + distributionDesc: { zh: '网格层级的资金分配方式', en: 'Fund allocation across grid levels' }, + uniform: { zh: '均匀分配', en: 'Uniform' }, + gaussian: { zh: '高斯分配 (推荐)', en: 'Gaussian (Recommended)' }, + pyramid: { zh: '金字塔分配', en: 'Pyramid' }, + + // Price bounds + useAtrBounds: { zh: '自动计算边界 (ATR)', en: 'Auto-calculate Bounds (ATR)' }, + useAtrBoundsDesc: { zh: '基于 ATR 自动计算网格上下边界', en: 'Auto-calculate bounds based on ATR' }, + atrMultiplier: { zh: 'ATR 倍数', en: 'ATR Multiplier' }, + atrMultiplierDesc: { zh: '边界距离当前价格的 ATR 倍数', en: 'ATR multiplier for bounds distance' }, + upperPrice: { zh: '上边界价格', en: 'Upper Price' }, + upperPriceDesc: { zh: '网格上边界价格 (0=自动计算)', en: 'Grid upper bound (0=auto)' }, + lowerPrice: { zh: '下边界价格', en: 'Lower Price' }, + lowerPriceDesc: { zh: '网格下边界价格 (0=自动计算)', en: 'Grid lower bound (0=auto)' }, + + // Risk control + maxDrawdown: { zh: '最大回撤 (%)', en: 'Max Drawdown (%)' }, + maxDrawdownDesc: { zh: '触发紧急退出的最大回撤百分比', en: 'Max drawdown before emergency exit' }, + stopLoss: { zh: '止损 (%)', en: 'Stop Loss (%)' }, + stopLossDesc: { zh: '单仓位止损百分比', en: 'Stop loss per position' }, + dailyLossLimit: { zh: '日损失限制 (%)', en: 'Daily Loss Limit (%)' }, + dailyLossLimitDesc: { zh: '每日最大亏损百分比', en: 'Maximum daily loss percentage' }, + useMakerOnly: { zh: '仅使用 Maker 订单', en: 'Maker Only Orders' }, + useMakerOnlyDesc: { zh: '使用限价单以降低手续费', en: 'Use limit orders for lower fees' }, + } + return translations[key]?.[language] || key + } + + const updateField = ( + key: K, + value: GridStrategyConfig[K] + ) => { + if (!disabled) { + onChange({ ...config, [key]: value }) + } + } + + const inputStyle = { + background: '#1E2329', + border: '1px solid #2B3139', + color: '#EAECEF', + } + + const sectionStyle = { + background: '#0B0E11', + border: '1px solid #2B3139', + } + + return ( +
+ {/* Trading Setup */} +
+
+ +

+ {t('tradingPair')} +

+
+ +
+ {/* Symbol */} +
+ +

+ {t('symbolDesc')} +

+ +
+ + {/* Investment */} +
+ +

+ {t('totalInvestmentDesc')} +

+ updateField('total_investment', parseFloat(e.target.value) || 1000)} + disabled={disabled} + min={100} + step={100} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+ + {/* Leverage */} +
+ +

+ {t('leverageDesc')} +

+ updateField('leverage', parseInt(e.target.value) || 5)} + disabled={disabled} + min={1} + max={20} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+
+
+ + {/* Grid Parameters */} +
+
+ +

+ {t('gridParameters')} +

+
+ +
+ {/* Grid Count */} +
+ +

+ {t('gridCountDesc')} +

+ updateField('grid_count', parseInt(e.target.value) || 10)} + disabled={disabled} + min={5} + max={50} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+ + {/* Distribution */} +
+ +

+ {t('distributionDesc')} +

+ +
+
+
+ + {/* Price Bounds */} +
+
+ +

+ {t('priceBounds')} +

+
+ + {/* ATR Toggle */} +
+
+
+ +

+ {t('useAtrBoundsDesc')} +

+
+ +
+
+ + {config.use_atr_bounds ? ( +
+ +

+ {t('atrMultiplierDesc')} +

+ updateField('atr_multiplier', parseFloat(e.target.value) || 2.0)} + disabled={disabled} + min={1} + max={5} + step={0.5} + className="w-32 px-3 py-2 rounded" + style={inputStyle} + /> +
+ ) : ( +
+
+ +

+ {t('upperPriceDesc')} +

+ updateField('upper_price', parseFloat(e.target.value) || 0)} + disabled={disabled} + min={0} + step={0.01} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+
+ +

+ {t('lowerPriceDesc')} +

+ updateField('lower_price', parseFloat(e.target.value) || 0)} + disabled={disabled} + min={0} + step={0.01} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+
+ )} +
+ + {/* Risk Control */} +
+
+ +

+ {t('riskControl')} +

+
+ +
+
+ +

+ {t('maxDrawdownDesc')} +

+ updateField('max_drawdown_pct', parseFloat(e.target.value) || 15)} + disabled={disabled} + min={5} + max={50} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+ +
+ +

+ {t('stopLossDesc')} +

+ updateField('stop_loss_pct', parseFloat(e.target.value) || 5)} + disabled={disabled} + min={1} + max={20} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+ +
+ +

+ {t('dailyLossLimitDesc')} +

+ updateField('daily_loss_limit_pct', parseFloat(e.target.value) || 10)} + disabled={disabled} + min={1} + max={30} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+
+ + {/* Maker Only Toggle */} +
+
+
+ +

+ {t('useMakerOnlyDesc')} +

+
+ +
+
+
+
+ ) +} diff --git a/web/src/components/strategy/GridRiskPanel.tsx b/web/src/components/strategy/GridRiskPanel.tsx new file mode 100644 index 00000000..75c5b73b --- /dev/null +++ b/web/src/components/strategy/GridRiskPanel.tsx @@ -0,0 +1,372 @@ +import { useState, useEffect, useCallback } from 'react' +import { Shield, TrendingUp, AlertTriangle, Activity, Box, ChevronDown, ChevronUp } from 'lucide-react' +import type { GridRiskInfo } from '../../types' + +interface GridRiskPanelProps { + traderId: string + language?: string + refreshInterval?: number // ms, default 5000 +} + +export function GridRiskPanel({ + traderId, + language = 'en', + refreshInterval = 5000, +}: GridRiskPanelProps) { + const [riskInfo, setRiskInfo] = useState(null) + const [loading, setLoading] = useState(true) + const [error, setError] = useState(null) + const [expanded, setExpanded] = useState(false) + + const t = (key: string) => { + const translations: Record> = { + // Section titles + gridRisk: { zh: '网格风控', en: 'Grid Risk' }, + leverageInfo: { zh: '杠杆', en: 'Leverage' }, + positionInfo: { zh: '仓位', en: 'Position' }, + liquidationInfo: { zh: '清算', en: 'Liquidation' }, + marketState: { zh: '市场', en: 'Market' }, + boxState: { zh: '箱体', en: 'Box' }, + + // Leverage + currentLeverage: { zh: '当前', en: 'Current' }, + effectiveLeverage: { zh: '有效', en: 'Effective' }, + recommendedLeverage: { zh: '建议', en: 'Recommend' }, + + // Position + currentPosition: { zh: '当前', en: 'Current' }, + maxPosition: { zh: '最大', en: 'Max' }, + positionPercent: { zh: '占比', en: 'Usage' }, + + // Liquidation + liquidationPrice: { zh: '清算价', en: 'Liq Price' }, + liquidationDistance: { zh: '距离', en: 'Distance' }, + + // Market + regimeLevel: { zh: '波动', en: 'Regime' }, + currentPrice: { zh: '价格', en: 'Price' }, + breakoutLevel: { zh: '突破', en: 'Breakout' }, + breakoutDirection: { zh: '方向', en: 'Direction' }, + + // Box + shortBox: { zh: '短期', en: 'Short' }, + midBox: { zh: '中期', en: 'Mid' }, + longBox: { zh: '长期', en: 'Long' }, + + // Regime levels + narrow: { zh: '窄幅', en: 'Narrow' }, + standard: { zh: '标准', en: 'Standard' }, + wide: { zh: '宽幅', en: 'Wide' }, + volatile: { zh: '剧烈', en: 'Volatile' }, + trending: { zh: '趋势', en: 'Trending' }, + + // Breakout levels + none: { zh: '无', en: 'None' }, + short: { zh: '短期', en: 'Short' }, + mid: { zh: '中期', en: 'Mid' }, + long: { zh: '长期', en: 'Long' }, + + // Directions + up: { zh: '↑', en: '↑' }, + down: { zh: '↓', en: '↓' }, + + // Status + loading: { zh: '加载中...', en: 'Loading...' }, + error: { zh: '加载失败', en: 'Load Failed' }, + noData: { zh: '暂无数据', en: 'No Data' }, + } + return translations[key]?.[language] || key + } + + const fetchRiskInfo = useCallback(async () => { + try { + const token = localStorage.getItem('auth_token') + const response = await fetch(`/api/traders/${traderId}/grid-risk`, { + headers: { + Authorization: `Bearer ${token}`, + }, + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}`) + } + + const data = await response.json() + setRiskInfo(data) + setError(null) + } catch (err) { + setError(err instanceof Error ? err.message : 'Unknown error') + } finally { + setLoading(false) + } + }, [traderId]) + + useEffect(() => { + fetchRiskInfo() + const interval = setInterval(fetchRiskInfo, refreshInterval) + return () => clearInterval(interval) + }, [fetchRiskInfo, refreshInterval]) + + const getRegimeColor = (regime: string) => { + switch (regime) { + case 'narrow': return '#0ECB81' + case 'standard': return '#F0B90B' + case 'wide': return '#F7931A' + case 'volatile': return '#F6465D' + case 'trending': return '#8B5CF6' + default: return '#848E9C' + } + } + + const getBreakoutColor = (level: string) => { + switch (level) { + case 'none': return '#0ECB81' + case 'short': return '#F0B90B' + case 'mid': return '#F7931A' + case 'long': return '#F6465D' + default: return '#848E9C' + } + } + + const getPositionColor = (percent: number) => { + if (percent < 50) return '#0ECB81' + if (percent < 80) return '#F0B90B' + return '#F6465D' + } + + const formatPrice = (price: number) => { + if (price === 0) return '-' + if (price >= 1000) return price.toLocaleString('en-US', { minimumFractionDigits: 2, maximumFractionDigits: 2 }) + if (price >= 1) return price.toFixed(4) + return price.toFixed(6) + } + + const formatUSD = (value: number) => { + return `$${value.toLocaleString('en-US', { minimumFractionDigits: 0, maximumFractionDigits: 0 })}` + } + + const cardStyle = { + background: '#0B0E11', + border: '1px solid #2B3139', + } + + if (loading) { + return ( +
+ {t('loading')} +
+ ) + } + + if (error) { + return ( +
+ {t('error')}: {error} +
+ ) + } + + if (!riskInfo) { + return ( +
+ {t('noData')} +
+ ) + } + + return ( +
+ {/* Collapsible Header */} +
setExpanded(!expanded)} + > +
+ + + {t('gridRisk')} + +
+
+ {/* Summary badges when collapsed */} +
+ + {t(riskInfo.regime_level || 'standard')} + + + {riskInfo.effective_leverage.toFixed(1)}x + + + {riskInfo.position_percent.toFixed(0)}% + +
+ {expanded ? ( + + ) : ( + + )} +
+
+ + {/* Expanded Content */} + {expanded && ( +
+ {/* Row 1: Leverage & Position */} +
+ {/* Leverage */} +
+
+ + {t('leverageInfo')} +
+
+
+
{t('currentLeverage')}
+
{riskInfo.current_leverage}x
+
+
+
{t('effectiveLeverage')}
+
{riskInfo.effective_leverage.toFixed(2)}x
+
+
+
{t('recommendedLeverage')}
+
riskInfo.recommended_leverage ? '#F6465D' : '#0ECB81' }} + > + {riskInfo.recommended_leverage}x +
+
+
+
+ + {/* Position */} +
+
+ + {t('positionInfo')} +
+
+
+
{t('currentPosition')}
+
{formatUSD(riskInfo.current_position)}
+
+
+
{t('maxPosition')}
+
{formatUSD(riskInfo.max_position)}
+
+
+
{t('positionPercent')}
+
+ {riskInfo.position_percent.toFixed(1)}% +
+
+
+ {/* Mini progress bar */} +
+
+
+
+
+ + {/* Row 2: Market State & Liquidation */} +
+ {/* Market State */} +
+
+ + {t('marketState')} +
+
+
+
{t('regimeLevel')}
+
+ {t(riskInfo.regime_level || 'standard')} +
+
+
+
{t('currentPrice')}
+
{formatPrice(riskInfo.current_price)}
+
+
+
{t('breakoutLevel')}
+
+ {t(riskInfo.breakout_level || 'none')} +
+
+
+
{t('breakoutDirection')}
+
+ {riskInfo.breakout_direction ? t(riskInfo.breakout_direction) : '-'} +
+
+
+
+ + {/* Liquidation */} +
+
+ + {t('liquidationInfo')} +
+
+
+
{t('liquidationPrice')}
+
+ {riskInfo.liquidation_price > 0 ? formatPrice(riskInfo.liquidation_price) : '-'} +
+
+
+
{t('liquidationDistance')}
+
+ {riskInfo.liquidation_distance > 0 ? `${riskInfo.liquidation_distance.toFixed(1)}%` : '-'} +
+
+
+
+
+ + {/* Row 3: Box State */} +
+
+ + {t('boxState')} +
+
+
+ {t('shortBox')} + + {formatPrice(riskInfo.short_box_lower)} - {formatPrice(riskInfo.short_box_upper)} + +
+
+ {t('midBox')} + + {formatPrice(riskInfo.mid_box_lower)} - {formatPrice(riskInfo.mid_box_upper)} + +
+
+ {t('longBox')} + + {formatPrice(riskInfo.long_box_lower)} - {formatPrice(riskInfo.long_box_upper)} + +
+
+
+
+ )} +
+ ) +} diff --git a/web/src/pages/StrategyStudioPage.tsx b/web/src/pages/StrategyStudioPage.tsx index fcb3fc17..8b21d502 100644 --- a/web/src/pages/StrategyStudioPage.tsx +++ b/web/src/pages/StrategyStudioPage.tsx @@ -37,6 +37,7 @@ import { IndicatorEditor } from '../components/strategy/IndicatorEditor' import { RiskControlEditor } from '../components/strategy/RiskControlEditor' import { PromptSectionsEditor } from '../components/strategy/PromptSectionsEditor' import { PublishSettingsEditor } from '../components/strategy/PublishSettingsEditor' +import { GridConfigEditor, defaultGridConfig } from '../components/strategy/GridConfigEditor' import { DeepVoidBackground } from '../components/DeepVoidBackground' const API_BASE = import.meta.env.VITE_API_BASE || '' @@ -59,6 +60,7 @@ export function StrategyStudioPage() { // Accordion states for left panel const [expandedSections, setExpandedSections] = useState({ + gridConfig: true, coinSource: true, indicators: false, riskControl: false, @@ -486,6 +488,12 @@ export function StrategyStudioPage() { subtitle: { zh: '可视化配置和测试交易策略', en: 'Configure and test trading strategies' }, strategies: { zh: '策略', en: 'Strategies' }, newStrategy: { zh: '新建', en: 'New' }, + strategyType: { zh: '策略类型', en: 'Strategy Type' }, + aiTrading: { zh: 'AI 智能交易', en: 'AI Trading' }, + aiTradingDesc: { zh: 'AI 分析市场并自主决策买卖', en: 'AI analyzes market and makes trading decisions' }, + gridTrading: { zh: 'AI 网格交易', en: 'AI Grid Trading' }, + gridTradingDesc: { zh: 'AI 控制网格策略,在震荡市场获利', en: 'AI-controlled grid strategy for ranging markets' }, + gridConfig: { zh: '网格配置', en: 'Grid Configuration' }, coinSource: { zh: '币种来源', en: 'Coin Source' }, indicators: { zh: '技术指标', en: 'Indicators' }, riskControl: { zh: '风控参数', en: 'Risk Control' }, @@ -533,12 +541,33 @@ export function StrategyStudioPage() { ) } + // Get current strategy type (default to ai_trading if not set) + const currentStrategyType = editingConfig?.strategy_type || 'ai_trading' + const configSections = [ + // Grid Config - only for grid_trading + { + key: 'gridConfig' as const, + icon: Activity, + color: '#0ECB81', + title: t('gridConfig'), + forStrategyType: 'grid_trading' as const, + content: editingConfig?.grid_config && ( + updateConfig('grid_config', gridConfig)} + disabled={selectedStrategy?.is_default} + language={language} + /> + ), + }, + // AI Trading sections { key: 'coinSource' as const, icon: Target, color: '#F0B90B', title: t('coinSource'), + forStrategyType: 'ai_trading' as const, content: editingConfig && (

@@ -616,6 +649,7 @@ export function StrategyStudioPage() { icon: Globe, color: '#0ECB81', title: t('publishSettings'), + forStrategyType: 'both' as const, content: selectedStrategy && ( ), }, - ] + ].filter(section => + section.forStrategyType === 'both' || section.forStrategyType === currentStrategyType + ) return ( @@ -813,6 +849,62 @@ export function StrategyStudioPage() {

+ {/* Strategy Type Selector */} + {editingConfig && ( +
+
+ + {t('strategyType')} +
+
+ + +
+
+ )} + {/* Config Sections */}
{configSections.map(({ key, icon: Icon, color, title, content }) => ( diff --git a/web/src/pages/TraderDashboardPage.tsx b/web/src/pages/TraderDashboardPage.tsx index 6f60b76f..3c46f2f1 100644 --- a/web/src/pages/TraderDashboardPage.tsx +++ b/web/src/pages/TraderDashboardPage.tsx @@ -9,6 +9,7 @@ import { confirmToast, notify } from '../lib/notify' import { t, type Language } from '../i18n/translations' import { LogOut, Loader2, Eye, EyeOff, Copy, Check } from 'lucide-react' import { DeepVoidBackground } from '../components/DeepVoidBackground' +import { GridRiskPanel } from '../components/strategy/GridRiskPanel' import type { SystemStatus, AccountInfo, @@ -151,6 +152,13 @@ export function TraderDashboardPage({ setPositionsCurrentPage(1) }, [selectedTraderId, positionsPageSize]) + // Auto-set chart symbol for grid trading + useEffect(() => { + if (status?.strategy_type === 'grid_trading' && status?.grid_symbol) { + setSelectedChartSymbol(status.grid_symbol) + } + }, [status?.strategy_type, status?.grid_symbol]) + // Get current exchange info for perp-dex wallet display const currentExchange = exchanges?.find( (e) => e.id === selectedTrader?.exchange_id @@ -532,6 +540,17 @@ export function TraderDashboardPage({ />
+ {/* Grid Risk Panel - Only show for grid trading strategy */} + {status?.strategy_type === 'grid_trading' && selectedTraderId && ( +
+ +
+ )} + {/* Main Content Area */}
{/* Left Column: Charts + Positions */} diff --git a/web/src/types.ts b/web/src/types.ts index b4f39278..d2e1d1b0 100644 --- a/web/src/types.ts +++ b/web/src/types.ts @@ -11,6 +11,8 @@ export interface SystemStatus { stop_until: string last_reset_time: string ai_provider: string + strategy_type?: 'ai_trading' | 'grid_trading' + grid_symbol?: string } export interface AccountInfo { @@ -462,6 +464,8 @@ export interface PromptSectionsConfig { } export interface StrategyConfig { + // Strategy type: "ai_trading" (default) or "grid_trading" + strategy_type?: 'ai_trading' | 'grid_trading'; // Language setting: "zh" for Chinese, "en" for English // Determines the language used for data formatting and prompt generation language?: 'zh' | 'en'; @@ -470,6 +474,38 @@ export interface StrategyConfig { custom_prompt?: string; risk_control: RiskControlConfig; prompt_sections?: PromptSectionsConfig; + // Grid trading configuration (only used when strategy_type is 'grid_trading') + grid_config?: GridStrategyConfig; +} + +// Grid trading specific configuration +export interface GridStrategyConfig { + // Trading pair (e.g., "BTCUSDT") + symbol: string; + // Number of grid levels (5-50) + grid_count: number; + // Total investment in USDT + total_investment: number; + // Leverage (1-20) + leverage: number; + // Upper price boundary (0 = auto-calculate from ATR) + upper_price: number; + // Lower price boundary (0 = auto-calculate from ATR) + lower_price: number; + // Use ATR to auto-calculate bounds + use_atr_bounds: boolean; + // ATR multiplier for bound calculation (default 2.0) + atr_multiplier: number; + // Position distribution: "uniform" | "gaussian" | "pyramid" + distribution: 'uniform' | 'gaussian' | 'pyramid'; + // Maximum drawdown percentage before emergency exit + max_drawdown_pct: number; + // Stop loss percentage per position + stop_loss_pct: number; + // Daily loss limit percentage + daily_loss_limit_pct: number; + // Use maker-only orders for lower fees + use_maker_only: boolean; } export interface CoinSourceConfig { @@ -750,3 +786,36 @@ export interface PositionHistoryResponse { symbol_stats: SymbolStats[]; direction_stats: DirectionStats[]; } + +// Grid Risk Information for frontend display +export interface GridRiskInfo { + // Leverage info + current_leverage: number + effective_leverage: number + recommended_leverage: number + + // Position info + current_position: number + max_position: number + position_percent: number + + // Liquidation info + liquidation_price: number + liquidation_distance: number + + // Market state + regime_level: string + + // Box state + short_box_upper: number + short_box_lower: number + mid_box_upper: number + mid_box_lower: number + long_box_upper: number + long_box_lower: number + current_price: number + + // Breakout state + breakout_level: string + breakout_direction: string +}