diff --git a/.gitignore b/.gitignore index 89880c2a..a487ae58 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,7 @@ decision_logs/ nofx_test # Node.js +web/node_modules web/node_modules/ node_modules/ web/dist/ @@ -52,6 +53,9 @@ web/.vite/ # ESLint 临时报告文件(调试时生成,不纳入版本控制) eslint-*.json +# 本地 Agent QA seed(个人调试用,不纳入版本控制) +docs/qa/fixtures/agent_self_play_seed.zh-CN.json + # VS code .vscode diff --git a/agent/active_session.go b/agent/active_session.go new file mode 100644 index 00000000..8483392b --- /dev/null +++ b/agent/active_session.go @@ -0,0 +1,272 @@ +package agent + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +// ActiveSkillSession is the minimal session for the central brain architecture. +// It replaces the old skillSession + ExecutionState combo for management skill flows. +type ActiveSkillSession struct { + SessionID string `json:"session_id"` + UserID int64 `json:"user_id"` + SkillName string `json:"skill_name"` + ActionName string `json:"action_name"` + LegacyPhase string `json:"legacy_phase,omitempty"` + Goal string `json:"goal,omitempty"` + PendingHint *PendingHint `json:"pending_hint,omitempty"` + CollectedFields map[string]any `json:"collected_fields,omitempty"` + LocalHistory []chatMessage `json:"local_history,omitempty"` + UpdatedAt string `json:"updated_at"` +} + +type PendingHint struct { + Prompt string `json:"prompt,omitempty"` + HintType string `json:"hint_type,omitempty"` +} + +type PendingProposalSession struct { + UserID int64 `json:"user_id"` + SourceUserText string `json:"source_user_text,omitempty"` + ProposalText string `json:"proposal_text,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +func activeSkillSessionKey(userID int64) string { + return fmt.Sprintf("agent_active_skill_session_%d", userID) +} + +func pendingProposalSessionKey(userID int64) string { + return fmt.Sprintf("agent_pending_proposal_session_%d", userID) +} + +func (a *Agent) getActiveSkillSession(userID int64) (ActiveSkillSession, bool) { + if a.store == nil { + return ActiveSkillSession{}, false + } + raw, err := a.store.GetSystemConfig(activeSkillSessionKey(userID)) + if err != nil || strings.TrimSpace(raw) == "" { + return ActiveSkillSession{}, false + } + var s ActiveSkillSession + if err := json.Unmarshal([]byte(raw), &s); err != nil { + return ActiveSkillSession{}, false + } + if s.SessionID == "" || s.SkillName == "" { + return ActiveSkillSession{}, false + } + s.PendingHint = normalizePendingHint(s.PendingHint) + return s, true +} + +func (a *Agent) saveActiveSkillSession(s ActiveSkillSession) { + if a.store == nil { + return + } + s.PendingHint = normalizePendingHint(s.PendingHint) + s.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + data, _ := json.Marshal(s) + _ = a.store.SetSystemConfig(activeSkillSessionKey(s.UserID), string(data)) +} + +func (a *Agent) clearActiveSkillSession(userID int64) { + if a.store == nil { + return + } + _ = a.store.SetSystemConfig(activeSkillSessionKey(userID), "") +} + +func (a *Agent) getPendingProposalSession(userID int64) (PendingProposalSession, bool) { + if a.store == nil { + return PendingProposalSession{}, false + } + raw, err := a.store.GetSystemConfig(pendingProposalSessionKey(userID)) + if err != nil || strings.TrimSpace(raw) == "" { + return PendingProposalSession{}, false + } + var s PendingProposalSession + if err := json.Unmarshal([]byte(raw), &s); err != nil { + return PendingProposalSession{}, false + } + if s.UserID == 0 || strings.TrimSpace(s.ProposalText) == "" { + return PendingProposalSession{}, false + } + return s, true +} + +func (a *Agent) savePendingProposalSession(s PendingProposalSession) { + if a.store == nil { + return + } + s.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + data, _ := json.Marshal(s) + _ = a.store.SetSystemConfig(pendingProposalSessionKey(s.UserID), string(data)) +} + +func (a *Agent) clearPendingProposalSession(userID int64) { + if a.store == nil { + return + } + _ = a.store.SetSystemConfig(pendingProposalSessionKey(userID), "") +} + +func newActiveSkillSession(userID int64, skill, action string) ActiveSkillSession { + return ActiveSkillSession{ + SessionID: fmt.Sprintf("as_%d", time.Now().UnixNano()), + UserID: userID, + SkillName: skill, + ActionName: action, + LegacyPhase: "collecting", + Goal: "", + PendingHint: nil, + CollectedFields: map[string]any{}, + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } +} + +func normalizePendingHint(hint *PendingHint) *PendingHint { + if hint == nil { + return nil + } + prompt := strings.TrimSpace(hint.Prompt) + if prompt == "" { + return nil + } + out := &PendingHint{ + Prompt: prompt, + HintType: strings.TrimSpace(hint.HintType), + } + return out +} + +func pendingHintFromAssistantReply(reply string) *PendingHint { + reply = strings.TrimSpace(reply) + if reply == "" { + return nil + } + hintType := "" + switch { + case strings.Contains(reply, "请选择") || strings.Contains(strings.ToLower(reply), "choose"): + hintType = "choice" + case strings.Contains(reply, "确认") || strings.Contains(strings.ToLower(reply), "confirm"): + hintType = "confirmation" + case strings.HasSuffix(reply, "?") || strings.HasSuffix(reply, "?"): + hintType = "question" + } + if hintType == "" { + return nil + } + return &PendingHint{Prompt: reply, HintType: hintType} +} + +func setActiveSessionPendingHint(session *ActiveSkillSession, reply string) { + if session == nil { + return + } + session.PendingHint = pendingHintFromAssistantReply(reply) +} + +func clearActiveSessionPendingHint(session *ActiveSkillSession) { + if session == nil { + return + } + session.PendingHint = nil +} + +func (a *Agent) currentPendingHintText(userID int64) string { + if active, ok := a.getActiveSkillSession(userID); ok && active.PendingHint != nil && strings.TrimSpace(active.PendingHint.Prompt) != "" { + return strings.TrimSpace(active.PendingHint.Prompt) + } + if state := a.getExecutionState(userID); state.Waiting != nil && strings.TrimSpace(state.Waiting.Question) != "" { + return strings.TrimSpace(state.Waiting.Question) + } + if proposal, ok := a.getPendingProposalSession(userID); ok && strings.TrimSpace(proposal.ProposalText) != "" { + return strings.TrimSpace(proposal.ProposalText) + } + return strings.TrimSpace(a.getLastAssistantReply(userID)) +} + +func activeSessionHasField(s ActiveSkillSession, slot string) bool { + slot = strings.TrimSpace(slot) + if slot == "" { + return false + } + if len(s.CollectedFields) == 0 { + return false + } + switch slot { + case "target_ref": + if value, ok := s.CollectedFields["bulk_scope"]; ok && strings.EqualFold(strings.TrimSpace(fmt.Sprint(value)), "all") { + return true + } + for _, key := range []string{"target_ref", "target_ref_id", "target_ref_name"} { + if value, ok := s.CollectedFields[key]; ok && strings.TrimSpace(fmt.Sprint(value)) != "" { + return true + } + } + return false + case "exchange": + value, ok := s.CollectedFields["exchange_id"] + return ok && strings.TrimSpace(fmt.Sprint(value)) != "" + case "model": + for _, key := range []string{"model_id", "ai_model_id"} { + if value, ok := s.CollectedFields[key]; ok && strings.TrimSpace(fmt.Sprint(value)) != "" { + return true + } + } + return false + case "strategy": + value, ok := s.CollectedFields["strategy_id"] + return ok && strings.TrimSpace(fmt.Sprint(value)) != "" + default: + value, ok := s.CollectedFields[slot] + return ok && strings.TrimSpace(fmt.Sprint(value)) != "" + } +} + +// missingRequiredFields returns required slots not yet collected, reading from skill registry. +func missingRequiredFields(s ActiveSkillSession) []string { + def, ok := getSkillDefinition(s.SkillName) + if !ok { + return nil + } + actionDef, ok := def.Actions[s.ActionName] + if !ok { + return nil + } + var missing []string + for _, slot := range actionDef.RequiredSlots { + if !activeSessionHasField(s, slot) { + missing = append(missing, slot) + } + } + return missing +} + +// fieldConstraintSummary returns a compact description of missing fields for the LLM prompt. +func fieldConstraintSummary(s ActiveSkillSession) string { + def, ok := getSkillDefinition(s.SkillName) + if !ok { + return "" + } + missing := missingRequiredFields(s) + if len(missing) == 0 { + return "" + } + lines := make([]string, 0, len(missing)) + for _, key := range missing { + constraint, ok := def.FieldConstraints[key] + if !ok { + lines = append(lines, fmt.Sprintf("- %s (required)", key)) + continue + } + desc := constraint.Description + if len(constraint.Values) > 0 { + desc += fmt.Sprintf(" [options: %s]", strings.Join(constraint.Values, ", ")) + } + lines = append(lines, fmt.Sprintf("- %s: %s", key, desc)) + } + return strings.Join(lines, "\n") +} diff --git a/agent/agent.go b/agent/agent.go index 8eb51e95..6f4223d0 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -11,16 +11,20 @@ import ( "fmt" "log/slog" "net/http" + "os" "sort" "strconv" "strings" "sync" "time" + gethcrypto "github.com/ethereum/go-ethereum/crypto" + "nofx/manager" "nofx/market" "nofx/mcp" "nofx/store" + "nofx/wallet" ) type Agent struct { @@ -35,23 +39,35 @@ type Agent struct { history *chatHistory pending *pendingTrades stopCh chan struct{} // signals background goroutines to stop - stopOnce sync.Once + setupStates sync.Map + flowLocks sync.Map NotifyFunc func(userID int64, text string) error } type Config struct { - Language string `json:"language"` - WatchSymbols []string `json:"watch_symbols"` - EnableBriefs bool `json:"enable_briefs"` - EnableNews bool `json:"enable_news"` - EnableSentinel bool `json:"enable_sentinel"` - BriefTimes []int `json:"brief_times"` + Language string `json:"language"` + WatchSymbols []string `json:"watch_symbols"` + EnableBriefs bool `json:"enable_briefs"` + EnableNews bool `json:"enable_news"` + EnableSentinel bool `json:"enable_sentinel"` + AllowTradeExecution bool `json:"allow_trade_execution"` + BriefTimes []int `json:"brief_times"` } +var ( + agentWalletAddressFromPrivateKey = walletAddressFromPrivateKey + agentQueryUSDCBalanceCached = wallet.QueryUSDCBalanceCached +) + func DefaultConfig() *Config { return &Config{ - Language: "zh", WatchSymbols: []string{"BTCUSDT", "ETHUSDT", "SOLUSDT"}, - EnableBriefs: true, EnableNews: true, EnableSentinel: true, BriefTimes: []int{8, 20}, + Language: "zh", + WatchSymbols: []string{"BTCUSDT", "ETHUSDT", "SOLUSDT"}, + EnableBriefs: true, + EnableNews: true, + EnableSentinel: true, + AllowTradeExecution: true, + BriefTimes: []int{8, 20}, } } @@ -59,7 +75,7 @@ func New(tm *manager.TraderManager, st *store.Store, cfg *Config, logger *slog.L if cfg == nil { cfg = DefaultConfig() } - return &Agent{traderManager: tm, store: st, config: cfg, logger: logger, history: newChatHistory(100), pending: newPendingTrades(), stopCh: make(chan struct{})} + return &Agent{traderManager: tm, store: st, config: cfg, logger: logger, history: newChatHistory(chatHistoryMaxTurns), pending: newPendingTrades(), stopCh: make(chan struct{})} } func (a *Agent) SetAIClient(c mcp.AIClient) { a.aiClient = c } @@ -77,6 +93,14 @@ func (a *Agent) log() *slog.Logger { return slog.Default() } +func (a *Agent) flowLock(userID int64) *sync.Mutex { + if a == nil { + return &sync.Mutex{} + } + lock, _ := a.flowLocks.LoadOrStore(userID, &sync.Mutex{}) + return lock.(*sync.Mutex) +} + func (a *Agent) EnsureAIClient() { a.ensureAIClientForStoreUser("default") } @@ -108,57 +132,182 @@ func (a *Agent) loadAIClientFromStoreUser(storeUserID string) (mcp.AIClient, str if storeUserID == "" { storeUserID = "default" } - - model, err := a.store.AIModel().GetDefault(storeUserID) - if err != nil || model == nil { - a.log().Warn("no enabled AI model found for store user", "store_user_id", storeUserID, "error", err) - return nil, "", false + candidateUserIDs := []string{storeUserID} + if storeUserID != "default" { + candidateUserIDs = append(candidateUserIDs, "default") } - - a.log().Info( - "agent selected AI model config", - "store_user_id", storeUserID, - "model_id", model.ID, - "provider", model.Provider, - "enabled", model.Enabled, - "has_api_key", len(model.APIKey) > 0, - "custom_api_url", strings.TrimSpace(model.CustomAPIURL), - "custom_model_name", strings.TrimSpace(model.CustomModelName), - ) - - apiKey := string(model.APIKey) - customAPIURL := strings.TrimSpace(model.CustomAPIURL) - modelName := strings.TrimSpace(model.CustomModelName) - provider := strings.ToLower(strings.TrimSpace(model.Provider)) - - // Use the provider registry for providers like claw402 that have their own - // client implementation (x402 payment, custom auth, etc.). - if client := mcp.NewAIClientByProvider(provider); client != nil { - if modelName == "" { - modelName = model.ID + for _, candidateUserID := range candidateUserIDs { + models, err := a.store.AIModel().List(candidateUserID) + if err != nil { + a.log().Warn("failed to list AI models for store user", "store_user_id", candidateUserID, "error", err) + continue + } + candidates := rankAgentModelCandidates(models) + for _, candidate := range candidates { + model := candidate.model + if model == nil || !model.Enabled || !agentModelHasUsableAPIKey(model) { + continue + } + + a.log().Info( + "agent evaluating AI model config", + "store_user_id", candidateUserID, + "model_id", model.ID, + "provider", model.Provider, + "enabled", model.Enabled, + "has_api_key", len(model.APIKey) > 0, + "custom_api_url", strings.TrimSpace(model.CustomAPIURL), + "custom_model_name", strings.TrimSpace(model.CustomModelName), + "prefer_model_with_balance", candidate.preferModelWithBalance, + "wallet_balance_usdc", candidate.balanceUSDC, + ) + + apiKey := strings.TrimSpace(string(model.APIKey)) + customAPIURL := strings.TrimSpace(model.CustomAPIURL) + modelName := strings.TrimSpace(model.CustomModelName) + customAPIURL, modelName = resolveModelRuntimeConfig(model.Provider, customAPIURL, modelName, model.ID) + if apiKey == "" || customAPIURL == "" { + a.log().Warn( + "skipping incomplete enabled AI model", + "store_user_id", candidateUserID, + "model_id", model.ID, + "provider", model.Provider, + "has_api_key", apiKey != "", + "has_custom_api_url", customAPIURL != "", + ) + continue + } + + httpClient := &http.Client{Timeout: 60 * time.Second} + client := mcp.NewClient(mcp.WithHTTPClient(httpClient)) + client.SetAPIKey(apiKey, customAPIURL, modelName) + a.log().Info("agent AI client selected", "store_user_id", candidateUserID, "model_id", model.ID, "model", modelName) + return client, modelName, true } - client.SetAPIKey(apiKey, customAPIURL, modelName) - return client, modelName, true } - customAPIURL, modelName = resolveModelRuntimeConfig(provider, customAPIURL, modelName, model.ID) - if apiKey == "" || customAPIURL == "" { - a.log().Warn( - "enabled AI model is incomplete", - "store_user_id", storeUserID, - "model_id", model.ID, - "provider", model.Provider, - "has_api_key", apiKey != "", - "has_custom_api_url", customAPIURL != "", - ) - return nil, "", false + a.log().Warn("no enabled AI model found for store user", "store_user_id", storeUserID) + return nil, "", false +} + +type agentModelCandidate struct { + model *store.AIModel + preferModelWithBalance bool + balanceUSDC float64 +} + +func rankAgentModelCandidates(models []*store.AIModel) []agentModelCandidate { + candidates := make([]agentModelCandidate, 0, len(models)) + for _, model := range models { + if model == nil { + continue + } + candidate := agentModelCandidate{model: model} + if balance, ok := agentModelUSDCBalance(model); ok && balance > 0 { + candidate.preferModelWithBalance = true + candidate.balanceUSDC = balance + } + candidates = append(candidates, candidate) } - httpClient := &http.Client{Timeout: 60 * time.Second} - client := mcp.NewClient(mcp.WithHTTPClient(httpClient)) - name := modelName - client.SetAPIKey(apiKey, customAPIURL, name) - return client, name, true + sort.SliceStable(candidates, func(i, j int) bool { + left := candidates[i] + right := candidates[j] + if left.preferModelWithBalance != right.preferModelWithBalance { + return left.preferModelWithBalance + } + if left.balanceUSDC != right.balanceUSDC { + return left.balanceUSDC > right.balanceUSDC + } + leftUpdatedAt := time.Time{} + rightUpdatedAt := time.Time{} + if left.model != nil { + leftUpdatedAt = left.model.UpdatedAt + } + if right.model != nil { + rightUpdatedAt = right.model.UpdatedAt + } + if !leftUpdatedAt.Equal(rightUpdatedAt) { + return leftUpdatedAt.After(rightUpdatedAt) + } + leftID := "" + rightID := "" + if left.model != nil { + leftID = left.model.ID + } + if right.model != nil { + rightID = right.model.ID + } + return leftID < rightID + }) + + return candidates +} + +func agentModelUSDCBalance(model *store.AIModel) (float64, bool) { + if model == nil || !agentProviderSupportsUSDCBalance(model.Provider) { + return 0, false + } + privateKey := strings.TrimSpace(string(model.APIKey)) + if privateKey == "" { + return 0, false + } + walletAddress, err := agentWalletAddressFromPrivateKey(privateKey) + if err != nil || strings.TrimSpace(walletAddress) == "" { + return 0, false + } + balance, err := agentQueryUSDCBalanceCached(walletAddress) + if err != nil || balance <= 0 { + return 0, false + } + return balance, true +} + +func agentProviderSupportsUSDCBalance(provider string) bool { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "claw402", "blockrun-base": + return true + default: + return false + } +} + +func agentModelHasUsableAPIKey(model *store.AIModel) bool { + if model == nil { + return false + } + if strings.TrimSpace(string(model.APIKey)) != "" { + return true + } + envKeyByProvider := map[string]string{ + "deepseek": "DEEPSEEK_API_KEY", + "openai": "OPENAI_API_KEY", + "claude": "ANTHROPIC_API_KEY", + "gemini": "GEMINI_API_KEY", + "grok": "XAI_API_KEY", + "kimi": "MOONSHOT_API_KEY", + "minimax": "MINIMAX_API_KEY", + "qwen": "DASHSCOPE_API_KEY", + } + envKey := envKeyByProvider[strings.ToLower(strings.TrimSpace(model.Provider))] + return envKey != "" && strings.TrimSpace(os.Getenv(envKey)) != "" +} + +func walletAddressFromPrivateKey(privateKey string) (string, error) { + key := strings.TrimSpace(privateKey) + if !strings.HasPrefix(key, "0x") { + return "", fmt.Errorf("private key must start with 0x") + } + if len(key) != 66 { + return "", fmt.Errorf("private key must be 66 characters") + } + + privateKeyObj, err := gethcrypto.HexToECDSA(strings.TrimPrefix(key, "0x")) + if err != nil { + return "", err + } + + return gethcrypto.PubkeyToAddress(privateKeyObj.PublicKey).Hex(), nil } func resolveModelRuntimeConfig(provider, customAPIURL, customModelName, fallbackModelID string) (string, string) { @@ -180,6 +329,7 @@ func resolveModelRuntimeConfig(provider, customAPIURL, customModelName, fallback "grok": {url: "https://api.x.ai/v1", model: "grok-3-latest"}, "kimi": {url: "https://api.moonshot.ai/v1", model: "moonshot-v1-auto"}, "minimax": {url: "https://api.minimax.chat/v1", model: "MiniMax-M2.5"}, + "claw402": {url: "https://claw402.ai", model: "deepseek"}, } if customAPIURL == "" { @@ -221,7 +371,12 @@ func (a *Agent) Start() { func (a *Agent) Stop() { // Signal all background goroutines (e.g. chat-history-cleanup) to exit. - a.stopOnce.Do(func() { close(a.stopCh) }) + select { + case <-a.stopCh: + // Already closed + default: + close(a.stopCh) + } if a.sentinel != nil { a.sentinel.Stop() } @@ -263,9 +418,7 @@ func (a *Agent) handleMessageForStoreUser(ctx context.Context, storeUserID strin return a.handleStatus(lang), nil } if text == "/clear" { - a.history.Clear(userID) - a.clearTaskState(userID) - a.clearExecutionState(userID) + a.clearConversationState(userID) if lang == "zh" { return "🧹 对话记忆已清除。", nil } @@ -274,6 +427,9 @@ func (a *Agent) handleMessageForStoreUser(ctx context.Context, storeUserID strin if reply, handled := a.handleTradeConfirmation(ctx, userID, text, lang); handled { return reply, nil } + if reply, handled := a.handleModelWalletBalanceQuestion(storeUserID, lang, text); handled { + return reply, nil + } // Everything else goes through the planner and tool system. return a.thinkAndAct(ctx, storeUserID, userID, lang, text) @@ -309,9 +465,7 @@ func (a *Agent) handleMessageStreamForStoreUser(ctx context.Context, storeUserID return a.handleStatus(lang), nil } if text == "/clear" { - a.history.Clear(userID) - a.clearTaskState(userID) - a.clearExecutionState(userID) + a.clearConversationState(userID) if lang == "zh" { return "🧹 对话记忆已清除。", nil } @@ -319,13 +473,37 @@ func (a *Agent) handleMessageStreamForStoreUser(ctx context.Context, storeUserID } if reply, handled := a.handleTradeConfirmation(ctx, userID, text, lang); handled { if onEvent != nil { - onEvent(StreamEventDelta, reply) + emitStreamText(onEvent, reply) + } + return reply, nil + } + if reply, handled := a.handleModelWalletBalanceQuestion(storeUserID, lang, text); handled { + if onEvent != nil { + emitStreamText(onEvent, reply) } return reply, nil } return a.thinkAndActStream(ctx, storeUserID, userID, lang, text, onEvent) } +func (a *Agent) clearConversationState(userID int64) { + if a == nil { + return + } + if a.history != nil { + a.history.Clear(userID) + } + a.clearTaskState(userID) + a.clearSkillSession(userID) + a.clearActiveSkillSession(userID) + a.clearPendingProposalSession(userID) + a.clearWorkflowSession(userID) + a.clearExecutionState(userID) + a.clearReferenceMemory(userID) + a.SnapshotManager(userID).Clear() + a.clearSetupState(userID) +} + // StreamEvent types sent via SSE to the frontend. const ( StreamEventPlanning = "planning" @@ -341,8 +519,12 @@ const ( // buildSystemPrompt creates the system prompt that makes NOFXi behave like a real agent. func (a *Agent) buildSystemPrompt(lang string) string { + return a.buildSystemPromptForStoreUser(lang, "default") +} + +func (a *Agent) buildSystemPromptForStoreUser(lang, storeUserID string) string { // Gather live system state - traderInfo := a.getTradersSummary() + traderInfo := a.getTradersSummaryForStoreUser(storeUserID) watchlist := "" if a.sentinel != nil { watchlist = a.sentinel.FormatWatchlist(lang) @@ -382,19 +564,24 @@ func (a *Agent) buildSystemPrompt(lang string) string { ## 工具使用 你可以调用以下工具来执行操作: - **search_stock** — 搜索股票(支持中文名、英文名、代码)。当用户提到你不认识的股票时,先用这个工具搜索。 -- **execute_trade** — 下单交易(加密货币或美股)。美股:open_long=买入,close_long=卖出。调用后创建待确认订单,用户需回复"确认 trade_xxx"。 +- **execute_trade** — 下单交易(加密货币或美股)。常见写法:"做多 BTC 0.01 x10"、"做空 ETH 0.1"、"平多 BTC"、"平空 ETH";英文也支持 "long BTC 0.01 x10"、"short ETH 0.1"、"close long BTC"、"close short ETH"。美股:open_long=买入,close_long=卖出。调用后先创建待确认订单,不会立刻成交。若触发大额风控,用户必须回复"确认大额 trade_xxx";待确认订单 5 分钟后自动失效。 - **get_positions** — 查看当前所有持仓(加密货币 + 股票) - **get_balance** — 查看账户余额 - **get_market_price** — 获取实时价格(加密货币或股票代码) +- **get_kline** — 获取最近 K 线 / 蜡烛图数据(适合“看 15 分钟 K 线”“最近 50 根 1 小时 K 线”) - **get_exchange_configs / manage_exchange_config** — 查看、新增、修改、删除交易所绑定配置 - **get_model_configs / manage_model_config** — 查看、新增、修改、删除 AI 模型配置 - **get_strategies / manage_strategy** — 查看、新增、修改、删除、激活、复制策略模板 - **manage_trader** — 查看、新增、修改、删除、启动、停止交易员 +- **get_watchlist / manage_watchlist** — 查看、添加、移除运行时监控币对,适合“把 BTC 加入监控”“别再监控 SOL”这类请求 ### 配置、策略与交易员管理规则 - 当用户要求创建、修改、删除、激活、复制策略模板时,优先使用 get_strategies / manage_strategy - **策略模板本身是独立资源,不默认依赖交易所或 AI 模型** -- 只有当用户要求“运行策略 / 创建交易员 / 把策略部署到账户”时,才需要进一步关联交易所、模型或 trader +- **策略模板创建成功后应立即出现在策略列表/策略页** +- **策略模板不能直接启动或运行;只有交易员有运行态。** +- 如果用户说“启动策略 / 运行策略”,要明确说明:应先把策略绑定到交易员,再启动交易员 +- 用户没问运行/部署/创建交易员时,不要主动延伸到交易员、模型或交易所绑定 - 当用户要求配置交易所、绑定 API Key、修改交易所账户时,优先使用 manage_exchange_config - 当用户要求配置大模型、设置 API Key、切换模型、修改模型地址时,优先使用 manage_model_config - 当用户要求创建、修改、删除、启动、停止交易员时,优先使用 manage_trader @@ -406,9 +593,10 @@ func (a *Agent) buildSystemPrompt(lang string) string { ### 交易安全规则 - 用户明确要求交易时才调用 execute_trade +- 下单前先尊重风控:数量过大、仓位太小、杠杆过高、超过权益上限时,不要假装能下单,要直接用人话解释原因 - 分析和建议不需要调用工具,直接回复即可 - 交易确认信息要清晰展示:品种、方向、数量、杠杆 -- 提醒用户确认命令格式 +- 提醒用户确认命令格式;普通订单用“确认 trade_xxx”,大额订单用“确认大额 trade_xxx” ### 数据真实性规则(极其重要!) - **持仓信息必须且只能通过 get_positions 工具获取**,绝对禁止编造持仓 @@ -419,6 +607,10 @@ func (a *Agent) buildSystemPrompt(lang string) string { - 查股票行情 ≠ 用户持有该股票。不要混淆"查价格"和"有持仓" ## 行为准则 +- 把用户当交易小白,而不是开发者或量化工程师。 +- 先说结论,再说原因和下一步。 +- 语言要简单、清楚、直接,少用术语。 +- 如果必须用术语,立刻用大白话解释。 - 简洁、专业、有观点。不说废话。 - 用户问什么答什么,不要推销配置。 - 有实时数据时给具体价位,没有时给策略框架和思路。 @@ -461,10 +653,11 @@ func (a *Agent) buildSystemPrompt(lang string) string { ## Tools You can call these tools to take action: - **search_stock** — Search for stocks by name, ticker, or code. Covers A-share, HK, and US markets. Use when the user mentions an unknown stock. -- **execute_trade** — Place a trade order (crypto or US stocks). For stocks: open_long=buy, close_long=sell. Creates a pending order that requires user confirmation. +- **execute_trade** — Place a trade order (crypto or US stocks). Common phrasings include "long BTC 0.01 x10", "short ETH 0.1", "close long BTC", and "close short ETH". For stocks: open_long=buy, close_long=sell. This creates a pending trade first; it does not execute immediately. Large orders require "confirm large trade_xxx", and pending trades expire after 5 minutes. - **get_positions** — View all current open positions (crypto + stocks) - **get_balance** — View account balance and equity - **get_market_price** — Get real-time price from the exchange (crypto or stock symbol) +- **get_kline** — Get recent candlestick / kline data for a crypto symbol - **get_exchange_configs / manage_exchange_config** — View, create, update, and delete exchange bindings - **get_model_configs / manage_model_config** — View, create, update, and delete AI model bindings - **get_strategies / manage_strategy** — View, create, update, delete, activate, and duplicate strategy templates @@ -473,10 +666,14 @@ You can call these tools to take action: ### Configuration, Strategy, and Trader Rules - When the user wants to create, edit, delete, activate, or duplicate a strategy template, prefer get_strategies / manage_strategy - **A strategy template is an independent asset and does not require exchange or model bindings by default** -- Only ask for exchange/model/trader details when the user wants to run, deploy, or attach a strategy to a trader +- **After creation, a strategy template should immediately appear in the strategy list/page** +- **A strategy template cannot be started or run directly; only traders have runtime state** +- If the user says "start the strategy" or "run this strategy", explain that the strategy must be attached to a trader first, then the trader can be started +- Do not proactively bring up traders, models, or exchange bindings unless the user asks to run, deploy, or create a trader - When the user wants to bind or edit an exchange account, prefer manage_exchange_config - When the user wants to bind or edit an AI model, prefer manage_model_config - When the user wants to create, edit, delete, start, or stop a trader, prefer manage_trader +- When the user wants to add, remove, or inspect monitored coins, prefer get_watchlist / manage_watchlist - If required fields are missing, ask a focused follow-up question first, then call the tool - **Do not claim the system lacks these capabilities when the tools exist** - For secrets such as API keys, secrets, and private keys: store them, but never echo them back in full @@ -485,9 +682,10 @@ You can call these tools to take action: ### Trade Safety Rules - Only call execute_trade when user explicitly requests a trade +- Respect risk guardrails before placing a trade: if the quantity is too large, the notional is too small, leverage is too high, or the order exceeds equity limits, explain the reason plainly instead of pretending it can be placed - Analysis and advice don't need tools — just reply directly - Show trade details clearly: symbol, direction, quantity, leverage -- Remind user of the confirmation command format +- Remind user of the confirmation command format; normal orders use "confirm trade_xxx", large orders use "confirm large trade_xxx" ### Data Truthfulness Rules (CRITICAL!) - **Position data MUST come from get_positions tool only** — NEVER fabricate positions @@ -498,6 +696,10 @@ You can call these tools to take action: - Checking a stock price ≠ user owns that stock. Never confuse "quote lookup" with "holding" ## Behavior +- Treat the user like a trading beginner, not a developer. +- Lead with the conclusion first, then explain the reason and next step. +- Use plain language and keep jargon to a minimum. +- If you must use a technical term, explain it in simple words immediately. - Concise, professional, opinionated. No fluff. - Answer what's asked. Don't push setup. - With real-time data: give specific levels. Without: give strategy frameworks. @@ -508,7 +710,7 @@ Current time: %s`, traderInfo, watchlist, skillCatalog, time.Now().Format("2006- } // gatherContext collects real-time market data relevant to the user's message. -func (a *Agent) gatherContext(text string) string { +func (a *Agent) gatherContext(storeUserID, text string) string { var parts []string upper := strings.ToUpper(text) @@ -573,8 +775,16 @@ func (a *Agent) gatherContext(text string) string { } // Trader positions - if a.traderManager != nil { - for _, t := range a.traderManager.GetAllTraders() { + if a.traderManager != nil && a.store != nil { + traderConfigs, _ := a.store.Trader().List(storeUserID) + for _, traderCfg := range traderConfigs { + if strings.TrimSpace(traderCfg.ID) == "" { + continue + } + t, err := a.traderManager.GetTrader(traderCfg.ID) + if err != nil { + continue + } positions, err := t.GetPositions() if err != nil { continue @@ -594,27 +804,51 @@ func (a *Agent) gatherContext(text string) string { } func (a *Agent) getTradersSummary() string { + return a.getTradersSummaryForStoreUser("default") +} + +func (a *Agent) getTradersSummaryForStoreUser(storeUserID string) string { if a.traderManager == nil { return "Traders: none configured" } - traders := a.traderManager.GetAllTraders() - if len(traders) == 0 { + if a.store == nil { + return "Traders: none configured" + } + if strings.TrimSpace(storeUserID) == "" { + storeUserID = "default" + } + traderConfigs, err := a.store.Trader().List(storeUserID) + if err != nil || len(traderConfigs) == 0 { return "Traders: none configured" } var lines []string - for id, t := range traders { - s := t.GetStatus() - running, _ := s["is_running"].(bool) + for _, traderCfg := range traderConfigs { + if strings.TrimSpace(traderCfg.ID) == "" { + continue + } + t, err := a.traderManager.GetTrader(traderCfg.ID) + isRunning := traderCfg.IsRunning + exchange := traderCfg.ExchangeID + if err == nil && t != nil { + s := t.GetStatus() + if running, ok := s["is_running"].(bool); ok { + isRunning = running + } + exchange = t.GetExchange() + } status := "stopped" - if running { + if isRunning { status = "running" } - tid := id + tid := traderCfg.ID if len(tid) > 8 { tid = tid[:8] } - lines = append(lines, fmt.Sprintf("• %s [%s] %s | %s", t.GetName(), tid, status, t.GetExchange())) + lines = append(lines, fmt.Sprintf("• %s [%s] %s | %s", traderCfg.Name, tid, status, exchange)) + } + if len(lines) == 0 { + return "Traders: none configured" } return "Traders:\n" + strings.Join(lines, "\n") } @@ -642,7 +876,7 @@ func (a *Agent) handleStatus(L string) string { } // noAIFallback — when no AI is available, still try to be useful. -func (a *Agent) noAIFallback(lang, text string) (string, error) { +func (a *Agent) noAIFallback(storeUserID, lang, text string) (string, error) { upper := strings.ToUpper(text) // Try to provide market data directly @@ -657,16 +891,16 @@ func (a *Agent) noAIFallback(lang, text string) (string, error) { // Check if asking about positions/balance if strings.Contains(text, "持仓") || strings.Contains(upper, "POSITION") { - return a.queryPositionsDirect(lang) + return a.queryPositionsDirect(storeUserID, lang) } if strings.Contains(text, "余额") || strings.Contains(upper, "BALANCE") { - return a.queryBalancesDirect(lang) + return a.queryBalancesDirect(storeUserID, lang) } if lang == "zh" { - return "🤖 我是 NOFXi。配置 AI 模型后我就能理解你的任何问题——分析股票、制定策略、管理交易。\n\n现在可用:\n• 加密货币实时行情(试试「BTC」)\n• `/status` 系统状态\n\n发送 *开始配置* 配置 AI 模型。", nil + return "🤖 我是 NOFXi。配置 AI 模型后我就能理解你的任何问题——分析股票、制定策略、管理交易。\n\n现在可用:\n• 加密货币实时行情(试试「BTC」)\n• `/status` 查看系统状态\n• `/clear` 清空当前对话记忆\n\n发送 *开始配置* 配置 AI 模型。", nil } - return "🤖 I'm NOFXi. Configure an AI model and I can understand anything — analyze stocks, build strategies, manage trades.\n\nAvailable now:\n• Crypto real-time data (try 'BTC')\n• `/status` system status\n\nSend *setup* to configure AI.", nil + return "🤖 I'm NOFXi. Configure an AI model and I can understand anything — analyze stocks, build strategies, manage trades.\n\nAvailable now:\n• Crypto real-time data (try 'BTC')\n• `/status` to check system status\n• `/clear` to clear the current conversation memory\n\nSend *setup* to configure AI.", nil } func (a *Agent) aiServiceFailure(lang string, err error) (string, error) { @@ -676,19 +910,89 @@ func (a *Agent) aiServiceFailure(lang string, err error) (string, error) { } a.logger.Error("AI service call failed", "error", reason) if lang == "zh" { - return fmt.Sprintf("当前 AI 服务调用失败:%s\n\n这不是“未配置模型”。更可能是模型服务余额不足、接口报错或超时。请检查当前启用模型的 API 状态后再试。", reason), nil + return fmt.Sprintf("当前 AI 服务调用失败:%s\n\n%s", reason, aiServiceFailureGuidance("zh", reason)), nil } - return fmt.Sprintf("The AI service call failed: %s\n\nThis is not a missing-model issue. The active model provider likely returned an error, timed out, or has insufficient balance. Please check the active model API and try again.", reason), nil + return fmt.Sprintf("The AI service call failed: %s\n\n%s", reason, aiServiceFailureGuidance(lang, reason)), nil } -func (a *Agent) queryPositionsDirect(L string) (string, error) { +func aiServiceFailureGuidance(lang, reason string) string { + lower := strings.ToLower(strings.TrimSpace(reason)) + looksLikeHTMLGateway := strings.Contains(lower, "invalid character '<'") || + strings.Contains(lower, "unexpected character '<'") || + strings.Contains(lower, " 8 { tid = tid[:8] } @@ -717,18 +1021,32 @@ func (a *Agent) queryPositionsDirect(L string) (string, error) { return sb.String(), nil } -func (a *Agent) queryBalancesDirect(L string) (string, error) { +func (a *Agent) queryBalancesDirect(storeUserID, L string) (string, error) { if a.traderManager == nil { return a.msg(L, "no_traders"), nil } + if a.store == nil { + return a.msg(L, "no_traders"), nil + } + traderConfigs, err := a.store.Trader().List(storeUserID) + if err != nil { + return a.msg(L, "no_traders"), nil + } var sb strings.Builder sb.WriteString("💰 *Balance*\n\n") - for id, t := range a.traderManager.GetAllTraders() { + for _, traderCfg := range traderConfigs { + if strings.TrimSpace(traderCfg.ID) == "" { + continue + } + t, err := a.traderManager.GetTrader(traderCfg.ID) + if err != nil { + continue + } info, err := t.GetAccountInfo() if err != nil { continue } - tid := id + tid := traderCfg.ID if len(tid) > 8 { tid = tid[:8] } diff --git a/agent/agent_model_selection_test.go b/agent/agent_model_selection_test.go new file mode 100644 index 00000000..dd908c4e --- /dev/null +++ b/agent/agent_model_selection_test.go @@ -0,0 +1,53 @@ +package agent + +import ( + "log/slog" + "path/filepath" + "testing" + + "nofx/store" +) + +func TestLoadAIClientFromStoreUserPrefersModelWithBalance(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "agent-model-selection.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + + if err := st.AIModel().UpdateWithName("default", "default_openai", "OpenAI", true, "sk-test", "", "gpt-5.2"); err != nil { + t.Fatalf("create openai model: %v", err) + } + if err := st.AIModel().UpdateWithName("default", "wallet_claw402", "Claw402", true, "0x205d759b80bae1afa31a36c4afaeec0b10378c1c55e3363bcde5a1db75c747ca", "", "glm-5"); err != nil { + t.Fatalf("create claw402 model: %v", err) + } + + restoreWalletAddress := agentWalletAddressFromPrivateKey + restoreBalanceQuery := agentQueryUSDCBalanceCached + t.Cleanup(func() { + agentWalletAddressFromPrivateKey = restoreWalletAddress + agentQueryUSDCBalanceCached = restoreBalanceQuery + }) + + agentWalletAddressFromPrivateKey = func(privateKey string) (string, error) { + if privateKey == "0x205d759b80bae1afa31a36c4afaeec0b10378c1c55e3363bcde5a1db75c747ca" { + return "0xabc", nil + } + return "", nil + } + agentQueryUSDCBalanceCached = func(address string) (float64, error) { + if address == "0xabc" { + return 12.5, nil + } + return 0, nil + } + + a := New(nil, st, DefaultConfig(), slog.Default()) + _, modelName, ok := a.loadAIClientFromStoreUser("default") + if !ok { + t.Fatalf("expected model selection to succeed") + } + if modelName != "glm-5" { + t.Fatalf("expected model with wallet balance to be selected, got %q", modelName) + } +} diff --git a/agent/ai_service_failure_test.go b/agent/ai_service_failure_test.go new file mode 100644 index 00000000..b63b6363 --- /dev/null +++ b/agent/ai_service_failure_test.go @@ -0,0 +1,128 @@ +package agent + +import ( + "errors" + "log/slog" + "strings" + "testing" +) + +func TestAIServiceFailureHighlightsHTMLGatewayResponse(t *testing.T) { + a := New(nil, nil, DefaultConfig(), slog.Default()) + + msg, err := a.aiServiceFailure("zh", errors.New("fail to parse AI server response: failed to parse response: invalid character '<' looking for beginning of value")) + if err != nil { + t.Fatalf("aiServiceFailure returned error: %v", err) + } + + for _, want := range []string{ + "当前 AI 服务调用失败", + "上游返回了 HTML 页面或网关/反代错误页", + "custom_api_url", + "不是“未配置模型”", + } { + if !strings.Contains(msg, want) { + t.Fatalf("expected message to contain %q, got: %s", want, msg) + } + } + if strings.Contains(msg, "更可能是模型服务余额不足、接口报错或超时") { + t.Fatalf("html parse error should not use the generic balance/timeout-only guidance: %s", msg) + } +} + +func TestAIServiceFailureHighlightsUpstreamEmptyOutputRateLimit(t *testing.T) { + a := New(nil, nil, DefaultConfig(), slog.Default()) + + msg, err := a.aiServiceFailure("zh", errors.New(`API returned error (status 429): {"error":{"code":"upstream_empty_output","message":"Upstream model returned empty output.","param":null,"type":"rate_limit_error"}}`)) + if err != nil { + t.Fatalf("aiServiceFailure returned error: %v", err) + } + + for _, want := range []string{ + "当前 AI 服务调用失败", + "上游模型没有返回有效内容", + "不应优先归因成“余额不足”", + "切换到另一个可用模型", + } { + if !strings.Contains(msg, want) { + t.Fatalf("expected message to contain %q, got: %s", want, msg) + } + } + if strings.Contains(msg, "更可能是模型服务余额不足、接口报错、鉴权失败或超时") { + t.Fatalf("upstream empty output should not use the generic balance/auth/timeout guidance: %s", msg) + } +} + +func TestAIServiceFailureHighlightsBannedAccountAuthFailure(t *testing.T) { + a := New(nil, nil, DefaultConfig(), slog.Default()) + + msg, err := a.aiServiceFailure("zh", errors.New(`API returned error (status 401): {"error":{"code":"authentication_failed","message":"login failed: USER_IS_BANNED","param":null,"type":"authentication_error"}}`)) + if err != nil { + t.Fatalf("aiServiceFailure returned error: %v", err) + } + + for _, want := range []string{ + "当前 AI 服务调用失败", + "账号被禁用/封禁", + "USER_IS_BANNED", + "换一个可用账号/API Key", + "切换到另一个已启用模型", + } { + if !strings.Contains(msg, want) { + t.Fatalf("expected message to contain %q, got: %s", want, msg) + } + } + for _, unexpected := range []string{"余额不足", "超时"} { + if strings.Contains(msg, unexpected) { + t.Fatalf("banned account auth failure should not mention %q: %s", unexpected, msg) + } + } +} + +func TestCompletedPlanFallbackDoesNotExposeFinalSummaryFailure(t *testing.T) { + msg := formatCompletedPlanFallback("zh", []PlanStep{ + { + Type: planStepTypeTool, + Status: planStepStatusCompleted, + Title: "创建名为 eeg 的策略", + }, + }) + if msg == "" { + t.Fatalf("expected fallback message") + } + for _, bad := range []string{"失败", "AI", "稍后"} { + if strings.Contains(msg, bad) { + t.Fatalf("fallback should not expose final summary failure %q: %s", bad, msg) + } + } + if !strings.Contains(msg, "已完成") || !strings.Contains(msg, "创建名为 eeg 的策略") { + t.Fatalf("fallback should summarize completed work, got: %s", msg) + } +} + +func TestDeterministicCompletedPlanResponseSkipsLLMForSimpleConfirmation(t *testing.T) { + state := ExecutionState{ + Steps: []PlanStep{ + { + ID: "create_strategy", + Type: planStepTypeTool, + Status: planStepStatusCompleted, + Title: "创建名为 eeg 的策略", + }, + { + ID: "respond", + Type: planStepTypeRespond, + Status: planStepStatusRunning, + Title: "策略创建成功", + Instruction: "确认策略创建成功", + }, + }, + } + msg := deterministicCompletedPlanResponse("zh", state, state.Steps[1]) + if msg == "" { + t.Fatalf("expected deterministic response") + } + if !strings.Contains(msg, "已完成") || !strings.Contains(msg, "创建名为 eeg 的策略") { + t.Fatalf("unexpected deterministic response: %s", msg) + } +} diff --git a/agent/atomic_skill_executor.go b/agent/atomic_skill_executor.go new file mode 100644 index 00000000..7fa2cfb5 --- /dev/null +++ b/agent/atomic_skill_executor.go @@ -0,0 +1,87 @@ +package agent + +import "strings" + +func (a *Agent) executeAtomicSkillTask(storeUserID string, userID int64, lang, text, skill, action string, onEvent func(event, data string)) (string, bool) { + return a.executeAtomicSkillTaskWithSession(storeUserID, userID, lang, text, skillSession{Name: strings.TrimSpace(skill), Action: normalizeAtomicSkillAction(strings.TrimSpace(skill), action), Phase: "collecting"}, onEvent) +} + +func (a *Agent) executeAtomicSkillTaskWithSession(storeUserID string, userID int64, lang, text string, session skillSession, onEvent func(event, data string)) (string, bool) { + skill := strings.TrimSpace(session.Name) + action := normalizeAtomicSkillAction(skill, session.Action) + session.Name = skill + session.Action = action + if strings.TrimSpace(session.Phase) == "" { + session.Phase = "collecting" + } + skill = strings.TrimSpace(skill) + action = normalizeAtomicSkillAction(skill, action) + + var ( + answer string + handled bool + ) + + switch skill { + case "trader_management": + if action == "create" { + answer, handled = a.handleCreateTraderSkill(storeUserID, userID, lang, text, session) + } else { + answer, handled = a.handleTraderManagementSkill(storeUserID, userID, lang, text, session) + if handled && action == "query_running" { + answer = applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), "running_only") + } + } + case "exchange_management": + answer, handled = a.handleExchangeManagementSkill(storeUserID, userID, lang, text, session) + case "model_management": + answer, handled = a.handleModelManagementSkill(storeUserID, userID, lang, text, session) + case "strategy_management": + answer, handled = a.handleStrategyManagementSkill(storeUserID, userID, lang, text, session) + case "model_diagnosis": + answer, handled = a.handleModelDiagnosisSkill(storeUserID, lang, text), true + case "exchange_diagnosis": + answer, handled = a.handleExchangeDiagnosisSkill(storeUserID, lang, text), true + case "trader_diagnosis": + answer, handled = a.handleTraderDiagnosisSkill(storeUserID, lang, text), true + case "strategy_diagnosis": + answer, handled = a.handleStrategyDiagnosisSkill(storeUserID, lang, text), true + default: + return "", false + } + + if handled && onEvent != nil { + label := "atomic_skill:" + skill + if action != "" { + label += ":" + action + } + onEvent(StreamEventTool, label) + emitStreamText(onEvent, answer) + } + return answer, handled +} + +func (a *Agent) executeAtomicSkillTaskOutcome(storeUserID string, userID int64, lang, text, skill, action string, onEvent func(event, data string)) (skillOutcome, bool) { + return a.executeAtomicSkillTaskOutcomeWithSession(storeUserID, userID, lang, text, skillSession{Name: strings.TrimSpace(skill), Action: normalizeAtomicSkillAction(strings.TrimSpace(skill), action), Phase: "collecting"}, onEvent) +} + +func (a *Agent) executeAtomicSkillTaskOutcomeWithSession(storeUserID string, userID int64, lang, text string, session skillSession, onEvent func(event, data string)) (skillOutcome, bool) { + answer, handled := a.executeAtomicSkillTaskWithSession(storeUserID, userID, lang, text, session, onEvent) + if !handled { + return skillOutcome{}, false + } + skill := strings.TrimSpace(session.Name) + action := normalizeAtomicSkillAction(skill, session.Action) + switch skill { + case "model_diagnosis", "exchange_diagnosis", "trader_diagnosis", "strategy_diagnosis": + return skillOutcome{ + Skill: skill, + Action: defaultIfEmpty(action, "diagnose"), + Status: skillOutcomeSuccess, + GoalAchieved: true, + UserMessage: answer, + }, true + default: + return inferSkillOutcome(skill, action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, skill, action, a)), true + } +} diff --git a/agent/backend_logs_test.go b/agent/backend_logs_test.go deleted file mode 100644 index 16f37a64..00000000 --- a/agent/backend_logs_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package agent - -import ( - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - - "nofx/store" -) - -func TestReadBackendLogEntriesReturnsRecentErrorLines(t *testing.T) { - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - tmp := t.TempDir() - if err := os.Chdir(tmp); err != nil { - t.Fatalf("Chdir(tmp) error = %v", err) - } - t.Cleanup(func() { - _ = os.Chdir(wd) - }) - - if err := os.MkdirAll("data", 0o755); err != nil { - t.Fatalf("MkdirAll(data) error = %v", err) - } - logPath := filepath.Join("data", "nofx_2099-01-01.log") - content := strings.Join([]string{ - "04-19 13:00:00 [INFO] api/server.go:590 API server starting", - "04-19 13:00:01 [ERRO] api/server.go:600 invalid signature for okx account", - "04-19 13:00:02 [ERRO] agent/tools.go:123 model update failed: missing api key", - }, "\n") + "\n" - if err := os.WriteFile(logPath, []byte(content), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - - path, entries, err := readBackendLogEntries(10, "model", true) - if err != nil { - t.Fatalf("readBackendLogEntries() error = %v", err) - } - if !strings.Contains(path, "nofx_2099-01-01.log") { - t.Fatalf("unexpected log path: %s", path) - } - if len(entries) != 1 || !strings.Contains(entries[0], "missing api key") { - t.Fatalf("unexpected filtered entries: %#v", entries) - } -} - -func TestToolGetBackendLogsRequiresOwnedTrader(t *testing.T) { - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - tmp := t.TempDir() - if err := os.Chdir(tmp); err != nil { - t.Fatalf("Chdir(tmp) error = %v", err) - } - t.Cleanup(func() { - _ = os.Chdir(wd) - }) - - if err := os.MkdirAll("data", 0o755); err != nil { - t.Fatalf("MkdirAll(data) error = %v", err) - } - logPath := filepath.Join("data", "nofx_2099-01-01.log") - content := strings.Join([]string{ - "04-19 13:00:00 [INFO] api/server.go:590 API server starting", - "04-19 13:00:01 [ERRO] trader/runtime.go:88 trader_id=trader-owned strategy execution failed", - "04-19 13:00:02 [ERRO] trader/runtime.go:89 trader_id=trader-other strategy execution failed", - }, "\n") + "\n" - if err := os.WriteFile(logPath, []byte(content), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - - a := newTestAgentWithStore(t) - if err := a.store.Trader().Create(&store.Trader{ - ID: "trader-owned", - UserID: "user-1", - Name: "Owned Trader", - AIModelID: "model-1", - ExchangeID: "exchange-1", - StrategyID: "strategy-1", - InitialBalance: 1000, - }); err != nil { - t.Fatalf("create owned trader: %v", err) - } - if err := a.store.Trader().Create(&store.Trader{ - ID: "trader-other", - UserID: "user-2", - Name: "Other Trader", - AIModelID: "model-2", - ExchangeID: "exchange-2", - StrategyID: "strategy-2", - InitialBalance: 1000, - }); err != nil { - t.Fatalf("create other trader: %v", err) - } - - resp := a.toolGetBackendLogs("user-1", `{"trader_id":"trader-owned","limit":5}`) - var okResult struct { - TraderID string `json:"trader_id"` - Entries []string `json:"entries"` - Count int `json:"count"` - } - if err := json.Unmarshal([]byte(resp), &okResult); err != nil { - t.Fatalf("unmarshal owned response: %v\nraw=%s", err, resp) - } - if okResult.TraderID != "trader-owned" || okResult.Count != 1 { - t.Fatalf("unexpected owned response: %+v", okResult) - } - if len(okResult.Entries) != 1 || !strings.Contains(okResult.Entries[0], "trader-owned") { - t.Fatalf("unexpected owned entries: %#v", okResult.Entries) - } - - resp = a.toolGetBackendLogs("user-1", `{"trader_id":"trader-other","limit":5}`) - var denied struct { - Error string `json:"error"` - } - if err := json.Unmarshal([]byte(resp), &denied); err != nil { - t.Fatalf("unmarshal denied response: %v\nraw=%s", err, resp) - } - if denied.Error != "trader not found for current user" { - t.Fatalf("unexpected denied response: %+v", denied) - } -} diff --git a/agent/brain.go b/agent/brain.go index 03b16d37..dfb146bf 100644 --- a/agent/brain.go +++ b/agent/brain.go @@ -30,7 +30,11 @@ func NewBrain(agent *Agent, logger *slog.Logger) *Brain { } } -func (b *Brain) Stop() { b.stopOnce.Do(func() { close(b.stopCh) }) } +func (b *Brain) Stop() { + b.stopOnce.Do(func() { + close(b.stopCh) + }) +} // cleanStaleSignals removes debounce entries older than 30 minutes. func (b *Brain) cleanStaleSignals() { @@ -54,22 +58,26 @@ func (b *Brain) HandleSignal(sig Signal) { emoji := map[string]string{"info": "ℹ️", "warning": "⚠️", "critical": "🚨"} e := emoji[sig.Severity] - if e == "" { e = "📊" } + if e == "" { + e = "📊" + } b.agent.notifyAll(fmt.Sprintf("%s *%s*\n\n%s", e, sig.Title, sig.Detail)) } func (b *Brain) StartNewsScan(interval time.Duration) { seen := make(map[string]bool) + seenOrder := make([]string, 0, 1024) safe.GoNamed("brain-news-scan", func() { ticker := time.NewTicker(interval) defer ticker.Stop() cleanTick := 0 for { select { - case <-b.stopCh: return + case <-b.stopCh: + return case <-ticker.C: - b.scanNews(seen) + b.scanNews(seen, &seenOrder) cleanTick++ if cleanTick%6 == 0 { // every ~30 min b.cleanStaleSignals() @@ -79,16 +87,20 @@ func (b *Brain) StartNewsScan(interval time.Duration) { }) } -func (b *Brain) scanNews(seen map[string]bool) { +func (b *Brain) scanNews(seen map[string]bool, seenOrder *[]string) { resp, err := b.http.Get("https://min-api.cryptocompare.com/data/v2/news/?lang=EN&sortOrder=latest") - if err != nil { return } + if err != nil { + return + } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { b.logger.Debug("news API non-200", "status", resp.StatusCode) return } body, err := safe.ReadAllLimited(resp.Body, 1024*1024) // 1MB limit - if err != nil { return } + if err != nil { + return + } var result struct { Data []struct { @@ -100,39 +112,65 @@ func (b *Brain) scanNews(seen map[string]bool) { PublishedOn int64 `json:"published_on"` } `json:"Data"` } - if err := json.Unmarshal(body, &result); err != nil { return } + if err := json.Unmarshal(body, &result); err != nil { + return + } bullish := []string{"surge", "rally", "bullish", "breakout", "ath", "pump", "adoption"} bearish := []string{"crash", "dump", "bearish", "sell-off", "plunge", "hack", "ban", "fraud"} for _, d := range result.Data { - if seen[d.URL] { continue } + if seen[d.URL] { + continue + } seen[d.URL] = true - if time.Since(time.Unix(d.PublishedOn, 0)) > 10*time.Minute { continue } + *seenOrder = append(*seenOrder, d.URL) + if time.Since(time.Unix(d.PublishedOn, 0)) > 10*time.Minute { + continue + } lower := strings.ToLower(d.Title + " " + d.Body) bc, brc := 0, 0 - for _, w := range bullish { if strings.Contains(lower, w) { bc++ } } - for _, w := range bearish { if strings.Contains(lower, w) { brc++ } } + for _, w := range bullish { + if strings.Contains(lower, w) { + bc++ + } + } + for _, w := range bearish { + if strings.Contains(lower, w) { + brc++ + } + } - if bc == 0 && brc == 0 { continue } + if bc == 0 && brc == 0 { + continue + } emoji := "📰" sentiment := "NEUTRAL" - if bc > brc { emoji = "🟢"; sentiment = "BULLISH" } - if brc > bc { emoji = "🔴"; sentiment = "BEARISH" } + if bc > brc { + emoji = "🟢" + sentiment = "BULLISH" + } + if brc > bc { + emoji = "🔴" + sentiment = "BEARISH" + } b.agent.notifyAll(fmt.Sprintf("%s *News*\n\n%s\n\n• Source: %s\n• Sentiment: %s", emoji, d.Title, d.Source, sentiment)) } - // Evict ~half when seen map gets large (keep recent half to avoid re-notifying) + // Evict the oldest half when seen grows large so recent URLs stay deduped deterministically. if len(seen) > 1000 { - i, half := 0, len(seen)/2 - for k := range seen { - if i >= half { break } - delete(seen, k) - i++ + half := len(seen) / 2 + for i := 0; i < half && i < len(*seenOrder); i++ { + delete(seen, (*seenOrder)[i]) + } + if half < len(*seenOrder) { + *seenOrder = append((*seenOrder)[:0], (*seenOrder)[half:]...) + } else { + *seenOrder = (*seenOrder)[:0] } } } @@ -144,7 +182,8 @@ func (b *Brain) StartMarketBriefs(hours []int) { sent := make(map[string]bool) for { select { - case <-b.stopCh: return + case <-b.stopCh: + return case now := <-ticker.C: key := now.Format("2006-01-02-15") for _, h := range hours { @@ -160,21 +199,35 @@ func (b *Brain) StartMarketBriefs(hours []int) { func (b *Brain) sendBrief(hour int) { title := "☀️ *早间市场简报*" - if hour >= 18 { title = "🌙 *晚间市场简报*" } + if hour >= 18 { + title = "🌙 *晚间市场简报*" + } // Fetch BTC/ETH prices for the brief var btcPrice, ethPrice, btcChg, ethChg string for _, sym := range []string{"BTCUSDT", "ETHUSDT"} { resp, err := b.http.Get(fmt.Sprintf("https://fapi.binance.com/fapi/v1/ticker/24hr?symbol=%s", sym)) - if err != nil { continue } + if err != nil { + continue + } body, readErr := safe.ReadAllLimited(resp.Body, 64*1024) // 64KB limit statusOK := resp.StatusCode == http.StatusOK resp.Body.Close() - if readErr != nil || !statusOK { continue } + if readErr != nil || !statusOK { + continue + } var t map[string]string - if err := json.Unmarshal(body, &t); err != nil { continue } - if sym == "BTCUSDT" { btcPrice = t["lastPrice"]; btcChg = t["priceChangePercent"] } - if sym == "ETHUSDT" { ethPrice = t["lastPrice"]; ethChg = t["priceChangePercent"] } + if err := json.Unmarshal(body, &t); err != nil { + continue + } + if sym == "BTCUSDT" { + btcPrice = t["lastPrice"] + btcChg = t["priceChangePercent"] + } + if sym == "ETHUSDT" { + ethPrice = t["lastPrice"] + ethChg = t["priceChangePercent"] + } } brief := fmt.Sprintf("%s\n\n• BTC: $%s (%s%%)\n• ETH: $%s (%s%%)\n\n_%s_", diff --git a/agent/central_brain.go b/agent/central_brain.go new file mode 100644 index 00000000..52c8d2d6 --- /dev/null +++ b/agent/central_brain.go @@ -0,0 +1,1489 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "nofx/mcp" + "nofx/store" +) + +// brainDecision is the routing contract between the first-pass LLM and the executor. +type brainDecision struct { + ThoughtProcess string `json:"thought_process"` + ActionType string `json:"action_type"` // CONTINUE_TASK | NEW_TASK | EXPLAIN_KNOWLEDGE | CANCEL_TASK + TargetSkill string `json:"target_skill,omitempty"` // "skill_name:action" for NEW_TASK + ExtractedData map[string]any `json:"extracted_data,omitempty"` + ReplyToUser string `json:"reply_to_user"` +} + +// activeSessionStepDecision is the per-turn control loop inside one active skill task. +type activeSessionStepDecision struct { + Route string `json:"route"` // ask_user | execute_skill | finish_task | cancel_task + Reply string `json:"reply,omitempty"` + ExtractedData map[string]any `json:"extracted_data,omitempty"` +} + +// tryMinimalBrain is the single entry point replacing tryUnifiedSemanticGateway. +// Intelligence layer: one routing LLM call → active-session loop → legacy skill execution. +func (a *Agent) tryMinimalBrain(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { + if a.aiClient == nil { + return "", false, nil + } + + activeSession, hasActive := a.getActiveSkillSession(userID) + recentHistory := a.buildRecentConversationContext(userID, text) + currentRefs := buildCurrentReferenceSummary(lang, a.semanticCurrentReferences(userID)) + previousAssistantReply := a.currentPendingHintText(userID) + + systemPrompt := buildBrainSystemPrompt(lang) + userPrompt := buildBrainUserPrompt(lang, text, previousAssistantReply, recentHistory, currentRefs, activeSession, hasActive) + + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return "", false, nil + } + + decision, ok := parseBrainDecision(raw) + if !ok { + return "", false, nil + } + + return a.executeBrainDecision(ctx, storeUserID, userID, lang, text, decision, activeSession, hasActive, onEvent) +} + +func buildBrainSystemPrompt(lang string) string { + return prependNOFXiAdvisorPreamble(`You are the central brain of NOFXi. Read the intelligence report and output ONE JSON decision. No markdown, no extra text. + +Available action_type values: +- "CONTINUE_TASK": user is continuing the current active task +- "NEW_TASK": user is starting a new task +- "EXPLAIN_KNOWLEDGE": user is asking a knowledge question only +- "CANCEL_TASK": user wants to stop the current task + +Available skills (for NEW_TASK target_skill): +trader_management, exchange_management, model_management, strategy_management, +trader_diagnosis, exchange_diagnosis, model_diagnosis, strategy_diagnosis + +Available actions: +create, update, update_name, update_bindings, configure_strategy, configure_exchange, configure_model, +update_status, update_endpoint, update_config, update_prompt, delete, start, stop, activate, duplicate, +query_list, query_detail, query_running + +Rules: +- Prefer CONTINUE_TASK when there is an active task and the user is still talking about it. +- If the current user message is only a greeting, thanks, acknowledgement, or lightweight social chat like "你好", "hi", "hello", "thanks", "谢谢", "收到", do NOT continue the task. +- For those lightweight social messages, choose EXPLAIN_KNOWLEDGE and reply naturally, or let the task stay suspended. +- Use NEW_TASK only when there is no active task, or the user clearly switches goals/domains. +- Use EXPLAIN_KNOWLEDGE for concept/range/help questions; do not change state. When answering, use ONLY the options/values listed in the active session's missing_required_fields. Never invent field values or provider names. +- For diagnosis, create, update, delete, start, stop, activation, duplication, or historical-performance analysis tasks, never reply only with a future promise such as "I'll do it now", "please wait", "diagnosis is running", or "I'll tell you later". If the next step is execution, choose the corresponding skill/planned execution. If execution is impossible, say exactly what information or data is missing. +- Use CANCEL_TASK for "cancel", "stop", "forget it", "never mind", "算了", "取消". +- Domain guard: if the user says "模型", "AI 模型", or "model" and asks to create or configure one, you must route to model_management, not exchange_management. +- Domain guard: for model_management, the field "provider" means the AI model vendor such as OpenAI, DeepSeek, Claude, Gemini, Qwen, Kimi, Grok, Minimax, claw402, blockrun-base, or blockrun-sol. It never means an exchange like Binance, OKX, Bybit, CFD, forex, or metals. +- extracted_data should include any concrete facts from the user's message. +- When an active session exposes allowed_field_spec_json, extracted_data must use only those canonical keys. Never output aliases, translated labels, or raw user wording as keys. +- If the user clearly means a bulk destructive operation like "删除所有策略" or "全部删除策略", put the intent signal into extracted_data too. Example: {"bulk_scope":"all"}. +- For strategy changes, do not use the generic "strategy_management:update" action. Use "strategy_management:update_name" for renaming, "strategy_management:update_prompt" for prompt changes, or "strategy_management:update_config" for parameter/config changes. For strategy_management:update_config, extracted_data may include a StrategyConfig-shaped "config_patch". +- Current references are context only. Do not turn a current reference into target_ref_id/target_ref_name unless the user explicitly names that object or clearly refers to "this/current/that previous one". If a mutating task has no clear target, ask instead of executing. +- reply_to_user should be concise and in the user's language. +- For NEW_TASK, target_skill format must be "skill_name:action", for example "strategy_management:create". + +Output shape (JSON only): +{"thought_process":"...","action_type":"...","target_skill":"...","extracted_data":{},"reply_to_user":"..."}`) +} + +func buildBrainUserPrompt(lang, text, previousAssistantReply, recentHistory, currentRefs string, activeSession ActiveSkillSession, hasActive bool) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Language: %s\nUser message: %s\n\n", lang, text)) + sb.WriteString("=== PREVIOUS ASSISTANT REPLY ===\n") + sb.WriteString(defaultIfEmpty(strings.TrimSpace(previousAssistantReply), "none")) + sb.WriteString("\n\n") + sb.WriteString("=== MANAGEMENT DOMAIN PRIMER ===\n") + if hasActive { + sb.WriteString(defaultIfEmpty(buildSkillDomainPrimerForSession(lang, activeToLegacySkillSession(activeSession)), "none")) + } else { + sb.WriteString(defaultIfEmpty(buildManagementDomainPrimer(lang), "none")) + } + sb.WriteString("\n\n") + + sb.WriteString("=== ACTIVE SESSION ===\n") + if hasActive { + sb.WriteString(fmt.Sprintf("skill: %s\naction: %s\n", activeSession.SkillName, activeSession.ActionName)) + if strings.TrimSpace(activeSession.Goal) != "" { + sb.WriteString(fmt.Sprintf("goal: %s\n", activeSession.Goal)) + } + if activeSession.PendingHint != nil && strings.TrimSpace(activeSession.PendingHint.Prompt) != "" { + sb.WriteString(fmt.Sprintf("pending_hint: %s\n", strings.TrimSpace(activeSession.PendingHint.Prompt))) + } + if len(activeSession.CollectedFields) > 0 { + fieldsJSON, _ := json.Marshal(activeSession.CollectedFields) + sb.WriteString(fmt.Sprintf("collected_fields: %s\n", fieldsJSON)) + } + if missing := fieldConstraintSummary(activeSession); missing != "" { + sb.WriteString("missing_required_fields:\n") + sb.WriteString(missing) + sb.WriteString("\n") + } + fieldSpecs := allowedFieldSpecsForSkillSession(activeToLegacySkillSession(activeSession), lang) + if len(fieldSpecs) > 0 { + fieldSpecsJSON, _ := json.Marshal(fieldSpecs) + sb.WriteString(fmt.Sprintf("allowed_field_spec_json: %s\n", fieldSpecsJSON)) + } + } else { + sb.WriteString("none\n") + } + + sb.WriteString("\n=== CURRENT REFERENCES ===\n") + sb.WriteString(currentRefs) + + sb.WriteString("\n\n=== RECENT CONVERSATION ===\n") + sb.WriteString(recentHistory) + + return sb.String() +} + +func parseBrainDecision(raw string) (brainDecision, bool) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + + var d brainDecision + if err := json.Unmarshal([]byte(raw), &d); err != nil { + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start < 0 || end <= start { + return brainDecision{}, false + } + if err := json.Unmarshal([]byte(raw[start:end+1]), &d); err != nil { + return brainDecision{}, false + } + } + d.ActionType = strings.ToUpper(strings.TrimSpace(d.ActionType)) + d.TargetSkill = strings.TrimSpace(d.TargetSkill) + d.ReplyToUser = strings.TrimSpace(d.ReplyToUser) + switch d.ActionType { + case "CONTINUE_TASK", "NEW_TASK", "EXPLAIN_KNOWLEDGE", "CANCEL_TASK": + return d, true + default: + return brainDecision{}, false + } +} + +func parseActiveSessionStepDecision(raw string) (activeSessionStepDecision, bool) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + + var d activeSessionStepDecision + if err := json.Unmarshal([]byte(raw), &d); err != nil { + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start < 0 || end <= start { + return activeSessionStepDecision{}, false + } + if err := json.Unmarshal([]byte(raw[start:end+1]), &d); err != nil { + return activeSessionStepDecision{}, false + } + } + d.Route = strings.TrimSpace(strings.ToLower(d.Route)) + d.Reply = strings.TrimSpace(d.Reply) + switch d.Route { + case "ask_user", "execute_skill", "finish_task", "cancel_task": + return d, true + default: + return activeSessionStepDecision{}, false + } +} + +func (a *Agent) executeBrainDecision(ctx context.Context, storeUserID string, userID int64, lang, text string, d brainDecision, activeSession ActiveSkillSession, hasActive bool, onEvent func(event, data string)) (string, bool, error) { + switch d.ActionType { + case "CANCEL_TASK": + a.clearActiveSkillSession(userID) + a.clearAnyActiveContext(userID) + reply := d.ReplyToUser + if reply == "" { + if lang == "zh" { + reply = "已取消当前流程。" + } else { + reply = "Cancelled the current flow." + } + } + emitBrainReply(onEvent, reply) + a.recordSkillInteraction(userID, text, reply) + return reply, true, nil + + case "EXPLAIN_KNOWLEDGE": + reply := d.ReplyToUser + if reply == "" { + return "", false, nil + } + if guarded, blocked := guardUnsupportedAsyncPromise(lang, reply); blocked { + reply = guarded + } + emitBrainReply(onEvent, reply) + a.recordSkillInteraction(userID, text, reply) + return reply, true, nil + + case "NEW_TASK": + skill, action := parseTargetSkill(d.TargetSkill) + if skill == "" { + answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent) + return answer, true, err + } + session := newActiveSkillSession(userID, skill, action) + session.Goal = strings.TrimSpace(text) + d.ExtractedData = filterExtractedDataForActiveSession(session, d.ExtractedData, lang) + markStrategyCreateConfigProgressThisTurn(&session, d.ExtractedData) + mergeExtractedData(&session, d.ExtractedData) + return a.driveActiveSession(ctx, storeUserID, userID, lang, text, session, onEvent) + + case "CONTINUE_TASK": + if !hasActive { + return "", false, nil + } + d.ExtractedData = filterExtractedDataForActiveSession(activeSession, d.ExtractedData, lang) + markStrategyCreateConfigProgressThisTurn(&activeSession, d.ExtractedData) + mergeExtractedData(&activeSession, d.ExtractedData) + return a.driveActiveSession(ctx, storeUserID, userID, lang, text, activeSession, onEvent) + + default: + return "", false, nil + } +} + +func (a *Agent) driveActiveSession(ctx context.Context, storeUserID string, userID int64, lang, text string, session ActiveSkillSession, onEvent func(event, data string)) (string, bool, error) { + session = appendActiveSessionLocalHistory(session, "user", text) + clearActiveSessionPendingHint(&session) + + stepDecision, ok := a.planActiveSessionStep(ctx, storeUserID, userID, lang, text, session) + if !ok { + stepDecision = activeSessionStepDecision{} + } + configProgressThisTurn := consumeStrategyCreateConfigProgressThisTurn(&session) + if strategyCreateDecisionHasConfigProgress(session, stepDecision.ExtractedData) { + configProgressThisTurn = true + } + mergeExtractedData(&session, stepDecision.ExtractedData) + maybeForceStrategyCreateExecutionOnConfirmation(lang, text, &session, &stepDecision) + + if stepDecision.Route == "" { + if len(missingRequiredFields(session)) > 0 { + stepDecision.Route = "ask_user" + } else { + stepDecision.Route = "execute_skill" + } + } + switch stepDecision.Route { + case "cancel_task": + a.clearActiveSkillSession(userID) + reply := defaultIfEmpty(stepDecision.Reply, "已取消当前流程。") + if lang != "zh" && strings.TrimSpace(stepDecision.Reply) == "" { + reply = "Cancelled the current flow." + } + emitBrainReply(onEvent, reply) + a.recordSkillInteraction(userID, text, reply) + return reply, true, nil + + case "finish_task": + reply := strings.TrimSpace(stepDecision.Reply) + if guarded, blocked := guardUnexecutedActiveTaskCompletion(lang, session, reply); blocked { + session = appendActiveSessionLocalHistory(session, "assistant", guarded) + setActiveSessionPendingHint(&session, guarded) + a.saveActiveSkillSession(session) + emitBrainReply(onEvent, guarded) + a.recordSkillInteraction(userID, text, guarded) + return guarded, true, nil + } + if guarded, blocked := guardUnsupportedAsyncPromise(lang, reply); blocked { + session = appendActiveSessionLocalHistory(session, "assistant", guarded) + setActiveSessionPendingHint(&session, guarded) + a.saveActiveSkillSession(session) + emitBrainReply(onEvent, guarded) + a.recordSkillInteraction(userID, text, guarded) + return guarded, true, nil + } + a.clearActiveSkillSession(userID) + if reply == "" { + return "", false, nil + } + emitBrainReply(onEvent, reply) + a.recordSkillInteraction(userID, text, reply) + return reply, true, nil + + case "ask_user": + reply := "" + if guarded, blocked := guardStrategyCreateBeforeFinalConfirmation(lang, session); blocked { + session.CollectedFields["awaiting_final_confirmation"] = true + reply = guarded + } + if reply == "" && configProgressThisTurn { + if deterministic, ok := strategyCreateTemplateMissingReply(lang, text, session); ok { + reply = deterministic + } + } + if reply == "" { + reply = strings.TrimSpace(stepDecision.Reply) + if reply == "" { + reply = a.askForMissingFields(lang, session) + } + } + if guarded, blocked := guardStrategyCreateAINonTemplateQuestion(lang, session, reply); blocked { + reply = guarded + } + if guarded, blocked := guardUnsupportedAsyncPromise(lang, reply); blocked { + reply = guarded + } + if len(missingRequiredFields(session)) == 0 && actionNeedsConfirmation(session.SkillName, session.ActionName) { + session.LegacyPhase = "await_confirmation" + session.CollectedFields["phase"] = "await_confirmation" + } + session = appendActiveSessionLocalHistory(session, "assistant", reply) + setActiveSessionPendingHint(&session, reply) + a.saveActiveSkillSession(session) + emitBrainReply(onEvent, reply) + a.recordSkillInteraction(userID, text, reply) + return reply, true, nil + + case "execute_skill": + var repairReply string + var canExecute bool + session, repairReply, canExecute = a.ensureStrategyCreateExecutableState(ctx, lang, text, session) + if !canExecute { + if strategyCreateLooseConfirmationReply(text) { + repairReply = a.askForMissingFields(lang, session) + } else { + repairReply = defaultIfEmpty(repairReply, a.askForMissingFields(lang, session)) + } + session = appendActiveSessionLocalHistory(session, "assistant", repairReply) + setActiveSessionPendingHint(&session, repairReply) + a.saveActiveSkillSession(session) + emitBrainReply(onEvent, repairReply) + a.recordSkillInteraction(userID, text, repairReply) + return repairReply, true, nil + } + if !strategyCreateLooseConfirmationReply(text) { + if guarded, blocked := guardStrategyCreateBeforeFinalConfirmation(lang, session); blocked { + session.CollectedFields["awaiting_final_confirmation"] = true + session = appendActiveSessionLocalHistory(session, "assistant", guarded) + setActiveSessionPendingHint(&session, guarded) + a.saveActiveSkillSession(session) + emitBrainReply(onEvent, guarded) + a.recordSkillInteraction(userID, text, guarded) + return guarded, true, nil + } + } + outcome, nextSession, pending, ok := a.executeActiveSkillSession(storeUserID, userID, lang, text, session) + if !ok { + return "", false, nil + } + if pending { + reply := strings.TrimSpace(outcome.UserMessage) + if reply == "" { + reply = a.askForMissingFields(lang, nextSession) + } + nextSession = appendActiveSessionLocalHistory(nextSession, "assistant", reply) + setActiveSessionPendingHint(&nextSession, reply) + a.saveActiveSkillSession(nextSession) + emitBrainReply(onEvent, reply) + a.recordSkillInteraction(userID, text, reply) + return reply, true, nil + } + + if shouldTrustDeterministicSkillReply(outcome) { + answer := strings.TrimSpace(outcome.UserMessage) + if answer == "" { + return "", false, nil + } + a.clearActiveSkillSession(userID) + emitBrainReply(onEvent, answer) + a.recordSkillInteraction(userID, text, answer) + return answer, true, nil + } + + review, err := a.reviewTaskCompletion(ctx, userID, lang, text, outcome) + if err != nil { + review = taskReviewDecision{Route: "complete", Answer: outcome.UserMessage} + } + answer := strings.TrimSpace(review.Answer) + if answer == "" { + answer = strings.TrimSpace(outcome.UserMessage) + } + if review.Route == "replan" && answer == "" { + answer = outcome.UserMessage + } + if answer == "" { + return "", false, nil + } + a.clearActiveSkillSession(userID) + emitBrainReply(onEvent, answer) + a.recordSkillInteraction(userID, text, answer) + return answer, true, nil + + default: + return "", false, nil + } +} + +func strategyCreateLooseConfirmationReply(text string) bool { + if strategyCreateConfirmationReply(text) { + return true + } + lower := strings.ToLower(strings.TrimSpace(text)) + return strings.Contains(lower, "确认创建") || + strings.Contains(lower, "按这个创建") || + strings.Contains(lower, "confirm create") +} + +func (a *Agent) ensureStrategyCreateExecutableState(ctx context.Context, lang, text string, session ActiveSkillSession) (ActiveSkillSession, string, bool) { + if session.SkillName != "strategy_management" || session.ActionName != "create" { + return session, "", true + } + if strategyCreateSessionReady(lang, session) { + return session, "", true + } + if a.aiClient == nil { + return session, "", true + } + + legacy := activeToLegacySkillSession(session) + collectedJSON, _ := json.Marshal(session.CollectedFields) + fieldSpecsJSON, _ := json.Marshal(allowedFieldSpecsForSkillSession(legacy, lang)) + history := formatActiveSessionLocalHistory(session.LocalHistory) + if history == "" { + history = "(empty)" + } + systemPrompt := prependNOFXiAdvisorPreamble(`You repair structured state for one active NOFXi strategy creation task. +Return JSON only. + +Rules: +- Think from the current user message, previous assistant proposal, and active history. +- If the previous assistant already asked the user to confirm a concrete creation proposal in chat and the current user confirms it, set extracted_data.awaiting_final_confirmation=true too. +- For each user message, decide how it relates to the currently selected strategy product template. +- If the message provides explicit values, corrections, preferences, constraints, or asks you to recommend/design, translate only the determinable template fields into extracted_data.config_patch as a StrategyConfig-shaped JSON patch. +- If the message is only a question, explanation request, greeting, or unrelated text, answer it without inventing config_patch. +- Do not silently fill missing fields when the user has not authorized it. But if the user explicitly says things like "你帮我定 / 你推荐 / 按稳健高频设计 / 其他你定", that is authorization for the Agent to design the remaining fields. In that case you must produce a recommended config_patch based on the current strategy template and field limits, and explain which values came from the user versus which values are Agent recommendations. +- The product editor template is the source of truth. Use only fields from the selected product template. +- If the user switches strategy type, set extracted_data.strategy_type to the new type and discard fields from the previous type. Keep only shared fields such as name/description/publish settings. +- In NOFXi product schema, AI500/OI Top/OI Low/static coin-source requests are ai_trading, not grid_trading. +- Strategy creation is chat-executable. Do not tell the user to click a web/app button, open a page, or manually create it elsewhere. +- Do not claim the strategy was created and do not promise future execution ("马上创建", "正在创建", "稍后通知"). This step only repairs state or asks for missing information. +- When the current user message is a confirmation, prefer route="ready" whenever the structured template can be repaired. If it cannot be repaired, route="ask_user" with only the missing fields; never reply that you are about to create it. +- If the template is still incomplete after applying determinable config_patch, ask one natural follow-up question or explain the missing fields. + +Return shape: +{"route":"ready|ask_user","reply":"","extracted_data":{}}`) + userPrompt := fmt.Sprintf("Language: %s\nCurrent user message: %s\n\nCurrent collected fields JSON:\n%s\n\nAllowed field spec JSON:\n%s\n\nActive task history:\n%s", lang, text, string(collectedJSON), string(fieldSpecsJSON), history) + + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return session, "", false + } + decision, ok := parseStrategyCreateStateRepairDecision(raw) + if !ok { + return session, "", false + } + decision.ExtractedData = filterExtractedDataForActiveSession(session, decision.ExtractedData, lang) + mergeExtractedData(&session, decision.ExtractedData) + if decision.Route == "ask_user" { + return session, strings.TrimSpace(decision.Reply), false + } + if strategyCreateSessionReady(lang, session) { + return session, strings.TrimSpace(decision.Reply), true + } + return session, strings.TrimSpace(decision.Reply), false +} + +type strategyCreateStateRepairDecision struct { + Route string `json:"route"` + Reply string `json:"reply,omitempty"` + ExtractedData map[string]any `json:"extracted_data,omitempty"` +} + +func parseStrategyCreateStateRepairDecision(raw string) (strategyCreateStateRepairDecision, bool) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + var d strategyCreateStateRepairDecision + if err := json.Unmarshal([]byte(raw), &d); err != nil { + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start < 0 || end <= start { + return strategyCreateStateRepairDecision{}, false + } + if err := json.Unmarshal([]byte(raw[start:end+1]), &d); err != nil { + return strategyCreateStateRepairDecision{}, false + } + } + d.Route = strings.ToLower(strings.TrimSpace(d.Route)) + d.Reply = strings.TrimSpace(d.Reply) + switch d.Route { + case "ready", "ask_user": + return d, true + default: + return strategyCreateStateRepairDecision{}, false + } +} + +func strategyCreateSessionReady(lang string, session ActiveSkillSession) bool { + legacy := activeToLegacySkillSession(session) + cfg, _, _, err := strategyCreateConfigFromSession(legacy, lang) + if err != nil { + return false + } + ready, _ := strategyCreateConfigReady(legacy, cfg, "") + return ready +} + +func strategyCreateDecisionHasConfigProgress(session ActiveSkillSession, data map[string]any) bool { + if session.SkillName != "strategy_management" || session.ActionName != "create" || len(data) == 0 { + return false + } + patch, ok := data[strategyCreateConfigPatchField] + if !ok { + return false + } + sanitized := sanitizeStrategyCreateConfigPatchForType(patch, defaultIfEmpty(strategyTypeFromExtractedData(data), strategyTypeFromCollectedFields(session.CollectedFields))) + return len(sanitized) > 0 +} + +const strategyCreateConfigProgressThisTurnField = "__strategy_create_config_progress_this_turn" + +func markStrategyCreateConfigProgressThisTurn(session *ActiveSkillSession, data map[string]any) { + if session == nil || !strategyCreateDecisionHasConfigProgress(*session, data) { + return + } + if session.CollectedFields == nil { + session.CollectedFields = map[string]any{} + } + session.CollectedFields[strategyCreateConfigProgressThisTurnField] = true +} + +func consumeStrategyCreateConfigProgressThisTurn(session *ActiveSkillSession) bool { + if session == nil || session.CollectedFields == nil { + return false + } + progress := activeFieldBool(session.CollectedFields[strategyCreateConfigProgressThisTurnField]) + delete(session.CollectedFields, strategyCreateConfigProgressThisTurnField) + return progress +} + +func maybeForceStrategyCreateExecutionOnConfirmation(lang, text string, session *ActiveSkillSession, decision *activeSessionStepDecision) bool { + if session == nil || decision == nil { + return false + } + if session.SkillName != "strategy_management" || session.ActionName != "create" { + return false + } + if !strategyCreateLooseConfirmationReply(text) { + return false + } + if !strategyCreateSessionReady(lang, *session) { + return false + } + if session.CollectedFields == nil { + session.CollectedFields = map[string]any{} + } + session.CollectedFields["awaiting_final_confirmation"] = true + decision.Route = "execute_skill" + decision.Reply = "" + return true +} + +func (a *Agent) activeStrategyCreateSession(userID int64) (ActiveSkillSession, bool) { + if session, ok := a.getActiveSkillSession(userID); ok && session.SkillName == "strategy_management" && session.ActionName == "create" { + return session, true + } + if legacy := a.getSkillSession(userID); legacy.Name == "strategy_management" && legacy.Action == "create" { + return activeSessionFromLegacy(ActiveSkillSession{ + UserID: userID, + SkillName: "strategy_management", + ActionName: "create", + }, legacy), true + } + return ActiveSkillSession{}, false +} + +func guardStrategyCreateBeforeFinalConfirmation(lang string, session ActiveSkillSession) (string, bool) { + if session.SkillName != "strategy_management" || session.ActionName != "create" { + return "", false + } + if activeFieldBool(session.CollectedFields["awaiting_final_confirmation"]) && strategyCreateHasPriorConfirmationPrompt(session) { + return "", false + } + legacy := activeToLegacySkillSession(session) + cfg, _, _, err := strategyCreateConfigFromSession(legacy, lang) + if err != nil { + return "", false + } + if ready, _ := strategyCreateConfigReady(legacy, cfg, ""); !ready { + return "", false + } + return formatStrategyCreateFinalConfirmation(lang, legacy, cfg), true +} + +func strategyCreateTemplateMissingReply(lang, text string, session ActiveSkillSession) (string, bool) { + if session.SkillName != "strategy_management" || session.ActionName != "create" { + return "", false + } + legacy := activeToLegacySkillSession(session) + cfg, _, _, err := strategyCreateConfigFromSession(legacy, lang) + if err != nil { + return "", false + } + ready, missingKind := strategyCreateConfigReady(legacy, cfg, "") + if ready || strings.TrimSpace(missingKind) == "" { + return "", false + } + if reply := formatStrategyCreateFieldOptionsReply(lang, text, missingKind); reply != "" { + return reply, true + } + return formatStrategyCreateConfigNeeded(lang, missingKind), true +} + +func strategyCreateHasPriorConfirmationPrompt(session ActiveSkillSession) bool { + for i := len(session.LocalHistory) - 1; i >= 0; i-- { + msg := session.LocalHistory[i] + if msg.Role != "assistant" { + continue + } + content := strings.TrimSpace(msg.Content) + if content == "" { + continue + } + lower := strings.ToLower(content) + return strings.Contains(content, "确认创建") || + strings.Contains(content, "确认后我再创建") || + strings.Contains(content, "配置整理好了") || + strings.Contains(content, "请确认是否按以上设置创建") || + strings.Contains(content, "如果没问题,我就执行创建") || + strings.Contains(content, "是否按以上设置创建") || + strings.Contains(lower, "confirm") || + strings.Contains(lower, "create it") + } + return false +} + +func activeFieldBool(v any) bool { + switch typed := v.(type) { + case bool: + return typed + case string: + return strings.EqualFold(strings.TrimSpace(typed), "true") + default: + return false + } +} + +func guardUnexecutedActiveTaskCompletion(lang string, session ActiveSkillSession, reply string) (string, bool) { + if !isMutatingActiveTask(session) || !looksLikeCompletionClaim(reply) { + return "", false + } + if lang == "zh" { + if session.SkillName == "strategy_management" { + return "还没有真正创建到策略列表里。刚才只是整理/确认配置方案;需要继续的话,我会先用结构化配置调用策略创建工具,再基于真实结果回复。", true + } + return "还没有真正执行完成。刚才只是继续当前配置流程;需要实际执行时,我会调用对应工具后再基于真实结果回复。", true + } + return "It has not actually been executed yet. The previous step only prepared or confirmed the draft; I need to run the structured tool before claiming completion.", true +} + +func guardStrategyCreateAINonTemplateQuestion(lang string, session ActiveSkillSession, reply string) (string, bool) { + if session.SkillName != "strategy_management" || session.ActionName != "create" { + return "", false + } + if strategyTypeFromCollectedFields(session.CollectedFields) != "ai_trading" { + return "", false + } + lower := strings.ToLower(strings.TrimSpace(reply)) + if lower == "" { + return "", false + } + if !containsAny(lower, []string{ + "投入多少", "投入资金", "总投入", "固定投入", "每笔交易", "每笔固定", "100u", "500u", "1000u", + "止损", "日亏损", "最大回撤", + "investment amount", "capital", "fixed amount", "per-trade", "stop loss", "daily loss", "max drawdown", + }) { + return "", false + } + legacy := activeToLegacySkillSession(session) + cfg, _, _, err := strategyCreateConfigFromSession(legacy, lang) + if err != nil { + return "", false + } + _, missingKind := strategyCreateConfigReady(legacy, cfg, "") + if strings.TrimSpace(missingKind) == "" { + return "", false + } + if lang == "zh" { + return "这些不是 AI 策略创建模板里的字段。我会继续按 AI 策略模板填写;当前还需要围绕选币来源、周期、杠杆、置信度、盈亏比、交易频率和开仓标准来确定配置。你也可以直接说“全部你定,按稳健/高频/激进”。", true + } + return "Those are not fields in the AI strategy creation template. I will continue using the AI strategy template: coin source, timeframes, leverage, confidence, risk/reward, trading frequency, and entry standards.", true +} + +func guardUnsupportedAsyncPromise(lang, reply string) (string, bool) { + lower := strings.ToLower(strings.TrimSpace(reply)) + if lower == "" { + return "", false + } + promiseSignals := []string{ + "请稍等", "稍等片刻", "再稍等", "马上", "稍后", "立刻告诉", "数据一出来", "一两分钟", + "还在进行", "正在进行", "正在为", "正在帮", "一直在帮", "诊断中", "分析中", + "please wait", "give me a moment", "still running", "i'll let you know", "i will let you know", + } + taskSignals := []string{ + "诊断", "分析", "历史交易", "历史表现", "亏损原因", "创建", "修改", "删除", "启动", "停止", + "diagnos", "analyz", "history", "performance", "loss", "create", "update", "delete", "start", "stop", + } + if !containsAny(lower, promiseSignals) || !containsAny(lower, taskSignals) { + return "", false + } + if lang == "zh" { + return "我需要纠正一下:我没有后台异步任务在运行,也不会自动推送后续结果。诊断/创建/修改/启动这类任务必须在当前回复里实际执行并给出真实结果;如果还不能执行,我应该直接说明缺少哪个对象、时间范围或数据。", true + } + return "I need to correct that: there is no background task running, and I will not automatically push a later result. Diagnosis/create/update/start tasks must actually execute and return a real result in the current response; if execution is not possible, I should state which target, range, or data is missing.", true +} + +func isMutatingActiveTask(session ActiveSkillSession) bool { + if strings.TrimSpace(session.SkillName) == "" { + return false + } + switch strings.TrimSpace(session.ActionName) { + case "create", "update", "update_name", "update_bindings", "configure_strategy", "configure_exchange", "configure_model", "update_status", "update_endpoint", "update_config", "update_prompt", "delete", "start", "stop", "activate", "duplicate": + return true + default: + return false + } +} + +func looksLikeCompletionClaim(reply string) bool { + lower := strings.ToLower(strings.TrimSpace(reply)) + if lower == "" { + return false + } + return containsAny(lower, []string{ + "已创建", "创建好了", "创建好", "已经创建", "已更新", "更新好了", "已修改", "已删除", "已启动", "已停止", "已激活", "已复制", "已经完成", "已完成", + "created", "has been created", "updated", "deleted", "started", "stopped", "activated", "duplicated", "completed", + }) +} + +func (a *Agent) planActiveSessionStep(ctx context.Context, storeUserID string, userID int64, lang, text string, session ActiveSkillSession) (activeSessionStepDecision, bool) { + if a.aiClient == nil { + return activeSessionStepDecision{}, false + } + + legacy := activeToLegacySkillSession(session) + resources := a.buildActiveSessionResources(storeUserID, legacy) + resourcesJSON, _ := json.Marshal(resources) + collectedJSON, _ := json.Marshal(session.CollectedFields) + missingSummary := formatConversationMissingFields(lang, missingRequiredFieldsForBrain(session)) + fieldSpecs := allowedFieldSpecsForSkillSession(legacy, lang) + fieldSpecsJSON, _ := json.Marshal(fieldSpecs) + localHistory := formatActiveSessionLocalHistory(session.LocalHistory) + if localHistory == "" { + localHistory = "(empty)" + } + previousAssistantReply := a.currentPendingHintText(userID) + + domainPrimer := buildSkillDomainPrimerForSession(lang, legacy) + specificRules := activeSessionSpecificRules(legacy) + + systemPrompt := prependNOFXiAdvisorPreamble(fmt.Sprintf(`You are the active-task orchestration loop for NOFXi. +You decide the NEXT step for exactly one active task. Return JSON only. + +Active task: +- skill: %s +- action: %s +- goal: %s + +Current collected fields: +%s + +Current missing field summary: +%s + +Relevant disclosed resources: +%s + +Allowed field spec JSON: +%s + +Domain knowledge: +%s + +Rules: +- Your job is to decide the next move, not to explain internal schema names. +- Read the previous assistant reply carefully. The user's short answer may be replying to that exact proposal, confirmation request, or question. +- Use contextual memory from the active task history and current references. +- Prefer "execute_skill" when the user has already given enough information to act. +- Prefer "ask_user" only when something truly necessary is still missing. +%s +- For any mutating task, a reply that only promises future execution ("now I will create/update/start it", "result soon") is not a valid finish_task or ask_user outcome. If execution is the next step, choose execute_skill. +- For diagnosis, create, update, delete, start, stop, query/history, and performance-analysis tasks, never answer with only "马上处理 / 请稍等 / 诊断中 / I'll tell you later". NOFXi has no background chat job that will later push an answer. Choose execute_skill/planned_agent when enough information exists; otherwise ask for the missing target/range/data. +- Never choose finish_task for an unfinished mutating active task by claiming it was created/updated/deleted/started/stopped. Only a real skill/tool execution outcome can support that claim. +- If the user says they do not understand the current form, choices, or required information, choose "ask_user" and explain the current pending question in plain language before asking the next easiest question. Cover the relevant concepts from the previous assistant reply; do not collapse the answer to only the first missing field. +- For beginner/confusion replies, give a safe recommended path when the domain supports one, but do not execute or create anything unless the user confirms after the explanation. +- If the current message is only a greeting, thanks, acknowledgement, or small talk and does not add task information, do NOT continue task execution. Choose "ask_user" only if you need to gently restate what is pending; otherwise choose "finish_task" with a short social reply. +- Ask naturally. Do not say raw slot names like target_ref unless the user explicitly asks for internal details. +- If the user clearly means a bulk destructive operation like "删除所有策略", "全部删除策略", "all strategies", set extracted_data to {"bulk_scope":"all"} and choose "execute_skill". Do not ask for target_ref. +- If the user refers to a specific object from disclosed targets, set target_ref_id and target_ref_name when you can resolve it. +- Current references are context for reasoning only. Do not copy a current reference into target_ref_id/target_ref_name unless the user explicitly refers to that object by name/id or clearly says "this/current/that previous one". If the target is not clear, ask instead of executing. +- For trader bindings, exchange/model/strategy must resolve to an ID from Relevant disclosed resources before execution. Never invent a resource name or use a generic venue type like Binance/OKX as the bound exchange unless it appears as an actual disclosed resource. +- If there are multiple targets and the user did not disambiguate, ask a natural question with the available names. +- If the current user message answers a missing field directly, extract it and continue. +- extracted_data must use only canonical keys from Allowed field spec JSON. Never output aliases, translated labels, or raw user wording as keys. +- If a user-provided value does not fit one of those canonical keys, omit it; never create another key. +- If this task is already done and the best next step is just to tell the user the result, choose "finish_task". +- If the user aborts the task, choose "cancel_task". + +Return JSON with this exact shape: +{"route":"ask_user|execute_skill|finish_task|cancel_task","reply":"","extracted_data":{}}`, + session.SkillName, + session.ActionName, + defaultIfEmpty(session.Goal, "(not set)"), + defaultIfEmpty(string(collectedJSON), "{}"), + missingSummary, + defaultIfEmpty(string(resourcesJSON), "{}"), + defaultIfEmpty(string(fieldSpecsJSON), "[]"), + defaultIfEmpty(domainPrimer, "(none)"), + specificRules, + )) + userPrompt := fmt.Sprintf("Language: %s\nCurrent user message: %s\n\nPrevious assistant reply:\n%s\n\nActive task local history:\n%s\n", lang, text, defaultIfEmpty(previousAssistantReply, "(empty)"), localHistory) + + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return activeSessionStepDecision{}, false + } + decision, ok := parseActiveSessionStepDecision(raw) + if !ok { + return activeSessionStepDecision{}, false + } + decision.ExtractedData = filterExtractedDataForActiveSession(session, decision.ExtractedData, lang) + return decision, true +} + +func activeSessionSpecificRules(session skillSession) string { + if session.Name != "strategy_management" { + return "" + } + switch session.Action { + case "create", "update_config": + return strings.Join([]string{ + "- For strategy_management:create/update_config, the selected product editor template is the only schema. Write values only through extracted_data.config_patch, using the current type branch only: ai_trading => strategy_type + ai_config + publish_config; grid_trading => strategy_type + grid_config + publish_config.", + "- For strategy_management:create/update_config, config_patch values must be product schema raw values, not user-facing labels. Examples: source_type=\"ai500\" not \"AI500\"; strategy_type=\"ai_trading\" not \"AI 策略\"; selected_timeframes=[\"1m\",\"5m\",\"15m\"] not a JSON string.", + "- For strategy_management:create, AI500/OI Top/OI Low/static coin-source requests imply strategy_type=\"ai_trading\". Do not leave strategy type ambiguous in that case.", + "- For strategy_management:create/update_config, judge the user's natural-language intent. Explicit values, corrections, constraints, preferences, or requests to recommend/design must become config_patch for every determinable current-template field; pure questions/greetings/acknowledgements must not invent config_patch.", + "- For strategy_management:create, the Relevant disclosed resources include product_default_template and current_missing_template_fields. Treat product_default_template as the product editor's default template and field shape.", + "- For strategy_management:create, do not ask for or present fields listed in product_default_template.non_fields. They are not part of the selected product editor template.", + "- For strategy_management:create, when the user states a strategy style/preference or authorizes the Agent to recommend/design remaining settings, use product_default_template as the base, adjust it to the user's stated preference, and output config_patch that fills every determinable missing template field. Do not ask the user to restate fields that can be responsibly selected from the default template.", + "- For grid_trading create, if the user authorizes the Agent to choose/recommend remaining settings, set grid_config.use_atr_bounds=true for the price range unless the user explicitly gives manual upper_price/lower_price. Never invent current market prices or say a symbol is currently near a price without a fresh market-data tool result.", + "- For strategy_management:create, any user-facing strategy plan must be generated from the post-merge structured config built from config_patch and the current strategy type. Do not display fields that would be filtered out or belong to the other strategy type.", + "- For strategy_management:create, once complete, ask for one chat confirmation with awaiting_final_confirmation=true; after confirmation execute synchronously with empty reply and only report success after the tool returns.", + }, "\n") + default: + return "" + } +} + +func (a *Agent) executeActiveSkillSession(storeUserID string, userID int64, lang, text string, session ActiveSkillSession) (skillOutcome, ActiveSkillSession, bool, bool) { + legacy := activeToLegacySkillSession(session) + a.saveSkillSession(userID, legacy) + answer, handled := a.dispatchBridgedSkillSession(storeUserID, userID, lang, text, legacy) + if !handled { + a.clearSkillSession(userID) + return skillOutcome{}, ActiveSkillSession{}, false, false + } + + updatedLegacy := a.getSkillSession(userID) + a.clearSkillSession(userID) + outcome := inferSkillOutcome(session.SkillName, session.ActionName, answer, updatedLegacy, skillDataForAction(storeUserID, session.SkillName, session.ActionName, a)) + if updatedLegacy.Name != "" { + nextSession := activeSessionFromLegacy(session, updatedLegacy) + return outcome, nextSession, true, true + } + return outcome, ActiveSkillSession{}, false, true +} + +func shouldTrustDeterministicSkillReply(outcome skillOutcome) bool { + if outcome.Status != skillOutcomeSuccess || !outcome.GoalAchieved { + return false + } + switch outcome.Skill { + case "strategy_management", "trader_management", "model_management", "exchange_management": + switch outcome.Action { + case "create", "update", "update_name", "update_bindings", "configure_strategy", "configure_exchange", "configure_model", "update_status", "update_endpoint", "update_config", "update_prompt", "delete", "start", "stop", "activate", "duplicate": + return true + } + } + return false +} + +func (a *Agent) askForMissingFields(lang string, session ActiveSkillSession) string { + missing := missingRequiredFieldsForBrain(session) + if len(missing) == 0 { + if lang == "zh" { + return "还需要一点信息,我再继续。" + } + return "I need a bit more information before I continue." + } + + if session.SkillName == "model_management" && session.ActionName == "create" { + for _, field := range missing { + if field == "provider" { + return modelProviderChoicePrompt(lang) + } + } + } + + def, ok := getSkillDefinition(session.SkillName) + if !ok { + if lang == "zh" { + return "还需要更多信息,请继续。" + } + return "I need a bit more information to continue." + } + + labels := make([]string, 0, len(missing)) + for _, field := range missing { + label := slotDisplayName(field, lang) + if constraint, ok := def.FieldConstraints[field]; ok { + desc := strings.TrimSpace(constraint.Description) + if len(constraint.Values) > 0 { + desc = strings.Join(constraint.Values, " / ") + } + if desc != "" { + label = fmt.Sprintf("%s(%s)", label, desc) + } + } + labels = append(labels, label) + } + + if lang == "zh" { + return "还差一点信息,我才能继续:" + strings.Join(labels, "、") + "。" + } + return "I still need a bit more information before I can continue: " + strings.Join(labels, ", ") + "." +} + +func activeToLegacySkillSession(s ActiveSkillSession) skillSession { + legacy := skillSession{ + Name: s.SkillName, + Action: s.ActionName, + Phase: defaultIfEmpty(strings.TrimSpace(s.LegacyPhase), "executing"), + Fields: make(map[string]string), + } + for k, v := range s.CollectedFields { + str := activeFieldString(v) + if str == "" || str == "" { + continue + } + switch k { + case "phase": + legacy.Phase = str + case "target_ref_id": + ensureTargetRef(&legacy) + legacy.TargetRef.ID = str + case "target_ref_name": + ensureTargetRef(&legacy) + legacy.TargetRef.Name = str + case "target_ref": + ensureTargetRef(&legacy) + if legacy.TargetRef.ID == "" { + legacy.TargetRef.ID = str + } + if legacy.TargetRef.Name == "" { + legacy.TargetRef.Name = str + } + default: + legacy.Fields[k] = str + } + } + if s.SkillName == "strategy_management" && s.ActionName == "create" && legacy.Fields["name"] == "" { + for i := len(s.LocalHistory) - 1; i > 0; i-- { + msg := s.LocalHistory[i] + if msg.Role != "user" || !activeHistoryMessageAsksStrategyName(s.LocalHistory[i-1].Content) { + continue + } + if inferred := inferStandaloneStrategyName(msg.Content); inferred != "" { + legacy.Fields["name"] = inferred + break + } + } + } + return legacy +} + +func activeFieldString(value any) string { + switch v := value.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(v) + case map[string]any, []any, map[string]string, []string: + raw, err := json.Marshal(v) + if err != nil { + return "" + } + return strings.TrimSpace(string(raw)) + default: + return strings.TrimSpace(fmt.Sprint(v)) + } +} + +func activeSessionFromLegacy(base ActiveSkillSession, legacy skillSession) ActiveSkillSession { + next := base + next.LegacyPhase = strings.TrimSpace(legacy.Phase) + if next.CollectedFields == nil { + next.CollectedFields = map[string]any{} + } + for key, value := range legacy.Fields { + value = strings.TrimSpace(value) + if value == "" { + continue + } + next.CollectedFields[key] = value + } + if legacy.TargetRef != nil { + if value := strings.TrimSpace(legacy.TargetRef.ID); value != "" { + next.CollectedFields["target_ref_id"] = value + } + if value := strings.TrimSpace(legacy.TargetRef.Name); value != "" { + next.CollectedFields["target_ref_name"] = value + } + } + return next +} + +func ensureTargetRef(s *skillSession) { + if s.TargetRef == nil { + s.TargetRef = &EntityReference{} + } +} + +func (a *Agent) buildActiveSessionResources(storeUserID string, session skillSession) map[string]any { + switch session.Name { + case "trader_management": + if session.Action == "create" { + return a.buildTraderCreateConversationResources(storeUserID, session) + } + return a.buildSimpleEntityConversationResources(storeUserID, session, a.loadTraderOptions(storeUserID)) + case "exchange_management": + return a.buildSimpleEntityConversationResources(storeUserID, session, a.loadExchangeOptions(storeUserID)) + case "model_management": + return a.buildSimpleEntityConversationResources(storeUserID, session, a.loadEnabledModelOptions(storeUserID)) + case "strategy_management": + resources := a.buildSimpleEntityConversationResources(storeUserID, session, a.loadStrategyOptions(storeUserID)) + if strategyType := explicitStrategyCreateType(session); strategyType != "" { + lang := defaultIfEmpty(a.config.Language, "zh") + resources["current_strategy_type"] = strategyType + resources["current_editable_fields"] = manualStrategyEditableFieldKeysForType(strategyType) + if session.Action == "create" || session.Action == "update_config" { + resources["product_default_template"] = strategyProductDefaultTemplateResource(lang, strategyType) + if cfg, _, _, err := strategyCreateConfigFromSession(session, lang); err == nil { + resources["current_missing_template_fields"] = strategyCreateMissingTemplateFields(session, cfg) + } + } + } else if strategyType, ok := a.strategyTypeForTarget(storeUserID, session.TargetRef); ok { + lang := defaultIfEmpty(a.config.Language, "zh") + resources["target_strategy_type"] = strategyType + resources["target_editable_fields"] = manualStrategyEditableFieldKeysForType(strategyType) + resources["product_default_template"] = strategyProductDefaultTemplateResource(lang, strategyType) + } + return resources + default: + return nil + } +} + +func strategyProductDefaultTemplateResource(lang, strategyType string) map[string]any { + cfg := store.GetDefaultStrategyConfig(defaultIfEmpty(lang, "zh")) + cfg.StrategyType = strings.TrimSpace(strategyType) + cfg.ClampLimits() + publish := map[string]any{ + "is_public": false, + "config_visible": true, + } + switch cfg.StrategyType { + case "grid_trading": + grid := cfg.GridConfig + if grid == nil { + defaultGrid := store.DefaultGridStrategyConfig() + grid = &defaultGrid + } + return map[string]any{ + "strategy_type": "grid_trading", + "grid_config": grid, + "publish_config": publish, + "required_fields": strategyCreateMissingGridFields(skillSession{}), + } + default: + return map[string]any{ + "strategy_type": "ai_trading", + "ai_config": map[string]any{ + "coin_source": map[string]any{ + "source_type": cfg.CoinSource.SourceType, + "static_coins": cfg.CoinSource.StaticCoins, + "excluded_coins": cfg.CoinSource.ExcludedCoins, + "ai500_limit": cfg.CoinSource.AI500Limit, + "oi_top_limit": cfg.CoinSource.OITopLimit, + "oi_low_limit": cfg.CoinSource.OILowLimit, + }, + "indicators": map[string]any{ + "klines": map[string]any{ + "primary_timeframe": cfg.Indicators.Klines.PrimaryTimeframe, + "primary_count": cfg.Indicators.Klines.PrimaryCount, + "selected_timeframes": cfg.Indicators.Klines.SelectedTimeframes, + }, + "enable_ema": cfg.Indicators.EnableEMA, + "enable_macd": cfg.Indicators.EnableMACD, + "enable_rsi": cfg.Indicators.EnableRSI, + "enable_atr": cfg.Indicators.EnableATR, + "enable_boll": cfg.Indicators.EnableBOLL, + "enable_volume": cfg.Indicators.EnableVolume, + "enable_oi": cfg.Indicators.EnableOI, + "enable_funding_rate": cfg.Indicators.EnableFundingRate, + }, + "risk_control": map[string]any{ + "btc_eth_max_leverage": cfg.RiskControl.BTCETHMaxLeverage, + "altcoin_max_leverage": cfg.RiskControl.AltcoinMaxLeverage, + "min_confidence": cfg.RiskControl.MinConfidence, + "min_risk_reward_ratio": cfg.RiskControl.MinRiskRewardRatio, + }, + "prompt_sections": map[string]any{ + "trading_frequency": cfg.PromptSections.TradingFrequency, + "entry_standards": cfg.PromptSections.EntryStandards, + }, + "custom_prompt": cfg.CustomPrompt, + }, + "publish_config": publish, + "non_fields": []string{ + "investment_amount", + "fixed_position_size", + "stop_loss_pct", + "daily_loss_limit_pct", + "max_drawdown_pct", + }, + "required_fields": []string{ + "source_type", + "primary_timeframe", + "selected_timeframes", + "btceth_max_leverage", + "altcoin_max_leverage", + "min_confidence", + "min_risk_reward_ratio", + "trading_frequency", + "entry_standards", + }, + } + } +} + +func missingRequiredFieldsForBrain(session ActiveSkillSession) []string { + missing := missingRequiredFields(session) + if len(missing) == 0 { + return nil + } + out := make([]string, 0, len(missing)) + for _, field := range missing { + if field == "target_ref" { + if activeSessionHasField(session, "target_ref") { + continue + } + } + out = append(out, field) + } + return out +} + +func formatActiveSessionLocalHistory(history []chatMessage) string { + if len(history) == 0 { + return "" + } + start := 0 + if len(history) > 8 { + start = len(history) - 8 + } + lines := make([]string, 0, len(history)-start) + for _, msg := range history[start:] { + role := strings.TrimSpace(msg.Role) + if role == "" { + role = "unknown" + } + content := strings.TrimSpace(msg.Content) + if content == "" { + continue + } + lines = append(lines, fmt.Sprintf("%s: %s", role, content)) + } + return strings.Join(lines, "\n") +} + +func appendActiveSessionLocalHistory(session ActiveSkillSession, role, content string) ActiveSkillSession { + content = strings.TrimSpace(content) + if content == "" { + return session + } + session.LocalHistory = append(session.LocalHistory, chatMessage{ + Role: strings.TrimSpace(role), + Content: content, + }) + if len(session.LocalHistory) > 12 { + session.LocalHistory = append([]chatMessage(nil), session.LocalHistory[len(session.LocalHistory)-12:]...) + } + return session +} + +func parseTargetSkill(target string) (skill, action string) { + parts := strings.SplitN(target, ":", 2) + if len(parts) != 2 { + return "", "" + } + return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) +} + +func mergeExtractedData(s *ActiveSkillSession, data map[string]any) { + if s.CollectedFields == nil { + s.CollectedFields = map[string]any{} + } + if s.SkillName == "strategy_management" && s.ActionName == "create" { + if incomingType := strategyTypeFromExtractedData(data); incomingType != "" { + currentType := strategyTypeFromCollectedFields(s.CollectedFields) + if currentType != "" && currentType != incomingType { + resetActiveStrategyCreateFieldsForType(s, incomingType) + } + } + } + for k, v := range data { + k = strings.TrimSpace(k) + if k == "" { + continue + } + s.CollectedFields[k] = v + } +} + +func filterExtractedDataForActiveSession(session ActiveSkillSession, data map[string]any, lang string) map[string]any { + if len(data) == 0 { + return data + } + specs := allowedFieldSpecsForSkillSession(activeToLegacySkillSession(session), lang) + if len(specs) == 0 { + return nil + } + allowed := make(map[string]struct{}, len(specs)) + for _, spec := range specs { + key := strings.TrimSpace(spec.Key) + if key != "" { + allowed[key] = struct{}{} + } + } + out := make(map[string]any, len(data)) + for key, value := range data { + key = strings.TrimSpace(key) + if key == "" { + continue + } + if _, ok := allowed[key]; !ok { + continue + } + out[key] = value + } + if session.SkillName == "strategy_management" && session.ActionName == "create" { + out = filterStrategyCreateExtractedDataByTemplate(session, out) + } + if len(out) == 0 { + return nil + } + return out +} + +func strategyTypeFromExtractedData(data map[string]any) string { + if len(data) == 0 { + return "" + } + if value, ok := data["strategy_type"]; ok { + if strategyType := parseStrategyTypeValue(fmt.Sprint(value)); strategyType != "" { + return strategyType + } + } + if patch, ok := data[strategyCreateConfigPatchField]; ok { + if strategyType := strategyTypeFromConfigPatchAny(patch); strategyType != "" { + return strategyType + } + } + return "" +} + +func strategyTypeFromCollectedFields(fields map[string]any) string { + if len(fields) == 0 { + return "" + } + if value, ok := fields["strategy_type"]; ok { + if strategyType := parseStrategyTypeValue(fmt.Sprint(value)); strategyType != "" { + return strategyType + } + } + if patch, ok := fields[strategyCreateConfigPatchField]; ok { + if strategyType := strategyTypeFromConfigPatchAny(patch); strategyType != "" { + return strategyType + } + } + return "" +} + +func strategyTypeFromConfigPatchAny(value any) string { + patch := mapFromAny(value) + if len(patch) == 0 { + return "" + } + if strategyType := parseStrategyTypeValue(fmt.Sprint(patch["strategy_type"])); strategyType != "" { + return strategyType + } + if _, ok := patch["grid_config"]; ok { + return "grid_trading" + } + if _, ok := patch["ai_config"]; ok { + return "ai_trading" + } + return "" +} + +func resetActiveStrategyCreateFieldsForType(s *ActiveSkillSession, strategyType string) { + if s.CollectedFields == nil { + s.CollectedFields = map[string]any{} + } + keep := map[string]any{} + for _, key := range []string{"name", "description", "is_public", "config_visible", "lang"} { + if value, ok := s.CollectedFields[key]; ok { + keep[key] = value + } + } + keep["strategy_type"] = strategyType + s.CollectedFields = keep +} + +func filterStrategyCreateExtractedDataByTemplate(session ActiveSkillSession, data map[string]any) map[string]any { + if len(data) == 0 { + return data + } + strategyType := strategyTypeFromExtractedData(data) + if strategyType == "" { + strategyType = strategyTypeFromCollectedFields(session.CollectedFields) + } + if strategyType == "" { + return data + } + allowed := map[string]struct{}{} + for _, key := range manualStrategyEditableFieldKeysForType(strategyType) { + allowed[key] = struct{}{} + } + out := make(map[string]any, len(data)) + for key, value := range data { + if key == strategyCreateConfigPatchField { + if patch := sanitizeStrategyCreateConfigPatchForType(value, strategyType); len(patch) > 0 { + out[key] = patch + } + continue + } + if key == "awaiting_final_confirmation" { + out[key] = value + continue + } + if _, ok := allowed[key]; ok { + out[key] = value + } + } + if len(out) == 0 { + return nil + } + return out +} + +func sanitizeStrategyCreateConfigPatchForType(value any, strategyType string) map[string]any { + patch := mapFromAny(value) + if len(patch) == 0 { + return nil + } + out := map[string]any{ + "strategy_type": strategyType, + } + if publish := mapFromAny(patch["publish_config"]); len(publish) > 0 { + out["publish_config"] = publish + } + switch strategyType { + case "grid_trading": + if grid := mapFromAny(patch["grid_config"]); len(grid) > 0 { + out["grid_config"] = grid + } + case "ai_trading": + ai := mapFromAny(patch["ai_config"]) + if ai == nil { + ai = map[string]any{} + } + for _, key := range []string{"coin_source", "indicators", "risk_control", "prompt_sections", "custom_prompt"} { + if value, ok := patch[key]; ok { + ai[key] = value + } + } + if len(ai) > 0 { + out["ai_config"] = ai + } + } + if len(out) == 1 { + return nil + } + return out +} + +func mapFromAny(value any) map[string]any { + switch typed := value.(type) { + case map[string]any: + return typed + case string: + var out map[string]any + if err := json.Unmarshal([]byte(typed), &out); err == nil { + return out + } + } + return nil +} + +func emitBrainReply(onEvent func(event, data string), reply string) { + if onEvent == nil || reply == "" { + return + } + onEvent(StreamEventTool, "central_brain") + emitStreamText(onEvent, reply) +} diff --git a/agent/clear_memory_test.go b/agent/clear_memory_test.go new file mode 100644 index 00000000..05b5c91e --- /dev/null +++ b/agent/clear_memory_test.go @@ -0,0 +1,116 @@ +package agent + +import ( + "context" + "log/slog" + "path/filepath" + "strings" + "testing" + + "nofx/store" +) + +func TestClearRemovesActiveAndPendingConversationState(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "agent-clear.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + + a := New(nil, st, DefaultConfig(), slog.Default()) + userID := int64(42) + + a.history.Add(userID, "assistant", "之前的回复") + _ = a.saveTaskState(userID, TaskState{CurrentGoal: "配置模型"}) + a.saveActiveSkillSession(ActiveSkillSession{ + SessionID: "as_test", + UserID: userID, + SkillName: "model_management", + ActionName: "create", + PendingHint: &PendingHint{ + Prompt: "请选择 provider", + HintType: "question", + }, + }) + a.savePendingProposalSession(PendingProposalSession{ + UserID: userID, + SourceUserText: "帮我配置模型", + ProposalText: "推荐 claw402,你要继续吗?", + }) + a.saveSetupState(userID, &SetupState{ + Step: "await_ai_model", + AIProvider: "claw402", + }) + if err := st.SetSystemConfig(skillSessionConfigKey(userID), `{"name":"model_management","action":"create"}`); err != nil { + t.Fatalf("seed skill session: %v", err) + } + a.saveWorkflowSession(userID, WorkflowSession{ + Tasks: []WorkflowTask{{ + ID: "task_1", + Skill: "model_management", + Action: "create", + Request: "帮我配置模型", + Status: workflowTaskPending, + }}, + }) + if err := st.SetSystemConfig(ExecutionStateConfigKey(userID), `{"user_id":42,"session_id":"exec_1"}`); err != nil { + t.Fatalf("seed execution state: %v", err) + } + a.saveReferenceMemory(userID, &CurrentReferences{ + Model: &EntityReference{ID: "m1", Name: "claw402", Source: "context"}, + }, nil) + a.SnapshotManager(userID).Save(SuspendedTask{ResumeHint: "旧任务"}) + + reply, err := a.HandleMessage(context.Background(), userID, "/clear") + if err != nil { + t.Fatalf("clear returned error: %v", err) + } + if reply == "" { + t.Fatalf("expected clear reply") + } + + if got := a.history.Get(userID); len(got) != 0 { + t.Fatalf("history not cleared: %+v", got) + } + if got := a.buildRecentConversationContext(userID, "你好"); got != "" { + t.Fatalf("recent conversation context not cleared: %q", got) + } + if got := a.currentPendingHintText(userID); got != "" { + t.Fatalf("pending hint not cleared: %q", got) + } + if got := a.buildCurrentTurnContext(userID, "zh", "你好"); got != "" { + if strings.Contains(got, "Previous assistant reply:") || strings.Contains(got, "Recent conversation:") { + t.Fatalf("current turn context still contains prior chat memory: %q", got) + } + } + if got := a.buildActiveTaskStateContext(userID, "zh"); got != "" { + t.Fatalf("active task state context not cleared: %q", got) + } + if state := a.getTaskState(userID); state.CurrentGoal != "" || state.ActiveFlow != "" { + t.Fatalf("task state not cleared: %+v", state) + } + if _, ok := a.getActiveSkillSession(userID); ok { + t.Fatalf("active skill session not cleared") + } + if _, ok := a.getPendingProposalSession(userID); ok { + t.Fatalf("pending proposal session not cleared") + } + if session := a.getSkillSession(userID); session.Name != "" { + t.Fatalf("legacy skill session not cleared: %+v", session) + } + if session := a.getWorkflowSession(userID); len(session.Tasks) != 0 { + t.Fatalf("workflow session not cleared: %+v", session) + } + if state := a.getExecutionState(userID); state.SessionID != "" { + t.Fatalf("execution state not cleared: %+v", state) + } + if memory := a.getReferenceMemory(userID); memory.CurrentReferences != nil || len(memory.ReferenceHistory) != 0 { + t.Fatalf("reference memory not cleared: %+v", memory) + } + if stack := a.SnapshotManager(userID).List(); len(stack) != 0 { + t.Fatalf("snapshots not cleared: %+v", stack) + } + if setup := a.getSetupState(userID); setup.Step != "" || setup.AIProvider != "" { + t.Fatalf("setup state not cleared: %+v", setup) + } +} diff --git a/agent/config_tools_test.go b/agent/config_tools_test.go deleted file mode 100644 index 7d6d89ef..00000000 --- a/agent/config_tools_test.go +++ /dev/null @@ -1,387 +0,0 @@ -package agent - -import ( - "encoding/json" - "path/filepath" - "strings" - "testing" - - "nofx/store" -) - -func newTestAgentWithStore(t *testing.T) *Agent { - t.Helper() - st, err := store.New(filepath.Join(t.TempDir(), "test.db")) - if err != nil { - t.Fatalf("create test store: %v", err) - } - t.Cleanup(func() { - _ = st.Close() - }) - return &Agent{store: st} -} - -func TestToolManageExchangeConfigLifecycle(t *testing.T) { - a := newTestAgentWithStore(t) - - createResp := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"binance", - "account_name":"Main", - "enabled":true, - "testnet":true - }`) - - var created struct { - Status string `json:"status"` - Action string `json:"action"` - Exchange safeExchangeToolConfig `json:"exchange"` - } - if err := json.Unmarshal([]byte(createResp), &created); err != nil { - t.Fatalf("unmarshal create response: %v\nraw=%s", err, createResp) - } - if created.Status != "ok" || created.Action != "create" { - t.Fatalf("unexpected create response: %+v", created) - } - if created.Exchange.AccountName != "Main" || created.Exchange.ExchangeType != "binance" { - t.Fatalf("unexpected exchange payload: %+v", created.Exchange) - } - - updateResp := a.toolManageExchangeConfig("user-1", `{ - "action":"update", - "exchange_id":"`+created.Exchange.ID+`", - "account_name":"Renamed", - "enabled":false - }`) - var updated struct { - Status string `json:"status"` - Action string `json:"action"` - Exchange safeExchangeToolConfig `json:"exchange"` - } - if err := json.Unmarshal([]byte(updateResp), &updated); err != nil { - t.Fatalf("unmarshal update response: %v\nraw=%s", err, updateResp) - } - if updated.Exchange.AccountName != "Renamed" || updated.Exchange.Enabled { - t.Fatalf("unexpected updated exchange payload: %+v", updated.Exchange) - } - - deleteResp := a.toolManageExchangeConfig("user-1", `{ - "action":"delete", - "exchange_id":"`+created.Exchange.ID+`" - }`) - var deleted map[string]any - if err := json.Unmarshal([]byte(deleteResp), &deleted); err != nil { - t.Fatalf("unmarshal delete response: %v\nraw=%s", err, deleteResp) - } - if deleted["status"] != "ok" || deleted["action"] != "delete" { - t.Fatalf("unexpected delete response: %+v", deleted) - } -} - -func TestToolManageModelConfigLifecycle(t *testing.T) { - a := newTestAgentWithStore(t) - - createResp := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.openai.com/v1", - "custom_model_name":"gpt-5-mini" - }`) - - var created struct { - Status string `json:"status"` - Action string `json:"action"` - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(createResp), &created); err != nil { - t.Fatalf("unmarshal create response: %v\nraw=%s", err, createResp) - } - if created.Status != "ok" || created.Action != "create" { - t.Fatalf("unexpected create response: %+v", created) - } - if created.Model.Provider != "openai" || created.Model.CustomModelName != "gpt-5-mini" { - t.Fatalf("unexpected model payload: %+v", created.Model) - } - - updateResp := a.toolManageModelConfig("user-1", `{ - "action":"update", - "model_id":"`+created.Model.ID+`", - "enabled":false, - "custom_model_name":"gpt-5" - }`) - var updated struct { - Status string `json:"status"` - Action string `json:"action"` - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(updateResp), &updated); err != nil { - t.Fatalf("unmarshal update response: %v\nraw=%s", err, updateResp) - } - if updated.Model.Enabled || updated.Model.CustomModelName != "gpt-5" { - t.Fatalf("unexpected updated model payload: %+v", updated.Model) - } - - deleteResp := a.toolManageModelConfig("user-1", `{ - "action":"delete", - "model_id":"`+created.Model.ID+`" - }`) - var deleted map[string]any - if err := json.Unmarshal([]byte(deleteResp), &deleted); err != nil { - t.Fatalf("unmarshal delete response: %v\nraw=%s", err, deleteResp) - } - if deleted["status"] != "ok" || deleted["action"] != "delete" { - t.Fatalf("unexpected delete response: %+v", deleted) - } -} - -func TestToolManageModelConfigRejectsEnableWithoutAPIKey(t *testing.T) { - a := newTestAgentWithStore(t) - - createResp := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":false, - "custom_model_name":"gpt-4o" - }`) - var created struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(createResp), &created); err != nil { - t.Fatalf("unmarshal create response: %v\nraw=%s", err, createResp) - } - - updateResp := a.toolManageModelConfig("user-1", `{ - "action":"update", - "model_id":"`+created.Model.ID+`", - "enabled":true - }`) - if !strings.Contains(updateResp, "cannot enable model config before API key is configured") { - t.Fatalf("expected enabling incomplete model to fail, got %s", updateResp) - } -} - -func TestGetDefaultSkipsEnabledModelWithoutAPIKey(t *testing.T) { - a := newTestAgentWithStore(t) - - incompleteCreate := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":true, - "custom_model_name":"gpt-4o" - }`) - var incomplete struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(incompleteCreate), &incomplete); err != nil { - t.Fatalf("unmarshal incomplete create response: %v\nraw=%s", err, incompleteCreate) - } - - completeCreate := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"deepseek", - "enabled":true, - "api_key":"sk-test", - "custom_model_name":"deepseek-chat" - }`) - var complete struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(completeCreate), &complete); err != nil { - t.Fatalf("unmarshal complete create response: %v\nraw=%s", err, completeCreate) - } - - model, err := a.store.AIModel().GetDefault("user-1") - if err != nil { - t.Fatalf("GetDefault() error = %v", err) - } - if model.ID != complete.Model.ID { - t.Fatalf("expected GetDefault to skip incomplete enabled model and return %s, got %s", complete.Model.ID, model.ID) - } -} - -func TestToolManageTraderLifecycle(t *testing.T) { - a := newTestAgentWithStore(t) - - modelResp := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.openai.com/v1", - "custom_model_name":"gpt-5-mini" - }`) - var modelCreated struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(modelResp), &modelCreated); err != nil { - t.Fatalf("unmarshal model response: %v", err) - } - - exchangeResp := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"binance", - "account_name":"Main", - "enabled":true - }`) - var exchangeCreated struct { - Exchange safeExchangeToolConfig `json:"exchange"` - } - if err := json.Unmarshal([]byte(exchangeResp), &exchangeCreated); err != nil { - t.Fatalf("unmarshal exchange response: %v", err) - } - - createResp := a.toolManageTrader("user-1", `{ - "action":"create", - "name":"Momentum Trader", - "ai_model_id":"`+modelCreated.Model.ID+`", - "exchange_id":"`+exchangeCreated.Exchange.ID+`", - "scan_interval_minutes":5 - }`) - var created struct { - Status string `json:"status"` - Action string `json:"action"` - Trader safeTraderToolConfig `json:"trader"` - } - if err := json.Unmarshal([]byte(createResp), &created); err != nil { - t.Fatalf("unmarshal create trader response: %v\nraw=%s", err, createResp) - } - if created.Status != "ok" || created.Action != "create" { - t.Fatalf("unexpected create trader response: %+v", created) - } - if created.Trader.Name != "Momentum Trader" || created.Trader.ScanIntervalMinutes != 5 { - t.Fatalf("unexpected created trader: %+v", created.Trader) - } - - listResp := a.toolManageTrader("user-1", `{"action":"list"}`) - var listed struct { - Count int `json:"count"` - Traders []safeTraderToolConfig `json:"traders"` - } - if err := json.Unmarshal([]byte(listResp), &listed); err != nil { - t.Fatalf("unmarshal list response: %v\nraw=%s", err, listResp) - } - if listed.Count != 1 || len(listed.Traders) != 1 { - t.Fatalf("unexpected trader list: %+v", listed) - } - - updateResp := a.toolManageTrader("user-1", `{ - "action":"update", - "trader_id":"`+created.Trader.ID+`", - "name":"Renamed Trader", - "scan_interval_minutes":8 - }`) - var updated struct { - Status string `json:"status"` - Action string `json:"action"` - Trader safeTraderToolConfig `json:"trader"` - } - if err := json.Unmarshal([]byte(updateResp), &updated); err != nil { - t.Fatalf("unmarshal update trader response: %v\nraw=%s", err, updateResp) - } - if updated.Trader.Name != "Renamed Trader" || updated.Trader.ScanIntervalMinutes != 8 { - t.Fatalf("unexpected updated trader: %+v", updated.Trader) - } - - deleteResp := a.toolManageTrader("user-1", `{ - "action":"delete", - "trader_id":"`+created.Trader.ID+`" - }`) - var deleted map[string]any - if err := json.Unmarshal([]byte(deleteResp), &deleted); err != nil { - t.Fatalf("unmarshal delete trader response: %v\nraw=%s", err, deleteResp) - } - if deleted["status"] != "ok" || deleted["action"] != "delete" { - t.Fatalf("unexpected delete trader response: %+v", deleted) - } -} - -func TestToolManageStrategyLifecycle(t *testing.T) { - a := newTestAgentWithStore(t) - - createResp := a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"激进", - "description":"激进策略模板", - "lang":"zh" - }`) - - var created struct { - Status string `json:"status"` - Action string `json:"action"` - Strategy safeStrategyToolConfig `json:"strategy"` - } - if err := json.Unmarshal([]byte(createResp), &created); err != nil { - t.Fatalf("unmarshal create response: %v\nraw=%s", err, createResp) - } - if created.Status != "ok" || created.Action != "create" { - t.Fatalf("unexpected create response: %+v", created) - } - if created.Strategy.Name != "激进" { - t.Fatalf("unexpected strategy payload: %+v", created.Strategy) - } - - listResp := a.toolGetStrategies("user-1") - if !strings.Contains(listResp, "激进") { - t.Fatalf("expected created strategy in list, got %s", listResp) - } - - updateResp := a.toolManageStrategy("user-1", `{ - "action":"update", - "strategy_id":"`+created.Strategy.ID+`", - "description":"更新后的描述" - }`) - var updated struct { - Status string `json:"status"` - Action string `json:"action"` - Strategy safeStrategyToolConfig `json:"strategy"` - } - if err := json.Unmarshal([]byte(updateResp), &updated); err != nil { - t.Fatalf("unmarshal update response: %v\nraw=%s", err, updateResp) - } - if updated.Strategy.Description != "更新后的描述" { - t.Fatalf("unexpected updated strategy payload: %+v", updated.Strategy) - } - - activateResp := a.toolManageStrategy("user-1", `{ - "action":"activate", - "strategy_id":"`+created.Strategy.ID+`" - }`) - if !strings.Contains(activateResp, `"action":"activate"`) { - t.Fatalf("unexpected activate response: %s", activateResp) - } - - deleteResp := a.toolManageStrategy("user-1", `{ - "action":"delete", - "strategy_id":"`+created.Strategy.ID+`" - }`) - if !strings.Contains(deleteResp, `"action":"delete"`) { - t.Fatalf("unexpected delete response: %s", deleteResp) - } -} - -func TestLoadAIClientFromStoreUserUsesUserSpecificEnabledModel(t *testing.T) { - a := newTestAgentWithStore(t) - - if err := a.store.AIModel().Update("user-42", "openai", true, "sk-test", "https://api.openai.com/v1", "gpt-5-mini"); err != nil { - t.Fatalf("seed model: %v", err) - } - - client, modelName, ok := a.loadAIClientFromStoreUser("user-42") - if !ok { - t.Fatal("expected AI client to load from user-specific model") - } - if client == nil { - t.Fatal("expected non-nil AI client") - } - if modelName != "gpt-5-mini" { - t.Fatalf("unexpected model name: %s", modelName) - } - - // After the provider registry refactor, registered providers (like openai) - // return their own AIClient implementation, not *mcp.Client. - if client == nil { - t.Fatal("expected non-nil AI client from provider registry") - } -} diff --git a/agent/config_validation.go b/agent/config_validation.go new file mode 100644 index 00000000..7c443caa --- /dev/null +++ b/agent/config_validation.go @@ -0,0 +1,466 @@ +package agent + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + + "nofx/security" + "nofx/store" +) + +type ConfigValidationResult struct { + Warnings []string +} + +type ConfigValidator interface { + Validate() error +} + +var ( + openAIAPIKeyPattern = regexp.MustCompile(`^sk-[A-Za-z0-9\-_]{4,}$`) + genericAPIKeyPattern = regexp.MustCompile(`^[A-Za-z0-9_\-]{8,}$`) + hexCredentialPattern = regexp.MustCompile(`^(0x)?[A-Fa-f0-9]{16,}$`) + supportedModelProvider = map[string]struct{}{ + "openai": {}, "deepseek": {}, "claude": {}, "gemini": {}, "qwen": {}, "kimi": {}, "grok": {}, "minimax": {}, "claw402": {}, "blockrun-base": {}, "blockrun-sol": {}, + } +) + +const ( + manualTraderScanIntervalMin = 3 + manualTraderScanIntervalMax = 60 + manualTraderInitialBalance = 100.0 + manualLighterAPIKeyIndexMin = 0 + manualLighterAPIKeyIndexMax = 255 +) + +type modelConfigValidator struct { + provider string + enabled bool + apiKey string + customAPIURL string + customModelName string + modelID string +} + +func (v modelConfigValidator) Validate() error { + provider := strings.ToLower(strings.TrimSpace(v.provider)) + if provider == "" { + return fmt.Errorf("provider is required") + } + if _, ok := supportedModelProvider[provider]; !ok { + return fmt.Errorf("unsupported provider: %s", provider) + } + if trimmed := strings.TrimSpace(v.customAPIURL); trimmed != "" { + if err := security.ValidateURL(strings.TrimSuffix(trimmed, "#")); err != nil { + return fmt.Errorf("invalid custom_api_url: %w", err) + } + } + if v.enabled && !modelConfigUsable(provider, v.modelID, strings.TrimSpace(v.apiKey), strings.TrimSpace(v.customAPIURL), strings.TrimSpace(v.customModelName)) { + return fmt.Errorf("cannot enable model config before a usable API key, URL, and model are configured") + } + if provider == "openai" && strings.TrimSpace(v.apiKey) != "" && !openAIAPIKeyPattern.MatchString(strings.TrimSpace(v.apiKey)) { + return fmt.Errorf("OpenAI API Key format looks invalid") + } + return nil +} + +type exchangeConfigValidator struct { + exchangeType string + enabled bool + apiKey string + secretKey string + passphrase string + hyperliquidWalletAddr string + asterUser string + asterSigner string + asterPrivateKey string + lighterWalletAddr string + lighterPrivateKey string + lighterAPIKeyPrivateKey string +} + +func (v exchangeConfigValidator) Validate() error { + exchangeType := strings.ToLower(strings.TrimSpace(v.exchangeType)) + if exchangeType == "" { + return fmt.Errorf("exchange_type is required") + } + if trimmed := strings.TrimSpace(v.apiKey); trimmed != "" && !genericAPIKeyPattern.MatchString(trimmed) { + return fmt.Errorf("API Key format looks invalid") + } + if trimmed := strings.TrimSpace(v.secretKey); trimmed != "" && !genericAPIKeyPattern.MatchString(trimmed) && !hexCredentialPattern.MatchString(trimmed) { + return fmt.Errorf("Secret format looks invalid") + } + if v.enabled { + missing := store.MissingRequiredExchangeCredentialFields( + exchangeType, + v.apiKey, + v.secretKey, + v.passphrase, + v.hyperliquidWalletAddr, + v.asterUser, + v.asterSigner, + v.asterPrivateKey, + v.lighterWalletAddr, + v.lighterAPIKeyPrivateKey, + ) + if len(missing) > 0 { + return fmt.Errorf("cannot enable exchange config before required fields are complete: %s", strings.Join(missing, ", ")) + } + } + return nil +} + +type traderBindingValidator struct { + store *store.Store + storeUserID string + aiModelID string + exchangeID string + strategyID string +} + +func (v traderBindingValidator) Validate() error { + if v.store == nil { + return fmt.Errorf("store unavailable") + } + if strings.TrimSpace(v.aiModelID) == "" { + return fmt.Errorf("ai_model_id is required") + } + if strings.TrimSpace(v.exchangeID) == "" { + return fmt.Errorf("exchange_id is required") + } + model, err := v.store.AIModel().Get(v.storeUserID, strings.TrimSpace(v.aiModelID)) + if err != nil { + return fmt.Errorf("invalid ai_model_id: %w", err) + } + if !model.Enabled { + return fmt.Errorf("ai model is disabled") + } + if !modelConfigUsable(model.Provider, model.ID, strings.TrimSpace(string(model.APIKey)), strings.TrimSpace(model.CustomAPIURL), strings.TrimSpace(model.CustomModelName)) { + return fmt.Errorf("ai model config is incomplete") + } + exchange, err := v.store.Exchange().GetByID(v.storeUserID, strings.TrimSpace(v.exchangeID)) + if err != nil { + return fmt.Errorf("invalid exchange_id: %w", err) + } + if !exchange.Enabled { + return fmt.Errorf("exchange is disabled") + } + if err := (exchangeConfigValidator{ + exchangeType: exchange.ExchangeType, + enabled: exchange.Enabled, + apiKey: strings.TrimSpace(string(exchange.APIKey)), + secretKey: strings.TrimSpace(string(exchange.SecretKey)), + passphrase: strings.TrimSpace(string(exchange.Passphrase)), + hyperliquidWalletAddr: exchange.HyperliquidWalletAddr, + asterUser: exchange.AsterUser, + asterSigner: exchange.AsterSigner, + asterPrivateKey: strings.TrimSpace(string(exchange.AsterPrivateKey)), + lighterWalletAddr: exchange.LighterWalletAddr, + lighterPrivateKey: strings.TrimSpace(string(exchange.LighterPrivateKey)), + lighterAPIKeyPrivateKey: strings.TrimSpace(string(exchange.LighterAPIKeyPrivateKey)), + }).Validate(); err != nil { + return fmt.Errorf("exchange config is incomplete: %w", err) + } + if trimmed := strings.TrimSpace(v.strategyID); trimmed != "" { + if _, err := v.store.Strategy().Get(v.storeUserID, trimmed); err != nil { + return fmt.Errorf("invalid strategy_id: %w", err) + } + } + return nil +} + +func (a *Agent) validateModelDraft(storeUserID, modelID, provider string, enabled bool, apiKey, customAPIURL, customModelName string) error { + if a == nil || a.store == nil { + return fmt.Errorf("store unavailable") + } + if strings.TrimSpace(provider) == "" && strings.TrimSpace(modelID) != "" { + model, err := a.store.AIModel().Get(storeUserID, strings.TrimSpace(modelID)) + if err != nil { + return err + } + provider = model.Provider + if strings.TrimSpace(apiKey) == "" { + apiKey = strings.TrimSpace(string(model.APIKey)) + } + if strings.TrimSpace(customAPIURL) == "" { + customAPIURL = strings.TrimSpace(model.CustomAPIURL) + } + if strings.TrimSpace(customModelName) == "" { + customModelName = strings.TrimSpace(model.CustomModelName) + } + } + return (modelConfigValidator{ + provider: provider, + enabled: enabled, + apiKey: apiKey, + customAPIURL: customAPIURL, + customModelName: customModelName, + modelID: modelID, + }).Validate() +} + +func (a *Agent) validateExchangeDraft(storeUserID, exchangeID, exchangeType string, enabled bool, apiKey, secretKey, passphrase, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterAPIKeyPrivateKey string) error { + if a == nil || a.store == nil { + return fmt.Errorf("store unavailable") + } + if strings.TrimSpace(exchangeType) == "" && strings.TrimSpace(exchangeID) != "" { + exchange, err := a.store.Exchange().GetByID(storeUserID, strings.TrimSpace(exchangeID)) + if err != nil { + return err + } + exchangeType = exchange.ExchangeType + if strings.TrimSpace(apiKey) == "" { + apiKey = strings.TrimSpace(string(exchange.APIKey)) + } + if strings.TrimSpace(secretKey) == "" { + secretKey = strings.TrimSpace(string(exchange.SecretKey)) + } + if strings.TrimSpace(passphrase) == "" { + passphrase = strings.TrimSpace(string(exchange.Passphrase)) + } + if strings.TrimSpace(hyperliquidWalletAddr) == "" { + hyperliquidWalletAddr = strings.TrimSpace(exchange.HyperliquidWalletAddr) + } + if strings.TrimSpace(asterUser) == "" { + asterUser = strings.TrimSpace(exchange.AsterUser) + } + if strings.TrimSpace(asterSigner) == "" { + asterSigner = strings.TrimSpace(exchange.AsterSigner) + } + if strings.TrimSpace(asterPrivateKey) == "" { + asterPrivateKey = strings.TrimSpace(string(exchange.AsterPrivateKey)) + } + if strings.TrimSpace(lighterWalletAddr) == "" { + lighterWalletAddr = strings.TrimSpace(exchange.LighterWalletAddr) + } + if strings.TrimSpace(lighterAPIKeyPrivateKey) == "" { + lighterAPIKeyPrivateKey = strings.TrimSpace(string(exchange.LighterAPIKeyPrivateKey)) + } + } + return (exchangeConfigValidator{ + exchangeType: exchangeType, + enabled: enabled, + apiKey: apiKey, + secretKey: secretKey, + passphrase: passphrase, + hyperliquidWalletAddr: hyperliquidWalletAddr, + asterUser: asterUser, + asterSigner: asterSigner, + asterPrivateKey: asterPrivateKey, + lighterWalletAddr: lighterWalletAddr, + lighterAPIKeyPrivateKey: lighterAPIKeyPrivateKey, + }).Validate() +} + +func (a *Agent) validateTraderDraft(storeUserID, aiModelID, exchangeID, strategyID string) error { + return (traderBindingValidator{ + store: a.store, + storeUserID: storeUserID, + aiModelID: aiModelID, + exchangeID: exchangeID, + strategyID: strategyID, + }).Validate() +} + +func formatValidationFeedback(lang, domain string, err error) string { + if err == nil { + return "" + } + raw := strings.TrimSpace(err.Error()) + lower := strings.ToLower(raw) + if lang == "zh" { + switch { + case strings.Contains(lower, "openai api key format looks invalid"): + return "这份配置还有问题:API Key 格式不对。OpenAI 的 API Key 通常以 `sk-` 开头,请直接发完整 Key,我继续帮你补进当前草稿。" + case strings.Contains(lower, "api key format looks invalid"): + return "这份配置还有问题:API Key 格式不对。请直接发完整的 API Key,不要附带多余说明文字。" + case strings.Contains(lower, "secret format looks invalid"): + return "这份配置还有问题:Secret 格式不对。请直接发完整的 Secret 值,不要和 API Key 填反。" + case strings.Contains(lower, "okx requires passphrase"): + return "这份配置还有问题:OKX 账户缺少 Passphrase,启用前需要补齐。你直接把 Passphrase 发我就行。" + case strings.Contains(lower, "hyperliquid requires wallet address"): + return "这份配置还有问题:Hyperliquid 账户缺少钱包地址,启用前需要补齐。" + case strings.Contains(lower, "aster requires user, signer, and private key"): + return "这份配置还有问题:Aster 账户还缺 user、signer 和 private key,启用前需要补齐。" + case strings.Contains(lower, "lighter requires wallet address and api key private key"): + return "这份配置还有问题:Lighter 账户还缺钱包地址和 API key private key,启用前需要补齐。" + case strings.Contains(lower, "cannot enable model config before a usable api key, url, and model are configured"): + return "这份配置还有问题:要先把 API Key、接口地址和模型名称配完整,才能启用。你可以继续把缺的字段发给我。" + case strings.Contains(lower, "unsupported provider"): + return "这份配置还有问题:provider 不在支持范围内。请从 OpenAI、DeepSeek、Claude、Gemini、Qwen、Kimi、Grok、Minimax 里选一个。" + case strings.Contains(lower, "invalid custom_api_url"): + return "这份配置还有问题:接口地址格式不对。请给我完整的 URL,或直接说使用默认地址。" + case strings.Contains(lower, "ai model is disabled"): + return "这份配置还有问题:绑定的模型当前是禁用状态。请换一个已启用模型,或先启用这个模型。" + case strings.Contains(lower, "exchange is disabled"): + return "这份配置还有问题:绑定的交易所当前已禁用。请换一个已启用交易所,或先启用这个交易所。" + case strings.Contains(lower, "ai model config is incomplete"): + return "这份配置还有问题:绑定的模型配置还没补完整,暂时不能使用。" + case strings.Contains(lower, "invalid ai_model_id"): + return "这份配置还有问题:模型引用无效。请明确告诉我你要绑定哪个模型。" + case strings.Contains(lower, "invalid exchange_id"): + return "这份配置还有问题:交易所引用无效。请明确告诉我你要绑定哪个交易所。" + case strings.Contains(lower, "invalid strategy_id"): + return "这份配置还有问题:策略引用无效。请明确告诉我你要绑定哪个策略。" + case strings.Contains(lower, "provider is required"): + return "这份配置还缺 provider。请先告诉我你要用哪个模型提供商。" + case strings.Contains(lower, "exchange_type is required"): + return "这份配置还缺交易所类型。请先告诉我你要接哪个交易所。" + } + switch domain { + case "model": + return "这份模型草稿还有问题:" + raw + case "exchange": + return "这份交易所草稿还有问题:" + raw + case "trader": + return "这份交易员草稿还有问题:" + raw + case "strategy": + return "这份策略草稿还有问题:" + raw + default: + return "这份配置还有问题:" + raw + } + } + + switch { + case strings.Contains(lower, "openai api key format looks invalid"): + return "This draft still has an issue: the API key format looks wrong. OpenAI keys usually start with `sk-`. Send the full key and I'll keep filling the draft." + case strings.Contains(lower, "api key format looks invalid"): + return "This draft still has an issue: the API key format looks wrong. Send the full API key directly." + case strings.Contains(lower, "secret format looks invalid"): + return "This draft still has an issue: the secret format looks wrong. Send the full secret value directly." + case strings.Contains(lower, "okx requires passphrase"): + return "This draft still has an issue: an OKX config needs a passphrase before it can be enabled. Send the passphrase and I'll keep going." + case strings.Contains(lower, "cannot enable model config before a usable api key, url, and model are configured"): + return "This draft still has an issue: the API key, endpoint URL, and model name must be completed before the config can be enabled." + } + switch domain { + case "model": + return "This model draft still has an issue: " + raw + case "exchange": + return "This exchange draft still has an issue: " + raw + case "trader": + return "This trader draft still has an issue: " + raw + case "strategy": + return "This strategy draft still has an issue: " + raw + default: + return "This draft still has an issue: " + raw + } +} + +func normalizeTraderArgsToManualLimits(lang string, args traderUpdateArgs) (traderUpdateArgs, []string) { + warnings := make([]string, 0, 2) + if args.ScanIntervalMinutes != nil { + requested := *args.ScanIntervalMinutes + normalized := requested + if normalized < manualTraderScanIntervalMin { + normalized = manualTraderScanIntervalMin + } + if normalized > manualTraderScanIntervalMax { + normalized = manualTraderScanIntervalMax + } + if normalized != requested { + args.ScanIntervalMinutes = &normalized + if lang == "zh" { + warnings = append(warnings, fmt.Sprintf("扫描间隔手动可配置范围是 %d 到 %d 分钟,已从 %d 调整为 %d", manualTraderScanIntervalMin, manualTraderScanIntervalMax, requested, normalized)) + } else { + warnings = append(warnings, fmt.Sprintf("scan interval is limited to %d-%d minutes in the manual config, adjusted from %d to %d", manualTraderScanIntervalMin, manualTraderScanIntervalMax, requested, normalized)) + } + } + } + return args, warnings +} + +func formatRiskControlAcceptancePrompt(lang string, warnings []string, confirmLabel string) string { + if len(warnings) == 0 { + return "" + } + if lang == "zh" { + lines := []string{ + "这些配置超出了手动面板允许的范围,我已经先按风控范围收敛:", + } + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + lines = append(lines, fmt.Sprintf("如果接受当前范围,回复“%s”;也可以继续告诉我你想怎么改。", confirmLabel)) + return strings.Join(lines, "\n") + } + lines := []string{ + "Some values were outside the manual editor limits, so I normalized them first:", + } + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + lines = append(lines, fmt.Sprintf("Reply %q to accept these safe values, or keep refining the draft.", confirmLabel)) + return strings.Join(lines, "\n") +} + +func formatRiskControlRefusalPrompt(lang string, warnings []string, confirmLabel string) string { + if len(warnings) == 0 { + return "" + } + if lang == "zh" { + lines := []string{ + "这些配置超出了手动面板允许的范围,本次不会按你给的原值直接保存:", + } + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + lines = append(lines, fmt.Sprintf("如果接受当前安全范围,回复“%s”;也可以继续告诉我你想怎么改。", confirmLabel)) + return strings.Join(lines, "\n") + } + lines := []string{ + "Some values were outside the manual editor limits, so I did not save the original request as-is:", + } + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + lines = append(lines, fmt.Sprintf("Reply %q to accept these safe values, or keep refining the draft.", confirmLabel)) + return strings.Join(lines, "\n") +} + +func marshalStringList(values []string) string { + if len(values) == 0 { + return "" + } + raw, err := json.Marshal(values) + if err != nil { + return "" + } + return string(raw) +} + +func unmarshalStringList(raw string) []string { + if strings.TrimSpace(raw) == "" { + return nil + } + var values []string + if err := json.Unmarshal([]byte(raw), &values); err != nil { + return nil + } + return values +} + +func normalizeExchangePatchToManualLimits(lang string, patch exchangeUpdatePatch) (exchangeUpdatePatch, []string) { + warnings := make([]string, 0, 1) + if patch.LighterAPIKeyIndex != nil { + requested := *patch.LighterAPIKeyIndex + normalized := requested + if normalized < manualLighterAPIKeyIndexMin { + normalized = manualLighterAPIKeyIndexMin + } + if normalized > manualLighterAPIKeyIndexMax { + normalized = manualLighterAPIKeyIndexMax + } + if normalized != requested { + patch.LighterAPIKeyIndex = &normalized + if lang == "zh" { + warnings = append(warnings, fmt.Sprintf("Lighter API Key Index 手动面板范围是 %d 到 %d,已从 %d 调整为 %d", manualLighterAPIKeyIndexMin, manualLighterAPIKeyIndexMax, requested, normalized)) + } else { + warnings = append(warnings, fmt.Sprintf("lighter API key index is limited to %d-%d in the manual editor, adjusted from %d to %d", manualLighterAPIKeyIndexMin, manualLighterAPIKeyIndexMax, requested, normalized)) + } + } + } + return patch, warnings +} diff --git a/agent/config_visibility_test.go b/agent/config_visibility_test.go new file mode 100644 index 00000000..df9666f9 --- /dev/null +++ b/agent/config_visibility_test.go @@ -0,0 +1,692 @@ +package agent + +import ( + "encoding/json" + "log/slog" + "path/filepath" + "strings" + "testing" + + "nofx/store" +) + +func TestToolManageModelConfigCreateRequiresCredential(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "visibility.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + resp := a.toolManageModelConfig("default", `{"action":"create","provider":"deepseek"}`) + if !strings.Contains(resp, `"error":"api_key is required for create"`) { + t.Fatalf("expected missing api_key error, got: %s", resp) + } +} + +func TestToolManageModelConfigCreateDefaultsToEnabledLikeManualPage(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "model-create-enabled.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + resp := a.toolManageModelConfig("default", `{"action":"create","provider":"qwen","name":"qwen","api_key":"sk-test-qwen-123456","custom_model_name":"qwen3-max"}`) + if strings.Contains(resp, `"error"`) { + t.Fatalf("expected create to succeed, got: %s", resp) + } + + model, err := st.AIModel().Get("default", "default_qwen") + if err != nil { + t.Fatalf("load created model: %v", err) + } + if !model.Enabled { + t.Fatalf("expected agent-created model to default to enabled so it matches manual creation") + } +} + +func TestToolManageModelConfigCreateReusesExistingProviderRecord(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "model-create-upsert.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.AIModel().UpdateWithName("default", "default_qwen", "qwen1", false, "sk-old-qwen-123456", "", "qwen3-max"); err != nil { + t.Fatalf("seed existing qwen model: %v", err) + } + + resp := a.toolManageModelConfig("default", `{"action":"create","provider":"qwen","name":"Qwen","api_key":"sk-new-qwen-123456","custom_model_name":"qwen3-max"}`) + if strings.Contains(resp, `"error"`) { + t.Fatalf("expected create to reuse existing qwen config instead of failing, got: %s", resp) + } + + models, err := st.AIModel().List("default") + if err != nil { + t.Fatalf("list models: %v", err) + } + qwenCount := 0 + for _, model := range models { + if model != nil && model.Provider == "qwen" { + qwenCount++ + if model.ID != "default_qwen" { + t.Fatalf("expected existing qwen record to be reused, got model id %q", model.ID) + } + if model.Name != "Qwen" { + t.Fatalf("expected reused qwen record to be renamed, got %q", model.Name) + } + if !model.Enabled { + t.Fatalf("expected reused qwen record to be enabled after agent create") + } + } + } + if qwenCount != 1 { + t.Fatalf("expected exactly one qwen record after reuse, got %d", qwenCount) + } +} + +func TestToolManageExchangeConfigCreateDefaultsToEnabledLikeManualPage(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "exchange-create-enabled.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + resp := a.toolManageExchangeConfig("default", `{"action":"create","exchange_type":"binance","account_name":"Binance Main","api_key":"api-test-123456","secret_key":"secret-test-123456"}`) + if strings.Contains(resp, `"error"`) { + t.Fatalf("expected create to succeed, got: %s", resp) + } + + exchanges, err := st.Exchange().List("default") + if err != nil { + t.Fatalf("list exchanges: %v", err) + } + if len(exchanges) != 1 || exchanges[0] == nil { + t.Fatalf("expected one created exchange, got %#v", exchanges) + } + if !exchanges[0].Enabled { + t.Fatalf("expected agent-created exchange to default to enabled so it matches manual creation") + } +} + +func TestToolManageExchangeConfigCreateRejectsIncompleteDraft(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "exchange-create-incomplete.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + resp := a.toolManageExchangeConfig("default", `{"action":"create","exchange_type":"okx","account_name":"OKX Main","api_key":"api-test-123456","secret_key":"secret-test-123456"}`) + if !strings.Contains(resp, `"error"`) || !strings.Contains(resp, "passphrase") { + t.Fatalf("expected incomplete create to be rejected with missing passphrase, got: %s", resp) + } + + exchanges, err := st.Exchange().List("default") + if err != nil { + t.Fatalf("list exchanges: %v", err) + } + if len(exchanges) != 0 { + t.Fatalf("expected incomplete exchange not to be persisted, got %#v", exchanges) + } +} + +func TestToolGetModelConfigsHidesIncompleteRows(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "visibility-list.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.AIModel().UpdateWithName("default", "default_openai", "OpenAI", false, "", "", ""); err != nil { + t.Fatalf("seed incomplete model: %v", err) + } + if err := st.AIModel().UpdateWithName("default", "default_deepseek", "DeepSeek", false, "sk-test-12345", "", "deepseek-chat"); err != nil { + t.Fatalf("seed configured model: %v", err) + } + + resp := a.toolGetModelConfigs("default") + if strings.Contains(resp, `"id":"default_openai"`) { + t.Fatalf("incomplete model should be hidden from tool query: %s", resp) + } + if !strings.Contains(resp, `"id":"default_deepseek"`) { + t.Fatalf("configured model should remain visible: %s", resp) + } +} + +func TestToolManageStrategyUpdateRejectsOutOfRangeLeverageBeforeSave(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-risk-guard.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + strategy := &store.Strategy{ + ID: "strategy-risk-guard", + UserID: "default", + Name: "AI500稳重策略", + Description: "test", + IsPublic: false, + ConfigVisible: true, + Config: string(rawCfg), + } + if err := st.Strategy().Create(strategy); err != nil { + t.Fatalf("create strategy: %v", err) + } + + resp := a.toolManageStrategy("default", `{"action":"update","strategy_id":"strategy-risk-guard","config":{"risk_control":{"btc_eth_max_leverage":100,"altcoin_max_leverage":100}}}`) + if !strings.Contains(resp, `不会按你给的原值直接保存`) { + t.Fatalf("expected out-of-range leverage update to be rejected before save, got: %s", resp) + } + + updated, err := st.Strategy().Get("default", strategy.ID) + if err != nil { + t.Fatalf("reload strategy: %v", err) + } + parsed, err := updated.ParseConfig() + if err != nil { + t.Fatalf("parse updated strategy config: %v", err) + } + if parsed.RiskControl.BTCETHMaxLeverage != 5 || parsed.RiskControl.AltcoinMaxLeverage != 5 { + t.Fatalf("expected stored leverage to remain unchanged at safe defaults, got btc_eth=%d alt=%d", parsed.RiskControl.BTCETHMaxLeverage, parsed.RiskControl.AltcoinMaxLeverage) + } +} + +func TestToolManageStrategyRejectsFixedMinPositionSizeUpdates(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-fixed-min-position.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + strategy := &store.Strategy{ + ID: "strategy-fixed-min-position", + UserID: "default", + Name: "固定最小开仓策略", + Description: "test", + IsPublic: false, + ConfigVisible: true, + Config: string(rawCfg), + } + if err := st.Strategy().Create(strategy); err != nil { + t.Fatalf("create strategy: %v", err) + } + + resp := a.toolManageStrategy("default", `{"action":"update","strategy_id":"strategy-fixed-min-position","config":{"risk_control":{"min_position_size":20}}}`) + if !strings.Contains(resp, "固定值 12 USDT") { + t.Fatalf("expected fixed min position size rejection, got: %s", resp) + } + + updated, err := st.Strategy().Get("default", strategy.ID) + if err != nil { + t.Fatalf("reload strategy: %v", err) + } + parsed, err := updated.ParseConfig() + if err != nil { + t.Fatalf("parse updated strategy config: %v", err) + } + if parsed.RiskControl.MinPositionSize != 12 { + t.Fatalf("expected stored min position size to remain fixed at 12, got %v", parsed.RiskControl.MinPositionSize) + } +} + +func TestExchangeSkillOptionSummaryMatchesManualPage(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "exchange-options.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + summary := a.exchangeSkillOptionSummary("zh") + for _, expected := range []string{"Binance", "Bybit", "OKX", "Bitget", "Gate", "KuCoin", "Hyperliquid", "Aster", "Lighter", "Indodax"} { + if !strings.Contains(summary, expected) { + t.Fatalf("expected option %q in summary, got: %s", expected, summary) + } + } + for _, hidden := range []string{"Alpaca", "Forex", "Metals"} { + if strings.Contains(summary, hidden) { + t.Fatalf("did not expect hidden manual-page option %q in summary: %s", hidden, summary) + } + } +} + +func TestLoadExchangeOptionsHidesInvisibleExchangeRows(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "exchange-options-visible.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := store.DB().Create(&store.Exchange{ + ID: "hidden-exchange", + UserID: "default", + ExchangeType: "okx", + AccountName: "123413", + Name: "OKX Futures", + Type: "cex", + Enabled: false, + }).Error; err != nil { + t.Fatalf("seed legacy hidden exchange: %v", err) + } + if _, err := st.Exchange().Create("default", "okx", "我的主力OKX账户", true, "api-test", "secret-test", "pass-test", false, "", false, "", "", "", "", "", "", 0); err != nil { + t.Fatalf("create visible exchange: %v", err) + } + + options := a.loadExchangeOptions("default") + if len(options) != 1 { + t.Fatalf("expected only the visible exchange option, got %+v", options) + } + if options[0].Name != "我的主力OKX账户" { + t.Fatalf("expected visible exchange name, got %+v", options) + } +} + +func TestDescribeExchangeIncludesTypeSpecificVisibleFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "exchange-detail.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + hyperID, err := st.Exchange().Create("default", "hyperliquid", "Dex Pro", true, "hyper-api-key", "", "", true, "0xabc", true, "", "", "", "", "", "", 0) + if err != nil { + t.Fatalf("seed hyperliquid exchange: %v", err) + } + detail, ok := a.describeExchange("default", "zh", &EntityReference{ID: hyperID}) + if !ok { + t.Fatal("expected describeExchange to resolve hyperliquid config") + } + for _, expected := range []string{"交易所配置“Dex Pro”详情", "交易所:hyperliquid", "账户名:Dex Pro", "API Key:true", "Hyperliquid 钱包地址:0xabc"} { + if !strings.Contains(detail, expected) { + t.Fatalf("expected hyperliquid detail to contain %q, got: %s", expected, detail) + } + } + + lighterID, err := st.Exchange().Create("default", "lighter", "Lighter Main", false, "", "", "", false, "", true, "", "", "", "wallet-1", "", "lighter-secret", 7) + if err != nil { + t.Fatalf("seed lighter exchange: %v", err) + } + detail, ok = a.describeExchange("default", "zh", &EntityReference{ID: lighterID}) + if !ok { + t.Fatal("expected describeExchange to resolve lighter config") + } + for _, expected := range []string{"交易所:lighter", "Lighter 钱包地址:wallet-1", "Lighter API Key 私钥:true", "Lighter API Key Index:7"} { + if !strings.Contains(detail, expected) { + t.Fatalf("expected lighter detail to contain %q, got: %s", expected, detail) + } + } +} + +func TestSkillVisibleFieldSummaryForExchangeUsesReadableNames(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "exchange-field-summary.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + summary := a.skillVisibleFieldSummary("default", "zh", "exchange_management", "update") + for _, expected := range []string{"交易所类型", "账户名", "API Key", "Secret", "Passphrase", "Hyperliquid 钱包地址", "Aster User", "Lighter API Key 私钥", "Lighter API Key Index"} { + if !strings.Contains(summary, expected) { + t.Fatalf("expected field label %q in summary, got: %s", expected, summary) + } + } + if strings.Contains(summary, "hyperliquid_wallet_addr") || strings.Contains(summary, "lighter_api_key_private_key") { + t.Fatalf("field summary should use readable labels instead of raw keys: %s", summary) + } +} + +func TestSkillVisibleFieldSummaryForStrategyCoversManualPageFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-field-summary.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + summary := a.skillVisibleFieldSummary("default", "zh", "strategy_management", "update_config") + for _, expected := range []string{"发布到市场", "配置可见", "交易对", "杠杆", "主周期", "多周期时间框架", "NofxOS API key", "角色定义", "自定义 Prompt"} { + if !strings.Contains(summary, expected) { + t.Fatalf("expected field label %q in summary, got: %s", expected, summary) + } + } + if strings.Contains(summary, "最小开仓金额") { + t.Fatalf("strategy field summary should not expose fixed min position size editing: %s", summary) + } +} + +func TestStrategyVisibleFieldSummaryUsesTargetStrategyType(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-type-field-summary.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + cfg := store.GetDefaultStrategyConfig("zh") + cfg.StrategyType = "grid_trading" + cfg.GridConfig = &store.GridStrategyConfig{ + Symbol: "ETHUSDT", + GridCount: 12, + TotalInvestment: 1000, + Leverage: 3, + UseATRBounds: true, + ATRMultiplier: 2, + Distribution: "gaussian", + } + raw, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + strategy := &store.Strategy{ + ID: "strategy-grid-fields", + UserID: "default", + Name: "我的第一个网格策略", + Description: "", + IsPublic: false, + ConfigVisible: true, + Config: string(raw), + } + if err := st.Strategy().Create(strategy); err != nil { + t.Fatalf("create strategy: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + session := skillSession{ + Name: "strategy_management", + Action: "update_config", + TargetRef: &EntityReference{ + ID: strategy.ID, + Name: strategy.Name, + }, + } + resources := a.buildActiveSessionResources("default", session) + if got := resources["target_strategy_type"]; got != "grid_trading" { + t.Fatalf("expected grid strategy type in resources, got: %#v", got) + } + fields, ok := resources["target_editable_fields"].([]string) + if !ok { + t.Fatalf("expected editable field list in resources, got: %#v", resources["target_editable_fields"]) + } + joined := strings.Join(fields, ",") + if !strings.Contains(joined, "symbol") || strings.Contains(joined, "source_type") { + t.Fatalf("expected grid-only editable fields in resources, got: %s", joined) + } +} + +func TestSkillVisibleFieldSummaryForTraderMatchesManualPanelFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-field-summary.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + summary := a.skillVisibleFieldSummary("default", "zh", "trader_management", "update") + for _, expected := range []string{"交易所", "模型", "策略", "扫描间隔", "全仓模式", "竞技场显示"} { + if !strings.Contains(summary, expected) { + t.Fatalf("expected trader field label %q in summary, got: %s", expected, summary) + } + } + for _, unexpected := range []string{"名称", "初始资金", "初始余额", "杠杆", "交易对", "Prompt", "AI500", "OI Top"} { + if strings.Contains(summary, unexpected) { + t.Fatalf("trader field summary should stay within manual panel fields, got: %s", summary) + } + } +} + +func TestToolUpdateTraderRejectsRenameOutsideManualPanel(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-update-reject-rename.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.AIModel().UpdateWithName("default", "default_deepseek", "DeepSeek", true, "sk-test-12345", "", "deepseek-chat"); err != nil { + t.Fatalf("seed model: %v", err) + } + exchangeID, err := st.Exchange().Create("default", "binance", "Main", true, "api-test", "secret-test", "", false, "", false, "", "", "", "", "", "", 0) + if err != nil { + t.Fatalf("seed exchange: %v", err) + } + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + if err := st.Strategy().Create(&store.Strategy{ + ID: "strategy-trader-rename", + UserID: "default", + Name: "Rename Strategy", + Description: "test", + IsPublic: false, + ConfigVisible: true, + Config: string(rawCfg), + }); err != nil { + t.Fatalf("seed strategy: %v", err) + } + if err := st.Trader().Create(&store.Trader{ + ID: "trader-rename", + UserID: "default", + Name: "原交易员", + AIModelID: "default_deepseek", + ExchangeID: exchangeID, + StrategyID: "strategy-trader-rename", + InitialBalance: 1000, + ScanIntervalMinutes: 5, + IsCrossMargin: true, + ShowInCompetition: true, + }); err != nil { + t.Fatalf("seed trader: %v", err) + } + + resp := a.toolManageTrader("default", `{"action":"update","trader_id":"trader-rename","name":"新名字"}`) + if !strings.Contains(resp, "trader rename is not supported here") { + t.Fatalf("expected rename rejection, got: %s", resp) + } +} + +func TestToolCreateTraderResponseHidesLegacyTraderTuningFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-create-response-shape.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.AIModel().UpdateWithName("default", "default_deepseek", "DeepSeek", true, "sk-test-12345", "", "deepseek-chat"); err != nil { + t.Fatalf("seed model: %v", err) + } + exchangeID, err := st.Exchange().Create("default", "binance", "Main", true, "api-test", "secret-test", "", false, "", false, "", "", "", "", "", "", 0) + if err != nil { + t.Fatalf("seed exchange: %v", err) + } + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + if err := st.Strategy().Create(&store.Strategy{ + ID: "strategy-trader-shape", + UserID: "default", + Name: "Shape Strategy", + Description: "test", + IsPublic: false, + ConfigVisible: true, + Config: string(rawCfg), + }); err != nil { + t.Fatalf("seed strategy: %v", err) + } + + originalFetcher := traderInitialBalanceFetcher + traderInitialBalanceFetcher = func(exchangeCfg *store.Exchange, userID string) (float64, bool, error) { + return 88.5, true, nil + } + defer func() { + traderInitialBalanceFetcher = originalFetcher + }() + + resp := a.toolManageTrader("default", `{"action":"create","name":"形状测试","ai_model_id":"default_deepseek","exchange_id":"`+exchangeID+`","strategy_id":"strategy-trader-shape"}`) + if strings.Contains(resp, `"error"`) { + t.Fatalf("expected trader create to succeed, got: %s", resp) + } + for _, blocked := range []string{"btc_eth_leverage", "altcoin_leverage", "trading_symbols", "custom_prompt", "system_prompt_template"} { + if strings.Contains(resp, blocked) { + t.Fatalf("expected trader create response to hide legacy tuning field %q, got: %s", blocked, resp) + } + } +} + +func TestToolCreateTraderAutoReadsInitialBalanceFromExchange(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-auto-balance.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.AIModel().UpdateWithName("default", "default_deepseek", "DeepSeek", true, "sk-test-12345", "", "deepseek-chat"); err != nil { + t.Fatalf("seed model: %v", err) + } + exchangeID, err := st.Exchange().Create("default", "binance", "Main", true, "api-test", "secret-test", "", false, "", false, "", "", "", "", "", "", 0) + if err != nil { + t.Fatalf("seed exchange: %v", err) + } + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + if err := st.Strategy().Create(&store.Strategy{ + ID: "strategy-auto-balance", + UserID: "default", + Name: "Auto Balance Strategy", + Description: "test", + IsPublic: false, + ConfigVisible: true, + Config: string(rawCfg), + }); err != nil { + t.Fatalf("seed strategy: %v", err) + } + + originalFetcher := traderInitialBalanceFetcher + traderInitialBalanceFetcher = func(exchangeCfg *store.Exchange, userID string) (float64, bool, error) { + if exchangeCfg == nil || exchangeCfg.ID != exchangeID { + t.Fatalf("unexpected exchange config passed to balance fetcher: %#v", exchangeCfg) + } + if userID != "default" { + t.Fatalf("unexpected user id %q", userID) + } + return 4321.25, true, nil + } + defer func() { + traderInitialBalanceFetcher = originalFetcher + }() + + resp := a.toolManageTrader("default", `{"action":"create","name":"奶茶","ai_model_id":"default_deepseek","exchange_id":"`+exchangeID+`","strategy_id":"strategy-auto-balance","initial_balance":999}`) + if strings.Contains(resp, `"error"`) { + t.Fatalf("expected trader create to succeed, got: %s", resp) + } + + traders, err := st.Trader().List("default") + if err != nil { + t.Fatalf("list traders: %v", err) + } + if len(traders) != 1 { + t.Fatalf("expected one trader, got %d", len(traders)) + } + if traders[0].InitialBalance != 4321.25 { + t.Fatalf("expected initial balance to be auto-read from exchange, got %.2f", traders[0].InitialBalance) + } +} + +func TestDescribeStrategyIncludesManualPageSections(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-detail.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + cfg := store.GetDefaultStrategyConfig("zh") + cfg.StrategyType = "grid_trading" + cfg.GridConfig = &store.GridStrategyConfig{ + Symbol: "BTCUSDT", + GridCount: 12, + TotalInvestment: 1500, + Leverage: 4, + UpperPrice: 120000, + LowerPrice: 90000, + UseATRBounds: false, + ATRMultiplier: 2, + Distribution: "gaussian", + MaxDrawdownPct: 15, + StopLossPct: 5, + DailyLossLimitPct: 10, + UseMakerOnly: true, + EnableDirectionAdjust: true, + DirectionBiasRatio: 0.7, + } + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + + strategy := &store.Strategy{ + ID: "strategy-detail-1", + UserID: "default", + Name: "Grid Alpha", + Description: "grid strategy for regression", + IsPublic: true, + ConfigVisible: true, + Config: string(rawCfg), + } + if err := st.Strategy().Create(strategy); err != nil { + t.Fatalf("create strategy: %v", err) + } + strategy.ConfigVisible = false + if err := st.Strategy().Update(strategy); err != nil { + t.Fatalf("update strategy visibility: %v", err) + } + + detail, ok := a.describeStrategy("default", "zh", &EntityReference{ID: strategy.ID}) + if !ok { + t.Fatal("expected describeStrategy to resolve seeded strategy") + } + for _, expected := range []string{ + "策略“Grid Alpha”概览", + "发布设置:已发布到市场;配置隐藏", + "网格参数:交易对 BTCUSDT;网格 12;总投资 1500.00;杠杆 4;分布 gaussian", + "网格边界:上沿 120000.0000,下沿 90000.0000", + } { + if !strings.Contains(detail, expected) { + t.Fatalf("expected strategy detail to contain %q, got: %s", expected, detail) + } + } + for _, unexpected := range []string{ + "标的来源:", + "NofxOS 数据:", + } { + if strings.Contains(detail, unexpected) { + t.Fatalf("expected grid strategy detail not to contain AI field %q, got: %s", unexpected, detail) + } + } +} diff --git a/agent/entity_field_catalog.go b/agent/entity_field_catalog.go new file mode 100644 index 00000000..7cc4712f --- /dev/null +++ b/agent/entity_field_catalog.go @@ -0,0 +1,111 @@ +package agent + +type entityFieldMeta struct { + Key string + Keywords []string + ValueType string + ManualEditable bool + AgentUpdatable bool +} + +var traderFieldCatalog = []entityFieldMeta{ + {Key: "ai_model_id", Keywords: []string{"换模型", "切换模型", "模型"}, ValueType: "entity_ref", ManualEditable: true, AgentUpdatable: true}, + {Key: "exchange_id", Keywords: []string{"换交易所", "切换交易所", "交易所"}, ValueType: "entity_ref", ManualEditable: true, AgentUpdatable: true}, + {Key: "strategy_id", Keywords: []string{"换策略", "切换策略", "策略"}, ValueType: "entity_ref", ManualEditable: true, AgentUpdatable: true}, + {Key: "scan_interval_minutes", Keywords: []string{"扫描间隔", "扫描频率", "scan interval", "scan frequency"}, ValueType: "int", ManualEditable: true, AgentUpdatable: true}, + {Key: "is_cross_margin", Keywords: []string{"全仓", "cross margin", "is_cross_margin"}, ValueType: "flag", ManualEditable: true, AgentUpdatable: true}, + {Key: "show_in_competition", Keywords: []string{"竞技场显示", "显示在竞技场", "show in competition", "competition"}, ValueType: "flag", ManualEditable: true, AgentUpdatable: true}, +} + +var modelFieldCatalog = []entityFieldMeta{ + {Key: "provider", Keywords: []string{"provider", "模型提供商", "模型厂商", "vendor"}, ValueType: "enum", ManualEditable: true, AgentUpdatable: true}, + {Key: "name", Keywords: []string{"名称", "名字", "name"}, ValueType: "name", ManualEditable: true, AgentUpdatable: true}, + {Key: "enabled", Keywords: []string{"启用", "禁用", "enable", "disable"}, ValueType: "enabled", AgentUpdatable: true}, + {Key: "api_key", Keywords: []string{"api key", "apikey", "api_key"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "custom_api_url", Keywords: []string{"url", "endpoint", "地址", "接口"}, ValueType: "url", ManualEditable: true, AgentUpdatable: true}, + {Key: "custom_model_name", Keywords: []string{"model name", "模型名称", "模型名"}, ValueType: "model_name", ManualEditable: true, AgentUpdatable: true}, +} + +var exchangeFieldCatalog = []entityFieldMeta{ + {Key: "exchange_type", Keywords: []string{"交易所类型", "交易所", "exchange type", "exchange"}, ValueType: "enum", ManualEditable: true, AgentUpdatable: true}, + {Key: "account_name", Keywords: []string{"账户名", "account name"}, ValueType: "account_name", ManualEditable: true, AgentUpdatable: true}, + {Key: "enabled", Keywords: []string{"启用", "禁用", "enable", "disable"}, ValueType: "enabled", AgentUpdatable: true}, + {Key: "api_key", Keywords: []string{"api key", "apikey", "api_key"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "secret_key", Keywords: []string{"secret key", "secret", "secret_key"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "passphrase", Keywords: []string{"passphrase", "密码短语"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "testnet", Keywords: []string{"testnet", "测试网"}, ValueType: "flag", ManualEditable: true, AgentUpdatable: true}, + {Key: "hyperliquid_wallet_addr", Keywords: []string{"hyperliquid wallet", "hyperliquid钱包", "主钱包地址", "wallet address"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "aster_user", Keywords: []string{"aster user", "aster用户", "用户地址", "user"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "aster_signer", Keywords: []string{"aster signer", "signer"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "aster_private_key", Keywords: []string{"aster private key", "aster私钥", "private key"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "lighter_wallet_addr", Keywords: []string{"lighter wallet", "lighter钱包", "wallet address"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "lighter_api_key_private_key", Keywords: []string{"lighter api key private key", "lighter api key", "api key private key"}, ValueType: "credential", ManualEditable: true, AgentUpdatable: true}, + {Key: "lighter_api_key_index", Keywords: []string{"lighter api key index", "lighter索引", "api key index"}, ValueType: "int", ManualEditable: true, AgentUpdatable: true}, +} + +func fieldKeysByCapability(catalog []entityFieldMeta, include func(entityFieldMeta) bool) []string { + keys := make([]string, 0, len(catalog)) + for _, field := range catalog { + if include(field) { + keys = append(keys, field.Key) + } + } + return keys +} + +func keywordsForField(catalog []entityFieldMeta, field string) []string { + for _, item := range catalog { + if item.Key == field { + return item.Keywords + } + } + return nil +} + +func manualTraderEditableFieldKeys() []string { + return fieldKeysByCapability(traderFieldCatalog, func(field entityFieldMeta) bool { + return field.ManualEditable + }) +} + +func agentTraderUpdatableFieldKeys() []string { + return fieldKeysByCapability(traderFieldCatalog, func(field entityFieldMeta) bool { + return field.AgentUpdatable + }) +} + +func manualModelEditableFieldKeys() []string { + return fieldKeysByCapability(modelFieldCatalog, func(field entityFieldMeta) bool { + return field.ManualEditable + }) +} + +func agentModelUpdatableFieldKeys() []string { + return fieldKeysByCapability(modelFieldCatalog, func(field entityFieldMeta) bool { + return field.AgentUpdatable + }) +} + +func manualExchangeEditableFieldKeys() []string { + return fieldKeysByCapability(exchangeFieldCatalog, func(field entityFieldMeta) bool { + return field.ManualEditable + }) +} + +func agentExchangeUpdatableFieldKeys() []string { + return fieldKeysByCapability(exchangeFieldCatalog, func(field entityFieldMeta) bool { + return field.AgentUpdatable + }) +} + +func traderFieldKeywords(field string) []string { + return keywordsForField(traderFieldCatalog, field) +} + +func modelFieldKeywords(field string) []string { + return keywordsForField(modelFieldCatalog, field) +} + +func exchangeFieldKeywords(field string) []string { + return keywordsForField(exchangeFieldCatalog, field) +} diff --git a/agent/execution_state.go b/agent/execution_state.go index fe6e7540..fc82176d 100644 --- a/agent/execution_state.go +++ b/agent/execution_state.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" "time" + + "github.com/google/uuid" ) const ( @@ -30,22 +32,38 @@ const ( ) type ExecutionState struct { - SessionID string `json:"session_id"` - UserID int64 `json:"user_id"` - Goal string `json:"goal"` - Status string `json:"status"` - PlanID string `json:"plan_id"` - Steps []PlanStep `json:"steps,omitempty"` - CurrentStepID string `json:"current_step_id,omitempty"` + SessionID string `json:"session_id"` + UserID int64 `json:"user_id"` + Goal string `json:"goal"` + Status string `json:"status"` + PlanID string `json:"plan_id"` + Steps []PlanStep `json:"steps,omitempty"` + CurrentStepID string `json:"current_step_id,omitempty"` CurrentReferences *CurrentReferences `json:"current_references,omitempty"` - DynamicSnapshots []Observation `json:"dynamic_snapshots,omitempty"` - ExecutionLog []Observation `json:"execution_log,omitempty"` - SummaryNotes []Observation `json:"summary_notes,omitempty"` - Waiting *WaitingState `json:"waiting,omitempty"` - Observations []Observation `json:"observations,omitempty"` - FinalAnswer string `json:"final_answer,omitempty"` - LastError string `json:"last_error,omitempty"` - UpdatedAt string `json:"updated_at"` + ReferenceHistory []ReferenceRecord `json:"reference_history,omitempty"` + DynamicSnapshots []Observation `json:"dynamic_snapshots,omitempty"` + ExecutionLog []Observation `json:"execution_log,omitempty"` + SummaryNotes []Observation `json:"summary_notes,omitempty"` + Waiting *WaitingState `json:"waiting,omitempty"` + Observations []Observation `json:"observations,omitempty"` + FinalAnswer string `json:"final_answer,omitempty"` + LastError string `json:"last_error,omitempty"` + UpdatedAt string `json:"updated_at"` +} + +type SuspendedTask struct { + SnapshotID string `json:"snapshot_id,omitempty"` + IntentID string `json:"intent_id,omitempty"` + ParentIntentID string `json:"parent_intent_id,omitempty"` + Kind string `json:"kind,omitempty"` + ResumeHint string `json:"resume_hint,omitempty"` + ResumeOnSuccess bool `json:"resume_on_success,omitempty"` + ResumeTriggers []string `json:"resume_triggers,omitempty"` + SkillSession *skillSession `json:"skill_session,omitempty"` + WorkflowSession *WorkflowSession `json:"workflow_session,omitempty"` + ExecutionState *ExecutionState `json:"execution_state,omitempty"` + LocalHistory []chatMessage `json:"local_history,omitempty"` + SuspendedAt string `json:"suspended_at,omitempty"` } type PlanStep struct { @@ -78,8 +96,18 @@ type WaitingState struct { } type EntityReference struct { - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Source string `json:"source,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type ReferenceRecord struct { + Kind string `json:"kind,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Source string `json:"source,omitempty"` + CreatedAt string `json:"created_at,omitempty"` } type CurrentReferences struct { @@ -89,6 +117,20 @@ type CurrentReferences struct { Exchange *EntityReference `json:"exchange,omitempty"` } +type SnapshotSummary struct { + SnapshotID string `json:"snapshot_id,omitempty"` + IntentID string `json:"intent_id,omitempty"` + ParentIntentID string `json:"parent_intent_id,omitempty"` + Kind string `json:"kind,omitempty"` + ResumeHint string `json:"resume_hint,omitempty"` + SuspendedAt string `json:"suspended_at,omitempty"` +} + +type SnapshotManager struct { + agent *Agent + userID int64 +} + type executionPlan struct { Goal string `json:"goal"` Steps []PlanStep `json:"steps"` @@ -103,6 +145,82 @@ func ExecutionStateConfigKey(userID int64) string { return fmt.Sprintf("agent_execution_state_%d", userID) } +func taskStackConfigKey(userID int64) string { + return fmt.Sprintf("agent_task_stack_%d", userID) +} + +func (a *Agent) SnapshotManager(userID int64) SnapshotManager { + return SnapshotManager{agent: a, userID: userID} +} + +func (m SnapshotManager) Save(task SuspendedTask) { + if m.agent == nil { + return + } + m.agent.pushTaskStack(m.userID, task) +} + +func (m SnapshotManager) Load() (SuspendedTask, bool) { + if m.agent == nil { + return SuspendedTask{}, false + } + return m.agent.popTaskStack(m.userID) +} + +func (m SnapshotManager) Peek() (SuspendedTask, bool) { + if m.agent == nil { + return SuspendedTask{}, false + } + return m.agent.peekTaskStack(m.userID) +} + +func (m SnapshotManager) List() []SnapshotSummary { + if m.agent == nil { + return nil + } + stack := m.agent.getTaskStack(m.userID) + out := make([]SnapshotSummary, 0, len(stack)) + for _, item := range stack { + out = append(out, SnapshotSummary{ + SnapshotID: strings.TrimSpace(item.SnapshotID), + IntentID: strings.TrimSpace(item.IntentID), + ParentIntentID: strings.TrimSpace(item.ParentIntentID), + Kind: strings.TrimSpace(item.Kind), + ResumeHint: strings.TrimSpace(item.ResumeHint), + SuspendedAt: strings.TrimSpace(item.SuspendedAt), + }) + } + return out +} + +func (m SnapshotManager) Stack() []SuspendedTask { + if m.agent == nil { + return nil + } + return m.agent.getTaskStack(m.userID) +} + +func (m SnapshotManager) RemoveAt(index int) (SuspendedTask, bool) { + if m.agent == nil { + return SuspendedTask{}, false + } + stack := m.agent.getTaskStack(m.userID) + if index < 0 || index >= len(stack) { + return SuspendedTask{}, false + } + task := stack[index] + stack = append(stack[:index], stack[index+1:]...) + m.agent.saveTaskStack(m.userID, stack) + return task, true +} + +func (m SnapshotManager) Clear() { + if m.agent == nil { + return + } + m.agent.clearTaskStack(m.userID) +} + func (a *Agent) getExecutionState(userID int64) ExecutionState { if a.store == nil { return ExecutionState{} @@ -133,6 +251,9 @@ func (a *Agent) saveExecutionState(state ExecutionState) error { if state.SessionID == "" { return a.store.SetSystemConfig(ExecutionStateConfigKey(state.UserID), "") } + if state.UserID != 0 && (state.CurrentReferences != nil || len(state.ReferenceHistory) > 0) { + a.saveReferenceMemory(state.UserID, state.CurrentReferences, state.ReferenceHistory) + } data, err := json.Marshal(state) if err != nil { return err @@ -149,6 +270,80 @@ func (a *Agent) clearExecutionState(userID int64) { } } +func (a *Agent) getTaskStack(userID int64) []SuspendedTask { + if a.store == nil { + return nil + } + raw, err := a.store.GetSystemConfig(taskStackConfigKey(userID)) + if err != nil { + a.logger.Warn("failed to load task stack", "error", err, "user_id", userID) + return nil + } + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + var stack []SuspendedTask + if err := json.Unmarshal([]byte(raw), &stack); err != nil { + a.logger.Warn("failed to parse task stack", "error", err, "user_id", userID) + return nil + } + return normalizeTaskStack(stack) +} + +func (a *Agent) saveTaskStack(userID int64, stack []SuspendedTask) { + if a.store == nil { + return + } + stack = normalizeTaskStack(stack) + if len(stack) == 0 { + _ = a.store.SetSystemConfig(taskStackConfigKey(userID), "") + return + } + data, err := json.Marshal(stack) + if err != nil { + return + } + _ = a.store.SetSystemConfig(taskStackConfigKey(userID), string(data)) +} + +func (a *Agent) peekTaskStack(userID int64) (SuspendedTask, bool) { + stack := a.getTaskStack(userID) + if len(stack) == 0 { + return SuspendedTask{}, false + } + return stack[len(stack)-1], true +} + +func (a *Agent) pushTaskStack(userID int64, task SuspendedTask) { + task = normalizeSuspendedTask(task) + if task.Kind == "" { + return + } + stack := a.getTaskStack(userID) + stack = append(stack, task) + stack = normalizeTaskStack(stack) + a.saveTaskStack(userID, stack) +} + +func (a *Agent) popTaskStack(userID int64) (SuspendedTask, bool) { + stack := a.getTaskStack(userID) + if len(stack) == 0 { + return SuspendedTask{}, false + } + task := stack[len(stack)-1] + stack = stack[:len(stack)-1] + a.saveTaskStack(userID, stack) + return task, true +} + +func (a *Agent) clearTaskStack(userID int64) { + if a.store == nil { + return + } + _ = a.store.SetSystemConfig(taskStackConfigKey(userID), "") +} + func newExecutionState(userID int64, goal string) ExecutionState { now := time.Now().UTC().Format(time.RFC3339) return normalizeExecutionState(ExecutionState{ @@ -168,6 +363,7 @@ func normalizeExecutionState(state ExecutionState) ExecutionState { state.FinalAnswer = strings.TrimSpace(state.FinalAnswer) state.LastError = strings.TrimSpace(state.LastError) state.CurrentReferences = normalizeCurrentReferences(state.CurrentReferences) + state.ReferenceHistory = normalizeReferenceHistory(state.ReferenceHistory) state.Waiting = normalizeWaitingState(state.Waiting) if state.Status == "" && state.SessionID != "" { state.Status = executionStatusPlanning @@ -201,6 +397,88 @@ func normalizeExecutionState(state ExecutionState) ExecutionState { return state } +func normalizeSuspendedTask(task SuspendedTask) SuspendedTask { + task.SnapshotID = strings.TrimSpace(task.SnapshotID) + task.IntentID = strings.TrimSpace(task.IntentID) + task.ParentIntentID = strings.TrimSpace(task.ParentIntentID) + task.Kind = strings.TrimSpace(task.Kind) + task.ResumeHint = strings.TrimSpace(task.ResumeHint) + task.ResumeTriggers = cleanStringList(task.ResumeTriggers) + task.SuspendedAt = strings.TrimSpace(task.SuspendedAt) + if task.SkillSession != nil { + session := normalizeSkillSession(*task.SkillSession) + if session.Name == "" { + task.SkillSession = nil + } else { + task.SkillSession = &session + } + } + if task.WorkflowSession != nil { + session := normalizeWorkflowSession(*task.WorkflowSession) + if len(session.Tasks) == 0 { + task.WorkflowSession = nil + } else { + task.WorkflowSession = &session + } + } + if task.ExecutionState != nil { + state := normalizeExecutionState(*task.ExecutionState) + if strings.TrimSpace(state.SessionID) == "" { + task.ExecutionState = nil + } else { + task.ExecutionState = &state + } + } + if task.Kind == "" { + switch { + case task.SkillSession != nil: + task.Kind = "skill_session" + case task.WorkflowSession != nil: + task.Kind = "workflow_session" + case task.ExecutionState != nil: + task.Kind = "execution_state" + } + } + if task.Kind == "" { + return SuspendedTask{} + } + if task.SnapshotID == "" { + task.SnapshotID = "snap_" + uuid.NewString() + } + if task.IntentID == "" { + task.IntentID = "intent_" + uuid.NewString() + } + if task.SuspendedAt == "" { + task.SuspendedAt = time.Now().UTC().Format(time.RFC3339) + } + return task +} + +func normalizeTaskStack(stack []SuspendedTask) []SuspendedTask { + if len(stack) == 0 { + return nil + } + now := time.Now().UTC() + out := make([]SuspendedTask, 0, len(stack)) + for _, item := range stack { + item = normalizeSuspendedTask(item) + if item.Kind == "" { + continue + } + if t, err := time.Parse(time.RFC3339, item.SuspendedAt); err == nil && now.Sub(t) > 24*time.Hour { + continue + } + out = append(out, item) + } + if len(out) == 0 { + return nil + } + if len(out) > 5 { + out = out[len(out)-5:] + } + return out +} + func normalizeWaitingState(waiting *WaitingState) *WaitingState { if waiting == nil { return nil @@ -224,9 +502,14 @@ func normalizeEntityReference(ref *EntityReference) *EntityReference { } ref.ID = strings.TrimSpace(ref.ID) ref.Name = strings.TrimSpace(ref.Name) + ref.Source = strings.TrimSpace(ref.Source) + ref.UpdatedAt = strings.TrimSpace(ref.UpdatedAt) if ref.ID == "" && ref.Name == "" { return nil } + if ref.UpdatedAt == "" { + ref.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + } return ref } @@ -244,6 +527,34 @@ func normalizeCurrentReferences(refs *CurrentReferences) *CurrentReferences { return refs } +func normalizeReferenceHistory(history []ReferenceRecord) []ReferenceRecord { + if len(history) == 0 { + return nil + } + out := make([]ReferenceRecord, 0, len(history)) + for _, item := range history { + item.Kind = strings.TrimSpace(item.Kind) + item.ID = strings.TrimSpace(item.ID) + item.Name = strings.TrimSpace(item.Name) + item.Source = strings.TrimSpace(item.Source) + item.CreatedAt = strings.TrimSpace(item.CreatedAt) + if item.Kind == "" || (item.ID == "" && item.Name == "") { + continue + } + if item.CreatedAt == "" { + item.CreatedAt = time.Now().UTC().Format(time.RFC3339) + } + out = append(out, item) + } + if len(out) == 0 { + return nil + } + if len(out) > 12 { + out = out[len(out)-12:] + } + return out +} + func normalizeObservationList(values []Observation) []Observation { if len(values) == 0 { return nil @@ -332,8 +643,8 @@ func buildObservationContext(state ExecutionState) map[string]any { state = normalizeExecutionState(state) return map[string]any{ "current_references": state.CurrentReferences, - "dynamic_snapshots": state.DynamicSnapshots, - "execution_log": state.ExecutionLog, - "summary_notes": state.SummaryNotes, + "dynamic_snapshots": state.DynamicSnapshots, + "execution_log": state.ExecutionLog, + "summary_notes": state.SummaryNotes, } } diff --git a/agent/history.go b/agent/history.go index 662bbd31..cad912d1 100644 --- a/agent/history.go +++ b/agent/history.go @@ -1,6 +1,7 @@ package agent import ( + "strings" "sync" "time" ) @@ -101,3 +102,16 @@ func (h *chatHistory) CleanOld(maxAge time.Duration) { } } } + +func (a *Agent) getLastAssistantReply(userID int64) string { + if a == nil || a.history == nil { + return "" + } + msgs := a.history.Get(userID) + for i := len(msgs) - 1; i >= 0; i-- { + if strings.EqualFold(strings.TrimSpace(msgs[i].Role), "assistant") { + return strings.TrimSpace(msgs[i].Content) + } + } + return "" +} diff --git a/agent/i18n.go b/agent/i18n.go index 47425cab..3e5af0b8 100644 --- a/agent/i18n.go +++ b/agent/i18n.go @@ -3,20 +3,22 @@ package agent var i18nMessages = map[string]map[string]string{ "help": { "zh": "🤖 *NOFXi — 你的 AI 交易 Agent*\n\n" + - "*交易:* /buy /sell /long /short + 交易对 数量 杠杆\n" + + "*交易:* 做多 BTC 0.01 x10 · 做空 ETH 0.1 · 平多 BTC · 平空 ETH\n" + + " 也支持 /buy /sell /long /short + 交易对 数量 杠杆\n" + "*查询:* /positions /balance /pnl /traders\n" + "*分析:* /analyze BTC\n" + "*监控:* /watch BTC · /unwatch BTC\n" + "*策略:* /strategy\n" + - "*系统:* /status /help\n\n" + + "*系统:* /status /clear /help\n\n" + "直接跟我说话就行,中英文都可以 💬", "en": "🤖 *NOFXi — Your AI Trading Agent*\n\n" + - "*Trade:* /buy /sell /long /short + symbol qty leverage\n" + + "*Trade:* long BTC 0.01 x10 · short ETH 0.1 · close long BTC · close short ETH\n" + + " Also supports /buy /sell /long /short + symbol qty leverage\n" + "*Query:* /positions /balance /pnl /traders\n" + "*Analyze:* /analyze BTC\n" + "*Monitor:* /watch BTC · /unwatch BTC\n" + "*Strategy:* /strategy\n" + - "*System:* /status /help\n\n" + + "*System:* /status /clear /help\n\n" + "Just talk to me in any language 💬", }, "status": { @@ -52,8 +54,8 @@ var i18nMessages = map[string]map[string]string{ "en": "🤖 *Traders*\n\n", }, "trade_usage": { - "zh": "用法: `/buy BTC 0.01` 或 `/sell ETH 0.5 3x`", - "en": "Usage: `/buy BTC 0.01` or `/sell ETH 0.5 3x`", + "zh": "手动下单示例:`做多 BTC 0.01 x10`、`做空 ETH 0.1`、`平多 BTC`、`平空 ETH`。也支持 `/buy BTC 0.01` 或 `/sell ETH 0.5 3x`。下单后需要确认;大额订单要用“确认大额 trade_xxx”。", + "en": "Manual trade examples: `long BTC 0.01 x10`, `short ETH 0.1`, `close long BTC`, `close short ETH`. Also supports `/buy BTC 0.01` or `/sell ETH 0.5 3x`. Orders require confirmation; large orders use `confirm large trade_xxx`.", }, "invalid_qty": { "zh": "❓ 无效数量: %s", @@ -68,8 +70,8 @@ var i18nMessages = map[string]map[string]string{ "en": "⚠️ Sentinel not enabled.", }, "system_prompt": { - "zh": "你是 NOFXi,一个专业的 AI 交易 Agent。简洁、专业、用中文回复。使用交易相关 emoji。", - "en": "You are NOFXi, a professional AI trading agent. Be concise, professional. Use trading emojis.", + "zh": "你是 NOFXi,一个专业的 AI 交易 Agent。把用户当交易小白,用简单清楚的大白话回复,先说结论,再说下一步。使用少量交易相关 emoji。", + "en": "You are NOFXi, a professional AI trading agent. Treat the user like a trading beginner, use plain language, lead with the conclusion, then the next step. Use a small amount of trading emojis.", }, } diff --git a/agent/llm_flow_extractor.go b/agent/llm_flow_extractor.go new file mode 100644 index 00000000..3566bb1e --- /dev/null +++ b/agent/llm_flow_extractor.go @@ -0,0 +1,578 @@ +package agent + +import ( + "encoding/json" + "fmt" + "sort" + "strings" +) + +type llmFlowExtractionTask struct { + Skill string `json:"skill,omitempty"` + Action string `json:"action,omitempty"` + Fields map[string]string `json:"fields,omitempty"` +} + +type llmFlowExtractionResult struct { + Intent string `json:"intent,omitempty"` + TargetSnapshotID string `json:"target_snapshot_id,omitempty"` + InlineSubIntent string `json:"inline_sub_intent,omitempty"` + Fields map[string]string `json:"fields,omitempty"` + Tasks []llmFlowExtractionTask `json:"tasks,omitempty"` + Reason string `json:"reason,omitempty"` +} + +type llmFlowFieldSpec struct { + Key string `json:"key"` + Description string `json:"description"` + Required bool `json:"required,omitempty"` +} + +func buildActiveFlowExtractionPrompt(lang, flowLabel, flowContext string, text string, recentConversationCtx string, currentRefs any, suspendedSnapshots any, extraSections []string) (string, string) { + systemPrompt := `You extract structured continuation input for an active NOFXi flow. +Return JSON only. No markdown. + +You must decide one of: +- "continue": the user is continuing the current flow and may have supplied fields +- "switch": the user is switching away to another task +- "cancel": the user is cancelling the current flow +- "instant_reply": the user is only chatting / greeting and no task fields should be written + +Rules: +- Prefer "continue" only when the message clearly contributes to the current flow. +- Set target_snapshot_id only when the user is clearly referring to one suspended snapshot from Suspended snapshots JSON. +- For greetings, thanks, and casual chat, use "instant_reply". +- Consider Current references JSON and Suspended snapshots JSON when resolving vague references like "那个", "刚才那个", or "前面那个". +- Treat this as semantic slot filling, not keyword copying. +- Users will often speak in natural language, shorthand, colloquial labels, translated labels, or mild misspellings instead of exact schema keys. +- Your job is to decide which allowed canonical field each value belongs to based on the active flow, field descriptions, current missing fields, and conversation context. +- Never require the user to say the exact internal field key. +- In task.fields, always emit the canonical field keys from Allowed field spec JSON, never aliases, paraphrases, or user wording. +- If the user clearly supplied a value for one allowed field, normalize it to that canonical key before returning JSON.` + + sections := []string{ + fmt.Sprintf("Language: %s", lang), + fmt.Sprintf("Active flow label: %s", flowLabel), + flowContext, + fmt.Sprintf("Current references JSON: %s", mustMarshalJSON(currentRefs)), + fmt.Sprintf("Suspended snapshots JSON: %s", mustMarshalJSON(suspendedSnapshots)), + } + sections = append(sections, extraSections...) + sections = append(sections, fmt.Sprintf("User message: %s", text), fmt.Sprintf("Recent conversation:\n%s", recentConversationCtx)) + return systemPrompt, strings.Join(sections, "\n") +} + +func parseLLMFlowExtractionResult(raw string) llmFlowExtractionResult { + out, ok := parseRawFlowExtractionEnvelope(raw) + if !ok { + return llmFlowExtractionResult{} + } + switch out.Intent { + case "continue", "switch", "cancel", "instant_reply": + return out + default: + return llmFlowExtractionResult{} + } +} + +func parseRawFlowExtractionEnvelope(raw string) (llmFlowExtractionResult, bool) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + + var out llmFlowExtractionResult + if err := json.Unmarshal([]byte(raw), &out); err != nil { + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start < 0 || end <= start || json.Unmarshal([]byte(raw[start:end+1]), &out) != nil { + return llmFlowExtractionResult{}, false + } + } + + out.Intent = strings.TrimSpace(strings.ToLower(out.Intent)) + out.TargetSnapshotID = strings.TrimSpace(out.TargetSnapshotID) + out.Reason = strings.TrimSpace(out.Reason) + if len(out.Fields) > 0 { + clean := make(map[string]string, len(out.Fields)) + for key, value := range out.Fields { + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key == "" || value == "" { + continue + } + clean[key] = value + } + out.Fields = clean + } + cleanTasks := make([]llmFlowExtractionTask, 0, len(out.Tasks)) + for _, task := range out.Tasks { + task.Skill = strings.TrimSpace(task.Skill) + task.Action = strings.TrimSpace(task.Action) + if len(task.Fields) > 0 { + clean := make(map[string]string, len(task.Fields)) + for key, value := range task.Fields { + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key == "" || value == "" { + continue + } + clean[key] = value + } + task.Fields = clean + } + cleanTasks = append(cleanTasks, task) + } + out.Tasks = cleanTasks + return out, out.Intent != "" +} + +func filterLLMFlowExtractionFields(result llmFlowExtractionResult, specs []llmFlowFieldSpec) llmFlowExtractionResult { + if len(specs) == 0 { + result.Fields = nil + for i := range result.Tasks { + result.Tasks[i].Fields = nil + } + return result + } + allowed := make(map[string]struct{}, len(specs)) + for _, spec := range specs { + key := strings.TrimSpace(spec.Key) + if key != "" { + allowed[key] = struct{}{} + } + } + filter := func(fields map[string]string) map[string]string { + if len(fields) == 0 { + return fields + } + clean := make(map[string]string, len(fields)) + for key, value := range fields { + if _, ok := allowed[key]; !ok { + continue + } + clean[key] = value + } + if len(clean) == 0 { + return nil + } + return clean + } + result.Fields = filter(result.Fields) + for i := range result.Tasks { + result.Tasks[i].Fields = filter(result.Tasks[i].Fields) + } + return result +} + +func formatConversationMissingFields(lang string, missingFields []string) string { + if len(missingFields) == 0 { + if lang == "zh" { + return "当前没有缺失槽位。" + } + return "There are currently no missing slots." + } + display := make([]string, 0, len(missingFields)) + for _, field := range missingFields { + display = append(display, slotDisplayName(field, lang)) + } + if lang == "zh" { + return "当前仍缺这些槽位:" + strings.Join(display, "、") + } + return "Current missing slots: " + strings.Join(display, ", ") +} + +func skillSessionExtractionContext(session skillSession, lang string) (string, []llmFlowFieldSpec, map[string]string, []string) { + currentStep, _ := currentSkillDAGStep(session) + fieldSpecs := allowedFieldSpecsForSkillSession(session, lang) + currentValues := currentFieldValuesForSkillSession(session) + missing := missingFieldKeysForSkillSession(session) + summary := fmt.Sprintf("Active flow type: skill_session\nSkill: %s\nAction: %s\nCurrent DAG step: %s", session.Name, session.Action, currentStep.ID) + return summary, fieldSpecs, currentValues, missing +} + +func allowedFieldSpecsForSkillSession(session skillSession, lang string) []llmFlowFieldSpec { + add := func(out *[]llmFlowFieldSpec, key, description string, required bool) { + *out = append(*out, llmFlowFieldSpec{Key: key, Description: description, Required: required}) + } + out := make([]llmFlowFieldSpec, 0, 24) + if actionRequiresSlot(session.Name, session.Action, "target_ref") { + add(&out, "target_ref_id", slotDisplayName("target_ref", lang)+" ID", true) + add(&out, "target_ref_name", slotDisplayName("target_ref", lang), true) + } + if supportsBulkTargetSelection(session.Name, session.Action) { + add(&out, "bulk_scope", "bulk deletion scope, use all only when the user clearly requested all targets", false) + } + switch session.Name { + case "model_management": + required := map[string]bool{"provider": true} + if strings.HasPrefix(session.Action, "update") { + add(&out, "update_field", displayCatalogFieldName("update_field", lang), false) + } + add(&out, "provider", slotDisplayName("provider", lang), required["provider"]) + add(&out, "name", displayCatalogFieldName("name", lang), required["name"]) + add(&out, "custom_model_name", displayCatalogFieldName("custom_model_name", lang), required["custom_model_name"]) + add(&out, "api_key", displayCatalogFieldName("api_key", lang), required["api_key"]) + add(&out, "custom_api_url", displayCatalogFieldName("custom_api_url", lang), false) + add(&out, "enabled", displayCatalogFieldName("enabled", lang), false) + case "exchange_management": + required := map[string]bool{"exchange_type": true, "account_name": true} + if strings.HasPrefix(session.Action, "update") { + add(&out, "update_field", displayCatalogFieldName("update_field", lang), false) + } + add(&out, "exchange_type", slotDisplayName("exchange_type", lang), required["exchange_type"]) + add(&out, "account_name", displayCatalogFieldName("account_name", lang), required["account_name"]) + add(&out, "api_key", displayCatalogFieldName("api_key", lang), false) + add(&out, "secret_key", displayCatalogFieldName("secret_key", lang), false) + add(&out, "passphrase", displayCatalogFieldName("passphrase", lang), false) + add(&out, "testnet", displayCatalogFieldName("testnet", lang), false) + add(&out, "enabled", displayCatalogFieldName("enabled", lang), false) + add(&out, "hyperliquid_wallet_addr", displayCatalogFieldName("hyperliquid_wallet_addr", lang), false) + add(&out, "aster_user", displayCatalogFieldName("aster_user", lang), false) + add(&out, "aster_signer", displayCatalogFieldName("aster_signer", lang), false) + add(&out, "aster_private_key", displayCatalogFieldName("aster_private_key", lang), false) + add(&out, "lighter_wallet_addr", displayCatalogFieldName("lighter_wallet_addr", lang), false) + add(&out, "lighter_api_key_private_key", displayCatalogFieldName("lighter_api_key_private_key", lang), false) + add(&out, "lighter_api_key_index", displayCatalogFieldName("lighter_api_key_index", lang), false) + case "trader_management": + if strings.HasPrefix(session.Action, "update") { + add(&out, "update_field", displayCatalogFieldName("update_field", lang), false) + } + add(&out, "name", slotDisplayName("name", lang), true) + add(&out, "exchange_id", slotDisplayName("exchange", lang)+" ID", false) + add(&out, "exchange_name", slotDisplayName("exchange", lang), true) + add(&out, "model_id", slotDisplayName("model", lang)+" ID", false) + add(&out, "model_name", slotDisplayName("model", lang), true) + add(&out, "strategy_id", slotDisplayName("strategy", lang)+" ID", false) + add(&out, "strategy_name", slotDisplayName("strategy", lang), true) + add(&out, "auto_start", "auto_start", false) + add(&out, "scan_interval_minutes", displayCatalogFieldName("scan_interval_minutes", lang), false) + add(&out, "is_cross_margin", displayCatalogFieldName("is_cross_margin", lang), false) + add(&out, "show_in_competition", displayCatalogFieldName("show_in_competition", lang), false) + case "strategy_management": + if session.Action == "create" || session.Action == "update_config" { + if session.Action == "create" { + add(&out, "strategy_type", "Strategy type. Use ai_trading for AI strategies, including AI500/OI/static coin-source requests; use grid_trading only for grid strategy requests.", false) + } + configPatchDescription := "Partial StrategyConfig JSON patch inferred from the user's strategy intent. Use exact product schema values, not display labels: source_type must be one of static, ai500, oi_top, oi_low; strategy_type must be ai_trading or grid_trading; selected_timeframes must be a JSON array of strings, not a JSON-encoded string." + switch explicitStrategyCreateType(session) { + case "grid_trading": + configPatchDescription += " Current strategy_type is grid_trading: use only top-level strategy_type, grid_config, publish_config, and language. Do not output ai_config or AI fields such as coin_source, indicators, risk_control, timeframes, confidence, or prompt_sections." + case "ai_trading": + configPatchDescription += " Current strategy_type is ai_trading: use top-level strategy_type, ai_config, publish_config, and language. Put coin_source, indicators, risk_control, prompt_sections, and custom_prompt inside ai_config. Do not output grid_config." + default: + configPatchDescription += " Include strategy_type first when the user chooses AI or grid; after strategy_type is known, use only the config branch for that type: grid_config for grid, ai_config for AI." + } + add(&out, "config_patch", configPatchDescription, false) + } + if session.Action == "create" { + add(&out, "awaiting_final_confirmation", "Set true only after you have produced a final user-facing creation summary from the current structured config and are waiting for the user's final confirmation before executing create.", false) + } + if session.Action == "update_prompt" { + add(&out, "prompt", "Full strategy prompt text to write into the strategy custom prompt.", false) + add(&out, "custom_prompt", strategyConfigFieldDisplayName("custom_prompt", lang), false) + } + if session.Action == "update_config" { + return out + } + add(&out, "name", slotDisplayName("name", lang), true) + if session.Action == "create" { + return out + } + keys := manualStrategyEditableFieldKeys() + if strategyType := explicitStrategyCreateType(session); strategyType != "" { + keys = manualStrategyEditableFieldKeysForType(strategyType) + } + for _, key := range keys { + add(&out, key, strategyConfigFieldDisplayName(key, lang), false) + } + } + return out +} + +func currentFieldValuesForSkillSession(session skillSession) map[string]string { + values := map[string]string{} + for key, value := range session.Fields { + if trimmed := strings.TrimSpace(value); trimmed != "" { + values[key] = trimmed + } + } + if session.TargetRef != nil { + if session.TargetRef.ID != "" { + values["target_ref_id"] = session.TargetRef.ID + } + if session.TargetRef.Name != "" { + values["target_ref_name"] = session.TargetRef.Name + } + } + for _, key := range []string{"name", "exchange_id", "exchange_name", "model_id", "model_name", "strategy_id", "strategy_name", "auto_start"} { + if value := fieldValue(session, key); value != "" { + values[key] = value + } + } + return values +} + +func missingFieldKeysForSkillSession(session skillSession) []string { + missing := make([]string, 0, 8) + switch session.Name { + case "model_management": + if session.Action != "create" && session.Action != "query_list" && session.Action != "query" && session.Action != "query_detail" && session.TargetRef == nil { + missing = append(missing, "target_ref") + } + if strings.HasPrefix(session.Action, "update") { + if session.Action == "update_status" { + if fieldValue(session, "enabled") == "" { + missing = append(missing, "enabled") + } + } else if session.Action == "update_endpoint" { + if fieldValue(session, "custom_api_url") == "" { + missing = append(missing, "custom_api_url") + } + } else { + if fieldValue(session, "update_field") == "" { + missing = append(missing, "update_field") + } + } + } else { + for _, key := range []string{"provider"} { + if fieldValue(session, key) == "" { + missing = append(missing, key) + } + } + if fieldValue(session, "api_key") == "" { + missing = append(missing, "api_key") + } + } + case "exchange_management": + if session.Action != "create" && session.Action != "query_list" && session.Action != "query" && session.Action != "query_detail" && session.TargetRef == nil { + missing = append(missing, "target_ref") + } + if strings.HasPrefix(session.Action, "update") { + if session.Action == "update_status" { + if fieldValue(session, "enabled") == "" { + missing = append(missing, "enabled") + } + } else { + if fieldValue(session, "update_field") == "" { + missing = append(missing, "update_field") + } + } + } else { + for _, key := range []string{"exchange_type", "account_name", "api_key", "secret_key"} { + if fieldValue(session, key) == "" { + missing = append(missing, key) + } + } + } + case "trader_management": + if strings.HasPrefix(session.Action, "update") || strings.HasPrefix(session.Action, "configure_") { + if session.TargetRef == nil { + missing = append(missing, "target_ref") + } + if session.Action == "update_bindings" || session.Action == "configure_strategy" || session.Action == "configure_exchange" || session.Action == "configure_model" { + switch session.Action { + case "configure_strategy": + if fieldValue(session, "strategy_id") == "" { + missing = append(missing, "strategy_name") + } + break + case "configure_exchange": + if fieldValue(session, "exchange_id") == "" { + missing = append(missing, "exchange_name") + } + break + case "configure_model": + if fieldValue(session, "model_id") == "" { + missing = append(missing, "model_name") + } + break + } + if len(missing) > 0 { + break + } + if fieldValue(session, "model_id") == "" && fieldValue(session, "exchange_id") == "" && fieldValue(session, "strategy_id") == "" && + fieldValue(session, "model_name") == "" && fieldValue(session, "exchange_name") == "" && fieldValue(session, "strategy_name") == "" { + missing = append(missing, "update_field") + } + } else { + if fieldValue(session, "update_field") == "" { + missing = append(missing, "update_field") + } + } + } else { + if fieldValue(session, "name") == "" { + missing = append(missing, "name") + } + if fieldValue(session, "exchange_id") == "" { + missing = append(missing, "exchange_name") + } + if fieldValue(session, "model_id") == "" { + missing = append(missing, "model_name") + } + if fieldValue(session, "strategy_id") == "" { + missing = append(missing, "strategy_name") + } + } + case "strategy_management": + if session.Action != "create" && session.Action != "query_list" && session.Action != "query" && session.Action != "query_detail" && session.TargetRef == nil { + missing = append(missing, "target_ref") + } + switch session.Action { + case "update_name": + if fieldValue(session, "name") == "" { + missing = append(missing, "name") + } + case "update_prompt": + if fieldValue(session, "prompt") == "" && fieldValue(session, "custom_prompt") == "" { + missing = append(missing, "prompt") + } + case "update_config": + if fieldValue(session, "config_patch") == "" { + missing = append(missing, "config_patch") + } + case "create": + if fieldValue(session, "name") == "" { + missing = append(missing, "name") + } + default: + missing = append(missing, "update_field") + } + } + sort.Strings(missing) + return missing +} + +func providerExplicitlyMentionedInText(provider, text string) bool { + provider = strings.ToLower(strings.TrimSpace(provider)) + lower := strings.ToLower(strings.TrimSpace(text)) + if provider == "" || lower == "" { + return false + } + spec, _ := modelProviderSpecByID(provider) + candidates := []string{provider, strings.ToLower(strings.TrimSpace(spec.DisplayName))} + switch provider { + case "blockrun-base": + candidates = append(candidates, "blockrun", "blockrun base", "base wallet") + case "blockrun-sol": + candidates = append(candidates, "blockrun", "blockrun sol", "solana wallet") + case "claw402": + candidates = append(candidates, "claw 402") + } + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate != "" && strings.Contains(lower, candidate) { + return true + } + } + return false +} + +func sanitizeLLMExtractionForSkillSession(text string, session skillSession, result llmFlowExtractionResult) llmFlowExtractionResult { + if session.Name != "model_management" || len(result.Tasks) == 0 { + return result + } + task := result.Tasks[0] + if task.Fields == nil { + return result + } + if provider := strings.TrimSpace(task.Fields["provider"]); provider != "" && !providerExplicitlyMentionedInText(provider, text) { + delete(task.Fields, "provider") + result.Tasks[0] = task + } + return result +} + +func (a *Agent) applyLLMExtractionToSkillSession(storeUserID string, session *skillSession, result llmFlowExtractionResult, lang string, text string) { + if session == nil { + return + } + result = sanitizeLLMExtractionForSkillSession(text, *session, result) + if sub := strings.TrimSpace(result.InlineSubIntent); sub == "create_sub_resource" || sub == "edit_sub_resource" { + setField(session, "inline_sub_intent", sub) + } + if len(result.Tasks) == 0 { + return + } + task := result.Tasks[0] + if task.Skill != "" && task.Skill != session.Name { + return + } + if task.Action != "" && session.Action != "" && task.Action != session.Action { + return + } + for key, value := range task.Fields { + value = strings.TrimSpace(value) + if value == "" { + continue + } + switch key { + case "target_ref_id": + if session.TargetRef == nil { + session.TargetRef = &EntityReference{} + } + session.TargetRef.ID = value + if session.TargetRef.Source == "" { + session.TargetRef.Source = "llm_extraction" + } + continue + case "target_ref_name": + if session.TargetRef == nil { + session.TargetRef = &EntityReference{} + } + session.TargetRef.Name = value + if session.TargetRef.Source == "" { + session.TargetRef.Source = "llm_extraction" + } + continue + } + switch session.Name { + case "model_management": + if key == "provider" || key == "name" || key == "custom_model_name" || key == "api_key" || key == "custom_api_url" || key == "enabled" || key == "update_field" { + setField(session, key, value) + } + case "exchange_management": + switch key { + case "exchange_type", "account_name", "api_key", "secret_key", "passphrase", "testnet", "enabled", "update_field": + setField(session, key, value) + } + case "trader_management": + switch key { + case "update_field": + setField(session, key, value) + case "name", "exchange_id", "exchange_name", "model_id", "ai_model_id", "model_name", "strategy_id", "strategy_name", "auto_start": + setField(session, key, value) + case "scan_interval_minutes", "is_cross_margin", "show_in_competition": + setField(session, key, value) + } + case "strategy_management": + if key == "name" { + setField(session, "name", value) + continue + } + if session.Action == "create" || session.Action == "update_config" { + switch key { + case "strategy_type": + if strategyType := parseStrategyTypeValue(value); strategyType != "" { + setStrategyCreateType(session, strategyType) + } + case strategyCreateConfigPatchField: + strategyType := explicitStrategyCreateType(*session) + if strategyType == "" { + strategyType = strategyTypeFromConfigPatchAny(value) + } + if sanitized := sanitizeStrategyCreateConfigPatchForType(value, strategyType); len(sanitized) > 0 { + raw, _ := json.Marshal(sanitized) + setField(session, strategyCreateConfigPatchField, string(raw)) + } + } + continue + } + cfg := unmarshalStrategyCreateDraft(fieldValue(*session, strategyCreateDraftConfigField), lang) + if err := applyStrategyConfigPatch(&cfg, key, value); err == nil { + setField(session, strategyCreateDraftConfigField, marshalStrategyCreateDraft(cfg)) + } + } + } +} diff --git a/agent/llm_flow_extractor_test.go b/agent/llm_flow_extractor_test.go new file mode 100644 index 00000000..e9df9c20 --- /dev/null +++ b/agent/llm_flow_extractor_test.go @@ -0,0 +1,28 @@ +package agent + +import ( + "strings" + "testing" +) + +func TestBuildActiveFlowExtractionPromptRequiresCanonicalFieldOutput(t *testing.T) { + systemPrompt, _ := buildActiveFlowExtractionPrompt( + "zh", + "skill_session", + "Active flow type: skill_session\nSkill: exchange_management\nAction: create", + "secret是abc123456", + "", + nil, + nil, + nil, + ) + + for _, want := range []string{ + "Treat this as semantic slot filling, not keyword copying.", + "always emit the canonical field keys from Allowed field spec JSON", + } { + if !strings.Contains(systemPrompt, want) { + t.Fatalf("expected system prompt to contain %q, got:\n%s", want, systemPrompt) + } + } +} diff --git a/agent/llm_skill_router.go b/agent/llm_skill_router.go index 3e53a699..2e9b943e 100644 --- a/agent/llm_skill_router.go +++ b/agent/llm_skill_router.go @@ -9,14 +9,19 @@ import ( "nofx/mcp" ) -type llmSkillRouteDecision struct { - Route string `json:"route"` - Skill string `json:"skill,omitempty"` - Action string `json:"action,omitempty"` - Filter string `json:"filter,omitempty"` +type unifiedTurnDecision struct { + TopicIntent string `json:"topic_intent,omitempty"` + BusinessAction string `json:"business_action,omitempty"` + TargetSkill string `json:"target_skill,omitempty"` + Tasks []WorkflowTask `json:"tasks,omitempty"` + TargetSnapshotID string `json:"target_snapshot_id,omitempty"` + ContextMode string `json:"context_mode,omitempty"` + ExtractedData map[string]any `json:"extracted_data,omitempty"` + ReplyToUser string `json:"reply_to_user,omitempty"` + Confidence float64 `json:"confidence,omitempty"` } -func (a *Agent) tryLLMSkillRoute(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { +func (a *Agent) tryLLMIntentRoute(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { if a.aiClient == nil { return "", false, nil } @@ -26,65 +31,101 @@ func (a *Agent) tryLLMSkillRoute(ctx context.Context, storeUserID string, userID return "", false, nil } - recentConversationCtx := a.buildRecentConversationContext(userID, text) - taskStateCtx := buildTaskStateContext(a.getTaskState(userID)) - executionState := normalizeExecutionState(a.getExecutionState(userID)) - executionJSON, _ := json.Marshal(executionState) - systemPrompt := `You are the lightweight skill router for NOFXi. -Decide whether the user's message should go to a structured skill or continue to the planner. -Return JSON only. Do not return markdown. + if decision, ok, err := a.routeTurnUnifiedWithLLM(ctx, userID, lang, text); err == nil && ok { + if answer, handled, execErr := a.executeUnifiedTurnDecision(ctx, storeUserID, userID, lang, text, decision, onEvent); handled || execErr != nil { + return answer, handled, execErr + } + } + return a.tryMinimalBrain(ctx, storeUserID, userID, lang, text, onEvent) +} -Use route "skill" only when the user intent is clear enough to send directly to one structured skill. -Use route "planner" for ambiguous, multi-step, open-ended, analytical, or diagnostic requests. +func parseUnifiedTurnDecision(raw string) (unifiedTurnDecision, error) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) -Available skills: -- trader_management -- exchange_management -- model_management -- strategy_management -- trader_diagnosis -- exchange_diagnosis -- model_diagnosis -- strategy_diagnosis + var decision unifiedTurnDecision + if err := json.Unmarshal([]byte(raw), &decision); err == nil { + return normalizeUnifiedTurnDecision(decision), nil + } + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start >= 0 && end > start { + if err := json.Unmarshal([]byte(raw[start:end+1]), &decision); err == nil { + return normalizeUnifiedTurnDecision(decision), nil + } + } + return unifiedTurnDecision{}, fmt.Errorf("invalid unified turn decision json") +} -For management skills, choose one atomic action from: -- query_list -- query_detail -- query_running -- create -- update_name -- update_bindings -- update_status -- update_endpoint -- update_config -- update_prompt -- delete -- start -- stop -- activate -- duplicate +func normalizeUnifiedTurnDecision(decision unifiedTurnDecision) unifiedTurnDecision { + decision.TopicIntent = strings.TrimSpace(strings.ToLower(decision.TopicIntent)) + decision.BusinessAction = strings.TrimSpace(strings.ToLower(decision.BusinessAction)) + decision.TargetSkill = strings.TrimSpace(decision.TargetSkill) + decision.TargetSnapshotID = strings.TrimSpace(decision.TargetSnapshotID) + decision.ContextMode = strings.TrimSpace(strings.ToLower(decision.ContextMode)) + decision.ReplyToUser = strings.TrimSpace(decision.ReplyToUser) + decision.Tasks = normalizeWorkflowDecomposition(workflowDecomposition{Tasks: decision.Tasks}).Tasks + if decision.ExtractedData == nil { + decision.ExtractedData = map[string]any{} + } + if decision.Confidence < 0 { + decision.Confidence = 0 + } + if decision.Confidence > 1 { + decision.Confidence = 1 + } + switch decision.TopicIntent { + case "continue", "continue_active": + decision.TopicIntent = "continue_active" + case "start_new", "resume_snapshot", "cancel", "instant_reply": + default: + decision.TopicIntent = "" + } + switch decision.BusinessAction { + case "direct_answer", "new_skill", "skill_tasks", "continue_skill", "planned_agent", "none": + default: + decision.BusinessAction = "" + } + switch decision.ContextMode { + case "use_current", "fresh_context", "resume_snapshot": + default: + decision.ContextMode = "use_current" + } + return decision +} -Set filter only when it is clearly implied by the user. Use values like: -- running_only -- stopped_only -- enabled_only -- disabled_only -- active_only -- default_only - -Rules: -- Prefer route "planner" when uncertain. -- Prefer route "planner" for market analysis, broad advice, multi-step troubleshooting, or requests that need synthesis. -- Prefer route "skill" for straightforward management requests like listing, creating, starting, stopping, enabling, disabling, renaming, or deleting known entities. -- Questions like "当前有运行中的trader吗" and "有没有 trader 在跑" are trader_management with action "query_running". -- Questions about one entity's details, config, parameters, or prompt should prefer action "query_detail". -- Do not use route "skill" for casual chat. -- Consider Recent conversation, Task state, and Execution state JSON before deciding. - -Return JSON with this exact shape: -{"route":"skill|planner","skill":"","action":"","filter":""}` - userPrompt := fmt.Sprintf("Language: %s\nUser message: %s\n\nRecent conversation:\n%s\n\nTask state:\n%s\n\nExecution state JSON:\n%s", lang, text, recentConversationCtx, taskStateCtx, string(executionJSON)) +func (d unifiedTurnDecision) reliable() bool { + if d.TopicIntent == "" || d.BusinessAction == "" { + return false + } + if d.Confidence > 0 && d.Confidence < 0.45 { + return false + } + switch d.BusinessAction { + case "direct_answer": + return strings.TrimSpace(d.ReplyToUser) != "" + case "new_skill": + if len(d.Tasks) > 0 { + return true + } + skill, _ := parseTargetSkill(d.TargetSkill) + return skill != "" + case "skill_tasks": + return len(d.Tasks) > 0 + case "continue_skill": + return d.TopicIntent == "continue_active" + case "planned_agent", "none": + return true + default: + return false + } +} +func (a *Agent) routeTurnUnifiedWithLLM(ctx context.Context, userID int64, lang, text string) (unifiedTurnDecision, bool, error) { + systemPrompt, userPrompt := a.buildUnifiedTurnRouterPrompt(userID, lang, text) stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) defer cancel() @@ -96,160 +137,469 @@ Return JSON with this exact shape: Ctx: stageCtx, }) if err != nil { + return unifiedTurnDecision{}, false, err + } + decision, err := parseUnifiedTurnDecision(raw) + if err != nil { + return unifiedTurnDecision{}, false, err + } + if !decision.reliable() { + return decision, false, nil + } + return decision, true, nil +} + +func (a *Agent) buildUnifiedTurnRouterPrompt(userID int64, lang, text string) (string, string) { + activeSkill := a.getSkillSession(userID) + activeTask, hasActiveTask := a.getActiveSkillSession(userID) + activeWorkflow := a.getWorkflowSession(userID) + activeExec := a.getExecutionState(userID) + pendingProposal, hasPendingProposal := a.getPendingProposalSession(userID) + previousAssistantReply := a.currentPendingHintText(userID) + snapshots := a.SnapshotManager(userID).List() + snapshotJSON, _ := json.Marshal(snapshots) + currentRefs := buildCurrentReferenceSummary(lang, a.semanticCurrentReferences(userID)) + recentConversation := a.buildRecentConversationContext(userID, text) + if strings.TrimSpace(recentConversation) == "" { + recentConversation = "(empty)" + } + activeFlowSummary := buildTopLevelActiveFlowSummary(lang, activeSkill, activeTask, hasActiveTask, activeWorkflow, activeExec, pendingProposal, hasPendingProposal) + if strings.TrimSpace(activeFlowSummary) == "" { + activeFlowSummary = "none" + } + + activeTaskDetails := "none" + if hasActiveTask { + activeTaskDetails = buildBrainUserPrompt(lang, text, previousAssistantReply, recentConversation, currentRefs, activeTask, true) + } + + systemPrompt := prependNOFXiAdvisorPreamble(`You are the unified turn router for NOFXi. +Return JSON only. No markdown. + +You must make ONE combined decision for this user turn: +1. Topic/context decision: continue active context, start fresh/new context, resume snapshot, cancel, or direct conversational reply. +2. Business routing decision: answer directly, start/continue a management skill, or hand off to the planner. +3. Context policy: whether downstream modules may use current references, must use fresh context, or must resume a snapshot. + +topic_intent values: +- "continue_active": user is answering or continuing the active flow +- "start_new": user starts or switches to a new task/topic +- "resume_snapshot": user wants to resume one suspended snapshot +- "cancel": user cancels the current active flow +- "instant_reply": user only greets, thanks, chats, or asks a direct explanation + +business_action values: +- "direct_answer": reply_to_user is the final answer; do not change state +- "skill_tasks": start one or more management/diagnosis skill tasks; tasks is required +- "new_skill": legacy single-skill route; target_skill is required if tasks is empty +- "continue_skill": continue the active skill session +- "planned_agent": hand off to the execution planner/tools +- "none": only valid with cancel when no more action is needed + +tasks format for skill_tasks: +- id: "task_1", "task_2", ... +- skill: one available skill name +- action: one available action +- request: the self-contained user-readable subtask +- depends_on: array of task ids, empty when independent + +target_skill format for legacy new_skill: +skill_name:action, for example "trader_management:create". +Available skills: +trader_management, exchange_management, model_management, strategy_management, +trader_diagnosis, exchange_diagnosis, model_diagnosis, strategy_diagnosis + +Available actions: +create, update, update_name, update_bindings, configure_strategy, configure_exchange, configure_model, +update_status, update_endpoint, update_config, update_prompt, delete, start, stop, activate, duplicate, +query_list, query_detail, query_running + +context_mode values: +- "use_current": downstream modules may use current references and recent context +- "fresh_context": the user is switching topic; do not use old current references to fill business fields +- "resume_snapshot": restore target_snapshot_id first + +Rules: +- This router decides what context downstream LLMs will see. Be conservative with stale references. +- Treat topic_intent as the primary decision. If the user is naturally responding to the active flow, choose topic_intent="continue_active", business_action="continue_skill", context_mode="use_current"; do not hand off a continuing active flow to planned_agent. +- When an active flow has a previous assistant question, proposal, or confirmation request, reason about what the user's message refers to in that context before deciding it is a new task. +- If the user clearly switches domain/entity, set topic_intent="start_new" and context_mode="fresh_context". +- If the user says "不是交易员,是策略" or similar corrections, use fresh_context. +- If the user answers the previous assistant question, choose continue_active. +- If the user only says "你好", "hi", "谢谢", "收到", choose instant_reply + direct_answer unless it clearly answers a pending task. +- If the user asks a read-only management query, prefer planned_agent unless the answer is already fully available in the provided context. +- Use skill_tasks for clear management tasks such as creating/updating/deleting/configuring trader/model/exchange/strategy. +- If the user request contains multiple management operations, include multiple tasks and depends_on where a later task needs an earlier result. +- If the request contains exactly one management operation, include exactly one task. +- Use planned_agent for multi-step, tool-heavy, market/account, diagnosis, or ambiguous tasks. +- For model_management, "provider" means AI vendor, never an exchange. +- Current references are context only. Do not copy them into extracted_data unless the user explicitly says this/current/that previous one. +- extracted_data must contain only concrete facts from the current user message. +- reply_to_user must be concise and in the user's language. +- confidence should reflect how safe it is to execute this decision without the old router fallback. + +Return JSON with this exact shape: +{"topic_intent":"continue_active|start_new|resume_snapshot|cancel|instant_reply","business_action":"direct_answer|skill_tasks|new_skill|continue_skill|planned_agent|none","target_skill":"","tasks":[{"id":"task_1","skill":"","action":"","request":"","depends_on":[]}],"target_snapshot_id":"","context_mode":"use_current|fresh_context|resume_snapshot","extracted_data":{},"reply_to_user":"","confidence":0.0}`) + + userPrompt := fmt.Sprintf("Language: %s\nUser message: %s\n\nPrevious assistant reply:\n%s\n\nCurrent reference summary:\n%s\n\nActive flow summary:\n%s\n\nSuspended snapshots JSON:\n%s\n\nRecent conversation:\n%s\n\nManagement domain primer:\n%s\n\nActive task details:\n%s\n", + lang, + text, + defaultIfEmpty(previousAssistantReply, "(empty)"), + currentRefs, + activeFlowSummary, + defaultIfEmpty(string(snapshotJSON), "[]"), + recentConversation, + defaultIfEmpty(buildManagementDomainPrimer(lang), "(empty)"), + activeTaskDetails, + ) + + return systemPrompt, userPrompt +} + +func (a *Agent) executeUnifiedTurnDecision(ctx context.Context, storeUserID string, userID int64, lang, text string, decision unifiedTurnDecision, onEvent func(event, data string)) (string, bool, error) { + if session, ok := a.activeStrategyCreateSession(userID); ok && strategyCreateConfirmationReply(text) { + return a.driveActiveSession(ctx, storeUserID, userID, lang, text, session, onEvent) + } + switch decision.TopicIntent { + case "cancel": + a.clearPendingProposalSession(userID) + if a.hasAnyActiveContext(userID) { + a.clearActiveSkillSession(userID) + a.clearAnyActiveContext(userID) + return a.maybeOfferParentTaskAfterCancel(userID, lang), true, nil + } + if decision.BusinessAction == "direct_answer" && decision.ReplyToUser != "" { + emitBrainReply(onEvent, decision.ReplyToUser) + a.recordSkillInteraction(userID, text, decision.ReplyToUser) + return decision.ReplyToUser, true, nil + } + return "", false, nil + case "resume_snapshot": + a.clearPendingProposalSession(userID) + if a.tryRestoreSuspendedTaskAfterSwitch(userID, text, decision.TargetSnapshotID) { + if decision.BusinessAction == "planned_agent" { + answer, err := a.runPlannedAgentWithContextMode(ctx, storeUserID, userID, lang, text, "use_current", onEvent) + return answer, true, err + } + return a.tryMinimalBrain(ctx, storeUserID, userID, lang, text, onEvent) + } return "", false, nil } - decision, err := parseLLMSkillRouteDecision(raw) - if err != nil || decision.Route != "skill" { - return "", false, nil + if decision.TopicIntent == "continue_active" { + if _, hasProposal := a.getPendingProposalSession(userID); hasProposal && !a.hasAnyActiveContext(userID) { + return a.handlePendingProposalResponse(ctx, storeUserID, userID, lang, text, onEvent) + } + if activeSession, hasActive := a.getActiveSkillSession(userID); hasActive { + decision.ExtractedData = filterExtractedDataForActiveSession(activeSession, decision.ExtractedData, lang) + mergeExtractedData(&activeSession, decision.ExtractedData) + return a.driveActiveSession(ctx, storeUserID, userID, lang, text, activeSession, onEvent) + } + if a.hasAnyActiveContext(userID) { + return a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, onEvent) + } } - outcome, ok := a.executeLLMSkillRoute(storeUserID, userID, lang, text, decision) + switch decision.BusinessAction { + case "direct_answer": + if decision.ReplyToUser == "" { + return "", false, nil + } + if decision.TopicIntent == "instant_reply" && a.hasAnyActiveContext(userID) { + return a.replyToActiveFlowInstantReply(ctx, userID, lang, text, onEvent), true, nil + } + if guarded, blocked := guardUnsupportedAsyncPromise(lang, decision.ReplyToUser); blocked { + decision.ReplyToUser = guarded + } + emitBrainReply(onEvent, decision.ReplyToUser) + a.recordSkillInteraction(userID, text, decision.ReplyToUser) + a.runPostResponseMaintenanceAsync(userID) + return decision.ReplyToUser, true, nil + case "new_skill": + if len(decision.Tasks) > 0 { + return a.executeUnifiedSkillTasks(ctx, storeUserID, userID, lang, text, decision, onEvent) + } + skill, action := parseTargetSkill(decision.TargetSkill) + if skill == "" { + return "", false, nil + } + if a.hasAnyActiveContext(userID) && decision.ContextMode == "fresh_context" { + if !a.suspendActiveContexts(userID, lang) { + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + a.clearExecutionState(userID) + } + a.clearActiveSkillSession(userID) + } + session := newActiveSkillSession(userID, skill, action) + session.Goal = strings.TrimSpace(text) + decision.ExtractedData = filterExtractedDataForActiveSession(session, decision.ExtractedData, lang) + mergeExtractedData(&session, decision.ExtractedData) + return a.driveActiveSession(ctx, storeUserID, userID, lang, text, session, onEvent) + case "skill_tasks": + return a.executeUnifiedSkillTasks(ctx, storeUserID, userID, lang, text, decision, onEvent) + case "continue_skill": + activeSession, hasActive := a.getActiveSkillSession(userID) + if !hasActive { + return "", false, nil + } + decision.ExtractedData = filterExtractedDataForActiveSession(activeSession, decision.ExtractedData, lang) + mergeExtractedData(&activeSession, decision.ExtractedData) + return a.driveActiveSession(ctx, storeUserID, userID, lang, text, activeSession, onEvent) + case "planned_agent": + if session, ok := a.activeStrategyCreateSession(userID); ok { + return a.driveActiveSession(ctx, storeUserID, userID, lang, text, session, onEvent) + } + contextMode := decision.ContextMode + if contextMode == "resume_snapshot" { + contextMode = "use_current" + } + answer, err := a.runPlannedAgentWithContextMode(ctx, storeUserID, userID, lang, text, contextMode, onEvent) + return answer, true, err + case "none": + return "", false, nil + default: + return "", false, nil + } +} + +func (a *Agent) executeUnifiedSkillTasks(ctx context.Context, storeUserID string, userID int64, lang, text string, decision unifiedTurnDecision, onEvent func(event, data string)) (string, bool, error) { + tasks := normalizeWorkflowDecomposition(workflowDecomposition{Tasks: decision.Tasks}).Tasks + if len(tasks) == 0 { + return "", false, nil + } + if task, ok := strategyCreateWorkflowTask(tasks); ok { + if a.hasAnyActiveContext(userID) && decision.ContextMode == "fresh_context" { + if !a.suspendActiveContexts(userID, lang) { + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + a.clearExecutionState(userID) + } + a.clearActiveSkillSession(userID) + } + a.clearWorkflowSession(userID) + a.clearExecutionState(userID) + session := newActiveSkillSession(userID, task.Skill, task.Action) + session.Goal = defaultIfEmpty(strings.TrimSpace(task.Request), strings.TrimSpace(text)) + decision.ExtractedData = filterExtractedDataForActiveSession(session, decision.ExtractedData, lang) + mergeExtractedData(&session, decision.ExtractedData) + return a.driveActiveSession(ctx, storeUserID, userID, lang, defaultIfEmpty(task.Request, text), session, onEvent) + } + if a.hasAnyActiveContext(userID) && decision.ContextMode == "fresh_context" { + if !a.suspendActiveContexts(userID, lang) { + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + a.clearExecutionState(userID) + } + a.clearActiveSkillSession(userID) + } + if len(tasks) == 1 { + task := tasks[0] + session := newActiveSkillSession(userID, task.Skill, task.Action) + session.Goal = defaultIfEmpty(strings.TrimSpace(task.Request), strings.TrimSpace(text)) + decision.ExtractedData = filterExtractedDataForActiveSession(session, decision.ExtractedData, lang) + mergeExtractedData(&session, decision.ExtractedData) + return a.driveActiveSession(ctx, storeUserID, userID, lang, defaultIfEmpty(task.Request, text), session, onEvent) + } + session := normalizeWorkflowSession(WorkflowSession{ + UserID: userID, + OriginalRequest: strings.TrimSpace(text), + Tasks: tasks, + }) + if len(session.Tasks) == 0 { + return "", false, nil + } + a.saveWorkflowSession(userID, session) + return a.maybeAdvanceWorkflow(ctx, storeUserID, userID, lang, session, onEvent) +} + +func strategyCreateWorkflowTask(tasks []WorkflowTask) (WorkflowTask, bool) { + for _, task := range tasks { + if strings.TrimSpace(task.Skill) == "strategy_management" && strings.TrimSpace(task.Action) == "create" { + return task, true + } + } + return WorkflowTask{}, false +} + +func buildTopLevelActiveFlowSummary(lang string, skill skillSession, activeTask ActiveSkillSession, hasActiveTask bool, workflow WorkflowSession, state ExecutionState, pendingProposal PendingProposalSession, hasPendingProposal bool) string { + lines := make([]string, 0, 8) + if hasActiveTask { + lines = append(lines, fmt.Sprintf("Active task session: %s / %s / phase=%s", activeTask.SkillName, activeTask.ActionName, defaultIfEmpty(activeTask.LegacyPhase, "collecting"))) + if strings.TrimSpace(activeTask.Goal) != "" { + lines = append(lines, "Active task goal: "+strings.TrimSpace(activeTask.Goal)) + } + if activeTask.PendingHint != nil && strings.TrimSpace(activeTask.PendingHint.Prompt) != "" { + lines = append(lines, "Active task pending hint: "+strings.TrimSpace(activeTask.PendingHint.Prompt)) + } + if len(activeTask.CollectedFields) > 0 { + fieldsJSON, _ := json.Marshal(activeTask.CollectedFields) + lines = append(lines, "Active task collected_fields: "+string(fieldsJSON)) + } + } + if strings.TrimSpace(skill.Name) != "" { + lines = append(lines, fmt.Sprintf("Active skill session: %s / %s / phase=%s", skill.Name, skill.Action, defaultIfEmpty(skill.Phase, "collecting"))) + if routing := buildSkillActionRoutingSummary(lang, skill); routing != "" { + lines = append(lines, routing) + } + } + if hasActiveWorkflowSession(workflow) { + lines = append(lines, fmt.Sprintf("Active workflow: original_request=%s pending_tasks=%d", workflow.OriginalRequest, countPendingWorkflowTasks(workflow))) + } + if hasActiveExecutionState(state) { + lines = append(lines, fmt.Sprintf("Active execution state: status=%s goal=%s", state.Status, state.Goal)) + if state.Waiting != nil && strings.TrimSpace(state.Waiting.Question) != "" { + lines = append(lines, "Waiting question: "+strings.TrimSpace(state.Waiting.Question)) + } + } + if hasPendingProposal { + lines = append(lines, "Pending assistant proposal awaiting user response.") + if strings.TrimSpace(pendingProposal.SourceUserText) != "" { + lines = append(lines, "Proposal source request: "+strings.TrimSpace(pendingProposal.SourceUserText)) + } + lines = append(lines, "Proposal text: "+strings.TrimSpace(pendingProposal.ProposalText)) + } + return strings.Join(lines, "\n") +} + +func (a *Agent) handlePendingProposalResponse(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { + proposal, ok := a.getPendingProposalSession(userID) if !ok { return "", false, nil } - - review, err := a.reviewTaskCompletion(ctx, userID, lang, text, outcome) - if err != nil { - if outcome.Status == skillOutcomeRecoverableError || outcome.Status == skillOutcomeFatalError || outcome.Status == skillOutcomeNotHandled { - return "", false, nil - } - review = taskReviewDecision{Route: "complete", Answer: outcome.UserMessage} + answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, fmt.Sprintf("The user is replying to the assistant's previous proposal.\n\nOriginal user request:\n%s\n\nPrevious assistant proposal:\n%s\n\nCurrent user reply:\n%s", proposal.SourceUserText, proposal.ProposalText, text), onEvent) + if err == nil && strings.TrimSpace(answer) != "" { + a.clearPendingProposalSession(userID) } - if review.Route == "replan" { - answer, planErr := a.runPlannedAgent(ctx, storeUserID, userID, lang, fmt.Sprintf("Original user request:\n%s\n\nPrevious skill outcome JSON:\n%s", text, mustMarshalJSON(outcome)), onEvent) - return answer, true, planErr - } - - answer := strings.TrimSpace(review.Answer) - if answer == "" { - answer = strings.TrimSpace(outcome.UserMessage) - } - if answer == "" { - return "", false, nil - } - - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - label := "llm_skill_route" - if decision.Skill != "" { - label += ":" + decision.Skill - } - if decision.Action != "" { - label += ":" + decision.Action - } - onEvent(StreamEventTool, label) - onEvent(StreamEventDelta, answer) - } - return answer, true, nil + return answer, true, err } -func parseLLMSkillRouteDecision(raw string) (llmSkillRouteDecision, error) { - raw = strings.TrimSpace(raw) - raw = strings.TrimPrefix(raw, "```json") - raw = strings.TrimPrefix(raw, "```") - raw = strings.TrimSuffix(raw, "```") - raw = strings.TrimSpace(raw) - - var decision llmSkillRouteDecision - if err := json.Unmarshal([]byte(raw), &decision); err == nil { - return normalizeLLMSkillRouteDecision(decision), nil - } - start := strings.Index(raw, "{") - end := strings.LastIndex(raw, "}") - if start >= 0 && end > start { - if err := json.Unmarshal([]byte(raw[start:end+1]), &decision); err == nil { - return normalizeLLMSkillRouteDecision(decision), nil +func countPendingWorkflowTasks(session WorkflowSession) int { + count := 0 + for _, task := range session.Tasks { + switch task.Status { + case workflowTaskPending, workflowTaskRunning: + count++ } } - return llmSkillRouteDecision{}, fmt.Errorf("invalid llm skill route json") + return count } -func normalizeLLMSkillRouteDecision(decision llmSkillRouteDecision) llmSkillRouteDecision { - decision.Route = strings.TrimSpace(strings.ToLower(decision.Route)) - decision.Skill = strings.TrimSpace(strings.ToLower(decision.Skill)) - decision.Filter = strings.TrimSpace(strings.ToLower(decision.Filter)) - if decision.Action == "query" && decision.Filter == "running_only" && decision.Skill == "trader_management" { - decision.Action = "query_running" - } else { - decision.Action = normalizeAtomicSkillAction(decision.Skill, decision.Action) +func buildCurrentReferenceSummary(lang string, refs *CurrentReferences) string { + if refs == nil { + if lang == "zh" { + return "- 当前没有明确锁定的操作对象。" + } + return "- No current entity references are locked yet." } - return decision -} -func (a *Agent) executeLLMSkillRoute(storeUserID string, userID int64, lang, text string, decision llmSkillRouteDecision) (skillOutcome, bool) { - session := skillSession{Name: decision.Skill, Action: decision.Action} - - switch decision.Skill { - case "trader_management": - if decision.Action == "create" { - answer, handled := a.handleCreateTraderSkill(storeUserID, userID, lang, text, session) - if !handled { - return skillOutcome{}, false + lines := make([]string, 0, 4) + appendLine := func(kind string, ref *EntityReference) { + if ref == nil { + return + } + name := strings.TrimSpace(defaultIfEmpty(ref.Name, ref.ID)) + if name == "" { + return + } + source := formatReferenceSourceLabel(lang, ref.Source) + if lang == "zh" { + line := fmt.Sprintf("- 当前%s: %s", referenceKindDisplayName(lang, kind), name) + if source != "" { + line += fmt.Sprintf("(来源: %s)", source) } - return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true + if strings.TrimSpace(ref.ID) != "" && strings.TrimSpace(ref.ID) != name { + line += fmt.Sprintf(" [id=%s]", ref.ID) + } + lines = append(lines, line) + return } - answer, handled := a.handleTraderManagementSkill(storeUserID, userID, lang, text, session) - if handled && decision.Action == "query_running" { - answer = applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), "running_only") + + line := fmt.Sprintf("- Current %s: %s", referenceKindDisplayName(lang, kind), name) + if source != "" { + line += fmt.Sprintf(" (source: %s)", source) } - if !handled { - return skillOutcome{}, false + if strings.TrimSpace(ref.ID) != "" && strings.TrimSpace(ref.ID) != name { + line += fmt.Sprintf(" [id=%s]", ref.ID) } - return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true - case "exchange_management": - answer, handled := a.handleExchangeManagementSkill(storeUserID, userID, lang, text, session) - if !handled { - return skillOutcome{}, false - } - return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true - case "model_management": - answer, handled := a.handleModelManagementSkill(storeUserID, userID, lang, text, session) - if !handled { - return skillOutcome{}, false - } - return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true - case "strategy_management": - answer, handled := a.handleStrategyManagementSkill(storeUserID, userID, lang, text, session) - if !handled { - return skillOutcome{}, false - } - return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true - case "model_diagnosis": - return skillOutcome{ - Skill: decision.Skill, - Action: defaultIfEmpty(decision.Action, "diagnose"), - Status: skillOutcomeSuccess, - GoalAchieved: true, - UserMessage: a.handleModelDiagnosisSkill(storeUserID, lang, text), - }, true - case "exchange_diagnosis": - return skillOutcome{ - Skill: decision.Skill, - Action: defaultIfEmpty(decision.Action, "diagnose"), - Status: skillOutcomeSuccess, - GoalAchieved: true, - UserMessage: a.handleExchangeDiagnosisSkill(storeUserID, lang, text), - }, true - case "trader_diagnosis": - return skillOutcome{ - Skill: decision.Skill, - Action: defaultIfEmpty(decision.Action, "diagnose"), - Status: skillOutcomeSuccess, - GoalAchieved: true, - UserMessage: a.handleTraderDiagnosisSkill(storeUserID, lang, text), - }, true - case "strategy_diagnosis": - return skillOutcome{ - Skill: decision.Skill, - Action: defaultIfEmpty(decision.Action, "diagnose"), - Status: skillOutcomeSuccess, - GoalAchieved: true, - UserMessage: a.handleStrategyDiagnosisSkill(storeUserID, lang, text), - }, true - default: - return skillOutcome{}, false + lines = append(lines, line) } + + appendLine("strategy", refs.Strategy) + appendLine("trader", refs.Trader) + appendLine("model", refs.Model) + appendLine("exchange", refs.Exchange) + + if len(lines) == 0 { + if lang == "zh" { + return "- 当前没有明确锁定的操作对象。" + } + return "- No current entity references are locked yet." + } + return strings.Join(lines, "\n") +} + +func formatReferenceSourceLabel(lang, source string) string { + source = strings.TrimSpace(source) + if source == "" { + return "" + } + if lang == "zh" { + switch source { + case "user_mention": + return "用户提及" + case "tool_output": + return "工具结果" + case "inferred_from_context": + return "上下文推断" + default: + return source + } + } + switch source { + case "user_mention": + return "user mention" + case "tool_output": + return "tool output" + case "inferred_from_context": + return "context inference" + default: + return source + } +} + +func hasAnyActiveContext(a *Agent, userID int64) bool { + if a == nil { + return false + } + if _, ok := a.getActiveSkillSession(userID); ok { + return true + } + return a.hasActiveSkillSession(userID) || hasActiveWorkflowSession(a.getWorkflowSession(userID)) || hasActiveExecutionState(a.getExecutionState(userID)) +} + +func (a *Agent) clearAnyActiveContext(userID int64) bool { + cleared := false + if _, ok := a.getActiveSkillSession(userID); ok { + a.clearActiveSkillSession(userID) + cleared = true + } + if a.hasActiveSkillSession(userID) { + a.clearSkillSession(userID) + cleared = true + } + if hasActiveWorkflowSession(a.getWorkflowSession(userID)) { + a.clearWorkflowSession(userID) + cleared = true + } + if hasActiveExecutionState(a.getExecutionState(userID)) { + a.clearExecutionState(userID) + cleared = true + } + if cleared { + a.SnapshotManager(userID).Clear() + } + return cleared } func skillDataForAction(storeUserID, skill, action string, a *Agent) map[string]any { diff --git a/agent/market_snapshot_test.go b/agent/market_snapshot_test.go new file mode 100644 index 00000000..9828b862 --- /dev/null +++ b/agent/market_snapshot_test.go @@ -0,0 +1,107 @@ +package agent + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "testing" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestToolGetMarketSnapshotReturnsRealtimeAnalysisContext(t *testing.T) { + prevBaseURL := binanceFuturesAPIBaseURL + prevClient := marketDataHTTPClient + binanceFuturesAPIBaseURL = "https://example.test" + marketDataHTTPClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + body := "" + switch { + case strings.HasPrefix(req.URL.Path, "/fapi/v1/ticker/24hr"): + body = `{"symbol":"BTCUSDT","lastPrice":"65000","priceChange":"1200","priceChangePercent":"1.88","highPrice":"66000","lowPrice":"63800","volume":"12345","quoteVolume":"800000000","count":98765}` + case strings.HasPrefix(req.URL.Path, "/fapi/v1/premiumIndex"): + body = `{"symbol":"BTCUSDT","markPrice":"65010","indexPrice":"64990","lastFundingRate":"0.00010000","nextFundingTime":1710000000000}` + case strings.HasPrefix(req.URL.Path, "/fapi/v1/openInterest"): + body = `{"symbol":"BTCUSDT","openInterest":"45678.9","time":1710000000000}` + case strings.HasPrefix(req.URL.Path, "/fapi/v1/klines"): + body = `[[1710000000000,"64000","65100","63900","64500","100",1710000899999],[1710000900000,"64500","65500","64400","65000","120",1710001799999]]` + default: + body = `{"error":"not found"}` + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + }, nil + }), + } + defer func() { + binanceFuturesAPIBaseURL = prevBaseURL + marketDataHTTPClient = prevClient + }() + + a := New(nil, nil, DefaultConfig(), nil) + raw := a.toolGetMarketSnapshot(`{"symbol":"BTC","interval":"15m","limit":2}`) + + var resp struct { + Symbol string `json:"symbol"` + Price float64 `json:"price"` + Ticker24h struct { + PriceChangePercent float64 `json:"price_change_percent"` + } `json:"ticker_24h"` + PerpMetrics struct { + FundingRate float64 `json:"funding_rate"` + OpenInterest float64 `json:"open_interest"` + } `json:"perp_metrics"` + KlineSnapshot struct { + Interval string `json:"interval"` + Limit int `json:"limit"` + PeriodChangePercent float64 `json:"period_change_percent"` + RecentKlines []map[string]any `json:"recent_klines"` + } `json:"kline_snapshot"` + Error string `json:"error"` + } + if err := json.Unmarshal([]byte(raw), &resp); err != nil { + t.Fatalf("failed to parse tool response: %v\nraw=%s", err, raw) + } + if resp.Error != "" { + t.Fatalf("unexpected tool error: %s", resp.Error) + } + if resp.Symbol != "BTCUSDT" { + t.Fatalf("expected normalized symbol BTCUSDT, got %s", resp.Symbol) + } + if resp.Price != 65000 { + t.Fatalf("expected price 65000, got %v", resp.Price) + } + if resp.Ticker24h.PriceChangePercent != 1.88 { + t.Fatalf("expected 24h change 1.88, got %v", resp.Ticker24h.PriceChangePercent) + } + if resp.PerpMetrics.FundingRate != 0.0001 { + t.Fatalf("expected funding rate 0.0001, got %v", resp.PerpMetrics.FundingRate) + } + if resp.PerpMetrics.OpenInterest != 45678.9 { + t.Fatalf("expected open interest 45678.9, got %v", resp.PerpMetrics.OpenInterest) + } + if resp.KlineSnapshot.Interval != "15m" || resp.KlineSnapshot.Limit != 2 { + t.Fatalf("unexpected kline snapshot metadata: %+v", resp.KlineSnapshot) + } + if len(resp.KlineSnapshot.RecentKlines) != 2 { + t.Fatalf("expected 2 klines, got %d", len(resp.KlineSnapshot.RecentKlines)) + } + if resp.KlineSnapshot.PeriodChangePercent <= 0 { + t.Fatalf("expected positive period change, got %v", resp.KlineSnapshot.PeriodChangePercent) + } +} + +func TestToolGetMarketSnapshotRejectsStockSymbols(t *testing.T) { + a := New(nil, nil, DefaultConfig(), nil) + raw := a.toolGetMarketSnapshot(`{"symbol":"AAPL"}`) + if !strings.Contains(raw, "currently supports crypto symbols only") { + t.Fatalf("expected stock rejection, got: %s", raw) + } +} diff --git a/agent/memory.go b/agent/memory.go index 4b274648..7ffe5c36 100644 --- a/agent/memory.go +++ b/agent/memory.go @@ -11,8 +11,9 @@ import ( ) const ( - recentConversationRounds = 3 + recentConversationRounds = 6 recentConversationMessages = recentConversationRounds * 2 + chatHistoryMaxTurns = recentConversationMessages * 2 // fallback cap when compression is unavailable taskStateSummaryTokenLimit = 1200 shortTermCompressThreshold = 900 incrementalTaskStateMessages = 6 diff --git a/agent/memory_test.go b/agent/memory_test.go deleted file mode 100644 index ed772be9..00000000 --- a/agent/memory_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package agent - -import ( - "context" - "log/slog" - "path/filepath" - "strings" - "testing" - "time" - - "nofx/mcp" - "nofx/store" -) - -type fakeAIClient struct { - callCount int -} - -func (f *fakeAIClient) SetAPIKey(string, string, string) {} -func (f *fakeAIClient) SetTimeout(time.Duration) {} -func (f *fakeAIClient) CallWithMessages(string, string) (string, error) { - return "", nil -} -func (f *fakeAIClient) CallWithRequest(req *mcp.Request) (string, error) { - f.callCount++ - return `{"current_goal":"continue setup","active_flow":"onboarding","open_loops":["finish trader setup after external exchange/model configuration is ready"],"important_facts":["user selected OKX"],"last_decision":{"action":"paused setup","reason":"user asked a market question","still_valid":true},"updated_at":"2026-04-01T00:00:00Z"}`, nil -} -func (f *fakeAIClient) CallWithRequestStream(req *mcp.Request, onChunk func(string)) (string, error) { - return "", nil -} -func (f *fakeAIClient) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) { - return nil, nil -} - -func TestMaybeCompressHistoryKeepsRecentThreeRounds(t *testing.T) { - st, err := store.New(filepath.Join(t.TempDir(), "nofxi-test.db")) - if err != nil { - t.Fatalf("store.New() error = %v", err) - } - - fakeClient := &fakeAIClient{} - a := &Agent{ - store: st, - logger: slog.Default(), - history: newChatHistory(100), - aiClient: fakeClient, - } - - userID := int64(42) - payload := strings.Repeat("BTC ETH market context ", 20) - for i := 0; i < 6; i++ { - a.history.Add(userID, "user", "user turn #"+string(rune('0'+i))+" "+payload) - a.history.Add(userID, "assistant", "assistant turn #"+string(rune('0'+i))+" "+payload) - } - - a.maybeCompressHistory(context.Background(), userID) - - msgs := a.history.Get(userID) - if len(msgs) != recentConversationMessages { - t.Fatalf("expected %d recent messages, got %d", recentConversationMessages, len(msgs)) - } - if fakeClient.callCount != 1 { - t.Fatalf("expected summarizer to be called once, got %d", fakeClient.callCount) - } - - state := a.getTaskState(userID) - if state.CurrentGoal != "continue setup" { - t.Fatalf("expected persisted task state goal, got %#v", state) - } - if state.LastDecision == nil || state.LastDecision.Action != "paused setup" { - t.Fatalf("expected persisted last_decision, got %#v", state.LastDecision) - } - if len(state.OpenLoops) != 1 || state.OpenLoops[0] != "finish trader setup after external exchange/model configuration is ready" { - t.Fatalf("expected high-level open loop, got %#v", state.OpenLoops) - } - if strings.Contains(msgs[0].Content, "#0") { - t.Fatalf("expected oldest round to be compressed away, first recent message = %q", msgs[0].Content) - } - if !strings.Contains(msgs[0].Content, "#3") { - t.Fatalf("expected recent window to start from round #3, got %q", msgs[0].Content) - } - if !strings.Contains(msgs[len(msgs)-1].Content, "#5") { - t.Fatalf("expected latest round to remain in short-term history, got %q", msgs[len(msgs)-1].Content) - } -} - -func TestNormalizeTaskStateDropsExecutionLevelOpenLoops(t *testing.T) { - state := normalizeTaskState(TaskState{ - OpenLoops: []string{ - "wait for API secret", - "call get_exchange_configs", - "finish trader setup after external configuration is ready", - }, - }) - - if len(state.OpenLoops) != 1 { - t.Fatalf("expected only one high-level open loop to remain, got %#v", state.OpenLoops) - } - if state.OpenLoops[0] != "finish trader setup after external configuration is ready" { - t.Fatalf("unexpected open loop after normalization: %#v", state.OpenLoops) - } -} - -func TestMaybeUpdateTaskStateIncrementallyPersistsShortConversationFacts(t *testing.T) { - st, err := store.New(filepath.Join(t.TempDir(), "nofxi-test.db")) - if err != nil { - t.Fatalf("store.New() error = %v", err) - } - - fakeClient := &fakeAIClient{} - a := &Agent{ - store: st, - logger: slog.Default(), - history: newChatHistory(100), - aiClient: fakeClient, - } - - userID := int64(7) - a.history.Add(userID, "user", "我是在运行测试1交易员时遇到的,错误是运行时出现的") - a.history.Add(userID, "assistant", "我会继续排查测试1交易员的运行时错误") - - a.maybeUpdateTaskStateIncrementally(context.Background(), userID) - - if fakeClient.callCount != 1 { - t.Fatalf("expected incremental summarizer to be called once, got %d", fakeClient.callCount) - } - - state := a.getTaskState(userID) - if state.CurrentGoal != "continue setup" { - t.Fatalf("expected incrementally persisted task state, got %#v", state) - } -} diff --git a/agent/model_create_flow_test.go b/agent/model_create_flow_test.go new file mode 100644 index 00000000..29bf4c34 --- /dev/null +++ b/agent/model_create_flow_test.go @@ -0,0 +1,75 @@ +package agent + +import ( + "log/slog" + "path/filepath" + "strings" + "testing" + + "nofx/store" +) + +func TestHandleModelCreateSkillAsksProviderFirstWithClaw402Recommendation(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "agent-model-create.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + + a := New(nil, st, DefaultConfig(), slog.Default()) + reply := a.handleModelCreateSkill("default", 42, "zh", "请帮我创建一个模型", skillSession{}) + + for _, want := range []string{ + "还缺这些字段:模型提供商", + "可选模型 provider", + "推荐 `claw402`", + "并列可选", + "按次付费", + "Base USDC 钱包支付", + "直接创建 Base 钱包", + "直接扫码充值/支付", + } { + if !strings.Contains(reply, want) { + t.Fatalf("expected reply to contain %q, got: %s", want, reply) + } + } + for _, unexpected := range []string{ + "还缺这些字段:模型提供商、API Key", + "还缺这些字段:模型提供商、钱包私钥", + "还缺这些字段:模型提供商、wallet private key", + } { + if strings.Contains(reply, unexpected) { + t.Fatalf("provider-first reply should not ask for credentials yet: %s", reply) + } + } +} + +func TestHandleModelCreateSkillUsesCollectedClaw402PrivateKey(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "agent-model-create-claw402.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + + a := New(nil, st, DefaultConfig(), slog.Default()) + session := skillSession{ + Name: "model_management", + Action: "create", + Phase: "collecting", + Fields: map[string]string{ + "provider": "claw402", + "name": "Claw402 (Base USDC)", + "api_key": "0x205d759b80bae1afa31a36c4afaeec0b10378c1c55e3363bcde5a1db75c747ca", + "custom_model_name": "deepseek", + }, + } + + reply := a.handleModelCreateSkill("default", 42, "zh", "继续", session) + + if strings.Contains(reply, "还缺这些字段:钱包私钥") { + t.Fatalf("expected bare private key to be accepted, got: %s", reply) + } + if !strings.Contains(reply, "我先整理了一份模型配置草稿") { + t.Fatalf("expected draft summary after accepting private key, got: %s", reply) + } +} diff --git a/agent/model_provider_catalog.go b/agent/model_provider_catalog.go new file mode 100644 index 00000000..a42eedc6 --- /dev/null +++ b/agent/model_provider_catalog.go @@ -0,0 +1,242 @@ +package agent + +import ( + "fmt" + "strings" +) + +type modelProviderSpec struct { + ID string + DisplayName string + DefaultModel string + CredentialLabelZH string + CredentialLabelEN string + SupportsCustomAPIURL bool + SupportsCustomModel bool + UsesWalletCredential bool + Recommended bool + RecommendedModelHints []string +} + +func supportedModelProviders() []modelProviderSpec { + return []modelProviderSpec{ + {ID: "deepseek", DisplayName: "DeepSeek", DefaultModel: "deepseek-chat", CredentialLabelZH: "API Key", CredentialLabelEN: "API key", SupportsCustomAPIURL: true, SupportsCustomModel: true}, + {ID: "qwen", DisplayName: "Qwen", DefaultModel: "qwen3-max", CredentialLabelZH: "API Key", CredentialLabelEN: "API key", SupportsCustomAPIURL: true, SupportsCustomModel: true}, + {ID: "openai", DisplayName: "OpenAI", DefaultModel: "gpt-5.1", CredentialLabelZH: "API Key", CredentialLabelEN: "API key", SupportsCustomAPIURL: true, SupportsCustomModel: true}, + {ID: "claude", DisplayName: "Claude", DefaultModel: "claude-opus-4-6", CredentialLabelZH: "API Key", CredentialLabelEN: "API key", SupportsCustomAPIURL: true, SupportsCustomModel: true}, + {ID: "gemini", DisplayName: "Google Gemini", DefaultModel: "gemini-3-pro-preview", CredentialLabelZH: "API Key", CredentialLabelEN: "API key", SupportsCustomAPIURL: true, SupportsCustomModel: true}, + {ID: "grok", DisplayName: "Grok (xAI)", DefaultModel: "grok-3-latest", CredentialLabelZH: "API Key", CredentialLabelEN: "API key", SupportsCustomAPIURL: true, SupportsCustomModel: true}, + {ID: "kimi", DisplayName: "Kimi (Moonshot)", DefaultModel: "moonshot-v1-auto", CredentialLabelZH: "API Key", CredentialLabelEN: "API key", SupportsCustomAPIURL: true, SupportsCustomModel: true}, + {ID: "minimax", DisplayName: "MiniMax", DefaultModel: "MiniMax-M2.5", CredentialLabelZH: "API Key", CredentialLabelEN: "API key", SupportsCustomAPIURL: true, SupportsCustomModel: true}, + { + ID: "claw402", + DisplayName: "Claw402 (Base USDC)", + DefaultModel: "deepseek", + CredentialLabelZH: "钱包私钥", + CredentialLabelEN: "wallet private key", + SupportsCustomAPIURL: false, + SupportsCustomModel: true, + UsesWalletCredential: true, + Recommended: true, + RecommendedModelHints: []string{"deepseek", "glm-5", "gpt-5.4", "claude-opus", "qwen-max", "grok-4.1"}, + }, + { + ID: "blockrun-base", + DisplayName: "BlockRun (Base Wallet)", + DefaultModel: "auto", + CredentialLabelZH: "钱包私钥", + CredentialLabelEN: "wallet private key", + SupportsCustomAPIURL: false, + SupportsCustomModel: false, + UsesWalletCredential: true, + }, + { + ID: "blockrun-sol", + DisplayName: "BlockRun (Solana Wallet)", + DefaultModel: "auto", + CredentialLabelZH: "钱包私钥", + CredentialLabelEN: "wallet private key", + SupportsCustomAPIURL: false, + SupportsCustomModel: false, + UsesWalletCredential: true, + }, + } +} + +func modelProviderSpecByID(provider string) (modelProviderSpec, bool) { + provider = strings.ToLower(strings.TrimSpace(provider)) + for _, spec := range supportedModelProviders() { + if spec.ID == provider { + return spec, true + } + } + return modelProviderSpec{}, false +} + +func supportedModelProviderIDs() []string { + specs := supportedModelProviders() + out := make([]string, 0, len(specs)) + for _, spec := range specs { + out = append(out, spec.ID) + } + return out +} + +func defaultModelNameForProvider(provider string) string { + spec, ok := modelProviderSpecByID(provider) + if !ok { + return "" + } + return strings.TrimSpace(spec.DefaultModel) +} + +func defaultModelConfigName(provider string) string { + spec, ok := modelProviderSpecByID(provider) + if !ok { + provider = strings.TrimSpace(provider) + if provider == "" { + return "" + } + return provider + " AI" + } + return spec.DisplayName +} + +func modelProviderSupportsCustomAPIURL(provider string) bool { + spec, ok := modelProviderSpecByID(provider) + return ok && spec.SupportsCustomAPIURL +} + +func modelProviderSupportsCustomModel(provider string) bool { + spec, ok := modelProviderSpecByID(provider) + return ok && spec.SupportsCustomModel +} + +func modelProviderCredentialLabel(lang, provider string) string { + spec, ok := modelProviderSpecByID(provider) + if !ok { + if lang == "zh" { + return "API Key" + } + return "API key" + } + if lang == "zh" { + return spec.CredentialLabelZH + } + return spec.CredentialLabelEN +} + +func modelProviderSummaryList(lang string) string { + parts := make([]string, 0, len(supportedModelProviders())) + for _, spec := range supportedModelProviders() { + if lang == "zh" { + item := fmt.Sprintf("%s(默认 %s)", spec.ID, spec.DefaultModel) + if spec.Recommended { + item += " [推荐]" + } + parts = append(parts, item) + continue + } + item := fmt.Sprintf("%s (default %s)", spec.ID, spec.DefaultModel) + if spec.Recommended { + item += " [recommended]" + } + parts = append(parts, item) + } + if lang == "zh" { + return strings.Join(parts, "、") + } + return strings.Join(parts, ", ") +} + +func modelProviderChoicePrompt(lang string) string { + if lang == "zh" { + return "可选模型 provider:" + modelProviderSummaryList(lang) + "。这些 provider 是并列可选的:你可以直接选 `claw402`、DeepSeek / OpenAI / Claude / Gemini / Qwen / Kimi / Grok / MiniMax 这类 API Key provider,或者选 `blockrun-base` / `blockrun-sol` 这类钱包 provider。我们优先推荐 `claw402`,因为它按次付费、用 Base USDC 钱包支付、默认配置更省事。对于第一次使用的新手,也可以直接去产品配置页的模型配置里选择 `claw402`:那里支持直接创建 Base 钱包,并且可以直接扫码充值/支付。请先告诉我你想用哪个 provider。" + } + return "Available model providers: " + modelProviderSummaryList(lang) + ". These providers are peer options: you can choose `claw402`, an API-key provider such as DeepSeek / OpenAI / Claude / Gemini / Qwen / Kimi / Grok / MiniMax, or a wallet-based provider such as `blockrun-base` / `blockrun-sol`. We recommend `claw402` first because it is pay-per-use, uses Base USDC wallet payment, and has the simplest default setup. If this is your first time, you can also open the product's model config page, choose `claw402`, create a Base wallet there directly, and pay by scanning the QR/deposit flow. Tell me which provider you want first." +} + +func modelProviderDetailedGuidance(lang, provider string) string { + spec, ok := modelProviderSpecByID(provider) + if !ok { + return "" + } + if lang == "zh" { + lines := []string{ + fmt.Sprintf("你现在选的是 %s。", spec.DisplayName), + fmt.Sprintf("- 默认模型名:%s", spec.DefaultModel), + fmt.Sprintf("- 凭证类型:%s", spec.CredentialLabelZH), + } + if spec.SupportsCustomModel { + lines = append(lines, "- `custom_model_name` 可选;留空时默认用上面的默认模型。") + } else { + lines = append(lines, "- 这个 provider 不需要单独填写 `custom_model_name`。") + } + if spec.SupportsCustomAPIURL { + lines = append(lines, "- `custom_api_url` 可选;留空时使用官方默认地址。") + } else { + lines = append(lines, "- 这个 provider 不需要 `custom_api_url`。") + } + if len(spec.RecommendedModelHints) > 0 { + lines = append(lines, "- 常见可选模型:"+strings.Join(spec.RecommendedModelHints, "、")) + } + if provider == "claw402" { + lines = append(lines, "- 这是我们优先推荐的 provider:按次付费、Base USDC 钱包支付,对新手最省事。") + lines = append(lines, "- 如果你是第一次用,也可以直接去配置页的模型配置里选择 `claw402`,那里支持直接创建 Base 钱包,并可直接扫码充值/支付。") + } + return strings.Join(lines, "\n") + } + lines := []string{ + fmt.Sprintf("You selected %s.", spec.DisplayName), + fmt.Sprintf("- Default model: %s", spec.DefaultModel), + fmt.Sprintf("- Credential type: %s", spec.CredentialLabelEN), + } + if spec.SupportsCustomModel { + lines = append(lines, "- `custom_model_name` is optional; if omitted, the default model will be used.") + } else { + lines = append(lines, "- This provider does not need a separate `custom_model_name`.") + } + if spec.SupportsCustomAPIURL { + lines = append(lines, "- `custom_api_url` is optional; if omitted, the official default endpoint will be used.") + } else { + lines = append(lines, "- This provider does not need `custom_api_url`.") + } + if len(spec.RecommendedModelHints) > 0 { + lines = append(lines, "- Common model choices: "+strings.Join(spec.RecommendedModelHints, ", ")) + } + if provider == "claw402" { + lines = append(lines, "- This is our recommended provider: pay-per-use, Base USDC wallet payment, and the easiest setup for first-time users.") + lines = append(lines, "- If this is your first time, you can also open the model config page, choose `claw402`, create a Base wallet there directly, and pay through the QR/deposit flow.") + } + return strings.Join(lines, "\n") +} + +func modelProviderCredentialGuidance(lang, provider string) string { + spec, ok := modelProviderSpecByID(provider) + if !ok { + return "" + } + provider = strings.TrimSpace(spec.ID) + if lang == "zh" { + switch provider { + case "claw402": + return "claw402 这里要填的是 Base 链 EVM 钱包私钥。\n- 如果你是第一次用,最省事的方式是直接去配置页的模型配置里选择 `claw402`。\n- 那里可以一键快速创建钱包,界面会直接展示新钱包私钥,并且提供 Base USDC 充值入口。\n- 创建后请立刻备份私钥;系统会用它完成 claw402 支付和模型调用。\n- 如果你已经有 MetaMask、Rabby、Coinbase Wallet 这类 Base/EVM 钱包,也可以从钱包里导出现有私钥再发我。" + case "blockrun-base": + return "blockrun-base 这里要填的是 Base 链 EVM 钱包私钥。你可以从现有 EVM 钱包导出私钥后发我。" + case "blockrun-sol": + return "blockrun-sol 这里要填的是 Solana 钱包私钥。你可以从现有 Solana 钱包导出私钥后发我。" + default: + return fmt.Sprintf("%s 这里要填的是 %s。你把完整值发我就行,我会继续当前模型草稿。", spec.DisplayName, spec.CredentialLabelZH) + } + } + switch provider { + case "claw402": + return "For claw402, this field expects a Base-chain EVM wallet private key.\n- If this is your first time, the easiest path is to open the model config page and choose `claw402`.\n- That flow can quickly create a wallet for you, show the new private key, and provide a Base USDC deposit path.\n- Back up the key immediately after creation; the system uses it for claw402 payments and model access.\n- If you already use MetaMask, Rabby, or Coinbase Wallet, you can also export an existing Base/EVM wallet private key and send it to me." + case "blockrun-base": + return "For blockrun-base, this field expects a Base-chain EVM wallet private key. You can export it from an existing EVM wallet and send it to me." + case "blockrun-sol": + return "For blockrun-sol, this field expects a Solana wallet private key. You can export it from an existing Solana wallet and send it to me." + default: + return fmt.Sprintf("For %s, this field expects your %s. Send me the full value and I'll continue the current model draft.", spec.DisplayName, spec.CredentialLabelEN) + } +} diff --git a/agent/model_provider_catalog_test.go b/agent/model_provider_catalog_test.go new file mode 100644 index 00000000..8921b598 --- /dev/null +++ b/agent/model_provider_catalog_test.go @@ -0,0 +1,57 @@ +package agent + +import ( + "strings" + "testing" +) + +func TestModelProviderChoicePromptIncludesRecommendationWithoutAutoSelection(t *testing.T) { + msg := modelProviderChoicePrompt("zh") + for _, want := range []string{ + "可选模型 provider", + "claw402", + "DeepSeek", + "OpenAI", + "并列可选", + "blockrun-base", + "直接创建 Base 钱包", + "直接扫码充值/支付", + "请先告诉我你想用哪个 provider", + } { + if !strings.Contains(msg, want) { + t.Fatalf("expected prompt to contain %q, got: %s", want, msg) + } + } + if strings.Contains(msg, "把私钥发给我") { + t.Fatalf("provider choice prompt should not jump ahead to credential collection: %s", msg) + } +} + +func TestModelProviderCredentialGuidanceForClaw402MentionsConfigPageWalletFlow(t *testing.T) { + msg := modelProviderCredentialGuidance("zh", "claw402") + for _, want := range []string{ + "Base 链 EVM 钱包私钥", + "配置页的模型配置里选择 `claw402`", + "快速创建钱包", + "充值入口", + } { + if !strings.Contains(msg, want) { + t.Fatalf("expected guidance to contain %q, got: %s", want, msg) + } + } +} + +func TestModelProviderDetailedGuidanceForClaw402MentionsBeginnerFlow(t *testing.T) { + msg := modelProviderDetailedGuidance("zh", "claw402") + for _, want := range []string{ + "优先推荐", + "按次付费", + "Base USDC 钱包支付", + "直接创建 Base 钱包", + "直接扫码充值/支付", + } { + if !strings.Contains(msg, want) { + t.Fatalf("expected detailed guidance to contain %q, got: %s", want, msg) + } + } +} diff --git a/agent/model_wallet_fastpath.go b/agent/model_wallet_fastpath.go new file mode 100644 index 00000000..6670defe --- /dev/null +++ b/agent/model_wallet_fastpath.go @@ -0,0 +1,86 @@ +package agent + +import ( + "fmt" + "strconv" + "strings" +) + +func isModelWalletBalanceQuestion(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" || !strings.Contains(lower, "claw402") { + return false + } + return containsAny(lower, []string{"余额", "balance", "usdc"}) && + containsAny(lower, []string{"钱包", "wallet", "主钱包", "base"}) +} + +func (a *Agent) handleModelWalletBalanceQuestion(storeUserID, lang, text string) (string, bool) { + if !isModelWalletBalanceQuestion(text) || a == nil || a.store == nil { + return "", false + } + models, err := a.store.AIModel().List(storeUserID) + if err != nil { + if lang == "zh" { + return "我现在读取模型配置失败,暂时查不到 claw402 钱包余额。", true + } + return "I could not read model configs, so I cannot check the claw402 wallet balance right now.", true + } + + var matches []safeModelToolConfig + for _, model := range models { + if model == nil || strings.ToLower(strings.TrimSpace(model.Provider)) != "claw402" { + continue + } + matches = append(matches, safeModelForTool(model)) + } + if len(matches) == 0 { + if lang == "zh" { + return "当前没有找到 claw402 模型钱包配置。", true + } + return "No claw402 model wallet config was found.", true + } + + if lang == "zh" { + lines := []string{"当前 claw402 模型钱包余额:"} + for _, model := range matches { + name := defaultIfEmpty(model.Name, model.ID) + lines = append(lines, fmt.Sprintf("- %s:%s USDC", name, defaultIfEmpty(model.BalanceUSDC, "暂时无法读取"))) + if strings.TrimSpace(model.WalletAddress) != "" { + lines = append(lines, fmt.Sprintf(" 钱包地址:%s", model.WalletAddress)) + } + if balanceIsZero(model.BalanceUSDC) { + if model.Enabled { + lines = append(lines, " 这个模型配置已启用,但钱包余额为 0 USDC;这不是“未启用”,而是需要先充值 Base USDC 后才能稳定调用。") + } else { + lines = append(lines, " 钱包余额为 0 USDC;启用并充值 Base USDC 后才能稳定调用。") + } + } + } + lines = append(lines, "注意:这是 claw402/Base 模型支付钱包余额,不是 OKX/Binance 等交易所账户余额。") + return strings.Join(lines, "\n"), true + } + + lines := []string{"Current claw402 model wallet balance:"} + for _, model := range matches { + name := defaultIfEmpty(model.Name, model.ID) + lines = append(lines, fmt.Sprintf("- %s: %s USDC", name, defaultIfEmpty(model.BalanceUSDC, "unavailable"))) + if strings.TrimSpace(model.WalletAddress) != "" { + lines = append(lines, fmt.Sprintf(" Wallet address: %s", model.WalletAddress)) + } + if balanceIsZero(model.BalanceUSDC) { + lines = append(lines, " This model config may be enabled, but the wallet balance is 0 USDC; recharge Base USDC before relying on it.") + } + } + lines = append(lines, "Note: this is the claw402/Base model payment wallet balance, not an exchange account balance.") + return strings.Join(lines, "\n"), true +} + +func balanceIsZero(value string) bool { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return false + } + parsed, err := strconv.ParseFloat(trimmed, 64) + return err == nil && parsed <= 0 +} diff --git a/agent/onboard.go b/agent/onboard.go index d047dd82..b89e5b18 100644 --- a/agent/onboard.go +++ b/agent/onboard.go @@ -11,6 +11,7 @@ import ( ) var titleCaser = cases.Title(language.English) + const setupExchangeAccountName = "Default" // Onboard handles first-time setup through natural language. @@ -41,6 +42,11 @@ func (a *Agent) needsSetup() bool { // getSetupState loads the current setup state from user preferences. func (a *Agent) getSetupState(userID int64) *SetupState { + if cached, ok := a.setupStates.Load(userID); ok { + if state, ok := cached.(*SetupState); ok && state != nil { + return cloneSetupState(state) + } + } step, _ := a.store.GetSystemConfig(fmt.Sprintf("setup_step_%d", userID)) if step == "" { return &SetupState{} @@ -49,48 +55,30 @@ func (a *Agent) getSetupState(userID int64) *SetupState { Step: step, Exchange: getConfig(a.store, userID, "exchange"), ExchangeID: getConfig(a.store, userID, "exchange_id"), - APIKey: getConfig(a.store, userID, "api_key"), - APISecret: getConfig(a.store, userID, "api_secret"), - Passphrase: getConfig(a.store, userID, "passphrase"), AIProvider: getConfig(a.store, userID, "ai_provider"), AIModel: getConfig(a.store, userID, "ai_model"), AIModelID: getConfig(a.store, userID, "ai_model_id"), - AIKey: getConfig(a.store, userID, "ai_key"), AIBaseURL: getConfig(a.store, userID, "ai_base_url"), } } func (a *Agent) saveSetupState(userID int64, s *SetupState) { + a.setupStates.Store(userID, cloneSetupState(s)) a.store.SetSystemConfig(fmt.Sprintf("setup_step_%d", userID), s.Step) setConfig(a.store, userID, "exchange", s.Exchange) setConfig(a.store, userID, "exchange_id", s.ExchangeID) - // Store only a masked marker for secrets — full values stay in memory only. - // This prevents plaintext credentials from lingering in the config store - // if the setup flow is interrupted before clearSetupState runs. - if s.APIKey != "" { - setConfig(a.store, userID, "api_key", "****") - } - if s.APISecret != "" { - setConfig(a.store, userID, "api_secret", "****") - } - if s.Passphrase != "" { - setConfig(a.store, userID, "passphrase", "****") - } setConfig(a.store, userID, "ai_provider", s.AIProvider) setConfig(a.store, userID, "ai_model", s.AIModel) setConfig(a.store, userID, "ai_model_id", s.AIModelID) - if s.AIKey != "" { - setConfig(a.store, userID, "ai_key", "****") - } setConfig(a.store, userID, "ai_base_url", s.AIBaseURL) } func (a *Agent) clearSetupState(userID int64) { - for _, k := range []string{"step", "exchange", "exchange_id", "api_key", "api_secret", "passphrase", "ai_provider", "ai_model", "ai_model_id", "ai_key", "ai_base_url"} { - if err := a.store.SetSystemConfig(fmt.Sprintf("setup_%s_%d", k, userID), ""); err != nil { - a.log().Warn("clearSetupState: failed to clear key", "key", k, "error", err) - } + a.setupStates.Delete(userID) + for _, k := range []string{"step", "exchange", "exchange_id", "ai_provider", "ai_model", "ai_model_id", "ai_base_url"} { + a.store.SetSystemConfig(fmt.Sprintf("setup_%s_%d", k, userID), "") } + a.store.SetSystemConfig(fmt.Sprintf("setup_step_%d", userID), "") } func getConfig(st *store.Store, uid int64, key string) string { @@ -102,6 +90,14 @@ func setConfig(st *store.Store, uid int64, key, val string) { st.SetSystemConfig(fmt.Sprintf("setup_%s_%d", key, uid), val) } +func cloneSetupState(s *SetupState) *SetupState { + if s == nil { + return &SetupState{} + } + copy := *s + return © +} + // handleSetupFlow processes the setup conversation. // Returns (response, handled). If handled=false, continue to normal routing. func (a *Agent) handleSetupFlow(userID int64, text string, L string) (string, bool) { @@ -165,7 +161,7 @@ func (a *Agent) handleSetupFlowForStoreUser(storeUserID string, userID int64, te if L == "zh" { return fmt.Sprintf("⚠️ 交易所配置保存失败: %v\n请再试一次,或稍后去 Web UI 继续。", err), true } - return fmt.Sprintf("⚠️ Failed to save exchange config: %v\nPlease try again, or continue later in the Web UI.", err), true + return fmt.Sprintf("⚠️ I could not save the exchange settings just now: %v\nPlease try again, or continue later on the web page.", err), true } state.ExchangeID = exchangeID state.Step = "await_ai_model" @@ -182,7 +178,7 @@ func (a *Agent) handleSetupFlowForStoreUser(storeUserID string, userID int64, te if L == "zh" { return fmt.Sprintf("⚠️ 交易所配置保存失败: %v\n请再试一次,或稍后去 Web UI 继续。", err), true } - return fmt.Sprintf("⚠️ Failed to save exchange config: %v\nPlease try again, or continue later in the Web UI.", err), true + return fmt.Sprintf("⚠️ I could not save the exchange settings just now: %v\nPlease try again, or continue later on the web page.", err), true } state.ExchangeID = exchangeID state.Step = "await_ai_model" @@ -201,7 +197,7 @@ func (a *Agent) handleSetupFlowForStoreUser(storeUserID string, userID int64, te if L == "zh" { return fmt.Sprintf("⚠️ AI 模型配置保存失败: %v\n请再试一次,或稍后去 Web UI 继续。", err), true } - return fmt.Sprintf("⚠️ Failed to save AI model config: %v\nPlease try again, or continue later in the Web UI.", err), true + return fmt.Sprintf("⚠️ I could not save the AI model settings just now: %v\nPlease try again, or continue later on the web page.", err), true } state.AIModelID = aiModelID return a.finishSetup(storeUserID, userID, state, L) @@ -226,7 +222,7 @@ func isDirectSetupCommand(text string) bool { return false } switch text { - case "setup", "/setup", "开始配置", "配置", "开始设置": + case "setup", "/setup": return true default: return false @@ -265,19 +261,19 @@ func (a *Agent) handleAIChoice(storeUserID string, userID int64, text string, st lower := strings.ToLower(strings.TrimSpace(text)) models := map[string]struct{ provider, model, url string }{ - "deepseek": {"deepseek", "deepseek-chat", "https://api.deepseek.com/v1"}, - "1": {"deepseek", "deepseek-chat", "https://api.deepseek.com/v1"}, - "qwen": {"qwen", "qwen-plus", "https://dashscope.aliyuncs.com/compatible-mode/v1"}, + "deepseek": {"deepseek", "deepseek-chat", "https://api.deepseek.com/v1"}, + "1": {"deepseek", "deepseek-chat", "https://api.deepseek.com/v1"}, + "qwen": {"qwen", "qwen-plus", "https://dashscope.aliyuncs.com/compatible-mode/v1"}, "通义": {"qwen", "qwen-plus", "https://dashscope.aliyuncs.com/compatible-mode/v1"}, - "2": {"qwen", "qwen-plus", "https://dashscope.aliyuncs.com/compatible-mode/v1"}, - "openai": {"openai", "gpt-4o", "https://api.openai.com/v1"}, - "gpt": {"openai", "gpt-4o", "https://api.openai.com/v1"}, - "3": {"openai", "gpt-4o", "https://api.openai.com/v1"}, - "claude": {"claude", "claude-3-5-sonnet-20241022", "https://api.anthropic.com/v1"}, - "4": {"claude", "claude-3-5-sonnet-20241022", "https://api.anthropic.com/v1"}, - "skip": {"", "", ""}, + "2": {"qwen", "qwen-plus", "https://dashscope.aliyuncs.com/compatible-mode/v1"}, + "openai": {"openai", "gpt-4o", "https://api.openai.com/v1"}, + "gpt": {"openai", "gpt-4o", "https://api.openai.com/v1"}, + "3": {"openai", "gpt-4o", "https://api.openai.com/v1"}, + "claude": {"claude", "claude-3-5-sonnet-20241022", "https://api.anthropic.com/v1"}, + "4": {"claude", "claude-3-5-sonnet-20241022", "https://api.anthropic.com/v1"}, + "skip": {"", "", ""}, "跳过": {"", "", ""}, - "5": {"", "", ""}, + "5": {"", "", ""}, } choice, ok := models[lower] @@ -502,7 +498,9 @@ func (a *Agent) saveSetupAIModel(storeUserID string, state *SetupState) (string, return "", err } - modelID = fmt.Sprintf("%s_%s", storeUserID, state.AIProvider) + if modelID == state.AIProvider { + modelID = fmt.Sprintf("%s_%s", storeUserID, state.AIProvider) + } return modelID, nil } diff --git a/agent/onboard_test.go b/agent/onboard_test.go deleted file mode 100644 index 117055c3..00000000 --- a/agent/onboard_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package agent - -import "testing" - -func TestIsDirectSetupCommand(t *testing.T) { - cases := []struct { - text string - want bool - }{ - {text: "setup", want: true}, - {text: "/setup", want: true}, - {text: "开始配置", want: true}, - {text: "配置", want: true}, - {text: "开始设置", want: true}, - {text: "/开始配置", want: false}, - {text: "创建全新的配置,杠杆你定", want: false}, - {text: "帮我配置一个 deepseek 模型", want: false}, - {text: "绑定交易所 okx", want: false}, - } - - for _, tc := range cases { - if got := isDirectSetupCommand(tc.text); got != tc.want { - t.Fatalf("isDirectSetupCommand(%q) = %v, want %v", tc.text, got, tc.want) - } - } -} diff --git a/agent/planner_runtime.go b/agent/planner_runtime.go index ea319181..95839e28 100644 --- a/agent/planner_runtime.go +++ b/agent/planner_runtime.go @@ -5,16 +5,18 @@ import ( "encoding/json" "errors" "fmt" + "sort" "strings" "time" "nofx/mcp" + "nofx/store" ) const ( plannerMaxSteps = 8 plannerMaxIterations = 12 - observationMaxLength = 400 + observationMaxLength = 1000 ) var ( @@ -198,6 +200,28 @@ func isRealtimeAccountIntent(text string) bool { func snapshotKindsForIntent(userText string) []string { kinds := make([]string, 0, 6) + lower := strings.ToLower(strings.TrimSpace(userText)) + if lower == "" || isRealtimeAccountIntent(lower) { + return nil + } + + configKeywords := []string{ + "交易员", "trader", "traders", + "交易所", "exchange", "exchanges", + "模型", "model", "models", "llm", "ai model", + "策略", "strategy", "strategies", + "配置", "config", "setup", "create", "创建", "修改", "更新", "删除", "delete", + } + if containsAnyKeyword(lower, configKeywords) { + kinds = append(kinds, + "current_model_configs", + "current_exchange_configs", + "current_traders", + ) + if strings.Contains(lower, "策略") || strings.Contains(lower, "strategy") { + kinds = append(kinds, "current_strategies") + } + } return uniqueStrings(kinds) } @@ -277,9 +301,10 @@ func ensureCurrentReferences(state *ExecutionState) { } } -func preferReference(current **EntityReference, id, name string) { +func preferReference(current **EntityReference, id, name, source string) { id = strings.TrimSpace(id) name = strings.TrimSpace(name) + source = strings.TrimSpace(source) if id == "" && name == "" { return } @@ -292,6 +317,31 @@ func preferReference(current **EntityReference, id, name string) { if name != "" { (*current).Name = name } + if source != "" { + (*current).Source = source + } + (*current).UpdatedAt = time.Now().UTC().Format(time.RFC3339) +} + +func appendReferenceHistory(state *ExecutionState, kind, id, name, source string) { + if state == nil { + return + } + kind = strings.TrimSpace(kind) + id = strings.TrimSpace(id) + name = strings.TrimSpace(name) + source = strings.TrimSpace(source) + if kind == "" || (id == "" && name == "") { + return + } + state.ReferenceHistory = append(state.ReferenceHistory, ReferenceRecord{ + Kind: kind, + ID: id, + Name: name, + Source: source, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + }) + state.ReferenceHistory = normalizeReferenceHistory(state.ReferenceHistory) } func matchEntityReference(text string, candidates []EntityReference) *EntityReference { @@ -329,7 +379,8 @@ func (a *Agent) refreshCurrentReferencesForUserText(storeUserID, text string, st candidates = append(candidates, EntityReference{ID: strategy.ID, Name: strategy.Name}) } if ref := matchEntityReference(text, candidates); ref != nil { - preferReference(&state.CurrentReferences.Strategy, ref.ID, ref.Name) + preferReference(&state.CurrentReferences.Strategy, ref.ID, ref.Name, "user_mention") + appendReferenceHistory(state, "strategy", ref.ID, ref.Name, "user_mention") } } if traders, err := a.store.Trader().List(storeUserID); err == nil { @@ -338,7 +389,8 @@ func (a *Agent) refreshCurrentReferencesForUserText(storeUserID, text string, st candidates = append(candidates, EntityReference{ID: trader.ID, Name: trader.Name}) } if ref := matchEntityReference(text, candidates); ref != nil { - preferReference(&state.CurrentReferences.Trader, ref.ID, ref.Name) + preferReference(&state.CurrentReferences.Trader, ref.ID, ref.Name, "user_mention") + appendReferenceHistory(state, "trader", ref.ID, ref.Name, "user_mention") } } if models, err := a.store.AIModel().List(storeUserID); err == nil { @@ -354,12 +406,16 @@ func (a *Agent) refreshCurrentReferencesForUserText(storeUserID, text string, st candidates = append(candidates, EntityReference{ID: model.ID, Name: name}) } if ref := matchEntityReference(text, candidates); ref != nil { - preferReference(&state.CurrentReferences.Model, ref.ID, ref.Name) + preferReference(&state.CurrentReferences.Model, ref.ID, ref.Name, "user_mention") + appendReferenceHistory(state, "model", ref.ID, ref.Name, "user_mention") } } if exchanges, err := a.store.Exchange().List(storeUserID); err == nil { candidates := make([]EntityReference, 0, len(exchanges)) for _, exchange := range exchanges { + if !store.IsVisibleExchange(exchange) { + continue + } name := exchange.AccountName if name == "" { name = exchange.ExchangeType @@ -367,7 +423,8 @@ func (a *Agent) refreshCurrentReferencesForUserText(storeUserID, text string, st candidates = append(candidates, EntityReference{ID: exchange.ID, Name: name}) } if ref := matchEntityReference(text, candidates); ref != nil { - preferReference(&state.CurrentReferences.Exchange, ref.ID, ref.Name) + preferReference(&state.CurrentReferences.Exchange, ref.ID, ref.Name, "user_mention") + appendReferenceHistory(state, "exchange", ref.ID, ref.Name, "user_mention") } } } @@ -386,14 +443,18 @@ func updateCurrentReferencesFromToolResult(state *ExecutionState, toolName, raw switch toolName { case "manage_strategy": if item, ok := payload["strategy"].(map[string]any); ok { - preferReference(&state.CurrentReferences.Strategy, asString(item["id"]), asString(item["name"])) + id, name := asString(item["id"]), asString(item["name"]) + preferReference(&state.CurrentReferences.Strategy, id, name, "tool_output") + appendReferenceHistory(state, "strategy", id, name, "tool_output") } case "manage_trader": if item, ok := payload["trader"].(map[string]any); ok { - preferReference(&state.CurrentReferences.Trader, asString(item["id"]), asString(item["name"])) - preferReference(&state.CurrentReferences.Model, asString(item["ai_model_id"]), "") - preferReference(&state.CurrentReferences.Exchange, asString(item["exchange_id"]), "") - preferReference(&state.CurrentReferences.Strategy, asString(item["strategy_id"]), "") + id, name := asString(item["id"]), asString(item["name"]) + preferReference(&state.CurrentReferences.Trader, id, name, "tool_output") + appendReferenceHistory(state, "trader", id, name, "tool_output") + preferReference(&state.CurrentReferences.Model, asString(item["ai_model_id"]), "", "tool_output") + preferReference(&state.CurrentReferences.Exchange, asString(item["exchange_id"]), "", "tool_output") + preferReference(&state.CurrentReferences.Strategy, asString(item["strategy_id"]), "", "tool_output") } case "manage_model_config": if item, ok := payload["model"].(map[string]any); ok { @@ -401,7 +462,9 @@ func updateCurrentReferencesFromToolResult(state *ExecutionState, toolName, raw if name == "" { name = asString(item["provider"]) } - preferReference(&state.CurrentReferences.Model, asString(item["id"]), name) + id := asString(item["id"]) + preferReference(&state.CurrentReferences.Model, id, name, "tool_output") + appendReferenceHistory(state, "model", id, name, "tool_output") } case "manage_exchange_config": if item, ok := payload["exchange"].(map[string]any); ok { @@ -409,12 +472,33 @@ func updateCurrentReferencesFromToolResult(state *ExecutionState, toolName, raw if name == "" { name = asString(item["exchange_type"]) } - preferReference(&state.CurrentReferences.Exchange, asString(item["id"]), name) + id := asString(item["id"]) + preferReference(&state.CurrentReferences.Exchange, id, name, "tool_output") + appendReferenceHistory(state, "exchange", id, name, "tool_output") } case "get_strategies": - if items, ok := payload["strategies"].([]any); ok && len(items) == 1 { - if item, ok := items[0].(map[string]any); ok { - preferReference(&state.CurrentReferences.Strategy, asString(item["id"]), asString(item["name"])) + if items, ok := payload["strategies"].([]any); ok { + var matched map[string]any + if len(items) == 1 { + matched, _ = items[0].(map[string]any) + } else { + goal := strings.ToLower(strings.TrimSpace(state.Goal)) + for _, it := range items { + item, ok := it.(map[string]any) + if !ok { + continue + } + name := strings.ToLower(strings.TrimSpace(asString(item["name"]))) + if name != "" && goal != "" && strings.Contains(goal, name) { + matched = item + break + } + } + } + if matched != nil { + id, name := asString(matched["id"]), asString(matched["name"]) + preferReference(&state.CurrentReferences.Strategy, id, name, "tool_output") + appendReferenceHistory(state, "strategy", id, name, "tool_output") } } } @@ -459,28 +543,21 @@ func detectReadFastPath(text string) *readFastPathRequest { case "/history", "/trades": return &readFastPathRequest{Kind: "get_trade_history", ArgsJSON: `{"limit":10}`} default: - return nil + switch { + case containsAnyKeyword(lower, []string{"列出", "查看", "看看", "查询", "list", "show"}) && containsAnyKeyword(lower, []string{"策略", "strategy"}): + return &readFastPathRequest{Kind: "get_strategies"} + case containsAnyKeyword(lower, []string{"列出", "查看", "看看", "查询", "list", "show"}) && containsAnyKeyword(lower, []string{"交易员", "trader"}): + return &readFastPathRequest{Kind: "list_traders"} + case containsAnyKeyword(lower, []string{"列出", "查看", "看看", "查询", "list", "show"}) && containsAnyKeyword(lower, []string{"模型", "model"}): + return &readFastPathRequest{Kind: "get_model_configs"} + case containsAnyKeyword(lower, []string{"列出", "查看", "看看", "查询", "list", "show"}) && containsAnyKeyword(lower, []string{"交易所", "exchange"}): + return &readFastPathRequest{Kind: "get_exchange_configs"} + default: + return nil + } } } -func (a *Agent) tryReadFastPath(storeUserID string, userID int64, lang, text string) (string, bool) { - req := detectReadFastPath(text) - if req == nil { - return "", false - } - a.ensureHistory() - - a.history.Add(userID, "user", text) - raw := a.executeReadFastPath(storeUserID, userID, req) - answer := formatReadFastPathResponse(lang, req.Kind, raw) - a.history.Add(userID, "assistant", answer) - if !isEphemeralReadFastPathKind(req.Kind) { - a.maybeUpdateTaskStateIncrementally(context.Background(), userID) - a.maybeCompressHistory(context.Background(), userID) - } - return answer, true -} - func isEphemeralReadFastPathKind(kind string) bool { switch kind { case "get_balance", "get_positions", "get_trade_history": @@ -493,9 +570,9 @@ func isEphemeralReadFastPathKind(kind string) bool { func (a *Agent) executeReadFastPath(storeUserID string, _ int64, req *readFastPathRequest) string { switch req.Kind { case "get_balance": - return a.toolGetBalance() + return a.toolGetBalance(storeUserID) case "get_positions": - return a.toolGetPositions() + return a.toolGetPositions(storeUserID) case "get_trade_history": return a.toolGetTradeHistory(req.ArgsJSON) case "get_strategies": @@ -741,112 +818,79 @@ func formatReadFastPathResponse(lang, kind, raw string) string { } func (a *Agent) thinkAndAct(ctx context.Context, storeUserID string, userID int64, lang, text string) (string, error) { - if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, nil); ok || err != nil { - return answer, err - } - if answer, ok := tryInstantDirectReply(lang, text); ok { - return answer, nil - } - if answer, ok := a.tryReadFastPath(storeUserID, userID, lang, text); ok { - return answer, nil - } - if answer, ok, err := a.tryWorkflowIntent(ctx, storeUserID, userID, lang, text, nil); ok || err != nil { - return answer, err - } - if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, nil); ok { - return answer, nil - } - // Check setup flow before falling back to noAI — handles "开始配置", "setup", etc. - if reply, handled := a.handleSetupFlowForStoreUser(storeUserID, userID, text, lang); handled { - return reply, nil + lock := a.flowLock(userID) + lock.Lock() + defer lock.Unlock() + if a.aiClient != nil { + if answer, ok, err := a.tryLLMIntentRoute(ctx, storeUserID, userID, lang, text, nil); ok || err != nil { + return a.maybeAppendResumePrompt(userID, lang, text, answer), err + } + } else if a.hasAnyActiveContext(userID) { + if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, nil); ok || err != nil { + return a.maybeAppendResumePrompt(userID, lang, text, answer), err + } } if a.aiClient == nil { - return a.noAIFallback(lang, text) + if !a.hasAnyActiveContext(userID) { + if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, nil); ok || err != nil { + return a.maybeAppendResumePrompt(userID, lang, text, answer), err + } + } + if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, nil); ok { + return a.maybeAppendResumePrompt(userID, lang, text, answer), nil + } + if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, nil); ok { + return a.maybeAppendResumePrompt(userID, lang, text, answer), nil + } + return a.noAIFallback(storeUserID, lang, text) } - return a.runPlannedAgent(ctx, storeUserID, userID, lang, text, nil) + answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, nil) + return a.maybeAppendResumePrompt(userID, lang, text, answer), err } func (a *Agent) thinkAndActStream(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, error) { - if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil { - return answer, err - } - if answer, ok := tryInstantDirectReply(lang, text); ok { - if onEvent != nil { - onEvent(StreamEventDelta, answer) + lock := a.flowLock(userID) + lock.Lock() + defer lock.Unlock() + if a.aiClient != nil { + if answer, ok, err := a.tryLLMIntentRoute(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil { + answer = a.maybeAppendResumePrompt(userID, lang, text, answer) + return answer, err } - return answer, nil - } - if answer, ok := a.tryReadFastPath(storeUserID, userID, lang, text); ok { - if onEvent != nil { - onEvent(StreamEventTool, "read_fast_path") - onEvent(StreamEventDelta, answer) + } else if a.hasAnyActiveContext(userID) { + if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil { + answer = a.maybeAppendResumePrompt(userID, lang, text, answer) + return answer, err } - return answer, nil - } - if answer, ok, err := a.tryWorkflowIntent(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil { - return answer, err - } - if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { - return answer, nil - } - // Check setup flow before falling back to noAI — handles "开始配置", "setup", etc. - if reply, handled := a.handleSetupFlowForStoreUser(storeUserID, userID, text, lang); handled { - if onEvent != nil { - onEvent(StreamEventDelta, reply) - } - return reply, nil } if a.aiClient == nil { - return a.noAIFallback(lang, text) + if !a.hasAnyActiveContext(userID) { + if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil { + answer = a.maybeAppendResumePrompt(userID, lang, text, answer) + return answer, err + } + } + if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, onEvent); ok { + answer = a.maybeAppendResumePrompt(userID, lang, text, answer) + return answer, nil + } + if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { + return a.maybeAppendResumePrompt(userID, lang, text, answer), nil + } + return a.noAIFallback(storeUserID, lang, text) } - return a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent) + answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent) + return a.maybeAppendResumePrompt(userID, lang, text, answer), err } -func tryInstantDirectReply(lang, text string) (string, bool) { +func isInstantDirectReplyText(text string) bool { lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return "", false + switch lower { + case "hi", "hello", "hey", "你好", "嗨", "在吗", "你好吗", "最近怎么样", "最近还好吗", "谢谢", "多谢", "谢了", "ok", "好的", "收到", "thanks", "thank you", "okay", "got it", "how are you": + return true + default: + return false } - - zhReplies := map[string]string{ - "hi": "在,有什么我帮你看的?", - "hello": "在,有什么我帮你看的?", - "hey": "在,有什么我帮你看的?", - "你好": "在,有什么我帮你看的?", - "嗨": "在,有什么我帮你看的?", - "在吗": "在,有什么我帮你看的?", - "谢谢": "不客气。", - "多谢": "不客气。", - "谢了": "不客气。", - "ok": "好。", - "好的": "好。", - "收到": "好。", - } - enReplies := map[string]string{ - "hi": "I'm here. What should we look at?", - "hello": "I'm here. What should we look at?", - "hey": "I'm here. What should we look at?", - "thanks": "You're welcome.", - "thank you": "You're welcome.", - "ok": "Okay.", - "okay": "Okay.", - "got it": "Got it.", - } - - if lang == "zh" { - if reply, ok := zhReplies[lower]; ok { - return reply, true - } - if reply, ok := enReplies[lower]; ok { - return reply, true - } - return "", false - } - - if reply, ok := enReplies[lower]; ok { - return reply, true - } - return "", false } func (a *Agent) hasActiveSkillSession(userID int64) bool { @@ -854,6 +898,19 @@ func (a *Agent) hasActiveSkillSession(userID int64) bool { return strings.TrimSpace(session.Name) != "" } +func (a *Agent) hasAnyActiveContext(userID int64) bool { + if _, ok := a.getActiveSkillSession(userID); ok { + return true + } + if a.hasActiveSkillSession(userID) { + return true + } + if hasActiveWorkflowSession(a.getWorkflowSession(userID)) { + return true + } + return hasActiveExecutionState(a.getExecutionState(userID)) +} + func hasActiveExecutionState(state ExecutionState) bool { if strings.TrimSpace(state.SessionID) == "" { return false @@ -867,25 +924,61 @@ func hasActiveExecutionState(state ExecutionState) bool { } func (a *Agent) tryStatePriorityPath(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { + if answer, ok := a.tryResumeSuspendedTask(userID, lang, text); ok { + return answer, true, nil + } + if !a.hasActiveSkillSession(userID) && !hasActiveWorkflowSession(a.getWorkflowSession(userID)) && !hasActiveExecutionState(a.getExecutionState(userID)) { + if a.tryRestoreSuspendedTaskFromIdle(ctx, userID, lang, text) { + return a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, onEvent) + } + } if workflow := a.getWorkflowSession(userID); hasActiveWorkflowSession(workflow) { + if task, _, ok := nextRunnableWorkflowTask(workflow); ok && strings.TrimSpace(task.Skill) == "strategy_management" && strings.TrimSpace(task.Action) == "create" { + a.clearWorkflowSession(userID) + session := newActiveSkillSession(userID, "strategy_management", "create") + session.Goal = defaultIfEmpty(strings.TrimSpace(task.Request), strings.TrimSpace(text)) + answer, handled, err := a.driveActiveSession(ctx, storeUserID, userID, lang, defaultIfEmpty(task.Request, text), session, onEvent) + return answer, handled, err + } answer, handled, err := a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, workflow, onEvent) if handled || err != nil { return answer, true, err } } if session := a.getSkillSession(userID); strings.TrimSpace(session.Name) != "" { - switch a.classifySkillSessionInput(ctx, userID, lang, session, text) { + if answer, ok := a.redirectModelCreateSessionToStrategyCreateIfNeeded(storeUserID, userID, lang, text, session); ok { + if onEvent != nil && strings.TrimSpace(answer) != "" { + onEvent(StreamEventTool, "hard_skill:strategy_management") + emitStreamText(onEvent, answer) + } + return answer, true, nil + } + decision, _ := a.resolveSkillSessionTurn(ctx, userID, lang, text, session) + switch decision.Intent { case "cancel": a.clearSkillSession(userID) a.clearWorkflowSession(userID) - if lang == "zh" { - return "已取消当前流程。", true, nil - } - return "Cancelled the current flow.", true, nil - case "interrupt": - a.clearSkillSession(userID) + return a.maybeOfferParentTaskAfterCancel(userID, lang), true, nil + case "instant_reply": + return a.replyToActiveFlowInstantReply(ctx, userID, lang, text, onEvent), true, nil + case "resume_snapshot", "start_new": + answer, handled, err := a.handoffFromActiveFlow(ctx, storeUserID, userID, lang, text, decision.TargetSnapshotID, onEvent) + return answer, handled, err default: - if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { + if answer, ok := a.dispatchBridgedSkillSession(storeUserID, userID, lang, text, session); ok { + if onEvent != nil && strings.TrimSpace(answer) != "" { + switch session.Name { + case "trader_management": + onEvent(StreamEventTool, "hard_skill:trader_management") + case "model_management": + onEvent(StreamEventTool, "hard_skill:model_management") + case "exchange_management": + onEvent(StreamEventTool, "hard_skill:exchange_management") + case "strategy_management": + onEvent(StreamEventTool, "hard_skill:strategy_management") + } + emitStreamText(onEvent, answer) + } return answer, true, nil } } @@ -893,16 +986,32 @@ func (a *Agent) tryStatePriorityPath(ctx context.Context, storeUserID string, us state := a.getExecutionState(userID) if hasActiveExecutionState(state) { - switch classifyExecutionStateInput(state, text) { + decision, extraction := a.resolveExecutionStateTurn(ctx, userID, lang, state, text) + switch decision.Intent { case "cancel": a.clearExecutionState(userID) - if lang == "zh" { - return "已取消当前流程。", true, nil - } - return "Cancelled the current flow.", true, nil - case "interrupt": - a.clearExecutionState(userID) + return a.maybeOfferParentTaskAfterCancel(userID, lang), true, nil + case "instant_reply": + return a.replyToActiveFlowInstantReply(ctx, userID, lang, text, onEvent), true, nil + case "resume_snapshot", "start_new": + answer, handled, err := a.handoffFromActiveFlow(ctx, storeUserID, userID, lang, text, decision.TargetSnapshotID, onEvent) + return answer, handled, err default: + if decision.Intent == "continue_active" { + if answer, handled, err := a.redirectExecutionStateStrategyCreate(ctx, storeUserID, userID, lang, text, state, onEvent); handled || err != nil { + return answer, handled, err + } + if session, ok := a.bridgeExecutionStateToSkillSession(storeUserID, userID, text, state, extraction); ok { + answer, handled := a.dispatchBridgedSkillSession(storeUserID, userID, lang, text, session) + return answer, handled, nil + } + } + if extraction.Intent == "continue" { + a.applyExecutionStateExtraction(&state, extraction) + if err := a.saveExecutionState(state); err != nil { + return "", true, err + } + } answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent) return answer, true, err } @@ -911,6 +1020,451 @@ func (a *Agent) tryStatePriorityPath(ctx context.Context, storeUserID string, us return "", false, nil } +func isTraderCreateWaitingState(state ExecutionState) bool { + lowerGoal := strings.ToLower(strings.TrimSpace(state.Goal)) + if strings.Contains(lowerGoal, "创建交易员") || strings.Contains(lowerGoal, "新建交易员") || strings.Contains(lowerGoal, "create trader") { + return true + } + if state.Waiting == nil { + return false + } + lowerIntent := strings.ToLower(strings.TrimSpace(state.Waiting.Intent)) + lowerTarget := strings.ToLower(strings.TrimSpace(state.Waiting.ConfirmationTarget)) + return lowerIntent == "complete_trader_setup" || (lowerIntent == "confirm_action" && lowerTarget == "trader") +} + +func hasSkillBridgeSignal(a *Agent, storeUserID, skillName, action, text string, extraction executionFlowExtractionResult) bool { + if len(extraction.Fields) > 0 { + return true + } + lower := strings.ToLower(strings.TrimSpace(text)) + if isYesReply(text) || isNoReply(text) { + return true + } + switch skillName { + case "trader_management": + if containsAny(lower, []string{"名称", "名字", "name", "交易所", "exchange", "模型", "model", "策略", "strategy"}) { + return true + } + case "model_management": + if containsAny(lower, []string{"provider", "模型名", "模型名称", "api key", "api_key", "apikey", "url", "endpoint", "名称", "名字", "name"}) { + return true + } + case "exchange_management": + if containsAny(lower, []string{"交易所", "exchange", "账户名", "account", "api key", "secret", "passphrase", "testnet", "名称", "名字", "name"}) { + return true + } + case "strategy_management": + if containsAny(lower, []string{"策略", "strategy", "名称", "名字", "name", "prompt", "提示词", "配置", "参数"}) { + return true + } + } + if action == "create" && containsAny(lower, []string{"名称", "名字", "name"}) { + return true + } + if a == nil { + return false + } + return hasStrictOptionMention(text, a.loadEnabledModelOptions(storeUserID)) || + hasStrictOptionMention(text, a.loadExchangeOptions(storeUserID)) || + hasStrictOptionMention(text, a.loadStrategyOptions(storeUserID)) +} + +func inferExecutionStateSkillBridge(state ExecutionState, text string) (string, string) { + lowerGoal := strings.ToLower(strings.TrimSpace(state.Goal)) + waitingIntent := "" + waitingTarget := "" + if state.Waiting != nil { + waitingIntent = strings.ToLower(strings.TrimSpace(state.Waiting.Intent)) + waitingTarget = strings.ToLower(strings.TrimSpace(state.Waiting.ConfirmationTarget)) + } + switch waitingIntent { + case "complete_trader_setup": + return "trader_management", "create" + case "complete_model_config": + return "model_management", "create" + case "complete_exchange_config": + return "exchange_management", "create" + } + switch waitingTarget { + case "trader": + if containsAny(lowerGoal, []string{"创建", "新建", "create", "setup", "配置"}) || hasExplicitCreateIntentForDomain(state.Goal, "trader") { + return "trader_management", "create" + } + return "trader_management", "create" + case "model", "model_config": + return "model_management", "create" + case "exchange", "exchange_config": + return "exchange_management", "create" + case "strategy", "manage_strategy": + return "strategy_management", "create" + } + switch { + case hasExplicitCreateIntentForDomain(state.Goal, "trader"): + return "trader_management", "create" + } + return "", "" +} + +func traderCreateFieldsFromExecutionExtraction(result executionFlowExtractionResult) map[string]string { + if len(result.Fields) == 0 { + return nil + } + fields := make(map[string]string, len(result.Fields)) + for key, value := range result.Fields { + value = strings.TrimSpace(value) + if value == "" { + continue + } + switch strings.TrimSpace(key) { + case "name": + fields["name"] = value + case "model", "model_id", "ai_model_id": + fields["model_id"] = value + case "model_name": + fields["model_name"] = value + case "exchange", "exchange_id": + fields["exchange_id"] = value + case "exchange_name": + fields["exchange_name"] = value + case "strategy", "strategy_id": + fields["strategy_id"] = value + case "strategy_name": + fields["strategy_name"] = value + case "auto_start", "scan_interval_minutes", "is_cross_margin", "show_in_competition": + fields[key] = value + } + } + if len(fields) == 0 { + return nil + } + return fields +} + +func (a *Agent) bridgeExecutionStateToSkillSession(storeUserID string, userID int64, text string, state ExecutionState, extraction executionFlowExtractionResult) (skillSession, bool) { + skillName, action := inferExecutionStateSkillBridge(state, text) + if a == nil || skillName == "" || action == "" || !hasSkillBridgeSignal(a, storeUserID, skillName, action, text, extraction) { + return skillSession{}, false + } + if skillName == "strategy_management" && action == "create" { + return skillSession{}, false + } + + session := a.getSkillSession(userID) + if session.Name != "" && (session.Name != skillName || session.Action != action) { + return skillSession{}, false + } + if session.Name == "" { + session = skillSession{ + Name: skillName, + Action: action, + Phase: "collecting", + } + } + if len(extraction.Fields) > 0 { + fields := extraction.Fields + if skillName == "trader_management" { + fields = traderCreateFieldsFromExecutionExtraction(extraction) + } + if len(fields) > 0 { + a.applyLLMExtractionToSkillSession(storeUserID, &session, llmFlowExtractionResult{ + Tasks: []llmFlowExtractionTask{{ + Skill: skillName, + Action: action, + Fields: fields, + }}, + }, "zh", text) + } + } + + switch skillName { + case "trader_management": + a.hydrateCreateTraderSlotReferences(storeUserID, &session) + } + a.saveSkillSession(userID, session) + a.clearExecutionState(userID) + return session, true +} + +func (a *Agent) redirectExecutionStateStrategyCreate(ctx context.Context, storeUserID string, userID int64, lang, text string, state ExecutionState, onEvent func(event, data string)) (string, bool, error) { + skillName, action := inferExecutionStateSkillBridge(state, text) + if skillName != "strategy_management" || action != "create" { + return "", false, nil + } + a.clearExecutionState(userID) + session := newActiveSkillSession(userID, "strategy_management", "create") + session.Goal = defaultIfEmpty(strings.TrimSpace(state.Goal), strings.TrimSpace(text)) + return a.driveActiveSession(ctx, storeUserID, userID, lang, text, session, onEvent) +} + +func (a *Agent) redirectModelCreateSessionToStrategyCreateIfNeeded(storeUserID string, userID int64, lang, text string, session skillSession) (string, bool) { + if strings.TrimSpace(session.Name) != "model_management" || strings.TrimSpace(session.Action) != "create" { + return "", false + } + strategyType := parseStrategyTypeValue(text) + if strategyType == "" && !hasExplicitCreateIntentForDomain(text, "strategy") { + return "", false + } + strategySession := skillSession{ + Name: "strategy_management", + Action: "create", + Phase: "collecting", + Fields: map[string]string{}, + } + if strategyType != "" { + setStrategyCreateType(&strategySession, strategyType) + } + a.clearSkillSession(userID) + return a.handleStrategyCreateSkill(storeUserID, userID, lang, text, strategySession), true +} + +func (a *Agent) dispatchBridgedSkillSession(storeUserID string, userID int64, lang, text string, session skillSession) (string, bool) { + switch session.Name { + case "trader_management": + if session.Action == "create" { + return a.handleCreateTraderSkill(storeUserID, userID, lang, text, session) + } + return a.handleTraderManagementSkill(storeUserID, userID, lang, text, session) + case "model_management": + if session.Action == "create" { + return a.handleModelCreateSkill(storeUserID, userID, lang, text, session), true + } + return a.handleModelManagementSkill(storeUserID, userID, lang, text, session) + case "exchange_management": + if session.Action == "create" { + return a.handleExchangeCreateSkill(storeUserID, userID, lang, text, session), true + } + return a.handleExchangeManagementSkill(storeUserID, userID, lang, text, session) + case "strategy_management": + if session.Action == "create" { + return a.handleStrategyCreateSkill(storeUserID, userID, lang, text, session), true + } + return a.handleStrategyManagementSkill(storeUserID, userID, lang, text, session) + default: + return "", false + } +} + +func (a *Agent) resolveSkillSessionTurn(ctx context.Context, userID int64, lang, text string, session skillSession) (unifiedFlowDecision, llmFlowExtractionResult) { + text = strings.TrimSpace(text) + if text == "" { + return unifiedFlowDecision{Intent: "continue_active"}, llmFlowExtractionResult{} + } + if isInstantDirectReplyText(text) { + return unifiedFlowDecision{Intent: "instant_reply"}, llmFlowExtractionResult{Intent: "instant_reply"} + } + return a.classifySkillSessionDecision(ctx, userID, lang, session, text), llmFlowExtractionResult{} +} + +func (a *Agent) resolveExecutionStateTurn(ctx context.Context, userID int64, lang string, state ExecutionState, text string) (unifiedFlowDecision, executionFlowExtractionResult) { + text = strings.TrimSpace(text) + if text == "" { + return unifiedFlowDecision{Intent: "continue_active"}, executionFlowExtractionResult{} + } + if isInstantDirectReplyText(text) { + return unifiedFlowDecision{Intent: "instant_reply"}, executionFlowExtractionResult{Intent: "instant_reply"} + } + if a.aiClient != nil { + result := a.extractExecutionStateContinuationWithLLM(ctx, userID, lang, state, text) + if decision := unifiedFlowDecisionFromIntent(result.Intent, result.TargetSnapshotID); decision.Intent != "" { + return decision, result + } + } + return a.classifyExecutionStateDecision(ctx, userID, lang, state, text), executionFlowExtractionResult{} +} + +func unifiedFlowDecisionFromIntent(intent, targetSnapshotID string) unifiedFlowDecision { + intent = strings.TrimSpace(strings.ToLower(intent)) + targetSnapshotID = strings.TrimSpace(targetSnapshotID) + switch intent { + case "continue", "continue_active": + return unifiedFlowDecision{Intent: "continue_active"} + case "cancel": + return unifiedFlowDecision{Intent: "cancel"} + case "instant_reply": + return unifiedFlowDecision{Intent: "instant_reply"} + case "switch", "interrupt", "start_new", "resume_snapshot": + if targetSnapshotID != "" { + return unifiedFlowDecision{Intent: "resume_snapshot", TargetSnapshotID: targetSnapshotID} + } + return unifiedFlowDecision{Intent: "start_new"} + default: + return unifiedFlowDecision{} + } +} + +func (a *Agent) replyToActiveFlowInstantReply(ctx context.Context, userID int64, lang, text string, onEvent func(event, data string)) string { + a.suspendActiveContexts(userID, lang) + if a.aiClient != nil { + if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, onEvent); ok { + return a.maybeAppendResumePrompt(userID, lang, text, answer) + } + } + if lang == "zh" { + return a.maybeAppendResumePrompt(userID, lang, text, "刚才的流程我先保留着。要继续的话,直接说“继续”。") + } + return a.maybeAppendResumePrompt(userID, lang, text, "I kept the previous flow available. Say “continue” when you want to resume it.") +} + +func (a *Agent) handoffFromActiveFlow(ctx context.Context, storeUserID string, userID int64, lang, text, targetSnapshotID string, onEvent func(event, data string)) (string, bool, error) { + if a.suspendAndTryRestoreSuspendedTask(userID, lang, text, targetSnapshotID) { + if a.aiClient != nil { + return a.tryMinimalBrain(ctx, storeUserID, userID, lang, text, onEvent) + } + return a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, onEvent) + } + if answer, ok, err := a.tryLLMIntentRoute(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil { + return a.maybeAppendResumePrompt(userID, lang, text, answer), true, err + } + if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, onEvent); ok { + return a.maybeAppendResumePrompt(userID, lang, text, answer), true, nil + } + if a.aiClient == nil { + if a.tryRestoreSuspendedTaskAfterSwitch(userID, text, "") { + if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { + return answer, true, nil + } + } + if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { + return a.maybeAppendResumePrompt(userID, lang, text, answer), true, nil + } + answer, err := a.noAIFallback(storeUserID, lang, text) + return a.maybeAppendResumePrompt(userID, lang, text, answer), true, err + } + answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent) + return a.maybeAppendResumePrompt(userID, lang, text, answer), true, err +} + +func (a *Agent) extractExecutionStateContinuationWithLLM(ctx context.Context, userID int64, lang string, state ExecutionState, text string) executionFlowExtractionResult { + if a == nil || a.aiClient == nil || strings.TrimSpace(text) == "" { + return executionFlowExtractionResult{} + } + recentConversationCtx := a.buildRecentConversationContext(userID, text) + flowContext := fmt.Sprintf( + "Active flow type: execution_state\nGoal: %s\nStatus: %s", + state.Goal, + state.Status, + ) + waitingSummary := "" + if state.Waiting != nil { + waitingSummary = fmt.Sprintf("Waiting summary: question=%s pending_fields=%s", strings.TrimSpace(state.Waiting.Question), strings.Join(state.Waiting.PendingFields, ", ")) + } + systemPrompt, userPrompt := buildActiveFlowExtractionPrompt( + lang, + "execution_state", + flowContext, + text, + recentConversationCtx, + state.CurrentReferences, + a.SnapshotManager(userID).List(), + []string{ + fmt.Sprintf("Waiting JSON: %s", mustMarshalJSON(state.Waiting)), + waitingSummary, + }, + ) + systemPrompt += ` +- This is the structured continuation input for an active NOFXi execution flow. +- Prefer "continue" only when the message clearly contributes to the current waiting question or active execution goal. +- Use "switch" for read-only queries, unrelated requests, explanation requests, or clear topic changes. +- For "continue", extract only explicit field values that answer the waiting question or pending fields. +- Do not invent fields. If no field can be safely extracted, you may still return "continue" when the message is a meaningful free-form answer. + +Return JSON with this exact shape: +{"intent":"continue|switch|cancel|instant_reply","target_snapshot_id":"","fields":{},"reason":""}` + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return executionFlowExtractionResult{} + } + envelope, ok := parseRawFlowExtractionEnvelope(raw) + if !ok { + return executionFlowExtractionResult{} + } + out := executionFlowExtractionResult{ + Intent: envelope.Intent, + TargetSnapshotID: envelope.TargetSnapshotID, + Reason: envelope.Reason, + } + if len(envelope.Fields) > 0 { + out.Fields = envelope.Fields + } else if len(envelope.Tasks) > 0 { + out.Fields = envelope.Tasks[0].Fields + } + switch out.Intent { + case "continue", "switch", "cancel", "instant_reply", "interrupt": + return out + default: + return executionFlowExtractionResult{} + } +} + +func parseSuspendedTaskSelectionResult(raw string) (suspendedTaskSelectionResult, bool) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + + var out suspendedTaskSelectionResult + if err := json.Unmarshal([]byte(raw), &out); err != nil { + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start < 0 || end <= start || json.Unmarshal([]byte(raw[start:end+1]), &out) != nil { + return suspendedTaskSelectionResult{}, false + } + } + out.TargetSnapshotID = strings.TrimSpace(out.TargetSnapshotID) + if out.TargetSnapshotID == "" { + return suspendedTaskSelectionResult{}, false + } + return out, true +} + +func (a *Agent) applyExecutionStateExtraction(state *ExecutionState, result executionFlowExtractionResult) { + if state == nil || result.Intent != "continue" { + return + } + if len(result.Fields) == 0 && strings.TrimSpace(result.Reason) == "" { + return + } + fieldBits := make([]string, 0, len(result.Fields)) + for key, value := range result.Fields { + fieldBits = append(fieldBits, fmt.Sprintf("%s=%s", key, value)) + } + sort.Strings(fieldBits) + summary := "User continued the active execution flow." + if len(fieldBits) > 0 { + summary = "User supplied continuation fields: " + strings.Join(fieldBits, ", ") + } + appendExecutionLog(state, Observation{ + Kind: "waiting_user_input", + Summary: summary, + RawJSON: mustMarshalJSON(result), + CreatedAt: time.Now().UTC().Format(time.RFC3339), + }) + if state.Waiting != nil && len(state.Waiting.PendingFields) > 0 && len(result.Fields) > 0 { + remaining := make([]string, 0, len(state.Waiting.PendingFields)) + for _, field := range state.Waiting.PendingFields { + if _, ok := result.Fields[field]; ok { + continue + } + remaining = append(remaining, field) + } + state.Waiting.PendingFields = cleanStringList(remaining) + } +} + +func (a *Agent) classifySkillSessionDecision(ctx context.Context, userID int64, lang string, session skillSession, text string) unifiedFlowDecision { + return unifiedFlowDecisionFromIntent(a.classifySkillSessionInput(ctx, userID, lang, session, text), "") +} + func (a *Agent) classifySkillSessionInput(ctx context.Context, userID int64, lang string, session skillSession, text string) string { lower := strings.ToLower(strings.TrimSpace(text)) if lower == "" { @@ -922,16 +1476,25 @@ func (a *Agent) classifySkillSessionInput(ctx context.Context, userID int64, lan if isExplicitFlowAbort(text) { return "cancel" } - if shouldContinueSkillSessionByExpectedSlot(session, text) { + if strings.TrimSpace(session.Name) == "trader_management" && strings.TrimSpace(session.Action) == "create" { + if detectReadFastPath(text) == nil { + switch detectMentionedSkillDomain(text) { + case "exchange_management", "model_management", "strategy_management": + return "continue" + } + } + } + if a != nil && a.aiClient != nil { + if decision := a.classifySkillSessionIntentWithLLM(ctx, userID, lang, session, text); decision != "" { + return decision + } return "continue" } - if decision := a.classifySkillSessionIntentWithLLM(ctx, userID, lang, session, text); decision != "" { - return decision + if strings.TrimSpace(session.Name) != "" && strings.TrimSpace(session.Action) != "" && + !looksLikeNewTopLevelIntent(text) { + return "continue" } - if isNewSkillRootIntent(session, text) { - return "interrupt" - } - if isSkillFlowDeflection(session, text) { + if shouldInterruptSkillSessionBySnapshot(session, text) || shouldInterruptSkillSessionByExplicitDomainMention(session, text) || isNewSkillRootIntent(session, text) || isSkillFlowDeflection(session, text) { return "interrupt" } if belongsToSkillDomain(session.Name, text) || !looksLikeNewTopLevelIntent(text) { @@ -940,10 +1503,79 @@ func (a *Agent) classifySkillSessionInput(ctx context.Context, userID int64, lan return "interrupt" } -type skillSessionIntentDecision struct { +type activeFlowIntentDecision struct { Decision string `json:"decision"` } +type unifiedFlowDecision struct { + Intent string + TargetSnapshotID string +} + +type executionFlowExtractionResult struct { + Intent string `json:"intent,omitempty"` + TargetSnapshotID string `json:"target_snapshot_id,omitempty"` + Fields map[string]string `json:"fields,omitempty"` + Reason string `json:"reason,omitempty"` +} + +type suspendedTaskSelectionResult struct { + TargetSnapshotID string `json:"target_snapshot_id,omitempty"` +} + +func buildActiveFlowClassifierPrompt(lang, flowLabel, flowContext, text, recentConversationCtx string, currentRefs any, suspendedSnapshots any) (string, string) { + systemPrompt := `You classify one user message while an active NOFXi flow is in progress. +Return JSON only. No markdown. + +Possible decisions: +- "continue": the user is still continuing the current active flow +- "cancel": the user wants to stop the current active flow +- "interrupt": the user wants to leave the current active flow for another task, query, explanation, or topic +- "instant_reply": the user is only greeting, chatting, or thanking + +Be conservative: +- Prefer "continue" only when the message still contributes to the current active flow. +- Use "cancel" for explicit abandonment. +- Use "instant_reply" for greetings, thanks, and simple social chat. +- Use "interrupt" for unrelated requests, explanation requests, read-only queries, or clear topic shifts. +- Consider Current references JSON and Suspended snapshots JSON when resolving vague phrases like "那个", "刚才那个", or "前面那个". + +Return JSON with this exact shape: +{"decision":"continue|cancel|interrupt|instant_reply"}` + return systemPrompt, fmt.Sprintf( + "Language: %s\nActive flow label: %s\n%s\nCurrent references JSON: %s\nSuspended snapshots JSON: %s\nUser message: %s\n\nRecent conversation:\n%s", + lang, + flowLabel, + flowContext, + mustMarshalJSON(currentRefs), + mustMarshalJSON(suspendedSnapshots), + text, + recentConversationCtx, + ) +} + +func parseActiveFlowIntentDecision(raw string) string { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + var decision activeFlowIntentDecision + if err := json.Unmarshal([]byte(raw), &decision); err != nil { + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start < 0 || end <= start || json.Unmarshal([]byte(raw[start:end+1]), &decision) != nil { + return "" + } + } + switch strings.TrimSpace(decision.Decision) { + case "continue", "cancel", "interrupt", "instant_reply": + return decision.Decision + default: + return "" + } +} + func shouldUseLLMSkillSessionClassifier(session skillSession, text string) bool { if strings.TrimSpace(text) == "" { return false @@ -951,55 +1583,64 @@ func shouldUseLLMSkillSessionClassifier(session skillSession, text string) bool if isExplicitFlowAbort(text) || isYesReply(text) || isNoReply(text) { return false } - if shouldContinueSkillSessionByExpectedSlot(session, text) { - return false - } return true } -func shouldContinueSkillSessionByExpectedSlot(session skillSession, text string) bool { - text = strings.TrimSpace(text) - if text == "" { +func detectRootSkillIntent(text string) string { + return "" +} + +func shouldInterruptSkillSessionBySnapshot(session skillSession, text string) bool { + currentSkill := strings.TrimSpace(session.Name) + if currentSkill == "" { return false } - currentStep, ok := currentSkillDAGStep(session) - if !ok { + rootSkill := detectRootSkillIntent(text) + if rootSkill == "" { return false } - switch currentStep.ID { - case "await_start_confirmation", "await_confirmation": - return isYesReply(text) || isNoReply(text) - case "resolve_config_value": - if fieldValue(session, "config_field") == "selected_timeframes" { - return timeframeTokenRE.MatchString(strings.ToLower(text)) - } - return firstIntegerPattern.MatchString(text) - case "collect_enabled": - _, ok := parseEnabledValue(text) - return ok - case "collect_custom_api_url": - return extractURL(text) != "" - case "resolve_exchange_type": - return exchangeTypeFromText(text) != "" - case "resolve_provider": - return providerFromText(text) != "" - case "resolve_name", "collect_name", "collect_prompt", "collect_account_name", "collect_custom_model_name": - return !looksLikeNewTopLevelIntent(text) - } - for _, field := range currentStep.RequiredFields { - switch field { - case "config_value": - return firstIntegerPattern.MatchString(text) - case "enabled": - _, ok := parseEnabledValue(text) - return ok - case "custom_api_url": - return extractURL(text) != "" - } + if rootSkill != currentSkill && looksLikeNewTopLevelIntent(text) { + return true } return false } +func detectMentionedSkillDomain(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"交易员", "trader", "agent"}): + return "trader_management" + case containsAny(lower, []string{"策略", "strategy"}): + return "strategy_management" + case containsAny(lower, []string{"模型", "model"}): + return "model_management" + case containsAny(lower, []string{"交易所", "exchange"}): + return "exchange_management" + default: + return "" + } +} + +func shouldInterruptSkillSessionByExplicitDomainMention(session skillSession, text string) bool { + currentSkill := strings.TrimSpace(session.Name) + if currentSkill == "" { + return false + } + if currentSkill == "trader_management" { + if currentStep, ok := currentSkillDAGStep(session); ok { + switch currentStep.ID { + case "resolve_exchange", "resolve_model", "resolve_strategy", "collect_bindings": + return false + } + } + } + mentioned := detectMentionedSkillDomain(text) + if mentioned == "" || mentioned == currentSkill { + return false + } + return looksLikeNewTopLevelIntent(text) +} + func (a *Agent) classifySkillSessionIntentWithLLM(ctx context.Context, userID int64, lang string, session skillSession, text string) string { if a == nil || a.aiClient == nil { return "" @@ -1009,27 +1650,26 @@ func (a *Agent) classifySkillSessionIntentWithLLM(ctx context.Context, userID in } currentStep, _ := currentSkillDAGStep(session) recentConversationCtx := a.buildRecentConversationContext(userID, text) - systemPrompt := `You classify one user message while a NOFXi structured management flow is active. -Return JSON only. No markdown. - -Possible decisions: -- "continue": the user is still answering the current flow -- "cancel": the user wants to stop the current flow -- "interrupt": the user changed topic, wants diagnosis/query/new task, or should leave the current flow - -Be conservative: -- Prefer "continue" only when the message clearly answers the current slot/question. -- Use "cancel" for explicit abandonment like "算了", "不改了", "换话题", "别弄了". -- Use "interrupt" for diagnosis, query, new requests, or topic shifts.` - userPrompt := fmt.Sprintf( - "Language: %s\nActive skill: %s\nAction: %s\nCurrent DAG step: %s\nExpected required fields: %s\nUser message: %s\n\nRecent conversation:\n%s", - lang, + state := a.getExecutionState(userID) + flowContext := fmt.Sprintf( + "Active skill: %s\nAction: %s\nCurrent DAG step: %s\nExpected required fields: %s\nSkill session fields JSON: %s", session.Name, session.Action, currentStep.ID, strings.Join(currentStep.RequiredFields, ", "), + mustMarshalJSON(session.Fields), + ) + if skillContext := buildCurrentSkillExecutionContext(lang, session); skillContext != "" { + flowContext += "\n" + skillContext + } + systemPrompt, userPrompt := buildActiveFlowClassifierPrompt( + lang, + "skill_session", + flowContext, text, recentConversationCtx, + state.CurrentReferences, + a.SnapshotManager(userID).List(), ) stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) defer cancel() @@ -1043,25 +1683,45 @@ Be conservative: if err != nil { return "" } - raw = strings.TrimSpace(raw) - raw = strings.TrimPrefix(raw, "```json") - raw = strings.TrimPrefix(raw, "```") - raw = strings.TrimSuffix(raw, "```") - raw = strings.TrimSpace(raw) - var decision skillSessionIntentDecision - if err := json.Unmarshal([]byte(raw), &decision); err != nil { - start := strings.Index(raw, "{") - end := strings.LastIndex(raw, "}") - if start < 0 || end <= start || json.Unmarshal([]byte(raw[start:end+1]), &decision) != nil { - return "" - } - } - switch strings.TrimSpace(decision.Decision) { - case "continue", "cancel", "interrupt": - return decision.Decision - default: + return parseActiveFlowIntentDecision(raw) +} + +func (a *Agent) classifyExecutionStateIntentWithLLM(ctx context.Context, userID int64, lang string, state ExecutionState, text string) string { + if a == nil || a.aiClient == nil { return "" } + if strings.TrimSpace(text) == "" || isExplicitFlowAbort(text) || isYesReply(text) || isNoReply(text) || shouldResetExecutionStateForNewAttempt(text, state) { + return "" + } + recentConversationCtx := a.buildRecentConversationContext(userID, text) + flowContext := fmt.Sprintf( + "Goal: %s\nStatus: %s\nWaiting JSON: %s", + state.Goal, + state.Status, + mustMarshalJSON(state.Waiting), + ) + systemPrompt, userPrompt := buildActiveFlowClassifierPrompt( + lang, + "execution_state", + flowContext, + text, + recentConversationCtx, + state.CurrentReferences, + a.SnapshotManager(userID).List(), + ) + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return "" + } + return parseActiveFlowIntentDecision(raw) } func isSkillFlowDeflection(session skillSession, text string) bool { @@ -1077,13 +1737,13 @@ func isSkillFlowDeflection(session skillSession, text string) bool { } switch strings.TrimSpace(session.Name) { case "exchange_management": - return detectModelDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text) + return false case "model_management": - return detectExchangeDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text) + return false case "strategy_management": - return detectExchangeDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectModelDiagnosisSkill(text) + return false case "trader_management": - return detectExchangeDiagnosisSkill(text) || detectModelDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text) + return false default: return false } @@ -1095,31 +1755,57 @@ func isNewSkillRootIntent(session skillSession, text string) bool { if currentSkill == "" { return false } + if currentSkill != "trader_management" && hasExplicitManagementDomainCue(text, "trader") && containsAny(strings.ToLower(strings.TrimSpace(text)), []string{"创建", "新建", "create", "new"}) { + return true + } + if currentSkill != "strategy_management" && hasExplicitManagementDomainCue(text, "strategy") && containsAny(strings.ToLower(strings.TrimSpace(text)), []string{"创建", "新建", "create", "new"}) { + return true + } + if currentSkill != "model_management" && hasExplicitManagementDomainCue(text, "model") && containsAny(strings.ToLower(strings.TrimSpace(text)), []string{"创建", "新建", "create", "new"}) { + return true + } + if currentSkill != "exchange_management" && hasExplicitManagementDomainCue(text, "exchange") && containsAny(strings.ToLower(strings.TrimSpace(text)), []string{"创建", "新建", "create", "new"}) { + return true + } switch currentSkill { case "trader_management": - if detectCreateTraderSkill(text) && currentAction != "create" { - return true - } - if action := normalizeAtomicSkillAction("trader_management", detectManagementAction(text, "trader")); action == "create" && currentAction != "create" { - return true - } + return hasExplicitCreateIntentForDomain(text, "trader") && currentAction != "create" case "strategy_management": - if action := normalizeAtomicSkillAction("strategy_management", detectManagementAction(text, "strategy")); action == "create" && currentAction != "create" { - return true - } + return hasExplicitManagementDomainCue(text, "strategy") && containsAny(strings.ToLower(strings.TrimSpace(text)), []string{"创建", "新建", "create", "new"}) && currentAction != "create" case "model_management": - if action := normalizeAtomicSkillAction("model_management", detectManagementAction(text, "model")); action == "create" && currentAction != "create" { - return true - } + return hasExplicitManagementDomainCue(text, "model") && containsAny(strings.ToLower(strings.TrimSpace(text)), []string{"创建", "新建", "create", "new"}) && currentAction != "create" case "exchange_management": - if action := normalizeAtomicSkillAction("exchange_management", detectManagementAction(text, "exchange")); action == "create" && currentAction != "create" { - return true - } + return hasExplicitManagementDomainCue(text, "exchange") && containsAny(strings.ToLower(strings.TrimSpace(text)), []string{"创建", "新建", "create", "new"}) && currentAction != "create" } return false } -func classifyExecutionStateInput(state ExecutionState, text string) string { +func shouldSuspendInterruptedTask(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if isConfigOrTraderIntent(text) || detectRootSkillIntent(text) != "" { + return false + } + if hasExplicitManagementDomainCue(text, "trader") || hasExplicitManagementDomainCue(text, "model") || + hasExplicitManagementDomainCue(text, "exchange") || hasExplicitManagementDomainCue(text, "strategy") { + return false + } + if req := detectReadFastPath(text); req != nil { + return isEphemeralReadFastPathKind(req.Kind) + } + return containsAny(lower, []string{ + "btc", "eth", "sol", "价格", "行情", "balance", "position", "positions", "portfolio", + "market", "price", "仓位", "持仓", "余额", "账户", "trade history", "历史成交", + }) +} + +func (a *Agent) classifyExecutionStateDecision(ctx context.Context, userID int64, lang string, state ExecutionState, text string) unifiedFlowDecision { + return unifiedFlowDecisionFromIntent(a.classifyExecutionStateInput(ctx, userID, lang, state, text), "") +} + +func (a *Agent) classifyExecutionStateInput(ctx context.Context, userID int64, lang string, state ExecutionState, text string) string { lower := strings.ToLower(strings.TrimSpace(text)) if lower == "" { return "continue" @@ -1130,6 +1816,12 @@ func classifyExecutionStateInput(state ExecutionState, text string) string { if isYesReply(text) || isNoReply(text) || shouldResetExecutionStateForNewAttempt(text, state) { return "continue" } + if a != nil && a.aiClient != nil { + if decision := a.classifyExecutionStateIntentWithLLM(ctx, userID, lang, state, text); decision != "" { + return decision + } + return "continue" + } if state.Waiting != nil && !looksLikeNewTopLevelIntent(text) { return "continue" } @@ -1139,6 +1831,435 @@ func classifyExecutionStateInput(state ExecutionState, text string) string { return "continue" } +func isResumeFlowReply(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + switch lower { + case "继续", "继续吧", "继续刚才的", "恢复", "恢复刚才的", "resume", "continue", "继续创建", "继续配置": + return true + default: + return false + } +} + +func isCancelParentFlowReply(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + switch lower { + case "一并取消", "也取消", "都取消", "全部取消", "取消父任务", "cancel all", "cancel parent", "drop all": + return true + default: + return false + } +} + +func suspendedTaskResumePrompt(lang string, task SuspendedTask) string { + hint := strings.TrimSpace(task.ResumeHint) + if hint == "" { + if lang == "zh" { + hint = "刚才未完成的任务还在,要继续吗?" + } else { + hint = "Your previous unfinished task is still here. Do you want to continue?" + } + } + return hint +} + +func (a *Agent) maybeOfferParentTaskAfterCancel(userID int64, lang string) string { + task, ok := a.SnapshotManager(userID).Peek() + if !ok { + if lang == "zh" { + return "已取消当前流程。" + } + return "Cancelled the current flow." + } + if lang == "zh" { + return "已取消当前流程。\n" + suspendedTaskResumePrompt(lang, task) + "\n如果父任务也不要了,回复“一并取消”。" + } + return "Cancelled the current flow.\n" + suspendedTaskResumePrompt(lang, task) + "\nReply 'cancel all' if you want to cancel the parent task too." +} + +func suspendedTaskDomain(task SuspendedTask) string { + task = normalizeSuspendedTask(task) + if task.SkillSession != nil { + return strings.TrimSpace(task.SkillSession.Name) + } + if task.WorkflowSession != nil { + for _, item := range task.WorkflowSession.Tasks { + if strings.TrimSpace(item.Skill) != "" { + return strings.TrimSpace(item.Skill) + } + } + } + return "" +} + +func (a *Agent) buildSuspendedTask(userID int64, lang string) SuspendedTask { + task := SuspendedTask{} + if session := a.getSkillSession(userID); strings.TrimSpace(session.Name) != "" { + sessionCopy := normalizeSkillSession(session) + task.Kind = "skill_session" + task.SkillSession = &sessionCopy + task.ResumeHint = buildSkillResumeHint(lang, sessionCopy) + if sessionCopy.Name == "trader_management" && sessionCopy.Action == "create" { + task.ResumeOnSuccess = true + task.ResumeTriggers = []string{"exchange_management", "model_management", "strategy_management"} + } + } + if workflow := a.getWorkflowSession(userID); hasActiveWorkflowSession(workflow) { + workflowCopy := normalizeWorkflowSession(workflow) + task.Kind = "workflow_session" + task.WorkflowSession = &workflowCopy + if task.ResumeHint == "" { + task.ResumeHint = buildWorkflowResumeHint(lang, workflowCopy) + } + } + if state := a.getExecutionState(userID); hasActiveExecutionState(state) { + stateCopy := normalizeExecutionState(state) + if task.Kind == "" { + task.Kind = "execution_state" + } + task.ExecutionState = &stateCopy + if task.ResumeHint == "" { + task.ResumeHint = buildExecutionResumeHint(lang, stateCopy) + } + } + if a.history != nil { + if msgs := a.history.Get(userID); len(msgs) > 0 { + if len(msgs) > chatHistoryMaxTurns { + msgs = msgs[len(msgs)-chatHistoryMaxTurns:] + } + task.LocalHistory = msgs + } + } + return normalizeSuspendedTask(task) +} + +func buildSkillResumeHint(lang string, session skillSession) string { + target := "" + if session.TargetRef != nil { + target = defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID) + } + if lang == "zh" { + switch session.Name { + case "strategy_management": + if target != "" { + return fmt.Sprintf("刚才关于策略“%s”的流程还没完成,要继续吗?", target) + } + return "刚才的策略配置流程还没完成,要继续吗?" + case "model_management": + if target != "" { + return fmt.Sprintf("刚才关于模型“%s”的流程还没完成,要继续吗?", target) + } + return "刚才的模型配置流程还没完成,要继续吗?" + case "exchange_management": + if target != "" { + return fmt.Sprintf("刚才关于交易所“%s”的流程还没完成,要继续吗?", target) + } + return "刚才的交易所配置流程还没完成,要继续吗?" + case "trader_management": + if target != "" { + return fmt.Sprintf("刚才关于交易员“%s”的流程还没完成,要继续吗?", target) + } + return "刚才的交易员配置流程还没完成,要继续吗?" + } + } + if target != "" { + return fmt.Sprintf("The flow for %s is still unfinished. Do you want to continue?", target) + } + return "The previous configuration flow is still unfinished. Do you want to continue?" +} + +func buildWorkflowResumeHint(lang string, session WorkflowSession) string { + req := strings.TrimSpace(session.OriginalRequest) + if req == "" { + if lang == "zh" { + return "刚才的多步任务还没完成,要继续吗?" + } + return "The previous workflow is still unfinished. Do you want to continue?" + } + if lang == "zh" { + return fmt.Sprintf("刚才关于“%s”的多步任务还没完成,要继续吗?", req) + } + return fmt.Sprintf("The workflow for %q is still unfinished. Do you want to continue?", req) +} + +func buildExecutionResumeHint(lang string, state ExecutionState) string { + if state.Waiting != nil && strings.TrimSpace(state.Waiting.Question) != "" { + if lang == "zh" { + return fmt.Sprintf("刚才我们停在这个问题:%s 回复“继续”我就接着来。", state.Waiting.Question) + } + return fmt.Sprintf("We paused at this question: %s Reply 'continue' and I'll resume.", state.Waiting.Question) + } + goal := strings.TrimSpace(state.Goal) + if goal == "" { + if lang == "zh" { + return "刚才未完成的任务还在,要继续吗?" + } + return "The previous unfinished task is still here. Do you want to continue?" + } + if lang == "zh" { + return fmt.Sprintf("刚才关于“%s”的任务还没完成,要继续吗?", goal) + } + return fmt.Sprintf("The task for %q is still unfinished. Do you want to continue?", goal) +} + +func (a *Agent) suspendActiveContexts(userID int64, lang string) bool { + task := a.buildSuspendedTask(userID, lang) + if task.Kind == "" { + return false + } + a.SnapshotManager(userID).Save(task) + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + a.clearExecutionState(userID) + return true +} + +func (a *Agent) restoreSuspendedTask(userID int64, task SuspendedTask) bool { + task = normalizeSuspendedTask(task) + if task.Kind == "" { + return false + } + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + a.clearExecutionState(userID) + if a.history != nil && len(task.LocalHistory) > 0 { + a.history.Replace(userID, task.LocalHistory) + } + if task.ExecutionState != nil { + _ = a.saveExecutionState(*task.ExecutionState) + } + if task.WorkflowSession != nil { + a.saveWorkflowSession(userID, *task.WorkflowSession) + } + if task.SkillSession != nil { + a.saveSkillSession(userID, *task.SkillSession) + } + return true +} + +func (a *Agent) restoreSuspendedTaskByID(userID int64, snapshotID string) bool { + snapshotID = strings.TrimSpace(snapshotID) + if snapshotID == "" { + return false + } + manager := a.SnapshotManager(userID) + stack := manager.Stack() + if len(stack) == 0 { + return false + } + match := -1 + for i := len(stack) - 1; i >= 0; i-- { + if strings.TrimSpace(stack[i].SnapshotID) == snapshotID { + match = i + break + } + } + if match < 0 { + return false + } + task, ok := manager.RemoveAt(match) + if !ok { + return false + } + return a.restoreSuspendedTask(userID, task) +} + +func (a *Agent) tryRestoreSuspendedTaskAfterSwitch(userID int64, text, targetSnapshotID string) bool { + if a.restoreSuspendedTaskByID(userID, targetSnapshotID) { + return true + } + return a.restoreMatchingSuspendedTask(userID, text) +} + +func (a *Agent) suspendAndTryRestoreSuspendedTask(userID int64, lang, text, targetSnapshotID string) bool { + a.suspendActiveContexts(userID, lang) + return a.tryRestoreSuspendedTaskAfterSwitch(userID, text, targetSnapshotID) +} + +func (a *Agent) tryResumeSuspendedTask(userID int64, lang, text string) (string, bool) { + if isCancelParentFlowReply(text) && !a.hasActiveSkillSession(userID) && !hasActiveWorkflowSession(a.getWorkflowSession(userID)) && !hasActiveExecutionState(a.getExecutionState(userID)) { + a.SnapshotManager(userID).Clear() + if lang == "zh" { + return "已把之前挂起的父任务也一并取消。", true + } + return "Cancelled the previously suspended parent tasks as well.", true + } + if !isResumeFlowReply(text) { + return "", false + } + if a.hasActiveSkillSession(userID) || hasActiveWorkflowSession(a.getWorkflowSession(userID)) || hasActiveExecutionState(a.getExecutionState(userID)) { + return "", false + } + task, ok := a.SnapshotManager(userID).Load() + if !ok { + return "", false + } + if !a.restoreSuspendedTask(userID, task) { + return "", false + } + return suspendedTaskResumePrompt(lang, task), true +} + +func (a *Agent) tryRestoreSuspendedTaskWithLLM(ctx context.Context, userID int64, lang, text string) bool { + if a == nil || a.aiClient == nil || strings.TrimSpace(text) == "" { + return false + } + snapshots := a.SnapshotManager(userID).List() + if len(snapshots) == 0 { + return false + } + snapshotsJSON, _ := json.Marshal(snapshots) + recentConversationCtx := a.buildRecentConversationContext(userID, text) + systemPrompt := `You select whether a user message refers to one suspended NOFXi snapshot that should be restored now. +Return JSON only. No markdown. + +Rules: +- Choose target_snapshot_id only when the user clearly refers to exactly one suspended snapshot. +- Prefer empty target_snapshot_id when uncertain. +- Use the snapshot resume hint, kind, and recent conversation to resolve references like "刚才那个", "the model one", or "继续那个策略". +- Never invent snapshot ids. + +Return JSON with this exact shape: +{"target_snapshot_id":""}` + userPrompt := fmt.Sprintf("Language: %s\nUser message: %s\nSuspended snapshots JSON: %s\n\nRecent conversation:\n%s", lang, text, string(snapshotsJSON), recentConversationCtx) + + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return false + } + selection, ok := parseSuspendedTaskSelectionResult(raw) + if !ok { + return false + } + return a.restoreSuspendedTaskByID(userID, selection.TargetSnapshotID) +} + +func (a *Agent) tryRestoreSuspendedTaskFromIdle(ctx context.Context, userID int64, lang, text string) bool { + if a.tryRestoreAwaitingConfirmationSnapshot(userID, text) { + return true + } + if a.tryRestoreSuspendedTaskWithLLM(ctx, userID, lang, text) { + return true + } + return a.restoreMatchingSuspendedTask(userID, text) +} + +func (a *Agent) tryRestoreAwaitingConfirmationSnapshot(userID int64, text string) bool { + if !isYesReply(text) && !isNoReply(text) && !createConfirmationReply(text) { + return false + } + stack := a.SnapshotManager(userID).Stack() + if len(stack) != 1 { + return false + } + task := stack[0] + if task.Kind != "skill_session" || task.SkillSession == nil { + return false + } + phase := strings.TrimSpace(task.SkillSession.Phase) + switch phase { + case "await_confirmation", "await_create_confirmation", "await_start_confirmation": + return a.restoreSuspendedTask(userID, task) + default: + return false + } +} + +func (a *Agent) restoreMatchingSuspendedTask(userID int64, text string) bool { + wanted := detectRootSkillIntent(text) + if wanted == "" { + wanted = detectMentionedSkillDomain(text) + } + if wanted == "" { + return false + } + manager := a.SnapshotManager(userID) + fullStack := manager.Stack() + if len(fullStack) == 0 { + return false + } + match := -1 + for i := len(fullStack) - 1; i >= 0; i-- { + if suspendedTaskDomain(fullStack[i]) == wanted { + match = i + break + } + } + if match < 0 { + return false + } + task, ok := manager.RemoveAt(match) + if !ok { + return false + } + return a.restoreSuspendedTask(userID, task) +} + +func (a *Agent) maybeAppendResumePrompt(userID int64, lang, text, answer string) string { + a.trackPendingProposalSession(userID, lang, text, answer) + if strings.TrimSpace(answer) == "" || !shouldSuspendInterruptedTask(text) { + return answer + } + if a.hasActiveSkillSession(userID) || hasActiveWorkflowSession(a.getWorkflowSession(userID)) || hasActiveExecutionState(a.getExecutionState(userID)) { + return answer + } + task, ok := a.SnapshotManager(userID).Peek() + if !ok { + return answer + } + prompt := suspendedTaskResumePrompt(lang, task) + if prompt == "" || strings.Contains(answer, prompt) { + return answer + } + return strings.TrimSpace(answer) + "\n\n" + prompt +} + +func (a *Agent) trackPendingProposalSession(userID int64, lang, sourceUserText, answer string) { + answer = strings.TrimSpace(answer) + if answer == "" { + return + } + if looksLikePendingProposalReply(answer) { + if a.hasActiveSkillSession(userID) || hasActiveWorkflowSession(a.getWorkflowSession(userID)) || hasActiveExecutionState(a.getExecutionState(userID)) { + a.suspendActiveContexts(userID, lang) + } + a.clearActiveSkillSession(userID) + a.savePendingProposalSession(PendingProposalSession{ + UserID: userID, + SourceUserText: strings.TrimSpace(sourceUserText), + ProposalText: answer, + }) + return + } + a.clearPendingProposalSession(userID) +} + +func looksLikePendingProposalReply(answer string) bool { + lower := strings.ToLower(strings.TrimSpace(answer)) + if lower == "" { + return false + } + return containsAny(lower, []string{ + "需要我按这个方案操作吗", + "按这个方案操作吗", + "你想选哪个", + "请选择", + "两个选择", + "直接使用已有的", + "which option do you want", + "do you want me to follow this plan", + "should i proceed with this plan", + }) +} + func isExplicitFlowAbort(text string) bool { lower := strings.ToLower(strings.TrimSpace(text)) if lower == "" { @@ -1156,13 +2277,13 @@ func isExplicitFlowAbort(text string) bool { func belongsToSkillDomain(skillName, text string) bool { switch strings.TrimSpace(skillName) { case "trader_management": - return detectCreateTraderSkill(text) || detectTraderManagementIntent(text) || detectTraderDiagnosisSkill(text) + return hasExplicitCreateIntentForDomain(text, "trader") case "strategy_management": - return detectStrategyManagementIntent(text) || detectStrategyDiagnosisSkill(text) + return false case "model_management": - return detectModelManagementIntent(text) || detectModelDiagnosisSkill(text) + return false case "exchange_management": - return detectExchangeManagementIntent(text) || detectExchangeDiagnosisSkill(text) + return false default: return false } @@ -1176,15 +2297,7 @@ func looksLikeNewTopLevelIntent(text string) bool { if strings.HasPrefix(lower, "/") { return true } - if detectCreateTraderSkill(text) || - detectTraderManagementIntent(text) || - detectExchangeManagementIntent(text) || - detectModelManagementIntent(text) || - detectStrategyManagementIntent(text) || - detectTraderDiagnosisSkill(text) || - detectExchangeDiagnosisSkill(text) || - detectModelDiagnosisSkill(text) || - detectStrategyDiagnosisSkill(text) { + if hasExplicitCreateIntentForDomain(text, "trader") { return true } if detectReadFastPath(text) != nil { @@ -1206,10 +2319,8 @@ func (a *Agent) tryDirectAnswer(ctx context.Context, userID int64, lang, text st return "", false } - recentConversationCtx := a.buildRecentConversationContext(userID, text) - taskStateCtx := buildTaskStateContext(a.getTaskState(userID)) - executionState := normalizeExecutionState(a.getExecutionState(userID)) - executionJSON, _ := json.Marshal(executionState) + currentTurnCtx := a.buildCurrentTurnContext(userID, lang, text) + activeTaskCtx := a.buildActiveTaskStateContext(userID, lang) systemPrompt := `You are the first-pass router for NOFXi. Decide whether the assistant can answer the user's message directly without using skills, tools, or planning. Return JSON only. Do not return markdown. @@ -1226,17 +2337,24 @@ Use "defer" when the message likely needs: - tool reads - multi-step planning - continuation of an active execution flow that needs stateful follow-up +- interpretation of current product state, observations, counts, duplicates, balances, configuration-page findings, or anything that sounds like "I see / I noticed / there are still ..." Rules: -- Consider Recent conversation, Task state, and Execution state JSON before deciding. +- If you choose direct_answer, write for a trading beginner, not a developer. +- Keep the answer simple, clear, and easy to act on. +- Lead with the conclusion first, then one or two concrete next steps when helpful. +- Avoid internal jargon, architecture talk, tool names, or implementation detail unless the user explicitly asks. +- Use Current turn context as the primary memory for this turn. +- Use Active task state only as a compact summary of any unfinished operational flow. - Default to direct_answer for greetings, thanks, identity questions, and other lightweight conversational turns unless there is a clearly unfinished operational flow that the user is continuing. - If the user is clearly continuing an unfinished operational flow, choose defer. +- If the user mentions concrete operational entities or observations such as traders, strategies, models, exchanges, balances, counts, duplicate items, config pages, or numeric account facts, choose defer. - If you choose direct_answer, provide the final user-facing answer in the same language as the user. - Prefer defer when uncertain. Return JSON with this exact shape: {"action":"direct_answer|defer","answer":""}` - userPrompt := fmt.Sprintf("Language: %s\nUser message: %s\n\nRecent conversation:\n%s\n\nTask state:\n%s\n\nExecution state JSON:\n%s", lang, text, recentConversationCtx, taskStateCtx, string(executionJSON)) + userPrompt := fmt.Sprintf("Language: %s\nUser message: %s\n\nCurrent turn context:\n%s\n\nActive task state:\n%s", lang, text, defaultIfEmpty(currentTurnCtx, "(empty)"), defaultIfEmpty(activeTaskCtx, "(empty)")) stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) defer cancel() @@ -1265,13 +2383,14 @@ Return JSON with this exact shape: return "", false } - a.ensureHistory() + if a.history == nil { + a.history = newChatHistory(chatHistoryMaxTurns) + } a.history.Add(userID, "user", text) a.history.Add(userID, "assistant", answer) - a.maybeUpdateTaskStateIncrementally(ctx, userID) - a.maybeCompressHistory(ctx, userID) + a.runPostResponseMaintenanceAsync(userID) if onEvent != nil { - onEvent(StreamEventDelta, answer) + emitStreamText(onEvent, answer) } return answer, true } @@ -1303,7 +2422,135 @@ func normalizeDirectReplyDecision(decision directReplyDecision) directReplyDecis return decision } +func looksLikeInternalAgentJSON(raw string) bool { + raw = strings.TrimSpace(raw) + if raw == "" || !strings.HasPrefix(raw, "{") || !strings.HasSuffix(raw, "}") { + return false + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return false + } + if _, ok := payload["intent"]; ok { + if _, hasTasks := payload["tasks"]; hasTasks { + return true + } + if _, hasFields := payload["fields"]; hasFields { + return true + } + if _, hasReason := payload["reason"]; hasReason { + return true + } + } + return false +} + +func firstFlowExtractionFields(result llmFlowExtractionResult) map[string]string { + if len(result.Fields) > 0 { + return result.Fields + } + if len(result.Tasks) > 0 && len(result.Tasks[0].Fields) > 0 { + return result.Tasks[0].Fields + } + return nil +} + +func (a *Agent) tryRecoverFromInternalAgentJSON(ctx context.Context, storeUserID string, userID int64, lang, text, raw string, onEvent func(event, data string)) (string, bool, error) { + result := parseLLMFlowExtractionResult(raw) + if result.Intent == "" { + return "", false, nil + } + switch result.Intent { + case "instant_reply": + return a.replyToActiveFlowInstantReply(ctx, userID, lang, text, onEvent), true, nil + case "cancel": + if a.hasActiveSkillSession(userID) { + a.clearSkillSession(userID) + } + if hasActiveExecutionState(a.getExecutionState(userID)) { + a.clearExecutionState(userID) + } + return a.maybeOfferParentTaskAfterCancel(userID, lang), true, nil + case "continue": + if session := a.getSkillSession(userID); strings.TrimSpace(session.Name) != "" { + a.applyLLMExtractionToSkillSession(storeUserID, &session, result, lang, text) + a.saveSkillSession(userID, session) + if answer, ok := a.dispatchBridgedSkillSession(storeUserID, userID, lang, text, session); ok { + return answer, true, nil + } + } + if len(result.Tasks) > 0 { + task := result.Tasks[0] + if strings.TrimSpace(task.Skill) != "" { + recovered := skillSession{ + Name: strings.TrimSpace(task.Skill), + Action: strings.TrimSpace(task.Action), + Phase: "collecting", + Fields: map[string]string{}, + } + if suspended, ok := a.SnapshotManager(userID).Peek(); ok && suspended.SkillSession != nil { + suspendedSkill := strings.TrimSpace(suspended.SkillSession.Name) + suspendedAction := strings.TrimSpace(suspended.SkillSession.Action) + if suspendedSkill == recovered.Name && (recovered.Action == "" || suspendedAction == recovered.Action) { + recovered = *suspended.SkillSession + } + } + for key, value := range task.Fields { + setField(&recovered, key, value) + } + recovered = normalizeSkillSession(recovered) + if recovered.Name == "trader_management" && recovered.Action == "create" { + a.hydrateCreateTraderSlotReferences(storeUserID, &recovered) + } + if recovered.Name == "trader_management" && recovered.Action == "create" && len(missingFieldKeysForSkillSession(recovered)) == 0 { + if fieldValue(recovered, "auto_start") == "true" { + recovered.Phase = "await_start_confirmation" + a.saveSkillSession(userID, recovered) + if lang == "zh" { + return fmt.Sprintf("准备创建交易员并立即启动。\n交易所:%s\n模型:%s\n策略:%s\n\n回复确认继续,回复先不用则只创建不启动。", + traderCreateExchangeNameOrID(recovered), traderCreateModelNameOrID(recovered), traderCreateStrategyNameOrID(recovered)), true, nil + } + return fmt.Sprintf("Ready to create trader and start it immediately.\nExchange: %s\nModel: %s\nStrategy: %s\n\nReply confirm to continue, or no to create without starting.", + traderCreateExchangeNameOrID(recovered), traderCreateModelNameOrID(recovered), traderCreateStrategyNameOrID(recovered)), true, nil + } + recovered.Phase = "await_create_confirmation" + a.saveSkillSession(userID, recovered) + return formatTraderCreateDraftSummary(lang, recovered), true, nil + } + a.saveSkillSession(userID, recovered) + if answer, ok := a.dispatchBridgedSkillSession(storeUserID, userID, lang, text, recovered); ok { + return answer, true, nil + } + } + } + if state := a.getExecutionState(userID); hasActiveExecutionState(state) { + extraction := executionFlowExtractionResult{ + Intent: "continue", + TargetSnapshotID: result.TargetSnapshotID, + Fields: firstFlowExtractionFields(result), + Reason: result.Reason, + } + if answer, handled, err := a.redirectExecutionStateStrategyCreate(ctx, storeUserID, userID, lang, text, state, onEvent); handled || err != nil { + return answer, handled, err + } + if session, ok := a.bridgeExecutionStateToSkillSession(storeUserID, userID, text, state, extraction); ok { + answer, handled := a.dispatchBridgedSkillSession(storeUserID, userID, lang, text, session) + return answer, handled, nil + } + } + } + return "", false, nil +} + func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, error) { + return a.runPlannedAgentWithContextMode(ctx, storeUserID, userID, lang, text, "", onEvent) +} + +func (a *Agent) runPlannedAgentWithContextMode(ctx context.Context, storeUserID string, userID int64, lang, text string, contextMode string, onEvent func(event, data string)) (string, error) { + if session, ok := a.activeStrategyCreateSession(userID); ok { + answer, _, err := a.driveActiveSession(ctx, storeUserID, userID, lang, text, session, onEvent) + return answer, err + } a.ensureHistory() a.history.Add(userID, "user", text) if onEvent != nil { @@ -1311,19 +2558,26 @@ func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID } requestStartedAt := time.Now() - state, err := a.prepareExecutionState(ctx, storeUserID, userID, lang, text) + state, err := a.prepareExecutionState(ctx, storeUserID, userID, lang, text, contextMode) if err != nil { a.logPlannerTiming("", userID, "prepare_execution_state", requestStartedAt, err) if isPlannerTimeoutError(err) { msg := plannerTimeoutMessage(lang) if onEvent != nil { onEvent(StreamEventError, msg) - onEvent(StreamEventDelta, msg) + emitStreamText(onEvent, msg) } return msg, nil } + if hasExplicitCreateIntentForDomain(text, "strategy") { + a.logger.Warn("planner failed during strategy create; using template strategy flow instead of legacy loop", "error", err, "user_id", userID) + session := newActiveSkillSession(userID, "strategy_management", "create") + session.Goal = strings.TrimSpace(text) + answer, _, flowErr := a.driveActiveSession(ctx, storeUserID, userID, lang, text, session, onEvent) + return answer, flowErr + } a.logger.Warn("planner failed, falling back to legacy loop", "error", err, "user_id", userID) - return a.thinkAndActLegacy(ctx, userID, lang, text, onEvent) + return a.thinkAndActLegacyWithStore(ctx, storeUserID, userID, lang, text, onEvent) } a.logPlannerTiming(state.SessionID, userID, "prepare_execution_state", requestStartedAt, nil) @@ -1335,22 +2589,58 @@ func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID msg := plannerTimeoutMessage(lang) if onEvent != nil { onEvent(StreamEventError, msg) - onEvent(StreamEventDelta, msg) + emitStreamText(onEvent, msg) } return msg, nil } + if answer, ok := a.tryExecutionSummaryFallbackOnAIError(lang, &state, err, onEvent); ok { + return answer, nil + } + if hasExplicitCreateIntentForDomain(state.Goal, "strategy") || hasExplicitCreateIntentForDomain(text, "strategy") { + a.logger.Warn("plan execution failed during strategy create; using template strategy flow instead of legacy loop", "error", err, "user_id", userID) + a.clearExecutionState(userID) + session := newActiveSkillSession(userID, "strategy_management", "create") + session.Goal = defaultIfEmpty(strings.TrimSpace(state.Goal), strings.TrimSpace(text)) + answer, _, flowErr := a.driveActiveSession(ctx, storeUserID, userID, lang, text, session, onEvent) + return answer, flowErr + } a.logger.Warn("plan execution failed, falling back to legacy loop", "error", err, "user_id", userID) - return a.thinkAndActLegacy(ctx, userID, lang, text, onEvent) + return a.thinkAndActLegacyWithStore(ctx, storeUserID, userID, lang, text, onEvent) } + if guarded, blocked := guardUnsupportedAsyncPromise(lang, answer); blocked { + answer = guarded + } a.history.Add(userID, "assistant", answer) - a.maybeUpdateTaskStateIncrementally(ctx, userID) - a.maybeCompressHistory(ctx, userID) + a.runPostResponseMaintenanceAsync(userID) a.logPlannerTiming(state.SessionID, userID, "run_planned_agent_total", requestStartedAt, nil) return answer, nil } -func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, userID int64, lang, text string) (ExecutionState, error) { +func (a *Agent) runPostResponseMaintenanceAsync(userID int64) { + if a == nil || a.aiClient == nil || a.history == nil { + return + } + go func() { + defer func() { + if r := recover(); r != nil { + a.log().Warn("post-response maintenance panicked", "user_id", userID, "panic", r) + } + }() + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + // Respect agent shutdown: abort early if stopCh is closed. + select { + case <-a.stopCh: + return + default: + } + a.maybeUpdateTaskStateIncrementally(ctx, userID) + a.maybeCompressHistory(ctx, userID) + }() +} + +func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, userID int64, lang, text, contextMode string) (ExecutionState, error) { existing := a.getExecutionState(userID) if shouldResetExecutionStateForNewAttempt(text, existing) { a.clearExecutionState(userID) @@ -1384,6 +2674,16 @@ func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, u } state := newExecutionState(userID, text) + mem := a.getReferenceMemory(userID) + switch strings.TrimSpace(contextMode) { + case "fresh_context": + a.SnapshotManager(userID).Clear() + default: + if mem.CurrentReferences != nil { + state.CurrentReferences = mem.CurrentReferences + state.ReferenceHistory = mem.ReferenceHistory + } + } a.refreshCurrentReferencesForUserText(storeUserID, text, &state) state = a.refreshStateForDynamicRequests(storeUserID, text, state) state.Status = executionStatusRunning @@ -1400,12 +2700,11 @@ type nextStepDecision struct { } func (a *Agent) decideNextStep(ctx context.Context, userID int64, lang string, state ExecutionState) (nextStepDecision, error) { - toolDefs, _ := json.Marshal(agentTools()) - stateJSON, _ := json.Marshal(normalizeExecutionState(state)) + toolDefs, _ := json.Marshal(plannerToolsForText(state.Goal)) obsJSON, _ := json.Marshal(buildObservationContext(state)) recentlyFetchedJSON, _ := json.Marshal(buildRecentlyFetchedData(state, time.Now().UTC())) - taskStateCtx := buildTaskStateContext(a.getTaskState(userID)) - recentConversationCtx := a.buildRecentConversationContext(userID, state.Goal) + currentTurnCtx := a.buildCurrentTurnContext(userID, lang, state.Goal) + activeTaskCtx := a.buildActiveTaskStateContext(userID, lang) systemPrompt := `You are the step selector for NOFXi. Return JSON only. Do not return markdown. @@ -1420,9 +2719,10 @@ Allowed step types: - respond Rules: -- Use all available memory layers: Execution state JSON, Observations JSON, Recent conversation, and Task state. +- Use Current turn context and Active task state as the main conversational memory. +- Use Observations JSON as the source of truth for what tools already revealed in this execution. - Use Recently fetched data JSON as the deduplication source of truth for fresh tool results. -- Prefer the freshest evidence in this order: execution state, observations, recent conversation, then task state. +- Prefer the freshest evidence in this order: observations, current turn context, active task state. - If fresh external or system data is needed, choose a tool step. - If the user is blocked on a missing parameter, choose ask_user. - If there is enough information to answer now, choose respond. @@ -1444,7 +2744,7 @@ Rules: Return JSON with this exact shape: {"goal":"","steps":[{"id":"step_1","type":"tool|reason|ask_user|respond","title":"","tool_name":"","tool_args":{},"instruction":"","requires_confirmation":false}]}` - userPrompt := fmt.Sprintf("Language: %s\nGoal: %s\n\nRecent conversation:\n%s\n\nAvailable tools JSON:\n%s\n\nPersistent preferences:\n%s\n\nTask state:\n%s\n\nExecution state JSON:\n%s\n\nObservations JSON:\n%s\n\nRecently fetched data JSON:\n%s", lang, state.Goal, recentConversationCtx, string(toolDefs), a.buildPersistentPreferencesContext(userID), taskStateCtx, string(stateJSON), string(obsJSON), string(recentlyFetchedJSON)) + userPrompt := fmt.Sprintf("Language: %s\nGoal: %s\n\nCurrent turn context:\n%s\n\nActive task state:\n%s\n\nAvailable tools JSON:\n%s\n\nPersistent preferences:\n%s\n\nObservations JSON:\n%s\n\nRecently fetched data JSON:\n%s", lang, state.Goal, defaultIfEmpty(currentTurnCtx, "(empty)"), defaultIfEmpty(activeTaskCtx, "(empty)"), string(toolDefs), a.buildPersistentPreferencesContext(userID), string(obsJSON), string(recentlyFetchedJSON)) stageCtx, cancel := withPlannerStageTimeout(ctx, plannerCreateTimeout) defer cancel() @@ -1543,9 +2843,9 @@ func (a *Agent) refreshStateForDynamicRequests(storeUserID, userText string, sta case "current_strategies": appendSnapshot(kind, a.toolGetStrategies(storeUserID)) case "current_balances": - appendSnapshot(kind, a.toolGetBalance()) + appendSnapshot(kind, a.toolGetBalance(storeUserID)) case "current_positions": - appendSnapshot(kind, a.toolGetPositions()) + appendSnapshot(kind, a.toolGetPositions(storeUserID)) case "recent_trade_history": appendSnapshot(kind, a.toolGetTradeHistory(`{"limit":10}`)) } @@ -1587,17 +2887,18 @@ func (a *Agent) buildRecentConversationContext(userID int64, currentUserText str } func (a *Agent) createExecutionPlan(ctx context.Context, userID int64, lang, userText string, state ExecutionState) (executionPlan, error) { - toolDefs, _ := json.Marshal(agentTools()) - stateJSON, _ := json.Marshal(normalizeExecutionState(state)) - taskStateCtx := buildTaskStateContext(a.getTaskState(userID)) - recentConversationCtx := a.buildRecentConversationContext(userID, userText) + toolDefs, _ := json.Marshal(plannerToolsForText(userText)) + currentTurnCtx := a.buildCurrentTurnContext(userID, lang, userText) + activeTaskCtx := a.buildActiveTaskStateContext(userID, lang) + currentReferenceSummary := buildCurrentReferenceSummary(lang, a.semanticCurrentReferences(userID)) + skillContext := buildManagementSkillRoutingContext(lang) if isConfigOrTraderIntent(userText) { // Configuration and trader setup requests are especially sensitive to stale - // summaries like "this capability does not exist". Prefer fresh tool checks. - taskStateCtx = "" + // durable summaries. Prefer the current turn context plus fresh tool checks. + activeTaskCtx = "" } - systemPrompt := `You are the planning module for NOFXi. + systemPrompt := prependNOFXiAdvisorPreamble(`You are the planning module for NOFXi. Return JSON only. Do not return markdown. Create a minimal safe execution plan using these step types only: @@ -1607,21 +2908,24 @@ Create a minimal safe execution plan using these step types only: - respond Rules: -- Use all available memory layers when planning: Execution state JSON, Recent conversation, and Task state. +- Use a compact memory layout when planning: Current reference summary, Current turn context, and Active task state. - Memory priority order: - 1. Execution state JSON = current operational truth for the active task. - 2. Recent conversation = the best source for what was said in the last few turns. - 3. Task state = compressed durable background only. -- If these memory layers conflict, prefer execution state first, then recent conversation. Do not let task state override fresher evidence. -- Do not ask the user to repeat a fact that is already explicit in execution state or recent conversation unless the inputs are contradictory. + 1. Current reference summary = the currently locked entity/object memory for follow-up turns. + 2. Current turn context = the best source for what was just said, especially the last assistant reply and latest turns. + 3. Active task state = compact unfinished-task memory only. +- If these memory layers conflict, prefer current reference summary first for the target entity, then current turn context, then active task state. +- Do not ask the user to repeat a fact that is already explicit in current reference summary, current turn context, or active task state unless the inputs are contradictory. - Use tool steps whenever fresh external data is required. - Use ask_user if required parameters are missing. +- For config or create flows, prefer multi-slot ask_user prompts: ask for the main missing fields together instead of one field per turn whenever practical. - Never place a trade unless the user intent is explicit. - For exchange binding or exchange credential requests, prefer get_exchange_configs/manage_exchange_config. - For AI model binding or model credential requests, prefer get_model_configs/manage_model_config. -- For strategy template creation or editing requests, prefer get_strategies/manage_strategy. +- For strategy template editing/query requests, prefer get_strategies/manage_strategy. +- For strategy template creation, do not call manage_strategy action=create from the planner. Strategy creation must be handled by the active strategy template flow so the selected product editor template can collect fields and require chat confirmation. - For trader creation or trader lifecycle requests, prefer manage_trader. - A strategy template is independent and does not require exchange/model bindings unless the user explicitly asks to run or deploy it through a trader. +- Do NOT expand the goal beyond what the user explicitly requested. When the user's request is fulfilled, respond and stop. Do not proactively suggest or ask about the next logical step (e.g. do not ask "should I bind this to a trader?" after a strategy update unless the user asked for that). - If these tools exist, never answer that the system lacks exchange/model/trader management capability. - When configuration, strategy, or trader creation is requested, gather missing required fields via ask_user, then call the appropriate tool. - Before concluding that exchange/model/trader/strategy setup is impossible or missing, first inspect current state with the relevant tools. @@ -1634,7 +2938,7 @@ Rules: - For ask_user steps, put the exact follow-up question in instruction. - For respond steps, put either a short instruction or leave instruction empty. - If resuming after a waiting_user state, incorporate the new user reply and return a fresh full plan. -- Never invent tools.` +- Never invent tools.`) resumeContext := "" if state.SessionID != "" { @@ -1647,7 +2951,7 @@ Rules: } } - userPrompt := fmt.Sprintf("Language: %s\nUser request: %s%s\n\nRecent conversation:\n%s\n\nAvailable tools JSON:\n%s\n\nPersistent preferences:\n%s\n\nTask state:\n%s\n\nExecution state JSON:\n%s\n\nReturn JSON with this exact shape:\n{\"goal\":\"\",\"steps\":[{\"id\":\"step_1\",\"type\":\"tool|reason|ask_user|respond\",\"title\":\"\",\"tool_name\":\"\",\"tool_args\":{},\"instruction\":\"\",\"requires_confirmation\":false}]}", lang, userText, resumeContext, recentConversationCtx, string(toolDefs), a.buildPersistentPreferencesContext(userID), taskStateCtx, string(stateJSON)) + userPrompt := fmt.Sprintf("Language: %s\nUser request: %s%s\n\n%s\n\nCurrent reference summary:\n%s\n\nCurrent turn context:\n%s\n\nActive task state:\n%s\n\nAvailable tools JSON:\n%s\n\nPersistent preferences:\n%s\n\nReturn JSON with this exact shape:\n{\"goal\":\"\",\"steps\":[{\"id\":\"step_1\",\"type\":\"tool|reason|ask_user|respond\",\"title\":\"\",\"tool_name\":\"\",\"tool_args\":{},\"instruction\":\"\",\"requires_confirmation\":false}]}", lang, userText, resumeContext, skillContext, currentReferenceSummary, defaultIfEmpty(currentTurnCtx, "(empty)"), defaultIfEmpty(activeTaskCtx, "(empty)"), string(toolDefs), a.buildPersistentPreferencesContext(userID)) stageCtx, cancel := withPlannerStageTimeout(ctx, plannerCreateTimeout) defer cancel() @@ -1729,16 +3033,7 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 } steps := filterFreshDuplicateToolSteps(decision.Steps, *state, time.Now().UTC()) if len(steps) == 0 { - appendExecutionLog(state, Observation{ - Kind: "decision_note", - Summary: "Skipped duplicate fresh tool calls from next-step decision", - CreatedAt: time.Now().UTC().Format(time.RFC3339), - }) - state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) - if err := a.saveExecutionState(*state); err != nil { - return "", err - } - continue + return "", fmt.Errorf("all next steps are duplicate fresh tool calls") } if hasRepeatedReasonLoop(*state, steps) { return "", fmt.Errorf("repeated reasoning loop detected") @@ -1787,6 +3082,13 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 switch step.Type { case planStepTypeTool: + if answer, handled := a.redirectPlannerStrategyCreateStep(storeUserID, userID, lang, state.Goal, *step); handled { + a.clearExecutionState(userID) + if onEvent != nil && strings.TrimSpace(answer) != "" { + emitStreamText(onEvent, answer) + } + return answer, nil + } if onEvent != nil { onEvent(StreamEventTool, step.ToolName) } @@ -1805,9 +3107,7 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 CreatedAt: time.Now().UTC().Format(time.RFC3339), }) referencesChanged = updateCurrentReferencesFromToolResult(state, step.ToolName, result) - if referencesChanged { - a.log().Info("tool step updated references", "tool", step.ToolName, "session", state.SessionID) - } + _ = referencesChanged case planStepTypeReason: reasonStartedAt := time.Now() reasoning, err := a.executeReasonStep(ctx, userID, lang, state.Goal, *state, *step) @@ -1817,9 +3117,7 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 step.Error = err.Error() state.Status = executionStatusFailed state.LastError = err.Error() - if saveErr := a.saveExecutionState(*state); saveErr != nil { - a.log().Warn("failed to save execution state after reason step error", "error", saveErr) - } + _ = a.saveExecutionState(*state) return "", err } step.Status = planStepStatusCompleted @@ -1850,12 +3148,29 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 } if onEvent != nil { onEvent(StreamEventStepComplete, formatStepCompleteStatus(*step, lang)) - onEvent(StreamEventDelta, question) + emitStreamText(onEvent, question) } return question, nil case planStepTypeRespond: + if finalText := deterministicCompletedPlanResponse(lang, *state, *step); finalText != "" { + step.Status = planStepStatusCompleted + step.OutputSummary = finalText + state.Status = executionStatusCompleted + state.Waiting = nil + state.FinalAnswer = finalText + state.CurrentStepID = "" + state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + if err := a.saveExecutionState(*state); err != nil { + return "", err + } + if onEvent != nil { + onEvent(StreamEventStepComplete, formatStepCompleteStatus(*step, lang)) + emitStreamText(onEvent, finalText) + } + return finalText, nil + } respondStartedAt := time.Now() - finalText, err := a.generateFinalPlanResponse(ctx, userID, lang, *state, step.Instruction) + finalText, err := a.generateFinalPlanResponse(ctx, storeUserID, userID, lang, *state, step.Instruction) a.logPlannerTiming(state.SessionID, userID, "respond_step", respondStartedAt, err) if err != nil { return "", err @@ -1872,7 +3187,7 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 } if onEvent != nil { onEvent(StreamEventStepComplete, formatStepCompleteStatus(*step, lang)) - onEvent(StreamEventDelta, finalText) + emitStreamText(onEvent, finalText) } return finalText, nil default: @@ -1891,6 +3206,48 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 return "", fmt.Errorf("plan execution exceeded iteration limit") } +func deterministicCompletedPlanResponse(lang string, state ExecutionState, respondStep PlanStep) string { + if !isCompletionOnlyRespondStep(respondStep) { + return "" + } + completed := make([]PlanStep, 0, len(state.Steps)) + for _, step := range state.Steps { + if step.ID == respondStep.ID { + continue + } + if step.Status == planStepStatusCompleted && step.Type == planStepTypeTool { + completed = append(completed, step) + continue + } + if step.Status == planStepStatusCompleted && step.Type == planStepTypeReason { + return "" + } + } + if len(completed) == 0 { + return "" + } + return formatCompletedPlanFallback(lang, completed) +} + +func isCompletionOnlyRespondStep(step PlanStep) bool { + text := strings.ToLower(strings.TrimSpace(step.Title + " " + step.Instruction)) + if text == "" { + return false + } + return containsAny(text, []string{ + "成功", + "完成", + "确认", + "created", + "updated", + "deleted", + "activated", + "duplicated", + "completed", + "confirm", + }) +} + type fetchedToolRecord struct { ToolName string `json:"tool_name"` ToolArgsJSON string `json:"tool_args_json"` @@ -2022,7 +3379,7 @@ func parseRFC3339(value string) time.Time { func (a *Agent) replanAfterStep(ctx context.Context, userID int64, lang string, state ExecutionState, completedStep PlanStep) (replannerDecision, error) { obsJSON, _ := json.Marshal(buildObservationContext(state)) stepsJSON, _ := json.Marshal(state.Steps) - systemPrompt := `You are the replanning module for NOFXi. + systemPrompt := prependNOFXiAdvisorPreamble(`You are the replanning module for NOFXi. Return JSON only. Decide what to do after a plan step completed. @@ -2039,7 +3396,7 @@ Rules: - Use finish when there is enough information to answer and remaining steps are unnecessary. - If action=replace_remaining, return a fresh list of remaining steps only. - Keep plans short and safe. -- Never invent tools.` +- Never invent tools.`) userPrompt := fmt.Sprintf("Language: %s\nGoal: %s\nCompleted step: %s (%s)\nCompleted summary: %s\n\nCurrent steps JSON:\n%s\n\nObservations JSON:\n%s\n\nPersistent preferences:\n%s\n\nTask state:\n%s\n\nReturn JSON with this exact shape:\n{\"action\":\"continue|replace_remaining|ask_user|finish\",\"goal\":\"\",\"instruction\":\"\",\"question\":\"\",\"steps\":[{\"id\":\"step_x\",\"type\":\"tool|reason|ask_user|respond\",\"title\":\"\",\"tool_name\":\"\",\"tool_args\":{},\"instruction\":\"\",\"requires_confirmation\":false}]}", lang, state.Goal, completedStep.ID, completedStep.Type, completedStep.OutputSummary, string(stepsJSON), string(obsJSON), a.buildPersistentPreferencesContext(userID), buildTaskStateContext(a.getTaskState(userID))) @@ -2305,6 +3662,38 @@ func (a *Agent) executePlanTool(ctx context.Context, storeUserID string, userID }) } +func (a *Agent) redirectPlannerStrategyCreateStep(storeUserID string, userID int64, lang, text string, step PlanStep) (string, bool) { + if strings.TrimSpace(step.ToolName) != "manage_strategy" { + return "", false + } + action, _ := step.ToolArgs["action"].(string) + if strings.TrimSpace(action) != "create" { + return "", false + } + session := skillSession{ + Name: "strategy_management", + Action: "create", + Phase: "collecting", + Fields: map[string]string{}, + } + if name, _ := step.ToolArgs["name"].(string); strings.TrimSpace(name) != "" { + setField(&session, "name", name) + } + if rawConfig, ok := step.ToolArgs["config"]; ok { + if strategyType := strategyTypeFromConfigPatchAny(rawConfig); strategyType != "" { + setStrategyCreateType(&session, strategyType) + if sanitized := sanitizeStrategyCreateConfigPatchForType(rawConfig, strategyType); len(sanitized) > 0 { + raw, _ := json.Marshal(sanitized) + setField(&session, strategyCreateConfigPatchField, string(raw)) + } + } + } + if confirmed, ok := step.ToolArgs["confirmed"].(bool); ok && confirmed { + setField(&session, "awaiting_final_confirmation", "true") + } + return a.handleStrategyCreateSkill(storeUserID, userID, lang, text, session), true +} + func (a *Agent) executeReasonStep(ctx context.Context, userID int64, lang, goal string, state ExecutionState, step PlanStep) (string, error) { obsJSON, _ := json.Marshal(buildObservationContext(state)) stageCtx, cancel := withPlannerStageTimeout(ctx, plannerReasonTimeout) @@ -2325,9 +3714,8 @@ func (a *Agent) executeReasonStep(ctx context.Context, userID int64, lang, goal return summarizeObservation(resp), nil } -func (a *Agent) generateFinalPlanResponse(ctx context.Context, userID int64, lang string, state ExecutionState, instruction string) (string, error) { +func (a *Agent) generateFinalPlanResponse(ctx context.Context, storeUserID string, userID int64, lang string, state ExecutionState, instruction string) (string, error) { obsJSON, _ := json.Marshal(buildObservationContext(state)) - systemPrompt := a.buildSystemPrompt(lang) if instruction == "" { instruction = "Provide the best possible final response to the user based on the finished execution." } @@ -2336,8 +3724,9 @@ func (a *Agent) generateFinalPlanResponse(ctx context.Context, userID int64, lan startedAt := time.Now() resp, err := a.aiClient.CallWithRequest(&mcp.Request{ Messages: []mcp.Message{ - mcp.NewSystemMessage(systemPrompt), + mcp.NewSystemMessage(finalPlanResponseSystemPrompt(lang)), mcp.NewSystemMessage("You are responding after a completed execution plan. Use the observations as the source of truth. Be concise and actionable."), + mcp.NewSystemMessage(cleanUserFacingReplyInstruction), mcp.NewUserMessage(fmt.Sprintf("Goal: %s\nResponse instruction: %s\nObservations JSON: %s\nPersistent preferences: %s\nTask state: %s", state.Goal, instruction, string(obsJSON), a.buildPersistentPreferencesContext(userID), buildTaskStateContext(a.getTaskState(userID)))), }, Ctx: stageCtx, @@ -2346,6 +3735,21 @@ func (a *Agent) generateFinalPlanResponse(ctx context.Context, userID int64, lan return resp, err } +func finalPlanResponseSystemPrompt(lang string) string { + if lang == "zh" { + return `你是 NOFXi 的执行结果回复模块。 +只根据 Observations JSON 和已完成步骤回答用户。 +不要引入未观察到的策略、交易员、模型或交易所信息。 +不要承诺稍后通知;如果工具已经执行,直接说结果;如果工具失败,直接说失败原因和下一步。 +用中文,简洁清楚。` + } + return `You are NOFXi's execution-result response module. +Answer only from Observations JSON and completed steps. +Do not introduce unobserved strategy, trader, model, or exchange details. +Do not promise later notification; if a tool executed, state the result; if it failed, state the reason and next step. +Be concise and clear.` +} + func (a *Agent) logPlannerTiming(sessionID string, userID int64, stage string, startedAt time.Time, err error) { if stage == "" || startedAt.IsZero() { return @@ -2379,9 +3783,128 @@ func summarizeObservation(value string) string { return strings.TrimSpace(value[:observationMaxLength]) + "..." } +func isAIServiceFailureError(err error) bool { + if err == nil { + return false + } + lower := strings.ToLower(strings.TrimSpace(err.Error())) + if lower == "" { + return false + } + return strings.Contains(lower, "api returned error") || + strings.Contains(lower, "rate_limit_error") || + strings.Contains(lower, "upstream_empty_output") || + strings.Contains(lower, "insufficient balance") || + strings.Contains(lower, "context deadline exceeded") +} + +func planStepFallbackLabel(step PlanStep) string { + for _, candidate := range []string{ + strings.TrimSpace(step.Title), + strings.TrimSpace(step.Instruction), + strings.TrimSpace(step.ToolName), + } { + if candidate != "" { + return candidate + } + } + return strings.TrimSpace(step.ID) +} + +func formatCompletedPlanFallback(lang string, steps []PlanStep) string { + labels := make([]string, 0, len(steps)) + for _, step := range steps { + if label := planStepFallbackLabel(step); label != "" { + labels = append(labels, label) + } + } + if len(labels) == 0 { + return "" + } + if lang == "zh" { + lines := []string{"已完成:"} + for _, label := range labels { + lines = append(lines, "- "+label) + } + return strings.Join(lines, "\n") + } + lines := []string{"Completed:"} + for _, label := range labels { + lines = append(lines, "- "+label) + } + return strings.Join(lines, "\n") +} + +func (a *Agent) tryExecutionSummaryFallbackOnAIError(lang string, state *ExecutionState, err error, onEvent func(event, data string)) (string, bool) { + if a == nil || state == nil || !isAIServiceFailureError(err) { + return "", false + } + completed := make([]PlanStep, 0, len(state.Steps)) + for _, step := range state.Steps { + if step.Status == planStepStatusCompleted && step.Type == planStepTypeTool { + completed = append(completed, step) + } + } + if len(completed) == 0 { + return "", false + } + answer := formatCompletedPlanFallback(lang, completed) + if answer == "" { + return "", false + } + currentStepID := state.CurrentStepID + state.Status = executionStatusCompleted + state.Waiting = nil + state.FinalAnswer = answer + state.LastError = strings.TrimSpace(err.Error()) + state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + for i := range state.Steps { + if state.Steps[i].ID == currentStepID || (state.Steps[i].Status == planStepStatusRunning && state.Steps[i].Type == planStepTypeRespond) { + state.Steps[i].Status = planStepStatusCompleted + state.Steps[i].OutputSummary = answer + state.Steps[i].Error = "" + } + } + state.CurrentStepID = "" + appendExecutionLog(state, Observation{ + Kind: "respond_fallback", + Summary: summarizeObservation(answer), + RawJSON: err.Error(), + CreatedAt: time.Now().UTC().Format(time.RFC3339), + }) + _ = a.saveExecutionState(*state) + if onEvent != nil { + emitStreamText(onEvent, answer) + } + return answer, true +} + +func (a *Agent) tryDeterministicFallbackAfterAIServiceFailure(ctx context.Context, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { + storeUserID := storeUserIDFromContext(ctx) + if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { + return a.maybeAppendResumePrompt(userID, lang, text, answer), true, nil + } + if state := a.getExecutionState(userID); hasActiveExecutionState(state) || len(state.Steps) > 0 { + completed := make([]PlanStep, 0, len(state.Steps)) + for _, step := range state.Steps { + if step.Status == planStepStatusCompleted && step.Type == planStepTypeTool { + completed = append(completed, step) + } + } + if answer := formatCompletedPlanFallback(lang, completed); answer != "" { + return a.maybeAppendResumePrompt(userID, lang, text, answer), true, nil + } + } + return "", false, nil +} + func (a *Agent) thinkAndActLegacy(ctx context.Context, userID int64, lang, text string, onEvent func(event, data string)) (string, error) { - systemPrompt := a.buildSystemPrompt(lang) - enrichment := a.gatherContext(text) + return a.thinkAndActLegacyWithStore(ctx, storeUserIDFromContext(ctx), userID, lang, text, onEvent) +} + +func (a *Agent) thinkAndActLegacyWithStore(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, error) { + systemPrompt := a.buildSystemPromptForStoreUser(lang, storeUserID) + enrichment := a.gatherContext(storeUserID, text) preferencesCtx := a.buildPersistentPreferencesContext(userID) userPrompt := text @@ -2429,20 +3952,60 @@ func (a *Agent) thinkAndActLegacy(ctx context.Context, userID int64, lang, text plainResp, plainErr := a.aiClient.CallWithRequest(&mcp.Request{Messages: messages, Ctx: ctx}) if plainErr != nil { a.logger.Warn("legacy AI plain fallback failed", "error", plainErr, "user_id", userID) + if answer, ok, fallbackErr := a.tryDeterministicFallbackAfterAIServiceFailure(ctx, userID, lang, text, onEvent); ok || fallbackErr != nil { + return answer, fallbackErr + } return a.aiServiceFailure(lang, plainErr) } + if looksLikeInternalAgentJSON(plainResp) { + a.logger.Warn("legacy AI plain fallback returned internal orchestration json; attempting active-flow recovery", "user_id", userID) + if answer, ok, err := a.tryRecoverFromInternalAgentJSON(ctx, storeUserID, userID, lang, text, plainResp, onEvent); ok || err != nil { + return answer, err + } + if answer, ok, fallbackErr := a.tryDeterministicFallbackAfterAIServiceFailure(ctx, userID, lang, text, onEvent); ok || fallbackErr != nil { + return answer, fallbackErr + } + if lang == "zh" { + return "我理解到你还在继续刚才的操作,但这次内部回复格式不对。你再说一次刚才想做的那一步,我继续接着帮你。", nil + } + return "I can tell you're continuing the previous task, but the internal response format was invalid. Please repeat that step and I'll keep going.", nil + } if onEvent != nil { - onEvent(StreamEventDelta, plainResp) + emitStreamText(onEvent, plainResp) } return plainResp, nil } a.logger.Warn("legacy AI tool round failed", "error", err, "user_id", userID, "round", round) + if answer, ok, fallbackErr := a.tryDeterministicFallbackAfterAIServiceFailure(ctx, userID, lang, text, onEvent); ok || fallbackErr != nil { + return answer, fallbackErr + } return a.aiServiceFailure(lang, err) } if len(resp.ToolCalls) == 0 { + if looksLikeInternalAgentJSON(resp.Content) { + a.logger.Warn("legacy AI returned internal orchestration json; attempting active-flow recovery", "user_id", userID) + if answer, ok, err := a.tryRecoverFromInternalAgentJSON(ctx, storeUserID, userID, lang, text, resp.Content, onEvent); ok || err != nil { + return answer, err + } + if answer, ok, fallbackErr := a.tryDeterministicFallbackAfterAIServiceFailure(ctx, userID, lang, text, onEvent); ok || fallbackErr != nil { + return answer, fallbackErr + } + if lang == "zh" { + return "我理解到你还在继续刚才的操作,但这次内部回复格式不对。你再说一次刚才想做的那一步,我继续接着帮你。", nil + } + return "I can tell you're continuing the previous task, but the internal response format was invalid. Please repeat that step and I'll keep going.", nil + } if onEvent != nil { - onEvent(StreamEventDelta, resp.Content) + reply := resp.Content + if guarded, blocked := guardUnsupportedAsyncPromise(lang, reply); blocked { + reply = guarded + } + emitStreamText(onEvent, reply) + return reply, nil + } + if guarded, blocked := guardUnsupportedAsyncPromise(lang, resp.Content); blocked { + return guarded, nil } return resp.Content, nil } @@ -2457,7 +4020,7 @@ func (a *Agent) thinkAndActLegacy(ctx context.Context, userID int64, lang, text if onEvent != nil { onEvent(StreamEventTool, tc.Function.Name) } - result := a.handleToolCall(ctx, storeUserIDFromContext(ctx), userID, lang, tc) + result := a.handleToolCall(ctx, storeUserID, userID, lang, tc) messages = append(messages, mcp.Message{ Role: "tool", Content: result, @@ -2469,10 +4032,33 @@ func (a *Agent) thinkAndActLegacy(ctx context.Context, userID int64, lang, text finalResp, err := a.aiClient.CallWithRequest(&mcp.Request{Messages: messages, Ctx: ctx}) if err != nil { a.logger.Warn("legacy AI final response failed", "error", err, "user_id", userID) + if answer, ok, fallbackErr := a.tryDeterministicFallbackAfterAIServiceFailure(ctx, userID, lang, text, onEvent); ok || fallbackErr != nil { + return answer, fallbackErr + } return a.aiServiceFailure(lang, err) } + if looksLikeInternalAgentJSON(finalResp) { + a.logger.Warn("legacy AI final response returned internal orchestration json; attempting active-flow recovery", "user_id", userID) + if answer, ok, err := a.tryRecoverFromInternalAgentJSON(ctx, storeUserID, userID, lang, text, finalResp, onEvent); ok || err != nil { + return answer, err + } + if answer, ok, fallbackErr := a.tryDeterministicFallbackAfterAIServiceFailure(ctx, userID, lang, text, onEvent); ok || fallbackErr != nil { + return answer, fallbackErr + } + if lang == "zh" { + return "我理解到你还在继续刚才的操作,但这次内部回复格式不对。你再说一次刚才想做的那一步,我继续接着帮你。", nil + } + return "I can tell you're continuing the previous task, but the internal response format was invalid. Please repeat that step and I'll keep going.", nil + } if onEvent != nil { - onEvent(StreamEventDelta, finalResp) + if guarded, blocked := guardUnsupportedAsyncPromise(lang, finalResp); blocked { + finalResp = guarded + } + emitStreamText(onEvent, finalResp) + return finalResp, nil + } + if guarded, blocked := guardUnsupportedAsyncPromise(lang, finalResp); blocked { + return guarded, nil } return finalResp, nil } diff --git a/agent/planner_runtime_state_test.go b/agent/planner_runtime_state_test.go deleted file mode 100644 index ed1b08da..00000000 --- a/agent/planner_runtime_state_test.go +++ /dev/null @@ -1,807 +0,0 @@ -package agent - -import ( - "context" - "encoding/json" - "errors" - "log/slog" - "strings" - "testing" - "time" - - "nofx/mcp" -) - -func TestIsConfigOrTraderIntent(t *testing.T) { - cases := []struct { - text string - want bool - }{ - {text: "帮我创建一个交易员", want: true}, - {text: "我已经配置好了 OKX 和 DeepSeek", want: true}, - {text: "List my traders", want: true}, - {text: "BTC 接下来怎么看", want: false}, - } - for _, tc := range cases { - if got := isConfigOrTraderIntent(tc.text); got != tc.want { - t.Fatalf("isConfigOrTraderIntent(%q) = %v, want %v", tc.text, got, tc.want) - } - } -} - -func TestIsRealtimeAccountIntent(t *testing.T) { - cases := []struct { - text string - want bool - }{ - {text: "现在余额多少", want: true}, - {text: "我的仓位还在吗", want: true}, - {text: "show recent trade history", want: true}, - {text: "帮我创建交易员", want: false}, - } - for _, tc := range cases { - if got := isRealtimeAccountIntent(tc.text); got != tc.want { - t.Fatalf("isRealtimeAccountIntent(%q) = %v, want %v", tc.text, got, tc.want) - } - } -} - -func TestDetectReadFastPath(t *testing.T) { - cases := []struct { - text string - want string - }{ - {text: "/traders", want: "list_traders"}, - {text: "/strategies", want: "get_strategies"}, - {text: "/models", want: "get_model_configs"}, - {text: "/exchanges", want: "get_exchange_configs"}, - {text: "/balance", want: "get_balance"}, - {text: "/positions", want: "get_positions"}, - {text: "/history", want: "get_trade_history"}, - {text: "/trades", want: "get_trade_history"}, - {text: "列出我当前的策略", want: ""}, - {text: "查看当前交易员", want: ""}, - {text: "现在余额多少", want: ""}, - {text: "我的仓位还在吗", want: ""}, - {text: "我现在有哪些账户", want: ""}, - {text: "我的余额", want: ""}, - {text: "根据我的余额帮我分析我应该买什么", want: ""}, - {text: "我的策略是AI100,但是No candidate coins available, cycle skipped", want: ""}, - {text: "帮我创建一个 trader", want: ""}, - } - for _, tc := range cases { - req := detectReadFastPath(tc.text) - got := "" - if req != nil { - got = req.Kind - } - if got != tc.want { - t.Fatalf("detectReadFastPath(%q) = %q, want %q", tc.text, got, tc.want) - } - } -} - -func TestShouldResetExecutionStateForNewAttempt(t *testing.T) { - state := ExecutionState{ - SessionID: "sess_1", - Status: executionStatusWaitingUser, - } - if !shouldResetExecutionStateForNewAttempt("我已经配置好了,继续创建交易员", state) { - t.Fatalf("expected retry-style config request to reset execution state") - } - if shouldResetExecutionStateForNewAttempt("BTC 价格多少", state) { - t.Fatalf("did not expect generic market query to reset execution state") - } -} - -func TestLatestAskedQuestion(t *testing.T) { - state := ExecutionState{ - Status: executionStatusWaitingUser, - Steps: []PlanStep{ - {ID: "step_1", Type: planStepTypeTool, Status: planStepStatusCompleted}, - {ID: "step_2", Type: planStepTypeAskUser, Status: planStepStatusCompleted, Instruction: "需要我用正确的参数重试创建交易员 lky 吗?"}, - }, - } - got := latestAskedQuestion(state) - want := "需要我用正确的参数重试创建交易员 lky 吗?" - if got != want { - t.Fatalf("latestAskedQuestion() = %q, want %q", got, want) - } -} - -func TestLatestAskedQuestionPrefersStructuredWaitingState(t *testing.T) { - state := ExecutionState{ - Status: executionStatusWaitingUser, - Waiting: &WaitingState{ - Question: "请确认是否继续创建交易员 lky", - Intent: "confirm_action", - }, - Steps: []PlanStep{ - {ID: "step_2", Type: planStepTypeAskUser, Status: planStepStatusCompleted, Instruction: "旧问题"}, - }, - } - if got := latestAskedQuestion(state); got != "请确认是否继续创建交易员 lky" { - t.Fatalf("latestAskedQuestion() = %q, want structured waiting question", got) - } -} - -func TestRefreshStateForDynamicRequestsAddsFreshSnapshots(t *testing.T) { - a := newTestAgentWithStore(t) - - _ = a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":true, - "custom_api_url":"https://api.openai.com/v1", - "custom_model_name":"gpt-5-mini" - }`) - _ = a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"okx", - "account_name":"Main", - "enabled":true - }`) - - state := ExecutionState{ - SessionID: "sess_1", - UserID: 1, - DynamicSnapshots: []Observation{ - {Kind: "current_model_configs", Summary: "stale"}, - }, - ExecutionLog: []Observation{{Kind: "user_reply", Summary: "continue"}}, - } - - refreshed := a.refreshStateForDynamicRequests("user-1", "帮我创建交易员", state) - - if len(refreshed.DynamicSnapshots) < 3 { - t.Fatalf("expected refreshed observations to include snapshots, got %+v", refreshed.DynamicSnapshots) - } - - var foundModel, foundExchange, foundTraders bool - for _, obs := range refreshed.DynamicSnapshots { - switch obs.Kind { - case "current_model_configs": - foundModel = strings.Contains(obs.RawJSON, "openai") - case "current_exchange_configs": - foundExchange = strings.Contains(obs.RawJSON, "okx") - case "current_traders": - foundTraders = strings.Contains(obs.RawJSON, `"traders"`) - } - } - - if !foundModel || !foundExchange || !foundTraders { - t.Fatalf("missing fresh snapshots: %+v", refreshed.DynamicSnapshots) - } -} - -func TestRefreshStateForRealtimeAccountRequestsAddsFreshSnapshots(t *testing.T) { - a := newTestAgentWithStore(t) - - state := ExecutionState{ - SessionID: "sess_2", - UserID: 1, - DynamicSnapshots: []Observation{ - {Kind: "current_balances", Summary: "stale balances"}, - {Kind: "current_positions", Summary: "stale positions"}, - }, - ExecutionLog: []Observation{{Kind: "user_reply", Summary: "现在余额多少"}}, - } - - refreshed := a.refreshStateForDynamicRequests("user-1", "现在余额多少,我的仓位还在吗", state) - - var keptBalances, keptPositions, foundHistory bool - for _, obs := range refreshed.DynamicSnapshots { - switch obs.Kind { - case "current_balances": - keptBalances = strings.Contains(obs.Summary, "stale balances") - case "current_positions": - keptPositions = strings.Contains(obs.Summary, "stale positions") - case "recent_trade_history": - foundHistory = obs.RawJSON != "" - } - } - - if !keptBalances || !keptPositions || foundHistory { - t.Fatalf("expected realtime snapshots to stay untouched, got %+v", refreshed.DynamicSnapshots) - } -} - -func TestThinkAndActNaturalLanguageReadCanBeHandledByHighLevelSkill(t *testing.T) { - a := newTestAgentWithStore(t) - _ = a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"激进", - "description":"激进策略模板", - "lang":"zh" - }`) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 1, "zh", "列出我当前的策略") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "当前策略") || !strings.Contains(resp, "激进") { - t.Fatalf("expected natural-language read to be handled by high-level skill, got %q", resp) - } -} - -func TestNormalizeExecutionStateMigratesLegacyObservations(t *testing.T) { - state := normalizeExecutionState(ExecutionState{ - SessionID: "sess_legacy", - UserID: 1, - Observations: []Observation{ - {Kind: "tool_result", Summary: "legacy tool result"}, - }, - }) - - if len(state.Observations) != 0 { - t.Fatalf("expected legacy observations field to be cleared, got %+v", state.Observations) - } - if len(state.ExecutionLog) != 1 || state.ExecutionLog[0].Summary != "legacy tool result" { - t.Fatalf("expected legacy observations to migrate into execution log, got %+v", state.ExecutionLog) - } -} - -func TestBuildWaitingStateForTraderConfirmation(t *testing.T) { - state := ExecutionState{Goal: "创建交易员 lky"} - step := PlanStep{ - ID: "step_ask_1", - Type: planStepTypeAskUser, - Instruction: "需要我用正确的参数重试创建交易员 lky 吗?", - RequiresConfirmation: true, - } - - waiting := buildWaitingState(state, step, step.Instruction) - if waiting == nil { - t.Fatal("expected waiting state") - } - if waiting.Intent != "confirm_action" { - t.Fatalf("unexpected waiting intent: %+v", waiting) - } - if waiting.ConfirmationTarget != "trader" { - t.Fatalf("unexpected confirmation target: %+v", waiting) - } -} - -func TestNormalizeWaitingStateCleansFields(t *testing.T) { - state := normalizeExecutionState(ExecutionState{ - SessionID: "sess_waiting", - UserID: 1, - Waiting: &WaitingState{ - Question: " 请提供 strategy_id ", - Intent: " complete_trader_setup ", - PendingFields: []string{" strategy_id ", "strategy_id"}, - ConfirmationTarget: " trader ", - }, - }) - - if state.Waiting == nil { - t.Fatal("expected normalized waiting state") - } - if state.Waiting.Question != "请提供 strategy_id" { - t.Fatalf("unexpected normalized question: %+v", state.Waiting) - } - if len(state.Waiting.PendingFields) != 1 || state.Waiting.PendingFields[0] != "strategy_id" { - t.Fatalf("unexpected pending fields: %+v", state.Waiting) - } - if state.Waiting.ConfirmationTarget != "trader" { - t.Fatalf("unexpected confirmation target: %+v", state.Waiting) - } -} - -func TestRefreshCurrentReferencesForUserTextMatchesStrategyName(t *testing.T) { - a := newTestAgentWithStore(t) - _ = a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"激进", - "description":"激进策略模板", - "lang":"zh" - }`) - - state := newExecutionState(1, "帮我改一下激进这个策略") - a.refreshCurrentReferencesForUserText("user-1", "帮我改一下激进这个策略", &state) - - if state.CurrentReferences == nil || state.CurrentReferences.Strategy == nil { - t.Fatalf("expected strategy reference, got %+v", state.CurrentReferences) - } - if state.CurrentReferences.Strategy.Name != "激进" { - t.Fatalf("unexpected strategy reference: %+v", state.CurrentReferences.Strategy) - } -} - -func TestUpdateCurrentReferencesFromToolResultTracksCreatedStrategy(t *testing.T) { - state := newExecutionState(1, "创建策略") - changed := updateCurrentReferencesFromToolResult(&state, "manage_strategy", `{ - "status":"ok", - "action":"create", - "strategy":{"id":"strategy_1","name":"激进"} - }`) - - if !changed { - t.Fatalf("expected reference update to report changed") - } - if state.CurrentReferences == nil || state.CurrentReferences.Strategy == nil { - t.Fatalf("expected strategy reference after tool result, got %+v", state.CurrentReferences) - } - if state.CurrentReferences.Strategy.ID != "strategy_1" { - t.Fatalf("unexpected strategy reference: %+v", state.CurrentReferences.Strategy) - } -} - -func TestShouldAttemptReplan(t *testing.T) { - state := ExecutionState{ - Steps: []PlanStep{ - {ID: "step_1", Type: planStepTypeTool, Status: planStepStatusCompleted}, - {ID: "step_2", Type: planStepTypeRespond, Status: planStepStatusPending}, - }, - } - - if !shouldAttemptReplan(state, PlanStep{ - Type: planStepTypeTool, - ToolName: "manage_trader", - ToolArgs: map[string]any{"action": "create"}, - OutputSummary: `{"status":"ok","action":"create"}`, - }, false) { - t.Fatalf("expected create trader step to trigger replan") - } - - if shouldAttemptReplan(state, PlanStep{ - Type: planStepTypeTool, - ToolName: "get_balance", - OutputSummary: `{"balances":[]}`, - }, false) { - t.Fatalf("did not expect read-only balance step to trigger replan") - } - - if !shouldAttemptReplan(state, PlanStep{ - Type: planStepTypeTool, - ToolName: "get_balance", - OutputSummary: `{"error":"ai_model_id is required"}`, - }, false) { - t.Fatalf("expected dependency/error result to trigger replan") - } -} - -type failingAIClient struct{} - -func (f *failingAIClient) SetAPIKey(string, string, string) {} -func (f *failingAIClient) SetTimeout(_ time.Duration) {} -func (f *failingAIClient) CallWithMessages(string, string) (string, error) { - return "", errors.New("unexpected CallWithMessages") -} -func (f *failingAIClient) CallWithRequest(*mcp.Request) (string, error) { - return "", errors.New("API returned error (status 402): insufficient balance") -} -func (f *failingAIClient) CallWithRequestStream(*mcp.Request, func(string)) (string, error) { - return "", errors.New("unexpected CallWithRequestStream") -} -func (f *failingAIClient) CallWithRequestFull(*mcp.Request) (*mcp.LLMResponse, error) { - return nil, errors.New("API returned error (status 402): insufficient balance") -} - -type capturePlannerAIClient struct { - systemPrompt string - userPrompt string -} - -func (c *capturePlannerAIClient) SetAPIKey(string, string, string) {} -func (c *capturePlannerAIClient) SetTimeout(time.Duration) {} -func (c *capturePlannerAIClient) CallWithMessages(string, string) (string, error) { - return "", errors.New("unexpected CallWithMessages") -} -func (c *capturePlannerAIClient) CallWithRequest(req *mcp.Request) (string, error) { - if len(req.Messages) > 0 { - c.systemPrompt = req.Messages[0].Content - } - if len(req.Messages) > 1 { - c.userPrompt = req.Messages[1].Content - } - return `{"goal":"test goal","steps":[{"id":"step_1","type":"respond","instruction":"ok"}]}`, nil -} -func (c *capturePlannerAIClient) CallWithRequestStream(*mcp.Request, func(string)) (string, error) { - return "", errors.New("unexpected CallWithRequestStream") -} -func (c *capturePlannerAIClient) CallWithRequestFull(*mcp.Request) (*mcp.LLMResponse, error) { - return nil, errors.New("unexpected CallWithRequestFull") -} - -type blockingAIClient struct{} - -func (b *blockingAIClient) SetAPIKey(string, string, string) {} -func (b *blockingAIClient) SetTimeout(time.Duration) {} -func (b *blockingAIClient) CallWithMessages(string, string) (string, error) { - return "", errors.New("unexpected CallWithMessages") -} -func (b *blockingAIClient) CallWithRequest(req *mcp.Request) (string, error) { - <-req.Ctx.Done() - return "", req.Ctx.Err() -} -func (b *blockingAIClient) CallWithRequestStream(*mcp.Request, func(string)) (string, error) { - return "", errors.New("unexpected CallWithRequestStream") -} -func (b *blockingAIClient) CallWithRequestFull(*mcp.Request) (*mcp.LLMResponse, error) { - return nil, errors.New("unexpected CallWithRequestFull") -} - -type directReplyAIClient struct { - lastSystemPrompt string - lastUserPrompt string - routerPrompt string - skillRouterPrompt string - plannerPrompt string -} - -func (d *directReplyAIClient) SetAPIKey(string, string, string) {} -func (d *directReplyAIClient) SetTimeout(time.Duration) {} -func (d *directReplyAIClient) CallWithMessages(string, string) (string, error) { - return "", errors.New("unexpected CallWithMessages") -} -func (d *directReplyAIClient) CallWithRequest(req *mcp.Request) (string, error) { - if len(req.Messages) > 0 { - d.lastSystemPrompt = req.Messages[0].Content - } - if len(req.Messages) > 1 { - d.lastUserPrompt = req.Messages[1].Content - } - if strings.Contains(d.lastSystemPrompt, "first-pass router for NOFXi") { - d.routerPrompt = d.lastSystemPrompt - if strings.Contains(d.lastUserPrompt, "你好") { - return `{"action":"direct_answer","answer":"你好,我在。想聊策略、配置还是排障?"}`, nil - } - return `{"action":"defer","answer":""}`, nil - } - if strings.Contains(d.lastSystemPrompt, "lightweight skill router for NOFXi") { - d.skillRouterPrompt = d.lastSystemPrompt - if strings.Contains(d.lastUserPrompt, "运行中的trader") || strings.Contains(d.lastUserPrompt, "有没有 trader 在跑") { - return `{"route":"skill","skill":"trader_management","action":"query","filter":"running_only"}`, nil - } - return `{"route":"planner","skill":"","action":"","filter":""}`, nil - } - if strings.Contains(d.lastSystemPrompt, "planning module for NOFXi") { - d.plannerPrompt = d.lastSystemPrompt - } - return `{"goal":"test goal","steps":[{"id":"step_1","type":"respond","instruction":"ok"}]}`, nil -} -func (d *directReplyAIClient) CallWithRequestStream(*mcp.Request, func(string)) (string, error) { - return "", errors.New("unexpected CallWithRequestStream") -} -func (d *directReplyAIClient) CallWithRequestFull(*mcp.Request) (*mcp.LLMResponse, error) { - return nil, errors.New("unexpected CallWithRequestFull") -} - -func TestThinkAndActLegacyReturnsProviderFailureInsteadOfNoAIFallback(t *testing.T) { - a := &Agent{ - aiClient: &failingAIClient{}, - config: DefaultConfig(), - logger: slog.Default(), - history: newChatHistory(10), - } - - resp, err := a.thinkAndActLegacy(context.Background(), 42, "zh", "你好", nil) - if err != nil { - t.Fatalf("thinkAndActLegacy() error = %v", err) - } - if strings.Contains(resp, "发送 *开始配置* 配置 AI 模型") { - t.Fatalf("expected provider failure message, got fallback: %q", resp) - } - if !strings.Contains(resp, "AI 服务调用失败") { - t.Fatalf("expected provider failure message, got %q", resp) - } -} - -func TestThinkAndActUsesDirectReplyGateForConversationalQuestion(t *testing.T) { - client := &directReplyAIClient{} - a := &Agent{ - aiClient: client, - config: DefaultConfig(), - logger: slog.Default(), - history: newChatHistory(10), - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 88, "zh", "你好") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "你好,我在") { - t.Fatalf("expected direct reply response, got %q", resp) - } - if !strings.Contains(client.routerPrompt, "first-pass router for NOFXi") { - t.Fatalf("expected direct reply router prompt, got %q", client.routerPrompt) - } -} - -func TestThinkAndActDefersFromDirectReplyGateToHardSkill(t *testing.T) { - a := newTestAgentWithStore(t) - a.aiClient = &directReplyAIClient{} - - resp, err := a.thinkAndAct(context.Background(), "user-1", 89, "zh", "帮我创建一个 DeepSeek 模型配置") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "已创建模型配置") { - t.Fatalf("expected direct reply gate to defer to hard skill, got %q", resp) - } -} - -func TestThinkAndActUsesLLMSkillRouterForNaturalLanguageTraderQuery(t *testing.T) { - client := &directReplyAIClient{} - a := newTestAgentWithStore(t) - a.aiClient = client - a.history = newChatHistory(10) - - modelResp := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":true, - "custom_api_url":"https://api.openai.com/v1", - "custom_model_name":"gpt-5-mini" - }`) - var modelCreated struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(modelResp), &modelCreated); err != nil { - t.Fatalf("unmarshal model response: %v", err) - } - - exchangeResp := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"binance", - "account_name":"Main", - "enabled":true - }`) - var exchangeCreated struct { - Exchange safeExchangeToolConfig `json:"exchange"` - } - if err := json.Unmarshal([]byte(exchangeResp), &exchangeCreated); err != nil { - t.Fatalf("unmarshal exchange response: %v", err) - } - - createResp := a.toolManageTrader("user-1", `{ - "action":"create", - "name":"Momentum Trader", - "ai_model_id":"`+modelCreated.Model.ID+`", - "exchange_id":"`+exchangeCreated.Exchange.ID+`", - "scan_interval_minutes":5 - }`) - var created struct { - Trader safeTraderToolConfig `json:"trader"` - } - if err := json.Unmarshal([]byte(createResp), &created); err != nil { - t.Fatalf("unmarshal create trader response: %v\nraw=%s", err, createResp) - } - if err := a.store.Trader().UpdateStatus("user-1", created.Trader.ID, true); err != nil { - t.Fatalf("update trader status: %v", err) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 90, "zh", "当前有运行中的trader吗") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "运行中的交易员") || !strings.Contains(resp, "Momentum Trader") { - t.Fatalf("expected routed running-trader answer, got %q", resp) - } - if client.skillRouterPrompt == "" { - t.Fatal("expected lightweight skill router prompt to be used") - } - if client.plannerPrompt != "" { - t.Fatalf("expected planner to be skipped, got prompt %q", client.plannerPrompt) - } -} - -func TestThinkAndActPrioritizesActiveExecutionStateOverDirectReply(t *testing.T) { - client := &directReplyAIClient{} - a := newTestAgentWithStore(t) - a.aiClient = client - a.history = newChatHistory(10) - a.logger = slog.Default() - - userID := int64(90) - state := newExecutionState(userID, "继续完成当前任务") - state.Status = executionStatusWaitingUser - state.Waiting = &WaitingState{ - Question: "请确认是否继续", - Intent: "confirm_action", - } - if err := a.saveExecutionState(state); err != nil { - t.Fatalf("saveExecutionState() error = %v", err) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", userID, "zh", "你好") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if strings.Contains(resp, "你好,我在") { - t.Fatalf("expected active execution state to bypass direct reply gate, got %q", resp) - } - if !strings.Contains(client.plannerPrompt, "planning module for NOFXi") { - t.Fatalf("expected planner prompt when execution state is active, got %q", client.plannerPrompt) - } -} - -func TestThinkAndActInterruptsWaitingExecutionStateForNewTopic(t *testing.T) { - a := newTestAgentWithStore(t) - a.history = newChatHistory(10) - - _ = a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"激进", - "lang":"zh" - }`) - - userID := int64(91) - state := newExecutionState(userID, "创建交易员") - state.Status = executionStatusWaitingUser - state.Waiting = &WaitingState{ - Question: "请告诉我交易员名称", - PendingFields: []string{"name"}, - } - if err := a.saveExecutionState(state); err != nil { - t.Fatalf("saveExecutionState() error = %v", err) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", userID, "zh", "列出我当前的策略") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "当前策略") || !strings.Contains(resp, "激进") { - t.Fatalf("expected new topic to be handled, got %q", resp) - } - if got := a.getExecutionState(userID); got.SessionID != "" { - t.Fatalf("expected execution state to be cleared, got %+v", got) - } -} - -func TestCreateExecutionPlanIncludesRecentConversation(t *testing.T) { - client := &capturePlannerAIClient{} - a := &Agent{ - aiClient: client, - config: DefaultConfig(), - logger: slog.Default(), - history: newChatHistory(10), - } - - userID := int64(42) - a.history.Add(userID, "user", "先帮我看一下当前trader") - a.history.Add(userID, "assistant", "当前只有测试1这个trader。") - a.history.Add(userID, "user", "好的,那就按当前trader来") - - _, err := a.createExecutionPlan(context.Background(), userID, "zh", "好的,那就按当前trader来", newExecutionState(userID, "好的,那就按当前trader来")) - if err != nil { - t.Fatalf("createExecutionPlan() error = %v", err) - } - if !strings.Contains(client.userPrompt, "Recent conversation:") { - t.Fatalf("expected planner prompt to include recent conversation, got %q", client.userPrompt) - } - if !strings.Contains(client.userPrompt, "先帮我看一下当前trader") { - t.Fatalf("expected previous user turn in recent conversation, got %q", client.userPrompt) - } - if !strings.Contains(client.userPrompt, "当前只有测试1这个trader") { - t.Fatalf("expected previous assistant turn in recent conversation, got %q", client.userPrompt) - } - recentIdx := strings.Index(client.userPrompt, "Recent conversation:\n") - toolsIdx := strings.Index(client.userPrompt, "\n\nAvailable tools JSON:") - if recentIdx == -1 || toolsIdx == -1 || toolsIdx <= recentIdx { - t.Fatalf("expected recent conversation block boundaries, got %q", client.userPrompt) - } - recentBlock := client.userPrompt[recentIdx:toolsIdx] - if strings.Contains(recentBlock, "好的,那就按当前trader来") { - t.Fatalf("expected current user text to stay out of recent conversation block, got %q", recentBlock) - } - if !strings.Contains(client.systemPrompt, "Memory priority order:") { - t.Fatalf("expected planner system prompt to include memory priority guidance, got %q", client.systemPrompt) - } - if !strings.Contains(client.systemPrompt, "Execution state JSON = current operational truth") { - t.Fatalf("expected planner system prompt to prioritize execution state, got %q", client.systemPrompt) - } - if !strings.Contains(client.systemPrompt, "Do not ask the user to repeat a fact") { - t.Fatalf("expected planner system prompt to forbid unnecessary repeated questions, got %q", client.systemPrompt) - } -} - -func TestCreateExecutionPlanIncludesRecentConversationForFreshRequest(t *testing.T) { - client := &capturePlannerAIClient{} - a := &Agent{ - aiClient: client, - config: DefaultConfig(), - logger: slog.Default(), - history: newChatHistory(10), - } - - userID := int64(99) - a.history.Add(userID, "user", "先帮我看一下当前trader") - a.history.Add(userID, "assistant", "当前只有测试1这个trader。") - - _, err := a.createExecutionPlan(context.Background(), userID, "zh", "帮我分析一下比特币", ExecutionState{}) - if err != nil { - t.Fatalf("createExecutionPlan() error = %v", err) - } - if !strings.Contains(client.userPrompt, "Recent conversation:") { - t.Fatalf("expected fresh request to still include recent conversation block, got %q", client.userPrompt) - } - if !strings.Contains(client.userPrompt, "先帮我看一下当前trader") { - t.Fatalf("expected previous user turn in recent conversation, got %q", client.userPrompt) - } - if !strings.Contains(client.userPrompt, "当前只有测试1这个trader") { - t.Fatalf("expected previous assistant turn in recent conversation, got %q", client.userPrompt) - } -} - -func TestCreateExecutionPlanIncludesQuotedEarlierAssistantClaim(t *testing.T) { - client := &capturePlannerAIClient{} - a := &Agent{ - aiClient: client, - config: DefaultConfig(), - logger: slog.Default(), - history: newChatHistory(10), - } - - userID := int64(100) - a.history.Add(userID, "user", "配置页怎么只有三个交易所") - a.history.Add(userID, "assistant", "目前你看到的是三个交易所。") - - _, err := a.createExecutionPlan(context.Background(), userID, "zh", "你前面也跟我说只有三个交易所", ExecutionState{}) - if err != nil { - t.Fatalf("createExecutionPlan() error = %v", err) - } - if !strings.Contains(client.userPrompt, "目前你看到的是三个交易所") { - t.Fatalf("expected planner prompt to include earlier assistant claim, got %q", client.userPrompt) - } - if !strings.Contains(client.userPrompt, "配置页怎么只有三个交易所") { - t.Fatalf("expected planner prompt to include earlier user complaint, got %q", client.userPrompt) - } -} - -func TestRunPlannedAgentReturnsTimeoutMessageOnPlannerTimeout(t *testing.T) { - oldTimeout := plannerCreateTimeout - plannerCreateTimeout = 10 * time.Millisecond - defer func() { plannerCreateTimeout = oldTimeout }() - - a := &Agent{ - aiClient: &blockingAIClient{}, - config: DefaultConfig(), - logger: slog.Default(), - history: newChatHistory(10), - } - - resp, err := a.runPlannedAgent(context.Background(), "default", 7, "zh", "帮我分析一下当前市场", nil) - if err != nil { - t.Fatalf("runPlannedAgent() error = %v", err) - } - if !strings.Contains(resp, "处理超时") { - t.Fatalf("expected timeout message, got %q", resp) - } -} - -func TestHandleMessageForStoreUserBypassesPlannerForTradeConfirmation(t *testing.T) { - a := &Agent{ - config: DefaultConfig(), - logger: slog.Default(), - history: newChatHistory(10), - pending: newPendingTrades(), - } - - resp, err := a.handleMessageForStoreUser(context.Background(), "default", 1, "确认 trade_missing") - if err != nil { - t.Fatalf("handleMessageForStoreUser() error = %v", err) - } - if !strings.Contains(resp, "交易已过期或不存在") { - t.Fatalf("expected direct trade confirmation handling, got %q", resp) - } -} - -func TestResolveModelRuntimeConfigUsesProviderDefaults(t *testing.T) { - url, model := resolveModelRuntimeConfig("deepseek", "", "", "user_deepseek") - if url != "https://api.deepseek.com/v1" { - t.Fatalf("unexpected deepseek default url: %q", url) - } - if model != "deepseek-chat" { - t.Fatalf("unexpected deepseek default model: %q", model) - } - - url, model = resolveModelRuntimeConfig("deepseek", "", "deepseek1", "user_deepseek") - if url != "https://api.deepseek.com/v1" { - t.Fatalf("unexpected resolved url: %q", url) - } - if model != "deepseek1" { - t.Fatalf("expected existing custom model name to win, got %q", model) - } -} diff --git a/agent/planner_tools_test.go b/agent/planner_tools_test.go new file mode 100644 index 00000000..955efee2 --- /dev/null +++ b/agent/planner_tools_test.go @@ -0,0 +1,84 @@ +package agent + +import ( + "encoding/json" + "testing" + + "nofx/mcp" +) + +func TestPlannerToolsForMarketIntentAreTrimmed(t *testing.T) { + tools := plannerToolsForText("看一下 BTCUSDT 行情和 K线") + names := toolNamesForTest(tools) + + for _, expected := range []string{"get_market_snapshot", "get_market_price", "get_kline"} { + if !containsString(names, expected) { + t.Fatalf("expected market tool %q in %v", expected, names) + } + } + for _, unexpected := range []string{"manage_strategy", "manage_trader", "manage_exchange_config", "manage_model_config"} { + if containsString(names, unexpected) { + t.Fatalf("did not expect management tool %q in market tools %v", unexpected, names) + } + } +} + +func TestPlannerToolsForExchangeIntentAreTrimmed(t *testing.T) { + tools := plannerToolsForText("帮我添加 okx 交易所 API key") + names := toolNamesForTest(tools) + + if len(names) != 2 { + t.Fatalf("expected two exchange tools, got %v", names) + } + for _, expected := range []string{"get_exchange_configs", "manage_exchange_config"} { + if !containsString(names, expected) { + t.Fatalf("expected exchange tool %q in %v", expected, names) + } + } +} + +func TestPlannerToolsUseCompactManageStrategyForReadIntent(t *testing.T) { + tools := plannerToolsForText("列出我的策略") + tool := findToolForTest(tools, "manage_strategy") + if tool == nil { + t.Fatalf("expected manage_strategy in strategy tools") + } + + raw, _ := json.Marshal(tool.Function.Parameters) + if len(raw) > 900 { + t.Fatalf("expected compact strategy schema, got %d bytes", len(raw)) + } + if string(raw) == "" || !json.Valid(raw) { + t.Fatalf("expected valid strategy schema JSON") + } +} + +func TestPlannerToolsKeepFullManageStrategyForMutationIntent(t *testing.T) { + tools := plannerToolsForText("创建一个 BTC 网格策略") + tool := findToolForTest(tools, "manage_strategy") + if tool == nil { + t.Fatalf("expected manage_strategy in strategy tools") + } + + raw, _ := json.Marshal(tool.Function.Parameters) + if len(raw) < 1500 { + t.Fatalf("expected full strategy schema for mutation intent, got %d bytes", len(raw)) + } +} + +func toolNamesForTest(tools []mcp.Tool) []string { + names := make([]string, 0, len(tools)) + for _, tool := range tools { + names = append(names, tool.Function.Name) + } + return names +} + +func findToolForTest(tools []mcp.Tool, name string) *mcp.Tool { + for i := range tools { + if tools[i].Function.Name == name { + return &tools[i] + } + } + return nil +} diff --git a/agent/preferences.go b/agent/preferences.go index af43c9e8..5b834ecf 100644 --- a/agent/preferences.go +++ b/agent/preferences.go @@ -8,6 +8,8 @@ import ( "time" ) +const maxPersistentPreferenceLength = 500 + // PersistentPreference is a durable user instruction shown in the UI and // injected into the agent context for future conversations. type PersistentPreference struct { @@ -21,6 +23,9 @@ func NewPersistentPreference(text string) (PersistentPreference, error) { if text == "" { return PersistentPreference{}, fmt.Errorf("text required") } + if len([]rune(text)) > maxPersistentPreferenceLength { + return PersistentPreference{}, fmt.Errorf("text too long (max %d characters)", maxPersistentPreferenceLength) + } now := time.Now().UTC() return PersistentPreference{ diff --git a/agent/preferences_test.go b/agent/preferences_test.go deleted file mode 100644 index 5c45e2c5..00000000 --- a/agent/preferences_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package agent - -import ( - "strings" - "testing" -) - -func TestNewPersistentPreference(t *testing.T) { - pref, err := NewPersistentPreference(" Always answer in Chinese. ") - if err != nil { - t.Fatalf("expected preference to be created, got error: %v", err) - } - if pref.ID == "" { - t.Fatal("expected non-empty preference id") - } - if pref.Text != "Always answer in Chinese." { - t.Fatalf("expected trimmed text, got %q", pref.Text) - } - if pref.CreatedAt == "" { - t.Fatal("expected created_at to be set") - } - if strings.Contains(pref.ID, "Always") { - t.Fatalf("expected generated id, got %q", pref.ID) - } -} - -func TestNewPersistentPreferenceRejectsEmptyText(t *testing.T) { - if _, err := NewPersistentPreference(" "); err == nil { - t.Fatal("expected empty text to be rejected") - } -} diff --git a/agent/prompt_context.go b/agent/prompt_context.go new file mode 100644 index 00000000..23f7af52 --- /dev/null +++ b/agent/prompt_context.go @@ -0,0 +1,74 @@ +package agent + +import ( + "fmt" + "strings" +) + +func (a *Agent) buildCurrentTurnContext(userID int64, lang, currentUserText string) string { + var parts []string + previousAssistantReply := strings.TrimSpace(a.currentPendingHintText(userID)) + if previousAssistantReply != "" { + parts = append(parts, "Previous assistant reply:\n"+previousAssistantReply) + } + recentConversation := strings.TrimSpace(a.buildRecentConversationContext(userID, currentUserText)) + if recentConversation != "" { + parts = append(parts, "Recent conversation:\n"+recentConversation) + } + currentRefs := strings.TrimSpace(buildCurrentReferenceSummary(lang, a.semanticCurrentReferences(userID))) + if currentRefs != "" { + parts = append(parts, "Current references:\n"+currentRefs) + } + return strings.Join(parts, "\n\n") +} + +func (a *Agent) buildActiveTaskStateContext(userID int64, lang string) string { + activeSkill := a.getSkillSession(userID) + activeTask, hasActiveTask := a.getActiveSkillSession(userID) + activeWorkflow := a.getWorkflowSession(userID) + activeExec := normalizeExecutionState(a.getExecutionState(userID)) + pendingProposal, hasPendingProposal := a.getPendingProposalSession(userID) + + lines := []string{} + if hasActiveTask || strings.TrimSpace(activeSkill.Name) != "" || hasActiveWorkflowSession(activeWorkflow) || hasActiveExecutionState(activeExec) || hasPendingProposal { + summary := strings.TrimSpace(buildTopLevelActiveFlowSummary(lang, activeSkill, activeTask, hasActiveTask, activeWorkflow, activeExec, pendingProposal, hasPendingProposal)) + if summary != "" { + lines = append(lines, summary) + } + } + + taskState := normalizeTaskState(a.getTaskState(userID)) + if taskState.CurrentGoal != "" { + lines = append(lines, "Durable goal: "+taskState.CurrentGoal) + } + if taskState.ActiveFlow != "" { + lines = append(lines, "Durable active flow: "+taskState.ActiveFlow) + } + if len(taskState.OpenLoops) > 0 { + limit := len(taskState.OpenLoops) + if limit > 3 { + limit = 3 + } + for _, loop := range taskState.OpenLoops[:limit] { + lines = append(lines, "Open loop: "+loop) + } + } + + if hasActiveExecutionState(activeExec) { + lines = append(lines, fmt.Sprintf("Execution status: %s", activeExec.Status)) + if strings.TrimSpace(activeExec.Goal) != "" { + lines = append(lines, "Execution goal: "+strings.TrimSpace(activeExec.Goal)) + } + if activeExec.Waiting != nil && strings.TrimSpace(activeExec.Waiting.Question) != "" { + lines = append(lines, "Waiting question: "+strings.TrimSpace(activeExec.Waiting.Question)) + } + if strings.TrimSpace(activeExec.CurrentStepID) != "" { + lines = append(lines, "Current step id: "+strings.TrimSpace(activeExec.CurrentStepID)) + } + } + + if len(lines) == 0 { + return "" + } + return strings.Join(lines, "\n") +} diff --git a/agent/prompt_persona.go b/agent/prompt_persona.go new file mode 100644 index 00000000..9ae3337a --- /dev/null +++ b/agent/prompt_persona.go @@ -0,0 +1,25 @@ +package agent + +import "strings" + +const nofxiAdvisorSystemPreamble = `You are NOFXi, the core intelligence hub of the NOFX platform. +You understand NOFX's underlying logic, feature boundaries, and quantitative operating model. +Your first duty is not blind execution. You act as the user's senior quantitative advisor so every NOFX configuration is correct, safe, and logically consistent. +When the user runs into a problem, combine the current state with NOFX platform constraints, proactively diagnose what is wrong, and provide concrete next steps. + +User-facing response style rules: +- Treat the user like a trading beginner, not a developer. +- Prefer simple, plain language over technical jargon. +- Lead with the conclusion first, then one or two concrete next steps. +- Keep sentences short and easy to scan. +- If you must use a technical term, explain it in everyday words immediately. +- Do not expose internal architecture, tool names, JSON fields, or implementation details unless the user explicitly asks for them. +- When asking follow-up questions, make them specific, friendly, and easy to answer.` + +func prependNOFXiAdvisorPreamble(body string) string { + body = strings.TrimSpace(body) + if body == "" { + return nofxiAdvisorSystemPreamble + } + return nofxiAdvisorSystemPreamble + "\n\n" + body +} diff --git a/agent/reference_memory.go b/agent/reference_memory.go new file mode 100644 index 00000000..e07f037f --- /dev/null +++ b/agent/reference_memory.go @@ -0,0 +1,101 @@ +package agent + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +type ReferenceMemory struct { + CurrentReferences *CurrentReferences `json:"current_references,omitempty"` + ReferenceHistory []ReferenceRecord `json:"reference_history,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +func referenceMemoryConfigKey(userID int64) string { + return fmt.Sprintf("agent_reference_memory_%d", userID) +} + +func (a *Agent) getReferenceMemory(userID int64) ReferenceMemory { + if a == nil || a.store == nil { + return ReferenceMemory{} + } + raw, err := a.store.GetSystemConfig(referenceMemoryConfigKey(userID)) + if err != nil { + return ReferenceMemory{} + } + raw = strings.TrimSpace(raw) + if raw == "" { + return ReferenceMemory{} + } + var memory ReferenceMemory + if err := json.Unmarshal([]byte(raw), &memory); err != nil { + return ReferenceMemory{} + } + memory.CurrentReferences = normalizeCurrentReferences(memory.CurrentReferences) + memory.ReferenceHistory = normalizeReferenceHistory(memory.ReferenceHistory) + return memory +} + +func (a *Agent) saveReferenceMemory(userID int64, refs *CurrentReferences, history []ReferenceRecord) { + if a == nil || a.store == nil { + return + } + memory := ReferenceMemory{ + CurrentReferences: normalizeCurrentReferences(refs), + ReferenceHistory: normalizeReferenceHistory(history), + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + if memory.CurrentReferences == nil && len(memory.ReferenceHistory) == 0 { + _ = a.store.SetSystemConfig(referenceMemoryConfigKey(userID), "") + return + } + data, err := json.Marshal(memory) + if err != nil { + return + } + _ = a.store.SetSystemConfig(referenceMemoryConfigKey(userID), string(data)) +} + +func (a *Agent) clearReferenceMemory(userID int64) { + if a == nil || a.store == nil { + return + } + _ = a.store.SetSystemConfig(referenceMemoryConfigKey(userID), "") +} + +func (a *Agent) semanticCurrentReferences(userID int64) *CurrentReferences { + state := a.getExecutionState(userID) + if refs := normalizeCurrentReferences(state.CurrentReferences); refs != nil { + return refs + } + return a.getReferenceMemory(userID).CurrentReferences +} + +func (a *Agent) semanticReferenceHistory(userID int64) []ReferenceRecord { + state := a.getExecutionState(userID) + if history := normalizeReferenceHistory(state.ReferenceHistory); len(history) > 0 { + return history + } + return a.getReferenceMemory(userID).ReferenceHistory +} + +func (a *Agent) rememberReferencesFromToolResult(userID int64, toolName, raw string) { + if a == nil { + return + } + memory := a.getReferenceMemory(userID) + state := ExecutionState{ + UserID: userID, + CurrentReferences: memory.CurrentReferences, + ReferenceHistory: memory.ReferenceHistory, + } + if !updateCurrentReferencesFromToolResult(&state, toolName, raw) { + return + } + a.saveReferenceMemory(userID, state.CurrentReferences, state.ReferenceHistory) + execState := a.getExecutionState(userID) + execState.CurrentReferences = state.CurrentReferences + a.saveExecutionState(execState) +} diff --git a/agent/scheduler.go b/agent/scheduler.go index a96da16b..b2a8da2b 100644 --- a/agent/scheduler.go +++ b/agent/scheduler.go @@ -29,8 +29,10 @@ func (s *Scheduler) Start(ctx context.Context) { lastCheck := time.Time{} for { select { - case <-ctx.Done(): return - case <-s.stopCh: return + case <-ctx.Done(): + return + case <-s.stopCh: + return case now := <-ticker.C: // Daily report at 21:00 if now.Hour() == 21 && now.Sub(lastReport) > 12*time.Hour { @@ -53,13 +55,21 @@ func (s *Scheduler) Start(ctx context.Context) { }) } -func (s *Scheduler) Stop() { s.stopOnce.Do(func() { close(s.stopCh) }) } +func (s *Scheduler) Stop() { + s.stopOnce.Do(func() { + close(s.stopCh) + }) +} func (s *Scheduler) dailyReport() { - if s.agent.traderManager == nil { return } + if s.agent.traderManager == nil { + return + } traders := s.agent.traderManager.GetAllTraders() - if len(traders) == 0 { return } + if len(traders) == 0 { + return + } var sb strings.Builder sb.WriteString(fmt.Sprintf("📊 *NOFXi 每日报告 — %s*\n\n", time.Now().Format("2006-01-02"))) @@ -67,30 +77,40 @@ func (s *Scheduler) dailyReport() { totalPnL := 0.0 for _, t := range traders { info, err := t.GetAccountInfo() - if err != nil { continue } + if err != nil { + continue + } equity := toFloat(info["total_equity"]) pnl := toFloat(info["unrealized_pnl"]) sb.WriteString(fmt.Sprintf("• %s: $%.2f (P/L: $%.2f)\n", t.GetName(), equity, pnl)) totalPnL += pnl } e := "📈" - if totalPnL < 0 { e = "📉" } + if totalPnL < 0 { + e = "📉" + } sb.WriteString(fmt.Sprintf("\n%s Total P/L: $%.2f", e, totalPnL)) s.agent.notifyAll(sb.String()) } func (s *Scheduler) riskCheck() { - if s.agent.traderManager == nil { return } + if s.agent.traderManager == nil { + return + } var alerts []string for _, t := range s.agent.traderManager.GetAllTraders() { positions, err := t.GetPositions() - if err != nil { continue } + if err != nil { + continue + } for _, p := range positions { pnl := toFloat(p["unrealizedPnl"]) size := toFloat(p["size"]) - if size == 0 { continue } + if size == 0 { + continue + } entry := toFloat(p["entryPrice"]) if entry > 0 { pnlPct := (pnl / (entry * size)) * 100 diff --git a/agent/sentinel.go b/agent/sentinel.go index a8c4a3ed..91d76385 100644 --- a/agent/sentinel.go +++ b/agent/sentinel.go @@ -77,20 +77,51 @@ func (s *Sentinel) Start() { }) } -func (s *Sentinel) Stop() { s.stopOnce.Do(func() { close(s.stopCh) }) } -func (s *Sentinel) SymbolCount() int { s.mu.RLock(); defer s.mu.RUnlock(); return len(s.symbols) } -func (s *Sentinel) AddSymbol(sym string) { s.mu.Lock(); defer s.mu.Unlock(); for _, x := range s.symbols { if x == sym { return } }; s.symbols = append(s.symbols, sym) } -func (s *Sentinel) RemoveSymbol(sym string) { s.mu.Lock(); defer s.mu.Unlock(); for i, x := range s.symbols { if x == sym { s.symbols = append(s.symbols[:i], s.symbols[i+1:]...); return } } } +func (s *Sentinel) Stop() { s.stopOnce.Do(func() { close(s.stopCh) }) } +func (s *Sentinel) SymbolCount() int { s.mu.RLock(); defer s.mu.RUnlock(); return len(s.symbols) } +func (s *Sentinel) Symbols() []string { + s.mu.RLock() + defer s.mu.RUnlock() + out := make([]string, len(s.symbols)) + copy(out, s.symbols) + return out +} +func (s *Sentinel) AddSymbol(sym string) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range s.symbols { + if x == sym { + return + } + } + s.symbols = append(s.symbols, sym) +} +func (s *Sentinel) RemoveSymbol(sym string) { + s.mu.Lock() + defer s.mu.Unlock() + for i, x := range s.symbols { + if x == sym { + s.symbols = append(s.symbols[:i], s.symbols[i+1:]...) + return + } + } +} func (s *Sentinel) FormatWatchlist(L string) string { s.mu.RLock() defer s.mu.RUnlock() if len(s.symbols) == 0 { - if L == "zh" { return "📭 监控列表为空。用 `/watch BTC` 添加。" } + if L == "zh" { + return "📭 监控列表为空。用 `/watch BTC` 添加。" + } return "📭 Watchlist empty. Use `/watch BTC` to add." } var sb strings.Builder - if L == "zh" { sb.WriteString("👁️ *监控列表*\n\n") } else { sb.WriteString("👁️ *Watchlist*\n\n") } + if L == "zh" { + sb.WriteString("👁️ *监控列表*\n\n") + } else { + sb.WriteString("👁️ *Watchlist*\n\n") + } for _, sym := range s.symbols { if pts, ok := s.history[sym]; ok && len(pts) > 0 { last := pts[len(pts)-1] @@ -114,16 +145,22 @@ func (s *Sentinel) scan() { func (s *Sentinel) check(symbol string) { resp, err := s.http.Get(fmt.Sprintf("https://fapi.binance.com/fapi/v1/ticker/24hr?symbol=%s", symbol)) - if err != nil { return } + if err != nil { + return + } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { s.logger.Debug("sentinel ticker non-200", "symbol", symbol, "status", resp.StatusCode) return } body, err := safe.ReadAllLimited(resp.Body, 256*1024) // 256KB limit - if err != nil { return } + if err != nil { + return + } var t map[string]interface{} - if err := json.Unmarshal(body, &t); err != nil { return } + if err := json.Unmarshal(body, &t); err != nil { + return + } price, _ := strconv.ParseFloat(fmt.Sprint(t["lastPrice"]), 64) vol, _ := strconv.ParseFloat(fmt.Sprint(t["quoteVolume"]), 64) @@ -133,41 +170,53 @@ func (s *Sentinel) check(symbol string) { s.mu.Lock() h := s.history[symbol] h = append(h, pt) - if len(h) > 60 { h = h[len(h)-60:] } + if len(h) > 60 { + h = h[len(h)-60:] + } s.history[symbol] = h s.mu.Unlock() - if len(h) < 5 { return } + if len(h) < 5 { + return + } // Price breakout (>3% in 5 min) old := h[len(h)-5] pct := ((price - old.Price) / old.Price) * 100 if math.Abs(pct) >= 3.0 { sev := "warning" - if math.Abs(pct) >= 6.0 { sev = "critical" } + if math.Abs(pct) >= 6.0 { + sev = "critical" + } dir := "📈 拉升" - if pct < 0 { dir = "📉 下跌" } + if pct < 0 { + dir = "📉 下跌" + } s.emit(Signal{Type: SignalPriceBreakout, Symbol: symbol, Severity: sev, - Title: fmt.Sprintf("%s %s %.1f%%", symbol, dir, math.Abs(pct)), + Title: fmt.Sprintf("%s %s %.1f%%", symbol, dir, math.Abs(pct)), Detail: fmt.Sprintf("5min: $%.2f → $%.2f (24h: %.1f%%)", old.Price, price, chg), - Price: price, Change: pct}) + Price: price, Change: pct}) } // Volume spike (>3x avg) if len(h) >= 10 { var avg float64 - for i := 0; i < len(h)-1; i++ { avg += h[i].Volume } + for i := 0; i < len(h)-1; i++ { + avg += h[i].Volume + } avg /= float64(len(h) - 1) if avg > 0 && vol > avg*3 { s.emit(Signal{Type: SignalVolumeSpike, Symbol: symbol, Severity: "warning", - Title: fmt.Sprintf("%s 成交量异常 %.1fx", symbol, vol/avg), + Title: fmt.Sprintf("%s 成交量异常 %.1fx", symbol, vol/avg), Detail: fmt.Sprintf("Price: $%.2f (24h: %.1f%%)", price, chg), - Price: price, Change: chg}) + Price: price, Change: chg}) } } } func (s *Sentinel) emit(sig Signal) { s.logger.Info("signal", "type", sig.Type, "symbol", sig.Symbol, "title", sig.Title) - if s.onSignal != nil { s.onSignal(sig) } + if s.onSignal != nil { + s.onSignal(sig) + } } diff --git a/agent/skill_catalog_test.go b/agent/skill_catalog_test.go deleted file mode 100644 index 36d96fbe..00000000 --- a/agent/skill_catalog_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package agent - -import ( - "log/slog" - "strings" - "testing" -) - -func TestSkillCatalogPromptZHIncludesDiagnosisSkills(t *testing.T) { - got := skillCatalogPrompt("zh") - for _, want := range []string{ - "多轮与 Skill-First 工作模式", - "skill_model_config_diagnosis", - "skill_exchange_api_diagnosis", - "skill_trader_start_diagnosis", - } { - if !strings.Contains(got, want) { - t.Fatalf("skillCatalogPrompt(zh) missing %q\n%s", want, got) - } - } -} - -func TestBuildSystemPromptIncludesSkillCatalog(t *testing.T) { - a := New(nil, nil, DefaultConfig(), slog.Default()) - got := a.buildSystemPrompt("zh") - for _, want := range []string{ - "多轮与 Skill-First 工作模式", - "skill_exchange_api_setup", - "skill_order_execution_diagnosis", - } { - if !strings.Contains(got, want) { - t.Fatalf("buildSystemPrompt(zh) missing %q", want) - } - } -} diff --git a/agent/skill_dag.go b/agent/skill_dag.go index ad026115..a0b84c15 100644 --- a/agent/skill_dag.go +++ b/agent/skill_dag.go @@ -35,15 +35,6 @@ func buildSkillDAGRegistry() map[string]SkillDAG { {ID: "execute_create_and_start", Kind: "execute", RequiredFields: []string{"name", "exchange_id", "model_id", "strategy_id"}, OptionalFields: []string{"auto_start"}, Terminal: true}, }, }, - { - SkillName: "trader_management", - Action: "update_name", - Steps: []SkillDAGStep{ - {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_name"}}, - {ID: "collect_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"execute_update"}}, - {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "name"}, Terminal: true}, - }, - }, { SkillName: "trader_management", Action: "update_bindings", @@ -53,6 +44,33 @@ func buildSkillDAGRegistry() map[string]SkillDAG { {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "binding_update"}, OptionalFields: []string{"ai_model_id", "exchange_id", "strategy_id"}, Terminal: true}, }, }, + { + SkillName: "trader_management", + Action: "configure_strategy", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_bindings"}}, + {ID: "collect_bindings", Kind: "collect_slot", RequiredFields: []string{"binding_update"}, OptionalFields: []string{"strategy_id"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "binding_update", "strategy_id"}, Terminal: true}, + }, + }, + { + SkillName: "trader_management", + Action: "configure_exchange", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_bindings"}}, + {ID: "collect_bindings", Kind: "collect_slot", RequiredFields: []string{"binding_update"}, OptionalFields: []string{"exchange_id"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "binding_update", "exchange_id"}, Terminal: true}, + }, + }, + { + SkillName: "trader_management", + Action: "configure_model", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_bindings"}}, + {ID: "collect_bindings", Kind: "collect_slot", RequiredFields: []string{"binding_update"}, OptionalFields: []string{"ai_model_id"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "binding_update", "ai_model_id"}, Terminal: true}, + }, + }, { SkillName: "trader_management", Action: "start", @@ -111,12 +129,9 @@ func buildSkillDAGRegistry() map[string]SkillDAG { SkillName: "strategy_management", Action: "update_config", Steps: []SkillDAGStep{ - {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"resolve_config_field"}}, - {ID: "resolve_config_field", Kind: "collect_slot", RequiredFields: []string{"config_field"}, Next: []string{"resolve_config_value"}}, - {ID: "resolve_config_value", Kind: "collect_slot", RequiredFields: []string{"config_value"}, Next: []string{"load_config"}}, - {ID: "load_config", Kind: "load_state", RequiredFields: []string{"target_ref"}, Next: []string{"apply_field_update"}}, - {ID: "apply_field_update", Kind: "transform", RequiredFields: []string{"config_field", "config_value"}, Next: []string{"execute_update"}}, - {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "config_field", "config_value"}, Terminal: true}, + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_config_patch"}}, + {ID: "collect_config_patch", Kind: "collect_slot", RequiredFields: []string{"config_patch"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "config_patch"}, Terminal: true}, }, }, { @@ -274,4 +289,3 @@ func listSkillDAGs() []SkillDAG { } return out } - diff --git a/agent/skill_dag_runtime_test.go b/agent/skill_dag_runtime_test.go deleted file mode 100644 index 8085ceee..00000000 --- a/agent/skill_dag_runtime_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package agent - -import "testing" - -func TestCurrentSkillDAGStepDefaultsToFirstStep(t *testing.T) { - session := skillSession{Name: "strategy_management", Action: "update_config"} - step, ok := currentSkillDAGStep(session) - if !ok { - t.Fatal("expected dag step") - } - if step.ID != "resolve_target" { - t.Fatalf("expected first step resolve_target, got %s", step.ID) - } -} - -func TestAdvanceSkillDAGStepMovesToNextStep(t *testing.T) { - session := skillSession{Name: "strategy_management", Action: "update_config"} - setSkillDAGStep(&session, "resolve_config_field") - advanceSkillDAGStep(&session, "resolve_config_field") - step, ok := currentSkillDAGStep(session) - if !ok { - t.Fatal("expected dag step") - } - if step.ID != "resolve_config_value" { - t.Fatalf("expected resolve_config_value, got %s", step.ID) - } -} diff --git a/agent/skill_dag_test.go b/agent/skill_dag_test.go deleted file mode 100644 index 73707474..00000000 --- a/agent/skill_dag_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package agent - -import "testing" - -func TestGetSkillDAGForStructuredActions(t *testing.T) { - tests := []struct { - skill string - action string - }{ - {skill: "trader_management", action: "create"}, - {skill: "trader_management", action: "update_bindings"}, - {skill: "strategy_management", action: "update_config"}, - {skill: "strategy_management", action: "update_prompt"}, - {skill: "model_management", action: "update_status"}, - {skill: "exchange_management", action: "update_name"}, - } - - for _, tt := range tests { - dag, ok := getSkillDAG(tt.skill, tt.action) - if !ok { - t.Fatalf("expected DAG for %s/%s", tt.skill, tt.action) - } - if dag.SkillName != tt.skill || dag.Action != tt.action { - t.Fatalf("unexpected dag identity: %+v", dag) - } - if len(dag.Steps) == 0 { - t.Fatalf("expected DAG steps for %s/%s", tt.skill, tt.action) - } - } -} - -func TestStructuredDAGsHaveTerminalStep(t *testing.T) { - for _, dag := range listSkillDAGs() { - hasTerminal := false - for _, step := range dag.Steps { - if step.Terminal { - hasTerminal = true - break - } - } - if !hasTerminal { - t.Fatalf("expected terminal step for %s/%s", dag.SkillName, dag.Action) - } - } -} - -func TestStrategyUpdateConfigDAGMatchesCurrentAtomicFlow(t *testing.T) { - dag, ok := getSkillDAG("strategy_management", "update_config") - if !ok { - t.Fatal("missing strategy update_config dag") - } - if len(dag.Steps) != 6 { - t.Fatalf("expected 6 steps, got %d", len(dag.Steps)) - } - if dag.Steps[0].ID != "resolve_target" { - t.Fatalf("expected first step resolve_target, got %s", dag.Steps[0].ID) - } - if dag.Steps[1].ID != "resolve_config_field" { - t.Fatalf("expected second step resolve_config_field, got %s", dag.Steps[1].ID) - } - if dag.Steps[2].ID != "resolve_config_value" { - t.Fatalf("expected third step resolve_config_value, got %s", dag.Steps[2].ID) - } - if dag.Steps[5].ID != "execute_update" || !dag.Steps[5].Terminal { - t.Fatalf("expected final terminal execute step, got %+v", dag.Steps[5]) - } -} diff --git a/agent/skill_dispatcher.go b/agent/skill_dispatcher.go index db3ee326..41c919c7 100644 --- a/agent/skill_dispatcher.go +++ b/agent/skill_dispatcher.go @@ -4,9 +4,10 @@ import ( "context" "encoding/json" "fmt" - "regexp" "strings" "time" + + "nofx/store" ) type skillSession struct { @@ -34,13 +35,9 @@ type traderSkillOption struct { ID string Name string Enabled bool + Hint string } -var ( - quotedNamePattern = regexp.MustCompile(`[“"]([^“”"]{1,40})[”"]`) - traderNamedPattern = regexp.MustCompile(`(?:叫|名为|名字是)\s*([A-Za-z0-9_\-\p{Han}]{2,40})`) -) - func skillSessionConfigKey(userID int64) string { return fmt.Sprintf("agent_skill_session_%d", userID) } @@ -53,7 +50,7 @@ func normalizeSkillSession(session skillSession) skillSession { if len(session.Fields) > 0 { normalized := make(map[string]string, len(session.Fields)) for key, value := range session.Fields { - key = strings.TrimSpace(key) + key = normalizeFieldKey(&session, key) value = strings.TrimSpace(value) if key == "" || value == "" { continue @@ -67,6 +64,7 @@ func normalizeSkillSession(session skillSession) skillSession { } } if session.Slots != nil { + ensureSkillFields(&session) session.Slots.Name = strings.TrimSpace(session.Slots.Name) session.Slots.ExchangeID = strings.TrimSpace(session.Slots.ExchangeID) session.Slots.ExchangeName = strings.TrimSpace(session.Slots.ExchangeName) @@ -74,11 +72,43 @@ func normalizeSkillSession(session skillSession) skillSession { session.Slots.ModelName = strings.TrimSpace(session.Slots.ModelName) session.Slots.StrategyID = strings.TrimSpace(session.Slots.StrategyID) session.Slots.StrategyName = strings.TrimSpace(session.Slots.StrategyName) - if session.Slots.Name == "" && - session.Slots.ExchangeID == "" && - session.Slots.ModelID == "" && - session.Slots.StrategyID == "" && - session.Slots.AutoStart == nil { + if session.Slots.Name != "" { + session.Fields["name"] = session.Slots.Name + } + if session.Slots.ExchangeID != "" { + session.Fields["exchange_id"] = session.Slots.ExchangeID + } + if session.Slots.ExchangeName != "" { + session.Fields["exchange_name"] = session.Slots.ExchangeName + } + if session.Slots.ModelID != "" { + session.Fields["model_id"] = session.Slots.ModelID + } + if session.Slots.ModelName != "" { + session.Fields["model_name"] = session.Slots.ModelName + } + if session.Slots.StrategyID != "" { + session.Fields["strategy_id"] = session.Slots.StrategyID + } + if session.Slots.StrategyName != "" { + session.Fields["strategy_name"] = session.Slots.StrategyName + } + if session.Slots.AutoStart != nil { + if *session.Slots.AutoStart { + session.Fields["auto_start"] = "true" + } else { + session.Fields["auto_start"] = "false" + } + } + syncTraderCreateSlotMirror(&session) + if fieldValue(session, "name") == "" && + fieldValue(session, "exchange_id") == "" && + fieldValue(session, "model_id") == "" && + fieldValue(session, "strategy_id") == "" && + fieldValue(session, "exchange_name") == "" && + fieldValue(session, "model_name") == "" && + fieldValue(session, "strategy_name") == "" && + fieldValue(session, "auto_start") == "" { session.Slots = nil } } @@ -165,296 +195,28 @@ func isCancelSkillReply(text string) bool { } } -func detectCreateTraderSkill(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - hasCreate := containsAny(lower, []string{"创建", "新建", "建一个", "create", "new"}) - hasTrader := containsAny(lower, []string{"交易员", "trader", "agent"}) - return hasCreate && hasTrader -} - -func detectModelDiagnosisSkill(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - if containsAny(lower, []string{"custom_api_url", "invalid custom_api_url", "ai assistant unavailable", "模型配置失败", "模型不可用", "ai unavailable"}) { - return true - } - return containsAny(lower, []string{"模型", "model", "api key", "base url", "custom_api_url"}) && - containsAny(lower, []string{"报错", "错误", "失败", "不可用", "不生效", "invalid", "error", "failed"}) -} - -func detectExchangeDiagnosisSkill(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - return containsAny(lower, []string{ - "invalid signature", "timestamp", "ip not allowed", "permission denied", - "签名错误", "签名失败", "时间戳", "白名单", "权限不足", "交易所 api 报错", "交易所连接不上", - }) -} - -func detectStartIntent(text string) bool { - lower := strings.ToLower(text) - return containsAny(lower, []string{"启动", "跑起来", "run", "start", "立即运行", "并启动"}) -} - -func looksLikeStandaloneValueReply(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - if firstIntegerPattern.MatchString(lower) && len(strings.Fields(lower)) <= 4 { - return true - } - return containsAny(lower, []string{"启用", "禁用", "enable", "disable", "打开", "关闭"}) -} - -func detectImplicitStrategyAction(text string) string { - lower := strings.ToLower(strings.TrimSpace(text)) - switch { - case containsAny(lower, []string{"prompt", "提示词"}): - return "update_prompt" - case containsAny(lower, []string{"参数", "配置", "置信度", "持仓", "周期", "timeframe", "调到", "改到", "改成", "调整"}): - return "update_config" - default: +func normalizeTraderDraftName(value string) string { + value = strings.TrimSpace(value) + if value == "" { return "" } -} - -func detectImplicitTraderAction(text string) string { - lower := strings.ToLower(strings.TrimSpace(text)) - switch { - case containsAny(lower, []string{"启动", "开始", "run", "start"}): - return "start" - case containsAny(lower, []string{"停止", "停掉", "stop", "pause"}): - return "stop" - case containsAny(lower, []string{"换模型", "换交易所", "换策略", "绑定", "切换模型", "切换交易所", "切换策略"}): - return "update_bindings" - case containsAny(lower, []string{"改名", "重命名", "rename"}): - return "update_name" - default: - return "" - } -} - -func detectImplicitModelAction(text string) string { - lower := strings.ToLower(strings.TrimSpace(text)) - switch { - case containsAny(lower, []string{"启用", "禁用", "enable", "disable"}): - return "update_status" - case containsAny(lower, []string{"url", "endpoint", "地址", "接口"}): - return "update_endpoint" - case containsAny(lower, []string{"模型名", "模型名称", "model name", "改名", "重命名", "rename"}): - return "update_name" - default: - return "" - } -} - -func detectImplicitExchangeAction(text string) string { - lower := strings.ToLower(strings.TrimSpace(text)) - switch { - case containsAny(lower, []string{"启用", "禁用", "enable", "disable"}): - return "update_status" - case containsAny(lower, []string{"账户名", "改名", "重命名", "rename"}): - return "update_name" - default: - return "" - } -} - -func (a *Agent) inferContextualSkillSession(storeUserID string, userID int64, text string, session skillSession) skillSession { - if session.Name != "" || strings.TrimSpace(text) == "" { - return session - } - state := a.getExecutionState(userID) - lower := strings.ToLower(strings.TrimSpace(text)) - if state.CurrentReferences != nil { - if ref := state.CurrentReferences.Strategy; ref != nil { - if action := detectImplicitStrategyAction(text); action != "" || looksLikeStandaloneValueReply(text) { - return skillSession{Name: "strategy_management", Action: defaultIfEmpty(action, "update_config"), Phase: "collecting", TargetRef: ref} - } - } - if ref := state.CurrentReferences.Trader; ref != nil { - if action := detectImplicitTraderAction(text); action != "" { - return skillSession{Name: "trader_management", Action: action, Phase: "collecting", TargetRef: ref} - } - } - if ref := state.CurrentReferences.Model; ref != nil { - if action := detectImplicitModelAction(text); action != "" { - return skillSession{Name: "model_management", Action: action, Phase: "collecting", TargetRef: ref} - } - } - if ref := state.CurrentReferences.Exchange; ref != nil { - if action := detectImplicitExchangeAction(text); action != "" { - return skillSession{Name: "exchange_management", Action: action, Phase: "collecting", TargetRef: ref} - } + for _, prefix := range []string{"名称:", "名称:", "名字:", "名字:", "name:", "name:"} { + if strings.HasPrefix(strings.ToLower(value), strings.ToLower(prefix)) { + value = strings.TrimSpace(value[len(prefix):]) + break } } - if containsAny(lower, []string{"调整参数", "改参数", "改配置"}) { - options := a.loadStrategyOptions(storeUserID) - if len(options) == 1 { - return skillSession{ - Name: "strategy_management", - Action: "update_config", - Phase: "collecting", - TargetRef: &EntityReference{ - ID: options[0].ID, - Name: options[0].Name, - }, - } + for _, sep := range []string{"交易所:", "交易所:", "模型:", "模型:", "策略:", "策略:", "exchange:", "model:", "strategy:"} { + if idx := strings.Index(strings.ToLower(value), strings.ToLower(sep)); idx >= 0 { + value = strings.TrimSpace(value[:idx]) } } - return session -} - -func extractTraderName(text string) string { - text = strings.TrimSpace(text) - if text == "" { - return "" - } - if matches := quotedNamePattern.FindStringSubmatch(text); len(matches) == 2 { - return strings.TrimSpace(matches[1]) - } - if matches := traderNamedPattern.FindStringSubmatch(text); len(matches) == 2 { - return strings.TrimSpace(matches[1]) - } - return "" -} - -func extractSegmentAfterKeywords(text string, keywords []string) string { - trimmed := strings.TrimSpace(text) - if trimmed == "" { - return "" - } - lower := strings.ToLower(trimmed) - for _, keyword := range keywords { - idx := strings.Index(lower, strings.ToLower(keyword)) - if idx < 0 { - continue - } - segment := strings.TrimSpace(trimmed[idx+len(keyword):]) - if segment == "" { - continue - } - cut := len(segment) - for i, r := range segment { - switch r { - case ',', ',', '。', ';', ';', '\n', '、': - cut = i - goto done - } - } - done: - segment = strings.TrimSpace(segment[:cut]) - segment = strings.Trim(segment, "“”\"':: ") - if segment != "" { - return segment + for _, sep := range []string{",", ",", "。", ";", ";", "\n"} { + if idx := strings.Index(value, sep); idx >= 0 { + value = strings.TrimSpace(value[:idx]) } } - return "" -} - -func pickMentionedOption(text string, options []traderSkillOption) *traderSkillOption { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return nil - } - bestScore := 0 - var matched *traderSkillOption - for _, option := range options { - id := strings.ToLower(strings.TrimSpace(option.ID)) - name := strings.ToLower(strings.TrimSpace(option.Name)) - if id == "" && name == "" { - continue - } - score := optionMatchScore(lower, id, name) - if score == 0 { - continue - } - if score == bestScore { - matched = nil - continue - } - if score > bestScore { - bestScore = score - copy := option - matched = © - } - } - return matched -} - -func pickOptionFromSegment(text string, keywords []string, options []traderSkillOption) *traderSkillOption { - segment := extractSegmentAfterKeywords(text, keywords) - if strings.TrimSpace(segment) == "" { - return nil - } - return pickMentionedOption(segment, options) -} - -func optionMatchScore(text, id, name string) int { - if id != "" && strings.Contains(text, id) { - return 4 - } - return optionNameMatchScore(text, name) -} - -func optionNameMatchScore(text, name string) int { - name = strings.TrimSpace(strings.ToLower(name)) - if name == "" { - return 0 - } - if strings.Contains(text, name) { - return 3 - } - fields := strings.FieldsFunc(name, func(r rune) bool { - switch r { - case ' ', ',', ',', '/', '|', '、', '(', ')', '(', ')': - return true - default: - return false - } - }) - best := 0 - for _, field := range fields { - field = strings.TrimSpace(field) - if field == "" { - continue - } - if len([]rune(field)) <= 2 && !containsHan(field) { - continue - } - if strings.Contains(text, field) { - if containsHan(field) && len([]rune(field)) >= 3 { - best = max(best, 2) - } else { - best = max(best, 1) - } - } - } - return best -} - -func containsHan(s string) bool { - for _, r := range s { - if r >= 0x4E00 && r <= 0x9FFF { - return true - } - } - return false -} - -func max(a, b int) int { - if a > b { - return a - } - return b + return strings.Trim(value, "“”\"':: ") } func choosePreferredOption(options []traderSkillOption) *traderSkillOption { @@ -482,6 +244,9 @@ func formatOptionList(prefix string, options []traderSkillOption) string { if label == "" { label = option.ID } + if hint := strings.TrimSpace(option.Hint); hint != "" { + label += "(" + hint + ")" + } if option.Enabled { label += "(已启用)" } else { @@ -505,6 +270,28 @@ func parseSkillError(raw string) string { return strings.TrimSpace(raw) } +func modelWalletBalanceHint(model *store.AIModel) string { + if model == nil || !agentProviderSupportsUSDCBalance(model.Provider) { + return "" + } + privateKey := strings.TrimSpace(string(model.APIKey)) + if privateKey == "" { + return "钱包未配置" + } + walletAddress, err := agentWalletAddressFromPrivateKey(privateKey) + if err != nil || strings.TrimSpace(walletAddress) == "" { + return "钱包私钥无效" + } + balance, err := agentQueryUSDCBalanceCached(walletAddress) + if err != nil { + return "钱包余额暂时无法读取" + } + if balance <= 0 { + return "钱包余额 0 USDC,需充值后才能稳定调用" + } + return fmt.Sprintf("钱包余额 %.4g USDC", balance) +} + func (a *Agent) loadEnabledModelOptions(storeUserID string) []traderSkillOption { if a.store == nil { return nil @@ -515,13 +302,16 @@ func (a *Agent) loadEnabledModelOptions(storeUserID string) []traderSkillOption } out := make([]traderSkillOption, 0, len(models)) for _, model := range models { - parts := cleanStringList([]string{ - strings.TrimSpace(model.Name), + name := strings.TrimSpace(model.Name) + if name == "" { + name = strings.TrimSpace(model.ID) + } + hint := strings.Join(cleanStringList([]string{ strings.TrimSpace(model.CustomModelName), strings.TrimSpace(model.Provider), - }) - name := strings.Join(parts, " ") - out = append(out, traderSkillOption{ID: model.ID, Name: name, Enabled: model.Enabled}) + modelWalletBalanceHint(model), + }), " / ") + out = append(out, traderSkillOption{ID: model.ID, Name: name, Hint: hint, Enabled: model.Enabled}) } return out } @@ -536,6 +326,9 @@ func (a *Agent) loadExchangeOptions(storeUserID string) []traderSkillOption { } out := make([]traderSkillOption, 0, len(exchanges)) for _, exchange := range exchanges { + if !store.IsVisibleExchange(exchange) { + continue + } name := strings.TrimSpace(exchange.AccountName) if name == "" { name = strings.TrimSpace(exchange.ExchangeType) @@ -560,112 +353,78 @@ func (a *Agent) loadStrategyOptions(storeUserID string) []traderSkillOption { return out } +func (a *Agent) buildTraderCreateConversationResources(storeUserID string, session skillSession) map[string]any { + missing := missingFieldKeysForSkillSession(session) + needExchange := false + needModel := false + needStrategy := false + for _, field := range missing { + switch strings.TrimSpace(field) { + case "exchange_name", "exchange_id", "exchange": + needExchange = true + case "model_name", "model_id", "ai_model_id", "model": + needModel = true + case "strategy_name", "strategy_id", "strategy": + needStrategy = true + } + } + resources := map[string]any{} + if needExchange { + resources["exchanges"] = a.loadExchangeOptions(storeUserID) + } + if needModel { + resources["models"] = a.loadEnabledModelOptions(storeUserID) + } + if needStrategy { + resources["strategies"] = a.loadStrategyOptions(storeUserID) + } + return resources +} + func (a *Agent) tryHardSkill(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool) { if ctx != nil && ctx.Err() != nil { return "", false } - session := a.getSkillSession(userID) - session = a.inferContextualSkillSession(storeUserID, userID, text, session) - if (session.Name == "trader_management" && session.Action == "create") || detectCreateTraderSkill(text) { - answer, handled := a.handleCreateTraderSkill(storeUserID, userID, lang, text, session) + emptySession := skillSession{} + if hasExplicitCreateIntentForDomain(text, "trader") { + answer, handled := a.handleCreateTraderSkill(storeUserID, userID, lang, text, emptySession) if handled { a.recordSkillInteraction(userID, text, answer) if onEvent != nil { onEvent(StreamEventTool, "hard_skill:trader_management:create") - onEvent(StreamEventDelta, answer) + emitStreamText(onEvent, answer) } + return answer, true } - return answer, handled - } - if (session.Name == "trader_management" && session.Action != "create") || detectTraderManagementIntent(text) { - answer, handled := a.handleTraderManagementSkill(storeUserID, userID, lang, text, session) - if handled { - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - onEvent(StreamEventTool, "hard_skill:trader_management") - onEvent(StreamEventDelta, answer) - } - } - return answer, handled - } - if session.Name == "exchange_management" || detectExchangeManagementIntent(text) { - answer, handled := a.handleExchangeManagementSkill(storeUserID, userID, lang, text, session) - if handled { - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - onEvent(StreamEventTool, "hard_skill:exchange_management") - onEvent(StreamEventDelta, answer) - } - } - return answer, handled - } - if session.Name == "model_management" || detectModelManagementIntent(text) { - answer, handled := a.handleModelManagementSkill(storeUserID, userID, lang, text, session) - if handled { - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - onEvent(StreamEventTool, "hard_skill:model_management") - onEvent(StreamEventDelta, answer) - } - } - return answer, handled - } - if session.Name == "strategy_management" || detectStrategyManagementIntent(text) { - answer, handled := a.handleStrategyManagementSkill(storeUserID, userID, lang, text, session) - if handled { - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - onEvent(StreamEventTool, "hard_skill:strategy_management") - onEvent(StreamEventDelta, answer) - } - } - return answer, handled - } - if detectModelDiagnosisSkill(text) { - answer := a.handleModelDiagnosisSkill(storeUserID, lang, text) - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - onEvent(StreamEventTool, "hard_skill:model_diagnosis") - onEvent(StreamEventDelta, answer) - } - return answer, true - } - if detectExchangeDiagnosisSkill(text) { - answer := a.handleExchangeDiagnosisSkill(storeUserID, lang, text) - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - onEvent(StreamEventTool, "hard_skill:exchange_diagnosis") - onEvent(StreamEventDelta, answer) - } - return answer, true - } - if detectTraderDiagnosisSkill(text) { - answer := a.handleTraderDiagnosisSkill(storeUserID, lang, text) - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - onEvent(StreamEventTool, "hard_skill:trader_diagnosis") - onEvent(StreamEventDelta, answer) - } - return answer, true - } - if detectStrategyDiagnosisSkill(text) { - answer := a.handleStrategyDiagnosisSkill(storeUserID, lang, text) - a.recordSkillInteraction(userID, text, answer) - if onEvent != nil { - onEvent(StreamEventTool, "hard_skill:strategy_diagnosis") - onEvent(StreamEventDelta, answer) - } - return answer, true } return "", false } func (a *Agent) recordSkillInteraction(userID int64, userText, answer string) { - a.ensureHistory() + if a.history == nil { + a.history = newChatHistory(chatHistoryMaxTurns) + } a.history.Add(userID, "user", userText) a.history.Add(userID, "assistant", answer) } +func (a *Agent) rerouteRejectedSkillFlow(ctx context.Context, storeUserID string, userID int64, lang, text string) (string, bool) { + a.clearSkillSession(userID) + if a == nil || a.aiClient == nil { + return "", false + } + if answer, handled, err := a.tryLLMIntentRoute(ctx, storeUserID, userID, lang, text, nil); err == nil && handled { + return answer, true + } + if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, nil); ok { + return answer, true + } + if answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, nil); err == nil && strings.TrimSpace(answer) != "" { + return answer, true + } + return "", false +} + func ensureSkillFields(session *skillSession) { if session.Fields == nil { session.Fields = make(map[string]string) @@ -673,223 +432,75 @@ func ensureSkillFields(session *skillSession) { } func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang, text string, session skillSession) (string, bool) { - if isCancelSkillReply(text) { - a.clearSkillSession(userID) - if lang == "zh" { - return "已取消当前创建交易员流程。", true - } - return "Cancelled the current trader creation flow.", true - } - if session.Name == "" { session = skillSession{ Name: "trader_management", Action: "create", Phase: "collecting", - Slots: &createTraderSkillSlots{}, - } - if detectStartIntent(text) { - autoStart := true - session.Slots.AutoStart = &autoStart + Fields: map[string]string{}, } } - if session.Slots == nil { - session.Slots = &createTraderSkillSlots{} - } - if fieldValue(session, skillDAGStepField) == "" { - setSkillDAGStep(&session, "resolve_name") + if session.Fields == nil { + session.Fields = map[string]string{} } + syncTraderCreateSlotMirror(&session) if session.Phase == "await_start_confirmation" { - setSkillDAGStep(&session, "await_start_confirmation") switch { case isYesReply(text): - answer := a.executeCreateTraderSkill(storeUserID, userID, lang, session, true) - return answer, true + return a.executeCreateTraderSkill(storeUserID, userID, lang, session, true), true case isNoReply(text): - answer := a.executeCreateTraderSkill(storeUserID, userID, lang, session, false) - return answer, true - default: + return a.executeCreateTraderSkill(storeUserID, userID, lang, session, false), true + } + } + if session.Phase == "await_create_confirmation" { + switch { + case isYesReply(text): + return a.executeCreateTraderSkill(storeUserID, userID, lang, session, false), true + case isNoReply(text), isCancelSkillReply(text): + session.Phase = "collecting" a.saveSkillSession(userID, session) if lang == "zh" { - return "当前流程在等待你确认是否立即启动交易员。回复“确认”继续启动,回复“先不用”则只创建不启动。", true + return "好的,那我先不创建。你也可以继续改名称、交易所、模型或策略。", true } - return "This flow is waiting for your confirmation to start the trader. Reply 'confirm' to start it now, or 'no' to create without starting.", true + return "Okay, I won't create it yet. You can keep adjusting the name, exchange, model, or strategy.", true } } - slots := session.Slots - if slots.Name == "" { - slots.Name = extractTraderName(text) - } - if slots.Name != "" { - setSkillDAGStep(&session, "resolve_exchange") - } - - models := a.loadEnabledModelOptions(storeUserID) - exchanges := a.loadExchangeOptions(storeUserID) - strategies := a.loadStrategyOptions(storeUserID) - - if slots.ModelID == "" { - if match := pickOptionFromSegment(text, []string{"模型用", "模型", "model"}, models); match != nil { - slots.ModelID = match.ID - slots.ModelName = match.Name - } else if match := pickMentionedOption(text, models); match != nil { - slots.ModelID = match.ID - slots.ModelName = match.Name - } else if choice := choosePreferredOption(models); choice != nil { - slots.ModelID = choice.ID - slots.ModelName = choice.Name + a.hydrateCreateTraderSlotReferences(storeUserID, &session) + if fieldValue(session, "exchange_id") != "" && fieldValue(session, "model_id") != "" && fieldValue(session, "strategy_id") != "" { + if err := a.validateTraderDraft(storeUserID, fieldValue(session, "model_id"), fieldValue(session, "exchange_id"), fieldValue(session, "strategy_id")); err != nil { + session.Phase = "collecting" + a.saveSkillSession(userID, session) + return formatValidationFeedback(lang, "trader", err), true } } - if slots.ExchangeID != "" { - setSkillDAGStep(&session, "resolve_model") - } - if slots.ExchangeID == "" { - if match := pickOptionFromSegment(text, []string{"交易所用", "交易所", "exchange"}, exchanges); match != nil { - if match.Enabled { - slots.ExchangeID = match.ID - slots.ExchangeName = match.Name - } else { - if lang == "zh" { - extra := "你刚才提到的交易所“" + defaultIfEmpty(match.Name, match.ID) + "”当前已禁用,请换一个已启用的交易所。" - a.saveSkillSession(userID, session) - return extra + "\n" + formatOptionList("可用交易所:", exchanges), true - } - a.saveSkillSession(userID, session) - return "The exchange you mentioned is currently disabled. Please choose an enabled exchange.\n" + formatOptionList("Available exchanges:", exchanges), true - } - } else if match := pickMentionedOption(text, exchanges); match != nil { - if match.Enabled { - slots.ExchangeID = match.ID - slots.ExchangeName = match.Name - } else { - if lang == "zh" { - extra := "你刚才提到的交易所“" + defaultIfEmpty(match.Name, match.ID) + "”当前已禁用,请换一个已启用的交易所。" - a.saveSkillSession(userID, session) - return extra + "\n" + formatOptionList("可用交易所:", exchanges), true - } - a.saveSkillSession(userID, session) - return "The exchange you mentioned is currently disabled. Please choose an enabled exchange.\n" + formatOptionList("Available exchanges:", exchanges), true - } - } else if choice := choosePreferredOption(exchanges); choice != nil { - slots.ExchangeID = choice.ID - slots.ExchangeName = choice.Name - } - } - if slots.StrategyID == "" { - if match := pickOptionFromSegment(text, []string{"策略用", "策略", "strategy"}, strategies); match != nil { - slots.StrategyID = match.ID - slots.StrategyName = match.Name - } else if match := pickMentionedOption(text, strategies); match != nil { - slots.StrategyID = match.ID - slots.StrategyName = match.Name - } else if choice := choosePreferredOption(strategies); choice != nil { - slots.StrategyID = choice.ID - slots.StrategyName = choice.Name - } - } - if slots.ModelID != "" { - setSkillDAGStep(&session, "resolve_strategy") - } - if slots.StrategyID != "" { - setSkillDAGStep(&session, "maybe_confirm_start") - } - - if slots.AutoStart == nil && detectStartIntent(text) { - autoStart := true - slots.AutoStart = &autoStart - } - - missing := make([]string, 0, 3) - extraLines := make([]string, 0, 3) - if actionRequiresSlot("trader_management", "create", "name") && slots.Name == "" { - missing = append(missing, slotDisplayName("name", lang)) - } - if actionRequiresSlot("trader_management", "create", "exchange") && slots.ExchangeID == "" { - missing = append(missing, slotDisplayName("exchange", lang)) - if len(exchanges) == 0 { - if lang == "zh" { - extraLines = append(extraLines, "当前还没有可用交易所配置,请先配置并启用一个交易所账户。") - } else { - extraLines = append(extraLines, "There is no enabled exchange config yet. Please create and enable one first.") - } - } else { - label := "Available exchanges:" - if lang == "zh" { - label = "可用交易所:" - } - extraLines = append(extraLines, formatOptionList(label, exchanges)) - } - } - if actionRequiresSlot("trader_management", "create", "model") && slots.ModelID == "" { - missing = append(missing, slotDisplayName("model", lang)) - if len(models) == 0 { - if lang == "zh" { - extraLines = append(extraLines, "当前还没有可用模型配置,请先配置并启用一个模型。") - } else { - extraLines = append(extraLines, "There is no enabled model config yet. Please create and enable one first.") - } - } else { - label := "Available models:" - if lang == "zh" { - label = "可用模型:" - } - extraLines = append(extraLines, formatOptionList(label, models)) - } - } - if slots.StrategyID == "" && (actionRequiresSlot("trader_management", "create", "strategy") || len(strategies) == 0) { - missing = append(missing, slotDisplayName("strategy", lang)) - } - if slots.StrategyID == "" { - if len(strategies) == 0 { - if lang == "zh" { - extraLines = append(extraLines, "当前还没有可用策略,请先创建一个策略。") - } else { - extraLines = append(extraLines, "There is no strategy available yet. Please create one first.") - } - } else { - label := "Available strategies:" - if lang == "zh" { - label = "可用策略:" - } - extraLines = append(extraLines, formatOptionList(label, strategies)) - } - } - - if len(missing) > 0 { + if missing := missingFieldKeysForSkillSession(session); len(missing) > 0 { session.Phase = "collecting" a.saveSkillSession(userID, session) - if lang == "zh" { - reply := "要继续创建交易员,还缺这些信息:" + strings.Join(missing, "、") + "。" - if len(extraLines) > 0 { - reply += "\n" + strings.Join(cleanStringList(extraLines), "\n") - } - reply += "\n你可以直接一次性告诉我,例如:名称、用哪个交易所、哪个模型、哪个策略。" - return reply, true - } - reply := "To continue creating the trader, I still need: " + strings.Join(missing, ", ") + "." - if len(extraLines) > 0 { - reply += "\n" + strings.Join(cleanStringList(extraLines), "\n") - } - reply += "\nYou can reply with all missing fields in one message." - return reply, true + return a.buildTraderCreateMissingPrompt(storeUserID, lang, session, a.buildTraderCreateConversationResources(storeUserID, session)), true } - if slots.AutoStart != nil && *slots.AutoStart { + if stillMissing := missingFieldKeysForSkillSession(session); len(stillMissing) > 0 { + session.Phase = "collecting" + a.saveSkillSession(userID, session) + return a.buildTraderCreateMissingPrompt(storeUserID, lang, session, a.buildTraderCreateConversationResources(storeUserID, session)), true + } + + if fieldValue(session, "auto_start") == "true" { session.Phase = "await_start_confirmation" - setSkillDAGStep(&session, "await_start_confirmation") a.saveSkillSession(userID, session) if lang == "zh" { - return fmt.Sprintf("我已经准备好创建交易员“%s”,并在创建后立即启动它。\n使用的交易所:%s\n使用的模型:%s\n使用的策略:%s\n\n这是高风险动作。回复“确认”继续,回复“先不用”则只创建不启动。", - slots.Name, slots.ExchangeNameOrID(), slots.ModelNameOrID(), slots.StrategyNameOrID()), true + return fmt.Sprintf("准备创建交易员并立即启动。\n交易所:%s\n模型:%s\n策略:%s\n\n回复确认继续,回复先不用则只创建不启动。", + traderCreateExchangeNameOrID(session), traderCreateModelNameOrID(session), traderCreateStrategyNameOrID(session)), true } - return fmt.Sprintf("I'm ready to create trader %q and start it immediately.\nExchange: %s\nModel: %s\nStrategy: %s\n\nThis is a high-risk action. Reply 'confirm' to continue, or 'no' to create it without starting.", - slots.Name, slots.ExchangeNameOrID(), slots.ModelNameOrID(), slots.StrategyNameOrID()), true + return fmt.Sprintf("Ready to create trader and start it immediately.\nExchange: %s\nModel: %s\nStrategy: %s\n\nReply confirm to continue, or no to create without starting.", + traderCreateExchangeNameOrID(session), traderCreateModelNameOrID(session), traderCreateStrategyNameOrID(session)), true } - answer := a.executeCreateTraderSkill(storeUserID, userID, lang, session, false) - return answer, true + session.Phase = "await_create_confirmation" + a.saveSkillSession(userID, session) + return formatTraderCreateDraftSummary(lang, session), true } func (s *createTraderSkillSlots) ExchangeNameOrID() string { @@ -913,28 +524,150 @@ func (s *createTraderSkillSlots) StrategyNameOrID() string { return s.StrategyID } +func traderCreateExchangeNameOrID(session skillSession) string { + if value := fieldValue(session, "exchange_name"); value != "" { + return value + } + return fieldValue(session, "exchange_id") +} + +func traderCreateModelNameOrID(session skillSession) string { + if value := fieldValue(session, "model_name"); value != "" { + return value + } + return fieldValue(session, "model_id") +} + +func traderCreateStrategyNameOrID(session skillSession) string { + if value := fieldValue(session, "strategy_name"); value != "" { + return value + } + return fieldValue(session, "strategy_id") +} + +func renderSkillMissingLabels(lang string, missing []string) []string { + out := make([]string, 0, len(missing)) + for _, field := range missing { + out = append(out, slotDisplayName(field, lang)) + } + return out +} + +func (a *Agent) buildTraderCreateMissingPrompt(storeUserID, lang string, session skillSession, availableResources map[string]any) string { + missing := missingFieldKeysForSkillSession(session) + missingLabels := strings.Join(renderSkillMissingLabels(lang, missing), "、") + prereqs := make([]string, 0, 3) + optionLines := make([]string, 0, 3) + if exchanges, _ := availableResources["exchanges"].([]traderSkillOption); len(exchanges) == 0 && containsString(missing, "exchange_name") { + if lang == "zh" { + prereqs = append(prereqs, "当前还没有可用交易所配置") + } else { + prereqs = append(prereqs, "there is no exchange config yet") + } + } else if containsString(missing, "exchange_name") { + if list := formatOptionList("现有交易所:", exchanges); lang == "zh" && list != "" { + optionLines = append(optionLines, list) + } else if list := formatOptionList("Available exchanges:", exchanges); lang != "zh" && list != "" { + optionLines = append(optionLines, list) + } + } + if models, _ := availableResources["models"].([]traderSkillOption); len(models) == 0 && containsString(missing, "model_name") { + if lang == "zh" { + prereqs = append(prereqs, "当前还没有可用模型配置") + } else { + prereqs = append(prereqs, "there is no model config yet") + } + } else if containsString(missing, "model_name") { + if list := formatOptionList("现有模型:", models); lang == "zh" && list != "" { + optionLines = append(optionLines, list) + } else if list := formatOptionList("Available models:", models); lang != "zh" && list != "" { + optionLines = append(optionLines, list) + } + } + if strategies, _ := availableResources["strategies"].([]traderSkillOption); len(strategies) == 0 && containsString(missing, "strategy_name") { + if lang == "zh" { + prereqs = append(prereqs, "当前还没有可用策略") + } else { + prereqs = append(prereqs, "there is no strategy yet") + } + } else if containsString(missing, "strategy_name") { + if list := formatOptionList("现有策略:", strategies); lang == "zh" && list != "" { + optionLines = append(optionLines, list) + } else if list := formatOptionList("Available strategies:", strategies); lang != "zh" && list != "" { + optionLines = append(optionLines, list) + } + } + if lang == "zh" { + reply := "新建交易员还缺这些槽位:" + missingLabels + "。" + if len(prereqs) > 0 { + reply += "\n" + strings.Join(prereqs, ";") + "。" + } + if len(optionLines) > 0 { + reply += "\n" + strings.Join(optionLines, "\n") + } + return reply + } + reply := "Creating the trader still needs these slots: " + strings.Join(renderSkillMissingLabels(lang, missing), ", ") + "." + if len(prereqs) > 0 { + reply += "\n" + strings.Join(prereqs, "; ") + "." + } + if len(optionLines) > 0 { + reply += "\n" + strings.Join(optionLines, "\n") + } + return reply +} + +func containsString(items []string, target string) bool { + for _, item := range items { + if item == target { + return true + } + } + return false +} + +func shouldPreserveTraderCreateSessionOnError(errMsg string) bool { + lower := strings.ToLower(strings.TrimSpace(errMsg)) + if lower == "" { + return false + } + return strings.Contains(lower, "exchange is disabled") || + strings.Contains(lower, "exchange_id is required") || + strings.Contains(lower, "model_id is required") || + strings.Contains(lower, "strategy_id is required") +} + func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang string, session skillSession, startAfterCreate bool) string { + a.hydrateCreateTraderSlotReferences(storeUserID, &session) + normalizedArgs, _ := normalizeTraderArgsToManualLimits(lang, buildTraderUpdateArgsFromSession(session)) args := manageTraderArgs{ - Action: "create", - Name: session.Slots.Name, - AIModelID: session.Slots.ModelID, - ExchangeID: session.Slots.ExchangeID, - StrategyID: session.Slots.StrategyID, + Action: "create", + Name: fieldValue(session, "name"), + AIModelID: fieldValue(session, "model_id"), + ExchangeID: fieldValue(session, "exchange_id"), + StrategyID: fieldValue(session, "strategy_id"), + ScanIntervalMinutes: normalizedArgs.ScanIntervalMinutes, + IsCrossMargin: normalizedArgs.IsCrossMargin, + ShowInCompetition: normalizedArgs.ShowInCompetition, } createRaw := a.toolCreateTrader(storeUserID, args) if errMsg := parseSkillError(createRaw); errMsg != "" && strings.Contains(createRaw, `"error"`) { - session.Phase = "collecting" - a.saveSkillSession(userID, session) + if shouldPreserveTraderCreateSessionOnError(errMsg) { + session.Phase = "collecting" + a.saveSkillSession(userID, session) + } else { + a.clearSkillSession(userID) + } if strings.Contains(strings.ToLower(errMsg), "exchange is disabled") { exchanges := a.loadExchangeOptions(storeUserID) if lang == "zh" { - reply := fmt.Sprintf("创建交易员失败:你选的交易所“%s”当前已禁用,请换一个已启用的交易所。", session.Slots.ExchangeNameOrID()) + reply := fmt.Sprintf("创建交易员失败:你选的交易所“%s”当前已禁用,请换一个已启用的交易所。", traderCreateExchangeNameOrID(session)) if list := formatOptionList("可用交易所:", exchanges); list != "" { reply += "\n" + list } return reply } - reply := fmt.Sprintf("Failed to create trader: the selected exchange %q is disabled. Please choose an enabled exchange.", session.Slots.ExchangeNameOrID()) + reply := fmt.Sprintf("That trader could not be created because the exchange %q is turned off. Please choose one that is enabled.", traderCreateExchangeNameOrID(session)) if list := formatOptionList("Available exchanges:", exchanges); list != "" { reply += "\n" + list } @@ -943,7 +676,7 @@ func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang if lang == "zh" { return "创建交易员失败:" + errMsg } - return "Failed to create trader: " + errMsg + return "That create request did not go through: " + errMsg } var created struct { Trader safeTraderToolConfig `json:"trader"` @@ -961,10 +694,10 @@ func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang a.clearSkillSession(userID) if lang == "zh" { return fmt.Sprintf("已创建交易员“%s”。\n交易所:%s\n模型:%s\n策略:%s\n当前状态:未启动。", - created.Trader.Name, session.Slots.ExchangeNameOrID(), session.Slots.ModelNameOrID(), session.Slots.StrategyNameOrID()) + created.Trader.Name, traderCreateExchangeNameOrID(session), traderCreateModelNameOrID(session), traderCreateStrategyNameOrID(session)) } return fmt.Sprintf("Created trader %q.\nExchange: %s\nModel: %s\nStrategy: %s\nCurrent status: not started.", - created.Trader.Name, session.Slots.ExchangeNameOrID(), session.Slots.ModelNameOrID(), session.Slots.StrategyNameOrID()) + created.Trader.Name, traderCreateExchangeNameOrID(session), traderCreateModelNameOrID(session), traderCreateStrategyNameOrID(session)) } setSkillDAGStep(&session, "execute_create_and_start") @@ -980,10 +713,10 @@ func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang a.clearSkillSession(userID) if lang == "zh" { return fmt.Sprintf("已创建并启动交易员“%s”。\n交易所:%s\n模型:%s\n策略:%s", - created.Trader.Name, session.Slots.ExchangeNameOrID(), session.Slots.ModelNameOrID(), session.Slots.StrategyNameOrID()) + created.Trader.Name, traderCreateExchangeNameOrID(session), traderCreateModelNameOrID(session), traderCreateStrategyNameOrID(session)) } return fmt.Sprintf("Created and started trader %q.\nExchange: %s\nModel: %s\nStrategy: %s", - created.Trader.Name, session.Slots.ExchangeNameOrID(), session.Slots.ModelNameOrID(), session.Slots.StrategyNameOrID()) + created.Trader.Name, traderCreateExchangeNameOrID(session), traderCreateModelNameOrID(session), traderCreateStrategyNameOrID(session)) } func (a *Agent) handleModelDiagnosisSkill(storeUserID, lang, text string) string { @@ -1123,3 +856,278 @@ func backendLogDiagnosisExcerpt(lang, text, fallbackFilter string) string { } return "Recent matching backend error logs:\n- " + strings.Join(entries, "\n- ") } + +type targetResolution struct { + Ref *EntityReference + Ambiguous []traderSkillOption + WasMentioned bool +} + +func enabledTraderSkillOptions(options []traderSkillOption) []traderSkillOption { + out := make([]traderSkillOption, 0, len(options)) + for _, o := range options { + if o.Enabled { + out = append(out, o) + } + } + return out +} + +func resolveSemanticExistingTraderDependency(currentRef *EntityReference, options []traderSkillOption) targetResolution { + if currentRef != nil && strings.TrimSpace(currentRef.ID) != "" { + for _, opt := range options { + if opt.ID == currentRef.ID { + return targetResolution{Ref: &EntityReference{ID: opt.ID, Name: opt.Name}} + } + } + } + enabled := enabledTraderSkillOptions(options) + if len(enabled) == 1 { + return targetResolution{Ref: &EntityReference{ID: enabled[0].ID, Name: enabled[0].Name}} + } + if len(enabled) > 1 { + return targetResolution{Ambiguous: enabled} + } + return targetResolution{} +} + +func (a *Agent) hydrateCreateTraderSlotReferences(storeUserID string, session *skillSession) { + if session == nil { + return + } + if fieldValue(*session, "exchange_id") == "" && fieldValue(*session, "exchange_name") != "" { + options := a.loadExchangeOptions(storeUserID) + if opt := findOptionByIDOrName(options, fieldValue(*session, "exchange_name")); opt != nil { + setField(session, "exchange_id", opt.ID) + } else if opt := findUniqueContainingOption(options, fieldValue(*session, "exchange_name")); opt != nil { + setField(session, "exchange_id", opt.ID) + } + } + if fieldValue(*session, "exchange_id") != "" { + options := a.loadExchangeOptions(storeUserID) + if opt := findOptionByIDOrName(options, fieldValue(*session, "exchange_id")); opt != nil { + setField(session, "exchange_id", opt.ID) + if fieldValue(*session, "exchange_name") == "" { + setField(session, "exchange_name", opt.Name) + } + } + } + if fieldValue(*session, "model_id") == "" && fieldValue(*session, "model_name") != "" { + options := a.loadEnabledModelOptions(storeUserID) + if opt := findOptionByIDOrName(options, fieldValue(*session, "model_name")); opt != nil { + setField(session, "model_id", opt.ID) + } else if opt := findUniqueContainingOption(options, fieldValue(*session, "model_name")); opt != nil { + setField(session, "model_id", opt.ID) + } + } + if fieldValue(*session, "model_id") != "" { + options := a.loadEnabledModelOptions(storeUserID) + if opt := findOptionByIDOrName(options, fieldValue(*session, "model_id")); opt != nil { + setField(session, "model_id", opt.ID) + if fieldValue(*session, "model_name") == "" { + setField(session, "model_name", opt.Name) + } + } + } + if fieldValue(*session, "strategy_id") == "" && fieldValue(*session, "strategy_name") != "" { + options := a.loadStrategyOptions(storeUserID) + if opt := findOptionByIDOrName(options, fieldValue(*session, "strategy_name")); opt != nil { + setField(session, "strategy_id", opt.ID) + } else if opt := findUniqueContainingOption(options, fieldValue(*session, "strategy_name")); opt != nil { + setField(session, "strategy_id", opt.ID) + } + } + if fieldValue(*session, "strategy_id") != "" { + options := a.loadStrategyOptions(storeUserID) + if opt := findOptionByIDOrName(options, fieldValue(*session, "strategy_id")); opt != nil { + setField(session, "strategy_id", opt.ID) + if fieldValue(*session, "strategy_name") == "" { + setField(session, "strategy_name", opt.Name) + } + } + } +} + +func (a *Agent) maybeResumeParentTaskAfterSuccessfulSkill(storeUserID string, userID int64, lang, skill, action, answer string) string { + sm := a.SnapshotManager(userID) + parent, ok := sm.Peek() + if !ok || !parent.ResumeOnSuccess { + return answer + } + triggered := false + for _, t := range parent.ResumeTriggers { + if t == skill { + triggered = true + break + } + } + if !triggered { + return answer + } + sm.Load() // pop + // restore parent history + if a.history != nil && len(parent.LocalHistory) > 0 { + a.history.Replace(userID, parent.LocalHistory) + } + // inject child result as system message + if a.history != nil && strings.TrimSpace(answer) != "" { + inject := fmt.Sprintf("[子任务 %s/%s 已完成,结果:%s]", skill, action, answer) + a.history.Add(userID, "system", inject) + } + // restore parent skill session + if parent.SkillSession != nil { + restored := *parent.SkillSession + a.hydrateCreateTraderSlotReferences(storeUserID, &restored) + a.saveSkillSession(userID, restored) + resumeNotice := "" + if lang == "zh" { + resumeNotice = "我已经切回刚才的主任务。" + } else { + resumeNotice = "I switched back to the earlier main task." + } + if restored.Name == "trader_management" && restored.Action == "create" { + followup := a.buildTraderCreateMissingPrompt(storeUserID, lang, restored, a.buildTraderCreateConversationResources(storeUserID, restored)) + if strings.TrimSpace(followup) != "" { + if strings.TrimSpace(answer) == "" { + return resumeNotice + "\n" + followup + } + return strings.TrimSpace(answer) + "\n" + resumeNotice + "\n" + followup + } + } + if strings.TrimSpace(answer) == "" { + return resumeNotice + } + return strings.TrimSpace(answer) + "\n" + resumeNotice + } + return answer +} + +func resolveTargetSelection(text string, options []traderSkillOption, existing *EntityReference) targetResolution { + if existing != nil && strings.TrimSpace(existing.ID) != "" { + for _, opt := range options { + if opt.ID == existing.ID { + return targetResolution{Ref: &EntityReference{ID: opt.ID, Name: defaultIfEmpty(opt.Name, existing.Name), Source: existing.Source}} + } + } + } + if existing != nil && strings.TrimSpace(existing.Name) != "" { + if opt := findOptionByIDOrName(options, existing.Name); opt != nil { + return targetResolution{Ref: &EntityReference{ID: opt.ID, Name: opt.Name, Source: existing.Source}} + } + if opt := findUniqueContainingOption(options, existing.Name); opt != nil { + return targetResolution{Ref: &EntityReference{ID: opt.ID, Name: opt.Name, Source: existing.Source}} + } + } + if opt := findOptionByIDOrName(options, text); opt != nil { + return targetResolution{Ref: &EntityReference{ID: opt.ID, Name: opt.Name, Source: "user_mention"}} + } + if opt := findUniqueContainingOption(options, text); opt != nil { + return targetResolution{Ref: &EntityReference{ID: opt.ID, Name: opt.Name, Source: "user_mention"}} + } + if len(options) > 1 { + return targetResolution{Ambiguous: options} + } + return targetResolution{} +} + +func findOptionByIDOrName(options []traderSkillOption, query string) *traderSkillOption { + query = strings.TrimSpace(query) + if query == "" { + return nil + } + for i, opt := range options { + if opt.ID == query || strings.EqualFold(opt.Name, query) || strings.EqualFold(opt.Hint, query) { + return &options[i] + } + } + return nil +} + +func findUniqueContainingOption(options []traderSkillOption, query string) *traderSkillOption { + query = strings.ToLower(strings.TrimSpace(query)) + if query == "" { + return nil + } + matches := make([]traderSkillOption, 0, 1) + for _, opt := range options { + name := strings.ToLower(strings.TrimSpace(opt.Name)) + hint := strings.ToLower(strings.TrimSpace(opt.Hint)) + id := strings.ToLower(strings.TrimSpace(opt.ID)) + if (name != "" && (strings.Contains(name, query) || strings.Contains(query, name))) || + (hint != "" && (strings.Contains(hint, query) || strings.Contains(query, hint))) || + (id != "" && (strings.Contains(id, query) || strings.Contains(query, id))) { + matches = append(matches, opt) + } + } + if len(matches) != 1 { + return nil + } + return &matches[0] +} + +func formatAmbiguousTargetPrompt(lang string, options []traderSkillOption) string { + if duplicateName, ok := sharedAmbiguousOptionName(options); ok { + if lang == "zh" { + return fmt.Sprintf("你提到的是“%s”,但当前有 %d 个同名对象。请告诉我你要操作哪一个。\n%s", duplicateName, len(options), formatDisambiguationOptionList("可选对象:", options)) + } + return fmt.Sprintf("You mentioned %q, but there are %d objects with the same name. Please tell me which one to operate on.\n%s", duplicateName, len(options), formatDisambiguationOptionList("Available targets:", options)) + } + if lang == "zh" { + return "找到多个匹配对象,请告诉我你要操作哪一个。\n" + formatDisambiguationOptionList("可选对象:", options) + } + return "Multiple matches found. Please tell me which one to operate on.\n" + formatDisambiguationOptionList("Available targets:", options) +} + +func sharedAmbiguousOptionName(options []traderSkillOption) (string, bool) { + if len(options) < 2 { + return "", false + } + base := strings.TrimSpace(options[0].Name) + if base == "" { + return "", false + } + for _, option := range options[1:] { + if !strings.EqualFold(strings.TrimSpace(option.Name), base) { + return "", false + } + } + return base, true +} + +func formatDisambiguationOptionList(prefix string, options []traderSkillOption) string { + parts := make([]string, 0, len(options)) + for _, option := range options { + label := strings.TrimSpace(option.Name) + if label == "" { + label = option.ID + } + if hint := strings.TrimSpace(option.Hint); hint != "" { + label += "(" + hint + ")" + } + if suffix := shortOptionIDSuffix(option.ID); suffix != "" { + label += fmt.Sprintf("(ID后缀 %s)", suffix) + } + if option.Enabled { + label += "(已启用)" + } else { + label += "(已禁用)" + } + parts = append(parts, label) + } + if len(parts) == 0 { + return "" + } + return prefix + strings.Join(parts, "、") +} + +func shortOptionIDSuffix(id string) string { + id = strings.TrimSpace(id) + if id == "" { + return "" + } + runes := []rune(id) + if len(runes) <= 4 { + return id + } + return string(runes[len(runes)-4:]) +} diff --git a/agent/skill_dispatcher_test.go b/agent/skill_dispatcher_test.go deleted file mode 100644 index bb292156..00000000 --- a/agent/skill_dispatcher_test.go +++ /dev/null @@ -1,828 +0,0 @@ -package agent - -import ( - "context" - "encoding/json" - "errors" - "strings" - "testing" - "time" - - "nofx/mcp" -) - -func TestCreateTraderSkillCollectsMissingFieldsAndCreatesTrader(t *testing.T) { - a := newTestAgentWithStore(t) - - modelResp := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"deepseek", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.deepseek.com/v1", - "custom_model_name":"deepseek-chat" - }`) - if strings.Contains(modelResp, `"error"`) { - t.Fatalf("failed to create model: %s", modelResp) - } - exchangeResp := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"okx", - "account_name":"主账户", - "enabled":true - }`) - if strings.Contains(exchangeResp, `"error"`) { - t.Fatalf("failed to create exchange: %s", exchangeResp) - } - strategyResp := a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"趋势策略", - "lang":"zh" - }`) - if strings.Contains(strategyResp, `"error"`) { - t.Fatalf("failed to create strategy: %s", strategyResp) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 1, "zh", "帮我创建一个交易员") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "还缺这些信息") || !strings.Contains(resp, "名称") { - t.Fatalf("expected missing-field prompt, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 1, "zh", "叫 波段一号") - if err != nil { - t.Fatalf("thinkAndAct() second turn error = %v", err) - } - if !strings.Contains(resp, "已创建交易员") || !strings.Contains(resp, "波段一号") { - t.Fatalf("expected trader creation confirmation, got %q", resp) - } - - listResp := a.toolListTraders("user-1") - if !strings.Contains(listResp, "波段一号") { - t.Fatalf("expected created trader in list, got %s", listResp) - } -} - -func TestCreateTraderSkillReportsAllMissingPrerequisitesAtOnce(t *testing.T) { - a := newTestAgentWithStore(t) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 11, "zh", "帮我创建一个交易员") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - for _, want := range []string{"名称", "交易所", "模型", "策略"} { - if !strings.Contains(resp, want) { - t.Fatalf("expected response to mention %q, got %q", want, resp) - } - } - for _, want := range []string{"当前还没有可用交易所配置", "当前还没有可用模型配置", "当前还没有可用策略"} { - if !strings.Contains(resp, want) { - t.Fatalf("expected response to mention prerequisite %q, got %q", want, resp) - } - } -} - -func TestActiveSkillSessionYieldsToNewTopic(t *testing.T) { - a := newTestAgentWithStore(t) - - _ = a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"测试策略", - "lang":"zh" - }`) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 13, "zh", "帮我创建一个交易员") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "还缺这些信息") { - t.Fatalf("expected trader creation flow prompt, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 13, "zh", "列出我当前的策略") - if err != nil { - t.Fatalf("thinkAndAct() interrupt error = %v", err) - } - if !strings.Contains(resp, "当前策略") || !strings.Contains(resp, "测试策略") { - t.Fatalf("expected new topic to be handled, got %q", resp) - } - if a.hasActiveSkillSession(13) { - t.Fatal("expected skill session to be cleared after interruption") - } -} - -func TestCreateTraderSkillRequestsStartConfirmation(t *testing.T) { - a := newTestAgentWithStore(t) - - _ = a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.openai.com/v1", - "custom_model_name":"gpt-5" - }`) - _ = a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"binance", - "account_name":"Main", - "enabled":true - }`) - _ = a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"保守策略", - "lang":"zh" - }`) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 2, "zh", "创建一个叫“实盘一号”的交易员并启动") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "高风险动作") || !strings.Contains(resp, "确认") { - t.Fatalf("expected start confirmation prompt, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 2, "zh", "先不用") - if err != nil { - t.Fatalf("thinkAndAct() confirmation error = %v", err) - } - if !strings.Contains(resp, "已创建交易员") || strings.Contains(resp, "已创建并启动") { - t.Fatalf("expected create-without-start response, got %q", resp) - } -} - -func TestModelDiagnosisSkillHandledWithoutAIClient(t *testing.T) { - a := newTestAgentWithStore(t) - resp, err := a.thinkAndAct(context.Background(), "user-1", 3, "zh", "为什么我的模型配置失败了") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "模型配置") { - t.Fatalf("expected model diagnosis response, got %q", resp) - } -} - -func TestExchangeDiagnosisSkillHandledWithoutAIClient(t *testing.T) { - a := newTestAgentWithStore(t) - resp, err := a.thinkAndAct(context.Background(), "user-1", 4, "zh", "交易所 API 报 invalid signature 怎么办") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "invalid signature") && !strings.Contains(resp, "签名") { - t.Fatalf("expected exchange diagnosis response, got %q", resp) - } -} - -func TestExchangeManagementCreateAndQuerySkill(t *testing.T) { - a := newTestAgentWithStore(t) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 5, "zh", "帮我创建一个 OKX 交易所配置") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "已创建交易所配置") { - t.Fatalf("expected exchange create response, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 5, "zh", "列出我的交易所配置") - if err != nil { - t.Fatalf("thinkAndAct() query error = %v", err) - } - if !strings.Contains(resp, "当前交易所配置") && !strings.Contains(resp, "Default") { - t.Fatalf("expected exchange query response, got %q", resp) - } -} - -func TestModelManagementCreateSkill(t *testing.T) { - a := newTestAgentWithStore(t) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 6, "zh", "帮我创建一个 DeepSeek 模型配置") - if err != nil { - t.Fatalf("thinkAndAct() error = %v", err) - } - if !strings.Contains(resp, "已创建模型配置") { - t.Fatalf("expected model create response, got %q", resp) - } -} - -func TestStrategyManagementCreateAndActivateSkill(t *testing.T) { - a := newTestAgentWithStore(t) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 7, "zh", "创建一个叫“趋势策略B”的策略") - if err != nil { - t.Fatalf("thinkAndAct() create error = %v", err) - } - if !strings.Contains(resp, "已创建策略") { - t.Fatalf("expected strategy create response, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 7, "zh", "激活趋势策略B") - if err != nil { - t.Fatalf("thinkAndAct() activate error = %v", err) - } - if !strings.Contains(resp, "已激活策略") { - t.Fatalf("expected strategy activate response, got %q", resp) - } -} - -func TestStrategyManagementQueryCanExplainStrategyDetails(t *testing.T) { - a := newTestAgentWithStore(t) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 12, "zh", "创建一个叫“激进的”的策略") - if err != nil { - t.Fatalf("thinkAndAct() create error = %v", err) - } - if !strings.Contains(resp, "已创建策略") { - t.Fatalf("expected strategy create response, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 12, "zh", "这个策略里面的参数和prompt分别是什么样的") - if err != nil { - t.Fatalf("thinkAndAct() detail query error = %v", err) - } - for _, want := range []string{"策略“激进的”概览", "K线周期", "仓位风险", "Prompt"} { - if !strings.Contains(resp, want) { - t.Fatalf("expected response to mention %q, got %q", want, resp) - } - } -} - -func TestTraderManagementQueryAndDiagnosisSkill(t *testing.T) { - a := newTestAgentWithStore(t) - - modelResp := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.openai.com/v1", - "custom_model_name":"gpt-5" - }`) - var modelCreated struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(modelResp), &modelCreated); err != nil { - t.Fatalf("unmarshal model response: %v", err) - } - - exchangeResp := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"binance", - "account_name":"Main", - "enabled":true - }`) - var exchangeCreated struct { - Exchange safeExchangeToolConfig `json:"exchange"` - } - if err := json.Unmarshal([]byte(exchangeResp), &exchangeCreated); err != nil { - t.Fatalf("unmarshal exchange response: %v", err) - } - _ = a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"测试策略", - "lang":"zh" - }`) - _ = a.toolManageTrader("user-1", `{ - "action":"create", - "name":"测试交易员", - "ai_model_id":"`+modelCreated.Model.ID+`", - "exchange_id":"`+exchangeCreated.Exchange.ID+`", - "strategy_id":"" - }`) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 8, "zh", "查看我的交易员") - if err != nil { - t.Fatalf("thinkAndAct() query error = %v", err) - } - if !strings.Contains(resp, "当前交易员") && !strings.Contains(resp, "测试交易员") { - t.Fatalf("expected trader query response, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 8, "zh", "为什么我的交易员不交易") - if err != nil { - t.Fatalf("thinkAndAct() diagnosis error = %v", err) - } - if !strings.Contains(resp, "交易员运行诊断") { - t.Fatalf("expected trader diagnosis response, got %q", resp) - } -} - -func TestExchangeManagementAtomicUpdates(t *testing.T) { - a := newTestAgentWithStore(t) - - createResp := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"okx", - "account_name":"主账户", - "enabled":true - }`) - var created struct { - Exchange safeExchangeToolConfig `json:"exchange"` - } - if err := json.Unmarshal([]byte(createResp), &created); err != nil { - t.Fatalf("unmarshal exchange response: %v", err) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 14, "zh", "更新交易所,把主账户改名为备用账户") - if err != nil { - t.Fatalf("rename exchange error = %v", err) - } - if !strings.Contains(resp, "已更新交易所配置") { - t.Fatalf("expected exchange update response, got %q", resp) - } - - raw := a.toolGetExchangeConfigs("user-1") - if !strings.Contains(raw, "备用账户") { - t.Fatalf("expected renamed exchange in list, got %s", raw) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 14, "zh", "禁用这个交易所配置") - if err != nil { - t.Fatalf("disable exchange error = %v", err) - } - if !strings.Contains(resp, "已更新交易所配置") { - t.Fatalf("expected exchange status update response, got %q", resp) - } - - raw = a.toolGetExchangeConfigs("user-1") - if strings.Contains(raw, `"enabled":true`) && strings.Contains(raw, "备用账户") { - t.Fatalf("expected exchange to be disabled, got %s", raw) - } -} - -func TestModelManagementAtomicUpdates(t *testing.T) { - a := newTestAgentWithStore(t) - - createResp := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"deepseek", - "enabled":true, - "custom_api_url":"https://api.deepseek.com/v1", - "custom_model_name":"deepseek-chat" - }`) - var created struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(createResp), &created); err != nil { - t.Fatalf("unmarshal model response: %v", err) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 15, "zh", "更新模型,把模型名称改成 deepseek-reasoner") - if err != nil { - t.Fatalf("rename model error = %v", err) - } - if !strings.Contains(resp, "已更新模型配置") { - t.Fatalf("expected model update response, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 15, "zh", "更新模型,把接口地址改成 https://api.deepseek.com/beta") - if err != nil { - t.Fatalf("update model endpoint error = %v", err) - } - if !strings.Contains(resp, "已更新模型配置") { - t.Fatalf("expected model endpoint update response, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 15, "zh", "禁用这个模型配置") - if err != nil { - t.Fatalf("disable model error = %v", err) - } - if !strings.Contains(resp, "已更新模型配置") { - t.Fatalf("expected model status update response, got %q", resp) - } - - raw := a.toolGetModelConfigs("user-1") - if !strings.Contains(raw, "deepseek-reasoner") || !strings.Contains(raw, "https://api.deepseek.com/beta") { - t.Fatalf("expected updated model fields, got %s", raw) - } - if strings.Contains(raw, `"enabled":true`) && strings.Contains(raw, created.Model.ID) { - t.Fatalf("expected model to be disabled, got %s", raw) - } -} - -func TestStrategyManagementAtomicUpdates(t *testing.T) { - a := newTestAgentWithStore(t) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 16, "zh", "创建一个叫“激进策略C”的策略") - if err != nil { - t.Fatalf("create strategy error = %v", err) - } - if !strings.Contains(resp, "已创建策略") { - t.Fatalf("expected strategy create response, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 16, "zh", "更新这个策略的prompt,把提示词改成“优先观察BTC和ETH,信号不一致时不要开仓”") - if err != nil { - t.Fatalf("update strategy prompt error = %v", err) - } - if !strings.Contains(resp, "已更新策略 prompt") { - t.Fatalf("expected strategy prompt update response, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 16, "zh", "更新这个策略参数,把最大持仓改成2,最低置信度改成80,主周期改成15m,并使用15m 1h 4h") - if err != nil { - t.Fatalf("update strategy config error = %v", err) - } - if !strings.Contains(resp, "已更新策略参数") { - t.Fatalf("expected strategy config update response, got %q", resp) - } - - listRaw := a.toolGetStrategies("user-1") - if !strings.Contains(listRaw, "优先观察BTC和ETH") || !strings.Contains(listRaw, `"max_positions":2`) || !strings.Contains(listRaw, `"min_confidence":80`) || !strings.Contains(listRaw, `"primary_timeframe":"15m"`) { - t.Fatalf("expected updated strategy config, got %s", listRaw) - } -} - -func TestTraderManagementAtomicBindingUpdate(t *testing.T) { - a := newTestAgentWithStore(t) - - modelOpenAI := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"openai", - "enabled":true, - "custom_api_url":"https://api.openai.com/v1", - "custom_model_name":"gpt-5-mini" - }`) - var openAI struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(modelOpenAI), &openAI); err != nil { - t.Fatalf("unmarshal openai model: %v", err) - } - modelDeepSeek := a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"deepseek", - "enabled":true, - "custom_api_url":"https://api.deepseek.com/v1", - "custom_model_name":"deepseek-chat" - }`) - var deepSeek struct { - Model safeModelToolConfig `json:"model"` - } - if err := json.Unmarshal([]byte(modelDeepSeek), &deepSeek); err != nil { - t.Fatalf("unmarshal deepseek model: %v", err) - } - - exchangeBinance := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"binance", - "account_name":"Binance 主账户", - "enabled":true - }`) - var binance struct { - Exchange safeExchangeToolConfig `json:"exchange"` - } - if err := json.Unmarshal([]byte(exchangeBinance), &binance); err != nil { - t.Fatalf("unmarshal binance exchange: %v", err) - } - exchangeOKX := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"okx", - "account_name":"OKX 主账户", - "enabled":true - }`) - var okx struct { - Exchange safeExchangeToolConfig `json:"exchange"` - } - if err := json.Unmarshal([]byte(exchangeOKX), &okx); err != nil { - t.Fatalf("unmarshal okx exchange: %v", err) - } - - strategyA := a.toolManageStrategy("user-1", `{"action":"create","name":"策略A","lang":"zh"}`) - var stA struct { - Strategy safeStrategyToolConfig `json:"strategy"` - } - if err := json.Unmarshal([]byte(strategyA), &stA); err != nil { - t.Fatalf("unmarshal strategy A: %v", err) - } - strategyB := a.toolManageStrategy("user-1", `{"action":"create","name":"策略B","lang":"zh"}`) - var stB struct { - Strategy safeStrategyToolConfig `json:"strategy"` - } - if err := json.Unmarshal([]byte(strategyB), &stB); err != nil { - t.Fatalf("unmarshal strategy B: %v", err) - } - - createTrader := a.toolManageTrader("user-1", `{ - "action":"create", - "name":"实盘一号", - "ai_model_id":"`+openAI.Model.ID+`", - "exchange_id":"`+binance.Exchange.ID+`", - "strategy_id":"`+stA.Strategy.ID+`" - }`) - var trader struct { - Trader safeTraderToolConfig `json:"trader"` - } - if err := json.Unmarshal([]byte(createTrader), &trader); err != nil { - t.Fatalf("unmarshal trader: %v", err) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 17, "zh", "更新交易员绑定,把实盘一号换成 deepseek-chat、OKX 主账户 和 策略B") - if err != nil { - t.Fatalf("update trader bindings error = %v", err) - } - if !strings.Contains(resp, "已更新交易员绑定") { - t.Fatalf("expected trader binding update response, got %q", resp) - } - - listRaw := a.toolListTraders("user-1") - if !strings.Contains(listRaw, deepSeek.Model.ID) || !strings.Contains(listRaw, okx.Exchange.ID) || !strings.Contains(listRaw, stB.Strategy.ID) { - t.Fatalf("expected trader bindings to change, got %s", listRaw) - } -} - -func TestStrategyManagementDeleteAllUserStrategies(t *testing.T) { - a := newTestAgentWithStore(t) - - for _, name := range []string{"趋势策略A", "趋势策略B"} { - resp := a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"`+name+`", - "lang":"zh" - }`) - if strings.Contains(resp, `"error"`) { - t.Fatalf("failed to create strategy %q: %s", name, resp) - } - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 21, "zh", "现在把所有的策略全部删除") - if err != nil { - t.Fatalf("thinkAndAct() bulk delete start error = %v", err) - } - if !strings.Contains(resp, "确认") || !strings.Contains(resp, "全部自定义策略") { - t.Fatalf("expected bulk delete confirmation, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 21, "zh", "确认") - if err != nil { - t.Fatalf("thinkAndAct() bulk delete confirm error = %v", err) - } - if !strings.Contains(resp, "成功删除 2 个") { - t.Fatalf("expected bulk delete success summary, got %q", resp) - } - - listResp := a.toolGetStrategies("user-1") - if strings.Contains(listResp, "趋势策略A") || strings.Contains(listResp, "趋势策略B") { - t.Fatalf("expected created strategies to be deleted, got %s", listResp) - } -} - -func TestCreateTraderSkillRejectsDisabledExchangeWithClearPrompt(t *testing.T) { - a := newTestAgentWithStore(t) - - _ = a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"deepseek", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.deepseek.com/v1", - "custom_model_name":"deepseek-chat" - }`) - enabledExchange := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"okx", - "account_name":"test", - "enabled":true - }`) - if strings.Contains(enabledExchange, `"error"`) { - t.Fatalf("failed to create enabled exchange: %s", enabledExchange) - } - anotherEnabledExchange := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"okx", - "account_name":"lky", - "enabled":true - }`) - if strings.Contains(anotherEnabledExchange, `"error"`) { - t.Fatalf("failed to create second enabled exchange: %s", anotherEnabledExchange) - } - disabledExchange := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"okx", - "account_name":"new", - "enabled":false - }`) - if strings.Contains(disabledExchange, `"error"`) { - t.Fatalf("failed to create disabled exchange: %s", disabledExchange) - } - _ = a.toolManageStrategy("user-1", `{"action":"create","name":"激进","lang":"zh"}`) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 24, "zh", "给我创建一个trader") - if err != nil { - t.Fatalf("create trader start error = %v", err) - } - if !strings.Contains(resp, "new(已禁用)") { - t.Fatalf("expected disabled exchange to be labelled, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 24, "zh", "名称叫test,交易所用new、策略用激进") - if err != nil { - t.Fatalf("disabled exchange selection error = %v", err) - } - if !strings.Contains(resp, "当前已禁用") { - t.Fatalf("expected disabled exchange warning, got %q", resp) - } -} - -func TestCancelReplyExitsExchangeUpdateFlow(t *testing.T) { - a := newTestAgentWithStore(t) - _ = a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"deepseek", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.deepseek.com/v1", - "custom_model_name":"deepseek-chat" - }`) - - exchangeResp := a.toolManageExchangeConfig("user-1", `{ - "action":"create", - "exchange_type":"okx", - "account_name":"test", - "enabled":true - }`) - if strings.Contains(exchangeResp, `"error"`) { - t.Fatalf("failed to create exchange: %s", exchangeResp) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 25, "zh", "把test这个交易所改一下") - if err != nil { - t.Fatalf("enter exchange update flow error = %v", err) - } - if !strings.Contains(resp, "请告诉我你要改什么") { - t.Fatalf("expected exchange update prompt, got %q", resp) - } - - resp, err = a.thinkAndAct(context.Background(), "user-1", 25, "zh", "不改") - if err != nil { - t.Fatalf("cancel exchange flow error = %v", err) - } - if !strings.Contains(resp, "已取消当前流程") { - t.Fatalf("expected flow cancellation, got %q", resp) - } -} - -func TestClassifySkillSessionInputInterruptsOnDeflection(t *testing.T) { - session := skillSession{Name: "exchange_management", Action: "update"} - a := &Agent{} - - if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "你能帮我看下报错吗"); got != "interrupt" { - t.Fatalf("expected diagnosis deflection to interrupt current skill flow, got %q", got) - } - if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "换话题了大哥"); got != "cancel" { - t.Fatalf("expected topic shift to cancel current skill flow, got %q", got) - } -} - -type skillSessionClassifierAIClient struct { - lastSystemPrompt string - lastUserPrompt string - response string -} - -func (c *skillSessionClassifierAIClient) SetAPIKey(string, string, string) {} -func (c *skillSessionClassifierAIClient) SetTimeout(time.Duration) {} -func (c *skillSessionClassifierAIClient) CallWithMessages(string, string) (string, error) { - return "", errors.New("unexpected CallWithMessages") -} -func (c *skillSessionClassifierAIClient) CallWithRequest(req *mcp.Request) (string, error) { - if len(req.Messages) > 0 { - c.lastSystemPrompt = req.Messages[0].Content - } - if len(req.Messages) > 1 { - c.lastUserPrompt = req.Messages[1].Content - } - return c.response, nil -} -func (c *skillSessionClassifierAIClient) CallWithRequestStream(*mcp.Request, func(string)) (string, error) { - return "", errors.New("unexpected CallWithRequestStream") -} -func (c *skillSessionClassifierAIClient) CallWithRequestFull(*mcp.Request) (*mcp.LLMResponse, error) { - return nil, errors.New("unexpected CallWithRequestFull") -} - -func TestClassifySkillSessionInputUsesSlotExpectationWithoutLLM(t *testing.T) { - client := &skillSessionClassifierAIClient{response: `{"decision":"interrupt"}`} - a := &Agent{aiClient: client} - session := skillSession{ - Name: "strategy_management", - Action: "update_config", - Fields: map[string]string{ - skillDAGStepField: "resolve_config_value", - "config_field": "min_confidence", - }, - } - - if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "70"); got != "continue" { - t.Fatalf("expected numeric slot fill to continue, got %q", got) - } - if client.lastSystemPrompt != "" { - t.Fatalf("expected no LLM call for direct slot expectation, got prompt %q", client.lastSystemPrompt) - } -} - -func TestClassifySkillSessionInputUsesLLMOnlyForAmbiguousDeflection(t *testing.T) { - client := &skillSessionClassifierAIClient{response: `{"decision":"interrupt"}`} - a := &Agent{ - aiClient: client, - history: newChatHistory(10), - } - session := skillSession{ - Name: "exchange_management", - Action: "update", - Fields: map[string]string{ - skillDAGStepField: "collect_account_name", - }, - } - - if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "你能帮我看下报错吗"); got != "interrupt" { - t.Fatalf("expected ambiguous deflection to interrupt, got %q", got) - } - if !strings.Contains(client.lastSystemPrompt, "classify one user message while a NOFXi structured management flow is active") { - t.Fatalf("expected LLM classifier prompt, got %q", client.lastSystemPrompt) - } -} - -func TestClassifySkillSessionInputUsesLLMForUnmatchedActiveSessionInput(t *testing.T) { - client := &skillSessionClassifierAIClient{response: `{"decision":"continue"}`} - a := &Agent{ - aiClient: client, - history: newChatHistory(10), - } - session := skillSession{ - Name: "model_management", - Action: "create", - Fields: map[string]string{ - skillDAGStepField: "collect_optional_fields", - "provider": "openai", - }, - } - - if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "新增一个"); got != "continue" { - t.Fatalf("expected unmatched active-session input to follow LLM decision, got %q", got) - } - if !strings.Contains(client.lastSystemPrompt, "classify one user message while a NOFXi structured management flow is active") { - t.Fatalf("expected LLM classifier prompt, got %q", client.lastSystemPrompt) - } -} - -func TestStrategyManagementCanDescribeDefaultConfig(t *testing.T) { - a := newTestAgentWithStore(t) - _ = a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"deepseek", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.deepseek.com/v1", - "custom_model_name":"deepseek-chat" - }`) - - resp, err := a.thinkAndAct(context.Background(), "user-1", 22, "zh", "看一下默认配置") - if err != nil { - t.Fatalf("thinkAndAct() default config error = %v", err) - } - if !strings.Contains(resp, "默认策略模板") || !strings.Contains(resp, "最低置信度") { - t.Fatalf("expected default strategy config response, got %q", resp) - } -} - -func TestStrategyManagementSupportsMultiFieldConfigUpdate(t *testing.T) { - a := newTestAgentWithStore(t) - _ = a.toolManageModelConfig("user-1", `{ - "action":"create", - "provider":"deepseek", - "enabled":true, - "api_key":"sk-test", - "custom_api_url":"https://api.deepseek.com/v1", - "custom_model_name":"deepseek-chat" - }`) - - createResp := a.toolManageStrategy("user-1", `{ - "action":"create", - "name":"趋势策略A", - "lang":"zh" - }`) - if strings.Contains(createResp, `"error"`) { - t.Fatalf("failed to create strategy: %s", createResp) - } - - resp, err := a.thinkAndAct(context.Background(), "user-1", 23, "zh", "把趋势策略A的最小置信度改成70,核心指标都全选") - if err != nil { - t.Fatalf("thinkAndAct() multi-field update error = %v", err) - } - if !strings.Contains(resp, "最小置信度") || !strings.Contains(resp, "EMA") { - t.Fatalf("expected multi-field update confirmation, got %q", resp) - } - - strategiesRaw := a.toolGetStrategies("user-1") - if !strings.Contains(strategiesRaw, `"min_confidence":70`) || - !strings.Contains(strategiesRaw, `"enable_ema":true`) || - !strings.Contains(strategiesRaw, `"enable_macd":true`) || - !strings.Contains(strategiesRaw, `"enable_rsi":true`) || - !strings.Contains(strategiesRaw, `"enable_atr":true`) || - !strings.Contains(strategiesRaw, `"enable_boll":true`) { - t.Fatalf("expected strategy config to include updated confidence and indicators, got %s", strategiesRaw) - } -} diff --git a/agent/skill_domain_context.go b/agent/skill_domain_context.go new file mode 100644 index 00000000..965f694b --- /dev/null +++ b/agent/skill_domain_context.go @@ -0,0 +1,209 @@ +package agent + +import "strings" + +func buildSkillDomainPrimer(lang, skillName string) string { + skillName = strings.TrimSpace(skillName) + if skillName == "" { + return "" + } + switch skillName { + case "model_management": + fields := []string{ + fieldKnowledgeDisplayName("provider", lang), + displayCatalogFieldName("name", lang), + displayCatalogFieldName("api_key", lang), + displayCatalogFieldName("custom_api_url", lang), + displayCatalogFieldName("custom_model_name", lang), + displayCatalogFieldName("enabled", lang), + } + if lang == "zh" { + return strings.Join([]string{ + "### 模型配置领域约束", + "- 当前领域是 AI 模型配置,不是交易所配置。", + "- provider 指模型厂商,不是交易所类型。", + "- 关键字段:" + strings.Join(fields, "、"), + "- 候选 provider:" + modelProviderSummaryList(lang), + "- 推荐 provider:claw402。claw402 是 NOFXi 官方推荐方案,按次付费,使用 Base 链 EVM 钱包 + USDC 支付。", + "- 如果用户不确定选哪个 provider,可以优先推荐 claw402 并说明其优势,但绝不能替用户自动选中 claw402;必须先展示完整 provider 选项并让用户自己选择。", + "- 如果 provider 还没选定,下一步必须先让用户从完整 provider 列表里选一个,不能先收集 API Key、钱包私钥或其他凭证。", + "- 普通 provider(openai/deepseek/claude 等)通常要填 API Key;custom_model_name 和 custom_api_url 可以留空走默认值。", + "- claw402 需要钱包私钥,custom_model_name 留空时默认 deepseek。", + "- blockrun-base / blockrun-sol 走钱包私钥模式,不需要 custom_api_url,custom_model_name 默认 auto。", + }, "\n") + } + return strings.Join([]string{ + "### Model Config Domain Guard", + "- The current domain is AI model configuration, not exchange configuration.", + "- provider means the model vendor, not an exchange venue.", + "- Key fields: " + strings.Join(fields, ", "), + "- Supported providers: " + modelProviderSummaryList(lang), + "- Recommended provider: claw402. claw402 is the NOFXi recommended pay-per-use option that uses a Base chain wallet + USDC.", + "- If the user is unsure which provider to pick, you may recommend claw402 and explain its advantages, but you must not auto-select claw402 for them. Show the full provider options first and let the user choose.", + "- If provider is still missing, the next step must be to ask the user to choose one from the full provider list. Do not ask for an API key, wallet private key, or other credentials before the provider is chosen.", + "- Standard providers (openai/deepseek/claude etc.) usually require an API key; `custom_model_name` and `custom_api_url` can be omitted to use defaults.", + "- claw402 uses a wallet private key and defaults to `deepseek` if `custom_model_name` is omitted.", + "- blockrun-base / blockrun-sol use wallet private keys, do not need `custom_api_url`, and default to `auto`.", + }, "\n") + case "exchange_management": + fields := []string{ + slotDisplayName("exchange_type", lang), + displayCatalogFieldName("account_name", lang), + displayCatalogFieldName("api_key", lang), + displayCatalogFieldName("secret_key", lang), + displayCatalogFieldName("passphrase", lang), + displayCatalogFieldName("enabled", lang), + } + if lang == "zh" { + return strings.Join([]string{ + "### 交易所配置领域约束", + "- 当前领域是交易所账户配置,不是 AI 模型配置。", + "- exchange_type 指交易所类型,provider 这个词不应用来代指交易所。", + "- 关键字段:" + strings.Join(fields, "、"), + "- 支持的交易所类型:" + strings.Join(enumOptionValues("exchange_management", "exchange_type"), "、"), + }, "\n") + } + return strings.Join([]string{ + "### Exchange Config Domain Guard", + "- The current domain is exchange account configuration, not AI model configuration.", + "- exchange_type means the trading venue. Do not use provider to mean an exchange.", + "- Key fields: " + strings.Join(fields, ", "), + "- Supported exchange types: " + strings.Join(enumOptionValues("exchange_management", "exchange_type"), ", "), + }, "\n") + case "trader_management": + fields := []string{ + slotDisplayName("name", lang), + slotDisplayName("exchange", lang), + slotDisplayName("model", lang), + slotDisplayName("strategy", lang), + displayCatalogFieldName("scan_interval_minutes", lang), + } + if lang == "zh" { + return strings.Join([]string{ + "### 交易员配置领域约束", + "- 交易员是装配层,负责创建、换绑策略/交易所/模型,以及启动、停止、删除、查询。", + "- 编辑交易员时,默认只处理绑定关系;不要顺手改策略、模型、交易所内部配置。", + "- 交易员初始余额由系统在创建时自动读取绑定交易所账户净值,不接受手动设置、充值或人为改余额。", + "- 若用户要改策略参数、模型配置或交易所凭证,应切到对应 management skill。", + "- 创建交易员时最关键的是:名称、交易所、模型、策略。", + "- 关键字段:" + strings.Join(fields, "、"), + }, "\n") + } + return strings.Join([]string{ + "### Trader Config Domain Guard", + "- Traders are the assembly layer: create, rebind strategy/exchange/model, and control lifecycle.", + "- When editing a trader, default to changing bindings only; do not silently edit the internals of the strategy, model, or exchange.", + "- Trader initial balance is auto-read from the bound exchange account equity at creation time; do not ask the user to set, top up, or manually edit trader balance.", + "- If the user wants to change strategy parameters, model config, or exchange credentials, switch to the corresponding management skill.", + "- The key create fields are name, exchange, model, and strategy.", + "- Key fields: " + strings.Join(fields, ", "), + }, "\n") + case "strategy_management": + fields := []string{ + slotDisplayName("name", lang), + displayCatalogFieldName("strategy_type", lang), + } + if lang == "zh" { + return strings.Join([]string{ + "### 策略配置领域约束", + "- 本领域只处理策略模板。", + "- strategy_type 选项:ai_trading、grid_trading。", + "- 用户提到 AI500、OI Top、OI Low、静态币种/固定币种这类选币来源时,属于 ai_trading。", + "- 策略类型确定后,只能使用当前类型的产品编辑页模板。", + "- 策略类型未确定时,只判断类型,不要展示或混合任一分支的具体配置字段。", + "- 关键字段:" + strings.Join(fields, "、"), + }, "\n") + } + return strings.Join([]string{ + "### Strategy Config Domain Guard", + "- This domain only handles strategy templates.", + "- strategy_type options: ai_trading, grid_trading.", + "- AI500, OI Top, OI Low, and static coin-source requests imply ai_trading.", + "- Once strategy_type is known, use only that product editor template.", + "- Before strategy_type is known, only determine the type; do not show or mix concrete fields from either branch.", + "- Key fields: " + strings.Join(fields, ", "), + }, "\n") + default: + return "" + } +} + +func buildSkillDomainPrimerForSession(lang string, session skillSession) string { + if session.Name != "strategy_management" { + return buildSkillDomainPrimer(lang, session.Name) + } + strategyType := explicitStrategyCreateType(session) + if strategyType == "" { + return buildSkillDomainPrimer(lang, session.Name) + } + if lang == "zh" { + switch strategyType { + case "ai_trading": + return strings.Join([]string{ + "### AI 策略模板", + "- 只使用 ai_trading 模板:strategy_type + ai_config + publish_config。", + "- config_patch 必须使用产品 schema 原值,不要使用展示文案:strategy_type=ai_trading;source_type 只能是 static、ai500、oi_top、oi_low;没有 mixed/混合模式。", + "- 时间周期必须输出为产品枚举字符串,例如 1m、3m、5m、15m、1h;selected_timeframes 必须是字符串数组,例如 [\"1m\",\"5m\",\"15m\"],不要输出 JSON 字符串。", + "- AI500/OI Top/OI Low 选币数量范围 1~10;static_coins 最多 10 个;selected_timeframes 最多 4 个;primary_count 10~30。", + "- BTC/ETH 最大杠杆 1~20;山寨币最大杠杆 1~20;min_confidence 50~100;min_risk_reward_ratio 1~10。", + "- AI 策略创建方案不要展示或询问非 AI 模板字段:投入金额、每笔固定投入、止损、日亏损限制、最大回撤、网格字段。", + }, "\n") + case "grid_trading": + return strings.Join([]string{ + "### 网格策略模板", + "- 只使用 grid_trading 模板:strategy_type + grid_config + publish_config;config_patch 必须使用产品 schema 原值,strategy_type=grid_trading。", + "- 交易对选项:BTCUSDT、ETHUSDT、SOLUSDT、BNBUSDT、XRPUSDT、DOGEUSDT。", + "- grid_count 5~50;total_investment 最小 100;leverage 1~5;atr_multiplier 1~5。", + "- total_investment 是用户实际投入/保证金预算,不是杠杆后的名义仓位;最大名义仓位约等于 total_investment × leverage。用户说“投入/总投入/本金/保证金”时默认映射到 total_investment。", + "- max_drawdown_pct 5~50;stop_loss_pct 1~20;daily_loss_limit_pct 1~30;direction_bias_ratio 0.55~0.90。", + "- 没有实时行情工具结果时,不要猜当前价格或手动价格上下界;推荐 use_atr_bounds=true 的 ATR 自动边界。", + "- 如果用户让你选择/推荐剩余网格参数,价格区间默认写入 use_atr_bounds=true;不要反问用户手动价格区间,也不要编造“当前 BTC/ETH 在某价附近”。", + }, "\n") + } + } + switch strategyType { + case "ai_trading": + return strings.Join([]string{ + "### AI Strategy Template", + "- Use only ai_trading: strategy_type + ai_config + publish_config.", + "- config_patch must use product schema raw values, not display labels: strategy_type=ai_trading; source_type is only static, ai500, oi_top, or oi_low; no mixed mode.", + "- Timeframes must be product enum strings such as 1m, 3m, 5m, 15m, 1h; selected_timeframes must be a JSON string array such as [\"1m\",\"5m\",\"15m\"], not a JSON-encoded string.", + "- AI500/OI source counts 1-10; static_coins at most 10; selected_timeframes at most 4; primary_count 10-30.", + "- BTC/ETH leverage 1-20; altcoin leverage 1-20; min_confidence 50-100; min_risk_reward_ratio 1-10.", + "- Do not show or ask for non-AI-template fields in AI strategy drafts: investment amount, fixed per-trade amount, stop loss, daily loss limit, max drawdown, or grid fields.", + }, "\n") + case "grid_trading": + return strings.Join([]string{ + "### Grid Strategy Template", + "- Use only grid_trading: strategy_type + grid_config + publish_config; config_patch must use product schema raw values with strategy_type=grid_trading.", + "- Symbol options: BTCUSDT, ETHUSDT, SOLUSDT, BNBUSDT, XRPUSDT, DOGEUSDT.", + "- grid_count 5-50; total_investment >=100; leverage 1-5; atr_multiplier 1-5.", + "- total_investment is the user's actual capital/margin budget, not leveraged notional exposure; maximum notional exposure is approximately total_investment * leverage. When the user says investment, capital, amount to put in, or margin, map it to total_investment by default.", + "- max_drawdown_pct 5-50; stop_loss_pct 1-20; daily_loss_limit_pct 1-30; direction_bias_ratio 0.55-0.90.", + "- Without fresh market data, do not guess the current price or manual upper/lower prices; recommend ATR auto bounds with use_atr_bounds=true.", + "- If the user asks you to choose/recommend the remaining grid parameters, default the price range to use_atr_bounds=true; do not ask for manual price bounds or invent statements like the current BTC/ETH price is near a value.", + }, "\n") + } + return buildSkillDomainPrimer(lang, session.Name) +} + +func buildManagementDomainPrimer(lang string) string { + if lang == "zh" { + return strings.Join([]string{ + "### 管理领域路由速记", + "- 模型/API Key/provider:model_management。", + "- 交易所账户/API 凭证:exchange_management。", + "- 交易员创建、启动、停止、绑定策略/模型/交易所:trader_management。", + "- 策略模板创建、查看、修改、删除、激活、复制:strategy_management。", + "- 这里只用于路由;具体字段和模板只在进入对应 skill 后注入。", + }, "\n") + } + return strings.Join([]string{ + "### Management Routing Cheat Sheet", + "- Model/API key/provider: model_management.", + "- Exchange account/API credentials: exchange_management.", + "- Trader create/start/stop/bind strategy/model/exchange: trader_management.", + "- Strategy template create/query/update/delete/activate/duplicate: strategy_management.", + "- This is only for routing; detailed fields/templates are injected after entering the selected skill.", + }, "\n") +} diff --git a/agent/skill_execution_handlers.go b/agent/skill_execution_handlers.go index 791e69fa..bed12c7b 100644 --- a/agent/skill_execution_handlers.go +++ b/agent/skill_execution_handlers.go @@ -1,112 +1,610 @@ package agent import ( + "context" "encoding/json" "fmt" "regexp" - "sort" "strconv" "strings" + "time" + "nofx/mcp" "nofx/store" ) var ( firstIntegerPattern = regexp.MustCompile(`\d+`) + firstFloatPattern = regexp.MustCompile(`\d+(?:\.\d+)?`) timeframeTokenRE = regexp.MustCompile(`(?i)\b\d{1,2}[mhdw]\b`) + coinSymbolTokenRE = regexp.MustCompile(`(?i)^(?:xyz:)?[a-z0-9._-]{2,20}(?:usdt|usd|-usdc)?$`) + quotedContentRE = regexp.MustCompile(`[“"]([^“”"]{1,200})[”"]`) ) -func parseStandaloneInteger(text string) (int, bool) { - match := firstIntegerPattern.FindString(strings.TrimSpace(text)) - if match == "" { - return 0, false - } - value, err := strconv.Atoi(match) - if err != nil { - return 0, false - } - return value, true +const ( + strategyPendingUpdateConfigField = "_pending_strategy_update_config" + strategyPendingUpdateWarnings = "_pending_strategy_update_warnings" + strategyPendingUpdateZhMsg = "_pending_strategy_update_zh_msg" + strategyPendingUpdateEnMsg = "_pending_strategy_update_en_msg" +) + +func generatedDraftRequiresConfirmation(session skillSession) bool { + return fieldValue(session, "_requires_generated_confirmation") == "true" } -func parseEnabledValue(text string) (bool, bool) { +func clearGeneratedDraftConfirmation(session *skillSession, keys ...string) { + if session == nil || session.Fields == nil { + return + } + delete(session.Fields, "_requires_generated_confirmation") + for _, key := range keys { + if strings.TrimSpace(key) != "" { + delete(session.Fields, key) + } + } +} + +func detectCatalogField(text string, catalog []entityFieldMeta) string { lower := strings.ToLower(strings.TrimSpace(text)) - switch { - case containsAny(lower, []string{"启用", "打开", "开启", "enable", "enabled", "on"}): - return true, true - case containsAny(lower, []string{"禁用", "关闭", "停用", "disable", "disabled", "off"}): - return false, true + if lower == "" { + return "" + } + if strings.Contains(lower, "api key index") || strings.Contains(lower, "lighter api key index") { + for _, meta := range catalog { + if meta.Key == "lighter_api_key_index" { + return meta.Key + } + } + } + bestKey := "" + bestLen := -1 + for _, meta := range catalog { + for _, keyword := range meta.Keywords { + normalized := strings.ToLower(strings.TrimSpace(keyword)) + if normalized == "" { + continue + } + if entityFieldExplicitlyMentioned(lower, []string{normalized}) && len([]rune(normalized)) > bestLen { + bestKey = meta.Key + bestLen = len([]rune(normalized)) + } + } + } + return bestKey +} + +func displayCatalogFieldName(field, lang string) string { + switch field { + case "name": + if lang == "zh" { + return "名称" + } + return "name" + case "ai_model_id": + if lang == "zh" { + return "模型" + } + return "model" + case "exchange_id": + if lang == "zh" { + return "交易所" + } + return "exchange" + case "strategy_id": + if lang == "zh" { + return "策略" + } + return "strategy" + case "initial_balance": + if lang == "zh" { + return "初始资金" + } + return "initial balance" + case "scan_interval_minutes": + if lang == "zh" { + return "扫描间隔" + } + return "scan interval" + case "is_cross_margin": + if lang == "zh" { + return "全仓模式" + } + return "cross margin" + case "show_in_competition": + if lang == "zh" { + return "竞技场显示" + } + return "show in competition" + case "enabled": + if lang == "zh" { + return "启用状态" + } + return "enabled state" + case "api_key": + return "API Key" + case "custom_api_url": + if lang == "zh" { + return "接口地址" + } + return "API URL" + case "custom_model_name": + if lang == "zh" { + return "模型名称" + } + return "model name" + case "account_name": + if lang == "zh" { + return "账户名" + } + return "account name" + case "exchange_type": + if lang == "zh" { + return "交易所类型" + } + return "exchange type" + case "secret_key": + return "Secret" + case "passphrase": + return "Passphrase" + case "testnet": + if lang == "zh" { + return "测试网" + } + return "testnet" + case "hyperliquid_wallet_addr": + if lang == "zh" { + return "Hyperliquid 钱包地址" + } + return "Hyperliquid wallet address" + case "hyperliquid_unified_account": + if lang == "zh" { + return "Hyperliquid Unified Account" + } + return "Hyperliquid unified account" + case "aster_user": + if lang == "zh" { + return "Aster User" + } + return "Aster user" + case "aster_signer": + if lang == "zh" { + return "Aster Signer" + } + return "Aster signer" + case "aster_private_key": + if lang == "zh" { + return "Aster 私钥" + } + return "Aster private key" + case "lighter_wallet_addr": + if lang == "zh" { + return "Lighter 钱包地址" + } + return "Lighter wallet address" + case "lighter_private_key": + if lang == "zh" { + return "Lighter 私钥" + } + return "Lighter private key" + case "lighter_api_key_private_key": + if lang == "zh" { + return "Lighter API Key 私钥" + } + return "Lighter API key private key" + case "lighter_api_key_index": + if lang == "zh" { + return "Lighter API Key Index" + } + return "Lighter API key index" default: - return false, false + if lang == "zh" { + return field + } + return field } } -func parseFlagValue(text string, keywords []string) (bool, bool) { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" || !containsAny(lower, keywords) { - return false, false - } - switch { - case containsAny(lower, []string{"启用", "打开", "开启", "使用", "用", "是", "true", "enable", "enabled", "on", "use"}): - return true, true - case containsAny(lower, []string{"禁用", "关闭", "停用", "不用", "不要", "否", "false", "disable", "disabled", "off", "don't use", "do not use"}): - return false, true - default: - return false, false - } -} - -func extractCredentialValue(text string, keywords []string) string { - if value := extractQuotedContent(text); value != "" && containsAny(strings.ToLower(text), keywords) { - return value - } - return extractPostKeywordName(text, keywords) -} - -func parseScanIntervalMinutes(text string) (int, bool) { - if value, ok := extractLabeledInt(text, []string{"扫描间隔", "扫描频率", "scan interval", "scan frequency"}); ok { - return value, true - } - lower := strings.ToLower(strings.TrimSpace(text)) - if !containsAny(lower, []string{"扫描间隔", "扫描频率", "scan interval", "scan frequency"}) { - return 0, false - } - return parseStandaloneInteger(text) -} - -func detectStrategyConfigField(text string) string { +func detectCatalogDomainFromText(text string) string { lower := strings.ToLower(strings.TrimSpace(text)) switch { - case containsAny(lower, []string{"最大持仓", "最多持仓", "max positions"}): - return "max_positions" - case containsAny(lower, []string{"最低置信度", "最小置信度", "min confidence"}): - return "min_confidence" - case containsAny(lower, []string{"btc/eth杠杆", "btc eth杠杆", "btc eth leverage", "btc/eth leverage", "主流币杠杆"}): - return "btceth_max_leverage" - case containsAny(lower, []string{"山寨币杠杆", "altcoin leverage", "alts leverage"}): - return "altcoin_max_leverage" - case containsAny(lower, []string{"ema"}): - return "enable_ema" - case containsAny(lower, []string{"macd"}): - return "enable_macd" - case containsAny(lower, []string{"rsi"}): - return "enable_rsi" - case containsAny(lower, []string{"atr"}): - return "enable_atr" - case containsAny(lower, []string{"boll", "bollinger", "布林"}): - return "enable_boll" - case containsAny(lower, []string{"核心指标"}) && containsAny(lower, []string{"全选", "全部", "全开", "都开", "都启用", "全部启用"}): - return "enable_all_core_indicators" - case containsAny(lower, []string{"主周期", "主时间周期", "primary timeframe"}): - return "primary_timeframe" - case containsAny(lower, []string{"多周期", "时间框架", "timeframes", "selected timeframes"}): - return "selected_timeframes" + case containsAny(lower, []string{"策略", "strategy"}): + return "strategy_management" + case containsAny(lower, []string{"交易所", "exchange"}): + return "exchange_management" + case containsAny(lower, []string{"模型", "model"}): + return "model_management" default: return "" } } +func (a *Agent) executeAtomicSkillWithSession(storeUserID string, userID int64, lang, text string, session skillSession) string { + if answer, ok := a.dispatchBridgedSkillSession(storeUserID, userID, lang, text, session); ok { + return answer + } + return "" +} + +func parseLooseTextValue(text string) string { + return "" +} + +func entityFieldExplicitlyMentioned(text string, keywords []string) bool { + if len(keywords) == 0 { + return false + } + return containsAny(strings.ToLower(text), keywords) +} + +type traderUpdateArgs struct { + AIModelID string + ExchangeID string + StrategyID string + ScanIntervalMinutes *int + IsCrossMargin *bool + ShowInCompetition *bool +} + +func (a traderUpdateArgs) hasAny() bool { + return a.AIModelID != "" || a.ExchangeID != "" || a.StrategyID != "" || + a.ScanIntervalMinutes != nil || a.IsCrossMargin != nil || a.ShowInCompetition != nil +} + +func parseStandaloneTraderUpdateArgs(text string) traderUpdateArgs { + return traderUpdateArgs{} +} + +func mergeTraderUpdateArgs(base, patch traderUpdateArgs) traderUpdateArgs { + if patch.AIModelID != "" { + base.AIModelID = patch.AIModelID + } + if patch.ExchangeID != "" { + base.ExchangeID = patch.ExchangeID + } + if patch.StrategyID != "" { + base.StrategyID = patch.StrategyID + } + if patch.ScanIntervalMinutes != nil { + base.ScanIntervalMinutes = patch.ScanIntervalMinutes + } + if patch.IsCrossMargin != nil { + base.IsCrossMargin = patch.IsCrossMargin + } + if patch.ShowInCompetition != nil { + base.ShowInCompetition = patch.ShowInCompetition + } + return base +} + +func applyTraderUpdateArgsToSession(session *skillSession, args traderUpdateArgs) { + if args.AIModelID != "" { + setField(session, "ai_model_id", args.AIModelID) + } + if args.ExchangeID != "" { + setField(session, "exchange_id", args.ExchangeID) + } + if args.StrategyID != "" { + setField(session, "strategy_id", args.StrategyID) + } + if args.ScanIntervalMinutes != nil { + setField(session, "scan_interval_minutes", strconv.Itoa(*args.ScanIntervalMinutes)) + } + if args.IsCrossMargin != nil { + setField(session, "is_cross_margin", strconv.FormatBool(*args.IsCrossMargin)) + } + if args.ShowInCompetition != nil { + setField(session, "show_in_competition", strconv.FormatBool(*args.ShowInCompetition)) + } +} + +func buildTraderUpdateArgsFromSession(session skillSession) traderUpdateArgs { + var args traderUpdateArgs + args.AIModelID = fieldValue(session, "ai_model_id") + args.ExchangeID = fieldValue(session, "exchange_id") + args.StrategyID = fieldValue(session, "strategy_id") + if value := fieldValue(session, "scan_interval_minutes"); value != "" { + if parsed, err := strconv.Atoi(value); err == nil { + args.ScanIntervalMinutes = &parsed + } + } + if value := fieldValue(session, "is_cross_margin"); value != "" { + parsed := value == "true" + args.IsCrossMargin = &parsed + } + if value := fieldValue(session, "show_in_competition"); value != "" { + parsed := value == "true" + args.ShowInCompetition = &parsed + } + return args +} + +type modelUpdatePatch struct { + Enabled *bool + APIKey string + CustomAPIURL string + CustomModelName string +} + +func (p modelUpdatePatch) hasAny() bool { + return p.Enabled != nil || p.APIKey != "" || p.CustomAPIURL != "" || p.CustomModelName != "" +} + +func applyModelUpdatePatchToSession(session *skillSession, patch modelUpdatePatch) { + if patch.CustomAPIURL != "" { + setField(session, "custom_api_url", patch.CustomAPIURL) + } + if patch.Enabled != nil { + setField(session, "enabled", strconv.FormatBool(*patch.Enabled)) + } + if patch.APIKey != "" { + setField(session, "api_key", patch.APIKey) + } + if patch.CustomModelName != "" { + setField(session, "custom_model_name", patch.CustomModelName) + } +} + +func mergeModelUpdatePatch(base, patch modelUpdatePatch) modelUpdatePatch { + if patch.Enabled != nil { + base.Enabled = patch.Enabled + } + if patch.APIKey != "" { + base.APIKey = patch.APIKey + } + if patch.CustomAPIURL != "" { + base.CustomAPIURL = patch.CustomAPIURL + } + if patch.CustomModelName != "" { + base.CustomModelName = patch.CustomModelName + } + return base +} + +func buildModelUpdatePatchFromSession(session skillSession) modelUpdatePatch { + var patch modelUpdatePatch + if value := fieldValue(session, "enabled"); value != "" { + parsed := value == "true" + patch.Enabled = &parsed + } + patch.APIKey = fieldValue(session, "api_key") + patch.CustomAPIURL = fieldValue(session, "custom_api_url") + patch.CustomModelName = fieldValue(session, "custom_model_name") + return patch +} + +type exchangeUpdatePatch struct { + AccountName string + Enabled *bool + APIKey string + SecretKey string + Passphrase string + Testnet *bool + HyperliquidWalletAddr string + AsterUser string + AsterSigner string + AsterPrivateKey string + LighterWalletAddr string + LighterAPIKeyPrivateKey string + LighterAPIKeyIndex *int +} + +func (p exchangeUpdatePatch) hasAny() bool { + return p.AccountName != "" || p.Enabled != nil || p.APIKey != "" || p.SecretKey != "" || + p.Passphrase != "" || p.Testnet != nil || p.HyperliquidWalletAddr != "" || p.AsterUser != "" || + p.AsterSigner != "" || p.AsterPrivateKey != "" || p.LighterWalletAddr != "" || + p.LighterAPIKeyPrivateKey != "" || p.LighterAPIKeyIndex != nil +} + +func applyExchangeUpdatePatchToSession(session *skillSession, patch exchangeUpdatePatch) { + if patch.AccountName != "" { + setField(session, "account_name", patch.AccountName) + } + if patch.Enabled != nil { + setField(session, "enabled", strconv.FormatBool(*patch.Enabled)) + } + if patch.APIKey != "" { + setField(session, "api_key", patch.APIKey) + } + if patch.SecretKey != "" { + setField(session, "secret_key", patch.SecretKey) + } + if patch.Passphrase != "" { + setField(session, "passphrase", patch.Passphrase) + } + if patch.Testnet != nil { + setField(session, "testnet", strconv.FormatBool(*patch.Testnet)) + } + if patch.HyperliquidWalletAddr != "" { + setField(session, "hyperliquid_wallet_addr", patch.HyperliquidWalletAddr) + } + if patch.AsterUser != "" { + setField(session, "aster_user", patch.AsterUser) + } + if patch.AsterSigner != "" { + setField(session, "aster_signer", patch.AsterSigner) + } + if patch.AsterPrivateKey != "" { + setField(session, "aster_private_key", patch.AsterPrivateKey) + } + if patch.LighterWalletAddr != "" { + setField(session, "lighter_wallet_addr", patch.LighterWalletAddr) + } + if patch.LighterAPIKeyPrivateKey != "" { + setField(session, "lighter_api_key_private_key", patch.LighterAPIKeyPrivateKey) + } + if patch.LighterAPIKeyIndex != nil { + setField(session, "lighter_api_key_index", strconv.Itoa(*patch.LighterAPIKeyIndex)) + } +} + +func mergeExchangeUpdatePatch(base, patch exchangeUpdatePatch) exchangeUpdatePatch { + if patch.AccountName != "" { + base.AccountName = patch.AccountName + } + if patch.Enabled != nil { + base.Enabled = patch.Enabled + } + if patch.APIKey != "" { + base.APIKey = patch.APIKey + } + if patch.SecretKey != "" { + base.SecretKey = patch.SecretKey + } + if patch.Passphrase != "" { + base.Passphrase = patch.Passphrase + } + if patch.Testnet != nil { + base.Testnet = patch.Testnet + } + if patch.HyperliquidWalletAddr != "" { + base.HyperliquidWalletAddr = patch.HyperliquidWalletAddr + } + if patch.AsterUser != "" { + base.AsterUser = patch.AsterUser + } + if patch.AsterSigner != "" { + base.AsterSigner = patch.AsterSigner + } + if patch.AsterPrivateKey != "" { + base.AsterPrivateKey = patch.AsterPrivateKey + } + if patch.LighterWalletAddr != "" { + base.LighterWalletAddr = patch.LighterWalletAddr + } + if patch.LighterAPIKeyPrivateKey != "" { + base.LighterAPIKeyPrivateKey = patch.LighterAPIKeyPrivateKey + } + if patch.LighterAPIKeyIndex != nil { + base.LighterAPIKeyIndex = patch.LighterAPIKeyIndex + } + return base +} + +func buildExchangeUpdatePatchFromSession(session skillSession) exchangeUpdatePatch { + var patch exchangeUpdatePatch + patch.AccountName = fieldValue(session, "account_name") + if value := fieldValue(session, "enabled"); value != "" { + parsed := value == "true" + patch.Enabled = &parsed + } + patch.APIKey = fieldValue(session, "api_key") + patch.SecretKey = fieldValue(session, "secret_key") + patch.Passphrase = fieldValue(session, "passphrase") + if value := fieldValue(session, "testnet"); value != "" { + parsed := value == "true" + patch.Testnet = &parsed + } + patch.HyperliquidWalletAddr = fieldValue(session, "hyperliquid_wallet_addr") + patch.AsterUser = fieldValue(session, "aster_user") + patch.AsterSigner = fieldValue(session, "aster_signer") + patch.AsterPrivateKey = fieldValue(session, "aster_private_key") + patch.LighterWalletAddr = fieldValue(session, "lighter_wallet_addr") + patch.LighterAPIKeyPrivateKey = fieldValue(session, "lighter_api_key_private_key") + if value := fieldValue(session, "lighter_api_key_index"); value != "" { + if parsed, err := strconv.Atoi(value); err == nil { + patch.LighterAPIKeyIndex = &parsed + } + } + return patch +} + func strategyConfigFieldDisplayName(field, lang string) string { switch field { + case "name": + if lang == "zh" { + return "名称" + } + return "name" + case "strategy_type": + if lang == "zh" { + return "策略类型" + } + return "strategy type" + case "symbol": + if lang == "zh" { + return "交易对" + } + return "symbol" + case "grid_count": + if lang == "zh" { + return "网格数量" + } + return "grid count" + case "total_investment": + if lang == "zh" { + return "总投资" + } + return "total investment" + case "upper_price": + if lang == "zh" { + return "上沿价格" + } + return "upper price" + case "lower_price": + if lang == "zh" { + return "下沿价格" + } + return "lower price" + case "use_atr_bounds": + if lang == "zh" { + return "ATR 自动边界" + } + return "use ATR bounds" + case "atr_multiplier": + if lang == "zh" { + return "ATR 倍数" + } + return "ATR multiplier" + case "distribution": + if lang == "zh" { + return "分布方式" + } + return "distribution" + case "enable_direction_adjust": + if lang == "zh" { + return "方向自适应" + } + return "enable direction adjust" + case "direction_bias_ratio": + if lang == "zh" { + return "方向偏置比例" + } + return "direction bias ratio" + case "max_drawdown_pct": + if lang == "zh" { + return "最大回撤" + } + return "max drawdown pct" + case "stop_loss_pct": + if lang == "zh" { + return "止损比例" + } + return "stop loss pct" + case "daily_loss_limit_pct": + if lang == "zh" { + return "日亏损限制" + } + return "daily loss limit pct" + case "use_maker_only": + if lang == "zh" { + return "仅 Maker" + } + return "use maker only" + case "description": + if lang == "zh" { + return "描述" + } + return "description" + case "is_public": + if lang == "zh" { + return "发布到市场" + } + return "publish to market" + case "config_visible": + if lang == "zh" { + return "配置可见" + } + return "config visible" case "max_positions": if lang == "zh" { return "最大持仓" @@ -117,6 +615,16 @@ func strategyConfigFieldDisplayName(field, lang string) string { return "最小置信度" } return "min confidence" + case "min_risk_reward_ratio": + if lang == "zh" { + return "最小盈亏比" + } + return "min risk reward ratio" + case "leverage": + if lang == "zh" { + return "杠杆" + } + return "leverage" case "btceth_max_leverage": if lang == "zh" { return "BTC/ETH 最大杠杆" @@ -127,6 +635,26 @@ func strategyConfigFieldDisplayName(field, lang string) string { return "山寨币最大杠杆" } return "altcoin max leverage" + case "btceth_max_position_value_ratio": + if lang == "zh" { + return "BTC/ETH 最大仓位价值倍数" + } + return "BTC/ETH max position value ratio" + case "altcoin_max_position_value_ratio": + if lang == "zh" { + return "山寨币最大仓位价值倍数" + } + return "altcoin max position value ratio" + case "max_margin_usage": + if lang == "zh" { + return "最大保证金使用率" + } + return "max margin usage" + case "min_position_size": + if lang == "zh" { + return "最小开仓金额" + } + return "min position size" case "enable_ema": if lang == "zh" { return "EMA" @@ -167,135 +695,272 @@ func strategyConfigFieldDisplayName(field, lang string) string { return "多周期时间框架" } return "selected timeframes" + case "source_type": + if lang == "zh" { + return "来源类型" + } + return "source type" + case "static_coins": + if lang == "zh" { + return "静态币种" + } + return "static coins" + case "excluded_coins": + if lang == "zh" { + return "排除币种" + } + return "excluded coins" + case "use_ai500": + if lang == "zh" { + return "AI500" + } + return "use AI500" + case "ai500_limit": + if lang == "zh" { + return "AI500 数量" + } + return "AI500 limit" + case "use_oi_top": + if lang == "zh" { + return "OI Top" + } + return "use OI Top" + case "oi_top_limit": + if lang == "zh" { + return "OI Top 数量" + } + return "OI Top limit" + case "use_oi_low": + if lang == "zh" { + return "OI Low" + } + return "use OI Low" + case "oi_low_limit": + if lang == "zh" { + return "OI Low 数量" + } + return "OI Low limit" + case "primary_count": + if lang == "zh" { + return "K线数量" + } + return "kline count" + case "ema_periods": + return "EMA periods" + case "rsi_periods": + return "RSI periods" + case "atr_periods": + return "ATR periods" + case "boll_periods": + return "BOLL periods" + case "enable_volume": + if lang == "zh" { + return "成交量" + } + return "volume" + case "enable_oi": + if lang == "zh" { + return "持仓量" + } + return "OI" + case "enable_funding_rate": + if lang == "zh" { + return "资金费率" + } + return "funding rate" + case "nofxos_api_key": + return "NofxOS API key" + case "enable_quant_data": + if lang == "zh" { + return "量化数据" + } + return "quant data" + case "enable_quant_oi": + return "quant OI" + case "enable_quant_netflow": + return "quant netflow" + case "enable_oi_ranking": + return "OI ranking" + case "oi_ranking_duration": + return "OI ranking duration" + case "oi_ranking_limit": + return "OI ranking limit" + case "enable_netflow_ranking": + return "netflow ranking" + case "netflow_ranking_duration": + return "netflow ranking duration" + case "netflow_ranking_limit": + return "netflow ranking limit" + case "enable_price_ranking": + return "price ranking" + case "price_ranking_duration": + return "price ranking duration" + case "price_ranking_limit": + return "price ranking limit" + case "role_definition": + if lang == "zh" { + return "角色定义" + } + return "role definition" + case "trading_frequency": + if lang == "zh" { + return "交易频率" + } + return "trading frequency" + case "entry_standards": + if lang == "zh" { + return "开仓标准" + } + return "entry standards" + case "decision_process": + if lang == "zh" { + return "决策流程" + } + return "decision process" + case "custom_prompt": + if lang == "zh" { + return "自定义 Prompt" + } + return "custom prompt" default: return field } } -func extractStrategyConfigValue(text, field string) (string, bool) { - switch field { - case "max_positions": - if value, ok := extractLabeledInt(text, []string{"最大持仓", "最多持仓", "max positions"}); ok { - return strconv.Itoa(value), true - } - if value, ok := parseStandaloneInteger(text); ok { - return strconv.Itoa(value), true - } - case "min_confidence": - if value, ok := extractLabeledInt(text, []string{"最低置信度", "最小置信度", "min confidence"}); ok { - return strconv.Itoa(value), true - } - if value, ok := parseStandaloneInteger(text); ok { - return strconv.Itoa(value), true - } - case "btceth_max_leverage": - if value, ok := extractLabeledInt(text, []string{"btc/eth杠杆", "btc eth杠杆", "btc/eth leverage", "btc eth leverage", "主流币杠杆"}); ok { - return strconv.Itoa(value), true - } - if value, ok := parseStandaloneInteger(text); ok { - return strconv.Itoa(value), true - } - case "altcoin_max_leverage": - if value, ok := extractLabeledInt(text, []string{"山寨币杠杆", "altcoin leverage", "alts leverage"}); ok { - return strconv.Itoa(value), true - } - if value, ok := parseStandaloneInteger(text); ok { - return strconv.Itoa(value), true - } - case "enable_ema", "enable_macd", "enable_rsi", "enable_atr", "enable_boll": - if enabled, ok := parseEnabledValue(text); ok { - return strconv.FormatBool(enabled), true - } - case "enable_all_core_indicators": - lower := strings.ToLower(strings.TrimSpace(text)) - switch { - case containsAny(lower, []string{"全选", "全部", "全开", "都开", "都启用", "全部启用"}): - return "true", true - case containsAny(lower, []string{"关闭", "停用", "禁用", "全部关闭", "全部禁用"}): - return "false", true - } - case "primary_timeframe": - if tf := extractTimeframeAfterKeywords(text, []string{"主周期", "主时间周期", "primary timeframe", "timeframe"}); tf != "" { - return tf, true - } - case "selected_timeframes": - if tfs := extractTimeframes(text); len(tfs) > 0 { - return strings.Join(tfs, ","), true - } - } - return "", false -} - -type strategyConfigPatch struct { - Field string - Value string -} - -func detectStrategyConfigPatches(text string) []strategyConfigPatch { - seen := map[string]string{} - addPatch := func(field, value string) { - field = strings.TrimSpace(field) - value = strings.TrimSpace(value) - if field == "" || value == "" { - return - } - seen[field] = value - } - - for _, field := range []string{ - "max_positions", - "min_confidence", - "btceth_max_leverage", - "altcoin_max_leverage", - "primary_timeframe", - "selected_timeframes", - "enable_ema", - "enable_macd", - "enable_rsi", - "enable_atr", - "enable_boll", - "enable_all_core_indicators", - } { - if value, ok := extractStrategyConfigValue(text, field); ok { - if field == "enable_all_core_indicators" { - addPatch("enable_ema", value) - addPatch("enable_macd", value) - addPatch("enable_rsi", value) - addPatch("enable_atr", value) - addPatch("enable_boll", value) - continue - } - addPatch(field, value) - } - } - - fields := make([]string, 0, len(seen)) - for field := range seen { - fields = append(fields, field) - } - sort.Strings(fields) - - patches := make([]strategyConfigPatch, 0, len(fields)) - for _, field := range fields { - patches = append(patches, strategyConfigPatch{Field: field, Value: seen[field]}) - } - return patches -} - func applyStrategyConfigPatch(cfg *store.StrategyConfig, field, value string) error { + ensureGridConfig := func() *store.GridStrategyConfig { + if cfg.GridConfig == nil { + defaults := store.GetDefaultStrategyConfig(cfg.Language) + if defaults.GridConfig != nil { + copy := *defaults.GridConfig + cfg.GridConfig = © + } else { + cfg.GridConfig = &store.GridStrategyConfig{} + } + } + return cfg.GridConfig + } + switch field { - case "max_positions": + case "strategy_type": + cfg.StrategyType = value + case "symbol": + ensureGridConfig().Symbol = value + case "grid_count": parsed, err := strconv.Atoi(value) if err != nil { - return fmt.Errorf("最大持仓需要是整数") + return fmt.Errorf("网格数量需要是整数") } - cfg.RiskControl.MaxPositions = parsed + ensureGridConfig().GridCount = parsed + case "total_investment": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("总投资需要是数字") + } + ensureGridConfig().TotalInvestment = parsed + case "upper_price": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("上沿价格需要是数字") + } + ensureGridConfig().UpperPrice = parsed + case "lower_price": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("下沿价格需要是数字") + } + ensureGridConfig().LowerPrice = parsed + case "use_atr_bounds": + ensureGridConfig().UseATRBounds = value == "true" + case "atr_multiplier": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("ATR 倍数需要是数字") + } + ensureGridConfig().ATRMultiplier = parsed + case "distribution": + ensureGridConfig().Distribution = value + case "enable_direction_adjust": + ensureGridConfig().EnableDirectionAdjust = value == "true" + case "direction_bias_ratio": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("方向偏置比例需要是数字") + } + ensureGridConfig().DirectionBiasRatio = parsed + case "max_drawdown_pct": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("最大回撤需要是数字") + } + ensureGridConfig().MaxDrawdownPct = parsed + case "stop_loss_pct": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("止损比例需要是数字") + } + ensureGridConfig().StopLossPct = parsed + case "daily_loss_limit_pct": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("日亏损限制需要是数字") + } + ensureGridConfig().DailyLossLimitPct = parsed + case "use_maker_only": + ensureGridConfig().UseMakerOnly = value == "true" + case "description", "is_public", "config_visible": + return nil + case "max_positions": + return fmt.Errorf("%s", strategyLockedFieldError("zh", field)) + case "source_type": + cfg.CoinSource.SourceType = value + case "static_coins": + cfg.CoinSource.StaticCoins = cleanStringList(strings.Split(value, ",")) + case "excluded_coins": + cfg.CoinSource.ExcludedCoins = cleanStringList(strings.Split(value, ",")) + case "use_ai500": + cfg.CoinSource.UseAI500 = value == "true" + case "ai500_limit": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("AI500 数量需要是整数") + } + cfg.CoinSource.AI500Limit = parsed + case "use_oi_top": + cfg.CoinSource.UseOITop = value == "true" + case "oi_top_limit": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("OI Top 数量需要是整数") + } + cfg.CoinSource.OITopLimit = parsed + case "use_oi_low": + cfg.CoinSource.UseOILow = value == "true" + case "oi_low_limit": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("OI Low 数量需要是整数") + } + cfg.CoinSource.OILowLimit = parsed case "min_confidence": parsed, err := strconv.Atoi(value) if err != nil { return fmt.Errorf("最小置信度需要是整数") } cfg.RiskControl.MinConfidence = parsed + case "min_risk_reward_ratio": + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("最小盈亏比需要是数字") + } + cfg.RiskControl.MinRiskRewardRatio = parsed + case "leverage": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("杠杆需要是整数") + } + cfg.RiskControl.BTCETHMaxLeverage = parsed + cfg.RiskControl.AltcoinMaxLeverage = parsed case "btceth_max_leverage": parsed, err := strconv.Atoi(value) if err != nil { @@ -308,12 +973,34 @@ func applyStrategyConfigPatch(cfg *store.StrategyConfig, field, value string) er return fmt.Errorf("山寨币最大杠杆需要是整数") } cfg.RiskControl.AltcoinMaxLeverage = parsed + case "btceth_max_position_value_ratio": + return fmt.Errorf("%s", strategyLockedFieldError("zh", field)) + case "altcoin_max_position_value_ratio": + return fmt.Errorf("%s", strategyLockedFieldError("zh", field)) + case "max_margin_usage": + return fmt.Errorf("%s", strategyLockedFieldError("zh", field)) + case "min_position_size": + return fmt.Errorf("%s", strategyLockedFieldError("zh", field)) case "primary_timeframe": cfg.Indicators.Klines.PrimaryTimeframe = value + case "primary_count": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("K线数量需要是整数") + } + cfg.Indicators.Klines.PrimaryCount = parsed case "selected_timeframes": tfs := strings.Split(value, ",") cfg.Indicators.Klines.SelectedTimeframes = tfs cfg.Indicators.Klines.EnableMultiTimeframe = len(tfs) > 1 + case "ema_periods": + cfg.Indicators.EMAPeriods = parseCSVIntegers(value) + case "rsi_periods": + cfg.Indicators.RSIPeriods = parseCSVIntegers(value) + case "atr_periods": + cfg.Indicators.ATRPeriods = parseCSVIntegers(value) + case "boll_periods": + cfg.Indicators.BOLLPeriods = parseCSVIntegers(value) case "enable_ema": cfg.Indicators.EnableEMA = value == "true" case "enable_macd": @@ -324,18 +1011,464 @@ func applyStrategyConfigPatch(cfg *store.StrategyConfig, field, value string) er cfg.Indicators.EnableATR = value == "true" case "enable_boll": cfg.Indicators.EnableBOLL = value == "true" + case "enable_volume": + cfg.Indicators.EnableVolume = value == "true" + case "enable_oi": + cfg.Indicators.EnableOI = value == "true" + case "enable_funding_rate": + cfg.Indicators.EnableFundingRate = value == "true" + case "nofxos_api_key": + cfg.Indicators.NofxOSAPIKey = value + case "enable_quant_data": + cfg.Indicators.EnableQuantData = value == "true" + case "enable_quant_oi": + cfg.Indicators.EnableQuantOI = value == "true" + case "enable_quant_netflow": + cfg.Indicators.EnableQuantNetflow = value == "true" + case "enable_oi_ranking": + cfg.Indicators.EnableOIRanking = value == "true" + case "oi_ranking_duration": + cfg.Indicators.OIRankingDuration = value + case "oi_ranking_limit": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("OI 排行数量需要是整数") + } + cfg.Indicators.OIRankingLimit = parsed + case "enable_netflow_ranking": + cfg.Indicators.EnableNetFlowRanking = value == "true" + case "netflow_ranking_duration": + cfg.Indicators.NetFlowRankingDuration = value + case "netflow_ranking_limit": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("资金流排行数量需要是整数") + } + cfg.Indicators.NetFlowRankingLimit = parsed + case "enable_price_ranking": + cfg.Indicators.EnablePriceRanking = value == "true" + case "price_ranking_duration": + cfg.Indicators.PriceRankingDuration = value + case "price_ranking_limit": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("涨跌幅排行数量需要是整数") + } + cfg.Indicators.PriceRankingLimit = parsed + case "role_definition": + cfg.PromptSections.RoleDefinition = value + case "trading_frequency": + cfg.PromptSections.TradingFrequency = value + case "entry_standards": + cfg.PromptSections.EntryStandards = value + case "decision_process": + cfg.PromptSections.DecisionProcess = value + case "custom_prompt": + cfg.CustomPrompt = value default: return fmt.Errorf("unsupported strategy config field: %s", field) } return nil } -func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { - if session.TargetRef == nil && session.Action != "query" && session.Action != "query_list" && session.Action != "create" { - if lang == "zh" { - return "请先告诉我你要操作哪个交易员。" +func parseSourceTypeValue(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"静态", "固定", "static"}): + return "static" + case containsAny(lower, []string{"ai500"}): + return "ai500" + case containsAny(lower, []string{"oi top"}): + return "oi_top" + case containsAny(lower, []string{"oi low"}): + return "oi_low" + default: + return "" + } +} + +func extractSymbolList(text string, labels []string) []string { + segment := extractLongSegmentAfterKeywords(text, labels) + if segment == "" { + return nil + } + parts := strings.FieldsFunc(segment, func(r rune) bool { + return r == ',' || r == ',' || r == '、' || r == ' ' || r == '\n' || r == '\t' + }) + out := make([]string, 0, len(parts)) + for _, part := range parts { + if !looksLikeCoinSymbol(part) { + continue } - return "Please specify which trader you want to manage." + part = normalizeCoinSymbol(part) + if part == "" { + continue + } + out = append(out, part) + } + return cleanStringList(out) +} + +func looksLikeCoinSymbol(value string) bool { + value = strings.TrimSpace(value) + if value == "" { + return false + } + value = strings.Trim(value, `"'“”‘’()[]{}<>`) + value = strings.TrimSpace(value) + if value == "" { + return false + } + return coinSymbolTokenRE.MatchString(value) +} + +func normalizeCoinSymbol(symbol string) string { + symbol = strings.TrimSpace(strings.ToUpper(symbol)) + if symbol == "" { + return "" + } + if strings.HasPrefix(symbol, "XYZ:") { + return symbol + } + if strings.HasSuffix(symbol, "USDT") || strings.HasSuffix(symbol, "USD") || strings.HasSuffix(symbol, "-USDC") { + return symbol + } + return symbol + "USDT" +} + +func extractIntegerList(text string) []string { + matches := firstIntegerPattern.FindAllString(text, -1) + if len(matches) == 0 { + return nil + } + return matches +} + +func parseCSVIntegers(value string) []int { + parts := strings.Split(value, ",") + out := make([]int, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + n, err := strconv.Atoi(part) + if err != nil { + continue + } + out = append(out, n) + } + return out +} + +func extractDurationValue(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case strings.Contains(lower, "1h,4h,24h"): + return "1h,4h,24h" + case strings.Contains(lower, "24h"): + return "24h" + case strings.Contains(lower, "4h"): + return "4h" + case strings.Contains(lower, "1h"): + return "1h" + default: + return "" + } +} + +func parseStrategyTypeValue(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case lower == "grid_trading": + return "grid_trading" + case lower == "ai_trading": + return "ai_trading" + case containsAny(lower, []string{"grid", "网格"}): + return "grid_trading" + case containsAny(lower, []string{"ai500", "oi top", "oi low", "静态币", "固定币", "选币来源"}): + return "ai_trading" + case containsAny(lower, []string{"ai trading", "ai策略", "ai 策略", "ai交易", "ai 交易", "ai智能", "智能策略", "普通策略"}): + return "ai_trading" + default: + return "" + } +} + +func extractLongSegmentAfterKeywords(text string, keywords []string) string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return "" + } + lower := strings.ToLower(trimmed) + for _, keyword := range keywords { + idx := strings.Index(lower, strings.ToLower(keyword)) + if idx < 0 { + continue + } + segment := strings.TrimSpace(trimmed[idx+len(keyword):]) + segment = strings.TrimLeft(segment, "“”\"':: ") + for _, prefix := range []string{"改成", "改为", "设为", "设置为", "变成"} { + segment = strings.TrimSpace(strings.TrimPrefix(segment, prefix)) + } + for _, marker := range []string{"排除币", "excluded coins", "exclude coins", "ai500", "oi top", "oi low", "并且", "然后"} { + if cut := strings.Index(strings.ToLower(segment), marker); cut > 0 { + segment = strings.TrimSpace(segment[:cut]) + break + } + } + segment = strings.Trim(segment, "“”\"':: ") + if segment != "" { + return segment + } + } + return "" +} + +func extractDelimitedSegmentAfterKeywords(text string, keywords []string) string { + segment := extractLongSegmentAfterKeywords(text, keywords) + if segment == "" { + return "" + } + for _, marker := range []string{",", ",", "。", ".", ";", ";", "\n", "\t", "并且", "然后"} { + if cut := strings.Index(segment, marker); cut > 0 { + segment = strings.TrimSpace(segment[:cut]) + break + } + } + return strings.Trim(segment, "“”\"':: ") +} + +func extractModelNameValue(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + if !containsAny(lower, []string{"模型名", "模型名称", "model name"}) { + return "" + } + if value := extractDelimitedSegmentAfterKeywords(text, []string{"model name", "模型名称", "模型名"}); value != "" { + return value + } + if containsAny(lower, []string{"改成", "改为"}) { + if value := extractDelimitedSegmentAfterKeywords(text, []string{"改成", "改为"}); value != "" { + return value + } + } + if value := extractQuotedContent(text); value != "" { + return value + } + return "" +} + +func sanitizeExtractedURL(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + for _, marker := range []string{",", ",", "。", ";", ";", "并且", "然后"} { + if cut := strings.Index(raw, marker); cut > 0 { + raw = strings.TrimSpace(raw[:cut]) + break + } + } + return raw +} + +func strategyFieldKeywords(field string) []string { + switch field { + case "source_type": + return []string{"来源类型", "source type", "选币来源", "静态来源", "ai500来源", "oi top来源", "oi low来源"} + case "strategy_type": + return []string{"策略类型", "strategy type", "网格策略", "grid strategy", "ai策略"} + case "symbol": + return []string{"交易对", "symbol", "币对"} + case "grid_count": + return []string{"网格数量", "grid count", "grid levels"} + case "total_investment": + return []string{"总投入", "总投资", "total investment"} + case "upper_price": + return []string{"上沿价格", "上限价格", "upper price"} + case "lower_price": + return []string{"下沿价格", "下限价格", "lower price"} + case "use_atr_bounds": + return []string{"atr自动边界", "atr边界", "use atr bounds"} + case "atr_multiplier": + return []string{"atr倍数", "atr multiplier"} + case "distribution": + return []string{"分布方式", "distribution", "均匀分布", "高斯分布", "金字塔分布"} + case "enable_direction_adjust": + return []string{"方向调整", "direction adjust"} + case "direction_bias_ratio": + return []string{"方向偏置", "bias ratio", "direction bias"} + case "max_drawdown_pct": + return []string{"最大回撤", "max drawdown"} + case "stop_loss_pct": + return []string{"止损比例", "stop loss"} + case "daily_loss_limit_pct": + return []string{"日亏损限制", "daily loss limit"} + case "use_maker_only": + return []string{"maker only", "只挂maker", "仅maker"} + case "description": + return []string{"描述", "description"} + case "is_public": + return []string{"发布到市场", "公开", "publish"} + case "config_visible": + return []string{"配置可见", "显示配置", "config visible"} + case "nofxos_api_key": + return []string{"nofxos api key", "nofxos key", "api key"} + case "role_definition": + return []string{"角色定义", "role definition"} + case "trading_frequency": + return []string{"交易频率", "trading frequency"} + case "entry_standards": + return []string{"开仓标准", "入场标准", "entry standards"} + case "decision_process": + return []string{"决策流程", "decision process"} + case "custom_prompt": + return []string{"自定义prompt", "custom prompt", "提示词"} + case "ema_periods": + return []string{"ema周期", "ema periods"} + case "rsi_periods": + return []string{"rsi周期", "rsi periods"} + case "atr_periods": + return []string{"atr周期", "atr periods"} + case "boll_periods": + return []string{"boll周期", "布林周期", "boll periods"} + case "oi_ranking_duration": + return []string{"oi ranking duration", "oi排行周期"} + case "netflow_ranking_duration": + return []string{"netflow ranking duration", "资金流排行周期"} + case "price_ranking_duration": + return []string{"price ranking duration", "涨跌幅排行周期"} + case "oi_ranking_limit": + return []string{"oi ranking limit", "oi排行数量"} + case "netflow_ranking_limit": + return []string{"netflow ranking limit", "资金流排行数量"} + case "price_ranking_limit": + return []string{"price ranking limit", "涨跌幅排行数量"} + case "btceth_max_position_value_ratio": + return []string{"btc/eth仓位价值倍数", "btc eth position value", "主流币仓位价值倍数"} + case "altcoin_max_position_value_ratio": + return []string{"山寨币仓位价值倍数", "altcoin position value"} + case "max_margin_usage": + return []string{"最大保证金使用率", "max margin usage"} + default: + return nil + } +} + +func matchesStrategyFieldKeywords(text, field string) bool { + keywords := strategyFieldKeywords(field) + if len(keywords) == 0 { + return true + } + return containsAny(strings.ToLower(text), keywords) +} + +func strategyFieldExplicitlyMentioned(text, field string) bool { + keywords := strategyFieldKeywords(field) + if len(keywords) == 0 { + switch field { + case "max_positions": + keywords = []string{"最大持仓", "最多持仓", "max positions"} + case "symbol": + keywords = []string{"交易对", "symbol", "币对"} + case "grid_count": + keywords = []string{"网格数量", "grid count", "grid levels"} + case "total_investment": + keywords = []string{"总投入", "总投资", "total investment"} + case "upper_price": + keywords = []string{"上沿价格", "上限价格", "upper price"} + case "lower_price": + keywords = []string{"下沿价格", "下限价格", "lower price"} + case "use_atr_bounds": + keywords = []string{"atr自动边界", "atr边界", "use atr bounds"} + case "atr_multiplier": + keywords = []string{"atr倍数", "atr multiplier"} + case "distribution": + keywords = []string{"分布方式", "distribution", "均匀分布", "高斯分布", "金字塔分布"} + case "enable_direction_adjust": + keywords = []string{"方向调整", "direction adjust"} + case "direction_bias_ratio": + keywords = []string{"方向偏置", "bias ratio", "direction bias"} + case "max_drawdown_pct": + keywords = []string{"最大回撤", "max drawdown"} + case "stop_loss_pct": + keywords = []string{"止损比例", "stop loss"} + case "daily_loss_limit_pct": + keywords = []string{"日亏损限制", "daily loss limit"} + case "use_maker_only": + keywords = []string{"maker only", "只挂maker", "仅maker"} + case "min_confidence": + keywords = []string{"最低置信度", "最小置信度", "min confidence"} + case "min_risk_reward_ratio": + keywords = []string{"最小盈亏比", "风险回报比", "risk reward", "risk/reward"} + case "leverage": + keywords = []string{"杠杆", "leverage"} + case "btceth_max_leverage": + keywords = []string{"btc/eth杠杆", "btc eth杠杆", "btc/eth leverage", "btc eth leverage", "主流币杠杆"} + case "altcoin_max_leverage": + keywords = []string{"山寨币杠杆", "altcoin leverage", "alts leverage"} + case "btceth_max_position_value_ratio": + keywords = []string{"btc/eth仓位价值倍数", "btc eth position value", "主流币仓位价值倍数"} + case "altcoin_max_position_value_ratio": + keywords = []string{"山寨币仓位价值倍数", "altcoin position value"} + case "max_margin_usage": + keywords = []string{"最大保证金使用率", "max margin usage"} + case "primary_timeframe": + keywords = []string{"主周期", "主时间周期", "primary timeframe"} + case "primary_count": + keywords = []string{"k线数量", "k线根数", "primary count", "kline count"} + case "selected_timeframes": + keywords = []string{"多周期", "时间框架", "timeframes", "selected timeframes"} + case "enable_ema": + keywords = []string{"ema"} + case "enable_macd": + keywords = []string{"macd"} + case "enable_rsi": + keywords = []string{"rsi"} + case "enable_atr": + keywords = []string{"atr"} + case "enable_boll": + keywords = []string{"boll", "bollinger", "布林"} + case "enable_volume": + keywords = []string{"成交量", "volume"} + case "enable_oi": + keywords = []string{"持仓量", "open interest", "oi"} + case "enable_funding_rate": + keywords = []string{"资金费率", "funding rate"} + case "source_type": + keywords = []string{"来源类型", "source type", "选币来源"} + case "static_coins": + keywords = []string{"静态币", "固定币", "static coins", "static symbols"} + case "excluded_coins": + keywords = []string{"排除币", "排除币种", "excluded coins", "exclude coins"} + case "use_ai500": + keywords = []string{"ai500"} + case "ai500_limit": + keywords = []string{"ai500 limit", "ai500数量", "ai500上限"} + case "use_oi_top": + keywords = []string{"oi top", "持仓量增长", "持仓量排行上涨"} + case "oi_top_limit": + keywords = []string{"oi top limit", "oi top数量", "oi top上限"} + case "use_oi_low": + keywords = []string{"oi low", "持仓量下降", "持仓量排行下跌"} + case "oi_low_limit": + keywords = []string{"oi low limit", "oi low数量", "oi low上限"} + case "enable_all_core_indicators": + keywords = []string{"核心指标"} + } + } + if len(keywords) == 0 { + return false + } + return containsAny(strings.ToLower(text), keywords) +} + +func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { + if session.Action == "query_strategy_binding" || session.Action == "query_exchange_binding" || session.Action == "query_model_binding" { + if detail, ok := a.describeTrader(storeUserID, lang, session.TargetRef); ok { + return detail + } + return formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)) } switch session.Action { case "query", "query_list": @@ -346,10 +1479,19 @@ func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, } return formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)) case "start", "stop", "delete": + if session.TargetRef == nil && !(session.Action == "delete" && fieldValue(session, "bulk_scope") == "all") { + if lang == "zh" { + return "请先指定要操作的交易员。" + } + return "Please specify which trader to operate on." + } if fieldValue(session, skillDAGStepField) == "" { setSkillDAGStep(&session, "await_confirmation") } - if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { + if session.Action == "delete" && fieldValue(session, "bulk_scope") == "all" { + return a.executeBulkTraderDelete(storeUserID, userID, lang, text, session) + } + if msg, waiting := a.beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { a.saveSkillSession(userID, session) return msg } @@ -380,20 +1522,17 @@ func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, return fmt.Sprintf("已完成交易员操作:%s。", session.Action) } return fmt.Sprintf("Completed trader action: %s.", session.Action) - case "update", "update_name", "update_bindings": - if session.Action == "update_bindings" { + case "update", "update_bindings", "configure_strategy", "configure_exchange", "configure_model": + if session.Action == "update_bindings" || session.Action == "configure_strategy" || session.Action == "configure_exchange" || session.Action == "configure_model" { if fieldValue(session, skillDAGStepField) == "" { setSkillDAGStep(&session, "collect_bindings") } - args := manageTraderArgs{Action: "update", TraderID: session.TargetRef.ID} - if match := pickMentionedOption(text, a.loadEnabledModelOptions(storeUserID)); match != nil { - args.AIModelID = match.ID - } - if match := pickMentionedOption(text, a.loadExchangeOptions(storeUserID)); match != nil { - args.ExchangeID = match.ID - } - if match := pickMentionedOption(text, a.loadStrategyOptions(storeUserID)); match != nil { - args.StrategyID = match.ID + args := manageTraderArgs{ + Action: "update", + TraderID: session.TargetRef.ID, + AIModelID: fieldValue(session, "ai_model_id"), + ExchangeID: fieldValue(session, "exchange_id"), + StrategyID: fieldValue(session, "strategy_id"), } if args.AIModelID != "" { setField(&session, "ai_model_id", args.AIModelID) @@ -404,120 +1543,373 @@ func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, if args.StrategyID != "" { setField(&session, "strategy_id", args.StrategyID) } - if value := fieldValue(session, "ai_model_id"); value != "" { - args.AIModelID = value - } - if value := fieldValue(session, "exchange_id"); value != "" { - args.ExchangeID = value - } - if value := fieldValue(session, "strategy_id"); value != "" { - args.StrategyID = value + selectedField := fieldValue(session, "update_field") + if selectedField == "" { + switch session.Action { + case "configure_strategy": + selectedField = "strategy_id" + case "configure_exchange": + selectedField = "exchange_id" + case "configure_model": + selectedField = "ai_model_id" + default: + if args.AIModelID == "" && args.ExchangeID == "" && args.StrategyID == "" { + selectedField = detectCatalogField(text, traderFieldCatalog) + } + } + if selectedField == "name" || selectedField == "scan_interval_minutes" || selectedField == "is_cross_margin" || selectedField == "show_in_competition" { + selectedField = "" + } + if selectedField != "" { + setField(&session, "update_field", selectedField) + } } if args.AIModelID == "" && args.ExchangeID == "" && args.StrategyID == "" { + if fieldValue(session, "inline_sub_intent") == "create_sub_resource" { + delete(session.Fields, "inline_sub_intent") + a.saveSkillSession(userID, session) + task := a.buildSuspendedTask(userID, lang) + if task.Kind != "" && task.SkillSession != nil { + task.ResumeOnSuccess = true + var childSkill, childResumeTrigger string + switch session.Action { + case "configure_strategy": + childSkill = "strategy_management" + childResumeTrigger = "strategy_management" + case "configure_exchange": + childSkill = "exchange_management" + childResumeTrigger = "exchange_management" + case "configure_model": + childSkill = "model_management" + childResumeTrigger = "model_management" + case "create": + // infer child skill from which binding slot is missing + slots := session.Slots + if slots == nil || slots.StrategyID == "" { + childSkill = "strategy_management" + childResumeTrigger = "strategy_management" + } else if slots.ExchangeID == "" { + childSkill = "exchange_management" + childResumeTrigger = "exchange_management" + } else if slots.ModelID == "" { + childSkill = "model_management" + childResumeTrigger = "model_management" + } + } + if childSkill != "" { + task.ResumeTriggers = []string{childResumeTrigger} + a.SnapshotManager(userID).Save(task) + a.clearSkillSession(userID) + child := skillSession{Name: childSkill, Action: "create", Phase: "collecting"} + var answer string + var handled bool + switch childSkill { + case "strategy_management": + answer, handled = a.handleStrategyManagementSkill(storeUserID, userID, lang, text, child) + case "exchange_management": + answer, handled = a.handleExchangeManagementSkill(storeUserID, userID, lang, text, child) + case "model_management": + answer, handled = a.handleModelManagementSkill(storeUserID, userID, lang, text, child) + } + if !handled { + answer = "" + } + return a.maybeResumeParentTaskAfterSuccessfulSkill(storeUserID, userID, lang, childSkill, "create", answer) + } + } + } + if fieldValue(session, "inline_sub_intent") == "edit_sub_resource" { + delete(session.Fields, "inline_sub_intent") + a.saveSkillSession(userID, session) + task := a.buildSuspendedTask(userID, lang) + if task.Kind != "" && task.SkillSession != nil { + task.ResumeOnSuccess = true + var childSkill string + switch session.Action { + case "configure_strategy": + childSkill = "strategy_management" + case "configure_exchange": + childSkill = "exchange_management" + case "configure_model": + childSkill = "model_management" + case "create", "update_bindings": + childSkill = detectCatalogDomainFromText(text) + } + if childSkill != "" { + task.ResumeTriggers = []string{childSkill} + a.SnapshotManager(userID).Save(task) + a.clearSkillSession(userID) + child := skillSession{Name: childSkill, Action: "update", Phase: "collecting"} + var answer string + var handled bool + switch childSkill { + case "strategy_management": + answer, handled = a.handleStrategyManagementSkill(storeUserID, userID, lang, text, child) + case "exchange_management": + answer, handled = a.handleExchangeManagementSkill(storeUserID, userID, lang, text, child) + case "model_management": + answer, handled = a.handleModelManagementSkill(storeUserID, userID, lang, text, child) + } + if !handled { + answer = "" + } + return a.maybeResumeParentTaskAfterSuccessfulSkill(storeUserID, userID, lang, childSkill, "update", answer) + } + } + } setSkillDAGStep(&session, "collect_bindings") a.saveSkillSession(userID, session) if lang == "zh" { - return "这次是更新交易员绑定,请直接说要换成哪个模型、交易所或策略。" + if selectedField != "" { + return fmt.Sprintf("还差一步:请告诉我你想换成哪个%s。", displayCatalogFieldName(selectedField, lang)) + } + switch session.Action { + case "configure_strategy": + return "好,我来帮你换策略。直接告诉我想用哪个策略就行。" + case "configure_exchange": + return "好,我来帮你换交易所。直接告诉我想用哪个交易所就行。" + case "configure_model": + return "好,我来帮你换模型。直接告诉我想用哪个模型就行。" + default: + return "好,我来帮你调整交易员绑定。你直接告诉我想换成哪个模型、交易所或策略就行。" + } + } + if selectedField != "" { + return fmt.Sprintf("One more thing: tell me which %s you want to use.", displayCatalogFieldName(selectedField, lang)) + } + switch session.Action { + case "configure_strategy": + return "Sure. Tell me which strategy you want to use." + case "configure_exchange": + return "Sure. Tell me which exchange you want to use." + case "configure_model": + return "Sure. Tell me which model you want to use." + default: + return "Sure. Tell me which model, exchange, or strategy you want to switch to." } - return "This action updates trader bindings. Tell me which model, exchange, or strategy to switch to." } setSkillDAGStep(&session, "execute_update") resp := a.toolUpdateTrader(storeUserID, args) a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "更新交易员绑定失败:" + errMsg + return "这次没改成功:" + errMsg } - return "Failed to update trader bindings: " + errMsg + return "That change did not go through: " + errMsg } + a.rememberReferencesFromToolResult(userID, "manage_trader", resp) if lang == "zh" { - return "已更新交易员绑定。" + switch session.Action { + case "configure_strategy": + return "已更新交易员策略。" + case "configure_exchange": + return "已更新交易员交易所。" + case "configure_model": + return "已更新交易员模型。" + default: + return "已更新交易员绑定。" + } + } + switch session.Action { + case "configure_strategy": + return "Updated the trader strategy." + case "configure_exchange": + return "Updated the trader exchange." + case "configure_model": + return "Updated the trader model." + default: + return "Updated trader bindings." } - return "Updated trader bindings." } if fieldValue(session, skillDAGStepField) == "" { setSkillDAGStep(&session, "collect_name") } - args := manageTraderArgs{Action: "update", TraderID: session.TargetRef.ID} - if minutes, ok := parseScanIntervalMinutes(text); ok && minutes > 0 { - args.ScanIntervalMinutes = &minutes - } - if value, ok := extractStrategyConfigValue(text, "btceth_max_leverage"); ok { - if parsed, err := strconv.Atoi(value); err == nil { - args.BTCETHLeverage = &parsed + parsedArgs := buildTraderUpdateArgsFromSession(session) + selectedField := fieldValue(session, "update_field") + if selectedField == "" { + if !parsedArgs.hasAny() { + selectedField = detectCatalogField(text, traderFieldCatalog) + } + if selectedField != "" { + setField(&session, "update_field", selectedField) } } - if value, ok := extractStrategyConfigValue(text, "altcoin_max_leverage"); ok { - if parsed, err := strconv.Atoi(value); err == nil { - args.AltcoinLeverage = &parsed + applyTraderUpdateArgsToSession(&session, parsedArgs) + parsedArgs = mergeTraderUpdateArgs(buildTraderUpdateArgsFromSession(session), parsedArgs) + if parsedArgs.hasAny() { + normalizedArgs, warnings := normalizeTraderArgsToManualLimits(lang, parsedArgs) + applyTraderUpdateArgsToSession(&session, normalizedArgs) + args := manageTraderArgs{ + Action: "update", + TraderID: session.TargetRef.ID, + AIModelID: normalizedArgs.AIModelID, + ExchangeID: normalizedArgs.ExchangeID, + StrategyID: normalizedArgs.StrategyID, + ScanIntervalMinutes: normalizedArgs.ScanIntervalMinutes, + IsCrossMargin: normalizedArgs.IsCrossMargin, + ShowInCompetition: normalizedArgs.ShowInCompetition, } - } - if prompt := extractCredentialValue(text, []string{"自定义提示词", "提示词", "custom prompt", "prompt"}); prompt != "" && - containsAny(strings.ToLower(text), []string{"提示词", "prompt"}) { - args.CustomPrompt = prompt - } - if enabled, ok := parseFlagValue(text, []string{"ai500"}); ok { - args.UseAI500 = &enabled - } - if enabled, ok := parseFlagValue(text, []string{"oi top", "oitop", "持仓量排名"}); ok { - args.UseOITop = &enabled - } - if args.ScanIntervalMinutes != nil || args.BTCETHLeverage != nil || args.AltcoinLeverage != nil || args.CustomPrompt != "" || args.UseAI500 != nil || args.UseOITop != nil { setSkillDAGStep(&session, "execute_update") resp := a.toolUpdateTrader(storeUserID, args) a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "更新交易员失败:" + errMsg + return "这次没改成功:" + errMsg } - return "Failed to update trader: " + errMsg + return "That change did not go through: " + errMsg } if lang == "zh" { - return "已更新交易员配置。" + reply := "已更新交易员配置。" + if len(warnings) > 0 { + reply += "\n\n已按手动面板范围自动调整:\n- " + strings.Join(warnings, "\n- ") + } + return reply } - return "Updated trader config." + reply := "Updated trader config." + if len(warnings) > 0 { + reply += "\n\nAdjusted to stay within the manual editor limits:\n- " + strings.Join(warnings, "\n- ") + } + return reply } - newName := extractTraderName(text) - if newName == "" { - newName = extractPostKeywordName(text, []string{"改成", "改为", "rename to"}) - } - if newName != "" { - setField(&session, "name", newName) - } - newName = fieldValue(session, "name") - if newName == "" { + if selectedField != "" { + setSkillDAGStep(&session, "collect_field_value") + } else { setSkillDAGStep(&session, "collect_name") - a.saveSkillSession(userID, session) - if lang == "zh" { - return "目前更新交易员这条 skill 先支持改名。请直接告诉我新的名字。" - } - return "This trader update skill currently supports renaming first. Tell me the new name." - } - args = manageTraderArgs{Action: "update", TraderID: session.TargetRef.ID, Name: newName} - setSkillDAGStep(&session, "execute_update") - resp := a.toolUpdateTrader(storeUserID, args) - a.clearSkillSession(userID) - if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { - if lang == "zh" { - return "更新交易员失败:" + errMsg - } - return "Failed to update trader: " + errMsg } + a.saveSkillSession(userID, session) if lang == "zh" { - return fmt.Sprintf("已将交易员改名为“%s”。", newName) + if selectedField != "" { + if selectedField == "ai_model_id" || selectedField == "exchange_id" || selectedField == "strategy_id" { + return fmt.Sprintf("还差一步:请告诉我你想换成哪个%s。", displayCatalogFieldName(selectedField, lang)) + } + return fmt.Sprintf("还差一步:请告诉我新的%s。", displayCatalogFieldName(selectedField, lang)) + } + return "你可以直接告诉我想改哪一项,比如绑定的模型、交易所、策略,或者扫描间隔、保证金模式、是否展示到竞技场。若你要改策略参数、模型配置或交易所凭证,我会切到对应配置流程。" } - return fmt.Sprintf("Renamed trader to %q.", newName) + if selectedField != "" { + if selectedField == "ai_model_id" || selectedField == "exchange_id" || selectedField == "strategy_id" { + return fmt.Sprintf("One more thing: tell me which %s you want to use.", displayCatalogFieldName(selectedField, lang)) + } + return fmt.Sprintf("One more thing: tell me the new %s.", displayCatalogFieldName(selectedField, lang)) + } + return "Tell me what you want to change first, for example the linked model, exchange, strategy, scan interval, margin mode, or competition visibility. If you want to edit the internals of a strategy, model, or exchange, I'll switch to the right config flow." default: return "" } } -func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { - if session.TargetRef == nil && session.Action != "query" && session.Action != "query_list" && session.Action != "create" { +func (a *Agent) executeBulkTraderDelete(storeUserID string, userID int64, lang, text string, session skillSession) string { + if a == nil || a.store == nil { if lang == "zh" { - return "请先告诉我你要操作哪个交易所配置。" + return "我这边暂时无法读取交易员列表。" + } + return "I cannot load the trader list right now." + } + traders, err := a.store.Trader().List(storeUserID) + if err != nil { + if lang == "zh" { + return "我这边暂时没读到交易员列表:" + err.Error() + } + return "I could not load the trader list just now: " + err.Error() + } + if len(traders) == 0 { + a.clearSkillSession(userID) + if lang == "zh" { + return "当前没有可删除的交易员。" + } + return "There are no traders to delete." + } + + deletable := make([]*store.Trader, 0, len(traders)) + runningNames := make([]string, 0) + for _, trader := range traders { + if trader == nil { + continue + } + isRunning := trader.IsRunning + if a.traderManager != nil { + if memTrader, err := a.traderManager.GetTrader(trader.ID); err == nil { + if running, ok := memTrader.GetStatus()["is_running"].(bool); ok { + isRunning = running + } + } + } + if isRunning { + runningNames = append(runningNames, defaultIfEmpty(trader.Name, trader.ID)) + continue + } + deletable = append(deletable, trader) + } + + if len(deletable) == 0 { + a.clearSkillSession(userID) + if lang == "zh" { + return "当前所有交易员都还在运行中,删除前需要先停止:" + strings.Join(runningNames, "、") + } + return "All traders are still running. Stop them before deleting: " + strings.Join(runningNames, ", ") + } + + targetLabel := fmt.Sprintf("全部已停止交易员(共 %d 个)", len(deletable)) + if msg, waiting := a.beginConfirmationIfNeeded(userID, lang, &session, targetLabel); waiting { + a.saveSkillSession(userID, session) + return msg + } + if msg, waiting := awaitingConfirmationButNotApproved(lang, session, text); waiting { + a.saveSkillSession(userID, session) + return msg + } + + setSkillDAGStep(&session, "execute_delete") + deletedNames := make([]string, 0, len(deletable)) + failedNames := make([]string, 0) + for _, trader := range deletable { + resp := a.toolDeleteTrader(storeUserID, trader.ID) + if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { + failedNames = append(failedNames, fmt.Sprintf("%s(%s)", defaultIfEmpty(trader.Name, trader.ID), errMsg)) + continue + } + deletedNames = append(deletedNames, defaultIfEmpty(trader.Name, trader.ID)) + } + a.clearSkillSession(userID) + + if lang == "zh" { + parts := []string{fmt.Sprintf("批量删除交易员已完成:成功删除 %d 个。", len(deletedNames))} + if len(runningNames) > 0 { + parts = append(parts, "这些交易员仍在运行,已跳过,删除前需要先停止:"+strings.Join(runningNames, "、")) + } + if len(failedNames) > 0 { + parts = append(parts, "这些没删成功:"+strings.Join(failedNames, ";")) + } + if len(deletedNames) > 0 { + parts = append(parts, "已删除:"+strings.Join(deletedNames, "、")) + } + return strings.Join(parts, "\n") + } + + parts := []string{fmt.Sprintf("Bulk trader deletion finished: deleted %d trader(s).", len(deletedNames))} + if len(runningNames) > 0 { + parts = append(parts, "Skipped running traders; stop them before deleting: "+strings.Join(runningNames, ", ")) + } + if len(failedNames) > 0 { + parts = append(parts, "These did not delete successfully: "+strings.Join(failedNames, "; ")) + } + if len(deletedNames) > 0 { + parts = append(parts, "Deleted: "+strings.Join(deletedNames, ", ")) + } + return strings.Join(parts, "\n") +} + +func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { + switch session.Action { + case "query", "query_list", "create": + // These actions don't need a target — fall through. + default: + if session.TargetRef == nil { + if lang == "zh" { + return "请先指定要操作的交易所配置。" + } + return "Please specify which exchange config to operate on." } - return "Please specify which exchange config you want to manage." } switch session.Action { case "query_detail": @@ -529,7 +1921,7 @@ func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64 if fieldValue(session, skillDAGStepField) == "" { setSkillDAGStep(&session, "await_confirmation") } - if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { + if msg, waiting := a.beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { a.saveSkillSession(userID, session) return msg } @@ -543,14 +1935,14 @@ func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64 a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "删除交易所配置失败:" + errMsg + return "这次没删成功:" + errMsg } - return "Failed to delete exchange config: " + errMsg + return "That delete did not go through: " + errMsg } if lang == "zh" { - return "已删除交易所配置。" + return a.maybeResumeParentTaskAfterSuccessfulSkill(storeUserID, userID, lang, "exchange_management", "delete", "已删除交易所配置。") } - return "Deleted exchange config." + return a.maybeResumeParentTaskAfterSuccessfulSkill(storeUserID, userID, lang, "exchange_management", "delete", "Deleted exchange config.") case "update", "update_name", "update_status": if fieldValue(session, skillDAGStepField) == "" { if session.Action == "update_status" { @@ -559,48 +1951,69 @@ func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64 setSkillDAGStep(&session, "collect_account_name") } } - accountName := extractTraderName(text) - if accountName == "" { - accountName = extractPostKeywordName(text, []string{"改成", "改为", "账户名改成", "rename to"}) - } - if accountName != "" { - setField(&session, "account_name", accountName) - } - if enabled, ok := parseEnabledValue(text); ok { - setField(&session, "enabled", strconv.FormatBool(enabled)) - } - if value := extractCredentialValue(text, []string{"api key", "apikey", "api_key"}); value != "" { - setField(&session, "api_key", value) - } - if value := extractCredentialValue(text, []string{"secret key", "secret", "secret_key"}); value != "" { - setField(&session, "secret_key", value) - } - if value := extractCredentialValue(text, []string{"passphrase", "密码短语"}); value != "" { - setField(&session, "passphrase", value) - } - if testnet, ok := parseFlagValue(text, []string{"testnet", "测试网"}); ok { - setField(&session, "testnet", strconv.FormatBool(testnet)) + patch := buildExchangeUpdatePatchFromSession(session) + selectedField := fieldValue(session, "update_field") + if selectedField == "" && session.Action == "update_status" { + selectedField = "enabled" + setField(&session, "update_field", selectedField) } + applyExchangeUpdatePatchToSession(&session, patch) + patch = mergeExchangeUpdatePatch(buildExchangeUpdatePatchFromSession(session), patch) + patch, warnings := normalizeExchangePatchToManualLimits(lang, patch) + applyExchangeUpdatePatchToSession(&session, patch) payload := map[string]any{"action": "update", "exchange_id": session.TargetRef.ID} - accountName = fieldValue(session, "account_name") + accountName := defaultIfEmpty(patch.AccountName, fieldValue(session, "account_name")) if accountName != "" && session.Action != "update_status" { payload["account_name"] = accountName } - if enabledRaw := fieldValue(session, "enabled"); enabledRaw != "" { + enabledRaw := fieldValue(session, "enabled") + if patch.Enabled != nil { + enabledRaw = strconv.FormatBool(*patch.Enabled) + } + if enabledRaw != "" { payload["enabled"] = enabledRaw == "true" } - if value := fieldValue(session, "api_key"); value != "" { + if value := defaultIfEmpty(patch.APIKey, fieldValue(session, "api_key")); value != "" { payload["api_key"] = value } - if value := fieldValue(session, "secret_key"); value != "" { + if value := defaultIfEmpty(patch.SecretKey, fieldValue(session, "secret_key")); value != "" { payload["secret_key"] = value } - if value := fieldValue(session, "passphrase"); value != "" { + if value := defaultIfEmpty(patch.Passphrase, fieldValue(session, "passphrase")); value != "" { payload["passphrase"] = value } - if value := fieldValue(session, "testnet"); value != "" { + testnetRaw := fieldValue(session, "testnet") + if patch.Testnet != nil { + testnetRaw = strconv.FormatBool(*patch.Testnet) + } + if value := testnetRaw; value != "" { payload["testnet"] = value == "true" } + if value := defaultIfEmpty(patch.HyperliquidWalletAddr, fieldValue(session, "hyperliquid_wallet_addr")); value != "" { + payload["hyperliquid_wallet_addr"] = value + } + if value := defaultIfEmpty(patch.AsterUser, fieldValue(session, "aster_user")); value != "" { + payload["aster_user"] = value + } + if value := defaultIfEmpty(patch.AsterSigner, fieldValue(session, "aster_signer")); value != "" { + payload["aster_signer"] = value + } + if value := defaultIfEmpty(patch.AsterPrivateKey, fieldValue(session, "aster_private_key")); value != "" { + payload["aster_private_key"] = value + } + if value := defaultIfEmpty(patch.LighterWalletAddr, fieldValue(session, "lighter_wallet_addr")); value != "" { + payload["lighter_wallet_addr"] = value + } + if value := defaultIfEmpty(patch.LighterAPIKeyPrivateKey, fieldValue(session, "lighter_api_key_private_key")); value != "" { + payload["lighter_api_key_private_key"] = value + } + if patch.LighterAPIKeyIndex != nil { + payload["lighter_api_key_index"] = *patch.LighterAPIKeyIndex + } else if value := fieldValue(session, "lighter_api_key_index"); value != "" { + if parsed, err := strconv.Atoi(value); err == nil { + payload["lighter_api_key_index"] = parsed + } + } if session.Action == "update_status" { delete(payload, "account_name") } @@ -608,13 +2021,41 @@ func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64 if session.Action == "update_status" { setSkillDAGStep(&session, "collect_enabled") } else { - setSkillDAGStep(&session, "collect_account_name") + if selectedField != "" { + setSkillDAGStep(&session, "collect_field_value") + } else { + setSkillDAGStep(&session, "collect_account_name") + } } a.saveSkillSession(userID, session) if lang == "zh" { - return "目前更新交易所 skill 支持改账户名、启用状态、API Key、Secret、Passphrase 和 testnet。请告诉我你要改什么。" + if selectedField != "" { + return fmt.Sprintf("还差一步:请告诉我你想把交易所配置里的%s改成什么。", displayCatalogFieldName(selectedField, lang)) + } + return "你可以直接告诉我想改交易所配置里的哪一项,比如账户名、启用开关、API Key、Passphrase、钱包地址或 testnet。" } - return "This exchange update skill supports account name, enabled state, API key, secret, passphrase, and testnet." + if selectedField != "" { + return fmt.Sprintf("One more thing: tell me what you want to change the exchange config %s to.", displayCatalogFieldName(selectedField, lang)) + } + return "Tell me which exchange config field you want to change, for example the account name, enabled switch, API key, passphrase, wallet address, or testnet." + } + if err := a.validateExchangeDraft( + storeUserID, + session.TargetRef.ID, + "", + payload["enabled"] == true, + asString(payload["api_key"]), + asString(payload["secret_key"]), + asString(payload["passphrase"]), + asString(payload["hyperliquid_wallet_addr"]), + asString(payload["aster_user"]), + asString(payload["aster_signer"]), + asString(payload["aster_private_key"]), + asString(payload["lighter_wallet_addr"]), + asString(payload["lighter_api_key_private_key"]), + ); err != nil { + a.saveSkillSession(userID, session) + return formatValidationFeedback(lang, "exchange", err) } setSkillDAGStep(&session, "execute_update") raw, _ := json.Marshal(payload) @@ -622,25 +2063,39 @@ func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64 a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "更新交易所配置失败:" + errMsg + return "这次没改成功:" + errMsg } - return "Failed to update exchange config: " + errMsg + return "That change did not go through: " + errMsg } + a.rememberReferencesFromToolResult(userID, "manage_exchange_config", resp) if lang == "zh" { - return "已更新交易所配置。" + reply := "已更新交易所配置。" + if len(warnings) > 0 { + reply += "\n\n已按手动面板范围自动调整:\n- " + strings.Join(warnings, "\n- ") + } + return a.maybeResumeParentTaskAfterSuccessfulSkill(storeUserID, userID, lang, "exchange_management", "update", reply) } - return "Updated exchange config." + reply := "Updated exchange config." + if len(warnings) > 0 { + reply += "\n\nAdjusted to stay within the manual editor limits:\n- " + strings.Join(warnings, "\n- ") + } + return a.maybeResumeParentTaskAfterSuccessfulSkill(storeUserID, userID, lang, "exchange_management", "update", reply) default: return "" } } func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { - if session.TargetRef == nil && session.Action != "query" && session.Action != "query_list" && session.Action != "create" { - if lang == "zh" { - return "请先告诉我你要操作哪个模型配置。" + switch session.Action { + case "query", "query_list", "create": + // These actions don't need a target — fall through. + default: + if session.TargetRef == nil { + if lang == "zh" { + return "请先指定要操作的模型。" + } + return "Please specify which model to operate on." } - return "Please specify which model config you want to manage." } switch session.Action { case "query_detail": @@ -652,7 +2107,7 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l if fieldValue(session, skillDAGStepField) == "" { setSkillDAGStep(&session, "await_confirmation") } - if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { + if msg, waiting := a.beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { a.saveSkillSession(userID, session) return msg } @@ -666,9 +2121,9 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "删除模型配置失败:" + errMsg + return "这次没删成功:" + errMsg } - return "Failed to delete model config: " + errMsg + return "That delete did not go through: " + errMsg } if lang == "zh" { return "已删除模型配置。" @@ -686,28 +2141,38 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l } } payload := map[string]any{"action": "update", "model_id": session.TargetRef.ID} - if url := extractURL(text); url != "" { - setField(&session, "custom_api_url", url) + patch := buildModelUpdatePatchFromSession(session) + selectedField := fieldValue(session, "update_field") + if selectedField == "" { + switch session.Action { + case "update_status": + selectedField = "enabled" + case "update_endpoint": + selectedField = "custom_api_url" + } + if selectedField != "" { + setField(&session, "update_field", selectedField) + } } - if enabled, ok := parseEnabledValue(text); ok { - setField(&session, "enabled", strconv.FormatBool(enabled)) + applyModelUpdatePatchToSession(&session, patch) + patch = mergeModelUpdatePatch(buildModelUpdatePatchFromSession(session), patch) + urlValue := patch.CustomAPIURL + enabledValue := "" + if patch.Enabled != nil { + enabledValue = strconv.FormatBool(*patch.Enabled) } - if apiKey := extractCredentialValue(text, []string{"api key", "apikey", "api_key"}); apiKey != "" { - setField(&session, "api_key", apiKey) - } - if modelName := extractPostKeywordName(text, []string{"model name", "模型名", "模型名称", "改成", "改为", "修改为", "换成", "换到", "切换为", "切换到", "change to", "switch to"}); modelName != "" { - setField(&session, "custom_model_name", normalizeModelName(modelName)) - } - if value := fieldValue(session, "custom_api_url"); value != "" { + apiKeyValue := patch.APIKey + modelNameValue := patch.CustomModelName + if value := defaultIfEmpty(urlValue, fieldValue(session, "custom_api_url")); value != "" { payload["custom_api_url"] = value } - if value := fieldValue(session, "enabled"); value != "" { + if value := defaultIfEmpty(enabledValue, fieldValue(session, "enabled")); value != "" { payload["enabled"] = value == "true" } - if value := fieldValue(session, "api_key"); value != "" { + if value := defaultIfEmpty(apiKeyValue, fieldValue(session, "api_key")); value != "" { payload["api_key"] = value } - if value := fieldValue(session, "custom_model_name"); value != "" { + if value := defaultIfEmpty(modelNameValue, fieldValue(session, "custom_model_name")); value != "" { payload["custom_model_name"] = value } if session.Action == "update_name" { @@ -732,13 +2197,35 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l case "update_endpoint": setSkillDAGStep(&session, "collect_custom_api_url") default: - setSkillDAGStep(&session, "collect_custom_model_name") + if selectedField != "" { + setSkillDAGStep(&session, "collect_field_value") + } else { + setSkillDAGStep(&session, "collect_custom_model_name") + } } a.saveSkillSession(userID, session) if lang == "zh" { - return "目前更新模型 skill 支持改 API Key、URL、模型名和启用状态。请告诉我你要改什么。" + if selectedField != "" { + return fmt.Sprintf("还差一步:请告诉我新的%s。", displayCatalogFieldName(selectedField, lang)) + } + return "你可以直接告诉我想改哪一项,比如模型名称、接口地址,或者开关状态。" } - return "This model update skill supports API key, URL, model name, and enabled state." + if selectedField != "" { + return fmt.Sprintf("One more thing: tell me the new %s.", displayCatalogFieldName(selectedField, lang)) + } + return "Tell me what you want to change, for example the model name, endpoint URL, or on or off status." + } + if err := a.validateModelDraft( + storeUserID, + session.TargetRef.ID, + "", + payload["enabled"] == true, + asString(payload["api_key"]), + asString(payload["custom_api_url"]), + asString(payload["custom_model_name"]), + ); err != nil { + a.saveSkillSession(userID, session) + return formatValidationFeedback(lang, "model", err) } setSkillDAGStep(&session, "execute_update") raw, _ := json.Marshal(payload) @@ -749,12 +2236,13 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l if strings.Contains(errMsg, "cannot enable model config before API key is configured") { return "更新模型配置失败:这个模型还没有配置 API Key,暂时不能启用。你可以直接把 API Key 发给我,我帮你继续配置。" } - return "更新模型配置失败:" + errMsg + return "这次没改成功:" + errMsg } a.saveSkillSession(userID, session) - return "Failed to update model config: " + errMsg + return "That change did not go through: " + errMsg } a.clearSkillSession(userID) + a.rememberReferencesFromToolResult(userID, "manage_model_config", resp) if lang == "zh" { if session.Action == "update_status" { return "已更新模型配置启用状态。" @@ -767,65 +2255,18 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l } } -// normalizeModelName maps common user-friendly model aliases to the canonical -// names used by claw402 and other providers (e.g. "claude opus4.6" → "claude-opus"). -func normalizeModelName(name string) string { - lower := strings.ToLower(strings.TrimSpace(name)) - aliases := map[string]string{ - // Claude - "claude opus": "claude-opus", - "claude opus4.6": "claude-opus", - "claude opus 4.6": "claude-opus", - "claude-opus-4-6": "claude-opus", - "claude sonnet": "claude-sonnet", - "claude sonnet4.6": "claude-sonnet", - "claude sonnet 4.6": "claude-sonnet", - "claude haiku": "claude-haiku", - // GPT - "gpt5.4": "gpt-5.4", - "gpt 5.4": "gpt-5.4", - "gpt5.4pro": "gpt-5.4-pro", - "gpt 5.4pro": "gpt-5.4-pro", - "gpt 5.4 pro": "gpt-5.4-pro", - "gpt5 mini": "gpt-5-mini", - "gpt 5 mini": "gpt-5-mini", - "gpt5.3": "gpt-5.3", - "gpt 5.3": "gpt-5.3", - // DeepSeek - "deepseek reasoner": "deepseek-reasoner", - "deepseek chat": "deepseek-chat", - // Qwen (通义千问) - "qwen max": "qwen-max", - "qwen plus": "qwen-plus", - "qwen turbo": "qwen-turbo", - "qwen flash": "qwen-flash", - "通义千问": "qwen-max", - // Gemini - "gemini 3.1 pro": "gemini-3.1-pro", - "gemini 3.1pro": "gemini-3.1-pro", - // Kimi - "kimi k2.5": "kimi-k2.5", - // GLM (智谱清言) - "glm5": "glm-5", - "glm 5": "glm-5", - "glm5 turbo": "glm-5-turbo", - "glm 5 turbo": "glm-5-turbo", - "glm5-turbo": "glm-5-turbo", - "智谱清言": "glm-5", - } - if canonical, ok := aliases[lower]; ok { - return canonical - } - // Replace spaces with hyphens as a general fallback - return strings.ReplaceAll(strings.TrimSpace(name), " ", "-") -} - func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { - if session.TargetRef == nil && session.Action != "query" && session.Action != "query_list" && session.Action != "create" { - if lang == "zh" { - return "请先告诉我你要操作哪个策略。" + switch session.Action { + case "query", "query_list", "create": + // These actions don't need a target — fall through. + default: + isBulkDelete := session.Action == "delete" && fieldValue(session, "bulk_scope") == "all" + if session.TargetRef == nil && !isBulkDelete { + if lang == "zh" { + return "请先指定要操作的策略。" + } + return "Please specify which strategy to operate on." } - return "Please specify which strategy you want to manage." } switch session.Action { case "query", "query_list": @@ -841,9 +2282,9 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "激活策略失败:" + errMsg + return "这次没激活成功:" + errMsg } - return "Failed to activate strategy: " + errMsg + return "That activation did not go through: " + errMsg } if lang == "zh" { return "已激活策略。" @@ -853,21 +2294,17 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 if fieldValue(session, skillDAGStepField) == "" { setSkillDAGStep(&session, "collect_name") } - newName := extractTraderName(text) - if newName == "" { - newName = extractPostKeywordName(text, []string{"叫", "名为", "改成", "rename to"}) - } + newName := fieldValue(session, "name") if newName != "" { setField(&session, "name", newName) } - newName = fieldValue(session, "name") if newName == "" { setSkillDAGStep(&session, "collect_name") a.saveSkillSession(userID, session) if lang == "zh" { - return "复制策略时,我还需要一个新名称。" + return "还差一步:请给这个新策略起个名字。" } - return "I still need a new name for the duplicated strategy." + return "One more thing: give the new strategy a name." } setSkillDAGStep(&session, "execute_duplicate") raw, _ := json.Marshal(map[string]any{"action": "duplicate", "strategy_id": session.TargetRef.ID, "name": newName}) @@ -875,9 +2312,9 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "复制策略失败:" + errMsg + return "这次没复制成功:" + errMsg } - return "Failed to duplicate strategy: " + errMsg + return "That copy did not go through: " + errMsg } if lang == "zh" { return fmt.Sprintf("已复制策略,新名称为“%s”。", newName) @@ -891,9 +2328,9 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 strategies, err := a.store.Strategy().List(storeUserID) if err != nil { if lang == "zh" { - return "读取策略列表失败:" + err.Error() + return "我这边暂时没读到策略列表:" + err.Error() } - return "Failed to load strategies: " + err.Error() + return "I could not load the strategy list just now: " + err.Error() } deletable := make([]*store.Strategy, 0, len(strategies)) @@ -917,7 +2354,7 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 } targetLabel := fmt.Sprintf("全部自定义策略(共 %d 个)", len(deletable)) - if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, targetLabel); waiting { + if msg, waiting := a.beginConfirmationIfNeeded(userID, lang, &session, targetLabel); waiting { a.saveSkillSession(userID, session) return msg } @@ -946,7 +2383,7 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 parts = append(parts, fmt.Sprintf("已跳过系统默认策略 %d 个。", skippedDefault)) } if len(failedNames) > 0 { - parts = append(parts, "删除失败:"+strings.Join(failedNames, ";")) + parts = append(parts, "这些没删成功:"+strings.Join(failedNames, ";")) } if len(deletedNames) > 0 { parts = append(parts, "已删除:"+strings.Join(deletedNames, "、")) @@ -959,14 +2396,14 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 parts = append(parts, fmt.Sprintf("Skipped %d default strategy(ies).", skippedDefault)) } if len(failedNames) > 0 { - parts = append(parts, "Failed: "+strings.Join(failedNames, "; ")) + parts = append(parts, "These did not delete successfully: "+strings.Join(failedNames, "; ")) } if len(deletedNames) > 0 { parts = append(parts, "Deleted: "+strings.Join(deletedNames, ", ")) } return strings.Join(parts, "\n") } - if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { + if msg, waiting := a.beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { a.saveSkillSession(userID, session) return msg } @@ -980,39 +2417,35 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "删除策略失败:" + errMsg + return "这次没删成功:" + errMsg } - return "Failed to delete strategy: " + errMsg + return "That delete did not go through: " + errMsg } if lang == "zh" { return "已删除策略。" } return "Deleted strategy." - case "update", "update_name", "update_config", "update_prompt": + case "update_name", "update_config", "update_prompt": if session.Action == "update_prompt" { return a.executeStrategyPromptUpdate(storeUserID, userID, lang, text, session) } - if session.Action == "update_config" { + if session.Action == "update_config" || fieldValue(session, strategyPendingUpdateConfigField) != "" { return a.executeStrategyConfigUpdate(storeUserID, userID, lang, text, session) } if fieldValue(session, skillDAGStepField) == "" { setSkillDAGStep(&session, "collect_name") } - newName := extractTraderName(text) - if newName == "" { - newName = extractPostKeywordName(text, []string{"改成", "改为", "rename to"}) - } + newName := fieldValue(session, "name") if newName != "" { setField(&session, "name", newName) } - newName = fieldValue(session, "name") if newName == "" { setSkillDAGStep(&session, "collect_name") a.saveSkillSession(userID, session) if lang == "zh" { - return "目前更新策略 skill 先支持改名。请告诉我新的策略名称。" + return "目前这里先支持改策略名称。你直接把新名字发给我就行。" } - return "This strategy update skill currently supports renaming first." + return "For now, this step supports renaming the strategy. Just send me the new name." } setSkillDAGStep(&session, "execute_update") raw, _ := json.Marshal(map[string]any{"action": "update", "strategy_id": session.TargetRef.ID, "name": newName}) @@ -1020,14 +2453,20 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "更新策略失败:" + errMsg + return "这次没改成功:" + errMsg } - return "Failed to update strategy: " + errMsg + return "That change did not go through: " + errMsg } if lang == "zh" { return fmt.Sprintf("已将策略改名为“%s”。", newName) } return fmt.Sprintf("Renamed strategy to %q.", newName) + case "update": + a.clearSkillSession(userID) + if lang == "zh" { + return "我需要先明确你要改策略的哪一部分:名称、提示词,还是策略参数。" + } + return "I need to know which part of the strategy to update: name, prompt, or config." default: return "" } @@ -1040,26 +2479,46 @@ func (a *Agent) executeStrategyPromptUpdate(storeUserID string, userID int64, la strategy, cfg, err := a.loadStrategyConfigForUpdate(storeUserID, session.TargetRef.ID) if err != nil { if lang == "zh" { - return "读取策略失败:" + err.Error() + return "我这边暂时没读到这份策略:" + err.Error() } - return "Failed to load strategy: " + err.Error() + return "I could not load that strategy just now: " + err.Error() } - prompt := extractQuotedContent(text) + prompt := fieldValue(session, "prompt") if prompt == "" { - prompt = extractPostKeywordName(text, []string{"prompt改成", "prompt 改成", "提示词改成", "提示词改为", "custom prompt 改成"}) + prompt = fieldValue(session, "custom_prompt") + if prompt != "" { + setField(&session, "prompt", prompt) + } } - if prompt != "" { - setField(&session, "prompt", prompt) + if generatedDraftRequiresConfirmation(session) { + switch { + case createConfirmationReply(text): + clearGeneratedDraftConfirmation(&session) + case isNoReply(text): + clearGeneratedDraftConfirmation(&session, "prompt", "custom_prompt") + setSkillDAGStep(&session, "collect_prompt") + session.Phase = "collecting" + a.saveSkillSession(userID, session) + if lang == "zh" { + return "好,我先不用这版草稿。你可以告诉我想保留的风格,或者直接让我重新设计一版 prompt。" + } + return "Okay, I won't use that draft. Tell me the style you want to keep, or ask me to draft another prompt." + } + } + if prompt == "" { + prompt = extractQuotedContent(text) + if prompt != "" { + setField(&session, "prompt", prompt) + } } - prompt = fieldValue(session, "prompt") if prompt == "" { setSkillDAGStep(&session, "collect_prompt") a.saveSkillSession(userID, session) if lang == "zh" { - return "这次是更新策略 prompt,请直接把新的 prompt 内容发给我,最好放在引号里。" + return "还差一步:请把新的提示词内容发给我,直接发正文就行。" } - return "This action updates the strategy prompt. Send me the new prompt text, ideally inside quotes." + return "One more thing: send me the new prompt text." } cfg.CustomPrompt = prompt @@ -1068,101 +2527,105 @@ func (a *Agent) executeStrategyPromptUpdate(storeUserID string, userID int64, la } func (a *Agent) executeStrategyConfigUpdate(storeUserID string, userID int64, lang, text string, session skillSession) string { + if rawPending := fieldValue(session, strategyPendingUpdateConfigField); rawPending != "" { + if createConfirmationReply(text) { + var pendingCfg store.StrategyConfig + if err := json.Unmarshal([]byte(rawPending), &pendingCfg); err != nil { + if session.Fields != nil { + delete(session.Fields, strategyPendingUpdateConfigField) + delete(session.Fields, strategyPendingUpdateWarnings) + delete(session.Fields, strategyPendingUpdateZhMsg) + delete(session.Fields, strategyPendingUpdateEnMsg) + } + session.Phase = "collecting" + a.saveSkillSession(userID, session) + if lang == "zh" { + return "我这边暂时没读到刚才那版草稿。你再告诉我想改哪一项,我马上继续。" + } + return "I could not read that draft just now. Tell me what you want to change and I will continue." + } + zhMsg := defaultIfEmpty(fieldValue(session, strategyPendingUpdateZhMsg), "已更新策略参数。") + enMsg := defaultIfEmpty(fieldValue(session, strategyPendingUpdateEnMsg), "Updated strategy config.") + return a.persistPendingStrategyConfigUpdate(storeUserID, userID, lang, session, pendingCfg, zhMsg, enMsg) + } + if session.Fields != nil { + delete(session.Fields, strategyPendingUpdateConfigField) + delete(session.Fields, strategyPendingUpdateWarnings) + delete(session.Fields, strategyPendingUpdateZhMsg) + delete(session.Fields, strategyPendingUpdateEnMsg) + } + session.Phase = "collecting" + } + if _, ok := getSkillDAG("strategy_management", "update_config"); ok { if fieldValue(session, skillDAGStepField) == "" { - setSkillDAGStep(&session, "resolve_config_field") + setSkillDAGStep(&session, "collect_config_patch") } } - currentStep, _ := currentSkillDAGStep(session) strategy, cfg, err := a.loadStrategyConfigForUpdate(storeUserID, session.TargetRef.ID) if err != nil { if lang == "zh" { - return "读取策略失败:" + err.Error() + return "我这边暂时没读到这份策略:" + err.Error() } - return "Failed to load strategy: " + err.Error() + return "I could not load that strategy just now: " + err.Error() } - if fieldValue(session, "config_field") == "" && fieldValue(session, "config_value") == "" { - patches := detectStrategyConfigPatches(text) - if len(patches) > 1 { - changed := make([]string, 0, len(patches)) - for _, patch := range patches { - if err := applyStrategyConfigPatch(&cfg, patch.Field, patch.Value); err != nil { - a.saveSkillSession(userID, session) - if lang == "zh" { - return "更新策略参数失败:" + err.Error() - } - return "Failed to update strategy config: " + err.Error() - } - changed = append(changed, strategyConfigFieldDisplayName(patch.Field, lang)) + if patchRaw := strings.TrimSpace(fieldValue(session, strategyCreateConfigPatchField)); patchRaw != "" { + var patch map[string]any + if err := json.Unmarshal([]byte(patchRaw), &patch); err != nil { + setSkillDAGStep(&session, "collect_config_patch") + a.saveSkillSession(userID, session) + if lang == "zh" { + return "策略配置 patch 不是合法 JSON:" + err.Error() } - cfg.ClampLimits() - setSkillDAGStep(&session, "apply_field_update") - setSkillDAGStep(&session, "execute_update") - msgZH := "已更新策略参数:" + strings.Join(changed, "、") + "。" - msgEN := "Updated strategy config fields: " + strings.Join(changed, ", ") + "." - return a.persistStrategyConfigUpdate(storeUserID, userID, lang, strategy, cfg, msgZH, msgEN) + return "The strategy config patch is not valid JSON: " + err.Error() } - } - - field := fieldValue(session, "config_field") - if field == "" { - field = detectStrategyConfigField(text) - if field != "" { - setField(&session, "config_field", field) - if currentStep.ID == "resolve_config_field" { - advanceSkillDAGStep(&session, currentStep.ID) - currentStep, _ = currentSkillDAGStep(session) + merged, err := store.MergeStrategyConfig(cfg, patch) + if err != nil { + setSkillDAGStep(&session, "collect_config_patch") + a.saveSkillSession(userID, session) + if lang == "zh" { + return "策略配置 patch 无法应用:" + err.Error() } + return "The strategy config patch could not be applied: " + err.Error() } - } - if field == "" { - setSkillDAGStep(&session, "resolve_config_field") - a.saveSkillSession(userID, session) + beforeClamp := merged + merged.ClampLimits() + msgZH := "已更新策略配置。" + msgEN := "Updated strategy config." + setSkillDAGStep(&session, "execute_update") + if warnings := store.StrategyClampWarnings(beforeClamp, merged, lang); len(warnings) > 0 { + return a.deferStrategyRiskControlledUpdate(userID, lang, &session, merged, warnings, msgZH, msgEN) + } + setSkillDAGStep(&session, "execute_update") + raw, _ := json.Marshal(map[string]any{ + "action": "update", + "strategy_id": strategy.ID, + "config": patch, + "allow_clamped_update": true, + }) + resp := a.toolManageStrategy(storeUserID, string(raw)) + a.clearSkillSession(userID) + if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { + if lang == "zh" { + return "这次没改成功:" + errMsg + } + return "That change did not go through: " + errMsg + } + a.rememberReferencesFromToolResult(userID, "manage_strategy", resp) if lang == "zh" { - return "这次是更新策略参数。我当前先支持这些字段:最大持仓、最低置信度、主周期、多周期时间框架。请先告诉我要改哪个字段。" + return msgZH } - return "This action updates strategy config. I currently support max positions, min confidence, primary timeframe, and selected timeframes. Tell me which field to change first." + return msgEN } - if value, ok := extractStrategyConfigValue(text, field); ok { - setField(&session, "config_value", value) - if currentStep.ID == "resolve_config_value" { - advanceSkillDAGStep(&session, currentStep.ID) - currentStep, _ = currentSkillDAGStep(session) - } + setSkillDAGStep(&session, "collect_config_patch") + a.saveSkillSession(userID, session) + if lang == "zh" { + return "你可以直接说想怎么改策略配置,比如“选币来源改成 AI500,最低置信度 80”。我会按当前策略类型的产品模板生成 config_patch 后再更新。" } - value := fieldValue(session, "config_value") - if value == "" { - setSkillDAGStep(&session, "resolve_config_value") - a.saveSkillSession(userID, session) - if lang == "zh" { - return fmt.Sprintf("要更新策略参数,我还需要 %s 的目标值。", strategyConfigFieldDisplayName(field, lang)) - } - return fmt.Sprintf("I still need the target value for %s.", strategyConfigFieldDisplayName(field, lang)) - } - - if err := applyStrategyConfigPatch(&cfg, field, value); err != nil { - setSkillDAGStep(&session, "resolve_config_value") - a.saveSkillSession(userID, session) - if lang == "zh" { - return err.Error() - } - return err.Error() - } - - cfg.ClampLimits() - changed := []string{field} - displayChanged := make([]string, 0, len(changed)) - for _, item := range changed { - displayChanged = append(displayChanged, strategyConfigFieldDisplayName(item, lang)) - } - msgZH := "已更新策略参数:" + strings.Join(displayChanged, "、") + "。" - msgEN := "Updated strategy config fields: " + strings.Join(displayChanged, ", ") + "." - setSkillDAGStep(&session, "apply_field_update") - setSkillDAGStep(&session, "execute_update") - return a.persistStrategyConfigUpdate(storeUserID, userID, lang, strategy, cfg, msgZH, msgEN) + return "Tell me how you want to change the strategy config, for example: set coin source to ai500 and minimum confidence to 80. I will turn it into a config_patch for the current strategy type before updating." } func (a *Agent) loadStrategyConfigForUpdate(storeUserID, strategyID string) (*store.Strategy, store.StrategyConfig, error) { @@ -1177,26 +2640,75 @@ func (a *Agent) loadStrategyConfigForUpdate(storeUserID, strategyID string) (*st return strategy, cfg, nil } +func (a *Agent) deferStrategyRiskControlledUpdate(userID int64, lang string, session *skillSession, cfg store.StrategyConfig, warnings []string, zhMsg, enMsg string) string { + rawConfig, _ := json.Marshal(cfg) + setField(session, strategyPendingUpdateConfigField, string(rawConfig)) + setField(session, strategyPendingUpdateWarnings, marshalStringList(warnings)) + setField(session, strategyPendingUpdateZhMsg, zhMsg) + setField(session, strategyPendingUpdateEnMsg, enMsg) + session.Phase = "await_confirmation" + setSkillDAGStep(session, "await_confirmation") + a.saveSkillSession(userID, *session) + task := SuspendedTask{ + Kind: "skill_session", + SkillSession: func() *skillSession { + copy := normalizeSkillSession(*session) + return © + }(), + ResumeHint: buildSkillResumeHint(lang, *session), + } + a.SnapshotManager(userID).Save(task) + return formatRiskControlAcceptancePrompt(lang, warnings, "确认应用") +} + +func (a *Agent) persistPendingStrategyConfigUpdate(storeUserID string, userID int64, lang string, session skillSession, cfg store.StrategyConfig, zhMsg, enMsg string) string { + if session.Fields != nil { + delete(session.Fields, strategyPendingUpdateConfigField) + delete(session.Fields, strategyPendingUpdateWarnings) + delete(session.Fields, strategyPendingUpdateZhMsg) + delete(session.Fields, strategyPendingUpdateEnMsg) + } + strategy, _, err := a.loadStrategyConfigForUpdate(storeUserID, session.TargetRef.ID) + if err != nil { + if lang == "zh" { + return "我这边暂时没读到这份策略:" + err.Error() + } + return "I could not load that strategy just now: " + err.Error() + } + return a.persistStrategyConfigUpdate(storeUserID, userID, lang, strategy, cfg, zhMsg, enMsg) +} + func (a *Agent) persistStrategyConfigUpdate(storeUserID string, userID int64, lang string, strategy *store.Strategy, cfg store.StrategyConfig, zhMsg, enMsg string) string { rawConfig, err := json.Marshal(cfg) if err != nil { if lang == "zh" { - return "序列化策略配置失败:" + err.Error() + return "我这边整理这份策略配置时出了点问题:" + err.Error() } - return "Failed to serialize strategy config: " + err.Error() + return "I ran into a problem while preparing that strategy config: " + err.Error() } raw, _ := json.Marshal(map[string]any{ - "action": "update", - "strategy_id": strategy.ID, - "config": json.RawMessage(rawConfig), + "action": "update", + "strategy_id": strategy.ID, + "name": strategy.Name, + "description": strategy.Description, + "is_public": strategy.IsPublic, + "config_visible": strategy.ConfigVisible, + "config": json.RawMessage(rawConfig), }) resp := a.toolManageStrategy(storeUserID, string(raw)) a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { if lang == "zh" { - return "更新策略失败:" + errMsg + return "这次没改成功:" + errMsg + } + return "That change did not go through: " + errMsg + } + if warnings := parseToolWarnings(resp); len(warnings) > 0 { + if lang == "zh" { + zhMsg += "\n\n已按安全范围自动调整:\n- " + strings.Join(warnings, "\n- ") + } else { + enMsg += "\n\nAdjusted to stay within safe limits:\n- " + strings.Join(warnings, "\n- ") } - return "Failed to update strategy: " + errMsg } if lang == "zh" { return zhMsg @@ -1204,8 +2716,18 @@ func (a *Agent) persistStrategyConfigUpdate(storeUserID string, userID int64, la return enMsg } +func parseToolWarnings(raw string) []string { + var payload struct { + Warnings []string `json:"warnings"` + } + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return nil + } + return payload.Warnings +} + func extractQuotedContent(text string) string { - if matches := quotedNamePattern.FindStringSubmatch(text); len(matches) == 2 { + if matches := quotedContentRE.FindStringSubmatch(text); len(matches) == 2 { return strings.TrimSpace(matches[1]) } return "" @@ -1265,20 +2787,398 @@ func extractTimeframes(text string) []string { } func (a *Agent) handleTraderDiagnosisSkill(storeUserID, lang, text string) string { - raw := a.toolListTraders(storeUserID) - list := formatReadFastPathResponse(lang, "list_traders", raw) - if lang == "zh" { - reply := "现象:这是交易员运行诊断问题。\n优先排查:\n1. 交易员是否已创建并处于运行状态。\n2. 绑定的模型、交易所、策略是否齐全。\n3. 是“没有启动”、还是“启动了但 AI 没有下单”、还是“下单失败”。\n当前交易员概览:\n" + list - if excerpt := backendLogDiagnosisExcerpt(lang, text, "trader"); excerpt != "" { - reply += "\n" + excerpt + target := resolveDiagnosisTraderTarget(a.loadTraderOptions(storeUserID), text) + if target == nil { + raw := a.toolListTraders(storeUserID) + list := formatReadFastPathResponse(lang, "list_traders", raw) + if lang == "zh" { + return "我需要先确定要诊断哪个交易员。当前交易员:\n" + list } - return reply + return "I need to know which trader to diagnose first. Current traders:\n" + list } - reply := "This looks like a trader diagnosis issue.\nCheck whether the trader exists, is running, and has model/exchange/strategy bindings.\nCurrent trader overview:\n" + list - if excerpt := backendLogDiagnosisExcerpt(lang, text, "trader"); excerpt != "" { - reply += "\n" + excerpt + + evidence := a.collectTraderDiagnosisEvidence(storeUserID, target.ID, target.Name) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if answer, ok := a.generateTraderDiagnosisAnswerWithLLM(ctx, lang, text, evidence); ok { + return answer } - return reply + return formatTraderDiagnosisEvidence(lang, evidence) +} + +func resolveDiagnosisTraderTarget(options []traderSkillOption, text string) *traderSkillOption { + if opt := findOptionByIDOrName(options, text); opt != nil { + return opt + } + if opt := findUniqueContainingOption(options, text); opt != nil { + return opt + } + if len(options) == 1 { + return &options[0] + } + return nil +} + +type traderDecisionToolResponse struct { + Error string `json:"error"` + TraderID string `json:"trader_id"` + TraderName string `json:"trader_name"` + Count int `json:"count"` + Records []struct { + Success bool `json:"success"` + ErrorMessage string `json:"error_message"` + AIRequestDurationMs int `json:"ai_request_duration_ms"` + CandidateCoins []string `json:"candidate_coins"` + ExecutionLog []string `json:"execution_log"` + DecisionJSON string `json:"decision_json"` + Decisions []map[string]any `json:"decisions"` + } `json:"records"` +} + +type traderDiagnosisEvidence struct { + TraderName string + TraderConfig *store.Trader + Model *safeModelToolConfig + Exchange *safeExchangeToolConfig + Strategy *safeStrategyToolConfig + Runtime map[string]any + Account map[string]any + Positions []map[string]any + Decisions traderDecisionToolResponse + Logs struct { + Entries []any `json:"entries"` + Count int `json:"count"` + Error string `json:"error"` + } +} + +func (a *Agent) collectTraderDiagnosisEvidence(storeUserID, traderID, traderName string) traderDiagnosisEvidence { + ev := traderDiagnosisEvidence{TraderName: traderName} + if a.store != nil { + if traderCfg, err := a.resolveTraderForTool(storeUserID, traderID, traderName); err == nil { + ev.TraderConfig = traderCfg + ev.TraderName = defaultIfEmpty(traderCfg.Name, ev.TraderName) + if model, err := a.store.AIModel().Get(storeUserID, traderCfg.AIModelID); err == nil && model != nil { + safeModel := safeModelForTool(model) + ev.Model = &safeModel + } + if exchange, err := a.store.Exchange().GetByID(storeUserID, traderCfg.ExchangeID); err == nil && exchange != nil { + safeExchange := safeExchangeForTool(exchange) + ev.Exchange = &safeExchange + } + if strings.TrimSpace(traderCfg.StrategyID) != "" { + if strategy, err := a.store.Strategy().Get(storeUserID, traderCfg.StrategyID); err == nil && strategy != nil { + safeStrategy := safeStrategyForTool(strategy) + ev.Strategy = &safeStrategy + } + } + } + } + if a.traderManager != nil && ev.TraderConfig != nil { + if runtimeTrader, err := a.traderManager.GetTrader(ev.TraderConfig.ID); err == nil && runtimeTrader != nil { + ev.Runtime = runtimeTrader.GetStatus() + if account, err := runtimeTrader.GetAccountInfo(); err == nil { + ev.Account = account + } + if positions, err := runtimeTrader.GetPositions(); err == nil { + ev.Positions = positions + } + } + } + if ev.TraderConfig != nil { + decisionArgs, _ := json.Marshal(map[string]any{"trader_id": ev.TraderConfig.ID, "limit": 5}) + _ = json.Unmarshal([]byte(a.toolGetDecisions(storeUserID, string(decisionArgs))), &ev.Decisions) + logArgs, _ := json.Marshal(map[string]any{"trader_id": ev.TraderConfig.ID, "limit": 30, "errors_only": false}) + _ = json.Unmarshal([]byte(a.toolGetBackendLogs(storeUserID, string(logArgs))), &ev.Logs) + } + return ev +} + +func (a *Agent) generateTraderDiagnosisAnswerWithLLM(ctx context.Context, lang, userText string, ev traderDiagnosisEvidence) (string, bool) { + if a == nil || a.aiClient == nil || ev.TraderConfig == nil { + return "", false + } + evidenceJSON, err := json.MarshalIndent(ev, "", " ") + if err != nil { + return "", false + } + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + systemPrompt := `You are the trader diagnosis reasoning layer for NOFXi. +You receive a complete evidence package collected by tools: trader config, bound model, bound exchange, bound strategy, account/positions, recent AI decisions, and backend logs. + +Your job: +- Reason from the evidence and produce the final user-facing diagnosis in the user's language. +- The answer must be short and useful: final cause + what the user should do. +- Prefer recent AI decisions, order validation, exchange result, runtime/account/positions over scattered backend logs. +- Do not expose evidence-package wording, tool names, raw logs, HTTP status codes, backend internals, or engineering troubleshooting unless the user explicitly asked for technical logs. +- Do not invent subscriptions, data services, websites, missing product fields, or unsupported actions. +- Never say "subscription expired" unless the evidence explicitly contains a confirmed subscription state. +- If an order is blocked because the amount is too small, explain it as account size/order minimum/system limit. Do not suggest editing position_size_usd, min_position_size, max_positions, position value ratios, or other System enforced fields. +- If the latest decision is wait/hold, explain that the trader is running and the AI chose to wait because the entry standard was not met. +- If evidence is insufficient, say what is missing and the next concrete check. + +Return plain text only. No markdown tables.` + userPrompt := fmt.Sprintf("Language: %s\nUser question: %s\n\nEvidence JSON:\n%s", lang, userText, string(evidenceJSON)) + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + a.log().Warn("trader diagnosis LLM failed; using deterministic fallback", "error", err) + return "", false + } + answer := strings.TrimSpace(raw) + if answer == "" { + return "", false + } + return answer, true +} + +func formatTraderDiagnosisEvidence(lang string, ev traderDiagnosisEvidence) string { + traderName := defaultIfEmpty(ev.TraderName, "未知交易员") + if ev.TraderConfig == nil { + if lang == "zh" { + return fmt.Sprintf("我没有找到交易员“%s”,所以没法继续诊断。", traderName) + } + return fmt.Sprintf("I could not find trader %q, so I cannot diagnose it yet.", traderName) + } + latest := struct { + Success bool `json:"success"` + ErrorMessage string `json:"error_message"` + AIRequestDurationMs int `json:"ai_request_duration_ms"` + CandidateCoins []string `json:"candidate_coins"` + ExecutionLog []string `json:"execution_log"` + DecisionJSON string `json:"decision_json"` + Decisions []map[string]any `json:"decisions"` + }{} + hasDecision := len(ev.Decisions.Records) > 0 + if hasDecision { + latest = ev.Decisions.Records[0] + } + rawDecisions, _ := json.Marshal(ev.Decisions) + allEvidence := strings.ToLower(string(rawDecisions)) + latestEvidence := strings.ToLower(strings.Join(append(append([]string{}, latest.ExecutionLog...), latest.ErrorMessage, latest.DecisionJSON), "\n")) + hasAmountTooSmall := containsAny(allEvidence, []string{"opening amount too small", "below minimum", "must be ≥", "must be >=", "position value below minimum"}) + latestWait := containsAny(latestEvidence, []string{"wait succeeded", `"action":"wait"`, `"action":"hold"`}) + primarySymbol := primaryDiagnosisSymbol(latest.CandidateCoins, latest.DecisionJSON) + amount, minimum := openingAmountAndMinimum(string(rawDecisions)) + totalEquity := toFloat(ev.Account["total_equity"]) + available := toFloat(ev.Account["available_balance"]) + if available == 0 { + available = toFloat(ev.Account["available"]) + } + var maxBTCETHPositionValue float64 + if ev.Strategy != nil && ev.Strategy.Config != nil { + if risk, ok := nestedMap(ev.Strategy.Config, "ai_config", "risk_control"); ok { + maxBTCETHPositionValue = totalEquity * firstPositiveFloat(risk["btc_eth_max_position_value_ratio"], risk["btceth_max_position_value_ratio"]) + } + if maxBTCETHPositionValue == 0 { + if risk, ok := ev.Strategy.Config["risk_control"].(map[string]any); ok { + maxBTCETHPositionValue = totalEquity * firstPositiveFloat(risk["btc_eth_max_position_value_ratio"], risk["btceth_max_position_value_ratio"]) + } + } + } + + if lang == "zh" { + lines := []string{} + switch { + case !ev.TraderConfig.IsRunning: + lines = append(lines, fmt.Sprintf("%s 现在没有运行,所以不会开单。", traderName)) + lines = append(lines, "该怎么办:先启动这个交易员;启动后等它跑到下一个扫描周期,再看是否有新的 AI 决策。") + case strings.TrimSpace(ev.TraderConfig.AIModelID) == "": + lines = append(lines, fmt.Sprintf("%s 没有绑定 AI 模型,所以没法做交易决策。", traderName)) + lines = append(lines, "该怎么办:先给这个交易员绑定一个已启用、可正常调用的模型。") + case ev.Model != nil && !modelEnabled(ev.Model): + lines = append(lines, fmt.Sprintf("%s 绑定的 AI 模型目前没有启用,所以没法稳定做交易决策。", traderName)) + lines = append(lines, "该怎么办:启用当前模型,或者把交易员换到另一个可用模型。") + case strings.TrimSpace(ev.TraderConfig.ExchangeID) == "": + lines = append(lines, fmt.Sprintf("%s 没有绑定交易所账户,所以即使有信号也不能下单。", traderName)) + lines = append(lines, "该怎么办:先绑定一个可用的交易所账户。") + case ev.Exchange != nil && !exchangeEnabled(ev.Exchange): + lines = append(lines, fmt.Sprintf("%s 绑定的交易所账户目前没有启用,所以不能下单。", traderName)) + lines = append(lines, "该怎么办:启用这个交易所账户,或换成另一个可用账户。") + case hasAmountTooSmall: + summary := fmt.Sprintf("%s 不是没运行。最近它有尝试开 %s 的单,但账户资金太小,算出来的开仓金额", traderName, primarySymbol) + if amount > 0 { + summary += fmt.Sprintf("约 %.2f USDT", amount) + } + summary += ",低于系统最小下单要求" + if minimum > 0 { + summary += fmt.Sprintf(" %.2f USDT", minimum) + } + summary += ",所以这笔单被拦下了。" + lines = append(lines, summary) + if totalEquity > 0 && maxBTCETHPositionValue > 0 { + lines = append(lines, fmt.Sprintf("当前账户权益约 %.2f USDT,按策略风控算出来的单笔仓位上限约 %.2f USDT,容易达不到最小下单金额。", totalEquity, maxBTCETHPositionValue)) + } + if latestWait { + lines = append(lines, "另外,最近也有一些周期是 AI 主动选择等待,说明并不是系统完全没跑。") + } + lines = append(lines, "该怎么办:增加账户资金,或者换更适合小资金的策略/标的。AI 智能策略里的最小开仓金额是系统限制,不能手动修改。") + case latestWait: + lines = append(lines, fmt.Sprintf("%s 是运行的,最近 AI 决策也成功了;它不开单的原因是当前信号没有达到入场标准,所以主动选择等待。", traderName)) + lines = append(lines, "该怎么办:如果你想让它更容易出手,可以调整产品里真实可改的策略偏好,比如降低最低置信度或最低盈亏比;如果你更重视安全,就让它继续等待更明确的机会。") + case !hasDecision: + lines = append(lines, fmt.Sprintf("%s 目前没有读到最近 AI 决策记录,所以还不能证明它已经跑到完整决策周期。", traderName)) + lines = append(lines, "该怎么办:确认交易员已启动,并等待一个扫描周期后再查;如果仍然没有决策记录,再检查运行状态和模型调用。") + case len(latest.CandidateCoins) == 0: + lines = append(lines, fmt.Sprintf("%s 最近没有拿到可交易候选币,所以没有进入开单。", traderName)) + lines = append(lines, "该怎么办:检查策略的选币方式、指定币种或排除币设置,确认当前策略确实有可交易标的。") + case strings.TrimSpace(latest.ErrorMessage) != "": + lines = append(lines, fmt.Sprintf("%s 最近没有开单,是因为系统在决策或下单校验时返回了错误:%s", traderName, latest.ErrorMessage)) + lines = append(lines, "该怎么办:先按这条错误处理;如果它涉及交易所权限、余额、仓位模式或最小下单金额,就优先处理对应账户或策略可编辑项。") + default: + lines = append(lines, fmt.Sprintf("%s 最近没有开单,但现有记录没有显示明确的拒单原因。", traderName)) + lines = append(lines, "该怎么办:继续观察下一个扫描周期;如果连续没有开单,再重点看策略门槛、账户余额、交易所权限和模型调用是否正常。") + } + return strings.Join(lines, "\n") + } + + lines := []string{} + switch { + case !ev.TraderConfig.IsRunning: + lines = append(lines, fmt.Sprintf("%s is not running, so it will not open trades.", traderName)) + case hasAmountTooSmall: + lines = append(lines, fmt.Sprintf("%s did try to open a %s trade, but the calculated order size was below the system minimum, so it was blocked.", traderName, primarySymbol)) + case latestWait: + lines = append(lines, fmt.Sprintf("%s is running, but the latest AI decision chose to wait because the signal did not meet its entry standard.", traderName)) + case !hasDecision: + lines = append(lines, fmt.Sprintf("%s has no recent AI decision records yet, so there is not enough evidence that it completed a decision cycle.", traderName)) + case len(latest.CandidateCoins) == 0: + lines = append(lines, fmt.Sprintf("%s has no tradable candidate coins in the latest decision, so it did not open a trade.", traderName)) + case strings.TrimSpace(latest.ErrorMessage) != "": + lines = append(lines, fmt.Sprintf("%s did not open a trade because the latest decision/check returned: %s", traderName, latest.ErrorMessage)) + default: + lines = append(lines, fmt.Sprintf("%s has no clear rejection reason in the latest records yet.", traderName)) + } + lines = append(lines, "What to do: use the real editable product settings or account actions, such as adding funds, changing to a small-account-friendly symbol/strategy, or adjusting confidence/risk-reward preferences. Do not change system-enforced fields.") + return strings.Join(lines, "\n") +} + +func primaryDiagnosisSymbol(candidates []string, decisionJSON string) string { + for _, candidate := range candidates { + if trimmed := strings.TrimSpace(candidate); trimmed != "" { + return trimmed + } + } + match := regexp.MustCompile(`(?i)"symbol"\s*:\s*"([^"]+)"`).FindStringSubmatch(decisionJSON) + if len(match) >= 2 && strings.TrimSpace(match[1]) != "" { + return strings.ToUpper(strings.TrimSpace(match[1])) + } + return "当前标的" +} + +func openingAmountAndMinimum(evidence string) (float64, float64) { + amount := 0.0 + minimum := 0.0 + if match := regexp.MustCompile(`(?i)opening amount too small \((\d+(?:\.\d+)?)\s*USDT\)`).FindStringSubmatch(evidence); len(match) >= 2 { + amount, _ = strconv.ParseFloat(match[1], 64) + } + if amount == 0 { + if match := regexp.MustCompile(`(?i)"position_size_usd"\s*:\s*(\d+(?:\.\d+)?)`).FindStringSubmatch(evidence); len(match) >= 2 { + amount, _ = strconv.ParseFloat(match[1], 64) + } + } + if match := regexp.MustCompile(`(?:must be|must be ≥|>=|≥)\s*(\d+(?:\.\d+)?)\s*USDT`).FindStringSubmatch(evidence); len(match) >= 2 { + minimum, _ = strconv.ParseFloat(match[1], 64) + } + return amount, minimum +} + +func nestedMap(root map[string]any, path ...string) (map[string]any, bool) { + var current any = root + for _, key := range path { + obj, ok := current.(map[string]any) + if !ok { + return nil, false + } + current, ok = obj[key] + if !ok { + return nil, false + } + } + obj, ok := current.(map[string]any) + return obj, ok +} + +func firstPositiveFloat(values ...any) float64 { + for _, value := range values { + parsed := toFloat(value) + if parsed > 0 { + return parsed + } + } + return 0 +} + +func nonZeroPositions(positions []map[string]any) []map[string]any { + out := make([]map[string]any, 0, len(positions)) + for _, position := range positions { + if toFloat(position["size"]) != 0 { + out = append(out, position) + } + } + return out +} + +func joinAnyLines(values []any) string { + lines := make([]string, 0, len(values)) + for _, value := range values { + switch typed := value.(type) { + case string: + lines = append(lines, typed) + default: + raw, _ := json.Marshal(typed) + if len(raw) > 0 { + lines = append(lines, string(raw)) + } + } + } + return strings.Join(lines, "\n") +} + +func valueOrUnset(value string) string { + return defaultIfEmpty(strings.TrimSpace(value), "未设置") +} + +func modelName(model *safeModelToolConfig) string { + if model == nil { + return "" + } + return model.Name +} + +func modelProvider(model *safeModelToolConfig) string { + if model == nil { + return "" + } + return model.Provider +} + +func modelEnabled(model *safeModelToolConfig) bool { + return model != nil && model.Enabled +} + +func exchangeName(exchange *safeExchangeToolConfig) string { + if exchange == nil { + return "" + } + return defaultIfEmpty(exchange.AccountName, exchange.ExchangeType) +} + +func exchangeEnabled(exchange *safeExchangeToolConfig) bool { + return exchange != nil && exchange.Enabled +} + +func strategyName(strategy *safeStrategyToolConfig) string { + if strategy == nil { + return "" + } + return strategy.Name } func (a *Agent) handleStrategyDiagnosisSkill(storeUserID, lang, text string) string { diff --git a/agent/skill_management_handlers.go b/agent/skill_management_handlers.go index ce7bba2b..f27cf30d 100644 --- a/agent/skill_management_handlers.go +++ b/agent/skill_management_handlers.go @@ -5,6 +5,7 @@ import ( "fmt" "regexp" "sort" + "strconv" "strings" "nofx/store" @@ -12,173 +13,118 @@ import ( var urlPattern = regexp.MustCompile(`https://[^\s"'<>]+`) -func detectTraderManagementIntent(text string) bool { +func hasExplicitCreateIntentForDomain(text, domain string) bool { lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { + if lower == "" || !hasExplicitManagementDomainCue(text, domain) { return false } - return containsAny(lower, []string{"交易员", "trader", "agent"}) && - containsAny(lower, []string{"修改", "编辑", "更新", "改", "改一下", "删除", "删了", "启动", "停止", "查看", "查询", "列出", "rename", "update", "delete", "start", "stop", "list", "show"}) -} - -func detectExchangeManagementIntent(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - return containsAny(lower, []string{"交易所", "exchange", "okx", "binance", "bybit", "gate", "kucoin", "hyperliquid"}) && - containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "删除", "删了", "查询", "查看", "列出", "启用", "禁用", "改名", "rename", "create", "update", "delete", "list", "show", "enable", "disable"}) -} - -func detectModelManagementIntent(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - return containsAny(lower, []string{"模型", "model", "provider", "deepseek", "openai", "claude", "gemini", "qwen", "kimi", "grok", "minimax"}) && - containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "删除", "删了", "查询", "查看", "列出", "启用", "禁用", "改名", "rename", "create", "update", "delete", "list", "show", "enable", "disable"}) -} - -func detectStrategyManagementIntent(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - if wantsDefaultStrategyConfig(text) { - return true - } - return containsAny(lower, []string{"策略", "strategy"}) && - containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "改成", "改为", "删除", "删了", "查询", "查看", "列出", "激活", "复制", "参数", "配置", "详情", "详细", "prompt", "提示词", "什么样", "怎么样", "create", "update", "delete", "list", "show", "activate", "duplicate", "detail", "details", "config", "configuration", "parameter", "prompt", "what kind"}) -} - -func detectTraderDiagnosisSkill(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - return containsAny(lower, []string{"交易员", "trader"}) && - containsAny(lower, []string{"启动失败", "不交易", "没开仓", "无法启动", "异常", "失败", "diagnose", "error", "not trading"}) -} - -func detectStrategyDiagnosisSkill(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - return containsAny(lower, []string{"策略", "strategy", "prompt"}) && - containsAny(lower, []string{"不生效", "没生效", "异常", "失败", "不一致", "失效", "diagnose", "error"}) -} - -func detectManagementAction(text string, domain string) string { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return "" - } - hasUpdateVerb := containsAny(lower, []string{"修改", "编辑", "更新", "改", "rename", "update", "切换", "换成", "换到"}) - switch { - case containsAny(lower, []string{"删除", "删掉", "删了", "remove", "delete"}): - return "delete" - case containsAny(lower, []string{"启动", "开始", "run", "start"}) && domain == "trader": - return "start" - case containsAny(lower, []string{"停止", "停掉", "stop", "pause"}) && domain == "trader": - return "stop" - case containsAny(lower, []string{"激活", "activate"}) && domain == "strategy": - return "activate" - case containsAny(lower, []string{"复制", "duplicate"}) && domain == "strategy": - return "duplicate" - case containsAny(lower, []string{"改名", "重命名", "rename"}): - return "update_name" - case domain == "trader" && containsAny(lower, []string{"换模型", "换交易所", "换策略", "绑定", "切换模型", "切换交易所", "切换策略"}): - return "update_bindings" - case (domain == "exchange" || domain == "model") && containsAny(lower, []string{"启用", "禁用", "enable", "disable"}): - return "update_status" - case domain == "model" && hasUpdateVerb && containsAny(lower, []string{"url", "endpoint", "地址", "接口"}): - return "update_endpoint" - case domain == "strategy" && hasUpdateVerb && containsAny(lower, []string{"prompt", "提示词"}): - return "update_prompt" - case domain == "strategy" && hasUpdateVerb && containsAny(lower, []string{ - "参数", "配置", "config", "configuration", "parameter", - "最大持仓", "最小置信度", "最低置信度", "主周期", "多周期", "时间框架", - "btc/eth杠杆", "btc eth杠杆", "山寨币杠杆", - "核心指标", "ema", "macd", "rsi", "atr", "boll", "bollinger", "布林", - }): - return "update_config" - case containsAny(lower, []string{"修改", "编辑", "更新", "改", "rename", "update"}): - return "update" - case domain == "trader" && containsAny(lower, []string{"运行中的", "在跑", "running"}): - return "query_running" - case !containsAny(lower, []string{"创建", "新建", "create", "new"}) && - containsAny(lower, []string{"详情", "详细", "prompt", "提示词", "什么样", "怎么样", "detail", "details", "what kind"}): - return "query_detail" - case containsAny(lower, []string{"查询", "查看", "列出", "list", "show", "有哪些"}): - return "query_list" - case containsAny(lower, []string{"创建", "新建", "加一个", "create", "new"}): - return "create" - default: - return "" - } -} - -func exchangeTypeFromText(text string) string { - lower := strings.ToLower(text) - candidates := []string{"binance", "okx", "bybit", "gate", "kucoin", "hyperliquid", "aster", "lighter"} - for _, candidate := range candidates { - if strings.Contains(lower, candidate) { - return candidate - } - } - switch { - case strings.Contains(text, "币安"): - return "binance" - case strings.Contains(text, "欧易"): - return "okx" - case strings.Contains(text, "库币"): - return "kucoin" - default: - return "" - } -} - -func providerFromText(text string) string { - lower := strings.ToLower(text) - candidates := []string{"openai", "deepseek", "claude", "gemini", "qwen", "kimi", "grok", "minimax"} - for _, candidate := range candidates { - if strings.Contains(lower, candidate) { - return candidate - } - } - if strings.Contains(text, "通义") { - return "qwen" - } - return "" + return containsAny(lower, []string{"创建", "新建", "创一个", "创个", "建一个", "create", "new"}) } func extractURL(text string) string { return strings.TrimSpace(urlPattern.FindString(text)) } -func extractPostKeywordName(text string, keywords []string) string { - trimmed := strings.TrimSpace(text) - for _, keyword := range keywords { - if idx := strings.Index(trimmed, keyword); idx >= 0 { - name := strings.TrimSpace(trimmed[idx+len(keyword):]) - name = strings.Trim(name, "“”\"':: ") - if name != "" && len([]rune(name)) <= 50 { - return name +func setField(session *skillSession, key, value string) { + ensureSkillFields(session) + key = normalizeFieldKey(session, key) + value = strings.TrimSpace(value) + if value == "" { + return + } + if session != nil && session.Name == "trader_management" && key == "name" { + value = normalizeTraderDraftName(value) + if value == "" { + return + } + } + session.Fields[key] = value + syncTraderCreateSlotMirror(session) +} + +func fieldValue(session skillSession, key string) string { + key = normalizeFieldKey(&session, key) + if session.Fields != nil { + if value := strings.TrimSpace(session.Fields[key]); value != "" { + return value + } + } + if session.Name == "trader_management" && session.Slots != nil { + switch key { + case "name": + return strings.TrimSpace(session.Slots.Name) + case "exchange_id": + return strings.TrimSpace(session.Slots.ExchangeID) + case "exchange_name": + return strings.TrimSpace(session.Slots.ExchangeName) + case "model_id": + return strings.TrimSpace(session.Slots.ModelID) + case "model_name": + return strings.TrimSpace(session.Slots.ModelName) + case "strategy_id": + return strings.TrimSpace(session.Slots.StrategyID) + case "strategy_name": + return strings.TrimSpace(session.Slots.StrategyName) + case "auto_start": + if session.Slots.AutoStart != nil { + if *session.Slots.AutoStart { + return "true" + } + return "false" } } } return "" } -func setField(session *skillSession, key, value string) { - ensureSkillFields(session) - value = strings.TrimSpace(value) - if value == "" { - return +func normalizeFieldKey(session *skillSession, key string) string { + key = strings.TrimSpace(key) + if session == nil || session.Name != "trader_management" { + return key + } + switch key { + case "ai_model_id": + return "model_id" + default: + return key } - session.Fields[key] = value } -func fieldValue(session skillSession, key string) string { - if session.Fields == nil { - return "" +func syncTraderCreateSlotMirror(session *skillSession) { + if session == nil || session.Name != "trader_management" { + return + } + if session.Slots == nil { + session.Slots = &createTraderSkillSlots{} + } + if session.Fields == nil { + return + } + if value := strings.TrimSpace(session.Fields["name"]); value != "" { + session.Slots.Name = value + } + if value := strings.TrimSpace(session.Fields["exchange_id"]); value != "" { + session.Slots.ExchangeID = value + } + if value := strings.TrimSpace(session.Fields["exchange_name"]); value != "" { + session.Slots.ExchangeName = value + } + if value := strings.TrimSpace(session.Fields["model_id"]); value != "" { + session.Slots.ModelID = value + } + if value := strings.TrimSpace(session.Fields["model_name"]); value != "" { + session.Slots.ModelName = value + } + if value := strings.TrimSpace(session.Fields["strategy_id"]); value != "" { + session.Slots.StrategyID = value + } + if value := strings.TrimSpace(session.Fields["strategy_name"]); value != "" { + session.Slots.StrategyName = value + } + if value := strings.TrimSpace(session.Fields["auto_start"]); value != "" { + b := strings.EqualFold(value, "true") + session.Slots.AutoStart = &b } - return strings.TrimSpace(session.Fields[key]) } func textMeansAllTargets(text string) bool { @@ -187,44 +133,129 @@ func textMeansAllTargets(text string) bool { return false } return containsAny(lower, []string{ - "全部", "所有", "全都", "全部策略", "所有策略", + "全部", "所有", "全都", "全部策略", "所有策略", "全部删除", "全部删掉", "全部删了", + "全删", "全删了", "都删", "都删了", "全清", "全清掉", "all", "all strategies", "every strategy", }) } func supportsBulkTargetSelection(skillName, action string) bool { - return skillName == "strategy_management" && action == "delete" + switch skillName { + case "strategy_management", "trader_management": + return action == "delete" + default: + return false + } } func resolveTargetFromText(text string, options []traderSkillOption, existing *EntityReference) *EntityReference { - if existing != nil && (existing.ID != "" || existing.Name != "") { - return existing + return resolveTargetSelection(text, options, existing).Ref +} + +func hasStrictOptionMention(text string, options []traderSkillOption) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false } - if match := pickMentionedOption(text, options); match != nil { - return &EntityReference{ID: match.ID, Name: match.Name} + for _, option := range options { + name := strings.ToLower(strings.TrimSpace(option.Name)) + if name != "" && strings.Contains(lower, name) { + return true + } + id := strings.ToLower(strings.TrimSpace(option.ID)) + if id != "" && strings.Contains(lower, id) { + return true + } } - if choice := choosePreferredOption(options); choice != nil { - return &EntityReference{ID: choice.ID, Name: choice.Name} + return false +} + +func isSimpleEntityMutationAction(action string) bool { + switch strings.TrimSpace(action) { + case "update", "update_name", "update_status", "update_endpoint", "update_bindings", + "configure_strategy", "configure_exchange", "configure_model", + "update_prompt", "update_config", "activate", "duplicate": + return true + default: + return false } - return nil +} + +func hasExplicitManagementDomainCue(text, domain string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + switch strings.TrimSpace(domain) { + case "trader": + return containsAny(lower, []string{"交易员", "trader", "agent"}) + case "exchange": + return containsAny(lower, []string{"交易所", "exchange", "okx", "binance", "bybit", "gate", "kucoin", "hyperliquid"}) + case "model": + return containsAny(lower, []string{"模型", "model"}) + case "strategy": + return containsAny(lower, []string{"策略", "strategy"}) + default: + return false + } +} + +func ensureLiveTargetReference(session *skillSession, options []traderSkillOption) bool { + if session == nil || session.TargetRef == nil { + return true + } + var match *traderSkillOption + if id := strings.TrimSpace(session.TargetRef.ID); id != "" { + match = findOptionByIDOrName(options, id) + } + if match == nil { + if name := strings.TrimSpace(session.TargetRef.Name); name != "" { + match = findOptionByIDOrName(options, name) + if match == nil { + match = findUniqueContainingOption(options, name) + } + } + } + if match == nil { + session.TargetRef = nil + return false + } + session.TargetRef.ID = match.ID + session.TargetRef.Name = defaultIfEmpty(match.Name, session.TargetRef.Name) + return true +} + +func (a *Agent) buildSimpleEntityConversationResources(storeUserID string, session skillSession, options []traderSkillOption) map[string]any { + missing := missingFieldKeysForSkillSession(session) + resources := map[string]any{} + for _, field := range missing { + switch strings.TrimSpace(field) { + case "target_ref": + if len(options) > 0 { + resources["targets"] = options + } + case "exchange_name", "exchange_id", "exchange": + resources["exchanges"] = a.loadExchangeOptions(storeUserID) + case "model_name", "model_id", "ai_model_id", "model": + resources["models"] = a.loadEnabledModelOptions(storeUserID) + case "strategy_name", "strategy_id", "strategy": + resources["strategies"] = a.loadStrategyOptions(storeUserID) + } + } + return resources } func (a *Agent) handleTraderManagementSkill(storeUserID string, userID int64, lang, text string, session skillSession) (string, bool) { - action := detectManagementAction(text, "trader") - if session.Name == "trader_management" && session.Action != "" { - action = session.Action - } - if action == "" || action == "create" { + if session.Name != "trader_management" || session.Action == "" { return "", false } + action := session.Action if action == "query_running" { answer := formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)) return applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), "running_only"), true } if action == "query_detail" { - options := a.loadTraderOptions(storeUserID) - target := resolveTargetFromText(text, options, session.TargetRef) - if detail, ok := a.describeTrader(storeUserID, lang, target); ok { + if detail, ok := a.describeTrader(storeUserID, lang, session.TargetRef); ok { return detail, true } return formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)), true @@ -233,20 +264,16 @@ func (a *Agent) handleTraderManagementSkill(storeUserID string, userID int64, la } func (a *Agent) handleExchangeManagementSkill(storeUserID string, userID int64, lang, text string, session skillSession) (string, bool) { - action := detectManagementAction(text, "exchange") - if session.Name == "exchange_management" && session.Action != "" { - action = session.Action - } - if action == "" { + if session.Name != "exchange_management" || session.Action == "" { return "", false } + action := session.Action options := a.loadExchangeOptions(storeUserID) switch action { case "query_list": return formatReadFastPathResponse(lang, "get_exchange_configs", a.toolGetExchangeConfigs(storeUserID)), true case "query_detail": - target := resolveTargetFromText(text, options, session.TargetRef) - if detail, ok := a.describeExchange(storeUserID, lang, target); ok { + if detail, ok := a.describeExchange(storeUserID, lang, session.TargetRef); ok { return detail, true } return formatReadFastPathResponse(lang, "get_exchange_configs", a.toolGetExchangeConfigs(storeUserID)), true @@ -258,20 +285,16 @@ func (a *Agent) handleExchangeManagementSkill(storeUserID string, userID int64, } func (a *Agent) handleModelManagementSkill(storeUserID string, userID int64, lang, text string, session skillSession) (string, bool) { - action := detectManagementAction(text, "model") - if session.Name == "model_management" && session.Action != "" { - action = session.Action - } - if action == "" { + if session.Name != "model_management" || session.Action == "" { return "", false } + action := session.Action options := a.loadEnabledModelOptions(storeUserID) switch action { case "query_list": return formatReadFastPathResponse(lang, "get_model_configs", a.toolGetModelConfigs(storeUserID)), true case "query_detail": - target := resolveTargetFromText(text, options, session.TargetRef) - if detail, ok := a.describeModel(storeUserID, lang, target); ok { + if detail, ok := a.describeModel(storeUserID, lang, session.TargetRef); ok { return detail, true } return formatReadFastPathResponse(lang, "get_model_configs", a.toolGetModelConfigs(storeUserID)), true @@ -283,24 +306,14 @@ func (a *Agent) handleModelManagementSkill(storeUserID string, userID int64, lan } func (a *Agent) handleStrategyManagementSkill(storeUserID string, userID int64, lang, text string, session skillSession) (string, bool) { - action := detectManagementAction(text, "strategy") - if session.Name == "strategy_management" && session.Action != "" { - action = session.Action - } - if action == "" && wantsStrategyDetails(text) { - action = "query_detail" - } - if action == "" { + if session.Name != "strategy_management" || session.Action == "" { return "", false } + action := session.Action options := a.loadStrategyOptions(storeUserID) switch action { case "query_detail": - if wantsDefaultStrategyConfig(text) { - return a.describeDefaultStrategyConfig(lang), true - } - target := resolveTargetFromText(text, options, session.TargetRef) - if detail, ok := a.describeStrategy(storeUserID, lang, target); ok { + if detail, ok := a.describeStrategy(storeUserID, lang, session.TargetRef); ok { return detail, true } return formatReadFastPathResponse(lang, "get_strategies", a.toolGetStrategies(storeUserID)), true @@ -313,17 +326,1261 @@ func (a *Agent) handleStrategyManagementSkill(storeUserID string, userID int64, } } -func wantsStrategyDetails(text string) bool { +// strategyCreateDraftConfigField stores the materialized, product-normalized +// draft between turns. User-visible strategy proposals should still be rendered +// from the post-merge structured config, not from free-form LLM text. +const strategyCreateDraftConfigField = "strategy_create_draft_config" +const strategyCreateConfigPatchField = "config_patch" + +func marshalStrategyCreateDraft(cfg store.StrategyConfig) string { + raw, err := json.Marshal(cfg) + if err != nil { + return "" + } + return string(raw) +} + +func unmarshalStrategyCreateDraft(raw, lang string) store.StrategyConfig { + cfg := store.GetDefaultStrategyConfig(lang) + if strings.TrimSpace(raw) == "" { + return cfg + } + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + return store.GetDefaultStrategyConfig(lang) + } + return cfg +} + +func strategyCreateConfigFromSession(session skillSession, lang string) (store.StrategyConfig, map[string]any, []string, error) { + normalizeLegacyStrategyCreateSession(&session) + cfg := unmarshalStrategyCreateDraft(fieldValue(session, strategyCreateDraftConfigField), lang) + patchRaw := strings.TrimSpace(fieldValue(session, strategyCreateConfigPatchField)) + var patch map[string]any + if patchRaw != "" { + if err := json.Unmarshal([]byte(patchRaw), &patch); err != nil { + return cfg, nil, nil, fmt.Errorf("策略配置 patch 不是合法 JSON:%w", err) + } + merged, err := store.MergeStrategyConfig(cfg, patch) + if err != nil { + return cfg, nil, nil, fmt.Errorf("策略配置 patch 无法应用:%w", err) + } + cfg = merged + } + applyStrategyCreateTypeDefaults(&cfg) + beforeClamp := cfg + cfg.ClampLimits() + rawCfg, _ := json.Marshal(cfg) + var configMap map[string]any + _ = json.Unmarshal(rawCfg, &configMap) + removeLockedStrategyCreateFields(configMap) + return cfg, configMap, store.StrategyClampWarnings(beforeClamp, cfg, cfg.Language), nil +} + +func resolveStrategyCreateName(session *skillSession, text string) string { + if session == nil { + return "" + } + name := strings.TrimSpace(fieldValue(*session, "name")) + if name == "" { + if inferred := inferStandaloneStrategyName(text); inferred != "" { + name = inferred + } + } + if name != "" { + setField(session, "name", name) + } + return name +} + +func normalizeLegacyStrategyCreateSession(session *skillSession) { + if session == nil || session.Action != "create" { + return + } + strategyType := explicitStrategyCreateType(*session) + if strategyType == "" { + return + } + filterLegacyStrategyCreateFieldsForType(session, strategyType) + if patchRaw := strings.TrimSpace(fieldValue(*session, strategyCreateConfigPatchField)); patchRaw != "" { + if sanitized := sanitizeStrategyCreateConfigPatchForType(patchRaw, strategyType); len(sanitized) > 0 { + raw, _ := json.Marshal(sanitized) + setField(session, strategyCreateConfigPatchField, string(raw)) + } else { + delete(session.Fields, strategyCreateConfigPatchField) + } + } +} + +func filterLegacyStrategyCreateFieldsForType(session *skillSession, strategyType string) { + if session == nil || len(session.Fields) == 0 { + return + } + allowed := map[string]struct{}{} + for _, key := range []string{ + "name", + "description", + "is_public", + "config_visible", + "lang", + "strategy_type", + strategyCreateDraftConfigField, + strategyCreateConfigPatchField, + skillDAGStepField, + "awaiting_final_confirmation", + } { + allowed[key] = struct{}{} + } + for key := range session.Fields { + if _, ok := allowed[key]; !ok { + delete(session.Fields, key) + } + } +} + +func resetLegacyStrategyCreateSessionForType(session *skillSession, strategyType string) { + if session == nil { + return + } + keep := map[string]string{} + for _, key := range []string{"name", "description", "is_public", "config_visible", "lang"} { + if value := fieldValue(*session, key); strings.TrimSpace(value) != "" { + keep[key] = value + } + } + session.Fields = keep + setField(session, "strategy_type", strategyType) +} + +func setStrategyCreateType(session *skillSession, strategyType string) { + if session == nil || strategyType == "" { + return + } + current := explicitStrategyCreateType(*session) + if current != "" && current != strategyType { + resetLegacyStrategyCreateSessionForType(session, strategyType) + return + } + setField(session, "strategy_type", strategyType) + filterLegacyStrategyCreateFieldsForType(session, strategyType) +} + +func applyStrategyCreateTypeDefaults(cfg *store.StrategyConfig) { + if cfg == nil { + return + } + switch strings.TrimSpace(cfg.StrategyType) { + case "grid_trading": + defaultGrid := store.DefaultGridStrategyConfig() + if cfg.GridConfig == nil { + cfg.GridConfig = &defaultGrid + return + } + if strings.TrimSpace(cfg.GridConfig.Symbol) == "" { + cfg.GridConfig.Symbol = defaultGrid.Symbol + } + if cfg.GridConfig.GridCount <= 0 { + cfg.GridConfig.GridCount = defaultGrid.GridCount + } + if cfg.GridConfig.TotalInvestment <= 0 { + cfg.GridConfig.TotalInvestment = defaultGrid.TotalInvestment + } + if cfg.GridConfig.Leverage <= 0 { + cfg.GridConfig.Leverage = defaultGrid.Leverage + } + if cfg.GridConfig.ATRMultiplier <= 0 { + cfg.GridConfig.ATRMultiplier = defaultGrid.ATRMultiplier + } + if strings.TrimSpace(cfg.GridConfig.Distribution) == "" { + cfg.GridConfig.Distribution = defaultGrid.Distribution + } + if cfg.GridConfig.MaxDrawdownPct <= 0 { + cfg.GridConfig.MaxDrawdownPct = defaultGrid.MaxDrawdownPct + } + if cfg.GridConfig.StopLossPct <= 0 { + cfg.GridConfig.StopLossPct = defaultGrid.StopLossPct + } + if cfg.GridConfig.DailyLossLimitPct <= 0 { + cfg.GridConfig.DailyLossLimitPct = defaultGrid.DailyLossLimitPct + } + if cfg.GridConfig.DirectionBiasRatio <= 0 { + cfg.GridConfig.DirectionBiasRatio = defaultGrid.DirectionBiasRatio + } + if cfg.GridConfig.UpperPrice <= 0 && cfg.GridConfig.LowerPrice <= 0 { + cfg.GridConfig.UseATRBounds = true + } + case "": + cfg.StrategyType = "ai_trading" + } +} + +func removeLockedStrategyCreateFields(configMap map[string]any) { + if configMap == nil { + return + } + risk, ok := configMap["risk_control"].(map[string]any) + if ok { + removeLockedAIRiskFields(risk) + } + if aiConfig, ok := configMap["ai_config"].(map[string]any); ok { + if risk, ok := aiConfig["risk_control"].(map[string]any); ok { + removeLockedAIRiskFields(risk) + } + } +} + +func removeLockedAIRiskFields(risk map[string]any) { + delete(risk, "max_positions") + delete(risk, "btc_eth_max_position_value_ratio") + delete(risk, "btceth_max_position_value_ratio") + delete(risk, "altcoin_max_position_value_ratio") + delete(risk, "max_margin_usage") + delete(risk, "min_position_size") +} + +func strategyCreateConfirmationReply(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + for _, exact := range []string{ + "确认创建", "确认", "创建吧", "就按这个创建", "按这个创建", "确认应用", "就按这个应用", + "可以", "好的", "好", "没问题", "就这样", "按这个", "ok", "okay", "yes", "yep", "looks good", + } { + if lower == exact { + return true + } + } + return false +} + +func strategyCreateDefaultConfigReply(text string) bool { lower := strings.ToLower(strings.TrimSpace(text)) if lower == "" { return false } return containsAny(lower, []string{ - "什么样", "怎么样", "详情", "详细", "参数", "配置", "prompt", "提示词", - "what kind", "details", "detail", "config", "configuration", "parameter", "prompt", + "默认", "先创建", "直接创建", "不用配置", "其他默认", "用默认", "按默认", "默认配置", + "use default", "use defaults", "default config", "create now", "create directly", }) } +func explicitStrategyCreateType(session skillSession) string { + if value := strings.TrimSpace(fieldValue(session, "strategy_type")); value != "" { + return value + } + patchRaw := strings.TrimSpace(fieldValue(session, strategyCreateConfigPatchField)) + if patchRaw == "" { + return "" + } + var patch map[string]any + if err := json.Unmarshal([]byte(patchRaw), &patch); err != nil { + return "" + } + if value, ok := patch["strategy_type"].(string); ok { + return strings.TrimSpace(value) + } + if gridConfig, ok := patch["grid_config"]; ok && gridConfig != nil { + return "grid_trading" + } + if aiConfig, ok := patch["ai_config"]; ok && aiConfig != nil { + return "ai_trading" + } + return "" +} + +func strategyCreateConfigReady(session skillSession, cfg store.StrategyConfig, text string) (bool, string) { + strategyType := explicitStrategyCreateType(session) + if strategyType == "" { + return false, "strategy_type" + } + if missing := strategyCreateMissingTemplateFields(session, cfg); len(missing) > 0 { + return false, strings.Join(missing, ",") + } + return true, "" +} + +func strategyCreateFinalConfirmationReady(session skillSession) bool { + return strings.EqualFold(strings.TrimSpace(fieldValue(session, "awaiting_final_confirmation")), "true") +} + +func strategyCreateHasExplicitConfigBeyondType(session skillSession) bool { + for _, key := range manualStrategyEditableFieldKeys() { + switch key { + case "name", "description", "is_public", "config_visible", "strategy_type": + continue + } + if strings.TrimSpace(fieldValue(session, key)) != "" { + return true + } + } + patchRaw := strings.TrimSpace(fieldValue(session, strategyCreateConfigPatchField)) + if patchRaw == "" { + return false + } + var patch map[string]any + if err := json.Unmarshal([]byte(patchRaw), &patch); err != nil { + return true + } + for key := range patch { + if strings.TrimSpace(key) != "" && strings.TrimSpace(key) != "strategy_type" { + return true + } + } + return false +} + +func strategyCreateMissingTemplateFields(session skillSession, cfg store.StrategyConfig) []string { + switch explicitStrategyCreateType(session) { + case "ai_trading": + return strategyCreateMissingAIFields(session, cfg) + case "grid_trading": + return strategyCreateMissingGridFields(session) + default: + return []string{"strategy_type"} + } +} + +func strategyCreateMissingAIFields(session skillSession, cfg store.StrategyConfig) []string { + required := []string{ + "source_type", + "primary_timeframe", + "selected_timeframes", + "btceth_max_leverage", + "altcoin_max_leverage", + "min_confidence", + "min_risk_reward_ratio", + "trading_frequency", + "entry_standards", + } + missing := make([]string, 0, len(required)+1) + for _, field := range required { + if !strategyCreateFieldExplicit(session, field) { + missing = append(missing, field) + } + } + if strings.EqualFold(strings.TrimSpace(cfg.CoinSource.SourceType), "static") && !strategyCreateFieldExplicit(session, "static_coins") { + missing = append(missing, "static_coins") + } + return missing +} + +func strategyCreateMissingGridFields(session skillSession) []string { + required := []string{ + "symbol", + "grid_count", + "total_investment", + "leverage", + "distribution", + "max_drawdown_pct", + "stop_loss_pct", + "daily_loss_limit_pct", + "use_maker_only", + } + missing := make([]string, 0, len(required)+1) + for _, field := range required { + if !strategyCreateFieldExplicit(session, field) { + missing = append(missing, field) + } + } + if !strategyCreateFieldExplicit(session, "use_atr_bounds") && (!strategyCreateFieldExplicit(session, "upper_price") || !strategyCreateFieldExplicit(session, "lower_price")) { + missing = append(missing, "use_atr_bounds 或 upper_price/lower_price") + } + return missing +} + +func strategyCreateFieldExplicit(session skillSession, field string) bool { + field = strings.TrimSpace(field) + if field == "" { + return false + } + if strings.TrimSpace(fieldValue(session, field)) != "" { + return true + } + patchRaw := strings.TrimSpace(fieldValue(session, strategyCreateConfigPatchField)) + if patchRaw == "" { + return false + } + var patch map[string]any + if err := json.Unmarshal([]byte(patchRaw), &patch); err != nil { + return false + } + for _, path := range strategyCreatePatchPaths(field) { + if strategyCreatePatchHasPath(patch, path...) { + return true + } + } + return false +} + +func strategyCreatePatchPaths(field string) [][]string { + switch strings.TrimSpace(field) { + case "strategy_type": + return [][]string{{"strategy_type"}} + case "source_type": + return [][]string{ + {"ai_config", "coin_source", "source_type"}, {"coin_source", "source_type"}, + {"ai_config", "coin_source", "static_coins"}, {"coin_source", "static_coins"}, + {"ai_config", "coin_source", "use_ai500"}, {"coin_source", "use_ai500"}, + {"ai_config", "coin_source", "use_oi_top"}, {"coin_source", "use_oi_top"}, + {"ai_config", "coin_source", "use_oi_low"}, {"coin_source", "use_oi_low"}, + } + case "static_coins": + return [][]string{{"ai_config", "coin_source", "static_coins"}, {"coin_source", "static_coins"}} + case "primary_timeframe": + return [][]string{{"ai_config", "indicators", "klines", "primary_timeframe"}, {"indicators", "klines", "primary_timeframe"}} + case "selected_timeframes": + return [][]string{{"ai_config", "indicators", "klines", "selected_timeframes"}, {"indicators", "klines", "selected_timeframes"}} + case "btceth_max_leverage": + return [][]string{{"ai_config", "risk_control", "btc_eth_max_leverage"}, {"risk_control", "btc_eth_max_leverage"}, {"ai_config", "risk_control", "btceth_max_leverage"}, {"risk_control", "btceth_max_leverage"}} + case "altcoin_max_leverage": + return [][]string{{"ai_config", "risk_control", "altcoin_max_leverage"}, {"risk_control", "altcoin_max_leverage"}} + case "min_confidence": + return [][]string{{"ai_config", "risk_control", "min_confidence"}, {"risk_control", "min_confidence"}} + case "min_risk_reward_ratio": + return [][]string{{"ai_config", "risk_control", "min_risk_reward_ratio"}, {"risk_control", "min_risk_reward_ratio"}} + case "trading_frequency": + return [][]string{{"ai_config", "prompt_sections", "trading_frequency"}, {"prompt_sections", "trading_frequency"}} + case "entry_standards": + return [][]string{{"ai_config", "prompt_sections", "entry_standards"}, {"prompt_sections", "entry_standards"}} + case "symbol", "grid_count", "total_investment", "leverage", "distribution", "max_drawdown_pct", "stop_loss_pct", "daily_loss_limit_pct", "use_maker_only", "use_atr_bounds", "upper_price", "lower_price": + return [][]string{{"grid_config", field}} + default: + return [][]string{{field}} + } +} + +func strategyCreatePatchHasPath(value any, path ...string) bool { + current := value + for _, part := range path { + obj, ok := current.(map[string]any) + if !ok { + return false + } + next, ok := obj[part] + if !ok { + return false + } + current = next + } + return true +} + +func formatStrategyCreateConfigNeeded(lang, missingKind string) string { + if lang == "zh" { + if missingKind == "strategy_type" { + return "先选择策略类型:grid_trading(网格策略)或 ai_trading(AI 策略)。类型确认后我会继续收集对应配置,配置好后再创建。" + } + if hints := formatStrategyMissingFieldHints(lang, missingKind); hints != "" { + return "这份策略模板还没填完整,还缺这些字段。你可以按下面选,也可以直接说“你帮我按稳健/高频/激进来推荐”:\n" + hints + } + return "这份策略模板还没填完整,还缺:" + formatStrategyMissingFieldNames(lang, missingKind) + "。你可以一句话告诉我这些字段,我会继续填模板。" + } + if missingKind == "strategy_type" { + return "Choose the strategy type first: grid_trading or ai_trading. I will collect the matching config before creating it." + } + if hints := formatStrategyMissingFieldHints(lang, missingKind); hints != "" { + return "This strategy template is not complete yet. You can choose from these options, or ask me to recommend a conservative/balanced/high-frequency setup:\n" + hints + } + return "This strategy template is not complete yet. Missing: " + formatStrategyMissingFieldNames(lang, missingKind) + ". Tell me these fields in one message and I will keep filling the template." +} + +func formatStrategyMissingFieldHints(lang, missingKind string) string { + parts := strings.Split(missingKind, ",") + lines := make([]string, 0, len(parts)) + for _, part := range parts { + field := strings.TrimSpace(part) + if field == "" { + continue + } + hint := strategyCreateFieldInlineHint(lang, field) + if hint == "" { + hint = strategyCreateFieldDisplayName(lang, field) + } + if lang == "zh" { + lines = append(lines, "- "+hint) + } else { + lines = append(lines, "- "+hint) + } + } + return strings.Join(lines, "\n") +} + +func strategyCreateFieldInlineHint(lang, field string) string { + field = strings.TrimSpace(field) + if lang != "zh" { + switch field { + case "source_type": + return "Coin source: ai500 / oi_top / oi_low / static" + case "static_coins": + return "Static coins: up to 10 symbols, e.g. BTCUSDT, ETHUSDT" + case "primary_timeframe": + return "Primary timeframe: 1m / 3m / 5m / 15m / 30m / 1h / 2h / 4h / 6h / 8h / 12h / 1d / 3d / 1w" + case "selected_timeframes": + return "Multi-timeframes: up to 4, e.g. 5m,15m,1h" + case "btceth_max_leverage", "altcoin_max_leverage": + return strategyCreateFieldDisplayName(lang, field) + ": 1-20" + case "min_confidence": + return "Minimum confidence: 50-100" + case "min_risk_reward_ratio": + return "Minimum risk/reward ratio: 1-10, step 0.5" + case "trading_frequency": + return "Trading frequency rule: free text, e.g. max 2-4 trades per day" + case "entry_standards": + return "Entry standards: free text, e.g. enter only when trend and risk/reward align" + case "symbol": + return "Symbol: BTCUSDT / ETHUSDT / SOLUSDT / BNBUSDT / XRPUSDT / DOGEUSDT" + case "grid_count": + return "Grid count: 5-50" + case "total_investment": + return "Total investment: user's capital/margin budget, minimum 100 USDT; not leveraged notional exposure" + case "leverage": + return "Grid leverage: 1-5" + case "distribution": + return "Distribution: uniform / gaussian / pyramid" + case "max_drawdown_pct": + return "Max drawdown: 5%-50%" + case "stop_loss_pct": + return "Stop loss: 1%-20%" + case "daily_loss_limit_pct": + return "Daily loss limit: 1%-30%" + case "use_maker_only": + return "Maker only: on / off" + } + return "" + } + switch field { + case "source_type": + return "选币来源:AI500 / OI Top / OI Low / 静态币种(没有混合模式)" + case "static_coins": + return "静态币种:最多 10 个,例如 BTCUSDT、ETHUSDT" + case "primary_timeframe": + return "主周期:1m / 3m / 5m / 15m / 30m / 1h / 2h / 4h / 6h / 8h / 12h / 1d / 3d / 1w" + case "selected_timeframes": + return "多周期时间框架:最多 4 个,例如 5m,15m,1h" + case "btceth_max_leverage": + return "BTC/ETH 最大杠杆:1~20 倍" + case "altcoin_max_leverage": + return "山寨币最大杠杆:1~20 倍" + case "min_confidence": + return "最低置信度:50~100,越高越谨慎" + case "min_risk_reward_ratio": + return "最小盈亏比:1~10,步进 0.5" + case "trading_frequency": + return "交易频率规则:文本,例如“每天最多 2~4 笔,避免连续追单”" + case "entry_standards": + return "开仓标准:文本,例如“趋势明确、成交量配合、风险收益合理才开仓”" + case "symbol": + return "交易对:BTCUSDT / ETHUSDT / SOLUSDT / BNBUSDT / XRPUSDT / DOGEUSDT" + case "grid_count": + return "网格数量:5~50" + case "total_investment": + return "总投入:用户实际投入/保证金预算,最低 100 USDT;不是杠杆后的名义仓位" + case "leverage": + return "杠杆:1~5 倍" + case "distribution": + return "网格分布:uniform(均匀)/ gaussian(正态)/ pyramid(金字塔)" + case "max_drawdown_pct": + return "最大回撤:5%~50%" + case "stop_loss_pct": + return "止损:1%~20%" + case "daily_loss_limit_pct": + return "日亏损限制:1%~30%" + case "use_maker_only": + return "只挂 Maker:开启 / 关闭" + case "use_atr_bounds 或 upper_price/lower_price": + return "价格边界:开启 ATR 自动边界,或手动填写上边界/下边界" + } + return "" +} + +func formatStrategyCreateFieldOptionsReply(lang, text, missingKind string) string { + if !strategyCreateAsksFieldOptions(text) { + return "" + } + field := firstStrategyMissingField(missingKind) + if field == "" { + return "" + } + if lang != "zh" { + switch field { + case "source_type": + return "Coin source options: ai500, oi_top, oi_low, or static. Pick one and I will continue filling the AI strategy template." + case "primary_timeframe", "selected_timeframes": + return "Timeframe options: 1m, 3m, 5m, 15m, 30m, 1h, 2h, 4h, 6h, 8h, 12h, 1d, 3d, 1w." + } + return "For " + strategyCreateFieldDisplayName(lang, field) + ", tell me the value you want and I will keep filling the selected strategy template." + } + switch field { + case "strategy_type": + return "策略类型只有两个:\n- AI 策略:让 AI 根据行情和策略规则判断开平仓。\n- 网格策略:在价格区间内按网格低买高卖。\n你直接回复“AI 策略”或“网格策略”就行。" + case "source_type": + return "AI 策略的选币来源有 4 个:\n- AI500:从 NOFX AI500 榜单自动选币。\n- OI Top:选持仓量靠前/更活跃的币。\n- OI Low:选持仓量较低或变化较弱的币。\n- 静态币种:你指定固定币种,比如 BTCUSDT、ETHUSDT。\n没有混合模式。你选一个,我继续填模板。" + case "primary_timeframe": + return "主周期可选:1m、3m、5m、15m、30m、1h、2h、4h、6h、8h、12h、1d、3d、1w。高频一般偏 1m/3m/5m,稳健一点可以用 15m/1h。" + case "selected_timeframes": + return "多周期最多选 4 个,可选:1m、3m、5m、15m、30m、1h、2h、4h、6h、8h、12h、1d、3d、1w。常见组合比如 5m,15m,1h。" + case "btceth_max_leverage", "altcoin_max_leverage": + return strategyCreateFieldDisplayName(lang, field) + "范围是 1~20 倍。数值越高风险越大。" + case "min_confidence": + return "最低置信度范围是 50~100。数值越高越谨慎,开单会更少。" + case "min_risk_reward_ratio": + return "最小盈亏比范围是 1~10,步进 0.5。比如 1.5 表示预期收益至少是风险的 1.5 倍。" + case "trading_frequency": + return "交易频率规则是文本规则,例如“每天最多 2~4 笔,避免连续追单”。你也可以说“你帮我按高频但不过度交易来写”。" + case "entry_standards": + return "开仓标准是文本规则,例如“只在趋势明确、成交量配合、风险收益合理时开仓”。你也可以说“你帮我写一版稳健开仓标准”。" + case "symbol": + return "网格交易对可选:BTCUSDT、ETHUSDT、SOLUSDT、BNBUSDT、XRPUSDT、DOGEUSDT。" + case "grid_count": + return "网格数量范围是 5~50。数量越多越密,交易更频繁;数量越少,每格空间更大。" + case "total_investment": + return "网格总投入是用户实际投入/保证金预算,不是杠杆后的名义仓位;最小 100 USDT,按 100 USDT 步进。" + case "leverage": + return "网格杠杆范围是 1~5 倍。稳健一般用 1 倍。" + case "distribution": + return "网格分布可选:uniform(均匀)、gaussian(正态)、pyramid(金字塔)。" + case "max_drawdown_pct": + return "最大回撤范围是 5%~50%。" + case "stop_loss_pct": + return "止损范围是 1%~20%。" + case "daily_loss_limit_pct": + return "日亏损限制范围是 1%~30%。" + case "use_maker_only": + return "只挂 Maker 是开关项:开启会更偏向低手续费挂单,成交可能慢一些;关闭则更灵活。" + } + return strategyCreateFieldDisplayName(lang, field) + "是当前模板字段。你告诉我想怎么设置,我继续填模板。" +} + +func strategyCreateAsksFieldOptions(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + return containsAny(lower, []string{ + "有哪些", "有什么", "可选", "选项", "怎么选", "怎么填", "不知道", "不会填", + "what options", "which options", "options", "how to choose", "how should i fill", + }) +} + +func firstStrategyMissingField(missingKind string) string { + for _, part := range strings.Split(missingKind, ",") { + part = strings.TrimSpace(part) + if part != "" { + return part + } + } + return "" +} + +func formatStrategyMissingFieldNames(lang, missingKind string) string { + parts := strings.Split(missingKind, ",") + names := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if strings.Contains(part, "或") || strings.Contains(part, "/") { + names = append(names, part) + continue + } + names = append(names, strategyCreateFieldDisplayName(lang, part)) + } + if lang == "zh" { + return strings.Join(names, "、") + } + return strings.Join(names, ", ") +} + +func strategyCreateFieldDisplayName(lang, field string) string { + if lang != "zh" { + return field + } + switch strings.TrimSpace(field) { + case "source_type": + return "选币来源" + case "static_coins": + return "静态币种" + case "primary_timeframe": + return "主周期" + case "selected_timeframes": + return "多周期时间框架" + case "btceth_max_leverage": + return "BTC/ETH 最大杠杆" + case "altcoin_max_leverage": + return "山寨币最大杠杆" + case "min_confidence": + return "最低置信度" + case "min_risk_reward_ratio": + return "最小盈亏比" + case "trading_frequency": + return "交易频率规则" + case "entry_standards": + return "开仓标准" + case "symbol": + return "交易对" + case "grid_count": + return "网格数量" + case "total_investment": + return "总投入" + case "leverage": + return "杠杆" + case "distribution": + return "网格分布" + case "max_drawdown_pct": + return "最大回撤" + case "stop_loss_pct": + return "止损" + case "daily_loss_limit_pct": + return "日亏损限制" + case "use_maker_only": + return "只挂 Maker" + default: + return field + } +} + +func formatStrategyCreateDraftSummary(lang, name, strategyType string, changedFields, warnings []string) string { + name = strings.TrimSpace(name) + if name == "" { + if lang == "zh" { + name = "未命名策略" + } else { + name = "unnamed strategy" + } + } + if lang == "zh" { + lines := []string{ + fmt.Sprintf("我先把策略草稿整理成了“%s”。", name), + } + if len(changedFields) > 0 { + lines = append(lines, "我已经识别到这些配置意图:") + for _, field := range changedFields { + lines = append(lines, "- "+field) + } + } + if len(warnings) > 0 { + lines = append(lines, "其中有些参数超出了当前安全范围,我先拦下来了:") + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + lines = append(lines, "你可以继续告诉我其他字段怎么设计;如果接受当前安全范围,也可以直接回复“确认创建”。") + return strings.Join(lines, "\n") + } + switch strategyType { + case "grid_trading": + lines = append(lines, "这是网格策略草稿。请继续补充交易对、网格数量、总投入、杠杆、价格区间和网格风控;我只会按产品编辑页模板填你明确给出或明确委托我设计的字段。") + case "ai_trading": + lines = append(lines, "这是 AI 策略草稿。请继续补充选币来源、时间周期、风险参数和提示词方向;我只会按产品编辑页模板填你明确给出或明确委托我设计的字段。") + default: + lines = append(lines, "你可以继续补充策略类型和对应参数;如果现在就创建,直接回复“确认创建”。") + } + return strings.Join(lines, "\n") + } + + lines := []string{ + fmt.Sprintf("I turned that into a draft strategy named %q.", name), + } + if len(changedFields) > 0 { + lines = append(lines, "Recognized fields:") + for _, field := range changedFields { + lines = append(lines, "- "+field) + } + } + if len(warnings) > 0 { + lines = append(lines, "Some values exceeded the current safety limits, so I stopped before creating it:") + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + lines = append(lines, "You can keep refining the draft, or reply 'confirm' to create it with the safe adjusted values.") + return strings.Join(lines, "\n") + } + switch strategyType { + case "grid_trading": + lines = append(lines, "This is a grid strategy draft. Keep refining symbol, grid count, total investment, leverage, price bounds, and grid risk settings; I will only fill fields you explicitly provide or ask me to design.") + case "ai_trading": + lines = append(lines, "This is an AI strategy draft. Keep refining coin source, timeframes, risk settings, and prompt direction; I will only fill fields you explicitly provide or ask me to design.") + default: + lines = append(lines, "You can keep refining the strategy type and matching parameters, or reply 'confirm' to create it now.") + } + return strings.Join(lines, "\n") +} + +func formatStrategyCreateFinalConfirmation(lang string, session skillSession, cfg store.StrategyConfig) string { + name := defaultIfEmpty(fieldValue(session, "name"), "未命名策略") + if lang != "zh" { + name = defaultIfEmpty(fieldValue(session, "name"), "unnamed strategy") + } + if lang == "zh" { + lines := []string{fmt.Sprintf("我已经把“%s”的配置整理好了,确认后我再创建到策略列表。", name)} + switch cfg.StrategyType { + case "grid_trading": + grid := cfg.GridConfig + if grid == nil { + grid = &store.GridStrategyConfig{} + } + lines = append(lines, + "- 类型:网格策略", + fmt.Sprintf("- 发布到策略市场:%t", fieldValue(session, "is_public") == "true"), + fmt.Sprintf("- 发布后配置可见:%t", fieldValue(session, "config_visible") != "false"), + fmt.Sprintf("- 交易对:%s", defaultIfEmpty(grid.Symbol, "未设置")), + fmt.Sprintf("- 网格数量:%d", grid.GridCount), + fmt.Sprintf("- 总投入:%.2f USDT", grid.TotalInvestment), + fmt.Sprintf("- 杠杆:%d倍", grid.Leverage), + ) + if grid.UseATRBounds { + lines = append(lines, fmt.Sprintf("- 价格区间:ATR 动态范围(倍数 %.2f)", grid.ATRMultiplier)) + } else { + lines = append(lines, fmt.Sprintf("- 价格区间:%.2f ~ %.2f", grid.LowerPrice, grid.UpperPrice)) + } + lines = append(lines, + fmt.Sprintf("- 网格分布:%s", defaultIfEmpty(grid.Distribution, "uniform")), + fmt.Sprintf("- 最大回撤:%.2f%%", grid.MaxDrawdownPct), + fmt.Sprintf("- 止损:%.2f%%", grid.StopLossPct), + fmt.Sprintf("- 日亏损限制:%.2f%%", grid.DailyLossLimitPct), + ) + default: + lines = append(lines, + "- 类型:AI 策略", + fmt.Sprintf("- 发布到策略市场:%t", fieldValue(session, "is_public") == "true"), + fmt.Sprintf("- 发布后配置可见:%t", fieldValue(session, "config_visible") != "false"), + fmt.Sprintf("- 选币来源:%s", defaultIfEmpty(cfg.CoinSource.SourceType, "未设置")), + ) + lines = append(lines, formatAICoinSourceSummaryZH(cfg)...) + lines = append(lines, + fmt.Sprintf("- 主周期:%s", defaultIfEmpty(cfg.Indicators.Klines.PrimaryTimeframe, "未设置")), + fmt.Sprintf("- K线数量:%d", cfg.Indicators.Klines.PrimaryCount), + fmt.Sprintf("- 多周期:%s", defaultIfEmpty(strings.Join(cfg.Indicators.Klines.SelectedTimeframes, ","), "未设置")), + fmt.Sprintf("- 指标:%s", formatEnabledAIIndicatorsZH(cfg)), + fmt.Sprintf("- NofxOS 量化数据:%t", cfg.Indicators.EnableQuantData), + fmt.Sprintf("- OI 排行数据:%t(%s / %d)", cfg.Indicators.EnableOIRanking, defaultIfEmpty(cfg.Indicators.OIRankingDuration, "未设置"), cfg.Indicators.OIRankingLimit), + fmt.Sprintf("- 资金流排行数据:%t(%s / %d)", cfg.Indicators.EnableNetFlowRanking, defaultIfEmpty(cfg.Indicators.NetFlowRankingDuration, "未设置"), cfg.Indicators.NetFlowRankingLimit), + fmt.Sprintf("- 涨跌幅排行数据:%t(%s / %d)", cfg.Indicators.EnablePriceRanking, defaultIfEmpty(cfg.Indicators.PriceRankingDuration, "未设置"), cfg.Indicators.PriceRankingLimit), + fmt.Sprintf("- BTC/ETH 最大杠杆:%d倍", cfg.RiskControl.BTCETHMaxLeverage), + fmt.Sprintf("- 山寨币最大杠杆:%d倍", cfg.RiskControl.AltcoinMaxLeverage), + fmt.Sprintf("- 最小置信度:%d", cfg.RiskControl.MinConfidence), + fmt.Sprintf("- 最小盈亏比:%.2f", cfg.RiskControl.MinRiskRewardRatio), + fmt.Sprintf("- 最大持仓数(System enforced):%d", cfg.RiskControl.MaxPositions), + fmt.Sprintf("- BTC/ETH 单币仓位上限(System enforced):账户权益 %.2f 倍", cfg.RiskControl.BTCETHMaxPositionValueRatio), + fmt.Sprintf("- 山寨币单币仓位上限(System enforced):账户权益 %.2f 倍", cfg.RiskControl.AltcoinMaxPositionValueRatio), + fmt.Sprintf("- 最大保证金使用率(System enforced):%.0f%%", cfg.RiskControl.MaxMarginUsage*100), + fmt.Sprintf("- 最小开仓金额(System enforced):%.2f USDT", cfg.RiskControl.MinPositionSize), + fmt.Sprintf("- 角色定义:%s", compactSummaryText(cfg.PromptSections.RoleDefinition)), + fmt.Sprintf("- 交易频率规则:%s", compactSummaryText(cfg.PromptSections.TradingFrequency)), + fmt.Sprintf("- 开仓标准:%s", compactSummaryText(cfg.PromptSections.EntryStandards)), + fmt.Sprintf("- 决策流程:%s", compactSummaryText(cfg.PromptSections.DecisionProcess)), + fmt.Sprintf("- 自定义 Prompt:%s", compactSummaryText(cfg.CustomPrompt)), + ) + } + lines = append(lines, "确认创建的话,直接回复“确认创建”。要调整也可以直接说改哪项。") + return strings.Join(lines, "\n") + } + lines := []string{fmt.Sprintf("I prepared the config for %q. Confirm and I will create it in the strategy list.", name)} + if cfg.StrategyType == "grid_trading" && cfg.GridConfig != nil { + grid := cfg.GridConfig + lines = append(lines, + "- Type: grid strategy", + fmt.Sprintf("- Symbol: %s", defaultIfEmpty(grid.Symbol, "unset")), + fmt.Sprintf("- Grid count: %d", grid.GridCount), + fmt.Sprintf("- Total investment: %.2f USDT", grid.TotalInvestment), + fmt.Sprintf("- Leverage: %dx", grid.Leverage), + ) + } else { + lines = append(lines, "- Type: AI strategy") + } + lines = append(lines, "Reply 'confirm create' to create it, or tell me what to change.") + return strings.Join(lines, "\n") +} + +func formatEnabledAIIndicatorsZH(cfg store.StrategyConfig) string { + enabled := make([]string, 0, 8) + if cfg.Indicators.EnableRawKlines { + enabled = append(enabled, "K线") + } + if cfg.Indicators.EnableVolume { + enabled = append(enabled, "成交量") + } + if cfg.Indicators.EnableOI { + enabled = append(enabled, "OI") + } + if cfg.Indicators.EnableFundingRate { + enabled = append(enabled, "资金费率") + } + if cfg.Indicators.EnableEMA { + enabled = append(enabled, "EMA") + } + if cfg.Indicators.EnableMACD { + enabled = append(enabled, "MACD") + } + if cfg.Indicators.EnableRSI { + enabled = append(enabled, "RSI") + } + if cfg.Indicators.EnableATR { + enabled = append(enabled, "ATR") + } + if cfg.Indicators.EnableBOLL { + enabled = append(enabled, "BOLL") + } + if len(enabled) == 0 { + return "无" + } + return strings.Join(enabled, ",") +} + +func formatAICoinSourceSummaryZH(cfg store.StrategyConfig) []string { + lines := make([]string, 0, 4) + sourceType := strings.ToLower(strings.TrimSpace(cfg.CoinSource.SourceType)) + switch sourceType { + case "static": + lines = append(lines, fmt.Sprintf("- 静态币种:%s", defaultIfEmpty(strings.Join(cfg.CoinSource.StaticCoins, ","), "未设置"))) + case "ai500": + lines = append(lines, fmt.Sprintf("- AI500 数量:%d", cfg.CoinSource.AI500Limit)) + case "oi_top": + lines = append(lines, fmt.Sprintf("- OI Top 数量:%d", cfg.CoinSource.OITopLimit)) + case "oi_low": + lines = append(lines, fmt.Sprintf("- OI Low 数量:%d", cfg.CoinSource.OILowLimit)) + default: + if cfg.CoinSource.UseAI500 { + lines = append(lines, fmt.Sprintf("- AI500 数量:%d", cfg.CoinSource.AI500Limit)) + } + if cfg.CoinSource.UseOITop { + lines = append(lines, fmt.Sprintf("- OI Top 数量:%d", cfg.CoinSource.OITopLimit)) + } + if cfg.CoinSource.UseOILow { + lines = append(lines, fmt.Sprintf("- OI Low 数量:%d", cfg.CoinSource.OILowLimit)) + } + } + if len(cfg.CoinSource.ExcludedCoins) > 0 { + lines = append(lines, fmt.Sprintf("- 排除币种:%s", strings.Join(cfg.CoinSource.ExcludedCoins, ","))) + } + return lines +} + +func compactSummaryText(value string) string { + value = strings.Join(strings.Fields(strings.TrimSpace(value)), " ") + if value == "" { + return "未设置" + } + const maxLen = 120 + runes := []rune(value) + if len(runes) <= maxLen { + return value + } + return string(runes[:maxLen]) + "..." +} + +func createConfirmationReply(text string) bool { + return strategyCreateConfirmationReply(text) +} + +func formatMissingFieldList(lang string, fields []string) string { + if len(fields) == 0 { + return "" + } + if lang == "zh" { + return strings.Join(fields, "、") + } + return strings.Join(fields, ", ") +} + +func availableModelProvidersMessage(lang string) string { + return modelProviderChoicePrompt(lang) +} + +func inferCreateDisplayName(text string) string { + clean := func(value string) string { + value = strings.TrimSpace(value) + value = strings.Trim(value, "“”\"':: ,,。.;;") + for _, sep := range []string{",", ",", "。", ";", ";", "\n"} { + if idx := strings.Index(value, sep); idx >= 0 { + value = strings.TrimSpace(value[:idx]) + } + } + for _, marker := range []string{" 交易所", " 模型", " 策略", " exchange", " model", " strategy"} { + if idx := strings.Index(value, marker); idx >= 0 { + value = strings.TrimSpace(value[:idx]) + } + } + for _, suffix := range []string{"的交易员", "的模型", "的策略", "的交易所", "这个交易员", "这个模型", "这个策略", "这个交易所"} { + if strings.HasSuffix(value, suffix) { + value = strings.TrimSpace(strings.TrimSuffix(value, suffix)) + } + } + return strings.TrimSpace(value) + } + if value := extractDelimitedSegmentAfterKeywords(text, []string{"名称叫", "名字叫", "配置名", "叫", "名为", "名称", "名字是", "called"}); value != "" { + return clean(value) + } + if value := extractQuotedContent(text); value != "" && !containsAny(strings.ToLower(text), []string{"api key", "apikey", "api_key", "secret", "passphrase"}) { + return clean(value) + } + return "" +} + +func formatModelCreateDraftSummary(lang string, session skillSession) string { + providerID := fieldValue(session, "provider") + name := defaultIfEmpty(fieldValue(session, "name"), defaultIfEmpty(defaultModelConfigName(providerID), "未命名模型")) + provider := defaultIfEmpty(providerID, "未选择") + modelName := defaultIfEmpty(fieldValue(session, "custom_model_name"), defaultIfEmpty(defaultModelNameForProvider(providerID), "未设置")) + apiURL := defaultIfEmpty(fieldValue(session, "custom_api_url"), "默认官方地址") + if lang != "zh" { + apiURL = defaultIfEmpty(fieldValue(session, "custom_api_url"), "provider default endpoint") + } + enabled := fieldValue(session, "enabled") != "false" + if lang == "zh" { + lines := []string{ + fmt.Sprintf("我先整理了一份模型配置草稿“%s”。", name), + fmt.Sprintf("- Provider:%s", provider), + fmt.Sprintf("- 配置名称:%s", name), + fmt.Sprintf("- 模型名称:%s", modelName), + fmt.Sprintf("- 接口地址:%s", apiURL), + fmt.Sprintf("- 启用状态:%t(未指定时默认 true)", enabled), + modelProviderDetailedGuidance(lang, providerID), + "如果这些字段没问题,直接回复“确认创建”;也可以继续补充或修改任意字段。", + } + return strings.Join(lines, "\n") + } + lines := []string{ + fmt.Sprintf("I prepared a draft model config %q.", name), + fmt.Sprintf("- Provider: %s", provider), + fmt.Sprintf("- Config name: %s", name), + fmt.Sprintf("- Model name: %s", modelName), + fmt.Sprintf("- API URL: %s", apiURL), + fmt.Sprintf("- Enabled: %t (defaults to true if omitted)", enabled), + modelProviderDetailedGuidance(lang, providerID), + "Reply 'confirm' to create it, or keep refining any field.", + } + return strings.Join(lines, "\n") +} + +func formatExchangeCreateDraftSummary(lang string, session skillSession) string { + exType := defaultIfEmpty(fieldValue(session, "exchange_type"), "未选择") + accountName := defaultIfEmpty(fieldValue(session, "account_name"), "未命名账户") + enabled := fieldValue(session, "enabled") != "false" + testnet := fieldValue(session, "testnet") == "true" + if lang == "zh" { + lines := []string{ + fmt.Sprintf("我先整理了一份交易所配置草稿“%s”。", accountName), + fmt.Sprintf("- 交易所:%s", exType), + fmt.Sprintf("- 账户名:%s", accountName), + fmt.Sprintf("- 启用状态:%t(未指定时默认 true)", enabled), + fmt.Sprintf("- 测试网:%t(未指定时默认 false)", testnet), + } + switch exType { + case "binance", "bybit", "gate", "indodax": + lines = append(lines, + fmt.Sprintf("- 已提供 API Key:%t", fieldValue(session, "api_key") != ""), + fmt.Sprintf("- 已提供 Secret:%t", fieldValue(session, "secret_key") != ""), + ) + case "okx", "bitget", "kucoin": + lines = append(lines, + fmt.Sprintf("- 已提供 API Key:%t", fieldValue(session, "api_key") != ""), + fmt.Sprintf("- 已提供 Secret:%t", fieldValue(session, "secret_key") != ""), + fmt.Sprintf("- 已提供 Passphrase:%t", fieldValue(session, "passphrase") != ""), + ) + case "hyperliquid": + lines = append(lines, + fmt.Sprintf("- 已提供 API Key:%t", fieldValue(session, "api_key") != ""), + fmt.Sprintf("- Hyperliquid 钱包地址:%s", defaultIfEmpty(fieldValue(session, "hyperliquid_wallet_addr"), "未设置")), + ) + case "aster": + lines = append(lines, + fmt.Sprintf("- Aster User:%s", defaultIfEmpty(fieldValue(session, "aster_user"), "未设置")), + fmt.Sprintf("- Aster Signer:%s", defaultIfEmpty(fieldValue(session, "aster_signer"), "未设置")), + fmt.Sprintf("- 已提供 Aster 私钥:%t", fieldValue(session, "aster_private_key") != ""), + ) + case "lighter": + lines = append(lines, + fmt.Sprintf("- Lighter 钱包地址:%s", defaultIfEmpty(fieldValue(session, "lighter_wallet_addr"), "未设置")), + fmt.Sprintf("- 已提供 Lighter API Key 私钥:%t", fieldValue(session, "lighter_api_key_private_key") != ""), + ) + if value := fieldValue(session, "lighter_api_key_index"); value != "" { + lines = append(lines, fmt.Sprintf("- Lighter API Key Index:%s", value)) + } + default: + lines = append(lines, + fmt.Sprintf("- 已提供 API Key:%t", fieldValue(session, "api_key") != ""), + fmt.Sprintf("- 已提供 Secret:%t", fieldValue(session, "secret_key") != ""), + ) + } + lines = append(lines, "如果这些字段没问题,直接回复“确认创建”;也可以继续补充或修改任意字段。") + return strings.Join(lines, "\n") + } + lines := []string{ + fmt.Sprintf("I prepared a draft exchange config %q.", accountName), + fmt.Sprintf("- Exchange: %s", exType), + fmt.Sprintf("- Account name: %s", accountName), + fmt.Sprintf("- Enabled: %t (defaults to true if omitted)", enabled), + fmt.Sprintf("- Testnet: %t (defaults to false if omitted)", testnet), + } + switch exType { + case "binance", "bybit", "gate", "indodax": + lines = append(lines, + fmt.Sprintf("- API key provided: %t", fieldValue(session, "api_key") != ""), + fmt.Sprintf("- Secret provided: %t", fieldValue(session, "secret_key") != ""), + ) + case "okx", "bitget", "kucoin": + lines = append(lines, + fmt.Sprintf("- API key provided: %t", fieldValue(session, "api_key") != ""), + fmt.Sprintf("- Secret provided: %t", fieldValue(session, "secret_key") != ""), + fmt.Sprintf("- Passphrase provided: %t", fieldValue(session, "passphrase") != ""), + ) + case "hyperliquid": + lines = append(lines, + fmt.Sprintf("- API key provided: %t", fieldValue(session, "api_key") != ""), + fmt.Sprintf("- Hyperliquid wallet address: %s", defaultIfEmpty(fieldValue(session, "hyperliquid_wallet_addr"), "not set")), + ) + case "aster": + lines = append(lines, + fmt.Sprintf("- Aster user: %s", defaultIfEmpty(fieldValue(session, "aster_user"), "not set")), + fmt.Sprintf("- Aster signer: %s", defaultIfEmpty(fieldValue(session, "aster_signer"), "not set")), + fmt.Sprintf("- Aster private key provided: %t", fieldValue(session, "aster_private_key") != ""), + ) + case "lighter": + lines = append(lines, + fmt.Sprintf("- Lighter wallet address: %s", defaultIfEmpty(fieldValue(session, "lighter_wallet_addr"), "not set")), + fmt.Sprintf("- Lighter API key private key provided: %t", fieldValue(session, "lighter_api_key_private_key") != ""), + ) + if value := fieldValue(session, "lighter_api_key_index"); value != "" { + lines = append(lines, fmt.Sprintf("- Lighter API key index: %s", value)) + } + default: + lines = append(lines, + fmt.Sprintf("- API key provided: %t", fieldValue(session, "api_key") != ""), + fmt.Sprintf("- Secret provided: %t", fieldValue(session, "secret_key") != ""), + ) + } + lines = append(lines, "Reply 'confirm' to create it, or keep refining any field.") + return strings.Join(lines, "\n") +} + +func formatTraderCreateDraftSummary(lang string, session skillSession) string { + args := buildTraderUpdateArgsFromSession(session) + args, warnings := normalizeTraderArgsToManualLimits(lang, args) + scanInterval := 3 + if args.ScanIntervalMinutes != nil && *args.ScanIntervalMinutes > 0 { + scanInterval = *args.ScanIntervalMinutes + } + isCrossMargin := true + if args.IsCrossMargin != nil { + isCrossMargin = *args.IsCrossMargin + } + showInCompetition := true + if args.ShowInCompetition != nil { + showInCompetition = *args.ShowInCompetition + } + autoStart := fieldValue(session, "auto_start") == "true" + name := defaultIfEmpty(fieldValue(session, "name"), "未命名交易员") + if lang != "zh" { + name = defaultIfEmpty(fieldValue(session, "name"), "unnamed trader") + } + if lang == "zh" { + lines := []string{ + fmt.Sprintf("我先整理了一份交易员草稿“%s”。", name), + fmt.Sprintf("- 名称:%s", name), + fmt.Sprintf("- 交易所:%s", traderCreateExchangeNameOrID(session)), + fmt.Sprintf("- 模型:%s", traderCreateModelNameOrID(session)), + fmt.Sprintf("- 策略:%s", traderCreateStrategyNameOrID(session)), + fmt.Sprintf("- 扫描间隔:%d 分钟(未指定时默认 3)", scanInterval), + "- 初始余额:创建时由系统自动读取绑定交易所账户净值", + fmt.Sprintf("- 全仓模式:%t(未指定时默认 true)", isCrossMargin), + fmt.Sprintf("- 竞技场显示:%t(未指定时默认 true)", showInCompetition), + } + if autoStart { + lines = append(lines, "- 创建后立即启动:true") + if len(warnings) > 0 { + lines = append(lines, "这些字段里有超出手动面板范围的值,我已经先按风控范围收敛:") + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + } + lines = append(lines, "如果这些字段没问题,直接回复“确认创建并启动”;也可以继续补充或修改任意字段。") + } else { + if len(warnings) > 0 { + lines = append(lines, "这些字段里有超出手动面板范围的值,我已经先按风控范围收敛:") + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + } + lines = append(lines, "如果这些字段没问题,直接回复“确认创建”;也可以继续补充或修改任意字段。") + } + return strings.Join(lines, "\n") + } + lines := []string{ + fmt.Sprintf("I prepared a draft trader %q.", name), + fmt.Sprintf("- Name: %s", name), + fmt.Sprintf("- Exchange: %s", traderCreateExchangeNameOrID(session)), + fmt.Sprintf("- Model: %s", traderCreateModelNameOrID(session)), + fmt.Sprintf("- Strategy: %s", traderCreateStrategyNameOrID(session)), + fmt.Sprintf("- Scan interval: %d minutes (defaults to 3)", scanInterval), + "- Initial balance: auto-read from the bound exchange account equity at creation time", + fmt.Sprintf("- Cross margin: %t (defaults to true)", isCrossMargin), + fmt.Sprintf("- Show in competition: %t (defaults to true)", showInCompetition), + } + if autoStart { + lines = append(lines, "- Start immediately after creation: true") + if len(warnings) > 0 { + lines = append(lines, "Some values exceeded the manual editor limits, so I normalized them first:") + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + } + lines = append(lines, "Reply 'confirm' to create and start it, or keep refining any field.") + } else { + if len(warnings) > 0 { + lines = append(lines, "Some values exceeded the manual editor limits, so I normalized them first:") + for _, warning := range warnings { + lines = append(lines, "- "+warning) + } + } + lines = append(lines, "Reply 'confirm' to create it, or keep refining any field.") + } + return strings.Join(lines, "\n") +} + +func hasExplicitStrategyDetailIntent(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if !hasExplicitManagementDomainCue(text, "strategy") { + return false + } + return containsAny(lower, []string{ + "什么样", "怎么样", "详情", "详细", "prompt", "提示词", + "哪个策略", "哪一个策略", "你改的是哪个策略", "你把哪个策略", + "what kind", "details", "detail", "prompt", "which strategy", + }) +} + +func shouldPreferStrategyQueryDetail(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if !containsAny(lower, []string{"?", "?", "哪个", "哪一个", "哪条", "which"}) { + return false + } + return containsAny(lower, []string{"策略", "strategy"}) +} + +func shouldExplainStrategyRuntimeBoundary(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if !containsAny(lower, []string{"策略", "strategy"}) { + return false + } + if !containsAny(lower, []string{"启动", "运行", "run", "start", "deploy"}) { + return false + } + if containsAny(lower, []string{"交易员", "trader", "机器人", "bot"}) { + return false + } + return true +} + func wantsDefaultStrategyConfig(text string) bool { lower := strings.ToLower(strings.TrimSpace(text)) if lower == "" { @@ -396,6 +1653,9 @@ func formatStrategyDetailResponse(lang string, strategy *store.Strategy, cfg sto if len(cfg.CoinSource.StaticCoins) > 0 { sourceBits = append(sourceBits, "static="+strings.Join(cfg.CoinSource.StaticCoins, ",")) } + if len(cfg.CoinSource.ExcludedCoins) > 0 { + sourceBits = append(sourceBits, "excluded="+strings.Join(cfg.CoinSource.ExcludedCoins, ",")) + } timeframes := append([]string(nil), cfg.Indicators.Klines.SelectedTimeframes...) if len(timeframes) == 0 { @@ -453,15 +1713,43 @@ func formatStrategyDetailResponse(lang string, strategy *store.Strategy, cfg sto customPromptPreview = string(runes[:120]) + "..." } + publishStatusZh := "未发布" + publishStatusEn := "private" + if strategy.IsPublic { + publishStatusZh = "已发布到市场" + publishStatusEn = "public" + } + configVisibleZh := "隐藏" + configVisibleEn := "hidden" + if strategy.ConfigVisible { + configVisibleZh = "可见" + configVisibleEn = "visible" + } + if lang == "zh" { lines := []string{ fmt.Sprintf("策略“%s”概览:", name), fmt.Sprintf("- 类型:%s", defaultIfEmpty(strings.TrimSpace(cfg.StrategyType), "ai_trading")), fmt.Sprintf("- 语言:%s", defaultIfEmpty(strings.TrimSpace(cfg.Language), "zh")), + fmt.Sprintf("- 发布设置:%s;配置%s", publishStatusZh, configVisibleZh), } if strings.TrimSpace(strategy.Description) != "" { lines = append(lines, fmt.Sprintf("- 描述:%s", strings.TrimSpace(strategy.Description))) } + if cfg.GridConfig != nil { + lines = append(lines, fmt.Sprintf("- 网格参数:交易对 %s;网格 %d;总投资 %.2f;杠杆 %d;分布 %s", + defaultIfEmpty(strings.TrimSpace(cfg.GridConfig.Symbol), "未设置"), + cfg.GridConfig.GridCount, + cfg.GridConfig.TotalInvestment, + cfg.GridConfig.Leverage, + defaultIfEmpty(strings.TrimSpace(cfg.GridConfig.Distribution), "未设置"), + )) + if cfg.GridConfig.UseATRBounds { + lines = append(lines, fmt.Sprintf("- 网格边界:ATR 自动边界,倍数 %.2f", cfg.GridConfig.ATRMultiplier)) + } else if cfg.GridConfig.UpperPrice > 0 || cfg.GridConfig.LowerPrice > 0 { + lines = append(lines, fmt.Sprintf("- 网格边界:上沿 %.4f,下沿 %.4f", cfg.GridConfig.UpperPrice, cfg.GridConfig.LowerPrice)) + } + } if len(sourceBits) > 0 { lines = append(lines, "- 标的来源:"+strings.Join(sourceBits, " | ")) } @@ -470,9 +1758,20 @@ func formatStrategyDetailResponse(lang string, strategy *store.Strategy, cfg sto } lines = append(lines, fmt.Sprintf("- 仓位风险:最多持仓 %d,BTC/ETH 最大杠杆 %d,山寨最大杠杆 %d,最低置信度 %d", cfg.RiskControl.MaxPositions, cfg.RiskControl.BTCETHMaxLeverage, cfg.RiskControl.AltcoinMaxLeverage, cfg.RiskControl.MinConfidence)) + lines = append(lines, fmt.Sprintf("- 风控阈值:最小盈亏比 %.2f;最大保证金使用率 %.2f;最小开仓金额 %.2f", + cfg.RiskControl.MinRiskRewardRatio, cfg.RiskControl.MaxMarginUsage, cfg.RiskControl.MinPositionSize)) if len(indicatorBits) > 0 { lines = append(lines, "- 已启用指标:"+strings.Join(indicatorBits, "、")) } + if strings.TrimSpace(cfg.Indicators.NofxOSAPIKey) != "" || cfg.Indicators.EnableQuantData || cfg.Indicators.EnableOIRanking || cfg.Indicators.EnableNetFlowRanking || cfg.Indicators.EnablePriceRanking { + lines = append(lines, fmt.Sprintf("- NofxOS 数据:API Key=%t,量化数据=%t,OI 排行=%t,净流入排行=%t,价格排行=%t", + strings.TrimSpace(cfg.Indicators.NofxOSAPIKey) != "", + cfg.Indicators.EnableQuantData, + cfg.Indicators.EnableOIRanking, + cfg.Indicators.EnableNetFlowRanking, + cfg.Indicators.EnablePriceRanking, + )) + } if len(promptBits) > 0 { lines = append(lines, "- Prompt 模块:"+strings.Join(promptBits, "、")) } @@ -489,10 +1788,25 @@ func formatStrategyDetailResponse(lang string, strategy *store.Strategy, cfg sto fmt.Sprintf("Strategy %q overview:", name), fmt.Sprintf("- Type: %s", defaultIfEmpty(strings.TrimSpace(cfg.StrategyType), "ai_trading")), fmt.Sprintf("- Language: %s", defaultIfEmpty(strings.TrimSpace(cfg.Language), "en")), + fmt.Sprintf("- Publish settings: %s; config %s", publishStatusEn, configVisibleEn), } if strings.TrimSpace(strategy.Description) != "" { lines = append(lines, fmt.Sprintf("- Description: %s", strings.TrimSpace(strategy.Description))) } + if cfg.GridConfig != nil { + lines = append(lines, fmt.Sprintf("- Grid config: symbol %s; grids %d; investment %.2f; leverage %d; distribution %s", + defaultIfEmpty(strings.TrimSpace(cfg.GridConfig.Symbol), "not set"), + cfg.GridConfig.GridCount, + cfg.GridConfig.TotalInvestment, + cfg.GridConfig.Leverage, + defaultIfEmpty(strings.TrimSpace(cfg.GridConfig.Distribution), "not set"), + )) + if cfg.GridConfig.UseATRBounds { + lines = append(lines, fmt.Sprintf("- Grid bounds: ATR auto bounds with multiplier %.2f", cfg.GridConfig.ATRMultiplier)) + } else if cfg.GridConfig.UpperPrice > 0 || cfg.GridConfig.LowerPrice > 0 { + lines = append(lines, fmt.Sprintf("- Grid bounds: upper %.4f, lower %.4f", cfg.GridConfig.UpperPrice, cfg.GridConfig.LowerPrice)) + } + } if len(sourceBits) > 0 { lines = append(lines, "- Coin source: "+strings.Join(sourceBits, " | ")) } @@ -501,9 +1815,20 @@ func formatStrategyDetailResponse(lang string, strategy *store.Strategy, cfg sto } lines = append(lines, fmt.Sprintf("- Risk: max positions %d, BTC/ETH max leverage %d, alt max leverage %d, min confidence %d", cfg.RiskControl.MaxPositions, cfg.RiskControl.BTCETHMaxLeverage, cfg.RiskControl.AltcoinMaxLeverage, cfg.RiskControl.MinConfidence)) + lines = append(lines, fmt.Sprintf("- Risk thresholds: min RR %.2f, max margin usage %.2f, min position size %.2f", + cfg.RiskControl.MinRiskRewardRatio, cfg.RiskControl.MaxMarginUsage, cfg.RiskControl.MinPositionSize)) if len(indicatorBits) > 0 { lines = append(lines, "- Enabled indicators: "+strings.Join(indicatorBits, ", ")) } + if strings.TrimSpace(cfg.Indicators.NofxOSAPIKey) != "" || cfg.Indicators.EnableQuantData || cfg.Indicators.EnableOIRanking || cfg.Indicators.EnableNetFlowRanking || cfg.Indicators.EnablePriceRanking { + lines = append(lines, fmt.Sprintf("- NofxOS data: API key=%t, quant data=%t, OI ranking=%t, netflow ranking=%t, price ranking=%t", + strings.TrimSpace(cfg.Indicators.NofxOSAPIKey) != "", + cfg.Indicators.EnableQuantData, + cfg.Indicators.EnableOIRanking, + cfg.Indicators.EnableNetFlowRanking, + cfg.Indicators.EnablePriceRanking, + )) + } if len(promptBits) > 0 { lines = append(lines, "- Prompt modules: "+strings.Join(promptBits, ", ")) } @@ -580,12 +1905,74 @@ func (a *Agent) describeExchange(storeUserID, lang string, target *EntityReferen } exchange = &payload.ExchangeConfigs[0] } - if lang == "zh" { - return fmt.Sprintf("交易所配置“%s”详情:\n- 交易所:%s\n- 已启用:%t\n- API Key:%t\n- Secret:%t\n- Passphrase:%t\n- Testnet:%t", - defaultIfEmpty(exchange.AccountName, exchange.ID), exchange.ExchangeType, exchange.Enabled, exchange.HasAPIKey, exchange.HasSecretKey, exchange.HasPassphrase, exchange.Testnet), true + name := defaultIfEmpty(exchange.AccountName, exchange.ID) + credentialLinesZh := make([]string, 0, 8) + credentialLinesEn := make([]string, 0, 8) + addCredentialLine := func(labelZh, labelEn string, present bool) { + credentialLinesZh = append(credentialLinesZh, fmt.Sprintf("- %s:%t", labelZh, present)) + credentialLinesEn = append(credentialLinesEn, fmt.Sprintf("- %s: %t", labelEn, present)) } - return fmt.Sprintf("Exchange config %q details:\n- Exchange: %s\n- Enabled: %t\n- API key present: %t\n- Secret present: %t\n- Passphrase present: %t\n- Testnet: %t", - defaultIfEmpty(exchange.AccountName, exchange.ID), exchange.ExchangeType, exchange.Enabled, exchange.HasAPIKey, exchange.HasSecretKey, exchange.HasPassphrase, exchange.Testnet), true + switch exchange.ExchangeType { + case "binance", "bybit", "gate", "indodax": + addCredentialLine("API Key", "API key present", exchange.HasAPIKey) + addCredentialLine("Secret", "Secret present", exchange.HasSecretKey) + case "okx", "bitget", "kucoin": + addCredentialLine("API Key", "API key present", exchange.HasAPIKey) + addCredentialLine("Secret", "Secret present", exchange.HasSecretKey) + addCredentialLine("Passphrase", "Passphrase present", exchange.HasPassphrase) + case "hyperliquid": + addCredentialLine("API Key", "API key present", exchange.HasAPIKey) + credentialLinesZh = append(credentialLinesZh, fmt.Sprintf("- Hyperliquid 钱包地址:%s", defaultIfEmpty(exchange.HyperliquidWalletAddr, "未设置"))) + credentialLinesEn = append(credentialLinesEn, fmt.Sprintf("- Hyperliquid wallet address: %s", defaultIfEmpty(exchange.HyperliquidWalletAddr, "not set"))) + case "aster": + credentialLinesZh = append(credentialLinesZh, + fmt.Sprintf("- Aster User:%s", defaultIfEmpty(exchange.AsterUser, "未设置")), + fmt.Sprintf("- Aster Signer:%s", defaultIfEmpty(exchange.AsterSigner, "未设置")), + fmt.Sprintf("- Aster 私钥:%t", exchange.HasAsterPrivateKey), + ) + credentialLinesEn = append(credentialLinesEn, + fmt.Sprintf("- Aster user: %s", defaultIfEmpty(exchange.AsterUser, "not set")), + fmt.Sprintf("- Aster signer: %s", defaultIfEmpty(exchange.AsterSigner, "not set")), + fmt.Sprintf("- Aster private key present: %t", exchange.HasAsterPrivateKey), + ) + case "lighter": + credentialLinesZh = append(credentialLinesZh, + fmt.Sprintf("- Lighter 钱包地址:%s", defaultIfEmpty(exchange.LighterWalletAddr, "未设置")), + fmt.Sprintf("- Lighter API Key 私钥:%t", exchange.HasLighterAPIKey), + fmt.Sprintf("- Lighter API Key Index:%d", exchange.LighterAPIKeyIndex), + ) + credentialLinesEn = append(credentialLinesEn, + fmt.Sprintf("- Lighter wallet address: %s", defaultIfEmpty(exchange.LighterWalletAddr, "not set")), + fmt.Sprintf("- Lighter API key private key present: %t", exchange.HasLighterAPIKey), + fmt.Sprintf("- Lighter API key index: %d", exchange.LighterAPIKeyIndex), + ) + default: + addCredentialLine("API Key", "API key present", exchange.HasAPIKey) + addCredentialLine("Secret", "Secret present", exchange.HasSecretKey) + if exchange.HasPassphrase { + addCredentialLine("Passphrase", "Passphrase present", true) + } + } + if lang == "zh" { + lines := []string{ + fmt.Sprintf("交易所配置“%s”详情:", name), + fmt.Sprintf("- 交易所:%s", exchange.ExchangeType), + fmt.Sprintf("- 账户名:%s", name), + fmt.Sprintf("- 已启用:%t", exchange.Enabled), + fmt.Sprintf("- Testnet:%t", exchange.Testnet), + } + lines = append(lines, credentialLinesZh...) + return strings.Join(lines, "\n"), true + } + lines := []string{ + fmt.Sprintf("Exchange config %q details:", name), + fmt.Sprintf("- Exchange: %s", exchange.ExchangeType), + fmt.Sprintf("- Account name: %s", name), + fmt.Sprintf("- Enabled: %t", exchange.Enabled), + fmt.Sprintf("- Testnet: %t", exchange.Testnet), + } + lines = append(lines, credentialLinesEn...) + return strings.Join(lines, "\n"), true } func (a *Agent) describeModel(storeUserID, lang string, target *EntityReference) (string, bool) { @@ -604,11 +1991,37 @@ func (a *Agent) describeModel(storeUserID, lang string, target *EntityReference) model = &payload.ModelConfigs[0] } if lang == "zh" { - return fmt.Sprintf("模型配置“%s”详情:\n- Provider:%s\n- 已启用:%t\n- API Key:%t\n- URL:%s\n- Model Name:%s", - defaultIfEmpty(model.Name, model.ID), model.Provider, model.Enabled, model.HasAPIKey, defaultIfEmpty(model.CustomAPIURL, "未设置"), defaultIfEmpty(model.CustomModelName, "未设置")), true + lines := []string{ + fmt.Sprintf("模型配置“%s”详情:", defaultIfEmpty(model.Name, model.ID)), + fmt.Sprintf("- Provider:%s", model.Provider), + fmt.Sprintf("- 已启用:%t", model.Enabled), + fmt.Sprintf("- API Key:%t", model.HasAPIKey), + fmt.Sprintf("- URL:%s", defaultIfEmpty(model.CustomAPIURL, "未设置")), + fmt.Sprintf("- Model Name:%s", defaultIfEmpty(model.CustomModelName, "未设置")), + } + if strings.TrimSpace(model.WalletAddress) != "" { + lines = append(lines, fmt.Sprintf("- 钱包地址:%s", model.WalletAddress)) + } + if strings.TrimSpace(model.BalanceUSDC) != "" { + lines = append(lines, fmt.Sprintf("- 钱包余额:%s USDC", model.BalanceUSDC)) + } + return strings.Join(lines, "\n"), true } - return fmt.Sprintf("Model config %q details:\n- Provider: %s\n- Enabled: %t\n- API key present: %t\n- URL: %s\n- Model name: %s", - defaultIfEmpty(model.Name, model.ID), model.Provider, model.Enabled, model.HasAPIKey, defaultIfEmpty(model.CustomAPIURL, "not set"), defaultIfEmpty(model.CustomModelName, "not set")), true + lines := []string{ + fmt.Sprintf("Model config %q details:", defaultIfEmpty(model.Name, model.ID)), + fmt.Sprintf("- Provider: %s", model.Provider), + fmt.Sprintf("- Enabled: %t", model.Enabled), + fmt.Sprintf("- API key present: %t", model.HasAPIKey), + fmt.Sprintf("- URL: %s", defaultIfEmpty(model.CustomAPIURL, "not set")), + fmt.Sprintf("- Model name: %s", defaultIfEmpty(model.CustomModelName, "not set")), + } + if strings.TrimSpace(model.WalletAddress) != "" { + lines = append(lines, fmt.Sprintf("- Wallet address: %s", model.WalletAddress)) + } + if strings.TrimSpace(model.BalanceUSDC) != "" { + lines = append(lines, fmt.Sprintf("- Wallet balance: %s USDC", model.BalanceUSDC)) + } + return strings.Join(lines, "\n"), true } func findTraderByReference(items []safeTraderToolConfig, target *EntityReference) *safeTraderToolConfig { @@ -665,9 +2078,48 @@ func (a *Agent) loadTraderOptions(storeUserID string) []traderSkillOption { if err != nil { return nil } + exchangeNames := map[string]string{} + if exchanges, err := a.store.Exchange().List(storeUserID); err == nil { + for _, exchange := range exchanges { + if !store.IsVisibleExchange(exchange) { + continue + } + name := strings.TrimSpace(exchange.AccountName) + if name == "" { + name = strings.TrimSpace(exchange.ExchangeType) + } + if name != "" { + exchangeNames[exchange.ID] = name + } + } + } + modelNames := map[string]string{} + if models, err := a.store.AIModel().List(storeUserID); err == nil { + for _, model := range models { + name := strings.TrimSpace(model.Name) + if name == "" { + name = strings.TrimSpace(model.CustomModelName) + } + if name != "" { + modelNames[model.ID] = name + } + } + } out := make([]traderSkillOption, 0, len(traders)) for _, trader := range traders { - out = append(out, traderSkillOption{ID: trader.ID, Name: trader.Name, Enabled: trader.IsRunning}) + hints := make([]string, 0, 2) + if exchangeName := strings.TrimSpace(exchangeNames[trader.ExchangeID]); exchangeName != "" { + hints = append(hints, "交易所 "+exchangeName) + } + if modelName := strings.TrimSpace(modelNames[trader.AIModelID]); modelName != "" { + hints = append(hints, "模型 "+modelName) + } + out = append(out, traderSkillOption{ + ID: trader.ID, + Name: trader.Name, + Enabled: trader.IsRunning, + Hint: strings.Join(hints, ","), + }) } return out } @@ -686,24 +2138,65 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang } return "Cancelled the current exchange creation flow." } - if v := exchangeTypeFromText(text); fieldValue(session, "exchange_type") == "" && v != "" { - setField(&session, "exchange_type", v) - } - if v := extractTraderName(text); fieldValue(session, "account_name") == "" && v != "" { - setField(&session, "account_name", v) - } exType := fieldValue(session, "exchange_type") + accountName := fieldValue(session, "account_name") + missing := make([]string, 0, 6) if actionRequiresSlot("exchange_management", "create", "exchange_type") && exType == "" { + missing = append(missing, slotDisplayName("exchange_type", lang)) + } + if accountName == "" { + missing = append(missing, displayCatalogFieldName("account_name", lang)) + } + if fieldValue(session, "api_key") == "" { + missing = append(missing, displayCatalogFieldName("api_key", lang)) + } + if fieldValue(session, "secret_key") == "" { + missing = append(missing, displayCatalogFieldName("secret_key", lang)) + } + switch exType { + case "okx": + if fieldValue(session, "passphrase") == "" { + missing = append(missing, displayCatalogFieldName("passphrase", lang)) + } + case "hyperliquid": + if fieldValue(session, "hyperliquid_wallet_addr") == "" { + missing = append(missing, "Hyperliquid Wallet") + } + } + if len(missing) > 0 { setSkillDAGStep(&session, "resolve_exchange_type") a.saveSkillSession(userID, session) if lang == "zh" { - return "要创建交易所配置,我还需要:" + slotDisplayName("exchange_type", lang) + "。例如:OKX、Binance、Bybit。" + reply := "要创建交易所配置,还缺这些字段:" + formatMissingFieldList(lang, missing) + "。" + if exType == "" { + reply += "\n例如:OKX、Binance、Bybit。" + } + return reply } - return "To create an exchange config, tell me which exchange to use, for example OKX, Binance, or Bybit." + return "One more thing: please tell me these details: " + formatMissingFieldList(lang, missing) + "." } - accountName := fieldValue(session, "account_name") - if accountName == "" { - accountName = "Default" + validator := exchangeConfigValidator{ + exchangeType: exType, + enabled: fieldValue(session, "enabled") == "true", + apiKey: fieldValue(session, "api_key"), + secretKey: fieldValue(session, "secret_key"), + passphrase: fieldValue(session, "passphrase"), + hyperliquidWalletAddr: fieldValue(session, "hyperliquid_wallet_addr"), + asterUser: fieldValue(session, "aster_user"), + asterSigner: fieldValue(session, "aster_signer"), + asterPrivateKey: fieldValue(session, "aster_private_key"), + lighterWalletAddr: fieldValue(session, "lighter_wallet_addr"), + lighterAPIKeyPrivateKey: fieldValue(session, "lighter_api_key_private_key"), + } + if err := validator.Validate(); err != nil { + a.saveSkillSession(userID, session) + return formatValidationFeedback(lang, "exchange", err) + } + if !createConfirmationReply(text) { + session.Phase = "await_create_confirmation" + setSkillDAGStep(&session, "await_create_confirmation") + a.saveSkillSession(userID, session) + return formatExchangeCreateDraftSummary(lang, session) } setSkillDAGStep(&session, "execute_create") args := map[string]any{ @@ -711,6 +2204,22 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang "exchange_type": exType, "account_name": accountName, } + for _, field := range []string{"api_key", "secret_key", "passphrase", "hyperliquid_wallet_addr", "aster_user", "aster_signer", "aster_private_key", "lighter_wallet_addr", "lighter_api_key_private_key"} { + if value := fieldValue(session, field); value != "" { + args[field] = value + } + } + if value := fieldValue(session, "enabled"); value != "" { + args["enabled"] = value == "true" + } + if value := fieldValue(session, "testnet"); value != "" { + args["testnet"] = value == "true" + } + if value := fieldValue(session, "lighter_api_key_index"); value != "" { + if parsed, err := strconv.Atoi(value); err == nil { + args["lighter_api_key_index"] = parsed + } + } raw, _ := json.Marshal(args) resp := a.toolManageExchangeConfig(storeUserID, string(raw)) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { @@ -718,13 +2227,14 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang if lang == "zh" { return "创建交易所配置失败:" + errMsg } - return "Failed to create exchange config: " + errMsg + return "That create request did not go through: " + errMsg } a.clearSkillSession(userID) + a.rememberReferencesFromToolResult(userID, "manage_exchange_config", resp) if lang == "zh" { - return fmt.Sprintf("已创建交易所配置:%s(%s)。如需继续补 API Key、Secret 或 Passphrase,可以直接继续说。", accountName, exType) + return fmt.Sprintf("已创建交易所配置:%s(%s)。", accountName, exType) } - return fmt.Sprintf("Created exchange config %s (%s). You can continue by adding API key, secret, or passphrase.", accountName, exType) + return fmt.Sprintf("Created exchange config %s (%s).", accountName, exType) } func (a *Agent) handleModelCreateSkill(storeUserID string, userID int64, lang, text string, session skillSession) string { @@ -741,32 +2251,75 @@ func (a *Agent) handleModelCreateSkill(storeUserID string, userID int64, lang, t } return "Cancelled the current model creation flow." } - if v := providerFromText(text); fieldValue(session, "provider") == "" && v != "" { - setField(&session, "provider", v) - } - if v := extractTraderName(text); fieldValue(session, "name") == "" && v != "" { - setField(&session, "name", v) - } - if v := extractURL(text); fieldValue(session, "custom_api_url") == "" && v != "" { - setField(&session, "custom_api_url", v) - } provider := fieldValue(session, "provider") - if actionRequiresSlot("model_management", "create", "provider") && provider == "" { + if provider != "" { + if fieldValue(session, "name") == "" { + setField(&session, "name", defaultModelConfigName(provider)) + } + if modelProviderSupportsCustomModel(provider) && fieldValue(session, "custom_model_name") == "" { + if defaultModel := defaultModelNameForProvider(provider); defaultModel != "" { + setField(&session, "custom_model_name", defaultModel) + } + } + if !modelProviderSupportsCustomAPIURL(provider) { + setField(&session, "custom_api_url", "") + } + } + missing := make([]string, 0, 4) + providerMissing := actionRequiresSlot("model_management", "create", "provider") && provider == "" + if providerMissing { + missing = append(missing, slotDisplayName("provider", lang)) + } + if !providerMissing && fieldValue(session, "api_key") == "" { + missing = append(missing, modelProviderCredentialLabel(lang, provider)) + } + if len(missing) > 0 { setSkillDAGStep(&session, "resolve_provider") a.saveSkillSession(userID, session) if lang == "zh" { - return "要创建模型配置,我还需要:" + slotDisplayName("provider", lang) + ",例如:OpenAI、DeepSeek、Claude、Gemini。" + reply := "要创建模型配置,还缺这些字段:" + formatMissingFieldList(lang, missing) + "。" + if provider == "" { + reply += "\n" + availableModelProvidersMessage(lang) + } else { + reply += "\n" + modelProviderDetailedGuidance(lang, provider) + } + return reply } - return "To create a model config, I need the provider first, for example OpenAI, DeepSeek, Claude, or Gemini." + reply := "One more thing: please tell me these details: " + formatMissingFieldList(lang, missing) + "." + if provider != "" { + reply += "\n" + modelProviderDetailedGuidance(lang, provider) + } + return reply + } + validator := modelConfigValidator{ + provider: provider, + enabled: fieldValue(session, "enabled") == "true", + apiKey: fieldValue(session, "api_key"), + customAPIURL: fieldValue(session, "custom_api_url"), + customModelName: fieldValue(session, "custom_model_name"), + } + if err := validator.Validate(); err != nil { + a.saveSkillSession(userID, session) + return formatValidationFeedback(lang, "model", err) + } + if !createConfirmationReply(text) { + session.Phase = "await_create_confirmation" + setSkillDAGStep(&session, "await_create_confirmation") + a.saveSkillSession(userID, session) + return formatModelCreateDraftSummary(lang, session) } setSkillDAGStep(&session, "execute_create") args := map[string]any{ "action": "create", "provider": provider, - "name": defaultIfEmpty(fieldValue(session, "name"), provider), + "name": fieldValue(session, "name"), + "api_key": fieldValue(session, "api_key"), "custom_api_url": fieldValue(session, "custom_api_url"), "custom_model_name": fieldValue(session, "custom_model_name"), } + if value := fieldValue(session, "enabled"); value != "" { + args["enabled"] = value == "true" + } raw, _ := json.Marshal(args) resp := a.toolManageModelConfig(storeUserID, string(raw)) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { @@ -774,13 +2327,51 @@ func (a *Agent) handleModelCreateSkill(storeUserID string, userID int64, lang, t if lang == "zh" { return "创建模型配置失败:" + errMsg } - return "Failed to create model config: " + errMsg + return "That create request did not go through: " + errMsg } a.clearSkillSession(userID) + a.rememberReferencesFromToolResult(userID, "manage_model_config", resp) if lang == "zh" { - return fmt.Sprintf("已创建模型配置:%s。你后续还可以继续补 API Key、URL 或模型名。", provider) + return fmt.Sprintf("已创建模型配置:%s。", fieldValue(session, "name")) + } + return fmt.Sprintf("Created model config %s.", fieldValue(session, "name")) +} + +func inferModelCredentialFromText(provider, text string) string { + provider = strings.ToLower(strings.TrimSpace(provider)) + text = strings.TrimSpace(text) + if provider == "" || text == "" { + return "" + } + + if value := extractQuotedContent(text); value != "" { + trimmed := strings.TrimSpace(value) + if credentialLooksCompatibleWithProvider(provider, trimmed) { + return trimmed + } + } + + if credentialLooksCompatibleWithProvider(provider, text) { + return text + } + return "" +} + +func credentialLooksCompatibleWithProvider(provider, value string) bool { + provider = strings.ToLower(strings.TrimSpace(provider)) + value = strings.TrimSpace(value) + if provider == "" || value == "" { + return false + } + + switch provider { + case "claw402", "blockrun-base", "blockrun-sol": + return hexCredentialPattern.MatchString(value) + case "openai": + return openAIAPIKeyPattern.MatchString(value) + default: + return genericAPIKeyPattern.MatchString(value) || hexCredentialPattern.MatchString(value) } - return fmt.Sprintf("Created model config for %s. You can continue by adding API key, URL, or model name.", provider) } func (a *Agent) handleStrategyCreateSkill(storeUserID string, userID int64, lang, text string, session skillSession) string { @@ -797,16 +2388,7 @@ func (a *Agent) handleStrategyCreateSkill(storeUserID string, userID int64, lang } return "Cancelled the current strategy creation flow." } - name := fieldValue(session, "name") - if name == "" { - name = extractTraderName(text) - if name == "" { - name = extractPostKeywordName(text, []string{"叫", "名为", "策略叫", "strategy called"}) - } - if name != "" { - setField(&session, "name", name) - } - } + name := resolveStrategyCreateName(&session, text) if actionRequiresSlot("strategy_management", "create", "name") && name == "" { setSkillDAGStep(&session, "resolve_name") a.saveSkillSession(userID, session) @@ -815,8 +2397,51 @@ func (a *Agent) handleStrategyCreateSkill(storeUserID string, userID int64, lang } return "To create a strategy, I need a strategy name. You can say: create a strategy called 'Trend A'." } + if fieldValue(session, "strategy_type") == "" { + if strategyType := parseStrategyTypeValue(text); strategyType != "" { + setStrategyCreateType(&session, strategyType) + } + } else if strategyType := parseStrategyTypeValue(text); strategyType != "" { + setStrategyCreateType(&session, strategyType) + } + cfg, configMap, warnings, cfgErr := strategyCreateConfigFromSession(session, lang) + if cfgErr != nil { + a.saveSkillSession(userID, session) + if lang == "zh" { + return "创建策略失败:" + cfgErr.Error() + } + return "That strategy config could not be prepared: " + cfgErr.Error() + } + if ready, missingKind := strategyCreateConfigReady(session, cfg, text); !ready { + setField(&session, strategyCreateDraftConfigField, marshalStrategyCreateDraft(cfg)) + setSkillDAGStep(&session, "collect_config") + session.Phase = "collecting" + a.saveSkillSession(userID, session) + if reply := formatStrategyCreateFieldOptionsReply(lang, text, missingKind); reply != "" { + return reply + } + return formatStrategyCreateConfigNeeded(lang, missingKind) + } + if !strategyCreateConfirmationReply(text) && !strategyCreateFinalConfirmationReady(session) { + setField(&session, strategyCreateDraftConfigField, marshalStrategyCreateDraft(cfg)) + setField(&session, "awaiting_final_confirmation", "true") + setSkillDAGStep(&session, "await_create_confirmation") + session.Phase = "await_create_confirmation" + a.saveSkillSession(userID, session) + return formatStrategyCreateFinalConfirmation(lang, session, cfg) + } + setSkillDAGStep(&session, "execute_create") - args := map[string]any{"action": "create", "name": name, "lang": "zh"} + args := map[string]any{ + "action": "create", + "name": name, + "lang": defaultIfEmpty(lang, "zh"), + "allow_clamped_update": true, + "confirmed": true, + } + if len(configMap) > 0 { + args["config"] = configMap + } raw, _ := json.Marshal(args) resp := a.toolManageStrategy(storeUserID, string(raw)) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { @@ -824,38 +2449,99 @@ func (a *Agent) handleStrategyCreateSkill(storeUserID string, userID int64, lang if lang == "zh" { return "创建策略失败:" + errMsg } - return "Failed to create strategy: " + errMsg + return "That create request did not go through: " + errMsg } a.clearSkillSession(userID) - if lang == "zh" { - return fmt.Sprintf("已创建策略“%s”。默认配置已就绪,你后续可以继续让我帮你改细节。", name) + a.rememberReferencesFromToolResult(userID, "manage_strategy", resp) + return formatCreatedStrategyReply(lang, name, cfg, warnings) +} + +func formatCreatedStrategyReply(lang, name string, cfg store.StrategyConfig, warnings []string) string { + name = defaultIfEmpty(strings.TrimSpace(name), "未命名策略") + if lang != "zh" { + name = defaultIfEmpty(strings.TrimSpace(name), "unnamed strategy") } - return fmt.Sprintf("Created strategy %q with the default configuration.", name) + _ = warnings + if lang == "zh" { + lines := []string{fmt.Sprintf("已创建策略“%s”。实际保存配置如下:", name)} + if cfg.StrategyType == "grid_trading" && cfg.GridConfig != nil { + grid := cfg.GridConfig + lines = append(lines, + "- 类型:网格策略", + fmt.Sprintf("- 交易对:%s", defaultIfEmpty(grid.Symbol, "未设置")), + fmt.Sprintf("- 网格数量:%d", grid.GridCount), + fmt.Sprintf("- 总投入:%.2f USDT", grid.TotalInvestment), + fmt.Sprintf("- 杠杆:%d倍", grid.Leverage), + fmt.Sprintf("- 分布方式:%s", defaultIfEmpty(grid.Distribution, "未设置")), + ) + if grid.UseATRBounds { + lines = append(lines, fmt.Sprintf("- 价格范围:ATR 自动计算(倍数 %.2f)", grid.ATRMultiplier)) + } else { + lines = append(lines, fmt.Sprintf("- 价格范围:%.2f ~ %.2f USDT", grid.LowerPrice, grid.UpperPrice)) + } + lines = append(lines, + fmt.Sprintf("- 最大回撤:%.2f%%", grid.MaxDrawdownPct), + fmt.Sprintf("- 止损:%.2f%%", grid.StopLossPct), + fmt.Sprintf("- 日亏损限制:%.2f%%", grid.DailyLossLimitPct), + fmt.Sprintf("- 只挂 Maker:%t", grid.UseMakerOnly), + ) + } else { + lines = append(lines, + "- 类型:AI 策略", + fmt.Sprintf("- 选币来源:%s", defaultIfEmpty(cfg.CoinSource.SourceType, "未设置")), + fmt.Sprintf("- 主周期:%s", defaultIfEmpty(cfg.Indicators.Klines.PrimaryTimeframe, "未设置")), + fmt.Sprintf("- BTC/ETH 最大杠杆:%d倍", cfg.RiskControl.BTCETHMaxLeverage), + fmt.Sprintf("- 山寨币最大杠杆:%d倍", cfg.RiskControl.AltcoinMaxLeverage), + fmt.Sprintf("- 最小置信度:%d", cfg.RiskControl.MinConfidence), + fmt.Sprintf("- 最小盈亏比:%.2f", cfg.RiskControl.MinRiskRewardRatio), + ) + } + return strings.Join(lines, "\n") + } + + lines := []string{fmt.Sprintf("Created strategy %q with this saved config:", name)} + if cfg.StrategyType == "grid_trading" && cfg.GridConfig != nil { + grid := cfg.GridConfig + lines = append(lines, + "- Type: grid strategy", + fmt.Sprintf("- Symbol: %s", defaultIfEmpty(grid.Symbol, "unset")), + fmt.Sprintf("- Grid count: %d", grid.GridCount), + fmt.Sprintf("- Total investment: %.2f USDT", grid.TotalInvestment), + fmt.Sprintf("- Leverage: %dx", grid.Leverage), + fmt.Sprintf("- Distribution: %s", defaultIfEmpty(grid.Distribution, "unset")), + ) + if grid.UseATRBounds { + lines = append(lines, fmt.Sprintf("- Price range: ATR auto bounds (multiplier %.2f)", grid.ATRMultiplier)) + } else { + lines = append(lines, fmt.Sprintf("- Price range: %.2f - %.2f USDT", grid.LowerPrice, grid.UpperPrice)) + } + } else { + lines = append(lines, + "- Type: AI strategy", + fmt.Sprintf("- Coin source: %s", defaultIfEmpty(cfg.CoinSource.SourceType, "unset")), + fmt.Sprintf("- Primary timeframe: %s", defaultIfEmpty(cfg.Indicators.Klines.PrimaryTimeframe, "unset")), + ) + } + return strings.Join(lines, "\n") } func (a *Agent) handleSimpleEntitySkill(storeUserID string, userID int64, lang, text string, session skillSession, skillName, action string, options []traderSkillOption) (string, bool) { - if isCancelSkillReply(text) { - a.clearSkillSession(userID) - if lang == "zh" { - return "已取消当前流程。", true - } - return "Cancelled the current flow.", true - } if session.Name == "" { session = skillSession{Name: skillName, Action: action, Phase: "collecting"} } if session.Name != skillName || session.Action != action { return "", false } + if supportsBulkTargetSelection(skillName, action) && textMeansAllTargets(text) { + setField(&session, "bulk_scope", "all") + session.TargetRef = nil + } if dag, ok := getSkillDAG(skillName, action); ok && len(dag.Steps) > 0 { currentStep, _ := currentSkillDAGStep(session) if currentStep.ID == "resolve_target" { - if supportsBulkTargetSelection(skillName, action) && textMeansAllTargets(text) { - setField(&session, "bulk_scope", "all") - advanceSkillDAGStep(&session, currentStep.ID) - } else { - session.TargetRef = resolveTargetFromText(text, options, session.TargetRef) + if resolved := resolveTargetSelection(text, options, session.TargetRef); resolved.Ref != nil { + session.TargetRef = resolved.Ref } if session.TargetRef == nil { if !(supportsBulkTargetSelection(skillName, action) && fieldValue(session, "bulk_scope") == "all") { @@ -873,7 +2559,7 @@ func (a *Agent) handleSimpleEntitySkill(storeUserID string, userID int64, lang, } return reply, true } - reply := "This step needs a target object first. Tell me which one to operate on." + reply := "One more thing: tell me which one you want me to work on." if optionList != "" { reply += "\n" + optionList } @@ -885,10 +2571,8 @@ func (a *Agent) handleSimpleEntitySkill(storeUserID string, userID int64, lang, } } } else { - if supportsBulkTargetSelection(skillName, action) && textMeansAllTargets(text) { - setField(&session, "bulk_scope", "all") - } else { - session.TargetRef = resolveTargetFromText(text, options, session.TargetRef) + if resolved := resolveTargetSelection(text, options, session.TargetRef); resolved.Ref != nil { + session.TargetRef = resolved.Ref } if session.TargetRef == nil && fieldValue(session, "bulk_scope") != "all" && action != "query" && action != "query_list" && action != "query_detail" && action != "query_running" { a.saveSkillSession(userID, session) @@ -900,7 +2584,26 @@ func (a *Agent) handleSimpleEntitySkill(storeUserID string, userID int64, lang, } return reply, true } - reply := "I still need you to specify which object to operate on." + reply := "One more thing: tell me which one you want to work on." + if label != "" { + reply += "\n" + label + } + return reply, true + } + } + + if session.TargetRef != nil && action != "create" && action != "query_list" && action != "query_running" { + if !ensureLiveTargetReference(&session, options) { + a.saveSkillSession(userID, session) + label := formatOptionList("可选对象:", options) + if lang == "zh" { + reply := "我刚检查了一下,刚才记住的对象已经不存在或已失效了。请重新告诉我要操作哪一个对象。" + if label != "" { + reply += "\n" + label + } + return reply, true + } + reply := "The object remembered from earlier no longer exists. Please tell me which object to operate on now." if label != "" { reply += "\n" + label } @@ -922,6 +2625,10 @@ func (a *Agent) handleSimpleEntitySkill(storeUserID string, userID int64, lang, } } +func (a *Agent) askLLMAmbiguousTargetQuestion(storeUserID string, userID int64, lang, text string, session skillSession, skillName, action string, allOptions, ambiguous []traderSkillOption) string { + return formatAmbiguousTargetPrompt(lang, ambiguous) +} + func defaultIfEmpty(value, fallback string) string { value = strings.TrimSpace(value) if value == "" { diff --git a/agent/skill_outcome.go b/agent/skill_outcome.go index 1075a434..8922ad2d 100644 --- a/agent/skill_outcome.go +++ b/agent/skill_outcome.go @@ -42,11 +42,11 @@ func normalizeAtomicSkillAction(skill, action string) string { return "query_list" case "query_running": return "query_running" - case "query_detail": + case "query_detail", "query_strategy_binding", "query_exchange_binding", "query_model_binding": + return action + case "query_binding": return "query_detail" - case "update": - return "update_name" - case "update_name", "update_bindings": + case "update", "update_bindings", "configure_strategy", "configure_exchange", "configure_model": return action } case "exchange_management": @@ -55,9 +55,7 @@ func normalizeAtomicSkillAction(skill, action string) string { return "query_list" case "query_detail": return "query_detail" - case "update": - return "update_name" - case "update_name", "update_status": + case "update", "update_name", "update_status": return action } case "model_management": @@ -66,9 +64,7 @@ func normalizeAtomicSkillAction(skill, action string) string { return "query_list" case "query_detail": return "query_detail" - case "update": - return "update_name" - case "update_name", "update_endpoint", "update_status": + case "update", "update_name", "update_endpoint", "update_status": return action } case "strategy_management": @@ -77,9 +73,7 @@ func normalizeAtomicSkillAction(skill, action string) string { return "query_list" case "query_detail": return "query_detail" - case "update": - return "update_name" - case "update_name", "update_config", "update_prompt": + case "update", "update_name", "update_config", "update_prompt": return action } } @@ -158,6 +152,7 @@ Rules: - Use route "replan" when the user's task is not complete yet and the planner should continue from the new skill outcome. - Prefer route "replan" for recoverable errors, unmet goals, missing prerequisites, or cases where another skill/tool sequence may help. - If you choose "complete", produce the final user-facing answer in the user's language. +- ` + cleanUserFacingReplyInstruction + ` Return JSON with this exact shape: {"route":"complete|replan","answer":""}` diff --git a/agent/skill_registry.go b/agent/skill_registry.go index a74b3cbf..8b231c09 100644 --- a/agent/skill_registry.go +++ b/agent/skill_registry.go @@ -6,19 +6,40 @@ import ( "fmt" "sort" "strings" + "sync" ) //go:embed skills/*.json var embeddedSkillDefinitions embed.FS type SkillDefinition struct { - Name string `json:"name"` - Kind string `json:"kind"` - Domain string `json:"domain"` - Description string `json:"description"` - Intents []string `json:"intents,omitempty"` - Actions map[string]SkillActionDefinition `json:"actions,omitempty"` - ToolMapping map[string]string `json:"tool_mapping,omitempty"` + Name string `json:"name"` + Kind string `json:"kind"` + Domain string `json:"domain"` + Description string `json:"description"` + Intents []string `json:"intents,omitempty"` + Capabilities []string `json:"capabilities,omitempty"` + DynamicRules []string `json:"dynamic_rules,omitempty"` + Actions map[string]SkillActionDefinition `json:"actions,omitempty"` + ToolMapping map[string]string `json:"tool_mapping,omitempty"` + FieldConstraints map[string]SkillFieldConstraint `json:"field_constraints,omitempty"` + ValidationRules []string `json:"validation_rules,omitempty"` + PerExchangeRequiredFields map[string][]string `json:"per_exchange_required_fields,omitempty"` +} + +type SkillFieldConstraint struct { + Type string `json:"type,omitempty"` + Required bool `json:"required,omitempty"` + Values []string `json:"values,omitempty"` + Aliases map[string]string `json:"aliases,omitempty"` + Description string `json:"description,omitempty"` + RequiredFor []string `json:"required_for,omitempty"` + Default any `json:"default,omitempty"` + Min *float64 `json:"min,omitempty"` + Max *float64 `json:"max,omitempty"` + MaxLength int `json:"max_length,omitempty"` + MustBeHTTPS bool `json:"must_be_https,omitempty"` + Pattern string `json:"pattern,omitempty"` } type SkillActionDefinition struct { @@ -26,9 +47,14 @@ type SkillActionDefinition struct { RequiredSlots []string `json:"required_slots,omitempty"` OptionalSlots []string `json:"optional_slots,omitempty"` NeedsConfirmation bool `json:"needs_confirmation,omitempty"` + Goal string `json:"goal,omitempty"` + DynamicRules []string `json:"dynamic_rules,omitempty"` + SuccessOutput string `json:"success_output,omitempty"` + FailureOutput string `json:"failure_output,omitempty"` } var skillRegistry = mustLoadSkillRegistry() +var skillContextCache sync.Map func mustLoadSkillRegistry() map[string]SkillDefinition { registry, err := loadSkillRegistry() @@ -72,6 +98,8 @@ func normalizeSkillDefinition(def SkillDefinition) SkillDefinition { def.Domain = strings.TrimSpace(def.Domain) def.Description = strings.TrimSpace(def.Description) def.Intents = cleanStringList(def.Intents) + def.Capabilities = cleanStringList(def.Capabilities) + def.DynamicRules = cleanStringList(def.DynamicRules) if len(def.Actions) > 0 { normalized := make(map[string]SkillActionDefinition, len(def.Actions)) @@ -83,6 +111,10 @@ func normalizeSkillDefinition(def SkillDefinition) SkillDefinition { action.Description = strings.TrimSpace(action.Description) action.RequiredSlots = cleanStringList(action.RequiredSlots) action.OptionalSlots = cleanStringList(action.OptionalSlots) + action.Goal = strings.TrimSpace(action.Goal) + action.DynamicRules = cleanStringList(action.DynamicRules) + action.SuccessOutput = strings.TrimSpace(action.SuccessOutput) + action.FailureOutput = strings.TrimSpace(action.FailureOutput) normalized[key] = action } def.Actions = normalized @@ -101,6 +133,46 @@ func normalizeSkillDefinition(def SkillDefinition) SkillDefinition { def.ToolMapping = normalized } + if len(def.FieldConstraints) > 0 { + normalized := make(map[string]SkillFieldConstraint, len(def.FieldConstraints)) + for key, constraint := range def.FieldConstraints { + key = strings.TrimSpace(key) + if key == "" { + continue + } + constraint.Type = strings.TrimSpace(constraint.Type) + constraint.Values = cleanStringList(constraint.Values) + constraint.RequiredFor = cleanStringList(constraint.RequiredFor) + constraint.Description = strings.TrimSpace(constraint.Description) + if len(constraint.Aliases) > 0 { + aliases := make(map[string]string, len(constraint.Aliases)) + for alias, value := range constraint.Aliases { + alias = strings.TrimSpace(alias) + value = strings.TrimSpace(value) + if alias == "" || value == "" { + continue + } + aliases[alias] = value + } + constraint.Aliases = aliases + } + normalized[key] = constraint + } + def.FieldConstraints = normalized + } + def.ValidationRules = cleanStringList(def.ValidationRules) + if len(def.PerExchangeRequiredFields) > 0 { + normalized := make(map[string][]string, len(def.PerExchangeRequiredFields)) + for key, fields := range def.PerExchangeRequiredFields { + key = strings.TrimSpace(key) + if key == "" { + continue + } + normalized[key] = cleanStringList(fields) + } + def.PerExchangeRequiredFields = normalized + } + return def } @@ -117,3 +189,533 @@ func listSkillNames() []string { sort.Strings(names) return names } + +func buildSkillRoutingSummary(lang string, skillNames []string) string { + lines := make([]string, 0, len(skillNames)) + for _, name := range skillNames { + def, ok := getSkillDefinition(name) + if !ok { + continue + } + parts := []string{strings.TrimSpace(def.Description)} + if len(def.DynamicRules) > 0 { + parts = append(parts, strings.Join(def.DynamicRules, " ")) + } + switch name { + case "trader_management": + if lang == "zh" { + parts = append(parts, "这个 skill 负责交易员本体和绑定关系;交易员编辑默认只换绑定,不改策略、模型、交易所的内部配置。") + } else { + parts = append(parts, "This skill owns the trader itself and its bindings; trader edits should switch bindings, not mutate the internals of the strategy, model, or exchange.") + } + case "strategy_management": + if lang == "zh" { + parts = append(parts, "策略模板创建后应出现在策略列表/策略页。用户没问运行时,不要主动延伸到交易员绑定。") + } else { + parts = append(parts, "After creation, strategy templates should appear in the strategy list/page. Do not proactively bring up trader binding unless the user asks to run it.") + } + } + lines = append(lines, fmt.Sprintf("- %s: %s", name, strings.Join(cleanStringList(parts), " "))) + } + return strings.Join(lines, "\n") +} + +func buildSkillDefinitionSummary(lang string, skillNames []string) string { + lines := make([]string, 0, len(skillNames)) + for _, name := range skillNames { + def, ok := getSkillDefinition(name) + if !ok { + continue + } + parts := []string{strings.TrimSpace(def.Description)} + if len(def.Capabilities) > 0 { + if lang == "zh" { + parts = append(parts, "能力: "+strings.Join(def.Capabilities, ";")) + } else { + parts = append(parts, "capabilities: "+strings.Join(def.Capabilities, "; ")) + } + } + if len(def.DynamicRules) > 0 { + if lang == "zh" { + parts = append(parts, "规则: "+strings.Join(def.DynamicRules, ";")) + } else { + parts = append(parts, "rules: "+strings.Join(def.DynamicRules, "; ")) + } + } + if action, ok := def.Actions["create"]; ok && len(action.RequiredSlots) > 0 { + if lang == "zh" { + parts = append(parts, "创建必填: "+formatRequiredSlotList(lang, action.RequiredSlots)) + } else { + parts = append(parts, "create requires: "+formatRequiredSlotList(lang, action.RequiredSlots)) + } + } + switch name { + case "trader_management": + if lang == "zh" { + parts = append(parts, "这个 skill 负责交易员本体和绑定关系;交易员编辑默认只换绑定,不改策略、模型、交易所的内部配置。") + } else { + parts = append(parts, "This skill owns the trader itself and its bindings; trader edits should switch bindings, not mutate the internals of the strategy, model, or exchange.") + } + case "strategy_management": + if lang == "zh" { + parts = append(parts, "策略模板创建后应出现在策略列表/策略页。用户没问运行时,不要主动延伸到交易员绑定。") + } else { + parts = append(parts, "After creation, strategy templates should appear in the strategy list/page. Do not proactively bring up trader binding unless the user asks to run it.") + } + } + lines = append(lines, fmt.Sprintf("- %s: %s", name, strings.Join(cleanStringList(parts), " "))) + } + return strings.Join(lines, "\n") +} + +func defaultManagementSkillNames() []string { + return []string{ + "trader_management", + "exchange_management", + "model_management", + "strategy_management", + } +} + +func buildSkillDependencySummary(lang string, session skillSession) string { + if strings.TrimSpace(session.Name) == "" { + return "" + } + switch session.Name { + case "trader_management": + if session.Action == "create" { + if lang == "zh" { + return "trader_management:create 必须收齐 4 个核心槽位:交易员名称、交易所、模型、策略。后 3 个依赖项都允许两种补法:直接选用户已有可用资源,或在当前主流程里立即新建/启用后再回流继续创建交易员。若用户是在启用、修复或新建这些依赖资源,这仍然是在继续创建交易员主流程,不是新开平级任务。" + } + return "trader_management:create requires 4 core slots: trader name, exchange, model, and strategy. The last 3 dependencies can be satisfied in two ways: choose an existing usable resource, or create/enable one inline and then resume trader creation. If the user is enabling, fixing, or creating one of those dependencies, that is still continuation of the trader creation flow, not a new peer task." + } + if lang == "zh" { + return "当当前对象是交易员时,换绑模型、交易所、策略都属于 trader_management 的继续操作;但如果用户要改这些对象的内部配置,应切到对应 management skill。" + } + return "When the current object is a trader, rebinding its model, exchange, or strategy remains inside trader_management; but if the user wants to change the internals of those resources, switch to the corresponding management skill." + default: + return "" + } +} + +func buildSkillActionContractSummary(lang string, session skillSession) string { + if strings.TrimSpace(session.Name) == "" || strings.TrimSpace(session.Action) == "" { + return "" + } + + def, ok := getSkillDefinition(session.Name) + if !ok { + return "" + } + action, ok := def.Actions[session.Action] + if !ok { + return "" + } + + required := defaultIfEmpty(formatRequiredSlotList(lang, action.RequiredSlots), "无") + goal := strings.TrimSpace(action.Goal) + if goal == "" { + goal = strings.TrimSpace(action.Description) + } + + lines := []string{ + fmt.Sprintf("### Active Skill Contract: %s:%s", session.Name, session.Action), + } + if lang == "zh" { + lines = append(lines, "- 目标:"+defaultIfEmpty(goal, "按该动作的业务规则完成当前请求。")) + lines = append(lines, "- 必填输入:"+required) + if len(action.DynamicRules) > 0 { + lines = append(lines, "- 动态逻辑规则:") + for i, rule := range action.DynamicRules { + lines = append(lines, fmt.Sprintf(" %d. %s", i+1, rule)) + } + } + if action.SuccessOutput != "" || action.FailureOutput != "" { + lines = append(lines, "- 预期输出:"+strings.TrimSpace(strings.Join(cleanStringList([]string{ + ifThenElse(action.SuccessOutput != "", "成功:"+action.SuccessOutput, ""), + ifThenElse(action.FailureOutput != "", "失败:"+action.FailureOutput, ""), + }), ";"))) + } + } else { + lines = append(lines, "- Goal: "+defaultIfEmpty(goal, "Complete the current request under this action's business rules.")) + lines = append(lines, "- Required input: "+required) + if len(action.DynamicRules) > 0 { + lines = append(lines, "- Dynamic rules:") + for i, rule := range action.DynamicRules { + lines = append(lines, fmt.Sprintf(" %d. %s", i+1, rule)) + } + } + if action.SuccessOutput != "" || action.FailureOutput != "" { + lines = append(lines, "- Expected output: "+strings.TrimSpace(strings.Join(cleanStringList([]string{ + ifThenElse(action.SuccessOutput != "", "success: "+action.SuccessOutput, ""), + ifThenElse(action.FailureOutput != "", "failure: "+action.FailureOutput, ""), + }), "; "))) + } + } + return strings.Join(lines, "\n") +} + +func ifThenElse[T any](cond bool, a, b T) T { + if cond { + return a + } + return b +} + +func buildSkillForbiddenSummary(lang string, skillNames []string) string { + lines := make([]string, 0, len(skillNames)) + for _, name := range skillNames { + switch name { + case "trader_management": + if lang == "zh" { + lines = append(lines, "- trader_management 不能直接设计赚钱/不亏钱方案;那类目标应交给 planner。") + lines = append(lines, "- trader_management 不能让用户手动设置、充值或修改交易员余额;交易员初始余额应由系统自动读取绑定交易所净值。") + } else { + lines = append(lines, "- trader_management must not invent a profit-seeking plan; those requests belong to the planner.") + lines = append(lines, "- trader_management must not let the user set, top up, or manually edit trader balance; trader initial balance should be auto-read from the bound exchange equity.") + } + case "exchange_management": + if lang == "zh" { + lines = append(lines, "- exchange_management 只负责保存和修改交易所配置,不负责行情查询、交易执行或诊断 API 报错。") + } else { + lines = append(lines, "- exchange_management only saves and updates exchange configs; it does not do market reads, trading, or API diagnosis.") + } + case "model_management": + if lang == "zh" { + lines = append(lines, "- model_management 只负责保存和修改模型配置,不负责测试连接、诊断上游错误或生成策略方案。") + } else { + lines = append(lines, "- model_management only saves and updates model configs; it does not test connectivity, diagnose upstream failures, or design strategies.") + } + case "strategy_management": + if lang == "zh" { + lines = append(lines, "- strategy_management 只负责模板管理;策略模板不能直接启动运行,运行态属于 trader。") + } else { + lines = append(lines, "- strategy_management only manages templates; strategy templates do not run directly and runtime belongs to traders.") + } + } + } + return strings.Join(lines, "\n") +} + +func buildManagementSkillContext(lang string, session *skillSession) string { + key := fmt.Sprintf("full|%s|", lang) + if session != nil { + key = fmt.Sprintf("full|%s|%s|%s", lang, strings.TrimSpace(session.Name), strings.TrimSpace(session.Action)) + } + return cachedSkillContext(key, func() string { + parts := make([]string, 0, 3) + if summary := buildSkillDefinitionSummary(lang, defaultManagementSkillNames()); summary != "" { + parts = append(parts, "Management skill summary:\n"+summary) + } + if forbidden := buildSkillForbiddenSummary(lang, defaultManagementSkillNames()); forbidden != "" { + parts = append(parts, "Management skill negative constraints:\n"+forbidden) + } + if session != nil { + if dependency := buildSkillDependencySummary(lang, *session); dependency != "" { + parts = append(parts, "Active skill dependency summary:\n"+dependency) + } + if contract := buildSkillActionContractSummary(lang, *session); contract != "" { + parts = append(parts, contract) + } + } + return strings.Join(parts, "\n\n") + }) +} + +func buildManagementSkillRoutingContext(lang string) string { + return buildManagementSkillRoutingContextWithSession(lang, nil) +} + +func buildSkillActionRoutingSummary(lang string, session skillSession) string { + if strings.TrimSpace(session.Name) == "" || strings.TrimSpace(session.Action) == "" { + return "" + } + def, ok := getSkillDefinition(session.Name) + if !ok { + return "" + } + action, ok := def.Actions[session.Action] + if !ok { + return "" + } + + lines := []string{ + fmt.Sprintf("### Active skill routing hints: %s:%s", session.Name, session.Action), + } + if goal := strings.TrimSpace(action.Goal); goal != "" { + if lang == "zh" { + lines = append(lines, "- 当前动作目标:"+goal) + } else { + lines = append(lines, "- Current action goal: "+goal) + } + } + if dependency := buildSkillDependencySummary(lang, session); dependency != "" { + if lang == "zh" { + lines = append(lines, "- 当前 flow 依赖提示:"+dependency) + } else { + lines = append(lines, "- Flow dependency hint: "+dependency) + } + } + if len(action.DynamicRules) > 0 { + if lang == "zh" { + lines = append(lines, "- 当前动作动态规则:") + } else { + lines = append(lines, "- Current action dynamic rules:") + } + for i, rule := range action.DynamicRules { + lines = append(lines, fmt.Sprintf(" %d. %s", i+1, rule)) + } + } + return strings.Join(lines, "\n") +} + +func buildManagementSkillRoutingContextWithSession(lang string, session *skillSession) string { + key := fmt.Sprintf("routing|%s|", lang) + if session != nil { + key = fmt.Sprintf("routing|%s|%s|%s", lang, strings.TrimSpace(session.Name), strings.TrimSpace(session.Action)) + } + return cachedSkillContext(key, func() string { + parts := make([]string, 0, 1) + if summary := buildSkillRoutingSummary(lang, defaultManagementSkillNames()); summary != "" { + parts = append(parts, "Management skill summary:\n"+summary) + } + if session != nil { + if summary := buildSkillActionRoutingSummary(lang, *session); summary != "" { + parts = append(parts, summary) + } + } + return strings.Join(parts, "\n\n") + }) +} + +func buildCurrentSkillExecutionContext(lang string, session skillSession) string { + key := fmt.Sprintf("current|%s|%s|%s", lang, strings.TrimSpace(session.Name), strings.TrimSpace(session.Action)) + return cachedSkillContext(key, func() string { + parts := make([]string, 0, 3) + if dependency := buildSkillDependencySummary(lang, session); dependency != "" { + parts = append(parts, "Active skill dependency summary:\n"+dependency) + } + if contract := buildSkillActionContractSummary(lang, session); contract != "" { + parts = append(parts, contract) + } + if knowledge := buildSkillFieldKnowledgeSummary(lang, session); knowledge != "" { + parts = append(parts, knowledge) + } + return strings.Join(parts, "\n\n") + }) +} + +func buildSkillFieldKnowledgeSummary(lang string, session skillSession) string { + def, ok := getSkillDefinition(session.Name) + if !ok { + return "" + } + action, hasAction := def.Actions[session.Action] + relevant := orderedSkillFieldKeys(def, action, hasAction) + lines := make([]string, 0, len(relevant)+6) + title := "### Active Field Knowledge" + if lang == "zh" { + title = "### 当前字段知识" + } + lines = append(lines, title) + for _, field := range relevant { + constraint, ok := def.FieldConstraints[field] + if !ok { + continue + } + lines = append(lines, formatFieldKnowledgeLine(lang, field, constraint)) + } + if len(def.PerExchangeRequiredFields) > 0 { + if lang == "zh" { + lines = append(lines, "- 按交易所类型的必填字段:") + } else { + lines = append(lines, "- Required fields by exchange type:") + } + keys := make([]string, 0, len(def.PerExchangeRequiredFields)) + for key := range def.PerExchangeRequiredFields { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + fields := make([]string, 0, len(def.PerExchangeRequiredFields[key])) + for _, field := range def.PerExchangeRequiredFields[key] { + fields = append(fields, fieldKnowledgeDisplayName(field, lang)) + } + lines = append(lines, fmt.Sprintf(" - %s: %s", key, strings.Join(fields, "、"))) + } + } + if len(def.ValidationRules) > 0 { + if lang == "zh" { + lines = append(lines, "- 关键校验规则:") + } else { + lines = append(lines, "- Key validation rules:") + } + for i, rule := range def.ValidationRules { + lines = append(lines, fmt.Sprintf(" %d. %s", i+1, rule)) + } + } + if len(lines) == 1 { + return "" + } + return strings.Join(lines, "\n") +} + +func orderedSkillFieldKeys(def SkillDefinition, action SkillActionDefinition, hasAction bool) []string { + keys := make([]string, 0, len(def.FieldConstraints)) + seen := map[string]struct{}{} + add := func(field string) { + field = strings.TrimSpace(field) + if field == "" { + return + } + if _, ok := def.FieldConstraints[field]; !ok { + return + } + if _, ok := seen[field]; ok { + return + } + seen[field] = struct{}{} + keys = append(keys, field) + } + if hasAction { + for _, field := range action.RequiredSlots { + add(field) + } + for _, field := range action.OptionalSlots { + add(field) + } + } + if len(keys) == 0 { + for field := range def.FieldConstraints { + add(field) + } + } + return keys +} + +func formatFieldKnowledgeLine(lang, field string, constraint SkillFieldConstraint) string { + parts := make([]string, 0, 8) + if constraint.Description != "" { + parts = append(parts, constraint.Description) + } + if constraint.Type != "" { + if lang == "zh" { + parts = append(parts, "类型="+constraint.Type) + } else { + parts = append(parts, "type="+constraint.Type) + } + } + if constraint.Required { + if lang == "zh" { + parts = append(parts, "当前全局必填") + } else { + parts = append(parts, "globally required") + } + } + if len(constraint.Values) > 0 { + label := "可选值=" + if lang != "zh" { + label = "values=" + } + parts = append(parts, label+strings.Join(constraint.Values, "/")) + } + if len(constraint.RequiredFor) > 0 { + label := "仅这些类型必填=" + if lang != "zh" { + label = "required_for=" + } + parts = append(parts, label+strings.Join(constraint.RequiredFor, "/")) + } + if len(constraint.Aliases) > 0 { + aliasPairs := make([]string, 0, len(constraint.Aliases)) + keys := make([]string, 0, len(constraint.Aliases)) + for alias := range constraint.Aliases { + keys = append(keys, alias) + } + sort.Strings(keys) + for _, alias := range keys { + aliasPairs = append(aliasPairs, alias+"->"+constraint.Aliases[alias]) + } + label := "别名=" + if lang != "zh" { + label = "aliases=" + } + parts = append(parts, label+strings.Join(aliasPairs, ", ")) + } + if constraint.MustBeHTTPS { + if lang == "zh" { + parts = append(parts, "必须是 HTTPS") + } else { + parts = append(parts, "must be HTTPS") + } + } + if constraint.Min != nil || constraint.Max != nil { + rangeText := "" + switch { + case constraint.Min != nil && constraint.Max != nil: + rangeText = fmt.Sprintf("%.0f~%.0f", *constraint.Min, *constraint.Max) + case constraint.Min != nil: + rangeText = fmt.Sprintf(">=%.0f", *constraint.Min) + case constraint.Max != nil: + rangeText = fmt.Sprintf("<=%.0f", *constraint.Max) + } + if rangeText != "" { + label := "范围=" + if lang != "zh" { + label = "range=" + } + parts = append(parts, label+rangeText) + } + } + return fmt.Sprintf("- %s: %s", fieldKnowledgeDisplayName(field, lang), strings.Join(cleanStringList(parts), ";")) +} + +func fieldKnowledgeDisplayName(field, lang string) string { + if lang == "zh" { + switch field { + case "exchange_type": + return "交易所类型" + case "account_name": + return "账户名" + case "provider": + return "模型提供商" + case "custom_model_name": + return "模型名称" + case "custom_api_url": + return "接口地址" + } + } + return displayCatalogFieldName(field, lang) +} + +func formatRequiredSlotList(lang string, slots []string) string { + display := make([]string, 0, len(slots)) + for _, slot := range cleanStringList(slots) { + display = append(display, slotDisplayName(slot, lang)) + } + return strings.Join(display, "、") +} + +func missingRequiredActionSlots(skillName, action string, values map[string]string) []string { + runtime, ok := getSkillActionRuntime(skillName, action) + if !ok { + return nil + } + missing := make([]string, 0, len(runtime.Action.RequiredSlots)) + for _, slot := range runtime.Action.RequiredSlots { + if strings.TrimSpace(values[slot]) == "" { + missing = append(missing, slot) + } + } + return missing +} +func cachedSkillContext(key string, build func() string) string { + if cached, ok := skillContextCache.Load(key); ok { + if s, ok := cached.(string); ok { + return s + } + } + value := build() + skillContextCache.Store(key, value) + return value +} diff --git a/agent/skill_registry_test.go b/agent/skill_registry_test.go deleted file mode 100644 index 99a14987..00000000 --- a/agent/skill_registry_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package agent - -import "testing" - -func TestSkillRegistryLoadsDefinitions(t *testing.T) { - names := listSkillNames() - if len(names) < 4 { - t.Fatalf("expected skill registry to load definitions, got %v", names) - } - - for _, name := range []string{ - "trader_management", - "exchange_management", - "model_management", - "strategy_management", - "exchange_diagnosis", - "model_diagnosis", - } { - if _, ok := getSkillDefinition(name); !ok { - t.Fatalf("missing skill definition %q", name) - } - } -} - -func TestTraderManagementDefinitionHasCreateAction(t *testing.T) { - def, ok := getSkillDefinition("trader_management") - if !ok { - t.Fatalf("missing trader_management definition") - } - action, ok := def.Actions["create"] - if !ok { - t.Fatalf("missing create action in trader_management") - } - if len(action.RequiredSlots) == 0 { - t.Fatalf("expected required slots for trader_management create action") - } -} - -func TestActionNeedsConfirmationUsesSkillDefinition(t *testing.T) { - if !actionNeedsConfirmation("exchange_management", "delete") { - t.Fatalf("expected exchange_management delete to require confirmation") - } - if actionNeedsConfirmation("exchange_management", "query") { - t.Fatalf("did not expect exchange_management query to require confirmation") - } -} - -func TestActionRequiresSlotUsesSkillDefinition(t *testing.T) { - if !actionRequiresSlot("model_management", "create", "provider") { - t.Fatalf("expected model_management create to require provider") - } - if actionRequiresSlot("model_management", "create", "target_ref") { - t.Fatalf("did not expect model_management create to require target_ref") - } -} diff --git a/agent/skill_runner.go b/agent/skill_runner.go index a2b7fdbf..38a1940a 100644 --- a/agent/skill_runner.go +++ b/agent/skill_runner.go @@ -89,7 +89,7 @@ func slotDisplayName(slot, lang string) string { case "exchange_type": return "交易所类型" case "provider": - return "provider" + return "模型提供商" default: return slot } @@ -115,6 +115,39 @@ func formatAwaitConfirmationMessage(lang, action, targetLabel string) string { return fmt.Sprintf("You are about to %s %q. Please reply 'confirm' to continue or 'cancel' to stop.", actionLabel, targetLabel) } +func formatTargetConfirmationLabel(lang string, session *skillSession, targetLabel string) string { + targetLabel = strings.TrimSpace(targetLabel) + if session == nil || session.TargetRef == nil || targetLabel == "" { + return targetLabel + } + source := strings.TrimSpace(session.TargetRef.Source) + if source == "" { + return targetLabel + } + if lang == "zh" { + sourceLabel := "系统上下文" + switch source { + case "user_mention": + sourceLabel = "你刚才点名的对象" + case "tool_output": + sourceLabel = "刚刚工具返回的对象" + case "inferred_from_context": + sourceLabel = "上下文推断对象" + } + return fmt.Sprintf("%s(当前识别来源:%s)", targetLabel, sourceLabel) + } + sourceLabel := "context" + switch source { + case "user_mention": + sourceLabel = "your explicit mention" + case "tool_output": + sourceLabel = "recent tool output" + case "inferred_from_context": + sourceLabel = "context inference" + } + return fmt.Sprintf("%s (current reference source: %s)", targetLabel, sourceLabel) +} + func formatStillWaitingConfirmationMessage(lang string) string { if lang == "zh" { return "当前流程仍在等待你确认。回复“确认”继续,或“取消”终止。" @@ -122,13 +155,80 @@ func formatStillWaitingConfirmationMessage(lang string) string { return "This flow is still waiting for your confirmation." } -func beginConfirmationIfNeeded(userID int64, lang string, session *skillSession, targetLabel string) (string, bool) { +func referenceKindForSkill(skillName string) string { + switch strings.TrimSpace(skillName) { + case "strategy_management": + return "strategy" + case "trader_management": + return "trader" + case "model_management": + return "model" + case "exchange_management": + return "exchange" + default: + return "" + } +} + +func referenceKindDisplayName(lang, kind string) string { + if lang == "zh" { + switch kind { + case "strategy": + return "策略" + case "trader": + return "交易员" + case "model": + return "模型" + case "exchange": + return "交易所" + } + return "对象" + } + return kind +} + +func (a *Agent) formatConfirmationTargetLabel(userID int64, lang string, session *skillSession, targetLabel string) string { + label := formatTargetConfirmationLabel(lang, session, targetLabel) + if session == nil || session.TargetRef == nil { + return label + } + kind := referenceKindForSkill(session.Name) + if kind == "" { + return label + } + state := a.getExecutionState(userID) + recentNames := map[string]struct{}{} + for _, item := range state.ReferenceHistory { + if item.Kind != kind { + continue + } + name := strings.TrimSpace(defaultIfEmpty(item.Name, item.ID)) + if name == "" { + continue + } + recentNames[name] = struct{}{} + } + targetName := strings.TrimSpace(defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)) + _, inferred := recentNames[targetName] + if targetName == "" { + return label + } + if len(recentNames) <= 1 && strings.TrimSpace(session.TargetRef.Source) != "inferred_from_context" && inferred { + return label + } + if lang == "zh" { + return fmt.Sprintf("%s。系统当前理解你要操作的%s是“%s”。", label, referenceKindDisplayName(lang, kind), targetName) + } + return fmt.Sprintf("%s. The current %s I'm about to operate on is %q.", label, referenceKindDisplayName(lang, kind), targetName) +} + +func (a *Agent) beginConfirmationIfNeeded(userID int64, lang string, session *skillSession, targetLabel string) (string, bool) { if session == nil || !actionNeedsConfirmation(session.Name, session.Action) { return "", false } if session.Phase != "await_confirmation" { session.Phase = "await_confirmation" - return formatAwaitConfirmationMessage(lang, session.Action, targetLabel), true + return formatAwaitConfirmationMessage(lang, session.Action, a.formatConfirmationTargetLabel(userID, lang, session, targetLabel)), true } return "", false } diff --git a/agent/skill_semantic_gate.go b/agent/skill_semantic_gate.go new file mode 100644 index 00000000..d110d70f --- /dev/null +++ b/agent/skill_semantic_gate.go @@ -0,0 +1,246 @@ +package agent + +import ( + "encoding/json" + "strings" + + "nofx/store" +) + +func (a *Agent) skillVisibleFieldSummary(storeUserID, lang, skillName, action string) string { + fieldNames := make([]string, 0, 20) + add := func(field string) { + field = strings.TrimSpace(field) + if field == "" { + return + } + for _, existing := range fieldNames { + if existing == field { + return + } + } + fieldNames = append(fieldNames, field) + } + + switch skillName { + case "model_management": + if lang == "zh" { + add("Provider") + } else { + add("provider") + } + add(displayCatalogFieldName("name", lang)) + for _, field := range manualModelEditableFieldKeys() { + add(displayCatalogFieldName(field, lang)) + } + case "exchange_management": + add(slotDisplayName("exchange_type", lang)) + for _, field := range manualExchangeEditableFieldKeys() { + add(displayCatalogFieldName(field, lang)) + } + case "trader_management": + if strings.TrimSpace(action) == "create" { + add(slotDisplayName("name", lang)) + } + for _, field := range manualTraderEditableFieldKeys() { + add(displayCatalogFieldName(field, lang)) + } + case "strategy_management": + add(slotDisplayName("name", lang)) + for _, field := range manualStrategyEditableFieldKeys() { + add(strategyConfigFieldDisplayName(field, lang)) + } + } + if len(fieldNames) == 0 { + return "" + } + prefix := "Visible UI fields" + if lang == "zh" { + prefix = "当前可见字段" + } + return prefix + ":" + strings.Join(fieldNames, "、") +} + +func (a *Agent) strategyTypeForTarget(storeUserID string, target *EntityReference) (string, bool) { + if a == nil || a.store == nil || target == nil { + return "", false + } + var strategy *store.Strategy + var err error + if id := strings.TrimSpace(target.ID); id != "" { + strategy, err = a.store.Strategy().Get(storeUserID, id) + } else if name := strings.TrimSpace(target.Name); name != "" { + strategies, listErr := a.store.Strategy().List(storeUserID) + if listErr != nil { + return "", false + } + for _, item := range strategies { + if item != nil && strings.EqualFold(strings.TrimSpace(item.Name), name) { + strategy = item + break + } + } + } else { + return "", false + } + if err != nil || strategy == nil { + return "", false + } + cfg := store.GetDefaultStrategyConfig("zh") + if strings.TrimSpace(strategy.Config) != "" { + _ = json.Unmarshal([]byte(strategy.Config), &cfg) + } + strategyType := strings.TrimSpace(cfg.StrategyType) + if strategyType == "" { + strategyType = "ai_trading" + } + return strategyType, true +} + +func (a *Agent) skillVisibleOptionSummary(storeUserID, lang, skillName, action string) string { + switch skillName { + case "model_management": + return a.modelSkillOptionSummary(lang) + case "exchange_management": + return a.exchangeSkillOptionSummary(lang) + case "trader_management": + return a.traderSkillOptionSummary(storeUserID, lang) + case "strategy_management": + return a.strategySkillOptionSummary(storeUserID, lang) + default: + return "" + } +} + +func (a *Agent) modelSkillOptionSummary(lang string) string { + if lang == "zh" { + return modelProviderChoicePrompt(lang) + } + return modelProviderChoicePrompt(lang) +} + +func (a *Agent) exchangeSkillOptionSummary(lang string) string { + options := enumOptionValues("exchange_management", "exchange_type") + if len(options) == 0 { + options = []string{"Binance", "Bybit", "OKX", "Bitget", "Gate", "KuCoin", "Hyperliquid", "Aster", "Lighter", "Indodax"} + } + if lang == "zh" { + return "交易所类型选项:" + strings.Join(options, "、") + } + return "Exchange type options: " + strings.Join(options, ", ") +} + +func enumOptionValues(skillName, field string) []string { + def, ok := getSkillDefinition(skillName) + if !ok { + return nil + } + constraint, ok := def.FieldConstraints[field] + if !ok || len(constraint.Values) == 0 { + return nil + } + values := make([]string, 0, len(constraint.Values)) + for _, value := range constraint.Values { + if value == "" { + continue + } + switch value { + case "openai": + values = append(values, "OpenAI") + case "deepseek": + values = append(values, "DeepSeek") + case "claude": + values = append(values, "Claude") + case "gemini": + values = append(values, "Gemini") + case "qwen": + values = append(values, "Qwen") + case "kimi": + values = append(values, "Kimi") + case "grok": + values = append(values, "Grok") + case "minimax": + values = append(values, "Minimax") + case "binance": + values = append(values, "Binance") + case "okx": + values = append(values, "OKX") + case "bybit": + values = append(values, "Bybit") + case "gate": + values = append(values, "Gate") + case "kucoin": + values = append(values, "KuCoin") + case "bitget": + values = append(values, "Bitget") + case "hyperliquid": + values = append(values, "Hyperliquid") + case "aster": + values = append(values, "Aster") + case "lighter": + values = append(values, "Lighter") + case "indodax": + values = append(values, "Indodax") + default: + values = append(values, value) + } + } + return values +} + +func (a *Agent) traderSkillOptionSummary(storeUserID, lang string) string { + parts := []string{ + formatSkillOptionList(lang, "可选模型", "Available models", a.loadEnabledModelOptions(storeUserID)), + formatSkillOptionList(lang, "可选交易所", "Available exchanges", a.loadExchangeOptions(storeUserID)), + formatSkillOptionList(lang, "可选策略", "Available strategies", a.loadStrategyOptions(storeUserID)), + } + return strings.Join(filterNonEmptyStrings(parts), "\n") +} + +func (a *Agent) strategySkillOptionSummary(storeUserID, lang string) string { + parts := []string{ + "", + formatSkillOptionList(lang, "现有策略", "Existing strategies", a.loadStrategyOptions(storeUserID)), + } + sourceOptions := []string{"static", "ai500", "oi_top", "oi_low"} + if lang == "zh" { + parts[0] = "选币来源选项:static、ai500、oi_top、oi_low" + } else { + parts[0] = "Coin source options: static, ai500, oi_top, oi_low" + } + _ = sourceOptions + return strings.Join(filterNonEmptyStrings(parts), "\n") +} + +func formatSkillOptionList(lang, zhPrefix, enPrefix string, options []traderSkillOption) string { + names := make([]string, 0, len(options)) + for _, option := range options { + label := strings.TrimSpace(defaultIfEmpty(option.Name, option.ID)) + if label == "" { + continue + } + names = append(names, label) + } + if len(names) == 0 { + if lang == "zh" { + return zhPrefix + ":暂无" + } + return enPrefix + ": none" + } + if lang == "zh" { + return zhPrefix + ":" + strings.Join(names, "、") + } + return enPrefix + ": " + strings.Join(names, ", ") +} + +func filterNonEmptyStrings(items []string) []string { + out := make([]string, 0, len(items)) + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + out = append(out, item) + } + return out +} diff --git a/agent/skills/exchange_diagnosis.json b/agent/skills/exchange_diagnosis.json index c8d9b0ba..cebe3363 100644 --- a/agent/skills/exchange_diagnosis.json +++ b/agent/skills/exchange_diagnosis.json @@ -2,5 +2,17 @@ "name": "exchange_diagnosis", "kind": "diagnosis", "domain": "exchange", - "description": "当用户反馈交易所 API 连接失败、签名错误、timestamp 异常、权限不足、IP 白名单限制、账户不可用等问题时调用。适用于用户在手动配置或运行交易员时遇到的交易所接入故障。不用于创建、修改、删除或查询交易所配置这类管理操作。" + "description": "当用户反馈交易所 API 连接失败、签名错误、timestamp 异常、权限不足、IP 白名单限制、账户不可用、余额读取失败、下单失败或仓位模式错误等问题时调用。适用于用户在手动配置或运行交易员时遇到的交易所接入与执行故障。不用于创建、修改、删除或查询交易所配置这类管理操作。", + "capabilities": [ + "区分凭证缺失、签名错误、时间戳偏差、IP 白名单、权限不足、余额不足、仓位模式和 symbol 不可交易等原因", + "解释不同交易所的必填字段差异,尤其是 OKX/Bitget/KuCoin passphrase、Hyperliquid 钱包地址、Aster signer/private key、Lighter API key private key", + "把交易所原始错误翻译成新手可执行的修复步骤" + ], + "dynamic_rules": [ + "交易所连接失败优先按顺序排查:配置是否启用 -> 必填凭证是否齐全 -> API Key/Secret/Passphrase 是否填反或过期 -> 系统时间/timestamp -> IP 白名单 -> 合约/交易权限 -> 测试网/主网是否选错。", + "OKX、Bitget、KuCoin 的 passphrase/API 口令不是可选项;如果缺失,必须明确提示补齐。", + "下单失败时继续排查:账户余额/可用保证金 -> 杠杆限制 -> 仓位模式(单向/双向) -> symbol 是否支持合约交易 -> 最小下单金额/数量。", + "Hyperliquid、Aster、Lighter 这类钱包/DEX 配置错误时,不要用 CEX 的 API Key/Secret 逻辑套用;按各自 required fields 解释。", + "诊断回复不得展示完整 API Key、Secret、Passphrase 或私钥。" + ] } diff --git a/agent/skills/exchange_management.json b/agent/skills/exchange_management.json index 1baf26ce..b2c12cdf 100644 --- a/agent/skills/exchange_management.json +++ b/agent/skills/exchange_management.json @@ -3,30 +3,205 @@ "kind": "management", "domain": "exchange", "description": "当用户想创建、查看、修改或删除交易所账户配置时调用。适用于用户提到交易所账户、API Key、Secret、Passphrase、测试网开关、启用状态等配置管理需求。不用于排查 invalid signature、timestamp、权限不足、白名单限制等连接或鉴权诊断问题。", + "field_constraints": { + "exchange_type": { + "type": "enum", + "required": true, + "values": ["binance", "bybit", "okx", "bitget", "gate", "kucoin", "hyperliquid", "aster", "lighter", "indodax"], + "aliases": {"币安": "binance", "欧易": "okx", "必安": "binance", "bitget": "bitget", "bitget futures": "bitget", "bitget合约": "bitget", "库币": "kucoin", "gate.io": "gate", "hyper": "hyperliquid", "印尼站": "indodax"}, + "description": "交易所类型,必填,决定后续需要哪些凭证字段。" + }, + "account_name": { + "type": "string", + "max_length": 50, + "description": "账户显示名称,可选,用于区分同一交易所的多个账户。" + }, + "api_key": { + "type": "credential", + "pattern": "^[A-Za-z0-9_\\-]{8,}$", + "description": "交易所 API Key,至少 8 位字母数字。" + }, + "secret_key": { + "type": "credential", + "pattern": "^([A-Za-z0-9_\\-]{8,}|(0x)?[A-Fa-f0-9]{16,})$", + "description": "交易所 Secret Key,至少 8 位字母数字,或十六进制格式。" + }, + "passphrase": { + "type": "credential", + "required_for": ["okx", "bitget", "kucoin"], + "description": "OKX、Bitget、KuCoin 专用 Passphrase/API 口令,对这些交易所启用前必须填写;Binance、Bybit、Gate、Indodax 通常不需要。" + }, + "testnet": { + "type": "bool", + "default": false, + "description": "是否使用测试网(沙盒环境),默认 false(主网)。" + }, + "enabled": { + "type": "bool", + "default": true, + "description": "是否启用该交易所配置。只要必要字段齐全并配置成功,就默认启用。" + }, + "hyperliquid_wallet_addr": { + "type": "credential", + "required_for": ["hyperliquid"], + "description": "Hyperliquid 主钱包地址,Hyperliquid 账户启用前必须填写。" + }, + "hyperliquid_unified_account": { + "type": "bool", + "default": false, + "required_for": ["hyperliquid"], + "description": "是否启用 Hyperliquid unified account 模式。" + }, + "aster_user": { + "type": "credential", + "required_for": ["aster"], + "description": "Aster 用户地址,Aster 账户启用前必须填写。" + }, + "aster_signer": { + "type": "credential", + "required_for": ["aster"], + "description": "Aster Signer 地址,Aster 账户启用前必须填写。" + }, + "aster_private_key": { + "type": "credential", + "required_for": ["aster"], + "description": "Aster 私钥,Aster 账户启用前必须填写。" + }, + "lighter_wallet_addr": { + "type": "credential", + "required_for": ["lighter"], + "description": "Lighter 钱包地址,Lighter 账户启用前必须填写。" + }, + "lighter_private_key": { + "type": "credential", + "required_for": ["lighter"], + "description": "Lighter 私钥,某些 Lighter 账户模式下启用前必须填写。" + }, + "lighter_api_key_private_key": { + "type": "credential", + "required_for": ["lighter"], + "description": "Lighter API Key 私钥,Lighter 账户启用前必须填写。" + }, + "lighter_api_key_index": { + "type": "int", + "min": 0, + "max": 255, + "required_for": ["lighter"], + "description": "Lighter API Key Index,范围 0~255,超出范围自动收敛并告知用户。" + } + }, + "validation_rules": [ + "api_key 格式:至少 8 位字母数字,不符合时提示用户重新输入完整 Key。", + "secret_key 格式:至少 8 位字母数字,或十六进制格式,不符合时提示用户重新输入。", + "OKX 账户启用前必须填写 passphrase,否则拒绝启用并提示补填。", + "Bitget 和 KuCoin 页面流程里也需要 passphrase/API 口令,不能回答“没有就留空”;缺失时应明确提示补填。", + "Hyperliquid 创建/更新时应与手动页面保持一致:至少收集 api_key + hyperliquid_wallet_addr。", + "Hyperliquid 账户启用前必须填写 hyperliquid_wallet_addr。", + "若用户使用 Hyperliquid unified account 模式,应明确记录 hyperliquid_unified_account 开关状态。", + "Aster 账户启用前必须填写 aster_user、aster_signer、aster_private_key 三个字段,任一缺失都不能启用。", + "Lighter 账户启用前必须填写 lighter_wallet_addr + lighter_api_key_private_key;若当前账户模式还依赖 lighter_private_key,也要先补齐后再启用。", + "lighter_api_key_index 超出 0~255 时自动收敛到边界值并告知用户。", + "删除操作不可逆,必须先向用户确认再执行。" + ], + "per_exchange_required_fields": { + "binance": ["api_key", "secret_key"], + "okx": ["api_key", "secret_key", "passphrase"], + "bybit": ["api_key", "secret_key"], + "bitget": ["api_key", "secret_key", "passphrase"], + "gate": ["api_key", "secret_key"], + "kucoin": ["api_key", "secret_key", "passphrase"], + "indodax": ["api_key", "secret_key"], + "hyperliquid": ["api_key", "hyperliquid_wallet_addr"], + "aster": ["aster_user", "aster_signer", "aster_private_key"], + "lighter": ["lighter_wallet_addr", "lighter_api_key_private_key"] + }, "actions": { "create": { - "description": "创建新的交易所配置。", - "required_slots": ["exchange_type"], - "optional_slots": ["account_name", "api_key", "secret_key", "passphrase", "testnet"] + "description": "创建新的交易所配置。根据 exchange_type 决定需要收集哪些凭证字段。", + "required_slots": ["exchange_type", "account_name"], + "optional_slots": ["account_name", "api_key", "secret_key", "passphrase", "testnet", "hyperliquid_wallet_addr", "hyperliquid_unified_account", "aster_user", "aster_signer", "aster_private_key", "lighter_wallet_addr", "lighter_private_key", "lighter_api_key_private_key", "lighter_api_key_index"], + "goal": "创建一个可供 trader 绑定使用的交易所配置。", + "dynamic_rules": [ + "确认 exchange_type 后,根据 per_exchange_required_fields 决定需要追问哪些凭证字段。", + "Binance/Bybit/Gate/Indodax 需要 API Key + Secret;OKX/Bitget/KuCoin 还必须追问 passphrase;Hyperliquid 必须追问 api_key + 钱包地址,并允许记录 unified account 开关;Aster 必须追问 user/signer/private_key;Lighter 必须追问钱包地址和 api_key_private_key。", + "如果用户选择 OKX、Bitget 或 KuCoin,不能把 passphrase 说成可选项;没有 passphrase 时应停在补字段,不要创建半成品。", + "凭证字段格式不符时,用人话告知用户正确格式,不要静默丢弃。", + "若当前父任务只是缺一个可用交易所,本动作完成后应允许父任务恢复并消费新的 exchange_id。", + "若请求只是在启用已有交易所,不应误走 create,应改走 update_status。" + ], + "success_output": "返回新 exchange_id 和创建后的交易所配置摘要(类型、账户名、是否启用)。", + "failure_output": "明确指出缺失的必填字段或非法凭证格式,禁止返回含糊的成功信息。" }, "update": { - "description": "更新已有交易所配置。", + "description": "更新已有交易所配置的任意可编辑字段。", "required_slots": ["target_ref"], - "optional_slots": ["account_name", "api_key", "secret_key", "passphrase", "enabled", "testnet"] + "optional_slots": ["account_name", "api_key", "secret_key", "passphrase", "enabled", "testnet", "hyperliquid_wallet_addr", "hyperliquid_unified_account", "aster_user", "aster_signer", "aster_private_key", "lighter_wallet_addr", "lighter_private_key", "lighter_api_key_private_key", "lighter_api_key_index"], + "goal": "更新一个已有交易所配置的指定字段,而不影响未提及字段。", + "dynamic_rules": [ + "只更新用户明确提到的字段,不要覆盖未提及的字段。", + "更新凭证字段时,格式不符则提示用户重新输入。" + ], + "success_output": "返回 exchange_id 和更新后的交易所配置摘要。", + "failure_output": "明确指出目标交易所不存在、凭证格式非法,或仍缺哪个字段。" + }, + "update_name": { + "description": "修改交易所配置中的账户显示名称字段。", + "required_slots": ["target_ref", "account_name"], + "goal": "修改交易所配置中的账户显示名称,而不影响其他字段。", + "dynamic_rules": [ + "若用户同时提到其他字段,应优先走更通用的 update。" + ], + "success_output": "返回 exchange_id,并明确告知交易所配置已更新。", + "failure_output": "明确指出目标交易所不存在,或新的账户名称仍缺失。" + }, + "update_status": { + "description": "修改交易所配置中的启用开关。启用前系统会校验凭证完整性。", + "required_slots": ["target_ref", "enabled"], + "goal": "修改交易所配置中的启用状态字段。", + "dynamic_rules": [ + "启用前根据 exchange_type 校验必填凭证是否齐全,不齐全则提示用户补填后再启用。" + ], + "success_output": "返回 exchange_id,并明确告知交易所配置已更新。", + "failure_output": "明确指出目标交易所不存在、缺少必填凭证,或当前状态切换失败。" }, "delete": { - "description": "删除交易所配置。", + "description": "删除交易所配置,不可逆操作,必须确认。", "required_slots": ["target_ref"], - "needs_confirmation": true + "needs_confirmation": true, + "goal": "删除一个交易所配置。", + "dynamic_rules": [ + "必须在确认后执行,并明确提醒删除不可逆。" + ], + "success_output": "返回删除成功结果,并明确告知该交易所配置已被移除。", + "failure_output": "明确指出缺少确认、目标交易所不存在,或删除失败原因。" }, - "query": { - "description": "查询交易所配置。" + "query_list": { + "description": "查询所有交易所配置列表,包含类型、账户名、启用状态。", + "goal": "列出当前用户可用的交易所配置,便于后续绑定或选择。", + "dynamic_rules": [ + "优先返回类型、账户名、启用状态,不返回敏感凭证明文。" + ], + "success_output": "返回交易所配置列表摘要。", + "failure_output": "若列表为空,应明确告知当前没有交易所配置。" + }, + "query_detail": { + "description": "查询某个交易所配置的详细信息。", + "required_slots": ["target_ref"], + "goal": "读取一个交易所配置的详细信息和当前状态。", + "dynamic_rules": [ + "详情返回中只能暴露凭证存在性,不得返回凭证明文。" + ], + "success_output": "返回目标交易所配置的详细摘要。", + "failure_output": "明确指出目标交易所不存在,或当前引用已经失效。" } }, "tool_mapping": { "create": "manage_exchange_config:create", "update": "manage_exchange_config:update", + "update_name": "manage_exchange_config:update", + "update_status": "manage_exchange_config:update", "delete": "manage_exchange_config:delete", - "query": "get_exchange_configs" + "query_list": "get_exchange_configs", + "query_detail": "get_exchange_configs" } } diff --git a/agent/skills/model_diagnosis.json b/agent/skills/model_diagnosis.json index d47e0d77..f09b5a90 100644 --- a/agent/skills/model_diagnosis.json +++ b/agent/skills/model_diagnosis.json @@ -2,5 +2,16 @@ "name": "model_diagnosis", "kind": "diagnosis", "domain": "model", - "description": "当用户反馈模型配置失败、API Key 无效、Base URL 非法、模型名不匹配、调用返回错误、模型不可用等问题时调用。适用于用户在接入或测试大模型时遇到的配置与兼容性故障。不用于创建、修改、删除或查询模型配置这类管理操作。" + "description": "当用户反馈模型配置失败、API Key 无效、Base URL 非法、模型名不匹配、调用返回错误、模型不可用、claw402 钱包余额不足或支付失败等问题时调用。适用于用户在接入或测试大模型时遇到的配置、兼容性、支付和调用故障。不用于创建、修改、删除或查询模型配置这类管理操作。", + "capabilities": [ + "区分模型未启用、凭证缺失、endpoint/model name 配置错误、钱包余额不足、上游限流或网关异常", + "对 claw402 / blockrun-base 这类钱包付费模型,解释钱包地址、USDC 余额和支付状态", + "给出不泄露敏感凭证的下一步修复建议" + ], + "dynamic_rules": [ + "诊断模型不可用时,按顺序检查:是否存在该模型配置 -> enabled 是否为 true -> provider 是否支持 -> 凭证/API Key 或钱包私钥是否存在 -> custom_api_url 是否合法 HTTPS 或可留空 -> custom_model_name 是否有默认值或已填写 -> 钱包余额/支付状态 -> 上游限流、超时或网关错误。", + "claw402 是模型 provider,使用 Base USDC 钱包按次付费;余额为 0 USDC 时应明确说需要充值,不要说成“未配置模型”。", + "429/rate_limit_error、空响应、超时不应默认归因为余额不足;只有工具结果或错误文本指向余额/支付失败时才这么判断。", + "任何诊断回复都不得展示 API Key、钱包私钥或完整敏感凭证。" + ] } diff --git a/agent/skills/model_management.json b/agent/skills/model_management.json index 98b159ee..5e06b36e 100644 --- a/agent/skills/model_management.json +++ b/agent/skills/model_management.json @@ -3,30 +3,155 @@ "kind": "management", "domain": "model", "description": "当用户想创建、查看、修改或删除 AI 模型配置时调用。适用于用户提到 provider、API Key、Base URL、模型名称、启用状态等配置管理需求。不用于排查模型调用失败、接口不兼容、鉴权错误、模型不存在等诊断问题。", + "field_constraints": { + "provider": { + "type": "enum", + "required": true, + "values": ["openai", "deepseek", "claude", "gemini", "qwen", "kimi", "grok", "minimax", "claw402", "blockrun-base", "blockrun-sol"], + "description": "模型提供商,必填。决定默认模型、凭证类型以及可选配置项。" + }, + "name": { + "type": "string", + "max_length": 50, + "description": "模型配置显示名称,可选,用于区分同一 provider 的多个配置。" + }, + "api_key": { + "type": "credential", + "description": "模型凭证。普通 provider 使用 API Key;claw402 和 blockrun 使用钱包私钥。启用前必须填写。" + }, + "custom_api_url": { + "type": "url", + "must_be_https": true, + "description": "自定义 API Base URL,必须是合法的 HTTPS 地址。普通 provider 可留空走默认地址;claw402 / blockrun 不需要。" + }, + "custom_model_name": { + "type": "string", + "description": "实际调用的模型 ID,例如 gpt-5.1、deepseek-chat。若 provider 有默认模型,可留空走默认值。" + }, + "enabled": { + "type": "bool", + "default": false, + "description": "是否启用该模型配置。启用前必须填写 provider 对应的凭证;若 provider 没有默认模型,还需要 custom_model_name。" + } + }, + "validation_rules": [ + "provider 必须是支持列表之一:openai、deepseek、claude、gemini、qwen、kimi、grok、minimax、claw402、blockrun-base、blockrun-sol。", + "OpenAI 的 api_key 格式校验:必须以 sk- 开头,不符合时提示用户检查 Key 是否完整。", + "custom_api_url 若填写,必须是合法 HTTPS 地址,系统拒绝 HTTP 地址,提示用户改用 HTTPS。", + "启用(enabled=true)前必须填写 provider 对应的凭证;如果 custom_model_name 留空,则系统应先尝试使用 provider 默认模型。", + "启用(enabled=true)前,custom_api_url 若填写必须是合法 HTTPS 地址;不允许用 HTTP 地址硬启用。", + "claw402 是 AI 模型 provider,不是交易所、策略或交易员名称;用户说“用 claw402”时应解释为选择/绑定 claw402 模型配置。", + "claw402 使用 Base 链 EVM 钱包 + USDC 按次付费;enabled=true 只代表模型配置已启用,不代表钱包一定有余额。", + "claw402 或 blockrun-base 钱包余额为 0 USDC 时,应明确提示“钱包余额不足/需要充值”,不要说成“模型未启用”或静默改用其他模型。", + "用户明确指定某个 provider 或模型时,如果当前不可用,必须先说明不可用原因,再让用户选择修复该模型或改用其他已可用模型;不得静默替换。", + "删除操作不可逆,必须先向用户确认再执行。" + ], "actions": { "create": { "description": "创建新的模型配置。", "required_slots": ["provider"], - "optional_slots": ["name", "api_key", "custom_api_url", "custom_model_name", "enabled"] + "optional_slots": ["name", "api_key", "custom_api_url", "custom_model_name", "enabled"], + "goal": "创建一个可供 trader 绑定使用的模型配置。", + "dynamic_rules": [ + "确认 provider 后,先说明该 provider 的默认模型和凭证类型,再按 provider 特性补充追问。", + "普通 provider(openai、deepseek、claude 等)通常需要 api_key;custom_api_url 和 custom_model_name 可留空走默认值。", + "claw402 需要钱包私钥,不需要 custom_api_url;custom_model_name 留空时默认 deepseek。", + "创建 claw402 后若钱包余额为 0 USDC,应提示用户充值 Base USDC 后再用于稳定调用;不要把余额不足误报为配置未启用。", + "blockrun-base 和 blockrun-sol 需要钱包私钥,不需要 custom_api_url;custom_model_name 留空时默认 auto。", + "若用户提供了 custom_api_url,校验是否为合法 HTTPS 地址,不合法则提示修正。", + "OpenAI 的 api_key 不以 sk- 开头时,提示用户检查 Key 格式。", + "若用户要在父任务里使用现有模型,应优先选择当前已启用模型,而不是误开新的 create。", + "若当前父任务只是缺一个可用模型,本动作完成后应允许父任务恢复并消费新的 model_id。" + ], + "success_output": "返回 model_id 和创建后的模型配置摘要(provider、名称、是否启用)。", + "failure_output": "明确指出缺失字段、非法 endpoint 或不支持的 provider,禁止只说泛化失败。" }, "update": { - "description": "更新已有模型配置。", + "description": "更新已有模型配置的任意可编辑字段。", "required_slots": ["target_ref"], - "optional_slots": ["api_key", "custom_api_url", "custom_model_name", "enabled"] + "optional_slots": ["name", "api_key", "custom_api_url", "custom_model_name", "enabled"], + "goal": "更新一个已有模型配置的指定字段,而不覆盖未提及字段。", + "dynamic_rules": [ + "只更新用户明确提到的字段,不要覆盖未提及的字段。", + "如果用户只是想给 trader 改用 claw402,不要在模型配置里误改显示名称;应把 claw402 作为 provider/model 选择处理。", + "更新 custom_api_url 时校验 HTTPS 格式。", + "更新 api_key 时对 OpenAI 校验 sk- 前缀。" + ], + "success_output": "返回 model_id 和更新后的模型配置摘要。", + "failure_output": "明确指出目标模型不存在、provider/endpoint 非法,或仍缺哪个关键字段。" + }, + "update_status": { + "description": "启用或禁用模型配置。启用前系统会校验 api_key 和 custom_model_name 是否已填写。", + "required_slots": ["target_ref", "enabled"], + "goal": "切换模型配置的启用状态。", + "dynamic_rules": [ + "启用前必须确保 provider 对应凭证已经齐全;若 provider 有默认模型,custom_model_name 可按默认值处理。", + "启用 claw402 只校验钱包私钥等配置完整性;若钱包 0 USDC,应提示充值,但不要把它等同于 enabled=false。" + ], + "success_output": "返回 model_id,并明确告知该模型已启用或已禁用。", + "failure_output": "明确指出目标模型不存在、缺少启用前必填项,或当前状态切换失败。" + }, + "update_endpoint": { + "description": "仅修改模型的 custom_api_url。", + "required_slots": ["target_ref", "custom_api_url"], + "goal": "仅更新模型配置的 custom_api_url。", + "dynamic_rules": [ + "custom_api_url 必须是合法 HTTPS 地址;若不合法,先让用户修正而不是继续执行。" + ], + "success_output": "返回 model_id,并明确告知新的接口地址。", + "failure_output": "明确指出目标模型不存在,或接口地址仍不合法。" + }, + "update_name": { + "description": "仅修改模型配置的 custom_model_name(实际调用的模型 ID)。", + "required_slots": ["target_ref", "custom_model_name"], + "goal": "仅更新模型配置的实际调用模型 ID。", + "dynamic_rules": [ + "若用户其实是在改显示名称或 provider,应转去更通用的 update,而不是误用本动作。" + ], + "success_output": "返回 model_id,并明确告知新的 custom_model_name。", + "failure_output": "明确指出目标模型不存在,或新的模型 ID 仍未收齐。" }, "delete": { - "description": "删除模型配置。", + "description": "删除模型配置,不可逆操作,必须确认。", "required_slots": ["target_ref"], - "needs_confirmation": true + "needs_confirmation": true, + "goal": "删除一个模型配置。", + "dynamic_rules": [ + "必须在确认后执行,并明确提醒删除不可逆。" + ], + "success_output": "返回删除成功结果,并明确告知该模型配置已被移除。", + "failure_output": "明确指出缺少确认、目标模型不存在,或删除失败原因。" }, - "query": { - "description": "查询模型配置。" + "query_list": { + "description": "查询所有模型配置列表,包含 provider、名称、启用状态。", + "goal": "列出当前用户可见的模型配置,便于后续选择或绑定。", + "dynamic_rules": [ + "优先返回 provider、名称、启用状态,不返回 API Key 明文。", + "对于 claw402 / blockrun-base,若工具结果包含钱包地址或 USDC 余额,应用它解释支付状态;余额不足时要说“需要充值”,不要说“没配置”。" + ], + "success_output": "返回模型配置列表摘要。", + "failure_output": "若列表为空,应明确告知当前没有模型配置。" + }, + "query_detail": { + "description": "查询某个模型配置的详细信息。", + "required_slots": ["target_ref"], + "goal": "读取一个模型配置的详细信息。", + "dynamic_rules": [ + "详情返回中只能暴露 API Key/钱包私钥是否存在,不得返回明文凭证。", + "对于 claw402,应区分三种状态:配置未启用、钱包凭证缺失、钱包余额不足。" + ], + "success_output": "返回目标模型配置的详细摘要。", + "failure_output": "明确指出目标模型不存在,或当前引用已经失效。" } }, "tool_mapping": { "create": "manage_model_config:create", "update": "manage_model_config:update", + "update_status": "manage_model_config:update", + "update_endpoint": "manage_model_config:update", + "update_name": "manage_model_config:update", "delete": "manage_model_config:delete", - "query": "get_model_configs" + "query_list": "get_model_configs", + "query_detail": "get_model_configs" } } diff --git a/agent/skills/strategy_diagnosis.json b/agent/skills/strategy_diagnosis.json index 827185c6..0e1eb8ec 100644 --- a/agent/skills/strategy_diagnosis.json +++ b/agent/skills/strategy_diagnosis.json @@ -2,5 +2,21 @@ "name": "strategy_diagnosis", "kind": "diagnosis", "domain": "strategy", - "description": "当用户反馈策略未生效、策略输出异常、提示词或配置结果与预期不一致、策略执行表现异常时调用。适用于策略内容和执行效果相关的排障与解释。不用于创建、修改、删除、激活、复制或查询策略模板这类管理操作。" + "description": "当用户反馈策略未生效、候选币为空、策略输出异常、提示词或配置结果与预期不一致、AI 一直 hold/wait、策略执行表现异常时调用。适用于策略内容、候选币、风控边界和执行效果相关的排障与解释。不用于创建、修改、删除、激活、复制或查询策略模板这类管理操作。", + "capabilities": [ + "区分策略模板配置问题、交易员绑定问题、市场数据/候选币问题、AI 决策为 hold/wait、风控拦截和交易所下单失败", + "解释 AI 策略与网格策略的字段边界、页面范围和 System enforced 字段", + "指出策略模板不能直接运行,必须由交易员绑定后执行" + ], + "dynamic_rules": [ + "策略没生效时,先区分:只是策略模板未被交易员绑定,还是交易员已绑定但运行结果不符合预期。", + "若候选币为空,检查 source_type/static_coins/AI500/OI 榜单/排除币/量化数据开关,不要直接归因为模型问题。", + "若 AI 一直 hold/wait,先检查 min_confidence、min_risk_reward_ratio、提示词是否过于保守、行情是否满足入场条件,再判断是否需要放宽策略。", + "若交易员绑定了策略但没有下单,应与 trader_diagnosis 协作区分策略无信号、风控拦截和交易所下单失败。", + "策略诊断必须区分可编辑策略字段和 System enforced 字段。AI 智能策略里的 max_positions、btceth_max_position_value_ratio、altcoin_max_position_value_ratio、max_margin_usage、min_position_size 只能解释,不能建议用户修改。", + "如果不开单原因来自最小下单金额、保证金或仓位价值边界,不要建议修改 min_position_size 或 position_size_usd;应建议增加账户权益、换更适合小资金的标的、调整可编辑风险偏好或让策略在资金不足时等待。", + "策略页不存在 position_size_usd 这类固定配置项;position_size_usd 是 AI 每轮决策输出,不是策略模板字段。不要把 AI 决策里的 position_size_usd 说成可以在策略页手动修改的参数。", + "后台 402/404/EOF 类数据源错误只能作为策略分析质量的辅助影响,不能在决策记录已经显示明确风控/最小金额拒绝时作为主因。", + "策略模板本身不保存交易所、模型、扫描间隔或初始余额;这些问题应引导到 trader/model/exchange 相关诊断。" + ] } diff --git a/agent/skills/strategy_management.json b/agent/skills/strategy_management.json index a6ce0465..0f5dfb91 100644 --- a/agent/skills/strategy_management.json +++ b/agent/skills/strategy_management.json @@ -2,41 +2,472 @@ "name": "strategy_management", "kind": "management", "domain": "strategy", - "description": "当用户想创建、查看、修改、删除、激活或复制策略模板时调用。适用于用户提到策略名称、策略配置、描述、语言、激活状态、复制新版本等管理需求。不用于排查策略未生效、策略输出异常、执行结果异常等诊断问题。", + "description": "当用户想创建、查看、修改、删除、激活或复制策略模板时调用。", + "field_constraints": { + "name": { + "type": "string", + "required": true, + "max_length": 50, + "description": "策略模板名称,必填,最多 50 个字符。" + }, + "description": { + "type": "string", + "description": "策略描述,可选。" + }, + "is_public": { + "type": "bool", + "default": false, + "description": "是否发布到策略市场。" + }, + "config_visible": { + "type": "bool", + "default": true, + "description": "发布到市场后,是否允许别人查看策略配置。" + }, + "lang": { + "type": "enum", + "values": ["zh", "en"], + "default": "zh", + "description": "策略语言,zh 或 en,影响 AI 决策时使用的语言。" + }, + "strategy_type": { + "type": "enum", + "values": ["ai_trading", "grid_trading"], + "description": "策略类型:ai_trading(AI 量化)或 grid_trading(网格策略)。创建策略时必须先由用户选择或从用户话语明确识别,不能默认成 ai_trading。" + }, + "symbol": { + "type": "enum", + "values": ["BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT", "XRPUSDT", "DOGEUSDT"], + "description": "网格策略页面交易对下拉选项,只能从 BTCUSDT、ETHUSDT、SOLUSDT、BNBUSDT、XRPUSDT、DOGEUSDT 中选择。用户问“交易对有哪些选项”时,直接列出这些选项。" + }, + "source_type": { + "type": "enum", + "values": ["static", "ai500", "oi_top", "oi_low"], + "description": "选币来源类型。static=用户指定静态币池,ai500=AI500榜单,oi_top=持仓量增长,oi_low=持仓量下降。" + }, + "static_coins": { + "type": "string_array", + "max_items": 10, + "description": "静态币池,例如 [\"BTCUSDT\", \"ETHUSDT\"],source_type=static 时使用,手动页面最多 10 个。页面支持常规合约币种,也支持 xyz: 前缀资产(如 xyz:TSLA、xyz:GOLD、xyz:XYZ100)。" + }, + "excluded_coins": { + "type": "string_array", + "description": "排除币列表,所有来源均会排除这些币。" + }, + "primary_timeframe": { + "type": "string", + "values": ["1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w"], + "description": "主 K 线周期,例如 5m、15m、1h。" + }, + "selected_timeframes": { + "type": "string_array", + "values": ["1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w"], + "max_items": 4, + "description": "多周期分析时间框架列表,例如 [\"5m\",\"15m\",\"1h\"];手动页面最多选择 4 个。" + }, + "btceth_max_leverage": { + "type": "int", + "min": 1, + "max": 20, + "description": "BTC/ETH 最大杠杆倍数,范围 1~20。" + }, + "altcoin_max_leverage": { + "type": "int", + "min": 1, + "max": 20, + "description": "山寨币最大杠杆倍数,范围 1~20。" + }, + "min_confidence": { + "type": "int", + "min": 50, + "max": 100, + "description": "最小开仓置信度,手动页面范围 50~100,数值越高开单越谨慎。" + }, + "min_risk_reward_ratio": { + "type": "float", + "min": 1, + "max": 10, + "description": "最小盈亏比,手动页面范围 1~10,步进 0.5;例如 1.5 表示每笔交易至少 1.5 倍风险收益比。" + }, + "custom_prompt": { + "type": "text", + "description": "自定义 AI 提示词,追加到策略基础提示词之后。" + }, + "role_definition": { + "type": "text", + "description": "AI 角色定义,描述 AI 的交易风格和定位。" + }, + "trading_frequency": { + "type": "text", + "description": "交易频率描述,例如:每天最多开 3 笔。" + }, + "entry_standards": { + "type": "text", + "description": "入场标准描述,例如:只在趋势明确时开仓。" + }, + "decision_process": { + "type": "text", + "description": "决策流程描述,例如:先看大周期趋势,再看小周期入场点。" + }, + "grid_count": { + "type": "int", + "min": 5, + "max": 50, + "description": "网格数量,grid_trading 类型专用,手动页面范围 5~50。" + }, + "total_investment": { + "type": "float", + "min": 100, + "description": "网格总投入金额,grid_trading 类型专用,表示用户实际投入/保证金预算,不是杠杆后的名义仓位;名义仓位约等于 total_investment × leverage。手动页面最小 100 USDT,步进 100。" + }, + "leverage": { + "type": "int", + "min": 1, + "max": 5, + "description": "网格策略杠杆倍数,手动页面当前范围 1~5。" + }, + "upper_price": { + "type": "float", + "description": "网格上边界价格,grid_trading 类型专用。" + }, + "lower_price": { + "type": "float", + "description": "网格下边界价格,grid_trading 类型专用,必须小于 upper_price。" + }, + "distribution": { + "type": "enum", + "values": ["uniform", "gaussian", "pyramid"], + "description": "网格分布方式:uniform=均匀,gaussian=正态,pyramid=金字塔。" + }, + "use_atr_bounds": { + "type": "bool", + "default": false, + "description": "网格边界是否改为按 ATR 动态计算。" + }, + "atr_multiplier": { + "type": "float", + "min": 1, + "max": 5, + "description": "ATR 边界倍数,use_atr_bounds=true 时使用,手动页面范围 1~5,步进 0.5。" + }, + "enable_direction_adjust": { + "type": "bool", + "default": false, + "description": "是否启用方向偏置调整。" + }, + "direction_bias_ratio": { + "type": "float", + "min": 0.55, + "max": 0.9, + "description": "方向偏置比例,决定多空倾向强弱;手动页面范围 0.55~0.90,通常以 55%~90% 展示。" + }, + "max_drawdown_pct": { + "type": "float", + "min": 5, + "max": 50, + "description": "网格策略最大回撤百分比,手动页面范围 5~50。" + }, + "stop_loss_pct": { + "type": "float", + "min": 1, + "max": 20, + "description": "网格策略止损百分比,手动页面范围 1~20。" + }, + "daily_loss_limit_pct": { + "type": "float", + "min": 1, + "max": 30, + "description": "网格策略每日最大亏损比例,手动页面范围 1~30,达到后当天停止新开仓。" + }, + "use_maker_only": { + "type": "bool", + "default": false, + "description": "是否优先只挂 maker 单。" + }, + "use_ai500": { + "type": "bool", + "default": false, + "description": "是否启用 AI500 榜单作为候选币来源。" + }, + "ai500_limit": { + "type": "int", + "min": 1, + "max": 10, + "description": "AI500 榜单选取数量,手动页面范围 1~10。" + }, + "use_oi_top": { + "type": "bool", + "default": false, + "description": "是否启用 OI Top 作为候选币来源。" + }, + "oi_top_limit": { + "type": "int", + "min": 1, + "max": 10, + "description": "OI Top 选取数量,手动页面范围 1~10。" + }, + "use_oi_low": { + "type": "bool", + "default": false, + "description": "是否启用 OI Low 作为候选币来源。" + }, + "oi_low_limit": { + "type": "int", + "min": 1, + "max": 10, + "description": "OI Low 选取数量,手动页面范围 1~10。" + }, + "primary_count": { + "type": "int", + "min": 10, + "max": 30, + "description": "主周期 K 线样本数量,手动页面范围 10~30。" + }, + "enable_ema": { + "type": "bool", + "default": false, + "description": "是否启用 EMA 指标。" + }, + "enable_macd": { + "type": "bool", + "default": false, + "description": "是否启用 MACD 指标。" + }, + "enable_rsi": { + "type": "bool", + "default": false, + "description": "是否启用 RSI 指标。" + }, + "enable_atr": { + "type": "bool", + "default": false, + "description": "是否启用 ATR 指标。" + }, + "enable_boll": { + "type": "bool", + "default": false, + "description": "是否启用布林带指标。" + }, + "enable_volume": { + "type": "bool", + "default": false, + "description": "是否启用成交量指标。" + }, + "enable_oi": { + "type": "bool", + "default": false, + "description": "是否启用持仓量指标。" + }, + "enable_funding_rate": { + "type": "bool", + "default": false, + "description": "是否启用资金费率指标。" + }, + "ema_periods": { + "type": "int_array", + "description": "EMA 周期列表,例如 [9,21,55]。" + }, + "rsi_periods": { + "type": "int_array", + "description": "RSI 周期列表。" + }, + "atr_periods": { + "type": "int_array", + "description": "ATR 周期列表。" + }, + "boll_periods": { + "type": "int_array", + "description": "布林带周期列表。" + }, + "nofxos_api_key": { + "type": "credential", + "description": "量化数据 API Key。" + }, + "enable_quant_data": { + "type": "bool", + "default": false, + "description": "是否启用量化数据增强。" + }, + "enable_quant_oi": { + "type": "bool", + "default": false, + "description": "是否启用量化持仓量数据。" + }, + "enable_quant_netflow": { + "type": "bool", + "default": false, + "description": "是否启用量化净流入数据。" + }, + "enable_oi_ranking": { + "type": "bool", + "default": false, + "description": "是否启用 OI 排行榜。" + }, + "oi_ranking_duration": { + "type": "enum", + "values": ["1h", "4h", "24h"], + "description": "OI 排行榜统计周期,页面选项为 1h、4h、24h。" + }, + "oi_ranking_limit": { + "type": "int", + "min": 5, + "max": 20, + "description": "OI 排行榜返回数量,页面选项为 5、10、15、20。" + }, + "enable_netflow_ranking": { + "type": "bool", + "default": false, + "description": "是否启用净流入排行榜。" + }, + "netflow_ranking_duration": { + "type": "enum", + "values": ["1h", "4h", "24h"], + "description": "净流入排行榜统计周期,页面选项为 1h、4h、24h。" + }, + "netflow_ranking_limit": { + "type": "int", + "min": 5, + "max": 20, + "description": "净流入排行榜返回数量,页面选项为 5、10、15、20。" + }, + "enable_price_ranking": { + "type": "bool", + "default": false, + "description": "是否启用价格波动排行榜。" + }, + "price_ranking_duration": { + "type": "enum", + "values": ["1h", "4h", "24h", "1h,4h,24h"], + "description": "价格排行榜统计周期,页面选项为 1h、4h、24h、1h,4h,24h。" + }, + "price_ranking_limit": { + "type": "int", + "min": 5, + "max": 20, + "description": "价格排行榜返回数量,页面选项为 5、10、15、20。" + } + }, + "validation_rules": [ + "本 skill 只负责策略模板创建、查看、修改、删除、激活和复制。", + "字段选项和范围来自 field_constraints;产品行为规则由 active session prompt 负责。" + ], "actions": { "create": { "description": "创建策略模板。", "required_slots": ["name"], - "optional_slots": ["config", "description", "lang"] + "optional_slots": ["strategy_type", "config_patch"], + "goal": "创建一个可供 trader 绑定使用的策略模板。", + "success_output": "返回 strategy_id 和新策略摘要(名称、类型、主要配置)。", + "failure_output": "明确指出仍缺哪些核心参数,或说明需要先确认的风控收敛结果。" }, "update": { - "description": "更新策略模板。", + "description": "更新策略模板的任意可编辑字段。", "required_slots": ["target_ref"], - "optional_slots": ["name", "config", "description"] + "optional_slots": ["name", "description", "is_public", "config_visible", "config_patch"], + "goal": "更新一个已有策略模板的指定配置,而不覆盖未提及字段。", + "dynamic_rules": [ + "只更新用户明确提到的字段,不要覆盖未提及的字段。", + "杠杆超出 1~20 范围时,自动收敛并告知用户。", + "grid_trading 类型时,lower_price 必须小于 upper_price。" + ], + "success_output": "返回 strategy_id 和更新后的策略摘要。", + "failure_output": "明确指出目标策略不存在、参数非法,或仍缺哪个关键字段。" }, - "delete": { - "description": "删除策略模板。", + "update_name": { + "description": "仅修改策略模板名称。", + "required_slots": ["target_ref", "name"], + "goal": "仅修改策略模板名称。", + "dynamic_rules": [ + "若输入里还包含其他配置项,应优先转去更通用的 update 或 update_config。" + ], + "success_output": "返回 strategy_id,并明确告知新的策略名称。", + "failure_output": "明确指出目标策略不存在,或新的名称仍未收齐。" + }, + "update_prompt": { + "description": "仅修改策略的 custom_prompt 或 prompt_sections(role_definition、trading_frequency、entry_standards、decision_process)。", "required_slots": ["target_ref"], - "needs_confirmation": true + "optional_slots": ["custom_prompt", "role_definition", "trading_frequency", "entry_standards", "decision_process"], + "goal": "更新策略模板的提示词相关内容,而不改动其他配置。", + "dynamic_rules": [ + "若用户一次修改多个 prompt section,应整体应用并在结果里清楚说明。", + "若用户实际是在改纯配置项,应转去 update_config。", + "当需要收集 custom_prompt 或 prompt_sections 等长文本槽位,而用户表达了“交给你”“你帮我写”“你自己设计”等委托生成意图时,严禁再次机械索要原文。", + "此时你必须直接以量化专家身份先拟出一版高质量文本,将生成内容写入对应字段,并在回复里展示草稿让用户确认是否直接采用。" + ], + "success_output": "返回 strategy_id,并明确告知哪些 prompt 字段已更新。", + "failure_output": "明确指出目标策略不存在,或新的 prompt 内容仍不完整。" + }, + "update_config": { + "description": "修改策略的某个具体配置参数(选币来源、指标、风控参数等)。", + "required_slots": ["target_ref"], + "optional_slots": ["config_patch"], + "goal": "修改策略模板中的一个或一组具体配置参数。", + "dynamic_rules": [ + "配置变更统一通过 config_patch 表达,字段必须来自当前策略类型的产品模板。", + "字段选项、范围和非策略字段拦截由 active session prompt 与后端 schema 负责。" + ], + "success_output": "返回 strategy_id,并明确告知已修改的配置字段及其最终值。", + "failure_output": "明确指出目标策略不存在、配置字段非法,或值仍需用户澄清。" }, "activate": { - "description": "激活策略模板。", - "required_slots": ["target_ref"] + "description": "将策略模板设为默认模板(激活)。", + "required_slots": ["target_ref"], + "goal": "将某个策略模板设为默认模板。", + "success_output": "返回 strategy_id,并明确告知该策略已被设为默认模板。", + "failure_output": "明确指出目标策略不存在,或激活失败原因。" }, "duplicate": { - "description": "复制策略模板。", - "required_slots": ["target_ref", "name"] + "description": "复制策略模板,生成一个新的同配置模板。", + "required_slots": ["target_ref", "name"], + "goal": "复制一个现有策略模板并生成新的模板名称。", + "dynamic_rules": [ + "新名称必须单独收齐;若名称有歧义或为空,应先继续追问。" + ], + "success_output": "返回新的 strategy_id,并明确告知复制后的策略名称。", + "failure_output": "明确指出目标策略不存在,或新名称仍未收齐。" }, - "query": { - "description": "查询策略模板。" + "delete": { + "description": "删除策略模板,不可逆操作,必须确认。", + "required_slots": ["target_ref"], + "needs_confirmation": true, + "goal": "删除一个策略模板。", + "dynamic_rules": [ + "必须在确认后执行,并明确提醒删除不可逆。", + "若策略是默认模板或受系统保护,应向用户解释限制。" + ], + "success_output": "返回删除成功结果,并明确告知该策略模板已被移除。", + "failure_output": "明确指出缺少确认、目标策略不存在,或删除失败原因。" + }, + "query_list": { + "description": "查询所有策略模板列表,包含名称、类型、是否为默认模板。", + "goal": "列出当前用户可见的策略模板,便于后续选择或绑定。", + "dynamic_rules": [ + "优先返回名称、类型、默认状态,不必展开全部详细配置。" + ], + "success_output": "返回策略模板列表摘要。", + "failure_output": "若列表为空,应明确告知当前没有策略模板。" + }, + "query_detail": { + "description": "查询某个策略模板的详细配置,包括选币来源、指标、风控参数、提示词等。", + "required_slots": ["target_ref"], + "goal": "读取一个策略模板的详细配置。", + "dynamic_rules": [ + "若目标有歧义,应先澄清再返回详情。" + ], + "success_output": "返回目标策略模板的详细配置摘要。", + "failure_output": "明确指出目标策略不存在,或当前引用已经失效。" } }, "tool_mapping": { "create": "manage_strategy:create", "update": "manage_strategy:update", - "delete": "manage_strategy:delete", + "update_name": "manage_strategy:update", + "update_prompt": "manage_strategy:update", + "update_config": "manage_strategy:update", "activate": "manage_strategy:activate", "duplicate": "manage_strategy:duplicate", - "query": "get_strategies" + "delete": "manage_strategy:delete", + "query_list": "get_strategies", + "query_detail": "get_strategies" } } diff --git a/agent/skills/trade_execution.json b/agent/skills/trade_execution.json new file mode 100644 index 00000000..9adfe7f9 --- /dev/null +++ b/agent/skills/trade_execution.json @@ -0,0 +1,63 @@ +{ + "name": "trade_execution", + "kind": "execution", + "domain": "trade", + "description": "当用户明确要求开仓、平仓、买入、卖出,或确认待执行的大额订单时调用。负责真实下单前的安全校验、待确认订单、确认执行与交易历史查询。", + "intents": [ + "下单交易", + "开多开空", + "平仓", + "确认大额订单", + "查询交易历史" + ], + "actions": { + "execute": { + "description": "创建一笔待确认交易。不会直接成交,而是先做风险检查,再给用户确认指令。", + "required_slots": ["action", "symbol", "quantity"], + "optional_slots": ["leverage", "trader_id"], + "needs_confirmation": true, + "goal": "在真实执行前先做风险检查,并给用户一个可确认的待执行订单。", + "dynamic_rules": [ + "只有当用户明确要求交易时才允许进入本动作;分析、建议、解释行情都不应触发下单。", + "开仓数量必须大于 0,单笔数量硬上限为 1000000,超过时直接拒绝。", + "会先按实时价格估算名义价值;单笔名义价值硬上限为 100000 USDT,超过时直接拒绝。", + "若单笔名义价值达到 5000 USDT,或达到账户权益的 25%,必须标记为大额订单,要求用户发送“确认大额 trade_xxx”后才执行。", + "若单笔名义价值超过账户权益的 100%,直接拒绝,不允许创建待确认订单。", + "加密货币订单的杠杆上限受策略 btceth_max_leverage / altcoin_max_leverage 约束,默认上限为 5x;超出时直接拒绝。", + "BTC/ETH 单笔最大仓位价值默认不超过 5 倍账户权益,山寨币默认不超过 1 倍账户权益;若策略里有自定义比例,以策略为准。", + "最小仓位价值固定为 12 USDT;这是系统强制项,不允许通过 Agent 修改。低于最小值时直接拒绝。", + "创建后的待确认订单默认 5 分钟有效,超时自动失效。" + ], + "success_output": "返回 trade_id、估算仓位价值、是否触发大额确认、确认命令和 5 分钟有效期。", + "failure_output": "用简单清楚的话说明是哪条风控挡住了,例如数量过大、仓位太小、杠杆过高、超过权益上限。" + }, + "confirm_large_order": { + "description": "确认一笔已创建的大额待执行订单。", + "required_slots": ["trade_id"], + "needs_confirmation": true, + "goal": "在用户明确确认后,执行已通过初步检查的大额订单。", + "dynamic_rules": [ + "用户必须发送“确认大额 trade_xxx”或“confirm large trade_xxx”才能执行大额订单。", + "若订单已过期、已不存在,或 trade_id 无效,要直接说明这笔订单已经失效。", + "若用户只发送普通确认,但订单被标记为大额订单,必须继续要求“大额确认”,不能直接放行。" + ], + "success_output": "明确告知订单已执行,并展示方向、品种、数量。", + "failure_output": "明确说明订单已过期、风控未通过,或执行失败原因。" + }, + "query_history": { + "description": "查询最近的交易历史。", + "optional_slots": ["limit", "trader_id"], + "goal": "让用户快速查看最近成交记录和交易结果。", + "dynamic_rules": [ + "优先返回最近几笔最重要的交易,不要一次性给太长的开发者原始日志。", + "若当前没有交易记录,要直接说明当前还没有成交记录。" + ], + "success_output": "返回最近交易记录摘要,包括方向、品种、时间和结果。", + "failure_output": "若没有记录或查询失败,要明确告知用户。" + } + }, + "tool_mapping": { + "execute": "execute_trade", + "query_history": "get_trade_history" + } +} diff --git a/agent/skills/trader_diagnosis.json b/agent/skills/trader_diagnosis.json index ae263145..4e9c96e3 100644 --- a/agent/skills/trader_diagnosis.json +++ b/agent/skills/trader_diagnosis.json @@ -2,5 +2,38 @@ "name": "trader_diagnosis", "kind": "diagnosis", "domain": "trader", - "description": "当用户反馈交易员无法启动、启动后不交易、绑定模型或交易所缺失、运行状态异常、收益或仓位表现异常时调用。适用于交易员运行过程中的排障与原因定位。不用于创建、修改、删除、启动、停止或查询交易员这类管理操作。" + "description": "当用户反馈交易员无法启动、启动后不交易、反复报错、绑定模型或交易所缺失、运行状态异常、收益或仓位表现异常时调用。适用于交易员运行过程中的排障与原因定位。不用于创建、修改、删除、启动、停止或查询交易员这类管理操作。", + "capabilities": [ + "读取交易员当前状态、账户、持仓和最近决策记录", + "读取交易员绑定的策略、模型、交易所配置摘要,并把它们纳入不开单诊断证据包", + "在用户明确指定目标交易员后,读取该交易员最近的后端日志", + "把完整证据合并成适合新手理解的最终原因和下一步行动" + ], + "dynamic_rules": [ + "当用户问“为什么报错”“为什么不交易”“为什么停了”这类问题时,优先走诊断而不是管理类 skill。", + "如果已经能唯一确定目标交易员,应一次性收集完整诊断证据包:交易员配置/运行状态、绑定策略、绑定模型、绑定交易所、账户权益/可用余额、当前持仓、get_decisions 最近决策记录、get_backend_logs 后台日志。不要只查其中一项就下结论。", + "面向普通用户的诊断回复只说最终原因和该怎么办,不要输出证据包清单、工具名、后台日志片段、HTTP 状态码或工程排障过程。", + "诊断结论内部必须区分:直接原因、次要影响、待确认因素。直接原因必须来自最近决策记录、交易所下单结果、风控校验或明确运行状态;后台日志里的零散错误只能作为辅助证据。", + "证据优先级固定为:最近决策记录 > 交易员运行状态/账户/持仓 > 交易所下单结果 > 后台日志。除非最近决策记录本身显示数据获取失败或 AI 决策中断,否则不要让 backend logs 盖过决策记录。", + "交易员不下单的排查顺序固定为:是否运行中 -> 是否已到扫描间隔 -> 策略候选币/行情数据是否为空 -> 最近 AI 决策是否为 hold/wait -> 风控是否拦截 -> 交易所下单是否报错 -> 余额、杠杆、仓位模式或权限是否限制。", + "判断“不下单/不开单”的主因时,最近决策记录优先级高于零散 backend error 日志;如果最近决策显示 wait succeeded,应解释为 AI 主动等待;如果最新决策 error_message 显示 opening amount too small / below minimum / must be ≥,应解释为开仓金额低于系统或交易所最小下单门槛。", + "遇到 opening amount too small、position value below minimum、must be ≥ 这类错误时,不要建议用户修改 AI 智能策略的 min_position_size 或 position_size_usd。先说明这是系统/交易所门槛或 System enforced 边界,再建议增加账户权益、换更适合小资金的交易标的、调整可编辑策略偏好,或让策略在资金不足时等待。", + "AI 智能策略里的 System enforced 字段(max_positions、btceth_max_position_value_ratio、altcoin_max_position_value_ratio、max_margin_usage、min_position_size)只能解释,不能建议用户修改;如果限制来自这些字段,行动建议必须落在产品实际可改项或用户账户/标的选择上。", + "不要只因为 backend logs 里出现 402、404、EOF、payment retry failed 就直接归因为数据服务、订阅到期或付款失败;这些内部异常不应在普通用户回答里出现,除非用户明确追问后台日志或技术细节。", + "402 不要直接翻译成“订阅到期”。在没有钱包余额、支付状态或服务侧确认前,不能说订阅过期;普通用户回答里也不要主动说 402。", + "如果最近决策记录显示 candidate_coins 非空、AI call completed、wait succeeded 或 open_* 决策已生成,则说明核心决策链路并非完全拿不到数据;此时不要把 402/404/EOF 说成不开单主因。", + "行动建议必须对应产品里真实存在且可修改的字段或操作。不要编造策略页不存在的 position_size_usd 参数,不要建议修改 System enforced 字段。", + "如果模型是 claw402 或 blockrun-base,应单独检查钱包 USDC 余额;余额不足时应说“支付余额不足/需要充值”,不要泛化成“模型没启用”。", + "如果日志显示 AI 返回 hold/wait,应解释为模型判断当前没有足够交易信号,不应误判为系统没有运行。", + "如果日志显示下单失败,应优先归因到交易所权限、API 凭证、仓位模式、余额、杠杆或 symbol 可交易性,而不是策略没有生效。", + "当用户表达“启动不了”“启动失败”“无法启动”“一启动就报错”“为什么启动不起来”这类启动故障时,只要目标交易员能唯一确定,就优先自动读取 get_backend_logs。", + "当证据中已经出现明确错误原因时,直接用人话解释最终原因和下一步,不要复述原始日志。" + ], + "tool_mapping": { + "query_runtime_state": "get_trader_system_status", + "query_positions": "get_positions", + "query_account": "get_account_info", + "query_recent_decisions": "get_decisions", + "query_backend_logs": "get_backend_logs" + } } diff --git a/agent/skills/trader_management.json b/agent/skills/trader_management.json index babd251d..cdff2e11 100644 --- a/agent/skills/trader_management.json +++ b/agent/skills/trader_management.json @@ -2,7 +2,7 @@ "name": "trader_management", "kind": "management", "domain": "trader", - "description": "当用户想创建、查看、修改、删除、启动或停止交易员时调用。适用于用户提到交易员名称、绑定交易所、绑定模型、绑定策略、扫描频率、自定义提示词、运行状态等管理需求。不用于排查交易员启动失败、未下单、收益异常、仓位异常等诊断问题。", + "description": "当用户想创建、查看、修改、删除、启动或停止交易员时调用。交易员是装配层;创建交易员时需要名称以及绑定的交易所、模型、策略。编辑交易员只允许修改手动面板可改的 6 项:绑定交易所、绑定模型、绑定策略、扫描间隔、保证金模式、是否展示到竞技场;不修改这些依赖对象的内部配置,也不在这里改名。若用户要改策略参数、模型配置或交易所凭证,应切到各自的 management skill。创建交易员时交易所、模型、策略既可以直接选择用户已有可用资源,也可以在当前主流程里先新建/启用对应资源,再继续完成交易员创建。不用于排查交易员启动失败、未下单、收益异常、仓位异常等诊断问题。", "intents": [ "创建交易员", "修改交易员", @@ -11,42 +11,221 @@ "停止交易员", "查询交易员" ], + "field_constraints": { + "name": { + "type": "string", + "required": true, + "max_length": 50, + "description": "交易员名称,用于识别和管理,最多 50 个字符。" + }, + "exchange_id": { + "type": "entity_ref", + "required": true, + "description": "绑定的交易所配置 ID,必须是已存在且已启用的交易所配置。" + }, + "ai_model_id": { + "type": "entity_ref", + "required": true, + "description": "绑定的 AI 模型配置 ID,必须是已存在且已启用的模型配置。" + }, + "strategy_id": { + "type": "entity_ref", + "required": true, + "description": "绑定的策略模板 ID,必须是已存在的策略模板。" + }, + "scan_interval_minutes": { + "type": "int", + "min": 3, + "max": 60, + "default": 5, + "description": "AI 扫描决策间隔,单位分钟,手动面板可配置范围 3~60 分钟。超出范围会自动收敛到边界值并告知用户。" + }, + "is_cross_margin": { + "type": "bool", + "default": true, + "description": "保证金模式。true = 全仓(cross margin),false = 逐仓(isolated margin)。" + }, + "show_in_competition": { + "type": "bool", + "default": true, + "description": "是否在竞技场中显示该交易员的成绩。" + }, + "auto_start": { + "type": "bool", + "default": false, + "description": "创建后是否立即启动交易员。启动前系统会校验绑定的交易所、模型、策略均可用。" + } + }, + "validation_rules": [ + "exchange_id 对应的交易所配置必须已启用(enabled=true),否则无法创建或启动交易员。", + "ai_model_id 对应的模型配置必须已启用(enabled=true)且配置完整(api_key、custom_model_name 不为空;custom_api_url 若填写必须为合法 HTTPS),否则无法创建或启动交易员。", + "strategy_id 对应的策略模板必须存在,否则无法创建交易员。", + "scan_interval_minutes 超出 3~60 范围时,系统自动收敛到边界值,并通过 LLM 告知用户已调整,询问是否接受。", + "交易员初始余额由系统在创建时自动读取绑定交易所账户净值,不接受用户手动设置、充值或修改。", + "交易员名称不能从模型 provider 自动推断;用户说“用 claw402”表示模型选择,不表示交易员名称叫 claw402。", + "用户明确指定模型、交易所或策略时,若该资源不存在、被禁用、配置不完整或钱包余额不足,必须说明具体原因并让用户确认修复或替换;不得静默换成另一个资源。", + "若用户指定 claw402 作为模型,但 claw402 钱包余额为 0 USDC,应提示先充值或确认临时改用其他可用模型;不得说成 claw402 未启用,除非 enabled 确实为 false。", + "启动交易员前,绑定的模型必须已启用且完整,绑定的交易所也必须已启用且通过对应交易所的完整性校验,否则拒绝启动并明确指出缺哪一项。", + "若绑定的是 OKX 交易所,启用前必须已有 passphrase;若绑定的是 Hyperliquid,启用前必须已有 wallet_addr;若绑定的是 Aster,启用前必须已有 user、signer、private_key;若绑定的是 Lighter,启用前必须已有 wallet_addr 和 api_key_private_key。", + "启动(start)和停止(stop)操作属于高风险操作,必须先向用户确认再执行。", + "删除(delete)操作不可逆,必须先向用户确认再执行。" + ], "actions": { "create": { - "description": "创建新的交易员。", - "required_slots": ["name", "exchange", "model"], - "optional_slots": ["strategy", "auto_start"] + "description": "创建新的交易员。若缺少交易所、模型或策略,可在当前流程内先选择已有资源,或切去对应 skill 新建/启用后自动回流继续。", + "required_slots": ["name", "exchange", "model", "strategy"], + "optional_slots": ["auto_start", "scan_interval_minutes", "is_cross_margin", "show_in_competition"], + "goal": "创建并初始化一个交易员。", + "dynamic_rules": [ + "若用户提到的交易所、模型或策略已经存在且可用,应优先直接补入对应槽位,不要重新创建。", + "如果用户明确指定某个模型 provider(如 claw402),应先尝试匹配该 provider 对应的模型配置;只有在说明原因并得到用户确认后,才可改用其他模型。", + "若用户没有提供交易员名称,应生成一个来自交易所/策略/方向的清晰名称,或向用户追问;不要把模型 provider、交易所类型或策略字段误用为交易员名称。", + "若依赖资源不存在、被禁用,或用户明确要求新建或启用,禁止直接报缺字段;应切去对应 management:create 或 management:update_status 子任务。", + "子任务成功后,系统会恢复当前交易员草稿并继续补齐剩余槽位。", + "scan_interval_minutes 超出 3~60 时,自动收敛并告知用户。", + "不要向用户收集或确认初始余额;创建时由系统自动读取绑定交易所账户净值作为初始余额。", + "创建完成后询问用户是否立即启动(auto_start),启动前再次确认。" + ], + "success_output": "返回 trader_id,并给出创建结果摘要(名称、绑定的交易所/模型/策略、是否已启动)。", + "failure_output": "用人话指出缺失依赖项,或说明当前正在进入哪个依赖子任务。" }, "update": { - "description": "更新已有交易员。", + "description": "更新已有交易员,但只处理手动面板允许的字段:换绑策略、交易所、模型,或修改扫描间隔、保证金模式、竞技场显示。", "required_slots": ["target_ref"], - "optional_slots": ["name", "exchange", "model", "strategy", "scan_interval_minutes", "custom_prompt"] + "optional_slots": ["exchange_id", "ai_model_id", "strategy_id", "scan_interval_minutes", "is_cross_margin", "show_in_competition"], + "goal": "更新一个已有交易员的手动面板字段,但不改动策略、模型、交易所内部配置。", + "dynamic_rules": [ + "只更新用户明确提到的字段,不要覆盖未提及的字段。", + "换绑交易所/模型/策略时,新的资源必须已存在且已启用;若是钱包付费模型,还要解释余额不足等支付状态。", + "用户明确要求换成某个模型/交易所/策略时,不能自动选择另一个看起来可用的资源,除非用户确认。", + "如果用户要求改名,应明确告知交易员改名不在这里处理。", + "如果用户实际上是想修改策略参数、模型配置或交易所凭证,不要继续留在 trader update;应切到对应 management skill。" + ], + "success_output": "返回更新后的 trader_id 与简短配置摘要,明确哪些字段已经生效。", + "failure_output": "明确指出目标交易员不存在、依赖资源不可用,或哪一个字段值仍需用户补充/修正。" }, - "delete": { - "description": "删除交易员。", + "update_bindings": { + "description": "修改交易员手动面板可编辑的字段,可同时修改绑定关系、扫描间隔、保证金模式、竞技场显示。", "required_slots": ["target_ref"], - "needs_confirmation": true + "optional_slots": ["exchange_id", "ai_model_id", "strategy_id", "scan_interval_minutes", "is_cross_margin", "show_in_competition"], + "goal": "调整交易员手动面板可编辑的字段,而不改动无关配置。", + "dynamic_rules": [ + "新绑定的资源必须已存在且已启用,否则提示用户先启用或新建。", + "当指定模型是 claw402 或 blockrun-base 且钱包余额不足时,应提示充值或让用户确认临时切换模型。", + "扫描间隔超出 3~60 时,自动收敛并告知用户。" + ], + "success_output": "返回 trader_id,并明确展示新的模型/交易所/策略绑定结果。", + "failure_output": "明确指出缺少哪个绑定目标,或当前依赖资源为什么不可直接绑定。" + }, + "configure_strategy": { + "description": "仅修改交易员绑定的策略。", + "required_slots": ["target_ref", "strategy_id"], + "goal": "为指定交易员换绑一个策略模板。", + "dynamic_rules": [ + "若用户提到的是不存在的策略,应优先澄清或引导创建,而不是静默失败。" + ], + "success_output": "返回 trader_id,并明确告知当前生效的 strategy_id/策略名称。", + "failure_output": "明确指出目标交易员或策略不存在,或策略仍需用户澄清。" + }, + "configure_exchange": { + "description": "仅修改交易员绑定的交易所。", + "required_slots": ["target_ref", "exchange_id"], + "goal": "为指定交易员换绑一个交易所配置。", + "dynamic_rules": [ + "新的交易所配置必须已启用且可用,否则提示用户先启用或补齐凭证。" + ], + "success_output": "返回 trader_id,并明确告知当前生效的 exchange_id/交易所名称。", + "failure_output": "明确指出目标交易员或交易所不存在,或交易所当前不可用。" + }, + "configure_model": { + "description": "仅修改交易员绑定的 AI 模型。", + "required_slots": ["target_ref", "ai_model_id"], + "goal": "为指定交易员换绑一个 AI 模型配置。", + "dynamic_rules": [ + "新的模型配置必须已启用且可调用,否则提示用户先启用或补齐模型配置。", + "若用户指定的是 claw402,应优先绑定 claw402;只有在钱包余额不足、凭证缺失或配置不可用且用户确认后,才允许改绑其他模型。" + ], + "success_output": "返回 trader_id,并明确告知当前生效的 ai_model_id/模型名称。", + "failure_output": "明确指出目标交易员或模型不存在,或模型当前不可用。" }, "start": { - "description": "启动交易员。", + "description": "启动交易员,使其开始自动交易。高风险操作,必须确认。", "required_slots": ["target_ref"], - "needs_confirmation": true + "needs_confirmation": true, + "goal": "让一个已配置好的交易员进入运行状态。", + "dynamic_rules": [ + "启动前系统会自动校验绑定的交易所、模型、策略是否均可用。", + "若绑定模型为 claw402 或 blockrun-base 且钱包余额不足,应提示充值或换模型;不要把它泛化成“模型不可用”。", + "若校验失败,用人话告知用户具体哪个依赖不可用,并引导修复。" + ], + "success_output": "返回 trader_id,并明确告知交易员已开始运行。", + "failure_output": "明确指出缺少确认、依赖资源不可用,或启动未通过校验。" }, "stop": { - "description": "停止交易员。", + "description": "停止交易员,使其停止自动交易。高风险操作,必须确认。", "required_slots": ["target_ref"], - "needs_confirmation": true + "needs_confirmation": true, + "goal": "让一个运行中的交易员停止自动交易。", + "dynamic_rules": [ + "若交易员当前并未运行,也应给用户清晰说明,而不是假装停止成功。" + ], + "success_output": "返回 trader_id,并明确告知交易员已停止。", + "failure_output": "明确指出缺少确认、目标交易员不存在,或当前状态无法停止。" }, - "query": { - "description": "查询交易员列表或状态。" + "delete": { + "description": "删除交易员,不可逆操作,必须确认。支持删除单个、多个或全部交易员。", + "required_slots": [], + "needs_confirmation": true, + "goal": "删除一个、多个或全部交易员及其运行入口。", + "dynamic_rules": [ + "必须在确认后执行,并明确提醒该操作不可逆。", + "删除范围可以是单个 target_ref、多个目标,或 bulk_scope=all。", + "删除前必须确认目标交易员都已停止;若存在运行中的交易员,不能删除,应要求用户先停止这些交易员。" + ], + "success_output": "返回删除成功结果,并明确告知哪些交易员已被移除。", + "failure_output": "明确指出缺少确认、目标交易员不存在、目标仍在运行,或删除失败原因。" + }, + "query_list": { + "description": "查询所有交易员列表,包含名称、运行状态、绑定信息。", + "goal": "列出当前用户可见的交易员,并给出足够的摘要用于后续选择。", + "dynamic_rules": [ + "优先返回名称、运行状态、绑定的模型/交易所/策略,不要冗余展开全部详情。" + ], + "success_output": "返回交易员列表摘要,便于用户继续指定目标对象。", + "failure_output": "若列表为空,应明确告知当前没有交易员,而不是返回模糊空结果。" + }, + "query_running": { + "description": "查询当前运行中的交易员列表。", + "goal": "仅列出处于运行状态的交易员。", + "dynamic_rules": [ + "若当前没有运行中的交易员,应明确告知为空。" + ], + "success_output": "返回当前运行中的交易员列表摘要。", + "failure_output": "若没有运行中的交易员,应明确返回空列表说明。" + }, + "query_detail": { + "description": "查询某个交易员的详细配置,包括绑定的交易所、模型、策略、扫描间隔、保证金模式等。", + "required_slots": ["target_ref"], + "goal": "读取一个交易员的详细配置和当前绑定信息。", + "dynamic_rules": [ + "若目标对象有歧义,应先澄清再读取详情。" + ], + "success_output": "返回目标交易员的详细配置摘要。", + "failure_output": "明确指出目标交易员不存在,或当前引用需要重新指定。" } }, "tool_mapping": { "create": "manage_trader:create", "update": "manage_trader:update", - "delete": "manage_trader:delete", + "update_bindings": "manage_trader:update", + "configure_strategy": "manage_trader:update", + "configure_exchange": "manage_trader:update", + "configure_model": "manage_trader:update", "start": "manage_trader:start", "stop": "manage_trader:stop", - "query": "manage_trader:list" + "delete": "manage_trader:delete", + "query_list": "manage_trader:list", + "query_running": "manage_trader:list", + "query_detail": "manage_trader:list" } } diff --git a/agent/strategy_draft.go b/agent/strategy_draft.go new file mode 100644 index 00000000..5089ed61 --- /dev/null +++ b/agent/strategy_draft.go @@ -0,0 +1,27 @@ +package agent + +import ( + "strings" +) + +func inferStandaloneStrategyName(text string) string { + value := strings.TrimSpace(text) + if value == "" || len([]rune(value)) > 50 { + return "" + } + if strategyCreateConfirmationReply(value) || strategyCreateDefaultConfigReply(value) || isCancelSkillReply(value) { + return "" + } + if parseStrategyTypeValue(value) != "" { + return "" + } + if containsAny(strings.ToLower(value), []string{"创建", "新建", "create", "grid_trading", "ai_trading"}) { + return "" + } + return value +} + +func activeHistoryMessageAsksStrategyName(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + return containsAny(lower, []string{"策略名", "名称", "名字", "叫什么", "name"}) +} diff --git a/agent/strategy_field_catalog.go b/agent/strategy_field_catalog.go new file mode 100644 index 00000000..d21589fb --- /dev/null +++ b/agent/strategy_field_catalog.go @@ -0,0 +1,224 @@ +package agent + +func manualStrategyEditableFieldKeys() []string { + return []string{ + "name", + "description", + "is_public", + "config_visible", + "strategy_type", + "symbol", + "grid_count", + "total_investment", + "leverage", + "upper_price", + "lower_price", + "use_atr_bounds", + "atr_multiplier", + "distribution", + "enable_direction_adjust", + "direction_bias_ratio", + "max_drawdown_pct", + "stop_loss_pct", + "daily_loss_limit_pct", + "use_maker_only", + "source_type", + "static_coins", + "excluded_coins", + "use_ai500", + "ai500_limit", + "use_oi_top", + "oi_top_limit", + "use_oi_low", + "oi_low_limit", + "primary_timeframe", + "primary_count", + "selected_timeframes", + "enable_ema", + "enable_macd", + "enable_rsi", + "enable_atr", + "enable_boll", + "enable_volume", + "enable_oi", + "enable_funding_rate", + "ema_periods", + "rsi_periods", + "atr_periods", + "boll_periods", + "nofxos_api_key", + "enable_quant_data", + "enable_quant_oi", + "enable_quant_netflow", + "enable_oi_ranking", + "oi_ranking_duration", + "oi_ranking_limit", + "enable_netflow_ranking", + "netflow_ranking_duration", + "netflow_ranking_limit", + "enable_price_ranking", + "price_ranking_duration", + "price_ranking_limit", + "btceth_max_leverage", + "altcoin_max_leverage", + "min_risk_reward_ratio", + "min_confidence", + "role_definition", + "trading_frequency", + "entry_standards", + "decision_process", + "custom_prompt", + } +} + +func manualStrategyEditableFieldKeysForType(strategyType string) []string { + common := []string{ + "name", + "description", + "is_public", + "config_visible", + "strategy_type", + } + switch strategyType { + case "grid_trading": + return append(common, + "symbol", + "grid_count", + "total_investment", + "leverage", + "upper_price", + "lower_price", + "use_atr_bounds", + "atr_multiplier", + "distribution", + "enable_direction_adjust", + "direction_bias_ratio", + "max_drawdown_pct", + "stop_loss_pct", + "daily_loss_limit_pct", + "use_maker_only", + ) + case "ai_trading": + return append(common, + "source_type", + "static_coins", + "excluded_coins", + "use_ai500", + "ai500_limit", + "use_oi_top", + "oi_top_limit", + "use_oi_low", + "oi_low_limit", + "primary_timeframe", + "primary_count", + "selected_timeframes", + "enable_ema", + "enable_macd", + "enable_rsi", + "enable_atr", + "enable_boll", + "enable_volume", + "enable_oi", + "enable_funding_rate", + "ema_periods", + "rsi_periods", + "atr_periods", + "boll_periods", + "nofxos_api_key", + "enable_quant_data", + "enable_quant_oi", + "enable_quant_netflow", + "enable_oi_ranking", + "oi_ranking_duration", + "oi_ranking_limit", + "enable_netflow_ranking", + "netflow_ranking_duration", + "netflow_ranking_limit", + "enable_price_ranking", + "price_ranking_duration", + "price_ranking_limit", + "btceth_max_leverage", + "altcoin_max_leverage", + "min_risk_reward_ratio", + "min_confidence", + "role_definition", + "trading_frequency", + "entry_standards", + "decision_process", + "custom_prompt", + ) + default: + return manualStrategyEditableFieldKeys() + } +} + +func agentStrategyUpdatableFieldKeys() []string { + return []string{ + "name", + "description", + "is_public", + "config_visible", + "strategy_type", + "symbol", + "grid_count", + "total_investment", + "leverage", + "upper_price", + "lower_price", + "use_atr_bounds", + "atr_multiplier", + "distribution", + "enable_direction_adjust", + "direction_bias_ratio", + "max_drawdown_pct", + "stop_loss_pct", + "daily_loss_limit_pct", + "use_maker_only", + "source_type", + "static_coins", + "excluded_coins", + "use_ai500", + "ai500_limit", + "use_oi_top", + "oi_top_limit", + "use_oi_low", + "oi_low_limit", + "primary_timeframe", + "primary_count", + "selected_timeframes", + "enable_ema", + "enable_macd", + "enable_rsi", + "enable_atr", + "enable_boll", + "enable_volume", + "enable_oi", + "enable_funding_rate", + "ema_periods", + "rsi_periods", + "atr_periods", + "boll_periods", + "nofxos_api_key", + "enable_quant_data", + "enable_quant_oi", + "enable_quant_netflow", + "enable_oi_ranking", + "oi_ranking_duration", + "oi_ranking_limit", + "enable_netflow_ranking", + "netflow_ranking_duration", + "netflow_ranking_limit", + "enable_price_ranking", + "price_ranking_duration", + "price_ranking_limit", + "btceth_max_leverage", + "altcoin_max_leverage", + "min_risk_reward_ratio", + "min_confidence", + "role_definition", + "trading_frequency", + "entry_standards", + "decision_process", + "custom_prompt", + } +} diff --git a/agent/stream_text.go b/agent/stream_text.go new file mode 100644 index 00000000..f57e8f52 --- /dev/null +++ b/agent/stream_text.go @@ -0,0 +1,49 @@ +package agent + +import "strings" + +func emitStreamText(onEvent func(event, data string), text string) { + if onEvent == nil { + return + } + for _, chunk := range splitStreamText(text) { + onEvent(StreamEventDelta, chunk) + } +} + +func splitStreamText(text string) []string { + text = strings.TrimSpace(text) + if text == "" { + return nil + } + + lines := strings.Split(text, "\n") + chunks := make([]string, 0, len(lines)*2) + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + start := 0 + for i, r := range line { + switch r { + case '。', '!', '?', '.', '!', '?', ';', ';', ':', ':', ',', ',': + part := strings.TrimSpace(line[start : i+len(string(r))]) + if part != "" { + chunks = append(chunks, part) + } + start = i + len(string(r)) + } + } + if start < len(line) { + part := strings.TrimSpace(line[start:]) + if part != "" { + chunks = append(chunks, part) + } + } + } + if len(chunks) == 0 { + return []string{text} + } + return chunks +} diff --git a/agent/tools.go b/agent/tools.go index be7e1f24..981287db 100644 --- a/agent/tools.go +++ b/agent/tools.go @@ -5,9 +5,11 @@ import ( "context" "encoding/json" "fmt" + "net/http" "os" "path/filepath" "sort" + "strconv" "strings" "time" @@ -16,14 +18,458 @@ import ( "nofx/safe" "nofx/security" "nofx/store" + "nofx/trader" + "nofx/trader/aster" + "nofx/trader/binance" + "nofx/trader/bitget" + "nofx/trader/bybit" + "nofx/trader/gate" + hyperliquidtrader "nofx/trader/hyperliquid" + "nofx/trader/indodax" + "nofx/trader/kucoin" + "nofx/trader/lighter" + "nofx/trader/okx" ) // cachedTools holds the static tool definitions (built once, reused per message). var cachedTools = buildAgentTools() +var ( + binanceFuturesAPIBaseURL = "https://fapi.binance.com" + marketDataHTTPClient = http.DefaultClient + traderInitialBalanceFetcher = defaultTraderInitialBalanceFetcher +) + // agentTools returns the tools available to the LLM for autonomous action. func agentTools() []mcp.Tool { return cachedTools } +func plannerToolsForText(text string) []mcp.Tool { + domain := plannerToolDomainForText(text) + compactStrategy := !looksLikeStrategyMutationIntent(text) + names := plannerToolNamesForDomain(domain) + return toolsByName(names, compactStrategy) +} + +func plannerToolDomainForText(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return "general" + } + if containsAny(lower, []string{"诊断", "排查", "为什么", "为啥", "失败", "报错", "异常", "停止", "没下单", "failed", "error", "diagnose", "debug", "logs", "stopped", "not trading"}) { + return "diagnosis" + } + if hasExplicitManagementDomainCue(text, "exchange") || containsAny(lower, []string{"交易所", "exchange", "apikey", "secret", "passphrase", "wallet address", "api凭证"}) { + return "exchange" + } + if hasExplicitManagementDomainCue(text, "model") || containsAny(lower, []string{"ai model", "模型", "provider", "api key", "custom_model", "custom api"}) { + return "model" + } + if hasExplicitManagementDomainCue(text, "strategy") || containsAny(lower, []string{"策略", "strategy", "选币", "止盈", "止损", "杠杆", "风控", "risk control"}) { + return "strategy" + } + if hasExplicitManagementDomainCue(text, "trader") || containsAny(lower, []string{"交易员", "trader", "启动", "停止交易员", "扫描间隔", "竞技场"}) { + return "trader" + } + if containsAny(lower, []string{"余额", "资产", "仓位", "持仓", "订单", "成交", "交易历史", "balance", "position", "positions", "trade history", "account"}) { + return "account" + } + if containsAny(lower, []string{"行情", "价格", "k线", "kline", "market", "price", "btc", "eth", "sol", "usdt", "股票", "stock"}) { + return "market" + } + return "general" +} + +func plannerToolNamesForDomain(domain string) []string { + switch domain { + case "market": + return []string{"get_market_snapshot", "get_market_price", "get_kline", "search_stock"} + case "account": + return []string{"get_balance", "get_positions", "get_trade_history"} + case "trader": + return []string{"get_model_configs", "get_exchange_configs", "get_strategies", "manage_trader"} + case "model": + return []string{"get_model_configs", "manage_model_config"} + case "exchange": + return []string{"get_exchange_configs", "manage_exchange_config"} + case "strategy": + return []string{"get_strategies", "manage_strategy"} + case "diagnosis": + return []string{"get_decisions", "get_backend_logs", "get_model_configs", "get_exchange_configs", "get_strategies", "manage_trader"} + default: + return []string{ + "get_preferences", "manage_preferences", + "get_decisions", "get_backend_logs", + "get_exchange_configs", "manage_exchange_config", + "get_model_configs", "manage_model_config", + "get_strategies", "manage_strategy", + "manage_trader", + "get_balance", "get_positions", "get_trade_history", + "get_market_snapshot", "get_market_price", "get_kline", "search_stock", + } + } +} + +func toolsByName(names []string, compactStrategy bool) []mcp.Tool { + if len(names) == 0 { + return nil + } + byName := make(map[string]mcp.Tool, len(cachedTools)) + for _, tool := range cachedTools { + byName[tool.Function.Name] = tool + } + out := make([]mcp.Tool, 0, len(names)) + seen := make(map[string]bool, len(names)) + for _, name := range names { + if seen[name] { + continue + } + seen[name] = true + tool, ok := byName[name] + if !ok { + continue + } + if compactStrategy && name == "manage_strategy" { + tool = compactManageStrategyTool(tool) + } + out = append(out, tool) + } + return out +} + +func compactManageStrategyTool(tool mcp.Tool) mcp.Tool { + tool.Function.Description = "List, query, delete, activate, duplicate, create, or update strategy templates. Planning schema is compact; use action plus strategy_id/name/description/lang/is_public/config_visible, and include config only when the user explicitly provides strategy config fields." + tool.Function.Parameters = map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{"type": "string", "enum": []string{"list", "create", "update", "delete", "activate", "duplicate", "get_default_config"}}, + "strategy_id": map[string]any{"type": "string"}, + "name": map[string]any{"type": "string"}, + "description": map[string]any{"type": "string"}, + "lang": map[string]any{"type": "string", "enum": []string{"zh", "en"}}, + "is_public": map[string]any{"type": "boolean"}, + "config_visible": map[string]any{"type": "boolean"}, + "config": map[string]any{"type": "object", "description": "Strategy config patch. Use precise field paths/objects from the user request; omit when listing/querying/deleting/activating/duplicating."}, + }, + "required": []string{"action"}, + } + return tool +} + +func looksLikeStrategyMutationIntent(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + return hasExplicitManagementDomainCue(text, "strategy") && + containsAny(lower, []string{"创建", "新建", "创一个", "创个", "建一个", "修改", "更新", "编辑", "调整", "配置", "create", "new", "update", "edit", "configure"}) +} + +func normalizedEntityName(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func sameEntityName(a, b string) bool { + return normalizedEntityName(a) != "" && normalizedEntityName(a) == normalizedEntityName(b) +} + +func (a *Agent) ensureUniqueModelName(storeUserID, name, excludeID string) error { + models, err := a.store.AIModel().List(storeUserID) + if err != nil { + return err + } + for _, model := range models { + if model == nil || strings.TrimSpace(model.ID) == strings.TrimSpace(excludeID) { + continue + } + if sameEntityName(model.Name, name) { + return fmt.Errorf("model name %q already exists", strings.TrimSpace(name)) + } + } + return nil +} + +func (a *Agent) findModelByProvider(storeUserID, provider string) (*store.AIModel, error) { + models, err := a.store.AIModel().List(storeUserID) + if err != nil { + return nil, err + } + normalizedProvider := strings.ToLower(strings.TrimSpace(provider)) + for _, model := range models { + if model == nil { + continue + } + if strings.ToLower(strings.TrimSpace(model.Provider)) == normalizedProvider { + return model, nil + } + } + return nil, nil +} + +func (a *Agent) ensureUniqueExchangeAccountName(storeUserID, accountName, excludeID string) error { + exchanges, err := a.store.Exchange().List(storeUserID) + if err != nil { + return err + } + for _, exchange := range exchanges { + if exchange == nil || strings.TrimSpace(exchange.ID) == strings.TrimSpace(excludeID) { + continue + } + if sameEntityName(exchange.AccountName, accountName) { + return fmt.Errorf("exchange account name %q already exists", strings.TrimSpace(accountName)) + } + } + return nil +} + +func (a *Agent) ensureUniqueStrategyName(storeUserID, name, excludeID string) error { + strategies, err := a.store.Strategy().List(storeUserID) + if err != nil { + return err + } + for _, strategy := range strategies { + if strategy == nil || strings.TrimSpace(strategy.ID) == strings.TrimSpace(excludeID) { + continue + } + if sameEntityName(strategy.Name, name) { + return fmt.Errorf("strategy name %q already exists", strings.TrimSpace(name)) + } + } + return nil +} + +func (a *Agent) ensureUniqueTraderName(storeUserID, name, excludeID string) error { + traders, err := a.store.Trader().List(storeUserID) + if err != nil { + return err + } + for _, trader := range traders { + if trader == nil || strings.TrimSpace(trader.ID) == strings.TrimSpace(excludeID) { + continue + } + if sameEntityName(trader.Name, name) { + return fmt.Errorf("trader name %q already exists", strings.TrimSpace(name)) + } + } + return nil +} + +func stringArraySchema(description string) map[string]any { + return map[string]any{ + "type": "array", + "description": description, + "items": map[string]any{"type": "string"}, + } +} + +func intArraySchema(description string) map[string]any { + return map[string]any{ + "type": "array", + "description": description, + "items": map[string]any{"type": "number"}, + } +} + +func strategyConfigSchema() map[string]any { + return map[string]any{ + "type": "object", + "description": "Full or partial strategy config. Only include the fields you want to create or update.", + "properties": map[string]any{ + "strategy_type": map[string]any{"type": "string", "enum": []string{"ai_trading", "grid_trading"}, "description": "Top-level discriminator. ai_trading must use ai_config only. grid_trading must use grid_config only."}, + "language": map[string]any{"type": "string", "enum": []string{"zh", "en"}}, + "ai_config": map[string]any{ + "type": "object", + "description": "AI trading only. Do not include this for grid_trading.", + "properties": map[string]any{ + "coin_source": map[string]any{ + "type": "object", + "properties": map[string]any{ + "source_type": map[string]any{"type": "string", "enum": []string{"static", "ai500", "oi_top", "oi_low"}, "description": "Manual page coin source: static, ai500, oi_top, oi_low."}, + "static_coins": stringArraySchema("Static coin symbols such as BTCUSDT or ETHUSDT. Manual page allows at most 10. xyz: assets such as xyz:TSLA, xyz:GOLD, xyz:XYZ100 are also supported."), + "excluded_coins": stringArraySchema("Coin symbols to exclude from all sources."), + "use_ai500": map[string]any{"type": "boolean"}, + "ai500_limit": map[string]any{"type": "number", "minimum": 1, "maximum": 10, "description": "Manual page range 1-10."}, + "use_oi_top": map[string]any{"type": "boolean"}, + "oi_top_limit": map[string]any{"type": "number", "minimum": 1, "maximum": 10, "description": "Manual page range 1-10."}, + "use_oi_low": map[string]any{"type": "boolean"}, + "oi_low_limit": map[string]any{"type": "number", "minimum": 1, "maximum": 10, "description": "Manual page range 1-10."}, + }, + }, + "indicators": map[string]any{ + "type": "object", + "properties": map[string]any{ + "klines": map[string]any{ + "type": "object", + "properties": map[string]any{ + "primary_timeframe": map[string]any{"type": "string", "enum": []string{"1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w"}}, + "primary_count": map[string]any{"type": "number", "minimum": 10, "maximum": 30, "description": "Manual page range 10-30."}, + "longer_timeframe": map[string]any{"type": "string"}, + "longer_count": map[string]any{"type": "number"}, + "enable_multi_timeframe": map[string]any{"type": "boolean"}, + "selected_timeframes": stringArraySchema("Selected analysis timeframes. Allowed values: 1m,3m,5m,15m,30m,1h,2h,4h,6h,8h,12h,1d,3d,1w. Manual page allows at most 4."), + }, + }, + "enable_raw_klines": map[string]any{"type": "boolean"}, + "enable_ema": map[string]any{"type": "boolean"}, + "enable_macd": map[string]any{"type": "boolean"}, + "enable_rsi": map[string]any{"type": "boolean"}, + "enable_atr": map[string]any{"type": "boolean"}, + "enable_boll": map[string]any{"type": "boolean"}, + "enable_volume": map[string]any{"type": "boolean"}, + "enable_oi": map[string]any{"type": "boolean"}, + "enable_funding_rate": map[string]any{"type": "boolean"}, + "ema_periods": intArraySchema("EMA periods such as [20,50]."), + "rsi_periods": intArraySchema("RSI periods such as [7,14]."), + "atr_periods": intArraySchema("ATR periods such as [14]."), + "boll_periods": intArraySchema("BOLL periods such as [20]."), + "nofxos_api_key": map[string]any{"type": "string"}, + "enable_quant_data": map[string]any{"type": "boolean"}, + "enable_quant_oi": map[string]any{"type": "boolean"}, + "enable_quant_netflow": map[string]any{"type": "boolean"}, + "enable_oi_ranking": map[string]any{"type": "boolean"}, + "oi_ranking_duration": map[string]any{"type": "string", "enum": []string{"1h", "4h", "24h"}}, + "oi_ranking_limit": map[string]any{"type": "number", "enum": []int{5, 10, 15, 20}}, + "enable_netflow_ranking": map[string]any{"type": "boolean"}, + "netflow_ranking_duration": map[string]any{"type": "string", "enum": []string{"1h", "4h", "24h"}}, + "netflow_ranking_limit": map[string]any{"type": "number", "enum": []int{5, 10, 15, 20}}, + "enable_price_ranking": map[string]any{"type": "boolean"}, + "price_ranking_duration": map[string]any{"type": "string", "enum": []string{"1h", "4h", "24h", "1h,4h,24h"}}, + "price_ranking_limit": map[string]any{"type": "number", "enum": []int{5, 10, 15, 20}}, + }, + }, + "custom_prompt": map[string]any{"type": "string"}, + "risk_control": map[string]any{ + "type": "object", + "properties": map[string]any{ + "btc_eth_max_leverage": map[string]any{"type": "number", "minimum": 1, "maximum": 20}, + "altcoin_max_leverage": map[string]any{"type": "number", "minimum": 1, "maximum": 20}, + "min_risk_reward_ratio": map[string]any{"type": "number", "minimum": 1, "maximum": 10, "description": "Manual page range 1-10, step 0.5."}, + "min_confidence": map[string]any{"type": "number", "minimum": 50, "maximum": 100, "description": "Manual page range 50-100."}, + }, + }, + "prompt_sections": map[string]any{ + "type": "object", + "properties": map[string]any{ + "role_definition": map[string]any{"type": "string"}, + "trading_frequency": map[string]any{"type": "string"}, + "entry_standards": map[string]any{"type": "string"}, + "decision_process": map[string]any{"type": "string"}, + }, + }, + }, + }, + "grid_config": map[string]any{ + "description": "Grid trading only. Do not include this for ai_trading.", + "type": "object", + "properties": map[string]any{ + "symbol": map[string]any{"type": "string", "enum": []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT", "XRPUSDT", "DOGEUSDT"}, "description": "Manual page dropdown options for grid trading symbols."}, + "grid_count": map[string]any{"type": "number", "minimum": 5, "maximum": 50, "description": "Manual page range 5-50."}, + "total_investment": map[string]any{"type": "number", "minimum": 100, "description": "User's actual capital/margin budget for the grid strategy, not leveraged notional exposure. Minimum 100 USDT."}, + "leverage": map[string]any{"type": "number", "minimum": 1, "maximum": 5, "description": "Manual page range 1-5."}, + "upper_price": map[string]any{"type": "number"}, + "lower_price": map[string]any{"type": "number"}, + "use_atr_bounds": map[string]any{"type": "boolean"}, + "atr_multiplier": map[string]any{"type": "number", "minimum": 1, "maximum": 5, "description": "Manual page range 1-5, step 0.5."}, + "distribution": map[string]any{"type": "string", "enum": []string{"uniform", "gaussian", "pyramid"}}, + "max_drawdown_pct": map[string]any{"type": "number", "minimum": 5, "maximum": 50, "description": "Manual page range 5-50."}, + "stop_loss_pct": map[string]any{"type": "number", "minimum": 1, "maximum": 20, "description": "Manual page range 1-20."}, + "daily_loss_limit_pct": map[string]any{"type": "number", "minimum": 1, "maximum": 30, "description": "Manual page range 1-30."}, + "use_maker_only": map[string]any{"type": "boolean"}, + "enable_direction_adjust": map[string]any{"type": "boolean"}, + "direction_bias_ratio": map[string]any{"type": "number", "minimum": 0.55, "maximum": 0.9, "description": "Manual page range 0.55-0.90 (shown as 55%-90%)."}, + }, + }, + "publish_config": map[string]any{ + "type": "object", + "description": "Shared publish settings for both AI and grid strategies.", + "properties": map[string]any{ + "is_public": map[string]any{"type": "boolean"}, + "config_visible": map[string]any{"type": "boolean"}, + }, + }, + }, + } +} + +func modelConfigFieldsSchema() map[string]any { + return map[string]any{ + "model_id": map[string]any{ + "type": "string", + "description": "Existing model id for update/delete, or the desired id for create.", + }, + "provider": map[string]any{ + "type": "string", + "description": "Provider slug such as openai, claude, gemini, deepseek, qwen, kimi, grok, minimax, claw402, blockrun-base, or blockrun-sol.", + }, + "name": map[string]any{ + "type": "string", + "description": "Display name for the model binding.", + }, + "enabled": map[string]any{ + "type": "boolean", + "description": "Whether this model binding is enabled.", + }, + "api_key": map[string]any{ + "type": "string", + "description": "Provider credential. For standard providers this is an API key; for claw402/blockrun it is the wallet private key. Sensitive and never returned in full.", + }, + "custom_api_url": map[string]any{ + "type": "string", + "description": "Custom API base URL or endpoint override. Optional for standard providers; not used by claw402/blockrun.", + }, + "custom_model_name": map[string]any{ + "type": "string", + "description": "Actual upstream model name to send to the provider. Optional when the provider has a default model.", + }, + } +} + +func exchangeConfigFieldsSchema() map[string]any { + return map[string]any{ + "exchange_id": map[string]any{ + "type": "string", + "description": "Existing exchange account id. Required for update and delete.", + }, + "exchange_type": map[string]any{ + "type": "string", + "description": "Exchange type such as binance, bybit, okx, bitget, gate, kucoin, hyperliquid, aster, lighter, or indodax.", + }, + "account_name": map[string]any{ + "type": "string", + "description": "User-visible account name like Main, Testnet, or Mom Account.", + }, + "enabled": map[string]any{ + "type": "boolean", + "description": "Whether this exchange binding should be enabled.", + }, + "api_key": map[string]any{"type": "string", "description": "API key for CEX-style exchanges."}, + "secret_key": map[string]any{"type": "string", "description": "Secret key for CEX-style exchanges."}, + "passphrase": map[string]any{"type": "string", "description": "Optional passphrase, required by exchanges like OKX, Bitget, and KuCoin."}, + "testnet": map[string]any{"type": "boolean", "description": "Whether to use the exchange testnet/sandbox."}, + "hyperliquid_wallet_addr": map[string]any{"type": "string", "description": "Hyperliquid wallet address."}, + "hyperliquid_unified_account": map[string]any{"type": "boolean", "description": "Whether Hyperliquid unified account mode is enabled."}, + "aster_user": map[string]any{"type": "string", "description": "Aster user address."}, + "aster_signer": map[string]any{"type": "string", "description": "Aster signer address."}, + "aster_private_key": map[string]any{"type": "string", "description": "Aster private key."}, + "lighter_wallet_addr": map[string]any{"type": "string", "description": "LIGHTER wallet address."}, + "lighter_private_key": map[string]any{"type": "string", "description": "LIGHTER private key."}, + "lighter_api_key_private_key": map[string]any{"type": "string", "description": "LIGHTER API key private key."}, + "lighter_api_key_index": map[string]any{"type": "number", "description": "LIGHTER API key index."}, + } +} + +func traderConfigFieldsSchema() map[string]any { + return map[string]any{ + "trader_id": map[string]any{ + "type": "string", + "description": "Required for update, delete, start, and stop.", + }, + "name": map[string]any{"type": "string", "description": "Trader display name. Required for create."}, + "ai_model_id": map[string]any{"type": "string", "description": "Bound AI model id."}, + "exchange_id": map[string]any{"type": "string", "description": "Bound exchange id."}, + "strategy_id": map[string]any{"type": "string", "description": "Bound strategy id."}, + "scan_interval_minutes": map[string]any{"type": "number", "description": "Trading scan interval in minutes."}, + "is_cross_margin": map[string]any{"type": "boolean", "description": "Whether cross margin is enabled."}, + "show_in_competition": map[string]any{"type": "boolean", "description": "Whether to show this trader in competition views."}, + } +} + func buildAgentTools() []mcp.Tool { return []mcp.Tool{ { @@ -64,20 +510,33 @@ func buildAgentTools() []mcp.Tool { Type: "function", Function: mcp.FunctionDef{ Name: "get_backend_logs", - Description: "Get recent backend log lines for a trader diagnosis. Prefer this when the user asks why a specific trader failed, stopped, or behaved unexpectedly. Returns recent matching log lines for the authenticated user's trader.", + Description: "Get recent backend log lines for a trader diagnosis. Prefer this when the user asks why a specific trader failed, stopped, or behaved unexpectedly. Returns recent matching log lines for the authenticated user's trader. You can identify the trader by name or id — name is preferred when the user provides it.", Parameters: map[string]any{ "type": "object", "properties": map[string]any{ - "trader_id": map[string]any{ - "type": "string", - "description": "Trader id to diagnose. The backend verifies that this trader belongs to the authenticated user before returning logs.", - }, + "trader_id": map[string]any{"type": "string", "description": "Trader id to diagnose."}, + "trader_name": map[string]any{"type": "string", "description": "Trader name to diagnose. Used to look up the trader when id is not known."}, "limit": map[string]any{"type": "number", "description": "Maximum number of recent log lines to return. Default 30."}, "errors_only": map[string]any{"type": "boolean", "description": "When true, only return error-like log lines. Default true."}, }, }, }, }, + { + Type: "function", + Function: mcp.FunctionDef{ + Name: "get_decisions", + Description: "Get recent AI decision records for a trader diagnosis. Use this before concluding why a trader is not opening orders: it shows candidate coins, wait/hold/open decisions, validation errors, execution logs, and AI call duration.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "trader_id": map[string]any{"type": "string", "description": "Trader id to diagnose."}, + "trader_name": map[string]any{"type": "string", "description": "Trader name to diagnose. Used to look up the trader when id is not known."}, + "limit": map[string]any{"type": "number", "description": "Maximum number of recent decision records to return. Default 5, max 20."}, + }, + }, + }, + }, { Type: "function", Function: mcp.FunctionDef{ @@ -90,7 +549,7 @@ func buildAgentTools() []mcp.Tool { Type: "function", Function: mcp.FunctionDef{ Name: "manage_exchange_config", - Description: "Create, update, or delete an exchange account binding. Use this when the user asks to add/edit/remove an exchange account, API key, secret, passphrase, wallet address, or account name. Sensitive fields are stored securely and are never returned in full.", + Description: "Create, update, or delete an exchange account binding. Use this when the user asks to add/edit/remove an exchange account, API key, secret, passphrase, wallet address, or account name. Prefer passing exact field values instead of vague summaries. Sensitive fields are stored securely and are never returned in full.", Parameters: map[string]any{ "type": "object", "properties": map[string]any{ @@ -98,35 +557,23 @@ func buildAgentTools() []mcp.Tool { "type": "string", "enum": []string{"create", "update", "delete"}, }, - "exchange_id": map[string]any{ - "type": "string", - "description": "Existing exchange account id. Required for update and delete.", - }, - "exchange_type": map[string]any{ - "type": "string", - "description": "Exchange type for a new binding, such as binance, bybit, okx, hyperliquid, aster, lighter, gate, kucoin, alpaca, forex, or metals.", - }, - "account_name": map[string]any{ - "type": "string", - "description": "User-visible account name like Main, Testnet, or Mom Account.", - }, - "enabled": map[string]any{ - "type": "boolean", - "description": "Whether this exchange binding should be enabled.", - }, - "api_key": map[string]any{"type": "string"}, - "secret_key": map[string]any{"type": "string"}, - "passphrase": map[string]any{"type": "string"}, - "testnet": map[string]any{"type": "boolean"}, - "hyperliquid_wallet_addr": map[string]any{"type": "string"}, - "hyperliquid_unified_account": map[string]any{"type": "boolean"}, - "aster_user": map[string]any{"type": "string"}, - "aster_signer": map[string]any{"type": "string"}, - "aster_private_key": map[string]any{"type": "string"}, - "lighter_wallet_addr": map[string]any{"type": "string"}, - "lighter_private_key": map[string]any{"type": "string"}, - "lighter_api_key_private_key": map[string]any{"type": "string"}, - "lighter_api_key_index": map[string]any{"type": "number"}, + "exchange_id": exchangeConfigFieldsSchema()["exchange_id"], + "exchange_type": exchangeConfigFieldsSchema()["exchange_type"], + "account_name": exchangeConfigFieldsSchema()["account_name"], + "enabled": exchangeConfigFieldsSchema()["enabled"], + "api_key": exchangeConfigFieldsSchema()["api_key"], + "secret_key": exchangeConfigFieldsSchema()["secret_key"], + "passphrase": exchangeConfigFieldsSchema()["passphrase"], + "testnet": exchangeConfigFieldsSchema()["testnet"], + "hyperliquid_wallet_addr": exchangeConfigFieldsSchema()["hyperliquid_wallet_addr"], + "hyperliquid_unified_account": exchangeConfigFieldsSchema()["hyperliquid_unified_account"], + "aster_user": exchangeConfigFieldsSchema()["aster_user"], + "aster_signer": exchangeConfigFieldsSchema()["aster_signer"], + "aster_private_key": exchangeConfigFieldsSchema()["aster_private_key"], + "lighter_wallet_addr": exchangeConfigFieldsSchema()["lighter_wallet_addr"], + "lighter_private_key": exchangeConfigFieldsSchema()["lighter_private_key"], + "lighter_api_key_private_key": exchangeConfigFieldsSchema()["lighter_api_key_private_key"], + "lighter_api_key_index": exchangeConfigFieldsSchema()["lighter_api_key_index"], }, "required": []string{"action"}, }, @@ -144,7 +591,7 @@ func buildAgentTools() []mcp.Tool { Type: "function", Function: mcp.FunctionDef{ Name: "manage_model_config", - Description: "Create, update, or delete an AI model binding. Use this when the user asks to add/edit/remove a model provider, API key, custom API URL, or custom model name. Sensitive fields are stored securely and are never returned in full.", + Description: "Create, update, or delete an AI model binding. Use this when the user asks to add/edit/remove a model provider, API key, custom API URL, or custom model name. Prefer passing exact field values instead of vague summaries. Sensitive fields are stored securely and are never returned in full.", Parameters: map[string]any{ "type": "object", "properties": map[string]any{ @@ -152,22 +599,13 @@ func buildAgentTools() []mcp.Tool { "type": "string", "enum": []string{"create", "update", "delete"}, }, - "model_id": map[string]any{ - "type": "string", - "description": "Existing model id for update/delete, or the desired id for create.", - }, - "provider": map[string]any{ - "type": "string", - "description": "Provider slug such as openai, claude, gemini, deepseek, qwen, kimi, grok, minimax, claw402, or blockrun-base.", - }, - "name": map[string]any{ - "type": "string", - "description": "Display name for a newly created model binding.", - }, - "enabled": map[string]any{"type": "boolean"}, - "api_key": map[string]any{"type": "string"}, - "custom_api_url": map[string]any{"type": "string"}, - "custom_model_name": map[string]any{"type": "string"}, + "model_id": modelConfigFieldsSchema()["model_id"], + "provider": modelConfigFieldsSchema()["provider"], + "name": modelConfigFieldsSchema()["name"], + "enabled": modelConfigFieldsSchema()["enabled"], + "api_key": modelConfigFieldsSchema()["api_key"], + "custom_api_url": modelConfigFieldsSchema()["custom_api_url"], + "custom_model_name": modelConfigFieldsSchema()["custom_model_name"], }, "required": []string{"action"}, }, @@ -185,7 +623,7 @@ func buildAgentTools() []mcp.Tool { Type: "function", Function: mcp.FunctionDef{ Name: "manage_strategy", - Description: "List, create, update, delete, activate, duplicate strategies, or get the default strategy config template. Use this when the user asks to create or edit a strategy template. Strategy templates are independent assets and do not require exchange/model bindings unless the user asks to run them via a trader.", + Description: "List, create, update, delete, activate, duplicate strategies, or get the default strategy config template. Use this when the user asks to create or edit a strategy template. Prefer passing precise field-level config patches in `config` instead of vague natural-language summaries.", Parameters: map[string]any{ "type": "object", "properties": map[string]any{ @@ -199,7 +637,7 @@ func buildAgentTools() []mcp.Tool { "lang": map[string]any{"type": "string", "enum": []string{"zh", "en"}}, "is_public": map[string]any{"type": "boolean"}, "config_visible": map[string]any{"type": "boolean"}, - "config": map[string]any{"type": "object", "description": "Full or partial strategy config JSON object, depending on action."}, + "config": strategyConfigSchema(), }, "required": []string{"action"}, }, @@ -209,7 +647,7 @@ func buildAgentTools() []mcp.Tool { Type: "function", Function: mcp.FunctionDef{ Name: "manage_trader", - Description: "List, create, update, delete, start, or stop traders. Use this when the user asks to create a trader, rename one, switch its exchange/model/strategy, delete it, or control its running state.", + Description: "List, create, update, delete, start, or stop traders. Trader edits are limited to exchange/model/strategy bindings, scan interval, margin mode, and competition visibility so they match the manual trader panel. If the user wants to modify the internal config of a strategy, model, or exchange, use the corresponding management tool instead.", Parameters: map[string]any{ "type": "object", "properties": map[string]any{ @@ -217,26 +655,14 @@ func buildAgentTools() []mcp.Tool { "type": "string", "enum": []string{"list", "create", "update", "delete", "start", "stop"}, }, - "trader_id": map[string]any{ - "type": "string", - "description": "Required for update, delete, start, and stop.", - }, - "name": map[string]any{"type": "string"}, - "ai_model_id": map[string]any{"type": "string"}, - "exchange_id": map[string]any{"type": "string"}, - "strategy_id": map[string]any{"type": "string"}, - "initial_balance": map[string]any{"type": "number"}, - "scan_interval_minutes": map[string]any{"type": "number"}, - "is_cross_margin": map[string]any{"type": "boolean"}, - "show_in_competition": map[string]any{"type": "boolean"}, - "btc_eth_leverage": map[string]any{"type": "number"}, - "altcoin_leverage": map[string]any{"type": "number"}, - "trading_symbols": map[string]any{"type": "string"}, - "custom_prompt": map[string]any{"type": "string"}, - "override_base_prompt": map[string]any{"type": "boolean"}, - "system_prompt_template": map[string]any{"type": "string"}, - "use_ai500": map[string]any{"type": "boolean"}, - "use_oi_top": map[string]any{"type": "boolean"}, + "trader_id": traderConfigFieldsSchema()["trader_id"], + "name": traderConfigFieldsSchema()["name"], + "ai_model_id": traderConfigFieldsSchema()["ai_model_id"], + "exchange_id": traderConfigFieldsSchema()["exchange_id"], + "strategy_id": traderConfigFieldsSchema()["strategy_id"], + "scan_interval_minutes": traderConfigFieldsSchema()["scan_interval_minutes"], + "is_cross_margin": traderConfigFieldsSchema()["is_cross_margin"], + "show_in_competition": traderConfigFieldsSchema()["show_in_competition"], }, "required": []string{"action"}, }, @@ -263,7 +689,7 @@ func buildAgentTools() []mcp.Tool { Type: "function", Function: mcp.FunctionDef{ Name: "execute_trade", - Description: "Execute a trade order (crypto or US stocks). Use this when the user explicitly asks to open/close a position. For stocks (e.g. AAPL, TSLA), use open_long to buy and close_long to sell. This creates a pending trade that requires user confirmation.", + Description: "Execute a trade order (crypto or US stocks). Use this only when the user explicitly asks to trade. For stocks (e.g. AAPL, TSLA), use open_long to buy and close_long to sell. This creates a pending trade first; it does not execute immediately. Large orders require an extra confirmation with 确认大额 trade_xxx / confirm large trade_xxx, and pending trades expire after 5 minutes.", Parameters: map[string]any{ "type": "object", "properties": map[string]any{ @@ -322,6 +748,56 @@ func buildAgentTools() []mcp.Tool { }, }, }, + { + Type: "function", + Function: mcp.FunctionDef{ + Name: "get_market_snapshot", + Description: "Get a real-time crypto market snapshot for analysis. Returns current price, 24h change, high/low, volume, funding rate, open interest, and recent K-line structure in one tool call. Prefer this when the user asks to analyze a coin, assess current行情, or wants a richer market read than a single price.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "symbol": map[string]any{ + "type": "string", + "description": "Crypto trading symbol, for example BTC, ETH, BTCUSDT, or ETHUSDT.", + }, + "interval": map[string]any{ + "type": "string", + "description": "Kline interval for the structure snapshot, for example 5m, 15m, 1h, or 4h. Defaults to 15m.", + }, + "limit": map[string]any{ + "type": "number", + "description": "Number of recent candles to fetch for the structure snapshot. Defaults to 20 and is capped at 100.", + }, + }, + "required": []string{"symbol"}, + }, + }, + }, + { + Type: "function", + Function: mcp.FunctionDef{ + Name: "get_kline", + Description: "Get recent kline/candlestick data for a crypto symbol. Use this when the user asks for recent candles, K 线, recent price structure, or a short-term chart context.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "symbol": map[string]any{ + "type": "string", + "description": "Crypto trading symbol, for example BTC, ETH, BTCUSDT, or ETHUSDT.", + }, + "interval": map[string]any{ + "type": "string", + "description": "Kline interval, for example 1m, 5m, 15m, 1h, 4h, or 1d. Defaults to 15m.", + }, + "limit": map[string]any{ + "type": "number", + "description": "Number of recent candles to fetch. Defaults to 50 and is capped at 300.", + }, + }, + "required": []string{"symbol"}, + }, + }, + }, { Type: "function", Function: mcp.FunctionDef{ @@ -358,6 +834,36 @@ func buildAgentTools() []mcp.Tool { }, }, }, + { + Type: "function", + Function: mcp.FunctionDef{ + Name: "get_watchlist", + Description: "Get the current Sentinel watchlist of monitored crypto symbols. Use this when the user asks which coins are being watched or monitored right now.", + Parameters: map[string]any{"type": "object", "properties": map[string]any{}}, + }, + }, + { + Type: "function", + Function: mcp.FunctionDef{ + Name: "manage_watchlist", + Description: "Add or remove a monitored crypto symbol from the Sentinel watchlist at runtime. Use this when the user asks to watch, monitor, unwatch, or stop monitoring a coin.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"add", "remove"}, + "description": "Whether to add or remove the symbol from the watchlist.", + }, + "symbol": map[string]any{ + "type": "string", + "description": "Crypto symbol to watch, such as BTC, ETH, SOL, BTCUSDT, or ETHUSDT.", + }, + }, + "required": []string{"action", "symbol"}, + }, + }, + }, } } @@ -370,6 +876,8 @@ func (a *Agent) handleToolCall(ctx context.Context, storeUserID string, userID i return a.toolManagePreferences(userID, tc.Function.Arguments) case "get_backend_logs": return a.toolGetBackendLogs(storeUserID, tc.Function.Arguments) + case "get_decisions": + return a.toolGetDecisions(storeUserID, tc.Function.Arguments) case "get_exchange_configs": return a.toolGetExchangeConfigs(storeUserID) case "manage_exchange_config": @@ -389,15 +897,23 @@ func (a *Agent) handleToolCall(ctx context.Context, storeUserID string, userID i case "execute_trade": return a.toolExecuteTrade(ctx, userID, lang, tc.Function.Arguments) case "get_positions": - return a.toolGetPositions() + return a.toolGetPositions(storeUserID) case "get_balance": - return a.toolGetBalance() + return a.toolGetBalance(storeUserID) case "get_market_price": return a.toolGetMarketPrice(tc.Function.Arguments) + case "get_market_snapshot": + return a.toolGetMarketSnapshot(tc.Function.Arguments) + case "get_kline": + return a.toolGetKline(tc.Function.Arguments) case "get_trade_history": return a.toolGetTradeHistory(tc.Function.Arguments) case "get_candidate_coins": return a.toolGetCandidateCoins(storeUserID, userID, tc.Function.Arguments) + case "get_watchlist": + return a.toolGetWatchlist(lang) + case "manage_watchlist": + return a.toolManageWatchlist(lang, tc.Function.Arguments) default: return fmt.Sprintf(`{"error": "unknown tool: %s"}`, tc.Function.Name) } @@ -419,6 +935,7 @@ type safeExchangeToolConfig struct { AsterUser string `json:"aster_user,omitempty"` AsterSigner string `json:"aster_signer,omitempty"` LighterWalletAddr string `json:"lighter_wallet_addr,omitempty"` + LighterAPIKeyIndex int `json:"lighter_api_key_index,omitempty"` HasLighterPrivateKey bool `json:"has_lighter_private_key"` HasLighterAPIKey bool `json:"has_lighter_api_key_private_key"` } @@ -431,24 +948,21 @@ type safeModelToolConfig struct { HasAPIKey bool `json:"has_api_key"` CustomAPIURL string `json:"custom_api_url,omitempty"` CustomModelName string `json:"custom_model_name,omitempty"` + WalletAddress string `json:"wallet_address,omitempty"` + BalanceUSDC string `json:"balance_usdc,omitempty"` } type safeTraderToolConfig struct { - ID string `json:"id"` - Name string `json:"name"` - AIModelID string `json:"ai_model_id"` - ExchangeID string `json:"exchange_id"` - StrategyID string `json:"strategy_id,omitempty"` - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - IsRunning bool `json:"is_running"` - IsCrossMargin bool `json:"is_cross_margin"` - ShowInCompetition bool `json:"show_in_competition"` - BTCETHLeverage int `json:"btc_eth_leverage,omitempty"` - AltcoinLeverage int `json:"altcoin_leverage,omitempty"` - TradingSymbols string `json:"trading_symbols,omitempty"` - CustomPrompt string `json:"custom_prompt,omitempty"` - SystemPromptTemplate string `json:"system_prompt_template,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + AIModelID string `json:"ai_model_id"` + ExchangeID string `json:"exchange_id"` + StrategyID string `json:"strategy_id,omitempty"` + InitialBalance float64 `json:"initial_balance"` + ScanIntervalMinutes int `json:"scan_interval_minutes"` + IsRunning bool `json:"is_running"` + IsCrossMargin bool `json:"is_cross_margin"` + ShowInCompetition bool `json:"show_in_competition"` } type safeStrategyToolConfig struct { @@ -463,25 +977,47 @@ type safeStrategyToolConfig struct { HasConfig bool `json:"has_config"` } +var sensitiveToolKeys = map[string]struct{}{ + "api_key": {}, + "secret_key": {}, + "passphrase": {}, + "private_key": {}, + "password_hash": {}, + "lighter_api_key_private_key": {}, +} + +func stripSensitiveToolFields(value any) any { + switch typed := value.(type) { + case map[string]any: + cleaned := make(map[string]any, len(typed)) + for key, inner := range typed { + if _, blocked := sensitiveToolKeys[strings.ToLower(strings.TrimSpace(key))]; blocked { + continue + } + cleaned[key] = stripSensitiveToolFields(inner) + } + return cleaned + case []any: + out := make([]any, 0, len(typed)) + for _, inner := range typed { + out = append(out, stripSensitiveToolFields(inner)) + } + return out + default: + return value + } +} + type manageTraderArgs struct { - Action string `json:"action"` - TraderID string `json:"trader_id"` - Name string `json:"name"` - AIModelID string `json:"ai_model_id"` - ExchangeID string `json:"exchange_id"` - StrategyID string `json:"strategy_id"` - InitialBalance *float64 `json:"initial_balance"` - ScanIntervalMinutes *int `json:"scan_interval_minutes"` - IsCrossMargin *bool `json:"is_cross_margin"` - ShowInCompetition *bool `json:"show_in_competition"` - BTCETHLeverage *int `json:"btc_eth_leverage"` - AltcoinLeverage *int `json:"altcoin_leverage"` - TradingSymbols string `json:"trading_symbols"` - CustomPrompt string `json:"custom_prompt"` - OverrideBasePrompt *bool `json:"override_base_prompt"` - SystemPromptTemplate string `json:"system_prompt_template"` - UseAI500 *bool `json:"use_ai500"` - UseOITop *bool `json:"use_oi_top"` + Action string `json:"action"` + TraderID string `json:"trader_id"` + Name string `json:"name"` + AIModelID string `json:"ai_model_id"` + ExchangeID string `json:"exchange_id"` + StrategyID string `json:"strategy_id"` + ScanIntervalMinutes *int `json:"scan_interval_minutes"` + IsCrossMargin *bool `json:"is_cross_margin"` + ShowInCompetition *bool `json:"show_in_competition"` } func safeExchangeForTool(ex *store.Exchange) safeExchangeToolConfig { @@ -501,13 +1037,97 @@ func safeExchangeForTool(ex *store.Exchange) safeExchangeToolConfig { AsterUser: ex.AsterUser, AsterSigner: ex.AsterSigner, LighterWalletAddr: ex.LighterWalletAddr, + LighterAPIKeyIndex: ex.LighterAPIKeyIndex, HasLighterPrivateKey: ex.LighterPrivateKey != "", HasLighterAPIKey: ex.LighterAPIKeyPrivateKey != "", } } +func defaultTraderInitialBalanceFetcher(exchangeCfg *store.Exchange, userID string) (float64, bool, error) { + if exchangeCfg == nil { + return 0, false, fmt.Errorf("exchange config not found") + } + probe, err := buildTraderExchangeProbe(exchangeCfg, userID) + if err != nil { + return 0, false, err + } + balanceInfo, err := probe.GetBalance() + if err != nil { + return 0, false, err + } + return extractTraderInitialBalance(balanceInfo) +} + +func buildTraderExchangeProbe(exchangeCfg *store.Exchange, userID string) (trader.Trader, error) { + switch exchangeCfg.ExchangeType { + case "binance": + return binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID), nil + case "bybit": + return bybit.NewBybitTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey)), nil + case "okx": + return okx.NewOKXTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), string(exchangeCfg.Passphrase)), nil + case "bitget": + return bitget.NewBitgetTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), string(exchangeCfg.Passphrase)), nil + case "gate": + return gate.NewGateTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey)), nil + case "kucoin": + return kucoin.NewKuCoinTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), string(exchangeCfg.Passphrase)), nil + case "indodax": + return indodax.NewIndodaxTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey)), nil + case "hyperliquid": + return hyperliquidtrader.NewHyperliquidTrader( + string(exchangeCfg.APIKey), + exchangeCfg.HyperliquidWalletAddr, + exchangeCfg.Testnet, + exchangeCfg.HyperliquidUnifiedAcct, + ) + case "aster": + return aster.NewAsterTrader( + exchangeCfg.AsterUser, + exchangeCfg.AsterSigner, + string(exchangeCfg.AsterPrivateKey), + ) + case "lighter": + return lighter.NewLighterTraderV2( + exchangeCfg.LighterWalletAddr, + string(exchangeCfg.LighterAPIKeyPrivateKey), + exchangeCfg.LighterAPIKeyIndex, + false, + ) + default: + return nil, fmt.Errorf("unsupported exchange type: %s", exchangeCfg.ExchangeType) + } +} + +func extractTraderInitialBalance(balanceInfo map[string]interface{}) (float64, bool, error) { + for _, key := range []string{"total_equity", "totalEquity", "totalWalletBalance", "wallet_balance", "totalEq", "balance"} { + raw, ok := balanceInfo[key] + if !ok { + continue + } + switch v := raw.(type) { + case float64: + return v, true, nil + case float32: + return float64(v), true, nil + case int: + return float64(v), true, nil + case int64: + return float64(v), true, nil + case int32: + return float64(v), true, nil + case string: + parsed, err := strconv.ParseFloat(v, 64) + if err == nil { + return parsed, true, nil + } + } + } + return 0, false, fmt.Errorf("initial balance not set and unable to fetch balance from exchange") +} + func safeModelForTool(model *store.AIModel) safeModelToolConfig { - return safeModelToolConfig{ + safeModel := safeModelToolConfig{ ID: model.ID, Name: model.Name, Provider: model.Provider, @@ -516,6 +1136,18 @@ func safeModelForTool(model *store.AIModel) safeModelToolConfig { CustomAPIURL: model.CustomAPIURL, CustomModelName: model.CustomModelName, } + if agentProviderSupportsUSDCBalance(model.Provider) { + privateKey := strings.TrimSpace(string(model.APIKey)) + if privateKey != "" { + if walletAddress, err := agentWalletAddressFromPrivateKey(privateKey); err == nil && strings.TrimSpace(walletAddress) != "" { + safeModel.WalletAddress = walletAddress + if balance, balanceErr := agentQueryUSDCBalanceCached(walletAddress); balanceErr == nil { + safeModel.BalanceUSDC = fmt.Sprintf("%.6f", balance) + } + } + } + } + return safeModel } func modelConfigUsable(provider, modelID, apiKey, customAPIURL, customModelName string) bool { @@ -528,21 +1160,16 @@ func modelConfigUsable(provider, modelID, apiKey, customAPIURL, customModelName func safeTraderForTool(trader *store.Trader, isRunning bool) safeTraderToolConfig { return safeTraderToolConfig{ - ID: trader.ID, - Name: trader.Name, - AIModelID: trader.AIModelID, - ExchangeID: trader.ExchangeID, - StrategyID: trader.StrategyID, - InitialBalance: trader.InitialBalance, - ScanIntervalMinutes: trader.ScanIntervalMinutes, - IsRunning: isRunning, - IsCrossMargin: trader.IsCrossMargin, - ShowInCompetition: trader.ShowInCompetition, - BTCETHLeverage: trader.BTCETHLeverage, - AltcoinLeverage: trader.AltcoinLeverage, - TradingSymbols: trader.TradingSymbols, - CustomPrompt: trader.CustomPrompt, - SystemPromptTemplate: trader.SystemPromptTemplate, + ID: trader.ID, + Name: trader.Name, + AIModelID: trader.AIModelID, + ExchangeID: trader.ExchangeID, + StrategyID: trader.StrategyID, + InitialBalance: trader.InitialBalance, + ScanIntervalMinutes: trader.ScanIntervalMinutes, + IsRunning: isRunning, + IsCrossMargin: trader.IsCrossMargin, + ShowInCompetition: trader.ShowInCompetition, } } @@ -576,12 +1203,19 @@ func (a *Agent) toolGetExchangeConfigs(storeUserID string) string { } safe := make([]safeExchangeToolConfig, 0, len(exchanges)) for _, ex := range exchanges { + if !store.IsVisibleExchange(ex) { + continue + } safe = append(safe, safeExchangeForTool(ex)) } result, _ := json.Marshal(map[string]any{ "exchange_configs": safe, "count": len(safe), }) + var payload any + if err := json.Unmarshal(result, &payload); err == nil { + result, _ = json.Marshal(stripSensitiveToolFields(payload)) + } return string(result) } @@ -644,9 +1278,66 @@ func readBackendLogEntries(limit int, contains string, errorsOnly bool) (string, return path, matches, nil } +func filterBackendLogEntriesAny(entries []string, needles ...string) []string { + if len(entries) == 0 { + return nil + } + normalized := make([]string, 0, len(needles)) + for _, needle := range needles { + needle = strings.ToLower(strings.TrimSpace(needle)) + if needle == "" { + continue + } + normalized = append(normalized, needle) + } + if len(normalized) == 0 { + return entries + } + filtered := make([]string, 0, len(entries)) + for _, entry := range entries { + lower := strings.ToLower(entry) + for _, needle := range normalized { + if strings.Contains(lower, needle) { + filtered = append(filtered, entry) + break + } + } + } + return filtered +} + +func (a *Agent) resolveTraderForTool(storeUserID, traderID, traderName string) (*store.Trader, error) { + traderID = strings.TrimSpace(traderID) + traderName = strings.TrimSpace(traderName) + if traderID == "" && traderName == "" { + return nil, fmt.Errorf("trader_id or trader_name is required") + } + if traderID != "" { + traderCfg, err := a.store.Trader().GetByID(traderID) + if err != nil { + return nil, fmt.Errorf("failed to load trader: %w", err) + } + if traderCfg.UserID != storeUserID { + return nil, fmt.Errorf("trader not found for current user") + } + return traderCfg, nil + } + traders, err := a.store.Trader().List(storeUserID) + if err != nil { + return nil, fmt.Errorf("failed to list traders: %w", err) + } + for _, traderCfg := range traders { + if strings.EqualFold(strings.TrimSpace(traderCfg.Name), traderName) { + return traderCfg, nil + } + } + return nil, fmt.Errorf("trader %q not found", traderName) +} + func (a *Agent) toolGetBackendLogs(storeUserID, argsJSON string) string { var args struct { TraderID string `json:"trader_id"` + TraderName string `json:"trader_name"` Limit int `json:"limit"` ErrorsOnly *bool `json:"errors_only"` } @@ -655,30 +1346,31 @@ func (a *Agent) toolGetBackendLogs(storeUserID, argsJSON string) string { return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err) } } + if a.store == nil { + return `{"error":"store unavailable"}` + } errorsOnly := true if args.ErrorsOnly != nil { errorsOnly = *args.ErrorsOnly } - traderID := strings.TrimSpace(args.TraderID) - if traderID == "" { - return `{"error":"trader_id is required"}` - } - if a.store == nil { - return `{"error":"store unavailable"}` - } - trader, err := a.store.Trader().GetByID(traderID) + traderCfg, err := a.resolveTraderForTool(storeUserID, args.TraderID, args.TraderName) if err != nil { - return fmt.Sprintf(`{"error":"failed to load trader: %s"}`, err) + return fmt.Sprintf(`{"error":"%s"}`, err) } - if trader.UserID != storeUserID { - return `{"error":"trader not found for current user"}` - } - path, entries, err := readBackendLogEntries(args.Limit, traderID, errorsOnly) + path, entries, err := readBackendLogEntries(args.Limit, "", errorsOnly) if err != nil { return fmt.Sprintf(`{"error":"failed to read backend logs: %s"}`, err) } + entries = filterBackendLogEntriesAny(entries, traderCfg.ID, traderCfg.Name) + if args.Limit <= 0 { + args.Limit = 30 + } + if len(entries) > args.Limit { + entries = entries[len(entries)-args.Limit:] + } result, _ := json.Marshal(map[string]any{ - "trader_id": traderID, + "trader_id": traderCfg.ID, + "trader_name": traderCfg.Name, "log_file": path, "entries": entries, "count": len(entries), @@ -687,6 +1379,59 @@ func (a *Agent) toolGetBackendLogs(storeUserID, argsJSON string) string { return string(result) } +func (a *Agent) toolGetDecisions(storeUserID, argsJSON string) string { + var args struct { + TraderID string `json:"trader_id"` + TraderName string `json:"trader_name"` + Limit int `json:"limit"` + } + if strings.TrimSpace(argsJSON) != "" { + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err) + } + } + if a.store == nil { + return `{"error":"store unavailable"}` + } + traderCfg, err := a.resolveTraderForTool(storeUserID, args.TraderID, args.TraderName) + if err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } + limit := args.Limit + if limit <= 0 { + limit = 5 + } + if limit > 20 { + limit = 20 + } + records, err := a.store.Decision().GetLatestRecords(traderCfg.ID, limit) + if err != nil { + return fmt.Sprintf(`{"error":"failed to get decision records: %s"}`, err) + } + items := make([]map[string]any, 0, len(records)) + for _, record := range records { + items = append(items, map[string]any{ + "id": record.ID, + "cycle_number": record.CycleNumber, + "timestamp": record.Timestamp, + "success": record.Success, + "error_message": record.ErrorMessage, + "ai_request_duration_ms": record.AIRequestDurationMs, + "candidate_coins": record.CandidateCoins, + "execution_log": record.ExecutionLog, + "decisions": record.Decisions, + "decision_json": record.DecisionJSON, + }) + } + result, _ := json.Marshal(map[string]any{ + "trader_id": traderCfg.ID, + "trader_name": traderCfg.Name, + "count": len(items), + "records": items, + }) + return string(result) +} + func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { if a.store == nil { return `{"error":"store unavailable"}` @@ -717,13 +1462,18 @@ func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { action := strings.TrimSpace(args.Action) switch action { case "create": - if strings.TrimSpace(args.ExchangeType) == "" { + missing := missingRequiredActionSlots("exchange_management", "create", map[string]string{ + "exchange_type": strings.TrimSpace(args.ExchangeType), + "account_name": strings.TrimSpace(args.AccountName), + }) + if len(missing) > 0 { + return fmt.Sprintf(`{"error":"missing required fields for create: %s"}`, strings.Join(missing, ", ")) + } + exchangeType := strings.TrimSpace(args.ExchangeType) + if exchangeType == "" { return `{"error":"exchange_type is required for create"}` } - enabled := false - if args.Enabled != nil { - enabled = *args.Enabled - } + enabled := true testnet := false if args.Testnet != nil { testnet = *args.Testnet @@ -736,9 +1486,28 @@ func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { if args.LighterAPIKeyIndex != nil { lighterIndex = *args.LighterAPIKeyIndex } + if err := (exchangeConfigValidator{ + exchangeType: exchangeType, + enabled: enabled, + apiKey: strings.TrimSpace(args.APIKey), + secretKey: strings.TrimSpace(args.SecretKey), + passphrase: strings.TrimSpace(args.Passphrase), + hyperliquidWalletAddr: strings.TrimSpace(args.HyperliquidWalletAddr), + asterUser: strings.TrimSpace(args.AsterUser), + asterSigner: strings.TrimSpace(args.AsterSigner), + asterPrivateKey: strings.TrimSpace(args.AsterPrivateKey), + lighterWalletAddr: strings.TrimSpace(args.LighterWalletAddr), + lighterPrivateKey: strings.TrimSpace(args.LighterPrivateKey), + lighterAPIKeyPrivateKey: strings.TrimSpace(args.LighterAPIKeyPrivateKey), + }).Validate(); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } + if err := a.ensureUniqueExchangeAccountName(storeUserID, strings.TrimSpace(args.AccountName), ""); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } id, err := a.store.Exchange().Create( storeUserID, - strings.TrimSpace(args.ExchangeType), + exchangeType, strings.TrimSpace(args.AccountName), enabled, strings.TrimSpace(args.APIKey), @@ -767,6 +1536,28 @@ func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { "action": "create", "exchange": safeExchangeForTool(created), }) + var payload any + if err := json.Unmarshal(result, &payload); err == nil { + result, _ = json.Marshal(stripSensitiveToolFields(payload)) + } + return string(result) + case "query": + if strings.TrimSpace(args.ExchangeID) == "" { + return `{"error":"exchange_id is required for query"}` + } + existing, err := a.store.Exchange().GetByID(storeUserID, strings.TrimSpace(args.ExchangeID)) + if err != nil { + return fmt.Sprintf(`{"error":"failed to load exchange config: %s"}`, err) + } + result, _ := json.Marshal(map[string]any{ + "status": "ok", + "action": "query", + "exchange": safeExchangeForTool(existing), + }) + var payload any + if err := json.Unmarshal(result, &payload); err == nil { + result, _ = json.Marshal(stripSensitiveToolFields(payload)) + } return string(result) case "update": if strings.TrimSpace(args.ExchangeID) == "" { @@ -776,10 +1567,7 @@ func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { if err != nil { return fmt.Sprintf(`{"error":"failed to load exchange config: %s"}`, err) } - enabled := existing.Enabled - if args.Enabled != nil { - enabled = *args.Enabled - } + enabled := true testnet := existing.Testnet if args.Testnet != nil { testnet = *args.Testnet @@ -808,6 +1596,47 @@ func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { if strings.TrimSpace(args.LighterWalletAddr) != "" { lighterWallet = strings.TrimSpace(args.LighterWalletAddr) } + effectiveAPIKey := strings.TrimSpace(string(existing.APIKey)) + if trimmed := strings.TrimSpace(args.APIKey); trimmed != "" { + effectiveAPIKey = trimmed + } + effectiveSecretKey := strings.TrimSpace(string(existing.SecretKey)) + if trimmed := strings.TrimSpace(args.SecretKey); trimmed != "" { + effectiveSecretKey = trimmed + } + effectivePassphrase := strings.TrimSpace(string(existing.Passphrase)) + if trimmed := strings.TrimSpace(args.Passphrase); trimmed != "" { + effectivePassphrase = trimmed + } + effectiveAsterPrivateKey := strings.TrimSpace(string(existing.AsterPrivateKey)) + if trimmed := strings.TrimSpace(args.AsterPrivateKey); trimmed != "" { + effectiveAsterPrivateKey = trimmed + } + effectiveLighterPrivateKey := strings.TrimSpace(string(existing.LighterPrivateKey)) + if trimmed := strings.TrimSpace(args.LighterPrivateKey); trimmed != "" { + effectiveLighterPrivateKey = trimmed + } + effectiveLighterAPIKeyPrivateKey := strings.TrimSpace(string(existing.LighterAPIKeyPrivateKey)) + if trimmed := strings.TrimSpace(args.LighterAPIKeyPrivateKey); trimmed != "" { + effectiveLighterAPIKeyPrivateKey = trimmed + } + validator := exchangeConfigValidator{ + exchangeType: existing.ExchangeType, + enabled: true, + apiKey: effectiveAPIKey, + secretKey: effectiveSecretKey, + passphrase: effectivePassphrase, + hyperliquidWalletAddr: hyperWallet, + asterUser: asterUser, + asterSigner: asterSigner, + asterPrivateKey: effectiveAsterPrivateKey, + lighterWalletAddr: lighterWallet, + lighterPrivateKey: effectiveLighterPrivateKey, + lighterAPIKeyPrivateKey: effectiveLighterAPIKeyPrivateKey, + } + if err := validator.Validate(); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } if err := a.store.Exchange().Update( storeUserID, existing.ID, @@ -829,6 +1658,9 @@ func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { return fmt.Sprintf(`{"error":"failed to update exchange config: %s"}`, err) } if trimmed := strings.TrimSpace(args.AccountName); trimmed != "" && trimmed != existing.AccountName { + if err := a.ensureUniqueExchangeAccountName(storeUserID, trimmed, existing.ID); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } if err := a.store.Exchange().UpdateAccountName(storeUserID, existing.ID, trimmed); err != nil { return fmt.Sprintf(`{"error":"exchange updated but failed to rename account: %s"}`, err) } @@ -842,6 +1674,10 @@ func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { "action": "update", "exchange": safeExchangeForTool(updated), }) + var payload any + if err := json.Unmarshal(result, &payload); err == nil { + result, _ = json.Marshal(stripSensitiveToolFields(payload)) + } return string(result) case "delete": if strings.TrimSpace(args.ExchangeID) == "" { @@ -871,12 +1707,19 @@ func (a *Agent) toolGetModelConfigs(storeUserID string) string { } safe := make([]safeModelToolConfig, 0, len(models)) for _, model := range models { + if !store.IsVisibleAIModel(model) { + continue + } safe = append(safe, safeModelForTool(model)) } result, _ := json.Marshal(map[string]any{ "model_configs": safe, "count": len(safe), }) + var payload any + if err := json.Unmarshal(result, &payload); err == nil { + result, _ = json.Marshal(stripSensitiveToolFields(payload)) + } return string(result) } @@ -905,19 +1748,72 @@ func (a *Agent) toolManageModelConfig(storeUserID, argsJSON string) string { action := strings.TrimSpace(args.Action) switch action { case "create": + missing := missingRequiredActionSlots("model_management", "create", map[string]string{ + "provider": strings.TrimSpace(args.Provider), + }) + if len(missing) > 0 { + return fmt.Sprintf(`{"error":"missing required fields for create: %s"}`, strings.Join(missing, ", ")) + } provider := strings.TrimSpace(args.Provider) if provider == "" { return `{"error":"provider is required for create"}` } + if strings.TrimSpace(args.APIKey) == "" { + return `{"error":"api_key is required for create"}` + } modelID := strings.TrimSpace(args.ModelID) if modelID == "" { modelID = provider } - enabled := false + // Match the manual settings page: newly created model configs should be + // enabled unless the caller explicitly asks to keep them disabled. + enabled := true if args.Enabled != nil { enabled = *args.Enabled } - if err := a.store.AIModel().Update(storeUserID, modelID, enabled, strings.TrimSpace(args.APIKey), strings.TrimSpace(args.CustomAPIURL), strings.TrimSpace(args.CustomModelName)); err != nil { + name := strings.TrimSpace(args.Name) + if name == "" { + name = defaultModelConfigName(provider) + } + customModelName := strings.TrimSpace(args.CustomModelName) + if customModelName == "" && modelProviderSupportsCustomModel(provider) { + customModelName = defaultModelNameForProvider(provider) + } + customAPIURL := strings.TrimSpace(args.CustomAPIURL) + if !modelProviderSupportsCustomAPIURL(provider) { + customAPIURL = "" + } + if err := (modelConfigValidator{ + provider: provider, + enabled: enabled, + apiKey: strings.TrimSpace(args.APIKey), + customAPIURL: customAPIURL, + customModelName: customModelName, + modelID: modelID, + }).Validate(); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } + existingByProvider, err := a.findModelByProvider(storeUserID, provider) + if err != nil { + return fmt.Sprintf(`{"error":"failed to inspect existing model configs: %s"}`, err) + } + excludeID := "" + if existingByProvider != nil { + modelID = existingByProvider.ID + excludeID = existingByProvider.ID + } + if err := a.ensureUniqueModelName(storeUserID, name, excludeID); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } + if err := a.store.AIModel().UpdateWithName( + storeUserID, + modelID, + name, + enabled, + strings.TrimSpace(args.APIKey), + customAPIURL, + customModelName, + ); err != nil { return fmt.Sprintf(`{"error":"failed to create model config: %s"}`, err) } createdID := modelID @@ -936,6 +1832,10 @@ func (a *Agent) toolManageModelConfig(storeUserID, argsJSON string) string { "action": "create", "model": safeModelForTool(model), }) + var payload any + if err := json.Unmarshal(result, &payload); err == nil { + result, _ = json.Marshal(stripSensitiveToolFields(payload)) + } return string(result) case "update": modelID := strings.TrimSpace(args.ModelID) @@ -963,10 +1863,30 @@ func (a *Agent) toolManageModelConfig(storeUserID, argsJSON string) string { if apiKey != "" { effectiveAPIKey = apiKey } - if enabled && !modelConfigUsable(existing.Provider, existing.ID, effectiveAPIKey, customAPIURL, customModelName) { - return `{"error":"cannot enable model config before API key is configured"}` + if err := (modelConfigValidator{ + provider: existing.Provider, + enabled: enabled, + apiKey: effectiveAPIKey, + customAPIURL: customAPIURL, + customModelName: customModelName, + modelID: existing.ID, + }).Validate(); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) } - if err := a.store.AIModel().Update(storeUserID, existing.ID, enabled, apiKey, customAPIURL, customModelName); err != nil { + if trimmed := strings.TrimSpace(args.Name); trimmed != "" && !sameEntityName(trimmed, existing.Name) { + if err := a.ensureUniqueModelName(storeUserID, trimmed, existing.ID); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } + } + if err := a.store.AIModel().UpdateWithName( + storeUserID, + existing.ID, + strings.TrimSpace(args.Name), + enabled, + apiKey, + customAPIURL, + customModelName, + ); err != nil { return fmt.Sprintf(`{"error":"failed to update model config: %s"}`, err) } updated, err := a.store.AIModel().Get(storeUserID, existing.ID) @@ -978,6 +1898,10 @@ func (a *Agent) toolManageModelConfig(storeUserID, argsJSON string) string { "action": "update", "model": safeModelForTool(updated), }) + var payload any + if err := json.Unmarshal(result, &payload); err == nil { + result, _ = json.Marshal(stripSensitiveToolFields(payload)) + } return string(result) case "delete": modelID := strings.TrimSpace(args.ModelID) @@ -1008,6 +1932,9 @@ func (a *Agent) toolGetStrategies(storeUserID string) string { } safeStrategies := make([]safeStrategyToolConfig, 0, len(strategies)) for _, strategy := range strategies { + if !store.IsVisibleStrategy(strategy) { + continue + } safeStrategies = append(safeStrategies, safeStrategyForTool(strategy)) } result, _ := json.Marshal(map[string]any{ @@ -1029,6 +1956,8 @@ func (a *Agent) toolManageStrategy(storeUserID, argsJSON string) string { Lang string `json:"lang"` IsPublic *bool `json:"is_public"` ConfigVisible *bool `json:"config_visible"` + AllowClamped bool `json:"allow_clamped_update"` + Confirmed bool `json:"confirmed"` Config map[string]any `json:"config"` } if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { @@ -1055,9 +1984,30 @@ func (a *Agent) toolManageStrategy(storeUserID, argsJSON string) string { if name == "" { return `{"error":"name is required for create"}` } - var cfg any = store.GetDefaultStrategyConfig(strings.TrimSpace(args.Lang)) + if !args.Confirmed { + return `{"error":"strategy create requires explicit chat confirmation before execution. Present the strategy config summary to the user and ask them to reply 确认创建; do not claim the strategy was created.","requires_confirmation":true}` + } + if lockedField, ok := strategyConfigContainsLockedField(args.Config); ok { + return fmt.Sprintf(`{"error":"%s"}`, strategyLockedFieldError("zh", lockedField)) + } + if err := a.ensureUniqueStrategyName(storeUserID, name, ""); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } + defaultConfig := store.GetDefaultStrategyConfig(strings.TrimSpace(args.Lang)) + var cfg any = defaultConfig + var warnings []string if len(args.Config) > 0 { - cfg = args.Config + merged, err := store.MergeStrategyConfig(defaultConfig, args.Config) + if err != nil { + return fmt.Sprintf(`{"error":"invalid strategy config: %s"}`, err) + } + before := merged + merged.ClampLimits() + warnings = store.StrategyClampWarnings(before, merged, merged.Language) + if len(warnings) > 0 && !args.AllowClamped { + return fmt.Sprintf(`{"error":"%s"}`, formatRiskControlRefusalPrompt(merged.Language, warnings, "确认应用")) + } + cfg = merged } configJSON, err := json.Marshal(cfg) if err != nil { @@ -1081,6 +2031,7 @@ func (a *Agent) toolManageStrategy(storeUserID, argsJSON string) string { "status": "ok", "action": "create", "strategy": safeStrategyForTool(record), + "warnings": warnings, }) return string(payload) case "update": @@ -1088,6 +2039,9 @@ func (a *Agent) toolManageStrategy(storeUserID, argsJSON string) string { if strategyID == "" { return `{"error":"strategy_id is required for update"}` } + if lockedField, ok := strategyConfigContainsLockedField(args.Config); ok { + return fmt.Sprintf(`{"error":"%s"}`, strategyLockedFieldError("zh", lockedField)) + } existing, err := a.store.Strategy().Get(storeUserID, strategyID) if err != nil { return fmt.Sprintf(`{"error":"failed to load strategy: %s"}`, err) @@ -1099,6 +2053,11 @@ func (a *Agent) toolManageStrategy(storeUserID, argsJSON string) string { if trimmed := strings.TrimSpace(args.Name); trimmed != "" { name = trimmed } + if !sameEntityName(name, existing.Name) { + if err := a.ensureUniqueStrategyName(storeUserID, name, existing.ID); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } + } description := existing.Description if trimmed := strings.TrimSpace(args.Description); trimmed != "" { description = trimmed @@ -1112,12 +2071,29 @@ func (a *Agent) toolManageStrategy(storeUserID, argsJSON string) string { configVisible = *args.ConfigVisible } configJSON := existing.Config + var warnings []string if len(args.Config) > 0 { - raw, err := json.Marshal(args.Config) + var existingConfig store.StrategyConfig + if strings.TrimSpace(existing.Config) != "" { + if err := json.Unmarshal([]byte(existing.Config), &existingConfig); err != nil { + return fmt.Sprintf(`{"error":"failed to load existing strategy config: %s"}`, err) + } + } + merged, err := store.MergeStrategyConfig(existingConfig, args.Config) + if err != nil { + return fmt.Sprintf(`{"error":"invalid strategy config: %s"}`, err) + } + before := merged + merged.ClampLimits() + warnings = store.StrategyClampWarnings(before, merged, merged.Language) + if len(warnings) > 0 && !args.AllowClamped { + return fmt.Sprintf(`{"error":"%s"}`, formatRiskControlRefusalPrompt(merged.Language, warnings, "确认应用")) + } + normalized, err := json.Marshal(merged) if err != nil { return fmt.Sprintf(`{"error":"failed to serialize strategy config: %s"}`, err) } - configJSON = string(raw) + configJSON = string(normalized) } record := &store.Strategy{ ID: existing.ID, @@ -1139,6 +2115,7 @@ func (a *Agent) toolManageStrategy(storeUserID, argsJSON string) string { "status": "ok", "action": "update", "strategy": safeStrategyForTool(updated), + "warnings": warnings, }) return string(payload) case "delete": @@ -1229,6 +2206,9 @@ func (a *Agent) toolManageStrategy(storeUserID, argsJSON string) string { if name == "" { return `{"error":"name is required for duplicate"}` } + if err := a.ensureUniqueStrategyName(storeUserID, name, ""); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } newID := fmt.Sprintf("strategy_%d", time.Now().UnixNano()) if err := a.store.Strategy().Duplicate(storeUserID, sourceID, newID, name); err != nil { return fmt.Sprintf(`{"error":"failed to duplicate strategy: %s"}`, err) @@ -1280,8 +2260,27 @@ func (a *Agent) toolListTraders(storeUserID string) string { if err != nil { return fmt.Sprintf(`{"error":"failed to list traders: %s"}`, err) } + if len(traders) == 0 && a != nil && a.store != nil { + if all, listErr := a.store.Trader().ListAll(); listErr == nil && len(all) > 0 { + counts := make(map[string]int) + for _, trader := range all { + uid := strings.TrimSpace(trader.UserID) + if uid == "" { + uid = "default" + } + counts[uid]++ + } + a.log().Warn("toolListTraders returned empty for current store user while traders exist under other user scopes", + "store_user_id", storeUserID, + "known_user_scopes", counts, + ) + } + } safeTraders := make([]safeTraderToolConfig, 0, len(traders)) for _, traderCfg := range traders { + if !store.IsVisibleTrader(traderCfg) { + continue + } isRunning := traderCfg.IsRunning if a.traderManager != nil { if memTrader, err := a.traderManager.GetTrader(traderCfg.ID); err == nil { @@ -1300,32 +2299,13 @@ func (a *Agent) toolListTraders(storeUserID string) string { } func (a *Agent) validateTraderReferences(storeUserID, aiModelID, exchangeID, strategyID string) error { - if strings.TrimSpace(aiModelID) == "" { - return fmt.Errorf("ai_model_id is required") - } - if strings.TrimSpace(exchangeID) == "" { - return fmt.Errorf("exchange_id is required") - } - model, err := a.store.AIModel().Get(storeUserID, strings.TrimSpace(aiModelID)) - if err != nil { - return fmt.Errorf("invalid ai_model_id: %w", err) - } - if !model.Enabled { - return fmt.Errorf("ai model is disabled") - } - exchange, err := a.store.Exchange().GetByID(storeUserID, strings.TrimSpace(exchangeID)) - if err != nil { - return fmt.Errorf("invalid exchange_id: %w", err) - } - if !exchange.Enabled { - return fmt.Errorf("exchange is disabled") - } - if trimmed := strings.TrimSpace(strategyID); trimmed != "" { - if _, err := a.store.Strategy().Get(storeUserID, trimmed); err != nil { - return fmt.Errorf("invalid strategy_id: %w", err) - } - } - return nil + return (traderBindingValidator{ + store: a.store, + storeUserID: storeUserID, + aiModelID: aiModelID, + exchangeID: exchangeID, + strategyID: strategyID, + }).Validate() } func (a *Agent) toolCreateTrader(storeUserID string, args manageTraderArgs) string { @@ -1333,9 +2313,16 @@ func (a *Agent) toolCreateTrader(storeUserID string, args manageTraderArgs) stri if name == "" { return `{"error":"name is required for create"}` } + if err := a.ensureUniqueTraderName(storeUserID, name, ""); err != nil { + return fmt.Sprintf(`{"error":"%s"}`, err) + } if err := a.validateTraderReferences(storeUserID, args.AIModelID, args.ExchangeID, args.StrategyID); err != nil { return fmt.Sprintf(`{"error":"%s"}`, err) } + exchangeCfg, err := a.store.Exchange().GetByID(storeUserID, strings.TrimSpace(args.ExchangeID)) + if err != nil { + return fmt.Sprintf(`{"error":"failed to load exchange config: %s"}`, err) + } scanInterval := 3 if args.ScanIntervalMinutes != nil && *args.ScanIntervalMinutes > 0 { scanInterval = *args.ScanIntervalMinutes @@ -1343,9 +2330,12 @@ func (a *Agent) toolCreateTrader(storeUserID string, args manageTraderArgs) stri scanInterval = 3 } } - initialBalance := 0.0 - if args.InitialBalance != nil && *args.InitialBalance > 0 { - initialBalance = *args.InitialBalance + initialBalance, found, err := traderInitialBalanceFetcher(exchangeCfg, storeUserID) + if err != nil { + return fmt.Sprintf(`{"error":"failed to auto-read trader initial balance from exchange: %s"}`, err) + } + if !found { + return `{"error":"failed to auto-read trader initial balance from exchange"}` } isCrossMargin := true if args.IsCrossMargin != nil { @@ -1356,29 +2346,11 @@ func (a *Agent) toolCreateTrader(storeUserID string, args manageTraderArgs) stri showInCompetition = *args.ShowInCompetition } btcEthLeverage := 10 - if args.BTCETHLeverage != nil && *args.BTCETHLeverage > 0 { - btcEthLeverage = *args.BTCETHLeverage - } altcoinLeverage := 5 - if args.AltcoinLeverage != nil && *args.AltcoinLeverage > 0 { - altcoinLeverage = *args.AltcoinLeverage - } overrideBasePrompt := false - if args.OverrideBasePrompt != nil { - overrideBasePrompt = *args.OverrideBasePrompt - } useAI500 := false - if args.UseAI500 != nil { - useAI500 = *args.UseAI500 - } useOITop := false - if args.UseOITop != nil { - useOITop = *args.UseOITop - } - systemPromptTemplate := strings.TrimSpace(args.SystemPromptTemplate) - if systemPromptTemplate == "" { - systemPromptTemplate = "default" - } + systemPromptTemplate := "default" exchangeIDShort := strings.TrimSpace(args.ExchangeID) if len(exchangeIDShort) > 8 { exchangeIDShort = exchangeIDShort[:8] @@ -1398,10 +2370,10 @@ func (a *Agent) toolCreateTrader(storeUserID string, args manageTraderArgs) stri ShowInCompetition: showInCompetition, BTCETHLeverage: btcEthLeverage, AltcoinLeverage: altcoinLeverage, - TradingSymbols: strings.TrimSpace(args.TradingSymbols), + TradingSymbols: "", UseAI500: useAI500, UseOITop: useOITop, - CustomPrompt: strings.TrimSpace(args.CustomPrompt), + CustomPrompt: "", OverrideBasePrompt: overrideBasePrompt, SystemPromptTemplate: systemPromptTemplate, } @@ -1438,9 +2410,8 @@ func (a *Agent) toolUpdateTrader(storeUserID string, args manageTraderArgs) stri if existing == nil { return `{"error":"trader not found"}` } - name := existing.Name - if trimmed := strings.TrimSpace(args.Name); trimmed != "" { - name = trimmed + if trimmed := strings.TrimSpace(args.Name); trimmed != "" && !sameEntityName(trimmed, existing.Name) { + return `{"error":"trader rename is not supported here; only bindings, scan interval, margin mode, and competition visibility can be edited"}` } aiModelID := existing.AIModelID if trimmed := strings.TrimSpace(args.AIModelID); trimmed != "" { @@ -1460,7 +2431,7 @@ func (a *Agent) toolUpdateTrader(storeUserID string, args manageTraderArgs) stri record := &store.Trader{ ID: existing.ID, UserID: storeUserID, - Name: name, + Name: existing.Name, AIModelID: aiModelID, ExchangeID: exchangeID, StrategyID: strategyID, @@ -1478,9 +2449,6 @@ func (a *Agent) toolUpdateTrader(storeUserID string, args manageTraderArgs) stri OverrideBasePrompt: existing.OverrideBasePrompt, SystemPromptTemplate: existing.SystemPromptTemplate, } - if args.InitialBalance != nil && *args.InitialBalance > 0 { - record.InitialBalance = *args.InitialBalance - } if args.ScanIntervalMinutes != nil && *args.ScanIntervalMinutes > 0 { record.ScanIntervalMinutes = *args.ScanIntervalMinutes if record.ScanIntervalMinutes < 3 { @@ -1493,30 +2461,6 @@ func (a *Agent) toolUpdateTrader(storeUserID string, args manageTraderArgs) stri if args.ShowInCompetition != nil { record.ShowInCompetition = *args.ShowInCompetition } - if args.BTCETHLeverage != nil && *args.BTCETHLeverage > 0 { - record.BTCETHLeverage = *args.BTCETHLeverage - } - if args.AltcoinLeverage != nil && *args.AltcoinLeverage > 0 { - record.AltcoinLeverage = *args.AltcoinLeverage - } - if trimmed := strings.TrimSpace(args.TradingSymbols); trimmed != "" { - record.TradingSymbols = trimmed - } - if trimmed := strings.TrimSpace(args.CustomPrompt); trimmed != "" { - record.CustomPrompt = trimmed - } - if args.OverrideBasePrompt != nil { - record.OverrideBasePrompt = *args.OverrideBasePrompt - } - if trimmed := strings.TrimSpace(args.SystemPromptTemplate); trimmed != "" { - record.SystemPromptTemplate = trimmed - } - if args.UseAI500 != nil { - record.UseAI500 = *args.UseAI500 - } - if args.UseOITop != nil { - record.UseOITop = *args.UseOITop - } if err := a.store.Trader().Update(record); err != nil { return fmt.Sprintf(`{"error":"failed to update trader: %s"}`, err) } @@ -1536,13 +2480,27 @@ func (a *Agent) toolDeleteTrader(storeUserID, traderID string) string { if traderID == "" { return `{"error":"trader_id is required for delete"}` } + if a.traderManager != nil { + if trader, err := a.traderManager.GetTrader(traderID); err == nil { + if running, ok := trader.GetStatus()["is_running"].(bool); ok && running { + return `{"error":"trader is running; stop it before deleting"}` + } + } + } + if record, err := a.store.Trader().GetFullConfig(storeUserID, traderID); err == nil && record != nil && record.Trader != nil && record.Trader.IsRunning { + return `{"error":"trader is running; stop it before deleting"}` + } + if traders, err := a.store.Trader().List(storeUserID); err == nil { + for _, trader := range traders { + if trader != nil && trader.ID == traderID && trader.IsRunning { + return `{"error":"trader is running; stop it before deleting"}` + } + } + } if err := a.store.Trader().Delete(storeUserID, traderID); err != nil { return fmt.Sprintf(`{"error":"failed to delete trader: %s"}`, err) } if a.traderManager != nil { - if trader, err := a.traderManager.GetTrader(traderID); err == nil { - trader.Stop() - } a.traderManager.RemoveTrader(traderID) } result, _ := json.Marshal(map[string]any{ @@ -1737,7 +2695,15 @@ func (a *Agent) toolSearchStock(argsJSON string) string { return string(result) } -func (a *Agent) toolExecuteTrade(_ context.Context, userID int64, lang, argsJSON string) string { +func (a *Agent) toolExecuteTrade(ctx context.Context, userID int64, lang, argsJSON string) string { + policy := sessionPolicyFromContext(ctx) + if !policy.Authenticated { + return `{"error": "trade execution requires an authenticated session"}` + } + if !policy.CanExecuteTrade || a == nil || a.config == nil || !a.config.AllowTradeExecution { + return `{"error": "trade execution is blocked by server policy for this session"}` + } + var args struct { Action string `json:"action"` Symbol string `json:"symbol"` @@ -1802,20 +2768,33 @@ func (a *Agent) toolExecuteTrade(_ context.Context, userID int64, lang, argsJSON Status: "pending_confirmation", CreatedAt: time.Now().Unix(), } + if _, selectedTrader, underlyingTrader, err := a.resolveTradeExecutionContext(trade); err != nil { + return fmt.Sprintf(`{"error": %q}`, err.Error()) + } else if err := validateTradeAction(trade, isStockSymbol(sym), selectedTrader, underlyingTrader); err != nil { + return fmt.Sprintf(`{"error": %q}`, err.Error()) + } a.pending.Add(trade) a.pending.CleanExpired() + confirmMessage := fmt.Sprintf("Trade created. User must confirm with: 确认 %s (or: confirm %s)", trade.ID, trade.ID) + if trade.RequiresLargeOrderConfirmation { + confirmMessage = fmt.Sprintf("Trade created but flagged as high-risk. User must confirm with: 确认大额 %s (or: confirm large %s)", trade.ID, trade.ID) + } + // Return confirmation info to LLM so it can present it to the user resultMap := map[string]any{ - "status": "pending_confirmation", - "trade_id": trade.ID, - "action": trade.Action, - "symbol": trade.Symbol, - "quantity": trade.Quantity, - "leverage": trade.Leverage, - "message": fmt.Sprintf("Trade created. User must confirm with: 确认 %s (or: confirm %s)", trade.ID, trade.ID), - "expires": "5 minutes", + "status": "pending_confirmation", + "trade_id": trade.ID, + "action": trade.Action, + "symbol": trade.Symbol, + "quantity": trade.Quantity, + "leverage": trade.Leverage, + "estimated_price": trade.EstimatedPrice, + "estimated_notional": trade.EstimatedNotional, + "requires_large_order_confirmation": trade.RequiresLargeOrderConfirmation, + "message": confirmMessage, + "expires": "5 minutes", } if marketWarning != "" { resultMap["market_warning"] = marketWarning @@ -1824,13 +2803,27 @@ func (a *Agent) toolExecuteTrade(_ context.Context, userID int64, lang, argsJSON return string(result) } -func (a *Agent) toolGetPositions() string { +func (a *Agent) toolGetPositions(storeUserID string) string { if a.traderManager == nil { return `{"error": "no trader manager configured"}` } + if a.store == nil { + return `{"error": "store unavailable"}` + } + traderConfigs, err := a.store.Trader().List(storeUserID) + if err != nil { + return fmt.Sprintf(`{"error": "failed to list traders: %s"}`, err) + } var positions []map[string]any - for id, t := range a.traderManager.GetAllTraders() { + for _, traderCfg := range traderConfigs { + if strings.TrimSpace(traderCfg.ID) == "" { + continue + } + t, err := a.traderManager.GetTrader(traderCfg.ID) + if err != nil { + continue + } pos, err := t.GetPositions() if err != nil { continue @@ -1840,7 +2833,7 @@ func (a *Agent) toolGetPositions() string { if size == 0 { continue } - tid := id + tid := traderCfg.ID if len(tid) > 8 { tid = tid[:8] } @@ -1866,18 +2859,32 @@ func (a *Agent) toolGetPositions() string { return string(result) } -func (a *Agent) toolGetBalance() string { +func (a *Agent) toolGetBalance(storeUserID string) string { if a.traderManager == nil { return `{"error": "no trader manager configured"}` } + if a.store == nil { + return `{"error": "store unavailable"}` + } + traderConfigs, err := a.store.Trader().List(storeUserID) + if err != nil { + return fmt.Sprintf(`{"error": "failed to list traders: %s"}`, err) + } var balances []map[string]any - for id, t := range a.traderManager.GetAllTraders() { + for _, traderCfg := range traderConfigs { + if strings.TrimSpace(traderCfg.ID) == "" { + continue + } + t, err := a.traderManager.GetTrader(traderCfg.ID) + if err != nil { + continue + } info, err := t.GetAccountInfo() if err != nil { continue } - tid := id + tid := traderCfg.ID if len(tid) > 8 { tid = tid[:8] } @@ -1953,6 +2960,369 @@ func (a *Agent) toolGetMarketPrice(argsJSON string) string { return fmt.Sprintf(`{"error": "could not get price for %s"}`, sym) } +func binanceFuturesGET(path string, out any) error { + req, err := http.NewRequest(http.MethodGet, binanceFuturesAPIBaseURL+path, nil) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + req = req.WithContext(ctx) + + resp, err := marketDataHTTPClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("source returned status %d", resp.StatusCode) + } + return json.NewDecoder(resp.Body).Decode(out) +} + +func (a *Agent) toolGetMarketSnapshot(argsJSON string) string { + var args struct { + Symbol string `json:"symbol"` + Interval string `json:"interval"` + Limit int `json:"limit"` + } + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err) + } + + symbol := strings.ToUpper(strings.TrimSpace(args.Symbol)) + if symbol == "" { + return `{"error":"symbol is required"}` + } + if isStockSymbol(symbol) { + return `{"error":"get_market_snapshot currently supports crypto symbols only"}` + } + if !strings.HasSuffix(symbol, "USDT") { + symbol += "USDT" + } + + interval := strings.TrimSpace(strings.ToLower(args.Interval)) + if interval == "" { + interval = "15m" + } + if !validKlineInterval(interval) { + return fmt.Sprintf(`{"error":"invalid interval %q"}`, interval) + } + + limit := args.Limit + switch { + case limit <= 0: + limit = 20 + case limit > 100: + limit = 100 + } + + var ticker24h struct { + Symbol string `json:"symbol"` + LastPrice string `json:"lastPrice"` + PriceChange string `json:"priceChange"` + PriceChangePercent string `json:"priceChangePercent"` + HighPrice string `json:"highPrice"` + LowPrice string `json:"lowPrice"` + Volume string `json:"volume"` + QuoteVolume string `json:"quoteVolume"` + Count int64 `json:"count"` + } + if err := binanceFuturesGET("/fapi/v1/ticker/24hr?symbol="+symbol, &ticker24h); err != nil { + return fmt.Sprintf(`{"error":"failed to fetch 24h ticker for %s: %s"}`, symbol, err) + } + + var premiumIndex struct { + Symbol string `json:"symbol"` + MarkPrice string `json:"markPrice"` + IndexPrice string `json:"indexPrice"` + LastFundingRate string `json:"lastFundingRate"` + NextFundingTime int64 `json:"nextFundingTime"` + Time int64 `json:"time"` + } + if err := binanceFuturesGET("/fapi/v1/premiumIndex?symbol="+symbol, &premiumIndex); err != nil { + return fmt.Sprintf(`{"error":"failed to fetch funding data for %s: %s"}`, symbol, err) + } + + var openInterest struct { + OpenInterest string `json:"openInterest"` + Symbol string `json:"symbol"` + Time int64 `json:"time"` + } + if err := binanceFuturesGET("/fapi/v1/openInterest?symbol="+symbol, &openInterest); err != nil { + return fmt.Sprintf(`{"error":"failed to fetch open interest for %s: %s"}`, symbol, err) + } + + var rawKlines [][]any + if err := binanceFuturesGET(fmt.Sprintf("/fapi/v1/klines?symbol=%s&interval=%s&limit=%d", symbol, interval, limit), &rawKlines); err != nil { + return fmt.Sprintf(`{"error":"failed to fetch kline for %s: %s"}`, symbol, err) + } + if len(rawKlines) == 0 { + return fmt.Sprintf(`{"error":"empty kline response for %s"}`, symbol) + } + + klines := make([]map[string]any, 0, len(rawKlines)) + highestHigh := 0.0 + lowestLow := 0.0 + firstClose := 0.0 + lastClose := 0.0 + totalVolume := 0.0 + for i, row := range rawKlines { + if len(row) < 7 { + continue + } + openVal := toSnapshotFloat(row[1]) + highVal := toSnapshotFloat(row[2]) + lowVal := toSnapshotFloat(row[3]) + closeVal := toSnapshotFloat(row[4]) + volumeVal := toSnapshotFloat(row[5]) + if i == 0 { + firstClose = closeVal + highestHigh = highVal + lowestLow = lowVal + } + if highVal > highestHigh { + highestHigh = highVal + } + if lowestLow == 0 || (lowVal > 0 && lowVal < lowestLow) { + lowestLow = lowVal + } + lastClose = closeVal + totalVolume += volumeVal + klines = append(klines, map[string]any{ + "open_time": row[0], + "open": openVal, + "high": highVal, + "low": lowVal, + "close": closeVal, + "volume": volumeVal, + "close_time": row[6], + }) + } + + periodChangePercent := 0.0 + if firstClose > 0 && lastClose > 0 { + periodChangePercent = ((lastClose - firstClose) / firstClose) * 100 + } + + tickerLastPrice, _ := strconv.ParseFloat(strings.TrimSpace(ticker24h.LastPrice), 64) + tickerPriceChange, _ := strconv.ParseFloat(strings.TrimSpace(ticker24h.PriceChange), 64) + tickerPriceChangePercent, _ := strconv.ParseFloat(strings.TrimSpace(ticker24h.PriceChangePercent), 64) + tickerHighPrice, _ := strconv.ParseFloat(strings.TrimSpace(ticker24h.HighPrice), 64) + tickerLowPrice, _ := strconv.ParseFloat(strings.TrimSpace(ticker24h.LowPrice), 64) + tickerVolume, _ := strconv.ParseFloat(strings.TrimSpace(ticker24h.Volume), 64) + tickerQuoteVolume, _ := strconv.ParseFloat(strings.TrimSpace(ticker24h.QuoteVolume), 64) + markPrice, _ := strconv.ParseFloat(strings.TrimSpace(premiumIndex.MarkPrice), 64) + indexPrice, _ := strconv.ParseFloat(strings.TrimSpace(premiumIndex.IndexPrice), 64) + fundingRate, _ := strconv.ParseFloat(strings.TrimSpace(premiumIndex.LastFundingRate), 64) + oiValue, _ := strconv.ParseFloat(strings.TrimSpace(openInterest.OpenInterest), 64) + + out, _ := json.Marshal(map[string]any{ + "symbol": symbol, + "price": tickerLastPrice, + "ticker_24h": map[string]any{ + "price_change": tickerPriceChange, + "price_change_percent": tickerPriceChangePercent, + "high_price": tickerHighPrice, + "low_price": tickerLowPrice, + "volume": tickerVolume, + "quote_volume": tickerQuoteVolume, + "trade_count": ticker24h.Count, + }, + "perp_metrics": map[string]any{ + "mark_price": markPrice, + "index_price": indexPrice, + "funding_rate": fundingRate, + "next_funding_time": premiumIndex.NextFundingTime, + "open_interest": oiValue, + }, + "kline_snapshot": map[string]any{ + "interval": interval, + "limit": len(klines), + "period_change_percent": periodChangePercent, + "highest_high": highestHigh, + "lowest_low": lowestLow, + "average_volume": totalVolume / float64(maxInt(len(klines), 1)), + "recent_klines": klines, + }, + }) + return string(out) +} + +func toSnapshotFloat(value any) float64 { + switch v := value.(type) { + case string: + f, _ := strconv.ParseFloat(strings.TrimSpace(v), 64) + return f + case float64: + return v + case json.Number: + f, _ := v.Float64() + return f + default: + return 0 + } +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func strategyLockedFieldError(lang, field string) string { + switch strings.TrimSpace(field) { + case "max_positions": + if lang == "zh" { + return "最大持仓数是 System enforced 字段,策略编辑页不提供普通输入控件,Agent 不能修改。" + } + return "Max positions is System enforced in the strategy editor and cannot be changed by the agent." + case "btceth_max_position_value_ratio": + if lang == "zh" { + return "BTC/ETH 单币仓位上限是 System enforced 字段,策略编辑页不提供普通输入控件,Agent 不能修改。" + } + return "BTC/ETH position value ratio is System enforced in the strategy editor and cannot be changed by the agent." + case "altcoin_max_position_value_ratio": + if lang == "zh" { + return "山寨币单币仓位上限是 System enforced 字段,策略编辑页不提供普通输入控件,Agent 不能修改。" + } + return "Altcoin position value ratio is System enforced in the strategy editor and cannot be changed by the agent." + case "max_margin_usage": + if lang == "zh" { + return "最大保证金使用率是 System enforced 字段,策略编辑页不提供普通输入控件,Agent 不能修改。" + } + return "Max margin usage is System enforced in the strategy editor and cannot be changed by the agent." + case "min_position_size": + if lang == "zh" { + return "最小开仓金额是系统固定值 12 USDT,手动面板里也是 System enforced,Agent 不能修改。" + } + return "The minimum position size is a fixed system value of 12 USDT. It is System enforced in the manual panel and cannot be changed by the agent." + default: + if lang == "zh" { + return "这个字段是系统固定项,Agent 不能修改。" + } + return "This field is system enforced and cannot be changed by the agent." + } +} + +func strategyConfigContainsLockedField(config map[string]any) (string, bool) { + if len(config) == 0 { + return "", false + } + if _, ok := config["min_position_size"]; ok { + return "min_position_size", true + } + if risk, ok := config["risk_control"].(map[string]any); ok { + for _, field := range []string{"max_positions", "btc_eth_max_position_value_ratio", "btceth_max_position_value_ratio", "altcoin_max_position_value_ratio", "max_margin_usage", "min_position_size"} { + if _, ok := risk[field]; ok { + return field, true + } + } + } + if aiConfig, ok := config["ai_config"].(map[string]any); ok { + if risk, ok := aiConfig["risk_control"].(map[string]any); ok { + for _, field := range []string{"max_positions", "btc_eth_max_position_value_ratio", "btceth_max_position_value_ratio", "altcoin_max_position_value_ratio", "max_margin_usage", "min_position_size"} { + if _, ok := risk[field]; ok { + return field, true + } + } + } + } + return "", false +} + +func validKlineInterval(interval string) bool { + switch strings.TrimSpace(strings.ToLower(interval)) { + case "1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w", "1mo": + return true + default: + return false + } +} + +func (a *Agent) toolGetKline(argsJSON string) string { + var args struct { + Symbol string `json:"symbol"` + Interval string `json:"interval"` + Limit int `json:"limit"` + } + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return fmt.Sprintf(`{"error": "invalid arguments: %s"}`, err) + } + + symbol := strings.ToUpper(strings.TrimSpace(args.Symbol)) + if symbol == "" { + return `{"error": "symbol is required"}` + } + if !strings.HasSuffix(symbol, "USDT") { + symbol += "USDT" + } + + interval := strings.TrimSpace(strings.ToLower(args.Interval)) + if interval == "" { + interval = "15m" + } + if !validKlineInterval(interval) { + return fmt.Sprintf(`{"error":"invalid interval %q"}`, interval) + } + + limit := args.Limit + switch { + case limit <= 0: + limit = 50 + case limit > 300: + limit = 300 + } + + url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/klines?symbol=%s&interval=%s&limit=%d", symbol, interval, limit) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return fmt.Sprintf(`{"error":"failed to create request: %s"}`, err) + } + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + req = req.WithContext(ctx) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Sprintf(`{"error":"failed to fetch kline for %s: %s"}`, symbol, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Sprintf(`{"error":"kline source returned status %d for %s"}`, resp.StatusCode, symbol) + } + + var raw [][]any + if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil { + return fmt.Sprintf(`{"error":"failed to parse kline response: %s"}`, err) + } + + candles := make([]map[string]any, 0, len(raw)) + for _, row := range raw { + if len(row) < 7 { + continue + } + candles = append(candles, map[string]any{ + "open_time": row[0], + "open": row[1], + "high": row[2], + "low": row[3], + "close": row[4], + "volume": row[5], + "close_time": row[6], + }) + } + + out, _ := json.Marshal(map[string]any{ + "symbol": symbol, + "interval": interval, + "limit": limit, + "klines": candles, + }) + return string(out) +} + func (a *Agent) toolGetTradeHistory(argsJSON string) string { if a.store == nil { return `{"error": "store not available"}` @@ -2187,6 +3557,94 @@ func candidateCoinDetails(coins []kernel.CandidateCoin) []map[string]any { return out } +func normalizeWatchSymbol(raw string) string { + symbol := strings.ToUpper(strings.TrimSpace(raw)) + symbol = strings.ReplaceAll(symbol, " ", "") + if symbol == "" { + return "" + } + hasQuoteSuffix := strings.HasSuffix(symbol, "USDT") || strings.HasSuffix(symbol, "BUSD") || strings.HasSuffix(symbol, "USDC") + if !hasQuoteSuffix && isStockSymbol(symbol) == false { + return symbol + "USDT" + } + return symbol +} + +func (a *Agent) toolGetWatchlist(lang string) string { + if a.sentinel == nil { + return fmt.Sprintf(`{"error":"%s"}`, a.msg(lang, "sentinel_off")) + } + symbols := a.sentinel.Symbols() + payload := map[string]any{ + "enabled": true, + "count": len(symbols), + "symbols": symbols, + "text": a.sentinel.FormatWatchlist(lang), + } + raw, _ := json.Marshal(payload) + return string(raw) +} + +func (a *Agent) toolManageWatchlist(lang, argsJSON string) string { + if a.sentinel == nil { + return fmt.Sprintf(`{"error":"%s"}`, a.msg(lang, "sentinel_off")) + } + + var args struct { + Action string `json:"action"` + Symbol string `json:"symbol"` + } + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err) + } + + action := strings.ToLower(strings.TrimSpace(args.Action)) + symbol := normalizeWatchSymbol(args.Symbol) + if symbol == "" { + return `{"error":"symbol is required"}` + } + + switch action { + case "add": + a.sentinel.AddSymbol(symbol) + case "remove": + a.sentinel.RemoveSymbol(symbol) + default: + return `{"error":"unsupported action"}` + } + + symbols := a.sentinel.Symbols() + if a.config != nil { + a.config.WatchSymbols = symbols + } + + message := "" + if lang == "zh" { + if action == "add" { + message = fmt.Sprintf("已把 %s 加入监控。", symbol) + } else { + message = fmt.Sprintf("已把 %s 移出监控。", symbol) + } + } else { + if action == "add" { + message = fmt.Sprintf("Added %s to the watchlist.", symbol) + } else { + message = fmt.Sprintf("Removed %s from the watchlist.", symbol) + } + } + + payload := map[string]any{ + "ok": true, + "action": action, + "symbol": symbol, + "count": len(symbols), + "symbols": symbols, + "message": message, + } + raw, _ := json.Marshal(payload) + return string(raw) +} + // knownCryptoSymbols is a set of well-known cryptocurrency base symbols. // Without this, isStockSymbol("BTC") would incorrectly return true because // "BTC" is 3 uppercase letters and the suffix check only catches "BTCUSDT"-style pairs. diff --git a/agent/tools_test.go b/agent/tools_test.go deleted file mode 100644 index d7ea4918..00000000 --- a/agent/tools_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package agent - -import "testing" - -func TestIsStockSymbol(t *testing.T) { - tests := []struct { - sym string - want bool - }{ - // Known crypto base symbols — must NOT be detected as stock - {"BTC", false}, - {"ETH", false}, - {"SOL", false}, - {"BNB", false}, - {"XRP", false}, - {"DOGE", false}, - {"ADA", false}, - {"AVAX", false}, - {"DOT", false}, - {"LINK", false}, - {"PEPE", false}, - {"SHIB", false}, - {"TRUMP", false}, - {"USDT", false}, - {"USDC", false}, - {"W", false}, // single letter crypto - - // Crypto pairs — must NOT be stock - {"BTCUSDT", false}, - {"ETHUSDT", false}, - {"SOLUSDT", false}, - {"DOGEUSDT", false}, - - // Real stock tickers — must be detected as stock - {"AAPL", true}, - {"TSLA", true}, - {"NVDA", true}, - {"MSFT", true}, - {"GOOGL", true}, - {"AMZN", true}, - {"META", true}, - {"AMD", true}, - {"PLTR", true}, - {"BA", true}, - {"F", true}, // Ford — 1 letter - {"GM", true}, // 2 letters - {"JPM", true}, // 3 letters - - // Mixed / edge cases - {"btc", false}, // lowercase crypto - {"aapl", true}, // lowercase stock (uppercased internally) - {"BTC123", false}, // not pure letters - {"123456", false}, // digits - {"", false}, - } - - for _, tt := range tests { - t.Run(tt.sym, func(t *testing.T) { - got := isStockSymbol(tt.sym) - if got != tt.want { - t.Errorf("isStockSymbol(%q) = %v, want %v", tt.sym, got, tt.want) - } - }) - } -} diff --git a/agent/trade.go b/agent/trade.go index 14e08a79..6c48c0ce 100644 --- a/agent/trade.go +++ b/agent/trade.go @@ -5,22 +5,50 @@ import ( "encoding/json" "fmt" "log/slog" + "math" + "nofx/store" "strings" "sync" "time" ) +const ( + tradeAbsoluteMaxQuantity = 1_000_000.0 + tradeLargeOrderNotionalUSDT = 5_000.0 + tradeHardMaxOrderNotionalUSDT = 100_000.0 + tradeLargeOrderEquityRatio = 0.25 + tradeHardMaxOrderEquityRatio = 1.00 + tradeLargeOrderConfirmCommandZH = "确认大额 %s" + tradeLargeOrderConfirmCommandEN = "confirm large %s" +) + +type tradeSelectedTrader interface { + GetStrategyConfig() *store.StrategyConfig + GetAccountInfo() (map[string]interface{}, error) +} + +type tradeUnderlyingTrader interface { + OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) + OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) + CloseLong(symbol string, quantity float64) (map[string]interface{}, error) + CloseShort(symbol string, quantity float64) (map[string]interface{}, error) + GetMarketPrice(symbol string) (float64, error) +} + // TradeAction represents a parsed trade intent from the LLM or user. type TradeAction struct { - ID string `json:"id"` - Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short" - Symbol string `json:"symbol"` // e.g. "BTCUSDT" - Quantity float64 `json:"quantity"` // amount - Leverage int `json:"leverage"` // leverage multiplier - TraderID string `json:"trader_id"` // which trader to use - Status string `json:"status"` // "pending", "confirmed", "executed", "failed", "expired" - CreatedAt int64 `json:"created_at"` - Error string `json:"error,omitempty"` + ID string `json:"id"` + Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short" + Symbol string `json:"symbol"` // e.g. "BTCUSDT" + Quantity float64 `json:"quantity"` // amount + Leverage int `json:"leverage"` // leverage multiplier + TraderID string `json:"trader_id"` // which trader to use + Status string `json:"status"` // "pending", "confirmed", "executed", "failed", "expired" + CreatedAt int64 `json:"created_at"` + EstimatedPrice float64 `json:"estimated_price,omitempty"` + EstimatedNotional float64 `json:"estimated_notional,omitempty"` + RequiresLargeOrderConfirmation bool `json:"requires_large_order_confirmation,omitempty"` + Error string `json:"error,omitempty"` } // pendingTrades stores pending trade confirmations. @@ -149,57 +177,12 @@ func (a *Agent) executeTrade(ctx context.Context, trade *TradeAction) error { return fmt.Errorf("no trader manager available") } - traders := a.traderManager.GetAllTraders() - if len(traders) == 0 { - return fmt.Errorf("no traders configured") + wantStock, selectedTrader, underlyingTrader, err := a.resolveTradeExecutionContext(trade) + if err != nil { + return err } - - // Determine if this is a stock trade to route to the right exchange - wantStock := isStockSymbol(trade.Symbol) - - // Find a running trader's underlying exchange interface - var underlyingTrader interface { - OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) - OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) - CloseLong(symbol string, quantity float64) (map[string]interface{}, error) - CloseShort(symbol string, quantity float64) (map[string]interface{}, error) - } - - for _, t := range traders { - s := t.GetStatus() - running, _ := s["is_running"].(bool) - if running { - ut := t.GetUnderlyingTrader() - if ut == nil { - continue - } - // Route stock symbols to alpaca traders, crypto to others - exchange := t.GetExchange() - isAlpaca := exchange == "alpaca" - if wantStock && !isAlpaca { - continue // Skip non-stock traders for stock symbols - } - if !wantStock && isAlpaca { - continue // Skip stock traders for crypto symbols - } - underlyingTrader = ut - break - } - } - - if underlyingTrader == nil { - if wantStock { - return fmt.Errorf("no running stock trader (Alpaca) found — configure one to trade stocks") - } - return fmt.Errorf("no running trader supports trade execution") - } - - // Sanity caps to prevent LLM hallucinations or input errors from causing damage. - const maxQuantity = 100000.0 - const maxLeverage = 125 - - if trade.Leverage > maxLeverage { - return fmt.Errorf("leverage %dx exceeds maximum allowed (%dx)", trade.Leverage, maxLeverage) + if err := validateTradeAction(trade, wantStock, selectedTrader, underlyingTrader); err != nil { + return err } switch trade.Action { @@ -207,18 +190,12 @@ func (a *Agent) executeTrade(ctx context.Context, trade *TradeAction) error { if trade.Quantity <= 0 { return fmt.Errorf("quantity must be > 0") } - if trade.Quantity > maxQuantity { - return fmt.Errorf("quantity %.4f exceeds maximum allowed (%.0f)", trade.Quantity, maxQuantity) - } _, err := underlyingTrader.OpenLong(trade.Symbol, trade.Quantity, trade.Leverage) return err case "open_short": if trade.Quantity <= 0 { return fmt.Errorf("quantity must be > 0") } - if trade.Quantity > maxQuantity { - return fmt.Errorf("quantity %.4f exceeds maximum allowed (%.0f)", trade.Quantity, maxQuantity) - } _, err := underlyingTrader.OpenShort(trade.Symbol, trade.Quantity, trade.Leverage) return err case "close_long": @@ -232,6 +209,172 @@ func (a *Agent) executeTrade(ctx context.Context, trade *TradeAction) error { } } +func (a *Agent) resolveTradeExecutionContext(trade *TradeAction) (bool, tradeSelectedTrader, tradeUnderlyingTrader, error) { + if a.traderManager == nil { + return false, nil, nil, fmt.Errorf("no trader manager available") + } + traders := a.traderManager.GetAllTraders() + if len(traders) == 0 { + return false, nil, nil, fmt.Errorf("no traders configured") + } + + wantStock := isStockSymbol(trade.Symbol) + for _, t := range traders { + s := t.GetStatus() + running, _ := s["is_running"].(bool) + if !running { + continue + } + ut := t.GetUnderlyingTrader() + if ut == nil { + continue + } + exchange := t.GetExchange() + isAlpaca := exchange == "alpaca" + if wantStock && !isAlpaca { + continue + } + if !wantStock && isAlpaca { + continue + } + return wantStock, t, ut, nil + } + + if wantStock { + return true, nil, nil, fmt.Errorf("no running stock trader (Alpaca) found — configure one to trade stocks") + } + return false, nil, nil, fmt.Errorf("no running trader supports trade execution") +} + +func validateTradeAction( + trade *TradeAction, + wantStock bool, + selectedTrader tradeSelectedTrader, + underlyingTrader tradeUnderlyingTrader, +) error { + if trade == nil { + return fmt.Errorf("trade is required") + } + if math.IsNaN(trade.Quantity) || math.IsInf(trade.Quantity, 0) { + return fmt.Errorf("quantity must be a finite number") + } + if !strings.HasPrefix(trade.Action, "open_") { + return nil + } + if trade.Quantity <= 0 { + return fmt.Errorf("quantity must be > 0") + } + if trade.Quantity > tradeAbsoluteMaxQuantity { + return fmt.Errorf("quantity %.4f exceeds hard sanity cap %.0f", trade.Quantity, tradeAbsoluteMaxQuantity) + } + + price, err := underlyingTrader.GetMarketPrice(trade.Symbol) + if err != nil { + return fmt.Errorf("failed to fetch market price for %s: %w", trade.Symbol, err) + } + if price <= 0 { + return fmt.Errorf("invalid market price for %s", trade.Symbol) + } + positionValue := trade.Quantity * price + trade.EstimatedPrice = price + trade.EstimatedNotional = positionValue + + if positionValue > tradeHardMaxOrderNotionalUSDT { + return fmt.Errorf("position value %.2f exceeds hard safety cap %.2f USDT", positionValue, tradeHardMaxOrderNotionalUSDT) + } + + var equity float64 + if selectedTrader != nil { + accountInfo, err := selectedTrader.GetAccountInfo() + if err != nil { + return fmt.Errorf("failed to load trader account info: %w", err) + } + equity = toFloat(accountInfo["total_equity"]) + if equity <= 0 { + equity = toFloat(accountInfo["totalEquity"]) + } + if equity <= 0 { + return fmt.Errorf("invalid trader equity for risk validation") + } + if positionValue > equity*tradeHardMaxOrderEquityRatio { + return fmt.Errorf( + "position value %.2f USDT exceeds hard safety cap %.2f USDT (equity %.2f x %.2f)", + positionValue, + equity*tradeHardMaxOrderEquityRatio, + equity, + tradeHardMaxOrderEquityRatio, + ) + } + if positionValue >= equity*tradeLargeOrderEquityRatio { + trade.RequiresLargeOrderConfirmation = true + } + } + if positionValue >= tradeLargeOrderNotionalUSDT { + trade.RequiresLargeOrderConfirmation = true + } + + if wantStock { + if trade.Leverage < 0 { + return fmt.Errorf("leverage must be >= 0") + } + return nil + } + + cfg := store.GetDefaultStrategyConfig("zh") + if selectedTrader != nil && selectedTrader.GetStrategyConfig() != nil { + cfg = *selectedTrader.GetStrategyConfig() + } + riskControl := cfg.RiskControl + + maxLeverage := riskControl.AltcoinMaxLeverage + maxPositionValueRatio := riskControl.AltcoinMaxPositionValueRatio + if isBTCETHSymbol(trade.Symbol) { + maxLeverage = riskControl.BTCETHMaxLeverage + maxPositionValueRatio = riskControl.BTCETHMaxPositionValueRatio + } + if maxLeverage <= 0 { + maxLeverage = 5 + } + if trade.Leverage <= 0 { + return fmt.Errorf("leverage must be > 0") + } + if trade.Leverage > maxLeverage { + return fmt.Errorf("leverage exceeds configured limit (%dx > %dx)", trade.Leverage, maxLeverage) + } + + minPositionSize := riskControl.MinPositionSize + if minPositionSize <= 0 { + minPositionSize = 12 + } + if positionValue < minPositionSize { + return fmt.Errorf("position value %.2f USDT is below configured minimum %.2f USDT", positionValue, minPositionSize) + } + + if maxPositionValueRatio <= 0 { + if isBTCETHSymbol(trade.Symbol) { + maxPositionValueRatio = 5.0 + } else { + maxPositionValueRatio = 1.0 + } + } + maxPositionValue := equity * maxPositionValueRatio + if positionValue > maxPositionValue { + return fmt.Errorf( + "position value %.2f USDT exceeds configured limit %.2f USDT (equity %.2f x %.2f)", + positionValue, + maxPositionValue, + equity, + maxPositionValueRatio, + ) + } + return nil +} + +func isBTCETHSymbol(symbol string) bool { + symbol = strings.ToUpper(strings.TrimSpace(symbol)) + return strings.HasPrefix(symbol, "BTC") || strings.HasPrefix(symbol, "ETH") +} + // formatTradeConfirmation creates a confirmation message for a pending trade. func formatTradeConfirmation(trade *TradeAction, lang string) string { actionNames := map[string]string{ @@ -260,6 +403,13 @@ func formatTradeConfirmation(trade *TradeAction, lang string) string { if trade.Leverage > 0 { msg += fmt.Sprintf("杠杆: %dx\n", trade.Leverage) } + if trade.EstimatedNotional > 0 { + msg += fmt.Sprintf("估算仓位价值: %.2f USDT\n", trade.EstimatedNotional) + } + if trade.RequiresLargeOrderConfirmation { + msg += fmt.Sprintf("\n⚠️ 该订单已触发大额风控,请发送 `"+tradeLargeOrderConfirmCommandZH+"` 执行交易,或忽略取消。", trade.ID) + return msg + } msg += fmt.Sprintf("\n发送 `确认 %s` 执行交易,或忽略取消。", trade.ID) return msg } @@ -273,6 +423,13 @@ func formatTradeConfirmation(trade *TradeAction, lang string) string { if trade.Leverage > 0 { msg += fmt.Sprintf("Leverage: %dx\n", trade.Leverage) } + if trade.EstimatedNotional > 0 { + msg += fmt.Sprintf("Estimated notional: %.2f USDT\n", trade.EstimatedNotional) + } + if trade.RequiresLargeOrderConfirmation { + msg += fmt.Sprintf("\n⚠️ This order triggered high-risk protection. Send `"+tradeLargeOrderConfirmCommandEN+"` to execute, or ignore to cancel.", trade.ID) + return msg + } msg += fmt.Sprintf("\nSend `confirm %s` to execute, or ignore to cancel.", trade.ID) return msg } @@ -282,7 +439,14 @@ func (a *Agent) handleTradeConfirmation(ctx context.Context, userID int64, text, upper := strings.ToUpper(strings.TrimSpace(text)) var tradeID string - if strings.HasPrefix(upper, "确认 ") || strings.HasPrefix(upper, "CONFIRM ") { + largeConfirm := false + if strings.HasPrefix(upper, "确认大额 ") || strings.HasPrefix(upper, "CONFIRM LARGE ") { + largeConfirm = true + parts := strings.Fields(text) + if len(parts) >= 2 { + tradeID = parts[len(parts)-1] + } + } else if strings.HasPrefix(upper, "确认 ") || strings.HasPrefix(upper, "CONFIRM ") { parts := strings.Fields(text) if len(parts) >= 2 { tradeID = parts[1] @@ -304,6 +468,12 @@ func (a *Agent) handleTradeConfirmation(ctx context.Context, userID int64, text, } return "❌ Trade expired or not found.", true } + if trade.RequiresLargeOrderConfirmation && !largeConfirm { + if lang == "zh" { + return fmt.Sprintf("⚠️ 这是一笔大额订单,请发送 `"+tradeLargeOrderConfirmCommandZH+"` 继续执行。", trade.ID), true + } + return fmt.Sprintf("⚠️ This is a high-risk order. Send `"+tradeLargeOrderConfirmCommandEN+"` to continue.", trade.ID), true + } a.pending.Remove(tradeID) trade.Status = "confirmed" diff --git a/agent/trader_scope_test.go b/agent/trader_scope_test.go new file mode 100644 index 00000000..24c6730d --- /dev/null +++ b/agent/trader_scope_test.go @@ -0,0 +1,2027 @@ +package agent + +import ( + "context" + "encoding/json" + "log/slog" + "path/filepath" + "strings" + "testing" + "time" + + "nofx/mcp" + "nofx/store" +) + +type staticAIClient struct { + response string + lastRequest *mcp.Request +} + +func (c *staticAIClient) SetAPIKey(apiKey string, customURL string, customModel string) {} +func (c *staticAIClient) SetTimeout(timeout time.Duration) {} +func (c *staticAIClient) CallWithMessages(systemPrompt, userPrompt string) (string, error) { + return c.response, nil +} +func (c *staticAIClient) CallWithRequest(req *mcp.Request) (string, error) { + c.lastRequest = req + return c.response, nil +} +func (c *staticAIClient) CallWithRequestStream(req *mcp.Request, onChunk func(string)) (string, error) { + c.lastRequest = req + if onChunk != nil { + onChunk(c.response) + } + return c.response, nil +} +func (c *staticAIClient) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) { + c.lastRequest = req + return &mcp.LLMResponse{Content: c.response}, nil +} + +func TestClassifyWorkflowTaskTreatsTraderEditAsManualPanelUpdate(t *testing.T) { + task, ok := classifyWorkflowTask("帮我把交易员小爱换策略") + if !ok { + t.Fatal("expected trader binding edit to classify") + } + if task.Skill != "trader_management" || task.Action != "update_bindings" { + t.Fatalf("unexpected task: %+v", task) + } + + task, ok = classifyWorkflowTask("帮我把交易员小爱扫描间隔改成10分钟") + if !ok { + t.Fatal("expected trader manual-panel edit to classify") + } + if task.Skill != "trader_management" || task.Action != "update_bindings" { + t.Fatalf("unexpected trader update task: %+v", task) + } +} + +func TestGetDecisionsToolReturnsRecentTraderDecisionEvidence(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "decision-evidence.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + traderCfg := &store.Trader{ + ID: "trader-claw402", + UserID: "default", + Name: "claw402", + AIModelID: "model-1", + ExchangeID: "exchange-1", + InitialBalance: 6.21, + ScanIntervalMinutes: 3, + IsRunning: true, + } + if err := st.Trader().Create(traderCfg); err != nil { + t.Fatalf("seed trader: %v", err) + } + if err := st.Decision().LogDecision(&store.DecisionRecord{ + TraderID: traderCfg.ID, + CycleNumber: 150, + Timestamp: time.Now().Add(-3 * time.Minute), + Success: true, + AIRequestDurationMs: 12095, + CandidateCoins: []string{"BTCUSDT"}, + ExecutionLog: []string{"AI call duration: 12095 ms", "✓ BTCUSDT wait succeeded"}, + Decisions: []store.DecisionAction{{ + Symbol: "BTCUSDT", + Action: "wait", + Success: true, + }}, + }); err != nil { + t.Fatalf("seed wait decision: %v", err) + } + if err := st.Decision().LogDecision(&store.DecisionRecord{ + TraderID: traderCfg.ID, + CycleNumber: 151, + Timestamp: time.Now(), + Success: false, + ErrorMessage: "Failed to get AI decision: failed to parse AI response: decision validation failed: decision #1 validation failed: BTCUSDT opening amount too small (28.00 USDT), must be ≥60.00 USDT", + AIRequestDurationMs: 25878, + CandidateCoins: []string{"BTCUSDT"}, + ExecutionLog: []string{"AI call duration: 25878 ms"}, + DecisionJSON: `[{"symbol":"BTCUSDT","action":"open_short","position_size_usd":28}]`, + }); err != nil { + t.Fatalf("seed rejected decision: %v", err) + } + + raw := a.toolGetDecisions("default", `{"trader_name":"claw402","limit":2}`) + for _, want := range []string{"claw402", "BTCUSDT", "wait", "wait succeeded", "opening amount too small", "must be ≥60.00 USDT"} { + if !strings.Contains(raw, want) { + t.Fatalf("expected decision evidence %q in tool response, got: %s", want, raw) + } + } +} + +func TestTraderDiagnosisReadsDecisionsInsteadOfAskingUserForScreenshot(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-diagnosis-decisions.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + traderCfg := &store.Trader{ + ID: "trader-claw402", + UserID: "default", + Name: "claw402", + AIModelID: "model-1", + ExchangeID: "exchange-1", + InitialBalance: 6.21, + ScanIntervalMinutes: 3, + IsRunning: true, + } + if err := st.Trader().Create(traderCfg); err != nil { + t.Fatalf("seed trader: %v", err) + } + if err := st.Decision().LogDecision(&store.DecisionRecord{ + TraderID: traderCfg.ID, + CycleNumber: 1, + Timestamp: time.Now(), + Success: true, + AIRequestDurationMs: 13249, + CandidateCoins: []string{"BTCUSDT"}, + ExecutionLog: []string{"AI call duration: 13249 ms", "✓ BTCUSDT wait succeeded"}, + Decisions: []store.DecisionAction{{ + Symbol: "BTCUSDT", + Action: "wait", + Success: true, + }}, + }); err != nil { + t.Fatalf("seed decision: %v", err) + } + + reply := a.handleTraderDiagnosisSkill("default", "zh", "为什么我的claw402交易员一直不开单呢") + for _, want := range []string{"claw402 是运行的", "主动选择等待", "入场标准", "该怎么办"} { + if !strings.Contains(reply, want) { + t.Fatalf("expected diagnosis to include %q, got: %s", want, reply) + } + } + for _, unexpected := range []string{"截图", "自己点", "不能直接帮你查", "诊断证据包", "AI 调用耗时", "status 402", "404", "EOF", "订阅"} { + if strings.Contains(reply, unexpected) { + t.Fatalf("diagnosis should not ask user to self-serve %q, got: %s", unexpected, reply) + } + } +} + +func TestTraderDiagnosisAmountTooSmallUsesUserFacingCauseAndAction(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-diagnosis-amount-too-small.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + traderCfg := &store.Trader{ + ID: "trader-claw402", + UserID: "default", + Name: "claw402", + AIModelID: "model-1", + ExchangeID: "exchange-1", + InitialBalance: 6.21, + ScanIntervalMinutes: 3, + IsRunning: true, + } + if err := st.Trader().Create(traderCfg); err != nil { + t.Fatalf("seed trader: %v", err) + } + if err := st.Decision().LogDecision(&store.DecisionRecord{ + TraderID: traderCfg.ID, + CycleNumber: 2, + Timestamp: time.Now(), + Success: false, + ErrorMessage: "Failed to get AI decision: failed to parse AI response: decision validation failed: decision #1 validation failed: BTCUSDT opening amount too small (28.00 USDT), must be ≥60.00 USDT", + AIRequestDurationMs: 25878, + CandidateCoins: []string{"BTCUSDT"}, + ExecutionLog: []string{"AI call duration: 25878 ms"}, + DecisionJSON: `[{"symbol":"BTCUSDT","action":"open_short","position_size_usd":28}]`, + }); err != nil { + t.Fatalf("seed decision: %v", err) + } + + reply := a.handleTraderDiagnosisSkill("default", "zh", "为什么我的claw402交易员一直不开单呢") + for _, want := range []string{"不是没运行", "账户资金太小", "开仓金额约 28.00 USDT", "最小下单要求 60.00 USDT", "增加账户资金", "不能手动修改"} { + if !strings.Contains(reply, want) { + t.Fatalf("expected diagnosis to include %q, got: %s", want, reply) + } + } + for _, unexpected := range []string{"诊断证据包", "辅助异常", "status 402", "404", "EOF", "订阅", "数据服务", "position_size_usd", "AI 调用耗时"} { + if strings.Contains(reply, unexpected) { + t.Fatalf("diagnosis should stay user-facing and avoid %q, got: %s", unexpected, reply) + } + } +} + +func TestTraderDiagnosisUsesLLMToReasonOverCollectedEvidence(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-diagnosis-llm.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + llm := &staticAIClient{response: "claw402 的最终原因是账户资金太小,最近想开 BTCUSDT 空单但金额低于最小下单要求。该怎么办:增加账户资金,或换更适合小资金的策略/标的。"} + a := New(nil, st, DefaultConfig(), slog.Default()) + a.SetAIClient(llm) + traderCfg := &store.Trader{ + ID: "trader-claw402", + UserID: "default", + Name: "claw402", + AIModelID: "model-1", + ExchangeID: "exchange-1", + InitialBalance: 6.21, + ScanIntervalMinutes: 3, + IsRunning: true, + } + if err := st.Trader().Create(traderCfg); err != nil { + t.Fatalf("seed trader: %v", err) + } + if err := st.Decision().LogDecision(&store.DecisionRecord{ + TraderID: traderCfg.ID, + CycleNumber: 3, + Timestamp: time.Now(), + Success: false, + ErrorMessage: "BTCUSDT opening amount too small (28.00 USDT), must be ≥60.00 USDT", + CandidateCoins: []string{"BTCUSDT"}, + DecisionJSON: `[{"symbol":"BTCUSDT","action":"open_short","position_size_usd":28}]`, + }); err != nil { + t.Fatalf("seed decision: %v", err) + } + + reply := a.handleTraderDiagnosisSkill("default", "zh", "为什么我的claw402交易员一直不开单呢") + if reply != llm.response { + t.Fatalf("expected LLM diagnosis response, got: %s", reply) + } + if llm.lastRequest == nil || len(llm.lastRequest.Messages) < 2 { + t.Fatalf("expected LLM request to be captured") + } + prompt := llm.lastRequest.Messages[1].Content + for _, want := range []string{"Evidence JSON", "claw402", "BTCUSDT", "opening amount too small", "decision_json"} { + if !strings.Contains(prompt, want) { + t.Fatalf("expected LLM evidence prompt to include %q, got: %s", want, prompt) + } + } +} + +func TestTraderDomainPrimerExplainsInternalConfigBoundary(t *testing.T) { + primer := buildSkillDomainPrimer("zh", "trader_management") + for _, want := range []string{ + "交易员是装配层", + "默认只处理绑定关系", + "应切到对应 management skill", + } { + if !strings.Contains(primer, want) { + t.Fatalf("expected primer to contain %q, got: %s", want, primer) + } + } +} + +func TestStrategyDomainPrimerKeepsSourceCountsWithinEditorBounds(t *testing.T) { + primer := buildSkillDomainPrimerForSession("zh", skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "strategy_type": "ai_trading", + }, + }) + for _, want := range []string{ + "AI500/OI Top/OI Low 选币数量范围 1~10", + "没有 mixed/混合模式", + "BTC/ETH 最大杠杆 1~20", + "min_confidence 50~100", + } { + if !strings.Contains(primer, want) { + t.Fatalf("expected primer to contain %q, got: %s", want, primer) + } + } +} + +func TestStrategyConfigSchemaOnlyExposesEditorCoinSourceFields(t *testing.T) { + schema := strategyConfigSchema() + properties := schema["properties"].(map[string]any) + aiConfig := properties["ai_config"].(map[string]any) + aiProperties := aiConfig["properties"].(map[string]any) + coinSource := aiProperties["coin_source"].(map[string]any) + coinProperties := coinSource["properties"].(map[string]any) + for _, unexpected := range []string{"use_hyper_all", "use_hyper_main", "hyper_main_limit"} { + if _, ok := coinProperties[unexpected]; ok { + t.Fatalf("strategy config schema should not expose non-editor coin source field %s", unexpected) + } + } + ai500 := coinProperties["ai500_limit"].(map[string]any) + if ai500["maximum"] != 10 { + t.Fatalf("expected AI500 maximum 10, got %+v", ai500) + } +} + +func TestLoadEnabledModelOptionsUseConfigNameAsPrimaryLabel(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-model-options.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.AIModel().UpdateWithName("default", "default_deepseek", "DeepSeek AI", true, "sk-test-12345", "", "deepseek-chat"); err != nil { + t.Fatalf("seed model: %v", err) + } + + options := a.loadEnabledModelOptions("default") + if len(options) != 1 { + t.Fatalf("expected one model option, got %d", len(options)) + } + if options[0].Name != "DeepSeek AI" { + t.Fatalf("expected primary option label to stay on config name, got %q", options[0].Name) + } + if !strings.Contains(options[0].Hint, "deepseek-chat") || !strings.Contains(options[0].Hint, "deepseek") { + t.Fatalf("expected hint to retain runtime model/provider context, got %q", options[0].Hint) + } +} + +func TestHydrateCreateTraderSlotReferencesNormalizesModelIDFromVisibleName(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-model-id-normalize.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.AIModel().UpdateWithName("default", "default_deepseek", "DeepSeek AI", true, "sk-test-12345", "", "deepseek-chat"); err != nil { + t.Fatalf("seed model: %v", err) + } + + session := skillSession{ + Name: "trader_management", + Action: "create", + Fields: map[string]string{ + "model_id": "DeepSeek AI", + }, + } + a.hydrateCreateTraderSlotReferences("default", &session) + if got := fieldValue(session, "model_id"); got != "default_deepseek" { + t.Fatalf("expected visible model name in model_id slot to normalize to actual id, got %q", got) + } + if got := fieldValue(session, "model_name"); got != "DeepSeek AI" { + t.Fatalf("expected normalized model name to be preserved, got %q", got) + } +} + +func TestHydrateCreateTraderSlotReferencesNormalizesExchangeIDFromVisibleName(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-exchange-id-normalize.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + exchangeID, err := st.Exchange().Create("default", "okx", "小偶", true, "api-test", "secret-test", "pass", false, "", false, "", "", "", "", "", "", 0) + if err != nil { + t.Fatalf("seed exchange: %v", err) + } + + session := skillSession{ + Name: "trader_management", + Action: "create", + Fields: map[string]string{ + "exchange_id": "小偶", + }, + } + a.hydrateCreateTraderSlotReferences("default", &session) + if got := fieldValue(session, "exchange_id"); got != exchangeID { + t.Fatalf("expected visible exchange name in exchange_id slot to normalize to actual id, got %q", got) + } + if got := fieldValue(session, "exchange_name"); got != "小偶" { + t.Fatalf("expected normalized exchange name to be preserved, got %q", got) + } +} + +func TestToolDeleteTraderRejectsRunningTrader(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "delete-running-trader.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.Trader().Create(&store.Trader{ + ID: "trader-running", + UserID: "default", + Name: "运行中", + AIModelID: "model-1", + ExchangeID: "exchange-1", + InitialBalance: 100, + ScanIntervalMinutes: 3, + IsRunning: true, + }); err != nil { + t.Fatalf("seed trader: %v", err) + } + + resp := a.toolDeleteTrader("default", "trader-running") + if !strings.Contains(resp, "stop it before deleting") { + t.Fatalf("expected running trader delete to be rejected, got: %s", resp) + } + traders, err := st.Trader().List("default") + if err != nil { + t.Fatalf("list traders: %v", err) + } + if len(traders) != 1 { + t.Fatalf("expected running trader to remain, got %d traders", len(traders)) + } +} + +func TestBulkTraderDeleteDeletesOnlyStoppedTraders(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "bulk-delete-traders.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + for _, trader := range []*store.Trader{ + {ID: "trader-stopped", UserID: "default", Name: "已停止", AIModelID: "model-1", ExchangeID: "exchange-1", InitialBalance: 100, ScanIntervalMinutes: 3, IsRunning: false}, + {ID: "trader-running", UserID: "default", Name: "运行中", AIModelID: "model-1", ExchangeID: "exchange-1", InitialBalance: 100, ScanIntervalMinutes: 3, IsRunning: true}, + } { + if err := st.Trader().Create(trader); err != nil { + t.Fatalf("seed trader %s: %v", trader.ID, err) + } + } + + session := skillSession{ + Name: "trader_management", + Action: "delete", + Phase: "await_confirmation", + Fields: map[string]string{ + "bulk_scope": "all", + skillDAGStepField: "await_confirmation", + }, + } + resp := a.executeBulkTraderDelete("default", 99, "zh", "确认", session) + if !strings.Contains(resp, "成功删除 1 个") || !strings.Contains(resp, "运行中") { + t.Fatalf("expected stopped trader deleted and running trader skipped, got: %s", resp) + } + traders, err := st.Trader().List("default") + if err != nil { + t.Fatalf("list traders: %v", err) + } + if len(traders) != 1 || traders[0].ID != "trader-running" { + t.Fatalf("expected only running trader to remain, got: %+v", traders) + } +} + +func TestBulkTraderDeleteRequiresConfirmationBeforeDeleting(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "bulk-delete-traders-confirmation.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.Trader().Create(&store.Trader{ + ID: "trader-stopped", + UserID: "default", + Name: "已停止", + AIModelID: "model-1", + ExchangeID: "exchange-1", + InitialBalance: 100, + ScanIntervalMinutes: 3, + IsRunning: false, + }); err != nil { + t.Fatalf("seed trader: %v", err) + } + + session := skillSession{ + Name: "trader_management", + Action: "delete", + Fields: map[string]string{ + "bulk_scope": "all", + }, + } + resp := a.executeBulkTraderDelete("default", 99, "zh", "全部删除", session) + if !strings.Contains(resp, "请回复“确认”继续") { + t.Fatalf("expected confirmation prompt, got: %s", resp) + } + traders, err := st.Trader().List("default") + if err != nil { + t.Fatalf("list traders: %v", err) + } + if len(traders) != 1 { + t.Fatalf("expected trader to remain before confirmation, got %d traders", len(traders)) + } +} + +func TestResolveTargetSelectionMatchesUniqueNameInUserText(t *testing.T) { + options := []traderSkillOption{ + {ID: "exchange-a", Name: "okx"}, + {ID: "exchange-b", Name: "为:小易"}, + {ID: "exchange-c", Name: "小偶"}, + } + resolved := resolveTargetSelection("先把 为:小易 删掉,其他 5 个先保留", options, nil) + if resolved.Ref == nil { + t.Fatal("expected target ref to resolve from user text") + } + if resolved.Ref.ID != "exchange-b" || resolved.Ref.Name != "为:小易" { + t.Fatalf("unexpected resolved target: %+v", resolved.Ref) + } +} + +func TestStrategyUpdateUsesExplicitTargetOverCurrentReference(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-explicit-target-over-current.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + userID := int64(99) + + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + for _, strategy := range []*store.Strategy{ + {ID: "strategy-short", UserID: "default", Name: "BTC趋势做空", ConfigVisible: true, Config: string(rawCfg)}, + {ID: "strategy-long", UserID: "default", Name: "AI500 做多策略", ConfigVisible: true, Config: string(rawCfg)}, + } { + if err := st.Strategy().Create(strategy); err != nil { + t.Fatalf("seed strategy %s: %v", strategy.ID, err) + } + } + a.saveReferenceMemory(userID, &CurrentReferences{ + Strategy: &EntityReference{ID: "strategy-short", Name: "BTC趋势做空", Source: "tool_output"}, + }, nil) + + patch := map[string]any{ + "coin_source": map[string]any{ + "source_type": "ai500", + "use_ai500": true, + "ai500_limit": 5, + }, + "custom_prompt": "AI500 强做多策略:只寻找强趋势多头机会。", + } + rawPatch, _ := json.Marshal(patch) + session := skillSession{ + Name: "strategy_management", + Action: "update_config", + Phase: "collecting", + Fields: map[string]string{strategyCreateConfigPatchField: string(rawPatch)}, + } + + reply, handled := a.handleSimpleEntitySkill( + "default", + userID, + "zh", + "我想基于AI500 做多策略来调整成更强的做多逻辑", + session, + "strategy_management", + "update_config", + a.loadStrategyOptions("default"), + ) + if !handled { + t.Fatalf("expected handler to handle request") + } + if !strings.Contains(reply, "已更新策略配置") { + t.Fatalf("expected strategy update reply, got: %s", reply) + } + + shortStrategy, err := st.Strategy().Get("default", "strategy-short") + if err != nil { + t.Fatalf("load short strategy: %v", err) + } + longStrategy, err := st.Strategy().Get("default", "strategy-long") + if err != nil { + t.Fatalf("load long strategy: %v", err) + } + if strings.Contains(shortStrategy.Config, "强做多") { + t.Fatalf("current reference strategy was incorrectly updated: %s", shortStrategy.Config) + } + if !strings.Contains(longStrategy.Config, "强做多") { + t.Fatalf("explicitly named strategy was not updated: %s", longStrategy.Config) + } +} + +func TestStrategyUpdateDoesNotInferTargetFromCurrentReference(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-no-current-reference-fallback.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + userID := int64(100) + + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + if err := st.Strategy().Create(&store.Strategy{ + ID: "strategy-short", + UserID: "default", + Name: "BTC趋势做空", + ConfigVisible: true, + Config: string(rawCfg), + }); err != nil { + t.Fatalf("seed strategy: %v", err) + } + a.saveReferenceMemory(userID, &CurrentReferences{ + Strategy: &EntityReference{ID: "strategy-short", Name: "BTC趋势做空", Source: "tool_output"}, + }, nil) + + patch := map[string]any{"custom_prompt": "不应被写入"} + rawPatch, _ := json.Marshal(patch) + session := skillSession{ + Name: "strategy_management", + Action: "update_config", + Phase: "collecting", + Fields: map[string]string{strategyCreateConfigPatchField: string(rawPatch)}, + } + + reply, handled := a.handleSimpleEntitySkill( + "default", + userID, + "zh", + "帮我把策略改强一点", + session, + "strategy_management", + "update_config", + a.loadStrategyOptions("default"), + ) + if !handled { + t.Fatalf("expected handler to ask for target") + } + if !strings.Contains(reply, "确定目标对象") && !strings.Contains(reply, "明确要操作的是哪一个对象") { + t.Fatalf("expected target clarification, got: %s", reply) + } + strategy, err := st.Strategy().Get("default", "strategy-short") + if err != nil { + t.Fatalf("load strategy: %v", err) + } + if strings.Contains(strategy.Config, "不应被写入") { + t.Fatalf("strategy was incorrectly updated through current reference fallback: %s", strategy.Config) + } +} + +func TestBulkStrategyDeleteRequiresConfirmationBeforeDeleting(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "bulk-delete-strategies-confirmation.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + if err := st.Strategy().Create(&store.Strategy{ + ID: "strategy-custom", + UserID: "default", + Name: "自定义策略", + ConfigVisible: true, + Config: string(rawCfg), + }); err != nil { + t.Fatalf("seed strategy: %v", err) + } + + session := skillSession{ + Name: "strategy_management", + Action: "delete", + Fields: map[string]string{ + "bulk_scope": "all", + }, + } + resp := a.executeStrategyManagementAction("default", 99, "zh", "全部删除", session) + if !strings.Contains(resp, "请回复“确认”继续") { + t.Fatalf("expected confirmation prompt, got: %s", resp) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + found := false + for _, strategy := range strategies { + if strategy.ID == "strategy-custom" { + found = true + } + } + if !found { + t.Fatal("expected strategy to remain before confirmation") + } +} + +func TestEnsureLiveTargetReferenceFallsBackFromStaleIDToName(t *testing.T) { + session := skillSession{ + TargetRef: &EntityReference{ + ID: "stale-id", + Name: "小易", + }, + } + options := []traderSkillOption{ + {ID: "exchange-a", Name: "okx"}, + {ID: "exchange-b", Name: "为:小易"}, + } + if !ensureLiveTargetReference(&session, options) { + t.Fatal("expected stale id with matching name to resolve") + } + if session.TargetRef == nil || session.TargetRef.ID != "exchange-b" || session.TargetRef.Name != "为:小易" { + t.Fatalf("unexpected target ref after live check: %+v", session.TargetRef) + } +} + +func TestBuildTraderCreateMissingPromptListsAllMissingSlots(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "trader-create-missing-prompt.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + if err := st.AIModel().UpdateWithName("default", "default_deepseek", "DeepSeek AI", true, "sk-test-12345", "", "deepseek-chat"); err != nil { + t.Fatalf("seed model: %v", err) + } + exchangeID, err := st.Exchange().Create("default", "okx", "OKX 主账户", true, "api-test", "secret-test", "pass", false, "", false, "", "", "", "", "", "", 0) + if err != nil { + t.Fatalf("seed exchange: %v", err) + } + _ = exchangeID + cfg := store.GetDefaultStrategyConfig("zh") + rawCfg, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal strategy config: %v", err) + } + if err := st.Strategy().Create(&store.Strategy{ + ID: "strategy-ai500", + UserID: "default", + Name: "AI500稳重策略", + Description: "test", + IsPublic: false, + ConfigVisible: true, + Config: string(rawCfg), + }); err != nil { + t.Fatalf("seed strategy: %v", err) + } + + session := skillSession{ + Name: "trader_management", + Action: "create", + Phase: "collecting", + Fields: map[string]string{}, + } + prompt := a.buildTraderCreateMissingPrompt("default", "zh", session, a.buildTraderCreateConversationResources("default", session)) + for _, want := range []string{"名称", "交易所", "模型", "策略"} { + if !strings.Contains(prompt, want) { + t.Fatalf("expected missing prompt to include %q, got: %s", want, prompt) + } + } + for _, want := range []string{"现有交易所", "现有模型", "现有策略"} { + if !strings.Contains(prompt, want) { + t.Fatalf("expected missing prompt to include options line %q, got: %s", want, prompt) + } + } +} + +func TestTraderCreateRequiresResolvedResourceIDs(t *testing.T) { + session := skillSession{ + Name: "trader_management", + Action: "create", + Fields: map[string]string{ + "name": "凯茵", + "exchange_name": "Binance", + "model_name": "deepseek", + "strategy_name": "BTC趋势做空", + }, + } + + missing := missingFieldKeysForSkillSession(session) + for _, want := range []string{"exchange_name", "model_name", "strategy_name"} { + if !containsString(missing, want) { + t.Fatalf("expected unresolved %s to remain missing, got %v", want, missing) + } + } + + active := ActiveSkillSession{ + SkillName: "trader_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "凯茵", + "exchange_name": "Binance", + "model_name": "deepseek", + "strategy_name": "BTC趋势做空", + }, + } + activeMissing := missingRequiredFields(active) + for _, want := range []string{"exchange", "model", "strategy"} { + if !containsString(activeMissing, want) { + t.Fatalf("expected unresolved active slot %s to remain missing, got %v", want, activeMissing) + } + } +} + +func TestStrategyCreateUsesConfigPatch(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-config-patch.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + patch := map[string]any{ + "strategy_type": "ai_trading", + "coin_source": map[string]any{ + "source_type": "static", + "static_coins": []any{"BTCUSDT"}, + "use_ai500": false, + "use_oi_low": true, + "oi_low_limit": 1, + }, + "risk_control": map[string]any{ + "max_positions": 1, + "btc_eth_max_leverage": 5, + "altcoin_max_leverage": 5, + "min_confidence": 80, + "min_risk_reward_ratio": 3, + }, + "indicators": map[string]any{ + "klines": map[string]any{ + "primary_timeframe": "5m", + "selected_timeframes": []any{"5m", "15m"}, + }, + }, + "prompt_sections": map[string]any{ + "trading_frequency": "每天最多 2-4 笔,避免过度交易。", + "entry_standards": "只在 BTC 下跌趋势确认时考虑做空,禁止把做多作为主方向。", + }, + "custom_prompt": "BTC 趋势做空策略:仅关注 BTCUSDT,趋势向下且反弹受阻时才考虑开空。", + } + rawPatch, _ := json.Marshal(patch) + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "BTC趋势做空", + strategyCreateConfigPatchField: string(rawPatch), + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "确认创建", session) + if !strings.Contains(reply, "已创建策略") { + t.Fatalf("expected created reply, got: %s", reply) + } + + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + var created *store.Strategy + for _, strategy := range strategies { + if strategy.Name == "BTC趋势做空" { + created = strategy + break + } + } + if created == nil { + t.Fatalf("expected strategy to be created") + } + + var cfg store.StrategyConfig + if err := json.Unmarshal([]byte(created.Config), &cfg); err != nil { + t.Fatalf("unmarshal config: %v", err) + } + if cfg.CoinSource.SourceType != "static" || len(cfg.CoinSource.StaticCoins) != 1 || cfg.CoinSource.StaticCoins[0] != "BTCUSDT" { + t.Fatalf("expected BTC static coin source, got %+v", cfg.CoinSource) + } + if cfg.CoinSource.UseAI500 { + t.Fatalf("expected AI500 disabled for explicit BTC strategy") + } + if cfg.CoinSource.UseOILow { + t.Fatalf("expected OI low disabled when source_type is static, got %+v", cfg.CoinSource) + } + if cfg.RiskControl.MaxPositions != 3 || cfg.RiskControl.MinConfidence != 80 { + t.Fatalf("expected risk patch to apply, got %+v", cfg.RiskControl) + } + if !strings.Contains(cfg.CustomPrompt, "BTC 趋势做空") || !strings.Contains(cfg.PromptSections.EntryStandards, "做空") { + t.Fatalf("expected prompt patch to apply, got custom=%q entry=%q", cfg.CustomPrompt, cfg.PromptSections.EntryStandards) + } +} + +func TestAIStrategySystemEnforcedFieldsAreDisplayedButNotEditable(t *testing.T) { + cfg := store.GetDefaultStrategyConfig("zh") + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "我的AI策略", + }, + } + reply := formatStrategyCreateFinalConfirmation("zh", session, cfg) + for _, want := range []string{"最大持仓数(System enforced)", "BTC/ETH 单币仓位上限(System enforced)", "最大保证金使用率(System enforced)", "最小开仓金额(System enforced)"} { + if !strings.Contains(reply, want) { + t.Fatalf("expected final summary to display %q, got: %s", want, reply) + } + } + + resp := applyStrategyConfigPatch(&cfg, "max_margin_usage", "0.5") + if resp == nil || !strings.Contains(resp.Error(), "System enforced") { + t.Fatalf("expected system enforced edit to be rejected, got: %v", resp) + } +} + +func TestStrategyCreateNaturalLanguageDoesNotBypassTemplateType(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-draft-two-turn.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + active := ActiveSkillSession{ + SessionID: "as_test", + UserID: 1, + SkillName: "strategy_management", + ActionName: "create", + Goal: "真的去创建一个趋势策略,交易BTC和ETH,15m,杠杆 5 倍", + CollectedFields: map[string]any{ + "name": "BTCETH_15m_趋势", + }, + LocalHistory: []chatMessage{ + {Role: "user", Content: "真的去创建一个趋势策略,交易BTC和ETH,15m,杠杆 5 倍"}, + {Role: "assistant", Content: "现在只差一个名称。"}, + {Role: "user", Content: "BTCETH_15m_趋势"}, + }, + } + session := activeToLegacySkillSession(active) + reply := a.handleStrategyCreateSkill("default", 1, "zh", "BTCETH_15m_趋势", session) + if !strings.Contains(reply, "先选择策略类型") { + t.Fatalf("expected strategy type question instead of legacy natural-language parsing, got: %s", reply) + } + + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + if len(strategies) != 0 { + t.Fatalf("expected no strategy before template is complete, got %d", len(strategies)) + } +} + +func TestStrategyCreateAsksTypeBeforeUsingDefaultTemplateType(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-ask-type.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "我的策略", + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "我的策略", session) + if !strings.Contains(reply, "先选择策略类型") || strings.Contains(reply, "交易所") { + t.Fatalf("expected strategy type question without exchange binding, got: %s", reply) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + for _, strategy := range strategies { + if strategy.Name == "我的策略" { + t.Fatalf("strategy should not be created before type is confirmed") + } + } +} + +func TestStrategyCreateConfirmationStillRequiresType(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-confirm-no-type.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "我的策略", + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "确认创建", session) + if !strings.Contains(reply, "先选择策略类型") { + t.Fatalf("expected type question before create, got: %s", reply) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + for _, strategy := range strategies { + if strategy.Name == "我的策略" { + t.Fatalf("strategy should not be created before type is known") + } + } +} + +func TestStrategyCreateStandaloneNameCanContainStrategyWord(t *testing.T) { + active := ActiveSkillSession{ + SessionID: "as_test", + UserID: 1, + SkillName: "strategy_management", + ActionName: "create", + Goal: "创建一个趋势策略,交易BTC和ETH,15m,杠杆 5 倍", + CollectedFields: map[string]any{}, + LocalHistory: []chatMessage{ + {Role: "user", Content: "创建一个趋势策略,交易BTC和ETH,15m,杠杆 5 倍"}, + {Role: "assistant", Content: "现在只差一个名称。"}, + {Role: "user", Content: "趋势策略A"}, + }, + } + + session := activeToLegacySkillSession(active) + if got := fieldValue(session, "name"); got != "趋势策略A" { + t.Fatalf("expected standalone strategy name to be preserved, got %q", got) + } +} + +func TestStrategyCreateProposesGridDefaultsBeforeCreate(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-grid-create-draft.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "我的网格策略", + "strategy_type": "grid_trading", + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "grid_trading", session) + if !strings.Contains(reply, "还缺") || !strings.Contains(reply, "交易对") || !strings.Contains(reply, "网格数量") { + t.Fatalf("expected grid template missing-fields prompt, got: %s", reply) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + for _, strategy := range strategies { + if strategy.Name == "我的网格策略" { + t.Fatalf("strategy should not be created before grid config is ready") + } + } +} + +func TestStrategyCreateSwitchingTypeDropsPreviousTemplateFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-switch-type.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + aiPatch := map[string]any{ + "strategy_type": "ai_trading", + "ai_config": map[string]any{ + "coin_source": map[string]any{"source_type": "ai500"}, + "risk_control": map[string]any{ + "min_confidence": 80, + "min_risk_reward_ratio": 3, + }, + }, + } + rawPatch, _ := json.Marshal(aiPatch) + session := skillSession{ + Name: "strategy_management", + Action: "create", + Phase: "collecting", + Fields: map[string]string{ + "name": "我的网格大大", + "strategy_type": "ai_trading", + strategyCreateConfigPatchField: string(rawPatch), + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "算了选网格策略吧", session) + if !strings.Contains(reply, "还缺") || !strings.Contains(reply, "交易对") { + t.Fatalf("expected grid missing fields after type switch, got: %s", reply) + } + if strings.Contains(reply, "AI500") || strings.Contains(reply, "置信度") { + t.Fatalf("type switch should not reuse AI fields or default BTC summary, got: %s", reply) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + for _, strategy := range strategies { + if strategy.Name == "我的网格大大" { + t.Fatalf("strategy should not be created after switching type with missing grid fields") + } + } +} + +func TestActiveStrategyCreateFilterIsolatesTemplateOnTypeSwitch(t *testing.T) { + session := ActiveSkillSession{ + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "我的网格大大", + "strategy_type": "ai_trading", + strategyCreateConfigPatchField: map[string]any{ + "strategy_type": "ai_trading", + "ai_config": map[string]any{ + "coin_source": map[string]any{"source_type": "ai500"}, + }, + }, + }, + } + filtered := filterExtractedDataForActiveSession(session, map[string]any{ + "strategy_type": "grid_trading", + strategyCreateConfigPatchField: map[string]any{ + "strategy_type": "grid_trading", + "grid_config": map[string]any{"symbol": "ETHUSDT"}, + "ai_config": map[string]any{"coin_source": map[string]any{"source_type": "ai500"}}, + }, + }, "zh") + mergeExtractedData(&session, filtered) + if got := session.CollectedFields["strategy_type"]; got != "grid_trading" { + t.Fatalf("expected switched strategy type, got %+v", session.CollectedFields) + } + if _, ok := session.CollectedFields["source_type"]; ok { + t.Fatalf("expected AI-only flat fields to be dropped, got %+v", session.CollectedFields) + } + patch := session.CollectedFields[strategyCreateConfigPatchField].(map[string]any) + if _, ok := patch["ai_config"]; ok { + t.Fatalf("expected ai_config to be removed from grid patch, got %+v", patch) + } + if _, ok := patch["grid_config"]; !ok { + t.Fatalf("expected grid_config to remain, got %+v", patch) + } +} + +func TestStrategyCreateConfirmationFillsMissingGridDefaults(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-grid-create-confirm-defaults.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "餐巾纸", + "strategy_type": "grid_trading", + "symbol": "BTCUSDT", + "awaiting_final_confirmation": "true", + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "好的,就这样", session) + if !strings.Contains(reply, "还缺") || strings.Contains(reply, "已创建策略") { + t.Fatalf("expected missing grid fields instead of default create, got: %s", reply) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + for _, strategy := range strategies { + if strategy.Name == "餐巾纸" { + t.Fatalf("strategy should not be created before grid template is complete") + } + } +} + +func TestStrategyCreateGridDraftSummaryDoesNotMentionAIFields(t *testing.T) { + reply := formatStrategyCreateDraftSummary("zh", "我的网格策略", "grid_trading", nil, nil) + for _, unexpected := range []string{"选币来源", "最大持仓", "置信度", "盈亏比", "多周期"} { + if strings.Contains(reply, unexpected) { + t.Fatalf("grid draft summary should not mention AI-only field %q: %s", unexpected, reply) + } + } + for _, expected := range []string{"网格策略", "交易对", "网格数量", "总投入", "杠杆", "价格区间"} { + if !strings.Contains(reply, expected) { + t.Fatalf("grid draft summary should mention %q, got: %s", expected, reply) + } + } +} + +func TestAllowedStrategyCreateFieldsUseConfigPatchOnly(t *testing.T) { + gridSession := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "strategy_type": "grid_trading", + }, + } + gridSpecs := allowedFieldSpecsForSkillSession(gridSession, "zh") + gridKeys := make(map[string]bool, len(gridSpecs)) + for _, spec := range gridSpecs { + gridKeys[spec.Key] = true + } + for _, expected := range []string{"strategy_type", "name", strategyCreateConfigPatchField, "awaiting_final_confirmation"} { + if !gridKeys[expected] { + t.Fatalf("expected grid field %q in specs", expected) + } + } + for _, unexpected := range []string{"symbol", "grid_count", "total_investment", "source_type", "selected_timeframes", "min_confidence", "min_risk_reward_ratio"} { + if gridKeys[unexpected] { + t.Fatalf("strategy create specs should not expose template field %q outside config_patch", unexpected) + } + } +} + +func TestStrategyCreateReadyConfigRequiresFinalConfirmation(t *testing.T) { + patch := map[string]any{ + "strategy_type": "grid_trading", + "grid_config": map[string]any{ + "symbol": "BTCUSDT", + "grid_count": 20, + "total_investment": 200, + "leverage": 2, + "use_atr_bounds": true, + "atr_multiplier": 2, + "distribution": "uniform", + "max_drawdown_pct": 15, + "stop_loss_pct": 8, + "daily_loss_limit_pct": 6, + "use_maker_only": true, + "enable_direction_adjust": false, + }, + } + rawPatch, _ := json.Marshal(patch) + session := ActiveSkillSession{ + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "小白策略", + "strategy_type": "grid_trading", + strategyCreateConfigPatchField: string(rawPatch), + }, + } + + reply, blocked := guardStrategyCreateBeforeFinalConfirmation("zh", session) + if !blocked { + t.Fatalf("expected ready strategy create config to require final confirmation") + } + if !strings.Contains(reply, "确认后我再创建") || !strings.Contains(reply, "BTCUSDT") || !strings.Contains(reply, "20") { + t.Fatalf("expected final confirmation summary, got: %s", reply) + } + + session.CollectedFields["awaiting_final_confirmation"] = true + if _, blocked := guardStrategyCreateBeforeFinalConfirmation("zh", session); !blocked { + t.Fatalf("same-turn awaiting flag without prior assistant confirmation should still be blocked") + } + session.LocalHistory = append(session.LocalHistory, chatMessage{Role: "assistant", Content: reply}) + if _, blocked := guardStrategyCreateBeforeFinalConfirmation("zh", session); blocked { + t.Fatalf("already-confirmable session should not be blocked") + } +} + +func TestStrategyCreateConfirmationForcesSynchronousExecutionRoute(t *testing.T) { + patch := map[string]any{ + "strategy_type": "ai_trading", + "ai_config": map[string]any{ + "coin_source": map[string]any{ + "source_type": "ai500", + "use_ai500": true, + "ai500_limit": 5, + }, + "indicators": map[string]any{ + "klines": map[string]any{ + "primary_timeframe": "1m", + "selected_timeframes": []any{"1m", "5m"}, + }, + }, + "risk_control": map[string]any{ + "btc_eth_max_leverage": 3, + "altcoin_max_leverage": 2, + "min_confidence": 70, + "min_risk_reward_ratio": 1.5, + }, + "prompt_sections": map[string]any{ + "trading_frequency": "高频交易但避免过度交易。", + "entry_standards": "只在短周期趋势明确且风险收益合理时开仓。", + }, + }, + } + rawPatch, _ := json.Marshal(patch) + session := ActiveSkillSession{ + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "AI500高频交易", + "strategy_type": "ai_trading", + strategyCreateConfigPatchField: string(rawPatch), + }, + LocalHistory: []chatMessage{ + {Role: "assistant", Content: "请确认是否按以上设置创建?如果没问题,我就执行创建。"}, + }, + } + for _, confirmation := range []string{"确认创建", "可以", "好的", "没问题", "ok"} { + t.Run(confirmation, func(t *testing.T) { + sessionCopy := session + sessionCopy.CollectedFields = map[string]any{ + "name": "AI500高频交易", + "strategy_type": "ai_trading", + strategyCreateConfigPatchField: string(rawPatch), + } + decision := activeSessionStepDecision{ + Route: "ask_user", + Reply: "好的,正在为你创建“AI500高频交易”策略……", + } + + if !maybeForceStrategyCreateExecutionOnConfirmation("zh", confirmation, &sessionCopy, &decision) { + t.Fatalf("expected confirmation %q to force execute route", confirmation) + } + if decision.Route != "execute_skill" || decision.Reply != "" { + t.Fatalf("expected synchronous execute route with empty reply, got %+v", decision) + } + if !activeFieldBool(sessionCopy.CollectedFields["awaiting_final_confirmation"]) { + t.Fatalf("expected awaiting_final_confirmation to be set before execution") + } + }) + } +} + +func TestStrategyCreateConfirmationForcesExecutionWithoutPriorPromptPhrase(t *testing.T) { + patch := map[string]any{ + "strategy_type": "ai_trading", + "ai_config": map[string]any{ + "coin_source": map[string]any{ + "source_type": "ai500", + "use_ai500": true, + "ai500_limit": 5, + }, + "indicators": map[string]any{ + "klines": map[string]any{ + "primary_timeframe": "3m", + "selected_timeframes": []any{"3m", "5m", "15m"}, + }, + }, + "risk_control": map[string]any{ + "btc_eth_max_leverage": 3, + "altcoin_max_leverage": 2, + "min_confidence": 75, + "min_risk_reward_ratio": 1.5, + }, + "prompt_sections": map[string]any{ + "trading_frequency": "高频但避免过度交易。", + "entry_standards": "趋势明确、成交量配合、风险收益合理才开仓。", + }, + }, + } + rawPatch, _ := json.Marshal(patch) + session := ActiveSkillSession{ + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "高频稳健AI500", + "strategy_type": "ai_trading", + strategyCreateConfigPatchField: string(rawPatch), + }, + LocalHistory: []chatMessage{ + {Role: "assistant", Content: "这是我建议的一版配置。"}, + }, + } + decision := activeSessionStepDecision{ + Route: "ask_user", + Reply: "好的,马上为你创建“高频稳健AI500”策略。", + } + if !maybeForceStrategyCreateExecutionOnConfirmation("zh", "确认创建", &session, &decision) { + t.Fatalf("expected ready strategy confirmation to force execute even without prior prompt phrase") + } + if decision.Route != "execute_skill" || decision.Reply != "" { + t.Fatalf("expected execute route, got %+v", decision) + } +} + +func TestUnifiedPlannedAgentCannotStealActiveStrategyCreateConfirmation(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-planner-steal.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + a.SetAIClient(&staticAIClient{response: `{"route":"ask_user","reply":"好的,马上为你创建策略。","extracted_data":{}}`}) + + patch := map[string]any{ + "strategy_type": "ai_trading", + "ai_config": map[string]any{ + "coin_source": map[string]any{ + "source_type": "ai500", + "use_ai500": true, + "ai500_limit": 5, + }, + "indicators": map[string]any{ + "klines": map[string]any{ + "primary_timeframe": "5m", + "selected_timeframes": []any{"1m", "5m", "15m"}, + }, + }, + "risk_control": map[string]any{ + "btc_eth_max_leverage": 3, + "altcoin_max_leverage": 2, + "min_confidence": 80, + "min_risk_reward_ratio": 1.5, + }, + "prompt_sections": map[string]any{ + "trading_frequency": "每天最多 5-8 笔,避免连续亏损后追单。", + "entry_standards": "趋势确认、成交量放大、资金费率正常才开仓。", + }, + }, + } + rawPatch, _ := json.Marshal(patch) + userID := int64(42) + session := newActiveSkillSession(userID, "strategy_management", "create") + session.CollectedFields = map[string]any{ + "name": "AI500高频", + "strategy_type": "ai_trading", + "awaiting_final_confirmation": true, + strategyCreateConfigPatchField: string(rawPatch), + } + a.saveActiveSkillSession(session) + + decision := unifiedTurnDecision{ + TopicIntent: "continue_active", + BusinessAction: "planned_agent", + } + reply, handled, err := a.executeUnifiedTurnDecision(context.Background(), "default", userID, "zh", "确认", decision, nil) + if err != nil { + t.Fatalf("execute unified turn: %v", err) + } + if !handled { + t.Fatalf("expected turn to be handled") + } + if strings.Contains(reply, "马上") || strings.Contains(reply, "稍后") || strings.Contains(reply, "正在") { + t.Fatalf("expected planner promise to be bypassed, got: %s", reply) + } + if !strings.Contains(reply, "已创建策略") { + t.Fatalf("expected real strategy creation result, got: %s", reply) + } +} + +func TestStrategyCreateRepairPromiseIsNotReturnedOnConfirmation(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-repair-promise.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + a.SetAIClient(&staticAIClient{response: `{"route":"ask_user","reply":"好的,马上为你创建AI500高频稳健策略。","extracted_data":{}}`}) + + userID := int64(42) + session := newActiveSkillSession(userID, "strategy_management", "create") + session.CollectedFields = map[string]any{ + "name": "AI500高频", + "strategy_type": "ai_trading", + } + session.LocalHistory = []chatMessage{ + {Role: "assistant", Content: "如果你确认没问题,告诉我“确认创建”,我就帮你直接创建。"}, + } + a.saveActiveSkillSession(session) + + reply, handled, err := a.driveActiveSession(context.Background(), "default", userID, "zh", "确认创建", session, nil) + if err != nil { + t.Fatalf("drive active session: %v", err) + } + if !handled { + t.Fatalf("expected confirmation turn to be handled") + } + if strings.Contains(reply, "马上") || strings.Contains(reply, "正在") || strings.Contains(reply, "稍后") { + t.Fatalf("repair promise should not be returned on confirmation, got: %s", reply) + } +} + +func TestModelCreateSessionRedirectsStrategyTypeChoiceToStrategyCreate(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-type-choice-not-model-provider.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + const userID int64 = 42 + a.saveSkillSession(userID, skillSession{ + Name: "model_management", + Action: "create", + Phase: "collecting", + Fields: map[string]string{}, + }) + + reply, ok := a.redirectModelCreateSessionToStrategyCreateIfNeeded("default", userID, "zh", "1.AI交易策略", a.getSkillSession(userID)) + if !ok { + t.Fatalf("expected strategy type choice to redirect away from model create") + } + if strings.Contains(reply, "模型提供商") || strings.Contains(reply, "provider") { + t.Fatalf("strategy type choice must not ask for model provider, got: %s", reply) + } + session := a.getSkillSession(userID) + if session.Name != "strategy_management" || session.Action != "create" { + t.Fatalf("expected active session to be strategy create, got %+v", session) + } + if got := fieldValue(session, "strategy_type"); got != "ai_trading" { + t.Fatalf("expected ai strategy type to be captured, got %q in %+v", got, session) + } +} + +func TestStrategyCreateAskUserReplyIsNotOverriddenByTemplateMissingFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-llm-ask-reply.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + a.SetAIClient(&staticAIClient{response: `{"route":"ask_user","reply":"我会按 AI 策略模板继续填。你如果想稳健,我建议先用 OI Low、15m 主周期、最低置信度 70。确认这个方向吗?"}`}) + + session := ActiveSkillSession{ + UserID: 42, + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "AI高频", + "strategy_type": "ai_trading", + }, + } + reply, handled, err := a.driveActiveSession(context.Background(), "default", 42, "zh", "1h", session, nil) + if err != nil { + t.Fatalf("drive active session: %v", err) + } + if !handled { + t.Fatalf("expected active session to be handled") + } + if strings.Contains(reply, "这份策略模板还没填完整") || strings.Contains(reply, "还缺这些字段") { + t.Fatalf("LLM ask_user reply should not be overridden by hard template missing list, got: %s", reply) + } + if !strings.Contains(reply, "OI Low") || !strings.Contains(reply, "70") { + t.Fatalf("expected LLM reply to pass through, got: %s", reply) + } +} + +func TestStrategyCreateAIReplyRejectsNonTemplateInvestmentQuestion(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-ai-non-template-question.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + a.SetAIClient(&staticAIClient{response: `{"route":"ask_user","reply":"你打算投入多少资金来运行这个策略?比如100U、500U?这样我可以帮你设置止损和仓位。"}`}) + + session := ActiveSkillSession{ + UserID: 42, + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "AI500稳健", + "strategy_type": "ai_trading", + }, + } + reply, handled, err := a.driveActiveSession(context.Background(), "default", 42, "zh", "全部你定吧,稳健就行", session, nil) + if err != nil { + t.Fatalf("drive active session: %v", err) + } + if !handled { + t.Fatalf("expected active session to be handled") + } + for _, blocked := range []string{"投入多少", "100U", "500U", "止损", "仓位"} { + if strings.Contains(reply, blocked) { + t.Fatalf("AI strategy reply should not ask non-template field %q, got: %s", blocked, reply) + } + } +} + +func TestStrategyCreateOptionsQuestionExplainsCurrentMissingField(t *testing.T) { + session := ActiveSkillSession{ + UserID: 42, + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "AI500高频交易", + "strategy_type": "ai_trading", + }, + } + reply, blocked := strategyCreateTemplateMissingReply("zh", "有哪些选择吗", session) + if !blocked { + t.Fatalf("expected options question to be handled") + } + for _, want := range []string{"AI500", "OI Top", "OI Low", "静态币种"} { + if !strings.Contains(reply, want) { + t.Fatalf("expected source options to include %q, got: %s", want, reply) + } + } + if strings.Contains(reply, "还缺") || strings.Contains(reply, "BTC/ETH 最大杠杆") { + t.Fatalf("options question should not repeat the full missing-field list, got: %s", reply) + } +} + +func TestStrategyCreateMissingFieldsIncludeInlineOptions(t *testing.T) { + reply := formatStrategyCreateConfigNeeded("zh", "source_type,primary_timeframe,btceth_max_leverage,min_confidence,trading_frequency") + for _, want := range []string{"AI500", "OI Top", "OI Low", "静态币种", "1m", "1h", "1~20", "50~100", "每天最多"} { + if !strings.Contains(reply, want) { + t.Fatalf("expected missing-field prompt to include option/range %q, got: %s", want, reply) + } + } + if !strings.Contains(reply, "你帮我按稳健/高频/激进来推荐") { + t.Fatalf("expected prompt to offer recommendation shortcut, got: %s", reply) + } +} + +func TestStrategyCreateConfigPatchReplyUsesStructuredMissingFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-recommendation-structured-missing.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + a.SetAIClient(&staticAIClient{response: `{"route":"ask_user","reply":"我建议按高频但稳健来填:主周期 3m,多周期 3m/5m/15m,BTC/ETH 3倍,山寨币 2倍。确认的话我会继续整理完整模板。","extracted_data":{"config_patch":{"strategy_type":"ai_trading","ai_config":{"indicators":{"klines":{"primary_timeframe":"3m","selected_timeframes":["3m","5m","15m"]}}}}}}`}) + + session := ActiveSkillSession{ + UserID: 42, + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "AI500高频交易", + "strategy_type": "ai_trading", + strategyCreateConfigPatchField: map[string]any{ + "strategy_type": "ai_trading", + "ai_config": map[string]any{ + "coin_source": map[string]any{"source_type": "ai500", "use_ai500": true, "ai500_limit": 5}, + "risk_control": map[string]any{ + "min_confidence": 70, + }, + }, + }, + }, + } + reply, handled, err := a.driveActiveSession(context.Background(), "default", 42, "zh", "继续", session, nil) + if err != nil { + t.Fatalf("drive active session: %v", err) + } + if !handled { + t.Fatalf("expected recommendation request to be handled") + } + if !strings.Contains(reply, "这份策略模板还没填完整") { + t.Fatalf("expected structured missing-field prompt after partial config_patch, got: %s", reply) + } + if strings.Contains(reply, "我建议按高频但稳健来填") { + t.Fatalf("LLM free-form recommendation should not be used as the current plan, got: %s", reply) + } + if !strings.Contains(reply, "BTC/ETH 最大杠杆") || !strings.Contains(reply, "开仓标准") { + t.Fatalf("expected deterministic missing template fields, got: %s", reply) + } +} + +func TestStrategyCreateFirstStageConfigProgressUsesStructuredMissingFields(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-first-stage-structured-missing.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + a.SetAIClient(&staticAIClient{response: `{"route":"ask_user","reply":"收到,AI500 和最低置信度 80 是你指定的;其他我建议按高频稳健来定:主周期 3m,多周期 3m/5m/15m,BTC/ETH 3倍,山寨币 2倍,最小盈亏比 2。","extracted_data":{}}`}) + + session := ActiveSkillSession{ + UserID: 42, + SkillName: "strategy_management", + ActionName: "create", + CollectedFields: map[string]any{ + "name": "高频稳健AI500", + "strategy_type": "ai_trading", + strategyCreateConfigProgressThisTurnField: true, + strategyCreateConfigPatchField: map[string]any{ + "strategy_type": "ai_trading", + "ai_config": map[string]any{ + "coin_source": map[string]any{"source_type": "ai500", "use_ai500": true}, + "risk_control": map[string]any{ + "min_confidence": 80, + }, + }, + }, + }, + } + reply, handled, err := a.driveActiveSession(context.Background(), "default", 42, "zh", "选币选AI500,最新置信度80,其他你定,能高频交易稳定就行", session, nil) + if err != nil { + t.Fatalf("drive active session: %v", err) + } + if !handled { + t.Fatalf("expected active session to be handled") + } + if !strings.Contains(reply, "这份策略模板还没填完整") { + t.Fatalf("expected structured missing-field prompt after first-stage config progress, got: %s", reply) + } + if strings.Contains(reply, "其他我建议按高频稳健来定") { + t.Fatalf("LLM free-form recommendation should not be used as the current plan, got: %s", reply) + } + if !strings.Contains(reply, "主周期") || !strings.Contains(reply, "BTC/ETH 最大杠杆") { + t.Fatalf("expected deterministic missing template fields, got: %s", reply) + } +} + +func TestStrategyCreateConfirmationUsesModelRepairForPriorStyleProposal(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-create-style-repair.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + a.SetAIClient(&staticAIClient{response: `{"route":"execute_skill","extracted_data":{"config_patch":{"strategy_type":"ai_trading","ai_config":{"coin_source":{"source_type":"ai500","use_ai500":true,"ai500_limit":3},"indicators":{"klines":{"primary_timeframe":"1m","primary_count":20,"selected_timeframes":["1m","5m","15m"],"enable_multi_timeframe":true,"enable_raw_klines":true},"enable_volume":true,"enable_oi":true,"enable_funding_rate":true,"enable_quant_data":true},"risk_control":{"btc_eth_max_leverage":5,"altcoin_max_leverage":5,"min_confidence":75,"min_risk_reward_ratio":3},"prompt_sections":{"trading_frequency":"高频但不过度交易:目标每小时 1-3 笔;单笔持仓通常 10-30 分钟。","entry_standards":"只在短周期趋势、成交量/OI、资金费率或排行信号形成共振时入场。"}}}}}`}) + + userID := int64(42) + session := newActiveSkillSession(userID, "strategy_management", "create") + session.CollectedFields = map[string]any{ + "name": "AI500极致稳定高频", + "strategy_type": "ai_trading", + } + session.LocalHistory = []chatMessage{ + {Role: "assistant", Content: "我建议主周期改成1分钟,多周期改成1分钟、5分钟、15分钟,交易频率按高频但稳定来写。"}, + } + + reply, handled, err := a.driveActiveSession(context.Background(), "default", userID, "zh", "好的可以,确认创建", session, nil) + if err != nil { + t.Fatalf("drive active session: %v", err) + } + if !handled { + t.Fatalf("expected confirmation to be handled") + } + if !strings.Contains(reply, "已创建策略") { + t.Fatalf("expected real strategy creation after model repair, got: %s", reply) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + if len(strategies) != 1 { + t.Fatalf("expected one created strategy, got %d", len(strategies)) + } + var cfg store.StrategyConfig + if err := json.Unmarshal([]byte(strategies[0].Config), &cfg); err != nil { + t.Fatalf("unmarshal config: %v", err) + } + if cfg.CoinSource.SourceType != "ai500" || cfg.Indicators.Klines.PrimaryTimeframe != "1m" { + t.Fatalf("expected model-repaired AI500 1m strategy, got %+v", cfg) + } +} + +func TestStrategyCreateCreatesGridAfterConfigPatch(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-grid-create-ready.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + patch := map[string]any{ + "strategy_type": "grid_trading", + "grid_config": map[string]any{ + "symbol": "ETHUSDT", + "grid_count": 12, + "total_investment": 1000, + "leverage": 3, + "use_atr_bounds": true, + "atr_multiplier": 2, + "distribution": "gaussian", + "max_drawdown_pct": 15, + "stop_loss_pct": 5, + "daily_loss_limit_pct": 10, + "use_maker_only": true, + }, + } + rawPatch, _ := json.Marshal(patch) + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "我的网格策略", + strategyCreateConfigPatchField: string(rawPatch), + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "确认创建", session) + if !strings.Contains(reply, "已创建策略") { + t.Fatalf("expected create reply, got: %s", reply) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + var created *store.Strategy + for _, strategy := range strategies { + if strategy.Name == "我的网格策略" { + created = strategy + break + } + } + if created == nil { + t.Fatalf("expected grid strategy to be created") + } + var cfg store.StrategyConfig + if err := json.Unmarshal([]byte(created.Config), &cfg); err != nil { + t.Fatalf("unmarshal config: %v", err) + } + if cfg.StrategyType != "grid_trading" || cfg.GridConfig == nil || cfg.GridConfig.Symbol != "ETHUSDT" { + t.Fatalf("expected grid config to persist, got %+v", cfg) + } +} + +func TestManageStrategyToolCreateRequiresConfirmation(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-tool-create-confirmation.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + resp := a.toolManageStrategy("default", `{"action":"create","name":"未确认网格","lang":"zh","config":{"strategy_type":"grid_trading","grid_config":{"symbol":"BTCUSDT","total_investment":200,"use_atr_bounds":true}}}`) + if !strings.Contains(resp, "requires_confirmation") { + t.Fatalf("expected tool create to require confirmation, got: %s", resp) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + for _, strategy := range strategies { + if strategy.Name == "未确认网格" { + t.Fatalf("unconfirmed tool call should not create strategy") + } + } + + resp = a.toolManageStrategy("default", `{"action":"create","name":"已确认网格","lang":"zh","confirmed":true,"allow_clamped_update":true,"config":{"strategy_type":"grid_trading","grid_config":{"symbol":"BTCUSDT","total_investment":200,"use_atr_bounds":true}}}`) + if strings.Contains(resp, `"error"`) { + t.Fatalf("expected confirmed create to succeed, got: %s", resp) + } +} + +func TestStrategyCreateGridPatchInfersStrategyType(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-grid-create-infers-type.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + patch := map[string]any{ + "grid_config": map[string]any{ + "symbol": "BTCUSDT", + "grid_count": 20, + "total_investment": 200, + "leverage": 2, + "use_atr_bounds": true, + "atr_multiplier": 2, + "distribution": "uniform", + "max_drawdown_pct": 15, + "stop_loss_pct": 5, + "daily_loss_limit_pct": 10, + "use_maker_only": true, + }, + } + rawPatch, _ := json.Marshal(patch) + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "小白网格", + strategyCreateConfigPatchField: string(rawPatch), + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "确认创建", session) + if !strings.Contains(reply, "已创建策略") { + t.Fatalf("expected create reply, got: %s", reply) + } + strategies, err := st.Strategy().List("default") + if err != nil { + t.Fatalf("list strategies: %v", err) + } + var cfg store.StrategyConfig + for _, strategy := range strategies { + if strategy.Name == "小白网格" { + if err := json.Unmarshal([]byte(strategy.Config), &cfg); err != nil { + t.Fatalf("unmarshal config: %v", err) + } + break + } + } + if cfg.StrategyType != "grid_trading" || cfg.GridConfig == nil || cfg.GridConfig.Symbol != "BTCUSDT" { + t.Fatalf("expected grid patch to infer grid_trading, got %+v", cfg) + } +} + +func TestStrategyCreateGridPatchKeepsBackendGridDefaults(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "strategy-grid-create-defaults.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), slog.Default()) + + patch := map[string]any{ + "strategy_type": "grid_trading", + "grid_config": map[string]any{ + "symbol": "ETHUSDT", + "grid_count": 20, + "total_investment": 500, + "leverage": 3, + }, + } + rawPatch, _ := json.Marshal(patch) + session := skillSession{ + Name: "strategy_management", + Action: "create", + Fields: map[string]string{ + "name": "餐巾纸", + strategyCreateConfigPatchField: string(rawPatch), + }, + } + + reply := a.handleStrategyCreateSkill("default", 1, "zh", "确认创建", session) + if !strings.Contains(reply, "还缺") || strings.Contains(reply, "已创建策略") { + t.Fatalf("expected incomplete grid patch to ask for missing fields, got: %s", reply) + } +} + +func TestLLMFlowExtractionFiltersFieldsToAllowedSchema(t *testing.T) { + result := llmFlowExtractionResult{ + Intent: "continue", + Tasks: []llmFlowExtractionTask{{ + Skill: "exchange_management", + Action: "create", + Fields: map[string]string{ + "secret": "wrong-key", + "secret_key": "canonical-secret", + "api_key": "api", + }, + }}, + } + filtered := filterLLMFlowExtractionFields(result, []llmFlowFieldSpec{ + {Key: "secret_key"}, + {Key: "api_key"}, + }) + fields := filtered.Tasks[0].Fields + if _, ok := fields["secret"]; ok { + t.Fatalf("expected invented field key to be filtered, got: %+v", fields) + } + if fields["secret_key"] != "canonical-secret" || fields["api_key"] != "api" { + t.Fatalf("expected canonical fields to remain, got: %+v", fields) + } +} + +func TestExchangeCreateAllowedFieldSpecsUseCanonicalSecretKey(t *testing.T) { + specs := allowedFieldSpecsForSkillSession(skillSession{Name: "exchange_management", Action: "create"}, "zh") + foundSecretKey := false + for _, spec := range specs { + if spec.Key == "secret" { + t.Fatal("exchange create schema should not expose non-canonical secret key") + } + if spec.Key == "secret_key" { + foundSecretKey = true + } + } + if !foundSecretKey { + t.Fatal("expected exchange create schema to include canonical secret_key") + } +} + +func TestActiveSessionExtractedDataFiltersToAllowedSchema(t *testing.T) { + session := ActiveSkillSession{ + SkillName: "exchange_management", + ActionName: "create", + CollectedFields: map[string]any{ + "exchange_type": "okx", + }, + } + filtered := filterExtractedDataForActiveSession(session, map[string]any{ + "account_name": "呢呢", + "api_key": "api", + "secret": "wrong-key", + "secret_key": "canonical-secret", + "passphrase": "pass", + }, "zh") + if _, ok := filtered["secret"]; ok { + t.Fatalf("expected central brain alias key to be filtered, got: %+v", filtered) + } + for _, key := range []string{"account_name", "api_key", "secret_key", "passphrase"} { + if _, ok := filtered[key]; !ok { + t.Fatalf("expected canonical key %q to remain, got: %+v", key, filtered) + } + } +} + +func TestBrainUserPromptIncludesActiveAllowedFieldSchema(t *testing.T) { + prompt := buildBrainUserPrompt( + "zh", + "密钥是abc123456", + "要创建交易所配置,还缺这些字段:Secret。", + "", + "", + ActiveSkillSession{SkillName: "exchange_management", ActionName: "create"}, + true, + ) + if !strings.Contains(prompt, "allowed_field_spec_json") || !strings.Contains(prompt, `"secret_key"`) { + t.Fatalf("expected brain prompt to expose canonical field schema, got:\n%s", prompt) + } +} diff --git a/agent/unified_turn_router_test.go b/agent/unified_turn_router_test.go new file mode 100644 index 00000000..9d3a86ec --- /dev/null +++ b/agent/unified_turn_router_test.go @@ -0,0 +1,251 @@ +package agent + +import ( + "context" + "path/filepath" + "strings" + "testing" + + "nofx/store" +) + +func TestParseUnifiedTurnDecisionNormalizesContextPolicy(t *testing.T) { + raw := `{ + "topic_intent": "start_new", + "business_action": "new_skill", + "target_skill": "strategy_management:update_config", + "context_mode": "fresh_context", + "extracted_data": {"name": "BTC趋势"}, + "confidence": 0.82 + }` + + decision, err := parseUnifiedTurnDecision(raw) + if err != nil { + t.Fatalf("parse unified decision: %v", err) + } + if decision.TopicIntent != "start_new" { + t.Fatalf("expected normalized topic intent, got %q", decision.TopicIntent) + } + if decision.BusinessAction != "new_skill" { + t.Fatalf("expected business action new_skill, got %q", decision.BusinessAction) + } + if decision.ContextMode != "fresh_context" { + t.Fatalf("expected fresh_context, got %q", decision.ContextMode) + } + if !decision.reliable() { + t.Fatalf("expected decision to be reliable: %+v", decision) + } +} + +func TestParseUnifiedTurnDecisionAcceptsSkillTaskList(t *testing.T) { + raw := `{ + "topic_intent": "start_new", + "business_action": "skill_tasks", + "context_mode": "fresh_context", + "tasks": [ + {"id":"task_1","skill":"strategy_management","action":"create","request":"创建高频交易策略","depends_on":[]}, + {"id":"task_2","skill":"trader_management","action":"configure_strategy","request":"绑定到交易员","depends_on":["task_1"]} + ], + "confidence": 0.86 + }` + + decision, err := parseUnifiedTurnDecision(raw) + if err != nil { + t.Fatalf("parse unified decision: %v", err) + } + if decision.BusinessAction != "skill_tasks" { + t.Fatalf("expected skill_tasks, got %q", decision.BusinessAction) + } + if len(decision.Tasks) != 2 { + t.Fatalf("expected 2 tasks, got %+v", decision.Tasks) + } + if decision.Tasks[0].Skill != "strategy_management" || decision.Tasks[0].Action != "create" { + t.Fatalf("unexpected first task: %+v", decision.Tasks[0]) + } + if !decision.reliable() { + t.Fatalf("expected task-list decision to be reliable: %+v", decision) + } +} + +func TestUnifiedTurnDecisionNewSkillCanUseSingleTask(t *testing.T) { + decision := normalizeUnifiedTurnDecision(unifiedTurnDecision{ + TopicIntent: "start_new", + BusinessAction: "new_skill", + ContextMode: "fresh_context", + Tasks: []WorkflowTask{{ + Skill: "strategy_management", + Action: "create", + Request: "创建高频交易策略", + }}, + Confidence: 0.9, + }) + if !decision.reliable() { + t.Fatalf("expected new_skill with task list to be reliable: %+v", decision) + } +} + +func TestUnifiedTurnDecisionRejectsLowConfidenceAndIncompleteDirectAnswer(t *testing.T) { + lowConfidence := unifiedTurnDecision{ + TopicIntent: "start_new", + BusinessAction: "planned_agent", + ContextMode: "fresh_context", + Confidence: 0.2, + } + lowConfidence = normalizeUnifiedTurnDecision(lowConfidence) + if lowConfidence.reliable() { + t.Fatalf("expected low confidence decision to fall back") + } + + emptyDirect := unifiedTurnDecision{ + TopicIntent: "instant_reply", + BusinessAction: "direct_answer", + ContextMode: "use_current", + Confidence: 0.9, + } + emptyDirect = normalizeUnifiedTurnDecision(emptyDirect) + if emptyDirect.reliable() { + t.Fatalf("expected direct_answer without reply_to_user to fall back") + } +} + +func TestExecuteUnifiedTurnDecisionDirectAnswerRecordsHistory(t *testing.T) { + a := New(nil, nil, DefaultConfig(), nil) + userID := int64(101) + decision := normalizeUnifiedTurnDecision(unifiedTurnDecision{ + TopicIntent: "instant_reply", + BusinessAction: "direct_answer", + ContextMode: "use_current", + ReplyToUser: "你好,我在。", + Confidence: 0.9, + }) + + answer, handled, err := a.executeUnifiedTurnDecision(context.Background(), "default", userID, "zh", "你好", decision, nil) + if err != nil { + t.Fatalf("execute unified decision: %v", err) + } + if !handled { + t.Fatal("expected direct answer to be handled") + } + if answer != "你好,我在。" { + t.Fatalf("unexpected answer: %q", answer) + } + + history := a.history.Get(userID) + if len(history) != 2 { + t.Fatalf("expected user and assistant history entries, got %d", len(history)) + } + if history[0].Role != "user" || history[0].Content != "你好" { + t.Fatalf("unexpected user history entry: %+v", history[0]) + } + if history[1].Role != "assistant" || history[1].Content != "你好,我在。" { + t.Fatalf("unexpected assistant history entry: %+v", history[1]) + } +} + +func TestExecuteUnifiedTurnDecisionContinueActiveDoesNotHandOffToPlanner(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "continue-active-router.db") + st, err := store.New(dbPath) + if err != nil { + t.Fatalf("create store: %v", err) + } + a := New(nil, st, DefaultConfig(), nil) + userID := int64(102) + + session := newActiveSkillSession(userID, "strategy_management", "create") + session.Goal = "创建网格策略" + session.CollectedFields["name"] = "我的网格策略" + session.CollectedFields["strategy_type"] = "grid_trading" + setActiveSessionPendingHint(&session, "现在还需要确认网格交易对、网格数量、总投入、杠杆和价格区间。") + a.saveActiveSkillSession(session) + + decision := normalizeUnifiedTurnDecision(unifiedTurnDecision{ + TopicIntent: "continue_active", + BusinessAction: "planned_agent", + ContextMode: "use_current", + Confidence: 0.9, + }) + answer, handled, err := a.executeUnifiedTurnDecision(context.Background(), "default", userID, "zh", "那你帮我创吧", decision, nil) + if err != nil { + t.Fatalf("execute unified decision: %v", err) + } + if !handled { + t.Fatal("expected active session continuation to be handled") + } + if !strings.Contains(answer, "还缺") || !strings.Contains(answer, "交易对") || strings.Contains(answer, "交易机器人") || strings.Contains(answer, "AI模型和交易所") { + t.Fatalf("expected strategy session to continue without planner/trader handoff, got: %s", answer) + } + if _, ok := a.getActiveSkillSession(userID); !ok { + t.Fatalf("expected strategy active session to remain pending") + } +} + +func TestGuardUnexecutedActiveTaskCompletionBlocksCreationClaim(t *testing.T) { + session := ActiveSkillSession{ + SkillName: "strategy_management", + ActionName: "create", + } + reply, blocked := guardUnexecutedActiveTaskCompletion("zh", session, "已经创建好了。策略现在就在你的策略列表里。") + if !blocked { + t.Fatalf("expected unexecuted active create completion claim to be blocked") + } + if !strings.Contains(reply, "还没有真正创建") { + t.Fatalf("expected honest not-created reply, got: %s", reply) + } + + _, blocked = guardUnexecutedActiveTaskCompletion("zh", session, "我建议先用 BTCUSDT 做新手网格策略。") + if blocked { + t.Fatalf("non-completion proposal should not be blocked") + } +} + +func TestGuardUnsupportedAsyncPromiseBlocksFakeDiagnosisProgress(t *testing.T) { + reply, blocked := guardUnsupportedAsyncPromise("zh", "诊断还在进行中,请再稍等一下。我马上分析完“小小”的历史交易记录,找到亏损原因后会立刻告诉您。") + if !blocked { + t.Fatal("expected fake async diagnosis progress to be blocked") + } + for _, want := range []string{"没有后台异步任务", "当前回复"} { + if !strings.Contains(reply, want) { + t.Fatalf("expected guarded reply to contain %q, got: %s", want, reply) + } + } + + _, blocked = guardUnsupportedAsyncPromise("zh", "我需要策略名称和历史记录范围,才能开始诊断。") + if blocked { + t.Fatal("missing-info diagnosis reply should not be blocked") + } + + _, blocked = guardUnsupportedAsyncPromise("zh", "好的,参数已确认,正在为您创建“餐巾纸”网格策略。") + if !blocked { + t.Fatal("expected fake async strategy create progress to be blocked") + } +} + +func TestFinishTaskGuardBlocksFakeCreateProgressPromise(t *testing.T) { + reply, blocked := guardUnsupportedAsyncPromise("zh", "策略正在创建中,请稍等一会儿。创建成功后我会立刻告诉您。") + if !blocked { + t.Fatal("expected fake create progress promise to be blocked") + } + if !strings.Contains(reply, "没有后台异步任务") || !strings.Contains(reply, "实际执行") { + t.Fatalf("expected honest execution correction, got: %s", reply) + } +} + +func TestBuildUnifiedTurnRouterPromptNamesContextPolicy(t *testing.T) { + a := New(nil, nil, DefaultConfig(), nil) + systemPrompt, userPrompt := a.buildUnifiedTurnRouterPrompt(42, "zh", "不是交易员,是策略") + for _, want := range []string{ + "context_mode values", + "fresh_context", + "downstream modules", + "tasks format", + "skill_tasks", + "topic_intent as the primary decision", + } { + if !strings.Contains(systemPrompt, want) { + t.Fatalf("expected system prompt to contain %q", want) + } + } + if !strings.Contains(userPrompt, "不是交易员,是策略") { + t.Fatalf("expected user prompt to contain current user message") + } +} diff --git a/agent/user_facing_prompt.go b/agent/user_facing_prompt.go new file mode 100644 index 00000000..8af3a506 --- /dev/null +++ b/agent/user_facing_prompt.go @@ -0,0 +1,3 @@ +package agent + +const cleanUserFacingReplyInstruction = "Your final reply must be clean and easy to understand, with no fluff, no internal jargon, and no unnecessary explanation." diff --git a/agent/user_facing_prompt_test.go b/agent/user_facing_prompt_test.go new file mode 100644 index 00000000..273503af --- /dev/null +++ b/agent/user_facing_prompt_test.go @@ -0,0 +1,12 @@ +package agent + +import "testing" + +func TestCleanUserFacingReplyInstruction(t *testing.T) { + if cleanUserFacingReplyInstruction == "" { + t.Fatal("expected clean user-facing reply instruction to be defined") + } + if got, want := cleanUserFacingReplyInstruction, "Your final reply must be clean and easy to understand, with no fluff, no internal jargon, and no unnecessary explanation."; got != want { + t.Fatalf("unexpected instruction\nwant: %q\ngot: %q", want, got) + } +} diff --git a/agent/web.go b/agent/web.go index 12865d84..e1571226 100644 --- a/agent/web.go +++ b/agent/web.go @@ -3,6 +3,7 @@ package agent import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -13,6 +14,14 @@ import ( ) type storeUserIDContextKey struct{} +type sessionPolicyContextKey struct{} + +type SessionPolicy struct { + Authenticated bool + IsAdmin bool + CanExecuteTrade bool + CanViewSensitiveSecrets bool +} // WithStoreUserID annotates an HTTP request context with the authenticated store user ID. func WithStoreUserID(ctx context.Context, storeUserID string) context.Context { @@ -26,6 +35,17 @@ func storeUserIDFromContext(ctx context.Context) string { return "default" } +func WithSessionPolicy(ctx context.Context, policy SessionPolicy) context.Context { + return context.WithValue(ctx, sessionPolicyContextKey{}, policy) +} + +func sessionPolicyFromContext(ctx context.Context) SessionPolicy { + if v, ok := ctx.Value(sessionPolicyContextKey{}).(SessionPolicy); ok { + return v + } + return SessionPolicy{} +} + // validSymbolRe matches only alphanumeric trading symbols (e.g. BTCUSDT, ETH-USD). var validSymbolRe = regexp.MustCompile(`^[A-Za-z0-9\-_]{1,20}$`) @@ -80,7 +100,7 @@ func (w *WebHandler) HandleChat(rw http.ResponseWriter, r *http.Request) { return } if req.UserID == 0 { - req.UserID = SessionUserIDFromKey(req.UserKey) + req.UserID = SessionUserIDFromKey(storeUserIDFromContext(r.Context())) } msg := req.Message if req.Lang != "" { @@ -93,7 +113,7 @@ func (w *WebHandler) HandleChat(rw http.ResponseWriter, r *http.Request) { resp, err := w.agent.HandleMessageForStoreUser(ctx, storeUserIDFromContext(r.Context()), req.UserID, msg) if err != nil { w.logger.Error("agent HandleMessage failed", "error", err, "user_id", req.UserID) - writeJSON(rw, 500, map[string]string{"error": "Failed to process message. Please try again."}) + writeJSON(rw, 500, map[string]string{"error": "I ran into a problem while handling that message. Please try again."}) return } writeJSON(rw, 200, map[string]string{"response": resp}) @@ -122,7 +142,7 @@ func (w *WebHandler) HandleChatStream(rw http.ResponseWriter, r *http.Request) { return } if req.UserID == 0 { - req.UserID = SessionUserIDFromKey(req.UserKey) + req.UserID = SessionUserIDFromKey(storeUserIDFromContext(r.Context())) } msg := req.Message if req.Lang != "" { @@ -146,11 +166,21 @@ func (w *WebHandler) HandleChatStream(rw http.ResponseWriter, r *http.Request) { defer cancel() resp, err := w.agent.HandleMessageStreamForStoreUser(ctx, storeUserIDFromContext(r.Context()), req.UserID, msg, func(event, data string) { + if ctx.Err() != nil { + return + } writeSSE(rw, flusher, event, data) }) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil { + w.logger.Info("agent stream cancelled", "user_id", req.UserID, "error", err) + return + } w.logger.Error("agent HandleMessageStream failed", "error", err, "user_id", req.UserID) - writeSSE(rw, flusher, "error", "Failed to process message. Please try again.") + writeSSE(rw, flusher, "error", "I ran into a problem while handling that message. Please try again.") + return + } + if ctx.Err() != nil { return } // Send final done event with complete response diff --git a/agent/workflow.go b/agent/workflow.go index fa704c3f..2f7ec355 100644 --- a/agent/workflow.go +++ b/agent/workflow.go @@ -161,49 +161,50 @@ func supportedWorkflowSkill(skill, action string) bool { if _, ok := getSkillDAG(skill, action); ok { return true } + if def, ok := getSkillDefinition(skill); ok { + if _, ok := def.Actions[action]; ok { + return true + } + } switch skill { case "trader_management", "strategy_management", "model_management", "exchange_management": - switch action { - case "create", "query_list", "query_detail", "query_running", "activate": + if action == "query_running" { return true } } return false } -func (a *Agent) tryWorkflowIntent(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { - if session := a.getWorkflowSession(userID); hasActiveWorkflowSession(session) { - return a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, session, onEvent) - } - - decomposition, err := a.decomposeWorkflowIntent(ctx, userID, lang, text) - if err != nil || len(decomposition.Tasks) <= 1 { - return "", false, err - } - session := WorkflowSession{ - UserID: userID, - OriginalRequest: text, - Tasks: decomposition.Tasks, - } - a.saveWorkflowSession(userID, session) - return a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, session, onEvent) -} - func (a *Agent) handleWorkflowSession(ctx context.Context, storeUserID string, userID int64, lang, text string, session WorkflowSession, onEvent func(event, data string)) (string, bool, error) { if isExplicitFlowAbort(text) { a.clearSkillSession(userID) a.clearWorkflowSession(userID) - if lang == "zh" { - return "已取消当前任务流。", true, nil - } - return "Cancelled the current workflow.", true, nil + return a.maybeOfferParentTaskAfterCancel(userID, lang), true, nil } if activeSkill := a.getSkillSession(userID); strings.TrimSpace(activeSkill.Name) != "" { - answer, handled := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent) + decision, _ := a.resolveSkillSessionTurn(ctx, userID, lang, text, activeSkill) + switch decision.Intent { + case "cancel": + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + return a.maybeOfferParentTaskAfterCancel(userID, lang), true, nil + case "instant_reply": + return a.replyToActiveFlowInstantReply(ctx, userID, lang, text, onEvent), true, nil + case "resume_snapshot", "start_new": + if shouldSuspendInterruptedTask(text) || decision.Intent == "resume_snapshot" { + answer, handled, err := a.handoffFromActiveFlow(ctx, storeUserID, userID, lang, text, decision.TargetSnapshotID, onEvent) + return answer, handled, err + } + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + return "", false, nil + } + answer, handled := a.executeAtomicSkillTask(storeUserID, userID, lang, text, activeSkill.Name, activeSkill.Action, onEvent) if !handled { return "", false, nil } + a.recordSkillInteraction(userID, text, answer) session = a.getWorkflowSession(userID) if hasActiveWorkflowSession(session) && strings.TrimSpace(a.getSkillSession(userID).Name) == "" { session = markCurrentWorkflowTask(session, workflowTaskCompleted, "") @@ -221,9 +222,78 @@ func (a *Agent) handleWorkflowSession(ctx context.Context, storeUserID string, u return answer, true, nil } + if decision := a.classifyWorkflowSessionInput(ctx, userID, lang, session, text); decision.Intent != "" && decision.Intent != "continue_active" { + switch decision.Intent { + case "cancel": + a.clearWorkflowSession(userID) + return a.maybeOfferParentTaskAfterCancel(userID, lang), true, nil + case "instant_reply": + return a.replyToActiveFlowInstantReply(ctx, userID, lang, text, onEvent), true, nil + case "resume_snapshot", "start_new": + if shouldSuspendInterruptedTask(text) || decision.Intent == "resume_snapshot" { + answer, handled, err := a.handoffFromActiveFlow(ctx, storeUserID, userID, lang, text, decision.TargetSnapshotID, onEvent) + return answer, handled, err + } + a.clearWorkflowSession(userID) + return "", false, nil + } + } + return a.maybeAdvanceWorkflow(ctx, storeUserID, userID, lang, session, onEvent) } +func (a *Agent) classifyWorkflowSessionInput(ctx context.Context, userID int64, lang string, session WorkflowSession, text string) unifiedFlowDecision { + text = strings.TrimSpace(text) + if text == "" { + return unifiedFlowDecision{Intent: "continue_active"} + } + if isExplicitFlowAbort(text) { + return unifiedFlowDecision{Intent: "cancel"} + } + if isInstantDirectReplyText(text) { + return unifiedFlowDecision{Intent: "instant_reply"} + } + if a == nil || a.aiClient == nil { + if looksLikeNewTopLevelIntent(text) && !strings.EqualFold(text, strings.TrimSpace(session.OriginalRequest)) { + return unifiedFlowDecision{Intent: "start_new"} + } + return unifiedFlowDecision{Intent: "continue_active"} + } + currentTask, _, _ := nextRunnableWorkflowTask(session) + recentConversationCtx := a.buildRecentConversationContext(userID, text) + flowContext := fmt.Sprintf( + "Workflow original request: %s\nCurrent runnable task: %s / %s / %s\nWorkflow tasks JSON: %s", + session.OriginalRequest, + currentTask.Skill, + currentTask.Action, + currentTask.Request, + mustMarshalJSON(session.Tasks), + ) + state := a.getExecutionState(userID) + systemPrompt, userPrompt := buildActiveFlowClassifierPrompt( + lang, + "workflow_session", + flowContext, + text, + recentConversationCtx, + state.CurrentReferences, + a.SnapshotManager(userID).List(), + ) + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return unifiedFlowDecision{} + } + return unifiedFlowDecisionFromIntent(parseActiveFlowIntentDecision(raw), "") +} + func (a *Agent) maybeAdvanceWorkflow(ctx context.Context, storeUserID string, userID int64, lang string, session WorkflowSession, onEvent func(event, data string)) (string, bool, error) { task, index, ok := nextRunnableWorkflowTask(session) if !ok { @@ -238,7 +308,7 @@ func (a *Agent) maybeAdvanceWorkflow(ctx context.Context, storeUserID string, us } if onEvent != nil { onEvent(StreamEventPlan, summary) - onEvent(StreamEventDelta, summary) + emitStreamText(onEvent, summary) } return summary, true, nil } @@ -253,13 +323,14 @@ func (a *Agent) maybeAdvanceWorkflow(ctx context.Context, storeUserID string, us onEvent(StreamEventTool, "workflow:"+task.Skill+":"+task.Action) } - answer, handled := a.tryHardSkill(ctx, storeUserID, userID, lang, task.Request, onEvent) + answer, handled := a.executeAtomicSkillTask(storeUserID, userID, lang, task.Request, task.Skill, task.Action, onEvent) if !handled { session.Tasks[index].Status = workflowTaskFailed session.Tasks[index].Error = "task_not_handled" a.saveWorkflowSession(userID, session) return "", false, nil } + a.recordSkillInteraction(userID, task.Request, answer) if strings.TrimSpace(a.getSkillSession(userID).Name) == "" { session = a.getWorkflowSession(userID) @@ -332,7 +403,8 @@ func (a *Agent) generateWorkflowSummary(ctx context.Context, userID int64, lang defer cancel() systemPrompt := `You are summarizing a finished workflow for NOFXi. Return one short user-facing summary in the user's language. -Do not mention internal DAG, scheduler, or JSON.` +Do not mention internal DAG, scheduler, or JSON. +` + cleanUserFacingReplyInstruction userPrompt := fmt.Sprintf("Language: %s\nOriginal request: %s\nCompleted tasks:\n- %s", lang, session.OriginalRequest, strings.Join(completed, "\n- ")) raw, err := a.aiClient.CallWithRequest(&mcp.Request{ Messages: []mcp.Message{ @@ -374,24 +446,88 @@ func looksLikeMultiTaskIntent(text string) bool { count++ } } - return count > 0 + if count > 0 { + return true + } + if looksLikeCompoundStrategyIntent(text) || looksLikeCompoundTraderIntent(text) || + looksLikeCompoundModelIntent(text) || looksLikeCompoundExchangeIntent(text) { + return true + } + return false +} + +func looksLikeCompoundStrategyIntent(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if !hasExplicitManagementDomainCue(text, "strategy") { + return false + } + hasCreate := containsAny(lower, []string{"创建", "新建", "创一个", "创个", "加一个", "create", "new"}) + hasConfigUpdate := containsAny(lower, []string{"修改", "更新", "参数", "配置", "prompt", "提示词", "改成", "改为"}) + hasLifecycle := containsAny(lower, []string{"激活", "activate", "复制", "duplicate", "删除", "删了", "删掉", "delete"}) + hasMetaUpdate := containsAny(lower, []string{"发布", "公开", "可见", "描述", "改成", "改为"}) + return (hasCreate && (hasConfigUpdate || hasLifecycle || hasMetaUpdate)) || + (hasConfigUpdate && hasLifecycle) +} + +func looksLikeCompoundTraderIntent(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if !(hasExplicitManagementDomainCue(text, "trader") || hasExplicitCreateIntentForDomain(text, "trader")) { + return false + } + hasCreate := containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}) + hasBindingsOrConfig := containsAny(lower, []string{"修改", "更新", "换模型", "换交易所", "换策略", "切换模型", "切换交易所", "切换策略", "扫描间隔", "全仓", "逐仓", "竞技场"}) + hasLifecycle := containsAny(lower, []string{"启动", "开始", "start", "停止", "stop"}) + return (hasCreate && (hasBindingsOrConfig || hasLifecycle)) || + (hasBindingsOrConfig && hasLifecycle) +} + +func looksLikeCompoundModelIntent(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if !hasExplicitManagementDomainCue(text, "model") { + return false + } + hasCreate := containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}) + hasConfig := containsAny(lower, []string{"修改", "更新", "改", "接口地址", "模型名", "启用", "禁用", "api key"}) + hasLifecycle := containsAny(lower, []string{"启用", "禁用", "enable", "disable", "删除", "删了", "删掉", "delete"}) + return (hasCreate && (hasConfig || hasLifecycle)) || (hasConfig && hasLifecycle) +} + +func looksLikeCompoundExchangeIntent(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if !hasExplicitManagementDomainCue(text, "exchange") { + return false + } + hasCreate := containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}) + hasConfig := containsAny(lower, []string{"修改", "更新", "改", "账户名", "api key", "secret", "passphrase", "钱包", "启用", "禁用"}) + hasLifecycle := containsAny(lower, []string{"启用", "禁用", "enable", "disable", "删除", "删了", "删掉", "delete"}) + return (hasCreate && (hasConfig || hasLifecycle)) || (hasConfig && hasLifecycle) } func (a *Agent) decomposeWorkflowIntentWithLLM(ctx context.Context, userID int64, lang, text string) (workflowDecomposition, error) { stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) defer cancel() - systemPrompt := `You decompose one NOFXi user request into a small task graph. + systemPrompt := `You decompose one NOFXi user request into a small task graph for execution. Return JSON only. No markdown. Only use these skills: trader_management, strategy_management, model_management, exchange_management. Only use one atomic action per task. +You are the action decomposition layer. Split complex requests into atomic management steps and decide dependencies. Each task must include: - id - skill - action - request - depends_on (array, may be empty) -If the request is effectively a single task, return one task only.` +Rules: +- Prefer atomic actions such as create, update_bindings, configure_strategy, configure_exchange, configure_model, update_status, update_endpoint, update_config, update_prompt, activate, duplicate, start, stop, delete, query_list, query_detail. +- If one request contains create plus follow-up edits in the same skill, split them into multiple tasks. +- If later tasks need an entity created earlier, make the dependency explicit in depends_on. +- Keep each request user-readable and self-contained enough for a single skill handler to execute. +- Do not merge two actions into one task. +- If the request is effectively a single task, return one task only.` userPrompt := fmt.Sprintf("Language: %s\nUser request: %s", lang, text) + if skillContext := buildManagementSkillRoutingContext(lang); skillContext != "" { + userPrompt += "\n\n" + skillContext + } raw, err := a.aiClient.CallWithRequest(&mcp.Request{ Messages: []mcp.Message{ mcp.NewSystemMessage(systemPrompt), @@ -451,21 +587,256 @@ func normalizeWorkflowDecomposition(out workflowDecomposition) workflowDecomposi func (a *Agent) decomposeWorkflowIntentFallback(text string) workflowDecomposition { segments := splitWorkflowSegments(text) tasks := make([]WorkflowTask, 0, len(segments)) - for i, segment := range segments { - task, ok := classifyWorkflowTask(segment) - if !ok { - continue - } - task.ID = fmt.Sprintf("task_%d", i+1) - task.Status = workflowTaskPending + nextID := 1 + for _, segment := range segments { + prevSkill := "" if len(tasks) > 0 { - task.DependsOn = []string{tasks[len(tasks)-1].ID} + prevSkill = tasks[len(tasks)-1].Skill + } + compound := classifyCompoundWorkflowTasksWithContext(segment, prevSkill) + if len(compound) == 0 { + task, ok := classifyWorkflowTaskWithContext(segment, prevSkill) + if !ok { + continue + } + compound = []WorkflowTask{task} + } + for i := range compound { + compound[i].ID = fmt.Sprintf("task_%d", nextID) + compound[i].Status = workflowTaskPending + if len(tasks) > 0 && len(compound[i].DependsOn) == 0 { + compound[i].DependsOn = []string{tasks[len(tasks)-1].ID} + } + if i > 0 { + compound[i].DependsOn = []string{compound[i-1].ID} + } + tasks = append(tasks, compound[i]) + nextID++ } - tasks = append(tasks, task) } return workflowDecomposition{Tasks: tasks} } +func classifyCompoundWorkflowTasksWithContext(text, previousSkill string) []WorkflowTask { + if tasks := classifyCompoundWorkflowTasks(text); len(tasks) > 1 { + return tasks + } + switch strings.TrimSpace(previousSkill) { + case "strategy_management": + return classifyContextualStrategyWorkflowTasks(text) + case "trader_management": + return classifyContextualTraderWorkflowTasks(text) + } + return nil +} + +func classifyCompoundWorkflowTasks(text string) []WorkflowTask { + segment := strings.TrimSpace(text) + if segment == "" { + return nil + } + + if tasks := classifyCompoundStrategyWorkflowTasks(segment); len(tasks) > 1 { + return tasks + } + if tasks := classifyCompoundTraderWorkflowTasks(segment); len(tasks) > 1 { + return tasks + } + if tasks := classifyCompoundModelWorkflowTasks(segment); len(tasks) > 1 { + return tasks + } + if tasks := classifyCompoundExchangeWorkflowTasks(segment); len(tasks) > 1 { + return tasks + } + return nil +} + +func classifyContextualStrategyWorkflowTasks(text string) []WorkflowTask { + lower := strings.ToLower(strings.TrimSpace(text)) + hasConfig := containsAny(lower, []string{"修改", "更新", "参数", "配置", "prompt", "提示词", "改成", "改为"}) + hasActivate := containsAny(lower, []string{"激活", "activate"}) + hasDuplicate := containsAny(lower, []string{"复制", "duplicate"}) + if !hasConfig && !hasActivate && !hasDuplicate { + return nil + } + var tasks []WorkflowTask + if hasConfig { + action := "update_config" + if containsAny(lower, []string{"prompt", "提示词"}) { + action = "update_prompt" + } + tasks = append(tasks, WorkflowTask{Skill: "strategy_management", Action: action, Request: text}) + } + if hasActivate { + tasks = append(tasks, WorkflowTask{Skill: "strategy_management", Action: "activate", Request: text}) + } + if hasDuplicate { + tasks = append(tasks, WorkflowTask{Skill: "strategy_management", Action: "duplicate", Request: text}) + } + if len(tasks) == 0 { + return nil + } + return tasks +} + +func classifyContextualTraderWorkflowTasks(text string) []WorkflowTask { + lower := strings.ToLower(strings.TrimSpace(text)) + hasUpdate := containsAny(lower, []string{"修改", "更新", "换模型", "换交易所", "换策略", "切换模型", "切换交易所", "切换策略", "扫描间隔", "全仓", "逐仓", "竞技场"}) + hasStart := containsAny(lower, []string{"启动", "开始", "run", "start"}) + hasStop := containsAny(lower, []string{"停止", "停掉", "stop", "pause"}) + if !hasUpdate && !hasStart && !hasStop { + return nil + } + var tasks []WorkflowTask + if hasUpdate { + tasks = append(tasks, WorkflowTask{Skill: "trader_management", Action: "update_bindings", Request: text}) + } + if hasStart { + tasks = append(tasks, WorkflowTask{Skill: "trader_management", Action: "start", Request: text}) + } + if hasStop { + tasks = append(tasks, WorkflowTask{Skill: "trader_management", Action: "stop", Request: text}) + } + if len(tasks) == 0 { + return nil + } + return tasks +} + +func classifyWorkflowTaskWithContext(text, previousSkill string) (WorkflowTask, bool) { + if task, ok := classifyWorkflowTask(text); ok { + return task, true + } + switch strings.TrimSpace(previousSkill) { + case "strategy_management": + if tasks := classifyContextualStrategyWorkflowTasks(text); len(tasks) > 0 { + return tasks[0], true + } + case "trader_management": + if tasks := classifyContextualTraderWorkflowTasks(text); len(tasks) > 0 { + return tasks[0], true + } + } + return WorkflowTask{}, false +} + +func classifyCompoundStrategyWorkflowTasks(text string) []WorkflowTask { + if !hasExplicitManagementDomainCue(text, "strategy") { + return nil + } + lower := strings.ToLower(strings.TrimSpace(text)) + hasCreate := containsAny(lower, []string{"创建", "新建", "创一个", "创个", "加一个", "create", "new"}) + hasConfig := containsAny(lower, []string{"修改", "更新", "参数", "配置", "prompt", "提示词", "改成", "改为"}) + hasActivate := containsAny(lower, []string{"激活", "activate"}) + hasDuplicate := containsAny(lower, []string{"复制", "duplicate"}) + + if !hasCreate && !hasConfig && !hasActivate && !hasDuplicate { + return nil + } + + var tasks []WorkflowTask + if hasCreate { + tasks = append(tasks, WorkflowTask{Skill: "strategy_management", Action: "create", Request: text}) + } + if hasConfig { + action := "update_config" + if containsAny(lower, []string{"prompt", "提示词"}) { + action = "update_prompt" + } + tasks = append(tasks, WorkflowTask{Skill: "strategy_management", Action: action, Request: text}) + } + if hasActivate { + tasks = append(tasks, WorkflowTask{Skill: "strategy_management", Action: "activate", Request: text}) + } + if hasDuplicate { + tasks = append(tasks, WorkflowTask{Skill: "strategy_management", Action: "duplicate", Request: text}) + } + if len(tasks) <= 1 { + return nil + } + return tasks +} + +func classifyCompoundTraderWorkflowTasks(text string) []WorkflowTask { + if !(hasExplicitManagementDomainCue(text, "trader") || hasExplicitCreateIntentForDomain(text, "trader")) { + return nil + } + lower := strings.ToLower(strings.TrimSpace(text)) + hasCreate := containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}) + hasUpdate := containsAny(lower, []string{"修改", "更新", "换模型", "换交易所", "换策略", "切换模型", "切换交易所", "切换策略", "扫描间隔", "全仓", "逐仓", "竞技场"}) + hasStart := containsAny(lower, []string{"启动", "开始", "run", "start"}) + hasStop := containsAny(lower, []string{"停止", "停掉", "stop", "pause"}) + + var tasks []WorkflowTask + if hasCreate { + tasks = append(tasks, WorkflowTask{Skill: "trader_management", Action: "create", Request: text}) + } + if hasUpdate { + tasks = append(tasks, WorkflowTask{Skill: "trader_management", Action: "update_bindings", Request: text}) + } + if hasStart { + tasks = append(tasks, WorkflowTask{Skill: "trader_management", Action: "start", Request: text}) + } + if hasStop { + tasks = append(tasks, WorkflowTask{Skill: "trader_management", Action: "stop", Request: text}) + } + if len(tasks) <= 1 { + return nil + } + return tasks +} + +func classifyCompoundModelWorkflowTasks(text string) []WorkflowTask { + if !hasExplicitManagementDomainCue(text, "model") { + return nil + } + lower := strings.ToLower(strings.TrimSpace(text)) + hasCreate := containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}) + hasConfig := containsAny(lower, []string{"修改", "更新", "改", "接口地址", "模型名", "api key"}) + hasStatus := containsAny(lower, []string{"启用", "禁用", "enable", "disable"}) + + var tasks []WorkflowTask + if hasCreate { + tasks = append(tasks, WorkflowTask{Skill: "model_management", Action: "create", Request: text}) + } + if hasConfig { + action := "update_endpoint" + tasks = append(tasks, WorkflowTask{Skill: "model_management", Action: action, Request: text}) + } + if hasStatus { + tasks = append(tasks, WorkflowTask{Skill: "model_management", Action: "update_status", Request: text}) + } + if len(tasks) <= 1 { + return nil + } + return tasks +} + +func classifyCompoundExchangeWorkflowTasks(text string) []WorkflowTask { + if !hasExplicitManagementDomainCue(text, "exchange") { + return nil + } + lower := strings.ToLower(strings.TrimSpace(text)) + hasCreate := containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}) + hasConfig := containsAny(lower, []string{"修改", "更新", "改", "账户名", "api key", "secret", "passphrase", "钱包"}) + hasStatus := containsAny(lower, []string{"启用", "禁用", "enable", "disable"}) + + var tasks []WorkflowTask + if hasCreate { + tasks = append(tasks, WorkflowTask{Skill: "exchange_management", Action: "create", Request: text}) + } + if hasConfig { + tasks = append(tasks, WorkflowTask{Skill: "exchange_management", Action: "update_name", Request: text}) + } + if hasStatus { + tasks = append(tasks, WorkflowTask{Skill: "exchange_management", Action: "update_status", Request: text}) + } + if len(tasks) <= 1 { + return nil + } + return tasks +} + func splitWorkflowSegments(text string) []string { parts := []string{strings.TrimSpace(text)} separators := []string{",", ",", "然后", "再", "并且", "同时", " and then ", " then ", " and "} @@ -490,27 +861,94 @@ func classifyWorkflowTask(text string) (WorkflowTask, bool) { if segment == "" { return WorkflowTask{}, false } + lower := strings.ToLower(segment) switch { - case detectCreateTraderSkill(segment): + case hasExplicitCreateIntentForDomain(segment, "trader"): return WorkflowTask{Skill: "trader_management", Action: "create", Request: segment}, true - case detectTraderManagementIntent(segment): - action := normalizeAtomicSkillAction("trader_management", detectManagementAction(segment, "trader")) + case hasExplicitManagementDomainCue(segment, "trader"): + action := "" + switch { + case containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}): + action = "create" + case containsAny(lower, []string{"启动", "开始", "run", "start"}): + action = "start" + case containsAny(lower, []string{"停止", "停掉", "stop", "pause"}): + action = "stop" + case containsAny(lower, []string{"删除", "删了", "删掉", "delete"}): + action = "delete" + case containsAny(lower, []string{"换模型", "换交易所", "换策略", "切换模型", "切换交易所", "切换策略", "扫描间隔", "全仓", "逐仓", "竞技场"}): + action = "update_bindings" + case containsAny(lower, []string{"修改", "更新", "改"}): + action = "update_bindings" + case containsAny(lower, []string{"详情", "配置", "参数", "what", "detail"}): + action = "query_detail" + case containsAny(lower, []string{"列表", "全部", "哪些", "list"}): + action = "query_list" + } if supportedWorkflowSkill("trader_management", action) { return WorkflowTask{Skill: "trader_management", Action: action, Request: segment}, true } - case detectExchangeManagementIntent(segment): - action := normalizeAtomicSkillAction("exchange_management", detectManagementAction(segment, "exchange")) + case hasExplicitManagementDomainCue(segment, "exchange"): + action := "" + switch { + case containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}): + action = "create" + case containsAny(lower, []string{"启用", "enable", "禁用", "disable"}): + action = "update_status" + case containsAny(lower, []string{"删除", "删了", "删掉", "delete"}): + action = "delete" + case containsAny(lower, []string{"修改", "更新", "改", "账户名", "api key", "secret", "passphrase", "钱包"}): + action = "update" + case containsAny(lower, []string{"详情", "配置", "参数", "what", "detail"}): + action = "query_detail" + case containsAny(lower, []string{"列表", "全部", "哪些", "list"}): + action = "query_list" + } if supportedWorkflowSkill("exchange_management", action) { return WorkflowTask{Skill: "exchange_management", Action: action, Request: segment}, true } - case detectModelManagementIntent(segment): - action := normalizeAtomicSkillAction("model_management", detectManagementAction(segment, "model")) + case hasExplicitManagementDomainCue(segment, "model"): + action := "" + switch { + case containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}): + action = "create" + case containsAny(lower, []string{"启用", "enable", "禁用", "disable"}): + action = "update_status" + case containsAny(lower, []string{"删除", "删了", "删掉", "delete"}): + action = "delete" + case containsAny(lower, []string{"接口地址", "endpoint", "url"}): + action = "update_endpoint" + case containsAny(lower, []string{"修改", "更新", "改", "模型名", "api key"}): + action = "update" + case containsAny(lower, []string{"详情", "配置", "参数", "what", "detail"}): + action = "query_detail" + case containsAny(lower, []string{"列表", "全部", "哪些", "list"}): + action = "query_list" + } if supportedWorkflowSkill("model_management", action) { return WorkflowTask{Skill: "model_management", Action: action, Request: segment}, true } - case detectStrategyManagementIntent(segment): - action := normalizeAtomicSkillAction("strategy_management", detectManagementAction(segment, "strategy")) - if action == "" && wantsStrategyDetails(segment) { + case hasExplicitManagementDomainCue(segment, "strategy"): + action := "" + switch { + case containsAny(lower, []string{"创建", "新建", "创一个", "创个", "create", "new"}): + action = "create" + case containsAny(lower, []string{"激活", "activate"}): + action = "activate" + case containsAny(lower, []string{"复制", "duplicate"}): + action = "duplicate" + case containsAny(lower, []string{"删除", "删了", "删掉", "delete"}): + action = "delete" + case containsAny(lower, []string{"prompt", "提示词"}): + action = "update_prompt" + case containsAny(lower, []string{"修改", "更新", "改", "参数", "配置"}): + action = "update_config" + case containsAny(lower, []string{"详情", "配置", "参数", "what", "detail"}) || hasExplicitStrategyDetailIntent(segment): + action = "query_detail" + case containsAny(lower, []string{"列表", "全部", "哪些", "list"}): + action = "query_list" + } + if action == "" && hasExplicitStrategyDetailIntent(segment) { action = "query_detail" } if supportedWorkflowSkill("strategy_management", action) { diff --git a/agent/workflow_test.go b/agent/workflow_test.go deleted file mode 100644 index bffed9bb..00000000 --- a/agent/workflow_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package agent - -import "testing" - -func TestSplitWorkflowSegments(t *testing.T) { - got := splitWorkflowSegments("把策略删了,再把交易所改名") - if len(got) != 2 { - t.Fatalf("expected 2 segments, got %d: %#v", len(got), got) - } -} - -func TestClassifyWorkflowTask(t *testing.T) { - task, ok := classifyWorkflowTask("把策略删了") - if !ok { - t.Fatal("expected task") - } - if task.Skill != "strategy_management" || task.Action != "delete" { - t.Fatalf("unexpected task: %+v", task) - } -} - -func TestFallbackWorkflowDecompositionBuildsTwoTasks(t *testing.T) { - a := &Agent{} - out := a.decomposeWorkflowIntentFallback("把策略删了,再把交易所改名") - if len(out.Tasks) != 2 { - t.Fatalf("expected 2 tasks, got %d", len(out.Tasks)) - } - if out.Tasks[0].Skill != "strategy_management" { - t.Fatalf("unexpected first task: %+v", out.Tasks[0]) - } - if out.Tasks[1].Skill != "exchange_management" { - t.Fatalf("unexpected second task: %+v", out.Tasks[1]) - } - if len(out.Tasks[1].DependsOn) != 1 || out.Tasks[1].DependsOn[0] != out.Tasks[0].ID { - t.Fatalf("expected dependency on first task, got %+v", out.Tasks[1].DependsOn) - } -} diff --git a/api/agent_preferences.go b/api/agent_preferences.go index 1c188840..f73c3e5f 100644 --- a/api/agent_preferences.go +++ b/api/agent_preferences.go @@ -39,6 +39,10 @@ func (s *Server) handleCreateAgentPreference(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "text required"}) return } + if len([]rune(strings.TrimSpace(req.Text))) > 500 { + c.JSON(http.StatusBadRequest, gin.H{"error": "text too long"}) + return + } created, err := agent.NewPersistentPreference(req.Text) if err != nil { diff --git a/api/agent_routes.go b/api/agent_routes.go index 91d09a0c..b15f6114 100644 --- a/api/agent_routes.go +++ b/api/agent_routes.go @@ -11,11 +11,27 @@ import ( func (s *Server) RegisterAgentHandler(h *agent.WebHandler) { // Chat requires auth — can trigger trades and access account data s.router.POST("/api/agent/chat", s.authMiddleware(), func(c *gin.Context) { - req := c.Request.WithContext(agent.WithStoreUserID(c.Request.Context(), c.GetString("user_id"))) + isAdmin := c.GetString("user_id") == "admin" + ctx := agent.WithStoreUserID(c.Request.Context(), c.GetString("user_id")) + ctx = agent.WithSessionPolicy(ctx, agent.SessionPolicy{ + Authenticated: true, + IsAdmin: isAdmin, + CanExecuteTrade: true, + CanViewSensitiveSecrets: false, + }) + req := c.Request.WithContext(ctx) h.HandleChat(c.Writer, req) }) s.router.POST("/api/agent/chat/stream", s.authMiddleware(), func(c *gin.Context) { - req := c.Request.WithContext(agent.WithStoreUserID(c.Request.Context(), c.GetString("user_id"))) + isAdmin := c.GetString("user_id") == "admin" + ctx := agent.WithStoreUserID(c.Request.Context(), c.GetString("user_id")) + ctx = agent.WithSessionPolicy(ctx, agent.SessionPolicy{ + Authenticated: true, + IsAdmin: isAdmin, + CanExecuteTrade: true, + CanViewSensitiveSecrets: false, + }) + req := c.Request.WithContext(ctx) h.HandleChatStream(c.Writer, req) }) // Public endpoints — read-only market data diff --git a/api/exchange_account_state.go b/api/exchange_account_state.go index f3079496..91179907 100644 --- a/api/exchange_account_state.go +++ b/api/exchange_account_state.go @@ -319,29 +319,23 @@ func accountAssetForExchange(exchangeType string) string { } func missingExchangeCredentials(exchangeCfg *store.Exchange) (status string, code string, message string, missing bool) { - switch exchangeCfg.ExchangeType { - case "binance", "bybit", "gate", "indodax": - if exchangeCfg.APIKey == "" || exchangeCfg.SecretKey == "" { - return exchangeAccountStatusMissingCredentials, "MISSING_REQUIRED_FIELDS", "API key and secret key are required", true + missingFields := store.MissingRequiredExchangeCredentialFields( + exchangeCfg.ExchangeType, + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + exchangeCfg.HyperliquidWalletAddr, + exchangeCfg.AsterUser, + exchangeCfg.AsterSigner, + string(exchangeCfg.AsterPrivateKey), + exchangeCfg.LighterWalletAddr, + string(exchangeCfg.LighterAPIKeyPrivateKey), + ) + if len(missingFields) > 0 { + if len(missingFields) == 1 && missingFields[0] == "exchange_type" { + return exchangeAccountStatusUnavailable, "UNSUPPORTED_EXCHANGE", "Unsupported exchange type", true } - case "okx", "bitget", "kucoin": - if exchangeCfg.APIKey == "" || exchangeCfg.SecretKey == "" || exchangeCfg.Passphrase == "" { - return exchangeAccountStatusMissingCredentials, "MISSING_REQUIRED_FIELDS", "API key, secret key, and passphrase are required", true - } - case "hyperliquid": - if exchangeCfg.APIKey == "" || exchangeCfg.HyperliquidWalletAddr == "" { - return exchangeAccountStatusMissingCredentials, "MISSING_REQUIRED_FIELDS", "Private key and wallet address are required", true - } - case "aster": - if exchangeCfg.AsterUser == "" || exchangeCfg.AsterSigner == "" || exchangeCfg.AsterPrivateKey == "" { - return exchangeAccountStatusMissingCredentials, "MISSING_REQUIRED_FIELDS", "Aster user, signer, and private key are required", true - } - case "lighter": - if exchangeCfg.LighterWalletAddr == "" || exchangeCfg.LighterAPIKeyPrivateKey == "" { - return exchangeAccountStatusMissingCredentials, "MISSING_REQUIRED_FIELDS", "Wallet address and API key private key are required", true - } - default: - return exchangeAccountStatusUnavailable, "UNSUPPORTED_EXCHANGE", "Unsupported exchange type", true + return exchangeAccountStatusMissingCredentials, "MISSING_REQUIRED_FIELDS", "Missing required fields: " + strings.Join(missingFields, ", "), true } return "", "", "", false diff --git a/api/handler_ai_model.go b/api/handler_ai_model.go index ffbb3fdc..f19b759c 100644 --- a/api/handler_ai_model.go +++ b/api/handler_ai_model.go @@ -10,6 +10,7 @@ import ( "nofx/crypto" "nofx/logger" "nofx/security" + "nofx/store" "nofx/wallet" "github.com/gin-gonic/gin" @@ -77,8 +78,11 @@ func (s *Server) handleGetModelConfigs(c *gin.Context) { logger.Infof("✅ Found %d AI model configs", len(models)) // Convert to safe response structure, remove sensitive information - safeModels := make([]SafeModelConfig, len(models)) - for i, model := range models { + safeModels := make([]SafeModelConfig, 0, len(models)) + for _, model := range models { + if !store.IsVisibleAIModel(model) { + continue + } safeModel := SafeModelConfig{ ID: model.ID, Name: model.Name, @@ -100,7 +104,23 @@ func (s *Server) handleGetModelConfigs(c *gin.Context) { } } - safeModels[i] = safeModel + safeModels = append(safeModels, safeModel) + } + + if len(safeModels) == 0 { + logger.Infof("⚠️ No visible AI models in database, returning defaults") + defaultModels := []SafeModelConfig{ + {ID: "deepseek", Name: "DeepSeek AI", Provider: "deepseek", Enabled: false, HasAPIKey: false}, + {ID: "qwen", Name: "Qwen AI", Provider: "qwen", Enabled: false, HasAPIKey: false}, + {ID: "openai", Name: "OpenAI", Provider: "openai", Enabled: false, HasAPIKey: false}, + {ID: "claude", Name: "Claude AI", Provider: "claude", Enabled: false, HasAPIKey: false}, + {ID: "gemini", Name: "Gemini AI", Provider: "gemini", Enabled: false, HasAPIKey: false}, + {ID: "grok", Name: "Grok AI", Provider: "grok", Enabled: false, HasAPIKey: false}, + {ID: "kimi", Name: "Kimi AI", Provider: "kimi", Enabled: false, HasAPIKey: false}, + {ID: "minimax", Name: "MiniMax AI", Provider: "minimax", Enabled: false, HasAPIKey: false}, + } + c.JSON(http.StatusOK, defaultModels) + return } c.JSON(http.StatusOK, safeModels) @@ -217,10 +237,12 @@ func (s *Server) handleGetSupportedModels(c *gin.Context) { {"id": "qwen", "name": "Qwen", "provider": "qwen", "defaultModel": "qwen3-max"}, {"id": "openai", "name": "OpenAI", "provider": "openai", "defaultModel": "gpt-5.1"}, {"id": "claude", "name": "Claude", "provider": "claude", "defaultModel": "claude-opus-4-6"}, - {"id": "gemini", "name": "Google Gemini", "provider": "gemini", "defaultModel": "gemini-3.1-pro"}, + {"id": "gemini", "name": "Google Gemini", "provider": "gemini", "defaultModel": "gemini-3-pro-preview"}, {"id": "grok", "name": "Grok (xAI)", "provider": "grok", "defaultModel": "grok-3-latest"}, {"id": "kimi", "name": "Kimi (Moonshot)", "provider": "kimi", "defaultModel": "moonshot-v1-auto"}, {"id": "minimax", "name": "MiniMax", "provider": "minimax", "defaultModel": "MiniMax-M2.7"}, + {"id": "blockrun-base", "name": "BlockRun (Base Wallet)", "provider": "blockrun-base", "defaultModel": "auto"}, + {"id": "blockrun-sol", "name": "BlockRun (Solana Wallet)", "provider": "blockrun-sol", "defaultModel": "auto"}, {"id": "claw402", "name": "Claw402 (Base USDC)", "provider": "claw402", "defaultModel": "deepseek-v4-flash"}, } diff --git a/api/handler_exchange.go b/api/handler_exchange.go index 0fb8650a..217aacaa 100644 --- a/api/handler_exchange.go +++ b/api/handler_exchange.go @@ -4,10 +4,12 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "nofx/config" "nofx/crypto" "nofx/logger" + "nofx/store" "github.com/gin-gonic/gin" ) @@ -30,11 +32,39 @@ type SafeExchangeConfig struct { Name string `json:"name"` // Display name Type string `json:"type"` // "cex" or "dex" Enabled bool `json:"enabled"` + HasAPIKey bool `json:"has_api_key"` + HasSecretKey bool `json:"has_secret_key"` + HasPassphrase bool `json:"has_passphrase"` Testnet bool `json:"testnet,omitempty"` HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Hyperliquid wallet address (not sensitive) - AsterUser string `json:"asterUser"` // Aster username (not sensitive) - AsterSigner string `json:"asterSigner"` // Aster signer (not sensitive) - LighterWalletAddr string `json:"lighterWalletAddr"` // LIGHTER wallet address (not sensitive) + HasAsterPrivateKey bool `json:"has_aster_private_key"` + AsterUser string `json:"asterUser"` // Aster username (not sensitive) + AsterSigner string `json:"asterSigner"` // Aster signer (not sensitive) + LighterWalletAddr string `json:"lighterWalletAddr"` // LIGHTER wallet address (not sensitive) + HasLighterPrivateKey bool `json:"has_lighter_private_key"` + HasLighterAPIKey bool `json:"has_lighter_api_key_private_key"` +} + +func safeExchangeConfigFromStore(exchange *store.Exchange) SafeExchangeConfig { + return SafeExchangeConfig{ + ID: exchange.ID, + ExchangeType: exchange.ExchangeType, + AccountName: exchange.AccountName, + Name: exchange.Name, + Type: exchange.Type, + Enabled: exchange.Enabled, + HasAPIKey: exchange.APIKey != "", + HasSecretKey: exchange.SecretKey != "", + HasPassphrase: exchange.Passphrase != "", + Testnet: exchange.Testnet, + HyperliquidWalletAddr: exchange.HyperliquidWalletAddr, + HasAsterPrivateKey: exchange.AsterPrivateKey != "", + AsterUser: exchange.AsterUser, + AsterSigner: exchange.AsterSigner, + LighterWalletAddr: exchange.LighterWalletAddr, + HasLighterPrivateKey: exchange.LighterPrivateKey != "", + HasLighterAPIKey: exchange.LighterAPIKeyPrivateKey != "", + } } type UpdateExchangeConfigRequest struct { @@ -96,21 +126,12 @@ func (s *Server) handleGetExchangeConfigs(c *gin.Context) { logger.Infof("✅ Found %d exchange configs", len(exchanges)) // Convert to safe response structure, remove sensitive information - safeExchanges := make([]SafeExchangeConfig, len(exchanges)) - for i, exchange := range exchanges { - safeExchanges[i] = SafeExchangeConfig{ - ID: exchange.ID, - ExchangeType: exchange.ExchangeType, - AccountName: exchange.AccountName, - Name: exchange.Name, - Type: exchange.Type, - Enabled: exchange.Enabled, - Testnet: exchange.Testnet, - HyperliquidWalletAddr: exchange.HyperliquidWalletAddr, - AsterUser: exchange.AsterUser, - AsterSigner: exchange.AsterSigner, - LighterWalletAddr: exchange.LighterWalletAddr, + safeExchanges := make([]SafeExchangeConfig, 0, len(exchanges)) + for _, exchange := range exchanges { + if !store.IsVisibleExchange(exchange) { + continue } + safeExchanges = append(safeExchanges, safeExchangeConfigFromStore(exchange)) } c.JSON(http.StatusOK, safeExchanges) @@ -179,13 +200,73 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { // Update each exchange's configuration and track traders that need reload tradersToReload := make(map[string]bool) for exchangeID, exchangeData := range req.Exchanges { + existing, err := s.store.Exchange().GetByID(userID, exchangeID) + if err != nil { + SafeInternalError(c, fmt.Sprintf("Load exchange %s", exchangeID), err) + return + } + effectiveAPIKey := strings.TrimSpace(exchangeData.APIKey) + if effectiveAPIKey == "" { + effectiveAPIKey = strings.TrimSpace(string(existing.APIKey)) + } + effectiveSecretKey := strings.TrimSpace(exchangeData.SecretKey) + if effectiveSecretKey == "" { + effectiveSecretKey = strings.TrimSpace(string(existing.SecretKey)) + } + effectivePassphrase := strings.TrimSpace(exchangeData.Passphrase) + if effectivePassphrase == "" { + effectivePassphrase = strings.TrimSpace(string(existing.Passphrase)) + } + effectiveAsterPrivateKey := strings.TrimSpace(exchangeData.AsterPrivateKey) + if effectiveAsterPrivateKey == "" { + effectiveAsterPrivateKey = strings.TrimSpace(string(existing.AsterPrivateKey)) + } + effectiveLighterAPIKeyPrivateKey := strings.TrimSpace(exchangeData.LighterAPIKeyPrivateKey) + if effectiveLighterAPIKeyPrivateKey == "" { + effectiveLighterAPIKeyPrivateKey = strings.TrimSpace(string(existing.LighterAPIKeyPrivateKey)) + } + effectiveHyperliquidWalletAddr := strings.TrimSpace(exchangeData.HyperliquidWalletAddr) + if effectiveHyperliquidWalletAddr == "" { + effectiveHyperliquidWalletAddr = strings.TrimSpace(existing.HyperliquidWalletAddr) + } + effectiveAsterUser := strings.TrimSpace(exchangeData.AsterUser) + if effectiveAsterUser == "" { + effectiveAsterUser = strings.TrimSpace(existing.AsterUser) + } + effectiveAsterSigner := strings.TrimSpace(exchangeData.AsterSigner) + if effectiveAsterSigner == "" { + effectiveAsterSigner = strings.TrimSpace(existing.AsterSigner) + } + effectiveLighterWalletAddr := strings.TrimSpace(exchangeData.LighterWalletAddr) + if effectiveLighterWalletAddr == "" { + effectiveLighterWalletAddr = strings.TrimSpace(existing.LighterWalletAddr) + } + if missing := store.MissingRequiredExchangeCredentialFields( + existing.ExchangeType, + effectiveAPIKey, + effectiveSecretKey, + effectivePassphrase, + effectiveHyperliquidWalletAddr, + effectiveAsterUser, + effectiveAsterSigner, + effectiveAsterPrivateKey, + effectiveLighterWalletAddr, + effectiveLighterAPIKeyPrivateKey, + ); len(missing) > 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Missing required exchange fields: %s", strings.Join(missing, ", ")), + "missing_fields": missing, + }) + return + } + // Find traders using this exchange BEFORE updating traders, _ := s.store.Trader().ListByExchangeID(userID, exchangeID) for _, t := range traders { tradersToReload[t.ID] = true } - err := s.store.Exchange().Update(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Passphrase, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.HyperliquidUnifiedAcct, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey, exchangeData.LighterAPIKeyPrivateKey, exchangeData.LighterAPIKeyIndex) + err = s.store.Exchange().Update(userID, exchangeID, true, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Passphrase, exchangeData.Testnet, effectiveHyperliquidWalletAddr, exchangeData.HyperliquidUnifiedAcct, effectiveAsterUser, effectiveAsterSigner, exchangeData.AsterPrivateKey, effectiveLighterWalletAddr, exchangeData.LighterPrivateKey, exchangeData.LighterAPIKeyPrivateKey, exchangeData.LighterAPIKeyIndex) if err != nil { SafeInternalError(c, fmt.Sprintf("Update exchange %s", exchangeID), err) return @@ -271,10 +352,28 @@ func (s *Server) handleCreateExchange(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid exchange type: %s", req.ExchangeType)}) return } + if missing := store.MissingRequiredExchangeCredentialFields( + req.ExchangeType, + req.APIKey, + req.SecretKey, + req.Passphrase, + req.HyperliquidWalletAddr, + req.AsterUser, + req.AsterSigner, + req.AsterPrivateKey, + req.LighterWalletAddr, + req.LighterAPIKeyPrivateKey, + ); len(missing) > 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Missing required exchange fields: %s", strings.Join(missing, ", ")), + "missing_fields": missing, + }) + return + } - // Create new exchange account + // Exchange configs only persist once complete; persisted configs are always enabled. id, err := s.store.Exchange().Create( - userID, req.ExchangeType, req.AccountName, req.Enabled, + userID, req.ExchangeType, req.AccountName, true, req.APIKey, req.SecretKey, req.Passphrase, req.Testnet, req.HyperliquidWalletAddr, req.HyperliquidUnifiedAcct, req.AsterUser, req.AsterSigner, req.AsterPrivateKey, diff --git a/api/handler_exchange_test.go b/api/handler_exchange_test.go new file mode 100644 index 00000000..4c4da538 --- /dev/null +++ b/api/handler_exchange_test.go @@ -0,0 +1,45 @@ +package api + +import ( + "testing" + + "nofx/crypto" + "nofx/store" +) + +func TestSafeExchangeConfigFromStoreIncludesCredentialPresenceFlags(t *testing.T) { + cfg := &store.Exchange{ + ID: "ex-1", + ExchangeType: "okx", + AccountName: "OKX Main", + Name: "OKX Main", + Type: "cex", + Enabled: true, + APIKey: crypto.EncryptedString("api-test-123"), + SecretKey: crypto.EncryptedString("secret-test-123"), + Passphrase: crypto.EncryptedString("passphrase-test-123"), + AsterPrivateKey: crypto.EncryptedString("aster-private-key"), + LighterPrivateKey: crypto.EncryptedString("lighter-private-key"), + LighterAPIKeyPrivateKey: crypto.EncryptedString("lighter-api-key-private-key"), + } + + safe := safeExchangeConfigFromStore(cfg) + if !safe.HasAPIKey { + t.Fatalf("expected has_api_key to be true") + } + if !safe.HasSecretKey { + t.Fatalf("expected has_secret_key to be true") + } + if !safe.HasPassphrase { + t.Fatalf("expected has_passphrase to be true") + } + if !safe.HasAsterPrivateKey { + t.Fatalf("expected has_aster_private_key to be true") + } + if !safe.HasLighterPrivateKey { + t.Fatalf("expected has_lighter_private_key to be true") + } + if !safe.HasLighterAPIKey { + t.Fatalf("expected has_lighter_api_key_private_key to be true") + } +} diff --git a/api/handler_trader.go b/api/handler_trader.go index 6b503e61..67d3157e 100644 --- a/api/handler_trader.go +++ b/api/handler_trader.go @@ -14,6 +14,11 @@ import ( "gorm.io/gorm" ) +const ( + maxManualBTCETHLeverage = 20 + maxManualAltLeverage = 20 +) + // AI trader management related structures type CreateTraderRequest struct { Name string `json:"name" binding:"required"` @@ -65,6 +70,16 @@ func traderCreationRequestError(reason string) string { return formatTraderCreationError(reason, "请检查你刚刚填写的内容后,再重新提交") } +func validateTraderLeverageRange(btcEthLeverage, altcoinLeverage int) (string, string) { + if btcEthLeverage < 0 || btcEthLeverage > maxManualBTCETHLeverage { + return traderCreationRequestError("BTC/ETH 杠杆倍数需要在 1 到 20 倍之间"), "trader.create.invalid_btc_eth_leverage" + } + if altcoinLeverage < 0 || altcoinLeverage > maxManualAltLeverage { + return traderCreationRequestError("山寨币杠杆倍数需要在 1 到 20 倍之间"), "trader.create.invalid_altcoin_leverage" + } + return "", "" +} + func exchangeDisplayName(exchange *store.Exchange) string { if exchange == nil { return "所选交易所账户" @@ -306,13 +321,9 @@ func (s *Server) handleCreateTrader(c *gin.Context) { return } - // Validate leverage values - if req.BTCETHLeverage < 0 || req.BTCETHLeverage > 50 { - SafeBadRequestWithDetails(c, traderCreationRequestError("BTC/ETH 杠杆倍数需要在 1 到 50 倍之间"), "trader.create.invalid_btc_eth_leverage", nil) - return - } - if req.AltcoinLeverage < 0 || req.AltcoinLeverage > 20 { - SafeBadRequestWithDetails(c, traderCreationRequestError("山寨币杠杆倍数需要在 1 到 20 倍之间"), "trader.create.invalid_altcoin_leverage", nil) + // Validate leverage values against the same limits exposed by manual user config. + if errMsg, errCode := validateTraderLeverageRange(req.BTCETHLeverage, req.AltcoinLeverage); errMsg != "" { + SafeBadRequestWithDetails(c, errMsg, errCode, nil) return } @@ -574,6 +585,11 @@ func (s *Server) handleUpdateTrader(c *gin.Context) { return } + if errMsg, errCode := validateTraderLeverageRange(req.BTCETHLeverage, req.AltcoinLeverage); errMsg != "" { + SafeBadRequestWithDetails(c, errMsg, errCode, nil) + return + } + // Set default values isCrossMargin := existingTrader.IsCrossMargin // Keep original value if req.IsCrossMargin != nil { diff --git a/api/handler_trader_test.go b/api/handler_trader_test.go new file mode 100644 index 00000000..aee46644 --- /dev/null +++ b/api/handler_trader_test.go @@ -0,0 +1,17 @@ +package api + +import "testing" + +func TestValidateTraderLeverageRangeMatchesManualLimits(t *testing.T) { + if msg, code := validateTraderLeverageRange(20, 20); msg != "" || code != "" { + t.Fatalf("expected 20/20 leverage to be accepted, got msg=%q code=%q", msg, code) + } + + if msg, code := validateTraderLeverageRange(21, 20); msg == "" || code != "trader.create.invalid_btc_eth_leverage" { + t.Fatalf("expected BTC/ETH leverage > 20 to be rejected, got msg=%q code=%q", msg, code) + } + + if msg, code := validateTraderLeverageRange(20, 21); msg == "" || code != "trader.create.invalid_altcoin_leverage" { + t.Fatalf("expected altcoin leverage > 20 to be rejected, got msg=%q code=%q", msg, code) + } +} diff --git a/api/server.go b/api/server.go index 7d4bed95..3c657175 100644 --- a/api/server.go +++ b/api/server.go @@ -259,7 +259,7 @@ CRITICAL: Always use the "id" field for strategy_id.`, IMPORTANT: For most use cases just POST {"name":""} — the backend fills everything in. Only include "config" when the user explicitly requests custom settings (specific coins, custom leverage, custom timeframes). StrategyConfig fields: - coin_source.source_type: "static"(fixed coin list) | "ai500"(AI top500 ranking) | "oi_top"(OI increasing, suited for long) | "oi_low"(OI decreasing, suited for short) | "mixed" + coin_source.source_type: "static"(fixed coin list) | "ai500"(AI top500 ranking) | "oi_top"(OI increasing, suited for long) | "oi_low"(OI decreasing, suited for short) coin_source.static_coins: ["BTCUSDT","ETHUSDT"] — only when source_type="static" coin_source.use_ai500, ai500_limit: number of coins from AI500 pool (default 10) coin_source.use_oi_top/use_oi_low, oi_top_limit/oi_low_limit: OI-based coin selection diff --git a/api/strategy.go b/api/strategy.go index 64105fa5..023cd022 100644 --- a/api/strategy.go +++ b/api/strategy.go @@ -20,6 +20,9 @@ import ( // validateStrategyConfig validates strategy configuration and returns warnings func validateStrategyConfig(config *store.StrategyConfig) []string { var warnings []string + if config.StrategyType == "grid_trading" { + return warnings + } // Validate NofxOS API key if any NofxOS feature is enabled if (config.Indicators.EnableQuantData || config.Indicators.EnableOIRanking || @@ -31,6 +34,17 @@ func validateStrategyConfig(config *store.StrategyConfig) []string { return warnings } +func attachPublishConfig(config *store.StrategyConfig, strategy *store.Strategy) { + if config == nil || strategy == nil { + return + } + config.ClampLimits() + config.PublishConfig = &store.PublishStrategyConfig{ + IsPublic: strategy.IsPublic, + ConfigVisible: strategy.ConfigVisible, + } +} + // handleEstimateTokens estimates token usage for a strategy config (no auth required, pure computation) func (s *Server) handleEstimateTokens(c *gin.Context) { var req struct { @@ -71,6 +85,7 @@ func (s *Server) handlePublicStrategies(c *gin.Context) { if st.ConfigVisible { var config store.StrategyConfig json.Unmarshal([]byte(st.Config), &config) + attachPublishConfig(&config, st) item["config"] = config } @@ -101,6 +116,7 @@ func (s *Server) handleGetStrategies(c *gin.Context) { for _, st := range strategies { var config store.StrategyConfig json.Unmarshal([]byte(st.Config), &config) + attachPublishConfig(&config, st) result = append(result, gin.H{ "id": st.ID, @@ -139,6 +155,7 @@ func (s *Server) handleGetStrategy(c *gin.Context) { var config store.StrategyConfig json.Unmarshal([]byte(strategy.Config), &config) + attachPublishConfig(&config, strategy) c.JSON(http.StatusOK, gin.H{ "id": strategy.ID, @@ -162,10 +179,12 @@ func (s *Server) handleCreateStrategy(c *gin.Context) { } var req struct { - Name string `json:"name" binding:"required"` - Description string `json:"description"` - Lang string `json:"lang"` // "zh" or "en", used when config is omitted - Config *store.StrategyConfig `json:"config"` // optional — uses default if omitted + Name string `json:"name" binding:"required"` + Description string `json:"description"` + Lang string `json:"lang"` // "zh" or "en", used when config is omitted + Config *store.StrategyConfig `json:"config"` // optional — uses default if omitted + IsPublic bool `json:"is_public"` + ConfigVisible bool `json:"config_visible"` } if err := c.ShouldBindJSON(&req); err != nil { @@ -182,6 +201,19 @@ func (s *Server) handleCreateStrategy(c *gin.Context) { defaultCfg := store.GetDefaultStrategyConfig(lang) req.Config = &defaultCfg } + beforeClamp := *req.Config + req.Config.ClampLimits() + hadPublishConfig := req.Config.PublishConfig != nil + isPublic := req.IsPublic + configVisible := req.ConfigVisible + if hadPublishConfig { + isPublic = req.Config.PublishConfig.IsPublic + configVisible = req.Config.PublishConfig.ConfigVisible + } + req.Config.PublishConfig = &store.PublishStrategyConfig{ + IsPublic: isPublic, + ConfigVisible: configVisible, + } // Serialize configuration configJSON, err := json.Marshal(req.Config) @@ -197,7 +229,10 @@ func (s *Server) handleCreateStrategy(c *gin.Context) { Description: req.Description, IsActive: false, IsDefault: false, - Config: string(configJSON), + IsPublic: isPublic, + // Existing default is true; keep that behavior when no explicit publish config is sent. + ConfigVisible: configVisible || !hadPublishConfig, + Config: string(configJSON), } if err := s.store.Strategy().Create(strategy); err != nil { @@ -207,6 +242,7 @@ func (s *Server) handleCreateStrategy(c *gin.Context) { // Validate configuration and collect warnings warnings := validateStrategyConfig(req.Config) + warnings = append(warnings, store.StrategyClampWarnings(beforeClamp, *req.Config, req.Config.Language)...) response := gin.H{ "id": strategy.ID, @@ -263,14 +299,21 @@ func (s *Server) handleUpdateStrategy(c *gin.Context) { mergedConfig = store.StrategyConfig{} } - // Apply incoming config on top: top-level sections present in the request overwrite - // their corresponding existing section; absent sections remain unchanged. + // Apply incoming config on top while preserving nested fields that were not sent. if len(req.Config) > 0 && string(req.Config) != "null" { - if err := json.Unmarshal(req.Config, &mergedConfig); err != nil { + var patch map[string]any + if err := json.Unmarshal(req.Config, &patch); err != nil { + SafeBadRequest(c, "Invalid config JSON") + return + } + mergedConfig, err = store.MergeStrategyConfig(mergedConfig, patch) + if err != nil { SafeBadRequest(c, "Invalid config JSON") return } } + beforeClamp := mergedConfig + mergedConfig.ClampLimits() // Preserve existing name/description when not supplied name := req.Name @@ -324,6 +367,7 @@ func (s *Server) handleUpdateStrategy(c *gin.Context) { // Validate merged configuration and collect warnings warnings := validateStrategyConfig(&mergedConfig) + warnings = append(warnings, store.StrategyClampWarnings(beforeClamp, mergedConfig, mergedConfig.Language)...) response := gin.H{"message": "Strategy updated successfully"} if len(warnings) > 0 { @@ -417,6 +461,7 @@ func (s *Server) handleGetActiveStrategy(c *gin.Context) { var config store.StrategyConfig json.Unmarshal([]byte(strategy.Config), &config) + attachPublishConfig(&config, strategy) c.JSON(http.StatusOK, gin.H{ "id": strategy.ID, diff --git a/main.go b/main.go index b99ead0d..351fcce3 100644 --- a/main.go +++ b/main.go @@ -2,18 +2,18 @@ package main import ( "log/slog" - "nofx/api" nofxiagent "nofx/agent" + "nofx/api" "nofx/auth" "nofx/config" "nofx/crypto" "nofx/logger" "nofx/manager" - "nofx/telemetry" _ "nofx/mcp/payment" _ "nofx/mcp/provider" "nofx/store" "nofx/telegram" + "nofx/telemetry" "os" "os/signal" "path/filepath" @@ -121,10 +121,10 @@ func main() { status = "✅ Running" } idShort := t.ID - if len(idShort) > 8 { - idShort = idShort[:8] - } - logger.Infof(" • %s [%s] %s - AI Model: %s, Exchange: %s", + if len(idShort) > 8 { + idShort = idShort[:8] + } + logger.Infof(" • %s [%s] %s - AI Model: %s, Exchange: %s", t.Name, idShort, status, t.AIModelID, t.ExchangeID) } } @@ -137,20 +137,19 @@ func main() { telegramReloadCh := make(chan struct{}, 1) server.SetTelegramReloadCh(telegramReloadCh) + // Start the NOFXi web agent on top of the current dev branch services. + nofxiAgent := nofxiagent.New(traderManager, st, nil, slog.Default()) + agentWeb := nofxiagent.NewWebHandler(nofxiAgent, slog.Default()) + server.RegisterAgentHandler(agentWeb) + nofxiAgent.Start() + defer nofxiAgent.Stop() + go func() { if err := server.Start(); err != nil { logger.Fatalf("❌ Failed to start API server: %v", err) } }() - // Start the NOFXi web agent on top of the current dev branch services. - nofxiAgent := nofxiagent.New(traderManager, st, nil, slog.Default()) - nofxiAgent.Start() - defer nofxiAgent.Stop() - - agentWeb := nofxiagent.NewWebHandler(nofxiAgent, slog.Default()) - server.RegisterAgentHandler(agentWeb) - // Start Telegram bot (if TELEGRAM_BOT_TOKEN is configured) go telegram.Start(cfg, st, telegramReloadCh) diff --git a/mcp/client.go b/mcp/client.go index 6916de3c..97b4cded 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -32,10 +32,13 @@ var ( "no such host", "stream error", // HTTP/2 stream error "INTERNAL_ERROR", // Server internal error - "status 502", // Bad Gateway - "status 503", // Service Unavailable - "status 520", // Cloudflare origin error - "status 524", // Cloudflare timeout + "status 429", // Rate limit / upstream gateway throttling + "rate_limit_error", + "upstream_empty_output", + "status 502", // Bad Gateway + "status 503", // Service Unavailable + "status 520", // Cloudflare origin error + "status 524", // Cloudflare timeout } // TokenUsageCallback is called after each AI request with token usage info @@ -197,7 +200,9 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, if attempt < maxRetries { waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) client.Log.Infof("⏳ Waiting %v before retry...", waitTime) - time.Sleep(waitTime) + if err := sleepWithContext(context.Background(), waitTime); err != nil { + return "", err + } } } @@ -332,6 +337,38 @@ func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request, return req, nil } +func contextFromRequest(req *Request) context.Context { + if req != nil && req.Ctx != nil { + return req.Ctx + } + return context.Background() +} + +func (client *Client) buildHTTPRequestWithContext(ctx context.Context, url string, jsonData []byte) (*http.Request, error) { + if ctx == nil { + ctx = context.Background() + } + httpReq, err := client.Hooks.BuildRequest(url, jsonData) + if err != nil { + return nil, err + } + return httpReq.WithContext(ctx), nil +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + if ctx == nil { + ctx = context.Background() + } + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Call single AI API call (fixed flow, cannot be overridden) func (client *Client) Call(systemPrompt, userPrompt string) (string, error) { // Print current AI configuration @@ -450,7 +487,9 @@ func (client *Client) CallWithRequest(req *Request) (string, error) { if attempt < maxRetries { waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) client.Log.Infof("⏳ Waiting %v before retry...", waitTime) - time.Sleep(waitTime) + if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil { + return "", err + } } } @@ -482,7 +521,9 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) { } if attempt < maxRetries { waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) - time.Sleep(waitTime) + if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil { + return nil, err + } } } return nil, fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr) @@ -499,7 +540,7 @@ func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) { } url := client.Hooks.BuildUrl() - httpReq, err := client.Hooks.BuildRequest(url, jsonData) + httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -537,7 +578,7 @@ func (client *Client) callWithRequest(req *Request) (string, error) { url := client.Hooks.BuildUrl() client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url) - httpReq, err := client.Hooks.BuildRequest(url, jsonData) + httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } @@ -679,7 +720,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) } url := client.Hooks.BuildUrl() - httpReq, err := client.Hooks.BuildRequest(url, jsonData) + httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData) if err != nil { return "", err } @@ -687,7 +728,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) // Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds. // This breaks the scanner out of an indefinitely blocking Read on a hung connection. const idleTimeout = 60 * time.Second - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(contextFromRequest(req)) defer cancel() resetCh := make(chan struct{}, 1) go func() { diff --git a/mcp/client_test.go b/mcp/client_test.go index b76890dd..1fadeac9 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -345,6 +345,11 @@ func TestClient_IsRetryableError(t *testing.T) { err: errors.New("connection reset by peer"), expected: true, }, + { + name: "upstream empty output", + err: errors.New(`API returned error (status 429): {"error":{"code":"upstream_empty_output","message":"Upstream model returned empty output.","type":"rate_limit_error"}}`), + expected: true, + }, { name: "normal error", err: errors.New("bad request"), diff --git a/mcp/payment/x402.go b/mcp/payment/x402.go index 577da51f..4f2e460f 100644 --- a/mcp/payment/x402.go +++ b/mcp/payment/x402.go @@ -35,13 +35,34 @@ const ( X402Timeout = 5 * time.Minute ) +func x402ContextFromRequest(req *mcp.Request) context.Context { + if req != nil && req.Ctx != nil { + return req.Ctx + } + return context.Background() +} + +func x402Sleep(ctx context.Context, d time.Duration) error { + if ctx == nil { + ctx = context.Background() + } + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // ── Shared x402 types ──────────────────────────────────────────────────────── // X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2). type X402v2PaymentRequired struct { - X402Version int `json:"x402Version"` + X402Version int `json:"x402Version"` Accepts []X402AcceptOption `json:"accepts"` - Resource *X402Resource `json:"resource"` + Resource *X402Resource `json:"resource"` } // X402AcceptOption is a payment option from the x402 v2 header. @@ -114,16 +135,21 @@ func SignBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string // DoX402Request executes an HTTP request and handles the x402 v2 payment flow. func DoX402Request( + ctx context.Context, httpClient *http.Client, buildReqFn func() (*http.Request, error), signFn X402SignFunc, providerTag string, logger mcp.Logger, ) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } req, err := buildReqFn() if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } + req = req.WithContext(ctx) resp, err := httpClient.Do(req) if err != nil { @@ -157,6 +183,7 @@ func DoX402Request( if err != nil { return nil, fmt.Errorf("failed to build retry request: %w", err) } + req2 = req2.WithContext(ctx) req2.Header.Set("X-Payment", paymentSig) req2.Header.Set("Payment-Signature", paymentSig) @@ -166,7 +193,9 @@ func DoX402Request( wait := X402RetryBaseWait * time.Duration(attempt) logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...", providerTag, err, wait, attempt+1, X402MaxPaymentRetries) - time.Sleep(wait) + if err := x402Sleep(ctx, wait); err != nil { + return nil, err + } continue } return nil, fmt.Errorf("failed to send payment retry: %w", err) @@ -221,7 +250,9 @@ func DoX402Request( providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries) } - time.Sleep(wait) + if err := x402Sleep(ctx, wait); err != nil { + return nil, err + } continue } @@ -256,11 +287,15 @@ func DoX402RequestStream( providerTag string, logger mcp.Logger, ) (*http.Response, error) { - // Initial request — use background context (no idle timeout yet). + if ctx == nil { + ctx = context.Background() + } + // Initial request also inherits ctx so stage timeouts cancel the 402 handshake. req, err := buildReqFn() if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } + req = req.WithContext(ctx) resp, err := httpClient.Do(req) if err != nil { @@ -314,7 +349,9 @@ func DoX402RequestStream( wait := X402RetryBaseWait * time.Duration(attempt) logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...", providerTag, err, wait, attempt+1, X402MaxPaymentRetries) - time.Sleep(wait) + if err := x402Sleep(ctx, wait); err != nil { + return nil, err + } continue } return nil, fmt.Errorf("failed to send payment retry: %w", err) @@ -369,7 +406,9 @@ func DoX402RequestStream( providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries) } - time.Sleep(wait) + if err := x402Sleep(ctx, wait); err != nil { + return nil, err + } continue } @@ -500,7 +539,7 @@ func X402Call(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt, user return "", err } - body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) { + body, err := DoX402Request(context.Background(), c.HTTPClient, func() (*http.Request, error) { return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData) }, signFn, tag, c.Log) if err != nil { @@ -526,7 +565,7 @@ func X402CallFull(c *mcp.Client, signFn X402SignFunc, tag string, req *mcp.Reque return nil, err } - body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) { + body, err := DoX402Request(x402ContextFromRequest(req), c.HTTPClient, func() (*http.Request, error) { return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData) }, signFn, tag, c.Log) if err != nil { diff --git a/mcp/request_builder_test.go b/mcp/request_builder_test.go index 4ec10a9f..176dc040 100644 --- a/mcp/request_builder_test.go +++ b/mcp/request_builder_test.go @@ -1,8 +1,13 @@ package mcp import ( + "context" "encoding/json" + "io" + "net/http" + "strings" "testing" + "time" ) // ============================================================ @@ -342,6 +347,110 @@ func TestClient_CallWithRequest_Success(t *testing.T) { } } +func TestClient_CallWithRequest_AttachesRequestContextToHTTP(t *testing.T) { + type contextKey string + const key contextKey = "stage" + ctx := context.WithValue(context.Background(), key, "planner") + + mockHTTP := NewMockHTTPClient() + mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) { + if req.Context().Value(key) != "planner" { + t.Fatalf("expected HTTP request to inherit mcp.Request context") + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok"}}]}`)), + Header: make(http.Header), + }, nil + } + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(NewMockLogger()), + WithAPIKey("sk-test-key"), + ) + request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild() + request.Ctx = ctx + + result, err := client.CallWithRequest(request) + if err != nil { + t.Fatalf("should not error: %v", err) + } + if result != "ok" { + t.Fatalf("expected ok, got %q", result) + } +} + +func TestClient_CallWithRequest_RetrySleepStopsWhenContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + mockHTTP := NewMockHTTPClient() + mockHTTP.SetNetworkError(io.EOF) + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(NewMockLogger()), + WithAPIKey("sk-test-key"), + WithMaxRetries(2), + WithRetryWaitBase(time.Hour), + ) + request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild() + request.Ctx = ctx + + start := time.Now() + _, err := client.CallWithRequest(request) + if err == nil || !strings.Contains(err.Error(), "context canceled") { + t.Fatalf("expected context canceled during retry wait, got %v", err) + } + if elapsed := time.Since(start); elapsed > 500*time.Millisecond { + t.Fatalf("retry sleep did not respect context cancellation, elapsed=%v", elapsed) + } + if got := len(mockHTTP.GetRequests()); got != 1 { + t.Fatalf("expected no retry after context cancellation, got %d requests", got) + } +} + +func TestClient_CallWithRequest_RetriesUpstreamEmptyOutput(t *testing.T) { + mockHTTP := NewMockHTTPClient() + attempts := 0 + mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) { + attempts++ + if attempts == 1 { + body := `{"error":{"code":"upstream_empty_output","message":"Upstream model returned empty output.","type":"rate_limit_error"}}` + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok after retry"}}]}`)), + Header: make(http.Header), + }, nil + } + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(NewMockLogger()), + WithAPIKey("sk-test-key"), + WithMaxRetries(2), + WithRetryWaitBase(time.Millisecond), + ) + request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild() + + result, err := client.CallWithRequest(request) + if err != nil { + t.Fatalf("should retry upstream empty output and succeed: %v", err) + } + if result != "ok after retry" { + t.Fatalf("expected retry result, got %q", result) + } + if attempts != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts) + } +} + func TestClient_CallWithRequest_MultiRound(t *testing.T) { mockHTTP := NewMockHTTPClient() mockHTTP.SetSuccessResponse("Multi-round response") diff --git a/provider/nofxos/claw402.go b/provider/nofxos/claw402.go index 0e4d6107..53f0a31a 100644 --- a/provider/nofxos/claw402.go +++ b/provider/nofxos/claw402.go @@ -98,6 +98,7 @@ func (c *Claw402DataClient) DoRequest(endpoint string) ([]byte, error) { signFn := payment.MakeClaw402SignFunc(c.privateKey) body, err := payment.DoX402Request( + context.Background(), c.httpClient, buildReq, signFn, diff --git a/store/ai_model.go b/store/ai_model.go index 7cd08e2a..15270a47 100644 --- a/store/ai_model.go +++ b/store/ai_model.go @@ -5,6 +5,7 @@ import ( "fmt" "nofx/crypto" "nofx/logger" + "os" "strings" "time" @@ -18,16 +19,16 @@ type AIModelStore struct { // AIModel AI model configuration type AIModel struct { - ID string `gorm:"primaryKey" json:"id"` - UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"` - Name string `gorm:"not null" json:"name"` - Provider string `gorm:"not null" json:"provider"` - Enabled bool `gorm:"default:false" json:"enabled"` + ID string `gorm:"primaryKey" json:"id"` + UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"` + Name string `gorm:"not null" json:"name"` + Provider string `gorm:"not null" json:"provider"` + Enabled bool `gorm:"default:false" json:"enabled"` APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"` - CustomAPIURL string `gorm:"column:custom_api_url;default:''" json:"customApiUrl"` - CustomModelName string `gorm:"column:custom_model_name;default:''" json:"customModelName"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + CustomAPIURL string `gorm:"column:custom_api_url;default:''" json:"customApiUrl"` + CustomModelName string `gorm:"column:custom_model_name;default:''" json:"customModelName"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } func (AIModel) TableName() string { return "ai_models" } @@ -145,32 +146,64 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) { } func (s *AIModelStore) firstEnabledUsable(userID string) (*AIModel, error) { - var model AIModel + var models []AIModel err := s.db.Where("user_id = ? AND enabled = ? AND api_key != ''", userID, true). Order("updated_at DESC, id ASC"). - First(&model).Error + Find(&models).Error if err != nil { return nil, err } - return &model, nil + for i := range models { + if hasUsableAPIKey(models[i]) { + return &models[i], nil + } + } + return nil, gorm.ErrRecordNotFound } // GetAnyEnabled returns the first enabled AI model across all users. // Used by single-user features (e.g. Telegram bot) that need any working LLM client. func (s *AIModelStore) GetAnyEnabled() (*AIModel, error) { - var model AIModel - err := s.db.Where("enabled = ? AND api_key != ''", true). + var models []AIModel + err := s.db.Where("enabled = ?", true). Order("updated_at DESC, id ASC"). - First(&model).Error + Find(&models).Error if err != nil { return nil, err } - return &model, nil + for i := range models { + if hasUsableAPIKey(models[i]) { + return &models[i], nil + } + } + return nil, gorm.ErrRecordNotFound +} + +func hasUsableAPIKey(model AIModel) bool { + if strings.TrimSpace(string(model.APIKey)) != "" { + return true + } + envKeyByProvider := map[string]string{ + "deepseek": "DEEPSEEK_API_KEY", + "openai": "OPENAI_API_KEY", + "claude": "ANTHROPIC_API_KEY", + "gemini": "GEMINI_API_KEY", + "grok": "XAI_API_KEY", + "kimi": "MOONSHOT_API_KEY", + "minimax": "MINIMAX_API_KEY", + "qwen": "DASHSCOPE_API_KEY", + } + envKey := envKeyByProvider[strings.ToLower(strings.TrimSpace(model.Provider))] + return envKey != "" && strings.TrimSpace(os.Getenv(envKey)) != "" } // Update updates AI model, creates if not exists // IMPORTANT: If apiKey is empty string, the existing API key will be preserved (not overwritten) func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error { + return s.UpdateWithName(userID, id, "", enabled, apiKey, customAPIURL, customModelName) +} + +func (s *AIModelStore) UpdateWithName(userID, id, name string, enabled bool, apiKey, customAPIURL, customModelName string) error { // Try exact ID match first var existingModel AIModel err := s.db.Where("user_id = ? AND id = ?", userID, id).First(&existingModel).Error @@ -182,6 +215,9 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI "custom_model_name": customModelName, "updated_at": time.Now().UTC(), } + if strings.TrimSpace(name) != "" { + updates["name"] = strings.TrimSpace(name) + } // If apiKey is not empty, update it (encryption handled by crypto.EncryptedString) if apiKey != "" { updates["api_key"] = crypto.EncryptedString(apiKey) @@ -200,6 +236,9 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI "custom_model_name": customModelName, "updated_at": time.Now().UTC(), } + if strings.TrimSpace(name) != "" { + updates["name"] = strings.TrimSpace(name) + } if apiKey != "" { updates["api_key"] = crypto.EncryptedString(apiKey) } @@ -218,31 +257,35 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI } } - // Try to get name from existing model with same provider + // Try to get a sensible default name from an existing model with the same provider. var refModel AIModel - var name string + defaultName := "" if err := s.db.Where("provider = ?", provider).First(&refModel).Error; err == nil { - name = refModel.Name + defaultName = refModel.Name } else { if provider == "deepseek" { - name = "DeepSeek AI" + defaultName = "DeepSeek AI" } else if provider == "qwen" { - name = "Qwen AI" + defaultName = "Qwen AI" } else { - name = provider + " AI" + defaultName = provider + " AI" } } + finalName := strings.TrimSpace(name) + if finalName == "" { + finalName = strings.TrimSpace(defaultName) + } newModelID := id if id == provider { newModelID = fmt.Sprintf("%s_%s", userID, provider) } - logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name) + logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, finalName) newModel := &AIModel{ ID: newModelID, UserID: userID, - Name: name, + Name: finalName, Provider: provider, Enabled: enabled, APIKey: crypto.EncryptedString(apiKey), diff --git a/store/exchange.go b/store/exchange.go index e4acf69d..3819a376 100644 --- a/store/exchange.go +++ b/store/exchange.go @@ -4,6 +4,7 @@ import ( "fmt" "nofx/crypto" "nofx/logger" + "strings" "time" "github.com/google/uuid" @@ -57,6 +58,9 @@ func (s *ExchangeStore) initTables() error { // Still run data migrations s.migrateToMultiAccount() s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default") + if err := s.cleanupIncompleteExchangeConfigs(); err != nil { + logger.Warnf("Exchange cleanup migration warning: %v", err) + } return nil } } @@ -72,10 +76,48 @@ func (s *ExchangeStore) initTables() error { // Fix empty account_name for existing records s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default") + if err := s.cleanupIncompleteExchangeConfigs(); err != nil { + logger.Warnf("Exchange cleanup migration warning: %v", err) + } return nil } +func (s *ExchangeStore) cleanupIncompleteExchangeConfigs() error { + var exchanges []Exchange + if err := s.db.Find(&exchanges).Error; err != nil { + return err + } + for _, exchange := range exchanges { + missing := MissingRequiredExchangeCredentialFields( + exchange.ExchangeType, + string(exchange.APIKey), + string(exchange.SecretKey), + string(exchange.Passphrase), + exchange.HyperliquidWalletAddr, + exchange.AsterUser, + exchange.AsterSigner, + string(exchange.AsterPrivateKey), + exchange.LighterWalletAddr, + string(exchange.LighterAPIKeyPrivateKey), + ) + if len(missing) > 0 { + if err := s.db.Delete(&Exchange{}, "id = ? AND user_id = ?", exchange.ID, exchange.UserID).Error; err != nil { + return err + } + logger.Infof("🧹 Removed incomplete exchange config during migration: id=%s user=%s missing=%s", exchange.ID, exchange.UserID, strings.Join(missing, ",")) + continue + } + if !exchange.Enabled { + if err := s.db.Model(&Exchange{}).Where("id = ? AND user_id = ?", exchange.ID, exchange.UserID).Update("enabled", true).Error; err != nil { + return err + } + logger.Infof("🧹 Enabled complete exchange config during migration: id=%s user=%s", exchange.ID, exchange.UserID) + } + } + return nil +} + // migrateToMultiAccount migrates old schema (id=exchange_type) to new schema (id=UUID) func (s *ExchangeStore) migrateToMultiAccount() error { // Check if migration is needed by looking for old-style IDs (non-UUID) @@ -188,6 +230,10 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterPrivateKey, lighterApiKeyPrivateKey string, lighterApiKeyIndex int) (string, error) { + if missing := MissingRequiredExchangeCredentialFields(exchangeType, apiKey, secretKey, passphrase, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterApiKeyPrivateKey); len(missing) > 0 { + return "", fmt.Errorf("missing required exchange fields: %s", strings.Join(missing, ", ")) + } + id := uuid.New().String() name, typ := getExchangeNameAndType(exchangeType) @@ -205,7 +251,7 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled UserID: userID, Name: name, Type: typ, - Enabled: enabled, + Enabled: true, APIKey: crypto.EncryptedString(apiKey), SecretKey: crypto.EncryptedString(secretKey), Passphrase: crypto.EncryptedString(passphrase), @@ -232,10 +278,10 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe hyperliquidWalletAddr string, hyperliquidUnifiedAcct bool, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterPrivateKey, lighterApiKeyPrivateKey string, lighterApiKeyIndex int) error { - logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled) + logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s", userID, id) updates := map[string]interface{}{ - "enabled": enabled, + "enabled": true, "testnet": testnet, "hyperliquid_wallet_addr": hyperliquidWalletAddr, "hyperliquid_unified_account": hyperliquidUnifiedAcct, diff --git a/store/strategy.go b/store/strategy.go index 8860e8fe..4380446a 100644 --- a/store/strategy.go +++ b/store/strategy.go @@ -17,10 +17,25 @@ const ( MaxTimeframes = 4 MinKlineCount = 10 MaxKlineCount = 30 + MinLeverage = 1 + MaxBTCETHLeverage = 20 + MaxAltLeverage = 20 + MinPositionRatio = 0.5 + MaxPositionRatio = 10.0 + MinRiskReward = 1.0 + MaxRiskReward = 10.0 + MinMarginUsage = 0.1 + MaxMarginUsage = 1.0 + MinPositionSize = 10.0 + MaxPositionSize = 1000.0 + MinConfidence = 50 + MaxConfidence = 100 ) // ClampLimits enforces product-level limits on strategy config to prevent token overflow. func (c *StrategyConfig) ClampLimits() { + c.NormalizeProductSchema() + // Clamp coin source limits if c.CoinSource.AI500Limit > MaxCandidateCoins { c.CoinSource.AI500Limit = MaxCandidateCoins @@ -54,10 +69,426 @@ func (c *StrategyConfig) ClampLimits() { } // Clamp max positions + if c.RiskControl.MaxPositions < 1 { + c.RiskControl.MaxPositions = 1 + } if c.RiskControl.MaxPositions > MaxPositions { c.RiskControl.MaxPositions = MaxPositions } + // Clamp leverage limits to the same bounds as the manual config UI. + if c.RiskControl.BTCETHMaxLeverage < MinLeverage { + c.RiskControl.BTCETHMaxLeverage = MinLeverage + } + if c.RiskControl.BTCETHMaxLeverage > MaxBTCETHLeverage { + c.RiskControl.BTCETHMaxLeverage = MaxBTCETHLeverage + } + if c.RiskControl.AltcoinMaxLeverage < MinLeverage { + c.RiskControl.AltcoinMaxLeverage = MinLeverage + } + if c.RiskControl.AltcoinMaxLeverage > MaxAltLeverage { + c.RiskControl.AltcoinMaxLeverage = MaxAltLeverage + } + + // Clamp position value ratio limits. + if c.RiskControl.BTCETHMaxPositionValueRatio < MinPositionRatio { + c.RiskControl.BTCETHMaxPositionValueRatio = MinPositionRatio + } + if c.RiskControl.BTCETHMaxPositionValueRatio > MaxPositionRatio { + c.RiskControl.BTCETHMaxPositionValueRatio = MaxPositionRatio + } + if c.RiskControl.AltcoinMaxPositionValueRatio < MinPositionRatio { + c.RiskControl.AltcoinMaxPositionValueRatio = MinPositionRatio + } + if c.RiskControl.AltcoinMaxPositionValueRatio > MaxPositionRatio { + c.RiskControl.AltcoinMaxPositionValueRatio = MaxPositionRatio + } + + // Clamp risk parameters and entry requirements. + if c.RiskControl.MinRiskRewardRatio < MinRiskReward { + c.RiskControl.MinRiskRewardRatio = MinRiskReward + } + if c.RiskControl.MinRiskRewardRatio > MaxRiskReward { + c.RiskControl.MinRiskRewardRatio = MaxRiskReward + } + if c.RiskControl.MaxMarginUsage < MinMarginUsage { + c.RiskControl.MaxMarginUsage = MinMarginUsage + } + if c.RiskControl.MaxMarginUsage > MaxMarginUsage { + c.RiskControl.MaxMarginUsage = MaxMarginUsage + } + if c.RiskControl.MinPositionSize < MinPositionSize { + c.RiskControl.MinPositionSize = MinPositionSize + } + if c.RiskControl.MinPositionSize > MaxPositionSize { + c.RiskControl.MinPositionSize = MaxPositionSize + } + if c.RiskControl.MinConfidence < MinConfidence { + c.RiskControl.MinConfidence = MinConfidence + } + if c.RiskControl.MinConfidence > MaxConfidence { + c.RiskControl.MinConfidence = MaxConfidence + } +} + +// NormalizeProductSchema keeps saved strategy JSON aligned with the product +// editor schema. LLMs may emit user-facing labels such as "AI500"; persistence +// must use the exact frontend/backend enum values. +func (c *StrategyConfig) NormalizeProductSchema() { + c.StrategyType = normalizeStrategyType(c.StrategyType) + c.CoinSource.SourceType = normalizeCoinSourceType(c.CoinSource.SourceType) + if c.CoinSource.SourceType == "" { + c.CoinSource.SourceType = inferCoinSourceType(c.CoinSource) + } + + switch c.CoinSource.SourceType { + case "ai500": + c.CoinSource.UseAI500 = true + c.CoinSource.UseOITop = false + c.CoinSource.UseOILow = false + if c.CoinSource.AI500Limit <= 0 { + c.CoinSource.AI500Limit = 3 + } + case "oi_top": + c.CoinSource.UseAI500 = false + c.CoinSource.UseOITop = true + c.CoinSource.UseOILow = false + if c.CoinSource.OITopLimit <= 0 { + c.CoinSource.OITopLimit = 3 + } + case "oi_low": + c.CoinSource.UseAI500 = false + c.CoinSource.UseOITop = false + c.CoinSource.UseOILow = true + if c.CoinSource.OILowLimit <= 0 { + c.CoinSource.OILowLimit = 3 + } + case "static": + c.CoinSource.UseAI500 = false + c.CoinSource.UseOITop = false + c.CoinSource.UseOILow = false + default: + c.CoinSource.SourceType = "ai500" + c.CoinSource.UseAI500 = true + if c.CoinSource.AI500Limit <= 0 { + c.CoinSource.AI500Limit = 3 + } + } + + c.CoinSource.StaticCoins = normalizeSymbols(c.CoinSource.StaticCoins) + c.CoinSource.ExcludedCoins = normalizeSymbols(c.CoinSource.ExcludedCoins) + c.Indicators.Klines.PrimaryTimeframe = normalizeTimeframe(c.Indicators.Klines.PrimaryTimeframe) + c.Indicators.Klines.LongerTimeframe = normalizeTimeframe(c.Indicators.Klines.LongerTimeframe) + c.Indicators.Klines.SelectedTimeframes = normalizeTimeframes(c.Indicators.Klines.SelectedTimeframes) + if len(c.Indicators.Klines.SelectedTimeframes) > 0 { + c.Indicators.Klines.EnableMultiTimeframe = true + } +} + +func normalizeStrategyType(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + switch value { + case "grid", "grid_strategy", "grid-trading", "grid trading", "grid_trading", "网格", "网格策略", "网格交易": + return "grid_trading" + case "", "ai", "ai_strategy", "ai-trading", "ai trading", "ai_trading", "ai策略", "ai 策略", "ai交易策略", "ai智能策略": + return "ai_trading" + default: + return value + } +} + +func normalizeCoinSourceType(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + compact := strings.NewReplacer(" ", "", "_", "", "-", "", "数据源", "", "选币", "", "币种", "").Replace(value) + switch { + case compact == "": + return "" + case strings.Contains(compact, "ai500"): + return "ai500" + case strings.Contains(compact, "oitop") || strings.Contains(value, "oi top") || strings.Contains(value, "持仓量最高") || strings.Contains(value, "持仓量靠前"): + return "oi_top" + case strings.Contains(compact, "oilow") || strings.Contains(value, "oi low") || strings.Contains(value, "持仓量最低") || strings.Contains(value, "持仓量较低"): + return "oi_low" + case strings.Contains(value, "static") || strings.Contains(value, "固定") || strings.Contains(value, "静态"): + return "static" + default: + return value + } +} + +func inferCoinSourceType(source CoinSourceConfig) string { + switch { + case len(source.StaticCoins) > 0: + return "static" + case source.UseAI500: + return "ai500" + case source.UseOITop: + return "oi_top" + case source.UseOILow: + return "oi_low" + default: + return "ai500" + } +} + +func normalizeSymbols(values []string) []string { + out := make([]string, 0, len(values)) + seen := make(map[string]bool, len(values)) + for _, value := range splitLooseStringList(values) { + value = strings.ToUpper(strings.TrimSpace(value)) + value = strings.Trim(value, ",,;; ") + if value == "" || seen[value] { + continue + } + seen[value] = true + out = append(out, value) + } + return out +} + +func normalizeTimeframes(values []string) []string { + out := make([]string, 0, len(values)) + seen := make(map[string]bool, len(values)) + for _, value := range splitLooseStringList(values) { + tf := normalizeTimeframe(value) + if tf == "" || seen[tf] { + continue + } + seen[tf] = true + out = append(out, tf) + } + return out +} + +func splitLooseStringList(values []string) []string { + if len(values) == 0 { + return nil + } + joined := strings.TrimSpace(strings.Join(values, ",")) + if strings.HasPrefix(joined, "[") && strings.HasSuffix(joined, "]") { + var parsed []string + if err := json.Unmarshal([]byte(joined), &parsed); err == nil { + return parsed + } + } + parts := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" { + continue + } + if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { + var parsed []string + if err := json.Unmarshal([]byte(value), &parsed); err == nil { + parts = append(parts, parsed...) + continue + } + } + value = strings.Trim(value, "[]") + for _, part := range strings.FieldsFunc(value, func(r rune) bool { + return r == ',' || r == ',' || r == ';' || r == ';' || r == '\n' + }) { + part = strings.Trim(strings.TrimSpace(part), "\"'") + if part != "" { + parts = append(parts, part) + } + } + } + return parts +} + +func normalizeTimeframe(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + value = strings.Trim(value, "\"',,。 ") + if value == "" { + return "" + } + aliases := map[string]string{ + "1分钟": "1m", + "3分钟": "3m", + "5分钟": "5m", + "15分钟": "15m", + "30分钟": "30m", + "1小时": "1h", + "2小时": "2h", + "4小时": "4h", + "6小时": "6h", + "8小时": "8h", + "12小时": "12h", + "1天": "1d", + "3天": "3d", + "1周": "1w", + } + if alias, ok := aliases[value]; ok { + return alias + } + allowed := map[string]bool{ + "1m": true, "3m": true, "5m": true, "15m": true, "30m": true, + "1h": true, "2h": true, "4h": true, "6h": true, "8h": true, "12h": true, + "1d": true, "3d": true, "1w": true, + } + if !allowed[value] { + return "" + } + return value +} + +// MergeStrategyConfig applies a partial JSON-style patch onto a full strategy config. +// Nested objects are merged recursively so omitted fields keep their previous values. +func MergeStrategyConfig(base StrategyConfig, patch map[string]any) (StrategyConfig, error) { + baseJSON, err := json.Marshal(base) + if err != nil { + return StrategyConfig{}, err + } + + var mergedMap map[string]any + if err := json.Unmarshal(baseJSON, &mergedMap); err != nil { + return StrategyConfig{}, err + } + + normalizeStrategyConfigPatch(patch) + if fmt.Sprint(patch["strategy_type"]) == "grid_trading" { + ensureDefaultGridConfigMap(mergedMap) + } + mergeJSONMaps(mergedMap, patch) + + mergedJSON, err := json.Marshal(mergedMap) + if err != nil { + return StrategyConfig{}, err + } + + var merged StrategyConfig + if err := json.Unmarshal(mergedJSON, &merged); err != nil { + return StrategyConfig{}, err + } + return merged, nil +} + +func DefaultGridStrategyConfig() GridStrategyConfig { + return GridStrategyConfig{ + Symbol: "BTCUSDT", + GridCount: 10, + TotalInvestment: 1000, + Leverage: 5, + UpperPrice: 0, + LowerPrice: 0, + UseATRBounds: true, + ATRMultiplier: 2.0, + Distribution: "gaussian", + MaxDrawdownPct: 15, + StopLossPct: 5, + DailyLossLimitPct: 10, + UseMakerOnly: true, + EnableDirectionAdjust: false, + DirectionBiasRatio: 0.7, + } +} + +func ensureDefaultGridConfigMap(config map[string]any) { + if config == nil { + return + } + if existing, ok := config["grid_config"].(map[string]any); ok && len(existing) > 0 { + return + } + defaultGrid := DefaultGridStrategyConfig() + raw, err := json.Marshal(defaultGrid) + if err != nil { + return + } + var gridMap map[string]any + if err := json.Unmarshal(raw, &gridMap); err != nil { + return + } + config["grid_config"] = gridMap +} + +func normalizeStrategyConfigPatch(patch map[string]any) { + if patch == nil { + return + } + + if gridConfig, hasGrid := patch["grid_config"]; hasGrid && gridConfig != nil { + if _, hasType := patch["strategy_type"]; !hasType { + patch["strategy_type"] = "grid_trading" + } + } + + aiKeys := []string{"coin_source", "indicators", "risk_control", "prompt_sections", "custom_prompt"} + for _, key := range aiKeys { + value, ok := patch[key] + if !ok { + continue + } + aiConfig, _ := patch["ai_config"].(map[string]any) + if aiConfig == nil { + aiConfig = map[string]any{} + patch["ai_config"] = aiConfig + } + aiConfig[key] = value + delete(patch, key) + } + + if fmt.Sprint(patch["strategy_type"]) == "grid_trading" { + delete(patch, "ai_config") + } + + if _, hasType := patch["strategy_type"]; hasType { + return + } + if gridConfig, hasGrid := patch["grid_config"]; hasGrid && gridConfig != nil { + patch["strategy_type"] = "grid_trading" + } +} + +func mergeJSONMaps(dst, src map[string]any) { + for key, srcVal := range src { + srcMap, srcIsMap := srcVal.(map[string]any) + dstMap, dstIsMap := dst[key].(map[string]any) + if srcIsMap && dstIsMap { + mergeJSONMaps(dstMap, srcMap) + continue + } + dst[key] = srcVal + } +} + +func StrategyClampWarnings(before, after StrategyConfig, lang string) []string { + if lang != "zh" { + lang = "en" + } + warnings := make([]string, 0, 8) + appendInt := func(labelZH, labelEN string, from, to int) { + if from == to { + return + } + if lang == "zh" { + warnings = append(warnings, fmt.Sprintf("%s 已从 %d 调整为 %d", labelZH, from, to)) + return + } + warnings = append(warnings, fmt.Sprintf("%s adjusted from %d to %d", labelEN, from, to)) + } + appendFloat := func(labelZH, labelEN string, from, to float64) { + if from == to { + return + } + if lang == "zh" { + warnings = append(warnings, fmt.Sprintf("%s 已从 %.2f 调整为 %.2f", labelZH, from, to)) + return + } + warnings = append(warnings, fmt.Sprintf("%s adjusted from %.2f to %.2f", labelEN, from, to)) + } + + appendInt("最大持仓数", "max_positions", before.RiskControl.MaxPositions, after.RiskControl.MaxPositions) + appendInt("BTC/ETH 最大杠杆", "btc_eth_max_leverage", before.RiskControl.BTCETHMaxLeverage, after.RiskControl.BTCETHMaxLeverage) + appendInt("山寨币最大杠杆", "altcoin_max_leverage", before.RiskControl.AltcoinMaxLeverage, after.RiskControl.AltcoinMaxLeverage) + appendFloat("BTC/ETH 最大仓位价值倍数", "btc_eth_max_position_value_ratio", before.RiskControl.BTCETHMaxPositionValueRatio, after.RiskControl.BTCETHMaxPositionValueRatio) + appendFloat("山寨币最大仓位价值倍数", "altcoin_max_position_value_ratio", before.RiskControl.AltcoinMaxPositionValueRatio, after.RiskControl.AltcoinMaxPositionValueRatio) + appendFloat("最小盈亏比", "min_risk_reward_ratio", before.RiskControl.MinRiskRewardRatio, after.RiskControl.MinRiskRewardRatio) + appendFloat("最大保证金使用率", "max_margin_usage", before.RiskControl.MaxMarginUsage, after.RiskControl.MaxMarginUsage) + appendFloat("最小开仓金额", "min_position_size", before.RiskControl.MinPositionSize, after.RiskControl.MinPositionSize) + appendInt("最低置信度", "min_confidence", before.RiskControl.MinConfidence, after.RiskControl.MinConfidence) + return warnings } // StrategyStore strategy storage @@ -90,19 +521,128 @@ type StrategyConfig struct { // language setting: "zh" for Chinese, "en" for English // This determines the language used for data formatting and prompt generation Language string `json:"language,omitempty"` - // coin source configuration - CoinSource CoinSourceConfig `json:"coin_source"` - // quantitative data configuration - Indicators IndicatorConfig `json:"indicators"` - // custom prompt (appended at the end) - CustomPrompt string `json:"custom_prompt,omitempty"` - // risk control configuration - RiskControl RiskControlConfig `json:"risk_control"` - // editable sections of System Prompt - PromptSections PromptSectionsConfig `json:"prompt_sections,omitempty"` + // AI trading configuration fields are kept on the Go struct for engine + // compatibility, but JSON persistence nests them under ai_config. + CoinSource CoinSourceConfig `json:"-"` + Indicators IndicatorConfig `json:"-"` + CustomPrompt string `json:"-"` + RiskControl RiskControlConfig `json:"-"` + PromptSections PromptSectionsConfig `json:"-"` // Grid trading configuration (only used when StrategyType == "grid_trading") GridConfig *GridStrategyConfig `json:"grid_config,omitempty"` + + // Publish settings are shared by AI and grid strategies. The database still + // stores the authoritative booleans on Strategy, but config JSON may carry + // this object for agent/frontend schema consistency. + PublishConfig *PublishStrategyConfig `json:"publish_config,omitempty"` +} + +// AIStrategyConfig contains fields only used by AI trading strategies. +type AIStrategyConfig struct { + CoinSource CoinSourceConfig `json:"coin_source"` + Indicators IndicatorConfig `json:"indicators"` + CustomPrompt string `json:"custom_prompt,omitempty"` + RiskControl RiskControlConfig `json:"risk_control"` + PromptSections PromptSectionsConfig `json:"prompt_sections,omitempty"` +} + +// PublishStrategyConfig contains settings shared by all strategy types. +type PublishStrategyConfig struct { + IsPublic bool `json:"is_public"` + ConfigVisible bool `json:"config_visible"` +} + +// MarshalJSON writes the product-facing strategy schema: +// strategy_type + grid_config or ai_config + shared publish_config. +func (c StrategyConfig) MarshalJSON() ([]byte, error) { + strategyType := strings.TrimSpace(c.StrategyType) + if strategyType == "" { + strategyType = "ai_trading" + } + + out := struct { + StrategyType string `json:"strategy_type"` + Language string `json:"language,omitempty"` + AIConfig *AIStrategyConfig `json:"ai_config,omitempty"` + GridConfig *GridStrategyConfig `json:"grid_config,omitempty"` + PublishConfig *PublishStrategyConfig `json:"publish_config,omitempty"` + }{ + StrategyType: strategyType, + Language: c.Language, + PublishConfig: c.PublishConfig, + } + + if strategyType == "grid_trading" { + out.GridConfig = c.GridConfig + } else { + out.AIConfig = &AIStrategyConfig{ + CoinSource: c.CoinSource, + Indicators: c.Indicators, + CustomPrompt: c.CustomPrompt, + RiskControl: c.RiskControl, + PromptSections: c.PromptSections, + } + } + + return json.Marshal(out) +} + +// UnmarshalJSON accepts both the new nested schema and old flat configs. Old +// top-level AI fields are normalized into the Go compatibility fields. +func (c *StrategyConfig) UnmarshalJSON(data []byte) error { + type rawStrategyConfig struct { + StrategyType string `json:"strategy_type"` + Language string `json:"language"` + AIConfig *AIStrategyConfig `json:"ai_config"` + GridConfig *GridStrategyConfig `json:"grid_config"` + PublishConfig *PublishStrategyConfig `json:"publish_config"` + + CoinSource *CoinSourceConfig `json:"coin_source"` + Indicators *IndicatorConfig `json:"indicators"` + CustomPrompt *string `json:"custom_prompt"` + RiskControl *RiskControlConfig `json:"risk_control"` + PromptSections *PromptSectionsConfig `json:"prompt_sections"` + } + + var raw rawStrategyConfig + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + c.StrategyType = raw.StrategyType + c.Language = raw.Language + c.GridConfig = raw.GridConfig + c.PublishConfig = raw.PublishConfig + + if raw.AIConfig != nil { + c.CoinSource = raw.AIConfig.CoinSource + c.Indicators = raw.AIConfig.Indicators + c.CustomPrompt = raw.AIConfig.CustomPrompt + c.RiskControl = raw.AIConfig.RiskControl + c.PromptSections = raw.AIConfig.PromptSections + } else { + if raw.CoinSource != nil { + c.CoinSource = *raw.CoinSource + } + if raw.Indicators != nil { + c.Indicators = *raw.Indicators + } + if raw.CustomPrompt != nil { + c.CustomPrompt = *raw.CustomPrompt + } + if raw.RiskControl != nil { + c.RiskControl = *raw.RiskControl + } + if raw.PromptSections != nil { + c.PromptSections = *raw.PromptSections + } + } + + if strings.TrimSpace(c.StrategyType) == "" && c.GridConfig != nil { + c.StrategyType = "grid_trading" + } + return nil } // GridStrategyConfig grid trading specific configuration @@ -153,7 +693,7 @@ type PromptSectionsConfig struct { // CoinSourceConfig coin source configuration type CoinSourceConfig struct { - // source type: "static" | "ai500" | "oi_top" | "oi_low" | "mixed" + // source type shown in the product editor: "static" | "ai500" | "oi_top" | "oi_low" SourceType string `json:"source_type"` // static coin list (used when source_type = "static") StaticCoins []string `json:"static_coins,omitempty"` @@ -850,16 +1390,6 @@ func (c *StrategyConfig) getEffectiveCoinCount() int { count = c.CoinSource.OITopLimit case "oi_low": count = c.CoinSource.OILowLimit - case "mixed": - if c.CoinSource.UseAI500 { - count += c.CoinSource.AI500Limit - } - if c.CoinSource.UseOITop { - count += c.CoinSource.OITopLimit - } - if c.CoinSource.UseOILow { - count += c.CoinSource.OILowLimit - } default: count = c.CoinSource.AI500Limit } diff --git a/store/strategy_schema_test.go b/store/strategy_schema_test.go new file mode 100644 index 00000000..cc330eae --- /dev/null +++ b/store/strategy_schema_test.go @@ -0,0 +1,124 @@ +package store + +import ( + "encoding/json" + "testing" +) + +func TestStrategyConfigMarshalSeparatesGridAndAIConfig(t *testing.T) { + cfg := GetDefaultStrategyConfig("zh") + cfg.StrategyType = "grid_trading" + cfg.GridConfig = &GridStrategyConfig{ + Symbol: "BTCUSDT", + GridCount: 20, + TotalInvestment: 200, + Leverage: 2, + UseATRBounds: true, + ATRMultiplier: 2, + Distribution: "uniform", + } + + raw, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal grid config: %v", err) + } + + var asMap map[string]any + if err := json.Unmarshal(raw, &asMap); err != nil { + t.Fatalf("unmarshal grid config map: %v", err) + } + if asMap["strategy_type"] != "grid_trading" { + t.Fatalf("expected grid strategy_type, got %v", asMap["strategy_type"]) + } + if _, ok := asMap["grid_config"]; !ok { + t.Fatalf("expected grid_config in grid strategy JSON: %s", string(raw)) + } + for _, key := range []string{"ai_config", "coin_source", "indicators", "risk_control", "prompt_sections", "custom_prompt"} { + if _, ok := asMap[key]; ok { + t.Fatalf("did not expect %s in grid strategy JSON: %s", key, string(raw)) + } + } +} + +func TestStrategyConfigUnmarshalLegacyFlatAIConfig(t *testing.T) { + raw := []byte(`{ + "strategy_type":"ai_trading", + "coin_source":{"source_type":"static","static_coins":["ETHUSDT"]}, + "indicators":{"klines":{"primary_timeframe":"15m"}}, + "risk_control":{"max_positions":2,"min_confidence":80}, + "prompt_sections":{"entry_standards":"trend only"}, + "custom_prompt":"prefer ETH" + }`) + + var cfg StrategyConfig + if err := json.Unmarshal(raw, &cfg); err != nil { + t.Fatalf("unmarshal legacy flat config: %v", err) + } + if cfg.CoinSource.SourceType != "static" || len(cfg.CoinSource.StaticCoins) != 1 || cfg.CoinSource.StaticCoins[0] != "ETHUSDT" { + t.Fatalf("legacy coin source was not normalized: %+v", cfg.CoinSource) + } + if cfg.Indicators.Klines.PrimaryTimeframe != "15m" { + t.Fatalf("legacy indicators were not normalized: %+v", cfg.Indicators.Klines) + } + + normalized, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal normalized config: %v", err) + } + var asMap map[string]any + if err := json.Unmarshal(normalized, &asMap); err != nil { + t.Fatalf("unmarshal normalized map: %v", err) + } + if _, ok := asMap["ai_config"]; !ok { + t.Fatalf("expected ai_config after normalizing legacy config: %s", string(normalized)) + } + if _, ok := asMap["coin_source"]; ok { + t.Fatalf("did not expect legacy coin_source at top level: %s", string(normalized)) + } +} + +func TestStrategyConfigNormalizeProductSchemaForLLMLabels(t *testing.T) { + cfg := GetDefaultStrategyConfig("zh") + patch := map[string]any{ + "strategy_type": "AI 策略", + "ai_config": map[string]any{ + "coin_source": map[string]any{ + "source_type": "AI500", + }, + "indicators": map[string]any{ + "klines": map[string]any{ + "primary_timeframe": "1分钟", + "selected_timeframes": []any{`["1m"`, `"5m"`, `"15m"]`}, + }, + }, + }, + } + + merged, err := MergeStrategyConfig(cfg, patch) + if err != nil { + t.Fatalf("merge strategy config: %v", err) + } + merged.ClampLimits() + + if merged.StrategyType != "ai_trading" { + t.Fatalf("strategy_type = %q, want ai_trading", merged.StrategyType) + } + if merged.CoinSource.SourceType != "ai500" { + t.Fatalf("source_type = %q, want ai500", merged.CoinSource.SourceType) + } + if !merged.CoinSource.UseAI500 || merged.CoinSource.UseOITop || merged.CoinSource.UseOILow { + t.Fatalf("coin source flags not normalized: %+v", merged.CoinSource) + } + if merged.Indicators.Klines.PrimaryTimeframe != "1m" { + t.Fatalf("primary_timeframe = %q, want 1m", merged.Indicators.Klines.PrimaryTimeframe) + } + want := []string{"1m", "5m", "15m"} + if len(merged.Indicators.Klines.SelectedTimeframes) != len(want) { + t.Fatalf("selected_timeframes = %+v, want %+v", merged.Indicators.Klines.SelectedTimeframes, want) + } + for i := range want { + if merged.Indicators.Klines.SelectedTimeframes[i] != want[i] { + t.Fatalf("selected_timeframes = %+v, want %+v", merged.Indicators.Klines.SelectedTimeframes, want) + } + } +} diff --git a/store/trader.go b/store/trader.go index 8b983baa..9a4bd780 100644 --- a/store/trader.go +++ b/store/trader.go @@ -110,12 +110,20 @@ func (s *TraderStore) Update(trader *Trader) error { trader.ID, trader.Name, trader.AIModelID, trader.StrategyID) updates := map[string]interface{}{ - "name": trader.Name, - "ai_model_id": trader.AIModelID, - "exchange_id": trader.ExchangeID, - "strategy_id": trader.StrategyID, - "is_cross_margin": trader.IsCrossMargin, - "show_in_competition": trader.ShowInCompetition, + "name": trader.Name, + "ai_model_id": trader.AIModelID, + "exchange_id": trader.ExchangeID, + "strategy_id": trader.StrategyID, + "is_cross_margin": trader.IsCrossMargin, + "show_in_competition": trader.ShowInCompetition, + "btc_eth_leverage": trader.BTCETHLeverage, + "altcoin_leverage": trader.AltcoinLeverage, + "trading_symbols": trader.TradingSymbols, + "use_coin_pool": trader.UseAI500, + "use_oi_top": trader.UseOITop, + "custom_prompt": trader.CustomPrompt, + "override_base_prompt": trader.OverrideBasePrompt, + "system_prompt_template": trader.SystemPromptTemplate, } // Only update these if > 0 diff --git a/store/visibility.go b/store/visibility.go new file mode 100644 index 00000000..ed6d981f --- /dev/null +++ b/store/visibility.go @@ -0,0 +1,96 @@ +package store + +import "strings" + +func MissingRequiredExchangeCredentialFields(exchangeType, apiKey, secretKey, passphrase, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterAPIKeyPrivateKey string) []string { + switch strings.ToLower(strings.TrimSpace(exchangeType)) { + case "binance", "bybit", "gate", "indodax": + return missingNamedFields( + namedField{"api_key", apiKey}, + namedField{"secret_key", secretKey}, + ) + case "okx", "bitget", "kucoin": + return missingNamedFields( + namedField{"api_key", apiKey}, + namedField{"secret_key", secretKey}, + namedField{"passphrase", passphrase}, + ) + case "hyperliquid": + return missingNamedFields( + namedField{"api_key", apiKey}, + namedField{"hyperliquid_wallet_addr", hyperliquidWalletAddr}, + ) + case "aster": + return missingNamedFields( + namedField{"aster_user", asterUser}, + namedField{"aster_signer", asterSigner}, + namedField{"aster_private_key", asterPrivateKey}, + ) + case "lighter": + return missingNamedFields( + namedField{"lighter_wallet_addr", lighterWalletAddr}, + namedField{"lighter_api_key_private_key", lighterAPIKeyPrivateKey}, + ) + default: + return []string{"exchange_type"} + } +} + +type namedField struct { + name string + value string +} + +func missingNamedFields(fields ...namedField) []string { + missing := make([]string, 0, len(fields)) + for _, field := range fields { + if strings.TrimSpace(field.value) == "" { + missing = append(missing, field.name) + } + } + return missing +} + +func IsVisibleAIModel(model *AIModel) bool { + if model == nil { + return false + } + return model.Enabled || + strings.TrimSpace(string(model.APIKey)) != "" || + strings.TrimSpace(model.CustomAPIURL) != "" || + strings.TrimSpace(model.CustomModelName) != "" +} + +func IsVisibleExchange(exchange *Exchange) bool { + if exchange == nil { + return false + } + return exchange.Enabled || + strings.TrimSpace(string(exchange.APIKey)) != "" || + strings.TrimSpace(string(exchange.SecretKey)) != "" || + strings.TrimSpace(string(exchange.Passphrase)) != "" || + strings.TrimSpace(exchange.HyperliquidWalletAddr) != "" || + strings.TrimSpace(exchange.AsterUser) != "" || + strings.TrimSpace(exchange.AsterSigner) != "" || + strings.TrimSpace(string(exchange.AsterPrivateKey)) != "" || + strings.TrimSpace(exchange.LighterWalletAddr) != "" || + strings.TrimSpace(string(exchange.LighterPrivateKey)) != "" || + strings.TrimSpace(string(exchange.LighterAPIKeyPrivateKey)) != "" || + exchange.LighterAPIKeyIndex != 0 +} + +func IsVisibleTrader(trader *Trader) bool { + if trader == nil { + return false + } + return strings.TrimSpace(trader.Name) != "" && + strings.TrimSpace(trader.AIModelID) != "" && + strings.TrimSpace(trader.ExchangeID) != "" +} + +func IsVisibleStrategy(strategy *Strategy) bool { + if strategy == nil { + return false + } + return strings.TrimSpace(strategy.Name) != "" +} diff --git a/web/src/components/agent/AgentStepPanel.tsx b/web/src/components/agent/AgentStepPanel.tsx index 9e999bb7..acb4f216 100644 --- a/web/src/components/agent/AgentStepPanel.tsx +++ b/web/src/components/agent/AgentStepPanel.tsx @@ -1,9 +1,4 @@ -interface AgentStep { - id: string - label: string - status: 'planning' | 'pending' | 'running' | 'completed' | 'replanned' - detail?: string -} +import type { AgentStep } from '../../types/agent' interface AgentStepPanelProps { steps?: AgentStep[] @@ -23,6 +18,16 @@ export function AgentStepPanel({ steps, visible }: AgentStepPanelProps) { return null } + const sanitizedSteps = steps.filter((step) => { + const label = step.label.trim().toLowerCase() + const detail = (step.detail || '').trim().toLowerCase() + return !(label.startsWith('tool:') || detail === 'central_brain') + }) + + if (sanitizedSteps.length === 0) { + return null + } + return (
- {steps.map((step) => { + {sanitizedSteps.map((step) => { const style = statusStyles[step.status] return (
void @@ -10,43 +17,60 @@ export interface ChatInputHandle { interface ChatInputProps { language: string loading: boolean + value: string + onChange: (value: string) => void onSend: (text: string) => void + onStop: () => void } export const ChatInput = forwardRef( - function ChatInput({ language, loading, onSend }, ref) { - const [input, setInput] = useState('') + function ChatInput( + { language, loading, value, onChange, onSend, onStop }, + ref + ) { const [composing, setComposing] = useState(false) const inputRef = useRef(null) - useImperativeHandle(ref, () => ({ - focus: () => inputRef.current?.focus(), - clear: () => { - setInput('') - if (inputRef.current) inputRef.current.style.height = 'auto' - }, - getValue: () => input, - })) + useImperativeHandle( + ref, + () => ({ + focus: () => inputRef.current?.focus(), + clear: () => { + onChange('') + if (inputRef.current) inputRef.current.style.height = 'auto' + }, + getValue: () => value, + }), + [onChange, value] + ) + + const resizeInput = useCallback(() => { + const el = inputRef.current + if (!el) return + el.style.height = 'auto' + el.style.height = Math.min(el.scrollHeight, 150) + 'px' + }, []) const handleInputChange = useCallback( (e: React.ChangeEvent) => { - setInput(e.target.value) - const el = e.target - el.style.height = 'auto' - el.style.height = Math.min(el.scrollHeight, 150) + 'px' + onChange(e.target.value) }, - [] + [onChange] ) const handleSend = () => { - const msg = input.trim() + const msg = value.trim() if (!msg || loading) return - setInput('') + onChange('') if (inputRef.current) inputRef.current.style.height = 'auto' onSend(msg) inputRef.current?.focus() } + useEffect(() => { + resizeInput() + }, [resizeInput, value]) + // Keyboard shortcut: Cmd+K to focus useEffect(() => { const handleKeyDown = (e: KeyboardEvent) => { @@ -84,7 +108,7 @@ export const ChatInput = forwardRef( >