| 
1 | 1 | import argparse  | 
 | 2 | +import os  | 
2 | 3 | 
 
  | 
3 | 4 | import numpy as np  | 
4 | 5 | import pandas as pd  | 
 | 
10 | 11 | from azureml.automl.runtime.shared.score import scoring, constants  | 
11 | 12 | from azureml.core import Run  | 
12 | 13 | 
 
  | 
13 |  | -import torch  | 
 | 14 | +try:  | 
 | 15 | +    import torch  | 
 | 16 | + | 
 | 17 | +    _torch_present = True  | 
 | 18 | +except ImportError:  | 
 | 19 | +    _torch_present = False  | 
14 | 20 | 
 
  | 
15 | 21 | 
 
  | 
16 | 22 | def align_outputs(y_predicted, X_trans, X_test, y_test,  | 
@@ -50,7 +56,7 @@ def align_outputs(y_predicted, X_trans, X_test, y_test,  | 
50 | 56 |     # or at edges of time due to lags/rolling windows  | 
51 | 57 |     clean = together[together[[target_column_name,  | 
52 | 58 |                                predicted_column_name]].notnull().all(axis=1)]  | 
53 |  | -    return(clean)  | 
 | 59 | +    return (clean)  | 
54 | 60 | 
 
  | 
55 | 61 | 
 
  | 
56 | 62 | def do_rolling_forecast_with_lookback(fitted_model, X_test, y_test,  | 
@@ -85,8 +91,7 @@ def do_rolling_forecast_with_lookback(fitted_model, X_test, y_test,  | 
85 | 91 |         if origin_time != X[time_column_name].min():  | 
86 | 92 |             # Set the context by including actuals up-to the origin time  | 
87 | 93 |             test_context_expand_wind = (X[time_column_name] < origin_time)  | 
88 |  | -            context_expand_wind = (  | 
89 |  | -                X_test_expand[time_column_name] < origin_time)  | 
 | 94 | +            context_expand_wind = (X_test_expand[time_column_name] < origin_time)  | 
90 | 95 |             y_query_expand[context_expand_wind] = y[test_context_expand_wind]  | 
91 | 96 | 
 
  | 
92 | 97 |         # Print some debug info  | 
@@ -117,8 +122,7 @@ def do_rolling_forecast_with_lookback(fitted_model, X_test, y_test,  | 
117 | 122 |         # Align forecast with test set for dates within  | 
118 | 123 |         # the current rolling window  | 
119 | 124 |         trans_tindex = X_trans.index.get_level_values(time_column_name)  | 
120 |  | -        trans_roll_wind = (trans_tindex >= origin_time) & (  | 
121 |  | -            trans_tindex < horizon_time)  | 
 | 125 | +        trans_roll_wind = (trans_tindex >= origin_time) & (trans_tindex < horizon_time)  | 
122 | 126 |         test_roll_wind = expand_wind & (X[time_column_name] >= origin_time)  | 
123 | 127 |         df_list.append(align_outputs(  | 
124 | 128 |             y_fcst[trans_roll_wind], X_trans[trans_roll_wind],  | 
@@ -157,8 +161,7 @@ def do_rolling_forecast(fitted_model, X_test, y_test, max_horizon, freq='D'):  | 
157 | 161 |         if origin_time != X_test[time_column_name].min():  | 
158 | 162 |             # Set the context by including actuals up-to the origin time  | 
159 | 163 |             test_context_expand_wind = (X_test[time_column_name] < origin_time)  | 
160 |  | -            context_expand_wind = (  | 
161 |  | -                X_test_expand[time_column_name] < origin_time)  | 
 | 164 | +            context_expand_wind = (X_test_expand[time_column_name] < origin_time)  | 
162 | 165 |             y_query_expand[context_expand_wind] = y_test[  | 
163 | 166 |                 test_context_expand_wind]  | 
164 | 167 | 
 
  | 
@@ -188,10 +191,8 @@ def do_rolling_forecast(fitted_model, X_test, y_test, max_horizon, freq='D'):  | 
188 | 191 |         # Align forecast with test set for dates within the  | 
189 | 192 |         # current rolling window  | 
190 | 193 |         trans_tindex = X_trans.index.get_level_values(time_column_name)  | 
191 |  | -        trans_roll_wind = (trans_tindex >= origin_time) & (  | 
192 |  | -            trans_tindex < horizon_time)  | 
193 |  | -        test_roll_wind = expand_wind & (  | 
194 |  | -            X_test[time_column_name] >= origin_time)  | 
 | 194 | +        trans_roll_wind = (trans_tindex >= origin_time) & (trans_tindex < horizon_time)  | 
 | 195 | +        test_roll_wind = expand_wind & (X_test[time_column_name] >= origin_time)  | 
195 | 196 |         df_list.append(align_outputs(y_fcst[trans_roll_wind],  | 
196 | 197 |                                      X_trans[trans_roll_wind],  | 
197 | 198 |                                      X_test[test_roll_wind],  | 
@@ -244,15 +245,13 @@ def map_location_cuda(storage, loc):  | 
244 | 245 |     '--model_path', type=str, dest='model_path',  | 
245 | 246 |     default='model.pkl', help='Filename of model to be loaded')  | 
246 | 247 | 
 
  | 
247 |  | - | 
248 | 248 | args = parser.parse_args()  | 
249 | 249 | max_horizon = args.max_horizon  | 
250 | 250 | target_column_name = args.target_column_name  | 
251 | 251 | time_column_name = args.time_column_name  | 
252 | 252 | freq = args.freq  | 
253 | 253 | model_path = args.model_path  | 
254 | 254 | 
 
  | 
255 |  | - | 
256 | 255 | print('args passed are: ')  | 
257 | 256 | print(max_horizon)  | 
258 | 257 | print(target_column_name)  | 
@@ -280,13 +279,19 @@ def map_location_cuda(storage, loc):  | 
280 | 279 | y_lookback_df = lookback_dataset.with_timestamp_columns(  | 
281 | 280 |     None).keep_columns(columns=[target_column_name])  | 
282 | 281 | 
 
  | 
283 |  | -# Load the trained model with torch.  | 
284 |  | -if torch.cuda.is_available():  | 
285 |  | -    map_location = map_location_cuda  | 
 | 282 | +_, ext = os.path.splitext(model_path)  | 
 | 283 | +if ext == '.pt':  | 
 | 284 | +    # Load the fc-tcn torch model.  | 
 | 285 | +    assert _torch_present  | 
 | 286 | +    if torch.cuda.is_available():  | 
 | 287 | +        map_location = map_location_cuda  | 
 | 288 | +    else:  | 
 | 289 | +        map_location = 'cpu'  | 
 | 290 | +    with open(model_path, 'rb') as fh:  | 
 | 291 | +        fitted_model = torch.load(fh, map_location=map_location)  | 
286 | 292 | else:  | 
287 |  | -    map_location = 'cpu'  | 
288 |  | -with open(model_path, 'rb') as fh:  | 
289 |  | -    fitted_model = torch.load(fh, map_location=map_location)  | 
 | 293 | +    # Load the sklearn pipeline.  | 
 | 294 | +    fitted_model = joblib.load(model_path)  | 
290 | 295 | 
 
  | 
291 | 296 | if hasattr(fitted_model, 'get_lookback'):  | 
292 | 297 |     lookback = fitted_model.get_lookback()  | 
 | 
0 commit comments