11from  azureml .pipeline .core .graph  import  PipelineParameter 
22from  azureml .pipeline .steps  import  PythonScriptStep 
33from  azureml .pipeline .core  import  Pipeline , PipelineData 
4- from  azureml .core  import  Workspace 
4+ from  azureml .core  import  Workspace ,  Dataset ,  Datastore 
55from  azureml .core .runconfig  import  RunConfiguration 
6- from  azureml .core  import  Dataset 
76from  ml_service .util .attach_compute  import  get_compute 
87from  ml_service .util .env_variables  import  Env 
98from  ml_service .util .manage_environment  import  get_environment 
@@ -39,8 +38,20 @@ def main():
3938    run_config  =  RunConfiguration ()
4039    run_config .environment  =  environment 
4140
41+     if  (e .datastore_name ):
42+         datastore_name  =  e .datastore_name 
43+     else :
44+         datastore_name  =  aml_workspace .get_default_datastore ().name 
45+     run_config .environment .environment_variables ["DATASTORE_NAME" ] =  datastore_name   # NOQA: E501 
46+ 
4247    model_name_param  =  PipelineParameter (
4348        name = "model_name" , default_value = e .model_name )
49+     dataset_version_param  =  PipelineParameter (
50+         name = "dataset_version" , default_value = e .dataset_version )
51+     data_file_path_param  =  PipelineParameter (
52+         name = "data_file_path" , default_value = "none" )
53+     caller_run_id_param  =  PipelineParameter (
54+         name = "caller_run_id" , default_value = "none" )
4455
4556    # Get dataset name 
4657    dataset_name  =  e .dataset_name 
@@ -57,9 +68,9 @@ def main():
5768        df .to_csv (file_name , index = False )
5869
5970        # Upload file to default datastore in workspace 
60-         default_ds  =  aml_workspace . get_default_datastore ( )
71+         datatstore  =  Datastore . get ( aml_workspace ,  datastore_name )
6172        target_path  =  'training-data/' 
62-         default_ds .upload_files (
73+         datatstore .upload_files (
6374            files = [file_name ],
6475            target_path = target_path ,
6576            overwrite = True ,
@@ -68,17 +79,14 @@ def main():
6879        # Register dataset 
6980        path_on_datastore  =  os .path .join (target_path , file_name )
7081        dataset  =  Dataset .Tabular .from_delimited_files (
71-             path = (default_ds , path_on_datastore ))
82+             path = (datatstore , path_on_datastore ))
7283        dataset  =  dataset .register (
7384            workspace = aml_workspace ,
7485            name = dataset_name ,
7586            description = 'diabetes training data' ,
7687            tags = {'format' : 'CSV' },
7788            create_new_version = True )
7889
79-     # Get the dataset 
80-     dataset  =  Dataset .get_by_name (aml_workspace , dataset_name )
81- 
8290    # Create a PipelineData to pass data between steps 
8391    pipeline_data  =  PipelineData (
8492        'pipeline_data' ,
@@ -89,11 +97,14 @@ def main():
8997        script_name = e .train_script_path ,
9098        compute_target = aml_compute ,
9199        source_directory = e .sources_directory_train ,
92-         inputs = [dataset .as_named_input ('training_data' )],
93100        outputs = [pipeline_data ],
94101        arguments = [
95102            "--model_name" , model_name_param ,
96-             "--step_output" , pipeline_data 
103+             "--step_output" , pipeline_data ,
104+             "--dataset_version" , dataset_version_param ,
105+             "--data_file_path" , data_file_path_param ,
106+             "--caller_run_id" , caller_run_id_param ,
107+             "--dataset_name" , dataset_name ,
97108        ],
98109        runconfig = run_config ,
99110        allow_reuse = False ,
0 commit comments