Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import json | |
| import logging | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.distributed as dist | |
| import torch.distributed.checkpoint as dcp | |
| import torch.nn as nn | |
| import torch.optim.optimizer | |
| from pydantic import BaseModel, ConfigDict | |
| from torch.distributed._tensor import DeviceMesh | |
| from torch.distributed.checkpoint.format_utils import dcp_to_torch_save | |
| from torch.distributed.checkpoint.state_dict import ( | |
| get_model_state_dict, | |
| get_state_dict, | |
| set_state_dict, | |
| ) | |
| from bytelatent.distributed import get_is_master | |
| logger = logging.getLogger("CHECKPOINT") | |
| FOLDER_NAME = "{:010d}" | |
| RE_FOLDER = r"\d{10}" | |
| RE_CKPT = r"__\d_\d\.distcp" | |
| CONSOLIDATE_FOLDER = "consolidated" | |
| CONSOLIDATE_NAME = "consolidated.pth" | |
| CONFIG_NAME = "params.json" | |
| TRAIN_STATE_NAME = "train_state_{:05d}.json" | |
| RE_DIGITS = re.compile(r"\d+") | |
| class SaveEvery(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| every: int = 1000 | |
| keep: int = 0 | |
| class CheckpointArgs(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| dump: SaveEvery = SaveEvery() | |
| eval: SaveEvery = SaveEvery() | |
| path: str | None = None | |
| init_ckpt_path: str | None = None | |
| continue_training_from_init: bool = False | |
| def _get_key_step(name: str): | |
| return int(re.findall(RE_DIGITS, name)[-1]) | |
| def consolidate_checkpoints(ckpt_dir: str): | |
| """ | |
| Consolidates all FSDP checkpoints in a directory to a single file | |
| Consolidate checkpoint is saved in a subdirectory of ckpt_dir | |
| Parameters: | |
| ckpt_dir: str - path to the directory containing the checkpoints | |
| Returns the path to the consolidated checkpoint | |
| """ | |
| consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER | |
| if not (consolidate_path / CONSOLIDATE_NAME).exists(): | |
| consolidate_path.mkdir(exist_ok=True) | |
| logger.info(f"Consolidating to: {str(consolidate_path)}") | |
| dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) | |
| (consolidate_path / CONFIG_NAME).write_text( | |
| (Path(ckpt_dir) / CONFIG_NAME).read_text() | |
| ) | |
| logger.info("Consolidated !") | |
| return consolidate_path | |
| def load_from_checkpoint( | |
| ckpt_dir: str, | |
| model: nn.Module, | |
| optimizer: Optional[torch.optim.Optimizer] = None, | |
| model_key: str = "model", | |
| optim_key: str = "optim", | |
| ): | |
| if not (Path(ckpt_dir) / ".metadata").exists(): | |
| raise ValueError( | |
| f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" | |
| ) | |
| state_dict = {} | |
| if optimizer is not None: | |
| state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer) | |
| else: | |
| state_dict[model_key] = get_model_state_dict(model) | |
| if model_key == "": # If only loading a model directly, the key should be empty | |
| state_dict = state_dict.pop(model_key) | |
| dcp.load(state_dict, checkpoint_id=ckpt_dir) | |
| class CheckpointManager: | |
| def __init__(self, args: CheckpointArgs): | |
| self.path = args.path | |
| self.dump_every = args.dump | |
| self.eval_every = args.eval | |
| self.init_ckpt_path = args.init_ckpt_path | |
| self.continue_training_from_init = args.continue_training_from_init | |
| assert os.path.exists( | |
| self.path | |
| ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" | |
| self.existing_saves = self.get_existing_saves() | |
| def get_existing_saves(self) -> List[Path]: | |
| folders = [ | |
| p | |
| for p in Path(self.path).iterdir() | |
| if p.is_dir() and re.match(RE_FOLDER, p.name) | |
| ] | |
| folders.sort(key=lambda p: _get_key_step(p.name)) | |
| return folders | |
| def clean_up(self): | |
| logger.info("Cleaning up checkpoints...") | |
| dump_folders = [] | |
| eval_folders = [] | |
| other_folders = [] | |
| for p in self.existing_saves: | |
| is_dump = _get_key_step(p.name) % self.dump_every.every == 0 | |
| is_eval = _get_key_step(p.name) % self.eval_every.every == 0 | |
| if is_dump: | |
| dump_folders.append(p) | |
| if is_eval: | |
| eval_folders.append(p) | |
| if not (is_dump or is_eval): | |
| other_folders.append(p) | |
| logger.info(f"Dump folders: {dump_folders}") | |
| logger.info(f"Eval folders: {eval_folders}") | |
| logger.info(f"Other folders: {other_folders}") | |
| if self.dump_every.keep > 0: | |
| dump_folders = dump_folders[-self.dump_every.keep :] | |
| if self.eval_every.keep > 0: | |
| eval_folders = eval_folders[-self.eval_every.keep :] | |
| folder_to_keep = set(other_folders + dump_folders + eval_folders) | |
| folder_to_remove = set(self.existing_saves) - folder_to_keep | |
| logger.info(f"Removing folders: {folder_to_remove}") | |
| if dist.get_rank() == 0: | |
| for folder in folder_to_remove: | |
| for file in folder.iterdir(): | |
| if file.is_file(): | |
| file.unlink() | |
| elif file.is_dir(): | |
| assert file.name in [CONSOLIDATE_FOLDER] | |
| for f in file.iterdir(): | |
| f.unlink() | |
| file.rmdir() | |
| folder.rmdir() | |
| dist.barrier() | |
| self.existing_saves = list(folder_to_keep) | |
| self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) | |
| def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: | |
| path = None | |
| for p in reversed(self.existing_saves): | |
| if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): | |
| path = p | |
| break | |
| return path | |
| def _create_folder(self, base_path: Path, folder_name: str) -> Path: | |
| folder = base_path / folder_name | |
| if get_is_master(): | |
| folder.mkdir(parents=False, exist_ok=True) | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| return folder | |
| def _get_dp_tp_mesh( | |
| self, device_mesh: Optional[DeviceMesh] = None | |
| ) -> Tuple[int, int]: | |
| dp_rank = 0 | |
| tp_rank = 0 | |
| if device_mesh is not None: | |
| if "dp_replicate" in device_mesh.mesh_dim_names: | |
| dp_rank = device_mesh.get_local_rank("dp_replicate") | |
| if "dp_shard" in device_mesh.mesh_dim_names: | |
| dp_rank = dp_rank * device_mesh[ | |
| "dp_replicate" | |
| ].size() + device_mesh.get_local_rank("dp_shard") | |
| if "tp" in device_mesh.mesh_dim_names: | |
| tp_rank = device_mesh.get_local_rank("tp") | |
| return dp_rank, tp_rank | |
| def get_state_dict( | |
| self, | |
| model, | |
| optimizer, | |
| ): | |
| model_sd, optim_sd = get_state_dict(model, optimizer) | |
| return {"model": model_sd, "optim": optim_sd} | |
| def save( | |
| self, | |
| model, | |
| optimizer, | |
| train_state, | |
| config, | |
| device_mesh: Optional[DeviceMesh] = None, | |
| ) -> bool: | |
| # When creating directory check if only rank0 or is there other solution | |
| path = Path(self.path) | |
| curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) | |
| logger.info(f"Saving to: {str(curr_save_dir)}") | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| logger.info("Saving...") | |
| state_dict = self.get_state_dict(model, optimizer) | |
| dcp.save(state_dict, checkpoint_id=curr_save_dir) | |
| logger.info("State dict saved!") | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| if get_is_master(): | |
| config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) | |
| # Add json dump here | |
| dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) | |
| if tp_rank == 0: | |
| train_state_name = TRAIN_STATE_NAME.format(dp_rank) | |
| logger.info( | |
| f"Saving train state to: {str(curr_save_dir / train_state_name)}" | |
| ) | |
| with open(curr_save_dir / train_state_name, "w") as f: | |
| json.dump(train_state.state_dict(), f) | |
| logger.info("Train state saved !") | |
| self.existing_saves.append(curr_save_dir) | |
| self.clean_up() | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| return True | |
| def load( | |
| self, | |
| model: nn.Module, | |
| optimizer, | |
| train_state, | |
| device_mesh: DeviceMesh, | |
| path: Optional[Path] = None, | |
| ): | |
| dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) | |
| # Loading tries to load the provided path, if not available the last saved step and finally from the init path | |
| path = path or self.get_last_step_path(dp_rank=dp_rank) | |
| # If none of those are available don't do anything | |
| if path is None: | |
| # If no checkpoints exist do nothing | |
| return | |
| # Only load train state if it's provided, the files exist and we're not loading from init path | |
| train_state_name = TRAIN_STATE_NAME.format(dp_rank) | |
| logger.info("Reloading train state") | |
| with open(path / train_state_name, "r") as f: | |
| train_state_dict = json.load(f) | |
| train_state.load_state_dict(train_state_dict) | |
| logger.info("Train state reloaded") | |
| logger.info(f"Loading from: {str(path)}") | |
| state_dict = self.get_state_dict( | |
| model=model, | |
| optimizer=optimizer, | |
| ) | |
| dcp.load(state_dict, checkpoint_id=path) | |
| logger.info("State dict loaded.") | |
| logger.info("Reloading model and optim") | |
| set_state_dict( | |
| model, | |
| optimizer, | |
| model_state_dict=state_dict["model"], | |
| optim_state_dict=state_dict["optim"], | |
| ) | |
| logger.info("Model and optim reloaded") | |
| def instantiate_and_make_dir(cls, args: CheckpointArgs): | |
| if get_is_master(): | |
| os.makedirs(args.path, exist_ok=True) | |
| dist.barrier() | |
| return cls(args) | |