Skip to content

Commit d049be1

Browse files
committed
Updates to fix sort performance
1 parent b0a58f8 commit d049be1

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

trainer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def str2bool(v):
278278
# Use beam search decoding for validation
279279
model.actor_net.decoder.decode_type = "beam_search"
280280

281-
print('\n~Validating with beam search decoding~\n')
281+
print('\n~Validating~\n')
282282

283283
example_input = []
284284
example_output = []
@@ -331,9 +331,18 @@ def str2bool(v):
331331

332332
torch.save(model, os.path.join(save_dir, 'epoch-{}.pt'.format(i)))
333333

334+
# If the task requires generating new data after each epoch, do that here!
334335
if COP == 'tsp':
335-
print('Generating and loading new TSP training set')
336336
training_dataset = tsp_task.TSPDataset(train=True, size=size,
337337
num_samples=int(args['train_size']))
338338
training_dataloader = DataLoader(training_dataset, batch_size=int(args['batch_size']),
339339
shuffle=True, num_workers=1)
340+
if COP == 'sort':
341+
train_fname, _ = sorting_task.create_dataset(
342+
int(args['train_size']),
343+
int(args['val_size']),
344+
data_dir,
345+
data_len=size)
346+
training_dataset = sorting_task.SortingDataset(train_fname)
347+
training_dataloader = DataLoader(training_dataset, batch_size=int(args['batch_size']),
348+
shuffle=True, num_workers=1)

0 commit comments

Comments
 (0)