Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix more flakiness and only run openai on weekly
  • Loading branch information
kristapratico committed Jun 7, 2024
commit ee5553a3af40d4e29125059166444d7dbc8f212e
11 changes: 11 additions & 0 deletions sdk/openai/azure-openai/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DefaultAzureCredential as AsyncDefaultAzureCredential,
get_bearer_token_provider as get_bearer_token_provider_async,
)
from ci_tools.variables import in_ci


# for pytest.parametrize
Expand Down Expand Up @@ -65,8 +66,15 @@
ENV_OPENAI_TTS_MODEL = "tts-1"


def skip_openai_test(api_type) -> bool:
return in_ci() and "openai" in api_type and "tests-weekly" not in os.getenv("SYSTEM_DEFINITIONNAME", "")


@pytest.fixture
def client(api_type, api_version):
if skip_openai_test(api_type):
pytest.skip("Skipping openai tests - they only run on tests-weekly.")

if api_type == "azure":
client = openai.AzureOpenAI(
azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT),
Expand Down Expand Up @@ -100,6 +108,9 @@ def client(api_type, api_version):

@pytest.fixture
def client_async(api_type, api_version):
if skip_openai_test(api_type):
pytest.skip("Skipping openai tests - they only run on tests-weekly.")

if api_type == "azure":
client = openai.AsyncAzureOpenAI(
azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT),
Expand Down
27 changes: 18 additions & 9 deletions sdk/openai/azure-openai/tests/test_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TextDelta,
MessageDelta,
)
from openai.types.beta.threads import Run
from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta


Expand Down Expand Up @@ -181,6 +182,14 @@ def on_tool_call_done(self, tool_call: ToolCall):

class TestAssistants(AzureRecordedTestCase):

def handle_run_failure(self, run: Run):
if run.status == "failed":
if "Rate limit" in run.last_error.message:
pytest.skip("Skipping - Rate limit reached.")
raise openai.OpenAIError(run.last_error.message)
if run.status not in ["completed", "requires_action"]:
raise openai.OpenAIError(f"Run in unexpected status: {run.status}")

@configure
@pytest.mark.parametrize(
"api_type, api_version",
Expand Down Expand Up @@ -496,8 +505,7 @@ def test_assistants_runs_code(self, client, api_type, api_version, **kwargs):
instructions="Please address the user as Jane Doe.",
additional_instructions="After solving each equation, say 'Isn't math fun?'",
)
if run.status == "failed":
raise openai.OpenAIError(run.last_error.message)
self.handle_run_failure(run)
if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=thread.id)

Expand Down Expand Up @@ -540,7 +548,7 @@ def test_assistants_runs_file_search(self, client, api_type, api_version, **kwar
)
client.beta.vector_stores.files.upload_and_poll(
vector_store_id=vector_store.id,
file_id=path
file=path
)
assistant = client.beta.assistants.create(
name="python test",
Expand All @@ -562,8 +570,7 @@ def test_assistants_runs_file_search(self, client, api_type, api_version, **kwar
]
}
)
if run.status == "failed":
raise openai.OpenAIError(run.last_error.message)
self.handle_run_failure(run)
if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=run.thread_id)

Expand Down Expand Up @@ -631,8 +638,7 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs
]
}
)
if run.status == "failed":
raise openai.OpenAIError(run.last_error.message)
self.handle_run_failure(run)
if run.status == "requires_action":
run = client.beta.threads.runs.submit_tool_outputs_and_poll(
thread_id=run.thread_id,
Expand All @@ -644,7 +650,7 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs
}
]
)

self.handle_run_failure(run)
if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=run.thread_id)

Expand All @@ -666,10 +672,13 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs
thread_id=run.thread_id,
run_id=r.id
)
for step in run_steps:
assert step.id

retrieved_step = client.beta.threads.runs.steps.retrieve(
thread_id=run.thread_id,
run_id=r.id,
step_id=run_steps.data[0].id
step_id=step.id
)
assert retrieved_step.id
assert retrieved_step.created_at
Expand Down
62 changes: 23 additions & 39 deletions sdk/openai/azure-openai/tests/test_assistants_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# ------------------------------------

import os
import time
import pytest
import pathlib
import uuid
Expand All @@ -19,10 +18,9 @@
TextDelta,
MessageDelta,
)
from openai.types.beta.threads import Run
from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta

TIMEOUT = 300


class AsyncEventHandler(AsyncAssistantEventHandler):
async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
Expand Down Expand Up @@ -183,6 +181,14 @@ async def on_tool_call_done(self, tool_call: ToolCall):

class TestAssistantsAsync(AzureRecordedTestCase):

def handle_run_failure(self, run: Run):
if run.status == "failed":
if "Rate limit" in run.last_error.message:
pytest.skip("Skipping - Rate limit reached.")
raise openai.OpenAIError(run.last_error.message)
if run.status not in ["completed", "requires_action"]:
raise openai.OpenAIError(f"Run in unexpected status: {run.status}")

@configure_async
@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -240,7 +246,6 @@ async def test_assistants_threads_crud(self, client_async, api_type, api_version
],
metadata={"key": "value"},
)

retrieved_thread = await client_async.beta.threads.retrieve(
thread_id=thread.id,
)
Expand Down Expand Up @@ -482,7 +487,6 @@ async def test_assistants_vector_stores_batch_crud(self, client_async, api_type,
@pytest.mark.asyncio
@pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")])
async def test_assistants_runs_code(self, client_async, api_type, api_version, **kwargs):

try:
assistant = await client_async.beta.assistants.create(
name="python test",
Expand All @@ -498,31 +502,19 @@ async def test_assistants_runs_code(self, client_async, api_type, api_version, *
content="I need to solve the equation `3x + 11 = 14`. Can you help me?",
)

run = await client_async.beta.threads.runs.create(
run = await client_async.beta.threads.runs.create_and_poll(
thread_id=thread.id,
assistant_id=assistant.id,
instructions="Please address the user as Jane Doe.",
additional_instructions="After solving each equation, say 'Isn't math fun?'",
)
self.handle_run_failure(run)
if run.status == "completed":
messages = client_async.beta.threads.messages.list(thread_id=thread.id)

start_time = time.time()

while True:
if time.time() - start_time > TIMEOUT:
raise TimeoutError("Run timed out")

run = await client_async.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)

if run.status == "completed":
messages = client_async.beta.threads.messages.list(thread_id=thread.id)

async for message in messages:
assert message.content[0].type == "text"
assert message.content[0].text.value

break
else:
time.sleep(5)
async for message in messages:
assert message.content[0].type == "text"
assert message.content[0].text.value

run = await client_async.beta.threads.runs.update(
thread_id=thread.id,
Expand Down Expand Up @@ -554,14 +546,13 @@ async def test_assistants_runs_file_search(self, client_async, api_type, api_ver

path = pathlib.Path(file_name)

file = await client_async.files.create(
file=open(path, "rb"),
purpose="assistants"
)
try:
vector_store = await client_async.beta.vector_stores.create(
name="Support FAQ",
file_ids=[file.id]
)
await client_async.beta.vector_stores.files.upload_and_poll(
vector_store_id=vector_store.id,
file=path
)
assistant = await client_async.beta.assistants.create(
name="python test",
Expand All @@ -583,7 +574,7 @@ async def test_assistants_runs_file_search(self, client_async, api_type, api_ver
]
}
)

self.handle_run_failure(run)
if run.status == "completed":
messages = client_async.beta.threads.messages.list(thread_id=run.thread_id)

Expand All @@ -604,11 +595,6 @@ async def test_assistants_runs_file_search(self, client_async, api_type, api_ver
)
assert delete_thread.id
assert delete_thread.deleted is True
deleted_vector_store_file = await client_async.beta.vector_stores.files.delete(
vector_store_id=vector_store.id,
file_id=file.id
)
assert deleted_vector_store_file.deleted is True
deleted_vector_store = await client_async.beta.vector_stores.delete(
vector_store_id=vector_store.id
)
Expand Down Expand Up @@ -657,9 +643,7 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi
]
}
)

if run.status == "failed":
raise openai.OpenAIError(run.last_error.message)
self.handle_run_failure(run)
if run.status == "requires_action":
run = await client_async.beta.threads.runs.submit_tool_outputs_and_poll(
thread_id=run.thread_id,
Expand All @@ -671,7 +655,7 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi
}
]
)

self.handle_run_failure(run)
if run.status == "completed":
messages = client_async.beta.threads.messages.list(thread_id=run.thread_id)

Expand Down