Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
| import random | |
| from itertools import islice | |
| import numpy as np | |
| import torch | |
| class LengthBasedBatchSampler(torch.utils.data.BatchSampler): | |
| def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None: | |
| if isinstance(next(iter(data_source)), dict): | |
| first_key = next(iter(next(iter(data_source)).keys())) | |
| self.lengths = [len(d[first_key]) for d in data_source] | |
| else: | |
| self.lengths = [len(d) for d in data_source] | |
| self.batch_size = batch_size | |
| self.drop_last = drop_last | |
| self.shuffle = shuffle | |
| def __iter__(self): | |
| ids = np.argsort(self.lengths) | |
| if self.drop_last: | |
| ids = ids[:len(ids) // self.batch_size * self.batch_size] | |
| batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)] | |
| if self.shuffle: | |
| random.shuffle(batches) | |
| for b in batches: | |
| yield b | |
| def __len__(self): | |
| if self.drop_last: | |
| return len(self.lengths) // self.batch_size | |
| else: | |
| return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0) | |
| class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler): | |
| def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None: | |
| random.seed(seed) | |
| self.batch_sampler = LengthBasedBatchSampler( | |
| data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle | |
| ) | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| def __iter__(self): | |
| max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas | |
| return islice(self.batch_sampler, self.rank, max_length, self.num_replicas) | |
| def __len__(self): | |
| return len(self.batch_sampler) // self.num_replicas | |