kunkk commited on
Commit
38a20c8
·
verified ·
1 Parent(s): a0fd9c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -119
app.py CHANGED
@@ -390,130 +390,130 @@ class ImageProcessor:
390
  brightness=0, contrast=0, filter_type="无",
391
  process_mode="完整分析"):
392
 
393
- if image is None:
394
- return [None] * 9 + ["请提供有效的图像"]
395
-
396
- # 快速模式检查
397
- if process_mode == "快速模式":
398
- saliency_map, time_info = self.quick_process(image, threshold, testsize)
399
- return (image, saliency_map, None, None, None, None, time_info, None, None)
400
-
401
- # 检查完整处理的缓存
402
- image_hash = hash(image.tobytes())
403
- cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full"
404
-
405
- if cache_key in self.cache:
406
- return self.cache[cache_key]
407
-
408
- # 使用线程进行图像预处理
409
- def preprocess_image():
410
- processed_image = image.copy()
411
-
412
- if denoise:
413
- processed_image = cv2.fastNlMeansDenoisingColored(processed_image, None, 10, 10, 7, 21)
414
-
415
- processed_image = self.adjust_brightness_contrast(processed_image, brightness, contrast)
416
- processed_image = self.apply_filters(processed_image, filter_type)
417
-
418
- if enhance_contrast:
419
- lab = cv2.cvtColor(processed_image, cv2.COLOR_RGB2LAB)
420
- l, a, b = cv2.split(lab)
421
- clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
422
- l = clahe.apply(l)
423
- lab = cv2.merge((l,a,b))
424
- processed_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
425
-
426
- return processed_image
427
 
428
- with concurrent.futures.ThreadPoolExecutor() as executor:
429
- future_preprocess = executor.submit(preprocess_image)
430
- processed_image = future_preprocess.result()
 
431
 
432
- original_image = processed_image.copy()
433
-
434
- # 模型推理
435
- image_pil = Image.fromarray(processed_image).convert('RGB')
436
- image_tensor = transform_image(image_pil, testsize)
437
- image_tensor = image_tensor.unsqueeze(0).to(device)
438
-
439
- time_start = time.time()
440
-
441
- with torch.no_grad():
442
- if device.type == 'cuda':
443
- with torch.cuda.amp.autocast():
444
- x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
445
- else:
446
- with torch.amp.autocast(device_type='cpu'):
447
- x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
448
-
449
- time_end = time.time()
450
- inference_time = time_end - time_start
451
-
452
- # 确保转换为float32类型并保持原始显著性值
453
- res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
454
-
455
- # 存储原始显著性值(未归一化)用于分析
456
- original_saliency = res.copy()
457
-
458
- # 归一化用于可视化
459
- res_normalized = (res - res.min()) / (res.max() - res.min() + 1e-8)
460
-
461
- h, w = original_image.shape[:2]
462
- # 调整大小但保持原始值范围
463
- res_resized_original = cv2.resize(original_saliency, (w, h))
464
- res_resized_normalized = cv2.resize(res_normalized, (w, h))
465
-
466
- # 生成可视化用的图像
467
- res_vis = (res_resized_normalized * 255).astype(np.uint8)
468
- heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
469
- _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
470
-
471
- # 创建叠加效果
472
- alpha = 0.5
473
- original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
474
- overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
475
- segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
476
-
477
- # 转换回RGB
478
- overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
479
- segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
480
- heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
481
-
482
- # 使用原始显著性值生成分析图表
483
- analysis_plot = self.generate_analysis_plots(res_resized_original)
484
-
485
- # 计算统计信息
486
- contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
487
- total_area = w * h
488
- detected_area = cv2.countNonZero(binary_mask)
489
- coverage_ratio = detected_area / total_area
490
-
491
- stats = {
492
- "处理分辨率": f"{w}x{h}",
493
- "检测到对象数": str(len(contours)),
494
- "平均置信度": f"{np.mean(res_resized_original):.2%}",
495
- "最大置信度": f"{np.max(res_resized_original):.2%}",
496
- "覆盖率": f"{coverage_ratio:.2%}",
497
- "处理时间": f"{inference_time:.3f}秒"
498
- }
499
-
500
- # 创建对比图像
501
- comparison_img = self.create_comparison_image(original_image, overlayed_rgb)
502
-
503
- # 保存结果
504
- self.last_results = {
505
- 'saliency_map': res_resized_original, # 使用原始值
506
- 'binary_mask': binary_mask,
507
- 'stats': stats
508
- }
509
 
510
- result = (original_image, res_vis, heatmap_rgb, overlayed_rgb, segmented_rgb,
511
- comparison_img, f"处理时间: {inference_time:.4f}秒", stats, analysis_plot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
- # 缓存结果
514
- self.cache[cache_key] = result
 
515
 
516
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
  def create_comparison_image(self, original, processed):
519
  """创建对比图像"""
 
390
  brightness=0, contrast=0, filter_type="无",
391
  process_mode="完整分析"):
392
 
393
+ if image is None:
394
+ return [None] * 9 + ["请提供有效的图像"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
+ # 快速模式检查
397
+ if process_mode == "快速模式":
398
+ saliency_map, time_info = self.quick_process(image, threshold, testsize)
399
+ return (image, saliency_map, None, None, None, None, time_info, None, None)
400
 
401
+ # 检查完整处理的缓存
402
+ image_hash = hash(image.tobytes())
403
+ cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
+ if cache_key in self.cache:
406
+ return self.cache[cache_key]
407
+
408
+ # 使用线程进行图像预处理
409
+ def preprocess_image():
410
+ processed_image = image.copy()
411
+
412
+ if denoise:
413
+ processed_image = cv2.fastNlMeansDenoisingColored(processed_image, None, 10, 10, 7, 21)
414
+
415
+ processed_image = self.adjust_brightness_contrast(processed_image, brightness, contrast)
416
+ processed_image = self.apply_filters(processed_image, filter_type)
417
+
418
+ if enhance_contrast:
419
+ lab = cv2.cvtColor(processed_image, cv2.COLOR_RGB2LAB)
420
+ l, a, b = cv2.split(lab)
421
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
422
+ l = clahe.apply(l)
423
+ lab = cv2.merge((l,a,b))
424
+ processed_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
425
+
426
+ return processed_image
427
 
428
+ with concurrent.futures.ThreadPoolExecutor() as executor:
429
+ future_preprocess = executor.submit(preprocess_image)
430
+ processed_image = future_preprocess.result()
431
 
432
+ original_image = processed_image.copy()
433
+
434
+ # 模型推理
435
+ image_pil = Image.fromarray(processed_image).convert('RGB')
436
+ image_tensor = transform_image(image_pil, testsize)
437
+ image_tensor = image_tensor.unsqueeze(0).to(device)
438
+
439
+ time_start = time.time()
440
+
441
+ with torch.no_grad():
442
+ if device.type == 'cuda':
443
+ with torch.cuda.amp.autocast():
444
+ x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
445
+ else:
446
+ with torch.amp.autocast(device_type='cpu'):
447
+ x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
448
+
449
+ time_end = time.time()
450
+ inference_time = time_end - time_start
451
+
452
+ # 确保转换为float32类型并保持原始显著性值
453
+ res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
454
+
455
+ # 存储原始显著性值(未归一化)用于分析
456
+ original_saliency = res.copy()
457
+
458
+ # 归一化用于可视化
459
+ res_normalized = (res - res.min()) / (res.max() - res.min() + 1e-8)
460
+
461
+ h, w = original_image.shape[:2]
462
+ # 调整大小但保持原始值范围
463
+ res_resized_original = cv2.resize(original_saliency, (w, h))
464
+ res_resized_normalized = cv2.resize(res_normalized, (w, h))
465
+
466
+ # 生成可视化用的图像
467
+ res_vis = (res_resized_normalized * 255).astype(np.uint8)
468
+ heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
469
+ _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
470
+
471
+ # 创建叠加效果
472
+ alpha = 0.5
473
+ original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
474
+ overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
475
+ segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
476
+
477
+ # 转换回RGB
478
+ overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
479
+ segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
480
+ heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
481
+
482
+ # 使用原始显著性值生成分析图表
483
+ analysis_plot = self.generate_analysis_plots(res_resized_original)
484
+
485
+ # 计算统计信息
486
+ contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
487
+ total_area = w * h
488
+ detected_area = cv2.countNonZero(binary_mask)
489
+ coverage_ratio = detected_area / total_area
490
+
491
+ stats = {
492
+ "处理分辨率": f"{w}x{h}",
493
+ "检测到对象数": str(len(contours)),
494
+ "平均置信度": f"{np.mean(res_resized_original):.2%}",
495
+ "最大置信度": f"{np.max(res_resized_original):.2%}",
496
+ "覆盖率": f"{coverage_ratio:.2%}",
497
+ "处理时间": f"{inference_time:.3f}秒"
498
+ }
499
+
500
+ # 创建对比图像
501
+ comparison_img = self.create_comparison_image(original_image, overlayed_rgb)
502
+
503
+ # 保存结果
504
+ self.last_results = {
505
+ 'saliency_map': res_resized_original, # 使用原始值
506
+ 'binary_mask': binary_mask,
507
+ 'stats': stats
508
+ }
509
+
510
+ result = (original_image, res_vis, heatmap_rgb, overlayed_rgb, segmented_rgb,
511
+ comparison_img, f"处理时间: {inference_time:.4f}秒", stats, analysis_plot)
512
+
513
+ # 缓存结果
514
+ self.cache[cache_key] = result
515
+
516
+ return result
517
 
518
  def create_comparison_image(self, original, processed):
519
  """创建对比图像"""