mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2026-06-06 05:51:19 +08:00
Propagate MCP request context to HTTP calls
This commit is contained in:
@@ -197,7 +197,9 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
|
|||||||
if attempt < maxRetries {
|
if attempt < maxRetries {
|
||||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||||
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
|
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||||
time.Sleep(waitTime)
|
if err := sleepWithContext(context.Background(), waitTime); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -332,6 +334,38 @@ func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request,
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func contextFromRequest(req *Request) context.Context {
|
||||||
|
if req != nil && req.Ctx != nil {
|
||||||
|
return req.Ctx
|
||||||
|
}
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) buildHTTPRequestWithContext(ctx context.Context, url string, jsonData []byte) (*http.Request, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return httpReq.WithContext(ctx), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Call single AI API call (fixed flow, cannot be overridden)
|
// Call single AI API call (fixed flow, cannot be overridden)
|
||||||
func (client *Client) Call(systemPrompt, userPrompt string) (string, error) {
|
func (client *Client) Call(systemPrompt, userPrompt string) (string, error) {
|
||||||
// Print current AI configuration
|
// Print current AI configuration
|
||||||
@@ -450,7 +484,9 @@ func (client *Client) CallWithRequest(req *Request) (string, error) {
|
|||||||
if attempt < maxRetries {
|
if attempt < maxRetries {
|
||||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||||
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
|
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||||
time.Sleep(waitTime)
|
if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -482,7 +518,9 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
|
|||||||
}
|
}
|
||||||
if attempt < maxRetries {
|
if attempt < maxRetries {
|
||||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||||
time.Sleep(waitTime)
|
if err := sleepWithContext(contextFromRequest(req), waitTime); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
|
return nil, fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
|
||||||
@@ -499,7 +537,7 @@ func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
url := client.Hooks.BuildUrl()
|
url := client.Hooks.BuildUrl()
|
||||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -537,7 +575,7 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
|
|||||||
url := client.Hooks.BuildUrl()
|
url := client.Hooks.BuildUrl()
|
||||||
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
||||||
|
|
||||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to create request: %w", err)
|
return "", fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -679,7 +717,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
|
|||||||
}
|
}
|
||||||
|
|
||||||
url := client.Hooks.BuildUrl()
|
url := client.Hooks.BuildUrl()
|
||||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
httpReq, err := client.buildHTTPRequestWithContext(contextFromRequest(req), url, jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -687,7 +725,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
|
|||||||
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds.
|
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds.
|
||||||
// This breaks the scanner out of an indefinitely blocking Read on a hung connection.
|
// This breaks the scanner out of an indefinitely blocking Read on a hung connection.
|
||||||
const idleTimeout = 60 * time.Second
|
const idleTimeout = 60 * time.Second
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(contextFromRequest(req))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
resetCh := make(chan struct{}, 1)
|
resetCh := make(chan struct{}, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -35,13 +35,34 @@ const (
|
|||||||
X402Timeout = 5 * time.Minute
|
X402Timeout = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func x402ContextFromRequest(req *mcp.Request) context.Context {
|
||||||
|
if req != nil && req.Ctx != nil {
|
||||||
|
return req.Ctx
|
||||||
|
}
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
func x402Sleep(ctx context.Context, d time.Duration) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── Shared x402 types ────────────────────────────────────────────────────────
|
// ── Shared x402 types ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2).
|
// X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2).
|
||||||
type X402v2PaymentRequired struct {
|
type X402v2PaymentRequired struct {
|
||||||
X402Version int `json:"x402Version"`
|
X402Version int `json:"x402Version"`
|
||||||
Accepts []X402AcceptOption `json:"accepts"`
|
Accepts []X402AcceptOption `json:"accepts"`
|
||||||
Resource *X402Resource `json:"resource"`
|
Resource *X402Resource `json:"resource"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// X402AcceptOption is a payment option from the x402 v2 header.
|
// X402AcceptOption is a payment option from the x402 v2 header.
|
||||||
@@ -114,16 +135,21 @@ func SignBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string
|
|||||||
|
|
||||||
// DoX402Request executes an HTTP request and handles the x402 v2 payment flow.
|
// DoX402Request executes an HTTP request and handles the x402 v2 payment flow.
|
||||||
func DoX402Request(
|
func DoX402Request(
|
||||||
|
ctx context.Context,
|
||||||
httpClient *http.Client,
|
httpClient *http.Client,
|
||||||
buildReqFn func() (*http.Request, error),
|
buildReqFn func() (*http.Request, error),
|
||||||
signFn X402SignFunc,
|
signFn X402SignFunc,
|
||||||
providerTag string,
|
providerTag string,
|
||||||
logger mcp.Logger,
|
logger mcp.Logger,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
req, err := buildReqFn()
|
req, err := buildReqFn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -157,6 +183,7 @@ func DoX402Request(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to build retry request: %w", err)
|
return nil, fmt.Errorf("failed to build retry request: %w", err)
|
||||||
}
|
}
|
||||||
|
req2 = req2.WithContext(ctx)
|
||||||
req2.Header.Set("X-Payment", paymentSig)
|
req2.Header.Set("X-Payment", paymentSig)
|
||||||
req2.Header.Set("Payment-Signature", paymentSig)
|
req2.Header.Set("Payment-Signature", paymentSig)
|
||||||
|
|
||||||
@@ -166,7 +193,9 @@ func DoX402Request(
|
|||||||
wait := X402RetryBaseWait * time.Duration(attempt)
|
wait := X402RetryBaseWait * time.Duration(attempt)
|
||||||
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
|
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
|
||||||
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
|
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
|
||||||
time.Sleep(wait)
|
if err := x402Sleep(ctx, wait); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("failed to send payment retry: %w", err)
|
return nil, fmt.Errorf("failed to send payment retry: %w", err)
|
||||||
@@ -221,7 +250,9 @@ func DoX402Request(
|
|||||||
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
|
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(wait)
|
if err := x402Sleep(ctx, wait); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,11 +287,15 @@ func DoX402RequestStream(
|
|||||||
providerTag string,
|
providerTag string,
|
||||||
logger mcp.Logger,
|
logger mcp.Logger,
|
||||||
) (*http.Response, error) {
|
) (*http.Response, error) {
|
||||||
// Initial request — use background context (no idle timeout yet).
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
// Initial request also inherits ctx so stage timeouts cancel the 402 handshake.
|
||||||
req, err := buildReqFn()
|
req, err := buildReqFn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -314,7 +349,9 @@ func DoX402RequestStream(
|
|||||||
wait := X402RetryBaseWait * time.Duration(attempt)
|
wait := X402RetryBaseWait * time.Duration(attempt)
|
||||||
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
|
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
|
||||||
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
|
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
|
||||||
time.Sleep(wait)
|
if err := x402Sleep(ctx, wait); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("failed to send payment retry: %w", err)
|
return nil, fmt.Errorf("failed to send payment retry: %w", err)
|
||||||
@@ -369,7 +406,9 @@ func DoX402RequestStream(
|
|||||||
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
|
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(wait)
|
if err := x402Sleep(ctx, wait); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -500,7 +539,7 @@ func X402Call(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt, user
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
|
body, err := DoX402Request(context.Background(), c.HTTPClient, func() (*http.Request, error) {
|
||||||
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
|
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
|
||||||
}, signFn, tag, c.Log)
|
}, signFn, tag, c.Log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -526,7 +565,7 @@ func X402CallFull(c *mcp.Client, signFn X402SignFunc, tag string, req *mcp.Reque
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
|
body, err := DoX402Request(x402ContextFromRequest(req), c.HTTPClient, func() (*http.Request, error) {
|
||||||
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
|
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
|
||||||
}, signFn, tag, c.Log)
|
}, signFn, tag, c.Log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ============================================================
|
// ============================================================
|
||||||
@@ -342,6 +347,69 @@ func TestClient_CallWithRequest_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_CallWithRequest_AttachesRequestContextToHTTP(t *testing.T) {
|
||||||
|
type contextKey string
|
||||||
|
const key contextKey = "stage"
|
||||||
|
ctx := context.WithValue(context.Background(), key, "planner")
|
||||||
|
|
||||||
|
mockHTTP := NewMockHTTPClient()
|
||||||
|
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
|
||||||
|
if req.Context().Value(key) != "planner" {
|
||||||
|
t.Fatalf("expected HTTP request to inherit mcp.Request context")
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok"}}]}`)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := NewClient(
|
||||||
|
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||||
|
WithLogger(NewMockLogger()),
|
||||||
|
WithAPIKey("sk-test-key"),
|
||||||
|
)
|
||||||
|
request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild()
|
||||||
|
request.Ctx = ctx
|
||||||
|
|
||||||
|
result, err := client.CallWithRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("should not error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "ok" {
|
||||||
|
t.Fatalf("expected ok, got %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_CallWithRequest_RetrySleepStopsWhenContextCancelled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
mockHTTP := NewMockHTTPClient()
|
||||||
|
mockHTTP.SetNetworkError(io.EOF)
|
||||||
|
client := NewClient(
|
||||||
|
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||||
|
WithLogger(NewMockLogger()),
|
||||||
|
WithAPIKey("sk-test-key"),
|
||||||
|
WithMaxRetries(2),
|
||||||
|
WithRetryWaitBase(time.Hour),
|
||||||
|
)
|
||||||
|
request := NewRequestBuilder().WithUserPrompt("Hello").MustBuild()
|
||||||
|
request.Ctx = ctx
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
_, err := client.CallWithRequest(request)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "context canceled") {
|
||||||
|
t.Fatalf("expected context canceled during retry wait, got %v", err)
|
||||||
|
}
|
||||||
|
if elapsed := time.Since(start); elapsed > 500*time.Millisecond {
|
||||||
|
t.Fatalf("retry sleep did not respect context cancellation, elapsed=%v", elapsed)
|
||||||
|
}
|
||||||
|
if got := len(mockHTTP.GetRequests()); got != 1 {
|
||||||
|
t.Fatalf("expected no retry after context cancellation, got %d requests", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClient_CallWithRequest_MultiRound(t *testing.T) {
|
func TestClient_CallWithRequest_MultiRound(t *testing.T) {
|
||||||
mockHTTP := NewMockHTTPClient()
|
mockHTTP := NewMockHTTPClient()
|
||||||
mockHTTP.SetSuccessResponse("Multi-round response")
|
mockHTTP.SetSuccessResponse("Multi-round response")
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ func (c *Claw402DataClient) DoRequest(endpoint string) ([]byte, error) {
|
|||||||
signFn := payment.MakeClaw402SignFunc(c.privateKey)
|
signFn := payment.MakeClaw402SignFunc(c.privateKey)
|
||||||
|
|
||||||
body, err := payment.DoX402Request(
|
body, err := payment.DoX402Request(
|
||||||
|
context.Background(),
|
||||||
c.httpClient,
|
c.httpClient,
|
||||||
buildReq,
|
buildReq,
|
||||||
signFn,
|
signFn,
|
||||||
|
|||||||
Reference in New Issue
Block a user