diff --git a/api/server.go b/api/server.go index 99eeb8c0..dfe4d7b3 100644 --- a/api/server.go +++ b/api/server.go @@ -157,6 +157,7 @@ func (s *Server) setupRoutes() { protected.POST("/traders/:id/sync-balance", s.handleSyncBalance) protected.POST("/traders/:id/close-position", s.handleClosePosition) protected.PUT("/traders/:id/competition", s.handleToggleCompetition) + protected.GET("/traders/:id/grid-risk", s.handleGetGridRiskInfo) // AI model configuration protected.GET("/models", s.handleGetModelConfigs) @@ -1096,6 +1097,20 @@ func (s *Server) handleToggleCompetition(c *gin.Context) { }) } +// handleGetGridRiskInfo returns current risk information for a grid trader +func (s *Server) handleGetGridRiskInfo(c *gin.Context) { + traderID := c.Param("id") + + autoTrader, err := s.traderManager.GetTrader(traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "trader not found"}) + return + } + + riskInfo := autoTrader.GetGridRiskInfo() + c.JSON(http.StatusOK, riskInfo) +} + // handleSyncBalance Sync exchange balance to initial_balance (Option B: Manual Sync + Option C: Smart Detection) func (s *Server) handleSyncBalance(c *gin.Context) { userID := c.GetString("user_id") @@ -1369,7 +1384,7 @@ func (s *Server) handleClosePosition(c *gin.Context) { if closeErr != nil { logger.Infof("❌ Close position failed: symbol=%s, side=%s, error=%v", req.Symbol, req.Side, closeErr) - SafeInternalError(c, "Failed to close position", closeErr) + SafeInternalError(c, "Close position", closeErr) return } @@ -1705,8 +1720,15 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { logger.Infof("🔓 Decrypted model config data (UserID: %s)", userID) } - // Update each model's configuration + // Update each model's configuration and track traders that need reload + tradersToReload := make(map[string]bool) for modelID, modelData := range req.Models { + // Find traders using this AI model BEFORE updating + traders, _ := s.store.Trader().ListByAIModelID(userID, modelID) + for _, t := range traders { + tradersToReload[t.ID] = true + } + err := s.store.AIModel().Update(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName) if err != nil { SafeInternalError(c, fmt.Sprintf("Update model %s", modelID), err) @@ -1714,6 +1736,12 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { } } + // Remove affected traders from memory BEFORE reloading to pick up new config + for traderID := range tradersToReload { + logger.Infof("🔄 Removing trader %s from memory to reload with new AI model config", traderID) + s.traderManager.RemoveTrader(traderID) + } + // Reload all traders for this user to make new config take effect immediately err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { @@ -1825,8 +1853,15 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { logger.Infof("🔓 Decrypted exchange config data (UserID: %s)", userID) } - // Update each exchange's configuration + // Update each exchange's configuration and track traders that need reload + tradersToReload := make(map[string]bool) for exchangeID, exchangeData := range req.Exchanges { + // Find traders using this exchange BEFORE updating + traders, _ := s.store.Trader().ListByExchangeID(userID, exchangeID) + for _, t := range traders { + tradersToReload[t.ID] = true + } + err := s.store.Exchange().Update(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Passphrase, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey, exchangeData.LighterAPIKeyPrivateKey, exchangeData.LighterAPIKeyIndex) if err != nil { SafeInternalError(c, fmt.Sprintf("Update exchange %s", exchangeID), err) @@ -1834,6 +1869,12 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { } } + // Remove affected traders from memory BEFORE reloading to pick up new config + for traderID := range tradersToReload { + logger.Infof("🔄 Removing trader %s from memory to reload with new exchange config", traderID) + s.traderManager.RemoveTrader(traderID) + } + // Reload all traders for this user to make new config take effect immediately err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { diff --git a/cmd/lighter_test/main.go b/cmd/lighter_test/main.go new file mode 100644 index 00000000..6f896a23 --- /dev/null +++ b/cmd/lighter_test/main.go @@ -0,0 +1,233 @@ +// Lighter API Authentication Test Tool +// Usage: go run cmd/lighter_test/main.go -wallet=0x... -apikey=... [-testnet] +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + lighterClient "github.com/elliottech/lighter-go/client" + lighterHTTP "github.com/elliottech/lighter-go/client/http" +) + +func main() { + // Parse command line flags + walletAddr := flag.String("wallet", "", "Ethereum wallet address") + apiKeyPrivateKey := flag.String("apikey", "", "API key private key (40 bytes hex)") + apiKeyIndex := flag.Int("apikeyindex", 0, "API key index (0-255)") + testnet := flag.Bool("testnet", false, "Use testnet instead of mainnet") + flag.Parse() + + if *walletAddr == "" || *apiKeyPrivateKey == "" { + fmt.Println("Usage: go run cmd/lighter_test/main.go -wallet=0x... -apikey=...") + fmt.Println("Options:") + fmt.Println(" -wallet Ethereum wallet address (required)") + fmt.Println(" -apikey API key private key, 40 bytes hex (required)") + fmt.Println(" -apikeyindex API key index, 0-255 (default: 0)") + fmt.Println(" -testnet Use testnet instead of mainnet") + os.Exit(1) + } + + fmt.Println("=== Lighter API Authentication Test ===") + fmt.Printf("Wallet: %s\n", *walletAddr) + fmt.Printf("API Key Index: %d\n", *apiKeyIndex) + fmt.Printf("Testnet: %v\n", *testnet) + fmt.Println() + + // Determine base URL + baseURL := "https://mainnet.zklighter.elliot.ai" + chainID := uint32(304) + if *testnet { + baseURL = "https://testnet.zklighter.elliot.ai" + chainID = uint32(300) + } + + // Create HTTP client + httpClient := lighterHTTP.NewClient(baseURL) + client := &http.Client{Timeout: 30 * time.Second} + + // Step 1: Get account info + fmt.Println("Step 1: Getting account info...") + accountInfo, err := getAccountByL1Address(client, baseURL, *walletAddr) + if err != nil { + fmt.Printf("ERROR: Failed to get account info: %v\n", err) + os.Exit(1) + } + fmt.Printf("SUCCESS: Account index = %d\n\n", accountInfo.AccountIndex) + + // Step 2: Create TxClient + fmt.Println("Step 2: Creating TxClient...") + txClient, err := lighterClient.NewTxClient( + httpClient, + *apiKeyPrivateKey, + accountInfo.AccountIndex, + uint8(*apiKeyIndex), + chainID, + ) + if err != nil { + fmt.Printf("ERROR: Failed to create TxClient: %v\n", err) + os.Exit(1) + } + fmt.Println("SUCCESS: TxClient created\n") + + // Step 3: Generate auth token + fmt.Println("Step 3: Generating auth token...") + deadline := time.Now().Add(1 * time.Hour) + authToken, err := txClient.GetAuthToken(deadline) + if err != nil { + fmt.Printf("ERROR: Failed to generate auth token: %v\n", err) + os.Exit(1) + } + fmt.Printf("SUCCESS: Auth token generated\n") + fmt.Printf("Token: %s...\n", authToken[:min(50, len(authToken))]) + fmt.Printf("Valid until: %s\n\n", deadline.Format(time.RFC3339)) + + // Step 4: Test GetActiveOrders API with auth query parameter + fmt.Println("Step 4: Testing GetActiveOrders API...") + encodedAuth := url.QueryEscape(authToken) + endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0&auth=%s", + baseURL, accountInfo.AccountIndex, encodedAuth) + + fmt.Printf("Endpoint: %s...\n", endpoint[:min(120, len(endpoint))]) + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + fmt.Printf("ERROR: Failed to create request: %v\n", err) + os.Exit(1) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + fmt.Printf("ERROR: Request failed: %v\n", err) + os.Exit(1) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + fmt.Printf("Status: %d\n", resp.StatusCode) + fmt.Printf("Response: %s\n\n", string(body)) + + // Parse response + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []struct { + OrderID string `json:"order_id"` + Side string `json:"side"` + Type string `json:"type"` + Price string `json:"price"` + } `json:"orders"` + } + if err := json.Unmarshal(body, &apiResp); err != nil { + fmt.Printf("ERROR: Failed to parse response: %v\n", err) + os.Exit(1) + } + + if apiResp.Code != 200 { + fmt.Printf("API ERROR: code=%d, message=%s\n", apiResp.Code, apiResp.Message) + fmt.Println("\n=== DIAGNOSTIC INFO ===") + fmt.Println("If you see 'invalid signature', possible causes:") + fmt.Println("1. API key is not registered on-chain") + fmt.Println("2. API key private key is incorrect") + fmt.Println("3. API key index is wrong") + fmt.Println("4. Account index mismatch") + fmt.Println("\nTo fix:") + fmt.Println("- Go to app.lighter.xyz and register/verify your API key") + fmt.Println("- Make sure you're using the correct API key private key") + os.Exit(1) + } + + fmt.Printf("SUCCESS: Retrieved %d orders\n", len(apiResp.Orders)) + for i, order := range apiResp.Orders { + if i >= 5 { + fmt.Printf("... and %d more orders\n", len(apiResp.Orders)-5) + break + } + fmt.Printf(" Order %s: %s %s @ %s\n", order.OrderID, order.Side, order.Type, order.Price) + } + + // Step 5: Test GetTrades API (also needs auth) + fmt.Println("\nStep 5: Testing GetTrades API...") + tradesEndpoint := fmt.Sprintf("%s/api/v1/trades?account_index=%d&sort_by=timestamp&sort_dir=desc&limit=5&auth=%s", + baseURL, accountInfo.AccountIndex, encodedAuth) + + tradesReq, _ := http.NewRequest("GET", tradesEndpoint, nil) + tradesResp, err := client.Do(tradesReq) + if err != nil { + fmt.Printf("ERROR: Trades request failed: %v\n", err) + } else { + defer tradesResp.Body.Close() + tradesBody, _ := io.ReadAll(tradesResp.Body) + fmt.Printf("Status: %d\n", tradesResp.StatusCode) + if tradesResp.StatusCode == 200 { + fmt.Println("SUCCESS: GetTrades API working") + } else { + fmt.Printf("Response: %s\n", string(tradesBody)) + } + } + + fmt.Println("\n=== ALL TESTS PASSED ===") +} + +// AccountInfo represents Lighter account information +type AccountInfo struct { + AccountIndex int64 `json:"account_index"` + L1Address string `json:"l1_address"` +} + +// getAccountByL1Address gets account info by L1 wallet address +func getAccountByL1Address(client *http.Client, baseURL, walletAddr string) (*AccountInfo, error) { + endpoint := fmt.Sprintf("%s/api/v1/account?by=l1_address&value=%s", baseURL, walletAddr) + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + req = req.WithContext(ctx) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Parse response - can be in "accounts" or "sub_accounts" field + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Accounts []AccountInfo `json:"accounts"` + SubAccounts []AccountInfo `json:"sub_accounts"` + } + + if err := json.Unmarshal(body, &apiResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w, body: %s", err, string(body)) + } + + // Check main accounts first + if len(apiResp.Accounts) > 0 { + return &apiResp.Accounts[0], nil + } + + // Check sub-accounts + if len(apiResp.SubAccounts) > 0 { + return &apiResp.SubAccounts[0], nil + } + + return nil, fmt.Errorf("no account found for address: %s", walletAddr) +} diff --git a/docs/market-regime-classification-en.md b/docs/market-regime-classification-en.md new file mode 100644 index 00000000..ddc83c89 --- /dev/null +++ b/docs/market-regime-classification-en.md @@ -0,0 +1,281 @@ +# Market Regime Classification Framework + +> A comprehensive market state identification system for quantitative trading strategy matching + +--- + +## 1. Classification Dimensions Overview + +Market state identification requires analysis across multiple dimensions: + +| Dimension | Sub-dimensions | Description | +|-----------|---------------|-------------| +| **Trend** | Direction, Strength | Determine market movement direction and momentum | +| **Volatility** | Amplitude, Frequency | Measure price fluctuation characteristics | +| **Structure** | Pattern, Phase | Identify market structure and cycle position | + +--- + +## 2. Primary Classification (5 Categories) + +### 2.1 Classification Overview + +| Code | Name | Key Characteristics | Suitable Strategies | +|------|------|---------------------|---------------------| +| `TREND_UP` | Uptrend | Higher highs & higher lows | Trend following, Breakout | +| `TREND_DOWN` | Downtrend | Lower highs & lower lows | Trend following, Short selling | +| `RANGE` | Range-bound | Price oscillates within bounds | Grid trading, Mean reversion | +| `TRANSITION` | Transition | Uncertain directional period | Wait & watch, Small positions | +| `BREAKOUT` | Breakout | Price breaks key levels | Breakout trading | + +### 2.2 Identification Indicators + +- **ADX (Average Directional Index)**: Measures trend strength + - ADX > 25: Clear trend exists + - ADX < 20: Range-bound market +- **EMA Alignment**: Determines trend direction + - EMA20 > EMA50 > EMA200: Bullish alignment + - EMA20 < EMA50 < EMA200: Bearish alignment + +--- + +## 3. Secondary Classification (18 Sub-categories) + +### 3.1 Uptrend Sub-categories (5 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `TU_STRONG_LOW_VOL` | Strong Uptrend · Low Vol | Steady rise, shallow pullbacks | ADX>40, ATR%<2%, Pullback<38.2% | +| `TU_STRONG_HIGH_VOL` | Strong Uptrend · High Vol | Rapid surge, high volatility | ADX>40, ATR%>4%, MACD histogram expanding | +| `TU_WEAK_CHOPPY` | Weak Uptrend · Choppy | Two steps forward, one back | ADX 20-30, RSI oscillating 50-70 | +| `TU_PARABOLIC` | Parabolic Acceleration | Exponential price increase | Price far from MA, RSI>80, Volume surge | +| `TU_EXHAUSTION` | Uptrend Exhaustion | New highs but weakening momentum | Price new high + MACD/RSI divergence | + +**Strategy Matching:** +- Strong Low Vol: Heavy trend following, pyramid adding +- Strong High Vol: Medium position, trailing stops +- Weak Choppy: Light swing trading +- Parabolic: Cautious, prepare to exit +- Exhaustion: Reduce positions, prepare for reversal + +### 3.2 Downtrend Sub-categories (5 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `TD_STRONG_LOW_VOL` | Strong Downtrend · Low Vol | Steady decline, weak bounces | ADX>40, ATR%<2%, Bounce<38.2% | +| `TD_STRONG_HIGH_VOL` | Strong Downtrend · High Vol | Panic selling, wild swings | ADX>40, ATR%>5%, VIX spike | +| `TD_WEAK_CHOPPY` | Weak Downtrend · Choppy | Grinding lower with bounces | ADX 20-30, RSI oscillating 30-50 | +| `TD_CAPITULATION` | Capitulation | High volume crash, extreme fear | RSI<20, Volume>3x average | +| `TD_EXHAUSTION` | Downtrend Exhaustion | New lows but selling pressure fading | Price new low + MACD/RSI divergence | + +**Strategy Matching:** +- Strong Low Vol: Short trend following +- Strong High Vol: Stay flat or light hedge +- Weak Choppy: Wait for stabilization +- Capitulation: Light bottom fishing possible +- Exhaustion: Gradually build long positions + +### 3.3 Range Sub-categories (4 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `RG_TIGHT_LOW_VOL` | Tight Range · Low Vol | Extreme contraction, coiling | BB Width<2%, ATR at new lows | +| `RG_TIGHT_HIGH_VOL` | Tight Range · High Vol | Violent swings within range | BB Width<3%, ATR%>3% | +| `RG_WIDE_LOW_VOL` | Wide Range · Low Vol | Large range, slow movement | BB Width>5%, ATR%<2% | +| `RG_WIDE_HIGH_VOL` | Wide Range · High Vol | Large range, fast movement | BB Width>5%, ATR%>3% | + +**Strategy Matching:** +- Tight Low Vol: Dense grid, wait for breakout +- Tight High Vol: Fast grid, small frequent profits +- Wide Low Vol: Sparse grid, patient holding +- Wide High Vol: Swing trading, high profit targets + +### 3.4 Transition (2 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `TR_BOTTOM_FORMING` | Bottom Forming | Decline slowing, testing support | Price stabilizing + Volume drying up + RSI divergence | +| `TR_TOP_FORMING` | Top Forming | Rally slowing, testing resistance | Price stalling + Volume drying up + RSI divergence | + +### 3.5 Breakout (2 Types) + +| Code | Name | Technical Features | Quantitative Indicators | +|------|------|-------------------|------------------------| +| `BK_UPWARD` | Upward Breakout | Breaking resistance with volume | Price>Previous high, Volume>2x, BB breakout | +| `BK_DOWNWARD` | Downward Breakout | Breaking support with volume | Price2x, BB breakdown | + +--- + +## 4. Tertiary Classification (36 Ultra-fine Categories) + +### 4.1 Trend Phase Classification + +Uptrend lifecycle consists of 5 phases: + +| Phase Code | Name | Description | Quantitative Criteria | +|------------|------|-------------|----------------------| +| `TU_S1_INITIATION` | Uptrend Initiation | First break above MA or previous high | MACD bullish cross, Price>EMA20 | +| `TU_S2_ACCELERATION` | Uptrend Acceleration | Momentum increasing, slope steepening | MACD histogram expanding, ADX rising | +| `TU_S3_MAIN_WAVE` | Main Wave | Sustained rise, shallow pullbacks | RSI 60-80, Pullbacks hold EMA20 | +| `TU_S4_EXHAUSTION` | Uptrend Exhaustion | Slowing momentum, divergences appearing | RSI divergence, MACD divergence | +| `TU_S5_REVERSAL` | Trend Reversal | Breakdown, trend ending | Break below EMA50, MACD bearish cross | + +Downtrend phases follow same pattern: `TD_S1` through `TD_S5` + +### 4.2 Range Position Classification + +| Position Code | Name | Description | Strategy Suggestion | +|---------------|------|-------------|---------------------| +| `RG_UPPER` | Upper Range | Price near resistance | Bias toward short | +| `RG_MIDDLE` | Mid Range | Price near middle band | Neutral grid trading | +| `RG_LOWER` | Lower Range | Price near support | Bias toward long | +| `RG_SQUEEZE` | Squeeze Pattern | Highs and lows converging | Wait for direction | +| `RG_EXPAND` | Expanding Pattern | Highs and lows diverging | Boundary reversal | + +### 4.3 Volatility Grades + +| Code | Name | ATR% | BB Width | Strategy Suggestion | +|------|------|------|----------|---------------------| +| `VOL_EXTREME_LOW` | Extreme Low Vol | <1% | <1.5% | Option selling | +| `VOL_LOW` | Low Volatility | 1-2% | 1.5-2.5% | Grid / Mean reversion | +| `VOL_NORMAL` | Normal Volatility | 2-3% | 2.5-4% | Trend following | +| `VOL_HIGH` | High Volatility | 3-5% | 4-6% | Momentum / Breakout | +| `VOL_EXTREME_HIGH` | Extreme High Vol | >5% | >6% | Reduce exposure / Hedge | + +--- + +## 5. Complete State Encoding Rules + +### 5.1 Encoding Format + +``` +{Primary}_{Volatility}_{Phase}_{Position} +``` + +### 5.2 Encoding Examples + +| Full Code | Interpretation | +|-----------|----------------| +| `TU_LV_S3_M` | Uptrend_LowVol_MainWave_Middle | +| `TD_HV_S2_L` | Downtrend_HighVol_Acceleration_Lower | +| `RG_NV_SQ_U` | Range_NormalVol_Squeeze_Upper | +| `BK_HV_UP_M` | Breakout_HighVol_Upward_Middle | + +--- + +## 6. Core Identification Indicators + +### 6.1 Trend Indicators + +| Indicator | Calculation | Criteria | +|-----------|-------------|----------| +| ADX | 14-period Average Directional Index | >40 Strong, 25-40 Medium, <25 Weak/Range | +| Trend Score | Composite EMA/MACD/Price structure | -100 to +100, Positive=Bullish, Negative=Bearish | +| EMA Alignment | Relative position of EMA20/50/200 | Bullish/Bearish/Mixed alignment | + +### 6.2 Volatility Indicators + +| Indicator | Calculation | Purpose | +|-----------|-------------|---------| +| ATR Percent | ATR(14) / Current Price × 100% | Measure relative volatility | +| BB Width | (Upper - Lower) / Middle × 100% | Measure price range | +| Volatility Rank | Current vol percentile in history | Determine vol level | + +### 6.3 Momentum Indicators + +| Indicator | Calculation | Criteria | +|-----------|-------------|----------| +| RSI | 14-period Relative Strength Index | >70 Overbought, <30 Oversold, 50 Neutral | +| MACD Histogram | MACD - Signal | Positive=Bullish momentum, Negative=Bearish | +| Momentum Score | Composite RSI/MACD/Volume | Measure current momentum | + +### 6.4 Structure Indicators + +| Indicator | Description | Purpose | +|-----------|-------------|---------| +| Swing Structure | HH/HL/LH/LL sequence | Determine trend structure | +| Support/Resistance | Key price levels | Define trading range | +| Volume Profile | Volume-price relationship | Validate price action | + +--- + +## 7. Strategy Matching Matrix + +### 7.1 Regime-Strategy Mapping + +| Regime Type | Recommended Strategy | Position Size | Stop Loss | +|-------------|---------------------|---------------|-----------| +| Strong Uptrend · Low Vol | Trend following + Pyramid | 60-80% | ATR×2 | +| Strong Uptrend · High Vol | Momentum + Quick profit | 40-60% | ATR×1.5 | +| Uptrend Exhaustion | Reduce + Reversal short | 20-30% | Previous high | +| Panic Decline | Wait or light bottom fish | 10-20% | Wide stop | +| Low Vol Range | Grid trading | 50-70% | Range boundary | +| High Vol Range | Swing trading | 30-50% | ATR×2 | +| Squeeze Pattern | Wait for breakout | 10-20% | - | +| Upward Breakout | Chase + Add on pullback | 50-70% | Breakout level | +| Bottom Formation | Scale in gradually | 20-40% | New low | + +### 7.2 Grid Strategy Parameter Matching + +| Range Type | Grid Levels | Grid Spacing | Other Parameters | +|------------|-------------|--------------|------------------| +| Tight Low Vol | 30-50 levels | Small spacing | Enable Maker Only | +| Tight High Vol | 15-25 levels | Small spacing | Fast execution mode | +| Wide Low Vol | 10-20 levels | Large spacing | Patient execution | +| Wide High Vol | 15-25 levels | Large spacing | High profit targets | +| Squeeze Pattern | Pause grid | - | Wait for breakout signal | +| Upper Range | Short bias | Medium | Increase sell weight | +| Lower Range | Long bias | Medium | Increase buy weight | + +--- + +## 8. Real-time Monitoring Guidelines + +### 8.1 State Transition Triggers + +| Current State | Trigger Condition | Transitions To | +|---------------|-------------------|----------------| +| Range | Price breakout + Volume + ADX rising | Breakout | +| Uptrend | RSI divergence + Volume decline | Exhaustion | +| Downtrend | RSI divergence + Volume decline | Exhaustion | +| Breakout | Failed breakout, price returns | Range | +| Exhaustion | Confirmed reversal breakout | Opposite trend | + +### 8.2 Risk Control Rules + +| Regime State | Max Position | Risk Per Trade | Special Rules | +|--------------|--------------|----------------|---------------| +| Strong Trend | 80% | 2% | Adding allowed | +| Weak Trend | 50% | 1.5% | No adding | +| Range | 60% | 1% | Diversified holding | +| Transition | 30% | 1% | Reduce activity | +| High Volatility | 40% | 0.5% | Wide stops | + +--- + +## 9. Appendix + +### 9.1 Abbreviation Reference + +| Abbrev | Full Form | Description | +|--------|-----------|-------------| +| TU | Trend Up | Upward trend | +| TD | Trend Down | Downward trend | +| RG | Range | Range-bound market | +| TR | Transition | Trend transition | +| BK | Breakout | Breakout pattern | +| LV | Low Volatility | Low volatility regime | +| HV | High Volatility | High volatility regime | +| NV | Normal Volatility | Normal volatility regime | +| XLV | Extreme Low Vol | Extremely low volatility | +| XHV | Extreme High Vol | Extremely high volatility | + +### 9.2 Document Information + +- Version: v1.0 +- Created: January 2026 +- Applicable: Cryptocurrency, Forex, Stocks, and other financial markets + +--- + +*This document is designed for market state identification and strategy matching in quantitative trading systems* diff --git a/docs/market-regime-classification-zh.md b/docs/market-regime-classification-zh.md new file mode 100644 index 00000000..36e0092d --- /dev/null +++ b/docs/market-regime-classification-zh.md @@ -0,0 +1,281 @@ +# 市场行情精细分类体系 + +> 用于量化交易策略匹配的市场状态识别框架 + +--- + +## 一、分类维度概览 + +市场状态识别需要从多个维度进行分析: + +| 维度 | 子维度 | 说明 | +|------|--------|------| +| **趋势维度** | 方向、强度 | 判断市场运动方向和力度 | +| **波动维度** | 幅度、频率 | 衡量价格波动特征 | +| **结构维度** | 形态、阶段 | 识别市场结构和所处周期 | + +--- + +## 二、一级分类(5大类) + +### 2.1 分类总览 + +| 代码 | 名称 | 核心特征 | 适合策略 | +|------|------|----------|----------| +| `TREND_UP` | 上涨趋势 | 高点/低点持续抬升 | 趋势跟踪、突破追涨 | +| `TREND_DOWN` | 下跌趋势 | 高点/低点持续降低 | 趋势跟踪、做空策略 | +| `RANGE` | 震荡区间 | 价格在区间内波动 | 网格交易、均值回归 | +| `TRANSITION` | 趋势转换 | 方向不明确的过渡期 | 观望、小仓位试探 | +| `BREAKOUT` | 突破行情 | 价格突破关键位置 | 突破追踪策略 | + +### 2.2 识别指标 + +- **ADX(平均方向指数)**:衡量趋势强度 + - ADX > 25:存在明确趋势 + - ADX < 20:震荡市场 +- **EMA排列**:判断趋势方向 + - EMA20 > EMA50 > EMA200:多头排列 + - EMA20 < EMA50 < EMA200:空头排列 + +--- + +## 三、二级分类(18细分类) + +### 3.1 上涨趋势细分(5种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `TU_STRONG_LOW_VOL` | 强势上涨·低波动 | 稳步上涨,回调幅度小 | ADX>40, ATR%<2%, 回调<38.2% | +| `TU_STRONG_HIGH_VOL` | 强势上涨·高波动 | 快速拉升,波动剧烈 | ADX>40, ATR%>4%, MACD柱放大 | +| `TU_WEAK_CHOPPY` | 弱势上涨·震荡 | 涨三退二,反复磨蹭 | ADX 20-30, RSI在50-70震荡 | +| `TU_PARABOLIC` | 抛物线加速 | 指数级加速上涨 | 价格远离均线, RSI>80, 成交量放大 | +| `TU_EXHAUSTION` | 上涨衰竭 | 创新高但动能减弱 | 价格新高 + MACD/RSI顶背离 | + +**策略匹配:** +- 强势低波动:重仓趋势跟踪,金字塔加仓 +- 强势高波动:中等仓位,设置移动止盈 +- 弱势震荡:轻仓波段,高抛低吸 +- 抛物线加速:谨慎追涨,准备离场 +- 上涨衰竭:减仓观望,准备反转做空 + +### 3.2 下跌趋势细分(5种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `TD_STRONG_LOW_VOL` | 强势下跌·低波动 | 稳步下跌,反弹无力 | ADX>40, ATR%<2%, 反弹<38.2% | +| `TD_STRONG_HIGH_VOL` | 强势下跌·高波动 | 恐慌抛售,波动剧烈 | ADX>40, ATR%>5%, 恐慌指数飙升 | +| `TD_WEAK_CHOPPY` | 弱势下跌·震荡 | 跌跌涨涨,磨底过程 | ADX 20-30, RSI在30-50震荡 | +| `TD_CAPITULATION` | 恐慌投降 | 放量暴跌,情绪极端 | RSI<20, 成交量>3倍均量 | +| `TD_EXHAUSTION` | 下跌衰竭 | 创新低但卖压减弱 | 价格新低 + MACD/RSI底背离 | + +**策略匹配:** +- 强势低波动:空头趋势跟踪 +- 强势高波动:观望或轻仓对冲 +- 弱势震荡:等待企稳信号 +- 恐慌投降:极端情况可轻仓抄底 +- 下跌衰竭:逐步建立多头仓位 + +### 3.3 震荡区间细分(4种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `RG_TIGHT_LOW_VOL` | 窄幅震荡·低波动 | 极度收敛,蓄势待发 | 布林带宽度<2%, ATR创新低 | +| `RG_TIGHT_HIGH_VOL` | 窄幅震荡·高波动 | 区间内剧烈波动 | 布林带宽度<3%, ATR%>3% | +| `RG_WIDE_LOW_VOL` | 宽幅震荡·低波动 | 大区间慢速波动 | 布林带宽度>5%, ATR%<2% | +| `RG_WIDE_HIGH_VOL` | 宽幅震荡·高波动 | 大区间快速波动 | 布林带宽度>5%, ATR%>3% | + +**策略匹配:** +- 窄幅低波动:密集网格,等待突破 +- 窄幅高波动:快速网格,小利润多次 +- 宽幅低波动:稀疏网格,耐心持有 +- 宽幅高波动:波段交易,高利润目标 + +### 3.4 转换过渡(2种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `TR_BOTTOM_FORMING` | 底部形成中 | 下跌放缓,试探支撑 | 价格止跌 + 成交量萎缩 + RSI底背离 | +| `TR_TOP_FORMING` | 顶部形成中 | 上涨放缓,试探压力 | 价格滞涨 + 成交量萎缩 + RSI顶背离 | + +### 3.5 突破行情(2种) + +| 代码 | 名称 | 技术特征 | 量化指标 | +|------|------|----------|----------| +| `BK_UPWARD` | 向上突破 | 突破阻力位并放量 | 价格>前高, 成交量>2倍, 布林带突破 | +| `BK_DOWNWARD` | 向下突破 | 跌破支撑位并放量 | 价格<前低, 成交量>2倍, 布林带跌破 | + +--- + +## 四、三级分类(36超细分类) + +### 4.1 趋势阶段细分 + +上涨趋势生命周期分为5个阶段: + +| 阶段代码 | 名称 | 特征描述 | 量化判断标准 | +|----------|------|----------|--------------| +| `TU_S1_INITIATION` | 上涨启动期 | 首次突破均线或前高 | MACD金叉, 价格突破EMA20 | +| `TU_S2_ACCELERATION` | 上涨加速期 | 动能增强,斜率加大 | MACD柱持续增大, ADX上升 | +| `TU_S3_MAIN_WAVE` | 主升浪阶段 | 持续上涨,回调幅度浅 | RSI维持60-80, 回调不破EMA20 | +| `TU_S4_EXHAUSTION` | 上涨衰竭期 | 涨速放缓,出现背离 | RSI顶背离, MACD顶背离 | +| `TU_S5_REVERSAL` | 趋势反转期 | 破位下跌,趋势结束 | 跌破EMA50, MACD死叉 | + +下跌趋势同理,代码为 `TD_S1` 至 `TD_S5` + +### 4.2 震荡位置细分 + +| 位置代码 | 名称 | 特征描述 | 策略建议 | +|----------|------|----------|----------| +| `RG_UPPER` | 区间上沿震荡 | 价格接近阻力位 | 偏空操作为主 | +| `RG_MIDDLE` | 区间中部震荡 | 价格在中轨附近 | 双向网格交易 | +| `RG_LOWER` | 区间下沿震荡 | 价格接近支撑位 | 偏多操作为主 | +| `RG_SQUEEZE` | 收敛三角震荡 | 高低点逐渐收窄 | 等待方向选择 | +| `RG_EXPAND` | 扩散三角震荡 | 高低点逐渐扩张 | 边界反转操作 | + +### 4.3 波动率等级 + +| 代码 | 名称 | ATR百分比 | 布林带宽度 | 策略建议 | +|------|------|-----------|------------|----------| +| `VOL_EXTREME_LOW` | 极低波动 | <1% | <1.5% | 期权卖方策略 | +| `VOL_LOW` | 低波动 | 1-2% | 1.5-2.5% | 网格/均值回归 | +| `VOL_NORMAL` | 正常波动 | 2-3% | 2.5-4% | 趋势跟踪 | +| `VOL_HIGH` | 高波动 | 3-5% | 4-6% | 动量/突破 | +| `VOL_EXTREME_HIGH` | 极高波动 | >5% | >6% | 减仓/对冲 | + +--- + +## 五、完整状态编码规则 + +### 5.1 编码格式 + +``` +{一级分类}_{波动等级}_{阶段}_{位置} +``` + +### 5.2 编码示例 + +| 完整代码 | 含义解释 | +|----------|----------| +| `TU_LV_S3_M` | 上涨趋势_低波动_主升浪_中部位置 | +| `TD_HV_S2_L` | 下跌趋势_高波动_加速期_下部位置 | +| `RG_NV_SQ_U` | 震荡区间_正常波动_收敛形态_上沿位置 | +| `BK_HV_UP_M` | 突破行情_高波动_向上突破_中部位置 | + +--- + +## 六、核心识别指标 + +### 6.1 趋势指标 + +| 指标 | 计算方法 | 判断标准 | +|------|----------|----------| +| ADX | 14周期平均方向指数 | >40强趋势, 25-40中等, <25弱/震荡 | +| 趋势评分 | 综合EMA/MACD/价格结构 | -100到+100, 正数多头,负数空头 | +| EMA排列 | EMA20/50/200相对位置 | 多头排列/空头排列/混乱 | + +### 6.2 波动指标 + +| 指标 | 计算方法 | 用途 | +|------|----------|------| +| ATR百分比 | ATR(14) / 当前价格 × 100% | 衡量相对波动幅度 | +| 布林带宽度 | (上轨-下轨) / 中轨 × 100% | 衡量价格波动区间 | +| 波动率排名 | 当前波动在历史中的分位 | 判断波动率高低 | + +### 6.3 动量指标 + +| 指标 | 计算方法 | 判断标准 | +|------|----------|----------| +| RSI | 14周期相对强弱指数 | >70超买, <30超卖, 50中性 | +| MACD柱 | MACD - Signal | 正数多头动能,负数空头动能 | +| 动量评分 | 综合RSI/MACD/成交量 | 衡量当前动能强弱 | + +### 6.4 结构指标 + +| 指标 | 说明 | 用途 | +|------|------|------| +| 高低点结构 | HH/HL/LH/LL序列 | 判断趋势结构 | +| 支撑阻力位 | 关键价格水平 | 确定交易区间 | +| 成交量形态 | 量价配合关系 | 验证价格走势 | + +--- + +## 七、策略匹配矩阵 + +### 7.1 行情类型与策略对应 + +| 行情类型 | 推荐策略 | 建议仓位 | 止损设置 | +|----------|----------|----------|----------| +| 强势上涨·低波动 | 趋势跟踪+金字塔加仓 | 60-80% | ATR×2 | +| 强势上涨·高波动 | 动量突破+快速止盈 | 40-60% | ATR×1.5 | +| 上涨衰竭期 | 减仓+反转信号做空 | 20-30% | 前高 | +| 恐慌下跌 | 观望或轻仓抄底 | 10-20% | 宽止损 | +| 低波动震荡 | 网格交易 | 50-70% | 区间边界 | +| 高波动震荡 | 波段高抛低吸 | 30-50% | ATR×2 | +| 收敛等待 | 蓄势等突破 | 10-20% | - | +| 向上突破 | 追涨+回踩加仓 | 50-70% | 突破位 | +| 底部形成 | 分批建仓 | 20-40% | 新低 | + +### 7.2 网格策略参数匹配 + +| 震荡类型 | 网格层数 | 网格间距 | 其他参数 | +|----------|----------|----------|----------| +| 窄幅低波动 | 30-50层 | 小间距 | 启用Maker Only | +| 窄幅高波动 | 15-25层 | 小间距 | 快速成交模式 | +| 宽幅低波动 | 10-20层 | 大间距 | 耐心等待成交 | +| 宽幅高波动 | 15-25层 | 大间距 | 高利润目标 | +| 收敛形态 | 暂停网格 | - | 等待突破信号 | +| 区间上沿 | 偏空配置 | 中等 | 卖单权重增加 | +| 区间下沿 | 偏多配置 | 中等 | 买单权重增加 | + +--- + +## 八、实时监控建议 + +### 8.1 状态转换触发条件 + +| 当前状态 | 触发条件 | 转换到 | +|----------|----------|--------| +| 震荡区间 | 价格突破+放量+ADX上升 | 突破行情 | +| 上涨趋势 | RSI顶背离+成交量萎缩 | 上涨衰竭 | +| 下跌趋势 | RSI底背离+成交量萎缩 | 下跌衰竭 | +| 突破行情 | 突破失败回落 | 震荡区间 | +| 趋势衰竭 | 反向突破确认 | 反向趋势 | + +### 8.2 风险控制规则 + +| 行情状态 | 最大仓位 | 单笔风险 | 特殊规则 | +|----------|----------|----------|----------| +| 强趋势 | 80% | 2% | 可加仓 | +| 弱趋势 | 50% | 1.5% | 不加仓 | +| 震荡 | 60% | 1% | 分散持仓 | +| 转换期 | 30% | 1% | 减少操作 | +| 高波动 | 40% | 0.5% | 宽止损 | + +--- + +## 九、附录 + +### 9.1 缩写对照表 + +| 缩写 | 英文全称 | 中文含义 | +|------|----------|----------| +| TU | Trend Up | 上涨趋势 | +| TD | Trend Down | 下跌趋势 | +| RG | Range | 震荡区间 | +| TR | Transition | 趋势转换 | +| BK | Breakout | 突破行情 | +| LV | Low Volatility | 低波动 | +| HV | High Volatility | 高波动 | +| NV | Normal Volatility | 正常波动 | +| XLV | Extreme Low Vol | 极低波动 | +| XHV | Extreme High Vol | 极高波动 | + +### 9.2 版本信息 + +- 文档版本:v1.0 +- 创建日期:2026年1月 +- 适用范围:加密货币、外汇、股票等金融市场 + +--- + +*本文档用于量化交易系统的市场状态识别和策略匹配* diff --git a/docs/plans/2026-01-14-grid-trading-fixes.md b/docs/plans/2026-01-14-grid-trading-fixes.md new file mode 100644 index 00000000..48d37435 --- /dev/null +++ b/docs/plans/2026-01-14-grid-trading-fixes.md @@ -0,0 +1,1072 @@ +# AI自适应网格交易系统修复计划 + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** 修复AI网格交易系统的所有致命和严重问题,添加代码级风控保护机制。 + +**Architecture:** +1. 在AI决策和订单执行之间添加风控验证层 +2. 实现代码级止损、仓位限制、突破检测 +3. 修复杠杆设置和订单取消的BUG +4. 添加自动网格调整机制 + +**Tech Stack:** Go, GORM, 交易所API接口 + +--- + +## 问题优先级 + +| 优先级 | 问题 | Task | +|--------|------|------| +| P0 致命 | 杠杆未生效 | Task 1 | +| P0 致命 | 取消订单逻辑错误 | Task 2 | +| P0 致命 | 无总仓位限制 | Task 3 | +| P1 严重 | 无止损执行 | Task 4 | +| P1 严重 | 无突破检测 | Task 5 | +| P1 严重 | MaxDrawdown未执行 | Task 6 | +| P1 严重 | DailyLossLimit未执行 | Task 7 | +| P2 中等 | 无动态调整 | Task 8 | +| P2 中等 | 订单状态同步错误 | Task 9 | + +--- + +## Task 1: 修复杠杆设置BUG + +**问题:** `PlaceLimitOrder` 完全忽略 `Leverage` 字段,从未调用 `SetLeverage()` + +**Files:** +- Modify: `trader/interface.go:171-194` +- Modify: `trader/auto_trader_grid.go:324-409` +- Create: `trader/grid_test.go` (新增测试) + +### Step 1.1: 在 GridTraderAdapter.PlaceLimitOrder 中添加杠杆设置 + +修改 `trader/interface.go`: + +```go +// PlaceLimitOrder implements limit order using available methods +// For exchanges without native limit order support, this uses conditional orders +func (a *GridTraderAdapter) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // CRITICAL FIX: Set leverage before placing order + if req.Leverage > 0 { + if err := a.Trader.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Grid] Failed to set leverage %dx: %v", req.Leverage, err) + // Continue anyway - some exchanges don't require explicit leverage setting + } + } + + // Use SetStopLoss/SetTakeProfit as conditional limit orders + // For buy orders below current price, use stop-loss mechanism + // For sell orders above current price, use take-profit mechanism + var err error + if req.Side == "BUY" { + err = a.Trader.SetStopLoss(req.Symbol, "SHORT", req.Quantity, req.Price) + } else { + err = a.Trader.SetTakeProfit(req.Symbol, "LONG", req.Quantity, req.Price) + } + if err != nil { + return nil, err + } + return &LimitOrderResult{ + OrderID: req.ClientID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} +``` + +### Step 1.2: 在 InitializeGrid 中设置杠杆 + +修改 `trader/auto_trader_grid.go`, 在 `InitializeGrid()` 函数末尾添加: + +```go +// InitializeGrid initializes the grid state and calculates levels +func (at *AutoTrader) InitializeGrid() error { + // ... 现有代码 ... + + at.gridState.IsInitialized = true + + // CRITICAL: Set leverage on exchange before trading + if err := at.trader.SetLeverage(gridConfig.Symbol, gridConfig.Leverage); err != nil { + logger.Warnf("[Grid] Failed to set leverage %dx on exchange: %v", gridConfig.Leverage, err) + // Not fatal - continue with default leverage + } else { + logger.Infof("[Grid] Leverage set to %dx for %s", gridConfig.Leverage, gridConfig.Symbol) + } + + logger.Infof("[Grid] Initialized: %d levels, $%.2f - $%.2f, spacing $%.2f", + gridConfig.GridCount, at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) + + return nil +} +``` + +### Step 1.3: 运行测试验证 + +```bash +go build ./trader/ +go test -v -run "TestLighter.*Leverage" ./trader/ -timeout 60s +``` + +### Step 1.4: 提交 + +```bash +git add trader/interface.go trader/auto_trader_grid.go +git commit -m "fix(grid): add leverage setting before order placement + +CRITICAL BUG FIX: +- Call SetLeverage() in GridTraderAdapter.PlaceLimitOrder() +- Set leverage during grid initialization +- Log leverage setting results" +``` + +--- + +## Task 2: 修复订单取消逻辑BUG + +**问题:** `GridTraderAdapter.CancelOrder()` 错误地调用 `CancelAllOrders()` + +**Files:** +- Modify: `trader/interface.go:196-200` + +### Step 2.1: 修复 CancelOrder 实现 + +修改 `trader/interface.go`: + +```go +// CancelOrder cancels a specific order +func (a *GridTraderAdapter) CancelOrder(symbol, orderID string) error { + // Try to use CancelOrder if trader supports it directly + if canceler, ok := a.Trader.(interface { + CancelOrder(symbol, orderID string) error + }); ok { + return canceler.CancelOrder(symbol, orderID) + } + + // For traders that only support CancelAllOrders, log a warning + // This is a limitation - we cannot cancel individual orders + logger.Warnf("[Grid] Trader does not support individual order cancellation, " + + "cannot cancel order %s. Consider using exchange-specific GridTrader implementation.", orderID) + + // Return error instead of canceling all orders + return fmt.Errorf("individual order cancellation not supported for this exchange") +} +``` + +### Step 2.2: 添加 fmt import (如果缺失) + +确保 `trader/interface.go` 顶部有: +```go +import ( + "fmt" + // ... 其他imports +) +``` + +### Step 2.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 2.4: 提交 + +```bash +git add trader/interface.go +git commit -m "fix(grid): prevent CancelOrder from canceling all orders + +CRITICAL BUG FIX: +- CancelOrder no longer calls CancelAllOrders +- Try exchange-specific CancelOrder if available +- Return error if individual cancellation not supported" +``` + +--- + +## Task 3: 添加总仓位限制 + +**问题:** 只检查单层仓位,不检查总仓位,导致可能开出巨额仓位 + +**Files:** +- Modify: `trader/auto_trader_grid.go:324-409` +- Modify: `trader/auto_trader_grid.go` (新增 `checkTotalPositionLimit` 函数) + +### Step 3.1: 添加总仓位检查函数 + +在 `trader/auto_trader_grid.go` 中 `placeGridLimitOrder` 函数之前添加: + +```go +// checkTotalPositionLimit checks if adding a new position would exceed total limits +// Returns: (allowed bool, currentPositionValue float64, maxAllowed float64) +func (at *AutoTrader) checkTotalPositionLimit(symbol string, additionalValue float64) (bool, float64, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + // Calculate max allowed total position value + // Total position should not exceed: TotalInvestment × Leverage + maxTotalPositionValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) + + // Get current position value from exchange + currentPositionValue := 0.0 + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == symbol { + if size, ok := pos["positionAmt"].(float64); ok { + if price, ok := pos["markPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * price + } else if entryPrice, ok := pos["entryPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * entryPrice + } + } + } + } + } + + // Also count pending orders as potential position + at.gridState.mu.RLock() + pendingValue := 0.0 + for _, level := range at.gridState.Levels { + if level.State == "pending" { + pendingValue += level.OrderQuantity * level.Price + } + } + at.gridState.mu.RUnlock() + + totalAfterOrder := currentPositionValue + pendingValue + additionalValue + allowed := totalAfterOrder <= maxTotalPositionValue + + return allowed, currentPositionValue + pendingValue, maxTotalPositionValue +} +``` + +### Step 3.2: 在 placeGridLimitOrder 中使用总仓位检查 + +修改 `trader/auto_trader_grid.go` 的 `placeGridLimitOrder` 函数,在现有检查之后添加: + +```go +func (at *AutoTrader) placeGridLimitOrder(d *kernel.Decision, side string) error { + // ... 现有代码到 line 377 ... + + // CRITICAL: Check total position limit before placing order + orderValue := quantity * d.Price + allowed, currentValue, maxValue := at.checkTotalPositionLimit(d.Symbol, orderValue) + if !allowed { + logger.Errorf("[Grid] TOTAL POSITION LIMIT EXCEEDED: current=$%.2f + order=$%.2f > max=$%.2f. Rejecting order.", + currentValue, orderValue, maxValue) + return fmt.Errorf("total position value $%.2f would exceed limit $%.2f", currentValue+orderValue, maxValue) + } + + req := &LimitOrderRequest{ + // ... 现有代码 ... + } + // ... 其余代码 ... +} +``` + +### Step 3.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 3.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "fix(grid): add total position value limit check + +CRITICAL: Prevent excessive position accumulation +- New checkTotalPositionLimit() function +- Checks current + pending + new order value +- Rejects orders that would exceed TotalInvestment × Leverage +- Logs clear error messages when limit exceeded" +``` + +--- + +## Task 4: 添加止损执行机制 + +**问题:** `StopLossPct` 存在于配置但从未使用 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkAndExecuteStopLoss` 函数) +- Modify: `trader/auto_trader_grid.go:504-565` (在 `syncGridState` 中调用) + +### Step 4.1: 添加止损检查和执行函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// checkAndExecuteStopLoss checks if any filled level has exceeded stop loss and closes it +func (at *AutoTrader) checkAndExecuteStopLoss() { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.StopLossPct <= 0 { + return // Stop loss not configured + } + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get market price for stop loss check: %v", err) + return + } + + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State != "filled" || level.PositionEntry <= 0 { + continue + } + + // Calculate loss percentage + var lossPct float64 + if level.Side == "buy" { + // Long position: loss when price drops + lossPct = (level.PositionEntry - currentPrice) / level.PositionEntry * 100 + } else { + // Short position: loss when price rises + lossPct = (currentPrice - level.PositionEntry) / level.PositionEntry * 100 + } + + // Check if stop loss triggered + if lossPct >= gridConfig.StopLossPct { + logger.Warnf("[Grid] STOP LOSS TRIGGERED: Level %d, entry=$%.2f, current=$%.2f, loss=%.2f%%", + i, level.PositionEntry, currentPrice, lossPct) + + // Close the position + var closeErr error + if level.Side == "buy" { + _, closeErr = at.trader.CloseLong(gridConfig.Symbol, level.PositionSize) + } else { + _, closeErr = at.trader.CloseShort(gridConfig.Symbol, level.PositionSize) + } + + if closeErr != nil { + logger.Errorf("[Grid] Failed to execute stop loss for level %d: %v", i, closeErr) + } else { + level.State = "stopped" + level.UnrealizedPnL = -lossPct * level.AllocatedUSD / 100 + at.gridState.TotalTrades++ + logger.Infof("[Grid] Stop loss executed: Level %d closed at $%.2f (loss %.2f%%)", + i, currentPrice, lossPct) + } + } + } +} +``` + +### Step 4.2: 在 syncGridState 中调用止损检查 + +修改 `trader/auto_trader_grid.go` 的 `syncGridState` 函数末尾: + +```go +func (at *AutoTrader) syncGridState() { + // ... 现有代码 ... + + logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", totalPosition, len(openOrders)) + + // CRITICAL: Check stop loss for filled levels + at.checkAndExecuteStopLoss() +} +``` + +### Step 4.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 4.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): implement stop loss execution + +CRITICAL: Add code-level stop loss protection +- New checkAndExecuteStopLoss() function +- Checks each filled level against StopLossPct +- Automatically closes positions exceeding stop loss +- Called during every grid state sync" +``` + +--- + +## Task 5: 添加突破检测机制 + +**问题:** 价格突破网格边界时无响应,继续执行导致单边亏损 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkBreakout` 函数) +- Modify: `trader/auto_trader_grid.go:184-224` (在 `RunGridCycle` 中调用) + +### Step 5.1: 添加突破检测函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// BreakoutType represents the type of price breakout +type BreakoutType string + +const ( + BreakoutNone BreakoutType = "none" + BreakoutUpper BreakoutType = "upper" + BreakoutLower BreakoutType = "lower" +) + +// checkBreakout detects if price has broken out of grid range +// Returns breakout type and percentage beyond boundary +func (at *AutoTrader) checkBreakout() (BreakoutType, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return BreakoutNone, 0 + } + + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + if upper <= 0 || lower <= 0 { + return BreakoutNone, 0 + } + + // Check upper breakout + if currentPrice > upper { + breakoutPct := (currentPrice - upper) / upper * 100 + return BreakoutUpper, breakoutPct + } + + // Check lower breakout + if currentPrice < lower { + breakoutPct := (lower - currentPrice) / lower * 100 + return BreakoutLower, breakoutPct + } + + return BreakoutNone, 0 +} + +// handleBreakout handles price breakout from grid range +func (at *AutoTrader) handleBreakout(breakoutType BreakoutType, breakoutPct float64) error { + gridConfig := at.config.StrategyConfig.GridConfig + + logger.Warnf("[Grid] BREAKOUT DETECTED: %s, %.2f%% beyond boundary", breakoutType, breakoutPct) + + // If breakout exceeds 2%, pause grid and cancel orders + if breakoutPct >= 2.0 { + logger.Warnf("[Grid] Significant breakout (%.2f%%), pausing grid and canceling orders", breakoutPct) + + // Cancel all pending orders to prevent further losses + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders on breakout: %v", err) + } + + // Pause grid trading + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + return fmt.Errorf("grid paused due to %s breakout (%.2f%%)", breakoutType, breakoutPct) + } + + // If breakout is minor (< 2%), consider adjusting grid + if breakoutPct >= 1.0 { + logger.Infof("[Grid] Minor breakout (%.2f%%), considering grid adjustment", breakoutPct) + // Let AI decide whether to adjust + } + + return nil +} +``` + +### Step 5.2: 在 RunGridCycle 中添加突破检测 + +修改 `trader/auto_trader_grid.go` 的 `RunGridCycle` 函数: + +```go +func (at *AutoTrader) RunGridCycle() error { + if at.gridState == nil || !at.gridState.IsInitialized { + if err := at.InitializeGrid(); err != nil { + return fmt.Errorf("failed to initialize grid: %w", err) + } + } + + // CRITICAL: Check for breakout before executing any trades + breakoutType, breakoutPct := at.checkBreakout() + if breakoutType != BreakoutNone { + if err := at.handleBreakout(breakoutType, breakoutPct); err != nil { + return err // Grid paused due to breakout + } + } + + // Check if grid is paused + at.gridState.mu.RLock() + isPaused := at.gridState.IsPaused + at.gridState.mu.RUnlock() + if isPaused { + logger.Infof("[Grid] Grid is paused, skipping cycle") + return nil + } + + gridConfig := at.config.StrategyConfig.GridConfig + // ... 其余现有代码 ... +} +``` + +### Step 5.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 5.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): add breakout detection and auto-pause + +CRITICAL: Detect price breakout from grid range +- New checkBreakout() function +- Auto-pause grid on significant breakout (>2%) +- Cancel all orders when breakout detected +- Prevent continued losses in trending market" +``` + +--- + +## Task 6: 添加 MaxDrawdown 强制执行 + +**问题:** `MaxDrawdownPct` 存在于配置但从未检查 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkMaxDrawdown` 函数) +- Modify: `trader/auto_trader_grid.go:184-224` (在 `RunGridCycle` 中调用) + +### Step 6.1: 添加最大回撤检查函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// checkMaxDrawdown checks if current drawdown exceeds maximum allowed +// Returns: (exceeded bool, currentDrawdown float64) +func (at *AutoTrader) checkMaxDrawdown() (bool, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.MaxDrawdownPct <= 0 { + return false, 0 + } + + // Get current equity + balance, err := at.trader.GetBalance() + if err != nil { + return false, 0 + } + + currentEquity := 0.0 + if equity, ok := balance["total_equity"].(float64); ok { + currentEquity = equity + } else if total, ok := balance["totalWalletBalance"].(float64); ok { + if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { + currentEquity = total + unrealized + } + } + + if currentEquity <= 0 { + return false, 0 + } + + // Update peak equity + at.gridState.mu.Lock() + if currentEquity > at.gridState.PeakEquity { + at.gridState.PeakEquity = currentEquity + } + peakEquity := at.gridState.PeakEquity + at.gridState.mu.Unlock() + + if peakEquity <= 0 { + return false, 0 + } + + // Calculate current drawdown + drawdown := (peakEquity - currentEquity) / peakEquity * 100 + + // Update max drawdown tracking + at.gridState.mu.Lock() + if drawdown > at.gridState.MaxDrawdown { + at.gridState.MaxDrawdown = drawdown + } + at.gridState.mu.Unlock() + + return drawdown >= gridConfig.MaxDrawdownPct, drawdown +} + +// emergencyExit closes all positions and cancels all orders +func (at *AutoTrader) emergencyExit(reason string) error { + gridConfig := at.config.StrategyConfig.GridConfig + + logger.Errorf("[Grid] EMERGENCY EXIT: %s", reason) + + // Cancel all orders + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders in emergency: %v", err) + } + + // Close all positions + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok && size != 0 { + if size > 0 { + at.trader.CloseLong(gridConfig.Symbol, size) + } else { + at.trader.CloseShort(gridConfig.Symbol, -size) + } + } + } + } + } + + // Pause grid + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + return nil +} +``` + +### Step 6.2: 在 RunGridCycle 中添加回撤检查 + +修改 `trader/auto_trader_grid.go` 的 `RunGridCycle` 函数,在突破检测后添加: + +```go +func (at *AutoTrader) RunGridCycle() error { + // ... 初始化检查 ... + + // CRITICAL: Check for breakout + // ... 突破检测代码 ... + + // CRITICAL: Check max drawdown + exceeded, drawdown := at.checkMaxDrawdown() + if exceeded { + return at.emergencyExit(fmt.Sprintf("max drawdown exceeded: %.2f%%", drawdown)) + } + + // ... 其余代码 ... +} +``` + +### Step 6.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 6.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): enforce max drawdown limit with emergency exit + +CRITICAL: Add drawdown protection +- New checkMaxDrawdown() function tracks peak equity +- emergencyExit() closes all positions and cancels orders +- Auto-pause grid when MaxDrawdownPct exceeded +- Protect capital from excessive losses" +``` + +--- + +## Task 7: 添加 DailyLossLimit 强制执行 + +**问题:** `DailyLossLimitPct` 存在于配置但从未检查 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkDailyLossLimit` 函数) +- Modify: `trader/auto_trader_grid.go:184-224` (在 `RunGridCycle` 中调用) + +### Step 7.1: 添加日损失限制检查函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// checkDailyLossLimit checks if daily loss exceeds limit +// Returns: (exceeded bool, dailyLossPct float64) +func (at *AutoTrader) checkDailyLossLimit() (bool, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.DailyLossLimitPct <= 0 { + return false, 0 + } + + at.gridState.mu.Lock() + // Reset daily PnL if new day + now := time.Now() + if now.YearDay() != at.gridState.LastDailyReset.YearDay() || + now.Year() != at.gridState.LastDailyReset.Year() { + at.gridState.DailyPnL = 0 + at.gridState.LastDailyReset = now + } + dailyPnL := at.gridState.DailyPnL + at.gridState.mu.Unlock() + + // Calculate daily loss as percentage of total investment + dailyLossPct := 0.0 + if gridConfig.TotalInvestment > 0 && dailyPnL < 0 { + dailyLossPct = (-dailyPnL) / gridConfig.TotalInvestment * 100 + } + + return dailyLossPct >= gridConfig.DailyLossLimitPct, dailyLossPct +} + +// updateDailyPnL updates the daily PnL tracking +func (at *AutoTrader) updateDailyPnL(realizedPnL float64) { + at.gridState.mu.Lock() + at.gridState.DailyPnL += realizedPnL + at.gridState.TotalProfit += realizedPnL + at.gridState.mu.Unlock() +} +``` + +### Step 7.2: 在 RunGridCycle 中添加日损失检查 + +修改 `trader/auto_trader_grid.go` 的 `RunGridCycle` 函数: + +```go +func (at *AutoTrader) RunGridCycle() error { + // ... 初始化和突破检测 ... + + // CRITICAL: Check max drawdown + // ... + + // CRITICAL: Check daily loss limit + exceeded, dailyLossPct := at.checkDailyLossLimit() + if exceeded { + logger.Errorf("[Grid] Daily loss limit exceeded: %.2f%%", dailyLossPct) + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return fmt.Errorf("daily loss limit exceeded: %.2f%%", dailyLossPct) + } + + // ... 其余代码 ... +} +``` + +### Step 7.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 7.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): enforce daily loss limit + +- New checkDailyLossLimit() function +- Track daily PnL with auto-reset at midnight +- Pause grid when DailyLossLimitPct exceeded +- Prevent excessive single-day losses" +``` + +--- + +## Task 8: 添加自动网格调整 + +**问题:** 网格无法自动适应价格偏移 + +**Files:** +- Modify: `trader/auto_trader_grid.go` (添加 `checkGridSkew` 函数) +- Modify: `trader/auto_trader_grid.go:504-565` (在 `syncGridState` 中调用) + +### Step 8.1: 添加网格倾斜检测函数 + +在 `trader/auto_trader_grid.go` 中添加: + +```go +// checkGridSkew checks if grid is heavily skewed (too many fills on one side) +// Returns: (skewed bool, buyFilledCount int, sellFilledCount int) +func (at *AutoTrader) checkGridSkew() (bool, int, int) { + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + buyFilled := 0 + sellFilled := 0 + buyEmpty := 0 + sellEmpty := 0 + + for _, level := range at.gridState.Levels { + if level.Side == "buy" { + if level.State == "filled" { + buyFilled++ + } else if level.State == "empty" { + buyEmpty++ + } + } else { + if level.State == "filled" { + sellFilled++ + } else if level.State == "empty" { + sellEmpty++ + } + } + } + + // Grid is skewed if one side has 3x more fills than the other + // or if one side is completely empty + skewed := false + if buyFilled > 0 && sellFilled == 0 && sellEmpty > 5 { + skewed = true // All buys filled, no sells + } else if sellFilled > 0 && buyFilled == 0 && buyEmpty > 5 { + skewed = true // All sells filled, no buys + } else if buyFilled >= 3*sellFilled && buyFilled > 5 { + skewed = true + } else if sellFilled >= 3*buyFilled && sellFilled > 5 { + skewed = true + } + + return skewed, buyFilled, sellFilled +} + +// autoAdjustGrid automatically adjusts grid when heavily skewed +func (at *AutoTrader) autoAdjustGrid() { + skewed, buyFilled, sellFilled := at.checkGridSkew() + if !skewed { + return + } + + logger.Warnf("[Grid] Grid heavily skewed: buy_filled=%d, sell_filled=%d. Auto-adjusting...", + buyFilled, sellFilled) + + gridConfig := at.config.StrategyConfig.GridConfig + + // Get current price + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Errorf("[Grid] Failed to get price for auto-adjust: %v", err) + return + } + + // Check if price is near grid boundary + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + // Only adjust if price has moved significantly (>50% of grid range) + gridRange := upper - lower + midPrice := (upper + lower) / 2 + priceDeviation := math.Abs(currentPrice - midPrice) + + if priceDeviation < gridRange*0.3 { + return // Price still near center, don't adjust + } + + // Cancel existing orders and reinitialize + logger.Infof("[Grid] Adjusting grid around new price $%.2f", currentPrice) + at.cancelAllGridOrders() + at.initializeGridLevels(currentPrice, gridConfig) +} +``` + +### Step 8.2: 在 syncGridState 中调用自动调整 + +修改 `trader/auto_trader_grid.go` 的 `syncGridState` 函数: + +```go +func (at *AutoTrader) syncGridState() { + // ... 现有代码 ... + + // Check stop loss + at.checkAndExecuteStopLoss() + + // Check grid skew and auto-adjust if needed + at.autoAdjustGrid() +} +``` + +### Step 8.3: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 8.4: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(grid): add automatic grid adjustment + +- New checkGridSkew() detects imbalanced grid +- autoAdjustGrid() reinitializes around current price +- Prevents grid from becoming ineffective after drift +- Triggers when one side is 3x more filled than other" +``` + +--- + +## Task 9: 修复订单状态同步逻辑 + +**问题:** 假设订单不存在就是成交,但可能是被取消 + +**Files:** +- Modify: `trader/auto_trader_grid.go:504-565` + +### Step 9.1: 改进订单状态同步逻辑 + +修改 `trader/auto_trader_grid.go` 的 `syncGridState` 函数: + +```go +// syncGridState syncs grid state with exchange +func (at *AutoTrader) syncGridState() { + gridConfig := at.config.StrategyConfig.GridConfig + + // Get open orders from exchange + openOrders, err := at.trader.GetOpenOrders(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get open orders: %v", err) + return + } + + // Build set of active order IDs + activeOrderIDs := make(map[string]bool) + for _, order := range openOrders { + activeOrderIDs[order.OrderID] = true + } + + // Get current positions to verify fills + positions, err := at.trader.GetPositions() + currentPositionSize := 0.0 + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok { + currentPositionSize = size + } + } + } + } + + // Update levels based on order status + at.gridState.mu.Lock() + previousFilledCount := 0 + for _, level := range at.gridState.Levels { + if level.State == "filled" { + previousFilledCount++ + } + } + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State == "pending" && level.OrderID != "" { + if !activeOrderIDs[level.OrderID] { + // Order no longer exists - check if position changed to determine fill vs cancel + // This is a heuristic - ideally we'd query order history + if math.Abs(currentPositionSize) > math.Abs(float64(previousFilledCount)*level.OrderQuantity) { + // Position increased, likely filled + level.State = "filled" + level.PositionEntry = level.Price + level.PositionSize = level.OrderQuantity + at.gridState.TotalTrades++ + logger.Infof("[Grid] Level %d order filled at $%.2f", i, level.Price) + } else { + // Position didn't increase as expected, likely cancelled + level.State = "empty" + level.OrderID = "" + level.OrderQuantity = 0 + logger.Infof("[Grid] Level %d order cancelled/expired", i) + } + delete(at.gridState.OrderBook, level.OrderID) + } + } + } + at.gridState.mu.Unlock() + + logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", currentPositionSize, len(openOrders)) + + // Check stop loss + at.checkAndExecuteStopLoss() + + // Check grid skew + at.autoAdjustGrid() +} +``` + +### Step 9.2: 运行测试验证 + +```bash +go build ./trader/ +``` + +### Step 9.3: 提交 + +```bash +git add trader/auto_trader_grid.go +git commit -m "fix(grid): improve order state sync logic + +- Don't assume missing orders are filled +- Compare position size to determine fill vs cancel +- Properly reset cancelled orders to empty state +- More accurate grid state tracking" +``` + +--- + +## 完成后的验证步骤 + +### 全面测试 + +```bash +# 编译验证 +go build ./... + +# 运行所有trader测试 +go test -v ./trader/... -timeout 300s + +# 运行网格相关测试 +go test -v -run "Grid" ./trader/ -timeout 60s +``` + +### 代码审查清单 + +- [ ] 所有P0致命问题已修复 +- [ ] 所有P1严重问题已修复 +- [ ] 杠杆在初始化时设置 +- [ ] 订单取消逻辑正确 +- [ ] 总仓位有限制 +- [ ] 止损被执行 +- [ ] 突破时自动暂停 +- [ ] MaxDrawdown触发紧急退出 +- [ ] DailyLossLimit暂停交易 +- [ ] 网格自动调整 + +--- + +## 架构改进总结 + +``` +修复后的架构: + +┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ ┌─────────────┐ +│ 市场数据 │ ──▶ │ AI决策 │ ──▶ │ 代码级风控验证 │ ──▶ │ 执行交易 │ +└─────────────┘ └─────────────┘ └─────────────────────────┘ └─────────────┘ + │ + ▼ + ┌────────────────────────────────────────────────────┐ + │ 风控检查清单 (每个周期执行) │ + │ ✓ checkBreakout() - 突破检测 │ + │ ✓ checkMaxDrawdown() - 最大回撤 │ + │ ✓ checkDailyLossLimit() - 日损失限制 │ + │ ✓ checkTotalPositionLimit() - 总仓位限制 │ + │ ✓ checkAndExecuteStopLoss() - 止损执行 │ + │ ✓ checkGridSkew() - 网格平衡 │ + │ ✓ SetLeverage() - 杠杆设置 │ + └────────────────────────────────────────────────────┘ +``` diff --git a/docs/plans/2026-01-17-grid-market-regime-design.md b/docs/plans/2026-01-17-grid-market-regime-design.md new file mode 100644 index 00000000..bd9ba7c1 --- /dev/null +++ b/docs/plans/2026-01-17-grid-market-regime-design.md @@ -0,0 +1,151 @@ +# 网格策略市场状态识别与风控设计 + +## 概述 + +增强网格策略的市场状态识别能力,实现震荡/趋势的精准判断,并根据不同震荡级别自动调整网格参数和风控策略。 + +--- + +## 一、市场状态识别 + +### 1.1 识别维度(3个) + +| 维度 | 指标 | 作用 | +|------|------|------| +| 价格波动 | ATR14 + Bollinger带宽 | 判断震荡幅度 | +| 趋势强度 | EMA20/50距离 + MACD | 判断是否有趋势 | +| 动量 | RSI14 + 1h/4h涨跌幅 | 判断超买超卖 | + +### 1.2 箱体指标(新增) + +基于1小时K线的多周期Donchian通道: + +| 箱体级别 | 周期 | 覆盖时间 | 用途 | +|----------|------|----------|------| +| 短期箱体 | 72根1小时 | 3天 | 日内波动边界 | +| 中期箱体 | 240根1小时 | 10天 | 周级别震荡区间 | +| 长期箱体 | 500根1小时 | ~21天 | 大级别趋势边界 | + +### 1.3 判断方式 + +由AI综合分析以上指标 + 原始K线序列 + 箱体位置,输出市场状态判断。 + +--- + +## 二、震荡分级与网格策略 + +### 2.1 四级震荡分类 + +| 级别 | 特征 | 判断依据 | +|------|------|----------| +| 窄幅震荡 | 价格在短期箱体内小幅波动 | Bollinger带宽 < 2%,ATR低 | +| 标准震荡 | 价格在中期箱体内正常波动 | Bollinger带宽 2-3%,ATR正常 | +| 宽幅震荡 | 价格接近中期箱体边缘 | Bollinger带宽 3-4%,ATR较高 | +| 剧烈震荡 | 价格接近长期箱体边缘 | Bollinger带宽 > 4%,ATR高 | + +### 2.2 各级别对应的网格策略 + +| 级别 | 网格密度 | 网格范围 | 单格仓位 | 总仓位上限 | 有效杠杆上限 | +|------|----------|----------|----------|------------|--------------| +| 窄幅震荡 | 密集 | 窄 | 小 | 30-40% | 2x | +| 标准震荡 | 正常 | 中等 | 正常 | 60-70% | 3-4x | +| 宽幅震荡 | 稀疏 | 宽 | 正常 | 50-60% | 3x | +| 剧烈震荡 | 最稀疏 | 最宽 | 小 | 30-40% | 2x | + +**核心原则:** +- 窄幅震荡:单格仓位小 + 总仓位上限低(防击穿风险) +- 剧烈震荡:同样保守(随时可能变趋势) +- 标准震荡:才是放量的最佳时机 + +--- + +## 三、突破处理与恢复机制 + +### 3.1 突破判断与处理 + +**确认方式:** 收盘价突破箱体后,持续3根1小时K线不回箱体 + +| 箱体级别 | 突破处理 | +|----------|----------| +| 短期箱体突破 | 降低仓位到 50% | +| 中期箱体突破 | 暂停网格 + 取消挂单 | +| 长期箱体突破 | 暂停网格 + 取消挂单 + 平掉所有持仓 | + +### 3.2 假突破恢复 + +**价格回到箱体内 → 以50%仓位恢复网格** + +--- + +## 四、前端风控面板 + +### 4.1 需要展示的信息 + +| 类别 | 显示内容 | +|------|----------| +| 杠杆信息 | 当前杠杆、有效杠杆、系统推荐杠杆 | +| 仓位信息 | 当前仓位、最大仓位、仓位占比 | +| 爆仓信息 | 爆仓价格、爆仓距离(%) | +| 市场状态 | 当前震荡级别(窄幅/标准/宽幅/剧烈) | +| 箱体状态 | 短期/中期/长期箱体上下沿、当前价格位置 | + +--- + +## 五、实现要点 + +### 5.1 后端新增 + +1. **箱体指标计算** (`market/data.go`) + - 新增 `calculateDonchian(klines, period)` 函数 + - 返回 upper(最高价), lower(最低价) + - 支持72/240/500三个周期 + +2. **市场状态评估** (`kernel/grid_engine.go`) + - 更新AI prompt,加入箱体指标和K线序列 + - AI输出震荡级别判断 + +3. **网格参数动态调整** (`trader/auto_trader_grid.go`) + - 根据震荡级别自动调整:网格密度、范围、仓位、杠杆 + - 实现有效杠杆上限控制 + +4. **突破处理逻辑** (`trader/auto_trader_grid.go`) + - 实现三级箱体突破检测 + - 实现3根K线确认逻辑 + - 实现降级恢复机制 + +### 5.2 前端新增 + +1. **风控面板组件** + - 杠杆信息展示 + - 仓位信息展示 + - 爆仓信息展示 + - 市场状态展示 + - 箱体状态可视化 + +### 5.3 数据模型更新 + +1. **GridConfigModel** 新增字段: + - `EffectiveLeverageLimit` - 有效杠杆上限 + - `ShortBoxPeriod` - 短期箱体周期 (默认72) + - `MidBoxPeriod` - 中期箱体周期 (默认240) + - `LongBoxPeriod` - 长期箱体周期 (默认500) + +2. **GridInstanceModel** 新增字段: + - `CurrentRegimeLevel` - 当前震荡级别 (narrow/standard/wide/volatile) + - `ShortBoxUpper/Lower` - 短期箱体上下沿 + - `MidBoxUpper/Lower` - 中期箱体上下沿 + - `LongBoxUpper/Lower` - 长期箱体上下沿 + - `BreakoutStatus` - 突破状态 (none/short/mid/long) + - `BreakoutConfirmCount` - 突破确认K线计数 + +--- + +## 六、风险控制总结 + +| 控制点 | 机制 | +|--------|------| +| 仓位控制 | 根据震荡级别限制总仓位上限 (30-70%) | +| 杠杆控制 | 根据震荡级别限制有效杠杆 (2-4x) | +| 突破保护 | 三级箱体突破分级处理 | +| 假突破恢复 | 50%仓位降级恢复 | +| 爆仓预防 | 前端展示爆仓距离,系统自动限制杠杆 | diff --git a/docs/plans/2026-01-17-grid-market-regime-impl.md b/docs/plans/2026-01-17-grid-market-regime-impl.md new file mode 100644 index 00000000..a28d2b44 --- /dev/null +++ b/docs/plans/2026-01-17-grid-market-regime-impl.md @@ -0,0 +1,1655 @@ +# Grid Market Regime Detection Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Implement multi-period box indicators and 4-level ranging classification for grid trading with automatic parameter adjustment and breakout handling. + +**Architecture:** Add Donchian channel calculation to market package, extend grid models with box/regime fields, implement breakout detection in auto_trader_grid, add risk control panel to frontend. + +**Tech Stack:** Go (backend), React/TypeScript (frontend), GORM (database), 1-hour Kline data + +--- + +## Task 1: Add Donchian Channel Calculation + +**Files:** +- Modify: `market/data.go` +- Test: `market/data_test.go` + +**Step 1: Write the failing test** + +Add to `market/data_test.go`: + +```go +func TestCalculateDonchian(t *testing.T) { + // Create test klines with known high/low values + klines := []Kline{ + {High: 100, Low: 90}, + {High: 105, Low: 88}, + {High: 102, Low: 92}, + {High: 108, Low: 85}, + {High: 103, Low: 91}, + } + + upper, lower := calculateDonchian(klines, 5) + + if upper != 108 { + t.Errorf("Expected upper = 108, got %v", upper) + } + if lower != 85 { + t.Errorf("Expected lower = 85, got %v", lower) + } +} + +func TestCalculateDonchian_PartialPeriod(t *testing.T) { + klines := []Kline{ + {High: 100, Low: 90}, + {High: 105, Low: 88}, + } + + upper, lower := calculateDonchian(klines, 10) + + // Should use all available klines when period > len(klines) + if upper != 105 { + t.Errorf("Expected upper = 105, got %v", upper) + } + if lower != 88 { + t.Errorf("Expected lower = 88, got %v", lower) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./market/... -run TestCalculateDonchian` +Expected: FAIL with "undefined: calculateDonchian" + +**Step 3: Write minimal implementation** + +Add to `market/data.go`: + +```go +// calculateDonchian calculates Donchian channel (highest high, lowest low) for given period +func calculateDonchian(klines []Kline, period int) (upper, lower float64) { + if len(klines) == 0 { + return 0, 0 + } + + // Use all available klines if period > len(klines) + start := len(klines) - period + if start < 0 { + start = 0 + } + + upper = klines[start].High + lower = klines[start].Low + + for i := start + 1; i < len(klines); i++ { + if klines[i].High > upper { + upper = klines[i].High + } + if klines[i].Low < lower { + lower = klines[i].Low + } + } + + return upper, lower +} + +// ExportCalculateDonchian exports calculateDonchian for testing +func ExportCalculateDonchian(klines []Kline, period int) (float64, float64) { + return calculateDonchian(klines, period) +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./market/... -run TestCalculateDonchian` +Expected: PASS + +**Step 5: Commit** + +```bash +git add market/data.go market/data_test.go +git commit -m "feat(market): add Donchian channel calculation" +``` + +--- + +## Task 2: Add Box Data Types + +**Files:** +- Modify: `market/types.go` + +**Step 1: Add BoxData struct** + +Add to `market/types.go`: + +```go +// BoxData represents multi-period Donchian channel (box) data +type BoxData struct { + // Short-term box (72 1h candles = 3 days) + ShortUpper float64 `json:"short_upper"` + ShortLower float64 `json:"short_lower"` + + // Mid-term box (240 1h candles = 10 days) + MidUpper float64 `json:"mid_upper"` + MidLower float64 `json:"mid_lower"` + + // Long-term box (500 1h candles = ~21 days) + LongUpper float64 `json:"long_upper"` + LongLower float64 `json:"long_lower"` + + // Current price position relative to boxes + CurrentPrice float64 `json:"current_price"` +} + +// RegimeLevel represents the ranging classification level +type RegimeLevel string + +const ( + RegimeLevelNarrow RegimeLevel = "narrow" // 窄幅震荡 + RegimeLevelStandard RegimeLevel = "standard" // 标准震荡 + RegimeLevelWide RegimeLevel = "wide" // 宽幅震荡 + RegimeLevelVolatile RegimeLevel = "volatile" // 剧烈震荡 + RegimeLevelTrending RegimeLevel = "trending" // 趋势 +) + +// BreakoutLevel represents which box level has been broken +type BreakoutLevel string + +const ( + BreakoutNone BreakoutLevel = "none" + BreakoutShort BreakoutLevel = "short" + BreakoutMid BreakoutLevel = "mid" + BreakoutLong BreakoutLevel = "long" +) +``` + +**Step 2: Commit** + +```bash +git add market/types.go +git commit -m "feat(market): add BoxData and RegimeLevel types" +``` + +--- + +## Task 3: Add GetBoxData Function + +**Files:** +- Modify: `market/data.go` +- Test: `market/data_test.go` + +**Step 1: Write the failing test** + +Add to `market/data_test.go`: + +```go +func TestGetBoxData(t *testing.T) { + // This test requires mocking kline data source + // For now, test the internal calculation logic + klines := make([]Kline, 500) + for i := 0; i < 500; i++ { + // Create synthetic price data + basePrice := 100.0 + klines[i] = Kline{ + High: basePrice + float64(i%10), + Low: basePrice - float64(i%10), + } + } + + box := calculateBoxData(klines, 100.0) + + if box.ShortUpper == 0 || box.ShortLower == 0 { + t.Error("Short box should not be zero") + } + if box.MidUpper == 0 || box.MidLower == 0 { + t.Error("Mid box should not be zero") + } + if box.LongUpper == 0 || box.LongLower == 0 { + t.Error("Long box should not be zero") + } + if box.CurrentPrice != 100.0 { + t.Errorf("Expected CurrentPrice = 100.0, got %v", box.CurrentPrice) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./market/... -run TestGetBoxData` +Expected: FAIL with "undefined: calculateBoxData" + +**Step 3: Write minimal implementation** + +Add to `market/data.go`: + +```go +const ( + ShortBoxPeriod = 72 // 3 days of 1h candles + MidBoxPeriod = 240 // 10 days of 1h candles + LongBoxPeriod = 500 // ~21 days of 1h candles +) + +// calculateBoxData calculates multi-period box data from klines +func calculateBoxData(klines []Kline, currentPrice float64) *BoxData { + box := &BoxData{ + CurrentPrice: currentPrice, + } + + if len(klines) == 0 { + return box + } + + box.ShortUpper, box.ShortLower = calculateDonchian(klines, ShortBoxPeriod) + box.MidUpper, box.MidLower = calculateDonchian(klines, MidBoxPeriod) + box.LongUpper, box.LongLower = calculateDonchian(klines, LongBoxPeriod) + + return box +} + +// GetBoxData fetches 1h klines and calculates box data for a symbol +func GetBoxData(symbol string) (*BoxData, error) { + symbol = Normalize(symbol) + + // Fetch 500 1h klines + var klines []Kline + var err error + + if IsXyzDexAsset(symbol) { + klines, err = getKlinesFromHyperliquid(symbol, "1h", LongBoxPeriod) + } else { + klines, err = getKlinesFromCoinAnk(symbol, "1h", LongBoxPeriod) + } + + if err != nil { + return nil, fmt.Errorf("failed to get 1h klines: %w", err) + } + + if len(klines) == 0 { + return nil, fmt.Errorf("no kline data available") + } + + currentPrice := klines[len(klines)-1].Close + + return calculateBoxData(klines, currentPrice), nil +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./market/... -run TestGetBoxData` +Expected: PASS + +**Step 5: Commit** + +```bash +git add market/data.go market/data_test.go +git commit -m "feat(market): add GetBoxData for multi-period box calculation" +``` + +--- + +## Task 4: Update GridConfigModel with Box Parameters + +**Files:** +- Modify: `store/grid.go` + +**Step 1: Add new fields to GridConfigModel** + +Add fields after `TrendResumeThreshold` in `store/grid.go`: + +```go + // Box indicator periods (1h candles) + ShortBoxPeriod int `json:"short_box_period" gorm:"default:72"` // 3 days + MidBoxPeriod int `json:"mid_box_period" gorm:"default:240"` // 10 days + LongBoxPeriod int `json:"long_box_period" gorm:"default:500"` // 21 days + + // Effective leverage limits by regime level + NarrowRegimeLeverage int `json:"narrow_regime_leverage" gorm:"default:2"` + StandardRegimeLeverage int `json:"standard_regime_leverage" gorm:"default:4"` + WideRegimeLeverage int `json:"wide_regime_leverage" gorm:"default:3"` + VolatileRegimeLeverage int `json:"volatile_regime_leverage" gorm:"default:2"` + + // Position limits by regime level (percentage of total investment) + NarrowRegimePositionPct float64 `json:"narrow_regime_position_pct" gorm:"default:40"` + StandardRegimePositionPct float64 `json:"standard_regime_position_pct" gorm:"default:70"` + WideRegimePositionPct float64 `json:"wide_regime_position_pct" gorm:"default:60"` + VolatileRegimePositionPct float64 `json:"volatile_regime_position_pct" gorm:"default:40"` +``` + +**Step 2: Commit** + +```bash +git add store/grid.go +git commit -m "feat(store): add box period and regime leverage fields to GridConfigModel" +``` + +--- + +## Task 5: Update GridInstanceModel with Box State + +**Files:** +- Modify: `store/grid.go` + +**Step 1: Add new fields to GridInstanceModel** + +Add fields after `ConsecutiveTrending` in `store/grid.go`: + +```go + // Current regime level (narrow/standard/wide/volatile/trending) + CurrentRegimeLevel string `json:"current_regime_level" gorm:"default:standard"` + + // Box state + ShortBoxUpper float64 `json:"short_box_upper"` + ShortBoxLower float64 `json:"short_box_lower"` + MidBoxUpper float64 `json:"mid_box_upper"` + MidBoxLower float64 `json:"mid_box_lower"` + LongBoxUpper float64 `json:"long_box_upper"` + LongBoxLower float64 `json:"long_box_lower"` + + // Breakout state + BreakoutLevel string `json:"breakout_level" gorm:"default:none"` // none/short/mid/long + BreakoutDirection string `json:"breakout_direction"` // up/down + BreakoutConfirmCount int `json:"breakout_confirm_count" gorm:"default:0"` + BreakoutStartTime time.Time `json:"breakout_start_time"` + + // Position adjustment due to breakout + PositionReductionPct float64 `json:"position_reduction_pct" gorm:"default:0"` // 0 = normal, 50 = reduced +``` + +**Step 2: Commit** + +```bash +git add store/grid.go +git commit -m "feat(store): add box state and breakout fields to GridInstanceModel" +``` + +--- + +## Task 6: Add Regime Level Classification + +**Files:** +- Create: `trader/grid_regime.go` +- Test: `trader/grid_regime_test.go` + +**Step 1: Write the failing test** + +Create `trader/grid_regime_test.go`: + +```go +package trader + +import ( + "nofx/market" + "testing" +) + +func TestClassifyRegimeLevel(t *testing.T) { + tests := []struct { + name string + bollingerWidth float64 + atr14Pct float64 + expected market.RegimeLevel + }{ + {"narrow", 1.5, 0.8, market.RegimeLevelNarrow}, + {"standard", 2.5, 1.5, market.RegimeLevelStandard}, + {"wide", 3.5, 2.5, market.RegimeLevelWide}, + {"volatile", 5.0, 4.0, market.RegimeLevelVolatile}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := classifyRegimeLevel(tt.bollingerWidth, tt.atr14Pct) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestClassifyRegimeLevel` +Expected: FAIL with "undefined: classifyRegimeLevel" + +**Step 3: Write minimal implementation** + +Create `trader/grid_regime.go`: + +```go +package trader + +import "nofx/market" + +// classifyRegimeLevel determines the regime level based on market indicators +// bollingerWidth: Bollinger band width as percentage +// atr14Pct: ATR14 as percentage of current price +func classifyRegimeLevel(bollingerWidth, atr14Pct float64) market.RegimeLevel { + // Narrow: Bollinger < 2%, ATR < 1% + if bollingerWidth < 2.0 && atr14Pct < 1.0 { + return market.RegimeLevelNarrow + } + + // Standard: Bollinger 2-3%, ATR 1-2% + if bollingerWidth <= 3.0 && atr14Pct <= 2.0 { + return market.RegimeLevelStandard + } + + // Wide: Bollinger 3-4%, ATR 2-3% + if bollingerWidth <= 4.0 && atr14Pct <= 3.0 { + return market.RegimeLevelWide + } + + // Volatile: Bollinger > 4%, ATR > 3% + return market.RegimeLevelVolatile +} + +// getRegimeLeverageLimit returns the effective leverage limit for a regime level +func getRegimeLeverageLimit(level market.RegimeLevel, config *store.GridStrategyConfig) int { + switch level { + case market.RegimeLevelNarrow: + return config.NarrowRegimeLeverage + case market.RegimeLevelStandard: + return config.StandardRegimeLeverage + case market.RegimeLevelWide: + return config.WideRegimeLeverage + case market.RegimeLevelVolatile: + return config.VolatileRegimeLeverage + default: + return 2 // Conservative default + } +} + +// getRegimePositionLimit returns the position limit percentage for a regime level +func getRegimePositionLimit(level market.RegimeLevel, config *store.GridStrategyConfig) float64 { + switch level { + case market.RegimeLevelNarrow: + return config.NarrowRegimePositionPct + case market.RegimeLevelStandard: + return config.StandardRegimePositionPct + case market.RegimeLevelWide: + return config.WideRegimePositionPct + case market.RegimeLevelVolatile: + return config.VolatileRegimePositionPct + default: + return 40.0 // Conservative default + } +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestClassifyRegimeLevel` +Expected: PASS + +**Step 5: Commit** + +```bash +git add trader/grid_regime.go trader/grid_regime_test.go +git commit -m "feat(trader): add regime level classification" +``` + +--- + +## Task 7: Add Breakout Detection + +**Files:** +- Modify: `trader/grid_regime.go` +- Test: `trader/grid_regime_test.go` + +**Step 1: Write the failing test** + +Add to `trader/grid_regime_test.go`: + +```go +func TestDetectBoxBreakout(t *testing.T) { + box := &market.BoxData{ + ShortUpper: 100, + ShortLower: 90, + MidUpper: 105, + MidLower: 85, + LongUpper: 110, + LongLower: 80, + CurrentPrice: 95, + } + + // No breakout + level, direction := detectBoxBreakout(box) + if level != market.BreakoutNone { + t.Errorf("Expected no breakout, got %v", level) + } + + // Short breakout up + box.CurrentPrice = 101 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutShort || direction != "up" { + t.Errorf("Expected short breakout up, got %v %v", level, direction) + } + + // Mid breakout down + box.CurrentPrice = 84 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutMid || direction != "down" { + t.Errorf("Expected mid breakout down, got %v %v", level, direction) + } + + // Long breakout up + box.CurrentPrice = 112 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutLong || direction != "up" { + t.Errorf("Expected long breakout up, got %v %v", level, direction) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestDetectBoxBreakout` +Expected: FAIL with "undefined: detectBoxBreakout" + +**Step 3: Write minimal implementation** + +Add to `trader/grid_regime.go`: + +```go +// detectBoxBreakout checks if price has broken out of any box level +// Returns the highest breakout level and direction +func detectBoxBreakout(box *market.BoxData) (market.BreakoutLevel, string) { + price := box.CurrentPrice + + // Check long box first (highest priority) + if price > box.LongUpper { + return market.BreakoutLong, "up" + } + if price < box.LongLower { + return market.BreakoutLong, "down" + } + + // Check mid box + if price > box.MidUpper { + return market.BreakoutMid, "up" + } + if price < box.MidLower { + return market.BreakoutMid, "down" + } + + // Check short box + if price > box.ShortUpper { + return market.BreakoutShort, "up" + } + if price < box.ShortLower { + return market.BreakoutShort, "down" + } + + return market.BreakoutNone, "" +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestDetectBoxBreakout` +Expected: PASS + +**Step 5: Commit** + +```bash +git add trader/grid_regime.go trader/grid_regime_test.go +git commit -m "feat(trader): add box breakout detection" +``` + +--- + +## Task 8: Add Breakout Confirmation Logic + +**Files:** +- Modify: `trader/grid_regime.go` +- Test: `trader/grid_regime_test.go` + +**Step 1: Write the failing test** + +Add to `trader/grid_regime_test.go`: + +```go +func TestBreakoutConfirmation(t *testing.T) { + state := &BreakoutState{ + Level: market.BreakoutShort, + Direction: "up", + ConfirmCount: 0, + } + + // First confirmation + confirmed := confirmBreakout(state, market.BreakoutShort, "up") + if confirmed || state.ConfirmCount != 1 { + t.Errorf("Expected not confirmed, count=1, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Second confirmation + confirmed = confirmBreakout(state, market.BreakoutShort, "up") + if confirmed || state.ConfirmCount != 2 { + t.Errorf("Expected not confirmed, count=2, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Third confirmation - should confirm + confirmed = confirmBreakout(state, market.BreakoutShort, "up") + if !confirmed || state.ConfirmCount != 3 { + t.Errorf("Expected confirmed, count=3, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Reset on price return + state.ConfirmCount = 2 + confirmed = confirmBreakout(state, market.BreakoutNone, "") + if state.ConfirmCount != 0 { + t.Errorf("Expected count reset to 0, got %d", state.ConfirmCount) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestBreakoutConfirmation` +Expected: FAIL with "undefined: BreakoutState" + +**Step 3: Write minimal implementation** + +Add to `trader/grid_regime.go`: + +```go +const BreakoutConfirmRequired = 3 // 3 candles to confirm breakout + +// BreakoutState tracks the current breakout state +type BreakoutState struct { + Level market.BreakoutLevel + Direction string + ConfirmCount int + StartTime time.Time +} + +// confirmBreakout updates breakout state and returns true if breakout is confirmed +func confirmBreakout(state *BreakoutState, currentLevel market.BreakoutLevel, direction string) bool { + // If price returned to box, reset state + if currentLevel == market.BreakoutNone { + state.ConfirmCount = 0 + state.Level = market.BreakoutNone + state.Direction = "" + return false + } + + // If same breakout continues, increment count + if state.Level == currentLevel && state.Direction == direction { + state.ConfirmCount++ + } else { + // New breakout, reset count + state.Level = currentLevel + state.Direction = direction + state.ConfirmCount = 1 + state.StartTime = time.Now() + } + + return state.ConfirmCount >= BreakoutConfirmRequired +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestBreakoutConfirmation` +Expected: PASS + +**Step 5: Commit** + +```bash +git add trader/grid_regime.go trader/grid_regime_test.go +git commit -m "feat(trader): add breakout confirmation logic" +``` + +--- + +## Task 9: Add Breakout Handler + +**Files:** +- Modify: `trader/grid_regime.go` +- Test: `trader/grid_regime_test.go` + +**Step 1: Write the failing test** + +Add to `trader/grid_regime_test.go`: + +```go +func TestGetBreakoutAction(t *testing.T) { + tests := []struct { + level market.BreakoutLevel + expected BreakoutAction + }{ + {market.BreakoutNone, BreakoutActionNone}, + {market.BreakoutShort, BreakoutActionReducePosition}, + {market.BreakoutMid, BreakoutActionPauseGrid}, + {market.BreakoutLong, BreakoutActionCloseAll}, + } + + for _, tt := range tests { + t.Run(string(tt.level), func(t *testing.T) { + action := getBreakoutAction(tt.level) + if action != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, action) + } + }) + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestGetBreakoutAction` +Expected: FAIL with "undefined: BreakoutAction" + +**Step 3: Write minimal implementation** + +Add to `trader/grid_regime.go`: + +```go +// BreakoutAction represents the action to take on breakout +type BreakoutAction int + +const ( + BreakoutActionNone BreakoutAction = iota + BreakoutActionReducePosition // Short box breakout: reduce to 50% + BreakoutActionPauseGrid // Mid box breakout: pause grid + cancel orders + BreakoutActionCloseAll // Long box breakout: pause + cancel + close all +) + +// getBreakoutAction returns the appropriate action for a breakout level +func getBreakoutAction(level market.BreakoutLevel) BreakoutAction { + switch level { + case market.BreakoutShort: + return BreakoutActionReducePosition + case market.BreakoutMid: + return BreakoutActionPauseGrid + case market.BreakoutLong: + return BreakoutActionCloseAll + default: + return BreakoutActionNone + } +} +``` + +**Step 4: Run test to verify it passes** + +Run: `cd /Users/yida/gopro/open-nofx && go test -v ./trader/... -run TestGetBreakoutAction` +Expected: PASS + +**Step 5: Commit** + +```bash +git add trader/grid_regime.go trader/grid_regime_test.go +git commit -m "feat(trader): add breakout action handler" +``` + +--- + +## Task 10: Integrate Breakout Detection into Grid Cycle + +**Files:** +- Modify: `trader/auto_trader_grid.go` + +**Step 1: Add checkBoxBreakout method** + +Add to `trader/auto_trader_grid.go` after `checkBreakout` function: + +```go +// checkBoxBreakout checks for multi-period box breakouts and takes appropriate action +func (at *AutoTrader) checkBoxBreakout() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + // Get box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + logger.Infof("Failed to get box data: %v", err) + return nil // Non-fatal, continue with other checks + } + + // Update instance with box values + at.gridState.mu.Lock() + // Store box values in grid state for reference + at.gridState.mu.Unlock() + + // Detect breakout + breakoutLevel, direction := detectBoxBreakout(box) + + // Get current breakout state from instance + state := &BreakoutState{ + Level: market.BreakoutLevel(at.gridState.BreakoutLevel), + Direction: at.gridState.BreakoutDirection, + ConfirmCount: at.gridState.BreakoutConfirmCount, + } + + // Check if breakout is confirmed (3 candles) + confirmed := confirmBreakout(state, breakoutLevel, direction) + + // Update grid state + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(state.Level) + at.gridState.BreakoutDirection = state.Direction + at.gridState.BreakoutConfirmCount = state.ConfirmCount + at.gridState.mu.Unlock() + + if !confirmed { + return nil + } + + // Take action based on breakout level + action := getBreakoutAction(breakoutLevel) + return at.executeBreakoutAction(action) +} + +// executeBreakoutAction executes the appropriate action for a breakout +func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error { + gridConfig := at.config.StrategyConfig.GridConfig + + switch action { + case BreakoutActionReducePosition: + // Short box breakout: reduce position to 50% + logger.Infof("Short box breakout confirmed, reducing position to 50%%") + at.gridState.mu.Lock() + at.gridState.PositionReductionPct = 50 + at.gridState.mu.Unlock() + return nil + + case BreakoutActionPauseGrid: + // Mid box breakout: pause grid + cancel orders + logger.Infof("Mid box breakout confirmed, pausing grid and canceling orders") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return at.cancelAllGridOrders() + + case BreakoutActionCloseAll: + // Long box breakout: pause + cancel + close all + logger.Infof("Long box breakout confirmed, closing all positions") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + if err := at.cancelAllGridOrders(); err != nil { + logger.Infof("Failed to cancel orders: %v", err) + } + return at.closeAllPositions() + } + + return nil +} + +// closeAllPositions closes all open positions +func (at *AutoTrader) closeAllPositions() error { + gridConfig := at.config.StrategyConfig.GridConfig + + positions, err := at.trader.GetPositions() + if err != nil { + return fmt.Errorf("failed to get positions: %w", err) + } + + for _, pos := range positions { + symbol, _ := pos["symbol"].(string) + if symbol != gridConfig.Symbol { + continue + } + + size, _ := pos["positionAmt"].(float64) + if size == 0 { + continue + } + + if size > 0 { + _, err = at.trader.CloseLong(symbol, size) + } else { + _, err = at.trader.CloseShort(symbol, -size) + } + if err != nil { + logger.Infof("Failed to close position: %v", err) + } + } + + return nil +} +``` + +**Step 2: Add checkBoxBreakout call to RunGridCycle** + +In `RunGridCycle`, add after existing breakout check: + +```go + // Check multi-period box breakout + if err := at.checkBoxBreakout(); err != nil { + logger.Infof("Box breakout check error: %v", err) + } +``` + +**Step 3: Commit** + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(trader): integrate box breakout detection into grid cycle" +``` + +--- + +## Task 11: Add False Breakout Recovery + +**Files:** +- Modify: `trader/auto_trader_grid.go` + +**Step 1: Add recovery logic** + +Add to `trader/auto_trader_grid.go`: + +```go +// checkFalseBreakoutRecovery checks if price has returned to box after breakout +func (at *AutoTrader) checkFalseBreakoutRecovery() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + at.gridState.mu.RLock() + breakoutLevel := at.gridState.BreakoutLevel + isPaused := at.gridState.IsPaused + positionReduction := at.gridState.PositionReductionPct + at.gridState.mu.RUnlock() + + // Only check if we had a breakout + if breakoutLevel == string(market.BreakoutNone) && positionReduction == 0 && !isPaused { + return nil + } + + // Get current box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + return nil + } + + // Check if price is back inside the long box + if box.CurrentPrice >= box.LongLower && box.CurrentPrice <= box.LongUpper { + logger.Infof("Price returned to box, recovering with 50%% position") + + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(market.BreakoutNone) + at.gridState.BreakoutDirection = "" + at.gridState.BreakoutConfirmCount = 0 + at.gridState.PositionReductionPct = 50 // Recover at 50% + at.gridState.IsPaused = false + at.gridState.mu.Unlock() + } + + return nil +} +``` + +**Step 2: Add call in RunGridCycle** + +```go + // Check for false breakout recovery + if err := at.checkFalseBreakoutRecovery(); err != nil { + logger.Infof("False breakout recovery check error: %v", err) + } +``` + +**Step 3: Commit** + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(trader): add false breakout recovery logic" +``` + +--- + +## Task 12: Update GridState with Box Fields + +**Files:** +- Modify: `trader/auto_trader_grid.go` + +**Step 1: Add box fields to GridState struct** + +Add to `GridState` struct in `trader/auto_trader_grid.go`: + +```go + // Box state + ShortBoxUpper float64 + ShortBoxLower float64 + MidBoxUpper float64 + MidBoxLower float64 + LongBoxUpper float64 + LongBoxLower float64 + + // Breakout state + BreakoutLevel string + BreakoutDirection string + BreakoutConfirmCount int + + // Position reduction (0 = normal, 50 = reduced after false breakout) + PositionReductionPct float64 + + // Current regime level + CurrentRegimeLevel string +``` + +**Step 2: Commit** + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(trader): add box and regime fields to GridState" +``` + +--- + +## Task 13: Add Frontend Types + +**Files:** +- Modify: `web/src/types.ts` (or equivalent types file) + +**Step 1: Add grid risk info types** + +Add to types file: + +```typescript +export interface GridRiskInfo { + // Leverage info + currentLeverage: number + effectiveLeverage: number + recommendedLeverage: number + + // Position info + currentPosition: number + maxPosition: number + positionPercent: number + + // Liquidation info + liquidationPrice: number + liquidationDistance: number // percentage + + // Market state + regimeLevel: 'narrow' | 'standard' | 'wide' | 'volatile' | 'trending' + + // Box state + shortBoxUpper: number + shortBoxLower: number + midBoxUpper: number + midBoxLower: number + longBoxUpper: number + longBoxLower: number + currentPrice: number + + // Breakout state + breakoutLevel: 'none' | 'short' | 'mid' | 'long' + breakoutDirection: 'up' | 'down' | '' +} +``` + +**Step 2: Commit** + +```bash +git add web/src/types.ts +git commit -m "feat(web): add GridRiskInfo type" +``` + +--- + +## Task 14: Add API Endpoint for Risk Info + +**Files:** +- Modify: `api/server.go` + +**Step 1: Add handler function** + +Add to `api/server.go`: + +```go +// handleGetGridRiskInfo returns current risk information for a grid trader +func (s *Server) handleGetGridRiskInfo(c *gin.Context) { + traderID := c.Param("id") + + trader, err := s.manager.GetTrader(traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "trader not found"}) + return + } + + autoTrader, ok := trader.(*trader.AutoTrader) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "not an auto trader"}) + return + } + + riskInfo := autoTrader.GetGridRiskInfo() + c.JSON(http.StatusOK, riskInfo) +} +``` + +**Step 2: Add route** + +Add route in `setupRoutes`: + +```go + api.GET("/traders/:id/grid-risk", s.handleGetGridRiskInfo) +``` + +**Step 3: Commit** + +```bash +git add api/server.go +git commit -m "feat(api): add grid risk info endpoint" +``` + +--- + +## Task 15: Add GetGridRiskInfo Method to AutoTrader + +**Files:** +- Modify: `trader/auto_trader_grid.go` + +**Step 1: Add method** + +Add to `trader/auto_trader_grid.go`: + +```go +// GridRiskInfo contains risk information for frontend display +type GridRiskInfo struct { + CurrentLeverage int `json:"current_leverage"` + EffectiveLeverage float64 `json:"effective_leverage"` + RecommendedLeverage int `json:"recommended_leverage"` + + CurrentPosition float64 `json:"current_position"` + MaxPosition float64 `json:"max_position"` + PositionPercent float64 `json:"position_percent"` + + LiquidationPrice float64 `json:"liquidation_price"` + LiquidationDistance float64 `json:"liquidation_distance"` + + RegimeLevel string `json:"regime_level"` + + ShortBoxUpper float64 `json:"short_box_upper"` + ShortBoxLower float64 `json:"short_box_lower"` + MidBoxUpper float64 `json:"mid_box_upper"` + MidBoxLower float64 `json:"mid_box_lower"` + LongBoxUpper float64 `json:"long_box_upper"` + LongBoxLower float64 `json:"long_box_lower"` + CurrentPrice float64 `json:"current_price"` + + BreakoutLevel string `json:"breakout_level"` + BreakoutDirection string `json:"breakout_direction"` +} + +// GetGridRiskInfo returns current risk information +func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return &GridRiskInfo{} + } + + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + // Get current price + currentPrice, _ := at.trader.GetMarketPrice(gridConfig.Symbol) + + // Calculate effective leverage + totalInvestment := gridConfig.TotalInvestment + leverage := gridConfig.Leverage + + // Get current position value + positions, _ := at.trader.GetPositions() + var currentPositionValue float64 + for _, pos := range positions { + if sym, _ := pos["symbol"].(string); sym == gridConfig.Symbol { + size, _ := pos["positionAmt"].(float64) + entry, _ := pos["entryPrice"].(float64) + currentPositionValue = math.Abs(size * entry) + break + } + } + + effectiveLeverage := currentPositionValue / totalInvestment + + // Calculate max position based on regime + regimeLevel := market.RegimeLevel(at.gridState.CurrentRegimeLevel) + maxPositionPct := getRegimePositionLimit(regimeLevel, gridConfig) + maxPosition := totalInvestment * maxPositionPct / 100 * float64(leverage) + recommendedLeverage := getRegimeLeverageLimit(regimeLevel, gridConfig) + + // Calculate liquidation distance + liquidationDistance := 100.0 / float64(leverage) * 0.9 // ~90% of theoretical max + + var liquidationPrice float64 + if currentPositionValue > 0 { + liquidationPrice = currentPrice * (1 - liquidationDistance/100) + } + + return &GridRiskInfo{ + CurrentLeverage: leverage, + EffectiveLeverage: effectiveLeverage, + RecommendedLeverage: recommendedLeverage, + + CurrentPosition: currentPositionValue, + MaxPosition: maxPosition, + PositionPercent: currentPositionValue / maxPosition * 100, + + LiquidationPrice: liquidationPrice, + LiquidationDistance: liquidationDistance, + + RegimeLevel: at.gridState.CurrentRegimeLevel, + + ShortBoxUpper: at.gridState.ShortBoxUpper, + ShortBoxLower: at.gridState.ShortBoxLower, + MidBoxUpper: at.gridState.MidBoxUpper, + MidBoxLower: at.gridState.MidBoxLower, + LongBoxUpper: at.gridState.LongBoxUpper, + LongBoxLower: at.gridState.LongBoxLower, + CurrentPrice: currentPrice, + + BreakoutLevel: at.gridState.BreakoutLevel, + BreakoutDirection: at.gridState.BreakoutDirection, + } +} +``` + +**Step 2: Commit** + +```bash +git add trader/auto_trader_grid.go +git commit -m "feat(trader): add GetGridRiskInfo method" +``` + +--- + +## Task 16: Create GridRiskPanel Component + +**Files:** +- Create: `web/src/components/strategy/GridRiskPanel.tsx` + +**Step 1: Create component** + +Create `web/src/components/strategy/GridRiskPanel.tsx`: + +```tsx +import { useState, useEffect } from 'react' +import { AlertTriangle, TrendingUp, Shield, Box } from 'lucide-react' + +interface GridRiskInfo { + currentLeverage: number + effectiveLeverage: number + recommendedLeverage: number + currentPosition: number + maxPosition: number + positionPercent: number + liquidationPrice: number + liquidationDistance: number + regimeLevel: string + shortBoxUpper: number + shortBoxLower: number + midBoxUpper: number + midBoxLower: number + longBoxUpper: number + longBoxLower: number + currentPrice: number + breakoutLevel: string + breakoutDirection: string +} + +interface GridRiskPanelProps { + traderId: string + language: string +} + +export function GridRiskPanel({ traderId, language }: GridRiskPanelProps) { + const [riskInfo, setRiskInfo] = useState(null) + const [loading, setLoading] = useState(true) + + const t = (key: string) => { + const translations: Record> = { + leverageInfo: { zh: '杠杆信息', en: 'Leverage Info' }, + currentLeverage: { zh: '当前杠杆', en: 'Current Leverage' }, + effectiveLeverage: { zh: '有效杠杆', en: 'Effective Leverage' }, + recommendedLeverage: { zh: '推荐杠杆', en: 'Recommended Leverage' }, + positionInfo: { zh: '仓位信息', en: 'Position Info' }, + currentPosition: { zh: '当前仓位', en: 'Current Position' }, + maxPosition: { zh: '最大仓位', en: 'Max Position' }, + liquidationInfo: { zh: '爆仓信息', en: 'Liquidation Info' }, + liquidationPrice: { zh: '爆仓价格', en: 'Liquidation Price' }, + liquidationDistance: { zh: '爆仓距离', en: 'Distance' }, + marketState: { zh: '市场状态', en: 'Market State' }, + regimeLevel: { zh: '震荡级别', en: 'Regime Level' }, + boxState: { zh: '箱体状态', en: 'Box State' }, + shortBox: { zh: '短期箱体', en: 'Short Box' }, + midBox: { zh: '中期箱体', en: 'Mid Box' }, + longBox: { zh: '长期箱体', en: 'Long Box' }, + narrow: { zh: '窄幅震荡', en: 'Narrow' }, + standard: { zh: '标准震荡', en: 'Standard' }, + wide: { zh: '宽幅震荡', en: 'Wide' }, + volatile: { zh: '剧烈震荡', en: 'Volatile' }, + trending: { zh: '趋势', en: 'Trending' }, + breakout: { zh: '突破', en: 'Breakout' }, + none: { zh: '无', en: 'None' }, + } + return translations[key]?.[language] || key + } + + useEffect(() => { + const fetchRiskInfo = async () => { + try { + const res = await fetch(`/api/traders/${traderId}/grid-risk`) + if (res.ok) { + const data = await res.json() + setRiskInfo(data) + } + } catch (err) { + console.error('Failed to fetch risk info:', err) + } finally { + setLoading(false) + } + } + + fetchRiskInfo() + const interval = setInterval(fetchRiskInfo, 10000) // Update every 10s + return () => clearInterval(interval) + }, [traderId]) + + if (loading || !riskInfo) { + return
+ } + + const getRegimeColor = (level: string) => { + switch (level) { + case 'narrow': return 'text-green-400' + case 'standard': return 'text-blue-400' + case 'wide': return 'text-yellow-400' + case 'volatile': return 'text-orange-400' + case 'trending': return 'text-red-400' + default: return 'text-gray-400' + } + } + + return ( +
+ {/* Leverage Info */} +
+

+ + {t('leverageInfo')} +

+
+
+
{t('currentLeverage')}
+
{riskInfo.currentLeverage}x
+
+
+
{t('effectiveLeverage')}
+
{riskInfo.effectiveLeverage.toFixed(2)}x
+
+
+
{t('recommendedLeverage')}
+
{riskInfo.recommendedLeverage}x
+
+
+
+ + {/* Position Info */} +
+

+ + {t('positionInfo')} +

+
+
+
{t('currentPosition')}
+
${riskInfo.currentPosition.toFixed(2)}
+
+
+
{t('maxPosition')}
+
${riskInfo.maxPosition.toFixed(2)}
+
+
+
+
+
+
+ + {/* Liquidation Info */} +
+

+ + {t('liquidationInfo')} +

+
+
+
{t('liquidationPrice')}
+
${riskInfo.liquidationPrice.toFixed(2)}
+
+
+
{t('liquidationDistance')}
+
{riskInfo.liquidationDistance.toFixed(1)}%
+
+
+
+ + {/* Market State */} +
+

+ + {t('marketState')} +

+
+
+
{t('regimeLevel')}
+
+ {t(riskInfo.regimeLevel)} +
+
+ {riskInfo.breakoutLevel !== 'none' && ( +
+ {t('breakout')}: {riskInfo.breakoutLevel} ({riskInfo.breakoutDirection}) +
+ )} +
+
+ + {/* Box State */} +
+

{t('boxState')}

+
+
+ {t('shortBox')} + {riskInfo.shortBoxLower.toFixed(2)} - {riskInfo.shortBoxUpper.toFixed(2)} +
+
+ {t('midBox')} + {riskInfo.midBoxLower.toFixed(2)} - {riskInfo.midBoxUpper.toFixed(2)} +
+
+ {t('longBox')} + {riskInfo.longBoxLower.toFixed(2)} - {riskInfo.longBoxUpper.toFixed(2)} +
+
+ Current Price + ${riskInfo.currentPrice.toFixed(2)} +
+
+
+
+ ) +} +``` + +**Step 2: Commit** + +```bash +git add web/src/components/strategy/GridRiskPanel.tsx +git commit -m "feat(web): add GridRiskPanel component" +``` + +--- + +## Task 17: Update AI Prompt with Box Indicators + +**Files:** +- Modify: `kernel/grid_engine.go` + +**Step 1: Update BuildGridUserPrompt to include box data** + +Add box data section to the prompt in `kernel/grid_engine.go`: + +```go +// In BuildGridUserPrompt function, add after market data section: + + // Box Indicator Section + if gridCtx.BoxData != nil { + sb.WriteString("\n## Box Indicators (Donchian Channels)\n\n") + sb.WriteString("| Box Level | Upper | Lower | Width |\n") + sb.WriteString("|-----------|-------|-------|-------|\n") + + shortWidth := (gridCtx.BoxData.ShortUpper - gridCtx.BoxData.ShortLower) / gridCtx.BoxData.CurrentPrice * 100 + midWidth := (gridCtx.BoxData.MidUpper - gridCtx.BoxData.MidLower) / gridCtx.BoxData.CurrentPrice * 100 + longWidth := (gridCtx.BoxData.LongUpper - gridCtx.BoxData.LongLower) / gridCtx.BoxData.CurrentPrice * 100 + + sb.WriteString(fmt.Sprintf("| Short (3d) | %.2f | %.2f | %.2f%% |\n", + gridCtx.BoxData.ShortUpper, gridCtx.BoxData.ShortLower, shortWidth)) + sb.WriteString(fmt.Sprintf("| Mid (10d) | %.2f | %.2f | %.2f%% |\n", + gridCtx.BoxData.MidUpper, gridCtx.BoxData.MidLower, midWidth)) + sb.WriteString(fmt.Sprintf("| Long (21d) | %.2f | %.2f | %.2f%% |\n", + gridCtx.BoxData.LongUpper, gridCtx.BoxData.LongLower, longWidth)) + + // Price position + sb.WriteString(fmt.Sprintf("\nCurrent Price: %.2f\n", gridCtx.BoxData.CurrentPrice)) + + // Check position relative to boxes + price := gridCtx.BoxData.CurrentPrice + if price > gridCtx.BoxData.LongUpper || price < gridCtx.BoxData.LongLower { + sb.WriteString("⚠️ BREAKOUT: Price outside long-term box!\n") + } else if price > gridCtx.BoxData.MidUpper || price < gridCtx.BoxData.MidLower { + sb.WriteString("⚠️ WARNING: Price approaching long-term box boundary\n") + } + } +``` + +**Step 2: Update GridContext struct** + +Add BoxData field to GridContext: + +```go +type GridContext struct { + // ... existing fields ... + + // Box data + BoxData *market.BoxData +} +``` + +**Step 3: Commit** + +```bash +git add kernel/grid_engine.go +git commit -m "feat(kernel): add box indicators to AI prompt" +``` + +--- + +## Task 18: Database Migration + +**Files:** +- Modify: `store/grid.go` + +**Step 1: Update InitGridSchema to migrate new fields** + +The GORM AutoMigrate will handle adding new columns. Verify by running: + +```bash +cd /Users/yida/gopro/open-nofx && go run . migrate +``` + +**Step 2: Commit** + +```bash +git add store/grid.go +git commit -m "chore(store): ensure new grid fields are migrated" +``` + +--- + +## Task 19: Run All Tests + +**Step 1: Run backend tests** + +```bash +cd /Users/yida/gopro/open-nofx && go test -v ./... +``` + +**Step 2: Run frontend tests (if available)** + +```bash +cd /Users/yida/gopro/open-nofx/web && npm test +``` + +**Step 3: Fix any failing tests and commit** + +```bash +git add . +git commit -m "test: fix tests for grid regime implementation" +``` + +--- + +## Task 20: Final Integration Test + +**Step 1: Start the server** + +```bash +cd /Users/yida/gopro/open-nofx && go run . +``` + +**Step 2: Verify API endpoint** + +```bash +curl http://localhost:8080/api/traders//grid-risk +``` + +**Step 3: Verify frontend displays risk panel** + +Open browser and check grid trading page shows risk panel. + +**Step 4: Final commit** + +```bash +git add . +git commit -m "feat: complete grid market regime detection implementation" +``` + +--- + +## Summary + +| Task | Description | Files | +|------|-------------|-------| +| 1 | Donchian calculation | market/data.go | +| 2 | Box data types | market/types.go | +| 3 | GetBoxData function | market/data.go | +| 4 | GridConfigModel fields | store/grid.go | +| 5 | GridInstanceModel fields | store/grid.go | +| 6 | Regime classification | trader/grid_regime.go | +| 7 | Breakout detection | trader/grid_regime.go | +| 8 | Breakout confirmation | trader/grid_regime.go | +| 9 | Breakout handler | trader/grid_regime.go | +| 10 | Grid cycle integration | trader/auto_trader_grid.go | +| 11 | False breakout recovery | trader/auto_trader_grid.go | +| 12 | GridState fields | trader/auto_trader_grid.go | +| 13 | Frontend types | web/src/types.ts | +| 14 | API endpoint | api/server.go | +| 15 | GetGridRiskInfo method | trader/auto_trader_grid.go | +| 16 | GridRiskPanel component | web/src/components/ | +| 17 | AI prompt update | kernel/grid_engine.go | +| 18 | Database migration | store/grid.go | +| 19 | Run all tests | - | +| 20 | Integration test | - | diff --git a/kernel/engine.go b/kernel/engine.go index e452c730..6ce5aa52 100644 --- a/kernel/engine.go +++ b/kernel/engine.go @@ -130,7 +130,8 @@ type Context struct { // Decision AI trading decision type Decision struct { Symbol string `json:"symbol"` - Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short", "hold", "wait" + Action string `json:"action"` // Standard: "open_long", "open_short", "close_long", "close_short", "hold", "wait" + // Grid actions: "place_buy_limit", "place_sell_limit", "cancel_order", "cancel_all_orders", "pause_grid", "resume_grid", "adjust_grid" // Opening position parameters Leverage int `json:"leverage,omitempty"` @@ -138,6 +139,12 @@ type Decision struct { StopLoss float64 `json:"stop_loss,omitempty"` TakeProfit float64 `json:"take_profit,omitempty"` + // Grid trading parameters + Price float64 `json:"price,omitempty"` // Limit order price (for grid) + Quantity float64 `json:"quantity,omitempty"` // Order quantity (for grid) + LevelIndex int `json:"level_index,omitempty"` // Grid level index + OrderID string `json:"order_id,omitempty"` // Order ID (for cancel) + // Common parameters Confidence int `json:"confidence,omitempty"` // Confidence level (0-100) RiskUSD float64 `json:"risk_usd,omitempty"` // Maximum USD risk diff --git a/kernel/grid_engine.go b/kernel/grid_engine.go new file mode 100644 index 00000000..243432a2 --- /dev/null +++ b/kernel/grid_engine.go @@ -0,0 +1,587 @@ +package kernel + +import ( + "encoding/json" + "fmt" + "nofx/logger" + "nofx/market" + "nofx/mcp" + "nofx/store" + "strings" + "time" +) + +// ============================================================================ +// Grid Trading Context and Types +// ============================================================================ + +// GridLevelInfo represents a single grid level's current state +type GridLevelInfo struct { + Index int `json:"index"` // Level index (0 = lowest) + Price float64 `json:"price"` // Target price for this level + State string `json:"state"` // "empty", "pending", "filled" + Side string `json:"side"` // "buy" or "sell" + OrderID string `json:"order_id"` // Current order ID (if pending) + OrderQuantity float64 `json:"order_quantity"` // Order quantity + PositionSize float64 `json:"position_size"` // Position size (if filled) + PositionEntry float64 `json:"position_entry"` // Entry price (if filled) + AllocatedUSD float64 `json:"allocated_usd"` // USD allocated to this level + UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized P&L (if filled) +} + +// GridContext contains all information needed for AI grid decision making +type GridContext struct { + // Basic info + Symbol string `json:"symbol"` + CurrentTime string `json:"current_time"` + CurrentPrice float64 `json:"current_price"` + + // Grid configuration + GridCount int `json:"grid_count"` + TotalInvestment float64 `json:"total_investment"` + Leverage int `json:"leverage"` + UpperPrice float64 `json:"upper_price"` + LowerPrice float64 `json:"lower_price"` + GridSpacing float64 `json:"grid_spacing"` + Distribution string `json:"distribution"` + + // Grid state + Levels []GridLevelInfo `json:"levels"` + ActiveOrderCount int `json:"active_order_count"` + FilledLevelCount int `json:"filled_level_count"` + IsPaused bool `json:"is_paused"` + + // Market data + ATR14 float64 `json:"atr14"` + BollingerUpper float64 `json:"bollinger_upper"` + BollingerMiddle float64 `json:"bollinger_middle"` + BollingerLower float64 `json:"bollinger_lower"` + BollingerWidth float64 `json:"bollinger_width"` // Percentage + EMA20 float64 `json:"ema20"` + EMA50 float64 `json:"ema50"` + EMADistance float64 `json:"ema_distance"` // Percentage + RSI14 float64 `json:"rsi14"` + MACD float64 `json:"macd"` + MACDSignal float64 `json:"macd_signal"` + MACDHistogram float64 `json:"macd_histogram"` + FundingRate float64 `json:"funding_rate"` + Volume24h float64 `json:"volume_24h"` + PriceChange1h float64 `json:"price_change_1h"` + PriceChange4h float64 `json:"price_change_4h"` + + // Account info + TotalEquity float64 `json:"total_equity"` + AvailableBalance float64 `json:"available_balance"` + CurrentPosition float64 `json:"current_position"` // Net position size + UnrealizedPnL float64 `json:"unrealized_pnl"` + + // Performance + TotalProfit float64 `json:"total_profit"` + TotalTrades int `json:"total_trades"` + WinningTrades int `json:"winning_trades"` + MaxDrawdown float64 `json:"max_drawdown"` + DailyPnL float64 `json:"daily_pnl"` + + // Box indicators (Donchian Channels) + BoxData *market.BoxData `json:"box_data,omitempty"` +} + +// ============================================================================ +// Grid Prompt Building +// ============================================================================ + +// BuildGridSystemPrompt builds the system prompt for grid trading AI +func BuildGridSystemPrompt(config *store.GridStrategyConfig, lang string) string { + if lang == "zh" { + return buildGridSystemPromptZh(config) + } + return buildGridSystemPromptEn(config) +} + +func buildGridSystemPromptZh(config *store.GridStrategyConfig) string { + return fmt.Sprintf(`# 你是一个专业的网格交易AI + +## 角色定义 +你是一个经验丰富的网格交易专家,负责管理 %s 的网格交易策略。你的任务是: +1. 判断当前市场状态(震荡/趋势/高波动) +2. 决定是否需要调整网格或暂停交易 +3. 管理每个网格层级的订单 + +## 网格配置 +- 交易对: %s +- 网格层数: %d +- 总投资: %.2f USDT +- 杠杆: %dx +- 价格分布: %s + +## 决策规则 + +### 市场状态判断 +- **震荡市场** (适合网格): 布林带宽度 < 3%%, EMA20/50 距离 < 1%%, 价格在布林带中轨附近 +- **趋势市场** (暂停网格): 布林带宽度 > 4%%, EMA20/50 距离 > 2%%, 价格持续突破布林带 +- **高波动市场** (谨慎): ATR异常放大, 价格剧烈波动 + +### 可执行的操作 +- place_buy_limit: 在指定价格下买入限价单 +- place_sell_limit: 在指定价格下卖出限价单 +- cancel_order: 取消指定订单 +- cancel_all_orders: 取消所有订单 +- pause_grid: 暂停网格交易(趋势市场时) +- resume_grid: 恢复网格交易(震荡市场时) +- adjust_grid: 调整网格边界 +- hold: 保持当前状态不操作 + +## 输出格式 +输出JSON数组,每个决策包含: +- symbol: 交易对 +- action: 操作类型 +- price: 价格(限价单用) +- quantity: 数量 +- level_index: 网格层级索引 +- order_id: 订单ID(取消订单用) +- confidence: 置信度 0-100 +- reasoning: 决策理由 + +示例: +[ + {"symbol": "BTCUSDT", "action": "place_buy_limit", "price": 94000, "quantity": 0.01, "level_index": 2, "confidence": 85, "reasoning": "第2层价格接近,下买单"}, + {"symbol": "BTCUSDT", "action": "hold", "confidence": 90, "reasoning": "市场震荡,保持当前网格"} +] +`, config.Symbol, config.Symbol, config.GridCount, config.TotalInvestment, config.Leverage, config.Distribution) +} + +func buildGridSystemPromptEn(config *store.GridStrategyConfig) string { + return fmt.Sprintf(`# You are a Professional Grid Trading AI + +## Role Definition +You are an experienced grid trading expert managing a grid strategy for %s. Your tasks are: +1. Assess current market regime (ranging/trending/volatile) +2. Decide whether to adjust grid or pause trading +3. Manage orders at each grid level + +## Grid Configuration +- Symbol: %s +- Grid Levels: %d +- Total Investment: %.2f USDT +- Leverage: %dx +- Distribution: %s + +## Decision Rules + +### Market Regime Assessment +- **Ranging Market** (ideal for grid): Bollinger width < 3%%, EMA20/50 distance < 1%%, price near middle band +- **Trending Market** (pause grid): Bollinger width > 4%%, EMA20/50 distance > 2%%, price breaking bands +- **High Volatility** (caution): ATR spike, erratic price movement + +### Available Actions +- place_buy_limit: Place buy limit order at specified price +- place_sell_limit: Place sell limit order at specified price +- cancel_order: Cancel specific order +- cancel_all_orders: Cancel all orders +- pause_grid: Pause grid trading (in trending market) +- resume_grid: Resume grid trading (in ranging market) +- adjust_grid: Adjust grid boundaries +- hold: Maintain current state + +## Output Format +Output JSON array, each decision contains: +- symbol: Trading pair +- action: Action type +- price: Price (for limit orders) +- quantity: Quantity +- level_index: Grid level index +- order_id: Order ID (for cancel) +- confidence: Confidence 0-100 +- reasoning: Decision reason + +Example: +[ + {"symbol": "BTCUSDT", "action": "place_buy_limit", "price": 94000, "quantity": 0.01, "level_index": 2, "confidence": 85, "reasoning": "Level 2 price approaching, place buy order"}, + {"symbol": "BTCUSDT", "action": "hold", "confidence": 90, "reasoning": "Market ranging, maintain current grid"} +] +`, config.Symbol, config.Symbol, config.GridCount, config.TotalInvestment, config.Leverage, config.Distribution) +} + +// BuildGridUserPrompt builds the user prompt with current grid context +func BuildGridUserPrompt(ctx *GridContext, lang string) string { + if lang == "zh" { + return buildGridUserPromptZh(ctx) + } + return buildGridUserPromptEn(ctx) +} + +func buildGridUserPromptZh(ctx *GridContext) string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("## 当前时间: %s\n\n", ctx.CurrentTime)) + + // Market data section + sb.WriteString("## 市场数据\n") + sb.WriteString(fmt.Sprintf("- 当前价格: $%.2f\n", ctx.CurrentPrice)) + sb.WriteString(fmt.Sprintf("- 1小时涨跌: %.2f%%\n", ctx.PriceChange1h)) + sb.WriteString(fmt.Sprintf("- 4小时涨跌: %.2f%%\n", ctx.PriceChange4h)) + sb.WriteString(fmt.Sprintf("- ATR14: $%.2f (%.2f%%)\n", ctx.ATR14, ctx.ATR14/ctx.CurrentPrice*100)) + sb.WriteString(fmt.Sprintf("- 布林带: 上轨 $%.2f, 中轨 $%.2f, 下轨 $%.2f\n", ctx.BollingerUpper, ctx.BollingerMiddle, ctx.BollingerLower)) + sb.WriteString(fmt.Sprintf("- 布林带宽度: %.2f%%\n", ctx.BollingerWidth)) + sb.WriteString(fmt.Sprintf("- EMA20: $%.2f, EMA50: $%.2f, 距离: %.2f%%\n", ctx.EMA20, ctx.EMA50, ctx.EMADistance)) + sb.WriteString(fmt.Sprintf("- RSI14: %.1f\n", ctx.RSI14)) + sb.WriteString(fmt.Sprintf("- MACD: %.4f, Signal: %.4f, Histogram: %.4f\n", ctx.MACD, ctx.MACDSignal, ctx.MACDHistogram)) + sb.WriteString(fmt.Sprintf("- 资金费率: %.4f%%\n", ctx.FundingRate*100)) + sb.WriteString("\n") + + // Box Indicator Section + if ctx.BoxData != nil { + sb.WriteString("## 箱体指标 (唐奇安通道)\n\n") + sb.WriteString("| 箱体级别 | 上轨 | 下轨 | 宽度 |\n") + sb.WriteString("|----------|------|------|------|\n") + + shortWidth := 0.0 + midWidth := 0.0 + longWidth := 0.0 + + if ctx.BoxData.CurrentPrice > 0 { + shortWidth = (ctx.BoxData.ShortUpper - ctx.BoxData.ShortLower) / ctx.BoxData.CurrentPrice * 100 + midWidth = (ctx.BoxData.MidUpper - ctx.BoxData.MidLower) / ctx.BoxData.CurrentPrice * 100 + longWidth = (ctx.BoxData.LongUpper - ctx.BoxData.LongLower) / ctx.BoxData.CurrentPrice * 100 + } + + sb.WriteString(fmt.Sprintf("| 短期 (3天) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.ShortUpper, ctx.BoxData.ShortLower, shortWidth)) + sb.WriteString(fmt.Sprintf("| 中期 (10天) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.MidUpper, ctx.BoxData.MidLower, midWidth)) + sb.WriteString(fmt.Sprintf("| 长期 (21天) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.LongUpper, ctx.BoxData.LongLower, longWidth)) + + sb.WriteString(fmt.Sprintf("\n当前价格: %.2f\n", ctx.BoxData.CurrentPrice)) + + // Check position relative to boxes + price := ctx.BoxData.CurrentPrice + if price > ctx.BoxData.LongUpper || price < ctx.BoxData.LongLower { + sb.WriteString("⚠️ 突破: 价格突破长期箱体!\n") + } else if price > ctx.BoxData.MidUpper || price < ctx.BoxData.MidLower { + sb.WriteString("⚠️ 警告: 价格接近长期箱体边界\n") + } + sb.WriteString("\n") + } + + // Account section + sb.WriteString("## 账户状态\n") + sb.WriteString(fmt.Sprintf("- 总权益: $%.2f\n", ctx.TotalEquity)) + sb.WriteString(fmt.Sprintf("- 可用余额: $%.2f\n", ctx.AvailableBalance)) + sb.WriteString(fmt.Sprintf("- 当前持仓: %.4f (净头寸)\n", ctx.CurrentPosition)) + sb.WriteString(fmt.Sprintf("- 未实现盈亏: $%.2f\n", ctx.UnrealizedPnL)) + sb.WriteString("\n") + + // Grid state section + sb.WriteString("## 网格状态\n") + sb.WriteString(fmt.Sprintf("- 网格范围: $%.2f - $%.2f\n", ctx.LowerPrice, ctx.UpperPrice)) + sb.WriteString(fmt.Sprintf("- 网格间距: $%.2f\n", ctx.GridSpacing)) + sb.WriteString(fmt.Sprintf("- 活跃订单数: %d\n", ctx.ActiveOrderCount)) + sb.WriteString(fmt.Sprintf("- 已成交层数: %d\n", ctx.FilledLevelCount)) + sb.WriteString(fmt.Sprintf("- 网格已暂停: %v\n", ctx.IsPaused)) + sb.WriteString("\n") + + // Grid levels detail + sb.WriteString("## 网格层级详情\n") + sb.WriteString("| 层级 | 价格 | 状态 | 方向 | 订单数量 | 持仓数量 | 未实现盈亏 |\n") + sb.WriteString("|------|------|------|------|----------|----------|------------|\n") + for _, level := range ctx.Levels { + sb.WriteString(fmt.Sprintf("| %d | $%.2f | %s | %s | %.4f | %.4f | $%.2f |\n", + level.Index, level.Price, level.State, level.Side, + level.OrderQuantity, level.PositionSize, level.UnrealizedPnL)) + } + sb.WriteString("\n") + + // Performance section + sb.WriteString("## 绩效统计\n") + sb.WriteString(fmt.Sprintf("- 总利润: $%.2f\n", ctx.TotalProfit)) + sb.WriteString(fmt.Sprintf("- 总交易次数: %d\n", ctx.TotalTrades)) + sb.WriteString(fmt.Sprintf("- 胜率: %.1f%%\n", float64(ctx.WinningTrades)/float64(max(ctx.TotalTrades, 1))*100)) + sb.WriteString(fmt.Sprintf("- 最大回撤: %.2f%%\n", ctx.MaxDrawdown)) + sb.WriteString(fmt.Sprintf("- 今日盈亏: $%.2f\n", ctx.DailyPnL)) + sb.WriteString("\n") + + sb.WriteString("## 请分析以上数据,做出网格交易决策\n") + sb.WriteString("输出JSON数组格式的决策列表。\n") + + return sb.String() +} + +func buildGridUserPromptEn(ctx *GridContext) string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("## Current Time: %s\n\n", ctx.CurrentTime)) + + // Market data section + sb.WriteString("## Market Data\n") + sb.WriteString(fmt.Sprintf("- Current Price: $%.2f\n", ctx.CurrentPrice)) + sb.WriteString(fmt.Sprintf("- 1h Change: %.2f%%\n", ctx.PriceChange1h)) + sb.WriteString(fmt.Sprintf("- 4h Change: %.2f%%\n", ctx.PriceChange4h)) + sb.WriteString(fmt.Sprintf("- ATR14: $%.2f (%.2f%%)\n", ctx.ATR14, ctx.ATR14/ctx.CurrentPrice*100)) + sb.WriteString(fmt.Sprintf("- Bollinger Bands: Upper $%.2f, Middle $%.2f, Lower $%.2f\n", ctx.BollingerUpper, ctx.BollingerMiddle, ctx.BollingerLower)) + sb.WriteString(fmt.Sprintf("- Bollinger Width: %.2f%%\n", ctx.BollingerWidth)) + sb.WriteString(fmt.Sprintf("- EMA20: $%.2f, EMA50: $%.2f, Distance: %.2f%%\n", ctx.EMA20, ctx.EMA50, ctx.EMADistance)) + sb.WriteString(fmt.Sprintf("- RSI14: %.1f\n", ctx.RSI14)) + sb.WriteString(fmt.Sprintf("- MACD: %.4f, Signal: %.4f, Histogram: %.4f\n", ctx.MACD, ctx.MACDSignal, ctx.MACDHistogram)) + sb.WriteString(fmt.Sprintf("- Funding Rate: %.4f%%\n", ctx.FundingRate*100)) + sb.WriteString("\n") + + // Box Indicator Section + if ctx.BoxData != nil { + sb.WriteString("## Box Indicators (Donchian Channels)\n\n") + sb.WriteString("| Box Level | Upper | Lower | Width |\n") + sb.WriteString("|-----------|-------|-------|-------|\n") + + shortWidth := 0.0 + midWidth := 0.0 + longWidth := 0.0 + + if ctx.BoxData.CurrentPrice > 0 { + shortWidth = (ctx.BoxData.ShortUpper - ctx.BoxData.ShortLower) / ctx.BoxData.CurrentPrice * 100 + midWidth = (ctx.BoxData.MidUpper - ctx.BoxData.MidLower) / ctx.BoxData.CurrentPrice * 100 + longWidth = (ctx.BoxData.LongUpper - ctx.BoxData.LongLower) / ctx.BoxData.CurrentPrice * 100 + } + + sb.WriteString(fmt.Sprintf("| Short (3d) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.ShortUpper, ctx.BoxData.ShortLower, shortWidth)) + sb.WriteString(fmt.Sprintf("| Mid (10d) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.MidUpper, ctx.BoxData.MidLower, midWidth)) + sb.WriteString(fmt.Sprintf("| Long (21d) | %.2f | %.2f | %.2f%% |\n", + ctx.BoxData.LongUpper, ctx.BoxData.LongLower, longWidth)) + + sb.WriteString(fmt.Sprintf("\nCurrent Price: %.2f\n", ctx.BoxData.CurrentPrice)) + + // Check position relative to boxes + price := ctx.BoxData.CurrentPrice + if price > ctx.BoxData.LongUpper || price < ctx.BoxData.LongLower { + sb.WriteString("⚠️ BREAKOUT: Price outside long-term box!\n") + } else if price > ctx.BoxData.MidUpper || price < ctx.BoxData.MidLower { + sb.WriteString("⚠️ WARNING: Price approaching long-term box boundary\n") + } + sb.WriteString("\n") + } + + // Account section + sb.WriteString("## Account Status\n") + sb.WriteString(fmt.Sprintf("- Total Equity: $%.2f\n", ctx.TotalEquity)) + sb.WriteString(fmt.Sprintf("- Available Balance: $%.2f\n", ctx.AvailableBalance)) + sb.WriteString(fmt.Sprintf("- Current Position: %.4f (net)\n", ctx.CurrentPosition)) + sb.WriteString(fmt.Sprintf("- Unrealized PnL: $%.2f\n", ctx.UnrealizedPnL)) + sb.WriteString("\n") + + // Grid state section + sb.WriteString("## Grid Status\n") + sb.WriteString(fmt.Sprintf("- Grid Range: $%.2f - $%.2f\n", ctx.LowerPrice, ctx.UpperPrice)) + sb.WriteString(fmt.Sprintf("- Grid Spacing: $%.2f\n", ctx.GridSpacing)) + sb.WriteString(fmt.Sprintf("- Active Orders: %d\n", ctx.ActiveOrderCount)) + sb.WriteString(fmt.Sprintf("- Filled Levels: %d\n", ctx.FilledLevelCount)) + sb.WriteString(fmt.Sprintf("- Grid Paused: %v\n", ctx.IsPaused)) + sb.WriteString("\n") + + // Grid levels detail + sb.WriteString("## Grid Levels Detail\n") + sb.WriteString("| Level | Price | State | Side | Order Qty | Position | Unrealized PnL |\n") + sb.WriteString("|-------|-------|-------|------|-----------|----------|----------------|\n") + for _, level := range ctx.Levels { + sb.WriteString(fmt.Sprintf("| %d | $%.2f | %s | %s | %.4f | %.4f | $%.2f |\n", + level.Index, level.Price, level.State, level.Side, + level.OrderQuantity, level.PositionSize, level.UnrealizedPnL)) + } + sb.WriteString("\n") + + // Performance section + sb.WriteString("## Performance Stats\n") + sb.WriteString(fmt.Sprintf("- Total Profit: $%.2f\n", ctx.TotalProfit)) + sb.WriteString(fmt.Sprintf("- Total Trades: %d\n", ctx.TotalTrades)) + sb.WriteString(fmt.Sprintf("- Win Rate: %.1f%%\n", float64(ctx.WinningTrades)/float64(max(ctx.TotalTrades, 1))*100)) + sb.WriteString(fmt.Sprintf("- Max Drawdown: %.2f%%\n", ctx.MaxDrawdown)) + sb.WriteString(fmt.Sprintf("- Daily PnL: $%.2f\n", ctx.DailyPnL)) + sb.WriteString("\n") + + sb.WriteString("## Please analyze the data above and make grid trading decisions\n") + sb.WriteString("Output a JSON array of decisions.\n") + + return sb.String() +} + +// ============================================================================ +// Grid Decision Functions +// ============================================================================ + +// GetGridDecisions gets AI decisions for grid trading +func GetGridDecisions(ctx *GridContext, mcpClient mcp.AIClient, config *store.GridStrategyConfig, lang string) (*FullDecision, error) { + startTime := time.Now() + + // Build prompts + systemPrompt := BuildGridSystemPrompt(config, lang) + userPrompt := BuildGridUserPrompt(ctx, lang) + + logger.Infof("🤖 [Grid] Calling AI for grid decisions...") + + // Call AI + response, err := mcpClient.CallWithMessages(systemPrompt, userPrompt) + if err != nil { + return nil, fmt.Errorf("AI call failed: %w", err) + } + + // Parse decisions from response + decisions, err := parseGridDecisions(response, ctx.Symbol) + if err != nil { + logger.Warnf("Failed to parse grid decisions: %v", err) + // Return hold decision as fallback + decisions = []Decision{{ + Symbol: ctx.Symbol, + Action: "hold", + Confidence: 50, + Reasoning: "Failed to parse AI response, holding current state", + }} + } + + duration := time.Since(startTime).Milliseconds() + logger.Infof("⏱️ [Grid] AI call duration: %d ms, decisions: %d", duration, len(decisions)) + + // Extract chain of thought from response + cotTrace := extractCoTTrace(response) + + return &FullDecision{ + SystemPrompt: systemPrompt, + UserPrompt: userPrompt, + CoTTrace: cotTrace, + Decisions: decisions, + RawResponse: response, + AIRequestDurationMs: duration, + Timestamp: time.Now(), + }, nil +} + +// parseGridDecisions parses AI response into grid decisions +func parseGridDecisions(response string, symbol string) ([]Decision, error) { + // Try to find JSON array in response + jsonStr := extractJSONArray(response) + if jsonStr == "" { + return nil, fmt.Errorf("no JSON array found in response") + } + + var decisions []Decision + if err := json.Unmarshal([]byte(jsonStr), &decisions); err != nil { + return nil, fmt.Errorf("failed to parse JSON: %w", err) + } + + // Validate and set default symbol + for i := range decisions { + if decisions[i].Symbol == "" { + decisions[i].Symbol = symbol + } + // Validate action + if !isValidGridAction(decisions[i].Action) { + logger.Warnf("Invalid grid action: %s", decisions[i].Action) + } + } + + return decisions, nil +} + +// extractJSONArray extracts JSON array from AI response +func extractJSONArray(response string) string { + // Try to find ```json code block first + matches := reJSONFence.FindStringSubmatch(response) + if len(matches) > 1 { + return matches[1] + } + + // Try to find raw JSON array + matches = reJSONArray.FindStringSubmatch(response) + if len(matches) > 0 { + return matches[0] + } + + return "" +} + +// isValidGridAction checks if action is a valid grid action +func isValidGridAction(action string) bool { + validActions := map[string]bool{ + "place_buy_limit": true, + "place_sell_limit": true, + "cancel_order": true, + "cancel_all_orders": true, + "pause_grid": true, + "resume_grid": true, + "adjust_grid": true, + "hold": true, + // Also support standard actions for compatibility + "open_long": true, + "open_short": true, + "close_long": true, + "close_short": true, + } + return validActions[action] +} + +// ============================================================================ +// Grid Context Builder Helpers +// ============================================================================ + +// BuildGridContextFromMarketData builds grid context from market data +func BuildGridContextFromMarketData(mktData *market.Data, config *store.GridStrategyConfig) *GridContext { + ctx := &GridContext{ + Symbol: config.Symbol, + CurrentTime: time.Now().Format("2006-01-02 15:04:05"), + CurrentPrice: mktData.CurrentPrice, + + // Grid config + GridCount: config.GridCount, + TotalInvestment: config.TotalInvestment, + Leverage: config.Leverage, + Distribution: config.Distribution, + + // Market data + PriceChange1h: mktData.PriceChange1h, + PriceChange4h: mktData.PriceChange4h, + FundingRate: mktData.FundingRate, + } + + // Extract indicators from timeframe data + if mktData.TimeframeData != nil { + if tf5m, ok := mktData.TimeframeData["5m"]; ok { + if len(tf5m.BOLLUpper) > 0 { + ctx.BollingerUpper = tf5m.BOLLUpper[len(tf5m.BOLLUpper)-1] + ctx.BollingerMiddle = tf5m.BOLLMiddle[len(tf5m.BOLLMiddle)-1] + ctx.BollingerLower = tf5m.BOLLLower[len(tf5m.BOLLLower)-1] + if ctx.BollingerMiddle > 0 { + ctx.BollingerWidth = (ctx.BollingerUpper - ctx.BollingerLower) / ctx.BollingerMiddle * 100 + } + } + ctx.ATR14 = tf5m.ATR14 + if len(tf5m.RSI14Values) > 0 { + ctx.RSI14 = tf5m.RSI14Values[len(tf5m.RSI14Values)-1] + } + } + } + + // Extract longer term context + if mktData.LongerTermContext != nil { + if ctx.ATR14 == 0 { + ctx.ATR14 = mktData.LongerTermContext.ATR14 + } + ctx.EMA50 = mktData.LongerTermContext.EMA50 + } + + ctx.EMA20 = mktData.CurrentEMA20 + ctx.MACD = mktData.CurrentMACD + + // Calculate EMA distance + if ctx.EMA50 > 0 { + ctx.EMADistance = (ctx.EMA20 - ctx.EMA50) / ctx.EMA50 * 100 + } + + return ctx +} + +// Helper function for max +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/manager/trader_manager.go b/manager/trader_manager.go index 8e39670e..4060a794 100644 --- a/manager/trader_manager.go +++ b/manager/trader_manager.go @@ -292,8 +292,8 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [ // Concurrently fetch data for each trader for i, t := range traders { go func(index int, trader *trader.AutoTrader) { - // Set timeout to 3 seconds for single trader - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + // Set timeout to 10 seconds for single trader (increased from 3s for DEX reliability) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Use channel for timeout control @@ -330,7 +330,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [ } case err := <-errorChan: // Failed to get account info - logger.Infof("⚠️ Failed to get account info for trader %s: %v", trader.GetID(), err) + logger.Infof("⚠️ Failed to get account info for trader %s (%s/%s): %v", trader.GetName(), trader.GetID(), trader.GetExchange(), err) traderData = map[string]interface{}{ "trader_id": trader.GetID(), "trader_name": trader.GetName(), @@ -347,7 +347,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [ } case <-ctx.Done(): // Timeout - logger.Infof("⏰ Timeout getting account info for trader %s", trader.GetID()) + logger.Infof("⏰ Timeout (10s) getting account info for trader %s (%s/%s)", trader.GetName(), trader.GetID(), trader.GetExchange()) traderData = map[string]interface{}{ "trader_id": trader.GetID(), "trader_name": trader.GetName(), diff --git a/market/data.go b/market/data.go index 1fa990c0..a993736a 100644 --- a/market/data.go +++ b/market/data.go @@ -1210,3 +1210,91 @@ func ExportCalculateATR(klines []Kline, period int) float64 { func ExportCalculateBOLL(klines []Kline, period int, multiplier float64) (upper, middle, lower float64) { return calculateBOLL(klines, period, multiplier) } + +// calculateDonchian calculates Donchian channel (highest high, lowest low) for given period +func calculateDonchian(klines []Kline, period int) (upper, lower float64) { + if len(klines) == 0 || period <= 0 { + return 0, 0 + } + + // Use all available klines if period > len(klines) + start := len(klines) - period + if start < 0 { + start = 0 + } + + upper = klines[start].High + lower = klines[start].Low + + for i := start + 1; i < len(klines); i++ { + if klines[i].High > upper { + upper = klines[i].High + } + if klines[i].Low < lower { + lower = klines[i].Low + } + } + + return upper, lower +} + +// ExportCalculateDonchian exports calculateDonchian for testing +func ExportCalculateDonchian(klines []Kline, period int) (float64, float64) { + return calculateDonchian(klines, period) +} + +// Box period constants (in 1h candles) +const ( + ShortBoxPeriod = 72 // 3 days of 1h candles + MidBoxPeriod = 240 // 10 days of 1h candles + LongBoxPeriod = 500 // ~21 days of 1h candles +) + +// calculateBoxData calculates multi-period box data from klines +func calculateBoxData(klines []Kline, currentPrice float64) *BoxData { + box := &BoxData{ + CurrentPrice: currentPrice, + } + + if len(klines) == 0 { + return box + } + + box.ShortUpper, box.ShortLower = calculateDonchian(klines, ShortBoxPeriod) + box.MidUpper, box.MidLower = calculateDonchian(klines, MidBoxPeriod) + box.LongUpper, box.LongLower = calculateDonchian(klines, LongBoxPeriod) + + return box +} + +// ExportCalculateBoxData exports calculateBoxData for testing +func ExportCalculateBoxData(klines []Kline, currentPrice float64) *BoxData { + return calculateBoxData(klines, currentPrice) +} + +// GetBoxData fetches 1h klines and calculates box data for a symbol +func GetBoxData(symbol string) (*BoxData, error) { + symbol = Normalize(symbol) + + // Fetch 500 1h klines + var klines []Kline + var err error + + if IsXyzDexAsset(symbol) { + klines, err = getKlinesFromHyperliquid(symbol, "1h", LongBoxPeriod) + } else { + klines, err = getKlinesFromCoinAnk(symbol, "1h", LongBoxPeriod) + } + + if err != nil { + return nil, fmt.Errorf("failed to get 1h klines: %w", err) + } + + if len(klines) == 0 { + return nil, fmt.Errorf("no kline data available") + } + + currentPrice := klines[len(klines)-1].Close + + return calculateBoxData(klines, currentPrice), nil +} diff --git a/market/data_test.go b/market/data_test.go index b20a336c..231c4806 100644 --- a/market/data_test.go +++ b/market/data_test.go @@ -500,3 +500,86 @@ func TestIsStaleData_EmptyKlines(t *testing.T) { t.Error("Expected false for empty klines, got true") } } + +func TestCalculateDonchian(t *testing.T) { + // Create test klines with known high/low values + klines := []Kline{ + {High: 100, Low: 90}, + {High: 105, Low: 88}, + {High: 102, Low: 92}, + {High: 108, Low: 85}, + {High: 103, Low: 91}, + } + + upper, lower := ExportCalculateDonchian(klines, 5) + + if upper != 108 { + t.Errorf("Expected upper = 108, got %v", upper) + } + if lower != 85 { + t.Errorf("Expected lower = 85, got %v", lower) + } +} + +func TestCalculateDonchian_PartialPeriod(t *testing.T) { + klines := []Kline{ + {High: 100, Low: 90}, + {High: 105, Low: 88}, + } + + upper, lower := ExportCalculateDonchian(klines, 10) + + // Should use all available klines when period > len(klines) + if upper != 105 { + t.Errorf("Expected upper = 105, got %v", upper) + } + if lower != 88 { + t.Errorf("Expected lower = 88, got %v", lower) + } +} + +func TestCalculateDonchian_InvalidPeriod(t *testing.T) { + klines := []Kline{ + {High: 100, Low: 90}, + } + + // Zero period should return (0, 0) + upper, lower := ExportCalculateDonchian(klines, 0) + if upper != 0 || lower != 0 { + t.Errorf("Expected (0, 0) for zero period, got (%v, %v)", upper, lower) + } + + // Negative period should return (0, 0) + upper, lower = ExportCalculateDonchian(klines, -1) + if upper != 0 || lower != 0 { + t.Errorf("Expected (0, 0) for negative period, got (%v, %v)", upper, lower) + } +} + +func TestCalculateBoxData(t *testing.T) { + // Create synthetic kline data + klines := make([]Kline, 500) + for i := 0; i < 500; i++ { + basePrice := 100.0 + klines[i] = Kline{ + High: basePrice + float64(i%10), + Low: basePrice - float64(i%10), + Close: basePrice, + } + } + + box := ExportCalculateBoxData(klines, 100.0) + + if box.ShortUpper == 0 || box.ShortLower == 0 { + t.Error("Short box should not be zero") + } + if box.MidUpper == 0 || box.MidLower == 0 { + t.Error("Mid box should not be zero") + } + if box.LongUpper == 0 || box.LongLower == 0 { + t.Error("Long box should not be zero") + } + if box.CurrentPrice != 100.0 { + t.Errorf("Expected CurrentPrice = 100.0, got %v", box.CurrentPrice) + } +} diff --git a/market/types.go b/market/types.go index 1ca71a68..7569c9f3 100644 --- a/market/types.go +++ b/market/types.go @@ -187,3 +187,42 @@ var config = Config{ }, UpdateInterval: 60, // 1 minute } + +// BoxData represents multi-period Donchian channel (box) data +type BoxData struct { + // Short-term box (72 1h candles = 3 days) + ShortUpper float64 `json:"short_upper"` + ShortLower float64 `json:"short_lower"` + + // Mid-term box (240 1h candles = 10 days) + MidUpper float64 `json:"mid_upper"` + MidLower float64 `json:"mid_lower"` + + // Long-term box (500 1h candles = ~21 days) + LongUpper float64 `json:"long_upper"` + LongLower float64 `json:"long_lower"` + + // Current price position relative to boxes + CurrentPrice float64 `json:"current_price"` +} + +// RegimeLevel represents the ranging classification level +type RegimeLevel string + +const ( + RegimeLevelNarrow RegimeLevel = "narrow" // 窄幅震荡 + RegimeLevelStandard RegimeLevel = "standard" // 标准震荡 + RegimeLevelWide RegimeLevel = "wide" // 宽幅震荡 + RegimeLevelVolatile RegimeLevel = "volatile" // 剧烈震荡 + RegimeLevelTrending RegimeLevel = "trending" // 趋势 +) + +// BreakoutLevel represents which box level has been broken +type BreakoutLevel string + +const ( + BreakoutNone BreakoutLevel = "none" + BreakoutShort BreakoutLevel = "short" + BreakoutMid BreakoutLevel = "mid" + BreakoutLong BreakoutLevel = "long" +) diff --git a/scripts/test_lighter_orders.go b/scripts/test_lighter_orders.go new file mode 100644 index 00000000..e064cac2 --- /dev/null +++ b/scripts/test_lighter_orders.go @@ -0,0 +1,168 @@ +//go:build ignore + +// Test script to verify Lighter API authentication +// Run: go run scripts/test_lighter_orders.go +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + lighterClient "github.com/elliottech/lighter-go/client" + lighterHTTP "github.com/elliottech/lighter-go/client/http" +) + +func main() { + // Configuration - update these values + walletAddr := os.Getenv("LIGHTER_WALLET") + apiKeyPrivateKey := os.Getenv("LIGHTER_API_KEY") + + if walletAddr == "" || apiKeyPrivateKey == "" { + fmt.Println("Usage: LIGHTER_WALLET=0x... LIGHTER_API_KEY=... go run scripts/test_lighter_orders.go") + fmt.Println("Environment variables required:") + fmt.Println(" LIGHTER_WALLET - Ethereum wallet address") + fmt.Println(" LIGHTER_API_KEY - API key private key (40 bytes hex)") + os.Exit(1) + } + + fmt.Println("=== Lighter API Test ===") + fmt.Printf("Wallet: %s\n\n", walletAddr) + + baseURL := "https://mainnet.zklighter.elliot.ai" + chainID := uint32(304) + client := &http.Client{Timeout: 30 * time.Second} + + // Step 1: Get account info (no auth required) + fmt.Println("1. Getting account info...") + accountIndex, err := getAccountIndex(client, baseURL, walletAddr) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + os.Exit(1) + } + fmt.Printf(" OK: account_index = %d\n\n", accountIndex) + + // Step 2: Create TxClient and generate auth token + fmt.Println("2. Creating TxClient and generating auth token...") + httpClient := lighterHTTP.NewClient(baseURL) + txClient, err := lighterClient.NewTxClient(httpClient, apiKeyPrivateKey, accountIndex, 0, chainID) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + os.Exit(1) + } + + authToken, err := txClient.GetAuthToken(time.Now().Add(1 * time.Hour)) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + os.Exit(1) + } + fmt.Printf(" OK: auth token generated\n\n") + + // Step 3: Test GetActiveOrders with auth query parameter (NEW method) + fmt.Println("3. Testing GetActiveOrders with auth query parameter (FIXED)...") + encodedAuth := url.QueryEscape(authToken) + endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0&auth=%s", + baseURL, accountIndex, encodedAuth) + + resp, err := client.Get(endpoint) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + os.Exit(1) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var result map[string]interface{} + json.Unmarshal(body, &result) + + if code, ok := result["code"].(float64); ok && code == 200 { + orders := result["orders"].([]interface{}) + fmt.Printf(" OK: Retrieved %d orders\n", len(orders)) + if len(orders) > 0 { + fmt.Println(" Sample orders:") + for i, o := range orders { + if i >= 3 { + fmt.Printf(" ... and %d more\n", len(orders)-3) + break + } + order := o.(map[string]interface{}) + fmt.Printf(" - ID: %v, Price: %v, Side: %v\n", + order["order_id"], order["price"], order["is_ask"]) + } + } + } else { + fmt.Printf(" FAILED: %s\n", string(body)) + fmt.Println("\n Possible causes:") + fmt.Println(" - API key not registered on-chain") + fmt.Println(" - API key private key incorrect") + fmt.Println(" - Account index mismatch") + os.Exit(1) + } + + // Step 4: Test GetActiveOrders with Authorization header (OLD method - for comparison) + fmt.Println("\n4. Testing GetActiveOrders with Authorization header (OLD method)...") + endpoint2 := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=0", + baseURL, accountIndex) + + req, _ := http.NewRequest("GET", endpoint2, nil) + req.Header.Set("Authorization", authToken) + req.Header.Set("Content-Type", "application/json") + + resp2, err := client.Do(req) + if err != nil { + fmt.Printf(" FAILED: %v\n", err) + } else { + defer resp2.Body.Close() + body2, _ := io.ReadAll(resp2.Body) + var result2 map[string]interface{} + json.Unmarshal(body2, &result2) + + if code, ok := result2["code"].(float64); ok && code == 200 { + orders := result2["orders"].([]interface{}) + fmt.Printf(" OK: Retrieved %d orders (both methods work!)\n", len(orders)) + } else { + fmt.Printf(" FAILED: %s\n", string(body2)) + fmt.Println(" ^ This is expected - Authorization header doesn't work consistently") + } + } + + fmt.Println("\n=== TEST COMPLETE ===") + fmt.Println("If test 3 passed, the fix is working correctly.") +} + +func getAccountIndex(client *http.Client, baseURL, walletAddr string) (int64, error) { + endpoint := fmt.Sprintf("%s/api/v1/account?by=l1_address&value=%s", baseURL, walletAddr) + resp, err := client.Get(endpoint) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var result struct { + Code int `json:"code"` + Accounts []struct { + AccountIndex int64 `json:"account_index"` + } `json:"accounts"` + SubAccounts []struct { + AccountIndex int64 `json:"account_index"` + } `json:"sub_accounts"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return 0, fmt.Errorf("failed to parse: %w", err) + } + + if len(result.Accounts) > 0 { + return result.Accounts[0].AccountIndex, nil + } + if len(result.SubAccounts) > 0 { + return result.SubAccounts[0].AccountIndex, nil + } + + return 0, fmt.Errorf("no account found") +} diff --git a/store/grid.go b/store/grid.go new file mode 100644 index 00000000..49ce5708 --- /dev/null +++ b/store/grid.go @@ -0,0 +1,585 @@ +package store + +import ( + "fmt" + "time" + + "gorm.io/gorm" +) + +// ==================== Grid Store Models ==================== +// These models mirror the grid package types but are defined here +// to avoid import cycles between store and grid packages. + +// GridConfigModel GORM model for grid_configs table +type GridConfigModel struct { + ID string `json:"id" gorm:"primaryKey"` + UserID string `json:"user_id" gorm:"index"` + TraderID string `json:"trader_id" gorm:"index"` + Symbol string `json:"symbol" gorm:"not null"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + + GridCount int `json:"grid_count" gorm:"default:10"` + TotalInvestment float64 `json:"total_investment" gorm:"not null"` + Leverage int `json:"leverage" gorm:"default:5"` + UpperPrice float64 `json:"upper_price"` + LowerPrice float64 `json:"lower_price"` + UseATRBounds bool `json:"use_atr_bounds" gorm:"default:true"` + ATRMultiplier float64 `json:"atr_multiplier" gorm:"default:2.0"` + Distribution string `json:"distribution" gorm:"default:gaussian"` + + MaxDrawdownPct float64 `json:"max_drawdown_pct" gorm:"default:15.0"` + StopLossPct float64 `json:"stop_loss_pct" gorm:"default:5.0"` + DailyLossLimitPct float64 `json:"daily_loss_limit_pct" gorm:"default:10"` + MaxPositionSizePct float64 `json:"max_position_size_pct" gorm:"default:30"` + + RegimeCheckInterval int `json:"regime_check_interval" gorm:"default:30"` + AutoPauseOnTrend bool `json:"auto_pause_on_trend" gorm:"default:true"` + MinRangingScore int `json:"min_ranging_score" gorm:"default:60"` + TrendResumeThreshold int `json:"trend_resume_threshold" gorm:"default:70"` + + // Box indicator periods (1h candles) + ShortBoxPeriod int `json:"short_box_period" gorm:"default:72"` // 3 days + MidBoxPeriod int `json:"mid_box_period" gorm:"default:240"` // 10 days + LongBoxPeriod int `json:"long_box_period" gorm:"default:500"` // 21 days + + // Effective leverage limits by regime level + NarrowRegimeLeverage int `json:"narrow_regime_leverage" gorm:"default:2"` + StandardRegimeLeverage int `json:"standard_regime_leverage" gorm:"default:4"` + WideRegimeLeverage int `json:"wide_regime_leverage" gorm:"default:3"` + VolatileRegimeLeverage int `json:"volatile_regime_leverage" gorm:"default:2"` + + // Position limits by regime level (percentage of total investment) + NarrowRegimePositionPct float64 `json:"narrow_regime_position_pct" gorm:"default:40"` + StandardRegimePositionPct float64 `json:"standard_regime_position_pct" gorm:"default:70"` + WideRegimePositionPct float64 `json:"wide_regime_position_pct" gorm:"default:60"` + VolatileRegimePositionPct float64 `json:"volatile_regime_position_pct" gorm:"default:40"` + + OrderRefreshSec int `json:"order_refresh_sec" gorm:"default:300"` + UseMakerOnly bool `json:"use_maker_only" gorm:"default:true"` + SlippageTolerPct float64 `json:"slippage_toler_pct" gorm:"default:0.1"` + + AIProvider string `json:"ai_provider" gorm:"default:deepseek"` + AIModel string `json:"ai_model" gorm:"default:deepseek-chat"` + IsActive bool `json:"is_active" gorm:"default:false"` +} + +func (GridConfigModel) TableName() string { + return "grid_configs" +} + +// GridInstanceModel GORM model for grid_instances table +type GridInstanceModel struct { + ID string `json:"id" gorm:"primaryKey"` + ConfigID string `json:"config_id" gorm:"index;not null"` + Symbol string `json:"symbol" gorm:"not null"` + State string `json:"state" gorm:"not null"` + StartedAt time.Time `json:"started_at"` + StoppedAt *time.Time `json:"stopped_at,omitempty"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + + CurrentUpperPrice float64 `json:"current_upper_price"` + CurrentLowerPrice float64 `json:"current_lower_price"` + CurrentGridSpacing float64 `json:"current_grid_spacing"` + ActiveLevelCount int `json:"active_level_count"` + CurrentRegime string `json:"current_regime"` + RegimeScore int `json:"regime_score"` + LastRegimeCheck time.Time `json:"last_regime_check"` + ConsecutiveTrending int `json:"consecutive_trending"` + + // Current regime level (narrow/standard/wide/volatile/trending) + CurrentRegimeLevel string `json:"current_regime_level" gorm:"default:standard"` + + // Box state + ShortBoxUpper float64 `json:"short_box_upper"` + ShortBoxLower float64 `json:"short_box_lower"` + MidBoxUpper float64 `json:"mid_box_upper"` + MidBoxLower float64 `json:"mid_box_lower"` + LongBoxUpper float64 `json:"long_box_upper"` + LongBoxLower float64 `json:"long_box_lower"` + + // Breakout state + BreakoutLevel string `json:"breakout_level" gorm:"default:none"` // none/short/mid/long + BreakoutDirection string `json:"breakout_direction"` // up/down + BreakoutConfirmCount int `json:"breakout_confirm_count" gorm:"default:0"` + BreakoutStartTime time.Time `json:"breakout_start_time"` + + // Position adjustment due to breakout + PositionReductionPct float64 `json:"position_reduction_pct" gorm:"default:0"` // 0 = normal, 50 = reduced + + TotalProfit float64 `json:"total_profit" gorm:"default:0"` + TotalFees float64 `json:"total_fees" gorm:"default:0"` + TotalTrades int `json:"total_trades" gorm:"default:0"` + WinningTrades int `json:"winning_trades" gorm:"default:0"` + MaxDrawdown float64 `json:"max_drawdown" gorm:"default:0"` + CurrentDrawdown float64 `json:"current_drawdown" gorm:"default:0"` + PeakEquity float64 `json:"peak_equity" gorm:"default:0"` + DailyProfit float64 `json:"daily_profit" gorm:"default:0"` + DailyLoss float64 `json:"daily_loss" gorm:"default:0"` + LastDailyReset time.Time `json:"last_daily_reset"` +} + +func (GridInstanceModel) TableName() string { + return "grid_instances" +} + +// GridLevelModel GORM model for grid_levels table +type GridLevelModel struct { + ID string `json:"id" gorm:"primaryKey"` + InstanceID string `json:"instance_id" gorm:"index;not null"` + LevelIndex int `json:"level_index" gorm:"not null"` + Price float64 `json:"price" gorm:"not null"` + State string `json:"state" gorm:"not null"` + Side string `json:"side"` + OrderID string `json:"order_id,omitempty"` + OrderPrice float64 `json:"order_price,omitempty"` + OrderQuantity float64 `json:"order_quantity,omitempty"` + OrderCreatedAt *time.Time `json:"order_created_at,omitempty"` + PositionSize float64 `json:"position_size,omitempty"` + PositionEntry float64 `json:"position_entry,omitempty"` + PositionOpenAt *time.Time `json:"position_open_at,omitempty"` + AllocationWeight float64 `json:"allocation_weight"` + AllocatedUSD float64 `json:"allocated_usd"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +func (GridLevelModel) TableName() string { + return "grid_levels" +} + +// GridEventModel GORM model for grid_events table +type GridEventModel struct { + ID string `json:"id" gorm:"primaryKey"` + InstanceID string `json:"instance_id" gorm:"index;not null"` + LevelID string `json:"level_id,omitempty" gorm:"index"` + EventType string `json:"event_type" gorm:"not null"` + EventTime time.Time `json:"event_time" gorm:"autoCreateTime"` + Price float64 `json:"price,omitempty"` + Quantity float64 `json:"quantity,omitempty"` + Side string `json:"side,omitempty"` + PnL float64 `json:"pnl,omitempty"` + Fee float64 `json:"fee,omitempty"` + Message string `json:"message,omitempty"` + OldRegime string `json:"old_regime,omitempty"` + NewRegime string `json:"new_regime,omitempty"` + TriggerType string `json:"trigger_type,omitempty"` + RawData string `json:"raw_data,omitempty" gorm:"type:text"` +} + +func (GridEventModel) TableName() string { + return "grid_events" +} + +// GridRegimeAssessmentModel GORM model for grid_regime_assessments table +type GridRegimeAssessmentModel struct { + ID string `json:"id" gorm:"primaryKey"` + InstanceID string `json:"instance_id" gorm:"index;not null"` + AssessedAt time.Time `json:"assessed_at" gorm:"autoCreateTime"` + Regime string `json:"regime" gorm:"not null"` + Score int `json:"score" gorm:"not null"` + Confidence float64 `json:"confidence"` + BollingerSignal int `json:"bollinger_signal"` + EMASignal int `json:"ema_signal"` + MACDSignal int `json:"macd_signal"` + VolumeSignal int `json:"volume_signal"` + OISignal int `json:"oi_signal"` + FundingSignal int `json:"funding_signal"` + CandleSignal int `json:"candle_signal"` + ATR14 float64 `json:"atr14"` + BollingerWidth float64 `json:"bollinger_width"` + EMADistance float64 `json:"ema_distance"` + CurrentPrice float64 `json:"current_price"` + AIReasoning string `json:"ai_reasoning" gorm:"type:text"` +} + +func (GridRegimeAssessmentModel) TableName() string { + return "grid_regime_assessments" +} + +// ==================== Grid Store ==================== + +// GridStore provides database operations for grid trading +type GridStore struct { + db *gorm.DB +} + +// NewGridStore creates a new grid store +func NewGridStore(db *gorm.DB) *GridStore { + return &GridStore{db: db} +} + +// InitTables initializes grid-related tables +func (s *GridStore) InitTables() error { + // For PostgreSQL with existing tables, skip AutoMigrate to avoid type conflicts + if s.db.Dialector.Name() == "postgres" { + var tableExists int64 + s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'grid_configs'`).Scan(&tableExists) + + if tableExists > 0 { + // Tables exist, just ensure indexes + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_configs_user_id ON grid_configs(user_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_configs_trader_id ON grid_configs(trader_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_instances_config_id ON grid_instances(config_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_levels_instance_id ON grid_levels(instance_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_events_instance_id ON grid_events(instance_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_events_level_id ON grid_events(level_id)`) + s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_grid_regime_assessments_instance_id ON grid_regime_assessments(instance_id)`) + return nil + } + } + + // AutoMigrate all grid tables + if err := s.db.AutoMigrate( + &GridConfigModel{}, + &GridInstanceModel{}, + &GridLevelModel{}, + &GridEventModel{}, + &GridRegimeAssessmentModel{}, + ); err != nil { + return fmt.Errorf("failed to migrate grid tables: %w", err) + } + + return nil +} + +// ==================== Config Operations ==================== + +// SaveGridConfig saves or updates a grid configuration +func (s *GridStore) SaveGridConfig(config *GridConfigModel) error { + config.UpdatedAt = time.Now() + if config.CreatedAt.IsZero() { + config.CreatedAt = time.Now() + } + return s.db.Save(config).Error +} + +// LoadGridConfig loads a grid configuration by ID +func (s *GridStore) LoadGridConfig(id string) (*GridConfigModel, error) { + var config GridConfigModel + err := s.db.Where("id = ?", id).First(&config).Error + if err != nil { + return nil, err + } + return &config, nil +} + +// LoadGridConfigByTrader loads a grid configuration by trader ID +func (s *GridStore) LoadGridConfigByTrader(traderID string) (*GridConfigModel, error) { + var config GridConfigModel + err := s.db.Where("trader_id = ? AND is_active = true", traderID).First(&config).Error + if err != nil { + return nil, err + } + return &config, nil +} + +// ListGridConfigs lists all grid configurations for a user +func (s *GridStore) ListGridConfigs(userID string) ([]GridConfigModel, error) { + var configs []GridConfigModel + err := s.db.Where("user_id = ?", userID).Order("created_at DESC").Find(&configs).Error + if err != nil { + return nil, err + } + return configs, nil +} + +// DeleteGridConfig deletes a grid configuration and all related data +func (s *GridStore) DeleteGridConfig(id string) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // Get all instances for this config + var instances []GridInstanceModel + if err := tx.Where("config_id = ?", id).Find(&instances).Error; err != nil { + return err + } + + // Delete related data for each instance + for _, instance := range instances { + if err := tx.Where("instance_id = ?", instance.ID).Delete(&GridLevelModel{}).Error; err != nil { + return err + } + if err := tx.Where("instance_id = ?", instance.ID).Delete(&GridEventModel{}).Error; err != nil { + return err + } + if err := tx.Where("instance_id = ?", instance.ID).Delete(&GridRegimeAssessmentModel{}).Error; err != nil { + return err + } + } + + // Delete instances + if err := tx.Where("config_id = ?", id).Delete(&GridInstanceModel{}).Error; err != nil { + return err + } + + // Delete config + return tx.Where("id = ?", id).Delete(&GridConfigModel{}).Error + }) +} + +// ==================== Instance Operations ==================== + +// SaveGridInstance saves or updates a grid instance +func (s *GridStore) SaveGridInstance(instance *GridInstanceModel) error { + instance.UpdatedAt = time.Now() + return s.db.Save(instance).Error +} + +// LoadGridInstance loads a grid instance by config ID +func (s *GridStore) LoadGridInstance(configID string) (*GridInstanceModel, error) { + var instance GridInstanceModel + err := s.db.Where("config_id = ?", configID). + Order("started_at DESC"). + First(&instance).Error + if err != nil { + return nil, err + } + return &instance, nil +} + +// LoadGridInstanceByID loads a grid instance by ID +func (s *GridStore) LoadGridInstanceByID(id string) (*GridInstanceModel, error) { + var instance GridInstanceModel + err := s.db.Where("id = ?", id).First(&instance).Error + if err != nil { + return nil, err + } + return &instance, nil +} + +// ListGridInstances lists all instances for a config +func (s *GridStore) ListGridInstances(configID string) ([]GridInstanceModel, error) { + var instances []GridInstanceModel + err := s.db.Where("config_id = ?", configID). + Order("started_at DESC"). + Find(&instances).Error + if err != nil { + return nil, err + } + return instances, nil +} + +// ==================== Level Operations ==================== + +// SaveGridLevel saves or updates a grid level +func (s *GridStore) SaveGridLevel(level *GridLevelModel) error { + level.UpdatedAt = time.Now() + return s.db.Save(level).Error +} + +// SaveGridLevels saves multiple grid levels +func (s *GridStore) SaveGridLevels(levels []GridLevelModel) error { + if len(levels) == 0 { + return nil + } + now := time.Now() + for i := range levels { + levels[i].UpdatedAt = now + } + return s.db.Save(&levels).Error +} + +// LoadGridLevels loads all levels for an instance +func (s *GridStore) LoadGridLevels(instanceID string) ([]GridLevelModel, error) { + var levels []GridLevelModel + err := s.db.Where("instance_id = ?", instanceID). + Order("level_index ASC"). + Find(&levels).Error + if err != nil { + return nil, err + } + return levels, nil +} + +// DeleteGridLevels deletes all levels for an instance +func (s *GridStore) DeleteGridLevels(instanceID string) error { + return s.db.Where("instance_id = ?", instanceID).Delete(&GridLevelModel{}).Error +} + +// ==================== Event Operations ==================== + +// SaveGridEvent saves a grid event +func (s *GridStore) SaveGridEvent(event *GridEventModel) error { + if event.EventTime.IsZero() { + event.EventTime = time.Now() + } + return s.db.Create(event).Error +} + +// LoadRecentGridEvents loads recent events for an instance +func (s *GridStore) LoadRecentGridEvents(instanceID string, limit int) ([]GridEventModel, error) { + var events []GridEventModel + query := s.db.Where("instance_id = ?", instanceID). + Order("event_time DESC") + if limit > 0 { + query = query.Limit(limit) + } + err := query.Find(&events).Error + if err != nil { + return nil, err + } + return events, nil +} + +// LoadGridEventsByType loads events of a specific type +func (s *GridStore) LoadGridEventsByType(instanceID, eventType string, limit int) ([]GridEventModel, error) { + var events []GridEventModel + query := s.db.Where("instance_id = ? AND event_type = ?", instanceID, eventType). + Order("event_time DESC") + if limit > 0 { + query = query.Limit(limit) + } + err := query.Find(&events).Error + if err != nil { + return nil, err + } + return events, nil +} + +// CountGridEvents counts events for an instance +func (s *GridStore) CountGridEvents(instanceID string) (int64, error) { + var count int64 + err := s.db.Model(&GridEventModel{}). + Where("instance_id = ?", instanceID). + Count(&count).Error + return count, err +} + +// ==================== Regime Assessment Operations ==================== + +// SaveGridRegimeAssessment saves a regime assessment +func (s *GridStore) SaveGridRegimeAssessment(assessment *GridRegimeAssessmentModel) error { + if assessment.AssessedAt.IsZero() { + assessment.AssessedAt = time.Now() + } + return s.db.Create(assessment).Error +} + +// LoadLatestGridRegime loads the latest regime assessment +func (s *GridStore) LoadLatestGridRegime(instanceID string) (*GridRegimeAssessmentModel, error) { + var assessment GridRegimeAssessmentModel + err := s.db.Where("instance_id = ?", instanceID). + Order("assessed_at DESC"). + First(&assessment).Error + if err != nil { + return nil, err + } + return &assessment, nil +} + +// LoadGridRegimeHistory loads regime assessment history +func (s *GridStore) LoadGridRegimeHistory(instanceID string, limit int) ([]GridRegimeAssessmentModel, error) { + var assessments []GridRegimeAssessmentModel + query := s.db.Where("instance_id = ?", instanceID). + Order("assessed_at DESC") + if limit > 0 { + query = query.Limit(limit) + } + err := query.Find(&assessments).Error + if err != nil { + return nil, err + } + return assessments, nil +} + +// ==================== Statistics Operations ==================== + +// GetGridInstanceStatistics returns statistics for an instance +func (s *GridStore) GetGridInstanceStatistics(instanceID string) (map[string]interface{}, error) { + var instance GridInstanceModel + if err := s.db.Where("id = ?", instanceID).First(&instance).Error; err != nil { + return nil, err + } + + // Count events by type + var eventCounts []struct { + EventType string + Count int64 + } + s.db.Model(&GridEventModel{}). + Select("event_type, count(*) as count"). + Where("instance_id = ?", instanceID). + Group("event_type"). + Find(&eventCounts) + + eventCountMap := make(map[string]int64) + for _, ec := range eventCounts { + eventCountMap[ec.EventType] = ec.Count + } + + // Get latest regime + var latestRegime GridRegimeAssessmentModel + s.db.Where("instance_id = ?", instanceID). + Order("assessed_at DESC"). + First(&latestRegime) + + winRate := 0.0 + if instance.TotalTrades > 0 { + winRate = float64(instance.WinningTrades) / float64(instance.TotalTrades) * 100 + } + + return map[string]interface{}{ + "instance_id": instance.ID, + "state": instance.State, + "started_at": instance.StartedAt, + "stopped_at": instance.StoppedAt, + "total_profit": instance.TotalProfit, + "total_fees": instance.TotalFees, + "total_trades": instance.TotalTrades, + "winning_trades": instance.WinningTrades, + "win_rate": winRate, + "max_drawdown": instance.MaxDrawdown, + "current_drawdown": instance.CurrentDrawdown, + "peak_equity": instance.PeakEquity, + "active_level_count": instance.ActiveLevelCount, + "current_regime": instance.CurrentRegime, + "regime_score": instance.RegimeScore, + "event_counts": eventCountMap, + "latest_regime_score": latestRegime.Score, + }, nil +} + +// GetGridPerformanceMetrics returns performance metrics for a time period +func (s *GridStore) GetGridPerformanceMetrics(instanceID string, from, to time.Time) (map[string]interface{}, error) { + // Count trades in period + var tradeCounts struct { + TotalFills int64 + BuyFills int64 + SellFills int64 + } + s.db.Model(&GridEventModel{}). + Select("count(*) as total_fills, "+ + "sum(case when side = 'buy' then 1 else 0 end) as buy_fills, "+ + "sum(case when side = 'sell' then 1 else 0 end) as sell_fills"). + Where("instance_id = ? AND event_type = 'order_filled' AND event_time BETWEEN ? AND ?", + instanceID, from, to). + Scan(&tradeCounts) + + // Sum profit/loss + var pnlSum struct { + TotalPnL float64 + TotalFee float64 + } + s.db.Model(&GridEventModel{}). + Select("coalesce(sum(pnl), 0) as total_pnl, coalesce(sum(fee), 0) as total_fee"). + Where("instance_id = ? AND event_time BETWEEN ? AND ?", instanceID, from, to). + Scan(&pnlSum) + + // Count regime changes + var regimeChanges int64 + s.db.Model(&GridEventModel{}). + Where("instance_id = ? AND event_type = 'regime_change' AND event_time BETWEEN ? AND ?", + instanceID, from, to). + Count(®imeChanges) + + return map[string]interface{}{ + "period_start": from, + "period_end": to, + "total_fills": tradeCounts.TotalFills, + "buy_fills": tradeCounts.BuyFills, + "sell_fills": tradeCounts.SellFills, + "total_pnl": pnlSum.TotalPnL, + "total_fees": pnlSum.TotalFee, + "net_pnl": pnlSum.TotalPnL - pnlSum.TotalFee, + "regime_changes": regimeChanges, + }, nil +} diff --git a/store/store.go b/store/store.go index 21c15813..8119b935 100644 --- a/store/store.go +++ b/store/store.go @@ -28,6 +28,7 @@ type Store struct { strategy *StrategyStore equity *EquityStore order *OrderStore + grid *GridStore mu sync.RWMutex } @@ -156,6 +157,9 @@ func (s *Store) initTables() error { if err := s.Order().InitTables(); err != nil { return fmt.Errorf("failed to initialize order tables: %w", err) } + if err := s.Grid().InitTables(); err != nil { + return fmt.Errorf("failed to initialize grid tables: %w", err) + } return nil } @@ -279,6 +283,16 @@ func (s *Store) Order() *OrderStore { return s.order } +// Grid gets grid trading storage +func (s *Store) Grid() *GridStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.grid == nil { + s.grid = NewGridStore(s.gdb) + } + return s.grid +} + // Close closes database connection func (s *Store) Close() error { if s.driver != nil { diff --git a/store/strategy.go b/store/strategy.go index 1b3f5b11..be009851 100644 --- a/store/strategy.go +++ b/store/strategy.go @@ -32,6 +32,9 @@ func (Strategy) TableName() string { return "strategies" } // StrategyConfig strategy configuration details (JSON structure) type StrategyConfig struct { + // Strategy type: "ai_trading" (default) or "grid_trading" + StrategyType string `json:"strategy_type,omitempty"` + // language setting: "zh" for Chinese, "en" for English // This determines the language used for data formatting and prompt generation Language string `json:"language,omitempty"` @@ -45,6 +48,39 @@ type StrategyConfig struct { RiskControl RiskControlConfig `json:"risk_control"` // editable sections of System Prompt PromptSections PromptSectionsConfig `json:"prompt_sections,omitempty"` + + // Grid trading configuration (only used when StrategyType == "grid_trading") + GridConfig *GridStrategyConfig `json:"grid_config,omitempty"` +} + +// GridStrategyConfig grid trading specific configuration +type GridStrategyConfig struct { + // Trading pair (e.g., "BTCUSDT") + Symbol string `json:"symbol"` + // Number of grid levels (5-50) + GridCount int `json:"grid_count"` + // Total investment in USDT + TotalInvestment float64 `json:"total_investment"` + // Leverage (1-20) + Leverage int `json:"leverage"` + // Upper price boundary (0 = auto-calculate from ATR) + UpperPrice float64 `json:"upper_price"` + // Lower price boundary (0 = auto-calculate from ATR) + LowerPrice float64 `json:"lower_price"` + // Use ATR to auto-calculate bounds + UseATRBounds bool `json:"use_atr_bounds"` + // ATR multiplier for bound calculation (default 2.0) + ATRMultiplier float64 `json:"atr_multiplier"` + // Position distribution: "uniform" | "gaussian" | "pyramid" + Distribution string `json:"distribution"` + // Maximum drawdown percentage before emergency exit + MaxDrawdownPct float64 `json:"max_drawdown_pct"` + // Stop loss percentage per position + StopLossPct float64 `json:"stop_loss_pct"` + // Daily loss limit percentage + DailyLossLimitPct float64 `json:"daily_loss_limit_pct"` + // Use maker-only orders for lower fees + UseMakerOnly bool `json:"use_maker_only"` } // PromptSectionsConfig editable sections of System Prompt diff --git a/store/trader.go b/store/trader.go index ebe93e55..8b983baa 100644 --- a/store/trader.go +++ b/store/trader.go @@ -248,3 +248,23 @@ func (s *TraderStore) ListAll() ([]*Trader, error) { } return traders, nil } + +// ListByExchangeID gets traders that use a specific exchange +func (s *TraderStore) ListByExchangeID(userID, exchangeID string) ([]*Trader, error) { + var traders []*Trader + err := s.db.Where("user_id = ? AND exchange_id = ?", userID, exchangeID).Find(&traders).Error + if err != nil { + return nil, err + } + return traders, nil +} + +// ListByAIModelID gets traders that use a specific AI model +func (s *TraderStore) ListByAIModelID(userID, aiModelID string) ([]*Trader, error) { + var traders []*Trader + err := s.db.Where("user_id = ? AND ai_model_id = ?", userID, aiModelID).Find(&traders).Error + if err != nil { + return nil, err + } + return traders, nil +} diff --git a/trader/aster_trader.go b/trader/aster_trader.go index 2c1bbe7d..1d9a547b 100644 --- a/trader/aster_trader.go +++ b/trader/aster_trader.go @@ -1417,6 +1417,191 @@ func (t *AsterTrader) GetTrades(startTime time.Time, limit int) ([]TradeRecord, // GetOpenOrders gets all open/pending orders for a symbol func (t *AsterTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { - // TODO: Implement Aster open orders - return []OpenOrder{}, nil + params := map[string]interface{}{ + "symbol": symbol, + } + + body, err := t.request("GET", "/fapi/v3/openOrders", params) + if err != nil { + return nil, fmt.Errorf("failed to get open orders: %w", err) + } + + var orders []struct { + OrderID int64 `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` + PositionSide string `json:"positionSide"` + Type string `json:"type"` + Price string `json:"price"` + StopPrice string `json:"stopPrice"` + OrigQty string `json:"origQty"` + Status string `json:"status"` + } + + if err := json.Unmarshal(body, &orders); err != nil { + return nil, fmt.Errorf("failed to parse open orders: %w", err) + } + + var result []OpenOrder + for _, order := range orders { + price, _ := strconv.ParseFloat(order.Price, 64) + stopPrice, _ := strconv.ParseFloat(order.StopPrice, 64) + quantity, _ := strconv.ParseFloat(order.OrigQty, 64) + + result = append(result, OpenOrder{ + OrderID: fmt.Sprintf("%d", order.OrderID), + Symbol: order.Symbol, + Side: order.Side, + PositionSide: order.PositionSide, + Type: order.Type, + Price: price, + StopPrice: stopPrice, + Quantity: quantity, + Status: order.Status, + }) + } + + logger.Infof("✓ ASTER GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +func (t *AsterTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // Format price and quantity to correct precision + formattedPrice, err := t.formatPrice(req.Symbol, req.Price) + if err != nil { + return nil, fmt.Errorf("failed to format price: %w", err) + } + formattedQty, err := t.formatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Get precision information + prec, err := t.getPrecision(req.Symbol) + if err != nil { + return nil, fmt.Errorf("failed to get precision: %w", err) + } + + // Convert to string with correct precision format + priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) + qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) + + // Determine side + side := "BUY" + if req.Side == "SELL" || req.Side == "Sell" || req.Side == "sell" { + side = "SELL" + } + + params := map[string]interface{}{ + "symbol": req.Symbol, + "positionSide": "BOTH", + "type": "LIMIT", + "side": side, + "timeInForce": "GTC", + "quantity": qtyStr, + "price": priceStr, + } + + // Add reduceOnly if specified + if req.ReduceOnly { + params["reduceOnly"] = "true" + } + + body, err := t.request("POST", "/fapi/v3/order", params) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + // Extract order ID + orderID := "" + if id, ok := result["orderId"].(float64); ok { + orderID = fmt.Sprintf("%.0f", id) + } else if id, ok := result["orderId"].(string); ok { + orderID = id + } + + // Extract client order ID + clientOrderID := "" + if cid, ok := result["clientOrderId"].(string); ok { + clientOrderID = cid + } + + return &LimitOrderResult{ + OrderID: orderID, + ClientID: clientOrderID, + Symbol: req.Symbol, + Side: side, + Price: formattedPrice, + Quantity: formattedQty, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by order ID +func (t *AsterTrader) CancelOrder(symbol, orderID string) error { + params := map[string]interface{}{ + "symbol": symbol, + "orderId": orderID, + } + + _, err := t.request("DELETE", "/fapi/v3/order", params) + if err != nil { + return fmt.Errorf("failed to cancel order %s: %w", orderID, err) + } + + return nil +} + +// GetOrderBook gets the order book for a symbol +func (t *AsterTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + if depth <= 0 { + depth = 20 + } + + // Aster uses public endpoint (no signature required) + resp, err := t.client.Get(fmt.Sprintf("%s/fapi/v3/depth?symbol=%s&limit=%d", t.baseURL, symbol, depth)) + if err != nil { + return nil, nil, fmt.Errorf("failed to fetch order book: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Bids [][]string `json:"bids"` // [[price, qty], ...] + Asks [][]string `json:"asks"` // [[price, qty], ...] + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + // Convert string arrays to float64 arrays + bids = make([][]float64, len(result.Bids)) + for i, bid := range result.Bids { + if len(bid) >= 2 { + price, _ := strconv.ParseFloat(bid[0], 64) + qty, _ := strconv.ParseFloat(bid[1], 64) + bids[i] = []float64{price, qty} + } + } + + asks = make([][]float64, len(result.Asks)) + for i, ask := range result.Asks { + if len(ask) >= 2 { + price, _ := strconv.ParseFloat(ask[0], 64) + qty, _ := strconv.ParseFloat(ask[1], 64) + asks[i] = []float64{price, qty} + } + } + + return bids, asks, nil } diff --git a/trader/auto_trader.go b/trader/auto_trader.go index 2f0daee0..3ec70635 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -123,6 +123,7 @@ type AutoTrader struct { peakPnLCacheMutex sync.RWMutex // Cache read-write lock lastBalanceSyncTime time.Time // Last balance sync time userID string // User ID + gridState *GridState // Grid trading state (only used when StrategyType == "grid_trading") } // NewAutoTrader creates an automatic trader @@ -419,9 +420,25 @@ func (at *AutoTrader) Run() error { ticker := time.NewTicker(at.config.ScanInterval) defer ticker.Stop() + // Check if this is a grid trading strategy + isGridStrategy := at.IsGridStrategy() + if isGridStrategy { + logger.Infof("🔲 [%s] Grid trading strategy detected, initializing grid...", at.name) + if err := at.InitializeGrid(); err != nil { + logger.Errorf("❌ [%s] Failed to initialize grid: %v", at.name, err) + return fmt.Errorf("grid initialization failed: %w", err) + } + } + // Execute immediately on first run - if err := at.runCycle(); err != nil { - logger.Infof("❌ Execution failed: %v", err) + if isGridStrategy { + if err := at.RunGridCycle(); err != nil { + logger.Infof("❌ Grid execution failed: %v", err) + } + } else { + if err := at.runCycle(); err != nil { + logger.Infof("❌ Execution failed: %v", err) + } } for { @@ -435,8 +452,14 @@ func (at *AutoTrader) Run() error { select { case <-ticker.C: - if err := at.runCycle(); err != nil { - logger.Infof("❌ Execution failed: %v", err) + if isGridStrategy { + if err := at.RunGridCycle(); err != nil { + logger.Infof("❌ Grid execution failed: %v", err) + } + } else { + if err := at.runCycle(); err != nil { + logger.Infof("❌ Execution failed: %v", err) + } } case <-at.stopMonitorCh: logger.Infof("[%s] ⏹ Stop signal received, exiting automatic trading main loop", at.name) @@ -1365,6 +1388,12 @@ func (at *AutoTrader) GetID() string { return at.id } +// GetUnderlyingTrader returns the underlying Trader interface implementation +// This is used by grid trading and other components that need direct exchange access +func (at *AutoTrader) GetUnderlyingTrader() Trader { + return at.trader +} + // GetName gets trader name func (at *AutoTrader) GetName() string { return at.name @@ -1471,7 +1500,7 @@ func (at *AutoTrader) GetStatus() map[string]interface{} { isRunning := at.isRunning at.isRunningMutex.RUnlock() - return map[string]interface{}{ + result := map[string]interface{}{ "trader_id": at.id, "trader_name": at.name, "ai_model": at.aiModel, @@ -1486,6 +1515,16 @@ func (at *AutoTrader) GetStatus() map[string]interface{} { "last_reset_time": at.lastResetTime.Format(time.RFC3339), "ai_provider": aiProvider, } + + // Add strategy info + if at.config.StrategyConfig != nil { + result["strategy_type"] = at.config.StrategyConfig.StrategyType + if at.config.StrategyConfig.GridConfig != nil { + result["grid_symbol"] = at.config.StrategyConfig.GridConfig.Symbol + } + } + + return result } // GetAccountInfo gets account information (for API) diff --git a/trader/auto_trader_grid.go b/trader/auto_trader_grid.go new file mode 100644 index 00000000..5b445b7f --- /dev/null +++ b/trader/auto_trader_grid.go @@ -0,0 +1,1579 @@ +package trader + +import ( + "encoding/json" + "fmt" + "math" + "nofx/kernel" + "nofx/logger" + "nofx/market" + "nofx/store" + "sync" + "time" +) + +// ============================================================================ +// Grid Trading State Management +// ============================================================================ + +// GridState holds the runtime state for grid trading +type GridState struct { + mu sync.RWMutex + + // Configuration + Config *store.GridStrategyConfig + + // Grid levels + Levels []kernel.GridLevelInfo + + // Calculated bounds + UpperPrice float64 + LowerPrice float64 + GridSpacing float64 + + // State flags + IsPaused bool + IsInitialized bool + + // Performance tracking + TotalProfit float64 + TotalTrades int + WinningTrades int + MaxDrawdown float64 + PeakEquity float64 + DailyPnL float64 + LastDailyReset time.Time + + // Order tracking + OrderBook map[string]int // OrderID -> LevelIndex + + // Box state + ShortBoxUpper float64 + ShortBoxLower float64 + MidBoxUpper float64 + MidBoxLower float64 + LongBoxUpper float64 + LongBoxLower float64 + + // Breakout state + BreakoutLevel string + BreakoutDirection string + BreakoutConfirmCount int + + // Position reduction (0 = normal, 50 = reduced after false breakout) + PositionReductionPct float64 + + // Current regime level + CurrentRegimeLevel string +} + +// NewGridState creates a new grid state +func NewGridState(config *store.GridStrategyConfig) *GridState { + return &GridState{ + Config: config, + Levels: make([]kernel.GridLevelInfo, 0), + OrderBook: make(map[string]int), + } +} + +// ============================================================================ +// Breakout Detection +// ============================================================================ + +// BreakoutType represents the type of price breakout +type BreakoutType string + +const ( + BreakoutNone BreakoutType = "none" + BreakoutUpper BreakoutType = "upper" + BreakoutLower BreakoutType = "lower" +) + +// checkBreakout detects if price has broken out of grid range +// Returns breakout type and percentage beyond boundary +func (at *AutoTrader) checkBreakout() (BreakoutType, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return BreakoutNone, 0 + } + + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + if upper <= 0 || lower <= 0 { + return BreakoutNone, 0 + } + + // Check upper breakout + if currentPrice > upper { + breakoutPct := (currentPrice - upper) / upper * 100 + return BreakoutUpper, breakoutPct + } + + // Check lower breakout + if currentPrice < lower { + breakoutPct := (lower - currentPrice) / lower * 100 + return BreakoutLower, breakoutPct + } + + return BreakoutNone, 0 +} + +// checkMaxDrawdown checks if current drawdown exceeds maximum allowed +// Returns: (exceeded bool, currentDrawdown float64) +func (at *AutoTrader) checkMaxDrawdown() (bool, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.MaxDrawdownPct <= 0 { + return false, 0 + } + + // Get current equity + balance, err := at.trader.GetBalance() + if err != nil { + return false, 0 + } + + currentEquity := 0.0 + if equity, ok := balance["total_equity"].(float64); ok { + currentEquity = equity + } else if total, ok := balance["totalWalletBalance"].(float64); ok { + if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { + currentEquity = total + unrealized + } + } + + if currentEquity <= 0 { + return false, 0 + } + + // Update peak equity + at.gridState.mu.Lock() + if currentEquity > at.gridState.PeakEquity { + at.gridState.PeakEquity = currentEquity + } + peakEquity := at.gridState.PeakEquity + at.gridState.mu.Unlock() + + if peakEquity <= 0 { + return false, 0 + } + + // Calculate current drawdown + drawdown := (peakEquity - currentEquity) / peakEquity * 100 + + // Update max drawdown tracking + at.gridState.mu.Lock() + if drawdown > at.gridState.MaxDrawdown { + at.gridState.MaxDrawdown = drawdown + } + at.gridState.mu.Unlock() + + return drawdown >= gridConfig.MaxDrawdownPct, drawdown +} + +// checkDailyLossLimit checks if daily loss exceeds limit +// Returns: (exceeded bool, dailyLossPct float64) +func (at *AutoTrader) checkDailyLossLimit() (bool, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.DailyLossLimitPct <= 0 { + return false, 0 + } + + at.gridState.mu.Lock() + // Reset daily PnL if new day + now := time.Now() + if now.YearDay() != at.gridState.LastDailyReset.YearDay() || + now.Year() != at.gridState.LastDailyReset.Year() { + at.gridState.DailyPnL = 0 + at.gridState.LastDailyReset = now + } + dailyPnL := at.gridState.DailyPnL + at.gridState.mu.Unlock() + + // Calculate daily loss as percentage of total investment + dailyLossPct := 0.0 + if gridConfig.TotalInvestment > 0 && dailyPnL < 0 { + dailyLossPct = (-dailyPnL) / gridConfig.TotalInvestment * 100 + } + + return dailyLossPct >= gridConfig.DailyLossLimitPct, dailyLossPct +} + +// updateDailyPnL updates the daily PnL tracking +func (at *AutoTrader) updateDailyPnL(realizedPnL float64) { + at.gridState.mu.Lock() + at.gridState.DailyPnL += realizedPnL + at.gridState.TotalProfit += realizedPnL + at.gridState.mu.Unlock() +} + +// emergencyExit closes all positions and cancels all orders +func (at *AutoTrader) emergencyExit(reason string) error { + gridConfig := at.config.StrategyConfig.GridConfig + + logger.Errorf("[Grid] EMERGENCY EXIT: %s", reason) + + // Cancel all orders + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders in emergency: %v", err) + } + + // Close all positions + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok && size != 0 { + if size > 0 { + at.trader.CloseLong(gridConfig.Symbol, size) + } else { + at.trader.CloseShort(gridConfig.Symbol, -size) + } + } + } + } + } + + // Pause grid + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + return nil +} + +// handleBreakout handles price breakout from grid range +func (at *AutoTrader) handleBreakout(breakoutType BreakoutType, breakoutPct float64) error { + logger.Warnf("[Grid] BREAKOUT DETECTED: %s, %.2f%% beyond boundary", breakoutType, breakoutPct) + + // If breakout exceeds 2%, pause grid and cancel orders + if breakoutPct >= 2.0 { + logger.Warnf("[Grid] Significant breakout (%.2f%%), pausing grid and canceling orders", breakoutPct) + + // Cancel all pending orders to prevent further losses + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders on breakout: %v", err) + } + + // Pause grid trading + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + return fmt.Errorf("grid paused due to %s breakout (%.2f%%)", breakoutType, breakoutPct) + } + + // If breakout is minor (< 2%), consider adjusting grid + if breakoutPct >= 1.0 { + logger.Infof("[Grid] Minor breakout (%.2f%%), considering grid adjustment", breakoutPct) + // Let AI decide whether to adjust + } + + return nil +} + +// checkBoxBreakout checks for multi-period box breakouts and takes appropriate action +func (at *AutoTrader) checkBoxBreakout() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + // Get box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + logger.Infof("Failed to get box data: %v", err) + return nil // Non-fatal, continue with other checks + } + + // Update grid state with box values + at.gridState.mu.Lock() + at.gridState.ShortBoxUpper = box.ShortUpper + at.gridState.ShortBoxLower = box.ShortLower + at.gridState.MidBoxUpper = box.MidUpper + at.gridState.MidBoxLower = box.MidLower + at.gridState.LongBoxUpper = box.LongUpper + at.gridState.LongBoxLower = box.LongLower + at.gridState.mu.Unlock() + + // Detect breakout + breakoutLevel, direction := detectBoxBreakout(box) + + // Get current breakout state + state := &BreakoutState{ + Level: market.BreakoutLevel(at.gridState.BreakoutLevel), + Direction: at.gridState.BreakoutDirection, + ConfirmCount: at.gridState.BreakoutConfirmCount, + } + + // Check if breakout is confirmed (3 candles) + confirmed := confirmBreakout(state, breakoutLevel, direction) + + // Update grid state + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(state.Level) + at.gridState.BreakoutDirection = state.Direction + at.gridState.BreakoutConfirmCount = state.ConfirmCount + at.gridState.mu.Unlock() + + if !confirmed { + return nil + } + + // Take action based on breakout level + action := getBreakoutAction(breakoutLevel) + return at.executeBreakoutAction(action) +} + +// executeBreakoutAction executes the appropriate action for a breakout +func (at *AutoTrader) executeBreakoutAction(action BreakoutAction) error { + switch action { + case BreakoutActionReducePosition: + // Short box breakout: reduce position to 50% + logger.Infof("Short box breakout confirmed, reducing position to 50%%") + at.gridState.mu.Lock() + at.gridState.PositionReductionPct = 50 + at.gridState.mu.Unlock() + return nil + + case BreakoutActionPauseGrid: + // Mid box breakout: pause grid + cancel orders + logger.Infof("Mid box breakout confirmed, pausing grid and canceling orders") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return at.cancelAllGridOrders() + + case BreakoutActionCloseAll: + // Long box breakout: pause + cancel + close all + logger.Infof("Long box breakout confirmed, closing all positions") + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + if err := at.cancelAllGridOrders(); err != nil { + logger.Infof("Failed to cancel orders: %v", err) + } + return at.closeAllPositions() + } + + return nil +} + +// closeAllPositions closes all open positions for the grid symbol +func (at *AutoTrader) closeAllPositions() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + positions, err := at.trader.GetPositions() + if err != nil { + return fmt.Errorf("failed to get positions: %w", err) + } + + for _, pos := range positions { + symbol, _ := pos["symbol"].(string) + if symbol != gridConfig.Symbol { + continue + } + + size, _ := pos["positionAmt"].(float64) + if size == 0 { + continue + } + + if size > 0 { + _, err = at.trader.CloseLong(symbol, size) + } else { + _, err = at.trader.CloseShort(symbol, -size) + } + if err != nil { + logger.Infof("Failed to close position: %v", err) + } + } + + return nil +} + +// checkFalseBreakoutRecovery checks if price has returned to box after breakout +func (at *AutoTrader) checkFalseBreakoutRecovery() error { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return nil + } + + at.gridState.mu.RLock() + breakoutLevel := at.gridState.BreakoutLevel + isPaused := at.gridState.IsPaused + positionReduction := at.gridState.PositionReductionPct + at.gridState.mu.RUnlock() + + // Only check if we had a breakout + if breakoutLevel == string(market.BreakoutNone) && positionReduction == 0 && !isPaused { + return nil + } + + // Get current box data + box, err := market.GetBoxData(gridConfig.Symbol) + if err != nil { + return nil + } + + // Check if price is back inside the long box + if box.CurrentPrice >= box.LongLower && box.CurrentPrice <= box.LongUpper { + logger.Infof("Price returned to box, recovering with 50%% position") + + at.gridState.mu.Lock() + at.gridState.BreakoutLevel = string(market.BreakoutNone) + at.gridState.BreakoutDirection = "" + at.gridState.BreakoutConfirmCount = 0 + at.gridState.PositionReductionPct = 50 // Recover at 50% + at.gridState.IsPaused = false + at.gridState.mu.Unlock() + } + + return nil +} + +// ============================================================================ +// AutoTrader Grid Methods +// ============================================================================ + +// InitializeGrid initializes the grid state and calculates levels +func (at *AutoTrader) InitializeGrid() error { + if at.config.StrategyConfig == nil || at.config.StrategyConfig.GridConfig == nil { + return fmt.Errorf("grid configuration not found") + } + + gridConfig := at.config.StrategyConfig.GridConfig + at.gridState = NewGridState(gridConfig) + + // Get current market price + price, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return fmt.Errorf("failed to get market price: %w", err) + } + + // Calculate grid bounds + if gridConfig.UseATRBounds { + // Get ATR for bound calculation + mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"4h"}, "4h", 20) + if err != nil { + logger.Warnf("Failed to get market data for ATR: %v, using default bounds", err) + at.calculateDefaultBounds(price, gridConfig) + } else { + at.calculateATRBounds(price, mktData, gridConfig) + } + } else { + // Use manual bounds + at.gridState.UpperPrice = gridConfig.UpperPrice + at.gridState.LowerPrice = gridConfig.LowerPrice + } + + // Calculate grid spacing + at.gridState.GridSpacing = (at.gridState.UpperPrice - at.gridState.LowerPrice) / float64(gridConfig.GridCount-1) + + // Initialize grid levels + at.initializeGridLevels(price, gridConfig) + + at.gridState.IsInitialized = true + + // CRITICAL: Set leverage on exchange before trading + if err := at.trader.SetLeverage(gridConfig.Symbol, gridConfig.Leverage); err != nil { + logger.Warnf("[Grid] Failed to set leverage %dx on exchange: %v", gridConfig.Leverage, err) + // Not fatal - continue with default leverage + } else { + logger.Infof("[Grid] Leverage set to %dx for %s", gridConfig.Leverage, gridConfig.Symbol) + } + + logger.Infof("📊 [Grid] Initialized: %d levels, $%.2f - $%.2f, spacing $%.2f", + gridConfig.GridCount, at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) + + return nil +} + +// calculateDefaultBounds calculates default bounds based on price +func (at *AutoTrader) calculateDefaultBounds(price float64, config *store.GridStrategyConfig) { + // Default: ±3% from current price + multiplier := 0.03 * float64(config.GridCount) / 10 + at.gridState.UpperPrice = price * (1 + multiplier) + at.gridState.LowerPrice = price * (1 - multiplier) +} + +// calculateATRBounds calculates bounds using ATR +func (at *AutoTrader) calculateATRBounds(price float64, mktData *market.Data, config *store.GridStrategyConfig) { + atr := 0.0 + if mktData.LongerTermContext != nil { + atr = mktData.LongerTermContext.ATR14 + } + + if atr <= 0 { + at.calculateDefaultBounds(price, config) + return + } + + multiplier := config.ATRMultiplier + if multiplier <= 0 { + multiplier = 2.0 + } + + halfRange := atr * multiplier + at.gridState.UpperPrice = price + halfRange + at.gridState.LowerPrice = price - halfRange +} + +// initializeGridLevels creates the grid level structure +func (at *AutoTrader) initializeGridLevels(currentPrice float64, config *store.GridStrategyConfig) { + levels := make([]kernel.GridLevelInfo, config.GridCount) + totalWeight := 0.0 + weights := make([]float64, config.GridCount) + + // Calculate weights based on distribution + for i := 0; i < config.GridCount; i++ { + switch config.Distribution { + case "gaussian": + // Gaussian distribution - more weight in the middle + center := float64(config.GridCount-1) / 2 + sigma := float64(config.GridCount) / 4 + weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma)) + case "pyramid": + // Pyramid - more weight at bottom + weights[i] = float64(config.GridCount - i) + default: // uniform + weights[i] = 1.0 + } + totalWeight += weights[i] + } + + // Create levels + for i := 0; i < config.GridCount; i++ { + price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing + allocatedUSD := config.TotalInvestment * weights[i] / totalWeight + + // Determine initial side (below current price = buy, above = sell) + side := "buy" + if price > currentPrice { + side = "sell" + } + + levels[i] = kernel.GridLevelInfo{ + Index: i, + Price: price, + State: "empty", + Side: side, + AllocatedUSD: allocatedUSD, + } + } + + at.gridState.Levels = levels +} + +// RunGridCycle executes one grid trading cycle +func (at *AutoTrader) RunGridCycle() error { + // Check if trader is stopped (early exit to prevent trades after Stop() is called) + at.isRunningMutex.RLock() + running := at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("[Grid] Trader is stopped, aborting grid cycle") + return nil + } + + if at.gridState == nil || !at.gridState.IsInitialized { + if err := at.InitializeGrid(); err != nil { + return fmt.Errorf("failed to initialize grid: %w", err) + } + } + + // CRITICAL: Check for breakout before executing any trades + breakoutType, breakoutPct := at.checkBreakout() + if breakoutType != BreakoutNone { + if err := at.handleBreakout(breakoutType, breakoutPct); err != nil { + return err // Grid paused due to breakout + } + } + + // CRITICAL: Check max drawdown + exceeded, drawdown := at.checkMaxDrawdown() + if exceeded { + return at.emergencyExit(fmt.Sprintf("max drawdown exceeded: %.2f%%", drawdown)) + } + + // CRITICAL: Check daily loss limit + dailyExceeded, dailyLossPct := at.checkDailyLossLimit() + if dailyExceeded { + logger.Errorf("[Grid] Daily loss limit exceeded: %.2f%%", dailyLossPct) + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + return fmt.Errorf("daily loss limit exceeded: %.2f%%", dailyLossPct) + } + + // Check multi-period box breakout + if err := at.checkBoxBreakout(); err != nil { + logger.Infof("Box breakout check error: %v", err) + } + + // Check for false breakout recovery + if err := at.checkFalseBreakoutRecovery(); err != nil { + logger.Infof("False breakout recovery check error: %v", err) + } + + // Check if grid is paused + at.gridState.mu.RLock() + isPaused := at.gridState.IsPaused + at.gridState.mu.RUnlock() + if isPaused { + logger.Infof("[Grid] Grid is paused, skipping cycle") + return nil + } + + gridConfig := at.config.StrategyConfig.GridConfig + lang := at.config.StrategyConfig.Language + if lang == "" { + lang = "en" + } + + // Build grid context + gridCtx, err := at.buildGridContext() + if err != nil { + return fmt.Errorf("failed to build grid context: %w", err) + } + + // Get AI decisions + decision, err := kernel.GetGridDecisions(gridCtx, at.mcpClient, gridConfig, lang) + if err != nil { + return fmt.Errorf("failed to get grid decisions: %w", err) + } + + // Check if trader is stopped before executing any decisions (prevent trades after Stop()) + at.isRunningMutex.RLock() + running = at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("[Grid] Trader stopped before decision execution, aborting grid cycle") + return nil + } + + // Execute decisions + for _, d := range decision.Decisions { + // Check if trader is still running before each decision + at.isRunningMutex.RLock() + running := at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("[Grid] Trader stopped, skipping remaining %d decisions", len(decision.Decisions)) + break + } + + if err := at.executeGridDecision(&d); err != nil { + logger.Warnf("[Grid] Failed to execute decision %s: %v", d.Action, err) + } + } + + // Sync state with exchange + at.syncGridState() + + // Save decision record + at.saveGridDecisionRecord(decision) + + return nil +} + +// buildGridContext builds the context for AI grid decisions +func (at *AutoTrader) buildGridContext() (*kernel.GridContext, error) { + gridConfig := at.config.StrategyConfig.GridConfig + + // Get market data + mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"5m", "4h"}, "5m", 50) + if err != nil { + return nil, fmt.Errorf("failed to get market data: %w", err) + } + + // Build base context from market data + ctx := kernel.BuildGridContextFromMarketData(mktData, gridConfig) + + // Add grid state + at.gridState.mu.RLock() + ctx.Levels = at.gridState.Levels + ctx.UpperPrice = at.gridState.UpperPrice + ctx.LowerPrice = at.gridState.LowerPrice + ctx.GridSpacing = at.gridState.GridSpacing + ctx.IsPaused = at.gridState.IsPaused + ctx.TotalProfit = at.gridState.TotalProfit + ctx.TotalTrades = at.gridState.TotalTrades + ctx.WinningTrades = at.gridState.WinningTrades + ctx.MaxDrawdown = at.gridState.MaxDrawdown + ctx.DailyPnL = at.gridState.DailyPnL + + // Count active orders and filled levels + for _, level := range at.gridState.Levels { + if level.State == "pending" { + ctx.ActiveOrderCount++ + } else if level.State == "filled" { + ctx.FilledLevelCount++ + } + } + at.gridState.mu.RUnlock() + + // Get account info + balance, err := at.trader.GetBalance() + if err == nil { + if equity, ok := balance["total_equity"].(float64); ok { + ctx.TotalEquity = equity + } + if available, ok := balance["availableBalance"].(float64); ok { + ctx.AvailableBalance = available + } + if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { + ctx.UnrealizedPnL = unrealized + } + } + + // Get current position + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok { + ctx.CurrentPosition = size + } + } + } + } + + return ctx, nil +} + +// executeGridDecision executes a single grid decision +func (at *AutoTrader) executeGridDecision(d *kernel.Decision) error { + switch d.Action { + case "place_buy_limit": + return at.placeGridLimitOrder(d, "BUY") + case "place_sell_limit": + return at.placeGridLimitOrder(d, "SELL") + case "cancel_order": + return at.cancelGridOrder(d) + case "cancel_all_orders": + return at.cancelAllGridOrders() + case "pause_grid": + return at.pauseGrid(d.Reasoning) + case "resume_grid": + return at.resumeGrid() + case "adjust_grid": + return at.adjustGrid(d) + case "hold": + logger.Infof("[Grid] Holding current state: %s", d.Reasoning) + return nil + // Support standard actions for closing positions + case "close_long": + _, err := at.trader.CloseLong(d.Symbol, d.Quantity) + return err + case "close_short": + _, err := at.trader.CloseShort(d.Symbol, d.Quantity) + return err + default: + logger.Warnf("[Grid] Unknown action: %s", d.Action) + return nil + } +} + +// checkTotalPositionLimit checks if adding a new position would exceed total limits +// Returns: (allowed bool, currentPositionValue float64, maxAllowed float64) +func (at *AutoTrader) checkTotalPositionLimit(symbol string, additionalValue float64) (bool, float64, float64) { + gridConfig := at.config.StrategyConfig.GridConfig + + // Calculate max allowed total position value + // Total position should not exceed: TotalInvestment × Leverage + maxTotalPositionValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) + + // Get current position value from exchange + currentPositionValue := 0.0 + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == symbol { + if size, ok := pos["positionAmt"].(float64); ok { + if price, ok := pos["markPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * price + } else if entryPrice, ok := pos["entryPrice"].(float64); ok { + currentPositionValue = math.Abs(size) * entryPrice + } + } + } + } + } + + // Also count pending orders as potential position + at.gridState.mu.RLock() + pendingValue := 0.0 + for _, level := range at.gridState.Levels { + if level.State == "pending" { + pendingValue += level.OrderQuantity * level.Price + } + } + at.gridState.mu.RUnlock() + + totalAfterOrder := currentPositionValue + pendingValue + additionalValue + allowed := totalAfterOrder <= maxTotalPositionValue + + return allowed, currentPositionValue + pendingValue, maxTotalPositionValue +} + +// placeGridLimitOrder places a limit order for grid trading +func (at *AutoTrader) placeGridLimitOrder(d *kernel.Decision, side string) error { + // Check if trader supports GridTrader interface + gridTrader, ok := at.trader.(GridTrader) + if !ok { + // Fallback to adapter + gridTrader = NewGridTraderAdapter(at.trader) + } + + gridConfig := at.config.StrategyConfig.GridConfig + + // CRITICAL: Validate and cap quantity to prevent excessive position sizes + // This protects against AI miscalculations or leverage misconfigurations + quantity := d.Quantity + if d.Price > 0 && gridConfig.TotalInvestment > 0 { + // Calculate max allowed position value per grid level + // Each level gets proportional share of total investment + maxMarginPerLevel := gridConfig.TotalInvestment / float64(gridConfig.GridCount) + maxPositionValuePerLevel := maxMarginPerLevel * float64(gridConfig.Leverage) + maxQuantityPerLevel := maxPositionValuePerLevel / d.Price + + // Also get the level's allocated USD for additional validation + at.gridState.mu.RLock() + var levelAllocatedUSD float64 + if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) { + levelAllocatedUSD = at.gridState.Levels[d.LevelIndex].AllocatedUSD + } + at.gridState.mu.RUnlock() + + // Use level-specific allocation if available + if levelAllocatedUSD > 0 { + levelMaxPositionValue := levelAllocatedUSD * float64(gridConfig.Leverage) + levelMaxQuantity := levelMaxPositionValue / d.Price + if levelMaxQuantity < maxQuantityPerLevel { + maxQuantityPerLevel = levelMaxQuantity + } + } + + // Cap quantity if it exceeds the maximum allowed + if quantity > maxQuantityPerLevel { + logger.Warnf("[Grid] ⚠️ Quantity %.4f exceeds max allowed %.4f (position_value $%.2f > max $%.2f), capping", + quantity, maxQuantityPerLevel, quantity*d.Price, maxPositionValuePerLevel) + quantity = maxQuantityPerLevel + } + + // Safety check: ensure position value is reasonable (within 2x of intended max as absolute limit) + positionValue := quantity * d.Price + absoluteMaxValue := gridConfig.TotalInvestment * float64(gridConfig.Leverage) * 2 // 2x safety margin + if positionValue > absoluteMaxValue { + logger.Errorf("[Grid] CRITICAL: Position value $%.2f exceeds absolute max $%.2f! Rejecting order.", + positionValue, absoluteMaxValue) + return fmt.Errorf("position value $%.2f exceeds safety limit $%.2f", positionValue, absoluteMaxValue) + } + } + + // CRITICAL: Check total position limit before placing order + orderValue := quantity * d.Price + allowed, currentValue, maxValue := at.checkTotalPositionLimit(d.Symbol, orderValue) + if !allowed { + logger.Errorf("[Grid] TOTAL POSITION LIMIT EXCEEDED: current=$%.2f + order=$%.2f > max=$%.2f. Rejecting order.", + currentValue, orderValue, maxValue) + return fmt.Errorf("total position value $%.2f would exceed limit $%.2f", currentValue+orderValue, maxValue) + } + + req := &LimitOrderRequest{ + Symbol: d.Symbol, + Side: side, + Price: d.Price, + Quantity: quantity, // Use validated/capped quantity + Leverage: gridConfig.Leverage, + PostOnly: gridConfig.UseMakerOnly, + ReduceOnly: false, + ClientID: fmt.Sprintf("grid-%d-%d", d.LevelIndex, time.Now().UnixNano()%1000000), + } + + result, err := gridTrader.PlaceLimitOrder(req) + if err != nil { + return fmt.Errorf("failed to place limit order: %w", err) + } + + // Update grid level state + at.gridState.mu.Lock() + if d.LevelIndex >= 0 && d.LevelIndex < len(at.gridState.Levels) { + at.gridState.Levels[d.LevelIndex].State = "pending" + at.gridState.Levels[d.LevelIndex].OrderID = result.OrderID + at.gridState.Levels[d.LevelIndex].OrderQuantity = d.Quantity + at.gridState.OrderBook[result.OrderID] = d.LevelIndex + } + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Placed %s limit order at $%.2f, qty=%.4f, level=%d, orderID=%s", + side, d.Price, d.Quantity, d.LevelIndex, result.OrderID) + + return nil +} + +// cancelGridOrder cancels a specific grid order +func (at *AutoTrader) cancelGridOrder(d *kernel.Decision) error { + gridTrader, ok := at.trader.(GridTrader) + if !ok { + gridTrader = NewGridTraderAdapter(at.trader) + } + + if err := gridTrader.CancelOrder(d.Symbol, d.OrderID); err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + // Update state + at.gridState.mu.Lock() + if levelIdx, ok := at.gridState.OrderBook[d.OrderID]; ok { + if levelIdx >= 0 && levelIdx < len(at.gridState.Levels) { + at.gridState.Levels[levelIdx].State = "empty" + at.gridState.Levels[levelIdx].OrderID = "" + at.gridState.Levels[levelIdx].OrderQuantity = 0 + } + delete(at.gridState.OrderBook, d.OrderID) + } + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Cancelled order: %s", d.OrderID) + return nil +} + +// cancelAllGridOrders cancels all grid orders +func (at *AutoTrader) cancelAllGridOrders() error { + gridConfig := at.config.StrategyConfig.GridConfig + + if err := at.trader.CancelAllOrders(gridConfig.Symbol); err != nil { + return fmt.Errorf("failed to cancel all orders: %w", err) + } + + // Reset all pending levels + at.gridState.mu.Lock() + for i := range at.gridState.Levels { + if at.gridState.Levels[i].State == "pending" { + at.gridState.Levels[i].State = "empty" + at.gridState.Levels[i].OrderID = "" + at.gridState.Levels[i].OrderQuantity = 0 + } + } + at.gridState.OrderBook = make(map[string]int) + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Cancelled all orders") + return nil +} + +// pauseGrid pauses grid trading +func (at *AutoTrader) pauseGrid(reason string) error { + at.cancelAllGridOrders() + + at.gridState.mu.Lock() + at.gridState.IsPaused = true + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Paused: %s", reason) + return nil +} + +// resumeGrid resumes grid trading +func (at *AutoTrader) resumeGrid() error { + at.gridState.mu.Lock() + at.gridState.IsPaused = false + at.gridState.mu.Unlock() + + logger.Infof("[Grid] Resumed") + return nil +} + +// adjustGrid adjusts grid parameters +func (at *AutoTrader) adjustGrid(d *kernel.Decision) error { + // Cancel existing orders first + at.cancelAllGridOrders() + + gridConfig := at.config.StrategyConfig.GridConfig + + // Get current price + price, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + return fmt.Errorf("failed to get market price: %w", err) + } + + // Reinitialize grid levels + at.initializeGridLevels(price, gridConfig) + + logger.Infof("[Grid] Adjusted grid bounds around price $%.2f", price) + return nil +} + +// syncGridState syncs grid state with exchange +func (at *AutoTrader) syncGridState() { + gridConfig := at.config.StrategyConfig.GridConfig + + // Get open orders from exchange + openOrders, err := at.trader.GetOpenOrders(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get open orders: %v", err) + return + } + + // Build set of active order IDs + activeOrderIDs := make(map[string]bool) + for _, order := range openOrders { + activeOrderIDs[order.OrderID] = true + } + + // Get current positions to verify fills + positions, err := at.trader.GetPositions() + currentPositionSize := 0.0 + if err != nil { + logger.Warnf("[Grid] Failed to get positions for state sync: %v", err) + } else { + for _, pos := range positions { + if sym, ok := pos["symbol"].(string); ok && sym == gridConfig.Symbol { + if size, ok := pos["positionAmt"].(float64); ok { + currentPositionSize = size + } + } + } + } + + // Update levels based on order status + at.gridState.mu.Lock() + expectedPositionSize := 0.0 + for _, level := range at.gridState.Levels { + if level.State == "filled" { + expectedPositionSize += level.PositionSize + } + } + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State == "pending" && level.OrderID != "" { + if !activeOrderIDs[level.OrderID] { + // Order no longer exists - check if position changed to determine fill vs cancel + // This is a heuristic - ideally we'd query order history + // If current position is larger than expected filled positions, this order was likely filled + if math.Abs(currentPositionSize) > math.Abs(expectedPositionSize) { + // Position increased, likely filled + level.State = "filled" + level.PositionEntry = level.Price + level.PositionSize = level.OrderQuantity + at.gridState.TotalTrades++ + logger.Infof("[Grid] Level %d order filled at $%.2f", i, level.Price) + } else { + // Position didn't increase as expected, likely cancelled + level.State = "empty" + level.OrderID = "" + level.OrderQuantity = 0 + logger.Infof("[Grid] Level %d order cancelled/expired", i) + } + delete(at.gridState.OrderBook, level.OrderID) + } + } + } + at.gridState.mu.Unlock() + + logger.Debugf("[Grid] Synced state: position=%.4f, orders=%d", currentPositionSize, len(openOrders)) + + // Check stop loss + at.checkAndExecuteStopLoss() + + // Check grid skew + at.autoAdjustGrid() +} + +// saveGridDecisionRecord saves the grid decision to database +func (at *AutoTrader) saveGridDecisionRecord(decision *kernel.FullDecision) { + if at.store == nil { + return + } + + at.cycleNumber++ + + record := &store.DecisionRecord{ + TraderID: at.id, + CycleNumber: at.cycleNumber, + Timestamp: time.Now().UTC(), + SystemPrompt: decision.SystemPrompt, + InputPrompt: decision.UserPrompt, + CoTTrace: decision.CoTTrace, + RawResponse: decision.RawResponse, + AIRequestDurationMs: decision.AIRequestDurationMs, + Success: true, + } + + if len(decision.Decisions) > 0 { + decisionJSON, _ := json.MarshalIndent(decision.Decisions, "", " ") + record.DecisionJSON = string(decisionJSON) + + // Convert kernel.Decision to store.DecisionAction for frontend display + for _, d := range decision.Decisions { + actionRecord := store.DecisionAction{ + Action: d.Action, + Symbol: d.Symbol, + Quantity: d.Quantity, + Leverage: d.Leverage, + Price: d.Price, + StopLoss: d.StopLoss, + TakeProfit: d.TakeProfit, + Confidence: d.Confidence, + Reasoning: d.Reasoning, + Timestamp: time.Now().UTC(), + Success: true, // Grid decisions are executed inline + } + record.Decisions = append(record.Decisions, actionRecord) + } + } + + record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("Grid cycle completed with %d decisions", len(decision.Decisions))) + + if err := at.store.Decision().LogDecision(record); err != nil { + logger.Warnf("[Grid] Failed to save decision record: %v", err) + } +} + +// IsGridStrategy returns true if current strategy is grid trading +func (at *AutoTrader) IsGridStrategy() bool { + if at.config.StrategyConfig == nil { + return false + } + return at.config.StrategyConfig.StrategyType == "grid_trading" && at.config.StrategyConfig.GridConfig != nil +} + +// checkGridSkew checks if grid is heavily skewed (too many fills on one side) +// Returns: (skewed bool, buyFilledCount int, sellFilledCount int) +func (at *AutoTrader) checkGridSkew() (bool, int, int) { + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + buyFilled := 0 + sellFilled := 0 + buyEmpty := 0 + sellEmpty := 0 + + for _, level := range at.gridState.Levels { + if level.Side == "buy" { + if level.State == "filled" { + buyFilled++ + } else if level.State == "empty" { + buyEmpty++ + } + } else { + if level.State == "filled" { + sellFilled++ + } else if level.State == "empty" { + sellEmpty++ + } + } + } + + // Grid is skewed if one side has 3x more fills than the other + // or if one side is completely empty + skewed := false + if buyFilled > 0 && sellFilled == 0 && sellEmpty > 5 { + skewed = true // All buys filled, no sells + } else if sellFilled > 0 && buyFilled == 0 && buyEmpty > 5 { + skewed = true // All sells filled, no buys + } else if buyFilled >= 3*sellFilled && buyFilled > 5 { + skewed = true + } else if sellFilled >= 3*buyFilled && sellFilled > 5 { + skewed = true + } + + return skewed, buyFilled, sellFilled +} + +// autoAdjustGrid automatically adjusts grid when heavily skewed +func (at *AutoTrader) autoAdjustGrid() { + skewed, buyFilled, sellFilled := at.checkGridSkew() + if !skewed { + return + } + + logger.Warnf("[Grid] Grid heavily skewed: buy_filled=%d, sell_filled=%d. Auto-adjusting...", + buyFilled, sellFilled) + + gridConfig := at.config.StrategyConfig.GridConfig + + // Get current price + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Errorf("[Grid] Failed to get price for auto-adjust: %v", err) + return + } + + // Check if price is near grid boundary + at.gridState.mu.RLock() + upper := at.gridState.UpperPrice + lower := at.gridState.LowerPrice + at.gridState.mu.RUnlock() + + // Only adjust if price has moved significantly (>30% of grid range) + gridRange := upper - lower + midPrice := (upper + lower) / 2 + priceDeviation := math.Abs(currentPrice - midPrice) + + if priceDeviation < gridRange*0.3 { + return // Price still near center, don't adjust + } + + logger.Infof("[Grid] Adjusting grid around new price $%.2f", currentPrice) + + // Cancel existing orders first (before taking the lock for state modification) + if err := at.cancelAllGridOrders(); err != nil { + logger.Errorf("[Grid] Failed to cancel orders during auto-adjust: %v", err) + // Continue with adjustment anyway + } + + // CRITICAL FIX: Hold lock for the entire adjustment operation to ensure atomicity + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + // Preserve filled positions before reinitializing + filledPositions := make(map[int]kernel.GridLevelInfo) + for i, level := range at.gridState.Levels { + if level.State == "filled" { + filledPositions[i] = level + } + } + + // CRITICAL FIX: Recalculate grid bounds centered on current price + // Use the same logic as InitializeGrid() - either ATR-based or default percentage + if gridConfig.UseATRBounds { + // Try to get ATR for bound calculation + mktData, err := market.GetWithTimeframes(gridConfig.Symbol, []string{"4h"}, "4h", 20) + if err != nil { + logger.Warnf("[Grid] Failed to get market data for ATR during adjust: %v, using default bounds", err) + at.calculateDefaultBoundsLocked(currentPrice, gridConfig) + } else { + at.calculateATRBoundsLocked(currentPrice, mktData, gridConfig) + } + } else { + // Use default bounds calculation (scaled by grid count) + at.calculateDefaultBoundsLocked(currentPrice, gridConfig) + } + + // Recalculate grid spacing based on new bounds + at.gridState.GridSpacing = (at.gridState.UpperPrice - at.gridState.LowerPrice) / float64(gridConfig.GridCount-1) + + logger.Infof("[Grid] New bounds: $%.2f - $%.2f, spacing: $%.2f", + at.gridState.LowerPrice, at.gridState.UpperPrice, at.gridState.GridSpacing) + + // Initialize new grid levels (without lock since we already hold it) + at.initializeGridLevelsLocked(currentPrice, gridConfig) + + // CRITICAL FIX: Restore filled positions - find closest new level for each filled position + for _, filledLevel := range filledPositions { + closestIdx := -1 + closestDist := math.MaxFloat64 + + for i, newLevel := range at.gridState.Levels { + dist := math.Abs(newLevel.Price - filledLevel.PositionEntry) + if dist < closestDist { + closestDist = dist + closestIdx = i + } + } + + if closestIdx >= 0 { + // Restore the filled state to the closest level + at.gridState.Levels[closestIdx].State = "filled" + at.gridState.Levels[closestIdx].PositionEntry = filledLevel.PositionEntry + at.gridState.Levels[closestIdx].PositionSize = filledLevel.PositionSize + at.gridState.Levels[closestIdx].UnrealizedPnL = filledLevel.UnrealizedPnL + at.gridState.Levels[closestIdx].OrderID = filledLevel.OrderID + at.gridState.Levels[closestIdx].OrderQuantity = filledLevel.OrderQuantity + logger.Infof("[Grid] Restored filled position at level %d (entry $%.2f)", closestIdx, filledLevel.PositionEntry) + } + } +} + +// calculateDefaultBoundsLocked calculates default bounds (caller must hold lock) +func (at *AutoTrader) calculateDefaultBoundsLocked(price float64, config *store.GridStrategyConfig) { + // Default: ±3% from current price, scaled by grid count + multiplier := 0.03 * float64(config.GridCount) / 10 + at.gridState.UpperPrice = price * (1 + multiplier) + at.gridState.LowerPrice = price * (1 - multiplier) +} + +// calculateATRBoundsLocked calculates bounds using ATR (caller must hold lock) +func (at *AutoTrader) calculateATRBoundsLocked(price float64, mktData *market.Data, config *store.GridStrategyConfig) { + atr := 0.0 + if mktData.LongerTermContext != nil { + atr = mktData.LongerTermContext.ATR14 + } + + if atr <= 0 { + at.calculateDefaultBoundsLocked(price, config) + return + } + + multiplier := config.ATRMultiplier + if multiplier <= 0 { + multiplier = 2.0 + } + + halfRange := atr * multiplier + at.gridState.UpperPrice = price + halfRange + at.gridState.LowerPrice = price - halfRange +} + +// initializeGridLevelsLocked creates the grid level structure (caller must hold lock) +func (at *AutoTrader) initializeGridLevelsLocked(currentPrice float64, config *store.GridStrategyConfig) { + levels := make([]kernel.GridLevelInfo, config.GridCount) + totalWeight := 0.0 + weights := make([]float64, config.GridCount) + + // Calculate weights based on distribution + for i := 0; i < config.GridCount; i++ { + switch config.Distribution { + case "gaussian": + // Gaussian distribution - more weight in the middle + center := float64(config.GridCount-1) / 2 + sigma := float64(config.GridCount) / 4 + weights[i] = math.Exp(-math.Pow(float64(i)-center, 2) / (2 * sigma * sigma)) + case "pyramid": + // Pyramid - more weight at bottom + weights[i] = float64(config.GridCount - i) + default: // uniform + weights[i] = 1.0 + } + totalWeight += weights[i] + } + + // Create levels + for i := 0; i < config.GridCount; i++ { + price := at.gridState.LowerPrice + float64(i)*at.gridState.GridSpacing + allocatedUSD := config.TotalInvestment * weights[i] / totalWeight + + // Determine initial side (below current price = buy, above = sell) + side := "buy" + if price > currentPrice { + side = "sell" + } + + levels[i] = kernel.GridLevelInfo{ + Index: i, + Price: price, + State: "empty", + Side: side, + AllocatedUSD: allocatedUSD, + } + } + + at.gridState.Levels = levels +} + +// GridRiskInfo contains risk information for frontend display +type GridRiskInfo struct { + CurrentLeverage int `json:"current_leverage"` + EffectiveLeverage float64 `json:"effective_leverage"` + RecommendedLeverage int `json:"recommended_leverage"` + + CurrentPosition float64 `json:"current_position"` + MaxPosition float64 `json:"max_position"` + PositionPercent float64 `json:"position_percent"` + + LiquidationPrice float64 `json:"liquidation_price"` + LiquidationDistance float64 `json:"liquidation_distance"` + + RegimeLevel string `json:"regime_level"` + + ShortBoxUpper float64 `json:"short_box_upper"` + ShortBoxLower float64 `json:"short_box_lower"` + MidBoxUpper float64 `json:"mid_box_upper"` + MidBoxLower float64 `json:"mid_box_lower"` + LongBoxUpper float64 `json:"long_box_upper"` + LongBoxLower float64 `json:"long_box_lower"` + CurrentPrice float64 `json:"current_price"` + + BreakoutLevel string `json:"breakout_level"` + BreakoutDirection string `json:"breakout_direction"` +} + +// GetGridRiskInfo returns current risk information for frontend display +func (at *AutoTrader) GetGridRiskInfo() *GridRiskInfo { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig == nil { + return &GridRiskInfo{} + } + + at.gridState.mu.RLock() + defer at.gridState.mu.RUnlock() + + // Get current price + currentPrice, _ := at.trader.GetMarketPrice(gridConfig.Symbol) + + // Calculate effective leverage + totalInvestment := gridConfig.TotalInvestment + leverage := gridConfig.Leverage + + // Get current position value + positions, _ := at.trader.GetPositions() + var currentPositionValue float64 + var currentPositionSize float64 + for _, pos := range positions { + if sym, _ := pos["symbol"].(string); sym == gridConfig.Symbol { + size, _ := pos["positionAmt"].(float64) + entry, _ := pos["entryPrice"].(float64) + currentPositionValue = math.Abs(size * entry) + currentPositionSize = size + break + } + } + + effectiveLeverage := 0.0 + if totalInvestment > 0 { + effectiveLeverage = currentPositionValue / totalInvestment + } + + // Calculate max position based on regime + regimeLevel := market.RegimeLevel(at.gridState.CurrentRegimeLevel) + if regimeLevel == "" { + regimeLevel = market.RegimeLevelStandard + } + + // Use default position limit since GridStrategyConfig doesn't have regime-specific limits + // Default is 70% for standard regime + maxPositionPct := 70.0 + switch regimeLevel { + case market.RegimeLevelNarrow: + maxPositionPct = 40.0 + case market.RegimeLevelStandard: + maxPositionPct = 70.0 + case market.RegimeLevelWide: + maxPositionPct = 60.0 + case market.RegimeLevelVolatile: + maxPositionPct = 40.0 + } + + maxPosition := totalInvestment * maxPositionPct / 100 * float64(leverage) + + // Use default leverage limits since GridStrategyConfig doesn't have regime-specific limits + recommendedLeverage := leverage + switch regimeLevel { + case market.RegimeLevelNarrow: + recommendedLeverage = min(leverage, 2) + case market.RegimeLevelStandard: + recommendedLeverage = min(leverage, 4) + case market.RegimeLevelWide: + recommendedLeverage = min(leverage, 3) + case market.RegimeLevelVolatile: + recommendedLeverage = min(leverage, 2) + } + + // Calculate liquidation distance and price only when there's a position + var liquidationDistance float64 + var liquidationPrice float64 + if currentPositionSize != 0 && currentPrice > 0 { + liquidationDistance = 100.0 / float64(leverage) * 0.9 // ~90% of theoretical max + if currentPositionSize > 0 { + // Long position: liquidation below entry + liquidationPrice = currentPrice * (1 - liquidationDistance/100) + } else { + // Short position: liquidation above entry + liquidationPrice = currentPrice * (1 + liquidationDistance/100) + } + } + + positionPercent := 0.0 + if maxPosition > 0 { + positionPercent = currentPositionValue / maxPosition * 100 + } + + return &GridRiskInfo{ + CurrentLeverage: leverage, + EffectiveLeverage: effectiveLeverage, + RecommendedLeverage: recommendedLeverage, + + CurrentPosition: currentPositionValue, + MaxPosition: maxPosition, + PositionPercent: positionPercent, + + LiquidationPrice: liquidationPrice, + LiquidationDistance: liquidationDistance, + + RegimeLevel: string(regimeLevel), + + ShortBoxUpper: at.gridState.ShortBoxUpper, + ShortBoxLower: at.gridState.ShortBoxLower, + MidBoxUpper: at.gridState.MidBoxUpper, + MidBoxLower: at.gridState.MidBoxLower, + LongBoxUpper: at.gridState.LongBoxUpper, + LongBoxLower: at.gridState.LongBoxLower, + CurrentPrice: currentPrice, + + BreakoutLevel: at.gridState.BreakoutLevel, + BreakoutDirection: at.gridState.BreakoutDirection, + } +} + +// checkAndExecuteStopLoss checks if any filled level has exceeded stop loss and closes it +func (at *AutoTrader) checkAndExecuteStopLoss() { + gridConfig := at.config.StrategyConfig.GridConfig + if gridConfig.StopLossPct <= 0 { + return // Stop loss not configured + } + + currentPrice, err := at.trader.GetMarketPrice(gridConfig.Symbol) + if err != nil { + logger.Warnf("[Grid] Failed to get market price for stop loss check: %v", err) + return + } + + at.gridState.mu.Lock() + defer at.gridState.mu.Unlock() + + for i := range at.gridState.Levels { + level := &at.gridState.Levels[i] + if level.State != "filled" || level.PositionEntry <= 0 { + continue + } + + // Calculate loss percentage + var lossPct float64 + if level.Side == "buy" { + // Long position: loss when price drops + lossPct = (level.PositionEntry - currentPrice) / level.PositionEntry * 100 + } else { + // Short position: loss when price rises + lossPct = (currentPrice - level.PositionEntry) / level.PositionEntry * 100 + } + + // Check if stop loss triggered + if lossPct >= gridConfig.StopLossPct { + logger.Warnf("[Grid] STOP LOSS TRIGGERED: Level %d, entry=$%.2f, current=$%.2f, loss=%.2f%%", + i, level.PositionEntry, currentPrice, lossPct) + + // Close the position + var closeErr error + if level.Side == "buy" { + _, closeErr = at.trader.CloseLong(gridConfig.Symbol, level.PositionSize) + } else { + _, closeErr = at.trader.CloseShort(gridConfig.Symbol, level.PositionSize) + } + + if closeErr != nil { + logger.Errorf("[Grid] Failed to execute stop loss for level %d: %v", i, closeErr) + } else { + level.State = "stopped" + realizedLoss := -lossPct * level.AllocatedUSD / 100 + level.UnrealizedPnL = realizedLoss + at.gridState.TotalTrades++ + // Update daily PnL tracking (lock already held, update directly) + at.gridState.DailyPnL += realizedLoss + at.gridState.TotalProfit += realizedLoss + logger.Infof("[Grid] Stop loss executed: Level %d closed at $%.2f (loss %.2f%%)", + i, currentPrice, lossPct) + } + } + } +} diff --git a/trader/binance_futures.go b/trader/binance_futures.go index 5a54db04..a7ef6dd0 100644 --- a/trader/binance_futures.go +++ b/trader/binance_futures.go @@ -716,6 +716,125 @@ func (t *FuturesTrader) CancelAllOrders(symbol string) error { return nil } +// PlaceLimitOrder places a limit order for grid trading +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // Format quantity to correct precision + quantityStr, err := t.FormatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Format price to correct precision + priceStr, err := t.FormatPrice(req.Symbol, req.Price) + if err != nil { + return nil, fmt.Errorf("failed to format price: %w", err) + } + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("Failed to set leverage: %v", err) + } + } + + // Determine side and position side + var side futures.SideType + var positionSide futures.PositionSideType + + if req.Side == "BUY" { + side = futures.SideTypeBuy + positionSide = futures.PositionSideTypeLong + } else { + side = futures.SideTypeSell + positionSide = futures.PositionSideTypeShort + } + + // Build order service with broker ID + orderService := t.client.NewCreateOrderService(). + Symbol(req.Symbol). + Side(side). + PositionSide(positionSide). + Type(futures.OrderTypeLimit). + TimeInForce(futures.TimeInForceTypeGTC). + Quantity(quantityStr). + Price(priceStr). + NewClientOrderID(getBrOrderID()) + + // Execute order + order, err := orderService.Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + logger.Infof("✓ [Grid] Placed limit order: %s %s %s @ %s, qty=%s, orderID=%d", + req.Symbol, req.Side, positionSide, priceStr, quantityStr, order.OrderID) + + return &LimitOrderResult{ + OrderID: fmt.Sprintf("%d", order.OrderID), + ClientID: order.ClientOrderID, + Symbol: order.Symbol, + Side: string(order.Side), + PositionSide: string(order.PositionSide), + Price: req.Price, + Quantity: req.Quantity, + Status: string(order.Status), + }, nil +} + +// CancelOrder cancels a specific order by ID +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) CancelOrder(symbol, orderID string) error { + // Parse order ID to int64 + orderIDInt, err := strconv.ParseInt(orderID, 10, 64) + if err != nil { + return fmt.Errorf("invalid order ID: %w", err) + } + + _, err = t.client.NewCancelOrderService(). + Symbol(symbol). + OrderID(orderIDInt). + Do(context.Background()) + + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Grid] Cancelled order: %s/%s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// This implements the GridTrader interface for FuturesTrader +func (t *FuturesTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + book, err := t.client.NewDepthService(). + Symbol(symbol). + Limit(depth). + Do(context.Background()) + + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + // Convert bids + bids = make([][]float64, len(book.Bids)) + for i, bid := range book.Bids { + price, _ := strconv.ParseFloat(bid.Price, 64) + qty, _ := strconv.ParseFloat(bid.Quantity, 64) + bids[i] = []float64{price, qty} + } + + // Convert asks + asks = make([][]float64, len(book.Asks)) + for i, ask := range book.Asks { + price, _ := strconv.ParseFloat(ask.Price, 64) + qty, _ := strconv.ParseFloat(ask.Quantity, 64) + asks[i] = []float64{price, qty} + } + + return bids, asks, nil +} + // CancelStopOrders cancels take-profit/stop-loss orders for this symbol (used to adjust TP/SL positions) // Now uses both legacy API and new Algo Order API (Binance migrated stop orders to Algo system) func (t *FuturesTrader) CancelStopOrders(symbol string) error { @@ -1035,6 +1154,42 @@ func (t *FuturesTrader) FormatQuantity(symbol string, quantity float64) (string, return fmt.Sprintf(format, quantity), nil } +// GetSymbolPricePrecision gets the price precision for a trading pair +func (t *FuturesTrader) GetSymbolPricePrecision(symbol string) (int, error) { + exchangeInfo, err := t.client.NewExchangeInfoService().Do(context.Background()) + if err != nil { + return 0, fmt.Errorf("failed to get trading rules: %w", err) + } + + for _, s := range exchangeInfo.Symbols { + if s.Symbol == symbol { + // Get precision from PRICE_FILTER filter + for _, filter := range s.Filters { + if filter["filterType"] == "PRICE_FILTER" { + tickSize := filter["tickSize"].(string) + precision := calculatePrecision(tickSize) + return precision, nil + } + } + } + } + + // Default to 2 decimal places for price + return 2, nil +} + +// FormatPrice formats price to correct precision +func (t *FuturesTrader) FormatPrice(symbol string, price float64) (string, error) { + precision, err := t.GetSymbolPricePrecision(symbol) + if err != nil { + // If retrieval fails, use default format + return fmt.Sprintf("%.2f", price), nil + } + + format := fmt.Sprintf("%%.%df", precision) + return fmt.Sprintf(format, price), nil +} + // Helper functions func contains(s, substr string) bool { return len(s) >= len(substr) && stringContains(s, substr) diff --git a/trader/binance_sync_e2e_test.go b/trader/binance_sync_e2e_test.go index 9024b43b..91f42436 100644 --- a/trader/binance_sync_e2e_test.go +++ b/trader/binance_sync_e2e_test.go @@ -92,7 +92,7 @@ func TestBinanceSyncE2E(t *testing.T) { t.Logf(" [%d] %s %s %s qty=%.6f price=%.4f action=%s time=%s", i+1, order.ExchangeOrderID, order.Symbol, order.Side, order.Quantity, order.Price, order.OrderAction, - order.FilledAt.Format(time.RFC3339)) + time.UnixMilli(order.FilledAt).Format(time.RFC3339)) } } @@ -118,10 +118,11 @@ func TestBinanceSyncE2E(t *testing.T) { } // Test GetLastFillTimeByExchange - lastFillTime, err := orderStore.GetLastFillTimeByExchange(exchangeID) + lastFillTimeMs, err := orderStore.GetLastFillTimeByExchange(exchangeID) if err != nil { t.Logf(" ⚠️ GetLastFillTimeByExchange error: %v", err) } else { + lastFillTime := time.UnixMilli(lastFillTimeMs) t.Logf("\n📅 Last fill time from DB: %s", lastFillTime.Format(time.RFC3339)) // Check if it would be in the future (the bug we fixed) @@ -175,7 +176,7 @@ func TestBinanceSyncWithExistingData(t *testing.T) { Price: 50000, Quantity: 0.001, QuoteQuantity: 50, - CreatedAt: localTime, // This time is "in the future" if interpreted as UTC + CreatedAt: localTime.UnixMilli(), // This time is "in the future" if interpreted as UTC } if err := orderStore.CreateFill(fakeFill); err != nil { t.Fatalf("Failed to create fake fill: %v", err) @@ -186,10 +187,11 @@ func TestBinanceSyncWithExistingData(t *testing.T) { t.Logf(" Current UTC time: %s", time.Now().UTC().Format(time.RFC3339)) // Check GetLastFillTimeByExchange - lastFillTime, _ := orderStore.GetLastFillTimeByExchange(exchangeID) - t.Logf(" GetLastFillTimeByExchange returned: %s", lastFillTime.Format(time.RFC3339)) + lastFillTimeMs2, _ := orderStore.GetLastFillTimeByExchange(exchangeID) + lastFillTime2 := time.UnixMilli(lastFillTimeMs2) + t.Logf(" GetLastFillTimeByExchange returned: %s", lastFillTime2.Format(time.RFC3339)) - if lastFillTime.After(time.Now().UTC()) { + if lastFillTime2.After(time.Now().UTC()) { t.Logf(" ⚠️ Last fill time is in the future - this is the bug scenario!") } diff --git a/trader/bitget_trader.go b/trader/bitget_trader.go index 41f42f4a..8990df52 100644 --- a/trader/bitget_trader.go +++ b/trader/bitget_trader.go @@ -1099,6 +1099,240 @@ func genBitgetClientOid() string { // GetOpenOrders gets all open/pending orders for a symbol func (t *BitgetTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { - // TODO: Implement Bitget open orders - return []OpenOrder{}, nil + symbol = t.convertSymbol(symbol) + var result []OpenOrder + + // 1. Get pending limit orders + params := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + } + + data, err := t.doRequest("GET", bitgetPendingPath, params) + if err != nil { + logger.Warnf("[Bitget] Failed to get pending orders: %v", err) + } + if err == nil && data != nil { + var orders struct { + EntrustedList []struct { + OrderId string `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` // buy/sell + TradeSide string `json:"tradeSide"` // open/close + PosSide string `json:"posSide"` // long/short + OrderType string `json:"orderType"` // limit/market + Price string `json:"price"` + Size string `json:"size"` + State string `json:"state"` + } `json:"entrustedList"` + } + if err := json.Unmarshal(data, &orders); err == nil { + for _, order := range orders.EntrustedList { + price, _ := strconv.ParseFloat(order.Price, 64) + quantity, _ := strconv.ParseFloat(order.Size, 64) + + // Convert side to standard format + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + + result = append(result, OpenOrder{ + OrderID: order.OrderId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: strings.ToUpper(order.OrderType), + Price: price, + StopPrice: 0, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + // 2. Get pending plan orders (stop-loss/take-profit) + planParams := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + } + + planData, err := t.doRequest("GET", "/api/v2/mix/order/orders-plan-pending", planParams) + if err != nil { + logger.Warnf("[Bitget] Failed to get plan orders: %v", err) + } + if err == nil && planData != nil { + var planOrders struct { + EntrustedList []struct { + OrderId string `json:"orderId"` + Symbol string `json:"symbol"` + Side string `json:"side"` + PosSide string `json:"posSide"` + PlanType string `json:"planType"` // normal_plan/profit_plan/loss_plan + TriggerPrice string `json:"triggerPrice"` + Size string `json:"size"` + State string `json:"state"` + } `json:"entrustedList"` + } + if err := json.Unmarshal(planData, &planOrders); err == nil { + for _, order := range planOrders.EntrustedList { + triggerPrice, _ := strconv.ParseFloat(order.TriggerPrice, 64) + quantity, _ := strconv.ParseFloat(order.Size, 64) + + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + + // Map Bitget plan type to order type + orderType := "STOP_MARKET" + if order.PlanType == "profit_plan" { + orderType = "TAKE_PROFIT_MARKET" + } + + result = append(result, OpenOrder{ + OrderID: order.OrderId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: orderType, + Price: 0, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + logger.Infof("✓ BITGET GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *BitgetTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + symbol := t.convertSymbol(req.Symbol) + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(symbol, req.Leverage); err != nil { + logger.Warnf("[Bitget] Failed to set leverage: %v", err) + } + } + + // Format quantity + qtyStr, _ := t.FormatQuantity(symbol, req.Quantity) + + // Determine side + side := "buy" + if req.Side == "SELL" { + side = "sell" + } + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "marginMode": "crossed", + "marginCoin": "USDT", + "side": side, + "orderType": "limit", + "size": qtyStr, + "price": fmt.Sprintf("%.8f", req.Price), + "force": "GTC", // Good Till Cancel + "clientOid": genBitgetClientOid(), + } + + // Add reduce only if specified + if req.ReduceOnly { + body["reduceOnly"] = "YES" + } + + logger.Infof("[Bitget] PlaceLimitOrder: %s %s @ %.4f, qty=%s", symbol, side, req.Price, qtyStr) + + data, err := t.doRequest("POST", bitgetOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var order struct { + OrderId string `json:"orderId"` + ClientOid string `json:"clientOid"` + } + + if err := json.Unmarshal(data, &order); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + logger.Infof("✓ [Bitget] Limit order placed: %s %s @ %.4f, orderID=%s", + symbol, side, req.Price, order.OrderId) + + return &LimitOrderResult{ + OrderID: order.OrderId, + ClientID: order.ClientOid, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *BitgetTrader) CancelOrder(symbol, orderID string) error { + symbol = t.convertSymbol(symbol) + + body := map[string]interface{}{ + "symbol": symbol, + "productType": "USDT-FUTURES", + "orderId": orderID, + } + + _, err := t.doRequest("POST", "/api/v2/mix/order/cancel-order", body) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Bitget] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *BitgetTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + symbol = t.convertSymbol(symbol) + path := fmt.Sprintf("/api/v2/mix/market/depth?symbol=%s&productType=USDT-FUTURES&limit=%d", symbol, depth) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + var result struct { + Bids [][]string `json:"bids"` + Asks [][]string `json:"asks"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + // Parse bids + for _, b := range result.Bids { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result.Asks { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil } diff --git a/trader/bybit_trader.go b/trader/bybit_trader.go index d40de870..745a58e4 100644 --- a/trader/bybit_trader.go +++ b/trader/bybit_trader.go @@ -1105,3 +1105,159 @@ func (t *BybitTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { return result, nil } + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *BybitTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // Format quantity + qtyStr, err := t.FormatQuantity(req.Symbol, req.Quantity) + if err != nil { + return nil, fmt.Errorf("failed to format quantity: %w", err) + } + + // Format price + priceStr := fmt.Sprintf("%.8f", req.Price) + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Bybit] Failed to set leverage: %v", err) + } + } + + // Determine side + side := "Buy" + if req.Side == "SELL" { + side = "Sell" + } + + params := map[string]interface{}{ + "category": "linear", + "symbol": req.Symbol, + "side": side, + "orderType": "Limit", + "qty": qtyStr, + "price": priceStr, + "timeInForce": "GTC", // Good Till Cancel + "positionIdx": 0, // One-way position mode + } + + // Add reduce only if specified + if req.ReduceOnly { + params["reduceOnly"] = true + } + + logger.Infof("[Bybit] PlaceLimitOrder: %s %s @ %s, qty=%s", req.Symbol, side, priceStr, qtyStr) + + result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + // Parse result + orderID := "" + if result.RetCode == 0 { + if resultData, ok := result.Result.(map[string]interface{}); ok { + if id, ok := resultData["orderId"].(string); ok { + orderID = id + } + } + } else { + return nil, fmt.Errorf("Bybit order failed: %s", result.RetMsg) + } + + logger.Infof("✓ [Bybit] Limit order placed: %s %s @ %s, qty=%s, orderID=%s", + req.Symbol, side, priceStr, qtyStr, orderID) + + return &LimitOrderResult{ + OrderID: orderID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *BybitTrader) CancelOrder(symbol, orderID string) error { + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "orderId": orderID, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).CancelOrder(context.Background()) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + if result.RetCode != 0 { + return fmt.Errorf("Bybit cancel order failed: %s", result.RetMsg) + } + + logger.Infof("✓ [Bybit] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *BybitTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + if depth <= 0 { + depth = 25 + } + + // Use HTTP request directly since the SDK doesn't expose GetOrderbook + url := fmt.Sprintf("https://api.bybit.com/v5/market/orderbook?category=linear&symbol=%s&limit=%d", symbol, depth) + resp, err := http.Get(url) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + RetCode int `json:"retCode"` + RetMsg string `json:"retMsg"` + Result struct { + S string `json:"s"` // symbol + B [][]string `json:"b"` // bids [[price, size], ...] + A [][]string `json:"a"` // asks [[price, size], ...] + } `json:"result"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + if result.RetCode != 0 { + return nil, nil, fmt.Errorf("Bybit get orderbook failed: %s", result.RetMsg) + } + + // Parse bids + for _, b := range result.Result.B { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result.Result.A { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil +} diff --git a/trader/exchange_sync_test.go b/trader/exchange_sync_test.go index 7811d4cd..fc0c4866 100644 --- a/trader/exchange_sync_test.go +++ b/trader/exchange_sync_test.go @@ -141,7 +141,7 @@ func runStandardTests(t *testing.T, exchangeName string) { traderID, exchangeID, exchangeType, trade.Symbol, trade.Side, trade.Action, trade.Quantity, trade.Price, trade.Fee, trade.RealizedPnL, - time.Now().Add(time.Duration(i)*time.Second), + time.Now().Add(time.Duration(i)*time.Second).UnixMilli(), "", ) if err != nil { @@ -227,7 +227,7 @@ func TestPositionAccumulationBug(t *testing.T) { traderID, exchangeID, exchangeType, "ETHUSDT", "LONG", "open_long", 0.1, 3500+float64(i*10), 0.5, 0, - time.Now().Add(time.Duration(i*2)*time.Second), + time.Now().Add(time.Duration(i*2)*time.Second).UnixMilli(), "", ) if err != nil { @@ -239,7 +239,7 @@ func TestPositionAccumulationBug(t *testing.T) { traderID, exchangeID, exchangeType, "ETHUSDT", "LONG", "close_long", 0.1, 3600+float64(i*10), 0.5, 10, - time.Now().Add(time.Duration(i*2+1)*time.Second), + time.Now().Add(time.Duration(i*2+1)*time.Second).UnixMilli(), "", ) if err != nil { @@ -309,7 +309,7 @@ func TestQuantityPrecision(t *testing.T) { traderID, exchangeID, exchangeType, "BTCUSDT", "LONG", "open_long", 0.01, 50000, 1.0, 0, - time.Now(), + time.Now().UnixMilli(), "", ) if err != nil { @@ -322,7 +322,7 @@ func TestQuantityPrecision(t *testing.T) { traderID, exchangeID, exchangeType, "BTCUSDT", "LONG", "close_long", 0.00999999, 51000, 1.0, 10, - time.Now().Add(time.Second), + time.Now().Add(time.Second).UnixMilli(), "", ) if err != nil { diff --git a/trader/grid_regime.go b/trader/grid_regime.go new file mode 100644 index 00000000..e574cc1b --- /dev/null +++ b/trader/grid_regime.go @@ -0,0 +1,196 @@ +package trader + +import ( + "nofx/market" + "nofx/store" + "time" +) + +// ============================================================================ +// Task 6: Regime Level Classification +// ============================================================================ + +// classifyRegimeLevel determines the regime level based on market indicators +// bollingerWidth: Bollinger band width as percentage +// atr14Pct: ATR14 as percentage of current price +func classifyRegimeLevel(bollingerWidth, atr14Pct float64) market.RegimeLevel { + // Narrow: Bollinger < 2%, ATR < 1% + if bollingerWidth < 2.0 && atr14Pct < 1.0 { + return market.RegimeLevelNarrow + } + + // Standard: Bollinger 2-3%, ATR 1-2% + if bollingerWidth <= 3.0 && atr14Pct <= 2.0 { + return market.RegimeLevelStandard + } + + // Wide: Bollinger 3-4%, ATR 2-3% + if bollingerWidth <= 4.0 && atr14Pct <= 3.0 { + return market.RegimeLevelWide + } + + // Volatile: Bollinger > 4%, ATR > 3% + return market.RegimeLevelVolatile +} + +// getRegimeLeverageLimit returns the effective leverage limit for a regime level +func getRegimeLeverageLimit(level market.RegimeLevel, config *store.GridConfigModel) int { + switch level { + case market.RegimeLevelNarrow: + if config.NarrowRegimeLeverage > 0 { + return config.NarrowRegimeLeverage + } + return 2 + case market.RegimeLevelStandard: + if config.StandardRegimeLeverage > 0 { + return config.StandardRegimeLeverage + } + return 4 + case market.RegimeLevelWide: + if config.WideRegimeLeverage > 0 { + return config.WideRegimeLeverage + } + return 3 + case market.RegimeLevelVolatile: + if config.VolatileRegimeLeverage > 0 { + return config.VolatileRegimeLeverage + } + return 2 + default: + return 2 // Conservative default + } +} + +// getRegimePositionLimit returns the position limit percentage for a regime level +func getRegimePositionLimit(level market.RegimeLevel, config *store.GridConfigModel) float64 { + switch level { + case market.RegimeLevelNarrow: + if config.NarrowRegimePositionPct > 0 { + return config.NarrowRegimePositionPct + } + return 40.0 + case market.RegimeLevelStandard: + if config.StandardRegimePositionPct > 0 { + return config.StandardRegimePositionPct + } + return 70.0 + case market.RegimeLevelWide: + if config.WideRegimePositionPct > 0 { + return config.WideRegimePositionPct + } + return 60.0 + case market.RegimeLevelVolatile: + if config.VolatileRegimePositionPct > 0 { + return config.VolatileRegimePositionPct + } + return 40.0 + default: + return 40.0 // Conservative default + } +} + +// ============================================================================ +// Task 7: Breakout Detection +// ============================================================================ + +// detectBoxBreakout checks if price has broken out of any box level +// Returns the highest breakout level and direction +func detectBoxBreakout(box *market.BoxData) (market.BreakoutLevel, string) { + if box == nil { + return market.BreakoutNone, "" + } + + price := box.CurrentPrice + + // Check long box first (highest priority) + if price > box.LongUpper { + return market.BreakoutLong, "up" + } + if price < box.LongLower { + return market.BreakoutLong, "down" + } + + // Check mid box + if price > box.MidUpper { + return market.BreakoutMid, "up" + } + if price < box.MidLower { + return market.BreakoutMid, "down" + } + + // Check short box + if price > box.ShortUpper { + return market.BreakoutShort, "up" + } + if price < box.ShortLower { + return market.BreakoutShort, "down" + } + + return market.BreakoutNone, "" +} + +// ============================================================================ +// Task 8: Breakout Confirmation Logic +// ============================================================================ + +const BreakoutConfirmRequired = 3 // 3 candles to confirm breakout + +// BreakoutState tracks the current breakout state +type BreakoutState struct { + Level market.BreakoutLevel + Direction string + ConfirmCount int + StartTime time.Time +} + +// confirmBreakout updates breakout state and returns true if breakout is confirmed +func confirmBreakout(state *BreakoutState, currentLevel market.BreakoutLevel, direction string) bool { + // If price returned to box, reset state + if currentLevel == market.BreakoutNone { + state.ConfirmCount = 0 + state.Level = market.BreakoutNone + state.Direction = "" + return false + } + + // If same breakout continues, increment count + if state.Level == currentLevel && state.Direction == direction { + state.ConfirmCount++ + } else { + // New breakout, reset count + state.Level = currentLevel + state.Direction = direction + state.ConfirmCount = 1 + state.StartTime = time.Now() + } + + return state.ConfirmCount >= BreakoutConfirmRequired +} + +// ============================================================================ +// Task 9: Breakout Handler +// ============================================================================ + +// BreakoutAction represents the action to take on breakout +type BreakoutAction int + +const ( + BreakoutActionNone BreakoutAction = iota + BreakoutActionReducePosition // Short box breakout: reduce to 50% + BreakoutActionPauseGrid // Mid box breakout: pause grid + cancel orders + BreakoutActionCloseAll // Long box breakout: pause + cancel + close all +) + +// getBreakoutAction returns the appropriate action for a breakout level +func getBreakoutAction(level market.BreakoutLevel) BreakoutAction { + switch level { + case market.BreakoutShort: + return BreakoutActionReducePosition + case market.BreakoutMid: + return BreakoutActionPauseGrid + case market.BreakoutLong: + return BreakoutActionCloseAll + default: + return BreakoutActionNone + } +} diff --git a/trader/grid_regime_test.go b/trader/grid_regime_test.go new file mode 100644 index 00000000..25d0753a --- /dev/null +++ b/trader/grid_regime_test.go @@ -0,0 +1,122 @@ +package trader + +import ( + "nofx/market" + "testing" +) + +func TestClassifyRegimeLevel(t *testing.T) { + tests := []struct { + name string + bollingerWidth float64 + atr14Pct float64 + expected market.RegimeLevel + }{ + {"narrow", 1.5, 0.8, market.RegimeLevelNarrow}, + {"standard", 2.5, 1.5, market.RegimeLevelStandard}, + {"wide", 3.5, 2.5, market.RegimeLevelWide}, + {"volatile", 5.0, 4.0, market.RegimeLevelVolatile}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := classifyRegimeLevel(tt.bollingerWidth, tt.atr14Pct) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestDetectBoxBreakout(t *testing.T) { + box := &market.BoxData{ + ShortUpper: 100, + ShortLower: 90, + MidUpper: 105, + MidLower: 85, + LongUpper: 110, + LongLower: 80, + CurrentPrice: 95, + } + + // No breakout + level, direction := detectBoxBreakout(box) + if level != market.BreakoutNone { + t.Errorf("Expected no breakout, got %v", level) + } + + // Short breakout up + box.CurrentPrice = 101 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutShort || direction != "up" { + t.Errorf("Expected short breakout up, got %v %v", level, direction) + } + + // Mid breakout down + box.CurrentPrice = 84 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutMid || direction != "down" { + t.Errorf("Expected mid breakout down, got %v %v", level, direction) + } + + // Long breakout up + box.CurrentPrice = 112 + level, direction = detectBoxBreakout(box) + if level != market.BreakoutLong || direction != "up" { + t.Errorf("Expected long breakout up, got %v %v", level, direction) + } +} + +func TestBreakoutConfirmation(t *testing.T) { + state := &BreakoutState{ + Level: market.BreakoutNone, + Direction: "", + ConfirmCount: 0, + } + + // First detection + confirmed := confirmBreakout(state, market.BreakoutShort, "up") + if confirmed || state.ConfirmCount != 1 { + t.Errorf("Expected not confirmed, count=1, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Second confirmation + confirmed = confirmBreakout(state, market.BreakoutShort, "up") + if confirmed || state.ConfirmCount != 2 { + t.Errorf("Expected not confirmed, count=2, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Third confirmation - should confirm + confirmed = confirmBreakout(state, market.BreakoutShort, "up") + if !confirmed || state.ConfirmCount != 3 { + t.Errorf("Expected confirmed, count=3, got confirmed=%v count=%d", confirmed, state.ConfirmCount) + } + + // Reset on price return + state.ConfirmCount = 2 + confirmed = confirmBreakout(state, market.BreakoutNone, "") + if state.ConfirmCount != 0 { + t.Errorf("Expected count reset to 0, got %d", state.ConfirmCount) + } +} + +func TestGetBreakoutAction(t *testing.T) { + tests := []struct { + level market.BreakoutLevel + expected BreakoutAction + }{ + {market.BreakoutNone, BreakoutActionNone}, + {market.BreakoutShort, BreakoutActionReducePosition}, + {market.BreakoutMid, BreakoutActionPauseGrid}, + {market.BreakoutLong, BreakoutActionCloseAll}, + } + + for _, tt := range tests { + t.Run(string(tt.level), func(t *testing.T) { + action := getBreakoutAction(tt.level) + if action != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, action) + } + }) + } +} diff --git a/trader/hyperliquid_sync_test.go b/trader/hyperliquid_sync_test.go index fce8fdc8..4eb4bd81 100644 --- a/trader/hyperliquid_sync_test.go +++ b/trader/hyperliquid_sync_test.go @@ -103,7 +103,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "open_long", 0.1, 3500, 0.5, 0, - time.Now(), "order-1", + time.Now().UnixMilli(), "order-1", ) if err != nil { t.Fatalf("Failed to process open long: %v", err) @@ -126,7 +126,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "close_long", 0.1, 3600, 0.5, 10.0, // PnL = (3600-3500)*0.1 = 10 - time.Now(), "order-2", + time.Now().UnixMilli(), "order-2", ) if err != nil { t.Fatalf("Failed to process close long: %v", err) @@ -152,7 +152,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "SHORT", "open_short", 0.05, 3500, 0.25, 0, - time.Now(), "order-3", + time.Now().UnixMilli(), "order-3", ) if err != nil { t.Fatalf("Failed to process open short: %v", err) @@ -176,7 +176,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "SHORT", "close_short", 0.05, 3400, 0.25, 5.0, // PnL = (3500-3400)*0.05 = 5 - time.Now(), "order-4", + time.Now().UnixMilli(), "order-4", ) if err != nil { t.Fatalf("Failed to process close short: %v", err) @@ -205,7 +205,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "open_long", 0.1, 3500, 0.5, 0, - time.Now(), "order-5", + time.Now().UnixMilli(), "order-5", ) if err != nil { t.Fatalf("Failed to process first open: %v", err) @@ -216,7 +216,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "open_long", 0.1, 3600, 0.5, 0, - time.Now(), "order-6", + time.Now().UnixMilli(), "order-6", ) if err != nil { t.Fatalf("Failed to process add position: %v", err) @@ -243,7 +243,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "close_long", 0.2, 3700, 1.0, 30.0, - time.Now(), "order-7", + time.Now().UnixMilli(), "order-7", ) if err != nil { t.Fatalf("Failed to process close: %v", err) @@ -269,7 +269,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "open_long", 1.0, 3500, 2.0, 0, - time.Now(), "order-8", + time.Now().UnixMilli(), "order-8", ) if err != nil { t.Fatalf("Failed to process open: %v", err) @@ -280,7 +280,7 @@ func TestHyperliquidPositionBuilding(t *testing.T) { traderID, exchangeID, exchangeType, symbol, "LONG", "close_long", 0.3, 3600, 0.6, 30.0, - time.Now(), "order-9", + time.Now().UnixMilli(), "order-9", ) if err != nil { t.Fatalf("Failed to process partial close: %v", err) @@ -351,7 +351,7 @@ func TestHyperliquidBugScenario(t *testing.T) { traderID, exchangeID, exchangeType, trade.symbol, trade.side, trade.action, trade.qty, trade.price, trade.fee, trade.pnl, - time.Now().Add(time.Duration(i)*time.Second), + time.Now().Add(time.Duration(i)*time.Second).UnixMilli(), "", ) if err != nil { diff --git a/trader/hyperliquid_trader.go b/trader/hyperliquid_trader.go index 354acd99..aa89f39e 100644 --- a/trader/hyperliquid_trader.go +++ b/trader/hyperliquid_trader.go @@ -2114,3 +2114,118 @@ func (t *HyperliquidTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { return result, nil } + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *HyperliquidTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + coin := convertSymbolToHyperliquid(req.Symbol) + + // Set leverage if specified and not xyz dex + isXyz := strings.HasPrefix(coin, "xyz:") + if req.Leverage > 0 && !isXyz { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Hyperliquid] Failed to set leverage: %v", err) + } + } + + // Round quantity to allowed decimals + roundedQuantity := t.roundToSzDecimals(coin, req.Quantity) + + // Round price to 5 significant figures + roundedPrice := t.roundPriceToSigfigs(req.Price) + + // Determine if buy or sell + isBuy := req.Side == "BUY" + + logger.Infof("[Hyperliquid] PlaceLimitOrder: %s %s @ %.4f, qty=%.4f", coin, req.Side, roundedPrice, roundedQuantity) + + order := hyperliquid.CreateOrderRequest{ + Coin: coin, + IsBuy: isBuy, + Size: roundedQuantity, + Price: roundedPrice, + OrderType: hyperliquid.OrderType{ + Limit: &hyperliquid.LimitOrderType{ + Tif: hyperliquid.TifGtc, // Good Till Cancel for grid orders + }, + }, + ReduceOnly: req.ReduceOnly, + } + + _, err := t.exchange.Order(t.ctx, order, defaultBuilder) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + // Note: Hyperliquid's Order response doesn't return the order ID directly + // We would need to query open orders to get it, but for grid trading + // we can track orders by price level instead + orderID := fmt.Sprintf("%d", time.Now().UnixNano()) + + logger.Infof("✓ [Hyperliquid] Limit order placed: %s %s @ %.4f", + coin, req.Side, roundedPrice) + + return &LimitOrderResult{ + OrderID: orderID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: roundedPrice, + Quantity: roundedQuantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *HyperliquidTrader) CancelOrder(symbol, orderID string) error { + coin := convertSymbolToHyperliquid(symbol) + + // Parse order ID + oid, err := strconv.ParseInt(orderID, 10, 64) + if err != nil { + return fmt.Errorf("invalid order ID: %w", err) + } + + _, err = t.exchange.Cancel(t.ctx, coin, oid) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [Hyperliquid] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *HyperliquidTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + coin := convertSymbolToHyperliquid(symbol) + + l2Book, err := t.exchange.Info().L2Snapshot(t.ctx, coin) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + if l2Book == nil || len(l2Book.Levels) < 2 { + return nil, nil, fmt.Errorf("invalid order book data") + } + + // Parse bids (first level array) + for i, level := range l2Book.Levels[0] { + if i >= depth { + break + } + bids = append(bids, []float64{level.Px, level.Sz}) + } + + // Parse asks (second level array) + for i, level := range l2Book.Levels[1] { + if i >= depth { + break + } + asks = append(asks, []float64{level.Px, level.Sz}) + } + + return bids, asks, nil +} diff --git a/trader/interface.go b/trader/interface.go index 35618633..741e6e31 100644 --- a/trader/interface.go +++ b/trader/interface.go @@ -1,6 +1,10 @@ package trader -import "time" +import ( + "fmt" + "nofx/logger" + "time" +) // ClosedPnLRecord represents a single closed position record from exchange type ClosedPnLRecord struct { @@ -112,3 +116,115 @@ type OpenOrder struct { Quantity float64 `json:"quantity"` Status string `json:"status"` // NEW } + +// LimitOrderRequest represents a limit order request for grid trading +type LimitOrderRequest struct { + Symbol string `json:"symbol"` + Side string `json:"side"` // BUY/SELL + PositionSide string `json:"position_side"` // LONG/SHORT (for hedge mode) + Price float64 `json:"price"` // Limit price + Quantity float64 `json:"quantity"` + Leverage int `json:"leverage"` + PostOnly bool `json:"post_only"` // Maker only order + ReduceOnly bool `json:"reduce_only"` // Reduce position only + ClientID string `json:"client_id"` // Client order ID for tracking +} + +// LimitOrderResult represents the result of placing a limit order +type LimitOrderResult struct { + OrderID string `json:"order_id"` + ClientID string `json:"client_id"` + Symbol string `json:"symbol"` + Side string `json:"side"` + PositionSide string `json:"position_side"` + Price float64 `json:"price"` + Quantity float64 `json:"quantity"` + Status string `json:"status"` // NEW, PARTIALLY_FILLED, FILLED, CANCELED +} + +// GridTrader extends Trader interface with limit order support for grid trading +// Exchanges that support grid trading should implement this interface +type GridTrader interface { + Trader + + // PlaceLimitOrder places a limit order at specified price + // Returns order ID and status + PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) + + // CancelOrder cancels a specific order by ID + CancelOrder(symbol, orderID string) error + + // GetOrderBook gets current order book (for price validation) + // Returns best bid/ask prices + GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) +} + +// GridTraderAdapter wraps a basic Trader to provide GridTrader interface +// Uses stop orders as a fallback when limit orders aren't directly available +type GridTraderAdapter struct { + Trader +} + +// NewGridTraderAdapter creates an adapter for basic Trader +func NewGridTraderAdapter(t Trader) *GridTraderAdapter { + return &GridTraderAdapter{Trader: t} +} + +// PlaceLimitOrder implements limit order using available methods +// For exchanges without native limit order support, this uses conditional orders +func (a *GridTraderAdapter) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + // CRITICAL FIX: Set leverage before placing order + if req.Leverage > 0 { + if err := a.Trader.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[Grid] Failed to set leverage %dx: %v", req.Leverage, err) + // Continue anyway - some exchanges don't require explicit leverage setting + } + } + + // Use SetStopLoss/SetTakeProfit as conditional limit orders + // For buy orders below current price, use stop-loss mechanism + // For sell orders above current price, use take-profit mechanism + var err error + if req.Side == "BUY" { + err = a.Trader.SetStopLoss(req.Symbol, "SHORT", req.Quantity, req.Price) + } else { + err = a.Trader.SetTakeProfit(req.Symbol, "LONG", req.Quantity, req.Price) + } + if err != nil { + return nil, err + } + return &LimitOrderResult{ + OrderID: req.ClientID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order +func (a *GridTraderAdapter) CancelOrder(symbol, orderID string) error { + // Try to use CancelOrder if trader supports it directly + if canceler, ok := a.Trader.(interface { + CancelOrder(symbol, orderID string) error + }); ok { + return canceler.CancelOrder(symbol, orderID) + } + + // For traders that only support CancelAllOrders, log a warning + // This is a limitation - we cannot cancel individual orders + logger.Warnf("[Grid] Trader does not support individual order cancellation, "+ + "cannot cancel order %s. Consider using exchange-specific GridTrader implementation.", orderID) + + // Return error instead of canceling all orders + return fmt.Errorf("individual order cancellation not supported for this exchange") +} + +// GetOrderBook returns empty order book (not supported in basic Trader) +func (a *GridTraderAdapter) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + // Not supported, return empty + return nil, nil, nil +} diff --git a/trader/lighter_integration_test.go b/trader/lighter_integration_test.go index c8c414e7..11281201 100644 --- a/trader/lighter_integration_test.go +++ b/trader/lighter_integration_test.go @@ -1,25 +1,41 @@ package trader import ( + "fmt" "os" "strings" "testing" "time" ) -// Test configuration - uses real account -// Run with: LIGHTER_TEST=1 go test -v ./trader -run TestLighter -timeout 120s -const ( - testWalletAddr = "" - testAPIKeyPrivateKey = "" - testAPIKeyIndex = 0 - testAccountIndex = int64(681514) -) +// Test configuration - uses environment variables for security +// Run with: +// LIGHTER_TEST=1 LIGHTER_WALLET=0x... LIGHTER_API_KEY=... LIGHTER_API_KEY_INDEX=2 go test -v ./trader -run TestLighter -timeout 300s +// Run with trading: +// LIGHTER_TEST=1 LIGHTER_TRADE_TEST=1 LIGHTER_WALLET=0x... LIGHTER_API_KEY=... go test -v ./trader -run TestLighter -timeout 300s + +// getTestConfig returns test configuration from environment variables +func getTestConfig() (walletAddr, apiKey string, apiKeyIndex int) { + walletAddr = os.Getenv("LIGHTER_WALLET") + apiKey = os.Getenv("LIGHTER_API_KEY") + // All credentials must be provided via environment variables for security + apiKeyIndex = 2 // Default to index 2 (more stable than index 0) + if idx := os.Getenv("LIGHTER_API_KEY_INDEX"); idx != "" { + fmt.Sscanf(idx, "%d", &apiKeyIndex) + } + return +} func skipIfNoEnv(t *testing.T) { if os.Getenv("LIGHTER_TEST") != "1" { t.Skip("Skipping Lighter integration test. Set LIGHTER_TEST=1 to run") } + if os.Getenv("LIGHTER_WALLET") == "" { + t.Skip("Skipping: LIGHTER_WALLET environment variable not set") + } + if os.Getenv("LIGHTER_API_KEY") == "" { + t.Skip("Skipping: LIGHTER_API_KEY environment variable not set") + } } // skipIfJurisdictionRestricted checks if error is due to geographic restriction @@ -31,7 +47,8 @@ func skipIfJurisdictionRestricted(t *testing.T, err error) { } func createTestTrader(t *testing.T) *LighterTraderV2 { - trader, err := NewLighterTraderV2(testWalletAddr, testAPIKeyPrivateKey, testAPIKeyIndex, false) + walletAddr, apiKey, apiKeyIndex := getTestConfig() + trader, err := NewLighterTraderV2(walletAddr, apiKey, apiKeyIndex, false) if err != nil { t.Fatalf("Failed to create trader: %v", err) } @@ -46,9 +63,9 @@ func TestLighterAccountInit(t *testing.T) { trader := createTestTrader(t) defer trader.Cleanup() - // Verify account index - if trader.accountIndex != testAccountIndex { - t.Errorf("Expected account index %d, got %d", testAccountIndex, trader.accountIndex) + // Verify account index is valid (non-zero) + if trader.accountIndex <= 0 { + t.Errorf("Expected valid account index, got %d", trader.accountIndex) } t.Logf("✅ Account initialized: index=%d", trader.accountIndex) @@ -253,11 +270,11 @@ func TestLighterCreateAndCancelLimitOrder(t *testing.T) { t.Fatalf("CreateOrder failed: %v", err) } - orderID, _ := result["order_id"].(string) + orderID, _ := result["orderId"].(string) t.Logf("✅ Order created: %s", orderID) if orderID == "" { - t.Fatal("Expected order ID in response") + t.Fatal("Expected orderId in response") } // Wait a moment for order to be processed @@ -517,11 +534,12 @@ func TestLighterOrderSync(t *testing.T) { // ==================== Benchmark Tests ==================== func BenchmarkLighterGetBalance(b *testing.B) { - if os.Getenv("LIGHTER_TEST") != "1" { - b.Skip("Skipping benchmark. Set LIGHTER_TEST=1 to run") + if os.Getenv("LIGHTER_TEST") != "1" || os.Getenv("LIGHTER_API_KEY") == "" { + b.Skip("Skipping benchmark. Set LIGHTER_TEST=1 and LIGHTER_API_KEY to run") } - trader, err := NewLighterTraderV2(testWalletAddr, testAPIKeyPrivateKey, testAPIKeyIndex, false) + walletAddr, apiKey, apiKeyIndex := getTestConfig() + trader, err := NewLighterTraderV2(walletAddr, apiKey, apiKeyIndex, false) if err != nil { b.Fatalf("Failed to create trader: %v", err) } @@ -537,11 +555,12 @@ func BenchmarkLighterGetBalance(b *testing.B) { } func BenchmarkLighterGetMarketPrice(b *testing.B) { - if os.Getenv("LIGHTER_TEST") != "1" { - b.Skip("Skipping benchmark. Set LIGHTER_TEST=1 to run") + if os.Getenv("LIGHTER_TEST") != "1" || os.Getenv("LIGHTER_API_KEY") == "" { + b.Skip("Skipping benchmark. Set LIGHTER_TEST=1 and LIGHTER_API_KEY to run") } - trader, err := NewLighterTraderV2(testWalletAddr, testAPIKeyPrivateKey, testAPIKeyIndex, false) + walletAddr, apiKey, apiKeyIndex := getTestConfig() + trader, err := NewLighterTraderV2(walletAddr, apiKey, apiKeyIndex, false) if err != nil { b.Fatalf("Failed to create trader: %v", err) } @@ -555,3 +574,533 @@ func BenchmarkLighterGetMarketPrice(b *testing.B) { } } } + +// ==================== GetOpenOrders Tests ==================== + +func TestLighterGetOpenOrders(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test GetOpenOrders + orders, err := trader.GetOpenOrders("ETH") + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("GetOpenOrders failed: %v", err) + } + + t.Logf("✅ GetOpenOrders: found %d open orders", len(orders)) + for i, order := range orders { + if i >= 5 { + t.Logf(" ... and %d more", len(orders)-5) + break + } + t.Logf(" [%d] %s %s %s: qty=%.4f @ %.2f, status=%s", + i+1, order.Symbol, order.Side, order.Type, order.Quantity, order.Price, order.Status) + } +} + +func TestLighterGetActiveOrders(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test GetActiveOrders (internal API) + orders, err := trader.GetActiveOrders("ETH") + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("GetActiveOrders failed: %v", err) + } + + t.Logf("✅ GetActiveOrders: found %d active orders", len(orders)) + for i, order := range orders { + if i >= 5 { + t.Logf(" ... and %d more", len(orders)-5) + break + } + t.Logf(" [%d] OrderID=%s, Type=%s, Price=%s, RemainingAmount=%s", + i+1, order.OrderID, order.Type, order.Price, order.RemainingBaseAmount) + } +} + +// ==================== OrderBook Tests ==================== + +func TestLighterGetOrderBook(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test GetOrderBook + bids, asks, err := trader.GetOrderBook("ETH", 10) + if err != nil { + // OrderBook API may not be available in all regions or require special permissions + if strings.Contains(err.Error(), "403") || strings.Contains(err.Error(), "restricted") { + t.Skipf("Skipping: OrderBook API not available: %v", err) + } + t.Fatalf("GetOrderBook failed: %v", err) + } + + t.Logf("✅ GetOrderBook: %d bids, %d asks", len(bids), len(asks)) + + if len(bids) > 0 { + t.Logf(" Best Bid: %.2f @ %.4f", bids[0][0], bids[0][1]) + } + if len(asks) > 0 { + t.Logf(" Best Ask: %.2f @ %.4f", asks[0][0], asks[0][1]) + } + + // Verify spread makes sense + if len(bids) > 0 && len(asks) > 0 { + spread := asks[0][0] - bids[0][0] + spreadPct := spread / bids[0][0] * 100 + t.Logf(" Spread: %.2f (%.4f%%)", spread, spreadPct) + + if spread < 0 { + t.Error("Invalid spread: ask < bid") + } + } +} + +// ==================== PlaceLimitOrder (GridTrader) Tests ==================== + +func TestLighterPlaceLimitOrder(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Get current market price + marketPrice, err := trader.GetMarketPrice("ETH") + if err != nil { + t.Fatalf("Failed to get market price: %v", err) + } + t.Logf("Current ETH price: %.2f", marketPrice) + + // Create a limit order using PlaceLimitOrder (GridTrader interface) + // Buy order at 75% of market price (won't fill) + limitPrice := marketPrice * 0.75 + quantity := 0.01 + + req := &LimitOrderRequest{ + Symbol: "ETH", + Side: "BUY", + PositionSide: "LONG", + Price: limitPrice, + Quantity: quantity, + Leverage: 10, + ClientID: "test-order-001", + ReduceOnly: false, + } + + t.Logf("Placing limit order via PlaceLimitOrder: %s %.4f @ %.2f", req.Side, req.Quantity, req.Price) + + result, err := trader.PlaceLimitOrder(req) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("PlaceLimitOrder failed: %v", err) + } + + t.Logf("✅ PlaceLimitOrder result: OrderID=%s, Status=%s", result.OrderID, result.Status) + + if result.OrderID == "" { + t.Fatal("Expected OrderID in result") + } + + // Wait and cancel + time.Sleep(3 * time.Second) + + // Cancel the order + err = trader.CancelOrder("ETH", result.OrderID) + if err != nil { + t.Logf("⚠️ Failed to cancel order: %v", err) + } else { + t.Log("✅ Order cancelled successfully") + } +} + +// ==================== SetMarginMode Tests ==================== + +func TestLighterSetMarginMode(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test setting cross margin + t.Log("Setting margin mode to CROSS...") + err := trader.SetMarginMode("ETH", true) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Errorf("SetMarginMode(cross) failed: %v", err) + } else { + t.Log("✅ SetMarginMode(cross) succeeded") + } + + time.Sleep(2 * time.Second) + + // Note: Isolated margin may fail if there's an open position + // Just test cross margin for safety +} + +// ==================== Stop-Loss/Take-Profit Tests ==================== + +func TestLighterStopLossOrder(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping stop-loss test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Check if we have a position first + pos, err := trader.GetPosition("ETH") + if err != nil { + t.Fatalf("GetPosition failed: %v", err) + } + + if pos == nil || pos.Size == 0 { + t.Skip("No ETH position to set stop-loss for") + } + + // Calculate stop-loss price (5% below entry for long, 5% above for short) + var stopPrice float64 + if pos.Side == "long" { + stopPrice = pos.EntryPrice * 0.95 + } else { + stopPrice = pos.EntryPrice * 1.05 + } + + t.Logf("Position: %s %s, size=%.4f, entry=%.2f", pos.Symbol, pos.Side, pos.Size, pos.EntryPrice) + t.Logf("Setting stop-loss at %.2f", stopPrice) + + err = trader.SetStopLoss("ETH", strings.ToUpper(pos.Side), pos.Size, stopPrice) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Errorf("SetStopLoss failed: %v", err) + } else { + t.Log("✅ SetStopLoss succeeded") + } +} + +func TestLighterTakeProfitOrder(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping take-profit test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Check if we have a position first + pos, err := trader.GetPosition("ETH") + if err != nil { + t.Fatalf("GetPosition failed: %v", err) + } + + if pos == nil || pos.Size == 0 { + t.Skip("No ETH position to set take-profit for") + } + + // Calculate take-profit price (10% above entry for long, 10% below for short) + var takeProfitPrice float64 + if pos.Side == "long" { + takeProfitPrice = pos.EntryPrice * 1.10 + } else { + takeProfitPrice = pos.EntryPrice * 0.90 + } + + t.Logf("Position: %s %s, size=%.4f, entry=%.2f", pos.Symbol, pos.Side, pos.Size, pos.EntryPrice) + t.Logf("Setting take-profit at %.2f", takeProfitPrice) + + err = trader.SetTakeProfit("ETH", strings.ToUpper(pos.Side), pos.Size, takeProfitPrice) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Errorf("SetTakeProfit failed: %v", err) + } else { + t.Log("✅ SetTakeProfit succeeded") + } +} + +// ==================== Full Trading Flow Tests ==================== + +func TestLighterFullTradingFlow(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping full trading flow test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + symbol := "ETH" + quantity := 0.01 // Minimum quantity + leverage := 10 + + // Step 1: Get initial state + t.Log("=== Step 1: Get Initial State ===") + balance, _ := trader.GetBalance() + if equity, ok := balance["total_equity"].(float64); ok { + t.Logf(" Initial equity: %.2f", equity) + } + + marketPrice, err := trader.GetMarketPrice(symbol) + if err != nil { + t.Fatalf("Failed to get market price: %v", err) + } + t.Logf(" Market price: %.2f", marketPrice) + + // Step 2: Set leverage + t.Log("=== Step 2: Set Leverage ===") + err = trader.SetLeverage(symbol, leverage) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("SetLeverage failed: %v", err) + } + t.Logf(" Leverage set to %dx", leverage) + time.Sleep(2 * time.Second) + + // Step 3: Open Long Position + t.Log("=== Step 3: Open Long Position ===") + result, err := trader.OpenLong(symbol, quantity, leverage) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("OpenLong failed: %v", err) + } + t.Logf(" OpenLong result: %v", result) + time.Sleep(3 * time.Second) + + // Step 4: Verify position + t.Log("=== Step 4: Verify Position ===") + pos, err := trader.GetPosition(symbol) + if err != nil { + t.Errorf("GetPosition failed: %v", err) + } else if pos != nil { + t.Logf(" Position: %s %s, size=%.4f, entry=%.2f, pnl=%.2f", + pos.Symbol, pos.Side, pos.Size, pos.EntryPrice, pos.UnrealizedPnL) + } + + // Step 5: Place limit order (sell at higher price) + t.Log("=== Step 5: Place Limit Sell Order ===") + limitPrice := marketPrice * 1.05 // 5% above market + limitResult, err := trader.CreateOrder(symbol, true, quantity, limitPrice, "limit", true) + if err != nil { + t.Logf(" Failed to place limit order: %v", err) + } else { + t.Logf(" Limit order placed: %v", limitResult) + } + time.Sleep(2 * time.Second) + + // Step 6: Get open orders + t.Log("=== Step 6: Get Open Orders ===") + orders, err := trader.GetOpenOrders(symbol) + if err != nil { + t.Logf(" Failed to get open orders: %v", err) + } else { + t.Logf(" Open orders: %d", len(orders)) + for _, o := range orders { + t.Logf(" - %s %s: qty=%.4f @ %.2f", o.Side, o.Type, o.Quantity, o.Price) + } + } + + // Step 7: Cancel all orders + t.Log("=== Step 7: Cancel All Orders ===") + err = trader.CancelAllOrders(symbol) + if err != nil { + t.Logf(" Failed to cancel orders: %v", err) + } else { + t.Log(" All orders cancelled") + } + time.Sleep(2 * time.Second) + + // Step 8: Close position + t.Log("=== Step 8: Close Position ===") + closeResult, err := trader.CloseLong(symbol, 0) // 0 = close all + if err != nil { + t.Errorf("CloseLong failed: %v", err) + } else { + t.Logf(" CloseLong result: %v", closeResult) + } + time.Sleep(3 * time.Second) + + // Step 9: Verify position closed + t.Log("=== Step 9: Verify Position Closed ===") + pos, _ = trader.GetPosition(symbol) + if pos == nil || pos.Size == 0 { + t.Log(" ✅ Position closed successfully") + } else { + t.Logf(" ⚠️ Position still exists: size=%.4f", pos.Size) + } + + // Step 10: Get final balance + t.Log("=== Step 10: Get Final State ===") + balance, _ = trader.GetBalance() + if equity, ok := balance["total_equity"].(float64); ok { + t.Logf(" Final equity: %.2f", equity) + } + + t.Log("=== Full Trading Flow Completed ===") +} + +// ==================== API Key Validation Tests ==================== + +func TestLighterAPIKeyValid(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Check if API key is valid + if trader.apiKeyValid { + t.Log("✅ API key is VALID and matches server") + } else { + t.Error("❌ API key is INVALID - does not match server") + } + + // Verify by checking the actual API key + err := trader.checkClient() + if err != nil { + t.Errorf("API key verification error: %v", err) + } else { + t.Log("✅ API key verification passed") + } +} + +// ==================== Market Order Tests ==================== + +func TestLighterMarketOrderBuy(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping market order test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Create a small market buy order + quantity := 0.01 + t.Logf("Creating market buy order: %.4f ETH", quantity) + + result, err := trader.CreateOrder("ETH", false, quantity, 0, "market", false) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("Market buy failed: %v", err) + } + + t.Logf("✅ Market buy result: %v", result) + + // Wait and close + time.Sleep(3 * time.Second) + + // Close the position + _, err = trader.CloseLong("ETH", quantity) + if err != nil { + t.Logf("⚠️ Failed to close position: %v", err) + } else { + t.Log("✅ Position closed") + } +} + +func TestLighterMarketOrderSell(t *testing.T) { + skipIfNoEnv(t) + + if os.Getenv("LIGHTER_TRADE_TEST") != "1" { + t.Skip("Skipping market order test. Set LIGHTER_TRADE_TEST=1 to run") + } + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Create a small market sell order (short) + quantity := 0.01 + t.Logf("Creating market sell order (short): %.4f ETH", quantity) + + result, err := trader.CreateOrder("ETH", true, quantity, 0, "market", false) + skipIfJurisdictionRestricted(t, err) + if err != nil { + t.Fatalf("Market sell failed: %v", err) + } + + t.Logf("✅ Market sell result: %v", result) + + // Wait and close + time.Sleep(3 * time.Second) + + // Close the position + _, err = trader.CloseShort("ETH", quantity) + if err != nil { + t.Logf("⚠️ Failed to close position: %v", err) + } else { + t.Log("✅ Position closed") + } +} + +// ==================== GetPosition Tests ==================== + +func TestLighterGetPosition(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test GetPosition for ETH + pos, err := trader.GetPosition("ETH") + if err != nil { + t.Fatalf("GetPosition failed: %v", err) + } + + if pos == nil { + t.Log("✅ No ETH position (pos is nil)") + } else if pos.Size == 0 { + t.Log("✅ No ETH position (size is 0)") + } else { + t.Logf("✅ ETH position found:") + t.Logf(" Symbol: %s", pos.Symbol) + t.Logf(" Side: %s", pos.Side) + t.Logf(" Size: %.4f", pos.Size) + t.Logf(" Entry Price: %.2f", pos.EntryPrice) + t.Logf(" Mark Price: %.2f", pos.MarkPrice) + t.Logf(" Liquidation Price: %.2f", pos.LiquidationPrice) + t.Logf(" Unrealized PnL: %.2f", pos.UnrealizedPnL) + t.Logf(" Leverage: %.1fx", pos.Leverage) + } +} + +// ==================== Symbol Normalization Tests ==================== + +func TestLighterSymbolNormalization(t *testing.T) { + skipIfNoEnv(t) + + trader := createTestTrader(t) + defer trader.Cleanup() + + // Test different symbol formats + testCases := []struct { + input string + expected string + }{ + {"ETH", "ETH"}, + {"ETH-PERP", "ETH"}, + {"ETHUSDT", "ETH"}, + {"ETH/USDT", "ETH"}, + {"BTC", "BTC"}, + {"BTCUSDT", "BTC"}, + } + + for _, tc := range testCases { + // Try to get market price with different formats + price, err := trader.GetMarketPrice(tc.input) + if err != nil { + t.Logf("⚠️ GetMarketPrice(%s) failed: %v", tc.input, err) + } else { + t.Logf("✅ GetMarketPrice(%s) = %.2f", tc.input, price) + } + } +} diff --git a/trader/lighter_trader_v2.go b/trader/lighter_trader_v2.go index 6abdf405..60b11570 100644 --- a/trader/lighter_trader_v2.go +++ b/trader/lighter_trader_v2.go @@ -74,6 +74,7 @@ type LighterTraderV2 struct { apiKeyPrivateKey string // 40-byte API Key private key (for signing transactions) apiKeyIndex uint8 // API Key index (default 0) accountIndex int64 // Account index + apiKeyValid bool // Whether API key has been validated against server // Authentication token authToken string @@ -85,8 +86,10 @@ type LighterTraderV2 struct { precisionMutex sync.RWMutex // Market index cache - marketIndexMap map[string]uint16 // symbol -> market_id - marketMutex sync.RWMutex + marketIndexMap map[string]uint16 // symbol -> market_id + marketMutex sync.RWMutex + marketListCache []MarketInfo // Cached market list + marketListCacheTime time.Time // Time when cache was populated } // NewLighterTraderV2 Create new LIGHTER trader (using official SDK) @@ -127,9 +130,6 @@ func NewLighterTraderV2(walletAddr, apiKeyPrivateKeyHex string, apiKeyIndex int, walletAddr: walletAddr, client: &http.Client{ Timeout: 30 * time.Second, - Transport: &http.Transport{ - Proxy: nil, // Disable proxy for direct connection to Lighter API - }, }, baseURL: baseURL, testnet: testnet, @@ -162,14 +162,18 @@ func NewLighterTraderV2(walletAddr, apiKeyPrivateKeyHex string, apiKeyIndex int, // 7. Verify API Key is correct if err := trader.checkClient(); err != nil { - logger.Warnf("⚠️ API Key verification failed: %v", err) - logger.Warnf("⚠️ The API key may not be registered on-chain. Authenticated API calls (like GetTrades) will fail.") - logger.Warnf("⚠️ To fix: Register this API key using change_api_key transaction from app.lighter.xyz") - // Don't fail here, allow trader to continue (may work with some operations) + trader.apiKeyValid = false + logger.Warnf("⚠️ API Key verification FAILED: %v", err) + logger.Warnf("⚠️ ❌ The API key stored in NOFX does NOT match the API key registered on Lighter.") + logger.Warnf("⚠️ ❌ ALL trading operations (open/close positions, cancel orders) WILL FAIL with 'invalid signature' error.") + logger.Warnf("⚠️ 🔧 To fix: Update your Lighter API key in NOFX Exchange settings with the correct key from app.lighter.xyz") + // Don't fail here, allow trader to continue for read operations (balance, positions) + } else { + trader.apiKeyValid = true } - logger.Infof("✓ LIGHTER trader initialized successfully (account=%d, apiKey=%d, testnet=%v)", - trader.accountIndex, trader.apiKeyIndex, testnet) + logger.Infof("✓ LIGHTER trader initialized (account=%d, apiKey=%d, testnet=%v, apiKeyValid=%v)", + trader.accountIndex, trader.apiKeyIndex, testnet, trader.apiKeyValid) return trader, nil } @@ -212,7 +216,7 @@ func (t *LighterTraderV2) getAccountByL1Address() (*AccountInfo, error) { } // Log raw response for debugging - logger.Infof("LIGHTER account API response: %s", string(body)) + logger.Debugf("LIGHTER account API response: %s", string(body)) if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get account (status %d): %s", resp.StatusCode, string(body)) @@ -238,10 +242,10 @@ func (t *LighterTraderV2) getAccountByL1Address() (*AccountInfo, error) { return nil, fmt.Errorf("no account found for wallet address: %s (try depositing funds first at app.lighter.xyz)", t.walletAddr) } - // Log all found accounts - logger.Infof("Found %d accounts (main: %d, sub: %d)", len(allAccounts), len(accountResp.Accounts), len(accountResp.SubAccounts)) + // Log account summary + logger.Infof("Found %d account(s) (main: %d, sub: %d)", len(allAccounts), len(accountResp.Accounts), len(accountResp.SubAccounts)) for i, acc := range allAccounts { - logger.Infof(" Account[%d]: index=%d, collateral=%s", i, acc.AccountIndex, acc.Collateral) + logger.Debugf(" Account[%d]: index=%d, collateral=%s", i, acc.AccountIndex, acc.Collateral) } account := &allAccounts[0] @@ -253,26 +257,79 @@ func (t *LighterTraderV2) getAccountByL1Address() (*AccountInfo, error) { return account, nil } +// ApiKeyResponse API key query response +type ApiKeyResponse struct { + Code int `json:"code"` + ApiKeys []struct { + AccountIndex int64 `json:"account_index"` + ApiKeyIndex uint8 `json:"api_key_index"` + Nonce int64 `json:"nonce"` + PublicKey string `json:"public_key"` + } `json:"api_keys"` +} + +// getApiKeyFromServer Get API Key public key from Lighter server +// Uses our own HTTP client instead of SDK's global client to avoid connection issues +func (t *LighterTraderV2) getApiKeyFromServer() (string, error) { + endpoint := fmt.Sprintf("%s/api/v1/apikeys?account_index=%d&api_key_index=%d", + t.baseURL, t.accountIndex, t.apiKeyIndex) + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return "", err + } + + resp, err := t.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var result ApiKeyResponse + if err := json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if result.Code != 200 { + return "", fmt.Errorf("API error (code %d)", result.Code) + } + + if len(result.ApiKeys) == 0 { + return "", fmt.Errorf("no API keys found for account %d", t.accountIndex) + } + + return result.ApiKeys[0].PublicKey, nil +} + // checkClient Verify if API Key is correct func (t *LighterTraderV2) checkClient() error { if t.txClient == nil { return fmt.Errorf("TxClient not initialized") } - // Get API Key public key registered on server - publicKey, err := t.httpClient.GetApiKey(t.accountIndex, t.apiKeyIndex) + // Get API Key public key registered on server (using our own HTTP client) + serverPubKey, err := t.getApiKeyFromServer() if err != nil { return fmt.Errorf("failed to get API Key: %w", err) } - // Get local API Key public key + // Get local API Key public key from SDK pubKeyBytes := t.txClient.GetKeyManager().PubKeyBytes() localPubKey := hexutil.Encode(pubKeyBytes[:]) - localPubKey = strings.Replace(localPubKey, "0x", "", 1) + localPubKey = strings.TrimPrefix(localPubKey, "0x") // Compare public keys - if publicKey != localPubKey { - return fmt.Errorf("API Key mismatch: local=%s, server=%s", localPubKey, publicKey) + if serverPubKey != localPubKey { + return fmt.Errorf("API Key mismatch: local=%s, server=%s", localPubKey, serverPubKey) } logger.Infof("✓ API Key verification passed") @@ -436,12 +493,8 @@ func (t *LighterTraderV2) GetTrades(startTime time.Time, limit int) ([]TradeReco return []TradeRecord{}, nil } - // Debug: log raw response (first 500 chars) - logBody := string(body) - if len(logBody) > 500 { - logBody = logBody[:500] + "..." - } - logger.Infof("📋 Lighter trades API raw response: %s", logBody) + // Debug: log raw response + logger.Debugf("Lighter trades API response: %s", string(body)) var response LighterTradeResponse if err := json.Unmarshal(body, &response); err != nil { diff --git a/trader/lighter_trader_v2_account.go b/trader/lighter_trader_v2_account.go index b9c84c18..b2865c32 100644 --- a/trader/lighter_trader_v2_account.go +++ b/trader/lighter_trader_v2_account.go @@ -11,6 +11,7 @@ import ( ) // getFullAccountInfo Fetch full account info from Lighter API (includes balance and positions) +// Supports both main accounts and sub-accounts func (t *LighterTraderV2) getFullAccountInfo() (*AccountInfo, error) { endpoint := fmt.Sprintf("%s/api/v1/account?by=l1_address&value=%s", t.baseURL, t.walletAddr) @@ -34,20 +35,47 @@ func (t *LighterTraderV2) getFullAccountInfo() (*AccountInfo, error) { return nil, fmt.Errorf("failed to get account (status %d): %s", resp.StatusCode, string(body)) } - // Parse response - Lighter returns {"accounts": [...]} + // Parse response - Lighter may return accounts in "accounts" or "sub_accounts" field var accountResp AccountResponse if err := json.Unmarshal(body, &accountResp); err != nil { return nil, fmt.Errorf("failed to parse account response: %w", err) } - if len(accountResp.Accounts) == 0 { - return nil, fmt.Errorf("no account found for wallet address: %s", t.walletAddr) + // Check for API error code + if accountResp.Code != 0 && accountResp.Code != 200 { + return nil, fmt.Errorf("Lighter API error (code %d): %s", accountResp.Code, accountResp.Message) } - account := &accountResp.Accounts[0] - // Use index field if account_index is 0 - if account.AccountIndex == 0 && account.Index != 0 { - account.AccountIndex = account.Index + // Combine both accounts and sub_accounts - some users have sub-accounts + var allAccounts []AccountInfo + allAccounts = append(allAccounts, accountResp.Accounts...) + allAccounts = append(allAccounts, accountResp.SubAccounts...) + + if len(allAccounts) == 0 { + return nil, fmt.Errorf("no account found for wallet address: %s (try depositing funds first at app.lighter.xyz)", t.walletAddr) + } + + // Find the account that matches our stored accountIndex, or use the first one + var account *AccountInfo + for i := range allAccounts { + acc := &allAccounts[i] + // Use index field if account_index is 0 + if acc.AccountIndex == 0 && acc.Index != 0 { + acc.AccountIndex = acc.Index + } + // Match by stored accountIndex if we have one + if t.accountIndex != 0 && acc.AccountIndex == t.accountIndex { + account = acc + break + } + } + + // If no specific match, use the first account + if account == nil { + account = &allAccounts[0] + if account.AccountIndex == 0 && account.Index != 0 { + account.AccountIndex = account.Index + } } return account, nil @@ -328,12 +356,13 @@ func (t *LighterTraderV2) FormatQuantity(symbol string, quantity float64) (strin return fmt.Sprintf("%.4f", quantity), nil } -// GetOrderBook Get order book with best bid/ask prices -func (t *LighterTraderV2) GetOrderBook(symbol string) (bestBid, bestAsk float64, err error) { +// GetOrderBook Get order book (implements GridTrader interface) +// Returns bids and asks as [][]float64 where each element is [price, quantity] +func (t *LighterTraderV2) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { // Get market_id first marketID, err := t.getMarketIndex(symbol) if err != nil { - return 0, 0, fmt.Errorf("failed to get market ID: %w", err) + return nil, nil, fmt.Errorf("failed to get market ID: %w", err) } // Get order book from Lighter API @@ -341,22 +370,22 @@ func (t *LighterTraderV2) GetOrderBook(symbol string) (bestBid, bestAsk float64, req, err := http.NewRequest("GET", endpoint, nil) if err != nil { - return 0, 0, err + return nil, nil, err } resp, err := t.client.Do(req) if err != nil { - return 0, 0, err + return nil, nil, err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return 0, 0, err + return nil, nil, err } if resp.StatusCode != http.StatusOK { - return 0, 0, fmt.Errorf("failed to get order book (status %d): %s", resp.StatusCode, string(body)) + return nil, nil, fmt.Errorf("failed to get order book (status %d): %s", resp.StatusCode, string(body)) } // Parse response @@ -369,35 +398,61 @@ func (t *LighterTraderV2) GetOrderBook(symbol string) (bestBid, bestAsk float64, } if err := json.Unmarshal(body, &apiResp); err != nil { - return 0, 0, fmt.Errorf("failed to parse order book: %w", err) + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) } if apiResp.Code != 200 { - return 0, 0, fmt.Errorf("API error code: %d", apiResp.Code) + return nil, nil, fmt.Errorf("API error code: %d", apiResp.Code) } - // Get best bid (highest buy price) - if len(apiResp.Data.Bids) > 0 && len(apiResp.Data.Bids[0]) >= 1 { - if price, ok := apiResp.Data.Bids[0][0].(float64); ok { - bestBid = price - } else if priceStr, ok := apiResp.Data.Bids[0][0].(string); ok { - bestBid, _ = strconv.ParseFloat(priceStr, 64) + // Helper to parse price/quantity from interface{} + parseFloat := func(v interface{}) float64 { + if f, ok := v.(float64); ok { + return f + } + if s, ok := v.(string); ok { + f, _ := strconv.ParseFloat(s, 64) + return f + } + return 0 + } + + // Convert bids to [][]float64 + maxBids := len(apiResp.Data.Bids) + if depth > 0 && depth < maxBids { + maxBids = depth + } + bids = make([][]float64, 0, maxBids) + for i := 0; i < maxBids; i++ { + if len(apiResp.Data.Bids[i]) >= 2 { + price := parseFloat(apiResp.Data.Bids[i][0]) + qty := parseFloat(apiResp.Data.Bids[i][1]) + if price > 0 && qty > 0 { + bids = append(bids, []float64{price, qty}) + } } } - // Get best ask (lowest sell price) - if len(apiResp.Data.Asks) > 0 && len(apiResp.Data.Asks[0]) >= 1 { - if price, ok := apiResp.Data.Asks[0][0].(float64); ok { - bestAsk = price - } else if priceStr, ok := apiResp.Data.Asks[0][0].(string); ok { - bestAsk, _ = strconv.ParseFloat(priceStr, 64) + // Convert asks to [][]float64 + maxAsks := len(apiResp.Data.Asks) + if depth > 0 && depth < maxAsks { + maxAsks = depth + } + asks = make([][]float64, 0, maxAsks) + for i := 0; i < maxAsks; i++ { + if len(apiResp.Data.Asks[i]) >= 2 { + price := parseFloat(apiResp.Data.Asks[i][0]) + qty := parseFloat(apiResp.Data.Asks[i][1]) + if price > 0 && qty > 0 { + asks = append(asks, []float64{price, qty}) + } } } - if bestBid <= 0 || bestAsk <= 0 { - return 0, 0, fmt.Errorf("invalid order book prices: bid=%.2f, ask=%.2f", bestBid, bestAsk) + if len(bids) > 0 && len(asks) > 0 { + logger.Infof("✓ Lighter order book: %s best_bid=%.2f, best_ask=%.2f, depth=%d/%d", + symbol, bids[0][0], asks[0][0], len(bids), len(asks)) } - logger.Infof("✓ Lighter order book: %s bid=%.2f, ask=%.2f", symbol, bestBid, bestAsk) - return bestBid, bestAsk, nil + return bids, asks, nil } diff --git a/trader/lighter_trader_v2_orders.go b/trader/lighter_trader_v2_orders.go index c30783c2..024b512c 100644 --- a/trader/lighter_trader_v2_orders.go +++ b/trader/lighter_trader_v2_orders.go @@ -1,12 +1,11 @@ package trader import ( - "bytes" "encoding/json" "fmt" "io" - "mime/multipart" "net/http" + "net/url" "nofx/logger" "strconv" @@ -100,15 +99,18 @@ func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[str return nil, fmt.Errorf("invalid auth token: %w", err) } - // Build request URL - endpoint := fmt.Sprintf("%s/api/v1/order/%s", t.baseURL, orderID) + // URL encode auth token (contains colons that need encoding) + // Authentication: Use "auth" query parameter (not Authorization header) + encodedAuth := url.QueryEscape(t.authToken) + + // Build request URL with auth query parameter + endpoint := fmt.Sprintf("%s/api/v1/order/%s?auth=%s", t.baseURL, orderID, encodedAuth) req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return nil, err } - req.Header.Set("Authorization", t.authToken) req.Header.Set("Content-Type", "application/json") resp, err := t.client.Do(req) @@ -148,7 +150,7 @@ func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[str "orderId": order.OrderID, "status": unifiedStatus, "avgPrice": order.Price, - "executedQty": order.FilledQty, + "executedQty": order.FilledBaseAmount, "commission": 0.0, }, nil } @@ -210,9 +212,15 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("failed to get market index: %w", err) } - // Build request URL - endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=%d", - t.baseURL, t.accountIndex, marketIndex) + // URL encode auth token (contains colons that need encoding) + // Authentication: Use "auth" query parameter (not Authorization header) + encodedAuth := url.QueryEscape(t.authToken) + + // Build request URL with auth query parameter + endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=%d&auth=%s", + t.baseURL, t.accountIndex, marketIndex, encodedAuth) + + logger.Debugf("📋 LIGHTER GetActiveOrders: endpoint=%s", endpoint[:min(len(endpoint), 120)]+"...") // Send GET request req, err := http.NewRequest("GET", endpoint, nil) @@ -220,8 +228,6 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("failed to create request: %w", err) } - // Add authentication header - req.Header.Set("Authorization", t.authToken) req.Header.Set("Content-Type", "application/json") resp, err := t.client.Do(req) @@ -235,11 +241,13 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("failed to read response: %w", err) } - // Parse response + logger.Debugf("📋 LIGHTER GetActiveOrders raw response: %s", string(body)) + + // Parse response - Lighter API uses "orders" field, not "data" var apiResp struct { Code int `json:"code"` Message string `json:"message"` - Data []OrderResponse `json:"data"` + Orders []OrderResponse `json:"orders"` } if err := json.Unmarshal(body, &apiResp); err != nil { @@ -250,11 +258,15 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("failed to get active orders (code %d): %s", apiResp.Code, apiResp.Message) } - logger.Infof("✓ LIGHTER - Retrieved %d active orders", len(apiResp.Data)) - return apiResp.Data, nil + logger.Infof("✓ LIGHTER - Retrieved %d active orders", len(apiResp.Orders)) + for i, order := range apiResp.Orders { + logger.Debugf(" Order[%d]: order_id=%s, order_index=%d, market=%d", i, order.OrderID, order.OrderIndex, order.MarketIndex) + } + return apiResp.Orders, nil } // CancelOrder Cancel a single order +// orderID can be either a numeric order_index or a tx_hash string func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { if t.txClient == nil { return fmt.Errorf("TxClient not initialized") @@ -267,10 +279,15 @@ func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { } marketIndex := uint8(marketIndexU16) // SDK expects uint8 - // Convert orderID to int64 + // Try to parse orderID as numeric order_index first orderIndex, err := strconv.ParseInt(orderID, 10, 64) if err != nil { - return fmt.Errorf("invalid order ID: %w", err) + // orderID is a tx_hash, need to query order to get numeric order_index + logger.Debugf("📋 LIGHTER CancelOrder: orderID is tx_hash, querying order...") + orderIndex, err = t.getOrderIndexByTxHash(symbol, orderID) + if err != nil { + return fmt.Errorf("failed to get order index from tx_hash: %w", err) + } } // Build cancel order request @@ -280,22 +297,26 @@ func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { } // Sign transaction using SDK + // Must provide FromAccountIndex and ApiKeyIndex for nonce auto-fetch to work nonce := int64(-1) // -1 means auto-fetch + apiKeyIdx := t.apiKeyIndex tx, err := t.txClient.GetCancelOrderTransaction(txReq, &types.TransactOpts{ - Nonce: &nonce, + FromAccountIndex: &t.accountIndex, + ApiKeyIndex: &apiKeyIdx, + Nonce: &nonce, }) if err != nil { return fmt.Errorf("failed to sign cancel order: %w", err) } - // Serialize transaction - txBytes, err := json.Marshal(tx) + // Get tx_info from SDK (consistent with CreateOrder and other transactions) + txInfo, err := tx.GetTxInfo() if err != nil { - return fmt.Errorf("failed to serialize transaction: %w", err) + return fmt.Errorf("failed to get tx info: %w", err) } - // Submit cancel order to LIGHTER API - _, err = t.submitCancelOrder(txBytes) + // Submit cancel order to LIGHTER API using unified submitOrder function + _, err = t.submitOrder(int(tx.GetTxType()), txInfo) if err != nil { return fmt.Errorf("failed to submit cancel order: %w", err) } @@ -304,65 +325,21 @@ func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { return nil } -// submitCancelOrder Submit signed cancel order to LIGHTER API using multipart/form-data -func (t *LighterTraderV2) submitCancelOrder(signedTx []byte) (map[string]interface{}, error) { - const TX_TYPE_CANCEL_ORDER = 15 - - // Build multipart form data (Lighter API requires form-data, not JSON) - var body bytes.Buffer - writer := multipart.NewWriter(&body) - - // Add tx_type field - if err := writer.WriteField("tx_type", strconv.Itoa(TX_TYPE_CANCEL_ORDER)); err != nil { - return nil, fmt.Errorf("failed to write tx_type: %w", err) - } - - // Add tx_info field - if err := writer.WriteField("tx_info", string(signedTx)); err != nil { - return nil, fmt.Errorf("failed to write tx_info: %w", err) - } - - // Close multipart writer - if err := writer.Close(); err != nil { - return nil, fmt.Errorf("failed to close multipart writer: %w", err) - } - - // Send POST request to /api/v1/sendTx - endpoint := fmt.Sprintf("%s/api/v1/sendTx", t.baseURL) - httpReq, err := http.NewRequest("POST", endpoint, &body) +// getOrderIndexByTxHash finds the numeric order_index by searching active orders for the tx_hash +func (t *LighterTraderV2) getOrderIndexByTxHash(symbol, txHash string) (int64, error) { + // Get all active orders for this symbol + orders, err := t.GetActiveOrders(symbol) if err != nil { - return nil, err + return 0, fmt.Errorf("failed to get active orders: %w", err) } - httpReq.Header.Set("Content-Type", writer.FormDataContentType()) - - resp, err := t.client.Do(httpReq) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err + // Search for the order with matching tx_hash (order_id) + for _, order := range orders { + if order.OrderID == txHash { + logger.Debugf("📋 LIGHTER Found order_index %d for tx_hash %s", order.OrderIndex, txHash) + return order.OrderIndex, nil + } } - // Parse response - var sendResp SendTxResponse - if err := json.Unmarshal(respBody, &sendResp); err != nil { - return nil, fmt.Errorf("failed to parse response: %w, body: %s", err, string(respBody)) - } - - // Check response code - if sendResp.Code != 200 { - return nil, fmt.Errorf("failed to submit cancel order (code %d): %s", sendResp.Code, sendResp.Message) - } - - result := map[string]interface{}{ - "tx_hash": sendResp.Data["tx_hash"], - "status": "cancelled", - } - - logger.Infof("✓ Cancel order submitted to LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"]) - return result, nil + return 0, fmt.Errorf("order not found with tx_hash: %s (may already be filled or cancelled)", txHash) } diff --git a/trader/lighter_trader_v2_orders_test.go b/trader/lighter_trader_v2_orders_test.go new file mode 100644 index 00000000..7b84912f --- /dev/null +++ b/trader/lighter_trader_v2_orders_test.go @@ -0,0 +1,421 @@ +package trader + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGetActiveOrders_ParseResponse tests parsing of Lighter API response +func TestGetActiveOrders_ParseResponse(t *testing.T) { + // Mock response from Lighter API + mockResponse := `{ + "code": 200, + "message": "success", + "orders": [ + { + "order_id": "123456", + "order_index": 123456, + "market_index": 0, + "side": "ask", + "type": "limit", + "is_ask": true, + "price": "3150.50", + "initial_base_amount": "1.5", + "remaining_base_amount": "1.5", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "", + "reduce_only": false, + "timestamp": 1736745600000, + "created_at": 1736745600000 + }, + { + "order_id": "123457", + "order_index": 123457, + "market_index": 0, + "side": "bid", + "type": "limit", + "is_ask": false, + "price": "3100.00", + "initial_base_amount": "2.0", + "remaining_base_amount": "2.0", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "", + "reduce_only": false, + "timestamp": 1736745601000, + "created_at": 1736745601000 + }, + { + "order_id": "123458", + "order_index": 123458, + "market_index": 0, + "side": "ask", + "type": "stop_loss", + "is_ask": true, + "price": "0", + "initial_base_amount": "1.0", + "remaining_base_amount": "1.0", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "3000.00", + "reduce_only": true, + "timestamp": 1736745602000, + "created_at": 1736745602000 + } + ] + }` + + // Parse the response + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + + err := json.Unmarshal([]byte(mockResponse), &apiResp) + require.NoError(t, err, "Should parse response without error") + + // Verify parsed data + assert.Equal(t, 200, apiResp.Code) + assert.Equal(t, 3, len(apiResp.Orders)) + + // Test first order (sell limit) + order1 := apiResp.Orders[0] + assert.Equal(t, "123456", order1.OrderID) + assert.True(t, order1.IsAsk, "First order should be ask (sell)") + assert.Equal(t, "3150.50", order1.Price) + assert.Equal(t, "1.5", order1.RemainingBaseAmount) + assert.False(t, order1.ReduceOnly) + + // Test second order (buy limit) + order2 := apiResp.Orders[1] + assert.Equal(t, "123457", order2.OrderID) + assert.False(t, order2.IsAsk, "Second order should be bid (buy)") + assert.Equal(t, "3100.00", order2.Price) + + // Test third order (stop-loss) + order3 := apiResp.Orders[2] + assert.Equal(t, "123458", order3.OrderID) + assert.Equal(t, "stop_loss", order3.Type) + assert.Equal(t, "3000.00", order3.TriggerPrice) + assert.True(t, order3.ReduceOnly) +} + +// TestGetActiveOrders_EmptyResponse tests handling of empty orders +func TestGetActiveOrders_EmptyResponse(t *testing.T) { + mockResponse := `{ + "code": 200, + "message": "success", + "orders": [] + }` + + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + + err := json.Unmarshal([]byte(mockResponse), &apiResp) + require.NoError(t, err) + assert.Equal(t, 200, apiResp.Code) + assert.Equal(t, 0, len(apiResp.Orders)) +} + +// TestGetActiveOrders_ErrorResponse tests handling of API error +func TestGetActiveOrders_ErrorResponse(t *testing.T) { + mockResponse := `{ + "code": 29500, + "message": "internal server error: invalid signature" + }` + + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + + err := json.Unmarshal([]byte(mockResponse), &apiResp) + require.NoError(t, err) + assert.Equal(t, 29500, apiResp.Code) + assert.Contains(t, apiResp.Message, "invalid signature") +} + +// TestConvertOrderResponseToOpenOrder tests conversion logic +func TestConvertOrderResponseToOpenOrder(t *testing.T) { + testCases := []struct { + name string + order OrderResponse + expectedSide string + expectedType string + expectedPosSide string + }{ + { + name: "Sell limit order (opening short)", + order: OrderResponse{ + OrderID: "1", + IsAsk: true, + Type: "limit", + Price: "3150.00", + RemainingBaseAmount: "1.0", + ReduceOnly: false, + }, + expectedSide: "SELL", + expectedType: "LIMIT", + expectedPosSide: "SHORT", + }, + { + name: "Buy limit order (opening long)", + order: OrderResponse{ + OrderID: "2", + IsAsk: false, + Type: "limit", + Price: "3100.00", + RemainingBaseAmount: "1.0", + ReduceOnly: false, + }, + expectedSide: "BUY", + expectedType: "LIMIT", + expectedPosSide: "LONG", + }, + { + name: "Sell stop-loss (closing long)", + order: OrderResponse{ + OrderID: "3", + IsAsk: true, + Type: "stop_loss", + TriggerPrice: "3000.00", + RemainingBaseAmount: "1.0", + ReduceOnly: true, + }, + expectedSide: "SELL", + expectedType: "STOP_MARKET", + expectedPosSide: "LONG", + }, + { + name: "Buy stop-loss (closing short)", + order: OrderResponse{ + OrderID: "4", + IsAsk: false, + Type: "stop_loss", + TriggerPrice: "3200.00", + RemainingBaseAmount: "1.0", + ReduceOnly: true, + }, + expectedSide: "BUY", + expectedType: "STOP_MARKET", + expectedPosSide: "SHORT", + }, + { + name: "Take profit (closing long)", + order: OrderResponse{ + OrderID: "5", + IsAsk: true, + Type: "take_profit", + TriggerPrice: "3500.00", + RemainingBaseAmount: "1.0", + ReduceOnly: true, + }, + expectedSide: "SELL", + expectedType: "TAKE_PROFIT_MARKET", + expectedPosSide: "LONG", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Convert side + side := "BUY" + if tc.order.IsAsk { + side = "SELL" + } + assert.Equal(t, tc.expectedSide, side) + + // Convert order type + orderType := "LIMIT" + if tc.order.Type == "market" { + orderType = "MARKET" + } else if tc.order.Type == "stop_loss" || tc.order.Type == "stop" { + orderType = "STOP_MARKET" + } else if tc.order.Type == "take_profit" { + orderType = "TAKE_PROFIT_MARKET" + } + assert.Equal(t, tc.expectedType, orderType) + + // Convert position side + positionSide := "LONG" + if tc.order.ReduceOnly { + if side == "BUY" { + positionSide = "SHORT" + } else { + positionSide = "LONG" + } + } else { + if side == "SELL" { + positionSide = "SHORT" + } + } + assert.Equal(t, tc.expectedPosSide, positionSide) + }) + } +} + +// TestGetActiveOrders_MockServer tests the full HTTP flow with a mock server +func TestGetActiveOrders_MockServer(t *testing.T) { + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request path and auth parameter + assert.Contains(t, r.URL.Path, "/api/v1/accountActiveOrders") + + // Check that auth query parameter is present + authParam := r.URL.Query().Get("auth") + if authParam == "" { + // Return error if no auth parameter + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]interface{}{ + "code": 29500, + "message": "internal server error: invalid signature", + }) + return + } + + // Return success response + response := map[string]interface{}{ + "code": 200, + "message": "success", + "orders": []map[string]interface{}{ + { + "order_id": "123456", + "order_index": 123456, + "market_index": 0, + "side": "ask", + "type": "limit", + "is_ask": true, + "price": "3150.50", + "initial_base_amount": "1.5", + "remaining_base_amount": "1.5", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "", + "reduce_only": false, + }, + }, + } + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // Test request without auth - should fail + resp, err := http.Get(server.URL + "/api/v1/accountActiveOrders?account_index=123&market_id=0") + require.NoError(t, err) + defer resp.Body.Close() + + var errorResp struct { + Code int `json:"code"` + Message string `json:"message"` + } + json.NewDecoder(resp.Body).Decode(&errorResp) + assert.Equal(t, 29500, errorResp.Code) + + // Test request with auth - should succeed + resp2, err := http.Get(server.URL + "/api/v1/accountActiveOrders?account_index=123&market_id=0&auth=test_token") + require.NoError(t, err) + defer resp2.Body.Close() + + var successResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + json.NewDecoder(resp2.Body).Decode(&successResp) + assert.Equal(t, 200, successResp.Code) + assert.Equal(t, 1, len(successResp.Orders)) +} + +// TestAuthTokenFormat tests the auth token format +func TestAuthTokenFormat(t *testing.T) { + // Auth token format: timestamp:account_index:api_key_index:signature + // Example: 1768308847:687247:0:742e02... + + sampleToken := "1768308847:687247:0:742e02abc123" + + // The token should be URL encoded when used as query parameter + // Colons become %3A + expectedEncoded := "1768308847%3A687247%3A0%3A742e02abc123" + + // URL encode the token + encoded := url.QueryEscape(sampleToken) + + assert.Equal(t, expectedEncoded, encoded) +} + +// TestOrderResponseStruct tests that OrderResponse struct matches API response +func TestOrderResponseStruct(t *testing.T) { + // Real API response sample (from logs) + realResponse := `{ + "order_id": "4609885", + "order_index": 4609885, + "market_index": 0, + "side": "ask", + "type": "limit", + "is_ask": true, + "price": "3150.00", + "initial_base_amount": "0.0300", + "remaining_base_amount": "0.0300", + "filled_base_amount": "0", + "status": "open", + "trigger_price": "", + "reduce_only": false, + "timestamp": 1736745600000, + "created_at": 1736745600000 + }` + + var order OrderResponse + err := json.Unmarshal([]byte(realResponse), &order) + require.NoError(t, err) + + assert.Equal(t, "4609885", order.OrderID) + assert.Equal(t, int64(4609885), order.OrderIndex) + assert.Equal(t, 0, order.MarketIndex) + assert.Equal(t, "ask", order.Side) + assert.Equal(t, "limit", order.Type) + assert.True(t, order.IsAsk) + assert.Equal(t, "3150.00", order.Price) + assert.Equal(t, "0.0300", order.InitialBaseAmount) + assert.Equal(t, "0.0300", order.RemainingBaseAmount) + assert.Equal(t, "0", order.FilledBaseAmount) + assert.Equal(t, "open", order.Status) + assert.Equal(t, "", order.TriggerPrice) + assert.False(t, order.ReduceOnly) + assert.Equal(t, int64(1736745600000), order.Timestamp) + assert.Equal(t, int64(1736745600000), order.CreatedAt) +} + +// BenchmarkParseOrderResponse benchmarks response parsing +func BenchmarkParseOrderResponse(b *testing.B) { + mockResponse := `{ + "code": 200, + "message": "success", + "orders": [ + {"order_id": "1", "is_ask": true, "price": "3150.50", "remaining_base_amount": "1.5"}, + {"order_id": "2", "is_ask": false, "price": "3100.00", "remaining_base_amount": "2.0"}, + {"order_id": "3", "is_ask": true, "price": "3200.00", "remaining_base_amount": "0.5"} + ] + }` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var apiResp struct { + Code int `json:"code"` + Message string `json:"message"` + Orders []OrderResponse `json:"orders"` + } + json.Unmarshal([]byte(mockResponse), &apiResp) + } +} diff --git a/trader/lighter_trader_v2_trading.go b/trader/lighter_trader_v2_trading.go index ff5a7341..1f4ff417 100644 --- a/trader/lighter_trader_v2_trading.go +++ b/trader/lighter_trader_v2_trading.go @@ -273,9 +273,13 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6 } // Sign transaction using SDK (nonce will be auto-fetched) + // Must provide FromAccountIndex and ApiKeyIndex for nonce auto-fetch to work nonce := int64(-1) // -1 means auto-fetch + apiKeyIdx := t.apiKeyIndex tx, err := t.txClient.GetCreateOrderTransaction(txReq, &types.TransactOpts{ - Nonce: &nonce, + FromAccountIndex: &t.accountIndex, + ApiKeyIndex: &apiKeyIdx, + Nonce: &nonce, }) if err != nil { return nil, fmt.Errorf("failed to sign order: %w", err) @@ -288,7 +292,7 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6 } // Debug: Log the tx_info content - logger.Infof("DEBUG tx_type: %d, tx_info: %s", tx.GetTxType(), txInfo) + logger.Debugf("tx_type: %d, tx_info: %s", tx.GetTxType(), txInfo) // Submit order to LIGHTER API orderResp, err := t.submitOrder(int(tx.GetTxType()), txInfo) @@ -302,6 +306,16 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6 } logger.Infof("✓ LIGHTER order created: %s %s qty=%.4f", symbol, side, quantity) + // For limit orders, poll for the actual order_index after submission + // This is needed because CancelOrder requires the numeric order_index, not tx_hash + if orderType == "limit" { + txHash, _ := orderResp["tx_hash"].(string) + if orderIndex, err := t.pollForOrderIndex(symbol, txHash); err == nil && orderIndex > 0 { + orderResp["orderId"] = fmt.Sprintf("%d", orderIndex) + orderResp["order_index"] = orderIndex + } + } + return orderResp, nil } @@ -386,10 +400,19 @@ func (t *LighterTraderV2) submitOrder(txType int, txInfo string) (map[string]int } // Log full response for debugging - logger.Infof("DEBUG API response: %s", string(respBody)) + logger.Debugf("API response: %s", string(respBody)) // Check response code if sendResp.Code != 200 { + // Provide more specific error message for signature errors + // Code 21120: invalid signature (order submission) + // Code 29500: internal server error: invalid signature (authenticated GET APIs) + if (sendResp.Code == 21120 || sendResp.Code == 29500) && strings.Contains(sendResp.Message, "invalid signature") { + if !t.apiKeyValid { + return nil, fmt.Errorf("API Key MISMATCH (code %d): The API key stored in NOFX does not match the one registered on Lighter. Please update your Lighter API key in Exchange settings at app.lighter.xyz", sendResp.Code) + } + return nil, fmt.Errorf("API Key signature invalid (code %d): Please verify your Lighter API Key in Exchange settings matches the key registered at app.lighter.xyz", sendResp.Code) + } return nil, fmt.Errorf("failed to submit order (code %d): %s", sendResp.Code, sendResp.Message) } @@ -403,17 +426,45 @@ func (t *LighterTraderV2) submitOrder(txType int, txInfo string) (map[string]int } } + logger.Infof("✓ Order submitted to LIGHTER - tx_hash: %s", txHash) + result := map[string]interface{}{ "tx_hash": txHash, "status": "submitted", - "orderId": txHash, // Use tx_hash as orderId + "orderId": txHash, // Use tx_hash as orderId initially } - logger.Infof("✓ Order submitted to LIGHTER - tx_hash: %s", txHash) - return result, nil } +// pollForOrderIndex polls active orders to find the order_index for a newly created order +// Returns the highest order_index (newest order) for the given symbol +func (t *LighterTraderV2) pollForOrderIndex(symbol string, txHash string) (int64, error) { + // Wait a moment for the order to be processed + time.Sleep(500 * time.Millisecond) + + // Get active orders + orders, err := t.GetActiveOrders(symbol) + if err != nil { + return 0, fmt.Errorf("failed to get active orders: %w", err) + } + + if len(orders) == 0 { + return 0, fmt.Errorf("no active orders found (order may have been filled immediately)") + } + + // Find the highest order_index (newest order) + var highestIndex int64 + for _, order := range orders { + if order.OrderIndex > highestIndex { + highestIndex = order.OrderIndex + } + } + + logger.Infof("✓ Order created with order_index: %d (tx_hash: %s)", highestIndex, txHash) + return highestIndex, nil +} + // normalizeSymbol Convert NOFX symbol format to Lighter format // NOFX uses "BTC-PERP", "BTCUSDT", etc. Lighter uses "BTC", "ETH", etc. func normalizeSymbol(symbol string) string { @@ -431,7 +482,7 @@ func (t *LighterTraderV2) getMarketInfo(symbol string) (*MarketInfo, error) { // Normalize symbol to Lighter format normalizedSymbol := normalizeSymbol(symbol) - // 1. Fetch market list from API (TODO: cache this) + // Fetch market list from API (cached for 1 hour) markets, err := t.fetchMarketList() if err != nil { return nil, fmt.Errorf("failed to fetch market list: %w", err) @@ -467,8 +518,18 @@ type MarketInfo struct { PriceDecimals int `json:"price_decimals"` } -// fetchMarketList Fetch market list from API +// fetchMarketList Fetch market list from API with caching (TTL: 1 hour) func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) { + // Check cache (TTL: 1 hour) + t.marketMutex.RLock() + if len(t.marketListCache) > 0 && time.Since(t.marketListCacheTime) < time.Hour { + cached := t.marketListCache + t.marketMutex.RUnlock() + return cached, nil + } + t.marketMutex.RUnlock() + + // Fetch from API endpoint := fmt.Sprintf("%s/api/v1/orderBooks", t.baseURL) req, err := http.NewRequest("GET", endpoint, nil) @@ -514,14 +575,20 @@ func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) { for _, market := range apiResp.OrderBooks { if market.Status == "active" { markets = append(markets, MarketInfo{ - Symbol: market.Symbol, - MarketID: market.MarketID, - SizeDecimals: market.SupportedSizeDecimals, - PriceDecimals: market.SupportedPriceDecimals, + Symbol: market.Symbol, + MarketID: market.MarketID, + SizeDecimals: market.SupportedSizeDecimals, + PriceDecimals: market.SupportedPriceDecimals, }) } } + // Update cache + t.marketMutex.Lock() + t.marketListCache = markets + t.marketListCacheTime = time.Now() + t.marketMutex.Unlock() + logger.Infof("✓ Retrieved %d active markets from Lighter", len(markets)) return markets, nil } @@ -550,31 +617,132 @@ func (t *LighterTraderV2) getFallbackMarketIndex(symbol string) (uint16, error) } // SetLeverage Set leverage (implements Trader interface) +// Lighter uses InitialMarginFraction to represent leverage: +// - InitialMarginFraction = (100 / leverage) * 100 (stored as percentage * 100) +// - e.g., 5x leverage = 20% margin = 2000 in API +// - e.g., 20x leverage = 5% margin = 500 in API func (t *LighterTraderV2) SetLeverage(symbol string, leverage int) error { if t.txClient == nil { return fmt.Errorf("TxClient not initialized") } - // TODO: Sign and submit SetLeverage transaction using SDK - logger.Infof("⚙️ Setting leverage: %s = %dx", symbol, leverage) + // Validate leverage range (1x to 50x typical max) + if leverage < 1 || leverage > 50 { + return fmt.Errorf("leverage must be between 1 and 50, got %d", leverage) + } - return nil // Return success for now + // Get market info (includes market_id) + marketInfo, err := t.getMarketInfo(symbol) + if err != nil { + return fmt.Errorf("failed to get market info: %w", err) + } + marketIndex := uint8(marketInfo.MarketID) + + // Calculate InitialMarginFraction from leverage + // leverage = 100 / margin_fraction_percent + // margin_fraction_percent = 100 / leverage + // API value = margin_fraction_percent * 100 + marginFractionPercent := 100.0 / float64(leverage) + initialMarginFraction := uint16(marginFractionPercent * 100) // e.g., 5x => 20% => 2000 + + logger.Infof("⚙️ Setting leverage: %s = %dx (margin_fraction=%.2f%%, API value=%d)", + symbol, leverage, marginFractionPercent, initialMarginFraction) + + // Build UpdateLeverage request + txReq := &types.UpdateLeverageTxReq{ + MarketIndex: marketIndex, + InitialMarginFraction: initialMarginFraction, + MarginMode: 0, // 0 = cross margin (default) + } + + // Sign transaction using SDK + nonce := int64(-1) // Auto-fetch nonce + tx, err := t.txClient.GetUpdateLeverageTransaction(txReq, &types.TransactOpts{ + Nonce: &nonce, + }) + if err != nil { + return fmt.Errorf("failed to sign leverage transaction: %w", err) + } + + // Get tx_info from SDK + txInfo, err := tx.GetTxInfo() + if err != nil { + return fmt.Errorf("failed to get tx info: %w", err) + } + + // Submit to Lighter API (reuse submitOrder which handles any transaction type) + result, err := t.submitOrder(int(tx.GetTxType()), txInfo) + if err != nil { + return fmt.Errorf("failed to submit leverage transaction: %w", err) + } + + logger.Infof("✓ Leverage set successfully: %s = %dx (tx_hash: %v)", symbol, leverage, result["tx_hash"]) + return nil } // SetMarginMode Set margin mode (implements Trader interface) +// Lighter uses UpdateLeverage transaction which includes both leverage and margin mode +// MarginMode: 0 = cross, 1 = isolated func (t *LighterTraderV2) SetMarginMode(symbol string, isCrossMargin bool) error { if t.txClient == nil { return fmt.Errorf("TxClient not initialized") } - modeStr := "isolated" - if isCrossMargin { - modeStr = "cross" + // Get market info + marketInfo, err := t.getMarketInfo(symbol) + if err != nil { + return fmt.Errorf("failed to get market info: %w", err) + } + marketIndex := uint8(marketInfo.MarketID) + + // Determine margin mode value + var marginMode uint8 = 0 // cross + modeStr := "cross" + if !isCrossMargin { + marginMode = 1 // isolated + modeStr = "isolated" } - logger.Infof("⚙️ Setting margin mode: %s = %s", symbol, modeStr) + // Get current position to preserve leverage, or use default 10x if no position + var initialMarginFraction uint16 = 1000 // Default 10x leverage (10% margin = 1000) + pos, err := t.GetPosition(symbol) + if err == nil && pos != nil && pos.Leverage > 0 { + // Calculate InitialMarginFraction from current leverage + marginFractionPercent := 100.0 / pos.Leverage + initialMarginFraction = uint16(marginFractionPercent * 100) + } - // TODO: Sign and submit SetMarginMode transaction using SDK + logger.Infof("⚙️ Setting margin mode: %s = %s (margin_mode=%d, preserving leverage)", symbol, modeStr, marginMode) + + // Build UpdateLeverage request (also updates margin mode) + txReq := &types.UpdateLeverageTxReq{ + MarketIndex: marketIndex, + InitialMarginFraction: initialMarginFraction, + MarginMode: marginMode, + } + + // Sign transaction + nonce := int64(-1) + tx, err := t.txClient.GetUpdateLeverageTransaction(txReq, &types.TransactOpts{ + Nonce: &nonce, + }) + if err != nil { + return fmt.Errorf("failed to sign margin mode transaction: %w", err) + } + + // Get tx_info + txInfo, err := tx.GetTxInfo() + if err != nil { + return fmt.Errorf("failed to get tx info: %w", err) + } + + // Submit to Lighter API + result, err := t.submitOrder(int(tx.GetTxType()), txInfo) + if err != nil { + return fmt.Errorf("failed to submit margin mode transaction: %w", err) + } + + logger.Infof("✓ Margin mode set successfully: %s = %s (tx_hash: %v)", symbol, modeStr, result["tx_hash"]) return nil } @@ -653,7 +821,7 @@ func (t *LighterTraderV2) CreateStopOrder(symbol string, isAsk bool, quantity fl return nil, fmt.Errorf("failed to get tx info: %w", err) } - logger.Infof("DEBUG stop order - type: %d, trigger: %.2f, price: %.2f, isAsk: %v", orderTypeValue, triggerPrice, float64(priceValue)/100, isAsk) + logger.Debugf("stop order - type: %d, trigger: %.2f, price: %.2f, isAsk: %v", orderTypeValue, triggerPrice, float64(priceValue)/100, isAsk) // Submit order orderResp, err := t.submitOrder(int(tx.GetTxType()), txInfo) @@ -689,6 +857,117 @@ func pow10(n int) int64 { // GetOpenOrders gets all open/pending orders for a symbol func (t *LighterTraderV2) GetOpenOrders(symbol string) ([]OpenOrder, error) { - // TODO: Implement Lighter open orders - return []OpenOrder{}, nil + // Get active orders from Lighter API + activeOrders, err := t.GetActiveOrders(symbol) + if err != nil { + return nil, fmt.Errorf("failed to get active orders: %w", err) + } + + var result []OpenOrder + for _, order := range activeOrders { + // Convert side: Lighter uses is_ask (true=sell, false=buy) + side := "BUY" + if order.IsAsk { + side = "SELL" + } + + // Determine order type from Lighter's type field + orderType := "LIMIT" + if order.Type == "market" { + orderType = "MARKET" + } else if order.Type == "stop_loss" || order.Type == "stop" { + orderType = "STOP_MARKET" + } else if order.Type == "take_profit" { + orderType = "TAKE_PROFIT_MARKET" + } + + // Determine position side based on order direction and reduce-only flag + positionSide := "LONG" + if order.ReduceOnly { + // For reduce-only orders, position side is opposite to order side + if side == "BUY" { + positionSide = "SHORT" // Buying to close short + } else { + positionSide = "LONG" // Selling to close long + } + } else { + // For opening orders + if side == "SELL" { + positionSide = "SHORT" + } + } + + // Parse price and quantity from string fields + price, _ := strconv.ParseFloat(order.Price, 64) + quantity, _ := strconv.ParseFloat(order.RemainingBaseAmount, 64) + if quantity == 0 { + quantity, _ = strconv.ParseFloat(order.InitialBaseAmount, 64) + } + triggerPrice, _ := strconv.ParseFloat(order.TriggerPrice, 64) + + openOrder := OpenOrder{ + OrderID: order.OrderID, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: orderType, + Price: price, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + } + result = append(result, openOrder) + } + + logger.Infof("✓ LIGHTER GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder implements GridTrader interface for grid trading +// Places a limit order at the specified price +func (t *LighterTraderV2) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + if t.txClient == nil { + return nil, fmt.Errorf("TxClient not initialized") + } + + // Determine if this is a sell (ask) order + isAsk := req.Side == "SELL" + + logger.Infof("📝 LIGHTER placing limit order: %s %s @ %.4f, qty=%.4f, leverage=%dx", + req.Symbol, req.Side, req.Price, req.Quantity, req.Leverage) + + // Set leverage before placing order (important for grid trading) + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("⚠️ Failed to set leverage: %v (continuing with current leverage)", err) + } + } + + // Create limit order using existing CreateOrder function + orderResult, err := t.CreateOrder(req.Symbol, isAsk, req.Quantity, req.Price, "limit", req.ReduceOnly) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + // Extract order ID from result + orderID := "" + if id, ok := orderResult["orderId"]; ok { + orderID = fmt.Sprintf("%v", id) + } else if txHash, ok := orderResult["tx_hash"]; ok { + orderID = fmt.Sprintf("%v", txHash) + } + + logger.Infof("✓ LIGHTER limit order placed: %s %s @ %.4f, OrderID: %s", + req.Symbol, req.Side, req.Price, orderID) + + return &LimitOrderResult{ + OrderID: orderID, + ClientID: req.ClientID, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil } diff --git a/trader/lighter_types.go b/trader/lighter_types.go index 3a76b7c4..e4670cdd 100644 --- a/trader/lighter_types.go +++ b/trader/lighter_types.go @@ -41,18 +41,24 @@ type CreateOrderRequest struct { PostOnly bool `json:"post_only"` // Post-only (maker only) } -// OrderResponse Order response (Lighter) +// OrderResponse Order response (Lighter API) +// Field names must match Lighter API response exactly type OrderResponse struct { - OrderID string `json:"order_id"` - Symbol string `json:"symbol"` - Side string `json:"side"` - OrderType string `json:"order_type"` - Quantity float64 `json:"quantity"` - Price float64 `json:"price"` - Status string `json:"status"` // "open", "filled", "cancelled" - FilledQty float64 `json:"filled_qty"` - RemainingQty float64 `json:"remaining_qty"` - CreateTime int64 `json:"create_time"` + OrderID string `json:"order_id"` + OrderIndex int64 `json:"order_index"` + MarketIndex int `json:"market_index"` + Side string `json:"side"` // "bid" or "ask" + Type string `json:"type"` // "limit", "market", etc. + IsAsk bool `json:"is_ask"` // true = sell, false = buy + Price string `json:"price"` // Price as string + InitialBaseAmount string `json:"initial_base_amount"` // Original quantity + RemainingBaseAmount string `json:"remaining_base_amount"` // Remaining quantity + FilledBaseAmount string `json:"filled_base_amount"` // Filled quantity + Status string `json:"status"` // "open", "filled", "cancelled" + TriggerPrice string `json:"trigger_price"` // For stop orders + ReduceOnly bool `json:"reduce_only"` + Timestamp int64 `json:"timestamp"` + CreatedAt int64 `json:"created_at"` } // LighterTradeResponse represents the response from Lighter trades API diff --git a/trader/okx_trader.go b/trader/okx_trader.go index 2d4b8b89..fc3fae69 100644 --- a/trader/okx_trader.go +++ b/trader/okx_trader.go @@ -1390,6 +1390,254 @@ func (t *OKXTrader) GetClosedPnL(startTime time.Time, limit int) ([]ClosedPnLRec // GetOpenOrders gets all open/pending orders for a symbol func (t *OKXTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { - // TODO: Implement OKX open orders - return []OpenOrder{}, nil + instId := t.convertSymbol(symbol) + var result []OpenOrder + + // 1. Get pending limit orders + path := fmt.Sprintf("%s?instId=%s&instType=SWAP", okxPendingOrdersPath, instId) + data, err := t.doRequest("GET", path, nil) + if err != nil { + logger.Warnf("[OKX] Failed to get pending orders: %v", err) + } + if err == nil && data != nil { + var orders []struct { + OrdId string `json:"ordId"` + InstId string `json:"instId"` + Side string `json:"side"` // buy/sell + PosSide string `json:"posSide"` // long/short/net + OrdType string `json:"ordType"` // limit/market/post_only + Px string `json:"px"` // price + Sz string `json:"sz"` // size + State string `json:"state"` // live/partially_filled + } + if err := json.Unmarshal(data, &orders); err == nil { + for _, order := range orders { + price, _ := strconv.ParseFloat(order.Px, 64) + quantity, _ := strconv.ParseFloat(order.Sz, 64) + + // Convert OKX side to standard format + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + if positionSide == "NET" { + positionSide = "BOTH" + } + + result = append(result, OpenOrder{ + OrderID: order.OrdId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: strings.ToUpper(order.OrdType), + Price: price, + StopPrice: 0, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + // 2. Get pending algo orders (stop-loss/take-profit) + algoPath := fmt.Sprintf("%s?instId=%s&instType=SWAP", okxAlgoPendingPath, instId) + algoData, err := t.doRequest("GET", algoPath, nil) + if err != nil { + logger.Warnf("[OKX] Failed to get algo orders: %v", err) + } + if err == nil && algoData != nil { + var algoOrders []struct { + AlgoId string `json:"algoId"` + InstId string `json:"instId"` + Side string `json:"side"` + PosSide string `json:"posSide"` + OrdType string `json:"ordType"` // conditional/oco/trigger + TriggerPx string `json:"triggerPx"` + Sz string `json:"sz"` + State string `json:"state"` + } + if err := json.Unmarshal(algoData, &algoOrders); err == nil { + for _, order := range algoOrders { + triggerPrice, _ := strconv.ParseFloat(order.TriggerPx, 64) + quantity, _ := strconv.ParseFloat(order.Sz, 64) + + side := strings.ToUpper(order.Side) + positionSide := strings.ToUpper(order.PosSide) + if positionSide == "NET" { + positionSide = "BOTH" + } + + // Map OKX algo order type + orderType := "STOP_MARKET" + if order.OrdType == "oco" { + orderType = "TAKE_PROFIT_MARKET" + } + + result = append(result, OpenOrder{ + OrderID: order.AlgoId, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Type: orderType, + Price: 0, + StopPrice: triggerPrice, + Quantity: quantity, + Status: "NEW", + }) + } + } + } + + logger.Infof("✓ OKX GetOpenOrders: found %d open orders for %s", len(result), symbol) + return result, nil +} + +// PlaceLimitOrder places a limit order for grid trading +// Implements GridTrader interface +func (t *OKXTrader) PlaceLimitOrder(req *LimitOrderRequest) (*LimitOrderResult, error) { + instId := t.convertSymbol(req.Symbol) + + // Get instrument info + inst, err := t.getInstrument(req.Symbol) + if err != nil { + return nil, fmt.Errorf("failed to get instrument info: %w", err) + } + + // Set leverage if specified + if req.Leverage > 0 { + if err := t.SetLeverage(req.Symbol, req.Leverage); err != nil { + logger.Warnf("[OKX] Failed to set leverage: %v", err) + } + } + + // Convert quantity to contract size + sz := req.Quantity / inst.CtVal + szStr := t.formatSize(sz, inst) + + // Determine side and position side + side := "buy" + posSide := "long" + if req.Side == "SELL" { + side = "sell" + posSide = "short" + } + + body := map[string]interface{}{ + "instId": instId, + "tdMode": "cross", + "side": side, + "posSide": posSide, + "ordType": "limit", + "sz": szStr, + "px": fmt.Sprintf("%.8f", req.Price), + "clOrdId": genOkxClOrdID(), + "tag": okxTag, + } + + // Add reduce only if specified + if req.ReduceOnly { + body["reduceOnly"] = true + } + + logger.Infof("[OKX] PlaceLimitOrder: %s %s @ %.4f, sz=%s", instId, side, req.Price, szStr) + + data, err := t.doRequest("POST", okxOrderPath, body) + if err != nil { + return nil, fmt.Errorf("failed to place limit order: %w", err) + } + + var orders []struct { + OrdId string `json:"ordId"` + ClOrdId string `json:"clOrdId"` + SCode string `json:"sCode"` + SMsg string `json:"sMsg"` + } + + if err := json.Unmarshal(data, &orders); err != nil { + return nil, fmt.Errorf("failed to parse order response: %w", err) + } + + if len(orders) == 0 { + return nil, fmt.Errorf("empty order response") + } + + if orders[0].SCode != "0" { + return nil, fmt.Errorf("OKX order failed: %s", orders[0].SMsg) + } + + logger.Infof("✓ [OKX] Limit order placed: %s %s @ %.4f, orderID=%s", + instId, side, req.Price, orders[0].OrdId) + + return &LimitOrderResult{ + OrderID: orders[0].OrdId, + ClientID: orders[0].ClOrdId, + Symbol: req.Symbol, + Side: req.Side, + PositionSide: req.PositionSide, + Price: req.Price, + Quantity: req.Quantity, + Status: "NEW", + }, nil +} + +// CancelOrder cancels a specific order by ID +// Implements GridTrader interface +func (t *OKXTrader) CancelOrder(symbol, orderID string) error { + instId := t.convertSymbol(symbol) + + body := map[string]interface{}{ + "instId": instId, + "ordId": orderID, + } + + _, err := t.doRequest("POST", "/api/v5/trade/cancel-order", body) + if err != nil { + return fmt.Errorf("failed to cancel order: %w", err) + } + + logger.Infof("✓ [OKX] Order cancelled: %s %s", symbol, orderID) + return nil +} + +// GetOrderBook gets the order book for a symbol +// Implements GridTrader interface +func (t *OKXTrader) GetOrderBook(symbol string, depth int) (bids, asks [][]float64, err error) { + instId := t.convertSymbol(symbol) + path := fmt.Sprintf("/api/v5/market/books?instId=%s&sz=%d", instId, depth) + + data, err := t.doRequest("GET", path, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to get order book: %w", err) + } + + var result []struct { + Bids [][]string `json:"bids"` + Asks [][]string `json:"asks"` + } + + if err := json.Unmarshal(data, &result); err != nil { + return nil, nil, fmt.Errorf("failed to parse order book: %w", err) + } + + if len(result) == 0 { + return nil, nil, nil + } + + // Parse bids + for _, b := range result[0].Bids { + if len(b) >= 2 { + price, _ := strconv.ParseFloat(b[0], 64) + qty, _ := strconv.ParseFloat(b[1], 64) + bids = append(bids, []float64{price, qty}) + } + } + + // Parse asks + for _, a := range result[0].Asks { + if len(a) >= 2 { + price, _ := strconv.ParseFloat(a[0], 64) + qty, _ := strconv.ParseFloat(a[1], 64) + asks = append(asks, []float64{price, qty}) + } + } + + return bids, asks, nil } diff --git a/web/src/components/strategy/GridConfigEditor.tsx b/web/src/components/strategy/GridConfigEditor.tsx new file mode 100644 index 00000000..7756f2ba --- /dev/null +++ b/web/src/components/strategy/GridConfigEditor.tsx @@ -0,0 +1,424 @@ +import { Grid, DollarSign, TrendingUp, Shield } from 'lucide-react' +import type { GridStrategyConfig } from '../../types' + +interface GridConfigEditorProps { + config: GridStrategyConfig + onChange: (config: GridStrategyConfig) => void + disabled?: boolean + language: string +} + +// Default grid config +export const defaultGridConfig: GridStrategyConfig = { + symbol: 'BTCUSDT', + grid_count: 10, + total_investment: 1000, + leverage: 5, + upper_price: 0, + lower_price: 0, + use_atr_bounds: true, + atr_multiplier: 2.0, + distribution: 'gaussian', + max_drawdown_pct: 15, + stop_loss_pct: 5, + daily_loss_limit_pct: 10, + use_maker_only: true, +} + +export function GridConfigEditor({ + config, + onChange, + disabled, + language, +}: GridConfigEditorProps) { + const t = (key: string) => { + const translations: Record> = { + // Section titles + tradingPair: { zh: '交易设置', en: 'Trading Setup' }, + gridParameters: { zh: '网格参数', en: 'Grid Parameters' }, + priceBounds: { zh: '价格边界', en: 'Price Bounds' }, + riskControl: { zh: '风险控制', en: 'Risk Control' }, + + // Trading pair + symbol: { zh: '交易对', en: 'Trading Pair' }, + symbolDesc: { zh: '选择要进行网格交易的交易对', en: 'Select trading pair for grid trading' }, + + // Investment + totalInvestment: { zh: '投资金额 (USDT)', en: 'Investment (USDT)' }, + totalInvestmentDesc: { zh: '网格策略的总投资金额', en: 'Total investment for grid strategy' }, + leverage: { zh: '杠杆倍数', en: 'Leverage' }, + leverageDesc: { zh: '交易使用的杠杆倍数 (1-20)', en: 'Leverage for trading (1-20)' }, + + // Grid parameters + gridCount: { zh: '网格数量', en: 'Grid Count' }, + gridCountDesc: { zh: '网格层级数量 (5-50)', en: 'Number of grid levels (5-50)' }, + distribution: { zh: '资金分配方式', en: 'Distribution' }, + distributionDesc: { zh: '网格层级的资金分配方式', en: 'Fund allocation across grid levels' }, + uniform: { zh: '均匀分配', en: 'Uniform' }, + gaussian: { zh: '高斯分配 (推荐)', en: 'Gaussian (Recommended)' }, + pyramid: { zh: '金字塔分配', en: 'Pyramid' }, + + // Price bounds + useAtrBounds: { zh: '自动计算边界 (ATR)', en: 'Auto-calculate Bounds (ATR)' }, + useAtrBoundsDesc: { zh: '基于 ATR 自动计算网格上下边界', en: 'Auto-calculate bounds based on ATR' }, + atrMultiplier: { zh: 'ATR 倍数', en: 'ATR Multiplier' }, + atrMultiplierDesc: { zh: '边界距离当前价格的 ATR 倍数', en: 'ATR multiplier for bounds distance' }, + upperPrice: { zh: '上边界价格', en: 'Upper Price' }, + upperPriceDesc: { zh: '网格上边界价格 (0=自动计算)', en: 'Grid upper bound (0=auto)' }, + lowerPrice: { zh: '下边界价格', en: 'Lower Price' }, + lowerPriceDesc: { zh: '网格下边界价格 (0=自动计算)', en: 'Grid lower bound (0=auto)' }, + + // Risk control + maxDrawdown: { zh: '最大回撤 (%)', en: 'Max Drawdown (%)' }, + maxDrawdownDesc: { zh: '触发紧急退出的最大回撤百分比', en: 'Max drawdown before emergency exit' }, + stopLoss: { zh: '止损 (%)', en: 'Stop Loss (%)' }, + stopLossDesc: { zh: '单仓位止损百分比', en: 'Stop loss per position' }, + dailyLossLimit: { zh: '日损失限制 (%)', en: 'Daily Loss Limit (%)' }, + dailyLossLimitDesc: { zh: '每日最大亏损百分比', en: 'Maximum daily loss percentage' }, + useMakerOnly: { zh: '仅使用 Maker 订单', en: 'Maker Only Orders' }, + useMakerOnlyDesc: { zh: '使用限价单以降低手续费', en: 'Use limit orders for lower fees' }, + } + return translations[key]?.[language] || key + } + + const updateField = ( + key: K, + value: GridStrategyConfig[K] + ) => { + if (!disabled) { + onChange({ ...config, [key]: value }) + } + } + + const inputStyle = { + background: '#1E2329', + border: '1px solid #2B3139', + color: '#EAECEF', + } + + const sectionStyle = { + background: '#0B0E11', + border: '1px solid #2B3139', + } + + return ( +
+ {/* Trading Setup */} +
+
+ +

+ {t('tradingPair')} +

+
+ +
+ {/* Symbol */} +
+ +

+ {t('symbolDesc')} +

+ +
+ + {/* Investment */} +
+ +

+ {t('totalInvestmentDesc')} +

+ updateField('total_investment', parseFloat(e.target.value) || 1000)} + disabled={disabled} + min={100} + step={100} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+ + {/* Leverage */} +
+ +

+ {t('leverageDesc')} +

+ updateField('leverage', parseInt(e.target.value) || 5)} + disabled={disabled} + min={1} + max={20} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+
+
+ + {/* Grid Parameters */} +
+
+ +

+ {t('gridParameters')} +

+
+ +
+ {/* Grid Count */} +
+ +

+ {t('gridCountDesc')} +

+ updateField('grid_count', parseInt(e.target.value) || 10)} + disabled={disabled} + min={5} + max={50} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+ + {/* Distribution */} +
+ +

+ {t('distributionDesc')} +

+ +
+
+
+ + {/* Price Bounds */} +
+
+ +

+ {t('priceBounds')} +

+
+ + {/* ATR Toggle */} +
+
+
+ +

+ {t('useAtrBoundsDesc')} +

+
+ +
+
+ + {config.use_atr_bounds ? ( +
+ +

+ {t('atrMultiplierDesc')} +

+ updateField('atr_multiplier', parseFloat(e.target.value) || 2.0)} + disabled={disabled} + min={1} + max={5} + step={0.5} + className="w-32 px-3 py-2 rounded" + style={inputStyle} + /> +
+ ) : ( +
+
+ +

+ {t('upperPriceDesc')} +

+ updateField('upper_price', parseFloat(e.target.value) || 0)} + disabled={disabled} + min={0} + step={0.01} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+
+ +

+ {t('lowerPriceDesc')} +

+ updateField('lower_price', parseFloat(e.target.value) || 0)} + disabled={disabled} + min={0} + step={0.01} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+
+ )} +
+ + {/* Risk Control */} +
+
+ +

+ {t('riskControl')} +

+
+ +
+
+ +

+ {t('maxDrawdownDesc')} +

+ updateField('max_drawdown_pct', parseFloat(e.target.value) || 15)} + disabled={disabled} + min={5} + max={50} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+ +
+ +

+ {t('stopLossDesc')} +

+ updateField('stop_loss_pct', parseFloat(e.target.value) || 5)} + disabled={disabled} + min={1} + max={20} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+ +
+ +

+ {t('dailyLossLimitDesc')} +

+ updateField('daily_loss_limit_pct', parseFloat(e.target.value) || 10)} + disabled={disabled} + min={1} + max={30} + className="w-full px-3 py-2 rounded" + style={inputStyle} + /> +
+
+ + {/* Maker Only Toggle */} +
+
+
+ +

+ {t('useMakerOnlyDesc')} +

+
+ +
+
+
+
+ ) +} diff --git a/web/src/components/strategy/GridRiskPanel.tsx b/web/src/components/strategy/GridRiskPanel.tsx new file mode 100644 index 00000000..75c5b73b --- /dev/null +++ b/web/src/components/strategy/GridRiskPanel.tsx @@ -0,0 +1,372 @@ +import { useState, useEffect, useCallback } from 'react' +import { Shield, TrendingUp, AlertTriangle, Activity, Box, ChevronDown, ChevronUp } from 'lucide-react' +import type { GridRiskInfo } from '../../types' + +interface GridRiskPanelProps { + traderId: string + language?: string + refreshInterval?: number // ms, default 5000 +} + +export function GridRiskPanel({ + traderId, + language = 'en', + refreshInterval = 5000, +}: GridRiskPanelProps) { + const [riskInfo, setRiskInfo] = useState(null) + const [loading, setLoading] = useState(true) + const [error, setError] = useState(null) + const [expanded, setExpanded] = useState(false) + + const t = (key: string) => { + const translations: Record> = { + // Section titles + gridRisk: { zh: '网格风控', en: 'Grid Risk' }, + leverageInfo: { zh: '杠杆', en: 'Leverage' }, + positionInfo: { zh: '仓位', en: 'Position' }, + liquidationInfo: { zh: '清算', en: 'Liquidation' }, + marketState: { zh: '市场', en: 'Market' }, + boxState: { zh: '箱体', en: 'Box' }, + + // Leverage + currentLeverage: { zh: '当前', en: 'Current' }, + effectiveLeverage: { zh: '有效', en: 'Effective' }, + recommendedLeverage: { zh: '建议', en: 'Recommend' }, + + // Position + currentPosition: { zh: '当前', en: 'Current' }, + maxPosition: { zh: '最大', en: 'Max' }, + positionPercent: { zh: '占比', en: 'Usage' }, + + // Liquidation + liquidationPrice: { zh: '清算价', en: 'Liq Price' }, + liquidationDistance: { zh: '距离', en: 'Distance' }, + + // Market + regimeLevel: { zh: '波动', en: 'Regime' }, + currentPrice: { zh: '价格', en: 'Price' }, + breakoutLevel: { zh: '突破', en: 'Breakout' }, + breakoutDirection: { zh: '方向', en: 'Direction' }, + + // Box + shortBox: { zh: '短期', en: 'Short' }, + midBox: { zh: '中期', en: 'Mid' }, + longBox: { zh: '长期', en: 'Long' }, + + // Regime levels + narrow: { zh: '窄幅', en: 'Narrow' }, + standard: { zh: '标准', en: 'Standard' }, + wide: { zh: '宽幅', en: 'Wide' }, + volatile: { zh: '剧烈', en: 'Volatile' }, + trending: { zh: '趋势', en: 'Trending' }, + + // Breakout levels + none: { zh: '无', en: 'None' }, + short: { zh: '短期', en: 'Short' }, + mid: { zh: '中期', en: 'Mid' }, + long: { zh: '长期', en: 'Long' }, + + // Directions + up: { zh: '↑', en: '↑' }, + down: { zh: '↓', en: '↓' }, + + // Status + loading: { zh: '加载中...', en: 'Loading...' }, + error: { zh: '加载失败', en: 'Load Failed' }, + noData: { zh: '暂无数据', en: 'No Data' }, + } + return translations[key]?.[language] || key + } + + const fetchRiskInfo = useCallback(async () => { + try { + const token = localStorage.getItem('auth_token') + const response = await fetch(`/api/traders/${traderId}/grid-risk`, { + headers: { + Authorization: `Bearer ${token}`, + }, + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}`) + } + + const data = await response.json() + setRiskInfo(data) + setError(null) + } catch (err) { + setError(err instanceof Error ? err.message : 'Unknown error') + } finally { + setLoading(false) + } + }, [traderId]) + + useEffect(() => { + fetchRiskInfo() + const interval = setInterval(fetchRiskInfo, refreshInterval) + return () => clearInterval(interval) + }, [fetchRiskInfo, refreshInterval]) + + const getRegimeColor = (regime: string) => { + switch (regime) { + case 'narrow': return '#0ECB81' + case 'standard': return '#F0B90B' + case 'wide': return '#F7931A' + case 'volatile': return '#F6465D' + case 'trending': return '#8B5CF6' + default: return '#848E9C' + } + } + + const getBreakoutColor = (level: string) => { + switch (level) { + case 'none': return '#0ECB81' + case 'short': return '#F0B90B' + case 'mid': return '#F7931A' + case 'long': return '#F6465D' + default: return '#848E9C' + } + } + + const getPositionColor = (percent: number) => { + if (percent < 50) return '#0ECB81' + if (percent < 80) return '#F0B90B' + return '#F6465D' + } + + const formatPrice = (price: number) => { + if (price === 0) return '-' + if (price >= 1000) return price.toLocaleString('en-US', { minimumFractionDigits: 2, maximumFractionDigits: 2 }) + if (price >= 1) return price.toFixed(4) + return price.toFixed(6) + } + + const formatUSD = (value: number) => { + return `$${value.toLocaleString('en-US', { minimumFractionDigits: 0, maximumFractionDigits: 0 })}` + } + + const cardStyle = { + background: '#0B0E11', + border: '1px solid #2B3139', + } + + if (loading) { + return ( +
+ {t('loading')} +
+ ) + } + + if (error) { + return ( +
+ {t('error')}: {error} +
+ ) + } + + if (!riskInfo) { + return ( +
+ {t('noData')} +
+ ) + } + + return ( +
+ {/* Collapsible Header */} +
setExpanded(!expanded)} + > +
+ + + {t('gridRisk')} + +
+
+ {/* Summary badges when collapsed */} +
+ + {t(riskInfo.regime_level || 'standard')} + + + {riskInfo.effective_leverage.toFixed(1)}x + + + {riskInfo.position_percent.toFixed(0)}% + +
+ {expanded ? ( + + ) : ( + + )} +
+
+ + {/* Expanded Content */} + {expanded && ( +
+ {/* Row 1: Leverage & Position */} +
+ {/* Leverage */} +
+
+ + {t('leverageInfo')} +
+
+
+
{t('currentLeverage')}
+
{riskInfo.current_leverage}x
+
+
+
{t('effectiveLeverage')}
+
{riskInfo.effective_leverage.toFixed(2)}x
+
+
+
{t('recommendedLeverage')}
+
riskInfo.recommended_leverage ? '#F6465D' : '#0ECB81' }} + > + {riskInfo.recommended_leverage}x +
+
+
+
+ + {/* Position */} +
+
+ + {t('positionInfo')} +
+
+
+
{t('currentPosition')}
+
{formatUSD(riskInfo.current_position)}
+
+
+
{t('maxPosition')}
+
{formatUSD(riskInfo.max_position)}
+
+
+
{t('positionPercent')}
+
+ {riskInfo.position_percent.toFixed(1)}% +
+
+
+ {/* Mini progress bar */} +
+
+
+
+
+ + {/* Row 2: Market State & Liquidation */} +
+ {/* Market State */} +
+
+ + {t('marketState')} +
+
+
+
{t('regimeLevel')}
+
+ {t(riskInfo.regime_level || 'standard')} +
+
+
+
{t('currentPrice')}
+
{formatPrice(riskInfo.current_price)}
+
+
+
{t('breakoutLevel')}
+
+ {t(riskInfo.breakout_level || 'none')} +
+
+
+
{t('breakoutDirection')}
+
+ {riskInfo.breakout_direction ? t(riskInfo.breakout_direction) : '-'} +
+
+
+
+ + {/* Liquidation */} +
+
+ + {t('liquidationInfo')} +
+
+
+
{t('liquidationPrice')}
+
+ {riskInfo.liquidation_price > 0 ? formatPrice(riskInfo.liquidation_price) : '-'} +
+
+
+
{t('liquidationDistance')}
+
+ {riskInfo.liquidation_distance > 0 ? `${riskInfo.liquidation_distance.toFixed(1)}%` : '-'} +
+
+
+
+
+ + {/* Row 3: Box State */} +
+
+ + {t('boxState')} +
+
+
+ {t('shortBox')} + + {formatPrice(riskInfo.short_box_lower)} - {formatPrice(riskInfo.short_box_upper)} + +
+
+ {t('midBox')} + + {formatPrice(riskInfo.mid_box_lower)} - {formatPrice(riskInfo.mid_box_upper)} + +
+
+ {t('longBox')} + + {formatPrice(riskInfo.long_box_lower)} - {formatPrice(riskInfo.long_box_upper)} + +
+
+
+
+ )} +
+ ) +} diff --git a/web/src/pages/StrategyStudioPage.tsx b/web/src/pages/StrategyStudioPage.tsx index fcb3fc17..8b21d502 100644 --- a/web/src/pages/StrategyStudioPage.tsx +++ b/web/src/pages/StrategyStudioPage.tsx @@ -37,6 +37,7 @@ import { IndicatorEditor } from '../components/strategy/IndicatorEditor' import { RiskControlEditor } from '../components/strategy/RiskControlEditor' import { PromptSectionsEditor } from '../components/strategy/PromptSectionsEditor' import { PublishSettingsEditor } from '../components/strategy/PublishSettingsEditor' +import { GridConfigEditor, defaultGridConfig } from '../components/strategy/GridConfigEditor' import { DeepVoidBackground } from '../components/DeepVoidBackground' const API_BASE = import.meta.env.VITE_API_BASE || '' @@ -59,6 +60,7 @@ export function StrategyStudioPage() { // Accordion states for left panel const [expandedSections, setExpandedSections] = useState({ + gridConfig: true, coinSource: true, indicators: false, riskControl: false, @@ -486,6 +488,12 @@ export function StrategyStudioPage() { subtitle: { zh: '可视化配置和测试交易策略', en: 'Configure and test trading strategies' }, strategies: { zh: '策略', en: 'Strategies' }, newStrategy: { zh: '新建', en: 'New' }, + strategyType: { zh: '策略类型', en: 'Strategy Type' }, + aiTrading: { zh: 'AI 智能交易', en: 'AI Trading' }, + aiTradingDesc: { zh: 'AI 分析市场并自主决策买卖', en: 'AI analyzes market and makes trading decisions' }, + gridTrading: { zh: 'AI 网格交易', en: 'AI Grid Trading' }, + gridTradingDesc: { zh: 'AI 控制网格策略,在震荡市场获利', en: 'AI-controlled grid strategy for ranging markets' }, + gridConfig: { zh: '网格配置', en: 'Grid Configuration' }, coinSource: { zh: '币种来源', en: 'Coin Source' }, indicators: { zh: '技术指标', en: 'Indicators' }, riskControl: { zh: '风控参数', en: 'Risk Control' }, @@ -533,12 +541,33 @@ export function StrategyStudioPage() { ) } + // Get current strategy type (default to ai_trading if not set) + const currentStrategyType = editingConfig?.strategy_type || 'ai_trading' + const configSections = [ + // Grid Config - only for grid_trading + { + key: 'gridConfig' as const, + icon: Activity, + color: '#0ECB81', + title: t('gridConfig'), + forStrategyType: 'grid_trading' as const, + content: editingConfig?.grid_config && ( + updateConfig('grid_config', gridConfig)} + disabled={selectedStrategy?.is_default} + language={language} + /> + ), + }, + // AI Trading sections { key: 'coinSource' as const, icon: Target, color: '#F0B90B', title: t('coinSource'), + forStrategyType: 'ai_trading' as const, content: editingConfig && (

@@ -616,6 +649,7 @@ export function StrategyStudioPage() { icon: Globe, color: '#0ECB81', title: t('publishSettings'), + forStrategyType: 'both' as const, content: selectedStrategy && ( ), }, - ] + ].filter(section => + section.forStrategyType === 'both' || section.forStrategyType === currentStrategyType + ) return ( @@ -813,6 +849,62 @@ export function StrategyStudioPage() {

+ {/* Strategy Type Selector */} + {editingConfig && ( +
+
+ + {t('strategyType')} +
+
+ + +
+
+ )} + {/* Config Sections */}
{configSections.map(({ key, icon: Icon, color, title, content }) => ( diff --git a/web/src/pages/TraderDashboardPage.tsx b/web/src/pages/TraderDashboardPage.tsx index 6f60b76f..3c46f2f1 100644 --- a/web/src/pages/TraderDashboardPage.tsx +++ b/web/src/pages/TraderDashboardPage.tsx @@ -9,6 +9,7 @@ import { confirmToast, notify } from '../lib/notify' import { t, type Language } from '../i18n/translations' import { LogOut, Loader2, Eye, EyeOff, Copy, Check } from 'lucide-react' import { DeepVoidBackground } from '../components/DeepVoidBackground' +import { GridRiskPanel } from '../components/strategy/GridRiskPanel' import type { SystemStatus, AccountInfo, @@ -151,6 +152,13 @@ export function TraderDashboardPage({ setPositionsCurrentPage(1) }, [selectedTraderId, positionsPageSize]) + // Auto-set chart symbol for grid trading + useEffect(() => { + if (status?.strategy_type === 'grid_trading' && status?.grid_symbol) { + setSelectedChartSymbol(status.grid_symbol) + } + }, [status?.strategy_type, status?.grid_symbol]) + // Get current exchange info for perp-dex wallet display const currentExchange = exchanges?.find( (e) => e.id === selectedTrader?.exchange_id @@ -532,6 +540,17 @@ export function TraderDashboardPage({ />
+ {/* Grid Risk Panel - Only show for grid trading strategy */} + {status?.strategy_type === 'grid_trading' && selectedTraderId && ( +
+ +
+ )} + {/* Main Content Area */}
{/* Left Column: Charts + Positions */} diff --git a/web/src/types.ts b/web/src/types.ts index b4f39278..d2e1d1b0 100644 --- a/web/src/types.ts +++ b/web/src/types.ts @@ -11,6 +11,8 @@ export interface SystemStatus { stop_until: string last_reset_time: string ai_provider: string + strategy_type?: 'ai_trading' | 'grid_trading' + grid_symbol?: string } export interface AccountInfo { @@ -462,6 +464,8 @@ export interface PromptSectionsConfig { } export interface StrategyConfig { + // Strategy type: "ai_trading" (default) or "grid_trading" + strategy_type?: 'ai_trading' | 'grid_trading'; // Language setting: "zh" for Chinese, "en" for English // Determines the language used for data formatting and prompt generation language?: 'zh' | 'en'; @@ -470,6 +474,38 @@ export interface StrategyConfig { custom_prompt?: string; risk_control: RiskControlConfig; prompt_sections?: PromptSectionsConfig; + // Grid trading configuration (only used when strategy_type is 'grid_trading') + grid_config?: GridStrategyConfig; +} + +// Grid trading specific configuration +export interface GridStrategyConfig { + // Trading pair (e.g., "BTCUSDT") + symbol: string; + // Number of grid levels (5-50) + grid_count: number; + // Total investment in USDT + total_investment: number; + // Leverage (1-20) + leverage: number; + // Upper price boundary (0 = auto-calculate from ATR) + upper_price: number; + // Lower price boundary (0 = auto-calculate from ATR) + lower_price: number; + // Use ATR to auto-calculate bounds + use_atr_bounds: boolean; + // ATR multiplier for bound calculation (default 2.0) + atr_multiplier: number; + // Position distribution: "uniform" | "gaussian" | "pyramid" + distribution: 'uniform' | 'gaussian' | 'pyramid'; + // Maximum drawdown percentage before emergency exit + max_drawdown_pct: number; + // Stop loss percentage per position + stop_loss_pct: number; + // Daily loss limit percentage + daily_loss_limit_pct: number; + // Use maker-only orders for lower fees + use_maker_only: boolean; } export interface CoinSourceConfig { @@ -750,3 +786,36 @@ export interface PositionHistoryResponse { symbol_stats: SymbolStats[]; direction_stats: DirectionStats[]; } + +// Grid Risk Information for frontend display +export interface GridRiskInfo { + // Leverage info + current_leverage: number + effective_leverage: number + recommended_leverage: number + + // Position info + current_position: number + max_position: number + position_percent: number + + // Liquidation info + liquidation_price: number + liquidation_distance: number + + // Market state + regime_level: string + + // Box state + short_box_upper: number + short_box_lower: number + mid_box_upper: number + mid_box_lower: number + long_box_upper: number + long_box_lower: number + current_price: number + + // Breakout state + breakout_level: string + breakout_direction: string +}