MODLI commited on
Commit
fd50bed
·
verified ·
1 Parent(s): e34f43b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+
8
+ # Chargement du modèle spécialisé dans la mode
9
+ MODEL_NAME = "google/vit-base-patch16-224" # Modèle de base fiable
10
+ # Alternative: "nateraw/fashion-clip" si disponible
11
+
12
+ # Initialisation du modèle
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"🖥️ Utilisation du device: {device}")
15
+
16
+ try:
17
+ # Chargeur d'images
18
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
19
+ # Modèle de classification
20
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
21
+ model.to(device)
22
+ model.eval()
23
+ print("✅ Modèle chargé avec succès!")
24
+ except Exception as e:
25
+ print(f"❌ Erreur chargement modèle: {e}")
26
+ processor = None
27
+ model = None
28
+
29
+ def classify_clothing(image):
30
+ """Classifie une image de vêtement"""
31
+ try:
32
+ if image is None:
33
+ return "❌ Veuillez uploader une image de vêtement"
34
+
35
+ if processor is None or model is None:
36
+ return "⚠️ Modèle en cours de chargement... Réessayez dans 30 secondes"
37
+
38
+ # Prétraitement de l'image
39
+ inputs = processor(images=image, return_tensors="pt")
40
+ inputs = {k: v.to(device) for k, v in inputs.items()}
41
+
42
+ # Classification
43
+ with torch.no_grad():
44
+ outputs = model(**inputs)
45
+
46
+ # Récupération des résultats
47
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
48
+ top_probs, top_indices = torch.topk(probabilities, 5)
49
+
50
+ # Conversion en résultats lisibles
51
+ results = []
52
+ for i in range(len(top_indices[0])):
53
+ label = model.config.id2label[top_indices[0][i].item()]
54
+ score = top_probs[0][i].item() * 100
55
+ results.append({"label": label, "score": score})
56
+
57
+ # Formatage des résultats
58
+ output = "## 🎯 Résultats de Classification:\n\n"
59
+ for i, result in enumerate(results):
60
+ # Nettoyage des labels
61
+ clean_label = result['label'].replace('_', ' ').title()
62
+ output += f"{i+1}. **{clean_label}** - {result['score']:.1f}%\n"
63
+
64
+ output += "\n---\n"
65
+ output += "💡 **Conseils pour de meilleurs résultats:**\n"
66
+ output += "• Utilisez des images claires sur fond uni\n"
67
+ output += "• Cadrez bien le vêtement\n"
68
+ output += "• Évitez les images avec plusieurs personnes\n"
69
+
70
+ return output
71
+
72
+ except Exception as e:
73
+ return f"❌ Erreur lors de la classification: {str(e)}"
74
+
75
+ def load_example_image(url):
76
+ """Charge une image d'exemple depuis une URL"""
77
+ try:
78
+ response = requests.get(url, timeout=10)
79
+ image = Image.open(BytesIO(response.content))
80
+ return image
81
+ except:
82
+ return None
83
+
84
+ # Exemples d'images de test
85
+ example_images = [
86
+ ["https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400"], # T-shirt
87
+ ["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400"], # Robe
88
+ ["https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400"], # Chemise
89
+ ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"], # Veste
90
+ ]
91
+
92
+ # Interface Gradio
93
+ with gr.Blocks(title="Classificateur de Vêtements", theme=gr.themes.Soft()) as demo:
94
+ gr.Markdown("""
95
+ # 👗 Classificateur de Vêtements Intelligent
96
+ **Uploader une image de vêtement** pour obtenir sa classification automatique
97
+ """)
98
+
99
+ with gr.Row():
100
+ with gr.Column(scale=1):
101
+ gr.Markdown("### 📤 Uploader votre image")
102
+ image_input = gr.Image(
103
+ type="pil",
104
+ label="Image de vêtement",
105
+ height=300,
106
+ sources=["upload", "webcam", "clipboard"]
107
+ )
108
+
109
+ gr.Markdown("### 🎯 Actions")
110
+ classify_btn = gr.Button("🚀 Classifier", variant="primary")
111
+ clear_btn = gr.Button("🧹 Effacer", variant="secondary")
112
+
113
+ gr.Markdown("### 💡 Conseils")
114
+ gr.Markdown("""
115
+ - Images claires et bien éclairées
116
+ - Vêtement visible et bien cadré
117
+ - Fond simple de préférence
118
+ """)
119
+
120
+ with gr.Column(scale=2):
121
+ gr.Markdown("### 📊 Résultats")
122
+ output_text = gr.Markdown(
123
+ value="⬅️ Uploader une image ou choisissez un exemple ci-dessous"
124
+ )
125
+
126
+ # Section exemples
127
+ gr.Markdown("### 🖼️ Exemples à tester")
128
+ gr.Examples(
129
+ examples=example_images,
130
+ inputs=image_input,
131
+ outputs=output_text,
132
+ fn=classify_clothing,
133
+ label="Cliquez sur une image pour tester",
134
+ cache_examples=True
135
+ )
136
+
137
+ # Événements
138
+ classify_btn.click(
139
+ fn=classify_clothing,
140
+ inputs=[image_input],
141
+ outputs=output_text
142
+ )
143
+
144
+ clear_btn.click(
145
+ fn=lambda: (None, "⬅️ Uploader une nouvelle image"),
146
+ inputs=[],
147
+ outputs=[image_input, output_text]
148
+ )
149
+
150
+ # Classification automatique au changement
151
+ image_input.change(
152
+ fn=classify_clothing,
153
+ inputs=[image_input],
154
+ outputs=output_text
155
+ )
156
+
157
+ # Configuration
158
+ if __name__ == "__main__":
159
+ demo.launch(
160
+ server_name="0.0.0.0",
161
+ server_port=7860,
162
+ share=False,
163
+ debug=True
164
+ )