diff --git a/src/plugins/contracts/registry.retry.test.ts b/src/plugins/contracts/registry.retry.test.ts index d6cd3dd000f0..298be48bbb26 100644 --- a/src/plugins/contracts/registry.retry.test.ts +++ b/src/plugins/contracts/registry.retry.test.ts @@ -35,6 +35,30 @@ function createMockRuntimeRegistry(params: { }; } +function createMockContractSnapshot(pluginId: string, providerIds: string[]) { + return { + pluginId, + cliBackendIds: [], + providerIds, + providerEnvVars: {}, + embeddingProviderIds: [], + speechProviderIds: [], + realtimeTranscriptionProviderIds: [], + realtimeVoiceProviderIds: [], + mediaUnderstandingProviderIds: [], + transcriptSourceProviderIds: [], + documentExtractorIds: [], + imageGenerationProviderIds: [], + videoGenerationProviderIds: [], + musicGenerationProviderIds: [], + webContentExtractorIds: [], + webFetchProviderIds: [], + webSearchProviderIds: [], + migrationProviderIds: [], + toolNames: [], + }; +} + afterEach(() => { vi.resetModules(); vi.restoreAllMocks(); @@ -259,6 +283,68 @@ describe("plugin contract registry scoped retries", () => { expect(loadBundledCapabilityRuntimeRegistry).not.toHaveBeenCalled(); }); + it("ignores poisoned provider metadata while resolving provider aliases", async () => { + const poisonedProvider = Object.defineProperties( + { + label: "Poisoned", + docsPath: "/providers/poisoned", + auth: [], + }, + { + id: { + enumerable: true, + get() { + throw new Error("provider contract alias metadata exploded"); + }, + }, + aliases: { + enumerable: true, + get() { + throw new Error("provider contract aliases exploded"); + }, + }, + hookAliases: { + enumerable: true, + get() { + throw new Error("provider contract hook aliases exploded"); + }, + }, + }, + ) as ProviderPlugin; + const healthyProvider = { + id: "healthy-provider", + label: "Healthy", + docsPath: "/providers/healthy", + aliases: ["healthy-alias"], + auth: [], + } as ProviderPlugin; + const resolveBundledExplicitProviderContractsFromPublicArtifacts = vi.fn( + ({ onlyPluginIds }: { onlyPluginIds: readonly string[] }) => + onlyPluginIds[0] === "poisoned" + ? [{ pluginId: "poisoned", provider: poisonedProvider }] + : [{ pluginId: "healthy", provider: healthyProvider }], + ); + + vi.doMock("./inventory/bundled-capability-metadata.js", () => ({ + BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS: [ + createMockContractSnapshot("poisoned", ["poisoned"]), + createMockContractSnapshot("healthy", ["healthy-provider"]), + ], + })); + vi.doMock("../bundled-capability-runtime.js", () => ({ + loadBundledCapabilityRuntimeRegistry: vi.fn(() => { + throw new Error("provider public artifacts should be enough"); + }), + })); + vi.doMock("../provider-contract-public-artifacts.js", () => ({ + resolveBundledExplicitProviderContractsFromPublicArtifacts, + })); + + const { resolveProviderContractPluginIdsForProviderAlias } = await import("./registry.js"); + + expect(resolveProviderContractPluginIdsForProviderAlias("healthy-alias")).toEqual(["healthy"]); + }); + it("uses web search public artifacts before falling back to the bundled runtime registry", async () => { const loadBundledCapabilityRuntimeRegistry = vi.fn(() => { throw new Error( diff --git a/src/plugins/contracts/registry.ts b/src/plugins/contracts/registry.ts index f8f0a9626f55..3ab5b0eddeb2 100644 --- a/src/plugins/contracts/registry.ts +++ b/src/plugins/contracts/registry.ts @@ -391,11 +391,7 @@ function loadProviderContractRegistry(): ProviderContractEntry[] { } function loadUniqueProviderContractProviders(): ProviderPlugin[] { - return [ - ...new Map( - loadProviderContractRegistry().map((entry) => [entry.provider.id, entry.provider]), - ).values(), - ]; + return uniqueProviderContractProvidersFromEntries(loadProviderContractRegistry()); } function loadProviderContractPluginIds(): string[] { @@ -666,14 +662,49 @@ export const providerContractCompatPluginIds: string[] = createLazyArrayView( loadProviderContractCompatPluginIds, ); +function readProviderContractId(provider: ProviderPlugin): string | undefined { + try { + return typeof provider.id === "string" ? provider.id : undefined; + } catch { + return undefined; + } +} + +function readProviderContractAliases( + provider: ProviderPlugin, + key: "aliases" | "hookAliases", +): string[] { + try { + const aliases = provider[key]; + return Array.isArray(aliases) + ? aliases.filter((alias): alias is string => typeof alias === "string") + : []; + } catch { + return []; + } +} + +function uniqueProviderContractProvidersFromEntries( + entries: readonly ProviderContractEntry[], +): ProviderPlugin[] { + const providers = new Map(); + for (const entry of entries) { + const providerId = readProviderContractId(entry.provider); + if (providerId) { + providers.set(providerId, entry.provider); + } + } + return [...providers.values()]; +} + export function requireProviderContractProvider(providerId: string): ProviderPlugin { const pluginIds = resolveBundledProviderContractPluginIdsByProviderId().get(providerId) ?? []; const entries = loadProviderContractEntriesForPluginIds(pluginIds); - const provider = entries.find((entry) => entry.provider.id === providerId)?.provider; + const provider = entries.find( + (entry) => readProviderContractId(entry.provider) === providerId, + )?.provider; if (!provider) { - const pluginScopedProviders = [ - ...new Map(entries.map((entry) => [entry.provider.id, entry.provider])).values(), - ]; + const pluginScopedProviders = uniqueProviderContractProvidersFromEntries(entries); if (pluginIds.length === 1 && pluginScopedProviders.length === 1) { return pluginScopedProviders[0]; } @@ -705,10 +736,10 @@ export function resolveProviderContractPluginIdsForProviderAlias( loadProviderContractEntriesForPluginIds(resolveBundledProviderContractPluginIds()) .filter((entry) => { const providerIds = [ - entry.provider.id, - ...(entry.provider.aliases ?? []), - ...(entry.provider.hookAliases ?? []), - ]; + readProviderContractId(entry.provider), + ...readProviderContractAliases(entry.provider, "aliases"), + ...readProviderContractAliases(entry.provider, "hookAliases"), + ].filter((candidate): candidate is string => typeof candidate === "string"); return providerIds.some( (candidate) => normalizeProviderId(candidate) === normalizedProvider, ); @@ -722,13 +753,11 @@ export function resolveProviderContractProvidersForPluginIds( pluginIds: readonly string[], ): ProviderPlugin[] { const allowed = new Set(pluginIds); - return [ - ...new Map( - loadProviderContractEntriesForPluginIds([...allowed]) - .filter((entry) => allowed.has(entry.pluginId)) - .map((entry) => [entry.provider.id, entry.provider]), - ).values(), - ]; + return uniqueProviderContractProvidersFromEntries( + loadProviderContractEntriesForPluginIds([...allowed]).filter((entry) => + allowed.has(entry.pluginId), + ), + ); } export const webSearchProviderContractRegistry: WebSearchProviderContractEntry[] =