@@ -175,6 +175,7 @@ def __init__(self,
175175 divide_by_per_img_std = False , # img stats
176176 raise_IOErrors = False ,
177177 rng = None ,
178+ preload = False ,
178179 ** kwargs ):
179180
180181 if len (kwargs ):
@@ -299,6 +300,7 @@ def __init__(self,
299300 self .divide_by_per_img_std = divide_by_per_img_std
300301 self .raise_IOErrors = raise_IOErrors
301302 self .rng = rng if rng is not None else RandomState (0xbeef )
303+ self .preload = preload
302304
303305 self .set_has_GT = getattr (self , 'set_has_GT' , True )
304306 self .mean = getattr (self , 'mean' , [])
@@ -324,6 +326,20 @@ def __init__(self,
324326 raise RuntimeError ('The name list cannot be empty' )
325327 self ._fill_names_batches (shuffle_at_each_epoch )
326328
329+ # Cache for already loaded data
330+ if self .preload :
331+ self .image_raw = self ._preload_data (
332+ self .image_path_raw , dtype = 'floatX' , expand = True )
333+ self .image_smooth = self ._preload_data (
334+ self .image_path_smooth , dtype = 'floatX' , expand = True )
335+ self .mask = self ._preload_data (self .mask_path , dtype = 'int32' )
336+ self .regions = self ._preload_data (self .regions_path , dtype = 'int32' )
337+ else :
338+ self .image_raw = None
339+ self .image_smooth = None
340+ self .mask = None
341+ self .regions = None
342+
327343 if self .use_threads :
328344 # Initialize the queues
329345 self .names_queue = Queue .Queue (maxsize = self .queues_size )
@@ -344,9 +360,28 @@ def __init__(self,
344360 # Give time to the data fetcher to die, in case of errors
345361 # sleep(1)
346362
347-
348363 # super(ThreadedDataset_1D, self).__init__(*args, **kwargs)
349364
365+ def _preload_data (self , path , dtype , expand = False ):
366+ if dtype == 'floatX' :
367+ py_type = float
368+ dtype = floatX
369+ elif dtype == 'int32' :
370+ py_type = int
371+ else :
372+ raise ValueError ('dtype not supported' , dtype )
373+ ret = []
374+ with open (path ) as fp :
375+ for i , line in enumerate (fp ):
376+ line = re .split (' ' , line )
377+ line = np .array ([py_type (el ) for el in line ], dtype = dtype )
378+ ret .append (line )
379+ ret = np .vstack (ret )
380+ if expand :
381+ # b,0 to b,0,c
382+ ret = np .expand_dims (ret , axis = 2 )
383+ return ret
384+
350385 def fetch_from_dataset (self , batch_to_load ):
351386 """
352387 Return *batches* of 1D data.
@@ -367,35 +402,41 @@ def fetch_from_dataset(self, batch_to_load):
367402 ret ['indices' ] = []#np.sort(batch_to_load)
368403
369404 if self .smooth_raw_both == 'raw' or self .smooth_raw_both == 'both' :
370- raw = []
371- with open (self .image_path_raw ) as fp :
372- for i , line in enumerate (fp ):
373- if i in batch_to_load :
374- line = re .split (' ' , line )
375- line = np .array ([float (el ) for el in line ])
376- line = line .astype (floatX )
377- raw .append (line )
378- if len (raw ) == len (batch_to_load ):
379- break
380- raw = np .vstack (raw )
381- # b,0 to b,0,c
382- raw = np .expand_dims (raw , axis = 2 )
405+ if self .preload :
406+ raw = self .image_raw [batch_to_load ]
407+ else :
408+ raw = []
409+ with open (self .image_path_raw ) as fp :
410+ for i , line in enumerate (fp ):
411+ if i in batch_to_load :
412+ line = re .split (' ' , line )
413+ line = np .array ([float (el ) for el in line ])
414+ line = line .astype (floatX )
415+ raw .append (line )
416+ if len (raw ) == len (batch_to_load ):
417+ break
418+ raw = np .vstack (raw )
419+ # b,0 to b,0,c
420+ raw = np .expand_dims (raw , axis = 2 )
383421
384422 if self .smooth_raw_both == 'smooth' or self .smooth_raw_both == 'both' :
385- smooth = []
386- with open (self .image_path_smooth ) as fp :
387- for i , line in enumerate (fp ):
388- if i in batch_to_load :
389- line = re .split (' ' , line )
390- line = np .array ([float (el ) for el in line ])
391- line = line .astype (floatX )
392- smooth .append (line )
393- if len (smooth ) == len (batch_to_load ):
394- break
395-
396- smooth = np .vstack (smooth )
397- # b,0 to b,0,c
398- smooth = np .expand_dims (smooth , axis = 2 )
423+ if self .preload :
424+ smooth = self .image_smooth [batch_to_load ]
425+ else :
426+ smooth = []
427+ with open (self .image_path_smooth ) as fp :
428+ for i , line in enumerate (fp ):
429+ if i in batch_to_load :
430+ line = re .split (' ' , line )
431+ line = np .array ([float (el ) for el in line ])
432+ line = line .astype (floatX )
433+ smooth .append (line )
434+ if len (smooth ) == len (batch_to_load ):
435+ break
436+
437+ smooth = np .vstack (smooth )
438+ # b,0 to b,0,c
439+ smooth = np .expand_dims (smooth , axis = 2 )
399440
400441 if self .smooth_raw_both == 'raw' :
401442 ret ['data' ] = raw
@@ -409,31 +450,34 @@ def fetch_from_dataset(self, batch_to_load):
409450 # Load mask
410451 ret ['labels' ] = []
411452 if self .task == 'segmentation' :
412- with open (self .mask_path ) as fp :
413- for i , line in enumerate (fp ):
414- if i in batch_to_load :
415- line = re .split (' ' , line )
416- line = np .array ([int (el ) for el in line ])
417- line = line .astype ('int32' )
418- ret ['labels' ].append (line )
419- if len (ret ['labels' ]) == len (batch_to_load ):
420- break
421- ret ['labels' ] = np .vstack (ret ['labels' ])
422-
453+ if self .preload :
454+ ret ['labels' ] = self .mask [batch_to_load ]
455+ else :
456+ with open (self .mask_path ) as fp :
457+ for i , line in enumerate (fp ):
458+ if i in batch_to_load :
459+ line = re .split (' ' , line )
460+ line = np .array ([int (el ) for el in line ])
461+ line = line .astype ('int32' )
462+ ret ['labels' ].append (line )
463+ if len (ret ['labels' ]) == len (batch_to_load ):
464+ break
465+ ret ['labels' ] = np .vstack (ret ['labels' ])
423466
424467 elif self .task == 'classification' :
425- with open (self .mask_path ) as fp :
426- for i , line in enumerate (fp ):
427- if i in batch_to_load :
428- line = re .split (' ' , line )
429- line = np .array ([int (el ) for el in line ])
430- line = line .astype ('int32' )
431- ret ['labels' ].append (line )
432- if len (ret ['labels' ]) == len (batch_to_load ):
433- break
434- ret ['labels' ] = np .vstack (ret ['labels' ])
435-
436-
468+ if self .preload :
469+ ret ['labels' ] = self .mask [batch_to_load ]
470+ else :
471+ with open (self .mask_path ) as fp :
472+ for i , line in enumerate (fp ):
473+ if i in batch_to_load :
474+ line = re .split (' ' , line )
475+ line = np .array ([int (el ) for el in line ])
476+ line = line .astype ('int32' )
477+ ret ['labels' ].append (line )
478+ if len (ret ['labels' ]) == len (batch_to_load ):
479+ break
480+ ret ['labels' ] = np .vstack (ret ['labels' ])
437481
438482
439483 ret ['filenames' ] = batch_to_load
0 commit comments