Spaces:
Sleeping
Sleeping
| import os | |
| import types | |
| import torch | |
| import soundfile as sf | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from typing import List, Optional, Tuple, Union | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, AutoModelForSeq2SeqLM, T5ForConditionalGeneration | |
| from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training | |
| from slam_llm.utils.config_utils import generate_peft_config | |
| from slam_llm.utils.train_utils import print_module_size, print_model_size | |
| from peft import PeftModel, PeftConfig | |
| from torch.nn import CrossEntropyLoss | |
| from slam_llm.utils.metric import compute_accuracy | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def model_factory(train_config, model_config, **kwargs): | |
| # return necessary components for training | |
| tokenizer = setup_tokenizer(train_config, model_config, **kwargs) | |
| encoder = setup_encoder(train_config, model_config, **kwargs) | |
| # llm | |
| llm = setup_llm(train_config, model_config, **kwargs) | |
| # projector | |
| encoder_projector = setup_encoder_projector( | |
| train_config, model_config, **kwargs | |
| ) | |
| model = slam_model( | |
| encoder, | |
| llm, | |
| encoder_projector, | |
| tokenizer, | |
| train_config, | |
| model_config, | |
| **kwargs, | |
| ) | |
| ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) | |
| if ckpt_path is not None: | |
| logger.info("loading other parts from: {}".format(ckpt_path)) | |
| ckpt_dict = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(ckpt_dict, strict=False) | |
| print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
| return model, tokenizer | |
| def setup_tokenizer(train_config, model_config, **kwargs): | |
| # Load the tokenizer and add special tokens | |
| if "vallex" in model_config.llm_name.lower(): | |
| return None | |
| elif "mupt" in model_config.llm_name.lower(): | |
| tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path, | |
| trust_remote_code=True, | |
| use_fast=False) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path) | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| return tokenizer | |
| def setup_encoder(train_config, model_config, **kwargs): | |
| encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else [] | |
| if len(encoder_list) == 0: | |
| return None | |
| if len(encoder_list) == 1: | |
| encoder_name = encoder_list[0] | |
| if encoder_name == "whisper" or encoder_name == "qwen-audio": | |
| from slam_llm.models.encoder import WhisperWrappedEncoder | |
| encoder = WhisperWrappedEncoder.load(model_config) | |
| if encoder_name == "beats": | |
| from slam_llm.models.encoder import BEATsEncoder | |
| encoder = BEATsEncoder.load(model_config) | |
| if encoder_name == "eat": | |
| from slam_llm.models.encoder import EATEncoder | |
| encoder = EATEncoder.load(model_config) | |
| if encoder_name == "SpatialAST": | |
| from slam_llm.models.encoder import SpatialASTEncoder | |
| encoder = SpatialASTEncoder.load(model_config) | |
| if encoder_name == "wavlm": | |
| from slam_llm.models.encoder import WavLMEncoder | |
| encoder = WavLMEncoder.load(model_config) | |
| if encoder_name == "av_hubert": | |
| from slam_llm.models.encoder import AVHubertEncoder | |
| encoder = AVHubertEncoder.load(model_config) | |
| if encoder_name == "hubert": | |
| from slam_llm.models.encoder import HubertEncoder | |
| encoder = HubertEncoder.load(model_config) | |
| if encoder_name == "musicfm": | |
| from slam_llm.models.encoder import MusicFMEncoder | |
| encoder = MusicFMEncoder.load(model_config) | |
| if "llama" in encoder_name.lower(): | |
| from slam_llm.models.encoder import HfTextEncoder | |
| encoder = HfTextEncoder.load(model_config) | |
| print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
| if train_config.freeze_encoder: | |
| for name, param in encoder.named_parameters(): | |
| param.requires_grad = False | |
| encoder.eval() | |
| print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
| return encoder | |
| def setup_llm(train_config, model_config, **kwargs): | |
| from pkg_resources import packaging | |
| use_cache = False if train_config.enable_fsdp or train_config.enable_ddp else None | |
| if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.low_cpu_fsdp: | |
| """ | |
| for FSDP, we can save cpu memory by loading pretrained model on rank0 only. | |
| this avoids cpu oom when loading large models like llama 70B, in which case | |
| model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms | |
| overhead and currently requires latest nightly. | |
| """ | |
| # v = packaging.version.parse(torch.__version__) | |
| # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 | |
| # if not verify_latest_nightly: | |
| # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " | |
| # "please install latest nightly.") | |
| rank = int(os.environ["RANK"]) | |
| if rank == 0: | |
| if "vallex" in model_config.llm_name.lower(): | |
| from src.slam_llm.models.vallex.vallex_config import VallexConfig | |
| from src.slam_llm.models.vallex.vallex_model import VALLE | |
| vallex_config = VallexConfig( | |
| **model_config | |
| ) | |
| model = VALLE(vallex_config) | |
| elif "aya" in model_config.llm_name.lower(): | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_config.llm_path, | |
| load_in_8bit=True if train_config.quantization else None, | |
| device_map="auto" if train_config.quantization else None, | |
| use_cache=use_cache, | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_config.llm_path, | |
| load_in_8bit=True if train_config.quantization else None, | |
| device_map="auto" if train_config.quantization else None, | |
| use_cache=use_cache, | |
| ) | |
| else: | |
| llama_config = AutoConfig.from_pretrained(model_config.llm_path) | |
| llama_config.use_cache = use_cache | |
| # with torch.device("meta"): | |
| if "aya" in model_config.llm_name.lower(): | |
| model = AutoModelForSeq2SeqLM(llama_config) | |
| else: | |
| model = AutoModelForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta` | |
| else: | |
| if "vallex" in model_config.llm_name.lower(): | |
| from src.slam_llm.models.vallex.vallex_config import VallexConfig | |
| from src.slam_llm.models.vallex.vallex_model import VALLE | |
| vallex_config = VallexConfig( | |
| **model_config | |
| ) | |
| model = VALLE(vallex_config) | |
| elif "aya" in model_config.llm_name.lower(): | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_config.llm_path, | |
| load_in_8bit=True if train_config.quantization else None, | |
| device_map="auto" if train_config.quantization else None, | |
| use_cache=use_cache, | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_config.llm_path, | |
| load_in_8bit=True if train_config.quantization else None, | |
| device_map="auto" if train_config.quantization else None, | |
| use_cache=use_cache, | |
| ) | |
| if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.use_fast_kernels: | |
| """ | |
| For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable | |
| using of Flash Attention or Xformer memory-efficient kernels | |
| based on the hardware being used. This would speed up fine-tuning. | |
| """ | |
| try: | |
| from optimum.bettertransformer import BetterTransformer | |
| model = BetterTransformer.transform(model) | |
| except ImportError: | |
| logger.warning("Module 'optimum' not found. Please install 'optimum' it before proceeding.") | |
| print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
| # Prepare the model for int8 training if quantization is enabled | |
| if train_config.quantization: | |
| model = prepare_model_for_kbit_training(model) | |
| if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers` | |
| for name, param in model.named_parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| if kwargs.get("peft_ckpt", None): # (FIX:MZY):reload will get wrong results when decoding | |
| logger.info("loading peft_ckpt from: {}".format(kwargs.get("peft_ckpt"))) | |
| model = PeftModel.from_pretrained(model=model, model_id=kwargs.get("peft_ckpt"), is_trainable=True) | |
| model.print_trainable_parameters() | |
| elif train_config.use_peft: | |
| logger.info("setup peft...") | |
| peft_config = generate_peft_config(train_config) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
| return model | |
| def setup_encoder_projector(train_config, model_config, **kwargs): | |
| if model_config.encoder_projector == "linear": | |
| from slam_llm.models.projector import EncoderProjectorConcat | |
| encoder_projector = EncoderProjectorConcat(model_config) | |
| elif model_config.encoder_projector == "cov1d-linear": | |
| from slam_llm.models.projector import EncoderProjectorCov1d | |
| encoder_projector = EncoderProjectorCov1d(model_config) | |
| elif model_config.encoder_projector == "q-former": | |
| from slam_llm.models.projector import EncoderProjectorQFormer | |
| encoder_projector = EncoderProjectorQFormer(model_config) | |
| else: | |
| return None | |
| print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
| return encoder_projector | |
| class slam_model(nn.Module): | |
| def __init__( | |
| self, | |
| encoder: nn.Module, | |
| llm: nn.Module, | |
| encoder_projector: nn.Module, | |
| tokenizer, | |
| train_config, | |
| model_config, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| # modality encoder | |
| self.encoder = encoder | |
| # llm | |
| self.llm = llm | |
| # projector | |
| self.encoder_projector = encoder_projector | |
| # tokenizer | |
| self.tokenizer = tokenizer | |
| self.metric = kwargs.get("metric", "acc") | |
| self.train_config = train_config | |
| self.model_config = model_config | |
| if train_config.get("enable_deepspeed", False): | |
| def new_forward(self, input): | |
| output = F.layer_norm( | |
| input.float(), | |
| self.normalized_shape, | |
| self.weight.float() if self.weight is not None else None, | |
| self.bias.float() if self.bias is not None else None, | |
| self.eps, | |
| ) | |
| return output.type_as(input) | |
| for item in self.modules(): | |
| if isinstance(item, nn.LayerNorm): | |
| item.forward = types.MethodType(new_forward, item) | |
| def forward(self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ): | |
| audio_mel = kwargs.get("audio_mel", None) | |
| audio_mel_mask = kwargs.get("audio_mel_mask", None) | |
| audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper | |
| audio = kwargs.get("audio", None) | |
| audio_mask = kwargs.get("audio_mask", None) | |
| visual = kwargs.get("visual", None) | |
| visual_mask = kwargs.get("visual_mask", None) | |
| # for text encoder | |
| instruct_ids = kwargs.get("instruct_ids", None) | |
| instruct_mask = kwargs.get("instruct_mask", None) | |
| modality_mask = kwargs.get("modality_mask", None) | |
| zh_data = kwargs.get("zh", None) | |
| en_data = kwargs.get("en", None) | |
| encoder_outs = None | |
| if audio_mel is not None or audio is not None or visual is not None: | |
| if self.train_config.freeze_encoder: # freeze encoder | |
| self.encoder.eval() | |
| if self.model_config.encoder_name == "whisper": | |
| encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim | |
| if self.model_config.encoder_name == "beats": | |
| encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim | |
| if self.model_config.encoder_name == "eat": | |
| encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x'] | |
| if self.model_config.encoder_name == "SpatialAST": | |
| encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768] | |
| if self.model_config.encoder_name == "wavlm": | |
| encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask | |
| if self.model_config.encoder_name == "hubert": | |
| results = self.encoder(source = audio, padding_mask = 1-audio_mask) | |
| if self.model_config.encoder_type == "pretrain": | |
| encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] | |
| if self.model_config.encoder_type == "finetune": | |
| encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] | |
| encoder_outs = encoder_outs.transpose(0, 1) | |
| if self.model_config.encoder_name == "av_hubert": | |
| results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim | |
| encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] | |
| encoder_outs = encoder_outs.transpose(0, 1) | |
| audio_mel_post_mask = (~audio_mel_post_mask).float() | |
| if self.model_config.encoder_name == 'musicfm': | |
| encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask | |
| if self.encoder is None: | |
| encoder_outs = audio_mel if audio_mel is not None else audio | |
| if self.model_config.encoder_projector == "q-former": | |
| encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) | |
| if self.model_config.encoder_projector == "linear": | |
| encoder_outs = self.encoder_projector(encoder_outs) | |
| if self.model_config.encoder_projector == "cov1d-linear": | |
| encoder_outs = self.encoder_projector(encoder_outs) | |
| if instruct_ids is not None: | |
| if self.encoder is not None: | |
| encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state | |
| if self.model_config.encoder_projector == "q-former": | |
| encoder_outs = self.encoder_projector(encoder_outs, instruct_mask) | |
| if self.model_config.encoder_projector == "linear": | |
| encoder_outs = self.encoder_projector(encoder_outs) | |
| if input_ids is not None: | |
| input_ids[input_ids == -1] = 0 | |
| if isinstance(self.llm, T5ForConditionalGeneration): | |
| inputs_embeds = self.llm.shared(input_ids) | |
| else: | |
| if hasattr(self.llm.model, "embed_tokens"): | |
| inputs_embeds = self.llm.model.embed_tokens(input_ids) | |
| elif hasattr(self.llm.model.model, "embed_tokens"): | |
| inputs_embeds = self.llm.model.model.embed_tokens(input_ids) | |
| else: | |
| inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) | |
| if modality_mask is not None: | |
| modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1) | |
| modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist() | |
| encoder_outs_pad = torch.zeros_like(inputs_embeds) | |
| for i in range(encoder_outs.shape[0]): | |
| encoder_outs_pad[ | |
| i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i] | |
| ] = encoder_outs[i][:modality_lengths[i]] | |
| inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None]) | |
| if kwargs.get("inference_mode", False): | |
| return inputs_embeds, attention_mask | |
| if zh_data is not None and en_data is not None: | |
| model_outputs, acc = self.llm(zh=zh_data, en=en_data) | |
| else: | |
| model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) | |
| acc = -1 | |
| if self.metric: | |
| with torch.no_grad(): | |
| preds = torch.argmax(model_outputs.logits, -1) | |
| acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100) | |
| return model_outputs, acc | |
| def generate(self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ): | |
| kwargs["inference_mode"] = True | |
| inputs_embeds, attention_mask = self.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| labels=labels, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| **kwargs, | |
| ) | |
| model_outputs = self.llm.generate( | |
| inputs_embeds=inputs_embeds, | |
| # max_length=kwargs.get("max_length", 200), | |
| max_new_tokens=kwargs.get("max_new_tokens", 200), | |
| num_beams=kwargs.get("num_beams", 4), | |
| do_sample=kwargs.get("do_sample", False), | |
| min_length=kwargs.get("min_length", 1), | |
| top_p=kwargs.get("top_p", 1.0), | |
| repetition_penalty=kwargs.get("repetition_penalty", 1.0), | |
| length_penalty=kwargs.get("length_penalty", 1.0), | |
| temperature=kwargs.get("temperature", 1.0), | |
| attention_mask=attention_mask, | |
| bos_token_id=self.tokenizer.bos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id | |
| ) | |
| return model_outputs | |