@@ -95,7 +95,8 @@ def match_targets_to_proposals(self, proposal, target):
95
95
# GT in the image, and matched_idxs can be -2, which goes
96
96
# out of bounds
97
97
98
- if self .use_matched_pairs_only and (matched_idxs >= 0 ).sum () > self .minimal_matched_pairs :
98
+ if self .use_matched_pairs_only and \
99
+ (matched_idxs >= 0 ).sum () > self .minimal_matched_pairs :
99
100
# filter all matched_idxs < 0
100
101
proposal_pairs = proposal_pairs [matched_idxs >= 0 ]
101
102
matched_idxs = matched_idxs [matched_idxs >= 0 ]
@@ -137,17 +138,10 @@ def prepare_targets(self, proposals, targets):
137
138
138
139
return labels , proposal_pairs
139
140
140
- def subsample (self , proposals , targets ):
141
+ def _randomsample_train (self , proposals , targets ):
141
142
"""
142
- This method performs the positive/negative sampling, and return
143
- the sampled proposals.
144
- Note: this function keeps a state.
145
-
146
- Arguments:
147
- proposals (list[BoxList])
148
- targets (list[BoxList])
143
+ perform relpn based sampling during training
149
144
"""
150
-
151
145
labels , proposal_pairs = self .prepare_targets (proposals , targets )
152
146
sampled_pos_inds , sampled_neg_inds = self .fg_bg_pair_sampler (labels )
153
147
@@ -173,6 +167,82 @@ def subsample(self, proposals, targets):
173
167
self ._proposal_pairs = proposal_pairs
174
168
return proposal_pairs
175
169
170
+ def _fullsample_test (self , proposals ):
171
+ """
172
+ This method get all subject-object pairs, and return the proposals.
173
+ Note: this function keeps a state.
174
+
175
+ Arguments:
176
+ proposals (list[BoxList])
177
+ """
178
+ proposal_pairs = []
179
+ for i , proposals_per_image in enumerate (proposals ):
180
+ box_subj = proposals_per_image .bbox
181
+ box_obj = proposals_per_image .bbox
182
+
183
+ box_subj = box_subj .unsqueeze (1 ).repeat (1 , box_subj .shape [0 ], 1 )
184
+ box_obj = box_obj .unsqueeze (0 ).repeat (box_obj .shape [0 ], 1 , 1 )
185
+ proposal_box_pairs = torch .cat ((box_subj .view (- 1 , 4 ), box_obj .view (- 1 , 4 )), 1 )
186
+
187
+ idx_subj = torch .arange (box_subj .shape [0 ]).view (- 1 , 1 , 1 ).repeat (1 , box_obj .shape [0 ], 1 ).to (proposals_per_image .bbox .device )
188
+ idx_obj = torch .arange (box_obj .shape [0 ]).view (1 , - 1 , 1 ).repeat (box_subj .shape [0 ], 1 , 1 ).to (proposals_per_image .bbox .device )
189
+ proposal_idx_pairs = torch .cat ((idx_subj .view (- 1 , 1 ), idx_obj .view (- 1 , 1 )), 1 )
190
+
191
+ keep_idx = (proposal_idx_pairs [:, 0 ] != proposal_idx_pairs [:, 1 ]).nonzero ().view (- 1 )
192
+
193
+ # if we filter non overlap bounding boxes
194
+ if self .cfg .MODEL .ROI_RELATION_HEAD .FILTER_NON_OVERLAP :
195
+ ious = boxlist_iou (proposals_per_image , proposals_per_image ).view (- 1 )
196
+ ious = ious [keep_idx ]
197
+ keep_idx = keep_idx [(ious > 0 ).nonzero ().view (- 1 )]
198
+ proposal_idx_pairs = proposal_idx_pairs [keep_idx ]
199
+ proposal_box_pairs = proposal_box_pairs [keep_idx ]
200
+ proposal_pairs_per_image = BoxPairList (proposal_box_pairs , proposals_per_image .size , proposals_per_image .mode )
201
+ proposal_pairs_per_image .add_field ("idx_pairs" , proposal_idx_pairs )
202
+
203
+ proposal_pairs .append (proposal_pairs_per_image )
204
+ return proposal_pairs
205
+
206
+ def subsample (self , proposals , targets = None ):
207
+ """
208
+ This method performs the random positive/negative sampling, and return
209
+ the sampled proposals.
210
+ Note: this function keeps a state.
211
+
212
+ Arguments:
213
+ proposals (list[BoxList])
214
+ targets (list[BoxList])
215
+ """
216
+ if targets is not None :
217
+ proposal_pairs = self ._randomsample_train (proposals , targets )
218
+ else :
219
+ proposal_pairs = self ._fullsample_test (proposals )
220
+ return proposal_pairs
221
+ # labels, proposal_pairs = self.prepare_targets(proposals, targets)
222
+ # sampled_pos_inds, sampled_neg_inds = self.fg_bg_pair_sampler(labels)
223
+ #
224
+ # proposal_pairs = list(proposal_pairs)
225
+ # # add corresponding label and regression_targets information to the bounding boxes
226
+ # for labels_per_image, proposal_pairs_per_image in zip(
227
+ # labels, proposal_pairs
228
+ # ):
229
+ # proposal_pairs_per_image.add_field("labels", labels_per_image)
230
+ # # proposals_per_image.add_field(
231
+ # # "regression_targets", regression_targets_per_image
232
+ # # )
233
+ #
234
+ # # distributed sampled proposals, that were obtained on all feature maps
235
+ # # concatenated via the fg_bg_sampler, into individual feature map levels
236
+ # for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
237
+ # zip(sampled_pos_inds, sampled_neg_inds)
238
+ # ):
239
+ # img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
240
+ # proposal_pairs_per_image = proposal_pairs[img_idx][img_sampled_inds]
241
+ # proposal_pairs[img_idx] = proposal_pairs_per_image
242
+ #
243
+ # self._proposal_pairs = proposal_pairs
244
+ # return proposal_pairs
245
+
176
246
def __call__ (self , class_logits ):
177
247
"""
178
248
Computes the loss for Faster R-CNN.
0 commit comments