MODLI commited on
Commit
140fdb2
·
verified ·
1 Parent(s): d8bd6de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -80
app.py CHANGED
@@ -1,164 +1,220 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoImageProcessor, AutoModelForImageClassification
4
  from PIL import Image
5
  import requests
6
  from io import BytesIO
 
7
 
8
- # Chargement du modèle spécialisé dans la mode
9
- MODEL_NAME = "google/vit-base-patch16-224" # Modèle de base fiable
10
- # Alternative: "nateraw/fashion-clip" si disponible
11
 
12
- # Initialisation du modèle
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"🖥️ Utilisation du device: {device}")
15
 
16
  try:
17
- # Chargeur d'images
18
- processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
19
- # Modèle de classification
20
- model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
21
  model.to(device)
22
  model.eval()
23
- print("✅ Modèle chargé avec succès!")
 
 
 
24
  except Exception as e:
25
- print(f"❌ Erreur chargement modèle: {e}")
26
  processor = None
27
  model = None
28
 
29
- def classify_clothing(image):
30
- """Classifie une image de vêtement"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  try:
32
  if image is None:
33
  return "❌ Veuillez uploader une image de vêtement"
34
 
35
  if processor is None or model is None:
36
- return "⚠️ Modèle en cours de chargement... Réessayez dans 30 secondes"
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Prétraitement de l'image
39
- inputs = processor(images=image, return_tensors="pt")
40
  inputs = {k: v.to(device) for k, v in inputs.items()}
41
 
42
- # Classification
43
  with torch.no_grad():
44
  outputs = model(**inputs)
45
 
46
- # Récupération des résultats
47
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
48
  top_probs, top_indices = torch.topk(probabilities, 5)
49
 
50
- # Conversion en résultats lisibles
51
  results = []
52
  for i in range(len(top_indices[0])):
53
- label = model.config.id2label[top_indices[0][i].item()]
 
 
54
  score = top_probs[0][i].item() * 100
55
- results.append({"label": label, "score": score})
 
 
 
56
 
57
- # Formatage des résultats
58
- output = "## 🎯 Résultats de Classification:\n\n"
59
  for i, result in enumerate(results):
60
- # Nettoyage des labels
61
- clean_label = result['label'].replace('_', ' ').title()
62
- output += f"{i+1}. **{clean_label}** - {result['score']:.1f}%\n"
 
 
 
63
 
64
- output += "\n---\n"
65
- output += "💡 **Conseils pour de meilleurs résultats:**\n"
66
- output += "• Utilisez des images claires sur fond uni\n"
67
- output += "• Cadrez bien le vêtement\n"
68
- output += "• Évitez les images avec plusieurs personnes\n"
69
 
70
  return output
71
 
72
  except Exception as e:
73
- return f"❌ Erreur lors de la classification: {str(e)}"
 
 
 
 
 
 
 
 
74
 
75
- def load_example_image(url):
76
- """Charge une image d'exemple depuis une URL"""
77
  try:
78
  response = requests.get(url, timeout=10)
79
- image = Image.open(BytesIO(response.content))
80
- return image
81
  except:
82
  return None
83
 
84
- # Exemples d'images de test
85
- example_images = [
86
- ["https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400"], # T-shirt
87
- ["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400"], # Robe
88
- ["https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400"], # Chemise
89
- ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"], # Veste
90
- ]
91
-
92
- # Interface Gradio
93
- with gr.Blocks(title="Classificateur de Vêtements", theme=gr.themes.Soft()) as demo:
94
  gr.Markdown("""
95
- # 👗 Classificateur de Vêtements Intelligent
96
- **Uploader une image de vêtement** pour obtenir sa classification automatique
97
  """)
98
 
99
  with gr.Row():
100
  with gr.Column(scale=1):
101
- gr.Markdown("### 📤 Uploader votre image")
102
  image_input = gr.Image(
103
  type="pil",
104
  label="Image de vêtement",
105
  height=300,
106
- sources=["upload", "webcam", "clipboard"]
 
107
  )
108
 
109
- gr.Markdown("### 🎯 Actions")
110
- classify_btn = gr.Button("🚀 Classifier", variant="primary")
111
- clear_btn = gr.Button("🧹 Effacer", variant="secondary")
112
 
113
- gr.Markdown("### 💡 Conseils")
114
  gr.Markdown("""
115
- - Images claires et bien éclairées
116
- - Vêtement visible et bien cadré
117
- - Fond simple de préférence
 
 
118
  """)
119
 
120
  with gr.Column(scale=2):
121
- gr.Markdown("### 📊 Résultats")
122
  output_text = gr.Markdown(
123
- value="⬅️ Uploader une image ou choisissez un exemple ci-dessous"
124
  )
125
 
126
- # Section exemples
127
- gr.Markdown("### 🖼️ Exemples à tester")
128
- gr.Examples(
129
- examples=example_images,
130
- inputs=image_input,
131
- outputs=output_text,
132
- fn=classify_clothing,
133
- label="Cliquez sur une image pour tester",
134
- cache_examples=True
135
- )
 
 
136
 
137
- # Événements
138
  classify_btn.click(
139
- fn=classify_clothing,
140
  inputs=[image_input],
141
- outputs=output_text
 
142
  )
143
 
144
  clear_btn.click(
145
- fn=lambda: (None, "⬅️ Uploader une nouvelle image"),
146
  inputs=[],
147
  outputs=[image_input, output_text]
148
  )
149
 
150
- # Classification automatique au changement
151
- image_input.change(
152
- fn=classify_clothing,
153
  inputs=[image_input],
154
  outputs=output_text
155
  )
156
 
157
- # Configuration
158
  if __name__ == "__main__":
159
  demo.launch(
160
  server_name="0.0.0.0",
161
  server_port=7860,
162
  share=False,
163
- debug=True
 
164
  )
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  from PIL import Image
6
  import requests
7
  from io import BytesIO
8
+ import numpy as np
9
 
10
+ # 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE
11
+ MODEL_NAME = "rafalosa/diffusiondb-fashion-mnist" # Modèle spécialisé mode
12
+ # Alternative: "nateraw/vit-base-patch16-224-fashion-mnist"
13
 
14
+ print("🔄 Chargement du modèle de mode...")
 
 
15
 
16
  try:
17
+ # Chargeur d'images avec prétraitement correct
18
+ processor = AutoImageProcessor.from_pretrained(
19
+ "google/vit-base-patch16-224", # Base standard
20
+ cache_dir="cache"
21
+ )
22
+
23
+ # Modèle fine-tuné sur la mode
24
+ model = AutoModelForImageClassification.from_pretrained(
25
+ MODEL_NAME,
26
+ cache_dir="cache",
27
+ trust_remote_code=True
28
+ )
29
+
30
+ # Configuration device
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
  model.to(device)
33
  model.eval()
34
+
35
+ print(f"✅ Modèle chargé sur {device}")
36
+ print(f"📊 Classe disponibles: {model.config.num_labels}")
37
+
38
  except Exception as e:
39
+ print(f"❌ Erreur chargement: {e}")
40
  processor = None
41
  model = None
42
 
43
+ # 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE
44
+ FASHION_LABELS = [
45
+ "T-shirt", "Pantalon", "Pull", "Robe", "Manteau",
46
+ "Sandale", "Chemise", "Sneaker", "Sac", "Botte"
47
+ ]
48
+
49
+ def preprocess_image(image):
50
+ """Prétraitement correct des images"""
51
+ # Conversion en RGB
52
+ if image.mode != 'RGB':
53
+ image = image.convert('RGB')
54
+
55
+ # Redimensionnement intelligent
56
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
57
+
58
+ return image
59
+
60
+ def classify_fashion(image):
61
+ """Classification spécialisée mode"""
62
  try:
63
  if image is None:
64
  return "❌ Veuillez uploader une image de vêtement"
65
 
66
  if processor is None or model is None:
67
+ return "⚠️ Modèle en cours de chargement... Patientez 30s"
68
+
69
+ # 🔥 PRÉTRAITEMENT CORRECT
70
+ processed_image = preprocess_image(image)
71
+
72
+ # Transformation pour le modèle
73
+ inputs = processor(
74
+ images=processed_image,
75
+ return_tensors="pt",
76
+ do_resize=True,
77
+ do_rescale=True,
78
+ do_normalize=True
79
+ )
80
 
81
+ # Transfert sur le bon device
 
82
  inputs = {k: v.to(device) for k, v in inputs.items()}
83
 
84
+ # 🔥 INFÉRENCE AVEC GRADIENTS DÉSACTIVÉS
85
  with torch.no_grad():
86
  outputs = model(**inputs)
87
 
88
+ # 🔥 POST-TRAITEMENT CORRECT
89
+ probabilities = F.softmax(outputs.logits, dim=-1)
90
  top_probs, top_indices = torch.topk(probabilities, 5)
91
 
92
+ # Conversion en résultats
93
  results = []
94
  for i in range(len(top_indices[0])):
95
+ # Utilisation de nos labels personnalisés
96
+ label_idx = top_indices[0][i].item()
97
+ label_name = FASHION_LABELS[label_idx % len(FASHION_LABELS)]
98
  score = top_probs[0][i].item() * 100
99
+ results.append({"label": label_name, "score": score})
100
+
101
+ # 📊 FORMATAGE DES RÉSULTATS
102
+ output = "## 🎯 RÉSULTATS DE CLASSIFICATION:\n\n"
103
 
 
 
104
  for i, result in enumerate(results):
105
+ output += f"{i+1}. **{result['label']}** - {result['score']:.1f}%\n"
106
+
107
+ # 📸 Aperçu de l'image traitée
108
+ output += f"\n---\n"
109
+ output += f"📏 Image traitée: 224x224 pixels\n"
110
+ output += f"🔢 Modèle: {MODEL_NAME.split('/')[-1]}\n"
111
 
112
+ output += "\n💡 **Pour de meilleurs résultats:**\n"
113
+ output += " Photo claire sur fond uni\n"
114
+ output += "• Vêtement bien visible\n"
115
+ output += "• Éviter les angles bizarres\n"
 
116
 
117
  return output
118
 
119
  except Exception as e:
120
+ return f"❌ Erreur: {str(e)}\n\n🔧 Vérifiez les logs pour plus de détails"
121
+
122
+ # 🖼️ EXEMPLES SPÉCIFIQUES MODE
123
+ EXAMPLE_URLS = [
124
+ "https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400", # T-shirt
125
+ "https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400", # Robe
126
+ "https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400", # Chemise
127
+ "https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400", # Veste
128
+ ]
129
 
130
+ def load_example(url):
131
+ """Charge un exemple depuis une URL"""
132
  try:
133
  response = requests.get(url, timeout=10)
134
+ return Image.open(BytesIO(response.content))
 
135
  except:
136
  return None
137
 
138
+ # 🎨 INTERFACE AMÉLIORÉE
139
+ with gr.Blocks(
140
+ title="Classificateur de Mode Expert",
141
+ theme=gr.themes.Soft(primary_hue="pink")
142
+ ) as demo:
143
+
 
 
 
 
144
  gr.Markdown("""
145
+ # 👗 CLASSIFICATEUR EXPERT DE VÊTEMENTS
146
+ *Powered by Fine-Tuned Vision Transformer*
147
  """)
148
 
149
  with gr.Row():
150
  with gr.Column(scale=1):
151
+ gr.Markdown("### 📤 UPLOADER")
152
  image_input = gr.Image(
153
  type="pil",
154
  label="Image de vêtement",
155
  height=300,
156
+ sources=["upload", "clipboard"],
157
+ interactive=True
158
  )
159
 
160
+ with gr.Row():
161
+ classify_btn = gr.Button("🚀 Classifier", variant="primary")
162
+ clear_btn = gr.Button("🧹 Effacer", variant="secondary")
163
 
 
164
  gr.Markdown("""
165
+ ### 💡 CONSEILS
166
+ - 📷 Photo claire et nette
167
+ - 🎯 Vêtement bien centré
168
+ - 🌟 Fond uni de préférence
169
+ - ⚡ Attendez 3-5 secondes
170
  """)
171
 
172
  with gr.Column(scale=2):
173
+ gr.Markdown("### 📊 RÉSULTATS")
174
  output_text = gr.Markdown(
175
+ value="⬅️ Uploader une image ou utilisez les exemples ci-dessous"
176
  )
177
 
178
+ # 🎯 EXEMPLES INTERACTIFS
179
+ gr.Markdown("### 🖼️ EXEMPLES DE TEST")
180
+ with gr.Row():
181
+ for i, url in enumerate(EXAMPLE_URLS):
182
+ gr.Examples(
183
+ examples=[[url]],
184
+ inputs=image_input,
185
+ outputs=output_text,
186
+ fn=classify_fashion,
187
+ label=f"Exemple {i+1}",
188
+ cache_examples=False
189
+ )
190
 
191
+ # 🎮 INTERACTIONS
192
  classify_btn.click(
193
+ fn=classify_fashion,
194
  inputs=[image_input],
195
+ outputs=output_text,
196
+ api_name="classify"
197
  )
198
 
199
  clear_btn.click(
200
+ fn=lambda: (None, "⬅️ Prêt pour une nouvelle image"),
201
  inputs=[],
202
  outputs=[image_input, output_text]
203
  )
204
 
205
+ # 🔄 AUTO-CLASSIFICATION
206
+ image_input.upload(
207
+ fn=classify_fashion,
208
  inputs=[image_input],
209
  outputs=output_text
210
  )
211
 
212
+ # ⚙️ CONFIGURATION
213
  if __name__ == "__main__":
214
  demo.launch(
215
  server_name="0.0.0.0",
216
  server_port=7860,
217
  share=False,
218
+ debug=True,
219
+ show_error=True
220
  )