Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Comment.
  • Loading branch information
belladoreai committed May 28, 2024
commit 686e7a8344f1aa9d27da3055906119f55f085d25
5 changes: 1 addition & 4 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
input_ids = input_ids[:, -self._range:]

for input_ids_row, scores_row in zip(input_ids, scores):

# Use normal Python data types for improved performance
input_ids = input_ids_row.tolist()

# Raw integer must be extracted here to check for set membership.
last_token = input_ids[-1]

last_token = input_ids[-1]
if last_token in self.sequence_breakers:
continue

Expand Down