Skip to content

Commit 5ea25d0

Browse files
fix o_prop
1 parent f583077 commit 5ea25d0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flair/models/sequence_tagger_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def forward(self, sentence_tensor: torch.Tensor, lengths: torch.LongTensor):
342342
if self.o_count != 1:
343343
props = F.softmax(features, dim=-1)
344344
o_prop = props[..., 0] + props[..., len(self.label_dictionary) :].sum(dim=-1)
345-
real_props = torch.cat([o_prop.unsqueeze(0), props[..., 1 : len(self.label_dictionary)]], dim=-1)
345+
real_props = torch.cat([o_prop.unsqueeze(-1), props[..., 1 : len(self.label_dictionary)]], dim=-1)
346346
features = torch.log(real_props)
347347

348348
# Depending on whether we are using CRF or a linear layer, scores is either:

0 commit comments

Comments
 (0)