Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
| import inspect | |
| # from dataclasses import asdict | |
| import torch.distributed as dist | |
| from torch.utils.data import DistributedSampler | |
| from peft import ( | |
| LoraConfig, | |
| AdaptionPromptConfig, | |
| PrefixTuningConfig, | |
| ) | |
| from transformers import default_data_collator | |
| from transformers.data import DataCollatorForSeq2Seq | |
| # from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config | |
| from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler | |
| from omegaconf import OmegaConf | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # def update_config(config, **kwargs): | |
| # if isinstance(config, (tuple, list)): | |
| # for c in config: | |
| # update_config(c, **kwargs) | |
| # else: | |
| # for k, v in kwargs.items(): | |
| # if hasattr(config, k): | |
| # setattr(config, k, v) | |
| # elif "." in k: | |
| # # allow --some_config.some_param=True | |
| # config_name, param_name = k.split(".") | |
| # if type(config).__name__ == config_name: | |
| # if hasattr(config, param_name): | |
| # setattr(config, param_name, v) | |
| # else: | |
| # # In case of specialized config we can warm user | |
| # logger.warning(f"Warning: {config_name} does not accept parameter: {k}") | |
| # elif isinstance(config, train_config): | |
| # logger.warning(f"Warning: unknown parameter {k}") | |
| def generate_peft_config(train_config): | |
| # configs = (lora_config, llama_adapter_config, prefix_config) | |
| # peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) | |
| peft_configs = {"lora": LoraConfig, | |
| "llama_adapter": AdaptionPromptConfig, | |
| "prefix": PrefixTuningConfig | |
| } | |
| # names = tuple(c.__name__.rstrip("_config") for c in configs) | |
| # | |
| # assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" | |
| # | |
| # config = configs[names.index(train_config.peft_method)]() | |
| config = train_config.peft_config | |
| params = OmegaConf.to_container(config, resolve=True) | |
| # peft_config = peft_configs[names.index(train_config.peft_method)](**params) | |
| params.pop("peft_method", None) #(FIX:MZY): remove peft_method from params to avoid error | |
| peft_config = peft_configs[config.get("peft_method", "lora")](**params) | |
| return peft_config | |
| def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): | |
| kwargs = {} | |
| batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size | |
| if train_config.batching_strategy == "padding": | |
| if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: | |
| kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( | |
| dataset, | |
| batch_size=batch_size, | |
| rank=dist.get_rank(), | |
| num_replicas=dist.get_world_size(), | |
| shuffle=mode=="train", | |
| ) | |
| else: | |
| kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") | |
| kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) | |
| elif train_config.batching_strategy == "packing": | |
| if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: | |
| kwargs["sampler"] = DistributedSampler( | |
| dataset, | |
| rank=dist.get_rank(), | |
| num_replicas=dist.get_world_size(), | |
| shuffle=mode=="train", | |
| ) | |
| kwargs["batch_size"] = batch_size | |
| kwargs["drop_last"] = True | |
| kwargs["collate_fn"] = default_data_collator | |
| else: | |
| # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") | |
| if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: | |
| kwargs["sampler"] = DistributedSampler( | |
| dataset, | |
| rank=dist.get_rank(), | |
| num_replicas=dist.get_world_size(), | |
| shuffle=mode=="train", | |
| ) | |
| kwargs["batch_size"] = batch_size | |
| kwargs["drop_last"] = True | |
| kwargs["collate_fn"] = dataset.collator | |
| logger.info(f"Using batching strategy: {train_config.batching_strategy}") | |
| return kwargs | |