Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core,openai): Adds streaming support for OpenAI withStructuredOutput #6721

Merged
merged 6 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ export function isAIMessage(x: BaseMessage): x is AIMessage {
return x._getType() === "ai";
}

export function isAIMessageChunk(x: BaseMessageChunk): x is AIMessageChunk {
return x._getType() === "ai";
}

export type AIMessageChunkFields = AIMessageFields & {
tool_call_chunks?: ToolCallChunk[];
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import { z } from "zod";
import { ChatGeneration } from "../../outputs.js";
import { BaseLLMOutputParser, OutputParserException } from "../base.js";
import { ChatGeneration, ChatGenerationChunk } from "../../outputs.js";
import { OutputParserException } from "../base.js";
import { parsePartialJson } from "../json.js";
import { InvalidToolCall, ToolCall } from "../../messages/tool.js";
import {
BaseCumulativeTransformOutputParser,
BaseCumulativeTransformOutputParserInput,
} from "../transform.js";
import { isAIMessage } from "../../messages/ai.js";

export type ParsedToolCall = {
id?: string;
Expand All @@ -23,7 +28,7 @@ export type ParsedToolCall = {
export type JsonOutputToolsParserParams = {
/** Whether to return the tool call id. */
returnId?: boolean;
};
} & BaseCumulativeTransformOutputParserInput;

export function parseToolCall(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -35,6 +40,11 @@ export function parseToolCall(
rawToolCall: Record<string, any>,
options?: { returnId?: boolean; partial?: false }
): ToolCall;
export function parseToolCall(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
rawToolCall: Record<string, any>,
options?: { returnId?: boolean; partial?: boolean }
): ToolCall | undefined;
export function parseToolCall(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
rawToolCall: Record<string, any>,
Expand Down Expand Up @@ -112,9 +122,9 @@ export function makeInvalidToolCall(
/**
* Class for parsing the output of a tool-calling LLM into a JSON object.
*/
export class JsonOutputToolsParser extends BaseLLMOutputParser<
ParsedToolCall[]
> {
export class JsonOutputToolsParser<
T
> extends BaseCumulativeTransformOutputParser<T> {
static lc_name() {
return "JsonOutputToolsParser";
}
Expand All @@ -130,31 +140,64 @@ export class JsonOutputToolsParser extends BaseLLMOutputParser<
this.returnId = fields?.returnId ?? this.returnId;
}

protected _diff() {
throw new Error("Not supported.");
}

async parse(): Promise<T> {
throw new Error("Not implemented.");
}

async parseResult(generations: ChatGeneration[]): Promise<T> {
const result = await this.parsePartialResult(generations, false);
return result;
}

/**
* Parses the output and returns a JSON object. If `argsOnly` is true,
* only the arguments of the function call are returned.
* @param generations The output of the LLM to parse.
* @returns A JSON object representation of the function call or its arguments.
*/
async parseResult(generations: ChatGeneration[]): Promise<ParsedToolCall[]> {
const toolCalls = generations[0].message.additional_kwargs.tool_calls;
if (!toolCalls) {
throw new Error(
`No tools_call in message ${JSON.stringify(generations)}`
async parsePartialResult(
generations: ChatGenerationChunk[] | ChatGeneration[],
partial = true
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): Promise<any> {
const message = generations[0].message;
let toolCalls;
if (isAIMessage(message) && message.tool_calls?.length) {
toolCalls = message.tool_calls.map((toolCall) => {
const { id, ...rest } = toolCall;
if (!this.returnId) {
return rest;
}
return {
id,
...rest,
};
});
} else if (message.additional_kwargs.tool_calls !== undefined) {
const rawToolCalls = JSON.parse(
JSON.stringify(message.additional_kwargs.tool_calls)
);
toolCalls = rawToolCalls.map((rawToolCall: Record<string, unknown>) => {
return parseToolCall(rawToolCall, { returnId: this.returnId, partial });
});
}
if (!toolCalls) {
return [];
}
const clonedToolCalls = JSON.parse(JSON.stringify(toolCalls));
const parsedToolCalls = [];
for (const toolCall of clonedToolCalls) {
const parsedToolCall = parseToolCall(toolCall, { partial: true });
if (parsedToolCall !== undefined) {
for (const toolCall of toolCalls) {
if (toolCall !== undefined) {
// backward-compatibility with previous
// versions of Langchain JS, which uses `name` and `arguments`
// @ts-expect-error name and arguemnts are defined by Object.defineProperty
const backwardsCompatibleToolCall: ParsedToolCall = {
type: parsedToolCall.name,
args: parsedToolCall.args,
id: parsedToolCall.id,
type: toolCall.name,
args: toolCall.args,
id: toolCall.id,
};
Object.defineProperty(backwardsCompatibleToolCall, "name", {
get() {
Expand All @@ -180,10 +223,8 @@ export type JsonOutputKeyToolsParserParams<
> = {
keyName: string;
returnSingle?: boolean;
/** Whether to return the tool call id. */
returnId?: boolean;
zodSchema?: z.ZodType<T>;
};
} & JsonOutputToolsParserParams;

/**
* Class for parsing the output of a tool-calling LLM into a JSON object if you are
Expand All @@ -192,7 +233,7 @@ export type JsonOutputKeyToolsParserParams<
export class JsonOutputKeyToolsParser<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends Record<string, any> = Record<string, any>
> extends BaseLLMOutputParser<T> {
> extends JsonOutputToolsParser<T> {
static lc_name() {
return "JsonOutputKeyToolsParser";
}
Expand All @@ -209,15 +250,12 @@ export class JsonOutputKeyToolsParser<
/** Whether to return only the first tool call. */
returnSingle = false;

initialParser: JsonOutputToolsParser;

zodSchema?: z.ZodType<T>;

constructor(params: JsonOutputKeyToolsParserParams<T>) {
super(params);
this.keyName = params.keyName;
this.returnSingle = params.returnSingle ?? this.returnSingle;
this.initialParser = new JsonOutputToolsParser(params);
this.zodSchema = params.zodSchema;
}

Expand All @@ -240,17 +278,45 @@ export class JsonOutputKeyToolsParser<
}
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
async parsePartialResult(generations: ChatGeneration[]): Promise<any> {
const results = await super.parsePartialResult(generations);
const matchingResults = results.filter(
(result: ParsedToolCall) => result.type === this.keyName
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let returnedValues: ParsedToolCall[] | Record<string, any>[] =
matchingResults;
if (!matchingResults.length) {
return undefined;
}
if (!this.returnId) {
returnedValues = matchingResults.map(
(result: ParsedToolCall) => result.args
);
}
if (this.returnSingle) {
return returnedValues[0];
}
return returnedValues;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
async parseResult(generations: ChatGeneration[]): Promise<any> {
const results = await this.initialParser.parseResult(generations);
const results = await super.parsePartialResult(generations, false);
const matchingResults = results.filter(
(result) => result.type === this.keyName
(result: ParsedToolCall) => result.type === this.keyName
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let returnedValues: ParsedToolCall[] | Record<string, any>[] =
matchingResults;
if (!matchingResults.length) {
return undefined;
}
if (!this.returnId) {
returnedValues = matchingResults.map((result) => result.args);
returnedValues = matchingResults.map(
(result: ParsedToolCall) => result.args
);
}
if (this.returnSingle) {
return this._validateResult(returnedValues[0]);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { test, expect } from "@jest/globals";
import { z } from "zod";
import { JsonOutputKeyToolsParser } from "../json_output_tools_parsers.js";
import { AIMessage } from "../../../messages/index.js";
import { OutputParserException } from "../../base.js";
import { AIMessage, AIMessageChunk } from "../../../messages/ai.js";
import { RunnableLambda } from "../../../runnables/base.js";

test("JSONOutputKeyToolsParser invoke", async () => {
const outputParser = new JsonOutputKeyToolsParser({
Expand Down Expand Up @@ -87,3 +89,144 @@ test("JSONOutputKeyToolsParser can validate a proper input", async () => {
);
expect(result).toEqual({ testKey: "testval" });
});

test("JSONOutputKeyToolsParser invoke with a top-level tool call", async () => {
const outputParser = new JsonOutputKeyToolsParser({
keyName: "testing",
returnSingle: true,
});
const result = await outputParser.invoke(
new AIMessage({
content: "",
tool_calls: [
{
id: "test",
name: "testing",
args: { testKey: 9 },
},
],
})
);
expect(result).toEqual({ testKey: 9 });
});

test("JSONOutputKeyToolsParser with a top-level tool call and passed schema throws", async () => {
const outputParser = new JsonOutputKeyToolsParser({
keyName: "testing",
returnSingle: true,
zodSchema: z.object({
testKey: z.string(),
}),
});
try {
await outputParser.invoke(
new AIMessage({
content: "",
tool_calls: [
{
id: "test",
name: "testing",
args: { testKey: 9 },
},
],
})
);
} catch (e) {
expect(e).toBeInstanceOf(OutputParserException);
}
});

test("JSONOutputKeyToolsParser with a top-level tool call can validate a proper input", async () => {
const outputParser = new JsonOutputKeyToolsParser({
keyName: "testing",
returnSingle: true,
zodSchema: z.object({
testKey: z.string(),
}),
});
const result = await outputParser.invoke(
new AIMessage({
content: "",
tool_calls: [
{
id: "test",
name: "testing",
args: { testKey: "testval" },
},
],
})
);
expect(result).toEqual({ testKey: "testval" });
});

test("JSONOutputKeyToolsParser can handle streaming input", async () => {
const outputParser = new JsonOutputKeyToolsParser({
keyName: "testing",
returnSingle: true,
zodSchema: z.object({
testKey: z.string(),
}),
});
const fakeModel = RunnableLambda.from(async function* () {
yield new AIMessageChunk({
content: "",
tool_call_chunks: [
{
index: 0,
id: "test",
name: "testing",
args: `{ "testKey":`,
type: "tool_call_chunk",
},
],
});
yield new AIMessageChunk({
content: "",
tool_call_chunks: [],
});
yield new AIMessageChunk({
content: "",
tool_call_chunks: [
{
index: 0,
id: "test",
args: ` "testv`,
type: "tool_call_chunk",
},
],
});
yield new AIMessageChunk({
content: "",
tool_call_chunks: [
{
index: 0,
id: "test",
args: `al" }`,
type: "tool_call_chunk",
},
],
});
});
const stream = await (fakeModel as any).pipe(outputParser).stream();
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(1);
expect(chunks.at(-1)).toEqual({ testKey: "testval" });
// TODO: Fix typing issue
const result = await (fakeModel as any).pipe(outputParser).invoke(
new AIMessage({
content: "",
tool_calls: [
{
id: "test",
name: "testing",
args: { testKey: "testval" },
type: "tool_call",
},
],
})
);
expect(result).toEqual({ testKey: "testval" });
});
4 changes: 4 additions & 0 deletions langchain-core/src/output_parsers/transform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,8 @@ export abstract class BaseCumulativeTransformOutputParser<
}
}
}

getFormatInstructions(): string {
return "";
}
}
Loading
Loading