from typing import List
import tensorflow as tf
[docs]class MaskedTensor:
def __init__(self, tensor: tf.Tensor, mask: tf.Tensor = None):
self.tensor = tensor
self.mask = mask if mask is not None else tf.ones(tensor.shape, dtype=tf.bool) # .to(tensor.device)
def __getattr__(self, item):
"""
Get attributes from the tensor, unless it's a callable in which case an error is raised.
Parameters
----------
item : str
Name of the attribute to fetch.
Raises
------
NotImplementedError
If the requested attribute is callable.
"""
val = self.tensor.__getattribute__(item)
if hasattr(val, '__call__'): # If is a function
raise NotImplementedError("callable '%s' not defined" % item)
else:
return val
def __len__(self):
"""
Return the length of the tensor.
Returns
-------
int
Length of the tensor along the first dimension.
"""
shape = self.tensor.shape
return shape[0] if len(shape) > 0 else 1
def __getitem__(self, key):
"""
Get elements from tensor and corresponding mask based on a key.
Parameters
----------
key : list or int or slice or tf.Tensor
Indexing key used to get the elements.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor containing elements selected by the indexing key.
"""
if isinstance(key, list):
tensor = tf.gather(self.tensor, key)
mask = tf.gather(self.mask, key)
else:
tensor = self.tensor[key]
mask = self.mask[key]
return MaskedTensor(tensor=tensor, mask=mask)
[docs] def arithmetic(self, action: str, other):
"""
For element-wise arithmetic operations with another tensor.
Parameters
----------
action : str
Name of the arithmetic operation
other : :class:`pose_format.tensorflow.masked.tensor.MaskedTensor` or tf.Tensor
Tensor or MaskedTensor to perform the operation with.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor containing the result of the arithmetic operation.
"""
if isinstance(other, MaskedTensor):
tensor = getattr(self.tensor, action)(other.tensor)
mask = self.mask & other.mask
else:
tensor = getattr(self.tensor, action)(other)
mask = self.mask
return MaskedTensor(tensor=tensor, mask=mask)
def __float__(self):
return float(self.tensor)
def __add__(self, other):
return self.arithmetic("__add__", other)
def __sub__(self, other):
return self.arithmetic("__sub__", other)
def __mul__(self, other):
return self.arithmetic("__mul__", other)
def __truediv__(self, other):
return self.arithmetic("__truediv__", other)
def __rtruediv__(self, other):
return self.arithmetic("__rtruediv__", other)
def __eq__(self, other):
other_tensor = other.tensor if isinstance(other, MaskedTensor) else other
return self.tensor == other_tensor
def __pow__(self, power):
return self.arithmetic("__pow__", power)
def __round__(self, ndigits):
multiplier = tf.constant(10**ndigits, dtype=tf.float32)
return tf.round(self.tensor * multiplier) / multiplier
[docs] def square(self):
"""
Element-wise square of the tensor.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor containing the squared values of the original tensor.
"""
tensor = tf.math.square(self.tensor)
return MaskedTensor(tensor=tensor, mask=self.mask)
[docs] def float(self):
"""
Convert tensor's data type to float32 while preserving mask.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor with the tensor's data type converted to float32.
"""
tensor = tf.cast(self.tensor, dtype=tf.float32)
return MaskedTensor(tensor=tensor, mask=self.mask)
[docs] def sqrt(self):
"""
Element-wise square root of the tensor
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor containing the square root values of the original tensor.
"""
tensor = tf.math.sqrt(self.tensor)
return MaskedTensor(tensor=tensor, mask=self.mask)
[docs] def sum(self, axis):
"""
Sum of tensor along specified axis while updating mask.
Parameters
----------
axis : int or None
Axis along which to compute sum. If None, compute the sum over all elements.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor containing the sums of the tensor along the specified axis.
"""
tensor = tf.math.reduce_sum(self.tensor, axis=axis)
mask = tf.cast(tf.math.reduce_prod(tf.cast(self.mask, tf.int32), axis=axis), tf.bool)
return MaskedTensor(tensor=tensor, mask=mask)
[docs] def size(self, *args):
"""
Get tensor's size along dimensions.
Parameters
----------
*args : int
Dimensions for which to get size
Returns
-------
int or tuple of int
Size of tensor of specified dimensions.
"""
return self.tensor.size(*args)
[docs] def fix_nan(self):
"""
Replace NaN values with zeros while keeping mask.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
New MaskedTensor with NaN values replaced by zeros.
"""
self.tensor = tf.where(tf.math.is_finite(self.tensor), self.tensor, tf.zeros_like(self.tensor))
return self
[docs] def zero_filled(self) -> tf.Tensor:
"""
Fill invalid values (as indicated by the mask) with zeros.
Returns
-------
tf.Tensor
Tensor with the same shape as `self.tensor` but with zeros where the mask is False.
"""
return self.tensor * tf.cast(self.mask, dtype=self.tensor.dtype)
[docs] def div(self, other: "MaskedTensor", in_place=False, update_mask=True) -> "MaskedTensor":
"""
Divide tensor by another tensor.
Parameters
----------
other : :class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
The divisor tensor.
in_place : bool, optional
Whether to do division in place. Default is False.
update_mask : bool, optional
Whether to update mask after division. Default is True.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
Masked tensor after division.
"""
tensor = tf.div(self.tensor, other.tensor, out=self.tensor if in_place else None)
mask = self.mask & other.mask if update_mask else self.mask
return MaskedTensor(tensor, mask)
[docs] def matmul(self, matrix: tf.Tensor) -> "MaskedTensor":
"""
Matrix multiplication a given matrix.
Parameters
----------
matrix : tf.Tensor
Matrix to perform multiplication with.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
MaskedTensor` with result of matrix multiplication.
"""
tensor = tf.matmul(self.tensor, matrix)
return MaskedTensor(tensor=tensor, mask=self.mask)
[docs] def transpose(self, perm: List[int]) -> "MaskedTensor":
"""
Transpose tensor according to given permutation.
Parameters
----------
perm : List[int]
The new order of dimensions/permutation after transposition.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
MaskedTensor with dimensions transposed according to the given permutation.
"""
tensor = tf.transpose(self.tensor, perm=perm)
mask = tf.transpose(self.mask, perm=perm)
return MaskedTensor(tensor=tensor, mask=mask)
[docs] def permute(self, dims: tuple) -> "MaskedTensor":
""" Permute the dimensions of the tensor according to the provided tuple.
Parameters
----------
dims : tuple
The new order of dimensions after permutation.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor with dimensions permuted according to the given tuple.
"""
tensor = self.tensor.permute(dims=dims)
mask = self.mask.permute(dims=dims)
return MaskedTensor(tensor=tensor, mask=mask)
[docs] def squeeze(self, axis) -> "MaskedTensor":
"""
Remove dimensions with size 1 while updating the mask.
Parameters
----------
axis : int or None
The axis along which to perform squeezing.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
MaskedTensor` with dimensions removed and mask updated.
"""
tensor = tf.squeeze(self.tensor, axis=axis)
mask = tf.squeeze(self.mask, axis=axis)
return MaskedTensor(tensor=tensor, mask=mask)
[docs] def split(self, split_size_or_sections, axis=0):
"""
Split tensor
Parameters
----------
split_size_or_sections : int or tf.Tensor
Number of splits or sizes of each split/sections.
axis : int, optional
Axis along which to do the splitting. Default is 0.
Returns
-------
list of :class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
List of new MaskedTensor objects containing the splits.
"""
tensors = tf.split(self.tensor, split_size_or_sections, axis)
masks = tf.split(self.mask, split_size_or_sections, axis)
return [MaskedTensor(tensor=tensor, mask=mask) for tensor, mask in zip(tensors, masks)]
[docs] def reshape(self, shape: tuple) -> "MaskedTensor":
"""
Reshape tensor into custom shape (tuple)
Parameters
----------
shape : tuple
New shape of tensor.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
new MaskedTensor with specified shape.
"""
tensor = tf.reshape(self.tensor, shape=shape)
mask = tf.reshape(self.mask, shape=shape)
return MaskedTensor(tensor=tensor, mask=mask)
[docs] def gather(self, indexes):
"""
Gather elements from tensor using indexes.
Parameters
----------
indexes : tf.Tensor or list or int
Indexes used to select elements from tensor
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor containing elements gathered from the tensor using the indexes.
"""
tensor = tf.gather(self.tensor, indexes)
mask = tf.gather(self.mask, indexes)
return MaskedTensor(tensor=tensor, mask=mask)
[docs] def rename(self, *names) -> "MaskedTensor":
"""
Rename using custom names.
Parameters
----------
*names : str
New names of the dimensions.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
A new MaskedTensor with dimensions renamed.
"""
tensor = self.tensor.rename(*names)
mask = self.mask.rename(*names)
return MaskedTensor(tensor=tensor, mask=mask)
[docs] def mean(self, axis=None) -> "MaskedTensor":
"""
Compute mean of tensor along a custom axis.
Parameters
----------
axis : None or int, optional
Sxis along which to compute the mean. If None, compute the mean of the entire tensor. Default is None.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
The mean of the masked tensor.
"""
mt_sum = tf.math.reduce_sum(self.zero_filled(), axis=axis)
mt_count = tf.math.reduce_sum(tf.cast(self.mask, mt_sum.dtype), axis=axis)
tensor = tf.math.divide(mt_sum, mt_count)
mask = tf.cast(mt_count, tf.bool)
mt = MaskedTensor(tensor=tensor, mask=mask)
return mt.fix_nan()
[docs] def variance(self, axis=None) -> "MaskedTensor":
"""
Compute variance of tensor along a specified axis
Parameters
----------
axis : None or int, optional
Axis along which to compute the variance. If None, compute the variance of the entire tensor. Default is None.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
The variance of the masked tensor.
"""
means = self.mean(axis=axis)
diff = self - means
squared_deviations = diff.square()
return squared_deviations.mean(axis=axis)
[docs] def std(self, axis=None) -> "MaskedTensor":
"""
Compute the standard deviation of the tensor along the specified axis.
Parameters
----------
axis : None or int, optional
The axis along which to compute the standard deviation. If None, compute the standard deviation of the entire tensor. Default is None.
Returns
-------
:class:`pose_format.tensorflow.masked.tensor.MaskedTensor`
The standard deviation of the tensor.
"""
variance = self.variance(axis=axis)
return variance.sqrt()