Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TheiaAi] Support referencing prompt fragments via variable #14985

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 94 additions & 16 deletions packages/ai-chat/src/common/chat-request-parser.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,42 @@ import * as sinon from 'sinon';
import { ChatAgentServiceImpl } from './chat-agent-service';
import { ChatRequestParserImpl } from './chat-request-parser';
import { ChatAgentLocation } from './chat-agents';
import { ChatRequest } from './chat-model';
import { ChatContext, ChatRequest } from './chat-model';
import { expect } from 'chai';
import { DefaultAIVariableService, ToolInvocationRegistry, ToolInvocationRegistryImpl } from '@theia/ai-core';
import { AIVariable, DefaultAIVariableService, ResolvedAIVariable, ToolInvocationRegistryImpl, ToolRequest } from '@theia/ai-core';
import { ILogger, Logger } from '@theia/core';
import { ParsedChatRequestTextPart, ParsedChatRequestVariablePart } from './parsed-chat-request';

describe('ChatRequestParserImpl', () => {
const chatAgentService = sinon.createStubInstance(ChatAgentServiceImpl);
const variableService = sinon.createStubInstance(DefaultAIVariableService);
const toolInvocationRegistry: ToolInvocationRegistry = sinon.createStubInstance(ToolInvocationRegistryImpl);
const parser = new ChatRequestParserImpl(chatAgentService, variableService, toolInvocationRegistry);
const toolInvocationRegistry = sinon.createStubInstance(ToolInvocationRegistryImpl);
const logger: ILogger = sinon.createStubInstance(Logger);
const parser = new ChatRequestParserImpl(chatAgentService, variableService, toolInvocationRegistry, logger);

it('parses simple text', () => {
beforeEach(() => {
// Reset our stubs before each test
sinon.reset();
});

it('parses simple text', async () => {
const req: ChatRequest = {
text: 'What is the best pizza topping?'
};
const result = parser.parseChatRequest(req, ChatAgentLocation.Panel);
const context: ChatContext = { variables: [] };
const result = await parser.parseChatRequest(req, ChatAgentLocation.Panel, context);
expect(result.parts).to.deep.contain({
text: 'What is the best pizza topping?',
range: { start: 0, endExclusive: 31 }
});
});

it('parses text with variable name', () => {
it('parses text with variable name', async () => {
const req: ChatRequest = {
text: 'What is the #best pizza topping?'
};
const result = parser.parseChatRequest(req, ChatAgentLocation.Panel);
const context: ChatContext = { variables: [] };
const result = await parser.parseChatRequest(req, ChatAgentLocation.Panel, context);
expect(result).to.deep.contain({
parts: [{
text: 'What is the ',
Expand All @@ -59,11 +69,12 @@ describe('ChatRequestParserImpl', () => {
});
});

it('parses text with variable name with argument', () => {
it('parses text with variable name with argument', async () => {
const req: ChatRequest = {
text: 'What is the #best:by-poll pizza topping?'
};
const result = parser.parseChatRequest(req, ChatAgentLocation.Panel);
const context: ChatContext = { variables: [] };
const result = await parser.parseChatRequest(req, ChatAgentLocation.Panel, context);
expect(result).to.deep.contain({
parts: [{
text: 'What is the ',
Expand All @@ -79,11 +90,12 @@ describe('ChatRequestParserImpl', () => {
});
});

it('parses text with variable name with numeric argument', () => {
it('parses text with variable name with numeric argument', async () => {
const req: ChatRequest = {
text: '#size-class:2'
};
const result = parser.parseChatRequest(req, ChatAgentLocation.Panel);
const context: ChatContext = { variables: [] };
const result = await parser.parseChatRequest(req, ChatAgentLocation.Panel, context);
expect(result.parts[0]).to.contain(
{
variableName: 'size-class',
Expand All @@ -92,11 +104,12 @@ describe('ChatRequestParserImpl', () => {
);
});

it('parses text with variable name with POSIX path argument', () => {
it('parses text with variable name with POSIX path argument', async () => {
const req: ChatRequest = {
text: '#file:/path/to/file.ext'
};
const result = parser.parseChatRequest(req, ChatAgentLocation.Panel);
const context: ChatContext = { variables: [] };
const result = await parser.parseChatRequest(req, ChatAgentLocation.Panel, context);
expect(result.parts[0]).to.contain(
{
variableName: 'file',
Expand All @@ -105,16 +118,81 @@ describe('ChatRequestParserImpl', () => {
);
});

it('parses text with variable name with Win32 path argument', () => {
it('parses text with variable name with Win32 path argument', async () => {
const req: ChatRequest = {
text: '#file:c:\\path\\to\\file.ext'
};
const result = parser.parseChatRequest(req, ChatAgentLocation.Panel);
const context: ChatContext = { variables: [] };
const result = await parser.parseChatRequest(req, ChatAgentLocation.Panel, context);
expect(result.parts[0]).to.contain(
{
variableName: 'file',
variableArg: 'c:\\path\\to\\file.ext'
}
);
});

it('resolves variable and extracts tool functions from resolved variable', async () => {
// Set up two test tool requests that will be referenced in the variable content
const testTool1: ToolRequest = {
id: 'testTool1',
name: 'Test Tool 1',
handler: async () => undefined
};
const testTool2: ToolRequest = {
id: 'testTool2',
name: 'Test Tool 2',
handler: async () => undefined
};
// Configure the tool registry to return our test tools
toolInvocationRegistry.getFunction.withArgs(testTool1.id).returns(testTool1);
toolInvocationRegistry.getFunction.withArgs(testTool2.id).returns(testTool2);

// Set up the test variable to include in the request
const testVariable: AIVariable = {
id: 'testVariable',
name: 'testVariable',
description: 'A test variable',
};
// Configure the variable service to return our test variable
variableService.getVariable.withArgs(testVariable.name).returns(testVariable);
variableService.resolveVariable.withArgs(
{ variable: testVariable.name, arg: 'myarg' },
sinon.match.any
).resolves({
variable: testVariable,
arg: 'myarg',
value: 'This is a test with ~testTool1 and ~testTool2',
});

// Create a request with the test variable
const req: ChatRequest = {
text: 'Test with #testVariable:myarg'
};
const context: ChatContext = { variables: [] };

// Parse the request
const result = await parser.parseChatRequest(req, ChatAgentLocation.Panel, context);

// Verify the variable part contains the correct properties
expect(result.parts.length).to.equal(2);
expect(result.parts[0] instanceof ParsedChatRequestTextPart).to.be.true;
expect(result.parts[1] instanceof ParsedChatRequestVariablePart).to.be.true;
const variablePart = result.parts[1] as ParsedChatRequestVariablePart;
expect(variablePart).to.have.property('resolution');
expect(variablePart.resolution).to.deep.equal({
variable: testVariable,
arg: 'myarg',
value: 'This is a test with ~testTool1 and ~testTool2',
} satisfies ResolvedAIVariable);

// Verify both tool functions were extracted from the variable content
expect(result.toolRequests.size).to.equal(2);
expect(result.toolRequests.has(testTool1.id)).to.be.true;
expect(result.toolRequests.has(testTool2.id)).to.be.true;

// Verify the result contains the tool requests returned by the registry
expect(result.toolRequests.get(testTool1.id)).to.deep.equal(testTool1);
expect(result.toolRequests.get(testTool2.id)).to.deep.equal(testTool2);
});
});
74 changes: 63 additions & 11 deletions packages/ai-chat/src/common/chat-request-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import { inject, injectable } from '@theia/core/shared/inversify';
import { ChatAgentService } from './chat-agent-service';
import { ChatAgentLocation } from './chat-agents';
import { ChatRequest } from './chat-model';
import { ChatContext, ChatRequest } from './chat-model';
import {
chatAgentLeader,
chatFunctionLeader,
Expand All @@ -35,15 +35,16 @@ import {
ParsedChatRequest,
ParsedChatRequestPart,
} from './parsed-chat-request';
import { AIVariable, AIVariableService, ToolInvocationRegistry, ToolRequest } from '@theia/ai-core';
import { AIVariable, AIVariableService, PROMPT_FUNCTION_REGEX, ToolInvocationRegistry, ToolRequest } from '@theia/ai-core';
import { ILogger } from '@theia/core';

const agentReg = /^@([\w_\-\.]+)(?=(\s|$|\b))/i; // An @-agent
const functionReg = /^~([\w_\-\.]+)(?=(\s|$|\b))/i; // A ~ tool function
const variableReg = /^#([\w_\-]+)(?::([\w_\-_\/\\.:]+))?(?=(\s|$|\b))/i; // A #-variable with an optional : arg (#file:workspace/path/name.ext)

export const ChatRequestParser = Symbol('ChatRequestParser');
export interface ChatRequestParser {
parseChatRequest(request: ChatRequest, location: ChatAgentLocation): ParsedChatRequest;
parseChatRequest(request: ChatRequest, location: ChatAgentLocation, context: ChatContext): Promise<ParsedChatRequest>;
}

function offsetRange(start: number, endExclusive: number): OffsetRange {
Expand All @@ -57,10 +58,43 @@ export class ChatRequestParserImpl {
constructor(
@inject(ChatAgentService) private readonly agentService: ChatAgentService,
@inject(AIVariableService) private readonly variableService: AIVariableService,
@inject(ToolInvocationRegistry) private readonly toolInvocationRegistry: ToolInvocationRegistry
@inject(ToolInvocationRegistry) private readonly toolInvocationRegistry: ToolInvocationRegistry,
@inject(ILogger) private readonly logger: ILogger
) { }

parseChatRequest(request: ChatRequest, location: ChatAgentLocation): ParsedChatRequest {
async parseChatRequest(request: ChatRequest, location: ChatAgentLocation, context: ChatContext): Promise<ParsedChatRequest> {
// Parse the request into parts
const { parts, toolRequests, variables } = this.parseParts(request, location);

// Resolve all variables and add them to the variable parts.
// Parse resolved variable texts again for tool requests.
// These are not added to parts as they are not visible in the initial chat message.
// However, add they need to be added to the result to be considered by the executing agent.
// TODO [recursive variable resolution] collect recursively resolved variables for result
for (const part of parts) {
if (part instanceof ParsedChatRequestVariablePart) {
const resolvedVariable = await this.variableService.resolveVariable(
{ variable: part.variableName, arg: part.variableArg },
context
);
if (resolvedVariable) {
part.resolution = resolvedVariable;
// Resolve tool requests in resolved variables
this.parseFunctionsFromText(resolvedVariable.value, toolRequests);
} else {
this.logger.warn(`Failed to resolve variable ${part.variableName} for ${location}`);
}
}
}

return { request, parts, toolRequests, variables };
}

protected parseParts(request: ChatRequest, location: ChatAgentLocation): {
parts: ParsedChatRequestPart[];
toolRequests: Map<string, ToolRequest>;
variables: Map<string, AIVariable>;
} {
const parts: ParsedChatRequestPart[] = [];
const variables = new Map<string, AIVariable>();
const toolRequests = new Map<string, ToolRequest>();
Expand All @@ -72,7 +106,7 @@ export class ChatRequestParserImpl {

if (previousChar.match(/\s/) || i === 0) {
if (char === chatFunctionLeader) {
const functionPart = this.tryParseFunction(
const functionPart = this.tryToParseFunction(
message.slice(i),
i
);
Expand Down Expand Up @@ -107,8 +141,7 @@ export class ChatRequestParserImpl {
if (i !== 0) {
// Insert a part for all the text we passed over, then insert the new parsed part
const previousPart = parts.at(-1);
const previousPartEnd =
previousPart?.range.endExclusive ?? 0;
const previousPartEnd = previousPart?.range.endExclusive ?? 0;
parts.push(
new ParsedChatRequestTextPart(
offsetRange(previousPartEnd, i),
Expand All @@ -131,8 +164,26 @@ export class ChatRequestParserImpl {
)
);
}
return { parts, toolRequests, variables };
}

return { request, parts, toolRequests, variables };
/**
* Parse text for tool requests and add them to the given map
*/
private parseFunctionsFromText(text: string, toolRequests: Map<string, ToolRequest>): void {
for (let i = 0; i < text.length; i++) {
const previousChar = i === 0 ? ' ' : text.charAt(i - 1);
const char = text.charAt(i);

// Check for function markers at start of words
if ((previousChar.match(/\s/) || i === 0) && char === chatFunctionLeader) {
const functionPart = this.tryToParseFunction(text.slice(i), i);
if (functionPart) {
// Add the found tool request to the given map
toolRequests.set(functionPart.toolRequest.id, functionPart.toolRequest);
}
}
}
}

private tryToParseAgent(
Expand Down Expand Up @@ -201,8 +252,9 @@ export class ChatRequestParserImpl {
return new ParsedChatRequestVariablePart(varRange, name, variableArg);
}

private tryParseFunction(message: string, offset: number): ParsedChatRequestFunctionPart | undefined {
const nextFunctionMatch = message.match(functionReg);
private tryToParseFunction(message: string, offset: number): ParsedChatRequestFunctionPart | undefined {
// Support both the and chat and prompt formats for functions
const nextFunctionMatch = message.match(functionReg) || message.match(PROMPT_FUNCTION_REGEX);
if (!nextFunctionMatch) {
return;
}
Expand Down
22 changes: 4 additions & 18 deletions packages/ai-chat/src/common/chat-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import {
ChatContext,
} from './chat-model';
import { ChatRequestParser } from './chat-request-parser';
import { ParsedChatRequest, ParsedChatRequestAgentPart, ParsedChatRequestVariablePart } from './parsed-chat-request';
import { ParsedChatRequest, ParsedChatRequestAgentPart } from './parsed-chat-request';

export interface ChatRequestInvocation {
/**
Expand Down Expand Up @@ -191,7 +191,9 @@ export class ChatServiceImpl implements ChatService {
}
session.title = request.text;

const parsedRequest = this.chatRequestParser.parseChatRequest(request, session.model.location);
const resolutionContext: ChatSessionContext = { model: session.model };
const resolvedContext = await this.resolveChatContext(session.model.context.getVariables(), resolutionContext);
const parsedRequest = await this.chatRequestParser.parseChatRequest(request, session.model.location, resolvedContext);
const agent = this.getAgent(parsedRequest, session);

if (agent === undefined) {
Expand All @@ -205,25 +207,9 @@ export class ChatServiceImpl implements ChatService {
};
}

const resolutionContext: ChatSessionContext = { model: session.model };
const resolvedContext = await this.resolveChatContext(session.model.context.getVariables(), resolutionContext);
const requestModel = session.model.addRequest(parsedRequest, agent?.id, resolvedContext);
resolutionContext.request = requestModel;

for (const part of parsedRequest.parts) {
if (part instanceof ParsedChatRequestVariablePart) {
const resolvedVariable = await this.variableService.resolveVariable(
{ variable: part.variableName, arg: part.variableArg },
resolutionContext
);
if (resolvedVariable) {
part.resolution = resolvedVariable;
} else {
this.logger.warn(`Failed to resolve variable ${part.variableName} for ${session.model.location}`);
}
}
}

let resolveResponseCreated: (responseModel: ChatResponseModel) => void;
let resolveResponseCompleted: (responseModel: ChatResponseModel) => void;
const invocation: ChatRequestInvocation = {
Expand Down
2 changes: 2 additions & 0 deletions packages/ai-core/src/browser/ai-core-frontend-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import { AIActivationService } from './ai-activation-service';
import { AgentService, AgentServiceImpl } from '../common/agent-service';
import { AICommandHandlerFactory } from './ai-command-handler-factory';
import { AISettingsService } from '../common/settings-service';
import { PromptVariableContribution } from '../common/prompt-variable-contribution';

export default new ContainerModule(bind => {
bindContributionProvider(bind, LanguageModelProvider);
Expand Down Expand Up @@ -109,6 +110,7 @@ export default new ContainerModule(bind => {
bind(TheiaVariableContribution).toSelf().inSingletonScope();
bind(AIVariableContribution).toService(TheiaVariableContribution);

bind(AIVariableContribution).to(PromptVariableContribution).inSingletonScope();
bind(AIVariableContribution).to(TodayVariableContribution).inSingletonScope();
bind(AIVariableContribution).to(FileVariableContribution).inSingletonScope();
bind(AIVariableContribution).to(AgentsVariableContribution).inSingletonScope();
Expand Down
Loading
Loading