| import os | |
| from glob import glob | |
| from collections import defaultdict | |
| import numpy as np | |
| from PIL import Image | |
| class MaskDataset(object): | |
| def __init__(self, root, sequences, is_label=True): | |
| self.is_label = is_label | |
| self.sequences = {} | |
| for seq in sequences: | |
| print(root, seq) | |
| if is_label: | |
| masks = np.sort(glob(os.path.join(root, seq, '*.png'))).tolist() | |
| else: | |
| masks = sorted(glob(os.path.join(root, seq, 'dynamic_mask_*.png')), key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])) | |
| self.sequences[seq] = masks | |
| def read_masks(self, seq): | |
| masks = [] | |
| for msk in self.sequences[seq]: | |
| if self.is_label: | |
| img = np.array(Image.open(msk)) | |
| img[img>0] = 255 | |
| img = Image.fromarray(img) | |
| masks.append(img) | |
| else: | |
| masks.append(Image.open(msk)) | |
| return masks | |