You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+20-9Lines changed: 20 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,20 +1,22 @@
1
1
# TensorFlow-ENet
2
2
TensorFlow implementation of [**ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation**](https://arxiv.org/pdf/1606.02147.pdf).
3
3
4
-
This model was tested on the CamVid dataset with street scenes taken from Cambridge, UK. For more information, please visit: http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/
4
+
This model was tested on the CamVid dataset with street scenes taken from Cambridge, UK. For more information on this dataset, please visit: http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/
5
5
6
6
7
7
## Visualizations
8
8
Note that the gifs may be out of sync if the network doesn't load them together. You can refresh your page to see them in sync.
9
9
10
-
### Original Video Input
11
-

12
-
13
10
### Test Dataset Output
14
-

11
+

12
+
13
+
### TensorBoard Visualizations
14
+
Execute `tensorboard --logdir=log` on your root directory to monitor your training and watch your segmentation output form against the ground truth and the original image as you train your model.
15
15
16
16
17
17
## Contents
18
+
19
+
#### Code
18
20
-**enet.py**: The ENet model definition, including the argument scope.
19
21
20
22
-**train_enet.py**: The file for training. Includes saving of images for visualization and tunable hyperparameters.
@@ -26,12 +28,16 @@ Note that the gifs may be out of sync if the network doesn't load them together.
26
28
-**predict_segmentation.py**: Obtains the segmentation output for visualization purposes. You can create your own gif with these outputs.
27
29
28
30
-**get_class_weights.py**: The file to obtain either the median frequency balancing class weights, or the custom ENet function class weights.
31
+
-**train.sh**: Example training script to train the different variations of the model.
29
32
33
+
-**test.sh** Example testing script to test the different variants you trained.
30
34
31
-
**TensorBoard Visualizations:** Execute `tensorboard --logdir=log` on your root directory to monitor your training and watch your segmentation output form against the ground truth and the original image as you train your model.
35
+
#### Folders
32
36
37
+
-**dataset**: Contains 6 folders that holds the original train-val-test images and their corresponding ground truth annotations.
33
38
34
-
**Note:** To use the checkpoint model, please set the argument `--stage_two_repeat=3` in both `train_enet.py` and `test_enet.py` as the checkpoint was trained on a slightly deeper version of ENet using 3 stage_two bottleneck series instead of the default 2.
39
+
-**checkpoint**: The checkpoint directory that could be used for predicting the segmentation output. The model was trained using the default parameters mentioned in the paper, except that it uses median frequency balancing to obtain the class weights.
40
+
-**visualizations**: Contains the gif files that were created from the output of `predict_segmentation.py`.
35
41
36
42
37
43
## Important Notes
@@ -43,12 +49,14 @@ Note that the gifs may be out of sync if the network doesn't load them together.
43
49
44
50
4. On the labels and colouring scheme: The dataset consists of only 12 labels, with the road-marking class merged with the road class. The last class is the unlabelled class.
45
51
46
-
5. No preprocessing is done to the images for ENet. (see references below on clarifications with author),
52
+
5. No preprocessing is done to the images for ENet. (see references below on clarifications with author).
53
+
54
+
6. Once you've fine-tuned to get your best hyperparameters, there's an option to combine the training and validation datasets together. However, if your training dataset is large enough, this won't make a lot of difference.
47
55
48
56
## Implementation and Architectural Changes
49
57
1. By default, skip connections are added to connect the corresponding encoder and decoder portions for better performance.
50
58
51
-
2. The number of initial blocks and the depth of stage 2 residual bottlenecks are tunable hyperparameters, to allow you to build a deeper network if required, since ENet is rather lightweight.
59
+
2. The number of initial blocks and the depth of stage 2 residual bottlenecks are tunable hyperparameters. This allows you to build a deeper network if required, since ENet is rather lightweight.
52
60
53
61
3. Fused batch normalization is used over standard batch normalization for faster computations. See [TensorFlow's best practices](https://www.tensorflow.org/performance/performance_guide).
54
62
@@ -62,3 +70,6 @@ Note that the gifs may be out of sync if the network doesn't load them together.
62
70
5.[Original Torch implementation of ENet](https://github.com/e-lab/ENet-training)
63
71
6.[ResNet paper for clarification on residual bottlenecks](https://arxiv.org/pdf/1512.03385.pdf)
This implementation may not be entirely correct and may contain bugs. It would be great if the open source community can spot any bugs and raise a github issue/submit a pull request to fix those bugs if any!
0 commit comments