from unittest import TestCase
import numpy as np
from numpy import ma
from pose_format.pose import Pose
from pose_format.pose_header import PoseNormalizationInfo
from pose_format.utils.normalization_3d import PoseNormalizer
[docs]class Test3DNormalization(TestCase):
"""
Test cases for 3D Normalization of pose data.
"""
[docs] def test_normal(self):
"""
Test the calculation of the normal vector for given 3 points on a plane.
Note
----
See the description and calculations on `link`_.
Example (Plane Equation Example revisited) Given,
P = (1, 1, 1), Q = (1, 2, 0), R = (-1, 2, 1).
The normal vector A is the cross product (Q - P) x (R - P) = (1, 2, 2)
.. _link: https://sites.math.washington.edu/~king/coursedir/m445w04/notes/vector/normals-planes.html#:~:text=Thus%20for%20a%20plane%20(or,4%2B4)%20%3D%203.
"""
p1 = (1, 1, 1)
p2 = (1, 2, 0)
p3 = (-1, 2, 1)
gold_normal = (1, 2, 2)
plane = PoseNormalizationInfo(p1=0, p2=1, p3=2)
normalizer = PoseNormalizer(plane=plane, line=None)
tensor = ma.array([[p1, p2, p3]], dtype=np.float32)
normal, _ = normalizer.get_normal(tensor)
gold_vec = ma.array(gold_normal) / np.linalg.norm(gold_normal)
self.assertEqual(ma.allclose(normal, gold_vec), True)
[docs] def test_rotate_vector_by_90_degrees(self):
"""
Test the rotation of a vector by 90 degrees.
"""
plane = PoseNormalizationInfo(p1=0, p2=1, p3=2)
line = PoseNormalizationInfo(p1=0, p2=1)
normalizer = PoseNormalizer(plane, line)
vector = ma.array([[[1, 0, 0]], [[0, 1, 0]]], dtype=float) # Shape (1, 1, 3)
rotated_vector = normalizer.rotate(vector, np.array(90))
expected_rotated_vector = np.array([[[0, -1, 0]], [[1, 0, 0]]], dtype=float)
print("rotated_vector", rotated_vector.shape)
print("expected_rotated_vector", expected_rotated_vector.shape)
assert np.allclose(rotated_vector, expected_rotated_vector, atol=1e-5)
[docs] def test_hand_normalization(self):
"""
Test the normalization of hand pose data using the PoseNormalizer.
"""
with open('data/mediapipe.pose', 'rb') as f:
pose = Pose.read(f.read())
pose = pose.get_components(["RIGHT_HAND_LANDMARKS"])
plane = pose.header.normalization_info(p1=("RIGHT_HAND_LANDMARKS", "WRIST"),
p2=("RIGHT_HAND_LANDMARKS", "PINKY_MCP"),
p3=("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_MCP"))
line = pose.header.normalization_info(p1=("RIGHT_HAND_LANDMARKS", "WRIST"),
p2=("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"))
normalizer = PoseNormalizer(plane=plane, line=line, size=100)
tensor = normalizer(pose.body.data)
pose.body.data = tensor
pose.focus()
with open('data/mediapipe_hand_normalized.pose', 'rb') as f:
pose_gold = Pose.read(f.read())
self.assertTrue(ma.allclose(pose.body.data, pose_gold.body.data))
[docs] def test_hand_normalization_results(self):
"""
Test the normalization results of hand data taken from sign translate.
"""
# Normalization results taken from sign translate
NORMALIZATIONS = [[
[
[248.66639614105225, 297.9550790786743, -0.08617399376817048],
[313.59346628189087, 319.78081941604614, 62.890407741069794],
[349.8637819290161, 310.3189992904663, 104.9409145116806],
[356.05104207992554, 306.42510890960693, 139.85478281974792],
[354.0298342704773, 308.7967085838318, 171.05390071868896],
[368.4513807296753, 211.83621287345886, 118.45286011695862],
[391.60820722579956, 184.4913125038147, 179.64956641197205],
[408.85960578918457, 166.19417786598206, 214.12095069885254],
[418.72357845306396, 153.56018781661987, 241.38201117515564],
[326.7276906967163, 189.34702157974243, 121.67787551879883],
[334.95264768600464, 154.9160075187683, 211.92330718040466],
[338.9851450920105, 145.05938053131104, 276.44052267074585],
[340.1783061027527, 143.3175253868103, 319.86109495162964],
[281.5382671356201, 184.43811058998108, 123.64041864871979],
[288.3403015136719, 189.4317603111267, 214.801185131073],
[288.1521964073181, 213.1494927406311, 234.93002772331238],
[285.52247524261475, 227.414470911026, 237.48689651489258],
[242.2289800643921, 187.8385078907013, 128.50980401039124],
[246.59330368041992, 202.5400733947754, 196.4439833164215],
[247.15078473091125, 223.4764289855957, 208.9101231098175],
[244.32156801223755, 234.57125186920166, 211.09885573387146],
],
[
[0.0, 0.0, 0.0],
[-53.87460867272204, -66.96863447520207, 56.796408965708935],
[-73.72715283958682, -123.0910500437069, 75.20279417200432],
[-67.90506838734477, -155.96881391642683, 96.28995223458003],
[-57.40134100627323, -178.19614573765904, 120.84104646210801],
[-50.965779993758964, -202.95828207912479, -1.509903313490213e-14],
[-45.260072782063325, -278.98544696979855, 18.405850951171615],
[-45.1386394307463, -325.57145063545363, 25.90184983567872],
[-41.976756680249935, -359.5037258132442, 33.766177611013816],
[-1.0464064364675199e-13, -199.6193245295997, -12.333907505180624],
[31.975990322228963, -295.8878599646533, 22.44664078562626],
[51.10130687673754, -354.78374037306935, 59.9347720675275],
[63.63850727086251, -390.7021194588825, 89.44223523582745],
[47.39541430043801, -182.86125824655022, -10.404736003411728],
[66.17638974840436, -254.99061010637854, 58.3901382417529],
[63.46550446035822, -256.24342044957825, 92.80607773270395],
[61.464656345964585, -248.2688837751719, 106.9255087795864],
[86.67900485407291, -166.0742098923334, -1.3322676295501878e-14],
[97.22986231282657, -212.75611112692522, 60.58942546045143],
[92.521597827552, -210.01652471672483, 87.09737972755666],
[91.80539351622573, -203.60235428380585, 98.30312193502695],
],
],
[
[
[266.67819833755493, 322.85115814208984, -0.04363307822495699],
[220.46386861801147, 237.49959421157837, -6.432308048009872],
[212.651225566864, 175.74850010871887, 27.4957894384861],
[253.3728199005127, 150.05188155174255, 57.6715869307518],
[290.9366660118103, 158.14260911941528, 97.49647510051727],
[196.53551816940308, 183.21415996551514, 117.52972221374512],
[180.0925350189209, 143.85282921791077, 150.1890881061554],
[170.33970069885254, 121.17275738716125, 186.817076921463],
[165.10530018806458, 101.94312655925751, 224.22807431221008],
[232.6294367313385, 191.9170138835907, 124.57139778137207],
[241.98445081710815, 147.45505690574646, 172.8555293083191],
[247.0248293876648, 124.05966544151306, 195.77275156974792],
[248.24022006988525, 114.18103969097137, 227.97385263442993],
[267.1261944770813, 208.61302876472473, 117.41860234737396],
[277.41753005981445, 168.1429328918457, 110.84358942508698],
[274.5585584640503, 183.77042746543884, 73.29665148258209],
[267.15885734558105, 202.62700247764587, 68.34884583950043],
[300.5849003791809, 226.72553181648254, 112.3670688867569],
[307.9186568260193, 195.0516927242279, 104.11798548698425],
[297.57420539855957, 204.44280195236206, 80.27082800865173],
[286.06309032440186, 217.2321858406067, 73.97288680076599],
],
[
[0.0, 0.0, 0.0],
[-54.479260198097805, -71.11846055086438, 58.17594848846076],
[-54.39223415540774, -145.6437901306979, 78.54576783252529],
[-3.545339962524018, -179.63888643031996, 87.61257240473797],
[49.14794301931342, -194.95344154721718, 62.06021751253273],
[-40.09854676422485, -208.82027169070966, 5.551115123125783e-16],
[-48.13049962464756, -266.6680014062149, 1.3281840284453539],
[-47.00941869901266, -313.07001717228695, -11.447169097370368],
[-40.773818965790895, -356.4530129101176, -26.203968571403916],
[9.134429734859657e-15, -199.99294837298712, -1.6794645217193471],
[23.603763878904314, -268.1038602115065, -0.7067006484911872],
[35.301688782987696, -302.1004761871728, 1.720669069409734],
[46.93353477547328, -333.09716362043764, -14.541999919797409],
[34.07355106941323, -174.87454913669825, 0.6581788197466907],
[40.2663401403231, -199.59555645909575, 40.242655464274804],
[25.438899509841193, -160.53426324779335, 55.575838830821],
[17.1508731893172, -143.6797255930755, 42.43293329268436],
[67.86518085896179, -150.39315561348292, 9.103828801926284e-15],
[70.91657215646, -167.61675658262482, 33.11742629384811],
[52.6709032860349, -144.92204539941508, 40.80547628453034],
[39.35404239471322, -132.63262286274133, 32.23603583001071],
],
],
[
[
[2.60705627e2, 3.37643646e2, -9.28593054e-2],
[2.10784332e2, 3.16998627e2, -8.27554855e1],
[1.93835342e2, 2.66705475e2, -1.24799034e2],
[2.35526871e2, 2.17327255e2, -1.48760818e2],
[2.75498596e2, 1.88349747e2, -1.49933273e2],
[1.71491211e2, 1.91392395e2, -5.14710464e1],
[1.61075333e2, 1.05774673e2, -1.07832794e2],
[1.63360901e2, 5.55294151e1, -1.40600769e2],
[1.73754684e2, 1.11177626e1, -1.5721608e2],
[2.14831955e2, 1.99762024e2, -4.1304863e1],
[2.44172714e2, 1.61405777e2, -1.30912308e2],
[2.38794632e2, 2.27534744e2, -1.27142723e2],
[2.19911545e2, 2.51557541e2, -8.53535767e1],
[2.56865509e2, 2.13906311e2, -4.56171799e1],
[2.84798065e2, 1.86186142e2, -1.3756131e2],
[2.6911795e2, 2.45126968e2, -1.312836e2],
[2.4843277e2, 2.59336426e2, -9.2029335e1],
[2.92065948e2, 2.32734634e2, -5.8856884e1],
[3.13212036e2, 2.08028992e2, -1.22773369e2],
[2.95430847e2, 2.51139603e2, -1.24858963e2],
[2.76066742e2, 2.66142792e2, -9.7968811e1],
],
[
[0.0, 0.0, 0.0],
[-2.05852102e1, -8.43542239e1, 9.97880671e1],
[-9.5163617, -1.69808429e2, 1.25001397e2],
[6.54616798e1, -2.20523798e2, 1.13256765e2],
[1.23026573e2, -2.37556703e2, 8.68171243e1],
[-4.64094101e1, -2.32186465e2, -1.82238355e-6],
[-1.13881838e1, -3.62578954e2, 1.98150407e1],
[1.93110838e1, -4.35353543e2, 2.87847134e1],
[5.1511044e1, -4.90263213e2, 1.97556152e1],
[2.06700921e-14, -1.9909322e2, -1.90234006e1],
[8.4304534e1, -2.73281319e2, 5.70360486e1],
[5.75141926e1, -1.97081808e2, 9.26397117e1],
[1.01073372e1, -1.57595191e2, 6.21706932e1],
[4.92581982e1, -1.66689775e2, -1.7115462e1],
[1.29823974e2, -2.30243201e2, 6.83153915e1],
[9.14003112e1, -1.65618917e2, 9.95540985e1],
[4.56163097e1, -1.39487004e2, 6.68586078e1],
[9.25816015e1, -1.35992949e2, 2.44384259e-6],
[1.52234167e2, -1.85738665e2, 5.57429433e1],
[1.19207937e2, -1.44327761e2, 8.82483889e1],
[8.00046381e1, -1.22540409e2, 7.03460515e1],
],
]]
plane = PoseNormalizationInfo(p1=0, p2=17, p3=5)
line = PoseNormalizationInfo(p1=0, p2=9)
normalizer = PoseNormalizer(plane=plane, line=line, size=200)
for hand_input, hand_output in NORMALIZATIONS:
hand_input_np = ma.array(hand_input).reshape((1, 1, 21, 3))
hand_output_np = ma.array(hand_output).reshape((1, 1, 21, 3))
tensor = normalizer(hand_input_np)
print("tensor", tensor[0, 0, :2])
print("gold", hand_output_np[0, 0, :2])
self.assertTrue(ma.allclose(tensor, hand_output_np, atol=1e-4))