Skip to content

Commit 8a9a83d

Browse files
committed
Fix the dataset loading when in the code directory.
1 parent c475870 commit 8a9a83d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

code/logistic_sgd.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,12 @@ def load_data(dataset):
159159
data_dir, data_file = os.path.split(dataset)
160160
if data_dir == "" and not os.path.isfile(dataset):
161161
# Check if dataset is in the data directory.
162-
new_path = os.path.join(os.path.split(os.path.split(__file__)[0])[0],
163-
"data", dataset)
162+
new_path = os.path.split(__file__)
163+
if new_path[0] == "":
164+
new_path = os.path.join('..', "data", dataset)
165+
else:
166+
new_path = os.path.join(os.path.split(os.path.split(__file__)[0])[0],
167+
"data", dataset)
164168
if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz':
165169
dataset = new_path
166170

0 commit comments

Comments
 (0)