kunkk commited on
Commit
9e416cb
·
verified ·
1 Parent(s): 8b81b1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +358 -1
app.py CHANGED
@@ -137,6 +137,7 @@ class ImageProcessor:
137
  def __init__(self):
138
  self.model = None
139
  self.load_model()
 
140
 
141
  def load_model(self):
142
  """加载预训练的模型"""
@@ -175,4 +176,360 @@ class ImageProcessor:
175
  return cv2.filter2D(image, -1, kernel)
176
  elif filter_type == "柔化":
177
  return cv2.GaussianBlur(image, (5,5), 0)
178
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def __init__(self):
138
  self.model = None
139
  self.load_model()
140
+ self.last_results = None
141
 
142
  def load_model(self):
143
  """加载预训练的模型"""
 
176
  return cv2.filter2D(image, -1, kernel)
177
  elif filter_type == "柔化":
178
  return cv2.GaussianBlur(image, (5,5), 0)
179
+ return image
180
+
181
+ def generate_analysis_plots(self, saliency_map):
182
+ """生成分析图表"""
183
+ # 直方图数据
184
+ hist_data = saliency_map.flatten()
185
+ fig_hist = px.histogram(hist_data, nbins=50,
186
+ title="显著性分布直方图",
187
+ labels={'value': '显著性值', 'count': '频率'})
188
+
189
+ # 计算显著性统计
190
+ regions = np.zeros_like(saliency_map)
191
+ regions[saliency_map > np.mean(saliency_map)] = 1
192
+
193
+ return fig_hist
194
+
195
+ def process_image(self, image, threshold=0.5, testsize=256,
196
+ enhance_contrast=False, denoise=False,
197
+ brightness=0, contrast=0, filter_type="无"):
198
+ """增强的图像处理函数"""
199
+ if image is None:
200
+ return [None] * 7 + ["请提供有效的图像"]
201
+
202
+ # 图像预处理
203
+ if denoise:
204
+ image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
205
+
206
+ # 应用亮度和对比度调整
207
+ image = self.adjust_brightness_contrast(image, brightness, contrast)
208
+
209
+ # 应用滤镜
210
+ image = self.apply_filters(image, filter_type)
211
+
212
+ if enhance_contrast:
213
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
214
+ l, a, b = cv2.split(lab)
215
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
216
+ l = clahe.apply(l)
217
+ lab = cv2.merge((l,a,b))
218
+ image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
219
+
220
+ # 保存原始图像
221
+ original_image = image.copy()
222
+
223
+ # 预处理图像
224
+ image_pil = Image.fromarray(image).convert('RGB')
225
+ image_tensor = transform_image(image_pil, testsize)
226
+ image_tensor = image_tensor.unsqueeze(0)
227
+ image_tensor = image_tensor.to(device)
228
+
229
+ # 计时
230
+ time_start = time.time()
231
+
232
+ # 推理
233
+ with torch.no_grad():
234
+ x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
235
+
236
+ time_end = time.time()
237
+ inference_time = time_end - time_start
238
+
239
+ # 处理输出结果
240
+ res = res.sigmoid().data.cpu().numpy().squeeze()
241
+ res = (res - res.min()) / (res.max() - res.min() + 1e-8)
242
+
243
+ # 调整大小
244
+ h, w = original_image.shape[:2]
245
+ res_resized = cv2.resize(res, (w, h))
246
+
247
+ # 生成可视化结果
248
+ res_vis = (res_resized * 255).astype(np.uint8)
249
+ heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
250
+ _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
251
+
252
+ # 叠加结果
253
+ alpha = 0.5
254
+ original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
255
+ overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
256
+ segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
257
+
258
+ # 转回RGB格式
259
+ overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
260
+ segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
261
+
262
+ # 生成分析图表
263
+ analysis_plot = self.generate_analysis_plots(res_resized)
264
+
265
+ # 计算统计信息
266
+ stats = {
267
+ "处理分辨率": f"{w}x{h}",
268
+ "检测目标数量": str(len(cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0])),
269
+ "平均置信度": f"{np.mean(res_resized):.2%}",
270
+ "最大置信度": f"{np.max(res_resized):.2%}",
271
+ "处理时间": f"{inference_time:.3f}秒"
272
+ }
273
+
274
+ # 保存结果供后续分析
275
+ self.last_results = {
276
+ 'saliency_map': res_resized,
277
+ 'binary_mask': binary_mask,
278
+ 'stats': stats
279
+ }
280
+
281
+ return (original_image, res_vis, heatmap, overlayed_rgb, segmented_rgb,
282
+ f"推理时间: {inference_time:.4f}秒", stats, analysis_plot)
283
+ # 创建处理器实例
284
+ processor = ImageProcessor()
285
+
286
+ def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
287
+ brightness, contrast, filter_type):
288
+ """增强的主处理函数"""
289
+ if input_image is None:
290
+ return [None] * 8 + ["请上传图片"]
291
+
292
+ # 处理图像
293
+ results = processor.process_image(
294
+ input_image,
295
+ threshold=threshold/100.0,
296
+ enhance_contrast=enhance_contrast,
297
+ denoise=denoise,
298
+ brightness=brightness,
299
+ contrast=contrast,
300
+ filter_type=filter_type
301
+ )
302
+
303
+ original, saliency_map, heatmap, overlayed, segmented, time_info, stats, analysis_plot = results
304
+
305
+ # 添加轮廓显示
306
+ if show_contours and saliency_map is not None:
307
+ _, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY)
308
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
309
+ overlay_with_contours = overlayed.copy()
310
+ cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2)
311
+ overlayed = overlay_with_contours
312
+
313
+ # 生成统计信息HTML
314
+ stats_html = "<div class='statistics-container'>"
315
+ for key, value in stats.items():
316
+ stats_html += f"<div class='statistic-card'><h4>{key}</h4><p>{value}</p></div>"
317
+ stats_html += "</div>"
318
+
319
+ return (original, saliency_map, heatmap, overlayed, segmented,
320
+ time_info, stats_html, analysis_plot)
321
+
322
+ # 创建Gradio界面
323
+ with gr.Blocks(title="高级显著性目标检测系统", css=custom_css) as demo:
324
+ gr.Markdown(
325
+ """
326
+ # 🎯 智能显著性目标检测系统
327
+ ### 基于深度学习的图像显著性检测与分析工具
328
+ """
329
+ )
330
+
331
+ with gr.Tabs() as tabs:
332
+ with gr.TabItem("🔍 主要功能"):
333
+ with gr.Row():
334
+ with gr.Column(scale=1):
335
+ # 输入控制面板
336
+ with gr.Group(elem_classes="control-panel"):
337
+ input_image = gr.Image(
338
+ label="输入图像",
339
+ type="numpy",
340
+ elem_classes="input-image"
341
+ )
342
+
343
+ with gr.Accordion("基础设置", open=True):
344
+ threshold_slider = gr.Slider(
345
+ minimum=0,
346
+ maximum=100,
347
+ value=50,
348
+ step=1,
349
+ label="检测阈值",
350
+ info="调整检测的灵敏度"
351
+ )
352
+ enhance_contrast = gr.Checkbox(
353
+ label="增强对比度",
354
+ value=False
355
+ )
356
+ denoise = gr.Checkbox(
357
+ label="降噪处理",
358
+ value=False
359
+ )
360
+ show_contours = gr.Checkbox(
361
+ label="显示轮廓",
362
+ value=True
363
+ )
364
+
365
+ with gr.Accordion("图像调整", open=False):
366
+ brightness = gr.Slider(
367
+ minimum=-100,
368
+ maximum=100,
369
+ value=0,
370
+ step=1,
371
+ label="亮度"
372
+ )
373
+ contrast = gr.Slider(
374
+ minimum=-100,
375
+ maximum=100,
376
+ value=0,
377
+ step=1,
378
+ label="对比度"
379
+ )
380
+ filter_type = gr.Radio(
381
+ choices=["无", "锐化", "柔化"],
382
+ value="无",
383
+ label="图像滤镜"
384
+ )
385
+
386
+ with gr.Row():
387
+ submit_btn = gr.Button(
388
+ "开始检测",
389
+ variant="primary",
390
+ elem_classes="custom-button"
391
+ )
392
+ reset_btn = gr.Button(
393
+ "重置参数",
394
+ elem_classes="custom-button"
395
+ )
396
+
397
+ with gr.Column(scale=2):
398
+ # 结果显示区域
399
+ with gr.Tabs():
400
+ with gr.TabItem("检测结果"):
401
+ with gr.Row(elem_classes="result-gallery"):
402
+ original_output = gr.Image(
403
+ label="原始图像",
404
+ elem_classes="output-image"
405
+ )
406
+ saliency_output = gr.Image(
407
+ label="显著性图",
408
+ elem_classes="output-image"
409
+ )
410
+
411
+ with gr.Row(elem_classes="result-gallery"):
412
+ heatmap_output = gr.Image(
413
+ label="热力图分析",
414
+ elem_classes="output-image"
415
+ )
416
+ overlayed_output = gr.Image(
417
+ label="叠加效果",
418
+ elem_classes="output-image"
419
+ )
420
+
421
+ segmented_output = gr.Image(
422
+ label="目标分割",
423
+ elem_classes="output-image"
424
+ )
425
+
426
+ with gr.TabItem("分析报告"):
427
+ with gr.Group(elem_classes="info-box"):
428
+ time_info = gr.Textbox(
429
+ label="处理时间",
430
+ show_label=True
431
+ )
432
+ stats_output = gr.HTML(
433
+ label="统计信息"
434
+ )
435
+ analysis_plot = gr.Plot(
436
+ label="显著性分布分析"
437
+ )
438
+
439
+ with gr.TabItem("📖 使用指南"):
440
+ gr.Markdown(
441
+ """
442
+ ## 使用说明
443
+ 1. **上传图片**: 点击左侧"输入图像"区域上传待分析的图片
444
+ 2. **调整参数**:
445
+ - 使用阈值滑块调整检测的灵敏度
446
+ - 可选择是否增强对比度或进行降噪
447
+ - 在图像调整面板中调整亮度、对比度和滤镜
448
+ 3. **开始检测**: 点击"开始检测"按钮进行分析
449
+ 4. **查看结果**: 在不同标签页中查看分析结果
450
+
451
+ ## 功能说明
452
+ - **显著性图**: 展示目标区域的重要性分布
453
+ - **热力图**: 使用色彩展示检测强度
454
+ - **叠加效果**: 将检测结果与原图叠加展示
455
+ - **目标分割**: 提取关键目标区域
456
+ - **分析报告**: 查看详细的统计信息和分析图表
457
+ """
458
+ )
459
+
460
+ with gr.TabItem("ℹ️ 关于"):
461
+ gr.Markdown(
462
+ """
463
+ ## 项目信息
464
+ - **版本**: 2.0.0
465
+ - **技术架构**: PyTorch + Gradio
466
+ - **模型**: CyueNet
467
+
468
+ ## 主要特点
469
+ - 实时图像处理和分析
470
+ - 多维度结果可视化
471
+ - 丰富的图像调整选项
472
+ - 详细的数据分析报告
473
+
474
+ ## 更新日志
475
+ - 增加了图像预处理选项
476
+ - 添加了统计分析功能
477
+ - 优化了用户界面
478
+ - 提升了处理性能
479
+ """
480
+ )
481
+
482
+ # 重置参数函数
483
+ def reset_params():
484
+ return {
485
+ threshold_slider: 50,
486
+ brightness: 0,
487
+ contrast: 0,
488
+ filter_type: "无",
489
+ enhance_contrast: False,
490
+ denoise: False,
491
+ show_contours: True
492
+ }
493
+
494
+ # 设置事件处理
495
+ submit_btn.click(
496
+ fn=run_demo,
497
+ inputs=[
498
+ input_image,
499
+ threshold_slider,
500
+ enhance_contrast,
501
+ denoise,
502
+ show_contours,
503
+ brightness,
504
+ contrast,
505
+ filter_type
506
+ ],
507
+ outputs=[
508
+ original_output,
509
+ saliency_output,
510
+ heatmap_output,
511
+ overlayed_output,
512
+ segmented_output,
513
+ time_info,
514
+ stats_output,
515
+ analysis_plot
516
+ ]
517
+ )
518
+
519
+ reset_btn.click(
520
+ fn=reset_params,
521
+ inputs=[],
522
+ outputs=[
523
+ threshold_slider,
524
+ brightness,
525
+ contrast,
526
+ filter_type,
527
+ enhance_contrast,
528
+ denoise,
529
+ show_contours
530
+ ]
531
+ )
532
+
533
+ # 启动应用
534
+ if __name__ == "__main__":
535
+ demo.launch(share=True)