diff --git a/package-lock.json b/package-lock.json index ef109c01..4aabf3cf 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11997,6 +11997,7 @@ "strip-ansi": "^7.1.0", "tiktoken": "^1.0.21", "undici": "^7.10.0", + "uuid": "^9.0.1", "ws": "^8.18.0" }, "devDependencies": { diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 0ec6bd07..c141be39 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -517,7 +517,6 @@ export async function loadCliConfig( (typeof argv.openaiLogging === 'undefined' ? settings.enableOpenAILogging : argv.openaiLogging) ?? false, - sampling_params: settings.sampling_params, systemPromptMappings: (settings.systemPromptMappings ?? [ { baseUrls: [ diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 73ffebdc..30f16bf7 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -503,7 +503,6 @@ export const SETTINGS_SCHEMA = { description: 'Show line numbers in the chat.', showInDialog: true, }, - contentGenerator: { type: 'object', label: 'Content Generator', @@ -513,15 +512,6 @@ export const SETTINGS_SCHEMA = { description: 'Content generator settings.', showInDialog: false, }, - sampling_params: { - type: 'object', - label: 'Sampling Params', - category: 'General', - requiresRestart: false, - default: undefined as Record | undefined, - description: 'Sampling parameters for the model.', - showInDialog: false, - }, enableOpenAILogging: { type: 'boolean', label: 'Enable OpenAI Logging', diff --git a/packages/core/package.json b/packages/core/package.json index 7b84fd01..0555bf99 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -52,6 +52,7 @@ "strip-ansi": "^7.1.0", "tiktoken": "^1.0.21", "undici": "^7.10.0", + "uuid": "^9.0.1", "ws": "^8.18.0" }, "devDependencies": { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index b1a2a096..f474a2dc 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -204,7 +204,6 @@ export interface ConfigParameters { folderTrust?: boolean; ideMode?: boolean; enableOpenAILogging?: boolean; - sampling_params?: Record; systemPromptMappings?: Array<{ baseUrls: string[]; modelNames: string[]; @@ -213,6 +212,9 @@ export interface ConfigParameters { contentGenerator?: { timeout?: number; maxRetries?: number; + samplingParams?: { + [key: string]: unknown; + }; }; cliVersion?: string; loadMemoryFromIncludeDirectories?: boolean; @@ -289,10 +291,10 @@ export class Config { | undefined; private readonly experimentalAcp: boolean = false; private readonly enableOpenAILogging: boolean; - private readonly sampling_params?: Record; private readonly contentGenerator?: { timeout?: number; maxRetries?: number; + samplingParams?: Record; }; private readonly cliVersion?: string; private readonly loadMemoryFromIncludeDirectories: boolean = false; @@ -367,7 +369,6 @@ export class Config { this.ideClient = IdeClient.getInstance(); this.systemPromptMappings = params.systemPromptMappings; this.enableOpenAILogging = params.enableOpenAILogging ?? false; - this.sampling_params = params.sampling_params; this.contentGenerator = params.contentGenerator; this.cliVersion = params.cliVersion; @@ -757,10 +758,6 @@ export class Config { return this.enableOpenAILogging; } - getSamplingParams(): Record | undefined { - return this.sampling_params; - } - getContentGeneratorTimeout(): number | undefined { return this.contentGenerator?.timeout; } @@ -769,6 +766,12 @@ export class Config { return this.contentGenerator?.maxRetries; } + getContentGeneratorSamplingParams(): ContentGeneratorConfig['samplingParams'] { + return this.contentGenerator?.samplingParams as + | ContentGeneratorConfig['samplingParams'] + | undefined; + } + getCliVersion(): string | undefined { return this.cliVersion; } diff --git a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts index c743c9b5..bb46b09b 100644 --- a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts +++ b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts @@ -7,6 +7,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { OpenAIContentGenerator } from '../openaiContentGenerator.js'; import { Config } from '../../config/config.js'; +import { AuthType } from '../contentGenerator.js'; import OpenAI from 'openai'; // Mock OpenAI @@ -41,9 +42,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => { mockConfig = { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'openai', - enableOpenAILogging: false, - timeout: 120000, - maxRetries: 3, }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -60,7 +58,12 @@ describe('OpenAIContentGenerator Timeout Handling', () => { vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); // Create generator instance - generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + }; + generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); }); afterEach(() => { @@ -237,12 +240,18 @@ describe('OpenAIContentGenerator Timeout Handling', () => { describe('timeout configuration', () => { it('should use default timeout configuration', () => { - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + baseUrl: 'http://localhost:8080', + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); // Verify OpenAI client was created with timeout config expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 120000, maxRetries: 3, defaultHeaders: { @@ -253,18 +262,23 @@ describe('OpenAIContentGenerator Timeout Handling', () => { it('should use custom timeout from config', () => { const customConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - timeout: 300000, // 5 minutes - maxRetries: 5, - }), + getContentGeneratorConfig: vi.fn().mockReturnValue({}), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', customConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'http://localhost:8080', + authType: AuthType.USE_OPENAI, + timeout: 300000, + maxRetries: 5, + }; + new OpenAIContentGenerator(contentGeneratorConfig, customConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 300000, maxRetries: 5, defaultHeaders: { @@ -279,11 +293,17 @@ describe('OpenAIContentGenerator Timeout Handling', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', noTimeoutConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + baseUrl: 'http://localhost:8080', + }; + new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 120000, // default maxRetries: 3, // default defaultHeaders: { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index f190df10..876e9027 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -565,10 +565,7 @@ export class GeminiClient { model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL; try { const userMemory = this.config.getUserMemory(); - const systemPromptMappings = this.config.getSystemPromptMappings(); - const systemInstruction = getCoreSystemPrompt(userMemory, { - systemPromptMappings, - }); + const systemInstruction = getCoreSystemPrompt(userMemory); const requestConfig = { abortSignal, ...this.generateContentConfig, @@ -656,10 +653,7 @@ export class GeminiClient { try { const userMemory = this.config.getUserMemory(); - const systemPromptMappings = this.config.getSystemPromptMappings(); - const systemInstruction = getCoreSystemPrompt(userMemory, { - systemPromptMappings, - }); + const systemInstruction = getCoreSystemPrompt(userMemory); const requestConfig = { abortSignal, diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index 5d735beb..2761c0c5 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -84,6 +84,7 @@ describe('createContentGeneratorConfig', () => { getSamplingParams: vi.fn().mockReturnValue(undefined), getContentGeneratorTimeout: vi.fn().mockReturnValue(undefined), getContentGeneratorMaxRetries: vi.fn().mockReturnValue(undefined), + getContentGeneratorSamplingParams: vi.fn().mockReturnValue(undefined), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 2c90e9c6..582ffbe4 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -53,6 +53,7 @@ export enum AuthType { export type ContentGeneratorConfig = { model: string; apiKey?: string; + baseUrl?: string; vertexai?: boolean; authType?: AuthType | undefined; enableOpenAILogging?: boolean; @@ -76,11 +77,16 @@ export function createContentGeneratorConfig( config: Config, authType: AuthType | undefined, ): ContentGeneratorConfig { + // google auth const geminiApiKey = process.env.GEMINI_API_KEY || undefined; const googleApiKey = process.env.GOOGLE_API_KEY || undefined; const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT || undefined; const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION || undefined; + + // openai auth const openaiApiKey = process.env.OPENAI_API_KEY; + const openaiBaseUrl = process.env.OPENAI_BASE_URL || undefined; + const openaiModel = process.env.OPENAI_MODEL || undefined; // Use runtime model from config if available; otherwise, fall back to parameter or default const effectiveModel = config.getModel() || DEFAULT_GEMINI_MODEL; @@ -92,7 +98,7 @@ export function createContentGeneratorConfig( enableOpenAILogging: config.getEnableOpenAILogging(), timeout: config.getContentGeneratorTimeout(), maxRetries: config.getContentGeneratorMaxRetries(), - samplingParams: config.getSamplingParams(), + samplingParams: config.getContentGeneratorSamplingParams(), }; // If we are using Google auth or we are in Cloud Shell, there is nothing else to validate for now @@ -127,8 +133,8 @@ export function createContentGeneratorConfig( if (authType === AuthType.USE_OPENAI && openaiApiKey) { contentGeneratorConfig.apiKey = openaiApiKey; - contentGeneratorConfig.model = - process.env.OPENAI_MODEL || DEFAULT_GEMINI_MODEL; + contentGeneratorConfig.baseUrl = openaiBaseUrl; + contentGeneratorConfig.model = openaiModel || DEFAULT_QWEN_MODEL; return contentGeneratorConfig; } @@ -196,7 +202,7 @@ export async function createContentGenerator( ); // Always use OpenAIContentGenerator, logging is controlled by enableOpenAILogging flag - return new OpenAIContentGenerator(config.apiKey, config.model, gcConfig); + return new OpenAIContentGenerator(config, gcConfig); } if (config.authType === AuthType.QWEN_OAUTH) { @@ -217,7 +223,7 @@ export async function createContentGenerator( const qwenClient = await getQwenOauthClient(gcConfig); // Create the content generator with dynamic token management - return new QwenContentGenerator(qwenClient, config.model, gcConfig); + return new QwenContentGenerator(qwenClient, config, gcConfig); } catch (error) { throw new Error( `Failed to initialize Qwen: ${error instanceof Error ? error.message : String(error)}`, diff --git a/packages/core/src/core/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator.test.ts index 92de6235..ac255a42 100644 --- a/packages/core/src/core/openaiContentGenerator.test.ts +++ b/packages/core/src/core/openaiContentGenerator.test.ts @@ -7,6 +7,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { OpenAIContentGenerator } from './openaiContentGenerator.js'; import { Config } from '../config/config.js'; +import { AuthType } from './contentGenerator.js'; import OpenAI from 'openai'; import type { GenerateContentParameters, @@ -84,7 +85,20 @@ describe('OpenAIContentGenerator', () => { vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); // Create generator instance - generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + samplingParams: { + temperature: 0.7, + max_tokens: 1000, + top_p: 0.9, + }, + }; + generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); }); afterEach(() => { @@ -95,7 +109,7 @@ describe('OpenAIContentGenerator', () => { it('should initialize with basic configuration', () => { expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: undefined, timeout: 120000, maxRetries: 3, defaultHeaders: { @@ -105,9 +119,16 @@ describe('OpenAIContentGenerator', () => { }); it('should handle custom base URL', () => { - vi.stubEnv('OPENAI_BASE_URL', 'https://api.custom.com'); - - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'https://api.custom.com', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', @@ -121,9 +142,16 @@ describe('OpenAIContentGenerator', () => { }); it('should configure OpenRouter headers when using OpenRouter', () => { - vi.stubEnv('OPENAI_BASE_URL', 'https://openrouter.ai/api/v1'); - - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'https://openrouter.ai/api/v1', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', @@ -147,11 +175,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', customConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + timeout: 300000, + maxRetries: 5, + }; + new OpenAIContentGenerator(contentGeneratorConfig, customConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: undefined, timeout: 300000, maxRetries: 5, defaultHeaders: { @@ -906,9 +941,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -1029,9 +1069,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -1587,7 +1632,23 @@ describe('OpenAIContentGenerator', () => { } } - const testGenerator = new TestGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + samplingParams: { + temperature: 0.7, + max_tokens: 1000, + top_p: 0.9, + }, + }; + const testGenerator = new TestGenerator( + contentGeneratorConfig, + mockConfig, + ); const consoleSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); @@ -1908,9 +1969,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + samplingParams: { + temperature: 0.8, + max_tokens: 500, + }, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -2093,9 +2163,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -2350,9 +2425,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + samplingParams: { + temperature: undefined, + max_tokens: undefined, + top_p: undefined, + }, + }; const testGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, configWithUndefined, ); @@ -2408,9 +2492,22 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + samplingParams: { + temperature: 0.8, + max_tokens: 1500, + top_p: 0.95, + top_k: 40, + repetition_penalty: 1.1, + presence_penalty: 0.5, + frequency_penalty: 0.3, + }, + }; const testGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, fullSamplingConfig, ); @@ -2489,9 +2586,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + apiKey: 'test-key', + authType: AuthType.QWEN_OAUTH, + enableOpenAILogging: false, + }; const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2528,12 +2630,6 @@ describe('OpenAIContentGenerator', () => { }); it('should include metadata when baseURL is dashscope openai compatible mode', async () => { - // Mock environment to set dashscope base URL BEFORE creating the generator - vi.stubEnv( - 'OPENAI_BASE_URL', - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - ); - const dashscopeConfig = { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'openai', // Not QWEN_OAUTH @@ -2543,9 +2639,15 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + apiKey: 'test-key', + baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + }; const dashscopeGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, dashscopeConfig, ); @@ -2604,9 +2706,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const regularGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, regularConfig, ); @@ -2650,9 +2761,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const otherGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, otherAuthConfig, ); @@ -2699,9 +2819,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const otherBaseUrlGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, otherBaseUrlConfig, ); @@ -2748,9 +2877,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2804,8 +2942,6 @@ describe('OpenAIContentGenerator', () => { sessionId: 'streaming-session-id', promptId: 'streaming-prompt-id', }, - stream: true, - stream_options: { include_usage: true }, }), ); @@ -2827,9 +2963,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const regularGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, regularConfig, ); @@ -2901,9 +3046,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2955,9 +3109,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const noBaseUrlGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, noBaseUrlConfig, ); @@ -3004,9 +3167,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const undefinedAuthGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, undefinedAuthConfig, ); @@ -3050,9 +3222,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const undefinedConfigGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, undefinedConfig, ); @@ -3089,4 +3270,232 @@ describe('OpenAIContentGenerator', () => { ); }); }); + + describe('cache control for DashScope', () => { + it('should add cache control to system message for DashScope providers', async () => { + // Mock environment to set dashscope base URL + vi.stubEnv( + 'OPENAI_BASE_URL', + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + + const dashscopeConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + + const dashscopeGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + dashscopeConfig, + ); + + // Mock the client's baseURL property to return the expected value + Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { + value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + writable: true, + }); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'qwen-turbo', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + config: { + systemInstruction: 'You are a helpful assistant.', + }, + model: 'qwen-turbo', + }; + + await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); + + // Should include cache control in system message + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'system', + content: expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: 'You are a helpful assistant.', + cache_control: { type: 'ephemeral' }, + }), + ]), + }), + ]), + }), + ); + }); + + it('should add cache control to last message for DashScope providers', async () => { + // Mock environment to set dashscope base URL + vi.stubEnv( + 'OPENAI_BASE_URL', + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + + const dashscopeConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + + const dashscopeGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + dashscopeConfig, + ); + + // Mock the client's baseURL property to return the expected value + Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { + value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + writable: true, + }); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'qwen-turbo', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello, how are you?' }] }], + model: 'qwen-turbo', + }; + + await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); + + // Should include cache control in last message + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'user', + content: expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: 'Hello, how are you?', + cache_control: { type: 'ephemeral' }, + }), + ]), + }), + ]), + }), + ); + }); + + it('should NOT add cache control for non-DashScope providers', async () => { + const regularConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('regular-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + + const regularGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + regularConfig, + ); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'gpt-4', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + config: { + systemInstruction: 'You are a helpful assistant.', + }, + model: 'gpt-4', + }; + + await regularGenerator.generateContent(request, 'regular-prompt-id'); + + // Should NOT include cache control (messages should be strings, not arrays) + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'system', + content: 'You are a helpful assistant.', + }), + expect.objectContaining({ + role: 'user', + content: 'Hello', + }), + ]), + }), + ); + }); + }); }); diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts index e24bd0c3..3f223f3e 100644 --- a/packages/core/src/core/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator.ts @@ -20,7 +20,11 @@ import { FunctionCall, FunctionResponse, } from '@google/genai'; -import { AuthType, ContentGenerator } from './contentGenerator.js'; +import { + AuthType, + ContentGenerator, + ContentGeneratorConfig, +} from './contentGenerator.js'; import OpenAI from 'openai'; import { logApiError, logApiResponse } from '../telemetry/loggers.js'; import { ApiErrorEvent, ApiResponseEvent } from '../telemetry/types.js'; @@ -28,6 +32,17 @@ import { Config } from '../config/config.js'; import { openaiLogger } from '../utils/openaiLogger.js'; import { safeJsonParse } from '../utils/safeJsonParse.js'; +// Extended types to support cache_control +interface ChatCompletionContentPartTextWithCache + extends OpenAI.Chat.ChatCompletionContentPartText { + cache_control?: { type: 'ephemeral' }; +} + +type ChatCompletionContentPartWithCache = + | ChatCompletionContentPartTextWithCache + | OpenAI.Chat.ChatCompletionContentPartImage + | OpenAI.Chat.ChatCompletionContentPartRefusal; + // OpenAI API type definitions for logging interface OpenAIToolCall { id: string; @@ -38,9 +53,15 @@ interface OpenAIToolCall { }; } +interface OpenAIContentItem { + type: 'text'; + text: string; + cache_control?: { type: 'ephemeral' }; +} + interface OpenAIMessage { role: 'system' | 'user' | 'assistant' | 'tool'; - content: string | null; + content: string | null | OpenAIContentItem[]; tool_calls?: OpenAIToolCall[]; tool_call_id?: string; } @@ -60,15 +81,6 @@ interface OpenAIChoice { finish_reason: string; } -interface OpenAIRequestFormat { - model: string; - messages: OpenAIMessage[]; - temperature?: number; - max_tokens?: number; - top_p?: number; - tools?: unknown[]; -} - interface OpenAIResponseFormat { id: string; object: string; @@ -81,6 +93,7 @@ interface OpenAIResponseFormat { export class OpenAIContentGenerator implements ContentGenerator { protected client: OpenAI; private model: string; + private contentGeneratorConfig: ContentGeneratorConfig; private config: Config; private streamingToolCalls: Map< number, @@ -91,50 +104,40 @@ export class OpenAIContentGenerator implements ContentGenerator { } > = new Map(); - constructor(apiKey: string, model: string, config: Config) { - this.model = model; - this.config = config; - const baseURL = process.env.OPENAI_BASE_URL || ''; + constructor( + contentGeneratorConfig: ContentGeneratorConfig, + gcConfig: Config, + ) { + this.model = contentGeneratorConfig.model; + this.contentGeneratorConfig = contentGeneratorConfig; + this.config = gcConfig; - // Configure timeout settings - using progressive timeouts - const timeoutConfig = { - // Base timeout for most requests (2 minutes) - timeout: 120000, - // Maximum retries for failed requests - maxRetries: 3, - // HTTP client options - httpAgent: undefined, // Let the client use default agent - }; - - // Allow config to override timeout settings - const contentGeneratorConfig = this.config.getContentGeneratorConfig(); - if (contentGeneratorConfig?.timeout) { - timeoutConfig.timeout = contentGeneratorConfig.timeout; - } - if (contentGeneratorConfig?.maxRetries !== undefined) { - timeoutConfig.maxRetries = contentGeneratorConfig.maxRetries; - } - - const version = config.getCliVersion() || 'unknown'; + const version = gcConfig.getCliVersion() || 'unknown'; const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`; // Check if using OpenRouter and add required headers - const isOpenRouter = baseURL.includes('openrouter.ai'); + const isOpenRouterProvider = this.isOpenRouterProvider(); + const isDashScopeProvider = this.isDashScopeProvider(); + const defaultHeaders = { 'User-Agent': userAgent, - ...(isOpenRouter + ...(isOpenRouterProvider ? { 'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git', 'X-Title': 'Qwen Code', } - : {}), + : isDashScopeProvider + ? { + 'X-DashScope-CacheControl': 'enable', + } + : {}), }; this.client = new OpenAI({ - apiKey, - baseURL, - timeout: timeoutConfig.timeout, - maxRetries: timeoutConfig.maxRetries, + apiKey: contentGeneratorConfig.apiKey, + baseURL: contentGeneratorConfig.baseUrl, + timeout: contentGeneratorConfig.timeout ?? 120000, + maxRetries: contentGeneratorConfig.maxRetries ?? 3, defaultHeaders, }); } @@ -185,22 +188,25 @@ export class OpenAIContentGenerator implements ContentGenerator { ); } + private isOpenRouterProvider(): boolean { + const baseURL = this.contentGeneratorConfig.baseUrl || ''; + return baseURL.includes('openrouter.ai'); + } + /** - * Determine if metadata should be included in the request. - * Only include the `metadata` field if the provider is QWEN_OAUTH - * or the baseUrl is 'https://dashscope.aliyuncs.com/compatible-mode/v1'. - * This is because some models/providers do not support metadata or need extra configuration. + * Determine if this is a DashScope provider. + * DashScope providers include QWEN_OAUTH auth type or specific DashScope base URLs. * - * @returns true if metadata should be included, false otherwise + * @returns true if this is a DashScope provider, false otherwise */ - private shouldIncludeMetadata(): boolean { - const authType = this.config.getContentGeneratorConfig?.()?.authType; - // baseUrl may be undefined; default to empty string if so - const baseUrl = this.client?.baseURL || ''; + private isDashScopeProvider(): boolean { + const authType = this.contentGeneratorConfig.authType; + const baseUrl = this.contentGeneratorConfig.baseUrl; return ( authType === AuthType.QWEN_OAUTH || - baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' + baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' || + baseUrl === 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1' ); } @@ -213,7 +219,7 @@ export class OpenAIContentGenerator implements ContentGenerator { private buildMetadata( userPromptId: string, ): { metadata: { sessionId?: string; promptId: string } } | undefined { - if (!this.shouldIncludeMetadata()) { + if (!this.isDashScopeProvider()) { return undefined; } @@ -225,35 +231,44 @@ export class OpenAIContentGenerator implements ContentGenerator { }; } + private async buildCreateParams( + request: GenerateContentParameters, + userPromptId: string, + ): Promise[0]> { + const messages = this.convertToOpenAIFormat(request); + + // Build sampling parameters with clear priority: + // 1. Request-level parameters (highest priority) + // 2. Config-level sampling parameters (medium priority) + // 3. Default values (lowest priority) + const samplingParams = this.buildSamplingParameters(request); + + const createParams: Parameters< + typeof this.client.chat.completions.create + >[0] = { + model: this.model, + messages, + ...samplingParams, + ...(this.buildMetadata(userPromptId) || {}), + }; + + if (request.config?.tools) { + createParams.tools = await this.convertGeminiToolsToOpenAI( + request.config.tools, + ); + } + + return createParams; + } + async generateContent( request: GenerateContentParameters, userPromptId: string, ): Promise { const startTime = Date.now(); - const messages = this.convertToOpenAIFormat(request); + const createParams = await this.buildCreateParams(request, userPromptId); try { - // Build sampling parameters with clear priority: - // 1. Request-level parameters (highest priority) - // 2. Config-level sampling parameters (medium priority) - // 3. Default values (lowest priority) - const samplingParams = this.buildSamplingParameters(request); - - const createParams: Parameters< - typeof this.client.chat.completions.create - >[0] = { - model: this.model, - messages, - ...samplingParams, - ...(this.buildMetadata(userPromptId) || {}), - }; - - if (request.config?.tools) { - createParams.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - // console.log('createParams', createParams); const completion = (await this.client.chat.completions.create( createParams, )) as OpenAI.Chat.ChatCompletion; @@ -267,15 +282,15 @@ export class OpenAIContentGenerator implements ContentGenerator { this.model, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, response.usageMetadata, ); logApiResponse(this.config, responseEvent); // Log interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { + const openaiRequest = createParams; const openaiResponse = this.convertGeminiResponseToOpenAI(response); await openaiLogger.logInteraction(openaiRequest, openaiResponse); } @@ -300,7 +315,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -309,10 +324,9 @@ export class OpenAIContentGenerator implements ContentGenerator { logApiError(this.config, errorEvent); // Log error interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { await openaiLogger.logInteraction( - openaiRequest, + createParams, undefined, error as Error, ); @@ -343,29 +357,12 @@ export class OpenAIContentGenerator implements ContentGenerator { userPromptId: string, ): Promise> { const startTime = Date.now(); - const messages = this.convertToOpenAIFormat(request); + const createParams = await this.buildCreateParams(request, userPromptId); + + createParams.stream = true; + createParams.stream_options = { include_usage: true }; try { - // Build sampling parameters with clear priority - const samplingParams = this.buildSamplingParameters(request); - - const createParams: Parameters< - typeof this.client.chat.completions.create - >[0] = { - model: this.model, - messages, - ...samplingParams, - stream: true, - stream_options: { include_usage: true }, - ...(this.buildMetadata(userPromptId) || {}), - }; - - if (request.config?.tools) { - createParams.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - const stream = (await this.client.chat.completions.create( createParams, )) as AsyncIterable; @@ -397,16 +394,15 @@ export class OpenAIContentGenerator implements ContentGenerator { this.model, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, finalUsageMetadata, ); logApiResponse(this.config, responseEvent); // Log interaction if enabled (same as generateContent method) - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = - await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { + const openaiRequest = createParams; // For streaming, we combine all responses into a single response for logging const combinedResponse = this.combineStreamResponsesForLogging(responses); @@ -433,7 +429,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -442,11 +438,9 @@ export class OpenAIContentGenerator implements ContentGenerator { logApiError(this.config, errorEvent); // Log error interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = - await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { await openaiLogger.logInteraction( - openaiRequest, + createParams, undefined, error as Error, ); @@ -487,7 +481,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -944,7 +938,114 @@ export class OpenAIContentGenerator implements ContentGenerator { // Clean up orphaned tool calls and merge consecutive assistant messages const cleanedMessages = this.cleanOrphanedToolCalls(messages); - return this.mergeConsecutiveAssistantMessages(cleanedMessages); + const mergedMessages = + this.mergeConsecutiveAssistantMessages(cleanedMessages); + + // Add cache control to system and last messages for DashScope providers + return this.addCacheControlFlag(mergedMessages, 'both'); + } + + /** + * Add cache control flag to specified message(s) for DashScope providers + */ + private addCacheControlFlag( + messages: OpenAI.Chat.ChatCompletionMessageParam[], + target: 'system' | 'last' | 'both' = 'both', + ): OpenAI.Chat.ChatCompletionMessageParam[] { + if (!this.isDashScopeProvider() || messages.length === 0) { + return messages; + } + + let updatedMessages = [...messages]; + + // Add cache control to system message if requested + if (target === 'system' || target === 'both') { + updatedMessages = this.addCacheControlToMessage( + updatedMessages, + 'system', + ); + } + + // Add cache control to last message if requested + if (target === 'last' || target === 'both') { + updatedMessages = this.addCacheControlToMessage(updatedMessages, 'last'); + } + + return updatedMessages; + } + + /** + * Helper method to add cache control to a specific message + */ + private addCacheControlToMessage( + messages: OpenAI.Chat.ChatCompletionMessageParam[], + target: 'system' | 'last', + ): OpenAI.Chat.ChatCompletionMessageParam[] { + const updatedMessages = [...messages]; + let messageIndex: number; + + if (target === 'system') { + // Find the first system message + messageIndex = messages.findIndex((msg) => msg.role === 'system'); + if (messageIndex === -1) { + return updatedMessages; + } + } else { + // Get the last message + messageIndex = messages.length - 1; + } + + const message = updatedMessages[messageIndex]; + + // Only process messages that have content + if ('content' in message && message.content !== null) { + if (typeof message.content === 'string') { + // Convert string content to array format with cache control + const messageWithArrayContent = { + ...message, + content: [ + { + type: 'text', + text: message.content, + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache, + ], + }; + updatedMessages[messageIndex] = + messageWithArrayContent as OpenAI.Chat.ChatCompletionMessageParam; + } else if (Array.isArray(message.content)) { + // If content is already an array, add cache_control to the last item + const contentArray = [ + ...message.content, + ] as ChatCompletionContentPartWithCache[]; + if (contentArray.length > 0) { + const lastItem = contentArray[contentArray.length - 1]; + if (lastItem.type === 'text') { + // Add cache_control to the last text item + contentArray[contentArray.length - 1] = { + ...lastItem, + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache; + } else { + // If the last item is not text, add a new text item with cache_control + contentArray.push({ + type: 'text', + text: '', + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache); + } + + const messageWithCache = { + ...message, + content: contentArray, + }; + updatedMessages[messageIndex] = + messageWithCache as OpenAI.Chat.ChatCompletionMessageParam; + } + } + } + + return updatedMessages; } /** @@ -1368,8 +1469,7 @@ export class OpenAIContentGenerator implements ContentGenerator { private buildSamplingParameters( request: GenerateContentParameters, ): Record { - const configSamplingParams = - this.config.getContentGeneratorConfig()?.samplingParams; + const configSamplingParams = this.contentGeneratorConfig.samplingParams; const params = { // Temperature: config > request > default @@ -1431,313 +1531,6 @@ export class OpenAIContentGenerator implements ContentGenerator { return mapping[openaiReason] || FinishReason.FINISH_REASON_UNSPECIFIED; } - /** - * Convert Gemini request format to OpenAI chat completion format for logging - */ - private async convertGeminiRequestToOpenAI( - request: GenerateContentParameters, - ): Promise { - const messages: OpenAIMessage[] = []; - - // Handle system instruction - if (request.config?.systemInstruction) { - const systemInstruction = request.config.systemInstruction; - let systemText = ''; - - if (Array.isArray(systemInstruction)) { - systemText = systemInstruction - .map((content) => { - if (typeof content === 'string') return content; - if ('parts' in content) { - const contentObj = content as Content; - return ( - contentObj.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || '' - ); - } - return ''; - }) - .join('\n'); - } else if (typeof systemInstruction === 'string') { - systemText = systemInstruction; - } else if ( - typeof systemInstruction === 'object' && - 'parts' in systemInstruction - ) { - const systemContent = systemInstruction as Content; - systemText = - systemContent.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || ''; - } - - if (systemText) { - messages.push({ - role: 'system', - content: systemText, - }); - } - } - - // Handle contents - if (Array.isArray(request.contents)) { - for (const content of request.contents) { - if (typeof content === 'string') { - messages.push({ role: 'user', content }); - } else if ('role' in content && 'parts' in content) { - const functionCalls: FunctionCall[] = []; - const functionResponses: FunctionResponse[] = []; - const textParts: string[] = []; - - for (const part of content.parts || []) { - if (typeof part === 'string') { - textParts.push(part); - } else if ('text' in part && part.text) { - textParts.push(part.text); - } else if ('functionCall' in part && part.functionCall) { - functionCalls.push(part.functionCall); - } else if ('functionResponse' in part && part.functionResponse) { - functionResponses.push(part.functionResponse); - } - } - - // Handle function responses (tool results) - if (functionResponses.length > 0) { - for (const funcResponse of functionResponses) { - messages.push({ - role: 'tool', - tool_call_id: funcResponse.id || '', - content: - typeof funcResponse.response === 'string' - ? funcResponse.response - : JSON.stringify(funcResponse.response), - }); - } - } - // Handle model messages with function calls - else if (content.role === 'model' && functionCalls.length > 0) { - const toolCalls = functionCalls.map((fc, index) => ({ - id: fc.id || `call_${index}`, - type: 'function' as const, - function: { - name: fc.name || '', - arguments: JSON.stringify(fc.args || {}), - }, - })); - - messages.push({ - role: 'assistant', - content: textParts.join('\n') || null, - tool_calls: toolCalls, - }); - } - // Handle regular text messages - else { - const role = content.role === 'model' ? 'assistant' : 'user'; - const text = textParts.join('\n'); - if (text) { - messages.push({ role, content: text }); - } - } - } - } - } else if (request.contents) { - if (typeof request.contents === 'string') { - messages.push({ role: 'user', content: request.contents }); - } else if ('role' in request.contents && 'parts' in request.contents) { - const content = request.contents; - const role = content.role === 'model' ? 'assistant' : 'user'; - const text = - content.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || ''; - messages.push({ role, content: text }); - } - } - - // Clean up orphaned tool calls and merge consecutive assistant messages - const cleanedMessages = this.cleanOrphanedToolCallsForLogging(messages); - const mergedMessages = - this.mergeConsecutiveAssistantMessagesForLogging(cleanedMessages); - - const openaiRequest: OpenAIRequestFormat = { - model: this.model, - messages: mergedMessages, - }; - - // Add sampling parameters using the same logic as actual API calls - const samplingParams = this.buildSamplingParameters(request); - Object.assign(openaiRequest, samplingParams); - - // Convert tools if present - if (request.config?.tools) { - openaiRequest.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - - return openaiRequest; - } - - /** - * Clean up orphaned tool calls for logging purposes - */ - private cleanOrphanedToolCallsForLogging( - messages: OpenAIMessage[], - ): OpenAIMessage[] { - const cleaned: OpenAIMessage[] = []; - const toolCallIds = new Set(); - const toolResponseIds = new Set(); - - // First pass: collect all tool call IDs and tool response IDs - for (const message of messages) { - if (message.role === 'assistant' && message.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.id) { - toolCallIds.add(toolCall.id); - } - } - } else if (message.role === 'tool' && message.tool_call_id) { - toolResponseIds.add(message.tool_call_id); - } - } - - // Second pass: filter out orphaned messages - for (const message of messages) { - if (message.role === 'assistant' && message.tool_calls) { - // Filter out tool calls that don't have corresponding responses - const validToolCalls = message.tool_calls.filter( - (toolCall) => toolCall.id && toolResponseIds.has(toolCall.id), - ); - - if (validToolCalls.length > 0) { - // Keep the message but only with valid tool calls - const cleanedMessage = { ...message }; - cleanedMessage.tool_calls = validToolCalls; - cleaned.push(cleanedMessage); - } else if ( - typeof message.content === 'string' && - message.content.trim() - ) { - // Keep the message if it has text content, but remove tool calls - const cleanedMessage = { ...message }; - delete cleanedMessage.tool_calls; - cleaned.push(cleanedMessage); - } - // If no valid tool calls and no content, skip the message entirely - } else if (message.role === 'tool' && message.tool_call_id) { - // Only keep tool responses that have corresponding tool calls - if (toolCallIds.has(message.tool_call_id)) { - cleaned.push(message); - } - } else { - // Keep all other messages as-is - cleaned.push(message); - } - } - - // Final validation: ensure every assistant message with tool_calls has corresponding tool responses - const finalCleaned: OpenAIMessage[] = []; - const finalToolCallIds = new Set(); - - // Collect all remaining tool call IDs - for (const message of cleaned) { - if (message.role === 'assistant' && message.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.id) { - finalToolCallIds.add(toolCall.id); - } - } - } - } - - // Verify all tool calls have responses - const finalToolResponseIds = new Set(); - for (const message of cleaned) { - if (message.role === 'tool' && message.tool_call_id) { - finalToolResponseIds.add(message.tool_call_id); - } - } - - // Remove any remaining orphaned tool calls - for (const message of cleaned) { - if (message.role === 'assistant' && message.tool_calls) { - const finalValidToolCalls = message.tool_calls.filter( - (toolCall) => toolCall.id && finalToolResponseIds.has(toolCall.id), - ); - - if (finalValidToolCalls.length > 0) { - const cleanedMessage = { ...message }; - cleanedMessage.tool_calls = finalValidToolCalls; - finalCleaned.push(cleanedMessage); - } else if ( - typeof message.content === 'string' && - message.content.trim() - ) { - const cleanedMessage = { ...message }; - delete cleanedMessage.tool_calls; - finalCleaned.push(cleanedMessage); - } - } else { - finalCleaned.push(message); - } - } - - return finalCleaned; - } - - /** - * Merge consecutive assistant messages to combine split text and tool calls for logging - */ - private mergeConsecutiveAssistantMessagesForLogging( - messages: OpenAIMessage[], - ): OpenAIMessage[] { - const merged: OpenAIMessage[] = []; - - for (const message of messages) { - if (message.role === 'assistant' && merged.length > 0) { - const lastMessage = merged[merged.length - 1]; - - // If the last message is also an assistant message, merge them - if (lastMessage.role === 'assistant') { - // Combine content - const combinedContent = [ - lastMessage.content || '', - message.content || '', - ] - .filter(Boolean) - .join(''); - - // Combine tool calls - const combinedToolCalls = [ - ...(lastMessage.tool_calls || []), - ...(message.tool_calls || []), - ]; - - // Update the last message with combined data - lastMessage.content = combinedContent || null; - if (combinedToolCalls.length > 0) { - lastMessage.tool_calls = combinedToolCalls; - } - - continue; // Skip adding the current message since it's been merged - } - } - - // Add the message as-is if no merging is needed - merged.push(message); - } - - return merged; - } - /** * Convert Gemini response format to OpenAI chat completion format for logging */ diff --git a/packages/core/src/qwen/qwenContentGenerator.test.ts b/packages/core/src/qwen/qwenContentGenerator.test.ts index d2878f93..a56aed81 100644 --- a/packages/core/src/qwen/qwenContentGenerator.test.ts +++ b/packages/core/src/qwen/qwenContentGenerator.test.ts @@ -21,6 +21,7 @@ import { } from '@google/genai'; import { QwenContentGenerator } from './qwenContentGenerator.js'; import { Config } from '../config/config.js'; +import { AuthType, ContentGeneratorConfig } from '../core/contentGenerator.js'; // Mock the OpenAIContentGenerator parent class vi.mock('../core/openaiContentGenerator.js', () => ({ @@ -30,10 +31,13 @@ vi.mock('../core/openaiContentGenerator.js', () => ({ baseURL: string; }; - constructor(apiKey: string, _model: string, _config: Config) { + constructor( + contentGeneratorConfig: ContentGeneratorConfig, + _config: Config, + ) { this.client = { - apiKey, - baseURL: 'https://api.openai.com/v1', + apiKey: contentGeneratorConfig.apiKey || 'test-key', + baseURL: contentGeneratorConfig.baseUrl || 'https://api.openai.com/v1', }; } @@ -131,9 +135,13 @@ describe('QwenContentGenerator', () => { }; // Create QwenContentGenerator instance + const contentGeneratorConfig = { + model: 'qwen-turbo', + authType: AuthType.QWEN_OAUTH, + }; qwenContentGenerator = new QwenContentGenerator( mockQwenClient, - 'qwen-turbo', + contentGeneratorConfig, mockConfig, ); }); diff --git a/packages/core/src/qwen/qwenContentGenerator.ts b/packages/core/src/qwen/qwenContentGenerator.ts index 1158d547..2a9468bd 100644 --- a/packages/core/src/qwen/qwenContentGenerator.ts +++ b/packages/core/src/qwen/qwenContentGenerator.ts @@ -20,6 +20,7 @@ import { EmbedContentParameters, EmbedContentResponse, } from '@google/genai'; +import { ContentGeneratorConfig } from '../core/contentGenerator.js'; // Default fallback base URL if no endpoint is provided const DEFAULT_QWEN_BASE_URL = @@ -36,9 +37,13 @@ export class QwenContentGenerator extends OpenAIContentGenerator { private currentEndpoint: string | null = null; private refreshPromise: Promise | null = null; - constructor(qwenClient: IQwenOAuth2Client, model: string, config: Config) { + constructor( + qwenClient: IQwenOAuth2Client, + contentGeneratorConfig: ContentGeneratorConfig, + config: Config, + ) { // Initialize with empty API key, we'll override it dynamically - super('', model, config); + super(contentGeneratorConfig, config); this.qwenClient = qwenClient; // Set default base URL, will be updated dynamically