fix(memory): abort timed-out embedding requests (#82770)

* fix(memory): abort timed-out embedding requests

* test: stabilize gateway ci shards

* test: pin control ui origin fixture

* test: stabilize gateway ci fixtures

* test: isolate forged origin fixture

* test: decouple setup code from gateway net mocks

* test: repair run-node and config preaction CI

* test: fix run-node progress fixture typing

* test: remove unused pairing setup helper

* fix: stabilize embedding timeout errors
This commit is contained in:
Peter Steinberger
2026-05-17 02:04:17 +01:00
committed by GitHub
parent b77b3a7ade
commit a6225060f1
30 changed files with 295 additions and 71 deletions

View File

@@ -34,6 +34,7 @@ Docs: https://docs.openclaw.ai
- Telegram: let catch-all mention patterns match captionless group photos, so media-only group messages reach the agent when the group is intentionally configured to respond to all messages. Fixes #44833. (#82756) Thanks @IWhatsskill. - Telegram: let catch-all mention patterns match captionless group photos, so media-only group messages reach the agent when the group is intentionally configured to respond to all messages. Fixes #44833. (#82756) Thanks @IWhatsskill.
- Gateway/pairing: reject forged loopback Control UI origins from non-local proxy paths, and keep mobile pairing setup on Tailscale bind mode pointing users to Tailscale Serve/Funnel instead of cleartext tailnet WebSockets. - Gateway/pairing: reject forged loopback Control UI origins from non-local proxy paths, and keep mobile pairing setup on Tailscale bind mode pointing users to Tailscale Serve/Funnel instead of cleartext tailnet WebSockets.
- Telegram/Gateway: persist isolated polling offsets only after main-thread dispatch and preserve gateway caller scopes for Telegram message actions, fixing consumed-but-unrouted polling updates and recursive CLI send scope approvals. Fixes #82277. (#82705) Thanks @udaymanish6. - Telegram/Gateway: persist isolated polling offsets only after main-thread dispatch and preserve gateway caller scopes for Telegram message actions, fixing consumed-but-unrouted polling updates and recursive CLI send scope approvals. Fixes #82277. (#82705) Thanks @udaymanish6.
- Memory-core: abort timed-out embedding provider calls so remote embedding HTTP requests do not continue running after memory query or indexing timeouts. Fixes #82732. Thanks @adityarya24.
- Channels/stream previews: contain rejected background draft-stream flushes so preview send failures do not surface as fatal unhandled rejections. Fixes #82712. (#82713) Thanks @coygeek. - Channels/stream previews: contain rejected background draft-stream flushes so preview send failures do not surface as fatal unhandled rejections. Fixes #82712. (#82713) Thanks @coygeek.
- Codex/app-server: keep shared native app-server clients isolated per agent runtime key so starting one agent no longer closes another agent's active Codex turn. Fixes #82758. Thanks @PashaGanson. - Codex/app-server: keep shared native app-server clients isolated per agent runtime key so starting one agent no longer closes another agent's active Codex turn. Fixes #82758. Thanks @PashaGanson.
- Providers/OpenAI Codex: include base `gpt-5.5` and `gpt-5.4` reasoning metadata in the bundled Codex catalog so `/think xhigh` remains available for those models. Fixes #82744. - Providers/OpenAI Codex: include base `gpt-5.5` and `gpt-5.4` reasoning metadata in the bundled Codex catalog so `/think xhigh` remains available for those models. Fixes #82744.

View File

@@ -331,7 +331,7 @@ export async function createBedrockEmbeddingProvider(
family, family,
}); });
const invoke = async (body: string): Promise<string> => { const invoke = async (body: string, signal?: AbortSignal): Promise<string> => {
await refreshAwsSharedConfigCacheForBedrock(); await refreshAwsSharedConfigCacheForBedrock();
const sdk = new BedrockRuntimeClient({ region: client.region }); const sdk = new BedrockRuntimeClient({ region: client.region });
try { try {
@@ -342,6 +342,7 @@ export async function createBedrockEmbeddingProvider(
contentType: "application/json", contentType: "application/json",
accept: "application/json", accept: "application/json",
}), }),
signal ? { abortSignal: signal } : undefined,
); );
return new TextDecoder().decode(res.body); return new TextDecoder().decode(res.body);
} finally { } finally {
@@ -351,37 +352,46 @@ export async function createBedrockEmbeddingProvider(
const isCohere = family === "cohere-v3" || family === "cohere-v4"; const isCohere = family === "cohere-v3" || family === "cohere-v4";
const embedSingle = async (text: string): Promise<number[]> => { const embedSingle = async (text: string, signal?: AbortSignal): Promise<number[]> => {
const raw = await invoke(buildBody(family, text, client.dimensions)); const raw = await invoke(buildBody(family, text, client.dimensions), signal);
return sanitizeAndNormalizeEmbedding(parseSingle(family, raw)); return sanitizeAndNormalizeEmbedding(parseSingle(family, raw));
}; };
const embedCohere = async ( const embedCohere = async (
texts: string[], texts: string[],
inputType: "search_query" | "search_document", inputType: "search_query" | "search_document",
signal?: AbortSignal,
): Promise<number[][]> => { ): Promise<number[][]> => {
const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions)); const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions), signal);
return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e)); return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e));
}; };
const embedQuery = async (text: string): Promise<number[]> => { const embedQuery = async (
text: string,
options?: { signal?: AbortSignal },
): Promise<number[]> => {
if (!text.trim()) { if (!text.trim()) {
return []; return [];
} }
if (isCohere) { if (isCohere) {
return (await embedCohere([text], "search_query"))[0] ?? []; return (await embedCohere([text], "search_query", options?.signal))[0] ?? [];
} }
return embedSingle(text); return embedSingle(text, options?.signal);
}; };
const embedBatch = async (texts: string[]): Promise<number[][]> => { const embedBatch = async (
texts: string[],
options?: { signal?: AbortSignal },
): Promise<number[][]> => {
if (texts.length === 0) { if (texts.length === 0) {
return []; return [];
} }
if (isCohere) { if (isCohere) {
return embedCohere(texts, "search_document"); return embedCohere(texts, "search_document", options?.signal);
} }
return Promise.all(texts.map((t) => (t.trim() ? embedSingle(t) : Promise.resolve([])))); return Promise.all(
texts.map((t) => (t.trim() ? embedSingle(t, options?.signal) : Promise.resolve([]))),
);
}; };
return { return {

View File

@@ -221,7 +221,7 @@ async function createGitHubCopilotEmbeddingProvider(
): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> { ): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
const initialSession = await resolveGitHubCopilotEmbeddingSession(client); const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
const embed = async (input: string[]): Promise<number[][]> => { const embed = async (input: string[], signal?: AbortSignal): Promise<number[][]> => {
if (input.length === 0) { if (input.length === 0) {
return []; return [];
} }
@@ -232,6 +232,7 @@ async function createGitHubCopilotEmbeddingProvider(
url, url,
fetchImpl: client.fetchImpl, fetchImpl: client.fetchImpl,
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl), ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
signal,
init: { init: {
method: "POST", method: "POST",
headers: session.headers, headers: session.headers,
@@ -259,11 +260,11 @@ async function createGitHubCopilotEmbeddingProvider(
provider: { provider: {
id: COPILOT_EMBEDDING_PROVIDER_ID, id: COPILOT_EMBEDDING_PROVIDER_ID,
model: client.model, model: client.model,
embedQuery: async (text) => { embedQuery: async (text, options) => {
const [vector] = await embed([text]); const [vector] = await embed([text], options?.signal);
return vector ?? []; return vector ?? [];
}, },
embedBatch: embed, embedBatch: async (texts, options) => await embed(texts, options?.signal),
}, },
client: { client: {
...client, ...client,

View File

@@ -242,6 +242,7 @@ async function fetchGeminiEmbeddingPayload(params: {
client: GeminiEmbeddingClient; client: GeminiEmbeddingClient;
endpoint: string; endpoint: string;
body: unknown; body: unknown;
signal?: AbortSignal;
}): Promise<Record<string, unknown>> { }): Promise<Record<string, unknown>> {
return await executeWithApiKeyRotation({ return await executeWithApiKeyRotation({
provider: "google", provider: "google",
@@ -256,6 +257,7 @@ async function fetchGeminiEmbeddingPayload(params: {
return await withRemoteHttpResponse({ return await withRemoteHttpResponse({
url: params.endpoint, url: params.endpoint,
ssrfPolicy: params.client.ssrfPolicy, ssrfPolicy: params.client.ssrfPolicy,
signal: params.signal,
init: { init: {
method: "POST", method: "POST",
headers, headers,
@@ -316,7 +318,10 @@ export async function createGeminiEmbeddingProvider(
const isV2 = isGeminiEmbedding2Model(client.model); const isV2 = isGeminiEmbedding2Model(client.model);
const outputDimensionality = client.outputDimensionality; const outputDimensionality = client.outputDimensionality;
const embedQuery = async (text: string): Promise<number[]> => { const embedQuery = async (
text: string,
callOptions?: { signal?: AbortSignal },
): Promise<number[]> => {
if (!text.trim()) { if (!text.trim()) {
return []; return [];
} }
@@ -328,11 +333,15 @@ export async function createGeminiEmbeddingProvider(
taskType: options.taskType ?? "RETRIEVAL_QUERY", taskType: options.taskType ?? "RETRIEVAL_QUERY",
outputDimensionality: isV2 ? outputDimensionality : undefined, outputDimensionality: isV2 ? outputDimensionality : undefined,
}), }),
signal: callOptions?.signal,
}); });
return sanitizeAndNormalizeEmbedding(readGeminiSingleEmbedding(payload)); return sanitizeAndNormalizeEmbedding(readGeminiSingleEmbedding(payload));
}; };
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => { const embedBatchInputs = async (
inputs: EmbeddingInput[],
callOptions?: { signal?: AbortSignal },
): Promise<number[][]> => {
if (inputs.length === 0) { if (inputs.length === 0) {
return []; return [];
} }
@@ -349,16 +358,21 @@ export async function createGeminiEmbeddingProvider(
}), }),
), ),
}, },
signal: callOptions?.signal,
}); });
const embeddings = readGeminiBatchEmbeddings(payload, inputs.length); const embeddings = readGeminiBatchEmbeddings(payload, inputs.length);
return embeddings.map((values) => sanitizeAndNormalizeEmbedding(values)); return embeddings.map((values) => sanitizeAndNormalizeEmbedding(values));
}; };
const embedBatch = async (texts: string[]): Promise<number[][]> => { const embedBatch = async (
texts: string[],
options?: { signal?: AbortSignal },
): Promise<number[][]> => {
return await embedBatchInputs( return await embedBatchInputs(
texts.map((text) => ({ texts.map((text) => ({
text, text,
})), })),
options,
); );
}; };

View File

@@ -111,6 +111,34 @@ export function resolveMemoryIndexConcurrency(params: {
return params.providerId === "ollama" ? 1 : EMBEDDING_INDEX_CONCURRENCY; return params.providerId === "ollama" ? 1 : EMBEDDING_INDEX_CONCURRENCY;
} }
export async function runEmbeddingOperationWithTimeout<T>(params: {
timeoutMs: number;
message: string;
run: (signal: AbortSignal) => Promise<T>;
}): Promise<T> {
const controller = new AbortController();
if (!Number.isFinite(params.timeoutMs) || params.timeoutMs <= 0) {
return await params.run(controller.signal);
}
let timer: NodeJS.Timeout | null = null;
const timeoutPromise = new Promise<never>((_, reject) => {
timer = setTimeout(() => {
const error = new Error(params.message);
reject(error);
controller.abort(error);
}, params.timeoutMs);
timer.unref?.();
});
try {
const operation = params.run(controller.signal);
return (await Promise.race([operation, timeoutPromise])) as T;
} finally {
if (timer) {
clearTimeout(timer);
}
}
}
export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps { export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
protected abstract batchFailureCount: number; protected abstract batchFailureCount: number;
protected abstract batchFailureLastError?: string; protected abstract batchFailureLastError?: string;
@@ -304,11 +332,11 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
items: texts.length, items: texts.length,
timeoutMs, timeoutMs,
}); });
return await this.withTimeout( return await runEmbeddingOperationWithTimeout({
provider.embedBatch(texts),
timeoutMs, timeoutMs,
`memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`, message: `memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
); run: async (signal) => await provider.embedBatch(texts, { signal }),
});
}, },
isRetryable: isRetryableMemoryEmbeddingError, isRetryable: isRetryableMemoryEmbeddingError,
waitForRetry: async (delayMs) => { waitForRetry: async (delayMs) => {
@@ -336,11 +364,11 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
items: inputs.length, items: inputs.length,
timeoutMs, timeoutMs,
}); });
return await this.withTimeout( return await runEmbeddingOperationWithTimeout({
embedBatchInputs(inputs),
timeoutMs, timeoutMs,
`memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`, message: `memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
); run: async (signal) => await embedBatchInputs(inputs, { signal }),
});
}, },
isRetryable: isRetryableMemoryEmbeddingError, isRetryable: isRetryableMemoryEmbeddingError,
waitForRetry: async (delayMs) => { waitForRetry: async (delayMs) => {
@@ -371,16 +399,17 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
} }
protected async embedQueryWithTimeout(text: string): Promise<number[]> { protected async embedQueryWithTimeout(text: string): Promise<number[]> {
if (!this.provider) { const provider = this.provider;
if (!provider) {
throw new Error("Cannot embed query in FTS-only mode (no embedding provider)"); throw new Error("Cannot embed query in FTS-only mode (no embedding provider)");
} }
const timeoutMs = this.resolveEmbeddingTimeout("query"); const timeoutMs = this.resolveEmbeddingTimeout("query");
log.debug("memory embeddings: query start", { provider: this.provider.id, timeoutMs }); log.debug("memory embeddings: query start", { provider: provider.id, timeoutMs });
return await this.withTimeout( return await runEmbeddingOperationWithTimeout({
this.provider.embedQuery(text),
timeoutMs, timeoutMs,
`memory embeddings query timed out after ${Math.round(timeoutMs / 1000)}s`, message: `memory embeddings query timed out after ${Math.round(timeoutMs / 1000)}s`,
); run: async (signal) => await provider.embedQuery(text, { signal }),
});
} }
protected async withTimeout<T>( protected async withTimeout<T>(

View File

@@ -2,6 +2,7 @@ import { describe, expect, it } from "vitest";
import { import {
resolveEmbeddingTimeoutMs, resolveEmbeddingTimeoutMs,
resolveMemoryIndexConcurrency, resolveMemoryIndexConcurrency,
runEmbeddingOperationWithTimeout,
} from "./manager-embedding-ops.js"; } from "./manager-embedding-ops.js";
describe("memory embedding timeout resolution", () => { describe("memory embedding timeout resolution", () => {
@@ -37,6 +38,42 @@ describe("memory embedding timeout resolution", () => {
}); });
}); });
describe("memory embedding timeout abort", () => {
it("aborts the provider operation when the timeout fires", async () => {
let signalSeen: AbortSignal | undefined;
await expect(
runEmbeddingOperationWithTimeout({
timeoutMs: 1,
message: "memory embeddings query timed out after 0s",
run: async (signal) => {
signalSeen = signal;
return await new Promise<number[]>((resolve, reject) => {
signal.addEventListener("abort", () => reject(signal.reason), { once: true });
});
},
}),
).rejects.toThrow("memory embeddings query timed out after 0s");
expect(signalSeen?.aborted).toBe(true);
});
it("keeps the timeout error when a provider abort listener rejects generically", async () => {
await expect(
runEmbeddingOperationWithTimeout({
timeoutMs: 1,
message: "memory embeddings batch timed out after 0s",
run: async (signal) =>
await new Promise<number[]>((_resolve, reject) => {
signal.addEventListener("abort", () => reject(new Error("provider aborted")), {
once: true,
});
}),
}),
).rejects.toThrow("memory embeddings batch timed out after 0s");
});
});
describe("memory index concurrency resolution", () => { describe("memory index concurrency resolution", () => {
it("uses the default index concurrency when batch mode is disabled and unconfigured", () => { it("uses the default index concurrency when batch mode is disabled and unconfigured", () => {
expect( expect(

View File

@@ -25,8 +25,8 @@ export type OllamaEmbeddingProvider = {
id: string; id: string;
model: string; model: string;
maxInputTokens?: number; maxInputTokens?: number;
embedQuery: (text: string) => Promise<number[]>; embedQuery: (text: string, options?: { signal?: AbortSignal }) => Promise<number[]>;
embedBatch: (texts: string[]) => Promise<number[][]>; embedBatch: (texts: string[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
}; };
type OllamaEmbeddingOptions = { type OllamaEmbeddingOptions = {
@@ -90,12 +90,14 @@ function sanitizeAndNormalizeEmbedding(vec: unknown[]): number[] {
async function withRemoteHttpResponse<T>(params: { async function withRemoteHttpResponse<T>(params: {
url: string; url: string;
init?: RequestInit; init?: RequestInit;
signal?: AbortSignal;
ssrfPolicy?: SsrFPolicy; ssrfPolicy?: SsrFPolicy;
onResponse: (response: Response) => Promise<T>; onResponse: (response: Response) => Promise<T>;
}): Promise<T> { }): Promise<T> {
const { response, release } = await fetchWithSsrFGuard({ const { response, release } = await fetchWithSsrFGuard({
url: params.url, url: params.url,
init: params.init, init: params.init,
signal: params.signal,
policy: params.ssrfPolicy, policy: params.ssrfPolicy,
auditContext: "memory-remote", auditContext: "memory-remote",
}); });
@@ -322,10 +324,11 @@ export async function createOllamaEmbeddingProvider(
const client = resolveOllamaEmbeddingClient(options); const client = resolveOllamaEmbeddingClient(options);
const embedUrl = `${client.baseUrl.replace(/\/$/, "")}/api/embed`; const embedUrl = `${client.baseUrl.replace(/\/$/, "")}/api/embed`;
const embedMany = async (input: string | string[]): Promise<number[][]> => { const embedMany = async (input: string | string[], signal?: AbortSignal): Promise<number[][]> => {
const json = await withRemoteHttpResponse({ const json = await withRemoteHttpResponse({
url: embedUrl, url: embedUrl,
ssrfPolicy: client.ssrfPolicy, ssrfPolicy: client.ssrfPolicy,
signal,
init: { init: {
method: "POST", method: "POST",
headers: client.headers, headers: client.headers,
@@ -355,22 +358,23 @@ export async function createOllamaEmbeddingProvider(
}); });
}; };
const embedOne = async (text: string): Promise<number[]> => { const embedOne = async (text: string, signal?: AbortSignal): Promise<number[]> => {
const [embedding] = await embedMany(text); const [embedding] = await embedMany(text, signal);
if (!embedding) { if (!embedding) {
throw new Error("Ollama embed response returned no embedding"); throw new Error("Ollama embed response returned no embedding");
} }
return embedding; return embedding;
}; };
const embedQuery = async (text: string): Promise<number[]> => const embedQuery = async (text: string, options?: { signal?: AbortSignal }): Promise<number[]> =>
await embedOne(applyQueryInstructionTemplate(client.model, text)); await embedOne(applyQueryInstructionTemplate(client.model, text), options?.signal);
const provider: OllamaEmbeddingProvider = { const provider: OllamaEmbeddingProvider = {
id: "ollama", id: "ollama",
model: client.model, model: client.model,
embedQuery, embedQuery,
embedBatch: async (texts) => (texts.length === 0 ? [] : await embedMany(texts)), embedBatch: async (texts, options) =>
texts.length === 0 ? [] : await embedMany(texts, options?.signal),
}; };
return { return {

View File

@@ -35,6 +35,7 @@ function expectFetchRemoteEmbeddingVectorsBody(body: Record<string, unknown>) {
headers: { Authorization: "Bearer test" }, headers: { Authorization: "Bearer test" },
ssrfPolicy: undefined, ssrfPolicy: undefined,
fetchImpl: undefined, fetchImpl: undefined,
signal: undefined,
body, body,
errorPrefix: "openai embeddings failed", errorPrefix: "openai embeddings failed",
}); });

View File

@@ -47,7 +47,11 @@ export async function createOpenAiEmbeddingProvider(
return typeof value === "string" && value.trim().length > 0 ? value.trim() : undefined; return typeof value === "string" && value.trim().length > 0 ? value.trim() : undefined;
}; };
const embed = async (input: string[], kind: "query" | "document"): Promise<number[][]> => { const embed = async (
input: string[],
kind: "query" | "document",
signal?: AbortSignal,
): Promise<number[][]> => {
if (input.length === 0) { if (input.length === 0) {
return []; return [];
} }
@@ -57,6 +61,7 @@ export async function createOpenAiEmbeddingProvider(
headers: client.headers, headers: client.headers,
ssrfPolicy: client.ssrfPolicy, ssrfPolicy: client.ssrfPolicy,
fetchImpl: client.fetchImpl, fetchImpl: client.fetchImpl,
signal,
body: { body: {
model: client.model, model: client.model,
input, input,
@@ -76,11 +81,11 @@ export async function createOpenAiEmbeddingProvider(
...(typeof OPENAI_MAX_INPUT_TOKENS[client.model] === "number" ...(typeof OPENAI_MAX_INPUT_TOKENS[client.model] === "number"
? { maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model] } ? { maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model] }
: {}), : {}),
embedQuery: async (text) => { embedQuery: async (text, options) => {
const [vec] = await embed([text], "query"); const [vec] = await embed([text], "query", options?.signal);
return vec ?? []; return vec ?? [];
}, },
embedBatch: async (texts) => await embed(texts, "document"), embedBatch: async (texts, options) => await embed(texts, "document", options?.signal),
}, },
client, client,
}; };

View File

@@ -36,7 +36,11 @@ export async function createVoyageEmbeddingProvider(
const client = await resolveVoyageEmbeddingClient(options); const client = await resolveVoyageEmbeddingClient(options);
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => { const embed = async (
input: string[],
input_type?: "query" | "document",
signal?: AbortSignal,
): Promise<number[][]> => {
if (input.length === 0) { if (input.length === 0) {
return []; return [];
} }
@@ -52,6 +56,7 @@ export async function createVoyageEmbeddingProvider(
url, url,
headers: client.headers, headers: client.headers,
ssrfPolicy: client.ssrfPolicy, ssrfPolicy: client.ssrfPolicy,
signal,
body, body,
errorPrefix: "voyage embeddings failed", errorPrefix: "voyage embeddings failed",
}); });
@@ -62,11 +67,11 @@ export async function createVoyageEmbeddingProvider(
id: "voyage", id: "voyage",
model: client.model, model: client.model,
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model], maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
embedQuery: async (text) => { embedQuery: async (text, options) => {
const [vec] = await embed([text], "query"); const [vec] = await embed([text], "query", options?.signal);
return vec ?? []; return vec ?? [];
}, },
embedBatch: async (texts) => embed(texts, "document"), embedBatch: async (texts, options) => embed(texts, "document", options?.signal),
}, },
client, client,
}; };

View File

@@ -11,6 +11,7 @@ export type {
MemoryEmbeddingBatchOptions, MemoryEmbeddingBatchOptions,
MemoryEmbeddingProvider, MemoryEmbeddingProvider,
MemoryEmbeddingProviderAdapter, MemoryEmbeddingProviderAdapter,
MemoryEmbeddingProviderCallOptions,
MemoryEmbeddingProviderCreateOptions, MemoryEmbeddingProviderCreateOptions,
MemoryEmbeddingProviderCreateResult, MemoryEmbeddingProviderCreateResult,
MemoryEmbeddingProviderRuntime, MemoryEmbeddingProviderRuntime,

View File

@@ -10,6 +10,7 @@ vi.mock("./post-json.js", () => ({
function requirePostJsonParams(): { function requirePostJsonParams(): {
url?: unknown; url?: unknown;
headers?: unknown; headers?: unknown;
signal?: unknown;
body?: unknown; body?: unknown;
errorPrefix?: unknown; errorPrefix?: unknown;
} { } {
@@ -51,6 +52,23 @@ describe("fetchRemoteEmbeddingVectors", () => {
expect(postJsonParams.errorPrefix).toBe("embedding fetch failed"); expect(postJsonParams.errorPrefix).toBe("embedding fetch failed");
}); });
it("passes abort signals to the JSON request", async () => {
const controller = new AbortController();
postJsonMock.mockImplementationOnce(async (params) => {
return await params.parse({ data: [{ embedding: [0.1] }] });
});
await fetchRemoteEmbeddingVectors({
url: "https://memory.example/v1/embeddings",
headers: {},
signal: controller.signal,
body: { input: ["one"] },
errorPrefix: "embedding fetch failed",
});
expect(requirePostJsonParams().signal).toBe(controller.signal);
});
it("throws a status-rich error on non-ok responses", async () => { it("throws a status-rich error on non-ok responses", async () => {
postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden")); postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden"));

View File

@@ -33,6 +33,7 @@ export async function fetchRemoteEmbeddingVectors(params: {
headers: Record<string, string>; headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy; ssrfPolicy?: SsrFPolicy;
fetchImpl?: typeof fetch; fetchImpl?: typeof fetch;
signal?: AbortSignal;
body: unknown; body: unknown;
errorPrefix: string; errorPrefix: string;
}): Promise<number[][]> { }): Promise<number[][]> {
@@ -41,6 +42,7 @@ export async function fetchRemoteEmbeddingVectors(params: {
headers: params.headers, headers: params.headers,
ssrfPolicy: params.ssrfPolicy, ssrfPolicy: params.ssrfPolicy,
fetchImpl: params.fetchImpl, fetchImpl: params.fetchImpl,
signal: params.signal,
body: params.body, body: params.body,
errorPrefix: params.errorPrefix, errorPrefix: params.errorPrefix,
parse: (payload) => { parse: (payload) => {

View File

@@ -23,7 +23,7 @@ export function createRemoteEmbeddingProvider(params: {
const { client } = params; const { client } = params;
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
const embed = async (input: string[]): Promise<number[][]> => { const embed = async (input: string[], signal?: AbortSignal): Promise<number[][]> => {
if (input.length === 0) { if (input.length === 0) {
return []; return [];
} }
@@ -32,6 +32,7 @@ export function createRemoteEmbeddingProvider(params: {
headers: client.headers, headers: client.headers,
ssrfPolicy: client.ssrfPolicy, ssrfPolicy: client.ssrfPolicy,
fetchImpl: client.fetchImpl, fetchImpl: client.fetchImpl,
signal,
body: { model: client.model, input }, body: { model: client.model, input },
errorPrefix: params.errorPrefix, errorPrefix: params.errorPrefix,
}); });
@@ -41,11 +42,11 @@ export function createRemoteEmbeddingProvider(params: {
id: params.id, id: params.id,
model: client.model, model: client.model,
...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}), ...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}),
embedQuery: async (text) => { embedQuery: async (text, options) => {
const [vec] = await embed([text]); const [vec] = await embed([text], options?.signal);
return vec ?? []; return vec ?? [];
}, },
embedBatch: embed, embedBatch: async (texts, options) => await embed(texts, options?.signal),
}; };
} }

View File

@@ -66,15 +66,20 @@ export async function createLocalEmbeddingProvider(
return { return {
id: "local", id: "local",
model: modelPath, model: modelPath,
embedQuery: async (text) => { embedQuery: async (text, options) => {
options?.signal?.throwIfAborted();
const ctx = await ensureContext(); const ctx = await ensureContext();
options?.signal?.throwIfAborted();
const embedding = await ctx.getEmbeddingFor(text); const embedding = await ctx.getEmbeddingFor(text);
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
}, },
embedBatch: async (texts) => { embedBatch: async (texts, options) => {
options?.signal?.throwIfAborted();
const ctx = await ensureContext(); const ctx = await ensureContext();
options?.signal?.throwIfAborted();
const embeddings = await Promise.all( const embeddings = await Promise.all(
texts.map(async (text) => { texts.map(async (text) => {
options?.signal?.throwIfAborted();
const embedding = await ctx.getEmbeddingFor(text); const embedding = await ctx.getEmbeddingFor(text);
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
}), }),

View File

@@ -5,9 +5,16 @@ export type EmbeddingProvider = {
id: string; id: string;
model: string; model: string;
maxInputTokens?: number; maxInputTokens?: number;
embedQuery: (text: string) => Promise<number[]>; embedQuery: (text: string, options?: EmbeddingProviderCallOptions) => Promise<number[]>;
embedBatch: (texts: string[]) => Promise<number[][]>; embedBatch: (texts: string[], options?: EmbeddingProviderCallOptions) => Promise<number[][]>;
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>; embedBatchInputs?: (
inputs: EmbeddingInput[],
options?: EmbeddingProviderCallOptions,
) => Promise<number[][]>;
};
export type EmbeddingProviderCallOptions = {
signal?: AbortSignal;
}; };
export type EmbeddingProviderId = string; export type EmbeddingProviderId = string;

View File

@@ -15,6 +15,7 @@ export type {
MemoryEmbeddingBatchOptions, MemoryEmbeddingBatchOptions,
MemoryEmbeddingProvider, MemoryEmbeddingProvider,
MemoryEmbeddingProviderAdapter, MemoryEmbeddingProviderAdapter,
MemoryEmbeddingProviderCallOptions,
MemoryEmbeddingProviderCreateOptions, MemoryEmbeddingProviderCreateOptions,
MemoryEmbeddingProviderCreateResult, MemoryEmbeddingProviderCreateResult,
MemoryEmbeddingProviderRuntime, MemoryEmbeddingProviderRuntime,

View File

@@ -112,6 +112,7 @@ export type {
MemoryEmbeddingBatchOptions, MemoryEmbeddingBatchOptions,
MemoryEmbeddingProvider, MemoryEmbeddingProvider,
MemoryEmbeddingProviderAdapter, MemoryEmbeddingProviderAdapter,
MemoryEmbeddingProviderCallOptions,
MemoryEmbeddingProviderCreateOptions, MemoryEmbeddingProviderCreateOptions,
MemoryEmbeddingProviderCreateResult, MemoryEmbeddingProviderCreateResult,
MemoryEmbeddingProviderRuntime, MemoryEmbeddingProviderRuntime,

View File

@@ -47,6 +47,23 @@ describe("postJson", () => {
expect(result).toEqual({ data: [{ embedding: [1, 2] }] }); expect(result).toEqual({ data: [{ embedding: [1, 2] }] });
}); });
it("forwards abort signals to the remote HTTP request", async () => {
const controller = new AbortController();
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.signal).toBe(controller.signal);
return await params.onResponse(jsonResponse({ ok: true }));
});
await postJson({
url: "https://memory.example/v1/post",
headers: {},
body: {},
signal: controller.signal,
errorPrefix: "post failed",
parse: (payload) => payload,
});
});
it("attaches status to thrown error when requested", async () => { it("attaches status to thrown error when requested", async () => {
remoteHttpMock.mockImplementationOnce(async (params) => { remoteHttpMock.mockImplementationOnce(async (params) => {
return await params.onResponse(textResponse("bad gateway", 502)); return await params.onResponse(textResponse("bad gateway", 502));

View File

@@ -6,6 +6,7 @@ export async function postJson<T>(params: {
headers: Record<string, string>; headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy; ssrfPolicy?: SsrFPolicy;
fetchImpl?: typeof fetch; fetchImpl?: typeof fetch;
signal?: AbortSignal;
body: unknown; body: unknown;
errorPrefix: string; errorPrefix: string;
attachStatus?: boolean; attachStatus?: boolean;
@@ -15,6 +16,7 @@ export async function postJson<T>(params: {
url: params.url, url: params.url,
ssrfPolicy: params.ssrfPolicy, ssrfPolicy: params.ssrfPolicy,
fetchImpl: params.fetchImpl, fetchImpl: params.fetchImpl,
signal: params.signal,
init: { init: {
method: "POST", method: "POST",
headers: params.headers, headers: params.headers,

View File

@@ -43,4 +43,18 @@ describe("package withRemoteHttpResponse", () => {
expect(deps.calls).toHaveLength(1); expect(deps.calls).toHaveLength(1);
expect(deps.calls[0]).not.toHaveProperty("mode"); expect(deps.calls[0]).not.toHaveProperty("mode");
}); });
it("passes abort signals to the guarded fetch", async () => {
const deps = makeFetchDeps();
const controller = new AbortController();
await withRemoteHttpResponse({
url: "https://memory.example/v1/embeddings",
signal: controller.signal,
onResponse: async () => undefined,
...deps,
});
expect(deps.calls[0]).toHaveProperty("signal", controller.signal);
});
}); });

View File

@@ -13,6 +13,7 @@ export const buildRemoteBaseUrlPolicy: (baseUrl: string) => SsrFPolicy | undefin
export async function withRemoteHttpResponse<T>(params: { export async function withRemoteHttpResponse<T>(params: {
url: string; url: string;
init?: RequestInit; init?: RequestInit;
signal?: AbortSignal;
ssrfPolicy?: SsrFPolicy; ssrfPolicy?: SsrFPolicy;
fetchImpl?: typeof fetch; fetchImpl?: typeof fetch;
fetchWithSsrFGuardImpl?: typeof fetchWithSsrFGuard; fetchWithSsrFGuardImpl?: typeof fetchWithSsrFGuard;
@@ -26,6 +27,7 @@ export async function withRemoteHttpResponse<T>(params: {
url: params.url, url: params.url,
fetchImpl: params.fetchImpl, fetchImpl: params.fetchImpl,
init: params.init, init: params.init,
signal: params.signal,
policy: params.ssrfPolicy, policy: params.ssrfPolicy,
auditContext: params.auditContext ?? "memory-remote", auditContext: params.auditContext ?? "memory-remote",
...(shouldUseEnvProxy(params.url) ? { mode: MEMORY_REMOTE_TRUSTED_ENV_PROXY_MODE } : {}), ...(shouldUseEnvProxy(params.url) ? { mode: MEMORY_REMOTE_TRUSTED_ENV_PROXY_MODE } : {}),

View File

@@ -51,5 +51,9 @@ export function runNodeMain(params?: {
cwd?: string; cwd?: string;
args?: string[]; args?: string[];
env?: NodeJS.ProcessEnv; env?: NodeJS.ProcessEnv;
runRuntimePostBuild?: (params?: {
cwd?: string;
env?: Record<string, string | undefined>;
}) => void | Promise<void>;
platform?: NodeJS.Platform; platform?: NodeJS.Platform;
}): Promise<number>; }): Promise<number>;

View File

@@ -79,6 +79,22 @@ function isGuidedConfigAction(actionCommand: Command): boolean {
return actionCommand.name() === "config" && !actionCommand.parent?.parent; return actionCommand.name() === "config" && !actionCommand.parent?.parent;
} }
function isGuidedConfigCommandPath(commandPath: string[]): boolean {
const [primary, secondary, extra] = commandPath;
if (primary !== "config" || extra !== undefined) {
return false;
}
return (
secondary !== "get" &&
secondary !== "set" &&
secondary !== "patch" &&
secondary !== "unset" &&
secondary !== "file" &&
secondary !== "schema" &&
secondary !== "validate"
);
}
export function registerPreActionHooks(program: Command, programVersion: string) { export function registerPreActionHooks(program: Command, programVersion: string) {
program.hook("preAction", async (_thisCommand, actionCommand) => { program.hook("preAction", async (_thisCommand, actionCommand) => {
setProcessTitleForCommand(actionCommand); setProcessTitleForCommand(actionCommand);
@@ -105,7 +121,11 @@ export function registerPreActionHooks(program: Command, programVersion: string)
if (!verbose) { if (!verbose) {
process.env.NODE_NO_WARNINGS ??= "1"; process.env.NODE_NO_WARNINGS ??= "1";
} }
if (shouldBypassConfigGuardForCommandPath(commandPath) || isGuidedConfigAction(actionCommand)) { if (
shouldBypassConfigGuardForCommandPath(commandPath) ||
isGuidedConfigAction(actionCommand) ||
isGuidedConfigCommandPath(commandPath)
) {
return; return;
} }
await ensureCliExecutionBootstrap({ await ensureCliExecutionBootstrap({

View File

@@ -430,10 +430,19 @@ describe("gateway auth browser hardening", () => {
}); });
test("rejects forged loopback origin for control-ui when proxy headers make client non-local", async () => { test("rejects forged loopback origin for control-ui when proxy headers make client non-local", async () => {
const { writeConfigFile } = await import("../config/config.js");
await writeConfigFile({
gateway: {
trustedProxies: ["127.0.0.1"],
controlUi: {
allowedOrigins: [],
},
},
});
testState.gatewayAuth = { mode: "token", token: "secret" }; testState.gatewayAuth = { mode: "token", token: "secret" };
await withGatewayServer(async ({ port }) => { await withGatewayServer(async ({ port }) => {
const ws = await openWs(port, { const ws = await openWs(port, {
origin: originForPort(port), origin: "http://localhost:5173",
"x-forwarded-for": "203.0.113.50", "x-forwarded-for": "203.0.113.50",
}); });
try { try {
@@ -444,6 +453,7 @@ describe("gateway auth browser hardening", () => {
id: GATEWAY_CLIENT_NAMES.CONTROL_UI, id: GATEWAY_CLIENT_NAMES.CONTROL_UI,
mode: GATEWAY_CLIENT_MODES.UI, mode: GATEWAY_CLIENT_MODES.UI,
}, },
device: null,
}); });
expect(res.ok).toBe(false); expect(res.ok).toBe(false);
expect(res.error?.message ?? "").toContain("origin not allowed"); expect(res.error?.message ?? "").toContain("origin not allowed");

View File

@@ -1415,9 +1415,6 @@ describe("run-node script", () => {
return true; return true;
}), }),
} as unknown as NodeJS.WriteStream; } as unknown as NodeJS.WriteStream;
const stdout = {
write: vi.fn(() => true),
} as unknown as NodeJS.WriteStream;
const exitCode = await runNodeMain({ const exitCode = await runNodeMain({
cwd: tmp, cwd: tmp,
@@ -1430,11 +1427,10 @@ describe("run-node script", () => {
spawn, spawn,
spawnSync, spawnSync,
stderr, stderr,
stdout,
runRuntimePostBuild: async () => {}, runRuntimePostBuild: async () => {},
execPath: process.execPath, execPath: process.execPath,
platform: process.platform, platform: process.platform,
} as Parameters<typeof runNodeMain>[0] & { stdout: NodeJS.WriteStream }); } as Parameters<typeof runNodeMain>[0]);
expect(exitCode).toBe(0); expect(exitCode).toBe(0);
const stderrText = stderrChunks.join(""); const stderrText = stderrChunks.join("");

View File

@@ -4,7 +4,6 @@ import type { OpenClawConfig } from "../config/types.js";
import { normalizeSecretInputString, resolveSecretInputRef } from "../config/types.secrets.js"; import { normalizeSecretInputString, resolveSecretInputRef } from "../config/types.secrets.js";
import { materializeGatewayAuthSecretRefs } from "../gateway/auth-config-utils.js"; import { materializeGatewayAuthSecretRefs } from "../gateway/auth-config-utils.js";
import { assertExplicitGatewayAuthModeWhenBothConfigured } from "../gateway/auth-mode-policy.js"; import { assertExplicitGatewayAuthModeWhenBothConfigured } from "../gateway/auth-mode-policy.js";
import { isLoopbackHost } from "../gateway/net.js";
import { issueDeviceBootstrapToken } from "../infra/device-bootstrap.js"; import { issueDeviceBootstrapToken } from "../infra/device-bootstrap.js";
import { import {
pickMatchingExternalInterfaceAddress, pickMatchingExternalInterfaceAddress,
@@ -16,6 +15,7 @@ import {
isCarrierGradeNatIpv4Address, isCarrierGradeNatIpv4Address,
isIpv4Address, isIpv4Address,
isIpv6Address, isIpv6Address,
isLoopbackIpAddress,
isRfc1918Ipv4Address, isRfc1918Ipv4Address,
parseCanonicalIpAddress, parseCanonicalIpAddress,
} from "../shared/net/ip.js"; } from "../shared/net/ip.js";
@@ -124,7 +124,12 @@ function isPrivateLanHost(host: string): boolean {
function isMobilePairingCleartextAllowedHost(host: string): boolean { function isMobilePairingCleartextAllowedHost(host: string): boolean {
const normalized = normalizeMobilePairingHost(host); const normalized = normalizeMobilePairingHost(host);
return isLoopbackHost(normalized) || normalized === "10.0.2.2" || isPrivateLanHost(normalized); return (
normalized === "localhost" ||
isLoopbackIpAddress(normalized) ||
normalized === "10.0.2.2" ||
isPrivateLanHost(normalized)
);
} }
function validateMobilePairingUrl(url: string, source?: string): string | null { function validateMobilePairingUrl(url: string, source?: string): string | null {

View File

@@ -66,6 +66,7 @@ export type {
MemoryEmbeddingBatchOptions, MemoryEmbeddingBatchOptions,
MemoryEmbeddingProvider, MemoryEmbeddingProvider,
MemoryEmbeddingProviderAdapter, MemoryEmbeddingProviderAdapter,
MemoryEmbeddingProviderCallOptions,
MemoryEmbeddingProviderCreateOptions, MemoryEmbeddingProviderCreateOptions,
MemoryEmbeddingProviderCreateResult, MemoryEmbeddingProviderCreateResult,
MemoryEmbeddingProviderRuntime, MemoryEmbeddingProviderRuntime,

View File

@@ -17,6 +17,10 @@ export type MemoryEmbeddingBatchOptions = {
debug: (message: string, data?: Record<string, unknown>) => void; debug: (message: string, data?: Record<string, unknown>) => void;
}; };
export type MemoryEmbeddingProviderCallOptions = {
signal?: AbortSignal;
};
export type MemoryEmbeddingProviderRuntime = { export type MemoryEmbeddingProviderRuntime = {
id: string; id: string;
cacheKeyData?: Record<string, unknown>; cacheKeyData?: Record<string, unknown>;
@@ -29,9 +33,15 @@ export type MemoryEmbeddingProvider = {
id: string; id: string;
model: string; model: string;
maxInputTokens?: number; maxInputTokens?: number;
embedQuery: (text: string) => Promise<number[]>; embedQuery: (text: string, options?: MemoryEmbeddingProviderCallOptions) => Promise<number[]>;
embedBatch: (texts: string[]) => Promise<number[][]>; embedBatch: (
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>; texts: string[],
options?: MemoryEmbeddingProviderCallOptions,
) => Promise<number[][]>;
embedBatchInputs?: (
inputs: EmbeddingInput[],
options?: MemoryEmbeddingProviderCallOptions,
) => Promise<number[][]>;
}; };
export type MemoryEmbeddingProviderCreateOptions = { export type MemoryEmbeddingProviderCreateOptions = {

View File

@@ -964,9 +964,9 @@ export type PluginEmbeddingProvider = {
id: string; id: string;
model: string; model: string;
maxInputTokens?: number; maxInputTokens?: number;
embedQuery: (text: string) => Promise<number[]>; embedQuery: (text: string, options?: { signal?: AbortSignal }) => Promise<number[]>;
embedBatch: (texts: string[]) => Promise<number[][]>; embedBatch: (texts: string[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
embedBatchInputs?: (inputs: unknown[]) => Promise<number[][]>; embedBatchInputs?: (inputs: unknown[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
client?: unknown; client?: unknown;
}; };