@@ -31,7 +31,7 @@ def __init__(
31
31
self .cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
32
32
self .use_matched_pairs_only = use_matched_pairs_only
33
33
self .minimal_matched_pairs = minimal_matched_pairs
34
- self .relationshipness = Relationshipness (self .cfg .MODEL .ROI_BOX_HEAD .NUM_CLASSES )
34
+ self .relationshipness = Relationshipness (self .cfg .MODEL .ROI_BOX_HEAD .NUM_CLASSES , pos_encoding = True )
35
35
36
36
def match_targets_to_proposals (self , proposal , target ):
37
37
match_quality_matrix = boxlist_iou (target , proposal )
@@ -146,20 +146,22 @@ def _relpnsample_train(self, proposals, targets):
146
146
enumerate (zip (proposals , sampled_pos_inds , sampled_neg_inds )):
147
147
obj_logits = proposals_per_image .get_field ('logits' )
148
148
obj_bboxes = proposals_per_image .bbox
149
- relness = self .relationshipness (obj_logits , obj_bboxes )
149
+ relness = self .relationshipness (obj_logits , obj_bboxes , proposals_per_image . size )
150
150
nondiag = (1 - torch .eye (obj_logits .shape [0 ]).to (relness .device )).view (- 1 )
151
151
relness = relness .view (- 1 )[nondiag .nonzero ()]
152
152
relness_sorted , order = torch .sort (relness .view (- 1 ), descending = True )
153
153
img_sampled_inds = order [:self .cfg .MODEL .ROI_RELATION_HEAD .BATCH_SIZE_PER_IMAGE ].view (- 1 )
154
154
proposal_pairs_per_image = proposal_pairs [img_idx ][img_sampled_inds ]
155
155
proposal_pairs [img_idx ] = proposal_pairs_per_image
156
156
157
- img_sampled_inds = torch .nonzero (pos_inds_img | neg_inds_img ).squeeze (1 )
158
- relness = relness [img_sampled_inds ]
159
- pos_labels = torch .ones (len (pos_inds_img .nonzero ()))
160
- neg_labels = torch .zeros (len (neg_inds_img .nonzero ()))
161
- rellabels = torch .cat ((pos_labels , neg_labels ), 0 ).view (- 1 , 1 )
162
- losses += F .binary_cross_entropy (relness , rellabels .to (relness .device ))
157
+ # import pdb; pdb.set_trace()
158
+ # img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
159
+ # relness = relness[img_sampled_inds]
160
+ # pos_labels = torch.ones(len(pos_inds_img.nonzero()))
161
+ # neg_labels = torch.zeros(len(neg_inds_img.nonzero()))
162
+ # rellabels = torch.cat((pos_labels, neg_labels), 0).view(-1, 1)
163
+ # losses += F.binary_cross_entropy(relness, rellabels.to(relness.device))
164
+ losses += F .binary_cross_entropy (relness , (labels [img_idx ] > 0 ).view (- 1 , 1 ).float ())
163
165
164
166
# distributed sampled proposals, that were obtained on all feature maps
165
167
# concatenated via the fg_bg_sampler, into individual feature map levels
@@ -174,42 +176,70 @@ def _relpnsample_train(self, proposals, targets):
174
176
175
177
return proposal_pairs , losses
176
178
179
+ def _fullsample_test (self , proposals ):
180
+ """
181
+ This method get all subject-object pairs, and return the proposals.
182
+ Note: this function keeps a state.
183
+
184
+ Arguments:
185
+ proposals (list[BoxList])
186
+ """
187
+ proposal_pairs = []
188
+ for i , proposals_per_image in enumerate (proposals ):
189
+ box_subj = proposals_per_image .bbox
190
+ box_obj = proposals_per_image .bbox
191
+
192
+ box_subj = box_subj .unsqueeze (1 ).repeat (1 , box_subj .shape [0 ], 1 )
193
+ box_obj = box_obj .unsqueeze (0 ).repeat (box_obj .shape [0 ], 1 , 1 )
194
+ proposal_box_pairs = torch .cat ((box_subj .view (- 1 , 4 ), box_obj .view (- 1 , 4 )), 1 )
195
+
196
+ idx_subj = torch .arange (box_subj .shape [0 ]).view (- 1 , 1 , 1 ).repeat (1 , box_obj .shape [0 ], 1 ).to (proposals_per_image .bbox .device )
197
+ idx_obj = torch .arange (box_obj .shape [0 ]).view (1 , - 1 , 1 ).repeat (box_subj .shape [0 ], 1 , 1 ).to (proposals_per_image .bbox .device )
198
+ proposal_idx_pairs = torch .cat ((idx_subj .view (- 1 , 1 ), idx_obj .view (- 1 , 1 )), 1 )
199
+
200
+ keep_idx = (proposal_idx_pairs [:, 0 ] != proposal_idx_pairs [:, 1 ]).nonzero ().view (- 1 )
201
+
202
+ # if we filter non overlap bounding boxes
203
+ if self .cfg .MODEL .ROI_RELATION_HEAD .FILTER_NON_OVERLAP :
204
+ ious = boxlist_iou (proposals_per_image , proposals_per_image ).view (- 1 )
205
+ ious = ious [keep_idx ]
206
+ keep_idx = keep_idx [(ious > 0 ).nonzero ().view (- 1 )]
207
+ proposal_idx_pairs = proposal_idx_pairs [keep_idx ]
208
+ proposal_box_pairs = proposal_box_pairs [keep_idx ]
209
+ proposal_pairs_per_image = BoxPairList (proposal_box_pairs , proposals_per_image .size , proposals_per_image .mode )
210
+ proposal_pairs_per_image .add_field ("idx_pairs" , proposal_idx_pairs )
211
+
212
+ proposal_pairs .append (proposal_pairs_per_image )
213
+ return proposal_pairs
214
+
177
215
def _relpnsample_test (self , proposals ):
178
216
"""
179
217
perform relpn based sampling during testing
180
218
"""
181
- labels , proposal_pairs = self .prepare_targets (proposals , targets )
219
+ proposals [0 ] = proposals [0 ]
220
+ proposal_pairs = self ._fullsample_test (proposals )
182
221
proposal_pairs = list (proposal_pairs )
183
222
184
- import pdb ; pdb .set_trace ()
185
-
186
- losses = 0
187
- for img_idx , (proposals_per_image ) in \
188
- enumerate (zip (proposals )):
223
+ relnesses = []
224
+ for img_idx , proposals_per_image in enumerate (proposals ):
189
225
obj_logits = proposals_per_image .get_field ('logits' )
190
226
obj_bboxes = proposals_per_image .bbox
191
- relness = self .relationshipness (obj_logits , obj_bboxes )
227
+ relness = self .relationshipness (obj_logits , obj_bboxes , proposals_per_image . size )
192
228
nondiag = (1 - torch .eye (obj_logits .shape [0 ]).to (relness .device )).view (- 1 )
193
229
relness = relness .view (- 1 )[nondiag .nonzero ()]
194
- relness_sorted , order = torch .sort (relness , descending = True )
195
- img_sampled_inds = order [:self .cfg .MODEL .ROI_RELATION_HEAD .BATCH_SIZE_PER_IMAGE ].view (- 1 )
230
+ relness_sorted , order = torch .sort (relness .view (- 1 ), descending = True )
231
+ img_sampled_inds = order [:196 ].view (- 1 )
232
+ relness = relness_sorted [:196 ].view (- 1 )
196
233
proposal_pairs_per_image = proposal_pairs [img_idx ][img_sampled_inds ]
197
234
proposal_pairs [img_idx ] = proposal_pairs_per_image
235
+ relnesses .append (relness )
198
236
199
- # distributed sampled proposals, that were obtained on all feature maps
200
- # concatenated via the fg_bg_sampler, into individual feature map levels
201
- # for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
202
- # zip(sampled_pos_inds, sampled_neg_inds)
203
- # ):
204
- # img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
205
- # proposal_pairs_per_image = proposal_pairs[img_idx][img_sampled_inds]
206
- # proposal_pairs[img_idx] = proposal_pairs_per_image
207
-
237
+ # self.cfg.MODEL.ROI_RELATION_HEAD.BATCH_SIZE_PER_IMAGE
208
238
self ._proposal_pairs = proposal_pairs
209
239
210
- return proposal_pairs , {}
240
+ return proposal_pairs , relnesses
211
241
212
- def forward (self , proposals , targets ):
242
+ def forward (self , proposals , targets = None ):
213
243
"""
214
244
This method performs the positive/negative sampling, and return
215
245
the sampled proposals.
0 commit comments