kunkk's picture
Update app.py
9e416cb verified
raw
history blame
18.8 kB
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import gradio as gr
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from model.CyueNet_models import MMS
from utils1.data import transform_image
# 设置GPU/CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 增强的自定义CSS
custom_css = """
:root {
--primary-color: #2196F3;
--secondary-color: #21CBF3;
--background-color: #f6f8fa;
--text-color: #333;
--border-radius: 10px;
}
.gradio-container {
background: linear-gradient(135deg, var(--background-color), #ffffff);
max-width: 1200px !important;
margin: auto !important;
}
.output-image, .input-image {
border-radius: var(--border-radius);
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
transition: transform 0.3s ease;
}
.output-image:hover, .input-image:hover {
transform: scale(1.02);
}
.custom-button {
background: linear-gradient(45deg, var(--primary-color), var(--secondary-color));
border: none;
color: white;
padding: 10px 20px;
border-radius: var(--border-radius);
cursor: pointer;
transition: all 0.3s ease;
font-weight: bold;
text-transform: uppercase;
letter-spacing: 1px;
}
.custom-button:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(33, 150, 243, 0.3);
}
.tabs {
border-radius: var(--border-radius);
overflow: hidden;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
.slider-component {
background: white;
padding: 15px;
border-radius: var(--border-radius);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
}
.info-box {
background: white;
padding: 15px;
border-radius: var(--border-radius);
margin: 10px 0;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
}
.statistics-container {
display: flex;
gap: 10px;
margin-top: 10px;
}
.statistic-card {
background: white;
padding: 15px;
border-radius: var(--border-radius);
flex: 1;
text-align: center;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
}
.accordion {
background: white;
border-radius: var(--border-radius);
margin: 10px 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.accordion-header {
padding: 15px;
cursor: pointer;
transition: background-color 0.3s ease;
}
.accordion-header:hover {
background-color: var(--background-color);
}
.result-card {
background: white;
padding: 20px;
border-radius: var(--border-radius);
margin: 10px 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.analysis-container {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 20px;
margin: 20px 0;
}
.chart-container {
height: 300px;
background: white;
border-radius: var(--border-radius);
padding: 15px;
}
"""
class ImageProcessor:
def __init__(self):
self.model = None
self.load_model()
self.last_results = None
def load_model(self):
"""加载预训练的模型"""
self.model = MMS()
try:
self.model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device))
print("模型加载成功")
except RuntimeError as e:
print(f"加载状态字典时出现部分不匹配,错误信息: {e}")
self.model.to(device)
self.model.eval()
def adjust_brightness_contrast(self, image, brightness=0, contrast=0):
"""调整图像亮度和对比度"""
if brightness != 0:
if brightness > 0:
shadow = brightness
highlight = 255
else:
shadow = 0
highlight = 255 + brightness
alpha_b = (highlight - shadow)/255
gamma_b = shadow
image = cv2.addWeighted(image, alpha_b, image, 0, gamma_b)
if contrast != 0:
f = 131*(contrast + 127)/(127*(131-contrast))
alpha_c = f
gamma_c = 127*(1-f)
image = cv2.addWeighted(image, alpha_c, image, 0, gamma_c)
return image
def apply_filters(self, image, filter_type):
"""应用图像滤镜效果"""
if filter_type == "锐化":
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
return cv2.filter2D(image, -1, kernel)
elif filter_type == "柔化":
return cv2.GaussianBlur(image, (5,5), 0)
return image
def generate_analysis_plots(self, saliency_map):
"""生成分析图表"""
# 直方图数据
hist_data = saliency_map.flatten()
fig_hist = px.histogram(hist_data, nbins=50,
title="显著性分布直方图",
labels={'value': '显著性值', 'count': '频率'})
# 计算显著性统计
regions = np.zeros_like(saliency_map)
regions[saliency_map > np.mean(saliency_map)] = 1
return fig_hist
def process_image(self, image, threshold=0.5, testsize=256,
enhance_contrast=False, denoise=False,
brightness=0, contrast=0, filter_type="无"):
"""增强的图像处理函数"""
if image is None:
return [None] * 7 + ["请提供有效的图像"]
# 图像预处理
if denoise:
image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
# 应用亮度和对比度调整
image = self.adjust_brightness_contrast(image, brightness, contrast)
# 应用滤镜
image = self.apply_filters(image, filter_type)
if enhance_contrast:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
l = clahe.apply(l)
lab = cv2.merge((l,a,b))
image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
# 保存原始图像
original_image = image.copy()
# 预处理图像
image_pil = Image.fromarray(image).convert('RGB')
image_tensor = transform_image(image_pil, testsize)
image_tensor = image_tensor.unsqueeze(0)
image_tensor = image_tensor.to(device)
# 计时
time_start = time.time()
# 推理
with torch.no_grad():
x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
time_end = time.time()
inference_time = time_end - time_start
# 处理输出结果
res = res.sigmoid().data.cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
# 调整大小
h, w = original_image.shape[:2]
res_resized = cv2.resize(res, (w, h))
# 生成可视化结果
res_vis = (res_resized * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
_, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
# 叠加结果
alpha = 0.5
original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
# 转回RGB格式
overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
# 生成分析图表
analysis_plot = self.generate_analysis_plots(res_resized)
# 计算统计信息
stats = {
"处理分辨率": f"{w}x{h}",
"检测目标数量": str(len(cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0])),
"平均置信度": f"{np.mean(res_resized):.2%}",
"最大置信度": f"{np.max(res_resized):.2%}",
"处理时间": f"{inference_time:.3f}秒"
}
# 保存结果供后续分析
self.last_results = {
'saliency_map': res_resized,
'binary_mask': binary_mask,
'stats': stats
}
return (original_image, res_vis, heatmap, overlayed_rgb, segmented_rgb,
f"推理时间: {inference_time:.4f}秒", stats, analysis_plot)
# 创建处理器实例
processor = ImageProcessor()
def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
brightness, contrast, filter_type):
"""增强的主处理函数"""
if input_image is None:
return [None] * 8 + ["请上传图片"]
# 处理图像
results = processor.process_image(
input_image,
threshold=threshold/100.0,
enhance_contrast=enhance_contrast,
denoise=denoise,
brightness=brightness,
contrast=contrast,
filter_type=filter_type
)
original, saliency_map, heatmap, overlayed, segmented, time_info, stats, analysis_plot = results
# 添加轮廓显示
if show_contours and saliency_map is not None:
_, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
overlay_with_contours = overlayed.copy()
cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2)
overlayed = overlay_with_contours
# 生成统计信息HTML
stats_html = "<div class='statistics-container'>"
for key, value in stats.items():
stats_html += f"<div class='statistic-card'><h4>{key}</h4><p>{value}</p></div>"
stats_html += "</div>"
return (original, saliency_map, heatmap, overlayed, segmented,
time_info, stats_html, analysis_plot)
# 创建Gradio界面
with gr.Blocks(title="高级显著性目标检测系统", css=custom_css) as demo:
gr.Markdown(
"""
# 🎯 智能显著性目标检测系统
### 基于深度学习的图像显著性检测与分析工具
"""
)
with gr.Tabs() as tabs:
with gr.TabItem("🔍 主要功能"):
with gr.Row():
with gr.Column(scale=1):
# 输入控制面板
with gr.Group(elem_classes="control-panel"):
input_image = gr.Image(
label="输入图像",
type="numpy",
elem_classes="input-image"
)
with gr.Accordion("基础设置", open=True):
threshold_slider = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="检测阈值",
info="调整检测的灵敏度"
)
enhance_contrast = gr.Checkbox(
label="增强对比度",
value=False
)
denoise = gr.Checkbox(
label="降噪处理",
value=False
)
show_contours = gr.Checkbox(
label="显示轮廓",
value=True
)
with gr.Accordion("图像调整", open=False):
brightness = gr.Slider(
minimum=-100,
maximum=100,
value=0,
step=1,
label="亮度"
)
contrast = gr.Slider(
minimum=-100,
maximum=100,
value=0,
step=1,
label="对比度"
)
filter_type = gr.Radio(
choices=["无", "锐化", "柔化"],
value="无",
label="图像滤镜"
)
with gr.Row():
submit_btn = gr.Button(
"开始检测",
variant="primary",
elem_classes="custom-button"
)
reset_btn = gr.Button(
"重置参数",
elem_classes="custom-button"
)
with gr.Column(scale=2):
# 结果显示区域
with gr.Tabs():
with gr.TabItem("检测结果"):
with gr.Row(elem_classes="result-gallery"):
original_output = gr.Image(
label="原始图像",
elem_classes="output-image"
)
saliency_output = gr.Image(
label="显著性图",
elem_classes="output-image"
)
with gr.Row(elem_classes="result-gallery"):
heatmap_output = gr.Image(
label="热力图分析",
elem_classes="output-image"
)
overlayed_output = gr.Image(
label="叠加效果",
elem_classes="output-image"
)
segmented_output = gr.Image(
label="目标分割",
elem_classes="output-image"
)
with gr.TabItem("分析报告"):
with gr.Group(elem_classes="info-box"):
time_info = gr.Textbox(
label="处理时间",
show_label=True
)
stats_output = gr.HTML(
label="统计信息"
)
analysis_plot = gr.Plot(
label="显著性分布分析"
)
with gr.TabItem("📖 使用指南"):
gr.Markdown(
"""
## 使用说明
1. **上传图片**: 点击左侧"输入图像"区域上传待分析的图片
2. **调整参数**:
- 使用阈值滑块调整检测的灵敏度
- 可选择是否增强对比度或进行降噪
- 在图像调整面板中调整亮度、对比度和滤镜
3. **开始检测**: 点击"开始检测"按钮进行分析
4. **查看结果**: 在不同标签页中查看分析结果
## 功能说明
- **显著性图**: 展示目标区域的重要性分布
- **热力图**: 使用色彩展示检测强度
- **叠加效果**: 将检测结果与原图叠加展示
- **目标分割**: 提取关键目标区域
- **分析报告**: 查看详细的统计信息和分析图表
"""
)
with gr.TabItem("ℹ️ 关于"):
gr.Markdown(
"""
## 项目信息
- **版本**: 2.0.0
- **技术架构**: PyTorch + Gradio
- **模型**: CyueNet
## 主要特点
- 实时图像处理和分析
- 多维度结果可视化
- 丰富的图像调整选项
- 详细的数据分析报告
## 更新日志
- 增加了图像预处理选项
- 添加了统计分析功能
- 优化了用户界面
- 提升了处理性能
"""
)
# 重置参数函数
def reset_params():
return {
threshold_slider: 50,
brightness: 0,
contrast: 0,
filter_type: "无",
enhance_contrast: False,
denoise: False,
show_contours: True
}
# 设置事件处理
submit_btn.click(
fn=run_demo,
inputs=[
input_image,
threshold_slider,
enhance_contrast,
denoise,
show_contours,
brightness,
contrast,
filter_type
],
outputs=[
original_output,
saliency_output,
heatmap_output,
overlayed_output,
segmented_output,
time_info,
stats_output,
analysis_plot
]
)
reset_btn.click(
fn=reset_params,
inputs=[],
outputs=[
threshold_slider,
brightness,
contrast,
filter_type,
enhance_contrast,
denoise,
show_contours
]
)
# 启动应用
if __name__ == "__main__":
demo.launch(share=True)