User Guide
API Reference
GrayscaleToRGBTransform
ColourMap
SpuCoMNIST
SpuCoAnimals
SpuCoDogs
SpuCoBirds
SpuriousFeatureDifficulty
SpuriousCorrelationStrength
SourceData
BaseSpuCoDataset
GroupLabeledDatasetWrapper
SpuriousTargetDatasetWrapper
WILDSDatasetWrapper
Evaluator
Evaluator.__init__()
Evaluator.evaluate()
Evaluator.evaluate_spurious_attribute_prediction()
Evaluator.worst_group_accuracy
Evaluator.average_accuracy
ClusterAlg
ClusterAlg.KMEANS
ClusterAlg.KMEDOIDS
Cluster
Cluster.__init__()
Cluster.infer_groups()
Cluster.silhouette()
Cluster.kmeans()
Cluster.kmedoids()
JTTInference
JTTInference.__init__()
JTTInference.infer_groups()
EIIL
EIIL.__init__()
EIIL.infer_groups()
SSA
SSA.__init__()
SSA.infer_groups()
SSA.train_ssa()
SSA.label_split()
SSATrainer
SSATrainer.__init__()
SSATrainer.train()
SSATrainer.train_step()
SSATrainer.validate()
CorrectNContrastInference
CorrectNContrastInference.__init__()
CorrectNContrastInference.infer_groups()
SpareInference
SpareInference.__init__()
SpareInference.infer_groups()
BaseGroupInference
BaseGroupInference.__init__()
BaseGroupInference.infer_groups()
BaseGroupInference.process_cluster_partition()
BaseGroupInference.save_group_partition()
DFR
DFR.__init__()
DFR.train_single_model()
DFR.train_multiple_model()
DFR.hyperparam_selection()
DFR.train()
DFR.evaluate_worstgroup_acc()
DFR.encode_dataset()
DISPEL
DISPEL.__init__()
DISPEL.train_single_model()
DISPEL.train_multiple_model()
DISPEL.hyperparam_selection()
DISPEL.train()
Identity
Identity.__init__()
Identity.forward()
SupportedModels
SupportedModels.MLP
SupportedModels.LeNet
SupportedModels.BERT
SupportedModels.DistilBERT
SupportedModels.ResNet18
SupportedModels.ResNet50
model_factory()
SpuCoModel
SpuCoModel.__init__()
SpuCoModel.forward()
ERM
ERM.__init__()
GroupWeightedLoss
GroupWeightedLoss.__init__()
GroupWeightedLoss.forward()
GroupWeightedLoss.update_group_weights()
GroupDRO
GroupDRO.__init__()
GroupDRO.train_epoch()
GroupBalanceBatchERM
GroupBalanceBatchERM.__init__()
GroupBalanceBatchERM.train_epoch()
ClassBalanceBatchERM
ClassBalanceBatchERM.__init__()
ClassBalanceBatchERM.train_epoch()
DownSampleERM
DownSampleERM.__init__()
UpSampleERM
UpSampleERM.__init__()
CustomSampleERM
CustomSampleERM.__init__()
CorrectNContrastTrain
CorrectNContrastTrain.__init__()
CNCTrainer
CNCTrainer.__init__()
CNCTrainer.train()
CNCTrainer.train_epoch()
CNCTrainer.compute_accuracy()
CNCTrainer.get_trainset_outputs()
SpareTrain
SpareTrain.__init__()
SpareTrain.train_epoch()
BaseRobustTrain
BaseRobustTrain.__init__()
BaseRobustTrain.train()
BaseRobustTrain.train_epoch()
BaseRobustTrain.best_model
BaseRobustTrain.best_wg_acc
BaseRobustTrain.best_epoch
Trainer
Trainer.__init__()
Trainer.train()
Trainer.train_epoch()
Trainer.compute_accuracy()
Trainer.get_trainset_outputs()
CustomIndicesSampler
CustomIndicesSampler.__init__()
cluster_by_exemplars()
closest_exemplar()
convert_labels_to_partition()
convert_partition_to_labels()
label_examples()
pairwise_similarity()
get_group_ratios()