diff --git a/src/plugins/contracts/trusted-tool-policy-registration.contract.test.ts b/src/plugins/contracts/trusted-tool-policy-registration.contract.test.ts new file mode 100644 index 000000000000..1bfd4e43f25c --- /dev/null +++ b/src/plugins/contracts/trusted-tool-policy-registration.contract.test.ts @@ -0,0 +1,82 @@ +// Trusted tool policy registration tests cover plugin-owned evaluator snapshotting. +import { + createPluginRegistryFixture, + registerTestPlugin, +} from "openclaw/plugin-sdk/plugin-test-contracts"; +import { afterEach, describe, expect, it } from "vitest"; +import type { PluginTrustedToolPolicyRegistration } from "../host-hooks.js"; +import { createEmptyPluginRegistry } from "../registry-empty.js"; +import { setActivePluginRegistry } from "../runtime.js"; +import { createPluginRecord } from "../status.test-helpers.js"; +import { runTrustedToolPolicies } from "../trusted-tool-policy.js"; + +describe("plugin trusted tool policy registration", () => { + afterEach(() => { + setActivePluginRegistry(createEmptyPluginRegistry()); + }); + + it("snapshots policy evaluators before trusted tool policy execution", async () => { + let idReads = 0; + let descriptionReads = 0; + let evaluateReads = 0; + const evaluatedTools: string[] = []; + const { config, registry } = createPluginRegistryFixture(); + registerTestPlugin({ + registry, + config, + record: createPluginRecord({ + id: "volatile-trusted-policy", + name: "Volatile Trusted Policy", + origin: "bundled", + }), + register(api) { + api.registerTrustedToolPolicy({ + get id() { + idReads += 1; + if (idReads > 1) { + throw new Error("policy id getter re-read"); + } + return "policy"; + }, + get description() { + descriptionReads += 1; + if (descriptionReads > 1) { + throw new Error("policy description getter re-read"); + } + return "Policy"; + }, + get evaluate() { + evaluateReads += 1; + if (evaluateReads > 1) { + throw new Error("policy evaluate getter re-read"); + } + return (event) => { + evaluatedTools.push(event.toolName); + return { block: true, blockReason: "blocked by stored policy" }; + }; + }, + } as PluginTrustedToolPolicyRegistration); + }, + }); + setActivePluginRegistry(registry.registry); + + expect(registry.registry.trustedToolPolicies?.[0]?.policy.description).toBe("Policy"); + expect(idReads).toBe(1); + expect(descriptionReads).toBe(1); + expect(evaluateReads).toBe(1); + + await expect( + runTrustedToolPolicies( + { toolName: "dangerous_tool", params: {} }, + { toolName: "dangerous_tool" }, + ), + ).resolves.toEqual({ + block: true, + blockReason: "blocked by stored policy", + }); + expect(evaluatedTools).toEqual(["dangerous_tool"]); + expect(idReads).toBe(1); + expect(descriptionReads).toBe(1); + expect(evaluateReads).toBe(1); + }); +}); diff --git a/src/plugins/registry.ts b/src/plugins/registry.ts index 21fc5ce98f03..c81ebce60053 100644 --- a/src/plugins/registry.ts +++ b/src/plugins/registry.ts @@ -2190,6 +2190,38 @@ export function createPluginRegistry(registryParams: PluginRegistryParams) { }); }; + const readTrustedToolPolicyFields = ( + record: PluginRecord, + policy: PluginTrustedToolPolicyRegistration, + ): + | { + id: unknown; + description: unknown; + evaluate: unknown; + } + | undefined => { + let id: unknown; + try { + id = policy.id; + return { + id, + description: policy.description, + evaluate: policy.evaluate, + }; + } catch (error) { + const normalizedId = normalizeOptionalHostHookString(id); + pushDiagnostic({ + level: "error", + pluginId: record.id, + source: record.source, + message: + `trusted tool policy registration has unreadable fields` + + `${normalizedId ? `: ${normalizedId}` : ""}: ${formatErrorMessage(error)}`, + }); + return undefined; + } + }; + const registerTrustedToolPolicy = ( record: PluginRecord, policy: PluginTrustedToolPolicyRegistration, @@ -2203,9 +2235,14 @@ export function createPluginRegistry(registryParams: PluginRegistryParams) { }); return; } - const id = normalizeHostHookString(policy.id); - const description = normalizeHostHookString(policy.description); - if (!id || !description || typeof policy.evaluate !== "function") { + const fields = readTrustedToolPolicyFields(record, policy); + if (!fields) { + return; + } + const id = normalizeHostHookString(fields.id); + const description = normalizeHostHookString(fields.description); + const evaluate = fields.evaluate; + if (!id || !description || typeof evaluate !== "function") { pushDiagnostic({ level: "error", pluginId: record.id, @@ -2228,9 +2265,9 @@ export function createPluginRegistry(registryParams: PluginRegistryParams) { pluginId: record.id, pluginName: record.name, policy: { - ...policy, id, description, + evaluate: evaluate as PluginTrustedToolPolicyRegistration["evaluate"], }, source: record.source, rootDir: record.rootDir,