Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Further tweaks to support not tying system message overwriting solely…
… into the context inclusion mechanism.
  • Loading branch information
ckpearson committed May 27, 2025
commit c4cde133e81eb12c4cc1348ed41df0a919d60332
1 change: 1 addition & 0 deletions dotnet-sdk/AGUIDotnet/AGUIDotnet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI" Version="9.5.0" />
<PackageReference Include="JsonPatch.Net" Version="3.3.0" />
</ItemGroup>

<ItemGroup>
Expand Down
77 changes: 56 additions & 21 deletions dotnet-sdk/AGUIDotnet/Agent/ChatClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace AGUIDotnet.Agent;

public sealed record ChatClientAgentOptions
public record ChatClientAgentOptions
{
/// <summary>
/// Options to provide when using the provided chat client.
Expand All @@ -34,6 +34,9 @@ public sealed record ChatClientAgentOptions
/// <para>
/// Switching this on will cause the agent to perform an initial typed extraction of the context if available, and then use that context for the agent run.
/// </para>
/// <para>
/// This is useful e.g. for frontends like CopilotKit that do not make useCopilotReadable context available to agents, instead relying on shared agent state - it does however provide the context in the system message.
/// </para>
/// </summary>
public bool PerformAiContextExtraction { get; init; } = false;

Expand All @@ -58,7 +61,7 @@ public class ChatClientAgent : IAGUIAgent
{
private readonly IChatClient _chatClient;
private readonly ChatClientAgentOptions _agentOptions;
private static readonly JsonSerializerOptions _jsonSerOpts = new(JsonSerializerDefaults.Web)
protected static readonly JsonSerializerOptions _jsonSerOpts = new(JsonSerializerDefaults.Web)
{
WriteIndented = false,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
Expand Down Expand Up @@ -92,18 +95,14 @@ public async Task RunAsync(RunAgentInput input, ChannelWriter<BaseEvent> events,
// Ensure we have an empty tools list to start with.
chatOpts.Tools ??= [];

// Do NOT support multiple tool calls (even if the caller has requested it).
// Frontend tools require us to short-circuit the function invocation loop provided out of the box by the abstraction library.
// todo: We may need to create a custom chat client wrapper that handles function invocation without the limitation.
chatOpts.AllowMultipleToolCalls = false;

/*
Prepare the backend tools by filtering out any frontend tools (there shouldn't be any, but just in case),
and allow the derived type to modify the backend tools if needed.
*/
var backendTools = (await PrepareBackendTools(
[.. chatOpts.Tools.OfType<AIFunction>().Where(f => f is not FrontendTool)],
input,
events,
cancellationToken
)).Where(t => t is not FrontendTool).ToImmutableList();

Expand All @@ -126,20 +125,24 @@ [.. chatOpts.Tools.OfType<AIFunction>().Where(f => f is not FrontendTool)],
}
}

// Re-init the tools list to only include the discovered backend and provided frontend tools
chatOpts.Tools = [.. backendTools, .. frontendTools];
if (frontendTools.IsEmpty && backendTools.IsEmpty)
{
chatOpts.Tools = null;
chatOpts.AllowMultipleToolCalls = null;
}
else
{
chatOpts.Tools = [.. backendTools, .. frontendTools];
chatOpts.AllowMultipleToolCalls = false;
}

var context = await PrepareContext(input, cancellationToken);
var mappedMessages = await MapAGUIMessagesToChatClientMessages(input, context, cancellationToken);

var inFlightFrontendCalls = new HashSet<string>();

await events.WriteAsync(new RunStartedEvent
{
ThreadId = input.ThreadId,
RunId = input.RunId,
Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(),
}, cancellationToken);
// Handle the run starting
await OnRunStartedAsync(input, events, cancellationToken);

string? currentResponseId = null;

Expand Down Expand Up @@ -329,6 +332,7 @@ You are an expert at extracting context from a provided system message.
2. The user's age: 30
</input>
<output>
```json
[
{
"name": "The name of the user,
Expand All @@ -339,11 +343,15 @@ You are an expert at extracting context from a provided system message.
"value": "30"
}
]
```
</output>
</example>
</examples>

<existingContext>
```json
{{JsonSerializer.Serialize(input.Context, _jsonSerOpts)}}
```
</existingContext>

<providedSystemMessage>
Expand Down Expand Up @@ -371,21 +379,26 @@ You are an expert at extracting context from a provided system message.
}

/// <summary>
/// When overridden in a derived class, allows for customisation of how the context is injected into the system message.
/// When overridden in a derived class, allows for customisation of the system message used for the agent run.
/// </summary>
/// <remarks>
/// This is only used when the system message is overridden by the agent options. Default behaviour is to just append JSON-serialized context to the message.
/// This is only used when the system message is overridden by the agent options. Default behaviour is to honour the agent option for including context in the system message, falling back to the provided system message if not set.
/// </remarks>
/// <param name="input">The input to the agent for the run</param>
/// <param name="systemMessage">The system message to use</param>
/// <param name="context">The final context prepared for the agent</param>
/// <returns>The final system message to use</returns>
protected virtual ValueTask<string> InjectContextIntoSystemMessage(
protected virtual ValueTask<string> PrepareSystemMessage(
RunAgentInput input,
string systemMessage,
ImmutableList<Context> context
)
{
if (!_agentOptions.IncludeContextInSystemMessage)
{
return ValueTask.FromResult(systemMessage);
}

return ValueTask.FromResult(
$"{systemMessage}\n\nThe following context is available to you:\n```{JsonSerializer.Serialize(context, _jsonSerOpts)}```"
);
Expand Down Expand Up @@ -415,9 +428,7 @@ [.. input.Messages.Where(m => m is not SystemMessage)
.Prepend(new SystemMessage
{
Id = Guid.NewGuid().ToString(),
Content = _agentOptions.IncludeContextInSystemMessage
? await InjectContextIntoSystemMessage(input, sysMessage, context)
: sysMessage,
Content = await PrepareSystemMessage(input, sysMessage, context)
})],

// Fallback to just preserving inbound messages as-is.
Expand Down Expand Up @@ -454,14 +465,38 @@ protected virtual ValueTask<ImmutableList<FrontendTool>> PrepareFrontendTools(
/// </summary>
/// <param name="backendTools">The backend tools already available via the provided chat client options</param>
/// <param name="input">The run input provided to the agent</param>
/// <param name="events">The events channel writer to push AG-UI events into
/// <returns>The backend tools to make available to the agent for the run</returns>
protected virtual ValueTask<ImmutableList<AIFunction>> PrepareBackendTools(
ImmutableList<AIFunction> backendTools,
RunAgentInput input,
ChannelWriter<BaseEvent> events,
CancellationToken cancellationToken = default
)
{
// By default, zero modification.
return ValueTask.FromResult(backendTools);
}

/// <summary>
/// When overridden in a derived class, allows for customisation of the handling for a run starting.
/// </summary>
/// <remarks>
/// The default implementation will emit a <see cref="RunStartedEvent"/> to the provided events channel.
/// </remarks>
/// <param name="input">The input to the agent for the current run.</param>
/// <param name="events">The events channel writer to push events into</param>
protected virtual async ValueTask OnRunStartedAsync(
RunAgentInput input,
ChannelWriter<BaseEvent> events,
CancellationToken cancellationToken = default
)
{
await events.WriteAsync(new RunStartedEvent
{
ThreadId = input.ThreadId,
RunId = input.RunId,
Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()
}, cancellationToken);
}
}
131 changes: 131 additions & 0 deletions dotnet-sdk/AGUIDotnet/Agent/StatefulChatClientAgent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
using System.Collections.Immutable;
using System.Text.Json;
using System.Threading.Channels;
using AGUIDotnet.Events;
using AGUIDotnet.Types;
using Json.Patch;
using Microsoft.Extensions.AI;

namespace AGUIDotnet.Agent;

public record StatefulChatClientAgentOptions<TState> : ChatClientAgentOptions where TState : notnull
{
}

/// <summary>
/// Much like <see cref="ChatClientAgent"/> but tailored for scenarios where the agent and frontend collaborate on shared state.
/// </summary>
/// <remarks>
/// This agent is NOT guaranteed to be thread-safe, nor is it resilient to shared use across multiple threads / runs, a separate instance should be used for each invocation.
/// </remarks>
/// <typeparam name="TState"></typeparam>
public class StatefulChatClientAgent<TState> : ChatClientAgent where TState : notnull
{
private TState _currentState = default!;

public StatefulChatClientAgent(IChatClient chatClient, TState initialState, StatefulChatClientAgentOptions<TState> agentOptions) : base(chatClient, agentOptions)
{
if (agentOptions?.SystemMessage is null)
{
throw new ArgumentException("System message must be provided for a stateful agent.", nameof(agentOptions));
}

_currentState = initialState;
}

private TState RetrieveState()
{
return _currentState;
}

private void UpdateState(TState newState)
{
_currentState = newState;
}

protected override async ValueTask<string> PrepareSystemMessage(RunAgentInput input, string systemMessage, ImmutableList<Context> context)
{
var coreMessage = await base.PrepareSystemMessage(input, systemMessage, context);

// Hijack the original system message to include some context to the LLM about the stateful nature of this agent.
// Nudging it to use the state collaboration tools available to it.
return $"""
<persona>
You are a stateful agent that wraps an existing agent, allowing it to collaborate with a human in the frontend on shared state to achieve a goal.
</persona>

<tools>
You may have a variety of tools available to you to help achieve your goal, and state collaboration is one of them.

You can retrieve the current shared state of the agent using the `retrieve_state` tool, and update the shared state using the `update_state` tool.
</tools>

<rules>
- Wherever necessary (e.g. it is aligned with your stated goal), you MUST make use of the state collaboration tools.
- Inspect the state of the agent to understand both the current state and the schema / purpose of the state in alignment with the agent's goal.
- Liberally use the `update_state` tool to update the shared state as you progress towards your goal.
- Avoid making assumptions about the state, always retrieve it first.
- Avoid making unnecessary updates to the state, e.g. if the user intent does not require it.
</rules>

<underlying_agent>
{coreMessage}
</underlying_agent>
""";
}

protected override async ValueTask<ImmutableList<AIFunction>> PrepareBackendTools(ImmutableList<AIFunction> backendTools, RunAgentInput input, ChannelWriter<BaseEvent> events, CancellationToken cancellationToken = default)
{
return [
.. await base.PrepareBackendTools(backendTools, input, events, cancellationToken),
AIFunctionFactory.Create(
RetrieveState,
name: "retrieve_state",
description: "Retrieves the current shared state of the agent."
),
AIFunctionFactory.Create(
async (TState newState) => {
var delta = _currentState.CreatePatch(newState, _jsonSerOpts);
if (delta.Operations.Count > 0) {
UpdateState(newState);
await events.WriteAsync(new StateDeltaEvent {
Delta = [.. delta.Operations.Cast<object>()],
Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(),
}, cancellationToken);
}
},
name: "update_state",
description: "Updates the current shared state of the agent."
)
];
}

protected override async ValueTask OnRunStartedAsync(RunAgentInput input, ChannelWriter<BaseEvent> events, CancellationToken cancellationToken = default)
{
// Allow the base behaviour of emitting the RunStartedEvent
await base.OnRunStartedAsync(input, events, cancellationToken);

// Take the initial state from the input if possible
try
{
if (input.State.ValueKind == JsonValueKind.Object)
{
var state = input.State.Deserialize<TState>(_jsonSerOpts);
if (state is not null)
{
_currentState = state;
}
}
}
catch (JsonException)
{

Comment on lines +189 to +191
Copy link

Copilot AI May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The empty catch block may hide potential JSON deserialization issues; consider adding a comment or logging the exception to aid future troubleshooting.

Suggested change
catch (JsonException)
{
catch (JsonException ex)
{
// Log the exception to aid troubleshooting during JSON deserialization
Console.Error.WriteLine($"JSON deserialization error: {ex.Message}");

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is deliberate, for now it's just to avoid failure to extract the context causing an actual problem, it just swallows the exception, but we could perhaps surface it somehow so the consumer decides what behaviour to exhibit.

}

await events.WriteAsync(new StateSnapshotEvent
{
Snapshot = _currentState,
Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(),
}, cancellationToken);
}
}
Loading