pose_format.torch.masked.torch module

Classes:

MaskedTorch

class mimicing torch functions and giving support for MaskedTensor.

TorchFallback

Meta class that gives a fallback mechanism to use torch functions on MaskedTensor objects.

class pose_format.torch.masked.torch.MaskedTorch[source]

Bases: object

class mimicing torch functions and giving support for MaskedTensor.

Methods:

cat(tensors, dim)

Concatenate MaskedTensor objects along a specified dimension.

squeeze(masked_tensor)

Remove dimensions of size 1 from MaskedTensor.

stack(tensors, dim)

Stack MaskedTensor objects along a new dimension.

zeros(*size[, dtype])

Creates a MaskedTensor of zeros with a given shape and data type.

static cat(tensors, dim)[source]

Concatenate MaskedTensor objects along a specified dimension.

Parameters:
  • tensors (list) – List of tensors or MaskedTensor objects to be concatenated.

  • dim (int) – Dimension along to concatenate.

Returns:

Concatenated tensor.

Return type:

MaskedTensor

static squeeze(masked_tensor)[source]

Remove dimensions of size 1 from MaskedTensor.

Parameters:

masked_tensor (MaskedTensor) – tensor from which dimensions are to be removed.

Returns:

Squeezed masked tensor.

Return type:

MaskedTensor

static stack(tensors, dim)[source]

Stack MaskedTensor objects along a new dimension.

Parameters:
  • tensors (list) – List of MaskedTensor objects to be stacked.

  • dim (int) – New dimension along which to stack.

Returns:

Stacked maked tensor.

Return type:

MaskedTensor

static zeros(*size, dtype=None)[source]

Creates a MaskedTensor of zeros with a given shape and data type.

Parameters:
  • *size (ints) – Dimensions of desired tensor.

  • dtype (torch.dtype, optional) – Data type of the tensor. If None, defaults to torch.float.

Returns:

masked tensor filled with zeros.

Return type:

MaskedTensor

class pose_format.torch.masked.torch.TorchFallback[source]

Bases: type

Meta class that gives a fallback mechanism to use torch functions on MaskedTensor objects. :noindex:

Attributes:

doesnt_change_mask

doesnt_change_mask = {'acos', 'asin', 'atan', 'cos', 'sin', 'sqrt', 'square', 'tan', 'unsqueeze'}