Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
| import gc | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| from contextlib import ExitStack | |
| from copy import deepcopy | |
| from dataclasses import asdict, dataclass | |
| from timeit import default_timer as timer | |
| from typing import Any, TypeVar | |
| import numpy as np | |
| import pyarrow | |
| import torch | |
| import torch.distributed | |
| import torch.nn.functional | |
| import torch.nn.functional as F | |
| import wandb | |
| import xformers.profiler | |
| from torch.distributed._tensor import DTensor | |
| from torch.distributed.checkpoint.stateful import Stateful | |
| from torch.optim import lr_scheduler | |
| from bytelatent.args import TrainArgs | |
| from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint | |
| from bytelatent.config_parser import parse_args_to_pydantic_model | |
| from bytelatent.data.file_util import get_fs | |
| from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh | |
| from bytelatent.data.iterators.multiprocess_iterator import ( | |
| MultiprocessIterator, | |
| MultiprocessIteratorState, | |
| PersistType, | |
| ) | |
| from bytelatent.data.iterators.packing_iterator import PackingIteratorState | |
| from bytelatent.distributed import ( | |
| check_model_value_range, | |
| clean_env, | |
| dist_mean, | |
| dist_sum, | |
| get_device_mesh, | |
| get_is_master, | |
| get_world_size, | |
| init_signal_handler, | |
| parallelize_model, | |
| requeue_slurm_job, | |
| setup_env, | |
| setup_torch_distributed, | |
| to_py_num, | |
| ) | |
| from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval | |
| from bytelatent.logger import init_logger | |
| from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params | |
| from bytelatent.model.blt import ByteLatentTransformer | |
| from bytelatent.norms import fixed_clip_grad_norm_ | |
| from bytelatent.optim import build_optimizer | |
| from bytelatent.probe import AutoProbeD | |
| from bytelatent.profiling import maybe_run_profiler | |
| from bytelatent.stool import StoolArgs, launch_job | |
| from bytelatent.transformer import ( | |
| LMTransformer, | |
| build_fsdp_grouping_plan, | |
| get_no_recompute_ops, | |
| get_num_flop_per_token, | |
| tp_parallelize, | |
| ) | |
| logger = logging.getLogger() | |
| T = TypeVar("T") | |
| def flatten_dict(d, parent_key="", sep="_"): | |
| items = [] | |
| for k, v in d.items(): | |
| new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
| if isinstance(v, dict): | |
| items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| def get_iterator_state_name(iterator_state): | |
| if isinstance(iterator_state, MultiprocessIteratorState): | |
| return "multiprocess" | |
| elif isinstance(iterator_state, PackingIteratorState): | |
| return "packing" | |
| else: | |
| raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") | |
| # TODO: Make this pydantic based instead of data class based | |
| # TODO: Generalize this to any iterator state | |
| class TrainState(Stateful): | |
| step: int # Nb of steps taken by the optimizer | |
| acc_step: int # Nb of accumulation steps done since last optimizer step | |
| scheduler: lr_scheduler.LambdaLR | |
| data_loader_state: MultiprocessIteratorState | PackingIteratorState | |
| scale: float = 1.0 | |
| data_loader_class: str | None = None | |
| def state_dict(self) -> dict[str, Any]: | |
| return { | |
| "step": self.step, | |
| "acc_step": self.acc_step, | |
| "data_loader_state": self.data_loader_state.model_dump(), | |
| "data_loader_class": get_iterator_state_name(self.data_loader_state), | |
| "scheduler": self.scheduler.state_dict(), | |
| } | |
| def load_state_dict(self, state_dict): | |
| self.step = state_dict["step"] | |
| self.acc_step = state_dict["acc_step"] | |
| self.data_loader_class = state_dict["data_loader_class"] | |
| if self.data_loader_class == "multiprocess": | |
| self.data_loader_state = MultiprocessIteratorState( | |
| **state_dict["data_loader_state"] | |
| ) | |
| elif self.data_loader_class == "packing": | |
| self.data_loader_state = PackingIteratorState( | |
| **state_dict["data_loader_state"] | |
| ) | |
| else: | |
| raise ValueError(f"invalid data loader class: {self.data_loader_class}") | |
| self.scheduler.load_state_dict(state_dict["scheduler"]) | |
| def validate_train_args(args: TrainArgs, output_size: int): | |
| assert args.model is not None or args.entropy_model is not None | |
| if args.model is not None: | |
| logger.info(f"Setting model output size to {args.model.vocab_size}") | |
| args.model.vocab_size = output_size | |
| assert ( | |
| args.model.max_encoder_seq_length == args.data.max_encoder_seq_length | |
| ), "max_encoder_seq_length for model and data should match" | |
| if args.entropy_model is not None: | |
| logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") | |
| args.entropy_model.vocab_size = output_size | |
| assert args.dump_dir, "Dump dir not set" | |
| if args.checkpoint.path is None: | |
| logger.info(f"Setting checkpoint path to {args.checkpoint.path}") | |
| args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints") | |
| if args.data.root_dir is not None: | |
| data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) | |
| for source in args.data.sources: | |
| data_path = os.path.join(args.data.root_dir, source) | |
| assert data_fs.exists(data_path), f"{data_path} doesn't exist" | |
| args.distributed.configure_world() | |
| if args.model is not None: | |
| args.model.max_seqlen = args.data.seq_len | |
| if args.entropy_model is not None: | |
| args.entropy_model.max_seqlen = args.data.seq_len | |
| if args.distributed.tp_size == 1: | |
| logger.warning( | |
| "Tensor parallelism has not been tested for a while, use at your own risk" | |
| ) | |
| assert ( | |
| args.probe_freq != args.profiling.mem_steps | |
| ), "Don't profile during probe step" | |
| assert ( | |
| args.probe_freq != args.profiling.profile_steps | |
| ), "Don't profile during probe step" | |
| if args.logging.wandb is not None: | |
| args.logging.wandb.name = args.name | |
| if args.probe_freq is not None: | |
| assert ( | |
| args.distributed.tp_size == 1 | |
| ), "Probing not supported with tensor parallelism" | |
| assert ( | |
| args.distributed.selective_activation_checkpointing is False | |
| ), "Probing not supported with selective activation checkpointing" | |
| preemption_flag = dict(flag=False) | |
| def set_preemption_flag(signum, frame): | |
| logger.warning("Signal handler called with signal " + str(signum)) | |
| logger.warning("Preemption ! checkpointing asap and exiting.") | |
| preemption_flag["flag"] = True | |
| def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None): | |
| if freq < 0: | |
| return False | |
| test = train_state.step % freq == 0 | |
| if acc_step is not None: | |
| test = test and (train_state.acc_step == acc_step) | |
| elif acc_freq is not None: | |
| test = test and ((train_state.acc_step % acc_freq) == 0) | |
| return test | |
| def compute_loss(p, y, mask, scale): | |
| tok_loss = scale * F.cross_entropy( | |
| p.flatten(0, 1), y.flatten(0, 1), reduction="none" | |
| ) | |
| if mask is None: | |
| loss = tok_loss.mean() | |
| else: | |
| mask = mask.flatten(0, 1) | |
| tok_loss = tok_loss * mask | |
| loss = tok_loss.sum() / (mask.sum() + 1e-6) | |
| return loss, tok_loss | |
| def train(args: TrainArgs): | |
| with ExitStack() as context_stack: | |
| pyarrow.set_io_thread_count(4) | |
| pyarrow.set_cpu_count(4) | |
| tokenizer = args.data.tokenizer_args.build() | |
| validate_train_args( | |
| args, | |
| tokenizer.get_vocab_size(), | |
| ) | |
| dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile) | |
| if get_is_master(): | |
| dump_fs.mkdirs(args.dump_dir, exist_ok=True) | |
| config_yaml_str = args.dump_to_yaml_str() | |
| logging.info("TrainArgs: \n%s", config_yaml_str) | |
| dump_fs.write_text( | |
| os.path.join(args.dump_dir, "config.yaml"), config_yaml_str | |
| ) | |
| init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) | |
| init_signal_handler(set_preemption_flag) # For handling preemption signals. | |
| setup_env(args.env) | |
| setup_torch_distributed(args.distributed) | |
| world_mesh = get_device_mesh(args.distributed) | |
| logger.info(f"Starting job: {args.name}") | |
| # build dataloader | |
| # need dp world size and rank | |
| dp_mesh = world_mesh["dp_replicate"] | |
| dp_degree = dp_mesh.size() | |
| dp_rank = dp_mesh.get_local_rank() | |
| if args.distributed.dp_shard > 1: | |
| dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() | |
| dp_degree *= world_mesh["dp_shard"].size() | |
| logger.info(f"Running on dp rank : {dp_rank}") | |
| logger.info(f"Running on dp size : {dp_degree}") | |
| torch.manual_seed(args.seed) | |
| logger.info("Building model") | |
| # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory | |
| with torch.device("meta"): | |
| if args.train_entropy_model: | |
| assert args.entropy_model is not None | |
| model = LMTransformer(args.entropy_model) | |
| model_args = args.entropy_model | |
| else: | |
| assert args.model is not None | |
| model = ByteLatentTransformer(args.model) | |
| model_args = args.model | |
| logger.info("Model is built !") | |
| model_param_count = get_num_params(model) | |
| model = parallelize_model( | |
| model, | |
| world_mesh, | |
| model_args, | |
| args.distributed, | |
| fsdp_grouping_plan=build_fsdp_grouping_plan(model_args), | |
| tp_parallelize=tp_parallelize, | |
| no_recompute_ops=get_no_recompute_ops(), | |
| ) | |
| # Once we shard the model on different gpus we can actually initialize the model | |
| # First we create empty tensors of the correct shapes | |
| model = model.to_empty(device="cuda") | |
| # Then we init the model. Please make sure this function initializes *ALL* parameters | |
| # and buffers, otherwise you will have random values in the unitialized tensors | |
| # which will silently fail (give nan gradients for example) | |
| if args.checkpoint.init_ckpt_path: | |
| logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") | |
| ckpt_fs = get_fs( | |
| args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile | |
| ) | |
| load_from_checkpoint( | |
| ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" | |
| ) # Put model_key="" if its directly the model checkpoint | |
| model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded | |
| else: | |
| with torch.random.fork_rng(devices=[torch.cuda.current_device()]): | |
| torch.manual_seed(model_args.seed) | |
| model.init_weights() | |
| check_model_value_range(model, range=10.0, std=1.0) | |
| # log model size | |
| logger.info(model) | |
| logger.info(f"Model size: {model_param_count:,} total parameters") | |
| gpu_memory_monitor = GPUMemoryMonitor("cuda") | |
| logger.info( | |
| f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " | |
| f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" | |
| ) | |
| logger.info(f"GPU memory usage: {gpu_memory_monitor}") | |
| # build optimizer after apply parallelisms to the model | |
| optimizer, scheduler = build_optimizer(model, args.optim, args.steps) | |
| data_loader = args.data.build_from_rank(dp_rank, dp_degree) | |
| data_loader_state = data_loader.get_state() | |
| train_state = TrainState( | |
| step=0, | |
| acc_step=0, | |
| data_loader_state=data_loader_state, | |
| scheduler=scheduler, | |
| scale=1.0, | |
| ) | |
| checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) | |
| checkpoint.load(model, optimizer, train_state, world_mesh) | |
| # Either load from latest checkpoint or start from scratch | |
| if args.probe_freq is not None: | |
| # TODO: Convert this to fsspec compatible | |
| if get_is_master(): | |
| os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True) | |
| torch.distributed.barrier() | |
| probe = AutoProbeD( | |
| model, | |
| ( | |
| os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl") | |
| if (dp_rank % 128 == 0) | |
| else None | |
| ), | |
| ) | |
| probe_mod = model._orig_mod if args.distributed.compile else model | |
| gc.disable() | |
| # train loop | |
| model.train() | |
| metric_logger = context_stack.enter_context( | |
| MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs) | |
| ) | |
| data_loader = train_state.data_loader_state.build() | |
| batch_iterator = data_loader.create_iter() | |
| torch_profiler = context_stack.enter_context( | |
| maybe_run_profiler(args.dump_dir, model, args.profiling) | |
| ) | |
| nwords_since_last_log = 0 | |
| time_last_log = timer() | |
| gc.collect() | |
| saved = False | |
| step_losses: list[float] = [] | |
| step_tok_losses: list[float] = [] | |
| n_bytes: int = 0 | |
| while train_state.step < args.steps and ( | |
| args.max_steps is None or train_state.step < args.max_steps | |
| ): | |
| # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 | |
| train_state.acc_step += 1 | |
| train_state.acc_step = train_state.acc_step % args.grad_acc_steps | |
| # get batch | |
| curr_lr = float(optimizer.param_groups[0]["lr"]) | |
| data_load_start = timer() | |
| batch = next(batch_iterator) | |
| batch_x = torch.from_numpy( | |
| batch.x, | |
| ).cuda() | |
| batch_y = torch.from_numpy(batch.y).cuda() | |
| if batch.patch_lengths is None: | |
| batch_patch_lengths = None | |
| else: | |
| batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() | |
| mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() | |
| if args.data.tokenizer_args.name in ["bytes", "blt"]: | |
| n_bytes += batch_y.numel() if mask is None else mask.sum() | |
| elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: | |
| for example in batch.y: | |
| target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) | |
| n_bytes += ( | |
| len(bytes(target_tokens, encoding="utf-8", errors="ignore")) | |
| + sum(example == tokenizer.eos_id) | |
| + sum(example == tokenizer.bos_id) | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" | |
| ) | |
| if ( | |
| not args.train_entropy_model | |
| and args.model.encoder_enable_byte_ngrams | |
| and batch.ngram_ids is None | |
| ): | |
| raise ValueError( | |
| "Cannot enable byte ngrams and have batch.ngram_ids be None" | |
| ) | |
| ngram_ids = ( | |
| None | |
| if batch.ngram_ids is None | |
| else torch.from_numpy(batch.ngram_ids).cuda() | |
| ) | |
| if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): | |
| logger.info("garbage collection") | |
| # we do garbage collection manually otherwise different processes | |
| # run the GC at different times so they slow down the whole pipeline | |
| gc.collect() | |
| data_load_time = round(timer() - data_load_start, 4) | |
| nwords_since_last_log += batch_x.numel() | |
| bsz, seqlen = batch_y.shape | |
| # forward | |
| start_timer = torch.cuda.Event(enable_timing=True) | |
| end_timer = torch.cuda.Event(enable_timing=True) | |
| start_timer.record() | |
| # This is an automatic probe that will compute statistics | |
| # of all linears' inputs, weights and outputs | |
| # along with attention logits and entropy | |
| # both in forward and backward pass | |
| tok_loss = None | |
| if (args.probe_freq is not None) and every_n_steps( | |
| train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps | |
| ): | |
| # Here we do a fake forward and backward pass on a smaller | |
| # batch size to avoid OOM | |
| # This assumes the model has no stateful layers (batch norm..) | |
| assert ( | |
| next(probe_mod.parameters()).grad is None | |
| ), "Can't probe model if grads are not reset" | |
| with probe: | |
| probe.metadata = { | |
| "it": train_state.step, | |
| "global_step": train_state.step, | |
| "loop": "lingua", | |
| } | |
| # Non compiled model uses roughly 2x memory in our exps | |
| # So we divide bsz by 2 or seqlen by 2 | |
| probe_bsz = max(1, bsz // 2) | |
| probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2) | |
| probe_loss = probe_mod( | |
| batch_x[:probe_bsz, :probe_seq], | |
| batch_y[:probe_bsz, :probe_seq], | |
| ) | |
| probe_loss.backward() | |
| # We zero grads to cancel this fake step | |
| optimizer.zero_grad() | |
| assert ( | |
| next(probe_mod.parameters()).grad is None | |
| ), "Probe model shouldn't have grads at this point" | |
| if args.train_entropy_model: | |
| pred = model(batch_x) | |
| else: | |
| pred = model( | |
| batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids | |
| ) | |
| loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) | |
| # We scale loss with grad_acc_steps so the gradient is the same | |
| # regardless of grad_acc_steps | |
| loss = loss / args.grad_acc_steps | |
| # backward on scaled loss to create scaled gradients | |
| loss.backward() | |
| # For logging we undo that scaling | |
| loss = loss.detach() * args.grad_acc_steps | |
| # Undo loss scaling so downstream down't need to worry about it | |
| step_losses.append((loss / train_state.scale).item()) | |
| step_tok_losses.append(tok_loss / train_state.scale) | |
| world_size = get_world_size() | |
| if 1 < world_size <= 8: | |
| # For some reason, there are errors in reduces due to | |
| # not working for non-bf16 numbers. This function is a patched | |
| # version that converts gradients to bf16 before computing norms. | |
| # The error only happens in distributed training on one node, | |
| # hence the guard | |
| grad_norm = fixed_clip_grad_norm_( | |
| model.parameters(), max_norm=args.optim.clip, foreach=True | |
| ) | |
| else: | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), max_norm=args.optim.clip, foreach=True | |
| ) | |
| grad_norm = ( | |
| grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm | |
| ).item() | |
| # optimizer step | |
| if train_state.acc_step == 0: | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| train_state.step += 1 | |
| # updates the scale for next iteration | |
| # training iteration complete | |
| end_timer.record() | |
| torch.cuda.synchronize() | |
| curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) | |
| # if profiler is active | |
| if torch_profiler: | |
| xformers.profiler.step() | |
| # log metrics | |
| if every_n_steps( | |
| train_state, | |
| args.logging.freq, | |
| acc_step=None if args.logging.acc_freq else 0, | |
| acc_freq=args.logging.acc_freq, | |
| ): | |
| time_delta = timer() - time_last_log | |
| wps = nwords_since_last_log / (time_delta * args.distributed.tp_size) | |
| gpu_mem_stats = gpu_memory_monitor.get_peak_stats() | |
| total_acc_steps = ( | |
| args.grad_acc_steps * train_state.step + train_state.acc_step | |
| ) | |
| tokens_per_gpu = ( | |
| total_acc_steps * args.data.batch_size * args.data.seq_len | |
| ) | |
| total_tokens = dp_degree * tokens_per_gpu | |
| # This is an estimate and the correct values may change | |
| # if you change the architecture | |
| # Use xformer's analyze profile trace to get actual measurement | |
| FLOPS = ( | |
| get_num_flop_per_token( | |
| model_param_count - model_args.vocab_size * model_args.dim, | |
| model_args.n_layers, | |
| model_args.dim, | |
| args.data.seq_len, | |
| ) | |
| * wps | |
| ) | |
| # Below, semantics are: | |
| # per_gpu: Metrics on a given rank | |
| # across_gpus: Metrics averaged/summed across all ranks | |
| # step: Metric at a step | |
| # interval: Metric averaged/summed across all steps since the last log interval. | |
| # Typically, this is 10 | |
| step_loss_per_gpu = loss | |
| step_loss_across_gpus = dist_mean(step_loss_per_gpu) | |
| interval_loss_per_gpu = np.mean(step_losses) | |
| interval_loss_across_gpus = dist_mean(interval_loss_per_gpu) | |
| stacked_tok_loss = torch.cat(step_tok_losses, dim=0) | |
| interval_total_tok_loss_per_gpu = stacked_tok_loss.sum() | |
| interval_total_tok_loss_across_gpus = dist_sum( | |
| interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 | |
| ) | |
| interval_total_n_bytes_per_gpu = n_bytes | |
| interval_total_n_bytes_across_gpus = dist_sum( | |
| n_bytes, reduce_dtype=torch.bfloat16 | |
| ) | |
| interval_bpb_per_gpu = ( | |
| interval_total_tok_loss_per_gpu | |
| / math.log(2) | |
| / interval_total_n_bytes_per_gpu | |
| ) | |
| interval_bpb_across_gpus = ( | |
| interval_total_tok_loss_across_gpus | |
| / math.log(2) | |
| / interval_total_n_bytes_across_gpus | |
| ) | |
| metric_dict = { | |
| "global_step": train_state.step, | |
| "acc_step": train_state.acc_step, | |
| "speed": { | |
| "wps": wps, | |
| "FLOPS": FLOPS, | |
| "curr_iter_time": curr_iter_time, | |
| "data_load_time": data_load_time, | |
| }, | |
| "optim": { | |
| "grad_norm": grad_norm, | |
| "lr": curr_lr, | |
| "total_tokens": total_tokens, | |
| }, | |
| "memory": gpu_mem_stats._asdict(), | |
| "loss": { | |
| "step_per_gpu": to_py_num(step_loss_per_gpu), | |
| "step_across_gpu": to_py_num(step_loss_across_gpus), | |
| "interval_per_gpu": to_py_num(interval_loss_per_gpu), | |
| "interval_across_gpu": to_py_num(interval_loss_across_gpus), | |
| }, | |
| "bpb": { | |
| "interval_per_gpu": to_py_num(interval_bpb_per_gpu), | |
| "interval_across_gpus": to_py_num(interval_bpb_across_gpus), | |
| }, | |
| "n_bytes": { | |
| "interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu), | |
| "interval_across_gpus": to_py_num( | |
| interval_total_n_bytes_across_gpus | |
| ), | |
| }, | |
| } | |
| metrics = flatten_dict( | |
| metric_dict, | |
| sep="/", | |
| ) | |
| if get_is_master(): | |
| metric_logger.log(metrics) | |
| # Below semantics are: | |
| # step=Metrics at a step | |
| # interval=Metrics averaged across the logging interval | |
| # local=On one rank | |
| # global=Across all ranks | |
| logger.info( | |
| f"step: {train_state.step}" | |
| f" acc: {train_state.acc_step}" | |
| f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}" | |
| f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}" | |
| f" bpb_gpu: {interval_bpb_per_gpu:3f}" | |
| f" bpb_avg: {interval_bpb_across_gpus:3f}" | |
| f" grad: {grad_norm:.2e}" | |
| f" flops: {FLOPS:.2e}" | |
| f" wps: {wps:.2e}" | |
| f" iter: {curr_iter_time:>7}" | |
| f" data: {data_load_time:>5}" | |
| f" lr: {curr_lr:.2e}" | |
| f" n_bytes_gpu: {int(interval_total_n_bytes_per_gpu)}" | |
| f" n_bytes_sum: {int(interval_total_n_bytes_across_gpus)}" | |
| f" mem: {gpu_mem_stats.max_active_pct:.0f}%" | |
| f" pow: {gpu_mem_stats.power_draw/1000} W" | |
| ) | |
| n_bytes = 0 | |
| step_losses = [] | |
| step_tok_losses = [] | |
| gpu_memory_monitor.reset_peak_stats() | |
| nwords_since_last_log = 0 | |
| time_last_log = timer() | |
| if every_n_steps( | |
| train_state, args.checkpoint.dump.every, acc_step=0 | |
| ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): | |
| if ( | |
| args.data.load_async | |
| and args.data.async_persist_type == PersistType.EXACT | |
| ): | |
| train_state.data_loader_state, data_loader, batch_iterator = ( | |
| get_state_and_refresh(data_loader) | |
| ) | |
| else: | |
| train_state.data_loader_state = data_loader.get_state() | |
| saved = checkpoint.save( | |
| model, | |
| optimizer, | |
| train_state, | |
| args, | |
| device_mesh=world_mesh, | |
| ) | |
| if args.eval is not None and every_n_steps( | |
| train_state, args.checkpoint.eval.every, acc_step=0 | |
| ): | |
| eval_args = args.eval | |
| eval_args.global_step = train_state.step | |
| eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) | |
| eval_args.dump_dir = os.path.join( | |
| args.dump_dir, | |
| "evals", | |
| EVAL_FOLDER_NAME.format(train_state.step), | |
| ) | |
| eval_args.metric_log_dir = args.dump_dir | |
| if args.async_eval_gpus is None: | |
| launch_eval(eval_args) | |
| elif get_is_master(): | |
| if wandb.run is not None and args.logging.wandb is not None: | |
| eval_args.wandb = deepcopy(args.logging.wandb) | |
| assert args.async_eval_gpus > 0 | |
| logger.info(f"Launching evals on {args.async_eval_gpus} gpus") | |
| with clean_env(): | |
| launch_job( | |
| StoolArgs( | |
| asdict(eval_args), | |
| script="apps.main.eval", | |
| copy_code=False, | |
| nodes=args.async_eval_gpus // 8, | |
| qos="lowest", | |
| ) | |
| ) | |
| if preemption_flag["flag"]: | |
| if not saved: | |
| if ( | |
| args.data.load_async | |
| and args.data.async_persist_type == PersistType.EXACT | |
| ): | |
| train_state.data_loader_state, data_loader, batch_iterator = ( | |
| get_state_and_refresh(data_loader) | |
| ) | |
| else: | |
| train_state.data_loader_state = data_loader.get_state() | |
| checkpoint.save( | |
| model, | |
| optimizer, | |
| train_state, | |
| args, | |
| device_mesh=world_mesh, | |
| ) | |
| requeue_slurm_job() | |
| sys.exit(0) | |
| if not saved: | |
| if ( | |
| args.data.load_async | |
| and args.data.async_persist_type == PersistType.EXACT | |
| ): | |
| train_state.data_loader_state, data_loader, batch_iterator = ( | |
| get_state_and_refresh(data_loader) | |
| ) | |
| else: | |
| train_state.data_loader_state = data_loader.get_state() | |
| checkpoint.save( | |
| model, | |
| optimizer, | |
| train_state, | |
| args, | |
| device_mesh=world_mesh, | |
| ) | |
| if isinstance(data_loader, MultiprocessIterator): | |
| logger.info("Closing MP iterator before exiting") | |
| data_loader.shutdown() | |
| gc.collect() | |
| def main(): | |
| """ | |
| The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments | |
| This accepts arguments as a dot list | |
| So if the dataclass looks like | |
| @dataclass | |
| class DummyArgs: | |
| name: str | |
| model: LMTransformerArgsgs | |
| @dataclass | |
| class LMTransformerArgsgs: | |
| dim: int | |
| Then you can pass model.dim=32 to change values in LMTransformerArgsgs | |
| or just name=tictac for top level attributes. | |
| The behavior here is as follows: | |
| 1. We instantiate TrainArgs with its default values | |
| 2. We override those default values with the ones in the provided config file | |
| 3. We override the result with the additional arguments provided through command line | |
| For example, if the config is the following | |
| model: | |
| dim: 128 | |
| n_layers: 4 | |
| and you call train.py with train.py model.dim=64 | |
| Then the final TrainArgs will have | |
| model: | |
| dim: 64 | |
| n_layers: 4 | |
| Plus all the default values in TrainArgs dataclass. | |
| """ | |
| train_args = parse_args_to_pydantic_model(TrainArgs) | |
| if train_args.debug_dynamo: | |
| import torch._dynamo | |
| torch._dynamo.config.suppress_errors = True | |
| train(train_args) | |
| if __name__ == "__main__": | |
| main() | |