Skip to content

Commit

Permalink
refactor implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jsumners-nr committed Mar 6, 2024
1 parent 863a4e7 commit 62e1f05
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 66 deletions.
14 changes: 4 additions & 10 deletions lib/llm/bedrock-response.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,17 @@ class BedrockResponse {
#parsedBody
#command
#completions = []
#tokenCountCB
#id

/**
* @param {object} params
* @param {AwsBedrockMiddlewareResponse} params.response
* @param {BedrockCommand} params.bedrockCommand
* @param {function} [params.llmTokenCountCallback]
*/
constructor({ response, bedrockCommand, isError = false, llmTokenCountCallback }) {
constructor({ response, bedrockCommand, isError = false }) {
this.#innerResponse = isError ? response.$response : response.response
this.#command = bedrockCommand
this.isError = isError
this.#tokenCountCB = llmTokenCountCallback

if (this.isError) {
return
Expand Down Expand Up @@ -129,7 +126,7 @@ class BedrockResponse {
* The number of tokens present in the prompt as determined by the remote
* API.
*
* @returns {number}
* @returns {number|undefined}
*/
get inputTokenCount() {
return this.#tokenCount('x-amzn-bedrock-input-token-count')
Expand All @@ -138,7 +135,7 @@ class BedrockResponse {
/**
* The number of tokens in the LLM response as determined by the remote API.
*
* @returns {number}
* @returns {number|undefined}
*/
get outputTokenCount() {
return this.#tokenCount('x-amzn-bedrock-output-token-count')
Expand Down Expand Up @@ -167,10 +164,7 @@ class BedrockResponse {
if (headerVal != null) {
return parseInt(headerVal, 10)
}
if (typeof this.#tokenCountCB === 'function') {
return this.#tokenCountCB(this.#parsedBody)
}
return 0
return undefined
}
}

Expand Down
20 changes: 14 additions & 6 deletions lib/llm/chat-completion-message.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,31 @@ class LlmChatCompletionMessage extends LlmEvent {
super(params)

const { agent, content, isResponse, index, completionId } = params
const recordContent = agent.config?.ai_monitoring?.record_content?.enabled
const tokenCB = agent?.llm?.tokenCountCallback

this.is_response = isResponse
this.completion_id = completionId
this.sequence = index
this.content = agent.config?.ai_monitoring?.record_content?.enabled ? content : undefined
this.content = recordContent === true ? content : undefined
this.role = ''

this.#setId(index)
if (this.is_response === true) {
this.role = 'assistant'
this.token_count = this.bedrockResponse.outputTokenCount
if (recordContent === false && typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, content)
} else {
this.token_count = this.bedrockResponse.outputTokenCount
}
} else {
this.role = 'user'
this.token_count = this.bedrockResponse.inputTokenCount
this.content = agent.config?.ai_monitoring?.record_content?.enabled
? this.bedrockCommand.prompt
: undefined
this.content = recordContent === true ? this.bedrockCommand.prompt : undefined
if (recordContent === false && typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, this.bedrockCommand.prompt)
} else {
this.token_count = this.bedrockResponse.inputTokenCount
}
}
}

Expand Down
34 changes: 4 additions & 30 deletions lib/v3/bedrock.js
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,16 @@ function recordEmbeddingMessage({ agent, segment, bedrockCommand, bedrockRespons
* @param {BedrockCommand} params.bedrockCommand parsed input
* @param {object} params.response response from bedrock
* @param {Error|null} params.err error from request if exists
* @param {function|undefined} [params.llmTokenCountCallback] Callback to invoke
* to retrieve the number of tokens when one has been provided by the user.
*
* @returns {BedrockResponse} parsed response from bedrock
*/
function createBedrockResponse({ bedrockCommand, response, err, llmTokenCountCallback }) {
function createBedrockResponse({ bedrockCommand, response, err }) {
let bedrockResponse

if (err) {
bedrockResponse = new BedrockResponse({
bedrockCommand,
response: err,
isError: err !== null,
llmTokenCountCallback
})
bedrockResponse = new BedrockResponse({ bedrockCommand, response: err, isError: err !== null })
} else {
bedrockResponse = new BedrockResponse({ bedrockCommand, response, llmTokenCountCallback })
bedrockResponse = new BedrockResponse({ bedrockCommand, response })
}
return bedrockResponse
}
Expand Down Expand Up @@ -235,26 +228,7 @@ function getBedrockSpec({ commandName }, shim, _original, _name, args) {

function handleResponse({ shim, err, response, segment, bedrockCommand, modelType }) {
const { agent } = shim
let llmTokenCountCallback
if (typeof agent?.llm?.tokenCountCallback === 'function') {
// The user has provided their own function for calculating the
// `token_count` metric. So we wrap it up for usage by BedrockResponse.
//
// Note: because of the fact that the spec requires the callback to accept
// the model and content, and that we get our content from the
// BedrockResponse instance, we establish a contract with BedrockResponse
// such that it will provide the content while we provide the model.
llmTokenCountCallback = function tokenCB(parsedResponseBody) {
return agent.llm.tokenCountCallback(bedrockCommand.modelId, parsedResponseBody)
}
}

const bedrockResponse = createBedrockResponse({
bedrockCommand,
response,
err,
llmTokenCountCallback
})
const bedrockResponse = createBedrockResponse({ bedrockCommand, response, err })

addLlmMeta({ agent, segment })
if (modelType === 'completion') {
Expand Down
18 changes: 5 additions & 13 deletions tests/versioned/v3/bedrock-chat-completions.tap.js
Original file line number Diff line number Diff line change
Expand Up @@ -605,25 +605,17 @@ tap.test('should not instrument stream when disabled', (t) => {
})

tap.test('should utilize tokenCountCallback when set', (t) => {
t.plan(9)
t.plan(5)

const { bedrock, client, helper } = t.context
const prompt = 'text amazon user token count callback response'
const input = requests.amazon(prompt, 'amazon.titan-text-express-v1')

helper.agent.config.ai_monitoring.record_content.enabled = false
helper.agent.llm.tokenCountCallback = function (model, content) {
t.equal(model, 'amazon.titan-text-express-v1')
t.same(content, {
inputTextTokenCount: 13,
results: [
{
tokenCount: 4,
outputText: '42',
completionReason: 'endoftext'
}
]
})
return content.inputTextTokenCount + content.results[0].tokenCount
t.equal([prompt, '42'].includes(content), true)
return content?.split(' ')?.length
}
const command = new bedrock.InvokeModelCommand(input)

Expand All @@ -634,7 +626,7 @@ tap.test('should utilize tokenCountCallback when set', (t) => {
const events = helper.agent.customEventAggregator.events.toArray()
const completions = events.filter((e) => e[0].type === 'LlmChatCompletionMessage')
t.equal(
completions.some((e) => e[1].token_count === 17),
completions.some((e) => e[1].token_count === 7),
true
)

Expand Down
12 changes: 5 additions & 7 deletions tests/versioned/v3/bedrock-embeddings.tap.js
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,17 @@ tap.afterEach(async (t) => {
})

tap.test('should utilize tokenCountCallback when set', (t) => {
t.plan(7)
t.plan(3)

const { bedrock, client, helper } = t.context
const prompt = 'embed text amazon token count callback response'
const input = requests.amazon(prompt, 'amazon.titan-text-express-v1')

helper.agent.config.ai_monitoring.record_content.enabled = false
helper.agent.llm.tokenCountCallback = function (model, content) {
t.equal(model, 'amazon.titan-text-express-v1')
t.same(content, {
embedding: [0.1, 0.2, 0.3, 0.4],
inputTextTokenCount: 13
})
return content.inputTextTokenCount
t.equal(content, prompt)
return content?.split(' ')?.length
}
const command = new bedrock.InvokeModelCommand(input)

Expand All @@ -276,7 +274,7 @@ tap.test('should utilize tokenCountCallback when set', (t) => {
const events = helper.agent.customEventAggregator.events.toArray()
const completions = events.filter((e) => e[0].type === 'LlmChatCompletionMessage')
t.equal(
completions.some((e) => e[1].token_count === 13),
completions.some((e) => e[1].token_count === 7),
true
)

Expand Down

0 comments on commit 62e1f05

Please sign in to comment.