@@ -33,7 +33,7 @@ def __init__(self, tracking_uri: str, s3_file_system: S3FileSystem):
3333 self ._client = MlflowClient (tracking_uri = tracking_uri )
3434 self ._s3_file_system = s3_file_system
3535
36- def create_experiment (
36+ async def create_experiment (
3737 self ,
3838 name : str ,
3939 description : str ,
@@ -79,15 +79,15 @@ def create_experiment(
7979 created_at = datetime .fromtimestamp (experiment .creation_time / 1000 ),
8080 )
8181
82- def delete_experiment (self , experiment_id : str ) -> None :
82+ async def delete_experiment (self , experiment_id : str ) -> None :
8383 """Delete an experiment. Although Mflow has a delete_experiment method,
8484 We will use the functions of this class instead, so that we make sure we correctly
8585 clean up all the artifacts/runs/etc. associated with the experiment.
8686 """
8787 workflow_ids = self ._find_workflows (experiment_id )
8888 # delete all the workflows
8989 for workflow_id in workflow_ids :
90- self .delete_workflow (workflow_id )
90+ await self .delete_workflow (workflow_id )
9191 # delete the experiment
9292 self ._client .delete_experiment (experiment_id )
9393
@@ -164,7 +164,7 @@ async def _format_experiment(self, experiment: MlflowExperiment) -> GetExperimen
164164 workflows = workflows ,
165165 )
166166
167- def update_experiment (self , experiment_id : str , new_name : str ) -> None :
167+ async def update_experiment (self , experiment_id : str , new_name : str ) -> None :
168168 """Update the name of an experiment."""
169169 raise NotImplementedError
170170
@@ -199,7 +199,7 @@ async def experiments_count(self):
199199 # this corresponds to creating a run in MLflow.
200200 # The run will have n number of nested runs,
201201 # which correspond to what we call "jobs" in our system
202- def create_workflow (
202+ async def create_workflow (
203203 self , experiment_id : str , description : str , name : str , model : str , system_prompt : str
204204 ) -> WorkflowResponse :
205205 """Create a new workflow."""
@@ -256,7 +256,7 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None
256256 system_prompt = workflow .data .tags .get ("system_prompt" ),
257257 status = WorkflowStatus (workflow .data .tags .get ("status" )),
258258 created_at = datetime .fromtimestamp (workflow .info .start_time / 1000 ),
259- jobs = [self .get_job (job_id ) for job_id in all_job_ids ],
259+ jobs = [await self .get_job (job_id ) for job_id in all_job_ids ],
260260 metrics = self ._compile_metrics (all_job_ids ),
261261 parameters = self ._compile_parameters (all_job_ids ),
262262 )
@@ -302,7 +302,7 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None
302302 workflow_details .artifacts_download_url = download_url
303303 return workflow_details
304304
305- def update_workflow_status (self , workflow_id : str , status : WorkflowStatus ) -> None :
305+ async def update_workflow_status (self , workflow_id : str , status : WorkflowStatus ) -> None :
306306 """Update the status of a workflow."""
307307 self ._client .set_tag (workflow_id , "status" , status .value )
308308
@@ -328,7 +328,7 @@ def _get_ray_job_logs(self, ray_job_id: str):
328328 loguru .logger .error (f"Response text: { resp .text } " )
329329 raise JobUpstreamError (ray_job_id , "JSON decode error in Ray response" ) from e
330330
331- def get_workflow_logs (self , workflow_id : str ) -> JobLogsResponse :
331+ async def get_workflow_logs (self , workflow_id : str ) -> JobLogsResponse :
332332 workflow_run = self ._client .get_run (workflow_id )
333333 # get the jobs associated with the workflow
334334 all_jobs = self ._client .search_runs (
@@ -343,7 +343,7 @@ def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
343343 # TODO: This is not a great solution but it matches the current API
344344 return JobLogsResponse (logs = "\n ================\n " .join ([log .logs for log in logs ]))
345345
346- def delete_workflow (self , workflow_id : str ) -> WorkflowResponse :
346+ async def delete_workflow (self , workflow_id : str ) -> WorkflowResponse :
347347 """Delete a workflow."""
348348 # first, get the workflow
349349 workflow = self ._client .get_run (workflow_id )
@@ -370,11 +370,11 @@ def delete_workflow(self, workflow_id: str) -> WorkflowResponse:
370370 created_at = datetime .fromtimestamp (workflow .info .start_time / 1000 ),
371371 )
372372
373- def list_workflows (self , experiment_id : str ) -> list :
373+ async def list_workflows (self , experiment_id : str ) -> list :
374374 """List all workflows in an experiment."""
375375 raise NotImplementedError
376376
377- def create_job (self , experiment_id : str , workflow_id : str , name : str , job_id : str ) -> str :
377+ async def create_job (self , experiment_id : str , workflow_id : str , name : str , job_id : str ) -> str :
378378 """Link a started job to an experiment and a workflow."""
379379 run = self ._client .create_run (
380380 experiment_id = experiment_id ,
@@ -384,14 +384,14 @@ def create_job(self, experiment_id: str, workflow_id: str, name: str, job_id: st
384384 self ._client .log_param (run .info .run_id , "ray_job_id" , job_id )
385385 return run .info .run_id
386386
387- def update_job (self , job_id : str , data : RunOutputs ):
387+ async def update_job (self , job_id : str , data : RunOutputs ):
388388 """Update the metrics and parameters of a job."""
389389 for metric , value in data .metrics .items ():
390390 self ._client .log_metric (job_id , metric , value )
391391 for parameter , value in data .parameters .items ():
392392 self ._client .log_param (job_id , parameter , value )
393393
394- def get_job (self , job_id : str ):
394+ async def get_job (self , job_id : str ):
395395 """Get the results of a job."""
396396 run = self ._client .get_run (job_id )
397397 if run .info .lifecycle_stage == "deleted" :
@@ -404,11 +404,11 @@ def get_job(self, job_id: str):
404404 artifact_url = "TODO" ,
405405 )
406406
407- def delete_job (self , job_id : str ):
407+ async def delete_job (self , job_id : str ):
408408 """Delete a job."""
409409 self ._client .delete_run (job_id )
410410
411- def list_jobs (self , workflow_id : str ):
411+ async def list_jobs (self , workflow_id : str ):
412412 """List all jobs in a workflow."""
413413 workflow_run = self ._client .get_run (workflow_id )
414414 # get the jobs associated with the workflow
0 commit comments