Propagate MCP request context to HTTP calls

This commit is contained in:
lky-spec
2026-04-28 12:22:45 +08:00
parent 30a703a827
commit b536265f93
4 changed files with 162 additions and 16 deletions

View File

@@ -197,7 +197,9 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
if attempt < maxRetries {
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
if err := sleepWithContext(context.Background(), waitTime); err != nil {
return "", err
}
}
}
@@ -332,6 +334,38 @@ func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request,
return req, nil
}
func contextFromRequest(req *Request) context.Context {
if req != nil && req.Ctx != nil {
return req.Ctx
}
return context.Background()
}
func (client *Client) buildHTTPRequestWithContext(ctx context.Context, url string, jsonData []byte) (*http.Request, error) {
if ctx == nil {
ctx = context.Background()
}
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return nil, err
}
return httpReq.WithContext(ctx), nil
}
func sleepWithContext(ctx context.Context, d time.Duration) error {
if ctx == nil {
ctx = context.Background()
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-timer.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Call single AI API call (fixed flow, cannot be overridden)
func (client *Client) Call(systemPrompt, userPrompt string) (string, error) {
// Print current AI configuration
@@ -450,7 +484,9 @@ func (client *Client) CallWithRequest(req *Request) (string, error) {
if attempt < maxRetries {
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil {
return "", err
}
}
}
@@ -482,7 +518,9 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
}
if attempt < maxRetries {
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
time.Sleep(waitTime)
if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil {
return nil, err
}
}
}
return nil, fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
@@ -499,7 +537,7 @@ func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) {
}
url := client.Hooks.BuildUrl()
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
@@ -537,7 +575,7 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
url := client.Hooks.BuildUrl()
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
@@ -679,7 +717,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
}
url := client.Hooks.BuildUrl()
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
if err != nil {
return "", err
}
@@ -687,7 +725,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds.
// This breaks the scanner out of an indefinitely blocking Read on a hung connection.
const idleTimeout = 60 * time.Second
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(contextFromRequest(req))
defer cancel()
resetCh := make(chan struct{}, 1)
go func() {

View File

@@ -35,13 +35,34 @@ const (
X402Timeout = 5 * time.Minute
)
func x402ContextFromRequest(req *mcp.Request) context.Context {
if req != nil && req.Ctx != nil {
return req.Ctx
}
return context.Background()
}
func x402Sleep(ctx context.Context, d time.Duration) error {
if ctx == nil {
ctx = context.Background()
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-timer.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// ── Shared x402 types ────────────────────────────────────────────────────────
// X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2).
type X402v2PaymentRequired struct {
X402Version int `json:"x402Version"`
X402Version int `json:"x402Version"`
Accepts []X402AcceptOption `json:"accepts"`
Resource *X402Resource `json:"resource"`
Resource *X402Resource `json:"resource"`
}
// X402AcceptOption is a payment option from the x402 v2 header.
@@ -114,16 +135,21 @@ func SignBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string
// DoX402Request executes an HTTP request and handles the x402 v2 payment flow.
func DoX402Request(
ctx context.Context,
httpClient *http.Client,
buildReqFn func() (*http.Request, error),
signFn X402SignFunc,
providerTag string,
logger mcp.Logger,
) ([]byte, error) {
if ctx == nil {
ctx = context.Background()
}
req, err := buildReqFn()
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req = req.WithContext(ctx)
resp, err := httpClient.Do(req)
if err != nil {
@@ -157,6 +183,7 @@ func DoX402Request(
if err != nil {
return nil, fmt.Errorf("failed to build retry request: %w", err)
}
req2 = req2.WithContext(ctx)
req2.Header.Set("X-Payment", paymentSig)
req2.Header.Set("Payment-Signature", paymentSig)
@@ -166,7 +193,9 @@ func DoX402Request(
wait := X402RetryBaseWait * time.Duration(attempt)
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
time.Sleep(wait)
if err := x402Sleep(ctx, wait); err != nil {
return nil, err
}
continue
}
return nil, fmt.Errorf("failed to send payment retry: %w", err)
@@ -221,7 +250,9 @@ func DoX402Request(
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
}
time.Sleep(wait)
if err := x402Sleep(ctx, wait); err != nil {
return nil, err
}
continue
}
@@ -256,11 +287,15 @@ func DoX402RequestStream(
providerTag string,
logger mcp.Logger,
) (*http.Response, error) {
// Initial request — use background context (no idle timeout yet).
if ctx == nil {
ctx = context.Background()
}
// Initial request also inherits ctx so stage timeouts cancel the 402 handshake.
req, err := buildReqFn()
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req = req.WithContext(ctx)
resp, err := httpClient.Do(req)
if err != nil {
@@ -314,7 +349,9 @@ func DoX402RequestStream(
wait := X402RetryBaseWait * time.Duration(attempt)
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
time.Sleep(wait)
if err := x402Sleep(ctx, wait); err != nil {
return nil, err
}
continue
}
return nil, fmt.Errorf("failed to send payment retry: %w", err)
@@ -369,7 +406,9 @@ func DoX402RequestStream(
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
}
time.Sleep(wait)
if err := x402Sleep(ctx, wait); err != nil {
return nil, err
}
continue
}
@@ -500,7 +539,7 @@ func X402Call(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt, user
return "", err
}
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
body, err := DoX402Request(context.Background(), c.HTTPClient, func() (*http.Request, error) {
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
}, signFn, tag, c.Log)
if err != nil {
@@ -526,7 +565,7 @@ func X402CallFull(c *mcp.Client, signFn X402SignFunc, tag string, req *mcp.Reque
return nil, err
}
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
body, err := DoX402Request(x402ContextFromRequest(req), c.HTTPClient, func() (*http.Request, error) {
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
}, signFn, tag, c.Log)
if err != nil {

View File

@@ -1,8 +1,13 @@
package mcp
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
)
// ============================================================
@@ -342,6 +347,69 @@ func TestClient_CallWithRequest_Success(t *testing.T) {
}
}
func TestClient_CallWithRequest_AttachesRequestContextToHTTP(t *testing.T) {
type contextKey string
const key contextKey = "stage"
ctx := context.WithValue(context.Background(), key, "planner")
mockHTTP := NewMockHTTPClient()
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
if req.Context().Value(key) != "planner" {
t.Fatalf("expected HTTP request to inherit mcp.Request context")
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok"}}]}`)),
Header: make(http.Header),
}, nil
}
client := NewClient(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(NewMockLogger()),
WithAPIKey("sk-test-key"),
)
request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild()
request.Ctx = ctx
result, err := client.CallWithRequest(request)
if err != nil {
t.Fatalf("should not error: %v", err)
}
if result != "ok" {
t.Fatalf("expected ok, got %q", result)
}
}
func TestClient_CallWithRequest_RetrySleepStopsWhenContextCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
mockHTTP := NewMockHTTPClient()
mockHTTP.SetNetworkError(io.EOF)
client := NewClient(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(NewMockLogger()),
WithAPIKey("sk-test-key"),
WithMaxRetries(2),
WithRetryWaitBase(time.Hour),
)
request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild()
request.Ctx = ctx
start := time.Now()
_, err := client.CallWithRequest(request)
if err == nil || !strings.Contains(err.Error(), "context canceled") {
t.Fatalf("expected context canceled during retry wait, got %v", err)
}
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
t.Fatalf("retry sleep did not respect context cancellation, elapsed=%v", elapsed)
}
if got := len(mockHTTP.GetRequests()); got != 1 {
t.Fatalf("expected no retry after context cancellation, got %d requests", got)
}
}
func TestClient_CallWithRequest_MultiRound(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockHTTP.SetSuccessResponse("Multi-round response")

View File

@@ -98,6 +98,7 @@ func (c *Claw402DataClient) DoRequest(endpoint string) ([]byte, error) {
signFn := payment.MakeClaw402SignFunc(c.privateKey)
body, err := payment.DoX402Request(
context.Background(),
c.httpClient,
buildReq,
signFn,