Skip to content

Commit a1303d6

Browse files
committed
Implement getNumTokens on GoogleBaseLLM and ChatGoogleBase
1 parent 7820824 commit a1303d6

File tree

3 files changed

+34
-14
lines changed

3 files changed

+34
-14
lines changed

libs/langchain-google-common/src/chat_models.ts

+5
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,11 @@ export abstract class ChatGoogleBase<AuthOptions>
306306
return copyAIModelParams(this, options);
307307
}
308308

309+
async getNumTokens(messages: BaseMessage[]): Promise<number> {
310+
const parameters = this.invocationParams();
311+
return this.connection.requestCountTokens(messages, parameters);
312+
}
313+
309314
async _generate(
310315
messages: BaseMessage[],
311316
options: this["ParsedCallOptions"],

libs/langchain-google-common/src/connection.ts

+24-14
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ export abstract class GoogleConnection<
9696
data: unknown | undefined,
9797
_options: CallOptions,
9898
requestHeaders: Record<string, string> = {},
99-
callMethod?: string
99+
callMethod: string | undefined = undefined
100100
): Promise<GoogleAbstractedClientOps> {
101-
const url = await this.buildUrl();
101+
const url = await this.buildUrl(callMethod);
102102
const method = this.buildMethod();
103103
const infoHeaders = (await this._clientInfoHeaders()) ?? {};
104104
const additionalHeaders = (await this.additionalHeaders()) ?? {};
@@ -127,17 +127,25 @@ export abstract class GoogleConnection<
127127
async _request(
128128
data: unknown | undefined,
129129
options: CallOptions,
130-
requestHeaders: Record<string, string> = {},
131-
callMethod?: string
130+
requestHeaders: Record<string, string> = {}
132131
): Promise<ResponseType> {
133-
const opts = await this._buildOpts(data, options, requestHeaders, callMethod);
132+
const opts = await this._buildOpts(data, options, requestHeaders);
134133
const callResponse = await this.caller.callWithOptions(
135134
{ signal: options?.signal },
136135
async () => this.client.request(opts)
137136
);
138137
const response: unknown = callResponse; // Done for typecast safety, I guess
139138
return <ResponseType>response;
140139
}
140+
141+
async _requestCountTokens(data: unknown | undefined): Promise<number> {
142+
const opts = await this._buildOpts(data, {} as CallOptions, {}, 'countTokens');
143+
const { totalTokens } = await this.caller.callWithOptions(
144+
{},
145+
async () => this.client.request(opts)
146+
) as { totalTokens: number };
147+
return totalTokens;
148+
}
141149
}
142150

143151
export abstract class GoogleHostConnection<
@@ -334,11 +342,11 @@ export abstract class GoogleAIConnection<
334342
return url;
335343
}
336344

337-
async buildUrlVertex(): Promise<string> {
345+
async buildUrlVertex(callMethod?: string): Promise<string> {
338346
if (this.isApiKey) {
339-
return this.buildUrlVertexExpress();
347+
return this.buildUrlVertexExpress(callMethod);
340348
} else {
341-
return this.buildUrlVertexLocation();
349+
return this.buildUrlVertexLocation(callMethod);
342350
}
343351
}
344352

@@ -356,6 +364,14 @@ export abstract class GoogleAIConnection<
356364
parameters: GoogleAIModelRequestParams
357365
): Promise<unknown>;
358366

367+
async requestCountTokens(
368+
input: InputType,
369+
parameters: GoogleAIModelRequestParams
370+
): Promise<number> {
371+
const data = await this.formatData(input, parameters);
372+
return await this._requestCountTokens(data);
373+
}
374+
359375
async request(
360376
input: InputType,
361377
parameters: GoogleAIModelRequestParams,
@@ -428,12 +444,6 @@ export abstract class AbstractGoogleLLMConnection<
428444
): Promise<unknown> {
429445
return this.api.formatData(input, parameters);
430446
}
431-
432-
async getNumTokens(input: MessageType) {
433-
const data = this.formatData(input)
434-
const { totalTokens } = await this._request(data, {}, {}, 'countTokens');
435-
return totalTokens
436-
}
437447
}
438448

439449
export interface GoogleCustomEventInfo {

libs/langchain-google-common/src/llms.ts

+5
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ export abstract class GoogleBaseLLM<AuthOptions>
177177
return prompt;
178178
}
179179

180+
async getNumTokens(input: string): Promise<number> {
181+
const parameters = copyAIModelParams(this, undefined);
182+
return this.connection.requestCountTokens(input, parameters);
183+
}
184+
180185
/**
181186
* For some given input string and options, return a string output.
182187
*

0 commit comments

Comments
 (0)