mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2026-06-06 05:51:19 +08:00
Enhance NOFXi agent workflow and diagnostics
This commit is contained in:
127
agent/backend_logs_test.go
Normal file
127
agent/backend_logs_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nofx/store"
|
||||
)
|
||||
|
||||
func TestReadBackendLogEntriesReturnsRecentErrorLines(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("Getwd() error = %v", err)
|
||||
}
|
||||
tmp := t.TempDir()
|
||||
if err := os.Chdir(tmp); err != nil {
|
||||
t.Fatalf("Chdir(tmp) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.Chdir(wd)
|
||||
})
|
||||
|
||||
if err := os.MkdirAll("data", 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(data) error = %v", err)
|
||||
}
|
||||
logPath := filepath.Join("data", "nofx_2099-01-01.log")
|
||||
content := strings.Join([]string{
|
||||
"04-19 13:00:00 [INFO] api/server.go:590 API server starting",
|
||||
"04-19 13:00:01 [ERRO] api/server.go:600 invalid signature for okx account",
|
||||
"04-19 13:00:02 [ERRO] agent/tools.go:123 model update failed: missing api key",
|
||||
}, "\n") + "\n"
|
||||
if err := os.WriteFile(logPath, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
path, entries, err := readBackendLogEntries(10, "model", true)
|
||||
if err != nil {
|
||||
t.Fatalf("readBackendLogEntries() error = %v", err)
|
||||
}
|
||||
if !strings.Contains(path, "nofx_2099-01-01.log") {
|
||||
t.Fatalf("unexpected log path: %s", path)
|
||||
}
|
||||
if len(entries) != 1 || !strings.Contains(entries[0], "missing api key") {
|
||||
t.Fatalf("unexpected filtered entries: %#v", entries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolGetBackendLogsRequiresOwnedTrader(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("Getwd() error = %v", err)
|
||||
}
|
||||
tmp := t.TempDir()
|
||||
if err := os.Chdir(tmp); err != nil {
|
||||
t.Fatalf("Chdir(tmp) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.Chdir(wd)
|
||||
})
|
||||
|
||||
if err := os.MkdirAll("data", 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(data) error = %v", err)
|
||||
}
|
||||
logPath := filepath.Join("data", "nofx_2099-01-01.log")
|
||||
content := strings.Join([]string{
|
||||
"04-19 13:00:00 [INFO] api/server.go:590 API server starting",
|
||||
"04-19 13:00:01 [ERRO] trader/runtime.go:88 trader_id=trader-owned strategy execution failed",
|
||||
"04-19 13:00:02 [ERRO] trader/runtime.go:89 trader_id=trader-other strategy execution failed",
|
||||
}, "\n") + "\n"
|
||||
if err := os.WriteFile(logPath, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
a := newTestAgentWithStore(t)
|
||||
if err := a.store.Trader().Create(&store.Trader{
|
||||
ID: "trader-owned",
|
||||
UserID: "user-1",
|
||||
Name: "Owned Trader",
|
||||
AIModelID: "model-1",
|
||||
ExchangeID: "exchange-1",
|
||||
StrategyID: "strategy-1",
|
||||
InitialBalance: 1000,
|
||||
}); err != nil {
|
||||
t.Fatalf("create owned trader: %v", err)
|
||||
}
|
||||
if err := a.store.Trader().Create(&store.Trader{
|
||||
ID: "trader-other",
|
||||
UserID: "user-2",
|
||||
Name: "Other Trader",
|
||||
AIModelID: "model-2",
|
||||
ExchangeID: "exchange-2",
|
||||
StrategyID: "strategy-2",
|
||||
InitialBalance: 1000,
|
||||
}); err != nil {
|
||||
t.Fatalf("create other trader: %v", err)
|
||||
}
|
||||
|
||||
resp := a.toolGetBackendLogs("user-1", `{"trader_id":"trader-owned","limit":5}`)
|
||||
var okResult struct {
|
||||
TraderID string `json:"trader_id"`
|
||||
Entries []string `json:"entries"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(resp), &okResult); err != nil {
|
||||
t.Fatalf("unmarshal owned response: %v\nraw=%s", err, resp)
|
||||
}
|
||||
if okResult.TraderID != "trader-owned" || okResult.Count != 1 {
|
||||
t.Fatalf("unexpected owned response: %+v", okResult)
|
||||
}
|
||||
if len(okResult.Entries) != 1 || !strings.Contains(okResult.Entries[0], "trader-owned") {
|
||||
t.Fatalf("unexpected owned entries: %#v", okResult.Entries)
|
||||
}
|
||||
|
||||
resp = a.toolGetBackendLogs("user-1", `{"trader_id":"trader-other","limit":5}`)
|
||||
var denied struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(resp), &denied); err != nil {
|
||||
t.Fatalf("unmarshal denied response: %v\nraw=%s", err, resp)
|
||||
}
|
||||
if denied.Error != "trader not found for current user" {
|
||||
t.Fatalf("unexpected denied response: %+v", denied)
|
||||
}
|
||||
}
|
||||
@@ -86,6 +86,7 @@ func TestToolManageModelConfigLifecycle(t *testing.T) {
|
||||
"action":"create",
|
||||
"provider":"openai",
|
||||
"enabled":true,
|
||||
"api_key":"sk-test",
|
||||
"custom_api_url":"https://api.openai.com/v1",
|
||||
"custom_model_name":"gpt-5-mini"
|
||||
}`)
|
||||
@@ -136,6 +137,71 @@ func TestToolManageModelConfigLifecycle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolManageModelConfigRejectsEnableWithoutAPIKey(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
createResp := a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"openai",
|
||||
"enabled":false,
|
||||
"custom_model_name":"gpt-4o"
|
||||
}`)
|
||||
var created struct {
|
||||
Model safeModelToolConfig `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(createResp), &created); err != nil {
|
||||
t.Fatalf("unmarshal create response: %v\nraw=%s", err, createResp)
|
||||
}
|
||||
|
||||
updateResp := a.toolManageModelConfig("user-1", `{
|
||||
"action":"update",
|
||||
"model_id":"`+created.Model.ID+`",
|
||||
"enabled":true
|
||||
}`)
|
||||
if !strings.Contains(updateResp, "cannot enable model config before API key is configured") {
|
||||
t.Fatalf("expected enabling incomplete model to fail, got %s", updateResp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultSkipsEnabledModelWithoutAPIKey(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
incompleteCreate := a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"openai",
|
||||
"enabled":true,
|
||||
"custom_model_name":"gpt-4o"
|
||||
}`)
|
||||
var incomplete struct {
|
||||
Model safeModelToolConfig `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(incompleteCreate), &incomplete); err != nil {
|
||||
t.Fatalf("unmarshal incomplete create response: %v\nraw=%s", err, incompleteCreate)
|
||||
}
|
||||
|
||||
completeCreate := a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"deepseek",
|
||||
"enabled":true,
|
||||
"api_key":"sk-test",
|
||||
"custom_model_name":"deepseek-chat"
|
||||
}`)
|
||||
var complete struct {
|
||||
Model safeModelToolConfig `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(completeCreate), &complete); err != nil {
|
||||
t.Fatalf("unmarshal complete create response: %v\nraw=%s", err, completeCreate)
|
||||
}
|
||||
|
||||
model, err := a.store.AIModel().GetDefault("user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetDefault() error = %v", err)
|
||||
}
|
||||
if model.ID != complete.Model.ID {
|
||||
t.Fatalf("expected GetDefault to skip incomplete enabled model and return %s, got %s", complete.Model.ID, model.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolManageTraderLifecycle(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
@@ -143,6 +209,7 @@ func TestToolManageTraderLifecycle(t *testing.T) {
|
||||
"action":"create",
|
||||
"provider":"openai",
|
||||
"enabled":true,
|
||||
"api_key":"sk-test",
|
||||
"custom_api_url":"https://api.openai.com/v1",
|
||||
"custom_model_name":"gpt-5-mini"
|
||||
}`)
|
||||
|
||||
@@ -16,14 +16,14 @@ type llmSkillRouteDecision struct {
|
||||
Filter string `json:"filter,omitempty"`
|
||||
}
|
||||
|
||||
func (a *Agent) tryLLMSkillRoute(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool) {
|
||||
func (a *Agent) tryLLMSkillRoute(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) {
|
||||
if a.aiClient == nil {
|
||||
return "", false
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return "", false
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
recentConversationCtx := a.buildRecentConversationContext(userID, text)
|
||||
@@ -47,10 +47,17 @@ Available skills:
|
||||
- model_diagnosis
|
||||
- strategy_diagnosis
|
||||
|
||||
For management skills, choose one action from:
|
||||
- query
|
||||
For management skills, choose one atomic action from:
|
||||
- query_list
|
||||
- query_detail
|
||||
- query_running
|
||||
- create
|
||||
- update
|
||||
- update_name
|
||||
- update_bindings
|
||||
- update_status
|
||||
- update_endpoint
|
||||
- update_config
|
||||
- update_prompt
|
||||
- delete
|
||||
- start
|
||||
- stop
|
||||
@@ -69,7 +76,8 @@ Rules:
|
||||
- Prefer route "planner" when uncertain.
|
||||
- Prefer route "planner" for market analysis, broad advice, multi-step troubleshooting, or requests that need synthesis.
|
||||
- Prefer route "skill" for straightforward management requests like listing, creating, starting, stopping, enabling, disabling, renaming, or deleting known entities.
|
||||
- Questions like "当前有运行中的trader吗" and "有没有 trader 在跑" are trader_management with action "query" and filter "running_only".
|
||||
- Questions like "当前有运行中的trader吗" and "有没有 trader 在跑" are trader_management with action "query_running".
|
||||
- Questions about one entity's details, config, parameters, or prompt should prefer action "query_detail".
|
||||
- Do not use route "skill" for casual chat.
|
||||
- Consider Recent conversation, Task state, and Execution state JSON before deciding.
|
||||
|
||||
@@ -88,17 +96,37 @@ Return JSON with this exact shape:
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
if err != nil {
|
||||
return "", false
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
decision, err := parseLLMSkillRouteDecision(raw)
|
||||
if err != nil || decision.Route != "skill" {
|
||||
return "", false
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
answer, ok := a.executeLLMSkillRoute(storeUserID, userID, lang, text, decision)
|
||||
outcome, ok := a.executeLLMSkillRoute(storeUserID, userID, lang, text, decision)
|
||||
if !ok {
|
||||
return "", false
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
review, err := a.reviewTaskCompletion(ctx, userID, lang, text, outcome)
|
||||
if err != nil {
|
||||
if outcome.Status == skillOutcomeRecoverableError || outcome.Status == skillOutcomeFatalError || outcome.Status == skillOutcomeNotHandled {
|
||||
return "", false, nil
|
||||
}
|
||||
review = taskReviewDecision{Route: "complete", Answer: outcome.UserMessage}
|
||||
}
|
||||
if review.Route == "replan" {
|
||||
answer, planErr := a.runPlannedAgent(ctx, storeUserID, userID, lang, fmt.Sprintf("Original user request:\n%s\n\nPrevious skill outcome JSON:\n%s", text, mustMarshalJSON(outcome)), onEvent)
|
||||
return answer, true, planErr
|
||||
}
|
||||
|
||||
answer := strings.TrimSpace(review.Answer)
|
||||
if answer == "" {
|
||||
answer = strings.TrimSpace(outcome.UserMessage)
|
||||
}
|
||||
if answer == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
a.recordSkillInteraction(userID, text, answer)
|
||||
@@ -113,7 +141,7 @@ Return JSON with this exact shape:
|
||||
onEvent(StreamEventTool, label)
|
||||
onEvent(StreamEventDelta, answer)
|
||||
}
|
||||
return answer, true
|
||||
return answer, true, nil
|
||||
}
|
||||
|
||||
func parseLLMSkillRouteDecision(raw string) (llmSkillRouteDecision, error) {
|
||||
@@ -140,43 +168,125 @@ func parseLLMSkillRouteDecision(raw string) (llmSkillRouteDecision, error) {
|
||||
func normalizeLLMSkillRouteDecision(decision llmSkillRouteDecision) llmSkillRouteDecision {
|
||||
decision.Route = strings.TrimSpace(strings.ToLower(decision.Route))
|
||||
decision.Skill = strings.TrimSpace(strings.ToLower(decision.Skill))
|
||||
decision.Action = strings.TrimSpace(strings.ToLower(decision.Action))
|
||||
decision.Filter = strings.TrimSpace(strings.ToLower(decision.Filter))
|
||||
if decision.Action == "query" && decision.Filter == "running_only" && decision.Skill == "trader_management" {
|
||||
decision.Action = "query_running"
|
||||
} else {
|
||||
decision.Action = normalizeAtomicSkillAction(decision.Skill, decision.Action)
|
||||
}
|
||||
return decision
|
||||
}
|
||||
|
||||
func (a *Agent) executeLLMSkillRoute(storeUserID string, userID int64, lang, text string, decision llmSkillRouteDecision) (string, bool) {
|
||||
func (a *Agent) executeLLMSkillRoute(storeUserID string, userID int64, lang, text string, decision llmSkillRouteDecision) (skillOutcome, bool) {
|
||||
session := skillSession{Name: decision.Skill, Action: decision.Action}
|
||||
|
||||
switch decision.Skill {
|
||||
case "trader_management":
|
||||
if decision.Action == "create" {
|
||||
return a.handleCreateTraderSkill(storeUserID, userID, lang, text, session)
|
||||
answer, handled := a.handleCreateTraderSkill(storeUserID, userID, lang, text, session)
|
||||
if !handled {
|
||||
return skillOutcome{}, false
|
||||
}
|
||||
return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true
|
||||
}
|
||||
answer, handled := a.handleTraderManagementSkill(storeUserID, userID, lang, text, session)
|
||||
if handled && decision.Action == "query" {
|
||||
return applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), decision.Filter), true
|
||||
if handled && decision.Action == "query_running" {
|
||||
answer = applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), "running_only")
|
||||
}
|
||||
return answer, handled
|
||||
if !handled {
|
||||
return skillOutcome{}, false
|
||||
}
|
||||
return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true
|
||||
case "exchange_management":
|
||||
return a.handleExchangeManagementSkill(storeUserID, userID, lang, text, session)
|
||||
answer, handled := a.handleExchangeManagementSkill(storeUserID, userID, lang, text, session)
|
||||
if !handled {
|
||||
return skillOutcome{}, false
|
||||
}
|
||||
return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true
|
||||
case "model_management":
|
||||
return a.handleModelManagementSkill(storeUserID, userID, lang, text, session)
|
||||
answer, handled := a.handleModelManagementSkill(storeUserID, userID, lang, text, session)
|
||||
if !handled {
|
||||
return skillOutcome{}, false
|
||||
}
|
||||
return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true
|
||||
case "strategy_management":
|
||||
return a.handleStrategyManagementSkill(storeUserID, userID, lang, text, session)
|
||||
answer, handled := a.handleStrategyManagementSkill(storeUserID, userID, lang, text, session)
|
||||
if !handled {
|
||||
return skillOutcome{}, false
|
||||
}
|
||||
return inferSkillOutcome(decision.Skill, decision.Action, answer, a.getSkillSession(userID), skillDataForAction(storeUserID, decision.Skill, decision.Action, a)), true
|
||||
case "model_diagnosis":
|
||||
return a.handleModelDiagnosisSkill(storeUserID, lang, text), true
|
||||
return skillOutcome{
|
||||
Skill: decision.Skill,
|
||||
Action: defaultIfEmpty(decision.Action, "diagnose"),
|
||||
Status: skillOutcomeSuccess,
|
||||
GoalAchieved: true,
|
||||
UserMessage: a.handleModelDiagnosisSkill(storeUserID, lang, text),
|
||||
}, true
|
||||
case "exchange_diagnosis":
|
||||
return a.handleExchangeDiagnosisSkill(storeUserID, lang, text), true
|
||||
return skillOutcome{
|
||||
Skill: decision.Skill,
|
||||
Action: defaultIfEmpty(decision.Action, "diagnose"),
|
||||
Status: skillOutcomeSuccess,
|
||||
GoalAchieved: true,
|
||||
UserMessage: a.handleExchangeDiagnosisSkill(storeUserID, lang, text),
|
||||
}, true
|
||||
case "trader_diagnosis":
|
||||
return a.handleTraderDiagnosisSkill(storeUserID, lang, text), true
|
||||
return skillOutcome{
|
||||
Skill: decision.Skill,
|
||||
Action: defaultIfEmpty(decision.Action, "diagnose"),
|
||||
Status: skillOutcomeSuccess,
|
||||
GoalAchieved: true,
|
||||
UserMessage: a.handleTraderDiagnosisSkill(storeUserID, lang, text),
|
||||
}, true
|
||||
case "strategy_diagnosis":
|
||||
return a.handleStrategyDiagnosisSkill(storeUserID, lang, text), true
|
||||
return skillOutcome{
|
||||
Skill: decision.Skill,
|
||||
Action: defaultIfEmpty(decision.Action, "diagnose"),
|
||||
Status: skillOutcomeSuccess,
|
||||
GoalAchieved: true,
|
||||
UserMessage: a.handleStrategyDiagnosisSkill(storeUserID, lang, text),
|
||||
}, true
|
||||
default:
|
||||
return "", false
|
||||
return skillOutcome{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func skillDataForAction(storeUserID, skill, action string, a *Agent) map[string]any {
|
||||
var raw string
|
||||
switch skill {
|
||||
case "trader_management":
|
||||
if strings.HasPrefix(action, "query") {
|
||||
raw = a.toolListTraders(storeUserID)
|
||||
}
|
||||
case "exchange_management":
|
||||
if strings.HasPrefix(action, "query") {
|
||||
raw = a.toolGetExchangeConfigs(storeUserID)
|
||||
}
|
||||
case "model_management":
|
||||
if strings.HasPrefix(action, "query") {
|
||||
raw = a.toolGetModelConfigs(storeUserID)
|
||||
}
|
||||
case "strategy_management":
|
||||
if strings.HasPrefix(action, "query") {
|
||||
raw = a.toolGetStrategies(storeUserID)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil
|
||||
}
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &data); err != nil {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func mustMarshalJSON(v any) string {
|
||||
data, _ := json.Marshal(v)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func applyTraderQueryFilter(lang, fallback, raw, filter string) string {
|
||||
filter = strings.TrimSpace(strings.ToLower(filter))
|
||||
if filter == "" {
|
||||
|
||||
@@ -299,7 +299,7 @@ func (a *Agent) maybeCompressHistory(ctx context.Context, userID int64) {
|
||||
return
|
||||
}
|
||||
if err := a.saveTaskState(userID, updatedState); err != nil {
|
||||
a.logger.Warn("failed to persist task state", "error", err, "user_id", userID)
|
||||
a.log().Warn("failed to persist task state", "error", err, "user_id", userID)
|
||||
return
|
||||
}
|
||||
a.history.Replace(userID, recentPart)
|
||||
@@ -323,11 +323,11 @@ func (a *Agent) maybeUpdateTaskStateIncrementally(ctx context.Context, userID in
|
||||
existingState := a.getTaskState(userID)
|
||||
updatedState, err := a.summarizeRecentConversationToTaskState(ctx, userID, existingState, window)
|
||||
if err != nil {
|
||||
a.logger.Warn("failed to incrementally update task state", "error", err, "user_id", userID)
|
||||
a.log().Warn("failed to incrementally update task state", "error", err, "user_id", userID)
|
||||
return
|
||||
}
|
||||
if err := a.saveTaskState(userID, updatedState); err != nil {
|
||||
a.logger.Warn("failed to persist incremental task state", "error", err, "user_id", userID)
|
||||
a.log().Warn("failed to persist incremental task state", "error", err, "user_id", userID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -383,7 +383,7 @@ Rules:
|
||||
return TaskState{}, err
|
||||
}
|
||||
state = normalizeTaskState(state)
|
||||
a.logger.Info("compressed chat history into task state", "user_id", userID, "archived_messages", len(oldPart))
|
||||
a.log().Info("compressed chat history into task state", "user_id", userID, "archived_messages", len(oldPart))
|
||||
return state, nil
|
||||
}
|
||||
|
||||
@@ -436,7 +436,7 @@ Rules:
|
||||
return TaskState{}, err
|
||||
}
|
||||
state = normalizeTaskState(state)
|
||||
a.logger.Info("incrementally refreshed task state", "user_id", userID, "window_messages", len(recentPart))
|
||||
a.log().Info("incrementally refreshed task state", "user_id", userID, "window_messages", len(recentPart))
|
||||
return state, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -198,16 +198,6 @@ func isRealtimeAccountIntent(text string) bool {
|
||||
|
||||
func snapshotKindsForIntent(userText string) []string {
|
||||
kinds := make([]string, 0, 6)
|
||||
if isConfigOrTraderIntent(userText) {
|
||||
kinds = append(kinds,
|
||||
"current_model_configs",
|
||||
"current_exchange_configs",
|
||||
"current_traders",
|
||||
)
|
||||
}
|
||||
if isStrategyIntent(userText) {
|
||||
kinds = append(kinds, "current_strategies")
|
||||
}
|
||||
return uniqueStrings(kinds)
|
||||
}
|
||||
|
||||
@@ -756,18 +746,18 @@ func (a *Agent) thinkAndAct(ctx context.Context, storeUserID string, userID int6
|
||||
if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, nil); ok || err != nil {
|
||||
return answer, err
|
||||
}
|
||||
if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, nil); ok {
|
||||
return answer, nil
|
||||
}
|
||||
if answer, ok := a.tryLLMSkillRoute(ctx, storeUserID, userID, lang, text, nil); ok {
|
||||
return answer, nil
|
||||
}
|
||||
if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, nil); ok {
|
||||
if answer, ok := tryInstantDirectReply(lang, text); ok {
|
||||
return answer, nil
|
||||
}
|
||||
if answer, ok := a.tryReadFastPath(storeUserID, userID, lang, text); ok {
|
||||
return answer, nil
|
||||
}
|
||||
if answer, ok, err := a.tryWorkflowIntent(ctx, storeUserID, userID, lang, text, nil); ok || err != nil {
|
||||
return answer, err
|
||||
}
|
||||
if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, nil); ok {
|
||||
return answer, nil
|
||||
}
|
||||
if a.aiClient == nil {
|
||||
return a.noAIFallback(lang, text)
|
||||
}
|
||||
@@ -778,13 +768,10 @@ func (a *Agent) thinkAndActStream(ctx context.Context, storeUserID string, userI
|
||||
if answer, ok, err := a.tryStatePriorityPath(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil {
|
||||
return answer, err
|
||||
}
|
||||
if answer, ok := a.tryDirectAnswer(ctx, userID, lang, text, onEvent); ok {
|
||||
return answer, nil
|
||||
}
|
||||
if answer, ok := a.tryLLMSkillRoute(ctx, storeUserID, userID, lang, text, onEvent); ok {
|
||||
return answer, nil
|
||||
}
|
||||
if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok {
|
||||
if answer, ok := tryInstantDirectReply(lang, text); ok {
|
||||
if onEvent != nil {
|
||||
onEvent(StreamEventDelta, answer)
|
||||
}
|
||||
return answer, nil
|
||||
}
|
||||
if answer, ok := a.tryReadFastPath(storeUserID, userID, lang, text); ok {
|
||||
@@ -794,12 +781,65 @@ func (a *Agent) thinkAndActStream(ctx context.Context, storeUserID string, userI
|
||||
}
|
||||
return answer, nil
|
||||
}
|
||||
if answer, ok, err := a.tryWorkflowIntent(ctx, storeUserID, userID, lang, text, onEvent); ok || err != nil {
|
||||
return answer, err
|
||||
}
|
||||
if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok {
|
||||
return answer, nil
|
||||
}
|
||||
if a.aiClient == nil {
|
||||
return a.noAIFallback(lang, text)
|
||||
}
|
||||
return a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent)
|
||||
}
|
||||
|
||||
func tryInstantDirectReply(lang, text string) (string, bool) {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
zhReplies := map[string]string{
|
||||
"hi": "在,有什么我帮你看的?",
|
||||
"hello": "在,有什么我帮你看的?",
|
||||
"hey": "在,有什么我帮你看的?",
|
||||
"你好": "在,有什么我帮你看的?",
|
||||
"嗨": "在,有什么我帮你看的?",
|
||||
"在吗": "在,有什么我帮你看的?",
|
||||
"谢谢": "不客气。",
|
||||
"多谢": "不客气。",
|
||||
"谢了": "不客气。",
|
||||
"ok": "好。",
|
||||
"好的": "好。",
|
||||
"收到": "好。",
|
||||
}
|
||||
enReplies := map[string]string{
|
||||
"hi": "I'm here. What should we look at?",
|
||||
"hello": "I'm here. What should we look at?",
|
||||
"hey": "I'm here. What should we look at?",
|
||||
"thanks": "You're welcome.",
|
||||
"thank you": "You're welcome.",
|
||||
"ok": "Okay.",
|
||||
"okay": "Okay.",
|
||||
"got it": "Got it.",
|
||||
}
|
||||
|
||||
if lang == "zh" {
|
||||
if reply, ok := zhReplies[lower]; ok {
|
||||
return reply, true
|
||||
}
|
||||
if reply, ok := enReplies[lower]; ok {
|
||||
return reply, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
if reply, ok := enReplies[lower]; ok {
|
||||
return reply, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (a *Agent) hasActiveSkillSession(userID int64) bool {
|
||||
session := a.getSkillSession(userID)
|
||||
return strings.TrimSpace(session.Name) != ""
|
||||
@@ -818,21 +858,335 @@ func hasActiveExecutionState(state ExecutionState) bool {
|
||||
}
|
||||
|
||||
func (a *Agent) tryStatePriorityPath(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) {
|
||||
if a.hasActiveSkillSession(userID) {
|
||||
if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok {
|
||||
return answer, true, nil
|
||||
if workflow := a.getWorkflowSession(userID); hasActiveWorkflowSession(workflow) {
|
||||
answer, handled, err := a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, workflow, onEvent)
|
||||
if handled || err != nil {
|
||||
return answer, true, err
|
||||
}
|
||||
}
|
||||
if session := a.getSkillSession(userID); strings.TrimSpace(session.Name) != "" {
|
||||
switch a.classifySkillSessionInput(ctx, userID, lang, session, text) {
|
||||
case "cancel":
|
||||
a.clearSkillSession(userID)
|
||||
a.clearWorkflowSession(userID)
|
||||
if lang == "zh" {
|
||||
return "已取消当前流程。", true, nil
|
||||
}
|
||||
return "Cancelled the current flow.", true, nil
|
||||
case "interrupt":
|
||||
a.clearSkillSession(userID)
|
||||
default:
|
||||
if answer, ok := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent); ok {
|
||||
return answer, true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
state := a.getExecutionState(userID)
|
||||
if hasActiveExecutionState(state) {
|
||||
answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent)
|
||||
return answer, true, err
|
||||
switch classifyExecutionStateInput(state, text) {
|
||||
case "cancel":
|
||||
a.clearExecutionState(userID)
|
||||
if lang == "zh" {
|
||||
return "已取消当前流程。", true, nil
|
||||
}
|
||||
return "Cancelled the current flow.", true, nil
|
||||
case "interrupt":
|
||||
a.clearExecutionState(userID)
|
||||
default:
|
||||
answer, err := a.runPlannedAgent(ctx, storeUserID, userID, lang, text, onEvent)
|
||||
return answer, true, err
|
||||
}
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func (a *Agent) classifySkillSessionInput(ctx context.Context, userID int64, lang string, session skillSession, text string) string {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return "continue"
|
||||
}
|
||||
if isYesReply(text) || isNoReply(text) {
|
||||
return "continue"
|
||||
}
|
||||
if isExplicitFlowAbort(text) {
|
||||
return "cancel"
|
||||
}
|
||||
if shouldContinueSkillSessionByExpectedSlot(session, text) {
|
||||
return "continue"
|
||||
}
|
||||
if decision := a.classifySkillSessionIntentWithLLM(ctx, userID, lang, session, text); decision != "" {
|
||||
return decision
|
||||
}
|
||||
if isNewSkillRootIntent(session, text) {
|
||||
return "interrupt"
|
||||
}
|
||||
if isSkillFlowDeflection(session, text) {
|
||||
return "interrupt"
|
||||
}
|
||||
if belongsToSkillDomain(session.Name, text) || !looksLikeNewTopLevelIntent(text) {
|
||||
return "continue"
|
||||
}
|
||||
return "interrupt"
|
||||
}
|
||||
|
||||
type skillSessionIntentDecision struct {
|
||||
Decision string `json:"decision"`
|
||||
}
|
||||
|
||||
func shouldUseLLMSkillSessionClassifier(session skillSession, text string) bool {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return false
|
||||
}
|
||||
if isExplicitFlowAbort(text) || isYesReply(text) || isNoReply(text) {
|
||||
return false
|
||||
}
|
||||
if shouldContinueSkillSessionByExpectedSlot(session, text) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldContinueSkillSessionByExpectedSlot(session skillSession, text string) bool {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
currentStep, ok := currentSkillDAGStep(session)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch currentStep.ID {
|
||||
case "await_start_confirmation", "await_confirmation":
|
||||
return isYesReply(text) || isNoReply(text)
|
||||
case "resolve_config_value":
|
||||
if fieldValue(session, "config_field") == "selected_timeframes" {
|
||||
return timeframeTokenRE.MatchString(strings.ToLower(text))
|
||||
}
|
||||
return firstIntegerPattern.MatchString(text)
|
||||
case "collect_enabled":
|
||||
_, ok := parseEnabledValue(text)
|
||||
return ok
|
||||
case "collect_custom_api_url":
|
||||
return extractURL(text) != ""
|
||||
case "resolve_exchange_type":
|
||||
return exchangeTypeFromText(text) != ""
|
||||
case "resolve_provider":
|
||||
return providerFromText(text) != ""
|
||||
case "resolve_name", "collect_name", "collect_prompt", "collect_account_name", "collect_custom_model_name":
|
||||
return !looksLikeNewTopLevelIntent(text)
|
||||
}
|
||||
for _, field := range currentStep.RequiredFields {
|
||||
switch field {
|
||||
case "config_value":
|
||||
return firstIntegerPattern.MatchString(text)
|
||||
case "enabled":
|
||||
_, ok := parseEnabledValue(text)
|
||||
return ok
|
||||
case "custom_api_url":
|
||||
return extractURL(text) != ""
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Agent) classifySkillSessionIntentWithLLM(ctx context.Context, userID int64, lang string, session skillSession, text string) string {
|
||||
if a == nil || a.aiClient == nil {
|
||||
return ""
|
||||
}
|
||||
if !shouldUseLLMSkillSessionClassifier(session, text) {
|
||||
return ""
|
||||
}
|
||||
currentStep, _ := currentSkillDAGStep(session)
|
||||
recentConversationCtx := a.buildRecentConversationContext(userID, text)
|
||||
systemPrompt := `You classify one user message while a NOFXi structured management flow is active.
|
||||
Return JSON only. No markdown.
|
||||
|
||||
Possible decisions:
|
||||
- "continue": the user is still answering the current flow
|
||||
- "cancel": the user wants to stop the current flow
|
||||
- "interrupt": the user changed topic, wants diagnosis/query/new task, or should leave the current flow
|
||||
|
||||
Be conservative:
|
||||
- Prefer "continue" only when the message clearly answers the current slot/question.
|
||||
- Use "cancel" for explicit abandonment like "算了", "不改了", "换话题", "别弄了".
|
||||
- Use "interrupt" for diagnosis, query, new requests, or topic shifts.`
|
||||
userPrompt := fmt.Sprintf(
|
||||
"Language: %s\nActive skill: %s\nAction: %s\nCurrent DAG step: %s\nExpected required fields: %s\nUser message: %s\n\nRecent conversation:\n%s",
|
||||
lang,
|
||||
session.Name,
|
||||
session.Action,
|
||||
currentStep.ID,
|
||||
strings.Join(currentStep.RequiredFields, ", "),
|
||||
text,
|
||||
recentConversationCtx,
|
||||
)
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout)
|
||||
defer cancel()
|
||||
raw, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage(systemPrompt),
|
||||
mcp.NewUserMessage(userPrompt),
|
||||
},
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimPrefix(raw, "```json")
|
||||
raw = strings.TrimPrefix(raw, "```")
|
||||
raw = strings.TrimSuffix(raw, "```")
|
||||
raw = strings.TrimSpace(raw)
|
||||
var decision skillSessionIntentDecision
|
||||
if err := json.Unmarshal([]byte(raw), &decision); err != nil {
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start < 0 || end <= start || json.Unmarshal([]byte(raw[start:end+1]), &decision) != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
switch strings.TrimSpace(decision.Decision) {
|
||||
case "continue", "cancel", "interrupt":
|
||||
return decision.Decision
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func isSkillFlowDeflection(session skillSession, text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if containsAny(lower, []string{
|
||||
"看下报错", "看看报错", "帮我看下报错", "帮我看看报错", "报错怎么回事", "错误怎么回事",
|
||||
"换话题", "聊别的", "不是这个", "先说别的", "不聊这个",
|
||||
}) {
|
||||
return true
|
||||
}
|
||||
switch strings.TrimSpace(session.Name) {
|
||||
case "exchange_management":
|
||||
return detectModelDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text)
|
||||
case "model_management":
|
||||
return detectExchangeDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text)
|
||||
case "strategy_management":
|
||||
return detectExchangeDiagnosisSkill(text) || detectTraderDiagnosisSkill(text) || detectModelDiagnosisSkill(text)
|
||||
case "trader_management":
|
||||
return detectExchangeDiagnosisSkill(text) || detectModelDiagnosisSkill(text) || detectStrategyDiagnosisSkill(text)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isNewSkillRootIntent(session skillSession, text string) bool {
|
||||
currentSkill := strings.TrimSpace(session.Name)
|
||||
currentAction := strings.TrimSpace(session.Action)
|
||||
if currentSkill == "" {
|
||||
return false
|
||||
}
|
||||
switch currentSkill {
|
||||
case "trader_management":
|
||||
if detectCreateTraderSkill(text) && currentAction != "create" {
|
||||
return true
|
||||
}
|
||||
if action := normalizeAtomicSkillAction("trader_management", detectManagementAction(text, "trader")); action == "create" && currentAction != "create" {
|
||||
return true
|
||||
}
|
||||
case "strategy_management":
|
||||
if action := normalizeAtomicSkillAction("strategy_management", detectManagementAction(text, "strategy")); action == "create" && currentAction != "create" {
|
||||
return true
|
||||
}
|
||||
case "model_management":
|
||||
if action := normalizeAtomicSkillAction("model_management", detectManagementAction(text, "model")); action == "create" && currentAction != "create" {
|
||||
return true
|
||||
}
|
||||
case "exchange_management":
|
||||
if action := normalizeAtomicSkillAction("exchange_management", detectManagementAction(text, "exchange")); action == "create" && currentAction != "create" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func classifyExecutionStateInput(state ExecutionState, text string) string {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return "continue"
|
||||
}
|
||||
if isExplicitFlowAbort(text) {
|
||||
return "cancel"
|
||||
}
|
||||
if isYesReply(text) || isNoReply(text) || shouldResetExecutionStateForNewAttempt(text, state) {
|
||||
return "continue"
|
||||
}
|
||||
if state.Waiting != nil && !looksLikeNewTopLevelIntent(text) {
|
||||
return "continue"
|
||||
}
|
||||
if looksLikeNewTopLevelIntent(text) {
|
||||
return "interrupt"
|
||||
}
|
||||
return "continue"
|
||||
}
|
||||
|
||||
func isExplicitFlowAbort(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if isCancelSkillReply(text) {
|
||||
return true
|
||||
}
|
||||
return containsAny(lower, []string{
|
||||
"算了", "先不", "不配了", "别弄了", "不搞了", "先停", "换个话题", "换话题", "聊点别的", "聊别的",
|
||||
"stop this", "drop it", "never mind", "forget it", "skip this",
|
||||
})
|
||||
}
|
||||
|
||||
func belongsToSkillDomain(skillName, text string) bool {
|
||||
switch strings.TrimSpace(skillName) {
|
||||
case "trader_management":
|
||||
return detectCreateTraderSkill(text) || detectTraderManagementIntent(text) || detectTraderDiagnosisSkill(text)
|
||||
case "strategy_management":
|
||||
return detectStrategyManagementIntent(text) || detectStrategyDiagnosisSkill(text)
|
||||
case "model_management":
|
||||
return detectModelManagementIntent(text) || detectModelDiagnosisSkill(text)
|
||||
case "exchange_management":
|
||||
return detectExchangeManagementIntent(text) || detectExchangeDiagnosisSkill(text)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func looksLikeNewTopLevelIntent(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(lower, "/") {
|
||||
return true
|
||||
}
|
||||
if detectCreateTraderSkill(text) ||
|
||||
detectTraderManagementIntent(text) ||
|
||||
detectExchangeManagementIntent(text) ||
|
||||
detectModelManagementIntent(text) ||
|
||||
detectStrategyManagementIntent(text) ||
|
||||
detectTraderDiagnosisSkill(text) ||
|
||||
detectExchangeDiagnosisSkill(text) ||
|
||||
detectModelDiagnosisSkill(text) ||
|
||||
detectStrategyDiagnosisSkill(text) {
|
||||
return true
|
||||
}
|
||||
if detectReadFastPath(text) != nil {
|
||||
return true
|
||||
}
|
||||
return containsAny(lower, []string{
|
||||
"btc", "eth", "sol", "市场", "行情", "余额", "仓位", "持仓", "订单", "账户",
|
||||
"price", "market", "balance", "position", "portfolio", "account",
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Agent) tryDirectAnswer(ctx context.Context, userID int64, lang, text string, onEvent func(event, data string)) (string, bool) {
|
||||
if a.aiClient == nil {
|
||||
return "", false
|
||||
@@ -948,8 +1302,10 @@ func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID
|
||||
onEvent(StreamEventPlanning, a.planningStatusText(lang))
|
||||
}
|
||||
|
||||
requestStartedAt := time.Now()
|
||||
state, err := a.prepareExecutionState(ctx, storeUserID, userID, lang, text)
|
||||
if err != nil {
|
||||
a.logPlannerTiming("", userID, "prepare_execution_state", requestStartedAt, err)
|
||||
if isPlannerTimeoutError(err) {
|
||||
msg := plannerTimeoutMessage(lang)
|
||||
if onEvent != nil {
|
||||
@@ -961,8 +1317,11 @@ func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID
|
||||
a.logger.Warn("planner failed, falling back to legacy loop", "error", err, "user_id", userID)
|
||||
return a.thinkAndActLegacy(ctx, userID, lang, text, onEvent)
|
||||
}
|
||||
a.logPlannerTiming(state.SessionID, userID, "prepare_execution_state", requestStartedAt, nil)
|
||||
|
||||
executionStartedAt := time.Now()
|
||||
answer, err := a.executePlan(ctx, storeUserID, userID, lang, &state, onEvent)
|
||||
a.logPlannerTiming(state.SessionID, userID, "execute_plan", executionStartedAt, err)
|
||||
if err != nil {
|
||||
if isPlannerTimeoutError(err) {
|
||||
msg := plannerTimeoutMessage(lang)
|
||||
@@ -979,6 +1338,7 @@ func (a *Agent) runPlannedAgent(ctx context.Context, storeUserID string, userID
|
||||
a.history.Add(userID, "assistant", answer)
|
||||
a.maybeUpdateTaskStateIncrementally(ctx, userID)
|
||||
a.maybeCompressHistory(ctx, userID)
|
||||
a.logPlannerTiming(state.SessionID, userID, "run_planned_agent_total", requestStartedAt, nil)
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
@@ -1005,12 +1365,7 @@ func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, u
|
||||
existing.FinalAnswer = ""
|
||||
existing.LastError = ""
|
||||
existing = a.refreshStateForDynamicRequests(storeUserID, text, existing)
|
||||
plan, err := a.createExecutionPlan(ctx, userID, lang, text, existing)
|
||||
if err != nil {
|
||||
return ExecutionState{}, err
|
||||
}
|
||||
existing.Goal = plan.Goal
|
||||
existing.Steps = plan.Steps
|
||||
existing.Steps = completedSteps(existing.Steps)
|
||||
existing.CurrentStepID = ""
|
||||
existing.Status = executionStatusRunning
|
||||
existing.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
@@ -1023,12 +1378,6 @@ func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, u
|
||||
state := newExecutionState(userID, text)
|
||||
a.refreshCurrentReferencesForUserText(storeUserID, text, &state)
|
||||
state = a.refreshStateForDynamicRequests(storeUserID, text, state)
|
||||
plan, err := a.createExecutionPlan(ctx, userID, lang, text, state)
|
||||
if err != nil {
|
||||
return ExecutionState{}, err
|
||||
}
|
||||
state.Goal = plan.Goal
|
||||
state.Steps = plan.Steps
|
||||
state.Status = executionStatusRunning
|
||||
if err := a.saveExecutionState(state); err != nil {
|
||||
return ExecutionState{}, err
|
||||
@@ -1036,6 +1385,114 @@ func (a *Agent) prepareExecutionState(ctx context.Context, storeUserID string, u
|
||||
return state, nil
|
||||
}
|
||||
|
||||
type nextStepDecision struct {
|
||||
Goal string `json:"goal"`
|
||||
Steps []PlanStep `json:"steps,omitempty"`
|
||||
Step PlanStep `json:"step"`
|
||||
}
|
||||
|
||||
func (a *Agent) decideNextStep(ctx context.Context, userID int64, lang string, state ExecutionState) (nextStepDecision, error) {
|
||||
toolDefs, _ := json.Marshal(agentTools())
|
||||
stateJSON, _ := json.Marshal(normalizeExecutionState(state))
|
||||
obsJSON, _ := json.Marshal(buildObservationContext(state))
|
||||
recentlyFetchedJSON, _ := json.Marshal(buildRecentlyFetchedData(state, time.Now().UTC()))
|
||||
taskStateCtx := buildTaskStateContext(a.getTaskState(userID))
|
||||
recentConversationCtx := a.buildRecentConversationContext(userID, state.Goal)
|
||||
|
||||
systemPrompt := `You are the step selector for NOFXi.
|
||||
Return JSON only. Do not return markdown.
|
||||
|
||||
You are operating in ReAct mode: Thought -> Action -> Observation.
|
||||
Choose the immediate next action batch. Do not generate a long multi-step execution plan.
|
||||
|
||||
Allowed step types:
|
||||
- tool
|
||||
- reason
|
||||
- ask_user
|
||||
- respond
|
||||
|
||||
Rules:
|
||||
- Use all available memory layers: Execution state JSON, Observations JSON, Recent conversation, and Task state.
|
||||
- Use Recently fetched data JSON as the deduplication source of truth for fresh tool results.
|
||||
- Prefer the freshest evidence in this order: execution state, observations, recent conversation, then task state.
|
||||
- If fresh external or system data is needed, choose a tool step.
|
||||
- If the user is blocked on a missing parameter, choose ask_user.
|
||||
- If there is enough information to answer now, choose respond.
|
||||
- Use reason only when a short intermediate synthesis is necessary before the next action.
|
||||
- Prefer tool or respond over reason whenever possible.
|
||||
- Never emit the same reason step twice in a row.
|
||||
- After a reason step, the next batch should usually be tool, ask_user, or respond. Do not stay in analysis loops.
|
||||
- Never invent tools.
|
||||
- If the task needs multiple independent tool reads, emit ALL of them together in one response.
|
||||
- Parallelism rule: when multiple tool reads are mutually independent, do not split them across turns. Return them together in steps.
|
||||
- Never mix ask_user/respond with additional steps in the same batch.
|
||||
- Only emit multiple steps when every emitted step is a tool step.
|
||||
- Avoid repeated tool calls. If a matching tool call already exists in Recently fetched data and age_seconds <= 60, do not call it again unless the user explicitly asks to refresh.
|
||||
- For tool steps, set tool_name exactly to one available tool and provide tool_args as a JSON object.
|
||||
- For ask_user or respond steps, put the user-facing question/response instruction in instruction.
|
||||
- If the latest observation already answers the goal, prefer respond over another tool call.
|
||||
- Never place a trade unless the user intent is explicit.
|
||||
|
||||
Return JSON with this exact shape:
|
||||
{"goal":"","steps":[{"id":"step_1","type":"tool|reason|ask_user|respond","title":"","tool_name":"","tool_args":{},"instruction":"","requires_confirmation":false}]}`
|
||||
|
||||
userPrompt := fmt.Sprintf("Language: %s\nGoal: %s\n\nRecent conversation:\n%s\n\nAvailable tools JSON:\n%s\n\nPersistent preferences:\n%s\n\nTask state:\n%s\n\nExecution state JSON:\n%s\n\nObservations JSON:\n%s\n\nRecently fetched data JSON:\n%s", lang, state.Goal, recentConversationCtx, string(toolDefs), a.buildPersistentPreferencesContext(userID), taskStateCtx, string(stateJSON), string(obsJSON), string(recentlyFetchedJSON))
|
||||
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, plannerCreateTimeout)
|
||||
defer cancel()
|
||||
|
||||
startedAt := time.Now()
|
||||
raw, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage(systemPrompt),
|
||||
mcp.NewUserMessage(userPrompt),
|
||||
},
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
a.logPlannerTiming(state.SessionID, userID, "decide_next_step_llm", startedAt, err)
|
||||
if err != nil {
|
||||
return nextStepDecision{}, err
|
||||
}
|
||||
return parseNextStepDecisionJSON(raw)
|
||||
}
|
||||
|
||||
func parseNextStepDecisionJSON(raw string) (nextStepDecision, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimPrefix(raw, "```json")
|
||||
raw = strings.TrimPrefix(raw, "```")
|
||||
raw = strings.TrimSuffix(raw, "```")
|
||||
raw = strings.TrimSpace(raw)
|
||||
|
||||
var decision nextStepDecision
|
||||
if err := json.Unmarshal([]byte(raw), &decision); err == nil {
|
||||
return normalizeNextStepDecision(decision), nil
|
||||
}
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start >= 0 && end > start {
|
||||
if err := json.Unmarshal([]byte(raw[start:end+1]), &decision); err == nil {
|
||||
return normalizeNextStepDecision(decision), nil
|
||||
}
|
||||
}
|
||||
return nextStepDecision{}, fmt.Errorf("invalid next step decision json")
|
||||
}
|
||||
|
||||
func normalizeNextStepDecision(decision nextStepDecision) nextStepDecision {
|
||||
decision.Goal = strings.TrimSpace(decision.Goal)
|
||||
steps := decision.Steps
|
||||
if len(steps) == 0 && decision.Step.Type != "" {
|
||||
steps = []PlanStep{decision.Step}
|
||||
}
|
||||
if len(steps) > 0 {
|
||||
steps = normalizeExecutionState(ExecutionState{Steps: steps}).Steps
|
||||
}
|
||||
decision.Steps = steps
|
||||
if len(steps) > 0 {
|
||||
decision.Step = steps[0]
|
||||
}
|
||||
return decision
|
||||
}
|
||||
|
||||
func (a *Agent) refreshStateForDynamicRequests(storeUserID, userText string, state ExecutionState) ExecutionState {
|
||||
kinds := snapshotKindsForIntent(userText)
|
||||
if len(kinds) == 0 {
|
||||
@@ -1187,6 +1644,7 @@ Rules:
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, plannerCreateTimeout)
|
||||
defer cancel()
|
||||
|
||||
startedAt := time.Now()
|
||||
resp, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage(systemPrompt),
|
||||
@@ -1194,6 +1652,7 @@ Rules:
|
||||
},
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
a.logPlannerTiming(state.SessionID, userID, "create_execution_plan_llm", startedAt, err)
|
||||
if err != nil {
|
||||
return executionPlan{}, err
|
||||
}
|
||||
@@ -1247,28 +1706,63 @@ func parseExecutionPlanJSON(raw string) (executionPlan, error) {
|
||||
}
|
||||
|
||||
func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int64, lang string, state *ExecutionState, onEvent func(event, data string)) (string, error) {
|
||||
if onEvent != nil {
|
||||
if onEvent != nil && len(state.Steps) > 0 {
|
||||
onEvent(StreamEventPlan, formatPlanStatus(*state, lang))
|
||||
}
|
||||
|
||||
for i := 0; i < plannerMaxIterations; i++ {
|
||||
stepIndex := nextPendingStepIndex(state.Steps)
|
||||
if stepIndex < 0 {
|
||||
finalText, err := a.generateFinalPlanResponse(ctx, userID, lang, *state, "")
|
||||
decisionStartedAt := time.Now()
|
||||
decision, err := a.decideNextStep(ctx, userID, lang, *state)
|
||||
a.logPlannerTiming(state.SessionID, userID, "decide_next_step", decisionStartedAt, err)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
state.Status = executionStatusCompleted
|
||||
state.FinalAnswer = finalText
|
||||
state.CurrentStepID = ""
|
||||
steps := filterFreshDuplicateToolSteps(decision.Steps, *state, time.Now().UTC())
|
||||
if len(steps) == 0 {
|
||||
appendExecutionLog(state, Observation{
|
||||
Kind: "decision_note",
|
||||
Summary: "Skipped duplicate fresh tool calls from next-step decision",
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
state.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := a.saveExecutionState(*state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if hasRepeatedReasonLoop(*state, steps) {
|
||||
return "", fmt.Errorf("repeated reasoning loop detected")
|
||||
}
|
||||
if decision.Goal != "" {
|
||||
state.Goal = decision.Goal
|
||||
}
|
||||
base := len(completedSteps(state.Steps))
|
||||
for idx := range steps {
|
||||
if steps[idx].Type == "" {
|
||||
return "", fmt.Errorf("next step decision missing step type")
|
||||
}
|
||||
if steps[idx].ID == "" {
|
||||
steps[idx].ID = fmt.Sprintf("step_%d", base+idx+1)
|
||||
}
|
||||
if steps[idx].Title == "" {
|
||||
steps[idx].Title = strings.ReplaceAll(steps[idx].ID, "_", " ")
|
||||
}
|
||||
if steps[idx].Status == "" {
|
||||
steps[idx].Status = planStepStatusPending
|
||||
}
|
||||
}
|
||||
state.Steps = append(completedSteps(state.Steps), steps...)
|
||||
state.Status = executionStatusRunning
|
||||
state.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := a.saveExecutionState(*state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if onEvent != nil {
|
||||
onEvent(StreamEventDelta, finalText)
|
||||
onEvent(StreamEventPlan, formatPlanStatus(*state, lang))
|
||||
}
|
||||
return finalText, nil
|
||||
continue
|
||||
}
|
||||
|
||||
step := &state.Steps[stepIndex]
|
||||
@@ -1288,7 +1782,9 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6
|
||||
if onEvent != nil {
|
||||
onEvent(StreamEventTool, step.ToolName)
|
||||
}
|
||||
stepStartedAt := time.Now()
|
||||
result := a.executePlanTool(ctx, storeUserID, userID, lang, *step)
|
||||
a.logPlannerTiming(state.SessionID, userID, "tool:"+step.ToolName, stepStartedAt, nil)
|
||||
summary := summarizeObservation(result)
|
||||
referencesChanged := false
|
||||
step.Status = planStepStatusCompleted
|
||||
@@ -1301,29 +1797,11 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
referencesChanged = updateCurrentReferencesFromToolResult(state, step.ToolName, result)
|
||||
if shouldAttemptReplan(*state, *step, referencesChanged) {
|
||||
state.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := a.saveExecutionState(*state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if onEvent != nil {
|
||||
onEvent(StreamEventStepComplete, formatStepCompleteStatus(*step, lang))
|
||||
}
|
||||
decision, err := a.replanAfterStep(ctx, userID, lang, *state, *step)
|
||||
if err == nil && applyReplannerDecision(state, decision) {
|
||||
state.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := a.saveExecutionState(*state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if onEvent != nil {
|
||||
onEvent(StreamEventReplan, formatReplanStatus(decision, lang))
|
||||
onEvent(StreamEventPlan, formatPlanStatus(*state, lang))
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = referencesChanged
|
||||
case planStepTypeReason:
|
||||
reasonStartedAt := time.Now()
|
||||
reasoning, err := a.executeReasonStep(ctx, userID, lang, state.Goal, *state, *step)
|
||||
a.logPlannerTiming(state.SessionID, userID, "reason_step", reasonStartedAt, err)
|
||||
if err != nil {
|
||||
step.Status = planStepStatusFailed
|
||||
step.Error = err.Error()
|
||||
@@ -1364,7 +1842,9 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6
|
||||
}
|
||||
return question, nil
|
||||
case planStepTypeRespond:
|
||||
respondStartedAt := time.Now()
|
||||
finalText, err := a.generateFinalPlanResponse(ctx, userID, lang, *state, step.Instruction)
|
||||
a.logPlannerTiming(state.SessionID, userID, "respond_step", respondStartedAt, err)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1399,6 +1879,134 @@ func (a *Agent) executePlan(ctx context.Context, storeUserID string, userID int6
|
||||
return "", fmt.Errorf("plan execution exceeded iteration limit")
|
||||
}
|
||||
|
||||
type fetchedToolRecord struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolArgsJSON string `json:"tool_args_json"`
|
||||
FetchedAt string `json:"fetched_at"`
|
||||
AgeSeconds int64 `json:"age_seconds"`
|
||||
}
|
||||
|
||||
func buildRecentlyFetchedData(state ExecutionState, now time.Time) []fetchedToolRecord {
|
||||
state = normalizeExecutionState(state)
|
||||
stepByID := make(map[string]PlanStep, len(state.Steps))
|
||||
for _, step := range state.Steps {
|
||||
stepByID[step.ID] = step
|
||||
}
|
||||
latest := map[string]fetchedToolRecord{}
|
||||
for _, obs := range state.ExecutionLog {
|
||||
if obs.Kind != "tool_result" {
|
||||
continue
|
||||
}
|
||||
step, ok := stepByID[obs.StepID]
|
||||
if !ok || step.ToolName == "" {
|
||||
continue
|
||||
}
|
||||
sig := toolCallSignature(step.ToolName, step.ToolArgs)
|
||||
createdAt := parseRFC3339(obs.CreatedAt)
|
||||
record := fetchedToolRecord{
|
||||
ToolName: step.ToolName,
|
||||
ToolArgsJSON: toolArgsJSONString(step.ToolArgs),
|
||||
FetchedAt: obs.CreatedAt,
|
||||
AgeSeconds: int64(now.Sub(createdAt).Seconds()),
|
||||
}
|
||||
prev, exists := latest[sig]
|
||||
if !exists || prev.FetchedAt < record.FetchedAt {
|
||||
latest[sig] = record
|
||||
}
|
||||
}
|
||||
out := make([]fetchedToolRecord, 0, len(latest))
|
||||
for _, record := range latest {
|
||||
if record.AgeSeconds < 0 {
|
||||
record.AgeSeconds = 0
|
||||
}
|
||||
out = append(out, record)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filterFreshDuplicateToolSteps(steps []PlanStep, state ExecutionState, now time.Time) []PlanStep {
|
||||
if len(steps) == 0 {
|
||||
return nil
|
||||
}
|
||||
fresh := make(map[string]struct{})
|
||||
for _, item := range buildRecentlyFetchedData(state, now) {
|
||||
if item.AgeSeconds <= 60 {
|
||||
fresh[item.ToolName+"|"+item.ToolArgsJSON] = struct{}{}
|
||||
}
|
||||
}
|
||||
out := make([]PlanStep, 0, len(steps))
|
||||
for _, step := range steps {
|
||||
if step.Type != planStepTypeTool {
|
||||
out = append(out, step)
|
||||
continue
|
||||
}
|
||||
sig := toolCallSignature(step.ToolName, step.ToolArgs)
|
||||
if _, ok := fresh[sig]; ok {
|
||||
continue
|
||||
}
|
||||
fresh[sig] = struct{}{}
|
||||
out = append(out, step)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func hasRepeatedReasonLoop(state ExecutionState, steps []PlanStep) bool {
|
||||
if len(steps) == 0 {
|
||||
return false
|
||||
}
|
||||
last := lastCompletedStep(state.Steps)
|
||||
if last == nil || last.Type != planStepTypeReason {
|
||||
return false
|
||||
}
|
||||
for _, step := range steps {
|
||||
if step.Type != planStepTypeReason {
|
||||
return false
|
||||
}
|
||||
if stepSemanticKey(*last) != stepSemanticKey(step) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func lastCompletedStep(steps []PlanStep) *PlanStep {
|
||||
for i := len(steps) - 1; i >= 0; i-- {
|
||||
if steps[i].Status == planStepStatusCompleted {
|
||||
return &steps[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func stepSemanticKey(step PlanStep) string {
|
||||
return strings.ToLower(strings.TrimSpace(
|
||||
step.Type + "|" + step.ToolName + "|" + step.Title + "|" + step.Instruction,
|
||||
))
|
||||
}
|
||||
|
||||
func toolCallSignature(toolName string, args map[string]any) string {
|
||||
return strings.TrimSpace(toolName) + "|" + toolArgsJSONString(args)
|
||||
}
|
||||
|
||||
func toolArgsJSONString(args map[string]any) string {
|
||||
if len(args) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
data, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func parseRFC3339(value string) time.Time {
|
||||
t, err := time.Parse(time.RFC3339, strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (a *Agent) replanAfterStep(ctx context.Context, userID int64, lang string, state ExecutionState, completedStep PlanStep) (replannerDecision, error) {
|
||||
obsJSON, _ := json.Marshal(buildObservationContext(state))
|
||||
stepsJSON, _ := json.Marshal(state.Steps)
|
||||
@@ -1426,6 +2034,7 @@ Rules:
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, plannerReplanTimeout)
|
||||
defer cancel()
|
||||
|
||||
startedAt := time.Now()
|
||||
raw, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage(systemPrompt),
|
||||
@@ -1434,6 +2043,7 @@ Rules:
|
||||
Ctx: stageCtx,
|
||||
MaxTokens: intPtr(500),
|
||||
})
|
||||
a.logPlannerTiming(state.SessionID, userID, "replan_after_step_llm", startedAt, err)
|
||||
if err != nil {
|
||||
return replannerDecision{}, err
|
||||
}
|
||||
@@ -1688,6 +2298,7 @@ func (a *Agent) executeReasonStep(ctx context.Context, userID int64, lang, goal
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, plannerReasonTimeout)
|
||||
defer cancel()
|
||||
|
||||
startedAt := time.Now()
|
||||
resp, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage("You are the reasoning module for NOFXi. Return one short paragraph only. No markdown, no bullet list."),
|
||||
@@ -1695,6 +2306,7 @@ func (a *Agent) executeReasonStep(ctx context.Context, userID int64, lang, goal
|
||||
},
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
a.logPlannerTiming(state.SessionID, userID, "reason_step_llm", startedAt, err)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1709,7 +2321,8 @@ func (a *Agent) generateFinalPlanResponse(ctx context.Context, userID int64, lan
|
||||
}
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, plannerFinalTimeout)
|
||||
defer cancel()
|
||||
return a.aiClient.CallWithRequest(&mcp.Request{
|
||||
startedAt := time.Now()
|
||||
resp, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage(systemPrompt),
|
||||
mcp.NewSystemMessage("You are responding after a completed execution plan. Use the observations as the source of truth. Be concise and actionable."),
|
||||
@@ -1717,6 +2330,24 @@ func (a *Agent) generateFinalPlanResponse(ctx context.Context, userID int64, lan
|
||||
},
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
a.logPlannerTiming(state.SessionID, userID, "generate_final_response_llm", startedAt, err)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (a *Agent) logPlannerTiming(sessionID string, userID int64, stage string, startedAt time.Time, err error) {
|
||||
if stage == "" || startedAt.IsZero() {
|
||||
return
|
||||
}
|
||||
attrs := []any{
|
||||
"session_id", sessionID,
|
||||
"user_id", userID,
|
||||
"stage", stage,
|
||||
"elapsed_ms", time.Since(startedAt).Milliseconds(),
|
||||
}
|
||||
if err != nil {
|
||||
attrs = append(attrs, "error", err.Error())
|
||||
}
|
||||
a.log().Info("planner timing", attrs...)
|
||||
}
|
||||
|
||||
func nextPendingStepIndex(steps []PlanStep) int {
|
||||
|
||||
@@ -617,6 +617,39 @@ func TestThinkAndActPrioritizesActiveExecutionStateOverDirectReply(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkAndActInterruptsWaitingExecutionStateForNewTopic(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
a.history = newChatHistory(10)
|
||||
|
||||
_ = a.toolManageStrategy("user-1", `{
|
||||
"action":"create",
|
||||
"name":"激进",
|
||||
"lang":"zh"
|
||||
}`)
|
||||
|
||||
userID := int64(91)
|
||||
state := newExecutionState(userID, "创建交易员")
|
||||
state.Status = executionStatusWaitingUser
|
||||
state.Waiting = &WaitingState{
|
||||
Question: "请告诉我交易员名称",
|
||||
PendingFields: []string{"name"},
|
||||
}
|
||||
if err := a.saveExecutionState(state); err != nil {
|
||||
t.Fatalf("saveExecutionState() error = %v", err)
|
||||
}
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", userID, "zh", "列出我当前的策略")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "当前策略") || !strings.Contains(resp, "激进") {
|
||||
t.Fatalf("expected new topic to be handled, got %q", resp)
|
||||
}
|
||||
if got := a.getExecutionState(userID); got.SessionID != "" {
|
||||
t.Fatalf("expected execution state to be cleared, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateExecutionPlanIncludesRecentConversation(t *testing.T) {
|
||||
client := &capturePlannerAIClient{}
|
||||
a := &Agent{
|
||||
|
||||
277
agent/skill_dag.go
Normal file
277
agent/skill_dag.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package agent
|
||||
|
||||
import "strings"
|
||||
|
||||
type SkillDAG struct {
|
||||
SkillName string
|
||||
Action string
|
||||
Steps []SkillDAGStep
|
||||
}
|
||||
|
||||
type SkillDAGStep struct {
|
||||
ID string
|
||||
Kind string
|
||||
RequiredFields []string
|
||||
OptionalFields []string
|
||||
Next []string
|
||||
Terminal bool
|
||||
}
|
||||
|
||||
var skillDAGRegistry = buildSkillDAGRegistry()
|
||||
|
||||
func buildSkillDAGRegistry() map[string]SkillDAG {
|
||||
dags := []SkillDAG{
|
||||
{
|
||||
SkillName: "trader_management",
|
||||
Action: "create",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"resolve_exchange"}},
|
||||
{ID: "resolve_exchange", Kind: "collect_slot", RequiredFields: []string{"exchange_id"}, OptionalFields: []string{"exchange_name"}, Next: []string{"resolve_model"}},
|
||||
{ID: "resolve_model", Kind: "collect_slot", RequiredFields: []string{"model_id"}, OptionalFields: []string{"model_name"}, Next: []string{"resolve_strategy"}},
|
||||
{ID: "resolve_strategy", Kind: "collect_slot", RequiredFields: []string{"strategy_id"}, OptionalFields: []string{"strategy_name"}, Next: []string{"maybe_confirm_start"}},
|
||||
{ID: "maybe_confirm_start", Kind: "branch", OptionalFields: []string{"auto_start"}, Next: []string{"await_start_confirmation", "execute_create_only"}},
|
||||
{ID: "await_start_confirmation", Kind: "confirm", RequiredFields: []string{"auto_start"}, Next: []string{"execute_create_and_start", "execute_create_only"}},
|
||||
{ID: "execute_create_only", Kind: "execute", RequiredFields: []string{"name", "exchange_id", "model_id", "strategy_id"}, Terminal: true},
|
||||
{ID: "execute_create_and_start", Kind: "execute", RequiredFields: []string{"name", "exchange_id", "model_id", "strategy_id"}, OptionalFields: []string{"auto_start"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "trader_management",
|
||||
Action: "update_name",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_name"}},
|
||||
{ID: "collect_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "name"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "trader_management",
|
||||
Action: "update_bindings",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_bindings"}},
|
||||
{ID: "collect_bindings", Kind: "collect_slot", RequiredFields: []string{"binding_update"}, OptionalFields: []string{"ai_model_id", "exchange_id", "strategy_id"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "binding_update"}, OptionalFields: []string{"ai_model_id", "exchange_id", "strategy_id"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "trader_management",
|
||||
Action: "start",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}},
|
||||
{ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_start"}},
|
||||
{ID: "execute_start", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "trader_management",
|
||||
Action: "stop",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}},
|
||||
{ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_stop"}},
|
||||
{ID: "execute_stop", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "trader_management",
|
||||
Action: "delete",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}},
|
||||
{ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_delete"}},
|
||||
{ID: "execute_delete", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "strategy_management",
|
||||
Action: "create",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_name", Kind: "collect_slot", RequiredFields: []string{"name"}, OptionalFields: []string{"lang", "description", "config"}, Next: []string{"execute_create"}},
|
||||
{ID: "execute_create", Kind: "execute", RequiredFields: []string{"name"}, OptionalFields: []string{"lang", "description", "config"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "strategy_management",
|
||||
Action: "update_name",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_name"}},
|
||||
{ID: "collect_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "name"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "strategy_management",
|
||||
Action: "update_prompt",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_prompt"}},
|
||||
{ID: "collect_prompt", Kind: "collect_slot", RequiredFields: []string{"prompt"}, Next: []string{"load_config"}},
|
||||
{ID: "load_config", Kind: "load_state", RequiredFields: []string{"target_ref"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "prompt"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "strategy_management",
|
||||
Action: "update_config",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"resolve_config_field"}},
|
||||
{ID: "resolve_config_field", Kind: "collect_slot", RequiredFields: []string{"config_field"}, Next: []string{"resolve_config_value"}},
|
||||
{ID: "resolve_config_value", Kind: "collect_slot", RequiredFields: []string{"config_value"}, Next: []string{"load_config"}},
|
||||
{ID: "load_config", Kind: "load_state", RequiredFields: []string{"target_ref"}, Next: []string{"apply_field_update"}},
|
||||
{ID: "apply_field_update", Kind: "transform", RequiredFields: []string{"config_field", "config_value"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "config_field", "config_value"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "strategy_management",
|
||||
Action: "duplicate",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_name"}},
|
||||
{ID: "collect_name", Kind: "collect_slot", RequiredFields: []string{"name"}, Next: []string{"execute_duplicate"}},
|
||||
{ID: "execute_duplicate", Kind: "execute", RequiredFields: []string{"target_ref", "name"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "strategy_management",
|
||||
Action: "activate",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"execute_activate"}},
|
||||
{ID: "execute_activate", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "strategy_management",
|
||||
Action: "delete",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}},
|
||||
{ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_delete"}},
|
||||
{ID: "execute_delete", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "model_management",
|
||||
Action: "create",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_provider", Kind: "collect_slot", RequiredFields: []string{"provider"}, Next: []string{"collect_optional_fields"}},
|
||||
{ID: "collect_optional_fields", Kind: "collect_slot", OptionalFields: []string{"name", "custom_api_url", "custom_model_name"}, Next: []string{"execute_create"}},
|
||||
{ID: "execute_create", Kind: "execute", RequiredFields: []string{"provider"}, OptionalFields: []string{"name", "custom_api_url", "custom_model_name"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "model_management",
|
||||
Action: "update_status",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_enabled"}},
|
||||
{ID: "collect_enabled", Kind: "collect_slot", RequiredFields: []string{"enabled"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "enabled"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "model_management",
|
||||
Action: "update_endpoint",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_custom_api_url"}},
|
||||
{ID: "collect_custom_api_url", Kind: "collect_slot", RequiredFields: []string{"custom_api_url"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "custom_api_url"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "model_management",
|
||||
Action: "update_name",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_custom_model_name"}},
|
||||
{ID: "collect_custom_model_name", Kind: "collect_slot", RequiredFields: []string{"custom_model_name"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "custom_model_name"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "model_management",
|
||||
Action: "delete",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}},
|
||||
{ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_delete"}},
|
||||
{ID: "execute_delete", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "exchange_management",
|
||||
Action: "create",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_exchange_type", Kind: "collect_slot", RequiredFields: []string{"exchange_type"}, Next: []string{"collect_account_name"}},
|
||||
{ID: "collect_account_name", Kind: "collect_slot", OptionalFields: []string{"account_name"}, Next: []string{"execute_create"}},
|
||||
{ID: "execute_create", Kind: "execute", RequiredFields: []string{"exchange_type"}, OptionalFields: []string{"account_name"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "exchange_management",
|
||||
Action: "update_name",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_account_name"}},
|
||||
{ID: "collect_account_name", Kind: "collect_slot", RequiredFields: []string{"account_name"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "account_name"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "exchange_management",
|
||||
Action: "update_status",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"collect_enabled"}},
|
||||
{ID: "collect_enabled", Kind: "collect_slot", RequiredFields: []string{"enabled"}, Next: []string{"execute_update"}},
|
||||
{ID: "execute_update", Kind: "execute", RequiredFields: []string{"target_ref", "enabled"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
SkillName: "exchange_management",
|
||||
Action: "delete",
|
||||
Steps: []SkillDAGStep{
|
||||
{ID: "resolve_target", Kind: "resolve_target", RequiredFields: []string{"target_ref"}, Next: []string{"await_confirmation"}},
|
||||
{ID: "await_confirmation", Kind: "confirm", RequiredFields: []string{"target_ref"}, Next: []string{"execute_delete"}},
|
||||
{ID: "execute_delete", Kind: "execute", RequiredFields: []string{"target_ref"}, Terminal: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
registry := make(map[string]SkillDAG, len(dags))
|
||||
for _, dag := range dags {
|
||||
dag = normalizeSkillDAG(dag)
|
||||
if dag.SkillName == "" || dag.Action == "" {
|
||||
continue
|
||||
}
|
||||
registry[skillDAGKey(dag.SkillName, dag.Action)] = dag
|
||||
}
|
||||
return registry
|
||||
}
|
||||
|
||||
func normalizeSkillDAG(dag SkillDAG) SkillDAG {
|
||||
dag.SkillName = strings.TrimSpace(dag.SkillName)
|
||||
dag.Action = strings.TrimSpace(dag.Action)
|
||||
steps := make([]SkillDAGStep, 0, len(dag.Steps))
|
||||
for _, step := range dag.Steps {
|
||||
step.ID = strings.TrimSpace(step.ID)
|
||||
step.Kind = strings.TrimSpace(step.Kind)
|
||||
step.RequiredFields = cleanStringList(step.RequiredFields)
|
||||
step.OptionalFields = cleanStringList(step.OptionalFields)
|
||||
step.Next = cleanStringList(step.Next)
|
||||
if step.ID == "" {
|
||||
continue
|
||||
}
|
||||
steps = append(steps, step)
|
||||
}
|
||||
dag.Steps = steps
|
||||
return dag
|
||||
}
|
||||
|
||||
func skillDAGKey(skillName, action string) string {
|
||||
return strings.TrimSpace(skillName) + ":" + strings.TrimSpace(action)
|
||||
}
|
||||
|
||||
func getSkillDAG(skillName, action string) (SkillDAG, bool) {
|
||||
dag, ok := skillDAGRegistry[skillDAGKey(skillName, action)]
|
||||
return dag, ok
|
||||
}
|
||||
|
||||
func listSkillDAGs() []SkillDAG {
|
||||
out := make([]SkillDAG, 0, len(skillDAGRegistry))
|
||||
for _, dag := range skillDAGRegistry {
|
||||
out = append(out, dag)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
51
agent/skill_dag_runtime.go
Normal file
51
agent/skill_dag_runtime.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package agent
|
||||
|
||||
const skillDAGStepField = "_dag_step"
|
||||
|
||||
func currentSkillDAGStep(session skillSession) (SkillDAGStep, bool) {
|
||||
dag, ok := getSkillDAG(session.Name, session.Action)
|
||||
if !ok || len(dag.Steps) == 0 {
|
||||
return SkillDAGStep{}, false
|
||||
}
|
||||
stepID := fieldValue(session, skillDAGStepField)
|
||||
if stepID == "" {
|
||||
return dag.Steps[0], true
|
||||
}
|
||||
for _, step := range dag.Steps {
|
||||
if step.ID == stepID {
|
||||
return step, true
|
||||
}
|
||||
}
|
||||
return dag.Steps[0], true
|
||||
}
|
||||
|
||||
func setSkillDAGStep(session *skillSession, stepID string) {
|
||||
ensureSkillFields(session)
|
||||
if stepID == "" {
|
||||
delete(session.Fields, skillDAGStepField)
|
||||
return
|
||||
}
|
||||
session.Fields[skillDAGStepField] = stepID
|
||||
}
|
||||
|
||||
func clearSkillDAGStep(session *skillSession) {
|
||||
if session == nil || session.Fields == nil {
|
||||
return
|
||||
}
|
||||
delete(session.Fields, skillDAGStepField)
|
||||
}
|
||||
|
||||
func advanceSkillDAGStep(session *skillSession, currentStepID string) {
|
||||
dag, ok := getSkillDAG(session.Name, session.Action)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, step := range dag.Steps {
|
||||
if step.ID != currentStepID || len(step.Next) == 0 {
|
||||
continue
|
||||
}
|
||||
setSkillDAGStep(session, step.Next[0])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
27
agent/skill_dag_runtime_test.go
Normal file
27
agent/skill_dag_runtime_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCurrentSkillDAGStepDefaultsToFirstStep(t *testing.T) {
|
||||
session := skillSession{Name: "strategy_management", Action: "update_config"}
|
||||
step, ok := currentSkillDAGStep(session)
|
||||
if !ok {
|
||||
t.Fatal("expected dag step")
|
||||
}
|
||||
if step.ID != "resolve_target" {
|
||||
t.Fatalf("expected first step resolve_target, got %s", step.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdvanceSkillDAGStepMovesToNextStep(t *testing.T) {
|
||||
session := skillSession{Name: "strategy_management", Action: "update_config"}
|
||||
setSkillDAGStep(&session, "resolve_config_field")
|
||||
advanceSkillDAGStep(&session, "resolve_config_field")
|
||||
step, ok := currentSkillDAGStep(session)
|
||||
if !ok {
|
||||
t.Fatal("expected dag step")
|
||||
}
|
||||
if step.ID != "resolve_config_value" {
|
||||
t.Fatalf("expected resolve_config_value, got %s", step.ID)
|
||||
}
|
||||
}
|
||||
67
agent/skill_dag_test.go
Normal file
67
agent/skill_dag_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetSkillDAGForStructuredActions(t *testing.T) {
|
||||
tests := []struct {
|
||||
skill string
|
||||
action string
|
||||
}{
|
||||
{skill: "trader_management", action: "create"},
|
||||
{skill: "trader_management", action: "update_bindings"},
|
||||
{skill: "strategy_management", action: "update_config"},
|
||||
{skill: "strategy_management", action: "update_prompt"},
|
||||
{skill: "model_management", action: "update_status"},
|
||||
{skill: "exchange_management", action: "update_name"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
dag, ok := getSkillDAG(tt.skill, tt.action)
|
||||
if !ok {
|
||||
t.Fatalf("expected DAG for %s/%s", tt.skill, tt.action)
|
||||
}
|
||||
if dag.SkillName != tt.skill || dag.Action != tt.action {
|
||||
t.Fatalf("unexpected dag identity: %+v", dag)
|
||||
}
|
||||
if len(dag.Steps) == 0 {
|
||||
t.Fatalf("expected DAG steps for %s/%s", tt.skill, tt.action)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructuredDAGsHaveTerminalStep(t *testing.T) {
|
||||
for _, dag := range listSkillDAGs() {
|
||||
hasTerminal := false
|
||||
for _, step := range dag.Steps {
|
||||
if step.Terminal {
|
||||
hasTerminal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasTerminal {
|
||||
t.Fatalf("expected terminal step for %s/%s", dag.SkillName, dag.Action)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrategyUpdateConfigDAGMatchesCurrentAtomicFlow(t *testing.T) {
|
||||
dag, ok := getSkillDAG("strategy_management", "update_config")
|
||||
if !ok {
|
||||
t.Fatal("missing strategy update_config dag")
|
||||
}
|
||||
if len(dag.Steps) != 6 {
|
||||
t.Fatalf("expected 6 steps, got %d", len(dag.Steps))
|
||||
}
|
||||
if dag.Steps[0].ID != "resolve_target" {
|
||||
t.Fatalf("expected first step resolve_target, got %s", dag.Steps[0].ID)
|
||||
}
|
||||
if dag.Steps[1].ID != "resolve_config_field" {
|
||||
t.Fatalf("expected second step resolve_config_field, got %s", dag.Steps[1].ID)
|
||||
}
|
||||
if dag.Steps[2].ID != "resolve_config_value" {
|
||||
t.Fatalf("expected third step resolve_config_value, got %s", dag.Steps[2].ID)
|
||||
}
|
||||
if dag.Steps[5].ID != "execute_update" || !dag.Steps[5].Terminal {
|
||||
t.Fatalf("expected final terminal execute step, got %+v", dag.Steps[5])
|
||||
}
|
||||
}
|
||||
@@ -37,8 +37,8 @@ type traderSkillOption struct {
|
||||
}
|
||||
|
||||
var (
|
||||
quotedNamePattern = regexp.MustCompile(`[“"]([^“”"]{1,40})[”"]`)
|
||||
traderNamedPattern = regexp.MustCompile(`(?:叫|名为|名字是)\s*([A-Za-z0-9_\-\p{Han}]{2,40})`)
|
||||
quotedNamePattern = regexp.MustCompile(`[“"]([^“”"]{1,40})[”"]`)
|
||||
traderNamedPattern = regexp.MustCompile(`(?:叫|名为|名字是)\s*([A-Za-z0-9_\-\p{Han}]{2,40})`)
|
||||
)
|
||||
|
||||
func skillSessionConfigKey(userID int64) string {
|
||||
@@ -157,7 +157,12 @@ func isNoReply(text string) bool {
|
||||
|
||||
func isCancelSkillReply(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
return lower == "取消" || lower == "/cancel" || lower == "cancel"
|
||||
switch lower {
|
||||
case "取消", "/cancel", "cancel", "不改", "先不改", "算了", "先不用", "不用了", "不弄了", "不搞了", "换话题", "换话题了", "聊别的", "先聊别的":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func detectCreateTraderSkill(text string) bool {
|
||||
@@ -198,6 +203,116 @@ func detectStartIntent(text string) bool {
|
||||
return containsAny(lower, []string{"启动", "跑起来", "run", "start", "立即运行", "并启动"})
|
||||
}
|
||||
|
||||
func looksLikeStandaloneValueReply(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if firstIntegerPattern.MatchString(lower) && len(strings.Fields(lower)) <= 4 {
|
||||
return true
|
||||
}
|
||||
return containsAny(lower, []string{"启用", "禁用", "enable", "disable", "打开", "关闭"})
|
||||
}
|
||||
|
||||
func detectImplicitStrategyAction(text string) string {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
switch {
|
||||
case containsAny(lower, []string{"prompt", "提示词"}):
|
||||
return "update_prompt"
|
||||
case containsAny(lower, []string{"参数", "配置", "置信度", "持仓", "周期", "timeframe", "调到", "改到", "改成", "调整"}):
|
||||
return "update_config"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func detectImplicitTraderAction(text string) string {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
switch {
|
||||
case containsAny(lower, []string{"启动", "开始", "run", "start"}):
|
||||
return "start"
|
||||
case containsAny(lower, []string{"停止", "停掉", "stop", "pause"}):
|
||||
return "stop"
|
||||
case containsAny(lower, []string{"换模型", "换交易所", "换策略", "绑定", "切换模型", "切换交易所", "切换策略"}):
|
||||
return "update_bindings"
|
||||
case containsAny(lower, []string{"改名", "重命名", "rename"}):
|
||||
return "update_name"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func detectImplicitModelAction(text string) string {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
switch {
|
||||
case containsAny(lower, []string{"启用", "禁用", "enable", "disable"}):
|
||||
return "update_status"
|
||||
case containsAny(lower, []string{"url", "endpoint", "地址", "接口"}):
|
||||
return "update_endpoint"
|
||||
case containsAny(lower, []string{"模型名", "模型名称", "model name", "改名", "重命名", "rename"}):
|
||||
return "update_name"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func detectImplicitExchangeAction(text string) string {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
switch {
|
||||
case containsAny(lower, []string{"启用", "禁用", "enable", "disable"}):
|
||||
return "update_status"
|
||||
case containsAny(lower, []string{"账户名", "改名", "重命名", "rename"}):
|
||||
return "update_name"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) inferContextualSkillSession(storeUserID string, userID int64, text string, session skillSession) skillSession {
|
||||
if session.Name != "" || strings.TrimSpace(text) == "" {
|
||||
return session
|
||||
}
|
||||
state := a.getExecutionState(userID)
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if state.CurrentReferences != nil {
|
||||
if ref := state.CurrentReferences.Strategy; ref != nil {
|
||||
if action := detectImplicitStrategyAction(text); action != "" || looksLikeStandaloneValueReply(text) {
|
||||
return skillSession{Name: "strategy_management", Action: defaultIfEmpty(action, "update_config"), Phase: "collecting", TargetRef: ref}
|
||||
}
|
||||
}
|
||||
if ref := state.CurrentReferences.Trader; ref != nil {
|
||||
if action := detectImplicitTraderAction(text); action != "" {
|
||||
return skillSession{Name: "trader_management", Action: action, Phase: "collecting", TargetRef: ref}
|
||||
}
|
||||
}
|
||||
if ref := state.CurrentReferences.Model; ref != nil {
|
||||
if action := detectImplicitModelAction(text); action != "" {
|
||||
return skillSession{Name: "model_management", Action: action, Phase: "collecting", TargetRef: ref}
|
||||
}
|
||||
}
|
||||
if ref := state.CurrentReferences.Exchange; ref != nil {
|
||||
if action := detectImplicitExchangeAction(text); action != "" {
|
||||
return skillSession{Name: "exchange_management", Action: action, Phase: "collecting", TargetRef: ref}
|
||||
}
|
||||
}
|
||||
}
|
||||
if containsAny(lower, []string{"调整参数", "改参数", "改配置"}) {
|
||||
options := a.loadStrategyOptions(storeUserID)
|
||||
if len(options) == 1 {
|
||||
return skillSession{
|
||||
Name: "strategy_management",
|
||||
Action: "update_config",
|
||||
Phase: "collecting",
|
||||
TargetRef: &EntityReference{
|
||||
ID: options[0].ID,
|
||||
Name: options[0].Name,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func extractTraderName(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
@@ -212,11 +327,45 @@ func extractTraderName(text string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractSegmentAfterKeywords(text string, keywords []string) string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
for _, keyword := range keywords {
|
||||
idx := strings.Index(lower, strings.ToLower(keyword))
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
segment := strings.TrimSpace(trimmed[idx+len(keyword):])
|
||||
if segment == "" {
|
||||
continue
|
||||
}
|
||||
cut := len(segment)
|
||||
for i, r := range segment {
|
||||
switch r {
|
||||
case ',', ',', '。', ';', ';', '\n', '、':
|
||||
cut = i
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
segment = strings.TrimSpace(segment[:cut])
|
||||
segment = strings.Trim(segment, "“”\"':: ")
|
||||
if segment != "" {
|
||||
return segment
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func pickMentionedOption(text string, options []traderSkillOption) *traderSkillOption {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return nil
|
||||
}
|
||||
bestScore := 0
|
||||
var matched *traderSkillOption
|
||||
for _, option := range options {
|
||||
id := strings.ToLower(strings.TrimSpace(option.ID))
|
||||
@@ -224,10 +373,16 @@ func pickMentionedOption(text string, options []traderSkillOption) *traderSkillO
|
||||
if id == "" && name == "" {
|
||||
continue
|
||||
}
|
||||
if (id != "" && strings.Contains(lower, id)) || (name != "" && strings.Contains(lower, name)) {
|
||||
if matched != nil {
|
||||
return nil
|
||||
}
|
||||
score := optionMatchScore(lower, id, name)
|
||||
if score == 0 {
|
||||
continue
|
||||
}
|
||||
if score == bestScore {
|
||||
matched = nil
|
||||
continue
|
||||
}
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
copy := option
|
||||
matched = ©
|
||||
}
|
||||
@@ -235,6 +390,73 @@ func pickMentionedOption(text string, options []traderSkillOption) *traderSkillO
|
||||
return matched
|
||||
}
|
||||
|
||||
func pickOptionFromSegment(text string, keywords []string, options []traderSkillOption) *traderSkillOption {
|
||||
segment := extractSegmentAfterKeywords(text, keywords)
|
||||
if strings.TrimSpace(segment) == "" {
|
||||
return nil
|
||||
}
|
||||
return pickMentionedOption(segment, options)
|
||||
}
|
||||
|
||||
func optionMatchScore(text, id, name string) int {
|
||||
if id != "" && strings.Contains(text, id) {
|
||||
return 4
|
||||
}
|
||||
return optionNameMatchScore(text, name)
|
||||
}
|
||||
|
||||
func optionNameMatchScore(text, name string) int {
|
||||
name = strings.TrimSpace(strings.ToLower(name))
|
||||
if name == "" {
|
||||
return 0
|
||||
}
|
||||
if strings.Contains(text, name) {
|
||||
return 3
|
||||
}
|
||||
fields := strings.FieldsFunc(name, func(r rune) bool {
|
||||
switch r {
|
||||
case ' ', ',', ',', '/', '|', '、', '(', ')', '(', ')':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
best := 0
|
||||
for _, field := range fields {
|
||||
field = strings.TrimSpace(field)
|
||||
if field == "" {
|
||||
continue
|
||||
}
|
||||
if len([]rune(field)) <= 2 && !containsHan(field) {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(text, field) {
|
||||
if containsHan(field) && len([]rune(field)) >= 3 {
|
||||
best = max(best, 2)
|
||||
} else {
|
||||
best = max(best, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func containsHan(s string) bool {
|
||||
for _, r := range s {
|
||||
if r >= 0x4E00 && r <= 0x9FFF {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func choosePreferredOption(options []traderSkillOption) *traderSkillOption {
|
||||
if len(options) == 1 {
|
||||
copy := options[0]
|
||||
@@ -262,6 +484,8 @@ func formatOptionList(prefix string, options []traderSkillOption) string {
|
||||
}
|
||||
if option.Enabled {
|
||||
label += "(已启用)"
|
||||
} else {
|
||||
label += "(已禁用)"
|
||||
}
|
||||
parts = append(parts, label)
|
||||
}
|
||||
@@ -291,13 +515,12 @@ func (a *Agent) loadEnabledModelOptions(storeUserID string) []traderSkillOption
|
||||
}
|
||||
out := make([]traderSkillOption, 0, len(models))
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(model.CustomModelName)
|
||||
}
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(model.Provider)
|
||||
}
|
||||
parts := cleanStringList([]string{
|
||||
strings.TrimSpace(model.Name),
|
||||
strings.TrimSpace(model.CustomModelName),
|
||||
strings.TrimSpace(model.Provider),
|
||||
})
|
||||
name := strings.Join(parts, " ")
|
||||
out = append(out, traderSkillOption{ID: model.ID, Name: name, Enabled: model.Enabled})
|
||||
}
|
||||
return out
|
||||
@@ -342,6 +565,7 @@ func (a *Agent) tryHardSkill(ctx context.Context, storeUserID string, userID int
|
||||
return "", false
|
||||
}
|
||||
session := a.getSkillSession(userID)
|
||||
session = a.inferContextualSkillSession(storeUserID, userID, text, session)
|
||||
if (session.Name == "trader_management" && session.Action == "create") || detectCreateTraderSkill(text) {
|
||||
answer, handled := a.handleCreateTraderSkill(storeUserID, userID, lang, text, session)
|
||||
if handled {
|
||||
@@ -459,13 +683,13 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang,
|
||||
return "Cancelled the current trader creation flow.", true
|
||||
}
|
||||
|
||||
if session.Name == "" {
|
||||
session = skillSession{
|
||||
Name: "trader_management",
|
||||
Action: "create",
|
||||
Phase: "collecting",
|
||||
Slots: &createTraderSkillSlots{},
|
||||
}
|
||||
if session.Name == "" {
|
||||
session = skillSession{
|
||||
Name: "trader_management",
|
||||
Action: "create",
|
||||
Phase: "collecting",
|
||||
Slots: &createTraderSkillSlots{},
|
||||
}
|
||||
if detectStartIntent(text) {
|
||||
autoStart := true
|
||||
session.Slots.AutoStart = &autoStart
|
||||
@@ -474,8 +698,12 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang,
|
||||
if session.Slots == nil {
|
||||
session.Slots = &createTraderSkillSlots{}
|
||||
}
|
||||
if fieldValue(session, skillDAGStepField) == "" {
|
||||
setSkillDAGStep(&session, "resolve_name")
|
||||
}
|
||||
|
||||
if session.Phase == "await_start_confirmation" {
|
||||
setSkillDAGStep(&session, "await_start_confirmation")
|
||||
switch {
|
||||
case isYesReply(text):
|
||||
answer := a.executeCreateTraderSkill(storeUserID, userID, lang, session, true)
|
||||
@@ -496,13 +724,19 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang,
|
||||
if slots.Name == "" {
|
||||
slots.Name = extractTraderName(text)
|
||||
}
|
||||
if slots.Name != "" {
|
||||
setSkillDAGStep(&session, "resolve_exchange")
|
||||
}
|
||||
|
||||
models := a.loadEnabledModelOptions(storeUserID)
|
||||
exchanges := a.loadExchangeOptions(storeUserID)
|
||||
strategies := a.loadStrategyOptions(storeUserID)
|
||||
|
||||
if slots.ModelID == "" {
|
||||
if match := pickMentionedOption(text, models); match != nil {
|
||||
if match := pickOptionFromSegment(text, []string{"模型用", "模型", "model"}, models); match != nil {
|
||||
slots.ModelID = match.ID
|
||||
slots.ModelName = match.Name
|
||||
} else if match := pickMentionedOption(text, models); match != nil {
|
||||
slots.ModelID = match.ID
|
||||
slots.ModelName = match.Name
|
||||
} else if choice := choosePreferredOption(models); choice != nil {
|
||||
@@ -510,17 +744,46 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang,
|
||||
slots.ModelName = choice.Name
|
||||
}
|
||||
}
|
||||
if slots.ExchangeID != "" {
|
||||
setSkillDAGStep(&session, "resolve_model")
|
||||
}
|
||||
if slots.ExchangeID == "" {
|
||||
if match := pickMentionedOption(text, exchanges); match != nil {
|
||||
slots.ExchangeID = match.ID
|
||||
slots.ExchangeName = match.Name
|
||||
if match := pickOptionFromSegment(text, []string{"交易所用", "交易所", "exchange"}, exchanges); match != nil {
|
||||
if match.Enabled {
|
||||
slots.ExchangeID = match.ID
|
||||
slots.ExchangeName = match.Name
|
||||
} else {
|
||||
if lang == "zh" {
|
||||
extra := "你刚才提到的交易所“" + defaultIfEmpty(match.Name, match.ID) + "”当前已禁用,请换一个已启用的交易所。"
|
||||
a.saveSkillSession(userID, session)
|
||||
return extra + "\n" + formatOptionList("可用交易所:", exchanges), true
|
||||
}
|
||||
a.saveSkillSession(userID, session)
|
||||
return "The exchange you mentioned is currently disabled. Please choose an enabled exchange.\n" + formatOptionList("Available exchanges:", exchanges), true
|
||||
}
|
||||
} else if match := pickMentionedOption(text, exchanges); match != nil {
|
||||
if match.Enabled {
|
||||
slots.ExchangeID = match.ID
|
||||
slots.ExchangeName = match.Name
|
||||
} else {
|
||||
if lang == "zh" {
|
||||
extra := "你刚才提到的交易所“" + defaultIfEmpty(match.Name, match.ID) + "”当前已禁用,请换一个已启用的交易所。"
|
||||
a.saveSkillSession(userID, session)
|
||||
return extra + "\n" + formatOptionList("可用交易所:", exchanges), true
|
||||
}
|
||||
a.saveSkillSession(userID, session)
|
||||
return "The exchange you mentioned is currently disabled. Please choose an enabled exchange.\n" + formatOptionList("Available exchanges:", exchanges), true
|
||||
}
|
||||
} else if choice := choosePreferredOption(exchanges); choice != nil {
|
||||
slots.ExchangeID = choice.ID
|
||||
slots.ExchangeName = choice.Name
|
||||
}
|
||||
}
|
||||
if slots.StrategyID == "" {
|
||||
if match := pickMentionedOption(text, strategies); match != nil {
|
||||
if match := pickOptionFromSegment(text, []string{"策略用", "策略", "strategy"}, strategies); match != nil {
|
||||
slots.StrategyID = match.ID
|
||||
slots.StrategyName = match.Name
|
||||
} else if match := pickMentionedOption(text, strategies); match != nil {
|
||||
slots.StrategyID = match.ID
|
||||
slots.StrategyName = match.Name
|
||||
} else if choice := choosePreferredOption(strategies); choice != nil {
|
||||
@@ -528,34 +791,18 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang,
|
||||
slots.StrategyName = choice.Name
|
||||
}
|
||||
}
|
||||
if slots.ModelID != "" {
|
||||
setSkillDAGStep(&session, "resolve_strategy")
|
||||
}
|
||||
if slots.StrategyID != "" {
|
||||
setSkillDAGStep(&session, "maybe_confirm_start")
|
||||
}
|
||||
|
||||
if slots.AutoStart == nil && detectStartIntent(text) {
|
||||
autoStart := true
|
||||
slots.AutoStart = &autoStart
|
||||
}
|
||||
|
||||
if len(strategies) == 0 {
|
||||
a.clearSkillSession(userID)
|
||||
if lang == "zh" {
|
||||
return "当前还没有可用策略,暂时不能创建交易员。请先创建一个策略,再回来继续。", true
|
||||
}
|
||||
return "There is no strategy available yet, so I can't create a trader. Please create a strategy first.", true
|
||||
}
|
||||
if len(models) == 0 {
|
||||
a.clearSkillSession(userID)
|
||||
if lang == "zh" {
|
||||
return "当前还没有模型配置,暂时不能创建交易员。请先配置并启用一个模型。", true
|
||||
}
|
||||
return "There is no model config yet, so I can't create a trader. Please configure and enable a model first.", true
|
||||
}
|
||||
if len(exchanges) == 0 {
|
||||
a.clearSkillSession(userID)
|
||||
if lang == "zh" {
|
||||
return "当前还没有交易所配置,暂时不能创建交易员。请先配置并启用一个交易所账户。", true
|
||||
}
|
||||
return "There is no exchange config yet, so I can't create a trader. Please configure and enable an exchange first.", true
|
||||
}
|
||||
|
||||
missing := make([]string, 0, 3)
|
||||
extraLines := make([]string, 0, 3)
|
||||
if actionRequiresSlot("trader_management", "create", "name") && slots.Name == "" {
|
||||
@@ -563,15 +810,53 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang,
|
||||
}
|
||||
if actionRequiresSlot("trader_management", "create", "exchange") && slots.ExchangeID == "" {
|
||||
missing = append(missing, slotDisplayName("exchange", lang))
|
||||
extraLines = append(extraLines, formatOptionList("可用交易所:", exchanges))
|
||||
if len(exchanges) == 0 {
|
||||
if lang == "zh" {
|
||||
extraLines = append(extraLines, "当前还没有可用交易所配置,请先配置并启用一个交易所账户。")
|
||||
} else {
|
||||
extraLines = append(extraLines, "There is no enabled exchange config yet. Please create and enable one first.")
|
||||
}
|
||||
} else {
|
||||
label := "Available exchanges:"
|
||||
if lang == "zh" {
|
||||
label = "可用交易所:"
|
||||
}
|
||||
extraLines = append(extraLines, formatOptionList(label, exchanges))
|
||||
}
|
||||
}
|
||||
if actionRequiresSlot("trader_management", "create", "model") && slots.ModelID == "" {
|
||||
missing = append(missing, slotDisplayName("model", lang))
|
||||
extraLines = append(extraLines, formatOptionList("可用模型:", models))
|
||||
if len(models) == 0 {
|
||||
if lang == "zh" {
|
||||
extraLines = append(extraLines, "当前还没有可用模型配置,请先配置并启用一个模型。")
|
||||
} else {
|
||||
extraLines = append(extraLines, "There is no enabled model config yet. Please create and enable one first.")
|
||||
}
|
||||
} else {
|
||||
label := "Available models:"
|
||||
if lang == "zh" {
|
||||
label = "可用模型:"
|
||||
}
|
||||
extraLines = append(extraLines, formatOptionList(label, models))
|
||||
}
|
||||
}
|
||||
if actionRequiresSlot("trader_management", "create", "strategy") && slots.StrategyID == "" {
|
||||
if slots.StrategyID == "" && (actionRequiresSlot("trader_management", "create", "strategy") || len(strategies) == 0) {
|
||||
missing = append(missing, slotDisplayName("strategy", lang))
|
||||
extraLines = append(extraLines, formatOptionList("可用策略:", strategies))
|
||||
}
|
||||
if slots.StrategyID == "" {
|
||||
if len(strategies) == 0 {
|
||||
if lang == "zh" {
|
||||
extraLines = append(extraLines, "当前还没有可用策略,请先创建一个策略。")
|
||||
} else {
|
||||
extraLines = append(extraLines, "There is no strategy available yet. Please create one first.")
|
||||
}
|
||||
} else {
|
||||
label := "Available strategies:"
|
||||
if lang == "zh" {
|
||||
label = "可用策略:"
|
||||
}
|
||||
extraLines = append(extraLines, formatOptionList(label, strategies))
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
@@ -595,6 +880,7 @@ func (a *Agent) handleCreateTraderSkill(storeUserID string, userID int64, lang,
|
||||
|
||||
if slots.AutoStart != nil && *slots.AutoStart {
|
||||
session.Phase = "await_start_confirmation"
|
||||
setSkillDAGStep(&session, "await_start_confirmation")
|
||||
a.saveSkillSession(userID, session)
|
||||
if lang == "zh" {
|
||||
return fmt.Sprintf("我已经准备好创建交易员“%s”,并在创建后立即启动它。\n使用的交易所:%s\n使用的模型:%s\n使用的策略:%s\n\n这是高风险动作。回复“确认”继续,回复“先不用”则只创建不启动。",
|
||||
@@ -631,16 +917,31 @@ func (s *createTraderSkillSlots) StrategyNameOrID() string {
|
||||
|
||||
func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang string, session skillSession, startAfterCreate bool) string {
|
||||
args := manageTraderArgs{
|
||||
Action: "create",
|
||||
Name: session.Slots.Name,
|
||||
AIModelID: session.Slots.ModelID,
|
||||
ExchangeID: session.Slots.ExchangeID,
|
||||
StrategyID: session.Slots.StrategyID,
|
||||
Action: "create",
|
||||
Name: session.Slots.Name,
|
||||
AIModelID: session.Slots.ModelID,
|
||||
ExchangeID: session.Slots.ExchangeID,
|
||||
StrategyID: session.Slots.StrategyID,
|
||||
}
|
||||
createRaw := a.toolCreateTrader(storeUserID, args)
|
||||
if errMsg := parseSkillError(createRaw); errMsg != "" && strings.Contains(createRaw, `"error"`) {
|
||||
session.Phase = "collecting"
|
||||
a.saveSkillSession(userID, session)
|
||||
if strings.Contains(strings.ToLower(errMsg), "exchange is disabled") {
|
||||
exchanges := a.loadExchangeOptions(storeUserID)
|
||||
if lang == "zh" {
|
||||
reply := fmt.Sprintf("创建交易员失败:你选的交易所“%s”当前已禁用,请换一个已启用的交易所。", session.Slots.ExchangeNameOrID())
|
||||
if list := formatOptionList("可用交易所:", exchanges); list != "" {
|
||||
reply += "\n" + list
|
||||
}
|
||||
return reply
|
||||
}
|
||||
reply := fmt.Sprintf("Failed to create trader: the selected exchange %q is disabled. Please choose an enabled exchange.", session.Slots.ExchangeNameOrID())
|
||||
if list := formatOptionList("Available exchanges:", exchanges); list != "" {
|
||||
reply += "\n" + list
|
||||
}
|
||||
return reply
|
||||
}
|
||||
if lang == "zh" {
|
||||
return "创建交易员失败:" + errMsg
|
||||
}
|
||||
@@ -658,6 +959,7 @@ func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang
|
||||
}
|
||||
|
||||
if !startAfterCreate {
|
||||
setSkillDAGStep(&session, "execute_create_only")
|
||||
a.clearSkillSession(userID)
|
||||
if lang == "zh" {
|
||||
return fmt.Sprintf("已创建交易员“%s”。\n交易所:%s\n模型:%s\n策略:%s\n当前状态:未启动。",
|
||||
@@ -667,6 +969,7 @@ func (a *Agent) executeCreateTraderSkill(storeUserID string, userID int64, lang
|
||||
created.Trader.Name, session.Slots.ExchangeNameOrID(), session.Slots.ModelNameOrID(), session.Slots.StrategyNameOrID())
|
||||
}
|
||||
|
||||
setSkillDAGStep(&session, "execute_create_and_start")
|
||||
startRaw := a.toolStartTrader(storeUserID, created.Trader.ID)
|
||||
if errMsg := parseSkillError(startRaw); errMsg != "" && strings.Contains(startRaw, `"error"`) {
|
||||
a.clearSkillSession(userID)
|
||||
@@ -735,6 +1038,9 @@ func (a *Agent) handleModelDiagnosisSkill(storeUserID, lang, text string) string
|
||||
lines = append(lines, fmt.Sprintf("1. 当前共 %d 个模型配置,已启用 %d 个。", len(payload.ModelConfigs), enabledCount))
|
||||
lines = append(lines, "2. 检查目标模型是否同时具备 enabled、API Key、custom_api_url。")
|
||||
lines = append(lines, "3. 如果是 OpenAI / Claude / DeepSeek 等 provider,确认 model name 填的是该 provider 实际可用的模型名。")
|
||||
if excerpt := backendLogDiagnosisExcerpt(lang, text, "model"); excerpt != "" {
|
||||
lines = append(lines, excerpt)
|
||||
}
|
||||
lines = append(lines, "下一步:如果你愿意,我下一步可以继续帮你逐项检查你当前配置里的具体模型。")
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@@ -749,6 +1055,9 @@ func (a *Agent) handleModelDiagnosisSkill(storeUserID, lang, text string) string
|
||||
lines = append(lines, "Likely cause: the model was saved, but the API key, custom_api_url, or custom_model_name does not match the provider runtime config.")
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("Check first: %d model configs exist, %d are enabled.", len(payload.ModelConfigs), enabledCount))
|
||||
if excerpt := backendLogDiagnosisExcerpt(lang, text, "model"); excerpt != "" {
|
||||
lines = append(lines, excerpt)
|
||||
}
|
||||
lines = append(lines, "Next step: verify the target model has enabled=true, a non-empty API key, a valid HTTPS custom_api_url, and a correct model name.")
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@@ -779,6 +1088,9 @@ func (a *Agent) handleExchangeDiagnosisSkill(storeUserID, lang, text string) str
|
||||
}
|
||||
lines = append(lines, "4. 检查 API 白名单是否包含当前服务器 IP。")
|
||||
lines = append(lines, "5. 检查是否已经开启交易/合约权限。")
|
||||
if excerpt := backendLogDiagnosisExcerpt(lang, text, "exchange"); excerpt != "" {
|
||||
lines = append(lines, excerpt)
|
||||
}
|
||||
lines = append(lines, "下一步:如果你把具体报错原文贴给我,我可以按报错类型继续缩小范围。")
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@@ -788,5 +1100,28 @@ func (a *Agent) handleExchangeDiagnosisSkill(storeUserID, lang, text string) str
|
||||
if len(exchanges) > 0 {
|
||||
lines = append(lines, "Current exchange bindings exist, so the next step is to match the exact error text to the most likely cause.")
|
||||
}
|
||||
if excerpt := backendLogDiagnosisExcerpt(lang, text, "exchange"); excerpt != "" {
|
||||
lines = append(lines, excerpt)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func backendLogDiagnosisExcerpt(lang, text, fallbackFilter string) string {
|
||||
filter := strings.TrimSpace(text)
|
||||
if strings.TrimSpace(filter) == "" {
|
||||
filter = fallbackFilter
|
||||
}
|
||||
_, entries, err := readBackendLogEntries(8, filter, true)
|
||||
if err != nil || len(entries) == 0 {
|
||||
if filter != fallbackFilter {
|
||||
_, entries, err = readBackendLogEntries(8, fallbackFilter, true)
|
||||
}
|
||||
}
|
||||
if err != nil || len(entries) == 0 {
|
||||
return ""
|
||||
}
|
||||
if lang == "zh" {
|
||||
return "最近命中的后端错误日志:\n- " + strings.Join(entries, "\n- ")
|
||||
}
|
||||
return "Recent matching backend error logs:\n- " + strings.Join(entries, "\n- ")
|
||||
}
|
||||
|
||||
@@ -3,8 +3,12 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
func TestCreateTraderSkillCollectsMissingFieldsAndCreatesTrader(t *testing.T) {
|
||||
@@ -61,6 +65,54 @@ func TestCreateTraderSkillCollectsMissingFieldsAndCreatesTrader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTraderSkillReportsAllMissingPrerequisitesAtOnce(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 11, "zh", "帮我创建一个交易员")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() error = %v", err)
|
||||
}
|
||||
for _, want := range []string{"名称", "交易所", "模型", "策略"} {
|
||||
if !strings.Contains(resp, want) {
|
||||
t.Fatalf("expected response to mention %q, got %q", want, resp)
|
||||
}
|
||||
}
|
||||
for _, want := range []string{"当前还没有可用交易所配置", "当前还没有可用模型配置", "当前还没有可用策略"} {
|
||||
if !strings.Contains(resp, want) {
|
||||
t.Fatalf("expected response to mention prerequisite %q, got %q", want, resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveSkillSessionYieldsToNewTopic(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
_ = a.toolManageStrategy("user-1", `{
|
||||
"action":"create",
|
||||
"name":"测试策略",
|
||||
"lang":"zh"
|
||||
}`)
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 13, "zh", "帮我创建一个交易员")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "还缺这些信息") {
|
||||
t.Fatalf("expected trader creation flow prompt, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 13, "zh", "列出我当前的策略")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() interrupt error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "当前策略") || !strings.Contains(resp, "测试策略") {
|
||||
t.Fatalf("expected new topic to be handled, got %q", resp)
|
||||
}
|
||||
if a.hasActiveSkillSession(13) {
|
||||
t.Fatal("expected skill session to be cleared after interruption")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTraderSkillRequestsStartConfirmation(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
@@ -175,6 +227,28 @@ func TestStrategyManagementCreateAndActivateSkill(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrategyManagementQueryCanExplainStrategyDetails(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 12, "zh", "创建一个叫“激进的”的策略")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() create error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已创建策略") {
|
||||
t.Fatalf("expected strategy create response, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 12, "zh", "这个策略里面的参数和prompt分别是什么样的")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() detail query error = %v", err)
|
||||
}
|
||||
for _, want := range []string{"策略“激进的”概览", "K线周期", "仓位风险", "Prompt"} {
|
||||
if !strings.Contains(resp, want) {
|
||||
t.Fatalf("expected response to mention %q, got %q", want, resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraderManagementQueryAndDiagnosisSkill(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
@@ -234,3 +308,521 @@ func TestTraderManagementQueryAndDiagnosisSkill(t *testing.T) {
|
||||
t.Fatalf("expected trader diagnosis response, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExchangeManagementAtomicUpdates(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
createResp := a.toolManageExchangeConfig("user-1", `{
|
||||
"action":"create",
|
||||
"exchange_type":"okx",
|
||||
"account_name":"主账户",
|
||||
"enabled":true
|
||||
}`)
|
||||
var created struct {
|
||||
Exchange safeExchangeToolConfig `json:"exchange"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(createResp), &created); err != nil {
|
||||
t.Fatalf("unmarshal exchange response: %v", err)
|
||||
}
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 14, "zh", "更新交易所,把主账户改名为备用账户")
|
||||
if err != nil {
|
||||
t.Fatalf("rename exchange error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已更新交易所配置") {
|
||||
t.Fatalf("expected exchange update response, got %q", resp)
|
||||
}
|
||||
|
||||
raw := a.toolGetExchangeConfigs("user-1")
|
||||
if !strings.Contains(raw, "备用账户") {
|
||||
t.Fatalf("expected renamed exchange in list, got %s", raw)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 14, "zh", "禁用这个交易所配置")
|
||||
if err != nil {
|
||||
t.Fatalf("disable exchange error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已更新交易所配置") {
|
||||
t.Fatalf("expected exchange status update response, got %q", resp)
|
||||
}
|
||||
|
||||
raw = a.toolGetExchangeConfigs("user-1")
|
||||
if strings.Contains(raw, `"enabled":true`) && strings.Contains(raw, "备用账户") {
|
||||
t.Fatalf("expected exchange to be disabled, got %s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelManagementAtomicUpdates(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
createResp := a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"deepseek",
|
||||
"enabled":true,
|
||||
"custom_api_url":"https://api.deepseek.com/v1",
|
||||
"custom_model_name":"deepseek-chat"
|
||||
}`)
|
||||
var created struct {
|
||||
Model safeModelToolConfig `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(createResp), &created); err != nil {
|
||||
t.Fatalf("unmarshal model response: %v", err)
|
||||
}
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 15, "zh", "更新模型,把模型名称改成 deepseek-reasoner")
|
||||
if err != nil {
|
||||
t.Fatalf("rename model error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已更新模型配置") {
|
||||
t.Fatalf("expected model update response, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 15, "zh", "更新模型,把接口地址改成 https://api.deepseek.com/beta")
|
||||
if err != nil {
|
||||
t.Fatalf("update model endpoint error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已更新模型配置") {
|
||||
t.Fatalf("expected model endpoint update response, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 15, "zh", "禁用这个模型配置")
|
||||
if err != nil {
|
||||
t.Fatalf("disable model error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已更新模型配置") {
|
||||
t.Fatalf("expected model status update response, got %q", resp)
|
||||
}
|
||||
|
||||
raw := a.toolGetModelConfigs("user-1")
|
||||
if !strings.Contains(raw, "deepseek-reasoner") || !strings.Contains(raw, "https://api.deepseek.com/beta") {
|
||||
t.Fatalf("expected updated model fields, got %s", raw)
|
||||
}
|
||||
if strings.Contains(raw, `"enabled":true`) && strings.Contains(raw, created.Model.ID) {
|
||||
t.Fatalf("expected model to be disabled, got %s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrategyManagementAtomicUpdates(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 16, "zh", "创建一个叫“激进策略C”的策略")
|
||||
if err != nil {
|
||||
t.Fatalf("create strategy error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已创建策略") {
|
||||
t.Fatalf("expected strategy create response, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 16, "zh", "更新这个策略的prompt,把提示词改成“优先观察BTC和ETH,信号不一致时不要开仓”")
|
||||
if err != nil {
|
||||
t.Fatalf("update strategy prompt error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已更新策略 prompt") {
|
||||
t.Fatalf("expected strategy prompt update response, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 16, "zh", "更新这个策略参数,把最大持仓改成2,最低置信度改成80,主周期改成15m,并使用15m 1h 4h")
|
||||
if err != nil {
|
||||
t.Fatalf("update strategy config error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已更新策略参数") {
|
||||
t.Fatalf("expected strategy config update response, got %q", resp)
|
||||
}
|
||||
|
||||
listRaw := a.toolGetStrategies("user-1")
|
||||
if !strings.Contains(listRaw, "优先观察BTC和ETH") || !strings.Contains(listRaw, `"max_positions":2`) || !strings.Contains(listRaw, `"min_confidence":80`) || !strings.Contains(listRaw, `"primary_timeframe":"15m"`) {
|
||||
t.Fatalf("expected updated strategy config, got %s", listRaw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraderManagementAtomicBindingUpdate(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
modelOpenAI := a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"openai",
|
||||
"enabled":true,
|
||||
"custom_api_url":"https://api.openai.com/v1",
|
||||
"custom_model_name":"gpt-5-mini"
|
||||
}`)
|
||||
var openAI struct {
|
||||
Model safeModelToolConfig `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(modelOpenAI), &openAI); err != nil {
|
||||
t.Fatalf("unmarshal openai model: %v", err)
|
||||
}
|
||||
modelDeepSeek := a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"deepseek",
|
||||
"enabled":true,
|
||||
"custom_api_url":"https://api.deepseek.com/v1",
|
||||
"custom_model_name":"deepseek-chat"
|
||||
}`)
|
||||
var deepSeek struct {
|
||||
Model safeModelToolConfig `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(modelDeepSeek), &deepSeek); err != nil {
|
||||
t.Fatalf("unmarshal deepseek model: %v", err)
|
||||
}
|
||||
|
||||
exchangeBinance := a.toolManageExchangeConfig("user-1", `{
|
||||
"action":"create",
|
||||
"exchange_type":"binance",
|
||||
"account_name":"Binance 主账户",
|
||||
"enabled":true
|
||||
}`)
|
||||
var binance struct {
|
||||
Exchange safeExchangeToolConfig `json:"exchange"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(exchangeBinance), &binance); err != nil {
|
||||
t.Fatalf("unmarshal binance exchange: %v", err)
|
||||
}
|
||||
exchangeOKX := a.toolManageExchangeConfig("user-1", `{
|
||||
"action":"create",
|
||||
"exchange_type":"okx",
|
||||
"account_name":"OKX 主账户",
|
||||
"enabled":true
|
||||
}`)
|
||||
var okx struct {
|
||||
Exchange safeExchangeToolConfig `json:"exchange"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(exchangeOKX), &okx); err != nil {
|
||||
t.Fatalf("unmarshal okx exchange: %v", err)
|
||||
}
|
||||
|
||||
strategyA := a.toolManageStrategy("user-1", `{"action":"create","name":"策略A","lang":"zh"}`)
|
||||
var stA struct {
|
||||
Strategy safeStrategyToolConfig `json:"strategy"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(strategyA), &stA); err != nil {
|
||||
t.Fatalf("unmarshal strategy A: %v", err)
|
||||
}
|
||||
strategyB := a.toolManageStrategy("user-1", `{"action":"create","name":"策略B","lang":"zh"}`)
|
||||
var stB struct {
|
||||
Strategy safeStrategyToolConfig `json:"strategy"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(strategyB), &stB); err != nil {
|
||||
t.Fatalf("unmarshal strategy B: %v", err)
|
||||
}
|
||||
|
||||
createTrader := a.toolManageTrader("user-1", `{
|
||||
"action":"create",
|
||||
"name":"实盘一号",
|
||||
"ai_model_id":"`+openAI.Model.ID+`",
|
||||
"exchange_id":"`+binance.Exchange.ID+`",
|
||||
"strategy_id":"`+stA.Strategy.ID+`"
|
||||
}`)
|
||||
var trader struct {
|
||||
Trader safeTraderToolConfig `json:"trader"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(createTrader), &trader); err != nil {
|
||||
t.Fatalf("unmarshal trader: %v", err)
|
||||
}
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 17, "zh", "更新交易员绑定,把实盘一号换成 deepseek-chat、OKX 主账户 和 策略B")
|
||||
if err != nil {
|
||||
t.Fatalf("update trader bindings error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已更新交易员绑定") {
|
||||
t.Fatalf("expected trader binding update response, got %q", resp)
|
||||
}
|
||||
|
||||
listRaw := a.toolListTraders("user-1")
|
||||
if !strings.Contains(listRaw, deepSeek.Model.ID) || !strings.Contains(listRaw, okx.Exchange.ID) || !strings.Contains(listRaw, stB.Strategy.ID) {
|
||||
t.Fatalf("expected trader bindings to change, got %s", listRaw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrategyManagementDeleteAllUserStrategies(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
for _, name := range []string{"趋势策略A", "趋势策略B"} {
|
||||
resp := a.toolManageStrategy("user-1", `{
|
||||
"action":"create",
|
||||
"name":"`+name+`",
|
||||
"lang":"zh"
|
||||
}`)
|
||||
if strings.Contains(resp, `"error"`) {
|
||||
t.Fatalf("failed to create strategy %q: %s", name, resp)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 21, "zh", "现在把所有的策略全部删除")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() bulk delete start error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "确认") || !strings.Contains(resp, "全部自定义策略") {
|
||||
t.Fatalf("expected bulk delete confirmation, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 21, "zh", "确认")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() bulk delete confirm error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "成功删除 2 个") {
|
||||
t.Fatalf("expected bulk delete success summary, got %q", resp)
|
||||
}
|
||||
|
||||
listResp := a.toolGetStrategies("user-1")
|
||||
if strings.Contains(listResp, "趋势策略A") || strings.Contains(listResp, "趋势策略B") {
|
||||
t.Fatalf("expected created strategies to be deleted, got %s", listResp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTraderSkillRejectsDisabledExchangeWithClearPrompt(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
|
||||
_ = a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"deepseek",
|
||||
"enabled":true,
|
||||
"api_key":"sk-test",
|
||||
"custom_api_url":"https://api.deepseek.com/v1",
|
||||
"custom_model_name":"deepseek-chat"
|
||||
}`)
|
||||
enabledExchange := a.toolManageExchangeConfig("user-1", `{
|
||||
"action":"create",
|
||||
"exchange_type":"okx",
|
||||
"account_name":"test",
|
||||
"enabled":true
|
||||
}`)
|
||||
if strings.Contains(enabledExchange, `"error"`) {
|
||||
t.Fatalf("failed to create enabled exchange: %s", enabledExchange)
|
||||
}
|
||||
anotherEnabledExchange := a.toolManageExchangeConfig("user-1", `{
|
||||
"action":"create",
|
||||
"exchange_type":"okx",
|
||||
"account_name":"lky",
|
||||
"enabled":true
|
||||
}`)
|
||||
if strings.Contains(anotherEnabledExchange, `"error"`) {
|
||||
t.Fatalf("failed to create second enabled exchange: %s", anotherEnabledExchange)
|
||||
}
|
||||
disabledExchange := a.toolManageExchangeConfig("user-1", `{
|
||||
"action":"create",
|
||||
"exchange_type":"okx",
|
||||
"account_name":"new",
|
||||
"enabled":false
|
||||
}`)
|
||||
if strings.Contains(disabledExchange, `"error"`) {
|
||||
t.Fatalf("failed to create disabled exchange: %s", disabledExchange)
|
||||
}
|
||||
_ = a.toolManageStrategy("user-1", `{"action":"create","name":"激进","lang":"zh"}`)
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 24, "zh", "给我创建一个trader")
|
||||
if err != nil {
|
||||
t.Fatalf("create trader start error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "new(已禁用)") {
|
||||
t.Fatalf("expected disabled exchange to be labelled, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 24, "zh", "名称叫test,交易所用new、策略用激进")
|
||||
if err != nil {
|
||||
t.Fatalf("disabled exchange selection error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "当前已禁用") {
|
||||
t.Fatalf("expected disabled exchange warning, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelReplyExitsExchangeUpdateFlow(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
_ = a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"deepseek",
|
||||
"enabled":true,
|
||||
"api_key":"sk-test",
|
||||
"custom_api_url":"https://api.deepseek.com/v1",
|
||||
"custom_model_name":"deepseek-chat"
|
||||
}`)
|
||||
|
||||
exchangeResp := a.toolManageExchangeConfig("user-1", `{
|
||||
"action":"create",
|
||||
"exchange_type":"okx",
|
||||
"account_name":"test",
|
||||
"enabled":true
|
||||
}`)
|
||||
if strings.Contains(exchangeResp, `"error"`) {
|
||||
t.Fatalf("failed to create exchange: %s", exchangeResp)
|
||||
}
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 25, "zh", "把test这个交易所改一下")
|
||||
if err != nil {
|
||||
t.Fatalf("enter exchange update flow error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "请告诉我你要改什么") {
|
||||
t.Fatalf("expected exchange update prompt, got %q", resp)
|
||||
}
|
||||
|
||||
resp, err = a.thinkAndAct(context.Background(), "user-1", 25, "zh", "不改")
|
||||
if err != nil {
|
||||
t.Fatalf("cancel exchange flow error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "已取消当前流程") {
|
||||
t.Fatalf("expected flow cancellation, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifySkillSessionInputInterruptsOnDeflection(t *testing.T) {
|
||||
session := skillSession{Name: "exchange_management", Action: "update"}
|
||||
a := &Agent{}
|
||||
|
||||
if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "你能帮我看下报错吗"); got != "interrupt" {
|
||||
t.Fatalf("expected diagnosis deflection to interrupt current skill flow, got %q", got)
|
||||
}
|
||||
if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "换话题了大哥"); got != "cancel" {
|
||||
t.Fatalf("expected topic shift to cancel current skill flow, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
type skillSessionClassifierAIClient struct {
|
||||
lastSystemPrompt string
|
||||
lastUserPrompt string
|
||||
response string
|
||||
}
|
||||
|
||||
func (c *skillSessionClassifierAIClient) SetAPIKey(string, string, string) {}
|
||||
func (c *skillSessionClassifierAIClient) SetTimeout(time.Duration) {}
|
||||
func (c *skillSessionClassifierAIClient) CallWithMessages(string, string) (string, error) {
|
||||
return "", errors.New("unexpected CallWithMessages")
|
||||
}
|
||||
func (c *skillSessionClassifierAIClient) CallWithRequest(req *mcp.Request) (string, error) {
|
||||
if len(req.Messages) > 0 {
|
||||
c.lastSystemPrompt = req.Messages[0].Content
|
||||
}
|
||||
if len(req.Messages) > 1 {
|
||||
c.lastUserPrompt = req.Messages[1].Content
|
||||
}
|
||||
return c.response, nil
|
||||
}
|
||||
func (c *skillSessionClassifierAIClient) CallWithRequestStream(*mcp.Request, func(string)) (string, error) {
|
||||
return "", errors.New("unexpected CallWithRequestStream")
|
||||
}
|
||||
func (c *skillSessionClassifierAIClient) CallWithRequestFull(*mcp.Request) (*mcp.LLMResponse, error) {
|
||||
return nil, errors.New("unexpected CallWithRequestFull")
|
||||
}
|
||||
|
||||
func TestClassifySkillSessionInputUsesSlotExpectationWithoutLLM(t *testing.T) {
|
||||
client := &skillSessionClassifierAIClient{response: `{"decision":"interrupt"}`}
|
||||
a := &Agent{aiClient: client}
|
||||
session := skillSession{
|
||||
Name: "strategy_management",
|
||||
Action: "update_config",
|
||||
Fields: map[string]string{
|
||||
skillDAGStepField: "resolve_config_value",
|
||||
"config_field": "min_confidence",
|
||||
},
|
||||
}
|
||||
|
||||
if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "70"); got != "continue" {
|
||||
t.Fatalf("expected numeric slot fill to continue, got %q", got)
|
||||
}
|
||||
if client.lastSystemPrompt != "" {
|
||||
t.Fatalf("expected no LLM call for direct slot expectation, got prompt %q", client.lastSystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifySkillSessionInputUsesLLMOnlyForAmbiguousDeflection(t *testing.T) {
|
||||
client := &skillSessionClassifierAIClient{response: `{"decision":"interrupt"}`}
|
||||
a := &Agent{
|
||||
aiClient: client,
|
||||
history: newChatHistory(10),
|
||||
}
|
||||
session := skillSession{
|
||||
Name: "exchange_management",
|
||||
Action: "update",
|
||||
Fields: map[string]string{
|
||||
skillDAGStepField: "collect_account_name",
|
||||
},
|
||||
}
|
||||
|
||||
if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "你能帮我看下报错吗"); got != "interrupt" {
|
||||
t.Fatalf("expected ambiguous deflection to interrupt, got %q", got)
|
||||
}
|
||||
if !strings.Contains(client.lastSystemPrompt, "classify one user message while a NOFXi structured management flow is active") {
|
||||
t.Fatalf("expected LLM classifier prompt, got %q", client.lastSystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifySkillSessionInputUsesLLMForUnmatchedActiveSessionInput(t *testing.T) {
|
||||
client := &skillSessionClassifierAIClient{response: `{"decision":"continue"}`}
|
||||
a := &Agent{
|
||||
aiClient: client,
|
||||
history: newChatHistory(10),
|
||||
}
|
||||
session := skillSession{
|
||||
Name: "model_management",
|
||||
Action: "create",
|
||||
Fields: map[string]string{
|
||||
skillDAGStepField: "collect_optional_fields",
|
||||
"provider": "openai",
|
||||
},
|
||||
}
|
||||
|
||||
if got := a.classifySkillSessionInput(context.Background(), 0, "zh", session, "新增一个"); got != "continue" {
|
||||
t.Fatalf("expected unmatched active-session input to follow LLM decision, got %q", got)
|
||||
}
|
||||
if !strings.Contains(client.lastSystemPrompt, "classify one user message while a NOFXi structured management flow is active") {
|
||||
t.Fatalf("expected LLM classifier prompt, got %q", client.lastSystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrategyManagementCanDescribeDefaultConfig(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
_ = a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"deepseek",
|
||||
"enabled":true,
|
||||
"api_key":"sk-test",
|
||||
"custom_api_url":"https://api.deepseek.com/v1",
|
||||
"custom_model_name":"deepseek-chat"
|
||||
}`)
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 22, "zh", "看一下默认配置")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() default config error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "默认策略模板") || !strings.Contains(resp, "最低置信度") {
|
||||
t.Fatalf("expected default strategy config response, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrategyManagementSupportsMultiFieldConfigUpdate(t *testing.T) {
|
||||
a := newTestAgentWithStore(t)
|
||||
_ = a.toolManageModelConfig("user-1", `{
|
||||
"action":"create",
|
||||
"provider":"deepseek",
|
||||
"enabled":true,
|
||||
"api_key":"sk-test",
|
||||
"custom_api_url":"https://api.deepseek.com/v1",
|
||||
"custom_model_name":"deepseek-chat"
|
||||
}`)
|
||||
|
||||
createResp := a.toolManageStrategy("user-1", `{
|
||||
"action":"create",
|
||||
"name":"趋势策略A",
|
||||
"lang":"zh"
|
||||
}`)
|
||||
if strings.Contains(createResp, `"error"`) {
|
||||
t.Fatalf("failed to create strategy: %s", createResp)
|
||||
}
|
||||
|
||||
resp, err := a.thinkAndAct(context.Background(), "user-1", 23, "zh", "把趋势策略A的最小置信度改成70,核心指标都全选")
|
||||
if err != nil {
|
||||
t.Fatalf("thinkAndAct() multi-field update error = %v", err)
|
||||
}
|
||||
if !strings.Contains(resp, "最小置信度") || !strings.Contains(resp, "EMA") {
|
||||
t.Fatalf("expected multi-field update confirmation, got %q", resp)
|
||||
}
|
||||
|
||||
strategiesRaw := a.toolGetStrategies("user-1")
|
||||
if !strings.Contains(strategiesRaw, `"min_confidence":70`) ||
|
||||
!strings.Contains(strategiesRaw, `"enable_ema":true`) ||
|
||||
!strings.Contains(strategiesRaw, `"enable_macd":true`) ||
|
||||
!strings.Contains(strategiesRaw, `"enable_rsi":true`) ||
|
||||
!strings.Contains(strategiesRaw, `"enable_atr":true`) ||
|
||||
!strings.Contains(strategiesRaw, `"enable_boll":true`) {
|
||||
t.Fatalf("expected strategy config to include updated confidence and indicators, got %s", strategiesRaw)
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"nofx/store"
|
||||
)
|
||||
|
||||
var urlPattern = regexp.MustCompile(`https://[^\s"'<>]+`)
|
||||
@@ -15,7 +18,7 @@ func detectTraderManagementIntent(text string) bool {
|
||||
return false
|
||||
}
|
||||
return containsAny(lower, []string{"交易员", "trader", "agent"}) &&
|
||||
containsAny(lower, []string{"修改", "编辑", "更新", "删除", "启动", "停止", "查看", "查询", "列出", "rename", "update", "delete", "start", "stop", "list", "show"})
|
||||
containsAny(lower, []string{"修改", "编辑", "更新", "改", "改一下", "删除", "删了", "启动", "停止", "查看", "查询", "列出", "rename", "update", "delete", "start", "stop", "list", "show"})
|
||||
}
|
||||
|
||||
func detectExchangeManagementIntent(text string) bool {
|
||||
@@ -24,7 +27,7 @@ func detectExchangeManagementIntent(text string) bool {
|
||||
return false
|
||||
}
|
||||
return containsAny(lower, []string{"交易所", "exchange", "okx", "binance", "bybit", "gate", "kucoin", "hyperliquid"}) &&
|
||||
containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "删除", "查询", "查看", "列出", "create", "update", "delete", "list", "show"})
|
||||
containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "删除", "删了", "查询", "查看", "列出", "启用", "禁用", "改名", "rename", "create", "update", "delete", "list", "show", "enable", "disable"})
|
||||
}
|
||||
|
||||
func detectModelManagementIntent(text string) bool {
|
||||
@@ -33,7 +36,7 @@ func detectModelManagementIntent(text string) bool {
|
||||
return false
|
||||
}
|
||||
return containsAny(lower, []string{"模型", "model", "provider", "deepseek", "openai", "claude", "gemini", "qwen", "kimi", "grok", "minimax"}) &&
|
||||
containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "删除", "查询", "查看", "列出", "create", "update", "delete", "list", "show"})
|
||||
containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "删除", "删了", "查询", "查看", "列出", "启用", "禁用", "改名", "rename", "create", "update", "delete", "list", "show", "enable", "disable"})
|
||||
}
|
||||
|
||||
func detectStrategyManagementIntent(text string) bool {
|
||||
@@ -41,8 +44,11 @@ func detectStrategyManagementIntent(text string) bool {
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if wantsDefaultStrategyConfig(text) {
|
||||
return true
|
||||
}
|
||||
return containsAny(lower, []string{"策略", "strategy"}) &&
|
||||
containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "删除", "查询", "查看", "列出", "激活", "复制", "create", "update", "delete", "list", "show", "activate", "duplicate"})
|
||||
containsAny(lower, []string{"创建", "新建", "修改", "编辑", "更新", "改", "改一下", "改成", "改为", "删除", "删了", "查询", "查看", "列出", "激活", "复制", "参数", "配置", "详情", "详细", "prompt", "提示词", "什么样", "怎么样", "create", "update", "delete", "list", "show", "activate", "duplicate", "detail", "details", "config", "configuration", "parameter", "prompt", "what kind"})
|
||||
}
|
||||
|
||||
func detectTraderDiagnosisSkill(text string) bool {
|
||||
@@ -62,8 +68,9 @@ func detectManagementAction(text string, domain string) string {
|
||||
if lower == "" {
|
||||
return ""
|
||||
}
|
||||
hasUpdateVerb := containsAny(lower, []string{"修改", "编辑", "更新", "改", "rename", "update", "切换", "换成", "换到"})
|
||||
switch {
|
||||
case containsAny(lower, []string{"删除", "删掉", "remove", "delete"}):
|
||||
case containsAny(lower, []string{"删除", "删掉", "删了", "remove", "delete"}):
|
||||
return "delete"
|
||||
case containsAny(lower, []string{"启动", "开始", "run", "start"}) && domain == "trader":
|
||||
return "start"
|
||||
@@ -73,10 +80,32 @@ func detectManagementAction(text string, domain string) string {
|
||||
return "activate"
|
||||
case containsAny(lower, []string{"复制", "duplicate"}) && domain == "strategy":
|
||||
return "duplicate"
|
||||
case containsAny(lower, []string{"改名", "重命名", "rename"}):
|
||||
return "update_name"
|
||||
case domain == "trader" && containsAny(lower, []string{"换模型", "换交易所", "换策略", "绑定", "切换模型", "切换交易所", "切换策略"}):
|
||||
return "update_bindings"
|
||||
case (domain == "exchange" || domain == "model") && containsAny(lower, []string{"启用", "禁用", "enable", "disable"}):
|
||||
return "update_status"
|
||||
case domain == "model" && hasUpdateVerb && containsAny(lower, []string{"url", "endpoint", "地址", "接口"}):
|
||||
return "update_endpoint"
|
||||
case domain == "strategy" && hasUpdateVerb && containsAny(lower, []string{"prompt", "提示词"}):
|
||||
return "update_prompt"
|
||||
case domain == "strategy" && hasUpdateVerb && containsAny(lower, []string{
|
||||
"参数", "配置", "config", "configuration", "parameter",
|
||||
"最大持仓", "最小置信度", "最低置信度", "主周期", "多周期", "时间框架",
|
||||
"btc/eth杠杆", "btc eth杠杆", "山寨币杠杆",
|
||||
"核心指标", "ema", "macd", "rsi", "atr", "boll", "bollinger", "布林",
|
||||
}):
|
||||
return "update_config"
|
||||
case containsAny(lower, []string{"修改", "编辑", "更新", "改", "rename", "update"}):
|
||||
return "update"
|
||||
case domain == "trader" && containsAny(lower, []string{"运行中的", "在跑", "running"}):
|
||||
return "query_running"
|
||||
case !containsAny(lower, []string{"创建", "新建", "create", "new"}) &&
|
||||
containsAny(lower, []string{"详情", "详细", "prompt", "提示词", "什么样", "怎么样", "detail", "details", "what kind"}):
|
||||
return "query_detail"
|
||||
case containsAny(lower, []string{"查询", "查看", "列出", "list", "show", "有哪些"}):
|
||||
return "query"
|
||||
return "query_list"
|
||||
case containsAny(lower, []string{"创建", "新建", "加一个", "create", "new"}):
|
||||
return "create"
|
||||
default:
|
||||
@@ -152,6 +181,21 @@ func fieldValue(session skillSession, key string) string {
|
||||
return strings.TrimSpace(session.Fields[key])
|
||||
}
|
||||
|
||||
func textMeansAllTargets(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return containsAny(lower, []string{
|
||||
"全部", "所有", "全都", "全部策略", "所有策略",
|
||||
"all", "all strategies", "every strategy",
|
||||
})
|
||||
}
|
||||
|
||||
func supportsBulkTargetSelection(skillName, action string) bool {
|
||||
return skillName == "strategy_management" && action == "delete"
|
||||
}
|
||||
|
||||
func resolveTargetFromText(text string, options []traderSkillOption, existing *EntityReference) *EntityReference {
|
||||
if existing != nil && (existing.ID != "" || existing.Name != "") {
|
||||
return existing
|
||||
@@ -173,6 +217,18 @@ func (a *Agent) handleTraderManagementSkill(storeUserID string, userID int64, la
|
||||
if action == "" || action == "create" {
|
||||
return "", false
|
||||
}
|
||||
if action == "query_running" {
|
||||
answer := formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID))
|
||||
return applyTraderQueryFilter(lang, answer, a.toolListTraders(storeUserID), "running_only"), true
|
||||
}
|
||||
if action == "query_detail" {
|
||||
options := a.loadTraderOptions(storeUserID)
|
||||
target := resolveTargetFromText(text, options, session.TargetRef)
|
||||
if detail, ok := a.describeTrader(storeUserID, lang, target); ok {
|
||||
return detail, true
|
||||
}
|
||||
return formatReadFastPathResponse(lang, "list_traders", a.toolListTraders(storeUserID)), true
|
||||
}
|
||||
return a.handleSimpleEntitySkill(storeUserID, userID, lang, text, session, "trader_management", action, a.loadTraderOptions(storeUserID))
|
||||
}
|
||||
|
||||
@@ -186,7 +242,13 @@ func (a *Agent) handleExchangeManagementSkill(storeUserID string, userID int64,
|
||||
}
|
||||
options := a.loadExchangeOptions(storeUserID)
|
||||
switch action {
|
||||
case "query":
|
||||
case "query_list":
|
||||
return formatReadFastPathResponse(lang, "get_exchange_configs", a.toolGetExchangeConfigs(storeUserID)), true
|
||||
case "query_detail":
|
||||
target := resolveTargetFromText(text, options, session.TargetRef)
|
||||
if detail, ok := a.describeExchange(storeUserID, lang, target); ok {
|
||||
return detail, true
|
||||
}
|
||||
return formatReadFastPathResponse(lang, "get_exchange_configs", a.toolGetExchangeConfigs(storeUserID)), true
|
||||
case "create":
|
||||
return a.handleExchangeCreateSkill(storeUserID, userID, lang, text, session), true
|
||||
@@ -205,7 +267,13 @@ func (a *Agent) handleModelManagementSkill(storeUserID string, userID int64, lan
|
||||
}
|
||||
options := a.loadEnabledModelOptions(storeUserID)
|
||||
switch action {
|
||||
case "query":
|
||||
case "query_list":
|
||||
return formatReadFastPathResponse(lang, "get_model_configs", a.toolGetModelConfigs(storeUserID)), true
|
||||
case "query_detail":
|
||||
target := resolveTargetFromText(text, options, session.TargetRef)
|
||||
if detail, ok := a.describeModel(storeUserID, lang, target); ok {
|
||||
return detail, true
|
||||
}
|
||||
return formatReadFastPathResponse(lang, "get_model_configs", a.toolGetModelConfigs(storeUserID)), true
|
||||
case "create":
|
||||
return a.handleModelCreateSkill(storeUserID, userID, lang, text, session), true
|
||||
@@ -219,12 +287,24 @@ func (a *Agent) handleStrategyManagementSkill(storeUserID string, userID int64,
|
||||
if session.Name == "strategy_management" && session.Action != "" {
|
||||
action = session.Action
|
||||
}
|
||||
if action == "" && wantsStrategyDetails(text) {
|
||||
action = "query_detail"
|
||||
}
|
||||
if action == "" {
|
||||
return "", false
|
||||
}
|
||||
options := a.loadStrategyOptions(storeUserID)
|
||||
switch action {
|
||||
case "query":
|
||||
case "query_detail":
|
||||
if wantsDefaultStrategyConfig(text) {
|
||||
return a.describeDefaultStrategyConfig(lang), true
|
||||
}
|
||||
target := resolveTargetFromText(text, options, session.TargetRef)
|
||||
if detail, ok := a.describeStrategy(storeUserID, lang, target); ok {
|
||||
return detail, true
|
||||
}
|
||||
return formatReadFastPathResponse(lang, "get_strategies", a.toolGetStrategies(storeUserID)), true
|
||||
case "query_list":
|
||||
return formatReadFastPathResponse(lang, "get_strategies", a.toolGetStrategies(storeUserID)), true
|
||||
case "create":
|
||||
return a.handleStrategyCreateSkill(storeUserID, userID, lang, text, session), true
|
||||
@@ -233,6 +313,350 @@ func (a *Agent) handleStrategyManagementSkill(storeUserID string, userID int64,
|
||||
}
|
||||
}
|
||||
|
||||
func wantsStrategyDetails(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return containsAny(lower, []string{
|
||||
"什么样", "怎么样", "详情", "详细", "参数", "配置", "prompt", "提示词",
|
||||
"what kind", "details", "detail", "config", "configuration", "parameter", "prompt",
|
||||
})
|
||||
}
|
||||
|
||||
func wantsDefaultStrategyConfig(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return containsAny(lower, []string{
|
||||
"默认配置", "默认策略", "默认模板", "模板配置",
|
||||
"default config", "default strategy", "default template",
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Agent) describeStrategy(storeUserID, lang string, target *EntityReference) (string, bool) {
|
||||
if a.store == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var strategy *store.Strategy
|
||||
var err error
|
||||
if target != nil && strings.TrimSpace(target.ID) != "" {
|
||||
strategy, err = a.store.Strategy().Get(storeUserID, strings.TrimSpace(target.ID))
|
||||
} else if target != nil && strings.TrimSpace(target.Name) != "" {
|
||||
strategies, listErr := a.store.Strategy().List(storeUserID)
|
||||
if listErr != nil {
|
||||
return "", false
|
||||
}
|
||||
for _, item := range strategies {
|
||||
if item != nil && strings.EqualFold(strings.TrimSpace(item.Name), strings.TrimSpace(target.Name)) {
|
||||
strategy = item
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
strategies, listErr := a.store.Strategy().List(storeUserID)
|
||||
if listErr != nil || len(strategies) != 1 {
|
||||
return "", false
|
||||
}
|
||||
strategy = strategies[0]
|
||||
}
|
||||
if err != nil || strategy == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var cfg store.StrategyConfig
|
||||
if strings.TrimSpace(strategy.Config) != "" {
|
||||
_ = json.Unmarshal([]byte(strategy.Config), &cfg)
|
||||
}
|
||||
|
||||
return formatStrategyDetailResponse(lang, strategy, cfg), true
|
||||
}
|
||||
|
||||
func formatStrategyDetailResponse(lang string, strategy *store.Strategy, cfg store.StrategyConfig) string {
|
||||
name := strings.TrimSpace(strategy.Name)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(strategy.ID)
|
||||
}
|
||||
|
||||
sourceBits := make([]string, 0, 4)
|
||||
if strings.TrimSpace(cfg.CoinSource.SourceType) != "" {
|
||||
sourceBits = append(sourceBits, cfg.CoinSource.SourceType)
|
||||
}
|
||||
if cfg.CoinSource.UseAI500 {
|
||||
sourceBits = append(sourceBits, fmt.Sprintf("AI500=%d", cfg.CoinSource.AI500Limit))
|
||||
}
|
||||
if cfg.CoinSource.UseOITop {
|
||||
sourceBits = append(sourceBits, fmt.Sprintf("OITop=%d", cfg.CoinSource.OITopLimit))
|
||||
}
|
||||
if cfg.CoinSource.UseOILow {
|
||||
sourceBits = append(sourceBits, fmt.Sprintf("OILow=%d", cfg.CoinSource.OILowLimit))
|
||||
}
|
||||
if len(cfg.CoinSource.StaticCoins) > 0 {
|
||||
sourceBits = append(sourceBits, "static="+strings.Join(cfg.CoinSource.StaticCoins, ","))
|
||||
}
|
||||
|
||||
timeframes := append([]string(nil), cfg.Indicators.Klines.SelectedTimeframes...)
|
||||
if len(timeframes) == 0 {
|
||||
timeframes = cleanStringList([]string{cfg.Indicators.Klines.PrimaryTimeframe, cfg.Indicators.Klines.LongerTimeframe})
|
||||
}
|
||||
|
||||
indicatorBits := make([]string, 0, 8)
|
||||
if cfg.Indicators.EnableRawKlines {
|
||||
indicatorBits = append(indicatorBits, "raw_klines")
|
||||
}
|
||||
if cfg.Indicators.EnableVolume {
|
||||
indicatorBits = append(indicatorBits, "volume")
|
||||
}
|
||||
if cfg.Indicators.EnableOI {
|
||||
indicatorBits = append(indicatorBits, "oi")
|
||||
}
|
||||
if cfg.Indicators.EnableFundingRate {
|
||||
indicatorBits = append(indicatorBits, "funding_rate")
|
||||
}
|
||||
if cfg.Indicators.EnableEMA {
|
||||
indicatorBits = append(indicatorBits, "ema")
|
||||
}
|
||||
if cfg.Indicators.EnableMACD {
|
||||
indicatorBits = append(indicatorBits, "macd")
|
||||
}
|
||||
if cfg.Indicators.EnableRSI {
|
||||
indicatorBits = append(indicatorBits, "rsi")
|
||||
}
|
||||
if cfg.Indicators.EnableATR {
|
||||
indicatorBits = append(indicatorBits, "atr")
|
||||
}
|
||||
if cfg.Indicators.EnableBOLL {
|
||||
indicatorBits = append(indicatorBits, "boll")
|
||||
}
|
||||
sort.Strings(indicatorBits)
|
||||
|
||||
promptBits := make([]string, 0, 5)
|
||||
if strings.TrimSpace(cfg.PromptSections.RoleDefinition) != "" {
|
||||
promptBits = append(promptBits, "role_definition")
|
||||
}
|
||||
if strings.TrimSpace(cfg.PromptSections.TradingFrequency) != "" {
|
||||
promptBits = append(promptBits, "trading_frequency")
|
||||
}
|
||||
if strings.TrimSpace(cfg.PromptSections.EntryStandards) != "" {
|
||||
promptBits = append(promptBits, "entry_standards")
|
||||
}
|
||||
if strings.TrimSpace(cfg.PromptSections.DecisionProcess) != "" {
|
||||
promptBits = append(promptBits, "decision_process")
|
||||
}
|
||||
|
||||
customPrompt := strings.TrimSpace(cfg.CustomPrompt)
|
||||
customPromptPreview := customPrompt
|
||||
if len([]rune(customPromptPreview)) > 120 {
|
||||
runes := []rune(customPromptPreview)
|
||||
customPromptPreview = string(runes[:120]) + "..."
|
||||
}
|
||||
|
||||
if lang == "zh" {
|
||||
lines := []string{
|
||||
fmt.Sprintf("策略“%s”概览:", name),
|
||||
fmt.Sprintf("- 类型:%s", defaultIfEmpty(strings.TrimSpace(cfg.StrategyType), "ai_trading")),
|
||||
fmt.Sprintf("- 语言:%s", defaultIfEmpty(strings.TrimSpace(cfg.Language), "zh")),
|
||||
}
|
||||
if strings.TrimSpace(strategy.Description) != "" {
|
||||
lines = append(lines, fmt.Sprintf("- 描述:%s", strings.TrimSpace(strategy.Description)))
|
||||
}
|
||||
if len(sourceBits) > 0 {
|
||||
lines = append(lines, "- 标的来源:"+strings.Join(sourceBits, " | "))
|
||||
}
|
||||
if len(timeframes) > 0 {
|
||||
lines = append(lines, "- K线周期:"+strings.Join(timeframes, " / "))
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- 仓位风险:最多持仓 %d,BTC/ETH 最大杠杆 %d,山寨最大杠杆 %d,最低置信度 %d",
|
||||
cfg.RiskControl.MaxPositions, cfg.RiskControl.BTCETHMaxLeverage, cfg.RiskControl.AltcoinMaxLeverage, cfg.RiskControl.MinConfidence))
|
||||
if len(indicatorBits) > 0 {
|
||||
lines = append(lines, "- 已启用指标:"+strings.Join(indicatorBits, "、"))
|
||||
}
|
||||
if len(promptBits) > 0 {
|
||||
lines = append(lines, "- Prompt 模块:"+strings.Join(promptBits, "、"))
|
||||
}
|
||||
if customPromptPreview != "" {
|
||||
lines = append(lines, "- 自定义 Prompt:"+customPromptPreview)
|
||||
} else {
|
||||
lines = append(lines, "- 自定义 Prompt:当前为空,主要使用策略模板内置 prompt sections。")
|
||||
}
|
||||
lines = append(lines, "- 如果你要,我还可以继续展开这条策略的完整参数 JSON,或者逐段解释它的 prompt。")
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
lines := []string{
|
||||
fmt.Sprintf("Strategy %q overview:", name),
|
||||
fmt.Sprintf("- Type: %s", defaultIfEmpty(strings.TrimSpace(cfg.StrategyType), "ai_trading")),
|
||||
fmt.Sprintf("- Language: %s", defaultIfEmpty(strings.TrimSpace(cfg.Language), "en")),
|
||||
}
|
||||
if strings.TrimSpace(strategy.Description) != "" {
|
||||
lines = append(lines, fmt.Sprintf("- Description: %s", strings.TrimSpace(strategy.Description)))
|
||||
}
|
||||
if len(sourceBits) > 0 {
|
||||
lines = append(lines, "- Coin source: "+strings.Join(sourceBits, " | "))
|
||||
}
|
||||
if len(timeframes) > 0 {
|
||||
lines = append(lines, "- Timeframes: "+strings.Join(timeframes, " / "))
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- Risk: max positions %d, BTC/ETH max leverage %d, alt max leverage %d, min confidence %d",
|
||||
cfg.RiskControl.MaxPositions, cfg.RiskControl.BTCETHMaxLeverage, cfg.RiskControl.AltcoinMaxLeverage, cfg.RiskControl.MinConfidence))
|
||||
if len(indicatorBits) > 0 {
|
||||
lines = append(lines, "- Enabled indicators: "+strings.Join(indicatorBits, ", "))
|
||||
}
|
||||
if len(promptBits) > 0 {
|
||||
lines = append(lines, "- Prompt modules: "+strings.Join(promptBits, ", "))
|
||||
}
|
||||
if customPromptPreview != "" {
|
||||
lines = append(lines, "- Custom prompt: "+customPromptPreview)
|
||||
} else {
|
||||
lines = append(lines, "- Custom prompt: empty right now; it mainly uses the built-in prompt sections from the strategy template.")
|
||||
}
|
||||
lines = append(lines, "- I can also expand the full strategy config JSON or walk through the prompt section by section.")
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (a *Agent) describeDefaultStrategyConfig(lang string) string {
|
||||
if lang != "zh" {
|
||||
lang = "en"
|
||||
}
|
||||
cfg := store.GetDefaultStrategyConfig(lang)
|
||||
name := "Default Strategy Template"
|
||||
description := "System default strategy configuration template"
|
||||
if lang == "zh" {
|
||||
name = "默认策略模板"
|
||||
description = "系统默认策略配置模板"
|
||||
}
|
||||
return formatStrategyDetailResponse(lang, &store.Strategy{
|
||||
ID: "default_strategy_template",
|
||||
Name: name,
|
||||
Description: description,
|
||||
}, cfg)
|
||||
}
|
||||
|
||||
func (a *Agent) describeTrader(storeUserID, lang string, target *EntityReference) (string, bool) {
|
||||
raw := a.toolListTraders(storeUserID)
|
||||
var payload struct {
|
||||
Traders []safeTraderToolConfig `json:"traders"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
return "", false
|
||||
}
|
||||
trader := findTraderByReference(payload.Traders, target)
|
||||
if trader == nil {
|
||||
if len(payload.Traders) != 1 {
|
||||
return "", false
|
||||
}
|
||||
trader = &payload.Traders[0]
|
||||
}
|
||||
if lang == "zh" {
|
||||
status := "未运行"
|
||||
if trader.IsRunning {
|
||||
status = "运行中"
|
||||
}
|
||||
return fmt.Sprintf("交易员“%s”详情:\n- 状态:%s\n- 模型:%s\n- 交易所:%s\n- 策略:%s\n- 扫描间隔:%d 分钟\n- 初始余额:%.2f",
|
||||
trader.Name, status, trader.AIModelID, trader.ExchangeID, defaultIfEmpty(trader.StrategyID, "未绑定"), trader.ScanIntervalMinutes, trader.InitialBalance), true
|
||||
}
|
||||
status := "stopped"
|
||||
if trader.IsRunning {
|
||||
status = "running"
|
||||
}
|
||||
return fmt.Sprintf("Trader %q details:\n- Status: %s\n- Model: %s\n- Exchange: %s\n- Strategy: %s\n- Scan interval: %d minutes\n- Initial balance: %.2f",
|
||||
trader.Name, status, trader.AIModelID, trader.ExchangeID, defaultIfEmpty(trader.StrategyID, "none"), trader.ScanIntervalMinutes, trader.InitialBalance), true
|
||||
}
|
||||
|
||||
func (a *Agent) describeExchange(storeUserID, lang string, target *EntityReference) (string, bool) {
|
||||
raw := a.toolGetExchangeConfigs(storeUserID)
|
||||
var payload struct {
|
||||
ExchangeConfigs []safeExchangeToolConfig `json:"exchange_configs"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
return "", false
|
||||
}
|
||||
exchange := findExchangeByReference(payload.ExchangeConfigs, target)
|
||||
if exchange == nil {
|
||||
if len(payload.ExchangeConfigs) != 1 {
|
||||
return "", false
|
||||
}
|
||||
exchange = &payload.ExchangeConfigs[0]
|
||||
}
|
||||
if lang == "zh" {
|
||||
return fmt.Sprintf("交易所配置“%s”详情:\n- 交易所:%s\n- 已启用:%t\n- API Key:%t\n- Secret:%t\n- Passphrase:%t\n- Testnet:%t",
|
||||
defaultIfEmpty(exchange.AccountName, exchange.ID), exchange.ExchangeType, exchange.Enabled, exchange.HasAPIKey, exchange.HasSecretKey, exchange.HasPassphrase, exchange.Testnet), true
|
||||
}
|
||||
return fmt.Sprintf("Exchange config %q details:\n- Exchange: %s\n- Enabled: %t\n- API key present: %t\n- Secret present: %t\n- Passphrase present: %t\n- Testnet: %t",
|
||||
defaultIfEmpty(exchange.AccountName, exchange.ID), exchange.ExchangeType, exchange.Enabled, exchange.HasAPIKey, exchange.HasSecretKey, exchange.HasPassphrase, exchange.Testnet), true
|
||||
}
|
||||
|
||||
func (a *Agent) describeModel(storeUserID, lang string, target *EntityReference) (string, bool) {
|
||||
raw := a.toolGetModelConfigs(storeUserID)
|
||||
var payload struct {
|
||||
ModelConfigs []safeModelToolConfig `json:"model_configs"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
return "", false
|
||||
}
|
||||
model := findModelByReference(payload.ModelConfigs, target)
|
||||
if model == nil {
|
||||
if len(payload.ModelConfigs) != 1 {
|
||||
return "", false
|
||||
}
|
||||
model = &payload.ModelConfigs[0]
|
||||
}
|
||||
if lang == "zh" {
|
||||
return fmt.Sprintf("模型配置“%s”详情:\n- Provider:%s\n- 已启用:%t\n- API Key:%t\n- URL:%s\n- Model Name:%s",
|
||||
defaultIfEmpty(model.Name, model.ID), model.Provider, model.Enabled, model.HasAPIKey, defaultIfEmpty(model.CustomAPIURL, "未设置"), defaultIfEmpty(model.CustomModelName, "未设置")), true
|
||||
}
|
||||
return fmt.Sprintf("Model config %q details:\n- Provider: %s\n- Enabled: %t\n- API key present: %t\n- URL: %s\n- Model name: %s",
|
||||
defaultIfEmpty(model.Name, model.ID), model.Provider, model.Enabled, model.HasAPIKey, defaultIfEmpty(model.CustomAPIURL, "not set"), defaultIfEmpty(model.CustomModelName, "not set")), true
|
||||
}
|
||||
|
||||
func findTraderByReference(items []safeTraderToolConfig, target *EntityReference) *safeTraderToolConfig {
|
||||
if target == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range items {
|
||||
if strings.TrimSpace(target.ID) != "" && items[i].ID == strings.TrimSpace(target.ID) {
|
||||
return &items[i]
|
||||
}
|
||||
if strings.TrimSpace(target.Name) != "" && strings.EqualFold(strings.TrimSpace(items[i].Name), strings.TrimSpace(target.Name)) {
|
||||
return &items[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findExchangeByReference(items []safeExchangeToolConfig, target *EntityReference) *safeExchangeToolConfig {
|
||||
if target == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range items {
|
||||
name := defaultIfEmpty(items[i].AccountName, items[i].Name)
|
||||
if strings.TrimSpace(target.ID) != "" && items[i].ID == strings.TrimSpace(target.ID) {
|
||||
return &items[i]
|
||||
}
|
||||
if strings.TrimSpace(target.Name) != "" && strings.EqualFold(strings.TrimSpace(name), strings.TrimSpace(target.Name)) {
|
||||
return &items[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findModelByReference(items []safeModelToolConfig, target *EntityReference) *safeModelToolConfig {
|
||||
if target == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range items {
|
||||
if strings.TrimSpace(target.ID) != "" && items[i].ID == strings.TrimSpace(target.ID) {
|
||||
return &items[i]
|
||||
}
|
||||
if strings.TrimSpace(target.Name) != "" && strings.EqualFold(strings.TrimSpace(items[i].Name), strings.TrimSpace(target.Name)) {
|
||||
return &items[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) loadTraderOptions(storeUserID string) []traderSkillOption {
|
||||
if a.store == nil {
|
||||
return nil
|
||||
@@ -252,6 +676,9 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang
|
||||
if session.Name == "" {
|
||||
session = skillSession{Name: "exchange_management", Action: "create", Phase: "collecting"}
|
||||
}
|
||||
if fieldValue(session, skillDAGStepField) == "" {
|
||||
setSkillDAGStep(&session, "resolve_exchange_type")
|
||||
}
|
||||
if isCancelSkillReply(text) {
|
||||
a.clearSkillSession(userID)
|
||||
if lang == "zh" {
|
||||
@@ -267,6 +694,7 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang
|
||||
}
|
||||
exType := fieldValue(session, "exchange_type")
|
||||
if actionRequiresSlot("exchange_management", "create", "exchange_type") && exType == "" {
|
||||
setSkillDAGStep(&session, "resolve_exchange_type")
|
||||
a.saveSkillSession(userID, session)
|
||||
if lang == "zh" {
|
||||
return "要创建交易所配置,我还需要:" + slotDisplayName("exchange_type", lang) + "。例如:OKX、Binance、Bybit。"
|
||||
@@ -277,6 +705,7 @@ func (a *Agent) handleExchangeCreateSkill(storeUserID string, userID int64, lang
|
||||
if accountName == "" {
|
||||
accountName = "Default"
|
||||
}
|
||||
setSkillDAGStep(&session, "execute_create")
|
||||
args := map[string]any{
|
||||
"action": "create",
|
||||
"exchange_type": exType,
|
||||
@@ -302,6 +731,9 @@ func (a *Agent) handleModelCreateSkill(storeUserID string, userID int64, lang, t
|
||||
if session.Name == "" {
|
||||
session = skillSession{Name: "model_management", Action: "create", Phase: "collecting"}
|
||||
}
|
||||
if fieldValue(session, skillDAGStepField) == "" {
|
||||
setSkillDAGStep(&session, "resolve_provider")
|
||||
}
|
||||
if isCancelSkillReply(text) {
|
||||
a.clearSkillSession(userID)
|
||||
if lang == "zh" {
|
||||
@@ -320,17 +752,19 @@ func (a *Agent) handleModelCreateSkill(storeUserID string, userID int64, lang, t
|
||||
}
|
||||
provider := fieldValue(session, "provider")
|
||||
if actionRequiresSlot("model_management", "create", "provider") && provider == "" {
|
||||
setSkillDAGStep(&session, "resolve_provider")
|
||||
a.saveSkillSession(userID, session)
|
||||
if lang == "zh" {
|
||||
return "要创建模型配置,我还需要:" + slotDisplayName("provider", lang) + ",例如:OpenAI、DeepSeek、Claude、Gemini。"
|
||||
}
|
||||
return "To create a model config, I need the provider first, for example OpenAI, DeepSeek, Claude, or Gemini."
|
||||
}
|
||||
setSkillDAGStep(&session, "execute_create")
|
||||
args := map[string]any{
|
||||
"action": "create",
|
||||
"provider": provider,
|
||||
"name": defaultIfEmpty(fieldValue(session, "name"), provider),
|
||||
"custom_api_url": fieldValue(session, "custom_api_url"),
|
||||
"action": "create",
|
||||
"provider": provider,
|
||||
"name": defaultIfEmpty(fieldValue(session, "name"), provider),
|
||||
"custom_api_url": fieldValue(session, "custom_api_url"),
|
||||
"custom_model_name": fieldValue(session, "custom_model_name"),
|
||||
}
|
||||
raw, _ := json.Marshal(args)
|
||||
@@ -353,6 +787,9 @@ func (a *Agent) handleStrategyCreateSkill(storeUserID string, userID int64, lang
|
||||
if session.Name == "" {
|
||||
session = skillSession{Name: "strategy_management", Action: "create", Phase: "collecting"}
|
||||
}
|
||||
if fieldValue(session, skillDAGStepField) == "" {
|
||||
setSkillDAGStep(&session, "resolve_name")
|
||||
}
|
||||
if isCancelSkillReply(text) {
|
||||
a.clearSkillSession(userID)
|
||||
if lang == "zh" {
|
||||
@@ -371,12 +808,14 @@ func (a *Agent) handleStrategyCreateSkill(storeUserID string, userID int64, lang
|
||||
}
|
||||
}
|
||||
if actionRequiresSlot("strategy_management", "create", "name") && name == "" {
|
||||
setSkillDAGStep(&session, "resolve_name")
|
||||
a.saveSkillSession(userID, session)
|
||||
if lang == "zh" {
|
||||
return "要创建策略,我还需要:" + slotDisplayName("name", lang) + "。你可以直接说:创建一个叫“趋势策略A”的策略。"
|
||||
}
|
||||
return "To create a strategy, I need a strategy name. You can say: create a strategy called 'Trend A'."
|
||||
}
|
||||
setSkillDAGStep(&session, "execute_create")
|
||||
args := map[string]any{"action": "create", "name": name, "lang": "zh"}
|
||||
raw, _ := json.Marshal(args)
|
||||
resp := a.toolManageStrategy(storeUserID, string(raw))
|
||||
@@ -408,22 +847,65 @@ func (a *Agent) handleSimpleEntitySkill(storeUserID string, userID int64, lang,
|
||||
if session.Name != skillName || session.Action != action {
|
||||
return "", false
|
||||
}
|
||||
session.TargetRef = resolveTargetFromText(text, options, session.TargetRef)
|
||||
if session.TargetRef == nil && action != "query" {
|
||||
a.saveSkillSession(userID, session)
|
||||
label := formatOptionList("可选对象:", options)
|
||||
if lang == "zh" {
|
||||
reply := "我还需要你明确要操作的是哪一个对象。"
|
||||
|
||||
if dag, ok := getSkillDAG(skillName, action); ok && len(dag.Steps) > 0 {
|
||||
currentStep, _ := currentSkillDAGStep(session)
|
||||
if currentStep.ID == "resolve_target" {
|
||||
if supportsBulkTargetSelection(skillName, action) && textMeansAllTargets(text) {
|
||||
setField(&session, "bulk_scope", "all")
|
||||
advanceSkillDAGStep(&session, currentStep.ID)
|
||||
} else {
|
||||
session.TargetRef = resolveTargetFromText(text, options, session.TargetRef)
|
||||
}
|
||||
if session.TargetRef == nil {
|
||||
if !(supportsBulkTargetSelection(skillName, action) && fieldValue(session, "bulk_scope") == "all") {
|
||||
setSkillDAGStep(&session, "resolve_target")
|
||||
a.saveSkillSession(userID, session)
|
||||
label := "可选对象:"
|
||||
if lang != "zh" {
|
||||
label = "Available targets:"
|
||||
}
|
||||
optionList := formatOptionList(label, options)
|
||||
if lang == "zh" {
|
||||
reply := "当前这一步需要先确定目标对象。请告诉我你要操作哪一个。"
|
||||
if optionList != "" {
|
||||
reply += "\n" + optionList
|
||||
}
|
||||
return reply, true
|
||||
}
|
||||
reply := "This step needs a target object first. Tell me which one to operate on."
|
||||
if optionList != "" {
|
||||
reply += "\n" + optionList
|
||||
}
|
||||
return reply, true
|
||||
}
|
||||
}
|
||||
if fieldValue(session, skillDAGStepField) == currentStep.ID {
|
||||
advanceSkillDAGStep(&session, currentStep.ID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if supportsBulkTargetSelection(skillName, action) && textMeansAllTargets(text) {
|
||||
setField(&session, "bulk_scope", "all")
|
||||
} else {
|
||||
session.TargetRef = resolveTargetFromText(text, options, session.TargetRef)
|
||||
}
|
||||
if session.TargetRef == nil && fieldValue(session, "bulk_scope") != "all" && action != "query" && action != "query_list" && action != "query_detail" && action != "query_running" {
|
||||
a.saveSkillSession(userID, session)
|
||||
label := formatOptionList("可选对象:", options)
|
||||
if lang == "zh" {
|
||||
reply := "我还需要你明确要操作的是哪一个对象。"
|
||||
if label != "" {
|
||||
reply += "\n" + label
|
||||
}
|
||||
return reply, true
|
||||
}
|
||||
reply := "I still need you to specify which object to operate on."
|
||||
if label != "" {
|
||||
reply += "\n" + label
|
||||
}
|
||||
return reply, true
|
||||
}
|
||||
reply := "I still need you to specify which object to operate on."
|
||||
if label != "" {
|
||||
reply += "\n" + label
|
||||
}
|
||||
return reply, true
|
||||
}
|
||||
|
||||
switch skillName {
|
||||
|
||||
180
agent/skill_outcome.go
Normal file
180
agent/skill_outcome.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
skillOutcomeSuccess = "success"
|
||||
skillOutcomeNeedMoreInfo = "need_more_info"
|
||||
skillOutcomeRecoverableError = "recoverable_error"
|
||||
skillOutcomeFatalError = "fatal_error"
|
||||
skillOutcomeNotHandled = "not_handled"
|
||||
)
|
||||
|
||||
type skillOutcome struct {
|
||||
Skill string `json:"skill"`
|
||||
Action string `json:"action"`
|
||||
Status string `json:"status"`
|
||||
GoalAchieved bool `json:"goal_achieved"`
|
||||
UserMessage string `json:"user_message,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
type taskReviewDecision struct {
|
||||
Route string `json:"route"`
|
||||
Answer string `json:"answer,omitempty"`
|
||||
}
|
||||
|
||||
func normalizeAtomicSkillAction(skill, action string) string {
|
||||
action = strings.TrimSpace(strings.ToLower(action))
|
||||
switch skill {
|
||||
case "trader_management":
|
||||
switch action {
|
||||
case "query", "query_list":
|
||||
return "query_list"
|
||||
case "query_running":
|
||||
return "query_running"
|
||||
case "query_detail":
|
||||
return "query_detail"
|
||||
case "update":
|
||||
return "update_name"
|
||||
case "update_name", "update_bindings":
|
||||
return action
|
||||
}
|
||||
case "exchange_management":
|
||||
switch action {
|
||||
case "query", "query_list":
|
||||
return "query_list"
|
||||
case "query_detail":
|
||||
return "query_detail"
|
||||
case "update":
|
||||
return "update_name"
|
||||
case "update_name", "update_status":
|
||||
return action
|
||||
}
|
||||
case "model_management":
|
||||
switch action {
|
||||
case "query", "query_list":
|
||||
return "query_list"
|
||||
case "query_detail":
|
||||
return "query_detail"
|
||||
case "update":
|
||||
return "update_name"
|
||||
case "update_name", "update_endpoint", "update_status":
|
||||
return action
|
||||
}
|
||||
case "strategy_management":
|
||||
switch action {
|
||||
case "query", "query_list":
|
||||
return "query_list"
|
||||
case "query_detail":
|
||||
return "query_detail"
|
||||
case "update":
|
||||
return "update_name"
|
||||
case "update_name", "update_config", "update_prompt":
|
||||
return action
|
||||
}
|
||||
}
|
||||
return action
|
||||
}
|
||||
|
||||
func inferSkillOutcome(skill, action, answer string, activeSession skillSession, data map[string]any) skillOutcome {
|
||||
outcome := skillOutcome{
|
||||
Skill: skill,
|
||||
Action: action,
|
||||
Status: skillOutcomeSuccess,
|
||||
UserMessage: strings.TrimSpace(answer),
|
||||
Data: data,
|
||||
}
|
||||
if activeSession.Name != "" {
|
||||
outcome.Status = skillOutcomeNeedMoreInfo
|
||||
outcome.GoalAchieved = false
|
||||
return outcome
|
||||
}
|
||||
|
||||
lower := strings.ToLower(strings.TrimSpace(answer))
|
||||
switch {
|
||||
case lower == "":
|
||||
outcome.Status = skillOutcomeNotHandled
|
||||
case strings.Contains(lower, "失败") || strings.Contains(lower, "failed") || strings.Contains(lower, "error"):
|
||||
outcome.Status = skillOutcomeRecoverableError
|
||||
outcome.Error = strings.TrimSpace(answer)
|
||||
default:
|
||||
outcome.GoalAchieved = true
|
||||
}
|
||||
return outcome
|
||||
}
|
||||
|
||||
func parseTaskReviewDecision(raw string) (taskReviewDecision, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimPrefix(raw, "```json")
|
||||
raw = strings.TrimPrefix(raw, "```")
|
||||
raw = strings.TrimSuffix(raw, "```")
|
||||
raw = strings.TrimSpace(raw)
|
||||
|
||||
var decision taskReviewDecision
|
||||
if err := json.Unmarshal([]byte(raw), &decision); err == nil {
|
||||
decision.Route = strings.TrimSpace(strings.ToLower(decision.Route))
|
||||
decision.Answer = strings.TrimSpace(decision.Answer)
|
||||
return decision, nil
|
||||
}
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start >= 0 && end > start {
|
||||
if err := json.Unmarshal([]byte(raw[start:end+1]), &decision); err == nil {
|
||||
decision.Route = strings.TrimSpace(strings.ToLower(decision.Route))
|
||||
decision.Answer = strings.TrimSpace(decision.Answer)
|
||||
return decision, nil
|
||||
}
|
||||
}
|
||||
return taskReviewDecision{}, fmt.Errorf("invalid task review json")
|
||||
}
|
||||
|
||||
func (a *Agent) reviewTaskCompletion(ctx context.Context, userID int64, lang, text string, outcome skillOutcome) (taskReviewDecision, error) {
|
||||
if a.aiClient == nil {
|
||||
if outcome.Status == skillOutcomeRecoverableError || outcome.Status == skillOutcomeFatalError || outcome.Status == skillOutcomeNotHandled {
|
||||
return taskReviewDecision{Route: "replan"}, nil
|
||||
}
|
||||
return taskReviewDecision{Route: "complete", Answer: outcome.UserMessage}, nil
|
||||
}
|
||||
|
||||
recentConversationCtx := a.buildRecentConversationContext(userID, text)
|
||||
outcomeJSON, _ := json.Marshal(outcome)
|
||||
systemPrompt := `You are the task-level Plan-Execute-Review supervisor for NOFXi.
|
||||
You are reviewing the JSON result returned by one structured skill execution.
|
||||
Return JSON only. Do not return markdown.
|
||||
|
||||
Rules:
|
||||
- Decide whether the OVERALL user task is finished, not whether the skill itself ran successfully.
|
||||
- Use route "complete" only when the user's task is now complete or the best next message is a final user-facing reply.
|
||||
- Use route "replan" when the user's task is not complete yet and the planner should continue from the new skill outcome.
|
||||
- Prefer route "replan" for recoverable errors, unmet goals, missing prerequisites, or cases where another skill/tool sequence may help.
|
||||
- If you choose "complete", produce the final user-facing answer in the user's language.
|
||||
|
||||
Return JSON with this exact shape:
|
||||
{"route":"complete|replan","answer":""}`
|
||||
userPrompt := fmt.Sprintf("Language: %s\nUser message: %s\n\nRecent conversation:\n%s\n\nSkill outcome JSON:\n%s", lang, text, recentConversationCtx, string(outcomeJSON))
|
||||
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout)
|
||||
defer cancel()
|
||||
|
||||
raw, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage(systemPrompt),
|
||||
mcp.NewUserMessage(userPrompt),
|
||||
},
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
if err != nil {
|
||||
return taskReviewDecision{}, err
|
||||
}
|
||||
return parseTaskReviewDecision(raw)
|
||||
}
|
||||
422
agent/tools.go
422
agent/tools.go
@@ -1,13 +1,17 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nofx/kernel"
|
||||
"nofx/mcp"
|
||||
"nofx/safe"
|
||||
"nofx/security"
|
||||
@@ -56,6 +60,24 @@ func buildAgentTools() []mcp.Tool {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: mcp.FunctionDef{
|
||||
Name: "get_backend_logs",
|
||||
Description: "Get recent backend log lines for a trader diagnosis. Prefer this when the user asks why a specific trader failed, stopped, or behaved unexpectedly. Returns recent matching log lines for the authenticated user's trader.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"trader_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Trader id to diagnose. The backend verifies that this trader belongs to the authenticated user before returning logs.",
|
||||
},
|
||||
"limit": map[string]any{"type": "number", "description": "Maximum number of recent log lines to return. Default 30."},
|
||||
"errors_only": map[string]any{"type": "boolean", "description": "When true, only return error-like log lines. Default true."},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: mcp.FunctionDef{
|
||||
@@ -92,19 +114,19 @@ func buildAgentTools() []mcp.Tool {
|
||||
"type": "boolean",
|
||||
"description": "Whether this exchange binding should be enabled.",
|
||||
},
|
||||
"api_key": map[string]any{"type": "string"},
|
||||
"secret_key": map[string]any{"type": "string"},
|
||||
"passphrase": map[string]any{"type": "string"},
|
||||
"testnet": map[string]any{"type": "boolean"},
|
||||
"hyperliquid_wallet_addr": map[string]any{"type": "string"},
|
||||
"api_key": map[string]any{"type": "string"},
|
||||
"secret_key": map[string]any{"type": "string"},
|
||||
"passphrase": map[string]any{"type": "string"},
|
||||
"testnet": map[string]any{"type": "boolean"},
|
||||
"hyperliquid_wallet_addr": map[string]any{"type": "string"},
|
||||
"hyperliquid_unified_account": map[string]any{"type": "boolean"},
|
||||
"aster_user": map[string]any{"type": "string"},
|
||||
"aster_signer": map[string]any{"type": "string"},
|
||||
"aster_private_key": map[string]any{"type": "string"},
|
||||
"lighter_wallet_addr": map[string]any{"type": "string"},
|
||||
"lighter_private_key": map[string]any{"type": "string"},
|
||||
"aster_user": map[string]any{"type": "string"},
|
||||
"aster_signer": map[string]any{"type": "string"},
|
||||
"aster_private_key": map[string]any{"type": "string"},
|
||||
"lighter_wallet_addr": map[string]any{"type": "string"},
|
||||
"lighter_private_key": map[string]any{"type": "string"},
|
||||
"lighter_api_key_private_key": map[string]any{"type": "string"},
|
||||
"lighter_api_key_index": map[string]any{"type": "number"},
|
||||
"lighter_api_key_index": map[string]any{"type": "number"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
@@ -171,10 +193,10 @@ func buildAgentTools() []mcp.Tool {
|
||||
"type": "string",
|
||||
"enum": []string{"list", "create", "update", "delete", "activate", "duplicate", "get_default_config"},
|
||||
},
|
||||
"strategy_id": map[string]any{"type": "string"},
|
||||
"name": map[string]any{"type": "string"},
|
||||
"description": map[string]any{"type": "string"},
|
||||
"lang": map[string]any{"type": "string", "enum": []string{"zh", "en"}},
|
||||
"strategy_id": map[string]any{"type": "string"},
|
||||
"name": map[string]any{"type": "string"},
|
||||
"description": map[string]any{"type": "string"},
|
||||
"lang": map[string]any{"type": "string", "enum": []string{"zh", "en"}},
|
||||
"is_public": map[string]any{"type": "boolean"},
|
||||
"config_visible": map[string]any{"type": "boolean"},
|
||||
"config": map[string]any{"type": "object", "description": "Full or partial strategy config JSON object, depending on action."},
|
||||
@@ -199,22 +221,22 @@ func buildAgentTools() []mcp.Tool {
|
||||
"type": "string",
|
||||
"description": "Required for update, delete, start, and stop.",
|
||||
},
|
||||
"name": map[string]any{"type": "string"},
|
||||
"ai_model_id": map[string]any{"type": "string"},
|
||||
"exchange_id": map[string]any{"type": "string"},
|
||||
"strategy_id": map[string]any{"type": "string"},
|
||||
"initial_balance": map[string]any{"type": "number"},
|
||||
"scan_interval_minutes": map[string]any{"type": "number"},
|
||||
"is_cross_margin": map[string]any{"type": "boolean"},
|
||||
"show_in_competition": map[string]any{"type": "boolean"},
|
||||
"btc_eth_leverage": map[string]any{"type": "number"},
|
||||
"altcoin_leverage": map[string]any{"type": "number"},
|
||||
"trading_symbols": map[string]any{"type": "string"},
|
||||
"custom_prompt": map[string]any{"type": "string"},
|
||||
"override_base_prompt": map[string]any{"type": "boolean"},
|
||||
"name": map[string]any{"type": "string"},
|
||||
"ai_model_id": map[string]any{"type": "string"},
|
||||
"exchange_id": map[string]any{"type": "string"},
|
||||
"strategy_id": map[string]any{"type": "string"},
|
||||
"initial_balance": map[string]any{"type": "number"},
|
||||
"scan_interval_minutes": map[string]any{"type": "number"},
|
||||
"is_cross_margin": map[string]any{"type": "boolean"},
|
||||
"show_in_competition": map[string]any{"type": "boolean"},
|
||||
"btc_eth_leverage": map[string]any{"type": "number"},
|
||||
"altcoin_leverage": map[string]any{"type": "number"},
|
||||
"trading_symbols": map[string]any{"type": "string"},
|
||||
"custom_prompt": map[string]any{"type": "string"},
|
||||
"override_base_prompt": map[string]any{"type": "boolean"},
|
||||
"system_prompt_template": map[string]any{"type": "string"},
|
||||
"use_ai500": map[string]any{"type": "boolean"},
|
||||
"use_oi_top": map[string]any{"type": "boolean"},
|
||||
"use_ai500": map[string]any{"type": "boolean"},
|
||||
"use_oi_top": map[string]any{"type": "boolean"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
@@ -316,6 +338,26 @@ func buildAgentTools() []mcp.Tool {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: mcp.FunctionDef{
|
||||
Name: "get_candidate_coins",
|
||||
Description: "Get the current candidate coin list for a trader or strategy, including AI500 coin-source settings and the selected symbols.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"trader_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional trader id. Prefer this when asking about a running trader.",
|
||||
},
|
||||
"strategy_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional strategy id. Use this when asking about a strategy template directly.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,6 +368,8 @@ func (a *Agent) handleToolCall(ctx context.Context, storeUserID string, userID i
|
||||
return a.toolGetPreferences(userID)
|
||||
case "manage_preferences":
|
||||
return a.toolManagePreferences(userID, tc.Function.Arguments)
|
||||
case "get_backend_logs":
|
||||
return a.toolGetBackendLogs(storeUserID, tc.Function.Arguments)
|
||||
case "get_exchange_configs":
|
||||
return a.toolGetExchangeConfigs(storeUserID)
|
||||
case "manage_exchange_config":
|
||||
@@ -352,6 +396,8 @@ func (a *Agent) handleToolCall(ctx context.Context, storeUserID string, userID i
|
||||
return a.toolGetMarketPrice(tc.Function.Arguments)
|
||||
case "get_trade_history":
|
||||
return a.toolGetTradeHistory(tc.Function.Arguments)
|
||||
case "get_candidate_coins":
|
||||
return a.toolGetCandidateCoins(storeUserID, userID, tc.Function.Arguments)
|
||||
default:
|
||||
return fmt.Sprintf(`{"error": "unknown tool: %s"}`, tc.Function.Name)
|
||||
}
|
||||
@@ -388,21 +434,21 @@ type safeModelToolConfig struct {
|
||||
}
|
||||
|
||||
type safeTraderToolConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
ExchangeID string `json:"exchange_id"`
|
||||
StrategyID string `json:"strategy_id,omitempty"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
IsCrossMargin bool `json:"is_cross_margin"`
|
||||
ShowInCompetition bool `json:"show_in_competition"`
|
||||
BTCETHLeverage int `json:"btc_eth_leverage,omitempty"`
|
||||
AltcoinLeverage int `json:"altcoin_leverage,omitempty"`
|
||||
TradingSymbols string `json:"trading_symbols,omitempty"`
|
||||
CustomPrompt string `json:"custom_prompt,omitempty"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
ExchangeID string `json:"exchange_id"`
|
||||
StrategyID string `json:"strategy_id,omitempty"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
IsCrossMargin bool `json:"is_cross_margin"`
|
||||
ShowInCompetition bool `json:"show_in_competition"`
|
||||
BTCETHLeverage int `json:"btc_eth_leverage,omitempty"`
|
||||
AltcoinLeverage int `json:"altcoin_leverage,omitempty"`
|
||||
TradingSymbols string `json:"trading_symbols,omitempty"`
|
||||
CustomPrompt string `json:"custom_prompt,omitempty"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template,omitempty"`
|
||||
}
|
||||
|
||||
type safeStrategyToolConfig struct {
|
||||
@@ -472,6 +518,14 @@ func safeModelForTool(model *store.AIModel) safeModelToolConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func modelConfigUsable(provider, modelID, apiKey, customAPIURL, customModelName string) bool {
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
return false
|
||||
}
|
||||
resolvedURL, resolvedModel := resolveModelRuntimeConfig(provider, customAPIURL, customModelName, modelID)
|
||||
return strings.TrimSpace(resolvedURL) != "" && strings.TrimSpace(resolvedModel) != ""
|
||||
}
|
||||
|
||||
func safeTraderForTool(trader *store.Trader, isRunning bool) safeTraderToolConfig {
|
||||
return safeTraderToolConfig{
|
||||
ID: trader.ID,
|
||||
@@ -531,29 +585,131 @@ func (a *Agent) toolGetExchangeConfigs(storeUserID string) string {
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func latestBackendLogFilePath() string {
|
||||
matches, err := filepath.Glob(filepath.Join("data", "nofx_*.log"))
|
||||
if err != nil || len(matches) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Strings(matches)
|
||||
return matches[len(matches)-1]
|
||||
}
|
||||
|
||||
func isBackendErrorLikeLogLine(line string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(line))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(lower, "[erro]") ||
|
||||
strings.Contains(lower, " panic") ||
|
||||
strings.Contains(lower, "🔥") ||
|
||||
strings.Contains(lower, "❌") ||
|
||||
strings.Contains(lower, " failed") ||
|
||||
strings.Contains(lower, " error") ||
|
||||
strings.Contains(lower, "invalid ")
|
||||
}
|
||||
|
||||
func readBackendLogEntries(limit int, contains string, errorsOnly bool) (string, []string, error) {
|
||||
path := latestBackendLogFilePath()
|
||||
if path == "" {
|
||||
return "", nil, fmt.Errorf("backend log file not found")
|
||||
}
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return path, nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
filter := strings.ToLower(strings.TrimSpace(contains))
|
||||
matches := make([]string, 0, max(limit, 1))
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if errorsOnly && !isBackendErrorLikeLogLine(line) {
|
||||
continue
|
||||
}
|
||||
if filter != "" && !strings.Contains(strings.ToLower(line), filter) {
|
||||
continue
|
||||
}
|
||||
matches = append(matches, line)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return path, nil, err
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 30
|
||||
}
|
||||
if len(matches) > limit {
|
||||
matches = matches[len(matches)-limit:]
|
||||
}
|
||||
return path, matches, nil
|
||||
}
|
||||
|
||||
func (a *Agent) toolGetBackendLogs(storeUserID, argsJSON string) string {
|
||||
var args struct {
|
||||
TraderID string `json:"trader_id"`
|
||||
Limit int `json:"limit"`
|
||||
ErrorsOnly *bool `json:"errors_only"`
|
||||
}
|
||||
if strings.TrimSpace(argsJSON) != "" {
|
||||
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||
return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err)
|
||||
}
|
||||
}
|
||||
errorsOnly := true
|
||||
if args.ErrorsOnly != nil {
|
||||
errorsOnly = *args.ErrorsOnly
|
||||
}
|
||||
traderID := strings.TrimSpace(args.TraderID)
|
||||
if traderID == "" {
|
||||
return `{"error":"trader_id is required"}`
|
||||
}
|
||||
if a.store == nil {
|
||||
return `{"error":"store unavailable"}`
|
||||
}
|
||||
trader, err := a.store.Trader().GetByID(traderID)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error":"failed to load trader: %s"}`, err)
|
||||
}
|
||||
if trader.UserID != storeUserID {
|
||||
return `{"error":"trader not found for current user"}`
|
||||
}
|
||||
path, entries, err := readBackendLogEntries(args.Limit, traderID, errorsOnly)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error":"failed to read backend logs: %s"}`, err)
|
||||
}
|
||||
result, _ := json.Marshal(map[string]any{
|
||||
"trader_id": traderID,
|
||||
"log_file": path,
|
||||
"entries": entries,
|
||||
"count": len(entries),
|
||||
"errors_only": errorsOnly,
|
||||
})
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (a *Agent) toolManageExchangeConfig(storeUserID, argsJSON string) string {
|
||||
if a.store == nil {
|
||||
return `{"error":"store unavailable"}`
|
||||
}
|
||||
var args struct {
|
||||
Action string `json:"action"`
|
||||
ExchangeID string `json:"exchange_id"`
|
||||
ExchangeType string `json:"exchange_type"`
|
||||
AccountName string `json:"account_name"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Passphrase string `json:"passphrase"`
|
||||
Testnet *bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
|
||||
HyperliquidUnifiedAccount *bool `json:"hyperliquid_unified_account"`
|
||||
AsterUser string `json:"aster_user"`
|
||||
AsterSigner string `json:"aster_signer"`
|
||||
AsterPrivateKey string `json:"aster_private_key"`
|
||||
LighterWalletAddr string `json:"lighter_wallet_addr"`
|
||||
LighterPrivateKey string `json:"lighter_private_key"`
|
||||
LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"`
|
||||
LighterAPIKeyIndex *int `json:"lighter_api_key_index"`
|
||||
Action string `json:"action"`
|
||||
ExchangeID string `json:"exchange_id"`
|
||||
ExchangeType string `json:"exchange_type"`
|
||||
AccountName string `json:"account_name"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Passphrase string `json:"passphrase"`
|
||||
Testnet *bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
|
||||
HyperliquidUnifiedAccount *bool `json:"hyperliquid_unified_account"`
|
||||
AsterUser string `json:"aster_user"`
|
||||
AsterSigner string `json:"aster_signer"`
|
||||
AsterPrivateKey string `json:"aster_private_key"`
|
||||
LighterWalletAddr string `json:"lighter_wallet_addr"`
|
||||
LighterPrivateKey string `json:"lighter_private_key"`
|
||||
LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"`
|
||||
LighterAPIKeyIndex *int `json:"lighter_api_key_index"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||
return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err)
|
||||
@@ -802,7 +958,15 @@ func (a *Agent) toolManageModelConfig(storeUserID, argsJSON string) string {
|
||||
if strings.TrimSpace(args.CustomModelName) != "" {
|
||||
customModelName = strings.TrimSpace(args.CustomModelName)
|
||||
}
|
||||
if err := a.store.AIModel().Update(storeUserID, existing.ID, enabled, strings.TrimSpace(args.APIKey), customAPIURL, customModelName); err != nil {
|
||||
apiKey := strings.TrimSpace(args.APIKey)
|
||||
effectiveAPIKey := string(existing.APIKey)
|
||||
if apiKey != "" {
|
||||
effectiveAPIKey = apiKey
|
||||
}
|
||||
if enabled && !modelConfigUsable(existing.Provider, existing.ID, effectiveAPIKey, customAPIURL, customModelName) {
|
||||
return `{"error":"cannot enable model config before API key is configured"}`
|
||||
}
|
||||
if err := a.store.AIModel().Update(storeUserID, existing.ID, enabled, apiKey, customAPIURL, customModelName); err != nil {
|
||||
return fmt.Sprintf(`{"error":"failed to update model config: %s"}`, err)
|
||||
}
|
||||
updated, err := a.store.AIModel().Get(storeUserID, existing.ID)
|
||||
@@ -1893,6 +2057,136 @@ func (a *Agent) toolGetTradeHistory(argsJSON string) string {
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (a *Agent) toolGetCandidateCoins(storeUserID string, userID int64, argsJSON string) string {
|
||||
if a.store == nil {
|
||||
return `{"error":"store unavailable"}`
|
||||
}
|
||||
|
||||
var args struct {
|
||||
TraderID string `json:"trader_id"`
|
||||
StrategyID string `json:"strategy_id"`
|
||||
}
|
||||
if strings.TrimSpace(argsJSON) != "" {
|
||||
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||
return fmt.Sprintf(`{"error":"invalid arguments: %s"}`, err)
|
||||
}
|
||||
}
|
||||
|
||||
traderID := strings.TrimSpace(args.TraderID)
|
||||
strategyID := strings.TrimSpace(args.StrategyID)
|
||||
state := a.getExecutionState(userID)
|
||||
if traderID == "" && state.CurrentReferences != nil && state.CurrentReferences.Trader != nil {
|
||||
traderID = strings.TrimSpace(state.CurrentReferences.Trader.ID)
|
||||
}
|
||||
if strategyID == "" && state.CurrentReferences != nil && state.CurrentReferences.Strategy != nil {
|
||||
strategyID = strings.TrimSpace(state.CurrentReferences.Strategy.ID)
|
||||
}
|
||||
|
||||
if traderID != "" {
|
||||
return a.toolGetCandidateCoinsForTrader(storeUserID, traderID)
|
||||
}
|
||||
if strategyID != "" {
|
||||
return a.toolGetCandidateCoinsForStrategy(storeUserID, strategyID)
|
||||
}
|
||||
return `{"error":"trader_id or strategy_id is required"}`
|
||||
}
|
||||
|
||||
func (a *Agent) toolGetCandidateCoinsForTrader(storeUserID, traderID string) string {
|
||||
if a.traderManager == nil {
|
||||
return `{"error":"no trader manager configured"}`
|
||||
}
|
||||
record, err := a.store.Trader().GetFullConfig(storeUserID, traderID)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error":"failed to load trader: %s"}`, err)
|
||||
}
|
||||
memTrader, err := a.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error":"trader is not loaded in memory: %s"}`, err)
|
||||
}
|
||||
|
||||
coins, coinErr := memTrader.GetCandidateCoins()
|
||||
cfg := memTrader.GetStrategyConfig()
|
||||
status := memTrader.GetStatus()
|
||||
isRunning, _ := status["is_running"].(bool)
|
||||
payload := map[string]any{
|
||||
"trader": safeTraderForTool(record.Trader, isRunning),
|
||||
"coin_source": candidateCoinSourceSummary(cfg),
|
||||
"candidate_count": len(coins),
|
||||
"candidate_symbols": candidateCoinSymbols(coins),
|
||||
"candidates": candidateCoinDetails(coins),
|
||||
}
|
||||
if coinErr != nil {
|
||||
payload["error"] = coinErr.Error()
|
||||
}
|
||||
result, _ := json.Marshal(payload)
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (a *Agent) toolGetCandidateCoinsForStrategy(storeUserID, strategyID string) string {
|
||||
record, err := a.store.Strategy().Get(storeUserID, strategyID)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error":"failed to load strategy: %s"}`, err)
|
||||
}
|
||||
cfg, err := record.ParseConfig()
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"error":"failed to parse strategy config: %s"}`, err)
|
||||
}
|
||||
|
||||
engine := kernel.NewStrategyEngine(cfg)
|
||||
coins, coinErr := engine.GetCandidateCoins()
|
||||
payload := map[string]any{
|
||||
"strategy": safeStrategyForTool(record),
|
||||
"coin_source": candidateCoinSourceSummary(cfg),
|
||||
"candidate_count": len(coins),
|
||||
"candidate_symbols": candidateCoinSymbols(coins),
|
||||
"candidates": candidateCoinDetails(coins),
|
||||
}
|
||||
if coinErr != nil {
|
||||
payload["error"] = coinErr.Error()
|
||||
}
|
||||
result, _ := json.Marshal(payload)
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func candidateCoinSourceSummary(cfg *store.StrategyConfig) map[string]any {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"source_type": cfg.CoinSource.SourceType,
|
||||
"use_ai500": cfg.CoinSource.UseAI500,
|
||||
"ai500_limit": cfg.CoinSource.AI500Limit,
|
||||
"use_oi_top": cfg.CoinSource.UseOITop,
|
||||
"oi_top_limit": cfg.CoinSource.OITopLimit,
|
||||
"use_oi_low": cfg.CoinSource.UseOILow,
|
||||
"oi_low_limit": cfg.CoinSource.OILowLimit,
|
||||
"use_hyper_all": cfg.CoinSource.UseHyperAll,
|
||||
"use_hyper_main": cfg.CoinSource.UseHyperMain,
|
||||
"hyper_main_limit": cfg.CoinSource.HyperMainLimit,
|
||||
"static_coins": cfg.CoinSource.StaticCoins,
|
||||
"excluded_coins": cfg.CoinSource.ExcludedCoins,
|
||||
}
|
||||
}
|
||||
|
||||
func candidateCoinSymbols(coins []kernel.CandidateCoin) []string {
|
||||
out := make([]string, 0, len(coins))
|
||||
for _, coin := range coins {
|
||||
out = append(out, coin.Symbol)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func candidateCoinDetails(coins []kernel.CandidateCoin) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(coins))
|
||||
for _, coin := range coins {
|
||||
out = append(out, map[string]any{
|
||||
"symbol": coin.Symbol,
|
||||
"sources": coin.Sources,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// knownCryptoSymbols is a set of well-known cryptocurrency base symbols.
|
||||
// Without this, isStockSymbol("BTC") would incorrectly return true because
|
||||
// "BTC" is 3 uppercase letters and the suffix check only catches "BTCUSDT"-style pairs.
|
||||
|
||||
521
agent/workflow.go
Normal file
521
agent/workflow.go
Normal file
@@ -0,0 +1,521 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
workflowTaskPending = "pending"
|
||||
workflowTaskRunning = "running"
|
||||
workflowTaskCompleted = "completed"
|
||||
workflowTaskFailed = "failed"
|
||||
)
|
||||
|
||||
type WorkflowTask struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Skill string `json:"skill,omitempty"`
|
||||
Action string `json:"action,omitempty"`
|
||||
Request string `json:"request,omitempty"`
|
||||
DependsOn []string `json:"depends_on,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type WorkflowSession struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
OriginalRequest string `json:"original_request,omitempty"`
|
||||
Tasks []WorkflowTask `json:"tasks,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
type workflowDecomposition struct {
|
||||
Tasks []WorkflowTask `json:"tasks"`
|
||||
}
|
||||
|
||||
func workflowSessionConfigKey(userID int64) string {
|
||||
return fmt.Sprintf("agent_workflow_session_%d", userID)
|
||||
}
|
||||
|
||||
func normalizeWorkflowSession(session WorkflowSession) WorkflowSession {
|
||||
session.OriginalRequest = strings.TrimSpace(session.OriginalRequest)
|
||||
normalized := make([]WorkflowTask, 0, len(session.Tasks))
|
||||
for i, task := range session.Tasks {
|
||||
task.ID = strings.TrimSpace(task.ID)
|
||||
if task.ID == "" {
|
||||
task.ID = fmt.Sprintf("task_%d", i+1)
|
||||
}
|
||||
task.Skill = strings.TrimSpace(task.Skill)
|
||||
task.Action = normalizeAtomicSkillAction(task.Skill, task.Action)
|
||||
task.Request = strings.TrimSpace(task.Request)
|
||||
task.DependsOn = cleanStringList(task.DependsOn)
|
||||
task.Status = strings.TrimSpace(task.Status)
|
||||
if task.Status == "" {
|
||||
task.Status = workflowTaskPending
|
||||
}
|
||||
task.Error = strings.TrimSpace(task.Error)
|
||||
if task.Skill == "" || task.Action == "" || task.Request == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, task)
|
||||
}
|
||||
session.Tasks = normalized
|
||||
if len(session.Tasks) == 0 {
|
||||
return WorkflowSession{}
|
||||
}
|
||||
if session.UpdatedAt == "" {
|
||||
session.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func (a *Agent) getWorkflowSession(userID int64) WorkflowSession {
|
||||
if a.store == nil {
|
||||
return WorkflowSession{}
|
||||
}
|
||||
raw, err := a.store.GetSystemConfig(workflowSessionConfigKey(userID))
|
||||
if err != nil || strings.TrimSpace(raw) == "" {
|
||||
return WorkflowSession{}
|
||||
}
|
||||
var session WorkflowSession
|
||||
if err := json.Unmarshal([]byte(raw), &session); err != nil {
|
||||
return WorkflowSession{}
|
||||
}
|
||||
return normalizeWorkflowSession(session)
|
||||
}
|
||||
|
||||
func (a *Agent) saveWorkflowSession(userID int64, session WorkflowSession) {
|
||||
if a.store == nil {
|
||||
return
|
||||
}
|
||||
session = normalizeWorkflowSession(session)
|
||||
if len(session.Tasks) == 0 {
|
||||
_ = a.store.SetSystemConfig(workflowSessionConfigKey(userID), "")
|
||||
return
|
||||
}
|
||||
session.UserID = userID
|
||||
session.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = a.store.SetSystemConfig(workflowSessionConfigKey(userID), string(data))
|
||||
}
|
||||
|
||||
func (a *Agent) clearWorkflowSession(userID int64) {
|
||||
if a.store == nil {
|
||||
return
|
||||
}
|
||||
_ = a.store.SetSystemConfig(workflowSessionConfigKey(userID), "")
|
||||
}
|
||||
|
||||
func hasActiveWorkflowSession(session WorkflowSession) bool {
|
||||
if len(session.Tasks) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, task := range session.Tasks {
|
||||
if task.Status == workflowTaskPending || task.Status == workflowTaskRunning {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func nextRunnableWorkflowTask(session WorkflowSession) (WorkflowTask, int, bool) {
|
||||
for i, task := range session.Tasks {
|
||||
if task.Status != workflowTaskPending && task.Status != workflowTaskRunning {
|
||||
continue
|
||||
}
|
||||
depsReady := true
|
||||
for _, dep := range task.DependsOn {
|
||||
ok := false
|
||||
for _, candidate := range session.Tasks {
|
||||
if candidate.ID == dep && candidate.Status == workflowTaskCompleted {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
depsReady = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if depsReady {
|
||||
return task, i, true
|
||||
}
|
||||
}
|
||||
return WorkflowTask{}, -1, false
|
||||
}
|
||||
|
||||
func supportedWorkflowSkill(skill, action string) bool {
|
||||
skill = strings.TrimSpace(skill)
|
||||
action = normalizeAtomicSkillAction(skill, action)
|
||||
if skill == "" || action == "" {
|
||||
return false
|
||||
}
|
||||
if _, ok := getSkillDAG(skill, action); ok {
|
||||
return true
|
||||
}
|
||||
switch skill {
|
||||
case "trader_management", "strategy_management", "model_management", "exchange_management":
|
||||
switch action {
|
||||
case "create", "query_list", "query_detail", "query_running", "activate":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Agent) tryWorkflowIntent(ctx context.Context, storeUserID string, userID int64, lang, text string, onEvent func(event, data string)) (string, bool, error) {
|
||||
if session := a.getWorkflowSession(userID); hasActiveWorkflowSession(session) {
|
||||
return a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, session, onEvent)
|
||||
}
|
||||
|
||||
decomposition, err := a.decomposeWorkflowIntent(ctx, userID, lang, text)
|
||||
if err != nil || len(decomposition.Tasks) <= 1 {
|
||||
return "", false, err
|
||||
}
|
||||
session := WorkflowSession{
|
||||
UserID: userID,
|
||||
OriginalRequest: text,
|
||||
Tasks: decomposition.Tasks,
|
||||
}
|
||||
a.saveWorkflowSession(userID, session)
|
||||
return a.handleWorkflowSession(ctx, storeUserID, userID, lang, text, session, onEvent)
|
||||
}
|
||||
|
||||
func (a *Agent) handleWorkflowSession(ctx context.Context, storeUserID string, userID int64, lang, text string, session WorkflowSession, onEvent func(event, data string)) (string, bool, error) {
|
||||
if isExplicitFlowAbort(text) {
|
||||
a.clearSkillSession(userID)
|
||||
a.clearWorkflowSession(userID)
|
||||
if lang == "zh" {
|
||||
return "已取消当前任务流。", true, nil
|
||||
}
|
||||
return "Cancelled the current workflow.", true, nil
|
||||
}
|
||||
|
||||
if activeSkill := a.getSkillSession(userID); strings.TrimSpace(activeSkill.Name) != "" {
|
||||
answer, handled := a.tryHardSkill(ctx, storeUserID, userID, lang, text, onEvent)
|
||||
if !handled {
|
||||
return "", false, nil
|
||||
}
|
||||
session = a.getWorkflowSession(userID)
|
||||
if hasActiveWorkflowSession(session) && strings.TrimSpace(a.getSkillSession(userID).Name) == "" {
|
||||
session = markCurrentWorkflowTask(session, workflowTaskCompleted, "")
|
||||
a.saveWorkflowSession(userID, session)
|
||||
if final, done, err := a.maybeAdvanceWorkflow(ctx, storeUserID, userID, lang, session, onEvent); done || err != nil {
|
||||
if final != "" && answer != "" {
|
||||
return answer + "\n\n" + final, true, err
|
||||
}
|
||||
if answer != "" {
|
||||
return answer, true, err
|
||||
}
|
||||
return final, true, err
|
||||
}
|
||||
}
|
||||
return answer, true, nil
|
||||
}
|
||||
|
||||
return a.maybeAdvanceWorkflow(ctx, storeUserID, userID, lang, session, onEvent)
|
||||
}
|
||||
|
||||
func (a *Agent) maybeAdvanceWorkflow(ctx context.Context, storeUserID string, userID int64, lang string, session WorkflowSession, onEvent func(event, data string)) (string, bool, error) {
|
||||
task, index, ok := nextRunnableWorkflowTask(session)
|
||||
if !ok {
|
||||
summary := a.generateWorkflowSummary(ctx, userID, lang, session)
|
||||
a.clearWorkflowSession(userID)
|
||||
if summary == "" {
|
||||
if lang == "zh" {
|
||||
summary = "已完成当前任务流。"
|
||||
} else {
|
||||
summary = "Completed the current workflow."
|
||||
}
|
||||
}
|
||||
if onEvent != nil {
|
||||
onEvent(StreamEventPlan, summary)
|
||||
onEvent(StreamEventDelta, summary)
|
||||
}
|
||||
return summary, true, nil
|
||||
}
|
||||
|
||||
session.Tasks[index].Status = workflowTaskRunning
|
||||
a.saveWorkflowSession(userID, session)
|
||||
taskSession := skillSession{Name: task.Skill, Action: task.Action, Phase: "collecting"}
|
||||
a.saveSkillSession(userID, taskSession)
|
||||
|
||||
if onEvent != nil {
|
||||
onEvent(StreamEventPlan, a.formatWorkflowStatus(lang, session))
|
||||
onEvent(StreamEventTool, "workflow:"+task.Skill+":"+task.Action)
|
||||
}
|
||||
|
||||
answer, handled := a.tryHardSkill(ctx, storeUserID, userID, lang, task.Request, onEvent)
|
||||
if !handled {
|
||||
session.Tasks[index].Status = workflowTaskFailed
|
||||
session.Tasks[index].Error = "task_not_handled"
|
||||
a.saveWorkflowSession(userID, session)
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
if strings.TrimSpace(a.getSkillSession(userID).Name) == "" {
|
||||
session = a.getWorkflowSession(userID)
|
||||
session = markCurrentWorkflowTask(session, workflowTaskCompleted, "")
|
||||
a.saveWorkflowSession(userID, session)
|
||||
if more, ok, err := a.maybeAdvanceWorkflow(ctx, storeUserID, userID, lang, session, onEvent); ok || err != nil {
|
||||
if answer != "" && more != "" {
|
||||
return answer + "\n\n" + more, true, err
|
||||
}
|
||||
if answer != "" {
|
||||
return answer, true, err
|
||||
}
|
||||
return more, true, err
|
||||
}
|
||||
}
|
||||
return answer, true, nil
|
||||
}
|
||||
|
||||
func markCurrentWorkflowTask(session WorkflowSession, status, errMsg string) WorkflowSession {
|
||||
for i := range session.Tasks {
|
||||
if session.Tasks[i].Status == workflowTaskRunning {
|
||||
session.Tasks[i].Status = status
|
||||
session.Tasks[i].Error = strings.TrimSpace(errMsg)
|
||||
return session
|
||||
}
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func (a *Agent) formatWorkflowStatus(lang string, session WorkflowSession) string {
|
||||
parts := make([]string, 0, len(session.Tasks))
|
||||
for _, task := range session.Tasks {
|
||||
label := task.Request
|
||||
if label == "" {
|
||||
label = task.Skill + ":" + task.Action
|
||||
}
|
||||
switch task.Status {
|
||||
case workflowTaskCompleted:
|
||||
label = "✓ " + label
|
||||
case workflowTaskRunning:
|
||||
label = "→ " + label
|
||||
default:
|
||||
label = "· " + label
|
||||
}
|
||||
parts = append(parts, label)
|
||||
}
|
||||
if lang == "zh" {
|
||||
return "任务流:" + strings.Join(parts, " | ")
|
||||
}
|
||||
return "Workflow: " + strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func (a *Agent) generateWorkflowSummary(ctx context.Context, userID int64, lang string, session WorkflowSession) string {
|
||||
completed := make([]string, 0, len(session.Tasks))
|
||||
for _, task := range session.Tasks {
|
||||
if task.Status == workflowTaskCompleted {
|
||||
completed = append(completed, task.Request)
|
||||
}
|
||||
}
|
||||
if len(completed) == 0 {
|
||||
return ""
|
||||
}
|
||||
if a.aiClient == nil {
|
||||
if lang == "zh" {
|
||||
return "已完成这些任务:" + strings.Join(completed, ";")
|
||||
}
|
||||
return "Completed these tasks: " + strings.Join(completed, "; ")
|
||||
}
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout)
|
||||
defer cancel()
|
||||
systemPrompt := `You are summarizing a finished workflow for NOFXi.
|
||||
Return one short user-facing summary in the user's language.
|
||||
Do not mention internal DAG, scheduler, or JSON.`
|
||||
userPrompt := fmt.Sprintf("Language: %s\nOriginal request: %s\nCompleted tasks:\n- %s", lang, session.OriginalRequest, strings.Join(completed, "\n- "))
|
||||
raw, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage(systemPrompt),
|
||||
mcp.NewUserMessage(userPrompt),
|
||||
},
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
if err != nil {
|
||||
if lang == "zh" {
|
||||
return "已完成这些任务:" + strings.Join(completed, ";")
|
||||
}
|
||||
return "Completed these tasks: " + strings.Join(completed, "; ")
|
||||
}
|
||||
return strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
func (a *Agent) decomposeWorkflowIntent(ctx context.Context, userID int64, lang, text string) (workflowDecomposition, error) {
|
||||
if !looksLikeMultiTaskIntent(text) {
|
||||
return workflowDecomposition{}, nil
|
||||
}
|
||||
if a.aiClient != nil {
|
||||
if dec, err := a.decomposeWorkflowIntentWithLLM(ctx, userID, lang, text); err == nil && len(dec.Tasks) > 1 {
|
||||
return dec, nil
|
||||
}
|
||||
}
|
||||
return a.decomposeWorkflowIntentFallback(text), nil
|
||||
}
|
||||
|
||||
func looksLikeMultiTaskIntent(text string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(text))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
connectors := []string{",", ",", "然后", "再", "并且", "并", "同时", "and", "then"}
|
||||
count := 0
|
||||
for _, c := range connectors {
|
||||
if strings.Contains(lower, c) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (a *Agent) decomposeWorkflowIntentWithLLM(ctx context.Context, userID int64, lang, text string) (workflowDecomposition, error) {
|
||||
stageCtx, cancel := withPlannerStageTimeout(ctx, directReplyTimeout)
|
||||
defer cancel()
|
||||
systemPrompt := `You decompose one NOFXi user request into a small task graph.
|
||||
Return JSON only. No markdown.
|
||||
Only use these skills: trader_management, strategy_management, model_management, exchange_management.
|
||||
Only use one atomic action per task.
|
||||
Each task must include:
|
||||
- id
|
||||
- skill
|
||||
- action
|
||||
- request
|
||||
- depends_on (array, may be empty)
|
||||
If the request is effectively a single task, return one task only.`
|
||||
userPrompt := fmt.Sprintf("Language: %s\nUser request: %s", lang, text)
|
||||
raw, err := a.aiClient.CallWithRequest(&mcp.Request{
|
||||
Messages: []mcp.Message{
|
||||
mcp.NewSystemMessage(systemPrompt),
|
||||
mcp.NewUserMessage(userPrompt),
|
||||
},
|
||||
Ctx: stageCtx,
|
||||
})
|
||||
if err != nil {
|
||||
return workflowDecomposition{}, err
|
||||
}
|
||||
return parseWorkflowDecomposition(raw)
|
||||
}
|
||||
|
||||
func parseWorkflowDecomposition(raw string) (workflowDecomposition, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimPrefix(raw, "```json")
|
||||
raw = strings.TrimPrefix(raw, "```")
|
||||
raw = strings.TrimSuffix(raw, "```")
|
||||
raw = strings.TrimSpace(raw)
|
||||
var out workflowDecomposition
|
||||
if err := json.Unmarshal([]byte(raw), &out); err == nil {
|
||||
out = normalizeWorkflowDecomposition(out)
|
||||
return out, nil
|
||||
}
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start >= 0 && end > start {
|
||||
if err := json.Unmarshal([]byte(raw[start:end+1]), &out); err == nil {
|
||||
out = normalizeWorkflowDecomposition(out)
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
return workflowDecomposition{}, fmt.Errorf("invalid workflow json")
|
||||
}
|
||||
|
||||
func normalizeWorkflowDecomposition(out workflowDecomposition) workflowDecomposition {
|
||||
normalized := make([]WorkflowTask, 0, len(out.Tasks))
|
||||
for i, task := range out.Tasks {
|
||||
task.ID = strings.TrimSpace(task.ID)
|
||||
if task.ID == "" {
|
||||
task.ID = fmt.Sprintf("task_%d", i+1)
|
||||
}
|
||||
task.Skill = strings.TrimSpace(task.Skill)
|
||||
task.Action = normalizeAtomicSkillAction(task.Skill, task.Action)
|
||||
task.Request = strings.TrimSpace(task.Request)
|
||||
task.DependsOn = cleanStringList(task.DependsOn)
|
||||
if !supportedWorkflowSkill(task.Skill, task.Action) || task.Request == "" {
|
||||
continue
|
||||
}
|
||||
task.Status = workflowTaskPending
|
||||
normalized = append(normalized, task)
|
||||
}
|
||||
out.Tasks = normalized
|
||||
return out
|
||||
}
|
||||
|
||||
func (a *Agent) decomposeWorkflowIntentFallback(text string) workflowDecomposition {
|
||||
segments := splitWorkflowSegments(text)
|
||||
tasks := make([]WorkflowTask, 0, len(segments))
|
||||
for i, segment := range segments {
|
||||
task, ok := classifyWorkflowTask(segment)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
task.ID = fmt.Sprintf("task_%d", i+1)
|
||||
task.Status = workflowTaskPending
|
||||
if len(tasks) > 0 {
|
||||
task.DependsOn = []string{tasks[len(tasks)-1].ID}
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
return workflowDecomposition{Tasks: tasks}
|
||||
}
|
||||
|
||||
func splitWorkflowSegments(text string) []string {
|
||||
parts := []string{strings.TrimSpace(text)}
|
||||
separators := []string{",", ",", "然后", "再", "并且", "同时", " and then ", " then ", " and "}
|
||||
for _, sep := range separators {
|
||||
next := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
split := strings.Split(part, sep)
|
||||
for _, candidate := range split {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate != "" {
|
||||
next = append(next, candidate)
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = next
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func classifyWorkflowTask(text string) (WorkflowTask, bool) {
|
||||
segment := strings.TrimSpace(text)
|
||||
if segment == "" {
|
||||
return WorkflowTask{}, false
|
||||
}
|
||||
switch {
|
||||
case detectCreateTraderSkill(segment):
|
||||
return WorkflowTask{Skill: "trader_management", Action: "create", Request: segment}, true
|
||||
case detectTraderManagementIntent(segment):
|
||||
action := normalizeAtomicSkillAction("trader_management", detectManagementAction(segment, "trader"))
|
||||
if supportedWorkflowSkill("trader_management", action) {
|
||||
return WorkflowTask{Skill: "trader_management", Action: action, Request: segment}, true
|
||||
}
|
||||
case detectExchangeManagementIntent(segment):
|
||||
action := normalizeAtomicSkillAction("exchange_management", detectManagementAction(segment, "exchange"))
|
||||
if supportedWorkflowSkill("exchange_management", action) {
|
||||
return WorkflowTask{Skill: "exchange_management", Action: action, Request: segment}, true
|
||||
}
|
||||
case detectModelManagementIntent(segment):
|
||||
action := normalizeAtomicSkillAction("model_management", detectManagementAction(segment, "model"))
|
||||
if supportedWorkflowSkill("model_management", action) {
|
||||
return WorkflowTask{Skill: "model_management", Action: action, Request: segment}, true
|
||||
}
|
||||
case detectStrategyManagementIntent(segment):
|
||||
action := normalizeAtomicSkillAction("strategy_management", detectManagementAction(segment, "strategy"))
|
||||
if action == "" && wantsStrategyDetails(segment) {
|
||||
action = "query_detail"
|
||||
}
|
||||
if supportedWorkflowSkill("strategy_management", action) {
|
||||
return WorkflowTask{Skill: "strategy_management", Action: action, Request: segment}, true
|
||||
}
|
||||
}
|
||||
return WorkflowTask{}, false
|
||||
}
|
||||
37
agent/workflow_test.go
Normal file
37
agent/workflow_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSplitWorkflowSegments(t *testing.T) {
|
||||
got := splitWorkflowSegments("把策略删了,再把交易所改名")
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 segments, got %d: %#v", len(got), got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyWorkflowTask(t *testing.T) {
|
||||
task, ok := classifyWorkflowTask("把策略删了")
|
||||
if !ok {
|
||||
t.Fatal("expected task")
|
||||
}
|
||||
if task.Skill != "strategy_management" || task.Action != "delete" {
|
||||
t.Fatalf("unexpected task: %+v", task)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackWorkflowDecompositionBuildsTwoTasks(t *testing.T) {
|
||||
a := &Agent{}
|
||||
out := a.decomposeWorkflowIntentFallback("把策略删了,再把交易所改名")
|
||||
if len(out.Tasks) != 2 {
|
||||
t.Fatalf("expected 2 tasks, got %d", len(out.Tasks))
|
||||
}
|
||||
if out.Tasks[0].Skill != "strategy_management" {
|
||||
t.Fatalf("unexpected first task: %+v", out.Tasks[0])
|
||||
}
|
||||
if out.Tasks[1].Skill != "exchange_management" {
|
||||
t.Fatalf("unexpected second task: %+v", out.Tasks[1])
|
||||
}
|
||||
if len(out.Tasks[1].DependsOn) != 1 || out.Tasks[1].DependsOn[0] != out.Tasks[0].ID {
|
||||
t.Fatalf("expected dependency on first task, got %+v", out.Tasks[1].DependsOn)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func traderLogTag(traderID, traderName string) string {
|
||||
if traderName != "" {
|
||||
return fmt.Sprintf("[trader_id=%s trader_name=%s]", traderID, traderName)
|
||||
}
|
||||
return fmt.Sprintf("[trader_id=%s]", traderID)
|
||||
}
|
||||
|
||||
// CompetitionCache competition data cache
|
||||
type CompetitionCache struct {
|
||||
data map[string]interface{}
|
||||
@@ -88,9 +95,9 @@ func (tm *TraderManager) StartAll() {
|
||||
logger.Info("🚀 Starting all traders...")
|
||||
for id, t := range tm.traders {
|
||||
go func(traderID string, at *trader.AutoTrader) {
|
||||
logger.Infof("▶️ Starting %s...", at.GetName())
|
||||
logger.Infof("%s ▶️ Starting trader runtime", traderLogTag(traderID, at.GetName()))
|
||||
if err := at.Run(); err != nil {
|
||||
logger.Infof("❌ %s runtime error: %v", at.GetName(), err)
|
||||
logger.Warnf("%s runtime error: %v", traderLogTag(traderID, at.GetName()), err)
|
||||
}
|
||||
}(id, t)
|
||||
}
|
||||
@@ -136,9 +143,9 @@ func (tm *TraderManager) AutoStartRunningTraders(st *store.Store) {
|
||||
for id, t := range tm.traders {
|
||||
if runningTraderIDs[id] {
|
||||
go func(traderID string, at *trader.AutoTrader) {
|
||||
logger.Infof("▶️ Auto-restoring %s...", at.GetName())
|
||||
logger.Infof("%s ▶️ Auto-restoring trader runtime", traderLogTag(traderID, at.GetName()))
|
||||
if err := at.Run(); err != nil {
|
||||
logger.Infof("❌ %s runtime error: %v", at.GetName(), err)
|
||||
logger.Warnf("%s runtime error: %v", traderLogTag(traderID, at.GetName()), err)
|
||||
}
|
||||
}(id, t)
|
||||
startedCount++
|
||||
@@ -487,7 +494,7 @@ func (tm *TraderManager) LoadUserTradersFromStore(st *store.Store, userID string
|
||||
logger.Infof("📦 Loading trader %s (AI Model: %s, Exchange: %s/%s, Strategy ID: %s)", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ExchangeType, exchangeCfg.AccountName, traderCfg.StrategyID)
|
||||
err = tm.addTraderFromStore(traderCfg, aiModelCfg, exchangeCfg, st)
|
||||
if err != nil {
|
||||
logger.Infof("❌ Failed to load trader %s: %v", traderCfg.Name, err)
|
||||
logger.Warnf("%s failed to load trader: %v", traderLogTag(traderCfg.ID, traderCfg.Name), err)
|
||||
// Save error for later retrieval
|
||||
tm.loadErrors[traderCfg.ID] = err
|
||||
} else {
|
||||
@@ -592,7 +599,7 @@ func (tm *TraderManager) LoadTradersFromStore(st *store.Store) error {
|
||||
// Add to TraderManager (ai500APIURL/oiTopAPIURL already obtained from strategy config)
|
||||
err = tm.addTraderFromStore(traderCfg, aiModelCfg, exchangeCfg, st)
|
||||
if err != nil {
|
||||
logger.Infof("❌ Failed to add trader %s: %v", traderCfg.Name, err)
|
||||
logger.Warnf("%s failed to add trader: %v", traderLogTag(traderCfg.ID, traderCfg.Name), err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -727,17 +734,17 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg
|
||||
|
||||
// Auto-start if trader was running before shutdown
|
||||
if traderCfg.IsRunning {
|
||||
logger.Infof("🔄 Auto-starting trader '%s' (was running before shutdown)...", traderCfg.Name)
|
||||
logger.Infof("%s 🔄 Auto-starting trader (was running before shutdown)...", traderLogTag(traderCfg.ID, traderCfg.Name))
|
||||
go func(trader *trader.AutoTrader, traderName, traderID, userID string) {
|
||||
if err := trader.Run(); err != nil {
|
||||
logger.Warnf("⚠️ Trader '%s' stopped with error: %v", traderName, err)
|
||||
logger.Warnf("%s trader stopped with error: %v", traderLogTag(traderID, traderName), err)
|
||||
// Update database to reflect stopped state
|
||||
if st != nil {
|
||||
_ = st.Trader().UpdateStatus(userID, traderID, false)
|
||||
}
|
||||
}
|
||||
}(at, traderCfg.Name, traderCfg.ID, traderCfg.UserID)
|
||||
logger.Infof("✅ Trader '%s' auto-started successfully", traderCfg.Name)
|
||||
logger.Infof("%s ✅ Trader auto-started successfully", traderLogTag(traderCfg.ID, traderCfg.Name))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -131,7 +131,7 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
|
||||
if userID == "" {
|
||||
userID = "default"
|
||||
}
|
||||
model, err := s.firstEnabled(userID)
|
||||
model, err := s.firstEnabledUsable(userID)
|
||||
if err == nil {
|
||||
return model, nil
|
||||
}
|
||||
@@ -139,14 +139,14 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
|
||||
return nil, err
|
||||
}
|
||||
if userID != "default" {
|
||||
return s.firstEnabled("default")
|
||||
return s.firstEnabledUsable("default")
|
||||
}
|
||||
return nil, fmt.Errorf("please configure an available AI model in the system first")
|
||||
}
|
||||
|
||||
func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
|
||||
func (s *AIModelStore) firstEnabledUsable(userID string) (*AIModel, error) {
|
||||
var model AIModel
|
||||
err := s.db.Where("user_id = ? AND enabled = ?", userID, true).
|
||||
err := s.db.Where("user_id = ? AND enabled = ? AND api_key != ''", userID, true).
|
||||
Order("updated_at DESC, id ASC").
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
|
||||
@@ -24,6 +24,31 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (at *AutoTrader) logTag() string {
|
||||
if at == nil {
|
||||
return "[trader_id=unknown]"
|
||||
}
|
||||
if at.name != "" {
|
||||
return fmt.Sprintf("[trader_id=%s trader_name=%s]", at.id, at.name)
|
||||
}
|
||||
return fmt.Sprintf("[trader_id=%s]", at.id)
|
||||
}
|
||||
|
||||
func (at *AutoTrader) logInfof(format string, args ...interface{}) {
|
||||
values := append([]interface{}{at.logTag()}, args...)
|
||||
logger.Infof("%s "+format, values...)
|
||||
}
|
||||
|
||||
func (at *AutoTrader) logWarnf(format string, args ...interface{}) {
|
||||
values := append([]interface{}{at.logTag()}, args...)
|
||||
logger.Warnf("%s "+format, values...)
|
||||
}
|
||||
|
||||
func (at *AutoTrader) logErrorf(format string, args ...interface{}) {
|
||||
values := append([]interface{}{at.logTag()}, args...)
|
||||
logger.Errorf("%s "+format, values...)
|
||||
}
|
||||
|
||||
// AutoTraderConfig auto trading configuration (simplified version - AI makes all decisions)
|
||||
type AutoTraderConfig struct {
|
||||
// Trader identification
|
||||
@@ -381,8 +406,8 @@ func (at *AutoTrader) Run() error {
|
||||
at.startTime = time.Now()
|
||||
|
||||
logger.Info("🚀 AI-driven automatic trading system started")
|
||||
logger.Infof("💰 Initial balance: %.2f USDT", at.initialBalance)
|
||||
logger.Infof("⚙️ Scan interval: %v", at.config.ScanInterval)
|
||||
at.logInfof("💰 Initial balance: %.2f USDT", at.initialBalance)
|
||||
at.logInfof("⚙️ Scan interval: %v", at.config.ScanInterval)
|
||||
logger.Info("🤖 AI will make full decisions on leverage, position size, stop loss/take profit, etc.")
|
||||
|
||||
// Pre-launch checks for claw402 users
|
||||
@@ -397,7 +422,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "lighter" {
|
||||
if lighterTrader, ok := at.trader.(*lighter.LighterTraderV2); ok && at.store != nil {
|
||||
lighterTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] Lighter order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 Lighter order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,7 +430,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "hyperliquid" {
|
||||
if hyperliquidTrader, ok := at.trader.(*hyperliquid.HyperliquidTrader); ok && at.store != nil {
|
||||
hyperliquidTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] Hyperliquid order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 Hyperliquid order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -413,7 +438,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "bybit" {
|
||||
if bybitTrader, ok := at.trader.(*bybit.BybitTrader); ok && at.store != nil {
|
||||
bybitTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] Bybit order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 Bybit order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -421,7 +446,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "okx" {
|
||||
if okxTrader, ok := at.trader.(*okx.OKXTrader); ok && at.store != nil {
|
||||
okxTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] OKX order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 OKX order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -429,7 +454,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "bitget" {
|
||||
if bitgetTrader, ok := at.trader.(*bitget.BitgetTrader); ok && at.store != nil {
|
||||
bitgetTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] Bitget order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 Bitget order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -437,7 +462,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "aster" {
|
||||
if asterTrader, ok := at.trader.(*aster.AsterTrader); ok && at.store != nil {
|
||||
asterTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] Aster order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 Aster order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -445,7 +470,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "binance" {
|
||||
if binanceTrader, ok := at.trader.(*binance.FuturesTrader); ok && at.store != nil {
|
||||
binanceTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] Binance order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 Binance order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,7 +478,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "gate" {
|
||||
if gateTrader, ok := at.trader.(*gate.GateTrader); ok && at.store != nil {
|
||||
gateTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] Gate order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 Gate order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -461,7 +486,7 @@ func (at *AutoTrader) Run() error {
|
||||
if at.exchange == "kucoin" {
|
||||
if kucoinTrader, ok := at.trader.(*kucoin.KuCoinTrader); ok && at.store != nil {
|
||||
kucoinTrader.StartOrderSync(at.id, at.exchangeID, at.exchange, at.store, 30*time.Second)
|
||||
logger.Infof("🔄 [%s] KuCoin order+position sync enabled (every 30s)", at.name)
|
||||
at.logInfof("🔄 KuCoin order+position sync enabled (every 30s)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -471,9 +496,9 @@ func (at *AutoTrader) Run() error {
|
||||
// Check if this is a grid trading strategy
|
||||
isGridStrategy := at.IsGridStrategy()
|
||||
if isGridStrategy {
|
||||
logger.Infof("🔲 [%s] Grid trading strategy detected, initializing grid...", at.name)
|
||||
at.logInfof("🔲 Grid trading strategy detected, initializing grid...")
|
||||
if err := at.InitializeGrid(); err != nil {
|
||||
logger.Errorf("❌ [%s] Failed to initialize grid: %v", at.name, err)
|
||||
at.logErrorf("❌ Failed to initialize grid: %v", err)
|
||||
return fmt.Errorf("grid initialization failed: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -481,11 +506,11 @@ func (at *AutoTrader) Run() error {
|
||||
// Execute immediately on first run
|
||||
if isGridStrategy {
|
||||
if err := at.RunGridCycle(); err != nil {
|
||||
logger.Infof("❌ Grid execution failed: %v", err)
|
||||
at.logErrorf("❌ Grid execution failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := at.runCycle(); err != nil {
|
||||
logger.Infof("❌ Execution failed: %v", err)
|
||||
at.logErrorf("❌ Execution failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,15 +527,15 @@ func (at *AutoTrader) Run() error {
|
||||
case <-ticker.C:
|
||||
if isGridStrategy {
|
||||
if err := at.RunGridCycle(); err != nil {
|
||||
logger.Infof("❌ Grid execution failed: %v", err)
|
||||
at.logErrorf("❌ Grid execution failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := at.runCycle(); err != nil {
|
||||
logger.Infof("❌ Execution failed: %v", err)
|
||||
at.logErrorf("❌ Execution failed: %v", err)
|
||||
}
|
||||
}
|
||||
case <-at.stopMonitorCh:
|
||||
logger.Infof("[%s] ⏹ Stop signal received, exiting automatic trading main loop", at.name)
|
||||
at.logInfof("⏹ Stop signal received, exiting automatic trading main loop")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -590,6 +615,22 @@ func (at *AutoTrader) GetSystemPromptTemplate() string {
|
||||
return "strategy"
|
||||
}
|
||||
|
||||
// GetCandidateCoins returns the current candidate coin set from the trader's strategy engine.
|
||||
func (at *AutoTrader) GetCandidateCoins() ([]kernel.CandidateCoin, error) {
|
||||
if at.strategyEngine == nil {
|
||||
return nil, fmt.Errorf("strategy engine not configured")
|
||||
}
|
||||
return at.strategyEngine.GetCandidateCoins()
|
||||
}
|
||||
|
||||
// GetStrategyConfig returns the current strategy config used by the trader.
|
||||
func (at *AutoTrader) GetStrategyConfig() *store.StrategyConfig {
|
||||
if at.strategyEngine == nil {
|
||||
return at.config.StrategyConfig
|
||||
}
|
||||
return at.strategyEngine.GetConfig()
|
||||
}
|
||||
|
||||
// GetStore gets data store (for external access to decision records, etc.)
|
||||
func (at *AutoTrader) GetStore() *store.Store {
|
||||
return at.store
|
||||
|
||||
@@ -24,7 +24,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
running := at.isRunning
|
||||
at.isRunningMutex.RUnlock()
|
||||
if !running {
|
||||
logger.Infof("⏹ Trader is stopped, aborting cycle #%d", at.callCount)
|
||||
at.logInfof("⏹ Trader is stopped, aborting cycle #%d", at.callCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
// 1. Check if trading needs to be stopped
|
||||
if time.Now().Before(at.stopUntil) {
|
||||
remaining := at.stopUntil.Sub(time.Now())
|
||||
logger.Infof("⏸ Risk control: Trading paused, remaining %.0f minutes", remaining.Minutes())
|
||||
at.logWarnf("⏸ Risk control: Trading paused, remaining %.0f minutes", remaining.Minutes())
|
||||
record.Success = false
|
||||
record.ErrorMessage = fmt.Sprintf("Risk control paused, remaining %.0f minutes", remaining.Minutes())
|
||||
at.saveDecision(record)
|
||||
@@ -59,6 +59,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
// 4. Collect trading context
|
||||
ctx, err := at.buildTradingContext()
|
||||
if err != nil {
|
||||
at.logErrorf("failed to build trading context: %v", err)
|
||||
record.Success = false
|
||||
record.ErrorMessage = fmt.Sprintf("Failed to build trading context: %v", err)
|
||||
at.saveDecision(record)
|
||||
@@ -71,7 +72,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
|
||||
// If no candidate coins available, log but do not error
|
||||
if len(ctx.CandidateCoins) == 0 {
|
||||
logger.Infof("ℹ️ No candidate coins available, skipping this cycle")
|
||||
at.logInfof("ℹ️ No candidate coins available, skipping this cycle")
|
||||
record.Success = true // Not an error, just no candidate coins
|
||||
record.ExecutionLog = append(record.ExecutionLog, "No candidate coins available, cycle skipped")
|
||||
record.AccountState = store.AccountSnapshot{
|
||||
@@ -90,16 +91,16 @@ func (at *AutoTrader) runCycle() error {
|
||||
record.CandidateCoins = append(record.CandidateCoins, coin.Symbol)
|
||||
}
|
||||
|
||||
logger.Infof("📊 Account equity: %.2f USDT | Available: %.2f USDT | Positions: %d",
|
||||
at.logInfof("📊 Account equity: %.2f USDT | Available: %.2f USDT | Positions: %d",
|
||||
ctx.Account.TotalEquity, ctx.Account.AvailableBalance, ctx.Account.PositionCount)
|
||||
|
||||
// 5. Use strategy engine to call AI for decision
|
||||
logger.Infof("🤖 Requesting AI analysis and decision... [Strategy Engine]")
|
||||
at.logInfof("🤖 Requesting AI analysis and decision... [Strategy Engine]")
|
||||
aiDecision, err := kernel.GetFullDecisionWithStrategy(ctx, at.mcpClient, at.strategyEngine, "balanced")
|
||||
|
||||
if aiDecision != nil && aiDecision.AIRequestDurationMs > 0 {
|
||||
record.AIRequestDurationMs = aiDecision.AIRequestDurationMs
|
||||
logger.Infof("⏱️ AI call duration: %.2f seconds", float64(record.AIRequestDurationMs)/1000)
|
||||
at.logInfof("⏱️ AI call duration: %.2f seconds", float64(record.AIRequestDurationMs)/1000)
|
||||
record.ExecutionLog = append(record.ExecutionLog,
|
||||
fmt.Sprintf("AI call duration: %d ms", record.AIRequestDurationMs))
|
||||
}
|
||||
@@ -119,7 +120,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
// Record AI charge (track cost regardless of decision outcome)
|
||||
if aiDecision != nil && at.store != nil {
|
||||
if chargeErr := at.store.AICharge().Record(at.id, at.aiModel, at.config.AIModel); chargeErr != nil {
|
||||
logger.Warnf("⚠️ Failed to record AI charge: %v", chargeErr)
|
||||
at.logWarnf("⚠️ Failed to record AI charge: %v", chargeErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,10 +133,9 @@ func (at *AutoTrader) runCycle() error {
|
||||
if at.consecutiveAIFailures >= 3 && !at.safeMode {
|
||||
at.safeMode = true
|
||||
at.safeModeReason = fmt.Sprintf("AI failed %d consecutive times: %v", at.consecutiveAIFailures, err)
|
||||
logger.Errorf("🛡️ [%s] SAFE MODE ACTIVATED — AI failed %d times in a row. No new positions will be opened. Existing positions are protected with current stop-loss settings.",
|
||||
at.name, at.consecutiveAIFailures)
|
||||
logger.Errorf("🛡️ [%s] Reason: %v", at.name, err)
|
||||
logger.Errorf("🛡️ [%s] Action: Will keep trying AI each cycle. Safe mode auto-deactivates when AI recovers.", at.name)
|
||||
at.logErrorf("🛡️ SAFE MODE ACTIVATED — AI failed %d times in a row. No new positions will be opened. Existing positions are protected with current stop-loss settings.", at.consecutiveAIFailures)
|
||||
at.logErrorf("🛡️ Reason: %v", err)
|
||||
at.logErrorf("🛡️ Action: Will keep trying AI each cycle. Safe mode auto-deactivates when AI recovers.")
|
||||
}
|
||||
|
||||
// Print system prompt and AI chain of thought (output even with errors for debugging)
|
||||
@@ -159,7 +159,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
|
||||
// In safe mode, don't return error — keep the loop running to retry next cycle
|
||||
if at.safeMode {
|
||||
logger.Warnf("🛡️ [%s] Safe mode: skipping this cycle, will retry in %v", at.name, at.config.ScanInterval)
|
||||
at.logWarnf("🛡️ Safe mode: skipping this cycle, will retry in %v", at.config.ScanInterval)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -168,11 +168,11 @@ func (at *AutoTrader) runCycle() error {
|
||||
|
||||
// AI succeeded — reset failure counter and deactivate safe mode
|
||||
if at.consecutiveAIFailures > 0 {
|
||||
logger.Infof("✅ [%s] AI recovered after %d consecutive failures", at.name, at.consecutiveAIFailures)
|
||||
at.logInfof("✅ AI recovered after %d consecutive failures", at.consecutiveAIFailures)
|
||||
}
|
||||
at.consecutiveAIFailures = 0
|
||||
if at.safeMode {
|
||||
logger.Infof("🛡️ [%s] SAFE MODE DEACTIVATED — AI is working again. Resuming normal trading.", at.name)
|
||||
at.logInfof("🛡️ SAFE MODE DEACTIVATED — AI is working again. Resuming normal trading.")
|
||||
at.safeMode = false
|
||||
at.safeModeReason = ""
|
||||
}
|
||||
@@ -219,7 +219,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
running = at.isRunning
|
||||
at.isRunningMutex.RUnlock()
|
||||
if !running {
|
||||
logger.Infof("⏹ Trader stopped before decision execution, aborting cycle #%d", at.callCount)
|
||||
at.logInfof("⏹ Trader stopped before decision execution, aborting cycle #%d", at.callCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -228,14 +228,14 @@ func (at *AutoTrader) runCycle() error {
|
||||
filtered := make([]kernel.Decision, 0)
|
||||
for _, d := range sortedDecisions {
|
||||
if d.Action == "open_long" || d.Action == "open_short" {
|
||||
logger.Warnf("🛡️ [%s] Safe mode: BLOCKED %s %s (no new positions allowed)", at.name, d.Action, d.Symbol)
|
||||
at.logWarnf("🛡️ Safe mode: BLOCKED %s %s (no new positions allowed)", d.Action, d.Symbol)
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, d)
|
||||
}
|
||||
sortedDecisions = filtered
|
||||
if len(sortedDecisions) == 0 {
|
||||
logger.Infof("🛡️ [%s] Safe mode: all decisions were open positions, nothing to execute", at.name)
|
||||
at.logInfof("🛡️ Safe mode: all decisions were open positions, nothing to execute")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,7 +246,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
running = at.isRunning
|
||||
at.isRunningMutex.RUnlock()
|
||||
if !running {
|
||||
logger.Infof("⏹ Trader stopped during decision execution, aborting remaining decisions")
|
||||
at.logInfof("⏹ Trader stopped during decision execution, aborting remaining decisions")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -265,7 +265,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
}
|
||||
|
||||
if err := at.executeDecisionWithRecord(&d, &actionRecord); err != nil {
|
||||
logger.Infof("❌ Failed to execute decision (%s %s): %v", d.Symbol, d.Action, err)
|
||||
at.logErrorf("❌ Failed to execute decision (%s %s): %v", d.Symbol, d.Action, err)
|
||||
actionRecord.Error = err.Error()
|
||||
record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("❌ %s %s failed: %v", d.Symbol, d.Action, err))
|
||||
} else {
|
||||
@@ -280,7 +280,7 @@ func (at *AutoTrader) runCycle() error {
|
||||
|
||||
// 9. Save decision record
|
||||
if err := at.saveDecision(record); err != nil {
|
||||
logger.Infof("⚠ Failed to save decision record: %v", err)
|
||||
at.logWarnf("⚠ Failed to save decision record: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -417,12 +417,12 @@ func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) {
|
||||
// 3. Use strategy engine to get candidate coins (must have strategy engine)
|
||||
var candidateCoins []kernel.CandidateCoin
|
||||
if at.strategyEngine == nil {
|
||||
logger.Infof("⚠️ [%s] No strategy engine configured, skipping candidate coins", at.name)
|
||||
at.logWarnf("⚠️ No strategy engine configured, skipping candidate coins")
|
||||
} else {
|
||||
coins, err := at.strategyEngine.GetCandidateCoins()
|
||||
if err != nil {
|
||||
// Log warning but don't fail - equity snapshot should still be saved
|
||||
logger.Infof("⚠️ [%s] Failed to get candidate coins: %v (will use empty list)", at.name, err)
|
||||
at.logWarnf("⚠️ Failed to get candidate coins: %v (will use empty list)", err)
|
||||
} else {
|
||||
candidateCoins = coins
|
||||
logger.Infof("📋 [%s] Strategy engine fetched candidate coins: %d", at.name, len(candidateCoins))
|
||||
@@ -473,7 +473,7 @@ func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) {
|
||||
// Get recent 10 closed trades for AI context
|
||||
recentTrades, err := at.store.Position().GetRecentTrades(at.id, 10)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ [%s] Failed to get recent trades: %v", at.name, err)
|
||||
at.logWarnf("⚠️ Failed to get recent trades: %v", err)
|
||||
} else {
|
||||
logger.Infof("📊 [%s] Found %d recent closed trades for AI context", at.name, len(recentTrades))
|
||||
for _, trade := range recentTrades {
|
||||
@@ -503,11 +503,11 @@ func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) {
|
||||
// Get trading statistics for AI context
|
||||
stats, err := at.store.Position().GetFullStats(at.id)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ [%s] Failed to get trading stats: %v", at.name, err)
|
||||
at.logWarnf("⚠️ Failed to get trading stats: %v", err)
|
||||
} else if stats == nil {
|
||||
logger.Infof("⚠️ [%s] GetFullStats returned nil", at.name)
|
||||
at.logWarnf("⚠️ GetFullStats returned nil")
|
||||
} else if stats.TotalTrades == 0 {
|
||||
logger.Infof("⚠️ [%s] GetFullStats returned 0 trades (traderID=%s)", at.name, at.id)
|
||||
at.logWarnf("⚠️ GetFullStats returned 0 trades")
|
||||
} else {
|
||||
ctx.TradingStats = &kernel.TradingStats{
|
||||
TotalTrades: stats.TotalTrades,
|
||||
@@ -523,7 +523,7 @@ func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) {
|
||||
at.name, stats.TotalTrades, stats.WinRate, stats.ProfitFactor, stats.SharpeRatio, stats.MaxDrawdownPct)
|
||||
}
|
||||
} else {
|
||||
logger.Infof("⚠️ [%s] Store is nil, cannot get recent trades", at.name)
|
||||
at.logWarnf("⚠️ Store is nil, cannot get recent trades")
|
||||
}
|
||||
|
||||
// 8. Get quantitative data (if enabled in strategy config)
|
||||
@@ -630,15 +630,15 @@ func (at *AutoTrader) checkClaw402Balance() {
|
||||
if at.claw402WalletAddr != "" {
|
||||
balance, err := wallet.QueryUSDCBalance(at.claw402WalletAddr)
|
||||
if err != nil {
|
||||
logger.Warnf("⚠️ [%s] Failed to query USDC balance: %v", at.name, err)
|
||||
at.logWarnf("⚠️ Failed to query USDC balance: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if balance < 1.0 {
|
||||
logger.Warnf("⚠️ [%s] Low USDC balance: $%.2f — AI may stop soon!", at.name, balance)
|
||||
at.logWarnf("⚠️ Low USDC balance: $%.2f — AI may stop soon!", balance)
|
||||
}
|
||||
if balance <= 0 {
|
||||
logger.Errorf("🚨 [%s] USDC balance is ZERO — AI calls will fail!", at.name)
|
||||
at.logErrorf("🚨 USDC balance is ZERO — AI calls will fail!")
|
||||
}
|
||||
|
||||
runway := float64(0)
|
||||
|
||||
@@ -43,7 +43,7 @@ export function AgentStepPanel({ steps, visible }: AgentStepPanelProps) {
|
||||
marginBottom: 10,
|
||||
}}
|
||||
>
|
||||
Agent Steps
|
||||
Live Run
|
||||
</div>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 8 }}>
|
||||
{steps.map((step) => {
|
||||
|
||||
@@ -19,6 +19,11 @@ import { WelcomeScreen } from '../components/agent/WelcomeScreen'
|
||||
import { ChatMessages } from '../components/agent/ChatMessages'
|
||||
import { ChatInput, type ChatInputHandle } from '../components/agent/ChatInput'
|
||||
import { UserPreferencesPanel } from '../components/agent/UserPreferencesPanel'
|
||||
import {
|
||||
useAgentChatStore,
|
||||
type AgentMessage as Message,
|
||||
type AgentStep,
|
||||
} from '../stores/agentChatStore'
|
||||
import {
|
||||
chatStorageKey,
|
||||
clearAgentMessages,
|
||||
@@ -29,22 +34,6 @@ import {
|
||||
persistAgentMessages,
|
||||
} from '../lib/agentChatStorage'
|
||||
|
||||
interface Message {
|
||||
id: string
|
||||
role: 'user' | 'bot'
|
||||
text: string
|
||||
time: string
|
||||
streaming?: boolean
|
||||
steps?: AgentStep[]
|
||||
}
|
||||
|
||||
interface AgentStep {
|
||||
id: string
|
||||
label: string
|
||||
status: 'planning' | 'pending' | 'running' | 'completed' | 'replanned'
|
||||
detail?: string
|
||||
}
|
||||
|
||||
let msgIdCounter = 0
|
||||
function nextId() {
|
||||
return `msg-${Date.now()}-${++msgIdCounter}`
|
||||
@@ -66,7 +55,7 @@ function parsePlanSteps(data: string): AgentStep[] {
|
||||
return text.split(/\s*->\s*/).map((part, index) => {
|
||||
const cleaned = part.replace(/^\d+\./, '').trim()
|
||||
return {
|
||||
id: `plan-${index + 1}`,
|
||||
id: `action-${index + 1}`,
|
||||
label: cleaned || `Step ${index + 1}`,
|
||||
status: 'pending',
|
||||
}
|
||||
@@ -76,7 +65,7 @@ function parsePlanSteps(data: string): AgentStep[] {
|
||||
function parseStepEvent(data: string, fallbackIndex: number): AgentStep {
|
||||
const match = data.match(/Step\s+(\d+)\/(\d+):\s+(.+)$/i) || data.match(/步骤\s+(\d+)\/(\d+):\s+(.+)$/)
|
||||
if (match) {
|
||||
const id = `plan-${match[1]}`
|
||||
const id = `action-${match[1]}`
|
||||
return {
|
||||
id,
|
||||
label: match[3].trim(),
|
||||
@@ -110,11 +99,14 @@ export function AgentChatPage() {
|
||||
const [storageUserId, setStorageUserId] = useState<string | undefined>(() => getStoredAuthUserId())
|
||||
const [sidebarOpen, setSidebarOpen] = useState(() => window.innerWidth > 1024)
|
||||
const storageKey = chatStorageKey(user?.id || storageUserId)
|
||||
const [messages, setMessages] = useState<Message[]>(
|
||||
() => loadAgentMessages<Message>(window.localStorage, user?.id || storageUserId).messages
|
||||
)
|
||||
const [historyHydrated, setHistoryHydrated] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const messages = useAgentChatStore((state) => state.messages)
|
||||
const loading = useAgentChatStore((state) => state.loading)
|
||||
const historyHydrated = useAgentChatStore((state) => state.hydrated)
|
||||
const activeUserId = useAgentChatStore((state) => state.activeUserId)
|
||||
const setMessages = useAgentChatStore((state) => state.setMessages)
|
||||
const updateMessages = useAgentChatStore((state) => state.updateMessages)
|
||||
const setLoading = useAgentChatStore((state) => state.setLoading)
|
||||
const resetForUser = useAgentChatStore((state) => state.resetForUser)
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null)
|
||||
const chatInputRef = useRef<ChatInputHandle>(null)
|
||||
const abortRef = useRef<AbortController | null>(null)
|
||||
@@ -147,10 +139,13 @@ export function AgentChatPage() {
|
||||
|
||||
// Restore chat history for the current user when opening the agent page.
|
||||
useEffect(() => {
|
||||
setHistoryHydrated(false)
|
||||
setMessages(loadAgentMessages<Message>(window.localStorage, user?.id || storageUserId).messages)
|
||||
setHistoryHydrated(true)
|
||||
}, [storageKey, storageUserId, user?.id])
|
||||
const nextUserId = user?.id || storageUserId
|
||||
if (activeUserId === nextUserId && historyHydrated) return
|
||||
resetForUser(
|
||||
nextUserId,
|
||||
loadAgentMessages<Message>(window.localStorage, nextUserId).messages
|
||||
)
|
||||
}, [activeUserId, historyHydrated, resetForUser, storageKey, storageUserId, user?.id])
|
||||
|
||||
// Persist chat history locally so page navigation does not wipe the conversation.
|
||||
useEffect(() => {
|
||||
@@ -163,6 +158,26 @@ export function AgentChatPage() {
|
||||
}
|
||||
}, [historyHydrated, messages, storageKey, storageUserId, user?.id])
|
||||
|
||||
const persistMessagesSnapshot = (nextMessages: Message[]) => {
|
||||
const persistable = prepareAgentMessagesForPersistence(nextMessages).slice(-100)
|
||||
persistAgentMessages(window.localStorage, user?.id || storageUserId, persistable)
|
||||
}
|
||||
|
||||
const replaceMessages = (nextMessages: Message[]) => {
|
||||
setMessages(nextMessages)
|
||||
if (historyHydrated) {
|
||||
persistMessagesSnapshot(nextMessages)
|
||||
}
|
||||
}
|
||||
|
||||
const patchMessages = (updater: (prev: Message[]) => Message[]) => {
|
||||
const nextMessages = updater(useAgentChatStore.getState().messages)
|
||||
updateMessages(() => nextMessages)
|
||||
if (useAgentChatStore.getState().hydrated) {
|
||||
persistMessagesSnapshot(nextMessages)
|
||||
}
|
||||
}
|
||||
|
||||
// Responsive sidebar
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
@@ -201,10 +216,10 @@ export function AgentChatPage() {
|
||||
streaming: true,
|
||||
},
|
||||
]
|
||||
setMessages((prev) =>
|
||||
replaceMessages(
|
||||
text.trim() === '/clear'
|
||||
? nextConversation
|
||||
: [...prev, ...nextConversation]
|
||||
: [...useAgentChatStore.getState().messages, ...nextConversation]
|
||||
)
|
||||
setLoading(true)
|
||||
|
||||
@@ -275,7 +290,7 @@ export function AgentChatPage() {
|
||||
if (eventType === 'delta') {
|
||||
// data is the accumulated text so far
|
||||
finalText = data
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId
|
||||
? { ...m, text: data, time: now() }
|
||||
@@ -284,13 +299,12 @@ export function AgentChatPage() {
|
||||
)
|
||||
} else if (eventType === 'plan') {
|
||||
const parsedSteps = parsePlanSteps(data)
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId
|
||||
? {
|
||||
...m,
|
||||
steps: parsedSteps.length > 0 ? parsedSteps : m.steps,
|
||||
text: m.text || data,
|
||||
time: now(),
|
||||
}
|
||||
: m
|
||||
@@ -299,33 +313,31 @@ export function AgentChatPage() {
|
||||
} else if (eventType === 'step_start') {
|
||||
stepCounter += 1
|
||||
const nextStep = parseStepEvent(data, stepCounter)
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId
|
||||
? {
|
||||
...m,
|
||||
steps: appendStep(m.steps, nextStep),
|
||||
text: m.text || data,
|
||||
time: now(),
|
||||
}
|
||||
: m
|
||||
)
|
||||
)
|
||||
} else if (eventType === 'step_complete') {
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId
|
||||
? {
|
||||
...m,
|
||||
steps: markLatestRunningCompleted(m.steps, data),
|
||||
text: m.text || data,
|
||||
time: now(),
|
||||
}
|
||||
: m
|
||||
)
|
||||
)
|
||||
} else if (eventType === 'replan') {
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId
|
||||
? {
|
||||
@@ -336,7 +348,6 @@ export function AgentChatPage() {
|
||||
status: 'replanned',
|
||||
detail: data,
|
||||
}),
|
||||
text: m.text || data,
|
||||
time: now(),
|
||||
}
|
||||
: m
|
||||
@@ -346,12 +357,11 @@ export function AgentChatPage() {
|
||||
eventType === 'tool'
|
||||
) {
|
||||
// Show tool being called as a status indicator
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId
|
||||
? {
|
||||
...m,
|
||||
text: m.text || `🔧 _Calling ${data}..._`,
|
||||
steps: appendStep(m.steps, {
|
||||
id: `tool-${Date.now()}`,
|
||||
label: `Tool: ${data}`,
|
||||
@@ -365,7 +375,7 @@ export function AgentChatPage() {
|
||||
)
|
||||
} else if (eventType === 'done') {
|
||||
finalText = data
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId
|
||||
? { ...m, text: data, time: now(), streaming: false }
|
||||
@@ -381,7 +391,7 @@ export function AgentChatPage() {
|
||||
}
|
||||
|
||||
// If stream ended without a "done" event, mark as done
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId && m.streaming
|
||||
? {
|
||||
@@ -398,9 +408,9 @@ export function AgentChatPage() {
|
||||
} catch (e: any) {
|
||||
if (e.name === 'AbortError') {
|
||||
// Request was cancelled (e.g. user sent a new message), clean up silently
|
||||
setMessages((prev) => prev.filter((m) => m.id !== botId))
|
||||
patchMessages((prev) => prev.filter((m) => m.id !== botId))
|
||||
} else {
|
||||
setMessages((prev) =>
|
||||
patchMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === botId
|
||||
? {
|
||||
|
||||
52
web/src/stores/agentChatStore.ts
Normal file
52
web/src/stores/agentChatStore.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import { create } from 'zustand'
|
||||
|
||||
export interface AgentStep {
|
||||
id: string
|
||||
label: string
|
||||
status: 'planning' | 'pending' | 'running' | 'completed' | 'replanned'
|
||||
detail?: string
|
||||
}
|
||||
|
||||
export interface AgentMessage {
|
||||
id: string
|
||||
role: 'user' | 'bot'
|
||||
text: string
|
||||
time: string
|
||||
streaming?: boolean
|
||||
steps?: AgentStep[]
|
||||
}
|
||||
|
||||
interface AgentChatStoreState {
|
||||
activeUserId?: string
|
||||
messages: AgentMessage[]
|
||||
loading: boolean
|
||||
hydrated: boolean
|
||||
setActiveUserId: (userId?: string) => void
|
||||
setMessages: (messages: AgentMessage[]) => void
|
||||
updateMessages: (
|
||||
updater: (messages: AgentMessage[]) => AgentMessage[]
|
||||
) => void
|
||||
setLoading: (loading: boolean) => void
|
||||
setHydrated: (hydrated: boolean) => void
|
||||
resetForUser: (userId?: string, messages?: AgentMessage[]) => void
|
||||
}
|
||||
|
||||
export const useAgentChatStore = create<AgentChatStoreState>((set) => ({
|
||||
activeUserId: undefined,
|
||||
messages: [],
|
||||
loading: false,
|
||||
hydrated: false,
|
||||
setActiveUserId: (userId) => set({ activeUserId: userId }),
|
||||
setMessages: (messages) => set({ messages }),
|
||||
updateMessages: (updater) =>
|
||||
set((state) => ({ messages: updater(state.messages) })),
|
||||
setLoading: (loading) => set({ loading }),
|
||||
setHydrated: (hydrated) => set({ hydrated }),
|
||||
resetForUser: (userId, messages = []) =>
|
||||
set({
|
||||
activeUserId: userId,
|
||||
messages,
|
||||
loading: false,
|
||||
hydrated: true,
|
||||
}),
|
||||
}))
|
||||
Reference in New Issue
Block a user