Skip to content

Commit 907b701

Browse files
committed
add relpn v0.1, unit test needed
1 parent 8838b3b commit 907b701

File tree

4 files changed

+294
-45
lines changed

4 files changed

+294
-45
lines changed

lib/scene_parser/rcnn/modeling/relation_heads/loss.py

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def match_targets_to_proposals(self, proposal, target):
9595
# GT in the image, and matched_idxs can be -2, which goes
9696
# out of bounds
9797

98-
if self.use_matched_pairs_only and (matched_idxs >= 0).sum() > self.minimal_matched_pairs:
98+
if self.use_matched_pairs_only and \
99+
(matched_idxs >= 0).sum() > self.minimal_matched_pairs:
99100
# filter all matched_idxs < 0
100101
proposal_pairs = proposal_pairs[matched_idxs >= 0]
101102
matched_idxs = matched_idxs[matched_idxs >= 0]
@@ -137,17 +138,10 @@ def prepare_targets(self, proposals, targets):
137138

138139
return labels, proposal_pairs
139140

140-
def subsample(self, proposals, targets):
141+
def _randomsample_train(self, proposals, targets):
141142
"""
142-
This method performs the positive/negative sampling, and return
143-
the sampled proposals.
144-
Note: this function keeps a state.
145-
146-
Arguments:
147-
proposals (list[BoxList])
148-
targets (list[BoxList])
143+
perform relpn based sampling during training
149144
"""
150-
151145
labels, proposal_pairs = self.prepare_targets(proposals, targets)
152146
sampled_pos_inds, sampled_neg_inds = self.fg_bg_pair_sampler(labels)
153147

@@ -173,6 +167,82 @@ def subsample(self, proposals, targets):
173167
self._proposal_pairs = proposal_pairs
174168
return proposal_pairs
175169

170+
def _fullsample_test(self, proposals):
171+
"""
172+
This method get all subject-object pairs, and return the proposals.
173+
Note: this function keeps a state.
174+
175+
Arguments:
176+
proposals (list[BoxList])
177+
"""
178+
proposal_pairs = []
179+
for i, proposals_per_image in enumerate(proposals):
180+
box_subj = proposals_per_image.bbox
181+
box_obj = proposals_per_image.bbox
182+
183+
box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
184+
box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
185+
proposal_box_pairs = torch.cat((box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)
186+
187+
idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(proposals_per_image.bbox.device)
188+
idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(proposals_per_image.bbox.device)
189+
proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)
190+
191+
keep_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero().view(-1)
192+
193+
# if we filter non overlap bounding boxes
194+
if self.cfg.MODEL.ROI_RELATION_HEAD.FILTER_NON_OVERLAP:
195+
ious = boxlist_iou(proposals_per_image, proposals_per_image).view(-1)
196+
ious = ious[keep_idx]
197+
keep_idx = keep_idx[(ious > 0).nonzero().view(-1)]
198+
proposal_idx_pairs = proposal_idx_pairs[keep_idx]
199+
proposal_box_pairs = proposal_box_pairs[keep_idx]
200+
proposal_pairs_per_image = BoxPairList(proposal_box_pairs, proposals_per_image.size, proposals_per_image.mode)
201+
proposal_pairs_per_image.add_field("idx_pairs", proposal_idx_pairs)
202+
203+
proposal_pairs.append(proposal_pairs_per_image)
204+
return proposal_pairs
205+
206+
def subsample(self, proposals, targets=None):
207+
"""
208+
This method performs the random positive/negative sampling, and return
209+
the sampled proposals.
210+
Note: this function keeps a state.
211+
212+
Arguments:
213+
proposals (list[BoxList])
214+
targets (list[BoxList])
215+
"""
216+
if targets is not None:
217+
proposal_pairs = self._randomsample_train(proposals, targets)
218+
else:
219+
proposal_pairs = self._fullsample_test(proposals)
220+
return proposal_pairs
221+
# labels, proposal_pairs = self.prepare_targets(proposals, targets)
222+
# sampled_pos_inds, sampled_neg_inds = self.fg_bg_pair_sampler(labels)
223+
#
224+
# proposal_pairs = list(proposal_pairs)
225+
# # add corresponding label and regression_targets information to the bounding boxes
226+
# for labels_per_image, proposal_pairs_per_image in zip(
227+
# labels, proposal_pairs
228+
# ):
229+
# proposal_pairs_per_image.add_field("labels", labels_per_image)
230+
# # proposals_per_image.add_field(
231+
# # "regression_targets", regression_targets_per_image
232+
# # )
233+
#
234+
# # distributed sampled proposals, that were obtained on all feature maps
235+
# # concatenated via the fg_bg_sampler, into individual feature map levels
236+
# for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
237+
# zip(sampled_pos_inds, sampled_neg_inds)
238+
# ):
239+
# img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
240+
# proposal_pairs_per_image = proposal_pairs[img_idx][img_sampled_inds]
241+
# proposal_pairs[img_idx] = proposal_pairs_per_image
242+
#
243+
# self._proposal_pairs = proposal_pairs
244+
# return proposal_pairs
245+
176246
def __call__(self, class_logits):
177247
"""
178248
Computes the loss for Faster R-CNN.

lib/scene_parser/rcnn/modeling/relation_heads/relation_heads.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,14 @@ def forward(self, features, proposals, targets=None):
111111
# Faster R-CNN subsamples during training the proposals with a fixed
112112
# positive / negative ratio
113113
if self.cfg.MODEL.USE_RELPN:
114-
proposal_pairs = self.relpn(proposals, targets)
114+
proposal_pairs, loss_relpn = self.relpn(proposals, targets)
115115
else:
116-
with torch.no_grad():
117-
proposal_pairs = self.loss_evaluator.subsample(proposals, targets)
116+
proposal_pairs = self.loss_evaluator.subsample(proposals, targets)
118117
else:
119118
if self.cfg.MODEL.USE_RELPN:
120-
proposal_pairs = self.relpn(proposals, targets)
119+
proposal_pairs, _ = self.relpn(proposals)
121120
else:
122-
proposal_pairs = self._get_proposal_pairs(proposals)
121+
proposal_pairs = self.loss_evaluator.subsample(proposals)
123122

124123
if self.cfg.MODEL.USE_FREQ_PRIOR:
125124
"""
@@ -166,13 +165,24 @@ def forward(self, features, proposals, targets=None):
166165
loss_obj_classifier = 0
167166
if obj_class_logits is not None:
168167
loss_obj_classifier = self.loss_evaluator.obj_classification_loss(proposals, [obj_class_logits])
169-
loss_pred_classifier = self.loss_evaluator([pred_class_logits])
170168

171-
return (
172-
x,
173-
proposal_pairs,
174-
dict(loss_obj_classifier=loss_obj_classifier, loss_pred_classifier=loss_pred_classifier),
175-
)
169+
if self.cfg.MODEL.USE_RELPN:
170+
loss_pred_classifier = self.relpn.pred_classification_loss([pred_class_logits])
171+
return (
172+
x,
173+
proposal_pairs,
174+
dict(loss_obj_classifier=loss_obj_classifier,
175+
loss_relpn = loss_relpn,
176+
loss_pred_classifier=loss_pred_classifier),
177+
)
178+
else:
179+
loss_pred_classifier = self.loss_evaluator([pred_class_logits])
180+
return (
181+
x,
182+
proposal_pairs,
183+
dict(loss_obj_classifier=loss_obj_classifier,
184+
loss_pred_classifier=loss_pred_classifier),
185+
)
176186

177187
def build_roi_relation_head(cfg, in_channels):
178188
"""
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class Relationshipness(nn.Module):
5+
"""
6+
compute relationshipness between subjects and objects
7+
"""
8+
def __init__(self, dim, pos_encoding=False):
9+
super(Relationshipness, self).__init__()
10+
self.subj_proj = nn.Sequential(
11+
nn.Linear(dim, 64),
12+
nn.ReLU(True),
13+
nn.Linear(64, 64)
14+
)
15+
16+
self.obj_prof = nn.Sequential(
17+
nn.Linear(dim, 64),
18+
nn.ReLU(True),
19+
nn.Linear(64, 64)
20+
)
21+
22+
def forward(self, x, bbox=None):
23+
x_subj = self.subj_proj(x) # k x 64
24+
x_obj = self.obj_prof(x) # k x 64
25+
scores = torch.mm(x_subj, x_obj.t()) # k x k
26+
relness = torch.sigmoid(scores) # k x k
27+
return relness

0 commit comments

Comments
 (0)