| | |
| | |
| | |
| | import torch |
| | from torchvision import datasets |
| | from torchvision.transforms import v2 |
| | from torch.utils.data import DataLoader |
| |
|
| | import utils |
| | from typing import Tuple |
| |
|
| | import io |
| | import base64 |
| | from PIL import Image |
| | import numpy as np |
| |
|
| | |
| | BASE_TRANSFORMS = v2.Compose([ |
| | v2.ToImage(), |
| | v2.ToDtype(torch.float32, scale = True), |
| | v2.Normalize(mean = [0.1307], std = [0.3081]) |
| | ]) |
| |
|
| | TRAIN_TRANSFORMS = v2.Compose([ |
| | v2.RandomAffine(degrees = 15, |
| | scale = (0.8, 1.2), |
| | translate = (0.08, 0.08), |
| | shear = 10), |
| | v2.ToImage(), |
| | v2.ToDtype(torch.float32, scale = True), |
| | v2.Normalize(mean = [0.1307], std = [0.3081]), |
| | ]) |
| |
|
| |
|
| | |
| | |
| | |
| | def get_dataloaders(root: str, |
| | batch_size: int, |
| | num_workers: int = 0) -> Tuple[DataLoader, DataLoader]: |
| | ''' |
| | Creates training and testing dataloaders for the MNIST dataset |
| | |
| | Args: |
| | root (str): Path to download MNIST data. |
| | batch_size (int): Size used to split training and testing datasets into batches. |
| | num_workers (int): Number of workers to use for multiprocessing. Default is 0. |
| | ''' |
| |
|
| | |
| | mnist_train = datasets.MNIST(root, download = True, train = True, |
| | transform = TRAIN_TRANSFORMS) |
| | mnist_test = datasets.MNIST(root, download = True, train = False, |
| | transform = BASE_TRANSFORMS) |
| |
|
| | |
| | if num_workers > 0: |
| | mp_context = utils.MP_CONTEXT |
| | persistent_workers = True |
| | else: |
| | mp_context = None |
| | persistent_workers = False |
| |
|
| | train_dl = DataLoader( |
| | dataset = mnist_train, |
| | batch_size = batch_size, |
| | shuffle = True, |
| | num_workers = num_workers, |
| | multiprocessing_context = mp_context, |
| | pin_memory = utils.PIN_MEM, |
| | persistent_workers = persistent_workers |
| | ) |
| |
|
| | test_dl = DataLoader( |
| | dataset = mnist_test, |
| | batch_size = batch_size, |
| | shuffle = False, |
| | num_workers = num_workers, |
| | multiprocessing_context = mp_context, |
| | pin_memory = utils.PIN_MEM, |
| | persistent_workers = persistent_workers |
| | ) |
| |
|
| | return train_dl, test_dl |
| |
|
| | def mnist_preprocess(uri: str): |
| | ''' |
| | Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset. |
| | The pipeline includes: |
| | 1. Converting the image to grayscale. |
| | 2. Resizing the image to 20x20, preserving the aspect ratio, and using anti-aliasing. |
| | 3. Centering the resized image in a 28x28 image based on the center of mass (COM). |
| | 4. Converting the image to a tensor (pixel values between 0 and 1) and normalizing it using MNIST statistics. |
| | |
| | Reference: https://paperswithcode.com/dataset/mnist |
| | |
| | Args: |
| | uri (str): A string representing the full data URI. |
| | |
| | Returns: |
| | Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics. |
| | ''' |
| | encoded_img = uri.split(',', 1)[1] |
| | image_bytes = io.BytesIO(base64.b64decode(encoded_img)) |
| | pil_img = Image.open(image_bytes).convert('L') |
| | |
| | |
| | pil_img.thumbnail((20, 20), Image.Resampling.LANCZOS) |
| |
|
| | |
| | img = 255 - np.array(pil_img) |
| |
|
| | |
| | img_idxs = np.indices(img.shape) |
| | tot_mass = img.sum() |
| | |
| | |
| | com_x = np.round((img_idxs[1] * img).sum() / tot_mass).astype(int) |
| | com_y = np.round((img_idxs[0] * img).sum() / tot_mass).astype(int) |
| | |
| | dist_com_end_x = img.shape[1] - com_x |
| | dist_com_end_y = img.shape[0] - com_y |
| | |
| | new_img = np.zeros((28, 28), dtype = np.uint8) |
| | new_com_x, new_com_y = 14, 14 |
| | |
| | valid_start_x = min(new_com_x, com_x) |
| | valid_end_x = min(14, dist_com_end_x) |
| | valid_start_y = min(new_com_y, com_y) |
| | valid_end_y = min(14, dist_com_end_y) |
| | |
| | old_slice_x = slice(com_x - valid_start_x, com_x + valid_end_x) |
| | old_slice_y = slice(com_y - valid_start_y, com_y + valid_end_y) |
| | new_slice_x = slice(new_com_x - valid_start_x, new_com_x + valid_end_x) |
| | new_slice_y = slice(new_com_y - valid_start_y, new_com_y + valid_end_y) |
| |
|
| | |
| | new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x] |
| |
|
| | |
| | return BASE_TRANSFORMS(new_img) |
| |
|