mirror of
https://github.com/openclaw/openclaw.git
synced 2026-06-06 05:51:15 +08:00
fix(memory): abort timed-out embedding requests (#82770)
* fix(memory): abort timed-out embedding requests * test: stabilize gateway ci shards * test: pin control ui origin fixture * test: stabilize gateway ci fixtures * test: isolate forged origin fixture * test: decouple setup code from gateway net mocks * test: repair run-node and config preaction CI * test: fix run-node progress fixture typing * test: remove unused pairing setup helper * fix: stabilize embedding timeout errors
This commit is contained in:
committed by
GitHub
parent
b77b3a7ade
commit
a6225060f1
@@ -34,6 +34,7 @@ Docs: https://docs.openclaw.ai
|
||||
- Telegram: let catch-all mention patterns match captionless group photos, so media-only group messages reach the agent when the group is intentionally configured to respond to all messages. Fixes #44833. (#82756) Thanks @IWhatsskill.
|
||||
- Gateway/pairing: reject forged loopback Control UI origins from non-local proxy paths, and keep mobile pairing setup on Tailscale bind mode pointing users to Tailscale Serve/Funnel instead of cleartext tailnet WebSockets.
|
||||
- Telegram/Gateway: persist isolated polling offsets only after main-thread dispatch and preserve gateway caller scopes for Telegram message actions, fixing consumed-but-unrouted polling updates and recursive CLI send scope approvals. Fixes #82277. (#82705) Thanks @udaymanish6.
|
||||
- Memory-core: abort timed-out embedding provider calls so remote embedding HTTP requests do not continue running after memory query or indexing timeouts. Fixes #82732. Thanks @adityarya24.
|
||||
- Channels/stream previews: contain rejected background draft-stream flushes so preview send failures do not surface as fatal unhandled rejections. Fixes #82712. (#82713) Thanks @coygeek.
|
||||
- Codex/app-server: keep shared native app-server clients isolated per agent runtime key so starting one agent no longer closes another agent's active Codex turn. Fixes #82758. Thanks @PashaGanson.
|
||||
- Providers/OpenAI Codex: include base `gpt-5.5` and `gpt-5.4` reasoning metadata in the bundled Codex catalog so `/think xhigh` remains available for those models. Fixes #82744.
|
||||
|
||||
@@ -331,7 +331,7 @@ export async function createBedrockEmbeddingProvider(
|
||||
family,
|
||||
});
|
||||
|
||||
const invoke = async (body: string): Promise<string> => {
|
||||
const invoke = async (body: string, signal?: AbortSignal): Promise<string> => {
|
||||
await refreshAwsSharedConfigCacheForBedrock();
|
||||
const sdk = new BedrockRuntimeClient({ region: client.region });
|
||||
try {
|
||||
@@ -342,6 +342,7 @@ export async function createBedrockEmbeddingProvider(
|
||||
contentType: "application/json",
|
||||
accept: "application/json",
|
||||
}),
|
||||
signal ? { abortSignal: signal } : undefined,
|
||||
);
|
||||
return new TextDecoder().decode(res.body);
|
||||
} finally {
|
||||
@@ -351,37 +352,46 @@ export async function createBedrockEmbeddingProvider(
|
||||
|
||||
const isCohere = family === "cohere-v3" || family === "cohere-v4";
|
||||
|
||||
const embedSingle = async (text: string): Promise<number[]> => {
|
||||
const raw = await invoke(buildBody(family, text, client.dimensions));
|
||||
const embedSingle = async (text: string, signal?: AbortSignal): Promise<number[]> => {
|
||||
const raw = await invoke(buildBody(family, text, client.dimensions), signal);
|
||||
return sanitizeAndNormalizeEmbedding(parseSingle(family, raw));
|
||||
};
|
||||
|
||||
const embedCohere = async (
|
||||
texts: string[],
|
||||
inputType: "search_query" | "search_document",
|
||||
signal?: AbortSignal,
|
||||
): Promise<number[][]> => {
|
||||
const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions));
|
||||
const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions), signal);
|
||||
return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e));
|
||||
};
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
const embedQuery = async (
|
||||
text: string,
|
||||
options?: { signal?: AbortSignal },
|
||||
): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
if (isCohere) {
|
||||
return (await embedCohere([text], "search_query"))[0] ?? [];
|
||||
return (await embedCohere([text], "search_query", options?.signal))[0] ?? [];
|
||||
}
|
||||
return embedSingle(text);
|
||||
return embedSingle(text, options?.signal);
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
const embedBatch = async (
|
||||
texts: string[],
|
||||
options?: { signal?: AbortSignal },
|
||||
): Promise<number[][]> => {
|
||||
if (texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
if (isCohere) {
|
||||
return embedCohere(texts, "search_document");
|
||||
return embedCohere(texts, "search_document", options?.signal);
|
||||
}
|
||||
return Promise.all(texts.map((t) => (t.trim() ? embedSingle(t) : Promise.resolve([]))));
|
||||
return Promise.all(
|
||||
texts.map((t) => (t.trim() ? embedSingle(t, options?.signal) : Promise.resolve([]))),
|
||||
);
|
||||
};
|
||||
|
||||
return {
|
||||
|
||||
@@ -221,7 +221,7 @@ async function createGitHubCopilotEmbeddingProvider(
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
|
||||
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
const embed = async (input: string[], signal?: AbortSignal): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
@@ -232,6 +232,7 @@ async function createGitHubCopilotEmbeddingProvider(
|
||||
url,
|
||||
fetchImpl: client.fetchImpl,
|
||||
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
|
||||
signal,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: session.headers,
|
||||
@@ -259,11 +260,11 @@ async function createGitHubCopilotEmbeddingProvider(
|
||||
provider: {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
model: client.model,
|
||||
embedQuery: async (text) => {
|
||||
const [vector] = await embed([text]);
|
||||
embedQuery: async (text, options) => {
|
||||
const [vector] = await embed([text], options?.signal);
|
||||
return vector ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
embedBatch: async (texts, options) => await embed(texts, options?.signal),
|
||||
},
|
||||
client: {
|
||||
...client,
|
||||
|
||||
@@ -242,6 +242,7 @@ async function fetchGeminiEmbeddingPayload(params: {
|
||||
client: GeminiEmbeddingClient;
|
||||
endpoint: string;
|
||||
body: unknown;
|
||||
signal?: AbortSignal;
|
||||
}): Promise<Record<string, unknown>> {
|
||||
return await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
@@ -256,6 +257,7 @@ async function fetchGeminiEmbeddingPayload(params: {
|
||||
return await withRemoteHttpResponse({
|
||||
url: params.endpoint,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
signal: params.signal,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers,
|
||||
@@ -316,7 +318,10 @@ export async function createGeminiEmbeddingProvider(
|
||||
const isV2 = isGeminiEmbedding2Model(client.model);
|
||||
const outputDimensionality = client.outputDimensionality;
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
const embedQuery = async (
|
||||
text: string,
|
||||
callOptions?: { signal?: AbortSignal },
|
||||
): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
@@ -328,11 +333,15 @@ export async function createGeminiEmbeddingProvider(
|
||||
taskType: options.taskType ?? "RETRIEVAL_QUERY",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
signal: callOptions?.signal,
|
||||
});
|
||||
return sanitizeAndNormalizeEmbedding(readGeminiSingleEmbedding(payload));
|
||||
};
|
||||
|
||||
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
|
||||
const embedBatchInputs = async (
|
||||
inputs: EmbeddingInput[],
|
||||
callOptions?: { signal?: AbortSignal },
|
||||
): Promise<number[][]> => {
|
||||
if (inputs.length === 0) {
|
||||
return [];
|
||||
}
|
||||
@@ -349,16 +358,21 @@ export async function createGeminiEmbeddingProvider(
|
||||
}),
|
||||
),
|
||||
},
|
||||
signal: callOptions?.signal,
|
||||
});
|
||||
const embeddings = readGeminiBatchEmbeddings(payload, inputs.length);
|
||||
return embeddings.map((values) => sanitizeAndNormalizeEmbedding(values));
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
const embedBatch = async (
|
||||
texts: string[],
|
||||
options?: { signal?: AbortSignal },
|
||||
): Promise<number[][]> => {
|
||||
return await embedBatchInputs(
|
||||
texts.map((text) => ({
|
||||
text,
|
||||
})),
|
||||
options,
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -111,6 +111,34 @@ export function resolveMemoryIndexConcurrency(params: {
|
||||
return params.providerId === "ollama" ? 1 : EMBEDDING_INDEX_CONCURRENCY;
|
||||
}
|
||||
|
||||
export async function runEmbeddingOperationWithTimeout<T>(params: {
|
||||
timeoutMs: number;
|
||||
message: string;
|
||||
run: (signal: AbortSignal) => Promise<T>;
|
||||
}): Promise<T> {
|
||||
const controller = new AbortController();
|
||||
if (!Number.isFinite(params.timeoutMs) || params.timeoutMs <= 0) {
|
||||
return await params.run(controller.signal);
|
||||
}
|
||||
let timer: NodeJS.Timeout | null = null;
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
timer = setTimeout(() => {
|
||||
const error = new Error(params.message);
|
||||
reject(error);
|
||||
controller.abort(error);
|
||||
}, params.timeoutMs);
|
||||
timer.unref?.();
|
||||
});
|
||||
try {
|
||||
const operation = params.run(controller.signal);
|
||||
return (await Promise.race([operation, timeoutPromise])) as T;
|
||||
} finally {
|
||||
if (timer) {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
||||
protected abstract batchFailureCount: number;
|
||||
protected abstract batchFailureLastError?: string;
|
||||
@@ -304,11 +332,11 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
||||
items: texts.length,
|
||||
timeoutMs,
|
||||
});
|
||||
return await this.withTimeout(
|
||||
provider.embedBatch(texts),
|
||||
return await runEmbeddingOperationWithTimeout({
|
||||
timeoutMs,
|
||||
`memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||
);
|
||||
message: `memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||
run: async (signal) => await provider.embedBatch(texts, { signal }),
|
||||
});
|
||||
},
|
||||
isRetryable: isRetryableMemoryEmbeddingError,
|
||||
waitForRetry: async (delayMs) => {
|
||||
@@ -336,11 +364,11 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
||||
items: inputs.length,
|
||||
timeoutMs,
|
||||
});
|
||||
return await this.withTimeout(
|
||||
embedBatchInputs(inputs),
|
||||
return await runEmbeddingOperationWithTimeout({
|
||||
timeoutMs,
|
||||
`memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||
);
|
||||
message: `memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||
run: async (signal) => await embedBatchInputs(inputs, { signal }),
|
||||
});
|
||||
},
|
||||
isRetryable: isRetryableMemoryEmbeddingError,
|
||||
waitForRetry: async (delayMs) => {
|
||||
@@ -371,16 +399,17 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
||||
}
|
||||
|
||||
protected async embedQueryWithTimeout(text: string): Promise<number[]> {
|
||||
if (!this.provider) {
|
||||
const provider = this.provider;
|
||||
if (!provider) {
|
||||
throw new Error("Cannot embed query in FTS-only mode (no embedding provider)");
|
||||
}
|
||||
const timeoutMs = this.resolveEmbeddingTimeout("query");
|
||||
log.debug("memory embeddings: query start", { provider: this.provider.id, timeoutMs });
|
||||
return await this.withTimeout(
|
||||
this.provider.embedQuery(text),
|
||||
log.debug("memory embeddings: query start", { provider: provider.id, timeoutMs });
|
||||
return await runEmbeddingOperationWithTimeout({
|
||||
timeoutMs,
|
||||
`memory embeddings query timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||
);
|
||||
message: `memory embeddings query timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||
run: async (signal) => await provider.embedQuery(text, { signal }),
|
||||
});
|
||||
}
|
||||
|
||||
protected async withTimeout<T>(
|
||||
|
||||
@@ -2,6 +2,7 @@ import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
resolveEmbeddingTimeoutMs,
|
||||
resolveMemoryIndexConcurrency,
|
||||
runEmbeddingOperationWithTimeout,
|
||||
} from "./manager-embedding-ops.js";
|
||||
|
||||
describe("memory embedding timeout resolution", () => {
|
||||
@@ -37,6 +38,42 @@ describe("memory embedding timeout resolution", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("memory embedding timeout abort", () => {
|
||||
it("aborts the provider operation when the timeout fires", async () => {
|
||||
let signalSeen: AbortSignal | undefined;
|
||||
|
||||
await expect(
|
||||
runEmbeddingOperationWithTimeout({
|
||||
timeoutMs: 1,
|
||||
message: "memory embeddings query timed out after 0s",
|
||||
run: async (signal) => {
|
||||
signalSeen = signal;
|
||||
return await new Promise<number[]>((resolve, reject) => {
|
||||
signal.addEventListener("abort", () => reject(signal.reason), { once: true });
|
||||
});
|
||||
},
|
||||
}),
|
||||
).rejects.toThrow("memory embeddings query timed out after 0s");
|
||||
|
||||
expect(signalSeen?.aborted).toBe(true);
|
||||
});
|
||||
|
||||
it("keeps the timeout error when a provider abort listener rejects generically", async () => {
|
||||
await expect(
|
||||
runEmbeddingOperationWithTimeout({
|
||||
timeoutMs: 1,
|
||||
message: "memory embeddings batch timed out after 0s",
|
||||
run: async (signal) =>
|
||||
await new Promise<number[]>((_resolve, reject) => {
|
||||
signal.addEventListener("abort", () => reject(new Error("provider aborted")), {
|
||||
once: true,
|
||||
});
|
||||
}),
|
||||
}),
|
||||
).rejects.toThrow("memory embeddings batch timed out after 0s");
|
||||
});
|
||||
});
|
||||
|
||||
describe("memory index concurrency resolution", () => {
|
||||
it("uses the default index concurrency when batch mode is disabled and unconfigured", () => {
|
||||
expect(
|
||||
|
||||
@@ -25,8 +25,8 @@ export type OllamaEmbeddingProvider = {
|
||||
id: string;
|
||||
model: string;
|
||||
maxInputTokens?: number;
|
||||
embedQuery: (text: string) => Promise<number[]>;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
embedQuery: (text: string, options?: { signal?: AbortSignal }) => Promise<number[]>;
|
||||
embedBatch: (texts: string[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
|
||||
};
|
||||
|
||||
type OllamaEmbeddingOptions = {
|
||||
@@ -90,12 +90,14 @@ function sanitizeAndNormalizeEmbedding(vec: unknown[]): number[] {
|
||||
async function withRemoteHttpResponse<T>(params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
signal?: AbortSignal;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
onResponse: (response: Response) => Promise<T>;
|
||||
}): Promise<T> {
|
||||
const { response, release } = await fetchWithSsrFGuard({
|
||||
url: params.url,
|
||||
init: params.init,
|
||||
signal: params.signal,
|
||||
policy: params.ssrfPolicy,
|
||||
auditContext: "memory-remote",
|
||||
});
|
||||
@@ -322,10 +324,11 @@ export async function createOllamaEmbeddingProvider(
|
||||
const client = resolveOllamaEmbeddingClient(options);
|
||||
const embedUrl = `${client.baseUrl.replace(/\/$/, "")}/api/embed`;
|
||||
|
||||
const embedMany = async (input: string | string[]): Promise<number[][]> => {
|
||||
const embedMany = async (input: string | string[], signal?: AbortSignal): Promise<number[][]> => {
|
||||
const json = await withRemoteHttpResponse({
|
||||
url: embedUrl,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
signal,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: client.headers,
|
||||
@@ -355,22 +358,23 @@ export async function createOllamaEmbeddingProvider(
|
||||
});
|
||||
};
|
||||
|
||||
const embedOne = async (text: string): Promise<number[]> => {
|
||||
const [embedding] = await embedMany(text);
|
||||
const embedOne = async (text: string, signal?: AbortSignal): Promise<number[]> => {
|
||||
const [embedding] = await embedMany(text, signal);
|
||||
if (!embedding) {
|
||||
throw new Error("Ollama embed response returned no embedding");
|
||||
}
|
||||
return embedding;
|
||||
};
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> =>
|
||||
await embedOne(applyQueryInstructionTemplate(client.model, text));
|
||||
const embedQuery = async (text: string, options?: { signal?: AbortSignal }): Promise<number[]> =>
|
||||
await embedOne(applyQueryInstructionTemplate(client.model, text), options?.signal);
|
||||
|
||||
const provider: OllamaEmbeddingProvider = {
|
||||
id: "ollama",
|
||||
model: client.model,
|
||||
embedQuery,
|
||||
embedBatch: async (texts) => (texts.length === 0 ? [] : await embedMany(texts)),
|
||||
embedBatch: async (texts, options) =>
|
||||
texts.length === 0 ? [] : await embedMany(texts, options?.signal),
|
||||
};
|
||||
|
||||
return {
|
||||
|
||||
@@ -35,6 +35,7 @@ function expectFetchRemoteEmbeddingVectorsBody(body: Record<string, unknown>) {
|
||||
headers: { Authorization: "Bearer test" },
|
||||
ssrfPolicy: undefined,
|
||||
fetchImpl: undefined,
|
||||
signal: undefined,
|
||||
body,
|
||||
errorPrefix: "openai embeddings failed",
|
||||
});
|
||||
|
||||
@@ -47,7 +47,11 @@ export async function createOpenAiEmbeddingProvider(
|
||||
return typeof value === "string" && value.trim().length > 0 ? value.trim() : undefined;
|
||||
};
|
||||
|
||||
const embed = async (input: string[], kind: "query" | "document"): Promise<number[][]> => {
|
||||
const embed = async (
|
||||
input: string[],
|
||||
kind: "query" | "document",
|
||||
signal?: AbortSignal,
|
||||
): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
@@ -57,6 +61,7 @@ export async function createOpenAiEmbeddingProvider(
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
fetchImpl: client.fetchImpl,
|
||||
signal,
|
||||
body: {
|
||||
model: client.model,
|
||||
input,
|
||||
@@ -76,11 +81,11 @@ export async function createOpenAiEmbeddingProvider(
|
||||
...(typeof OPENAI_MAX_INPUT_TOKENS[client.model] === "number"
|
||||
? { maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model] }
|
||||
: {}),
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text], "query");
|
||||
embedQuery: async (text, options) => {
|
||||
const [vec] = await embed([text], "query", options?.signal);
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: async (texts) => await embed(texts, "document"),
|
||||
embedBatch: async (texts, options) => await embed(texts, "document", options?.signal),
|
||||
},
|
||||
client,
|
||||
};
|
||||
|
||||
@@ -36,7 +36,11 @@ export async function createVoyageEmbeddingProvider(
|
||||
const client = await resolveVoyageEmbeddingClient(options);
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => {
|
||||
const embed = async (
|
||||
input: string[],
|
||||
input_type?: "query" | "document",
|
||||
signal?: AbortSignal,
|
||||
): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
@@ -52,6 +56,7 @@ export async function createVoyageEmbeddingProvider(
|
||||
url,
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
signal,
|
||||
body,
|
||||
errorPrefix: "voyage embeddings failed",
|
||||
});
|
||||
@@ -62,11 +67,11 @@ export async function createVoyageEmbeddingProvider(
|
||||
id: "voyage",
|
||||
model: client.model,
|
||||
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text], "query");
|
||||
embedQuery: async (text, options) => {
|
||||
const [vec] = await embed([text], "query", options?.signal);
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: async (texts) => embed(texts, "document"),
|
||||
embedBatch: async (texts, options) => embed(texts, "document", options?.signal),
|
||||
},
|
||||
client,
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ export type {
|
||||
MemoryEmbeddingBatchOptions,
|
||||
MemoryEmbeddingProvider,
|
||||
MemoryEmbeddingProviderAdapter,
|
||||
MemoryEmbeddingProviderCallOptions,
|
||||
MemoryEmbeddingProviderCreateOptions,
|
||||
MemoryEmbeddingProviderCreateResult,
|
||||
MemoryEmbeddingProviderRuntime,
|
||||
|
||||
@@ -10,6 +10,7 @@ vi.mock("./post-json.js", () => ({
|
||||
function requirePostJsonParams(): {
|
||||
url?: unknown;
|
||||
headers?: unknown;
|
||||
signal?: unknown;
|
||||
body?: unknown;
|
||||
errorPrefix?: unknown;
|
||||
} {
|
||||
@@ -51,6 +52,23 @@ describe("fetchRemoteEmbeddingVectors", () => {
|
||||
expect(postJsonParams.errorPrefix).toBe("embedding fetch failed");
|
||||
});
|
||||
|
||||
it("passes abort signals to the JSON request", async () => {
|
||||
const controller = new AbortController();
|
||||
postJsonMock.mockImplementationOnce(async (params) => {
|
||||
return await params.parse({ data: [{ embedding: [0.1] }] });
|
||||
});
|
||||
|
||||
await fetchRemoteEmbeddingVectors({
|
||||
url: "https://memory.example/v1/embeddings",
|
||||
headers: {},
|
||||
signal: controller.signal,
|
||||
body: { input: ["one"] },
|
||||
errorPrefix: "embedding fetch failed",
|
||||
});
|
||||
|
||||
expect(requirePostJsonParams().signal).toBe(controller.signal);
|
||||
});
|
||||
|
||||
it("throws a status-rich error on non-ok responses", async () => {
|
||||
postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden"));
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ export async function fetchRemoteEmbeddingVectors(params: {
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
fetchImpl?: typeof fetch;
|
||||
signal?: AbortSignal;
|
||||
body: unknown;
|
||||
errorPrefix: string;
|
||||
}): Promise<number[][]> {
|
||||
@@ -41,6 +42,7 @@ export async function fetchRemoteEmbeddingVectors(params: {
|
||||
headers: params.headers,
|
||||
ssrfPolicy: params.ssrfPolicy,
|
||||
fetchImpl: params.fetchImpl,
|
||||
signal: params.signal,
|
||||
body: params.body,
|
||||
errorPrefix: params.errorPrefix,
|
||||
parse: (payload) => {
|
||||
|
||||
@@ -23,7 +23,7 @@ export function createRemoteEmbeddingProvider(params: {
|
||||
const { client } = params;
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
const embed = async (input: string[], signal?: AbortSignal): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
@@ -32,6 +32,7 @@ export function createRemoteEmbeddingProvider(params: {
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
fetchImpl: client.fetchImpl,
|
||||
signal,
|
||||
body: { model: client.model, input },
|
||||
errorPrefix: params.errorPrefix,
|
||||
});
|
||||
@@ -41,11 +42,11 @@ export function createRemoteEmbeddingProvider(params: {
|
||||
id: params.id,
|
||||
model: client.model,
|
||||
...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}),
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text]);
|
||||
embedQuery: async (text, options) => {
|
||||
const [vec] = await embed([text], options?.signal);
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
embedBatch: async (texts, options) => await embed(texts, options?.signal),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -66,15 +66,20 @@ export async function createLocalEmbeddingProvider(
|
||||
return {
|
||||
id: "local",
|
||||
model: modelPath,
|
||||
embedQuery: async (text) => {
|
||||
embedQuery: async (text, options) => {
|
||||
options?.signal?.throwIfAborted();
|
||||
const ctx = await ensureContext();
|
||||
options?.signal?.throwIfAborted();
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
},
|
||||
embedBatch: async (texts) => {
|
||||
embedBatch: async (texts, options) => {
|
||||
options?.signal?.throwIfAborted();
|
||||
const ctx = await ensureContext();
|
||||
options?.signal?.throwIfAborted();
|
||||
const embeddings = await Promise.all(
|
||||
texts.map(async (text) => {
|
||||
options?.signal?.throwIfAborted();
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
}),
|
||||
|
||||
@@ -5,9 +5,16 @@ export type EmbeddingProvider = {
|
||||
id: string;
|
||||
model: string;
|
||||
maxInputTokens?: number;
|
||||
embedQuery: (text: string) => Promise<number[]>;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
||||
embedQuery: (text: string, options?: EmbeddingProviderCallOptions) => Promise<number[]>;
|
||||
embedBatch: (texts: string[], options?: EmbeddingProviderCallOptions) => Promise<number[][]>;
|
||||
embedBatchInputs?: (
|
||||
inputs: EmbeddingInput[],
|
||||
options?: EmbeddingProviderCallOptions,
|
||||
) => Promise<number[][]>;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderCallOptions = {
|
||||
signal?: AbortSignal;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderId = string;
|
||||
|
||||
@@ -15,6 +15,7 @@ export type {
|
||||
MemoryEmbeddingBatchOptions,
|
||||
MemoryEmbeddingProvider,
|
||||
MemoryEmbeddingProviderAdapter,
|
||||
MemoryEmbeddingProviderCallOptions,
|
||||
MemoryEmbeddingProviderCreateOptions,
|
||||
MemoryEmbeddingProviderCreateResult,
|
||||
MemoryEmbeddingProviderRuntime,
|
||||
|
||||
@@ -112,6 +112,7 @@ export type {
|
||||
MemoryEmbeddingBatchOptions,
|
||||
MemoryEmbeddingProvider,
|
||||
MemoryEmbeddingProviderAdapter,
|
||||
MemoryEmbeddingProviderCallOptions,
|
||||
MemoryEmbeddingProviderCreateOptions,
|
||||
MemoryEmbeddingProviderCreateResult,
|
||||
MemoryEmbeddingProviderRuntime,
|
||||
|
||||
@@ -47,6 +47,23 @@ describe("postJson", () => {
|
||||
expect(result).toEqual({ data: [{ embedding: [1, 2] }] });
|
||||
});
|
||||
|
||||
it("forwards abort signals to the remote HTTP request", async () => {
|
||||
const controller = new AbortController();
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.signal).toBe(controller.signal);
|
||||
return await params.onResponse(jsonResponse({ ok: true }));
|
||||
});
|
||||
|
||||
await postJson({
|
||||
url: "https://memory.example/v1/post",
|
||||
headers: {},
|
||||
body: {},
|
||||
signal: controller.signal,
|
||||
errorPrefix: "post failed",
|
||||
parse: (payload) => payload,
|
||||
});
|
||||
});
|
||||
|
||||
it("attaches status to thrown error when requested", async () => {
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
return await params.onResponse(textResponse("bad gateway", 502));
|
||||
|
||||
@@ -6,6 +6,7 @@ export async function postJson<T>(params: {
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
fetchImpl?: typeof fetch;
|
||||
signal?: AbortSignal;
|
||||
body: unknown;
|
||||
errorPrefix: string;
|
||||
attachStatus?: boolean;
|
||||
@@ -15,6 +16,7 @@ export async function postJson<T>(params: {
|
||||
url: params.url,
|
||||
ssrfPolicy: params.ssrfPolicy,
|
||||
fetchImpl: params.fetchImpl,
|
||||
signal: params.signal,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: params.headers,
|
||||
|
||||
@@ -43,4 +43,18 @@ describe("package withRemoteHttpResponse", () => {
|
||||
expect(deps.calls).toHaveLength(1);
|
||||
expect(deps.calls[0]).not.toHaveProperty("mode");
|
||||
});
|
||||
|
||||
it("passes abort signals to the guarded fetch", async () => {
|
||||
const deps = makeFetchDeps();
|
||||
const controller = new AbortController();
|
||||
|
||||
await withRemoteHttpResponse({
|
||||
url: "https://memory.example/v1/embeddings",
|
||||
signal: controller.signal,
|
||||
onResponse: async () => undefined,
|
||||
...deps,
|
||||
});
|
||||
|
||||
expect(deps.calls[0]).toHaveProperty("signal", controller.signal);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ export const buildRemoteBaseUrlPolicy: (baseUrl: string) => SsrFPolicy | undefin
|
||||
export async function withRemoteHttpResponse<T>(params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
signal?: AbortSignal;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
fetchImpl?: typeof fetch;
|
||||
fetchWithSsrFGuardImpl?: typeof fetchWithSsrFGuard;
|
||||
@@ -26,6 +27,7 @@ export async function withRemoteHttpResponse<T>(params: {
|
||||
url: params.url,
|
||||
fetchImpl: params.fetchImpl,
|
||||
init: params.init,
|
||||
signal: params.signal,
|
||||
policy: params.ssrfPolicy,
|
||||
auditContext: params.auditContext ?? "memory-remote",
|
||||
...(shouldUseEnvProxy(params.url) ? { mode: MEMORY_REMOTE_TRUSTED_ENV_PROXY_MODE } : {}),
|
||||
|
||||
@@ -51,5 +51,9 @@ export function runNodeMain(params?: {
|
||||
cwd?: string;
|
||||
args?: string[];
|
||||
env?: NodeJS.ProcessEnv;
|
||||
runRuntimePostBuild?: (params?: {
|
||||
cwd?: string;
|
||||
env?: Record<string, string | undefined>;
|
||||
}) => void | Promise<void>;
|
||||
platform?: NodeJS.Platform;
|
||||
}): Promise<number>;
|
||||
|
||||
@@ -79,6 +79,22 @@ function isGuidedConfigAction(actionCommand: Command): boolean {
|
||||
return actionCommand.name() === "config" && !actionCommand.parent?.parent;
|
||||
}
|
||||
|
||||
function isGuidedConfigCommandPath(commandPath: string[]): boolean {
|
||||
const [primary, secondary, extra] = commandPath;
|
||||
if (primary !== "config" || extra !== undefined) {
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
secondary !== "get" &&
|
||||
secondary !== "set" &&
|
||||
secondary !== "patch" &&
|
||||
secondary !== "unset" &&
|
||||
secondary !== "file" &&
|
||||
secondary !== "schema" &&
|
||||
secondary !== "validate"
|
||||
);
|
||||
}
|
||||
|
||||
export function registerPreActionHooks(program: Command, programVersion: string) {
|
||||
program.hook("preAction", async (_thisCommand, actionCommand) => {
|
||||
setProcessTitleForCommand(actionCommand);
|
||||
@@ -105,7 +121,11 @@ export function registerPreActionHooks(program: Command, programVersion: string)
|
||||
if (!verbose) {
|
||||
process.env.NODE_NO_WARNINGS ??= "1";
|
||||
}
|
||||
if (shouldBypassConfigGuardForCommandPath(commandPath) || isGuidedConfigAction(actionCommand)) {
|
||||
if (
|
||||
shouldBypassConfigGuardForCommandPath(commandPath) ||
|
||||
isGuidedConfigAction(actionCommand) ||
|
||||
isGuidedConfigCommandPath(commandPath)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
await ensureCliExecutionBootstrap({
|
||||
|
||||
@@ -430,10 +430,19 @@ describe("gateway auth browser hardening", () => {
|
||||
});
|
||||
|
||||
test("rejects forged loopback origin for control-ui when proxy headers make client non-local", async () => {
|
||||
const { writeConfigFile } = await import("../config/config.js");
|
||||
await writeConfigFile({
|
||||
gateway: {
|
||||
trustedProxies: ["127.0.0.1"],
|
||||
controlUi: {
|
||||
allowedOrigins: [],
|
||||
},
|
||||
},
|
||||
});
|
||||
testState.gatewayAuth = { mode: "token", token: "secret" };
|
||||
await withGatewayServer(async ({ port }) => {
|
||||
const ws = await openWs(port, {
|
||||
origin: originForPort(port),
|
||||
origin: "http://localhost:5173",
|
||||
"x-forwarded-for": "203.0.113.50",
|
||||
});
|
||||
try {
|
||||
@@ -444,6 +453,7 @@ describe("gateway auth browser hardening", () => {
|
||||
id: GATEWAY_CLIENT_NAMES.CONTROL_UI,
|
||||
mode: GATEWAY_CLIENT_MODES.UI,
|
||||
},
|
||||
device: null,
|
||||
});
|
||||
expect(res.ok).toBe(false);
|
||||
expect(res.error?.message ?? "").toContain("origin not allowed");
|
||||
|
||||
@@ -1415,9 +1415,6 @@ describe("run-node script", () => {
|
||||
return true;
|
||||
}),
|
||||
} as unknown as NodeJS.WriteStream;
|
||||
const stdout = {
|
||||
write: vi.fn(() => true),
|
||||
} as unknown as NodeJS.WriteStream;
|
||||
|
||||
const exitCode = await runNodeMain({
|
||||
cwd: tmp,
|
||||
@@ -1430,11 +1427,10 @@ describe("run-node script", () => {
|
||||
spawn,
|
||||
spawnSync,
|
||||
stderr,
|
||||
stdout,
|
||||
runRuntimePostBuild: async () => {},
|
||||
execPath: process.execPath,
|
||||
platform: process.platform,
|
||||
} as Parameters<typeof runNodeMain>[0] & { stdout: NodeJS.WriteStream });
|
||||
} as Parameters<typeof runNodeMain>[0]);
|
||||
|
||||
expect(exitCode).toBe(0);
|
||||
const stderrText = stderrChunks.join("");
|
||||
|
||||
@@ -4,7 +4,6 @@ import type { OpenClawConfig } from "../config/types.js";
|
||||
import { normalizeSecretInputString, resolveSecretInputRef } from "../config/types.secrets.js";
|
||||
import { materializeGatewayAuthSecretRefs } from "../gateway/auth-config-utils.js";
|
||||
import { assertExplicitGatewayAuthModeWhenBothConfigured } from "../gateway/auth-mode-policy.js";
|
||||
import { isLoopbackHost } from "../gateway/net.js";
|
||||
import { issueDeviceBootstrapToken } from "../infra/device-bootstrap.js";
|
||||
import {
|
||||
pickMatchingExternalInterfaceAddress,
|
||||
@@ -16,6 +15,7 @@ import {
|
||||
isCarrierGradeNatIpv4Address,
|
||||
isIpv4Address,
|
||||
isIpv6Address,
|
||||
isLoopbackIpAddress,
|
||||
isRfc1918Ipv4Address,
|
||||
parseCanonicalIpAddress,
|
||||
} from "../shared/net/ip.js";
|
||||
@@ -124,7 +124,12 @@ function isPrivateLanHost(host: string): boolean {
|
||||
|
||||
function isMobilePairingCleartextAllowedHost(host: string): boolean {
|
||||
const normalized = normalizeMobilePairingHost(host);
|
||||
return isLoopbackHost(normalized) || normalized === "10.0.2.2" || isPrivateLanHost(normalized);
|
||||
return (
|
||||
normalized === "localhost" ||
|
||||
isLoopbackIpAddress(normalized) ||
|
||||
normalized === "10.0.2.2" ||
|
||||
isPrivateLanHost(normalized)
|
||||
);
|
||||
}
|
||||
|
||||
function validateMobilePairingUrl(url: string, source?: string): string | null {
|
||||
|
||||
@@ -66,6 +66,7 @@ export type {
|
||||
MemoryEmbeddingBatchOptions,
|
||||
MemoryEmbeddingProvider,
|
||||
MemoryEmbeddingProviderAdapter,
|
||||
MemoryEmbeddingProviderCallOptions,
|
||||
MemoryEmbeddingProviderCreateOptions,
|
||||
MemoryEmbeddingProviderCreateResult,
|
||||
MemoryEmbeddingProviderRuntime,
|
||||
|
||||
@@ -17,6 +17,10 @@ export type MemoryEmbeddingBatchOptions = {
|
||||
debug: (message: string, data?: Record<string, unknown>) => void;
|
||||
};
|
||||
|
||||
export type MemoryEmbeddingProviderCallOptions = {
|
||||
signal?: AbortSignal;
|
||||
};
|
||||
|
||||
export type MemoryEmbeddingProviderRuntime = {
|
||||
id: string;
|
||||
cacheKeyData?: Record<string, unknown>;
|
||||
@@ -29,9 +33,15 @@ export type MemoryEmbeddingProvider = {
|
||||
id: string;
|
||||
model: string;
|
||||
maxInputTokens?: number;
|
||||
embedQuery: (text: string) => Promise<number[]>;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
||||
embedQuery: (text: string, options?: MemoryEmbeddingProviderCallOptions) => Promise<number[]>;
|
||||
embedBatch: (
|
||||
texts: string[],
|
||||
options?: MemoryEmbeddingProviderCallOptions,
|
||||
) => Promise<number[][]>;
|
||||
embedBatchInputs?: (
|
||||
inputs: EmbeddingInput[],
|
||||
options?: MemoryEmbeddingProviderCallOptions,
|
||||
) => Promise<number[][]>;
|
||||
};
|
||||
|
||||
export type MemoryEmbeddingProviderCreateOptions = {
|
||||
|
||||
@@ -964,9 +964,9 @@ export type PluginEmbeddingProvider = {
|
||||
id: string;
|
||||
model: string;
|
||||
maxInputTokens?: number;
|
||||
embedQuery: (text: string) => Promise<number[]>;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
embedBatchInputs?: (inputs: unknown[]) => Promise<number[][]>;
|
||||
embedQuery: (text: string, options?: { signal?: AbortSignal }) => Promise<number[]>;
|
||||
embedBatch: (texts: string[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
|
||||
embedBatchInputs?: (inputs: unknown[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
|
||||
client?: unknown;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user