Spaces:
Runtime error
Runtime error
| import os | |
| import argparse | |
| import pandas as pd | |
| from datasets import Dataset | |
| from sacrebleu.metrics import BLEU, CHRF | |
| from peft import LoraConfig, get_peft_model | |
| from IndicTransToolkit import IndicProcessor, IndicDataCollator | |
| from transformers import ( | |
| Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments, | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| EarlyStoppingCallback, | |
| ) | |
| bleu_metric = BLEU() | |
| chrf_metric = CHRF() | |
| def get_arg_parse(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--src_lang_list", | |
| type=str, | |
| help="comma separated list of source languages", | |
| ) | |
| parser.add_argument( | |
| "--tgt_lang_list", | |
| type=str, | |
| help="comma separated list of target languages", | |
| ) | |
| parser.add_argument("--data_dir", type=str) | |
| parser.add_argument("--output_dir", type=str) | |
| parser.add_argument("--save_steps", type=int, default=1000) | |
| parser.add_argument("--eval_steps", type=int, default=1000) | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--num_train_epochs", type=int, default=100) | |
| parser.add_argument("--max_steps", type=int, default=1000000) | |
| parser.add_argument("--grad_accum_steps", type=int, default=4) | |
| parser.add_argument("--warmup_steps", type=int, default=4000) | |
| parser.add_argument("--warmup_ratio", type=int, default=0.0) | |
| parser.add_argument("--max_grad_norm", type=float, default=1.0) | |
| parser.add_argument("--learning_rate", type=float, default=5e-4) | |
| parser.add_argument("--weight_decay", type=float, default=0.0) | |
| parser.add_argument("--adam_beta1", type=float, default=0.9) | |
| parser.add_argument("--adam_beta2", type=float, default=0.98) | |
| parser.add_argument("--dropout", type=float, default=0.0) | |
| parser.add_argument("--print_samples", action="store_true") | |
| parser.add_argument( | |
| "--optimizer", | |
| type=str, | |
| default="adamw_torch", | |
| choices=[ | |
| "adam_hf", | |
| "adamw_torch", | |
| "adamw_torch_fused", | |
| "adamw_apex_fused", | |
| "adafactor", | |
| ], | |
| ) | |
| parser.add_argument( | |
| "--lr_scheduler", | |
| type=str, | |
| default="inverse_sqrt", | |
| choices=[ | |
| "inverse_sqrt", | |
| "linear", | |
| "polynomial", | |
| "cosine", | |
| "constant", | |
| "constant_with_warmup", | |
| ], | |
| ) | |
| parser.add_argument("--label_smoothing", type=float, default=0.0) | |
| parser.add_argument("--num_workers", type=int, default=8) | |
| parser.add_argument("--metric_for_best_model", type=str, default="eval_loss") | |
| parser.add_argument("--greater_is_better", action="store_true") | |
| parser.add_argument("--lora_target_modules", type=str, default="q_proj,k_proj") | |
| parser.add_argument("--lora_dropout", type=float, default=0.1) | |
| parser.add_argument("--lora_r", type=int, default=16) | |
| parser.add_argument("--lora_alpha", type=int, default=32) | |
| parser.add_argument( | |
| "--report_to", | |
| type=str, | |
| default="none", | |
| choices=["wandb", "tensorboard", "azure_ml", "none"], | |
| ) | |
| parser.add_argument("--patience", type=int, default=5), | |
| parser.add_argument("--threshold", type=float, default=1e-3) | |
| return parser | |
| def load_and_process_translation_dataset( | |
| data_dir, | |
| split="train", | |
| tokenizer=None, | |
| processor=None, | |
| src_lang_list=None, | |
| tgt_lang_list=None, | |
| num_proc=8, | |
| seed=42 | |
| ): | |
| complete_dataset = { | |
| "sentence_SRC": [], | |
| "sentence_TGT": [], | |
| } | |
| for src_lang in src_lang_list: | |
| for tgt_lang in tgt_lang_list: | |
| if src_lang == tgt_lang: | |
| continue | |
| src_path = os.path.join( | |
| data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{src_lang}" | |
| ) | |
| tgt_path = os.path.join( | |
| data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{tgt_lang}" | |
| ) | |
| if not os.path.exists(src_path) or not os.path.exists(tgt_path): | |
| raise FileNotFoundError( | |
| f"Source ({split}.{src_lang}) or Target ({split}.{tgt_lang}) file not found in {data_dir}" | |
| ) | |
| with open(src_path, encoding="utf-8") as src_file, open( | |
| tgt_path, encoding="utf-8" | |
| ) as tgt_file: | |
| src_lines = src_file.readlines() | |
| tgt_lines = tgt_file.readlines() | |
| # Ensure both files have the same number of lines | |
| assert len(src_lines) == len( | |
| tgt_lines | |
| ), f"Source and Target files have different number of lines for {split}.{src_lang} and {split}.{tgt_lang}" | |
| complete_dataset["sentence_SRC"] += processor.preprocess_batch( | |
| src_lines, src_lang=src_lang, tgt_lang=tgt_lang, is_target=False | |
| ) | |
| complete_dataset["sentence_TGT"] += processor.preprocess_batch( | |
| tgt_lines, src_lang=tgt_lang, tgt_lang=src_lang, is_target=True | |
| ) | |
| complete_dataset = Dataset.from_dict(complete_dataset).shuffle(seed=seed) | |
| return complete_dataset.map( | |
| lambda example: preprocess_fn( | |
| example, | |
| tokenizer=tokenizer | |
| ), | |
| batched=True, | |
| num_proc=num_proc, | |
| ) | |
| def compute_metrics_factory( | |
| tokenizer, metric_dict=None, print_samples=False, n_samples=10 | |
| ): | |
| def compute_metrics(eval_preds): | |
| preds, labels = eval_preds | |
| labels[labels == -100] = tokenizer.pad_token_id | |
| preds[preds == -100] = tokenizer.pad_token_id | |
| with tokenizer.as_target_tokenizer(): | |
| preds = [ | |
| x.strip() | |
| for x in tokenizer.batch_decode( | |
| preds, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| ] | |
| labels = [ | |
| x.strip() | |
| for x in tokenizer.batch_decode( | |
| labels, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| ] | |
| assert len(preds) == len( | |
| labels | |
| ), "Predictions and Labels have different lengths" | |
| df = pd.DataFrame({"Predictions": preds, "References": labels}).sample( | |
| n=n_samples | |
| ) | |
| if print_samples: | |
| for pred, label in zip(df["Predictions"].values, df["References"].values): | |
| print(f" | > Prediction: {pred}") | |
| print(f" | > Reference: {label}\n") | |
| return { | |
| metric_name: metric.corpus_score(preds, [labels]).score | |
| for (metric_name, metric) in metric_dict.items() | |
| } | |
| return compute_metrics | |
| def preprocess_fn(example, tokenizer, **kwargs): | |
| model_inputs = tokenizer( | |
| example["sentence_SRC"], truncation=True, padding=False, max_length=256 | |
| ) | |
| with tokenizer.as_target_tokenizer(): | |
| labels = tokenizer( | |
| example["sentence_TGT"], truncation=True, padding=False, max_length=256 | |
| ) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| def main(args): | |
| print(f" | > Loading {args.model} and tokenizer ...") | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| args.model, | |
| trust_remote_code=True, | |
| attn_implementation="eager", | |
| dropout=args.dropout | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | |
| processor = IndicProcessor(inference=False) # pre-process before tokenization | |
| data_collator = IndicDataCollator( | |
| tokenizer=tokenizer, | |
| model=model, | |
| padding="longest", # saves padding tokens | |
| pad_to_multiple_of=8, # better to have it as 8 when using fp16 | |
| label_pad_token_id=-100 | |
| ) | |
| if args.data_dir is not None: | |
| train_dataset = load_and_process_translation_dataset( | |
| args.data_dir, | |
| split="train", | |
| tokenizer=tokenizer, | |
| processor=processor, | |
| src_lang_list=args.src_lang_list.split(","), | |
| tgt_lang_list=args.tgt_lang_list.split(","), | |
| ) | |
| print(f" | > Loaded train dataset from {args.data_dir}. Size: {len(train_dataset)} ...") | |
| eval_dataset = load_and_process_translation_dataset( | |
| args.data_dir, | |
| split="dev", | |
| tokenizer=tokenizer, | |
| processor=processor, | |
| src_lang_list=args.src_lang_list.split(","), | |
| tgt_lang_list=args.tgt_lang_list.split(","), | |
| ) | |
| print(f" | > Loaded eval dataset from {args.data_dir}. Size: {len(eval_dataset)} ...") | |
| else: | |
| raise ValueError(" | > Data directory not provided") | |
| lora_config = LoraConfig( | |
| r=args.lora_r, | |
| bias="none", | |
| inference_mode=False, | |
| task_type="SEQ_2_SEQ_LM", | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| target_modules=args.lora_target_modules.split(","), | |
| ) | |
| model.set_label_smoothing(args.label_smoothing) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| print(f" | > Loading metrics factory with BLEU and chrF ...") | |
| seq2seq_compute_metrics = compute_metrics_factory( | |
| tokenizer=tokenizer, | |
| print_samples=args.print_samples, | |
| metric_dict={"BLEU": bleu_metric, "chrF": chrf_metric}, | |
| ) | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir=args.output_dir, | |
| do_train=True, | |
| do_eval=True, | |
| fp16=True, # use fp16 for faster training | |
| logging_strategy="steps", | |
| evaluation_strategy="steps", | |
| save_strategy="steps", | |
| logging_steps=100, | |
| save_total_limit=1, | |
| predict_with_generate=True, | |
| load_best_model_at_end=True, | |
| max_steps=args.max_steps, # max_steps overrides num_train_epochs | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum_steps, | |
| eval_accumulation_steps=args.grad_accum_steps, | |
| weight_decay=args.weight_decay, | |
| adam_beta1=args.adam_beta1, | |
| adam_beta2=args.adam_beta2, | |
| max_grad_norm=args.max_grad_norm, | |
| optim=args.optimizer, | |
| lr_scheduler_type=args.lr_scheduler, | |
| warmup_ratio=args.warmup_ratio, | |
| warmup_steps=args.warmup_steps, | |
| learning_rate=args.learning_rate, | |
| num_train_epochs=args.num_train_epochs, | |
| save_steps=args.save_steps, | |
| eval_steps=args.eval_steps, | |
| dataloader_num_workers=args.num_workers, | |
| metric_for_best_model=args.metric_for_best_model, | |
| greater_is_better=args.greater_is_better, | |
| report_to=args.report_to, | |
| generation_max_length=256, | |
| generation_num_beams=5, | |
| sortish_sampler=True, | |
| group_by_length=True, | |
| include_tokens_per_second=True, | |
| include_num_input_tokens_seen=True, | |
| dataloader_prefetch_factor=2, | |
| ) | |
| # Create Trainer instance | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=data_collator, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| compute_metrics=seq2seq_compute_metrics, | |
| callbacks=[ | |
| EarlyStoppingCallback( | |
| early_stopping_patience=args.patience, | |
| early_stopping_threshold=args.threshold, | |
| ) | |
| ], | |
| ) | |
| print(f" | > Starting training ...") | |
| try: | |
| trainer.train() | |
| except KeyboardInterrupt: | |
| print(f" | > Training interrupted ...") | |
| # this will only save the LoRA adapter weights | |
| model.save_pretrained(args.output_dir) | |
| if __name__ == "__main__": | |
| parser = get_arg_parse() | |
| args = parser.parse_args() | |
| main(args) | |