Source code for pose_format.tensorflow.masked.tensorflow_test

from unittest import TestCase

import tensorflow as tf

from pose_format.tensorflow.masked.tensor import MaskedTensor
from pose_format.tensorflow.masked.tensorflow import MaskedTensorflow


[docs]class TestMaskedTensorflow(TestCase): """ A test class for checking the operations in the MaskedTensorflow class. """ # cat
[docs] def test_cat(self): """Test the `concat` concatination method for two tensors along axis=0.""" tensor1 = MaskedTensor(tf.constant([1, 2])) tensor2 = MaskedTensor(tf.constant([3, 4])) stack = MaskedTensorflow.concat([tensor1, tensor2], axis=0) res = MaskedTensor(tf.constant([[1, 2, 3, 4]])) self.assertTrue(tf.reduce_all(stack == res), msg="Cat is not equal to expected")
# stack
[docs] def test_stack(self): """ Test the `stack` method for two tensors along axis=0. """ tensor1 = MaskedTensor(tf.constant([1, 2])) tensor2 = MaskedTensor(tf.constant([3, 4])) stack = MaskedTensorflow.stack([tensor1, tensor2], axis=0) res = MaskedTensor(tf.constant([[1, 2], [3, 4]])) self.assertTrue(tf.reduce_all(stack == res), msg="Stack is not equal to expected")
# zeros
[docs] def test_zeros_tensor_shape(self): """ Test the shape of the tensor obtained from `zeros` method. """ zeros = MaskedTensorflow.zeros((1, 2, 3)) self.assertEqual(zeros.shape, (1, 2, 3))
[docs] def test_zeros_tensor_value(self): """ Test the values of the tensor obtained from `zeros` method. """ zeros = MaskedTensorflow.zeros((1, 2, 3)) self.assertTrue(tf.reduce_all(zeros == 0), msg="Zeros are not all zeros")
[docs] def test_zeros_tensor_type_float(self): """ Test dtype of tensor obtained from `zeros` method with dtype=tf.float32. """ dtype = tf.float32 zeros = MaskedTensorflow.zeros((1, 2, 3), dtype=dtype) self.assertEqual(zeros.tensor.dtype, dtype)
[docs] def test_zeros_tensor_type_bool(self): """Test the dtype of the tensor obtained from `zeros` method with dtype=tf.bool.""" dtype = tf.bool zeros = MaskedTensorflow.zeros((1, 2, 3), dtype=dtype) self.assertEqual(zeros.tensor.dtype, dtype)
[docs] def test_zeros_mask_value(self): """ Test the mask values of the tensor obtained from `zeros` method. """ zeros = MaskedTensorflow.zeros((1, 2, 3)) self.assertTrue(tf.reduce_all(zeros.mask == tf.zeros((1, 2, 3), dtype=tf.dtypes.bool)), msg="Zeros mask are not all zeros")
# Fallback
[docs] def test_not_implemented_method(self): """ Test a method that is not explicitly defined for MaskedTensor but is implemented via fallback. """ tensor = MaskedTensor(tensor=tf.constant([1, 2, 3])) tensor_square = MaskedTensorflow.square(tensor) self.assertTrue(tf.reduce_all(tensor_square == tf.constant([1, 4, 9])), msg="Square is not equal to expected")