| 
10 | 10 | from azureml.automl.runtime.shared.score import scoring, constants  | 
11 | 11 | from azureml.core import Run  | 
12 | 12 | 
 
  | 
 | 13 | +import torch  | 
 | 14 | + | 
13 | 15 | 
 
  | 
14 | 16 | def align_outputs(y_predicted, X_trans, X_test, y_test,  | 
15 | 17 |                   predicted_column_name='predicted',  | 
@@ -221,6 +223,10 @@ def MAPE(actual, pred):  | 
221 | 223 |     return np.mean(APE(actual_safe, pred_safe))  | 
222 | 224 | 
 
  | 
223 | 225 | 
 
  | 
 | 226 | +def map_location_cuda(storage, loc):  | 
 | 227 | +    return storage.cuda()  | 
 | 228 | + | 
 | 229 | + | 
224 | 230 | parser = argparse.ArgumentParser()  | 
225 | 231 | parser.add_argument(  | 
226 | 232 |     '--max_horizon', type=int, dest='max_horizon',  | 
@@ -274,8 +280,13 @@ def MAPE(actual, pred):  | 
274 | 280 | y_lookback_df = lookback_dataset.with_timestamp_columns(  | 
275 | 281 |     None).keep_columns(columns=[target_column_name])  | 
276 | 282 | 
 
  | 
277 |  | -fitted_model = joblib.load(model_path)  | 
278 |  | - | 
 | 283 | +# Load the trained model with torch.  | 
 | 284 | +if torch.cuda.is_available():  | 
 | 285 | +    map_location = map_location_cuda  | 
 | 286 | +else:  | 
 | 287 | +    map_location = 'cpu'  | 
 | 288 | +with open(model_path, 'rb') as fh:  | 
 | 289 | +    fitted_model = torch.load(fh, map_location=map_location)  | 
279 | 290 | 
 
  | 
280 | 291 | if hasattr(fitted_model, 'get_lookback'):  | 
281 | 292 |     lookback = fitted_model.get_lookback()  | 
 | 
0 commit comments