Source code for pose_format.tensorflow.pose_representation

from typing import List

import tensorflow as tf

from ..pose_representation import PoseRepresentation


[docs]class TensorflowPoseRepresentation(PoseRepresentation): """ Class for pose representations using TensorFlow tensors. * Inherites from ``PoseRepresentation`` This class extends PoseRepresentation and provides methods for manipulating pose representations using TensorFlow tensors. """
[docs] def group_embeds(self, embeds: List[tf.Tensor]): """ Group embeddings (list of tensors) along the first dimension. Parameters ---------- embeds : List[tf.Tensor] List of tensors, each with shape (embed_size, Batch, Len). Returns ------- tf.Tensor Tensor with shape (Batch, Len, embed_size). """ group = tf.concat(embeds, axis=0) # (embed_size, Batch, Len) return tf.transpose(group, perm=[1, 2, 0])
[docs] def get_points(self, tensor: tf.Tensor, points: List): """ Get specific points from a tensor. Parameters ---------- tensor : tf.Tensor Tensor. points : List[int] Indices/points needed from Tensor Returns ------- tf.Tensor Get values from the tensor using the given indices/points """ return tf.gather(tensor, points)
[docs] def permute(self, src, shape: tuple): """ Permute dimensions of a tensor according to a given shape (tuple). Parameters ---------- src : tf.Tensor tensor to permute shape : tuple Desired shape to permute to. Returns ------- tf.Tensor The permuted tensor. """ return tf.transpose(src, perm=shape)