package mcp import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" ) const ( ProviderCustom = "custom" MCPClientTemperature = 0.5 ) var ( DefaultTimeout = 120 * time.Second MaxRetryTimes = 3 retryableErrors = []string{ "EOF", "timeout", "connection reset", "connection refused", "temporary failure", "no such host", "stream error", // HTTP/2 stream error "INTERNAL_ERROR", // Server internal error } // TokenUsageCallback is called after each AI request with token usage info TokenUsageCallback func(usage TokenUsage) ) // TokenUsage represents token usage from AI API response type TokenUsage struct { Provider string Model string PromptTokens int CompletionTokens int TotalTokens int } // Client AI API configuration type Client struct { Provider string APIKey string BaseURL string Model string UseFullURL bool // Whether to use full URL (without appending /chat/completions) MaxTokens int // Maximum tokens for AI response httpClient *http.Client logger Logger // Logger (replaceable) config *Config // Config object (stores all configurations) // hooks are used to implement dynamic dispatch (polymorphism) // When DeepSeekClient embeds Client, hooks point to DeepSeekClient // This way methods called in call() are automatically dispatched to the overridden version in subclass hooks clientHooks } // New creates default client (backward compatible) // // Deprecated: Recommend using NewClient(...opts) for better flexibility func New() AIClient { return NewClient() } // NewClient creates client (supports options pattern) // // Usage examples: // // Basic usage (backward compatible) // client := mcp.NewClient() // // // Custom logger // client := mcp.NewClient(mcp.WithLogger(customLogger)) // // // Custom timeout // client := mcp.NewClient(mcp.WithTimeout(60*time.Second)) // // // Combine multiple options // client := mcp.NewClient( // mcp.WithDeepSeekConfig("sk-xxx"), // mcp.WithLogger(customLogger), // mcp.WithTimeout(60*time.Second), // ) func NewClient(opts ...ClientOption) AIClient { // 1. Create default config cfg := DefaultConfig() // 2. Apply user options for _, opt := range opts { opt(cfg) } // 3. Create client instance client := &Client{ Provider: cfg.Provider, APIKey: cfg.APIKey, BaseURL: cfg.BaseURL, Model: cfg.Model, MaxTokens: cfg.MaxTokens, UseFullURL: cfg.UseFullURL, httpClient: cfg.HTTPClient, logger: cfg.Logger, config: cfg, } // 4. Set default Provider (if not set) if client.Provider == "" { client.Provider = ProviderDeepSeek client.BaseURL = DefaultDeepSeekBaseURL client.Model = DefaultDeepSeekModel } // 5. Set hooks to point to self client.hooks = client return client } // SetCustomAPI sets custom OpenAI-compatible API func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) { client.Provider = ProviderCustom client.APIKey = apiKey // Check if URL ends with #, if so use full URL (without appending /chat/completions) if strings.HasSuffix(apiURL, "#") { client.BaseURL = strings.TrimSuffix(apiURL, "#") client.UseFullURL = true } else { client.BaseURL = apiURL client.UseFullURL = false } client.Model = customModel } func (client *Client) SetTimeout(timeout time.Duration) { client.httpClient.Timeout = timeout } // CallWithMessages template method - fixed retry flow (cannot be overridden) func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) { if client.APIKey == "" { return "", fmt.Errorf("AI API key not set, please call SetAPIKey first") } // Fixed retry flow var lastErr error maxRetries := client.config.MaxRetries for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) } // Call the fixed single-call flow result, err := client.hooks.call(systemPrompt, userPrompt) if err == nil { if attempt > 1 { client.logger.Infof("✓ AI API retry succeeded") } return result, nil } lastErr = err // Check if error is retryable via hooks (supports custom retry strategy in subclass) if !client.hooks.isRetryableError(err) { return "", err } // Wait before retry if attempt < maxRetries { waitTime := client.config.RetryWaitBase * time.Duration(attempt) client.logger.Infof("⏳ Waiting %v before retry...", waitTime) time.Sleep(waitTime) } } return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr) } func (client *Client) setAuthHeader(reqHeader http.Header) { reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey)) } func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { // Build messages array messages := []map[string]string{} // If system prompt exists, add system message if systemPrompt != "" { messages = append(messages, map[string]string{ "role": "system", "content": systemPrompt, }) } // Add user message messages = append(messages, map[string]string{ "role": "user", "content": userPrompt, }) // Build request body requestBody := map[string]interface{}{ "model": client.Model, "messages": messages, "temperature": client.config.Temperature, // Use configured temperature } // OpenAI newer models use max_completion_tokens instead of max_tokens if client.Provider == ProviderOpenAI { requestBody["max_completion_tokens"] = client.MaxTokens } else { requestBody["max_tokens"] = client.MaxTokens } return requestBody } // can be used to marshal the request body and can be overridden func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) { jsonData, err := json.Marshal(requestBody) if err != nil { return nil, fmt.Errorf("failed to serialize request: %w", err) } return jsonData, nil } func (client *Client) parseMCPResponse(body []byte) (string, error) { r, err := client.parseMCPResponseFull(body) if err != nil { return "", err } return r.Content, nil } // parseMCPResponseFull parses the OpenAI-format response body and returns both // the text content and any tool calls. func (client *Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) { var result struct { Choices []struct { Message struct { Content string `json:"content"` ToolCalls []ToolCall `json:"tool_calls"` } `json:"message"` } `json:"choices"` Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` } if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } if len(result.Choices) == 0 { return nil, fmt.Errorf("API returned empty response") } // Report token usage if callback is set if TokenUsageCallback != nil && result.Usage.TotalTokens > 0 { TokenUsageCallback(TokenUsage{ Provider: client.Provider, Model: client.Model, PromptTokens: result.Usage.PromptTokens, CompletionTokens: result.Usage.CompletionTokens, TotalTokens: result.Usage.TotalTokens, }) } msg := result.Choices[0].Message return &LLMResponse{ Content: msg.Content, ToolCalls: msg.ToolCalls, }, nil } func (client *Client) buildUrl() string { if client.UseFullURL { return client.BaseURL } return fmt.Sprintf("%s/chat/completions", client.BaseURL) } func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request, error) { // Create HTTP request req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("fail to build request: %w", err) } req.Header.Set("Content-Type", "application/json") // Set auth header via hooks (supports overriding in subclass) client.hooks.setAuthHeader(req.Header) return req, nil } // call single AI API call (fixed flow, cannot be overridden) func (client *Client) call(systemPrompt, userPrompt string) (string, error) { // Print current AI configuration client.logger.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL) client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL) if len(client.APIKey) > 8 { client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:]) } // Step 1: Build request body (via hooks for dynamic dispatch) requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt) // Step 2: Serialize request body (via hooks for dynamic dispatch) jsonData, err := client.hooks.marshalRequestBody(requestBody) if err != nil { return "", err } // Step 3: Build URL (via hooks for dynamic dispatch) url := client.hooks.buildUrl() client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url) // Step 4: Create HTTP request (fixed logic) req, err := client.hooks.buildRequest(url, jsonData) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } // Step 5: Send HTTP request (fixed logic) resp, err := client.httpClient.Do(req) if err != nil { return "", fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() // Step 6: Read response body (fixed logic) body, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("failed to read response: %w", err) } // Step 7: Check HTTP status code (fixed logic) if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body)) } // Step 8: Parse response (via hooks for dynamic dispatch) result, err := client.hooks.parseMCPResponse(body) if err != nil { return "", fmt.Errorf("fail to parse AI server response: %w", err) } return result, nil } func (client *Client) String() string { return fmt.Sprintf("[Provider: %s, Model: %s]", client.Provider, client.Model) } // isRetryableError determines if error is retryable (network errors, timeouts, etc.) func (client *Client) isRetryableError(err error) bool { errStr := err.Error() // Network errors, timeouts, EOF, etc. can be retried for _, retryable := range client.config.RetryableErrors { if strings.Contains(errStr, retryable) { return true } } return false } // ============================================================ // Builder Pattern API (Advanced Features) // ============================================================ // CallWithRequest calls AI API using Request object (supports advanced features) // // This method supports: // - Multi-turn conversation history // - Fine-grained parameter control (temperature, top_p, penalties, etc.) // - Function Calling / Tools // - Streaming response (future support) // // Usage example: // request := NewRequestBuilder(). // WithSystemPrompt("You are helpful"). // WithUserPrompt("Hello"). // WithTemperature(0.8). // Build() // result, err := client.CallWithRequest(request) func (client *Client) CallWithRequest(req *Request) (string, error) { if client.APIKey == "" { return "", fmt.Errorf("AI API key not set, please call SetAPIKey first") } // If Model is not set in Request, use Client's Model if req.Model == "" { req.Model = client.Model } // Fixed retry flow var lastErr error maxRetries := client.config.MaxRetries for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) } // Call single request result, err := client.callWithRequest(req) if err == nil { if attempt > 1 { client.logger.Infof("✓ AI API retry succeeded") } return result, nil } lastErr = err // Check if error is retryable if !client.hooks.isRetryableError(err) { return "", err } // Wait before retry if attempt < maxRetries { waitTime := client.config.RetryWaitBase * time.Duration(attempt) client.logger.Infof("⏳ Waiting %v before retry...", waitTime) time.Sleep(waitTime) } } return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr) } // CallWithRequestFull calls the AI API and returns both text content and tool calls. func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) { if client.APIKey == "" { return nil, fmt.Errorf("AI API key not set, please call SetAPIKey first") } if req.Model == "" { req.Model = client.Model } var lastErr error maxRetries := client.config.MaxRetries for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) } result, err := client.callWithRequestFull(req) if err == nil { return result, nil } lastErr = err if !client.hooks.isRetryableError(err) { return nil, err } if attempt < maxRetries { waitTime := client.config.RetryWaitBase * time.Duration(attempt) time.Sleep(waitTime) } } return nil, fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr) } // callWithRequestFull single call that returns LLMResponse (content + tool calls). func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) { client.logger.Infof("📡 [%s] Request AI Server (full): BaseURL: %s", client.String(), client.BaseURL) requestBody := client.hooks.buildRequestBodyFromRequest(req) jsonData, err := client.hooks.marshalRequestBody(requestBody) if err != nil { return nil, err } url := client.hooks.buildUrl() httpReq, err := client.hooks.buildRequest(url, jsonData) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } resp, err := client.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body)) } return client.hooks.parseMCPResponseFull(body) } // callWithRequest single AI API call (using Request object) func (client *Client) callWithRequest(req *Request) (string, error) { // Print current AI configuration client.logger.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL) client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages)) requestBody := client.hooks.buildRequestBodyFromRequest(req) jsonData, err := client.hooks.marshalRequestBody(requestBody) if err != nil { return "", err } url := client.hooks.buildUrl() client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url) httpReq, err := client.hooks.buildRequest(url, jsonData) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } resp, err := client.httpClient.Do(httpReq) if err != nil { return "", fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body)) } result, err := client.hooks.parseMCPResponse(body) if err != nil { return "", fmt.Errorf("fail to parse AI server response: %w", err) } return result, nil } // buildRequestBodyFromRequest builds request body from Request object func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any { // Convert Message to API format — must use map[string]any to support // tool-call messages (tool_calls, tool_call_id fields). messages := make([]map[string]any, 0, len(req.Messages)) for _, msg := range req.Messages { m := map[string]any{"role": msg.Role} if len(msg.ToolCalls) > 0 { // Assistant message that contains tool invocations. // content must be null/omitted for OpenAI compatibility. m["tool_calls"] = msg.ToolCalls } else if msg.ToolCallID != "" { // Tool result message (role="tool"). m["tool_call_id"] = msg.ToolCallID m["content"] = msg.Content } else { m["content"] = msg.Content } messages = append(messages, m) } // Build basic request body requestBody := map[string]interface{}{ "model": req.Model, "messages": messages, } // Add optional parameters (only add non-nil parameters) if req.Temperature != nil { requestBody["temperature"] = *req.Temperature } else { // If not set in Request, use Client's configuration requestBody["temperature"] = client.config.Temperature } // OpenAI newer models use max_completion_tokens instead of max_tokens tokenKey := "max_tokens" if client.Provider == ProviderOpenAI { tokenKey = "max_completion_tokens" } if req.MaxTokens != nil { requestBody[tokenKey] = *req.MaxTokens } else { // If not set in Request, use Client's MaxTokens requestBody[tokenKey] = client.MaxTokens } if req.TopP != nil { requestBody["top_p"] = *req.TopP } if req.FrequencyPenalty != nil { requestBody["frequency_penalty"] = *req.FrequencyPenalty } if req.PresencePenalty != nil { requestBody["presence_penalty"] = *req.PresencePenalty } if len(req.Stop) > 0 { requestBody["stop"] = req.Stop } if len(req.Tools) > 0 { requestBody["tools"] = req.Tools } if req.ToolChoice != "" { requestBody["tool_choice"] = req.ToolChoice } if req.Stream { requestBody["stream"] = true } return requestBody } // CallWithRequestStream streams the LLM response via SSE (Server-Sent Events). // onChunk is called with the full accumulated text so far after each received chunk. // Returns the complete final text when the stream ends. // // Idle timeout: if no chunk arrives for 30 seconds the stream is cancelled automatically. // This prevents the scanner from blocking indefinitely on a hung or stalled connection. func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) (string, error) { if client.APIKey == "" { return "", fmt.Errorf("AI API key not set") } if req.Model == "" { req.Model = client.Model } req.Stream = true requestBody := client.hooks.buildRequestBodyFromRequest(req) jsonData, err := client.hooks.marshalRequestBody(requestBody) if err != nil { return "", err } url := client.hooks.buildUrl() httpReq, err := client.hooks.buildRequest(url, jsonData) if err != nil { return "", err } // Idle-timeout watchdog: cancel the request if no SSE line arrives for 30 seconds. // This breaks the scanner out of an indefinitely blocking Read on a hung connection. const idleTimeout = 60 * time.Second ctx, cancel := context.WithCancel(context.Background()) defer cancel() resetCh := make(chan struct{}, 1) go func() { t := time.NewTimer(idleTimeout) defer t.Stop() for { select { case <-ctx.Done(): return case <-t.C: cancel() // idle timeout: kill the connection return case <-resetCh: // received a line — reset the idle timer if !t.Stop() { select { case <-t.C: default: } } t.Reset(idleTimeout) } } }() httpReq = httpReq.WithContext(ctx) resp, err := client.httpClient.Do(httpReq) if err != nil { return "", fmt.Errorf("streaming request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) } var accumulated strings.Builder scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { // Ping the watchdog: we received a line, reset the idle timer. select { case resetCh <- struct{}{}: default: } line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } // Parse the SSE JSON chunk var chunk struct { Choices []struct { Delta struct { Content string `json:"content"` } `json:"delta"` FinishReason *string `json:"finish_reason"` } `json:"choices"` } if err := json.Unmarshal([]byte(data), &chunk); err != nil { continue // skip malformed chunks } if len(chunk.Choices) == 0 { continue } delta := chunk.Choices[0].Delta.Content if delta == "" { continue } accumulated.WriteString(delta) if onChunk != nil { onChunk(accumulated.String()) } } if err := scanner.Err(); err != nil { return accumulated.String(), fmt.Errorf("stream interrupted: %w", err) } return accumulated.String(), nil }