mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2026-06-06 05:51:19 +08:00
Propagate MCP request context to HTTP calls
This commit is contained in:
@@ -197,7 +197,9 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
|
||||
if attempt < maxRetries {
|
||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||
time.Sleep(waitTime)
|
||||
if err := sleepWithContext(context.Background(), waitTime); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,6 +334,38 @@ func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request,
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func contextFromRequest(req *Request) context.Context {
|
||||
if req != nil && req.Ctx != nil {
|
||||
return req.Ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
func (client *Client) buildHTTPRequestWithContext(ctx context.Context, url string, jsonData []byte) (*http.Request, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return httpReq.WithContext(ctx), nil
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-timer.C:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Call single AI API call (fixed flow, cannot be overridden)
|
||||
func (client *Client) Call(systemPrompt, userPrompt string) (string, error) {
|
||||
// Print current AI configuration
|
||||
@@ -450,7 +484,9 @@ func (client *Client) CallWithRequest(req *Request) (string, error) {
|
||||
if attempt < maxRetries {
|
||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||
time.Sleep(waitTime)
|
||||
if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -482,7 +518,9 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
}
|
||||
if attempt < maxRetries {
|
||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||
time.Sleep(waitTime)
|
||||
if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
|
||||
@@ -499,7 +537,7 @@ func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
}
|
||||
|
||||
url := client.Hooks.BuildUrl()
|
||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
||||
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -537,7 +575,7 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
|
||||
url := client.Hooks.BuildUrl()
|
||||
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
||||
|
||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
||||
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -679,7 +717,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
|
||||
}
|
||||
|
||||
url := client.Hooks.BuildUrl()
|
||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
||||
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -687,7 +725,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
|
||||
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds.
|
||||
// This breaks the scanner out of an indefinitely blocking Read on a hung connection.
|
||||
const idleTimeout = 60 * time.Second
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(contextFromRequest(req))
|
||||
defer cancel()
|
||||
resetCh := make(chan struct{}, 1)
|
||||
go func() {
|
||||
|
||||
@@ -35,13 +35,34 @@ const (
|
||||
X402Timeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
func x402ContextFromRequest(req *mcp.Request) context.Context {
|
||||
if req != nil && req.Ctx != nil {
|
||||
return req.Ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
func x402Sleep(ctx context.Context, d time.Duration) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-timer.C:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Shared x402 types ────────────────────────────────────────────────────────
|
||||
|
||||
// X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2).
|
||||
type X402v2PaymentRequired struct {
|
||||
X402Version int `json:"x402Version"`
|
||||
X402Version int `json:"x402Version"`
|
||||
Accepts []X402AcceptOption `json:"accepts"`
|
||||
Resource *X402Resource `json:"resource"`
|
||||
Resource *X402Resource `json:"resource"`
|
||||
}
|
||||
|
||||
// X402AcceptOption is a payment option from the x402 v2 header.
|
||||
@@ -114,16 +135,21 @@ func SignBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string
|
||||
|
||||
// DoX402Request executes an HTTP request and handles the x402 v2 payment flow.
|
||||
func DoX402Request(
|
||||
ctx context.Context,
|
||||
httpClient *http.Client,
|
||||
buildReqFn func() (*http.Request, error),
|
||||
signFn X402SignFunc,
|
||||
providerTag string,
|
||||
logger mcp.Logger,
|
||||
) ([]byte, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
req, err := buildReqFn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -157,6 +183,7 @@ func DoX402Request(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build retry request: %w", err)
|
||||
}
|
||||
req2 = req2.WithContext(ctx)
|
||||
req2.Header.Set("X-Payment", paymentSig)
|
||||
req2.Header.Set("Payment-Signature", paymentSig)
|
||||
|
||||
@@ -166,7 +193,9 @@ func DoX402Request(
|
||||
wait := X402RetryBaseWait * time.Duration(attempt)
|
||||
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
|
||||
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
|
||||
time.Sleep(wait)
|
||||
if err := x402Sleep(ctx, wait); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to send payment retry: %w", err)
|
||||
@@ -221,7 +250,9 @@ func DoX402Request(
|
||||
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
|
||||
}
|
||||
|
||||
time.Sleep(wait)
|
||||
if err := x402Sleep(ctx, wait); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -256,11 +287,15 @@ func DoX402RequestStream(
|
||||
providerTag string,
|
||||
logger mcp.Logger,
|
||||
) (*http.Response, error) {
|
||||
// Initial request — use background context (no idle timeout yet).
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
// Initial request also inherits ctx so stage timeouts cancel the 402 handshake.
|
||||
req, err := buildReqFn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -314,7 +349,9 @@ func DoX402RequestStream(
|
||||
wait := X402RetryBaseWait * time.Duration(attempt)
|
||||
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
|
||||
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
|
||||
time.Sleep(wait)
|
||||
if err := x402Sleep(ctx, wait); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to send payment retry: %w", err)
|
||||
@@ -369,7 +406,9 @@ func DoX402RequestStream(
|
||||
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
|
||||
}
|
||||
|
||||
time.Sleep(wait)
|
||||
if err := x402Sleep(ctx, wait); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -500,7 +539,7 @@ func X402Call(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt, user
|
||||
return "", err
|
||||
}
|
||||
|
||||
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
|
||||
body, err := DoX402Request(context.Background(), c.HTTPClient, func() (*http.Request, error) {
|
||||
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
|
||||
}, signFn, tag, c.Log)
|
||||
if err != nil {
|
||||
@@ -526,7 +565,7 @@ func X402CallFull(c *mcp.Client, signFn X402SignFunc, tag string, req *mcp.Reque
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
|
||||
body, err := DoX402Request(x402ContextFromRequest(req), c.HTTPClient, func() (*http.Request, error) {
|
||||
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
|
||||
}, signFn, tag, c.Log)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
@@ -342,6 +347,69 @@ func TestClient_CallWithRequest_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_CallWithRequest_AttachesRequestContextToHTTP(t *testing.T) {
|
||||
type contextKey string
|
||||
const key contextKey = "stage"
|
||||
ctx := context.WithValue(context.Background(), key, "planner")
|
||||
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
|
||||
if req.Context().Value(key) != "planner" {
|
||||
t.Fatalf("expected HTTP request to inherit mcp.Request context")
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok"}}]}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
client := NewClient(
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(NewMockLogger()),
|
||||
WithAPIKey("sk-test-key"),
|
||||
)
|
||||
request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild()
|
||||
request.Ctx = ctx
|
||||
|
||||
result, err := client.CallWithRequest(request)
|
||||
if err != nil {
|
||||
t.Fatalf("should not error: %v", err)
|
||||
}
|
||||
if result != "ok" {
|
||||
t.Fatalf("expected ok, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_CallWithRequest_RetrySleepStopsWhenContextCancelled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockHTTP.SetNetworkError(io.EOF)
|
||||
client := NewClient(
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(NewMockLogger()),
|
||||
WithAPIKey("sk-test-key"),
|
||||
WithMaxRetries(2),
|
||||
WithRetryWaitBase(time.Hour),
|
||||
)
|
||||
request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild()
|
||||
request.Ctx = ctx
|
||||
|
||||
start := time.Now()
|
||||
_, err := client.CallWithRequest(request)
|
||||
if err == nil || !strings.Contains(err.Error(), "context canceled") {
|
||||
t.Fatalf("expected context canceled during retry wait, got %v", err)
|
||||
}
|
||||
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
|
||||
t.Fatalf("retry sleep did not respect context cancellation, elapsed=%v", elapsed)
|
||||
}
|
||||
if got := len(mockHTTP.GetRequests()); got != 1 {
|
||||
t.Fatalf("expected no retry after context cancellation, got %d requests", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_CallWithRequest_MultiRound(t *testing.T) {
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockHTTP.SetSuccessResponse("Multi-round response")
|
||||
|
||||
Reference in New Issue
Block a user