mirror of
https://github.com/openclaw/openclaw.git
synced 2026-06-06 05:51:15 +08:00
fix(memory): abort timed-out embedding requests (#82770)
* fix(memory): abort timed-out embedding requests * test: stabilize gateway ci shards * test: pin control ui origin fixture * test: stabilize gateway ci fixtures * test: isolate forged origin fixture * test: decouple setup code from gateway net mocks * test: repair run-node and config preaction CI * test: fix run-node progress fixture typing * test: remove unused pairing setup helper * fix: stabilize embedding timeout errors
This commit is contained in:
committed by
GitHub
parent
b77b3a7ade
commit
a6225060f1
@@ -34,6 +34,7 @@ Docs: https://docs.openclaw.ai
|
|||||||
- Telegram: let catch-all mention patterns match captionless group photos, so media-only group messages reach the agent when the group is intentionally configured to respond to all messages. Fixes #44833. (#82756) Thanks @IWhatsskill.
|
- Telegram: let catch-all mention patterns match captionless group photos, so media-only group messages reach the agent when the group is intentionally configured to respond to all messages. Fixes #44833. (#82756) Thanks @IWhatsskill.
|
||||||
- Gateway/pairing: reject forged loopback Control UI origins from non-local proxy paths, and keep mobile pairing setup on Tailscale bind mode pointing users to Tailscale Serve/Funnel instead of cleartext tailnet WebSockets.
|
- Gateway/pairing: reject forged loopback Control UI origins from non-local proxy paths, and keep mobile pairing setup on Tailscale bind mode pointing users to Tailscale Serve/Funnel instead of cleartext tailnet WebSockets.
|
||||||
- Telegram/Gateway: persist isolated polling offsets only after main-thread dispatch and preserve gateway caller scopes for Telegram message actions, fixing consumed-but-unrouted polling updates and recursive CLI send scope approvals. Fixes #82277. (#82705) Thanks @udaymanish6.
|
- Telegram/Gateway: persist isolated polling offsets only after main-thread dispatch and preserve gateway caller scopes for Telegram message actions, fixing consumed-but-unrouted polling updates and recursive CLI send scope approvals. Fixes #82277. (#82705) Thanks @udaymanish6.
|
||||||
|
- Memory-core: abort timed-out embedding provider calls so remote embedding HTTP requests do not continue running after memory query or indexing timeouts. Fixes #82732. Thanks @adityarya24.
|
||||||
- Channels/stream previews: contain rejected background draft-stream flushes so preview send failures do not surface as fatal unhandled rejections. Fixes #82712. (#82713) Thanks @coygeek.
|
- Channels/stream previews: contain rejected background draft-stream flushes so preview send failures do not surface as fatal unhandled rejections. Fixes #82712. (#82713) Thanks @coygeek.
|
||||||
- Codex/app-server: keep shared native app-server clients isolated per agent runtime key so starting one agent no longer closes another agent's active Codex turn. Fixes #82758. Thanks @PashaGanson.
|
- Codex/app-server: keep shared native app-server clients isolated per agent runtime key so starting one agent no longer closes another agent's active Codex turn. Fixes #82758. Thanks @PashaGanson.
|
||||||
- Providers/OpenAI Codex: include base `gpt-5.5` and `gpt-5.4` reasoning metadata in the bundled Codex catalog so `/think xhigh` remains available for those models. Fixes #82744.
|
- Providers/OpenAI Codex: include base `gpt-5.5` and `gpt-5.4` reasoning metadata in the bundled Codex catalog so `/think xhigh` remains available for those models. Fixes #82744.
|
||||||
|
|||||||
@@ -331,7 +331,7 @@ export async function createBedrockEmbeddingProvider(
|
|||||||
family,
|
family,
|
||||||
});
|
});
|
||||||
|
|
||||||
const invoke = async (body: string): Promise<string> => {
|
const invoke = async (body: string, signal?: AbortSignal): Promise<string> => {
|
||||||
await refreshAwsSharedConfigCacheForBedrock();
|
await refreshAwsSharedConfigCacheForBedrock();
|
||||||
const sdk = new BedrockRuntimeClient({ region: client.region });
|
const sdk = new BedrockRuntimeClient({ region: client.region });
|
||||||
try {
|
try {
|
||||||
@@ -342,6 +342,7 @@ export async function createBedrockEmbeddingProvider(
|
|||||||
contentType: "application/json",
|
contentType: "application/json",
|
||||||
accept: "application/json",
|
accept: "application/json",
|
||||||
}),
|
}),
|
||||||
|
signal ? { abortSignal: signal } : undefined,
|
||||||
);
|
);
|
||||||
return new TextDecoder().decode(res.body);
|
return new TextDecoder().decode(res.body);
|
||||||
} finally {
|
} finally {
|
||||||
@@ -351,37 +352,46 @@ export async function createBedrockEmbeddingProvider(
|
|||||||
|
|
||||||
const isCohere = family === "cohere-v3" || family === "cohere-v4";
|
const isCohere = family === "cohere-v3" || family === "cohere-v4";
|
||||||
|
|
||||||
const embedSingle = async (text: string): Promise<number[]> => {
|
const embedSingle = async (text: string, signal?: AbortSignal): Promise<number[]> => {
|
||||||
const raw = await invoke(buildBody(family, text, client.dimensions));
|
const raw = await invoke(buildBody(family, text, client.dimensions), signal);
|
||||||
return sanitizeAndNormalizeEmbedding(parseSingle(family, raw));
|
return sanitizeAndNormalizeEmbedding(parseSingle(family, raw));
|
||||||
};
|
};
|
||||||
|
|
||||||
const embedCohere = async (
|
const embedCohere = async (
|
||||||
texts: string[],
|
texts: string[],
|
||||||
inputType: "search_query" | "search_document",
|
inputType: "search_query" | "search_document",
|
||||||
|
signal?: AbortSignal,
|
||||||
): Promise<number[][]> => {
|
): Promise<number[][]> => {
|
||||||
const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions));
|
const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions), signal);
|
||||||
return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e));
|
return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e));
|
||||||
};
|
};
|
||||||
|
|
||||||
const embedQuery = async (text: string): Promise<number[]> => {
|
const embedQuery = async (
|
||||||
|
text: string,
|
||||||
|
options?: { signal?: AbortSignal },
|
||||||
|
): Promise<number[]> => {
|
||||||
if (!text.trim()) {
|
if (!text.trim()) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
if (isCohere) {
|
if (isCohere) {
|
||||||
return (await embedCohere([text], "search_query"))[0] ?? [];
|
return (await embedCohere([text], "search_query", options?.signal))[0] ?? [];
|
||||||
}
|
}
|
||||||
return embedSingle(text);
|
return embedSingle(text, options?.signal);
|
||||||
};
|
};
|
||||||
|
|
||||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
const embedBatch = async (
|
||||||
|
texts: string[],
|
||||||
|
options?: { signal?: AbortSignal },
|
||||||
|
): Promise<number[][]> => {
|
||||||
if (texts.length === 0) {
|
if (texts.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
if (isCohere) {
|
if (isCohere) {
|
||||||
return embedCohere(texts, "search_document");
|
return embedCohere(texts, "search_document", options?.signal);
|
||||||
}
|
}
|
||||||
return Promise.all(texts.map((t) => (t.trim() ? embedSingle(t) : Promise.resolve([]))));
|
return Promise.all(
|
||||||
|
texts.map((t) => (t.trim() ? embedSingle(t, options?.signal) : Promise.resolve([]))),
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ async function createGitHubCopilotEmbeddingProvider(
|
|||||||
): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
|
): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
|
||||||
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
|
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
|
||||||
|
|
||||||
const embed = async (input: string[]): Promise<number[][]> => {
|
const embed = async (input: string[], signal?: AbortSignal): Promise<number[][]> => {
|
||||||
if (input.length === 0) {
|
if (input.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
@@ -232,6 +232,7 @@ async function createGitHubCopilotEmbeddingProvider(
|
|||||||
url,
|
url,
|
||||||
fetchImpl: client.fetchImpl,
|
fetchImpl: client.fetchImpl,
|
||||||
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
|
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
|
||||||
|
signal,
|
||||||
init: {
|
init: {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: session.headers,
|
headers: session.headers,
|
||||||
@@ -259,11 +260,11 @@ async function createGitHubCopilotEmbeddingProvider(
|
|||||||
provider: {
|
provider: {
|
||||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||||
model: client.model,
|
model: client.model,
|
||||||
embedQuery: async (text) => {
|
embedQuery: async (text, options) => {
|
||||||
const [vector] = await embed([text]);
|
const [vector] = await embed([text], options?.signal);
|
||||||
return vector ?? [];
|
return vector ?? [];
|
||||||
},
|
},
|
||||||
embedBatch: embed,
|
embedBatch: async (texts, options) => await embed(texts, options?.signal),
|
||||||
},
|
},
|
||||||
client: {
|
client: {
|
||||||
...client,
|
...client,
|
||||||
|
|||||||
@@ -242,6 +242,7 @@ async function fetchGeminiEmbeddingPayload(params: {
|
|||||||
client: GeminiEmbeddingClient;
|
client: GeminiEmbeddingClient;
|
||||||
endpoint: string;
|
endpoint: string;
|
||||||
body: unknown;
|
body: unknown;
|
||||||
|
signal?: AbortSignal;
|
||||||
}): Promise<Record<string, unknown>> {
|
}): Promise<Record<string, unknown>> {
|
||||||
return await executeWithApiKeyRotation({
|
return await executeWithApiKeyRotation({
|
||||||
provider: "google",
|
provider: "google",
|
||||||
@@ -256,6 +257,7 @@ async function fetchGeminiEmbeddingPayload(params: {
|
|||||||
return await withRemoteHttpResponse({
|
return await withRemoteHttpResponse({
|
||||||
url: params.endpoint,
|
url: params.endpoint,
|
||||||
ssrfPolicy: params.client.ssrfPolicy,
|
ssrfPolicy: params.client.ssrfPolicy,
|
||||||
|
signal: params.signal,
|
||||||
init: {
|
init: {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers,
|
headers,
|
||||||
@@ -316,7 +318,10 @@ export async function createGeminiEmbeddingProvider(
|
|||||||
const isV2 = isGeminiEmbedding2Model(client.model);
|
const isV2 = isGeminiEmbedding2Model(client.model);
|
||||||
const outputDimensionality = client.outputDimensionality;
|
const outputDimensionality = client.outputDimensionality;
|
||||||
|
|
||||||
const embedQuery = async (text: string): Promise<number[]> => {
|
const embedQuery = async (
|
||||||
|
text: string,
|
||||||
|
callOptions?: { signal?: AbortSignal },
|
||||||
|
): Promise<number[]> => {
|
||||||
if (!text.trim()) {
|
if (!text.trim()) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
@@ -328,11 +333,15 @@ export async function createGeminiEmbeddingProvider(
|
|||||||
taskType: options.taskType ?? "RETRIEVAL_QUERY",
|
taskType: options.taskType ?? "RETRIEVAL_QUERY",
|
||||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||||
}),
|
}),
|
||||||
|
signal: callOptions?.signal,
|
||||||
});
|
});
|
||||||
return sanitizeAndNormalizeEmbedding(readGeminiSingleEmbedding(payload));
|
return sanitizeAndNormalizeEmbedding(readGeminiSingleEmbedding(payload));
|
||||||
};
|
};
|
||||||
|
|
||||||
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
|
const embedBatchInputs = async (
|
||||||
|
inputs: EmbeddingInput[],
|
||||||
|
callOptions?: { signal?: AbortSignal },
|
||||||
|
): Promise<number[][]> => {
|
||||||
if (inputs.length === 0) {
|
if (inputs.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
@@ -349,16 +358,21 @@ export async function createGeminiEmbeddingProvider(
|
|||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
signal: callOptions?.signal,
|
||||||
});
|
});
|
||||||
const embeddings = readGeminiBatchEmbeddings(payload, inputs.length);
|
const embeddings = readGeminiBatchEmbeddings(payload, inputs.length);
|
||||||
return embeddings.map((values) => sanitizeAndNormalizeEmbedding(values));
|
return embeddings.map((values) => sanitizeAndNormalizeEmbedding(values));
|
||||||
};
|
};
|
||||||
|
|
||||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
const embedBatch = async (
|
||||||
|
texts: string[],
|
||||||
|
options?: { signal?: AbortSignal },
|
||||||
|
): Promise<number[][]> => {
|
||||||
return await embedBatchInputs(
|
return await embedBatchInputs(
|
||||||
texts.map((text) => ({
|
texts.map((text) => ({
|
||||||
text,
|
text,
|
||||||
})),
|
})),
|
||||||
|
options,
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,34 @@ export function resolveMemoryIndexConcurrency(params: {
|
|||||||
return params.providerId === "ollama" ? 1 : EMBEDDING_INDEX_CONCURRENCY;
|
return params.providerId === "ollama" ? 1 : EMBEDDING_INDEX_CONCURRENCY;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function runEmbeddingOperationWithTimeout<T>(params: {
|
||||||
|
timeoutMs: number;
|
||||||
|
message: string;
|
||||||
|
run: (signal: AbortSignal) => Promise<T>;
|
||||||
|
}): Promise<T> {
|
||||||
|
const controller = new AbortController();
|
||||||
|
if (!Number.isFinite(params.timeoutMs) || params.timeoutMs <= 0) {
|
||||||
|
return await params.run(controller.signal);
|
||||||
|
}
|
||||||
|
let timer: NodeJS.Timeout | null = null;
|
||||||
|
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||||
|
timer = setTimeout(() => {
|
||||||
|
const error = new Error(params.message);
|
||||||
|
reject(error);
|
||||||
|
controller.abort(error);
|
||||||
|
}, params.timeoutMs);
|
||||||
|
timer.unref?.();
|
||||||
|
});
|
||||||
|
try {
|
||||||
|
const operation = params.run(controller.signal);
|
||||||
|
return (await Promise.race([operation, timeoutPromise])) as T;
|
||||||
|
} finally {
|
||||||
|
if (timer) {
|
||||||
|
clearTimeout(timer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
||||||
protected abstract batchFailureCount: number;
|
protected abstract batchFailureCount: number;
|
||||||
protected abstract batchFailureLastError?: string;
|
protected abstract batchFailureLastError?: string;
|
||||||
@@ -304,11 +332,11 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
|||||||
items: texts.length,
|
items: texts.length,
|
||||||
timeoutMs,
|
timeoutMs,
|
||||||
});
|
});
|
||||||
return await this.withTimeout(
|
return await runEmbeddingOperationWithTimeout({
|
||||||
provider.embedBatch(texts),
|
|
||||||
timeoutMs,
|
timeoutMs,
|
||||||
`memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
|
message: `memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||||
);
|
run: async (signal) => await provider.embedBatch(texts, { signal }),
|
||||||
|
});
|
||||||
},
|
},
|
||||||
isRetryable: isRetryableMemoryEmbeddingError,
|
isRetryable: isRetryableMemoryEmbeddingError,
|
||||||
waitForRetry: async (delayMs) => {
|
waitForRetry: async (delayMs) => {
|
||||||
@@ -336,11 +364,11 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
|||||||
items: inputs.length,
|
items: inputs.length,
|
||||||
timeoutMs,
|
timeoutMs,
|
||||||
});
|
});
|
||||||
return await this.withTimeout(
|
return await runEmbeddingOperationWithTimeout({
|
||||||
embedBatchInputs(inputs),
|
|
||||||
timeoutMs,
|
timeoutMs,
|
||||||
`memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
|
message: `memory embeddings batch timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||||
);
|
run: async (signal) => await embedBatchInputs(inputs, { signal }),
|
||||||
|
});
|
||||||
},
|
},
|
||||||
isRetryable: isRetryableMemoryEmbeddingError,
|
isRetryable: isRetryableMemoryEmbeddingError,
|
||||||
waitForRetry: async (delayMs) => {
|
waitForRetry: async (delayMs) => {
|
||||||
@@ -371,16 +399,17 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected async embedQueryWithTimeout(text: string): Promise<number[]> {
|
protected async embedQueryWithTimeout(text: string): Promise<number[]> {
|
||||||
if (!this.provider) {
|
const provider = this.provider;
|
||||||
|
if (!provider) {
|
||||||
throw new Error("Cannot embed query in FTS-only mode (no embedding provider)");
|
throw new Error("Cannot embed query in FTS-only mode (no embedding provider)");
|
||||||
}
|
}
|
||||||
const timeoutMs = this.resolveEmbeddingTimeout("query");
|
const timeoutMs = this.resolveEmbeddingTimeout("query");
|
||||||
log.debug("memory embeddings: query start", { provider: this.provider.id, timeoutMs });
|
log.debug("memory embeddings: query start", { provider: provider.id, timeoutMs });
|
||||||
return await this.withTimeout(
|
return await runEmbeddingOperationWithTimeout({
|
||||||
this.provider.embedQuery(text),
|
|
||||||
timeoutMs,
|
timeoutMs,
|
||||||
`memory embeddings query timed out after ${Math.round(timeoutMs / 1000)}s`,
|
message: `memory embeddings query timed out after ${Math.round(timeoutMs / 1000)}s`,
|
||||||
);
|
run: async (signal) => await provider.embedQuery(text, { signal }),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
protected async withTimeout<T>(
|
protected async withTimeout<T>(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { describe, expect, it } from "vitest";
|
|||||||
import {
|
import {
|
||||||
resolveEmbeddingTimeoutMs,
|
resolveEmbeddingTimeoutMs,
|
||||||
resolveMemoryIndexConcurrency,
|
resolveMemoryIndexConcurrency,
|
||||||
|
runEmbeddingOperationWithTimeout,
|
||||||
} from "./manager-embedding-ops.js";
|
} from "./manager-embedding-ops.js";
|
||||||
|
|
||||||
describe("memory embedding timeout resolution", () => {
|
describe("memory embedding timeout resolution", () => {
|
||||||
@@ -37,6 +38,42 @@ describe("memory embedding timeout resolution", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("memory embedding timeout abort", () => {
|
||||||
|
it("aborts the provider operation when the timeout fires", async () => {
|
||||||
|
let signalSeen: AbortSignal | undefined;
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
runEmbeddingOperationWithTimeout({
|
||||||
|
timeoutMs: 1,
|
||||||
|
message: "memory embeddings query timed out after 0s",
|
||||||
|
run: async (signal) => {
|
||||||
|
signalSeen = signal;
|
||||||
|
return await new Promise<number[]>((resolve, reject) => {
|
||||||
|
signal.addEventListener("abort", () => reject(signal.reason), { once: true });
|
||||||
|
});
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
).rejects.toThrow("memory embeddings query timed out after 0s");
|
||||||
|
|
||||||
|
expect(signalSeen?.aborted).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("keeps the timeout error when a provider abort listener rejects generically", async () => {
|
||||||
|
await expect(
|
||||||
|
runEmbeddingOperationWithTimeout({
|
||||||
|
timeoutMs: 1,
|
||||||
|
message: "memory embeddings batch timed out after 0s",
|
||||||
|
run: async (signal) =>
|
||||||
|
await new Promise<number[]>((_resolve, reject) => {
|
||||||
|
signal.addEventListener("abort", () => reject(new Error("provider aborted")), {
|
||||||
|
once: true,
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
).rejects.toThrow("memory embeddings batch timed out after 0s");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe("memory index concurrency resolution", () => {
|
describe("memory index concurrency resolution", () => {
|
||||||
it("uses the default index concurrency when batch mode is disabled and unconfigured", () => {
|
it("uses the default index concurrency when batch mode is disabled and unconfigured", () => {
|
||||||
expect(
|
expect(
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ export type OllamaEmbeddingProvider = {
|
|||||||
id: string;
|
id: string;
|
||||||
model: string;
|
model: string;
|
||||||
maxInputTokens?: number;
|
maxInputTokens?: number;
|
||||||
embedQuery: (text: string) => Promise<number[]>;
|
embedQuery: (text: string, options?: { signal?: AbortSignal }) => Promise<number[]>;
|
||||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
embedBatch: (texts: string[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
|
||||||
};
|
};
|
||||||
|
|
||||||
type OllamaEmbeddingOptions = {
|
type OllamaEmbeddingOptions = {
|
||||||
@@ -90,12 +90,14 @@ function sanitizeAndNormalizeEmbedding(vec: unknown[]): number[] {
|
|||||||
async function withRemoteHttpResponse<T>(params: {
|
async function withRemoteHttpResponse<T>(params: {
|
||||||
url: string;
|
url: string;
|
||||||
init?: RequestInit;
|
init?: RequestInit;
|
||||||
|
signal?: AbortSignal;
|
||||||
ssrfPolicy?: SsrFPolicy;
|
ssrfPolicy?: SsrFPolicy;
|
||||||
onResponse: (response: Response) => Promise<T>;
|
onResponse: (response: Response) => Promise<T>;
|
||||||
}): Promise<T> {
|
}): Promise<T> {
|
||||||
const { response, release } = await fetchWithSsrFGuard({
|
const { response, release } = await fetchWithSsrFGuard({
|
||||||
url: params.url,
|
url: params.url,
|
||||||
init: params.init,
|
init: params.init,
|
||||||
|
signal: params.signal,
|
||||||
policy: params.ssrfPolicy,
|
policy: params.ssrfPolicy,
|
||||||
auditContext: "memory-remote",
|
auditContext: "memory-remote",
|
||||||
});
|
});
|
||||||
@@ -322,10 +324,11 @@ export async function createOllamaEmbeddingProvider(
|
|||||||
const client = resolveOllamaEmbeddingClient(options);
|
const client = resolveOllamaEmbeddingClient(options);
|
||||||
const embedUrl = `${client.baseUrl.replace(/\/$/, "")}/api/embed`;
|
const embedUrl = `${client.baseUrl.replace(/\/$/, "")}/api/embed`;
|
||||||
|
|
||||||
const embedMany = async (input: string | string[]): Promise<number[][]> => {
|
const embedMany = async (input: string | string[], signal?: AbortSignal): Promise<number[][]> => {
|
||||||
const json = await withRemoteHttpResponse({
|
const json = await withRemoteHttpResponse({
|
||||||
url: embedUrl,
|
url: embedUrl,
|
||||||
ssrfPolicy: client.ssrfPolicy,
|
ssrfPolicy: client.ssrfPolicy,
|
||||||
|
signal,
|
||||||
init: {
|
init: {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: client.headers,
|
headers: client.headers,
|
||||||
@@ -355,22 +358,23 @@ export async function createOllamaEmbeddingProvider(
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const embedOne = async (text: string): Promise<number[]> => {
|
const embedOne = async (text: string, signal?: AbortSignal): Promise<number[]> => {
|
||||||
const [embedding] = await embedMany(text);
|
const [embedding] = await embedMany(text, signal);
|
||||||
if (!embedding) {
|
if (!embedding) {
|
||||||
throw new Error("Ollama embed response returned no embedding");
|
throw new Error("Ollama embed response returned no embedding");
|
||||||
}
|
}
|
||||||
return embedding;
|
return embedding;
|
||||||
};
|
};
|
||||||
|
|
||||||
const embedQuery = async (text: string): Promise<number[]> =>
|
const embedQuery = async (text: string, options?: { signal?: AbortSignal }): Promise<number[]> =>
|
||||||
await embedOne(applyQueryInstructionTemplate(client.model, text));
|
await embedOne(applyQueryInstructionTemplate(client.model, text), options?.signal);
|
||||||
|
|
||||||
const provider: OllamaEmbeddingProvider = {
|
const provider: OllamaEmbeddingProvider = {
|
||||||
id: "ollama",
|
id: "ollama",
|
||||||
model: client.model,
|
model: client.model,
|
||||||
embedQuery,
|
embedQuery,
|
||||||
embedBatch: async (texts) => (texts.length === 0 ? [] : await embedMany(texts)),
|
embedBatch: async (texts, options) =>
|
||||||
|
texts.length === 0 ? [] : await embedMany(texts, options?.signal),
|
||||||
};
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ function expectFetchRemoteEmbeddingVectorsBody(body: Record<string, unknown>) {
|
|||||||
headers: { Authorization: "Bearer test" },
|
headers: { Authorization: "Bearer test" },
|
||||||
ssrfPolicy: undefined,
|
ssrfPolicy: undefined,
|
||||||
fetchImpl: undefined,
|
fetchImpl: undefined,
|
||||||
|
signal: undefined,
|
||||||
body,
|
body,
|
||||||
errorPrefix: "openai embeddings failed",
|
errorPrefix: "openai embeddings failed",
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -47,7 +47,11 @@ export async function createOpenAiEmbeddingProvider(
|
|||||||
return typeof value === "string" && value.trim().length > 0 ? value.trim() : undefined;
|
return typeof value === "string" && value.trim().length > 0 ? value.trim() : undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
const embed = async (input: string[], kind: "query" | "document"): Promise<number[][]> => {
|
const embed = async (
|
||||||
|
input: string[],
|
||||||
|
kind: "query" | "document",
|
||||||
|
signal?: AbortSignal,
|
||||||
|
): Promise<number[][]> => {
|
||||||
if (input.length === 0) {
|
if (input.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
@@ -57,6 +61,7 @@ export async function createOpenAiEmbeddingProvider(
|
|||||||
headers: client.headers,
|
headers: client.headers,
|
||||||
ssrfPolicy: client.ssrfPolicy,
|
ssrfPolicy: client.ssrfPolicy,
|
||||||
fetchImpl: client.fetchImpl,
|
fetchImpl: client.fetchImpl,
|
||||||
|
signal,
|
||||||
body: {
|
body: {
|
||||||
model: client.model,
|
model: client.model,
|
||||||
input,
|
input,
|
||||||
@@ -76,11 +81,11 @@ export async function createOpenAiEmbeddingProvider(
|
|||||||
...(typeof OPENAI_MAX_INPUT_TOKENS[client.model] === "number"
|
...(typeof OPENAI_MAX_INPUT_TOKENS[client.model] === "number"
|
||||||
? { maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model] }
|
? { maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model] }
|
||||||
: {}),
|
: {}),
|
||||||
embedQuery: async (text) => {
|
embedQuery: async (text, options) => {
|
||||||
const [vec] = await embed([text], "query");
|
const [vec] = await embed([text], "query", options?.signal);
|
||||||
return vec ?? [];
|
return vec ?? [];
|
||||||
},
|
},
|
||||||
embedBatch: async (texts) => await embed(texts, "document"),
|
embedBatch: async (texts, options) => await embed(texts, "document", options?.signal),
|
||||||
},
|
},
|
||||||
client,
|
client,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -36,7 +36,11 @@ export async function createVoyageEmbeddingProvider(
|
|||||||
const client = await resolveVoyageEmbeddingClient(options);
|
const client = await resolveVoyageEmbeddingClient(options);
|
||||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||||
|
|
||||||
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => {
|
const embed = async (
|
||||||
|
input: string[],
|
||||||
|
input_type?: "query" | "document",
|
||||||
|
signal?: AbortSignal,
|
||||||
|
): Promise<number[][]> => {
|
||||||
if (input.length === 0) {
|
if (input.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
@@ -52,6 +56,7 @@ export async function createVoyageEmbeddingProvider(
|
|||||||
url,
|
url,
|
||||||
headers: client.headers,
|
headers: client.headers,
|
||||||
ssrfPolicy: client.ssrfPolicy,
|
ssrfPolicy: client.ssrfPolicy,
|
||||||
|
signal,
|
||||||
body,
|
body,
|
||||||
errorPrefix: "voyage embeddings failed",
|
errorPrefix: "voyage embeddings failed",
|
||||||
});
|
});
|
||||||
@@ -62,11 +67,11 @@ export async function createVoyageEmbeddingProvider(
|
|||||||
id: "voyage",
|
id: "voyage",
|
||||||
model: client.model,
|
model: client.model,
|
||||||
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
|
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
|
||||||
embedQuery: async (text) => {
|
embedQuery: async (text, options) => {
|
||||||
const [vec] = await embed([text], "query");
|
const [vec] = await embed([text], "query", options?.signal);
|
||||||
return vec ?? [];
|
return vec ?? [];
|
||||||
},
|
},
|
||||||
embedBatch: async (texts) => embed(texts, "document"),
|
embedBatch: async (texts, options) => embed(texts, "document", options?.signal),
|
||||||
},
|
},
|
||||||
client,
|
client,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ export type {
|
|||||||
MemoryEmbeddingBatchOptions,
|
MemoryEmbeddingBatchOptions,
|
||||||
MemoryEmbeddingProvider,
|
MemoryEmbeddingProvider,
|
||||||
MemoryEmbeddingProviderAdapter,
|
MemoryEmbeddingProviderAdapter,
|
||||||
|
MemoryEmbeddingProviderCallOptions,
|
||||||
MemoryEmbeddingProviderCreateOptions,
|
MemoryEmbeddingProviderCreateOptions,
|
||||||
MemoryEmbeddingProviderCreateResult,
|
MemoryEmbeddingProviderCreateResult,
|
||||||
MemoryEmbeddingProviderRuntime,
|
MemoryEmbeddingProviderRuntime,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ vi.mock("./post-json.js", () => ({
|
|||||||
function requirePostJsonParams(): {
|
function requirePostJsonParams(): {
|
||||||
url?: unknown;
|
url?: unknown;
|
||||||
headers?: unknown;
|
headers?: unknown;
|
||||||
|
signal?: unknown;
|
||||||
body?: unknown;
|
body?: unknown;
|
||||||
errorPrefix?: unknown;
|
errorPrefix?: unknown;
|
||||||
} {
|
} {
|
||||||
@@ -51,6 +52,23 @@ describe("fetchRemoteEmbeddingVectors", () => {
|
|||||||
expect(postJsonParams.errorPrefix).toBe("embedding fetch failed");
|
expect(postJsonParams.errorPrefix).toBe("embedding fetch failed");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("passes abort signals to the JSON request", async () => {
|
||||||
|
const controller = new AbortController();
|
||||||
|
postJsonMock.mockImplementationOnce(async (params) => {
|
||||||
|
return await params.parse({ data: [{ embedding: [0.1] }] });
|
||||||
|
});
|
||||||
|
|
||||||
|
await fetchRemoteEmbeddingVectors({
|
||||||
|
url: "https://memory.example/v1/embeddings",
|
||||||
|
headers: {},
|
||||||
|
signal: controller.signal,
|
||||||
|
body: { input: ["one"] },
|
||||||
|
errorPrefix: "embedding fetch failed",
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(requirePostJsonParams().signal).toBe(controller.signal);
|
||||||
|
});
|
||||||
|
|
||||||
it("throws a status-rich error on non-ok responses", async () => {
|
it("throws a status-rich error on non-ok responses", async () => {
|
||||||
postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden"));
|
postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden"));
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ export async function fetchRemoteEmbeddingVectors(params: {
|
|||||||
headers: Record<string, string>;
|
headers: Record<string, string>;
|
||||||
ssrfPolicy?: SsrFPolicy;
|
ssrfPolicy?: SsrFPolicy;
|
||||||
fetchImpl?: typeof fetch;
|
fetchImpl?: typeof fetch;
|
||||||
|
signal?: AbortSignal;
|
||||||
body: unknown;
|
body: unknown;
|
||||||
errorPrefix: string;
|
errorPrefix: string;
|
||||||
}): Promise<number[][]> {
|
}): Promise<number[][]> {
|
||||||
@@ -41,6 +42,7 @@ export async function fetchRemoteEmbeddingVectors(params: {
|
|||||||
headers: params.headers,
|
headers: params.headers,
|
||||||
ssrfPolicy: params.ssrfPolicy,
|
ssrfPolicy: params.ssrfPolicy,
|
||||||
fetchImpl: params.fetchImpl,
|
fetchImpl: params.fetchImpl,
|
||||||
|
signal: params.signal,
|
||||||
body: params.body,
|
body: params.body,
|
||||||
errorPrefix: params.errorPrefix,
|
errorPrefix: params.errorPrefix,
|
||||||
parse: (payload) => {
|
parse: (payload) => {
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ export function createRemoteEmbeddingProvider(params: {
|
|||||||
const { client } = params;
|
const { client } = params;
|
||||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||||
|
|
||||||
const embed = async (input: string[]): Promise<number[][]> => {
|
const embed = async (input: string[], signal?: AbortSignal): Promise<number[][]> => {
|
||||||
if (input.length === 0) {
|
if (input.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
@@ -32,6 +32,7 @@ export function createRemoteEmbeddingProvider(params: {
|
|||||||
headers: client.headers,
|
headers: client.headers,
|
||||||
ssrfPolicy: client.ssrfPolicy,
|
ssrfPolicy: client.ssrfPolicy,
|
||||||
fetchImpl: client.fetchImpl,
|
fetchImpl: client.fetchImpl,
|
||||||
|
signal,
|
||||||
body: { model: client.model, input },
|
body: { model: client.model, input },
|
||||||
errorPrefix: params.errorPrefix,
|
errorPrefix: params.errorPrefix,
|
||||||
});
|
});
|
||||||
@@ -41,11 +42,11 @@ export function createRemoteEmbeddingProvider(params: {
|
|||||||
id: params.id,
|
id: params.id,
|
||||||
model: client.model,
|
model: client.model,
|
||||||
...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}),
|
...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}),
|
||||||
embedQuery: async (text) => {
|
embedQuery: async (text, options) => {
|
||||||
const [vec] = await embed([text]);
|
const [vec] = await embed([text], options?.signal);
|
||||||
return vec ?? [];
|
return vec ?? [];
|
||||||
},
|
},
|
||||||
embedBatch: embed,
|
embedBatch: async (texts, options) => await embed(texts, options?.signal),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -66,15 +66,20 @@ export async function createLocalEmbeddingProvider(
|
|||||||
return {
|
return {
|
||||||
id: "local",
|
id: "local",
|
||||||
model: modelPath,
|
model: modelPath,
|
||||||
embedQuery: async (text) => {
|
embedQuery: async (text, options) => {
|
||||||
|
options?.signal?.throwIfAborted();
|
||||||
const ctx = await ensureContext();
|
const ctx = await ensureContext();
|
||||||
|
options?.signal?.throwIfAborted();
|
||||||
const embedding = await ctx.getEmbeddingFor(text);
|
const embedding = await ctx.getEmbeddingFor(text);
|
||||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||||
},
|
},
|
||||||
embedBatch: async (texts) => {
|
embedBatch: async (texts, options) => {
|
||||||
|
options?.signal?.throwIfAborted();
|
||||||
const ctx = await ensureContext();
|
const ctx = await ensureContext();
|
||||||
|
options?.signal?.throwIfAborted();
|
||||||
const embeddings = await Promise.all(
|
const embeddings = await Promise.all(
|
||||||
texts.map(async (text) => {
|
texts.map(async (text) => {
|
||||||
|
options?.signal?.throwIfAborted();
|
||||||
const embedding = await ctx.getEmbeddingFor(text);
|
const embedding = await ctx.getEmbeddingFor(text);
|
||||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -5,9 +5,16 @@ export type EmbeddingProvider = {
|
|||||||
id: string;
|
id: string;
|
||||||
model: string;
|
model: string;
|
||||||
maxInputTokens?: number;
|
maxInputTokens?: number;
|
||||||
embedQuery: (text: string) => Promise<number[]>;
|
embedQuery: (text: string, options?: EmbeddingProviderCallOptions) => Promise<number[]>;
|
||||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
embedBatch: (texts: string[], options?: EmbeddingProviderCallOptions) => Promise<number[][]>;
|
||||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
embedBatchInputs?: (
|
||||||
|
inputs: EmbeddingInput[],
|
||||||
|
options?: EmbeddingProviderCallOptions,
|
||||||
|
) => Promise<number[][]>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type EmbeddingProviderCallOptions = {
|
||||||
|
signal?: AbortSignal;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type EmbeddingProviderId = string;
|
export type EmbeddingProviderId = string;
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ export type {
|
|||||||
MemoryEmbeddingBatchOptions,
|
MemoryEmbeddingBatchOptions,
|
||||||
MemoryEmbeddingProvider,
|
MemoryEmbeddingProvider,
|
||||||
MemoryEmbeddingProviderAdapter,
|
MemoryEmbeddingProviderAdapter,
|
||||||
|
MemoryEmbeddingProviderCallOptions,
|
||||||
MemoryEmbeddingProviderCreateOptions,
|
MemoryEmbeddingProviderCreateOptions,
|
||||||
MemoryEmbeddingProviderCreateResult,
|
MemoryEmbeddingProviderCreateResult,
|
||||||
MemoryEmbeddingProviderRuntime,
|
MemoryEmbeddingProviderRuntime,
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ export type {
|
|||||||
MemoryEmbeddingBatchOptions,
|
MemoryEmbeddingBatchOptions,
|
||||||
MemoryEmbeddingProvider,
|
MemoryEmbeddingProvider,
|
||||||
MemoryEmbeddingProviderAdapter,
|
MemoryEmbeddingProviderAdapter,
|
||||||
|
MemoryEmbeddingProviderCallOptions,
|
||||||
MemoryEmbeddingProviderCreateOptions,
|
MemoryEmbeddingProviderCreateOptions,
|
||||||
MemoryEmbeddingProviderCreateResult,
|
MemoryEmbeddingProviderCreateResult,
|
||||||
MemoryEmbeddingProviderRuntime,
|
MemoryEmbeddingProviderRuntime,
|
||||||
|
|||||||
@@ -47,6 +47,23 @@ describe("postJson", () => {
|
|||||||
expect(result).toEqual({ data: [{ embedding: [1, 2] }] });
|
expect(result).toEqual({ data: [{ embedding: [1, 2] }] });
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("forwards abort signals to the remote HTTP request", async () => {
|
||||||
|
const controller = new AbortController();
|
||||||
|
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||||
|
expect(params.signal).toBe(controller.signal);
|
||||||
|
return await params.onResponse(jsonResponse({ ok: true }));
|
||||||
|
});
|
||||||
|
|
||||||
|
await postJson({
|
||||||
|
url: "https://memory.example/v1/post",
|
||||||
|
headers: {},
|
||||||
|
body: {},
|
||||||
|
signal: controller.signal,
|
||||||
|
errorPrefix: "post failed",
|
||||||
|
parse: (payload) => payload,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("attaches status to thrown error when requested", async () => {
|
it("attaches status to thrown error when requested", async () => {
|
||||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||||
return await params.onResponse(textResponse("bad gateway", 502));
|
return await params.onResponse(textResponse("bad gateway", 502));
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ export async function postJson<T>(params: {
|
|||||||
headers: Record<string, string>;
|
headers: Record<string, string>;
|
||||||
ssrfPolicy?: SsrFPolicy;
|
ssrfPolicy?: SsrFPolicy;
|
||||||
fetchImpl?: typeof fetch;
|
fetchImpl?: typeof fetch;
|
||||||
|
signal?: AbortSignal;
|
||||||
body: unknown;
|
body: unknown;
|
||||||
errorPrefix: string;
|
errorPrefix: string;
|
||||||
attachStatus?: boolean;
|
attachStatus?: boolean;
|
||||||
@@ -15,6 +16,7 @@ export async function postJson<T>(params: {
|
|||||||
url: params.url,
|
url: params.url,
|
||||||
ssrfPolicy: params.ssrfPolicy,
|
ssrfPolicy: params.ssrfPolicy,
|
||||||
fetchImpl: params.fetchImpl,
|
fetchImpl: params.fetchImpl,
|
||||||
|
signal: params.signal,
|
||||||
init: {
|
init: {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: params.headers,
|
headers: params.headers,
|
||||||
|
|||||||
@@ -43,4 +43,18 @@ describe("package withRemoteHttpResponse", () => {
|
|||||||
expect(deps.calls).toHaveLength(1);
|
expect(deps.calls).toHaveLength(1);
|
||||||
expect(deps.calls[0]).not.toHaveProperty("mode");
|
expect(deps.calls[0]).not.toHaveProperty("mode");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("passes abort signals to the guarded fetch", async () => {
|
||||||
|
const deps = makeFetchDeps();
|
||||||
|
const controller = new AbortController();
|
||||||
|
|
||||||
|
await withRemoteHttpResponse({
|
||||||
|
url: "https://memory.example/v1/embeddings",
|
||||||
|
signal: controller.signal,
|
||||||
|
onResponse: async () => undefined,
|
||||||
|
...deps,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(deps.calls[0]).toHaveProperty("signal", controller.signal);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ export const buildRemoteBaseUrlPolicy: (baseUrl: string) => SsrFPolicy | undefin
|
|||||||
export async function withRemoteHttpResponse<T>(params: {
|
export async function withRemoteHttpResponse<T>(params: {
|
||||||
url: string;
|
url: string;
|
||||||
init?: RequestInit;
|
init?: RequestInit;
|
||||||
|
signal?: AbortSignal;
|
||||||
ssrfPolicy?: SsrFPolicy;
|
ssrfPolicy?: SsrFPolicy;
|
||||||
fetchImpl?: typeof fetch;
|
fetchImpl?: typeof fetch;
|
||||||
fetchWithSsrFGuardImpl?: typeof fetchWithSsrFGuard;
|
fetchWithSsrFGuardImpl?: typeof fetchWithSsrFGuard;
|
||||||
@@ -26,6 +27,7 @@ export async function withRemoteHttpResponse<T>(params: {
|
|||||||
url: params.url,
|
url: params.url,
|
||||||
fetchImpl: params.fetchImpl,
|
fetchImpl: params.fetchImpl,
|
||||||
init: params.init,
|
init: params.init,
|
||||||
|
signal: params.signal,
|
||||||
policy: params.ssrfPolicy,
|
policy: params.ssrfPolicy,
|
||||||
auditContext: params.auditContext ?? "memory-remote",
|
auditContext: params.auditContext ?? "memory-remote",
|
||||||
...(shouldUseEnvProxy(params.url) ? { mode: MEMORY_REMOTE_TRUSTED_ENV_PROXY_MODE } : {}),
|
...(shouldUseEnvProxy(params.url) ? { mode: MEMORY_REMOTE_TRUSTED_ENV_PROXY_MODE } : {}),
|
||||||
|
|||||||
@@ -51,5 +51,9 @@ export function runNodeMain(params?: {
|
|||||||
cwd?: string;
|
cwd?: string;
|
||||||
args?: string[];
|
args?: string[];
|
||||||
env?: NodeJS.ProcessEnv;
|
env?: NodeJS.ProcessEnv;
|
||||||
|
runRuntimePostBuild?: (params?: {
|
||||||
|
cwd?: string;
|
||||||
|
env?: Record<string, string | undefined>;
|
||||||
|
}) => void | Promise<void>;
|
||||||
platform?: NodeJS.Platform;
|
platform?: NodeJS.Platform;
|
||||||
}): Promise<number>;
|
}): Promise<number>;
|
||||||
|
|||||||
@@ -79,6 +79,22 @@ function isGuidedConfigAction(actionCommand: Command): boolean {
|
|||||||
return actionCommand.name() === "config" && !actionCommand.parent?.parent;
|
return actionCommand.name() === "config" && !actionCommand.parent?.parent;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isGuidedConfigCommandPath(commandPath: string[]): boolean {
|
||||||
|
const [primary, secondary, extra] = commandPath;
|
||||||
|
if (primary !== "config" || extra !== undefined) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
secondary !== "get" &&
|
||||||
|
secondary !== "set" &&
|
||||||
|
secondary !== "patch" &&
|
||||||
|
secondary !== "unset" &&
|
||||||
|
secondary !== "file" &&
|
||||||
|
secondary !== "schema" &&
|
||||||
|
secondary !== "validate"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function registerPreActionHooks(program: Command, programVersion: string) {
|
export function registerPreActionHooks(program: Command, programVersion: string) {
|
||||||
program.hook("preAction", async (_thisCommand, actionCommand) => {
|
program.hook("preAction", async (_thisCommand, actionCommand) => {
|
||||||
setProcessTitleForCommand(actionCommand);
|
setProcessTitleForCommand(actionCommand);
|
||||||
@@ -105,7 +121,11 @@ export function registerPreActionHooks(program: Command, programVersion: string)
|
|||||||
if (!verbose) {
|
if (!verbose) {
|
||||||
process.env.NODE_NO_WARNINGS ??= "1";
|
process.env.NODE_NO_WARNINGS ??= "1";
|
||||||
}
|
}
|
||||||
if (shouldBypassConfigGuardForCommandPath(commandPath) || isGuidedConfigAction(actionCommand)) {
|
if (
|
||||||
|
shouldBypassConfigGuardForCommandPath(commandPath) ||
|
||||||
|
isGuidedConfigAction(actionCommand) ||
|
||||||
|
isGuidedConfigCommandPath(commandPath)
|
||||||
|
) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
await ensureCliExecutionBootstrap({
|
await ensureCliExecutionBootstrap({
|
||||||
|
|||||||
@@ -430,10 +430,19 @@ describe("gateway auth browser hardening", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("rejects forged loopback origin for control-ui when proxy headers make client non-local", async () => {
|
test("rejects forged loopback origin for control-ui when proxy headers make client non-local", async () => {
|
||||||
|
const { writeConfigFile } = await import("../config/config.js");
|
||||||
|
await writeConfigFile({
|
||||||
|
gateway: {
|
||||||
|
trustedProxies: ["127.0.0.1"],
|
||||||
|
controlUi: {
|
||||||
|
allowedOrigins: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
testState.gatewayAuth = { mode: "token", token: "secret" };
|
testState.gatewayAuth = { mode: "token", token: "secret" };
|
||||||
await withGatewayServer(async ({ port }) => {
|
await withGatewayServer(async ({ port }) => {
|
||||||
const ws = await openWs(port, {
|
const ws = await openWs(port, {
|
||||||
origin: originForPort(port),
|
origin: "http://localhost:5173",
|
||||||
"x-forwarded-for": "203.0.113.50",
|
"x-forwarded-for": "203.0.113.50",
|
||||||
});
|
});
|
||||||
try {
|
try {
|
||||||
@@ -444,6 +453,7 @@ describe("gateway auth browser hardening", () => {
|
|||||||
id: GATEWAY_CLIENT_NAMES.CONTROL_UI,
|
id: GATEWAY_CLIENT_NAMES.CONTROL_UI,
|
||||||
mode: GATEWAY_CLIENT_MODES.UI,
|
mode: GATEWAY_CLIENT_MODES.UI,
|
||||||
},
|
},
|
||||||
|
device: null,
|
||||||
});
|
});
|
||||||
expect(res.ok).toBe(false);
|
expect(res.ok).toBe(false);
|
||||||
expect(res.error?.message ?? "").toContain("origin not allowed");
|
expect(res.error?.message ?? "").toContain("origin not allowed");
|
||||||
|
|||||||
@@ -1415,9 +1415,6 @@ describe("run-node script", () => {
|
|||||||
return true;
|
return true;
|
||||||
}),
|
}),
|
||||||
} as unknown as NodeJS.WriteStream;
|
} as unknown as NodeJS.WriteStream;
|
||||||
const stdout = {
|
|
||||||
write: vi.fn(() => true),
|
|
||||||
} as unknown as NodeJS.WriteStream;
|
|
||||||
|
|
||||||
const exitCode = await runNodeMain({
|
const exitCode = await runNodeMain({
|
||||||
cwd: tmp,
|
cwd: tmp,
|
||||||
@@ -1430,11 +1427,10 @@ describe("run-node script", () => {
|
|||||||
spawn,
|
spawn,
|
||||||
spawnSync,
|
spawnSync,
|
||||||
stderr,
|
stderr,
|
||||||
stdout,
|
|
||||||
runRuntimePostBuild: async () => {},
|
runRuntimePostBuild: async () => {},
|
||||||
execPath: process.execPath,
|
execPath: process.execPath,
|
||||||
platform: process.platform,
|
platform: process.platform,
|
||||||
} as Parameters<typeof runNodeMain>[0] & { stdout: NodeJS.WriteStream });
|
} as Parameters<typeof runNodeMain>[0]);
|
||||||
|
|
||||||
expect(exitCode).toBe(0);
|
expect(exitCode).toBe(0);
|
||||||
const stderrText = stderrChunks.join("");
|
const stderrText = stderrChunks.join("");
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import type { OpenClawConfig } from "../config/types.js";
|
|||||||
import { normalizeSecretInputString, resolveSecretInputRef } from "../config/types.secrets.js";
|
import { normalizeSecretInputString, resolveSecretInputRef } from "../config/types.secrets.js";
|
||||||
import { materializeGatewayAuthSecretRefs } from "../gateway/auth-config-utils.js";
|
import { materializeGatewayAuthSecretRefs } from "../gateway/auth-config-utils.js";
|
||||||
import { assertExplicitGatewayAuthModeWhenBothConfigured } from "../gateway/auth-mode-policy.js";
|
import { assertExplicitGatewayAuthModeWhenBothConfigured } from "../gateway/auth-mode-policy.js";
|
||||||
import { isLoopbackHost } from "../gateway/net.js";
|
|
||||||
import { issueDeviceBootstrapToken } from "../infra/device-bootstrap.js";
|
import { issueDeviceBootstrapToken } from "../infra/device-bootstrap.js";
|
||||||
import {
|
import {
|
||||||
pickMatchingExternalInterfaceAddress,
|
pickMatchingExternalInterfaceAddress,
|
||||||
@@ -16,6 +15,7 @@ import {
|
|||||||
isCarrierGradeNatIpv4Address,
|
isCarrierGradeNatIpv4Address,
|
||||||
isIpv4Address,
|
isIpv4Address,
|
||||||
isIpv6Address,
|
isIpv6Address,
|
||||||
|
isLoopbackIpAddress,
|
||||||
isRfc1918Ipv4Address,
|
isRfc1918Ipv4Address,
|
||||||
parseCanonicalIpAddress,
|
parseCanonicalIpAddress,
|
||||||
} from "../shared/net/ip.js";
|
} from "../shared/net/ip.js";
|
||||||
@@ -124,7 +124,12 @@ function isPrivateLanHost(host: string): boolean {
|
|||||||
|
|
||||||
function isMobilePairingCleartextAllowedHost(host: string): boolean {
|
function isMobilePairingCleartextAllowedHost(host: string): boolean {
|
||||||
const normalized = normalizeMobilePairingHost(host);
|
const normalized = normalizeMobilePairingHost(host);
|
||||||
return isLoopbackHost(normalized) || normalized === "10.0.2.2" || isPrivateLanHost(normalized);
|
return (
|
||||||
|
normalized === "localhost" ||
|
||||||
|
isLoopbackIpAddress(normalized) ||
|
||||||
|
normalized === "10.0.2.2" ||
|
||||||
|
isPrivateLanHost(normalized)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function validateMobilePairingUrl(url: string, source?: string): string | null {
|
function validateMobilePairingUrl(url: string, source?: string): string | null {
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ export type {
|
|||||||
MemoryEmbeddingBatchOptions,
|
MemoryEmbeddingBatchOptions,
|
||||||
MemoryEmbeddingProvider,
|
MemoryEmbeddingProvider,
|
||||||
MemoryEmbeddingProviderAdapter,
|
MemoryEmbeddingProviderAdapter,
|
||||||
|
MemoryEmbeddingProviderCallOptions,
|
||||||
MemoryEmbeddingProviderCreateOptions,
|
MemoryEmbeddingProviderCreateOptions,
|
||||||
MemoryEmbeddingProviderCreateResult,
|
MemoryEmbeddingProviderCreateResult,
|
||||||
MemoryEmbeddingProviderRuntime,
|
MemoryEmbeddingProviderRuntime,
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ export type MemoryEmbeddingBatchOptions = {
|
|||||||
debug: (message: string, data?: Record<string, unknown>) => void;
|
debug: (message: string, data?: Record<string, unknown>) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type MemoryEmbeddingProviderCallOptions = {
|
||||||
|
signal?: AbortSignal;
|
||||||
|
};
|
||||||
|
|
||||||
export type MemoryEmbeddingProviderRuntime = {
|
export type MemoryEmbeddingProviderRuntime = {
|
||||||
id: string;
|
id: string;
|
||||||
cacheKeyData?: Record<string, unknown>;
|
cacheKeyData?: Record<string, unknown>;
|
||||||
@@ -29,9 +33,15 @@ export type MemoryEmbeddingProvider = {
|
|||||||
id: string;
|
id: string;
|
||||||
model: string;
|
model: string;
|
||||||
maxInputTokens?: number;
|
maxInputTokens?: number;
|
||||||
embedQuery: (text: string) => Promise<number[]>;
|
embedQuery: (text: string, options?: MemoryEmbeddingProviderCallOptions) => Promise<number[]>;
|
||||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
embedBatch: (
|
||||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
texts: string[],
|
||||||
|
options?: MemoryEmbeddingProviderCallOptions,
|
||||||
|
) => Promise<number[][]>;
|
||||||
|
embedBatchInputs?: (
|
||||||
|
inputs: EmbeddingInput[],
|
||||||
|
options?: MemoryEmbeddingProviderCallOptions,
|
||||||
|
) => Promise<number[][]>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type MemoryEmbeddingProviderCreateOptions = {
|
export type MemoryEmbeddingProviderCreateOptions = {
|
||||||
|
|||||||
@@ -964,9 +964,9 @@ export type PluginEmbeddingProvider = {
|
|||||||
id: string;
|
id: string;
|
||||||
model: string;
|
model: string;
|
||||||
maxInputTokens?: number;
|
maxInputTokens?: number;
|
||||||
embedQuery: (text: string) => Promise<number[]>;
|
embedQuery: (text: string, options?: { signal?: AbortSignal }) => Promise<number[]>;
|
||||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
embedBatch: (texts: string[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
|
||||||
embedBatchInputs?: (inputs: unknown[]) => Promise<number[][]>;
|
embedBatchInputs?: (inputs: unknown[], options?: { signal?: AbortSignal }) => Promise<number[][]>;
|
||||||
client?: unknown;
|
client?: unknown;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user