Skip to content

Commit 17a40e7

Browse files
authored
Add files via upload
relative position
0 parents  commit 17a40e7

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

relative_position.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
class relative_position(nn.Module):
2+
3+
def __init__(self, num_units, max_relative_position):
4+
super(relative_position, self).__init__()
5+
self.num_units = num_units
6+
self.max_relative_position = max_relative_position
7+
self.embeddings_table = Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)
8+
nn.init.xavier_uniform_(self.embeddings_table)
9+
10+
def forward(self, length_q, length_k):
11+
range_vec_q = torch.arange(length_q)
12+
range_vec_k = torch.arange(length_k)
13+
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
14+
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
15+
# 将序列distance_mat小于-max_relative_position或者大于max_relative_position的值都设置为-+max_relative_position
16+
final_mat = distance_mat_clipped + self.max_relative_position
17+
final_mat = torch.LongTensor(final_mat).cuda()
18+
embeddings = self.embeddings_table[final_mat].cuda()
19+
20+
return embeddings
21+
22+
r_k = self.relative_position(Q_.size()[1], K_.size()[1])
23+
outputs = outputs + torch.bmm(Q_.permute(1, 0, 2), r_k.permute(0, 2, 1)).permute(1, 0, 2)
24+
25+
r_v = self.relative_position(Q_.size()[1], V_.size()[1])
26+
outputs = outputs + torch.bmm(weights.permute(1, 0, 2), r_v).permute(1, 0, 2)
27+
#the size of Q,K,V is [heads*batch,length,dim//heads]

0 commit comments

Comments
 (0)