|
| 1 | +class relative_position(nn.Module): |
| 2 | + |
| 3 | + def __init__(self, num_units, max_relative_position): |
| 4 | + super(relative_position, self).__init__() |
| 5 | + self.num_units = num_units |
| 6 | + self.max_relative_position = max_relative_position |
| 7 | + self.embeddings_table = Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units) |
| 8 | + nn.init.xavier_uniform_(self.embeddings_table) |
| 9 | + |
| 10 | + def forward(self, length_q, length_k): |
| 11 | + range_vec_q = torch.arange(length_q) |
| 12 | + range_vec_k = torch.arange(length_k) |
| 13 | + distance_mat = range_vec_k[None, :] - range_vec_q[:, None] |
| 14 | + distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) |
| 15 | + # 将序列distance_mat小于-max_relative_position或者大于max_relative_position的值都设置为-+max_relative_position |
| 16 | + final_mat = distance_mat_clipped + self.max_relative_position |
| 17 | + final_mat = torch.LongTensor(final_mat).cuda() |
| 18 | + embeddings = self.embeddings_table[final_mat].cuda() |
| 19 | + |
| 20 | + return embeddings |
| 21 | + |
| 22 | +r_k = self.relative_position(Q_.size()[1], K_.size()[1]) |
| 23 | +outputs = outputs + torch.bmm(Q_.permute(1, 0, 2), r_k.permute(0, 2, 1)).permute(1, 0, 2) |
| 24 | + |
| 25 | +r_v = self.relative_position(Q_.size()[1], V_.size()[1]) |
| 26 | +outputs = outputs + torch.bmm(weights.permute(1, 0, 2), r_v).permute(1, 0, 2) |
| 27 | +#the size of Q,K,V is [heads*batch,length,dim//heads] |
0 commit comments