-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
Copy pathtransform.ts
141 lines (130 loc) Β· 3.99 KB
/
transform.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import { BaseOutputParser } from "./base.js";
import {
type BaseMessage,
isBaseMessage,
isBaseMessageChunk,
} from "../messages/base.js";
import { convertToChunk } from "../messages/utils.js";
import type { BaseCallbackConfig } from "../callbacks/manager.js";
import {
type Generation,
type ChatGeneration,
GenerationChunk,
ChatGenerationChunk,
} from "../outputs.js";
import { deepCompareStrict } from "../utils/@cfworker/json-schema/index.js";
/**
* Class to parse the output of an LLM call that also allows streaming inputs.
*/
export abstract class BaseTransformOutputParser<
T = unknown
> extends BaseOutputParser<T> {
async *_transform(
inputGenerator: AsyncGenerator<string | BaseMessage>
): AsyncGenerator<T> {
for await (const chunk of inputGenerator) {
if (typeof chunk === "string") {
yield this.parseResult([{ text: chunk }]);
} else {
yield this.parseResult([
{
message: chunk,
text: this._baseMessageToString(chunk),
},
]);
}
}
}
/**
* Transforms an asynchronous generator of input into an asynchronous
* generator of parsed output.
* @param inputGenerator An asynchronous generator of input.
* @param options A configuration object.
* @returns An asynchronous generator of parsed output.
*/
async *transform(
inputGenerator: AsyncGenerator<string | BaseMessage>,
options: BaseCallbackConfig
): AsyncGenerator<T> {
yield* this._transformStreamWithConfig(
inputGenerator,
this._transform.bind(this),
{
...options,
runType: "parser",
}
);
}
}
export type BaseCumulativeTransformOutputParserInput = { diff?: boolean };
/**
* A base class for output parsers that can handle streaming input. It
* extends the `BaseTransformOutputParser` class and provides a method for
* converting parsed outputs into a diff format.
*/
export abstract class BaseCumulativeTransformOutputParser<
T = unknown
> extends BaseTransformOutputParser<T> {
protected diff = false;
constructor(fields?: BaseCumulativeTransformOutputParserInput) {
super(fields);
this.diff = fields?.diff ?? this.diff;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
protected abstract _diff(prev: any | undefined, next: any): any;
abstract parsePartialResult(
generations: Generation[] | ChatGeneration[]
): Promise<T | undefined>;
async *_transform(
inputGenerator: AsyncGenerator<string | BaseMessage>
): AsyncGenerator<T> {
let prevParsed: T | undefined;
let accGen: GenerationChunk | undefined;
for await (const chunk of inputGenerator) {
if (typeof chunk !== "string" && typeof chunk.content !== "string") {
throw new Error("Cannot handle non-string output.");
}
let chunkGen: GenerationChunk;
if (isBaseMessageChunk(chunk)) {
if (typeof chunk.content !== "string") {
throw new Error("Cannot handle non-string message output.");
}
chunkGen = new ChatGenerationChunk({
message: chunk,
text: chunk.content,
});
} else if (isBaseMessage(chunk)) {
if (typeof chunk.content !== "string") {
throw new Error("Cannot handle non-string message output.");
}
chunkGen = new ChatGenerationChunk({
message: convertToChunk(chunk),
text: chunk.content,
});
} else {
chunkGen = new GenerationChunk({ text: chunk });
}
if (accGen === undefined) {
accGen = chunkGen;
} else {
accGen = accGen.concat(chunkGen);
}
const parsed = await this.parsePartialResult([accGen]);
if (
parsed !== undefined &&
parsed !== null &&
!deepCompareStrict(parsed, prevParsed)
) {
if (this.diff) {
yield this._diff(prevParsed, parsed);
} else {
yield parsed;
}
prevParsed = parsed;
}
}
}
getFormatInstructions(): string {
return "";
}
}