kunkk commited on
Commit
b8af151
·
verified ·
1 Parent(s): 249ac00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -168
app.py CHANGED
@@ -149,11 +149,11 @@ class ImageProcessor:
149
  self.model = MMS()
150
  try:
151
  self.model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device))
152
- print("Model loaded successfully")
153
  except RuntimeError as e:
154
- print(f"Model loading error: {e}")
155
  except FileNotFoundError:
156
- print("Model file not found. Please check the path.")
157
  self.model.to(device)
158
  self.model.eval()
159
 
@@ -178,62 +178,62 @@ class ImageProcessor:
178
 
179
  def apply_filters(self, image, filter_type):
180
  """应用图像滤镜效果"""
181
- if filter_type == "Sharpen":
182
  kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
183
  return cv2.filter2D(image, -1, kernel)
184
- elif filter_type == "Blur":
185
  return cv2.GaussianBlur(image, (5,5), 0)
186
- elif filter_type == "Edge Enhancement":
187
  kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]])
188
  return cv2.filter2D(image, -1, kernel)
189
  return image
190
 
191
  def generate_analysis_plots(self, saliency_map):
192
- """生成英文分析图表"""
193
  plt.style.use('seaborn-v0_8')
194
  fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
195
 
196
- # Histogram
197
  ax1.hist(saliency_map.flatten(), bins=50, color='#2196F3', alpha=0.7, edgecolor='black')
198
- ax1.set_title('Saliency Distribution Histogram', fontsize=12, pad=15)
199
- ax1.set_xlabel('Saliency Value', fontsize=10)
200
- ax1.set_ylabel('Frequency', fontsize=10)
201
  ax1.grid(True, alpha=0.3)
202
 
203
- # Add statistics
204
  mean_val = np.mean(saliency_map)
205
  median_val = np.median(saliency_map)
206
- ax1.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_val:.3f}')
207
- ax1.axvline(median_val, color='green', linestyle='--', alpha=0.7, label=f'Median: {median_val:.3f}')
208
  ax1.legend()
209
 
210
- # Cumulative distribution
211
  sorted_vals = np.sort(saliency_map.flatten())
212
  cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
213
  ax2.plot(sorted_vals, cumulative, color='#FF6B35', linewidth=2)
214
- ax2.set_title('Cumulative Distribution Function', fontsize=12)
215
- ax2.set_xlabel('Saliency Value', fontsize=10)
216
- ax2.set_ylabel('Cumulative Probability', fontsize=10)
217
  ax2.grid(True, alpha=0.3)
218
 
219
- # Box plot
220
  ax3.boxplot(saliency_map.flatten(), patch_artist=True,
221
  boxprops=dict(facecolor='#21CBF3', alpha=0.7))
222
- ax3.set_title('Saliency Distribution Box Plot', fontsize=12)
223
- ax3.set_ylabel('Saliency Value', fontsize=10)
224
  ax3.grid(True, alpha=0.3)
225
 
226
- # Intensity profile (center line)
227
  center_row = saliency_map[saliency_map.shape[0]//2, :]
228
  ax4.plot(center_row, color='#9C27B0', linewidth=2)
229
- ax4.set_title('Center Line Intensity Profile', fontsize=12)
230
- ax4.set_xlabel('Pixel Position', fontsize=10)
231
- ax4.set_ylabel('Saliency Value', fontsize=10)
232
  ax4.grid(True, alpha=0.3)
233
 
234
  plt.tight_layout()
235
 
236
- # Save to bytes
237
  buf = io.BytesIO()
238
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
239
  buf.seek(0)
@@ -245,9 +245,9 @@ class ImageProcessor:
245
  def quick_process(self, image, threshold=0.5, testsize=256):
246
  """快速处理模式,只输出显著性图"""
247
  if image is None:
248
- return None, "Please provide a valid image"
249
 
250
- # Check cache
251
  image_hash = hash(image.tobytes())
252
  cache_key = f"{image_hash}_{threshold}_{testsize}_quick"
253
 
@@ -278,31 +278,31 @@ class ImageProcessor:
278
  res_resized = cv2.resize(res, (w, h))
279
  res_vis = (res_resized * 255).astype(np.uint8)
280
 
281
- result = (res_vis, f"Quick processing completed in {time_end - time_start:.3f}s")
282
  self.cache[cache_key] = result
283
 
284
  return result
285
  def process_image(self, image, threshold=0.5, testsize=256,
286
  enhance_contrast=False, denoise=False,
287
- brightness=0, contrast=0, filter_type="None",
288
- process_mode="Full Analysis"):
289
  """增强的图像处理函数"""
290
  if image is None:
291
- return [None] * 9 + ["Please provide a valid image"]
292
 
293
- # Quick mode check
294
- if process_mode == "Quick Mode":
295
  saliency_map, time_info = self.quick_process(image, threshold, testsize)
296
  return (image, saliency_map, None, None, None, None, time_info, None, None)
297
 
298
- # Check cache for full processing
299
  image_hash = hash(image.tobytes())
300
  cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full"
301
 
302
  if cache_key in self.cache:
303
  return self.cache[cache_key]
304
 
305
- # Image preprocessing with threading
306
  def preprocess_image():
307
  processed_image = image.copy()
308
 
@@ -328,7 +328,7 @@ class ImageProcessor:
328
 
329
  original_image = processed_image.copy()
330
 
331
- # Model inference
332
  image_pil = Image.fromarray(processed_image).convert('RGB')
333
  image_tensor = transform_image(image_pil, testsize)
334
  image_tensor = image_tensor.unsqueeze(0).to(device)
@@ -353,44 +353,44 @@ class ImageProcessor:
353
  h, w = original_image.shape[:2]
354
  res_resized = cv2.resize(res, (w, h))
355
 
356
- # Generate visualizations
357
  res_vis = (res_resized * 255).astype(np.uint8)
358
  heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
359
  _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
360
 
361
- # Create overlays
362
  alpha = 0.5
363
  original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
364
  overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
365
  segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
366
 
367
- # Convert back to RGB
368
  overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
369
  segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
370
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
371
 
372
- # Generate analysis plots
373
  analysis_plot = self.generate_analysis_plots(res_resized)
374
 
375
- # Calculate statistics
376
  contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
377
  total_area = w * h
378
  detected_area = cv2.countNonZero(binary_mask)
379
  coverage_ratio = detected_area / total_area
380
 
381
  stats = {
382
- "Processing Resolution": f"{w}x{h}",
383
- "Detected Objects": str(len(contours)),
384
- "Average Confidence": f"{np.mean(res_resized):.2%}",
385
- "Max Confidence": f"{np.max(res_resized):.2%}",
386
- "Coverage Ratio": f"{coverage_ratio:.2%}",
387
- "Processing Time": f"{inference_time:.3f}s"
388
  }
389
 
390
- # Create comparison image
391
  comparison_img = self.create_comparison_image(original_image, overlayed_rgb)
392
 
393
- # Save results
394
  self.last_results = {
395
  'saliency_map': res_resized,
396
  'binary_mask': binary_mask,
@@ -398,9 +398,9 @@ class ImageProcessor:
398
  }
399
 
400
  result = (original_image, res_vis, heatmap_rgb, overlayed_rgb, segmented_rgb,
401
- comparison_img, f"Processing time: {inference_time:.4f}s", stats, analysis_plot)
402
 
403
- # Cache result
404
  self.cache[cache_key] = result
405
 
406
  return result
@@ -412,7 +412,7 @@ class ImageProcessor:
412
  comparison[:, :w] = original
413
  comparison[:, w:] = processed
414
 
415
- # Add dividing line
416
  cv2.line(comparison, (w, 0), (w, h), (255, 255, 255), 2)
417
 
418
  return comparison
@@ -420,15 +420,15 @@ class ImageProcessor:
420
  def export_results(self, format_type="PNG"):
421
  """导出结果"""
422
  if self.last_results is None:
423
- return "No results to export"
424
 
425
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
426
 
427
- if format_type == "PDF Report":
428
- # Generate PDF report logic here
429
- return f"PDF report saved as saliency_report_{timestamp}.pdf"
430
  else:
431
- return f"Results exported as {format_type.lower()} files"
432
  # Create processor instance
433
  processor = ImageProcessor()
434
 
@@ -436,9 +436,9 @@ def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
436
  brightness, contrast, filter_type, process_mode):
437
  """主处理函数"""
438
  if input_image is None:
439
- return [None] * 9 + ["Please upload an image"]
440
 
441
- # Process image
442
  results = processor.process_image(
443
  input_image,
444
  threshold=threshold/100.0,
@@ -452,7 +452,7 @@ def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
452
 
453
  original, saliency_map, heatmap, overlayed, segmented, comparison, time_info, stats, analysis_plot = results
454
 
455
- # Add contours if requested
456
  if show_contours and saliency_map is not None and overlayed is not None:
457
  _, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY)
458
  contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -460,14 +460,14 @@ def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
460
  cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2)
461
  overlayed = overlay_with_contours
462
 
463
- # Generate statistics HTML
464
  if stats:
465
  stats_html = "<div class='statistics-container'>"
466
  for key, value in stats.items():
467
  stats_html += f"<div class='statistic-card'><h4>{key}</h4><p>{value}</p></div>"
468
  stats_html += "</div>"
469
  else:
470
- stats_html = "<p>No statistics available</p>"
471
 
472
  return (original, saliency_map, heatmap, overlayed, segmented,
473
  comparison, time_info, stats_html, analysis_plot)
@@ -483,249 +483,249 @@ def create_comparison_view(original, result, slider_value):
483
  comparison = original.copy()
484
  comparison[:, split_point:] = result[:, split_point:]
485
 
486
- # Add vertical line
487
  cv2.line(comparison, (split_point, 0), (split_point, h), (255, 255, 0), 3)
488
 
489
  return comparison
490
 
491
  # Create Gradio interface
492
- with gr.Blocks(title="Advanced Saliency Object Detection System", css=custom_css) as demo:
493
  gr.Markdown(
494
  """
495
- # 🎯 Advanced Saliency Object Detection System
496
- ### AI-Powered Image Saliency Detection and Analysis Tool
497
  """
498
  )
499
 
500
  with gr.Tabs() as tabs:
501
- with gr.TabItem("🔍 Main Functions"):
502
  with gr.Row():
503
  with gr.Column(scale=1):
504
- # Input controls
505
  with gr.Group(elem_classes="advanced-controls"):
506
  input_image = gr.Image(
507
- label="Input Image",
508
  type="numpy",
509
  elem_classes="input-image"
510
  )
511
 
512
- # Processing mode selection
513
  process_mode = gr.Radio(
514
- choices=["Full Analysis", "Quick Mode"],
515
- value="Full Analysis",
516
- label="Processing Mode",
517
- info="Quick mode only outputs saliency map for faster processing"
518
  )
519
 
520
- with gr.Accordion("Basic Settings", open=True):
521
  threshold_slider = gr.Slider(
522
  minimum=0,
523
  maximum=100,
524
  value=50,
525
  step=1,
526
- label="Detection Threshold",
527
- info="Adjust detection sensitivity"
528
  )
529
  enhance_contrast = gr.Checkbox(
530
- label="Enhance Contrast",
531
  value=False
532
  )
533
  denoise = gr.Checkbox(
534
- label="Noise Reduction",
535
  value=False
536
  )
537
  show_contours = gr.Checkbox(
538
- label="Show Contours",
539
  value=True
540
  )
541
 
542
- with gr.Accordion("Image Adjustments", open=False):
543
  brightness = gr.Slider(
544
  minimum=-100,
545
  maximum=100,
546
  value=0,
547
  step=1,
548
- label="Brightness"
549
  )
550
  contrast = gr.Slider(
551
  minimum=-100,
552
  maximum=100,
553
  value=0,
554
  step=1,
555
- label="Contrast"
556
  )
557
  filter_type = gr.Radio(
558
- choices=["None", "Sharpen", "Blur", "Edge Enhancement"],
559
- value="None",
560
- label="Image Filter"
561
  )
562
 
563
- with gr.Accordion("Export Options", open=False):
564
  export_format = gr.Dropdown(
565
- choices=["PNG", "JPEG", "PDF Report"],
566
  value="PNG",
567
- label="Export Format"
568
  )
569
  export_btn = gr.Button(
570
- "Export Results",
571
  elem_classes="custom-button"
572
  )
573
 
574
  with gr.Row():
575
  submit_btn = gr.Button(
576
- "Start Detection",
577
  variant="primary",
578
  elem_classes="custom-button"
579
  )
580
  reset_btn = gr.Button(
581
- "Reset Parameters",
582
  elem_classes="custom-button"
583
  )
584
 
585
  with gr.Column(scale=2):
586
- # Results display
587
  with gr.Tabs():
588
- with gr.TabItem("Detection Results"):
589
  with gr.Row(elem_classes="result-container"):
590
  original_output = gr.Image(
591
- label="Original Image",
592
  elem_classes="output-image"
593
  )
594
  saliency_output = gr.Image(
595
- label="Saliency Map",
596
  elem_classes="output-image"
597
  )
598
 
599
  with gr.Row(elem_classes="result-container"):
600
  heatmap_output = gr.Image(
601
- label="Heatmap Analysis",
602
  elem_classes="output-image"
603
  )
604
  overlayed_output = gr.Image(
605
- label="Overlay Effect",
606
  elem_classes="output-image"
607
  )
608
 
609
  with gr.Row(elem_classes="result-container"):
610
  segmented_output = gr.Image(
611
- label="Object Segmentation",
612
  elem_classes="output-image"
613
  )
614
  comparison_output = gr.Image(
615
- label="Side-by-Side Comparison",
616
  elem_classes="output-image"
617
  )
618
 
619
- with gr.TabItem("Interactive Comparison"):
620
  with gr.Group(elem_classes="comparison-slider"):
621
  comparison_slider = gr.Slider(
622
  minimum=0,
623
  maximum=1,
624
  value=0.5,
625
  step=0.01,
626
- label="Original ← → Result",
627
- info="Drag to compare original and processed images"
628
  )
629
  interactive_comparison = gr.Image(
630
- label="Interactive Comparison View",
631
  elem_classes="interactive-viz"
632
  )
633
 
634
- with gr.TabItem("Analysis Report"):
635
  with gr.Group(elem_classes="result-container"):
636
  time_info = gr.Textbox(
637
- label="Processing Time",
638
  show_label=True
639
  )
640
  stats_output = gr.HTML(
641
- label="Statistical Information"
642
  )
643
  analysis_plot = gr.Image(
644
- label="Detailed Analysis Charts",
645
  elem_classes="output-image"
646
  )
647
 
648
- with gr.TabItem("📖 User Guide"):
649
  gr.Markdown(
650
  """
651
- ## Instructions
652
- 1. **Upload Image**: Click the "Input Image" area to upload your image
653
- 2. **Select Mode**: Choose between "Full Analysis" or "Quick Mode"
654
- - Full Analysis: Complete processing with all visualizations
655
- - Quick Mode: Fast processing, only outputs saliency map
656
- 3. **Adjust Parameters**:
657
- - Use threshold slider to adjust detection sensitivity
658
- - Enable contrast enhancement or noise reduction as needed
659
- - Fine-tune brightness, contrast, and filters in advanced settings
660
- 4. **Start Detection**: Click "Start Detection" to begin analysis
661
- 5. **View Results**: Check different tabs for various visualization results
662
- 6. **Export**: Use export options to save your results
663
 
664
- ## Features
665
- - **Saliency Map**: Shows importance distribution of image regions
666
- - **Heatmap**: Color-coded intensity visualization
667
- - **Overlay Effect**: Detection results overlaid on original image
668
- - **Object Segmentation**: Extracts key object regions
669
- - **Interactive Comparison**: Slide to compare original and processed images
670
- - **Analysis Report**: Detailed statistics and analysis charts
671
 
672
- ## Performance Tips
673
- - Use Quick Mode for faster processing when you only need saliency maps
674
- - Lower resolution images process faster
675
- - Enable GPU if available for better performance
676
  """
677
  )
678
 
679
- with gr.TabItem("ℹ️ About"):
680
  gr.Markdown(
681
  """
682
- ## Project Information
683
- - **Version**: 3.0.0
684
- - **Architecture**: PyTorch + Gradio
685
- - **Model**: CyueNet
686
- - **Language**: Multi-language support
687
 
688
- ## Key Features
689
- - Real-time image processing and analysis
690
- - Multi-dimensional result visualization
691
- - Rich image adjustment options
692
- - Detailed data analysis reports
693
- - Interactive comparison tools
694
- - Export functionality
695
- - Performance optimization with caching
696
 
697
- ## Update Log
698
- - ✅ Added Quick Mode for faster processing
699
- - ✅ Enhanced image preprocessing options
700
- - ✅ Added statistical analysis functions
701
- - ✅ Improved user interface with glassmorphism design
702
- - ✅ Added interactive comparison slider
703
- - ✅ Performance optimization with caching and threading
704
- - ✅ Multi-language chart support
705
- - ✅ Export functionality
706
 
707
- ## System Requirements
708
  - Python 3.8+
709
  - PyTorch 1.9+
710
- - CUDA (optional, for GPU acceleration)
711
- - 4GB+ RAM recommended
712
  """
713
  )
714
 
715
- # Event handlers
716
  def reset_params():
717
  return {
718
  threshold_slider: 50,
719
  brightness: 0,
720
  contrast: 0,
721
- filter_type: "None",
722
  enhance_contrast: False,
723
  denoise: False,
724
  show_contours: True,
725
- process_mode: "Full Analysis"
726
  }
727
 
728
- # Set up event handling
729
  submit_btn.click(
730
  fn=run_demo,
731
  inputs=[
@@ -767,25 +767,25 @@ with gr.Blocks(title="Advanced Saliency Object Detection System", css=custom_css
767
  ]
768
  )
769
 
770
- # Interactive comparison
771
  comparison_slider.change(
772
  fn=create_comparison_view,
773
  inputs=[original_output, overlayed_output, comparison_slider],
774
  outputs=[interactive_comparison]
775
  )
776
 
777
- # Export functionality
778
  export_btn.click(
779
  fn=processor.export_results,
780
  inputs=[export_format],
781
- outputs=[gr.Textbox(label="Export Status")]
782
  )
783
 
784
- # Launch the application
785
  if __name__ == "__main__":
786
  demo.launch(
787
  server_name="0.0.0.0",
788
  server_port=7860,
789
  share=True,
790
  show_error=True
791
- )
 
149
  self.model = MMS()
150
  try:
151
  self.model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device))
152
+ print("模型加载成功")
153
  except RuntimeError as e:
154
+ print(f"模型加载错误: {e}")
155
  except FileNotFoundError:
156
+ print("未找到模型文件,请检查路径。")
157
  self.model.to(device)
158
  self.model.eval()
159
 
 
178
 
179
  def apply_filters(self, image, filter_type):
180
  """应用图像滤镜效果"""
181
+ if filter_type == "锐化":
182
  kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
183
  return cv2.filter2D(image, -1, kernel)
184
+ elif filter_type == "模糊":
185
  return cv2.GaussianBlur(image, (5,5), 0)
186
+ elif filter_type == "边缘增强":
187
  kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]])
188
  return cv2.filter2D(image, -1, kernel)
189
  return image
190
 
191
  def generate_analysis_plots(self, saliency_map):
192
+ """生成分析图表"""
193
  plt.style.use('seaborn-v0_8')
194
  fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
195
 
196
+ # 直方图
197
  ax1.hist(saliency_map.flatten(), bins=50, color='#2196F3', alpha=0.7, edgecolor='black')
198
+ ax1.set_title('显著性分布直方图', fontsize=12, pad=15)
199
+ ax1.set_xlabel('显著性值', fontsize=10)
200
+ ax1.set_ylabel('频率', fontsize=10)
201
  ax1.grid(True, alpha=0.3)
202
 
203
+ # 添加统计信息
204
  mean_val = np.mean(saliency_map)
205
  median_val = np.median(saliency_map)
206
+ ax1.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'均值: {mean_val:.3f}')
207
+ ax1.axvline(median_val, color='green', linestyle='--', alpha=0.7, label=f'中位数: {median_val:.3f}')
208
  ax1.legend()
209
 
210
+ # 累积分布
211
  sorted_vals = np.sort(saliency_map.flatten())
212
  cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
213
  ax2.plot(sorted_vals, cumulative, color='#FF6B35', linewidth=2)
214
+ ax2.set_title('累积分布函数', fontsize=12)
215
+ ax2.set_xlabel('显著性值', fontsize=10)
216
+ ax2.set_ylabel('累积概率', fontsize=10)
217
  ax2.grid(True, alpha=0.3)
218
 
219
+ # 箱线图
220
  ax3.boxplot(saliency_map.flatten(), patch_artist=True,
221
  boxprops=dict(facecolor='#21CBF3', alpha=0.7))
222
+ ax3.set_title('显著性分布箱线图', fontsize=12)
223
+ ax3.set_ylabel('显著性值', fontsize=10)
224
  ax3.grid(True, alpha=0.3)
225
 
226
+ # 强度剖面(中心线)
227
  center_row = saliency_map[saliency_map.shape[0]//2, :]
228
  ax4.plot(center_row, color='#9C27B0', linewidth=2)
229
+ ax4.set_title('中心线强度剖面', fontsize=12)
230
+ ax4.set_xlabel('像素位置', fontsize=10)
231
+ ax4.set_ylabel('显著性值', fontsize=10)
232
  ax4.grid(True, alpha=0.3)
233
 
234
  plt.tight_layout()
235
 
236
+ # 保存为字节
237
  buf = io.BytesIO()
238
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
239
  buf.seek(0)
 
245
  def quick_process(self, image, threshold=0.5, testsize=256):
246
  """快速处理模式,只输出显著性图"""
247
  if image is None:
248
+ return None, "请提供有效的图像"
249
 
250
+ # 检查缓存
251
  image_hash = hash(image.tobytes())
252
  cache_key = f"{image_hash}_{threshold}_{testsize}_quick"
253
 
 
278
  res_resized = cv2.resize(res, (w, h))
279
  res_vis = (res_resized * 255).astype(np.uint8)
280
 
281
+ result = (res_vis, f"快速处理完成,耗时 {time_end - time_start:.3f}")
282
  self.cache[cache_key] = result
283
 
284
  return result
285
  def process_image(self, image, threshold=0.5, testsize=256,
286
  enhance_contrast=False, denoise=False,
287
+ brightness=0, contrast=0, filter_type="",
288
+ process_mode="完整分析"):
289
  """增强的图像处理函数"""
290
  if image is None:
291
+ return [None] * 9 + ["请提供有效的图像"]
292
 
293
+ # 快速模式检查
294
+ if process_mode == "快速模式":
295
  saliency_map, time_info = self.quick_process(image, threshold, testsize)
296
  return (image, saliency_map, None, None, None, None, time_info, None, None)
297
 
298
+ # 检查完整处理的缓存
299
  image_hash = hash(image.tobytes())
300
  cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full"
301
 
302
  if cache_key in self.cache:
303
  return self.cache[cache_key]
304
 
305
+ # 使用线程进行图像预处理
306
  def preprocess_image():
307
  processed_image = image.copy()
308
 
 
328
 
329
  original_image = processed_image.copy()
330
 
331
+ # 模型推理
332
  image_pil = Image.fromarray(processed_image).convert('RGB')
333
  image_tensor = transform_image(image_pil, testsize)
334
  image_tensor = image_tensor.unsqueeze(0).to(device)
 
353
  h, w = original_image.shape[:2]
354
  res_resized = cv2.resize(res, (w, h))
355
 
356
+ # 生成可视化
357
  res_vis = (res_resized * 255).astype(np.uint8)
358
  heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
359
  _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
360
 
361
+ # 创建叠加效果
362
  alpha = 0.5
363
  original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
364
  overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
365
  segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
366
 
367
+ # 转换回RGB
368
  overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
369
  segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
370
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
371
 
372
+ # 生成分析图表
373
  analysis_plot = self.generate_analysis_plots(res_resized)
374
 
375
+ # 计算统计信息
376
  contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
377
  total_area = w * h
378
  detected_area = cv2.countNonZero(binary_mask)
379
  coverage_ratio = detected_area / total_area
380
 
381
  stats = {
382
+ "处理分辨率": f"{w}x{h}",
383
+ "检测到对象数": str(len(contours)),
384
+ "平均置信度": f"{np.mean(res_resized):.2%}",
385
+ "最大置信度": f"{np.max(res_resized):.2%}",
386
+ "覆盖率": f"{coverage_ratio:.2%}",
387
+ "处理时间": f"{inference_time:.3f}"
388
  }
389
 
390
+ # 创建对比图像
391
  comparison_img = self.create_comparison_image(original_image, overlayed_rgb)
392
 
393
+ # 保存结果
394
  self.last_results = {
395
  'saliency_map': res_resized,
396
  'binary_mask': binary_mask,
 
398
  }
399
 
400
  result = (original_image, res_vis, heatmap_rgb, overlayed_rgb, segmented_rgb,
401
+ comparison_img, f"处理时间: {inference_time:.4f}", stats, analysis_plot)
402
 
403
+ # 缓存结果
404
  self.cache[cache_key] = result
405
 
406
  return result
 
412
  comparison[:, :w] = original
413
  comparison[:, w:] = processed
414
 
415
+ # 添加分界线
416
  cv2.line(comparison, (w, 0), (w, h), (255, 255, 255), 2)
417
 
418
  return comparison
 
420
  def export_results(self, format_type="PNG"):
421
  """导出结果"""
422
  if self.last_results is None:
423
+ return "没有结果可供导出"
424
 
425
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
426
 
427
+ if format_type == "PDF报告":
428
+ # 生成PDF报告逻辑
429
+ return f"PDF报告已保存为 saliency_report_{timestamp}.pdf"
430
  else:
431
+ return f"结果已导出为 {format_type.lower()} 文件"
432
  # Create processor instance
433
  processor = ImageProcessor()
434
 
 
436
  brightness, contrast, filter_type, process_mode):
437
  """主处理函数"""
438
  if input_image is None:
439
+ return [None] * 9 + ["请上传图像"]
440
 
441
+ # 处理图像
442
  results = processor.process_image(
443
  input_image,
444
  threshold=threshold/100.0,
 
452
 
453
  original, saliency_map, heatmap, overlayed, segmented, comparison, time_info, stats, analysis_plot = results
454
 
455
+ # 如果需要显示轮廓
456
  if show_contours and saliency_map is not None and overlayed is not None:
457
  _, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY)
458
  contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
460
  cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2)
461
  overlayed = overlay_with_contours
462
 
463
+ # 生成统计信息HTML
464
  if stats:
465
  stats_html = "<div class='statistics-container'>"
466
  for key, value in stats.items():
467
  stats_html += f"<div class='statistic-card'><h4>{key}</h4><p>{value}</p></div>"
468
  stats_html += "</div>"
469
  else:
470
+ stats_html = "<p>无可用统计信息</p>"
471
 
472
  return (original, saliency_map, heatmap, overlayed, segmented,
473
  comparison, time_info, stats_html, analysis_plot)
 
483
  comparison = original.copy()
484
  comparison[:, split_point:] = result[:, split_point:]
485
 
486
+ # 添加垂直线
487
  cv2.line(comparison, (split_point, 0), (split_point, h), (255, 255, 0), 3)
488
 
489
  return comparison
490
 
491
  # Create Gradio interface
492
+ with gr.Blocks(title="高级显著性对象检测系统", css=custom_css) as demo:
493
  gr.Markdown(
494
  """
495
+ # 🎯 高级显著性对象检测系统
496
+ ### AI驱动的图像显著性检测与分析工具
497
  """
498
  )
499
 
500
  with gr.Tabs() as tabs:
501
+ with gr.TabItem("🔍 主功能"):
502
  with gr.Row():
503
  with gr.Column(scale=1):
504
+ # 输入控件
505
  with gr.Group(elem_classes="advanced-controls"):
506
  input_image = gr.Image(
507
+ label="输入图像",
508
  type="numpy",
509
  elem_classes="input-image"
510
  )
511
 
512
+ # 处理模式选择
513
  process_mode = gr.Radio(
514
+ choices=["完整分析", "快速模式"],
515
+ value="完整分析",
516
+ label="处理模式",
517
+ info="快速模式仅输出显著性图,处理速度更快"
518
  )
519
 
520
+ with gr.Accordion("基本设置", open=True):
521
  threshold_slider = gr.Slider(
522
  minimum=0,
523
  maximum=100,
524
  value=50,
525
  step=1,
526
+ label="检测阈值",
527
+ info="调整检测灵敏度"
528
  )
529
  enhance_contrast = gr.Checkbox(
530
+ label="增强对比度",
531
  value=False
532
  )
533
  denoise = gr.Checkbox(
534
+ label="降噪",
535
  value=False
536
  )
537
  show_contours = gr.Checkbox(
538
+ label="显示轮廓",
539
  value=True
540
  )
541
 
542
+ with gr.Accordion("图像调整", open=False):
543
  brightness = gr.Slider(
544
  minimum=-100,
545
  maximum=100,
546
  value=0,
547
  step=1,
548
+ label="亮度"
549
  )
550
  contrast = gr.Slider(
551
  minimum=-100,
552
  maximum=100,
553
  value=0,
554
  step=1,
555
+ label="对比度"
556
  )
557
  filter_type = gr.Radio(
558
+ choices=["", "锐化", "模糊", "边缘增强"],
559
+ value="",
560
+ label="图像滤镜"
561
  )
562
 
563
+ with gr.Accordion("导出选项", open=False):
564
  export_format = gr.Dropdown(
565
+ choices=["PNG", "JPEG", "PDF报告"],
566
  value="PNG",
567
+ label="导出格式"
568
  )
569
  export_btn = gr.Button(
570
+ "导出结果",
571
  elem_classes="custom-button"
572
  )
573
 
574
  with gr.Row():
575
  submit_btn = gr.Button(
576
+ "开始检测",
577
  variant="primary",
578
  elem_classes="custom-button"
579
  )
580
  reset_btn = gr.Button(
581
+ "重置参数",
582
  elem_classes="custom-button"
583
  )
584
 
585
  with gr.Column(scale=2):
586
+ # 结果显示
587
  with gr.Tabs():
588
+ with gr.TabItem("检测结果"):
589
  with gr.Row(elem_classes="result-container"):
590
  original_output = gr.Image(
591
+ label="原始图像",
592
  elem_classes="output-image"
593
  )
594
  saliency_output = gr.Image(
595
+ label="显著性图",
596
  elem_classes="output-image"
597
  )
598
 
599
  with gr.Row(elem_classes="result-container"):
600
  heatmap_output = gr.Image(
601
+ label="热力图分析",
602
  elem_classes="output-image"
603
  )
604
  overlayed_output = gr.Image(
605
+ label="叠加效果",
606
  elem_classes="output-image"
607
  )
608
 
609
  with gr.Row(elem_classes="result-container"):
610
  segmented_output = gr.Image(
611
+ label="对象分割",
612
  elem_classes="output-image"
613
  )
614
  comparison_output = gr.Image(
615
+ label="并排对比",
616
  elem_classes="output-image"
617
  )
618
 
619
+ with gr.TabItem("交互式对比"):
620
  with gr.Group(elem_classes="comparison-slider"):
621
  comparison_slider = gr.Slider(
622
  minimum=0,
623
  maximum=1,
624
  value=0.5,
625
  step=0.01,
626
+ label="原始 ← → 结果",
627
+ info="拖动滑块对比原始图像和处理结果"
628
  )
629
  interactive_comparison = gr.Image(
630
+ label="交互式对比视图",
631
  elem_classes="interactive-viz"
632
  )
633
 
634
+ with gr.TabItem("分析报告"):
635
  with gr.Group(elem_classes="result-container"):
636
  time_info = gr.Textbox(
637
+ label="处理时间",
638
  show_label=True
639
  )
640
  stats_output = gr.HTML(
641
+ label="统计信息"
642
  )
643
  analysis_plot = gr.Image(
644
+ label="详细分析图表",
645
  elem_classes="output-image"
646
  )
647
 
648
+ with gr.TabItem("📖 用户指南"):
649
  gr.Markdown(
650
  """
651
+ ## 使用说明
652
+ 1. **上传图像**:点击"输入图像"区域上传您的图像
653
+ 2. **选择模式**:选择"完整分析""快速模式"
654
+ - 完整分析:完整处理流程,包含所有可视化结果
655
+ - 快速模式:快速处理,仅输出显著性图
656
+ 3. **调整参数**:
657
+ - 使用阈值滑块调整检测灵敏度
658
+ - 根据需要启用对比���增强或降噪
659
+ - 在高级设置中微调亮度、对比度和滤镜
660
+ 4. **开始检测**:点击"开始检测"按钮开始分析
661
+ 5. **查看结果**:在不同标签页查看各种可视化结果
662
+ 6. **导出**:使用导出选项保存您的结果
663
 
664
+ ## 功能特点
665
+ - **显著性图**:显示图像区域的显著性分布
666
+ - **热力图**:彩色编码的强度可视化
667
+ - **叠加效果**:在原始图像上叠加检测结果
668
+ - **对象分割**:提取关键对象区域
669
+ - **交互式对比**:滑动比较原始图像和处理结果
670
+ - **分析报告**:详细的统计信息和分析图表
671
 
672
+ ## 性能提示
673
+ - 当只需要显著性图时使用快速模式
674
+ - 分辨率较低的图像处理速度更快
675
+ - 启用GPU可获得更好的性能
676
  """
677
  )
678
 
679
+ with gr.TabItem("ℹ️ 关于"):
680
  gr.Markdown(
681
  """
682
+ ## 项目信息
683
+ - **版本**:3.0.0
684
+ - **架构**:PyTorch + Gradio
685
+ - **模型**:CyueNet
686
+ - **语言**:多语言支持
687
 
688
+ ## 主要特点
689
+ - 实时图像处理和分析
690
+ - 多维结果可视化
691
+ - 丰富的图像调整选项
692
+ - 详细的数据分析报告
693
+ - 交互式对比工具
694
+ - 导出功能
695
+ - 缓存优化性能
696
 
697
+ ## 更新日志
698
+ - ✅ 新增快速模式,提高处理速度
699
+ - ✅ 增强图像预处理选项
700
+ - ✅ 新增统计分析功能
701
+ - ✅ 改进用户界面,采用玻璃拟态设计
702
+ - ✅ 增加交互式对比滑块
703
+ - ✅ 使用缓存和线程优化性能
704
+ - ✅ 多语言图表支持
705
+ - ✅ 导出功能
706
 
707
+ ## 系统要求
708
  - Python 3.8+
709
  - PyTorch 1.9+
710
+ - CUDA(可选,用于GPU加速)
711
+ - 推荐4GB以上内存
712
  """
713
  )
714
 
715
+ # 事件处理
716
  def reset_params():
717
  return {
718
  threshold_slider: 50,
719
  brightness: 0,
720
  contrast: 0,
721
+ filter_type: "",
722
  enhance_contrast: False,
723
  denoise: False,
724
  show_contours: True,
725
+ process_mode: "完整分析"
726
  }
727
 
728
+ # 设置事件处理
729
  submit_btn.click(
730
  fn=run_demo,
731
  inputs=[
 
767
  ]
768
  )
769
 
770
+ # 交互式对比
771
  comparison_slider.change(
772
  fn=create_comparison_view,
773
  inputs=[original_output, overlayed_output, comparison_slider],
774
  outputs=[interactive_comparison]
775
  )
776
 
777
+ # 导出功能
778
  export_btn.click(
779
  fn=processor.export_results,
780
  inputs=[export_format],
781
+ outputs=[gr.Textbox(label="导出状态")]
782
  )
783
 
784
+ # 启动应用
785
  if __name__ == "__main__":
786
  demo.launch(
787
  server_name="0.0.0.0",
788
  server_port=7860,
789
  share=True,
790
  show_error=True
791
+ )