Source code for pose_format.torch.masked.tensor

import torch


[docs]class MaskedTensor: """ Container for a PyTorch tensor, providing utility functions for tensor masking. Parameters ---------- tensor : torch.Tensor Tensor data. mask : torch.Tensor, optional A boolean mask tensor of the same shape as `tensor`. If specified, elements of `tensor` corresponding to `True` values in the mask are considered valid. Defaults to a tensor of all `True` values. """ def __init__(self, tensor: torch.Tensor, mask: torch.Tensor = None): self.tensor = tensor self.mask = mask if mask is not None else torch.ones(tensor.shape, dtype=torch.bool).to(tensor.device) def __getattr__(self, item): """ Gets attributes of tensor. Raises ------ NotImplementedError If called attribute is not implemented. """ val = self.tensor.__getattribute__(item) if hasattr(val, '__call__'): # If is a function # return getattr(MaskedTorch, item)(self) raise NotImplementedError("callbable '%s' not defined" % item) else: return val def __len__(self): """ Gets size of first dimension of the tensor. Returns ------- int Size of first dimension of tensor. """ return self.tensor.shape[0] def __getitem__(self, key): """ Get a subset of a tensor based on a key or slice. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Subset of the tensor. """ tensor = self.tensor[key] mask = self.mask[key] return MaskedTensor(tensor=tensor, mask=mask)
[docs] def arithmetic(self, action: str, other): """ Helper method to perform arithmetic operations on tensors. Parameters ---------- action : str The arithmetic operation to be performed. other : Union[~pose_format.torch.masked.tensor.MaskedTensor`, torch.Tensor, float, int] The second operand. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` New `MaskedTensor` after the operation. """ if isinstance(other, MaskedTensor): tensor = getattr(self.tensor, action)(other.tensor) mask = self.mask & other.mask else: tensor = getattr(self.tensor, action)(other) mask = self.mask return MaskedTensor(tensor=tensor, mask=mask)
def __add__(self, other): """ Performs element-wise addition with another tensor or scalar. Parameters ---------- other : Union[:class:`~pose_format.torch.masked.tensor.MaskedTensor`, torch.Tensor, float, int] The tensor or scalar to add. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Resultant tensor after addition. """ return self.arithmetic("__add__", other) def __sub__(self, other): """ Performs element-wise subtraction with another tensor or scalar. Parameters ---------- other : Union[:class:`~pose_format.torch.masked.tensor.MaskedTensor`, torch.Tensor, float, int] The tensor or scalar to subtract. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Resultant tensor after subtraction. """ return self.arithmetic("__sub__", other) def __mul__(self, other): """ Performs element-wise multiplication with another tensor or scalar. Parameters ---------- other : Union[:class:`~pose_format.torch.masked.tensor.MaskedTensor`, torch.Tensor, float, int] The tensor or scalar to multiply. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Resultant tensor after multiplication. """ return self.arithmetic("__mul__", other) def __truediv__(self, other): """ Performs element-wise division with another tensor or scalar. Parameters ---------- other : Union[:class:`~pose_format.torch.masked.tensor.MaskedTensor`, torch.Tensor, float, int] The tensor or scalar to divide by. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Resultant tensor after division. """ return self.arithmetic("__truediv__", other) def __eq__(self, other): """ Compares the tensor for element-wise equality with another tensor. Parameters ---------- other : torch.Tensor The tensor to compare. Returns ------- torch.Tensor A boolean tensor with `True` where elements are equal and `False` otherwise. """ return self.tensor == other
[docs] def pow_(self, exponent: float): """ Raises tensor to power of a given exponent in-place. Parameters ---------- exponent : float The exponent value. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Masked tensor raised to a given exponent. """ self.tensor.pow_(exponent) return self
[docs] def sum(self, dim: int): """ Sums along a specified dimension. Parameters ---------- dim : int dimension to sum over. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Summed tensor along the specified dimension. """ tensor = self.tensor.sum(dim=dim) mask = self.mask.prod(dim=dim).bool() return MaskedTensor(tensor=tensor, mask=mask)
[docs] def size(self, *args): """ Get size of tensor for specified dimensions. Returns ------- torch.Size Size of tensor. """ return self.tensor.size(*args)
[docs] def fix_nan(self): # TODO think of faster way """ Replaces any NaN values in the tensor with zeros. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Tensor with NaN values replaced by zeros. """ self.tensor[self.tensor != self.tensor] = 0 return self
[docs] def to(self, device): """ Moves tensor to a custom device. Parameters ---------- device : str or torch.device The target device. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Tensor on the other device. """ tensor = self.tensor.to(device) mask = self.mask.to(device) return MaskedTensor(tensor=tensor, mask=mask)
[docs] def cuda(self, device=None, non_blocking: bool = False): """ Moves tensor to the GPU. Parameters ---------- device : str or torch.device, optional The target CUDA device. non_blocking : bool, optional Whether to perform an operation asynchronously. Default is False. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Tensor on CUDA device. """ tensor = self.tensor.cuda(device=device, non_blocking=non_blocking) mask = self.mask.cuda(device=device, non_blocking=non_blocking) return MaskedTensor(tensor=tensor, mask=mask)
[docs] def zero_filled(self) -> torch.Tensor: """ Get tensor with masked values set to zero. Returns ------- torch.Tensor Tensor with masked values set to zero. """ return self.tensor.mul(self.mask)
[docs] def div(self, other: "MaskedTensor", in_place=False, update_mask=True): """ Performs element-wise division with another tensor. Parameters ---------- other : :class:`~pose_format.torch.masked.tensor.MaskedTensor` The tensor to divide with. in_place : bool, optional If True, performs the operation in-place. Default is False. update_mask : bool, optional If True, updates the mask after division. Default is True. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Resultant tensor after division. """ tensor = torch.div(self.tensor, other.tensor, out=self.tensor if in_place else None) mask = self.mask & other.mask if update_mask else self.mask return MaskedTensor(tensor, mask)
[docs] def matmul(self, matrix: torch.Tensor): """ Perform matrix multiplication. Parameters ---------- matrix : torch.Tensor matrix to multiply with. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` New masked tensor after multiplication. """ tensor = torch.matmul(self.tensor, matrix.to(self.device)) return MaskedTensor(tensor, self.mask)
[docs] def transpose(self, dim0, dim1): """ Transposes tensor along two dimensions. Parameters ---------- dim0, dim1 : int Two dimensions to which to transpose. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Transposed masked tensor. """ tensor = self.tensor.transpose(dim0, dim1) mask = self.mask.transpose(dim0, dim1) return MaskedTensor(tensor=tensor, mask=mask)
[docs] def permute(self, dims: tuple): """ Permute dimensions of tensor. Parameters ---------- dims : tuple Desired ordering of dimensions. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Permuted masked tensor. """ tensor = self.tensor.permute(dims) mask = self.mask.permute(dims) return MaskedTensor(tensor=tensor, mask=mask)
[docs] def squeeze(self, dim): """ Squeeze tensor along chosen dimension. Parameters ---------- dim : int Dimension to squeeze. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Squeezed masked tensor. """ tensor = self.tensor.squeeze(dim) mask = self.mask.squeeze(dim) return MaskedTensor(tensor=tensor, mask=mask)
[docs] def split(self, split_size_or_sections, dim=0): """ Split tensor into multiple tensors. Parameters ---------- split_size_or_sections : int or tuple Size or sections to split tensor. dim : int, optional Dimension along which to split tensor. Default is 0. Returns ------- list[:class:`~pose_format.torch.masked.tensor.MaskedTensor`] List of split tensors. """ tensors = torch.split(self.tensor, split_size_or_sections, dim) masks = torch.split(self.mask, split_size_or_sections, dim) return [MaskedTensor(tensor=tensor, mask=mask) for tensor, mask in zip(tensors, masks)]
[docs] def reshape(self, shape: tuple): """ Reshape tensor to given shape. Parameters ---------- shape : tuple Desired shape. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Reshaped tensor. """ tensor = self.tensor.reshape(shape=shape) mask = self.mask.reshape(shape=shape) return MaskedTensor(tensor=tensor, mask=mask)
[docs] def rename(self, *names): """ Rename tensor's dimensions. Parameters ---------- names : tuple Desired names for each dimension. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` Renamed masked tensor. """ tensor = self.tensor.rename(*names) mask = self.mask.rename(*names) return MaskedTensor(tensor=tensor, mask=mask)