Module facetorch.utils

Expand source code
import omegaconf
import torch
import torchvision


def rgb2bgr(tensor: torch.Tensor) -> torch.Tensor:
    """Converts a batch of RGB tensors to BGR tensors or vice versa.

    Args:
        tensor (torch.Tensor): Batch of RGB (or BGR) channeled tensors
        with shape (dim0, channels, dim2, dim3)

    Returns:
        torch.Tensor: Batch of BGR (or RGB) tensors with shape (dim0, channels, dim2, dim3).
    """
    assert tensor.shape[1] == 3, "Tensor must have 3 channels."
    return tensor[:, [2, 1, 0]]


def fix_transform_list_attr(
    transform: torchvision.transforms.Compose,
) -> torchvision.transforms.Compose:
    """Fix the transform attributes by converting the listconfig to a list.
    This enables to optimize the transform using TorchScript.

    Args:
        transform (torchvision.transforms.Compose): Transform to be fixed.

    Returns:
        torchvision.transforms.Compose: Fixed transform.
    """
    for transform_x in transform.transforms:
        for key, value in transform_x.__dict__.items():
            if isinstance(value, omegaconf.listconfig.ListConfig):
                transform_x.__dict__[key] = list(value)
    return transform

Functions

def rgb2bgr(tensor: torch.Tensor) ‑> torch.Tensor

Converts a batch of RGB tensors to BGR tensors or vice versa.

Args

tensor : torch.Tensor
Batch of RGB (or BGR) channeled tensors

with shape (dim0, channels, dim2, dim3)

Returns

torch.Tensor
Batch of BGR (or RGB) tensors with shape (dim0, channels, dim2, dim3).
Expand source code
def rgb2bgr(tensor: torch.Tensor) -> torch.Tensor:
    """Converts a batch of RGB tensors to BGR tensors or vice versa.

    Args:
        tensor (torch.Tensor): Batch of RGB (or BGR) channeled tensors
        with shape (dim0, channels, dim2, dim3)

    Returns:
        torch.Tensor: Batch of BGR (or RGB) tensors with shape (dim0, channels, dim2, dim3).
    """
    assert tensor.shape[1] == 3, "Tensor must have 3 channels."
    return tensor[:, [2, 1, 0]]
def fix_transform_list_attr(transform: torchvision.transforms.transforms.Compose) ‑> torchvision.transforms.transforms.Compose

Fix the transform attributes by converting the listconfig to a list. This enables to optimize the transform using TorchScript.

Args

transform : torchvision.transforms.Compose
Transform to be fixed.

Returns

torchvision.transforms.Compose
Fixed transform.
Expand source code
def fix_transform_list_attr(
    transform: torchvision.transforms.Compose,
) -> torchvision.transforms.Compose:
    """Fix the transform attributes by converting the listconfig to a list.
    This enables to optimize the transform using TorchScript.

    Args:
        transform (torchvision.transforms.Compose): Transform to be fixed.

    Returns:
        torchvision.transforms.Compose: Fixed transform.
    """
    for transform_x in transform.transforms:
        for key, value in transform_x.__dict__.items():
            if isinstance(value, omegaconf.listconfig.ListConfig):
                transform_x.__dict__[key] = list(value)
    return transform