Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6a57cea
Reverting models to make sure calls to the simulator work
nagkumar91 Mar 7, 2024
dd01ecf
merge
nagkumar91 Mar 7, 2024
8cea9c3
quotes
nagkumar91 Mar 7, 2024
bea237e
Spellcheck fixes
nagkumar91 Mar 7, 2024
45073cc
ignore the models for doc generation
nagkumar91 Mar 7, 2024
08af5b3
Fixed the quotes on f strings
nagkumar91 Mar 7, 2024
7584cc9
pylint skip file
nagkumar91 Mar 7, 2024
e10fe6f
Merge branch 'Azure:main' into main
nagkumar91 Mar 7, 2024
304d506
Merge branch 'Azure:main' into main
nagkumar91 Mar 11, 2024
d727177
Support for summarization
nagkumar91 Mar 11, 2024
8b895ee
Adding a limit of 2 conversation turns for all but conversation simul…
nagkumar91 Mar 11, 2024
92d6d8e
exclude synthetic from mypy
nagkumar91 Mar 11, 2024
4742b04
Another lint fix
nagkumar91 Mar 11, 2024
975b0b3
Skip the file causing linting issues
nagkumar91 Mar 12, 2024
a00871f
Merge branch 'Azure:main' into main
nagkumar91 Mar 12, 2024
6bf1de0
Bugfix on output to json_qa_lines and empty response from callbacks
nagkumar91 Mar 13, 2024
5a974ce
Merge branch 'main' into main
nagkumar91 Mar 13, 2024
3f9c000
Skip pylint
nagkumar91 Mar 13, 2024
5ab6ab2
Merge branch 'main' of https://github.com/nagkumar91/azure-sdk-for-py…
nagkumar91 Mar 13, 2024
0c76fb0
Add if/else on message to eval json util
nagkumar91 Mar 14, 2024
fad8599
Merge branch 'Azure:main' into main
nagkumar91 Mar 21, 2024
a6d8d0f
Merge branch 'main' of https://github.com/nagkumar91/azure-sdk-for-py…
nagkumar91 Mar 25, 2024
a1e9c9d
adding max_simulation_results for sync call
nagkumar91 Mar 25, 2024
10ba426
Merge branch 'main' of https://github.com/nagkumar91/azure-sdk-for-py…
nagkumar91 Mar 25, 2024
3634d7c
Merge branch 'main' of https://github.com/nagkumar91/azure-sdk-for-py…
nagkumar91 Apr 1, 2024
1cbc55c
Merge branch 'main' of https://github.com/nagkumar91/azure-sdk-for-py…
nagkumar91 May 23, 2024
ac39d75
Merge branch 'main' of https://github.com/nagkumar91/azure-sdk-for-py…
nagkumar91 Jun 12, 2024
85e91bf
Merge branch 'main' of https://github.com/nagkumar91/azure-sdk-for-py…
nagkumar91 Jun 13, 2024
503634d
Bugfix: None was being added to the end of the output path
nagkumar91 Jun 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
quotes
  • Loading branch information
nagkumar91 committed Mar 7, 2024
commit 8cea9c33a47a35a9b8f74a68a019b052495f3945
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def get_model_class_from_url(endpoint_url: str):
'''Convert an endpoint URL to the appropriate model class.'''
"""Convert an endpoint URL to the appropriate model class."""
endpoint_path = urlparse(endpoint_url).path # remove query params

if endpoint_path.endswith("chat/completions"):
Expand Down Expand Up @@ -83,9 +83,9 @@ async def on_request_end(self, session, trace_config_ctx, params):
# ===========================================================

class LLMBase(ABC):
'''
"""
Base class for all LLM models.
'''
"""

def __init__(self, endpoint_url: str, name: str = "unknown", additional_headers: Optional[dict] = {}):
self.endpoint_url = endpoint_url
Expand Down Expand Up @@ -113,15 +113,15 @@ async def get_completion(
session: RetryClient,
**request_params,
) -> dict:
'''
"""
Query the model a single time with a prompt.

Parameters
----------
prompt: Prompt str to query model with.
session: aiohttp RetryClient object to use for the request.
**request_params: Additional parameters to pass to the request.
'''
"""
request_data = self.format_request_data(prompt, **request_params)
return await self.request_api(
session=session,
Expand Down Expand Up @@ -211,9 +211,9 @@ def __repr__(self):
# ===========================================================

class OpenAICompletionsModel(LLMBase):
'''
"""
Object for calling a Completions-style API for OpenAI models.
'''
"""
prompt_idx_key = "__prompt_idx__"

max_stop_tokens = 4
Expand All @@ -230,7 +230,7 @@ class OpenAICompletionsModel(LLMBase):
def __init__(
self, *,
endpoint_url: str,
name: str = 'OpenAICompletionsModel',
name: str = "OpenAICompletionsModel",
additional_headers: Optional[dict] = {},
api_version: Optional[str] = "2023-03-15-preview",
token_manager: APITokenManager,
Expand Down Expand Up @@ -263,7 +263,7 @@ def __init__(
if not stop:
stop = []
# Else if stop sequence is given as a string (Ex: "["\n", "<im_end>"]"), convert
elif type(stop) is str and stop.startswith('[') and stop.endswith(']'):
elif type(stop) is str and stop.startswith("[") and stop.endswith("]"):
stop = eval(stop)
elif type(stop) is str:
stop = [stop]
Expand All @@ -287,9 +287,9 @@ def get_model_params(self):


def format_request_data(self, prompt: str, **request_params) -> Dict[str, str]:
'''
"""
Format the request data for the OpenAI API.
'''
"""
# Caption images if available
if len(self.image_captions.keys()):
prompt = replace_prompt_captions(
Expand All @@ -309,7 +309,7 @@ async def get_conversation_completion(
role: str = "assistant",
**request_params,
) -> dict:
'''
"""
Query the model a single time with a message.

Parameters
Expand All @@ -318,10 +318,10 @@ async def get_conversation_completion(
session: aiohttp RetryClient object to query the model with.
role: Role of the user sending the message.
request_params: Additional parameters to pass to the model.
'''
"""
prompt = []
for message in messages:
prompt.append(f"{self.CHAT_START_TOKEN}{message['role']}\n{message['content']}\n{self.CHAT_END_TOKEN}\n")
prompt.append(f"{self.CHAT_START_TOKEN}{message["role"]}\n{message["content"]}\n{self.CHAT_END_TOKEN}\n")
prompt_string: str = "".join(prompt)
prompt_string += f"{self.CHAT_START_TOKEN}{role}\n"

Expand All @@ -341,7 +341,7 @@ async def get_all_completions( # type: ignore[override]
request_error_rate_threshold: float = 0.5,
**request_params,
) -> List[dict]:
'''
"""
Run a batch of prompts through the model and return the results in the order given.

Parameters
Expand All @@ -352,7 +352,7 @@ async def get_all_completions( # type: ignore[override]
api_call_delay_seconds: Number of seconds to wait between API requests.
request_error_rate_threshold: Maximum error rate allowed before raising an error.
request_params: Additional parameters to pass to the API.
'''
"""
if api_call_max_parallel_count > 1:
self.logger.info(f"Using {api_call_max_parallel_count} parallel workers to query the API..")

Expand Down Expand Up @@ -406,7 +406,7 @@ async def request_api_parallel(
"""
logger_tasks: List = [] # to await for logging to finish

while True: # process data from queue until it's empty
while True: # process data from queue until it"s empty
try:
request_data = request_datas.pop()
prompt_idx = request_data.pop(self.prompt_idx_key)
Expand All @@ -416,7 +416,7 @@ async def request_api_parallel(
session=session,
request_data=request_data,
)
await self._add_successful_response(response['time_taken'])
await self._add_successful_response(response["time_taken"])
except Exception as e:
response = {
"request": request_data,
Expand Down Expand Up @@ -469,7 +469,7 @@ async def request_api(
headers = {
"Content-Type": "application/json",
"X-CV": f"{uuid.uuid4()}",
"X-ModelType": self.model or '',
"X-ModelType": self.model or "",
}

if self.token_manager.auth_header == "Bearer":
Expand Down Expand Up @@ -539,21 +539,21 @@ def _parse_response(self, response_data: dict, request_data: Optional[dict] = No
# ===========================================================

class OpenAIChatCompletionsModel(OpenAICompletionsModel):
'''
"""
OpenAIChatCompletionsModel is a wrapper around OpenAICompletionsModel that
formats the prompt for chat completion.
'''
"""

def __init__(self, name='OpenAIChatCompletionsModel', *args, **kwargs):
def __init__(self, name="OpenAIChatCompletionsModel", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)


def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
# Caption images if available
if len(self.image_captions.keys()):
for message in messages:
message['content'] = replace_prompt_captions(
message['content'],
message["content"] = replace_prompt_captions(
message["content"],
captions=self.image_captions,
)

Expand All @@ -569,7 +569,7 @@ async def get_conversation_completion(
role: str = "assistant",
**request_params,
) -> dict:
'''
"""
Query the model a single time with a message.

Parameters
Expand All @@ -578,7 +578,7 @@ async def get_conversation_completion(
session: aiohttp RetryClient object to query the model with.
role: Not used for this model, since it is a chat model.
request_params: Additional parameters to pass to the model.
'''
"""
request_data = self.format_request_data(
messages=messages,
**request_params,
Expand All @@ -595,15 +595,15 @@ async def get_completion(
session: RetryClient,
**request_params,
) -> dict:
'''
"""
Query a ChatCompletions model with a single prompt. Note: entire message will be inserted into a "system" call.

Parameters
----------
prompt: Prompt str to query model with.
session: aiohttp RetryClient object to use for the request.
**request_params: Additional parameters to pass to the request.
'''
"""
messages = [{"role": "system", "content": prompt}]

request_data = self.format_request_data(
Expand Down Expand Up @@ -643,10 +643,10 @@ def _parse_response(self, response_data: dict, request_data: Optional[dict] = No
finish_reason = []

for choice in response_data["choices"]:
if 'message' in choice and 'content' in choice['message']:
samples.append(choice['message']['content'])
if 'message' in choice and 'finish_reason' in choice['message']:
finish_reason.append(choice['message']['finish_reason'])
if "message" in choice and "content" in choice["message"]:
samples.append(choice["message"]["content"])
if "message" in choice and "finish_reason" in choice["message"]:
finish_reason.append(choice["message"]["finish_reason"])

return {
"samples": samples,
Expand All @@ -659,13 +659,13 @@ def _parse_response(self, response_data: dict, request_data: Optional[dict] = No
# ===========================================================

class OpenAIMultiModalCompletionsModel(OpenAICompletionsModel):
'''
"""
Wrapper around OpenAICompletionsModel that formats the prompt for multimodal
completions containing images.
'''
"""
model_param_names = ["temperature", "max_tokens", "top_p", "n", "stop"]

def __init__(self, name='OpenAIMultiModalCompletionsModel', images_dir: Optional[str] = None, *args, **kwargs):
def __init__(self, name="OpenAIMultiModalCompletionsModel", images_dir: Optional[str] = None, *args, **kwargs):
self.images_dir = images_dir

super().__init__(name=name, *args, **kwargs)
Expand All @@ -683,13 +683,13 @@ def format_request_data(self, prompt: str, **request_params) -> dict:


def _log_request(self, request: dict) -> None:
'''Log prompt, ignoring image data if multimodal.'''
"""Log prompt, ignoring image data if multimodal."""
loggable_prompt_transcript = {
'transcript': [
(c if c['type'] != 'image' else {'type': 'image', 'data': '...'})
for c in request['transcript']
"transcript": [
(c if c["type"] != "image" else {"type": "image", "data": "..."})
for c in request["transcript"]
],
**{k: v for k, v in request.items() if k != 'transcript'}
**{k: v for k, v in request.items() if k != "transcript"}
}
super()._log_request(loggable_prompt_transcript)

Expand All @@ -699,20 +699,20 @@ def _log_request(self, request: dict) -> None:
# ===========================================================

class LLAMACompletionsModel(OpenAICompletionsModel):
'''
"""
Object for calling a Completions-style API for LLAMA models.
'''
"""

def __init__(
self, name: str = 'LLAMACompletionsModel', *args, **kwargs):
self, name: str = "LLAMACompletionsModel", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
# set authentication header to Bearer, as llama apis always uses the bearer auth_header
self.token_manager.auth_header = "Bearer"

def format_request_data(self, prompt: str, **request_params):
'''
"""
Format the request data for the OpenAI API.
'''
"""
# Caption images if available
if len(self.image_captions.keys()):
prompt = replace_prompt_captions(
Expand All @@ -731,12 +731,12 @@ def format_request_data(self, prompt: str, **request_params):
return request_data

def _parse_response(self, response_data: dict, request_data: dict) -> dict: # type: ignore[override]
prompt = request_data['input_data']['input_string'][0]
prompt = request_data["input_data"]["input_string"][0]

# remove prompt text from each response as llama model returns prompt + completion instead of only completion
# remove any text after the stop tokens, since llama doesn't support stop token
# remove any text after the stop tokens, since llama doesn"t support stop token
for idx, response in enumerate(response_data["samples"]):
response_data["samples"][idx] = response_data["samples"][idx].replace(prompt, '').strip()
response_data["samples"][idx] = response_data["samples"][idx].replace(prompt, "").strip()
for stop_token in self.stop:
if stop_token in response_data["samples"][idx]:
response_data["samples"][idx] = response_data["samples"][idx].split(stop_token)[0].strip()
Expand All @@ -746,7 +746,7 @@ def _parse_response(self, response_data: dict, request_data: dict) -> dict: # t
for choice in response_data:
if "0" in choice:
samples.append(choice["0"])
finish_reason.append('Stop')
finish_reason.append("Stop")

return {
"samples": samples,
Expand All @@ -758,15 +758,15 @@ def _parse_response(self, response_data: dict, request_data: dict) -> dict: # t
# ============== LLAMA ChatCompletionsModel =================
# ===========================================================
class LLAMAChatCompletionsModel(LLAMACompletionsModel):
'''
"""
LLaMa ChatCompletionsModel is a wrapper around LLaMaCompletionsModel that
formats the prompt for chat completion.
This chat completion model should be only used as assistant, and shouldn't be used to simulate user. It is not possible
This chat completion model should be only used as assistant, and shouldn"t be used to simulate user. It is not possible
to pass a system prompt do describe how the model would behave, So we only use the model as assistant to reply for questions
made by GPT simulated users.
'''
"""

def __init__(self, name='LLAMAChatCompletionsModel', *args, **kwargs):
def __init__(self, name="LLAMAChatCompletionsModel", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
# set authentication header to Bearer, as llama apis always uses the bearer auth_header
self.token_manager.auth_header = "Bearer"
Expand All @@ -775,21 +775,21 @@ def format_request_data(self, messages: List[dict], **request_params): # type:
# Caption images if available
if len(self.image_captions.keys()):
for message in messages:
message['content'] = replace_prompt_captions(
message['content'],
message["content"] = replace_prompt_captions(
message["content"],
captions=self.image_captions,
)

# For LLaMa we don't pass the prompt (user persona) as a system message since LLama doesn't support system message
# LLama only supports user, and assistant messages. The messages sequence has to start with User message/ It can't have two user or
# For LLaMa we don"t pass the prompt (user persona) as a system message since LLama doesn"t support system message
# LLama only supports user, and assistant messages. The messages sequence has to start with User message/ It can"t have two user or
# two assistant consecutive messages.
# so if we set the system meta prompt as a user message, and if we have the first two messages made by user then we
# combine the two messages in one message.
for idx, x in enumerate(messages):
if x['role'] == 'system':
x['role'] = 'user'
if len(messages) > 1 and messages[0]['role'] == 'user' and messages[1]['role'] == 'user':
messages[0] = {'role': 'user', 'content': messages[0]['content'] + '\n' + messages[1]['content']}
if x["role"] == "system":
x["role"] = "user"
if len(messages) > 1 and messages[0]["role"] == "user" and messages[1]["role"] == "user":
messages[0] = {"role": "user", "content": messages[0]["content"] + "\n" + messages[1]["content"]}
del messages[1]

# request_data = {"messages": messages, **self.get_model_params()}
Expand All @@ -810,7 +810,7 @@ async def get_conversation_completion(
role: str = "assistant",
**request_params,
) -> dict:
'''
"""
Query the model a single time with a message.

Parameters
Expand All @@ -819,7 +819,7 @@ async def get_conversation_completion(
session: aiohttp RetryClient object to query the model with.
role: Not used for this model, since it is a chat model.
request_params: Additional parameters to pass to the model.
'''
"""

request_data = self.format_request_data(
messages=messages,
Expand All @@ -835,9 +835,9 @@ def _parse_response(self, response_data: dict) -> dict: # type: ignore[override
samples = []
finish_reason = []
# for choice in response_data:
if 'output' in response_data:
samples.append(response_data['output'])
finish_reason.append('Stop')
if "output" in response_data:
samples.append(response_data["output"])
finish_reason.append("Stop")

return {
"samples": samples,
Expand Down