diff --git a/decision/engine.go b/decision/engine.go index a6b0113e..ea0f4fbc 100644 --- a/decision/engine.go +++ b/decision/engine.go @@ -346,13 +346,17 @@ func buildSystemPrompt(accountEquity float64, btcEthLeverage, altcoinLeverage in sb.WriteString("\n") sb.WriteString("```json\n[\n") sb.WriteString(fmt.Sprintf(" {\"symbol\": \"BTCUSDT\", \"action\": \"open_short\", \"leverage\": %d, \"position_size_usd\": %.0f, \"stop_loss\": 97000, \"take_profit\": 91000, \"confidence\": 85, \"risk_usd\": 300, \"reasoning\": \"下跌趋势+MACD死叉\"},\n", btcEthLeverage, accountEquity*5)) + sb.WriteString(" {\"symbol\": \"SOLUSDT\", \"action\": \"update_stop_loss\", \"new_stop_loss\": 155, \"reasoning\": \"移动止损至保本位\"},\n") sb.WriteString(" {\"symbol\": \"ETHUSDT\", \"action\": \"close_long\", \"reasoning\": \"止盈离场\"}\n") sb.WriteString("]\n```\n") sb.WriteString("\n\n") sb.WriteString("## 字段说明\n\n") sb.WriteString("- `action`: open_long | open_short | close_long | close_short | update_stop_loss | update_take_profit | partial_close | hold | wait\n") sb.WriteString("- `confidence`: 0-100(开仓建议≥75)\n") - sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd, reasoning\n\n") + sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd, reasoning\n") + sb.WriteString("- update_stop_loss 时必填: new_stop_loss (注意是 new_stop_loss,不是 stop_loss)\n") + sb.WriteString("- update_take_profit 时必填: new_take_profit (注意是 new_take_profit,不是 take_profit)\n") + sb.WriteString("- partial_close 时必填: close_percentage (0-100)\n\n") return sb.String() } diff --git a/decision/validate_test.go b/decision/validate_test.go index faac4fe5..d7e89229 100644 --- a/decision/validate_test.go +++ b/decision/validate_test.go @@ -98,3 +98,198 @@ func TestLeverageFallback(t *testing.T) { }) } } + +// TestUpdateStopLossValidation 测试 update_stop_loss 动作的字段验证 +func TestUpdateStopLossValidation(t *testing.T) { + tests := []struct { + name string + decision Decision + wantError bool + errorMsg string + }{ + { + name: "正确使用new_stop_loss字段", + decision: Decision{ + Symbol: "SOLUSDT", + Action: "update_stop_loss", + NewStopLoss: 155.5, + Reasoning: "移动止损至保本位", + }, + wantError: false, + }, + { + name: "new_stop_loss为0应该报错", + decision: Decision{ + Symbol: "SOLUSDT", + Action: "update_stop_loss", + NewStopLoss: 0, + Reasoning: "测试错误情况", + }, + wantError: true, + errorMsg: "新止损价格必须大于0", + }, + { + name: "new_stop_loss为负数应该报错", + decision: Decision{ + Symbol: "SOLUSDT", + Action: "update_stop_loss", + NewStopLoss: -100, + Reasoning: "测试错误情况", + }, + wantError: true, + errorMsg: "新止损价格必须大于0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDecision(&tt.decision, 1000.0, 10, 5) + + if (err != nil) != tt.wantError { + t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError) + return + } + + if tt.wantError && err != nil { + if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) { + t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg) + } + } + }) + } +} + +// TestUpdateTakeProfitValidation 测试 update_take_profit 动作的字段验证 +func TestUpdateTakeProfitValidation(t *testing.T) { + tests := []struct { + name string + decision Decision + wantError bool + errorMsg string + }{ + { + name: "正确使用new_take_profit字段", + decision: Decision{ + Symbol: "BTCUSDT", + Action: "update_take_profit", + NewTakeProfit: 98000, + Reasoning: "调整止盈至关键阻力位", + }, + wantError: false, + }, + { + name: "new_take_profit为0应该报错", + decision: Decision{ + Symbol: "BTCUSDT", + Action: "update_take_profit", + NewTakeProfit: 0, + Reasoning: "测试错误情况", + }, + wantError: true, + errorMsg: "新止盈价格必须大于0", + }, + { + name: "new_take_profit为负数应该报错", + decision: Decision{ + Symbol: "BTCUSDT", + Action: "update_take_profit", + NewTakeProfit: -1000, + Reasoning: "测试错误情况", + }, + wantError: true, + errorMsg: "新止盈价格必须大于0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDecision(&tt.decision, 1000.0, 10, 5) + + if (err != nil) != tt.wantError { + t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError) + return + } + + if tt.wantError && err != nil { + if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) { + t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg) + } + } + }) + } +} + +// TestPartialCloseValidation 测试 partial_close 动作的字段验证 +func TestPartialCloseValidation(t *testing.T) { + tests := []struct { + name string + decision Decision + wantError bool + errorMsg string + }{ + { + name: "正确使用close_percentage字段", + decision: Decision{ + Symbol: "ETHUSDT", + Action: "partial_close", + ClosePercentage: 50.0, + Reasoning: "锁定一半利润", + }, + wantError: false, + }, + { + name: "close_percentage为0应该报错", + decision: Decision{ + Symbol: "ETHUSDT", + Action: "partial_close", + ClosePercentage: 0, + Reasoning: "测试错误情况", + }, + wantError: true, + errorMsg: "平仓百分比必须在0-100之间", + }, + { + name: "close_percentage超过100应该报错", + decision: Decision{ + Symbol: "ETHUSDT", + Action: "partial_close", + ClosePercentage: 150, + Reasoning: "测试错误情况", + }, + wantError: true, + errorMsg: "平仓百分比必须在0-100之间", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDecision(&tt.decision, 1000.0, 10, 5) + + if (err != nil) != tt.wantError { + t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError) + return + } + + if tt.wantError && err != nil { + if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) { + t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg) + } + } + }) + } +} + +// contains 检查字符串是否包含子串(辅助函数) +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && stringContains(s, substr))) +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}