Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import logging | |
| import os | |
| from typing import Any | |
| import numpy as np | |
| import yaml | |
| from pydantic import BaseModel, ConfigDict | |
| from bytelatent.checkpoint import CheckpointArgs | |
| from bytelatent.data.data_types import Batch | |
| from bytelatent.data.iterators.abstract_iterator import StatefulIterator | |
| from bytelatent.data.iterators.arrow_iterator import ( | |
| ArrowFileIterator, | |
| find_and_sanitize_chunks, | |
| ) | |
| from bytelatent.data.iterators.looping_iterator import LoopingIterator | |
| from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator | |
| from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator | |
| from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator | |
| from bytelatent.data.iterators.sampling_iterator import SamplingIterator | |
| from bytelatent.data.iterators.sequence_iterator import ( | |
| SequenceIterator, | |
| SequencePackingArgs, | |
| ) | |
| from bytelatent.data.patcher import PatcherArgs | |
| from bytelatent.distributed import DistributedArgs, EnvironmentArgs | |
| from bytelatent.metrics import LoggingArgs | |
| from bytelatent.model.blt import ByteLatentTransformerArgs | |
| from bytelatent.optim import OptimArgs | |
| from bytelatent.profiling import ProfilerArgs | |
| from bytelatent.tokenizers.build_tokenizer import TokenizerArgs | |
| logger = logging.getLogger() | |
| def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: | |
| return np.random.default_rng((seed, rank, world_size)).bit_generator.state | |
| def distribute_data_to_rank( | |
| *, | |
| dataset_path: str, | |
| preprocess_dir: str, | |
| entropy_model_name: str | None, | |
| arrow_batch_size: int, | |
| rank: int, | |
| world_size: int, | |
| ) -> ArrowFileIterator: | |
| dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) | |
| n_workers_per_chunk = world_size // len(dataset_chunks) | |
| rank_to_arrow_iterator_params = [] | |
| for chunk_path in dataset_chunks: | |
| for worker_id in range(n_workers_per_chunk): | |
| rank_to_arrow_iterator_params.append( | |
| ArrowFileIterator( | |
| file_path=chunk_path, | |
| worker_id=worker_id, | |
| num_workers=n_workers_per_chunk, | |
| preprocess_dir=preprocess_dir, | |
| dataset_files=None, | |
| entropy_model_name=entropy_model_name, | |
| arrow_batch_size=arrow_batch_size, | |
| ) | |
| ) | |
| return rank_to_arrow_iterator_params[rank] | |
| class DataloaderArgs(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| root_dir: str | None = None | |
| sources: dict[str, float] = {} | |
| batch_size: int = 2 | |
| seq_len: int = 2048 | |
| seed: int = 42 | |
| add_bos: bool = True | |
| add_eos: bool = True | |
| load_async: bool = True | |
| prefetch_size: int = 64 | |
| preprocess_dir: str | None = None | |
| dataset_files: list[str] | None = None | |
| entropy_model_name: str | None = "transformer_100m" | |
| arrow_batch_size: int = 100 | |
| buffer_size: int = 64 | |
| pad_to_max_length: bool = True | |
| max_encoder_seq_length: int = 12288 | |
| enable_byte_ngrams: bool = False | |
| tokenizer_args: TokenizerArgs = TokenizerArgs() | |
| patcher_args: PatcherArgs = PatcherArgs() | |
| def _create_sequence_iterators( | |
| self, rank: int, world_size: int | |
| ) -> dict[str, SequenceIterator]: | |
| sequence_packing_args = SequencePackingArgs( | |
| output_seq_len=self.seq_len, | |
| buffer_size=self.buffer_size, | |
| ) | |
| source_to_sequence_iterator: dict[str, SequenceIterator] = {} | |
| for dataset_path in self.sources: | |
| shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size) | |
| arrow_iterator = distribute_data_to_rank( | |
| dataset_path=os.path.join(self.root_dir, dataset_path), | |
| preprocess_dir=self.preprocess_dir, | |
| entropy_model_name=self.entropy_model_name, | |
| arrow_batch_size=self.arrow_batch_size, | |
| rank=rank, | |
| world_size=world_size, | |
| ) | |
| looping_iterator = LoopingIterator(arrow_iterator) | |
| preprocess_iterator = PreprocessIterator( | |
| looping_iterator, | |
| patcher_args=self.patcher_args, | |
| tokenizer_args=self.tokenizer_args, | |
| ) | |
| sequence_iterator = SequenceIterator( | |
| preprocess_iterator, | |
| sequence_packing_args=sequence_packing_args, | |
| rng_state=shuffle_rng_state, | |
| ) | |
| source_to_sequence_iterator[dataset_path] = sequence_iterator | |
| return source_to_sequence_iterator | |
| def build_from_rank( | |
| self, rank: int, world_size: int | |
| ) -> StatefulIterator[Batch, Any]: | |
| source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size) | |
| weight_rng_state = get_rng_state(self.seed + 1, rank, world_size) | |
| sampling_iterator = SamplingIterator( | |
| rng_state=weight_rng_state, | |
| source_to_weight=self.sources, | |
| source_to_iterator=source_to_sequence_iterators, | |
| ) | |
| tokenizer = self.tokenizer_args.build() | |
| packing_args = PackingArgs( | |
| batch_size=self.batch_size, | |
| seq_len=self.seq_len, | |
| pad_id=tokenizer.boe_id, | |
| max_length=self.max_encoder_seq_length, | |
| pad_to_max_length=self.pad_to_max_length, | |
| enable_byte_ngrams=self.enable_byte_ngrams, | |
| ) | |
| packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) | |
| mp_iterator = MultiprocessIterator( | |
| packing_iterator, n_batches_to_prefetch=self.prefetch_size | |
| ) | |
| return mp_iterator | |
| class TrainArgs(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| name: str = "lingua" | |
| dump_dir: str = "" | |
| seed: int = 42 | |
| # Number of gradient accumulation steps | |
| # Total batch size is batch_size*grad_acc_steps | |
| grad_acc_steps: int = 1 | |
| gc_collect_freq: int = 1000 | |
| probe_freq: int | None = None | |
| # Nb optimizer steps to take | |
| steps: int = 1000 | |
| data: DataloaderArgs = DataloaderArgs() | |
| optim: OptimArgs = OptimArgs() | |
| model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() | |
| distributed: DistributedArgs = DistributedArgs() | |
| env: EnvironmentArgs = EnvironmentArgs() | |
| checkpoint: CheckpointArgs = CheckpointArgs() | |
| profiling: ProfilerArgs = ProfilerArgs() | |
| logging: LoggingArgs = LoggingArgs() | |
| # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus | |
| async_eval_gpus: int | None = None | |
| eval: Any | None = None | |
| eval_on_gpus: int | None = None | |
| def dump_to_yaml_file( | |
| self, path: str, log_config: bool = True, sort_keys: bool = True | |
| ): | |
| model_dict = self.model_dump(mode="json") | |
| yaml_str = yaml.dump( | |
| model_dict, | |
| allow_unicode=True, | |
| sort_keys=sort_keys, | |
| default_flow_style=False, | |
| ) | |
| with open(path, "w") as f: | |
| if log_config: | |
| logger.info("Using the following config for this run:") | |
| logger.info(yaml_str) | |
| f.write(yaml_str) | |