Source code for pose_format.torch.representation.inner_angle

import torch
from torch import nn

from ..masked.tensor import MaskedTensor
from ..masked.torch import MaskedTorch


[docs]def get_vectors_norm(vectors: MaskedTensor): """ Computes the normalized vectors from the given masked vectors. Parameters ---------- vectors : :class:`~pose_format.torch.masked.tensor.MaskedTensor` The input masked vectors with any shape. Returns ------- :class:`~pose_format.torch.masked.tensor.MaskedTensor` The normalized masked vectors of the same shape as the input. Notes ----- The function squares the input vectors, then sums along the last dimension. Taking the square root of the sum provides the magnitude. The original vectors are then divided by this magnitude to normalize. """ square = MaskedTorch.square(vectors) summed = square.sum(dim=-1) v_mag = MaskedTorch.sqrt(summed) mag_stack = MaskedTorch.stack([v_mag] * vectors.shape[-1], dim=-1) return vectors.div(mag_stack)
[docs]class InnerAngleRepresentation(nn.Module): """ A neural network module to compute the inner angle at a point for a triangle. """
[docs] def forward(self, p1s: MaskedTensor, p2s: MaskedTensor, p3s: MaskedTensor) -> torch.Tensor: """ Computes the angle in point `p2s` for the triangle defined by the points <p1s, p2s, p3s>. Parameters ---------- p1s : :class:`~pose_format.torch.masked.tensor.MaskedTensor` A tensor representing the first set of points, with shape (Points, Batch, Len, Dims). p2s : :class:`~pose_format.torch.masked.tensor.MaskedTensor` A tensor representing the second set of points (at which the angle is calculated), with shape (Points, Batch, Len, Dims). p3s : :class:`~pose_format.torch.masked.tensor.MaskedTensor` A tensor representing the third set of points, with shape (Points, Batch, Len, Dims). Returns ------- torch.Tensor A tensor representing the computed angles at point `p2s`, with shape (Points, Batch, Len). Note ---- The method is based on the approach suggested in: https://stackoverflow.com/questions/19729831/angle-between-3-points-in-3d-space The function first computes the vectors v1 and v2 by subtracting points p1s and p3s from p2s, respectively. The vectors are then normalized. The angle is calculated by finding the arccosine of the dot product of the normalized vectors. NaN values in the resulting tensor are set to zero. """ # Following https://stackoverflow.com/questions/19729831/angle-between-3-points-in-3d-space v1 = p1s - p2s # (Points, Batch, Len, Dims) v2 = p3s - p2s # (Points, Batch, Len, Dims) v1_norm = get_vectors_norm(v1) v2_norm = get_vectors_norm(v2) slopes = (v1_norm * v2_norm).sum(dim=-1) angles = MaskedTorch.acos(slopes) angles = angles.zero_filled() angles[angles != angles] = 0 # Fix NaN, TODO think of faster way return angles