From b536265f93e377cb5cb16ffcf0a7d16505f359bf Mon Sep 17 00:00:00 2001 From: lky-spec Date: Tue, 28 Apr 2026 12:22:45 +0800 Subject: [PATCH] Propagate MCP request context to HTTP calls --- mcp/client.go | 52 ++++++++++++++++++++++++---- mcp/payment/x402.go | 57 ++++++++++++++++++++++++++----- mcp/request_builder_test.go | 68 +++++++++++++++++++++++++++++++++++++ provider/nofxos/claw402.go | 1 + 4 files changed, 162 insertions(+), 16 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 6916de3c..9fee62a0 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -197,7 +197,9 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, if attempt < maxRetries { waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) client.Log.Infof("⏳ Waiting %v before retry...", waitTime) - time.Sleep(waitTime) + if err := sleepWithContext(context.Background(), waitTime); err != nil { + return "", err + } } } @@ -332,6 +334,38 @@ func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request, return req, nil } +func contextFromRequest(req *Request) context.Context { + if req != nil && req.Ctx != nil { + return req.Ctx + } + return context.Background() +} + +func (client *Client) buildHTTPRequestWithContext(ctx context.Context, url string, jsonData []byte) (*http.Request, error) { + if ctx == nil { + ctx = context.Background() + } + httpReq, err := client.Hooks.BuildRequest(url, jsonData) + if err != nil { + return nil, err + } + return httpReq.WithContext(ctx), nil +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + if ctx == nil { + ctx = context.Background() + } + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Call single AI API call (fixed flow, cannot be overridden) func (client *Client) Call(systemPrompt, userPrompt string) (string, error) { // Print current AI configuration @@ -450,7 +484,9 @@ func (client *Client) CallWithRequest(req *Request) (string, error) { if attempt < maxRetries { waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) client.Log.Infof("⏳ Waiting %v before retry...", waitTime) - time.Sleep(waitTime) + if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil { + return "", err + } } } @@ -482,7 +518,9 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) { } if attempt < maxRetries { waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) - time.Sleep(waitTime) + if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil { + return nil, err + } } } return nil, fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr) @@ -499,7 +537,7 @@ func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) { } url := client.Hooks.BuildUrl() - httpReq, err := client.Hooks.BuildRequest(url, jsonData) + httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -537,7 +575,7 @@ func (client *Client) callWithRequest(req *Request) (string, error) { url := client.Hooks.BuildUrl() client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url) - httpReq, err := client.Hooks.BuildRequest(url, jsonData) + httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } @@ -679,7 +717,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) } url := client.Hooks.BuildUrl() - httpReq, err := client.Hooks.BuildRequest(url, jsonData) + httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData) if err != nil { return "", err } @@ -687,7 +725,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) // Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds. // This breaks the scanner out of an indefinitely blocking Read on a hung connection. const idleTimeout = 60 * time.Second - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(contextFromRequest(req)) defer cancel() resetCh := make(chan struct{}, 1) go func() { diff --git a/mcp/payment/x402.go b/mcp/payment/x402.go index 577da51f..4f2e460f 100644 --- a/mcp/payment/x402.go +++ b/mcp/payment/x402.go @@ -35,13 +35,34 @@ const ( X402Timeout = 5 * time.Minute ) +func x402ContextFromRequest(req *mcp.Request) context.Context { + if req != nil && req.Ctx != nil { + return req.Ctx + } + return context.Background() +} + +func x402Sleep(ctx context.Context, d time.Duration) error { + if ctx == nil { + ctx = context.Background() + } + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // ── Shared x402 types ──────────────────────────────────────────────────────── // X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2). type X402v2PaymentRequired struct { - X402Version int `json:"x402Version"` + X402Version int `json:"x402Version"` Accepts []X402AcceptOption `json:"accepts"` - Resource *X402Resource `json:"resource"` + Resource *X402Resource `json:"resource"` } // X402AcceptOption is a payment option from the x402 v2 header. @@ -114,16 +135,21 @@ func SignBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string // DoX402Request executes an HTTP request and handles the x402 v2 payment flow. func DoX402Request( + ctx context.Context, httpClient *http.Client, buildReqFn func() (*http.Request, error), signFn X402SignFunc, providerTag string, logger mcp.Logger, ) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } req, err := buildReqFn() if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } + req = req.WithContext(ctx) resp, err := httpClient.Do(req) if err != nil { @@ -157,6 +183,7 @@ func DoX402Request( if err != nil { return nil, fmt.Errorf("failed to build retry request: %w", err) } + req2 = req2.WithContext(ctx) req2.Header.Set("X-Payment", paymentSig) req2.Header.Set("Payment-Signature", paymentSig) @@ -166,7 +193,9 @@ func DoX402Request( wait := X402RetryBaseWait * time.Duration(attempt) logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...", providerTag, err, wait, attempt+1, X402MaxPaymentRetries) - time.Sleep(wait) + if err := x402Sleep(ctx, wait); err != nil { + return nil, err + } continue } return nil, fmt.Errorf("failed to send payment retry: %w", err) @@ -221,7 +250,9 @@ func DoX402Request( providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries) } - time.Sleep(wait) + if err := x402Sleep(ctx, wait); err != nil { + return nil, err + } continue } @@ -256,11 +287,15 @@ func DoX402RequestStream( providerTag string, logger mcp.Logger, ) (*http.Response, error) { - // Initial request — use background context (no idle timeout yet). + if ctx == nil { + ctx = context.Background() + } + // Initial request also inherits ctx so stage timeouts cancel the 402 handshake. req, err := buildReqFn() if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } + req = req.WithContext(ctx) resp, err := httpClient.Do(req) if err != nil { @@ -314,7 +349,9 @@ func DoX402RequestStream( wait := X402RetryBaseWait * time.Duration(attempt) logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...", providerTag, err, wait, attempt+1, X402MaxPaymentRetries) - time.Sleep(wait) + if err := x402Sleep(ctx, wait); err != nil { + return nil, err + } continue } return nil, fmt.Errorf("failed to send payment retry: %w", err) @@ -369,7 +406,9 @@ func DoX402RequestStream( providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries) } - time.Sleep(wait) + if err := x402Sleep(ctx, wait); err != nil { + return nil, err + } continue } @@ -500,7 +539,7 @@ func X402Call(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt, user return "", err } - body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) { + body, err := DoX402Request(context.Background(), c.HTTPClient, func() (*http.Request, error) { return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData) }, signFn, tag, c.Log) if err != nil { @@ -526,7 +565,7 @@ func X402CallFull(c *mcp.Client, signFn X402SignFunc, tag string, req *mcp.Reque return nil, err } - body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) { + body, err := DoX402Request(x402ContextFromRequest(req), c.HTTPClient, func() (*http.Request, error) { return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData) }, signFn, tag, c.Log) if err != nil { diff --git a/mcp/request_builder_test.go b/mcp/request_builder_test.go index 4ec10a9f..96ed36e6 100644 --- a/mcp/request_builder_test.go +++ b/mcp/request_builder_test.go @@ -1,8 +1,13 @@ package mcp import ( + "context" "encoding/json" + "io" + "net/http" + "strings" "testing" + "time" ) // ============================================================ @@ -342,6 +347,69 @@ func TestClient_CallWithRequest_Success(t *testing.T) { } } +func TestClient_CallWithRequest_AttachesRequestContextToHTTP(t *testing.T) { + type contextKey string + const key contextKey = "stage" + ctx := context.WithValue(context.Background(), key, "planner") + + mockHTTP := NewMockHTTPClient() + mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) { + if req.Context().Value(key) != "planner" { + t.Fatalf("expected HTTP request to inherit mcp.Request context") + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok"}}]}`)), + Header: make(http.Header), + }, nil + } + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(NewMockLogger()), + WithAPIKey("sk-test-key"), + ) + request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild() + request.Ctx = ctx + + result, err := client.CallWithRequest(request) + if err != nil { + t.Fatalf("should not error: %v", err) + } + if result != "ok" { + t.Fatalf("expected ok, got %q", result) + } +} + +func TestClient_CallWithRequest_RetrySleepStopsWhenContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + mockHTTP := NewMockHTTPClient() + mockHTTP.SetNetworkError(io.EOF) + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(NewMockLogger()), + WithAPIKey("sk-test-key"), + WithMaxRetries(2), + WithRetryWaitBase(time.Hour), + ) + request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild() + request.Ctx = ctx + + start := time.Now() + _, err := client.CallWithRequest(request) + if err == nil || !strings.Contains(err.Error(), "context canceled") { + t.Fatalf("expected context canceled during retry wait, got %v", err) + } + if elapsed := time.Since(start); elapsed > 500*time.Millisecond { + t.Fatalf("retry sleep did not respect context cancellation, elapsed=%v", elapsed) + } + if got := len(mockHTTP.GetRequests()); got != 1 { + t.Fatalf("expected no retry after context cancellation, got %d requests", got) + } +} + func TestClient_CallWithRequest_MultiRound(t *testing.T) { mockHTTP := NewMockHTTPClient() mockHTTP.SetSuccessResponse("Multi-round response") diff --git a/provider/nofxos/claw402.go b/provider/nofxos/claw402.go index 0e4d6107..53f0a31a 100644 --- a/provider/nofxos/claw402.go +++ b/provider/nofxos/claw402.go @@ -98,6 +98,7 @@ func (c *Claw402DataClient) DoRequest(endpoint string) ([]byte, error) { signFn := payment.MakeClaw402SignFunc(c.privateKey) body, err := payment.DoX402Request( + context.Background(), c.httpClient, buildReq, signFn,