MODLI commited on
Commit
e005a79
·
verified ·
1 Parent(s): 205b07b

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +73 -0
train.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
3
+ from transformers import DefaultDataCollator
4
+ from datasets import load_dataset, Image
5
+ import torch
6
+
7
+ # 1. Charger le dataset et le mapper aux classes
8
+ dataset = load_dataset("ashraq/fashion-product-images-small", name="styles", split="train")
9
+ dataset = dataset.train_test_split(test_size=0.2)
10
+ train_ds = dataset["train"]
11
+ test_ds = dataset["test"]
12
+
13
+ # 2. Créer la liste des labels (catégories uniques)
14
+ labels = train_ds.unique("articleType")
15
+ label2id, id2label = {}, {}
16
+ for i, label in enumerate(labels):
17
+ label2id[label] = i
18
+ id2label[i] = label
19
+
20
+ # 3. Charger le processeur et le modèle de base CORRECTS
21
+ # On prend un modèle pré-entraîné sur ImageNet, pas sur des haricots !
22
+ model_ckpt = "google/vit-base-patch16-224"
23
+ processor = ViTImageProcessor.from_pretrained(model_ckpt)
24
+ model = ViTForImageClassification.from_pretrained(
25
+ model_ckpt,
26
+ num_labels=len(labels),
27
+ id2label=id2label,
28
+ label2id=label2id,
29
+ ignore_mismatched_sizes=True # Important car le nombre de classes change
30
+ )
31
+
32
+ # 4. Fonction de preprocessing pour transformer les images
33
+ def transform(example_batch):
34
+ inputs = processor([Image.open(img).convert("RGB") for img in example_batch["image_path"]], return_tensors="pt")
35
+ inputs["labels"] = [label2id[label] for label in example_batch["articleType"]]
36
+ return inputs
37
+
38
+ # Appliquer le preprocessing
39
+ train_ds = train_ds.cast_column("image_path", Image())
40
+ test_ds = test_ds.cast_column("image_path", Image())
41
+
42
+ train_ds.set_transform(transform)
43
+ test_ds.set_transform(transform)
44
+
45
+ # 5. Définir les arguments d'entraînement
46
+ training_args = TrainingArguments(
47
+ output_dir="./vit-fashion-classifier",
48
+ per_device_train_batch_size=16,
49
+ evaluation_strategy="steps",
50
+ num_train_epochs=4,
51
+ fp16=True,
52
+ save_steps=100,
53
+ eval_steps=100,
54
+ logging_steps=10,
55
+ learning_rate=2e-4,
56
+ save_total_limit=2,
57
+ remove_unused_columns=False,
58
+ push_to_hub=True, # Pour pousser directement sur votre HF Space après l'entraînement
59
+ hub_model_id="MODLI/vit-fashion-classifier", # Remplacez par votre repo
60
+ )
61
+
62
+ # 6. Lancer l'entraînement
63
+ trainer = Trainer(
64
+ model=model,
65
+ args=training_args,
66
+ data_collator=DefaultDataCollator(),
67
+ train_dataset=train_ds,
68
+ eval_dataset=test_ds,
69
+ tokenizer=processor,
70
+ )
71
+
72
+ trainer.train()
73
+ trainer.push_to_hub()