Skip to content

Commit 6c3abe2

Browse files
authored
Update train.py
1 parent 4627080 commit 6c3abe2

File tree

1 file changed

+11
-2
lines changed
  • 01.getting-started/04.train-on-remote-vm

1 file changed

+11
-2
lines changed

01.getting-started/04.train-on-remote-vm/train.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# Licensed under the MIT license.
33

44
import os
5-
from sklearn.datasets import load_diabetes
5+
import argparse
6+
67
from sklearn.linear_model import Ridge
78
from sklearn.metrics import mean_squared_error
89
from sklearn.model_selection import train_test_split
@@ -12,8 +13,16 @@
1213
import numpy as np
1314

1415
os.makedirs('./outputs', exist_ok=True)
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument('--data-folder', type=str,
18+
dest='data_folder', help='data folder')
19+
args = parser.parse_args()
20+
21+
print('Data folder is at:', args.data_folder)
22+
print('List all files: ', os.listdir(args.data_folder))
1523

16-
X, y = load_diabetes(return_X_y=True)
24+
X = np.load(os.path.join(args.data_folder, 'features.npy'))
25+
y = np.load(os.path.join(args.data_folder, 'labels.npy'))
1726

1827
run = Run.get_submitted_run()
1928

0 commit comments

Comments
 (0)