Skip to content

Commit

Permalink
.Net: Add prompt execution settings to AutoFunctionInvocationContext (#…
Browse files Browse the repository at this point in the history
…10551)

### Motivation, Context and Description
Today, the prompt execution settings supplied to the chat completion
service are not available in `AutoFunctionInvocationContext`. This PR
adds the new `AutoFunctionInvocationContext.ExecutionSettings` and
assigns it to the settings with which the chat completion service was
invoked.

Closes: #10475
  • Loading branch information
SergeyMenshykh authored Feb 24, 2025
1 parent 534d860 commit 7b83ffd
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,65 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)
Assert.Equal(isStreaming, actualStreamingFlag);
}

[Fact]
public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterContextAsync()
{
// Arrange
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var expectedExecutionSettings = new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
var result = await kernel.InvokePromptAsync("Test prompt", new(expectedExecutionSettings));

// Assert
Assert.NotNull(actualContext);
Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
}

[Fact]
public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingToFilterContextAsync()
{
// Arrange
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var expectedExecutionSettings = new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(expectedExecutionSettings)))
{ }

// Assert
Assert.NotNull(actualContext);
Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatExecutionSettings,
chatHistory,
requestIndex,
(FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content),
Expand Down Expand Up @@ -384,6 +385,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatExecutionSettings,
chatHistory,
requestIndex,
(FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
/// Processes AI function calls by iterating over the function calls, invoking them and adding the results to the chat history.
/// </summary>
/// <param name="chatMessageContent">The chat message content representing AI model response and containing function calls.</param>
/// <param name="executionSettings">The prompt execution settings.</param>
/// <param name="chatHistory">The chat history to add function invocation results to.</param>
/// <param name="requestIndex">AI model function(s) call request sequence index.</param>
/// <param name="checkIfFunctionAdvertised">Callback to check if a function was advertised to AI model or not.</param>
Expand All @@ -129,6 +130,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
/// <returns>Last chat history message if function invocation filter requested processing termination, otherwise null.</returns>
public async Task<ChatMessageContent?> ProcessFunctionCallsAsync(
ChatMessageContent chatMessageContent,
PromptExecutionSettings? executionSettings,
ChatHistory chatHistory,
int requestIndex,
Func<FunctionCallContent, bool> checkIfFunctionAdvertised,
Expand Down Expand Up @@ -177,7 +179,8 @@ public FunctionCallsProcessor(ILogger? logger = null)
FunctionCount = functionCalls.Length,
CancellationToken = cancellationToken,
IsStreaming = isStreaming,
ToolCallId = functionCall.Id
ToolCallId = functionCall.Id,
ExecutionSettings = executionSettings
};

s_inflightAutoInvokes.Value++;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.SemanticKernel.ChatCompletion;

Expand Down Expand Up @@ -79,6 +80,12 @@ public AutoFunctionInvocationContext(
/// </summary>
public ChatMessageContent ChatMessageContent { get; }

/// <summary>
/// The execution settings associated with the operation.
/// </summary>
[Experimental("SKEXP0001")]
public PromptExecutionSettings? ExecutionSettings { get; init; }

/// <summary>
/// Gets the <see cref="Microsoft.SemanticKernel.ChatCompletion.ChatHistory"/> associated with automatic function invocation.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Threading;

namespace Microsoft.SemanticKernel;
Expand Down Expand Up @@ -54,6 +55,12 @@ internal PromptRenderContext(Kernel kernel, KernelFunction function, KernelArgum
/// </summary>
public KernelArguments Arguments { get; }

/// <summary>
/// The execution settings associated with the operation.
/// </summary>
[Experimental("SKEXP0001")]
public PromptExecutionSettings? ExecutionSettings { get; init; }

/// <summary>
/// Gets or sets the rendered prompt.
/// </summary>
Expand Down
4 changes: 3 additions & 1 deletion dotnet/src/SemanticKernel.Abstractions/Kernel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,15 @@ internal async Task<PromptRenderContext> OnPromptRenderAsync(
KernelFunction function,
KernelArguments arguments,
bool isStreaming,
PromptExecutionSettings? executionSettings,
Func<PromptRenderContext, Task> renderCallback,
CancellationToken cancellationToken)
{
PromptRenderContext context = new(this, function, arguments)
{
CancellationToken = cancellationToken,
IsStreaming = isStreaming
IsStreaming = isStreaming,
ExecutionSettings = executionSettings
};

await InvokeFilterOrPromptRenderAsync(this._promptRenderFilters, renderCallback, context).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ private async Task<PromptRenderingResult> RenderPromptAsync(

Verify.NotNull(aiService);

var renderingContext = await kernel.OnPromptRenderAsync(this, arguments, isStreaming, async (context) =>
var renderingContext = await kernel.OnPromptRenderAsync(this, arguments, isStreaming, executionSettings, async (context) =>
{
renderedPrompt = await this._promptTemplate.RenderAsync(kernel, context.Arguments, cancellationToken).ConfigureAwait(false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,42 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)
// Assert
Assert.Equal(isStreaming, actualStreamingFlag);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task PromptExecutionSettingsArePropagatedToFilterContextAsync(bool isStreaming)
{
// Arrange
PromptExecutionSettings? actualExecutionSettings = null;

var mockTextGeneration = this.GetMockTextGeneration();

var function = KernelFunctionFactory.CreateFromPrompt("Prompt");

var kernel = this.GetKernelWithFilters(textGenerationService: mockTextGeneration.Object,
onPromptRender: (context, next) =>
{
actualExecutionSettings = context.ExecutionSettings;
return next(context);
});

var expectedExecutionSettings = new PromptExecutionSettings();

var arguments = new KernelArguments(expectedExecutionSettings);

// Act
if (isStreaming)
{
await foreach (var item in kernel.InvokeStreamingAsync(function, arguments))
{ }
}
else
{
await kernel.InvokeAsync(function, arguments);
}

// Assert
Assert.Same(expectedExecutionSettings, actualExecutionSettings);
}
}
Loading

0 comments on commit 7b83ffd

Please sign in to comment.