mirror of
https://github.com/openclaw/openclaw.git
synced 2026-06-07 14:31:35 +08:00
Compare commits
26 Commits
fix/codeql
...
codex/boot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5d70edf53 | ||
|
|
8205de84a9 | ||
|
|
c65f356ddc | ||
|
|
7e18c07e41 | ||
|
|
3fe8b24c4e | ||
|
|
c95507978f | ||
|
|
59d07f0ab4 | ||
|
|
5c1d6feb33 | ||
|
|
e8fd148437 | ||
|
|
2a283e87a7 | ||
|
|
15b2827fc1 | ||
|
|
65645ec54f | ||
|
|
e8ae3901b6 | ||
|
|
8e444ac5a6 | ||
|
|
6b45ba88a1 | ||
|
|
353950894a | ||
|
|
9da4d5f5df | ||
|
|
c6af0437c9 | ||
|
|
a2f2e5738e | ||
|
|
35fb3f7e1c | ||
|
|
a189394590 | ||
|
|
685f9903ec | ||
|
|
24431e5114 | ||
|
|
ee856ab31f | ||
|
|
acd86a06cd | ||
|
|
77e6e4cf87 |
6
.github/workflows/codeql.yml
vendored
6
.github/workflows/codeql.yml
vendored
@@ -1,12 +1,6 @@
|
||||
name: CodeQL
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths-ignore:
|
||||
- "**/*.md"
|
||||
- "**/*.mdx"
|
||||
- "LICENSE"
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 6 * * *"
|
||||
|
||||
14
CHANGELOG.md
14
CHANGELOG.md
@@ -6,7 +6,15 @@ Docs: https://docs.openclaw.ai
|
||||
|
||||
### Fixes
|
||||
|
||||
- Agents/bootstrap: resolve bootstrap from workspace truth instead of stale session transcript markers, keep embedded bootstrap instructions on the hidden user-context path, suppress normal `/new` and `/reset` greetings while `BOOTSTRAP.md` is still pending, and make the embedded runner read the bootstrap ritual before replying normally.
|
||||
- Onboarding/non-interactive: preserve existing gateway auth tokens during re-onboard so active local gateway clients are not disconnected by an implicit token rotation. (#67821) Thanks @BKF-Gitty.
|
||||
- OpenAI Codex/Responses: unify native Responses API capability detection so Codex OAuth requests emit the required `store: false` field on the native Responses path. (#67918) Thanks @obviyus.
|
||||
- WhatsApp/setup: guard personal-phone and allowlist prompt values so setup fails with clear validation errors instead of crashing on undefined prompt text. (#67895) Thanks @lawrence3699.
|
||||
- Models/config: preserve an existing `models.json` provider `baseUrl` during merge-mode regeneration so custom endpoints do not get reset on restart. (#67893) Thanks @lawrence3699.
|
||||
- Plugins/discovery: reuse bundled and global plugin discovery results across workspace cache misses so Windows multi-workspace startup stops redoing the shared synchronous scan. (#67940) Thanks @obviyus.
|
||||
- Plugins/webhooks: enforce synchronous plugin registration with full rollback of failed plugin side effects, and cache SecretRef-backed webhook auth per route so plugin startup and inbound webhook auth stay deterministic. (#67941) Thanks @obviyus.
|
||||
- Telegram/ACP bindings: drop persisted DM bindings that still point at missing or failed ACP sessions on restart, while preserving plugin-owned bindings and uncertain store reads. (#67822) Thanks @chinar-amrutkar.
|
||||
- Telegram/streaming: keep a transient preview on the same Telegram message when auto-compaction retries an in-flight answer, so streamed replies no longer appear duplicated after compaction. (#66939) Thanks @rubencu.
|
||||
|
||||
## 2026.4.15
|
||||
|
||||
@@ -14,12 +22,6 @@ Docs: https://docs.openclaw.ai
|
||||
|
||||
- Anthropic/models: default Anthropic selections, `opus` aliases, Claude CLI defaults, and bundled image understanding to Claude Opus 4.7.
|
||||
- Google/TTS: add Gemini text-to-speech support to the bundled `google` plugin, including provider registration, voice selection, WAV reply output, PCM telephony output, and setup/docs guidance. (#67515) Thanks @barronlroth.
|
||||
- Control UI/Overview: add a Model Auth status card showing OAuth token health and provider rate-limit pressure at a glance, with attention callouts when OAuth tokens are expiring or expired. Backed by a new `models.authStatus` gateway method that strips credentials and caches for 60s. (#66211) Thanks @omarshahine.
|
||||
- Memory/LanceDB: add cloud storage support to `memory-lancedb` so durable memory indexes can run on remote object storage instead of local disk only. (#63502) Thanks @rugvedS07.
|
||||
- GitHub Copilot/memory search: add a GitHub Copilot embedding provider for memory search, and expose a dedicated Copilot embedding host helper so plugins can reuse the transport while honoring remote overrides, token refresh, and safer payload validation. (#61718) Thanks @feiskyer and @vincentkoc.
|
||||
- Agents/local models: add experimental `agents.defaults.experimental.localModelLean: true` to drop heavyweight default tools like `browser`, `cron`, and `message`, reducing prompt size for weaker local-model setups without changing the normal path. (#66495) Thanks @ImLukeF.
|
||||
- Packaging/plugins: localize bundled plugin runtime deps to their owning extensions, trim the published docs payload, and tighten install/package-manager guardrails so published builds stay leaner and core stops carrying extension-owned runtime baggage. (#67099) Thanks @vincentkoc.
|
||||
- QA/Matrix: split Matrix live QA into a source-linked `qa-matrix` runner and keep repo-private `qa-*` surfaces out of packaged and published builds. (#66723) Thanks @gumadeiras.
|
||||
- Docs/showcase: add a scannable hero, complete section jump links, and a responsive video grid for community examples. (#48493) Thanks @jchopard69.
|
||||
|
||||
### Fixes
|
||||
|
||||
@@ -318,7 +318,7 @@ Current bundled provider examples:
|
||||
| `plugin-sdk/memory-core` | Bundled memory-core helpers | Memory manager/config/file/CLI helper surface |
|
||||
| `plugin-sdk/memory-core-engine-runtime` | Memory engine runtime facade | Memory index/search runtime facade |
|
||||
| `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine | Memory host foundation engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory host embedding engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory embedding contracts, registry access, local provider, and generic batch/remote helpers; concrete remote providers live in their owning plugins |
|
||||
| `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine | Memory host QMD engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine | Memory host storage engine exports |
|
||||
| `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers | Memory host multimodal helpers |
|
||||
|
||||
@@ -264,7 +264,7 @@ explicitly promotes one as public.
|
||||
| `plugin-sdk/memory-core` | Bundled memory-core helper surface for manager/config/file/CLI helpers |
|
||||
| `plugin-sdk/memory-core-engine-runtime` | Memory index/search runtime facade |
|
||||
| `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding contracts, registry access, local provider, and generic batch/remote helpers |
|
||||
| `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine exports |
|
||||
| `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers |
|
||||
|
||||
@@ -119,7 +119,7 @@ describe("active-memory plugin", () => {
|
||||
runEmbeddedPiAgent.mockResolvedValue({
|
||||
payloads: [{ text: "- lemon pepper wings\n- blue cheese" }],
|
||||
});
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
@@ -425,7 +425,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
allowedChatTypes: ["direct", "group"],
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
const result = await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should we order?", messages: [] },
|
||||
@@ -513,7 +513,7 @@ describe("active-memory plugin", () => {
|
||||
searchMode: "inherit",
|
||||
},
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -602,7 +602,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "message",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -630,7 +630,7 @@ describe("active-memory plugin", () => {
|
||||
queryMode: "message",
|
||||
promptStyle: "preference-only",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -675,7 +675,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
thinking: "medium",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -701,7 +701,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
promptAppend: "Prefer stable long-term preferences over one-off events.",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -730,7 +730,7 @@ describe("active-memory plugin", () => {
|
||||
promptOverride: "Custom memory prompt. Return NONE or one user fact.",
|
||||
promptAppend: "Extra custom instruction.",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -802,7 +802,7 @@ describe("active-memory plugin", () => {
|
||||
api.pluginConfig = {
|
||||
agents: ["main"],
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? temp transcript", messages: [] },
|
||||
@@ -828,7 +828,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
modelFallbackPolicy: "resolved-only",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
const result = await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? no fallback", messages: [] },
|
||||
@@ -851,7 +851,7 @@ describe("active-memory plugin", () => {
|
||||
modelFallback: "google/gemini-3-flash",
|
||||
modelFallbackPolicy: "default-remote",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? custom fallback", messages: [] },
|
||||
@@ -878,7 +878,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
modelFallbackPolicy: "default-remote",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
const result = await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? built-in fallback", messages: [] },
|
||||
@@ -1027,7 +1027,7 @@ describe("active-memory plugin", () => {
|
||||
timeoutMs: 250,
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
let lastAbortSignal: AbortSignal | undefined;
|
||||
runEmbeddedPiAgent.mockImplementation(async (params: { abortSignal?: AbortSignal }) => {
|
||||
lastAbortSignal = params.abortSignal;
|
||||
@@ -1073,7 +1073,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? session id cache", messages: [] },
|
||||
@@ -1107,7 +1107,7 @@ describe("active-memory plugin", () => {
|
||||
timeoutMs: 250,
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
runEmbeddedPiAgent.mockImplementationOnce(async (params: { timeoutMs?: number }) => {
|
||||
await new Promise((resolve) => setTimeout(resolve, (params.timeoutMs ?? 0) + 25));
|
||||
return {
|
||||
@@ -1145,7 +1145,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? log sanitization", messages: [] },
|
||||
@@ -1179,7 +1179,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
const hugeSession = `agent:main:${"x".repeat(500)}`;
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
@@ -1423,7 +1423,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "message",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1451,7 +1451,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "full",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1482,7 +1482,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "recent",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1536,7 +1536,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "recent",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1578,7 +1578,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "recent",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1611,7 +1611,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
queryMode: "recent",
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{
|
||||
@@ -1619,8 +1619,7 @@ describe("active-memory plugin", () => {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content:
|
||||
"Active Memory: I really do want you to remember that I prefer aisle seats.",
|
||||
content: "Active Memory: I really do want you to remember that I prefer aisle seats.",
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
@@ -1674,7 +1673,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
maxSummaryChars: 40,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
runEmbeddedPiAgent.mockResolvedValueOnce({
|
||||
payloads: [
|
||||
{
|
||||
@@ -1708,7 +1707,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
maxSummaryChars: 90,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
{ prompt: "what wings should i order? prompt-count-check", messages: [] },
|
||||
@@ -1758,7 +1757,7 @@ describe("active-memory plugin", () => {
|
||||
transcriptDir: "active-memory-subagents",
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
const mkdirSpy = vi.spyOn(fs, "mkdir").mockResolvedValue(undefined);
|
||||
const mkdtempSpy = vi.spyOn(fs, "mkdtemp");
|
||||
const rmSpy = vi.spyOn(fs, "rm").mockResolvedValue(undefined);
|
||||
@@ -1802,7 +1801,7 @@ describe("active-memory plugin", () => {
|
||||
transcriptDir: "C:/temp/escape",
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
const mkdirSpy = vi.spyOn(fs, "mkdir").mockResolvedValue(undefined);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
@@ -1839,7 +1838,7 @@ describe("active-memory plugin", () => {
|
||||
transcriptDir: "active-memory-subagents",
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
const mkdirSpy = vi.spyOn(fs, "mkdir").mockResolvedValue(undefined);
|
||||
|
||||
await hooks.before_prompt_build(
|
||||
@@ -1906,7 +1905,7 @@ describe("active-memory plugin", () => {
|
||||
agents: ["main"],
|
||||
logging: true,
|
||||
};
|
||||
await plugin.register(api as unknown as OpenClawPluginApi);
|
||||
plugin.register(api as unknown as OpenClawPluginApi);
|
||||
|
||||
for (let index = 0; index <= 1000; index += 1) {
|
||||
await hooks.before_prompt_build(
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import {
|
||||
debugEmbeddingsLog,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types & constants
|
||||
@@ -254,8 +257,8 @@ function parseCohereBatch(family: Family, raw: string): number[][] {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function createBedrockEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: BedrockEmbeddingClient }> {
|
||||
const client = resolveBedrockEmbeddingClient(options);
|
||||
const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk();
|
||||
const sdk = new BedrockRuntimeClient({ region: client.region });
|
||||
@@ -333,7 +336,7 @@ export async function createBedrockEmbeddingProvider(
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function resolveBedrockEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): BedrockEmbeddingClient {
|
||||
const model = normalizeBedrockEmbeddingModel(options.model);
|
||||
const spec = resolveSpec(model);
|
||||
37
extensions/amazon-bedrock/memory-embedding-adapter.ts
Normal file
37
extensions/amazon-bedrock/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createBedrockEmbeddingProvider,
|
||||
DEFAULT_BEDROCK_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const bedrockMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "bedrock",
|
||||
defaultModel: DEFAULT_BEDROCK_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "amazon-bedrock",
|
||||
autoSelectPriority: 60,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createBedrockEmbeddingProvider({
|
||||
...options,
|
||||
provider: "bedrock",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "bedrock",
|
||||
cacheKeyData: {
|
||||
provider: "bedrock",
|
||||
region: client.region,
|
||||
model: client.model,
|
||||
dimensions: client.dimensions,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -2,6 +2,9 @@
|
||||
"id": "amazon-bedrock",
|
||||
"enabledByDefault": true,
|
||||
"providers": ["amazon-bedrock"],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["bedrock"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
"description": "OpenClaw Amazon Bedrock provider plugin",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"@aws-sdk/client-bedrock": "3.1028.0"
|
||||
"@aws-sdk/client-bedrock": "3.1028.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "3.1028.0",
|
||||
"@aws-sdk/credential-provider-node": "3.972.30"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@openclaw/plugin-sdk": "workspace:*"
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
resolveBedrockConfigApiKey,
|
||||
resolveImplicitBedrockProvider,
|
||||
} from "./api.js";
|
||||
import { bedrockMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
|
||||
type GuardrailConfig = {
|
||||
guardrailIdentifier: string;
|
||||
@@ -78,6 +79,8 @@ export function registerAmazonBedrockPlugin(api: OpenClawPluginApi): void {
|
||||
const pluginConfig = (api.pluginConfig ?? {}) as AmazonBedrockPluginConfig;
|
||||
const guardrail = pluginConfig.guardrail;
|
||||
|
||||
api.registerMemoryEmbeddingProvider(bedrockMemoryEmbeddingProviderAdapter);
|
||||
|
||||
const baseWrapStreamFn = ({ modelId, streamFn }: { modelId: string; streamFn?: StreamFn }) =>
|
||||
isAnthropicBedrockModel(modelId) ? streamFn : createBedrockNoCacheWrapper(streamFn);
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ describeLive("comfy live", () => {
|
||||
beforeAll(async () => {
|
||||
cfg = withPluginsEnabled(loadConfig());
|
||||
agentDir = resolveOpenClawAgentDir();
|
||||
await plugin.register(
|
||||
plugin.register(
|
||||
createTestPluginApi({
|
||||
config: cfg as never,
|
||||
registerImageGenerationProvider(provider) {
|
||||
|
||||
@@ -92,7 +92,7 @@ function registerPairCommand(params?: {
|
||||
pluginConfig?: Record<string, unknown>;
|
||||
}): OpenClawPluginCommandDefinition {
|
||||
let command: OpenClawPluginCommandDefinition | undefined;
|
||||
void registerDevicePair.register(
|
||||
registerDevicePair.register(
|
||||
createApi({
|
||||
...params,
|
||||
registerCommand: (nextCommand) => {
|
||||
|
||||
@@ -4,7 +4,6 @@ const resolveFirstGithubTokenMock = vi.hoisted(() => vi.fn());
|
||||
const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn());
|
||||
const resolveConfiguredSecretInputStringMock = vi.hoisted(() => vi.fn());
|
||||
const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn());
|
||||
const createGitHubCopilotEmbeddingProviderMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("./auth.js", () => ({
|
||||
resolveFirstGithubToken: resolveFirstGithubTokenMock,
|
||||
@@ -19,10 +18,6 @@ vi.mock("openclaw/plugin-sdk/github-copilot-token", () => ({
|
||||
resolveCopilotApiToken: resolveCopilotApiTokenMock,
|
||||
}));
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/memory-core-host-engine-embeddings", () => ({
|
||||
createGitHubCopilotEmbeddingProvider: createGitHubCopilotEmbeddingProviderMock,
|
||||
}));
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({
|
||||
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
|
||||
}));
|
||||
@@ -73,15 +68,6 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
source: "test",
|
||||
baseUrl: TEST_BASE_URL,
|
||||
});
|
||||
createGitHubCopilotEmbeddingProviderMock.mockImplementation(async (client) => ({
|
||||
provider: {
|
||||
id: "github-copilot",
|
||||
model: client.model,
|
||||
embedQuery: async () => [0.1, 0.2, 0.3],
|
||||
embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]),
|
||||
},
|
||||
client,
|
||||
}));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -89,7 +75,6 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
resolveConfiguredSecretInputStringMock.mockReset();
|
||||
resolveFirstGithubTokenMock.mockReset();
|
||||
resolveCopilotApiTokenMock.mockReset();
|
||||
createGitHubCopilotEmbeddingProviderMock.mockReset();
|
||||
fetchWithSsrFGuardMock.mockReset();
|
||||
});
|
||||
|
||||
@@ -113,12 +98,8 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
const result = await githubCopilotMemoryEmbeddingProviderAdapter.create(defaultCreateOptions());
|
||||
|
||||
expect(result.provider?.model).toBe("text-embedding-3-small");
|
||||
expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseUrl: TEST_BASE_URL,
|
||||
githubToken: "gh_test_token_123",
|
||||
model: "text-embedding-3-small",
|
||||
}),
|
||||
expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ githubToken: "gh_test_token_123" }),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -217,14 +198,12 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
} as never);
|
||||
|
||||
expect(resolveFirstGithubTokenMock).toHaveBeenCalled();
|
||||
expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith({
|
||||
baseUrl: "https://proxy.example/v1",
|
||||
env: process.env,
|
||||
fetchImpl: fetch,
|
||||
githubToken: "gh_remote_token",
|
||||
headers: { "X-Proxy-Token": "proxy" },
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
env: process.env,
|
||||
githubToken: "gh_remote_token",
|
||||
}),
|
||||
);
|
||||
|
||||
const discoveryCall = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as {
|
||||
init: { headers: Record<string, string> };
|
||||
|
||||
@@ -4,7 +4,10 @@ import {
|
||||
resolveCopilotApiToken,
|
||||
} from "openclaw/plugin-sdk/github-copilot-token";
|
||||
import {
|
||||
createGitHubCopilotEmbeddingProvider,
|
||||
buildRemoteBaseUrlPolicy,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
withRemoteHttpResponse,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { fetchWithSsrFGuard, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
@@ -44,6 +47,15 @@ type CopilotModelEntry = {
|
||||
supported_endpoints?: unknown;
|
||||
};
|
||||
|
||||
type GitHubCopilotEmbeddingClient = {
|
||||
githubToken: string;
|
||||
model: string;
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
fetchImpl?: typeof fetch;
|
||||
};
|
||||
|
||||
function isCopilotSetupError(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
@@ -147,9 +159,126 @@ function pickBestModel(available: string[], userModel?: string): string {
|
||||
throw new Error("No embedding models available from GitHub Copilot");
|
||||
}
|
||||
|
||||
function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] {
|
||||
if (!payload || typeof payload !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
const data = (payload as { data?: unknown }).data;
|
||||
if (!Array.isArray(data)) {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
|
||||
const vectors = Array.from<number[] | undefined>({ length: expectedCount });
|
||||
for (const entry of data) {
|
||||
if (!entry || typeof entry !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid entry");
|
||||
}
|
||||
const indexValue = (entry as { index?: unknown }).index;
|
||||
const embedding = (entry as { embedding?: unknown }).embedding;
|
||||
const index = typeof indexValue === "number" ? indexValue : Number.NaN;
|
||||
if (!Number.isInteger(index)) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid index");
|
||||
}
|
||||
if (index < 0 || index >= expectedCount) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an out-of-range index");
|
||||
}
|
||||
if (vectors[index] !== undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response contains duplicate indexes");
|
||||
}
|
||||
if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid embedding");
|
||||
}
|
||||
vectors[index] = sanitizeAndNormalizeEmbedding(embedding);
|
||||
}
|
||||
|
||||
for (let index = 0; index < expectedCount; index += 1) {
|
||||
if (vectors[index] === undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response missing vectors for some inputs");
|
||||
}
|
||||
}
|
||||
return vectors as number[][];
|
||||
}
|
||||
|
||||
async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
}> {
|
||||
const token = await resolveCopilotApiToken({
|
||||
githubToken: client.githubToken,
|
||||
env: client.env,
|
||||
fetchImpl: client.fetchImpl,
|
||||
});
|
||||
const baseUrl = client.baseUrl?.trim() || token.baseUrl || DEFAULT_COPILOT_API_BASE_URL;
|
||||
return {
|
||||
baseUrl,
|
||||
headers: {
|
||||
...COPILOT_HEADERS_STATIC,
|
||||
...client.headers,
|
||||
Authorization: `Bearer ${token.token}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function createGitHubCopilotEmbeddingProvider(
|
||||
client: GitHubCopilotEmbeddingClient,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
|
||||
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const session = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
return await withRemoteHttpResponse({
|
||||
url,
|
||||
fetchImpl: client.fetchImpl,
|
||||
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: session.headers,
|
||||
body: JSON.stringify({ model: client.model, input }),
|
||||
},
|
||||
onResponse: async (response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`,
|
||||
);
|
||||
}
|
||||
|
||||
let payload: unknown;
|
||||
try {
|
||||
payload = await response.json();
|
||||
} catch {
|
||||
throw new Error("GitHub Copilot embeddings returned invalid JSON");
|
||||
}
|
||||
return parseGitHubCopilotEmbeddingPayload(payload, input.length);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
model: client.model,
|
||||
embedQuery: async (text) => {
|
||||
const [vector] = await embed([text]);
|
||||
return vector ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
},
|
||||
client: {
|
||||
...client,
|
||||
baseUrl: initialSession.baseUrl,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export const githubCopilotMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
transport: "remote",
|
||||
authProviderId: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
autoSelectPriority: 15,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: (err: unknown) => isCopilotSetupError(err),
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import crypto from "node:crypto";
|
||||
import {
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
runEmbeddingBatchGroups,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
} from "./batch-runner.js";
|
||||
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js";
|
||||
import { hashText } from "./internal.js";
|
||||
import { withRemoteHttpResponse } from "./remote-http.js";
|
||||
buildBatchHeaders,
|
||||
debugEmbeddingsLog,
|
||||
normalizeBatchBaseUrl,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
withRemoteHttpResponse,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embedding-provider.js";
|
||||
|
||||
export type GeminiBatchRequest = {
|
||||
custom_id: string;
|
||||
@@ -40,6 +41,10 @@ export type GeminiBatchOutputLine = {
|
||||
};
|
||||
|
||||
const GEMINI_BATCH_MAX_REQUESTS = 50000;
|
||||
function hashText(text: string): string {
|
||||
return crypto.createHash("sha256").update(text).digest("hex");
|
||||
}
|
||||
|
||||
function getGeminiUploadUrl(baseUrl: string): string {
|
||||
if (baseUrl.includes("/v1beta")) {
|
||||
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");
|
||||
@@ -1,84 +1,40 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../agents/model-auth.js";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
import {
|
||||
createGeminiBatchFetchMock,
|
||||
createJsonResponseFetchMock,
|
||||
installFetchMock,
|
||||
mockResolvedProviderKey,
|
||||
parseFetchBody,
|
||||
readFirstFetchRequest,
|
||||
type JsonFetchMock,
|
||||
} from "./embeddings-provider.test-support.js";
|
||||
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("../../agents/api-key-rotation.js", () => ({
|
||||
collectProviderApiKeysForExecution: (params: { primaryApiKey?: string }) =>
|
||||
params.primaryApiKey ? [params.primaryApiKey] : [],
|
||||
executeWithApiKeyRotation: async <T>(params: {
|
||||
apiKeys: string[];
|
||||
execute: (apiKey: string) => Promise<T>;
|
||||
}) => {
|
||||
const apiKey = params.apiKeys[0];
|
||||
if (!apiKey) {
|
||||
throw new Error('No API keys configured for provider "google".');
|
||||
}
|
||||
return await params.execute(apiKey);
|
||||
},
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
});
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
type GeminiProviderOptions = Parameters<
|
||||
typeof import("./embeddings-gemini.js").createGeminiEmbeddingProvider
|
||||
>[0];
|
||||
|
||||
async function createProviderWithFetch(
|
||||
fetchMock: JsonFetchMock,
|
||||
options: Partial<GeminiProviderOptions> & { model: string },
|
||||
) {
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
|
||||
const { createGeminiEmbeddingProvider } = await import("./embeddings-gemini.js");
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
...options,
|
||||
function installFetchMock(
|
||||
handler: (input: RequestInfo | URL, init?: RequestInit) => unknown,
|
||||
): ReturnType<typeof vi.fn> {
|
||||
const fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => {
|
||||
return new Response(JSON.stringify(handler(input, init)), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
});
|
||||
return provider;
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
return fetchMock;
|
||||
}
|
||||
|
||||
function fetchJsonBody(fetchMock: ReturnType<typeof vi.fn>, index: number): unknown {
|
||||
const init = fetchMock.mock.calls[index]?.[1] as RequestInit | undefined;
|
||||
const body = init?.body;
|
||||
if (typeof body !== "string") {
|
||||
throw new Error("Expected JSON string request body.");
|
||||
}
|
||||
return JSON.parse(body) as unknown;
|
||||
}
|
||||
|
||||
describe("Gemini embedding request helpers", () => {
|
||||
@@ -149,24 +105,9 @@ describe("Gemini embedding request helpers", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("gemini embedding provider", () => {
|
||||
describe("Gemini embedding provider", () => {
|
||||
it("handles legacy and v2 request/response behavior", async () => {
|
||||
const legacyFetch = createGeminiBatchFetchMock(2);
|
||||
const legacyProvider = await createProviderWithFetch(legacyFetch, {
|
||||
model: "gemini-embedding-001",
|
||||
});
|
||||
|
||||
await legacyProvider.embedQuery("test query");
|
||||
await legacyProvider.embedBatch(["text1", "text2"]);
|
||||
|
||||
expect(parseFetchBody(legacyFetch, 0)).toMatchObject({
|
||||
taskType: "RETRIEVAL_QUERY",
|
||||
content: { parts: [{ text: "test query" }] },
|
||||
});
|
||||
expect(parseFetchBody(legacyFetch, 0)).not.toHaveProperty("outputDimensionality");
|
||||
expect(parseFetchBody(legacyFetch, 1)).not.toHaveProperty("outputDimensionality");
|
||||
|
||||
const v2Fetch = createJsonResponseFetchMock((input) => {
|
||||
const fetchMock = installFetchMock((input) => {
|
||||
const url = input instanceof URL ? input.href : typeof input === "string" ? input : input.url;
|
||||
return url.endsWith(":batchEmbedContents")
|
||||
? {
|
||||
@@ -176,16 +117,22 @@ describe("gemini embedding provider", () => {
|
||||
}
|
||||
: { embedding: { values: [3, 4, Number.NaN] } };
|
||||
});
|
||||
const v2Provider = await createProviderWithFetch(v2Fetch, {
|
||||
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: { apiKey: "test-key" },
|
||||
model: "gemini-embedding-2-preview",
|
||||
outputDimensionality: 768,
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
fallback: "none",
|
||||
});
|
||||
await expect(v2Provider.embedQuery(" ")).resolves.toEqual([]);
|
||||
await expect(v2Provider.embedBatch([])).resolves.toEqual([]);
|
||||
await expect(v2Provider.embedQuery("test query")).resolves.toEqual([0.6, 0.8, 0]);
|
||||
|
||||
const structuredBatch = await v2Provider.embedBatchInputs?.([
|
||||
await expect(provider.embedQuery(" ")).resolves.toEqual([]);
|
||||
await expect(provider.embedBatch([])).resolves.toEqual([]);
|
||||
await expect(provider.embedQuery("test query")).resolves.toEqual([0.6, 0.8, 0]);
|
||||
|
||||
const structuredBatch = await provider.embedBatchInputs?.([
|
||||
{
|
||||
text: "Image file: diagram.png",
|
||||
parts: [
|
||||
@@ -206,38 +153,39 @@ describe("gemini embedding provider", () => {
|
||||
[0, 0, 1],
|
||||
]);
|
||||
|
||||
const { url } = readFirstFetchRequest(v2Fetch);
|
||||
expect(url).toBe(
|
||||
expect(fetchMock.mock.calls[0]?.[0]).toBe(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:embedContent",
|
||||
);
|
||||
expect(parseFetchBody(v2Fetch, 0)).toMatchObject({
|
||||
expect(fetchJsonBody(fetchMock, 0)).toMatchObject({
|
||||
outputDimensionality: 768,
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
content: { parts: [{ text: "test query" }] },
|
||||
});
|
||||
expect(parseFetchBody(v2Fetch, 1).requests).toEqual([
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "img" } },
|
||||
],
|
||||
expect(fetchJsonBody(fetchMock, 1)).toMatchObject({
|
||||
requests: [
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "img" } },
|
||||
],
|
||||
},
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
outputDimensionality: 768,
|
||||
},
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
outputDimensionality: 768,
|
||||
},
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Audio file: note.wav" },
|
||||
{ inlineData: { mimeType: "audio/wav", data: "aud" } },
|
||||
],
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Audio file: note.wav" },
|
||||
{ inlineData: { mimeType: "audio/wav", data: "aud" } },
|
||||
],
|
||||
},
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
outputDimensionality: 768,
|
||||
},
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
outputDimensionality: 768,
|
||||
},
|
||||
]);
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,44 +1,22 @@
|
||||
import { parseGeminiAuth } from "openclaw/plugin-sdk/image-generation-core";
|
||||
import {
|
||||
buildRemoteBaseUrlPolicy,
|
||||
debugEmbeddingsLog,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
withRemoteHttpResponse,
|
||||
type EmbeddingInput,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { resolveMemorySecretInputString } from "openclaw/plugin-sdk/memory-core-host-secret";
|
||||
import {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../../agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js";
|
||||
import { parseGeminiAuth } from "../../infra/gemini-auth.js";
|
||||
import {
|
||||
DEFAULT_GOOGLE_API_BASE_URL,
|
||||
normalizeGoogleApiBaseUrl,
|
||||
} from "../../infra/google-api-base-url.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeOptionalString } from "../../shared/string-coerce.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
type GeminiEmbeddingRequest,
|
||||
type GeminiInlinePart,
|
||||
type GeminiPart,
|
||||
type GeminiTaskType,
|
||||
type GeminiTextEmbeddingRequest,
|
||||
type GeminiTextPart,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
requireApiKey,
|
||||
resolveApiKeyForProvider,
|
||||
} from "openclaw/plugin-sdk/provider-auth-runtime";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
|
||||
|
||||
export type GeminiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
@@ -50,9 +28,111 @@ export type GeminiEmbeddingClient = {
|
||||
outputDimensionality?: number;
|
||||
};
|
||||
|
||||
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
const DEFAULT_GOOGLE_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
|
||||
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-004": 2048,
|
||||
"gemini-embedding-001": 2048,
|
||||
"gemini-embedding-2-preview": 8192,
|
||||
};
|
||||
|
||||
export type GeminiTaskType = NonNullable<MemoryEmbeddingProviderCreateOptions["taskType"]>;
|
||||
|
||||
// --- gemini-embedding-2-preview support ---
|
||||
|
||||
export const GEMINI_EMBEDDING_2_MODELS = new Set([
|
||||
"gemini-embedding-2-preview",
|
||||
// Add the GA model name here once released.
|
||||
]);
|
||||
|
||||
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
|
||||
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
|
||||
|
||||
export type GeminiTextPart = { text: string };
|
||||
export type GeminiInlinePart = {
|
||||
inlineData: { mimeType: string; data: string };
|
||||
};
|
||||
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
|
||||
export type GeminiEmbeddingRequest = {
|
||||
content: { parts: GeminiPart[] };
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
model?: string;
|
||||
};
|
||||
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
|
||||
|
||||
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
|
||||
export function buildGeminiTextEmbeddingRequest(params: {
|
||||
text: string;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiTextEmbeddingRequest {
|
||||
return buildGeminiEmbeddingRequest({
|
||||
input: { text: params.text },
|
||||
taskType: params.taskType,
|
||||
outputDimensionality: params.outputDimensionality,
|
||||
modelPath: params.modelPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGeminiEmbeddingRequest(params: {
|
||||
input: EmbeddingInput;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiEmbeddingRequest {
|
||||
const request: GeminiEmbeddingRequest = {
|
||||
content: {
|
||||
parts: params.input.parts?.map((part) =>
|
||||
part.type === "text"
|
||||
? ({ text: part.text } satisfies GeminiTextPart)
|
||||
: ({
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
} satisfies GeminiInlinePart),
|
||||
) ?? [{ text: params.input.text }],
|
||||
},
|
||||
taskType: params.taskType,
|
||||
};
|
||||
if (params.modelPath) {
|
||||
request.model = params.modelPath;
|
||||
}
|
||||
if (params.outputDimensionality != null) {
|
||||
request.outputDimensionality = params.outputDimensionality;
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given model name is a gemini-embedding-2 variant that
|
||||
* supports `outputDimensionality` and extended task types.
|
||||
*/
|
||||
export function isGeminiEmbedding2Model(model: string): boolean {
|
||||
return GEMINI_EMBEDDING_2_MODELS.has(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
|
||||
* Returns `undefined` for older models (they don't support the param).
|
||||
*/
|
||||
export function resolveGeminiOutputDimensionality(
|
||||
model: string,
|
||||
requested?: number,
|
||||
): number | undefined {
|
||||
if (!isGeminiEmbedding2Model(model)) {
|
||||
return undefined;
|
||||
}
|
||||
if (requested == null) {
|
||||
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
|
||||
if (!valid.includes(requested)) {
|
||||
throw new Error(
|
||||
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
|
||||
);
|
||||
}
|
||||
return requested;
|
||||
}
|
||||
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
const trimmed = resolveMemorySecretInputString({
|
||||
value: remoteApiKey,
|
||||
@@ -67,6 +147,21 @@ function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
export function normalizeGeminiModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
}
|
||||
const withoutPrefix = trimmed.replace(/^models\//, "");
|
||||
if (withoutPrefix.startsWith("gemini/")) {
|
||||
return withoutPrefix.slice("gemini/".length);
|
||||
}
|
||||
if (withoutPrefix.startsWith("google/")) {
|
||||
return withoutPrefix.slice("google/".length);
|
||||
}
|
||||
return withoutPrefix;
|
||||
}
|
||||
|
||||
async function fetchGeminiEmbeddingPayload(params: {
|
||||
client: GeminiEmbeddingClient;
|
||||
endpoint: string;
|
||||
@@ -120,9 +215,30 @@ function buildGeminiModelPath(model: string): string {
|
||||
return model.startsWith("models/") ? model : `models/${model}`;
|
||||
}
|
||||
|
||||
function normalizeGoogleApiBaseUrl(baseUrl: string): string {
|
||||
const trimmed = baseUrl.trim().replace(/\/+$/, "");
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GOOGLE_API_BASE_URL;
|
||||
}
|
||||
try {
|
||||
const url = new URL(trimmed);
|
||||
url.hash = "";
|
||||
url.search = "";
|
||||
if (
|
||||
url.origin.toLowerCase() === "https://generativelanguage.googleapis.com" &&
|
||||
url.pathname.replace(/\/+$/, "") === ""
|
||||
) {
|
||||
url.pathname = "/v1beta";
|
||||
}
|
||||
return url.toString().replace(/\/+$/, "");
|
||||
} catch {
|
||||
return trimmed;
|
||||
}
|
||||
}
|
||||
|
||||
export async function createGeminiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: GeminiEmbeddingClient }> {
|
||||
const client = await resolveGeminiEmbeddingClient(options);
|
||||
const baseUrl = client.baseUrl.replace(/\/$/, "");
|
||||
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
|
||||
@@ -190,7 +306,7 @@ export async function createGeminiEmbeddingProvider(
|
||||
}
|
||||
|
||||
export async function resolveGeminiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<GeminiEmbeddingClient> {
|
||||
const remote = options.remote;
|
||||
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
|
||||
@@ -3,6 +3,7 @@ import type { MediaUnderstandingProvider } from "openclaw/plugin-sdk/media-under
|
||||
import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { buildGoogleGeminiCliBackend } from "./cli-backend.js";
|
||||
import { registerGoogleGeminiCliProvider } from "./gemini-cli-provider.js";
|
||||
import { geminiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { buildGoogleMusicGenerationProvider } from "./music-generation-provider.js";
|
||||
import { registerGoogleProvider } from "./provider-registration.js";
|
||||
import { buildGoogleSpeechProvider } from "./speech-provider.js";
|
||||
@@ -111,6 +112,7 @@ export default definePluginEntry({
|
||||
api.registerCliBackend(buildGoogleGeminiCliBackend());
|
||||
registerGoogleGeminiCliProvider(api);
|
||||
registerGoogleProvider(api);
|
||||
api.registerMemoryEmbeddingProvider(geminiMemoryEmbeddingProviderAdapter);
|
||||
api.registerImageGenerationProvider(createLazyGoogleImageGenerationProvider());
|
||||
api.registerMediaUnderstandingProvider(createLazyGoogleMediaUnderstandingProvider());
|
||||
api.registerMusicGenerationProvider(buildGoogleMusicGenerationProvider());
|
||||
|
||||
79
extensions/google/memory-embedding-adapter.ts
Normal file
79
extensions/google/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import {
|
||||
hasNonTextEmbeddingParts,
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { runGeminiEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
function supportsGeminiMultimodalEmbeddings(model: string): boolean {
|
||||
const normalized = model
|
||||
.trim()
|
||||
.replace(/^models\//, "")
|
||||
.replace(/^(gemini|google)\//, "");
|
||||
return normalized === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
export const geminiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "gemini",
|
||||
defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "google",
|
||||
autoSelectPriority: 30,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model),
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "gemini",
|
||||
cacheKeyData: {
|
||||
provider: "gemini",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, [
|
||||
"authorization",
|
||||
"x-goog-api-key",
|
||||
]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) {
|
||||
return null;
|
||||
}
|
||||
const byCustomId = await runGeminiEmbeddingBatches({
|
||||
gemini: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
request: buildGeminiEmbeddingRequest({
|
||||
input: chunk.embeddingInput ?? { text: chunk.text },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: client.modelPath,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
}),
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -46,6 +46,7 @@
|
||||
},
|
||||
"contracts": {
|
||||
"mediaUnderstandingProviders": ["google"],
|
||||
"memoryEmbeddingProviders": ["gemini"],
|
||||
"imageGenerationProviders": ["google"],
|
||||
"musicGenerationProviders": ["google"],
|
||||
"speechProviders": ["google"],
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
type ProviderRuntimeModel,
|
||||
} from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { CUSTOM_LOCAL_AUTH_MARKER } from "openclaw/plugin-sdk/provider-auth";
|
||||
import { lmstudioMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import {
|
||||
LMSTUDIO_DEFAULT_API_KEY_ENV_VAR,
|
||||
LMSTUDIO_LOCAL_API_KEY_PLACEHOLDER,
|
||||
@@ -52,6 +53,7 @@ export default definePluginEntry({
|
||||
name: "LM Studio Provider",
|
||||
description: "Bundled LM Studio provider plugin",
|
||||
register(api: OpenClawPluginApi) {
|
||||
api.registerMemoryEmbeddingProvider(lmstudioMemoryEmbeddingProviderAdapter);
|
||||
api.registerProvider({
|
||||
id: PROVIDER_ID,
|
||||
label: "LM Studio",
|
||||
|
||||
35
extensions/lmstudio/memory-embedding-adapter.ts
Normal file
35
extensions/lmstudio/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import {
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createLmstudioEmbeddingProvider,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
} from "./src/embedding-provider.js";
|
||||
|
||||
export const lmstudioMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "lmstudio",
|
||||
defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "lmstudio",
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider({
|
||||
...options,
|
||||
provider: "lmstudio",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "lmstudio",
|
||||
cacheKeyData: {
|
||||
provider: "lmstudio",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -21,6 +21,9 @@
|
||||
"groupHint": "Self-hosted open-weight models"
|
||||
}
|
||||
],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["lmstudio"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import { formatErrorMessage } from "../../infra/errors.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { createSubsystemLogger } from "../../logging/subsystem.js";
|
||||
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
|
||||
import {
|
||||
buildRemoteBaseUrlPolicy,
|
||||
createRemoteEmbeddingProvider,
|
||||
normalizeEmbeddingModelWithPrefixes,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { resolveMemorySecretInputString } from "openclaw/plugin-sdk/memory-core-host-secret";
|
||||
import { formatErrorMessage, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { LMSTUDIO_DEFAULT_EMBEDDING_MODEL, LMSTUDIO_PROVIDER_ID } from "./defaults.js";
|
||||
import { ensureLmstudioModelLoaded } from "./models.fetch.js";
|
||||
import { resolveLmstudioInferenceBase } from "./models.js";
|
||||
import {
|
||||
buildLmstudioAuthHeaders,
|
||||
ensureLmstudioModelLoaded,
|
||||
LMSTUDIO_DEFAULT_EMBEDDING_MODEL,
|
||||
LMSTUDIO_PROVIDER_ID,
|
||||
resolveLmstudioInferenceBase,
|
||||
resolveLmstudioProviderHeaders,
|
||||
resolveLmstudioRuntimeApiKey,
|
||||
} from "../../plugin-sdk/lmstudio-runtime.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import { createRemoteEmbeddingProvider } from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
} from "./runtime.js";
|
||||
|
||||
const log = createSubsystemLogger("memory/embeddings");
|
||||
|
||||
@@ -47,7 +48,7 @@ function hasAuthorizationHeader(headers: Record<string, string> | undefined): bo
|
||||
|
||||
/** Resolves API key (real or synthetic placeholder) from runtime/provider auth config. */
|
||||
async function resolveLmstudioApiKey(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<string | undefined> {
|
||||
try {
|
||||
return await resolveLmstudioRuntimeApiKey({
|
||||
@@ -65,8 +66,8 @@ async function resolveLmstudioApiKey(
|
||||
|
||||
/** Creates the LM Studio embedding provider client and preloads the target model before return. */
|
||||
export async function createLmstudioEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: LmstudioEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: LmstudioEmbeddingClient }> {
|
||||
const providerConfig = options.config.models?.providers?.lmstudio;
|
||||
const providerBaseUrl = providerConfig?.baseUrl?.trim();
|
||||
const isFallbackActivation = options.fallback === "lmstudio" && options.provider !== "lmstudio";
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { IncomingMessage, ServerResponse } from "node:http";
|
||||
import { PassThrough } from "node:stream";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { OpenClawConfig, RuntimeEnv } from "../../runtime-api.js";
|
||||
import type { ResolvedMattermostAccount } from "./accounts.js";
|
||||
import { createSlashCommandHttpHandler } from "./slash-http.js";
|
||||
@@ -133,25 +133,19 @@ describe("slash-http", () => {
|
||||
});
|
||||
|
||||
it("returns 408 when the request body stalls", async () => {
|
||||
vi.useFakeTimers();
|
||||
try {
|
||||
const handler = createSlashCommandHttpHandler({
|
||||
account: accountFixture,
|
||||
cfg: {} as OpenClawConfig,
|
||||
runtime: {} as RuntimeEnv,
|
||||
commandTokens: new Set(["valid-token"]),
|
||||
});
|
||||
const req = createRequest({ autoEnd: false });
|
||||
const response = createResponse();
|
||||
const pending = handler(req, response.res);
|
||||
const handler = createSlashCommandHttpHandler({
|
||||
account: accountFixture,
|
||||
cfg: {} as OpenClawConfig,
|
||||
runtime: {} as RuntimeEnv,
|
||||
commandTokens: new Set(["valid-token"]),
|
||||
bodyTimeoutMs: 1,
|
||||
});
|
||||
const req = createRequest({ autoEnd: false });
|
||||
const response = createResponse();
|
||||
|
||||
await vi.advanceTimersByTimeAsync(5_000);
|
||||
await pending;
|
||||
await handler(req, response.res);
|
||||
|
||||
expect(response.res.statusCode).toBe(408);
|
||||
expect(response.getBody()).toBe("Request body timeout");
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
expect(response.res.statusCode).toBe(408);
|
||||
expect(response.getBody()).toBe("Request body timeout");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -54,6 +54,7 @@ type SlashHttpHandlerParams = {
|
||||
/** Map from trigger to original command name (for skill commands that start with oc_). */
|
||||
triggerMap?: ReadonlyMap<string, string>;
|
||||
log?: (msg: string) => void;
|
||||
bodyTimeoutMs?: number;
|
||||
};
|
||||
|
||||
const MAX_BODY_BYTES = 64 * 1024;
|
||||
@@ -62,10 +63,14 @@ const BODY_READ_TIMEOUT_MS = 5_000;
|
||||
/**
|
||||
* Read the full request body as a string.
|
||||
*/
|
||||
function readBody(req: IncomingMessage, maxBytes: number): Promise<string> {
|
||||
function readBody(
|
||||
req: IncomingMessage,
|
||||
maxBytes: number,
|
||||
timeoutMs = BODY_READ_TIMEOUT_MS,
|
||||
): Promise<string> {
|
||||
return readRequestBodyWithLimit(req, {
|
||||
maxBytes,
|
||||
timeoutMs: BODY_READ_TIMEOUT_MS,
|
||||
timeoutMs,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -219,7 +224,7 @@ async function authorizeSlashInvocation(params: {
|
||||
* from the Mattermost server when a user invokes a registered slash command.
|
||||
*/
|
||||
export function createSlashCommandHttpHandler(params: SlashHttpHandlerParams) {
|
||||
const { account, cfg, runtime, commandTokens, triggerMap, log } = params;
|
||||
const { account, cfg, runtime, commandTokens, triggerMap, log, bodyTimeoutMs } = params;
|
||||
|
||||
return async (req: IncomingMessage, res: ServerResponse): Promise<void> => {
|
||||
if (req.method !== "POST") {
|
||||
@@ -231,7 +236,7 @@ export function createSlashCommandHttpHandler(params: SlashHttpHandlerParams) {
|
||||
|
||||
let body: string;
|
||||
try {
|
||||
body = await readBody(req, MAX_BODY_BYTES);
|
||||
body = await readBody(req, MAX_BODY_BYTES, bodyTimeoutMs);
|
||||
} catch (error) {
|
||||
if (isRequestBodyLimitError(error, "REQUEST_BODY_TIMEOUT")) {
|
||||
res.statusCode = 408;
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
getMemoryEmbeddingProvider,
|
||||
listMemoryEmbeddingProviders,
|
||||
type MemoryEmbeddingProvider,
|
||||
@@ -15,15 +10,7 @@ import {
|
||||
import { formatErrorMessage } from "../dreaming-shared.js";
|
||||
import { canAutoSelectLocal } from "./provider-adapters.js";
|
||||
|
||||
export {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
export { DEFAULT_LOCAL_MODEL } from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
|
||||
export type EmbeddingProvider = MemoryEmbeddingProvider;
|
||||
export type EmbeddingProviderId = string;
|
||||
|
||||
@@ -11,9 +11,9 @@ import {
|
||||
} from "../../../../src/plugins/memory-embedding-providers.js";
|
||||
import "./test-runtime-mocks.js";
|
||||
import type { MemoryIndexManager } from "./index.js";
|
||||
import { getMemorySearchManager, closeAllMemorySearchManagers } from "./index.js";
|
||||
import { closeAllMemorySearchManagers, getMemorySearchManager } from "./index.js";
|
||||
import {
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
registerBuiltInMemoryEmbeddingProviders,
|
||||
} from "./provider-adapters.js";
|
||||
|
||||
@@ -112,14 +112,14 @@ vi.mock("./embeddings.js", () => {
|
||||
});
|
||||
|
||||
describe("memory index", () => {
|
||||
it("registers the builtin ollama embedding provider", () => {
|
||||
const adapter = listRegisteredAdapters().find((entry) => entry.id === "ollama");
|
||||
it("registers the builtin local embedding provider", () => {
|
||||
const adapter = listRegisteredAdapters().find((entry) => entry.id === "local");
|
||||
|
||||
expect(adapter).toBeDefined();
|
||||
expect(adapter).toEqual(
|
||||
expect.objectContaining({
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
id: "local",
|
||||
defaultModel: DEFAULT_LOCAL_MODEL,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,31 +1,13 @@
|
||||
import fsSync from "node:fs";
|
||||
import {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
OPENAI_BATCH_ENDPOINT,
|
||||
buildGeminiEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
createLmstudioEmbeddingProvider,
|
||||
createLocalEmbeddingProvider,
|
||||
createMistralEmbeddingProvider,
|
||||
createOllamaEmbeddingProvider,
|
||||
createOpenAiEmbeddingProvider,
|
||||
createVoyageEmbeddingProvider,
|
||||
hasNonTextEmbeddingParts,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
listMemoryEmbeddingProviders,
|
||||
listRegisteredMemoryEmbeddingProviderAdapters,
|
||||
runGeminiEmbeddingBatches,
|
||||
runOpenAiEmbeddingBatches,
|
||||
runVoyageEmbeddingBatches,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { resolveUserPath } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import { getProviderEnvVars } from "openclaw/plugin-sdk/provider-env-vars";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { formatErrorMessage } from "../dreaming-shared.js";
|
||||
import { filterUnregisteredMemoryEmbeddingProviderAdapters } from "./provider-adapter-registration.js";
|
||||
|
||||
@@ -37,31 +19,6 @@ export type BuiltinMemoryEmbeddingProviderDoctorMetadata = {
|
||||
autoSelectPriority?: number;
|
||||
};
|
||||
|
||||
function isMissingApiKeyError(err: unknown): boolean {
|
||||
return formatErrorMessage(err).includes("No API key found for provider");
|
||||
}
|
||||
|
||||
function sanitizeHeaders(
|
||||
headers: Record<string, string>,
|
||||
excludedHeaderNames: string[],
|
||||
): Array<[string, string]> {
|
||||
const excluded = new Set(
|
||||
excludedHeaderNames.map((name) => normalizeLowercaseStringOrEmpty(name)),
|
||||
);
|
||||
return Object.entries(headers)
|
||||
.filter(([key]) => !excluded.has(normalizeLowercaseStringOrEmpty(key)))
|
||||
.toSorted(([a], [b]) => a.localeCompare(b))
|
||||
.map(([key, value]) => [key, value]);
|
||||
}
|
||||
|
||||
function mapBatchEmbeddingsByIndex(byCustomId: Map<string, number[]>, count: number): number[][] {
|
||||
const embeddings: number[][] = [];
|
||||
for (let index = 0; index < count; index += 1) {
|
||||
embeddings.push(byCustomId.get(String(index)) ?? []);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
@@ -70,6 +27,20 @@ function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
return code === "ERR_MODULE_NOT_FOUND" && err.message.includes("node-llama-cpp");
|
||||
}
|
||||
|
||||
function listRemoteEmbeddingSetupHints(): string[] {
|
||||
try {
|
||||
return listMemoryEmbeddingProviders()
|
||||
.filter(
|
||||
(adapter) =>
|
||||
adapter.transport === "remote" && typeof adapter.autoSelectPriority === "number",
|
||||
)
|
||||
.toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0))
|
||||
.map((adapter) => `Or set agents.defaults.memorySearch.provider = "${adapter.id}" (remote).`);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function formatLocalSetupError(err: unknown): string {
|
||||
const detail = formatErrorMessage(err);
|
||||
const missing = isNodeLlamaCppMissing(err);
|
||||
@@ -87,9 +58,7 @@ function formatLocalSetupError(err: unknown): string {
|
||||
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
|
||||
: null,
|
||||
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
|
||||
...["openai", "gemini", "voyage", "mistral"].map(
|
||||
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
|
||||
),
|
||||
...listRemoteEmbeddingSetupHints(),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
@@ -111,237 +80,6 @@ function canAutoSelectLocal(modelPath?: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
function supportsGeminiMultimodalEmbeddings(model: string): boolean {
|
||||
const normalized = model
|
||||
.trim()
|
||||
.replace(/^models\//, "")
|
||||
.replace(/^(gemini|google)\//, "");
|
||||
return normalized === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
function resolveMemoryEmbeddingAuthProviderId(providerId: string): string {
|
||||
return providerId === "gemini" ? "google" : providerId;
|
||||
}
|
||||
|
||||
const openAiAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "openai",
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 20,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "openai",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "openai",
|
||||
cacheKeyData: {
|
||||
provider: "openai",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runOpenAiEmbeddingBatches({
|
||||
openAi: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: {
|
||||
model: client.model,
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const geminiAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "gemini",
|
||||
defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 30,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model),
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "gemini",
|
||||
cacheKeyData: {
|
||||
provider: "gemini",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization", "x-goog-api-key"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) {
|
||||
return null;
|
||||
}
|
||||
const byCustomId = await runGeminiEmbeddingBatches({
|
||||
gemini: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
request: buildGeminiEmbeddingRequest({
|
||||
input: chunk.embeddingInput ?? { text: chunk.text },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: client.modelPath,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
}),
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const voyageAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "voyage",
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 40,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider({
|
||||
...options,
|
||||
provider: "voyage",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "voyage",
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runVoyageEmbeddingBatches({
|
||||
client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
body: {
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const mistralAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "mistral",
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 50,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createMistralEmbeddingProvider({
|
||||
...options,
|
||||
provider: "mistral",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "mistral",
|
||||
cacheKeyData: {
|
||||
provider: "mistral",
|
||||
model: client.model,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const ollamaAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider({
|
||||
...options,
|
||||
provider: "ollama",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "ollama",
|
||||
cacheKeyData: {
|
||||
provider: "ollama",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const lmstudioAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "lmstudio",
|
||||
defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider({
|
||||
...options,
|
||||
provider: "lmstudio",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "lmstudio",
|
||||
cacheKeyData: {
|
||||
provider: "lmstudio",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const localAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "local",
|
||||
defaultModel: DEFAULT_LOCAL_MODEL,
|
||||
@@ -368,24 +106,14 @@ const localAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
},
|
||||
};
|
||||
|
||||
export const builtinMemoryEmbeddingProviderAdapters = [
|
||||
localAdapter,
|
||||
openAiAdapter,
|
||||
geminiAdapter,
|
||||
voyageAdapter,
|
||||
mistralAdapter,
|
||||
ollamaAdapter,
|
||||
lmstudioAdapter,
|
||||
] as const;
|
||||
export const builtinMemoryEmbeddingProviderAdapters = [localAdapter] as const;
|
||||
|
||||
const builtinMemoryEmbeddingProviderAdapterById = new Map(
|
||||
builtinMemoryEmbeddingProviderAdapters.map((adapter) => [adapter.id, adapter]),
|
||||
);
|
||||
export { DEFAULT_LOCAL_MODEL };
|
||||
|
||||
export function getBuiltinMemoryEmbeddingProviderAdapter(
|
||||
id: string,
|
||||
): MemoryEmbeddingProviderAdapter | undefined {
|
||||
return builtinMemoryEmbeddingProviderAdapterById.get(id);
|
||||
return listMemoryEmbeddingProviders().find((adapter) => adapter.id === id);
|
||||
}
|
||||
|
||||
export function registerBuiltInMemoryEmbeddingProviders(register: {
|
||||
@@ -409,7 +137,7 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata(
|
||||
if (!adapter) {
|
||||
return null;
|
||||
}
|
||||
const authProviderId = resolveMemoryEmbeddingAuthProviderId(adapter.id);
|
||||
const authProviderId = adapter.authProviderId ?? adapter.id;
|
||||
return {
|
||||
providerId: adapter.id,
|
||||
authProviderId,
|
||||
@@ -420,27 +148,19 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata(
|
||||
}
|
||||
|
||||
export function listBuiltinAutoSelectMemoryEmbeddingProviderDoctorMetadata(): Array<BuiltinMemoryEmbeddingProviderDoctorMetadata> {
|
||||
return builtinMemoryEmbeddingProviderAdapters
|
||||
return listMemoryEmbeddingProviders()
|
||||
.filter((adapter) => typeof adapter.autoSelectPriority === "number")
|
||||
.toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0))
|
||||
.map((adapter) => ({
|
||||
providerId: adapter.id,
|
||||
authProviderId: resolveMemoryEmbeddingAuthProviderId(adapter.id),
|
||||
envVars: getProviderEnvVars(resolveMemoryEmbeddingAuthProviderId(adapter.id)),
|
||||
transport: adapter.transport === "local" ? "local" : "remote",
|
||||
autoSelectPriority: adapter.autoSelectPriority,
|
||||
}));
|
||||
.map((adapter) => {
|
||||
const authProviderId = adapter.authProviderId ?? adapter.id;
|
||||
return {
|
||||
providerId: adapter.id,
|
||||
authProviderId,
|
||||
envVars: getProviderEnvVars(authProviderId),
|
||||
transport: adapter.transport === "local" ? "local" : "remote",
|
||||
autoSelectPriority: adapter.autoSelectPriority,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
export {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
canAutoSelectLocal,
|
||||
formatLocalSetupError,
|
||||
isMissingApiKeyError,
|
||||
};
|
||||
export { canAutoSelectLocal, formatLocalSetupError };
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
writeFileWithinRoot,
|
||||
type OpenClawConfig,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import {
|
||||
buildSessionEntry,
|
||||
deriveQmdScopeChannel,
|
||||
@@ -47,7 +48,6 @@ import {
|
||||
type ResolvedQmdConfig,
|
||||
type ResolvedQmdMcporterConfig,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-storage";
|
||||
import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import {
|
||||
localeLowercasePreservingWhitespace,
|
||||
normalizeLowercaseStringOrEmpty,
|
||||
@@ -1945,8 +1945,7 @@ export class QmdMemoryManager implements MemorySearchManager {
|
||||
from?: number,
|
||||
lines?: number,
|
||||
): Promise<
|
||||
| { missing: true }
|
||||
| { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean }
|
||||
{ missing: true } | { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean }
|
||||
> {
|
||||
const start = Math.max(1, from ?? 1);
|
||||
const count = Math.max(1, lines ?? Number.POSITIVE_INFINITY);
|
||||
|
||||
@@ -51,7 +51,7 @@ describe("memory-wiki cli metadata entry", () => {
|
||||
const resolvedConfig = { vaultMode: "bridge", vault: { path: "/vault" } };
|
||||
mocks.resolveMemoryWikiConfig.mockReturnValue(resolvedConfig);
|
||||
|
||||
await plugin.register(api);
|
||||
plugin.register(api);
|
||||
|
||||
const register = registerCli.mock.calls[0]?.[0];
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ describe("memory-wiki plugin", () => {
|
||||
registerTool,
|
||||
} = createPluginApi();
|
||||
|
||||
await plugin.register(api);
|
||||
plugin.register(api);
|
||||
|
||||
expect(registerMemoryCorpusSupplement).toHaveBeenCalledTimes(1);
|
||||
expect(registerMemoryPromptSupplement).toHaveBeenCalledTimes(1);
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
normalizeEmbeddingModelWithPrefixes,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
|
||||
export type MistralEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
@@ -25,8 +26,8 @@ export function normalizeMistralModel(model: string): string {
|
||||
}
|
||||
|
||||
export async function createMistralEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: MistralEmbeddingClient }> {
|
||||
const client = await resolveMistralEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
@@ -40,7 +41,7 @@ export async function createMistralEmbeddingProvider(
|
||||
}
|
||||
|
||||
export async function resolveMistralEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<MistralEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "mistral",
|
||||
@@ -1,6 +1,7 @@
|
||||
import { defineSingleProviderPluginEntry } from "openclaw/plugin-sdk/provider-entry";
|
||||
import { applyMistralModelCompat } from "./api.js";
|
||||
import { mistralMediaUnderstandingProvider } from "./media-understanding-provider.js";
|
||||
import { mistralMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { applyMistralConfig, MISTRAL_DEFAULT_MODEL_REF } from "./onboard.js";
|
||||
import { buildMistralProvider } from "./provider-catalog.js";
|
||||
import { contributeMistralResolvedModelCompat } from "./provider-compat.js";
|
||||
@@ -48,6 +49,7 @@ export default defineSingleProviderPluginEntry({
|
||||
buildReplayPolicy: () => buildMistralReplayPolicy(),
|
||||
},
|
||||
register(api) {
|
||||
api.registerMemoryEmbeddingProvider(mistralMemoryEmbeddingProviderAdapter);
|
||||
api.registerMediaUnderstandingProvider(mistralMediaUnderstandingProvider);
|
||||
},
|
||||
});
|
||||
|
||||
35
extensions/mistral/memory-embedding-adapter.ts
Normal file
35
extensions/mistral/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createMistralEmbeddingProvider,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const mistralMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "mistral",
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "mistral",
|
||||
autoSelectPriority: 50,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createMistralEmbeddingProvider({
|
||||
...options,
|
||||
provider: "mistral",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "mistral",
|
||||
cacheKeyData: {
|
||||
provider: "mistral",
|
||||
model: client.model,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -21,6 +21,7 @@
|
||||
}
|
||||
],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["mistral"],
|
||||
"mediaUnderstandingProviders": ["mistral"]
|
||||
},
|
||||
"configSchema": {
|
||||
|
||||
1
extensions/nextcloud-talk/src/api.ts
Normal file
1
extensions/nextcloud-talk/src/api.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { createAuthRateLimiter } from "openclaw/plugin-sdk/nextcloud-talk";
|
||||
@@ -11,13 +11,9 @@ const hoisted = vi.hoisted(() => ({
|
||||
monitorNextcloudTalkProvider: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("./monitor.js", async () => {
|
||||
const actual = await vi.importActual<typeof import("./monitor.js")>("./monitor.js");
|
||||
return {
|
||||
...actual,
|
||||
monitorNextcloudTalkProvider: hoisted.monitorNextcloudTalkProvider,
|
||||
};
|
||||
});
|
||||
vi.mock("./monitor-runtime.js", () => ({
|
||||
monitorNextcloudTalkProvider: hoisted.monitorNextcloudTalkProvider,
|
||||
}));
|
||||
|
||||
const { nextcloudTalkGatewayAdapter } = await import("./gateway.js");
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
type ChannelPlugin,
|
||||
type OpenClawConfig,
|
||||
} from "./channel-api.js";
|
||||
import { monitorNextcloudTalkProvider } from "./monitor.js";
|
||||
import { monitorNextcloudTalkProvider } from "./monitor-runtime.js";
|
||||
import { getNextcloudTalkRuntime } from "./runtime.js";
|
||||
import type { CoreConfig } from "./types.js";
|
||||
|
||||
|
||||
138
extensions/nextcloud-talk/src/monitor-runtime.ts
Normal file
138
extensions/nextcloud-talk/src/monitor-runtime.ts
Normal file
@@ -0,0 +1,138 @@
|
||||
import os from "node:os";
|
||||
import { resolveLoggerBackedRuntime } from "openclaw/plugin-sdk/extension-shared";
|
||||
import type { RuntimeEnv } from "openclaw/plugin-sdk/runtime";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { resolveNextcloudTalkAccount } from "./accounts.js";
|
||||
import { handleNextcloudTalkInbound } from "./inbound.js";
|
||||
import {
|
||||
createNextcloudTalkWebhookServer,
|
||||
processNextcloudTalkReplayGuardedMessage,
|
||||
} from "./monitor.js";
|
||||
import { createNextcloudTalkReplayGuard } from "./replay-guard.js";
|
||||
import { getNextcloudTalkRuntime } from "./runtime.js";
|
||||
import type { CoreConfig, NextcloudTalkInboundMessage } from "./types.js";
|
||||
|
||||
const DEFAULT_WEBHOOK_PORT = 8788;
|
||||
const DEFAULT_WEBHOOK_HOST = "0.0.0.0";
|
||||
const DEFAULT_WEBHOOK_PATH = "/nextcloud-talk-webhook";
|
||||
|
||||
function normalizeOrigin(value: string): string | null {
|
||||
try {
|
||||
return normalizeLowercaseStringOrEmpty(new URL(value).origin);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export type NextcloudTalkMonitorOptions = {
|
||||
accountId?: string;
|
||||
config?: CoreConfig;
|
||||
runtime?: RuntimeEnv;
|
||||
abortSignal?: AbortSignal;
|
||||
onMessage?: (message: NextcloudTalkInboundMessage) => void | Promise<void>;
|
||||
statusSink?: (patch: { lastInboundAt?: number; lastOutboundAt?: number }) => void;
|
||||
};
|
||||
|
||||
export async function monitorNextcloudTalkProvider(
|
||||
opts: NextcloudTalkMonitorOptions,
|
||||
): Promise<{ stop: () => void }> {
|
||||
const core = getNextcloudTalkRuntime();
|
||||
const cfg = opts.config ?? (core.config.loadConfig() as CoreConfig);
|
||||
const account = resolveNextcloudTalkAccount({
|
||||
cfg,
|
||||
accountId: opts.accountId,
|
||||
});
|
||||
const runtime: RuntimeEnv = resolveLoggerBackedRuntime(
|
||||
opts.runtime,
|
||||
core.logging.getChildLogger(),
|
||||
);
|
||||
|
||||
if (!account.secret) {
|
||||
throw new Error(`Nextcloud Talk bot secret not configured for account "${account.accountId}"`);
|
||||
}
|
||||
|
||||
const port = account.config.webhookPort ?? DEFAULT_WEBHOOK_PORT;
|
||||
const host = account.config.webhookHost ?? DEFAULT_WEBHOOK_HOST;
|
||||
const path = account.config.webhookPath ?? DEFAULT_WEBHOOK_PATH;
|
||||
|
||||
const logger = core.logging.getChildLogger({
|
||||
channel: "nextcloud-talk",
|
||||
accountId: account.accountId,
|
||||
});
|
||||
const expectedBackendOrigin = normalizeOrigin(account.baseUrl);
|
||||
const replayGuard = createNextcloudTalkReplayGuard({
|
||||
stateDir: core.state.resolveStateDir(process.env, os.homedir),
|
||||
onDiskError: (error) => {
|
||||
logger.warn(
|
||||
`[nextcloud-talk:${account.accountId}] replay guard disk error: ${String(error)}`,
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
const { start, stop } = createNextcloudTalkWebhookServer({
|
||||
port,
|
||||
host,
|
||||
path,
|
||||
secret: account.secret,
|
||||
isBackendAllowed: (backend) => {
|
||||
if (!expectedBackendOrigin) {
|
||||
return true;
|
||||
}
|
||||
const backendOrigin = normalizeOrigin(backend);
|
||||
return backendOrigin === expectedBackendOrigin;
|
||||
},
|
||||
processMessage: async (message) => {
|
||||
const result = await processNextcloudTalkReplayGuardedMessage({
|
||||
replayGuard,
|
||||
accountId: account.accountId,
|
||||
message,
|
||||
handleMessage: async () => {
|
||||
core.channel.activity.record({
|
||||
channel: "nextcloud-talk",
|
||||
accountId: account.accountId,
|
||||
direction: "inbound",
|
||||
at: message.timestamp,
|
||||
});
|
||||
if (opts.onMessage) {
|
||||
await opts.onMessage(message);
|
||||
} else {
|
||||
await handleNextcloudTalkInbound({
|
||||
message,
|
||||
account,
|
||||
config: cfg,
|
||||
runtime,
|
||||
statusSink: opts.statusSink,
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
if (result === "duplicate") {
|
||||
logger.warn(
|
||||
`[nextcloud-talk:${account.accountId}] replayed webhook ignored room=${message.roomToken} messageId=${message.messageId}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
},
|
||||
onMessage: async () => {},
|
||||
onError: (error) => {
|
||||
logger.error(`[nextcloud-talk:${account.accountId}] webhook error: ${error.message}`);
|
||||
},
|
||||
abortSignal: opts.abortSignal,
|
||||
});
|
||||
|
||||
if (opts.abortSignal?.aborted) {
|
||||
return { stop };
|
||||
}
|
||||
await start();
|
||||
if (opts.abortSignal?.aborted) {
|
||||
stop();
|
||||
return { stop };
|
||||
}
|
||||
|
||||
const publicUrl =
|
||||
account.config.webhookPublicUrl ??
|
||||
`http://${host === "0.0.0.0" ? "localhost" : host}:${port}${path}`;
|
||||
logger.info(`[nextcloud-talk:${account.accountId}] webhook listening on ${publicUrl}`);
|
||||
|
||||
return { stop };
|
||||
}
|
||||
@@ -3,7 +3,6 @@ import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { createMockIncomingRequest } from "../../../test/helpers/mock-incoming-request.js";
|
||||
import { WEBHOOK_RATE_LIMIT_DEFAULTS } from "../runtime-api.js";
|
||||
import {
|
||||
NextcloudTalkRetryableWebhookError,
|
||||
processNextcloudTalkReplayGuardedMessage,
|
||||
@@ -274,8 +273,10 @@ describe("createNextcloudTalkWebhookServer payload validation", () => {
|
||||
|
||||
describe("createNextcloudTalkWebhookServer auth rate limiting", () => {
|
||||
it("rate limits repeated invalid signature attempts from the same source", async () => {
|
||||
const maxRequests = 2;
|
||||
const harness = await startWebhookServer({
|
||||
path: "/nextcloud-auth-rate-limit",
|
||||
authRateLimit: { maxRequests },
|
||||
onMessage: vi.fn(),
|
||||
});
|
||||
const { body, headers } = createSignedCreateMessageRequest();
|
||||
@@ -286,7 +287,7 @@ describe("createNextcloudTalkWebhookServer auth rate limiting", () => {
|
||||
|
||||
let firstResponse: Response | undefined;
|
||||
let lastResponse: Response | undefined;
|
||||
for (let attempt = 0; attempt <= WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests; attempt += 1) {
|
||||
for (let attempt = 0; attempt <= maxRequests; attempt += 1) {
|
||||
const response = await fetch(harness.webhookUrl, {
|
||||
method: "POST",
|
||||
headers: invalidHeaders,
|
||||
@@ -306,14 +307,16 @@ describe("createNextcloudTalkWebhookServer auth rate limiting", () => {
|
||||
});
|
||||
|
||||
it("does not rate limit valid signed webhook bursts from the same source", async () => {
|
||||
const maxRequests = 2;
|
||||
const harness = await startWebhookServer({
|
||||
path: "/nextcloud-auth-rate-limit-valid",
|
||||
authRateLimit: { maxRequests },
|
||||
onMessage: vi.fn(),
|
||||
});
|
||||
const { body, headers } = createSignedCreateMessageRequest();
|
||||
|
||||
let lastResponse: Response | undefined;
|
||||
for (let attempt = 0; attempt <= WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests; attempt += 1) {
|
||||
for (let attempt = 0; attempt <= maxRequests; attempt += 1) {
|
||||
lastResponse = await fetch(harness.webhookUrl, {
|
||||
method: "POST",
|
||||
headers,
|
||||
|
||||
@@ -1,35 +1,22 @@
|
||||
import { createServer, type IncomingMessage, type Server, type ServerResponse } from "node:http";
|
||||
import os from "node:os";
|
||||
import {
|
||||
resolveLoggerBackedRuntime,
|
||||
safeParseJsonWithSchema,
|
||||
} from "openclaw/plugin-sdk/extension-shared";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { z } from "zod";
|
||||
import { safeParseJsonWithSchema } from "openclaw/plugin-sdk/extension-shared";
|
||||
import {
|
||||
WEBHOOK_RATE_LIMIT_DEFAULTS,
|
||||
createAuthRateLimiter,
|
||||
type RuntimeEnv,
|
||||
isRequestBodyLimitError,
|
||||
readRequestBodyWithLimit,
|
||||
requestBodyErrorToText,
|
||||
} from "../runtime-api.js";
|
||||
import { resolveNextcloudTalkAccount } from "./accounts.js";
|
||||
import { handleNextcloudTalkInbound } from "./inbound.js";
|
||||
import { createNextcloudTalkReplayGuard, type NextcloudTalkReplayGuard } from "./replay-guard.js";
|
||||
import { getNextcloudTalkRuntime } from "./runtime.js";
|
||||
} from "openclaw/plugin-sdk/webhook-ingress";
|
||||
import { z } from "zod";
|
||||
import { createAuthRateLimiter } from "./api.js";
|
||||
import type { NextcloudTalkReplayGuard } from "./replay-guard.js";
|
||||
import { extractNextcloudTalkHeaders, verifyNextcloudTalkSignature } from "./signature.js";
|
||||
import type {
|
||||
CoreConfig,
|
||||
NextcloudTalkInboundMessage,
|
||||
NextcloudTalkWebhookHeaders,
|
||||
NextcloudTalkWebhookPayload,
|
||||
NextcloudTalkWebhookServerOptions,
|
||||
} from "./types.js";
|
||||
|
||||
const DEFAULT_WEBHOOK_PORT = 8788;
|
||||
const DEFAULT_WEBHOOK_HOST = "0.0.0.0";
|
||||
const DEFAULT_WEBHOOK_PATH = "/nextcloud-talk-webhook";
|
||||
const DEFAULT_WEBHOOK_MAX_BODY_BYTES = 1024 * 1024;
|
||||
const PREAUTH_WEBHOOK_MAX_BODY_BYTES = 64 * 1024;
|
||||
const PREAUTH_WEBHOOK_BODY_TIMEOUT_MS = 5_000;
|
||||
@@ -122,14 +109,6 @@ function formatError(err: unknown): string {
|
||||
return typeof err === "string" ? err : JSON.stringify(err);
|
||||
}
|
||||
|
||||
function normalizeOrigin(value: string): string | null {
|
||||
try {
|
||||
return normalizeLowercaseStringOrEmpty(new URL(value).origin);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function parseWebhookPayload(body: string): NextcloudTalkWebhookPayload | null {
|
||||
return safeParseJsonWithSchema(NextcloudTalkWebhookPayloadSchema, body);
|
||||
}
|
||||
@@ -262,12 +241,20 @@ export function createNextcloudTalkWebhookServer(opts: NextcloudTalkWebhookServe
|
||||
const isBackendAllowed = opts.isBackendAllowed;
|
||||
const shouldProcessMessage = opts.shouldProcessMessage;
|
||||
const processMessage = opts.processMessage;
|
||||
const authRateLimitMaxRequests =
|
||||
typeof opts.authRateLimit?.maxRequests === "number"
|
||||
? opts.authRateLimit.maxRequests
|
||||
: WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests;
|
||||
const authRateLimitWindowMs =
|
||||
typeof opts.authRateLimit?.windowMs === "number"
|
||||
? opts.authRateLimit.windowMs
|
||||
: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs;
|
||||
const webhookAuthRateLimiter = createAuthRateLimiter({
|
||||
maxAttempts: WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests,
|
||||
windowMs: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs,
|
||||
lockoutMs: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs,
|
||||
maxAttempts: authRateLimitMaxRequests,
|
||||
windowMs: authRateLimitWindowMs,
|
||||
lockoutMs: authRateLimitWindowMs,
|
||||
exemptLoopback: false,
|
||||
pruneIntervalMs: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs,
|
||||
pruneIntervalMs: authRateLimitWindowMs,
|
||||
});
|
||||
|
||||
const server = createServer(async (req: IncomingMessage, res: ServerResponse) => {
|
||||
@@ -396,116 +383,3 @@ export function createNextcloudTalkWebhookServer(opts: NextcloudTalkWebhookServe
|
||||
|
||||
return { server, start, stop };
|
||||
}
|
||||
|
||||
export type NextcloudTalkMonitorOptions = {
|
||||
accountId?: string;
|
||||
config?: CoreConfig;
|
||||
runtime?: RuntimeEnv;
|
||||
abortSignal?: AbortSignal;
|
||||
onMessage?: (message: NextcloudTalkInboundMessage) => void | Promise<void>;
|
||||
statusSink?: (patch: { lastInboundAt?: number; lastOutboundAt?: number }) => void;
|
||||
};
|
||||
|
||||
export async function monitorNextcloudTalkProvider(
|
||||
opts: NextcloudTalkMonitorOptions,
|
||||
): Promise<{ stop: () => void }> {
|
||||
const core = getNextcloudTalkRuntime();
|
||||
const cfg = opts.config ?? (core.config.loadConfig() as CoreConfig);
|
||||
const account = resolveNextcloudTalkAccount({
|
||||
cfg,
|
||||
accountId: opts.accountId,
|
||||
});
|
||||
const runtime: RuntimeEnv = resolveLoggerBackedRuntime(
|
||||
opts.runtime,
|
||||
core.logging.getChildLogger(),
|
||||
);
|
||||
|
||||
if (!account.secret) {
|
||||
throw new Error(`Nextcloud Talk bot secret not configured for account "${account.accountId}"`);
|
||||
}
|
||||
|
||||
const port = account.config.webhookPort ?? DEFAULT_WEBHOOK_PORT;
|
||||
const host = account.config.webhookHost ?? DEFAULT_WEBHOOK_HOST;
|
||||
const path = account.config.webhookPath ?? DEFAULT_WEBHOOK_PATH;
|
||||
|
||||
const logger = core.logging.getChildLogger({
|
||||
channel: "nextcloud-talk",
|
||||
accountId: account.accountId,
|
||||
});
|
||||
const expectedBackendOrigin = normalizeOrigin(account.baseUrl);
|
||||
const replayGuard = createNextcloudTalkReplayGuard({
|
||||
stateDir: core.state.resolveStateDir(process.env, os.homedir),
|
||||
onDiskError: (error) => {
|
||||
logger.warn(
|
||||
`[nextcloud-talk:${account.accountId}] replay guard disk error: ${String(error)}`,
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
const { start, stop } = createNextcloudTalkWebhookServer({
|
||||
port,
|
||||
host,
|
||||
path,
|
||||
secret: account.secret,
|
||||
isBackendAllowed: (backend) => {
|
||||
if (!expectedBackendOrigin) {
|
||||
return true;
|
||||
}
|
||||
const backendOrigin = normalizeOrigin(backend);
|
||||
return backendOrigin === expectedBackendOrigin;
|
||||
},
|
||||
processMessage: async (message) => {
|
||||
const result = await processNextcloudTalkReplayGuardedMessage({
|
||||
replayGuard,
|
||||
accountId: account.accountId,
|
||||
message,
|
||||
handleMessage: async () => {
|
||||
core.channel.activity.record({
|
||||
channel: "nextcloud-talk",
|
||||
accountId: account.accountId,
|
||||
direction: "inbound",
|
||||
at: message.timestamp,
|
||||
});
|
||||
if (opts.onMessage) {
|
||||
await opts.onMessage(message);
|
||||
} else {
|
||||
await handleNextcloudTalkInbound({
|
||||
message,
|
||||
account,
|
||||
config: cfg,
|
||||
runtime,
|
||||
statusSink: opts.statusSink,
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
if (result === "duplicate") {
|
||||
logger.warn(
|
||||
`[nextcloud-talk:${account.accountId}] replayed webhook ignored room=${message.roomToken} messageId=${message.messageId}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
},
|
||||
onMessage: async () => {},
|
||||
onError: (error) => {
|
||||
logger.error(`[nextcloud-talk:${account.accountId}] webhook error: ${error.message}`);
|
||||
},
|
||||
abortSignal: opts.abortSignal,
|
||||
});
|
||||
|
||||
if (opts.abortSignal?.aborted) {
|
||||
return { stop };
|
||||
}
|
||||
await start();
|
||||
if (opts.abortSignal?.aborted) {
|
||||
stop();
|
||||
return { stop };
|
||||
}
|
||||
|
||||
const publicUrl =
|
||||
account.config.webhookPublicUrl ??
|
||||
`http://${host === "0.0.0.0" ? "localhost" : host}:${port}${path}`;
|
||||
logger.info(`[nextcloud-talk:${account.accountId}] webhook listening on ${publicUrl}`);
|
||||
|
||||
return { stop };
|
||||
}
|
||||
|
||||
@@ -179,6 +179,10 @@ export type NextcloudTalkWebhookServerOptions = {
|
||||
path: string;
|
||||
secret: string;
|
||||
maxBodyBytes?: number;
|
||||
authRateLimit?: {
|
||||
maxRequests?: number;
|
||||
windowMs?: number;
|
||||
};
|
||||
readBody?: (req: import("node:http").IncomingMessage, maxBodyBytes: number) => Promise<string>;
|
||||
isBackendAllowed?: (backend: string) => boolean;
|
||||
shouldProcessMessage?: (message: NextcloudTalkInboundMessage) => boolean | Promise<boolean>;
|
||||
|
||||
@@ -8,6 +8,7 @@ export const ollamaMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapte
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "ollama",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider({
|
||||
...options,
|
||||
|
||||
@@ -17,8 +17,8 @@ import {
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { OpenAiEmbeddingClient } from "./embedding-provider.js";
|
||||
|
||||
export type OpenAiBatchRequest = {
|
||||
custom_id: string;
|
||||
@@ -1,11 +1,11 @@
|
||||
import { parseStaticModelRef } from "../../agents/model-ref-shared.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../plugins/provider-model-defaults.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "./default-models.js";
|
||||
|
||||
export type OpenAiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
@@ -28,13 +28,12 @@ export function normalizeOpenAiModel(model: string): string {
|
||||
if (!trimmed) {
|
||||
return DEFAULT_OPENAI_EMBEDDING_MODEL;
|
||||
}
|
||||
const parsed = parseStaticModelRef(trimmed, "openai");
|
||||
return parsed && parsed.provider === "openai" ? parsed.model : trimmed;
|
||||
return trimmed.startsWith("openai/") ? trimmed.slice("openai/".length) : trimmed;
|
||||
}
|
||||
|
||||
export async function createOpenAiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
const client = await resolveOpenAiEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
@@ -49,7 +48,7 @@ export async function createOpenAiEmbeddingProvider(
|
||||
}
|
||||
|
||||
export async function resolveOpenAiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<OpenAiEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "openai",
|
||||
@@ -54,7 +54,7 @@ const _registerOpenAIPlugin = async () =>
|
||||
async function registerOpenAIPluginWithHook(params?: { pluginConfig?: Record<string, unknown> }) {
|
||||
const on = vi.fn();
|
||||
const providers: ProviderPlugin[] = [];
|
||||
await plugin.register(
|
||||
plugin.register(
|
||||
createTestPluginApi({
|
||||
id: "openai",
|
||||
name: "OpenAI Provider",
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
openaiCodexMediaUnderstandingProvider,
|
||||
openaiMediaUnderstandingProvider,
|
||||
} from "./media-understanding-provider.js";
|
||||
import { openAiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { buildOpenAICodexProviderPlugin } from "./openai-codex-provider.js";
|
||||
import { buildOpenAIProvider } from "./openai-provider.js";
|
||||
import {
|
||||
@@ -39,6 +40,7 @@ export default definePluginEntry({
|
||||
api.registerCliBackend(buildOpenAICodexCliBackend());
|
||||
api.registerProvider(buildProviderWithPromptContribution(buildOpenAIProvider()));
|
||||
api.registerProvider(buildProviderWithPromptContribution(buildOpenAICodexProviderPlugin()));
|
||||
api.registerMemoryEmbeddingProvider(openAiMemoryEmbeddingProviderAdapter);
|
||||
api.registerImageGenerationProvider(buildOpenAIImageGenerationProvider());
|
||||
api.registerRealtimeTranscriptionProvider(buildOpenAIRealtimeTranscriptionProvider());
|
||||
api.registerRealtimeVoiceProvider(buildOpenAIRealtimeVoiceProvider());
|
||||
|
||||
61
extensions/openai/memory-embedding-adapter.ts
Normal file
61
extensions/openai/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { OPENAI_BATCH_ENDPOINT, runOpenAiEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
createOpenAiEmbeddingProvider,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const openAiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "openai",
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "openai",
|
||||
autoSelectPriority: 20,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "openai",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "openai",
|
||||
cacheKeyData: {
|
||||
provider: "openai",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runOpenAiEmbeddingBatches({
|
||||
openAi: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: {
|
||||
model: client.model,
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -39,6 +39,7 @@
|
||||
"speechProviders": ["openai"],
|
||||
"realtimeTranscriptionProviders": ["openai"],
|
||||
"realtimeVoiceProviders": ["openai"],
|
||||
"memoryEmbeddingProviders": ["openai"],
|
||||
"mediaUnderstandingProviders": ["openai", "openai-codex"],
|
||||
"imageGenerationProviders": ["openai"],
|
||||
"videoGenerationProviders": ["openai"]
|
||||
|
||||
@@ -80,7 +80,7 @@ async function withRegisteredPhoneControl(
|
||||
});
|
||||
|
||||
let command: OpenClawPluginCommandDefinition | undefined;
|
||||
void registerPhoneControl.register(
|
||||
registerPhoneControl.register(
|
||||
createApi({
|
||||
stateDir,
|
||||
getConfig: () => config,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
"description": "OpenClaw QA lab plugin with private debugger UI and scenario runner",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"@copilotkit/aimock": "1.13.0",
|
||||
"playwright-core": "1.59.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { createPluginSetupWizardStatus } from "../../../test/helpers/plugins/setup-wizard.js";
|
||||
import type { ResolvedSynologyChatAccount } from "./types.js";
|
||||
|
||||
@@ -42,12 +42,18 @@ const getSynologyChatSetupStatus = createPluginSetupWizardStatus(synologyChatPlu
|
||||
|
||||
describe("createSynologyChatPlugin", () => {
|
||||
beforeEach(() => {
|
||||
vi.stubEnv("SYNOLOGY_CHAT_TOKEN", "");
|
||||
vi.stubEnv("SYNOLOGY_CHAT_INCOMING_URL", "");
|
||||
mockSendMessage.mockClear();
|
||||
registerSynologyWebhookRouteMock.mockClear();
|
||||
mockSendMessage.mockResolvedValue(true);
|
||||
registerSynologyWebhookRouteMock.mockImplementation(() => vi.fn());
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
describe("meta", () => {
|
||||
it("has correct id and label", () => {
|
||||
const plugin = createSynologyChatPlugin();
|
||||
@@ -480,11 +486,17 @@ describe("createSynologyChatPlugin", () => {
|
||||
abortController: AbortController,
|
||||
) {
|
||||
expect(result).toBeInstanceOf(Promise);
|
||||
const resolved = await Promise.race([
|
||||
result,
|
||||
new Promise((r) => setTimeout(() => r("pending"), 50)),
|
||||
]);
|
||||
expect(resolved).toBe("pending");
|
||||
let settled = false;
|
||||
void result.then(
|
||||
() => {
|
||||
settled = true;
|
||||
},
|
||||
() => {
|
||||
settled = true;
|
||||
},
|
||||
);
|
||||
await Promise.resolve();
|
||||
expect(settled).toBe(false);
|
||||
abortController.abort();
|
||||
await result;
|
||||
}
|
||||
@@ -584,8 +596,6 @@ describe("createSynologyChatPlugin", () => {
|
||||
const firstPromise = plugin.gateway.startAccount(makeCtx(abortFirst));
|
||||
const secondPromise = plugin.gateway.startAccount(makeCtx(abortSecond));
|
||||
|
||||
await new Promise((r) => setTimeout(r, 10));
|
||||
|
||||
expect(registerMock).toHaveBeenCalledTimes(2);
|
||||
expect(unregisterFirst).not.toHaveBeenCalled();
|
||||
expect(unregisterSecond).not.toHaveBeenCalled();
|
||||
|
||||
@@ -144,26 +144,19 @@ describe("createWebhookHandler", () => {
|
||||
});
|
||||
|
||||
it("returns 408 when request body times out", async () => {
|
||||
vi.useFakeTimers();
|
||||
try {
|
||||
const handler = createWebhookHandler({
|
||||
account: makeAccount(),
|
||||
deliver: vi.fn(),
|
||||
log,
|
||||
});
|
||||
const handler = createWebhookHandler({
|
||||
account: makeAccount(),
|
||||
deliver: vi.fn(),
|
||||
log,
|
||||
bodyTimeoutMs: 1,
|
||||
});
|
||||
|
||||
const req = makeStalledReq("POST");
|
||||
const res = makeRes();
|
||||
const run = handler(req, res);
|
||||
const req = makeStalledReq("POST");
|
||||
const res = makeRes();
|
||||
await handler(req, res);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(30_000);
|
||||
await run;
|
||||
|
||||
expect(res._status).toBe(408);
|
||||
expect(res._body).toContain("timeout");
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
expect(res._status).toBe(408);
|
||||
expect(res._body).toContain("timeout");
|
||||
});
|
||||
|
||||
it("rejects excess concurrent pre-auth body reads from the same remote IP", async () => {
|
||||
|
||||
@@ -142,7 +142,10 @@ function getSynologyWebhookInFlightKey(account: ResolvedSynologyChatAccount): st
|
||||
}
|
||||
|
||||
/** Read the full request body as a string. */
|
||||
async function readBody(req: IncomingMessage): Promise<
|
||||
async function readBody(
|
||||
req: IncomingMessage,
|
||||
timeoutMs = PREAUTH_BODY_TIMEOUT_MS,
|
||||
): Promise<
|
||||
| { ok: true; body: string }
|
||||
| {
|
||||
ok: false;
|
||||
@@ -153,7 +156,7 @@ async function readBody(req: IncomingMessage): Promise<
|
||||
try {
|
||||
const body = await readRequestBodyWithLimit(req, {
|
||||
maxBytes: PREAUTH_MAX_BODY_BYTES,
|
||||
timeoutMs: PREAUTH_BODY_TIMEOUT_MS,
|
||||
timeoutMs,
|
||||
});
|
||||
return { ok: true, body };
|
||||
} catch (err) {
|
||||
@@ -342,6 +345,7 @@ export interface WebhookHandlerDeps {
|
||||
warn: (...args: unknown[]) => void;
|
||||
error: (...args: unknown[]) => void;
|
||||
};
|
||||
bodyTimeoutMs?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -371,8 +375,9 @@ async function parseWebhookPayloadRequest(params: {
|
||||
req: IncomingMessage;
|
||||
res: ServerResponse;
|
||||
log?: WebhookHandlerDeps["log"];
|
||||
bodyTimeoutMs?: number;
|
||||
}): Promise<{ ok: false } | { ok: true; payload: SynologyWebhookPayload }> {
|
||||
const bodyResult = await readBody(params.req);
|
||||
const bodyResult = await readBody(params.req, params.bodyTimeoutMs);
|
||||
if (!bodyResult.ok) {
|
||||
params.log?.error("Failed to read request body", bodyResult.error);
|
||||
respondJson(params.res, bodyResult.statusCode, { error: bodyResult.error });
|
||||
@@ -465,6 +470,7 @@ async function parseAndAuthorizeSynologyWebhook(params: {
|
||||
invalidTokenRateLimiter: InvalidTokenRateLimiter;
|
||||
rateLimiter: RateLimiter;
|
||||
log?: WebhookHandlerDeps["log"];
|
||||
bodyTimeoutMs?: number;
|
||||
}): Promise<{ ok: false } | { ok: true; message: AuthorizedSynologyWebhook }> {
|
||||
const parsed = await parseWebhookPayloadRequest(params);
|
||||
if (!parsed.ok) {
|
||||
@@ -612,6 +618,7 @@ export function createWebhookHandler(deps: WebhookHandlerDeps) {
|
||||
invalidTokenRateLimiter,
|
||||
rateLimiter,
|
||||
log,
|
||||
bodyTimeoutMs: deps.bodyTimeoutMs,
|
||||
});
|
||||
} finally {
|
||||
// Only bound the pre-auth request pipeline; async reply delivery is outside webhook ingress.
|
||||
|
||||
@@ -20,7 +20,7 @@ function createHarness(config: Record<string, unknown>) {
|
||||
command = definition;
|
||||
}),
|
||||
};
|
||||
void register.register(api as never);
|
||||
register.register(api as never);
|
||||
if (!command) {
|
||||
throw new Error("talk-voice command not registered");
|
||||
}
|
||||
|
||||
@@ -800,6 +800,82 @@ describe("dispatchTelegramMessage draft streaming", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves pre-rotation skip until queued message-start callbacks flush", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
await replyOptions?.onPartialReply?.({ text: "Message B early" });
|
||||
void replyOptions?.onAssistantMessageStart?.();
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
123,
|
||||
1002,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("does not double-rotate when assistant_message_start arrives after final delivery drains", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
await replyOptions?.onPartialReply?.({ text: "Message B early" });
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
123,
|
||||
1002,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("clears active preview even when an unrelated boundary archive exists", async () => {
|
||||
const answerDraftStream = createDraftStream(999);
|
||||
answerDraftStream.materialize.mockResolvedValue(4321);
|
||||
@@ -1054,6 +1130,204 @@ describe("dispatchTelegramMessage draft streaming", () => {
|
||||
expect(answerDraftStream.update).toHaveBeenNthCalledWith(2, "Message B second chunk");
|
||||
});
|
||||
|
||||
it("does not rotate the streamed preview when compaction retries replay the same assistant message", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial extended" });
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).not.toHaveBeenCalled();
|
||||
expect(answerDraftStream.materialize).not.toHaveBeenCalled();
|
||||
expect(editMessageTelegram).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenCalledWith(
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("clears the compaction replay skip after the retried message finalizes", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial extended" });
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await replyOptions?.onPartialReply?.({ text: "Message B partial" });
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(editMessageTelegram).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
123,
|
||||
1002,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves the compaction replay flag until queued retry callbacks flush", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
void replyOptions?.onAssistantMessageStart?.();
|
||||
await dispatcherOptions.deliver({ text: "Message A final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).not.toHaveBeenCalled();
|
||||
expect(editMessageTelegram).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenCalledWith(
|
||||
123,
|
||||
1001,
|
||||
"Message A final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("keeps the existing preview when the retried answer only arrives as final text", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({ context: createContext(), streamMode: "partial" });
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).not.toHaveBeenCalled();
|
||||
expect(answerDraftStream.materialize).not.toHaveBeenCalled();
|
||||
expect(editMessageTelegram).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenCalledWith(
|
||||
123,
|
||||
1001,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("keeps the transient preview when a local exec approval prompt is suppressed after compaction", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
createTelegramDraftStream
|
||||
.mockImplementationOnce(() => answerDraftStream)
|
||||
.mockImplementationOnce(() => reasoningDraftStream);
|
||||
dispatchReplyWithBufferedBlockDispatcher.mockImplementation(
|
||||
async ({ dispatcherOptions, replyOptions }) => {
|
||||
await replyOptions?.onPartialReply?.({ text: "Message A partial" });
|
||||
await replyOptions?.onCompactionStart?.();
|
||||
await replyOptions?.onCompactionEnd?.();
|
||||
await dispatcherOptions.deliver(
|
||||
{
|
||||
text: "Approval required.\n\n```txt\n/approve 7f423fdc allow-once\n```",
|
||||
channelData: {
|
||||
execApproval: {
|
||||
approvalId: "7f423fdc-1111-2222-3333-444444444444",
|
||||
approvalSlug: "7f423fdc",
|
||||
allowedDecisions: ["allow-once", "allow-always", "deny"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{ kind: "tool" },
|
||||
);
|
||||
await replyOptions?.onAssistantMessageStart?.();
|
||||
await replyOptions?.onPartialReply?.({ text: "Message B partial" });
|
||||
await dispatcherOptions.deliver({ text: "Message B final" }, { kind: "final" });
|
||||
return { queuedFinal: true };
|
||||
},
|
||||
);
|
||||
deliverReplies.mockResolvedValue({ delivered: true });
|
||||
editMessageTelegram.mockResolvedValue({ ok: true, chatId: "123", messageId: "1001" });
|
||||
|
||||
await dispatchWithContext({
|
||||
context: createContext(),
|
||||
streamMode: "partial",
|
||||
cfg: {
|
||||
channels: {
|
||||
telegram: {
|
||||
execApprovals: {
|
||||
enabled: true,
|
||||
approvers: ["12345"],
|
||||
target: "dm",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(answerDraftStream.forceNewMessage).not.toHaveBeenCalled();
|
||||
expect(editMessageTelegram).toHaveBeenCalledTimes(1);
|
||||
expect(editMessageTelegram).toHaveBeenCalledWith(
|
||||
123,
|
||||
1001,
|
||||
"Message B final",
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it("finalizes multi-message assistant stream to matching preview messages in order", async () => {
|
||||
const answerDraftStream = createSequencedDraftStream(1001);
|
||||
const reasoningDraftStream = createDraftStream();
|
||||
|
||||
@@ -280,6 +280,10 @@ export const dispatchTelegramMessage = async ({
|
||||
const reasoningLane = lanes.reasoning;
|
||||
let splitReasoningOnNextStream = false;
|
||||
let skipNextAnswerMessageStartRotation = false;
|
||||
// If compaction interrupts a still-transient answer preview, keep the next
|
||||
// assistant-message boundary on that same preview instead of materializing a
|
||||
// duplicate retry message.
|
||||
let pendingCompactionReplayBoundary = false;
|
||||
let draftLaneEventQueue = Promise.resolve();
|
||||
const reasoningStepState = createTelegramReasoningStepState();
|
||||
const enqueueDraftLaneEvent = (task: () => Promise<void>): Promise<void> => {
|
||||
@@ -693,6 +697,9 @@ export const dispatchTelegramMessage = async ({
|
||||
}
|
||||
}
|
||||
if (segments.length > 0) {
|
||||
if (info.kind === "final") {
|
||||
pendingCompactionReplayBoundary = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (split.suppressedReasoningOnly) {
|
||||
@@ -703,6 +710,7 @@ export const dispatchTelegramMessage = async ({
|
||||
}
|
||||
if (info.kind === "final") {
|
||||
await flushBufferedFinalAnswer();
|
||||
pendingCompactionReplayBoundary = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -716,12 +724,14 @@ export const dispatchTelegramMessage = async ({
|
||||
if (!canSendAsIs) {
|
||||
if (info.kind === "final") {
|
||||
await flushBufferedFinalAnswer();
|
||||
pendingCompactionReplayBoundary = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
await sendPayload(payload);
|
||||
if (info.kind === "final") {
|
||||
await flushBufferedFinalAnswer();
|
||||
pendingCompactionReplayBoundary = false;
|
||||
}
|
||||
},
|
||||
onSkip: (payload, info) => {
|
||||
@@ -793,6 +803,12 @@ export const dispatchTelegramMessage = async ({
|
||||
retainPreviewOnCleanupByLane.answer = false;
|
||||
return;
|
||||
}
|
||||
if (pendingCompactionReplayBoundary) {
|
||||
pendingCompactionReplayBoundary = false;
|
||||
activePreviewLifecycleByLane.answer = "transient";
|
||||
retainPreviewOnCleanupByLane.answer = false;
|
||||
return;
|
||||
}
|
||||
await rotateAnswerLaneForNewAssistantMessage();
|
||||
// Message-start is an explicit assistant-message boundary.
|
||||
// Even when no forceNewMessage happened (e.g. prior answer had no
|
||||
@@ -817,9 +833,20 @@ export const dispatchTelegramMessage = async ({
|
||||
}
|
||||
}
|
||||
: undefined,
|
||||
onCompactionStart: statusReactionController
|
||||
? () => statusReactionController.setCompacting()
|
||||
: undefined,
|
||||
onCompactionStart:
|
||||
statusReactionController || answerLane.stream
|
||||
? async () => {
|
||||
if (
|
||||
answerLane.hasStreamedMessage &&
|
||||
activePreviewLifecycleByLane.answer === "transient"
|
||||
) {
|
||||
pendingCompactionReplayBoundary = true;
|
||||
}
|
||||
if (statusReactionController) {
|
||||
await statusReactionController.setCompacting();
|
||||
}
|
||||
}
|
||||
: undefined,
|
||||
onCompactionEnd: statusReactionController
|
||||
? async () => {
|
||||
statusReactionController.cancelPending();
|
||||
|
||||
@@ -7,6 +7,18 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { importFreshModule } from "../../../test/helpers/import-fresh.js";
|
||||
|
||||
const writeJsonFileAtomicallyMock = vi.hoisted(() => vi.fn());
|
||||
const readAcpSessionEntryMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/acp-runtime", async () => {
|
||||
const actual = await vi.importActual<typeof import("openclaw/plugin-sdk/acp-runtime")>(
|
||||
"openclaw/plugin-sdk/acp-runtime",
|
||||
);
|
||||
readAcpSessionEntryMock.mockImplementation(actual.readAcpSessionEntry);
|
||||
return {
|
||||
...actual,
|
||||
readAcpSessionEntry: readAcpSessionEntryMock,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/json-store", async () => {
|
||||
const actual = await vi.importActual<typeof import("openclaw/plugin-sdk/json-store")>(
|
||||
@@ -36,6 +48,11 @@ describe("telegram thread bindings", () => {
|
||||
|
||||
beforeEach(async () => {
|
||||
writeJsonFileAtomicallyMock.mockClear();
|
||||
readAcpSessionEntryMock.mockReset();
|
||||
const acpRuntime = await vi.importActual<typeof import("openclaw/plugin-sdk/acp-runtime")>(
|
||||
"openclaw/plugin-sdk/acp-runtime",
|
||||
);
|
||||
readAcpSessionEntryMock.mockImplementation(acpRuntime.readAcpSessionEntry);
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
});
|
||||
|
||||
@@ -293,6 +310,136 @@ describe("telegram thread bindings", () => {
|
||||
expect(reloaded.getByConversationId("8460800771")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("cleans up stale ACP bindings before restart routing can reuse them", async () => {
|
||||
stateDirOverride = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-telegram-bindings-"));
|
||||
process.env.OPENCLAW_STATE_DIR = stateDirOverride;
|
||||
|
||||
createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
await getSessionBindingService().bind({
|
||||
targetSessionKey: "agent:main:acp:stale-1",
|
||||
targetKind: "session",
|
||||
conversation: {
|
||||
channel: "telegram",
|
||||
accountId: "default",
|
||||
conversationId: "cleanup-me",
|
||||
},
|
||||
});
|
||||
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
readAcpSessionEntryMock.mockReturnValue({
|
||||
cfg: {} as never,
|
||||
storePath: "/tmp/acp-store.json",
|
||||
sessionKey: "agent:main:acp:stale-1",
|
||||
storeSessionKey: "agent:main:acp:stale-1",
|
||||
entry: undefined,
|
||||
acp: undefined,
|
||||
storeReadFailed: false,
|
||||
});
|
||||
|
||||
const reloaded = createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
expect(reloaded.getByConversationId("cleanup-me")).toBeUndefined();
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
const persisted = JSON.parse(
|
||||
fs.readFileSync(
|
||||
path.join(
|
||||
resolveStateDir(process.env, os.homedir),
|
||||
"telegram",
|
||||
"thread-bindings-default.json",
|
||||
),
|
||||
"utf8",
|
||||
),
|
||||
) as { bindings?: Array<{ conversationId?: string }> };
|
||||
expect(persisted.bindings?.map((binding) => binding.conversationId)).not.toContain(
|
||||
"cleanup-me",
|
||||
);
|
||||
});
|
||||
|
||||
it("keeps plugin-owned bindings when ACP cleanup runs on startup", async () => {
|
||||
stateDirOverride = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-telegram-bindings-"));
|
||||
process.env.OPENCLAW_STATE_DIR = stateDirOverride;
|
||||
|
||||
createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
await getSessionBindingService().bind({
|
||||
targetSessionKey: "plugin-binding:openclaw-codex-app-server:still-valid",
|
||||
targetKind: "session",
|
||||
conversation: {
|
||||
channel: "telegram",
|
||||
accountId: "default",
|
||||
conversationId: "plugin-binding-convo",
|
||||
},
|
||||
});
|
||||
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
|
||||
const reloaded = createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
expect(reloaded.getByConversationId("plugin-binding-convo")?.targetSessionKey).toBe(
|
||||
"plugin-binding:openclaw-codex-app-server:still-valid",
|
||||
);
|
||||
expect(readAcpSessionEntryMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("keeps ACP bindings when the session store cannot be read during startup cleanup", async () => {
|
||||
stateDirOverride = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-telegram-bindings-"));
|
||||
process.env.OPENCLAW_STATE_DIR = stateDirOverride;
|
||||
|
||||
createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
await getSessionBindingService().bind({
|
||||
targetSessionKey: "agent:main:acp:read-failed",
|
||||
targetKind: "session",
|
||||
conversation: {
|
||||
channel: "telegram",
|
||||
accountId: "default",
|
||||
conversationId: "keep-on-read-failure",
|
||||
},
|
||||
});
|
||||
|
||||
await __testing.resetTelegramThreadBindingsForTests();
|
||||
readAcpSessionEntryMock.mockReturnValue({
|
||||
cfg: {} as never,
|
||||
storePath: "/tmp/acp-store.json",
|
||||
sessionKey: "agent:main:acp:read-failed",
|
||||
storeSessionKey: "agent:main:acp:read-failed",
|
||||
entry: undefined,
|
||||
acp: undefined,
|
||||
storeReadFailed: true,
|
||||
});
|
||||
|
||||
const reloaded = createTelegramThreadBindingManager({
|
||||
accountId: "default",
|
||||
persist: true,
|
||||
enableSweeper: false,
|
||||
});
|
||||
|
||||
expect(reloaded.getByConversationId("keep-on-read-failure")?.targetSessionKey).toBe(
|
||||
"agent:main:acp:read-failed",
|
||||
);
|
||||
});
|
||||
|
||||
it("flushes pending lifecycle update persists before test reset", async () => {
|
||||
stateDirOverride = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-telegram-bindings-"));
|
||||
process.env.OPENCLAW_STATE_DIR = stateDirOverride;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import fs from "node:fs";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { readAcpSessionEntry } from "openclaw/plugin-sdk/acp-runtime";
|
||||
import { loadConfig } from "openclaw/plugin-sdk/config-runtime";
|
||||
import {
|
||||
formatThreadBindingDurationLabel,
|
||||
@@ -14,7 +15,7 @@ import {
|
||||
} from "openclaw/plugin-sdk/conversation-runtime";
|
||||
import { formatErrorMessage } from "openclaw/plugin-sdk/error-runtime";
|
||||
import { writeJsonFileAtomically } from "openclaw/plugin-sdk/json-store";
|
||||
import { normalizeAccountId } from "openclaw/plugin-sdk/routing";
|
||||
import { normalizeAccountId, isAcpSessionKey } from "openclaw/plugin-sdk/routing";
|
||||
import { logVerbose } from "openclaw/plugin-sdk/runtime-env";
|
||||
import { resolveStateDir } from "openclaw/plugin-sdk/state-paths";
|
||||
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
|
||||
@@ -440,6 +441,58 @@ export function createTelegramThreadBindingManager(
|
||||
});
|
||||
}
|
||||
|
||||
const acpSessionKeys = new Set<string>();
|
||||
for (const binding of getThreadBindingsState().bindingsByAccountConversation.values()) {
|
||||
if (binding.targetKind !== "acp" || !isAcpSessionKey(binding.targetSessionKey)) {
|
||||
continue;
|
||||
}
|
||||
acpSessionKeys.add(binding.targetSessionKey);
|
||||
}
|
||||
|
||||
const staleSessionKeys = new Set<string>();
|
||||
for (const targetSessionKey of acpSessionKeys) {
|
||||
const sessionEntry = readAcpSessionEntry({ sessionKey: targetSessionKey });
|
||||
if (!sessionEntry || sessionEntry.storeReadFailed) {
|
||||
continue;
|
||||
}
|
||||
const isStale =
|
||||
!sessionEntry.entry ||
|
||||
sessionEntry.entry.status === "failed" ||
|
||||
sessionEntry.entry.status === "killed" ||
|
||||
sessionEntry.entry.status === "timeout" ||
|
||||
sessionEntry.entry.acp?.state === "error";
|
||||
if (isStale) {
|
||||
staleSessionKeys.add(targetSessionKey);
|
||||
}
|
||||
}
|
||||
|
||||
let needsPersist = false;
|
||||
for (const sessionKey of staleSessionKeys) {
|
||||
const bindingsToRemove = listBindingsForAccount(accountId).filter(
|
||||
(b) => b.targetSessionKey === sessionKey,
|
||||
);
|
||||
for (const binding of bindingsToRemove) {
|
||||
getThreadBindingsState().bindingsByAccountConversation.delete(
|
||||
resolveBindingKey({ accountId, conversationId: binding.conversationId }),
|
||||
);
|
||||
}
|
||||
if (bindingsToRemove.length > 0) {
|
||||
needsPersist = true;
|
||||
logVerbose(
|
||||
`telegram thread binding: cleaned up ${bindingsToRemove.length} stale binding(s) for session ${sessionKey}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (needsPersist && persist) {
|
||||
persistBindingsSafely({
|
||||
accountId,
|
||||
persist: true,
|
||||
bindings: listBindingsForAccount(accountId),
|
||||
reason: "cleanup-stale",
|
||||
});
|
||||
}
|
||||
|
||||
let sweepTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
const manager: TelegramThreadBindingManager = {
|
||||
|
||||
@@ -40,8 +40,8 @@ describe("thread-ownership plugin", () => {
|
||||
});
|
||||
|
||||
describe("message_sending", () => {
|
||||
beforeEach(async () => {
|
||||
await register.register(api as unknown as OpenClawPluginApi);
|
||||
beforeEach(() => {
|
||||
register.register(api as unknown as OpenClawPluginApi);
|
||||
});
|
||||
|
||||
async function sendSlackThreadMessage() {
|
||||
@@ -112,8 +112,8 @@ describe("thread-ownership plugin", () => {
|
||||
});
|
||||
|
||||
describe("message_received @-mention tracking", () => {
|
||||
beforeEach(async () => {
|
||||
await register.register(api as unknown as OpenClawPluginApi);
|
||||
beforeEach(() => {
|
||||
register.register(api as unknown as OpenClawPluginApi);
|
||||
});
|
||||
|
||||
it("tracks @-mentions and skips ownership check for mentioned threads", async () => {
|
||||
|
||||
@@ -1,82 +1,65 @@
|
||||
import { describe, expect, it, vi, afterEach, beforeEach } from "vitest";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { fetchWithSsrFGuard } from "../../runtime-api.js";
|
||||
import { uploadFile } from "../tlon-api.js";
|
||||
import { uploadImageFromUrl } from "./upload.js";
|
||||
|
||||
// Mock fetchWithSsrFGuard from the local runtime seam.
|
||||
vi.mock("../../runtime-api.js", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("../../runtime-api.js")>("../../runtime-api.js");
|
||||
return {
|
||||
...actual,
|
||||
fetchWithSsrFGuard: vi.fn(),
|
||||
};
|
||||
});
|
||||
vi.mock("../../runtime-api.js", () => ({
|
||||
fetchWithSsrFGuard: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the local Tlon upload seam.
|
||||
vi.mock("../tlon-api.js", () => ({
|
||||
uploadFile: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockFetch = vi.mocked(fetchWithSsrFGuard);
|
||||
const mockUploadFile = vi.mocked(uploadFile);
|
||||
|
||||
type FetchMock = typeof mockFetch;
|
||||
|
||||
function mockSuccessfulFetch(params: {
|
||||
mockFetch: FetchMock;
|
||||
blob: Blob;
|
||||
finalUrl: string;
|
||||
contentType: string;
|
||||
}) {
|
||||
params.mockFetch.mockResolvedValue({
|
||||
response: {
|
||||
ok: true,
|
||||
headers: new Headers({ "content-type": params.contentType }),
|
||||
blob: () => Promise.resolve(params.blob),
|
||||
} as unknown as Response,
|
||||
finalUrl: params.finalUrl,
|
||||
release: vi.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
}
|
||||
|
||||
async function setupSuccessfulUpload(params?: {
|
||||
sourceUrl?: string;
|
||||
contentType?: string;
|
||||
uploadedUrl?: string;
|
||||
}) {
|
||||
const sourceUrl = params?.sourceUrl ?? "https://example.com/image.png";
|
||||
const contentType = params?.contentType ?? "image/png";
|
||||
const mockBlob = new Blob(["fake-image"], { type: contentType });
|
||||
mockSuccessfulFetch({
|
||||
mockFetch,
|
||||
blob: mockBlob,
|
||||
finalUrl: sourceUrl,
|
||||
contentType,
|
||||
});
|
||||
if (params?.uploadedUrl) {
|
||||
mockUploadFile.mockResolvedValue({ url: params.uploadedUrl });
|
||||
}
|
||||
return { mockBlob };
|
||||
}
|
||||
|
||||
describe("uploadImageFromUrl", () => {
|
||||
async function loadUploadMocks() {
|
||||
const { fetchWithSsrFGuard } = await import("../../runtime-api.js");
|
||||
const { uploadFile } = await import("../tlon-api.js");
|
||||
const { uploadImageFromUrl } = await import("./upload.js");
|
||||
return {
|
||||
mockFetch: vi.mocked(fetchWithSsrFGuard),
|
||||
mockUploadFile: vi.mocked(uploadFile),
|
||||
uploadImageFromUrl,
|
||||
};
|
||||
}
|
||||
|
||||
type UploadMocks = Awaited<ReturnType<typeof loadUploadMocks>>;
|
||||
|
||||
function mockSuccessfulFetch(params: {
|
||||
mockFetch: UploadMocks["mockFetch"];
|
||||
blob: Blob;
|
||||
finalUrl: string;
|
||||
contentType: string;
|
||||
}) {
|
||||
params.mockFetch.mockResolvedValue({
|
||||
response: {
|
||||
ok: true,
|
||||
headers: new Headers({ "content-type": params.contentType }),
|
||||
blob: () => Promise.resolve(params.blob),
|
||||
} as unknown as Response,
|
||||
finalUrl: params.finalUrl,
|
||||
release: vi.fn().mockResolvedValue(undefined),
|
||||
});
|
||||
}
|
||||
|
||||
async function setupSuccessfulUpload(params?: {
|
||||
sourceUrl?: string;
|
||||
contentType?: string;
|
||||
uploadedUrl?: string;
|
||||
}) {
|
||||
const { mockFetch, mockUploadFile, uploadImageFromUrl } = await loadUploadMocks();
|
||||
const sourceUrl = params?.sourceUrl ?? "https://example.com/image.png";
|
||||
const contentType = params?.contentType ?? "image/png";
|
||||
const mockBlob = new Blob(["fake-image"], { type: contentType });
|
||||
mockSuccessfulFetch({
|
||||
mockFetch,
|
||||
blob: mockBlob,
|
||||
finalUrl: sourceUrl,
|
||||
contentType,
|
||||
});
|
||||
if (params?.uploadedUrl) {
|
||||
mockUploadFile.mockResolvedValue({ url: params.uploadedUrl });
|
||||
}
|
||||
return { mockBlob, mockUploadFile, uploadImageFromUrl };
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("fetches image and calls uploadFile, returns uploaded URL", async () => {
|
||||
const { mockBlob, mockUploadFile, uploadImageFromUrl } = await setupSuccessfulUpload({
|
||||
const { mockBlob } = await setupSuccessfulUpload({
|
||||
uploadedUrl: "https://memex.tlon.network/uploaded.png",
|
||||
});
|
||||
|
||||
@@ -93,8 +76,6 @@ describe("uploadImageFromUrl", () => {
|
||||
});
|
||||
|
||||
it("returns original URL if fetch fails", async () => {
|
||||
const { mockFetch, uploadImageFromUrl } = await loadUploadMocks();
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
response: {
|
||||
ok: false,
|
||||
@@ -110,7 +91,7 @@ describe("uploadImageFromUrl", () => {
|
||||
});
|
||||
|
||||
it("returns original URL if upload fails", async () => {
|
||||
const { mockUploadFile, uploadImageFromUrl } = await setupSuccessfulUpload();
|
||||
await setupSuccessfulUpload();
|
||||
mockUploadFile.mockRejectedValue(new Error("Upload failed"));
|
||||
|
||||
const result = await uploadImageFromUrl("https://example.com/image.png");
|
||||
@@ -119,28 +100,19 @@ describe("uploadImageFromUrl", () => {
|
||||
});
|
||||
|
||||
it("rejects non-http(s) URLs", async () => {
|
||||
const { uploadImageFromUrl } = await import("./upload.js");
|
||||
|
||||
// file:// URL should be rejected
|
||||
const result = await uploadImageFromUrl("file:///etc/passwd");
|
||||
expect(result).toBe("file:///etc/passwd");
|
||||
|
||||
// ftp:// URL should be rejected
|
||||
const result2 = await uploadImageFromUrl("ftp://example.com/image.png");
|
||||
expect(result2).toBe("ftp://example.com/image.png");
|
||||
});
|
||||
|
||||
it("handles invalid URLs gracefully", async () => {
|
||||
const { uploadImageFromUrl } = await import("./upload.js");
|
||||
|
||||
// Invalid URL should return original
|
||||
const result = await uploadImageFromUrl("not-a-valid-url");
|
||||
expect(result).toBe("not-a-valid-url");
|
||||
});
|
||||
|
||||
it("extracts filename from URL path", async () => {
|
||||
const { mockFetch, mockUploadFile, uploadImageFromUrl } = await loadUploadMocks();
|
||||
|
||||
const mockBlob = new Blob(["fake-image"], { type: "image/jpeg" });
|
||||
mockSuccessfulFetch({
|
||||
mockFetch,
|
||||
@@ -161,8 +133,6 @@ describe("uploadImageFromUrl", () => {
|
||||
});
|
||||
|
||||
it("uses default filename when URL has no path", async () => {
|
||||
const { mockFetch, mockUploadFile, uploadImageFromUrl } = await loadUploadMocks();
|
||||
|
||||
const mockBlob = new Blob(["fake-image"], { type: "image/png" });
|
||||
mockSuccessfulFetch({
|
||||
mockFetch,
|
||||
|
||||
@@ -37,7 +37,7 @@ type Registered = {
|
||||
methods: Map<string, unknown>;
|
||||
tools: unknown[];
|
||||
};
|
||||
type RegisterVoiceCall = (api: Record<string, unknown>) => void | Promise<void>;
|
||||
type RegisterVoiceCall = (api: Record<string, unknown>) => void;
|
||||
type RegisterCliContext = {
|
||||
program: Command;
|
||||
config: Record<string, unknown>;
|
||||
@@ -83,7 +83,7 @@ async function registerVoiceCallCli(program: Command) {
|
||||
const { register } = plugin as unknown as {
|
||||
register: RegisterVoiceCall;
|
||||
};
|
||||
await register({
|
||||
register({
|
||||
id: "voice-call",
|
||||
name: "Voice Call",
|
||||
description: "test",
|
||||
|
||||
@@ -19,8 +19,8 @@ import {
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { VoyageEmbeddingClient } from "./embedding-provider.js";
|
||||
|
||||
/**
|
||||
* Voyage Batch API Input Line format.
|
||||
@@ -1,8 +1,11 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import {
|
||||
fetchRemoteEmbeddingVectors,
|
||||
normalizeEmbeddingModelWithPrefixes,
|
||||
resolveRemoteEmbeddingBearerClient,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
|
||||
export type VoyageEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
@@ -28,8 +31,8 @@ export function normalizeVoyageModel(model: string): string {
|
||||
}
|
||||
|
||||
export async function createVoyageEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: VoyageEmbeddingClient }> {
|
||||
const client = await resolveVoyageEmbeddingClient(options);
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
@@ -70,7 +73,7 @@ export async function createVoyageEmbeddingProvider(
|
||||
}
|
||||
|
||||
export async function resolveVoyageEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<VoyageEmbeddingClient> {
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: "voyage",
|
||||
11
extensions/voyage/index.ts
Normal file
11
extensions/voyage/index.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { voyageMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
|
||||
export default definePluginEntry({
|
||||
id: "voyage",
|
||||
name: "Voyage Embeddings",
|
||||
description: "Bundled Voyage memory embedding provider plugin",
|
||||
register(api) {
|
||||
api.registerMemoryEmbeddingProvider(voyageMemoryEmbeddingProviderAdapter);
|
||||
},
|
||||
});
|
||||
56
extensions/voyage/memory-embedding-adapter.ts
Normal file
56
extensions/voyage/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { runVoyageEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
createVoyageEmbeddingProvider,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const voyageMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "voyage",
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "voyage",
|
||||
autoSelectPriority: 40,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider({
|
||||
...options,
|
||||
provider: "voyage",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "voyage",
|
||||
cacheKeyData: {
|
||||
provider: "voyage",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runVoyageEmbeddingBatches({
|
||||
client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
body: { input: chunk.text },
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
15
extensions/voyage/openclaw.plugin.json
Normal file
15
extensions/voyage/openclaw.plugin.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"id": "voyage",
|
||||
"enabledByDefault": true,
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["voyage"]
|
||||
},
|
||||
"providerAuthEnvVars": {
|
||||
"voyage": ["VOYAGE_API_KEY"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
15
extensions/voyage/package.json
Normal file
15
extensions/voyage/package.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"name": "@openclaw/voyage-provider",
|
||||
"version": "2026.4.15-beta.1",
|
||||
"private": true,
|
||||
"description": "OpenClaw Voyage embedding provider plugin",
|
||||
"type": "module",
|
||||
"devDependencies": {
|
||||
"@openclaw/plugin-sdk": "workspace:*"
|
||||
},
|
||||
"openclaw": {
|
||||
"extensions": [
|
||||
"./index.ts"
|
||||
]
|
||||
}
|
||||
}
|
||||
66
extensions/webhooks/index.test.ts
Normal file
66
extensions/webhooks/index.test.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { createTestPluginApi } from "../../test/helpers/plugins/plugin-api.js";
|
||||
import type { OpenClawPluginApi } from "./api.js";
|
||||
import plugin from "./index.js";
|
||||
|
||||
function createApi(params?: {
|
||||
pluginConfig?: OpenClawPluginApi["pluginConfig"];
|
||||
registerHttpRoute?: OpenClawPluginApi["registerHttpRoute"];
|
||||
logger?: OpenClawPluginApi["logger"];
|
||||
}): OpenClawPluginApi {
|
||||
return createTestPluginApi({
|
||||
id: "webhooks",
|
||||
name: "Webhooks",
|
||||
source: "test",
|
||||
pluginConfig: params?.pluginConfig ?? {},
|
||||
runtime: {
|
||||
taskFlow: {
|
||||
bindSession: vi.fn(({ sessionKey }: { sessionKey: string }) => ({ sessionKey })),
|
||||
},
|
||||
} as unknown as OpenClawPluginApi["runtime"],
|
||||
registerHttpRoute: params?.registerHttpRoute ?? vi.fn(),
|
||||
logger:
|
||||
params?.logger ??
|
||||
({
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
} as OpenClawPluginApi["logger"]),
|
||||
});
|
||||
}
|
||||
|
||||
describe("webhooks plugin registration", () => {
|
||||
it("registers SecretRef-backed routes synchronously", () => {
|
||||
const registerHttpRoute = vi.fn();
|
||||
|
||||
const result = plugin.register(
|
||||
createApi({
|
||||
pluginConfig: {
|
||||
routes: {
|
||||
zapier: {
|
||||
sessionKey: "agent:main:main",
|
||||
secret: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "OPENCLAW_WEBHOOK_SECRET",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
registerHttpRoute,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(registerHttpRoute).toHaveBeenCalledTimes(1);
|
||||
expect(registerHttpRoute).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
path: "/plugins/webhooks/zapier",
|
||||
auth: "plugin",
|
||||
match: "exact",
|
||||
replaceExisting: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -2,50 +2,52 @@ import { definePluginEntry, type OpenClawPluginApi } from "./api.js";
|
||||
import { resolveWebhooksPluginConfig } from "./src/config.js";
|
||||
import { createTaskFlowWebhookRequestHandler, type TaskFlowWebhookTarget } from "./src/http.js";
|
||||
|
||||
function registerWebhookRoutes(api: OpenClawPluginApi): void {
|
||||
const routes = resolveWebhooksPluginConfig({
|
||||
pluginConfig: api.pluginConfig,
|
||||
});
|
||||
if (routes.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const targetsByPath = new Map<string, TaskFlowWebhookTarget[]>();
|
||||
const handler = createTaskFlowWebhookRequestHandler({
|
||||
cfg: api.config,
|
||||
targetsByPath,
|
||||
});
|
||||
|
||||
for (const route of routes) {
|
||||
const taskFlow = api.runtime.taskFlow.bindSession({
|
||||
sessionKey: route.sessionKey,
|
||||
});
|
||||
const target: TaskFlowWebhookTarget = {
|
||||
routeId: route.routeId,
|
||||
path: route.path,
|
||||
secretInput: route.secret,
|
||||
secretConfigPath: `plugins.entries.webhooks.routes.${route.routeId}.secret`,
|
||||
defaultControllerId: route.controllerId,
|
||||
taskFlow,
|
||||
};
|
||||
targetsByPath.set(target.path, [...(targetsByPath.get(target.path) ?? []), target]);
|
||||
api.registerHttpRoute({
|
||||
path: target.path,
|
||||
auth: "plugin",
|
||||
match: "exact",
|
||||
replaceExisting: true,
|
||||
handler,
|
||||
});
|
||||
api.logger.info?.(
|
||||
`[webhooks] registered route ${route.routeId} on ${route.path} for session ${route.sessionKey}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export default definePluginEntry({
|
||||
id: "webhooks",
|
||||
name: "Webhooks",
|
||||
description:
|
||||
"Authenticated inbound webhooks that bind external automation to OpenClaw TaskFlows.",
|
||||
async register(api: OpenClawPluginApi) {
|
||||
const routes = await resolveWebhooksPluginConfig({
|
||||
pluginConfig: api.pluginConfig,
|
||||
cfg: api.config,
|
||||
env: process.env,
|
||||
logger: api.logger,
|
||||
});
|
||||
if (routes.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const targetsByPath = new Map<string, TaskFlowWebhookTarget[]>();
|
||||
const handler = createTaskFlowWebhookRequestHandler({
|
||||
cfg: api.config,
|
||||
targetsByPath,
|
||||
});
|
||||
|
||||
for (const route of routes) {
|
||||
const taskFlow = api.runtime.taskFlow.bindSession({
|
||||
sessionKey: route.sessionKey,
|
||||
});
|
||||
const target: TaskFlowWebhookTarget = {
|
||||
routeId: route.routeId,
|
||||
path: route.path,
|
||||
secret: route.secret,
|
||||
defaultControllerId: route.controllerId,
|
||||
taskFlow,
|
||||
};
|
||||
targetsByPath.set(target.path, [...(targetsByPath.get(target.path) ?? []), target]);
|
||||
api.registerHttpRoute({
|
||||
path: target.path,
|
||||
auth: "plugin",
|
||||
match: "exact",
|
||||
replaceExisting: true,
|
||||
handler,
|
||||
});
|
||||
api.logger.info?.(
|
||||
`[webhooks] registered route ${route.routeId} on ${route.path} for session ${route.sessionKey}`,
|
||||
);
|
||||
}
|
||||
register(api: OpenClawPluginApi) {
|
||||
registerWebhookRoutes(api);
|
||||
},
|
||||
});
|
||||
|
||||
@@ -4,6 +4,7 @@ export {
|
||||
normalizeWebhookPath,
|
||||
readJsonWebhookBodyOrReject,
|
||||
resolveRequestClientIp,
|
||||
resolveWebhookTargetWithAuthOrReject,
|
||||
resolveWebhookTargetWithAuthOrRejectSync,
|
||||
withResolvedWebhookRequestPipeline,
|
||||
WEBHOOK_IN_FLIGHT_DEFAULTS,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import type { OpenClawConfig } from "../runtime-api.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { resolveWebhooksPluginConfig } from "./config.js";
|
||||
|
||||
describe("resolveWebhooksPluginConfig", () => {
|
||||
it("resolves default paths and SecretRef-backed secrets", async () => {
|
||||
const routes = await resolveWebhooksPluginConfig({
|
||||
it("keeps SecretRef-backed secrets on the route config", () => {
|
||||
const routes = resolveWebhooksPluginConfig({
|
||||
pluginConfig: {
|
||||
routes: {
|
||||
zapier: {
|
||||
@@ -17,10 +16,6 @@ describe("resolveWebhooksPluginConfig", () => {
|
||||
},
|
||||
},
|
||||
},
|
||||
cfg: {} as OpenClawConfig,
|
||||
env: {
|
||||
OPENCLAW_WEBHOOK_SECRET: "shared-secret",
|
||||
},
|
||||
});
|
||||
|
||||
expect(routes).toEqual([
|
||||
@@ -28,16 +23,18 @@ describe("resolveWebhooksPluginConfig", () => {
|
||||
routeId: "zapier",
|
||||
path: "/plugins/webhooks/zapier",
|
||||
sessionKey: "agent:main:main",
|
||||
secret: "shared-secret",
|
||||
secret: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "OPENCLAW_WEBHOOK_SECRET",
|
||||
},
|
||||
controllerId: "webhooks/zapier",
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("skips routes whose secret cannot be resolved", async () => {
|
||||
const warn = vi.fn();
|
||||
|
||||
const routes = await resolveWebhooksPluginConfig({
|
||||
it("keeps routes whose secret needs runtime resolution", () => {
|
||||
const routes = resolveWebhooksPluginConfig({
|
||||
pluginConfig: {
|
||||
routes: {
|
||||
missing: {
|
||||
@@ -50,19 +47,25 @@ describe("resolveWebhooksPluginConfig", () => {
|
||||
},
|
||||
},
|
||||
},
|
||||
cfg: {} as OpenClawConfig,
|
||||
env: {},
|
||||
logger: { warn } as never,
|
||||
});
|
||||
|
||||
expect(routes).toEqual([]);
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining("[webhooks] skipping route missing:"),
|
||||
);
|
||||
expect(routes).toEqual([
|
||||
{
|
||||
routeId: "missing",
|
||||
path: "/plugins/webhooks/missing",
|
||||
sessionKey: "agent:main:main",
|
||||
secret: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "MISSING_SECRET",
|
||||
},
|
||||
controllerId: "webhooks/missing",
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("rejects duplicate normalized paths", async () => {
|
||||
await expect(
|
||||
it("rejects duplicate normalized paths", () => {
|
||||
expect(() =>
|
||||
resolveWebhooksPluginConfig({
|
||||
pluginConfig: {
|
||||
routes: {
|
||||
@@ -78,9 +81,7 @@ describe("resolveWebhooksPluginConfig", () => {
|
||||
},
|
||||
},
|
||||
},
|
||||
cfg: {} as OpenClawConfig,
|
||||
env: {},
|
||||
}),
|
||||
).rejects.toThrow(/conflicts with routes\.first\.path/i);
|
||||
).toThrow(/conflicts with routes\.first\.path/i);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import { z } from "zod";
|
||||
import type { PluginLogger } from "../api.js";
|
||||
import {
|
||||
normalizeWebhookPath,
|
||||
resolveConfiguredSecretInputString,
|
||||
type OpenClawConfig,
|
||||
} from "../runtime-api.js";
|
||||
import { normalizeWebhookPath } from "../runtime-api.js";
|
||||
|
||||
const secretRefSchema = z
|
||||
.object({
|
||||
@@ -33,23 +28,22 @@ const webhooksPluginConfigSchema = z
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type ResolvedWebhookRouteConfig = {
|
||||
export type WebhookSecretInput = z.infer<typeof secretInputSchema>;
|
||||
|
||||
export type ConfiguredWebhookRouteConfig = {
|
||||
routeId: string;
|
||||
path: string;
|
||||
sessionKey: string;
|
||||
secret: string;
|
||||
secret: WebhookSecretInput;
|
||||
controllerId: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
export async function resolveWebhooksPluginConfig(params: {
|
||||
export function resolveWebhooksPluginConfig(params: {
|
||||
pluginConfig: unknown;
|
||||
cfg: OpenClawConfig;
|
||||
env: NodeJS.ProcessEnv;
|
||||
logger?: PluginLogger;
|
||||
}): Promise<ResolvedWebhookRouteConfig[]> {
|
||||
}): ConfiguredWebhookRouteConfig[] {
|
||||
const parsed = webhooksPluginConfigSchema.parse(params.pluginConfig ?? {});
|
||||
const resolvedRoutes: ResolvedWebhookRouteConfig[] = [];
|
||||
const configuredRoutes: ConfiguredWebhookRouteConfig[] = [];
|
||||
const seenPaths = new Map<string, string>();
|
||||
|
||||
for (const [routeId, route] of Object.entries(parsed.routes)) {
|
||||
@@ -64,32 +58,16 @@ export async function resolveWebhooksPluginConfig(params: {
|
||||
);
|
||||
}
|
||||
|
||||
const secretResolution = await resolveConfiguredSecretInputString({
|
||||
config: params.cfg,
|
||||
env: params.env,
|
||||
value: route.secret,
|
||||
path: `plugins.entries.webhooks.routes.${routeId}.secret`,
|
||||
});
|
||||
const secret = secretResolution.value?.trim();
|
||||
if (!secret) {
|
||||
params.logger?.warn?.(
|
||||
`[webhooks] skipping route ${routeId}: ${
|
||||
secretResolution.unresolvedRefReason ?? "secret is empty or unresolved"
|
||||
}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
seenPaths.set(path, routeId);
|
||||
resolvedRoutes.push({
|
||||
configuredRoutes.push({
|
||||
routeId,
|
||||
path,
|
||||
sessionKey: route.sessionKey,
|
||||
secret,
|
||||
secret: route.secret,
|
||||
controllerId: route.controllerId ?? `webhooks/${routeId}`,
|
||||
...(route.description ? { description: route.description } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
return resolvedRoutes;
|
||||
return configuredRoutes;
|
||||
}
|
||||
|
||||
@@ -10,10 +10,12 @@ const hoisted = vi.hoisted(() => {
|
||||
const sendMessageMock = vi.fn();
|
||||
const cancelSessionMock = vi.fn();
|
||||
const killSubagentRunAdminMock = vi.fn();
|
||||
const resolveConfiguredSecretInputStringMock = vi.fn();
|
||||
return {
|
||||
sendMessageMock,
|
||||
cancelSessionMock,
|
||||
killSubagentRunAdminMock,
|
||||
resolveConfiguredSecretInputStringMock,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -31,6 +33,17 @@ vi.mock("../../../src/agents/subagent-control.js", () => ({
|
||||
killSubagentRunAdmin: (params: unknown) => hoisted.killSubagentRunAdminMock(params),
|
||||
}));
|
||||
|
||||
vi.mock("../runtime-api.js", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("../runtime-api.js")>();
|
||||
hoisted.resolveConfiguredSecretInputStringMock.mockImplementation(
|
||||
actual.resolveConfiguredSecretInputString,
|
||||
);
|
||||
return {
|
||||
...actual,
|
||||
resolveConfiguredSecretInputString: hoisted.resolveConfiguredSecretInputStringMock,
|
||||
};
|
||||
});
|
||||
|
||||
type MockIncomingMessage = IncomingMessage & {
|
||||
destroyed?: boolean;
|
||||
destroy: () => MockIncomingMessage;
|
||||
@@ -58,7 +71,7 @@ function createJsonRequest(params: {
|
||||
return req;
|
||||
}) as MockIncomingMessage["destroy"];
|
||||
|
||||
void Promise.resolve().then(() => {
|
||||
setImmediate(() => {
|
||||
req.emit("data", Buffer.from(JSON.stringify(params.body), "utf8"));
|
||||
req.emit("end");
|
||||
});
|
||||
@@ -69,13 +82,16 @@ function createJsonRequest(params: {
|
||||
function createHandler(): {
|
||||
handler: ReturnType<typeof createTaskFlowWebhookRequestHandler>;
|
||||
target: TaskFlowWebhookTarget;
|
||||
secret: string;
|
||||
} {
|
||||
const runtime = createRuntimeTaskFlow();
|
||||
nextSessionId += 1;
|
||||
const secret = "shared-secret";
|
||||
const target: TaskFlowWebhookTarget = {
|
||||
routeId: "zapier",
|
||||
path: "/plugins/webhooks/zapier",
|
||||
secret: "shared-secret",
|
||||
secretInput: secret,
|
||||
secretConfigPath: "plugins.entries.webhooks.routes.zapier.secret",
|
||||
defaultControllerId: "webhooks/zapier",
|
||||
taskFlow: runtime.bindSession({
|
||||
sessionKey: `agent:main:webhook-test-${String(nextSessionId)}`,
|
||||
@@ -88,9 +104,21 @@ function createHandler(): {
|
||||
targetsByPath,
|
||||
}),
|
||||
target,
|
||||
secret,
|
||||
};
|
||||
}
|
||||
|
||||
function createHandlerWithTarget(
|
||||
target: TaskFlowWebhookTarget,
|
||||
cfg: OpenClawConfig = {} as OpenClawConfig,
|
||||
): ReturnType<typeof createTaskFlowWebhookRequestHandler> {
|
||||
const targetsByPath = new Map<string, TaskFlowWebhookTarget[]>([[target.path, [target]]]);
|
||||
return createTaskFlowWebhookRequestHandler({
|
||||
cfg,
|
||||
targetsByPath,
|
||||
});
|
||||
}
|
||||
|
||||
async function dispatchJsonRequest(params: {
|
||||
handler: ReturnType<typeof createTaskFlowWebhookRequestHandler>;
|
||||
path: string;
|
||||
@@ -132,12 +160,53 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
expect(target.taskFlow.list()).toEqual([]);
|
||||
});
|
||||
|
||||
it("caches SecretRef resolution across requests for the same route", async () => {
|
||||
const runtime = createRuntimeTaskFlow();
|
||||
const target: TaskFlowWebhookTarget = {
|
||||
routeId: "cached",
|
||||
path: "/plugins/webhooks/cached",
|
||||
secretInput: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "OPENCLAW_WEBHOOK_SECRET",
|
||||
},
|
||||
secretConfigPath: "plugins.entries.webhooks.routes.cached.secret",
|
||||
defaultControllerId: "webhooks/cached",
|
||||
taskFlow: runtime.bindSession({
|
||||
sessionKey: "agent:main:webhook-cached",
|
||||
}),
|
||||
};
|
||||
hoisted.resolveConfiguredSecretInputStringMock.mockResolvedValue({ value: "shared-secret" });
|
||||
const handler = createHandlerWithTarget(target);
|
||||
|
||||
const first = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: "shared-secret",
|
||||
body: {
|
||||
action: "list_flows",
|
||||
},
|
||||
});
|
||||
const second = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: "shared-secret",
|
||||
body: {
|
||||
action: "list_flows",
|
||||
},
|
||||
});
|
||||
|
||||
expect(first.statusCode).toBe(200);
|
||||
expect(second.statusCode).toBe(200);
|
||||
expect(hoisted.resolveConfiguredSecretInputStringMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("creates flows through the bound session and scrubs owner metadata from responses", async () => {
|
||||
const { handler, target } = createHandler();
|
||||
const { handler, target, secret } = createHandler();
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "create_flow",
|
||||
goal: "Review inbound queue",
|
||||
@@ -158,7 +227,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("runs child tasks and scrubs task ownership fields from responses", async () => {
|
||||
const { handler, target } = createHandler();
|
||||
const { handler, target, secret } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Triage inbox",
|
||||
@@ -166,7 +235,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -193,11 +262,11 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("returns 404 for missing flow mutations", async () => {
|
||||
const { handler, target } = createHandler();
|
||||
const { handler, target, secret } = createHandler();
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "set_waiting",
|
||||
flowId: "flow-missing",
|
||||
@@ -219,7 +288,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("returns 409 for revision conflicts", async () => {
|
||||
const { handler, target } = createHandler();
|
||||
const { handler, target, secret } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Review inbox",
|
||||
@@ -227,7 +296,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "set_waiting",
|
||||
flowId: flow.flowId,
|
||||
@@ -252,7 +321,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("rejects internal runtimes and running-only metadata from external callers", async () => {
|
||||
const { handler, target } = createHandler();
|
||||
const { handler, target, secret } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Review inbox",
|
||||
@@ -261,7 +330,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const runtimeRes = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -278,7 +347,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const queuedMetadataRes = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -297,7 +366,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("reuses the same task record when retried with the same runId", async () => {
|
||||
const { handler, target } = createHandler();
|
||||
const { handler, target, secret } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Triage inbox",
|
||||
@@ -306,7 +375,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const first = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -319,7 +388,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const second = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "run_task",
|
||||
flowId: flow.flowId,
|
||||
@@ -339,7 +408,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
});
|
||||
|
||||
it("returns 409 when cancellation targets a terminal flow", async () => {
|
||||
const { handler, target } = createHandler();
|
||||
const { handler, target, secret } = createHandler();
|
||||
const flow = target.taskFlow.createManaged({
|
||||
controllerId: "webhooks/zapier",
|
||||
goal: "Review inbox",
|
||||
@@ -353,7 +422,7 @@ describe("createTaskFlowWebhookRequestHandler", () => {
|
||||
const res = await dispatchJsonRequest({
|
||||
handler,
|
||||
path: target.path,
|
||||
secret: target.secret,
|
||||
secret,
|
||||
body: {
|
||||
action: "cancel_flow",
|
||||
flowId: flow.flowId,
|
||||
|
||||
@@ -8,13 +8,15 @@ import {
|
||||
createWebhookInFlightLimiter,
|
||||
readJsonWebhookBodyOrReject,
|
||||
resolveRequestClientIp,
|
||||
resolveWebhookTargetWithAuthOrRejectSync,
|
||||
resolveConfiguredSecretInputString,
|
||||
resolveWebhookTargetWithAuthOrReject,
|
||||
withResolvedWebhookRequestPipeline,
|
||||
WEBHOOK_IN_FLIGHT_DEFAULTS,
|
||||
WEBHOOK_RATE_LIMIT_DEFAULTS,
|
||||
type OpenClawConfig,
|
||||
type WebhookInFlightLimiter,
|
||||
} from "../runtime-api.js";
|
||||
import type { WebhookSecretInput } from "./config.js";
|
||||
|
||||
type BoundTaskFlowRuntime = ReturnType<PluginRuntime["taskFlow"]["bindSession"]>;
|
||||
|
||||
@@ -174,7 +176,8 @@ type WebhookAction = z.infer<typeof webhookActionSchema>;
|
||||
export type TaskFlowWebhookTarget = {
|
||||
routeId: string;
|
||||
path: string;
|
||||
secret: string;
|
||||
secretInput: WebhookSecretInput;
|
||||
secretConfigPath: string;
|
||||
defaultControllerId: string;
|
||||
taskFlow: BoundTaskFlowRuntime;
|
||||
};
|
||||
@@ -664,6 +667,7 @@ export function createTaskFlowWebhookRequestHandler(params: {
|
||||
targetsByPath: Map<string, TaskFlowWebhookTarget[]>;
|
||||
inFlightLimiter?: WebhookInFlightLimiter;
|
||||
}): (req: IncomingMessage, res: ServerResponse) => Promise<boolean> {
|
||||
const secretByTarget = new WeakMap<TaskFlowWebhookTarget, Promise<string | undefined>>();
|
||||
const rateLimiter = createFixedWindowRateLimiter({
|
||||
windowMs: WEBHOOK_RATE_LIMIT_DEFAULTS.windowMs,
|
||||
maxRequests: WEBHOOK_RATE_LIMIT_DEFAULTS.maxRequests,
|
||||
@@ -675,6 +679,20 @@ export function createTaskFlowWebhookRequestHandler(params: {
|
||||
maxInFlightPerKey: WEBHOOK_IN_FLIGHT_DEFAULTS.maxInFlightPerKey,
|
||||
maxTrackedKeys: WEBHOOK_IN_FLIGHT_DEFAULTS.maxTrackedKeys,
|
||||
});
|
||||
const resolveTargetSecret = (target: TaskFlowWebhookTarget): Promise<string | undefined> => {
|
||||
const cached = secretByTarget.get(target);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
const pending = resolveConfiguredSecretInputString({
|
||||
config: params.cfg,
|
||||
env: process.env,
|
||||
value: target.secretInput,
|
||||
path: target.secretConfigPath,
|
||||
}).then((resolved) => resolved.value);
|
||||
secretByTarget.set(target, pending);
|
||||
return pending;
|
||||
};
|
||||
|
||||
return async (req: IncomingMessage, res: ServerResponse): Promise<boolean> => {
|
||||
return await withResolvedWebhookRequestPipeline({
|
||||
@@ -698,11 +716,18 @@ export function createTaskFlowWebhookRequestHandler(params: {
|
||||
inFlightLimiter,
|
||||
handle: async ({ targets }) => {
|
||||
const presentedSecret = extractSharedSecret(req);
|
||||
const target = resolveWebhookTargetWithAuthOrRejectSync({
|
||||
const target = await resolveWebhookTargetWithAuthOrReject({
|
||||
targets,
|
||||
res,
|
||||
isMatch: (candidate) =>
|
||||
presentedSecret.length > 0 && timingSafeEquals(candidate.secret, presentedSecret),
|
||||
isMatch: async (candidate) => {
|
||||
if (presentedSecret.length === 0) {
|
||||
return false;
|
||||
}
|
||||
const resolvedSecret = await resolveTargetSecret(candidate);
|
||||
return Boolean(
|
||||
resolvedSecret && timingSafeEquals(resolvedSecret, presentedSecret),
|
||||
);
|
||||
},
|
||||
});
|
||||
if (!target) {
|
||||
return true;
|
||||
|
||||
@@ -169,6 +169,19 @@ describe("whatsapp setup wizard", () => {
|
||||
expectWhatsAppAllowlistModeSetup(result.cfg);
|
||||
});
|
||||
|
||||
it("throws a user-facing error instead of crashing when allowlist input is undefined", async () => {
|
||||
const harness = createSeparatePhoneHarness({
|
||||
selectValues: ["separate", "allowlist", "list"],
|
||||
});
|
||||
harness.text.mockResolvedValueOnce(undefined as never);
|
||||
|
||||
await expect(
|
||||
runConfigureWithHarness({
|
||||
harness,
|
||||
}),
|
||||
).rejects.toThrow("Invalid WhatsApp allowFrom list");
|
||||
});
|
||||
|
||||
it("enables allowlist self-chat mode for personal-phone setup", async () => {
|
||||
hoisted.pathExists.mockResolvedValue(true);
|
||||
const harness = createWhatsAppPersonalPhoneHarness(createQueuedWizardPrompter);
|
||||
@@ -180,6 +193,18 @@ describe("whatsapp setup wizard", () => {
|
||||
expectWhatsAppPersonalPhoneSetup(result.cfg);
|
||||
});
|
||||
|
||||
it("throws a user-facing error instead of crashing when personal-phone input is undefined", async () => {
|
||||
hoisted.pathExists.mockResolvedValue(true);
|
||||
const harness = createWhatsAppPersonalPhoneHarness(createQueuedWizardPrompter);
|
||||
harness.text.mockResolvedValueOnce(undefined as never);
|
||||
|
||||
await expect(
|
||||
runConfigureWithHarness({
|
||||
harness,
|
||||
}),
|
||||
).rejects.toThrow("Invalid WhatsApp owner number");
|
||||
});
|
||||
|
||||
it("forces wildcard allowFrom for open policy without allowFrom follow-up prompts", async () => {
|
||||
hoisted.pathExists.mockResolvedValue(true);
|
||||
const harness = createSeparatePhoneHarness({
|
||||
|
||||
@@ -23,6 +23,10 @@ type SetupRuntime = Parameters<NonNullable<ChannelSetupWizard["finalize"]>>[0]["
|
||||
type WhatsAppConfig = NonNullable<NonNullable<OpenClawConfig["channels"]>["whatsapp"]>;
|
||||
type WhatsAppAccountConfig = NonNullable<NonNullable<WhatsAppConfig["accounts"]>[string]>;
|
||||
|
||||
function trimPromptText(value: string | null | undefined): string {
|
||||
return value?.trim() ?? "";
|
||||
}
|
||||
|
||||
function mergeWhatsAppConfig(
|
||||
cfg: OpenClawConfig,
|
||||
accountId: string,
|
||||
@@ -124,7 +128,7 @@ async function promptWhatsAppOwnerAllowFrom(params: {
|
||||
placeholder: "+15555550123",
|
||||
initialValue: existingAllowFrom[0],
|
||||
validate: (value) => {
|
||||
const raw = value.trim();
|
||||
const raw = trimPromptText(value);
|
||||
if (!raw) {
|
||||
return "Required";
|
||||
}
|
||||
@@ -136,7 +140,7 @@ async function promptWhatsAppOwnerAllowFrom(params: {
|
||||
},
|
||||
});
|
||||
|
||||
const normalized = normalizeE164(entry.trim());
|
||||
const normalized = normalizeE164(trimPromptText(entry));
|
||||
if (!normalized) {
|
||||
throw new Error("Invalid WhatsApp owner number (expected E.164 after validation).");
|
||||
}
|
||||
@@ -311,7 +315,7 @@ async function promptWhatsAppDmAccess(params: {
|
||||
message: "Allowed sender numbers (comma-separated, E.164)",
|
||||
placeholder: "+15555550123, +447700900123",
|
||||
validate: (value) => {
|
||||
const raw = value.trim();
|
||||
const raw = trimPromptText(value);
|
||||
if (!raw) {
|
||||
return "Required";
|
||||
}
|
||||
@@ -326,7 +330,13 @@ async function promptWhatsAppDmAccess(params: {
|
||||
},
|
||||
});
|
||||
|
||||
const parsed = parseWhatsAppAllowFromEntries(allowRaw);
|
||||
const parsed = parseWhatsAppAllowFromEntries(trimPromptText(allowRaw));
|
||||
if (parsed.invalidEntry) {
|
||||
throw new Error(`Invalid number: ${parsed.invalidEntry}`);
|
||||
}
|
||||
if (parsed.entries.length === 0) {
|
||||
throw new Error("Invalid WhatsApp allowFrom list (expected at least one E.164 number).");
|
||||
}
|
||||
return setWhatsAppAllowFrom(next, accountId, parsed.entries);
|
||||
}
|
||||
|
||||
|
||||
@@ -1346,6 +1346,8 @@
|
||||
"test:parallels:windows": "bash scripts/e2e/parallels-windows-smoke.sh",
|
||||
"test:perf:budget": "node scripts/test-perf-budget.mjs",
|
||||
"test:perf:changed:bench": "node scripts/bench-test-changed.mjs",
|
||||
"test:perf:groups": "node scripts/test-group-report.mjs",
|
||||
"test:perf:groups:compare": "node scripts/test-group-report.mjs --compare",
|
||||
"test:perf:hotspots": "node scripts/test-hotspots.mjs",
|
||||
"test:perf:imports": "OPENCLAW_VITEST_IMPORT_DURATIONS=1 OPENCLAW_VITEST_PRINT_IMPORT_BREAKDOWN=1 node scripts/test-projects.mjs",
|
||||
"test:perf:imports:changed": "OPENCLAW_VITEST_IMPORT_DURATIONS=1 OPENCLAW_VITEST_PRINT_IMPORT_BREAKDOWN=1 node scripts/test-projects.mjs --changed origin/main",
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
|
||||
vi.mock("./remote-http.js", () => ({
|
||||
withRemoteHttpResponse: vi.fn(),
|
||||
}));
|
||||
|
||||
function magnitude(values: number[]) {
|
||||
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
|
||||
}
|
||||
|
||||
describe("runGeminiEmbeddingBatches", () => {
|
||||
let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches;
|
||||
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
|
||||
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js"));
|
||||
({ withRemoteHttpResponse } = await import("./remote-http.js"));
|
||||
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
const mockClient: GeminiEmbeddingClient = {
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
headers: {},
|
||||
model: "gemini-embedding-2-preview",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
apiKeys: ["test-key"],
|
||||
outputDimensionality: 1536,
|
||||
};
|
||||
|
||||
it("includes outputDimensionality in batch upload requests", async () => {
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart");
|
||||
const body = params.init?.body;
|
||||
if (!(body instanceof Blob)) {
|
||||
throw new Error("expected multipart blob body");
|
||||
}
|
||||
const text = await body.text();
|
||||
expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"');
|
||||
expect(text).toContain('"outputDimensionality":1536');
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ name: "files/file-123" }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/:asyncBatchEmbedContent$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
name: "batches/batch-1",
|
||||
state: "COMPLETED",
|
||||
outputConfig: { file: "files/output-1" },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/\/files\/output-1:download$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
key: "req-1",
|
||||
response: { embedding: { values: [3, 4] } },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/jsonl" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runGeminiEmbeddingBatches({
|
||||
gemini: mockClient,
|
||||
agentId: "main",
|
||||
requests: [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
request: {
|
||||
content: { parts: [{ text: "hello world" }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
},
|
||||
},
|
||||
],
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
});
|
||||
|
||||
const embedding = results.get("req-1");
|
||||
expect(embedding).toBeDefined();
|
||||
expect(embedding?.[0]).toBeCloseTo(0.6, 5);
|
||||
expect(embedding?.[1]).toBeCloseTo(0.8, 5);
|
||||
expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5);
|
||||
expect(remoteHttpMock).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
@@ -1,259 +0,0 @@
|
||||
import {
|
||||
applyEmbeddingBatchOutputLine,
|
||||
buildBatchHeaders,
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
EMBEDDING_BATCH_ENDPOINT,
|
||||
extractBatchErrorMessage,
|
||||
formatUnavailableBatchError,
|
||||
normalizeBatchBaseUrl,
|
||||
postJsonWithRetry,
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
runEmbeddingBatchGroups,
|
||||
throwIfBatchTerminalFailure,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
type EmbeddingBatchStatus,
|
||||
type BatchCompletionResult,
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
|
||||
export type OpenAiBatchRequest = {
|
||||
custom_id: string;
|
||||
method: "POST";
|
||||
url: "/v1/embeddings";
|
||||
body: {
|
||||
model: string;
|
||||
input: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type OpenAiBatchStatus = EmbeddingBatchStatus;
|
||||
export type OpenAiBatchOutputLine = ProviderBatchOutputLine;
|
||||
|
||||
export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
|
||||
const OPENAI_BATCH_COMPLETION_WINDOW = "24h";
|
||||
const OPENAI_BATCH_MAX_REQUESTS = 50000;
|
||||
|
||||
async function submitOpenAiBatch(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
requests: OpenAiBatchRequest[];
|
||||
agentId: string;
|
||||
}): Promise<OpenAiBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.openAi);
|
||||
const inputFileId = await uploadBatchJsonlFile({
|
||||
client: params.openAi,
|
||||
requests: params.requests,
|
||||
errorPrefix: "openai batch file upload failed",
|
||||
});
|
||||
|
||||
return await postJsonWithRetry<OpenAiBatchStatus>({
|
||||
url: `${baseUrl}/batches`,
|
||||
headers: buildBatchHeaders(params.openAi, { json: true }),
|
||||
ssrfPolicy: params.openAi.ssrfPolicy,
|
||||
body: {
|
||||
input_file_id: inputFileId,
|
||||
endpoint: OPENAI_BATCH_ENDPOINT,
|
||||
completion_window: OPENAI_BATCH_COMPLETION_WINDOW,
|
||||
metadata: {
|
||||
source: "openclaw-memory",
|
||||
agent: params.agentId,
|
||||
},
|
||||
},
|
||||
errorPrefix: "openai batch create failed",
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiBatchStatus(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
batchId: string;
|
||||
}): Promise<OpenAiBatchStatus> {
|
||||
return await fetchOpenAiBatchResource({
|
||||
openAi: params.openAi,
|
||||
path: `/batches/${params.batchId}`,
|
||||
errorPrefix: "openai batch status",
|
||||
parse: async (res) => (await res.json()) as OpenAiBatchStatus,
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiFileContent(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
fileId: string;
|
||||
}): Promise<string> {
|
||||
return await fetchOpenAiBatchResource({
|
||||
openAi: params.openAi,
|
||||
path: `/files/${params.fileId}/content`,
|
||||
errorPrefix: "openai batch file content",
|
||||
parse: async (res) => await res.text(),
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiBatchResource<T>(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
path: string;
|
||||
errorPrefix: string;
|
||||
parse: (res: Response) => Promise<T>;
|
||||
}): Promise<T> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.openAi);
|
||||
return await withRemoteHttpResponse({
|
||||
url: `${baseUrl}${params.path}`,
|
||||
ssrfPolicy: params.openAi.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.openAi, { json: true }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`${params.errorPrefix} failed: ${res.status} ${text}`);
|
||||
}
|
||||
return await params.parse(res);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
return text
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as OpenAiBatchOutputLine);
|
||||
}
|
||||
|
||||
async function readOpenAiBatchError(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
errorFileId: string;
|
||||
}): Promise<string | undefined> {
|
||||
try {
|
||||
const content = await fetchOpenAiFileContent({
|
||||
openAi: params.openAi,
|
||||
fileId: params.errorFileId,
|
||||
});
|
||||
const lines = parseOpenAiBatchOutput(content);
|
||||
return extractBatchErrorMessage(lines);
|
||||
} catch (err) {
|
||||
return formatUnavailableBatchError(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForOpenAiBatch(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
batchId: string;
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: OpenAiBatchStatus;
|
||||
}): Promise<BatchCompletionResult> {
|
||||
const start = Date.now();
|
||||
let current: OpenAiBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
current ??
|
||||
(await fetchOpenAiBatchStatus({
|
||||
openAi: params.openAi,
|
||||
batchId: params.batchId,
|
||||
}));
|
||||
const state = status.status ?? "unknown";
|
||||
if (state === "completed") {
|
||||
return resolveBatchCompletionFromStatus({
|
||||
provider: "openai",
|
||||
batchId: params.batchId,
|
||||
status,
|
||||
});
|
||||
}
|
||||
await throwIfBatchTerminalFailure({
|
||||
provider: "openai",
|
||||
status: { ...status, id: params.batchId },
|
||||
readError: async (errorFileId) =>
|
||||
await readOpenAiBatchError({
|
||||
openAi: params.openAi,
|
||||
errorFileId,
|
||||
}),
|
||||
});
|
||||
if (!params.wait) {
|
||||
throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`);
|
||||
}
|
||||
if (Date.now() - start > params.timeoutMs) {
|
||||
throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function runOpenAiEmbeddingBatches(
|
||||
params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: OpenAiBatchRequest[];
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: OPENAI_BATCH_MAX_REQUESTS,
|
||||
debugLabel: "memory embeddings: openai batch submit",
|
||||
}),
|
||||
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
|
||||
const batchInfo = await submitOpenAiBatch({
|
||||
openAi: params.openAi,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
});
|
||||
if (!batchInfo.id) {
|
||||
throw new Error("openai batch create failed: missing batch id");
|
||||
}
|
||||
const batchId = batchInfo.id;
|
||||
|
||||
params.debug?.("memory embeddings: openai batch created", {
|
||||
batchId: batchInfo.id,
|
||||
status: batchInfo.status,
|
||||
group: groupIndex + 1,
|
||||
groups,
|
||||
requests: group.length,
|
||||
});
|
||||
|
||||
const completed = await resolveCompletedBatchResult({
|
||||
provider: "openai",
|
||||
status: batchInfo,
|
||||
wait: params.wait,
|
||||
waitForBatch: async () =>
|
||||
await waitForOpenAiBatch({
|
||||
openAi: params.openAi,
|
||||
batchId,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
}),
|
||||
});
|
||||
|
||||
const content = await fetchOpenAiFileContent({
|
||||
openAi: params.openAi,
|
||||
fileId: completed.outputFileId,
|
||||
});
|
||||
const outputLines = parseOpenAiBatchOutput(content);
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
for (const line of outputLines) {
|
||||
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`);
|
||||
}
|
||||
if (remaining.size > 0) {
|
||||
throw new Error(
|
||||
`openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
import { ReadableStream } from "node:stream/web";
|
||||
import { setTimeout as nativeSleep } from "node:timers/promises";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
runVoyageEmbeddingBatches,
|
||||
type VoyageBatchOutputLine,
|
||||
type VoyageBatchRequest,
|
||||
} from "./batch-voyage.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
const realNow = Date.now.bind(Date);
|
||||
|
||||
describe("runVoyageEmbeddingBatches", () => {
|
||||
const mockClient: VoyageEmbeddingClient = {
|
||||
baseUrl: "https://api.voyageai.com/v1",
|
||||
headers: { Authorization: "Bearer test-key" },
|
||||
model: "voyage-4-large",
|
||||
};
|
||||
|
||||
const mockRequests: VoyageBatchRequest[] = [
|
||||
{ custom_id: "req-1", body: { input: "text1" } },
|
||||
{ custom_id: "req-2", body: { input: "text2" } },
|
||||
];
|
||||
|
||||
it("successfully submits batch, waits, and streams results", async () => {
|
||||
const outputLines: VoyageBatchOutputLine[] = [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } },
|
||||
},
|
||||
{
|
||||
custom_id: "req-2",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
|
||||
},
|
||||
];
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
|
||||
// Create a stream that emits the NDJSON lines
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const text = outputLines.map((l) => JSON.stringify(l)).join("\n");
|
||||
controller.enqueue(new TextEncoder().encode(text));
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
|
||||
expect(params.errorPrefix).toBe("voyage batch file upload failed");
|
||||
expect(params.requests).toEqual(mockRequests);
|
||||
return "file-123";
|
||||
});
|
||||
postJsonWithRetry.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches");
|
||||
expect(params.body).toMatchObject({
|
||||
input_file_id: "file-123",
|
||||
completion_window: "12h",
|
||||
request_params: {
|
||||
model: "voyage-4-large",
|
||||
input_type: "document",
|
||||
},
|
||||
});
|
||||
return {
|
||||
id: "batch-abc",
|
||||
status: "pending",
|
||||
};
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches/batch-abc");
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
id: "batch-abc",
|
||||
status: "completed",
|
||||
output_file_id: "file-out-999",
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/file-out-999/content");
|
||||
return await params.onResponse(
|
||||
new Response(stream as unknown as BodyInit, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/x-ndjson" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "agent-1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1, // fast poll
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.size).toBe(2);
|
||||
expect(results.get("req-1")).toEqual([0.1, 0.1]);
|
||||
expect(results.get("req-2")).toEqual([0.2, 0.2]);
|
||||
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
|
||||
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
|
||||
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("handles empty lines and stream chunks correctly", async () => {
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const line1 = JSON.stringify({
|
||||
custom_id: "req-1",
|
||||
response: { body: { data: [{ embedding: [1] }] } },
|
||||
});
|
||||
const line2 = JSON.stringify({
|
||||
custom_id: "req-2",
|
||||
response: { body: { data: [{ embedding: [2] }] } },
|
||||
});
|
||||
|
||||
// Split across chunks
|
||||
controller.enqueue(new TextEncoder().encode(line1 + "\n"));
|
||||
controller.enqueue(new TextEncoder().encode("\n")); // empty line
|
||||
controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
|
||||
postJsonWithRetry.mockResolvedValueOnce({
|
||||
id: "b1",
|
||||
status: "completed",
|
||||
output_file_id: "out1",
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/out1/content");
|
||||
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "a1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.get("req-1")).toEqual([1]);
|
||||
expect(results.get("req-2")).toEqual([2]);
|
||||
});
|
||||
});
|
||||
@@ -1,40 +1,14 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js";
|
||||
import type { EmbeddingProvider } from "./embeddings.js";
|
||||
|
||||
const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192;
|
||||
const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048;
|
||||
|
||||
const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"openai:text-embedding-3-small": 8192,
|
||||
"openai:text-embedding-3-large": 8192,
|
||||
"openai:text-embedding-ada-002": 8191,
|
||||
"gemini:text-embedding-004": 2048,
|
||||
"gemini:gemini-embedding-001": 2048,
|
||||
"gemini:gemini-embedding-2-preview": 8192,
|
||||
"voyage:voyage-3": 32000,
|
||||
"voyage:voyage-3-lite": 16000,
|
||||
"voyage:voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number {
|
||||
if (typeof provider.maxInputTokens === "number") {
|
||||
return provider.maxInputTokens;
|
||||
}
|
||||
|
||||
// Provider/model mapping is best-effort; different providers use different
|
||||
// limits and we prefer to be conservative when we don't know.
|
||||
const key = normalizeLowercaseStringOrEmpty(`${provider.id}:${provider.model}`);
|
||||
const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key];
|
||||
if (typeof known === "number") {
|
||||
return known;
|
||||
}
|
||||
|
||||
// Provider-specific conservative fallbacks. This prevents us from accidentally
|
||||
// using the OpenAI default for providers with much smaller limits.
|
||||
if (normalizeLowercaseStringOrEmpty(provider.id) === "gemini") {
|
||||
return 2048;
|
||||
}
|
||||
if (normalizeLowercaseStringOrEmpty(provider.id) === "local") {
|
||||
if (provider.id === "local") {
|
||||
return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,377 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const { defaultProviderMock, resolveCredentialsMock, sendMock } = vi.hoisted(() => ({
|
||||
defaultProviderMock: vi.fn(),
|
||||
resolveCredentialsMock: vi.fn(),
|
||||
sendMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
|
||||
class MockClient {
|
||||
region: string;
|
||||
constructor(config: { region: string }) {
|
||||
this.region = config.region;
|
||||
}
|
||||
send = sendMock;
|
||||
}
|
||||
class MockCommand {
|
||||
input: unknown;
|
||||
constructor(input: unknown) {
|
||||
this.input = input;
|
||||
}
|
||||
}
|
||||
return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand };
|
||||
});
|
||||
|
||||
vi.mock("@aws-sdk/credential-provider-node", () => ({
|
||||
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
|
||||
}));
|
||||
|
||||
let createBedrockEmbeddingProvider: typeof import("./embeddings-bedrock.js").createBedrockEmbeddingProvider;
|
||||
let resolveBedrockEmbeddingClient: typeof import("./embeddings-bedrock.js").resolveBedrockEmbeddingClient;
|
||||
let normalizeBedrockEmbeddingModel: typeof import("./embeddings-bedrock.js").normalizeBedrockEmbeddingModel;
|
||||
let hasAwsCredentials: typeof import("./embeddings-bedrock.js").hasAwsCredentials;
|
||||
|
||||
beforeAll(async () => {
|
||||
({
|
||||
createBedrockEmbeddingProvider,
|
||||
resolveBedrockEmbeddingClient,
|
||||
normalizeBedrockEmbeddingModel,
|
||||
hasAwsCredentials,
|
||||
} = await import("./embeddings-bedrock.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
defaultProviderMock.mockImplementation(() => resolveCredentialsMock);
|
||||
});
|
||||
|
||||
const enc = (body: unknown) => ({ body: new TextEncoder().encode(JSON.stringify(body)) });
|
||||
const reqBody = (i = 0): Record<string, unknown> =>
|
||||
JSON.parse(sendMock.mock.calls[i][0].input.body);
|
||||
|
||||
describe("bedrock embedding provider", () => {
|
||||
const originalEnv = process.env;
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
vi.restoreAllMocks();
|
||||
defaultProviderMock.mockClear();
|
||||
resolveCredentialsMock.mockReset();
|
||||
sendMock.mockReset();
|
||||
});
|
||||
|
||||
// --- Normalization ---
|
||||
|
||||
it("normalizes model names with prefixes", () => {
|
||||
expect(normalizeBedrockEmbeddingModel("bedrock/amazon.titan-embed-text-v2:0")).toBe(
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
);
|
||||
expect(normalizeBedrockEmbeddingModel("amazon-bedrock/cohere.embed-english-v3")).toBe(
|
||||
"cohere.embed-english-v3",
|
||||
);
|
||||
expect(normalizeBedrockEmbeddingModel("")).toBe("amazon.titan-embed-text-v2:0");
|
||||
});
|
||||
|
||||
// --- Client resolution ---
|
||||
|
||||
it("resolves region from env", () => {
|
||||
process.env = { ...originalEnv, AWS_REGION: "eu-west-1" };
|
||||
const c = resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(c.region).toBe("eu-west-1");
|
||||
expect(c.dimensions).toBe(1024);
|
||||
});
|
||||
|
||||
it("defaults to us-east-1", () => {
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.AWS_REGION;
|
||||
delete process.env.AWS_DEFAULT_REGION;
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
}).region,
|
||||
).toBe("us-east-1");
|
||||
});
|
||||
|
||||
it("extracts region from baseUrl", () => {
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.AWS_REGION;
|
||||
const c = resolveBedrockEmbeddingClient({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
"amazon-bedrock": { baseUrl: "https://bedrock-runtime.ap-southeast-2.amazonaws.com" },
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(c.region).toBe("ap-southeast-2");
|
||||
});
|
||||
|
||||
it("validates dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 768,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 768");
|
||||
});
|
||||
|
||||
it("accepts valid dimensions", () => {
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 256,
|
||||
}).dimensions,
|
||||
).toBe(256);
|
||||
});
|
||||
|
||||
it("resolves throughput-suffixed variants", () => {
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v1:2:8k",
|
||||
fallback: "none",
|
||||
}).dimensions,
|
||||
).toBe(1536);
|
||||
});
|
||||
|
||||
// --- Credential detection ---
|
||||
|
||||
it("detects access keys", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({
|
||||
AWS_ACCESS_KEY_ID: "A",
|
||||
AWS_SECRET_ACCESS_KEY: "s",
|
||||
} as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects profile", async () => {
|
||||
await expect(hasAwsCredentials({ AWS_PROFILE: "default" } as NodeJS.ProcessEnv)).resolves.toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
it("detects ECS task role", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({ AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: "/v2" } as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects EKS IRSA", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({
|
||||
AWS_WEB_IDENTITY_TOKEN_FILE: "/var/run/secrets/token",
|
||||
AWS_ROLE_ARN: "arn:aws:iam::123:role/x",
|
||||
} as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects credentials via the AWS SDK default provider chain", async () => {
|
||||
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
|
||||
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(true);
|
||||
expect(defaultProviderMock).toHaveBeenCalledWith({ timeout: 1000, maxRetries: 0 });
|
||||
});
|
||||
it("returns false with no creds", async () => {
|
||||
resolveCredentialsMock.mockRejectedValue(new Error("no aws credentials"));
|
||||
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(false);
|
||||
});
|
||||
|
||||
// --- Titan V2 ---
|
||||
|
||||
it("embeds with Titan V2", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.1, 0.2, 0.3] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("test")).toHaveLength(3);
|
||||
expect(reqBody()).toMatchObject({ inputText: "test", normalize: true, dimensions: 1024 });
|
||||
});
|
||||
|
||||
it("returns empty for blank text", async () => {
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery(" ")).toEqual([]);
|
||||
expect(sendMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("batches Titan V2 concurrently", async () => {
|
||||
sendMock
|
||||
.mockResolvedValueOnce(enc({ embedding: [0.1] }))
|
||||
.mockResolvedValueOnce(enc({ embedding: [0.2] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// --- Titan V1 ---
|
||||
|
||||
it("sends only inputText for Titan V1", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.5] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v1",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("hi");
|
||||
expect(reqBody()).toEqual({ inputText: "hi" });
|
||||
});
|
||||
|
||||
it("handles Titan G1 text variant", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.1] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-g1-text-02",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("hi");
|
||||
expect(reqBody()).toEqual({ inputText: "hi" });
|
||||
});
|
||||
|
||||
// --- Cohere V3 ---
|
||||
|
||||
it("embeds Cohere V3 batch in single call", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: [[0.1], [0.2]] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-english-v3",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(1);
|
||||
expect(reqBody()).toMatchObject({ texts: ["a", "b"], input_type: "search_document" });
|
||||
});
|
||||
|
||||
it("uses search_query for Cohere embedQuery", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: [[0.1]] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-english-v3",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("q");
|
||||
expect(reqBody().input_type).toBe("search_query");
|
||||
});
|
||||
|
||||
// --- Cohere V4 ---
|
||||
|
||||
it("embeds Cohere V4 with embedding_types + output_dimension", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: { float: [[0.1], [0.2]] } }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-v4:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(reqBody()).toMatchObject({ embedding_types: ["float"], output_dimension: 1536 });
|
||||
});
|
||||
|
||||
it("validates Cohere V4 dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-v4:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 2048,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 2048");
|
||||
});
|
||||
|
||||
// --- Nova ---
|
||||
|
||||
it("embeds Nova with SINGLE_EMBEDDING format", async () => {
|
||||
sendMock.mockResolvedValue(
|
||||
enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1, 0.2] }] }),
|
||||
);
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toHaveLength(2);
|
||||
expect(reqBody().taskType).toBe("SINGLE_EMBEDDING");
|
||||
});
|
||||
|
||||
it("validates Nova dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 512,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 512");
|
||||
});
|
||||
|
||||
it("batches Nova concurrently", async () => {
|
||||
sendMock
|
||||
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1] }] }))
|
||||
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.2] }] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// --- TwelveLabs ---
|
||||
|
||||
it("embeds TwelveLabs Marengo", async () => {
|
||||
sendMock.mockResolvedValue(enc({ data: [{ embedding: [0.1, 0.2] }] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "twelvelabs.marengo-embed-3-0-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toHaveLength(2);
|
||||
expect(reqBody()).toEqual({ inputType: "text", text: { inputText: "hi" } });
|
||||
});
|
||||
|
||||
it("embeds TwelveLabs object-style responses", async () => {
|
||||
sendMock.mockResolvedValue(enc({ data: { embedding: [0.3, 0.4] } }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "twelvelabs.marengo-embed-2-7-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toEqual([0.6, 0.8]);
|
||||
});
|
||||
});
|
||||
@@ -1,398 +0,0 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types & constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export type BedrockEmbeddingClient = {
|
||||
region: string;
|
||||
model: string;
|
||||
dimensions?: number;
|
||||
};
|
||||
|
||||
export const DEFAULT_BEDROCK_EMBEDDING_MODEL = "amazon.titan-embed-text-v2:0";
|
||||
|
||||
/** Request/response format family — each has a different API shape. */
|
||||
type Family = "titan-v1" | "titan-v2" | "cohere-v3" | "cohere-v4" | "nova" | "twelvelabs";
|
||||
|
||||
interface ModelSpec {
|
||||
maxTokens: number;
|
||||
dims: number;
|
||||
validDims?: number[];
|
||||
family: Family;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model catalog
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const MODELS: Record<string, ModelSpec> = {
|
||||
"amazon.titan-embed-text-v2:0": {
|
||||
maxTokens: 8192,
|
||||
dims: 1024,
|
||||
validDims: [256, 512, 1024],
|
||||
family: "titan-v2",
|
||||
},
|
||||
"amazon.titan-embed-text-v1": { maxTokens: 8000, dims: 1536, family: "titan-v1" },
|
||||
"amazon.titan-embed-g1-text-02": { maxTokens: 8000, dims: 1536, family: "titan-v1" },
|
||||
"amazon.titan-embed-image-v1": { maxTokens: 128, dims: 1024, family: "titan-v1" },
|
||||
"cohere.embed-english-v3": { maxTokens: 512, dims: 1024, family: "cohere-v3" },
|
||||
"cohere.embed-multilingual-v3": { maxTokens: 512, dims: 1024, family: "cohere-v3" },
|
||||
"cohere.embed-v4:0": {
|
||||
maxTokens: 128000,
|
||||
dims: 1536,
|
||||
validDims: [256, 384, 512, 768, 1024, 1536],
|
||||
family: "cohere-v4",
|
||||
},
|
||||
"amazon.nova-2-multimodal-embeddings-v1:0": {
|
||||
maxTokens: 8192,
|
||||
dims: 1024,
|
||||
validDims: [256, 384, 1024, 3072],
|
||||
family: "nova",
|
||||
},
|
||||
"twelvelabs.marengo-embed-2-7-v1:0": { maxTokens: 512, dims: 1024, family: "twelvelabs" },
|
||||
"twelvelabs.marengo-embed-3-0-v1:0": { maxTokens: 512, dims: 512, family: "twelvelabs" },
|
||||
};
|
||||
|
||||
/** Resolve spec, stripping throughput suffixes like `:2:8k` or `:0:512`. */
|
||||
function resolveSpec(modelId: string): ModelSpec | undefined {
|
||||
if (MODELS[modelId]) {
|
||||
return MODELS[modelId];
|
||||
}
|
||||
const parts = modelId.split(":");
|
||||
for (let i = parts.length - 1; i >= 1; i--) {
|
||||
const spec = MODELS[parts.slice(0, i).join(":")];
|
||||
if (spec) {
|
||||
return spec;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Infer family from model ID prefix when not in catalog. */
|
||||
function inferFamily(modelId: string): Family {
|
||||
const id = normalizeLowercaseStringOrEmpty(modelId);
|
||||
if (id.startsWith("amazon.titan-embed-text-v2")) {
|
||||
return "titan-v2";
|
||||
}
|
||||
if (id.startsWith("amazon.titan-embed")) {
|
||||
return "titan-v1";
|
||||
}
|
||||
if (id.startsWith("amazon.nova")) {
|
||||
return "nova";
|
||||
}
|
||||
if (id.startsWith("cohere.embed-v4")) {
|
||||
return "cohere-v4";
|
||||
}
|
||||
if (id.startsWith("cohere.embed")) {
|
||||
return "cohere-v3";
|
||||
}
|
||||
if (id.startsWith("twelvelabs.")) {
|
||||
return "twelvelabs";
|
||||
}
|
||||
return "titan-v1"; // safest default — simplest request format
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AWS SDK lazy loader
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type SdkClient = import("@aws-sdk/client-bedrock-runtime").BedrockRuntimeClient;
|
||||
type SdkCommand = import("@aws-sdk/client-bedrock-runtime").InvokeModelCommand;
|
||||
|
||||
interface AwsSdk {
|
||||
BedrockRuntimeClient: new (config: { region: string }) => SdkClient;
|
||||
InvokeModelCommand: new (input: {
|
||||
modelId: string;
|
||||
body: string;
|
||||
contentType: string;
|
||||
accept: string;
|
||||
}) => SdkCommand;
|
||||
}
|
||||
|
||||
interface AwsCredentialProviderSdk {
|
||||
defaultProvider: (init?: { timeout?: number; maxRetries?: number }) => () => Promise<{
|
||||
accessKeyId?: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
let sdkCache: AwsSdk | null = null;
|
||||
let credentialProviderSdkCache: AwsCredentialProviderSdk | null | undefined;
|
||||
|
||||
async function loadSdk(): Promise<AwsSdk> {
|
||||
if (sdkCache) {
|
||||
return sdkCache;
|
||||
}
|
||||
try {
|
||||
sdkCache = (await import("@aws-sdk/client-bedrock-runtime")) as unknown as AwsSdk;
|
||||
return sdkCache;
|
||||
} catch {
|
||||
throw new Error(
|
||||
"No API key found for provider bedrock: @aws-sdk/client-bedrock-runtime is not installed. " +
|
||||
"Install it with: npm install @aws-sdk/client-bedrock-runtime",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function loadCredentialProviderSdk(): Promise<AwsCredentialProviderSdk | null> {
|
||||
if (credentialProviderSdkCache !== undefined) {
|
||||
return credentialProviderSdkCache;
|
||||
}
|
||||
try {
|
||||
credentialProviderSdkCache =
|
||||
(await import("@aws-sdk/credential-provider-node")) as unknown as AwsCredentialProviderSdk;
|
||||
} catch {
|
||||
credentialProviderSdkCache = null;
|
||||
}
|
||||
return credentialProviderSdkCache;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const MODEL_PREFIX_RE = /^(?:bedrock|amazon-bedrock|aws)\//;
|
||||
const REGION_RE = /bedrock-runtime\.([a-z0-9-]+)\./;
|
||||
|
||||
export function normalizeBedrockEmbeddingModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
return trimmed ? trimmed.replace(MODEL_PREFIX_RE, "") : DEFAULT_BEDROCK_EMBEDDING_MODEL;
|
||||
}
|
||||
|
||||
function regionFromUrl(url: string | undefined): string | undefined {
|
||||
return url?.trim() ? REGION_RE.exec(url)?.[1] : undefined;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Request builders
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function buildBody(family: Family, text: string, dims?: number): string {
|
||||
switch (family) {
|
||||
case "titan-v2": {
|
||||
const b: Record<string, unknown> = { inputText: text };
|
||||
if (dims != null) {
|
||||
b.dimensions = dims;
|
||||
b.normalize = true;
|
||||
}
|
||||
return JSON.stringify(b);
|
||||
}
|
||||
case "titan-v1":
|
||||
return JSON.stringify({ inputText: text });
|
||||
case "nova":
|
||||
return JSON.stringify({
|
||||
taskType: "SINGLE_EMBEDDING",
|
||||
singleEmbeddingParams: {
|
||||
embeddingPurpose: "GENERIC_INDEX",
|
||||
embeddingDimension: dims ?? 1024,
|
||||
text: { truncationMode: "END", value: text },
|
||||
},
|
||||
});
|
||||
case "twelvelabs":
|
||||
return JSON.stringify({ inputType: "text", text: { inputText: text } });
|
||||
default:
|
||||
return JSON.stringify({ inputText: text });
|
||||
}
|
||||
}
|
||||
|
||||
function buildCohereBody(
|
||||
family: Family,
|
||||
texts: string[],
|
||||
inputType: "search_query" | "search_document",
|
||||
dims?: number,
|
||||
): string {
|
||||
const body: Record<string, unknown> = { texts, input_type: inputType, truncate: "END" };
|
||||
if (family === "cohere-v4") {
|
||||
body.embedding_types = ["float"];
|
||||
if (dims != null) {
|
||||
body.output_dimension = dims;
|
||||
}
|
||||
}
|
||||
return JSON.stringify(body);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Response parsers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function parseSingle(family: Family, raw: string): number[] {
|
||||
const data = JSON.parse(raw);
|
||||
switch (family) {
|
||||
case "nova":
|
||||
return data.embeddings?.[0]?.embedding ?? [];
|
||||
case "twelvelabs": {
|
||||
if (Array.isArray(data.data)) {
|
||||
return data.data[0]?.embedding ?? [];
|
||||
}
|
||||
if (Array.isArray(data.data?.embedding)) {
|
||||
return data.data.embedding;
|
||||
}
|
||||
return data.embedding ?? [];
|
||||
}
|
||||
default:
|
||||
return data.embedding ?? [];
|
||||
}
|
||||
}
|
||||
|
||||
function parseCohereBatch(family: Family, raw: string): number[][] {
|
||||
const data = JSON.parse(raw);
|
||||
const embeddings = data.embeddings;
|
||||
if (!embeddings) {
|
||||
return [];
|
||||
}
|
||||
if (family === "cohere-v4" && !Array.isArray(embeddings)) {
|
||||
return embeddings.float ?? [];
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Provider
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function createBedrockEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> {
|
||||
const client = resolveBedrockEmbeddingClient(options);
|
||||
const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk();
|
||||
const sdk = new BedrockRuntimeClient({ region: client.region });
|
||||
const spec = resolveSpec(client.model);
|
||||
const family = spec?.family ?? inferFamily(client.model);
|
||||
|
||||
debugEmbeddingsLog("memory embeddings: bedrock client", {
|
||||
region: client.region,
|
||||
model: client.model,
|
||||
dimensions: client.dimensions,
|
||||
family,
|
||||
});
|
||||
|
||||
const invoke = async (body: string): Promise<string> => {
|
||||
const res = await sdk.send(
|
||||
new InvokeModelCommand({
|
||||
modelId: client.model,
|
||||
body,
|
||||
contentType: "application/json",
|
||||
accept: "application/json",
|
||||
}),
|
||||
);
|
||||
return new TextDecoder().decode(res.body);
|
||||
};
|
||||
|
||||
const isCohere = family === "cohere-v3" || family === "cohere-v4";
|
||||
|
||||
const embedSingle = async (text: string): Promise<number[]> => {
|
||||
const raw = await invoke(buildBody(family, text, client.dimensions));
|
||||
return sanitizeAndNormalizeEmbedding(parseSingle(family, raw));
|
||||
};
|
||||
|
||||
const embedCohere = async (
|
||||
texts: string[],
|
||||
inputType: "search_query" | "search_document",
|
||||
): Promise<number[][]> => {
|
||||
const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions));
|
||||
return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e));
|
||||
};
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
if (isCohere) {
|
||||
return (await embedCohere([text], "search_query"))[0] ?? [];
|
||||
}
|
||||
return embedSingle(text);
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
if (texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
if (isCohere) {
|
||||
return embedCohere(texts, "search_document");
|
||||
}
|
||||
return Promise.all(texts.map((t) => (t.trim() ? embedSingle(t) : Promise.resolve([]))));
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "bedrock",
|
||||
model: client.model,
|
||||
maxInputTokens: spec?.maxTokens,
|
||||
embedQuery,
|
||||
embedBatch,
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Client resolution
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function resolveBedrockEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): BedrockEmbeddingClient {
|
||||
const model = normalizeBedrockEmbeddingModel(options.model);
|
||||
const spec = resolveSpec(model);
|
||||
const providerConfig = options.config.models?.providers?.["amazon-bedrock"];
|
||||
|
||||
const region =
|
||||
regionFromUrl(options.remote?.baseUrl) ??
|
||||
regionFromUrl(providerConfig?.baseUrl) ??
|
||||
process.env.AWS_REGION ??
|
||||
process.env.AWS_DEFAULT_REGION ??
|
||||
"us-east-1";
|
||||
|
||||
let dimensions: number | undefined;
|
||||
if (options.outputDimensionality != null) {
|
||||
if (spec?.validDims && !spec.validDims.includes(options.outputDimensionality)) {
|
||||
throw new Error(
|
||||
`Invalid dimensions ${options.outputDimensionality} for ${model}. Valid values: ${spec.validDims.join(", ")}`,
|
||||
);
|
||||
}
|
||||
dimensions = options.outputDimensionality;
|
||||
} else {
|
||||
dimensions = spec?.dims;
|
||||
}
|
||||
|
||||
return { region, model, dimensions };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Credential detection
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const CREDENTIAL_ENV_VARS = [
|
||||
"AWS_PROFILE",
|
||||
"AWS_BEARER_TOKEN_BEDROCK",
|
||||
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
|
||||
"AWS_CONTAINER_CREDENTIALS_FULL_URI",
|
||||
"AWS_EC2_METADATA_SERVICE_ENDPOINT",
|
||||
"AWS_WEB_IDENTITY_TOKEN_FILE",
|
||||
"AWS_ROLE_ARN",
|
||||
] as const;
|
||||
|
||||
export async function hasAwsCredentials(env: NodeJS.ProcessEnv = process.env): Promise<boolean> {
|
||||
if (env.AWS_ACCESS_KEY_ID?.trim() && env.AWS_SECRET_ACCESS_KEY?.trim()) {
|
||||
return true;
|
||||
}
|
||||
if (CREDENTIAL_ENV_VARS.some((k) => env[k]?.trim())) {
|
||||
return true;
|
||||
}
|
||||
const credentialProviderSdk = await loadCredentialProviderSdk();
|
||||
if (!credentialProviderSdk) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
const credentials = await credentialProviderSdk.defaultProvider({
|
||||
timeout: 1000,
|
||||
maxRetries: 0,
|
||||
})();
|
||||
return typeof credentials.accessKeyId === "string" && credentials.accessKeyId.trim().length > 0;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
|
||||
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
|
||||
export const GEMINI_EMBEDDING_2_MODELS = new Set([
|
||||
"gemini-embedding-2-preview",
|
||||
// Add the GA model name here once released.
|
||||
]);
|
||||
|
||||
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
|
||||
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
|
||||
|
||||
export type GeminiTaskType =
|
||||
| "RETRIEVAL_QUERY"
|
||||
| "RETRIEVAL_DOCUMENT"
|
||||
| "SEMANTIC_SIMILARITY"
|
||||
| "CLASSIFICATION"
|
||||
| "CLUSTERING"
|
||||
| "QUESTION_ANSWERING"
|
||||
| "FACT_VERIFICATION";
|
||||
|
||||
export type GeminiTextPart = { text: string };
|
||||
export type GeminiInlinePart = {
|
||||
inlineData: { mimeType: string; data: string };
|
||||
};
|
||||
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
|
||||
export type GeminiEmbeddingRequest = {
|
||||
content: { parts: GeminiPart[] };
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
model?: string;
|
||||
};
|
||||
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
|
||||
|
||||
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
|
||||
export function buildGeminiTextEmbeddingRequest(params: {
|
||||
text: string;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiTextEmbeddingRequest {
|
||||
return buildGeminiEmbeddingRequest({
|
||||
input: { text: params.text },
|
||||
taskType: params.taskType,
|
||||
outputDimensionality: params.outputDimensionality,
|
||||
modelPath: params.modelPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGeminiEmbeddingRequest(params: {
|
||||
input: EmbeddingInput;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiEmbeddingRequest {
|
||||
const request: GeminiEmbeddingRequest = {
|
||||
content: {
|
||||
parts: params.input.parts?.map((part) =>
|
||||
part.type === "text"
|
||||
? ({ text: part.text } satisfies GeminiTextPart)
|
||||
: ({
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
} satisfies GeminiInlinePart),
|
||||
) ?? [{ text: params.input.text }],
|
||||
},
|
||||
taskType: params.taskType,
|
||||
};
|
||||
if (params.modelPath) {
|
||||
request.model = params.modelPath;
|
||||
}
|
||||
if (params.outputDimensionality != null) {
|
||||
request.outputDimensionality = params.outputDimensionality;
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given model name is a gemini-embedding-2 variant that
|
||||
* supports `outputDimensionality` and extended task types.
|
||||
*/
|
||||
export function isGeminiEmbedding2Model(model: string): boolean {
|
||||
return GEMINI_EMBEDDING_2_MODELS.has(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
|
||||
* Returns `undefined` for older models (they don't support the param).
|
||||
*/
|
||||
export function resolveGeminiOutputDimensionality(
|
||||
model: string,
|
||||
requested?: number,
|
||||
): number | undefined {
|
||||
if (!isGeminiEmbedding2Model(model)) {
|
||||
return undefined;
|
||||
}
|
||||
if (requested == null) {
|
||||
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
|
||||
if (!valid.includes(requested)) {
|
||||
throw new Error(
|
||||
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
|
||||
);
|
||||
}
|
||||
return requested;
|
||||
}
|
||||
|
||||
export function normalizeGeminiModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
}
|
||||
const withoutPrefix = trimmed.replace(/^models\//, "");
|
||||
if (withoutPrefix.startsWith("gemini/")) {
|
||||
return withoutPrefix.slice("gemini/".length);
|
||||
}
|
||||
if (withoutPrefix.startsWith("google/")) {
|
||||
return withoutPrefix.slice("google/".length);
|
||||
}
|
||||
return withoutPrefix;
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
|
||||
describe("package Gemini embedding request helpers", () => {
|
||||
it("builds multimodal v2 requests and resolves model settings", () => {
|
||||
expect(
|
||||
buildGeminiEmbeddingRequest({
|
||||
input: {
|
||||
text: "Image file: diagram.png",
|
||||
parts: [
|
||||
{ type: "text", text: "Image file: diagram.png" },
|
||||
{ type: "inline-data", mimeType: "image/png", data: "abc123" },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
outputDimensionality: 1536,
|
||||
}),
|
||||
).toEqual({
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "abc123" } },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
});
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-001")).toBeUndefined();
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview")).toBe(3072);
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 768)).toBe(768);
|
||||
expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 512)).toThrow(
|
||||
/Invalid outputDimensionality 512/,
|
||||
);
|
||||
expect(normalizeGeminiModel("models/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("gemini/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("google/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("")).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL);
|
||||
});
|
||||
});
|
||||
@@ -1,238 +0,0 @@
|
||||
import {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../../../../src/agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../../../../src/agents/model-auth.js";
|
||||
import { parseGeminiAuth } from "../../../../src/infra/gemini-auth.js";
|
||||
import {
|
||||
DEFAULT_GOOGLE_API_BASE_URL,
|
||||
normalizeGoogleApiBaseUrl,
|
||||
} from "../../../../src/infra/google-api-base-url.js";
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
type GeminiEmbeddingRequest,
|
||||
type GeminiInlinePart,
|
||||
type GeminiPart,
|
||||
type GeminiTaskType,
|
||||
type GeminiTextEmbeddingRequest,
|
||||
type GeminiTextPart,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
|
||||
export type GeminiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
modelPath: string;
|
||||
apiKeys: string[];
|
||||
outputDimensionality?: number;
|
||||
};
|
||||
|
||||
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-004": 2048,
|
||||
};
|
||||
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
const trimmed = resolveMemorySecretInputString({
|
||||
value: remoteApiKey,
|
||||
path: "agents.*.memorySearch.remote.apiKey",
|
||||
});
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") {
|
||||
return process.env[trimmed]?.trim();
|
||||
}
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
async function fetchGeminiEmbeddingPayload(params: {
|
||||
client: GeminiEmbeddingClient;
|
||||
endpoint: string;
|
||||
body: unknown;
|
||||
}): Promise<{
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
}> {
|
||||
return await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: params.client.apiKeys,
|
||||
execute: async (apiKey) => {
|
||||
const authHeaders = parseGeminiAuth(apiKey);
|
||||
const headers = {
|
||||
...authHeaders.headers,
|
||||
...params.client.headers,
|
||||
};
|
||||
return await withRemoteHttpResponse({
|
||||
url: params.endpoint,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(params.body),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function normalizeGeminiBaseUrl(raw: string): string {
|
||||
const trimmed = raw.replace(/\/+$/, "");
|
||||
const openAiIndex = trimmed.indexOf("/openai");
|
||||
if (openAiIndex > -1) {
|
||||
return normalizeGoogleApiBaseUrl(trimmed.slice(0, openAiIndex));
|
||||
}
|
||||
return normalizeGoogleApiBaseUrl(trimmed);
|
||||
}
|
||||
|
||||
function buildGeminiModelPath(model: string): string {
|
||||
return model.startsWith("models/") ? model : `models/${model}`;
|
||||
}
|
||||
|
||||
export async function createGeminiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
|
||||
const client = await resolveGeminiEmbeddingClient(options);
|
||||
const baseUrl = client.baseUrl.replace(/\/$/, "");
|
||||
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
|
||||
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
|
||||
const isV2 = isGeminiEmbedding2Model(client.model);
|
||||
const outputDimensionality = client.outputDimensionality;
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: embedUrl,
|
||||
body: buildGeminiTextEmbeddingRequest({
|
||||
text,
|
||||
taskType: options.taskType ?? "RETRIEVAL_QUERY",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
});
|
||||
return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []);
|
||||
};
|
||||
|
||||
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
|
||||
if (inputs.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: batchUrl,
|
||||
body: {
|
||||
requests: inputs.map((input) =>
|
||||
buildGeminiEmbeddingRequest({
|
||||
input,
|
||||
modelPath: client.modelPath,
|
||||
taskType: options.taskType ?? "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
),
|
||||
},
|
||||
});
|
||||
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
|
||||
return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? []));
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
return await embedBatchInputs(
|
||||
texts.map((text) => ({
|
||||
text,
|
||||
})),
|
||||
);
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "gemini",
|
||||
model: client.model,
|
||||
maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery,
|
||||
embedBatch,
|
||||
embedBatchInputs,
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveGeminiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<GeminiEmbeddingClient> {
|
||||
const remote = options.remote;
|
||||
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
|
||||
const apiKey = remoteApiKey
|
||||
? remoteApiKey
|
||||
: requireApiKey(
|
||||
await resolveApiKeyForProvider({
|
||||
provider: "google",
|
||||
cfg: options.config,
|
||||
agentDir: options.agentDir,
|
||||
}),
|
||||
"google",
|
||||
);
|
||||
|
||||
const providerConfig = options.config.models?.providers?.google;
|
||||
const rawBaseUrl =
|
||||
remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GOOGLE_API_BASE_URL;
|
||||
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
|
||||
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
...headerOverrides,
|
||||
};
|
||||
const apiKeys = collectProviderApiKeysForExecution({
|
||||
provider: "google",
|
||||
primaryApiKey: apiKey,
|
||||
});
|
||||
const model = normalizeGeminiModel(options.model);
|
||||
const modelPath = buildGeminiModelPath(model);
|
||||
const outputDimensionality = resolveGeminiOutputDimensionality(
|
||||
model,
|
||||
options.outputDimensionality,
|
||||
);
|
||||
debugEmbeddingsLog("memory embeddings: gemini client", {
|
||||
rawBaseUrl,
|
||||
baseUrl,
|
||||
model,
|
||||
modelPath,
|
||||
outputDimensionality,
|
||||
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
||||
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
||||
});
|
||||
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality };
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
export * from "../../../../src/memory-host-sdk/host/embeddings-lmstudio.js";
|
||||
@@ -1,19 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js";
|
||||
|
||||
describe("normalizeMistralModel", () => {
|
||||
it("returns the default model for empty values", () => {
|
||||
expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
});
|
||||
|
||||
it("strips the mistral/ prefix", () => {
|
||||
expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed");
|
||||
});
|
||||
|
||||
it("keeps explicit non-prefixed models", () => {
|
||||
expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2");
|
||||
});
|
||||
});
|
||||
@@ -1,51 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type MistralEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed";
|
||||
const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1";
|
||||
|
||||
export function normalizeMistralModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
prefixes: ["mistral/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createMistralEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
|
||||
const client = await resolveMistralEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
provider: createRemoteEmbeddingProvider({
|
||||
id: "mistral",
|
||||
client,
|
||||
errorPrefix: "mistral embeddings failed",
|
||||
}),
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveMistralEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<MistralEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "mistral",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL,
|
||||
normalizeModel: normalizeMistralModel,
|
||||
});
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({
|
||||
createOllamaEmbeddingProviderMock: vi.fn(async (options: unknown) => ({
|
||||
provider: { source: "mock-provider", options },
|
||||
client: { source: "mock-client" },
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/plugin-sdk/ollama-runtime.js", () => ({
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL: "nomic-embed-text",
|
||||
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
|
||||
}));
|
||||
|
||||
describe("memory-host-sdk Ollama embedding facade", () => {
|
||||
beforeEach(() => {
|
||||
createOllamaEmbeddingProviderMock.mockClear();
|
||||
});
|
||||
|
||||
it("re-exports the default Ollama embedding model", async () => {
|
||||
const mod = await import("./embeddings-ollama.js");
|
||||
expect(mod.DEFAULT_OLLAMA_EMBEDDING_MODEL).toBe("nomic-embed-text");
|
||||
});
|
||||
|
||||
it("delegates provider creation to the plugin-sdk runtime facade", async () => {
|
||||
const mod = await import("./embeddings-ollama.js");
|
||||
const options = {
|
||||
provider: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
fallback: "none",
|
||||
config: {},
|
||||
};
|
||||
|
||||
const result = await mod.createOllamaEmbeddingProvider(options as never);
|
||||
|
||||
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledTimes(1);
|
||||
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledWith(options);
|
||||
expect(result).toEqual({
|
||||
provider: { source: "mock-provider", options },
|
||||
client: { source: "mock-client" },
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +0,0 @@
|
||||
export type { OllamaEmbeddingClient } from "../../../../src/plugin-sdk/ollama-runtime.js";
|
||||
export {
|
||||
createOllamaEmbeddingProvider,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
} from "../../../../src/plugin-sdk/ollama-runtime.js";
|
||||
@@ -1,58 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../../../src/plugins/provider-model-defaults.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type OpenAiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
|
||||
export const DEFAULT_OPENAI_EMBEDDING_MODEL = OPENAI_DEFAULT_EMBEDDING_MODEL;
|
||||
const OPENAI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-3-small": 8192,
|
||||
"text-embedding-3-large": 8192,
|
||||
"text-embedding-ada-002": 8191,
|
||||
};
|
||||
|
||||
export function normalizeOpenAiModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
prefixes: ["openai/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createOpenAiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
const client = await resolveOpenAiEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
provider: createRemoteEmbeddingProvider({
|
||||
id: "openai",
|
||||
client,
|
||||
errorPrefix: "openai embeddings failed",
|
||||
maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model],
|
||||
}),
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveOpenAiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<OpenAiEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "openai",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
|
||||
normalizeModel: normalizeOpenAiModel,
|
||||
});
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import type { EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral";
|
||||
export type RemoteEmbeddingProviderId = string;
|
||||
|
||||
export async function resolveRemoteEmbeddingBearerClient(params: {
|
||||
provider: RemoteEmbeddingProviderId;
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../../../src/agents/model-auth.js";
|
||||
import { type FetchMock, withFetchPreconnect } from "../../../../src/test-utils/fetch-mock.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: async (params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
fetchImpl?: typeof fetch;
|
||||
}) => {
|
||||
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
|
||||
if (!fetchImpl) {
|
||||
throw new Error("fetch is not available");
|
||||
}
|
||||
const response = await fetchImpl(params.url, params.init);
|
||||
return {
|
||||
response,
|
||||
finalUrl: params.url,
|
||||
release: async () => {},
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const createFetchMock = () => {
|
||||
const fetchMock = vi.fn<FetchMock>(
|
||||
async (_input: RequestInfo | URL, _init?: RequestInit) =>
|
||||
new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
return withFetchPreconnect(fetchMock);
|
||||
};
|
||||
|
||||
function installFetchMock(fetchMock: typeof globalThis.fetch) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
}
|
||||
|
||||
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
|
||||
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
|
||||
await import("./embeddings-voyage.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
});
|
||||
|
||||
function mockVoyageApiKey() {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "voyage-key-123",
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
async function createDefaultVoyageProvider(
|
||||
model: string,
|
||||
fetchMock: ReturnType<typeof createFetchMock>,
|
||||
) {
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockVoyageApiKey();
|
||||
return createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model,
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
describe("voyage embedding provider", () => {
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("configures client with correct defaults and headers", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedQuery("test query");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ provider: "voyage" }),
|
||||
);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://api.voyageai.com/v1/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer voyage-key-123");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["test query"],
|
||||
input_type: "query",
|
||||
});
|
||||
});
|
||||
|
||||
it("respects remote overrides for baseUrl and apiKey", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
|
||||
const result = await createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model: "voyage-4-lite",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "https://example.com",
|
||||
apiKey: "remote-override-key",
|
||||
headers: { "X-Custom": "123" },
|
||||
},
|
||||
});
|
||||
|
||||
await result.provider.embedQuery("test");
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://example.com/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-override-key");
|
||||
expect(headers["X-Custom"]).toBe("123");
|
||||
});
|
||||
|
||||
it("passes input_type=document for embedBatch", async () => {
|
||||
const fetchMock = withFetchPreconnect(
|
||||
vi.fn<FetchMock>(
|
||||
async (_input: RequestInfo | URL, _init?: RequestInit) =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
),
|
||||
),
|
||||
);
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedBatch(["doc1", "doc2"]);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["doc1", "doc2"],
|
||||
input_type: "document",
|
||||
});
|
||||
});
|
||||
|
||||
it("normalizes model names", async () => {
|
||||
expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2");
|
||||
expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large");
|
||||
expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite");
|
||||
expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default
|
||||
});
|
||||
});
|
||||
@@ -1,82 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type VoyageEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large";
|
||||
const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1";
|
||||
const VOYAGE_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"voyage-3": 32000,
|
||||
"voyage-3-lite": 16000,
|
||||
"voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function normalizeVoyageModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
prefixes: ["voyage/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createVoyageEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
|
||||
const client = await resolveVoyageEmbeddingClient(options);
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const body: { model: string; input: string[]; input_type?: "query" | "document" } = {
|
||||
model: client.model,
|
||||
input,
|
||||
};
|
||||
if (input_type) {
|
||||
body.input_type = input_type;
|
||||
}
|
||||
|
||||
return await fetchRemoteEmbeddingVectors({
|
||||
url,
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
body,
|
||||
errorPrefix: "voyage embeddings failed",
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "voyage",
|
||||
model: client.model,
|
||||
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text], "query");
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: async (texts) => embed(texts, "document"),
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveVoyageEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<VoyageEmbeddingClient> {
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: "voyage",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
|
||||
});
|
||||
const model = normalizeVoyageModel(options.model);
|
||||
return { baseUrl, headers, ssrfPolicy, model };
|
||||
}
|
||||
@@ -1,199 +1,8 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../../../src/agents/model-auth.js";
|
||||
import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
|
||||
import * as nodeLlamaModule from "./node-llama.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_LOCAL_MODEL } from "./embeddings.js";
|
||||
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: async (params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
fetchImpl?: typeof fetch;
|
||||
}) => {
|
||||
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
|
||||
if (!fetchImpl) {
|
||||
throw new Error("fetch is not available");
|
||||
}
|
||||
const response = await fetchImpl(params.url, params.init);
|
||||
return {
|
||||
response,
|
||||
finalUrl: params.url,
|
||||
release: async () => {},
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
const createEmbeddingDataFetchMock = () =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
|
||||
}));
|
||||
|
||||
const createGeminiFetchMock = () =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ embedding: { values: [1, 2, 3] } }),
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(authModule, "resolveApiKeyForProvider");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
function installFetchMock(fetchMock: typeof globalThis.fetch) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
}
|
||||
|
||||
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
|
||||
const [url, init] = fetchMock.mock.calls[0] ?? [];
|
||||
return { url, init: init as RequestInit | undefined };
|
||||
}
|
||||
|
||||
function requireProvider(result: Awaited<ReturnType<typeof createEmbeddingProvider>>) {
|
||||
if (!result.provider) {
|
||||
throw new Error("Expected embedding provider");
|
||||
}
|
||||
return result.provider;
|
||||
}
|
||||
|
||||
function mockResolvedProviderKey(apiKey = "provider-key") {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey,
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
describe("package embedding provider smoke", () => {
|
||||
it("uses remote OpenAI baseUrl/apiKey and merges headers", async () => {
|
||||
const fetchMock = createEmbeddingDataFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
headers: { "X-Provider": "p", "X-Shared": "provider" },
|
||||
},
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " remote-key ",
|
||||
headers: { "X-Shared": "remote", "X-Remote": "r" },
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
await requireProvider(result).embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
||||
const { url, init } = readFirstFetchRequest(fetchMock);
|
||||
expect(url).toBe("https://example.com/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-key");
|
||||
expect(headers["X-Provider"]).toBe("p");
|
||||
expect(headers["X-Shared"]).toBe("remote");
|
||||
expect(headers["X-Remote"]).toBe("r");
|
||||
});
|
||||
|
||||
it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
vi.stubEnv("GEMINI_API_KEY", "env-gemini-key");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: "GEMINI_API_KEY", // pragma: allowlist secret
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
await requireProvider(result).embedQuery("hello");
|
||||
|
||||
const { init } = readFirstFetchRequest(fetchMock);
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers["x-goog-api-key"]).toBe("env-gemini-key");
|
||||
});
|
||||
|
||||
it("normalizes local embeddings and resolves the default local model", async () => {
|
||||
const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf");
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array([2.35, 3.45, 0.63, 4.3]),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: resolveModelFileMock,
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const embedding = await requireProvider(result).embedQuery("test query");
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, value) => sum + value * value, 0));
|
||||
expect(magnitude).toBeCloseTo(1, 5);
|
||||
expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
|
||||
});
|
||||
|
||||
it("returns null provider when explicit primary and fallback auth paths fail", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
new Error("No API key found for provider"),
|
||||
);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "gemini",
|
||||
});
|
||||
|
||||
expect(result.provider).toBeNull();
|
||||
expect(result.requestedProvider).toBe("openai");
|
||||
expect(result.fallbackFrom).toBe("openai");
|
||||
expect(result.providerUnavailableReason).toContain("Fallback to gemini failed");
|
||||
describe("package embeddings barrel", () => {
|
||||
it("re-exports the source local embedding contract", () => {
|
||||
expect(DEFAULT_LOCAL_MODEL).toContain("embeddinggemma");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,373 +1 @@
|
||||
import fsSync from "node:fs";
|
||||
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
|
||||
import type { OpenClawConfig } from "../../../../src/config/config.js";
|
||||
import type { SecretInput } from "../../../../src/config/types.secrets.js";
|
||||
import { formatErrorMessage } from "../../../../src/infra/errors.js";
|
||||
import { resolveUserPath } from "../../../../src/utils.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import {
|
||||
createBedrockEmbeddingProvider,
|
||||
hasAwsCredentials,
|
||||
type BedrockEmbeddingClient,
|
||||
} from "./embeddings-bedrock.js";
|
||||
import {
|
||||
createGeminiEmbeddingProvider,
|
||||
type GeminiEmbeddingClient,
|
||||
type GeminiTaskType,
|
||||
} from "./embeddings-gemini.js";
|
||||
import {
|
||||
createLmstudioEmbeddingProvider,
|
||||
type LmstudioEmbeddingClient,
|
||||
} from "./embeddings-lmstudio.js";
|
||||
import {
|
||||
createMistralEmbeddingProvider,
|
||||
type MistralEmbeddingClient,
|
||||
} from "./embeddings-mistral.js";
|
||||
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
import { importNodeLlamaCpp } from "./node-llama.js";
|
||||
|
||||
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
export type { LmstudioEmbeddingClient } from "./embeddings-lmstudio.js";
|
||||
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
|
||||
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
export type { BedrockEmbeddingClient } from "./embeddings-bedrock.js";
|
||||
|
||||
export type EmbeddingProvider = {
|
||||
id: string;
|
||||
model: string;
|
||||
maxInputTokens?: number;
|
||||
embedQuery: (text: string) => Promise<number[]>;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderId =
|
||||
| "openai"
|
||||
| "local"
|
||||
| "gemini"
|
||||
| "voyage"
|
||||
| "mistral"
|
||||
| "bedrock"
|
||||
| "lmstudio"
|
||||
| "ollama";
|
||||
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
|
||||
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
|
||||
|
||||
// Remote providers considered for auto-selection when provider === "auto".
|
||||
// LM Studio and Ollama are intentionally excluded here so that "auto" mode does not
|
||||
// implicitly assume either instance is available.
|
||||
// Bedrock is handled separately when AWS credentials are detected.
|
||||
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
|
||||
|
||||
export type EmbeddingProviderResult = {
|
||||
provider: EmbeddingProvider | null;
|
||||
requestedProvider: EmbeddingProviderRequest;
|
||||
fallbackFrom?: EmbeddingProviderId;
|
||||
fallbackReason?: string;
|
||||
providerUnavailableReason?: string;
|
||||
openAi?: OpenAiEmbeddingClient;
|
||||
gemini?: GeminiEmbeddingClient;
|
||||
voyage?: VoyageEmbeddingClient;
|
||||
mistral?: MistralEmbeddingClient;
|
||||
bedrock?: BedrockEmbeddingClient;
|
||||
lmstudio?: LmstudioEmbeddingClient;
|
||||
ollama?: OllamaEmbeddingClient;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderOptions = {
|
||||
config: OpenClawConfig;
|
||||
agentDir?: string;
|
||||
provider: EmbeddingProviderRequest;
|
||||
remote?: {
|
||||
baseUrl?: string;
|
||||
apiKey?: SecretInput;
|
||||
headers?: Record<string, string>;
|
||||
};
|
||||
model: string;
|
||||
fallback: EmbeddingProviderFallback;
|
||||
local?: {
|
||||
modelPath?: string;
|
||||
modelCacheDir?: string;
|
||||
};
|
||||
/** Provider-specific output vector dimensions for supported embedding families. */
|
||||
outputDimensionality?: number;
|
||||
/** Gemini: override the default task type sent with embedding requests. */
|
||||
taskType?: GeminiTaskType;
|
||||
};
|
||||
|
||||
export const DEFAULT_LOCAL_MODEL =
|
||||
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
|
||||
|
||||
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
|
||||
const modelPath = options.local?.modelPath?.trim();
|
||||
if (!modelPath) {
|
||||
return false;
|
||||
}
|
||||
if (/^(hf:|https?:)/i.test(modelPath)) {
|
||||
return false;
|
||||
}
|
||||
const resolved = resolveUserPath(modelPath);
|
||||
try {
|
||||
return fsSync.statSync(resolved).isFile();
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function isMissingApiKeyError(err: unknown): boolean {
|
||||
const message = formatErrorMessage(err);
|
||||
return message.includes("No API key found for provider");
|
||||
}
|
||||
|
||||
export async function createLocalEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProvider> {
|
||||
const modelPath = options.local?.modelPath?.trim() || DEFAULT_LOCAL_MODEL;
|
||||
const modelCacheDir = options.local?.modelCacheDir?.trim();
|
||||
|
||||
// Lazy-load node-llama-cpp to keep startup light unless local is enabled.
|
||||
const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp();
|
||||
|
||||
let llama: Llama | null = null;
|
||||
let embeddingModel: LlamaModel | null = null;
|
||||
let embeddingContext: LlamaEmbeddingContext | null = null;
|
||||
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
|
||||
|
||||
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
|
||||
if (embeddingContext) {
|
||||
return embeddingContext;
|
||||
}
|
||||
if (initPromise) {
|
||||
return initPromise;
|
||||
}
|
||||
initPromise = (async () => {
|
||||
try {
|
||||
if (!llama) {
|
||||
llama = await getLlama({ logLevel: LlamaLogLevel.error });
|
||||
}
|
||||
if (!embeddingModel) {
|
||||
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
|
||||
embeddingModel = await llama.loadModel({ modelPath: resolved });
|
||||
}
|
||||
if (!embeddingContext) {
|
||||
embeddingContext = await embeddingModel.createEmbeddingContext();
|
||||
}
|
||||
return embeddingContext;
|
||||
} catch (err) {
|
||||
initPromise = null;
|
||||
throw err;
|
||||
}
|
||||
})();
|
||||
return initPromise;
|
||||
};
|
||||
|
||||
return {
|
||||
id: "local",
|
||||
model: modelPath,
|
||||
embedQuery: async (text) => {
|
||||
const ctx = await ensureContext();
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
},
|
||||
embedBatch: async (texts) => {
|
||||
const ctx = await ensureContext();
|
||||
const embeddings = await Promise.all(
|
||||
texts.map(async (text) => {
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
}),
|
||||
);
|
||||
return embeddings;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export async function createEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProviderResult> {
|
||||
const requestedProvider = options.provider;
|
||||
const fallback = options.fallback;
|
||||
|
||||
const createProvider = async (id: EmbeddingProviderId) => {
|
||||
if (id === "local") {
|
||||
const provider = await createLocalEmbeddingProvider(options);
|
||||
return { provider };
|
||||
}
|
||||
if (id === "lmstudio") {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider(options);
|
||||
return { provider, lmstudio: client };
|
||||
}
|
||||
if (id === "ollama") {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider(options);
|
||||
return { provider, ollama: client };
|
||||
}
|
||||
if (id === "gemini") {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider(options);
|
||||
return { provider, gemini: client };
|
||||
}
|
||||
if (id === "voyage") {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider(options);
|
||||
return { provider, voyage: client };
|
||||
}
|
||||
if (id === "mistral") {
|
||||
const { provider, client } = await createMistralEmbeddingProvider(options);
|
||||
return { provider, mistral: client };
|
||||
}
|
||||
if (id === "bedrock") {
|
||||
const { provider, client } = await createBedrockEmbeddingProvider(options);
|
||||
return { provider, bedrock: client };
|
||||
}
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider(options);
|
||||
return { provider, openAi: client };
|
||||
};
|
||||
|
||||
const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
|
||||
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
|
||||
|
||||
if (requestedProvider === "auto") {
|
||||
const missingKeyErrors: string[] = [];
|
||||
let localError: string | null = null;
|
||||
|
||||
if (canAutoSelectLocal(options)) {
|
||||
try {
|
||||
const local = await createProvider("local");
|
||||
return { ...local, requestedProvider };
|
||||
} catch (err) {
|
||||
localError = formatLocalSetupError(err);
|
||||
}
|
||||
}
|
||||
|
||||
for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
|
||||
try {
|
||||
const result = await createProvider(provider);
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, provider);
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
continue;
|
||||
}
|
||||
// Non-auth errors (e.g., network) are still fatal
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
// Try bedrock if AWS credentials are available
|
||||
if (await hasAwsCredentials()) {
|
||||
try {
|
||||
const result = await createProvider("bedrock");
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, "bedrock");
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
} else {
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All providers failed due to missing API keys - return null provider for FTS-only mode
|
||||
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
|
||||
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const primary = await createProvider(requestedProvider);
|
||||
return { ...primary, requestedProvider };
|
||||
} catch (primaryErr) {
|
||||
const reason = formatPrimaryError(primaryErr, requestedProvider);
|
||||
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
|
||||
try {
|
||||
const fallbackResult = await createProvider(fallback);
|
||||
return {
|
||||
...fallbackResult,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
};
|
||||
} catch (fallbackErr) {
|
||||
// Both primary and fallback failed - check if it's auth-related
|
||||
const fallbackReason = formatErrorMessage(fallbackErr);
|
||||
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
|
||||
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
|
||||
// Both failed due to missing API keys - return null for FTS-only mode
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
providerUnavailableReason: combinedReason,
|
||||
};
|
||||
}
|
||||
// Non-auth errors are still fatal
|
||||
const wrapped = new Error(combinedReason) as Error & {
|
||||
cause?: unknown;
|
||||
};
|
||||
wrapped.cause = fallbackErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
// No fallback configured - check if we should degrade to FTS-only
|
||||
if (isMissingApiKeyError(primaryErr)) {
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
const wrapped = new Error(reason) as Error & { cause?: unknown };
|
||||
wrapped.cause = primaryErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
}
|
||||
const code = (err as Error & { code?: unknown }).code;
|
||||
if (code === "ERR_MODULE_NOT_FOUND") {
|
||||
return err.message.includes("node-llama-cpp");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function formatLocalSetupError(err: unknown): string {
|
||||
const detail = formatErrorMessage(err);
|
||||
const missing = isNodeLlamaCppMissing(err);
|
||||
return [
|
||||
"Local embeddings unavailable.",
|
||||
missing
|
||||
? "Reason: optional dependency node-llama-cpp is missing (or failed to install)."
|
||||
: detail
|
||||
? `Reason: ${detail}`
|
||||
: undefined,
|
||||
missing && detail ? `Detail: ${detail}` : null,
|
||||
"To enable local embeddings:",
|
||||
"1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)",
|
||||
missing
|
||||
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
|
||||
: null,
|
||||
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
|
||||
...REMOTE_EMBEDDING_PROVIDER_IDS.map(
|
||||
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
|
||||
),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
}
|
||||
export * from "../../../../src/memory-host-sdk/host/embeddings.js";
|
||||
|
||||
@@ -100,21 +100,3 @@ export function classifyMemoryMultimodalPath(
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function normalizeGeminiEmbeddingModelForMemory(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return "";
|
||||
}
|
||||
return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, "");
|
||||
}
|
||||
|
||||
export function supportsMemoryMultimodalEmbeddings(params: {
|
||||
provider: string;
|
||||
model: string;
|
||||
}): boolean {
|
||||
if (params.provider !== "gemini") {
|
||||
return false;
|
||||
}
|
||||
return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user