import gradio as gr import torch import torch.nn.functional as F from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import requests from io import BytesIO import numpy as np import os from pathlib import Path import tempfile # 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE MODEL_NAME = "google/vit-base-patch16-224" # Modèle fiable et rapide print("🔄 Chargement du modèle de mode...") try: processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() print(f"✅ Modèle chargé sur {device}") except Exception as e: print(f"❌ Erreur chargement: {e}") processor = None model = None # 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE (adaptés au modèle) FASHION_LABELS = { 0: "T-shirt", 1: "Pantalon", 2: "Pull", 3: "Robe", 4: "Manteau", 5: "Sandale", 6: "Chemise", 7: "Sneaker", 8: "Sac", 9: "Botte", 10: "Veste", 11: "Jupe", 12: "Short", 13: "Chaussures", 14: "Accessoire" } def convert_heic_to_jpeg(image_path): """Convertit les HEIC en JPEG si nécessaire""" try: if isinstance(image_path, str) and image_path.lower().endswith('.heic'): # Conversion HEIC → JPEG img = Image.open(image_path) jpeg_path = image_path.replace('.heic', '.jpeg') img.convert('RGB').save(jpeg_path, 'JPEG') return jpeg_path except: pass return image_path def preprocess_image(image): """Prétraitement robuste des images""" try: # Si c'est un chemin de fichier (HEIC) if isinstance(image, str): image = convert_heic_to_jpeg(image) image = Image.open(image) # Conversion en RGB if image.mode != 'RGB': image = image.convert('RGB') # Redimensionnement image = image.resize((224, 224), Image.Resampling.LANCZOS) return image except Exception as e: raise Exception(f"Erreur prétraitement: {str(e)}") def classify_fashion(image): """Classification avec gestion robuste des formats""" try: if image is None: return "❌ Veuillez uploader une image de vêtement" if processor is None or model is None: return "⚠️ Modèle en cours de chargement... Patientez 30s" # 📸 Gestion spéciale HEIC et formats complexes try: # Si l'image est un chemin temporaire (format HEIC) if isinstance(image, str) and ('gradio' in image or 'tmp' in image): if image.lower().endswith('.heic'): # Conversion HEIC → JPEG img = Image.open(image) with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: img.convert('RGB').save(tmp.name, 'JPEG', quality=95) processed_image = Image.open(tmp.name) os.unlink(tmp.name) # Nettoyage else: processed_image = Image.open(image) else: # Image normale processed_image = image # Conversion en RGB si nécessaire if processed_image.mode != 'RGB': processed_image = processed_image.convert('RGB') except Exception as e: return f"❌ Format d'image non supporté: {str(e)}\n\n💡 Utilisez JPEG, PNG ou WebP" # 🔥 PRÉTRAITEMENT CORRECT processed_image = processed_image.resize((224, 224), Image.Resampling.LANCZOS) # Transformation pour le modèle inputs = processor(images=processed_image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # 🔥 INFÉRENCE with torch.no_grad(): outputs = model(**inputs) # 📊 POST-TRAITEMENT probabilities = F.softmax(outputs.logits, dim=-1) top_probs, top_indices = torch.topk(probabilities, 5) # Conversion en résultats results = [] for i in range(len(top_indices[0])): label_idx = top_indices[0][i].item() label_name = FASHION_LABELS.get(label_idx, f"Catégorie {label_idx}")