MODLI commited on
Commit
e26d51d
·
verified ·
1 Parent(s): 3086f97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -168
app.py CHANGED
@@ -6,59 +6,72 @@ from PIL import Image
6
  import requests
7
  from io import BytesIO
8
  import numpy as np
 
 
 
9
 
10
  # 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE
11
- MODEL_NAME = "rafalosa/diffusiondb-fashion-mnist" # Modèle spécialisé mode
12
- # Alternative: "nateraw/vit-base-patch16-224-fashion-mnist"
13
 
14
  print("🔄 Chargement du modèle de mode...")
15
 
16
  try:
17
- # Chargeur d'images avec prétraitement correct
18
- processor = AutoImageProcessor.from_pretrained(
19
- "google/vit-base-patch16-224", # Base standard
20
- cache_dir="cache"
21
- )
22
 
23
- # Modèle fine-tuné sur la mode
24
- model = AutoModelForImageClassification.from_pretrained(
25
- MODEL_NAME,
26
- cache_dir="cache",
27
- trust_remote_code=True
28
- )
29
-
30
- # Configuration device
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  model.to(device)
33
  model.eval()
34
 
35
  print(f"✅ Modèle chargé sur {device}")
36
- print(f"📊 Classe disponibles: {model.config.num_labels}")
37
 
38
  except Exception as e:
39
  print(f"❌ Erreur chargement: {e}")
40
  processor = None
41
  model = None
42
 
43
- # 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE
44
- FASHION_LABELS = [
45
- "T-shirt", "Pantalon", "Pull", "Robe", "Manteau",
46
- "Sandale", "Chemise", "Sneaker", "Sac", "Botte"
47
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def preprocess_image(image):
50
- """Prétraitement correct des images"""
51
- # Conversion en RGB
52
- if image.mode != 'RGB':
53
- image = image.convert('RGB')
54
-
55
- # Redimensionnement intelligent
56
- image = image.resize((224, 224), Image.Resampling.LANCZOS)
57
-
58
- return image
 
 
 
 
 
 
 
 
 
59
 
60
  def classify_fashion(image):
61
- """Classification spécialisée mode"""
62
  try:
63
  if image is None:
64
  return "❌ Veuillez uploader une image de vêtement"
@@ -66,155 +79,47 @@ def classify_fashion(image):
66
  if processor is None or model is None:
67
  return "⚠️ Modèle en cours de chargement... Patientez 30s"
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # 🔥 PRÉTRAITEMENT CORRECT
70
- processed_image = preprocess_image(image)
71
 
72
  # Transformation pour le modèle
73
- inputs = processor(
74
- images=processed_image,
75
- return_tensors="pt",
76
- do_resize=True,
77
- do_rescale=True,
78
- do_normalize=True
79
- )
80
-
81
- # Transfert sur le bon device
82
  inputs = {k: v.to(device) for k, v in inputs.items()}
83
 
84
- # 🔥 INFÉRENCE AVEC GRADIENTS DÉSACTIVÉS
85
  with torch.no_grad():
86
  outputs = model(**inputs)
87
 
88
- # 🔥 POST-TRAITEMENT CORRECT
89
  probabilities = F.softmax(outputs.logits, dim=-1)
90
  top_probs, top_indices = torch.topk(probabilities, 5)
91
 
92
  # Conversion en résultats
93
  results = []
94
  for i in range(len(top_indices[0])):
95
- # Utilisation de nos labels personnalisés
96
  label_idx = top_indices[0][i].item()
97
- label_name = FASHION_LABELS[label_idx % len(FASHION_LABELS)]
98
- score = top_probs[0][i].item() * 100
99
- results.append({"label": label_name, "score": score})
100
-
101
- # 📊 FORMATAGE DES RÉSULTATS
102
- output = "## 🎯 RÉSULTATS DE CLASSIFICATION:\n\n"
103
-
104
- for i, result in enumerate(results):
105
- output += f"{i+1}. **{result['label']}** - {result['score']:.1f}%\n"
106
-
107
- # 📸 Aperçu de l'image traitée
108
- output += f"\n---\n"
109
- output += f"📏 Image traitée: 224x224 pixels\n"
110
- output += f"🔢 Modèle: {MODEL_NAME.split('/')[-1]}\n"
111
-
112
- output += "\n💡 **Pour de meilleurs résultats:**\n"
113
- output += "• Photo claire sur fond uni\n"
114
- output += "• Vêtement bien visible\n"
115
- output += "• Éviter les angles bizarres\n"
116
-
117
- return output
118
-
119
- except Exception as e:
120
- return f"❌ Erreur: {str(e)}\n\n🔧 Vérifiez les logs pour plus de détails"
121
-
122
- # 🖼️ EXEMPLES SPÉCIFIQUES MODE
123
- EXAMPLE_URLS = [
124
- "https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400", # T-shirt
125
- "https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400", # Robe
126
- "https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400", # Chemise
127
- "https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400", # Veste
128
- ]
129
-
130
- def load_example(url):
131
- """Charge un exemple depuis une URL"""
132
- try:
133
- response = requests.get(url, timeout=10)
134
- return Image.open(BytesIO(response.content))
135
- except:
136
- return None
137
-
138
- # 🎨 INTERFACE AMÉLIORÉE
139
- with gr.Blocks(
140
- title="Classificateur de Mode Expert",
141
- theme=gr.themes.Soft(primary_hue="pink")
142
- ) as demo:
143
-
144
- gr.Markdown("""
145
- # 👗 CLASSIFICATEUR EXPERT DE VÊTEMENTS
146
- *Powered by Fine-Tuned Vision Transformer*
147
- """)
148
-
149
- with gr.Row():
150
- with gr.Column(scale=1):
151
- gr.Markdown("### 📤 UPLOADER")
152
- image_input = gr.Image(
153
- type="pil",
154
- label="Image de vêtement",
155
- height=300,
156
- sources=["upload", "clipboard"],
157
- interactive=True
158
- )
159
-
160
- with gr.Row():
161
- classify_btn = gr.Button("🚀 Classifier", variant="primary")
162
- clear_btn = gr.Button("🧹 Effacer", variant="secondary")
163
-
164
- gr.Markdown("""
165
- ### 💡 CONSEILS
166
- - 📷 Photo claire et nette
167
- - 🎯 Vêtement bien centré
168
- - 🌟 Fond uni de préférence
169
- - ⚡ Attendez 3-5 secondes
170
- """)
171
-
172
- with gr.Column(scale=2):
173
- gr.Markdown("### 📊 RÉSULTATS")
174
- output_text = gr.Markdown(
175
- value="⬅️ Uploader une image ou utilisez les exemples ci-dessous"
176
- )
177
-
178
- # 🎯 EXEMPLES INTERACTIFS
179
- gr.Markdown("### 🖼️ EXEMPLES DE TEST")
180
- with gr.Row():
181
- for i, url in enumerate(EXAMPLE_URLS):
182
- gr.Examples(
183
- examples=[[url]],
184
- inputs=image_input,
185
- outputs=output_text,
186
- fn=classify_fashion,
187
- label=f"Exemple {i+1}",
188
- cache_examples=False
189
- )
190
-
191
- # 🎮 INTERACTIONS
192
- classify_btn.click(
193
- fn=classify_fashion,
194
- inputs=[image_input],
195
- outputs=output_text,
196
- api_name="classify"
197
- )
198
-
199
- clear_btn.click(
200
- fn=lambda: (None, "⬅️ Prêt pour une nouvelle image"),
201
- inputs=[],
202
- outputs=[image_input, output_text]
203
- )
204
-
205
- # 🔄 AUTO-CLASSIFICATION
206
- image_input.upload(
207
- fn=classify_fashion,
208
- inputs=[image_input],
209
- outputs=output_text
210
- )
211
-
212
- # ⚙️ CONFIGURATION
213
- if __name__ == "__main__":
214
- demo.launch(
215
- server_name="0.0.0.0",
216
- server_port=7860,
217
- share=False,
218
- debug=True,
219
- show_error=True
220
- )
 
6
  import requests
7
  from io import BytesIO
8
  import numpy as np
9
+ import os
10
+ from pathlib import Path
11
+ import tempfile
12
 
13
  # 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE
14
+ MODEL_NAME = "google/vit-base-patch16-224" # Modèle fiable et rapide
 
15
 
16
  print("🔄 Chargement du modèle de mode...")
17
 
18
  try:
19
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
20
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
 
 
 
21
 
 
 
 
 
 
 
 
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model.to(device)
24
  model.eval()
25
 
26
  print(f"✅ Modèle chargé sur {device}")
 
27
 
28
  except Exception as e:
29
  print(f"❌ Erreur chargement: {e}")
30
  processor = None
31
  model = None
32
 
33
+ # 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE (adaptés au modèle)
34
+ FASHION_LABELS = {
35
+ 0: "T-shirt", 1: "Pantalon", 2: "Pull", 3: "Robe", 4: "Manteau",
36
+ 5: "Sandale", 6: "Chemise", 7: "Sneaker", 8: "Sac", 9: "Botte",
37
+ 10: "Veste", 11: "Jupe", 12: "Short", 13: "Chaussures", 14: "Accessoire"
38
+ }
39
+
40
+ def convert_heic_to_jpeg(image_path):
41
+ """Convertit les HEIC en JPEG si nécessaire"""
42
+ try:
43
+ if isinstance(image_path, str) and image_path.lower().endswith('.heic'):
44
+ # Conversion HEIC → JPEG
45
+ img = Image.open(image_path)
46
+ jpeg_path = image_path.replace('.heic', '.jpeg')
47
+ img.convert('RGB').save(jpeg_path, 'JPEG')
48
+ return jpeg_path
49
+ except:
50
+ pass
51
+ return image_path
52
 
53
  def preprocess_image(image):
54
+ """Prétraitement robuste des images"""
55
+ try:
56
+ # Si c'est un chemin de fichier (HEIC)
57
+ if isinstance(image, str):
58
+ image = convert_heic_to_jpeg(image)
59
+ image = Image.open(image)
60
+
61
+ # Conversion en RGB
62
+ if image.mode != 'RGB':
63
+ image = image.convert('RGB')
64
+
65
+ # Redimensionnement
66
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
67
+
68
+ return image
69
+
70
+ except Exception as e:
71
+ raise Exception(f"Erreur prétraitement: {str(e)}")
72
 
73
  def classify_fashion(image):
74
+ """Classification avec gestion robuste des formats"""
75
  try:
76
  if image is None:
77
  return "❌ Veuillez uploader une image de vêtement"
 
79
  if processor is None or model is None:
80
  return "⚠️ Modèle en cours de chargement... Patientez 30s"
81
 
82
+ # 📸 Gestion spéciale HEIC et formats complexes
83
+ try:
84
+ # Si l'image est un chemin temporaire (format HEIC)
85
+ if isinstance(image, str) and ('gradio' in image or 'tmp' in image):
86
+ if image.lower().endswith('.heic'):
87
+ # Conversion HEIC → JPEG
88
+ img = Image.open(image)
89
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
90
+ img.convert('RGB').save(tmp.name, 'JPEG', quality=95)
91
+ processed_image = Image.open(tmp.name)
92
+ os.unlink(tmp.name) # Nettoyage
93
+ else:
94
+ processed_image = Image.open(image)
95
+ else:
96
+ # Image normale
97
+ processed_image = image
98
+
99
+ # Conversion en RGB si nécessaire
100
+ if processed_image.mode != 'RGB':
101
+ processed_image = processed_image.convert('RGB')
102
+
103
+ except Exception as e:
104
+ return f"❌ Format d'image non supporté: {str(e)}\n\n💡 Utilisez JPEG, PNG ou WebP"
105
+
106
  # 🔥 PRÉTRAITEMENT CORRECT
107
+ processed_image = processed_image.resize((224, 224), Image.Resampling.LANCZOS)
108
 
109
  # Transformation pour le modèle
110
+ inputs = processor(images=processed_image, return_tensors="pt")
 
 
 
 
 
 
 
 
111
  inputs = {k: v.to(device) for k, v in inputs.items()}
112
 
113
+ # 🔥 INFÉRENCE
114
  with torch.no_grad():
115
  outputs = model(**inputs)
116
 
117
+ # 📊 POST-TRAITEMENT
118
  probabilities = F.softmax(outputs.logits, dim=-1)
119
  top_probs, top_indices = torch.topk(probabilities, 5)
120
 
121
  # Conversion en résultats
122
  results = []
123
  for i in range(len(top_indices[0])):
 
124
  label_idx = top_indices[0][i].item()
125
+ label_name = FASHION_LABELS.get(label_idx, f"Catégorie {label_idx}")