Skip to content

Commit d6b32d7

Browse files
committed
fix: 修复 show_cumulative_returns 函数以处理 NaN 值并确保累计收益计算正确
1 parent cd996b6 commit d6b32d7

3 files changed

Lines changed: 79 additions & 7 deletions

File tree

czsc/svc/returns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def show_cumulative_returns(df, key=None, **kwargs):
127127
display_legend = kwargs.get("display_legend", True)
128128
fig_title = kwargs.get("fig_title", "累计收益")
129129

130-
df_cumsum = df.cumsum()
130+
df_cumsum = df.fillna(0).cumsum()
131131
fig = px.line(df_cumsum, y=df_cumsum.columns.to_list(), title=fig_title)
132132
fig.update_xaxes(title="")
133133

czsc/utils/plotting/backtest.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
从 czsc.svc 模块中提取的 WeightBacktest 相关的绘图代码,按功能整理
55
"""
66

7+
import re
78
from typing import Union, Tuple
89
import numpy as np
910
import pandas as pd
@@ -32,6 +33,55 @@
3233
# ==================== 辅助函数 ====================
3334

3435

36+
TABLE_LIGHT_TEXT = "#f8fafc"
37+
TABLE_DARK_TEXT = "#0f172a"
38+
39+
40+
def _parse_rgb_color(color: str) -> Tuple[int, int, int, float] | None:
41+
"""解析 rgb/rgba/hex 颜色字符串。"""
42+
if not isinstance(color, str):
43+
return None
44+
45+
value = color.strip()
46+
if value.startswith("#"):
47+
hex_color = value[1:]
48+
if len(hex_color) == 3:
49+
hex_color = "".join(ch * 2 for ch in hex_color)
50+
if len(hex_color) == 6:
51+
return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16), 1.0
52+
return None
53+
54+
match = re.fullmatch(
55+
r"rgba?\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)(?:\s*,\s*(\d+(?:\.\d+)?))?\s*\)",
56+
value,
57+
)
58+
if not match:
59+
return None
60+
61+
red, green, blue = (int(float(match.group(index))) for index in range(1, 4))
62+
alpha = float(match.group(4)) if match.group(4) is not None else 1.0
63+
return red, green, blue, alpha
64+
65+
66+
def _get_table_text_color(background_color: str, default_color: str) -> str:
67+
"""根据背景亮度返回更易读的文字颜色。"""
68+
parsed = _parse_rgb_color(background_color)
69+
if parsed is None:
70+
return default_color
71+
72+
red, green, blue, alpha = parsed
73+
if alpha <= 0:
74+
return default_color
75+
76+
brightness = (red * 299 + green * 587 + blue * 114) / 1000
77+
return TABLE_DARK_TEXT if brightness >= 150 else TABLE_LIGHT_TEXT
78+
79+
80+
def _get_default_table_text_color(template: TemplateType) -> str:
81+
"""根据模板推断表格默认文字颜色。"""
82+
return TABLE_LIGHT_TEXT if "dark" in str(template).lower() else TABLE_DARK_TEXT
83+
84+
3585
def _calculate_drawdown(
3686
returns: pd.Series,
3787
fillna: bool = True
@@ -529,14 +579,18 @@ def plot_colored_table(
529579
row_height = kwargs.get("row_height", 30)
530580
border_color = kwargs.get("border_color", COLOR_BORDER)
531581
header_bgcolor = kwargs.get("header_bgcolor", COLOR_HEADER_BG)
582+
default_text_color = _get_default_table_text_color(template)
583+
header_text_color = _get_table_text_color(header_bgcolor, default_text_color)
532584

533585
# 准备数据和颜色
534586
cell_values = []
535587
cell_colors = []
588+
cell_font_colors = []
536589

537590
# 处理索引列
538591
cell_values.append(df.index.tolist())
539592
cell_colors.append([header_bgcolor] * len(df))
593+
cell_font_colors.append([header_text_color] * len(df))
540594

541595
# 处理数据列
542596
for col in df.columns:
@@ -565,23 +619,25 @@ def plot_colored_table(
565619
colors = px.colors.sample_colorscale("RdYlGn_r", sample_vals)
566620

567621
cell_colors.append(colors)
622+
cell_font_colors.append([_get_table_text_color(str(color), default_text_color) for color in colors])
568623
else:
569624
cell_colors.append(['rgba(0,0,0,0)'] * len(series))
625+
cell_font_colors.append([default_text_color] * len(series))
570626

571627
fig = go.Figure(data=[go.Table(
572628
header=dict(
573629
values=headers,
574630
fill_color=header_bgcolor,
575631
align='center',
576-
font=dict(color='white', size=12),
632+
font=dict(color=header_text_color, size=12),
577633
height=row_height,
578634
line=dict(color=border_color, width=1)
579635
),
580636
cells=dict(
581637
values=cell_values,
582638
fill_color=cell_colors,
583639
align='center',
584-
font=dict(color='white', size=12),
640+
font=dict(color=cell_font_colors, size=12),
585641
height=row_height,
586642
line=dict(color=border_color, width=1)
587643
)

test/test_plot_colored_table.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pandas as pd
2-
import numpy as np
3-
from czsc.utils.plotting.backtest import plot_colored_table
2+
from czsc.utils.plotting.backtest import plot_colored_table, TABLE_DARK_TEXT, TABLE_LIGHT_TEXT
3+
44

55
def test_plot_colored_table():
66
# 构造测试数据
@@ -16,8 +16,8 @@ def test_plot_colored_table():
1616

1717
# 生成 HTML
1818
html_content = plot_colored_table(
19-
df,
20-
title="策略绩效对比测试",
19+
df,
20+
title="策略绩效对比测试",
2121
to_html=True,
2222
is_good_high_columns=["年化收益率", "夏普比率", "胜率"],
2323
row_height=40,
@@ -52,5 +52,21 @@ def test_plot_colored_table():
5252

5353
print("测试完成,结果已保存至 test_plot_colored_table_result.html")
5454

55+
56+
def test_plot_colored_table_auto_adjusts_font_color():
57+
df = pd.DataFrame({
58+
"收益率": [-0.10, 0.00, 0.12],
59+
"交易次数": [10, 20, 30],
60+
}, index=["策略A", "策略B", "策略C"])
61+
62+
fig = plot_colored_table(df, to_html=False, template="plotly")
63+
table = fig.data[0]
64+
font_colors = table.cells.font.color
65+
66+
assert TABLE_DARK_TEXT in font_colors[1]
67+
assert TABLE_LIGHT_TEXT in font_colors[1]
68+
assert font_colors[2] == [TABLE_LIGHT_TEXT, TABLE_DARK_TEXT, TABLE_LIGHT_TEXT]
69+
70+
5571
if __name__ == "__main__":
5672
test_plot_colored_table()

0 commit comments

Comments
 (0)