Source code for pose_format.torch.masked.torch_test

from unittest import TestCase

import torch

from pose_format.torch.masked.tensor import MaskedTensor
from pose_format.torch.masked.torch import MaskedTorch


[docs]class TestMaskedTorch(TestCase): """Test cases for the :class:`~pose_format.torch.masked.tensor.MaskedTensor` class """ # cat
[docs] def test_cat(self): """Test `cat` method for concatenating :class:`~pose_format.torch.masked.tensor.MaskedTensor` objects along a specified dimension.""" tensor1 = MaskedTensor(torch.tensor([1, 2])) tensor2 = MaskedTensor(torch.tensor([3, 4])) stack = MaskedTorch.cat([tensor1, tensor2], dim=0) res = MaskedTensor(torch.tensor([[1, 2, 3, 4]])) self.assertTrue(torch.all(stack == res), msg="Cat is not equal to expected")
# stack
[docs] def test_stack(self): """Tests `stack` method for stacking :class:`~pose_format.torch.masked.tensor.MaskedTensor` objects along a new dimension.""" tensor1 = MaskedTensor(torch.tensor([1, 2])) tensor2 = MaskedTensor(torch.tensor([3, 4])) stack = MaskedTorch.stack([tensor1, tensor2], dim=0) res = MaskedTensor(torch.tensor([[1, 2], [3, 4]])) self.assertTrue(torch.all(stack == res), msg="Stack is not equal to expected")
# zeros
[docs] def test_zeros_tensor_shape(self): """Test if `zeros` method correctly produces a :class:`~pose_format.torch.masked.tensor.MaskedTensor` with the desired shape.""" zeros = MaskedTorch.zeros(1, 2, 3) self.assertEqual(zeros.shape, (1, 2, 3))
[docs] def test_zeros_tensor_value(self): """Test if the `zeros` method produces a :class:`~pose_format.torch.masked.tensor.MaskedTensor` with all zero values.""" zeros = MaskedTorch.zeros(1, 2, 3) self.assertTrue(torch.all(zeros == 0), msg="Zeros are not all zeros")
[docs] def test_zeros_tensor_type_float(self): """Test if the `zeros` method produces a :class:`~pose_format.torch.masked.tensor.MaskedTensor` with the correct float data type.""" dtype = torch.float zeros = MaskedTorch.zeros(1, 2, 3, dtype=dtype) self.assertEqual(zeros.tensor.dtype, dtype)
[docs] def test_zeros_tensor_type_bool(self): """Test if the `zeros` method produces a :class:`~pose_format.torch.masked.tensor.MaskedTensor` with the correct boolean data type.""" dtype = torch.bool zeros = MaskedTorch.zeros(1, 2, 3, dtype=dtype) self.assertEqual(zeros.tensor.dtype, dtype)
[docs] def test_zeros_mask_value(self): """Test if the mask in the produced `zeros` :class:`~pose_format.torch.masked.tensor.MaskedTensor` is initialized with zero values.""" zeros = MaskedTorch.zeros(1, 2, 3) self.assertTrue(torch.all(zeros.mask == 0), msg="Zeros mask are not all zeros")
# Fallback
[docs] def test_not_implemented_method(self): """Tests behavior when invoking an unimplemented method on a :class:`~pose_format.torch.masked.tensor.MaskedTensor`.""" tensor = MaskedTensor(tensor=torch.tensor([1, 2, 3])) torch_sum = MaskedTorch.sum(tensor) self.assertEqual(torch_sum, torch.tensor(6))