diff --git a/sdk/openai/azure-openai/tests/conftest.py b/sdk/openai/azure-openai/tests/conftest.py index 9a1516065255..43d853fead3e 100644 --- a/sdk/openai/azure-openai/tests/conftest.py +++ b/sdk/openai/azure-openai/tests/conftest.py @@ -16,11 +16,12 @@ DefaultAzureCredential as AsyncDefaultAzureCredential, get_bearer_token_provider as get_bearer_token_provider_async, ) +from ci_tools.variables import in_ci # for pytest.parametrize GA = "2024-02-01" -PREVIEW = "2024-03-01-preview" +PREVIEW = "2024-05-01-preview" LATEST = PREVIEW AZURE = "azure" @@ -65,8 +66,15 @@ ENV_OPENAI_TTS_MODEL = "tts-1" +def skip_openai_test(api_type) -> bool: + return in_ci() and "openai" in api_type and "tests-weekly" not in os.getenv("SYSTEM_DEFINITIONNAME", "") + + @pytest.fixture def client(api_type, api_version): + if skip_openai_test(api_type): + pytest.skip("Skipping openai tests - they only run on tests-weekly.") + if api_type == "azure": client = openai.AzureOpenAI( azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT), @@ -100,6 +108,9 @@ def client(api_type, api_version): @pytest.fixture def client_async(api_type, api_version): + if skip_openai_test(api_type): + pytest.skip("Skipping openai tests - they only run on tests-weekly.") + if api_type == "azure": client = openai.AsyncAzureOpenAI( azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT), diff --git a/sdk/openai/azure-openai/tests/test_assistants.py b/sdk/openai/azure-openai/tests/test_assistants.py index 6cb2c3f6ad6f..e4e19dd4e7d4 100644 --- a/sdk/openai/azure-openai/tests/test_assistants.py +++ b/sdk/openai/azure-openai/tests/test_assistants.py @@ -4,18 +4,192 @@ # ------------------------------------ import os -import time import pytest import pathlib import uuid +import openai from devtools_testutils import AzureRecordedTestCase from conftest import ASST_AZURE, PREVIEW, GPT_4_OPENAI, configure - -TIMEOUT = 300 +from openai import AssistantEventHandler +from openai.types.beta.threads import ( + Text, + Message, + ImageFile, + TextDelta, + MessageDelta, +) +from openai.types.beta.threads import Run +from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta + + +class EventHandler(AssistantEventHandler): + def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: + if delta.value: + assert delta.value is not None + if delta.annotations: + for annotation in delta.annotations: + if annotation.type == "file_citation": + assert annotation.index is not None + assert annotation.file_citation.file_id + assert annotation.file_citation.quote + elif annotation.type == "file_path": + assert annotation.index is not None + assert annotation.file_path.file_id + + def on_run_step_done(self, run_step: RunStep) -> None: + details = run_step.step_details + if details.type == "tool_calls": + for tool in details.tool_calls: + if tool.type == "code_interpreter": + assert tool.id + assert tool.code_interpreter.outputs + assert tool.code_interpreter.input is not None + elif tool.type == "function": + assert tool.id + assert tool.function.arguments is not None + assert tool.function.name is not None + + def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: + details = delta.step_details + if details is not None: + if details.type == "tool_calls": + for tool in details.tool_calls or []: + if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input: + assert tool.index is not None + assert tool.code_interpreter.input is not None + elif details.type == "message_creation": + assert details.message_creation.message_id + + def on_run_step_created(self, run_step: RunStep): + assert run_step.object == "thread.run.step" + assert run_step.id + assert run_step.type + assert run_step.created_at + assert run_step.assistant_id + assert run_step.thread_id + assert run_step.run_id + assert run_step.status + assert run_step.step_details + + def on_message_created(self, message: Message): + assert message.object == "thread.message" + assert message.id + assert message.created_at + assert message.attachments is not None + assert message.status + assert message.thread_id + + def on_message_delta(self, delta: MessageDelta, snapshot: Message): + if delta.content: + for content in delta.content: + if content.type == "text": + assert content.index is not None + if content.text: + if content.text.value: + assert content.text.value is not None + if content.text.annotations: + for annot in content.text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + elif content.type == "image_file": + assert content.index is not None + assert content.image_file.file_id + + + def on_message_done(self, message: Message): + for msg in message.content: + if msg.type == "image_file": + assert msg.image_file.file_id + if msg.type == "text": + assert msg.text.value + if msg.text.annotations: + for annot in msg.text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + assert annot.text is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + assert annot.text is not None + + def on_text_created(self, text: Text): + assert text.value is not None + + def on_text_done(self, text: Text): + assert text.value is not None + for annot in text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + assert annot.text is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + assert annot.text is not None + + def on_image_file_done(self, image_file: ImageFile): + assert image_file.file_id + + def on_tool_call_created(self, tool_call: ToolCall): + assert tool_call.id + + def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall): + if delta.type == "code_interpreter": + assert delta.index is not None + if delta.code_interpreter: + if delta.code_interpreter.input: + assert delta.code_interpreter.input is not None + if delta.code_interpreter.outputs: + for output in delta.code_interpreter.outputs: + if output.type == "image": + assert output.image.file_id + elif output.type == "logs": + assert output.logs + if delta.type == "function": + assert delta.id + if delta.function: + assert delta.function.arguments is not None + assert delta.function.name is not None + + def on_tool_call_done(self, tool_call: ToolCall): + if tool_call.type == "code_interpreter": + assert tool_call.id + assert tool_call.code_interpreter.input is not None + for output in tool_call.code_interpreter.outputs: + if output.type == "image": + assert output.image.file_id + elif output.type == "logs": + assert output.logs + if tool_call.type == "function": + assert tool_call.id + assert tool_call.function.arguments is not None + assert tool_call.function.name is not None class TestAssistants(AzureRecordedTestCase): + def handle_run_failure(self, run: Run): + if run.status == "failed": + if "Rate limit" in run.last_error.message: + pytest.skip("Skipping - Rate limit reached.") + raise openai.OpenAIError(run.last_error.message) + if run.status not in ["completed", "requires_action"]: + raise openai.OpenAIError(f"Run in unexpected status: {run.status}") + @configure @pytest.mark.parametrize( "api_type, api_version", @@ -93,7 +267,6 @@ def test_assistants_threads_crud(self, client, api_type, api_version, **kwargs): assert delete_thread.id == thread.id assert delete_thread.deleted is True - @pytest.mark.skip(reason="AOAI doesn't support assistants v2 yet") @configure @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) def test_assistants_messages_crud(self, client, api_type, api_version, **kwargs): @@ -163,6 +336,149 @@ def test_assistants_messages_crud(self, client, api_type, api_version, **kwargs) ) assert delete_thread.id == thread.id assert delete_thread.deleted is True + delete_file = client.files.delete(file.id) + assert delete_file.deleted is True + + @configure + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + def test_assistants_vector_stores_crud(self, client, api_type, api_version, **kwargs): + file_name = f"test{uuid.uuid4()}.txt" + with open(file_name, "w") as f: + f.write("test") + + path = pathlib.Path(file_name) + + file = client.files.create( + file=open(path, "rb"), + purpose="assistants" + ) + + try: + vector_store = client.beta.vector_stores.create( + name="Support FAQ" + ) + assert vector_store.name == "Support FAQ" + assert vector_store.id + assert vector_store.object == "vector_store" + assert vector_store.created_at + assert vector_store.file_counts.total == 0 + + vectors = client.beta.vector_stores.list() + for vector in vectors: + assert vector.id + assert vector_store.object == "vector_store" + assert vector_store.created_at + + vector_store = client.beta.vector_stores.update( + vector_store_id=vector_store.id, + name="Support FAQ and more", + metadata={"Q": "A"} + ) + retrieved_vector = client.beta.vector_stores.retrieve( + vector_store_id=vector_store.id + ) + assert retrieved_vector.id == vector_store.id + assert retrieved_vector.name == "Support FAQ and more" + assert retrieved_vector.metadata == {"Q": "A"} + + vector_store_file = client.beta.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert vector_store_file.id + assert vector_store_file.object == "vector_store.file" + assert vector_store_file.created_at + assert vector_store_file.vector_store_id == vector_store.id + + vector_store_files = client.beta.vector_stores.files.list( + vector_store_id=vector_store.id + ) + for vector_file in vector_store_files: + assert vector_file.id + assert vector_file.object == "vector_store.file" + assert vector_store_file.created_at + assert vector_store_file.vector_store_id == vector_store.id + + vector_store_file_2 = client.beta.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert vector_store_file_2.id == vector_store_file.id + assert vector_store_file.vector_store_id == vector_store.id + + finally: + os.remove(path) + deleted_vector_store_file = client.beta.vector_stores.files.delete( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert deleted_vector_store_file.deleted is True + deleted_vector_store = client.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True + + @configure + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + def test_assistants_vector_stores_batch_crud(self, client, api_type, api_version, **kwargs): + file_name = f"test{uuid.uuid4()}.txt" + file_name_2 = f"test{uuid.uuid4()}.txt" + with open(file_name, "w") as f: + f.write("test") + + path = pathlib.Path(file_name) + + file = client.files.create( + file=open(path, "rb"), + purpose="assistants" + ) + with open(file_name_2, "w") as f: + f.write("test") + path_2 = pathlib.Path(file_name_2) + + file_2 = client.files.create( + file=open(path_2, "rb"), + purpose="assistants" + ) + try: + vector_store = client.beta.vector_stores.create( + name="Support FAQ" + ) + vector_store_file_batch = client.beta.vector_stores.file_batches.create( + vector_store_id=vector_store.id, + file_ids=[file.id, file_2.id] + ) + assert vector_store_file_batch.id + assert vector_store_file_batch.object == "vector_store.file_batch" + assert vector_store_file_batch.created_at + assert vector_store_file_batch.status + + vectors = client.beta.vector_stores.file_batches.list_files( + vector_store_id=vector_store.id, + batch_id=vector_store_file_batch.id + ) + for vector in vectors: + assert vector.id + assert vector.object == "vector_store.file" + assert vector.created_at + + retrieved_vector_store_file_batch = client.beta.vector_stores.file_batches.retrieve( + vector_store_id=vector_store.id, + batch_id=vector_store_file_batch.id + ) + assert retrieved_vector_store_file_batch.id == vector_store_file_batch.id + + finally: + os.remove(path) + os.remove(path_2) + delete_file = client.files.delete(file.id) + assert delete_file.deleted is True + delete_file = client.files.delete(file_2.id) + assert delete_file.deleted is True + deleted_vector_store = client.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True @configure @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) @@ -183,32 +499,20 @@ def test_assistants_runs_code(self, client, api_type, api_version, **kwargs): content="I need to solve the equation `3x + 11 = 14`. Can you help me?", ) - run = client.beta.threads.runs.create( + run = client.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe.", - # additional_instructions="After solving each equation, say 'Isn't math fun?'", # not supported by AOAI yet + additional_instructions="After solving each equation, say 'Isn't math fun?'", ) + self.handle_run_failure(run) + if run.status == "completed": + messages = client.beta.threads.messages.list(thread_id=thread.id) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) - - if run.status == "completed": - messages = client.beta.threads.messages.list(thread_id=thread.id) - - for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value + for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value - break - else: - time.sleep(5) - run = client.beta.threads.runs.update( thread_id=thread.id, run_id=run.id, @@ -229,31 +533,36 @@ def test_assistants_runs_code(self, client, api_type, api_version, **kwargs): assert delete_thread.id == thread.id assert delete_thread.deleted is True - @pytest.mark.skip("AOAI does not support retrieval tools yet") @configure @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) - def test_assistants_runs_retrieval(self, client, api_type, api_version, **kwargs): + def test_assistants_runs_file_search(self, client, api_type, api_version, **kwargs): file_name = f"test{uuid.uuid4()}.txt" with open(file_name, "w") as f: f.write("Contoso company policy requires that all employees take at least 10 vacation days a year.") path = pathlib.Path(file_name) - file = client.files.create( - file=open(path, "rb"), - purpose="assistants" - ) - try: + vector_store = client.beta.vector_stores.create( + name="Support FAQ" + ) + client.beta.vector_stores.files.upload_and_poll( + vector_store_id=vector_store.id, + file=path + ) assistant = client.beta.assistants.create( name="python test", instructions="You help answer questions about Contoso company policy.", - tools=[{"type": "retrieval"}], - file_ids=[file.id], + tools=[{"type": "file_search"}], + tool_resources={ + "file_search": { + "vector_store_ids": [vector_store.id] + } + }, **kwargs ) - run = client.beta.threads.create_and_run( + run = client.beta.threads.create_and_run_poll( assistant_id=assistant.id, thread={ "messages": [ @@ -261,25 +570,13 @@ def test_assistants_runs_retrieval(self, client, api_type, api_version, **kwargs ] } ) + self.handle_run_failure(run) + if run.status == "completed": + messages = client.beta.threads.messages.list(thread_id=run.thread_id) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = client.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) - - if run.status == "completed": - messages = client.beta.threads.messages.list(thread_id=run.thread_id) - - for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value - - break - - time.sleep(5) + for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value finally: os.remove(path) @@ -294,6 +591,10 @@ def test_assistants_runs_retrieval(self, client, api_type, api_version, **kwargs ) assert delete_thread.id assert delete_thread.deleted is True + deleted_vector_store = client.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True @configure @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) @@ -329,7 +630,7 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs **kwargs, ) - run = client.beta.threads.create_and_run( + run = client.beta.threads.create_and_run_poll( assistant_id=assistant.id, thread={ "messages": [ @@ -337,36 +638,25 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs ] } ) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = client.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) - - if run.status == "requires_action": - run = client.beta.threads.runs.submit_tool_outputs( - thread_id=run.thread_id, - run_id=run.id, - tool_outputs=[ - { - "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, - "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}" - } - ] - ) - - if run.status == "completed": - messages = client.beta.threads.messages.list(thread_id=run.thread_id) - - for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value - - break + self.handle_run_failure(run) + if run.status == "requires_action": + run = client.beta.threads.runs.submit_tool_outputs_and_poll( + thread_id=run.thread_id, + run_id=run.id, + tool_outputs=[ + { + "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, + "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}" + } + ] + ) + self.handle_run_failure(run) + if run.status == "completed": + messages = client.beta.threads.messages.list(thread_id=run.thread_id) - time.sleep(5) + for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value runs = client.beta.threads.runs.list(thread_id=run.thread_id) for r in runs: @@ -382,10 +672,13 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs thread_id=run.thread_id, run_id=r.id ) + for step in run_steps: + assert step.id + retrieved_step = client.beta.threads.runs.steps.retrieve( thread_id=run.thread_id, run_id=r.id, - step_id=run_steps.data[0].id + step_id=step.id ) assert retrieved_step.id assert retrieved_step.created_at @@ -407,3 +700,63 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs ) assert delete_thread.id assert delete_thread.deleted is True + + @configure + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + def test_assistants_streaming(self, client, api_type, api_version, **kwargs): + assistant = client.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + **kwargs, + ) + try: + thread = client.beta.threads.create() + client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="I need to solve the equation `3x + 11 = 14`. Can you help me?", + ) + stream = client.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. The user has a premium account.", + stream=True, + ) + + for event in stream: + assert event + finally: + client.beta.assistants.delete(assistant.id) + + @configure + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + def test_assistants_stream_event_handler(self, client, api_type, api_version, **kwargs): + assistant = client.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + **kwargs + ) + + try: + question = "I need to solve the equation `3x + 11 = 14`. Can you help me and then generate an image with the answer?" + + thread = client.beta.threads.create( + messages=[ + { + "role": "user", + "content": question, + }, + ] + ) + + with client.beta.threads.runs.stream( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. The user has a premium account.", + event_handler=EventHandler(), + ) as stream: + stream.until_done() + finally: + client.beta.assistants.delete(assistant.id) diff --git a/sdk/openai/azure-openai/tests/test_assistants_async.py b/sdk/openai/azure-openai/tests/test_assistants_async.py index 4eaf03fc48e1..3ceabd3dd9b3 100644 --- a/sdk/openai/azure-openai/tests/test_assistants_async.py +++ b/sdk/openai/azure-openai/tests/test_assistants_async.py @@ -4,18 +4,191 @@ # ------------------------------------ import os -import time import pytest import pathlib import uuid +import openai from devtools_testutils import AzureRecordedTestCase from conftest import ASST_AZURE, PREVIEW, GPT_4_OPENAI, configure_async - -TIMEOUT = 300 +from openai import AsyncAssistantEventHandler +from openai.types.beta.threads import ( + Text, + Message, + ImageFile, + TextDelta, + MessageDelta, +) +from openai.types.beta.threads import Run +from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta + + +class AsyncEventHandler(AsyncAssistantEventHandler): + async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: + if delta.value: + assert delta.value is not None + if delta.annotations: + for annotation in delta.annotations: + if annotation.type == "file_citation": + assert annotation.index is not None + assert annotation.file_citation.file_id + assert annotation.file_citation.quote + elif annotation.type == "file_path": + assert annotation.index is not None + assert annotation.file_path.file_id + + async def on_run_step_done(self, run_step: RunStep) -> None: + details = run_step.step_details + if details.type == "tool_calls": + for tool in details.tool_calls: + if tool.type == "code_interpreter": + assert tool.id + assert tool.code_interpreter.outputs + assert tool.code_interpreter.input is not None + elif tool.type == "function": + assert tool.id + assert tool.function.arguments is not None + assert tool.function.name is not None + + async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: + details = delta.step_details + if details is not None: + if details.type == "tool_calls": + for tool in details.tool_calls or []: + if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input: + assert tool.index is not None + assert tool.code_interpreter.input is not None + elif details.type == "message_creation": + assert details.message_creation.message_id + + async def on_run_step_created(self, run_step: RunStep): + assert run_step.object == "thread.run.step" + assert run_step.id + assert run_step.type + assert run_step.created_at + assert run_step.assistant_id + assert run_step.thread_id + assert run_step.run_id + assert run_step.status + assert run_step.step_details + + async def on_message_created(self, message: Message): + assert message.object == "thread.message" + assert message.id + assert message.created_at + assert message.attachments is not None + assert message.status + assert message.thread_id + + async def on_message_delta(self, delta: MessageDelta, snapshot: Message): + if delta.content: + for content in delta.content: + if content.type == "text": + assert content.index is not None + if content.text: + if content.text.value: + assert content.text.value is not None + if content.text.annotations: + for annot in content.text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + elif content.type == "image_file": + assert content.index is not None + assert content.image_file.file_id + + async def on_message_done(self, message: Message): + for msg in message.content: + if msg.type == "image_file": + assert msg.image_file.file_id + if msg.type == "text": + assert msg.text.value + if msg.text.annotations: + for annot in msg.text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + assert annot.text is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + assert annot.text is not None + + async def on_text_created(self, text: Text): + assert text.value is not None + + async def on_text_done(self, text: Text): + assert text.value is not None + for annot in text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + assert annot.text is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + assert annot.text is not None + + async def on_image_file_done(self, image_file: ImageFile): + assert image_file.file_id + + async def on_tool_call_created(self, tool_call: ToolCall): + assert tool_call.id + + async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall): + if delta.type == "code_interpreter": + assert delta.index is not None + if delta.code_interpreter: + if delta.code_interpreter.input: + assert delta.code_interpreter.input is not None + if delta.code_interpreter.outputs: + for output in delta.code_interpreter.outputs: + if output.type == "image": + assert output.image.file_id + elif output.type == "logs": + assert output.logs + if delta.type == "function": + assert delta.id + if delta.function: + assert delta.function.arguments is not None + assert delta.function.name is not None + + async def on_tool_call_done(self, tool_call: ToolCall): + if tool_call.type == "code_interpreter": + assert tool_call.id + assert tool_call.code_interpreter.input is not None + for output in tool_call.code_interpreter.outputs: + if output.type == "image": + assert output.image.file_id + elif output.type == "logs": + assert output.logs + if tool_call.type == "function": + assert tool_call.id + assert tool_call.function.arguments is not None + assert tool_call.function.name is not None class TestAssistantsAsync(AzureRecordedTestCase): + def handle_run_failure(self, run: Run): + if run.status == "failed": + if "Rate limit" in run.last_error.message: + pytest.skip("Skipping - Rate limit reached.") + raise openai.OpenAIError(run.last_error.message) + if run.status not in ["completed", "requires_action"]: + raise openai.OpenAIError(f"Run in unexpected status: {run.status}") + @configure_async @pytest.mark.asyncio @pytest.mark.parametrize( @@ -73,7 +246,6 @@ async def test_assistants_threads_crud(self, client_async, api_type, api_version ], metadata={"key": "value"}, ) - retrieved_thread = await client_async.beta.threads.retrieve( thread_id=thread.id, ) @@ -95,7 +267,6 @@ async def test_assistants_threads_crud(self, client_async, api_type, api_version assert delete_thread.id == thread.id assert delete_thread.deleted is True - @pytest.mark.skip(reason="AOAI doesn't support assistants v2 yet") @configure_async @pytest.mark.asyncio @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) @@ -166,12 +337,156 @@ async def test_assistants_messages_crud(self, client_async, api_type, api_versio ) assert delete_thread.id == thread.id assert delete_thread.deleted is True + delete_file = await client_async.files.delete(file.id) + assert delete_file.deleted is True @configure_async @pytest.mark.asyncio @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) - async def test_assistants_runs_code(self, client_async, api_type, api_version, **kwargs): + async def test_assistants_vector_stores_crud(self, client_async, api_type, api_version, **kwargs): + file_name = f"test{uuid.uuid4()}.txt" + with open(file_name, "w") as f: + f.write("test") + + path = pathlib.Path(file_name) + + file = await client_async.files.create( + file=open(path, "rb"), + purpose="assistants" + ) + + try: + vector_store = await client_async.beta.vector_stores.create( + name="Support FAQ" + ) + assert vector_store.name == "Support FAQ" + assert vector_store.id + assert vector_store.object == "vector_store" + assert vector_store.created_at + assert vector_store.file_counts.total == 0 + + vectors = client_async.beta.vector_stores.list() + async for vector in vectors: + assert vector.id + assert vector_store.object == "vector_store" + assert vector_store.created_at + + vector_store = await client_async.beta.vector_stores.update( + vector_store_id=vector_store.id, + name="Support FAQ and more", + metadata={"Q": "A"} + ) + retrieved_vector = await client_async.beta.vector_stores.retrieve( + vector_store_id=vector_store.id + ) + assert retrieved_vector.id == vector_store.id + assert retrieved_vector.name == "Support FAQ and more" + assert retrieved_vector.metadata == {"Q": "A"} + + vector_store_file = await client_async.beta.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert vector_store_file.id + assert vector_store_file.object == "vector_store.file" + assert vector_store_file.created_at + assert vector_store_file.vector_store_id == vector_store.id + + vector_store_files = client_async.beta.vector_stores.files.list( + vector_store_id=vector_store.id + ) + async for vector_file in vector_store_files: + assert vector_file.id + assert vector_file.object == "vector_store.file" + assert vector_store_file.created_at + assert vector_store_file.vector_store_id == vector_store.id + + vector_store_file_2 = await client_async.beta.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert vector_store_file_2.id == vector_store_file.id + assert vector_store_file.vector_store_id == vector_store.id + + finally: + os.remove(path) + deleted_vector_store_file = await client_async.beta.vector_stores.files.delete( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert deleted_vector_store_file.deleted is True + deleted_vector_store = await client_async.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True + + @configure_async + @pytest.mark.asyncio + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + async def test_assistants_vector_stores_batch_crud(self, client_async, api_type, api_version, **kwargs): + file_name = f"test{uuid.uuid4()}.txt" + file_name_2 = f"test{uuid.uuid4()}.txt" + with open(file_name, "w") as f: + f.write("test") + + path = pathlib.Path(file_name) + + file = await client_async.files.create( + file=open(path, "rb"), + purpose="assistants" + ) + with open(file_name_2, "w") as f: + f.write("test") + path_2 = pathlib.Path(file_name_2) + file_2 = await client_async.files.create( + file=open(path_2, "rb"), + purpose="assistants" + ) + try: + vector_store = await client_async.beta.vector_stores.create( + name="Support FAQ" + ) + vector_store_file_batch = await client_async.beta.vector_stores.file_batches.create( + vector_store_id=vector_store.id, + file_ids=[file.id, file_2.id] + ) + assert vector_store_file_batch.id + assert vector_store_file_batch.object == "vector_store.file_batch" + assert vector_store_file_batch.created_at + assert vector_store_file_batch.status + + vectors = await client_async.beta.vector_stores.file_batches.list_files( + vector_store_id=vector_store.id, + batch_id=vector_store_file_batch.id + ) + for vector in vectors: + assert vector.id + assert vector.object == "vector_store.file" + assert vector.created_at + + retrieved_vector_store_file_batch = await client_async.beta.vector_stores.file_batches.retrieve( + vector_store_id=vector_store.id, + batch_id=vector_store_file_batch.id + ) + assert retrieved_vector_store_file_batch.id == vector_store_file_batch.id + + finally: + os.remove(path) + os.remove(path_2) + delete_file = await client_async.files.delete(file.id) + assert delete_file.deleted is True + delete_file = await client_async.files.delete(file_2.id) + assert delete_file.deleted is True + deleted_vector_store = await client_async.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True + + @configure_async + @pytest.mark.asyncio + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + async def test_assistants_runs_code(self, client_async, api_type, api_version, **kwargs): try: assistant = await client_async.beta.assistants.create( name="python test", @@ -187,31 +502,19 @@ async def test_assistants_runs_code(self, client_async, api_type, api_version, * content="I need to solve the equation `3x + 11 = 14`. Can you help me?", ) - run = await client_async.beta.threads.runs.create( + run = await client_async.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe.", - # additional_instructions="After solving each equation, say 'Isn't math fun?'", # not supported by AOAI yet + additional_instructions="After solving each equation, say 'Isn't math fun?'", ) + self.handle_run_failure(run) + if run.status == "completed": + messages = client_async.beta.threads.messages.list(thread_id=thread.id) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = await client_async.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) - - if run.status == "completed": - messages = client_async.beta.threads.messages.list(thread_id=thread.id) - - async for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value - - break - else: - time.sleep(5) + async for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value run = await client_async.beta.threads.runs.update( thread_id=thread.id, @@ -233,31 +536,37 @@ async def test_assistants_runs_code(self, client_async, api_type, api_version, * assert delete_thread.id == thread.id assert delete_thread.deleted is True - @pytest.mark.skip("AOAI does not support retrieval tools yet") @configure_async @pytest.mark.asyncio @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) - async def test_assistants_runs_retrieval(self, client_async, api_type, api_version, **kwargs): + async def test_assistants_runs_file_search(self, client_async, api_type, api_version, **kwargs): file_name = f"test{uuid.uuid4()}.txt" with open(file_name, "w") as f: f.write("Contoso company policy requires that all employees take at least 10 vacation days a year.") path = pathlib.Path(file_name) - file = await client_async.files.create( - file=open(path, "rb"), - purpose="assistants" - ) try: + vector_store = await client_async.beta.vector_stores.create( + name="Support FAQ", + ) + await client_async.beta.vector_stores.files.upload_and_poll( + vector_store_id=vector_store.id, + file=path + ) assistant = await client_async.beta.assistants.create( name="python test", instructions="You help answer questions about Contoso company policy.", - tools=[{"type": "retrieval"}], - file_ids=[file.id], + tools=[{"type": "file_search"}], + tool_resources={ + "file_search": { + "vector_store_ids": [vector_store.id] + } + }, **kwargs ) - run = await client_async.beta.threads.create_and_run( + run = await client_async.beta.threads.create_and_run_poll( assistant_id=assistant.id, thread={ "messages": [ @@ -265,25 +574,13 @@ async def test_assistants_runs_retrieval(self, client_async, api_type, api_versi ] } ) + self.handle_run_failure(run) + if run.status == "completed": + messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = await client_async.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) - - if run.status == "completed": - messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) - - async for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value - - break - - time.sleep(5) + async for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value finally: os.remove(path) @@ -298,6 +595,10 @@ async def test_assistants_runs_retrieval(self, client_async, api_type, api_versi ) assert delete_thread.id assert delete_thread.deleted is True + deleted_vector_store = await client_async.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True @configure_async @pytest.mark.asyncio @@ -334,7 +635,7 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi **kwargs, ) - run = await client_async.beta.threads.create_and_run( + run = await client_async.beta.threads.create_and_run_poll( assistant_id=assistant.id, thread={ "messages": [ @@ -342,36 +643,26 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi ] } ) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = await client_async.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) - - if run.status == "requires_action": - run = await client_async.beta.threads.runs.submit_tool_outputs( - thread_id=run.thread_id, - run_id=run.id, - tool_outputs=[ - { - "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, - "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}" - } - ] - ) - - if run.status == "completed": - messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) - - async for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value + self.handle_run_failure(run) + if run.status == "requires_action": + run = await client_async.beta.threads.runs.submit_tool_outputs_and_poll( + thread_id=run.thread_id, + run_id=run.id, + tool_outputs=[ + { + "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, + "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}" + } + ] + ) + self.handle_run_failure(run) + if run.status == "completed": + messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) - break + async for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value - time.sleep(5) runs = client_async.beta.threads.runs.list(thread_id=run.thread_id) async for r in runs: @@ -415,3 +706,65 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi ) assert delete_thread.id assert delete_thread.deleted is True + + @configure_async + @pytest.mark.asyncio + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + async def test_assistants_streaming(self, client_async, api_type, api_version, **kwargs): + assistant = await client_async.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + **kwargs, + ) + try: + thread = await client_async.beta.threads.create() + await client_async.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="I need to solve the equation `3x + 11 = 14`. Can you help me?", + ) + stream = await client_async.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. The user has a premium account.", + stream=True, + ) + + async for event in stream: + assert event + finally: + await client_async.beta.assistants.delete(assistant.id) + + @configure_async + @pytest.mark.asyncio + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + async def test_assistants_stream_event_handler(self, client_async, api_type, api_version, **kwargs): + assistant = await client_async.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + **kwargs + ) + + try: + question = "I need to solve the equation `3x + 11 = 14`. Can you help me and then generate an image with the answer?" + + thread = await client_async.beta.threads.create( + messages=[ + { + "role": "user", + "content": question, + }, + ] + ) + + async with client_async.beta.threads.runs.stream( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. The user has a premium account.", + event_handler=AsyncEventHandler(), + ) as stream: + await stream.until_done() + finally: + await client_async.beta.assistants.delete(assistant.id) diff --git a/sdk/openai/azure-openai/tests/test_chat_completions.py b/sdk/openai/azure-openai/tests/test_chat_completions.py index 135fcfe11154..ce8e526a0d46 100644 --- a/sdk/openai/azure-openai/tests/test_chat_completions.py +++ b/sdk/openai/azure-openai/tests/test_chat_completions.py @@ -814,7 +814,7 @@ def test_chat_completion_seed(self, client, api_type, api_version, **kwargs): assert completion.system_fingerprint @configure - @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, GA), (GPT_4_OPENAI, "v1")]) + @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) def test_chat_completion_json_response(self, client, api_type, api_version, **kwargs): messages = [ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/sdk/openai/azure-openai/tests/test_chat_completions_async.py b/sdk/openai/azure-openai/tests/test_chat_completions_async.py index 46919ab4101e..1ebd605514b2 100644 --- a/sdk/openai/azure-openai/tests/test_chat_completions_async.py +++ b/sdk/openai/azure-openai/tests/test_chat_completions_async.py @@ -834,7 +834,7 @@ async def test_chat_completion_seed(self, client_async, api_type, api_version, * @configure_async @pytest.mark.asyncio - @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, GA), (GPT_4_OPENAI, "v1")]) + @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) async def test_chat_completion_json_response(self, client_async, api_type, api_version, **kwargs): messages = [ {"role": "system", "content": "You are a helpful assistant."},