spuco.robust_train
Robust Training Methods.
ERM
- class ERM(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, criterion=CrossEntropyLoss(), device: device = device(type='cpu'), lr_scheduler=None, max_grad_norm=None, val_evaluator: Evaluator | None = None, verbose=False)
Bases:
BaseRobustTrain
Empirical Risk Minimization (ERM) Trainer.
- __init__(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, criterion=CrossEntropyLoss(), device: device = device(type='cpu'), lr_scheduler=None, max_grad_norm=None, val_evaluator: Evaluator | None = None, verbose=False)
Initializes a ERM instance.
- Parameters:
model (nn.Module) – The neural network model to train.
trainset (Dataset) – The trainset to use for training.
batch_size (int) – The batch size to use during training.
optimizer (torch.optim.Optimizer) – The optimizer to use during training.
num_epochs (int) – The number of epochs to train for.
criterion (nn.Module, optional) – The loss function to use. Default is nn.CrossEntropyLoss().
device (torch.device, optional) – The device to use for training. Default is torch.device(“cpu”).
verbose (bool, optional) – If True, prints verbose training information. Default is False.
Group DRO
- class GroupWeightedLoss(criterion: Callable[[tensor, tensor], tensor], num_groups: int, group_weight_lr: float = 0.01, device: device = device(type='cpu'))
Bases:
Module
A module for computing group-weighted loss.
- __init__(criterion: Callable[[tensor, tensor], tensor], num_groups: int, group_weight_lr: float = 0.01, device: device = device(type='cpu'))
Initializes GroupWeightedLoss.
- Parameters:
criterion (Callable[[torch.tensor, torch.tensor], torch.tensor]) – The loss criterion function.
num_groups (int) – The number of groups to consider.
group_weight_lr (float) – The learning rate for updating group weights (default: 0.01).
device (torch.device) – The device on which to perform computations. Defaults to CPU.
- forward(outputs, labels, groups)
Computes the group-weighted loss.
- update_group_weights(group_loss)
- class GroupDRO(model: Module, trainset: GroupLabeledDatasetWrapper, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Bases:
BaseRobustTrain
Group DRO (https://arxiv.org/abs/1911.08731)
- __init__(model: Module, trainset: GroupLabeledDatasetWrapper, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Initializes GroupDRO.
- Parameters:
model (nn.Module) – The PyTorch model to be trained.
trainset (GroupLabeledDatasetWrapper) – The training dataset containing group-labeled samples.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer used for training.
num_epochs (int) – The number of training epochs.
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
Group Balance Batch ERM
- class GroupBalanceBatchERM(model: Module, trainset: Dataset, group_partition: Dict, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Bases:
BaseRobustTrain
GroupBalanceBatchERM class for training a model using group balanced sampling.
- __init__(model: Module, trainset: Dataset, group_partition: Dict, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Initializes GroupBalanceBatchERM.
- Parameters:
model (nn.Module) – The PyTorch model to be trained.
trainset (Dataset) – The training dataset.
group_partition (Dict) – A dictionary mapping group labels to the indices of examples belonging to each group.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer used for training.
num_epochs (int) – The number of training epochs.
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
Class Balance Batch ERM
- class ClassBalanceBatchERM(model: Module, trainset: BaseSpuCoCompatibleDataset, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Bases:
BaseRobustTrain
ClassBalanceBatchERM class for training a model using class-balanced sampling.
- __init__(model: Module, trainset: BaseSpuCoCompatibleDataset, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Initializes ClassBalanceBatchERM.
- Parameters:
model (nn.Module) – The PyTorch model to be trained.
trainset (Dataset) – The training dataset.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer used for training.
num_epochs (int) – The number of training epochs.
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
DownSample ERM
- class DownSampleERM(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, group_partition: Dict[Tuple[int, int], List[int]], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Bases:
BaseRobustTrain
DownSampleERM class for training a model by downsampling all groups to size of smallest group.
- __init__(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, group_partition: Dict[Tuple[int, int], List[int]], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Initializes DownSampleERM.
- Parameters:
model (nn.Module) – The PyTorch model to be trained.
trainset (Dataset) – The training dataset.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer used for training.
num_epochs (int) – The number of training epochs.
indices (List[int]) – A list of indices specifying the samples to be shown to the model in 1 epoch.
criterion (nn.Module) – The loss criterion used for training (default: CrossEntropyLoss).
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
UpSample ERM
- class UpSampleERM(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, group_partition: Dict[Tuple[int, int], List[int]], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Bases:
BaseRobustTrain
UpSampleERM class for training a model by upsampling all groups to size of largest group.
- __init__(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, group_partition: Dict[Tuple[int, int], List[int]], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Initializes UpSampleERM.
- Parameters:
model (nn.Module) – The PyTorch model to be trained.
trainset (Dataset) – The training dataset.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer used for training.
num_epochs (int) – The number of training epochs.
indices (List[int]) – A list of indices specifying the samples to be shown to the model in 1 epoch.
criterion (nn.Module) – The loss criterion used for training (default: CrossEntropyLoss).
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
Custom Sample ERM
- class CustomSampleERM(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, indices: List[int], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Bases:
BaseRobustTrain
CustomSampleERM class for training a model using custom sampling of the dataset
- __init__(model: Module, trainset: Dataset, batch_size: int, optimizer: Optimizer, num_epochs: int, indices: List[int], criterion=CrossEntropyLoss(), device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Initializes CustomSampleERM.
- Parameters:
model (nn.Module) – The PyTorch model to be trained.
trainset (Dataset) – The training dataset.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer used for training.
num_epochs (int) – The number of training epochs.
indices (List[int]) – A list of indices specifying the samples to be shown to the model in 1 epoch.
criterion (nn.Module) – The loss criterion used for training (default: CrossEntropyLoss).
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
Correct-N-Contrast Train
- class CorrectNContrastTrain(trainset: GroupLabeledDatasetWrapper, model: SpuCoModel, batch_size: int, optimizer_encoder: Optimizer, optimizer_classifier: Optimizer, num_pos: int, num_neg: int, num_epochs: int, lambda_ce: float, temp: float, device: device = device(type='cpu'), accum: int = 32, val_evaluator: Evaluator | None = None, verbose: bool = False)
Bases:
BaseRobustTrain
CorrectNContrastTrain class for training a model using CNC’s Cross Entropy + modified Supervised Contrastive Learning loss.
- __init__(trainset: GroupLabeledDatasetWrapper, model: SpuCoModel, batch_size: int, optimizer_encoder: Optimizer, optimizer_classifier: Optimizer, num_pos: int, num_neg: int, num_epochs: int, lambda_ce: float, temp: float, device: device = device(type='cpu'), accum: int = 32, val_evaluator: Evaluator | None = None, verbose: bool = False)
Initializes CorrectNContrastTrain.
- Parameters:
trainset (GroupLabeledDatasetWrapper) – The training dataset containing group-labeled samples.
model (SpuCoModel) – The SpuCoModel to be trained.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer used for training.
num_pos (int) – The number of positive examples for contrastive loss.
num_neg (int) – The number of negative examples for contrastive loss.
num_epochs (int) – The number of training epochs.
lambda_ce (float) – The weight of the regular cross-entropy loss.
temp (float) – The temperature the regular cross-entropy loss.
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
- class CNCTrainer(trainset: Dataset, model: Module, batch_size: int, optimizer_1: Optimizer, optimizer_2: Optimizer, accum_1: int, accum_2: int, lr_scheduler: _LRScheduler | None = None, max_grad_norm: float | None = None, criterion: Module = CrossEntropyLoss(), forward_pass: Callable[[Any], Tuple[Tensor, Tensor, Tensor]] | None = None, sampler: Sampler | None = None, device: device = device(type='cpu'), verbose: bool = False)
Bases:
object
- __init__(trainset: Dataset, model: Module, batch_size: int, optimizer_1: Optimizer, optimizer_2: Optimizer, accum_1: int, accum_2: int, lr_scheduler: _LRScheduler | None = None, max_grad_norm: float | None = None, criterion: Module = CrossEntropyLoss(), forward_pass: Callable[[Any], Tuple[Tensor, Tensor, Tensor]] | None = None, sampler: Sampler | None = None, device: device = device(type='cpu'), verbose: bool = False) None
Initializes an instance of the Trainer class.
- Parameters:
trainset (torch.utils.data.Dataset) – The training set.
model (torch.nn.Module) – The PyTorch model to train.
batch_size (int) – The batch size to use during training.
optimizer (torch.optim.Optimizer) – The optimizer to use for training.
criterion (torch.nn.Module, optional) – The loss function to use during training. Default is nn.CrossEntropyLoss().
forward_pass (Callable[[Any], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], optional) – The forward pass function to use during training. Default is None.
sampler (torch.utils.data.Sampler, optional) – The sampler to use for creating batches. Default is None.
device (torch.device, optional) – The device to use for computations. Default is torch.device(“cpu”).
verbose (bool, optional) – Whether to print training progress. Default is False.
- train(num_epochs: int)
Trains for given number of epochs
- Parameters:
num_epochs (int) – Number of epochs to train for
- train_epoch(epoch: int) None
Trains the model for 1 epoch using CNC method
- Parameters:
epoch (int) – epoch number that is being trained (only used by logging)
- static compute_accuracy(outputs: Tensor, labels: Tensor) float
Computes the accuracy of the PyTorch model.
- Parameters:
outputs (torch.Tensor) – The predicted outputs of the model.
labels (torch.Tensor) – The ground truth labels.
- Returns:
The accuracy of the model.
- Return type:
- get_trainset_outputs()
Gets output of model on trainset
SPARE Train
- class SpareTrain(model: Module, trainset: Dataset, group_partition: Dict, sampling_powers: List, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Bases:
GroupBalanceBatchERM
SpareTrain class for training a model using group balanced sampling
- __init__(model: Module, trainset: Dataset, group_partition: Dict, sampling_powers: List, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), val_evaluator: Evaluator | None = None, verbose=False)
Initializes GroupBalanceBatchERM.
- Parameters:
model (nn.Module) – The PyTorch model to be trained.
trainset (Dataset) – The training dataset.
group_partition (Dict) – A dictionary mapping group labels to the indices of examples belonging to each group.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer used for training.
num_epochs (int) – The number of training epochs.
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
Base Robust Train
- class BaseRobustTrain(val_evaluator: Evaluator | None = None, verbose: bool = False)
Bases:
ABC
Abstract base class for InvariantTrain methods Provides support for worst group accuracy early stopping
- __init__(val_evaluator: Evaluator | None = None, verbose: bool = False)
Initializes the model trainer.
- train()
Train for specified number of epochs (and do early stopping if val_evaluator given)
- train_epoch(epoch: int)
Trains the model for a single epoch.
- Parameters:
epoch (int) – The current epoch number.
- property best_model
Property for accessing the best model.
- Returns:
The best model.
- Return type:
Any
- Raises:
NotImplementedError – If no val_evaluator is set to get worst group validation accuracy.
- property best_wg_acc
Property for accessing the best worst group validation accuracy.
- Returns:
The best worst group validation accuracy.
- Return type:
Any
- Raises:
NotImplementedError – If no val_evaluator is passed.
- property best_epoch
Property for accessing the best epoch number.
- Returns:
The best epoch number.
- Return type:
Any
- Raises:
NotImplementedError – If no val_evaluator is passed.