-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy patht2m_sample.py
More file actions
91 lines (77 loc) · 3.26 KB
/
Copy patht2m_sample.py
File metadata and controls
91 lines (77 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from pathlib import Path
import hydra
from omegaconf import DictConfig
import torch as th
from src.models.diffusion_module import DiffusionLitModule
from omegaconf import OmegaConf
import json
from pytorch3d.transforms.rotation_conversions import (
matrix_to_quaternion,
rotation_6d_to_matrix,
)
from scipy.ndimage import gaussian_filter1d
from src.utils.sample_util import sample_motion
from src.utils.misc import remove_special_characters
from src.utils.vis_util import render_video_summary
@hydra.main(config_path="configs/", config_name="t2m_sample.yaml")
def main(config: DictConfig):
ckpt_path = Path(config["ckpt_path"])
motion_length = config["motion_length"]
NUM_JOINTS = 24
guidance_scale = config["guidance_scale"]
labels_for_gen = list(config["labels_for_gen"])
plot_gif = config["plot_gif"]
export_json = config["export_json"]
use_smoothing = config["use_smoothing"]
fps = config["fps"]
print(OmegaConf.to_yaml(config))
model = DiffusionLitModule.load_from_checkpoint(ckpt_path).cuda()
motion_dim = (
model.net.motion_dim
) # 147 = translation (3) + rotation with 6D representation format (24 * 6 = 144)
ema_model = model.ema_model.model
ema_model.eval()
for sample_id, ann in enumerate(labels_for_gen):
generated_motions = sample_motion(
sampling_texts=[ann],
motion_lengths=[motion_length],
sample_fn=model.diffusion.p_sample_loop,
ema_model=ema_model,
device=th.device("cuda" if th.cuda.is_available() else "cpu"),
motion_dim=motion_dim,
guidance_scale=guidance_scale,
progress=True,
) # This will output (Batch, Motion Length, Representation dim)
if use_smoothing:
generated_motions = th.Tensor(gaussian_filter1d(generated_motions.cpu().numpy(), sigma=1, axis=1))
label_wos = remove_special_characters(ann)
if plot_gif:
render_video_summary(
img_ids=[f"{label_wos}_guidance{guidance_scale}_L{motion_length}"],
translations=generated_motions[:, :motion_length, :3],
rotation_6ds=generated_motions[:, :motion_length, 3:],
annotations=[ann],
fps=fps,
write=True,
)
if export_json:
translation = generated_motions[0, :motion_length, :3]
rotation_6d_gen = generated_motions[0, :motion_length, 3:]
rotation_matrix_gen = rotation_6d_to_matrix(
rotation_6d_gen.reshape(motion_length, NUM_JOINTS, 6)
)
quaternion_gen = matrix_to_quaternion(
rotation_matrix_gen
) # Real-part first for quaternion representation (w, x, y, z)
export_dict = {
"label": ann,
"translation": translation.tolist(), # (L, 3)
"rotation_quat": quaternion_gen.tolist(), # (L, 24, 4)
"quaternion_order": "wxyz",
"guidance_scale": guidance_scale,
}
file_name = f"{sample_id}_{label_wos}_L{motion_length}.json"
with open(file_name, "w", encoding="utf-8") as f:
json.dump(export_dict, f, indent=4)
if __name__ == "__main__":
main()