Loss functions
The BA-LR toolkit provides some loss functions to use during training of a Binary Attribute Encoder.
BaseAutoEncoderLoss
All loss functions which can be used with a BinaryAttributeEncoderTrainer
must be subclasses of BaseAutoEncoderLoss
.
class BaseAutoEncoderLoss(torch.nn.Module, ABC):
def __init__(self, name: str, weight: torch.Tensor | None = None, *args, **kwargs):
super().__init__(*args, **kwargs)
...
A BaseAutoEncoderLoss
function can have a weight
parameter which will be applied to itself to compute the total loss
during training (the total loss
is the weighted sum of the losses computed by the loss functions of the trainer).
All BaseAutoEncoderLoss
classes must implement the forward
method which computes the loss from the parameters.
def forward(
self, input: torch.Tensor, labels: torch.Tensor, output: torch.Tensor, recon: torch.Tensor, Z: torch.Tensor,
) -> torch.Tensor:
Parameters:
- input: input features fed to the model (i.e. embeddings). A torch.Tensor of size (batch_size, embedding_dim).
- labels: labels associated to the input. A torch.Tensor of size (batch_size).
- output: model's encoded reprensentation of the input. A torch.Tensor of size (batch_size, encoder_dim).
- recon: the reconstructed input from the model. A torch.Tensor of size (batch_size, embedding_dim).
- Z: the model's latent space representation of the input. A torch.Tensor of size (batch_size, encoder_dim).
MSE Loss
An implementation of torch.nn.MSELoss
.
Triplet Margin Loss
class TripletMarginLoss(BaseAutoEncoderLoss):
def __init__(
self, margin: float = 0.3, type_of_triplets: str = "all", weight: torch.Tensor | None = None,
):
Parameters:
- margin: the difference between the anchor-positive distance and the anchor-negative distance.
- type_of_triplets:
- "all" selects all triplets that violate the margin
- "hard" is a subset of "all", but the negative is closer to the anchor than the positive
- "semihard" is a subset of "all", but the negative is further from the anchor than the positive
- "easy" selects all triplets that do not violate the margin.
An implementation of Triplet Margin Loss
. The TripletMarginLoss
uses a Triplet Margin Miner
to compute all possible triplets within the batch based on the labels. Anchor-positive pairs are formed by embeddings that share the same label, and anchor-negative pairs are formed by embeddings that have different labels. The miner selects positive and negative pairs that are particularly difficult, i.e. all triplets that violate the margin
param (the difference between the anchor-positive distance and the anchor-negative distance).
Note
In order to be able to select triplets, each batch must contain at least 2 distinct classes (to form the anchor-negative pairs) and at least 2 samples per class (to form the anchor-positive pairs). Shuffling the classes grouped together in a batch is also recommended to avoid overfitting to some pairs.
ArcFace Loss
class ArcFaceLoss(BaseAutoEncoderLoss):
def __init__(
self, margin: float = 0.3, type_of_triplets: str = "all", weight: torch.Tensor | None = None,
):
...
def setup(self, nb_train_classes: int, internal_dim: int):
Parameters:
- margin: the difference between the anchor-positive distance and the anchor-negative distance.
- type_of_triplets:
- "all" selects all triplets that violate the margin
- "hard" is a subset of "all", but the negative is closer to the anchor than the positive
- "semihard" is a subset of "all", but the negative is further from the anchor than the positive
- "easy" selects all triplets that do not violate the margin.
- nb_train_classes: the number of unique classes in the train dataset.
- internal_dim: the encoder's dimension.
An implementation of ArcFace Loss
. The ArcFaceLoss
uses the same Triplet Margin Miner
as the TripletMarginLoss
to select triplets.
The loss's setup
method must be called once the trainer is instantiated to provide additionnal parameters. The trainer will call setup
on each loss function with named parameters that can be used to setup the loss.
Note
This loss's parameters must be passed to the optimizer during training.
Sparsity Loss
The goal of the sparsity loss is to push the model to encode binary representations modeled by shared discriminant attributes between speakers. An attribute is considered present in the profile, if the sum of all utterance activations per attribute for a speaker is non-zero, while the typicality is the presence frequency of an attribute amon speaker profiles. So we can regulate the activations of the latent space dimensions before binarization to ensure that only a subset of speakers has a particular dimension present in their profile.

The sparsity loss drives the binary-attribute dimensions towards achievin a desired presence frequency among speakers. Mainly, to ensure the absence of an attribute in a speaker's profile, we need to drive the sum of activations across all their utterance vectors to 0. The sparsity loss pushes each dimension to follow a specific sparsity while considering speakers.

Note
The sparsity loss requires each batch to contain N speakers and M samples per speaker.
The sparsity loss works with the latent space representations of the speaker embeddings from the encoder (Tanh activations ranging from -1 to 1. Z
in the figure above). The vector V
representing the desired frequency presence of attributes is generated randomly when the loss is initialized.
class SparsityLoss(BaseAutoEncoderLoss):
def __init__(self, weight: torch.Tensor | None = None):
...
def setup(self, M_samples_per_class: int, internal_dim: int):
The SparsityLoss
class does not take parameters at initialization (other than its weight
) but its setup
method must be called once the trainer is instantiated. The trainer will call setup
on each loss function with named parameters that can be used to setup the loss.
Parameters:
- M_samples_per_class: the number of samples per speaker in each batch.
- internal_dim: the encoder's dimension.
CLI
The losses used during training can be set by overriding the trainer.losses
parameter.