Skip to content

Binary-Attribute Encoder training

The BA-LR toolkit provides a BinaryAttributeEncoderTrainer class to train an AutoEncoder for binary-attribute representation extraction from embedding vectors.

BinaryAttributeEncoderTrainer

class BinaryAttributeEncoderTrainer:
    def __init__(
        self,
        train: AudioDataset,
        val: AudioDataset,
        data_config: DataConfig,
        loss_funcs: list[BaseAutoEncoderLoss],
        input_dim: int = 256,
        internal_dim: int = 512,
        learning_rate: float = 0.001,
        epochs: int = 100,
        seed: int = 1234,
        save_dir: Path | None = None,
        save_period: int = 0,
        log_period: int = 2,
        val_period: int = 10,
        device: str | torch.device = "cpu",
        **kwargs,
    ):

Parameters:

  • train: the training dataset.
  • val: the validation dataset.
  • data_config: config parameters for the Dataloader used with the datasets.
  • loss_funcs: a list of loss functions to use during training and validation.
  • input_dim: the input dimension for the AutoEncoder (i.e. the embedding dimension).
  • internal_dim: the internal or output dimension for the AutoEncoder (i.e. the dimension of the binary-attribute representation).
  • learning_rate: the learning rate.
  • epochs: the number of epochs for training.
  • seed: the seed for random functions.
  • save_dir: directory where the training output (logs, metrics and model checkpoints) are saved.
  • save_period: save model checkpoint every x epochs.
  • log_period: logs metrics every x epochs.
  • val_period: run validation every x epochs.
  • device: the device to use the model on.

Training

Once the BinaryAttributeEncoderTrainer class has been initialized with the proper parameters, the train method will run training on the training dataset for the set number of epochs.

CLI

The BA-LR cli provides a train command to train a BinaryAttributeEncoder using the BinaryAttributeEncoderTrainer.

Parameters:

  • train: the path to the training dataset.
  • val: the path to the validation dataset.
  • save_dir: directory where the training output (logs, metrics and model checkpoints) are saved. By default, results will be saved to ./runs/trainX, X being incremented as needed (train2, train3, etc. on successive runs).
  • device: the device to use the model on.
  • overrides: optional hydra config overrides.

Warning

Both the training and validation datasets must provide embeddings for all their samples. But since only embeddings are needed for training, you can run training on datasets that only provide embeddings (i.e. without audio files).

Example

balr train resources/data/voxceleb2/train.csv resources/data/voxceleb2/test.csv

will train a BinaryAttributeAutoEncoder model on the voxceleb2/train.csv dataset and use the voxceleb2/test.csv dataset for validation. Results will be saved by default to the ./runs/train directory. Training will run by default on the cpu device.

Example

balr train --save-dir training_output --device cuda resources/data/voxceleb2/train.csv resources/data/voxceleb2/test.csv trainer.epochs=10 'trainer.losses=[mse, arcface]'

This more complex command will run training on the same datasets, but specifies the device (cuda) and the output directory, as well as modifies the trainer's config for epochs and losses parameters using hydra overrides.

Distributed training

It is possible to run training on multiple GPUs with pytorch's torchrun launcher script.

Note

When using torchrun, you must call the cli's balr/cli/main.py module instead of the balr command.

Example

torchrun --nproc_per_node 2 balr/cli/main.py train resources/data/voxceleb2/train.csv resources/data/voxceleb2/test.csv --device cuda:0,1

will run training on two GPUs. The samplers used with the dataloaders must support distributed sampling.