diff --git a/decision/engine.go b/decision/engine.go index c8aa4f51..cc3ea3dc 100644 --- a/decision/engine.go +++ b/decision/engine.go @@ -121,12 +121,12 @@ type FullDecision struct { } // GetFullDecision 获取AI的完整交易决策(批量分析所有币种和持仓) -func GetFullDecision(ctx *Context, mcpClient *mcp.Client) (*FullDecision, error) { +func GetFullDecision(ctx *Context, mcpClient mcp.AIClient) (*FullDecision, error) { return GetFullDecisionWithCustomPrompt(ctx, mcpClient, "", false, "") } // GetFullDecisionWithCustomPrompt 获取AI的完整交易决策(支持自定义prompt和模板选择) -func GetFullDecisionWithCustomPrompt(ctx *Context, mcpClient *mcp.Client, customPrompt string, overrideBase bool, templateName string) (*FullDecision, error) { +func GetFullDecisionWithCustomPrompt(ctx *Context, mcpClient mcp.AIClient, customPrompt string, overrideBase bool, templateName string) (*FullDecision, error) { // 1. 为所有币种获取市场数据 if err := fetchMarketDataForContext(ctx); err != nil { return nil, fmt.Errorf("获取市场数据失败: %w", err) diff --git a/logger/decision_logger.go b/logger/decision_logger.go index a15377f2..1886f51c 100644 --- a/logger/decision_logger.go +++ b/logger/decision_logger.go @@ -64,6 +64,22 @@ type DecisionAction struct { Error string `json:"error"` // 错误信息 } +// IDecisionLogger 决策日志记录器接口 +type IDecisionLogger interface { + // LogDecision 记录决策 + LogDecision(record *DecisionRecord) error + // GetLatestRecords 获取最近N条记录(按时间正序:从旧到新) + GetLatestRecords(n int) ([]*DecisionRecord, error) + // GetRecordByDate 获取指定日期的所有记录 + GetRecordByDate(date time.Time) ([]*DecisionRecord, error) + // CleanOldRecords 清理N天前的旧记录 + CleanOldRecords(days int) error + // GetStatistics 获取统计信息 + GetStatistics() (*Statistics, error) + // AnalyzePerformance 分析最近N个周期的交易表现 + AnalyzePerformance(lookbackCycles int) (*PerformanceAnalysis, error) +} + // DecisionLogger 决策日志记录器 type DecisionLogger struct { logDir string @@ -71,7 +87,7 @@ type DecisionLogger struct { } // NewDecisionLogger 创建决策日志记录器 -func NewDecisionLogger(logDir string) *DecisionLogger { +func NewDecisionLogger(logDir string) IDecisionLogger { if logDir == "" { logDir = "decision_logs" } diff --git a/mcp/client.go b/mcp/client.go index 0f785534..a35d92e7 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -13,18 +13,17 @@ import ( "time" ) -// Provider AI提供商类型 -type Provider string - const ( - ProviderDeepSeek Provider = "deepseek" - ProviderQwen Provider = "qwen" - ProviderCustom Provider = "custom" + ProviderCustom = "custom" +) + +var ( + DefaultTimeout = 120 * time.Second ) // Client AI API配置 type Client struct { - Provider Provider + Provider string APIKey string BaseURL string Model string @@ -33,7 +32,7 @@ type Client struct { MaxTokens int // AI响应的最大token数 } -func New() *Client { +func New() AIClient { // 从环境变量读取 MaxTokens,默认 2000 maxTokens := 2000 if envMaxTokens := os.Getenv("AI_MAX_TOKENS"); envMaxTokens != "" { @@ -48,65 +47,15 @@ func New() *Client { // 默认配置 return &Client{ Provider: ProviderDeepSeek, - BaseURL: "https://api.deepseek.com/v1", - Model: "deepseek-chat", - Timeout: 120 * time.Second, // 增加到120秒,因为AI需要分析大量数据 + BaseURL: DefaultDeepSeekBaseURL, + Model: DefaultDeepSeekModel, + Timeout: DefaultTimeout, MaxTokens: maxTokens, } } -// SetDeepSeekAPIKey 设置DeepSeek API密钥 -// customURL 为空时使用默认URL,customModel 为空时使用默认模型 -func (client *Client) SetDeepSeekAPIKey(apiKey string, customURL string, customModel string) { - client.Provider = ProviderDeepSeek - client.APIKey = apiKey - if customURL != "" { - client.BaseURL = customURL - log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL) - } else { - client.BaseURL = "https://api.deepseek.com/v1" - log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", client.BaseURL) - } - if customModel != "" { - client.Model = customModel - log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel) - } else { - client.Model = "deepseek-chat" - log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", client.Model) - } - // 打印 API Key 的前后各4位用于验证 - if len(apiKey) > 8 { - log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } -} - -// SetQwenAPIKey 设置阿里云Qwen API密钥 -// customURL 为空时使用默认URL,customModel 为空时使用默认模型 -func (client *Client) SetQwenAPIKey(apiKey string, customURL string, customModel string) { - client.Provider = ProviderQwen - client.APIKey = apiKey - if customURL != "" { - client.BaseURL = customURL - log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL) - } else { - client.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" - log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", client.BaseURL) - } - if customModel != "" { - client.Model = customModel - log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel) - } else { - client.Model = "qwen3-max" - log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", client.Model) - } - // 打印 API Key 的前后各4位用于验证 - if len(apiKey) > 8 { - log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } -} - // SetCustomAPI 设置自定义OpenAI兼容API -func (client *Client) SetCustomAPI(apiURL, apiKey, modelName string) { +func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) { client.Provider = ProviderCustom client.APIKey = apiKey @@ -119,22 +68,14 @@ func (client *Client) SetCustomAPI(apiURL, apiKey, modelName string) { client.UseFullURL = false } - client.Model = modelName + client.Model = customModel client.Timeout = 120 * time.Second } -// SetClient 设置完整的AI配置(高级用户) -func (client *Client) SetClient(Client Client) { - if Client.Timeout == 0 { - Client.Timeout = 30 * time.Second - } - client = &Client -} - // CallWithMessages 使用 system + user prompt 调用AI API(推荐) func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) { if client.APIKey == "" { - return "", fmt.Errorf("AI API密钥未设置,请先调用 SetDeepSeekAPIKey() 或 SetQwenAPIKey()") + return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey") } // 重试配置 @@ -171,6 +112,10 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr) } +func (client *Client) setAuthHeader(reqHeader http.Header) { + reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey)) +} + // callOnce 单次调用AI API(内部使用) func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) { // 打印当前 AI 配置 @@ -234,17 +179,7 @@ func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) req.Header.Set("Content-Type", "application/json") - // 根据不同的Provider设置认证方式 - switch client.Provider { - case ProviderDeepSeek: - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey)) - case ProviderQwen: - // 阿里云Qwen使用API-Key认证 - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey)) - // 注意:如果使用的不是兼容模式,可能需要不同的认证方式 - default: - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey)) - } + client.setAuthHeader(req.Header) // 发送请求 httpClient := &http.Client{Timeout: client.Timeout} diff --git a/mcp/deepseek_client.go b/mcp/deepseek_client.go new file mode 100644 index 00000000..12489292 --- /dev/null +++ b/mcp/deepseek_client.go @@ -0,0 +1,53 @@ +package mcp + +import ( + "log" + "net/http" +) + +const ( + ProviderDeepSeek = "deepseek" + DefaultDeepSeekBaseURL = "https://api.deepseek.com/v1" + DefaultDeepSeekModel = "deepseek-chat" +) + +type DeepSeekClient struct { + *Client +} + +func NewDeepSeekClient() AIClient { + client := New().(*Client) + client.Provider = ProviderDeepSeek + client.Model = DefaultDeepSeekModel + client.BaseURL = DefaultDeepSeekBaseURL + return &DeepSeekClient{ + Client: client, + } +} + +func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) { + if dsClient.Client == nil { + dsClient.Client = New().(*Client) + } + dsClient.Client.APIKey = apiKey + + if len(apiKey) > 8 { + log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + dsClient.Client.BaseURL = customURL + log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL) + } else { + log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", dsClient.Client.BaseURL) + } + if customModel != "" { + dsClient.Client.Model = customModel + log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel) + } else { + log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", dsClient.Client.Model) + } +} + +func (dsClient *DeepSeekClient) setAuthHeader(reqHeaders http.Header) { + dsClient.Client.setAuthHeader(reqHeaders) +} diff --git a/mcp/interface.go b/mcp/interface.go new file mode 100644 index 00000000..8c9b9574 --- /dev/null +++ b/mcp/interface.go @@ -0,0 +1,12 @@ +package mcp + +import "net/http" + +// AIClient AI客户端接口 +type AIClient interface { + SetAPIKey(apiKey string, customURL string, customModel string) + // CallWithMessages 使用 system + user prompt 调用AI API + CallWithMessages(systemPrompt, userPrompt string) (string, error) + + setAuthHeader(reqHeaders http.Header) +} diff --git a/mcp/qwen_client.go b/mcp/qwen_client.go new file mode 100644 index 00000000..e56ed42d --- /dev/null +++ b/mcp/qwen_client.go @@ -0,0 +1,53 @@ +package mcp + +import ( + "log" + "net/http" +) + +const ( + ProviderQwen = "qwen" + DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + DefaultQwenModel = "qwen3-max" +) + +type QwenClient struct { + *Client +} + +func NewQwenClient() AIClient { + client := New().(*Client) + client.Provider = ProviderQwen + client.Model = DefaultQwenModel + client.BaseURL = DefaultQwenBaseURL + return &QwenClient{ + Client: client, + } +} + +func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) { + if qwenClient.Client == nil { + qwenClient.Client = New().(*Client) + } + qwenClient.Client.APIKey = apiKey + + if len(apiKey) > 8 { + log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + qwenClient.Client.BaseURL = customURL + log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL) + } else { + log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", qwenClient.Client.BaseURL) + } + if customModel != "" { + qwenClient.Client.Model = customModel + log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel) + } else { + log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", qwenClient.Client.Model) + } +} + +func (qwenClient *QwenClient) setAuthHeader(reqHeaders http.Header) { + qwenClient.Client.setAuthHeader(reqHeaders) +} diff --git a/trader/auto_trader.go b/trader/auto_trader.go index 7362e786..4e53a9b4 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -85,8 +85,8 @@ type AutoTrader struct { exchange string // 交易平台名称 config AutoTraderConfig trader Trader // 使用Trader接口(支持多平台) - mcpClient *mcp.Client - decisionLogger *logger.DecisionLogger // 决策日志记录器 + mcpClient mcp.AIClient + decisionLogger logger.IDecisionLogger // 决策日志记录器 initialBalance float64 dailyPnL float64 customPrompt string // 自定义交易策略prompt @@ -131,11 +131,12 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) // 初始化AI if config.AIModel == "custom" { // 使用自定义API - mcpClient.SetCustomAPI(config.CustomAPIURL, config.CustomAPIKey, config.CustomModelName) + mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) log.Printf("🤖 [%s] 使用自定义AI API: %s (模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) } else if config.UseQwen || config.AIModel == "qwen" { // 使用Qwen (支持自定义URL和Model) - mcpClient.SetQwenAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName) + mcpClient = mcp.NewQwenClient() + mcpClient.SetAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName) if config.CustomAPIURL != "" || config.CustomModelName != "" { log.Printf("🤖 [%s] 使用阿里云Qwen AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) } else { @@ -143,7 +144,8 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) } } else { // 默认使用DeepSeek (支持自定义URL和Model) - mcpClient.SetDeepSeekAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName) + mcpClient = mcp.NewDeepSeekClient() + mcpClient.SetAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName) if config.CustomAPIURL != "" || config.CustomModelName != "" { log.Printf("🤖 [%s] 使用DeepSeek AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) } else { @@ -1205,7 +1207,7 @@ func (at *AutoTrader) GetSystemPromptTemplate() string { } // GetDecisionLogger 获取决策日志记录器 -func (at *AutoTrader) GetDecisionLogger() *logger.DecisionLogger { +func (at *AutoTrader) GetDecisionLogger() logger.IDecisionLogger { return at.decisionLogger } diff --git a/trader/auto_trader_test.go b/trader/auto_trader_test.go index 09a2c428..9316981f 100644 --- a/trader/auto_trader_test.go +++ b/trader/auto_trader_test.go @@ -31,7 +31,7 @@ type AutoTraderTestSuite struct { // Mock 依赖 mockTrader *MockTrader mockDB *MockDatabase - mockLogger *logger.DecisionLogger + mockLogger logger.IDecisionLogger // gomonkey patches patches *gomonkey.Patches