diff --git a/api/server.go b/api/server.go index e65a8e90..3a5e645e 100644 --- a/api/server.go +++ b/api/server.go @@ -235,9 +235,10 @@ type ExchangeConfig struct { type UpdateModelConfigRequest struct { Models map[string]struct { - Enabled bool `json:"enabled"` - APIKey string `json:"api_key"` - CustomAPIURL string `json:"custom_api_url"` + Enabled bool `json:"enabled"` + APIKey string `json:"api_key"` + CustomAPIURL string `json:"custom_api_url"` + CustomModelName string `json:"custom_model_name"` } `json:"models"` } @@ -612,16 +613,23 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - + // 更新每个模型的配置 for modelID, modelData := range req.Models { - err := s.database.UpdateAIModel(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL) + err := s.database.UpdateAIModel(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新模型 %s 失败: %v", modelID, err)}) return } } - + + // 重新加载该用户的所有交易员,使新配置立即生效 + err := s.traderManager.LoadUserTraders(s.database, userID) + if err != nil { + log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err) + // 这里不返回错误,因为模型配置已经成功更新到数据库 + } + log.Printf("✓ AI模型配置已更新: %+v", req.Models) c.JSON(http.StatusOK, gin.H{"message": "模型配置已更新"}) } @@ -649,7 +657,7 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - + // 更新每个交易所的配置 for exchangeID, exchangeData := range req.Exchanges { err := s.database.UpdateExchange(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey) @@ -658,7 +666,14 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { return } } - + + // 重新加载该用户的所有交易员,使新配置立即生效 + err := s.traderManager.LoadUserTraders(s.database, userID) + if err != nil { + log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err) + // 这里不返回错误,因为交易所配置已经成功更新到数据库 + } + log.Printf("✓ 交易所配置已更新: %+v", req.Exchanges) c.JSON(http.StatusOK, gin.H{"message": "交易所配置已更新"}) } @@ -725,12 +740,21 @@ func (s *Server) handleTraderList(c *gin.Context) { } } + // AIModelID 应该已经是 provider(如 "deepseek"),直接使用 + // 如果是旧数据格式(如 "admin_deepseek"),提取 provider 部分 + aiModelID := trader.AIModelID + // 兼容旧数据:如果包含下划线,提取最后一部分作为 provider + if strings.Contains(aiModelID, "_") { + parts := strings.Split(aiModelID, "_") + aiModelID = parts[len(parts)-1] + } + result = append(result, map[string]interface{}{ - "trader_id": trader.ID, - "trader_name": trader.Name, - "ai_model": trader.AIModelID, - "exchange_id": trader.ExchangeID, - "is_running": isRunning, + "trader_id": trader.ID, + "trader_name": trader.Name, + "ai_model": aiModelID, + "exchange_id": trader.ExchangeID, + "is_running": isRunning, "initial_balance": trader.InitialBalance, }) } @@ -763,21 +787,30 @@ func (s *Server) handleGetTraderConfig(c *gin.Context) { } } + // AIModelID 应该已经是 provider(如 "deepseek"),直接使用 + // 如果是旧数据格式(如 "admin_deepseek"),提取 provider 部分 + aiModelID := traderConfig.AIModelID + // 兼容旧数据:如果包含下划线,提取最后一部分作为 provider + if strings.Contains(aiModelID, "_") { + parts := strings.Split(aiModelID, "_") + aiModelID = parts[len(parts)-1] + } + result := map[string]interface{}{ - "trader_id": traderConfig.ID, - "trader_name": traderConfig.Name, - "ai_model": traderConfig.AIModelID, - "exchange_id": traderConfig.ExchangeID, - "initial_balance": traderConfig.InitialBalance, - "btc_eth_leverage": traderConfig.BTCETHLeverage, - "altcoin_leverage": traderConfig.AltcoinLeverage, - "trading_symbols": traderConfig.TradingSymbols, - "custom_prompt": traderConfig.CustomPrompt, + "trader_id": traderConfig.ID, + "trader_name": traderConfig.Name, + "ai_model": aiModelID, + "exchange_id": traderConfig.ExchangeID, + "initial_balance": traderConfig.InitialBalance, + "btc_eth_leverage": traderConfig.BTCETHLeverage, + "altcoin_leverage": traderConfig.AltcoinLeverage, + "trading_symbols": traderConfig.TradingSymbols, + "custom_prompt": traderConfig.CustomPrompt, "override_base_prompt": traderConfig.OverrideBasePrompt, - "is_cross_margin": traderConfig.IsCrossMargin, - "use_coin_pool": traderConfig.UseCoinPool, - "use_oi_top": traderConfig.UseOITop, - "is_running": isRunning, + "is_cross_margin": traderConfig.IsCrossMargin, + "use_coin_pool": traderConfig.UseCoinPool, + "use_oi_top": traderConfig.UseOITop, + "is_running": isRunning, } c.JSON(http.StatusOK, result) diff --git a/config/database.go b/config/database.go index 70da76b6..a056db88 100644 --- a/config/database.go +++ b/config/database.go @@ -185,6 +185,7 @@ func (d *Database) createTables() error { `ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`, // 是否使用COIN POOL信号源 `ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`, // 是否使用OI TOP信号源 `ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`, // 自定义API地址 + `ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`, // 自定义模型名称 } for _, query := range alterQueries { @@ -362,15 +363,16 @@ type User struct { // AIModelConfig AI模型配置 type AIModelConfig struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Provider string `json:"provider"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` - CustomAPIURL string `json:"customApiUrl"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Provider string `json:"provider"` + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey"` + CustomAPIURL string `json:"customApiUrl"` + CustomModelName string `json:"customModelName"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // ExchangeConfig 交易所配置 @@ -530,7 +532,10 @@ func (d *Database) UpdateUserOTPVerified(userID string, verified bool) error { // GetAIModels 获取用户的AI模型配置 func (d *Database) GetAIModels(userID string) ([]*AIModelConfig, error) { rows, err := d.db.Query(` - SELECT id, user_id, name, provider, enabled, api_key, COALESCE(custom_api_url, '') as custom_api_url, created_at, updated_at + SELECT id, user_id, name, provider, enabled, api_key, + COALESCE(custom_api_url, '') as custom_api_url, + COALESCE(custom_model_name, '') as custom_model_name, + created_at, updated_at FROM ai_models WHERE user_id = ? ORDER BY id `, userID) if err != nil { @@ -543,8 +548,8 @@ func (d *Database) GetAIModels(userID string) ([]*AIModelConfig, error) { for rows.Next() { var model AIModelConfig err := rows.Scan( - &model.ID, &model.UserID, &model.Name, &model.Provider, - &model.Enabled, &model.APIKey, &model.CustomAPIURL, + &model.ID, &model.UserID, &model.Name, &model.Provider, + &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, &model.CreatedAt, &model.UpdatedAt, ) if err != nil { @@ -557,52 +562,50 @@ func (d *Database) GetAIModels(userID string) ([]*AIModelConfig, error) { } // UpdateAIModel 更新AI模型配置,如果不存在则创建用户特定配置 -func (d *Database) UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL string) error { - // 首先尝试更新现有的用户配置 - result, err := d.db.Exec(` - UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ? WHERE id = ? AND user_id = ? - `, enabled, apiKey, customAPIURL, id, userID) - if err != nil { - return err - } - - // 检查是否有行被更新 - rowsAffected, err := result.RowsAffected() - if err != nil { - return err - } - - // 如果没有行被更新,说明用户没有这个模型的配置,需要创建 - if rowsAffected == 0 { - // 获取模型的基本信息 - var name, provider string - err = d.db.QueryRow(` - SELECT name, provider FROM ai_models WHERE provider = ? LIMIT 1 - `, id).Scan(&name, &provider) - if err != nil { - // 如果找不到基本信息,使用默认值 - if id == "deepseek" { - name = "DeepSeek AI" - provider = "deepseek" - } else if id == "qwen" { - name = "Qwen AI" - provider = "qwen" - } else { - name = id + " AI" - provider = id - } - } - - // 创建用户特定的配置 - userModelID := fmt.Sprintf("%s_%s", userID, id) +func (d *Database) UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error { + // id 参数实际上是 provider(如 "deepseek", "qwen") + provider := id + + // 先查找用户是否已有这个 provider 的配置 + var existingID string + err := d.db.QueryRow(` + SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1 + `, userID, provider).Scan(&existingID) + + if err == nil { + // 找到了现有配置,更新它 _, err = d.db.Exec(` - INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) - `, userModelID, userID, name, provider, enabled, apiKey, customAPIURL) + UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') + WHERE id = ? AND user_id = ? + `, enabled, apiKey, customAPIURL, customModelName, existingID, userID) return err } - - return nil + + // 没有找到现有配置,创建新的 + // 获取模型的基本信息 + var name string + err = d.db.QueryRow(` + SELECT name FROM ai_models WHERE provider = ? LIMIT 1 + `, provider).Scan(&name) + if err != nil { + // 如果找不到基本信息,使用默认值 + if provider == "deepseek" { + name = "DeepSeek AI" + } else if provider == "qwen" { + name = "Qwen AI" + } else { + name = provider + " AI" + } + } + + // 创建用户特定的配置 + userModelID := fmt.Sprintf("%s_%s", userID, provider) + _, err = d.db.Exec(` + INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) + `, userModelID, userID, name, provider, enabled, apiKey, customAPIURL, customModelName) + + return err } // GetExchanges 获取用户的交易所配置 diff --git a/manager/trader_manager.go b/manager/trader_manager.go index 014eefd1..e55b4acf 100644 --- a/manager/trader_manager.go +++ b/manager/trader_manager.go @@ -94,7 +94,9 @@ func (tm *TraderManager) LoadTradersFromDatabase(database *config.Database) erro var aiModelCfg *config.AIModelConfig for _, model := range aiModels { - if model.ID == traderCfg.AIModelID { + // 使用 provider 来匹配,因为 AIModelID 存储的是 provider(如 "deepseek") + // 而 model.ID 可能是 "admin_deepseek" + if model.Provider == traderCfg.AIModelID { aiModelCfg = model break } @@ -202,6 +204,8 @@ func (tm *TraderManager) addTraderFromDB(traderCfg *config.TraderRecord, aiModel UseQwen: aiModelCfg.Provider == "qwen", DeepSeekKey: "", QwenKey: "", + CustomAPIURL: aiModelCfg.CustomAPIURL, // 自定义API URL + CustomModelName: aiModelCfg.CustomModelName, // 自定义模型名称 ScanInterval: time.Duration(traderCfg.ScanIntervalMinutes) * time.Minute, InitialBalance: traderCfg.InitialBalance, BTCETHLeverage: traderCfg.BTCETHLeverage, @@ -306,6 +310,8 @@ func (tm *TraderManager) AddTraderFromDB(traderCfg *config.TraderRecord, aiModel UseQwen: aiModelCfg.Provider == "qwen", DeepSeekKey: "", QwenKey: "", + CustomAPIURL: aiModelCfg.CustomAPIURL, // 自定义API URL + CustomModelName: aiModelCfg.CustomModelName, // 自定义模型名称 ScanInterval: time.Duration(traderCfg.ScanIntervalMinutes) * time.Minute, InitialBalance: traderCfg.InitialBalance, BTCETHLeverage: traderCfg.BTCETHLeverage, @@ -616,7 +622,8 @@ func (tm *TraderManager) LoadUserTraders(database *config.Database, userID strin var aiModelCfg *config.AIModelConfig for _, model := range aiModels { - if model.ID == traderCfg.AIModelID { + // 使用 provider 来匹配,因为 AIModelID 存储的是 provider(如 "deepseek") + if model.Provider == traderCfg.AIModelID { aiModelCfg = model break } diff --git a/mcp/client.go b/mcp/client.go index eefddbe2..9191dfaf 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "strings" "time" @@ -23,7 +24,6 @@ const ( type Client struct { Provider Provider APIKey string - SecretKey string // 阿里云需要 BaseURL string Model string Timeout time.Duration @@ -41,20 +41,53 @@ func New() *Client { } // SetDeepSeekAPIKey 设置DeepSeek API密钥 -func (client *Client) SetDeepSeekAPIKey(apiKey string) { +// customURL 为空时使用默认URL,customModel 为空时使用默认模型 +func (client *Client) SetDeepSeekAPIKey(apiKey string, customURL string, customModel string) { client.Provider = ProviderDeepSeek client.APIKey = apiKey - client.BaseURL = "https://api.deepseek.com/v1" - client.Model = "deepseek-chat" + if customURL != "" { + client.BaseURL = customURL + log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL) + } else { + client.BaseURL = "https://api.deepseek.com/v1" + log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", client.BaseURL) + } + if customModel != "" { + client.Model = customModel + log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel) + } else { + client.Model = "deepseek-chat" + log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", client.Model) + } + // 打印 API Key 的前后各4位用于验证 + if len(apiKey) > 8 { + log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } } // SetQwenAPIKey 设置阿里云Qwen API密钥 -func (client *Client) SetQwenAPIKey(apiKey, secretKey string) { +// customURL 为空时使用默认URL,customModel 为空时使用默认模型 +func (client *Client) SetQwenAPIKey(apiKey string, customURL string, customModel string) { client.Provider = ProviderQwen client.APIKey = apiKey - client.SecretKey = secretKey - client.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" - client.Model = "qwen-plus" // 可选: qwen-turbo, qwen-plus, qwen-max + if customURL != "" { + client.BaseURL = customURL + log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL) + } else { + client.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", client.BaseURL) + } + if customModel != "" { + client.Model = customModel + log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel) + } else { + client.Model = "qwen-plus" // 可选: qwen-turbo, qwen-plus, qwen-max + log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", client.Model) + } + // 打印 API Key 的前后各4位用于验证 + if len(apiKey) > 8 { + log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } } // SetCustomAPI 设置自定义OpenAI兼容API @@ -125,6 +158,16 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, // callOnce 单次调用AI API(内部使用) func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) { + // 打印当前 AI 配置 + log.Printf("📡 [MCP] AI 请求配置:") + log.Printf(" Provider: %s", client.Provider) + log.Printf(" BaseURL: %s", client.BaseURL) + log.Printf(" Model: %s", client.Model) + log.Printf(" UseFullURL: %v", client.UseFullURL) + if len(client.APIKey) > 8 { + log.Printf(" API Key: %s...%s", client.APIKey[:4], client.APIKey[len(client.APIKey)-4:]) + } + // 构建 messages 数组 messages := []map[string]string{} @@ -167,6 +210,8 @@ func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) // 默认行为:添加/chat/completions url = fmt.Sprintf("%s/chat/completions", client.BaseURL) } + log.Printf("📡 [MCP] 请求 URL: %s", url) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return "", fmt.Errorf("创建请求失败: %w", err) diff --git a/trader/auto_trader.go b/trader/auto_trader.go index 35472ab5..3c9a5e55 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -121,13 +121,21 @@ func NewAutoTrader(config AutoTraderConfig) (*AutoTrader, error) { mcpClient.SetCustomAPI(config.CustomAPIURL, config.CustomAPIKey, config.CustomModelName) log.Printf("🤖 [%s] 使用自定义AI API: %s (模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) } else if config.UseQwen || config.AIModel == "qwen" { - // 使用Qwen - mcpClient.SetQwenAPIKey(config.QwenKey, "") - log.Printf("🤖 [%s] 使用阿里云Qwen AI", config.Name) + // 使用Qwen (支持自定义URL和Model) + mcpClient.SetQwenAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName) + if config.CustomAPIURL != "" || config.CustomModelName != "" { + log.Printf("🤖 [%s] 使用阿里云Qwen AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) + } else { + log.Printf("🤖 [%s] 使用阿里云Qwen AI", config.Name) + } } else { - // 默认使用DeepSeek - mcpClient.SetDeepSeekAPIKey(config.DeepSeekKey) - log.Printf("🤖 [%s] 使用DeepSeek AI", config.Name) + // 默认使用DeepSeek (支持自定义URL和Model) + mcpClient.SetDeepSeekAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName) + if config.CustomAPIURL != "" || config.CustomModelName != "" { + log.Printf("🤖 [%s] 使用DeepSeek AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) + } else { + log.Printf("🤖 [%s] 使用DeepSeek AI", config.Name) + } } // 初始化币种池API diff --git a/web/src/components/AITradersPage.tsx b/web/src/components/AITradersPage.tsx index 9c3da522..69cbb7b3 100644 --- a/web/src/components/AITradersPage.tsx +++ b/web/src/components/AITradersPage.tsx @@ -131,19 +131,19 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { const handleCreateTrader = async (data: CreateTraderRequest) => { try { - const model = allModels?.find(m => m.id === data.ai_model_id); + const model = allModels?.find(m => m.provider === data.ai_model_id); const exchange = allExchanges?.find(e => e.id === data.exchange_id); - + if (!model?.enabled) { alert(t('modelNotConfigured', language)); return; } - + if (!exchange?.enabled) { alert(t('exchangeNotConfigured', language)); return; } - + await api.createTrader(data); setShowCreateModal(false); mutateTraders(); @@ -166,9 +166,9 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { const handleSaveEditTrader = async (data: CreateTraderRequest) => { if (!editingTrader) return; - + try { - const model = enabledModels?.find(m => m.id === data.ai_model_id); + const model = enabledModels?.find(m => m.provider === data.ai_model_id); const exchange = enabledExchanges?.find(e => e.id === data.exchange_id); if (!model) { @@ -248,24 +248,26 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { const handleDeleteModelConfig = async (modelId: string) => { if (!confirm(t('confirmDeleteModel', language))) return; - + try { - const updatedModels = allModels?.map(m => - m.id === modelId ? { ...m, apiKey: '', enabled: false } : m + const updatedModels = allModels?.map(m => + m.id === modelId ? { ...m, apiKey: '', customApiUrl: '', customModelName: '', enabled: false } : m ) || []; - + const request = { models: Object.fromEntries( updatedModels.map(model => [ - model.id, + model.provider, // 使用 provider 而不是 id { enabled: model.enabled, - api_key: model.apiKey || '' + api_key: model.apiKey || '', + custom_api_url: model.customApiUrl || '', + custom_model_name: model.customModelName || '' } ]) ) }; - + await api.updateModelConfigs(request); setAllModels(updatedModels); setShowModelModal(false); @@ -276,7 +278,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { } }; - const handleSaveModelConfig = async (modelId: string, apiKey: string, customApiUrl?: string) => { + const handleSaveModelConfig = async (modelId: string, apiKey: string, customApiUrl?: string, customModelName?: string) => { try { // 找到要配置的模型(从supportedModels中) const modelToUpdate = supportedModels?.find(m => m.id === modelId); @@ -288,37 +290,38 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { // 创建或更新用户的模型配置 const existingModel = allModels?.find(m => m.id === modelId); let updatedModels; - + if (existingModel) { // 更新现有配置 - updatedModels = allModels?.map(m => - m.id === modelId ? { ...m, apiKey, customApiUrl: customApiUrl || '', enabled: true } : m + updatedModels = allModels?.map(m => + m.id === modelId ? { ...m, apiKey, customApiUrl: customApiUrl || '', customModelName: customModelName || '', enabled: true } : m ) || []; } else { // 添加新配置 - const newModel = { ...modelToUpdate, apiKey, customApiUrl: customApiUrl || '', enabled: true }; + const newModel = { ...modelToUpdate, apiKey, customApiUrl: customApiUrl || '', customModelName: customModelName || '', enabled: true }; updatedModels = [...(allModels || []), newModel]; } - + const request = { models: Object.fromEntries( updatedModels.map(model => [ - model.id, + model.provider, // 使用 provider 而不是 id { enabled: model.enabled, api_key: model.apiKey || '', - custom_api_url: model.customApiUrl || '' + custom_api_url: model.customApiUrl || '', + custom_model_name: model.customModelName || '' } ]) ) }; - + await api.updateModelConfigs(request); - + // 重新获取用户配置以确保数据同步 const refreshedModels = await api.getModelConfigs(); setAllModels(refreshedModels); - + setShowModelModal(false); setEditingModel(null); } catch (error) { @@ -910,7 +913,7 @@ function ModelConfigModal({ allModels: AIModel[]; configuredModels: AIModel[]; editingModelId: string | null; - onSave: (modelId: string, apiKey: string, baseUrl?: string) => void; + onSave: (modelId: string, apiKey: string, baseUrl?: string, modelName?: string) => void; onDelete: (modelId: string) => void; onClose: () => void; language: Language; @@ -918,25 +921,27 @@ function ModelConfigModal({ const [selectedModelId, setSelectedModelId] = useState(editingModelId || ''); const [apiKey, setApiKey] = useState(''); const [baseUrl, setBaseUrl] = useState(''); + const [modelName, setModelName] = useState(''); // 获取当前编辑的模型信息 - 编辑时从已配置的模型中查找,新建时从所有支持的模型中查找 - const selectedModel = editingModelId - ? configuredModels?.find(m => m.id === selectedModelId) + const selectedModel = editingModelId + ? configuredModels?.find(m => m.id === selectedModelId) : allModels?.find(m => m.id === selectedModelId); - // 如果是编辑现有模型,初始化API Key和Base URL + // 如果是编辑现有模型,初始化API Key、Base URL和Model Name useEffect(() => { if (editingModelId && selectedModel) { setApiKey(selectedModel.apiKey || ''); setBaseUrl(selectedModel.customApiUrl || ''); + setModelName(selectedModel.customModelName || ''); } }, [editingModelId, selectedModel]); const handleSubmit = (e: React.FormEvent) => { e.preventDefault(); if (!selectedModelId || !apiKey.trim()) return; - - onSave(selectedModelId, apiKey.trim(), baseUrl.trim() || undefined); + + onSave(selectedModelId, apiKey.trim(), baseUrl.trim() || undefined, modelName.trim() || undefined); }; // 可选择的模型列表(所有支持的模型) @@ -1047,6 +1052,23 @@ function ModelConfigModal({ +