Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import os | |
| import time | |
| import gradio as gr | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from model.CyueNet_models import MMS | |
| from utils1.data import transform_image | |
| # 设置GPU/CPU | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| # 增强的自定义CSS | |
| custom_css = """ | |
| :root { | |
| --primary-color: #2196F3; | |
| --secondary-color: #21CBF3; | |
| --background-color: #f6f8fa; | |
| --text-color: #333; | |
| --border-radius: 10px; | |
| } | |
| .gradio-container { | |
| background: linear-gradient(135deg, var(--background-color), #ffffff); | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| .output-image, .input-image { | |
| border-radius: var(--border-radius); | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| transition: transform 0.3s ease; | |
| } | |
| .output-image:hover, .input-image:hover { | |
| transform: scale(1.02); | |
| } | |
| .custom-button { | |
| background: linear-gradient(45deg, var(--primary-color), var(--secondary-color)); | |
| border: none; | |
| color: white; | |
| padding: 10px 20px; | |
| border-radius: var(--border-radius); | |
| cursor: pointer; | |
| transition: all 0.3s ease; | |
| font-weight: bold; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| } | |
| .custom-button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 5px 15px rgba(33, 150, 243, 0.3); | |
| } | |
| .tabs { | |
| border-radius: var(--border-radius); | |
| overflow: hidden; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
| } | |
| .slider-component { | |
| background: white; | |
| padding: 15px; | |
| border-radius: var(--border-radius); | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
| } | |
| .info-box { | |
| background: white; | |
| padding: 15px; | |
| border-radius: var(--border-radius); | |
| margin: 10px 0; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
| } | |
| .statistics-container { | |
| display: flex; | |
| gap: 10px; | |
| margin-top: 10px; | |
| } | |
| .statistic-card { | |
| background: white; | |
| padding: 15px; | |
| border-radius: var(--border-radius); | |
| flex: 1; | |
| text-align: center; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
| } | |
| .accordion { | |
| background: white; | |
| border-radius: var(--border-radius); | |
| margin: 10px 0; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
| } | |
| .accordion-header { | |
| padding: 15px; | |
| cursor: pointer; | |
| transition: background-color 0.3s ease; | |
| } | |
| .accordion-header:hover { | |
| background-color: var(--background-color); | |
| } | |
| .result-card { | |
| background: white; | |
| padding: 20px; | |
| border-radius: var(--border-radius); | |
| margin: 10px 0; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
| } | |
| .analysis-container { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); | |
| gap: 20px; | |
| margin: 20px 0; | |
| } | |
| .chart-container { | |
| height: 300px; | |
| background: white; | |
| border-radius: var(--border-radius); | |
| padding: 15px; | |
| } | |
| """ | |
| class ImageProcessor: | |
| def __init__(self): | |
| self.model = None | |
| self.load_model() | |
| self.last_results = None | |
| def load_model(self): | |
| """加载预训练的模型""" | |
| self.model = MMS() | |
| try: | |
| self.model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device)) | |
| print("模型加载成功") | |
| except RuntimeError as e: | |
| print(f"加载状态字典时出现部分不匹配,错误信息: {e}") | |
| self.model.to(device) | |
| self.model.eval() | |
| def adjust_brightness_contrast(self, image, brightness=0, contrast=0): | |
| """调整图像亮度和对比度""" | |
| if brightness != 0: | |
| if brightness > 0: | |
| shadow = brightness | |
| highlight = 255 | |
| else: | |
| shadow = 0 | |
| highlight = 255 + brightness | |
| alpha_b = (highlight - shadow)/255 | |
| gamma_b = shadow | |
| image = cv2.addWeighted(image, alpha_b, image, 0, gamma_b) | |
| if contrast != 0: | |
| f = 131*(contrast + 127)/(127*(131-contrast)) | |
| alpha_c = f | |
| gamma_c = 127*(1-f) | |
| image = cv2.addWeighted(image, alpha_c, image, 0, gamma_c) | |
| return image | |
| def apply_filters(self, image, filter_type): | |
| """应用图像滤镜效果""" | |
| if filter_type == "锐化": | |
| kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) | |
| return cv2.filter2D(image, -1, kernel) | |
| elif filter_type == "柔化": | |
| return cv2.GaussianBlur(image, (5,5), 0) | |
| return image | |
| def generate_analysis_plots(self, saliency_map): | |
| """生成分析图表""" | |
| # 直方图数据 | |
| hist_data = saliency_map.flatten() | |
| fig_hist = px.histogram(hist_data, nbins=50, | |
| title="显著性分布直方图", | |
| labels={'value': '显著性值', 'count': '频率'}) | |
| # 计算显著性统计 | |
| regions = np.zeros_like(saliency_map) | |
| regions[saliency_map > np.mean(saliency_map)] = 1 | |
| return fig_hist | |
| def process_image(self, image, threshold=0.5, testsize=256, | |
| enhance_contrast=False, denoise=False, | |
| brightness=0, contrast=0, filter_type="无"): | |
| """增强的图像处理函数""" | |
| if image is None: | |
| return [None] * 7 + ["请提供有效的图像"] | |
| # 图像预处理 | |
| if denoise: | |
| image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21) | |
| # 应用亮度和对比度调整 | |
| image = self.adjust_brightness_contrast(image, brightness, contrast) | |
| # 应用滤镜 | |
| image = self.apply_filters(image, filter_type) | |
| if enhance_contrast: | |
| lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) | |
| l, a, b = cv2.split(lab) | |
| clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) | |
| l = clahe.apply(l) | |
| lab = cv2.merge((l,a,b)) | |
| image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) | |
| # 保存原始图像 | |
| original_image = image.copy() | |
| # 预处理图像 | |
| image_pil = Image.fromarray(image).convert('RGB') | |
| image_tensor = transform_image(image_pil, testsize) | |
| image_tensor = image_tensor.unsqueeze(0) | |
| image_tensor = image_tensor.to(device) | |
| # 计时 | |
| time_start = time.time() | |
| # 推理 | |
| with torch.no_grad(): | |
| x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor) | |
| time_end = time.time() | |
| inference_time = time_end - time_start | |
| # 处理输出结果 | |
| res = res.sigmoid().data.cpu().numpy().squeeze() | |
| res = (res - res.min()) / (res.max() - res.min() + 1e-8) | |
| # 调整大小 | |
| h, w = original_image.shape[:2] | |
| res_resized = cv2.resize(res, (w, h)) | |
| # 生成可视化结果 | |
| res_vis = (res_resized * 255).astype(np.uint8) | |
| heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET) | |
| _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY) | |
| # 叠加结果 | |
| alpha = 0.5 | |
| original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR) | |
| overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0) | |
| segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask) | |
| # 转回RGB格式 | |
| overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB) | |
| segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB) | |
| # 生成分析图表 | |
| analysis_plot = self.generate_analysis_plots(res_resized) | |
| # 计算统计信息 | |
| stats = { | |
| "处理分辨率": f"{w}x{h}", | |
| "检测目标数量": str(len(cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0])), | |
| "平均置信度": f"{np.mean(res_resized):.2%}", | |
| "最大置信度": f"{np.max(res_resized):.2%}", | |
| "处理时间": f"{inference_time:.3f}秒" | |
| } | |
| # 保存结果供后续分析 | |
| self.last_results = { | |
| 'saliency_map': res_resized, | |
| 'binary_mask': binary_mask, | |
| 'stats': stats | |
| } | |
| return (original_image, res_vis, heatmap, overlayed_rgb, segmented_rgb, | |
| f"推理时间: {inference_time:.4f}秒", stats, analysis_plot) | |
| # 创建处理器实例 | |
| processor = ImageProcessor() | |
| def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours, | |
| brightness, contrast, filter_type): | |
| """增强的主处理函数""" | |
| if input_image is None: | |
| return [None] * 8 + ["请上传图片"] | |
| # 处理图像 | |
| results = processor.process_image( | |
| input_image, | |
| threshold=threshold/100.0, | |
| enhance_contrast=enhance_contrast, | |
| denoise=denoise, | |
| brightness=brightness, | |
| contrast=contrast, | |
| filter_type=filter_type | |
| ) | |
| original, saliency_map, heatmap, overlayed, segmented, time_info, stats, analysis_plot = results | |
| # 添加轮廓显示 | |
| if show_contours and saliency_map is not None: | |
| _, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY) | |
| contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| overlay_with_contours = overlayed.copy() | |
| cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2) | |
| overlayed = overlay_with_contours | |
| # 生成统计信息HTML | |
| stats_html = "<div class='statistics-container'>" | |
| for key, value in stats.items(): | |
| stats_html += f"<div class='statistic-card'><h4>{key}</h4><p>{value}</p></div>" | |
| stats_html += "</div>" | |
| return (original, saliency_map, heatmap, overlayed, segmented, | |
| time_info, stats_html, analysis_plot) | |
| # 创建Gradio界面 | |
| with gr.Blocks(title="高级显著性目标检测系统", css=custom_css) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎯 智能显著性目标检测系统 | |
| ### 基于深度学习的图像显著性检测与分析工具 | |
| """ | |
| ) | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("🔍 主要功能"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # 输入控制面板 | |
| with gr.Group(elem_classes="control-panel"): | |
| input_image = gr.Image( | |
| label="输入图像", | |
| type="numpy", | |
| elem_classes="input-image" | |
| ) | |
| with gr.Accordion("基础设置", open=True): | |
| threshold_slider = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="检测阈值", | |
| info="调整检测的灵敏度" | |
| ) | |
| enhance_contrast = gr.Checkbox( | |
| label="增强对比度", | |
| value=False | |
| ) | |
| denoise = gr.Checkbox( | |
| label="降噪处理", | |
| value=False | |
| ) | |
| show_contours = gr.Checkbox( | |
| label="显示轮廓", | |
| value=True | |
| ) | |
| with gr.Accordion("图像调整", open=False): | |
| brightness = gr.Slider( | |
| minimum=-100, | |
| maximum=100, | |
| value=0, | |
| step=1, | |
| label="亮度" | |
| ) | |
| contrast = gr.Slider( | |
| minimum=-100, | |
| maximum=100, | |
| value=0, | |
| step=1, | |
| label="对比度" | |
| ) | |
| filter_type = gr.Radio( | |
| choices=["无", "锐化", "柔化"], | |
| value="无", | |
| label="图像滤镜" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button( | |
| "开始检测", | |
| variant="primary", | |
| elem_classes="custom-button" | |
| ) | |
| reset_btn = gr.Button( | |
| "重置参数", | |
| elem_classes="custom-button" | |
| ) | |
| with gr.Column(scale=2): | |
| # 结果显示区域 | |
| with gr.Tabs(): | |
| with gr.TabItem("检测结果"): | |
| with gr.Row(elem_classes="result-gallery"): | |
| original_output = gr.Image( | |
| label="原始图像", | |
| elem_classes="output-image" | |
| ) | |
| saliency_output = gr.Image( | |
| label="显著性图", | |
| elem_classes="output-image" | |
| ) | |
| with gr.Row(elem_classes="result-gallery"): | |
| heatmap_output = gr.Image( | |
| label="热力图分析", | |
| elem_classes="output-image" | |
| ) | |
| overlayed_output = gr.Image( | |
| label="叠加效果", | |
| elem_classes="output-image" | |
| ) | |
| segmented_output = gr.Image( | |
| label="目标分割", | |
| elem_classes="output-image" | |
| ) | |
| with gr.TabItem("分析报告"): | |
| with gr.Group(elem_classes="info-box"): | |
| time_info = gr.Textbox( | |
| label="处理时间", | |
| show_label=True | |
| ) | |
| stats_output = gr.HTML( | |
| label="统计信息" | |
| ) | |
| analysis_plot = gr.Plot( | |
| label="显著性分布分析" | |
| ) | |
| with gr.TabItem("📖 使用指南"): | |
| gr.Markdown( | |
| """ | |
| ## 使用说明 | |
| 1. **上传图片**: 点击左侧"输入图像"区域上传待分析的图片 | |
| 2. **调整参数**: | |
| - 使用阈值滑块调整检测的灵敏度 | |
| - 可选择是否增强对比度或进行降噪 | |
| - 在图像调整面板中调整亮度、对比度和滤镜 | |
| 3. **开始检测**: 点击"开始检测"按钮进行分析 | |
| 4. **查看结果**: 在不同标签页中查看分析结果 | |
| ## 功能说明 | |
| - **显著性图**: 展示目标区域的重要性分布 | |
| - **热力图**: 使用色彩展示检测强度 | |
| - **叠加效果**: 将检测结果与原图叠加展示 | |
| - **目标分割**: 提取关键目标区域 | |
| - **分析报告**: 查看详细的统计信息和分析图表 | |
| """ | |
| ) | |
| with gr.TabItem("ℹ️ 关于"): | |
| gr.Markdown( | |
| """ | |
| ## 项目信息 | |
| - **版本**: 2.0.0 | |
| - **技术架构**: PyTorch + Gradio | |
| - **模型**: CyueNet | |
| ## 主要特点 | |
| - 实时图像处理和分析 | |
| - 多维度结果可视化 | |
| - 丰富的图像调整选项 | |
| - 详细的数据分析报告 | |
| ## 更新日志 | |
| - 增加了图像预处理选项 | |
| - 添加了统计分析功能 | |
| - 优化了用户界面 | |
| - 提升了处理性能 | |
| """ | |
| ) | |
| # 重置参数函数 | |
| def reset_params(): | |
| return { | |
| threshold_slider: 50, | |
| brightness: 0, | |
| contrast: 0, | |
| filter_type: "无", | |
| enhance_contrast: False, | |
| denoise: False, | |
| show_contours: True | |
| } | |
| # 设置事件处理 | |
| submit_btn.click( | |
| fn=run_demo, | |
| inputs=[ | |
| input_image, | |
| threshold_slider, | |
| enhance_contrast, | |
| denoise, | |
| show_contours, | |
| brightness, | |
| contrast, | |
| filter_type | |
| ], | |
| outputs=[ | |
| original_output, | |
| saliency_output, | |
| heatmap_output, | |
| overlayed_output, | |
| segmented_output, | |
| time_info, | |
| stats_output, | |
| analysis_plot | |
| ] | |
| ) | |
| reset_btn.click( | |
| fn=reset_params, | |
| inputs=[], | |
| outputs=[ | |
| threshold_slider, | |
| brightness, | |
| contrast, | |
| filter_type, | |
| enhance_contrast, | |
| denoise, | |
| show_contours | |
| ] | |
| ) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |