spuco.robust_train

Robust Training Methods.

ERM

class ERM(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, criterion=CrossEntropyLoss(), device: device = device(type='cpu'), lr_scheduler=None, max_grad_norm=None, val_evaluator: Evaluator | None = None, verbose=False)

Bases: BaseRobustTrain

Empirical Risk Minimization (ERM) Trainer.

__init__(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, criterion=CrossEntropyLoss(), device: device = device(type='cpu'), lr_scheduler=None, max_grad_norm=None, val_evaluator: Evaluator | None = None, verbose=False)

Initializes a ERM instance.

Parameters:
  • model (nn.Module) – The neural network model to train.

  • trainset (Dataset) – The trainset to use for training.

  • batch_size (int) – The batch size to use during training.

  • optimizer (torch.optim.Optimizer) – The optimizer to use during training.

  • num_epochs (int) – The number of epochs to train for.

  • criterion (nn.Module, optional) – The loss function to use. Default is nn.CrossEntropyLoss().

  • device (torch.device, optional) – The device to use for training. Default is torch.device(“cpu”).

  • verbose (bool, optional) – If True, prints verbose training information. Default is False.

Group DRO

class GroupWeightedLoss(criterion: Callable[[tensor, tensor], tensor], num_groups: int, group_weight_lr: float = 0.01, device: device = device(type='cpu'))

Bases: Module

A module for computing group-weighted loss.

__init__(criterion: Callable[[tensor, tensor], tensor], num_groups: int, group_weight_lr: float = 0.01, device: device = device(type='cpu'))

Initializes GroupWeightedLoss.

Parameters:
  • criterion (Callable[[torch.tensor, torch.tensor], torch.tensor]) – The loss criterion function.

  • num_groups (int) – The number of groups to consider.

  • group_weight_lr (float) – The learning rate for updating group weights (default: 0.01).

  • device (torch.device) – The device on which to perform computations. Defaults to CPU.

forward(outputs, labels, groups)

Computes the group-weighted loss.

update_group_weights(group_loss)
class GroupDRO(model: Module, trainset: GroupLabeledDatasetWrapper, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Bases: BaseRobustTrain

Group DRO (https://arxiv.org/abs/1911.08731)

__init__(model: Module, trainset: GroupLabeledDatasetWrapper, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Initializes GroupDRO.

Parameters:
  • model (nn.Module) – The PyTorch model to be trained.

  • trainset (GroupLabeledDatasetWrapper) – The training dataset containing group-labeled samples.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer used for training.

  • num_epochs (int) – The number of training epochs.

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

train_epoch(epoch)

Trains the model for a single epoch with a group balanced batch (in expectation)

Parameters:

epoch (int) – The current epoch number.

Group Balance Batch ERM

class GroupBalanceBatchERM(model: Module, trainset: Dataset, group_partition: Dict, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Bases: BaseRobustTrain

GroupBalanceBatchERM class for training a model using group balanced sampling.

__init__(model: Module, trainset: Dataset, group_partition: Dict, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Initializes GroupBalanceBatchERM.

Parameters:
  • model (nn.Module) – The PyTorch model to be trained.

  • trainset (Dataset) – The training dataset.

  • group_partition (Dict) – A dictionary mapping group labels to the indices of examples belonging to each group.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer used for training.

  • num_epochs (int) – The number of training epochs.

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

train_epoch(epoch: int)

Trains the model for a single epoch with a group balanced batch (in expectation)

Parameters:

epoch (int) – The current epoch number.

Class Balance Batch ERM

class ClassBalanceBatchERM(model: Module, trainset: BaseSpuCoCompatibleDataset, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Bases: BaseRobustTrain

ClassBalanceBatchERM class for training a model using class-balanced sampling.

__init__(model: Module, trainset: BaseSpuCoCompatibleDataset, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Initializes ClassBalanceBatchERM.

Parameters:
  • model (nn.Module) – The PyTorch model to be trained.

  • trainset (Dataset) – The training dataset.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer used for training.

  • num_epochs (int) – The number of training epochs.

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

train_epoch(epoch: int)

Trains the model for a single epoch.

Parameters:

epoch (int) – The current epoch number.

DownSample ERM

class DownSampleERM(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, group_partition: Dict[Tuple[int, int], List[int]], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Bases: BaseRobustTrain

DownSampleERM class for training a model by downsampling all groups to size of smallest group.

__init__(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, group_partition: Dict[Tuple[int, int], List[int]], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Initializes DownSampleERM.

Parameters:
  • model (nn.Module) – The PyTorch model to be trained.

  • trainset (Dataset) – The training dataset.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer used for training.

  • num_epochs (int) – The number of training epochs.

  • indices (List[int]) – A list of indices specifying the samples to be shown to the model in 1 epoch.

  • criterion (nn.Module) – The loss criterion used for training (default: CrossEntropyLoss).

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

UpSample ERM

class UpSampleERM(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, group_partition: Dict[Tuple[int, int], List[int]], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Bases: BaseRobustTrain

UpSampleERM class for training a model by upsampling all groups to size of largest group.

__init__(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, group_partition: Dict[Tuple[int, int], List[int]], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Initializes UpSampleERM.

Parameters:
  • model (nn.Module) – The PyTorch model to be trained.

  • trainset (Dataset) – The training dataset.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer used for training.

  • num_epochs (int) – The number of training epochs.

  • indices (List[int]) – A list of indices specifying the samples to be shown to the model in 1 epoch.

  • criterion (nn.Module) – The loss criterion used for training (default: CrossEntropyLoss).

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

Custom Sample ERM

class CustomSampleERM(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, indices: List[int], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Bases: BaseRobustTrain

CustomSampleERM class for training a model using custom sampling of the dataset

__init__(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, indices: List[int], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Initializes CustomSampleERM.

Parameters:
  • model (nn.Module) – The PyTorch model to be trained.

  • trainset (Dataset) – The training dataset.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer used for training.

  • num_epochs (int) – The number of training epochs.

  • indices (List[int]) – A list of indices specifying the samples to be shown to the model in 1 epoch.

  • criterion (nn.Module) – The loss criterion used for training (default: CrossEntropyLoss).

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

Correct-N-Contrast Train

class CorrectNContrastTrain(trainset: GroupLabeledDatasetWrapper, model: SpuCoModel, batch_size: int, optimizer_encoder: Optimizer, optimizer_classifier: Optimizer, num_pos: int, num_neg: int, num_epochs: int, lambda_ce: float, temp: float, device: device = device(type='cpu'), accum: int = 32, val_evaluator: Evaluator | None = None, verbose: bool = False)

Bases: BaseRobustTrain

CorrectNContrastTrain class for training a model using CNC’s Cross Entropy + modified Supervised Contrastive Learning loss.

__init__(trainset: GroupLabeledDatasetWrapper, model: SpuCoModel, batch_size: int, optimizer_encoder: Optimizer, optimizer_classifier: Optimizer, num_pos: int, num_neg: int, num_epochs: int, lambda_ce: float, temp: float, device: device = device(type='cpu'), accum: int = 32, val_evaluator: Evaluator | None = None, verbose: bool = False)

Initializes CorrectNContrastTrain.

Parameters:
  • trainset (GroupLabeledDatasetWrapper) – The training dataset containing group-labeled samples.

  • model (SpuCoModel) – The SpuCoModel to be trained.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer used for training.

  • num_pos (int) – The number of positive examples for contrastive loss.

  • num_neg (int) – The number of negative examples for contrastive loss.

  • num_epochs (int) – The number of training epochs.

  • lambda_ce (float) – The weight of the regular cross-entropy loss.

  • temp (float) – The temperature the regular cross-entropy loss.

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

class CNCTrainer(trainset: Dataset, model: Module, batch_size: int, optimizer_1: Optimizer, optimizer_2: Optimizer, accum_1: int, accum_2: int, lr_scheduler: _LRScheduler | None = None, max_grad_norm: float | None = None, criterion: Module = CrossEntropyLoss(), forward_pass: Callable[[Any], Tuple[Tensor, Tensor, Tensor]] | None = None, sampler: Sampler | None = None, device: device = device(type='cpu'), verbose: bool = False)

Bases: object

__init__(trainset: Dataset, model: Module, batch_size: int, optimizer_1: Optimizer, optimizer_2: Optimizer, accum_1: int, accum_2: int, lr_scheduler: _LRScheduler | None = None, max_grad_norm: float | None = None, criterion: Module = CrossEntropyLoss(), forward_pass: Callable[[Any], Tuple[Tensor, Tensor, Tensor]] | None = None, sampler: Sampler | None = None, device: device = device(type='cpu'), verbose: bool = False) None

Initializes an instance of the Trainer class.

Parameters:
  • trainset (torch.utils.data.Dataset) – The training set.

  • model (torch.nn.Module) – The PyTorch model to train.

  • batch_size (int) – The batch size to use during training.

  • optimizer (torch.optim.Optimizer) – The optimizer to use for training.

  • criterion (torch.nn.Module, optional) – The loss function to use during training. Default is nn.CrossEntropyLoss().

  • forward_pass (Callable[[Any], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], optional) – The forward pass function to use during training. Default is None.

  • sampler (torch.utils.data.Sampler, optional) – The sampler to use for creating batches. Default is None.

  • device (torch.device, optional) – The device to use for computations. Default is torch.device(“cpu”).

  • verbose (bool, optional) – Whether to print training progress. Default is False.

train(num_epochs: int)

Trains for given number of epochs

Parameters:

num_epochs (int) – Number of epochs to train for

train_epoch(epoch: int) None

Trains the model for 1 epoch using CNC method

Parameters:

epoch (int) – epoch number that is being trained (only used by logging)

static compute_accuracy(outputs: Tensor, labels: Tensor) float

Computes the accuracy of the PyTorch model.

Parameters:
  • outputs (torch.Tensor) – The predicted outputs of the model.

  • labels (torch.Tensor) – The ground truth labels.

Returns:

The accuracy of the model.

Return type:

float

get_trainset_outputs()

Gets output of model on trainset

SPARE Train

class SpareTrain(model: Module, trainset: Dataset, group_partition: Dict, sampling_powers: List, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Bases: GroupBalanceBatchERM

SpareTrain class for training a model using group balanced sampling

__init__(model: Module, trainset: Dataset, group_partition: Dict, sampling_powers: List, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)

Initializes GroupBalanceBatchERM.

Parameters:
  • model (nn.Module) – The PyTorch model to be trained.

  • trainset (Dataset) – The training dataset.

  • group_partition (Dict) – A dictionary mapping group labels to the indices of examples belonging to each group.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer used for training.

  • num_epochs (int) – The number of training epochs.

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

train_epoch(epoch: int)

Trains the model for a single epoch with a custom upsampled batch

Parameters:

epoch (int) – The current epoch number.

Base Robust Train

class BaseRobustTrain(val_evaluator: Evaluator | None = None, verbose: bool = False)

Bases: ABC

Abstract base class for InvariantTrain methods Provides support for worst group accuracy early stopping

__init__(val_evaluator: Evaluator | None = None, verbose: bool = False)

Initializes the model trainer.

Parameters:
  • val_evaluator (Evaluator, optional) – Evaluator object for validation evaluation. Default is None.

  • verbose (bool) – Whether to print training progress. Default is False.

train()

Train for specified number of epochs (and do early stopping if val_evaluator given)

train_epoch(epoch: int)

Trains the model for a single epoch.

Parameters:

epoch (int) – The current epoch number.

property best_model

Property for accessing the best model.

Returns:

The best model.

Return type:

Any

Raises:

NotImplementedError – If no val_evaluator is set to get worst group validation accuracy.

property best_wg_acc

Property for accessing the best worst group validation accuracy.

Returns:

The best worst group validation accuracy.

Return type:

Any

Raises:

NotImplementedError – If no val_evaluator is passed.

property best_epoch

Property for accessing the best epoch number.

Returns:

The best epoch number.

Return type:

Any

Raises:

NotImplementedError – If no val_evaluator is passed.