kunkk commited on
Commit
249ac00
·
verified ·
1 Parent(s): 6749254

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -20
app.py CHANGED
@@ -10,7 +10,6 @@ import matplotlib.pyplot as plt
10
  import concurrent.futures
11
  from model.CyueNet_models import MMS
12
  from utils1.data import transform_image
13
- import plotly.graph_objects as go
14
  from datetime import datetime
15
  import io
16
  import base64
@@ -18,7 +17,7 @@ import base64
18
  # GPU/CPU设置
19
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
20
 
21
- # 更新后的CSS样式
22
  custom_css = """
23
  :root {
24
  --primary-color: #2196F3;
@@ -138,7 +137,6 @@ custom_css = """
138
  border: 1px solid rgba(255, 255, 255, 0.18);
139
  }
140
  """
141
-
142
  class ImageProcessor:
143
  def __init__(self):
144
  self.model = None
@@ -263,12 +261,17 @@ class ImageProcessor:
263
  time_start = time.time()
264
 
265
  with torch.no_grad():
266
- with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.cpu.amp.autocast():
267
- _, res, *_ = self.model(image_tensor)
 
 
 
 
268
 
269
  time_end = time.time()
270
 
271
- res = res.sigmoid().data.cpu().numpy().squeeze()
 
272
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
273
 
274
  h, w = image.shape[:2]
@@ -279,7 +282,6 @@ class ImageProcessor:
279
  self.cache[cache_key] = result
280
 
281
  return result
282
-
283
  def process_image(self, image, threshold=0.5, testsize=256,
284
  enhance_contrast=False, denoise=False,
285
  brightness=0, contrast=0, filter_type="None",
@@ -288,7 +290,7 @@ class ImageProcessor:
288
  if image is None:
289
  return [None] * 9 + ["Please provide a valid image"]
290
 
291
- # Quick mode
292
  if process_mode == "Quick Mode":
293
  saliency_map, time_info = self.quick_process(image, threshold, testsize)
294
  return (image, saliency_map, None, None, None, None, time_info, None, None)
@@ -334,14 +336,18 @@ class ImageProcessor:
334
  time_start = time.time()
335
 
336
  with torch.no_grad():
337
- with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.cpu.amp.autocast():
338
- x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
 
 
 
 
339
 
340
  time_end = time.time()
341
  inference_time = time_end - time_start
342
 
343
- # Process outputs
344
- res = res.sigmoid().data.cpu().numpy().squeeze()
345
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
346
 
347
  h, w = original_image.shape[:2]
@@ -423,7 +429,6 @@ class ImageProcessor:
423
  return f"PDF report saved as saliency_report_{timestamp}.pdf"
424
  else:
425
  return f"Results exported as {format_type.lower()} files"
426
-
427
  # Create processor instance
428
  processor = ImageProcessor()
429
 
@@ -720,11 +725,6 @@ with gr.Blocks(title="Advanced Saliency Object Detection System", css=custom_css
720
  process_mode: "Full Analysis"
721
  }
722
 
723
- def update_comparison(original, overlay, slider_val):
724
- if original is not None and overlay is not None:
725
- return create_comparison_view(original, overlay, slider_val)
726
- return None
727
-
728
  # Set up event handling
729
  submit_btn.click(
730
  fn=run_demo,
@@ -769,7 +769,7 @@ with gr.Blocks(title="Advanced Saliency Object Detection System", css=custom_css
769
 
770
  # Interactive comparison
771
  comparison_slider.change(
772
- fn=update_comparison,
773
  inputs=[original_output, overlayed_output, comparison_slider],
774
  outputs=[interactive_comparison]
775
  )
@@ -787,6 +787,5 @@ if __name__ == "__main__":
787
  server_name="0.0.0.0",
788
  server_port=7860,
789
  share=True,
790
- inbrowser=True,
791
  show_error=True
792
  )
 
10
  import concurrent.futures
11
  from model.CyueNet_models import MMS
12
  from utils1.data import transform_image
 
13
  from datetime import datetime
14
  import io
15
  import base64
 
17
  # GPU/CPU设置
18
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
 
20
+ # CSS样式设置
21
  custom_css = """
22
  :root {
23
  --primary-color: #2196F3;
 
137
  border: 1px solid rgba(255, 255, 255, 0.18);
138
  }
139
  """
 
140
  class ImageProcessor:
141
  def __init__(self):
142
  self.model = None
 
261
  time_start = time.time()
262
 
263
  with torch.no_grad():
264
+ if device.type == 'cuda':
265
+ with torch.cuda.amp.autocast():
266
+ _, res, *_ = self.model(image_tensor)
267
+ else:
268
+ with torch.amp.autocast(device_type='cpu'):
269
+ _, res, *_ = self.model(image_tensor)
270
 
271
  time_end = time.time()
272
 
273
+ # 确保转换为float32类型
274
+ res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
275
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
276
 
277
  h, w = image.shape[:2]
 
282
  self.cache[cache_key] = result
283
 
284
  return result
 
285
  def process_image(self, image, threshold=0.5, testsize=256,
286
  enhance_contrast=False, denoise=False,
287
  brightness=0, contrast=0, filter_type="None",
 
290
  if image is None:
291
  return [None] * 9 + ["Please provide a valid image"]
292
 
293
+ # Quick mode check
294
  if process_mode == "Quick Mode":
295
  saliency_map, time_info = self.quick_process(image, threshold, testsize)
296
  return (image, saliency_map, None, None, None, None, time_info, None, None)
 
336
  time_start = time.time()
337
 
338
  with torch.no_grad():
339
+ if device.type == 'cuda':
340
+ with torch.cuda.amp.autocast():
341
+ x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
342
+ else:
343
+ with torch.amp.autocast(device_type='cpu'):
344
+ x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
345
 
346
  time_end = time.time()
347
  inference_time = time_end - time_start
348
 
349
+ # 确保转换为float32类型
350
+ res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
351
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
352
 
353
  h, w = original_image.shape[:2]
 
429
  return f"PDF report saved as saliency_report_{timestamp}.pdf"
430
  else:
431
  return f"Results exported as {format_type.lower()} files"
 
432
  # Create processor instance
433
  processor = ImageProcessor()
434
 
 
725
  process_mode: "Full Analysis"
726
  }
727
 
 
 
 
 
 
728
  # Set up event handling
729
  submit_btn.click(
730
  fn=run_demo,
 
769
 
770
  # Interactive comparison
771
  comparison_slider.change(
772
+ fn=create_comparison_view,
773
  inputs=[original_output, overlayed_output, comparison_slider],
774
  outputs=[interactive_comparison]
775
  )
 
787
  server_name="0.0.0.0",
788
  server_port=7860,
789
  share=True,
 
790
  show_error=True
791
  )