File tree Expand file tree Collapse file tree 15 files changed +37
-19
lines changed Expand file tree Collapse file tree 15 files changed +37
-19
lines changed Original file line number Diff line number Diff line change @@ -64,7 +64,7 @@ def benchmark(
6464 num_workers = num_workers ,
6565 pin_memory = True ,
6666 )
67- test_results , run_hash = evaluate_classification (
67+ test_results , speed_mem_metrics , run_hash = evaluate_classification (
6868 model = model ,
6969 test_loader = test_loader ,
7070 model_output_transform = model_output_transform ,
@@ -84,6 +84,7 @@ def benchmark(
8484 config = config ,
8585 dataset = cls .dataset .__name__ ,
8686 results = test_results ,
87+ speed_mem_metrics = speed_mem_metrics ,
8788 pytorch_hub_id = pytorch_hub_url ,
8889 model = paper_model_name ,
8990 model_description = model_description ,
Original file line number Diff line number Diff line change @@ -64,7 +64,7 @@ def benchmark(
6464 num_workers = num_workers ,
6565 pin_memory = True ,
6666 )
67- test_results , run_hash = evaluate_classification (
67+ test_results , speed_mem_metrics , run_hash = evaluate_classification (
6868 model = model ,
6969 test_loader = test_loader ,
7070 model_output_transform = model_output_transform ,
@@ -84,6 +84,7 @@ def benchmark(
8484 config = config ,
8585 dataset = cls .dataset .__name__ ,
8686 results = test_results ,
87+ speed_mem_metrics = speed_mem_metrics ,
8788 pytorch_hub_id = pytorch_hub_url ,
8889 model = paper_model_name ,
8990 model_description = model_description ,
Original file line number Diff line number Diff line change @@ -200,7 +200,7 @@ def benchmark(
200200 num_workers = num_workers ,
201201 pin_memory = pin_memory ,
202202 )
203- test_results , run_hash = evaluate_classification (
203+ test_results , speed_mem_metrics , run_hash = evaluate_classification (
204204 model = model ,
205205 test_loader = test_loader ,
206206 model_output_transform = model_output_transform ,
@@ -220,6 +220,7 @@ def benchmark(
220220 config = config ,
221221 dataset = cls .dataset .__name__ ,
222222 results = test_results ,
223+ speed_mem_metrics = speed_mem_metrics ,
223224 pytorch_hub_id = pytorch_hub_url ,
224225 model = paper_model_name ,
225226 model_description = model_description ,
Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ def benchmark(
6363 num_workers = num_workers ,
6464 pin_memory = True ,
6565 )
66- test_results , run_hash = evaluate_classification (
66+ test_results , speed_mem_metrics , run_hash = evaluate_classification (
6767 model = model ,
6868 test_loader = test_loader ,
6969 model_output_transform = model_output_transform ,
@@ -83,6 +83,7 @@ def benchmark(
8383 config = config ,
8484 dataset = cls .dataset .__name__ ,
8585 results = test_results ,
86+ speed_mem_metrics = speed_mem_metrics ,
8687 pytorch_hub_id = pytorch_hub_url ,
8788 model = paper_model_name ,
8889 model_description = model_description ,
Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ def benchmark(
6363 num_workers = num_workers ,
6464 pin_memory = True ,
6565 )
66- test_results , run_hash = evaluate_classification (
66+ test_results , speed_mem_metrics , run_hash = evaluate_classification (
6767 model = model ,
6868 test_loader = test_loader ,
6969 model_output_transform = model_output_transform ,
@@ -83,6 +83,7 @@ def benchmark(
8383 config = config ,
8484 dataset = cls .dataset .__name__ ,
8585 results = test_results ,
86+ speed_mem_metrics = speed_mem_metrics ,
8687 pytorch_hub_id = pytorch_hub_url ,
8788 model = paper_model_name ,
8889 model_description = model_description ,
Original file line number Diff line number Diff line change @@ -66,7 +66,7 @@ def benchmark(
6666 num_workers = num_workers ,
6767 pin_memory = True ,
6868 )
69- test_results , run_hash = evaluate_classification (
69+ test_results , speed_mem_metrics , run_hash = evaluate_classification (
7070 model = model ,
7171 test_loader = test_loader ,
7272 model_output_transform = model_output_transform ,
@@ -86,6 +86,7 @@ def benchmark(
8686 config = config ,
8787 dataset = cls .dataset .__name__ ,
8888 results = test_results ,
89+ speed_mem_metrics = speed_mem_metrics ,
8990 pytorch_hub_id = pytorch_hub_url ,
9091 model = paper_model_name ,
9192 model_description = model_description ,
Original file line number Diff line number Diff line change @@ -62,11 +62,14 @@ def evaluate_classification(
6262
6363 end = time .time ()
6464
65+ speed_mem_metrics = {
66+ 'Tasks Per Second' : test_loader .batch_size / inference_time .avg ,
67+ 'Memory Allocated' : memory_allocated
68+ }
69+
6570 return (
6671 {"Top 1 Accuracy" : top1 .avg / 100 ,
67- "Top 5 Accuracy" : top5 .avg / 100 ,
68- 'Tasks Per Second' : test_loader .batch_size / inference_time .avg ,
69- 'Memory Allocated' : memory_allocated },
72+ "Top 5 Accuracy" : top5 .avg / 100 }, speed_mem_metrics ,
7073 run_hash ,
7174 )
7275
Original file line number Diff line number Diff line change @@ -217,7 +217,7 @@ def benchmark(
217217 collate_fn = collate_fn ,
218218 )
219219 test_loader .no_classes = 91 # Number of classes for COCO Detection
220- test_results , run_hash = evaluate_detection_coco (
220+ test_results , speed_mem_metrics , run_hash = evaluate_detection_coco (
221221 model = model ,
222222 test_loader = test_loader ,
223223 model_output_transform = model_output_transform ,
@@ -232,6 +232,7 @@ def benchmark(
232232 config = config ,
233233 dataset = 'COCO minival' ,
234234 results = test_results ,
235+ speed_mem_metrics = speed_mem_metrics ,
235236 pytorch_hub_id = pytorch_hub_url ,
236237 model = paper_model_name ,
237238 model_description = model_description ,
Original file line number Diff line number Diff line change @@ -249,7 +249,7 @@ def evaluate_detection_coco(
249249 'Tasks Per Second' : test_loader .batch_size / inference_time .avg ,
250250 'Memory Allocated' : memory_allocated }
251251
252- return ({ ** get_coco_metrics (coco_evaluator ), ** device_metrics } , run_hash )
252+ return (get_coco_metrics (coco_evaluator ), device_metrics , run_hash )
253253
254254
255255def evaluate_detection_voc (
Original file line number Diff line number Diff line change @@ -85,7 +85,7 @@ def benchmark(
8585 collate_fn = collate_fn ,
8686 )
8787 test_loader .no_classes = 150 # Number of classes for ADE20K
88- test_results , run_hash = evaluate_segmentation (
88+ test_results , speed_mem_metrics , run_hash = evaluate_segmentation (
8989 model = model ,
9090 test_loader = test_loader ,
9191 model_output_transform = model_output_transform ,
@@ -100,6 +100,7 @@ def benchmark(
100100 config = config ,
101101 dataset = cls .dataset .__name__ + " val" ,
102102 results = test_results ,
103+ speed_mem_metrics = speed_mem_metrics ,
103104 pytorch_hub_id = pytorch_hub_url ,
104105 model = paper_model_name ,
105106 model_description = model_description ,
You can’t perform that action at this time.
0 commit comments