kunkk commited on
Commit
6e571b2
·
verified ·
1 Parent(s): cde7ae3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +392 -221
app.py CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  import numpy as np
@@ -17,126 +32,208 @@ import base64
17
  # GPU/CPU设置
18
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
 
20
- # CSS样式设置
21
  custom_css = """
 
 
22
  :root {
23
- --primary-color: #2196F3;
24
- --secondary-color: #21CBF3;
25
- --background-color: #f6f8fa;
26
- --text-color: #333;
27
- --border-radius: 10px;
28
- --glass-bg: rgba(255, 255, 255, 0.25);
29
- --shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37);
 
 
 
 
 
 
 
 
30
  }
31
 
32
  .gradio-container {
33
- background: linear-gradient(135deg, var(--background-color), #ffffff);
34
- max-width: 1400px !important;
35
- margin: auto !important;
36
- backdrop-filter: blur(10px);
 
37
  }
38
 
39
- .output-image, .input-image {
40
- border-radius: var(--border-radius);
41
- box-shadow: var(--shadow);
42
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
43
- backdrop-filter: blur(10px);
44
- border: 1px solid rgba(255, 255, 255, 0.18);
45
  }
46
 
47
- .output-image:hover, .input-image:hover {
48
- transform: scale(1.02) translateY(-2px);
49
- box-shadow: 0 12px 40px 0 rgba(31, 38, 135, 0.5);
 
 
50
  }
51
 
52
- .custom-button {
53
- background: linear-gradient(45deg, var(--primary-color), var(--secondary-color));
54
- border: none;
55
- color: white;
56
- padding: 12px 24px;
57
- border-radius: var(--border-radius);
58
- cursor: pointer;
59
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
60
- font-weight: bold;
61
- text-transform: uppercase;
62
- letter-spacing: 1px;
63
- box-shadow: var(--shadow);
64
  }
65
 
66
- .custom-button:hover {
67
- transform: translateY(-3px);
68
- box-shadow: 0 12px 30px rgba(33, 150, 243, 0.4);
69
  }
70
 
71
- .advanced-controls {
72
- background: var(--glass-bg);
73
- border-radius: 20px;
74
- padding: 25px;
75
- box-shadow: var(--shadow);
76
- backdrop-filter: blur(10px);
77
- border: 1px solid rgba(255, 255, 255, 0.18);
78
  }
79
 
80
- .result-container {
81
- background: var(--glass-bg);
82
- border-radius: 20px;
83
- padding: 20px;
84
- backdrop-filter: blur(15px);
85
- border: 1px solid rgba(255, 255, 255, 0.18);
86
  box-shadow: var(--shadow);
 
87
  }
88
 
89
- .interactive-viz {
90
- border-radius: 15px;
91
- overflow: hidden;
92
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
 
 
 
 
 
 
 
 
93
  box-shadow: var(--shadow);
94
  }
95
 
96
- .interactive-viz:hover {
97
- transform: translateY(-5px);
98
- box-shadow: 0 15px 35px rgba(0,0,0,0.15);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  }
100
 
 
101
  .statistics-container {
102
  display: grid;
103
- grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
104
- gap: 15px;
105
- margin-top: 15px;
106
  }
107
 
108
  .statistic-card {
109
- background: var(--glass-bg);
110
- padding: 20px;
111
  border-radius: var(--border-radius);
 
112
  text-align: center;
113
  box-shadow: var(--shadow);
114
- backdrop-filter: blur(10px);
115
- border: 1px solid rgba(255, 255, 255, 0.18);
116
- transition: all 0.3s ease;
117
  }
118
 
119
  .statistic-card:hover {
120
  transform: translateY(-2px);
121
- box-shadow: 0 10px 25px rgba(0,0,0,0.1);
122
  }
123
 
124
- .progress-container {
125
- background: var(--glass-bg);
126
- border-radius: 10px;
127
- padding: 15px;
128
- margin: 10px 0;
129
- backdrop-filter: blur(10px);
 
130
  }
131
 
132
- .comparison-slider {
133
- background: var(--glass-bg);
134
- border-radius: 15px;
135
- padding: 20px;
136
- backdrop-filter: blur(10px);
137
- border: 1px solid rgba(255, 255, 255, 0.18);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  }
139
  """
 
140
  class ImageProcessor:
141
  def __init__(self):
142
  self.model = None
@@ -189,53 +286,94 @@ class ImageProcessor:
189
  return image
190
 
191
  def generate_analysis_plots(self, saliency_map):
192
- """生成分析图表"""
193
- plt.style.use('seaborn-v0_8')
 
 
 
 
 
 
 
 
 
194
  fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
195
- # 直方图
196
- ax1.hist(saliency_map.flatten(), bins=50, color='#2196F3', alpha=0.7, edgecolor='black')
197
- ax1.set_title('Histogram of Saliency Distribution', fontsize=12, pad=15)
198
- ax1.set_xlabel('Saliency Value', fontsize=10)
199
- ax1.set_ylabel('Frequency', fontsize=10)
200
- ax1.grid(True, alpha=0.3)
 
 
 
 
 
 
201
  # 添加统计信息
202
- mean_val = np.mean(saliency_map)
203
- median_val = np.median(saliency_map)
204
- ax1.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_val:.3f}')
205
- ax1.axvline(median_val, color='green', linestyle='--', alpha=0.7, label=f'Median: {median_val:.3f}')
206
- ax1.legend()
207
- # 累积分布
208
- sorted_vals = np.sort(saliency_map.flatten())
 
 
 
 
209
  cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
210
- ax2.plot(sorted_vals, cumulative, color='#FF6B35', linewidth=2)
211
- ax2.set_title('Cumulative Distribution Function', fontsize=12)
212
- ax2.set_xlabel('Saliency Value', fontsize=10)
213
- ax2.set_ylabel('Cumulative Probability', fontsize=10)
214
- ax2.grid(True, alpha=0.3)
215
- # 箱线图
216
- ax3.boxplot(saliency_map.flatten(), patch_artist=True,
217
- boxprops=dict(facecolor='#21CBF3', alpha=0.7))
218
- ax3.set_title('Boxplot of Saliency Distribution', fontsize=12)
219
- ax3.set_ylabel('Saliency Value', fontsize=10)
220
- ax3.grid(True, alpha=0.3)
221
- # 强度剖面(中心线)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  center_row = saliency_map[saliency_map.shape[0]//2, :]
223
- ax4.plot(center_row, color='#9C27B0', linewidth=2)
224
- ax4.set_title('Intensity Profile along Center Line', fontsize=12)
225
- ax4.set_xlabel('Pixel Position', fontsize=10)
226
- ax4.set_ylabel('Saliency Value', fontsize=10)
227
- ax4.grid(True, alpha=0.3)
228
- plt.tight_layout()
229
- # 保存为字节
 
 
 
 
 
 
 
230
  buf = io.BytesIO()
231
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
 
232
  buf.seek(0)
233
  img_array = np.array(Image.open(buf))
234
  plt.close()
235
  return img_array
236
 
237
  def quick_process(self, image, threshold=0.5, testsize=256):
238
-
239
  if image is None:
240
  return None, "请提供有效的图像"
241
 
@@ -253,17 +391,15 @@ class ImageProcessor:
253
  time_start = time.time()
254
 
255
  with torch.no_grad():
256
- # 关键修改:只计算必要的输出,避免完整模型计算
257
- if device.type == 'cuda':
258
  with torch.cuda.amp.autocast():
259
- _, res = self.model.forward_quick(image_tensor) # 使用简化版前向传播
260
  else:
261
- with torch.amp.autocast(device_type='cpu'):
262
- _, res = self.model.forward_quick(image_tensor) # 使用简化版前向传播
263
 
264
  time_end = time.time()
265
 
266
- # 确保转换为float32类型
267
  res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
268
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
269
 
@@ -276,7 +412,6 @@ class ImageProcessor:
276
 
277
  return result
278
 
279
-
280
  def process_image(self, image, threshold=0.5, testsize=256,
281
  enhance_contrast=False, denoise=False,
282
  brightness=0, contrast=0, filter_type="无",
@@ -297,7 +432,7 @@ class ImageProcessor:
297
  if cache_key in self.cache:
298
  return self.cache[cache_key]
299
 
300
- # 使用线程进行图像预处理
301
  def preprocess_image():
302
  processed_image = image.copy()
303
 
@@ -331,24 +466,26 @@ class ImageProcessor:
331
  time_start = time.time()
332
 
333
  with torch.no_grad():
334
- if device.type == 'cuda':
335
  with torch.cuda.amp.autocast():
336
- x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
337
  else:
338
- with torch.amp.autocast(device_type='cpu'):
339
- x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
 
 
340
 
341
  time_end = time.time()
342
  inference_time = time_end - time_start
343
 
344
- # 确保转换为float32类型
345
  res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
346
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
347
 
348
  h, w = original_image.shape[:2]
349
  res_resized = cv2.resize(res, (w, h))
350
 
351
- # 生成可视化
352
  res_vis = (res_resized * 255).astype(np.uint8)
353
  heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
354
  _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
@@ -364,21 +501,22 @@ class ImageProcessor:
364
  segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
365
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
366
 
367
- # 生成分析图表
368
  analysis_plot = self.generate_analysis_plots(res_resized)
369
 
370
- # 计算统计信息
371
  contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
372
  total_area = w * h
373
  detected_area = cv2.countNonZero(binary_mask)
374
  coverage_ratio = detected_area / total_area
375
 
376
  stats = {
377
- "处理分辨率": f"{w}x{h}",
378
- "检测到对象数": str(len(contours)),
379
- "平均置信度": f"{np.mean(res_resized):.2%}",
380
- "最大置信度": f"{np.max(res_resized):.2%}",
381
- "覆盖率": f"{coverage_ratio:.2%}",
 
382
  "处理时间": f"{inference_time:.3f}秒"
383
  }
384
 
@@ -420,11 +558,11 @@ class ImageProcessor:
420
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
421
 
422
  if format_type == "PDF报告":
423
- # 生成PDF报告逻辑
424
  return f"PDF报告已保存为 saliency_report_{timestamp}.pdf"
425
  else:
426
  return f"结果已导出为 {format_type.lower()} 文件"
427
- # Create processor instance
 
428
  processor = ImageProcessor()
429
 
430
  def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
@@ -483,23 +621,25 @@ def create_comparison_view(original, result, slider_value):
483
 
484
  return comparison
485
 
486
- # Create Gradio interface
487
- with gr.Blocks(title="高级显著性对象检测系统", css=custom_css) as demo:
488
  gr.Markdown(
489
  """
490
- # 🎯 高级显著性对象检测系统
491
- ### AI驱动的图像显著性检测与分析工具
 
 
492
  """
493
  )
494
 
495
  with gr.Tabs() as tabs:
496
- with gr.TabItem("🔍 主功能"):
497
  with gr.Row():
498
  with gr.Column(scale=1):
499
  # 输入控件
500
  with gr.Group(elem_classes="advanced-controls"):
501
  input_image = gr.Image(
502
- label="输入图像",
503
  type="numpy",
504
  elem_classes="input-image"
505
  )
@@ -509,29 +649,32 @@ with gr.Blocks(title="高级显著性对象检测系统", css=custom_css) as dem
509
  choices=["完整分析", "快速模式"],
510
  value="完整分析",
511
  label="处理模式",
512
- info="快速模式仅输出显著性图,处理速度更快"
513
  )
514
 
515
- with gr.Accordion("基本设置", open=True):
516
  threshold_slider = gr.Slider(
517
  minimum=0,
518
  maximum=100,
519
  value=50,
520
  step=1,
521
- label="检测阈值",
522
- info="调整检测灵敏度"
523
  )
524
  enhance_contrast = gr.Checkbox(
525
  label="增强对比度",
526
- value=False
 
527
  )
528
  denoise = gr.Checkbox(
529
- label="降噪",
530
- value=False
 
531
  )
532
  show_contours = gr.Checkbox(
533
- label="显示轮廓",
534
- value=True
 
535
  )
536
 
537
  with gr.Accordion("图像调整", open=False):
@@ -540,19 +683,19 @@ with gr.Blocks(title="高级显著性对象检测系统", css=custom_css) as dem
540
  maximum=100,
541
  value=0,
542
  step=1,
543
- label="亮度"
544
  )
545
  contrast = gr.Slider(
546
  minimum=-100,
547
  maximum=100,
548
  value=0,
549
  step=1,
550
- label="对比度"
551
  )
552
  filter_type = gr.Radio(
553
  choices=["无", "锐化", "模糊", "边缘增强"],
554
  value="无",
555
- label="图像滤镜"
556
  )
557
 
558
  with gr.Accordion("导出选项", open=False):
@@ -568,7 +711,7 @@ with gr.Blocks(title="高级显著性对象检测系统", css=custom_css) as dem
568
 
569
  with gr.Row():
570
  submit_btn = gr.Button(
571
- "开始检测",
572
  variant="primary",
573
  elem_classes="custom-button"
574
  )
@@ -597,117 +740,142 @@ with gr.Blocks(title="高级显著性对象检测系统", css=custom_css) as dem
597
  elem_classes="output-image"
598
  )
599
  overlayed_output = gr.Image(
600
- label="叠加效果",
601
  elem_classes="output-image"
602
  )
603
 
604
  with gr.Row(elem_classes="result-container"):
605
  segmented_output = gr.Image(
606
- label="对象分割",
607
  elem_classes="output-image"
608
  )
609
  comparison_output = gr.Image(
610
- label="并排对比",
611
  elem_classes="output-image"
612
  )
613
 
614
- with gr.TabItem("交互式对比"):
615
- with gr.Group(elem_classes="comparison-slider"):
 
616
  comparison_slider = gr.Slider(
617
  minimum=0,
618
  maximum=1,
619
  value=0.5,
620
  step=0.01,
621
- label="原始 ← → 结果",
622
- info="拖动滑块对比原始图像和处理结果"
623
  )
624
  interactive_comparison = gr.Image(
625
- label="交互式对比视图",
626
- elem_classes="interactive-viz"
627
  )
628
 
629
  with gr.TabItem("分析报告"):
630
  with gr.Group(elem_classes="result-container"):
631
  time_info = gr.Textbox(
632
- label="处理时间",
633
  show_label=True
634
  )
635
  stats_output = gr.HTML(
636
- label="统计信息"
637
  )
638
  analysis_plot = gr.Image(
639
- label="详细分析图表",
640
  elem_classes="output-image"
641
  )
642
 
643
- with gr.TabItem("📖 用户指南"):
644
  gr.Markdown(
645
  """
646
- ## 使用说明
647
- 1. **上传图像**:点击"输入图像"区域上传您的图像
648
- 2. **选择模式**:选择"完整分析"或"快速模式"
649
- - 完整分析:完整处理流程,包含所有可视化结果
650
- - 快速模式:快速处理,仅输出显著性图
651
- 3. **调整参数**:
652
- - 使用阈值滑块调整检测灵敏度
653
- - 根据需要启用对比度增强或降噪
654
- - 在高级设置中微调亮度、对比度和滤镜
655
- 4. **开始检测**:点击"开始检测"按钮开始分析
656
- 5. **查看结果**:在不同标���页查看各种可视化结果
657
- 6. **导出**:使用导出选项保存您的结果
658
 
659
- ## 功能特点
660
- - **显著性图**:显示图像区域的显著性分布
661
- - **热力图**:彩色编码的强度可视化
662
- - **叠加效果**:在原始图像上叠加检测结果
663
- - **对象分割**:提取关键对象区域
664
- - **交互式对比**:滑动比较原始图像和处理结果
665
- - **分析报告**:详细的统计信息和分析图表
666
 
667
- ## 性能提示
668
- - 当只需要显著性图时使用快速模式
669
- - 分辨率较低的图像处理速度更快
670
- - 启用GPU可获得更好的性能
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
  """
672
  )
673
 
674
- with gr.TabItem("ℹ️ 关于"):
675
  gr.Markdown(
676
  """
677
- ## 项目信息
678
- - **版本**:3.0.0
679
- - **架构**:PyTorch + Gradio
680
- - **模型**:CyueNet
681
- - **语言**:多语言支持
 
 
 
 
 
 
 
 
 
682
 
683
- ## 主要特点
684
- - 实时图像处理和分析
685
- - 多维结果可视化
686
- - 丰富的图像调整选项
687
- - 详细的数据分析报告
688
- - 交互式对比工具
689
- - 导出功能
690
- - 缓存优化性能
691
 
692
- ## 更新日志
693
- - 新增快速模式,提高处理速度
694
- - 增强图像预处理选项
695
- - 新增统计分析功能
696
- - ✅ 改进用户界面,采用玻璃拟态设计
697
- - ✅ 增加交互式对比滑块
698
- - ✅ 使用缓存和线程优化性能
699
- - ✅ 多语言图表支持
700
- - ✅ 导出功能
701
 
702
- ## 系统要求
703
- - Python 3.8+
704
- - PyTorch 1.9+
705
- - CUDA(可选,用于GPU加速)
706
- - 推荐4GB以上内存
 
 
 
 
 
 
 
 
 
707
  """
708
  )
709
 
710
- # 事件处理
711
  def reset_params():
712
  return {
713
  threshold_slider: 50,
@@ -782,5 +950,8 @@ if __name__ == "__main__":
782
  server_name="0.0.0.0",
783
  server_port=7860,
784
  share=True,
785
- show_error=True
 
786
  )
 
 
 
1
+ The user wants me to continue with the complete code. I need to provide the rest of the Python code for the saliency detection system, incorporating the modifications for better print visibility and using original saliency values for analysis instead of binary values.
2
+
3
+ I need to continue with the ImageProcessor class and the rest of the code, making sure to include:
4
+ 1. The modified generate_analysis_plots function that uses original saliency values
5
+ 2. All the other methods and functionality
6
+ 3. The Gradio interface setup
7
+ 4. The main execution block
8
+
9
+ I should also apply the default design guidelines provided to make the interface more clean and elegant while maintaining the high contrast for print visibility.
10
+
11
+
12
+ 继续第二部分:
13
+
14
+ `saliency_detection_app.py`
15
+ ```py
16
  import torch
17
  import torch.nn.functional as F
18
  import numpy as np
 
32
  # GPU/CPU设置
33
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
34
 
35
+ # CSS样式设置 - 优化为清晰的黑白打印效果和现代设计
36
  custom_css = """
37
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700;800&display=swap');
38
+
39
  :root {
40
+ --primary-color: #000000;
41
+ --secondary-color: #404040;
42
+ --accent-color: #666666;
43
+ --background-color: #ffffff;
44
+ --text-color: #000000;
45
+ --border-color: #e2e2e2;
46
+ --card-bg: #ffffff;
47
+ --shadow: 0 4px 16px rgba(0, 0, 0, 0.05);
48
+ --border-radius: 12px;
49
+ }
50
+
51
+ body {
52
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
53
+ background: var(--background-color);
54
+ color: var(--text-color);
55
  }
56
 
57
  .gradio-container {
58
+ max-width: 1200px !important;
59
+ margin: 0 auto !important;
60
+ background: var(--background-color);
61
+ color: var(--text-color);
62
+ padding: 2rem;
63
  }
64
 
65
+ /* Typography */
66
+ h1, h2, h3, h4 {
67
+ color: var(--text-color);
68
+ font-weight: 700;
69
+ line-height: 1.2;
70
+ margin-bottom: 1rem;
71
  }
72
 
73
+ h1 {
74
+ font-size: 2.5rem;
75
+ font-weight: 800;
76
+ text-align: center;
77
+ margin-bottom: 0.5rem;
78
  }
79
 
80
+ h2 {
81
+ font-size: 1.8rem;
82
+ font-weight: 700;
 
 
 
 
 
 
 
 
 
83
  }
84
 
85
+ h3 {
86
+ font-size: 1.4rem;
87
+ font-weight: 600;
88
  }
89
 
90
+ p {
91
+ color: var(--accent-color);
92
+ font-size: 1rem;
93
+ line-height: 1.6;
94
+ margin-bottom: 1rem;
 
 
95
  }
96
 
97
+ /* Cards and containers */
98
+ .output-image, .input-image {
99
+ border: 2px solid var(--border-color);
100
+ border-radius: var(--border-radius);
101
+ background: var(--card-bg);
 
102
  box-shadow: var(--shadow);
103
+ transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
104
  }
105
 
106
+ .output-image:hover, .input-image:hover {
107
+ border-color: var(--primary-color);
108
+ transform: translateY(-2px);
109
+ box-shadow: 0 8px 24px rgba(0, 0, 0, 0.1);
110
+ }
111
+
112
+ .advanced-controls, .result-container {
113
+ background: var(--card-bg);
114
+ border: 1px solid var(--border-color);
115
+ border-radius: var(--border-radius);
116
+ padding: 2rem;
117
+ margin: 1rem 0;
118
  box-shadow: var(--shadow);
119
  }
120
 
121
+ .custom-button {
122
+ background: var(--primary-color);
123
+ color: white;
124
+ border: 2px solid var(--primary-color);
125
+ padding: 12px 32px;
126
+ border-radius: 8px;
127
+ font-weight: 600;
128
+ font-size: 1rem;
129
+ cursor: pointer;
130
+ transition: all 0.3s ease;
131
+ text-transform: none;
132
+ letter-spacing: 0;
133
+ }
134
+
135
+ .custom-button:hover {
136
+ background: var(--secondary-color);
137
+ border-color: var(--secondary-color);
138
+ transform: translateY(-1px);
139
  }
140
 
141
+ /* Statistics */
142
  .statistics-container {
143
  display: grid;
144
+ grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
145
+ gap: 1.5rem;
146
+ margin-top: 1.5rem;
147
  }
148
 
149
  .statistic-card {
150
+ background: var(--card-bg);
151
+ border: 1px solid var(--border-color);
152
  border-radius: var(--border-radius);
153
+ padding: 1.5rem;
154
  text-align: center;
155
  box-shadow: var(--shadow);
156
+ transition: transform 0.3s ease;
 
 
157
  }
158
 
159
  .statistic-card:hover {
160
  transform: translateY(-2px);
 
161
  }
162
 
163
+ .statistic-card h4 {
164
+ color: var(--accent-color);
165
+ font-size: 0.9rem;
166
+ font-weight: 600;
167
+ text-transform: uppercase;
168
+ letter-spacing: 0.5px;
169
+ margin-bottom: 0.5rem;
170
  }
171
 
172
+ .statistic-card p {
173
+ color: var(--text-color);
174
+ font-size: 1.5rem;
175
+ font-weight: 700;
176
+ margin: 0;
177
+ }
178
+
179
+ /* Tabs */
180
+ .tab-nav {
181
+ border-bottom: 2px solid var(--border-color);
182
+ margin-bottom: 2rem;
183
+ }
184
+
185
+ /* Forms and inputs */
186
+ .gradio-slider input {
187
+ accent-color: var(--primary-color);
188
+ }
189
+
190
+ .gradio-checkbox input {
191
+ accent-color: var(--primary-color);
192
+ }
193
+
194
+ /* 打印样式优化 */
195
+ @media print {
196
+ * {
197
+ -webkit-print-color-adjust: exact !important;
198
+ print-color-adjust: exact !important;
199
+ }
200
+
201
+ body, .gradio-container {
202
+ background: white !important;
203
+ color: black !important;
204
+ }
205
+
206
+ .custom-button {
207
+ border: 2px solid black !important;
208
+ background: white !important;
209
+ color: black !important;
210
+ }
211
+
212
+ .statistic-card, .advanced-controls, .result-container {
213
+ border: 2px solid black !important;
214
+ }
215
+
216
+ h1, h2, h3, h4, p {
217
+ color: black !important;
218
+ }
219
+ }
220
+
221
+ /* Responsive design */
222
+ @media (max-width: 768px) {
223
+ .gradio-container {
224
+ padding: 1rem;
225
+ }
226
+
227
+ h1 {
228
+ font-size: 2rem;
229
+ }
230
+
231
+ .statistics-container {
232
+ grid-template-columns: 1fr;
233
+ }
234
  }
235
  """
236
+
237
  class ImageProcessor:
238
  def __init__(self):
239
  self.model = None
 
286
  return image
287
 
288
  def generate_analysis_plots(self, saliency_map):
289
+ """生成分析图表 - 使用原始显著性值,优化黑白打印效果"""
290
+ plt.style.use('default') # 使用默认样式
291
+ plt.rcParams.update({
292
+ 'font.size': 12,
293
+ 'font.weight': 'bold',
294
+ 'axes.labelweight': 'bold',
295
+ 'axes.titleweight': 'bold',
296
+ 'figure.facecolor': 'white',
297
+ 'axes.facecolor': 'white'
298
+ })
299
+
300
  fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
301
+
302
+ # 使用原��显著性值(未经二值化处理)
303
+ raw_values = saliency_map.flatten()
304
+
305
+ # 直方图 - 显示原始显著性分布
306
+ ax1.hist(raw_values, bins=50, color='black', alpha=0.7, edgecolor='black', linewidth=1)
307
+ ax1.set_title('显著性分布直方图', fontsize=14, fontweight='bold', color='black')
308
+ ax1.set_xlabel('显著性值 (0-1)', fontsize=11, fontweight='bold', color='black')
309
+ ax1.set_ylabel('频率', fontsize=11, fontweight='bold', color='black')
310
+ ax1.grid(True, color='gray', alpha=0.3, linewidth=0.5)
311
+ ax1.tick_params(colors='black')
312
+
313
  # 添加统计信息
314
+ mean_val = np.mean(raw_values)
315
+ median_val = np.median(raw_values)
316
+ std_val = np.std(raw_values)
317
+ ax1.axvline(mean_val, color='black', linestyle='--', alpha=0.8, linewidth=2,
318
+ label=f'均值: {mean_val:.3f}')
319
+ ax1.axvline(median_val, color='gray', linestyle='-.', alpha=0.8, linewidth=2,
320
+ label=f'中位数: {median_val:.3f}')
321
+ ax1.legend(frameon=True, facecolor='white', edgecolor='black')
322
+
323
+ # 累积分布函数
324
+ sorted_vals = np.sort(raw_values)
325
  cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
326
+ ax2.plot(sorted_vals, cumulative, color='black', linewidth=3)
327
+ ax2.set_title('累积分布函数', fontsize=14, fontweight='bold', color='black')
328
+ ax2.set_xlabel('显著性值', fontsize=11, fontweight='bold', color='black')
329
+ ax2.set_ylabel('累积概率', fontsize=11, fontweight='bold', color='black')
330
+ ax2.grid(True, color='gray', alpha=0.3, linewidth=0.5)
331
+ ax2.tick_params(colors='black')
332
+
333
+ # 添加关键百分位数
334
+ p25 = np.percentile(raw_values, 25)
335
+ p75 = np.percentile(raw_values, 75)
336
+ ax2.axvline(p25, color='gray', linestyle=':', alpha=0.7, label=f'25%: {p25:.3f}')
337
+ ax2.axvline(p75, color='gray', linestyle=':', alpha=0.7, label=f'75%: {p75:.3f}')
338
+ ax2.legend(frameon=True, facecolor='white', edgecolor='black')
339
+
340
+ # 箱线图 - 显示原始数据分布
341
+ bp = ax3.boxplot(raw_values, patch_artist=True,
342
+ boxprops=dict(facecolor='white', color='black', linewidth=2),
343
+ whiskerprops=dict(color='black', linewidth=2),
344
+ capprops=dict(color='black', linewidth=2),
345
+ medianprops=dict(color='black', linewidth=3))
346
+ ax3.set_title('显著性分布箱线图', fontsize=14, fontweight='bold', color='black')
347
+ ax3.set_ylabel('显著性值', fontsize=11, fontweight='bold', color='black')
348
+ ax3.grid(True, color='gray', alpha=0.3, linewidth=0.5)
349
+ ax3.tick_params(colors='black')
350
+
351
+ # 强度剖面图 - 使用原始显著性值
352
  center_row = saliency_map[saliency_map.shape[0]//2, :]
353
+ edge_row = saliency_map[saliency_map.shape[0]//4, :]
354
+
355
+ ax4.plot(center_row, color='black', linewidth=2, label='中心线')
356
+ ax4.plot(edge_row, color='gray', linewidth=2, linestyle='--', label='1/4位置线')
357
+ ax4.set_title('水平剖面显著性分析', fontsize=14, fontweight='bold', color='black')
358
+ ax4.set_xlabel('像素位置', fontsize=11, fontweight='bold', color='black')
359
+ ax4.set_ylabel('显著性值', fontsize=11, fontweight='bold', color='black')
360
+ ax4.grid(True, color='gray', alpha=0.3, linewidth=0.5)
361
+ ax4.legend(frameon=True, facecolor='white', edgecolor='black')
362
+ ax4.tick_params(colors='black')
363
+
364
+ plt.tight_layout(pad=3.0)
365
+
366
+ # 保存图表
367
  buf = io.BytesIO()
368
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight',
369
+ facecolor='white', edgecolor='black')
370
  buf.seek(0)
371
  img_array = np.array(Image.open(buf))
372
  plt.close()
373
  return img_array
374
 
375
  def quick_process(self, image, threshold=0.5, testsize=256):
376
+ """快速处理模式"""
377
  if image is None:
378
  return None, "请提供有效的图像"
379
 
 
391
  time_start = time.time()
392
 
393
  with torch.no_grad():
394
+ if device.type == 'cuda' and hasattr(torch.cuda.amp, 'autocast'):
 
395
  with torch.cuda.amp.autocast():
396
+ _, res = self.model(image_tensor)[:2] # 只取前两个输出
397
  else:
398
+ _, res = self.model(image_tensor)[:2] # 只取前两个输出
 
399
 
400
  time_end = time.time()
401
 
402
+ # 处理输出
403
  res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
404
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
405
 
 
412
 
413
  return result
414
 
 
415
  def process_image(self, image, threshold=0.5, testsize=256,
416
  enhance_contrast=False, denoise=False,
417
  brightness=0, contrast=0, filter_type="无",
 
432
  if cache_key in self.cache:
433
  return self.cache[cache_key]
434
 
435
+ # 图像预处理
436
  def preprocess_image():
437
  processed_image = image.copy()
438
 
 
466
  time_start = time.time()
467
 
468
  with torch.no_grad():
469
+ if device.type == 'cuda' and hasattr(torch.cuda.amp, 'autocast'):
470
  with torch.cuda.amp.autocast():
471
+ model_outputs = self.model(image_tensor)
472
  else:
473
+ model_outputs = self.model(image_tensor)
474
+
475
+ # 获取主要输出
476
+ x1, res = model_outputs[0], model_outputs[1]
477
 
478
  time_end = time.time()
479
  inference_time = time_end - time_start
480
 
481
+ # 处理显著性图
482
  res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
483
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
484
 
485
  h, w = original_image.shape[:2]
486
  res_resized = cv2.resize(res, (w, h))
487
 
488
+ # 生成可视化(保留原始显著性值用于分析)
489
  res_vis = (res_resized * 255).astype(np.uint8)
490
  heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
491
  _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
 
501
  segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
502
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
503
 
504
+ # 生成分析图表(使用原始显著性值)
505
  analysis_plot = self.generate_analysis_plots(res_resized)
506
 
507
+ # 计算统计信息(基于原始显著性值)
508
  contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
509
  total_area = w * h
510
  detected_area = cv2.countNonZero(binary_mask)
511
  coverage_ratio = detected_area / total_area
512
 
513
  stats = {
514
+ "处理分辨率": f"{w}×{h}",
515
+ "检测对象数": str(len(contours)),
516
+ "平均显著性": f"{np.mean(res_resized):.3f}",
517
+ "最大显著性": f"{np.max(res_resized):.3f}",
518
+ "标准差": f"{np.std(res_resized):.3f}",
519
+ "覆盖率": f"{coverage_ratio:.1%}",
520
  "处理时间": f"{inference_time:.3f}秒"
521
  }
522
 
 
558
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
559
 
560
  if format_type == "PDF报告":
 
561
  return f"PDF报告已保存为 saliency_report_{timestamp}.pdf"
562
  else:
563
  return f"结果已导出为 {format_type.lower()} 文件"
564
+
565
+ # 创建处理器实例
566
  processor = ImageProcessor()
567
 
568
  def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
 
621
 
622
  return comparison
623
 
624
+ # 创建Gradio界面
625
+ with gr.Blocks(title="显著性对象检测系统", css=custom_css) as demo:
626
  gr.Markdown(
627
  """
628
+ # 显著性对象检测系统
629
+ ### 基于深度学习的智能图像分析工具
630
+
631
+ 专业的显著性检测系统,提供完整的分析报告和多维度可视化结果。
632
  """
633
  )
634
 
635
  with gr.Tabs() as tabs:
636
+ with gr.TabItem("🔍 主要功能"):
637
  with gr.Row():
638
  with gr.Column(scale=1):
639
  # 输入控件
640
  with gr.Group(elem_classes="advanced-controls"):
641
  input_image = gr.Image(
642
+ label="上传图像",
643
  type="numpy",
644
  elem_classes="input-image"
645
  )
 
649
  choices=["完整分析", "快速模式"],
650
  value="完整分析",
651
  label="处理模式",
652
+ info="完整分析:全面的结果和分析 | 快速模式:仅显著性图"
653
  )
654
 
655
+ with gr.Accordion("基本参数", open=True):
656
  threshold_slider = gr.Slider(
657
  minimum=0,
658
  maximum=100,
659
  value=50,
660
  step=1,
661
+ label="检测阈值 (%)",
662
+ info="调整检测的敏感度"
663
  )
664
  enhance_contrast = gr.Checkbox(
665
  label="增强对比度",
666
+ value=False,
667
+ info="自适应直方图均衡化"
668
  )
669
  denoise = gr.Checkbox(
670
+ label="图像降噪",
671
+ value=False,
672
+ info="非局部均值降噪"
673
  )
674
  show_contours = gr.Checkbox(
675
+ label="显示轮廓边界",
676
+ value=True,
677
+ info="在结果上绘制检测轮廓"
678
  )
679
 
680
  with gr.Accordion("图像调整", open=False):
 
683
  maximum=100,
684
  value=0,
685
  step=1,
686
+ label="亮度调整"
687
  )
688
  contrast = gr.Slider(
689
  minimum=-100,
690
  maximum=100,
691
  value=0,
692
  step=1,
693
+ label="对比度调整"
694
  )
695
  filter_type = gr.Radio(
696
  choices=["无", "锐化", "模糊", "边缘增强"],
697
  value="无",
698
+ label="滤镜效果"
699
  )
700
 
701
  with gr.Accordion("导出选项", open=False):
 
711
 
712
  with gr.Row():
713
  submit_btn = gr.Button(
714
+ "开始分析",
715
  variant="primary",
716
  elem_classes="custom-button"
717
  )
 
740
  elem_classes="output-image"
741
  )
742
  overlayed_output = gr.Image(
743
+ label="叠加可视化",
744
  elem_classes="output-image"
745
  )
746
 
747
  with gr.Row(elem_classes="result-container"):
748
  segmented_output = gr.Image(
749
+ label="对象分割结果",
750
  elem_classes="output-image"
751
  )
752
  comparison_output = gr.Image(
753
+ label="对比视图",
754
  elem_classes="output-image"
755
  )
756
 
757
+ with gr.TabItem("交互对比"):
758
+ with gr.Group(elem_classes="result-container"):
759
+ gr.Markdown("### 滑动对比原始图像与处理结果")
760
  comparison_slider = gr.Slider(
761
  minimum=0,
762
  maximum=1,
763
  value=0.5,
764
  step=0.01,
765
+ label="原始 ← → 处理结果",
766
+ info="拖动滑块查看前后对比效果"
767
  )
768
  interactive_comparison = gr.Image(
769
+ label="交互式对比",
770
+ elem_classes="output-image"
771
  )
772
 
773
  with gr.TabItem("分析报告"):
774
  with gr.Group(elem_classes="result-container"):
775
  time_info = gr.Textbox(
776
+ label="处理信息",
777
  show_label=True
778
  )
779
  stats_output = gr.HTML(
780
+ label="详细统计"
781
  )
782
  analysis_plot = gr.Image(
783
+ label="数据分析图表",
784
  elem_classes="output-image"
785
  )
786
 
787
+ with gr.TabItem("📖 使用指南"):
788
  gr.Markdown(
789
  """
790
+ ## 功能介绍
 
 
 
 
 
 
 
 
 
 
 
791
 
792
+ 本系统是一个基于深度学习的显著性对象检测工具,能够智能识别图像中的重要对象和区域。
 
 
 
 
 
 
793
 
794
+ ### 主要特点
795
+
796
+ - **高精度检测**: 基于先进的神经网络架构
797
+ - **多维度分析**: 提供热力图、分割、统计等多种视角
798
+ - **实时处理**: 支持快速模式和完整分析模式
799
+ - **丰富的可视化**: 包含多种图表和可视化选项
800
+ - **交互式对比**: 滑动对比原图与处理结果
801
+
802
+ ### 使用步骤
803
+
804
+ 1. **上传图像**: 点击上传区域选择您的图像文件
805
+ 2. **选择模式**:
806
+ - 完整分析:获得全面的分析结果和可视化
807
+ - 快速模式:快速获得显著性图,适合批量处理
808
+ 3. **调整参数**:
809
+ - 检测阈值:控制检测的敏感度
810
+ - 图像增强:可选的对比度增强和降噪
811
+ - 图像调整:亮度、对比度和滤镜效果
812
+ 4. **开始分析**: 点击"开始分析"按钮
813
+ 5. **查看结果**: 在不同标签页中查看各种分析结果
814
+ 6. **导出保存**: 可将结果导出为多种格式
815
+
816
+ ### 结果解读
817
+
818
+ - **显著性图**: 白色区域表示高显著性,黑色区域表示低显著性
819
+ - **热力图**: 红色表示高显著性区域,蓝色表示低显著性区域
820
+ - **叠加可视化**: 在原图上叠加显著性热力图
821
+ - **对象分割**: 根据阈值提取的显著对象
822
+ - **统计分析**: 包含均值、标准差、覆盖率等关键指标
823
+
824
+ ### 性能优化
825
+
826
+ - 使用快速模式可显著提升处理速度
827
+ - 较小的图像尺寸处理更快
828
+ - 启用GPU可获得更好的性能表现
829
  """
830
  )
831
 
832
+ with gr.TabItem("ℹ️ 技术信息"):
833
  gr.Markdown(
834
  """
835
+ ## 系统架构
836
+
837
+ ### 核心技术
838
+ - **深度学习框架**: PyTorch
839
+ - **模型架构**: CyueNet (Multi-Modal Saliency)
840
+ - **前端界面**: Gradio
841
+ - **图像处理**: OpenCV, PIL
842
+ - **数据可视化**: Matplotlib
843
+
844
+ ### 模型特点
845
+ - 多尺度特征融合
846
+ - 边缘感知机制
847
+ - 全卷积网络架构
848
+ - 端到端训练优化
849
 
850
+ ### 性能指标
851
+ - 支持输入分辨率: 灵活调整
852
+ - 处理速度: 秒级响应
853
+ - 内存需求: 4GB+ 推荐
854
+ - GPU加速: CUDA支持
 
 
 
855
 
856
+ ### 版本信息
857
+ - **当前版本**: 3.1.0
858
+ - **更新日期**: 2024年
859
+ - **兼容性**: Python 3.8+, PyTorch 1.9+
 
 
 
 
 
860
 
861
+ ### 更新历史
862
+ - 优化黑白打印显示效果
863
+ - 改进分析图表,使用原始显著性值
864
+ - ✅ 增强用户界面设计
865
+ - ✅ 新增交互式对比功能
866
+ - ✅ 优化处理性能和缓存机制
867
+ - ✅ 增加详细统计分析
868
+
869
+ ### 系统要求
870
+ - **操作系统**: Windows, macOS, Linux
871
+ - **Python**: 3.8 或更高版本
872
+ - **内存**: 4GB 以上推荐
873
+ - **存储**: 2GB 可用空间
874
+ - **GPU**: 可选,用于加速处理
875
  """
876
  )
877
 
878
+ # 事件处理函数
879
  def reset_params():
880
  return {
881
  threshold_slider: 50,
 
950
  server_name="0.0.0.0",
951
  server_port=7860,
952
  share=True,
953
+ show_error=True,
954
+ show_tips=True
955
  )
956
+
957
+ ```