| """Adapted from https://github.com/cvg/GeoCalib""" |
|
|
| import collections.abc as collections |
| import functools |
| import inspect |
| from typing import Callable, List, Tuple |
|
|
| import numpy as np |
| import torch |
|
|
| |
| |
|
|
|
|
| string_classes = (str, bytes) |
|
|
|
|
| def autocast(func: Callable) -> Callable: |
| """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays. |
| |
| Use the device and dtype of the wrapper. |
| |
| Args: |
| func (Callable): Method of a TensorWrapper class. |
| |
| Returns: |
| Callable: Wrapped method. |
| """ |
|
|
| @functools.wraps(func) |
| def wrap(self, *args): |
| device = torch.device("cpu") |
| dtype = None |
| if isinstance(self, TensorWrapper): |
| if self._data is not None: |
| device = self.device |
| dtype = self.dtype |
| elif not inspect.isclass(self) or not issubclass(self, TensorWrapper): |
| raise ValueError(self) |
|
|
| cast_args = [] |
| for arg in args: |
| if isinstance(arg, np.ndarray): |
| arg = torch.from_numpy(arg) |
| arg = arg.to(device=device, dtype=dtype) |
| cast_args.append(arg) |
| return func(self, *cast_args) |
|
|
| return wrap |
|
|
|
|
| class TensorWrapper: |
| """Wrapper for PyTorch tensors.""" |
|
|
| _data = None |
|
|
| @autocast |
| def __init__(self, data: torch.Tensor): |
| """Wrapper for PyTorch tensors.""" |
| self._data = data |
|
|
| @property |
| def shape(self) -> torch.Size: |
| """Shape of the underlying tensor.""" |
| return self._data.shape[:-1] |
|
|
| @property |
| def device(self) -> torch.device: |
| """Get the device of the underlying tensor.""" |
| return self._data.device |
|
|
| @property |
| def dtype(self) -> torch.dtype: |
| """Get the dtype of the underlying tensor.""" |
| return self._data.dtype |
|
|
| def __getitem__(self, index) -> torch.Tensor: |
| """Get the underlying tensor.""" |
| return self.__class__(self._data[index]) |
|
|
| def __setitem__(self, index, item): |
| """Set the underlying tensor.""" |
| self._data[index] = item.data |
|
|
| def to(self, *args, **kwargs): |
| """Move the underlying tensor to a new device.""" |
| return self.__class__(self._data.to(*args, **kwargs)) |
|
|
| def cpu(self): |
| """Move the underlying tensor to the CPU.""" |
| return self.__class__(self._data.cpu()) |
|
|
| def cuda(self): |
| """Move the underlying tensor to the GPU.""" |
| return self.__class__(self._data.cuda()) |
|
|
| def pin_memory(self): |
| """Pin the underlying tensor to memory.""" |
| return self.__class__(self._data.pin_memory()) |
|
|
| def float(self): |
| """Cast the underlying tensor to float.""" |
| return self.__class__(self._data.float()) |
|
|
| def double(self): |
| """Cast the underlying tensor to double.""" |
| return self.__class__(self._data.double()) |
|
|
| def detach(self): |
| """Detach the underlying tensor.""" |
| return self.__class__(self._data.detach()) |
|
|
| def numpy(self): |
| """Convert the underlying tensor to a numpy array.""" |
| return self._data.detach().cpu().numpy() |
|
|
| def new_tensor(self, *args, **kwargs): |
| """Create a new tensor of the same type and device.""" |
| return self._data.new_tensor(*args, **kwargs) |
|
|
| def new_zeros(self, *args, **kwargs): |
| """Create a new tensor of the same type and device.""" |
| return self._data.new_zeros(*args, **kwargs) |
|
|
| def new_ones(self, *args, **kwargs): |
| """Create a new tensor of the same type and device.""" |
| return self._data.new_ones(*args, **kwargs) |
|
|
| def new_full(self, *args, **kwargs): |
| """Create a new tensor of the same type and device.""" |
| return self._data.new_full(*args, **kwargs) |
|
|
| def new_empty(self, *args, **kwargs): |
| """Create a new tensor of the same type and device.""" |
| return self._data.new_empty(*args, **kwargs) |
|
|
| def unsqueeze(self, *args, **kwargs): |
| """Create a new tensor of the same type and device.""" |
| return self.__class__(self._data.unsqueeze(*args, **kwargs)) |
|
|
| def squeeze(self, *args, **kwargs): |
| """Create a new tensor of the same type and device.""" |
| return self.__class__(self._data.squeeze(*args, **kwargs)) |
|
|
| @classmethod |
| def stack(cls, objects: List, dim=0, *, out=None): |
| """Stack a list of objects with the same type and shape.""" |
| data = torch.stack([obj._data for obj in objects], dim=dim, out=out) |
| return cls(data) |
|
|
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| """Support torch functions.""" |
| if kwargs is None: |
| kwargs = {} |
| return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented |
|
|
|
|
| def map_tensor(input_, func): |
| if isinstance(input_, string_classes): |
| return input_ |
| elif isinstance(input_, collections.Mapping): |
| return {k: map_tensor(sample, func) for k, sample in input_.items()} |
| elif isinstance(input_, collections.Sequence): |
| return [map_tensor(sample, func) for sample in input_] |
| elif input_ is None: |
| return None |
| else: |
| return func(input_) |
|
|
|
|
| def batch_to_numpy(batch): |
| return map_tensor(batch, lambda tensor: tensor.cpu().numpy()) |
|
|
|
|
| def batch_to_device(batch, device, non_blocking=True, detach=False): |
| def _func(tensor): |
| t = tensor.to(device=device, non_blocking=non_blocking, dtype=torch.float32) |
| return t.detach() if detach else t |
|
|
| return map_tensor(batch, _func) |
|
|
|
|
| def remove_batch_dim(data: dict) -> dict: |
| """Remove batch dimension from elements in data""" |
| return { |
| k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items() |
| } |
|
|
|
|
| def add_batch_dim(data: dict) -> dict: |
| """Add batch dimension to elements in data""" |
| return { |
| k: v[None] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v |
| for k, v in data.items() |
| } |
|
|
|
|
| def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False): |
| """Get padding to make the image size a multiple of the given number. |
| |
| Args: |
| x (torch.Tensor): Input tensor. |
| multiple (int, optional): Multiple. |
| crop (bool, optional): Whether to crop or pad. Defaults to False. |
| |
| Returns: |
| torch.Tensor: Padding. |
| """ |
| h, w = x.shape[-2:] |
|
|
| if crop: |
| pad_w = (w // multiple) * multiple - w |
| pad_h = (h // multiple) * multiple - h |
| else: |
| pad_w = (multiple - w % multiple) % multiple |
| pad_h = (multiple - h % multiple) % multiple |
|
|
| if mode == "center": |
| pad_l = pad_w // 2 |
| pad_r = pad_w - pad_l |
| pad_t = pad_h // 2 |
| pad_b = pad_h - pad_t |
| elif mode == "left": |
| pad_l = 0 |
| pad_r = pad_w |
| pad_t = 0 |
| pad_b = pad_h |
| else: |
| raise ValueError(f"Unknown mode {mode}") |
|
|
| return (pad_l, pad_r, pad_t, pad_b) |
|
|
|
|
| def fit_features_to_multiple( |
| features: torch.Tensor, multiple: int = 32, crop: bool = False |
| ) -> Tuple[torch.Tensor, Tuple[int, int]]: |
| """Pad image to a multiple of the given number. |
| |
| Args: |
| features (torch.Tensor): Input features. |
| multiple (int, optional): Multiple. Defaults to 32. |
| crop (bool, optional): Whether to crop or pad. Defaults to False. |
| |
| Returns: |
| Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding. |
| """ |
| pad = fit_to_multiple(features, multiple, crop=crop) |
| return torch.nn.functional.pad(features, pad, mode="reflect"), pad |
|
|