mirror of
https://github.com/openclaw/openclaw.git
synced 2026-06-09 15:31:18 +08:00
Compare commits
4 Commits
codex/boot
...
fix/codeql
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80f65fce4a | ||
|
|
9fc5f061e2 | ||
|
|
9b055ee2a3 | ||
|
|
457b2ee175 |
6
.github/workflows/codeql.yml
vendored
6
.github/workflows/codeql.yml
vendored
@@ -1,6 +1,12 @@
|
||||
name: CodeQL
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths-ignore:
|
||||
- "**/*.md"
|
||||
- "**/*.mdx"
|
||||
- "LICENSE"
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 6 * * *"
|
||||
|
||||
14
CHANGELOG.md
14
CHANGELOG.md
@@ -6,15 +6,7 @@ Docs: https://docs.openclaw.ai
|
||||
|
||||
### Fixes
|
||||
|
||||
- Agents/bootstrap: resolve bootstrap from workspace truth instead of stale session transcript markers, keep embedded bootstrap instructions on the hidden user-context path, suppress normal `/new` and `/reset` greetings while `BOOTSTRAP.md` is still pending, and make the embedded runner read the bootstrap ritual before replying normally.
|
||||
- Onboarding/non-interactive: preserve existing gateway auth tokens during re-onboard so active local gateway clients are not disconnected by an implicit token rotation. (#67821) Thanks @BKF-Gitty.
|
||||
- OpenAI Codex/Responses: unify native Responses API capability detection so Codex OAuth requests emit the required `store: false` field on the native Responses path. (#67918) Thanks @obviyus.
|
||||
- WhatsApp/setup: guard personal-phone and allowlist prompt values so setup fails with clear validation errors instead of crashing on undefined prompt text. (#67895) Thanks @lawrence3699.
|
||||
- Models/config: preserve an existing `models.json` provider `baseUrl` during merge-mode regeneration so custom endpoints do not get reset on restart. (#67893) Thanks @lawrence3699.
|
||||
- Plugins/discovery: reuse bundled and global plugin discovery results across workspace cache misses so Windows multi-workspace startup stops redoing the shared synchronous scan. (#67940) Thanks @obviyus.
|
||||
- Plugins/webhooks: enforce synchronous plugin registration with full rollback of failed plugin side effects, and cache SecretRef-backed webhook auth per route so plugin startup and inbound webhook auth stay deterministic. (#67941) Thanks @obviyus.
|
||||
- Telegram/ACP bindings: drop persisted DM bindings that still point at missing or failed ACP sessions on restart, while preserving plugin-owned bindings and uncertain store reads. (#67822) Thanks @chinar-amrutkar.
|
||||
- Telegram/streaming: keep a transient preview on the same Telegram message when auto-compaction retries an in-flight answer, so streamed replies no longer appear duplicated after compaction. (#66939) Thanks @rubencu.
|
||||
|
||||
## 2026.4.15
|
||||
|
||||
@@ -22,6 +14,12 @@ Docs: https://docs.openclaw.ai
|
||||
|
||||
- Anthropic/models: default Anthropic selections, `opus` aliases, Claude CLI defaults, and bundled image understanding to Claude Opus 4.7.
|
||||
- Google/TTS: add Gemini text-to-speech support to the bundled `google` plugin, including provider registration, voice selection, WAV reply output, PCM telephony output, and setup/docs guidance. (#67515) Thanks @barronlroth.
|
||||
- Control UI/Overview: add a Model Auth status card showing OAuth token health and provider rate-limit pressure at a glance, with attention callouts when OAuth tokens are expiring or expired. Backed by a new `models.authStatus` gateway method that strips credentials and caches for 60s. (#66211) Thanks @omarshahine.
|
||||
- Memory/LanceDB: add cloud storage support to `memory-lancedb` so durable memory indexes can run on remote object storage instead of local disk only. (#63502) Thanks @rugvedS07.
|
||||
- GitHub Copilot/memory search: add a GitHub Copilot embedding provider for memory search, and expose a dedicated Copilot embedding host helper so plugins can reuse the transport while honoring remote overrides, token refresh, and safer payload validation. (#61718) Thanks @feiskyer and @vincentkoc.
|
||||
- Agents/local models: add experimental `agents.defaults.experimental.localModelLean: true` to drop heavyweight default tools like `browser`, `cron`, and `message`, reducing prompt size for weaker local-model setups without changing the normal path. (#66495) Thanks @ImLukeF.
|
||||
- Packaging/plugins: localize bundled plugin runtime deps to their owning extensions, trim the published docs payload, and tighten install/package-manager guardrails so published builds stay leaner and core stops carrying extension-owned runtime baggage. (#67099) Thanks @vincentkoc.
|
||||
- QA/Matrix: split Matrix live QA into a source-linked `qa-matrix` runner and keep repo-private `qa-*` surfaces out of packaged and published builds. (#66723) Thanks @gumadeiras.
|
||||
- Docs/showcase: add a scannable hero, complete section jump links, and a responsive video grid for community examples. (#48493) Thanks @jchopard69.
|
||||
|
||||
### Fixes
|
||||
|
||||
@@ -318,7 +318,7 @@ Current bundled provider examples:
|
||||
| `plugin-sdk/memory-core` | Bundled memory-core helpers | Memory manager/config/file/CLI helper surface |
|
||||
| `plugin-sdk/memory-core-engine-runtime` | Memory engine runtime facade | Memory index/search runtime facade |
|
||||
| `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine | Memory host foundation engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory embedding contracts, registry access, local provider, and generic batch/remote helpers; concrete remote providers live in their owning plugins |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory host embedding engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine | Memory host QMD engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine | Memory host storage engine exports |
|
||||
| `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers | Memory host multimodal helpers |
|
||||
|
||||
@@ -264,7 +264,7 @@ explicitly promotes one as public.
|
||||
| `plugin-sdk/memory-core` | Bundled memory-core helper surface for manager/config/file/CLI helpers |
|
||||
| `plugin-sdk/memory-core-engine-runtime` | Memory index/search runtime facade |
|
||||
| `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding contracts, registry access, local provider, and generic batch/remote helpers |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine exports |
|
||||
| `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers |
|
||||
|
||||
@@ -119,7 +119,7 @@ describe("active-memory plugin", () => {
|
||||
runEmbeddedPiAgent.mockResolvedValue({
|
||||
payloads: [{ text: "- lemon pepper wings\n- blue cheese" }],
|
||||
});
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
@@ -425,7 +425,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
allowedChatTypes: ["direct", "group"],
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
const result = await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should we order?", messages: [] },
|
||||
@@ -513,7 +513,7 @@ describe("active-memory plugin", () => {
|
||||
searchMode: "inherit",
|
||||
},
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -602,7 +602,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "message",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -630,7 +630,7 @@ describe("active-memory plugin", () => {
|
||||
queryMode: "message",
|
||||
promptStyle: "preference-only",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -675,7 +675,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
thinking: "medium",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -701,7 +701,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
promptAppend: "Prefer stable long-term preferences over one-off events.",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -730,7 +730,7 @@ describe("active-memory plugin", () => {
|
||||
promptOverride: "Custom memory prompt. Return NONE or one user fact.",
|
||||
promptAppend: "Extra custom instruction.",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -802,7 +802,7 @@ describe("active-memory plugin", () => {
|
||||
api.pluginConfig = {
|
||||
agents: ["main"],
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? temp transcript", messages: [] },
|
||||
@@ -828,7 +828,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
modelFallbackPolicy: "resolved-only",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
const result = await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? no fallback", messages: [] },
|
||||
@@ -851,7 +851,7 @@ describe("active-memory plugin", () => {
|
||||
modelFallback: "google/gemini-3-flash",
|
||||
modelFallbackPolicy: "default-remote",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? custom fallback", messages: [] },
|
||||
@@ -878,7 +878,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
modelFallbackPolicy: "default-remote",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
const result = await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? built-in fallback", messages: [] },
|
||||
@@ -1027,7 +1027,7 @@ describe("active-memory plugin", () => {
|
||||
timeoutMs: 250,
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
let lastAbortSignal: AbortSignal | undefined;
|
||||
runEmbeddedPiAgent.mockImplementation(async (params: { abortSignal?: AbortSignal }) => {
|
||||
lastAbortSignal = params.abortSignal;
|
||||
@@ -1073,7 +1073,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? session id cache", messages: [] },
|
||||
@@ -1107,7 +1107,7 @@ describe("active-memory plugin", () => {
|
||||
timeoutMs: 250,
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
runEmbeddedPiAgent.mockImplementationOnce(async (params: { timeoutMs?: number }) => {
|
||||
await new Promise((resolve) => setTimeout(resolve, (params.timeoutMs ?? 0) + 25));
|
||||
return {
|
||||
@@ -1145,7 +1145,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? log sanitization", messages: [] },
|
||||
@@ -1179,7 +1179,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
const hugeSession = `agent:main:${"x".repeat(500)}`;
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
@@ -1423,7 +1423,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "message",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1451,7 +1451,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "full",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1482,7 +1482,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "recent",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1536,7 +1536,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "recent",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1578,7 +1578,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "recent",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1611,7 +1611,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "recent",
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1619,7 +1619,8 @@ describe("active-memory plugin", () => {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Active Memory: I really do want you to remember that I prefer aisle seats.",
|
||||
content:
|
||||
"Active Memory: I really do want you to remember that I prefer aisle seats.",
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
@@ -1673,7 +1674,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
maxSummaryChars: 40,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
runEmbeddedPiAgent.mockResolvedValueOnce({
|
||||
payloads: [
|
||||
{
|
||||
@@ -1707,7 +1708,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
maxSummaryChars: 90,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? prompt-count-check", messages: [] },
|
||||
@@ -1757,7 +1758,7 @@ describe("active-memory plugin", () => {
|
||||
transcriptDir: "active-memory-subagents",
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
const mkdirSpy = vi.spyOn(fs, "mkdir").mockResolvedValue(undefined);
|
||||
const mkdtempSpy = vi.spyOn(fs, "mkdtemp");
|
||||
const rmSpy = vi.spyOn(fs, "rm").mockResolvedValue(undefined);
|
||||
@@ -1801,7 +1802,7 @@ describe("active-memory plugin", () => {
|
||||
transcriptDir: "C:/temp/escape",
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
const mkdirSpy = vi.spyOn(fs, "mkdir").mockResolvedValue(undefined);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
@@ -1838,7 +1839,7 @@ describe("active-memory plugin", () => {
|
||||
transcriptDir: "active-memory-subagents",
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
const mkdirSpy = vi.spyOn(fs, "mkdir").mockResolvedValue(undefined);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
@@ -1905,7 +1906,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
logging: true,
|
||||
};
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
for (let index = 0; index <= 1000; index += 1) {
|
||||
await hooks.before_prompt_build(
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createBedrockEmbeddingProvider,
|
||||
DEFAULT_BEDROCK_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const bedrockMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "bedrock",
|
||||
defaultModel: DEFAULT_BEDROCK_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "amazon-bedrock",
|
||||
autoSelectPriority: 60,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createBedrockEmbeddingProvider({
|
||||
...options,
|
||||
provider: "bedrock",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "bedrock",
|
||||
cacheKeyData: {
|
||||
provider: "bedrock",
|
||||
region: client.region,
|
||||
model: client.model,
|
||||
dimensions: client.dimensions,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -2,9 +2,6 @@
|
||||
"id": "amazon-bedrock",
|
||||
"enabledByDefault": true,
|
||||
"providers": ["amazon-bedrock"],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["bedrock"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
|
||||
@@ -5,9 +5,7 @@
|
||||
"description": "OpenClaw Amazon Bedrock provider plugin",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"@aws-sdk/client-bedrock": "3.1028.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "3.1028.0",
|
||||
"@aws-sdk/credential-provider-node": "3.972.30"
|
||||
"@aws-sdk/client-bedrock": "3.1028.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@openclaw/plugin-sdk": "workspace:*"
|
||||
|
||||
@@ -14,7 +14,6 @@ import {
|
||||
resolveBedrockConfigApiKey,
|
||||
resolveImplicitBedrockProvider,
|
||||
} from "./api.js";
|
||||
import { bedrockMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
|
||||
type GuardrailConfig = {
|
||||
guardrailIdentifier: string;
|
||||
@@ -79,8 +78,6 @@ export function registerAmazonBedrockPlugin(api: OpenClawPluginApi): void {
|
||||
const pluginConfig = (api.pluginConfig ?? {}) as AmazonBedrockPluginConfig;
|
||||
const guardrail = pluginConfig.guardrail;
|
||||
|
||||
api.registerMemoryEmbeddingProvider(bedrockMemoryEmbeddingProviderAdapter);
|
||||
|
||||
const baseWrapStreamFn = ({ modelId, streamFn }: { modelId: string; streamFn?: StreamFn }) =>
|
||||
isAnthropicBedrockModel(modelId) ? streamFn : createBedrockNoCacheWrapper(streamFn);
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ describeLive("comfy live", () => {
|
||||
beforeAll(async () => {
|
||||
cfg = withPluginsEnabled(loadConfig());
|
||||
agentDir = resolveOpenClawAgentDir();
|
||||
plugin.register(
|
||||
await plugin.register(
|
||||
createTestPluginApi({
|
||||
config: cfg as never,
|
||||
registerImageGenerationProvider(provider) {
|
||||
|
||||
@@ -92,7 +92,7 @@ function registerPairCommand(params?: {
|
||||
pluginConfig?: Record<string, unknown>;
|
||||
}): OpenClawPluginCommandDefinition {
|
||||
let command: OpenClawPluginCommandDefinition | undefined;
|
||||
registerDevicePair.register(
|
||||
void registerDevicePair.register(
|
||||
createApi({
|
||||
...params,
|
||||
registerCommand: (nextCommand) => {
|
||||
|
||||
@@ -4,6 +4,7 @@ const resolveFirstGithubTokenMock = vi.hoisted(() => vi.fn());
|
||||
const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn());
|
||||
const resolveConfiguredSecretInputStringMock = vi.hoisted(() => vi.fn());
|
||||
const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn());
|
||||
const createGitHubCopilotEmbeddingProviderMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("./auth.js", () => ({
|
||||
resolveFirstGithubToken: resolveFirstGithubTokenMock,
|
||||
@@ -18,6 +19,10 @@ vi.mock("openclaw/plugin-sdk/github-copilot-token", () => ({
|
||||
resolveCopilotApiToken: resolveCopilotApiTokenMock,
|
||||
}));
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/memory-core-host-engine-embeddings", () => ({
|
||||
createGitHubCopilotEmbeddingProvider: createGitHubCopilotEmbeddingProviderMock,
|
||||
}));
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({
|
||||
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
|
||||
}));
|
||||
@@ -68,6 +73,15 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
source: "test",
|
||||
baseUrl: TEST_BASE_URL,
|
||||
});
|
||||
createGitHubCopilotEmbeddingProviderMock.mockImplementation(async (client) => ({
|
||||
provider: {
|
||||
id: "github-copilot",
|
||||
model: client.model,
|
||||
embedQuery: async () => [0.1, 0.2, 0.3],
|
||||
embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]),
|
||||
},
|
||||
client,
|
||||
}));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -75,6 +89,7 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
resolveConfiguredSecretInputStringMock.mockReset();
|
||||
resolveFirstGithubTokenMock.mockReset();
|
||||
resolveCopilotApiTokenMock.mockReset();
|
||||
createGitHubCopilotEmbeddingProviderMock.mockReset();
|
||||
fetchWithSsrFGuardMock.mockReset();
|
||||
});
|
||||
|
||||
@@ -98,8 +113,12 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
const result = await githubCopilotMemoryEmbeddingProviderAdapter.create(defaultCreateOptions());
|
||||
|
||||
expect(result.provider?.model).toBe("text-embedding-3-small");
|
||||
expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ githubToken: "gh_test_token_123" }),
|
||||
expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseUrl: TEST_BASE_URL,
|
||||
githubToken: "gh_test_token_123",
|
||||
model: "text-embedding-3-small",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -198,12 +217,14 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
} as never);
|
||||
|
||||
expect(resolveFirstGithubTokenMock).toHaveBeenCalled();
|
||||
expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
env: process.env,
|
||||
githubToken: "gh_remote_token",
|
||||
}),
|
||||
);
|
||||
expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith({
|
||||
baseUrl: "https://proxy.example/v1",
|
||||
env: process.env,
|
||||
fetchImpl: fetch,
|
||||
githubToken: "gh_remote_token",
|
||||
headers: { "X-Proxy-Token": "proxy" },
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
const discoveryCall = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as {
|
||||
init: { headers: Record<string, string> };
|
||||
|
||||
@@ -4,10 +4,7 @@ import {
|
||||
resolveCopilotApiToken,
|
||||
} from "openclaw/plugin-sdk/github-copilot-token";
|
||||
import {
|
||||
buildRemoteBaseUrlPolicy,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
withRemoteHttpResponse,
|
||||
type MemoryEmbeddingProvider,
|
||||
createGitHubCopilotEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { fetchWithSsrFGuard, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
@@ -47,15 +44,6 @@ type CopilotModelEntry = {
|
||||
supported_endpoints?: unknown;
|
||||
};
|
||||
|
||||
type GitHubCopilotEmbeddingClient = {
|
||||
githubToken: string;
|
||||
model: string;
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
fetchImpl?: typeof fetch;
|
||||
};
|
||||
|
||||
function isCopilotSetupError(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
@@ -159,126 +147,9 @@ function pickBestModel(available: string[], userModel?: string): string {
|
||||
throw new Error("No embedding models available from GitHub Copilot");
|
||||
}
|
||||
|
||||
function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] {
|
||||
if (!payload || typeof payload !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
const data = (payload as { data?: unknown }).data;
|
||||
if (!Array.isArray(data)) {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
|
||||
const vectors = Array.from<number[] | undefined>({ length: expectedCount });
|
||||
for (const entry of data) {
|
||||
if (!entry || typeof entry !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid entry");
|
||||
}
|
||||
const indexValue = (entry as { index?: unknown }).index;
|
||||
const embedding = (entry as { embedding?: unknown }).embedding;
|
||||
const index = typeof indexValue === "number" ? indexValue : Number.NaN;
|
||||
if (!Number.isInteger(index)) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid index");
|
||||
}
|
||||
if (index < 0 || index >= expectedCount) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an out-of-range index");
|
||||
}
|
||||
if (vectors[index] !== undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response contains duplicate indexes");
|
||||
}
|
||||
if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid embedding");
|
||||
}
|
||||
vectors[index] = sanitizeAndNormalizeEmbedding(embedding);
|
||||
}
|
||||
|
||||
for (let index = 0; index < expectedCount; index += 1) {
|
||||
if (vectors[index] === undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response missing vectors for some inputs");
|
||||
}
|
||||
}
|
||||
return vectors as number[][];
|
||||
}
|
||||
|
||||
async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
}> {
|
||||
const token = await resolveCopilotApiToken({
|
||||
githubToken: client.githubToken,
|
||||
env: client.env,
|
||||
fetchImpl: client.fetchImpl,
|
||||
});
|
||||
const baseUrl = client.baseUrl?.trim() || token.baseUrl || DEFAULT_COPILOT_API_BASE_URL;
|
||||
return {
|
||||
baseUrl,
|
||||
headers: {
|
||||
...COPILOT_HEADERS_STATIC,
|
||||
...client.headers,
|
||||
Authorization: `Bearer ${token.token}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function createGitHubCopilotEmbeddingProvider(
|
||||
client: GitHubCopilotEmbeddingClient,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
|
||||
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const session = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
return await withRemoteHttpResponse({
|
||||
url,
|
||||
fetchImpl: client.fetchImpl,
|
||||
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: session.headers,
|
||||
body: JSON.stringify({ model: client.model, input }),
|
||||
},
|
||||
onResponse: async (response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`,
|
||||
);
|
||||
}
|
||||
|
||||
let payload: unknown;
|
||||
try {
|
||||
payload = await response.json();
|
||||
} catch {
|
||||
throw new Error("GitHub Copilot embeddings returned invalid JSON");
|
||||
}
|
||||
return parseGitHubCopilotEmbeddingPayload(payload, input.length);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
model: client.model,
|
||||
embedQuery: async (text) => {
|
||||
const [vector] = await embed([text]);
|
||||
return vector ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
},
|
||||
client: {
|
||||
...client,
|
||||
baseUrl: initialSession.baseUrl,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export const githubCopilotMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
transport: "remote",
|
||||
authProviderId: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
autoSelectPriority: 15,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: (err: unknown) => isCopilotSetupError(err),
|
||||
|
||||
@@ -3,7 +3,6 @@ import type { MediaUnderstandingProvider } from "openclaw/plugin-sdk/media-under
|
||||
import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { buildGoogleGeminiCliBackend } from "./cli-backend.js";
|
||||
import { registerGoogleGeminiCliProvider } from "./gemini-cli-provider.js";
|
||||
import { geminiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { buildGoogleMusicGenerationProvider } from "./music-generation-provider.js";
|
||||
import { registerGoogleProvider } from "./provider-registration.js";
|
||||
import { buildGoogleSpeechProvider } from "./speech-provider.js";
|
||||
@@ -112,7 +111,6 @@ export default definePluginEntry({
|
||||
api.registerCliBackend(buildGoogleGeminiCliBackend());
|
||||
registerGoogleGeminiCliProvider(api);
|
||||
registerGoogleProvider(api);
|
||||
api.registerMemoryEmbeddingProvider(geminiMemoryEmbeddingProviderAdapter);
|
||||
api.registerImageGenerationProvider(createLazyGoogleImageGenerationProvider());
|
||||
api.registerMediaUnderstandingProvider(createLazyGoogleMediaUnderstandingProvider());
|
||||
api.registerMusicGenerationProvider(buildGoogleMusicGenerationProvider());
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import {
|
||||
hasNonTextEmbeddingParts,
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { runGeminiEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
function supportsGeminiMultimodalEmbeddings(model: string): boolean {
|
||||
const normalized = model
|
||||
.trim()
|
||||
.replace(/^models\//, "")
|
||||
.replace(/^(gemini|google)\//, "");
|
||||
return normalized === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
export const geminiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "gemini",
|
||||
defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "google",
|
||||
autoSelectPriority: 30,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model),
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "gemini",
|
||||
cacheKeyData: {
|
||||
provider: "gemini",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, [
|
||||
"authorization",
|
||||
"x-goog-api-key",
|
||||
]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) {
|
||||
return null;
|
||||
}
|
||||
const byCustomId = await runGeminiEmbeddingBatches({
|
||||
gemini: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
request: buildGeminiEmbeddingRequest({
|
||||
input: chunk.embeddingInput ?? { text: chunk.text },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: client.modelPath,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
}),
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -46,7 +46,6 @@
|
||||
},
|
||||
"contracts": {
|
||||
"mediaUnderstandingProviders": ["google"],
|
||||
"memoryEmbeddingProviders": ["gemini"],
|
||||
"imageGenerationProviders": ["google"],
|
||||
"musicGenerationProviders": ["google"],
|
||||
"speechProviders": ["google"],
|
||||
|
||||
@@ -8,7 +8,6 @@ import {
|
||||
type ProviderRuntimeModel,
|
||||
} from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { CUSTOM_LOCAL_AUTH_MARKER } from "openclaw/plugin-sdk/provider-auth";
|
||||
import { lmstudioMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import {
|
||||
LMSTUDIO_DEFAULT_API_KEY_ENV_VAR,
|
||||
LMSTUDIO_LOCAL_API_KEY_PLACEHOLDER,
|
||||
@@ -53,7 +52,6 @@ export default definePluginEntry({
|
||||
name: "LM Studio Provider",
|
||||
description: "Bundled LM Studio provider plugin",
|
||||
register(api: OpenClawPluginApi) {
|
||||
api.registerMemoryEmbeddingProvider(lmstudioMemoryEmbeddingProviderAdapter);
|
||||
api.registerProvider({
|
||||
id: PROVIDER_ID,
|
||||
label: "LM Studio",
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import {
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createLmstudioEmbeddingProvider,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
} from "./src/embedding-provider.js";
|
||||
|
||||
export const lmstudioMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "lmstudio",
|
||||
defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "lmstudio",
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider({
|
||||
...options,
|
||||
provider: "lmstudio",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "lmstudio",
|
||||
cacheKeyData: {
|
||||
provider: "lmstudio",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -21,9 +21,6 @@
|
||||
"groupHint": "Self-hosted open-weight models"
|
||||
}
|
||||
],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["lmstudio"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { IncomingMessage, ServerResponse } from "node:http";
|
||||
import { PassThrough } from "node:stream";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import type { OpenClawConfig, RuntimeEnv } from "../../runtime-api.js";
|
||||
import type { ResolvedMattermostAccount } from "./accounts.js";
|
||||
import { createSlashCommandHttpHandler } from "./slash-http.js";
|
||||
@@ -133,19 +133,25 @@ describe("slash-http", () => {
|
||||
});
|
||||
|
||||
it("returns 408 when the request body stalls", async () => {
|
||||
const handler = createSlashCommandHttpHandler({
|
||||
account: accountFixture,
|
||||
cfg: {} as OpenClawConfig,
|
||||
runtime: {} as RuntimeEnv,
|
||||
commandTokens: new Set(["valid-token"]),
|
||||
bodyTimeoutMs: 1,
|
||||
});
|
||||
const req = createRequest({ autoEnd: false });
|
||||
const response = createResponse();
|
||||
vi.useFakeTimers();
|
||||
try {
|
||||
const handler = createSlashCommandHttpHandler({
|
||||
account: accountFixture,
|
||||
cfg: {} as OpenClawConfig,
|
||||
runtime: {} as RuntimeEnv,
|
||||
commandTokens: new Set(["valid-token"]),
|
||||
});
|
||||
const req = createRequest({ autoEnd: false });
|
||||
const response = createResponse();
|
||||
const pending = handler(req, response.res);
|
||||
|
||||
await handler(req, response.res);
|
||||
await vi.advanceTimersByTimeAsync(5_000);
|
||||
await pending;
|
||||
|
||||
expect(response.res.statusCode).toBe(408);
|
||||
expect(response.getBody()).toBe("Request body timeout");
|
||||
expect(response.res.statusCode).toBe(408);
|
||||
expect(response.getBody()).toBe("Request body timeout");
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -54,7 +54,6 @@ type SlashHttpHandlerParams = {
|
||||
/** Map from trigger to original command name (for skill commands that start with oc_). */
|
||||
triggerMap?: ReadonlyMap<string, string>;
|
||||
log?: (msg: string) => void;
|
||||
bodyTimeoutMs?: number;
|
||||
};
|
||||
|
||||
const MAX_BODY_BYTES = 64 * 1024;
|
||||
@@ -63,14 +62,10 @@ const BODY_READ_TIMEOUT_MS = 5_000;
|
||||
/**
|
||||
* Read the full request body as a string.
|
||||
*/
|
||||
function readBody(
|
||||
req: IncomingMessage,
|
||||
maxBytes: number,
|
||||
timeoutMs = BODY_READ_TIMEOUT_MS,
|
||||
): Promise<string> {
|
||||
function readBody(req: IncomingMessage, maxBytes: number): Promise<string> {
|
||||
return readRequestBodyWithLimit(req, {
|
||||
maxBytes,
|
||||
timeoutMs,
|
||||
timeoutMs: BODY_READ_TIMEOUT_MS,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -224,7 +219,7 @@ async function authorizeSlashInvocation(params: {
|
||||
* from the Mattermost server when a user invokes a registered slash command.
|
||||
*/
|
||||
export function createSlashCommandHttpHandler(params: SlashHttpHandlerParams) {
|
||||
const { account, cfg, runtime, commandTokens, triggerMap, log, bodyTimeoutMs } = params;
|
||||
const { account, cfg, runtime, commandTokens, triggerMap, log } = params;
|
||||
|
||||
return async (req: IncomingMessage, res: ServerResponse): Promise<void> => {
|
||||
if (req.method !== "POST") {
|
||||
@@ -236,7 +231,7 @@ export function createSlashCommandHttpHandler(params: SlashHttpHandlerParams) {
|
||||
|
||||
let body: string;
|
||||
try {
|
||||
body = await readBody(req, MAX_BODY_BYTES, bodyTimeoutMs);
|
||||
body = await readBody(req, MAX_BODY_BYTES);
|
||||
} catch (error) {
|
||||
if (isRequestBodyLimitError(error, "REQUEST_BODY_TIMEOUT")) {
|
||||
res.statusCode = 408;
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
getMemoryEmbeddingProvider,
|
||||
listMemoryEmbeddingProviders,
|
||||
type MemoryEmbeddingProvider,
|
||||
@@ -10,7 +15,15 @@ import {
|
||||
import { formatErrorMessage } from "../dreaming-shared.js";
|
||||
import { canAutoSelectLocal } from "./provider-adapters.js";
|
||||
|
||||
export { DEFAULT_LOCAL_MODEL } from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
export {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
|
||||
export type EmbeddingProvider = MemoryEmbeddingProvider;
|
||||
export type EmbeddingProviderId = string;
|
||||
|
||||
@@ -11,9 +11,9 @@ import {
|
||||
} from "../../../../src/plugins/memory-embedding-providers.js";
|
||||
import "./test-runtime-mocks.js";
|
||||
import type { MemoryIndexManager } from "./index.js";
|
||||
import { closeAllMemorySearchManagers, getMemorySearchManager } from "./index.js";
|
||||
import { getMemorySearchManager, closeAllMemorySearchManagers } from "./index.js";
|
||||
import {
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
registerBuiltInMemoryEmbeddingProviders,
|
||||
} from "./provider-adapters.js";
|
||||
|
||||
@@ -112,14 +112,14 @@ vi.mock("./embeddings.js", () => {
|
||||
});
|
||||
|
||||
describe("memory index", () => {
|
||||
it("registers the builtin local embedding provider", () => {
|
||||
const adapter = listRegisteredAdapters().find((entry) => entry.id === "local");
|
||||
it("registers the builtin ollama embedding provider", () => {
|
||||
const adapter = listRegisteredAdapters().find((entry) => entry.id === "ollama");
|
||||
|
||||
expect(adapter).toBeDefined();
|
||||
expect(adapter).toEqual(
|
||||
expect.objectContaining({
|
||||
id: "local",
|
||||
defaultModel: DEFAULT_LOCAL_MODEL,
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,13 +1,31 @@
|
||||
import fsSync from "node:fs";
|
||||
import {
|
||||
createLocalEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
listMemoryEmbeddingProviders,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
OPENAI_BATCH_ENDPOINT,
|
||||
buildGeminiEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
createLmstudioEmbeddingProvider,
|
||||
createLocalEmbeddingProvider,
|
||||
createMistralEmbeddingProvider,
|
||||
createOllamaEmbeddingProvider,
|
||||
createOpenAiEmbeddingProvider,
|
||||
createVoyageEmbeddingProvider,
|
||||
hasNonTextEmbeddingParts,
|
||||
listRegisteredMemoryEmbeddingProviderAdapters,
|
||||
runGeminiEmbeddingBatches,
|
||||
runOpenAiEmbeddingBatches,
|
||||
runVoyageEmbeddingBatches,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { resolveUserPath } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import { getProviderEnvVars } from "openclaw/plugin-sdk/provider-env-vars";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { formatErrorMessage } from "../dreaming-shared.js";
|
||||
import { filterUnregisteredMemoryEmbeddingProviderAdapters } from "./provider-adapter-registration.js";
|
||||
|
||||
@@ -19,6 +37,31 @@ export type BuiltinMemoryEmbeddingProviderDoctorMetadata = {
|
||||
autoSelectPriority?: number;
|
||||
};
|
||||
|
||||
function isMissingApiKeyError(err: unknown): boolean {
|
||||
return formatErrorMessage(err).includes("No API key found for provider");
|
||||
}
|
||||
|
||||
function sanitizeHeaders(
|
||||
headers: Record<string, string>,
|
||||
excludedHeaderNames: string[],
|
||||
): Array<[string, string]> {
|
||||
const excluded = new Set(
|
||||
excludedHeaderNames.map((name) => normalizeLowercaseStringOrEmpty(name)),
|
||||
);
|
||||
return Object.entries(headers)
|
||||
.filter(([key]) => !excluded.has(normalizeLowercaseStringOrEmpty(key)))
|
||||
.toSorted(([a], [b]) => a.localeCompare(b))
|
||||
.map(([key, value]) => [key, value]);
|
||||
}
|
||||
|
||||
function mapBatchEmbeddingsByIndex(byCustomId: Map<string, number[]>, count: number): number[][] {
|
||||
const embeddings: number[][] = [];
|
||||
for (let index = 0; index < count; index += 1) {
|
||||
embeddings.push(byCustomId.get(String(index)) ?? []);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
@@ -27,20 +70,6 @@ function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
return code === "ERR_MODULE_NOT_FOUND" && err.message.includes("node-llama-cpp");
|
||||
}
|
||||
|
||||
function listRemoteEmbeddingSetupHints(): string[] {
|
||||
try {
|
||||
return listMemoryEmbeddingProviders()
|
||||
.filter(
|
||||
(adapter) =>
|
||||
adapter.transport === "remote" && typeof adapter.autoSelectPriority === "number",
|
||||
)
|
||||
.toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0))
|
||||
.map((adapter) => `Or set agents.defaults.memorySearch.provider = "${adapter.id}" (remote).`);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function formatLocalSetupError(err: unknown): string {
|
||||
const detail = formatErrorMessage(err);
|
||||
const missing = isNodeLlamaCppMissing(err);
|
||||
@@ -58,7 +87,9 @@ function formatLocalSetupError(err: unknown): string {
|
||||
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
|
||||
: null,
|
||||
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
|
||||
...listRemoteEmbeddingSetupHints(),
|
||||
...["openai", "gemini", "voyage", "mistral"].map(
|
||||
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
|
||||
),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
@@ -80,6 +111,237 @@ function canAutoSelectLocal(modelPath?: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
function supportsGeminiMultimodalEmbeddings(model: string): boolean {
|
||||
const normalized = model
|
||||
.trim()
|
||||
.replace(/^models\//, "")
|
||||
.replace(/^(gemini|google)\//, "");
|
||||
return normalized === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
function resolveMemoryEmbeddingAuthProviderId(providerId: string): string {
|
||||
return providerId === "gemini" ? "google" : providerId;
|
||||
}
|
||||
|
||||
const openAiAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "openai",
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 20,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "openai",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "openai",
|
||||
cacheKeyData: {
|
||||
provider: "openai",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runOpenAiEmbeddingBatches({
|
||||
openAi: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: {
|
||||
model: client.model,
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const geminiAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "gemini",
|
||||
defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 30,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model),
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "gemini",
|
||||
cacheKeyData: {
|
||||
provider: "gemini",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization", "x-goog-api-key"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) {
|
||||
return null;
|
||||
}
|
||||
const byCustomId = await runGeminiEmbeddingBatches({
|
||||
gemini: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
request: buildGeminiEmbeddingRequest({
|
||||
input: chunk.embeddingInput ?? { text: chunk.text },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: client.modelPath,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
}),
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const voyageAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "voyage",
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 40,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider({
|
||||
...options,
|
||||
provider: "voyage",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "voyage",
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runVoyageEmbeddingBatches({
|
||||
client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
body: {
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const mistralAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "mistral",
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 50,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createMistralEmbeddingProvider({
|
||||
...options,
|
||||
provider: "mistral",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "mistral",
|
||||
cacheKeyData: {
|
||||
provider: "mistral",
|
||||
model: client.model,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const ollamaAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider({
|
||||
...options,
|
||||
provider: "ollama",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "ollama",
|
||||
cacheKeyData: {
|
||||
provider: "ollama",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const lmstudioAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "lmstudio",
|
||||
defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider({
|
||||
...options,
|
||||
provider: "lmstudio",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "lmstudio",
|
||||
cacheKeyData: {
|
||||
provider: "lmstudio",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const localAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "local",
|
||||
defaultModel: DEFAULT_LOCAL_MODEL,
|
||||
@@ -106,14 +368,24 @@ const localAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
},
|
||||
};
|
||||
|
||||
export const builtinMemoryEmbeddingProviderAdapters = [localAdapter] as const;
|
||||
export const builtinMemoryEmbeddingProviderAdapters = [
|
||||
localAdapter,
|
||||
openAiAdapter,
|
||||
geminiAdapter,
|
||||
voyageAdapter,
|
||||
mistralAdapter,
|
||||
ollamaAdapter,
|
||||
lmstudioAdapter,
|
||||
] as const;
|
||||
|
||||
export { DEFAULT_LOCAL_MODEL };
|
||||
const builtinMemoryEmbeddingProviderAdapterById = new Map(
|
||||
builtinMemoryEmbeddingProviderAdapters.map((adapter) => [adapter.id, adapter]),
|
||||
);
|
||||
|
||||
export function getBuiltinMemoryEmbeddingProviderAdapter(
|
||||
id: string,
|
||||
): MemoryEmbeddingProviderAdapter | undefined {
|
||||
return listMemoryEmbeddingProviders().find((adapter) => adapter.id === id);
|
||||
return builtinMemoryEmbeddingProviderAdapterById.get(id);
|
||||
}
|
||||
|
||||
export function registerBuiltInMemoryEmbeddingProviders(register: {
|
||||
@@ -137,7 +409,7 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata(
|
||||
if (!adapter) {
|
||||
return null;
|
||||
}
|
||||
const authProviderId = adapter.authProviderId ?? adapter.id;
|
||||
const authProviderId = resolveMemoryEmbeddingAuthProviderId(adapter.id);
|
||||
return {
|
||||
providerId: adapter.id,
|
||||
authProviderId,
|
||||
@@ -148,19 +420,27 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata(
|
||||
}
|
||||
|
||||
export function listBuiltinAutoSelectMemoryEmbeddingProviderDoctorMetadata(): Array<BuiltinMemoryEmbeddingProviderDoctorMetadata> {
|
||||
return listMemoryEmbeddingProviders()
|
||||
return builtinMemoryEmbeddingProviderAdapters
|
||||
.filter((adapter) => typeof adapter.autoSelectPriority === "number")
|
||||
.toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0))
|
||||
.map((adapter) => {
|
||||
const authProviderId = adapter.authProviderId ?? adapter.id;
|
||||
return {
|
||||
providerId: adapter.id,
|
||||
authProviderId,
|
||||
envVars: getProviderEnvVars(authProviderId),
|
||||
transport: adapter.transport === "local" ? "local" : "remote",
|
||||
autoSelectPriority: adapter.autoSelectPriority,
|
||||
};
|
||||
});
|
||||
.map((adapter) => ({
|
||||
providerId: adapter.id,
|
||||
authProviderId: resolveMemoryEmbeddingAuthProviderId(adapter.id),
|
||||
envVars: getProviderEnvVars(resolveMemoryEmbeddingAuthProviderId(adapter.id)),
|
||||
transport: adapter.transport === "local" ? "local" : "remote",
|
||||
autoSelectPriority: adapter.autoSelectPriority,
|
||||
}));
|
||||
}
|
||||
|
||||
export { canAutoSelectLocal, formatLocalSetupError };
|
||||
export {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
canAutoSelectLocal,
|
||||
formatLocalSetupError,
|
||||
isMissingApiKeyError,
|
||||
};
|
||||
|
||||
@@ -16,7 +16,6 @@ import {
|
||||
writeFileWithinRoot,
|
||||
type OpenClawConfig,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import {
|
||||
buildSessionEntry,
|
||||
deriveQmdScopeChannel,
|
||||
@@ -48,6 +47,7 @@ import {
|
||||
type ResolvedQmdConfig,
|
||||
type ResolvedQmdMcporterConfig,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-storage";
|
||||
import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import {
|
||||
localeLowercasePreservingWhitespace,
|
||||
normalizeLowercaseStringOrEmpty,
|
||||
@@ -1945,7 +1945,8 @@ export class QmdMemoryManager implements MemorySearchManager {
|
||||
from?: number,
|
||||
lines?: number,
|
||||
): Promise<
|
||||
{ missing: true } | { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean }
|
||||
| { missing: true }
|
||||
| { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean }
|
||||
> {
|
||||
const start = Math.max(1, from ?? 1);
|
||||
const count = Math.max(1, lines ?? Number.POSITIVE_INFINITY);
|
||||
|
||||
@@ -51,7 +51,7 @@ describe("memory-wiki cli metadata entry", () => {
|
||||
const resolvedConfig = { vaultMode: "bridge", vault: { path: "/vault" } };
|
||||
mocks.resolveMemoryWikiConfig.mockReturnValue(resolvedConfig);
|
||||
|
||||
plugin.register(api);
|
||||
await plugin.register(api);
|
||||
|
||||
const register = registerCli.mock.calls[0]?.[0];
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ describe("memory-wiki plugin", () => {
|
||||
registerTool,
|
||||
} = createPluginApi();
|
||||
|
||||
plugin.register(api);
|
||||
await plugin.register(api);
|
||||
|
||||
expect(registerMemoryCorpusSupplement).toHaveBeenCalledTimes(1);
|
||||
expect(registerMemoryPromptSupplement).toHaveBeenCalledTimes(1);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { defineSingleProviderPluginEntry } from "openclaw/plugin-sdk/provider-entry";
|
||||
import { applyMistralModelCompat } from "./api.js";
|
||||
import { mistralMediaUnderstandingProvider } from "./media-understanding-provider.js";
|
||||
import { mistralMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { applyMistralConfig, MISTRAL_DEFAULT_MODEL_REF } from "./onboard.js";
|
||||
import { buildMistralProvider } from "./provider-catalog.js";
|
||||
import { contributeMistralResolvedModelCompat } from "./provider-compat.js";
|
||||
@@ -49,7 +48,6 @@ export default defineSingleProviderPluginEntry({
|
||||
buildReplayPolicy: () => buildMistralReplayPolicy(),
|
||||
},
|
||||
register(api) {
|
||||
api.registerMemoryEmbeddingProvider(mistralMemoryEmbeddingProviderAdapter);
|
||||
api.registerMediaUnderstandingProvider(mistralMediaUnderstandingProvider);
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createMistralEmbeddingProvider,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const mistralMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "mistral",
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "mistral",
|
||||
autoSelectPriority: 50,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createMistralEmbeddingProvider({
|
||||
...options,
|
||||
provider: "mistral",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "mistral",
|
||||
cacheKeyData: {
|
||||
provider: "mistral",
|
||||
model: client.model,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -21,7 +21,6 @@
|
||||
}
|
||||
],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["mistral"],
|
||||
"mediaUnderstandingProviders": ["mistral"]
|
||||
},
|
||||
"configSchema": {
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
export { createAuthRateLimiter } from "openclaw/plugin-sdk/nextcloud-talk";
|
||||
@@ -11,9 +11,13 @@ const hoisted = vi.hoisted(() => ({
|
||||
monitorNextcloudTalkProvider: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("./monitor-runtime.js", () => ({
|
||||
monitorNextcloudTalkProvider: hoisted.monitorNextcloudTalkProvider,
|
||||
}));
|
||||
vi.mock("./monitor.js", async () => {
|
||||
const actual = await vi.importActual<typeof import("./monitor.js")>("./monitor.js");
|
||||
return {
|
||||
...actual,
|
||||
monitorNextcloudTalkProvider: hoisted.monitorNextcloudTalkProvider,
|
||||
};
|
||||
});
|
||||
|
||||
const { nextcloudTalkGatewayAdapter } = await import("./gateway.js");
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
type ChannelPlugin,
|
||||
type OpenClawConfig,
|
||||
} from "./channel-api.js";
|
||||
import { monitorNextcloudTalkProvider } from "./monitor-runtime.js";
|
||||
import { monitorNextcloudTalkProvider } from "./monitor.js";
|
||||
import { getNextcloudTalkRuntime } from "./runtime.js";
|
||||
import type { CoreConfig } from "./types.js";
|
||||
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
import os from "node:os";
|
||||
import { resolveLoggerBackedRuntime } from "openclaw/plugin-sdk/extension-shared";
|
||||
import type { RuntimeEnv } from "openclaw/plugin-sdk/runtime";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { resolveNextcloudTalkAccount } from "./accounts.js";
|
||||
import { handleNextcloudTalkInbound } from "./inbound.js";
|
||||
import {
|
||||
createNextcloudTalkWebhookServer,
|
||||
processNextcloudTalkReplayGuardedMessage,
|
||||
} from "./monitor.js";
|
||||
import { createNextcloudTalkReplayGuard } from "./replay-guard.js";
|
||||
import { getNextcloudTalkRuntime } from "./runtime.js";
|
||||
import type { CoreConfig, NextcloudTalkInboundMessage } from "./types.js";
|
||||
|
||||
const DEFAULT_WEBHOOK_PORT = 8788;
|
||||
const DEFAULT_WEBHOOK_HOST = "0.0.0.0";
|
||||
const DEFAULT_WEBHOOK_PATH = "/nextcloud-talk-webhook";
|
||||
|
||||
function normalizeOrigin(value: string): string | null {
|
||||
try {
|
||||
return normalizeLowercaseStringOrEmpty(new URL(value).origin);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export type NextcloudTalkMonitorOptions = {
|
||||
accountId?: string;
|
||||
config?: CoreConfig;
|
||||
runtime?: RuntimeEnv;
|
||||
abortSignal?: AbortSignal;
|
||||
onMessage?: (message: NextcloudTalkInboundMessage) => void | Promise<void>;
|
||||
statusSink?: (patch: { lastInboundAt?: number; lastOutboundAt?: number }) => void;
|
||||
};
|
||||
|
||||
export async function monitorNextcloudTalkProvider(
|
||||
opts: NextcloudTalkMonitorOptions,
|
||||
): Promise<{ stop: () => void }> {
|
||||
const core = getNextcloudTalkRuntime();
|
||||
const cfg = opts.config ?? (core.config.loadConfig() as CoreConfig);
|
||||
const account = resolveNextcloudTalkAccount({
|
||||
cfg,
|
||||
accountId: opts.accountId,
|
||||
});
|
||||
const runtime: RuntimeEnv = resolveLoggerBackedRuntime(
|
||||
opts.runtime,
|
||||
core.logging.getChildLogger(),
|
||||
);
|
||||
|
||||
if (!account.secret) {
|
||||
throw new Error(`Nextcloud Talk bot secret not configured for account "${account.accountId}"`);
|
||||
}
|
||||
|
||||
const port = account.config.webhookPort ?? DEFAULT_WEBHOOK_PORT;
|
||||
const host = account.config.webhookHost ?? DEFAULT_WEBHOOK_HOST;
|
||||
const path = account.config.webhookPath ?? DEFAULT_WEBHOOK_PATH;
|
||||
|
||||
const logger = core.logging.getChildLogger({
|
||||
channel: "nextcloud-talk",
|
||||
accountId: account.accountId,
|
||||
});
|
||||
const expectedBackendOrigin = normalizeOrigin(account.baseUrl);
|
||||
const replayGuard = createNextcloudTalkReplayGuard({
|
||||
stateDir: core.state.resolveStateDir(process.env, os.homedir),
|
||||
onDiskError: (error) => {
|
||||
logger.warn(
|
||||
`[nextcloud-talk:${account.accountId}] replay guard disk error: ${String(error)}`,
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
const { start, stop } = createNextcloudTalkWebhookServer({
|
||||
port,
|
||||
host,
|
||||
path,
|
||||
secret: account.secret,
|
||||
isBackendAllowed: (backend) => {
|
||||
if (!expectedBackendOrigin) {
|
||||
return true;
|
||||
}
|
||||
const backendOrigin = normalizeOrigin(backend);
|
||||
return backendOrigin === expectedBackendOrigin;
|
||||
},
|
||||
processMessage: async (message) => {
|
||||
const result = await processNextcloudTalkReplayGuardedMessage({
|
||||
replayGuard,
|
||||
accountId: account.accountId,
|
||||
message,
|
||||
handleMessage: async () => {
|
||||
core.channel.activity.record({
|
||||
channel: "nextcloud-talk",
|
||||
accountId: account.accountId,
|
||||
direction: "inbound",
|
||||
at: message.timestamp,
|
||||
});
|
||||
if (opts.onMessage) {
|
||||
await opts.onMessage(message);
|
||||
} else {
|
||||
await handleNextcloudTalkInbound({
|
||||
message,
|
||||
account,
|
||||
config: cfg,
|
||||
runtime,
|
||||
statusSink: opts.statusSink,
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
if (result === "duplicate") {
|
||||
logger.warn(
|
||||
`[nextcloud-talk:${account.accountId}] replayed webhook ignored room=${message.roomToken} messageId=${message.messageId}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
},
|
||||
onMessage: async () => {},
|
||||
onError: (error) => {
|
||||
logger.error(`[nextcloud-talk:${account.accountId}] webhook error: ${error.message}`);
|
||||
},
|
||||
abortSignal: opts.abortSignal,
|
||||
});
|
||||
|
||||
if (opts.abortSignal?.aborted) {
|
||||
return { stop };
|
||||
}
|
||||
await start();
|
||||
if (opts.abortSignal?.aborted) {
|
||||
stop();
|
||||
return { stop };
|
||||
}
|
||||
|
||||
const publicUrl =
|
||||
account.config.webhookPublicUrl ??
|
||||
`http://${host === "0.0.0.0" ? "localhost" : host}:${port}${path}`;
|
||||
logger.info(`[nextcloud-talk:${account.accountId}] webhook listening on ${publicUrl}`);
|
||||
|
||||
return { stop };
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { createMockIncomingRequest } from "../../../test/helpers/mock-incoming-request.js";
|
||||
import { WEBHOOK_RATE_LIMIT_DEFAULTS } from "../runtime-api.js";
|
||||
import {
|
||||
NextcloudTalkRetryableWebhookError,
|
||||
processNextcloudTalkReplayGuardedMessage,
|
||||
@@ -273,10 +274,8 @@ describe("createNextcloudTalkWebhookServer payload validation", () => {
|
||||
|
||||
describe("createNextcloudTalkWebhookServer auth rate limiting", () => {
|
||||
it("rate limits repeated invalid signature attempts from the same source", async () => {
|
||||
const maxRequests = 2;
|
||||
const harness = await startWebhookServer({
|
||||
path: "/nextcloud-auth-rate-limit",
|
||||
authRateLimit: { maxRequests },
|
||||
onMessage: vi.fn(),
|
||||
});
|
||||
const { body, headers } = createSignedCreateMessageRequest();
|
||||
@@ -287,7 +286,7 @@ describe("createNextcloudTalkWebhookServer auth rate limiting", () => {
|
||||
|
||||
let firstResponse: Response | undefined;
|
||||
let lastResponse: Response | undefined;
|
||||
for (let attempt = 0; attempt <= maxRequests; attempt += 1) {
|
||||
for (let attempt = 0; attempt <= WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests; attempt += 1) {
|
||||
const response = await fetch(harness.webhookUrl, {
|
||||
method: "POST",
|
||||
headers: invalidHeaders,
|
||||
@@ -307,16 +306,14 @@ describe("createNextcloudTalkWebhookServer auth rate limiting", () => {
|
||||
});
|
||||
|
||||
it("does not rate limit valid signed webhook bursts from the same source", async () => {
|
||||
const maxRequests = 2;
|
||||
const harness = await startWebhookServer({
|
||||
path: "/nextcloud-auth-rate-limit-valid",
|
||||
authRateLimit: { maxRequests },
|
||||
onMessage: vi.fn(),
|
||||
});
|
||||
const { body, headers } = createSignedCreateMessageRequest();
|
||||
|
||||
let lastResponse: Response | undefined;
|
||||
for (let attempt = 0; attempt <= maxRequests; attempt += 1) {
|
||||
for (let attempt = 0; attempt <= WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests; attempt += 1) {
|
||||
lastResponse = await fetch(harness.webhookUrl, {
|
||||
method: "POST",
|
||||
headers,
|
||||
|
||||
@@ -1,22 +1,35 @@
|
||||
import { createServer, type IncomingMessage, type Server, type ServerResponse } from "node:http";
|
||||
import { safeParseJsonWithSchema } from "openclaw/plugin-sdk/extension-shared";
|
||||
import os from "node:os";
|
||||
import {
|
||||
resolveLoggerBackedRuntime,
|
||||
safeParseJsonWithSchema,
|
||||
} from "openclaw/plugin-sdk/extension-shared";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { z } from "zod";
|
||||
import {
|
||||
WEBHOOK_RATE_LIMIT_DEFAULTS,
|
||||
createAuthRateLimiter,
|
||||
type RuntimeEnv,
|
||||
isRequestBodyLimitError,
|
||||
readRequestBodyWithLimit,
|
||||
requestBodyErrorToText,
|
||||
} from "openclaw/plugin-sdk/webhook-ingress";
|
||||
import { z } from "zod";
|
||||
import { createAuthRateLimiter } from "./api.js";
|
||||
import type { NextcloudTalkReplayGuard } from "./replay-guard.js";
|
||||
} from "../runtime-api.js";
|
||||
import { resolveNextcloudTalkAccount } from "./accounts.js";
|
||||
import { handleNextcloudTalkInbound } from "./inbound.js";
|
||||
import { createNextcloudTalkReplayGuard, type NextcloudTalkReplayGuard } from "./replay-guard.js";
|
||||
import { getNextcloudTalkRuntime } from "./runtime.js";
|
||||
import { extractNextcloudTalkHeaders, verifyNextcloudTalkSignature } from "./signature.js";
|
||||
import type {
|
||||
CoreConfig,
|
||||
NextcloudTalkInboundMessage,
|
||||
NextcloudTalkWebhookHeaders,
|
||||
NextcloudTalkWebhookPayload,
|
||||
NextcloudTalkWebhookServerOptions,
|
||||
} from "./types.js";
|
||||
|
||||
const DEFAULT_WEBHOOK_PORT = 8788;
|
||||
const DEFAULT_WEBHOOK_HOST = "0.0.0.0";
|
||||
const DEFAULT_WEBHOOK_PATH = "/nextcloud-talk-webhook";
|
||||
const DEFAULT_WEBHOOK_MAX_BODY_BYTES = 1024 * 1024;
|
||||
const PREAUTH_WEBHOOK_MAX_BODY_BYTES = 64 * 1024;
|
||||
const PREAUTH_WEBHOOK_BODY_TIMEOUT_MS = 5_000;
|
||||
@@ -109,6 +122,14 @@ function formatError(err: unknown): string {
|
||||
return typeof err === "string" ? err : JSON.stringify(err);
|
||||
}
|
||||
|
||||
function normalizeOrigin(value: string): string | null {
|
||||
try {
|
||||
return normalizeLowercaseStringOrEmpty(new URL(value).origin);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function parseWebhookPayload(body: string): NextcloudTalkWebhookPayload | null {
|
||||
return safeParseJsonWithSchema(NextcloudTalkWebhookPayloadSchema, body);
|
||||
}
|
||||
@@ -241,20 +262,12 @@ export function createNextcloudTalkWebhookServer(opts: NextcloudTalkWebhookServe
|
||||
const isBackendAllowed = opts.isBackendAllowed;
|
||||
const shouldProcessMessage = opts.shouldProcessMessage;
|
||||
const processMessage = opts.processMessage;
|
||||
const authRateLimitMaxRequests =
|
||||
typeof opts.authRateLimit?.maxRequests === "number"
|
||||
? opts.authRateLimit.maxRequests
|
||||
: WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests;
|
||||
const authRateLimitWindowMs =
|
||||
typeof opts.authRateLimit?.windowMs === "number"
|
||||
? opts.authRateLimit.windowMs
|
||||
: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs;
|
||||
const webhookAuthRateLimiter = createAuthRateLimiter({
|
||||
maxAttempts: authRateLimitMaxRequests,
|
||||
windowMs: authRateLimitWindowMs,
|
||||
lockoutMs: authRateLimitWindowMs,
|
||||
maxAttempts: WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests,
|
||||
windowMs: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs,
|
||||
lockoutMs: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs,
|
||||
exemptLoopback: false,
|
||||
pruneIntervalMs: authRateLimitWindowMs,
|
||||
pruneIntervalMs: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs,
|
||||
});
|
||||
|
||||
const server = createServer(async (req: IncomingMessage, res: ServerResponse) => {
|
||||
@@ -383,3 +396,116 @@ export function createNextcloudTalkWebhookServer(opts: NextcloudTalkWebhookServe
|
||||
|
||||
return { server, start, stop };
|
||||
}
|
||||
|
||||
export type NextcloudTalkMonitorOptions = {
|
||||
accountId?: string;
|
||||
config?: CoreConfig;
|
||||
runtime?: RuntimeEnv;
|
||||
abortSignal?: AbortSignal;
|
||||
onMessage?: (message: NextcloudTalkInboundMessage) => void | Promise<void>;
|
||||
statusSink?: (patch: { lastInboundAt?: number; lastOutboundAt?: number }) => void;
|
||||
};
|
||||
|
||||
export async function monitorNextcloudTalkProvider(
|
||||
opts: NextcloudTalkMonitorOptions,
|
||||
): Promise<{ stop: () => void }> {
|
||||
const core = getNextcloudTalkRuntime();
|
||||
const cfg = opts.config ?? (core.config.loadConfig() as CoreConfig);
|
||||
const account = resolveNextcloudTalkAccount({
|
||||
cfg,
|
||||
accountId: opts.accountId,
|
||||
});
|
||||
const runtime: RuntimeEnv = resolveLoggerBackedRuntime(
|
||||
opts.runtime,
|
||||
core.logging.getChildLogger(),
|
||||
);
|
||||
|
||||
if (!account.secret) {
|
||||
throw new Error(`Nextcloud Talk bot secret not configured for account "${account.accountId}"`);
|
||||
}
|
||||
|
||||
const port = account.config.webhookPort ?? DEFAULT_WEBHOOK_PORT;
|
||||
const host = account.config.webhookHost ?? DEFAULT_WEBHOOK_HOST;
|
||||
const path = account.config.webhookPath ?? DEFAULT_WEBHOOK_PATH;
|
||||
|
||||
const logger = core.logging.getChildLogger({
|
||||
channel: "nextcloud-talk",
|
||||
accountId: account.accountId,
|
||||
});
|
||||
const expectedBackendOrigin = normalizeOrigin(account.baseUrl);
|
||||
const replayGuard = createNextcloudTalkReplayGuard({
|
||||
stateDir: core.state.resolveStateDir(process.env, os.homedir),
|
||||
onDiskError: (error) => {
|
||||
logger.warn(
|
||||
`[nextcloud-talk:${account.accountId}] replay guard disk error: ${String(error)}`,
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
const { start, stop } = createNextcloudTalkWebhookServer({
|
||||
port,
|
||||
host,
|
||||
path,
|
||||
secret: account.secret,
|
||||
isBackendAllowed: (backend) => {
|
||||
if (!expectedBackendOrigin) {
|
||||
return true;
|
||||
}
|
||||
const backendOrigin = normalizeOrigin(backend);
|
||||
return backendOrigin === expectedBackendOrigin;
|
||||
},
|
||||
processMessage: async (message) => {
|
||||
const result = await processNextcloudTalkReplayGuardedMessage({
|
||||
replayGuard,
|
||||
accountId: account.accountId,
|
||||
message,
|
||||
handleMessage: async () => {
|
||||
core.channel.activity.record({
|
||||
channel: "nextcloud-talk",
|
||||
accountId: account.accountId,
|
||||
direction: "inbound",
|
||||
at: message.timestamp,
|
||||
});
|
||||
if (opts.onMessage) {
|
||||
await opts.onMessage(message);
|
||||
} else {
|
||||
await handleNextcloudTalkInbound({
|
||||
message,
|
||||
account,
|
||||
config: cfg,
|
||||
runtime,
|
||||
statusSink: opts.statusSink,
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
if (result === "duplicate") {
|
||||
logger.warn(
|
||||
`[nextcloud-talk:${account.accountId}] replayed webhook ignored room=${message.roomToken} messageId=${message.messageId}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
},
|
||||
onMessage: async () => {},
|
||||
onError: (error) => {
|
||||
logger.error(`[nextcloud-talk:${account.accountId}] webhook error: ${error.message}`);
|
||||
},
|
||||
abortSignal: opts.abortSignal,
|
||||
});
|
||||
|
||||
if (opts.abortSignal?.aborted) {
|
||||
return { stop };
|
||||
}
|
||||
await start();
|
||||
if (opts.abortSignal?.aborted) {
|
||||
stop();
|
||||
return { stop };
|
||||
}
|
||||
|
||||
const publicUrl =
|
||||
account.config.webhookPublicUrl ??
|
||||
`http://${host === "0.0.0.0" ? "localhost" : host}:${port}${path}`;
|
||||
logger.info(`[nextcloud-talk:${account.accountId}] webhook listening on ${publicUrl}`);
|
||||
|
||||
return { stop };
|
||||
}
|
||||
|
||||
@@ -179,10 +179,6 @@ export type NextcloudTalkWebhookServerOptions = {
|
||||
path: string;
|
||||
secret: string;
|
||||
maxBodyBytes?: number;
|
||||
authRateLimit?: {
|
||||
maxRequests?: number;
|
||||
windowMs?: number;
|
||||
};
|
||||
readBody?: (req: import("node:http").IncomingMessage, maxBodyBytes: number) => Promise<string>;
|
||||
isBackendAllowed?: (backend: string) => boolean;
|
||||
shouldProcessMessage?: (message: NextcloudTalkInboundMessage) => boolean | Promise<boolean>;
|
||||
|
||||
@@ -8,7 +8,6 @@ export const ollamaMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapte
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "ollama",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider({
|
||||
...options,
|
||||
|
||||
@@ -54,7 +54,7 @@ const _registerOpenAIPlugin = async () =>
|
||||
async function registerOpenAIPluginWithHook(params?: { pluginConfig?: Record<string, unknown> }) {
|
||||
const on = vi.fn();
|
||||
const providers: ProviderPlugin[] = [];
|
||||
plugin.register(
|
||||
await plugin.register(
|
||||
createTestPluginApi({
|
||||
id: "openai",
|
||||
name: "OpenAI Provider",
|
||||
|
||||
@@ -6,7 +6,6 @@ import {
|
||||
openaiCodexMediaUnderstandingProvider,
|
||||
openaiMediaUnderstandingProvider,
|
||||
} from "./media-understanding-provider.js";
|
||||
import { openAiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { buildOpenAICodexProviderPlugin } from "./openai-codex-provider.js";
|
||||
import { buildOpenAIProvider } from "./openai-provider.js";
|
||||
import {
|
||||
@@ -40,7 +39,6 @@ export default definePluginEntry({
|
||||
api.registerCliBackend(buildOpenAICodexCliBackend());
|
||||
api.registerProvider(buildProviderWithPromptContribution(buildOpenAIProvider()));
|
||||
api.registerProvider(buildProviderWithPromptContribution(buildOpenAICodexProviderPlugin()));
|
||||
api.registerMemoryEmbeddingProvider(openAiMemoryEmbeddingProviderAdapter);
|
||||
api.registerImageGenerationProvider(buildOpenAIImageGenerationProvider());
|
||||
api.registerRealtimeTranscriptionProvider(buildOpenAIRealtimeTranscriptionProvider());
|
||||
api.registerRealtimeVoiceProvider(buildOpenAIRealtimeVoiceProvider());
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { OPENAI_BATCH_ENDPOINT, runOpenAiEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
createOpenAiEmbeddingProvider,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const openAiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "openai",
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "openai",
|
||||
autoSelectPriority: 20,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "openai",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "openai",
|
||||
cacheKeyData: {
|
||||
provider: "openai",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runOpenAiEmbeddingBatches({
|
||||
openAi: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: {
|
||||
model: client.model,
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -39,7 +39,6 @@
|
||||
"speechProviders": ["openai"],
|
||||
"realtimeTranscriptionProviders": ["openai"],
|
||||
"realtimeVoiceProviders": ["openai"],
|
||||
"memoryEmbeddingProviders": ["openai"],
|
||||
"mediaUnderstandingProviders": ["openai", "openai-codex"],
|
||||
"imageGenerationProviders": ["openai"],
|
||||
"videoGenerationProviders": ["openai"]
|
||||
|
||||
@@ -80,7 +80,7 @@ async function withRegisteredPhoneControl(
|
||||
});
|
||||
|
||||
let command: OpenClawPluginCommandDefinition | undefined;
|
||||
registerPhoneControl.register(
|
||||
void registerPhoneControl.register(
|
||||
createApi({
|
||||
stateDir,
|
||||
getConfig: () => config,
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
"description": "OpenClaw QA lab plugin with private debugger UI and scenario runner",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"@copilotkit/aimock": "1.13.0",
|
||||
"playwright-core": "1.59.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { createPluginSetupWizardStatus } from "../../../test/helpers/plugins/setup-wizard.js";
|
||||
import type { ResolvedSynologyChatAccount } from "./types.js";
|
||||
|
||||
@@ -42,18 +42,12 @@ const getSynologyChatSetupStatus = createPluginSetupWizardStatus(synologyChatPlu
|
||||
|
||||
describe("createSynologyChatPlugin", () => {
|
||||
beforeEach(() => {
|
||||
vi.stubEnv("SYNOLOGY_CHAT_TOKEN", "");
|
||||
vi.stubEnv("SYNOLOGY_CHAT_INCOMING_URL", "");
|
||||
mockSendMessage.mockClear();
|
||||
registerSynologyWebhookRouteMock.mockClear();
|
||||
mockSendMessage.mockResolvedValue(true);
|
||||
registerSynologyWebhookRouteMock.mockImplementation(() => vi.fn());
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
describe("meta", () => {
|
||||
it("has correct id and label", () => {
|
||||
const plugin = createSynologyChatPlugin();
|
||||
@@ -486,17 +480,11 @@ describe("createSynologyChatPlugin", () => {
|
||||
abortController: AbortController,
|
||||
) {
|
||||
expect(result).toBeInstanceOf(Promise);
|
||||
let settled = false;
|
||||
void result.then(
|
||||
() => {
|
||||
settled = true;
|
||||
},
|
||||
() => {
|
||||
settled = true;
|
||||
},
|
||||
);
|
||||
await Promise.resolve();
|
||||
expect(settled).toBe(false);
|
||||
const resolved = await Promise.race([
|
||||
result,
|
||||
new Promise((r) => setTimeout(() => r("pending"), 50)),
|
||||
]);
|
||||
expect(resolved).toBe("pending");
|
||||
abortController.abort();
|
||||
await result;
|
||||
}
|
||||
@@ -596,6 +584,8 @@ describe("createSynologyChatPlugin", () => {
|
||||
const firstPromise = plugin.gateway.startAccount(makeCtx(abortFirst));
|
||||
const secondPromise = plugin.gateway.startAccount(makeCtx(abortSecond));
|
||||
|
||||
await new Promise((r) => setTimeout(r, 10));
|
||||
|
||||
expect(registerMock).toHaveBeenCalledTimes(2);
|
||||
expect(unregisterFirst).not.toHaveBeenCalled();
|
||||
expect(unregisterSecond).not.toHaveBeenCalled();
|
||||
|
||||
@@ -144,19 +144,26 @@ describe("createWebhookHandler", () => {
|
||||
});
|
||||
|
||||
it("returns 408 when request body times out", async () => {
|
||||
const handler = createWebhookHandler({
|
||||
account: makeAccount(),
|
||||
deliver: vi.fn(),
|
||||
log,
|
||||
bodyTimeoutMs: 1,
|
||||
});
|
||||
vi.useFakeTimers();
|
||||
try {
|
||||
const handler = createWebhookHandler({
|
||||
account: makeAccount(),
|
||||
deliver: vi.fn(),
|
||||
log,
|
||||
});
|
||||
|
||||
const req = makeStalledReq("POST");
|
||||
const res = makeRes();
|
||||
await handler(req, res);
|
||||
const req = makeStalledReq("POST");
|
||||
const res = makeRes();
|
||||
const run = handler(req, res);
|
||||
|
||||
expect(res._status).toBe(408);
|
||||
expect(res._body).toContain("timeout");
|
||||
await vi.advanceTimersByTimeAsync(30_000);
|
||||
await run;
|
||||
|
||||
expect(res._status).toBe(408);
|
||||
expect(res._body).toContain("timeout");
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
});
|
||||
|
||||
it("rejects excess concurrent pre-auth body reads from the same remote IP", async () => {
|
||||
|
||||
@@ -142,10 +142,7 @@ function getSynologyWebhookInFlightKey(account: ResolvedSynologyChatAccount): st
|
||||
}
|
||||
|
||||
/** Read the full request body as a string. */
|
||||
async function readBody(
|
||||
req: IncomingMessage,
|
||||
timeoutMs = PREAUTH_BODY_TIMEOUT_MS,
|
||||
): Promise<
|
||||
async function readBody(req: IncomingMessage): Promise<
|
||||
| { ok: true; body: string }
|
||||
| {
|
||||
ok: false;
|
||||
@@ -156,7 +153,7 @@ async function readBody(
|
||||
try {
|
||||
const body = await readRequestBodyWithLimit(req, {
|
||||
maxBytes: PREAUTH_MAX_BODY_BYTES,
|
||||
timeoutMs,
|
||||
timeoutMs: PREAUTH_BODY_TIMEOUT_MS,
|
||||
});
|
||||
return { ok: true, body };
|
||||
} catch (err) {
|
||||
@@ -345,7 +342,6 @@ export interface WebhookHandlerDeps {
|
||||
warn: (...args: unknown[]) => void;
|
||||
error: (...args: unknown[]) => void;
|
||||
};
|
||||
bodyTimeoutMs?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -375,9 +371,8 @@ async function parseWebhookPayloadRequest(params: {
|
||||
req: IncomingMessage;
|
||||
res: ServerResponse;
|
||||
log?: WebhookHandlerDeps["log"];
|
||||
bodyTimeoutMs?: number;
|
||||
}): Promise<{ ok: false } | { ok: true; payload: SynologyWebhookPayload }> {
|
||||
const bodyResult = await readBody(params.req, params.bodyTimeoutMs);
|
||||
const bodyResult = await readBody(params.req);
|
||||
if (!bodyResult.ok) {
|
||||
params.log?.error("Failed to read request body", bodyResult.error);
|
||||
respondJson(params.res, bodyResult.statusCode, { error: bodyResult.error });
|
||||
@@ -470,7 +465,6 @@ async function parseAndAuthorizeSynologyWebhook(params: {
|
||||
invalidTokenRateLimiter: InvalidTokenRateLimiter;
|
||||
rateLimiter: RateLimiter;
|
||||
log?: WebhookHandlerDeps["log"];
|
||||
bodyTimeoutMs?: number;
|
||||
}): Promise<{ ok: false } | { ok: true; message: AuthorizedSynologyWebhook }> {
|
||||
const parsed = await parseWebhookPayloadRequest(params);
|
||||
if (!parsed.ok) {
|
||||
@@ -618,7 +612,6 @@ export function createWebhookHandler(deps: WebhookHandlerDeps) {
|
||||
invalidTokenRateLimiter,
|
||||
rateLimiter,
|
||||
log,
|
||||
bodyTimeoutMs: deps.bodyTimeoutMs,
|
||||
});
|
||||
} finally {
|
||||
// Only bound the pre-auth request pipeline; async reply delivery is outside webhook ingress.
|
||||
|
||||
@@ -20,7 +20,7 @@ function createHarness(config: Record<string, unknown>) {
|
||||
command = definition;
|
||||
}),
|
||||
};
|
||||
register.register(api as never);
|
||||
void register.register(api as never);
|
||||
if (!command) {
|
||||
throw new Error("talk-voice command not registered");
|
||||
}
|
||||
|
||||
@@ -800,82 +800,6 @@ describe("dispatchTelegramMessage draft streaming", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves pre-rotation skip until queued message-start callbacks flush", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
await replyOptions?.onPartialReply?.({ text: "Message B early" });
|
||||
void replyOptions?.onAssistantMessageStart?.();
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
123,
|
||||
1002,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("does not double-rotate when assistant_message_start arrives after final delivery drains", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
await replyOptions?.onPartialReply?.({ text: "Message B early" });
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
123,
|
||||
1002,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("clears active preview even when an unrelated boundary archive exists", async () => {
|
||||
const answerDraftStream = createDraftStream(999);
|
||||
answerDraftStream.materialize.mockResolvedValue(4321);
|
||||
@@ -1130,204 +1054,6 @@ describe("dispatchTelegramMessage draft streaming", () => {
|
||||
expect(answerDraftStream.update).toHaveBeenNthCalledWith(2, "Message B second chunk");
|
||||
});
|
||||
|
||||
it("does not rotate the streamed preview when compaction retries replay the same assistant message", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial extended" });
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).not.toHaveBeenCalled();
|
||||
expect(answerDraftStream.materialize).not.toHaveBeenCalled();
|
||||
expect(editMessageTelegram).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenCalledWith(
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("clears the compaction replay skip after the retried message finalizes", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial extended" });
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await replyOptions?.onPartialReply?.({ text: "Message B partial" });
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
123,
|
||||
1002,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves the compaction replay flag until queued retry callbacks flush", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
void replyOptions?.onAssistantMessageStart?.();
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).not.toHaveBeenCalled();
|
||||
expect(editMessageTelegram).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenCalledWith(
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("keeps the existing preview when the retried answer only arrives as final text", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).not.toHaveBeenCalled();
|
||||
expect(answerDraftStream.materialize).not.toHaveBeenCalled();
|
||||
expect(editMessageTelegram).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenCalledWith(
|
||||
123,
|
||||
1001,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("keeps the transient preview when a local exec approval prompt is suppressed after compaction", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
await dispatcherOptions.deliver(
|
||||
{
|
||||
text: "Approval required.\n\n```txt\n/approve 7f423fdc allow-once\n```",
|
||||
channelData: {
|
||||
execApproval: {
|
||||
approvalId: "7f423fdc-1111-2222-3333-444444444444",
|
||||
approvalSlug: "7f423fdc",
|
||||
allowedDecisions: ["allow-once", "allow-always", "deny"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{ kind: "tool" },
|
||||
);
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await replyOptions?.onPartialReply?.({ text: "Message B partial" });
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({
|
||||
context: createContext(),
|
||||
streamMode: "partial",
|
||||
cfg: {
|
||||
channels: {
|
||||
telegram: {
|
||||
execApprovals: {
|
||||
enabled: true,
|
||||
approvers: ["12345"],
|
||||
target: "dm",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).not.toHaveBeenCalled();
|
||||
expect(editMessageTelegram).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenCalledWith(
|
||||
123,
|
||||
1001,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("finalizes multi-message assistant stream to matching preview messages in order", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
|
||||
@@ -280,10 +280,6 @@ export const dispatchTelegramMessage = async ({
|
||||
const reasoningLane = lanes.reasoning;
|
||||
let splitReasoningOnNextStream = false;
|
||||
let skipNextAnswerMessageStartRotation = false;
|
||||
// If compaction interrupts a still-transient answer preview, keep the next
|
||||
// assistant-message boundary on that same preview instead of materializing a
|
||||
// duplicate retry message.
|
||||
let pendingCompactionReplayBoundary = false;
|
||||
let draftLaneEventQueue = Promise.resolve();
|
||||
const reasoningStepState = createTelegramReasoningStepState();
|
||||
const enqueueDraftLaneEvent = (task: () => Promise<void>): Promise<void> => {
|
||||
@@ -697,9 +693,6 @@ export const dispatchTelegramMessage = async ({
|
||||
}
|
||||
}
|
||||
if (segments.length > 0) {
|
||||
if (info.kind === "final") {
|
||||
pendingCompactionReplayBoundary = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (split.suppressedReasoningOnly) {
|
||||
@@ -710,7 +703,6 @@ export const dispatchTelegramMessage = async ({
|
||||
}
|
||||
if (info.kind === "final") {
|
||||
await flushBufferedFinalAnswer();
|
||||
pendingCompactionReplayBoundary = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -724,14 +716,12 @@ export const dispatchTelegramMessage = async ({
|
||||
if (!canSendAsIs) {
|
||||
if (info.kind === "final") {
|
||||
await flushBufferedFinalAnswer();
|
||||
pendingCompactionReplayBoundary = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
await sendPayload(payload);
|
||||
if (info.kind === "final") {
|
||||
await flushBufferedFinalAnswer();
|
||||
pendingCompactionReplayBoundary = false;
|
||||
}
|
||||
},
|
||||
onSkip: (payload, info) => {
|
||||
@@ -803,12 +793,6 @@ export const dispatchTelegramMessage = async ({
|
||||
retainPreviewOnCleanupByLane.answer = false;
|
||||
return;
|
||||
}
|
||||
if (pendingCompactionReplayBoundary) {
|
||||
pendingCompactionReplayBoundary = false;
|
||||
activePreviewLifecycleByLane.answer = "transient";
|
||||
retainPreviewOnCleanupByLane.answer = false;
|
||||
return;
|
||||
}
|
||||
await rotateAnswerLaneForNewAssistantMessage();
|
||||
// Message-start is an explicit assistant-message boundary.
|
||||
// Even when no forceNewMessage happened (e.g. prior answer had no
|
||||
@@ -833,20 +817,9 @@ export const dispatchTelegramMessage = async ({
|
||||
}
|
||||
}
|
||||
: undefined,
|
||||
onCompactionStart:
|
||||
statusReactionController || answerLane.stream
|
||||
? async () => {
|
||||
if (
|
||||
answerLane.hasStreamedMessage &&
|
||||
activePreviewLifecycleByLane.answer === "transient"
|
||||
) {
|
||||
pendingCompactionReplayBoundary = true;
|
||||
}
|
||||
if (statusReactionController) {
|
||||
await statusReactionController.setCompacting();
|
||||
}
|
||||
}
|
||||
: undefined,
|
||||
onCompactionStart: statusReactionController
|
||||
? () => statusReactionController.setCompacting()
|
||||
: undefined,
|
||||
onCompactionEnd: statusReactionController
|
||||
? async () => {
|
||||
statusReactionController.cancelPending();
|
||||
|
||||
@@ -7,18 +7,6 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { importFreshModule } from "../../../test/helpers/import-fresh.js";
|
||||
|
||||
const writeJsonFileAtomicallyMock = vi.hoisted(() => vi.fn());
|
||||
const readAcpSessionEntryMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/acp-runtime", async () => {
|
||||
const actual = await vi.importActual<typeof import("openclaw/plugin-sdk/acp-runtime")>(
|
||||
"openclaw/plugin-sdk/acp-runtime",
|
||||
);
|
||||
readAcpSessionEntryMock.mockImplementation(actual.readAcpSessionEntry);
|
||||
return {
|
||||
...actual,
|
||||
readAcpSessionEntry: readAcpSessionEntryMock,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/json-store", async () => {
|
||||
const actual = await vi.importActual<typeof import("openclaw/plugin-sdk/json-store")>(
|
||||
@@ -48,11 +36,6 @@ describe("telegram thread bindings", () => {
|
||||
|
||||
beforeEach(async () => {
|
||||
writeJsonFileAtomicallyMock.mockClear();
|
||||
readAcpSessionEntryMock.mockReset();
|
||||
const acpRuntime = await vi.importActual<typeof import("openclaw/plugin-sdk/acp-runtime")>(
|
||||
"openclaw/plugin-sdk/acp-runtime",
|
||||
);
|
||||
readAcpSessionEntryMock.mockImplementation(acpRuntime.readAcpSessionEntry);
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
});
|
||||
|
||||
@@ -310,136 +293,6 @@ describe("telegram thread bindings", () => {
|
||||
expect(reloaded.getByConversationId("8460800771")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("cleans up stale ACP bindings before restart routing can reuse them", async () => {
|
||||
stateDirOverride = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-telegram-bindings-"));
|
||||
process.env.OPENCLAW_STATE_DIR = stateDirOverride;
|
||||
|
||||
createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
await getSessionBindingService().bind({
|
||||
targetSessionKey: "agent:main:acp:stale-1",
|
||||
targetKind: "session",
|
||||
conversation: {
|
||||
channel: "telegram",
|
||||
accountId: "default",
|
||||
conversationId: "cleanup-me",
|
||||
},
|
||||
});
|
||||
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
readAcpSessionEntryMock.mockReturnValue({
|
||||
cfg: {} as never,
|
||||
storePath: "/tmp/acp-store.json",
|
||||
sessionKey: "agent:main:acp:stale-1",
|
||||
storeSessionKey: "agent:main:acp:stale-1",
|
||||
entry: undefined,
|
||||
acp: undefined,
|
||||
storeReadFailed: false,
|
||||
});
|
||||
|
||||
const reloaded = createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
expect(reloaded.getByConversationId("cleanup-me")).toBeUndefined();
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
const persisted = JSON.parse(
|
||||
fs.readFileSync(
|
||||
path.join(
|
||||
resolveStateDir(process.env, os.homedir),
|
||||
"telegram",
|
||||
"thread-bindings-default.json",
|
||||
),
|
||||
"utf8",
|
||||
),
|
||||
) as { bindings?: Array<{ conversationId?: string }> };
|
||||
expect(persisted.bindings?.map((binding) => binding.conversationId)).not.toContain(
|
||||
"cleanup-me",
|
||||
);
|
||||
});
|
||||
|
||||
it("keeps plugin-owned bindings when ACP cleanup runs on startup", async () => {
|
||||
stateDirOverride = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-telegram-bindings-"));
|
||||
process.env.OPENCLAW_STATE_DIR = stateDirOverride;
|
||||
|
||||
createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
await getSessionBindingService().bind({
|
||||
targetSessionKey: "plugin-binding:openclaw-codex-app-server:still-valid",
|
||||
targetKind: "session",
|
||||
conversation: {
|
||||
channel: "telegram",
|
||||
accountId: "default",
|
||||
conversationId: "plugin-binding-convo",
|
||||
},
|
||||
});
|
||||
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
|
||||
const reloaded = createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
expect(reloaded.getByConversationId("plugin-binding-convo")?.targetSessionKey).toBe(
|
||||
"plugin-binding:openclaw-codex-app-server:still-valid",
|
||||
);
|
||||
expect(readAcpSessionEntryMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("keeps ACP bindings when the session store cannot be read during startup cleanup", async () => {
|
||||
stateDirOverride = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-telegram-bindings-"));
|
||||
process.env.OPENCLAW_STATE_DIR = stateDirOverride;
|
||||
|
||||
createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
await getSessionBindingService().bind({
|
||||
targetSessionKey: "agent:main:acp:read-failed",
|
||||
targetKind: "session",
|
||||
conversation: {
|
||||
channel: "telegram",
|
||||
accountId: "default",
|
||||
conversationId: "keep-on-read-failure",
|
||||
},
|
||||
});
|
||||
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
readAcpSessionEntryMock.mockReturnValue({
|
||||
cfg: {} as never,
|
||||
storePath: "/tmp/acp-store.json",
|
||||
sessionKey: "agent:main:acp:read-failed",
|
||||
storeSessionKey: "agent:main:acp:read-failed",
|
||||
entry: undefined,
|
||||
acp: undefined,
|
||||
storeReadFailed: true,
|
||||
});
|
||||
|
||||
const reloaded = createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
expect(reloaded.getByConversationId("keep-on-read-failure")?.targetSessionKey).toBe(
|
||||
"agent:main:acp:read-failed",
|
||||
);
|
||||
});
|
||||
|
||||
it("flushes pending lifecycle update persists before test reset", async () => {
|
||||
stateDirOverride = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-telegram-bindings-"));
|
||||
process.env.OPENCLAW_STATE_DIR = stateDirOverride;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import fs from "node:fs";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { readAcpSessionEntry } from "openclaw/plugin-sdk/acp-runtime";
|
||||
import { loadConfig } from "openclaw/plugin-sdk/config-runtime";
|
||||
import {
|
||||
formatThreadBindingDurationLabel,
|
||||
@@ -15,7 +14,7 @@ import {
|
||||
} from "openclaw/plugin-sdk/conversation-runtime";
|
||||
import { formatErrorMessage } from "openclaw/plugin-sdk/error-runtime";
|
||||
import { writeJsonFileAtomically } from "openclaw/plugin-sdk/json-store";
|
||||
import { normalizeAccountId, isAcpSessionKey } from "openclaw/plugin-sdk/routing";
|
||||
import { normalizeAccountId } from "openclaw/plugin-sdk/routing";
|
||||
import { logVerbose } from "openclaw/plugin-sdk/runtime-env";
|
||||
import { resolveStateDir } from "openclaw/plugin-sdk/state-paths";
|
||||
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
|
||||
@@ -441,58 +440,6 @@ export function createTelegramThreadBindingManager(
|
||||
});
|
||||
}
|
||||
|
||||
const acpSessionKeys = new Set<string>();
|
||||
for (const binding of getThreadBindingsState().bindingsByAccountConversation.values()) {
|
||||
if (binding.targetKind !== "acp" || !isAcpSessionKey(binding.targetSessionKey)) {
|
||||
continue;
|
||||
}
|
||||
acpSessionKeys.add(binding.targetSessionKey);
|
||||
}
|
||||
|
||||
const staleSessionKeys = new Set<string>();
|
||||
for (const targetSessionKey of acpSessionKeys) {
|
||||
const sessionEntry = readAcpSessionEntry({ sessionKey: targetSessionKey });
|
||||
if (!sessionEntry || sessionEntry.storeReadFailed) {
|
||||
continue;
|
||||
}
|
||||
const isStale =
|
||||
!sessionEntry.entry ||
|
||||
sessionEntry.entry.status === "failed" ||
|
||||
sessionEntry.entry.status === "killed" ||
|
||||
sessionEntry.entry.status === "timeout" ||
|
||||
sessionEntry.entry.acp?.state === "error";
|
||||
if (isStale) {
|
||||
staleSessionKeys.add(targetSessionKey);
|
||||
}
|
||||
}
|
||||
|
||||
let needsPersist = false;
|
||||
for (const sessionKey of staleSessionKeys) {
|
||||
const bindingsToRemove = listBindingsForAccount(accountId).filter(
|
||||
(b) => b.targetSessionKey === sessionKey,
|
||||
);
|
||||
for (const binding of bindingsToRemove) {
|
||||
getThreadBindingsState().bindingsByAccountConversation.delete(
|
||||
resolveBindingKey({ accountId, conversationId: binding.conversationId }),
|
||||
);
|
||||
}
|
||||
if (bindingsToRemove.length > 0) {
|
||||
needsPersist = true;
|
||||
logVerbose(
|
||||
`telegram thread binding: cleaned up ${bindingsToRemove.length} stale binding(s) for session ${sessionKey}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (needsPersist && persist) {
|
||||
persistBindingsSafely({
|
||||
accountId,
|
||||
persist: true,
|
||||
bindings: listBindingsForAccount(accountId),
|
||||
reason: "cleanup-stale",
|
||||
});
|
||||
}
|
||||
|
||||
let sweepTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
const manager: TelegramThreadBindingManager = {
|
||||
|
||||
@@ -40,8 +40,8 @@ describe("thread-ownership plugin", () => {
|
||||
});
|
||||
|
||||
describe("message_sending", () => {
|
||||
beforeEach(() => {
|
||||
register.register(api as unknown as OpenClawPluginApi);
|
||||
beforeEach(async () => {
|
||||
await register.register(api as unknown as OpenClawPluginApi);
|
||||
});
|
||||
|
||||
async function sendSlackThreadMessage() {
|
||||
@@ -112,8 +112,8 @@ describe("thread-ownership plugin", () => {
|
||||
});
|
||||
|
||||
describe("message_received @-mention tracking", () => {
|
||||
beforeEach(() => {
|
||||
register.register(api as unknown as OpenClawPluginApi);
|
||||
beforeEach(async () => {
|
||||
await register.register(api as unknown as OpenClawPluginApi);
|
||||
});
|
||||
|
||||
it("tracks @-mentions and skips ownership check for mentioned threads", async () => {
|
||||
|
||||
@@ -1,65 +1,82 @@
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { fetchWithSsrFGuard } from "../../runtime-api.js";
|
||||
import { uploadFile } from "../tlon-api.js";
|
||||
import { uploadImageFromUrl } from "./upload.js";
|
||||
import { describe, expect, it, vi, afterEach, beforeEach } from "vitest";
|
||||
|
||||
vi.mock("../../runtime-api.js", () => ({
|
||||
fetchWithSsrFGuard: vi.fn(),
|
||||
}));
|
||||
// Mock fetchWithSsrFGuard from the local runtime seam.
|
||||
vi.mock("../../runtime-api.js", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("../../runtime-api.js")>("../../runtime-api.js");
|
||||
return {
|
||||
...actual,
|
||||
fetchWithSsrFGuard: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock the local Tlon upload seam.
|
||||
vi.mock("../tlon-api.js", () => ({
|
||||
uploadFile: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockFetch = vi.mocked(fetchWithSsrFGuard);
|
||||
const mockUploadFile = vi.mocked(uploadFile);
|
||||
|
||||
type FetchMock = typeof mockFetch;
|
||||
|
||||
function mockSuccessfulFetch(params: {
|
||||
mockFetch: FetchMock;
|
||||
blob: Blob;
|
||||
finalUrl: string;
|
||||
contentType: string;
|
||||
}) {
|
||||
params.mockFetch.mockResolvedValue({
|
||||
response: {
|
||||
ok: true,
|
||||
headers: new Headers({ "content-type": params.contentType }),
|
||||
blob: () => Promise.resolve(params.blob),
|
||||
} as unknown as Response,
|
||||
finalUrl: params.finalUrl,
|
||||
release: vi.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
}
|
||||
|
||||
async function setupSuccessfulUpload(params?: {
|
||||
sourceUrl?: string;
|
||||
contentType?: string;
|
||||
uploadedUrl?: string;
|
||||
}) {
|
||||
const sourceUrl = params?.sourceUrl ?? "https://example.com/image.png";
|
||||
const contentType = params?.contentType ?? "image/png";
|
||||
const mockBlob = new Blob(["fake-image"], { type: contentType });
|
||||
mockSuccessfulFetch({
|
||||
mockFetch,
|
||||
blob: mockBlob,
|
||||
finalUrl: sourceUrl,
|
||||
contentType,
|
||||
});
|
||||
if (params?.uploadedUrl) {
|
||||
mockUploadFile.mockResolvedValue({ url: params.uploadedUrl });
|
||||
}
|
||||
return { mockBlob };
|
||||
}
|
||||
|
||||
describe("uploadImageFromUrl", () => {
|
||||
async function loadUploadMocks() {
|
||||
const { fetchWithSsrFGuard } = await import("../../runtime-api.js");
|
||||
const { uploadFile } = await import("../tlon-api.js");
|
||||
const { uploadImageFromUrl } = await import("./upload.js");
|
||||
return {
|
||||
mockFetch: vi.mocked(fetchWithSsrFGuard),
|
||||
mockUploadFile: vi.mocked(uploadFile),
|
||||
uploadImageFromUrl,
|
||||
};
|
||||
}
|
||||
|
||||
type UploadMocks = Awaited<ReturnType<typeof loadUploadMocks>>;
|
||||
|
||||
function mockSuccessfulFetch(params: {
|
||||
mockFetch: UploadMocks["mockFetch"];
|
||||
blob: Blob;
|
||||
finalUrl: string;
|
||||
contentType: string;
|
||||
}) {
|
||||
params.mockFetch.mockResolvedValue({
|
||||
response: {
|
||||
ok: true,
|
||||
headers: new Headers({ "content-type": params.contentType }),
|
||||
blob: () => Promise.resolve(params.blob),
|
||||
} as unknown as Response,
|
||||
finalUrl: params.finalUrl,
|
||||
release: vi.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
}
|
||||
|
||||
async function setupSuccessfulUpload(params?: {
|
||||
sourceUrl?: string;
|
||||
contentType?: string;
|
||||
uploadedUrl?: string;
|
||||
}) {
|
||||
const { mockFetch, mockUploadFile, uploadImageFromUrl } = await loadUploadMocks();
|
||||
const sourceUrl = params?.sourceUrl ?? "https://example.com/image.png";
|
||||
const contentType = params?.contentType ?? "image/png";
|
||||
const mockBlob = new Blob(["fake-image"], { type: contentType });
|
||||
mockSuccessfulFetch({
|
||||
mockFetch,
|
||||
blob: mockBlob,
|
||||
finalUrl: sourceUrl,
|
||||
contentType,
|
||||
});
|
||||
if (params?.uploadedUrl) {
|
||||
mockUploadFile.mockResolvedValue({ url: params.uploadedUrl });
|
||||
}
|
||||
return { mockBlob, mockUploadFile, uploadImageFromUrl };
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("fetches image and calls uploadFile, returns uploaded URL", async () => {
|
||||
const { mockBlob } = await setupSuccessfulUpload({
|
||||
const { mockBlob, mockUploadFile, uploadImageFromUrl } = await setupSuccessfulUpload({
|
||||
uploadedUrl: "https://memex.tlon.network/uploaded.png",
|
||||
});
|
||||
|
||||
@@ -76,6 +93,8 @@ describe("uploadImageFromUrl", () => {
|
||||
});
|
||||
|
||||
it("returns original URL if fetch fails", async () => {
|
||||
const { mockFetch, uploadImageFromUrl } = await loadUploadMocks();
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
response: {
|
||||
ok: false,
|
||||
@@ -91,7 +110,7 @@ describe("uploadImageFromUrl", () => {
|
||||
});
|
||||
|
||||
it("returns original URL if upload fails", async () => {
|
||||
await setupSuccessfulUpload();
|
||||
const { mockUploadFile, uploadImageFromUrl } = await setupSuccessfulUpload();
|
||||
mockUploadFile.mockRejectedValue(new Error("Upload failed"));
|
||||
|
||||
const result = await uploadImageFromUrl("https://example.com/image.png");
|
||||
@@ -100,19 +119,28 @@ describe("uploadImageFromUrl", () => {
|
||||
});
|
||||
|
||||
it("rejects non-http(s) URLs", async () => {
|
||||
const { uploadImageFromUrl } = await import("./upload.js");
|
||||
|
||||
// file:// URL should be rejected
|
||||
const result = await uploadImageFromUrl("file:///etc/passwd");
|
||||
expect(result).toBe("file:///etc/passwd");
|
||||
|
||||
// ftp:// URL should be rejected
|
||||
const result2 = await uploadImageFromUrl("ftp://example.com/image.png");
|
||||
expect(result2).toBe("ftp://example.com/image.png");
|
||||
});
|
||||
|
||||
it("handles invalid URLs gracefully", async () => {
|
||||
const { uploadImageFromUrl } = await import("./upload.js");
|
||||
|
||||
// Invalid URL should return original
|
||||
const result = await uploadImageFromUrl("not-a-valid-url");
|
||||
expect(result).toBe("not-a-valid-url");
|
||||
});
|
||||
|
||||
it("extracts filename from URL path", async () => {
|
||||
const { mockFetch, mockUploadFile, uploadImageFromUrl } = await loadUploadMocks();
|
||||
|
||||
const mockBlob = new Blob(["fake-image"], { type: "image/jpeg" });
|
||||
mockSuccessfulFetch({
|
||||
mockFetch,
|
||||
@@ -133,6 +161,8 @@ describe("uploadImageFromUrl", () => {
|
||||
});
|
||||
|
||||
it("uses default filename when URL has no path", async () => {
|
||||
const { mockFetch, mockUploadFile, uploadImageFromUrl } = await loadUploadMocks();
|
||||
|
||||
const mockBlob = new Blob(["fake-image"], { type: "image/png" });
|
||||
mockSuccessfulFetch({
|
||||
mockFetch,
|
||||
|
||||
@@ -37,7 +37,7 @@ type Registered = {
|
||||
methods: Map<string, unknown>;
|
||||
tools: unknown[];
|
||||
};
|
||||
type RegisterVoiceCall = (api: Record<string, unknown>) => void;
|
||||
type RegisterVoiceCall = (api: Record<string, unknown>) => void | Promise<void>;
|
||||
type RegisterCliContext = {
|
||||
program: Command;
|
||||
config: Record<string, unknown>;
|
||||
@@ -83,7 +83,7 @@ async function registerVoiceCallCli(program: Command) {
|
||||
const { register } = plugin as unknown as {
|
||||
register: RegisterVoiceCall;
|
||||
};
|
||||
register({
|
||||
await register({
|
||||
id: "voice-call",
|
||||
name: "Voice Call",
|
||||
description: "test",
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { voyageMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
|
||||
export default definePluginEntry({
|
||||
id: "voyage",
|
||||
name: "Voyage Embeddings",
|
||||
description: "Bundled Voyage memory embedding provider plugin",
|
||||
register(api) {
|
||||
api.registerMemoryEmbeddingProvider(voyageMemoryEmbeddingProviderAdapter);
|
||||
},
|
||||
});
|
||||
@@ -1,56 +0,0 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { runVoyageEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
createVoyageEmbeddingProvider,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const voyageMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "voyage",
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "voyage",
|
||||
autoSelectPriority: 40,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider({
|
||||
...options,
|
||||
provider: "voyage",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "voyage",
|
||||
cacheKeyData: {
|
||||
provider: "voyage",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runVoyageEmbeddingBatches({
|
||||
client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
body: { input: chunk.text },
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"id": "voyage",
|
||||
"enabledByDefault": true,
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["voyage"]
|
||||
},
|
||||
"providerAuthEnvVars": {
|
||||
"voyage": ["VOYAGE_API_KEY"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"name": "@openclaw/voyage-provider",
|
||||
"version": "2026.4.15-beta.1",
|
||||
"private": true,
|
||||
"description": "OpenClaw Voyage embedding provider plugin",
|
||||
"type": "module",
|
||||
"devDependencies": {
|
||||
"@openclaw/plugin-sdk": "workspace:*"
|
||||
},
|
||||
"openclaw": {
|
||||
"extensions": [
|
||||
"./index.ts"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { createTestPluginApi } from "../../test/helpers/plugins/plugin-api.js";
|
||||
import type { OpenClawPluginApi } from "./api.js";
|
||||
import plugin from "./index.js";
|
||||
|
||||
function createApi(params?: {
|
||||
pluginConfig?: OpenClawPluginApi["pluginConfig"];
|
||||
registerHttpRoute?: OpenClawPluginApi["registerHttpRoute"];
|
||||
logger?: OpenClawPluginApi["logger"];
|
||||
}): OpenClawPluginApi {
|
||||
return createTestPluginApi({
|
||||
id: "webhooks",
|
||||
name: "Webhooks",
|
||||
source: "test",
|
||||
pluginConfig: params?.pluginConfig ?? {},
|
||||
runtime: {
|
||||
taskFlow: {
|
||||
bindSession: vi.fn(({ sessionKey }: { sessionKey: string }) => ({ sessionKey })),
|
||||
},
|
||||
} as unknown as OpenClawPluginApi["runtime"],
|
||||
registerHttpRoute: params?.registerHttpRoute ?? vi.fn(),
|
||||
logger:
|
||||
params?.logger ??
|
||||
({
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
} as OpenClawPluginApi["logger"]),
|
||||
});
|
||||
}
|
||||
|
||||
describe("webhooks plugin registration", () => {
|
||||
it("registers SecretRef-backed routes synchronously", () => {
|
||||
const registerHttpRoute = vi.fn();
|
||||
|
||||
const result = plugin.register(
|
||||
createApi({
|
||||
pluginConfig: {
|
||||
routes: {
|
||||
zapier: {
|
||||
sessionKey: "agent:main:main",
|
||||
secret: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "OPENCLAW_WEBHOOK_SECRET",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
registerHttpRoute,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(registerHttpRoute).toHaveBeenCalledTimes(1);
|
||||
expect(registerHttpRoute).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
path: "/plugins/webhooks/zapier",
|
||||
auth: "plugin",
|
||||
match: "exact",
|
||||
replaceExisting: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -2,52 +2,50 @@ import { definePluginEntry, type OpenClawPluginApi } from "./api.js";
|
||||
import { resolveWebhooksPluginConfig } from "./src/config.js";
|
||||
import { createTaskFlowWebhookRequestHandler, type TaskFlowWebhookTarget } from "./src/http.js";
|
||||
|
||||
function registerWebhookRoutes(api: OpenClawPluginApi): void {
|
||||
const routes = resolveWebhooksPluginConfig({
|
||||
pluginConfig: api.pluginConfig,
|
||||
});
|
||||
if (routes.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const targetsByPath = new Map<string, TaskFlowWebhookTarget[]>();
|
||||
const handler = createTaskFlowWebhookRequestHandler({
|
||||
cfg: api.config,
|
||||
targetsByPath,
|
||||
});
|
||||
|
||||
for (const route of routes) {
|
||||
const taskFlow = api.runtime.taskFlow.bindSession({
|
||||
sessionKey: route.sessionKey,
|
||||
});
|
||||
const target: TaskFlowWebhookTarget = {
|
||||
routeId: route.routeId,
|
||||
path: route.path,
|
||||
secretInput: route.secret,
|
||||
secretConfigPath: `plugins.entries.webhooks.routes.${route.routeId}.secret`,
|
||||
defaultControllerId: route.controllerId,
|
||||
taskFlow,
|
||||
};
|
||||
targetsByPath.set(target.path, [...(targetsByPath.get(target.path) ?? []), target]);
|
||||
api.registerHttpRoute({
|
||||
path: target.path,
|
||||
auth: "plugin",
|
||||
match: "exact",
|
||||
replaceExisting: true,
|
||||
handler,
|
||||
});
|
||||
api.logger.info?.(
|
||||
`[webhooks] registered route ${route.routeId} on ${route.path} for session ${route.sessionKey}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export default definePluginEntry({
|
||||
id: "webhooks",
|
||||
name: "Webhooks",
|
||||
description:
|
||||
"Authenticated inbound webhooks that bind external automation to OpenClaw TaskFlows.",
|
||||
register(api: OpenClawPluginApi) {
|
||||
registerWebhookRoutes(api);
|
||||
async register(api: OpenClawPluginApi) {
|
||||
const routes = await resolveWebhooksPluginConfig({
|
||||
pluginConfig: api.pluginConfig,
|
||||
cfg: api.config,
|
||||
env: process.env,
|
||||
logger: api.logger,
|
||||
});
|
||||
if (routes.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const targetsByPath = new Map<string, TaskFlowWebhookTarget[]>();
|
||||
const handler = createTaskFlowWebhookRequestHandler({
|
||||
cfg: api.config,
|
||||
targetsByPath,
|
||||
});
|
||||
|
||||
for (const route of routes) {
|
||||
const taskFlow = api.runtime.taskFlow.bindSession({
|
||||
sessionKey: route.sessionKey,
|
||||
});
|
||||
const target: TaskFlowWebhookTarget = {
|
||||
routeId: route.routeId,
|
||||
path: route.path,
|
||||
secret: route.secret,
|
||||
defaultControllerId: route.controllerId,
|
||||
taskFlow,
|
||||
};
|
||||
targetsByPath.set(target.path, [...(targetsByPath.get(target.path) ?? []), target]);
|
||||
api.registerHttpRoute({
|
||||
path: target.path,
|
||||
auth: "plugin",
|
||||
match: "exact",
|
||||
replaceExisting: true,
|
||||
handler,
|
||||
});
|
||||
api.logger.info?.(
|
||||
`[webhooks] registered route ${route.routeId} on ${route.path} for session ${route.sessionKey}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -4,7 +4,6 @@ export {
|
||||
normalizeWebhookPath,
|
||||
readJsonWebhookBodyOrReject,
|
||||
resolveRequestClientIp,
|
||||
resolveWebhookTargetWithAuthOrReject,
|
||||
resolveWebhookTargetWithAuthOrRejectSync,
|
||||
withResolvedWebhookRequestPipeline,
|
||||
WEBHOOK_IN_FLIGHT_DEFAULTS,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import type { OpenClawConfig } from "../runtime-api.js";
|
||||
import { resolveWebhooksPluginConfig } from "./config.js";
|
||||
|
||||
describe("resolveWebhooksPluginConfig", () => {
|
||||
it("keeps SecretRef-backed secrets on the route config", () => {
|
||||
const routes = resolveWebhooksPluginConfig({
|
||||
it("resolves default paths and SecretRef-backed secrets", async () => {
|
||||
const routes = await resolveWebhooksPluginConfig({
|
||||
pluginConfig: {
|
||||
routes: {
|
||||
zapier: {
|
||||
@@ -16,6 +17,10 @@ describe("resolveWebhooksPluginConfig", () => {
|
||||
},
|
||||
},
|
||||
},
|
||||
cfg: {} as OpenClawConfig,
|
||||
env: {
|
||||
OPENCLAW_WEBHOOK_SECRET: "shared-secret",
|
||||
},
|
||||
});
|
||||
|
||||
expect(routes).toEqual([
|
||||
@@ -23,18 +28,16 @@ describe("resolveWebhooksPluginConfig", () => {
|
||||
routeId: "zapier",
|
||||
path: "/plugins/webhooks/zapier",
|
||||
sessionKey: "agent:main:main",
|
||||
secret: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "OPENCLAW_WEBHOOK_SECRET",
|
||||
},
|
||||
secret: "shared-secret",
|
||||
controllerId: "webhooks/zapier",
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("keeps routes whose secret needs runtime resolution", () => {
|
||||
const routes = resolveWebhooksPluginConfig({
|
||||
it("skips routes whose secret cannot be resolved", async () => {
|
||||
const warn = vi.fn();
|
||||
|
||||
const routes = await resolveWebhooksPluginConfig({
|
||||
pluginConfig: {
|
||||
routes: {
|
||||
missing: {
|
||||
@@ -47,25 +50,19 @@ describe("resolveWebhooksPluginConfig", () => {
|
||||
},
|
||||
},
|
||||
},
|
||||
cfg: {} as OpenClawConfig,
|
||||
env: {},
|
||||
logger: { warn } as never,
|
||||
});
|
||||
|
||||
expect(routes).toEqual([
|
||||
{
|
||||
routeId: "missing",
|
||||
path: "/plugins/webhooks/missing",
|
||||
sessionKey: "agent:main:main",
|
||||
secret: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "MISSING_SECRET",
|
||||
},
|
||||
controllerId: "webhooks/missing",
|
||||
},
|
||||
]);
|
||||
expect(routes).toEqual([]);
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining("[webhooks] skipping route missing:"),
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects duplicate normalized paths", () => {
|
||||
expect(() =>
|
||||
it("rejects duplicate normalized paths", async () => {
|
||||
await expect(
|
||||
resolveWebhooksPluginConfig({
|
||||
pluginConfig: {
|
||||
routes: {
|
||||
@@ -81,7 +78,9 @@ describe("resolveWebhooksPluginConfig", () => {
|
||||
},
|
||||
},
|
||||
},
|
||||
cfg: {} as OpenClawConfig,
|
||||
env: {},
|
||||
}),
|
||||
).toThrow(/conflicts with routes\.first\.path/i);
|
||||
).rejects.toThrow(/conflicts with routes\.first\.path/i);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import { z } from "zod";
|
||||
import { normalizeWebhookPath } from "../runtime-api.js";
|
||||
import type { PluginLogger } from "../api.js";
|
||||
import {
|
||||
normalizeWebhookPath,
|
||||
resolveConfiguredSecretInputString,
|
||||
type OpenClawConfig,
|
||||
} from "../runtime-api.js";
|
||||
|
||||
const secretRefSchema = z
|
||||
.object({
|
||||
@@ -28,22 +33,23 @@ const webhooksPluginConfigSchema = z
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type WebhookSecretInput = z.infer<typeof secretInputSchema>;
|
||||
|
||||
export type ConfiguredWebhookRouteConfig = {
|
||||
export type ResolvedWebhookRouteConfig = {
|
||||
routeId: string;
|
||||
path: string;
|
||||
sessionKey: string;
|
||||
secret: WebhookSecretInput;
|
||||
secret: string;
|
||||
controllerId: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
export function resolveWebhooksPluginConfig(params: {
|
||||
export async function resolveWebhooksPluginConfig(params: {
|
||||
pluginConfig: unknown;
|
||||
}): ConfiguredWebhookRouteConfig[] {
|
||||
cfg: OpenClawConfig;
|
||||
env: NodeJS.ProcessEnv;
|
||||
logger?: PluginLogger;
|
||||
}): Promise<ResolvedWebhookRouteConfig[]> {
|
||||
const parsed = webhooksPluginConfigSchema.parse(params.pluginConfig ?? {});
|
||||
const configuredRoutes: ConfiguredWebhookRouteConfig[] = [];
|
||||
const resolvedRoutes: ResolvedWebhookRouteConfig[] = [];
|
||||
const seenPaths = new Map<string, string>();
|
||||
|
||||
for (const [routeId, route] of Object.entries(parsed.routes)) {
|
||||
@@ -58,16 +64,32 @@ export function resolveWebhooksPluginConfig(params: {
|
||||
);
|
||||
}
|
||||
|
||||
const secretResolution = await resolveConfiguredSecretInputString({
|
||||
config: params.cfg,
|
||||
env: params.env,
|
||||
value: route.secret,
|
||||
path: `plugins.entries.webhooks.routes.${routeId}.secret`,
|
||||
});
|
||||
const secret = secretResolution.value?.trim();
|
||||
if (!secret) {
|
||||
params.logger?.warn?.(
|
||||
`[webhooks] skipping route ${routeId}: ${
|
||||
secretResolution.unresolvedRefReason ?? "secret is empty or unresolved"
|
||||
}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
seenPaths.set(path, routeId);
|
||||
configuredRoutes.push({
|
||||
resolvedRoutes.push({
|
||||
routeId,
|
||||
path,
|
||||
sessionKey: route.sessionKey,
|
||||
secret: route.secret,
|
||||
secret,
|
||||
controllerId: route.controllerId ?? `webhooks/${routeId}`,
|
||||
...(route.description ? { description: route.description } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
return configuredRoutes;
|
||||
return resolvedRoutes;
|
||||
}
|
||||
|
||||
@@ -10,12 +10,10 @@ const hoisted = vi.hoisted(() => {
|
||||
const sendMessageMock = vi.fn();
|
||||
const cancelSessionMock = vi.fn();
|
||||
const killSubagentRunAdminMock = vi.fn();
|
||||
const resolveConfiguredSecretInputStringMock = vi.fn();
|
||||
return {
|
||||
sendMessageMock,
|
||||
cancelSessionMock,
|
||||
killSubagentRunAdminMock,
|
||||
resolveConfiguredSecretInputStringMock,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -33,17 +31,6 @@ vi.mock("../../../src/agents/subagent-control.js", () => ({
|
||||
killSubagentRunAdmin: (params: unknown) => hoisted.killSubagentRunAdminMock(params),
|
||||
}));
|
||||
|
||||
vi.mock("../runtime-api.js", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("../runtime-api.js")>();
|
||||
hoisted.resolveConfiguredSecretInputStringMock.mockImplementation(
|
||||
actual.resolveConfiguredSecretInputString,
|
||||
);
|
||||
return {
|
||||
...actual,
|
||||
resolveConfiguredSecretInputString: hoisted.resolveConfiguredSecretInputStringMock,
|
||||
};
|
||||
});
|
||||
|
||||
type MockIncomingMessage = IncomingMessage & {
|
||||
destroyed?: boolean;
|
||||
destroy: () => MockIncomingMessage;
|
||||
@@ -71,7 +58,7 @@ function createJsonRequest(params: {
|
||||
return req;
|
||||
}) as MockIncomingMessage["destroy"];
|
||||
|
||||
setImmediate(() => {
|
||||
void Promise.resolve().then(() => {
|
||||
req.emit("data", Buffer.from(JSON.stringify(params.body), "utf8"));
|
||||
req.emit("end");
|
||||
});
|
||||
@@ -82,16 +69,13 @@ function createJsonRequest(params: {
|
||||
function createHandler(): {
|
||||
handler: ReturnType<typeof createTaskFlowWebhookRequestHandler>;
|
||||
target: TaskFlowWebhookTarget;
|
||||
secret: string;
|
||||
} {
|
||||
const runtime = createRuntimeTaskFlow();
|
||||
nextSessionId += 1;
|
||||
const secret = "shared-secret";
|
||||
const target: TaskFlowWebhookTarget = {
|
||||
routeId: "zapier",
|
||||
path: "/plugins/webhooks/zapier",
|
||||
secretInput: secret,
|
||||
secretConfigPath: "plugins.entries.webhooks.routes.zapier.secret",
|
||||
secret: "shared-secret",
|
||||
defaultControllerId: "webhooks/zapier",
|
||||
taskFlow: runtime.bindSession({
|
||||
sessionKey: `agent:main:webhook-test-${String(nextSessionId)}`,
|
||||
@@ -104,21 +88,9 @@ function createHandler(): {
|
||||
targetsByPath,
|
||||
}),
|
||||
target,
|
||||
secret,
|
||||
};
|
||||
}
|
||||
|
||||
function createHandlerWithTarget(
|
||||
target: TaskFlowWebhookTarget,
|
||||
cfg: OpenClawConfig = {} as OpenClawConfig,
|
||||
): ReturnType<typeof createTaskFlowWebhookRequestHandler> {
|
||||
const targetsByPath = new Map<string, TaskFlowWebhookTarget[]>([[target.path, [target]]]);
|
||||
return createTaskFlowWebhookRequestHandler({
|
||||
cfg,
|
||||
targetsByPath,
|
||||
});
|
||||
}
|
||||
|
||||
async function dispatchJsonRequest(params: {
|
||||
handler: ReturnType<typeof createTaskFlowWebhookRequestHandler>;
|
||||
path: string;
|
||||
@@ -160,53 +132,12 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
expect(target.taskFlow.list()).toEqual([]);
|
||||
});
|
||||
|
||||
it("caches SecretRef resolution across requests for the same route", async () => {
|
||||
const runtime = createRuntimeTaskFlow();
|
||||
const target: TaskFlowWebhookTarget = {
|
||||
routeId: "cached",
|
||||
path: "/plugins/webhooks/cached",
|
||||
secretInput: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "OPENCLAW_WEBHOOK_SECRET",
|
||||
},
|
||||
secretConfigPath: "plugins.entries.webhooks.routes.cached.secret",
|
||||
defaultControllerId: "webhooks/cached",
|
||||
taskFlow: runtime.bindSession({
|
||||
sessionKey: "agent:main:webhook-cached",
|
||||
}),
|
||||
};
|
||||
hoisted.resolveConfiguredSecretInputStringMock.mockResolvedValue({ value: "shared-secret" });
|
||||
const handler = createHandlerWithTarget(target);
|
||||
|
||||
const first = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: "shared-secret",
|
||||
body: {
|
||||
action: "list_flows",
|
||||
},
|
||||
});
|
||||
const second = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: "shared-secret",
|
||||
body: {
|
||||
action: "list_flows",
|
||||
},
|
||||
});
|
||||
|
||||
expect(first.statusCode).toBe(200);
|
||||
expect(second.statusCode).toBe(200);
|
||||
expect(hoisted.resolveConfiguredSecretInputStringMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("creates flows through the bound session and scrubs owner metadata from responses", async () => {
|
||||
const { handler, target, secret } = createHandler();
|
||||
const { handler, target } = createHandler();
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "create_flow",
|
||||
goal: "Review inbound queue",
|
||||
@@ -227,7 +158,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("runs child tasks and scrubs task ownership fields from responses", async () => {
|
||||
const { handler, target, secret } = createHandler();
|
||||
const { handler, target } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Triage inbox",
|
||||
@@ -235,7 +166,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -262,11 +193,11 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("returns 404 for missing flow mutations", async () => {
|
||||
const { handler, target, secret } = createHandler();
|
||||
const { handler, target } = createHandler();
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "set_waiting",
|
||||
flowId: "flow-missing",
|
||||
@@ -288,7 +219,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("returns 409 for revision conflicts", async () => {
|
||||
const { handler, target, secret } = createHandler();
|
||||
const { handler, target } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Review inbox",
|
||||
@@ -296,7 +227,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "set_waiting",
|
||||
flowId: flow.flowId,
|
||||
@@ -321,7 +252,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("rejects internal runtimes and running-only metadata from external callers", async () => {
|
||||
const { handler, target, secret } = createHandler();
|
||||
const { handler, target } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Review inbox",
|
||||
@@ -330,7 +261,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const runtimeRes = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -347,7 +278,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const queuedMetadataRes = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -366,7 +297,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("reuses the same task record when retried with the same runId", async () => {
|
||||
const { handler, target, secret } = createHandler();
|
||||
const { handler, target } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Triage inbox",
|
||||
@@ -375,7 +306,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const first = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -388,7 +319,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const second = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -408,7 +339,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("returns 409 when cancellation targets a terminal flow", async () => {
|
||||
const { handler, target, secret } = createHandler();
|
||||
const { handler, target } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Review inbox",
|
||||
@@ -422,7 +353,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret,
|
||||
secret: target.secret,
|
||||
body: {
|
||||
action: "cancel_flow",
|
||||
flowId: flow.flowId,
|
||||
|
||||
@@ -8,15 +8,13 @@ import {
|
||||
createWebhookInFlightLimiter,
|
||||
readJsonWebhookBodyOrReject,
|
||||
resolveRequestClientIp,
|
||||
resolveConfiguredSecretInputString,
|
||||
resolveWebhookTargetWithAuthOrReject,
|
||||
resolveWebhookTargetWithAuthOrRejectSync,
|
||||
withResolvedWebhookRequestPipeline,
|
||||
WEBHOOK_IN_FLIGHT_DEFAULTS,
|
||||
WEBHOOK_RATE_LIMIT_DEFAULTS,
|
||||
type OpenClawConfig,
|
||||
type WebhookInFlightLimiter,
|
||||
} from "../runtime-api.js";
|
||||
import type { WebhookSecretInput } from "./config.js";
|
||||
|
||||
type BoundTaskFlowRuntime = ReturnType<PluginRuntime["taskFlow"]["bindSession"]>;
|
||||
|
||||
@@ -176,8 +174,7 @@ type WebhookAction = z.infer<typeof webhookActionSchema>;
|
||||
export type TaskFlowWebhookTarget = {
|
||||
routeId: string;
|
||||
path: string;
|
||||
secretInput: WebhookSecretInput;
|
||||
secretConfigPath: string;
|
||||
secret: string;
|
||||
defaultControllerId: string;
|
||||
taskFlow: BoundTaskFlowRuntime;
|
||||
};
|
||||
@@ -667,7 +664,6 @@ export function createTaskFlowWebhookRequestHandler(params: {
|
||||
targetsByPath: Map<string, TaskFlowWebhookTarget[]>;
|
||||
inFlightLimiter?: WebhookInFlightLimiter;
|
||||
}): (req: IncomingMessage, res: ServerResponse) => Promise<boolean> {
|
||||
const secretByTarget = new WeakMap<TaskFlowWebhookTarget, Promise<string | undefined>>();
|
||||
const rateLimiter = createFixedWindowRateLimiter({
|
||||
windowMs: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs,
|
||||
maxRequests: WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests,
|
||||
@@ -679,20 +675,6 @@ export function createTaskFlowWebhookRequestHandler(params: {
|
||||
maxInFlightPerKey: WEBHOOK_IN_FLIGHT_DEFAULTS.maxInFlightPerKey,
|
||||
maxTrackedKeys: WEBHOOK_IN_FLIGHT_DEFAULTS.maxTrackedKeys,
|
||||
});
|
||||
const resolveTargetSecret = (target: TaskFlowWebhookTarget): Promise<string | undefined> => {
|
||||
const cached = secretByTarget.get(target);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
const pending = resolveConfiguredSecretInputString({
|
||||
config: params.cfg,
|
||||
env: process.env,
|
||||
value: target.secretInput,
|
||||
path: target.secretConfigPath,
|
||||
}).then((resolved) => resolved.value);
|
||||
secretByTarget.set(target, pending);
|
||||
return pending;
|
||||
};
|
||||
|
||||
return async (req: IncomingMessage, res: ServerResponse): Promise<boolean> => {
|
||||
return await withResolvedWebhookRequestPipeline({
|
||||
@@ -716,18 +698,11 @@ export function createTaskFlowWebhookRequestHandler(params: {
|
||||
inFlightLimiter,
|
||||
handle: async ({ targets }) => {
|
||||
const presentedSecret = extractSharedSecret(req);
|
||||
const target = await resolveWebhookTargetWithAuthOrReject({
|
||||
const target = resolveWebhookTargetWithAuthOrRejectSync({
|
||||
targets,
|
||||
res,
|
||||
isMatch: async (candidate) => {
|
||||
if (presentedSecret.length === 0) {
|
||||
return false;
|
||||
}
|
||||
const resolvedSecret = await resolveTargetSecret(candidate);
|
||||
return Boolean(
|
||||
resolvedSecret && timingSafeEquals(resolvedSecret, presentedSecret),
|
||||
);
|
||||
},
|
||||
isMatch: (candidate) =>
|
||||
presentedSecret.length > 0 && timingSafeEquals(candidate.secret, presentedSecret),
|
||||
});
|
||||
if (!target) {
|
||||
return true;
|
||||
|
||||
@@ -169,19 +169,6 @@ describe("whatsapp setup wizard", () => {
|
||||
expectWhatsAppAllowlistModeSetup(result.cfg);
|
||||
});
|
||||
|
||||
it("throws a user-facing error instead of crashing when allowlist input is undefined", async () => {
|
||||
const harness = createSeparatePhoneHarness({
|
||||
selectValues: ["separate", "allowlist", "list"],
|
||||
});
|
||||
harness.text.mockResolvedValueOnce(undefined as never);
|
||||
|
||||
await expect(
|
||||
runConfigureWithHarness({
|
||||
harness,
|
||||
}),
|
||||
).rejects.toThrow("Invalid WhatsApp allowFrom list");
|
||||
});
|
||||
|
||||
it("enables allowlist self-chat mode for personal-phone setup", async () => {
|
||||
hoisted.pathExists.mockResolvedValue(true);
|
||||
const harness = createWhatsAppPersonalPhoneHarness(createQueuedWizardPrompter);
|
||||
@@ -193,18 +180,6 @@ describe("whatsapp setup wizard", () => {
|
||||
expectWhatsAppPersonalPhoneSetup(result.cfg);
|
||||
});
|
||||
|
||||
it("throws a user-facing error instead of crashing when personal-phone input is undefined", async () => {
|
||||
hoisted.pathExists.mockResolvedValue(true);
|
||||
const harness = createWhatsAppPersonalPhoneHarness(createQueuedWizardPrompter);
|
||||
harness.text.mockResolvedValueOnce(undefined as never);
|
||||
|
||||
await expect(
|
||||
runConfigureWithHarness({
|
||||
harness,
|
||||
}),
|
||||
).rejects.toThrow("Invalid WhatsApp owner number");
|
||||
});
|
||||
|
||||
it("forces wildcard allowFrom for open policy without allowFrom follow-up prompts", async () => {
|
||||
hoisted.pathExists.mockResolvedValue(true);
|
||||
const harness = createSeparatePhoneHarness({
|
||||
|
||||
@@ -23,10 +23,6 @@ type SetupRuntime = Parameters<NonNullable<ChannelSetupWizard["finalize"]>>[0]["
|
||||
type WhatsAppConfig = NonNullable<NonNullable<OpenClawConfig["channels"]>["whatsapp"]>;
|
||||
type WhatsAppAccountConfig = NonNullable<NonNullable<WhatsAppConfig["accounts"]>[string]>;
|
||||
|
||||
function trimPromptText(value: string | null | undefined): string {
|
||||
return value?.trim() ?? "";
|
||||
}
|
||||
|
||||
function mergeWhatsAppConfig(
|
||||
cfg: OpenClawConfig,
|
||||
accountId: string,
|
||||
@@ -128,7 +124,7 @@ async function promptWhatsAppOwnerAllowFrom(params: {
|
||||
placeholder: "+15555550123",
|
||||
initialValue: existingAllowFrom[0],
|
||||
validate: (value) => {
|
||||
const raw = trimPromptText(value);
|
||||
const raw = value.trim();
|
||||
if (!raw) {
|
||||
return "Required";
|
||||
}
|
||||
@@ -140,7 +136,7 @@ async function promptWhatsAppOwnerAllowFrom(params: {
|
||||
},
|
||||
});
|
||||
|
||||
const normalized = normalizeE164(trimPromptText(entry));
|
||||
const normalized = normalizeE164(entry.trim());
|
||||
if (!normalized) {
|
||||
throw new Error("Invalid WhatsApp owner number (expected E.164 after validation).");
|
||||
}
|
||||
@@ -315,7 +311,7 @@ async function promptWhatsAppDmAccess(params: {
|
||||
message: "Allowed sender numbers (comma-separated, E.164)",
|
||||
placeholder: "+15555550123, +447700900123",
|
||||
validate: (value) => {
|
||||
const raw = trimPromptText(value);
|
||||
const raw = value.trim();
|
||||
if (!raw) {
|
||||
return "Required";
|
||||
}
|
||||
@@ -330,13 +326,7 @@ async function promptWhatsAppDmAccess(params: {
|
||||
},
|
||||
});
|
||||
|
||||
const parsed = parseWhatsAppAllowFromEntries(trimPromptText(allowRaw));
|
||||
if (parsed.invalidEntry) {
|
||||
throw new Error(`Invalid number: ${parsed.invalidEntry}`);
|
||||
}
|
||||
if (parsed.entries.length === 0) {
|
||||
throw new Error("Invalid WhatsApp allowFrom list (expected at least one E.164 number).");
|
||||
}
|
||||
const parsed = parseWhatsAppAllowFromEntries(allowRaw);
|
||||
return setWhatsAppAllowFrom(next, accountId, parsed.entries);
|
||||
}
|
||||
|
||||
|
||||
@@ -1346,8 +1346,6 @@
|
||||
"test:parallels:windows": "bash scripts/e2e/parallels-windows-smoke.sh",
|
||||
"test:perf:budget": "node scripts/test-perf-budget.mjs",
|
||||
"test:perf:changed:bench": "node scripts/bench-test-changed.mjs",
|
||||
"test:perf:groups": "node scripts/test-group-report.mjs",
|
||||
"test:perf:groups:compare": "node scripts/test-group-report.mjs --compare",
|
||||
"test:perf:hotspots": "node scripts/test-hotspots.mjs",
|
||||
"test:perf:imports": "OPENCLAW_VITEST_IMPORT_DURATIONS=1 OPENCLAW_VITEST_PRINT_IMPORT_BREAKDOWN=1 node scripts/test-projects.mjs",
|
||||
"test:perf:imports:changed": "OPENCLAW_VITEST_IMPORT_DURATIONS=1 OPENCLAW_VITEST_PRINT_IMPORT_BREAKDOWN=1 node scripts/test-projects.mjs --changed origin/main",
|
||||
|
||||
116
packages/memory-host-sdk/src/host/batch-gemini.test.ts
Normal file
116
packages/memory-host-sdk/src/host/batch-gemini.test.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
|
||||
vi.mock("./remote-http.js", () => ({
|
||||
withRemoteHttpResponse: vi.fn(),
|
||||
}));
|
||||
|
||||
function magnitude(values: number[]) {
|
||||
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
|
||||
}
|
||||
|
||||
describe("runGeminiEmbeddingBatches", () => {
|
||||
let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches;
|
||||
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
|
||||
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js"));
|
||||
({ withRemoteHttpResponse } = await import("./remote-http.js"));
|
||||
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
const mockClient: GeminiEmbeddingClient = {
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
headers: {},
|
||||
model: "gemini-embedding-2-preview",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
apiKeys: ["test-key"],
|
||||
outputDimensionality: 1536,
|
||||
};
|
||||
|
||||
it("includes outputDimensionality in batch upload requests", async () => {
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart");
|
||||
const body = params.init?.body;
|
||||
if (!(body instanceof Blob)) {
|
||||
throw new Error("expected multipart blob body");
|
||||
}
|
||||
const text = await body.text();
|
||||
expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"');
|
||||
expect(text).toContain('"outputDimensionality":1536');
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ name: "files/file-123" }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/:asyncBatchEmbedContent$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
name: "batches/batch-1",
|
||||
state: "COMPLETED",
|
||||
outputConfig: { file: "files/output-1" },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/\/files\/output-1:download$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
key: "req-1",
|
||||
response: { embedding: { values: [3, 4] } },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/jsonl" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runGeminiEmbeddingBatches({
|
||||
gemini: mockClient,
|
||||
agentId: "main",
|
||||
requests: [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
request: {
|
||||
content: { parts: [{ text: "hello world" }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
},
|
||||
},
|
||||
],
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
});
|
||||
|
||||
const embedding = results.get("req-1");
|
||||
expect(embedding).toBeDefined();
|
||||
expect(embedding?.[0]).toBeCloseTo(0.6, 5);
|
||||
expect(embedding?.[1]).toBeCloseTo(0.8, 5);
|
||||
expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5);
|
||||
expect(remoteHttpMock).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
@@ -1,15 +1,14 @@
|
||||
import crypto from "node:crypto";
|
||||
import {
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
runEmbeddingBatchGroups,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
buildBatchHeaders,
|
||||
debugEmbeddingsLog,
|
||||
normalizeBatchBaseUrl,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
withRemoteHttpResponse,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embedding-provider.js";
|
||||
} from "./batch-runner.js";
|
||||
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js";
|
||||
import { hashText } from "./internal.js";
|
||||
import { withRemoteHttpResponse } from "./remote-http.js";
|
||||
|
||||
export type GeminiBatchRequest = {
|
||||
custom_id: string;
|
||||
@@ -41,10 +40,6 @@ export type GeminiBatchOutputLine = {
|
||||
};
|
||||
|
||||
const GEMINI_BATCH_MAX_REQUESTS = 50000;
|
||||
function hashText(text: string): string {
|
||||
return crypto.createHash("sha256").update(text).digest("hex");
|
||||
}
|
||||
|
||||
function getGeminiUploadUrl(baseUrl: string): string {
|
||||
if (baseUrl.includes("/v1beta")) {
|
||||
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");
|
||||
259
packages/memory-host-sdk/src/host/batch-openai.ts
Normal file
259
packages/memory-host-sdk/src/host/batch-openai.ts
Normal file
@@ -0,0 +1,259 @@
|
||||
import {
|
||||
applyEmbeddingBatchOutputLine,
|
||||
buildBatchHeaders,
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
EMBEDDING_BATCH_ENDPOINT,
|
||||
extractBatchErrorMessage,
|
||||
formatUnavailableBatchError,
|
||||
normalizeBatchBaseUrl,
|
||||
postJsonWithRetry,
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
runEmbeddingBatchGroups,
|
||||
throwIfBatchTerminalFailure,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
type EmbeddingBatchStatus,
|
||||
type BatchCompletionResult,
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
|
||||
export type OpenAiBatchRequest = {
|
||||
custom_id: string;
|
||||
method: "POST";
|
||||
url: "/v1/embeddings";
|
||||
body: {
|
||||
model: string;
|
||||
input: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type OpenAiBatchStatus = EmbeddingBatchStatus;
|
||||
export type OpenAiBatchOutputLine = ProviderBatchOutputLine;
|
||||
|
||||
export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
|
||||
const OPENAI_BATCH_COMPLETION_WINDOW = "24h";
|
||||
const OPENAI_BATCH_MAX_REQUESTS = 50000;
|
||||
|
||||
async function submitOpenAiBatch(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
requests: OpenAiBatchRequest[];
|
||||
agentId: string;
|
||||
}): Promise<OpenAiBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.openAi);
|
||||
const inputFileId = await uploadBatchJsonlFile({
|
||||
client: params.openAi,
|
||||
requests: params.requests,
|
||||
errorPrefix: "openai batch file upload failed",
|
||||
});
|
||||
|
||||
return await postJsonWithRetry<OpenAiBatchStatus>({
|
||||
url: `${baseUrl}/batches`,
|
||||
headers: buildBatchHeaders(params.openAi, { json: true }),
|
||||
ssrfPolicy: params.openAi.ssrfPolicy,
|
||||
body: {
|
||||
input_file_id: inputFileId,
|
||||
endpoint: OPENAI_BATCH_ENDPOINT,
|
||||
completion_window: OPENAI_BATCH_COMPLETION_WINDOW,
|
||||
metadata: {
|
||||
source: "openclaw-memory",
|
||||
agent: params.agentId,
|
||||
},
|
||||
},
|
||||
errorPrefix: "openai batch create failed",
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiBatchStatus(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
batchId: string;
|
||||
}): Promise<OpenAiBatchStatus> {
|
||||
return await fetchOpenAiBatchResource({
|
||||
openAi: params.openAi,
|
||||
path: `/batches/${params.batchId}`,
|
||||
errorPrefix: "openai batch status",
|
||||
parse: async (res) => (await res.json()) as OpenAiBatchStatus,
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiFileContent(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
fileId: string;
|
||||
}): Promise<string> {
|
||||
return await fetchOpenAiBatchResource({
|
||||
openAi: params.openAi,
|
||||
path: `/files/${params.fileId}/content`,
|
||||
errorPrefix: "openai batch file content",
|
||||
parse: async (res) => await res.text(),
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiBatchResource<T>(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
path: string;
|
||||
errorPrefix: string;
|
||||
parse: (res: Response) => Promise<T>;
|
||||
}): Promise<T> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.openAi);
|
||||
return await withRemoteHttpResponse({
|
||||
url: `${baseUrl}${params.path}`,
|
||||
ssrfPolicy: params.openAi.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.openAi, { json: true }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`${params.errorPrefix} failed: ${res.status} ${text}`);
|
||||
}
|
||||
return await params.parse(res);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
return text
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as OpenAiBatchOutputLine);
|
||||
}
|
||||
|
||||
async function readOpenAiBatchError(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
errorFileId: string;
|
||||
}): Promise<string | undefined> {
|
||||
try {
|
||||
const content = await fetchOpenAiFileContent({
|
||||
openAi: params.openAi,
|
||||
fileId: params.errorFileId,
|
||||
});
|
||||
const lines = parseOpenAiBatchOutput(content);
|
||||
return extractBatchErrorMessage(lines);
|
||||
} catch (err) {
|
||||
return formatUnavailableBatchError(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForOpenAiBatch(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
batchId: string;
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: OpenAiBatchStatus;
|
||||
}): Promise<BatchCompletionResult> {
|
||||
const start = Date.now();
|
||||
let current: OpenAiBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
current ??
|
||||
(await fetchOpenAiBatchStatus({
|
||||
openAi: params.openAi,
|
||||
batchId: params.batchId,
|
||||
}));
|
||||
const state = status.status ?? "unknown";
|
||||
if (state === "completed") {
|
||||
return resolveBatchCompletionFromStatus({
|
||||
provider: "openai",
|
||||
batchId: params.batchId,
|
||||
status,
|
||||
});
|
||||
}
|
||||
await throwIfBatchTerminalFailure({
|
||||
provider: "openai",
|
||||
status: { ...status, id: params.batchId },
|
||||
readError: async (errorFileId) =>
|
||||
await readOpenAiBatchError({
|
||||
openAi: params.openAi,
|
||||
errorFileId,
|
||||
}),
|
||||
});
|
||||
if (!params.wait) {
|
||||
throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`);
|
||||
}
|
||||
if (Date.now() - start > params.timeoutMs) {
|
||||
throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function runOpenAiEmbeddingBatches(
|
||||
params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: OpenAiBatchRequest[];
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: OPENAI_BATCH_MAX_REQUESTS,
|
||||
debugLabel: "memory embeddings: openai batch submit",
|
||||
}),
|
||||
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
|
||||
const batchInfo = await submitOpenAiBatch({
|
||||
openAi: params.openAi,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
});
|
||||
if (!batchInfo.id) {
|
||||
throw new Error("openai batch create failed: missing batch id");
|
||||
}
|
||||
const batchId = batchInfo.id;
|
||||
|
||||
params.debug?.("memory embeddings: openai batch created", {
|
||||
batchId: batchInfo.id,
|
||||
status: batchInfo.status,
|
||||
group: groupIndex + 1,
|
||||
groups,
|
||||
requests: group.length,
|
||||
});
|
||||
|
||||
const completed = await resolveCompletedBatchResult({
|
||||
provider: "openai",
|
||||
status: batchInfo,
|
||||
wait: params.wait,
|
||||
waitForBatch: async () =>
|
||||
await waitForOpenAiBatch({
|
||||
openAi: params.openAi,
|
||||
batchId,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
}),
|
||||
});
|
||||
|
||||
const content = await fetchOpenAiFileContent({
|
||||
openAi: params.openAi,
|
||||
fileId: completed.outputFileId,
|
||||
});
|
||||
const outputLines = parseOpenAiBatchOutput(content);
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
for (const line of outputLines) {
|
||||
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`);
|
||||
}
|
||||
if (remaining.size > 0) {
|
||||
throw new Error(
|
||||
`openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
176
packages/memory-host-sdk/src/host/batch-voyage.test.ts
Normal file
176
packages/memory-host-sdk/src/host/batch-voyage.test.ts
Normal file
@@ -0,0 +1,176 @@
|
||||
import { ReadableStream } from "node:stream/web";
|
||||
import { setTimeout as nativeSleep } from "node:timers/promises";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
runVoyageEmbeddingBatches,
|
||||
type VoyageBatchOutputLine,
|
||||
type VoyageBatchRequest,
|
||||
} from "./batch-voyage.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
const realNow = Date.now.bind(Date);
|
||||
|
||||
describe("runVoyageEmbeddingBatches", () => {
|
||||
const mockClient: VoyageEmbeddingClient = {
|
||||
baseUrl: "https://api.voyageai.com/v1",
|
||||
headers: { Authorization: "Bearer test-key" },
|
||||
model: "voyage-4-large",
|
||||
};
|
||||
|
||||
const mockRequests: VoyageBatchRequest[] = [
|
||||
{ custom_id: "req-1", body: { input: "text1" } },
|
||||
{ custom_id: "req-2", body: { input: "text2" } },
|
||||
];
|
||||
|
||||
it("successfully submits batch, waits, and streams results", async () => {
|
||||
const outputLines: VoyageBatchOutputLine[] = [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } },
|
||||
},
|
||||
{
|
||||
custom_id: "req-2",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
|
||||
},
|
||||
];
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
|
||||
// Create a stream that emits the NDJSON lines
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const text = outputLines.map((l) => JSON.stringify(l)).join("\n");
|
||||
controller.enqueue(new TextEncoder().encode(text));
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
|
||||
expect(params.errorPrefix).toBe("voyage batch file upload failed");
|
||||
expect(params.requests).toEqual(mockRequests);
|
||||
return "file-123";
|
||||
});
|
||||
postJsonWithRetry.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches");
|
||||
expect(params.body).toMatchObject({
|
||||
input_file_id: "file-123",
|
||||
completion_window: "12h",
|
||||
request_params: {
|
||||
model: "voyage-4-large",
|
||||
input_type: "document",
|
||||
},
|
||||
});
|
||||
return {
|
||||
id: "batch-abc",
|
||||
status: "pending",
|
||||
};
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches/batch-abc");
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
id: "batch-abc",
|
||||
status: "completed",
|
||||
output_file_id: "file-out-999",
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/file-out-999/content");
|
||||
return await params.onResponse(
|
||||
new Response(stream as unknown as BodyInit, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/x-ndjson" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "agent-1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1, // fast poll
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.size).toBe(2);
|
||||
expect(results.get("req-1")).toEqual([0.1, 0.1]);
|
||||
expect(results.get("req-2")).toEqual([0.2, 0.2]);
|
||||
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
|
||||
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
|
||||
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("handles empty lines and stream chunks correctly", async () => {
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const line1 = JSON.stringify({
|
||||
custom_id: "req-1",
|
||||
response: { body: { data: [{ embedding: [1] }] } },
|
||||
});
|
||||
const line2 = JSON.stringify({
|
||||
custom_id: "req-2",
|
||||
response: { body: { data: [{ embedding: [2] }] } },
|
||||
});
|
||||
|
||||
// Split across chunks
|
||||
controller.enqueue(new TextEncoder().encode(line1 + "\n"));
|
||||
controller.enqueue(new TextEncoder().encode("\n")); // empty line
|
||||
controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
|
||||
postJsonWithRetry.mockResolvedValueOnce({
|
||||
id: "b1",
|
||||
status: "completed",
|
||||
output_file_id: "out1",
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/out1/content");
|
||||
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "a1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.get("req-1")).toEqual([1]);
|
||||
expect(results.get("req-2")).toEqual([2]);
|
||||
});
|
||||
});
|
||||
@@ -19,8 +19,8 @@ import {
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { VoyageEmbeddingClient } from "./embedding-provider.js";
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
/**
|
||||
* Voyage Batch API Input Line format.
|
||||
@@ -1,14 +1,40 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js";
|
||||
import type { EmbeddingProvider } from "./embeddings.js";
|
||||
|
||||
const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192;
|
||||
const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048;
|
||||
|
||||
const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"openai:text-embedding-3-small": 8192,
|
||||
"openai:text-embedding-3-large": 8192,
|
||||
"openai:text-embedding-ada-002": 8191,
|
||||
"gemini:text-embedding-004": 2048,
|
||||
"gemini:gemini-embedding-001": 2048,
|
||||
"gemini:gemini-embedding-2-preview": 8192,
|
||||
"voyage:voyage-3": 32000,
|
||||
"voyage:voyage-3-lite": 16000,
|
||||
"voyage:voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number {
|
||||
if (typeof provider.maxInputTokens === "number") {
|
||||
return provider.maxInputTokens;
|
||||
}
|
||||
|
||||
if (provider.id === "local") {
|
||||
// Provider/model mapping is best-effort; different providers use different
|
||||
// limits and we prefer to be conservative when we don't know.
|
||||
const key = normalizeLowercaseStringOrEmpty(`${provider.id}:${provider.model}`);
|
||||
const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key];
|
||||
if (typeof known === "number") {
|
||||
return known;
|
||||
}
|
||||
|
||||
// Provider-specific conservative fallbacks. This prevents us from accidentally
|
||||
// using the OpenAI default for providers with much smaller limits.
|
||||
if (normalizeLowercaseStringOrEmpty(provider.id) === "gemini") {
|
||||
return 2048;
|
||||
}
|
||||
if (normalizeLowercaseStringOrEmpty(provider.id) === "local") {
|
||||
return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS;
|
||||
}
|
||||
|
||||
|
||||
377
packages/memory-host-sdk/src/host/embeddings-bedrock.test.ts
Normal file
377
packages/memory-host-sdk/src/host/embeddings-bedrock.test.ts
Normal file
@@ -0,0 +1,377 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const { defaultProviderMock, resolveCredentialsMock, sendMock } = vi.hoisted(() => ({
|
||||
defaultProviderMock: vi.fn(),
|
||||
resolveCredentialsMock: vi.fn(),
|
||||
sendMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
|
||||
class MockClient {
|
||||
region: string;
|
||||
constructor(config: { region: string }) {
|
||||
this.region = config.region;
|
||||
}
|
||||
send = sendMock;
|
||||
}
|
||||
class MockCommand {
|
||||
input: unknown;
|
||||
constructor(input: unknown) {
|
||||
this.input = input;
|
||||
}
|
||||
}
|
||||
return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand };
|
||||
});
|
||||
|
||||
vi.mock("@aws-sdk/credential-provider-node", () => ({
|
||||
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
|
||||
}));
|
||||
|
||||
let createBedrockEmbeddingProvider: typeof import("./embeddings-bedrock.js").createBedrockEmbeddingProvider;
|
||||
let resolveBedrockEmbeddingClient: typeof import("./embeddings-bedrock.js").resolveBedrockEmbeddingClient;
|
||||
let normalizeBedrockEmbeddingModel: typeof import("./embeddings-bedrock.js").normalizeBedrockEmbeddingModel;
|
||||
let hasAwsCredentials: typeof import("./embeddings-bedrock.js").hasAwsCredentials;
|
||||
|
||||
beforeAll(async () => {
|
||||
({
|
||||
createBedrockEmbeddingProvider,
|
||||
resolveBedrockEmbeddingClient,
|
||||
normalizeBedrockEmbeddingModel,
|
||||
hasAwsCredentials,
|
||||
} = await import("./embeddings-bedrock.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
defaultProviderMock.mockImplementation(() => resolveCredentialsMock);
|
||||
});
|
||||
|
||||
const enc = (body: unknown) => ({ body: new TextEncoder().encode(JSON.stringify(body)) });
|
||||
const reqBody = (i = 0): Record<string, unknown> =>
|
||||
JSON.parse(sendMock.mock.calls[i][0].input.body);
|
||||
|
||||
describe("bedrock embedding provider", () => {
|
||||
const originalEnv = process.env;
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
vi.restoreAllMocks();
|
||||
defaultProviderMock.mockClear();
|
||||
resolveCredentialsMock.mockReset();
|
||||
sendMock.mockReset();
|
||||
});
|
||||
|
||||
// --- Normalization ---
|
||||
|
||||
it("normalizes model names with prefixes", () => {
|
||||
expect(normalizeBedrockEmbeddingModel("bedrock/amazon.titan-embed-text-v2:0")).toBe(
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
);
|
||||
expect(normalizeBedrockEmbeddingModel("amazon-bedrock/cohere.embed-english-v3")).toBe(
|
||||
"cohere.embed-english-v3",
|
||||
);
|
||||
expect(normalizeBedrockEmbeddingModel("")).toBe("amazon.titan-embed-text-v2:0");
|
||||
});
|
||||
|
||||
// --- Client resolution ---
|
||||
|
||||
it("resolves region from env", () => {
|
||||
process.env = { ...originalEnv, AWS_REGION: "eu-west-1" };
|
||||
const c = resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(c.region).toBe("eu-west-1");
|
||||
expect(c.dimensions).toBe(1024);
|
||||
});
|
||||
|
||||
it("defaults to us-east-1", () => {
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.AWS_REGION;
|
||||
delete process.env.AWS_DEFAULT_REGION;
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
}).region,
|
||||
).toBe("us-east-1");
|
||||
});
|
||||
|
||||
it("extracts region from baseUrl", () => {
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.AWS_REGION;
|
||||
const c = resolveBedrockEmbeddingClient({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
"amazon-bedrock": { baseUrl: "https://bedrock-runtime.ap-southeast-2.amazonaws.com" },
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(c.region).toBe("ap-southeast-2");
|
||||
});
|
||||
|
||||
it("validates dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 768,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 768");
|
||||
});
|
||||
|
||||
it("accepts valid dimensions", () => {
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 256,
|
||||
}).dimensions,
|
||||
).toBe(256);
|
||||
});
|
||||
|
||||
it("resolves throughput-suffixed variants", () => {
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v1:2:8k",
|
||||
fallback: "none",
|
||||
}).dimensions,
|
||||
).toBe(1536);
|
||||
});
|
||||
|
||||
// --- Credential detection ---
|
||||
|
||||
it("detects access keys", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({
|
||||
AWS_ACCESS_KEY_ID: "A",
|
||||
AWS_SECRET_ACCESS_KEY: "s",
|
||||
} as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects profile", async () => {
|
||||
await expect(hasAwsCredentials({ AWS_PROFILE: "default" } as NodeJS.ProcessEnv)).resolves.toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
it("detects ECS task role", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({ AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: "/v2" } as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects EKS IRSA", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({
|
||||
AWS_WEB_IDENTITY_TOKEN_FILE: "/var/run/secrets/token",
|
||||
AWS_ROLE_ARN: "arn:aws:iam::123:role/x",
|
||||
} as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects credentials via the AWS SDK default provider chain", async () => {
|
||||
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
|
||||
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(true);
|
||||
expect(defaultProviderMock).toHaveBeenCalledWith({ timeout: 1000, maxRetries: 0 });
|
||||
});
|
||||
it("returns false with no creds", async () => {
|
||||
resolveCredentialsMock.mockRejectedValue(new Error("no aws credentials"));
|
||||
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(false);
|
||||
});
|
||||
|
||||
// --- Titan V2 ---
|
||||
|
||||
it("embeds with Titan V2", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.1, 0.2, 0.3] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("test")).toHaveLength(3);
|
||||
expect(reqBody()).toMatchObject({ inputText: "test", normalize: true, dimensions: 1024 });
|
||||
});
|
||||
|
||||
it("returns empty for blank text", async () => {
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery(" ")).toEqual([]);
|
||||
expect(sendMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("batches Titan V2 concurrently", async () => {
|
||||
sendMock
|
||||
.mockResolvedValueOnce(enc({ embedding: [0.1] }))
|
||||
.mockResolvedValueOnce(enc({ embedding: [0.2] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// --- Titan V1 ---
|
||||
|
||||
it("sends only inputText for Titan V1", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.5] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v1",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("hi");
|
||||
expect(reqBody()).toEqual({ inputText: "hi" });
|
||||
});
|
||||
|
||||
it("handles Titan G1 text variant", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.1] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-g1-text-02",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("hi");
|
||||
expect(reqBody()).toEqual({ inputText: "hi" });
|
||||
});
|
||||
|
||||
// --- Cohere V3 ---
|
||||
|
||||
it("embeds Cohere V3 batch in single call", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: [[0.1], [0.2]] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-english-v3",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(1);
|
||||
expect(reqBody()).toMatchObject({ texts: ["a", "b"], input_type: "search_document" });
|
||||
});
|
||||
|
||||
it("uses search_query for Cohere embedQuery", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: [[0.1]] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-english-v3",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("q");
|
||||
expect(reqBody().input_type).toBe("search_query");
|
||||
});
|
||||
|
||||
// --- Cohere V4 ---
|
||||
|
||||
it("embeds Cohere V4 with embedding_types + output_dimension", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: { float: [[0.1], [0.2]] } }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-v4:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(reqBody()).toMatchObject({ embedding_types: ["float"], output_dimension: 1536 });
|
||||
});
|
||||
|
||||
it("validates Cohere V4 dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-v4:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 2048,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 2048");
|
||||
});
|
||||
|
||||
// --- Nova ---
|
||||
|
||||
it("embeds Nova with SINGLE_EMBEDDING format", async () => {
|
||||
sendMock.mockResolvedValue(
|
||||
enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1, 0.2] }] }),
|
||||
);
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toHaveLength(2);
|
||||
expect(reqBody().taskType).toBe("SINGLE_EMBEDDING");
|
||||
});
|
||||
|
||||
it("validates Nova dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 512,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 512");
|
||||
});
|
||||
|
||||
it("batches Nova concurrently", async () => {
|
||||
sendMock
|
||||
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1] }] }))
|
||||
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.2] }] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// --- TwelveLabs ---
|
||||
|
||||
it("embeds TwelveLabs Marengo", async () => {
|
||||
sendMock.mockResolvedValue(enc({ data: [{ embedding: [0.1, 0.2] }] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "twelvelabs.marengo-embed-3-0-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toHaveLength(2);
|
||||
expect(reqBody()).toEqual({ inputType: "text", text: { inputText: "hi" } });
|
||||
});
|
||||
|
||||
it("embeds TwelveLabs object-style responses", async () => {
|
||||
sendMock.mockResolvedValue(enc({ data: { embedding: [0.3, 0.4] } }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "twelvelabs.marengo-embed-2-7-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toEqual([0.6, 0.8]);
|
||||
});
|
||||
});
|
||||
@@ -1,10 +1,7 @@
|
||||
import {
|
||||
debugEmbeddingsLog,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types & constants
|
||||
@@ -257,8 +254,8 @@ function parseCohereBatch(family: Family, raw: string): number[][] {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function createBedrockEmbeddingProvider(
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: BedrockEmbeddingClient }> {
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> {
|
||||
const client = resolveBedrockEmbeddingClient(options);
|
||||
const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk();
|
||||
const sdk = new BedrockRuntimeClient({ region: client.region });
|
||||
@@ -336,7 +333,7 @@ export async function createBedrockEmbeddingProvider(
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function resolveBedrockEmbeddingClient(
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
options: EmbeddingProviderOptions,
|
||||
): BedrockEmbeddingClient {
|
||||
const model = normalizeBedrockEmbeddingModel(options.model);
|
||||
const spec = resolveSpec(model);
|
||||
121
packages/memory-host-sdk/src/host/embeddings-gemini-request.ts
Normal file
121
packages/memory-host-sdk/src/host/embeddings-gemini-request.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
|
||||
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
|
||||
export const GEMINI_EMBEDDING_2_MODELS = new Set([
|
||||
"gemini-embedding-2-preview",
|
||||
// Add the GA model name here once released.
|
||||
]);
|
||||
|
||||
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
|
||||
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
|
||||
|
||||
export type GeminiTaskType =
|
||||
| "RETRIEVAL_QUERY"
|
||||
| "RETRIEVAL_DOCUMENT"
|
||||
| "SEMANTIC_SIMILARITY"
|
||||
| "CLASSIFICATION"
|
||||
| "CLUSTERING"
|
||||
| "QUESTION_ANSWERING"
|
||||
| "FACT_VERIFICATION";
|
||||
|
||||
export type GeminiTextPart = { text: string };
|
||||
export type GeminiInlinePart = {
|
||||
inlineData: { mimeType: string; data: string };
|
||||
};
|
||||
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
|
||||
export type GeminiEmbeddingRequest = {
|
||||
content: { parts: GeminiPart[] };
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
model?: string;
|
||||
};
|
||||
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
|
||||
|
||||
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
|
||||
export function buildGeminiTextEmbeddingRequest(params: {
|
||||
text: string;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiTextEmbeddingRequest {
|
||||
return buildGeminiEmbeddingRequest({
|
||||
input: { text: params.text },
|
||||
taskType: params.taskType,
|
||||
outputDimensionality: params.outputDimensionality,
|
||||
modelPath: params.modelPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGeminiEmbeddingRequest(params: {
|
||||
input: EmbeddingInput;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiEmbeddingRequest {
|
||||
const request: GeminiEmbeddingRequest = {
|
||||
content: {
|
||||
parts: params.input.parts?.map((part) =>
|
||||
part.type === "text"
|
||||
? ({ text: part.text } satisfies GeminiTextPart)
|
||||
: ({
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
} satisfies GeminiInlinePart),
|
||||
) ?? [{ text: params.input.text }],
|
||||
},
|
||||
taskType: params.taskType,
|
||||
};
|
||||
if (params.modelPath) {
|
||||
request.model = params.modelPath;
|
||||
}
|
||||
if (params.outputDimensionality != null) {
|
||||
request.outputDimensionality = params.outputDimensionality;
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given model name is a gemini-embedding-2 variant that
|
||||
* supports `outputDimensionality` and extended task types.
|
||||
*/
|
||||
export function isGeminiEmbedding2Model(model: string): boolean {
|
||||
return GEMINI_EMBEDDING_2_MODELS.has(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
|
||||
* Returns `undefined` for older models (they don't support the param).
|
||||
*/
|
||||
export function resolveGeminiOutputDimensionality(
|
||||
model: string,
|
||||
requested?: number,
|
||||
): number | undefined {
|
||||
if (!isGeminiEmbedding2Model(model)) {
|
||||
return undefined;
|
||||
}
|
||||
if (requested == null) {
|
||||
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
|
||||
if (!valid.includes(requested)) {
|
||||
throw new Error(
|
||||
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
|
||||
);
|
||||
}
|
||||
return requested;
|
||||
}
|
||||
|
||||
export function normalizeGeminiModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
}
|
||||
const withoutPrefix = trimmed.replace(/^models\//, "");
|
||||
if (withoutPrefix.startsWith("gemini/")) {
|
||||
return withoutPrefix.slice("gemini/".length);
|
||||
}
|
||||
if (withoutPrefix.startsWith("google/")) {
|
||||
return withoutPrefix.slice("google/".length);
|
||||
}
|
||||
return withoutPrefix;
|
||||
}
|
||||
52
packages/memory-host-sdk/src/host/embeddings-gemini.test.ts
Normal file
52
packages/memory-host-sdk/src/host/embeddings-gemini.test.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
|
||||
describe("package Gemini embedding request helpers", () => {
|
||||
it("builds multimodal v2 requests and resolves model settings", () => {
|
||||
expect(
|
||||
buildGeminiEmbeddingRequest({
|
||||
input: {
|
||||
text: "Image file: diagram.png",
|
||||
parts: [
|
||||
{ type: "text", text: "Image file: diagram.png" },
|
||||
{ type: "inline-data", mimeType: "image/png", data: "abc123" },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
outputDimensionality: 1536,
|
||||
}),
|
||||
).toEqual({
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "abc123" } },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
});
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-001")).toBeUndefined();
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview")).toBe(3072);
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 768)).toBe(768);
|
||||
expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 512)).toThrow(
|
||||
/Invalid outputDimensionality 512/,
|
||||
);
|
||||
expect(normalizeGeminiModel("models/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("gemini/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("google/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("")).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL);
|
||||
});
|
||||
});
|
||||
238
packages/memory-host-sdk/src/host/embeddings-gemini.ts
Normal file
238
packages/memory-host-sdk/src/host/embeddings-gemini.ts
Normal file
@@ -0,0 +1,238 @@
|
||||
import {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../../../../src/agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../../../../src/agents/model-auth.js";
|
||||
import { parseGeminiAuth } from "../../../../src/infra/gemini-auth.js";
|
||||
import {
|
||||
DEFAULT_GOOGLE_API_BASE_URL,
|
||||
normalizeGoogleApiBaseUrl,
|
||||
} from "../../../../src/infra/google-api-base-url.js";
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
type GeminiEmbeddingRequest,
|
||||
type GeminiInlinePart,
|
||||
type GeminiPart,
|
||||
type GeminiTaskType,
|
||||
type GeminiTextEmbeddingRequest,
|
||||
type GeminiTextPart,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
|
||||
export type GeminiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
modelPath: string;
|
||||
apiKeys: string[];
|
||||
outputDimensionality?: number;
|
||||
};
|
||||
|
||||
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-004": 2048,
|
||||
};
|
||||
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
const trimmed = resolveMemorySecretInputString({
|
||||
value: remoteApiKey,
|
||||
path: "agents.*.memorySearch.remote.apiKey",
|
||||
});
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") {
|
||||
return process.env[trimmed]?.trim();
|
||||
}
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
async function fetchGeminiEmbeddingPayload(params: {
|
||||
client: GeminiEmbeddingClient;
|
||||
endpoint: string;
|
||||
body: unknown;
|
||||
}): Promise<{
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
}> {
|
||||
return await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: params.client.apiKeys,
|
||||
execute: async (apiKey) => {
|
||||
const authHeaders = parseGeminiAuth(apiKey);
|
||||
const headers = {
|
||||
...authHeaders.headers,
|
||||
...params.client.headers,
|
||||
};
|
||||
return await withRemoteHttpResponse({
|
||||
url: params.endpoint,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(params.body),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function normalizeGeminiBaseUrl(raw: string): string {
|
||||
const trimmed = raw.replace(/\/+$/, "");
|
||||
const openAiIndex = trimmed.indexOf("/openai");
|
||||
if (openAiIndex > -1) {
|
||||
return normalizeGoogleApiBaseUrl(trimmed.slice(0, openAiIndex));
|
||||
}
|
||||
return normalizeGoogleApiBaseUrl(trimmed);
|
||||
}
|
||||
|
||||
function buildGeminiModelPath(model: string): string {
|
||||
return model.startsWith("models/") ? model : `models/${model}`;
|
||||
}
|
||||
|
||||
export async function createGeminiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
|
||||
const client = await resolveGeminiEmbeddingClient(options);
|
||||
const baseUrl = client.baseUrl.replace(/\/$/, "");
|
||||
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
|
||||
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
|
||||
const isV2 = isGeminiEmbedding2Model(client.model);
|
||||
const outputDimensionality = client.outputDimensionality;
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: embedUrl,
|
||||
body: buildGeminiTextEmbeddingRequest({
|
||||
text,
|
||||
taskType: options.taskType ?? "RETRIEVAL_QUERY",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
});
|
||||
return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []);
|
||||
};
|
||||
|
||||
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
|
||||
if (inputs.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: batchUrl,
|
||||
body: {
|
||||
requests: inputs.map((input) =>
|
||||
buildGeminiEmbeddingRequest({
|
||||
input,
|
||||
modelPath: client.modelPath,
|
||||
taskType: options.taskType ?? "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
),
|
||||
},
|
||||
});
|
||||
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
|
||||
return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? []));
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
return await embedBatchInputs(
|
||||
texts.map((text) => ({
|
||||
text,
|
||||
})),
|
||||
);
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "gemini",
|
||||
model: client.model,
|
||||
maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery,
|
||||
embedBatch,
|
||||
embedBatchInputs,
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveGeminiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<GeminiEmbeddingClient> {
|
||||
const remote = options.remote;
|
||||
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
|
||||
const apiKey = remoteApiKey
|
||||
? remoteApiKey
|
||||
: requireApiKey(
|
||||
await resolveApiKeyForProvider({
|
||||
provider: "google",
|
||||
cfg: options.config,
|
||||
agentDir: options.agentDir,
|
||||
}),
|
||||
"google",
|
||||
);
|
||||
|
||||
const providerConfig = options.config.models?.providers?.google;
|
||||
const rawBaseUrl =
|
||||
remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GOOGLE_API_BASE_URL;
|
||||
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
|
||||
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
...headerOverrides,
|
||||
};
|
||||
const apiKeys = collectProviderApiKeysForExecution({
|
||||
provider: "google",
|
||||
primaryApiKey: apiKey,
|
||||
});
|
||||
const model = normalizeGeminiModel(options.model);
|
||||
const modelPath = buildGeminiModelPath(model);
|
||||
const outputDimensionality = resolveGeminiOutputDimensionality(
|
||||
model,
|
||||
options.outputDimensionality,
|
||||
);
|
||||
debugEmbeddingsLog("memory embeddings: gemini client", {
|
||||
rawBaseUrl,
|
||||
baseUrl,
|
||||
model,
|
||||
modelPath,
|
||||
outputDimensionality,
|
||||
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
||||
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
||||
});
|
||||
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality };
|
||||
}
|
||||
1
packages/memory-host-sdk/src/host/embeddings-lmstudio.ts
Normal file
1
packages/memory-host-sdk/src/host/embeddings-lmstudio.ts
Normal file
@@ -0,0 +1 @@
|
||||
export * from "../../../../src/memory-host-sdk/host/embeddings-lmstudio.js";
|
||||
19
packages/memory-host-sdk/src/host/embeddings-mistral.test.ts
Normal file
19
packages/memory-host-sdk/src/host/embeddings-mistral.test.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js";
|
||||
|
||||
describe("normalizeMistralModel", () => {
|
||||
it("returns the default model for empty values", () => {
|
||||
expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
});
|
||||
|
||||
it("strips the mistral/ prefix", () => {
|
||||
expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed");
|
||||
});
|
||||
|
||||
it("keeps explicit non-prefixed models", () => {
|
||||
expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2");
|
||||
});
|
||||
});
|
||||
51
packages/memory-host-sdk/src/host/embeddings-mistral.ts
Normal file
51
packages/memory-host-sdk/src/host/embeddings-mistral.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type MistralEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed";
|
||||
const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1";
|
||||
|
||||
export function normalizeMistralModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
prefixes: ["mistral/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createMistralEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
|
||||
const client = await resolveMistralEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
provider: createRemoteEmbeddingProvider({
|
||||
id: "mistral",
|
||||
client,
|
||||
errorPrefix: "mistral embeddings failed",
|
||||
}),
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveMistralEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<MistralEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "mistral",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL,
|
||||
normalizeModel: normalizeMistralModel,
|
||||
});
|
||||
}
|
||||
43
packages/memory-host-sdk/src/host/embeddings-ollama.test.ts
Normal file
43
packages/memory-host-sdk/src/host/embeddings-ollama.test.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({
|
||||
createOllamaEmbeddingProviderMock: vi.fn(async (options: unknown) => ({
|
||||
provider: { source: "mock-provider", options },
|
||||
client: { source: "mock-client" },
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/plugin-sdk/ollama-runtime.js", () => ({
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL: "nomic-embed-text",
|
||||
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
|
||||
}));
|
||||
|
||||
describe("memory-host-sdk Ollama embedding facade", () => {
|
||||
beforeEach(() => {
|
||||
createOllamaEmbeddingProviderMock.mockClear();
|
||||
});
|
||||
|
||||
it("re-exports the default Ollama embedding model", async () => {
|
||||
const mod = await import("./embeddings-ollama.js");
|
||||
expect(mod.DEFAULT_OLLAMA_EMBEDDING_MODEL).toBe("nomic-embed-text");
|
||||
});
|
||||
|
||||
it("delegates provider creation to the plugin-sdk runtime facade", async () => {
|
||||
const mod = await import("./embeddings-ollama.js");
|
||||
const options = {
|
||||
provider: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
fallback: "none",
|
||||
config: {},
|
||||
};
|
||||
|
||||
const result = await mod.createOllamaEmbeddingProvider(options as never);
|
||||
|
||||
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledTimes(1);
|
||||
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledWith(options);
|
||||
expect(result).toEqual({
|
||||
provider: { source: "mock-provider", options },
|
||||
client: { source: "mock-client" },
|
||||
});
|
||||
});
|
||||
});
|
||||
5
packages/memory-host-sdk/src/host/embeddings-ollama.ts
Normal file
5
packages/memory-host-sdk/src/host/embeddings-ollama.ts
Normal file
@@ -0,0 +1,5 @@
|
||||
export type { OllamaEmbeddingClient } from "../../../../src/plugin-sdk/ollama-runtime.js";
|
||||
export {
|
||||
createOllamaEmbeddingProvider,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
} from "../../../../src/plugin-sdk/ollama-runtime.js";
|
||||
58
packages/memory-host-sdk/src/host/embeddings-openai.ts
Normal file
58
packages/memory-host-sdk/src/host/embeddings-openai.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../../../src/plugins/provider-model-defaults.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type OpenAiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
|
||||
export const DEFAULT_OPENAI_EMBEDDING_MODEL = OPENAI_DEFAULT_EMBEDDING_MODEL;
|
||||
const OPENAI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-3-small": 8192,
|
||||
"text-embedding-3-large": 8192,
|
||||
"text-embedding-ada-002": 8191,
|
||||
};
|
||||
|
||||
export function normalizeOpenAiModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
prefixes: ["openai/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createOpenAiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
const client = await resolveOpenAiEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
provider: createRemoteEmbeddingProvider({
|
||||
id: "openai",
|
||||
client,
|
||||
errorPrefix: "openai embeddings failed",
|
||||
maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model],
|
||||
}),
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveOpenAiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<OpenAiEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "openai",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
|
||||
normalizeModel: normalizeOpenAiModel,
|
||||
});
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import type { EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export type RemoteEmbeddingProviderId = string;
|
||||
export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral";
|
||||
|
||||
export async function resolveRemoteEmbeddingBearerClient(params: {
|
||||
provider: RemoteEmbeddingProviderId;
|
||||
|
||||
188
packages/memory-host-sdk/src/host/embeddings-voyage.test.ts
Normal file
188
packages/memory-host-sdk/src/host/embeddings-voyage.test.ts
Normal file
@@ -0,0 +1,188 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../../../src/agents/model-auth.js";
|
||||
import { type FetchMock, withFetchPreconnect } from "../../../../src/test-utils/fetch-mock.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: async (params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
fetchImpl?: typeof fetch;
|
||||
}) => {
|
||||
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
|
||||
if (!fetchImpl) {
|
||||
throw new Error("fetch is not available");
|
||||
}
|
||||
const response = await fetchImpl(params.url, params.init);
|
||||
return {
|
||||
response,
|
||||
finalUrl: params.url,
|
||||
release: async () => {},
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const createFetchMock = () => {
|
||||
const fetchMock = vi.fn<FetchMock>(
|
||||
async (_input: RequestInfo | URL, _init?: RequestInit) =>
|
||||
new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
return withFetchPreconnect(fetchMock);
|
||||
};
|
||||
|
||||
function installFetchMock(fetchMock: typeof globalThis.fetch) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
}
|
||||
|
||||
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
|
||||
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
|
||||
await import("./embeddings-voyage.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
});
|
||||
|
||||
function mockVoyageApiKey() {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "voyage-key-123",
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
async function createDefaultVoyageProvider(
|
||||
model: string,
|
||||
fetchMock: ReturnType<typeof createFetchMock>,
|
||||
) {
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockVoyageApiKey();
|
||||
return createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model,
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
describe("voyage embedding provider", () => {
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("configures client with correct defaults and headers", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedQuery("test query");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ provider: "voyage" }),
|
||||
);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://api.voyageai.com/v1/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer voyage-key-123");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["test query"],
|
||||
input_type: "query",
|
||||
});
|
||||
});
|
||||
|
||||
it("respects remote overrides for baseUrl and apiKey", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
|
||||
const result = await createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model: "voyage-4-lite",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "https://example.com",
|
||||
apiKey: "remote-override-key",
|
||||
headers: { "X-Custom": "123" },
|
||||
},
|
||||
});
|
||||
|
||||
await result.provider.embedQuery("test");
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://example.com/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-override-key");
|
||||
expect(headers["X-Custom"]).toBe("123");
|
||||
});
|
||||
|
||||
it("passes input_type=document for embedBatch", async () => {
|
||||
const fetchMock = withFetchPreconnect(
|
||||
vi.fn<FetchMock>(
|
||||
async (_input: RequestInfo | URL, _init?: RequestInit) =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
),
|
||||
),
|
||||
);
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedBatch(["doc1", "doc2"]);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["doc1", "doc2"],
|
||||
input_type: "document",
|
||||
});
|
||||
});
|
||||
|
||||
it("normalizes model names", async () => {
|
||||
expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2");
|
||||
expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large");
|
||||
expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite");
|
||||
expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default
|
||||
});
|
||||
});
|
||||
82
packages/memory-host-sdk/src/host/embeddings-voyage.ts
Normal file
82
packages/memory-host-sdk/src/host/embeddings-voyage.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type VoyageEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large";
|
||||
const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1";
|
||||
const VOYAGE_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"voyage-3": 32000,
|
||||
"voyage-3-lite": 16000,
|
||||
"voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function normalizeVoyageModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
prefixes: ["voyage/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createVoyageEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
|
||||
const client = await resolveVoyageEmbeddingClient(options);
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const body: { model: string; input: string[]; input_type?: "query" | "document" } = {
|
||||
model: client.model,
|
||||
input,
|
||||
};
|
||||
if (input_type) {
|
||||
body.input_type = input_type;
|
||||
}
|
||||
|
||||
return await fetchRemoteEmbeddingVectors({
|
||||
url,
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
body,
|
||||
errorPrefix: "voyage embeddings failed",
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "voyage",
|
||||
model: client.model,
|
||||
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text], "query");
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: async (texts) => embed(texts, "document"),
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveVoyageEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<VoyageEmbeddingClient> {
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: "voyage",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
|
||||
});
|
||||
const model = normalizeVoyageModel(options.model);
|
||||
return { baseUrl, headers, ssrfPolicy, model };
|
||||
}
|
||||
@@ -1,8 +1,199 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_LOCAL_MODEL } from "./embeddings.js";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../../../src/agents/model-auth.js";
|
||||
import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
|
||||
import * as nodeLlamaModule from "./node-llama.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
describe("package embeddings barrel", () => {
|
||||
it("re-exports the source local embedding contract", () => {
|
||||
expect(DEFAULT_LOCAL_MODEL).toContain("embeddinggemma");
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: async (params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
fetchImpl?: typeof fetch;
|
||||
}) => {
|
||||
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
|
||||
if (!fetchImpl) {
|
||||
throw new Error("fetch is not available");
|
||||
}
|
||||
const response = await fetchImpl(params.url, params.init);
|
||||
return {
|
||||
response,
|
||||
finalUrl: params.url,
|
||||
release: async () => {},
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
const createEmbeddingDataFetchMock = () =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
|
||||
}));
|
||||
|
||||
const createGeminiFetchMock = () =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ embedding: { values: [1, 2, 3] } }),
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(authModule, "resolveApiKeyForProvider");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
function installFetchMock(fetchMock: typeof globalThis.fetch) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
}
|
||||
|
||||
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
|
||||
const [url, init] = fetchMock.mock.calls[0] ?? [];
|
||||
return { url, init: init as RequestInit | undefined };
|
||||
}
|
||||
|
||||
function requireProvider(result: Awaited<ReturnType<typeof createEmbeddingProvider>>) {
|
||||
if (!result.provider) {
|
||||
throw new Error("Expected embedding provider");
|
||||
}
|
||||
return result.provider;
|
||||
}
|
||||
|
||||
function mockResolvedProviderKey(apiKey = "provider-key") {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey,
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
describe("package embedding provider smoke", () => {
|
||||
it("uses remote OpenAI baseUrl/apiKey and merges headers", async () => {
|
||||
const fetchMock = createEmbeddingDataFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
headers: { "X-Provider": "p", "X-Shared": "provider" },
|
||||
},
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " remote-key ",
|
||||
headers: { "X-Shared": "remote", "X-Remote": "r" },
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
await requireProvider(result).embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
||||
const { url, init } = readFirstFetchRequest(fetchMock);
|
||||
expect(url).toBe("https://example.com/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-key");
|
||||
expect(headers["X-Provider"]).toBe("p");
|
||||
expect(headers["X-Shared"]).toBe("remote");
|
||||
expect(headers["X-Remote"]).toBe("r");
|
||||
});
|
||||
|
||||
it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
vi.stubEnv("GEMINI_API_KEY", "env-gemini-key");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: "GEMINI_API_KEY", // pragma: allowlist secret
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
await requireProvider(result).embedQuery("hello");
|
||||
|
||||
const { init } = readFirstFetchRequest(fetchMock);
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers["x-goog-api-key"]).toBe("env-gemini-key");
|
||||
});
|
||||
|
||||
it("normalizes local embeddings and resolves the default local model", async () => {
|
||||
const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf");
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array([2.35, 3.45, 0.63, 4.3]),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: resolveModelFileMock,
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const embedding = await requireProvider(result).embedQuery("test query");
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, value) => sum + value * value, 0));
|
||||
expect(magnitude).toBeCloseTo(1, 5);
|
||||
expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
|
||||
});
|
||||
|
||||
it("returns null provider when explicit primary and fallback auth paths fail", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
new Error("No API key found for provider"),
|
||||
);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "gemini",
|
||||
});
|
||||
|
||||
expect(result.provider).toBeNull();
|
||||
expect(result.requestedProvider).toBe("openai");
|
||||
expect(result.fallbackFrom).toBe("openai");
|
||||
expect(result.providerUnavailableReason).toContain("Fallback to gemini failed");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1 +1,373 @@
|
||||
export * from "../../../../src/memory-host-sdk/host/embeddings.js";
|
||||
import fsSync from "node:fs";
|
||||
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
|
||||
import type { OpenClawConfig } from "../../../../src/config/config.js";
|
||||
import type { SecretInput } from "../../../../src/config/types.secrets.js";
|
||||
import { formatErrorMessage } from "../../../../src/infra/errors.js";
|
||||
import { resolveUserPath } from "../../../../src/utils.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import {
|
||||
createBedrockEmbeddingProvider,
|
||||
hasAwsCredentials,
|
||||
type BedrockEmbeddingClient,
|
||||
} from "./embeddings-bedrock.js";
|
||||
import {
|
||||
createGeminiEmbeddingProvider,
|
||||
type GeminiEmbeddingClient,
|
||||
type GeminiTaskType,
|
||||
} from "./embeddings-gemini.js";
|
||||
import {
|
||||
createLmstudioEmbeddingProvider,
|
||||
type LmstudioEmbeddingClient,
|
||||
} from "./embeddings-lmstudio.js";
|
||||
import {
|
||||
createMistralEmbeddingProvider,
|
||||
type MistralEmbeddingClient,
|
||||
} from "./embeddings-mistral.js";
|
||||
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
import { importNodeLlamaCpp } from "./node-llama.js";
|
||||
|
||||
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
export type { LmstudioEmbeddingClient } from "./embeddings-lmstudio.js";
|
||||
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
|
||||
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
export type { BedrockEmbeddingClient } from "./embeddings-bedrock.js";
|
||||
|
||||
export type EmbeddingProvider = {
|
||||
id: string;
|
||||
model: string;
|
||||
maxInputTokens?: number;
|
||||
embedQuery: (text: string) => Promise<number[]>;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderId =
|
||||
| "openai"
|
||||
| "local"
|
||||
| "gemini"
|
||||
| "voyage"
|
||||
| "mistral"
|
||||
| "bedrock"
|
||||
| "lmstudio"
|
||||
| "ollama";
|
||||
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
|
||||
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
|
||||
|
||||
// Remote providers considered for auto-selection when provider === "auto".
|
||||
// LM Studio and Ollama are intentionally excluded here so that "auto" mode does not
|
||||
// implicitly assume either instance is available.
|
||||
// Bedrock is handled separately when AWS credentials are detected.
|
||||
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
|
||||
|
||||
export type EmbeddingProviderResult = {
|
||||
provider: EmbeddingProvider | null;
|
||||
requestedProvider: EmbeddingProviderRequest;
|
||||
fallbackFrom?: EmbeddingProviderId;
|
||||
fallbackReason?: string;
|
||||
providerUnavailableReason?: string;
|
||||
openAi?: OpenAiEmbeddingClient;
|
||||
gemini?: GeminiEmbeddingClient;
|
||||
voyage?: VoyageEmbeddingClient;
|
||||
mistral?: MistralEmbeddingClient;
|
||||
bedrock?: BedrockEmbeddingClient;
|
||||
lmstudio?: LmstudioEmbeddingClient;
|
||||
ollama?: OllamaEmbeddingClient;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderOptions = {
|
||||
config: OpenClawConfig;
|
||||
agentDir?: string;
|
||||
provider: EmbeddingProviderRequest;
|
||||
remote?: {
|
||||
baseUrl?: string;
|
||||
apiKey?: SecretInput;
|
||||
headers?: Record<string, string>;
|
||||
};
|
||||
model: string;
|
||||
fallback: EmbeddingProviderFallback;
|
||||
local?: {
|
||||
modelPath?: string;
|
||||
modelCacheDir?: string;
|
||||
};
|
||||
/** Provider-specific output vector dimensions for supported embedding families. */
|
||||
outputDimensionality?: number;
|
||||
/** Gemini: override the default task type sent with embedding requests. */
|
||||
taskType?: GeminiTaskType;
|
||||
};
|
||||
|
||||
export const DEFAULT_LOCAL_MODEL =
|
||||
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
|
||||
|
||||
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
|
||||
const modelPath = options.local?.modelPath?.trim();
|
||||
if (!modelPath) {
|
||||
return false;
|
||||
}
|
||||
if (/^(hf:|https?:)/i.test(modelPath)) {
|
||||
return false;
|
||||
}
|
||||
const resolved = resolveUserPath(modelPath);
|
||||
try {
|
||||
return fsSync.statSync(resolved).isFile();
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function isMissingApiKeyError(err: unknown): boolean {
|
||||
const message = formatErrorMessage(err);
|
||||
return message.includes("No API key found for provider");
|
||||
}
|
||||
|
||||
export async function createLocalEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProvider> {
|
||||
const modelPath = options.local?.modelPath?.trim() || DEFAULT_LOCAL_MODEL;
|
||||
const modelCacheDir = options.local?.modelCacheDir?.trim();
|
||||
|
||||
// Lazy-load node-llama-cpp to keep startup light unless local is enabled.
|
||||
const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp();
|
||||
|
||||
let llama: Llama | null = null;
|
||||
let embeddingModel: LlamaModel | null = null;
|
||||
let embeddingContext: LlamaEmbeddingContext | null = null;
|
||||
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
|
||||
|
||||
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
|
||||
if (embeddingContext) {
|
||||
return embeddingContext;
|
||||
}
|
||||
if (initPromise) {
|
||||
return initPromise;
|
||||
}
|
||||
initPromise = (async () => {
|
||||
try {
|
||||
if (!llama) {
|
||||
llama = await getLlama({ logLevel: LlamaLogLevel.error });
|
||||
}
|
||||
if (!embeddingModel) {
|
||||
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
|
||||
embeddingModel = await llama.loadModel({ modelPath: resolved });
|
||||
}
|
||||
if (!embeddingContext) {
|
||||
embeddingContext = await embeddingModel.createEmbeddingContext();
|
||||
}
|
||||
return embeddingContext;
|
||||
} catch (err) {
|
||||
initPromise = null;
|
||||
throw err;
|
||||
}
|
||||
})();
|
||||
return initPromise;
|
||||
};
|
||||
|
||||
return {
|
||||
id: "local",
|
||||
model: modelPath,
|
||||
embedQuery: async (text) => {
|
||||
const ctx = await ensureContext();
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
},
|
||||
embedBatch: async (texts) => {
|
||||
const ctx = await ensureContext();
|
||||
const embeddings = await Promise.all(
|
||||
texts.map(async (text) => {
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
}),
|
||||
);
|
||||
return embeddings;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export async function createEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProviderResult> {
|
||||
const requestedProvider = options.provider;
|
||||
const fallback = options.fallback;
|
||||
|
||||
const createProvider = async (id: EmbeddingProviderId) => {
|
||||
if (id === "local") {
|
||||
const provider = await createLocalEmbeddingProvider(options);
|
||||
return { provider };
|
||||
}
|
||||
if (id === "lmstudio") {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider(options);
|
||||
return { provider, lmstudio: client };
|
||||
}
|
||||
if (id === "ollama") {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider(options);
|
||||
return { provider, ollama: client };
|
||||
}
|
||||
if (id === "gemini") {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider(options);
|
||||
return { provider, gemini: client };
|
||||
}
|
||||
if (id === "voyage") {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider(options);
|
||||
return { provider, voyage: client };
|
||||
}
|
||||
if (id === "mistral") {
|
||||
const { provider, client } = await createMistralEmbeddingProvider(options);
|
||||
return { provider, mistral: client };
|
||||
}
|
||||
if (id === "bedrock") {
|
||||
const { provider, client } = await createBedrockEmbeddingProvider(options);
|
||||
return { provider, bedrock: client };
|
||||
}
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider(options);
|
||||
return { provider, openAi: client };
|
||||
};
|
||||
|
||||
const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
|
||||
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
|
||||
|
||||
if (requestedProvider === "auto") {
|
||||
const missingKeyErrors: string[] = [];
|
||||
let localError: string | null = null;
|
||||
|
||||
if (canAutoSelectLocal(options)) {
|
||||
try {
|
||||
const local = await createProvider("local");
|
||||
return { ...local, requestedProvider };
|
||||
} catch (err) {
|
||||
localError = formatLocalSetupError(err);
|
||||
}
|
||||
}
|
||||
|
||||
for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
|
||||
try {
|
||||
const result = await createProvider(provider);
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, provider);
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
continue;
|
||||
}
|
||||
// Non-auth errors (e.g., network) are still fatal
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
// Try bedrock if AWS credentials are available
|
||||
if (await hasAwsCredentials()) {
|
||||
try {
|
||||
const result = await createProvider("bedrock");
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, "bedrock");
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
} else {
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All providers failed due to missing API keys - return null provider for FTS-only mode
|
||||
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
|
||||
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const primary = await createProvider(requestedProvider);
|
||||
return { ...primary, requestedProvider };
|
||||
} catch (primaryErr) {
|
||||
const reason = formatPrimaryError(primaryErr, requestedProvider);
|
||||
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
|
||||
try {
|
||||
const fallbackResult = await createProvider(fallback);
|
||||
return {
|
||||
...fallbackResult,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
};
|
||||
} catch (fallbackErr) {
|
||||
// Both primary and fallback failed - check if it's auth-related
|
||||
const fallbackReason = formatErrorMessage(fallbackErr);
|
||||
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
|
||||
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
|
||||
// Both failed due to missing API keys - return null for FTS-only mode
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
providerUnavailableReason: combinedReason,
|
||||
};
|
||||
}
|
||||
// Non-auth errors are still fatal
|
||||
const wrapped = new Error(combinedReason) as Error & {
|
||||
cause?: unknown;
|
||||
};
|
||||
wrapped.cause = fallbackErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
// No fallback configured - check if we should degrade to FTS-only
|
||||
if (isMissingApiKeyError(primaryErr)) {
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
const wrapped = new Error(reason) as Error & { cause?: unknown };
|
||||
wrapped.cause = primaryErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
}
|
||||
const code = (err as Error & { code?: unknown }).code;
|
||||
if (code === "ERR_MODULE_NOT_FOUND") {
|
||||
return err.message.includes("node-llama-cpp");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function formatLocalSetupError(err: unknown): string {
|
||||
const detail = formatErrorMessage(err);
|
||||
const missing = isNodeLlamaCppMissing(err);
|
||||
return [
|
||||
"Local embeddings unavailable.",
|
||||
missing
|
||||
? "Reason: optional dependency node-llama-cpp is missing (or failed to install)."
|
||||
: detail
|
||||
? `Reason: ${detail}`
|
||||
: undefined,
|
||||
missing && detail ? `Detail: ${detail}` : null,
|
||||
"To enable local embeddings:",
|
||||
"1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)",
|
||||
missing
|
||||
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
|
||||
: null,
|
||||
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
|
||||
...REMOTE_EMBEDDING_PROVIDER_IDS.map(
|
||||
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
|
||||
),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
@@ -100,3 +100,21 @@ export function classifyMemoryMultimodalPath(
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function normalizeGeminiEmbeddingModelForMemory(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return "";
|
||||
}
|
||||
return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, "");
|
||||
}
|
||||
|
||||
export function supportsMemoryMultimodalEmbeddings(params: {
|
||||
provider: string;
|
||||
model: string;
|
||||
}): boolean {
|
||||
if (params.provider !== "gemini") {
|
||||
return false;
|
||||
}
|
||||
return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
15
pnpm-lock.yaml
generated
15
pnpm-lock.yaml
generated
@@ -362,12 +362,6 @@ importers:
|
||||
'@aws-sdk/client-bedrock':
|
||||
specifier: 3.1028.0
|
||||
version: 3.1028.0
|
||||
'@aws-sdk/client-bedrock-runtime':
|
||||
specifier: 3.1028.0
|
||||
version: 3.1028.0
|
||||
'@aws-sdk/credential-provider-node':
|
||||
specifier: 3.972.30
|
||||
version: 3.972.30
|
||||
devDependencies:
|
||||
'@openclaw/plugin-sdk':
|
||||
specifier: workspace:*
|
||||
@@ -995,9 +989,6 @@ importers:
|
||||
|
||||
extensions/qa-lab:
|
||||
dependencies:
|
||||
'@copilotkit/aimock':
|
||||
specifier: 1.13.0
|
||||
version: 1.13.0
|
||||
playwright-core:
|
||||
specifier: 1.59.1
|
||||
version: 1.59.1
|
||||
@@ -1234,12 +1225,6 @@ importers:
|
||||
specifier: workspace:*
|
||||
version: link:../../packages/plugin-sdk
|
||||
|
||||
extensions/voyage:
|
||||
devDependencies:
|
||||
'@openclaw/plugin-sdk':
|
||||
specifier: workspace:*
|
||||
version: link:../../packages/plugin-sdk
|
||||
|
||||
extensions/vydra:
|
||||
devDependencies:
|
||||
'@openclaw/plugin-sdk':
|
||||
|
||||
@@ -1,486 +0,0 @@
|
||||
import path from "node:path";
|
||||
import { collectVitestFileDurations, normalizeTrackedRepoPath } from "../test-report-utils.mjs";
|
||||
import { formatMs } from "./vitest-report-cli-utils.mjs";
|
||||
|
||||
export function formatBytesAsMb(valueBytes) {
|
||||
return valueBytes === null || valueBytes === undefined
|
||||
? "n/a"
|
||||
: `${(valueBytes / 1024 / 1024).toFixed(1)}MB`;
|
||||
}
|
||||
|
||||
export function formatSignedMs(value, digits = 1) {
|
||||
return `${value > 0 ? "+" : ""}${formatMs(value, digits)}`;
|
||||
}
|
||||
|
||||
export function formatSignedBytesAsMb(valueBytes) {
|
||||
return valueBytes === null || valueBytes === undefined
|
||||
? "n/a"
|
||||
: `${valueBytes > 0 ? "+" : ""}${formatBytesAsMb(valueBytes)}`;
|
||||
}
|
||||
|
||||
export function normalizeConfigLabel(config) {
|
||||
return config.replace(/^test\/vitest\/vitest\./u, "").replace(/\.config\.ts$/u, "");
|
||||
}
|
||||
|
||||
export function resolveTestArea(file) {
|
||||
const normalized = normalizeTrackedRepoPath(file);
|
||||
const parts = normalized.split("/");
|
||||
if (parts[0] === "extensions" && parts[1]) {
|
||||
return `extensions/${parts[1]}`;
|
||||
}
|
||||
if (parts[0] === "src" && parts[1]) {
|
||||
return `src/${parts[1]}`;
|
||||
}
|
||||
if (parts[0] === "packages" && parts[1]) {
|
||||
return `packages/${parts[1]}`;
|
||||
}
|
||||
if (parts[0] === "apps" && parts[1]) {
|
||||
return `apps/${parts[1]}`;
|
||||
}
|
||||
if (parts[0] === "ui") {
|
||||
return parts[3] ? `ui/${parts[3]}` : "ui";
|
||||
}
|
||||
if (parts[0] === "test" && parts[1]) {
|
||||
return `test/${parts[1]}`;
|
||||
}
|
||||
return parts[0] || normalized;
|
||||
}
|
||||
|
||||
export function resolveTestFolder(file, depth = 2) {
|
||||
const normalized = normalizeTrackedRepoPath(file);
|
||||
const dir = path.posix.dirname(normalized);
|
||||
if (dir === ".") {
|
||||
return normalized;
|
||||
}
|
||||
return dir.split("/").slice(0, Math.max(1, depth)).join("/");
|
||||
}
|
||||
|
||||
export function resolveGroupKey(file, mode = "area") {
|
||||
if (mode === "folder") {
|
||||
return resolveTestFolder(file, 3);
|
||||
}
|
||||
if (mode === "top") {
|
||||
return normalizeTrackedRepoPath(file).split("/")[0] || file;
|
||||
}
|
||||
return resolveTestArea(file);
|
||||
}
|
||||
|
||||
function createCounter(key) {
|
||||
return {
|
||||
key,
|
||||
durationMs: 0,
|
||||
fileCount: 0,
|
||||
testCount: 0,
|
||||
configs: new Set(),
|
||||
};
|
||||
}
|
||||
|
||||
function addFileEntry(target, entry, config) {
|
||||
target.durationMs += entry.durationMs;
|
||||
target.fileCount += 1;
|
||||
target.testCount += entry.testCount;
|
||||
target.configs.add(config);
|
||||
}
|
||||
|
||||
function finalizeCounter(counter) {
|
||||
return {
|
||||
key: counter.key,
|
||||
durationMs: counter.durationMs,
|
||||
fileCount: counter.fileCount,
|
||||
testCount: counter.testCount,
|
||||
configs: [...counter.configs].toSorted((left, right) => left.localeCompare(right)),
|
||||
};
|
||||
}
|
||||
|
||||
export function buildGroupedTestReport(params) {
|
||||
const byGroup = new Map();
|
||||
const byConfig = new Map();
|
||||
const files = [];
|
||||
|
||||
for (const input of params.reports) {
|
||||
const config = normalizeConfigLabel(input.config);
|
||||
const fileEntries = collectVitestFileDurations(input.report, normalizeTrackedRepoPath);
|
||||
const configCounter = byConfig.get(config) ?? createCounter(config);
|
||||
byConfig.set(config, configCounter);
|
||||
|
||||
for (const entry of fileEntries) {
|
||||
const groupKey = resolveGroupKey(entry.file, params.groupBy);
|
||||
const groupCounter = byGroup.get(groupKey) ?? createCounter(groupKey);
|
||||
byGroup.set(groupKey, groupCounter);
|
||||
addFileEntry(groupCounter, entry, config);
|
||||
addFileEntry(configCounter, entry, config);
|
||||
files.push({ ...entry, config, group: groupKey });
|
||||
}
|
||||
}
|
||||
|
||||
const sortByDuration = (left, right) =>
|
||||
right.durationMs - left.durationMs || left.key.localeCompare(right.key);
|
||||
const sortFilesByDuration = (left, right) =>
|
||||
right.durationMs - left.durationMs || left.file.localeCompare(right.file);
|
||||
|
||||
const groups = [...byGroup.values()].map(finalizeCounter).toSorted(sortByDuration);
|
||||
const configs = [...byConfig.values()].map(finalizeCounter).toSorted(sortByDuration);
|
||||
const topFiles = files.toSorted(sortFilesByDuration);
|
||||
const totals = groups.reduce(
|
||||
(acc, group) => ({
|
||||
durationMs: acc.durationMs + group.durationMs,
|
||||
fileCount: acc.fileCount + group.fileCount,
|
||||
testCount: acc.testCount + group.testCount,
|
||||
}),
|
||||
{ durationMs: 0, fileCount: 0, testCount: 0 },
|
||||
);
|
||||
|
||||
return {
|
||||
generatedAt: new Date().toISOString(),
|
||||
groupBy: params.groupBy,
|
||||
totals,
|
||||
groups,
|
||||
configs,
|
||||
topFiles,
|
||||
};
|
||||
}
|
||||
|
||||
function percentDelta(beforeValue, afterValue) {
|
||||
if (beforeValue === 0) {
|
||||
return afterValue === 0 ? 0 : null;
|
||||
}
|
||||
return ((afterValue - beforeValue) / beforeValue) * 100;
|
||||
}
|
||||
|
||||
function formatPercent(value) {
|
||||
if (value === null || value === undefined) {
|
||||
return "new";
|
||||
}
|
||||
return `${value > 0 ? "+" : ""}${value.toFixed(1)}%`;
|
||||
}
|
||||
|
||||
function normalizeCounter(item) {
|
||||
return {
|
||||
durationMs: item?.durationMs ?? 0,
|
||||
fileCount: item?.fileCount ?? 0,
|
||||
testCount: item?.testCount ?? 0,
|
||||
};
|
||||
}
|
||||
|
||||
function compareStatus(beforeItem, afterItem) {
|
||||
if (beforeItem && afterItem) {
|
||||
return "changed";
|
||||
}
|
||||
return beforeItem ? "removed" : "added";
|
||||
}
|
||||
|
||||
function compareCounters(beforeItems = [], afterItems = []) {
|
||||
const beforeByKey = new Map(beforeItems.map((item) => [item.key, item]));
|
||||
const afterByKey = new Map(afterItems.map((item) => [item.key, item]));
|
||||
const keys = new Set([...beforeByKey.keys(), ...afterByKey.keys()]);
|
||||
|
||||
return [...keys]
|
||||
.map((key) => {
|
||||
const beforeItem = beforeByKey.get(key);
|
||||
const afterItem = afterByKey.get(key);
|
||||
const before = normalizeCounter(beforeItem);
|
||||
const after = normalizeCounter(afterItem);
|
||||
return {
|
||||
key,
|
||||
status: compareStatus(beforeItem, afterItem),
|
||||
before,
|
||||
after,
|
||||
delta: {
|
||||
durationMs: after.durationMs - before.durationMs,
|
||||
fileCount: after.fileCount - before.fileCount,
|
||||
testCount: after.testCount - before.testCount,
|
||||
},
|
||||
percent: {
|
||||
durationMs: percentDelta(before.durationMs, after.durationMs),
|
||||
},
|
||||
};
|
||||
})
|
||||
.toSorted(
|
||||
(left, right) =>
|
||||
Math.abs(right.delta.durationMs) - Math.abs(left.delta.durationMs) ||
|
||||
left.key.localeCompare(right.key),
|
||||
);
|
||||
}
|
||||
|
||||
function normalizeFileCounter(item) {
|
||||
return {
|
||||
durationMs: item?.durationMs ?? 0,
|
||||
testCount: item?.testCount ?? 0,
|
||||
};
|
||||
}
|
||||
|
||||
function fileKey(item) {
|
||||
return `${item.config}\0${item.file}`;
|
||||
}
|
||||
|
||||
function compareFiles(beforeFiles = [], afterFiles = []) {
|
||||
const beforeByKey = new Map(beforeFiles.map((item) => [fileKey(item), item]));
|
||||
const afterByKey = new Map(afterFiles.map((item) => [fileKey(item), item]));
|
||||
const keys = new Set([...beforeByKey.keys(), ...afterByKey.keys()]);
|
||||
|
||||
return [...keys]
|
||||
.map((key) => {
|
||||
const beforeItem = beforeByKey.get(key);
|
||||
const afterItem = afterByKey.get(key);
|
||||
const before = normalizeFileCounter(beforeItem);
|
||||
const after = normalizeFileCounter(afterItem);
|
||||
const source = afterItem ?? beforeItem;
|
||||
return {
|
||||
key,
|
||||
config: source.config,
|
||||
file: source.file,
|
||||
group: source.group,
|
||||
status: compareStatus(beforeItem, afterItem),
|
||||
before,
|
||||
after,
|
||||
delta: {
|
||||
durationMs: after.durationMs - before.durationMs,
|
||||
testCount: after.testCount - before.testCount,
|
||||
},
|
||||
percent: {
|
||||
durationMs: percentDelta(before.durationMs, after.durationMs),
|
||||
},
|
||||
};
|
||||
})
|
||||
.toSorted(
|
||||
(left, right) =>
|
||||
Math.abs(right.delta.durationMs) - Math.abs(left.delta.durationMs) ||
|
||||
left.file.localeCompare(right.file) ||
|
||||
left.config.localeCompare(right.config),
|
||||
);
|
||||
}
|
||||
|
||||
function runKey(run) {
|
||||
return normalizeConfigLabel(run.config);
|
||||
}
|
||||
|
||||
function compareOptionalNumber(beforeValue, afterValue) {
|
||||
if (typeof beforeValue !== "number" || typeof afterValue !== "number") {
|
||||
return null;
|
||||
}
|
||||
return afterValue - beforeValue;
|
||||
}
|
||||
|
||||
function normalizeRun(run) {
|
||||
return run
|
||||
? {
|
||||
elapsedMs: typeof run.elapsedMs === "number" ? run.elapsedMs : null,
|
||||
maxRssBytes: typeof run.maxRssBytes === "number" ? run.maxRssBytes : null,
|
||||
status: typeof run.status === "number" ? run.status : null,
|
||||
}
|
||||
: {
|
||||
elapsedMs: null,
|
||||
maxRssBytes: null,
|
||||
status: null,
|
||||
};
|
||||
}
|
||||
|
||||
function compareRuns(beforeRuns = [], afterRuns = []) {
|
||||
const beforeByKey = new Map(beforeRuns.map((run) => [runKey(run), run]));
|
||||
const afterByKey = new Map(afterRuns.map((run) => [runKey(run), run]));
|
||||
const keys = new Set([...beforeByKey.keys(), ...afterByKey.keys()]);
|
||||
|
||||
return [...keys]
|
||||
.map((key) => {
|
||||
const beforeRun = beforeByKey.get(key);
|
||||
const afterRun = afterByKey.get(key);
|
||||
const before = normalizeRun(beforeRun);
|
||||
const after = normalizeRun(afterRun);
|
||||
return {
|
||||
key,
|
||||
status: compareStatus(beforeRun, afterRun),
|
||||
before,
|
||||
after,
|
||||
delta: {
|
||||
elapsedMs: compareOptionalNumber(before.elapsedMs, after.elapsedMs),
|
||||
maxRssBytes: compareOptionalNumber(before.maxRssBytes, after.maxRssBytes),
|
||||
},
|
||||
};
|
||||
})
|
||||
.toSorted((left, right) => {
|
||||
const leftMagnitude = Math.abs(left.delta.elapsedMs ?? left.delta.maxRssBytes ?? 0);
|
||||
const rightMagnitude = Math.abs(right.delta.elapsedMs ?? right.delta.maxRssBytes ?? 0);
|
||||
return rightMagnitude - leftMagnitude || left.key.localeCompare(right.key);
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGroupedTestComparison(params) {
|
||||
const before = params.before;
|
||||
const after = params.after;
|
||||
const beforeTotals = normalizeCounter(before.totals);
|
||||
const afterTotals = normalizeCounter(after.totals);
|
||||
const warnings = [];
|
||||
|
||||
if (before.groupBy !== after.groupBy) {
|
||||
warnings.push(`groupBy differs: before=${before.groupBy} after=${after.groupBy}`);
|
||||
}
|
||||
|
||||
return {
|
||||
generatedAt: new Date().toISOString(),
|
||||
command: "test-group-report:compare",
|
||||
groupBy: after.groupBy ?? before.groupBy,
|
||||
warnings,
|
||||
totals: {
|
||||
before: beforeTotals,
|
||||
after: afterTotals,
|
||||
delta: {
|
||||
durationMs: afterTotals.durationMs - beforeTotals.durationMs,
|
||||
fileCount: afterTotals.fileCount - beforeTotals.fileCount,
|
||||
testCount: afterTotals.testCount - beforeTotals.testCount,
|
||||
},
|
||||
percent: {
|
||||
durationMs: percentDelta(beforeTotals.durationMs, afterTotals.durationMs),
|
||||
},
|
||||
},
|
||||
groups: compareCounters(before.groups, after.groups),
|
||||
configs: compareCounters(before.configs, after.configs),
|
||||
files: compareFiles(before.topFiles, after.topFiles),
|
||||
runs: compareRuns(before.runs, after.runs),
|
||||
inputs: {
|
||||
before: params.beforePath ?? null,
|
||||
after: params.afterPath ?? null,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function formatCountDelta(value) {
|
||||
return `${value > 0 ? "+" : ""}${value}`;
|
||||
}
|
||||
|
||||
function formatOptionalMs(value) {
|
||||
return typeof value === "number" ? formatMs(value) : "n/a";
|
||||
}
|
||||
|
||||
function formatOptionalSignedMs(value) {
|
||||
return typeof value === "number" ? formatSignedMs(value) : "n/a";
|
||||
}
|
||||
|
||||
function formatOptionalBytes(value) {
|
||||
return typeof value === "number" ? formatBytesAsMb(value) : "n/a";
|
||||
}
|
||||
|
||||
function formatOptionalSignedBytes(value) {
|
||||
return typeof value === "number" ? formatSignedBytesAsMb(value) : "n/a";
|
||||
}
|
||||
|
||||
function pushChangeRows(lines, entries, options) {
|
||||
const selected = entries.slice(0, options.limit);
|
||||
if (selected.length === 0) {
|
||||
lines.push(" (none)");
|
||||
return;
|
||||
}
|
||||
|
||||
for (const [index, entry] of selected.entries()) {
|
||||
lines.push(
|
||||
`${String(index + 1).padStart(2, " ")}. ${formatSignedMs(entry.delta.durationMs).padStart(11, " ")} (${formatPercent(entry.percent.durationMs).padStart(7, " ")}) | before=${formatMs(entry.before.durationMs).padStart(10, " ")} after=${formatMs(entry.after.durationMs).padStart(10, " ")} | files=${formatCountDelta(entry.delta.fileCount ?? 0).padStart(4, " ")} tests=${formatCountDelta(entry.delta.testCount ?? 0).padStart(5, " ")} | ${entry.key}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function pushFileChangeRows(lines, entries, options) {
|
||||
const selected = entries.slice(0, options.limit);
|
||||
if (selected.length === 0) {
|
||||
lines.push(" (none)");
|
||||
return;
|
||||
}
|
||||
|
||||
for (const [index, entry] of selected.entries()) {
|
||||
lines.push(
|
||||
`${String(index + 1).padStart(2, " ")}. ${formatSignedMs(entry.delta.durationMs).padStart(11, " ")} (${formatPercent(entry.percent.durationMs).padStart(7, " ")}) | before=${formatMs(entry.before.durationMs).padStart(10, " ")} after=${formatMs(entry.after.durationMs).padStart(10, " ")} | tests=${formatCountDelta(entry.delta.testCount).padStart(4, " ")} | ${entry.config} | ${entry.file}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export function renderGroupedTestComparison(comparison, options = {}) {
|
||||
const limit = options.limit ?? 25;
|
||||
const topFiles = options.topFiles ?? 25;
|
||||
const groupRegressions = comparison.groups.filter((entry) => entry.delta.durationMs > 0);
|
||||
const groupGains = comparison.groups.filter((entry) => entry.delta.durationMs < 0);
|
||||
const fileRegressions = comparison.files.filter((entry) => entry.delta.durationMs > 0);
|
||||
const fileGains = comparison.files.filter((entry) => entry.delta.durationMs < 0);
|
||||
const addedFiles = comparison.files.filter((entry) => entry.status === "added").length;
|
||||
const removedFiles = comparison.files.filter((entry) => entry.status === "removed").length;
|
||||
const lines = [
|
||||
`[test-group-report:compare] groupBy=${comparison.groupBy} file-sum=${formatMs(comparison.totals.before.durationMs)} -> ${formatMs(comparison.totals.after.durationMs)} (${formatSignedMs(comparison.totals.delta.durationMs)}, ${formatPercent(comparison.totals.percent.durationMs)}) files=${comparison.totals.before.fileCount}->${comparison.totals.after.fileCount} (${formatCountDelta(comparison.totals.delta.fileCount)}) tests=${comparison.totals.before.testCount}->${comparison.totals.after.testCount} (${formatCountDelta(comparison.totals.delta.testCount)}) addedFiles=${addedFiles} removedFiles=${removedFiles}`,
|
||||
];
|
||||
|
||||
for (const warning of comparison.warnings) {
|
||||
lines.push(`[test-group-report:compare] warning: ${warning}`);
|
||||
}
|
||||
|
||||
lines.push(
|
||||
"",
|
||||
`Top group regressions (${Math.min(limit, groupRegressions.length)} of ${groupRegressions.length})`,
|
||||
);
|
||||
pushChangeRows(lines, groupRegressions, { limit });
|
||||
|
||||
lines.push("", `Top group gains (${Math.min(limit, groupGains.length)} of ${groupGains.length})`);
|
||||
pushChangeRows(lines, groupGains, { limit });
|
||||
|
||||
lines.push(
|
||||
"",
|
||||
`Config duration deltas (${Math.min(limit, comparison.configs.length)} of ${comparison.configs.length})`,
|
||||
);
|
||||
pushChangeRows(lines, comparison.configs, { limit });
|
||||
|
||||
if (comparison.runs.length > 0) {
|
||||
lines.push(
|
||||
"",
|
||||
`Config wall/RSS deltas (${Math.min(limit, comparison.runs.length)} of ${comparison.runs.length})`,
|
||||
);
|
||||
for (const [index, run] of comparison.runs.slice(0, limit).entries()) {
|
||||
lines.push(
|
||||
`${String(index + 1).padStart(2, " ")}. wall=${formatOptionalSignedMs(run.delta.elapsedMs).padStart(11, " ")} before=${formatOptionalMs(run.before.elapsedMs).padStart(10, " ")} after=${formatOptionalMs(run.after.elapsedMs).padStart(10, " ")} | rss=${formatOptionalSignedBytes(run.delta.maxRssBytes).padStart(10, " ")} before=${formatOptionalBytes(run.before.maxRssBytes).padStart(9, " ")} after=${formatOptionalBytes(run.after.maxRssBytes).padStart(9, " ")} | status=${run.before.status ?? "n/a"}->${run.after.status ?? "n/a"} | ${run.key}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
lines.push(
|
||||
"",
|
||||
`Top file regressions (${Math.min(topFiles, fileRegressions.length)} of ${fileRegressions.length})`,
|
||||
);
|
||||
pushFileChangeRows(lines, fileRegressions, { limit: topFiles });
|
||||
|
||||
lines.push("", `Top file gains (${Math.min(topFiles, fileGains.length)} of ${fileGains.length})`);
|
||||
pushFileChangeRows(lines, fileGains, { limit: topFiles });
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
export function renderGroupedTestReport(report, options = {}) {
|
||||
const limit = options.limit ?? 25;
|
||||
const topFiles = options.topFiles ?? 25;
|
||||
const lines = [
|
||||
`[test-group-report] groupBy=${report.groupBy} files=${report.totals.fileCount} tests=${report.totals.testCount} file-sum=${formatMs(report.totals.durationMs)}`,
|
||||
"",
|
||||
`Top groups (${Math.min(limit, report.groups.length)} of ${report.groups.length})`,
|
||||
];
|
||||
|
||||
for (const [index, group] of report.groups.slice(0, limit).entries()) {
|
||||
lines.push(
|
||||
`${String(index + 1).padStart(2, " ")}. ${formatMs(group.durationMs).padStart(10, " ")} | files=${String(group.fileCount).padStart(4, " ")} | tests=${String(group.testCount).padStart(5, " ")} | ${group.key}`,
|
||||
);
|
||||
}
|
||||
|
||||
lines.push(
|
||||
"",
|
||||
`Top configs (${Math.min(limit, report.configs.length)} of ${report.configs.length})`,
|
||||
);
|
||||
for (const [index, config] of report.configs.slice(0, limit).entries()) {
|
||||
lines.push(
|
||||
`${String(index + 1).padStart(2, " ")}. ${formatMs(config.durationMs).padStart(10, " ")} | files=${String(config.fileCount).padStart(4, " ")} | tests=${String(config.testCount).padStart(5, " ")} | ${config.key}`,
|
||||
);
|
||||
}
|
||||
|
||||
lines.push(
|
||||
"",
|
||||
`Top files (${Math.min(topFiles, report.topFiles.length)} of ${report.topFiles.length})`,
|
||||
);
|
||||
for (const [index, file] of report.topFiles.slice(0, topFiles).entries()) {
|
||||
lines.push(
|
||||
`${String(index + 1).padStart(2, " ")}. ${formatMs(file.durationMs).padStart(10, " ")} | tests=${String(file.testCount).padStart(4, " ")} | ${file.config} | ${file.file}`,
|
||||
);
|
||||
}
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
@@ -660,8 +660,8 @@ function shouldRunBundledPluginPostinstall(params) {
|
||||
|
||||
export function runBundledPluginPostinstall(params = {}) {
|
||||
const env = params.env ?? process.env;
|
||||
const extensionsDir = params.extensionsDir ?? DEFAULT_EXTENSIONS_DIR;
|
||||
const packageRoot = params.packageRoot ?? DEFAULT_PACKAGE_ROOT;
|
||||
const extensionsDir = params.extensionsDir ?? join(packageRoot, "dist", "extensions");
|
||||
const spawn = params.spawnSync ?? spawnSync;
|
||||
const pathExists = params.existsSync ?? existsSync;
|
||||
const log = params.log ?? console;
|
||||
|
||||
@@ -1,372 +0,0 @@
|
||||
import { spawnSync } from "node:child_process";
|
||||
import fs from "node:fs";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { pathToFileURL } from "node:url";
|
||||
import {
|
||||
buildGroupedTestComparison,
|
||||
buildGroupedTestReport,
|
||||
formatBytesAsMb,
|
||||
normalizeConfigLabel,
|
||||
renderGroupedTestComparison,
|
||||
renderGroupedTestReport,
|
||||
} from "./lib/test-group-report.mjs";
|
||||
import { formatMs } from "./lib/vitest-report-cli-utils.mjs";
|
||||
import { resolveVitestNodeArgs } from "./run-vitest.mjs";
|
||||
import { buildFullSuiteVitestRunPlans } from "./test-projects.test-support.mjs";
|
||||
|
||||
const DEFAULT_OUTPUT = ".artifacts/test-perf/group-report.json";
|
||||
const DEFAULT_COMPARE_OUTPUT = ".artifacts/test-perf/group-report-compare.json";
|
||||
|
||||
function usage() {
|
||||
return [
|
||||
"Usage: node scripts/test-group-report.mjs [options] [-- <vitest args>]",
|
||||
"",
|
||||
"Build a grouped Vitest duration report from one or more JSON reports.",
|
||||
"",
|
||||
"Options:",
|
||||
" --config <path> Vitest config to run (repeatable)",
|
||||
" --compare <before> <after>",
|
||||
" Compare two grouped report JSON files",
|
||||
" --report <path> Existing Vitest JSON report to read (repeatable)",
|
||||
" --full-suite Run every full-suite leaf Vitest config serially",
|
||||
" --group-by <mode> area | folder | top (default: area)",
|
||||
" --output <path> JSON report path (default: .artifacts/test-perf/group-report.json)",
|
||||
" --limit <count> Number of groups/configs to print (default: 25)",
|
||||
" --top-files <count> Number of files to print (default: 25)",
|
||||
" --allow-failures Write a report even when a Vitest run exits non-zero",
|
||||
" --no-rss Skip macOS max RSS measurement",
|
||||
" --help Show this help",
|
||||
"",
|
||||
"Examples:",
|
||||
" pnpm test:perf:groups --config test/vitest/vitest.unit-fast.config.ts",
|
||||
" pnpm test:perf:groups --full-suite --allow-failures",
|
||||
" pnpm test:perf:groups:compare .artifacts/test-perf/baseline-before.json .artifacts/test-perf/after-first-fix.json",
|
||||
].join("\n");
|
||||
}
|
||||
|
||||
function parsePositiveInt(value, fallback) {
|
||||
const parsed = Number.parseInt(value ?? "", 10);
|
||||
return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback;
|
||||
}
|
||||
|
||||
export function parseTestGroupReportArgs(argv) {
|
||||
const args = {
|
||||
allowFailures: false,
|
||||
compare: null,
|
||||
configs: [],
|
||||
fullSuite: false,
|
||||
groupBy: "area",
|
||||
limit: 25,
|
||||
output: null,
|
||||
reports: [],
|
||||
rss: process.platform === "darwin",
|
||||
topFiles: 25,
|
||||
vitestArgs: [],
|
||||
};
|
||||
|
||||
for (let index = 0; index < argv.length; index += 1) {
|
||||
const arg = argv[index];
|
||||
if (arg === "--") {
|
||||
args.vitestArgs = argv.slice(index + 1);
|
||||
break;
|
||||
}
|
||||
if (arg === "--help") {
|
||||
args.help = true;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--allow-failures") {
|
||||
args.allowFailures = true;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--full-suite") {
|
||||
args.fullSuite = true;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--no-rss") {
|
||||
args.rss = false;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--config") {
|
||||
args.configs.push(argv[index + 1] ?? "");
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--compare") {
|
||||
args.compare = {
|
||||
before: argv[index + 1] ?? "",
|
||||
after: argv[index + 2] ?? "",
|
||||
};
|
||||
index += 2;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--report") {
|
||||
args.reports.push(argv[index + 1] ?? "");
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--group-by") {
|
||||
args.groupBy = argv[index + 1] ?? args.groupBy;
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--output") {
|
||||
args.output = argv[index + 1] ?? args.output;
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--limit") {
|
||||
args.limit = parsePositiveInt(argv[index + 1], args.limit);
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
if (arg === "--top-files") {
|
||||
args.topFiles = parsePositiveInt(argv[index + 1], args.topFiles);
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
throw new Error(`Unknown option: ${arg}`);
|
||||
}
|
||||
|
||||
if (!["area", "folder", "top"].includes(args.groupBy)) {
|
||||
throw new Error(`Unsupported --group-by value: ${args.groupBy}`);
|
||||
}
|
||||
if (args.compare && (!args.compare.before || !args.compare.after)) {
|
||||
throw new Error("--compare requires before and after report paths");
|
||||
}
|
||||
if (
|
||||
args.compare &&
|
||||
(args.configs.length > 0 ||
|
||||
args.fullSuite ||
|
||||
args.reports.length > 0 ||
|
||||
args.vitestArgs.length > 0)
|
||||
) {
|
||||
throw new Error("--compare cannot be combined with test run or report input options");
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
function sanitizePathSegment(value) {
|
||||
return (
|
||||
value
|
||||
.replace(/[^A-Za-z0-9._-]+/gu, "-")
|
||||
.replace(/^-+|-+$/gu, "")
|
||||
.slice(0, 180) || "report"
|
||||
);
|
||||
}
|
||||
|
||||
function parseMaxRssBytes(output) {
|
||||
const match = output.match(/(\d+)\s+maximum resident set size/u);
|
||||
return match ? Number.parseInt(match[1], 10) : null;
|
||||
}
|
||||
|
||||
function runVitestJsonReport(params) {
|
||||
fs.mkdirSync(path.dirname(params.reportPath), { recursive: true });
|
||||
fs.mkdirSync(path.dirname(params.logPath), { recursive: true });
|
||||
const command = [
|
||||
process.execPath,
|
||||
"scripts/run-vitest.mjs",
|
||||
"run",
|
||||
"--config",
|
||||
params.config,
|
||||
"--reporter=json",
|
||||
"--outputFile",
|
||||
params.reportPath,
|
||||
...params.vitestArgs,
|
||||
];
|
||||
const startedAt = process.hrtime.bigint();
|
||||
const result = spawnSync(
|
||||
params.rss ? "/usr/bin/time" : command[0],
|
||||
params.rss ? ["-l", ...command] : command.slice(1),
|
||||
{
|
||||
cwd: process.cwd(),
|
||||
encoding: "utf8",
|
||||
env: {
|
||||
...process.env,
|
||||
NODE_OPTIONS: [
|
||||
process.env.NODE_OPTIONS?.trim(),
|
||||
...resolveVitestNodeArgs(process.env).filter((arg) => arg !== "--no-maglev"),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(" "),
|
||||
},
|
||||
maxBuffer: 1024 * 1024 * 64,
|
||||
},
|
||||
);
|
||||
const elapsedMs = Number.parseFloat(String(process.hrtime.bigint() - startedAt)) / 1_000_000;
|
||||
const output = `${result.stdout ?? ""}${result.stderr ?? ""}`;
|
||||
fs.writeFileSync(params.logPath, output, "utf8");
|
||||
return {
|
||||
config: params.config,
|
||||
elapsedMs,
|
||||
logPath: params.logPath,
|
||||
maxRssBytes: params.rss ? parseMaxRssBytes(output) : null,
|
||||
reportPath: params.reportPath,
|
||||
status: result.status ?? 1,
|
||||
};
|
||||
}
|
||||
|
||||
function readReportInput(entry) {
|
||||
return {
|
||||
config: entry.config,
|
||||
report: JSON.parse(fs.readFileSync(entry.reportPath, "utf8")),
|
||||
reportPath: entry.reportPath,
|
||||
run: entry.run ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
function readGroupedReport(reportPath) {
|
||||
return JSON.parse(fs.readFileSync(reportPath, "utf8"));
|
||||
}
|
||||
|
||||
export function resolveReportArtifactDirs(outputPath) {
|
||||
const outputDir = path.dirname(outputPath);
|
||||
const outputExt = path.extname(outputPath);
|
||||
const outputStem = path.basename(outputPath, outputExt) || "group-report";
|
||||
const artifactDir = path.join(outputDir, outputStem);
|
||||
return {
|
||||
reportDir: path.join(artifactDir, "vitest-json"),
|
||||
logDir: path.join(artifactDir, "logs"),
|
||||
};
|
||||
}
|
||||
|
||||
function resolveConfigs(args) {
|
||||
if (args.reports.length > 0) {
|
||||
return [];
|
||||
}
|
||||
if (args.fullSuite) {
|
||||
return buildFullSuiteVitestRunPlans([], process.cwd()).map((plan) => plan.config);
|
||||
}
|
||||
return args.configs.length > 0 ? args.configs : ["test/vitest/vitest.unit.config.ts"];
|
||||
}
|
||||
|
||||
function printRunLine(run) {
|
||||
const label = normalizeConfigLabel(run.config);
|
||||
console.log(
|
||||
`[test-group-report] ${label} status=${run.status} wall=${formatMs(run.elapsedMs)} rss=${formatBytesAsMb(run.maxRssBytes)} report=${run.reportPath}`,
|
||||
);
|
||||
}
|
||||
|
||||
async function main() {
|
||||
const args = parseTestGroupReportArgs(process.argv.slice(2));
|
||||
if (args.help) {
|
||||
console.log(usage());
|
||||
return;
|
||||
}
|
||||
|
||||
const output = path.resolve(
|
||||
args.output ?? (args.compare ? DEFAULT_COMPARE_OUTPUT : DEFAULT_OUTPUT),
|
||||
);
|
||||
|
||||
if (args.compare) {
|
||||
const beforePath = path.resolve(args.compare.before);
|
||||
const afterPath = path.resolve(args.compare.after);
|
||||
const comparison = buildGroupedTestComparison({
|
||||
before: readGroupedReport(beforePath),
|
||||
after: readGroupedReport(afterPath),
|
||||
beforePath,
|
||||
afterPath,
|
||||
});
|
||||
|
||||
fs.mkdirSync(path.dirname(output), { recursive: true });
|
||||
fs.writeFileSync(output, `${JSON.stringify(comparison, null, 2)}\n`, "utf8");
|
||||
console.log(
|
||||
renderGroupedTestComparison(comparison, { limit: args.limit, topFiles: args.topFiles }),
|
||||
);
|
||||
console.log(`[test-group-report:compare] wrote ${path.relative(process.cwd(), output)}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const { reportDir, logDir } = resolveReportArtifactDirs(output);
|
||||
const runEntries = [];
|
||||
const configs = resolveConfigs(args);
|
||||
let failed = false;
|
||||
let exitCode = 0;
|
||||
|
||||
for (const reportPath of args.reports) {
|
||||
runEntries.push({
|
||||
config: path.basename(reportPath).replace(/\.json$/u, ""),
|
||||
reportPath: path.resolve(reportPath),
|
||||
});
|
||||
}
|
||||
|
||||
for (const config of configs) {
|
||||
const slug = sanitizePathSegment(normalizeConfigLabel(config));
|
||||
const run = runVitestJsonReport({
|
||||
config,
|
||||
logPath: path.join(logDir, `${slug}.log`),
|
||||
reportPath: path.join(reportDir, `${slug}.json`),
|
||||
rss: args.rss,
|
||||
vitestArgs: args.vitestArgs,
|
||||
});
|
||||
printRunLine(run);
|
||||
if (run.status !== 0) {
|
||||
failed = true;
|
||||
if (!fs.existsSync(run.reportPath)) {
|
||||
console.error(
|
||||
`[test-group-report] missing JSON report for failed config; see ${run.logPath}`,
|
||||
);
|
||||
if (!args.allowFailures) {
|
||||
exitCode = run.status;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
console.error(
|
||||
`[test-group-report] config failed; keeping partial report from ${run.reportPath}`,
|
||||
);
|
||||
if (!args.allowFailures) {
|
||||
exitCode = run.status;
|
||||
break;
|
||||
}
|
||||
}
|
||||
runEntries.push({ config, reportPath: run.reportPath, run });
|
||||
}
|
||||
|
||||
if (exitCode !== 0) {
|
||||
process.exit(exitCode);
|
||||
}
|
||||
|
||||
const reportInputs = runEntries
|
||||
.filter((entry) => fs.existsSync(entry.reportPath))
|
||||
.map(readReportInput);
|
||||
const report = buildGroupedTestReport({
|
||||
groupBy: args.groupBy,
|
||||
reports: reportInputs,
|
||||
});
|
||||
const envelope = {
|
||||
...report,
|
||||
command: "test-group-report",
|
||||
failed,
|
||||
runs: reportInputs.map((entry) => entry.run).filter(Boolean),
|
||||
system: {
|
||||
node: process.version,
|
||||
platform: process.platform,
|
||||
arch: process.arch,
|
||||
cpuCount: os.availableParallelism?.() ?? os.cpus().length,
|
||||
totalMemoryBytes: os.totalmem(),
|
||||
},
|
||||
};
|
||||
|
||||
fs.mkdirSync(path.dirname(output), { recursive: true });
|
||||
fs.writeFileSync(output, `${JSON.stringify(envelope, null, 2)}\n`, "utf8");
|
||||
console.log(renderGroupedTestReport(report, { limit: args.limit, topFiles: args.topFiles }));
|
||||
console.log(`[test-group-report] wrote ${path.relative(process.cwd(), output)}`);
|
||||
|
||||
if (failed && !args.allowFailures) {
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
const isMain =
|
||||
typeof process.argv[1] === "string" &&
|
||||
process.argv[1].length > 0 &&
|
||||
import.meta.url === pathToFileURL(path.resolve(process.argv[1])).href;
|
||||
|
||||
if (isMain) {
|
||||
main().catch((error) => {
|
||||
console.error(error instanceof Error ? error.message : String(error));
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
@@ -8,10 +8,8 @@ import {
|
||||
} from "../hooks/internal-hooks.js";
|
||||
import { makeTempWorkspace } from "../test-helpers/workspace.js";
|
||||
import {
|
||||
_resetBootstrapWarningCacheForTest,
|
||||
FULL_BOOTSTRAP_COMPLETED_CUSTOM_TYPE,
|
||||
hasCompletedBootstrapTurn,
|
||||
makeBootstrapWarn,
|
||||
resolveBootstrapContextForRun,
|
||||
resolveBootstrapFilesForRun,
|
||||
resolveContextInjectionMode,
|
||||
@@ -107,18 +105,6 @@ describe("resolveBootstrapContextForRun", () => {
|
||||
expect(extra?.content).toBe("extra");
|
||||
});
|
||||
|
||||
it("keeps BOOTSTRAP.md available in shared injected context for non-attempt consumers", async () => {
|
||||
const workspaceDir = await makeTempWorkspace("openclaw-bootstrap-");
|
||||
await fs.writeFile(path.join(workspaceDir, "BOOTSTRAP.md"), "ritual", "utf8");
|
||||
await fs.writeFile(path.join(workspaceDir, "AGENTS.md"), "rules", "utf8");
|
||||
|
||||
const result = await resolveBootstrapContextForRun({ workspaceDir });
|
||||
|
||||
expect(result.bootstrapFiles.some((file) => file.name === "BOOTSTRAP.md")).toBe(true);
|
||||
expect(result.contextFiles.some((file) => file.path.endsWith("BOOTSTRAP.md"))).toBe(true);
|
||||
expect(result.contextFiles.some((file) => file.path.endsWith("AGENTS.md"))).toBe(true);
|
||||
});
|
||||
|
||||
it("uses heartbeat-only bootstrap files in lightweight heartbeat mode", async () => {
|
||||
const workspaceDir = await makeTempWorkspace("openclaw-bootstrap-");
|
||||
await fs.writeFile(path.join(workspaceDir, "HEARTBEAT.md"), "check inbox", "utf8");
|
||||
@@ -376,69 +362,6 @@ describe("hasCompletedBootstrapTurn", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("makeBootstrapWarn", () => {
|
||||
afterEach(() => {
|
||||
_resetBootstrapWarningCacheForTest();
|
||||
});
|
||||
|
||||
it("deduplicates repeated warnings for the same session and message", () => {
|
||||
const warnings: string[] = [];
|
||||
const warn = makeBootstrapWarn({
|
||||
sessionLabel: "agent:main:test-session",
|
||||
warn: (message) => warnings.push(message),
|
||||
});
|
||||
|
||||
warn?.("workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating");
|
||||
warn?.("workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating");
|
||||
|
||||
expect(warnings).toEqual([
|
||||
"workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating (sessionKey=agent:main:test-session)",
|
||||
]);
|
||||
});
|
||||
|
||||
it("keeps warnings distinct across sessions", () => {
|
||||
const warnings: string[] = [];
|
||||
const first = makeBootstrapWarn({
|
||||
sessionLabel: "agent:main:first-session",
|
||||
warn: (message) => warnings.push(message),
|
||||
});
|
||||
const second = makeBootstrapWarn({
|
||||
sessionLabel: "agent:main:second-session",
|
||||
warn: (message) => warnings.push(message),
|
||||
});
|
||||
|
||||
first?.("workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating");
|
||||
second?.("workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating");
|
||||
|
||||
expect(warnings).toEqual([
|
||||
"workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating (sessionKey=agent:main:first-session)",
|
||||
"workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating (sessionKey=agent:main:second-session)",
|
||||
]);
|
||||
});
|
||||
|
||||
it("keeps warnings distinct across workspaces with the same session", () => {
|
||||
const warnings: string[] = [];
|
||||
const first = makeBootstrapWarn({
|
||||
sessionLabel: "agent:main:shared-session",
|
||||
workspaceDir: "/tmp/workspace-a",
|
||||
warn: (message) => warnings.push(message),
|
||||
});
|
||||
const second = makeBootstrapWarn({
|
||||
sessionLabel: "agent:main:shared-session",
|
||||
workspaceDir: "/tmp/workspace-b",
|
||||
warn: (message) => warnings.push(message),
|
||||
});
|
||||
|
||||
first?.("workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating");
|
||||
second?.("workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating");
|
||||
|
||||
expect(warnings).toEqual([
|
||||
"workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating (sessionKey=agent:main:shared-session)",
|
||||
"workspace bootstrap file MEMORY.md is 36697 chars (limit 12000); truncating (sessionKey=agent:main:shared-session)",
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("resolveContextInjectionMode", () => {
|
||||
it("defaults to always when config is missing", () => {
|
||||
expect(resolveContextInjectionMode(undefined)).toBe("always");
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
import {
|
||||
DEFAULT_HEARTBEAT_FILENAME,
|
||||
filterBootstrapFilesForSession,
|
||||
isWorkspaceBootstrapPending,
|
||||
loadWorkspaceBootstrapFiles,
|
||||
type WorkspaceBootstrapFile,
|
||||
} from "./workspace.js";
|
||||
@@ -26,29 +25,6 @@ export type BootstrapContextRunKind = "default" | "heartbeat" | "cron";
|
||||
const CONTINUATION_SCAN_MAX_TAIL_BYTES = 256 * 1024;
|
||||
const CONTINUATION_SCAN_MAX_RECORDS = 500;
|
||||
export const FULL_BOOTSTRAP_COMPLETED_CUSTOM_TYPE = "openclaw:bootstrap-context:full";
|
||||
const BOOTSTRAP_WARNING_DEDUPE_LIMIT = 1024;
|
||||
const seenBootstrapWarnings = new Set<string>();
|
||||
const bootstrapWarningOrder: string[] = [];
|
||||
|
||||
function rememberBootstrapWarning(key: string): boolean {
|
||||
if (seenBootstrapWarnings.has(key)) {
|
||||
return false;
|
||||
}
|
||||
if (seenBootstrapWarnings.size >= BOOTSTRAP_WARNING_DEDUPE_LIMIT) {
|
||||
const oldest = bootstrapWarningOrder.shift();
|
||||
if (oldest) {
|
||||
seenBootstrapWarnings.delete(oldest);
|
||||
}
|
||||
}
|
||||
seenBootstrapWarnings.add(key);
|
||||
bootstrapWarningOrder.push(key);
|
||||
return true;
|
||||
}
|
||||
|
||||
export function _resetBootstrapWarningCacheForTest(): void {
|
||||
seenBootstrapWarnings.clear();
|
||||
bootstrapWarningOrder.length = 0;
|
||||
}
|
||||
|
||||
export function resolveContextInjectionMode(config?: OpenClawConfig): AgentContextInjection {
|
||||
return config?.agents?.defaults?.contextInjection ?? "always";
|
||||
@@ -127,21 +103,12 @@ export async function hasCompletedBootstrapTurn(sessionFile: string): Promise<bo
|
||||
|
||||
export function makeBootstrapWarn(params: {
|
||||
sessionLabel: string;
|
||||
workspaceDir?: string;
|
||||
warn?: (message: string) => void;
|
||||
}): ((message: string) => void) | undefined {
|
||||
const warn = params.warn;
|
||||
if (!warn) {
|
||||
if (!params.warn) {
|
||||
return undefined;
|
||||
}
|
||||
const workspacePrefix = params.workspaceDir ?? "";
|
||||
return (message: string) => {
|
||||
const key = `${workspacePrefix}\u0000${params.sessionLabel}\u0000${message}`;
|
||||
if (!rememberBootstrapWarning(key)) {
|
||||
return;
|
||||
}
|
||||
warn(`${message} (sessionKey=${params.sessionLabel})`);
|
||||
};
|
||||
return (message: string) => params.warn?.(`${message} (sessionKey=${params.sessionLabel})`);
|
||||
}
|
||||
|
||||
function sanitizeBootstrapFiles(
|
||||
@@ -273,5 +240,3 @@ export async function resolveBootstrapContextForRun(params: {
|
||||
});
|
||||
return { bootstrapFiles, contextFiles };
|
||||
}
|
||||
|
||||
export { isWorkspaceBootstrapPending };
|
||||
|
||||
@@ -91,7 +91,6 @@ export async function prepareCliRunContext(
|
||||
sessionId: params.sessionId,
|
||||
warn: prepareDeps.makeBootstrapWarn({
|
||||
sessionLabel,
|
||||
workspaceDir,
|
||||
warn: (message) => cliBackendLog.warn(message),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -6,6 +6,7 @@ import type { SecretInput } from "../config/types.secrets.js";
|
||||
import {
|
||||
isMemoryMultimodalEnabled,
|
||||
normalizeMemoryMultimodalSettings,
|
||||
supportsMemoryMultimodalEmbeddings,
|
||||
type MemoryMultimodalSettings,
|
||||
} from "../memory-host-sdk/multimodal.js";
|
||||
import { getMemoryEmbeddingProvider } from "../plugins/memory-embedding-provider-runtime.js";
|
||||
@@ -388,9 +389,24 @@ export function resolveMemorySearchConfig(
|
||||
const multimodalActive = isMemoryMultimodalEnabled(resolved.multimodal);
|
||||
const multimodalProvider =
|
||||
resolved.provider === "auto" ? undefined : getMemoryEmbeddingProvider(resolved.provider);
|
||||
const builtinMultimodalSupport =
|
||||
resolved.provider === "auto"
|
||||
? false
|
||||
: supportsMemoryMultimodalEmbeddings({
|
||||
provider: resolved.provider,
|
||||
model: resolved.model,
|
||||
});
|
||||
if (
|
||||
multimodalActive &&
|
||||
!(multimodalProvider?.supportsMultimodalEmbeddings?.({ model: resolved.model }) ?? false)
|
||||
!(
|
||||
// Fall back to the built-in helper when the provider is not registered yet
|
||||
// or when a registered adapter does not implement multimodal capability checks.
|
||||
(
|
||||
multimodalProvider?.supportsMultimodalEmbeddings?.({
|
||||
model: resolved.model,
|
||||
}) ?? builtinMultimodalSupport
|
||||
)
|
||||
)
|
||||
) {
|
||||
throw new Error(
|
||||
"agents.*.memorySearch.multimodal requires a provider adapter that supports multimodal embeddings for the configured model.",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user