diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/__init__.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/__init__.py index 95a557e07bf5..514de6f5d387 100644 --- a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/__init__.py +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/__init__.py @@ -5,6 +5,5 @@ _template_dir = os.path.join(os.path.dirname(__file__), 'templates') from .simulator.simulator import Simulator -from .templates.simulator_templates import SimulatorTemplates -__all__ = ["Simulator", "SimulatorTemplates"] \ No newline at end of file +__all__ = ["Simulator"] \ No newline at end of file diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation.py index 28c100deaff0..7c391cda5ee6 100644 --- a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation.py +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation.py @@ -12,12 +12,26 @@ from .constants import ConversationRole -def is_closing_message(response: str): +def is_closing_message(response:Any, recursion_depth: int = 0): + if (recursion_depth > 10): + raise Exception("Exceeded max call depth in is_closing_message") + + # recursively go through each inner dictionary in the JSON dict and check if any value entry contains a closing message + if type(response) is dict: + for value in response.values(): + if is_closing_message(value, recursion_depth=recursion_depth+1): + return True + elif type(response) is str: + return is_closing_message_helper(response) + + return False + +def is_closing_message_helper(response: str): message = response.lower() if "?" in message.lower(): return False - punctuation = [".", ",", "!", ";", ":"] - for p in punctuation: + punc = [".", ",", "!", ";", ":"] + for p in punc: message = message.replace(p, "") if ( "bye" not in message.lower().split() @@ -36,9 +50,7 @@ async def simulate_conversation( history_limit: int = 5, api_call_delay_sec: float = 0, logger: logging.Logger = logging.getLogger(__name__), - mlflow_logger: Optional[Any] = None, - template_paramaters: Optional[dict] = None, - simulate_callback: Optional[Callable[[str, Sequence[Union[Dict, ConversationTurn]], Optional[dict]], str]] = None, + mlflow_logger=None, ): """ Simulate a conversation between the given bots. @@ -82,45 +94,30 @@ async def simulate_conversation( (current_turn < turn_limit) ): try: - current_character_idx = current_turn % 2 - # if there is only one bot, means using customized simulate callback - # in the customer bot turn, instead of using the customer bot, need to invoke the simulate callback - if len(bots) < 2 and current_character_idx == 1: - question = conversation_history[-1].message - # TODO: Fix Bug 2816997 - response = await simulate_callback(question, conversation_history, template_paramaters) # type: ignore[misc] - # add the generated response to the list of generated responses - conversation_history.append( - ConversationTurn( - role=ConversationRole.ASSISTANT, - name="ChatBot", - message=response, - )) - else: - current_bot = bots[current_character_idx] - # invoke Bot to generate response given the input request - logger.info(f"-- Sending to {current_bot.role.value}") - # pass only the last generated turn without passing the bot name. - response, request, time_taken, full_response = await current_bot.generate_response( - session=session, - conversation_history=conversation_history, - max_history=history_limit, - turn_number=current_turn, - ) - # add the generated response to the list of generated responses - conversation_history.append( - ConversationTurn( - role=current_bot.role, - name=current_bot.name, - message=response["samples"][0], - full_response=full_response, - request=request, - )) + current_character_idx = current_turn % len(bots) + current_bot = bots[current_character_idx] + # invoke Bot to generate response given the input request + logger.info(f"-- Sending to {current_bot.role.value}") + # pass only the last generated turn without passing the bot name. + response, request, time_taken, full_response = await current_bot.generate_response( + session=session, + conversation_history=conversation_history, + max_history=history_limit, + turn_number=current_turn, + ) # check if conversation id is null, which means conversation starter was used. use id from next turn if conversation_id is None and 'id' in response: conversation_id = response["id"] - + # add the generated response to the list of generated responses + conversation_history.append( + ConversationTurn( + role=current_bot.role, + name=current_bot.name, + message=response["samples"][0], + full_response=full_response, + request=request, + )) logger.info(f"Last turn: {conversation_history[-1]}") if mlflow_logger is not None: logger_tasks.append( # schedule logging but don't get blocked by it @@ -129,8 +126,7 @@ async def simulate_conversation( ) ) except Exception as e: - logger.warning(f"Error: {e}") - raise e + logger.warning("Error:" + str(e)) if mlflow_logger is not None: logger_tasks.append( # schedule logging but don't get blocked by it asyncio.create_task(mlflow_logger.log_error()) diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation_bot.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation_bot.py index 8737f357af20..f3ac70480602 100644 --- a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation_bot.py +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation_bot.py @@ -38,6 +38,7 @@ def __init__( self.role = role + self.conversation_template_orig = conversation_template self.conversation_template: jinja2.Template = jinja2.Template( conversation_template, undefined=jinja2.StrictUndefined ) @@ -50,17 +51,25 @@ def __init__( self.logger = logging.getLogger(repr(self)) + self.conversation_starter = None # can either be a dictionary or jinja template if role == ConversationRole.USER: if "conversation_starter" in self.persona_template_args: - self.logger.info( - 'This simulated bot will use the provided conversation starter ' - f'"{repr(self.persona_template_args["conversation_starter"])[:400]}"' - 'instead of generating a turn using a LLM' - ) - self.conversation_starter = self.persona_template_args["conversation_starter"] + conversation_starter_content = self.persona_template_args["conversation_starter"] + if type(conversation_starter_content) is dict: + self.logger.info(f'This simulated bot will use the provided conversation starter (passed in as dictionary): {conversation_starter_content} instead of generating a turn using a LLM') + self.conversation_starter = conversation_starter_content + else: + self.logger.info( + 'This simulated bot will use the provided conversation starter ' + f'{repr(conversation_starter_content)[:400]}' + ' instead of generating a turn using a LLM' + ) + self.conversation_starter = jinja2.Template( + conversation_starter_content, undefined=jinja2.StrictUndefined + ) else: self.logger.info('This simulated bot will generate the first turn as no conversation starter is provided') - self.conversation_starter = "" + async def generate_response( @@ -88,11 +97,16 @@ async def generate_response( # check if this is the first turn and the conversation_starter is not None, # return the conversations starter rather than generating turn using LLM - if turn_number == 0 and self.conversation_starter is not None and self.conversation_starter != "": - self.logger.info(f"Returning conversation starter: {self.conversation_starter}") + if turn_number == 0 and self.conversation_starter is not None: + # if conversation_starter is a dictionary, pass it into samples as is + if type(self.conversation_starter) is dict: + self.logger.info(f"Returning conversation starter: {self.conversation_starter}") + samples = [self.conversation_starter] + else: + self.logger.info(f"Returning conversation starter: {repr(self.persona_template_args['conversation_starter'])[:400]}") + samples = [self.conversation_starter.render(**self.persona_template_args)] # type: ignore[attr-defined] time_taken = 0 - samples = [self.conversation_starter] finish_reason = ["stop"] parsed_response = { @@ -103,11 +117,15 @@ async def generate_response( full_response = parsed_response return parsed_response, {}, time_taken, full_response - prompt = self.conversation_template.render( - conversation_turns=conversation_history[-max_history:], - role=self.role.value, - **self.persona_template_args - ) + try: + prompt = self.conversation_template.render( + conversation_turns=conversation_history[-max_history:], + role=self.role.value, + **self.persona_template_args + ) + except: + import code + code.interact(local=locals()) messages = [{"role": "system", "content": prompt}] diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_model_tools/models.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_model_tools/models.py index 99e4c78f683a..b1af023cfdf6 100644 --- a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_model_tools/models.py +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_model_tools/models.py @@ -40,7 +40,7 @@ def get_model_class_from_url(endpoint_url: str): # ===================== HTTP Retry ====================== class AsyncHTTPClientWithRetry: - def __init__(self, n_retry, retry_timeout, logger): + def __init__(self, n_retry, retry_timeout, logger, retry_options=None): self.attempts = n_retry self.logger = logger @@ -49,14 +49,14 @@ def __init__(self, n_retry, retry_timeout, logger): trace_config = TraceConfig() # set up request logging trace_config.on_request_start.append(self.on_request_start) trace_config.on_request_end.append(self.on_request_end) - - retry_options = RandomRetry( # set up retry configuration - statuses=[104, 408, 409, 424, 429, 500, 502, - 503, 504], # on which statuses to retry - attempts=n_retry, - min_timeout=retry_timeout, - max_timeout=retry_timeout, - ) + if retry_options is None: + retry_options = RandomRetry( # set up retry configuration + statuses=[104, 408, 409, 424, 429, 500, 502, + 503, 504], # on which statuses to retry + attempts=n_retry, + min_timeout=retry_timeout, + max_timeout=retry_timeout, + ) self.client = RetryClient( trace_configs=[trace_config], retry_options=retry_options) @@ -641,6 +641,7 @@ def _parse_response(self, response_data: dict, request_data: Optional[dict] = No # https://platform.openai.com/docs/api-reference/chat samples = [] finish_reason = [] + for choice in response_data["choices"]: if 'message' in choice and 'content' in choice['message']: samples.append(choice['message']['content']) diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_rai_rest_client/__init__.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_rai_rest_client/__init__.py new file mode 100644 index 000000000000..d540fd20468c --- /dev/null +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_rai_rest_client/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_rai_rest_client/rai_client.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_rai_rest_client/rai_client.py new file mode 100644 index 000000000000..3dfd11958a1d --- /dev/null +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_rai_rest_client/rai_client.py @@ -0,0 +1,77 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from azure.ai.generative.synthetic.simulator._model_tools.models import ( + AsyncHTTPClientWithRetry, +) +from aiohttp_retry import JitterRetry +import logging + +import os + +api_url = None +if "rai_svc_url" in os.environ: + api_url = os.environ["rai_svc_url"] + api_url = api_url.rstrip("/") + print( + f"Found rai_svc_url in environment variable, using {api_url} for rai service endpoint." + ) + + +class RAIClient: + def __init__(self, ml_client, token_manager): + self.ml_client = ml_client + self.token_manager = token_manager + + self.contentharm_parameters = None + self.jailbreaks_dataset = None + + if api_url is not None: + host = api_url + else: + host = self.ml_client.jobs._api_url + + self.api_url = ( + f"{host}/" + + f"raisvc/v1.0/subscriptions/{self.ml_client.subscription_id}/" + + f"resourceGroups/{self.ml_client.resource_group_name}/" + + f"providers/Microsoft.MachineLearningServices/workspaces/{self.ml_client.workspace_name}/" + ) + + self.parameter_json_endpoint = self.api_url + "simulation/template/parameters" + self.jailbreaks_json_endpoint = self.api_url + "simulation/jailbreak" + self.simulation_submit_endpoint = ( + self.api_url + "simulation/chat/completions/submit" + ) + + def _create_async_client(self): + return AsyncHTTPClientWithRetry( + n_retry=6, retry_timeout=5, logger=logging.getLogger() + ) + + async def get_contentharm_parameters(self): + if self.contentharm_parameters is None: + self.contentharm_parameters = await self.get(self.parameter_json_endpoint) + + return self.contentharm_parameters + + async def get_jailbreaks_dataset(self): + if self.jailbreaks_dataset is None: + self.jailbreaks_dataset = await self.get(self.jailbreaks_json_endpoint) + + return self.jailbreaks_dataset + + async def get(self, url): + token = await self.token_manager.get_token() + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + async with self._create_async_client().client as session: + async with session.get(url=url, headers=headers) as response: + if response.status == 200: + response = await response.json() + return response + + raise ValueError("Unable to retrieve requested resource from rai service.") diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_callback_conversation_bot.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_callback_conversation_bot.py new file mode 100644 index 000000000000..532ffb605803 --- /dev/null +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_callback_conversation_bot.py @@ -0,0 +1,67 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.generative.synthetic.simulator._conversation import ( + ConversationBot, + ConversationRole, + ConversationTurn, + simulate_conversation, +) + +import copy +from typing import List, Tuple + + +class CallbackConversationBot(ConversationBot): + def __init__( + self, callback, user_template, user_template_parameters, *args, **kwargs + ): + self.callback = callback + self.user_template = user_template + self.user_template_parameters = user_template_parameters + + super().__init__(*args, **kwargs) + + async def generate_response( + self, + session: "RetryClient", # type: ignore[name-defined] + conversation_history: List[ConversationTurn], + max_history: int, + turn_number: int = 0, + ) -> Tuple[dict, dict, int, dict]: + chat_protocol_message = self._to_chat_protocol( + self.user_template, conversation_history, self.user_template_parameters + ) + msg_copy = copy.deepcopy(chat_protocol_message) + result = await self.callback(msg_copy) + + self.logger.info(f"Using user provided callback returning response.") + + time_taken = 0 + try: + response = { + "samples": [result["messages"][-1]["content"]], + "finish_reason": ["stop"], + "id": None, + } + except: + raise TypeError( + "User provided callback do not conform to chat protocol standard." + ) + + self.logger.info(f"Parsed callback response") + + return response, {}, time_taken, response + + def _to_chat_protocol(self, template, conversation_history, template_parameters): + messages = [] + + for i, m in enumerate(conversation_history): + messages.append({"content": m.message, "role": m.role.value}) + + return { + "template_parameters": template_parameters, + "messages": messages, + "$schema": "http://azureml/sdk-2-0/ChatConversation.json", + } diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_proxy_completion_model.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_proxy_completion_model.py new file mode 100644 index 000000000000..7c82c01c4114 --- /dev/null +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_proxy_completion_model.py @@ -0,0 +1,165 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from aiohttp_retry import RetryClient, RandomRetry, JitterRetry +from aiohttp.web import HTTPException +from azure.ai.generative.synthetic.simulator._model_tools.models import ( + OpenAIChatCompletionsModel, + AsyncHTTPClientWithRetry, +) +from typing import List +import uuid + +from azure.ai.generative.synthetic.simulator.simulator._simulation_request_dto import ( + SimulationRequestDTO, +) + +import time +import logging +import copy + +import asyncio + + +class ProxyChatCompletionsModel(OpenAIChatCompletionsModel): + def __init__(self, name, template_key, template_parameters, *args, **kwargs): + self.tkey = template_key + self.tparam = template_parameters + self.result_url = None + + super().__init__(name=name, *args, **kwargs) + + def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override] + request_data = {"messages": messages, **self.get_model_params()} + request_data.update(request_params) + return request_data + + async def get_conversation_completion( + self, + messages: List[dict], + session: RetryClient, + role: str = "assistant", + **request_params, + ) -> dict: + """ + Query the model a single time with a message. + + Parameters + ---------- + messages: List of messages to query the model with. Expected format: [{"role": "user", "content": "Hello!"}, ...] + session: aiohttp RetryClient object to query the model with. + role: Not used for this model, since it is a chat model. + request_params: Additional parameters to pass to the model. + """ + request_data = self.format_request_data( + messages=messages, + **request_params, + ) + return await self.request_api( + session=session, + request_data=request_data, + ) + + async def request_api( + self, + session: RetryClient, + request_data: dict, + ) -> dict: + """ + Request the model with a body of data. + + Parameters + ---------- + session: HTTPS Session for invoking the endpoint. + request_data: Prompt dictionary to query the model with. (Pass {"prompt": prompt} instead of prompt.) + """ + + self._log_request(request_data) + + token = await self.token_manager.get_token() + + proxy_headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + headers = { + "Content-Type": "application/json", + "X-CV": f"{uuid.uuid4()}", + "X-ModelType": self.model or "", + } + # add all additional headers + headers.update(self.additional_headers) # type: ignore[arg-type] + + params = {} + if self.api_version: + params["api-version"] = self.api_version + + sim_request_dto = SimulationRequestDTO( + url=self.endpoint_url, + headers=headers, + payload=request_data, + params=params, + templatekey=self.tkey, + template_parameters=self.tparam, + ) + + time_start = time.time() + full_response = None + + async with session.post( + url=self.endpoint_url, headers=proxy_headers, json=sim_request_dto.to_dict() + ) as response: + if response.status == 202: + response = await response.json() + self.result_url = response["location"] + else: + raise HTTPException( + reason=f"Received unexpected HTTP status: {response.status} {await response.text()}" + ) + + retry_options = JitterRetry( # set up retry configuration + statuses=[202], # on which statuses to retry + attempts=7, + start_timeout=10, + max_timeout=180, + retry_all_server_errors=False + ) + + exp_retry_client = AsyncHTTPClientWithRetry( + n_retry=None, + retry_timeout=None, + logger=logging.getLogger(), + retry_options=retry_options, + ) + + # initial 10 seconds wait before attempting to fetch result + await asyncio.sleep(10) + + async with exp_retry_client.client as expsession: + async with expsession.get( + url=self.result_url, headers=proxy_headers + ) as response: + if response.status == 200: + response_data = await response.json() + self.logger.info(f"Response: {response_data}") + + # Copy the full response and return it to be saved in jsonl. + full_response = copy.copy(response_data) + + time_taken = time.time() - time_start + + parsed_response = self._parse_response( + response_data, request_data=request_data + ) + else: + raise HTTPException( + reason=f"Received unexpected HTTP status: {response.status} {await response.text()}" + ) + + return { + "request": request_data, + "response": parsed_response, + "time_taken": time_taken, + "full_response": full_response, + } diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_simulation_request_dto.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_simulation_request_dto.py new file mode 100644 index 000000000000..b81fef7dd650 --- /dev/null +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_simulation_request_dto.py @@ -0,0 +1,20 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import json + + +class SimulationRequestDTO: + def __init__(self, url, headers, payload, params, templatekey, template_parameters): + self.url = url + self.headers = headers + self.json = json.dumps(payload) + self.params = params + self.templatekey = templatekey + self.templateParameters = template_parameters + + def to_dict(self): + return self.__dict__ + + def to_json(self): + return json.dumps(self.__dict__) diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_token_manager.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_token_manager.py new file mode 100644 index 000000000000..a26533592d7c --- /dev/null +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_token_manager.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from enum import Enum +from azure.ai.generative.synthetic.simulator._model_tools import APITokenManager + + +class TokenScope(Enum): + DEFAULT_AZURE_MANAGEMENT = "https://management.azure.com/.default" + + +class PlainTokenManager(APITokenManager): + def __init__(self, openapi_key, logger, **kwargs): + super().__init__(logger, **kwargs) + self.token = openapi_key + + async def get_token(self): + return self.token diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_utils.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_utils.py new file mode 100644 index 000000000000..9b83c04132cf --- /dev/null +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/_utils.py @@ -0,0 +1,12 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import json + + +class JsonLineList(list): + def to_json_lines(self): + json_lines = "" + for item in self: + json_lines += json.dumps(item) + "\n" + return json_lines diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/simulator.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/simulator.py index 5d1ce6a2f4b2..ef6adbe96ca8 100644 --- a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/simulator.py +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/simulator.py @@ -2,7 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Callable, Dict, List, Union, Optional, Sequence +# needed for 'list' type annotations on 3.8 +from __future__ import annotations + +from typing import Callable, Dict, List, Union, Optional, Sequence, Any from azure.ai.generative.synthetic.simulator._conversation import ( ConversationBot, ConversationRole, @@ -14,14 +17,43 @@ AsyncHTTPClientWithRetry, ) from azure.ai.generative.synthetic.simulator import _template_dir as template_dir -from azure.ai.generative.synthetic.simulator._model_tools import APITokenManager, OpenAIChatCompletionsModel, LLMBase +from azure.ai.generative.synthetic.simulator._model_tools import ( + OpenAIChatCompletionsModel, + LLMBase, + ManagedIdentityAPITokenManager, +) +from azure.ai.generative.synthetic.simulator.templates.simulator_templates import ( + SimulatorTemplates, + Template, +) +from azure.ai.generative.synthetic.simulator.simulator._simulation_request_dto import ( + SimulationRequestDTO, +) +from azure.ai.generative.synthetic.simulator.simulator._token_manager import ( + PlainTokenManager, + TokenScope, +) +from azure.ai.generative.synthetic.simulator.simulator._proxy_completion_model import ( + ProxyChatCompletionsModel, +) + +from azure.ai.generative.synthetic.simulator.simulator._callback_conversation_bot import ( + CallbackConversationBot, +) + +from azure.ai.generative.synthetic.simulator._rai_rest_client.rai_client import ( + RAIClient, +) + +from azure.ai.generative.synthetic.simulator.simulator._utils import JsonLineList import logging import os import asyncio import threading import json +import random BASIC_MD = os.path.join(template_dir, "basic.md") # type: ignore[has-type] USER_MD = os.path.join(template_dir, "user.md") # type: ignore[has-type] @@ -30,112 +62,307 @@ class Simulator: def __init__( self, - systemConnection: Optional["AzureOpenAIModelConfiguration"] = None, # type: ignore[name-defined] - userConnection: Optional["AzureOpenAIModelConfiguration"] = None, # type: ignore[name-defined] - simulate_callback: Optional[Callable[[str, Sequence[Union[Dict, ConversationTurn]], Optional[Dict]], str]] = None, + simulator_connection: "AzureOpenAIModelConfiguration" = None, # type: ignore[name-defined] + ai_client: "AIClient" = None, # type: ignore[name-defined] + simulate_callback: Optional[Callable[[Dict], Dict]] = None, ): - self.userConnection = self._to_openai_chat_completion_model(userConnection) - self.systemConnection = self._to_openai_chat_completion_model(systemConnection) + """ + Initialize the instance with the given parameters. + + :keyword simulator_connection: An instance of AzureOpenAIModelConfiguration representing the connection + for simulating user response. Defaults to None. + :paramtype simulator_connection: Optional[AzureOpenAIModelConfiguration] + :keyword ai_client: An instance of AIClient for interacting with the AI service. Defaults to None. + :paramtype ai_client: Optional[AIClient] + :keyword simulate_callback: A callback function that takes a dictionary as input and returns a dictionary. + This function is called to simulate the assistant response. Defaults to None. + :paramtype simulate_callback: Optional[Callable[[Dict], Dict]] + + :raises ValueError: If both `simulator_connection` and `ai_client` are not provided (i.e., both are None). + """ + if (ai_client is None and simulator_connection is None) or ( + ai_client is not None and simulator_connection is not None + ): + raise ValueError( + "One and only one of the parameters [ai_client, simulator_connection] has to be set." + ) + + if simulate_callback is None: + raise ValueError("Callback cannot be None.") + + if not asyncio.iscoroutinefunction(simulate_callback): + raise ValueError("Callback has to be an async function.") + + self.ai_client = ai_client + self.simulator_connection = self._to_openai_chat_completion_model( + simulator_connection + ) + self.adversarial = False + self.rai_client = None + if ai_client: + self.ml_client = ai_client._ml_client + self.token_manager = ManagedIdentityAPITokenManager( + token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT, + logger=logging.getLogger("managed identity token manager"), + ) + self.rai_client = RAIClient(self.ml_client, self.token_manager) + self.template_handler = SimulatorTemplates(self.rai_client) + self.simulate_callback = simulate_callback - def _to_openai_chat_completion_model(self, config: "AzureOpenAIModelConfiguration"): # type: ignore[name-defined] - if config == None: + def _get_user_proxy_completion_model(self, tkey, tparam): + return ProxyChatCompletionsModel( + name="raisvc_proxy_model", + template_key=tkey, + template_parameters=tparam, + endpoint_url=self.rai_client.simulation_submit_endpoint, + token_manager=self.token_manager, + api_version="2023-07-01-preview", + max_tokens=1200, + temperature=0.0, + ) + + def _to_openai_chat_completion_model( + self, + config: "AzureOpenAIModelConfiguration" # type: ignore[name-defined] + ): + if config is None: return None token_manager = PlainTokenManager( openapi_key=config.api_key, auth_header="api-key", - logger=logging.getLogger(f"{config.deployment_name}_bot_token_manager") - ) + logger=logging.getLogger(f"{config.deployment_name}_bot_token_manager"), + ) return OpenAIChatCompletionsModel( endpoint_url=f"{config.api_base}openai/deployments/{config.deployment_name}/chat/completions", token_manager=token_manager, api_version=config.api_version, name=config.model_name, - **config.model_kwargs + **config.model_kwargs, ) - def create_bot( + def _create_bot( self, role: ConversationRole, conversation_template: str, instantiation_parameters: dict, - model: Union[LLMBase, OpenAIChatCompletionsModel], # type: ignore[arg-type] + adversarial_template_key: Optional[str] = None, + model: Union[LLMBase, OpenAIChatCompletionsModel] = None, # type: ignore[arg-type,assignment] ): - bot = ConversationBot( + if role == ConversationRole.USER and self.adversarial: + model = self._get_user_proxy_completion_model( + tkey=adversarial_template_key, tparam=instantiation_parameters + ) + + return ConversationBot( + role=role, + model=model, + conversation_template=conversation_template, + instantiation_parameters=instantiation_parameters, + ) + if role == ConversationRole.ASSISTANT: + dummy_model = lambda: None + dummy_model.name = "dummy_model" # type: ignore[attr-defined] + return CallbackConversationBot( + callback=self.simulate_callback, + role=role, + model=dummy_model, + user_template=conversation_template, + user_template_parameters=instantiation_parameters, + conversation_template="", + instantiation_parameters={}, + ) + + return ConversationBot( role=role, model=model, conversation_template=conversation_template, instantiation_parameters=instantiation_parameters, ) - return bot - def setup_bot( - self, - role: Union[str, ConversationRole], - template: str, - parameters: dict + def _setup_bot( + self, role: Union[str, ConversationRole], template: "Template", parameters: dict ): if role == ConversationRole.ASSISTANT: - with open(BASIC_MD, "r") as f: - chatbot_name_key = "chatbot_name" - assistant_template = f.read() - assistant_parameters = {chatbot_name_key: "ChatBot"} - if parameters.get(chatbot_name_key) is not None: - assistant_parameters[chatbot_name_key] = parameters[chatbot_name_key] - - return self.create_bot( - role, assistant_template, assistant_parameters, self.userConnection + return self._create_bot(role, str(template), parameters) + elif role == ConversationRole.USER: + if template.content_harm: + return self._create_bot( + role, str(template), parameters, template.template_name + ) + + return self._create_bot( + role, str(template), parameters, model=self.simulator_connection ) - elif role == ConversationRole.USER: - return self.create_bot(role, template, parameters, self.systemConnection) + def _ensure_service_dependencies(self): + if self.rai_client is None: + raise ValueError( + "Simulation options require rai services but ai client is not provided." + ) + + def _join_conversation_starter(self, parameters, to_join): + key = "conversation_starter" + if key in parameters.keys(): + parameters[key] = f"{to_join} {parameters[key]}" + else: + parameters[key] = to_join + + return parameters async def simulate_async( self, - template: str, - parameters: dict, + template: "Template", max_conversation_turns: int, + parameters: List[dict] = [], + jailbreak: bool = False, api_call_retry_max_count: int = 3, api_call_retry_sleep_sec: int = 1, api_call_delay_sec: float = 0, + concurrent_async_task: int = 3 ): - # create user bot - gpt_bot = self.setup_bot(ConversationRole.USER, str(template), parameters) + """Asynchronously simulate conversations using the provided template and parameters + + :keyword template: An instance of the Template class defining the conversation structure. + :paramtype template: Template + :keyword max_conversation_turns: The maximum number of conversation turns to simulate. + :paramtype max_conversation_turns: int + :keyword parameters: A list of dictionaries containing the parameter values to be used in the simulations. + Defaults to an empty list. + :paramtype parameters: list[dict], optional + :keyword jailbreak: If set to True, allows breaking out of the conversation flow defined by the template. + Defaults to False. + :paramtype jailbreak: bool, optional + :keyword api_call_retry_max_count: The maximum number of API call retries in case of errors. Defaults to 3. + :paramtype api_call_retry_max_count: int, optional + :keyword api_call_retry_sleep_sec: The time in seconds to wait between API call retries. Defaults to 1. + :paramtype api_call_retry_sleep_sec: int, optional + :keyword api_call_delay_sec: The time in seconds to wait between API calls. Defaults to 0. + :paramtype api_call_delay_sec: float, optional + :keyword concurrent_async_task: The maximum number of asynchronous tasks to run concurrently. Defaults to 3. + :paramtype concurrent_async_task: int, optional + + :return: A list of dictionaries containing the simulation results. + :rtype: List[Dict] + + Note: api_call_* parameters are only valid for simulation_connection defined. + The parameters cannot be used to configure behavior for calling user provided callback. + """ + if not isinstance(template, Template): + raise ValueError( + f"Please use simulator to construct template. Found {type(template)}" + ) - if self.userConnection == None: - bots = [gpt_bot] - else: - customer_bot = self.setup_bot( - ConversationRole.ASSISTANT, str(template), parameters + if not isinstance(parameters, list): + raise ValueError( + f"Expect parameters to be a list of dictionary, but found {type(parameters)}" ) - bots = [gpt_bot, customer_bot] + + if template.content_harm: + self._ensure_service_dependencies() + self.adversarial = True + templates = await self.template_handler._get_ch_template_collections( + template.template_name + ) + else: + template.template_parameters = parameters + templates = [template] + + semaphore = asyncio.Semaphore(concurrent_async_task) + sim_results = [] + tasks = [] + + for t in templates: + for p in t.template_parameters: + if jailbreak: + self._ensure_service_dependencies() + jailbreak_dataset = await self.rai_client.get_jailbreaks_dataset() # type: ignore[union-attr] + p = self._join_conversation_starter( + p, random.choice(jailbreak_dataset) + ) + + tasks.append( + asyncio.create_task( + self._simulate_async( + template=t, + parameters=p, + max_conversation_turns=max_conversation_turns, + api_call_retry_max_count=api_call_retry_max_count, + api_call_delay_sec=api_call_delay_sec, + sem=semaphore, + ) + ) + ) + + sim_results = await asyncio.gather(*tasks) + + return JsonLineList(sim_results) + + async def _simulate_async( + self, + template: "Template", + max_conversation_turns: int, + parameters: dict = {}, + api_call_retry_max_count: int = 3, + api_call_retry_sleep_sec: int = 1, + api_call_delay_sec: float = 0, + sem: "asyncio.Semaphore" = asyncio.Semaphore(3), + ): + """ + Asynchronously simulate conversations using the provided template and parameters. + + Args: + template (Template): An instance of the Template class defining the conversation structure. + max_conversation_turns (int): The maximum number of conversation turns to simulate. + parameters (list[dict], optional): A list of dictionaries containing the parameter values to be used in + the simulations. Defaults to an empty list. + jailbreak (bool, optional): If set to True, allows breaking out of the conversation flow defined by the + template. Defaults to False. + api_call_retry_max_count (int, optional): The maximum number of API call retries in case of errors. + Defaults to 3. + api_call_retry_sleep_sec (int, optional): The time in seconds to wait between API call retries. Defaults to 1. + api_call_delay_sec (float, optional): The time in seconds to wait between API calls. Defaults to 0. + concurrent_async_task (int, optional): The maximum number of asynchronous tasks to run concurrently. + Defaults to 3. + Returns: + List[Dict]: A list of dictionaries containing the simulation results. + + Raises: + Exception: If an error occurs during the simulation process. + """ + # create user bot + user_bot = self._setup_bot(ConversationRole.USER, template, parameters) + system_bot = self._setup_bot(ConversationRole.ASSISTANT, template, parameters) + + bots = [user_bot, system_bot] + # simulate the conversation asyncHttpClient = AsyncHTTPClientWithRetry( n_retry=api_call_retry_max_count, retry_timeout=api_call_retry_sleep_sec, - logger=logging.getLogger() + logger=logging.getLogger(), ) - - async with asyncHttpClient.client as session: - conversation_id, conversation_history = await simulate_conversation( - bots=bots, - simulate_callback=self.simulate_callback, - session=session, - turn_limit=max_conversation_turns, - api_call_delay_sec=api_call_delay_sec, - template_paramaters=parameters, - ) + async with sem: + async with asyncHttpClient.client as session: + conversation_id, conversation_history = await simulate_conversation( + bots=bots, + session=session, + turn_limit=max_conversation_turns, + api_call_delay_sec=api_call_delay_sec, + ) return self._to_chat_protocol(template, conversation_history, parameters) - def _get_citations(self, parameters, context_keys, turn_num = None): + def _get_citations(self, parameters, context_keys, turn_num=None): citations = [] for c_key in context_keys: if isinstance(parameters[c_key], dict): if "callback_citation_key" in parameters[c_key]: callback_citation_key = parameters[c_key]["callback_citation_key"] - callback_citations = self._get_callback_citations(parameters[c_key][callback_citation_key], turn_num) + callback_citations = self._get_callback_citations( + parameters[c_key][callback_citation_key], turn_num + ) else: callback_citations = [] if callback_citations: @@ -144,22 +371,17 @@ def _get_citations(self, parameters, context_keys, turn_num = None): for k, v in parameters[c_key].items(): if k not in ["callback_citations", "callback_citation_key"]: citations.append( - { - "id": k, - "content": self._to_citation_content(v) - } + {"id": k, "content": self._to_citation_content(v)} ) else: citations.append( { "id": c_key, - "content": self._to_citation_content(parameters[c_key]) + "content": self._to_citation_content(parameters[c_key]), } ) - return { - "citations": citations - } + return {"citations": citations} def _to_citation_content(self, obj): if isinstance(obj, str): @@ -167,7 +389,9 @@ def _to_citation_content(self, obj): else: return json.dumps(obj) - def _get_callback_citations(self, callback_citations: dict, turn_num: Optional[int] = None): + def _get_callback_citations( + self, callback_citations: dict, turn_num: Optional[int] = None + ): if turn_num == None: return [] current_turn_citations = [] @@ -175,18 +399,15 @@ def _get_callback_citations(self, callback_citations: dict, turn_num: Optional[i if current_turn_str in callback_citations.keys(): citations = callback_citations[current_turn_str] if isinstance(citations, dict): - for k, v in citations.items(): + for k, v in citations.items(): current_turn_citations.append( - { - "id": k, - "content": self._to_citation_content(v) - } + {"id": k, "content": self._to_citation_content(v)} ) else: current_turn_citations.append( { "id": current_turn_str, - "content": self._to_citation_content(citations) + "content": self._to_citation_content(citations), } ) return current_turn_citations @@ -195,65 +416,91 @@ def _to_chat_protocol(self, template, conversation_history, template_parameters) messages = [] for i, m in enumerate(conversation_history): - citations = self._get_citations(template_parameters, template.context_key, i) - messages.append( - { - "content": m.message, - "role": m.role.value, - "turn_number": i, - "template_parameters": template_parameters, - "context": citations - } - ) + message = {"content": m.message, "role": m.role.value} + if len(template.context_key) > 0: + citations = self._get_citations( + template_parameters, template.context_key, i + ) + message["context"] = citations + messages.append(message) return { + "template_parameters": template_parameters, "messages": messages, "$schema": "http://azureml/sdk-2-0/ChatConversation.json", } - def wrap_async( + def _wrap_async( self, results, - template: str, - parameters: dict, + template: "Template", max_conversation_turns: int, + parameters: List[dict] = [], + jailbreak: bool = False, api_call_retry_max_count: int = 3, api_call_retry_sleep_sec: int = 1, api_call_delay_sec: float = 0, + concurrent_async_task: int = 1, ): result = asyncio.run( self.simulate_async( template=template, parameters=parameters, max_conversation_turns=max_conversation_turns, + jailbreak=jailbreak, api_call_retry_max_count=api_call_retry_max_count, api_call_retry_sleep_sec=api_call_retry_sleep_sec, api_call_delay_sec=api_call_delay_sec, + concurrent_async_task=concurrent_async_task, ) ) results[0] = result def simulate( self, - template: str, - parameters: dict, + template: "Template", max_conversation_turns: int, + parameters: List[dict] = [], + jailbreak: bool = False, api_call_retry_max_count: int = 3, api_call_retry_sleep_sec: int = 1, api_call_delay_sec: float = 0, ): + """ + Simulates a conversation using a predefined template with customizable parameters and control over API behavior. + + :param template: The template object that defines the structure and flow of the conversation. + :type template: Template + :param max_conversation_turns: The maximum number of conversation turns to simulate. + :type max_conversation_turns: int + :param parameters: A list of dictionaries where each dictionary contains parameters specific to a single turn. + :type parameters: List[dict], optional + :param jailbreak: A flag to determine if the simulation should continue when encountering API errors. + :type jailbreak: bool, optional + :param api_call_retry_max_count: The maximum number of retries for API calls upon encountering an error. + :type api_call_retry_max_count: int, optional + :param api_call_retry_sleep_sec: The number of seconds to wait between retry attempts of an API call. + :type api_call_retry_sleep_sec: int, optional + :param api_call_delay_sec: The number of seconds to wait before making a new API call to simulate conversation delay. + :type api_call_delay_sec: float, optional + :return: The outcome of the simulated conversations as a list. + :rtype: List[Dict] + """ results = [None] + concurrent_async_task = 1 thread = threading.Thread( - target=self.wrap_async, + target=self._wrap_async, args=( results, template, - parameters, max_conversation_turns, + parameters, + jailbreak, api_call_retry_max_count, api_call_retry_sleep_sec, api_call_delay_sec, + concurrent_async_task, ), ) @@ -262,12 +509,199 @@ def simulate( return results[0] + @staticmethod + def from_fn( + fn: Callable[[Any], dict], + simulator_connection: "AzureOpenAIModelConfiguration" = None, # type: ignore[name-defined] + ai_client: "AIClient" = None, # type: ignore[name-defined] + **kwargs, + ): + """ + Creates an instance from a function that defines certain behaviors or configurations, along with connections to simulation and AI services. + + :param fn: The function to be used for configuring or defining behavior. This function should accept a single argument and return a dictionary of configurations. + :type fn: Callable[[Any], dict] + :param simulator_connection: Configuration for the connection to the simulation service, if any. + :type simulator_connection: AzureOpenAIModelConfiguration, optional + :param ai_client: The AI client to be used for interacting with AI services. + :type ai_client: AIClient, optional + :return: An instance of simulator configured with the specified function, simulation connection, and AI client. + :rtype: Simulator + :raises ValueError: If both `simulator_connection` and `ai_client` are not provided (i.e., both are None). + + Any additional keyword arguments (`**kwargs`) will be passed directly to the function `fn`. + """ + if hasattr(fn, "__wrapped__"): + func_module = fn.__wrapped__.__module__ + func_name = fn.__wrapped__.__name__ + if ( + func_module == "openai.resources.chat.completions" + and func_name == "create" + ): + return Simulator._from_openai_chat_completions( + fn, simulator_connection, ai_client, **kwargs + ) + return Simulator( + simulator_connection=simulator_connection, + ai_client=ai_client, + simulate_callback=fn, + ) -class PlainTokenManager(APITokenManager): - def __init__(self, openapi_key, logger, **kwargs): - super().__init__(logger, **kwargs) - self.token = openapi_key + @staticmethod + def _from_openai_chat_completions( + fn: Callable[[Any], dict], simulator_connection=None, ai_client=None, **kwargs + ): + return Simulator( + simulator_connection=simulator_connection, + ai_client=ai_client, + simulate_callback=Simulator._wrap_openai_chat_completion(fn, **kwargs), + ) + + @staticmethod + def _wrap_openai_chat_completion(fn, **kwargs): + async def callback(chat_protocol_message): + response = await fn(messages=chat_protocol_message["messages"], **kwargs) + + message = response.choices[0].message + + formatted_response = {"role": message.role, "content": message.content} + + chat_protocol_message["messages"].append(formatted_response) + + return chat_protocol_message + + return callback + + @staticmethod + def from_pf_path( + pf_path: str, + simulator_connection: "AzureOpenAIModelConfiguration" = None, # type: ignore[name-defined] + ai_client: "AIClient" = None, # type: ignore[name-defined] + **kwargs, + ): + """ + Creates an instance of Simulator from a specified promptflow path. + + :param pf_path: The path to the promptflow folder + :type pf_path: str + :param simulator_connection: Configuration for the connection to the simulation service, if any. + :type simulator_connection: AzureOpenAIModelConfiguration, optional + :param ai_client: The AI client to be used for interacting with AI services. + :type ai_client: AIClient, optional + :return: An instance of the class configured with the specified policy file, simulation connection, and AI client. + :rtype: The class which this static method is part of. + :return: An instance of simulator configured with the specified function, simulation connection, and AI client. + :rtype: Simulator + :raises ValueError: If both `simulator_connection` and `ai_client` are not provided (i.e., both are None). + + Any additional keyword arguments (`**kwargs`) will be passed to the underlying configuration or initialization methods. + """ + try: + from promptflow import load_flow + except: + raise EnvironmentError( + "Unable to import from promptflow. Have you installed promptflow in the python environment?" + ) + flow = load_flow(pf_path) + return Simulator( + simulator_connection=simulator_connection, + ai_client=ai_client, + simulate_callback=Simulator._wrap_pf(flow), + ) + + @staticmethod + def _wrap_pf(flow): + flow_ex = flow._init_executable() + for k, v in flow_ex.inputs.items(): + if v.is_chat_history: + chat_history_key = k + if v.type.value != "list": + raise TypeError(f"Chat history {k} not a list.") + + if v.is_chat_input: + chat_input_key = k + if v.type.value != "string": + raise TypeError(f"Chat input {k} not a string.") + + for k, v in flow_ex.outputs.items(): + if v.is_chat_output: + chat_output_key = k + if v.type.value != "string": + raise TypeError(f"Chat output {k} not a string.") + + if chat_output_key is None or chat_input_key is None: + raise ValueError("Prompflow has no required chat input and/or chat output.") + + async def callback(chat_protocol_message): + all_messages = chat_protocol_message["messages"] + input_data = {chat_input_key: all_messages[-1]} + if chat_history_key: + input_data[chat_history_key] = all_messages + + response = flow.invoke(input_data).output + chat_protocol_message["messages"].append( + {"role": "assistant", "content": response[chat_output_key]} + ) + + return chat_protocol_message + + return callback + + @staticmethod + def create_template( + name: str, + template: Optional[str], + template_path: Optional[str], + context_key: Optional[list[str]], + ): + """ + Creates a template instance either from a string or from a file path provided. + + :param name: The name to assign to the created template. + :type name: str + :param template: The string representation of the template content. + :type template: Optional[str] + :param template_path: The file system path to a file containing the template content. + :type template_path: Optional[str] + :param context_key: A list of keys that define the context used within the template. + :type context_key: Optional[list[str]] + :return: A new instance of a Template configured with the provided details. + :rtype: Template + + :raises ValueError: If both or neither of the parameters 'template' and 'template_path' are set. + + One of 'template' or 'template_path' must be provided to create a template. If 'template' is provided, it is used directly; if 'template_path' is provided, the content is read from the file at that path. + """ + if (template is None and template_path is None) or ( + template is not None and template_path is not None + ): + raise ValueError( + "One and only one of the parameters [template, template_path] has to be set." + ) + + if template is not None: + return Template(template_name=name, text=template, context_key=context_key) + + if template_path is not None: + with open(template_path, "r") as f: + tc = f.read() + + return Template(template_name=name, text=tc, context_key=context_key) + + raise ValueError( + "Condition not met for creating template, please check examples and parameter list." + ) - async def get_token(self): - return self.token + @staticmethod + def get_template(template_name: str): + """ + Retrieves a template instance by its name. + + :param template_name: The name of the template to retrieve. + :type template_name: str + :return: The Template instance corresponding to the given name. + :rtype: Template + """ + st = SimulatorTemplates() + return st.get_template(template_name) diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/templates/_templates.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/templates/_templates.py index 6f9c8e1b65db..a721d2695e0d 100644 --- a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/templates/_templates.py +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/templates/_templates.py @@ -23,3 +23,12 @@ "summarization": SUMMARIZATION_PATH, "search": SEARCH_PATH } + +CH_TEMPLATES_COLLECTION_KEY = set([ + "adv_qa", + "adv_conversation", + "adv_summarization", + "adv_search", + "adv_rewrite", + "adv_content_gen_ungrounded", + "adv_content_gen_grounded"]) \ No newline at end of file diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/templates/simulator_templates.py b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/templates/simulator_templates.py index 70048ba2dbdd..296e86cb4e23 100644 --- a/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/templates/simulator_templates.py +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/templates/simulator_templates.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from azure.ai.generative.synthetic.simulator.templates._templates import ALL_TEMPLATES, CONTEXT_KEY +from azure.ai.generative.synthetic.simulator.templates._templates import ALL_TEMPLATES, CONTEXT_KEY, CH_TEMPLATES_COLLECTION_KEY from azure.ai.generative.synthetic.simulator import _template_dir as template_dir from jinja2 import ( Environment as JinjaEnvironment, @@ -10,22 +10,54 @@ meta as JinjaMeta, ) import os +import asyncio class Template: - def __init__(self, template_name, text, context_key): + def __init__(self, template_name, text, context_key, content_harm=False, template_parameters=None): self.text = text self.context_key = context_key self.template_name = template_name + self.content_harm = content_harm + self.template_parameters = template_parameters def __str__(self): + if self.content_harm: + return "{{ch_template_placeholder}}" return self.text + def __to_ch_templates(self): + pass + +class ContentHarmTemplatesUtils: + @staticmethod + def get_template_category(key): + return key.split("/")[0] + + @staticmethod + def get_template_key(key): + filepath = key.rsplit(".json")[0] + parts = str(filepath).split("/") + filename = ContentHarmTemplatesUtils.json_name_to_md_name(parts[-1]) + prefix = parts[:-1] + prefix.append(filename) + + return "/".join(prefix) + + @staticmethod + def json_name_to_md_name(name): + result = name.replace("_aml", "") + + return result+".md" + + class SimulatorTemplates: - def __init__(self): + def __init__(self, rai_client=None): self.cached_templates_source = {} self.template_env = JinjaEnvironment( loader=JinjaFileSystemLoader(searchpath=template_dir) ) + self.rai_client = rai_client + self.categorized_ch_parameters = None def get_templates_list(self): return ALL_TEMPLATES.keys() @@ -33,7 +65,59 @@ def get_templates_list(self): def _get_template_context_key(self, template_name): return CONTEXT_KEY.get(template_name) + async def _get_ch_template_collections(self, collection_key): + if self.rai_client is None: + raise EnvironmentError("Service client is unavailable. Ai client is required to use rai service.") + + if self.categorized_ch_parameters is None: + categorized_parameters = {} + util = ContentHarmTemplatesUtils + + parameters = await self.rai_client.get_contentharm_parameters() + + for k in parameters.keys(): + template_key = util.get_template_key(k) + categorized_parameters[template_key] = { + "parameters": parameters[k], + "category": util.get_template_category(k), + "parameters_key": k + } + self.categorized_ch_parameters = categorized_parameters + + template_category = collection_key.split("adv_")[-1] + + plist = self.categorized_ch_parameters + ch_templates = [] + for tkey in [k for k in plist.keys() if plist[k]["category"] == template_category]: + params = plist[tkey]["parameters"] + for p in params: + p.update( + { + "ch_template_placeholder" : "{{ch_template_placeholder}}" + } + ) + + template = Template( + template_name=tkey, + text=None, + context_key=[], + content_harm=True, + template_parameters=params + ) + + ch_templates.append(template) + return ch_templates + def get_template(self, template_name): + if template_name in CH_TEMPLATES_COLLECTION_KEY: + return Template( + template_name=template_name, + text=None, + context_key=[], + content_harm=True, + template_parameters=None + ) + if template_name in self.cached_templates_source: template, template_path, loader_func = self.cached_templates_source[template_name] return Template(template_name, template, self._get_template_context_key(template_name)) diff --git a/sdk/ai/azure-ai-generative/cspell.json b/sdk/ai/azure-ai-generative/cspell.json index d5d75d454834..0f49bbd407e5 100644 --- a/sdk/ai/azure-ai-generative/cspell.json +++ b/sdk/ai/azure-ai-generative/cspell.json @@ -1,4 +1,4 @@ { - "ignoreWords": ["cmpl", "uqkvl", "redef", "datas", "unbatched", "endofprompt", "unlabel", "pydash"], + "ignoreWords": ["cmpl", "uqkvl", "redef", "datas", "unbatched", "endofprompt", "unlabel", "pydash", "raisvc", "tkey", "tparam", "punc"], "ignorePaths": ["sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/pf_templates/**/*"] } diff --git a/sdk/ai/azure-ai-generative/tests/simulator/unittests/test_simulator.py b/sdk/ai/azure-ai-generative/tests/simulator/unittests/test_simulator.py index efd2224d421f..4c76d36eb0c2 100644 --- a/sdk/ai/azure-ai-generative/tests/simulator/unittests/test_simulator.py +++ b/sdk/ai/azure-ai-generative/tests/simulator/unittests/test_simulator.py @@ -5,18 +5,75 @@ import pytest import tempfile import os -from unittest.mock import Mock, patch +import asyncio -from azure.ai.generative.synthetic.simulator import Simulator, _template_dir as template_dir, SimulatorTemplates +from unittest.mock import Mock, patch, AsyncMock, MagicMock + +from azure.ai.generative.synthetic.simulator import Simulator, _template_dir as template_dir +from azure.ai.generative.synthetic.simulator.templates.simulator_templates import SimulatorTemplates from azure.ai.generative.synthetic.simulator._conversation.conversation_turn import ConversationTurn from azure.ai.generative.synthetic.simulator._conversation import ConversationRole from azure.ai.generative.synthetic.simulator.templates._templates import CONVERSATION +@pytest.fixture() +def mock_config(): + mock_config = Mock() + mock_config.api_key = "apikey" + mock_config.deployment_name = "deployment" + mock_config.api_version = "api-version" + mock_config.api_base = "api-base" + mock_config.model_name = "model-name" + mock_config.model_kwargs = {} + yield mock_config + +@pytest.fixture() +def system_model_completion(): + model = Mock() + model.get_conversation_completion = AsyncMock() + response = { + "samples": ["message content"], + "finish_reason": ["stop"], + "id": None, + } + + model.get_conversation_completion.return_value = { + "request": {}, + "response": response, + "time_taken": 0, + "full_response": response, + } + + yield model + +@pytest.fixture() +def task_parameters(): + yield { + "name": "Jake", + "profile": "Jake is a 10 years old boy", + "tone": "friendly", + "metadata": {"k1": "v1", "k2": "v2"}, + "task": "this is task description", + "chatbot_name": "chatbot_name" + } + +@pytest.fixture() +def conv_template(): + st = SimulatorTemplates() + + conv_template = st.get_template(CONVERSATION) + yield conv_template + +@pytest.fixture() +def async_callback(): + async def callback(x): + return x + yield callback + @pytest.mark.unittest class TestSimulator: @patch("azure.ai.generative.synthetic.simulator.simulator.simulator.simulate_conversation") @patch("azure.ai.generative.synthetic.simulator.simulator.simulator.Simulator._to_openai_chat_completion_model") - def test_simulator_returns_formatted_conversations(self, _, simulate_conversation_mock): + def test_simulator_returns_formatted_conversations(self, _, simulate_conversation_mock, mock_config, task_parameters, conv_template, async_callback): ct1 = ConversationTurn( role=ConversationRole.USER, @@ -36,32 +93,29 @@ def test_simulator_returns_formatted_conversations(self, _, simulate_conversatio conv_history = [ct1, ct2] simulate_conversation_mock.return_value = ("conversation_id", conv_history) - simulator = Simulator(None, None) - - st = SimulatorTemplates() - task_parameters = { - "name": "Jake", - "profile": "Jake is a 10 years old boy", - "tone": "friendly", - "metadata": {"k1": "v1", "k2": "v2"}, - "task": "this is task description", - "chatbot_name": "chatbot_name" - } + simulator = Simulator( + simulator_connection=mock_config, + ai_client=None, + simulate_callback=async_callback + ) - conv_template = st.get_template(CONVERSATION) + st = SimulatorTemplates() conv_params = st.get_template_parameters(CONVERSATION) assert set(task_parameters.keys()) == set(conv_params.keys()) - conv = simulator.simulate(conv_template, task_parameters, 2) - - expected_keys = set(["messages", "$schema"]) - assert type(conv) == dict - assert set(conv) == expected_keys + conv = simulator.simulate( + template=conv_template, + parameters=[task_parameters], + max_conversation_turns=2) - def test_simulator_parse_callback_citations(self, ): + expected_keys = set(["messages", "$schema", "template_parameters"]) + assert issubclass(type(conv), list) + assert len(conv) == 1 + assert set(conv[0]) == expected_keys + def test_simulator_parse_callback_citations(self, mock_config, async_callback): tempalte_parameters = {'name': 'Jane', 'tone': 'happy', 'metadata': {'customer_info': '## customer_info name: Jane Doe age: 28', @@ -75,7 +129,10 @@ def test_simulator_parse_callback_citations(self, ): 'content': '## customer_info name: Jane Doe age: 28'}]} expected_turn_2_citations = {'citations': [{'id': 'documents', 'content': "\n>>> From: wohdjewodhfjevwdjfywlemfhe==\n# Information about product item_number: 3"}]} - simulator = Simulator(None, None) + simulator = Simulator( + simulator_connection=mock_config, + simulate_callback=async_callback + ) turn_0_citations = simulator._get_citations(tempalte_parameters, context_keys=['metadata'], turn_num = 0) turn_1_citations = simulator._get_citations(tempalte_parameters, context_keys=['metadata'], turn_num = 1) @@ -83,4 +140,138 @@ def test_simulator_parse_callback_citations(self, ): assert turn_0_citations == expected_turn_0_citations, "incorrect turn_0 citations" assert turn_1_citations == expected_turn_1_citations, "incorrect turn_1 citations" - assert turn_2_citations == expected_turn_2_citations, "incorrect turn_2 citations" \ No newline at end of file + assert turn_2_citations == expected_turn_2_citations, "incorrect turn_2 citations" + + @patch("azure.ai.generative.synthetic.simulator.simulator.simulator.Simulator._to_openai_chat_completion_model") + def test_simulator_from_openai_callback(self, to_chat_completion_model, mock_config, system_model_completion, task_parameters, conv_template): + oai_mock = AsyncMock() + oai_mock.__wrapped__ = Mock() + oai_mock.__wrapped__.__module__ = "openai.resources.chat.completions" + oai_mock.__wrapped__.__name__ = "create" + + content = "oai magic mock" + response = MagicMock() + response.choices[0].message.role = "user" + response.choices[0].message.content = content + + oai_mock.return_value = response + + to_chat_completion_model.return_value = system_model_completion + + sim = Simulator.from_fn( + fn=oai_mock, + simulator_connection=mock_config) + + conv = sim.simulate( + template=conv_template, + parameters=[task_parameters], + max_conversation_turns=2) + + oai_mock.assert_called_once() + assert(len(conv) == 1) + assert(conv[0]["messages"][1]["content"] == "oai magic mock") + + # disabled for now. Azure sdk for python test pipeline import error in promptflow + # from opencensus.ext.azure.log_exporter import AzureEventHandler + # E ImportError: cannot import name 'AzureEventHandler' from 'opencensus.ext.azure.log_exporter' (D:\a\_work\1\s\sdk\ai\azure-ai-generative\.tox\mindependency\lib\site-packages\opencensus\ext\azure\log_exporter\__init__.py) + @patch("azure.ai.generative.synthetic.simulator.simulator.simulator.Simulator._to_openai_chat_completion_model") + @patch("promptflow.load_flow") + @patch("azure.ai.generative.synthetic.simulator.simulator.simulator.Simulator._wrap_pf") + def simulator_from_pf(self, wrap_pf, load_flow, to_chat_completion_model, mock_config, system_model_completion, task_parameters, conv_template): + content = "pf_mock" + + async def callback(cm): + cm["messages"].append( + { + "role": "assistant", + "content": content + } + ) + return cm + + wrap_pf.return_value = callback + load_flow.return_value = "dontcare" + + to_chat_completion_model.return_value = system_model_completion + + sim = Simulator.from_pf_path( + pf_path="don't care", + simulator_connection=mock_config) + + conv = sim.simulate( + template=conv_template, + parameters=[task_parameters], + max_conversation_turns=2) + + assert(len(conv) == 1) + assert(conv[0]["messages"][1]["content"] == content) + + @patch("azure.ai.generative.synthetic.simulator.simulator.simulator.Simulator._to_openai_chat_completion_model") + def test_simulator_from_custom_callback(self, to_chat_completion_model, mock_config, system_model_completion, task_parameters, conv_template): + to_chat_completion_model.return_value = system_model_completion + + content = "async callback" + async def callback(cm): + cm["messages"].append( + { + "role": "assistant", + "content": content + } + ) + return cm + sim = Simulator.from_fn( + fn=callback, + simulator_connection=mock_config) + + conv = sim.simulate( + template=conv_template, + parameters=[task_parameters], + max_conversation_turns=2) + + assert(len(conv) == 1) + assert(conv[0]["messages"][1]["content"] == content) + + def test_simulator_throws_expected_error_from_incorrect_template_type(self, mock_config, task_parameters, async_callback): + simulator = Simulator( + simulator_connection=mock_config, + ai_client=None, + simulate_callback=async_callback + ) + with pytest.raises(ValueError) as exc_info: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(simulator.simulate_async( + template="wrong template type", + max_conversation_turns=2, + parameters=[task_parameters] + ) + ) + + assert(str(exc_info.value).startswith("Please use simulator to construct template")) + + def test_simulator_throws_expected_error_from_sync_callback(self, mock_config): + with pytest.raises(ValueError) as exc_info: + simulator = Simulator( + simulator_connection=mock_config, + ai_client=None, + simulate_callback=lambda x:x + ) + + assert(str(exc_info.value).startswith("Callback has to be an async function.")) + + def test_simulator_throws_expected_error_from_unset_ai_client_or_connection(self): + with pytest.raises(ValueError) as all_none_exc_info: + simulator = Simulator( + simulator_connection=None, + ai_client=None, + simulate_callback=lambda x:x + ) + with pytest.raises(ValueError) as all_set_exc_info: + simulator = Simulator( + simulator_connection="some value", + ai_client="some value", + simulate_callback=lambda x:x + ) + + assert(str(all_none_exc_info.value).startswith("One and only one of the parameters [ai_client, simulator_connection]")) + assert(str(all_set_exc_info.value).startswith("One and only one of the parameters [ai_client, simulator_connection]")) \ No newline at end of file diff --git a/sdk/ai/azure-ai-generative/tests/simulator/unittests/test_simulator_templates.py b/sdk/ai/azure-ai-generative/tests/simulator/unittests/test_simulator_templates.py index 9898b64fdaa2..48c340364a3f 100644 --- a/sdk/ai/azure-ai-generative/tests/simulator/unittests/test_simulator_templates.py +++ b/sdk/ai/azure-ai-generative/tests/simulator/unittests/test_simulator_templates.py @@ -5,8 +5,9 @@ import pytest import tempfile import os -from azure.ai.generative.synthetic.simulator import SimulatorTemplates + from unittest.mock import Mock, patch +from azure.ai.generative.synthetic.simulator.templates.simulator_templates import SimulatorTemplates from azure.ai.generative.synthetic.simulator import _template_dir as template_dir from azure.ai.generative.synthetic.simulator.templates._templates import SUMMARIZATION_PATH, SUMMARIZATION