Skip to content

Commit 174f764

Browse files
committed
TSP added
1 parent 122c69c commit 174f764

File tree

5 files changed

+529
-113
lines changed

5 files changed

+529
-113
lines changed

beam_search.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# beam search implementation in PyTorch."""
2+
#
3+
#
4+
# hyp1#-hyp1---hyp1 -hyp1
5+
# \ /
6+
# hyp2 \-hyp2 /-hyp2#hyp2
7+
# / \
8+
# hyp3#-hyp3---hyp3 -hyp3
9+
# ========================
10+
#
11+
# Takes care of beams, back pointers, and scores.
12+
13+
# Code borrowed from https://github.com/MaximumEntropy/Seq2Seq-PyTorch/blob/master/beam_search.py,
14+
# who borrowed it from PyTorch OpenNMT example
15+
# https://github.com/pytorch/examples/blob/master/OpenNMT/onmt/Beam.py
16+
# :-)
17+
18+
import torch
19+
20+
21+
class Beam(object):
22+
"""Ordered beam of candidate outputs. Fixed length."""
23+
24+
def __init__(self, size, steps, cuda=False):
25+
"""Initialize params."""
26+
self.size = size
27+
self.done = False
28+
self.pad = -1
29+
self.steps = steps
30+
self.current_step = 0
31+
self.tt = torch.cuda if cuda else torch
32+
33+
# The score for each translation on the beam.
34+
self.scores = self.tt.FloatTensor(size).zero_()
35+
36+
# The backpointers at each time-step.
37+
self.prevKs = []
38+
39+
# The outputs at each time-step.
40+
self.nextYs = [self.tt.LongTensor(size).fill_(self.pad)]
41+
42+
# The attentions (matrix) for each time.
43+
self.attn = []
44+
45+
# Get the outputs for the current timestep.
46+
def get_current_state(self):
47+
"""Get state of beam."""
48+
return self.nextYs[-1]
49+
50+
# Get the backpointers for the current timestep.
51+
def get_current_origin(self):
52+
"""Get the backpointer to the beam at this step."""
53+
return self.prevKs[-1]
54+
55+
# Given prob over words for every last beam `wordLk` and attention
56+
# `attnOut`: Compute and update the beam search.
57+
#
58+
# Parameters:
59+
#
60+
# * `wordLk`- probs of advancing from the last step (K x words)
61+
# * `attnOut`- attention at the last step
62+
#
63+
# Returns: True if beam search is complete.
64+
65+
def advance(self, workd_lk):
66+
"""Advance the beam."""
67+
num_words = workd_lk.size(1)
68+
69+
# Sum the previous scores.
70+
if len(self.prevKs) > 0:
71+
beam_lk = workd_lk + self.scores.unsqueeze(1).expand_as(workd_lk)
72+
else:
73+
beam_lk = workd_lk[0]
74+
75+
flat_beam_lk = beam_lk.view(-1)
76+
77+
bestScores, bestScoresId = flat_beam_lk.topk(self.size, 0, True, True)
78+
self.scores = bestScores
79+
80+
# bestScoresId is flattened beam x word array, so calculate which
81+
# word and beam each score came from
82+
prev_k = bestScoresId / num_words
83+
self.prevKs.append(prev_k)
84+
self.nextYs.append(bestScoresId - prev_k * num_words)
85+
86+
self.current_step += 1
87+
# End condition is when top-of-beam is EOS.
88+
if self.current_step == self.steps:
89+
self.done = True
90+
91+
return self.done
92+
93+
def sort_best(self):
94+
"""Sort the beam."""
95+
return torch.sort(self.scores, 0, True)
96+
97+
# Get the score of the best in the beam.
98+
def get_best(self):
99+
"""Get the most likely candidate."""
100+
scores, ids = self.sort_best()
101+
return scores[1], ids[1]
102+
103+
# Walk back to construct the full hypothesis.
104+
#
105+
# Parameters.
106+
#
107+
# * `k` - the position in the beam to construct.
108+
#
109+
# Returns.
110+
#
111+
# 1. The hypothesis
112+
# 2. The attention at each time step.
113+
def get_hyp(self, k):
114+
"""Get hypotheses."""
115+
hyp = []
116+
# print(len(self.prevKs), len(self.nextYs), len(self.attn))
117+
for j in range(len(self.prevKs) - 1, -1, -1):
118+
hyp.append(self.nextYs[j + 1][k])
119+
k = self.prevKs[j][k]
120+
121+
return hyp[::-1]

0 commit comments

Comments
 (0)