Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,6 @@ import matplotlib.pyplot as plt
|
|
| 10 |
import concurrent.futures
|
| 11 |
from model.CyueNet_models import MMS
|
| 12 |
from utils1.data import transform_image
|
| 13 |
-
import plotly.graph_objects as go
|
| 14 |
from datetime import datetime
|
| 15 |
import io
|
| 16 |
import base64
|
|
@@ -18,7 +17,7 @@ import base64
|
|
| 18 |
# GPU/CPU设置
|
| 19 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 20 |
|
| 21 |
-
#
|
| 22 |
custom_css = """
|
| 23 |
:root {
|
| 24 |
--primary-color: #2196F3;
|
|
@@ -138,7 +137,6 @@ custom_css = """
|
|
| 138 |
border: 1px solid rgba(255, 255, 255, 0.18);
|
| 139 |
}
|
| 140 |
"""
|
| 141 |
-
|
| 142 |
class ImageProcessor:
|
| 143 |
def __init__(self):
|
| 144 |
self.model = None
|
|
@@ -263,12 +261,17 @@ class ImageProcessor:
|
|
| 263 |
time_start = time.time()
|
| 264 |
|
| 265 |
with torch.no_grad():
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
time_end = time.time()
|
| 270 |
|
| 271 |
-
|
|
|
|
| 272 |
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
|
| 273 |
|
| 274 |
h, w = image.shape[:2]
|
|
@@ -279,7 +282,6 @@ class ImageProcessor:
|
|
| 279 |
self.cache[cache_key] = result
|
| 280 |
|
| 281 |
return result
|
| 282 |
-
|
| 283 |
def process_image(self, image, threshold=0.5, testsize=256,
|
| 284 |
enhance_contrast=False, denoise=False,
|
| 285 |
brightness=0, contrast=0, filter_type="None",
|
|
@@ -288,7 +290,7 @@ class ImageProcessor:
|
|
| 288 |
if image is None:
|
| 289 |
return [None] * 9 + ["Please provide a valid image"]
|
| 290 |
|
| 291 |
-
# Quick mode
|
| 292 |
if process_mode == "Quick Mode":
|
| 293 |
saliency_map, time_info = self.quick_process(image, threshold, testsize)
|
| 294 |
return (image, saliency_map, None, None, None, None, time_info, None, None)
|
|
@@ -334,14 +336,18 @@ class ImageProcessor:
|
|
| 334 |
time_start = time.time()
|
| 335 |
|
| 336 |
with torch.no_grad():
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
time_end = time.time()
|
| 341 |
inference_time = time_end - time_start
|
| 342 |
|
| 343 |
-
#
|
| 344 |
-
res = res.sigmoid().
|
| 345 |
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
|
| 346 |
|
| 347 |
h, w = original_image.shape[:2]
|
|
@@ -423,7 +429,6 @@ class ImageProcessor:
|
|
| 423 |
return f"PDF report saved as saliency_report_{timestamp}.pdf"
|
| 424 |
else:
|
| 425 |
return f"Results exported as {format_type.lower()} files"
|
| 426 |
-
|
| 427 |
# Create processor instance
|
| 428 |
processor = ImageProcessor()
|
| 429 |
|
|
@@ -720,11 +725,6 @@ with gr.Blocks(title="Advanced Saliency Object Detection System", css=custom_css
|
|
| 720 |
process_mode: "Full Analysis"
|
| 721 |
}
|
| 722 |
|
| 723 |
-
def update_comparison(original, overlay, slider_val):
|
| 724 |
-
if original is not None and overlay is not None:
|
| 725 |
-
return create_comparison_view(original, overlay, slider_val)
|
| 726 |
-
return None
|
| 727 |
-
|
| 728 |
# Set up event handling
|
| 729 |
submit_btn.click(
|
| 730 |
fn=run_demo,
|
|
@@ -769,7 +769,7 @@ with gr.Blocks(title="Advanced Saliency Object Detection System", css=custom_css
|
|
| 769 |
|
| 770 |
# Interactive comparison
|
| 771 |
comparison_slider.change(
|
| 772 |
-
fn=
|
| 773 |
inputs=[original_output, overlayed_output, comparison_slider],
|
| 774 |
outputs=[interactive_comparison]
|
| 775 |
)
|
|
@@ -787,6 +787,5 @@ if __name__ == "__main__":
|
|
| 787 |
server_name="0.0.0.0",
|
| 788 |
server_port=7860,
|
| 789 |
share=True,
|
| 790 |
-
inbrowser=True,
|
| 791 |
show_error=True
|
| 792 |
)
|
|
|
|
| 10 |
import concurrent.futures
|
| 11 |
from model.CyueNet_models import MMS
|
| 12 |
from utils1.data import transform_image
|
|
|
|
| 13 |
from datetime import datetime
|
| 14 |
import io
|
| 15 |
import base64
|
|
|
|
| 17 |
# GPU/CPU设置
|
| 18 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 19 |
|
| 20 |
+
# CSS样式设置
|
| 21 |
custom_css = """
|
| 22 |
:root {
|
| 23 |
--primary-color: #2196F3;
|
|
|
|
| 137 |
border: 1px solid rgba(255, 255, 255, 0.18);
|
| 138 |
}
|
| 139 |
"""
|
|
|
|
| 140 |
class ImageProcessor:
|
| 141 |
def __init__(self):
|
| 142 |
self.model = None
|
|
|
|
| 261 |
time_start = time.time()
|
| 262 |
|
| 263 |
with torch.no_grad():
|
| 264 |
+
if device.type == 'cuda':
|
| 265 |
+
with torch.cuda.amp.autocast():
|
| 266 |
+
_, res, *_ = self.model(image_tensor)
|
| 267 |
+
else:
|
| 268 |
+
with torch.amp.autocast(device_type='cpu'):
|
| 269 |
+
_, res, *_ = self.model(image_tensor)
|
| 270 |
|
| 271 |
time_end = time.time()
|
| 272 |
|
| 273 |
+
# 确保转换为float32类型
|
| 274 |
+
res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
|
| 275 |
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
|
| 276 |
|
| 277 |
h, w = image.shape[:2]
|
|
|
|
| 282 |
self.cache[cache_key] = result
|
| 283 |
|
| 284 |
return result
|
|
|
|
| 285 |
def process_image(self, image, threshold=0.5, testsize=256,
|
| 286 |
enhance_contrast=False, denoise=False,
|
| 287 |
brightness=0, contrast=0, filter_type="None",
|
|
|
|
| 290 |
if image is None:
|
| 291 |
return [None] * 9 + ["Please provide a valid image"]
|
| 292 |
|
| 293 |
+
# Quick mode check
|
| 294 |
if process_mode == "Quick Mode":
|
| 295 |
saliency_map, time_info = self.quick_process(image, threshold, testsize)
|
| 296 |
return (image, saliency_map, None, None, None, None, time_info, None, None)
|
|
|
|
| 336 |
time_start = time.time()
|
| 337 |
|
| 338 |
with torch.no_grad():
|
| 339 |
+
if device.type == 'cuda':
|
| 340 |
+
with torch.cuda.amp.autocast():
|
| 341 |
+
x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
|
| 342 |
+
else:
|
| 343 |
+
with torch.amp.autocast(device_type='cpu'):
|
| 344 |
+
x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
|
| 345 |
|
| 346 |
time_end = time.time()
|
| 347 |
inference_time = time_end - time_start
|
| 348 |
|
| 349 |
+
# 确保转换为float32类型
|
| 350 |
+
res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
|
| 351 |
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
|
| 352 |
|
| 353 |
h, w = original_image.shape[:2]
|
|
|
|
| 429 |
return f"PDF report saved as saliency_report_{timestamp}.pdf"
|
| 430 |
else:
|
| 431 |
return f"Results exported as {format_type.lower()} files"
|
|
|
|
| 432 |
# Create processor instance
|
| 433 |
processor = ImageProcessor()
|
| 434 |
|
|
|
|
| 725 |
process_mode: "Full Analysis"
|
| 726 |
}
|
| 727 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
# Set up event handling
|
| 729 |
submit_btn.click(
|
| 730 |
fn=run_demo,
|
|
|
|
| 769 |
|
| 770 |
# Interactive comparison
|
| 771 |
comparison_slider.change(
|
| 772 |
+
fn=create_comparison_view,
|
| 773 |
inputs=[original_output, overlayed_output, comparison_slider],
|
| 774 |
outputs=[interactive_comparison]
|
| 775 |
)
|
|
|
|
| 787 |
server_name="0.0.0.0",
|
| 788 |
server_port=7860,
|
| 789 |
share=True,
|
|
|
|
| 790 |
show_error=True
|
| 791 |
)
|