File size: 4,645 Bytes
f6bb351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48275ef
 
 
f6bb351
48275ef
 
 
 
f6bb351
 
 
 
 
 
48275ef
f6bb351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48275ef
 
 
 
 
 
 
f6bb351
 
 
48275ef
 
 
 
 
f6bb351
48275ef
f6bb351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48275ef
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import gradio as gr
import cv2
from PIL import Image
from model.CyueNet_models import MMS
from utils1.data import transform_image

# 设置GPU/CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def load_model():
    """加载预训练的模型"""
    model = MMS()
    try:
        # 使用相对路径,模型文件将存储在HuggingFace Spaces上
        model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device))
        print("模型加载成功")
    except RuntimeError as e:
        print(f"加载状态字典时出现部分不匹配,错误信息: {e}")
    model.to(device)
    model.eval()
    return model

def process_image(image, model, testsize=256):
    """处理图像并返回显著性检测结果"""
    # 保存原始图像用于后续处理
    original_image = image.copy()
    
    # 预处理图像
    image_pil = Image.fromarray(image).convert('RGB')
    image_tensor = transform_image(image_pil, testsize)
    image_tensor = image_tensor.unsqueeze(0)
    image_tensor = image_tensor.to(device)
    
    # 计时
    time_start = time.time()
    
    # 推理
    with torch.no_grad():
        x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = model(image_tensor)
    
    time_end = time.time()
    inference_time = time_end - time_start
    
    # 处理输出结果
    res = res.sigmoid().data.cpu().numpy().squeeze()
    res = (res - res.min()) / (res.max() - res.min() + 1e-8)
    
    # 将输出调整为原始图像大小
    h, w = original_image.shape[:2]
    res_resized = cv2.resize(res, (w, h))
    
    # 转换为可视化图像
    res_vis = (res_resized * 255).astype(np.uint8)
    
    # 创建热力图
    heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
    
    # 将热力图与原始图像混合
    alpha = 0.5
    # 确保原始图像是BGR格式用于OpenCV操作
    if len(original_image.shape) == 3 and original_image.shape[2] == 3:
        original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
    else:
        original_bgr = cv2.cvtColor(original_image, cv2.COLOR_GRAY2BGR)
    
    overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
    
    # 二值化结果用于分割
    _, binary_mask = cv2.threshold(res_vis, 127, 255, cv2.THRESH_BINARY)
    segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
    
    # 转回RGB格式用于显示
    overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
    segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
    
    return original_image, res_vis, heatmap, overlayed_rgb, segmented_rgb, f"推理时间: {inference_time:.4f}秒"

def run_demo(input_image):
    """Gradio界面的主函数"""
    if input_image is None:
        return [None] * 5 + ["请上传图片"]
    
    # 处理图像
    original, saliency_map, heatmap, overlayed, segmented, time_info = process_image(input_image, model)
    
    return original, saliency_map, heatmap, overlayed, segmented, time_info

# 加载模型
print("正在加载模型...")
model = load_model()

# 创建Gradio界面
with gr.Blocks(title="显著性目标检测Demo") as demo:
    gr.Markdown("# 显著性目标检测Demo")
    gr.Markdown("上传一张图片,系统将自动检测显著性区域")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="输入图像", type="numpy")
            submit_btn = gr.Button("开始检测")
        
        with gr.Column():
            original_output = gr.Image(label="原始图像")
            saliency_output = gr.Image(label="显著性图")
            heatmap_output = gr.Image(label="热力图")
            overlayed_output = gr.Image(label="叠加结果")
            segmented_output = gr.Image(label="分割结果")
            time_info = gr.Textbox(label="处理信息")
    
    submit_btn.click(
        fn=run_demo,
        inputs=input_image,
        outputs=[original_output, saliency_output, heatmap_output, overlayed_output, segmented_output, time_info]
    )
    
    gr.Markdown("## 使用说明")
    gr.Markdown("1. 点击'输入图像'区域上传一张图片")
    gr.Markdown("2. 点击'开始检测'按钮进行显著性目标检测")
    gr.Markdown("3. 系统将显示原始图像、显著性图、热力图、叠加结果和分割结果")

# 启动Gradio应用
if __name__ == "__main__":
    demo.launch(share=True)