diff --git a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts index 396bcb7331a06..72a0cbe78116d 100644 --- a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts +++ b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts @@ -37,11 +37,14 @@ import { aiChatPreferences } from './ai-chat-preferences'; import { AICustomAgentsFrontendApplicationContribution } from './custom-agent-frontend-application-contribution'; import { FrontendChatServiceImpl } from './frontend-chat-service'; import { CustomAgentFactory } from './custom-agent-factory'; +import { ChatToolRequestService } from '../common/chat-tool-request-service'; export default new ContainerModule(bind => { bindContributionProvider(bind, Agent); bindContributionProvider(bind, ChatAgent); + bind(ChatToolRequestService).toSelf().inSingletonScope(); + bind(ChatAgentServiceImpl).toSelf().inSingletonScope(); bind(ChatAgentService).toService(ChatAgentServiceImpl); bind(DefaultChatAgentId).toConstantValue({ id: OrchestratorChatAgentId }); diff --git a/packages/ai-chat/src/common/chat-agents.ts b/packages/ai-chat/src/common/chat-agents.ts index 033ff7dba3296..acd8e68d2ef66 100644 --- a/packages/ai-chat/src/common/chat-agents.ts +++ b/packages/ai-chat/src/common/chat-agents.ts @@ -43,7 +43,6 @@ import { inject, injectable, named, postConstruct, unmanaged } from '@theia/core import { ChatAgentService } from './chat-agent-service'; import { ChatModel, - ChatRequestModel, ChatRequestModelImpl, ChatResponseContent, ErrorChatResponseContentImpl, @@ -53,6 +52,7 @@ import { import { findFirstMatch, parseContents } from './parse-contents'; import { DefaultResponseContentFactory, ResponseContentMatcher, ResponseContentMatcherProvider } from './response-content-matcher'; import { ChatHistoryEntry } from './chat-history-entry'; +import { ChatToolRequestService } from './chat-tool-request-service'; /** * A conversation consists of a sequence of ChatMessages. @@ -123,10 +123,12 @@ export abstract class AbstractChatAgent { @inject(LanguageModelRegistry) protected languageModelRegistry: LanguageModelRegistry; @inject(ILogger) protected logger: ILogger; @inject(CommunicationRecordingService) protected recordingService: CommunicationRecordingService; + @inject(ChatToolRequestService) protected chatToolRequestService: ChatToolRequestService; @inject(PromptService) protected promptService: PromptService; @inject(ContributionProvider) @named(ResponseContentMatcherProvider) protected contentMatcherProviders: ContributionProvider; + protected additionalToolRequests: ToolRequest[] = []; protected contentMatchers: ResponseContentMatcher[] = []; @inject(DefaultResponseContentFactory) @@ -171,7 +173,6 @@ export abstract class AbstractChatAgent { ); } - const tools: Map = new Map(); if (systemMessageDescription) { const systemMsg: ChatMessage = { actor: 'system', @@ -180,16 +181,19 @@ export abstract class AbstractChatAgent { }; // insert system message at the beginning of the request messages messages.unshift(systemMsg); - systemMessageDescription.functionDescriptions?.forEach((tool, id) => { - tools.set(id, tool); - }); } - this.getTools(request)?.forEach(tool => tools.set(tool.id, tool)); + + const systemMessageToolRequests = systemMessageDescription?.functionDescriptions?.values(); + const tools = [ + ...this.chatToolRequestService.getChatToolRequests(request), + ...this.chatToolRequestService.toChatToolRequests(systemMessageToolRequests ? Array.from(systemMessageToolRequests) : [], request), + ...this.chatToolRequestService.toChatToolRequests(this.additionalToolRequests, request) + ]; const languageModelResponse = await this.callLlm( languageModel, messages, - tools.size > 0 ? Array.from(tools.values()) : undefined, + tools.length > 0 ? tools : undefined, request.response.cancellationToken ); await this.addContentsToResponse(languageModelResponse, request); @@ -258,15 +262,6 @@ export abstract class AbstractChatAgent { return requestMessages; } - /** - * @returns the list of tools used by this agent, or undefined if none is needed. - */ - protected getTools(request: ChatRequestModel): ToolRequest[] | undefined { - return request.message.toolRequests.size > 0 - ? [...request.message.toolRequests.values()] - : undefined; - } - protected async callLlm( languageModel: LanguageModel, messages: ChatMessage[], diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index dce76614cc768..ccd3fd5fbd374 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -76,6 +76,21 @@ export interface ChatRequestModel { readonly data?: { [key: string]: unknown }; } +export namespace ChatRequestModel { + export function is(request: unknown): request is ChatRequestModel { + return !!( + request && + typeof request === 'object' && + 'id' in request && + typeof (request as { id: unknown }).id === 'string' && + 'session' in request && + 'request' in request && + 'response' in request && + 'message' in request + ); + } +} + export interface ChatProgressMessage { kind: 'progressMessage'; id: string; diff --git a/packages/ai-chat/src/common/chat-tool-request-service.ts b/packages/ai-chat/src/common/chat-tool-request-service.ts new file mode 100644 index 0000000000000..ec238f1df5ff7 --- /dev/null +++ b/packages/ai-chat/src/common/chat-tool-request-service.ts @@ -0,0 +1,59 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { ToolRequest } from '@theia/ai-core'; +import { injectable } from '@theia/core/shared/inversify'; +import { ChatRequestModelImpl } from './chat-model'; + +export interface ChatToolRequest extends ToolRequest { + handler: ( + arg_string: string, + context: ChatRequestModelImpl, + ) => Promise; +} + +/** + * Wraps tool requests in a chat context. + * + * This service extracts tool requests from a given chat request model and wraps their + * handler functions to provide additional context, such as the chat request model. + */ +@injectable() +export class ChatToolRequestService { + + getChatToolRequests(request: ChatRequestModelImpl): ChatToolRequest[] { + const toolRequests = request.message.toolRequests.size > 0 ? [...request.message.toolRequests.values()] : undefined; + if (!toolRequests) { + return []; + } + return this.toChatToolRequests(toolRequests, request); + } + + toChatToolRequests(toolRequests: ToolRequest[] | undefined, request: ChatRequestModelImpl): ChatToolRequest[] { + if (!toolRequests) { + return []; + } + return toolRequests.map(toolRequest => this.toChatToolRequest(toolRequest, request)); + } + + protected toChatToolRequest(toolRequest: ToolRequest, request: ChatRequestModelImpl): ChatToolRequest { + return { + ...toolRequest, + handler: async (arg_string: string) => toolRequest.handler(arg_string, request) + }; + } + +} diff --git a/packages/ai-core/src/common/language-model.ts b/packages/ai-core/src/common/language-model.ts index 7cfebc5519cb1..d2e42519fb299 100644 --- a/packages/ai-core/src/common/language-model.ts +++ b/packages/ai-core/src/common/language-model.ts @@ -43,7 +43,7 @@ export interface ToolRequest { name: string; parameters?: ToolRequestParameters description?: string; - handler: (arg_string: string) => Promise; + handler: (arg_string: string, ctx?: unknown) => Promise; providerName?: string; }