Skip to content

Commit e0d53ba

Browse files
lamblinnotoraptor
authored andcommitted
Changes to help unet run
1 parent 2a9af6f commit e0d53ba

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

code/unet/Unet_lasagne_recipes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# end-snippet-1
1313

1414
# start-snippet-downsampling
15-
def build_UNet(n_input_channels=1, BATCH_SIZE=None, num_output_classes=2, pad='same', nonlinearity=lasagne.nonlinearities.elu, input_dim=(128, 128), base_n_filters=64, do_dropout=False):
15+
def build_UNet(n_input_channels=1, BATCH_SIZE=None, num_output_classes=2, pad='same', nonlinearity=lasagne.nonlinearities.elu, input_dim=(None, None), base_n_filters=64, do_dropout=False):
1616
net = OrderedDict()
1717
net['input'] = InputLayer((BATCH_SIZE, n_input_channels, input_dim[0], input_dim[1]))
1818

code/unet/train_unet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def train(dataset, learn_step=0.005,
207207

208208
net = build_UNet(n_input_channels= nb_in_channels,# BATCH_SIZE = batch_size,
209209
num_output_classes = n_classes, base_n_filters = 64, do_dropout=False,
210-
input_dim =input_dim)
210+
input_dim = (None, None))
211211

212212
output_layer = net["output_flattened"]
213213
#
@@ -383,7 +383,8 @@ def main():
383383
help='Max patience')
384384
parser.add_argument('-batch_size',
385385
type=int,
386-
default=[10, 1, 1],
386+
nargs='+',
387+
default=[5, 1, 1],
387388
help='Batch size [train, val, test]')
388389
parser.add_argument('-data_augmentation',
389390
type=dict,
@@ -394,7 +395,8 @@ def main():
394395
'fill_mode':'reflect',
395396
'spline_warp':True,
396397
'warp_sigma':10,
397-
'warp_grid_size':3},
398+
'warp_grid_size':3,
399+
'crop_size': (224, 224)},
398400
help='use data augmentation')
399401
parser.add_argument('-early_stop_class',
400402
type=int,

0 commit comments

Comments
 (0)