-
Notifications
You must be signed in to change notification settings - Fork 48
Expand file tree
/
Copy pathplot_single_factor.py
More file actions
105 lines (91 loc) · 3.09 KB
/
plot_single_factor.py
File metadata and controls
105 lines (91 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
import pandas as pd
import numpy as np
import os, sys
from scripts.constants import *
from absl import flags
FLAGS = flags.FLAGS
## Hparams
flags.DEFINE_string("csv_path", None, "csv path")
flags.DEFINE_integer("max_x", None, "max value of x")
flags.DEFINE_integer("window_size", 50, "window size of the plot")
flags.DEFINE_list("factors", None, "selected single factors")
flags.FLAGS(sys.argv)
print(FLAGS.flags_into_string())
df = pd.read_csv(FLAGS.csv_path)
# extract the tags
x_tag = x_tag[1]
key_of_interests = [y_tag[1] for y_tag in y_tags if y_tag[1] in df.columns]
variant_tag_names = [
variant_tag_name
for variant_tag_name in variant_tag_names
if variant_tag_name in df.columns
]
if FLAGS.factors is None:
FLAGS.factors = variant_tag_names
assert set(FLAGS.factors) <= set(variant_tag_names) # subset check
if FLAGS.max_x is not None:
df = df.loc[df[x_tag] <= FLAGS.max_x] # set max_x
# smoothing
for key in key_of_interests:
df[key] = df.groupby([*variant_tag_names, trial_tag])[key].transform(
lambda x: x.rolling(FLAGS.window_size, min_periods=1).mean() # rolling mean
)
# make a square-like plot
num_plots = len(key_of_interests) * max(1, len(FLAGS.factors))
cols = int(np.ceil(np.sqrt(num_plots)))
rows = int(np.ceil(num_plots / cols))
# seaborn plot
sns.set(font_scale=2.0)
fig, axes = plt.subplots(rows, cols, figsize=(cols * 7, rows * 4))
axes = (
axes.flatten() if isinstance(axes, np.ndarray) else [axes]
) # make it as a flat list
# use lineplot that has average curve (for same x-value) with 95% confidence interval on y-value
# https://seaborn.pydata.org/generated/seaborn.lineplot.html
# has at most 3 independent dims to plot, using hue and style. But recommend to use at most 2 dims,
# by setting hue and style the same key
# NOTE: any seaborn function has argument ax to support subplots
df_ours = df.loc[df["method"] == "ours"]
ax_id = 0
for key in key_of_interests:
for variant_tag_name in FLAGS.factors:
sns.lineplot(
ax=axes[ax_id],
data=df_ours,
x=x_tag,
y=key,
palette=variant_colors[variant_tag_name]
if variant_tag_name != "Len"
else None,
hue=variant_tag_name,
# hue_order=order,
# style=variant_tag,
# style_order=order,
# ci=None, # save a lot time without error bars
sort=False,
)
# the follow command will remove the legend hue, why?
axes[ax_id].legend(framealpha=0.5, loc="upper left")
axes[ax_id].set_title(variant_tag_name)
# if FLAGS.max_x is not None:
# axes[ax_id].set_xlim(0, FLAGS.max_x)
ax_id += 1
# set the rest subplots blank
while ax_id < rows * cols:
axes[ax_id].set_visible(False)
ax_id += 1
plt.tight_layout()
# plt.show()
# plt.close()
plt.savefig(
os.path.join(
*FLAGS.csv_path.split("/")[:-1],
f"single_factor-{''.join(FLAGS.factors)}-window{FLAGS.window_size}.png",
),
dpi=200,
bbox_inches="tight",
)
plt.close()