egisinsight / sankey_plot.py
wxy01giser's picture
Update sankey_plot.py
4c17f4e verified
# sankey_plot.py
import plotly.graph_objects as go
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import base64
import pandas as pd
import os
# ========== 核心配置(和之前一致) ==========
import plotly.io as pio
pio.kaleido.scope.default_font = "Noto Sans CJK SC"
FONT_FILE_PATH = "./SourceHanSansCN-Light.otf" # 字体文件在根目录
# CHINESE_FONT = "Source Han Sans CN Light"
CHINESE_FONT = "Noto Sans CJK SC"
# ========== 关键:设置环境变量,让 Plotly/Kaleido 找到字体 ==========
os.environ["KALEIDO_FONT_SEARCH_PATH"] = os.getcwd() # 字体搜索路径 = 当前目录
print(f"🔧 字体搜索路径:{os.getcwd()}")
print(f"🔧 字体文件是否存在:{os.path.exists(FONT_FILE_PATH)}")
# CHINESE_FONT = "Noto Sans SC" # 和聚类图保持一致,避免混乱
def plot_sankey_from_df(sankey_df: pd.DataFrame, title="问题 → 关键词共现") -> str:
if sankey_df.empty:
return "无数据"
# 过滤 top targets
top_targets = sankey_df.groupby('target')['value'].sum().sort_values(ascending=False).head(15).index
df = sankey_df[sankey_df['target'].isin(top_targets)].copy()
# 节点顺序
sources = ['S4_应用场景', 'S3_操作疑惑', 'S2_讲解需求', 'S1_难点']
sources = [s for s in sources if s in df['source'].unique()]
targets = top_targets.tolist()
all_nodes = sources + targets
node_index = {n: i for i, n in enumerate(all_nodes)}
# 颜色
source_color_map = {
'S1_难点': '#345DA7', 'S2_讲解需求': '#3B8AC4',
'S3_操作疑惑': '#4BB4DE', 'S4_应用场景': '#EFDBCB'
}
cmap = cm.get_cmap('Set3', len(targets))
target_colors = [mcolors.to_hex(cmap(i)) for i in range(len(targets))]
target_color_map = dict(zip(targets, target_colors))
node_colors = [source_color_map.get(n, target_color_map.get(n, '#gray')) for n in all_nodes]
link_colors = [target_color_map.get(t, '#gray') for t in df['target']]
fig = go.Figure(data=[go.Sankey(
node=dict(pad=15, thickness=20, line=dict(color="black", width=0.5),
label=all_nodes, color=node_colors),
link=dict(
source=df['source'].map(node_index),
target=df['target'].map(node_index),
value=df['value'],
color=link_colors
)
)])
fig.update_layout(title_text=title,
# 标题字体单独设置(更醒目,中文适配)
titlefont=dict(family=CHINESE_FONT, size=22),
font=dict(family=CHINESE_FONT, size=18), width=900, height=600,
margin=dict(l=50, r=50, t=80, b=50),autosize=False ) # 关闭自动缩放
print(f"✅ 桑基图{fig}生成成功")
# === 5. 导出高清 PNG(关键!)===
# 6. 关键:生成base64(增加异常捕获,确保生成成功)
try:
# 生成高清PNG字节流(scale=2 → 2倍DPI,避免模糊)
img_bytes = fig.to_image(
format="png",
width=900,
height=600,
scale=2,
engine="kaleido" # 显式指定引擎,避免依赖自动检测
)
# 转base64字符串
b64 = base64.b64encode(img_bytes).decode("utf-8")
print(f"✅ 桑基图base64生成成功,长度:{len(b64)} 字符") # 正常长度约10万+字符
except Exception as img_err:
# 图片生成失败时,用空图表的base64兜底
print(f"⚠️ 桑基图转base64失败:{str(img_err)}")
empty_img_bytes = fig.to_image(format="png", width=900, height=600, scale=1)
b64 = base64.b64encode(empty_img_bytes).decode("utf-8")
return fig, b64