You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TensorFlow implementation of [**ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation**](https://arxiv.org/pdf/1606.02147.pdf).
3
3
4
-
This model was tested only on the CamVid dataset with street scenes taken from Cambridge, UK. For more information, please visit: http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/
4
+
This model was tested on the CamVid dataset with street scenes taken from Cambridge, UK. For more information, please visit: http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/
5
5
6
6
7
7
### Visualizations
@@ -15,31 +15,35 @@ Note that the gifs may be out of sync if the network doesn't load them together.
15
15
16
16
17
17
### Contents
18
+
-**enet.py**: The ENet model definition, including the argument scope.
18
19
20
+
-**train_enet.py**: The file for training. Includes saving of images for visualization and tunable hyperparameters.
19
21
22
+
-**test_enet.py**: The file for evaluating on the test dataset. Includes option to visualize images as well.
20
23
24
+
-**preprocessing.py**: The preprocessing does just image resizing, just in case anyone wants to use a smaller image size due to memory issues or for other datasets.
21
25
26
+
-**predict_segmentation.py**: Obtains the segmentation output for visualization purposes. You can create your own gif with these outputs.
22
27
28
+
-**get_class_weights.py**: The file to obtain either the median frequency balancing class weights, or the custom ENet function class weights.
23
29
24
30
25
-
**Note:** To use the checkpoint model, please set the `stage_two_repeat=3` as the checkpoint was trained on a slightly deeper version of ENet.
26
31
27
32
28
-
### Training Arguments
29
33
34
+
**Note:** To use the checkpoint model, please set the argument `--stage_two_repeat=3` in both `train_enet.py` and `test_enet.py` as the checkpoint was trained on a slightly deeper version of ENet using 3 stage_two bottleneck series instead of the default 2.
30
35
31
36
32
-
### Evaluation Arguments
33
-
34
37
### Important Notes
35
38
1. As the Max Unpooling layer is not officially available from TensorFlow, a manual implementation was used to build the decoder portion of the network. This was based on the implementation suggested in this [TensorFlow github issue](https://github.com/tensorflow/tensorflow/issues/2169).
36
39
37
40
2. Batch normalization and 2D Spatial Dropout are still retained during testing for good performance.
38
41
39
42
3. Class weights are used to tackle the problem of imbalanced classes, as certain classes appear more dominantly than others. More notably, the background class has weight of 0.0, in order to not reward the model for predicting background.
40
43
41
-
4. The residual
44
+
4.On the labels and colouring scheme: The dataset consists of only 12 labels, with the road-marking class merged with the road class. The last class is the unlabelled class.
42
45
46
+
5. No preprocessing is done to the images for ENet. (see references below on clarifications with author),
43
47
44
48
### Implementation and Architectural Changes
45
49
1. By default, skip connections are added to connect the corresponding encoder and decoder portions for better performance.
@@ -48,8 +52,7 @@ Note that the gifs may be out of sync if the network doesn't load them together.
48
52
49
53
3. Fused batch normalization is used over standard batch normalization for faster computations. See [TensorFlow's best practices](https://www.tensorflow.org/performance/performance_guide).
50
54
51
-
4. To obtain the class weights for computing the weighted loss, Median Frequency Balancing (MFB) is used instead of the custom ENet class weighting function. This is due to an observation that MFB gives a slightly better performance than the custom function, at least on my machine. However, the option of using the ENet custom class weights is still possible.
52
-
55
+
4. To obtain the class weights for computing the weighted loss, Median Frequency Balancing (MFB) is used by default instead of the custom ENet class weighting function. This is due to an observation that MFB gives a slightly better performance than the custom function, at least on my machine. However, the option of using the ENet custom class weights is still possible.
53
56
54
57
### References
55
58
1.[ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation](https://arxiv.org/pdf/1606.02147.pdf)
@@ -58,3 +61,4 @@ Note that the gifs may be out of sync if the network doesn't load them together.
58
61
4.[Clarifications from ENet author](https://github.com/e-lab/ENet-training/issues/56)
59
62
5.[Original Torch implementation of ENet](https://github.com/e-lab/ENet-training)
60
63
6.[ResNet paper for clarification on residual bottlenecks](https://arxiv.org/pdf/1512.03385.pdf)
0 commit comments