mirror of
https://github.com/openclaw/openclaw.git
synced 2026-06-06 05:51:15 +08:00
feat: promote provider tool call stream wrapper (#86489)
This commit is contained in:
committed by
GitHub
parent
6eb46ceac8
commit
5d018034f6
@@ -368,7 +368,7 @@ API key auth, and dynamic model resolution.
|
||||
Each family builder is composed from lower-level public helpers exported from the same package, which you can reach for when a provider needs to go off the common pattern:
|
||||
|
||||
- `openclaw/plugin-sdk/provider-model-shared` - `ProviderReplayFamily`, `buildProviderReplayFamilyHooks(...)`, and the raw replay builders (`buildOpenAICompatibleReplayPolicy`, `buildAnthropicReplayPolicyForModel`, `buildGoogleGeminiReplayPolicy`, `buildHybridAnthropicOrOpenAIReplayPolicy`). Also exports Gemini replay helpers (`sanitizeGoogleGeminiReplayHistory`, `resolveTaggedReasoningOutputMode`) and endpoint/model helpers (`resolveProviderEndpoint`, `normalizeProviderId`, `normalizeGooglePreviewModelId`).
|
||||
- `openclaw/plugin-sdk/provider-stream` - `ProviderStreamFamily`, `buildProviderStreamFamilyHooks(...)`, `composeProviderStreamWrappers(...)`, plus the shared OpenAI/Codex wrappers (`createOpenAIAttributionHeadersWrapper`, `createOpenAIFastModeWrapper`, `createOpenAIServiceTierWrapper`, `createOpenAIResponsesContextManagementWrapper`, `createCodexNativeWebSearchWrapper`), DeepSeek V4 OpenAI-compatible wrapper (`createDeepSeekV4OpenAICompatibleThinkingWrapper`), Anthropic Messages thinking prefill cleanup (`createAnthropicThinkingPrefillPayloadWrapper`), and shared proxy/provider wrappers (`createOpenRouterWrapper`, `createToolStreamWrapper`, `createMinimaxFastModeWrapper`).
|
||||
- `openclaw/plugin-sdk/provider-stream` - `ProviderStreamFamily`, `buildProviderStreamFamilyHooks(...)`, `composeProviderStreamWrappers(...)`, plus the shared OpenAI/Codex wrappers (`createOpenAIAttributionHeadersWrapper`, `createOpenAIFastModeWrapper`, `createOpenAIServiceTierWrapper`, `createOpenAIResponsesContextManagementWrapper`, `createCodexNativeWebSearchWrapper`), DeepSeek V4 OpenAI-compatible wrapper (`createDeepSeekV4OpenAICompatibleThinkingWrapper`), Anthropic Messages thinking prefill cleanup (`createAnthropicThinkingPrefillPayloadWrapper`), plain-text tool-call promotion (`createPlainTextToolCallPromotionWrapper`), and shared proxy/provider wrappers (`createOpenRouterWrapper`, `createToolStreamWrapper`, `createMinimaxFastModeWrapper`).
|
||||
- `openclaw/plugin-sdk/provider-tools` - `ProviderToolCompatFamily`, `buildProviderToolCompatFamilyHooks("deepseek" | "gemini" | "openai")`, and underlying provider schema helpers.
|
||||
|
||||
Some stream helpers stay provider-local on purpose. `@openclaw/anthropic-provider` keeps `wrapAnthropicProviderStream`, `resolveAnthropicBetas`, `resolveAnthropicFastMode`, `resolveAnthropicServiceTier`, and the lower-level Anthropic wrapper builders in its own public `api.ts` / `contract-api.ts` seam because they encode Claude OAuth beta handling and `context1m` gating. The xAI plugin similarly keeps native xAI Responses shaping in its own `wrapStreamFn` (`/fast` aliases, default `tool_stream`, unsupported strict-tool cleanup, xAI-specific reasoning-payload removal).
|
||||
|
||||
@@ -179,7 +179,8 @@ focused channel/runtime subpaths, `config-contracts`, `string-coerce-runtime`,
|
||||
| `plugin-sdk/embedding-providers` | General embedding provider types and read helpers, including `EmbeddingProviderAdapter`, `getEmbeddingProvider(...)`, and `listEmbeddingProviders(...)`; plugins register providers through `api.registerEmbeddingProvider(...)` so manifest ownership is enforced |
|
||||
| `plugin-sdk/provider-tools` | `ProviderToolCompatFamily`, `buildProviderToolCompatFamilyHooks`, and DeepSeek/Gemini/OpenAI schema cleanup + diagnostics |
|
||||
| `plugin-sdk/provider-usage` | `fetchClaudeUsage` and similar |
|
||||
| `plugin-sdk/provider-stream` | `ProviderStreamFamily`, `buildProviderStreamFamilyHooks`, `composeProviderStreamWrappers`, stream wrapper types, and shared Anthropic/Bedrock/DeepSeek V4/Google/Kilocode/Moonshot/OpenAI/OpenRouter/Z.A.I/MiniMax/Copilot wrapper helpers |
|
||||
| `plugin-sdk/provider-stream` | `ProviderStreamFamily`, `buildProviderStreamFamilyHooks`, `composeProviderStreamWrappers`, stream wrapper types, plain-text tool-call promotion, and shared Anthropic/Bedrock/DeepSeek V4/Google/Kilocode/Moonshot/OpenAI/OpenRouter/Z.A.I/MiniMax/Copilot wrapper helpers |
|
||||
| `plugin-sdk/provider-stream-shared` | Public shared provider stream wrapper helpers including `composeProviderStreamWrappers`, `createPlainTextToolCallPromotionWrapper`, `createPayloadPatchStreamWrapper`, `createToolStreamWrapper`, and Anthropic/DeepSeek/OpenAI-compatible stream utilities |
|
||||
| `plugin-sdk/provider-transport-runtime` | Native provider transport helpers such as guarded fetch, transport message transforms, and writable transport event streams |
|
||||
| `plugin-sdk/provider-onboard` | Onboarding config patch helpers |
|
||||
| `plugin-sdk/global-singleton` | Process-local singleton/map/cache helpers |
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { StreamFn } from "@earendil-works/pi-agent-core";
|
||||
import { streamSimple } from "@earendil-works/pi-ai";
|
||||
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
|
||||
import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { createPlainTextToolCallPromotionWrapper } from "openclaw/plugin-sdk/provider-stream-runtime-internal";
|
||||
import { createPlainTextToolCallPromotionWrapper } from "openclaw/plugin-sdk/provider-stream-shared";
|
||||
import { ssrfPolicyFromHttpBaseUrlAllowedHostname } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { LMSTUDIO_PROVIDER_ID } from "./defaults.js";
|
||||
import { ensureLmstudioModelLoaded } from "./models.fetch.js";
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { StreamFn } from "@earendil-works/pi-agent-core";
|
||||
import { streamSimple } from "@earendil-works/pi-ai";
|
||||
import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { createPlainTextToolCallPromotionWrapper } from "openclaw/plugin-sdk/provider-stream-runtime-internal";
|
||||
import {
|
||||
composeProviderStreamWrappers,
|
||||
createPlainTextToolCallPromotionWrapper,
|
||||
createToolStreamWrapper,
|
||||
} from "openclaw/plugin-sdk/provider-stream-shared";
|
||||
|
||||
|
||||
@@ -1,417 +1,5 @@
|
||||
import { randomUUID } from "node:crypto";
|
||||
import type { StreamFn } from "@earendil-works/pi-agent-core";
|
||||
import { createAssistantMessageEventStream, streamSimple } from "@earendil-works/pi-ai";
|
||||
import { parseStandalonePlainTextToolCallBlocks } from "./tool-payload.js";
|
||||
|
||||
function toRecord(value: unknown): Record<string, unknown> | undefined {
|
||||
return value && typeof value === "object" ? (value as Record<string, unknown>) : undefined;
|
||||
}
|
||||
|
||||
function resolveContextToolNames(context: Parameters<StreamFn>[1]): Set<string> {
|
||||
const tools = (context as { tools?: unknown }).tools;
|
||||
if (!Array.isArray(tools)) {
|
||||
return new Set();
|
||||
}
|
||||
const names = tools
|
||||
.map((tool) => {
|
||||
const record = toRecord(tool);
|
||||
return typeof record?.name === "string" && record.name.trim() ? record.name : undefined;
|
||||
})
|
||||
.filter((name): name is string => Boolean(name));
|
||||
return new Set(names);
|
||||
}
|
||||
|
||||
function couldStillBePlainTextToolCall(text: string, toolNames: Set<string>): boolean {
|
||||
if (text.length > 256_000) {
|
||||
return false;
|
||||
}
|
||||
const trimmed = text.trimStart();
|
||||
return (
|
||||
trimmed.length === 0 ||
|
||||
couldStillBeBracketedToolCall(trimmed, toolNames) ||
|
||||
couldStillBeHarmonyToolCall(trimmed, toolNames)
|
||||
);
|
||||
}
|
||||
|
||||
function matchesLiteralPrefix(text: string, literal: string): boolean {
|
||||
return literal.startsWith(text) || text.startsWith(literal);
|
||||
}
|
||||
|
||||
function skipHorizontalWhitespace(text: string, start: number): number {
|
||||
let cursor = start;
|
||||
while (text[cursor] === " " || text[cursor] === "\t") {
|
||||
cursor += 1;
|
||||
}
|
||||
return cursor;
|
||||
}
|
||||
|
||||
function isToolNameChar(char: string | undefined): boolean {
|
||||
return Boolean(char && /[A-Za-z0-9_-]/.test(char));
|
||||
}
|
||||
|
||||
function hasToolNamePrefix(toolNames: Set<string>, prefix: string): boolean {
|
||||
for (const toolName of toolNames) {
|
||||
if (toolName.startsWith(prefix)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function couldStillBeJsonPayload(text: string, start: number): boolean {
|
||||
let cursor = start;
|
||||
while (cursor < text.length && /\s/.test(text[cursor] ?? "")) {
|
||||
cursor += 1;
|
||||
}
|
||||
return cursor >= text.length || text[cursor] === "{";
|
||||
}
|
||||
|
||||
function couldStillBeBracketedToolCall(text: string, toolNames: Set<string>): boolean {
|
||||
if (!text.startsWith("[")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const toolPrefix = "[tool:";
|
||||
if (matchesLiteralPrefix(text, toolPrefix)) {
|
||||
if (text.length <= toolPrefix.length) {
|
||||
return true;
|
||||
}
|
||||
let cursor = toolPrefix.length;
|
||||
while (isToolNameChar(text[cursor])) {
|
||||
cursor += 1;
|
||||
}
|
||||
const name = text.slice(toolPrefix.length, cursor);
|
||||
if (!name || !hasToolNamePrefix(toolNames, name)) {
|
||||
return false;
|
||||
}
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
if (text[cursor] !== "]") {
|
||||
return false;
|
||||
}
|
||||
return couldStillBeJsonPayload(text, cursor + 1);
|
||||
}
|
||||
|
||||
let cursor = 1;
|
||||
while (isToolNameChar(text[cursor])) {
|
||||
cursor += 1;
|
||||
}
|
||||
const name = text.slice(1, cursor);
|
||||
if (!name || !hasToolNamePrefix(toolNames, name)) {
|
||||
return false;
|
||||
}
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
if (text[cursor] !== "]") {
|
||||
return false;
|
||||
}
|
||||
|
||||
cursor = skipHorizontalWhitespace(text, cursor + 1);
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
if (text[cursor] === "\r") {
|
||||
if (cursor + 1 >= text.length) {
|
||||
return true;
|
||||
}
|
||||
return couldStillBeJsonPayload(text, text[cursor + 1] === "\n" ? cursor + 2 : cursor + 1);
|
||||
}
|
||||
if (text[cursor] !== "\n") {
|
||||
return false;
|
||||
}
|
||||
return couldStillBeJsonPayload(text, cursor + 1);
|
||||
}
|
||||
|
||||
function couldStillBeHarmonyToolCall(text: string, toolNames: Set<string>): boolean {
|
||||
const channelMarker = "<|channel|>";
|
||||
let cursor = 0;
|
||||
if (matchesLiteralPrefix(text, channelMarker)) {
|
||||
if (text.length <= channelMarker.length) {
|
||||
return true;
|
||||
}
|
||||
cursor = channelMarker.length;
|
||||
}
|
||||
|
||||
const rest = text.slice(cursor);
|
||||
const channel = ["commentary", "analysis", "final"].find((candidate) =>
|
||||
matchesLiteralPrefix(rest, candidate),
|
||||
);
|
||||
if (!channel) {
|
||||
return false;
|
||||
}
|
||||
if (rest.length <= channel.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cursor += channel.length;
|
||||
cursor = skipHorizontalWhitespace(text, cursor);
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const toMarker = "to=";
|
||||
const toRest = text.slice(cursor);
|
||||
if (!matchesLiteralPrefix(toRest, toMarker)) {
|
||||
return false;
|
||||
}
|
||||
if (toRest.length <= toMarker.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cursor += toMarker.length;
|
||||
const nameStart = cursor;
|
||||
while (isToolNameChar(text[cursor])) {
|
||||
cursor += 1;
|
||||
}
|
||||
const name = text.slice(nameStart, cursor);
|
||||
if (!name || !hasToolNamePrefix(toolNames, name)) {
|
||||
return false;
|
||||
}
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cursor = skipHorizontalWhitespace(text, cursor);
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
if (!toolNames.has(name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const codeMarker = "code";
|
||||
const codeRest = text.slice(cursor);
|
||||
if (!matchesLiteralPrefix(codeRest, codeMarker)) {
|
||||
return false;
|
||||
}
|
||||
if (codeRest.length <= codeMarker.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cursor += codeMarker.length;
|
||||
while (cursor < text.length && /\s/.test(text[cursor] ?? "")) {
|
||||
cursor += 1;
|
||||
}
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const messageMarker = "<|message|>";
|
||||
const messageRest = text.slice(cursor);
|
||||
if (matchesLiteralPrefix(messageRest, messageMarker)) {
|
||||
return true;
|
||||
}
|
||||
return text[cursor] === "{";
|
||||
}
|
||||
|
||||
function createSyntheticToolCallId(): string {
|
||||
return `call_${randomUUID().replace(/-/g, "").slice(0, 24)}`;
|
||||
}
|
||||
|
||||
function createPlainTextToolCallBlock(parsed: {
|
||||
arguments: Record<string, unknown>;
|
||||
name: string;
|
||||
}): Record<string, unknown> {
|
||||
return {
|
||||
type: "toolCall",
|
||||
id: createSyntheticToolCallId(),
|
||||
name: parsed.name,
|
||||
arguments: parsed.arguments,
|
||||
partialArgs: JSON.stringify(parsed.arguments),
|
||||
};
|
||||
}
|
||||
|
||||
function promotePlainTextToolCalls(
|
||||
message: unknown,
|
||||
toolNames: Set<string>,
|
||||
): Record<string, unknown> | undefined {
|
||||
const messageRecord = toRecord(message);
|
||||
if (!messageRecord) {
|
||||
return undefined;
|
||||
}
|
||||
if (!Array.isArray(messageRecord.content)) {
|
||||
if (typeof messageRecord.content !== "string" || !messageRecord.content.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
const parsed = parseStandalonePlainTextToolCallBlocks(messageRecord.content, {
|
||||
allowedToolNames: toolNames,
|
||||
});
|
||||
if (!parsed) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
...messageRecord,
|
||||
content: parsed.map(createPlainTextToolCallBlock),
|
||||
stopReason: "toolUse",
|
||||
};
|
||||
}
|
||||
if (
|
||||
messageRecord.content.some((block) => toRecord(block)?.type === "toolCall") ||
|
||||
messageRecord.content.length === 0
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let promoted = false;
|
||||
const nextContent: Array<Record<string, unknown>> = [];
|
||||
for (const block of messageRecord.content) {
|
||||
const blockRecord = toRecord(block);
|
||||
if (!blockRecord) {
|
||||
return undefined;
|
||||
}
|
||||
if (blockRecord.type !== "text") {
|
||||
nextContent.push(blockRecord);
|
||||
continue;
|
||||
}
|
||||
const text = typeof blockRecord.text === "string" ? blockRecord.text : "";
|
||||
if (!text.trim()) {
|
||||
continue;
|
||||
}
|
||||
const parsed = parseStandalonePlainTextToolCallBlocks(text, {
|
||||
allowedToolNames: toolNames,
|
||||
});
|
||||
if (!parsed) {
|
||||
return undefined;
|
||||
}
|
||||
nextContent.push(...parsed.map(createPlainTextToolCallBlock));
|
||||
promoted = true;
|
||||
}
|
||||
|
||||
if (!promoted) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
...messageRecord,
|
||||
content: nextContent,
|
||||
stopReason: "toolUse",
|
||||
};
|
||||
}
|
||||
|
||||
function emitPromotedToolCallEvents(
|
||||
stream: { push(event: unknown): void },
|
||||
message: Record<string, unknown>,
|
||||
): void {
|
||||
const content = Array.isArray(message.content) ? message.content : [];
|
||||
content.forEach((block, contentIndex) => {
|
||||
const record = toRecord(block);
|
||||
if (record?.type !== "toolCall") {
|
||||
return;
|
||||
}
|
||||
stream.push({ type: "toolcall_start", contentIndex, partial: message });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex,
|
||||
delta: typeof record.partialArgs === "string" ? record.partialArgs : "{}",
|
||||
partial: message,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function wrapPlainTextToolCallStream(
|
||||
source: ReturnType<StreamFn>,
|
||||
context: Parameters<StreamFn>[1],
|
||||
): ReturnType<StreamFn> {
|
||||
const toolNames = resolveContextToolNames(context);
|
||||
if (toolNames.size === 0) {
|
||||
return source;
|
||||
}
|
||||
const output = createAssistantMessageEventStream();
|
||||
const stream = output as unknown as { push(event: unknown): void; end(): void };
|
||||
|
||||
void (async () => {
|
||||
const bufferedTextEvents: unknown[] = [];
|
||||
let bufferedText = "";
|
||||
let ended = false;
|
||||
const endStream = () => {
|
||||
if (!ended) {
|
||||
ended = true;
|
||||
stream.end();
|
||||
}
|
||||
};
|
||||
const flushBufferedTextEvents = () => {
|
||||
for (const event of bufferedTextEvents.splice(0)) {
|
||||
stream.push(event);
|
||||
}
|
||||
bufferedText = "";
|
||||
};
|
||||
|
||||
try {
|
||||
for await (const event of source as AsyncIterable<unknown>) {
|
||||
const record = toRecord(event);
|
||||
const type = typeof record?.type === "string" ? record.type : "";
|
||||
|
||||
if (type === "text_start" || type === "text_delta" || type === "text_end") {
|
||||
bufferedTextEvents.push(event);
|
||||
if (typeof record?.delta === "string") {
|
||||
bufferedText += record.delta;
|
||||
} else if (typeof record?.content === "string" && !bufferedText) {
|
||||
bufferedText = record.content;
|
||||
}
|
||||
if (!couldStillBePlainTextToolCall(bufferedText, toolNames)) {
|
||||
flushBufferedTextEvents();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (type === "done") {
|
||||
const promotedMessage = promotePlainTextToolCalls(record?.message, toolNames);
|
||||
if (promotedMessage) {
|
||||
bufferedTextEvents.splice(0);
|
||||
bufferedText = "";
|
||||
emitPromotedToolCallEvents(stream, promotedMessage);
|
||||
stream.push({ ...record, reason: "toolUse", message: promotedMessage });
|
||||
} else {
|
||||
flushBufferedTextEvents();
|
||||
stream.push(event);
|
||||
}
|
||||
endStream();
|
||||
return;
|
||||
}
|
||||
|
||||
flushBufferedTextEvents();
|
||||
stream.push(event);
|
||||
if (type === "error") {
|
||||
endStream();
|
||||
return;
|
||||
}
|
||||
}
|
||||
flushBufferedTextEvents();
|
||||
} catch (error) {
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason: "error",
|
||||
error: {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
stopReason: "error",
|
||||
errorMessage: error instanceof Error ? error.message : String(error),
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
endStream();
|
||||
}
|
||||
})();
|
||||
|
||||
return output as ReturnType<StreamFn>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Bundled-provider runtime hygiene for providers that can leak tool-use syntax
|
||||
* as assistant text even when native tool calling is enabled.
|
||||
* @deprecated Use `createPlainTextToolCallPromotionWrapper` from
|
||||
* `openclaw/plugin-sdk/provider-stream-shared`.
|
||||
*/
|
||||
export function createPlainTextToolCallPromotionWrapper(
|
||||
baseStreamFn: StreamFn | undefined,
|
||||
): StreamFn {
|
||||
const underlying = baseStreamFn ?? streamSimple;
|
||||
return (model, context, options) => {
|
||||
const maybeStream = underlying(model, context, options);
|
||||
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
|
||||
return Promise.resolve(maybeStream).then((stream) =>
|
||||
wrapPlainTextToolCallStream(stream, context),
|
||||
) as ReturnType<StreamFn>;
|
||||
}
|
||||
return wrapPlainTextToolCallStream(maybeStream, context);
|
||||
};
|
||||
}
|
||||
export { createPlainTextToolCallPromotionWrapper } from "./provider-stream-shared.js";
|
||||
|
||||
@@ -1,14 +1,28 @@
|
||||
import type { StreamFn } from "@earendil-works/pi-agent-core";
|
||||
import { createAssistantMessageEventStream } from "@earendil-works/pi-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
createDeepSeekV4OpenAICompatibleThinkingWrapper,
|
||||
createAnthropicThinkingPrefillPayloadWrapper,
|
||||
createPayloadPatchStreamWrapper,
|
||||
createPlainTextToolCallPromotionWrapper,
|
||||
defaultToolStreamExtraParams,
|
||||
isOpenAICompatibleThinkingEnabled,
|
||||
stripTrailingAnthropicAssistantPrefillWhenThinking,
|
||||
} from "./provider-stream-shared.js";
|
||||
|
||||
function createEventStream(events: unknown[]): ReturnType<StreamFn> {
|
||||
const output = createAssistantMessageEventStream();
|
||||
const stream = output as unknown as { push(event: unknown): void; end(): void };
|
||||
queueMicrotask(() => {
|
||||
for (const event of events) {
|
||||
stream.push(event);
|
||||
}
|
||||
stream.end();
|
||||
});
|
||||
return output as ReturnType<StreamFn>;
|
||||
}
|
||||
|
||||
describe("defaultToolStreamExtraParams", () => {
|
||||
it("defaults tool_stream on when absent", () => {
|
||||
expect(defaultToolStreamExtraParams()).toEqual({ tool_stream: true });
|
||||
@@ -139,6 +153,80 @@ describe("createPayloadPatchStreamWrapper", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("createPlainTextToolCallPromotionWrapper", () => {
|
||||
it("promotes standalone text tool calls into tool-call stream events", async () => {
|
||||
const baseStreamFn: StreamFn = () =>
|
||||
createEventStream([
|
||||
{ type: "text_start", content: "" },
|
||||
{ type: "text_delta", delta: '[tool:read] {"path":"/tmp/file.txt"}' },
|
||||
{ type: "text_end" },
|
||||
{
|
||||
type: "done",
|
||||
reason: "stop",
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: '[tool:read] {"path":"/tmp/file.txt"}',
|
||||
},
|
||||
},
|
||||
]);
|
||||
const wrapped = createPlainTextToolCallPromotionWrapper(baseStreamFn);
|
||||
const events: unknown[] = [];
|
||||
|
||||
for await (const event of wrapped(
|
||||
{} as never,
|
||||
{ tools: [{ name: "read" }] } as never,
|
||||
{},
|
||||
) as AsyncIterable<unknown>) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events.map((event) => (event as { type?: string }).type)).toEqual([
|
||||
"toolcall_start",
|
||||
"toolcall_delta",
|
||||
"done",
|
||||
]);
|
||||
const done = events.at(-1) as { message?: { content?: unknown; stopReason?: unknown } };
|
||||
expect(done.message?.stopReason).toBe("toolUse");
|
||||
expect(done.message?.content).toEqual([
|
||||
expect.objectContaining({
|
||||
type: "toolCall",
|
||||
name: "read",
|
||||
arguments: { path: "/tmp/file.txt" },
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
it("passes through bracketed text when no configured tool names match", async () => {
|
||||
const baseStreamFn: StreamFn = () =>
|
||||
createEventStream([
|
||||
{ type: "text_delta", delta: "[note] keep streaming" },
|
||||
{
|
||||
type: "done",
|
||||
reason: "stop",
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "[note] keep streaming",
|
||||
},
|
||||
},
|
||||
]);
|
||||
const wrapped = createPlainTextToolCallPromotionWrapper(baseStreamFn);
|
||||
const events: unknown[] = [];
|
||||
|
||||
for await (const event of wrapped(
|
||||
{} as never,
|
||||
{ tools: [{ name: "read" }] } as never,
|
||||
{},
|
||||
) as AsyncIterable<unknown>) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events.map((event) => (event as { type?: string }).type)).toEqual([
|
||||
"text_delta",
|
||||
"done",
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("stripTrailingAnthropicAssistantPrefillWhenThinking", () => {
|
||||
it("removes trailing assistant text turns when Anthropic thinking is enabled", () => {
|
||||
const payload = {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import { randomUUID } from "node:crypto";
|
||||
import type { StreamFn } from "@earendil-works/pi-agent-core";
|
||||
import { streamSimple } from "@earendil-works/pi-ai";
|
||||
import { createAssistantMessageEventStream, streamSimple } from "@earendil-works/pi-ai";
|
||||
import { streamWithPayloadPatch } from "../agents/pi-embedded-runner/stream-payload-utils.js";
|
||||
import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js";
|
||||
import type { ProviderWrapStreamFnContext } from "./plugin-entry.js";
|
||||
import { parseStandalonePlainTextToolCallBlocks } from "./tool-payload.js";
|
||||
|
||||
export type ProviderStreamWrapperFactory =
|
||||
| ((streamFn: StreamFn | undefined) => StreamFn | undefined)
|
||||
@@ -20,6 +22,419 @@ export function composeProviderStreamWrappers(
|
||||
);
|
||||
}
|
||||
|
||||
function toRecord(value: unknown): Record<string, unknown> | undefined {
|
||||
return value && typeof value === "object" ? (value as Record<string, unknown>) : undefined;
|
||||
}
|
||||
|
||||
function resolveContextToolNames(context: Parameters<StreamFn>[1]): Set<string> {
|
||||
const tools = (context as { tools?: unknown }).tools;
|
||||
if (!Array.isArray(tools)) {
|
||||
return new Set();
|
||||
}
|
||||
const names = tools
|
||||
.map((tool) => {
|
||||
const record = toRecord(tool);
|
||||
return typeof record?.name === "string" && record.name.trim() ? record.name : undefined;
|
||||
})
|
||||
.filter((name): name is string => Boolean(name));
|
||||
return new Set(names);
|
||||
}
|
||||
|
||||
function couldStillBePlainTextToolCall(text: string, toolNames: Set<string>): boolean {
|
||||
if (text.length > 256_000) {
|
||||
return false;
|
||||
}
|
||||
const trimmed = text.trimStart();
|
||||
return (
|
||||
trimmed.length === 0 ||
|
||||
couldStillBeBracketedToolCall(trimmed, toolNames) ||
|
||||
couldStillBeHarmonyToolCall(trimmed, toolNames)
|
||||
);
|
||||
}
|
||||
|
||||
function matchesLiteralPrefix(text: string, literal: string): boolean {
|
||||
return literal.startsWith(text) || text.startsWith(literal);
|
||||
}
|
||||
|
||||
function skipHorizontalWhitespace(text: string, start: number): number {
|
||||
let cursor = start;
|
||||
while (text[cursor] === " " || text[cursor] === "\t") {
|
||||
cursor += 1;
|
||||
}
|
||||
return cursor;
|
||||
}
|
||||
|
||||
function isToolNameChar(char: string | undefined): boolean {
|
||||
return Boolean(char && /[A-Za-z0-9_-]/.test(char));
|
||||
}
|
||||
|
||||
function hasToolNamePrefix(toolNames: Set<string>, prefix: string): boolean {
|
||||
for (const toolName of toolNames) {
|
||||
if (toolName.startsWith(prefix)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function couldStillBeJsonPayload(text: string, start: number): boolean {
|
||||
let cursor = start;
|
||||
while (cursor < text.length && /\s/.test(text[cursor] ?? "")) {
|
||||
cursor += 1;
|
||||
}
|
||||
return cursor >= text.length || text[cursor] === "{";
|
||||
}
|
||||
|
||||
function couldStillBeBracketedToolCall(text: string, toolNames: Set<string>): boolean {
|
||||
if (!text.startsWith("[")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const toolPrefix = "[tool:";
|
||||
if (matchesLiteralPrefix(text, toolPrefix)) {
|
||||
if (text.length <= toolPrefix.length) {
|
||||
return true;
|
||||
}
|
||||
let cursor = toolPrefix.length;
|
||||
while (isToolNameChar(text[cursor])) {
|
||||
cursor += 1;
|
||||
}
|
||||
const name = text.slice(toolPrefix.length, cursor);
|
||||
if (!name || !hasToolNamePrefix(toolNames, name)) {
|
||||
return false;
|
||||
}
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
if (text[cursor] !== "]") {
|
||||
return false;
|
||||
}
|
||||
return couldStillBeJsonPayload(text, cursor + 1);
|
||||
}
|
||||
|
||||
let cursor = 1;
|
||||
while (isToolNameChar(text[cursor])) {
|
||||
cursor += 1;
|
||||
}
|
||||
const name = text.slice(1, cursor);
|
||||
if (!name || !hasToolNamePrefix(toolNames, name)) {
|
||||
return false;
|
||||
}
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
if (text[cursor] !== "]") {
|
||||
return false;
|
||||
}
|
||||
|
||||
cursor = skipHorizontalWhitespace(text, cursor + 1);
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
if (text[cursor] === "\r") {
|
||||
if (cursor + 1 >= text.length) {
|
||||
return true;
|
||||
}
|
||||
return couldStillBeJsonPayload(text, text[cursor + 1] === "\n" ? cursor + 2 : cursor + 1);
|
||||
}
|
||||
if (text[cursor] !== "\n") {
|
||||
return false;
|
||||
}
|
||||
return couldStillBeJsonPayload(text, cursor + 1);
|
||||
}
|
||||
|
||||
function couldStillBeHarmonyToolCall(text: string, toolNames: Set<string>): boolean {
|
||||
const channelMarker = "<|channel|>";
|
||||
let cursor = 0;
|
||||
if (matchesLiteralPrefix(text, channelMarker)) {
|
||||
if (text.length <= channelMarker.length) {
|
||||
return true;
|
||||
}
|
||||
cursor = channelMarker.length;
|
||||
}
|
||||
|
||||
const rest = text.slice(cursor);
|
||||
const channel = ["commentary", "analysis", "final"].find((candidate) =>
|
||||
matchesLiteralPrefix(rest, candidate),
|
||||
);
|
||||
if (!channel) {
|
||||
return false;
|
||||
}
|
||||
if (rest.length <= channel.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cursor += channel.length;
|
||||
cursor = skipHorizontalWhitespace(text, cursor);
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const toMarker = "to=";
|
||||
const toRest = text.slice(cursor);
|
||||
if (!matchesLiteralPrefix(toRest, toMarker)) {
|
||||
return false;
|
||||
}
|
||||
if (toRest.length <= toMarker.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cursor += toMarker.length;
|
||||
const nameStart = cursor;
|
||||
while (isToolNameChar(text[cursor])) {
|
||||
cursor += 1;
|
||||
}
|
||||
const name = text.slice(nameStart, cursor);
|
||||
if (!name || !hasToolNamePrefix(toolNames, name)) {
|
||||
return false;
|
||||
}
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cursor = skipHorizontalWhitespace(text, cursor);
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
if (!toolNames.has(name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const codeMarker = "code";
|
||||
const codeRest = text.slice(cursor);
|
||||
if (!matchesLiteralPrefix(codeRest, codeMarker)) {
|
||||
return false;
|
||||
}
|
||||
if (codeRest.length <= codeMarker.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cursor += codeMarker.length;
|
||||
while (cursor < text.length && /\s/.test(text[cursor] ?? "")) {
|
||||
cursor += 1;
|
||||
}
|
||||
if (cursor >= text.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const messageMarker = "<|message|>";
|
||||
const messageRest = text.slice(cursor);
|
||||
if (matchesLiteralPrefix(messageRest, messageMarker)) {
|
||||
return true;
|
||||
}
|
||||
return text[cursor] === "{";
|
||||
}
|
||||
|
||||
function createSyntheticToolCallId(): string {
|
||||
return `call_${randomUUID().replace(/-/g, "").slice(0, 24)}`;
|
||||
}
|
||||
|
||||
function createPlainTextToolCallBlock(parsed: {
|
||||
arguments: Record<string, unknown>;
|
||||
name: string;
|
||||
}): Record<string, unknown> {
|
||||
return {
|
||||
type: "toolCall",
|
||||
id: createSyntheticToolCallId(),
|
||||
name: parsed.name,
|
||||
arguments: parsed.arguments,
|
||||
partialArgs: JSON.stringify(parsed.arguments),
|
||||
};
|
||||
}
|
||||
|
||||
function promotePlainTextToolCalls(
|
||||
message: unknown,
|
||||
toolNames: Set<string>,
|
||||
): Record<string, unknown> | undefined {
|
||||
const messageRecord = toRecord(message);
|
||||
if (!messageRecord) {
|
||||
return undefined;
|
||||
}
|
||||
if (!Array.isArray(messageRecord.content)) {
|
||||
if (typeof messageRecord.content !== "string" || !messageRecord.content.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
const parsed = parseStandalonePlainTextToolCallBlocks(messageRecord.content, {
|
||||
allowedToolNames: toolNames,
|
||||
});
|
||||
if (!parsed) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
...messageRecord,
|
||||
content: parsed.map(createPlainTextToolCallBlock),
|
||||
stopReason: "toolUse",
|
||||
};
|
||||
}
|
||||
if (
|
||||
messageRecord.content.some((block) => toRecord(block)?.type === "toolCall") ||
|
||||
messageRecord.content.length === 0
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let promoted = false;
|
||||
const nextContent: Array<Record<string, unknown>> = [];
|
||||
for (const block of messageRecord.content) {
|
||||
const blockRecord = toRecord(block);
|
||||
if (!blockRecord) {
|
||||
return undefined;
|
||||
}
|
||||
if (blockRecord.type !== "text") {
|
||||
nextContent.push(blockRecord);
|
||||
continue;
|
||||
}
|
||||
const text = typeof blockRecord.text === "string" ? blockRecord.text : "";
|
||||
if (!text.trim()) {
|
||||
continue;
|
||||
}
|
||||
const parsed = parseStandalonePlainTextToolCallBlocks(text, {
|
||||
allowedToolNames: toolNames,
|
||||
});
|
||||
if (!parsed) {
|
||||
return undefined;
|
||||
}
|
||||
nextContent.push(...parsed.map(createPlainTextToolCallBlock));
|
||||
promoted = true;
|
||||
}
|
||||
|
||||
if (!promoted) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
...messageRecord,
|
||||
content: nextContent,
|
||||
stopReason: "toolUse",
|
||||
};
|
||||
}
|
||||
|
||||
function emitPromotedToolCallEvents(
|
||||
stream: { push(event: unknown): void },
|
||||
message: Record<string, unknown>,
|
||||
): void {
|
||||
const content = Array.isArray(message.content) ? message.content : [];
|
||||
content.forEach((block, contentIndex) => {
|
||||
const record = toRecord(block);
|
||||
if (record?.type !== "toolCall") {
|
||||
return;
|
||||
}
|
||||
stream.push({ type: "toolcall_start", contentIndex, partial: message });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex,
|
||||
delta: typeof record.partialArgs === "string" ? record.partialArgs : "{}",
|
||||
partial: message,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function wrapPlainTextToolCallStream(
|
||||
source: ReturnType<StreamFn>,
|
||||
context: Parameters<StreamFn>[1],
|
||||
): ReturnType<StreamFn> {
|
||||
const toolNames = resolveContextToolNames(context);
|
||||
if (toolNames.size === 0) {
|
||||
return source;
|
||||
}
|
||||
const output = createAssistantMessageEventStream();
|
||||
const stream = output as unknown as { push(event: unknown): void; end(): void };
|
||||
|
||||
void (async () => {
|
||||
const bufferedTextEvents: unknown[] = [];
|
||||
let bufferedText = "";
|
||||
let ended = false;
|
||||
const endStream = () => {
|
||||
if (!ended) {
|
||||
ended = true;
|
||||
stream.end();
|
||||
}
|
||||
};
|
||||
const flushBufferedTextEvents = () => {
|
||||
for (const event of bufferedTextEvents.splice(0)) {
|
||||
stream.push(event);
|
||||
}
|
||||
bufferedText = "";
|
||||
};
|
||||
|
||||
try {
|
||||
for await (const event of source as AsyncIterable<unknown>) {
|
||||
const record = toRecord(event);
|
||||
const type = typeof record?.type === "string" ? record.type : "";
|
||||
|
||||
if (type === "text_start" || type === "text_delta" || type === "text_end") {
|
||||
bufferedTextEvents.push(event);
|
||||
if (typeof record?.delta === "string") {
|
||||
bufferedText += record.delta;
|
||||
} else if (typeof record?.content === "string" && !bufferedText) {
|
||||
bufferedText = record.content;
|
||||
}
|
||||
if (!couldStillBePlainTextToolCall(bufferedText, toolNames)) {
|
||||
flushBufferedTextEvents();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (type === "done") {
|
||||
const promotedMessage = promotePlainTextToolCalls(record?.message, toolNames);
|
||||
if (promotedMessage) {
|
||||
bufferedTextEvents.splice(0);
|
||||
bufferedText = "";
|
||||
emitPromotedToolCallEvents(stream, promotedMessage);
|
||||
stream.push({ ...record, reason: "toolUse", message: promotedMessage });
|
||||
} else {
|
||||
flushBufferedTextEvents();
|
||||
stream.push(event);
|
||||
}
|
||||
endStream();
|
||||
return;
|
||||
}
|
||||
|
||||
flushBufferedTextEvents();
|
||||
stream.push(event);
|
||||
if (type === "error") {
|
||||
endStream();
|
||||
return;
|
||||
}
|
||||
}
|
||||
flushBufferedTextEvents();
|
||||
} catch (error) {
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason: "error",
|
||||
error: {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
stopReason: "error",
|
||||
errorMessage: error instanceof Error ? error.message : String(error),
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
endStream();
|
||||
}
|
||||
})();
|
||||
|
||||
return output as ReturnType<StreamFn>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider stream wrapper for local/proxy providers that sometimes emit a
|
||||
* standalone textual tool-call block even when native tool calling is enabled.
|
||||
*/
|
||||
export function createPlainTextToolCallPromotionWrapper(
|
||||
baseStreamFn: StreamFn | undefined,
|
||||
): StreamFn {
|
||||
const underlying = baseStreamFn ?? streamSimple;
|
||||
return (model, context, options) => {
|
||||
const maybeStream = underlying(model, context, options);
|
||||
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
|
||||
return Promise.resolve(maybeStream).then((stream) =>
|
||||
wrapPlainTextToolCallStream(stream, context),
|
||||
) as ReturnType<StreamFn>;
|
||||
}
|
||||
return wrapPlainTextToolCallStream(maybeStream, context);
|
||||
};
|
||||
}
|
||||
|
||||
/** @deprecated Bundled provider stream helper; do not use from third-party plugins. */
|
||||
export function defaultToolStreamExtraParams(
|
||||
extraParams?: Record<string, unknown>,
|
||||
|
||||
@@ -4,12 +4,14 @@ import { VERSION } from "../version.js";
|
||||
import {
|
||||
composeProviderStreamWrappers as composeProviderStreamWrappersShared,
|
||||
createMoonshotThinkingWrapper as createMoonshotThinkingWrapperShared,
|
||||
createPlainTextToolCallPromotionWrapper as createPlainTextToolCallPromotionWrapperShared,
|
||||
createToolStreamWrapper as createToolStreamWrapperShared,
|
||||
} from "./provider-stream-shared.js";
|
||||
import {
|
||||
buildProviderStreamFamilyHooks,
|
||||
composeProviderStreamWrappers,
|
||||
createMoonshotThinkingWrapper,
|
||||
createPlainTextToolCallPromotionWrapper,
|
||||
createToolStreamWrapper,
|
||||
GOOGLE_THINKING_STREAM_HOOKS,
|
||||
KILOCODE_THINKING_STREAM_HOOKS,
|
||||
@@ -65,6 +67,9 @@ describe("composeProviderStreamWrappers", () => {
|
||||
|
||||
it("re-exports shared helper wrappers", () => {
|
||||
expect(createMoonshotThinkingWrapper).toBe(createMoonshotThinkingWrapperShared);
|
||||
expect(createPlainTextToolCallPromotionWrapper).toBe(
|
||||
createPlainTextToolCallPromotionWrapperShared,
|
||||
);
|
||||
expect(createToolStreamWrapper).toBe(createToolStreamWrapperShared);
|
||||
});
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ export {
|
||||
composeProviderStreamWrappers,
|
||||
createAnthropicThinkingPrefillPayloadWrapper,
|
||||
createMoonshotThinkingWrapper,
|
||||
createPlainTextToolCallPromotionWrapper,
|
||||
createToolStreamWrapper,
|
||||
createZaiToolStreamWrapper,
|
||||
defaultToolStreamExtraParams,
|
||||
|
||||
Reference in New Issue
Block a user