kunkk commited on
Commit
944f056
·
verified ·
1 Parent(s): 699f6a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -245,10 +245,9 @@ class ImageProcessor:
245
  return image
246
 
247
  def generate_analysis_plots(self, saliency_map):
248
- """生成分析图表 - 使用原始显著性图而非二值化"""
249
- plt.style.use('default') # 使用默认样式以确保打印清晰
250
 
251
- # 设置全局字体和颜色
252
  plt.rcParams.update({
253
  'font.size': 12,
254
  'font.weight': 'bold',
@@ -269,9 +268,9 @@ class ImageProcessor:
269
  # 直方图 - 使用原始显著性值
270
  ax1.hist(saliency_map.flatten(), bins=50, color='black', alpha=0.7,
271
  edgecolor='black', linewidth=1.5)
272
- ax1.set_title('Histogram of Saliency Distribution', fontsize=14,
273
  fontweight='bold', color='black', pad=15)
274
- ax1.set_xlabel('Saliency Value', fontsize=12, fontweight='bold', color='black')
275
  ax1.set_ylabel('Frequency', fontsize=12, fontweight='bold', color='black')
276
  ax1.grid(True, alpha=0.5, color='gray', linewidth=1)
277
  ax1.set_facecolor('white')
@@ -289,9 +288,9 @@ class ImageProcessor:
289
  sorted_vals = np.sort(saliency_map.flatten())
290
  cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
291
  ax2.plot(sorted_vals, cumulative, color='black', linewidth=3)
292
- ax2.set_title('Cumulative Distribution Function', fontsize=14,
293
  fontweight='bold', color='black')
294
- ax2.set_xlabel('Saliency Value', fontsize=12, fontweight='bold', color='black')
295
  ax2.set_ylabel('Cumulative Probability', fontsize=12, fontweight='bold', color='black')
296
  ax2.grid(True, alpha=0.5, color='gray', linewidth=1)
297
  ax2.set_facecolor('white')
@@ -452,14 +451,22 @@ class ImageProcessor:
452
 
453
  # 确保转换为float32类型
454
  # 确保转换为float32类型
 
455
  res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
456
- res = (res - res.min()) / (res.max() - res.min() + 1e-8)
 
 
 
 
 
457
 
458
  h, w = original_image.shape[:2]
459
- res_resized = cv2.resize(res, (w, h))
 
 
460
 
461
- # 生成可视化
462
- res_vis = (res_resized * 255).astype(np.uint8)
463
  heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
464
  _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
465
 
@@ -474,8 +481,8 @@ class ImageProcessor:
474
  segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
475
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
476
 
477
- # 生成分析图表 - 使用原始的显著性值(未二值化)
478
- analysis_plot = self.generate_analysis_plots(res_resized) # 直接使用res_resized,而不是二值化后的mask
479
 
480
  # 生成分析图表
481
  analysis_plot = self.generate_analysis_plots(res_resized)
 
245
  return image
246
 
247
  def generate_analysis_plots(self, saliency_map):
248
+ """生成分析图表 - 使用原始显著性值"""
249
+ plt.style.use('default')
250
 
 
251
  plt.rcParams.update({
252
  'font.size': 12,
253
  'font.weight': 'bold',
 
268
  # 直方图 - 使用原始显著性值
269
  ax1.hist(saliency_map.flatten(), bins=50, color='black', alpha=0.7,
270
  edgecolor='black', linewidth=1.5)
271
+ ax1.set_title('Distribution of Raw Saliency Values', fontsize=14,
272
  fontweight='bold', color='black', pad=15)
273
+ ax1.set_xlabel('Raw Saliency Value', fontsize=12, fontweight='bold', color='black')
274
  ax1.set_ylabel('Frequency', fontsize=12, fontweight='bold', color='black')
275
  ax1.grid(True, alpha=0.5, color='gray', linewidth=1)
276
  ax1.set_facecolor('white')
 
288
  sorted_vals = np.sort(saliency_map.flatten())
289
  cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
290
  ax2.plot(sorted_vals, cumulative, color='black', linewidth=3)
291
+ ax2.set_title('Cumulative Distribution of Raw Saliency', fontsize=14,
292
  fontweight='bold', color='black')
293
+ ax2.set_xlabel('Raw Saliency Value', fontsize=12, fontweight='bold', color='black')
294
  ax2.set_ylabel('Cumulative Probability', fontsize=12, fontweight='bold', color='black')
295
  ax2.grid(True, alpha=0.5, color='gray', linewidth=1)
296
  ax2.set_facecolor('white')
 
451
 
452
  # 确保转换为float32类型
453
  # 确保转换为float32类型
454
+ # 确保转换为float32类型并保持原始显著性值
455
  res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
456
+
457
+ # 存储原始显著性值(未归一化)用于分析
458
+ original_saliency = res.copy()
459
+
460
+ # 归一化用于可视化
461
+ res_normalized = (res - res.min()) / (res.max() - res.min() + 1e-8)
462
 
463
  h, w = original_image.shape[:2]
464
+ # 调整大小但保持原始值范围
465
+ res_resized_original = cv2.resize(original_saliency, (w, h))
466
+ res_resized_normalized = cv2.resize(res_normalized, (w, h))
467
 
468
+ # 生成可视化用的图像
469
+ res_vis = (res_resized_normalized * 255).astype(np.uint8)
470
  heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
471
  _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
472
 
 
481
  segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
482
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
483
 
484
+ # 使用原始显著性值生成分析图表
485
+ analysis_plot = self.generate_analysis_plots(res_resized_original) # 使用未归一化的原始值
486
 
487
  # 生成分析图表
488
  analysis_plot = self.generate_analysis_plots(res_resized)