# train.py from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer from transformers import DefaultDataCollator from datasets import load_dataset, Image import torch # 1. Charger le dataset et le mapper aux classes dataset = load_dataset("ashraq/fashion-product-images-small", name="styles", split="train") dataset = dataset.train_test_split(test_size=0.2) train_ds = dataset["train"] test_ds = dataset["test"] # 2. Créer la liste des labels (catégories uniques) labels = train_ds.unique("articleType") label2id, id2label = {}, {} for i, label in enumerate(labels): label2id[label] = i id2label[i] = label # 3. Charger le processeur et le modèle de base CORRECTS # On prend un modèle pré-entraîné sur ImageNet, pas sur des haricots ! model_ckpt = "google/vit-base-patch16-224" processor = ViTImageProcessor.from_pretrained(model_ckpt) model = ViTForImageClassification.from_pretrained( model_ckpt, num_labels=len(labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True # Important car le nombre de classes change ) # 4. Fonction de preprocessing pour transformer les images def transform(example_batch): inputs = processor([Image.open(img).convert("RGB") for img in example_batch["image_path"]], return_tensors="pt") inputs["labels"] = [label2id[label] for label in example_batch["articleType"]] return inputs # Appliquer le preprocessing train_ds = train_ds.cast_column("image_path", Image()) test_ds = test_ds.cast_column("image_path", Image()) train_ds.set_transform(transform) test_ds.set_transform(transform) # 5. Définir les arguments d'entraînement training_args = TrainingArguments( output_dir="./vit-fashion-classifier", per_device_train_batch_size=16, evaluation_strategy="steps", num_train_epochs=4, fp16=True, save_steps=100, eval_steps=100, logging_steps=10, learning_rate=2e-4, save_total_limit=2, remove_unused_columns=False, push_to_hub=True, # Pour pousser directement sur votre HF Space après l'entraînement hub_model_id="MODLI/vit-fashion-classifier", # Remplacez par votre repo ) # 6. Lancer l'entraînement trainer = Trainer( model=model, args=training_args, data_collator=DefaultDataCollator(), train_dataset=train_ds, eval_dataset=test_ds, tokenizer=processor, ) trainer.train() trainer.push_to_hub()