Spaces:
Running
Running
| # sankey_plot.py | |
| import plotly.graph_objects as go | |
| import matplotlib.cm as cm | |
| import matplotlib.colors as mcolors | |
| import base64 | |
| import pandas as pd | |
| import os | |
| # ========== 核心配置(和之前一致) ========== | |
| import plotly.io as pio | |
| pio.kaleido.scope.default_font = "Noto Sans CJK SC" | |
| FONT_FILE_PATH = "./SourceHanSansCN-Light.otf" # 字体文件在根目录 | |
| # CHINESE_FONT = "Source Han Sans CN Light" | |
| CHINESE_FONT = "Noto Sans CJK SC" | |
| # ========== 关键:设置环境变量,让 Plotly/Kaleido 找到字体 ========== | |
| os.environ["KALEIDO_FONT_SEARCH_PATH"] = os.getcwd() # 字体搜索路径 = 当前目录 | |
| print(f"🔧 字体搜索路径:{os.getcwd()}") | |
| print(f"🔧 字体文件是否存在:{os.path.exists(FONT_FILE_PATH)}") | |
| # CHINESE_FONT = "Noto Sans SC" # 和聚类图保持一致,避免混乱 | |
| def plot_sankey_from_df(sankey_df: pd.DataFrame, title="问题 → 关键词共现") -> str: | |
| if sankey_df.empty: | |
| return "无数据" | |
| # 过滤 top targets | |
| top_targets = sankey_df.groupby('target')['value'].sum().sort_values(ascending=False).head(15).index | |
| df = sankey_df[sankey_df['target'].isin(top_targets)].copy() | |
| # 节点顺序 | |
| sources = ['S4_应用场景', 'S3_操作疑惑', 'S2_讲解需求', 'S1_难点'] | |
| sources = [s for s in sources if s in df['source'].unique()] | |
| targets = top_targets.tolist() | |
| all_nodes = sources + targets | |
| node_index = {n: i for i, n in enumerate(all_nodes)} | |
| # 颜色 | |
| source_color_map = { | |
| 'S1_难点': '#345DA7', 'S2_讲解需求': '#3B8AC4', | |
| 'S3_操作疑惑': '#4BB4DE', 'S4_应用场景': '#EFDBCB' | |
| } | |
| cmap = cm.get_cmap('Set3', len(targets)) | |
| target_colors = [mcolors.to_hex(cmap(i)) for i in range(len(targets))] | |
| target_color_map = dict(zip(targets, target_colors)) | |
| node_colors = [source_color_map.get(n, target_color_map.get(n, '#gray')) for n in all_nodes] | |
| link_colors = [target_color_map.get(t, '#gray') for t in df['target']] | |
| fig = go.Figure(data=[go.Sankey( | |
| node=dict(pad=15, thickness=20, line=dict(color="black", width=0.5), | |
| label=all_nodes, color=node_colors), | |
| link=dict( | |
| source=df['source'].map(node_index), | |
| target=df['target'].map(node_index), | |
| value=df['value'], | |
| color=link_colors | |
| ) | |
| )]) | |
| fig.update_layout(title_text=title, | |
| # 标题字体单独设置(更醒目,中文适配) | |
| titlefont=dict(family=CHINESE_FONT, size=22), | |
| font=dict(family=CHINESE_FONT, size=18), width=900, height=600, | |
| margin=dict(l=50, r=50, t=80, b=50),autosize=False ) # 关闭自动缩放 | |
| print(f"✅ 桑基图{fig}生成成功") | |
| # === 5. 导出高清 PNG(关键!)=== | |
| # 6. 关键:生成base64(增加异常捕获,确保生成成功) | |
| try: | |
| # 生成高清PNG字节流(scale=2 → 2倍DPI,避免模糊) | |
| img_bytes = fig.to_image( | |
| format="png", | |
| width=900, | |
| height=600, | |
| scale=2, | |
| engine="kaleido" # 显式指定引擎,避免依赖自动检测 | |
| ) | |
| # 转base64字符串 | |
| b64 = base64.b64encode(img_bytes).decode("utf-8") | |
| print(f"✅ 桑基图base64生成成功,长度:{len(b64)} 字符") # 正常长度约10万+字符 | |
| except Exception as img_err: | |
| # 图片生成失败时,用空图表的base64兜底 | |
| print(f"⚠️ 桑基图转base64失败:{str(img_err)}") | |
| empty_img_bytes = fig.to_image(format="png", width=900, height=600, scale=1) | |
| b64 = base64.b64encode(empty_img_bytes).decode("utf-8") | |
| return fig, b64 |