Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import math | |
| import torch | |
| import logging | |
| import random | |
| import subprocess | |
| import numpy as np | |
| import torch.distributed as dist | |
| from torch import inf | |
| from PIL import Image | |
| from typing import Union, Iterable | |
| from collections import OrderedDict | |
| from torch.utils.tensorboard import SummaryWriter | |
| from diffusers.utils import is_bs4_available, is_ftfy_available | |
| import html | |
| import re | |
| import urllib.parse as ul | |
| if is_bs4_available(): | |
| from bs4 import BeautifulSoup | |
| if is_ftfy_available(): | |
| import ftfy | |
| _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] | |
| ################################################################################# | |
| # Training Clip Gradients # | |
| ################################################################################# | |
| def get_grad_norm( | |
| parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: | |
| r""" | |
| Copy from torch.nn.utils.clip_grad_norm_ | |
| Clips gradient norm of an iterable of parameters. | |
| The norm is computed over all gradients together, as if they were | |
| concatenated into a single vector. Gradients are modified in-place. | |
| Args: | |
| parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a | |
| single Tensor that will have gradients normalized | |
| max_norm (float or int): max norm of the gradients | |
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |
| infinity norm. | |
| error_if_nonfinite (bool): if True, an error is thrown if the total | |
| norm of the gradients from :attr:`parameters` is ``nan``, | |
| ``inf``, or ``-inf``. Default: False (will switch to True in the future) | |
| Returns: | |
| Total norm of the parameter gradients (viewed as a single vector). | |
| """ | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| grads = [p.grad for p in parameters if p.grad is not None] | |
| norm_type = float(norm_type) | |
| if len(grads) == 0: | |
| return torch.tensor(0.) | |
| device = grads[0].device | |
| if norm_type == inf: | |
| norms = [g.detach().abs().max().to(device) for g in grads] | |
| total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) | |
| else: | |
| total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) | |
| return total_norm | |
| def clip_grad_norm_( | |
| parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, | |
| error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor: | |
| r""" | |
| Copy from torch.nn.utils.clip_grad_norm_ | |
| Clips gradient norm of an iterable of parameters. | |
| The norm is computed over all gradients together, as if they were | |
| concatenated into a single vector. Gradients are modified in-place. | |
| Args: | |
| parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a | |
| single Tensor that will have gradients normalized | |
| max_norm (float or int): max norm of the gradients | |
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |
| infinity norm. | |
| error_if_nonfinite (bool): if True, an error is thrown if the total | |
| norm of the gradients from :attr:`parameters` is ``nan``, | |
| ``inf``, or ``-inf``. Default: False (will switch to True in the future) | |
| Returns: | |
| Total norm of the parameter gradients (viewed as a single vector). | |
| """ | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| grads = [p.grad for p in parameters if p.grad is not None] | |
| max_norm = float(max_norm) | |
| norm_type = float(norm_type) | |
| if len(grads) == 0: | |
| return torch.tensor(0.) | |
| device = grads[0].device | |
| if norm_type == inf: | |
| norms = [g.detach().abs().max().to(device) for g in grads] | |
| total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) | |
| else: | |
| total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) | |
| if clip_grad: | |
| if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): | |
| raise RuntimeError( | |
| f'The total norm of order {norm_type} for gradients from ' | |
| '`parameters` is non-finite, so it cannot be clipped. To disable ' | |
| 'this error and scale the gradients by the non-finite norm anyway, ' | |
| 'set `error_if_nonfinite=False`') | |
| clip_coef = max_norm / (total_norm + 1e-6) | |
| # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so | |
| # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization | |
| # when the gradients do not reside in CPU memory. | |
| clip_coef_clamped = torch.clamp(clip_coef, max=1.0) | |
| for g in grads: | |
| g.detach().mul_(clip_coef_clamped.to(g.device)) | |
| # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) | |
| return total_norm | |
| def get_experiment_dir(root_dir, args): | |
| # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained: | |
| # root_dir += '-WOPRE' | |
| if args.use_compile: | |
| root_dir += '-Compile' # speedup by torch compile | |
| if args.fixed_spatial: | |
| root_dir += '-FixedSpa' | |
| if args.enable_xformers_memory_efficient_attention: | |
| root_dir += '-Xfor' | |
| if args.gradient_checkpointing: | |
| root_dir += '-Gc' | |
| if args.mixed_precision: | |
| root_dir += '-Amp' | |
| if args.image_size == 512: | |
| root_dir += '-512' | |
| return root_dir | |
| ################################################################################# | |
| # Training Logger # | |
| ################################################################################# | |
| def create_logger(logging_dir): | |
| """ | |
| Create a logger that writes to a log file and stdout. | |
| """ | |
| if dist.get_rank() == 0: # real logger | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| # format='[\033[34m%(asctime)s\033[0m] %(message)s', | |
| format='[%(asctime)s] %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S', | |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| else: # dummy logger (does nothing) | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| return logger | |
| def create_tensorboard(tensorboard_dir): | |
| """ | |
| Create a tensorboard that saves losses. | |
| """ | |
| if dist.get_rank() == 0: # real tensorboard | |
| # tensorboard | |
| writer = SummaryWriter(tensorboard_dir) | |
| return writer | |
| def write_tensorboard(writer, *args): | |
| ''' | |
| write the loss information to a tensorboard file. | |
| Only for pytorch DDP mode. | |
| ''' | |
| if dist.get_rank() == 0: # real tensorboard | |
| writer.add_scalar(args[0], args[1], args[2]) | |
| ################################################################################# | |
| # EMA Update/ DDP Training Utils # | |
| ################################################################################# | |
| def update_ema(ema_model, model, decay=0.9999): | |
| """ | |
| Step the EMA model towards the current model. | |
| """ | |
| ema_params = OrderedDict(ema_model.named_parameters()) | |
| model_params = OrderedDict(model.named_parameters()) | |
| for name, param in model_params.items(): | |
| # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed | |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) | |
| def requires_grad(model, flag=True): | |
| """ | |
| Set requires_grad flag for all parameters in a model. | |
| """ | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |
| def cleanup(): | |
| """ | |
| End DDP training. | |
| """ | |
| dist.destroy_process_group() | |
| def setup_distributed(backend="nccl", port=None): | |
| """Initialize distributed training environment. | |
| support both slurm and torch.distributed.launch | |
| see torch.distributed.init_process_group() for more details | |
| """ | |
| num_gpus = torch.cuda.device_count() | |
| if "SLURM_JOB_ID" in os.environ: | |
| rank = int(os.environ["SLURM_PROCID"]) | |
| world_size = int(os.environ["SLURM_NTASKS"]) | |
| node_list = os.environ["SLURM_NODELIST"] | |
| addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") | |
| # specify master port | |
| if port is not None: | |
| os.environ["MASTER_PORT"] = str(port) | |
| elif "MASTER_PORT" not in os.environ: | |
| # os.environ["MASTER_PORT"] = "29566" | |
| os.environ["MASTER_PORT"] = str(29567 + num_gpus) | |
| if "MASTER_ADDR" not in os.environ: | |
| os.environ["MASTER_ADDR"] = addr | |
| os.environ["WORLD_SIZE"] = str(world_size) | |
| os.environ["LOCAL_RANK"] = str(rank % num_gpus) | |
| os.environ["RANK"] = str(rank) | |
| else: | |
| rank = int(os.environ["RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| # torch.cuda.set_device(rank % num_gpus) | |
| dist.init_process_group( | |
| backend=backend, | |
| world_size=world_size, | |
| rank=rank, | |
| ) | |
| ################################################################################# | |
| # Testing Utils # | |
| ################################################################################# | |
| def save_video_grid(video, nrow=None): | |
| b, t, h, w, c = video.shape | |
| if nrow is None: | |
| nrow = math.ceil(math.sqrt(b)) | |
| ncol = math.ceil(b / nrow) | |
| padding = 1 | |
| video_grid = torch.zeros((t, (padding + h) * nrow + padding, | |
| (padding + w) * ncol + padding, c), dtype=torch.uint8) | |
| for i in range(b): | |
| r = i // ncol | |
| c = i % ncol | |
| start_r = (padding + h) * r | |
| start_c = (padding + w) * c | |
| video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] | |
| return video_grid | |
| def find_model(model_name): | |
| """ | |
| Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path. | |
| """ | |
| assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}' | |
| checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) | |
| if "ema" in checkpoint: # supports checkpoints from train.py | |
| print('Using Ema!') | |
| checkpoint = checkpoint["ema"] | |
| else: | |
| print('Using model!') | |
| checkpoint = checkpoint['model'] | |
| return checkpoint | |
| ################################################################################# | |
| # MMCV Utils # | |
| ################################################################################# | |
| def collect_env(): | |
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from mmcv.utils import collect_env as collect_base_env | |
| from mmcv.utils import get_git_hash | |
| """Collect the information of the running environments.""" | |
| env_info = collect_base_env() | |
| env_info['MMClassification'] = get_git_hash()[:7] | |
| for name, val in env_info.items(): | |
| print(f'{name}: {val}') | |
| print(torch.cuda.get_arch_list()) | |
| print(torch.version.cuda) | |
| ################################################################################# | |
| # Pixart-alpha Utils # | |
| ################################################################################# | |
| bad_punct_regex = re.compile( | |
| r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" | |
| ) | |
| def text_preprocessing(text, clean_caption=False): | |
| if clean_caption and not is_bs4_available(): | |
| clean_caption = False | |
| if clean_caption and not is_ftfy_available(): | |
| clean_caption = False | |
| if not isinstance(text, (tuple, list)): | |
| text = [text] | |
| def process(text: str): | |
| if clean_caption: | |
| text = clean_caption(text) | |
| text = clean_caption(text) | |
| else: | |
| text = text.lower().strip() | |
| return text | |
| return [process(t) for t in text] | |
| # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption | |
| def clean_caption(caption): | |
| caption = str(caption) | |
| caption = ul.unquote_plus(caption) | |
| caption = caption.strip().lower() | |
| caption = re.sub("<person>", "person", caption) | |
| # urls: | |
| caption = re.sub( | |
| r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa | |
| "", | |
| caption, | |
| ) # regex for urls | |
| caption = re.sub( | |
| r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa | |
| "", | |
| caption, | |
| ) # regex for urls | |
| # html: | |
| caption = BeautifulSoup(caption, features="html.parser").text | |
| # @<nickname> | |
| caption = re.sub(r"@[\w\d]+\b", "", caption) | |
| # 31C0—31EF CJK Strokes | |
| # 31F0—31FF Katakana Phonetic Extensions | |
| # 3200—32FF Enclosed CJK Letters and Months | |
| # 3300—33FF CJK Compatibility | |
| # 3400—4DBF CJK Unified Ideographs Extension A | |
| # 4DC0—4DFF Yijing Hexagram Symbols | |
| # 4E00—9FFF CJK Unified Ideographs | |
| caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) | |
| caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) | |
| caption = re.sub(r"[\u3200-\u32ff]+", "", caption) | |
| caption = re.sub(r"[\u3300-\u33ff]+", "", caption) | |
| caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) | |
| caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) | |
| caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) | |
| ####################################################### | |
| # все виды тире / all types of dash --> "-" | |
| caption = re.sub( | |
| r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa | |
| "-", | |
| caption, | |
| ) | |
| # кавычки к одному стандарту | |
| caption = re.sub(r"[`´«»“”¨]", '"', caption) | |
| caption = re.sub(r"[‘’]", "'", caption) | |
| # " | |
| caption = re.sub(r""?", "", caption) | |
| # & | |
| caption = re.sub(r"&", "", caption) | |
| # ip adresses: | |
| caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) | |
| # article ids: | |
| caption = re.sub(r"\d:\d\d\s+$", "", caption) | |
| # \n | |
| caption = re.sub(r"\\n", " ", caption) | |
| # "#123" | |
| caption = re.sub(r"#\d{1,3}\b", "", caption) | |
| # "#12345.." | |
| caption = re.sub(r"#\d{5,}\b", "", caption) | |
| # "123456.." | |
| caption = re.sub(r"\b\d{6,}\b", "", caption) | |
| # filenames: | |
| caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) | |
| # | |
| caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" | |
| caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" | |
| caption = re.sub(bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT | |
| caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " | |
| # this-is-my-cute-cat / this_is_my_cute_cat | |
| regex2 = re.compile(r"(?:\-|\_)") | |
| if len(re.findall(regex2, caption)) > 3: | |
| caption = re.sub(regex2, " ", caption) | |
| caption = ftfy.fix_text(caption) | |
| caption = html.unescape(html.unescape(caption)) | |
| caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 | |
| caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc | |
| caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 | |
| caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) | |
| caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) | |
| caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) | |
| caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) | |
| caption = re.sub(r"\bpage\s+\d+\b", "", caption) | |
| caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... | |
| caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) | |
| caption = re.sub(r"\b\s+\:\s+", r": ", caption) | |
| caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) | |
| caption = re.sub(r"\s+", " ", caption) | |
| caption.strip() | |
| caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) | |
| caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) | |
| caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) | |
| caption = re.sub(r"^\.\S+$", "", caption) | |
| return caption.strip() | |