kunkk commited on
Commit
f356a44
·
verified ·
1 Parent(s): 6e4f6fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -386,26 +386,26 @@ class ImageProcessor:
386
 
387
 
388
  def process_image(self, image, threshold=0.5, testsize=256,
389
- enhance_contrast=False, denoise=False,
390
- brightness=0, contrast=0, filter_type="无",
391
- process_mode="完整分析"):
392
- """增强的图像处理函数"""
393
- if image is None:
394
- return [None] * 9 + ["请提供有效的图像"]
395
-
396
- # 快速模式检查
397
- if process_mode == "快速模式":
398
- saliency_map, time_info = self.quick_process(image, threshold, testsize)
399
- return (image, saliency_map, None, None, None, None, time_info, None, None)
400
-
401
- # 检查完整处理的缓存
402
- image_hash = hash(image.tobytes())
403
- cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full"
404
-
405
- if cache_key in self.cache:
406
- return self.cache[cache_key]
407
 
408
- # 使用线程进行图像预处理
409
  def preprocess_image():
410
  processed_image = image.copy()
411
 
@@ -424,11 +424,11 @@ class ImageProcessor:
424
  processed_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
425
 
426
  return processed_image
427
-
428
  with concurrent.futures.ThreadPoolExecutor() as executor:
429
  future_preprocess = executor.submit(preprocess_image)
430
  processed_image = future_preprocess.result()
431
-
432
  original_image = processed_image.copy()
433
 
434
  # 模型推理
@@ -448,7 +448,7 @@ class ImageProcessor:
448
 
449
  time_end = time.time()
450
  inference_time = time_end - time_start
451
-
452
  # 确保转换为float32类型并保持原始显著性值
453
  res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
454
 
@@ -480,7 +480,7 @@ class ImageProcessor:
480
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
481
 
482
  # 使用原始显著性值生成分析图表
483
- analysis_plot = self.generate_analysis_plots(res_resized_original) # 使用未归一化的原始值
484
 
485
  # 计算统计信息
486
  contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
@@ -491,8 +491,8 @@ class ImageProcessor:
491
  stats = {
492
  "处理分辨率": f"{w}x{h}",
493
  "检测到对象数": str(len(contours)),
494
- "平均置信度": f"{np.mean(res_resized_original):.2%}", # 使用原始值
495
- "最大置信度": f"{np.max(res_resized_original):.2%}", # 使用原始值
496
  "覆盖率": f"{coverage_ratio:.2%}",
497
  "处理时间": f"{inference_time:.3f}秒"
498
  }
@@ -502,7 +502,7 @@ class ImageProcessor:
502
 
503
  # 保存结果
504
  self.last_results = {
505
- 'saliency_map': res_resized,
506
  'binary_mask': binary_mask,
507
  'stats': stats
508
  }
 
386
 
387
 
388
  def process_image(self, image, threshold=0.5, testsize=256,
389
+ enhance_contrast=False, denoise=False,
390
+ brightness=0, contrast=0, filter_type="无",
391
+ process_mode="完整分析"):
392
+ """增强的图像处理函数"""
393
+ if image is None:
394
+ return [None] * 9 + ["请提供有效的图像"]
395
+
396
+ # 快速模式检查
397
+ if process_mode == "快速模式":
398
+ saliency_map, time_info = self.quick_process(image, threshold, testsize)
399
+ return (image, saliency_map, None, None, None, None, time_info, None, None)
400
+
401
+ # 检查完整处理的缓存
402
+ image_hash = hash(image.tobytes())
403
+ cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full"
404
+
405
+ if cache_key in self.cache:
406
+ return self.cache[cache_key]
407
 
408
+ # 使用线程进行图像预处理
409
  def preprocess_image():
410
  processed_image = image.copy()
411
 
 
424
  processed_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
425
 
426
  return processed_image
427
+
428
  with concurrent.futures.ThreadPoolExecutor() as executor:
429
  future_preprocess = executor.submit(preprocess_image)
430
  processed_image = future_preprocess.result()
431
+
432
  original_image = processed_image.copy()
433
 
434
  # 模型推理
 
448
 
449
  time_end = time.time()
450
  inference_time = time_end - time_start
451
+
452
  # 确保转换为float32类型并保持原始显著性值
453
  res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
454
 
 
480
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
481
 
482
  # 使用原始显著性值生成分析图表
483
+ analysis_plot = self.generate_analysis_plots(res_resized_original)
484
 
485
  # 计算统计信息
486
  contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
 
491
  stats = {
492
  "处理分辨率": f"{w}x{h}",
493
  "检测到对象数": str(len(contours)),
494
+ "平均置信度": f"{np.mean(res_resized_original):.2%}",
495
+ "最大置信度": f"{np.max(res_resized_original):.2%}",
496
  "覆盖率": f"{coverage_ratio:.2%}",
497
  "处理时间": f"{inference_time:.3f}秒"
498
  }
 
502
 
503
  # 保存结果
504
  self.last_results = {
505
+ 'saliency_map': res_resized_original, # 使用原始值
506
  'binary_mask': binary_mask,
507
  'stats': stats
508
  }