qwen-code/packages/core/src/core/turn.ts

227 lines
6.3 KiB
TypeScript

/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
PartListUnion,
GenerateContentResponse,
FunctionCall,
FunctionDeclaration,
GenerateContentResponseUsageMetadata,
} from '@google/genai';
import {
ToolCallConfirmationDetails,
ToolResult,
ToolResultDisplay,
} from '../tools/tools.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { reportError } from '../utils/errorReporting.js';
import { getErrorMessage } from '../utils/errors.js';
import { GeminiChat } from './geminiChat.js';
// Define a structure for tools passed to the server
export interface ServerTool {
name: string;
schema: FunctionDeclaration;
// The execute method signature might differ slightly or be wrapped
execute(
params: Record<string, unknown>,
signal?: AbortSignal,
): Promise<ToolResult>;
shouldConfirmExecute(
params: Record<string, unknown>,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false>;
}
export enum GeminiEventType {
Content = 'content',
ToolCallRequest = 'tool_call_request',
ToolCallResponse = 'tool_call_response',
ToolCallConfirmation = 'tool_call_confirmation',
UserCancelled = 'user_cancelled',
Error = 'error',
ChatCompressed = 'chat_compressed',
UsageMetadata = 'usage_metadata',
}
export interface GeminiErrorEventValue {
message: string;
}
export interface ToolCallRequestInfo {
callId: string;
name: string;
args: Record<string, unknown>;
}
export interface ToolCallResponseInfo {
callId: string;
responseParts: PartListUnion;
resultDisplay: ToolResultDisplay | undefined;
error: Error | undefined;
}
export interface ServerToolCallConfirmationDetails {
request: ToolCallRequestInfo;
details: ToolCallConfirmationDetails;
}
export type ServerGeminiContentEvent = {
type: GeminiEventType.Content;
value: string;
};
export type ServerGeminiToolCallRequestEvent = {
type: GeminiEventType.ToolCallRequest;
value: ToolCallRequestInfo;
};
export type ServerGeminiToolCallResponseEvent = {
type: GeminiEventType.ToolCallResponse;
value: ToolCallResponseInfo;
};
export type ServerGeminiToolCallConfirmationEvent = {
type: GeminiEventType.ToolCallConfirmation;
value: ServerToolCallConfirmationDetails;
};
export type ServerGeminiUserCancelledEvent = {
type: GeminiEventType.UserCancelled;
};
export type ServerGeminiErrorEvent = {
type: GeminiEventType.Error;
value: GeminiErrorEventValue;
};
export type ServerGeminiChatCompressedEvent = {
type: GeminiEventType.ChatCompressed;
};
export type ServerGeminiUsageMetadataEvent = {
type: GeminiEventType.UsageMetadata;
value: GenerateContentResponseUsageMetadata & { apiTimeMs?: number };
};
// The original union type, now composed of the individual types
export type ServerGeminiStreamEvent =
| ServerGeminiContentEvent
| ServerGeminiToolCallRequestEvent
| ServerGeminiToolCallResponseEvent
| ServerGeminiToolCallConfirmationEvent
| ServerGeminiUserCancelledEvent
| ServerGeminiErrorEvent
| ServerGeminiChatCompressedEvent
| ServerGeminiUsageMetadataEvent;
// A turn manages the agentic loop turn within the server context.
export class Turn {
readonly pendingToolCalls: Array<{
callId: string;
name: string;
args: Record<string, unknown>;
}>;
private debugResponses: GenerateContentResponse[];
private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
constructor(private readonly chat: GeminiChat) {
this.pendingToolCalls = [];
this.debugResponses = [];
}
// The run method yields simpler events suitable for server logic
async *run(
req: PartListUnion,
signal: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent> {
const startTime = Date.now();
try {
const responseStream = await this.chat.sendMessageStream({
message: req,
config: {
abortSignal: signal,
},
});
for await (const resp of responseStream) {
if (signal?.aborted) {
yield { type: GeminiEventType.UserCancelled };
// Do not add resp to debugResponses if aborted before processing
return;
}
this.debugResponses.push(resp);
const text = getResponseText(resp);
if (text) {
yield { type: GeminiEventType.Content, value: text };
}
// Handle function calls (requesting tool execution)
const functionCalls = resp.functionCalls ?? [];
for (const fnCall of functionCalls) {
const event = this.handlePendingFunctionCall(fnCall);
if (event) {
yield event;
}
}
if (resp.usageMetadata) {
this.lastUsageMetadata =
resp.usageMetadata as GenerateContentResponseUsageMetadata;
}
}
if (this.lastUsageMetadata) {
const durationMs = Date.now() - startTime;
yield {
type: GeminiEventType.UsageMetadata,
value: { ...this.lastUsageMetadata, apiTimeMs: durationMs },
};
}
} catch (error) {
if (signal.aborted) {
yield { type: GeminiEventType.UserCancelled };
// Regular cancellation error, fail gracefully.
return;
}
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
await reportError(
error,
'Error when talking to Gemini API',
contextForReport,
'Turn.run-sendMessageStream',
);
const errorMessage = getErrorMessage(error);
yield { type: GeminiEventType.Error, value: { message: errorMessage } };
return;
}
}
private handlePendingFunctionCall(
fnCall: FunctionCall,
): ServerGeminiStreamEvent | null {
const callId =
fnCall.id ??
`${fnCall.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
const name = fnCall.name || 'undefined_tool_name';
const args = (fnCall.args || {}) as Record<string, unknown>;
this.pendingToolCalls.push({ callId, name, args });
// Yield a request for the tool call, not the pending/confirming status
const value: ToolCallRequestInfo = { callId, name, args };
return { type: GeminiEventType.ToolCallRequest, value };
}
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}
getUsageMetadata(): GenerateContentResponseUsageMetadata | null {
return this.lastUsageMetadata;
}
}