Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Simulator update
  • Loading branch information
kicha0 committed Feb 9, 2024
commit 883e73bfb604ea574be1ea0857ed74bf3c367c56
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
_template_dir = os.path.join(os.path.dirname(__file__), 'templates')

from .simulator.simulator import Simulator
from .templates.simulator_templates import SimulatorTemplates
from .templates.simulator_templates import SimulatorTemplates, Template

__all__ = ["Simulator", "SimulatorTemplates"]
__all__ = ["Simulator", "SimulatorTemplates", "Template"]
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,26 @@
from .constants import ConversationRole


def is_closing_message(response: str):
def is_closing_message(response:any, recursion_depth: int = 0):
if (recursion_depth > 10):
raise Exception("Exceeded max call depth in is_closing_message")

# recursively go through each inner dictionary in the JSON dict and check if any value entry contains a closing message
if type(response) is dict:
for value in response.values():
if is_closing_message(value, recursion_depth=recursion_depth+1):
return True
elif type(response) is str:
return is_closing_message_helper(response)

return False

def is_closing_message_helper(response: str):
message = response.lower()
if "?" in message.lower():
return False
punctuation = [".", ",", "!", ";", ":"]
for p in punctuation:
punc = [".", ",", "!", ";", ":"]
for p in punc:
message = message.replace(p, "")
if (
"bye" not in message.lower().split()
Expand All @@ -36,9 +50,7 @@ async def simulate_conversation(
history_limit: int = 5,
api_call_delay_sec: float = 0,
logger: logging.Logger = logging.getLogger(__name__),
mlflow_logger: Optional[Any] = None,
template_paramaters: Optional[dict] = None,
simulate_callback: Optional[Callable[[str, Sequence[Union[Dict, ConversationTurn]], Optional[dict]], str]] = None,
mlflow_logger=None,
):
"""
Simulate a conversation between the given bots.
Expand Down Expand Up @@ -82,45 +94,30 @@ async def simulate_conversation(
(current_turn < turn_limit)
):
try:
current_character_idx = current_turn % 2
# if there is only one bot, means using customized simulate callback
# in the customer bot turn, instead of using the customer bot, need to invoke the simulate callback
if len(bots) < 2 and current_character_idx == 1:
question = conversation_history[-1].message
# TODO: Fix Bug 2816997
response = await simulate_callback(question, conversation_history, template_paramaters) # type: ignore[misc]
# add the generated response to the list of generated responses
conversation_history.append(
ConversationTurn(
role=ConversationRole.ASSISTANT,
name="ChatBot",
message=response,
))
else:
current_bot = bots[current_character_idx]
# invoke Bot to generate response given the input request
logger.info(f"-- Sending to {current_bot.role.value}")
# pass only the last generated turn without passing the bot name.
response, request, time_taken, full_response = await current_bot.generate_response(
session=session,
conversation_history=conversation_history,
max_history=history_limit,
turn_number=current_turn,
)
# add the generated response to the list of generated responses
conversation_history.append(
ConversationTurn(
role=current_bot.role,
name=current_bot.name,
message=response["samples"][0],
full_response=full_response,
request=request,
))
current_character_idx = current_turn % len(bots)
current_bot = bots[current_character_idx]
# invoke Bot to generate response given the input request
logger.info(f"-- Sending to {current_bot.role.value}")
# pass only the last generated turn without passing the bot name.
response, request, time_taken, full_response = await current_bot.generate_response(
session=session,
conversation_history=conversation_history,
max_history=history_limit,
turn_number=current_turn,
)

# check if conversation id is null, which means conversation starter was used. use id from next turn
if conversation_id is None and 'id' in response:
conversation_id = response["id"]

# add the generated response to the list of generated responses
conversation_history.append(
ConversationTurn(
role=current_bot.role,
name=current_bot.name,
message=response["samples"][0],
full_response=full_response,
request=request,
))
logger.info(f"Last turn: {conversation_history[-1]}")
if mlflow_logger is not None:
logger_tasks.append( # schedule logging but don't get blocked by it
Expand All @@ -129,8 +126,7 @@ async def simulate_conversation(
)
)
except Exception as e:
logger.warning(f"Error: {e}")
raise e
logger.warning("Error:" + str(e))
if mlflow_logger is not None:
logger_tasks.append( # schedule logging but don't get blocked by it
asyncio.create_task(mlflow_logger.log_error())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,25 @@ def __init__(

self.logger = logging.getLogger(repr(self))

self.conversation_starter = None # can either be a dictionary or jinja template
if role == ConversationRole.USER:
if "conversation_starter" in self.persona_template_args:
self.logger.info(
'This simulated bot will use the provided conversation starter '
f'"{repr(self.persona_template_args["conversation_starter"])[:400]}"'
'instead of generating a turn using a LLM'
)
self.conversation_starter = self.persona_template_args["conversation_starter"]
conversation_starter_content = self.persona_template_args["conversation_starter"]
if type(conversation_starter_content) is dict:
self.logger.info(f'This simulated bot will use the provided conversation starter (passed in as dictionary): {conversation_starter_content} instead of generating a turn using a LLM')
self.conversation_starter = conversation_starter_content
else:
self.logger.info(
'This simulated bot will use the provided conversation starter '
f'{repr(conversation_starter_content)[:400]}'
' instead of generating a turn using a LLM'
)
self.conversation_starter = jinja2.Template(
conversation_starter_content, undefined=jinja2.StrictUndefined
)
else:
self.logger.info('This simulated bot will generate the first turn as no conversation starter is provided')
self.conversation_starter = ""



async def generate_response(
Expand Down Expand Up @@ -88,11 +96,16 @@ async def generate_response(

# check if this is the first turn and the conversation_starter is not None,
# return the conversations starter rather than generating turn using LLM
if turn_number == 0 and self.conversation_starter is not None and self.conversation_starter != "":
self.logger.info(f"Returning conversation starter: {self.conversation_starter}")
if turn_number == 0 and self.conversation_starter is not None:
# if conversation_starter is a dictionary, pass it into samples as is
if type(self.conversation_starter) is dict:
self.logger.info(f"Returning conversation starter: {self.conversation_starter}")
samples = [self.conversation_starter]
else:
self.logger.info(f"Returning conversation starter: {repr(self.persona_template_args['conversation_starter'])[:400]}")
samples = [self.conversation_starter.render(**self.persona_template_args)]
time_taken = 0

samples = [self.conversation_starter]
finish_reason = ["stop"]

parsed_response = {
Expand All @@ -103,6 +116,8 @@ async def generate_response(
full_response = parsed_response
return parsed_response, {}, time_taken, full_response

print(f"{self.role} is going to be simulated now with params {self.persona_template_args}")

prompt = self.conversation_template.render(
conversation_turns=conversation_history[-max_history:],
role=self.role.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_model_class_from_url(endpoint_url: str):

# ===================== HTTP Retry ======================
class AsyncHTTPClientWithRetry:
def __init__(self, n_retry, retry_timeout, logger):
def __init__(self, n_retry, retry_timeout, logger, retry_options=None):
self.attempts = n_retry
self.logger = logger

Expand All @@ -49,14 +49,14 @@ def __init__(self, n_retry, retry_timeout, logger):
trace_config = TraceConfig() # set up request logging
trace_config.on_request_start.append(self.on_request_start)
trace_config.on_request_end.append(self.on_request_end)

retry_options = RandomRetry( # set up retry configuration
statuses=[104, 408, 409, 424, 429, 500, 502,
503, 504], # on which statuses to retry
attempts=n_retry,
min_timeout=retry_timeout,
max_timeout=retry_timeout,
)
if retry_options is None:
retry_options = RandomRetry( # set up retry configuration
statuses=[104, 408, 409, 424, 429, 500, 502,
503, 504], # on which statuses to retry
attempts=n_retry,
min_timeout=retry_timeout,
max_timeout=retry_timeout,
)

self.client = RetryClient(
trace_configs=[trace_config], retry_options=retry_options)
Expand Down Expand Up @@ -641,6 +641,7 @@ def _parse_response(self, response_data: dict, request_data: Optional[dict] = No
# https://platform.openai.com/docs/api-reference/chat
samples = []
finish_reason = []

for choice in response_data["choices"]:
if 'message' in choice and 'content' in choice['message']:
samples.append(choice['message']['content'])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from azure.ai.generative.synthetic.simulator._model_tools.models import (
AsyncHTTPClientWithRetry,
)
from aiohttp_retry import JitterRetry
import logging

import os

api_url = None
if "rai_svc_url" in os.environ:
api_url = os.environ["rai_svc_url"]
api_url = api_url.rstrip("/")
print(
f"Found rai_svc_url in environment variable, using {api_url} for rai service endpoint."
)


class RAIClient:
def __init__(self, ml_client, token_manager):
self.ml_client = ml_client
self.token_manager = token_manager

self.contentharm_parameters = None
self.jailbreaks_dataset = None

if api_url is not None:
host = api_url
else:
host = self.ml_client.jobs._api_url

self.api_url = (
f"{host}/"
+ f"raisvc/v1.0/subscriptions/{self.ml_client.subscription_id}/"
+ f"resourceGroups/{self.ml_client.resource_group_name}/"
+ f"providers/Microsoft.MachineLearningServices/workspaces/{self.ml_client.workspace_name}/"
)

self.parameter_json_endpoint = self.api_url + "simulation/template/parameters"
self.jailbreaks_json_endpoint = self.api_url + "simulation/jailbreak"
self.simulation_submit_endpoint = (
self.api_url + "simulation/chat/completions/submit"
)

def _create_async_client(self):
return AsyncHTTPClientWithRetry(
n_retry=6, retry_timeout=5, logger=logging.getLogger()
)

async def get_contentharm_parameters(self):
if self.contentharm_parameters is None:
self.contentharm_parameters = await self.get(self.parameter_json_endpoint)

return self.contentharm_parameters

async def get_jailbreaks_dataset(self):
if self.jailbreaks_dataset is None:
self.jailbreaks_dataset = await self.get(self.jailbreaks_json_endpoint)

return self.jailbreaks_dataset

async def get(self, url):
token = await self.token_manager.get_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}

async with self._create_async_client().client as session:
async with session.get(url=url, headers=headers) as response:
if response.status == 200:
response = await response.json()
return response

raise ValueError("Unable to retrieve requested resource from rai service.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from azure.ai.generative.synthetic.simulator._conversation import (
ConversationBot,
ConversationRole,
ConversationTurn,
simulate_conversation,
)

import copy
from typing import List, Tuple


class CallbackConversationBot(ConversationBot):
def __init__(
self, callback, user_template, user_template_parameters, *args, **kwargs
):
self.callback = callback
self.user_template = user_template
self.user_template_parameters = user_template_parameters

super().__init__(*args, **kwargs)

async def generate_response(
self,
session: "RetryClient",
conversation_history: List[ConversationTurn],
max_history: int,
turn_number: int = 0,
) -> Tuple[dict, dict, int, dict]:
chat_protocol_message = self._to_chat_protocol(
self.user_template, conversation_history, self.user_template_parameters
)
msg_copy = copy.deepcopy(chat_protocol_message)
result = await self.callback(msg_copy)

self.logger.info(f"Using user provided callback returning response.")

time_taken = 0
request = {}
try:
response = {
"samples": [result["messages"][-1]["content"]],
"finish_reason": ["stop"],
"id": None,
}
except:
raise TypeError(
"User provided callback do not conform to chat protocol standard."
)

self.logger.info(f"Parsed callback response")

return response, request, time_taken, response

def _to_chat_protocol(self, template, conversation_history, template_parameters):
messages = []

for i, m in enumerate(conversation_history):
messages.append({"content": m.message, "role": m.role.value})

return {
"template_parameters": template_parameters,
"messages": messages,
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
}
Loading