Skip to content

Commit e4a6fc7

Browse files
author
Turkka Helinoja
committed
Add initial file for running the network and visualizing the results in real time for single images or videos
1 parent f940c83 commit e4a6fc7

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed

tools/run_and_visualize.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import argparse
6+
import os
7+
import pprint
8+
9+
import torch
10+
import torch.nn.parallel
11+
import torch.backends.cudnn as cudnn
12+
import torch.optim
13+
import torch.utils.data
14+
import torch.utils.data.distributed
15+
import torchvision.transforms as transforms
16+
17+
import _init_paths
18+
from config import cfg
19+
from config import update_config
20+
from core.loss import JointsMSELoss
21+
from core.function import validate
22+
from utils.utils import create_logger
23+
24+
import dataset
25+
import models
26+
27+
def parse_args():
28+
parser = argparse.ArgumentParser(description="Run and visualize keypoints network for image or video.")
29+
# general
30+
parser.add_argument("--cfg",
31+
help="Configuration file name",
32+
required=True,
33+
type=str)
34+
35+
parser.add_argument("opts",
36+
help="Modify config options using the command-line",
37+
default=None,
38+
nargs=argparse.REMAINDER)
39+
40+
parser.add_argument("--modelDir",
41+
help="model directory",
42+
type=str,
43+
default="")
44+
parser.add_argument("--logDir",
45+
help="log directory",
46+
type=str,
47+
default="")
48+
parser.add_argument("--dataDir",
49+
help="data directory",
50+
type=str,
51+
default="")
52+
parser.add_argument("--prevModelDir",
53+
help="prev Model directory",
54+
type=str,
55+
default="")
56+
parser.add_argument("--visualize",
57+
help="Visualize the results",
58+
type=bool,
59+
default=False)
60+
parser.add_argument("--input",
61+
help="Input image file",
62+
type=str,
63+
default="")
64+
parser.add_argument("--video",
65+
help="Input video file",
66+
type=str,
67+
default="")
68+
69+
args = parser.parse_args()
70+
return args
71+
72+
def main():
73+
args = parse_args()
74+
update_config(cfg, args)
75+
76+
# Create a logger
77+
logger, final_output_dir, tb_log_dir = create_logger(
78+
cfg, args.cfg, 'valid')
79+
logger.info(pprint.pformat(args))
80+
logger.info(cfg)
81+
82+
# cudnn related setting
83+
cudnn.benchmark = cfg.CUDNN.BENCHMARK
84+
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
85+
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
86+
87+
# Configure model
88+
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
89+
cfg, is_train=False
90+
)
91+
92+
if cfg.TEST.MODEL_FILE:
93+
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
94+
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
95+
else:
96+
model_state_file = os.path.join(
97+
final_output_dir, 'final_state.pth'
98+
)
99+
logger.info('=> loading model from {}'.format(model_state_file))
100+
model.load_state_dict(torch.load(model_state_file))
101+
102+
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
103+
104+
# define loss function (criterion) and optimizer
105+
criterion = JointsMSELoss(
106+
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
107+
).cuda()
108+
109+
# Data loading code
110+
normalize = transforms.Normalize(
111+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112+
)
113+
114+
# Load data
115+
if args.input != "":
116+
with open(args.input, "r") as image:
117+
# TODO: Write a way to handle single images
118+
# TODO: Handle visualization
119+
pass
120+
elif args.video:
121+
# TODO: Write a way to handle videos image by image
122+
# TODO: Handle visualization
123+
pass
124+
else:
125+
# Original dataset way
126+
valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
127+
cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
128+
transforms.Compose([
129+
transforms.ToTensor(),
130+
normalize,
131+
])
132+
)
133+
134+
valid_loader = torch.utils.data.DataLoader(
135+
valid_dataset,
136+
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
137+
shuffle=False,
138+
num_workers=cfg.WORKERS,
139+
pin_memory=True
140+
)
141+
142+
# evaluate on validation set
143+
validate(cfg, valid_loader, valid_dataset, model, criterion,
144+
final_output_dir, tb_log_dir)
145+
146+
if __name__ == '__main__':
147+
main()

0 commit comments

Comments
 (0)