Source code for tests.optical_flow_test

"""
Module for testing optical flow computations. Provides test cases for optical flow computation from pose landmarks data.
"""
import tempfile
from unittest import TestCase

from matplotlib.testing.compare import compare_images

from pose_format.numpy.representation.distance import DistanceRepresentation
from pose_format.pose import Pose
from pose_format.utils.optical_flow import OpticalFlowCalculator


[docs]class TestOpticalFlow(TestCase): """ Test cases for optical flow computation. """
[docs] def test_optical_flow(self): """ Tests optical flow from pose landmarks data and visualize the computed flow using a heatmap. Compares the generated heatmap with the reference image provided in `data/optical_flow.png` to validate the correctness of the omputed optical flow. Raises: AssertionError: If the computed optical flow visualization does not match the reference image. """ calculator = OpticalFlowCalculator(fps=30, distance=DistanceRepresentation()) with open('data/mediapipe.pose', 'rb') as f: pose = Pose.read(f.read()) pose = pose.get_components(["POSE_LANDMARKS", "RIGHT_HAND_LANDMARKS", "LEFT_HAND_LANDMARKS"]) flow = calculator(pose.body.data) flow = flow.squeeze(axis=1) print(flow.shape) # Matplotlib heatmap import matplotlib.pyplot as plt plt.imshow(flow.T) plt.tight_layout() fp = tempfile.NamedTemporaryFile() plt.savefig(fp.name, format='png') self.assertTrue(compare_images('data/optical_flow.png', fp.name, 0.001) is None)