Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial thinking support for claude #15092

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
import {
AbstractStreamParsingChatAgent,
ChatAgent,
ChatMessage,
ChatModel,
MutableChatRequestModel,
lastProgressMessage,
QuestionResponseContentImpl,
unansweredQuestions
} from '@theia/ai-chat';
import { Agent, PromptTemplate } from '@theia/ai-core';
import { Agent, LanguageModelMessage, PromptTemplate } from '@theia/ai-core';
import { injectable, interfaces, postConstruct } from '@theia/core/shared/inversify';

export function bindAskAndContinueChatAgentContribution(bind: interfaces.Bind): void {
Expand Down Expand Up @@ -161,15 +160,15 @@ export class AskAndContinueChatAgent extends AbstractStreamParsingChatAgent {
* As the question/answer are handled within the same response, we add an additional user message at the end to indicate to
* the LLM to continue generating.
*/
protected override async getMessages(model: ChatModel): Promise<ChatMessage[]> {
protected override async getMessages(model: ChatModel): Promise<LanguageModelMessage[]> {
const messages = await super.getMessages(model, true);
const requests = model.getRequests();
if (!requests[requests.length - 1].response.isComplete && requests[requests.length - 1].response.response?.content.length > 0) {
return [...messages,
{
type: 'text',
actor: 'user',
query: 'Continue generating based on the user\'s answer or finish the conversation if 5 or more questions were already answered.'
text: 'Continue generating based on the user\'s answer or finish the conversation if 5 or more questions were already answered.'
}];
}
return messages;
Expand Down
2 changes: 1 addition & 1 deletion packages/ai-anthropic/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"version": "1.59.0",
"description": "Theia - Anthropic Integration",
"dependencies": {
"@anthropic-ai/sdk": "^0.32.1",
"@anthropic-ai/sdk": "^0.39.0",
"@theia/ai-core": "1.59.0",
"@theia/core": "1.59.0"
},
Expand Down
68 changes: 50 additions & 18 deletions packages/ai-anthropic/src/node/anthropic-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import {
LanguageModel,
LanguageModelRequest,
LanguageModelRequestMessage,
LanguageModelMessage,
LanguageModelResponse,
LanguageModelStreamResponse,
LanguageModelStreamResponsePart,
LanguageModelTextResponse
} from '@theia/ai-core';
import { CancellationToken, isArray } from '@theia/core';
import { Anthropic } from '@anthropic-ai/sdk';
import { MessageParam } from '@anthropic-ai/sdk/resources';
import { Message, MessageParam } from '@anthropic-ai/sdk/resources';

const DEFAULT_MAX_TOKENS_STREAMING = 4096;
const DEFAULT_MAX_TOKENS_NON_STREAMING = 2048;
Expand All @@ -42,23 +42,36 @@ interface ToolCallback {
args: string;
}

const createMessageContent = (message: LanguageModelMessage): MessageParam['content'] => {
if (LanguageModelMessage.isTextMessage(message)) {
return message.text;
} else if (LanguageModelMessage.isThinkingMessage(message)) {
return [{ signature: message.signature, thinking: message.thinking, type: 'thinking' }];
} else if (LanguageModelMessage.isToolUseMessage(message)) {
return [{ id: message.id, input: message.input, name: message.name, type: 'tool_use' }];
} else if (LanguageModelMessage.isToolResultMessage(message)) {
return [{ type: 'tool_result', tool_use_id: message.tool_use_id }];
}
throw new Error(`Unknown message type:'${JSON.stringify(message)}'`);
};

/**
* Transforms Theia language model messages to Anthropic API format
* @param messages Array of LanguageModelRequestMessage to transform
* @returns Object containing transformed messages and optional system message
*/
function transformToAnthropicParams(
messages: readonly LanguageModelRequestMessage[]
messages: readonly LanguageModelMessage[]
): { messages: MessageParam[]; systemMessage?: string } {
// Extract the system message (if any), as it is a separate parameter in the Anthropic API.
const systemMessageObj = messages.find(message => message.actor === 'system');
const systemMessage = systemMessageObj?.query;
const systemMessage = systemMessageObj && LanguageModelMessage.isTextMessage(systemMessageObj) && systemMessageObj.text || '';

const convertedMessages = messages
.filter(message => message.actor !== 'system')
.map(message => ({
role: toAnthropicRole(message),
content: message.query || '',
content: createMessageContent(message)
}));

return {
Expand All @@ -74,7 +87,7 @@ export const AnthropicModelIdentifier = Symbol('AnthropicModelIdentifier');
* @param message The message to convert
* @returns Anthropic role ('user' or 'assistant')
*/
function toAnthropicRole(message: LanguageModelRequestMessage): 'user' | 'assistant' {
function toAnthropicRole(message: LanguageModelMessage): 'user' | 'assistant' {
switch (message.actor) {
case 'ai':
return 'assistant';
Expand Down Expand Up @@ -152,7 +165,7 @@ export class AnthropicModel implements LanguageModel {
...(systemMessage && { system: systemMessage }),
...settings
};

console.log(JSON.stringify(params));
const stream = anthropic.messages.stream(params);

cancellationToken?.onCancellationRequested(() => {
Expand All @@ -165,11 +178,15 @@ export class AnthropicModel implements LanguageModel {

const toolCalls: ToolCallback[] = [];
let toolCall: ToolCallback | undefined;
const currentMessages: Message[] = [];

for await (const event of stream) {
if (event.type === 'content_block_start') {
const contentBlock = event.content_block;

if (contentBlock.type === 'thinking') {
yield { thought: contentBlock.thinking, signature: contentBlock.signature ?? '' };
}
if (contentBlock.type === 'text') {
yield { content: contentBlock.text };
}
Expand All @@ -179,7 +196,12 @@ export class AnthropicModel implements LanguageModel {
}
} else if (event.type === 'content_block_delta') {
const delta = event.delta;

if (delta.type === 'thinking_delta') {
yield { thought: delta.thinking, signature: '' };
}
if (delta.type === 'signature_delta') {
yield { thought: '', signature: delta.signature };
}
if (delta.type === 'text_delta') {
yield { content: delta.text };
}
Expand All @@ -199,6 +221,8 @@ export class AnthropicModel implements LanguageModel {
}
throw new Error(`The response was stopped because it exceeded the max token limit of ${event.usage.output_tokens}.`);
}
} else if (event.type === 'message_start') {
currentMessages.push(event.message);
}
}
if (toolCalls.length > 0) {
Expand All @@ -216,16 +240,16 @@ export class AnthropicModel implements LanguageModel {
});
yield { tool_calls: calls };

const toolRequestMessage: Anthropic.Messages.MessageParam = {
role: 'assistant',
content: toolResult.map(call => ({
// const toolRequestMessage: Anthropic.Messages.MessageParam = {
// role: 'assistant',
// content: toolResult.map(call => ({

type: 'tool_use',
id: call.id,
name: call.name,
input: JSON.parse(call.arguments)
}))
};
// type: 'tool_use',
// id: call.id,
// name: call.name,
// input: JSON.parse(call.arguments)
// }))
// };

const toolResponseMessage: Anthropic.Messages.MessageParam = {
role: 'user',
Expand All @@ -235,7 +259,15 @@ export class AnthropicModel implements LanguageModel {
content: that.formatToolCallResult(call.result)
}))
};
const result = await that.handleStreamingRequest(anthropic, request, cancellationToken, [...(toolMessages ?? []), toolRequestMessage, toolResponseMessage]);
const result = await that.handleStreamingRequest(
anthropic,
request,
cancellationToken,
[
...(toolMessages ?? []),
...currentMessages.map(m => ({ role: m.role, content: m.content })),
toolResponseMessage
]);
for await (const nestedEvent of result.stream) {
yield nestedEvent;
}
Expand Down
2 changes: 2 additions & 0 deletions packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
HorizontalLayoutPartRenderer,
InsertCodeAtCursorButtonAction,
MarkdownPartRenderer,
TextPartRenderer,
ToolCallPartRenderer,
} from './chat-response-renderer';
import {
Expand Down Expand Up @@ -79,6 +80,7 @@ export default new ContainerModule((bind, _unbind, _isBound, rebind) => {

bind(ContextVariablePicker).toSelf().inSingletonScope();

bind(ChatResponsePartRenderer).to(TextPartRenderer).inSingletonScope();
bind(ChatResponsePartRenderer).to(HorizontalLayoutPartRenderer).inSingletonScope();
bind(ChatResponsePartRenderer).to(ErrorPartRenderer).inSingletonScope();
bind(ChatResponsePartRenderer).to(MarkdownPartRenderer).inSingletonScope();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ export class TextPartRenderer implements ChatResponsePartRenderer<ChatResponseCo
return 1;
}
render(response: ChatResponseContent): ReactNode {
if (response && ChatResponseContent.hasAsString(response)) {
return <span>{response.asString()}</span>;
if (response && ChatResponseContent.hasDisplayString(response)) {
return <span>{response.asDisplayString()}</span>;
}
return <span>
{nls.localize('theia/ai/chat-ui/text-part-renderer/cantDisplay',
Expand Down
77 changes: 42 additions & 35 deletions packages/ai-chat/src/common/chat-agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ import {
AIVariableContext,
CommunicationRecordingService,
getTextOfResponse,
isTextResponsePart,
isThinkingResponsePart,
isToolCallResponsePart,
LanguageModel,
LanguageModelMessage,
LanguageModelRequirement,
LanguageModelResponse,
LanguageModelStreamResponse,
PromptService,
PromptTemplate,
ResolvedPromptTemplate,
TextMessage,
ToolCall,
ToolRequest,
} from '@theia/ai-core';
Expand All @@ -39,8 +44,7 @@ import {
isLanguageModelStreamResponse,
isLanguageModelTextResponse,
LanguageModelRegistry,
LanguageModelStreamResponsePart,
MessageActor,
LanguageModelStreamResponsePart
} from '@theia/ai-core/lib/common';
import { CancellationToken, ContributionProvider, ILogger, isArray } from '@theia/core';
import { inject, injectable, named, postConstruct } from '@theia/core/shared/inversify';
Expand All @@ -52,25 +56,14 @@ import {
ErrorChatResponseContentImpl,
MarkdownChatResponseContentImpl,
ToolCallChatResponseContentImpl,
ChatRequestModel
ChatRequestModel,
ThinkingChatResponseContentImpl
} from './chat-model';
import { findFirstMatch, parseContents } from './parse-contents';
import { DefaultResponseContentFactory, ResponseContentMatcher, ResponseContentMatcherProvider } from './response-content-matcher';
import { ChatHistoryEntry } from './chat-history-entry';
import { ChatToolRequestService } from './chat-tool-request-service';

/**
* A conversation consists of a sequence of ChatMessages.
* Each ChatMessage is either a user message, AI message or a system message.
*
* For now we only support text based messages.
*/
export interface ChatMessage {
actor: MessageActor;
type: 'text';
query: string;
}

/**
* System message content, enriched with function descriptions.
*/
Expand Down Expand Up @@ -194,10 +187,10 @@ export abstract class AbstractChatAgent implements ChatAgent {
}

if (systemMessageDescription) {
const systemMsg: ChatMessage = {
const systemMsg: LanguageModelMessage = {
actor: 'system',
type: 'text',
query: systemMessageDescription.text
text: systemMessageDescription.text
};
// insert system message at the beginning of the request messages
messages.unshift(systemMsg);
Expand Down Expand Up @@ -266,21 +259,28 @@ export abstract class AbstractChatAgent implements ChatAgent {

protected async getMessages(
model: ChatModel, includeResponseInProgress = false
): Promise<ChatMessage[]> {
): Promise<LanguageModelMessage[]> {
const requestMessages = model.getRequests().flatMap(request => {
const messages: ChatMessage[] = [];
const messages: LanguageModelMessage[] = [];
const text = request.message.parts.map(part => part.promptText).join('');
messages.push({
actor: 'user',
type: 'text',
query: text,
text: text,
});
if (request.response.isComplete || includeResponseInProgress) {
messages.push({
actor: 'ai',
type: 'text',
query: request.response.response.asString(),
const responseMessages: LanguageModelMessage[] = request.response.response.content.flatMap(c => {
if (ChatResponseContent.hasToLanguageModelMessage(c)) {
return c.toLanguageModelMessage();
}

return {
actor: 'ai',
type: 'text',
text: c.asString?.() ?? c.asDisplayString?.() ?? '',
} as TextMessage;
});
messages.push(...responseMessages);
}
return messages;
});
Expand All @@ -290,7 +290,7 @@ export abstract class AbstractChatAgent implements ChatAgent {

protected async callLlm(
languageModel: LanguageModel,
messages: ChatMessage[],
messages: LanguageModelMessage[],
tools: ToolRequest[] | undefined,
token: CancellationToken
): Promise<LanguageModelResponse> {
Expand Down Expand Up @@ -407,17 +407,24 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {
}

protected parse(token: LanguageModelStreamResponsePart, request: MutableChatRequestModel): ChatResponseContent | ChatResponseContent[] {
const content = token.content;
// eslint-disable-next-line no-null/no-null
if (content !== undefined && content !== null) {
return this.defaultContentFactory.create(content, request);
if (isTextResponsePart(token)) {
const content = token.content;
// eslint-disable-next-line no-null/no-null
if (content !== undefined && content !== null) {
return this.defaultContentFactory.create(content, request);
}
}
const toolCalls = token.tool_calls;
if (toolCalls !== undefined) {
const toolCallContents = toolCalls.map(toolCall =>
this.createToolCallResponseContent(toolCall)
);
return toolCallContents;
if (isToolCallResponsePart(token)) {
const toolCalls = token.tool_calls;
if (toolCalls !== undefined) {
const toolCallContents = toolCalls.map(toolCall =>
this.createToolCallResponseContent(toolCall)
);
return toolCallContents;
}
}
if (isThinkingResponsePart(token)) {
return new ThinkingChatResponseContentImpl(token.thought, token.signature);
}
return this.defaultContentFactory.create('', request);
}
Expand Down
Loading
Loading