From 737f9bca95a4d2126e9ee53dfc3ba4fcbf72c14d Mon Sep 17 00:00:00 2001 From: lky-spec Date: Sun, 19 Apr 2026 16:06:28 +0800 Subject: [PATCH] Enhance NOFXi agent workflow and diagnostics --- agent/backend_logs_test.go | 127 +++ agent/config_tools_test.go | 67 ++ agent/llm_skill_router.go | 162 +++- agent/memory.go | 10 +- agent/planner_runtime.go | 771 ++++++++++++++-- agent/planner_runtime_state_test.go | 33 + agent/skill_dag.go | 277 ++++++ agent/skill_dag_runtime.go | 51 ++ agent/skill_dag_runtime_test.go | 27 + agent/skill_dag_test.go | 67 ++ agent/skill_dispatcher.go | 449 +++++++-- agent/skill_dispatcher_test.go | 592 ++++++++++++ agent/skill_execution_handlers.go | 966 +++++++++++++++++++- agent/skill_management_handlers.go | 530 ++++++++++- agent/skill_outcome.go | 180 ++++ agent/tools.go | 422 +++++++-- agent/workflow.go | 521 +++++++++++ agent/workflow_test.go | 37 + manager/trader_manager.go | 25 +- store/ai_model.go | 8 +- trader/auto_trader.go | 77 +- trader/auto_trader_loop.go | 60 +- web/src/components/agent/AgentStepPanel.tsx | 2 +- web/src/pages/AgentChatPage.tsx | 98 +- web/src/stores/agentChatStore.ts | 52 ++ 25 files changed, 5233 insertions(+), 378 deletions(-) create mode 100644 agent/backend_logs_test.go create mode 100644 agent/skill_dag.go create mode 100644 agent/skill_dag_runtime.go create mode 100644 agent/skill_dag_runtime_test.go create mode 100644 agent/skill_dag_test.go create mode 100644 agent/skill_outcome.go create mode 100644 agent/workflow.go create mode 100644 agent/workflow_test.go create mode 100644 web/src/stores/agentChatStore.ts diff --git a/agent/backend_logs_test.go b/agent/backend_logs_test.go new file mode 100644 index 00000000..16f37a64 --- /dev/null +++ b/agent/backend_logs_test.go @@ -0,0 +1,127 @@ +package agent + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "nofx/store" +) + +func TestReadBackendLogEntriesReturnsRecentErrorLines(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + tmp := t.TempDir() + if err := os.Chdir(tmp); err != nil { + t.Fatalf("Chdir(tmp) error = %v", err) + } + t.Cleanup(func() { + _ = os.Chdir(wd) + }) + + if err := os.MkdirAll("data", 0o755); err != nil { + t.Fatalf("MkdirAll(data) error = %v", err) + } + logPath := filepath.Join("data", "nofx_2099-01-01.log") + content := strings.Join([]string{ + "04-19 13:00:00 [INFO] api/server.go:590 API server starting", + "04-19 13:00:01 [ERRO] api/server.go:600 invalid signature for okx account", + "04-19 13:00:02 [ERRO] agent/tools.go:123 model update failed: missing api key", + }, "\n") + "\n" + if err := os.WriteFile(logPath, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + path, entries, err := readBackendLogEntries(10, "model", true) + if err != nil { + t.Fatalf("readBackendLogEntries() error = %v", err) + } + if !strings.Contains(path, "nofx_2099-01-01.log") { + t.Fatalf("unexpected log path: %s", path) + } + if len(entries) != 1 || !strings.Contains(entries[0], "missing api key") { + t.Fatalf("unexpected filtered entries: %#v", entries) + } +} + +func TestToolGetBackendLogsRequiresOwnedTrader(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + tmp := t.TempDir() + if err := os.Chdir(tmp); err != nil { + t.Fatalf("Chdir(tmp) error = %v", err) + } + t.Cleanup(func() { + _ = os.Chdir(wd) + }) + + if err := os.MkdirAll("data", 0o755); err != nil { + t.Fatalf("MkdirAll(data) error = %v", err) + } + logPath := filepath.Join("data", "nofx_2099-01-01.log") + content := strings.Join([]string{ + "04-19 13:00:00 [INFO] api/server.go:590 API server starting", + "04-19 13:00:01 [ERRO] trader/runtime.go:88 trader_id=trader-owned strategy execution failed", + "04-19 13:00:02 [ERRO] trader/runtime.go:89 trader_id=trader-other strategy execution failed", + }, "\n") + "\n" + if err := os.WriteFile(logPath, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + a := newTestAgentWithStore(t) + if err := a.store.Trader().Create(&store.Trader{ + ID: "trader-owned", + UserID: "user-1", + Name: "Owned Trader", + AIModelID: "model-1", + ExchangeID: "exchange-1", + StrategyID: "strategy-1", + InitialBalance: 1000, + }); err != nil { + t.Fatalf("create owned trader: %v", err) + } + if err := a.store.Trader().Create(&store.Trader{ + ID: "trader-other", + UserID: "user-2", + Name: "Other Trader", + AIModelID: "model-2", + ExchangeID: "exchange-2", + StrategyID: "strategy-2", + InitialBalance: 1000, + }); err != nil { + t.Fatalf("create other trader: %v", err) + } + + resp := a.toolGetBackendLogs("user-1", `{"trader_id":"trader-owned","limit":5}`) + var okResult struct { + TraderID string `json:"trader_id"` + Entries []string `json:"entries"` + Count int `json:"count"` + } + if err := json.Unmarshal([]byte(resp), &okResult); err != nil { + t.Fatalf("unmarshal owned response: %v\nraw=%s", err, resp) + } + if okResult.TraderID != "trader-owned" || okResult.Count != 1 { + t.Fatalf("unexpected owned response: %+v", okResult) + } + if len(okResult.Entries) != 1 || !strings.Contains(okResult.Entries[0], "trader-owned") { + t.Fatalf("unexpected owned entries: %#v", okResult.Entries) + } + + resp = a.toolGetBackendLogs("user-1", `{"trader_id":"trader-other","limit":5}`) + var denied struct { + Error string `json:"error"` + } + if err := json.Unmarshal([]byte(resp), &denied); err != nil { + t.Fatalf("unmarshal denied response: %v\nraw=%s", err, resp) + } + if denied.Error != "trader not found for current user" { + t.Fatalf("unexpected denied response: %+v", denied) + } +} diff --git a/agent/config_tools_test.go b/agent/config_tools_test.go index db1676bb..4cf717d7 100644 --- a/agent/config_tools_test.go +++ b/agent/config_tools_test.go @@ -86,6 +86,7 @@ func TestToolManageModelConfigLifecycle(t *testing.T) { "action":"create", "provider":"openai", "enabled":true, + "api_key":"sk-test", "custom_api_url":"https://api.openai.com/v1", "custom_model_name":"gpt-5-mini" }`) @@ -136,6 +137,71 @@ func TestToolManageModelConfigLifecycle(t *testing.T) { } } +func TestToolManageModelConfigRejectsEnableWithoutAPIKey(t *testing.T) { + a := newTestAgentWithStore(t) + + createResp := a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"openai", + "enabled":false, + "custom_model_name":"gpt-4o" + }`) + var created struct { + Model safeModelToolConfig `json:"model"` + } + if err := json.Unmarshal([]byte(createResp), &created); err != nil { + t.Fatalf("unmarshal create response: %v\nraw=%s", err, createResp) + } + + updateResp := a.toolManageModelConfig("user-1", `{ + "action":"update", + "model_id":"`+created.Model.ID+`", + "enabled":true + }`) + if !strings.Contains(updateResp, "cannot enable model config before API key is configured") { + t.Fatalf("expected enabling incomplete model to fail, got %s", updateResp) + } +} + +func TestGetDefaultSkipsEnabledModelWithoutAPIKey(t *testing.T) { + a := newTestAgentWithStore(t) + + incompleteCreate := a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"openai", + "enabled":true, + "custom_model_name":"gpt-4o" + }`) + var incomplete struct { + Model safeModelToolConfig `json:"model"` + } + if err := json.Unmarshal([]byte(incompleteCreate), &incomplete); err != nil { + t.Fatalf("unmarshal incomplete create response: %v\nraw=%s", err, incompleteCreate) + } + + completeCreate := a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"deepseek", + "enabled":true, + "api_key":"sk-test", + "custom_model_name":"deepseek-chat" + }`) + var complete struct { + Model safeModelToolConfig `json:"model"` + } + if err := json.Unmarshal([]byte(completeCreate), &complete); err != nil { + t.Fatalf("unmarshal complete create response: %v\nraw=%s", err, completeCreate) + } + + model, err := a.store.AIModel().GetDefault("user-1") + if err != nil { + t.Fatalf("GetDefault() error = %v", err) + } + if model.ID != complete.Model.ID { + t.Fatalf("expected GetDefault to skip incomplete enabled model and return %s, got %s", complete.Model.ID, model.ID) + } +} + func TestToolManageTraderLifecycle(t *testing.T) { a := newTestAgentWithStore(t) @@ -143,6 +209,7 @@ func TestToolManageTraderLifecycle(t *testing.T) { "action":"create", "provider":"openai", "enabled":true, + "api_key":"sk-test", "custom_api_url":"https://api.openai.com/v1", "custom_model_name":"gpt-5-mini" }`) diff --git a/agent/llm_skill_router.go b/agent/llm_skill_router.go index fddc819c..3e53a699 100644 --- a/agent/llm_skill_router.go +++ b/agent/llm_skill_router.go @@ -16,14 +16,14 @@ type llmSkillRouteDecision struct { Filter string `json:"filter,omitempty"` } -func (a *Agent) tryLLMSkillRoute(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool) { +func (a *Agent) tryLLMSkillRoute(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { if a.aiClient == nil { - return "", false + return "", false, nil } text = strings.TrimSpace(text) if text == "" { - return "", false + return "", false, nil } recentConversationCtx := a.buildRecentConversationContext(userID, text) @@ -47,10 +47,17 @@ Available skills: - model_diagnosis - strategy_diagnosis -For management skills, choose one action from: -- query +For management skills, choose one atomic action from: +- query_list +- query_detail +- query_running - create -- update +- update_name +- update_bindings +- update_status +- update_endpoint +- update_config +- update_prompt - delete - start - stop @@ -69,7 +76,8 @@ Rules: - Prefer route "planner" when uncertain. - Prefer route "planner" for market analysis, broad advice, multi-step troubleshooting, or requests that need synthesis. - Prefer route "skill" for straightforward management requests like listing, creating, starting, stopping, enabling, disabling, renaming, or deleting known entities. -- Questions like "当前有运行中的trader吗" and "有没有 trader 在跑" are trader_management with action "query" and filter "running_only". +- Questions like "当前有运行中的trader吗" and "有没有 trader 在跑" are trader_management with action "query_running". +- Questions about one entity's details, config, parameters, or prompt should prefer action "query_detail". - Do not use route "skill" for casual chat. - Consider Recent conversation, Task state, and Execution state JSON before deciding. @@ -88,17 +96,37 @@ Return JSON with this exact shape: Ctx: stageCtx, }) if err != nil { - return "", false + return "", false, nil } decision, err := parseLLMSkillRouteDecision(raw) if err != nil || decision.Route != "skill" { - return "", false + return "", false, nil } - answer, ok := a.executeLLMSkillRoute(storeUserID, userID, lang, text, decision) + outcome, ok := a.executeLLMSkillRoute(storeUserID, userID, lang, text, decision) if !ok { - return "", false + return "", false, nil + } + + review, err := a.reviewTaskCompletion(ctx, userID, lang, text, outcome) + if err != nil { + if outcome.Status == skillOutcomeRecoverableError || outcome.Status == skillOutcomeFatalError || outcome.Status == skillOutcomeNotHandled { + return "", false, nil + } + review = taskReviewDecision{Route: "complete", Answer: outcome.UserMessage} + } + if review.Route == "replan" { + answer, planErr := a.runPlannedAgent(ctx, storeUserID, userID, lang, fmt.Sprintf("Original user request:\n%s\n\nPrevious skill outcome JSON:\n%s", text, mustMarshalJSON(outcome)), onEvent) + return answer, true, planErr + } + + answer := strings.TrimSpace(review.Answer) + if answer == "" { + answer = strings.TrimSpace(outcome.UserMessage) + } + if answer == "" { + return "", false, nil } a.recordSkillInteraction(userID, text, answer) @@ -113,7 +141,7 @@ Return JSON with this exact shape: onEvent(StreamEventTool, label) onEvent(StreamEventDelta, answer) } - return answer, true + return answer, true, nil } func parseLLMSkillRouteDecision(raw string) (llmSkillRouteDecision, error) { @@ -140,43 +168,125 @@ func parseLLMSkillRouteDecision(raw string) (llmSkillRouteDecision, error) { func normalizeLLMSkillRouteDecision(decision llmSkillRouteDecision) llmSkillRouteDecision { decision.Route = strings.TrimSpace(strings.ToLower(decision.Route)) decision.Skill = strings.TrimSpace(strings.ToLower(decision.Skill)) - decision.Action = strings.TrimSpace(strings.ToLower(decision.Action)) decision.Filter = strings.TrimSpace(strings.ToLower(decision.Filter)) + if decision.Action == "query" && decision.Filter == "running_only" && decision.Skill == "trader_management" { + decision.Action = "query_running" + } else { + decision.Action = normalizeAtomicSkillAction(decision.Skill, decision.Action) + } return decision } -func (a *Agent) executeLLMSkillRoute(storeUserID string, userID int64, lang, text string, decision llmSkillRouteDecision) (string, bool) { +func (a *Agent) executeLLMSkillRoute(storeUserID string, userID int64, lang, text string, decision llmSkillRouteDecision) (skillOutcome, bool) { session := skillSession{Name: decision.Skill, Action: decision.Action} switch decision.Skill { case "trader_management": if decision.Action == "create" { - return a.handleCreateTraderSkill(storeUserID, userID, lang, text, session) + answer, handled := a.handleCreateTraderSkill(storeUserID, userID, lang, text, session) + if !handled { + return skillOutcome{}, false + } + return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true } answer, handled := a.handleTraderManagementSkill(storeUserID, userID, lang, text, session) - if handled && decision.Action == "query" { - return applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), decision.Filter), true + if handled && decision.Action == "query_running" { + answer = applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), "running_only") } - return answer, handled + if !handled { + return skillOutcome{}, false + } + return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true case "exchange_management": - return a.handleExchangeManagementSkill(storeUserID, userID, lang, text, session) + answer, handled := a.handleExchangeManagementSkill(storeUserID, userID, lang, text, session) + if !handled { + return skillOutcome{}, false + } + return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true case "model_management": - return a.handleModelManagementSkill(storeUserID, userID, lang, text, session) + answer, handled := a.handleModelManagementSkill(storeUserID, userID, lang, text, session) + if !handled { + return skillOutcome{}, false + } + return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true case "strategy_management": - return a.handleStrategyManagementSkill(storeUserID, userID, lang, text, session) + answer, handled := a.handleStrategyManagementSkill(storeUserID, userID, lang, text, session) + if !handled { + return skillOutcome{}, false + } + return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true case "model_diagnosis": - return a.handleModelDiagnosisSkill(storeUserID, lang, text), true + return skillOutcome{ + Skill: decision.Skill, + Action: defaultIfEmpty(decision.Action, "diagnose"), + Status: skillOutcomeSuccess, + GoalAchieved: true, + UserMessage: a.handleModelDiagnosisSkill(storeUserID, lang, text), + }, true case "exchange_diagnosis": - return a.handleExchangeDiagnosisSkill(storeUserID, lang, text), true + return skillOutcome{ + Skill: decision.Skill, + Action: defaultIfEmpty(decision.Action, "diagnose"), + Status: skillOutcomeSuccess, + GoalAchieved: true, + UserMessage: a.handleExchangeDiagnosisSkill(storeUserID, lang, text), + }, true case "trader_diagnosis": - return a.handleTraderDiagnosisSkill(storeUserID, lang, text), true + return skillOutcome{ + Skill: decision.Skill, + Action: defaultIfEmpty(decision.Action, "diagnose"), + Status: skillOutcomeSuccess, + GoalAchieved: true, + UserMessage: a.handleTraderDiagnosisSkill(storeUserID, lang, text), + }, true case "strategy_diagnosis": - return a.handleStrategyDiagnosisSkill(storeUserID, lang, text), true + return skillOutcome{ + Skill: decision.Skill, + Action: defaultIfEmpty(decision.Action, "diagnose"), + Status: skillOutcomeSuccess, + GoalAchieved: true, + UserMessage: a.handleStrategyDiagnosisSkill(storeUserID, lang, text), + }, true default: - return "", false + return skillOutcome{}, false } } +func skillDataForAction(storeUserID, skill, action string, a *Agent) map[string]any { + var raw string + switch skill { + case "trader_management": + if strings.HasPrefix(action, "query") { + raw = a.toolListTraders(storeUserID) + } + case "exchange_management": + if strings.HasPrefix(action, "query") { + raw = a.toolGetExchangeConfigs(storeUserID) + } + case "model_management": + if strings.HasPrefix(action, "query") { + raw = a.toolGetModelConfigs(storeUserID) + } + case "strategy_management": + if strings.HasPrefix(action, "query") { + raw = a.toolGetStrategies(storeUserID) + } + } + if strings.TrimSpace(raw) == "" { + return nil + } + var data map[string]any + if err := json.Unmarshal([]byte(raw), &data); err != nil { + return nil + } + return data +} + +func mustMarshalJSON(v any) string { + data, _ := json.Marshal(v) + return string(data) +} + func applyTraderQueryFilter(lang, fallback, raw, filter string) string { filter = strings.TrimSpace(strings.ToLower(filter)) if filter == "" { diff --git a/agent/memory.go b/agent/memory.go index 01142985..4b274648 100644 --- a/agent/memory.go +++ b/agent/memory.go @@ -299,7 +299,7 @@ func (a *Agent) maybeCompressHistory(ctx context.Context, userID int64) { return } if err := a.saveTaskState(userID, updatedState); err != nil { - a.logger.Warn("failed to persist task state", "error", err, "user_id", userID) + a.log().Warn("failed to persist task state", "error", err, "user_id", userID) return } a.history.Replace(userID, recentPart) @@ -323,11 +323,11 @@ func (a *Agent) maybeUpdateTaskStateIncrementally(ctx context.Context, userID in existingState := a.getTaskState(userID) updatedState, err := a.summarizeRecentConversationToTaskState(ctx, userID, existingState, window) if err != nil { - a.logger.Warn("failed to incrementally update task state", "error", err, "user_id", userID) + a.log().Warn("failed to incrementally update task state", "error", err, "user_id", userID) return } if err := a.saveTaskState(userID, updatedState); err != nil { - a.logger.Warn("failed to persist incremental task state", "error", err, "user_id", userID) + a.log().Warn("failed to persist incremental task state", "error", err, "user_id", userID) } } @@ -383,7 +383,7 @@ Rules: return TaskState{}, err } state = normalizeTaskState(state) - a.logger.Info("compressed chat history into task state", "user_id", userID, "archived_messages", len(oldPart)) + a.log().Info("compressed chat history into task state", "user_id", userID, "archived_messages", len(oldPart)) return state, nil } @@ -436,7 +436,7 @@ Rules: return TaskState{}, err } state = normalizeTaskState(state) - a.logger.Info("incrementally refreshed task state", "user_id", userID, "window_messages", len(recentPart)) + a.log().Info("incrementally refreshed task state", "user_id", userID, "window_messages", len(recentPart)) return state, nil } diff --git a/agent/planner_runtime.go b/agent/planner_runtime.go index 9bec1150..5db56a64 100644 --- a/agent/planner_runtime.go +++ b/agent/planner_runtime.go @@ -198,16 +198,6 @@ func isRealtimeAccountIntent(text string) bool { func snapshotKindsForIntent(userText string) []string { kinds := make([]string, 0, 6) - if isConfigOrTraderIntent(userText) { - kinds = append(kinds, - "current_model_configs", - "current_exchange_configs", - "current_traders", - ) - } - if isStrategyIntent(userText) { - kinds = append(kinds, "current_strategies") - } return uniqueStrings(kinds) } @@ -756,18 +746,18 @@ func (a *Agent) thinkAndAct(ctx context.Context, storeUserID string, userID int6 if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, nil); ok || err != nil { return answer, err } - if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, nil); ok { - return answer, nil - } - if answer, ok := a.tryLLMSkillRoute(ctx, storeUserID, userID, lang, text, nil); ok { - return answer, nil - } - if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, nil); ok { + if answer, ok := tryInstantDirectReply(lang, text); ok { return answer, nil } if answer, ok := a.tryReadFastPath(storeUserID, userID, lang, text); ok { return answer, nil } + if answer, ok, err := a.tryWorkflowIntent(ctx, storeUserID, userID, lang, text, nil); ok || err != nil { + return answer, err + } + if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, nil); ok { + return answer, nil + } if a.aiClient == nil { return a.noAIFallback(lang, text) } @@ -778,13 +768,10 @@ func (a *Agent) thinkAndActStream(ctx context.Context, storeUserID string, userI if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil { return answer, err } - if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, onEvent); ok { - return answer, nil - } - if answer, ok := a.tryLLMSkillRoute(ctx, storeUserID, userID, lang, text, onEvent); ok { - return answer, nil - } - if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { + if answer, ok := tryInstantDirectReply(lang, text); ok { + if onEvent != nil { + onEvent(StreamEventDelta, answer) + } return answer, nil } if answer, ok := a.tryReadFastPath(storeUserID, userID, lang, text); ok { @@ -794,12 +781,65 @@ func (a *Agent) thinkAndActStream(ctx context.Context, storeUserID string, userI } return answer, nil } + if answer, ok, err := a.tryWorkflowIntent(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil { + return answer, err + } + if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { + return answer, nil + } if a.aiClient == nil { return a.noAIFallback(lang, text) } return a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent) } +func tryInstantDirectReply(lang, text string) (string, bool) { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return "", false + } + + zhReplies := map[string]string{ + "hi": "在,有什么我帮你看的?", + "hello": "在,有什么我帮你看的?", + "hey": "在,有什么我帮你看的?", + "你好": "在,有什么我帮你看的?", + "嗨": "在,有什么我帮你看的?", + "在吗": "在,有什么我帮你看的?", + "谢谢": "不客气。", + "多谢": "不客气。", + "谢了": "不客气。", + "ok": "好。", + "好的": "好。", + "收到": "好。", + } + enReplies := map[string]string{ + "hi": "I'm here. What should we look at?", + "hello": "I'm here. What should we look at?", + "hey": "I'm here. What should we look at?", + "thanks": "You're welcome.", + "thank you": "You're welcome.", + "ok": "Okay.", + "okay": "Okay.", + "got it": "Got it.", + } + + if lang == "zh" { + if reply, ok := zhReplies[lower]; ok { + return reply, true + } + if reply, ok := enReplies[lower]; ok { + return reply, true + } + return "", false + } + + if reply, ok := enReplies[lower]; ok { + return reply, true + } + return "", false +} + func (a *Agent) hasActiveSkillSession(userID int64) bool { session := a.getSkillSession(userID) return strings.TrimSpace(session.Name) != "" @@ -818,21 +858,335 @@ func hasActiveExecutionState(state ExecutionState) bool { } func (a *Agent) tryStatePriorityPath(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { - if a.hasActiveSkillSession(userID) { - if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { - return answer, true, nil + if workflow := a.getWorkflowSession(userID); hasActiveWorkflowSession(workflow) { + answer, handled, err := a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, workflow, onEvent) + if handled || err != nil { + return answer, true, err + } + } + if session := a.getSkillSession(userID); strings.TrimSpace(session.Name) != "" { + switch a.classifySkillSessionInput(ctx, userID, lang, session, text) { + case "cancel": + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + if lang == "zh" { + return "已取消当前流程。", true, nil + } + return "Cancelled the current flow.", true, nil + case "interrupt": + a.clearSkillSession(userID) + default: + if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok { + return answer, true, nil + } } } state := a.getExecutionState(userID) if hasActiveExecutionState(state) { - answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent) - return answer, true, err + switch classifyExecutionStateInput(state, text) { + case "cancel": + a.clearExecutionState(userID) + if lang == "zh" { + return "已取消当前流程。", true, nil + } + return "Cancelled the current flow.", true, nil + case "interrupt": + a.clearExecutionState(userID) + default: + answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent) + return answer, true, err + } } return "", false, nil } +func (a *Agent) classifySkillSessionInput(ctx context.Context, userID int64, lang string, session skillSession, text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return "continue" + } + if isYesReply(text) || isNoReply(text) { + return "continue" + } + if isExplicitFlowAbort(text) { + return "cancel" + } + if shouldContinueSkillSessionByExpectedSlot(session, text) { + return "continue" + } + if decision := a.classifySkillSessionIntentWithLLM(ctx, userID, lang, session, text); decision != "" { + return decision + } + if isNewSkillRootIntent(session, text) { + return "interrupt" + } + if isSkillFlowDeflection(session, text) { + return "interrupt" + } + if belongsToSkillDomain(session.Name, text) || !looksLikeNewTopLevelIntent(text) { + return "continue" + } + return "interrupt" +} + +type skillSessionIntentDecision struct { + Decision string `json:"decision"` +} + +func shouldUseLLMSkillSessionClassifier(session skillSession, text string) bool { + if strings.TrimSpace(text) == "" { + return false + } + if isExplicitFlowAbort(text) || isYesReply(text) || isNoReply(text) { + return false + } + if shouldContinueSkillSessionByExpectedSlot(session, text) { + return false + } + return true +} + +func shouldContinueSkillSessionByExpectedSlot(session skillSession, text string) bool { + text = strings.TrimSpace(text) + if text == "" { + return false + } + currentStep, ok := currentSkillDAGStep(session) + if !ok { + return false + } + switch currentStep.ID { + case "await_start_confirmation", "await_confirmation": + return isYesReply(text) || isNoReply(text) + case "resolve_config_value": + if fieldValue(session, "config_field") == "selected_timeframes" { + return timeframeTokenRE.MatchString(strings.ToLower(text)) + } + return firstIntegerPattern.MatchString(text) + case "collect_enabled": + _, ok := parseEnabledValue(text) + return ok + case "collect_custom_api_url": + return extractURL(text) != "" + case "resolve_exchange_type": + return exchangeTypeFromText(text) != "" + case "resolve_provider": + return providerFromText(text) != "" + case "resolve_name", "collect_name", "collect_prompt", "collect_account_name", "collect_custom_model_name": + return !looksLikeNewTopLevelIntent(text) + } + for _, field := range currentStep.RequiredFields { + switch field { + case "config_value": + return firstIntegerPattern.MatchString(text) + case "enabled": + _, ok := parseEnabledValue(text) + return ok + case "custom_api_url": + return extractURL(text) != "" + } + } + return false +} + +func (a *Agent) classifySkillSessionIntentWithLLM(ctx context.Context, userID int64, lang string, session skillSession, text string) string { + if a == nil || a.aiClient == nil { + return "" + } + if !shouldUseLLMSkillSessionClassifier(session, text) { + return "" + } + currentStep, _ := currentSkillDAGStep(session) + recentConversationCtx := a.buildRecentConversationContext(userID, text) + systemPrompt := `You classify one user message while a NOFXi structured management flow is active. +Return JSON only. No markdown. + +Possible decisions: +- "continue": the user is still answering the current flow +- "cancel": the user wants to stop the current flow +- "interrupt": the user changed topic, wants diagnosis/query/new task, or should leave the current flow + +Be conservative: +- Prefer "continue" only when the message clearly answers the current slot/question. +- Use "cancel" for explicit abandonment like "算了", "不改了", "换话题", "别弄了". +- Use "interrupt" for diagnosis, query, new requests, or topic shifts.` + userPrompt := fmt.Sprintf( + "Language: %s\nActive skill: %s\nAction: %s\nCurrent DAG step: %s\nExpected required fields: %s\nUser message: %s\n\nRecent conversation:\n%s", + lang, + session.Name, + session.Action, + currentStep.ID, + strings.Join(currentStep.RequiredFields, ", "), + text, + recentConversationCtx, + ) + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return "" + } + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + var decision skillSessionIntentDecision + if err := json.Unmarshal([]byte(raw), &decision); err != nil { + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start < 0 || end <= start || json.Unmarshal([]byte(raw[start:end+1]), &decision) != nil { + return "" + } + } + switch strings.TrimSpace(decision.Decision) { + case "continue", "cancel", "interrupt": + return decision.Decision + default: + return "" + } +} + +func isSkillFlowDeflection(session skillSession, text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if containsAny(lower, []string{ + "看下报错", "看看报错", "帮我看下报错", "帮我看看报错", "报错怎么回事", "错误怎么回事", + "换话题", "聊别的", "不是这个", "先说别的", "不聊这个", + }) { + return true + } + switch strings.TrimSpace(session.Name) { + case "exchange_management": + return detectModelDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text) + case "model_management": + return detectExchangeDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text) + case "strategy_management": + return detectExchangeDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectModelDiagnosisSkill(text) + case "trader_management": + return detectExchangeDiagnosisSkill(text) || detectModelDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text) + default: + return false + } +} + +func isNewSkillRootIntent(session skillSession, text string) bool { + currentSkill := strings.TrimSpace(session.Name) + currentAction := strings.TrimSpace(session.Action) + if currentSkill == "" { + return false + } + switch currentSkill { + case "trader_management": + if detectCreateTraderSkill(text) && currentAction != "create" { + return true + } + if action := normalizeAtomicSkillAction("trader_management", detectManagementAction(text, "trader")); action == "create" && currentAction != "create" { + return true + } + case "strategy_management": + if action := normalizeAtomicSkillAction("strategy_management", detectManagementAction(text, "strategy")); action == "create" && currentAction != "create" { + return true + } + case "model_management": + if action := normalizeAtomicSkillAction("model_management", detectManagementAction(text, "model")); action == "create" && currentAction != "create" { + return true + } + case "exchange_management": + if action := normalizeAtomicSkillAction("exchange_management", detectManagementAction(text, "exchange")); action == "create" && currentAction != "create" { + return true + } + } + return false +} + +func classifyExecutionStateInput(state ExecutionState, text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return "continue" + } + if isExplicitFlowAbort(text) { + return "cancel" + } + if isYesReply(text) || isNoReply(text) || shouldResetExecutionStateForNewAttempt(text, state) { + return "continue" + } + if state.Waiting != nil && !looksLikeNewTopLevelIntent(text) { + return "continue" + } + if looksLikeNewTopLevelIntent(text) { + return "interrupt" + } + return "continue" +} + +func isExplicitFlowAbort(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if isCancelSkillReply(text) { + return true + } + return containsAny(lower, []string{ + "算了", "先不", "不配了", "别弄了", "不搞了", "先停", "换个话题", "换话题", "聊点别的", "聊别的", + "stop this", "drop it", "never mind", "forget it", "skip this", + }) +} + +func belongsToSkillDomain(skillName, text string) bool { + switch strings.TrimSpace(skillName) { + case "trader_management": + return detectCreateTraderSkill(text) || detectTraderManagementIntent(text) || detectTraderDiagnosisSkill(text) + case "strategy_management": + return detectStrategyManagementIntent(text) || detectStrategyDiagnosisSkill(text) + case "model_management": + return detectModelManagementIntent(text) || detectModelDiagnosisSkill(text) + case "exchange_management": + return detectExchangeManagementIntent(text) || detectExchangeDiagnosisSkill(text) + default: + return false + } +} + +func looksLikeNewTopLevelIntent(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.HasPrefix(lower, "/") { + return true + } + if detectCreateTraderSkill(text) || + detectTraderManagementIntent(text) || + detectExchangeManagementIntent(text) || + detectModelManagementIntent(text) || + detectStrategyManagementIntent(text) || + detectTraderDiagnosisSkill(text) || + detectExchangeDiagnosisSkill(text) || + detectModelDiagnosisSkill(text) || + detectStrategyDiagnosisSkill(text) { + return true + } + if detectReadFastPath(text) != nil { + return true + } + return containsAny(lower, []string{ + "btc", "eth", "sol", "市场", "行情", "余额", "仓位", "持仓", "订单", "账户", + "price", "market", "balance", "position", "portfolio", "account", + }) +} + func (a *Agent) tryDirectAnswer(ctx context.Context, userID int64, lang, text string, onEvent func(event, data string)) (string, bool) { if a.aiClient == nil { return "", false @@ -948,8 +1302,10 @@ func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID onEvent(StreamEventPlanning, a.planningStatusText(lang)) } + requestStartedAt := time.Now() state, err := a.prepareExecutionState(ctx, storeUserID, userID, lang, text) if err != nil { + a.logPlannerTiming("", userID, "prepare_execution_state", requestStartedAt, err) if isPlannerTimeoutError(err) { msg := plannerTimeoutMessage(lang) if onEvent != nil { @@ -961,8 +1317,11 @@ func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID a.logger.Warn("planner failed, falling back to legacy loop", "error", err, "user_id", userID) return a.thinkAndActLegacy(ctx, userID, lang, text, onEvent) } + a.logPlannerTiming(state.SessionID, userID, "prepare_execution_state", requestStartedAt, nil) + executionStartedAt := time.Now() answer, err := a.executePlan(ctx, storeUserID, userID, lang, &state, onEvent) + a.logPlannerTiming(state.SessionID, userID, "execute_plan", executionStartedAt, err) if err != nil { if isPlannerTimeoutError(err) { msg := plannerTimeoutMessage(lang) @@ -979,6 +1338,7 @@ func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID a.history.Add(userID, "assistant", answer) a.maybeUpdateTaskStateIncrementally(ctx, userID) a.maybeCompressHistory(ctx, userID) + a.logPlannerTiming(state.SessionID, userID, "run_planned_agent_total", requestStartedAt, nil) return answer, nil } @@ -1005,12 +1365,7 @@ func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, u existing.FinalAnswer = "" existing.LastError = "" existing = a.refreshStateForDynamicRequests(storeUserID, text, existing) - plan, err := a.createExecutionPlan(ctx, userID, lang, text, existing) - if err != nil { - return ExecutionState{}, err - } - existing.Goal = plan.Goal - existing.Steps = plan.Steps + existing.Steps = completedSteps(existing.Steps) existing.CurrentStepID = "" existing.Status = executionStatusRunning existing.UpdatedAt = time.Now().UTC().Format(time.RFC3339) @@ -1023,12 +1378,6 @@ func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, u state := newExecutionState(userID, text) a.refreshCurrentReferencesForUserText(storeUserID, text, &state) state = a.refreshStateForDynamicRequests(storeUserID, text, state) - plan, err := a.createExecutionPlan(ctx, userID, lang, text, state) - if err != nil { - return ExecutionState{}, err - } - state.Goal = plan.Goal - state.Steps = plan.Steps state.Status = executionStatusRunning if err := a.saveExecutionState(state); err != nil { return ExecutionState{}, err @@ -1036,6 +1385,114 @@ func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, u return state, nil } +type nextStepDecision struct { + Goal string `json:"goal"` + Steps []PlanStep `json:"steps,omitempty"` + Step PlanStep `json:"step"` +} + +func (a *Agent) decideNextStep(ctx context.Context, userID int64, lang string, state ExecutionState) (nextStepDecision, error) { + toolDefs, _ := json.Marshal(agentTools()) + stateJSON, _ := json.Marshal(normalizeExecutionState(state)) + obsJSON, _ := json.Marshal(buildObservationContext(state)) + recentlyFetchedJSON, _ := json.Marshal(buildRecentlyFetchedData(state, time.Now().UTC())) + taskStateCtx := buildTaskStateContext(a.getTaskState(userID)) + recentConversationCtx := a.buildRecentConversationContext(userID, state.Goal) + + systemPrompt := `You are the step selector for NOFXi. +Return JSON only. Do not return markdown. + +You are operating in ReAct mode: Thought -> Action -> Observation. +Choose the immediate next action batch. Do not generate a long multi-step execution plan. + +Allowed step types: +- tool +- reason +- ask_user +- respond + +Rules: +- Use all available memory layers: Execution state JSON, Observations JSON, Recent conversation, and Task state. +- Use Recently fetched data JSON as the deduplication source of truth for fresh tool results. +- Prefer the freshest evidence in this order: execution state, observations, recent conversation, then task state. +- If fresh external or system data is needed, choose a tool step. +- If the user is blocked on a missing parameter, choose ask_user. +- If there is enough information to answer now, choose respond. +- Use reason only when a short intermediate synthesis is necessary before the next action. +- Prefer tool or respond over reason whenever possible. +- Never emit the same reason step twice in a row. +- After a reason step, the next batch should usually be tool, ask_user, or respond. Do not stay in analysis loops. +- Never invent tools. +- If the task needs multiple independent tool reads, emit ALL of them together in one response. +- Parallelism rule: when multiple tool reads are mutually independent, do not split them across turns. Return them together in steps. +- Never mix ask_user/respond with additional steps in the same batch. +- Only emit multiple steps when every emitted step is a tool step. +- Avoid repeated tool calls. If a matching tool call already exists in Recently fetched data and age_seconds <= 60, do not call it again unless the user explicitly asks to refresh. +- For tool steps, set tool_name exactly to one available tool and provide tool_args as a JSON object. +- For ask_user or respond steps, put the user-facing question/response instruction in instruction. +- If the latest observation already answers the goal, prefer respond over another tool call. +- Never place a trade unless the user intent is explicit. + +Return JSON with this exact shape: +{"goal":"","steps":[{"id":"step_1","type":"tool|reason|ask_user|respond","title":"","tool_name":"","tool_args":{},"instruction":"","requires_confirmation":false}]}` + + userPrompt := fmt.Sprintf("Language: %s\nGoal: %s\n\nRecent conversation:\n%s\n\nAvailable tools JSON:\n%s\n\nPersistent preferences:\n%s\n\nTask state:\n%s\n\nExecution state JSON:\n%s\n\nObservations JSON:\n%s\n\nRecently fetched data JSON:\n%s", lang, state.Goal, recentConversationCtx, string(toolDefs), a.buildPersistentPreferencesContext(userID), taskStateCtx, string(stateJSON), string(obsJSON), string(recentlyFetchedJSON)) + + stageCtx, cancel := withPlannerStageTimeout(ctx, plannerCreateTimeout) + defer cancel() + + startedAt := time.Now() + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + a.logPlannerTiming(state.SessionID, userID, "decide_next_step_llm", startedAt, err) + if err != nil { + return nextStepDecision{}, err + } + return parseNextStepDecisionJSON(raw) +} + +func parseNextStepDecisionJSON(raw string) (nextStepDecision, error) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + + var decision nextStepDecision + if err := json.Unmarshal([]byte(raw), &decision); err == nil { + return normalizeNextStepDecision(decision), nil + } + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start >= 0 && end > start { + if err := json.Unmarshal([]byte(raw[start:end+1]), &decision); err == nil { + return normalizeNextStepDecision(decision), nil + } + } + return nextStepDecision{}, fmt.Errorf("invalid next step decision json") +} + +func normalizeNextStepDecision(decision nextStepDecision) nextStepDecision { + decision.Goal = strings.TrimSpace(decision.Goal) + steps := decision.Steps + if len(steps) == 0 && decision.Step.Type != "" { + steps = []PlanStep{decision.Step} + } + if len(steps) > 0 { + steps = normalizeExecutionState(ExecutionState{Steps: steps}).Steps + } + decision.Steps = steps + if len(steps) > 0 { + decision.Step = steps[0] + } + return decision +} + func (a *Agent) refreshStateForDynamicRequests(storeUserID, userText string, state ExecutionState) ExecutionState { kinds := snapshotKindsForIntent(userText) if len(kinds) == 0 { @@ -1187,6 +1644,7 @@ Rules: stageCtx, cancel := withPlannerStageTimeout(ctx, plannerCreateTimeout) defer cancel() + startedAt := time.Now() resp, err := a.aiClient.CallWithRequest(&mcp.Request{ Messages: []mcp.Message{ mcp.NewSystemMessage(systemPrompt), @@ -1194,6 +1652,7 @@ Rules: }, Ctx: stageCtx, }) + a.logPlannerTiming(state.SessionID, userID, "create_execution_plan_llm", startedAt, err) if err != nil { return executionPlan{}, err } @@ -1247,28 +1706,63 @@ func parseExecutionPlanJSON(raw string) (executionPlan, error) { } func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int64, lang string, state *ExecutionState, onEvent func(event, data string)) (string, error) { - if onEvent != nil { + if onEvent != nil && len(state.Steps) > 0 { onEvent(StreamEventPlan, formatPlanStatus(*state, lang)) } for i := 0; i < plannerMaxIterations; i++ { stepIndex := nextPendingStepIndex(state.Steps) if stepIndex < 0 { - finalText, err := a.generateFinalPlanResponse(ctx, userID, lang, *state, "") + decisionStartedAt := time.Now() + decision, err := a.decideNextStep(ctx, userID, lang, *state) + a.logPlannerTiming(state.SessionID, userID, "decide_next_step", decisionStartedAt, err) if err != nil { return "", err } - state.Status = executionStatusCompleted - state.FinalAnswer = finalText - state.CurrentStepID = "" + steps := filterFreshDuplicateToolSteps(decision.Steps, *state, time.Now().UTC()) + if len(steps) == 0 { + appendExecutionLog(state, Observation{ + Kind: "decision_note", + Summary: "Skipped duplicate fresh tool calls from next-step decision", + CreatedAt: time.Now().UTC().Format(time.RFC3339), + }) + state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + if err := a.saveExecutionState(*state); err != nil { + return "", err + } + continue + } + if hasRepeatedReasonLoop(*state, steps) { + return "", fmt.Errorf("repeated reasoning loop detected") + } + if decision.Goal != "" { + state.Goal = decision.Goal + } + base := len(completedSteps(state.Steps)) + for idx := range steps { + if steps[idx].Type == "" { + return "", fmt.Errorf("next step decision missing step type") + } + if steps[idx].ID == "" { + steps[idx].ID = fmt.Sprintf("step_%d", base+idx+1) + } + if steps[idx].Title == "" { + steps[idx].Title = strings.ReplaceAll(steps[idx].ID, "_", " ") + } + if steps[idx].Status == "" { + steps[idx].Status = planStepStatusPending + } + } + state.Steps = append(completedSteps(state.Steps), steps...) + state.Status = executionStatusRunning state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) if err := a.saveExecutionState(*state); err != nil { return "", err } if onEvent != nil { - onEvent(StreamEventDelta, finalText) + onEvent(StreamEventPlan, formatPlanStatus(*state, lang)) } - return finalText, nil + continue } step := &state.Steps[stepIndex] @@ -1288,7 +1782,9 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 if onEvent != nil { onEvent(StreamEventTool, step.ToolName) } + stepStartedAt := time.Now() result := a.executePlanTool(ctx, storeUserID, userID, lang, *step) + a.logPlannerTiming(state.SessionID, userID, "tool:"+step.ToolName, stepStartedAt, nil) summary := summarizeObservation(result) referencesChanged := false step.Status = planStepStatusCompleted @@ -1301,29 +1797,11 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 CreatedAt: time.Now().UTC().Format(time.RFC3339), }) referencesChanged = updateCurrentReferencesFromToolResult(state, step.ToolName, result) - if shouldAttemptReplan(*state, *step, referencesChanged) { - state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) - if err := a.saveExecutionState(*state); err != nil { - return "", err - } - if onEvent != nil { - onEvent(StreamEventStepComplete, formatStepCompleteStatus(*step, lang)) - } - decision, err := a.replanAfterStep(ctx, userID, lang, *state, *step) - if err == nil && applyReplannerDecision(state, decision) { - state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) - if err := a.saveExecutionState(*state); err != nil { - return "", err - } - if onEvent != nil { - onEvent(StreamEventReplan, formatReplanStatus(decision, lang)) - onEvent(StreamEventPlan, formatPlanStatus(*state, lang)) - } - } - continue - } + _ = referencesChanged case planStepTypeReason: + reasonStartedAt := time.Now() reasoning, err := a.executeReasonStep(ctx, userID, lang, state.Goal, *state, *step) + a.logPlannerTiming(state.SessionID, userID, "reason_step", reasonStartedAt, err) if err != nil { step.Status = planStepStatusFailed step.Error = err.Error() @@ -1364,7 +1842,9 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 } return question, nil case planStepTypeRespond: + respondStartedAt := time.Now() finalText, err := a.generateFinalPlanResponse(ctx, userID, lang, *state, step.Instruction) + a.logPlannerTiming(state.SessionID, userID, "respond_step", respondStartedAt, err) if err != nil { return "", err } @@ -1399,6 +1879,134 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6 return "", fmt.Errorf("plan execution exceeded iteration limit") } +type fetchedToolRecord struct { + ToolName string `json:"tool_name"` + ToolArgsJSON string `json:"tool_args_json"` + FetchedAt string `json:"fetched_at"` + AgeSeconds int64 `json:"age_seconds"` +} + +func buildRecentlyFetchedData(state ExecutionState, now time.Time) []fetchedToolRecord { + state = normalizeExecutionState(state) + stepByID := make(map[string]PlanStep, len(state.Steps)) + for _, step := range state.Steps { + stepByID[step.ID] = step + } + latest := map[string]fetchedToolRecord{} + for _, obs := range state.ExecutionLog { + if obs.Kind != "tool_result" { + continue + } + step, ok := stepByID[obs.StepID] + if !ok || step.ToolName == "" { + continue + } + sig := toolCallSignature(step.ToolName, step.ToolArgs) + createdAt := parseRFC3339(obs.CreatedAt) + record := fetchedToolRecord{ + ToolName: step.ToolName, + ToolArgsJSON: toolArgsJSONString(step.ToolArgs), + FetchedAt: obs.CreatedAt, + AgeSeconds: int64(now.Sub(createdAt).Seconds()), + } + prev, exists := latest[sig] + if !exists || prev.FetchedAt < record.FetchedAt { + latest[sig] = record + } + } + out := make([]fetchedToolRecord, 0, len(latest)) + for _, record := range latest { + if record.AgeSeconds < 0 { + record.AgeSeconds = 0 + } + out = append(out, record) + } + return out +} + +func filterFreshDuplicateToolSteps(steps []PlanStep, state ExecutionState, now time.Time) []PlanStep { + if len(steps) == 0 { + return nil + } + fresh := make(map[string]struct{}) + for _, item := range buildRecentlyFetchedData(state, now) { + if item.AgeSeconds <= 60 { + fresh[item.ToolName+"|"+item.ToolArgsJSON] = struct{}{} + } + } + out := make([]PlanStep, 0, len(steps)) + for _, step := range steps { + if step.Type != planStepTypeTool { + out = append(out, step) + continue + } + sig := toolCallSignature(step.ToolName, step.ToolArgs) + if _, ok := fresh[sig]; ok { + continue + } + fresh[sig] = struct{}{} + out = append(out, step) + } + return out +} + +func hasRepeatedReasonLoop(state ExecutionState, steps []PlanStep) bool { + if len(steps) == 0 { + return false + } + last := lastCompletedStep(state.Steps) + if last == nil || last.Type != planStepTypeReason { + return false + } + for _, step := range steps { + if step.Type != planStepTypeReason { + return false + } + if stepSemanticKey(*last) != stepSemanticKey(step) { + return false + } + } + return true +} + +func lastCompletedStep(steps []PlanStep) *PlanStep { + for i := len(steps) - 1; i >= 0; i-- { + if steps[i].Status == planStepStatusCompleted { + return &steps[i] + } + } + return nil +} + +func stepSemanticKey(step PlanStep) string { + return strings.ToLower(strings.TrimSpace( + step.Type + "|" + step.ToolName + "|" + step.Title + "|" + step.Instruction, + )) +} + +func toolCallSignature(toolName string, args map[string]any) string { + return strings.TrimSpace(toolName) + "|" + toolArgsJSONString(args) +} + +func toolArgsJSONString(args map[string]any) string { + if len(args) == 0 { + return "{}" + } + data, err := json.Marshal(args) + if err != nil { + return "{}" + } + return string(data) +} + +func parseRFC3339(value string) time.Time { + t, err := time.Parse(time.RFC3339, strings.TrimSpace(value)) + if err != nil { + return time.Time{} + } + return t +} + func (a *Agent) replanAfterStep(ctx context.Context, userID int64, lang string, state ExecutionState, completedStep PlanStep) (replannerDecision, error) { obsJSON, _ := json.Marshal(buildObservationContext(state)) stepsJSON, _ := json.Marshal(state.Steps) @@ -1426,6 +2034,7 @@ Rules: stageCtx, cancel := withPlannerStageTimeout(ctx, plannerReplanTimeout) defer cancel() + startedAt := time.Now() raw, err := a.aiClient.CallWithRequest(&mcp.Request{ Messages: []mcp.Message{ mcp.NewSystemMessage(systemPrompt), @@ -1434,6 +2043,7 @@ Rules: Ctx: stageCtx, MaxTokens: intPtr(500), }) + a.logPlannerTiming(state.SessionID, userID, "replan_after_step_llm", startedAt, err) if err != nil { return replannerDecision{}, err } @@ -1688,6 +2298,7 @@ func (a *Agent) executeReasonStep(ctx context.Context, userID int64, lang, goal stageCtx, cancel := withPlannerStageTimeout(ctx, plannerReasonTimeout) defer cancel() + startedAt := time.Now() resp, err := a.aiClient.CallWithRequest(&mcp.Request{ Messages: []mcp.Message{ mcp.NewSystemMessage("You are the reasoning module for NOFXi. Return one short paragraph only. No markdown, no bullet list."), @@ -1695,6 +2306,7 @@ func (a *Agent) executeReasonStep(ctx context.Context, userID int64, lang, goal }, Ctx: stageCtx, }) + a.logPlannerTiming(state.SessionID, userID, "reason_step_llm", startedAt, err) if err != nil { return "", err } @@ -1709,7 +2321,8 @@ func (a *Agent) generateFinalPlanResponse(ctx context.Context, userID int64, lan } stageCtx, cancel := withPlannerStageTimeout(ctx, plannerFinalTimeout) defer cancel() - return a.aiClient.CallWithRequest(&mcp.Request{ + startedAt := time.Now() + resp, err := a.aiClient.CallWithRequest(&mcp.Request{ Messages: []mcp.Message{ mcp.NewSystemMessage(systemPrompt), mcp.NewSystemMessage("You are responding after a completed execution plan. Use the observations as the source of truth. Be concise and actionable."), @@ -1717,6 +2330,24 @@ func (a *Agent) generateFinalPlanResponse(ctx context.Context, userID int64, lan }, Ctx: stageCtx, }) + a.logPlannerTiming(state.SessionID, userID, "generate_final_response_llm", startedAt, err) + return resp, err +} + +func (a *Agent) logPlannerTiming(sessionID string, userID int64, stage string, startedAt time.Time, err error) { + if stage == "" || startedAt.IsZero() { + return + } + attrs := []any{ + "session_id", sessionID, + "user_id", userID, + "stage", stage, + "elapsed_ms", time.Since(startedAt).Milliseconds(), + } + if err != nil { + attrs = append(attrs, "error", err.Error()) + } + a.log().Info("planner timing", attrs...) } func nextPendingStepIndex(steps []PlanStep) int { diff --git a/agent/planner_runtime_state_test.go b/agent/planner_runtime_state_test.go index 4d61eb78..ed1b08da 100644 --- a/agent/planner_runtime_state_test.go +++ b/agent/planner_runtime_state_test.go @@ -617,6 +617,39 @@ func TestThinkAndActPrioritizesActiveExecutionStateOverDirectReply(t *testing.T) } } +func TestThinkAndActInterruptsWaitingExecutionStateForNewTopic(t *testing.T) { + a := newTestAgentWithStore(t) + a.history = newChatHistory(10) + + _ = a.toolManageStrategy("user-1", `{ + "action":"create", + "name":"激进", + "lang":"zh" + }`) + + userID := int64(91) + state := newExecutionState(userID, "创建交易员") + state.Status = executionStatusWaitingUser + state.Waiting = &WaitingState{ + Question: "请告诉我交易员名称", + PendingFields: []string{"name"}, + } + if err := a.saveExecutionState(state); err != nil { + t.Fatalf("saveExecutionState() error = %v", err) + } + + resp, err := a.thinkAndAct(context.Background(), "user-1", userID, "zh", "列出我当前的策略") + if err != nil { + t.Fatalf("thinkAndAct() error = %v", err) + } + if !strings.Contains(resp, "当前策略") || !strings.Contains(resp, "激进") { + t.Fatalf("expected new topic to be handled, got %q", resp) + } + if got := a.getExecutionState(userID); got.SessionID != "" { + t.Fatalf("expected execution state to be cleared, got %+v", got) + } +} + func TestCreateExecutionPlanIncludesRecentConversation(t *testing.T) { client := &capturePlannerAIClient{} a := &Agent{ diff --git a/agent/skill_dag.go b/agent/skill_dag.go new file mode 100644 index 00000000..ad026115 --- /dev/null +++ b/agent/skill_dag.go @@ -0,0 +1,277 @@ +package agent + +import "strings" + +type SkillDAG struct { + SkillName string + Action string + Steps []SkillDAGStep +} + +type SkillDAGStep struct { + ID string + Kind string + RequiredFields []string + OptionalFields []string + Next []string + Terminal bool +} + +var skillDAGRegistry = buildSkillDAGRegistry() + +func buildSkillDAGRegistry() map[string]SkillDAG { + dags := []SkillDAG{ + { + SkillName: "trader_management", + Action: "create", + Steps: []SkillDAGStep{ + {ID: "resolve_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"resolve_exchange"}}, + {ID: "resolve_exchange", Kind: "collect_slot", RequiredFields: []string{"exchange_id"}, OptionalFields: []string{"exchange_name"}, Next: []string{"resolve_model"}}, + {ID: "resolve_model", Kind: "collect_slot", RequiredFields: []string{"model_id"}, OptionalFields: []string{"model_name"}, Next: []string{"resolve_strategy"}}, + {ID: "resolve_strategy", Kind: "collect_slot", RequiredFields: []string{"strategy_id"}, OptionalFields: []string{"strategy_name"}, Next: []string{"maybe_confirm_start"}}, + {ID: "maybe_confirm_start", Kind: "branch", OptionalFields: []string{"auto_start"}, Next: []string{"await_start_confirmation", "execute_create_only"}}, + {ID: "await_start_confirmation", Kind: "confirm", RequiredFields: []string{"auto_start"}, Next: []string{"execute_create_and_start", "execute_create_only"}}, + {ID: "execute_create_only", Kind: "execute", RequiredFields: []string{"name", "exchange_id", "model_id", "strategy_id"}, Terminal: true}, + {ID: "execute_create_and_start", Kind: "execute", RequiredFields: []string{"name", "exchange_id", "model_id", "strategy_id"}, OptionalFields: []string{"auto_start"}, Terminal: true}, + }, + }, + { + SkillName: "trader_management", + Action: "update_name", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_name"}}, + {ID: "collect_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "name"}, Terminal: true}, + }, + }, + { + SkillName: "trader_management", + Action: "update_bindings", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_bindings"}}, + {ID: "collect_bindings", Kind: "collect_slot", RequiredFields: []string{"binding_update"}, OptionalFields: []string{"ai_model_id", "exchange_id", "strategy_id"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "binding_update"}, OptionalFields: []string{"ai_model_id", "exchange_id", "strategy_id"}, Terminal: true}, + }, + }, + { + SkillName: "trader_management", + Action: "start", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}}, + {ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_start"}}, + {ID: "execute_start", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true}, + }, + }, + { + SkillName: "trader_management", + Action: "stop", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}}, + {ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_stop"}}, + {ID: "execute_stop", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true}, + }, + }, + { + SkillName: "trader_management", + Action: "delete", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}}, + {ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_delete"}}, + {ID: "execute_delete", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true}, + }, + }, + { + SkillName: "strategy_management", + Action: "create", + Steps: []SkillDAGStep{ + {ID: "resolve_name", Kind: "collect_slot", RequiredFields: []string{"name"}, OptionalFields: []string{"lang", "description", "config"}, Next: []string{"execute_create"}}, + {ID: "execute_create", Kind: "execute", RequiredFields: []string{"name"}, OptionalFields: []string{"lang", "description", "config"}, Terminal: true}, + }, + }, + { + SkillName: "strategy_management", + Action: "update_name", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_name"}}, + {ID: "collect_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "name"}, Terminal: true}, + }, + }, + { + SkillName: "strategy_management", + Action: "update_prompt", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_prompt"}}, + {ID: "collect_prompt", Kind: "collect_slot", RequiredFields: []string{"prompt"}, Next: []string{"load_config"}}, + {ID: "load_config", Kind: "load_state", RequiredFields: []string{"target_ref"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "prompt"}, Terminal: true}, + }, + }, + { + SkillName: "strategy_management", + Action: "update_config", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"resolve_config_field"}}, + {ID: "resolve_config_field", Kind: "collect_slot", RequiredFields: []string{"config_field"}, Next: []string{"resolve_config_value"}}, + {ID: "resolve_config_value", Kind: "collect_slot", RequiredFields: []string{"config_value"}, Next: []string{"load_config"}}, + {ID: "load_config", Kind: "load_state", RequiredFields: []string{"target_ref"}, Next: []string{"apply_field_update"}}, + {ID: "apply_field_update", Kind: "transform", RequiredFields: []string{"config_field", "config_value"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "config_field", "config_value"}, Terminal: true}, + }, + }, + { + SkillName: "strategy_management", + Action: "duplicate", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_name"}}, + {ID: "collect_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"execute_duplicate"}}, + {ID: "execute_duplicate", Kind: "execute", RequiredFields: []string{"target_ref", "name"}, Terminal: true}, + }, + }, + { + SkillName: "strategy_management", + Action: "activate", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"execute_activate"}}, + {ID: "execute_activate", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true}, + }, + }, + { + SkillName: "strategy_management", + Action: "delete", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}}, + {ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_delete"}}, + {ID: "execute_delete", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true}, + }, + }, + { + SkillName: "model_management", + Action: "create", + Steps: []SkillDAGStep{ + {ID: "resolve_provider", Kind: "collect_slot", RequiredFields: []string{"provider"}, Next: []string{"collect_optional_fields"}}, + {ID: "collect_optional_fields", Kind: "collect_slot", OptionalFields: []string{"name", "custom_api_url", "custom_model_name"}, Next: []string{"execute_create"}}, + {ID: "execute_create", Kind: "execute", RequiredFields: []string{"provider"}, OptionalFields: []string{"name", "custom_api_url", "custom_model_name"}, Terminal: true}, + }, + }, + { + SkillName: "model_management", + Action: "update_status", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_enabled"}}, + {ID: "collect_enabled", Kind: "collect_slot", RequiredFields: []string{"enabled"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "enabled"}, Terminal: true}, + }, + }, + { + SkillName: "model_management", + Action: "update_endpoint", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_custom_api_url"}}, + {ID: "collect_custom_api_url", Kind: "collect_slot", RequiredFields: []string{"custom_api_url"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "custom_api_url"}, Terminal: true}, + }, + }, + { + SkillName: "model_management", + Action: "update_name", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_custom_model_name"}}, + {ID: "collect_custom_model_name", Kind: "collect_slot", RequiredFields: []string{"custom_model_name"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "custom_model_name"}, Terminal: true}, + }, + }, + { + SkillName: "model_management", + Action: "delete", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}}, + {ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_delete"}}, + {ID: "execute_delete", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true}, + }, + }, + { + SkillName: "exchange_management", + Action: "create", + Steps: []SkillDAGStep{ + {ID: "resolve_exchange_type", Kind: "collect_slot", RequiredFields: []string{"exchange_type"}, Next: []string{"collect_account_name"}}, + {ID: "collect_account_name", Kind: "collect_slot", OptionalFields: []string{"account_name"}, Next: []string{"execute_create"}}, + {ID: "execute_create", Kind: "execute", RequiredFields: []string{"exchange_type"}, OptionalFields: []string{"account_name"}, Terminal: true}, + }, + }, + { + SkillName: "exchange_management", + Action: "update_name", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_account_name"}}, + {ID: "collect_account_name", Kind: "collect_slot", RequiredFields: []string{"account_name"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "account_name"}, Terminal: true}, + }, + }, + { + SkillName: "exchange_management", + Action: "update_status", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_enabled"}}, + {ID: "collect_enabled", Kind: "collect_slot", RequiredFields: []string{"enabled"}, Next: []string{"execute_update"}}, + {ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "enabled"}, Terminal: true}, + }, + }, + { + SkillName: "exchange_management", + Action: "delete", + Steps: []SkillDAGStep{ + {ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}}, + {ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_delete"}}, + {ID: "execute_delete", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true}, + }, + }, + } + + registry := make(map[string]SkillDAG, len(dags)) + for _, dag := range dags { + dag = normalizeSkillDAG(dag) + if dag.SkillName == "" || dag.Action == "" { + continue + } + registry[skillDAGKey(dag.SkillName, dag.Action)] = dag + } + return registry +} + +func normalizeSkillDAG(dag SkillDAG) SkillDAG { + dag.SkillName = strings.TrimSpace(dag.SkillName) + dag.Action = strings.TrimSpace(dag.Action) + steps := make([]SkillDAGStep, 0, len(dag.Steps)) + for _, step := range dag.Steps { + step.ID = strings.TrimSpace(step.ID) + step.Kind = strings.TrimSpace(step.Kind) + step.RequiredFields = cleanStringList(step.RequiredFields) + step.OptionalFields = cleanStringList(step.OptionalFields) + step.Next = cleanStringList(step.Next) + if step.ID == "" { + continue + } + steps = append(steps, step) + } + dag.Steps = steps + return dag +} + +func skillDAGKey(skillName, action string) string { + return strings.TrimSpace(skillName) + ":" + strings.TrimSpace(action) +} + +func getSkillDAG(skillName, action string) (SkillDAG, bool) { + dag, ok := skillDAGRegistry[skillDAGKey(skillName, action)] + return dag, ok +} + +func listSkillDAGs() []SkillDAG { + out := make([]SkillDAG, 0, len(skillDAGRegistry)) + for _, dag := range skillDAGRegistry { + out = append(out, dag) + } + return out +} + diff --git a/agent/skill_dag_runtime.go b/agent/skill_dag_runtime.go new file mode 100644 index 00000000..8178536c --- /dev/null +++ b/agent/skill_dag_runtime.go @@ -0,0 +1,51 @@ +package agent + +const skillDAGStepField = "_dag_step" + +func currentSkillDAGStep(session skillSession) (SkillDAGStep, bool) { + dag, ok := getSkillDAG(session.Name, session.Action) + if !ok || len(dag.Steps) == 0 { + return SkillDAGStep{}, false + } + stepID := fieldValue(session, skillDAGStepField) + if stepID == "" { + return dag.Steps[0], true + } + for _, step := range dag.Steps { + if step.ID == stepID { + return step, true + } + } + return dag.Steps[0], true +} + +func setSkillDAGStep(session *skillSession, stepID string) { + ensureSkillFields(session) + if stepID == "" { + delete(session.Fields, skillDAGStepField) + return + } + session.Fields[skillDAGStepField] = stepID +} + +func clearSkillDAGStep(session *skillSession) { + if session == nil || session.Fields == nil { + return + } + delete(session.Fields, skillDAGStepField) +} + +func advanceSkillDAGStep(session *skillSession, currentStepID string) { + dag, ok := getSkillDAG(session.Name, session.Action) + if !ok { + return + } + for _, step := range dag.Steps { + if step.ID != currentStepID || len(step.Next) == 0 { + continue + } + setSkillDAGStep(session, step.Next[0]) + return + } +} + diff --git a/agent/skill_dag_runtime_test.go b/agent/skill_dag_runtime_test.go new file mode 100644 index 00000000..8085ceee --- /dev/null +++ b/agent/skill_dag_runtime_test.go @@ -0,0 +1,27 @@ +package agent + +import "testing" + +func TestCurrentSkillDAGStepDefaultsToFirstStep(t *testing.T) { + session := skillSession{Name: "strategy_management", Action: "update_config"} + step, ok := currentSkillDAGStep(session) + if !ok { + t.Fatal("expected dag step") + } + if step.ID != "resolve_target" { + t.Fatalf("expected first step resolve_target, got %s", step.ID) + } +} + +func TestAdvanceSkillDAGStepMovesToNextStep(t *testing.T) { + session := skillSession{Name: "strategy_management", Action: "update_config"} + setSkillDAGStep(&session, "resolve_config_field") + advanceSkillDAGStep(&session, "resolve_config_field") + step, ok := currentSkillDAGStep(session) + if !ok { + t.Fatal("expected dag step") + } + if step.ID != "resolve_config_value" { + t.Fatalf("expected resolve_config_value, got %s", step.ID) + } +} diff --git a/agent/skill_dag_test.go b/agent/skill_dag_test.go new file mode 100644 index 00000000..73707474 --- /dev/null +++ b/agent/skill_dag_test.go @@ -0,0 +1,67 @@ +package agent + +import "testing" + +func TestGetSkillDAGForStructuredActions(t *testing.T) { + tests := []struct { + skill string + action string + }{ + {skill: "trader_management", action: "create"}, + {skill: "trader_management", action: "update_bindings"}, + {skill: "strategy_management", action: "update_config"}, + {skill: "strategy_management", action: "update_prompt"}, + {skill: "model_management", action: "update_status"}, + {skill: "exchange_management", action: "update_name"}, + } + + for _, tt := range tests { + dag, ok := getSkillDAG(tt.skill, tt.action) + if !ok { + t.Fatalf("expected DAG for %s/%s", tt.skill, tt.action) + } + if dag.SkillName != tt.skill || dag.Action != tt.action { + t.Fatalf("unexpected dag identity: %+v", dag) + } + if len(dag.Steps) == 0 { + t.Fatalf("expected DAG steps for %s/%s", tt.skill, tt.action) + } + } +} + +func TestStructuredDAGsHaveTerminalStep(t *testing.T) { + for _, dag := range listSkillDAGs() { + hasTerminal := false + for _, step := range dag.Steps { + if step.Terminal { + hasTerminal = true + break + } + } + if !hasTerminal { + t.Fatalf("expected terminal step for %s/%s", dag.SkillName, dag.Action) + } + } +} + +func TestStrategyUpdateConfigDAGMatchesCurrentAtomicFlow(t *testing.T) { + dag, ok := getSkillDAG("strategy_management", "update_config") + if !ok { + t.Fatal("missing strategy update_config dag") + } + if len(dag.Steps) != 6 { + t.Fatalf("expected 6 steps, got %d", len(dag.Steps)) + } + if dag.Steps[0].ID != "resolve_target" { + t.Fatalf("expected first step resolve_target, got %s", dag.Steps[0].ID) + } + if dag.Steps[1].ID != "resolve_config_field" { + t.Fatalf("expected second step resolve_config_field, got %s", dag.Steps[1].ID) + } + if dag.Steps[2].ID != "resolve_config_value" { + t.Fatalf("expected third step resolve_config_value, got %s", dag.Steps[2].ID) + } + if dag.Steps[5].ID != "execute_update" || !dag.Steps[5].Terminal { + t.Fatalf("expected final terminal execute step, got %+v", dag.Steps[5]) + } +} diff --git a/agent/skill_dispatcher.go b/agent/skill_dispatcher.go index 6e38f824..96c2a71f 100644 --- a/agent/skill_dispatcher.go +++ b/agent/skill_dispatcher.go @@ -37,8 +37,8 @@ type traderSkillOption struct { } var ( - quotedNamePattern = regexp.MustCompile(`[“"]([^“”"]{1,40})[”"]`) - traderNamedPattern = regexp.MustCompile(`(?:叫|名为|名字是)\s*([A-Za-z0-9_\-\p{Han}]{2,40})`) + quotedNamePattern = regexp.MustCompile(`[“"]([^“”"]{1,40})[”"]`) + traderNamedPattern = regexp.MustCompile(`(?:叫|名为|名字是)\s*([A-Za-z0-9_\-\p{Han}]{2,40})`) ) func skillSessionConfigKey(userID int64) string { @@ -157,7 +157,12 @@ func isNoReply(text string) bool { func isCancelSkillReply(text string) bool { lower := strings.ToLower(strings.TrimSpace(text)) - return lower == "取消" || lower == "/cancel" || lower == "cancel" + switch lower { + case "取消", "/cancel", "cancel", "不改", "先不改", "算了", "先不用", "不用了", "不弄了", "不搞了", "换话题", "换话题了", "聊别的", "先聊别的": + return true + default: + return false + } } func detectCreateTraderSkill(text string) bool { @@ -198,6 +203,116 @@ func detectStartIntent(text string) bool { return containsAny(lower, []string{"启动", "跑起来", "run", "start", "立即运行", "并启动"}) } +func looksLikeStandaloneValueReply(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if firstIntegerPattern.MatchString(lower) && len(strings.Fields(lower)) <= 4 { + return true + } + return containsAny(lower, []string{"启用", "禁用", "enable", "disable", "打开", "关闭"}) +} + +func detectImplicitStrategyAction(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"prompt", "提示词"}): + return "update_prompt" + case containsAny(lower, []string{"参数", "配置", "置信度", "持仓", "周期", "timeframe", "调到", "改到", "改成", "调整"}): + return "update_config" + default: + return "" + } +} + +func detectImplicitTraderAction(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"启动", "开始", "run", "start"}): + return "start" + case containsAny(lower, []string{"停止", "停掉", "stop", "pause"}): + return "stop" + case containsAny(lower, []string{"换模型", "换交易所", "换策略", "绑定", "切换模型", "切换交易所", "切换策略"}): + return "update_bindings" + case containsAny(lower, []string{"改名", "重命名", "rename"}): + return "update_name" + default: + return "" + } +} + +func detectImplicitModelAction(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"启用", "禁用", "enable", "disable"}): + return "update_status" + case containsAny(lower, []string{"url", "endpoint", "地址", "接口"}): + return "update_endpoint" + case containsAny(lower, []string{"模型名", "模型名称", "model name", "改名", "重命名", "rename"}): + return "update_name" + default: + return "" + } +} + +func detectImplicitExchangeAction(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"启用", "禁用", "enable", "disable"}): + return "update_status" + case containsAny(lower, []string{"账户名", "改名", "重命名", "rename"}): + return "update_name" + default: + return "" + } +} + +func (a *Agent) inferContextualSkillSession(storeUserID string, userID int64, text string, session skillSession) skillSession { + if session.Name != "" || strings.TrimSpace(text) == "" { + return session + } + state := a.getExecutionState(userID) + lower := strings.ToLower(strings.TrimSpace(text)) + if state.CurrentReferences != nil { + if ref := state.CurrentReferences.Strategy; ref != nil { + if action := detectImplicitStrategyAction(text); action != "" || looksLikeStandaloneValueReply(text) { + return skillSession{Name: "strategy_management", Action: defaultIfEmpty(action, "update_config"), Phase: "collecting", TargetRef: ref} + } + } + if ref := state.CurrentReferences.Trader; ref != nil { + if action := detectImplicitTraderAction(text); action != "" { + return skillSession{Name: "trader_management", Action: action, Phase: "collecting", TargetRef: ref} + } + } + if ref := state.CurrentReferences.Model; ref != nil { + if action := detectImplicitModelAction(text); action != "" { + return skillSession{Name: "model_management", Action: action, Phase: "collecting", TargetRef: ref} + } + } + if ref := state.CurrentReferences.Exchange; ref != nil { + if action := detectImplicitExchangeAction(text); action != "" { + return skillSession{Name: "exchange_management", Action: action, Phase: "collecting", TargetRef: ref} + } + } + } + if containsAny(lower, []string{"调整参数", "改参数", "改配置"}) { + options := a.loadStrategyOptions(storeUserID) + if len(options) == 1 { + return skillSession{ + Name: "strategy_management", + Action: "update_config", + Phase: "collecting", + TargetRef: &EntityReference{ + ID: options[0].ID, + Name: options[0].Name, + }, + } + } + } + return session +} + func extractTraderName(text string) string { text = strings.TrimSpace(text) if text == "" { @@ -212,11 +327,45 @@ func extractTraderName(text string) string { return "" } +func extractSegmentAfterKeywords(text string, keywords []string) string { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return "" + } + lower := strings.ToLower(trimmed) + for _, keyword := range keywords { + idx := strings.Index(lower, strings.ToLower(keyword)) + if idx < 0 { + continue + } + segment := strings.TrimSpace(trimmed[idx+len(keyword):]) + if segment == "" { + continue + } + cut := len(segment) + for i, r := range segment { + switch r { + case ',', ',', '。', ';', ';', '\n', '、': + cut = i + goto done + } + } + done: + segment = strings.TrimSpace(segment[:cut]) + segment = strings.Trim(segment, "“”\"':: ") + if segment != "" { + return segment + } + } + return "" +} + func pickMentionedOption(text string, options []traderSkillOption) *traderSkillOption { lower := strings.ToLower(strings.TrimSpace(text)) if lower == "" { return nil } + bestScore := 0 var matched *traderSkillOption for _, option := range options { id := strings.ToLower(strings.TrimSpace(option.ID)) @@ -224,10 +373,16 @@ func pickMentionedOption(text string, options []traderSkillOption) *traderSkillO if id == "" && name == "" { continue } - if (id != "" && strings.Contains(lower, id)) || (name != "" && strings.Contains(lower, name)) { - if matched != nil { - return nil - } + score := optionMatchScore(lower, id, name) + if score == 0 { + continue + } + if score == bestScore { + matched = nil + continue + } + if score > bestScore { + bestScore = score copy := option matched = © } @@ -235,6 +390,73 @@ func pickMentionedOption(text string, options []traderSkillOption) *traderSkillO return matched } +func pickOptionFromSegment(text string, keywords []string, options []traderSkillOption) *traderSkillOption { + segment := extractSegmentAfterKeywords(text, keywords) + if strings.TrimSpace(segment) == "" { + return nil + } + return pickMentionedOption(segment, options) +} + +func optionMatchScore(text, id, name string) int { + if id != "" && strings.Contains(text, id) { + return 4 + } + return optionNameMatchScore(text, name) +} + +func optionNameMatchScore(text, name string) int { + name = strings.TrimSpace(strings.ToLower(name)) + if name == "" { + return 0 + } + if strings.Contains(text, name) { + return 3 + } + fields := strings.FieldsFunc(name, func(r rune) bool { + switch r { + case ' ', ',', ',', '/', '|', '、', '(', ')', '(', ')': + return true + default: + return false + } + }) + best := 0 + for _, field := range fields { + field = strings.TrimSpace(field) + if field == "" { + continue + } + if len([]rune(field)) <= 2 && !containsHan(field) { + continue + } + if strings.Contains(text, field) { + if containsHan(field) && len([]rune(field)) >= 3 { + best = max(best, 2) + } else { + best = max(best, 1) + } + } + } + return best +} + +func containsHan(s string) bool { + for _, r := range s { + if r >= 0x4E00 && r <= 0x9FFF { + return true + } + } + return false +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + func choosePreferredOption(options []traderSkillOption) *traderSkillOption { if len(options) == 1 { copy := options[0] @@ -262,6 +484,8 @@ func formatOptionList(prefix string, options []traderSkillOption) string { } if option.Enabled { label += "(已启用)" + } else { + label += "(已禁用)" } parts = append(parts, label) } @@ -291,13 +515,12 @@ func (a *Agent) loadEnabledModelOptions(storeUserID string) []traderSkillOption } out := make([]traderSkillOption, 0, len(models)) for _, model := range models { - name := strings.TrimSpace(model.Name) - if name == "" { - name = strings.TrimSpace(model.CustomModelName) - } - if name == "" { - name = strings.TrimSpace(model.Provider) - } + parts := cleanStringList([]string{ + strings.TrimSpace(model.Name), + strings.TrimSpace(model.CustomModelName), + strings.TrimSpace(model.Provider), + }) + name := strings.Join(parts, " ") out = append(out, traderSkillOption{ID: model.ID, Name: name, Enabled: model.Enabled}) } return out @@ -342,6 +565,7 @@ func (a *Agent) tryHardSkill(ctx context.Context, storeUserID string, userID int return "", false } session := a.getSkillSession(userID) + session = a.inferContextualSkillSession(storeUserID, userID, text, session) if (session.Name == "trader_management" && session.Action == "create") || detectCreateTraderSkill(text) { answer, handled := a.handleCreateTraderSkill(storeUserID, userID, lang, text, session) if handled { @@ -459,13 +683,13 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang, return "Cancelled the current trader creation flow.", true } - if session.Name == "" { - session = skillSession{ - Name: "trader_management", - Action: "create", - Phase: "collecting", - Slots: &createTraderSkillSlots{}, - } + if session.Name == "" { + session = skillSession{ + Name: "trader_management", + Action: "create", + Phase: "collecting", + Slots: &createTraderSkillSlots{}, + } if detectStartIntent(text) { autoStart := true session.Slots.AutoStart = &autoStart @@ -474,8 +698,12 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang, if session.Slots == nil { session.Slots = &createTraderSkillSlots{} } + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "resolve_name") + } if session.Phase == "await_start_confirmation" { + setSkillDAGStep(&session, "await_start_confirmation") switch { case isYesReply(text): answer := a.executeCreateTraderSkill(storeUserID, userID, lang, session, true) @@ -496,13 +724,19 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang, if slots.Name == "" { slots.Name = extractTraderName(text) } + if slots.Name != "" { + setSkillDAGStep(&session, "resolve_exchange") + } models := a.loadEnabledModelOptions(storeUserID) exchanges := a.loadExchangeOptions(storeUserID) strategies := a.loadStrategyOptions(storeUserID) if slots.ModelID == "" { - if match := pickMentionedOption(text, models); match != nil { + if match := pickOptionFromSegment(text, []string{"模型用", "模型", "model"}, models); match != nil { + slots.ModelID = match.ID + slots.ModelName = match.Name + } else if match := pickMentionedOption(text, models); match != nil { slots.ModelID = match.ID slots.ModelName = match.Name } else if choice := choosePreferredOption(models); choice != nil { @@ -510,17 +744,46 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang, slots.ModelName = choice.Name } } + if slots.ExchangeID != "" { + setSkillDAGStep(&session, "resolve_model") + } if slots.ExchangeID == "" { - if match := pickMentionedOption(text, exchanges); match != nil { - slots.ExchangeID = match.ID - slots.ExchangeName = match.Name + if match := pickOptionFromSegment(text, []string{"交易所用", "交易所", "exchange"}, exchanges); match != nil { + if match.Enabled { + slots.ExchangeID = match.ID + slots.ExchangeName = match.Name + } else { + if lang == "zh" { + extra := "你刚才提到的交易所“" + defaultIfEmpty(match.Name, match.ID) + "”当前已禁用,请换一个已启用的交易所。" + a.saveSkillSession(userID, session) + return extra + "\n" + formatOptionList("可用交易所:", exchanges), true + } + a.saveSkillSession(userID, session) + return "The exchange you mentioned is currently disabled. Please choose an enabled exchange.\n" + formatOptionList("Available exchanges:", exchanges), true + } + } else if match := pickMentionedOption(text, exchanges); match != nil { + if match.Enabled { + slots.ExchangeID = match.ID + slots.ExchangeName = match.Name + } else { + if lang == "zh" { + extra := "你刚才提到的交易所“" + defaultIfEmpty(match.Name, match.ID) + "”当前已禁用,请换一个已启用的交易所。" + a.saveSkillSession(userID, session) + return extra + "\n" + formatOptionList("可用交易所:", exchanges), true + } + a.saveSkillSession(userID, session) + return "The exchange you mentioned is currently disabled. Please choose an enabled exchange.\n" + formatOptionList("Available exchanges:", exchanges), true + } } else if choice := choosePreferredOption(exchanges); choice != nil { slots.ExchangeID = choice.ID slots.ExchangeName = choice.Name } } if slots.StrategyID == "" { - if match := pickMentionedOption(text, strategies); match != nil { + if match := pickOptionFromSegment(text, []string{"策略用", "策略", "strategy"}, strategies); match != nil { + slots.StrategyID = match.ID + slots.StrategyName = match.Name + } else if match := pickMentionedOption(text, strategies); match != nil { slots.StrategyID = match.ID slots.StrategyName = match.Name } else if choice := choosePreferredOption(strategies); choice != nil { @@ -528,34 +791,18 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang, slots.StrategyName = choice.Name } } + if slots.ModelID != "" { + setSkillDAGStep(&session, "resolve_strategy") + } + if slots.StrategyID != "" { + setSkillDAGStep(&session, "maybe_confirm_start") + } if slots.AutoStart == nil && detectStartIntent(text) { autoStart := true slots.AutoStart = &autoStart } - if len(strategies) == 0 { - a.clearSkillSession(userID) - if lang == "zh" { - return "当前还没有可用策略,暂时不能创建交易员。请先创建一个策略,再回来继续。", true - } - return "There is no strategy available yet, so I can't create a trader. Please create a strategy first.", true - } - if len(models) == 0 { - a.clearSkillSession(userID) - if lang == "zh" { - return "当前还没有模型配置,暂时不能创建交易员。请先配置并启用一个模型。", true - } - return "There is no model config yet, so I can't create a trader. Please configure and enable a model first.", true - } - if len(exchanges) == 0 { - a.clearSkillSession(userID) - if lang == "zh" { - return "当前还没有交易所配置,暂时不能创建交易员。请先配置并启用一个交易所账户。", true - } - return "There is no exchange config yet, so I can't create a trader. Please configure and enable an exchange first.", true - } - missing := make([]string, 0, 3) extraLines := make([]string, 0, 3) if actionRequiresSlot("trader_management", "create", "name") && slots.Name == "" { @@ -563,15 +810,53 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang, } if actionRequiresSlot("trader_management", "create", "exchange") && slots.ExchangeID == "" { missing = append(missing, slotDisplayName("exchange", lang)) - extraLines = append(extraLines, formatOptionList("可用交易所:", exchanges)) + if len(exchanges) == 0 { + if lang == "zh" { + extraLines = append(extraLines, "当前还没有可用交易所配置,请先配置并启用一个交易所账户。") + } else { + extraLines = append(extraLines, "There is no enabled exchange config yet. Please create and enable one first.") + } + } else { + label := "Available exchanges:" + if lang == "zh" { + label = "可用交易所:" + } + extraLines = append(extraLines, formatOptionList(label, exchanges)) + } } if actionRequiresSlot("trader_management", "create", "model") && slots.ModelID == "" { missing = append(missing, slotDisplayName("model", lang)) - extraLines = append(extraLines, formatOptionList("可用模型:", models)) + if len(models) == 0 { + if lang == "zh" { + extraLines = append(extraLines, "当前还没有可用模型配置,请先配置并启用一个模型。") + } else { + extraLines = append(extraLines, "There is no enabled model config yet. Please create and enable one first.") + } + } else { + label := "Available models:" + if lang == "zh" { + label = "可用模型:" + } + extraLines = append(extraLines, formatOptionList(label, models)) + } } - if actionRequiresSlot("trader_management", "create", "strategy") && slots.StrategyID == "" { + if slots.StrategyID == "" && (actionRequiresSlot("trader_management", "create", "strategy") || len(strategies) == 0) { missing = append(missing, slotDisplayName("strategy", lang)) - extraLines = append(extraLines, formatOptionList("可用策略:", strategies)) + } + if slots.StrategyID == "" { + if len(strategies) == 0 { + if lang == "zh" { + extraLines = append(extraLines, "当前还没有可用策略,请先创建一个策略。") + } else { + extraLines = append(extraLines, "There is no strategy available yet. Please create one first.") + } + } else { + label := "Available strategies:" + if lang == "zh" { + label = "可用策略:" + } + extraLines = append(extraLines, formatOptionList(label, strategies)) + } } if len(missing) > 0 { @@ -595,6 +880,7 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang, if slots.AutoStart != nil && *slots.AutoStart { session.Phase = "await_start_confirmation" + setSkillDAGStep(&session, "await_start_confirmation") a.saveSkillSession(userID, session) if lang == "zh" { return fmt.Sprintf("我已经准备好创建交易员“%s”,并在创建后立即启动它。\n使用的交易所:%s\n使用的模型:%s\n使用的策略:%s\n\n这是高风险动作。回复“确认”继续,回复“先不用”则只创建不启动。", @@ -631,16 +917,31 @@ func (s *createTraderSkillSlots) StrategyNameOrID() string { func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang string, session skillSession, startAfterCreate bool) string { args := manageTraderArgs{ - Action: "create", - Name: session.Slots.Name, - AIModelID: session.Slots.ModelID, - ExchangeID: session.Slots.ExchangeID, - StrategyID: session.Slots.StrategyID, + Action: "create", + Name: session.Slots.Name, + AIModelID: session.Slots.ModelID, + ExchangeID: session.Slots.ExchangeID, + StrategyID: session.Slots.StrategyID, } createRaw := a.toolCreateTrader(storeUserID, args) if errMsg := parseSkillError(createRaw); errMsg != "" && strings.Contains(createRaw, `"error"`) { session.Phase = "collecting" a.saveSkillSession(userID, session) + if strings.Contains(strings.ToLower(errMsg), "exchange is disabled") { + exchanges := a.loadExchangeOptions(storeUserID) + if lang == "zh" { + reply := fmt.Sprintf("创建交易员失败:你选的交易所“%s”当前已禁用,请换一个已启用的交易所。", session.Slots.ExchangeNameOrID()) + if list := formatOptionList("可用交易所:", exchanges); list != "" { + reply += "\n" + list + } + return reply + } + reply := fmt.Sprintf("Failed to create trader: the selected exchange %q is disabled. Please choose an enabled exchange.", session.Slots.ExchangeNameOrID()) + if list := formatOptionList("Available exchanges:", exchanges); list != "" { + reply += "\n" + list + } + return reply + } if lang == "zh" { return "创建交易员失败:" + errMsg } @@ -658,6 +959,7 @@ func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang } if !startAfterCreate { + setSkillDAGStep(&session, "execute_create_only") a.clearSkillSession(userID) if lang == "zh" { return fmt.Sprintf("已创建交易员“%s”。\n交易所:%s\n模型:%s\n策略:%s\n当前状态:未启动。", @@ -667,6 +969,7 @@ func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang created.Trader.Name, session.Slots.ExchangeNameOrID(), session.Slots.ModelNameOrID(), session.Slots.StrategyNameOrID()) } + setSkillDAGStep(&session, "execute_create_and_start") startRaw := a.toolStartTrader(storeUserID, created.Trader.ID) if errMsg := parseSkillError(startRaw); errMsg != "" && strings.Contains(startRaw, `"error"`) { a.clearSkillSession(userID) @@ -735,6 +1038,9 @@ func (a *Agent) handleModelDiagnosisSkill(storeUserID, lang, text string) string lines = append(lines, fmt.Sprintf("1. 当前共 %d 个模型配置,已启用 %d 个。", len(payload.ModelConfigs), enabledCount)) lines = append(lines, "2. 检查目标模型是否同时具备 enabled、API Key、custom_api_url。") lines = append(lines, "3. 如果是 OpenAI / Claude / DeepSeek 等 provider,确认 model name 填的是该 provider 实际可用的模型名。") + if excerpt := backendLogDiagnosisExcerpt(lang, text, "model"); excerpt != "" { + lines = append(lines, excerpt) + } lines = append(lines, "下一步:如果你愿意,我下一步可以继续帮你逐项检查你当前配置里的具体模型。") return strings.Join(lines, "\n") } @@ -749,6 +1055,9 @@ func (a *Agent) handleModelDiagnosisSkill(storeUserID, lang, text string) string lines = append(lines, "Likely cause: the model was saved, but the API key, custom_api_url, or custom_model_name does not match the provider runtime config.") } lines = append(lines, fmt.Sprintf("Check first: %d model configs exist, %d are enabled.", len(payload.ModelConfigs), enabledCount)) + if excerpt := backendLogDiagnosisExcerpt(lang, text, "model"); excerpt != "" { + lines = append(lines, excerpt) + } lines = append(lines, "Next step: verify the target model has enabled=true, a non-empty API key, a valid HTTPS custom_api_url, and a correct model name.") return strings.Join(lines, "\n") } @@ -779,6 +1088,9 @@ func (a *Agent) handleExchangeDiagnosisSkill(storeUserID, lang, text string) str } lines = append(lines, "4. 检查 API 白名单是否包含当前服务器 IP。") lines = append(lines, "5. 检查是否已经开启交易/合约权限。") + if excerpt := backendLogDiagnosisExcerpt(lang, text, "exchange"); excerpt != "" { + lines = append(lines, excerpt) + } lines = append(lines, "下一步:如果你把具体报错原文贴给我,我可以按报错类型继续缩小范围。") return strings.Join(lines, "\n") } @@ -788,5 +1100,28 @@ func (a *Agent) handleExchangeDiagnosisSkill(storeUserID, lang, text string) str if len(exchanges) > 0 { lines = append(lines, "Current exchange bindings exist, so the next step is to match the exact error text to the most likely cause.") } + if excerpt := backendLogDiagnosisExcerpt(lang, text, "exchange"); excerpt != "" { + lines = append(lines, excerpt) + } return strings.Join(lines, "\n") } + +func backendLogDiagnosisExcerpt(lang, text, fallbackFilter string) string { + filter := strings.TrimSpace(text) + if strings.TrimSpace(filter) == "" { + filter = fallbackFilter + } + _, entries, err := readBackendLogEntries(8, filter, true) + if err != nil || len(entries) == 0 { + if filter != fallbackFilter { + _, entries, err = readBackendLogEntries(8, fallbackFilter, true) + } + } + if err != nil || len(entries) == 0 { + return "" + } + if lang == "zh" { + return "最近命中的后端错误日志:\n- " + strings.Join(entries, "\n- ") + } + return "Recent matching backend error logs:\n- " + strings.Join(entries, "\n- ") +} diff --git a/agent/skill_dispatcher_test.go b/agent/skill_dispatcher_test.go index e25341dc..bb292156 100644 --- a/agent/skill_dispatcher_test.go +++ b/agent/skill_dispatcher_test.go @@ -3,8 +3,12 @@ package agent import ( "context" "encoding/json" + "errors" "strings" "testing" + "time" + + "nofx/mcp" ) func TestCreateTraderSkillCollectsMissingFieldsAndCreatesTrader(t *testing.T) { @@ -61,6 +65,54 @@ func TestCreateTraderSkillCollectsMissingFieldsAndCreatesTrader(t *testing.T) { } } +func TestCreateTraderSkillReportsAllMissingPrerequisitesAtOnce(t *testing.T) { + a := newTestAgentWithStore(t) + + resp, err := a.thinkAndAct(context.Background(), "user-1", 11, "zh", "帮我创建一个交易员") + if err != nil { + t.Fatalf("thinkAndAct() error = %v", err) + } + for _, want := range []string{"名称", "交易所", "模型", "策略"} { + if !strings.Contains(resp, want) { + t.Fatalf("expected response to mention %q, got %q", want, resp) + } + } + for _, want := range []string{"当前还没有可用交易所配置", "当前还没有可用模型配置", "当前还没有可用策略"} { + if !strings.Contains(resp, want) { + t.Fatalf("expected response to mention prerequisite %q, got %q", want, resp) + } + } +} + +func TestActiveSkillSessionYieldsToNewTopic(t *testing.T) { + a := newTestAgentWithStore(t) + + _ = a.toolManageStrategy("user-1", `{ + "action":"create", + "name":"测试策略", + "lang":"zh" + }`) + + resp, err := a.thinkAndAct(context.Background(), "user-1", 13, "zh", "帮我创建一个交易员") + if err != nil { + t.Fatalf("thinkAndAct() error = %v", err) + } + if !strings.Contains(resp, "还缺这些信息") { + t.Fatalf("expected trader creation flow prompt, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 13, "zh", "列出我当前的策略") + if err != nil { + t.Fatalf("thinkAndAct() interrupt error = %v", err) + } + if !strings.Contains(resp, "当前策略") || !strings.Contains(resp, "测试策略") { + t.Fatalf("expected new topic to be handled, got %q", resp) + } + if a.hasActiveSkillSession(13) { + t.Fatal("expected skill session to be cleared after interruption") + } +} + func TestCreateTraderSkillRequestsStartConfirmation(t *testing.T) { a := newTestAgentWithStore(t) @@ -175,6 +227,28 @@ func TestStrategyManagementCreateAndActivateSkill(t *testing.T) { } } +func TestStrategyManagementQueryCanExplainStrategyDetails(t *testing.T) { + a := newTestAgentWithStore(t) + + resp, err := a.thinkAndAct(context.Background(), "user-1", 12, "zh", "创建一个叫“激进的”的策略") + if err != nil { + t.Fatalf("thinkAndAct() create error = %v", err) + } + if !strings.Contains(resp, "已创建策略") { + t.Fatalf("expected strategy create response, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 12, "zh", "这个策略里面的参数和prompt分别是什么样的") + if err != nil { + t.Fatalf("thinkAndAct() detail query error = %v", err) + } + for _, want := range []string{"策略“激进的”概览", "K线周期", "仓位风险", "Prompt"} { + if !strings.Contains(resp, want) { + t.Fatalf("expected response to mention %q, got %q", want, resp) + } + } +} + func TestTraderManagementQueryAndDiagnosisSkill(t *testing.T) { a := newTestAgentWithStore(t) @@ -234,3 +308,521 @@ func TestTraderManagementQueryAndDiagnosisSkill(t *testing.T) { t.Fatalf("expected trader diagnosis response, got %q", resp) } } + +func TestExchangeManagementAtomicUpdates(t *testing.T) { + a := newTestAgentWithStore(t) + + createResp := a.toolManageExchangeConfig("user-1", `{ + "action":"create", + "exchange_type":"okx", + "account_name":"主账户", + "enabled":true + }`) + var created struct { + Exchange safeExchangeToolConfig `json:"exchange"` + } + if err := json.Unmarshal([]byte(createResp), &created); err != nil { + t.Fatalf("unmarshal exchange response: %v", err) + } + + resp, err := a.thinkAndAct(context.Background(), "user-1", 14, "zh", "更新交易所,把主账户改名为备用账户") + if err != nil { + t.Fatalf("rename exchange error = %v", err) + } + if !strings.Contains(resp, "已更新交易所配置") { + t.Fatalf("expected exchange update response, got %q", resp) + } + + raw := a.toolGetExchangeConfigs("user-1") + if !strings.Contains(raw, "备用账户") { + t.Fatalf("expected renamed exchange in list, got %s", raw) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 14, "zh", "禁用这个交易所配置") + if err != nil { + t.Fatalf("disable exchange error = %v", err) + } + if !strings.Contains(resp, "已更新交易所配置") { + t.Fatalf("expected exchange status update response, got %q", resp) + } + + raw = a.toolGetExchangeConfigs("user-1") + if strings.Contains(raw, `"enabled":true`) && strings.Contains(raw, "备用账户") { + t.Fatalf("expected exchange to be disabled, got %s", raw) + } +} + +func TestModelManagementAtomicUpdates(t *testing.T) { + a := newTestAgentWithStore(t) + + createResp := a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"deepseek", + "enabled":true, + "custom_api_url":"https://api.deepseek.com/v1", + "custom_model_name":"deepseek-chat" + }`) + var created struct { + Model safeModelToolConfig `json:"model"` + } + if err := json.Unmarshal([]byte(createResp), &created); err != nil { + t.Fatalf("unmarshal model response: %v", err) + } + + resp, err := a.thinkAndAct(context.Background(), "user-1", 15, "zh", "更新模型,把模型名称改成 deepseek-reasoner") + if err != nil { + t.Fatalf("rename model error = %v", err) + } + if !strings.Contains(resp, "已更新模型配置") { + t.Fatalf("expected model update response, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 15, "zh", "更新模型,把接口地址改成 https://api.deepseek.com/beta") + if err != nil { + t.Fatalf("update model endpoint error = %v", err) + } + if !strings.Contains(resp, "已更新模型配置") { + t.Fatalf("expected model endpoint update response, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 15, "zh", "禁用这个模型配置") + if err != nil { + t.Fatalf("disable model error = %v", err) + } + if !strings.Contains(resp, "已更新模型配置") { + t.Fatalf("expected model status update response, got %q", resp) + } + + raw := a.toolGetModelConfigs("user-1") + if !strings.Contains(raw, "deepseek-reasoner") || !strings.Contains(raw, "https://api.deepseek.com/beta") { + t.Fatalf("expected updated model fields, got %s", raw) + } + if strings.Contains(raw, `"enabled":true`) && strings.Contains(raw, created.Model.ID) { + t.Fatalf("expected model to be disabled, got %s", raw) + } +} + +func TestStrategyManagementAtomicUpdates(t *testing.T) { + a := newTestAgentWithStore(t) + + resp, err := a.thinkAndAct(context.Background(), "user-1", 16, "zh", "创建一个叫“激进策略C”的策略") + if err != nil { + t.Fatalf("create strategy error = %v", err) + } + if !strings.Contains(resp, "已创建策略") { + t.Fatalf("expected strategy create response, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 16, "zh", "更新这个策略的prompt,把提示词改成“优先观察BTC和ETH,信号不一致时不要开仓”") + if err != nil { + t.Fatalf("update strategy prompt error = %v", err) + } + if !strings.Contains(resp, "已更新策略 prompt") { + t.Fatalf("expected strategy prompt update response, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 16, "zh", "更新这个策略参数,把最大持仓改成2,最低置信度改成80,主周期改成15m,并使用15m 1h 4h") + if err != nil { + t.Fatalf("update strategy config error = %v", err) + } + if !strings.Contains(resp, "已更新策略参数") { + t.Fatalf("expected strategy config update response, got %q", resp) + } + + listRaw := a.toolGetStrategies("user-1") + if !strings.Contains(listRaw, "优先观察BTC和ETH") || !strings.Contains(listRaw, `"max_positions":2`) || !strings.Contains(listRaw, `"min_confidence":80`) || !strings.Contains(listRaw, `"primary_timeframe":"15m"`) { + t.Fatalf("expected updated strategy config, got %s", listRaw) + } +} + +func TestTraderManagementAtomicBindingUpdate(t *testing.T) { + a := newTestAgentWithStore(t) + + modelOpenAI := a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"openai", + "enabled":true, + "custom_api_url":"https://api.openai.com/v1", + "custom_model_name":"gpt-5-mini" + }`) + var openAI struct { + Model safeModelToolConfig `json:"model"` + } + if err := json.Unmarshal([]byte(modelOpenAI), &openAI); err != nil { + t.Fatalf("unmarshal openai model: %v", err) + } + modelDeepSeek := a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"deepseek", + "enabled":true, + "custom_api_url":"https://api.deepseek.com/v1", + "custom_model_name":"deepseek-chat" + }`) + var deepSeek struct { + Model safeModelToolConfig `json:"model"` + } + if err := json.Unmarshal([]byte(modelDeepSeek), &deepSeek); err != nil { + t.Fatalf("unmarshal deepseek model: %v", err) + } + + exchangeBinance := a.toolManageExchangeConfig("user-1", `{ + "action":"create", + "exchange_type":"binance", + "account_name":"Binance 主账户", + "enabled":true + }`) + var binance struct { + Exchange safeExchangeToolConfig `json:"exchange"` + } + if err := json.Unmarshal([]byte(exchangeBinance), &binance); err != nil { + t.Fatalf("unmarshal binance exchange: %v", err) + } + exchangeOKX := a.toolManageExchangeConfig("user-1", `{ + "action":"create", + "exchange_type":"okx", + "account_name":"OKX 主账户", + "enabled":true + }`) + var okx struct { + Exchange safeExchangeToolConfig `json:"exchange"` + } + if err := json.Unmarshal([]byte(exchangeOKX), &okx); err != nil { + t.Fatalf("unmarshal okx exchange: %v", err) + } + + strategyA := a.toolManageStrategy("user-1", `{"action":"create","name":"策略A","lang":"zh"}`) + var stA struct { + Strategy safeStrategyToolConfig `json:"strategy"` + } + if err := json.Unmarshal([]byte(strategyA), &stA); err != nil { + t.Fatalf("unmarshal strategy A: %v", err) + } + strategyB := a.toolManageStrategy("user-1", `{"action":"create","name":"策略B","lang":"zh"}`) + var stB struct { + Strategy safeStrategyToolConfig `json:"strategy"` + } + if err := json.Unmarshal([]byte(strategyB), &stB); err != nil { + t.Fatalf("unmarshal strategy B: %v", err) + } + + createTrader := a.toolManageTrader("user-1", `{ + "action":"create", + "name":"实盘一号", + "ai_model_id":"`+openAI.Model.ID+`", + "exchange_id":"`+binance.Exchange.ID+`", + "strategy_id":"`+stA.Strategy.ID+`" + }`) + var trader struct { + Trader safeTraderToolConfig `json:"trader"` + } + if err := json.Unmarshal([]byte(createTrader), &trader); err != nil { + t.Fatalf("unmarshal trader: %v", err) + } + + resp, err := a.thinkAndAct(context.Background(), "user-1", 17, "zh", "更新交易员绑定,把实盘一号换成 deepseek-chat、OKX 主账户 和 策略B") + if err != nil { + t.Fatalf("update trader bindings error = %v", err) + } + if !strings.Contains(resp, "已更新交易员绑定") { + t.Fatalf("expected trader binding update response, got %q", resp) + } + + listRaw := a.toolListTraders("user-1") + if !strings.Contains(listRaw, deepSeek.Model.ID) || !strings.Contains(listRaw, okx.Exchange.ID) || !strings.Contains(listRaw, stB.Strategy.ID) { + t.Fatalf("expected trader bindings to change, got %s", listRaw) + } +} + +func TestStrategyManagementDeleteAllUserStrategies(t *testing.T) { + a := newTestAgentWithStore(t) + + for _, name := range []string{"趋势策略A", "趋势策略B"} { + resp := a.toolManageStrategy("user-1", `{ + "action":"create", + "name":"`+name+`", + "lang":"zh" + }`) + if strings.Contains(resp, `"error"`) { + t.Fatalf("failed to create strategy %q: %s", name, resp) + } + } + + resp, err := a.thinkAndAct(context.Background(), "user-1", 21, "zh", "现在把所有的策略全部删除") + if err != nil { + t.Fatalf("thinkAndAct() bulk delete start error = %v", err) + } + if !strings.Contains(resp, "确认") || !strings.Contains(resp, "全部自定义策略") { + t.Fatalf("expected bulk delete confirmation, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 21, "zh", "确认") + if err != nil { + t.Fatalf("thinkAndAct() bulk delete confirm error = %v", err) + } + if !strings.Contains(resp, "成功删除 2 个") { + t.Fatalf("expected bulk delete success summary, got %q", resp) + } + + listResp := a.toolGetStrategies("user-1") + if strings.Contains(listResp, "趋势策略A") || strings.Contains(listResp, "趋势策略B") { + t.Fatalf("expected created strategies to be deleted, got %s", listResp) + } +} + +func TestCreateTraderSkillRejectsDisabledExchangeWithClearPrompt(t *testing.T) { + a := newTestAgentWithStore(t) + + _ = a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"deepseek", + "enabled":true, + "api_key":"sk-test", + "custom_api_url":"https://api.deepseek.com/v1", + "custom_model_name":"deepseek-chat" + }`) + enabledExchange := a.toolManageExchangeConfig("user-1", `{ + "action":"create", + "exchange_type":"okx", + "account_name":"test", + "enabled":true + }`) + if strings.Contains(enabledExchange, `"error"`) { + t.Fatalf("failed to create enabled exchange: %s", enabledExchange) + } + anotherEnabledExchange := a.toolManageExchangeConfig("user-1", `{ + "action":"create", + "exchange_type":"okx", + "account_name":"lky", + "enabled":true + }`) + if strings.Contains(anotherEnabledExchange, `"error"`) { + t.Fatalf("failed to create second enabled exchange: %s", anotherEnabledExchange) + } + disabledExchange := a.toolManageExchangeConfig("user-1", `{ + "action":"create", + "exchange_type":"okx", + "account_name":"new", + "enabled":false + }`) + if strings.Contains(disabledExchange, `"error"`) { + t.Fatalf("failed to create disabled exchange: %s", disabledExchange) + } + _ = a.toolManageStrategy("user-1", `{"action":"create","name":"激进","lang":"zh"}`) + + resp, err := a.thinkAndAct(context.Background(), "user-1", 24, "zh", "给我创建一个trader") + if err != nil { + t.Fatalf("create trader start error = %v", err) + } + if !strings.Contains(resp, "new(已禁用)") { + t.Fatalf("expected disabled exchange to be labelled, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 24, "zh", "名称叫test,交易所用new、策略用激进") + if err != nil { + t.Fatalf("disabled exchange selection error = %v", err) + } + if !strings.Contains(resp, "当前已禁用") { + t.Fatalf("expected disabled exchange warning, got %q", resp) + } +} + +func TestCancelReplyExitsExchangeUpdateFlow(t *testing.T) { + a := newTestAgentWithStore(t) + _ = a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"deepseek", + "enabled":true, + "api_key":"sk-test", + "custom_api_url":"https://api.deepseek.com/v1", + "custom_model_name":"deepseek-chat" + }`) + + exchangeResp := a.toolManageExchangeConfig("user-1", `{ + "action":"create", + "exchange_type":"okx", + "account_name":"test", + "enabled":true + }`) + if strings.Contains(exchangeResp, `"error"`) { + t.Fatalf("failed to create exchange: %s", exchangeResp) + } + + resp, err := a.thinkAndAct(context.Background(), "user-1", 25, "zh", "把test这个交易所改一下") + if err != nil { + t.Fatalf("enter exchange update flow error = %v", err) + } + if !strings.Contains(resp, "请告诉我你要改什么") { + t.Fatalf("expected exchange update prompt, got %q", resp) + } + + resp, err = a.thinkAndAct(context.Background(), "user-1", 25, "zh", "不改") + if err != nil { + t.Fatalf("cancel exchange flow error = %v", err) + } + if !strings.Contains(resp, "已取消当前流程") { + t.Fatalf("expected flow cancellation, got %q", resp) + } +} + +func TestClassifySkillSessionInputInterruptsOnDeflection(t *testing.T) { + session := skillSession{Name: "exchange_management", Action: "update"} + a := &Agent{} + + if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "你能帮我看下报错吗"); got != "interrupt" { + t.Fatalf("expected diagnosis deflection to interrupt current skill flow, got %q", got) + } + if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "换话题了大哥"); got != "cancel" { + t.Fatalf("expected topic shift to cancel current skill flow, got %q", got) + } +} + +type skillSessionClassifierAIClient struct { + lastSystemPrompt string + lastUserPrompt string + response string +} + +func (c *skillSessionClassifierAIClient) SetAPIKey(string, string, string) {} +func (c *skillSessionClassifierAIClient) SetTimeout(time.Duration) {} +func (c *skillSessionClassifierAIClient) CallWithMessages(string, string) (string, error) { + return "", errors.New("unexpected CallWithMessages") +} +func (c *skillSessionClassifierAIClient) CallWithRequest(req *mcp.Request) (string, error) { + if len(req.Messages) > 0 { + c.lastSystemPrompt = req.Messages[0].Content + } + if len(req.Messages) > 1 { + c.lastUserPrompt = req.Messages[1].Content + } + return c.response, nil +} +func (c *skillSessionClassifierAIClient) CallWithRequestStream(*mcp.Request, func(string)) (string, error) { + return "", errors.New("unexpected CallWithRequestStream") +} +func (c *skillSessionClassifierAIClient) CallWithRequestFull(*mcp.Request) (*mcp.LLMResponse, error) { + return nil, errors.New("unexpected CallWithRequestFull") +} + +func TestClassifySkillSessionInputUsesSlotExpectationWithoutLLM(t *testing.T) { + client := &skillSessionClassifierAIClient{response: `{"decision":"interrupt"}`} + a := &Agent{aiClient: client} + session := skillSession{ + Name: "strategy_management", + Action: "update_config", + Fields: map[string]string{ + skillDAGStepField: "resolve_config_value", + "config_field": "min_confidence", + }, + } + + if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "70"); got != "continue" { + t.Fatalf("expected numeric slot fill to continue, got %q", got) + } + if client.lastSystemPrompt != "" { + t.Fatalf("expected no LLM call for direct slot expectation, got prompt %q", client.lastSystemPrompt) + } +} + +func TestClassifySkillSessionInputUsesLLMOnlyForAmbiguousDeflection(t *testing.T) { + client := &skillSessionClassifierAIClient{response: `{"decision":"interrupt"}`} + a := &Agent{ + aiClient: client, + history: newChatHistory(10), + } + session := skillSession{ + Name: "exchange_management", + Action: "update", + Fields: map[string]string{ + skillDAGStepField: "collect_account_name", + }, + } + + if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "你能帮我看下报错吗"); got != "interrupt" { + t.Fatalf("expected ambiguous deflection to interrupt, got %q", got) + } + if !strings.Contains(client.lastSystemPrompt, "classify one user message while a NOFXi structured management flow is active") { + t.Fatalf("expected LLM classifier prompt, got %q", client.lastSystemPrompt) + } +} + +func TestClassifySkillSessionInputUsesLLMForUnmatchedActiveSessionInput(t *testing.T) { + client := &skillSessionClassifierAIClient{response: `{"decision":"continue"}`} + a := &Agent{ + aiClient: client, + history: newChatHistory(10), + } + session := skillSession{ + Name: "model_management", + Action: "create", + Fields: map[string]string{ + skillDAGStepField: "collect_optional_fields", + "provider": "openai", + }, + } + + if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "新增一个"); got != "continue" { + t.Fatalf("expected unmatched active-session input to follow LLM decision, got %q", got) + } + if !strings.Contains(client.lastSystemPrompt, "classify one user message while a NOFXi structured management flow is active") { + t.Fatalf("expected LLM classifier prompt, got %q", client.lastSystemPrompt) + } +} + +func TestStrategyManagementCanDescribeDefaultConfig(t *testing.T) { + a := newTestAgentWithStore(t) + _ = a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"deepseek", + "enabled":true, + "api_key":"sk-test", + "custom_api_url":"https://api.deepseek.com/v1", + "custom_model_name":"deepseek-chat" + }`) + + resp, err := a.thinkAndAct(context.Background(), "user-1", 22, "zh", "看一下默认配置") + if err != nil { + t.Fatalf("thinkAndAct() default config error = %v", err) + } + if !strings.Contains(resp, "默认策略模板") || !strings.Contains(resp, "最低置信度") { + t.Fatalf("expected default strategy config response, got %q", resp) + } +} + +func TestStrategyManagementSupportsMultiFieldConfigUpdate(t *testing.T) { + a := newTestAgentWithStore(t) + _ = a.toolManageModelConfig("user-1", `{ + "action":"create", + "provider":"deepseek", + "enabled":true, + "api_key":"sk-test", + "custom_api_url":"https://api.deepseek.com/v1", + "custom_model_name":"deepseek-chat" + }`) + + createResp := a.toolManageStrategy("user-1", `{ + "action":"create", + "name":"趋势策略A", + "lang":"zh" + }`) + if strings.Contains(createResp, `"error"`) { + t.Fatalf("failed to create strategy: %s", createResp) + } + + resp, err := a.thinkAndAct(context.Background(), "user-1", 23, "zh", "把趋势策略A的最小置信度改成70,核心指标都全选") + if err != nil { + t.Fatalf("thinkAndAct() multi-field update error = %v", err) + } + if !strings.Contains(resp, "最小置信度") || !strings.Contains(resp, "EMA") { + t.Fatalf("expected multi-field update confirmation, got %q", resp) + } + + strategiesRaw := a.toolGetStrategies("user-1") + if !strings.Contains(strategiesRaw, `"min_confidence":70`) || + !strings.Contains(strategiesRaw, `"enable_ema":true`) || + !strings.Contains(strategiesRaw, `"enable_macd":true`) || + !strings.Contains(strategiesRaw, `"enable_rsi":true`) || + !strings.Contains(strategiesRaw, `"enable_atr":true`) || + !strings.Contains(strategiesRaw, `"enable_boll":true`) { + t.Fatalf("expected strategy config to include updated confidence and indicators, got %s", strategiesRaw) + } +} diff --git a/agent/skill_execution_handlers.go b/agent/skill_execution_handlers.go index 871d1b74..98db45cd 100644 --- a/agent/skill_execution_handlers.go +++ b/agent/skill_execution_handlers.go @@ -3,14 +3,346 @@ package agent import ( "encoding/json" "fmt" + "regexp" + "sort" + "strconv" "strings" + + "nofx/store" ) +var ( + firstIntegerPattern = regexp.MustCompile(`\d+`) + timeframeTokenRE = regexp.MustCompile(`(?i)\b\d{1,2}[mhdw]\b`) +) + +func parseStandaloneInteger(text string) (int, bool) { + match := firstIntegerPattern.FindString(strings.TrimSpace(text)) + if match == "" { + return 0, false + } + value, err := strconv.Atoi(match) + if err != nil { + return 0, false + } + return value, true +} + +func parseEnabledValue(text string) (bool, bool) { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"启用", "打开", "开启", "enable", "enabled", "on"}): + return true, true + case containsAny(lower, []string{"禁用", "关闭", "停用", "disable", "disabled", "off"}): + return false, true + default: + return false, false + } +} + +func parseFlagValue(text string, keywords []string) (bool, bool) { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" || !containsAny(lower, keywords) { + return false, false + } + switch { + case containsAny(lower, []string{"启用", "打开", "开启", "使用", "用", "是", "true", "enable", "enabled", "on", "use"}): + return true, true + case containsAny(lower, []string{"禁用", "关闭", "停用", "不用", "不要", "否", "false", "disable", "disabled", "off", "don't use", "do not use"}): + return false, true + default: + return false, false + } +} + +func extractCredentialValue(text string, keywords []string) string { + if value := extractQuotedContent(text); value != "" && containsAny(strings.ToLower(text), keywords) { + return value + } + return extractPostKeywordName(text, keywords) +} + +func parseScanIntervalMinutes(text string) (int, bool) { + if value, ok := extractLabeledInt(text, []string{"扫描间隔", "扫描频率", "scan interval", "scan frequency"}); ok { + return value, true + } + lower := strings.ToLower(strings.TrimSpace(text)) + if !containsAny(lower, []string{"扫描间隔", "扫描频率", "scan interval", "scan frequency"}) { + return 0, false + } + return parseStandaloneInteger(text) +} + +func detectStrategyConfigField(text string) string { + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"最大持仓", "最多持仓", "max positions"}): + return "max_positions" + case containsAny(lower, []string{"最低置信度", "最小置信度", "min confidence"}): + return "min_confidence" + case containsAny(lower, []string{"btc/eth杠杆", "btc eth杠杆", "btc eth leverage", "btc/eth leverage", "主流币杠杆"}): + return "btceth_max_leverage" + case containsAny(lower, []string{"山寨币杠杆", "altcoin leverage", "alts leverage"}): + return "altcoin_max_leverage" + case containsAny(lower, []string{"ema"}): + return "enable_ema" + case containsAny(lower, []string{"macd"}): + return "enable_macd" + case containsAny(lower, []string{"rsi"}): + return "enable_rsi" + case containsAny(lower, []string{"atr"}): + return "enable_atr" + case containsAny(lower, []string{"boll", "bollinger", "布林"}): + return "enable_boll" + case containsAny(lower, []string{"核心指标"}) && containsAny(lower, []string{"全选", "全部", "全开", "都开", "都启用", "全部启用"}): + return "enable_all_core_indicators" + case containsAny(lower, []string{"主周期", "主时间周期", "primary timeframe"}): + return "primary_timeframe" + case containsAny(lower, []string{"多周期", "时间框架", "timeframes", "selected timeframes"}): + return "selected_timeframes" + default: + return "" + } +} + +func strategyConfigFieldDisplayName(field, lang string) string { + switch field { + case "max_positions": + if lang == "zh" { + return "最大持仓" + } + return "max positions" + case "min_confidence": + if lang == "zh" { + return "最小置信度" + } + return "min confidence" + case "btceth_max_leverage": + if lang == "zh" { + return "BTC/ETH 最大杠杆" + } + return "BTC/ETH max leverage" + case "altcoin_max_leverage": + if lang == "zh" { + return "山寨币最大杠杆" + } + return "altcoin max leverage" + case "enable_ema": + if lang == "zh" { + return "EMA" + } + return "EMA" + case "enable_macd": + if lang == "zh" { + return "MACD" + } + return "MACD" + case "enable_rsi": + if lang == "zh" { + return "RSI" + } + return "RSI" + case "enable_atr": + if lang == "zh" { + return "ATR" + } + return "ATR" + case "enable_boll": + if lang == "zh" { + return "Bollinger" + } + return "Bollinger" + case "enable_all_core_indicators": + if lang == "zh" { + return "全部核心指标" + } + return "all core indicators" + case "primary_timeframe": + if lang == "zh" { + return "主周期" + } + return "primary timeframe" + case "selected_timeframes": + if lang == "zh" { + return "多周期时间框架" + } + return "selected timeframes" + default: + return field + } +} + +func extractStrategyConfigValue(text, field string) (string, bool) { + switch field { + case "max_positions": + if value, ok := extractLabeledInt(text, []string{"最大持仓", "最多持仓", "max positions"}); ok { + return strconv.Itoa(value), true + } + if value, ok := parseStandaloneInteger(text); ok { + return strconv.Itoa(value), true + } + case "min_confidence": + if value, ok := extractLabeledInt(text, []string{"最低置信度", "最小置信度", "min confidence"}); ok { + return strconv.Itoa(value), true + } + if value, ok := parseStandaloneInteger(text); ok { + return strconv.Itoa(value), true + } + case "btceth_max_leverage": + if value, ok := extractLabeledInt(text, []string{"btc/eth杠杆", "btc eth杠杆", "btc/eth leverage", "btc eth leverage", "主流币杠杆"}); ok { + return strconv.Itoa(value), true + } + if value, ok := parseStandaloneInteger(text); ok { + return strconv.Itoa(value), true + } + case "altcoin_max_leverage": + if value, ok := extractLabeledInt(text, []string{"山寨币杠杆", "altcoin leverage", "alts leverage"}); ok { + return strconv.Itoa(value), true + } + if value, ok := parseStandaloneInteger(text); ok { + return strconv.Itoa(value), true + } + case "enable_ema", "enable_macd", "enable_rsi", "enable_atr", "enable_boll": + if enabled, ok := parseEnabledValue(text); ok { + return strconv.FormatBool(enabled), true + } + case "enable_all_core_indicators": + lower := strings.ToLower(strings.TrimSpace(text)) + switch { + case containsAny(lower, []string{"全选", "全部", "全开", "都开", "都启用", "全部启用"}): + return "true", true + case containsAny(lower, []string{"关闭", "停用", "禁用", "全部关闭", "全部禁用"}): + return "false", true + } + case "primary_timeframe": + if tf := extractTimeframeAfterKeywords(text, []string{"主周期", "主时间周期", "primary timeframe", "timeframe"}); tf != "" { + return tf, true + } + case "selected_timeframes": + if tfs := extractTimeframes(text); len(tfs) > 0 { + return strings.Join(tfs, ","), true + } + } + return "", false +} + +type strategyConfigPatch struct { + Field string + Value string +} + +func detectStrategyConfigPatches(text string) []strategyConfigPatch { + seen := map[string]string{} + addPatch := func(field, value string) { + field = strings.TrimSpace(field) + value = strings.TrimSpace(value) + if field == "" || value == "" { + return + } + seen[field] = value + } + + for _, field := range []string{ + "max_positions", + "min_confidence", + "btceth_max_leverage", + "altcoin_max_leverage", + "primary_timeframe", + "selected_timeframes", + "enable_ema", + "enable_macd", + "enable_rsi", + "enable_atr", + "enable_boll", + "enable_all_core_indicators", + } { + if value, ok := extractStrategyConfigValue(text, field); ok { + if field == "enable_all_core_indicators" { + addPatch("enable_ema", value) + addPatch("enable_macd", value) + addPatch("enable_rsi", value) + addPatch("enable_atr", value) + addPatch("enable_boll", value) + continue + } + addPatch(field, value) + } + } + + fields := make([]string, 0, len(seen)) + for field := range seen { + fields = append(fields, field) + } + sort.Strings(fields) + + patches := make([]strategyConfigPatch, 0, len(fields)) + for _, field := range fields { + patches = append(patches, strategyConfigPatch{Field: field, Value: seen[field]}) + } + return patches +} + +func applyStrategyConfigPatch(cfg *store.StrategyConfig, field, value string) error { + switch field { + case "max_positions": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("最大持仓需要是整数") + } + cfg.RiskControl.MaxPositions = parsed + case "min_confidence": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("最小置信度需要是整数") + } + cfg.RiskControl.MinConfidence = parsed + case "btceth_max_leverage": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("BTC/ETH 最大杠杆需要是整数") + } + cfg.RiskControl.BTCETHMaxLeverage = parsed + case "altcoin_max_leverage": + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("山寨币最大杠杆需要是整数") + } + cfg.RiskControl.AltcoinMaxLeverage = parsed + case "primary_timeframe": + cfg.Indicators.Klines.PrimaryTimeframe = value + case "selected_timeframes": + tfs := strings.Split(value, ",") + cfg.Indicators.Klines.SelectedTimeframes = tfs + cfg.Indicators.Klines.EnableMultiTimeframe = len(tfs) > 1 + case "enable_ema": + cfg.Indicators.EnableEMA = value == "true" + case "enable_macd": + cfg.Indicators.EnableMACD = value == "true" + case "enable_rsi": + cfg.Indicators.EnableRSI = value == "true" + case "enable_atr": + cfg.Indicators.EnableATR = value == "true" + case "enable_boll": + cfg.Indicators.EnableBOLL = value == "true" + default: + return fmt.Errorf("unsupported strategy config field: %s", field) + } + return nil +} + func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { switch session.Action { - case "query": + case "query", "query_list": + return formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)) + case "query_detail": + if detail, ok := a.describeTrader(storeUserID, lang, session.TargetRef); ok { + return detail + } return formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)) case "start", "stop", "delete": + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "await_confirmation") + } if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { a.saveSkillSession(userID, session) return msg @@ -22,10 +354,13 @@ func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, var resp string switch session.Action { case "start": + setSkillDAGStep(&session, "execute_start") resp = a.toolStartTrader(storeUserID, session.TargetRef.ID) case "stop": + setSkillDAGStep(&session, "execute_stop") resp = a.toolStopTrader(storeUserID, session.TargetRef.ID) case "delete": + setSkillDAGStep(&session, "execute_delete") resp = a.toolDeleteTrader(storeUserID, session.TargetRef.ID) } a.clearSkillSession(userID) @@ -39,19 +374,121 @@ func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, return fmt.Sprintf("已完成交易员操作:%s。", session.Action) } return fmt.Sprintf("Completed trader action: %s.", session.Action) - case "update": + case "update", "update_name", "update_bindings": + if session.Action == "update_bindings" { + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "collect_bindings") + } + args := manageTraderArgs{Action: "update", TraderID: session.TargetRef.ID} + if match := pickMentionedOption(text, a.loadEnabledModelOptions(storeUserID)); match != nil { + args.AIModelID = match.ID + } + if match := pickMentionedOption(text, a.loadExchangeOptions(storeUserID)); match != nil { + args.ExchangeID = match.ID + } + if match := pickMentionedOption(text, a.loadStrategyOptions(storeUserID)); match != nil { + args.StrategyID = match.ID + } + if args.AIModelID != "" { + setField(&session, "ai_model_id", args.AIModelID) + } + if args.ExchangeID != "" { + setField(&session, "exchange_id", args.ExchangeID) + } + if args.StrategyID != "" { + setField(&session, "strategy_id", args.StrategyID) + } + if value := fieldValue(session, "ai_model_id"); value != "" { + args.AIModelID = value + } + if value := fieldValue(session, "exchange_id"); value != "" { + args.ExchangeID = value + } + if value := fieldValue(session, "strategy_id"); value != "" { + args.StrategyID = value + } + if args.AIModelID == "" && args.ExchangeID == "" && args.StrategyID == "" { + setSkillDAGStep(&session, "collect_bindings") + a.saveSkillSession(userID, session) + if lang == "zh" { + return "这次是更新交易员绑定,请直接说要换成哪个模型、交易所或策略。" + } + return "This action updates trader bindings. Tell me which model, exchange, or strategy to switch to." + } + setSkillDAGStep(&session, "execute_update") + resp := a.toolUpdateTrader(storeUserID, args) + a.clearSkillSession(userID) + if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { + if lang == "zh" { + return "更新交易员绑定失败:" + errMsg + } + return "Failed to update trader bindings: " + errMsg + } + if lang == "zh" { + return "已更新交易员绑定。" + } + return "Updated trader bindings." + } + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "collect_name") + } + args := manageTraderArgs{Action: "update", TraderID: session.TargetRef.ID} + if minutes, ok := parseScanIntervalMinutes(text); ok && minutes > 0 { + args.ScanIntervalMinutes = &minutes + } + if value, ok := extractStrategyConfigValue(text, "btceth_max_leverage"); ok { + if parsed, err := strconv.Atoi(value); err == nil { + args.BTCETHLeverage = &parsed + } + } + if value, ok := extractStrategyConfigValue(text, "altcoin_max_leverage"); ok { + if parsed, err := strconv.Atoi(value); err == nil { + args.AltcoinLeverage = &parsed + } + } + if prompt := extractCredentialValue(text, []string{"自定义提示词", "提示词", "custom prompt", "prompt"}); prompt != "" && + containsAny(strings.ToLower(text), []string{"提示词", "prompt"}) { + args.CustomPrompt = prompt + } + if enabled, ok := parseFlagValue(text, []string{"ai500"}); ok { + args.UseAI500 = &enabled + } + if enabled, ok := parseFlagValue(text, []string{"oi top", "oitop", "持仓量排名"}); ok { + args.UseOITop = &enabled + } + if args.ScanIntervalMinutes != nil || args.BTCETHLeverage != nil || args.AltcoinLeverage != nil || args.CustomPrompt != "" || args.UseAI500 != nil || args.UseOITop != nil { + setSkillDAGStep(&session, "execute_update") + resp := a.toolUpdateTrader(storeUserID, args) + a.clearSkillSession(userID) + if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { + if lang == "zh" { + return "更新交易员失败:" + errMsg + } + return "Failed to update trader: " + errMsg + } + if lang == "zh" { + return "已更新交易员配置。" + } + return "Updated trader config." + } newName := extractTraderName(text) if newName == "" { newName = extractPostKeywordName(text, []string{"改成", "改为", "rename to"}) } + if newName != "" { + setField(&session, "name", newName) + } + newName = fieldValue(session, "name") if newName == "" { + setSkillDAGStep(&session, "collect_name") a.saveSkillSession(userID, session) if lang == "zh" { return "目前更新交易员这条 skill 先支持改名。请直接告诉我新的名字。" } return "This trader update skill currently supports renaming first. Tell me the new name." } - args := manageTraderArgs{Action: "update", TraderID: session.TargetRef.ID, Name: newName} + args = manageTraderArgs{Action: "update", TraderID: session.TargetRef.ID, Name: newName} + setSkillDAGStep(&session, "execute_update") resp := a.toolUpdateTrader(storeUserID, args) a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { @@ -71,7 +508,15 @@ func (a *Agent) executeTraderManagementAction(storeUserID string, userID int64, func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { switch session.Action { + case "query_detail": + if detail, ok := a.describeExchange(storeUserID, lang, session.TargetRef); ok { + return detail + } + return formatReadFastPathResponse(lang, "get_exchange_configs", a.toolGetExchangeConfigs(storeUserID)) case "delete": + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "await_confirmation") + } if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { a.saveSkillSession(userID, session) return msg @@ -80,6 +525,7 @@ func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64 a.saveSkillSession(userID, session) return msg } + setSkillDAGStep(&session, "execute_delete") args, _ := json.Marshal(map[string]any{"action": "delete", "exchange_id": session.TargetRef.ID}) resp := a.toolManageExchangeConfig(storeUserID, string(args)) a.clearSkillSession(userID) @@ -93,28 +539,72 @@ func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64 return "已删除交易所配置。" } return "Deleted exchange config." - case "update": + case "update", "update_name", "update_status": + if fieldValue(session, skillDAGStepField) == "" { + if session.Action == "update_status" { + setSkillDAGStep(&session, "collect_enabled") + } else { + setSkillDAGStep(&session, "collect_account_name") + } + } accountName := extractTraderName(text) if accountName == "" { accountName = extractPostKeywordName(text, []string{"改成", "改为", "账户名改成", "rename to"}) } - payload := map[string]any{"action": "update", "exchange_id": session.TargetRef.ID} if accountName != "" { + setField(&session, "account_name", accountName) + } + if enabled, ok := parseEnabledValue(text); ok { + setField(&session, "enabled", strconv.FormatBool(enabled)) + } + if value := extractCredentialValue(text, []string{"api key", "apikey", "api_key"}); value != "" { + setField(&session, "api_key", value) + } + if value := extractCredentialValue(text, []string{"secret key", "secret", "secret_key"}); value != "" { + setField(&session, "secret_key", value) + } + if value := extractCredentialValue(text, []string{"passphrase", "密码短语"}); value != "" { + setField(&session, "passphrase", value) + } + if testnet, ok := parseFlagValue(text, []string{"testnet", "测试网"}); ok { + setField(&session, "testnet", strconv.FormatBool(testnet)) + } + payload := map[string]any{"action": "update", "exchange_id": session.TargetRef.ID} + accountName = fieldValue(session, "account_name") + if accountName != "" && session.Action != "update_status" { payload["account_name"] = accountName } - if containsAny(strings.ToLower(text), []string{"启用", "enable"}) { - payload["enabled"] = true + if enabledRaw := fieldValue(session, "enabled"); enabledRaw != "" { + payload["enabled"] = enabledRaw == "true" } - if containsAny(strings.ToLower(text), []string{"禁用", "disable"}) { - payload["enabled"] = false + if value := fieldValue(session, "api_key"); value != "" { + payload["api_key"] = value + } + if value := fieldValue(session, "secret_key"); value != "" { + payload["secret_key"] = value + } + if value := fieldValue(session, "passphrase"); value != "" { + payload["passphrase"] = value + } + if value := fieldValue(session, "testnet"); value != "" { + payload["testnet"] = value == "true" + } + if session.Action == "update_status" { + delete(payload, "account_name") } if len(payload) == 2 { + if session.Action == "update_status" { + setSkillDAGStep(&session, "collect_enabled") + } else { + setSkillDAGStep(&session, "collect_account_name") + } a.saveSkillSession(userID, session) if lang == "zh" { - return "目前更新交易所 skill 先支持改账户名和启用/禁用。请告诉我你要改什么。" + return "目前更新交易所 skill 支持改账户名、启用状态、API Key、Secret、Passphrase 和 testnet。请告诉我你要改什么。" } - return "This exchange update skill currently supports renaming and enable/disable. Tell me what to change." + return "This exchange update skill supports account name, enabled state, API key, secret, passphrase, and testnet." } + setSkillDAGStep(&session, "execute_update") raw, _ := json.Marshal(payload) resp := a.toolManageExchangeConfig(storeUserID, string(raw)) a.clearSkillSession(userID) @@ -135,7 +625,15 @@ func (a *Agent) executeExchangeManagementAction(storeUserID string, userID int64 func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { switch session.Action { + case "query_detail": + if detail, ok := a.describeModel(storeUserID, lang, session.TargetRef); ok { + return detail + } + return formatReadFastPathResponse(lang, "get_model_configs", a.toolGetModelConfigs(storeUserID)) case "delete": + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "await_confirmation") + } if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { a.saveSkillSession(userID, session) return msg @@ -144,6 +642,7 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l a.saveSkillSession(userID, session) return msg } + setSkillDAGStep(&session, "execute_delete") raw, _ := json.Marshal(map[string]any{"action": "delete", "model_id": session.TargetRef.ID}) resp := a.toolManageModelConfig(storeUserID, string(raw)) a.clearSkillSession(userID) @@ -157,37 +656,91 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l return "已删除模型配置。" } return "Deleted model config." - case "update": + case "update", "update_name", "update_endpoint", "update_status": + if fieldValue(session, skillDAGStepField) == "" { + switch session.Action { + case "update_status": + setSkillDAGStep(&session, "collect_enabled") + case "update_endpoint": + setSkillDAGStep(&session, "collect_custom_api_url") + default: + setSkillDAGStep(&session, "collect_custom_model_name") + } + } payload := map[string]any{"action": "update", "model_id": session.TargetRef.ID} if url := extractURL(text); url != "" { - payload["custom_api_url"] = url + setField(&session, "custom_api_url", url) } - if containsAny(strings.ToLower(text), []string{"启用", "enable"}) { - payload["enabled"] = true + if enabled, ok := parseEnabledValue(text); ok { + setField(&session, "enabled", strconv.FormatBool(enabled)) } - if containsAny(strings.ToLower(text), []string{"禁用", "disable"}) { - payload["enabled"] = false + if apiKey := extractCredentialValue(text, []string{"api key", "apikey", "api_key"}); apiKey != "" { + setField(&session, "api_key", apiKey) } if modelName := extractPostKeywordName(text, []string{"model name", "模型名", "模型名称", "改成"}); modelName != "" { - payload["custom_model_name"] = modelName + setField(&session, "custom_model_name", modelName) + } + if value := fieldValue(session, "custom_api_url"); value != "" { + payload["custom_api_url"] = value + } + if value := fieldValue(session, "enabled"); value != "" { + payload["enabled"] = value == "true" + } + if value := fieldValue(session, "api_key"); value != "" { + payload["api_key"] = value + } + if value := fieldValue(session, "custom_model_name"); value != "" { + payload["custom_model_name"] = value + } + if session.Action == "update_name" { + delete(payload, "custom_api_url") + delete(payload, "enabled") + delete(payload, "api_key") + } + if session.Action == "update_status" { + delete(payload, "custom_api_url") + delete(payload, "custom_model_name") + delete(payload, "api_key") + } + if session.Action == "update_endpoint" { + delete(payload, "custom_model_name") + delete(payload, "enabled") + delete(payload, "api_key") } if len(payload) == 2 { + switch session.Action { + case "update_status": + setSkillDAGStep(&session, "collect_enabled") + case "update_endpoint": + setSkillDAGStep(&session, "collect_custom_api_url") + default: + setSkillDAGStep(&session, "collect_custom_model_name") + } a.saveSkillSession(userID, session) if lang == "zh" { - return "目前更新模型 skill 先支持改 URL、模型名和启用状态。请告诉我你要改什么。" + return "目前更新模型 skill 支持改 API Key、URL、模型名和启用状态。请告诉我你要改什么。" } - return "This model update skill currently supports URL, model name, and enabled state." + return "This model update skill supports API key, URL, model name, and enabled state." } + setSkillDAGStep(&session, "execute_update") raw, _ := json.Marshal(payload) resp := a.toolManageModelConfig(storeUserID, string(raw)) - a.clearSkillSession(userID) if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { + a.saveSkillSession(userID, session) if lang == "zh" { + if strings.Contains(errMsg, "cannot enable model config before API key is configured") { + return "更新模型配置失败:这个模型还没有配置 API Key,暂时不能启用。你可以直接把 API Key 发给我,我帮你继续配置。" + } return "更新模型配置失败:" + errMsg } + a.saveSkillSession(userID, session) return "Failed to update model config: " + errMsg } + a.clearSkillSession(userID) if lang == "zh" { + if session.Action == "update_status" { + return "已更新模型配置启用状态。" + } return "已更新模型配置。" } return "Updated model config." @@ -198,6 +751,13 @@ func (a *Agent) executeModelManagementAction(storeUserID string, userID int64, l func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64, lang, text string, session skillSession) string { switch session.Action { + case "query", "query_list": + return formatReadFastPathResponse(lang, "get_strategies", a.toolGetStrategies(storeUserID)) + case "query_detail": + if detail, ok := a.describeStrategy(storeUserID, lang, session.TargetRef); ok { + return detail + } + return formatReadFastPathResponse(lang, "get_strategies", a.toolGetStrategies(storeUserID)) case "activate": raw, _ := json.Marshal(map[string]any{"action": "activate", "strategy_id": session.TargetRef.ID}) resp := a.toolManageStrategy(storeUserID, string(raw)) @@ -213,17 +773,26 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 } return "Activated strategy." case "duplicate": + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "collect_name") + } newName := extractTraderName(text) if newName == "" { newName = extractPostKeywordName(text, []string{"叫", "名为", "改成", "rename to"}) } + if newName != "" { + setField(&session, "name", newName) + } + newName = fieldValue(session, "name") if newName == "" { + setSkillDAGStep(&session, "collect_name") a.saveSkillSession(userID, session) if lang == "zh" { return "复制策略时,我还需要一个新名称。" } return "I still need a new name for the duplicated strategy." } + setSkillDAGStep(&session, "execute_duplicate") raw, _ := json.Marshal(map[string]any{"action": "duplicate", "strategy_id": session.TargetRef.ID, "name": newName}) resp := a.toolManageStrategy(storeUserID, string(raw)) a.clearSkillSession(userID) @@ -238,6 +807,88 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 } return fmt.Sprintf("Duplicated strategy as %q.", newName) case "delete": + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "await_confirmation") + } + if fieldValue(session, "bulk_scope") == "all" { + strategies, err := a.store.Strategy().List(storeUserID) + if err != nil { + if lang == "zh" { + return "读取策略列表失败:" + err.Error() + } + return "Failed to load strategies: " + err.Error() + } + + deletable := make([]*store.Strategy, 0, len(strategies)) + skippedDefault := 0 + for _, strategy := range strategies { + if strategy == nil { + continue + } + if strategy.IsDefault { + skippedDefault++ + continue + } + deletable = append(deletable, strategy) + } + if len(deletable) == 0 { + a.clearSkillSession(userID) + if lang == "zh" { + return "当前没有可删除的自定义策略。" + } + return "There are no user-created strategies to delete." + } + + targetLabel := fmt.Sprintf("全部自定义策略(共 %d 个)", len(deletable)) + if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, targetLabel); waiting { + a.saveSkillSession(userID, session) + return msg + } + if msg, waiting := awaitingConfirmationButNotApproved(lang, session, text); waiting { + a.saveSkillSession(userID, session) + return msg + } + + setSkillDAGStep(&session, "execute_delete") + deletedNames := make([]string, 0, len(deletable)) + failedNames := make([]string, 0) + for _, strategy := range deletable { + raw, _ := json.Marshal(map[string]any{"action": "delete", "strategy_id": strategy.ID}) + resp := a.toolManageStrategy(storeUserID, string(raw)) + if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { + failedNames = append(failedNames, fmt.Sprintf("%s(%s)", strategy.Name, errMsg)) + continue + } + deletedNames = append(deletedNames, strategy.Name) + } + a.clearSkillSession(userID) + + if lang == "zh" { + parts := []string{fmt.Sprintf("批量删除策略已完成:成功删除 %d 个。", len(deletedNames))} + if skippedDefault > 0 { + parts = append(parts, fmt.Sprintf("已跳过系统默认策略 %d 个。", skippedDefault)) + } + if len(failedNames) > 0 { + parts = append(parts, "删除失败:"+strings.Join(failedNames, ";")) + } + if len(deletedNames) > 0 { + parts = append(parts, "已删除:"+strings.Join(deletedNames, "、")) + } + return strings.Join(parts, "\n") + } + + parts := []string{fmt.Sprintf("Bulk strategy deletion finished: deleted %d strategy(s).", len(deletedNames))} + if skippedDefault > 0 { + parts = append(parts, fmt.Sprintf("Skipped %d default strategy(ies).", skippedDefault)) + } + if len(failedNames) > 0 { + parts = append(parts, "Failed: "+strings.Join(failedNames, "; ")) + } + if len(deletedNames) > 0 { + parts = append(parts, "Deleted: "+strings.Join(deletedNames, ", ")) + } + return strings.Join(parts, "\n") + } if msg, waiting := beginConfirmationIfNeeded(userID, lang, &session, defaultIfEmpty(session.TargetRef.Name, session.TargetRef.ID)); waiting { a.saveSkillSession(userID, session) return msg @@ -246,6 +897,7 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 a.saveSkillSession(userID, session) return msg } + setSkillDAGStep(&session, "execute_delete") raw, _ := json.Marshal(map[string]any{"action": "delete", "strategy_id": session.TargetRef.ID}) resp := a.toolManageStrategy(storeUserID, string(raw)) a.clearSkillSession(userID) @@ -259,18 +911,33 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 return "已删除策略。" } return "Deleted strategy." - case "update": + case "update", "update_name", "update_config", "update_prompt": + if session.Action == "update_prompt" { + return a.executeStrategyPromptUpdate(storeUserID, userID, lang, text, session) + } + if session.Action == "update_config" { + return a.executeStrategyConfigUpdate(storeUserID, userID, lang, text, session) + } + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "collect_name") + } newName := extractTraderName(text) if newName == "" { newName = extractPostKeywordName(text, []string{"改成", "改为", "rename to"}) } + if newName != "" { + setField(&session, "name", newName) + } + newName = fieldValue(session, "name") if newName == "" { + setSkillDAGStep(&session, "collect_name") a.saveSkillSession(userID, session) if lang == "zh" { return "目前更新策略 skill 先支持改名。请告诉我新的策略名称。" } return "This strategy update skill currently supports renaming first." } + setSkillDAGStep(&session, "execute_update") raw, _ := json.Marshal(map[string]any{"action": "update", "strategy_id": session.TargetRef.ID, "name": newName}) resp := a.toolManageStrategy(storeUserID, string(raw)) a.clearSkillSession(userID) @@ -289,20 +956,267 @@ func (a *Agent) executeStrategyManagementAction(storeUserID string, userID int64 } } +func (a *Agent) executeStrategyPromptUpdate(storeUserID string, userID int64, lang, text string, session skillSession) string { + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "collect_prompt") + } + strategy, cfg, err := a.loadStrategyConfigForUpdate(storeUserID, session.TargetRef.ID) + if err != nil { + if lang == "zh" { + return "读取策略失败:" + err.Error() + } + return "Failed to load strategy: " + err.Error() + } + + prompt := extractQuotedContent(text) + if prompt == "" { + prompt = extractPostKeywordName(text, []string{"prompt改成", "prompt 改成", "提示词改成", "提示词改为", "custom prompt 改成"}) + } + if prompt != "" { + setField(&session, "prompt", prompt) + } + prompt = fieldValue(session, "prompt") + if prompt == "" { + setSkillDAGStep(&session, "collect_prompt") + a.saveSkillSession(userID, session) + if lang == "zh" { + return "这次是更新策略 prompt,请直接把新的 prompt 内容发给我,最好放在引号里。" + } + return "This action updates the strategy prompt. Send me the new prompt text, ideally inside quotes." + } + + cfg.CustomPrompt = prompt + setSkillDAGStep(&session, "execute_update") + return a.persistStrategyConfigUpdate(storeUserID, userID, lang, strategy, cfg, "已更新策略 prompt。", "Updated strategy prompt.") +} + +func (a *Agent) executeStrategyConfigUpdate(storeUserID string, userID int64, lang, text string, session skillSession) string { + if _, ok := getSkillDAG("strategy_management", "update_config"); ok { + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "resolve_config_field") + } + } + + currentStep, _ := currentSkillDAGStep(session) + strategy, cfg, err := a.loadStrategyConfigForUpdate(storeUserID, session.TargetRef.ID) + if err != nil { + if lang == "zh" { + return "读取策略失败:" + err.Error() + } + return "Failed to load strategy: " + err.Error() + } + + if fieldValue(session, "config_field") == "" && fieldValue(session, "config_value") == "" { + patches := detectStrategyConfigPatches(text) + if len(patches) > 1 { + changed := make([]string, 0, len(patches)) + for _, patch := range patches { + if err := applyStrategyConfigPatch(&cfg, patch.Field, patch.Value); err != nil { + a.saveSkillSession(userID, session) + if lang == "zh" { + return "更新策略参数失败:" + err.Error() + } + return "Failed to update strategy config: " + err.Error() + } + changed = append(changed, strategyConfigFieldDisplayName(patch.Field, lang)) + } + cfg.ClampLimits() + setSkillDAGStep(&session, "apply_field_update") + setSkillDAGStep(&session, "execute_update") + msgZH := "已更新策略参数:" + strings.Join(changed, "、") + "。" + msgEN := "Updated strategy config fields: " + strings.Join(changed, ", ") + "." + return a.persistStrategyConfigUpdate(storeUserID, userID, lang, strategy, cfg, msgZH, msgEN) + } + } + + field := fieldValue(session, "config_field") + if field == "" { + field = detectStrategyConfigField(text) + if field != "" { + setField(&session, "config_field", field) + if currentStep.ID == "resolve_config_field" { + advanceSkillDAGStep(&session, currentStep.ID) + currentStep, _ = currentSkillDAGStep(session) + } + } + } + if field == "" { + setSkillDAGStep(&session, "resolve_config_field") + a.saveSkillSession(userID, session) + if lang == "zh" { + return "这次是更新策略参数。我当前先支持这些字段:最大持仓、最低置信度、主周期、多周期时间框架。请先告诉我要改哪个字段。" + } + return "This action updates strategy config. I currently support max positions, min confidence, primary timeframe, and selected timeframes. Tell me which field to change first." + } + + if value, ok := extractStrategyConfigValue(text, field); ok { + setField(&session, "config_value", value) + if currentStep.ID == "resolve_config_value" { + advanceSkillDAGStep(&session, currentStep.ID) + currentStep, _ = currentSkillDAGStep(session) + } + } + value := fieldValue(session, "config_value") + if value == "" { + setSkillDAGStep(&session, "resolve_config_value") + a.saveSkillSession(userID, session) + if lang == "zh" { + return fmt.Sprintf("要更新策略参数,我还需要 %s 的目标值。", strategyConfigFieldDisplayName(field, lang)) + } + return fmt.Sprintf("I still need the target value for %s.", strategyConfigFieldDisplayName(field, lang)) + } + + if err := applyStrategyConfigPatch(&cfg, field, value); err != nil { + setSkillDAGStep(&session, "resolve_config_value") + a.saveSkillSession(userID, session) + if lang == "zh" { + return err.Error() + } + return err.Error() + } + + cfg.ClampLimits() + changed := []string{field} + displayChanged := make([]string, 0, len(changed)) + for _, item := range changed { + displayChanged = append(displayChanged, strategyConfigFieldDisplayName(item, lang)) + } + msgZH := "已更新策略参数:" + strings.Join(displayChanged, "、") + "。" + msgEN := "Updated strategy config fields: " + strings.Join(displayChanged, ", ") + "." + setSkillDAGStep(&session, "apply_field_update") + setSkillDAGStep(&session, "execute_update") + return a.persistStrategyConfigUpdate(storeUserID, userID, lang, strategy, cfg, msgZH, msgEN) +} + +func (a *Agent) loadStrategyConfigForUpdate(storeUserID, strategyID string) (*store.Strategy, store.StrategyConfig, error) { + strategy, err := a.store.Strategy().Get(storeUserID, strategyID) + if err != nil { + return nil, store.StrategyConfig{}, err + } + cfg := store.GetDefaultStrategyConfig("zh") + if strings.TrimSpace(strategy.Config) != "" { + _ = json.Unmarshal([]byte(strategy.Config), &cfg) + } + return strategy, cfg, nil +} + +func (a *Agent) persistStrategyConfigUpdate(storeUserID string, userID int64, lang string, strategy *store.Strategy, cfg store.StrategyConfig, zhMsg, enMsg string) string { + rawConfig, err := json.Marshal(cfg) + if err != nil { + if lang == "zh" { + return "序列化策略配置失败:" + err.Error() + } + return "Failed to serialize strategy config: " + err.Error() + } + raw, _ := json.Marshal(map[string]any{ + "action": "update", + "strategy_id": strategy.ID, + "config": json.RawMessage(rawConfig), + }) + resp := a.toolManageStrategy(storeUserID, string(raw)) + a.clearSkillSession(userID) + if errMsg := parseSkillError(resp); strings.Contains(resp, `"error"`) { + if lang == "zh" { + return "更新策略失败:" + errMsg + } + return "Failed to update strategy: " + errMsg + } + if lang == "zh" { + return zhMsg + } + return enMsg +} + +func extractQuotedContent(text string) string { + if matches := quotedNamePattern.FindStringSubmatch(text); len(matches) == 2 { + return strings.TrimSpace(matches[1]) + } + return "" +} + +func extractLabeledInt(text string, labels []string) (int, bool) { + lower := strings.ToLower(text) + for _, label := range labels { + idx := strings.Index(lower, strings.ToLower(label)) + if idx < 0 { + continue + } + segment := text[idx:] + if match := firstIntegerPattern.FindString(segment); match != "" { + if value, err := strconv.Atoi(match); err == nil { + return value, true + } + } + } + return 0, false +} + +func extractTimeframeAfterKeywords(text string, labels []string) string { + lower := strings.ToLower(text) + for _, label := range labels { + idx := strings.Index(lower, strings.ToLower(label)) + if idx < 0 { + continue + } + segment := text[idx:] + if match := timeframeTokenRE.FindString(segment); match != "" { + return strings.ToLower(match) + } + } + return "" +} + +func extractTimeframes(text string) []string { + matches := timeframeTokenRE.FindAllString(text, -1) + if len(matches) == 0 { + return nil + } + seen := make(map[string]struct{}, len(matches)) + out := make([]string, 0, len(matches)) + for _, match := range matches { + tf := strings.ToLower(strings.TrimSpace(match)) + if tf == "" { + continue + } + if _, ok := seen[tf]; ok { + continue + } + seen[tf] = struct{}{} + out = append(out, tf) + } + return out +} + func (a *Agent) handleTraderDiagnosisSkill(storeUserID, lang, text string) string { raw := a.toolListTraders(storeUserID) list := formatReadFastPathResponse(lang, "list_traders", raw) if lang == "zh" { - return "现象:这是交易员运行诊断问题。\n优先排查:\n1. 交易员是否已创建并处于运行状态。\n2. 绑定的模型、交易所、策略是否齐全。\n3. 是“没有启动”、还是“启动了但 AI 没有下单”、还是“下单失败”。\n当前交易员概览:\n" + list + reply := "现象:这是交易员运行诊断问题。\n优先排查:\n1. 交易员是否已创建并处于运行状态。\n2. 绑定的模型、交易所、策略是否齐全。\n3. 是“没有启动”、还是“启动了但 AI 没有下单”、还是“下单失败”。\n当前交易员概览:\n" + list + if excerpt := backendLogDiagnosisExcerpt(lang, text, "trader"); excerpt != "" { + reply += "\n" + excerpt + } + return reply } - return "This looks like a trader diagnosis issue.\nCheck whether the trader exists, is running, and has model/exchange/strategy bindings.\nCurrent trader overview:\n" + list + reply := "This looks like a trader diagnosis issue.\nCheck whether the trader exists, is running, and has model/exchange/strategy bindings.\nCurrent trader overview:\n" + list + if excerpt := backendLogDiagnosisExcerpt(lang, text, "trader"); excerpt != "" { + reply += "\n" + excerpt + } + return reply } func (a *Agent) handleStrategyDiagnosisSkill(storeUserID, lang, text string) string { raw := a.toolGetStrategies(storeUserID) list := formatReadFastPathResponse(lang, "get_strategies", raw) if lang == "zh" { - return "现象:这是策略或提示词生效问题。\n优先排查:\n1. 你改的是策略模板,还是 trader 上的 custom prompt。\n2. 策略是否真的保存成功。\n3. 运行结果不符合预期,是配置问题还是市场条件问题。\n当前策略概览:\n" + list + reply := "现象:这是策略或提示词生效问题。\n优先排查:\n1. 你改的是策略模板,还是 trader 上的 custom prompt。\n2. 策略是否真的保存成功。\n3. 运行结果不符合预期,是配置问题还是市场条件问题。\n当前策略概览:\n" + list + if excerpt := backendLogDiagnosisExcerpt(lang, text, "strategy"); excerpt != "" { + reply += "\n" + excerpt + } + return reply } - return "This looks like a strategy or prompt diagnosis issue.\nCheck whether you changed the strategy template or a trader-specific prompt override.\nCurrent strategy overview:\n" + list + reply := "This looks like a strategy or prompt diagnosis issue.\nCheck whether you changed the strategy template or a trader-specific prompt override.\nCurrent strategy overview:\n" + list + if excerpt := backendLogDiagnosisExcerpt(lang, text, "strategy"); excerpt != "" { + reply += "\n" + excerpt + } + return reply } diff --git a/agent/skill_management_handlers.go b/agent/skill_management_handlers.go index 6b606678..ce7bba2b 100644 --- a/agent/skill_management_handlers.go +++ b/agent/skill_management_handlers.go @@ -4,7 +4,10 @@ import ( "encoding/json" "fmt" "regexp" + "sort" "strings" + + "nofx/store" ) var urlPattern = regexp.MustCompile(`https://[^\s"'<>]+`) @@ -15,7 +18,7 @@ func detectTraderManagementIntent(text string) bool { return false } return containsAny(lower, []string{"交易员", "trader", "agent"}) && - containsAny(lower, []string{"修改", "编辑", "更新", "删除", "启动", "停止", "查看", "查询", "列出", "rename", "update", "delete", "start", "stop", "list", "show"}) + containsAny(lower, []string{"修改", "编辑", "更新", "改", "改一下", "删除", "删了", "启动", "停止", "查看", "查询", "列出", "rename", "update", "delete", "start", "stop", "list", "show"}) } func detectExchangeManagementIntent(text string) bool { @@ -24,7 +27,7 @@ func detectExchangeManagementIntent(text string) bool { return false } return containsAny(lower, []string{"交易所", "exchange", "okx", "binance", "bybit", "gate", "kucoin", "hyperliquid"}) && - containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "删除", "查询", "查看", "列出", "create", "update", "delete", "list", "show"}) + containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "删除", "删了", "查询", "查看", "列出", "启用", "禁用", "改名", "rename", "create", "update", "delete", "list", "show", "enable", "disable"}) } func detectModelManagementIntent(text string) bool { @@ -33,7 +36,7 @@ func detectModelManagementIntent(text string) bool { return false } return containsAny(lower, []string{"模型", "model", "provider", "deepseek", "openai", "claude", "gemini", "qwen", "kimi", "grok", "minimax"}) && - containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "删除", "查询", "查看", "列出", "create", "update", "delete", "list", "show"}) + containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "删除", "删了", "查询", "查看", "列出", "启用", "禁用", "改名", "rename", "create", "update", "delete", "list", "show", "enable", "disable"}) } func detectStrategyManagementIntent(text string) bool { @@ -41,8 +44,11 @@ func detectStrategyManagementIntent(text string) bool { if lower == "" { return false } + if wantsDefaultStrategyConfig(text) { + return true + } return containsAny(lower, []string{"策略", "strategy"}) && - containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "删除", "查询", "查看", "列出", "激活", "复制", "create", "update", "delete", "list", "show", "activate", "duplicate"}) + containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "改成", "改为", "删除", "删了", "查询", "查看", "列出", "激活", "复制", "参数", "配置", "详情", "详细", "prompt", "提示词", "什么样", "怎么样", "create", "update", "delete", "list", "show", "activate", "duplicate", "detail", "details", "config", "configuration", "parameter", "prompt", "what kind"}) } func detectTraderDiagnosisSkill(text string) bool { @@ -62,8 +68,9 @@ func detectManagementAction(text string, domain string) string { if lower == "" { return "" } + hasUpdateVerb := containsAny(lower, []string{"修改", "编辑", "更新", "改", "rename", "update", "切换", "换成", "换到"}) switch { - case containsAny(lower, []string{"删除", "删掉", "remove", "delete"}): + case containsAny(lower, []string{"删除", "删掉", "删了", "remove", "delete"}): return "delete" case containsAny(lower, []string{"启动", "开始", "run", "start"}) && domain == "trader": return "start" @@ -73,10 +80,32 @@ func detectManagementAction(text string, domain string) string { return "activate" case containsAny(lower, []string{"复制", "duplicate"}) && domain == "strategy": return "duplicate" + case containsAny(lower, []string{"改名", "重命名", "rename"}): + return "update_name" + case domain == "trader" && containsAny(lower, []string{"换模型", "换交易所", "换策略", "绑定", "切换模型", "切换交易所", "切换策略"}): + return "update_bindings" + case (domain == "exchange" || domain == "model") && containsAny(lower, []string{"启用", "禁用", "enable", "disable"}): + return "update_status" + case domain == "model" && hasUpdateVerb && containsAny(lower, []string{"url", "endpoint", "地址", "接口"}): + return "update_endpoint" + case domain == "strategy" && hasUpdateVerb && containsAny(lower, []string{"prompt", "提示词"}): + return "update_prompt" + case domain == "strategy" && hasUpdateVerb && containsAny(lower, []string{ + "参数", "配置", "config", "configuration", "parameter", + "最大持仓", "最小置信度", "最低置信度", "主周期", "多周期", "时间框架", + "btc/eth杠杆", "btc eth杠杆", "山寨币杠杆", + "核心指标", "ema", "macd", "rsi", "atr", "boll", "bollinger", "布林", + }): + return "update_config" case containsAny(lower, []string{"修改", "编辑", "更新", "改", "rename", "update"}): return "update" + case domain == "trader" && containsAny(lower, []string{"运行中的", "在跑", "running"}): + return "query_running" + case !containsAny(lower, []string{"创建", "新建", "create", "new"}) && + containsAny(lower, []string{"详情", "详细", "prompt", "提示词", "什么样", "怎么样", "detail", "details", "what kind"}): + return "query_detail" case containsAny(lower, []string{"查询", "查看", "列出", "list", "show", "有哪些"}): - return "query" + return "query_list" case containsAny(lower, []string{"创建", "新建", "加一个", "create", "new"}): return "create" default: @@ -152,6 +181,21 @@ func fieldValue(session skillSession, key string) string { return strings.TrimSpace(session.Fields[key]) } +func textMeansAllTargets(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + return containsAny(lower, []string{ + "全部", "所有", "全都", "全部策略", "所有策略", + "all", "all strategies", "every strategy", + }) +} + +func supportsBulkTargetSelection(skillName, action string) bool { + return skillName == "strategy_management" && action == "delete" +} + func resolveTargetFromText(text string, options []traderSkillOption, existing *EntityReference) *EntityReference { if existing != nil && (existing.ID != "" || existing.Name != "") { return existing @@ -173,6 +217,18 @@ func (a *Agent) handleTraderManagementSkill(storeUserID string, userID int64, la if action == "" || action == "create" { return "", false } + if action == "query_running" { + answer := formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)) + return applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), "running_only"), true + } + if action == "query_detail" { + options := a.loadTraderOptions(storeUserID) + target := resolveTargetFromText(text, options, session.TargetRef) + if detail, ok := a.describeTrader(storeUserID, lang, target); ok { + return detail, true + } + return formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)), true + } return a.handleSimpleEntitySkill(storeUserID, userID, lang, text, session, "trader_management", action, a.loadTraderOptions(storeUserID)) } @@ -186,7 +242,13 @@ func (a *Agent) handleExchangeManagementSkill(storeUserID string, userID int64, } options := a.loadExchangeOptions(storeUserID) switch action { - case "query": + case "query_list": + return formatReadFastPathResponse(lang, "get_exchange_configs", a.toolGetExchangeConfigs(storeUserID)), true + case "query_detail": + target := resolveTargetFromText(text, options, session.TargetRef) + if detail, ok := a.describeExchange(storeUserID, lang, target); ok { + return detail, true + } return formatReadFastPathResponse(lang, "get_exchange_configs", a.toolGetExchangeConfigs(storeUserID)), true case "create": return a.handleExchangeCreateSkill(storeUserID, userID, lang, text, session), true @@ -205,7 +267,13 @@ func (a *Agent) handleModelManagementSkill(storeUserID string, userID int64, lan } options := a.loadEnabledModelOptions(storeUserID) switch action { - case "query": + case "query_list": + return formatReadFastPathResponse(lang, "get_model_configs", a.toolGetModelConfigs(storeUserID)), true + case "query_detail": + target := resolveTargetFromText(text, options, session.TargetRef) + if detail, ok := a.describeModel(storeUserID, lang, target); ok { + return detail, true + } return formatReadFastPathResponse(lang, "get_model_configs", a.toolGetModelConfigs(storeUserID)), true case "create": return a.handleModelCreateSkill(storeUserID, userID, lang, text, session), true @@ -219,12 +287,24 @@ func (a *Agent) handleStrategyManagementSkill(storeUserID string, userID int64, if session.Name == "strategy_management" && session.Action != "" { action = session.Action } + if action == "" && wantsStrategyDetails(text) { + action = "query_detail" + } if action == "" { return "", false } options := a.loadStrategyOptions(storeUserID) switch action { - case "query": + case "query_detail": + if wantsDefaultStrategyConfig(text) { + return a.describeDefaultStrategyConfig(lang), true + } + target := resolveTargetFromText(text, options, session.TargetRef) + if detail, ok := a.describeStrategy(storeUserID, lang, target); ok { + return detail, true + } + return formatReadFastPathResponse(lang, "get_strategies", a.toolGetStrategies(storeUserID)), true + case "query_list": return formatReadFastPathResponse(lang, "get_strategies", a.toolGetStrategies(storeUserID)), true case "create": return a.handleStrategyCreateSkill(storeUserID, userID, lang, text, session), true @@ -233,6 +313,350 @@ func (a *Agent) handleStrategyManagementSkill(storeUserID string, userID int64, } } +func wantsStrategyDetails(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + return containsAny(lower, []string{ + "什么样", "怎么样", "详情", "详细", "参数", "配置", "prompt", "提示词", + "what kind", "details", "detail", "config", "configuration", "parameter", "prompt", + }) +} + +func wantsDefaultStrategyConfig(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + return containsAny(lower, []string{ + "默认配置", "默认策略", "默认模板", "模板配置", + "default config", "default strategy", "default template", + }) +} + +func (a *Agent) describeStrategy(storeUserID, lang string, target *EntityReference) (string, bool) { + if a.store == nil { + return "", false + } + + var strategy *store.Strategy + var err error + if target != nil && strings.TrimSpace(target.ID) != "" { + strategy, err = a.store.Strategy().Get(storeUserID, strings.TrimSpace(target.ID)) + } else if target != nil && strings.TrimSpace(target.Name) != "" { + strategies, listErr := a.store.Strategy().List(storeUserID) + if listErr != nil { + return "", false + } + for _, item := range strategies { + if item != nil && strings.EqualFold(strings.TrimSpace(item.Name), strings.TrimSpace(target.Name)) { + strategy = item + break + } + } + } else { + strategies, listErr := a.store.Strategy().List(storeUserID) + if listErr != nil || len(strategies) != 1 { + return "", false + } + strategy = strategies[0] + } + if err != nil || strategy == nil { + return "", false + } + + var cfg store.StrategyConfig + if strings.TrimSpace(strategy.Config) != "" { + _ = json.Unmarshal([]byte(strategy.Config), &cfg) + } + + return formatStrategyDetailResponse(lang, strategy, cfg), true +} + +func formatStrategyDetailResponse(lang string, strategy *store.Strategy, cfg store.StrategyConfig) string { + name := strings.TrimSpace(strategy.Name) + if name == "" { + name = strings.TrimSpace(strategy.ID) + } + + sourceBits := make([]string, 0, 4) + if strings.TrimSpace(cfg.CoinSource.SourceType) != "" { + sourceBits = append(sourceBits, cfg.CoinSource.SourceType) + } + if cfg.CoinSource.UseAI500 { + sourceBits = append(sourceBits, fmt.Sprintf("AI500=%d", cfg.CoinSource.AI500Limit)) + } + if cfg.CoinSource.UseOITop { + sourceBits = append(sourceBits, fmt.Sprintf("OITop=%d", cfg.CoinSource.OITopLimit)) + } + if cfg.CoinSource.UseOILow { + sourceBits = append(sourceBits, fmt.Sprintf("OILow=%d", cfg.CoinSource.OILowLimit)) + } + if len(cfg.CoinSource.StaticCoins) > 0 { + sourceBits = append(sourceBits, "static="+strings.Join(cfg.CoinSource.StaticCoins, ",")) + } + + timeframes := append([]string(nil), cfg.Indicators.Klines.SelectedTimeframes...) + if len(timeframes) == 0 { + timeframes = cleanStringList([]string{cfg.Indicators.Klines.PrimaryTimeframe, cfg.Indicators.Klines.LongerTimeframe}) + } + + indicatorBits := make([]string, 0, 8) + if cfg.Indicators.EnableRawKlines { + indicatorBits = append(indicatorBits, "raw_klines") + } + if cfg.Indicators.EnableVolume { + indicatorBits = append(indicatorBits, "volume") + } + if cfg.Indicators.EnableOI { + indicatorBits = append(indicatorBits, "oi") + } + if cfg.Indicators.EnableFundingRate { + indicatorBits = append(indicatorBits, "funding_rate") + } + if cfg.Indicators.EnableEMA { + indicatorBits = append(indicatorBits, "ema") + } + if cfg.Indicators.EnableMACD { + indicatorBits = append(indicatorBits, "macd") + } + if cfg.Indicators.EnableRSI { + indicatorBits = append(indicatorBits, "rsi") + } + if cfg.Indicators.EnableATR { + indicatorBits = append(indicatorBits, "atr") + } + if cfg.Indicators.EnableBOLL { + indicatorBits = append(indicatorBits, "boll") + } + sort.Strings(indicatorBits) + + promptBits := make([]string, 0, 5) + if strings.TrimSpace(cfg.PromptSections.RoleDefinition) != "" { + promptBits = append(promptBits, "role_definition") + } + if strings.TrimSpace(cfg.PromptSections.TradingFrequency) != "" { + promptBits = append(promptBits, "trading_frequency") + } + if strings.TrimSpace(cfg.PromptSections.EntryStandards) != "" { + promptBits = append(promptBits, "entry_standards") + } + if strings.TrimSpace(cfg.PromptSections.DecisionProcess) != "" { + promptBits = append(promptBits, "decision_process") + } + + customPrompt := strings.TrimSpace(cfg.CustomPrompt) + customPromptPreview := customPrompt + if len([]rune(customPromptPreview)) > 120 { + runes := []rune(customPromptPreview) + customPromptPreview = string(runes[:120]) + "..." + } + + if lang == "zh" { + lines := []string{ + fmt.Sprintf("策略“%s”概览:", name), + fmt.Sprintf("- 类型:%s", defaultIfEmpty(strings.TrimSpace(cfg.StrategyType), "ai_trading")), + fmt.Sprintf("- 语言:%s", defaultIfEmpty(strings.TrimSpace(cfg.Language), "zh")), + } + if strings.TrimSpace(strategy.Description) != "" { + lines = append(lines, fmt.Sprintf("- 描述:%s", strings.TrimSpace(strategy.Description))) + } + if len(sourceBits) > 0 { + lines = append(lines, "- 标的来源:"+strings.Join(sourceBits, " | ")) + } + if len(timeframes) > 0 { + lines = append(lines, "- K线周期:"+strings.Join(timeframes, " / ")) + } + lines = append(lines, fmt.Sprintf("- 仓位风险:最多持仓 %d,BTC/ETH 最大杠杆 %d,山寨最大杠杆 %d,最低置信度 %d", + cfg.RiskControl.MaxPositions, cfg.RiskControl.BTCETHMaxLeverage, cfg.RiskControl.AltcoinMaxLeverage, cfg.RiskControl.MinConfidence)) + if len(indicatorBits) > 0 { + lines = append(lines, "- 已启用指标:"+strings.Join(indicatorBits, "、")) + } + if len(promptBits) > 0 { + lines = append(lines, "- Prompt 模块:"+strings.Join(promptBits, "、")) + } + if customPromptPreview != "" { + lines = append(lines, "- 自定义 Prompt:"+customPromptPreview) + } else { + lines = append(lines, "- 自定义 Prompt:当前为空,主要使用策略模板内置 prompt sections。") + } + lines = append(lines, "- 如果你要,我还可以继续展开这条策略的完整参数 JSON,或者逐段解释它的 prompt。") + return strings.Join(lines, "\n") + } + + lines := []string{ + fmt.Sprintf("Strategy %q overview:", name), + fmt.Sprintf("- Type: %s", defaultIfEmpty(strings.TrimSpace(cfg.StrategyType), "ai_trading")), + fmt.Sprintf("- Language: %s", defaultIfEmpty(strings.TrimSpace(cfg.Language), "en")), + } + if strings.TrimSpace(strategy.Description) != "" { + lines = append(lines, fmt.Sprintf("- Description: %s", strings.TrimSpace(strategy.Description))) + } + if len(sourceBits) > 0 { + lines = append(lines, "- Coin source: "+strings.Join(sourceBits, " | ")) + } + if len(timeframes) > 0 { + lines = append(lines, "- Timeframes: "+strings.Join(timeframes, " / ")) + } + lines = append(lines, fmt.Sprintf("- Risk: max positions %d, BTC/ETH max leverage %d, alt max leverage %d, min confidence %d", + cfg.RiskControl.MaxPositions, cfg.RiskControl.BTCETHMaxLeverage, cfg.RiskControl.AltcoinMaxLeverage, cfg.RiskControl.MinConfidence)) + if len(indicatorBits) > 0 { + lines = append(lines, "- Enabled indicators: "+strings.Join(indicatorBits, ", ")) + } + if len(promptBits) > 0 { + lines = append(lines, "- Prompt modules: "+strings.Join(promptBits, ", ")) + } + if customPromptPreview != "" { + lines = append(lines, "- Custom prompt: "+customPromptPreview) + } else { + lines = append(lines, "- Custom prompt: empty right now; it mainly uses the built-in prompt sections from the strategy template.") + } + lines = append(lines, "- I can also expand the full strategy config JSON or walk through the prompt section by section.") + return strings.Join(lines, "\n") +} + +func (a *Agent) describeDefaultStrategyConfig(lang string) string { + if lang != "zh" { + lang = "en" + } + cfg := store.GetDefaultStrategyConfig(lang) + name := "Default Strategy Template" + description := "System default strategy configuration template" + if lang == "zh" { + name = "默认策略模板" + description = "系统默认策略配置模板" + } + return formatStrategyDetailResponse(lang, &store.Strategy{ + ID: "default_strategy_template", + Name: name, + Description: description, + }, cfg) +} + +func (a *Agent) describeTrader(storeUserID, lang string, target *EntityReference) (string, bool) { + raw := a.toolListTraders(storeUserID) + var payload struct { + Traders []safeTraderToolConfig `json:"traders"` + } + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return "", false + } + trader := findTraderByReference(payload.Traders, target) + if trader == nil { + if len(payload.Traders) != 1 { + return "", false + } + trader = &payload.Traders[0] + } + if lang == "zh" { + status := "未运行" + if trader.IsRunning { + status = "运行中" + } + return fmt.Sprintf("交易员“%s”详情:\n- 状态:%s\n- 模型:%s\n- 交易所:%s\n- 策略:%s\n- 扫描间隔:%d 分钟\n- 初始余额:%.2f", + trader.Name, status, trader.AIModelID, trader.ExchangeID, defaultIfEmpty(trader.StrategyID, "未绑定"), trader.ScanIntervalMinutes, trader.InitialBalance), true + } + status := "stopped" + if trader.IsRunning { + status = "running" + } + return fmt.Sprintf("Trader %q details:\n- Status: %s\n- Model: %s\n- Exchange: %s\n- Strategy: %s\n- Scan interval: %d minutes\n- Initial balance: %.2f", + trader.Name, status, trader.AIModelID, trader.ExchangeID, defaultIfEmpty(trader.StrategyID, "none"), trader.ScanIntervalMinutes, trader.InitialBalance), true +} + +func (a *Agent) describeExchange(storeUserID, lang string, target *EntityReference) (string, bool) { + raw := a.toolGetExchangeConfigs(storeUserID) + var payload struct { + ExchangeConfigs []safeExchangeToolConfig `json:"exchange_configs"` + } + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return "", false + } + exchange := findExchangeByReference(payload.ExchangeConfigs, target) + if exchange == nil { + if len(payload.ExchangeConfigs) != 1 { + return "", false + } + exchange = &payload.ExchangeConfigs[0] + } + if lang == "zh" { + return fmt.Sprintf("交易所配置“%s”详情:\n- 交易所:%s\n- 已启用:%t\n- API Key:%t\n- Secret:%t\n- Passphrase:%t\n- Testnet:%t", + defaultIfEmpty(exchange.AccountName, exchange.ID), exchange.ExchangeType, exchange.Enabled, exchange.HasAPIKey, exchange.HasSecretKey, exchange.HasPassphrase, exchange.Testnet), true + } + return fmt.Sprintf("Exchange config %q details:\n- Exchange: %s\n- Enabled: %t\n- API key present: %t\n- Secret present: %t\n- Passphrase present: %t\n- Testnet: %t", + defaultIfEmpty(exchange.AccountName, exchange.ID), exchange.ExchangeType, exchange.Enabled, exchange.HasAPIKey, exchange.HasSecretKey, exchange.HasPassphrase, exchange.Testnet), true +} + +func (a *Agent) describeModel(storeUserID, lang string, target *EntityReference) (string, bool) { + raw := a.toolGetModelConfigs(storeUserID) + var payload struct { + ModelConfigs []safeModelToolConfig `json:"model_configs"` + } + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return "", false + } + model := findModelByReference(payload.ModelConfigs, target) + if model == nil { + if len(payload.ModelConfigs) != 1 { + return "", false + } + model = &payload.ModelConfigs[0] + } + if lang == "zh" { + return fmt.Sprintf("模型配置“%s”详情:\n- Provider:%s\n- 已启用:%t\n- API Key:%t\n- URL:%s\n- Model Name:%s", + defaultIfEmpty(model.Name, model.ID), model.Provider, model.Enabled, model.HasAPIKey, defaultIfEmpty(model.CustomAPIURL, "未设置"), defaultIfEmpty(model.CustomModelName, "未设置")), true + } + return fmt.Sprintf("Model config %q details:\n- Provider: %s\n- Enabled: %t\n- API key present: %t\n- URL: %s\n- Model name: %s", + defaultIfEmpty(model.Name, model.ID), model.Provider, model.Enabled, model.HasAPIKey, defaultIfEmpty(model.CustomAPIURL, "not set"), defaultIfEmpty(model.CustomModelName, "not set")), true +} + +func findTraderByReference(items []safeTraderToolConfig, target *EntityReference) *safeTraderToolConfig { + if target == nil { + return nil + } + for i := range items { + if strings.TrimSpace(target.ID) != "" && items[i].ID == strings.TrimSpace(target.ID) { + return &items[i] + } + if strings.TrimSpace(target.Name) != "" && strings.EqualFold(strings.TrimSpace(items[i].Name), strings.TrimSpace(target.Name)) { + return &items[i] + } + } + return nil +} + +func findExchangeByReference(items []safeExchangeToolConfig, target *EntityReference) *safeExchangeToolConfig { + if target == nil { + return nil + } + for i := range items { + name := defaultIfEmpty(items[i].AccountName, items[i].Name) + if strings.TrimSpace(target.ID) != "" && items[i].ID == strings.TrimSpace(target.ID) { + return &items[i] + } + if strings.TrimSpace(target.Name) != "" && strings.EqualFold(strings.TrimSpace(name), strings.TrimSpace(target.Name)) { + return &items[i] + } + } + return nil +} + +func findModelByReference(items []safeModelToolConfig, target *EntityReference) *safeModelToolConfig { + if target == nil { + return nil + } + for i := range items { + if strings.TrimSpace(target.ID) != "" && items[i].ID == strings.TrimSpace(target.ID) { + return &items[i] + } + if strings.TrimSpace(target.Name) != "" && strings.EqualFold(strings.TrimSpace(items[i].Name), strings.TrimSpace(target.Name)) { + return &items[i] + } + } + return nil +} + func (a *Agent) loadTraderOptions(storeUserID string) []traderSkillOption { if a.store == nil { return nil @@ -252,6 +676,9 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang if session.Name == "" { session = skillSession{Name: "exchange_management", Action: "create", Phase: "collecting"} } + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "resolve_exchange_type") + } if isCancelSkillReply(text) { a.clearSkillSession(userID) if lang == "zh" { @@ -267,6 +694,7 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang } exType := fieldValue(session, "exchange_type") if actionRequiresSlot("exchange_management", "create", "exchange_type") && exType == "" { + setSkillDAGStep(&session, "resolve_exchange_type") a.saveSkillSession(userID, session) if lang == "zh" { return "要创建交易所配置,我还需要:" + slotDisplayName("exchange_type", lang) + "。例如:OKX、Binance、Bybit。" @@ -277,6 +705,7 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang if accountName == "" { accountName = "Default" } + setSkillDAGStep(&session, "execute_create") args := map[string]any{ "action": "create", "exchange_type": exType, @@ -302,6 +731,9 @@ func (a *Agent) handleModelCreateSkill(storeUserID string, userID int64, lang, t if session.Name == "" { session = skillSession{Name: "model_management", Action: "create", Phase: "collecting"} } + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "resolve_provider") + } if isCancelSkillReply(text) { a.clearSkillSession(userID) if lang == "zh" { @@ -320,17 +752,19 @@ func (a *Agent) handleModelCreateSkill(storeUserID string, userID int64, lang, t } provider := fieldValue(session, "provider") if actionRequiresSlot("model_management", "create", "provider") && provider == "" { + setSkillDAGStep(&session, "resolve_provider") a.saveSkillSession(userID, session) if lang == "zh" { return "要创建模型配置,我还需要:" + slotDisplayName("provider", lang) + ",例如:OpenAI、DeepSeek、Claude、Gemini。" } return "To create a model config, I need the provider first, for example OpenAI, DeepSeek, Claude, or Gemini." } + setSkillDAGStep(&session, "execute_create") args := map[string]any{ - "action": "create", - "provider": provider, - "name": defaultIfEmpty(fieldValue(session, "name"), provider), - "custom_api_url": fieldValue(session, "custom_api_url"), + "action": "create", + "provider": provider, + "name": defaultIfEmpty(fieldValue(session, "name"), provider), + "custom_api_url": fieldValue(session, "custom_api_url"), "custom_model_name": fieldValue(session, "custom_model_name"), } raw, _ := json.Marshal(args) @@ -353,6 +787,9 @@ func (a *Agent) handleStrategyCreateSkill(storeUserID string, userID int64, lang if session.Name == "" { session = skillSession{Name: "strategy_management", Action: "create", Phase: "collecting"} } + if fieldValue(session, skillDAGStepField) == "" { + setSkillDAGStep(&session, "resolve_name") + } if isCancelSkillReply(text) { a.clearSkillSession(userID) if lang == "zh" { @@ -371,12 +808,14 @@ func (a *Agent) handleStrategyCreateSkill(storeUserID string, userID int64, lang } } if actionRequiresSlot("strategy_management", "create", "name") && name == "" { + setSkillDAGStep(&session, "resolve_name") a.saveSkillSession(userID, session) if lang == "zh" { return "要创建策略,我还需要:" + slotDisplayName("name", lang) + "。你可以直接说:创建一个叫“趋势策略A”的策略。" } return "To create a strategy, I need a strategy name. You can say: create a strategy called 'Trend A'." } + setSkillDAGStep(&session, "execute_create") args := map[string]any{"action": "create", "name": name, "lang": "zh"} raw, _ := json.Marshal(args) resp := a.toolManageStrategy(storeUserID, string(raw)) @@ -408,22 +847,65 @@ func (a *Agent) handleSimpleEntitySkill(storeUserID string, userID int64, lang, if session.Name != skillName || session.Action != action { return "", false } - session.TargetRef = resolveTargetFromText(text, options, session.TargetRef) - if session.TargetRef == nil && action != "query" { - a.saveSkillSession(userID, session) - label := formatOptionList("可选对象:", options) - if lang == "zh" { - reply := "我还需要你明确要操作的是哪一个对象。" + + if dag, ok := getSkillDAG(skillName, action); ok && len(dag.Steps) > 0 { + currentStep, _ := currentSkillDAGStep(session) + if currentStep.ID == "resolve_target" { + if supportsBulkTargetSelection(skillName, action) && textMeansAllTargets(text) { + setField(&session, "bulk_scope", "all") + advanceSkillDAGStep(&session, currentStep.ID) + } else { + session.TargetRef = resolveTargetFromText(text, options, session.TargetRef) + } + if session.TargetRef == nil { + if !(supportsBulkTargetSelection(skillName, action) && fieldValue(session, "bulk_scope") == "all") { + setSkillDAGStep(&session, "resolve_target") + a.saveSkillSession(userID, session) + label := "可选对象:" + if lang != "zh" { + label = "Available targets:" + } + optionList := formatOptionList(label, options) + if lang == "zh" { + reply := "当前这一步需要先确定目标对象。请告诉我你要操作哪一个。" + if optionList != "" { + reply += "\n" + optionList + } + return reply, true + } + reply := "This step needs a target object first. Tell me which one to operate on." + if optionList != "" { + reply += "\n" + optionList + } + return reply, true + } + } + if fieldValue(session, skillDAGStepField) == currentStep.ID { + advanceSkillDAGStep(&session, currentStep.ID) + } + } + } else { + if supportsBulkTargetSelection(skillName, action) && textMeansAllTargets(text) { + setField(&session, "bulk_scope", "all") + } else { + session.TargetRef = resolveTargetFromText(text, options, session.TargetRef) + } + if session.TargetRef == nil && fieldValue(session, "bulk_scope") != "all" && action != "query" && action != "query_list" && action != "query_detail" && action != "query_running" { + a.saveSkillSession(userID, session) + label := formatOptionList("可选对象:", options) + if lang == "zh" { + reply := "我还需要你明确要操作的是哪一个对象。" + if label != "" { + reply += "\n" + label + } + return reply, true + } + reply := "I still need you to specify which object to operate on." if label != "" { reply += "\n" + label } return reply, true } - reply := "I still need you to specify which object to operate on." - if label != "" { - reply += "\n" + label - } - return reply, true } switch skillName { diff --git a/agent/skill_outcome.go b/agent/skill_outcome.go new file mode 100644 index 00000000..1075a434 --- /dev/null +++ b/agent/skill_outcome.go @@ -0,0 +1,180 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "nofx/mcp" +) + +const ( + skillOutcomeSuccess = "success" + skillOutcomeNeedMoreInfo = "need_more_info" + skillOutcomeRecoverableError = "recoverable_error" + skillOutcomeFatalError = "fatal_error" + skillOutcomeNotHandled = "not_handled" +) + +type skillOutcome struct { + Skill string `json:"skill"` + Action string `json:"action"` + Status string `json:"status"` + GoalAchieved bool `json:"goal_achieved"` + UserMessage string `json:"user_message,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + Error string `json:"error,omitempty"` + Data map[string]any `json:"data,omitempty"` +} + +type taskReviewDecision struct { + Route string `json:"route"` + Answer string `json:"answer,omitempty"` +} + +func normalizeAtomicSkillAction(skill, action string) string { + action = strings.TrimSpace(strings.ToLower(action)) + switch skill { + case "trader_management": + switch action { + case "query", "query_list": + return "query_list" + case "query_running": + return "query_running" + case "query_detail": + return "query_detail" + case "update": + return "update_name" + case "update_name", "update_bindings": + return action + } + case "exchange_management": + switch action { + case "query", "query_list": + return "query_list" + case "query_detail": + return "query_detail" + case "update": + return "update_name" + case "update_name", "update_status": + return action + } + case "model_management": + switch action { + case "query", "query_list": + return "query_list" + case "query_detail": + return "query_detail" + case "update": + return "update_name" + case "update_name", "update_endpoint", "update_status": + return action + } + case "strategy_management": + switch action { + case "query", "query_list": + return "query_list" + case "query_detail": + return "query_detail" + case "update": + return "update_name" + case "update_name", "update_config", "update_prompt": + return action + } + } + return action +} + +func inferSkillOutcome(skill, action, answer string, activeSession skillSession, data map[string]any) skillOutcome { + outcome := skillOutcome{ + Skill: skill, + Action: action, + Status: skillOutcomeSuccess, + UserMessage: strings.TrimSpace(answer), + Data: data, + } + if activeSession.Name != "" { + outcome.Status = skillOutcomeNeedMoreInfo + outcome.GoalAchieved = false + return outcome + } + + lower := strings.ToLower(strings.TrimSpace(answer)) + switch { + case lower == "": + outcome.Status = skillOutcomeNotHandled + case strings.Contains(lower, "失败") || strings.Contains(lower, "failed") || strings.Contains(lower, "error"): + outcome.Status = skillOutcomeRecoverableError + outcome.Error = strings.TrimSpace(answer) + default: + outcome.GoalAchieved = true + } + return outcome +} + +func parseTaskReviewDecision(raw string) (taskReviewDecision, error) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + + var decision taskReviewDecision + if err := json.Unmarshal([]byte(raw), &decision); err == nil { + decision.Route = strings.TrimSpace(strings.ToLower(decision.Route)) + decision.Answer = strings.TrimSpace(decision.Answer) + return decision, nil + } + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start >= 0 && end > start { + if err := json.Unmarshal([]byte(raw[start:end+1]), &decision); err == nil { + decision.Route = strings.TrimSpace(strings.ToLower(decision.Route)) + decision.Answer = strings.TrimSpace(decision.Answer) + return decision, nil + } + } + return taskReviewDecision{}, fmt.Errorf("invalid task review json") +} + +func (a *Agent) reviewTaskCompletion(ctx context.Context, userID int64, lang, text string, outcome skillOutcome) (taskReviewDecision, error) { + if a.aiClient == nil { + if outcome.Status == skillOutcomeRecoverableError || outcome.Status == skillOutcomeFatalError || outcome.Status == skillOutcomeNotHandled { + return taskReviewDecision{Route: "replan"}, nil + } + return taskReviewDecision{Route: "complete", Answer: outcome.UserMessage}, nil + } + + recentConversationCtx := a.buildRecentConversationContext(userID, text) + outcomeJSON, _ := json.Marshal(outcome) + systemPrompt := `You are the task-level Plan-Execute-Review supervisor for NOFXi. +You are reviewing the JSON result returned by one structured skill execution. +Return JSON only. Do not return markdown. + +Rules: +- Decide whether the OVERALL user task is finished, not whether the skill itself ran successfully. +- Use route "complete" only when the user's task is now complete or the best next message is a final user-facing reply. +- Use route "replan" when the user's task is not complete yet and the planner should continue from the new skill outcome. +- Prefer route "replan" for recoverable errors, unmet goals, missing prerequisites, or cases where another skill/tool sequence may help. +- If you choose "complete", produce the final user-facing answer in the user's language. + +Return JSON with this exact shape: +{"route":"complete|replan","answer":""}` + userPrompt := fmt.Sprintf("Language: %s\nUser message: %s\n\nRecent conversation:\n%s\n\nSkill outcome JSON:\n%s", lang, text, recentConversationCtx, string(outcomeJSON)) + + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return taskReviewDecision{}, err + } + return parseTaskReviewDecision(raw) +} diff --git a/agent/tools.go b/agent/tools.go index ccddc39e..be7e1f24 100644 --- a/agent/tools.go +++ b/agent/tools.go @@ -1,13 +1,17 @@ package agent import ( + "bufio" "context" "encoding/json" "fmt" + "os" + "path/filepath" "sort" "strings" "time" + "nofx/kernel" "nofx/mcp" "nofx/safe" "nofx/security" @@ -56,6 +60,24 @@ func buildAgentTools() []mcp.Tool { }, }, }, + { + Type: "function", + Function: mcp.FunctionDef{ + Name: "get_backend_logs", + Description: "Get recent backend log lines for a trader diagnosis. Prefer this when the user asks why a specific trader failed, stopped, or behaved unexpectedly. Returns recent matching log lines for the authenticated user's trader.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "trader_id": map[string]any{ + "type": "string", + "description": "Trader id to diagnose. The backend verifies that this trader belongs to the authenticated user before returning logs.", + }, + "limit": map[string]any{"type": "number", "description": "Maximum number of recent log lines to return. Default 30."}, + "errors_only": map[string]any{"type": "boolean", "description": "When true, only return error-like log lines. Default true."}, + }, + }, + }, + }, { Type: "function", Function: mcp.FunctionDef{ @@ -92,19 +114,19 @@ func buildAgentTools() []mcp.Tool { "type": "boolean", "description": "Whether this exchange binding should be enabled.", }, - "api_key": map[string]any{"type": "string"}, - "secret_key": map[string]any{"type": "string"}, - "passphrase": map[string]any{"type": "string"}, - "testnet": map[string]any{"type": "boolean"}, - "hyperliquid_wallet_addr": map[string]any{"type": "string"}, + "api_key": map[string]any{"type": "string"}, + "secret_key": map[string]any{"type": "string"}, + "passphrase": map[string]any{"type": "string"}, + "testnet": map[string]any{"type": "boolean"}, + "hyperliquid_wallet_addr": map[string]any{"type": "string"}, "hyperliquid_unified_account": map[string]any{"type": "boolean"}, - "aster_user": map[string]any{"type": "string"}, - "aster_signer": map[string]any{"type": "string"}, - "aster_private_key": map[string]any{"type": "string"}, - "lighter_wallet_addr": map[string]any{"type": "string"}, - "lighter_private_key": map[string]any{"type": "string"}, + "aster_user": map[string]any{"type": "string"}, + "aster_signer": map[string]any{"type": "string"}, + "aster_private_key": map[string]any{"type": "string"}, + "lighter_wallet_addr": map[string]any{"type": "string"}, + "lighter_private_key": map[string]any{"type": "string"}, "lighter_api_key_private_key": map[string]any{"type": "string"}, - "lighter_api_key_index": map[string]any{"type": "number"}, + "lighter_api_key_index": map[string]any{"type": "number"}, }, "required": []string{"action"}, }, @@ -171,10 +193,10 @@ func buildAgentTools() []mcp.Tool { "type": "string", "enum": []string{"list", "create", "update", "delete", "activate", "duplicate", "get_default_config"}, }, - "strategy_id": map[string]any{"type": "string"}, - "name": map[string]any{"type": "string"}, - "description": map[string]any{"type": "string"}, - "lang": map[string]any{"type": "string", "enum": []string{"zh", "en"}}, + "strategy_id": map[string]any{"type": "string"}, + "name": map[string]any{"type": "string"}, + "description": map[string]any{"type": "string"}, + "lang": map[string]any{"type": "string", "enum": []string{"zh", "en"}}, "is_public": map[string]any{"type": "boolean"}, "config_visible": map[string]any{"type": "boolean"}, "config": map[string]any{"type": "object", "description": "Full or partial strategy config JSON object, depending on action."}, @@ -199,22 +221,22 @@ func buildAgentTools() []mcp.Tool { "type": "string", "description": "Required for update, delete, start, and stop.", }, - "name": map[string]any{"type": "string"}, - "ai_model_id": map[string]any{"type": "string"}, - "exchange_id": map[string]any{"type": "string"}, - "strategy_id": map[string]any{"type": "string"}, - "initial_balance": map[string]any{"type": "number"}, - "scan_interval_minutes": map[string]any{"type": "number"}, - "is_cross_margin": map[string]any{"type": "boolean"}, - "show_in_competition": map[string]any{"type": "boolean"}, - "btc_eth_leverage": map[string]any{"type": "number"}, - "altcoin_leverage": map[string]any{"type": "number"}, - "trading_symbols": map[string]any{"type": "string"}, - "custom_prompt": map[string]any{"type": "string"}, - "override_base_prompt": map[string]any{"type": "boolean"}, + "name": map[string]any{"type": "string"}, + "ai_model_id": map[string]any{"type": "string"}, + "exchange_id": map[string]any{"type": "string"}, + "strategy_id": map[string]any{"type": "string"}, + "initial_balance": map[string]any{"type": "number"}, + "scan_interval_minutes": map[string]any{"type": "number"}, + "is_cross_margin": map[string]any{"type": "boolean"}, + "show_in_competition": map[string]any{"type": "boolean"}, + "btc_eth_leverage": map[string]any{"type": "number"}, + "altcoin_leverage": map[string]any{"type": "number"}, + "trading_symbols": map[string]any{"type": "string"}, + "custom_prompt": map[string]any{"type": "string"}, + "override_base_prompt": map[string]any{"type": "boolean"}, "system_prompt_template": map[string]any{"type": "string"}, - "use_ai500": map[string]any{"type": "boolean"}, - "use_oi_top": map[string]any{"type": "boolean"}, + "use_ai500": map[string]any{"type": "boolean"}, + "use_oi_top": map[string]any{"type": "boolean"}, }, "required": []string{"action"}, }, @@ -316,6 +338,26 @@ func buildAgentTools() []mcp.Tool { }, }, }, + { + Type: "function", + Function: mcp.FunctionDef{ + Name: "get_candidate_coins", + Description: "Get the current candidate coin list for a trader or strategy, including AI500 coin-source settings and the selected symbols.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "trader_id": map[string]any{ + "type": "string", + "description": "Optional trader id. Prefer this when asking about a running trader.", + }, + "strategy_id": map[string]any{ + "type": "string", + "description": "Optional strategy id. Use this when asking about a strategy template directly.", + }, + }, + }, + }, + }, } } @@ -326,6 +368,8 @@ func (a *Agent) handleToolCall(ctx context.Context, storeUserID string, userID i return a.toolGetPreferences(userID) case "manage_preferences": return a.toolManagePreferences(userID, tc.Function.Arguments) + case "get_backend_logs": + return a.toolGetBackendLogs(storeUserID, tc.Function.Arguments) case "get_exchange_configs": return a.toolGetExchangeConfigs(storeUserID) case "manage_exchange_config": @@ -352,6 +396,8 @@ func (a *Agent) handleToolCall(ctx context.Context, storeUserID string, userID i return a.toolGetMarketPrice(tc.Function.Arguments) case "get_trade_history": return a.toolGetTradeHistory(tc.Function.Arguments) + case "get_candidate_coins": + return a.toolGetCandidateCoins(storeUserID, userID, tc.Function.Arguments) default: return fmt.Sprintf(`{"error": "unknown tool: %s"}`, tc.Function.Name) } @@ -388,21 +434,21 @@ type safeModelToolConfig struct { } type safeTraderToolConfig struct { - ID string `json:"id"` - Name string `json:"name"` - AIModelID string `json:"ai_model_id"` - ExchangeID string `json:"exchange_id"` - StrategyID string `json:"strategy_id,omitempty"` - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - IsRunning bool `json:"is_running"` - IsCrossMargin bool `json:"is_cross_margin"` - ShowInCompetition bool `json:"show_in_competition"` - BTCETHLeverage int `json:"btc_eth_leverage,omitempty"` - AltcoinLeverage int `json:"altcoin_leverage,omitempty"` - TradingSymbols string `json:"trading_symbols,omitempty"` - CustomPrompt string `json:"custom_prompt,omitempty"` - SystemPromptTemplate string `json:"system_prompt_template,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + AIModelID string `json:"ai_model_id"` + ExchangeID string `json:"exchange_id"` + StrategyID string `json:"strategy_id,omitempty"` + InitialBalance float64 `json:"initial_balance"` + ScanIntervalMinutes int `json:"scan_interval_minutes"` + IsRunning bool `json:"is_running"` + IsCrossMargin bool `json:"is_cross_margin"` + ShowInCompetition bool `json:"show_in_competition"` + BTCETHLeverage int `json:"btc_eth_leverage,omitempty"` + AltcoinLeverage int `json:"altcoin_leverage,omitempty"` + TradingSymbols string `json:"trading_symbols,omitempty"` + CustomPrompt string `json:"custom_prompt,omitempty"` + SystemPromptTemplate string `json:"system_prompt_template,omitempty"` } type safeStrategyToolConfig struct { @@ -472,6 +518,14 @@ func safeModelForTool(model *store.AIModel) safeModelToolConfig { } } +func modelConfigUsable(provider, modelID, apiKey, customAPIURL, customModelName string) bool { + if strings.TrimSpace(apiKey) == "" { + return false + } + resolvedURL, resolvedModel := resolveModelRuntimeConfig(provider, customAPIURL, customModelName, modelID) + return strings.TrimSpace(resolvedURL) != "" && strings.TrimSpace(resolvedModel) != "" +} + func safeTraderForTool(trader *store.Trader, isRunning bool) safeTraderToolConfig { return safeTraderToolConfig{ ID: trader.ID, @@ -531,29 +585,131 @@ func (a *Agent) toolGetExchangeConfigs(storeUserID string) string { return string(result) } +func latestBackendLogFilePath() string { + matches, err := filepath.Glob(filepath.Join("data", "nofx_*.log")) + if err != nil || len(matches) == 0 { + return "" + } + sort.Strings(matches) + return matches[len(matches)-1] +} + +func isBackendErrorLikeLogLine(line string) bool { + lower := strings.ToLower(strings.TrimSpace(line)) + if lower == "" { + return false + } + return strings.Contains(lower, "[erro]") || + strings.Contains(lower, " panic") || + strings.Contains(lower, "🔥") || + strings.Contains(lower, "❌") || + strings.Contains(lower, " failed") || + strings.Contains(lower, " error") || + strings.Contains(lower, "invalid ") +} + +func readBackendLogEntries(limit int, contains string, errorsOnly bool) (string, []string, error) { + path := latestBackendLogFilePath() + if path == "" { + return "", nil, fmt.Errorf("backend log file not found") + } + file, err := os.Open(path) + if err != nil { + return path, nil, err + } + defer file.Close() + + filter := strings.ToLower(strings.TrimSpace(contains)) + matches := make([]string, 0, max(limit, 1)) + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if errorsOnly && !isBackendErrorLikeLogLine(line) { + continue + } + if filter != "" && !strings.Contains(strings.ToLower(line), filter) { + continue + } + matches = append(matches, line) + } + if err := scanner.Err(); err != nil { + return path, nil, err + } + if limit <= 0 { + limit = 30 + } + if len(matches) > limit { + matches = matches[len(matches)-limit:] + } + return path, matches, nil +} + +func (a *Agent) toolGetBackendLogs(storeUserID, argsJSON string) string { + var args struct { + TraderID string `json:"trader_id"` + Limit int `json:"limit"` + ErrorsOnly *bool `json:"errors_only"` + } + if strings.TrimSpace(argsJSON) != "" { + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err) + } + } + errorsOnly := true + if args.ErrorsOnly != nil { + errorsOnly = *args.ErrorsOnly + } + traderID := strings.TrimSpace(args.TraderID) + if traderID == "" { + return `{"error":"trader_id is required"}` + } + if a.store == nil { + return `{"error":"store unavailable"}` + } + trader, err := a.store.Trader().GetByID(traderID) + if err != nil { + return fmt.Sprintf(`{"error":"failed to load trader: %s"}`, err) + } + if trader.UserID != storeUserID { + return `{"error":"trader not found for current user"}` + } + path, entries, err := readBackendLogEntries(args.Limit, traderID, errorsOnly) + if err != nil { + return fmt.Sprintf(`{"error":"failed to read backend logs: %s"}`, err) + } + result, _ := json.Marshal(map[string]any{ + "trader_id": traderID, + "log_file": path, + "entries": entries, + "count": len(entries), + "errors_only": errorsOnly, + }) + return string(result) +} + func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string { if a.store == nil { return `{"error":"store unavailable"}` } var args struct { - Action string `json:"action"` - ExchangeID string `json:"exchange_id"` - ExchangeType string `json:"exchange_type"` - AccountName string `json:"account_name"` - Enabled *bool `json:"enabled"` - APIKey string `json:"api_key"` - SecretKey string `json:"secret_key"` - Passphrase string `json:"passphrase"` - Testnet *bool `json:"testnet"` - HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"` - HyperliquidUnifiedAccount *bool `json:"hyperliquid_unified_account"` - AsterUser string `json:"aster_user"` - AsterSigner string `json:"aster_signer"` - AsterPrivateKey string `json:"aster_private_key"` - LighterWalletAddr string `json:"lighter_wallet_addr"` - LighterPrivateKey string `json:"lighter_private_key"` - LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"` - LighterAPIKeyIndex *int `json:"lighter_api_key_index"` + Action string `json:"action"` + ExchangeID string `json:"exchange_id"` + ExchangeType string `json:"exchange_type"` + AccountName string `json:"account_name"` + Enabled *bool `json:"enabled"` + APIKey string `json:"api_key"` + SecretKey string `json:"secret_key"` + Passphrase string `json:"passphrase"` + Testnet *bool `json:"testnet"` + HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"` + HyperliquidUnifiedAccount *bool `json:"hyperliquid_unified_account"` + AsterUser string `json:"aster_user"` + AsterSigner string `json:"aster_signer"` + AsterPrivateKey string `json:"aster_private_key"` + LighterWalletAddr string `json:"lighter_wallet_addr"` + LighterPrivateKey string `json:"lighter_private_key"` + LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"` + LighterAPIKeyIndex *int `json:"lighter_api_key_index"` } if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err) @@ -802,7 +958,15 @@ func (a *Agent) toolManageModelConfig(storeUserID, argsJSON string) string { if strings.TrimSpace(args.CustomModelName) != "" { customModelName = strings.TrimSpace(args.CustomModelName) } - if err := a.store.AIModel().Update(storeUserID, existing.ID, enabled, strings.TrimSpace(args.APIKey), customAPIURL, customModelName); err != nil { + apiKey := strings.TrimSpace(args.APIKey) + effectiveAPIKey := string(existing.APIKey) + if apiKey != "" { + effectiveAPIKey = apiKey + } + if enabled && !modelConfigUsable(existing.Provider, existing.ID, effectiveAPIKey, customAPIURL, customModelName) { + return `{"error":"cannot enable model config before API key is configured"}` + } + if err := a.store.AIModel().Update(storeUserID, existing.ID, enabled, apiKey, customAPIURL, customModelName); err != nil { return fmt.Sprintf(`{"error":"failed to update model config: %s"}`, err) } updated, err := a.store.AIModel().Get(storeUserID, existing.ID) @@ -1893,6 +2057,136 @@ func (a *Agent) toolGetTradeHistory(argsJSON string) string { return string(result) } +func (a *Agent) toolGetCandidateCoins(storeUserID string, userID int64, argsJSON string) string { + if a.store == nil { + return `{"error":"store unavailable"}` + } + + var args struct { + TraderID string `json:"trader_id"` + StrategyID string `json:"strategy_id"` + } + if strings.TrimSpace(argsJSON) != "" { + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err) + } + } + + traderID := strings.TrimSpace(args.TraderID) + strategyID := strings.TrimSpace(args.StrategyID) + state := a.getExecutionState(userID) + if traderID == "" && state.CurrentReferences != nil && state.CurrentReferences.Trader != nil { + traderID = strings.TrimSpace(state.CurrentReferences.Trader.ID) + } + if strategyID == "" && state.CurrentReferences != nil && state.CurrentReferences.Strategy != nil { + strategyID = strings.TrimSpace(state.CurrentReferences.Strategy.ID) + } + + if traderID != "" { + return a.toolGetCandidateCoinsForTrader(storeUserID, traderID) + } + if strategyID != "" { + return a.toolGetCandidateCoinsForStrategy(storeUserID, strategyID) + } + return `{"error":"trader_id or strategy_id is required"}` +} + +func (a *Agent) toolGetCandidateCoinsForTrader(storeUserID, traderID string) string { + if a.traderManager == nil { + return `{"error":"no trader manager configured"}` + } + record, err := a.store.Trader().GetFullConfig(storeUserID, traderID) + if err != nil { + return fmt.Sprintf(`{"error":"failed to load trader: %s"}`, err) + } + memTrader, err := a.traderManager.GetTrader(traderID) + if err != nil { + return fmt.Sprintf(`{"error":"trader is not loaded in memory: %s"}`, err) + } + + coins, coinErr := memTrader.GetCandidateCoins() + cfg := memTrader.GetStrategyConfig() + status := memTrader.GetStatus() + isRunning, _ := status["is_running"].(bool) + payload := map[string]any{ + "trader": safeTraderForTool(record.Trader, isRunning), + "coin_source": candidateCoinSourceSummary(cfg), + "candidate_count": len(coins), + "candidate_symbols": candidateCoinSymbols(coins), + "candidates": candidateCoinDetails(coins), + } + if coinErr != nil { + payload["error"] = coinErr.Error() + } + result, _ := json.Marshal(payload) + return string(result) +} + +func (a *Agent) toolGetCandidateCoinsForStrategy(storeUserID, strategyID string) string { + record, err := a.store.Strategy().Get(storeUserID, strategyID) + if err != nil { + return fmt.Sprintf(`{"error":"failed to load strategy: %s"}`, err) + } + cfg, err := record.ParseConfig() + if err != nil { + return fmt.Sprintf(`{"error":"failed to parse strategy config: %s"}`, err) + } + + engine := kernel.NewStrategyEngine(cfg) + coins, coinErr := engine.GetCandidateCoins() + payload := map[string]any{ + "strategy": safeStrategyForTool(record), + "coin_source": candidateCoinSourceSummary(cfg), + "candidate_count": len(coins), + "candidate_symbols": candidateCoinSymbols(coins), + "candidates": candidateCoinDetails(coins), + } + if coinErr != nil { + payload["error"] = coinErr.Error() + } + result, _ := json.Marshal(payload) + return string(result) +} + +func candidateCoinSourceSummary(cfg *store.StrategyConfig) map[string]any { + if cfg == nil { + return nil + } + return map[string]any{ + "source_type": cfg.CoinSource.SourceType, + "use_ai500": cfg.CoinSource.UseAI500, + "ai500_limit": cfg.CoinSource.AI500Limit, + "use_oi_top": cfg.CoinSource.UseOITop, + "oi_top_limit": cfg.CoinSource.OITopLimit, + "use_oi_low": cfg.CoinSource.UseOILow, + "oi_low_limit": cfg.CoinSource.OILowLimit, + "use_hyper_all": cfg.CoinSource.UseHyperAll, + "use_hyper_main": cfg.CoinSource.UseHyperMain, + "hyper_main_limit": cfg.CoinSource.HyperMainLimit, + "static_coins": cfg.CoinSource.StaticCoins, + "excluded_coins": cfg.CoinSource.ExcludedCoins, + } +} + +func candidateCoinSymbols(coins []kernel.CandidateCoin) []string { + out := make([]string, 0, len(coins)) + for _, coin := range coins { + out = append(out, coin.Symbol) + } + return out +} + +func candidateCoinDetails(coins []kernel.CandidateCoin) []map[string]any { + out := make([]map[string]any, 0, len(coins)) + for _, coin := range coins { + out = append(out, map[string]any{ + "symbol": coin.Symbol, + "sources": coin.Sources, + }) + } + return out +} + // knownCryptoSymbols is a set of well-known cryptocurrency base symbols. // Without this, isStockSymbol("BTC") would incorrectly return true because // "BTC" is 3 uppercase letters and the suffix check only catches "BTCUSDT"-style pairs. diff --git a/agent/workflow.go b/agent/workflow.go new file mode 100644 index 00000000..fa704c3f --- /dev/null +++ b/agent/workflow.go @@ -0,0 +1,521 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "nofx/mcp" +) + +const ( + workflowTaskPending = "pending" + workflowTaskRunning = "running" + workflowTaskCompleted = "completed" + workflowTaskFailed = "failed" +) + +type WorkflowTask struct { + ID string `json:"id,omitempty"` + Skill string `json:"skill,omitempty"` + Action string `json:"action,omitempty"` + Request string `json:"request,omitempty"` + DependsOn []string `json:"depends_on,omitempty"` + Status string `json:"status,omitempty"` + Error string `json:"error,omitempty"` +} + +type WorkflowSession struct { + UserID int64 `json:"user_id"` + OriginalRequest string `json:"original_request,omitempty"` + Tasks []WorkflowTask `json:"tasks,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type workflowDecomposition struct { + Tasks []WorkflowTask `json:"tasks"` +} + +func workflowSessionConfigKey(userID int64) string { + return fmt.Sprintf("agent_workflow_session_%d", userID) +} + +func normalizeWorkflowSession(session WorkflowSession) WorkflowSession { + session.OriginalRequest = strings.TrimSpace(session.OriginalRequest) + normalized := make([]WorkflowTask, 0, len(session.Tasks)) + for i, task := range session.Tasks { + task.ID = strings.TrimSpace(task.ID) + if task.ID == "" { + task.ID = fmt.Sprintf("task_%d", i+1) + } + task.Skill = strings.TrimSpace(task.Skill) + task.Action = normalizeAtomicSkillAction(task.Skill, task.Action) + task.Request = strings.TrimSpace(task.Request) + task.DependsOn = cleanStringList(task.DependsOn) + task.Status = strings.TrimSpace(task.Status) + if task.Status == "" { + task.Status = workflowTaskPending + } + task.Error = strings.TrimSpace(task.Error) + if task.Skill == "" || task.Action == "" || task.Request == "" { + continue + } + normalized = append(normalized, task) + } + session.Tasks = normalized + if len(session.Tasks) == 0 { + return WorkflowSession{} + } + if session.UpdatedAt == "" { + session.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + } + return session +} + +func (a *Agent) getWorkflowSession(userID int64) WorkflowSession { + if a.store == nil { + return WorkflowSession{} + } + raw, err := a.store.GetSystemConfig(workflowSessionConfigKey(userID)) + if err != nil || strings.TrimSpace(raw) == "" { + return WorkflowSession{} + } + var session WorkflowSession + if err := json.Unmarshal([]byte(raw), &session); err != nil { + return WorkflowSession{} + } + return normalizeWorkflowSession(session) +} + +func (a *Agent) saveWorkflowSession(userID int64, session WorkflowSession) { + if a.store == nil { + return + } + session = normalizeWorkflowSession(session) + if len(session.Tasks) == 0 { + _ = a.store.SetSystemConfig(workflowSessionConfigKey(userID), "") + return + } + session.UserID = userID + session.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + data, err := json.Marshal(session) + if err != nil { + return + } + _ = a.store.SetSystemConfig(workflowSessionConfigKey(userID), string(data)) +} + +func (a *Agent) clearWorkflowSession(userID int64) { + if a.store == nil { + return + } + _ = a.store.SetSystemConfig(workflowSessionConfigKey(userID), "") +} + +func hasActiveWorkflowSession(session WorkflowSession) bool { + if len(session.Tasks) == 0 { + return false + } + for _, task := range session.Tasks { + if task.Status == workflowTaskPending || task.Status == workflowTaskRunning { + return true + } + } + return false +} + +func nextRunnableWorkflowTask(session WorkflowSession) (WorkflowTask, int, bool) { + for i, task := range session.Tasks { + if task.Status != workflowTaskPending && task.Status != workflowTaskRunning { + continue + } + depsReady := true + for _, dep := range task.DependsOn { + ok := false + for _, candidate := range session.Tasks { + if candidate.ID == dep && candidate.Status == workflowTaskCompleted { + ok = true + break + } + } + if !ok { + depsReady = false + break + } + } + if depsReady { + return task, i, true + } + } + return WorkflowTask{}, -1, false +} + +func supportedWorkflowSkill(skill, action string) bool { + skill = strings.TrimSpace(skill) + action = normalizeAtomicSkillAction(skill, action) + if skill == "" || action == "" { + return false + } + if _, ok := getSkillDAG(skill, action); ok { + return true + } + switch skill { + case "trader_management", "strategy_management", "model_management", "exchange_management": + switch action { + case "create", "query_list", "query_detail", "query_running", "activate": + return true + } + } + return false +} + +func (a *Agent) tryWorkflowIntent(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) { + if session := a.getWorkflowSession(userID); hasActiveWorkflowSession(session) { + return a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, session, onEvent) + } + + decomposition, err := a.decomposeWorkflowIntent(ctx, userID, lang, text) + if err != nil || len(decomposition.Tasks) <= 1 { + return "", false, err + } + session := WorkflowSession{ + UserID: userID, + OriginalRequest: text, + Tasks: decomposition.Tasks, + } + a.saveWorkflowSession(userID, session) + return a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, session, onEvent) +} + +func (a *Agent) handleWorkflowSession(ctx context.Context, storeUserID string, userID int64, lang, text string, session WorkflowSession, onEvent func(event, data string)) (string, bool, error) { + if isExplicitFlowAbort(text) { + a.clearSkillSession(userID) + a.clearWorkflowSession(userID) + if lang == "zh" { + return "已取消当前任务流。", true, nil + } + return "Cancelled the current workflow.", true, nil + } + + if activeSkill := a.getSkillSession(userID); strings.TrimSpace(activeSkill.Name) != "" { + answer, handled := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent) + if !handled { + return "", false, nil + } + session = a.getWorkflowSession(userID) + if hasActiveWorkflowSession(session) && strings.TrimSpace(a.getSkillSession(userID).Name) == "" { + session = markCurrentWorkflowTask(session, workflowTaskCompleted, "") + a.saveWorkflowSession(userID, session) + if final, done, err := a.maybeAdvanceWorkflow(ctx, storeUserID, userID, lang, session, onEvent); done || err != nil { + if final != "" && answer != "" { + return answer + "\n\n" + final, true, err + } + if answer != "" { + return answer, true, err + } + return final, true, err + } + } + return answer, true, nil + } + + return a.maybeAdvanceWorkflow(ctx, storeUserID, userID, lang, session, onEvent) +} + +func (a *Agent) maybeAdvanceWorkflow(ctx context.Context, storeUserID string, userID int64, lang string, session WorkflowSession, onEvent func(event, data string)) (string, bool, error) { + task, index, ok := nextRunnableWorkflowTask(session) + if !ok { + summary := a.generateWorkflowSummary(ctx, userID, lang, session) + a.clearWorkflowSession(userID) + if summary == "" { + if lang == "zh" { + summary = "已完成当前任务流。" + } else { + summary = "Completed the current workflow." + } + } + if onEvent != nil { + onEvent(StreamEventPlan, summary) + onEvent(StreamEventDelta, summary) + } + return summary, true, nil + } + + session.Tasks[index].Status = workflowTaskRunning + a.saveWorkflowSession(userID, session) + taskSession := skillSession{Name: task.Skill, Action: task.Action, Phase: "collecting"} + a.saveSkillSession(userID, taskSession) + + if onEvent != nil { + onEvent(StreamEventPlan, a.formatWorkflowStatus(lang, session)) + onEvent(StreamEventTool, "workflow:"+task.Skill+":"+task.Action) + } + + answer, handled := a.tryHardSkill(ctx, storeUserID, userID, lang, task.Request, onEvent) + if !handled { + session.Tasks[index].Status = workflowTaskFailed + session.Tasks[index].Error = "task_not_handled" + a.saveWorkflowSession(userID, session) + return "", false, nil + } + + if strings.TrimSpace(a.getSkillSession(userID).Name) == "" { + session = a.getWorkflowSession(userID) + session = markCurrentWorkflowTask(session, workflowTaskCompleted, "") + a.saveWorkflowSession(userID, session) + if more, ok, err := a.maybeAdvanceWorkflow(ctx, storeUserID, userID, lang, session, onEvent); ok || err != nil { + if answer != "" && more != "" { + return answer + "\n\n" + more, true, err + } + if answer != "" { + return answer, true, err + } + return more, true, err + } + } + return answer, true, nil +} + +func markCurrentWorkflowTask(session WorkflowSession, status, errMsg string) WorkflowSession { + for i := range session.Tasks { + if session.Tasks[i].Status == workflowTaskRunning { + session.Tasks[i].Status = status + session.Tasks[i].Error = strings.TrimSpace(errMsg) + return session + } + } + return session +} + +func (a *Agent) formatWorkflowStatus(lang string, session WorkflowSession) string { + parts := make([]string, 0, len(session.Tasks)) + for _, task := range session.Tasks { + label := task.Request + if label == "" { + label = task.Skill + ":" + task.Action + } + switch task.Status { + case workflowTaskCompleted: + label = "✓ " + label + case workflowTaskRunning: + label = "→ " + label + default: + label = "· " + label + } + parts = append(parts, label) + } + if lang == "zh" { + return "任务流:" + strings.Join(parts, " | ") + } + return "Workflow: " + strings.Join(parts, " | ") +} + +func (a *Agent) generateWorkflowSummary(ctx context.Context, userID int64, lang string, session WorkflowSession) string { + completed := make([]string, 0, len(session.Tasks)) + for _, task := range session.Tasks { + if task.Status == workflowTaskCompleted { + completed = append(completed, task.Request) + } + } + if len(completed) == 0 { + return "" + } + if a.aiClient == nil { + if lang == "zh" { + return "已完成这些任务:" + strings.Join(completed, ";") + } + return "Completed these tasks: " + strings.Join(completed, "; ") + } + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + systemPrompt := `You are summarizing a finished workflow for NOFXi. +Return one short user-facing summary in the user's language. +Do not mention internal DAG, scheduler, or JSON.` + userPrompt := fmt.Sprintf("Language: %s\nOriginal request: %s\nCompleted tasks:\n- %s", lang, session.OriginalRequest, strings.Join(completed, "\n- ")) + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + if lang == "zh" { + return "已完成这些任务:" + strings.Join(completed, ";") + } + return "Completed these tasks: " + strings.Join(completed, "; ") + } + return strings.TrimSpace(raw) +} + +func (a *Agent) decomposeWorkflowIntent(ctx context.Context, userID int64, lang, text string) (workflowDecomposition, error) { + if !looksLikeMultiTaskIntent(text) { + return workflowDecomposition{}, nil + } + if a.aiClient != nil { + if dec, err := a.decomposeWorkflowIntentWithLLM(ctx, userID, lang, text); err == nil && len(dec.Tasks) > 1 { + return dec, nil + } + } + return a.decomposeWorkflowIntentFallback(text), nil +} + +func looksLikeMultiTaskIntent(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + connectors := []string{",", ",", "然后", "再", "并且", "并", "同时", "and", "then"} + count := 0 + for _, c := range connectors { + if strings.Contains(lower, c) { + count++ + } + } + return count > 0 +} + +func (a *Agent) decomposeWorkflowIntentWithLLM(ctx context.Context, userID int64, lang, text string) (workflowDecomposition, error) { + stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout) + defer cancel() + systemPrompt := `You decompose one NOFXi user request into a small task graph. +Return JSON only. No markdown. +Only use these skills: trader_management, strategy_management, model_management, exchange_management. +Only use one atomic action per task. +Each task must include: +- id +- skill +- action +- request +- depends_on (array, may be empty) +If the request is effectively a single task, return one task only.` + userPrompt := fmt.Sprintf("Language: %s\nUser request: %s", lang, text) + raw, err := a.aiClient.CallWithRequest(&mcp.Request{ + Messages: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + mcp.NewUserMessage(userPrompt), + }, + Ctx: stageCtx, + }) + if err != nil { + return workflowDecomposition{}, err + } + return parseWorkflowDecomposition(raw) +} + +func parseWorkflowDecomposition(raw string) (workflowDecomposition, error) { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "```json") + raw = strings.TrimPrefix(raw, "```") + raw = strings.TrimSuffix(raw, "```") + raw = strings.TrimSpace(raw) + var out workflowDecomposition + if err := json.Unmarshal([]byte(raw), &out); err == nil { + out = normalizeWorkflowDecomposition(out) + return out, nil + } + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start >= 0 && end > start { + if err := json.Unmarshal([]byte(raw[start:end+1]), &out); err == nil { + out = normalizeWorkflowDecomposition(out) + return out, nil + } + } + return workflowDecomposition{}, fmt.Errorf("invalid workflow json") +} + +func normalizeWorkflowDecomposition(out workflowDecomposition) workflowDecomposition { + normalized := make([]WorkflowTask, 0, len(out.Tasks)) + for i, task := range out.Tasks { + task.ID = strings.TrimSpace(task.ID) + if task.ID == "" { + task.ID = fmt.Sprintf("task_%d", i+1) + } + task.Skill = strings.TrimSpace(task.Skill) + task.Action = normalizeAtomicSkillAction(task.Skill, task.Action) + task.Request = strings.TrimSpace(task.Request) + task.DependsOn = cleanStringList(task.DependsOn) + if !supportedWorkflowSkill(task.Skill, task.Action) || task.Request == "" { + continue + } + task.Status = workflowTaskPending + normalized = append(normalized, task) + } + out.Tasks = normalized + return out +} + +func (a *Agent) decomposeWorkflowIntentFallback(text string) workflowDecomposition { + segments := splitWorkflowSegments(text) + tasks := make([]WorkflowTask, 0, len(segments)) + for i, segment := range segments { + task, ok := classifyWorkflowTask(segment) + if !ok { + continue + } + task.ID = fmt.Sprintf("task_%d", i+1) + task.Status = workflowTaskPending + if len(tasks) > 0 { + task.DependsOn = []string{tasks[len(tasks)-1].ID} + } + tasks = append(tasks, task) + } + return workflowDecomposition{Tasks: tasks} +} + +func splitWorkflowSegments(text string) []string { + parts := []string{strings.TrimSpace(text)} + separators := []string{",", ",", "然后", "再", "并且", "同时", " and then ", " then ", " and "} + for _, sep := range separators { + next := make([]string, 0, len(parts)) + for _, part := range parts { + split := strings.Split(part, sep) + for _, candidate := range split { + candidate = strings.TrimSpace(candidate) + if candidate != "" { + next = append(next, candidate) + } + } + } + parts = next + } + return parts +} + +func classifyWorkflowTask(text string) (WorkflowTask, bool) { + segment := strings.TrimSpace(text) + if segment == "" { + return WorkflowTask{}, false + } + switch { + case detectCreateTraderSkill(segment): + return WorkflowTask{Skill: "trader_management", Action: "create", Request: segment}, true + case detectTraderManagementIntent(segment): + action := normalizeAtomicSkillAction("trader_management", detectManagementAction(segment, "trader")) + if supportedWorkflowSkill("trader_management", action) { + return WorkflowTask{Skill: "trader_management", Action: action, Request: segment}, true + } + case detectExchangeManagementIntent(segment): + action := normalizeAtomicSkillAction("exchange_management", detectManagementAction(segment, "exchange")) + if supportedWorkflowSkill("exchange_management", action) { + return WorkflowTask{Skill: "exchange_management", Action: action, Request: segment}, true + } + case detectModelManagementIntent(segment): + action := normalizeAtomicSkillAction("model_management", detectManagementAction(segment, "model")) + if supportedWorkflowSkill("model_management", action) { + return WorkflowTask{Skill: "model_management", Action: action, Request: segment}, true + } + case detectStrategyManagementIntent(segment): + action := normalizeAtomicSkillAction("strategy_management", detectManagementAction(segment, "strategy")) + if action == "" && wantsStrategyDetails(segment) { + action = "query_detail" + } + if supportedWorkflowSkill("strategy_management", action) { + return WorkflowTask{Skill: "strategy_management", Action: action, Request: segment}, true + } + } + return WorkflowTask{}, false +} diff --git a/agent/workflow_test.go b/agent/workflow_test.go new file mode 100644 index 00000000..bffed9bb --- /dev/null +++ b/agent/workflow_test.go @@ -0,0 +1,37 @@ +package agent + +import "testing" + +func TestSplitWorkflowSegments(t *testing.T) { + got := splitWorkflowSegments("把策略删了,再把交易所改名") + if len(got) != 2 { + t.Fatalf("expected 2 segments, got %d: %#v", len(got), got) + } +} + +func TestClassifyWorkflowTask(t *testing.T) { + task, ok := classifyWorkflowTask("把策略删了") + if !ok { + t.Fatal("expected task") + } + if task.Skill != "strategy_management" || task.Action != "delete" { + t.Fatalf("unexpected task: %+v", task) + } +} + +func TestFallbackWorkflowDecompositionBuildsTwoTasks(t *testing.T) { + a := &Agent{} + out := a.decomposeWorkflowIntentFallback("把策略删了,再把交易所改名") + if len(out.Tasks) != 2 { + t.Fatalf("expected 2 tasks, got %d", len(out.Tasks)) + } + if out.Tasks[0].Skill != "strategy_management" { + t.Fatalf("unexpected first task: %+v", out.Tasks[0]) + } + if out.Tasks[1].Skill != "exchange_management" { + t.Fatalf("unexpected second task: %+v", out.Tasks[1]) + } + if len(out.Tasks[1].DependsOn) != 1 || out.Tasks[1].DependsOn[0] != out.Tasks[0].ID { + t.Fatalf("expected dependency on first task, got %+v", out.Tasks[1].DependsOn) + } +} diff --git a/manager/trader_manager.go b/manager/trader_manager.go index 36b745b8..dce65785 100644 --- a/manager/trader_manager.go +++ b/manager/trader_manager.go @@ -11,6 +11,13 @@ import ( "time" ) +func traderLogTag(traderID, traderName string) string { + if traderName != "" { + return fmt.Sprintf("[trader_id=%s trader_name=%s]", traderID, traderName) + } + return fmt.Sprintf("[trader_id=%s]", traderID) +} + // CompetitionCache competition data cache type CompetitionCache struct { data map[string]interface{} @@ -88,9 +95,9 @@ func (tm *TraderManager) StartAll() { logger.Info("🚀 Starting all traders...") for id, t := range tm.traders { go func(traderID string, at *trader.AutoTrader) { - logger.Infof("▶️ Starting %s...", at.GetName()) + logger.Infof("%s ▶️ Starting trader runtime", traderLogTag(traderID, at.GetName())) if err := at.Run(); err != nil { - logger.Infof("❌ %s runtime error: %v", at.GetName(), err) + logger.Warnf("%s runtime error: %v", traderLogTag(traderID, at.GetName()), err) } }(id, t) } @@ -136,9 +143,9 @@ func (tm *TraderManager) AutoStartRunningTraders(st *store.Store) { for id, t := range tm.traders { if runningTraderIDs[id] { go func(traderID string, at *trader.AutoTrader) { - logger.Infof("▶️ Auto-restoring %s...", at.GetName()) + logger.Infof("%s ▶️ Auto-restoring trader runtime", traderLogTag(traderID, at.GetName())) if err := at.Run(); err != nil { - logger.Infof("❌ %s runtime error: %v", at.GetName(), err) + logger.Warnf("%s runtime error: %v", traderLogTag(traderID, at.GetName()), err) } }(id, t) startedCount++ @@ -487,7 +494,7 @@ func (tm *TraderManager) LoadUserTradersFromStore(st *store.Store, userID string logger.Infof("📦 Loading trader %s (AI Model: %s, Exchange: %s/%s, Strategy ID: %s)", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ExchangeType, exchangeCfg.AccountName, traderCfg.StrategyID) err = tm.addTraderFromStore(traderCfg, aiModelCfg, exchangeCfg, st) if err != nil { - logger.Infof("❌ Failed to load trader %s: %v", traderCfg.Name, err) + logger.Warnf("%s failed to load trader: %v", traderLogTag(traderCfg.ID, traderCfg.Name), err) // Save error for later retrieval tm.loadErrors[traderCfg.ID] = err } else { @@ -592,7 +599,7 @@ func (tm *TraderManager) LoadTradersFromStore(st *store.Store) error { // Add to TraderManager (ai500APIURL/oiTopAPIURL already obtained from strategy config) err = tm.addTraderFromStore(traderCfg, aiModelCfg, exchangeCfg, st) if err != nil { - logger.Infof("❌ Failed to add trader %s: %v", traderCfg.Name, err) + logger.Warnf("%s failed to add trader: %v", traderLogTag(traderCfg.ID, traderCfg.Name), err) continue } } @@ -727,17 +734,17 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg // Auto-start if trader was running before shutdown if traderCfg.IsRunning { - logger.Infof("🔄 Auto-starting trader '%s' (was running before shutdown)...", traderCfg.Name) + logger.Infof("%s 🔄 Auto-starting trader (was running before shutdown)...", traderLogTag(traderCfg.ID, traderCfg.Name)) go func(trader *trader.AutoTrader, traderName, traderID, userID string) { if err := trader.Run(); err != nil { - logger.Warnf("⚠️ Trader '%s' stopped with error: %v", traderName, err) + logger.Warnf("%s trader stopped with error: %v", traderLogTag(traderID, traderName), err) // Update database to reflect stopped state if st != nil { _ = st.Trader().UpdateStatus(userID, traderID, false) } } }(at, traderCfg.Name, traderCfg.ID, traderCfg.UserID) - logger.Infof("✅ Trader '%s' auto-started successfully", traderCfg.Name) + logger.Infof("%s ✅ Trader auto-started successfully", traderLogTag(traderCfg.ID, traderCfg.Name)) } return nil diff --git a/store/ai_model.go b/store/ai_model.go index b14a88ed..7cd08e2a 100644 --- a/store/ai_model.go +++ b/store/ai_model.go @@ -131,7 +131,7 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) { if userID == "" { userID = "default" } - model, err := s.firstEnabled(userID) + model, err := s.firstEnabledUsable(userID) if err == nil { return model, nil } @@ -139,14 +139,14 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) { return nil, err } if userID != "default" { - return s.firstEnabled("default") + return s.firstEnabledUsable("default") } return nil, fmt.Errorf("please configure an available AI model in the system first") } -func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) { +func (s *AIModelStore) firstEnabledUsable(userID string) (*AIModel, error) { var model AIModel - err := s.db.Where("user_id = ? AND enabled = ?", userID, true). + err := s.db.Where("user_id = ? AND enabled = ? AND api_key != ''", userID, true). Order("updated_at DESC, id ASC"). First(&model).Error if err != nil { diff --git a/trader/auto_trader.go b/trader/auto_trader.go index 0a0a6786..1faaafd4 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -24,6 +24,31 @@ import ( "time" ) +func (at *AutoTrader) logTag() string { + if at == nil { + return "[trader_id=unknown]" + } + if at.name != "" { + return fmt.Sprintf("[trader_id=%s trader_name=%s]", at.id, at.name) + } + return fmt.Sprintf("[trader_id=%s]", at.id) +} + +func (at *AutoTrader) logInfof(format string, args ...interface{}) { + values := append([]interface{}{at.logTag()}, args...) + logger.Infof("%s "+format, values...) +} + +func (at *AutoTrader) logWarnf(format string, args ...interface{}) { + values := append([]interface{}{at.logTag()}, args...) + logger.Warnf("%s "+format, values...) +} + +func (at *AutoTrader) logErrorf(format string, args ...interface{}) { + values := append([]interface{}{at.logTag()}, args...) + logger.Errorf("%s "+format, values...) +} + // AutoTraderConfig auto trading configuration (simplified version - AI makes all decisions) type AutoTraderConfig struct { // Trader identification @@ -381,8 +406,8 @@ func (at *AutoTrader) Run() error { at.startTime = time.Now() logger.Info("🚀 AI-driven automatic trading system started") - logger.Infof("💰 Initial balance: %.2f USDT", at.initialBalance) - logger.Infof("⚙️ Scan interval: %v", at.config.ScanInterval) + at.logInfof("💰 Initial balance: %.2f USDT", at.initialBalance) + at.logInfof("⚙️ Scan interval: %v", at.config.ScanInterval) logger.Info("🤖 AI will make full decisions on leverage, position size, stop loss/take profit, etc.") // Pre-launch checks for claw402 users @@ -397,7 +422,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "lighter" { if lighterTrader, ok := at.trader.(*lighter.LighterTraderV2); ok && at.store != nil { lighterTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] Lighter order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 Lighter order+position sync enabled (every 30s)") } } @@ -405,7 +430,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "hyperliquid" { if hyperliquidTrader, ok := at.trader.(*hyperliquid.HyperliquidTrader); ok && at.store != nil { hyperliquidTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] Hyperliquid order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 Hyperliquid order+position sync enabled (every 30s)") } } @@ -413,7 +438,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "bybit" { if bybitTrader, ok := at.trader.(*bybit.BybitTrader); ok && at.store != nil { bybitTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] Bybit order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 Bybit order+position sync enabled (every 30s)") } } @@ -421,7 +446,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "okx" { if okxTrader, ok := at.trader.(*okx.OKXTrader); ok && at.store != nil { okxTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] OKX order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 OKX order+position sync enabled (every 30s)") } } @@ -429,7 +454,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "bitget" { if bitgetTrader, ok := at.trader.(*bitget.BitgetTrader); ok && at.store != nil { bitgetTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] Bitget order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 Bitget order+position sync enabled (every 30s)") } } @@ -437,7 +462,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "aster" { if asterTrader, ok := at.trader.(*aster.AsterTrader); ok && at.store != nil { asterTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] Aster order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 Aster order+position sync enabled (every 30s)") } } @@ -445,7 +470,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "binance" { if binanceTrader, ok := at.trader.(*binance.FuturesTrader); ok && at.store != nil { binanceTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] Binance order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 Binance order+position sync enabled (every 30s)") } } @@ -453,7 +478,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "gate" { if gateTrader, ok := at.trader.(*gate.GateTrader); ok && at.store != nil { gateTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] Gate order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 Gate order+position sync enabled (every 30s)") } } @@ -461,7 +486,7 @@ func (at *AutoTrader) Run() error { if at.exchange == "kucoin" { if kucoinTrader, ok := at.trader.(*kucoin.KuCoinTrader); ok && at.store != nil { kucoinTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second) - logger.Infof("🔄 [%s] KuCoin order+position sync enabled (every 30s)", at.name) + at.logInfof("🔄 KuCoin order+position sync enabled (every 30s)") } } @@ -471,9 +496,9 @@ func (at *AutoTrader) Run() error { // Check if this is a grid trading strategy isGridStrategy := at.IsGridStrategy() if isGridStrategy { - logger.Infof("🔲 [%s] Grid trading strategy detected, initializing grid...", at.name) + at.logInfof("🔲 Grid trading strategy detected, initializing grid...") if err := at.InitializeGrid(); err != nil { - logger.Errorf("❌ [%s] Failed to initialize grid: %v", at.name, err) + at.logErrorf("❌ Failed to initialize grid: %v", err) return fmt.Errorf("grid initialization failed: %w", err) } } @@ -481,11 +506,11 @@ func (at *AutoTrader) Run() error { // Execute immediately on first run if isGridStrategy { if err := at.RunGridCycle(); err != nil { - logger.Infof("❌ Grid execution failed: %v", err) + at.logErrorf("❌ Grid execution failed: %v", err) } } else { if err := at.runCycle(); err != nil { - logger.Infof("❌ Execution failed: %v", err) + at.logErrorf("❌ Execution failed: %v", err) } } @@ -502,15 +527,15 @@ func (at *AutoTrader) Run() error { case <-ticker.C: if isGridStrategy { if err := at.RunGridCycle(); err != nil { - logger.Infof("❌ Grid execution failed: %v", err) + at.logErrorf("❌ Grid execution failed: %v", err) } } else { if err := at.runCycle(); err != nil { - logger.Infof("❌ Execution failed: %v", err) + at.logErrorf("❌ Execution failed: %v", err) } } case <-at.stopMonitorCh: - logger.Infof("[%s] ⏹ Stop signal received, exiting automatic trading main loop", at.name) + at.logInfof("⏹ Stop signal received, exiting automatic trading main loop") return nil } } @@ -590,6 +615,22 @@ func (at *AutoTrader) GetSystemPromptTemplate() string { return "strategy" } +// GetCandidateCoins returns the current candidate coin set from the trader's strategy engine. +func (at *AutoTrader) GetCandidateCoins() ([]kernel.CandidateCoin, error) { + if at.strategyEngine == nil { + return nil, fmt.Errorf("strategy engine not configured") + } + return at.strategyEngine.GetCandidateCoins() +} + +// GetStrategyConfig returns the current strategy config used by the trader. +func (at *AutoTrader) GetStrategyConfig() *store.StrategyConfig { + if at.strategyEngine == nil { + return at.config.StrategyConfig + } + return at.strategyEngine.GetConfig() +} + // GetStore gets data store (for external access to decision records, etc.) func (at *AutoTrader) GetStore() *store.Store { return at.store diff --git a/trader/auto_trader_loop.go b/trader/auto_trader_loop.go index ae440699..c01b91f5 100644 --- a/trader/auto_trader_loop.go +++ b/trader/auto_trader_loop.go @@ -24,7 +24,7 @@ func (at *AutoTrader) runCycle() error { running := at.isRunning at.isRunningMutex.RUnlock() if !running { - logger.Infof("⏹ Trader is stopped, aborting cycle #%d", at.callCount) + at.logInfof("⏹ Trader is stopped, aborting cycle #%d", at.callCount) return nil } @@ -42,7 +42,7 @@ func (at *AutoTrader) runCycle() error { // 1. Check if trading needs to be stopped if time.Now().Before(at.stopUntil) { remaining := at.stopUntil.Sub(time.Now()) - logger.Infof("⏸ Risk control: Trading paused, remaining %.0f minutes", remaining.Minutes()) + at.logWarnf("⏸ Risk control: Trading paused, remaining %.0f minutes", remaining.Minutes()) record.Success = false record.ErrorMessage = fmt.Sprintf("Risk control paused, remaining %.0f minutes", remaining.Minutes()) at.saveDecision(record) @@ -59,6 +59,7 @@ func (at *AutoTrader) runCycle() error { // 4. Collect trading context ctx, err := at.buildTradingContext() if err != nil { + at.logErrorf("failed to build trading context: %v", err) record.Success = false record.ErrorMessage = fmt.Sprintf("Failed to build trading context: %v", err) at.saveDecision(record) @@ -71,7 +72,7 @@ func (at *AutoTrader) runCycle() error { // If no candidate coins available, log but do not error if len(ctx.CandidateCoins) == 0 { - logger.Infof("ℹ️ No candidate coins available, skipping this cycle") + at.logInfof("ℹ️ No candidate coins available, skipping this cycle") record.Success = true // Not an error, just no candidate coins record.ExecutionLog = append(record.ExecutionLog, "No candidate coins available, cycle skipped") record.AccountState = store.AccountSnapshot{ @@ -90,16 +91,16 @@ func (at *AutoTrader) runCycle() error { record.CandidateCoins = append(record.CandidateCoins, coin.Symbol) } - logger.Infof("📊 Account equity: %.2f USDT | Available: %.2f USDT | Positions: %d", + at.logInfof("📊 Account equity: %.2f USDT | Available: %.2f USDT | Positions: %d", ctx.Account.TotalEquity, ctx.Account.AvailableBalance, ctx.Account.PositionCount) // 5. Use strategy engine to call AI for decision - logger.Infof("🤖 Requesting AI analysis and decision... [Strategy Engine]") + at.logInfof("🤖 Requesting AI analysis and decision... [Strategy Engine]") aiDecision, err := kernel.GetFullDecisionWithStrategy(ctx, at.mcpClient, at.strategyEngine, "balanced") if aiDecision != nil && aiDecision.AIRequestDurationMs > 0 { record.AIRequestDurationMs = aiDecision.AIRequestDurationMs - logger.Infof("⏱️ AI call duration: %.2f seconds", float64(record.AIRequestDurationMs)/1000) + at.logInfof("⏱️ AI call duration: %.2f seconds", float64(record.AIRequestDurationMs)/1000) record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("AI call duration: %d ms", record.AIRequestDurationMs)) } @@ -119,7 +120,7 @@ func (at *AutoTrader) runCycle() error { // Record AI charge (track cost regardless of decision outcome) if aiDecision != nil && at.store != nil { if chargeErr := at.store.AICharge().Record(at.id, at.aiModel, at.config.AIModel); chargeErr != nil { - logger.Warnf("⚠️ Failed to record AI charge: %v", chargeErr) + at.logWarnf("⚠️ Failed to record AI charge: %v", chargeErr) } } @@ -132,10 +133,9 @@ func (at *AutoTrader) runCycle() error { if at.consecutiveAIFailures >= 3 && !at.safeMode { at.safeMode = true at.safeModeReason = fmt.Sprintf("AI failed %d consecutive times: %v", at.consecutiveAIFailures, err) - logger.Errorf("🛡️ [%s] SAFE MODE ACTIVATED — AI failed %d times in a row. No new positions will be opened. Existing positions are protected with current stop-loss settings.", - at.name, at.consecutiveAIFailures) - logger.Errorf("🛡️ [%s] Reason: %v", at.name, err) - logger.Errorf("🛡️ [%s] Action: Will keep trying AI each cycle. Safe mode auto-deactivates when AI recovers.", at.name) + at.logErrorf("🛡️ SAFE MODE ACTIVATED — AI failed %d times in a row. No new positions will be opened. Existing positions are protected with current stop-loss settings.", at.consecutiveAIFailures) + at.logErrorf("🛡️ Reason: %v", err) + at.logErrorf("🛡️ Action: Will keep trying AI each cycle. Safe mode auto-deactivates when AI recovers.") } // Print system prompt and AI chain of thought (output even with errors for debugging) @@ -159,7 +159,7 @@ func (at *AutoTrader) runCycle() error { // In safe mode, don't return error — keep the loop running to retry next cycle if at.safeMode { - logger.Warnf("🛡️ [%s] Safe mode: skipping this cycle, will retry in %v", at.name, at.config.ScanInterval) + at.logWarnf("🛡️ Safe mode: skipping this cycle, will retry in %v", at.config.ScanInterval) return nil } @@ -168,11 +168,11 @@ func (at *AutoTrader) runCycle() error { // AI succeeded — reset failure counter and deactivate safe mode if at.consecutiveAIFailures > 0 { - logger.Infof("✅ [%s] AI recovered after %d consecutive failures", at.name, at.consecutiveAIFailures) + at.logInfof("✅ AI recovered after %d consecutive failures", at.consecutiveAIFailures) } at.consecutiveAIFailures = 0 if at.safeMode { - logger.Infof("🛡️ [%s] SAFE MODE DEACTIVATED — AI is working again. Resuming normal trading.", at.name) + at.logInfof("🛡️ SAFE MODE DEACTIVATED — AI is working again. Resuming normal trading.") at.safeMode = false at.safeModeReason = "" } @@ -219,7 +219,7 @@ func (at *AutoTrader) runCycle() error { running = at.isRunning at.isRunningMutex.RUnlock() if !running { - logger.Infof("⏹ Trader stopped before decision execution, aborting cycle #%d", at.callCount) + at.logInfof("⏹ Trader stopped before decision execution, aborting cycle #%d", at.callCount) return nil } @@ -228,14 +228,14 @@ func (at *AutoTrader) runCycle() error { filtered := make([]kernel.Decision, 0) for _, d := range sortedDecisions { if d.Action == "open_long" || d.Action == "open_short" { - logger.Warnf("🛡️ [%s] Safe mode: BLOCKED %s %s (no new positions allowed)", at.name, d.Action, d.Symbol) + at.logWarnf("🛡️ Safe mode: BLOCKED %s %s (no new positions allowed)", d.Action, d.Symbol) continue } filtered = append(filtered, d) } sortedDecisions = filtered if len(sortedDecisions) == 0 { - logger.Infof("🛡️ [%s] Safe mode: all decisions were open positions, nothing to execute", at.name) + at.logInfof("🛡️ Safe mode: all decisions were open positions, nothing to execute") } } @@ -246,7 +246,7 @@ func (at *AutoTrader) runCycle() error { running = at.isRunning at.isRunningMutex.RUnlock() if !running { - logger.Infof("⏹ Trader stopped during decision execution, aborting remaining decisions") + at.logInfof("⏹ Trader stopped during decision execution, aborting remaining decisions") break } @@ -265,7 +265,7 @@ func (at *AutoTrader) runCycle() error { } if err := at.executeDecisionWithRecord(&d, &actionRecord); err != nil { - logger.Infof("❌ Failed to execute decision (%s %s): %v", d.Symbol, d.Action, err) + at.logErrorf("❌ Failed to execute decision (%s %s): %v", d.Symbol, d.Action, err) actionRecord.Error = err.Error() record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("❌ %s %s failed: %v", d.Symbol, d.Action, err)) } else { @@ -280,7 +280,7 @@ func (at *AutoTrader) runCycle() error { // 9. Save decision record if err := at.saveDecision(record); err != nil { - logger.Infof("⚠ Failed to save decision record: %v", err) + at.logWarnf("⚠ Failed to save decision record: %v", err) } return nil @@ -417,12 +417,12 @@ func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) { // 3. Use strategy engine to get candidate coins (must have strategy engine) var candidateCoins []kernel.CandidateCoin if at.strategyEngine == nil { - logger.Infof("⚠️ [%s] No strategy engine configured, skipping candidate coins", at.name) + at.logWarnf("⚠️ No strategy engine configured, skipping candidate coins") } else { coins, err := at.strategyEngine.GetCandidateCoins() if err != nil { // Log warning but don't fail - equity snapshot should still be saved - logger.Infof("⚠️ [%s] Failed to get candidate coins: %v (will use empty list)", at.name, err) + at.logWarnf("⚠️ Failed to get candidate coins: %v (will use empty list)", err) } else { candidateCoins = coins logger.Infof("📋 [%s] Strategy engine fetched candidate coins: %d", at.name, len(candidateCoins)) @@ -473,7 +473,7 @@ func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) { // Get recent 10 closed trades for AI context recentTrades, err := at.store.Position().GetRecentTrades(at.id, 10) if err != nil { - logger.Infof("⚠️ [%s] Failed to get recent trades: %v", at.name, err) + at.logWarnf("⚠️ Failed to get recent trades: %v", err) } else { logger.Infof("📊 [%s] Found %d recent closed trades for AI context", at.name, len(recentTrades)) for _, trade := range recentTrades { @@ -503,11 +503,11 @@ func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) { // Get trading statistics for AI context stats, err := at.store.Position().GetFullStats(at.id) if err != nil { - logger.Infof("⚠️ [%s] Failed to get trading stats: %v", at.name, err) + at.logWarnf("⚠️ Failed to get trading stats: %v", err) } else if stats == nil { - logger.Infof("⚠️ [%s] GetFullStats returned nil", at.name) + at.logWarnf("⚠️ GetFullStats returned nil") } else if stats.TotalTrades == 0 { - logger.Infof("⚠️ [%s] GetFullStats returned 0 trades (traderID=%s)", at.name, at.id) + at.logWarnf("⚠️ GetFullStats returned 0 trades") } else { ctx.TradingStats = &kernel.TradingStats{ TotalTrades: stats.TotalTrades, @@ -523,7 +523,7 @@ func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) { at.name, stats.TotalTrades, stats.WinRate, stats.ProfitFactor, stats.SharpeRatio, stats.MaxDrawdownPct) } } else { - logger.Infof("⚠️ [%s] Store is nil, cannot get recent trades", at.name) + at.logWarnf("⚠️ Store is nil, cannot get recent trades") } // 8. Get quantitative data (if enabled in strategy config) @@ -630,15 +630,15 @@ func (at *AutoTrader) checkClaw402Balance() { if at.claw402WalletAddr != "" { balance, err := wallet.QueryUSDCBalance(at.claw402WalletAddr) if err != nil { - logger.Warnf("⚠️ [%s] Failed to query USDC balance: %v", at.name, err) + at.logWarnf("⚠️ Failed to query USDC balance: %v", err) return } if balance < 1.0 { - logger.Warnf("⚠️ [%s] Low USDC balance: $%.2f — AI may stop soon!", at.name, balance) + at.logWarnf("⚠️ Low USDC balance: $%.2f — AI may stop soon!", balance) } if balance <= 0 { - logger.Errorf("🚨 [%s] USDC balance is ZERO — AI calls will fail!", at.name) + at.logErrorf("🚨 USDC balance is ZERO — AI calls will fail!") } runway := float64(0) diff --git a/web/src/components/agent/AgentStepPanel.tsx b/web/src/components/agent/AgentStepPanel.tsx index 0826a110..9e999bb7 100644 --- a/web/src/components/agent/AgentStepPanel.tsx +++ b/web/src/components/agent/AgentStepPanel.tsx @@ -43,7 +43,7 @@ export function AgentStepPanel({ steps, visible }: AgentStepPanelProps) { marginBottom: 10, }} > - Agent Steps + Live Run
{steps.map((step) => { diff --git a/web/src/pages/AgentChatPage.tsx b/web/src/pages/AgentChatPage.tsx index c71b339f..859bebda 100644 --- a/web/src/pages/AgentChatPage.tsx +++ b/web/src/pages/AgentChatPage.tsx @@ -19,6 +19,11 @@ import { WelcomeScreen } from '../components/agent/WelcomeScreen' import { ChatMessages } from '../components/agent/ChatMessages' import { ChatInput, type ChatInputHandle } from '../components/agent/ChatInput' import { UserPreferencesPanel } from '../components/agent/UserPreferencesPanel' +import { + useAgentChatStore, + type AgentMessage as Message, + type AgentStep, +} from '../stores/agentChatStore' import { chatStorageKey, clearAgentMessages, @@ -29,22 +34,6 @@ import { persistAgentMessages, } from '../lib/agentChatStorage' -interface Message { - id: string - role: 'user' | 'bot' - text: string - time: string - streaming?: boolean - steps?: AgentStep[] -} - -interface AgentStep { - id: string - label: string - status: 'planning' | 'pending' | 'running' | 'completed' | 'replanned' - detail?: string -} - let msgIdCounter = 0 function nextId() { return `msg-${Date.now()}-${++msgIdCounter}` @@ -66,7 +55,7 @@ function parsePlanSteps(data: string): AgentStep[] { return text.split(/\s*->\s*/).map((part, index) => { const cleaned = part.replace(/^\d+\./, '').trim() return { - id: `plan-${index + 1}`, + id: `action-${index + 1}`, label: cleaned || `Step ${index + 1}`, status: 'pending', } @@ -76,7 +65,7 @@ function parsePlanSteps(data: string): AgentStep[] { function parseStepEvent(data: string, fallbackIndex: number): AgentStep { const match = data.match(/Step\s+(\d+)\/(\d+):\s+(.+)$/i) || data.match(/步骤\s+(\d+)\/(\d+):\s+(.+)$/) if (match) { - const id = `plan-${match[1]}` + const id = `action-${match[1]}` return { id, label: match[3].trim(), @@ -110,11 +99,14 @@ export function AgentChatPage() { const [storageUserId, setStorageUserId] = useState(() => getStoredAuthUserId()) const [sidebarOpen, setSidebarOpen] = useState(() => window.innerWidth > 1024) const storageKey = chatStorageKey(user?.id || storageUserId) - const [messages, setMessages] = useState( - () => loadAgentMessages(window.localStorage, user?.id || storageUserId).messages - ) - const [historyHydrated, setHistoryHydrated] = useState(false) - const [loading, setLoading] = useState(false) + const messages = useAgentChatStore((state) => state.messages) + const loading = useAgentChatStore((state) => state.loading) + const historyHydrated = useAgentChatStore((state) => state.hydrated) + const activeUserId = useAgentChatStore((state) => state.activeUserId) + const setMessages = useAgentChatStore((state) => state.setMessages) + const updateMessages = useAgentChatStore((state) => state.updateMessages) + const setLoading = useAgentChatStore((state) => state.setLoading) + const resetForUser = useAgentChatStore((state) => state.resetForUser) const messagesEndRef = useRef(null) const chatInputRef = useRef(null) const abortRef = useRef(null) @@ -147,10 +139,13 @@ export function AgentChatPage() { // Restore chat history for the current user when opening the agent page. useEffect(() => { - setHistoryHydrated(false) - setMessages(loadAgentMessages(window.localStorage, user?.id || storageUserId).messages) - setHistoryHydrated(true) - }, [storageKey, storageUserId, user?.id]) + const nextUserId = user?.id || storageUserId + if (activeUserId === nextUserId && historyHydrated) return + resetForUser( + nextUserId, + loadAgentMessages(window.localStorage, nextUserId).messages + ) + }, [activeUserId, historyHydrated, resetForUser, storageKey, storageUserId, user?.id]) // Persist chat history locally so page navigation does not wipe the conversation. useEffect(() => { @@ -163,6 +158,26 @@ export function AgentChatPage() { } }, [historyHydrated, messages, storageKey, storageUserId, user?.id]) + const persistMessagesSnapshot = (nextMessages: Message[]) => { + const persistable = prepareAgentMessagesForPersistence(nextMessages).slice(-100) + persistAgentMessages(window.localStorage, user?.id || storageUserId, persistable) + } + + const replaceMessages = (nextMessages: Message[]) => { + setMessages(nextMessages) + if (historyHydrated) { + persistMessagesSnapshot(nextMessages) + } + } + + const patchMessages = (updater: (prev: Message[]) => Message[]) => { + const nextMessages = updater(useAgentChatStore.getState().messages) + updateMessages(() => nextMessages) + if (useAgentChatStore.getState().hydrated) { + persistMessagesSnapshot(nextMessages) + } + } + // Responsive sidebar useEffect(() => { const handleResize = () => { @@ -201,10 +216,10 @@ export function AgentChatPage() { streaming: true, }, ] - setMessages((prev) => + replaceMessages( text.trim() === '/clear' ? nextConversation - : [...prev, ...nextConversation] + : [...useAgentChatStore.getState().messages, ...nextConversation] ) setLoading(true) @@ -275,7 +290,7 @@ export function AgentChatPage() { if (eventType === 'delta') { // data is the accumulated text so far finalText = data - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId ? { ...m, text: data, time: now() } @@ -284,13 +299,12 @@ export function AgentChatPage() { ) } else if (eventType === 'plan') { const parsedSteps = parsePlanSteps(data) - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId ? { ...m, steps: parsedSteps.length > 0 ? parsedSteps : m.steps, - text: m.text || data, time: now(), } : m @@ -299,33 +313,31 @@ export function AgentChatPage() { } else if (eventType === 'step_start') { stepCounter += 1 const nextStep = parseStepEvent(data, stepCounter) - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId ? { ...m, steps: appendStep(m.steps, nextStep), - text: m.text || data, time: now(), } : m ) ) } else if (eventType === 'step_complete') { - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId ? { ...m, steps: markLatestRunningCompleted(m.steps, data), - text: m.text || data, time: now(), } : m ) ) } else if (eventType === 'replan') { - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId ? { @@ -336,7 +348,6 @@ export function AgentChatPage() { status: 'replanned', detail: data, }), - text: m.text || data, time: now(), } : m @@ -346,12 +357,11 @@ export function AgentChatPage() { eventType === 'tool' ) { // Show tool being called as a status indicator - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId ? { ...m, - text: m.text || `🔧 _Calling ${data}..._`, steps: appendStep(m.steps, { id: `tool-${Date.now()}`, label: `Tool: ${data}`, @@ -365,7 +375,7 @@ export function AgentChatPage() { ) } else if (eventType === 'done') { finalText = data - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId ? { ...m, text: data, time: now(), streaming: false } @@ -381,7 +391,7 @@ export function AgentChatPage() { } // If stream ended without a "done" event, mark as done - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId && m.streaming ? { @@ -398,9 +408,9 @@ export function AgentChatPage() { } catch (e: any) { if (e.name === 'AbortError') { // Request was cancelled (e.g. user sent a new message), clean up silently - setMessages((prev) => prev.filter((m) => m.id !== botId)) + patchMessages((prev) => prev.filter((m) => m.id !== botId)) } else { - setMessages((prev) => + patchMessages((prev) => prev.map((m) => m.id === botId ? { diff --git a/web/src/stores/agentChatStore.ts b/web/src/stores/agentChatStore.ts new file mode 100644 index 00000000..74332bb8 --- /dev/null +++ b/web/src/stores/agentChatStore.ts @@ -0,0 +1,52 @@ +import { create } from 'zustand' + +export interface AgentStep { + id: string + label: string + status: 'planning' | 'pending' | 'running' | 'completed' | 'replanned' + detail?: string +} + +export interface AgentMessage { + id: string + role: 'user' | 'bot' + text: string + time: string + streaming?: boolean + steps?: AgentStep[] +} + +interface AgentChatStoreState { + activeUserId?: string + messages: AgentMessage[] + loading: boolean + hydrated: boolean + setActiveUserId: (userId?: string) => void + setMessages: (messages: AgentMessage[]) => void + updateMessages: ( + updater: (messages: AgentMessage[]) => AgentMessage[] + ) => void + setLoading: (loading: boolean) => void + setHydrated: (hydrated: boolean) => void + resetForUser: (userId?: string, messages?: AgentMessage[]) => void +} + +export const useAgentChatStore = create((set) => ({ + activeUserId: undefined, + messages: [], + loading: false, + hydrated: false, + setActiveUserId: (userId) => set({ activeUserId: userId }), + setMessages: (messages) => set({ messages }), + updateMessages: (updater) => + set((state) => ({ messages: updater(state.messages) })), + setLoading: (loading) => set({ loading }), + setHydrated: (hydrated) => set({ hydrated }), + resetForUser: (userId, messages = []) => + set({ + activeUserId: userId, + messages, + loading: false, + hydrated: true, + }), +}))