Skip to content

Commit

Permalink
feat: Added llm.tokenCountCallback support to Bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
jsumners-nr committed Mar 6, 2024
1 parent f6a86cb commit 2b36ae2
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 45 deletions.
14 changes: 7 additions & 7 deletions THIRD_PARTY_NOTICES.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ code, the source code can be found at [https://github.com/newrelic/node-newrelic

### @aws-sdk/client-s3

This product includes source derived from [@aws-sdk/client-s3](https://github.com/aws/aws-sdk-js-v3) ([v3.485.0](https://github.com/aws/aws-sdk-js-v3/tree/v3.485.0)), distributed under the [Apache-2.0 License](https://github.com/aws/aws-sdk-js-v3/blob/v3.485.0/LICENSE):
This product includes source derived from [@aws-sdk/client-s3](https://github.com/aws/aws-sdk-js-v3) ([v3.525.0](https://github.com/aws/aws-sdk-js-v3/tree/v3.525.0)), distributed under the [Apache-2.0 License](https://github.com/aws/aws-sdk-js-v3/blob/v3.525.0/LICENSE):

```
Apache License
Expand Down Expand Up @@ -249,7 +249,7 @@ This product includes source derived from [@aws-sdk/client-s3](https://github.co

### @aws-sdk/s3-request-presigner

This product includes source derived from [@aws-sdk/s3-request-presigner](https://github.com/aws/aws-sdk-js-v3) ([v3.485.0](https://github.com/aws/aws-sdk-js-v3/tree/v3.485.0)), distributed under the [Apache-2.0 License](https://github.com/aws/aws-sdk-js-v3/blob/v3.485.0/LICENSE):
This product includes source derived from [@aws-sdk/s3-request-presigner](https://github.com/aws/aws-sdk-js-v3) ([v3.525.0](https://github.com/aws/aws-sdk-js-v3/tree/v3.525.0)), distributed under the [Apache-2.0 License](https://github.com/aws/aws-sdk-js-v3/blob/v3.525.0/LICENSE):

```
Apache License
Expand Down Expand Up @@ -875,7 +875,7 @@ This product includes source derived from [@newrelic/newrelic-oss-cli](https://g

### @newrelic/test-utilities

This product includes source derived from [@newrelic/test-utilities](https://github.com/newrelic/node-test-utilities) ([v8.2.0](https://github.com/newrelic/node-test-utilities/tree/v8.2.0)), distributed under the [Apache-2.0 License](https://github.com/newrelic/node-test-utilities/blob/v8.2.0/LICENSE):
This product includes source derived from [@newrelic/test-utilities](https://github.com/newrelic/node-test-utilities) ([v8.3.0](https://github.com/newrelic/node-test-utilities/tree/v8.3.0)), distributed under the [Apache-2.0 License](https://github.com/newrelic/node-test-utilities/blob/v8.3.0/LICENSE):

```
Apache License
Expand Down Expand Up @@ -1084,7 +1084,7 @@ This product includes source derived from [@newrelic/test-utilities](https://git

### aws-sdk

This product includes source derived from [aws-sdk](https://github.com/aws/aws-sdk-js) ([v2.1531.0](https://github.com/aws/aws-sdk-js/tree/v2.1531.0)), distributed under the [Apache-2.0 License](https://github.com/aws/aws-sdk-js/blob/v2.1531.0/LICENSE.txt):
This product includes source derived from [aws-sdk](https://github.com/aws/aws-sdk-js) ([v2.1571.0](https://github.com/aws/aws-sdk-js/tree/v2.1571.0)), distributed under the [Apache-2.0 License](https://github.com/aws/aws-sdk-js/blob/v2.1571.0/LICENSE.txt):

```
Expand Down Expand Up @@ -1316,7 +1316,7 @@ ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

### eslint

This product includes source derived from [eslint](https://github.com/eslint/eslint) ([v8.56.0](https://github.com/eslint/eslint/tree/v8.56.0)), distributed under the [MIT License](https://github.com/eslint/eslint/blob/v8.56.0/LICENSE):
This product includes source derived from [eslint](https://github.com/eslint/eslint) ([v8.57.0](https://github.com/eslint/eslint/tree/v8.57.0)), distributed under the [MIT License](https://github.com/eslint/eslint/blob/v8.57.0/LICENSE):

```
Copyright OpenJS Foundation and other contributors, <www.openjsf.org>
Expand Down Expand Up @@ -1401,7 +1401,7 @@ SOFTWARE.

### lockfile-lint

This product includes source derived from [lockfile-lint](https://github.com/lirantal/lockfile-lint) ([v4.12.1](https://github.com/lirantal/lockfile-lint/tree/v4.12.1)), distributed under the [Apache-2.0 License](https://github.com/lirantal/lockfile-lint/blob/v4.12.1/LICENSE):
This product includes source derived from [lockfile-lint](https://github.com/lirantal/lockfile-lint) ([v4.13.2](https://github.com/lirantal/lockfile-lint/tree/v4.13.2)), distributed under the [Apache-2.0 License](https://github.com/lirantal/lockfile-lint/blob/v4.13.2/LICENSE):

```
Expand Down Expand Up @@ -1599,7 +1599,7 @@ This product includes source derived from [lockfile-lint](https://github.com/lir

### newrelic

This product includes source derived from [newrelic](https://github.com/newrelic/node-newrelic) ([v11.11.0](https://github.com/newrelic/node-newrelic/tree/v11.11.0)), distributed under the [Apache-2.0 License](https://github.com/newrelic/node-newrelic/blob/v11.11.0/LICENSE):
This product includes source derived from [newrelic](https://github.com/newrelic/node-newrelic) ([v11.12.0](https://github.com/newrelic/node-newrelic/tree/v11.12.0)), distributed under the [Apache-2.0 License](https://github.com/newrelic/node-newrelic/blob/v11.12.0/LICENSE):

```
Apache License
Expand Down
20 changes: 17 additions & 3 deletions lib/llm/bedrock-response.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@ class BedrockResponse {
#parsedBody
#command
#completions = []
#tokenCountCB
#id

/**
* @param {object} params
* @param {AwsBedrockMiddlewareResponse} params.response
* @param {BedrockCommand} params.bedrockCommand
* @param {function} [params.llmTokenCountCallback]
*/
constructor({ response, bedrockCommand, isError = false }) {
constructor({ response, bedrockCommand, isError = false, llmTokenCountCallback }) {
this.#innerResponse = isError ? response.$response : response.response
this.#command = bedrockCommand
this.isError = isError
this.#tokenCountCB = llmTokenCountCallback

if (this.isError) {
return
Expand Down Expand Up @@ -129,7 +132,7 @@ class BedrockResponse {
* @returns {number}
*/
get inputTokenCount() {
return parseInt(this.headers?.['x-amzn-bedrock-input-token-count'] || 0, 10)
return this.#tokenCount('x-amzn-bedrock-input-token-count')
}

/**
Expand All @@ -138,7 +141,7 @@ class BedrockResponse {
* @returns {number}
*/
get outputTokenCount() {
return parseInt(this.headers?.['x-amzn-bedrock-output-token-count'] || 0, 10)
return this.#tokenCount('x-amzn-bedrock-output-token-count')
}

/**
Expand All @@ -158,6 +161,17 @@ class BedrockResponse {
get statusCode() {
return this.#innerResponse.statusCode
}

#tokenCount(headerName) {
const headerVal = this.headers?.[headerName]
if (headerVal != null) {
return parseInt(headerVal, 10)
}
if (typeof this.#tokenCountCB === 'function') {
return this.#tokenCountCB(this.#parsedBody)
}
return 0
}
}

module.exports = BedrockResponse
35 changes: 31 additions & 4 deletions lib/v3/bedrock.js
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,23 @@ function recordEmbeddingMessage({ agent, segment, bedrockCommand, bedrockRespons
* @param {BedrockCommand} params.bedrockCommand parsed input
* @param {object} params.response response from bedrock
* @param {Error|null} params.err error from request if exists
* @param {function|undefined} [params.llmTokenCountCallback] Callback to invoke
* to retrieve the number of tokens when one has been provided by the user.
*
* @returns {BedrockResponse} parsed response from bedrock
*/
function createBedrockResponse({ bedrockCommand, response, err }) {
function createBedrockResponse({ bedrockCommand, response, err, llmTokenCountCallback }) {
let bedrockResponse

if (err) {
bedrockResponse = new BedrockResponse({ bedrockCommand, response: err, isError: err !== null })
bedrockResponse = new BedrockResponse({
bedrockCommand,
response: err,
isError: err !== null,
llmTokenCountCallback
})
} else {
bedrockResponse = new BedrockResponse({ bedrockCommand, response })
bedrockResponse = new BedrockResponse({ bedrockCommand, response, llmTokenCountCallback })
}
return bedrockResponse
}
Expand Down Expand Up @@ -227,7 +235,26 @@ function getBedrockSpec({ commandName }, shim, _original, _name, args) {

function handleResponse({ shim, err, response, segment, bedrockCommand, modelType }) {
const { agent } = shim
const bedrockResponse = createBedrockResponse({ bedrockCommand, response, err })
let llmTokenCountCallback
if (typeof agent?.llm?.tokenCountCallback === 'function') {
// The user has provided their own function for calculating the
// `token_count` metric. So we wrap it up for usage by BedrockResponse.
//
// Note: because of the fact that the spec requires the callback to accept
// the model and content, and that we get our content from the
// BedrockResponse instance, we establish a contract with BedrockResponse
// such that it will provide the content while we provide the model.
llmTokenCountCallback = function tokenCB(parsedResponseBody) {
return agent.llm.tokenCountCallback(bedrockCommand.modelId, parsedResponseBody)
}
}

const bedrockResponse = createBedrockResponse({
bedrockCommand,
response,
err,
llmTokenCountCallback
})

addLlmMeta({ agent, segment })
if (modelType === 'completion') {
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"@aws-sdk/s3-request-presigner": "^3.342.0",
"@newrelic/eslint-config": "^0.4.0",
"@newrelic/newrelic-oss-cli": "^0.1.2",
"@newrelic/test-utilities": "^8.2.0",
"@newrelic/test-utilities": "^8.3.0",
"aws-sdk": "^2.1372.0",
"c8": "^7.12.0",
"eslint": "^8.56.0",
Expand Down
39 changes: 39 additions & 0 deletions tests/versioned/v3/bedrock-chat-completions.tap.js
Original file line number Diff line number Diff line change
Expand Up @@ -603,3 +603,42 @@ tap.test('should not instrument stream when disabled', (t) => {
t.end()
})
})

tap.test('should utilize tokenCountCallback when set', (t) => {
t.plan(9)

const { bedrock, client, helper } = t.context
const prompt = 'text amazon user token count callback response'
const input = requests.amazon(prompt, 'amazon.titan-text-express-v1')

helper.agent.llm.tokenCountCallback = function (model, content) {
t.equal(model, 'amazon.titan-text-express-v1')
t.same(content, {
inputTextTokenCount: 13,
results: [
{
tokenCount: 4,
outputText: '42',
completionReason: 'endoftext'
}
]
})
return content.inputTextTokenCount + content.results[0].tokenCount
}
const command = new bedrock.InvokeModelCommand(input)

helper.runInTransaction(async (tx) => {
await client.send(command)

// Chat completion messages should have the correct `token_count` value.
const events = helper.agent.customEventAggregator.events.toArray()
const completions = events.filter((e) => e[0].type === 'LlmChatCompletionMessage')
t.equal(
completions.some((e) => e[1].token_count === 17),
true
)

tx.end()
t.end()
})
})
33 changes: 33 additions & 0 deletions tests/versioned/v3/bedrock-embeddings.tap.js
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,36 @@ tap.afterEach(async (t) => {
})
})
})

tap.test('should utilize tokenCountCallback when set', (t) => {
t.plan(3)

const { bedrock, client, helper } = t.context
const prompt = 'embed text amazon token count callback response'
const input = requests.amazon(prompt, 'amazon.titan-text-express-v1')

helper.agent.llm.tokenCountCallback = function (model, content) {
t.equal(model, 'amazon.titan-text-express-v1')
t.same(content, {
embedding: [0.1, 0.2, 0.3, 0.4],
inputTextTokenCount: 13
})
return content.inputTextTokenCount
}
const command = new bedrock.InvokeModelCommand(input)

helper.runInTransaction(async (tx) => {
await client.send(command)

// Chat completion messages should have the correct `token_count` value.
const events = helper.agent.customEventAggregator.events.toArray()
const completions = events.filter((e) => e[0].type === 'LlmChatCompletionMessage')
t.equal(
completions.some((e) => e[1].token_count === 13),
true
)

tx.end()
t.end()
})
})
Loading

0 comments on commit 2b36ae2

Please sign in to comment.