Skip to content

Commit 77898ac

Browse files
committed
use all pos and neg samples for training relpn
1 parent 907b701 commit 77898ac

File tree

5 files changed

+114
-34
lines changed

5 files changed

+114
-34
lines changed

lib/scene_parser/parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,12 @@ def forward(self, images, targets=None):
111111
if self.training and targets is None:
112112
raise ValueError("In training mode, targets should be passed")
113113
images = to_image_list(images)
114-
features = self.backbone(images.tensors)
114+
if self.training:
115+
with torch.no_grad():
116+
features = self.backbone(images.tensors)
117+
else:
118+
features = self.backbone(images.tensors)
119+
115120
proposals, proposal_losses = self.rpn(images, features, targets)
116121
scene_parser_losses = {}
117122
if self.roi_heads:

lib/scene_parser/rcnn/modeling/relation_heads/relation_heads.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ def forward(self, features, proposals, targets=None):
115115
else:
116116
proposal_pairs = self.loss_evaluator.subsample(proposals, targets)
117117
else:
118-
if self.cfg.MODEL.USE_RELPN:
119-
proposal_pairs, _ = self.relpn(proposals)
120-
else:
121-
proposal_pairs = self.loss_evaluator.subsample(proposals)
118+
with torch.no_grad():
119+
if self.cfg.MODEL.USE_RELPN:
120+
proposal_pairs, relnesses = self.relpn(proposals)
121+
else:
122+
proposal_pairs = self.loss_evaluator.subsample(proposals)
122123

123124
if self.cfg.MODEL.USE_FREQ_PRIOR:
124125
"""
@@ -160,6 +161,11 @@ def forward(self, features, proposals, targets=None):
160161
# proposal.add_field("scores", obj_score)
161162
# proposal.add_field("labels", obj_label)
162163
result = self.post_processor((pred_class_logits), proposal_pairs, use_freq_prior=self.cfg.MODEL.USE_FREQ_PRIOR)
164+
165+
# if self.cfg.MODEL.USE_RELPN:
166+
# for res, relness in zip(result, relnesses):
167+
# res.add_field("scores", res.get_field("scores") * relness.view(-1, 1))
168+
163169
return x, result, {}
164170

165171
loss_obj_classifier = 0

lib/scene_parser/rcnn/modeling/relation_heads/relpn/relationshipness.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import torch
22
import torch.nn as nn
3+
from .utils import box_pos_encoder
34

45
class Relationshipness(nn.Module):
56
"""
67
compute relationshipness between subjects and objects
78
"""
89
def __init__(self, dim, pos_encoding=False):
910
super(Relationshipness, self).__init__()
11+
1012
self.subj_proj = nn.Sequential(
1113
nn.Linear(dim, 64),
1214
nn.ReLU(True),
@@ -19,9 +21,30 @@ def __init__(self, dim, pos_encoding=False):
1921
nn.Linear(64, 64)
2022
)
2123

22-
def forward(self, x, bbox=None):
24+
self.pos_encoding = False
25+
if pos_encoding:
26+
self.pos_encoding = True
27+
self.sub_pos_encoder = nn.Sequential(
28+
nn.Linear(6, 64),
29+
nn.ReLU(True),
30+
nn.Linear(64, 64)
31+
)
32+
33+
self.obj_pos_encoder = nn.Sequential(
34+
nn.Linear(6, 64),
35+
nn.ReLU(True),
36+
nn.Linear(64, 64)
37+
)
38+
39+
def forward(self, x, bbox=None, imsize=None):
2340
x_subj = self.subj_proj(x) # k x 64
2441
x_obj = self.obj_prof(x) # k x 64
2542
scores = torch.mm(x_subj, x_obj.t()) # k x k
43+
if self.pos_encoding:
44+
pos = box_pos_encoder(bbox, imsize[0], imsize[1])
45+
pos_subj = self.sub_pos_encoder(pos)
46+
pos_obj = self.obj_pos_encoder(pos)
47+
pos_scores = torch.mm(pos_subj, pos_obj.t()) # k x k
48+
scores = scores + pos_scores
2649
relness = torch.sigmoid(scores) # k x k
2750
return relness

lib/scene_parser/rcnn/modeling/relation_heads/relpn/relpn.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
3232
self.use_matched_pairs_only = use_matched_pairs_only
3333
self.minimal_matched_pairs = minimal_matched_pairs
34-
self.relationshipness = Relationshipness(self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES)
34+
self.relationshipness = Relationshipness(self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES, pos_encoding=True)
3535

3636
def match_targets_to_proposals(self, proposal, target):
3737
match_quality_matrix = boxlist_iou(target, proposal)
@@ -146,20 +146,22 @@ def _relpnsample_train(self, proposals, targets):
146146
enumerate(zip(proposals, sampled_pos_inds, sampled_neg_inds)):
147147
obj_logits = proposals_per_image.get_field('logits')
148148
obj_bboxes = proposals_per_image.bbox
149-
relness = self.relationshipness(obj_logits, obj_bboxes)
149+
relness = self.relationshipness(obj_logits, obj_bboxes, proposals_per_image.size)
150150
nondiag = (1 - torch.eye(obj_logits.shape[0]).to(relness.device)).view(-1)
151151
relness = relness.view(-1)[nondiag.nonzero()]
152152
relness_sorted, order = torch.sort(relness.view(-1), descending=True)
153153
img_sampled_inds = order[:self.cfg.MODEL.ROI_RELATION_HEAD.BATCH_SIZE_PER_IMAGE].view(-1)
154154
proposal_pairs_per_image = proposal_pairs[img_idx][img_sampled_inds]
155155
proposal_pairs[img_idx] = proposal_pairs_per_image
156156

157-
img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
158-
relness = relness[img_sampled_inds]
159-
pos_labels = torch.ones(len(pos_inds_img.nonzero()))
160-
neg_labels = torch.zeros(len(neg_inds_img.nonzero()))
161-
rellabels = torch.cat((pos_labels, neg_labels), 0).view(-1, 1)
162-
losses += F.binary_cross_entropy(relness, rellabels.to(relness.device))
157+
# import pdb; pdb.set_trace()
158+
# img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
159+
# relness = relness[img_sampled_inds]
160+
# pos_labels = torch.ones(len(pos_inds_img.nonzero()))
161+
# neg_labels = torch.zeros(len(neg_inds_img.nonzero()))
162+
# rellabels = torch.cat((pos_labels, neg_labels), 0).view(-1, 1)
163+
# losses += F.binary_cross_entropy(relness, rellabels.to(relness.device))
164+
losses += F.binary_cross_entropy(relness, (labels[img_idx] > 0).view(-1, 1).float())
163165

164166
# distributed sampled proposals, that were obtained on all feature maps
165167
# concatenated via the fg_bg_sampler, into individual feature map levels
@@ -174,42 +176,70 @@ def _relpnsample_train(self, proposals, targets):
174176

175177
return proposal_pairs, losses
176178

179+
def _fullsample_test(self, proposals):
180+
"""
181+
This method get all subject-object pairs, and return the proposals.
182+
Note: this function keeps a state.
183+
184+
Arguments:
185+
proposals (list[BoxList])
186+
"""
187+
proposal_pairs = []
188+
for i, proposals_per_image in enumerate(proposals):
189+
box_subj = proposals_per_image.bbox
190+
box_obj = proposals_per_image.bbox
191+
192+
box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
193+
box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
194+
proposal_box_pairs = torch.cat((box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)
195+
196+
idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(proposals_per_image.bbox.device)
197+
idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(proposals_per_image.bbox.device)
198+
proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)
199+
200+
keep_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero().view(-1)
201+
202+
# if we filter non overlap bounding boxes
203+
if self.cfg.MODEL.ROI_RELATION_HEAD.FILTER_NON_OVERLAP:
204+
ious = boxlist_iou(proposals_per_image, proposals_per_image).view(-1)
205+
ious = ious[keep_idx]
206+
keep_idx = keep_idx[(ious > 0).nonzero().view(-1)]
207+
proposal_idx_pairs = proposal_idx_pairs[keep_idx]
208+
proposal_box_pairs = proposal_box_pairs[keep_idx]
209+
proposal_pairs_per_image = BoxPairList(proposal_box_pairs, proposals_per_image.size, proposals_per_image.mode)
210+
proposal_pairs_per_image.add_field("idx_pairs", proposal_idx_pairs)
211+
212+
proposal_pairs.append(proposal_pairs_per_image)
213+
return proposal_pairs
214+
177215
def _relpnsample_test(self, proposals):
178216
"""
179217
perform relpn based sampling during testing
180218
"""
181-
labels, proposal_pairs = self.prepare_targets(proposals, targets)
219+
proposals[0] = proposals[0]
220+
proposal_pairs = self._fullsample_test(proposals)
182221
proposal_pairs = list(proposal_pairs)
183222

184-
import pdb; pdb.set_trace()
185-
186-
losses = 0
187-
for img_idx, (proposals_per_image) in \
188-
enumerate(zip(proposals)):
223+
relnesses = []
224+
for img_idx, proposals_per_image in enumerate(proposals):
189225
obj_logits = proposals_per_image.get_field('logits')
190226
obj_bboxes = proposals_per_image.bbox
191-
relness = self.relationshipness(obj_logits, obj_bboxes)
227+
relness = self.relationshipness(obj_logits, obj_bboxes, proposals_per_image.size)
192228
nondiag = (1 - torch.eye(obj_logits.shape[0]).to(relness.device)).view(-1)
193229
relness = relness.view(-1)[nondiag.nonzero()]
194-
relness_sorted, order = torch.sort(relness, descending=True)
195-
img_sampled_inds = order[:self.cfg.MODEL.ROI_RELATION_HEAD.BATCH_SIZE_PER_IMAGE].view(-1)
230+
relness_sorted, order = torch.sort(relness.view(-1), descending=True)
231+
img_sampled_inds = order[:196].view(-1)
232+
relness = relness_sorted[:196].view(-1)
196233
proposal_pairs_per_image = proposal_pairs[img_idx][img_sampled_inds]
197234
proposal_pairs[img_idx] = proposal_pairs_per_image
235+
relnesses.append(relness)
198236

199-
# distributed sampled proposals, that were obtained on all feature maps
200-
# concatenated via the fg_bg_sampler, into individual feature map levels
201-
# for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
202-
# zip(sampled_pos_inds, sampled_neg_inds)
203-
# ):
204-
# img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
205-
# proposal_pairs_per_image = proposal_pairs[img_idx][img_sampled_inds]
206-
# proposal_pairs[img_idx] = proposal_pairs_per_image
207-
237+
# self.cfg.MODEL.ROI_RELATION_HEAD.BATCH_SIZE_PER_IMAGE
208238
self._proposal_pairs = proposal_pairs
209239

210-
return proposal_pairs, {}
240+
return proposal_pairs, relnesses
211241

212-
def forward(self, proposals, targets):
242+
def forward(self, proposals, targets=None):
213243
"""
214244
This method performs the positive/negative sampling, and return
215245
the sampled proposals.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
def box_pos_encoder(bboxes, width, height):
4+
"""
5+
bounding box encoding
6+
"""
7+
bboxes_enc = bboxes.clone()
8+
9+
dim0 = bboxes_enc[:, 0] / width
10+
dim1 = bboxes_enc[:, 1] / height
11+
dim2 = bboxes_enc[:, 2] / width
12+
dim3 = bboxes_enc[:, 3] / height
13+
dim4 = (bboxes_enc[:, 2] - bboxes_enc[:, 0]) * (bboxes_enc[:, 3] - bboxes_enc[:, 1]) / height / width
14+
dim5 = (bboxes_enc[:, 3] - bboxes_enc[:, 1]) / (bboxes_enc[:, 2] - bboxes_enc[:, 0] + 1)
15+
16+
return torch.stack((dim0,dim1,dim2,dim3,dim4,dim5), 1)

0 commit comments

Comments
 (0)