Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
log predictions by themselves in separate file + main stdout log + fi…
…xes formatting of displayed results
  • Loading branch information
fmigneault committed Jan 15, 2021
commit 04a5dd2854d69c637306b1931d0de1126787550b
15 changes: 10 additions & 5 deletions slowfast/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ def setup_logging(output_dir=None):
logger.addHandler(ch)

if output_dir is not None and du.is_master_proc(du.get_world_size()):
filename = os.path.join(output_dir, "stdout.log")
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
setup_file_logger(logger, output_dir, "stdout.log", plain_formatter)


def setup_file_logger(logger, output_dir, file_name, formatter=None):
filename = os.path.join(output_dir, file_name)
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter or logging.Formatter("%(message)s"))
logger.addHandler(fh)
logger.setLevel(logging.DEBUG)


def get_logger(name):
Expand Down
37 changes: 25 additions & 12 deletions slowfast/visualization/video_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

logger = logging.get_logger(__name__)
log.getLogger("matplotlib").setLevel(log.ERROR)
pred_log = logging.get_logger("slowfast-predictions")


def _create_text_labels(classes, scores, class_names, ground_truth=False):
Expand Down Expand Up @@ -684,6 +685,7 @@ class VideoLogger(VideoVisualizer):
def __init__(self, *_, **__):
super(VideoLogger, self).__init__(*_, **__)
self.clip_index = 0
self.frame_range = []

def draw_clip_range(
self,
Expand All @@ -697,6 +699,12 @@ def draw_clip_range(
repeat_frame=1,
):
self.clip_index += 1
frame_range = [0, len(frames) - 1]
if not self.frame_range:
self.frame_range = frame_range
else:
self.frame_range[0] = self.frame_range[1] + frame_range[0]
self.frame_range[1] = self.frame_range[1] + frame_range[1]

if isinstance(preds, torch.Tensor):
if preds.ndim == 1:
Expand All @@ -723,6 +731,9 @@ def draw_clip_range(
top_scores.append(pred[mask].tolist())
top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist()
top_classes.append(top_class)
else:
logger.error("Unknown mode: %s", self.mode)
return

text_labels = []
for i in range(n_instances):
Expand All @@ -735,21 +746,23 @@ def draw_clip_range(
)
)

frames_info = "{:04d} [{:08d}, {:08d}]:".format(self.clip_index, self.frame_range[0], self.frame_range[1])
if bboxes is not None:
assert len(preds) == len(
bboxes
), "Encounter {} predictions and {} bounding boxes".format(
len(preds), len(bboxes)
)
logger.info("%04d", self.clip_index)
assert len(preds) == len(bboxes), \
"Encounter {} predictions and {} bounding boxes".format(len(preds), len(bboxes))
pred_log.info(frames_info)
for i, box in enumerate(bboxes):
top_labels = [self.class_names[i] for i in top_classes[i]]
txt_scores = [float("{:.4f}".format(float(score))) for score in top_scores[i]]
label = " labeled '{}'".format(text_labels[i]) if ground_truth else ""
text_box = "bbox: {},".format(list(box))
logger.info(" %s %s is predicted to class %s, %s: %s, %s",
text_box, label, top_classes[i], method, list(top_classes[i]), list(top_scores[i]))
text_box = "bbox: {},".format(list(float("{:04.2f}".format(float(c))) for c in list(box)))
pred_log.info(" %s%s is predicted to class %s, %s: %s, %s",
text_box, label, text_labels[i][0], method, top_labels, txt_scores)
else:
label = " labeled '{}'".format(text_labels[0]) if ground_truth else ""
logger.info("%04d%s is predicted to class %s, %s: %s, %s",
self.clip_index, label, top_classes[0], method, list(top_classes), list(top_scores))
top_labels = [self.class_names[i] for i in top_classes[0]]
txt_scores = [float("{:.4f}".format(float(score))) for score in top_scores[0]]
pred_log.info("%s%s is predicted to class %s, %s: %s, %s",
frames_info, label, text_labels[0], method, top_labels, txt_scores)

return []
return [] # drop frames to speed up process (no writing)
28 changes: 16 additions & 12 deletions tools/demo_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import torch
import tqdm
import os

from slowfast.utils import logging
from slowfast.visualization.async_predictor import AsyncDemo, AsyncVis
Expand All @@ -14,7 +15,7 @@
from slowfast.visualization.async_predictor import draw_predictions, log_predictions
from slowfast.visualization.demo_loader import ThreadVideoManager, VideoManager
from slowfast.visualization.predictor import ActionPredictor
from slowfast.visualization.video_visualizer import VideoVisualizer
from slowfast.visualization.video_visualizer import VideoVisualizer, VideoLogger

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -44,21 +45,24 @@ def run_demo(cfg, frame_provider):
)

if not cfg.DEMO.OUTPUT_DISPLAY:
video_vis = ()
video_vis_cls = VideoLogger
pred_processor = log_predictions
pred_log = logging.get_logger("slowfast-predictions")
logging.setup_file_logger(pred_log, cfg.OUTPUT_DIR, "predictions.log")
else:
video_vis = VideoVisualizer(
num_classes=cfg.MODEL.NUM_CLASSES,
class_names_path=cfg.DEMO.LABEL_FILE_PATH,
top_k=cfg.TENSORBOARD.MODEL_VIS.TOPK_PREDS,
thres=cfg.DEMO.COMMON_CLASS_THRES,
lower_thres=cfg.DEMO.UNCOMMON_CLASS_THRES,
common_class_names=common_classes,
colormap=cfg.TENSORBOARD.MODEL_VIS.COLORMAP,
mode=cfg.DEMO.VIS_MODE,
)
video_vis_cls = VideoVisualizer
pred_processor = draw_predictions

video_vis = video_vis_cls(
num_classes=cfg.MODEL.NUM_CLASSES,
class_names_path=cfg.DEMO.LABEL_FILE_PATH,
top_k=cfg.TENSORBOARD.MODEL_VIS.TOPK_PREDS,
thres=cfg.DEMO.COMMON_CLASS_THRES,
lower_thres=cfg.DEMO.UNCOMMON_CLASS_THRES,
common_class_names=common_classes,
colormap=cfg.TENSORBOARD.MODEL_VIS.COLORMAP,
mode=cfg.DEMO.VIS_MODE,
)
async_vis = AsyncVis(
video_vis,
n_workers=cfg.DEMO.NUM_VIS_INSTANCES,
Expand Down