-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
93 lines (74 loc) · 2.74 KB
/
utils.py
File metadata and controls
93 lines (74 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import torch
from data.schemas import TokenizedSeqBatch
from einops import rearrange
from torch import Tensor
def reset_kv_cache(fn):
def inner(self, *args, **kwargs):
self.decoder.reset_kv_cache()
out = fn(self, *args, **kwargs)
self.decoder.reset_kv_cache()
return out
return inner
def reset_encoder_cache(fn):
def inner(self, *args, **kwargs):
if self.jagged_mode:
self.transformer.cached_enc_output = None
out = fn(self, *args, **kwargs)
if self.jagged_mode:
self.transformer.cached_enc_output = None
return out
return inner
def eval_mode(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
def select_columns_per_row(x: Tensor, indices: Tensor) -> torch.Tensor:
assert x.shape[0] == indices.shape[0]
assert indices.shape[1] <= x.shape[1]
B = x.shape[0]
return x[
rearrange(torch.arange(B, device=x.device), "B -> B 1"), indices
]
def maybe_repeat_interleave(x, repeats, dim):
if not isinstance(x, Tensor):
return x
return x.repeat_interleave(repeats, dim=dim)
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str, help="Path to gin config file.")
args = parser.parse_args()
gin.parse_config_file(args.config_path)
@torch.no_grad
def compute_debug_metrics(batch: TokenizedSeqBatch, model_output = None, prefix: str = "") -> dict:
seq_lengths = batch.seq_mask.sum(axis=1).to(torch.float32)
prefix = prefix + "_"
debug_metrics = {
prefix + f"seq_length_p{q}": torch.quantile(seq_lengths, q=q).detach().cpu().item()
for q in [0.25, 0.5, 0.75, 0.9, 1]
}
if model_output is not None:
loss_debug_metrics = {
prefix + f"loss_{d}": model_output.loss_d[d].detach().cpu().item() for d in range(batch.sem_ids_fut.shape[1])
}
debug_metrics.update(loss_debug_metrics)
return debug_metrics
def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:
gumbels = -(
-torch.rand_like(logits, memory_format=torch.legacy_contiguous_format).log()
).log()
gumbels = (logits + gumbels) / tau
y_soft = gumbels.softmax(dim)
if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret