Skip to content

Commit ffa5213

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
update docs for poolers
Summary: Pull Request resolved: fairinternal/detectron2#322 Differential Revision: D18395920 Pulled By: ppwwyyxx fbshipit-source-id: 6c03e9474f289fcc293de5bfcd373f8cbc393914
1 parent 366f8ba commit ffa5213

File tree

5 files changed

+42
-11
lines changed

5 files changed

+42
-11
lines changed

detectron2/data/dataset_mapper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
22
import copy
33
import logging
4-
54
import numpy as np
65
import torch
76
from fvcore.common.file_io import PathManager

detectron2/modeling/poolers.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def assign_boxes_to_levels(box_lists, min_level, max_level, canonical_box_size,
3838
level_assignments = torch.floor(
3939
canonical_level + torch.log2(box_sizes / canonical_box_size + eps)
4040
)
41+
# clamp level to (min, max), in case the box size is too large or too small
42+
# for the available feature maps
4143
level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
4244
return level_assignments.to(torch.int64) - min_level
4345

@@ -100,15 +102,25 @@ def __init__(
100102
e.g., 14 x 14. If tuple or list is given, the length must be 2.
101103
scales (list[float]): The scale for each low-level pooling op relative to
102104
the input image. For a feature map with stride s relative to the input
103-
image, scale is defined as a 1 / s.
105+
image, scale is defined as a 1 / s. The stride must be power of 2.
106+
When there are multiple scales, they must form a pyramid, i.e. they must be
107+
a monotically decreasing geometric sequence with a factor of 1/2.
104108
sampling_ratio (int): The `sampling_ratio` parameter for the ROIAlign op.
105109
pooler_type (string): Name of the type of pooling operation that should be applied.
106110
For instance, "ROIPool" or "ROIAlignV2".
107111
canonical_box_size (int): A canonical box size in pixels (sqrt(box area)). The default
108112
is heuristically defined as 224 pixels in the FPN paper (based on ImageNet
109113
pre-training).
110-
canonical_level (int): The feature map level index on which a canonically-sized box
111-
should be placed. The default is defined as level 4 in the FPN paper.
114+
canonical_level (int): The feature map level index from which a canonically-sized box
115+
should be placed. The default is defined as level 4 (stride=16) in the FPN paper,
116+
i.e., a box of size 224x224 will be placed on the feature with stride=16.
117+
The box placement for all boxes will be determined from their sizes w.r.t
118+
canonical_box_size. For example, a box whose area is 4x that of a canonical box
119+
should be used to pool features from feature level ``canonical_level+1``.
120+
121+
Note that the actual input feature maps given to this module may not have
122+
sufficiently many levels for the input boxes. If the boxes are too large or too
123+
small for the input feature maps, the closest level will be used.
112124
"""
113125
super().__init__()
114126

@@ -148,22 +160,32 @@ def __init__(
148160
# assumption that stride is a power of 2.
149161
min_level = -math.log2(scales[0])
150162
max_level = -math.log2(scales[-1])
151-
assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level))
163+
assert math.isclose(min_level, int(min_level)) and math.isclose(
164+
max_level, int(max_level)
165+
), "Featuremap stride is not power of 2!"
152166
self.min_level = int(min_level)
153167
self.max_level = int(max_level)
168+
assert (
169+
len(scales) == self.max_level - self.min_level + 1
170+
), "[ROIPooler] Sizes of input featuremaps do not form a pyramid!"
154171
assert 0 < self.min_level and self.min_level <= self.max_level
155-
assert self.min_level <= canonical_level and canonical_level <= self.max_level
172+
if len(scales) > 1:
173+
# When there is only one feature map, canonical_level is redundant and we should not
174+
# require it to be a sensible value. Therefore we skip this assertion
175+
assert self.min_level <= canonical_level and canonical_level <= self.max_level
156176
self.canonical_level = canonical_level
157177
assert canonical_box_size > 0
158178
self.canonical_box_size = canonical_box_size
159179

160180
def forward(self, x, box_lists):
161181
"""
162182
Args:
163-
x (list[Tensor]): A list of feature maps with scales matching those used to
164-
construct this module.
183+
x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those
184+
used to construct this module.
165185
box_lists (list[Boxes] | list[RotatedBoxes]):
166186
A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch.
187+
The box coordinates are defined on the original image and
188+
will be scaled by the `scales` argument of :class:`ROIPooler`.
167189
168190
Returns:
169191
Tensor:
@@ -172,6 +194,9 @@ def forward(self, x, box_lists):
172194
"""
173195
num_level_assignments = len(self.level_poolers)
174196

197+
assert isinstance(x, list) and isinstance(
198+
box_lists, list
199+
), "Arguments to pooler must be lists"
175200
assert (
176201
len(x) == num_level_assignments
177202
), "unequal value, num_level_assignments={}, but x is list of {} Tensors".format(

detectron2/structures/instances.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def __len__(self) -> int:
131131
return len(v)
132132
raise NotImplementedError("Empty Instances does not support __len__!")
133133

134+
def __iter__(self):
135+
raise NotImplementedError("`Instances` object is not iterable!")
136+
134137
@staticmethod
135138
def cat(instance_lists: List["Instances"]) -> "Instances":
136139
"""

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_extensions():
7979
"matplotlib",
8080
"tqdm>4.29.0",
8181
"tensorboard",
82-
"fvcore @ git+https://github.com/facebookresearch/fvcore.git"
82+
"fvcore @ git+https://github.com/facebookresearch/fvcore.git",
8383
],
8484
extras_require={"all": ["shapely", "psutil"]},
8585
ext_modules=get_extensions(),

tests/test_rpn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def test_rpn(self):
3232
gt_instances = Instances(image_shape)
3333
gt_instances.gt_boxes = Boxes(gt_boxes)
3434
with EventStorage(): # capture events in a new storage to discard them
35-
proposals, proposal_losses = proposal_generator(images, features, gt_instances)
35+
proposals, proposal_losses = proposal_generator(
36+
images, features, [gt_instances[0], gt_instances[1]]
37+
)
3638

3739
expected_losses = {
3840
"loss_rpn_cls": torch.tensor(0.0804563984),
@@ -92,7 +94,9 @@ def test_rrpn(self):
9294
gt_instances = Instances(image_shape)
9395
gt_instances.gt_boxes = RotatedBoxes(gt_boxes)
9496
with EventStorage(): # capture events in a new storage to discard them
95-
proposals, proposal_losses = proposal_generator(images, features, gt_instances)
97+
proposals, proposal_losses = proposal_generator(
98+
images, features, [gt_instances[0], gt_instances[1]]
99+
)
96100

97101
expected_losses = {
98102
"loss_rpn_cls": torch.tensor(0.0432923734),

0 commit comments

Comments
 (0)