@@ -278,7 +278,7 @@ def str2bool(v):
278278 # Use beam search decoding for validation
279279 model .actor_net .decoder .decode_type = "beam_search"
280280
281- print ('\n ~Validating with beam search decoding ~\n ' )
281+ print ('\n ~Validating~\n ' )
282282
283283 example_input = []
284284 example_output = []
@@ -331,9 +331,18 @@ def str2bool(v):
331331
332332 torch .save (model , os .path .join (save_dir , 'epoch-{}.pt' .format (i )))
333333
334+ # If the task requires generating new data after each epoch, do that here!
334335 if COP == 'tsp' :
335- print ('Generating and loading new TSP training set' )
336336 training_dataset = tsp_task .TSPDataset (train = True , size = size ,
337337 num_samples = int (args ['train_size' ]))
338338 training_dataloader = DataLoader (training_dataset , batch_size = int (args ['batch_size' ]),
339339 shuffle = True , num_workers = 1 )
340+ if COP == 'sort' :
341+ train_fname , _ = sorting_task .create_dataset (
342+ int (args ['train_size' ]),
343+ int (args ['val_size' ]),
344+ data_dir ,
345+ data_len = size )
346+ training_dataset = sorting_task .SortingDataset (train_fname )
347+ training_dataloader = DataLoader (training_dataset , batch_size = int (args ['batch_size' ]),
348+ shuffle = True , num_workers = 1 )
0 commit comments