Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
hitl working with thread
  • Loading branch information
lusu-msft committed Jan 9, 2026
commit daf2cdce4ef561134ea15d81591192772eaf4d21
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List
import inspect

from agent_framework import AgentProtocol, AIFunction
from agent_framework import AgentProtocol, AIFunction, InMemoryCheckpointStorage
from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module
from opentelemetry import trace

Expand Down Expand Up @@ -87,6 +87,8 @@ def __init__(self, agent: Union[AgentProtocol, AgentFactory],
self._agent_or_factory: Union[AgentProtocol, AgentFactory] = agent
self._resolved_agent: "Optional[AgentProtocol]" = None
self._hitl_helper = HumanInTheLoopHelper()
self._checkpoint_storage = InMemoryCheckpointStorage()
self._agent_thread_in_memory = {}
# If agent is already instantiated, use it directly
if isinstance(agent, AgentProtocol):
self._resolved_agent = agent
Expand Down Expand Up @@ -189,9 +191,13 @@ def init_tracing(self):
self.tracer = trace.get_tracer(__name__)

def setup_tracing_with_azure_ai_client(self, project_endpoint: str):
logger.info("Setting up tracing with AzureAIClient")
logger.info(f"Project endpoint for tracing credential: {self.credentials}")
async def setup_async():
async with AzureAIClient(
project_endpoint=project_endpoint, async_credential=self.credentials
project_endpoint=project_endpoint,
async_credential=self.credentials,
credential=self.credentials,
) as agent_client:
await agent_client.setup_azure_ai_observability()

Expand Down Expand Up @@ -225,9 +231,20 @@ async def agent_run( # pylint: disable=too-many-statements

logger.info(f"Starting agent_run with stream={context.stream}")
request_input = context.request.get("input")

input_converter = AgentFrameworkInputConverter(agent=agent, hitl_helper=self._hitl_helper)
message = input_converter.transform_input(request_input)
# TODO: load agent thread from storage and deserialize
agent_thread = self._agent_thread_in_memory.get(context.conversation_id, agent.get_new_thread())

last_checkpoint = None
if self._checkpoint_storage:
checkpoints = await self._checkpoint_storage.list_checkpoints()
last_checkpoint = checkpoints[-1] if len(checkpoints) > 0 else None
logger.info(f"Last checkpoint data: {last_checkpoint.to_dict() if last_checkpoint else 'None'}")

input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper)
message = await input_converter.transform_input(
request_input,
agent_thread=agent_thread,
checkpoint=last_checkpoint)
logger.debug(f"Transformed input message type: {type(message)}")

# Use split converters
Expand All @@ -238,13 +255,23 @@ async def agent_run( # pylint: disable=too-many-statements
async def stream_updates():
try:
update_count = 0
updates = agent.run_stream(message)
updates = agent.run_stream(
message,
thread=agent_thread,
checkpoint_storage=self._checkpoint_storage,
checkpoint_id=last_checkpoint.checkpoint_id if last_checkpoint else None,
)
async for event in streaming_converter.convert(updates):
update_count += 1
yield event


if agent_thread:
self._agent_thread_in_memory[context.conversation_id] = agent_thread
logger.info("Streaming completed with %d updates", update_count)
finally:
if hasattr(agent, "pending_requests"):
logger.info("Clearing agent pending requests after streaming completed")
agent.pending_requests.clear()
# Close tool_client if it was created for this request
if tool_client is not None:
try:
Expand All @@ -258,8 +285,14 @@ async def stream_updates():
# Non-streaming path
logger.info("Running agent in non-streaming mode")
non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper)
result = await agent.run(message)
logger.debug(f"Agent run completed, result type: {type(result)}")
result = await agent.run(message,
thread=agent_thread,
checkpoint_storage=self._checkpoint_storage,
checkpoint_id=last_checkpoint.checkpoint_id if last_checkpoint else None,
)
logger.info(f"Agent run completed, result type: {type(result)}")
if agent_thread:
self._agent_thread_in_memory[context.conversation_id] = agent_thread
transformed_result = non_streaming_converter.transform_output_for_response(result)
logger.info("Agent run and transformation completed successfully")
return transformed_result
Expand All @@ -281,3 +314,6 @@ async def oauth_consent_stream(error=e):
logger.debug("Closed tool_client after request processing")
except Exception as ex: # pylint: disable=broad-exception-caught
logger.warning(f"Error closing tool_client: {ex}")
if not context.stream and hasattr(agent, "pending_requests"):
logger.info("Clearing agent pending requests after streaming completed")
# agent.pending_requests.clear()
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from typing import Dict, List, Optional

from agent_framework import ChatMessage, RequestInfoEvent, Role as ChatRole
from agent_framework import (
AgentThread,
ChatMessage,
RequestInfoEvent,
Role as ChatRole,
WorkflowCheckpoint,
)
from agent_framework._types import TextContent

from azure.ai.agentserver.core.logger import get_logger
Expand All @@ -21,13 +27,14 @@ class AgentFrameworkInputConverter:
Accepts: str | List | None
Returns: None | str | ChatMessage | list[str] | list[ChatMessage]
"""
def __init__(self, *, agent, hitl_helper=None):
self._agent = agent
def __init__(self, *, hitl_helper=None) -> None:
self._hitl_helper = hitl_helper

def transform_input(
async def transform_input(
self,
input: str | List[Dict] | None,
agent_thread: Optional[AgentThread] = None,
checkpoint: Optional[WorkflowCheckpoint] = None,
) -> str | ChatMessage | list[str] | list[ChatMessage] | None:
logger.debug("Transforming input of type: %s", type(input))

Expand All @@ -37,11 +44,21 @@ def transform_input(
if isinstance(input, str):
return input

pending_requests = getattr(self._agent, 'pending_requests', {})
if self._hitl_helper and pending_requests:
hitl_response = self._validate_and_convert_hitl_response(pending_requests, input)
if self._hitl_helper:
# load pending requests from checkpoint and thread messages if available
thread_messages = []
if agent_thread:
thread_messages = await agent_thread.message_store.list_messages()
logger.info(f"Thread messages count: {len(thread_messages)}")
pending_hitl_requests = self._hitl_helper.get_pending_hitl_request(thread_messages, checkpoint)
logger.info(f"Pending HitL requests: {list(pending_hitl_requests.keys())}")
hitl_response = self._hitl_helper.validate_and_convert_hitl_response(
input,
pending_requests=pending_hitl_requests)
logger.info(f"HitL response validation result: {[m.to_dict() for m in hitl_response]}")
if hitl_response:
return hitl_response

return self._transform_input_internal(input)

def _transform_input_internal(
Expand Down Expand Up @@ -157,7 +174,9 @@ def _validate_and_convert_hitl_response(
logger.warning("Function call output missing valid call_id for HitL response validation.")
return None
request_info = pending_request[call_id]
if not request_info or not isinstance(request_info, RequestInfoEvent):
if isinstance(request_info, dict):
request_info = RequestInfoEvent.from_dict(request_info)
if not isinstance(request_info, RequestInfoEvent):
logger.warning("No valid pending request info found for call_id: %s", call_id)
return None

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Any, Optional

from agent_framework import AgentThread, BaseAgent


class AgentStateInventory:
"""Checkpoint inventory to manage saved states of agent threads and workflows."""

async def get(self, conversation_id: str) -> Optional[Any]:
"""Retrieve the saved state for a given conversation ID.

Args:
conversation_id (str): The conversation ID.
"""
pass

async def set(self, conversation_id: str, state: Any) -> None:
"""Save the state for a given conversation ID.

Args:
conversation_id (str): The conversation ID.
state (Any): The state to save.
"""
pass


class InMemoryThreadAgentStateInventory(AgentStateInventory):
"""In-memory implementation of AgentStateInventory."""
def __init__(self, agent: BaseAgent) -> None:
self._agent = agent
self._inventory: dict[str, AgentThread] = {}

async def get(self, conversation_id: str) -> Optional[AgentThread]:
"""Retrieve the saved state for a given conversation ID.

Args:
conversation_id (str): The conversation ID.
"""
if conversation_id in self._inventory:
serialized_thread = self._inventory[conversation_id]
return await self._agent.deserialize_thread(serialized_thread)
return None

async def set(self, conversation_id: str, state: AgentThread) -> None:
"""Save the state for a given conversation ID.

Args:
conversation_id (str): The conversation ID.
state (AgentThread): The state to save.
"""
if conversation_id and state:
serialized_thread = await state.serialize()
self._inventory[conversation_id] = serialized_thread


class InMemoryCheckpointAgentStateInventory(AgentStateInventory):
"""In-memory implementation of AgentStateInventory for workflow checkpoints."""
def __init__(self) -> None:
self._inventory: dict[str, Any] = {}

async def get(self, conversation_id: str) -> Optional[Any]:
"""Retrieve the saved state for a given conversation ID.

Args:
conversation_id (str): The conversation ID.
"""
return self._inventory.get(conversation_id, None)

async def set(self, conversation_id: str, state: Any) -> None:
"""Save the state for a given conversation ID.

Args:
conversation_id (str): The conversation ID.
state (Any): The state to save.
"""
if conversation_id and state:
self._inventory[conversation_id] = state
Original file line number Diff line number Diff line change
@@ -1,14 +1,65 @@
from typing import Any, List, Dict
from typing import Any, List, Dict, Optional, Union
import json

from agent_framework import ChatMessage, FunctionResultContent, RequestInfoEvent
from agent_framework import (
ChatMessage,
FunctionResultContent,
FunctionApprovalResponseContent,
RequestInfoEvent,
WorkflowCheckpoint,
)
from agent_framework._types import UserInputRequestContents

from azure.ai.agentserver.core.logger import get_logger
from azure.ai.agentserver.core.server.common.constants import HUMAN_IN_THE_LOOP_FUNCTION_NAME

logger = get_logger()

class HumanInTheLoopHelper:

def get_pending_hitl_request(self,
thread_messages: List[ChatMessage] = None,
checkpoint: Optional[WorkflowCheckpoint] = None,
) -> dict[str, Union[RequestInfoEvent, Any]]:
res = {}
# if has checkpoint (WorkflowAgent), find pending request info from checkpoint
if checkpoint and checkpoint.pending_request_info_events:
for call_id, request in checkpoint.pending_request_info_events.items():
# find if the request is already responded in the thread messages
if isinstance(request, dict):
request_obj = RequestInfoEvent.from_dict(request)
res[call_id] = request_obj
return res

if not thread_messages:
return res

# if no checkpoint (Agent), find user input request and pair the feedbacks
for message in thread_messages:
for content in message.contents:
print(f" Content {type(content)}: {content.to_dict()}")
if isinstance(content, UserInputRequestContents):
# is a human input request
function_call = content.function_call
call_id = getattr(function_call, "call_id", "")
if call_id:
res[call_id] = RequestInfoEvent(
source_executor_id="agent",
request_id=call_id,
response_type=None,
request_data=function_call,
)
elif isinstance(content, FunctionResultContent):
if content.call_id and content.call_id in res:
# remove requests that already got feedback
res.pop(content.call_id)
elif isinstance(content, FunctionApprovalResponseContent):
function_call = content.function_call
call_id = getattr(function_call, "call_id", "")
if call_id and call_id in res:
res.pop(call_id)
return res

def convert_user_input_request_content(self, content: UserInputRequestContents) -> dict:
function_call = content.function_call
call_id = getattr(function_call, "call_id", "")
Expand All @@ -35,13 +86,34 @@ def convert_request_arguments(self, arguments: Any) -> str:
arguments = str(arguments)
return arguments

def convert_response(self, hitl_request: RequestInfoEvent, input: Dict) -> List[ChatMessage]:
def validate_and_convert_hitl_response(self,
input: str | List[Dict] | None,
pending_requests: Dict[str, RequestInfoEvent],
) -> List[ChatMessage] | None:

if input is None or isinstance(input, str):
logger.warning("Expected list input for HitL response validation, got str.")
return None

res = []
for item in input:
if item.get("type") != "function_call_output":
logger.warning("Expected function_call_output type for HitL response validation.")
return None
call_id = item.get("call_id", None)
if call_id and call_id in pending_requests:
res.append(self.convert_response(pending_requests[call_id], item))
return res

def convert_response(self, hitl_request: RequestInfoEvent, input: Dict) -> ChatMessage:
response_type = hitl_request.response_type
response_result = input.get("output", "")
logger.info(f"response_type {type(response_type)}: %s", response_type)
if response_type and hasattr(response_type, "convert_from_payload"):
response_result = response_type.convert_from_payload(input.get("output", ""))
logger.info(f"response_result {type(response_result)}: %s", response_result)
response_content = FunctionResultContent(
call_id=hitl_request.request_id,
result=response_result,
)
return [ChatMessage(role="tool", contents=[response_content])]
return ChatMessage(role="tool", contents=[response_content])
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
AZURE_OPENAI_ENDPOINT=https://<endpoint-name>.cognitiveservices.azure.com/
OPENAI_API_VERSION=2025-03-01-preview
AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=<deployment-name>
Loading