-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_test.py
More file actions
128 lines (103 loc) · 5.02 KB
/
train_test.py
File metadata and controls
128 lines (103 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import wandb
from equivariant_diffusion.utils import assert_mean_zero_with_mask, remove_mean_with_mask,\
assert_correctly_masked, sample_center_gravity_zero_gaussian_with_mask
import numpy as np
import qm9.visualizer as vis
from qm9.analyze import analyze_stability_for_molecules
from qm9.sampling import sample_chain, sample, sample_sweep_conditional
import utils
import qm9.utils as qm9utils
from qm9 import losses
import time
import torch
def train_epoch(args, loader, epoch, model, model_dp, model_ema, ema, device, dtype, property_norms, optim,
nodes_dist, gradnorm_queue, dataset_info, prop_dist):
model_dp.train()
model.train()
dist_loss_epoch = []
node_loss_epoch = []
nll_epoch = []
n_iterations = len(loader)
for i, data in enumerate(loader):
x = data['positions'].to(device, dtype)
node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
edge_mask = data['edge_mask'].to(device, dtype)
one_hot = data['one_hot'].to(device, dtype)
charges = (data['charges'] if args.include_charges else torch.zeros(0)).to(device, dtype)
x = remove_mean_with_mask(x, node_mask)
h = {'categorical': one_hot, 'integer': charges}
if len(args.conditioning) > 0:
context = qm9utils.prepare_context(args.conditioning, data, property_norms).to(device, dtype)
assert_correctly_masked(context, node_mask)
else:
context = None
optim.zero_grad()
dist_loss, node_loss = losses.compute_loss_and_nll(args, model_dp, nodes_dist,
x, h, node_mask, edge_mask, context)
loss = dist_loss + node_loss
loss.backward()
if args.clip_grad:
grad_norm = utils.gradient_clipping(model, gradnorm_queue)
else:
grad_norm = 0.
optim.step()
# Update EMA if enabled.
if args.ema_decay > 0:
ema.update_model_average(model_ema, model)
if i % args.n_report_steps == 0:
print(f"\rEpoch: {epoch}, iter: {i}/{n_iterations}, "
f"Loss {loss.item():.2f}, "
f"GradNorm: {grad_norm:.1f}")
nll_epoch.append(loss.item())
dist_loss_epoch.append(dist_loss.item())
node_loss_epoch.append(node_loss.item())
wandb.log({"Batch loss": loss.item(),
"Batch dist_loss": dist_loss.item(),
"Batch node_loss": node_loss.item()}, commit=True)
if args.break_train_epoch:
break
wandb.log({"Train Epoch loss": np.mean(nll_epoch),
"Train Epoch dist_loss": np.mean(dist_loss_epoch),
"Train Epoch node_loss": np.mean(node_loss_epoch)}, commit=True)
def check_mask_correct(variables, node_mask):
for i, variable in enumerate(variables):
if len(variable) > 0:
assert_correctly_masked(variable, node_mask)
def test(args, loader, epoch, eval_model, device, dtype, property_norms, nodes_dist, partition='Test'):
eval_model.eval()
with torch.no_grad():
nll_epoch = 0
dist_loss_epoch = 0
node_loss_epoch = 0
n_samples = 0
n_iterations = len(loader)
for i, data in enumerate(loader):
x = data['positions'].to(device, dtype)
batch_size = x.size(0)
node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
edge_mask = data['edge_mask'].to(device, dtype)
one_hot = data['one_hot'].to(device, dtype)
charges = (data['charges'] if args.include_charges else torch.zeros(0)).to(device, dtype)
if args.augment_noise > 0:
# Add noise eps ~ N(0, augment_noise) around points.
eps = sample_center_gravity_zero_gaussian_with_mask(x.size(),
x.device,
node_mask)
x = x + eps * args.augment_noise
x = remove_mean_with_mask(x, node_mask)
h = {'categorical': one_hot, 'integer': charges}
if len(args.conditioning) > 0:
context = qm9utils.prepare_context(args.conditioning, data, property_norms).to(device, dtype)
assert_correctly_masked(context, node_mask)
else:
context = None
dist_loss, node_loss = losses.compute_loss_and_nll(args, eval_model, nodes_dist, x, h,
node_mask, edge_mask, context)
dist_loss_epoch += dist_loss.item() * batch_size
node_loss_epoch += node_loss.item() * batch_size
n_samples += batch_size
if i % args.n_report_steps == 0:
print(f"\r {partition} NLL \t epoch: {epoch}, iter: {i}/{n_iterations}, "
f"dist_loss: {dist_loss_epoch/n_samples:.2f}, "
f"node_loss: {node_loss_epoch/n_samples:.2f}")
return dist_loss_epoch/n_samples, node_loss_epoch/n_samples