kunkk's picture
Update app.py
249ac00 verified
raw
history blame
30.8 kB
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import gradio as gr
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import concurrent.futures
from model.CyueNet_models import MMS
from utils1.data import transform_image
from datetime import datetime
import io
import base64
# GPU/CPU设置
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# CSS样式设置
custom_css = """
:root {
--primary-color: #2196F3;
--secondary-color: #21CBF3;
--background-color: #f6f8fa;
--text-color: #333;
--border-radius: 10px;
--glass-bg: rgba(255, 255, 255, 0.25);
--shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37);
}
.gradio-container {
background: linear-gradient(135deg, var(--background-color), #ffffff);
max-width: 1400px !important;
margin: auto !important;
backdrop-filter: blur(10px);
}
.output-image, .input-image {
border-radius: var(--border-radius);
box-shadow: var(--shadow);
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.18);
}
.output-image:hover, .input-image:hover {
transform: scale(1.02) translateY(-2px);
box-shadow: 0 12px 40px 0 rgba(31, 38, 135, 0.5);
}
.custom-button {
background: linear-gradient(45deg, var(--primary-color), var(--secondary-color));
border: none;
color: white;
padding: 12px 24px;
border-radius: var(--border-radius);
cursor: pointer;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
font-weight: bold;
text-transform: uppercase;
letter-spacing: 1px;
box-shadow: var(--shadow);
}
.custom-button:hover {
transform: translateY(-3px);
box-shadow: 0 12px 30px rgba(33, 150, 243, 0.4);
}
.advanced-controls {
background: var(--glass-bg);
border-radius: 20px;
padding: 25px;
box-shadow: var(--shadow);
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.18);
}
.result-container {
background: var(--glass-bg);
border-radius: 20px;
padding: 20px;
backdrop-filter: blur(15px);
border: 1px solid rgba(255, 255, 255, 0.18);
box-shadow: var(--shadow);
}
.interactive-viz {
border-radius: 15px;
overflow: hidden;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
box-shadow: var(--shadow);
}
.interactive-viz:hover {
transform: translateY(-5px);
box-shadow: 0 15px 35px rgba(0,0,0,0.15);
}
.statistics-container {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 15px;
margin-top: 15px;
}
.statistic-card {
background: var(--glass-bg);
padding: 20px;
border-radius: var(--border-radius);
text-align: center;
box-shadow: var(--shadow);
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.18);
transition: all 0.3s ease;
}
.statistic-card:hover {
transform: translateY(-2px);
box-shadow: 0 10px 25px rgba(0,0,0,0.1);
}
.progress-container {
background: var(--glass-bg);
border-radius: 10px;
padding: 15px;
margin: 10px 0;
backdrop-filter: blur(10px);
}
.comparison-slider {
background: var(--glass-bg);
border-radius: 15px;
padding: 20px;
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.18);
}
"""
class ImageProcessor:
def __init__(self):
self.model = None
self.load_model()
self.last_results = None
self.cache = {}
def load_model(self):
"""加载预训练的模型"""
self.model = MMS()
try:
self.model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device))
print("Model loaded successfully")
except RuntimeError as e:
print(f"Model loading error: {e}")
except FileNotFoundError:
print("Model file not found. Please check the path.")
self.model.to(device)
self.model.eval()
def adjust_brightness_contrast(self, image, brightness=0, contrast=0):
"""调整图像亮度和对比度"""
if brightness != 0:
if brightness > 0:
shadow = brightness
highlight = 255
else:
shadow = 0
highlight = 255 + brightness
alpha_b = (highlight - shadow)/255
gamma_b = shadow
image = cv2.addWeighted(image, alpha_b, image, 0, gamma_b)
if contrast != 0:
f = 131*(contrast + 127)/(127*(131-contrast))
alpha_c = f
gamma_c = 127*(1-f)
image = cv2.addWeighted(image, alpha_c, image, 0, gamma_c)
return image
def apply_filters(self, image, filter_type):
"""应用图像滤镜效果"""
if filter_type == "Sharpen":
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
return cv2.filter2D(image, -1, kernel)
elif filter_type == "Blur":
return cv2.GaussianBlur(image, (5,5), 0)
elif filter_type == "Edge Enhancement":
kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]])
return cv2.filter2D(image, -1, kernel)
return image
def generate_analysis_plots(self, saliency_map):
"""生成英文分析图表"""
plt.style.use('seaborn-v0_8')
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
# Histogram
ax1.hist(saliency_map.flatten(), bins=50, color='#2196F3', alpha=0.7, edgecolor='black')
ax1.set_title('Saliency Distribution Histogram', fontsize=12, pad=15)
ax1.set_xlabel('Saliency Value', fontsize=10)
ax1.set_ylabel('Frequency', fontsize=10)
ax1.grid(True, alpha=0.3)
# Add statistics
mean_val = np.mean(saliency_map)
median_val = np.median(saliency_map)
ax1.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_val:.3f}')
ax1.axvline(median_val, color='green', linestyle='--', alpha=0.7, label=f'Median: {median_val:.3f}')
ax1.legend()
# Cumulative distribution
sorted_vals = np.sort(saliency_map.flatten())
cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
ax2.plot(sorted_vals, cumulative, color='#FF6B35', linewidth=2)
ax2.set_title('Cumulative Distribution Function', fontsize=12)
ax2.set_xlabel('Saliency Value', fontsize=10)
ax2.set_ylabel('Cumulative Probability', fontsize=10)
ax2.grid(True, alpha=0.3)
# Box plot
ax3.boxplot(saliency_map.flatten(), patch_artist=True,
boxprops=dict(facecolor='#21CBF3', alpha=0.7))
ax3.set_title('Saliency Distribution Box Plot', fontsize=12)
ax3.set_ylabel('Saliency Value', fontsize=10)
ax3.grid(True, alpha=0.3)
# Intensity profile (center line)
center_row = saliency_map[saliency_map.shape[0]//2, :]
ax4.plot(center_row, color='#9C27B0', linewidth=2)
ax4.set_title('Center Line Intensity Profile', fontsize=12)
ax4.set_xlabel('Pixel Position', fontsize=10)
ax4.set_ylabel('Saliency Value', fontsize=10)
ax4.grid(True, alpha=0.3)
plt.tight_layout()
# Save to bytes
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
img_array = np.array(Image.open(buf))
plt.close()
return img_array
def quick_process(self, image, threshold=0.5, testsize=256):
"""快速处理模式,只输出显著性图"""
if image is None:
return None, "Please provide a valid image"
# Check cache
image_hash = hash(image.tobytes())
cache_key = f"{image_hash}_{threshold}_{testsize}_quick"
if cache_key in self.cache:
return self.cache[cache_key]
image_pil = Image.fromarray(image).convert('RGB')
image_tensor = transform_image(image_pil, testsize)
image_tensor = image_tensor.unsqueeze(0).to(device)
time_start = time.time()
with torch.no_grad():
if device.type == 'cuda':
with torch.cuda.amp.autocast():
_, res, *_ = self.model(image_tensor)
else:
with torch.amp.autocast(device_type='cpu'):
_, res, *_ = self.model(image_tensor)
time_end = time.time()
# 确保转换为float32类型
res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
h, w = image.shape[:2]
res_resized = cv2.resize(res, (w, h))
res_vis = (res_resized * 255).astype(np.uint8)
result = (res_vis, f"Quick processing completed in {time_end - time_start:.3f}s")
self.cache[cache_key] = result
return result
def process_image(self, image, threshold=0.5, testsize=256,
enhance_contrast=False, denoise=False,
brightness=0, contrast=0, filter_type="None",
process_mode="Full Analysis"):
"""增强的图像处理函数"""
if image is None:
return [None] * 9 + ["Please provide a valid image"]
# Quick mode check
if process_mode == "Quick Mode":
saliency_map, time_info = self.quick_process(image, threshold, testsize)
return (image, saliency_map, None, None, None, None, time_info, None, None)
# Check cache for full processing
image_hash = hash(image.tobytes())
cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full"
if cache_key in self.cache:
return self.cache[cache_key]
# Image preprocessing with threading
def preprocess_image():
processed_image = image.copy()
if denoise:
processed_image = cv2.fastNlMeansDenoisingColored(processed_image, None, 10, 10, 7, 21)
processed_image = self.adjust_brightness_contrast(processed_image, brightness, contrast)
processed_image = self.apply_filters(processed_image, filter_type)
if enhance_contrast:
lab = cv2.cvtColor(processed_image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
l = clahe.apply(l)
lab = cv2.merge((l,a,b))
processed_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
return processed_image
with concurrent.futures.ThreadPoolExecutor() as executor:
future_preprocess = executor.submit(preprocess_image)
processed_image = future_preprocess.result()
original_image = processed_image.copy()
# Model inference
image_pil = Image.fromarray(processed_image).convert('RGB')
image_tensor = transform_image(image_pil, testsize)
image_tensor = image_tensor.unsqueeze(0).to(device)
time_start = time.time()
with torch.no_grad():
if device.type == 'cuda':
with torch.cuda.amp.autocast():
x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
else:
with torch.amp.autocast(device_type='cpu'):
x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
time_end = time.time()
inference_time = time_end - time_start
# 确保转换为float32类型
res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
h, w = original_image.shape[:2]
res_resized = cv2.resize(res, (w, h))
# Generate visualizations
res_vis = (res_resized * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
_, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
# Create overlays
alpha = 0.5
original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
# Convert back to RGB
overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Generate analysis plots
analysis_plot = self.generate_analysis_plots(res_resized)
# Calculate statistics
contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
total_area = w * h
detected_area = cv2.countNonZero(binary_mask)
coverage_ratio = detected_area / total_area
stats = {
"Processing Resolution": f"{w}x{h}",
"Detected Objects": str(len(contours)),
"Average Confidence": f"{np.mean(res_resized):.2%}",
"Max Confidence": f"{np.max(res_resized):.2%}",
"Coverage Ratio": f"{coverage_ratio:.2%}",
"Processing Time": f"{inference_time:.3f}s"
}
# Create comparison image
comparison_img = self.create_comparison_image(original_image, overlayed_rgb)
# Save results
self.last_results = {
'saliency_map': res_resized,
'binary_mask': binary_mask,
'stats': stats
}
result = (original_image, res_vis, heatmap_rgb, overlayed_rgb, segmented_rgb,
comparison_img, f"Processing time: {inference_time:.4f}s", stats, analysis_plot)
# Cache result
self.cache[cache_key] = result
return result
def create_comparison_image(self, original, processed):
"""创建对比图像"""
h, w = original.shape[:2]
comparison = np.zeros((h, w*2, 3), dtype=np.uint8)
comparison[:, :w] = original
comparison[:, w:] = processed
# Add dividing line
cv2.line(comparison, (w, 0), (w, h), (255, 255, 255), 2)
return comparison
def export_results(self, format_type="PNG"):
"""导出结果"""
if self.last_results is None:
return "No results to export"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if format_type == "PDF Report":
# Generate PDF report logic here
return f"PDF report saved as saliency_report_{timestamp}.pdf"
else:
return f"Results exported as {format_type.lower()} files"
# Create processor instance
processor = ImageProcessor()
def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
brightness, contrast, filter_type, process_mode):
"""主处理函数"""
if input_image is None:
return [None] * 9 + ["Please upload an image"]
# Process image
results = processor.process_image(
input_image,
threshold=threshold/100.0,
enhance_contrast=enhance_contrast,
denoise=denoise,
brightness=brightness,
contrast=contrast,
filter_type=filter_type,
process_mode=process_mode
)
original, saliency_map, heatmap, overlayed, segmented, comparison, time_info, stats, analysis_plot = results
# Add contours if requested
if show_contours and saliency_map is not None and overlayed is not None:
_, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
overlay_with_contours = overlayed.copy()
cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2)
overlayed = overlay_with_contours
# Generate statistics HTML
if stats:
stats_html = "<div class='statistics-container'>"
for key, value in stats.items():
stats_html += f"<div class='statistic-card'><h4>{key}</h4><p>{value}</p></div>"
stats_html += "</div>"
else:
stats_html = "<p>No statistics available</p>"
return (original, saliency_map, heatmap, overlayed, segmented,
comparison, time_info, stats_html, analysis_plot)
def create_comparison_view(original, result, slider_value):
"""创建滑块对比视图"""
if original is None or result is None:
return None
h, w = original.shape[:2]
split_point = int(w * slider_value)
comparison = original.copy()
comparison[:, split_point:] = result[:, split_point:]
# Add vertical line
cv2.line(comparison, (split_point, 0), (split_point, h), (255, 255, 0), 3)
return comparison
# Create Gradio interface
with gr.Blocks(title="Advanced Saliency Object Detection System", css=custom_css) as demo:
gr.Markdown(
"""
# 🎯 Advanced Saliency Object Detection System
### AI-Powered Image Saliency Detection and Analysis Tool
"""
)
with gr.Tabs() as tabs:
with gr.TabItem("🔍 Main Functions"):
with gr.Row():
with gr.Column(scale=1):
# Input controls
with gr.Group(elem_classes="advanced-controls"):
input_image = gr.Image(
label="Input Image",
type="numpy",
elem_classes="input-image"
)
# Processing mode selection
process_mode = gr.Radio(
choices=["Full Analysis", "Quick Mode"],
value="Full Analysis",
label="Processing Mode",
info="Quick mode only outputs saliency map for faster processing"
)
with gr.Accordion("Basic Settings", open=True):
threshold_slider = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Detection Threshold",
info="Adjust detection sensitivity"
)
enhance_contrast = gr.Checkbox(
label="Enhance Contrast",
value=False
)
denoise = gr.Checkbox(
label="Noise Reduction",
value=False
)
show_contours = gr.Checkbox(
label="Show Contours",
value=True
)
with gr.Accordion("Image Adjustments", open=False):
brightness = gr.Slider(
minimum=-100,
maximum=100,
value=0,
step=1,
label="Brightness"
)
contrast = gr.Slider(
minimum=-100,
maximum=100,
value=0,
step=1,
label="Contrast"
)
filter_type = gr.Radio(
choices=["None", "Sharpen", "Blur", "Edge Enhancement"],
value="None",
label="Image Filter"
)
with gr.Accordion("Export Options", open=False):
export_format = gr.Dropdown(
choices=["PNG", "JPEG", "PDF Report"],
value="PNG",
label="Export Format"
)
export_btn = gr.Button(
"Export Results",
elem_classes="custom-button"
)
with gr.Row():
submit_btn = gr.Button(
"Start Detection",
variant="primary",
elem_classes="custom-button"
)
reset_btn = gr.Button(
"Reset Parameters",
elem_classes="custom-button"
)
with gr.Column(scale=2):
# Results display
with gr.Tabs():
with gr.TabItem("Detection Results"):
with gr.Row(elem_classes="result-container"):
original_output = gr.Image(
label="Original Image",
elem_classes="output-image"
)
saliency_output = gr.Image(
label="Saliency Map",
elem_classes="output-image"
)
with gr.Row(elem_classes="result-container"):
heatmap_output = gr.Image(
label="Heatmap Analysis",
elem_classes="output-image"
)
overlayed_output = gr.Image(
label="Overlay Effect",
elem_classes="output-image"
)
with gr.Row(elem_classes="result-container"):
segmented_output = gr.Image(
label="Object Segmentation",
elem_classes="output-image"
)
comparison_output = gr.Image(
label="Side-by-Side Comparison",
elem_classes="output-image"
)
with gr.TabItem("Interactive Comparison"):
with gr.Group(elem_classes="comparison-slider"):
comparison_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.01,
label="Original ← → Result",
info="Drag to compare original and processed images"
)
interactive_comparison = gr.Image(
label="Interactive Comparison View",
elem_classes="interactive-viz"
)
with gr.TabItem("Analysis Report"):
with gr.Group(elem_classes="result-container"):
time_info = gr.Textbox(
label="Processing Time",
show_label=True
)
stats_output = gr.HTML(
label="Statistical Information"
)
analysis_plot = gr.Image(
label="Detailed Analysis Charts",
elem_classes="output-image"
)
with gr.TabItem("📖 User Guide"):
gr.Markdown(
"""
## Instructions
1. **Upload Image**: Click the "Input Image" area to upload your image
2. **Select Mode**: Choose between "Full Analysis" or "Quick Mode"
- Full Analysis: Complete processing with all visualizations
- Quick Mode: Fast processing, only outputs saliency map
3. **Adjust Parameters**:
- Use threshold slider to adjust detection sensitivity
- Enable contrast enhancement or noise reduction as needed
- Fine-tune brightness, contrast, and filters in advanced settings
4. **Start Detection**: Click "Start Detection" to begin analysis
5. **View Results**: Check different tabs for various visualization results
6. **Export**: Use export options to save your results
## Features
- **Saliency Map**: Shows importance distribution of image regions
- **Heatmap**: Color-coded intensity visualization
- **Overlay Effect**: Detection results overlaid on original image
- **Object Segmentation**: Extracts key object regions
- **Interactive Comparison**: Slide to compare original and processed images
- **Analysis Report**: Detailed statistics and analysis charts
## Performance Tips
- Use Quick Mode for faster processing when you only need saliency maps
- Lower resolution images process faster
- Enable GPU if available for better performance
"""
)
with gr.TabItem("ℹ️ About"):
gr.Markdown(
"""
## Project Information
- **Version**: 3.0.0
- **Architecture**: PyTorch + Gradio
- **Model**: CyueNet
- **Language**: Multi-language support
## Key Features
- Real-time image processing and analysis
- Multi-dimensional result visualization
- Rich image adjustment options
- Detailed data analysis reports
- Interactive comparison tools
- Export functionality
- Performance optimization with caching
## Update Log
- ✅ Added Quick Mode for faster processing
- ✅ Enhanced image preprocessing options
- ✅ Added statistical analysis functions
- ✅ Improved user interface with glassmorphism design
- ✅ Added interactive comparison slider
- ✅ Performance optimization with caching and threading
- ✅ Multi-language chart support
- ✅ Export functionality
## System Requirements
- Python 3.8+
- PyTorch 1.9+
- CUDA (optional, for GPU acceleration)
- 4GB+ RAM recommended
"""
)
# Event handlers
def reset_params():
return {
threshold_slider: 50,
brightness: 0,
contrast: 0,
filter_type: "None",
enhance_contrast: False,
denoise: False,
show_contours: True,
process_mode: "Full Analysis"
}
# Set up event handling
submit_btn.click(
fn=run_demo,
inputs=[
input_image,
threshold_slider,
enhance_contrast,
denoise,
show_contours,
brightness,
contrast,
filter_type,
process_mode
],
outputs=[
original_output,
saliency_output,
heatmap_output,
overlayed_output,
segmented_output,
comparison_output,
time_info,
stats_output,
analysis_plot
]
)
reset_btn.click(
fn=reset_params,
inputs=[],
outputs=[
threshold_slider,
brightness,
contrast,
filter_type,
enhance_contrast,
denoise,
show_contours,
process_mode
]
)
# Interactive comparison
comparison_slider.change(
fn=create_comparison_view,
inputs=[original_output, overlayed_output, comparison_slider],
outputs=[interactive_comparison]
)
# Export functionality
export_btn.click(
fn=processor.export_results,
inputs=[export_format],
outputs=[gr.Textbox(label="Export Status")]
)
# Launch the application
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)