import gradio as gr from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import torch import os # --- Chargement du modèle et du processeur --- print("Loading model and processor...") model_name = "google/vit-base-patch16-224" processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) print("Model loaded successfully!") def predict(image): """Fonction de prédiction avec gestion d'erreurs et seuil de confiance""" try: # Conversion vers RGB pour éviter les erreurs de canaux if image.mode != 'RGB': image = image.convert('RGB') # Pré-traitement de l'image inputs = processor(images=image, return_tensors="pt") # Prédiction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Application de softmax pour obtenir les probabilités probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions # Formatage des résultats sous forme de dictionnaire pour l'affichage results = {} for prob, idx in zip(top_probs, top_indices): pred_label = model.config.id2label[idx.item()] confidence = prob.item() if confidence > 0.01: # Seuil de confiance à 1% results[pred_label] = confidence if not results: return {"Aucune prédiction fiable": 0.0}, "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire." # Créer un message de résultat top_prediction = list(results.items())[0] message = f"🏷️ Prédiction principale: {top_prediction[0]} ({top_prediction[1]:.2%})" return results, message except Exception as e: return {"Erreur": 0.0}, f"Une erreur s'est produite: {str(e)}" # Interface Gradio améliorée with gr.Blocks(title="Fashion Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown("# 👗 Fashion Item Classifier") gr.Markdown("Téléchargez une image de vêtement pour le classer automatiquement") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="Image du vêtement", height=300, sources=["upload", "webcam", "clipboard"] ) upload_btn = gr.Button("🚀 Analyser l'image", variant="primary") with gr.Column(scale=1): label_output = gr.Label( label="Résultats de classification", num_top_classes=5 ) text_output = gr.Textbox( label="Conclusion", interactive=False ) # Exemples gr.Examples( examples=[ ["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=300"], # T-shirt ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=300"], # Chaussures ["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=300"] # Robe ], inputs=image_input, label="Exemples d'images à tester" ) # Instructions gr.Markdown(""" ### 📋 Instructions - Téléchargez une image claire d'un vêtement - L'image doit montrer le vêtement de face - Fond uni recommandé pour de meilleurs résultats - Cliquez sur 'Analyser l'image' pour obtenir la classification """) # Liaison du bouton upload_btn.click( fn=predict, inputs=image_input, outputs=[label_output, text_output] ) # Liaison aussi quand on upload une image image_input.upload( fn=predict, inputs=image_input, outputs=[label_output, text_output] ) # Lancement de l'application if __name__ == "__main__": demo.launch( debug=True, server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)) )