|
10 | 10 | from azureml.automl.runtime.shared.score import scoring, constants |
11 | 11 | from azureml.core import Run |
12 | 12 |
|
| 13 | +import torch |
| 14 | + |
13 | 15 |
|
14 | 16 | def align_outputs(y_predicted, X_trans, X_test, y_test, |
15 | 17 | predicted_column_name='predicted', |
@@ -221,6 +223,10 @@ def MAPE(actual, pred): |
221 | 223 | return np.mean(APE(actual_safe, pred_safe)) |
222 | 224 |
|
223 | 225 |
|
| 226 | +def map_location_cuda(storage, loc): |
| 227 | + return storage.cuda() |
| 228 | + |
| 229 | + |
224 | 230 | parser = argparse.ArgumentParser() |
225 | 231 | parser.add_argument( |
226 | 232 | '--max_horizon', type=int, dest='max_horizon', |
@@ -274,8 +280,13 @@ def MAPE(actual, pred): |
274 | 280 | y_lookback_df = lookback_dataset.with_timestamp_columns( |
275 | 281 | None).keep_columns(columns=[target_column_name]) |
276 | 282 |
|
277 | | -fitted_model = joblib.load(model_path) |
278 | | - |
| 283 | +# Load the trained model with torch. |
| 284 | +if torch.cuda.is_available(): |
| 285 | + map_location = map_location_cuda |
| 286 | +else: |
| 287 | + map_location = 'cpu' |
| 288 | +with open(model_path, 'rb') as fh: |
| 289 | + fitted_model = torch.load(fh, map_location=map_location) |
279 | 290 |
|
280 | 291 | if hasattr(fitted_model, 'get_lookback'): |
281 | 292 | lookback = fitted_model.get_lookback() |
|
0 commit comments