
Evaluation: * worst group accuracy * average group accuracy * accuracy on task defined by spurious attribute


class Evaluator(testset: Dataset, group_partition: Dict[Tuple[int, int], List[int]], group_weights: Dict[Tuple[int, int], float], batch_size: int, model: Module, sklearn_linear_model: Tuple[float, float, float, StandardScaler | None] | None = None, device: device = device(type='cpu'), verbose: bool = False)

Bases: object

__init__(testset: Dataset, group_partition: Dict[Tuple[int, int], List[int]], group_weights: Dict[Tuple[int, int], float], batch_size: int, model: Module, sklearn_linear_model: Tuple[float, float, float, StandardScaler | None] | None = None, device: device = device(type='cpu'), verbose: bool = False)

Initializes an instance of the Evaluator class.

  • testset (Dataset) – Dataset object containing the test set.

  • group_partition (Dict[Tuple[int, int], List[int]]) – Dictionary object mapping group keys to a list of indices corresponding to the test samples in that group.

  • group_weights (Dict[Tuple[int, int], float]) – Dictionary object mapping group keys to their respective weights.

  • batch_size (int) – Batch size for DataLoader.

  • model (nn.Module) – PyTorch model to evaluate.

  • sklearn_linear_model (Optional[Tuple[float, float, float, Optional[StandardScaler]]], optional) – Tuple representing the coefficients and intercept of the linear model from sklearn. Default is None.

  • device (torch.device, optional) – Device to use for computations. Default is torch.device(“cpu”).

  • verbose (bool, optional) – Whether to print evaluation results. Default is False.


Evaluates the PyTorch model on the test dataset and computes the accuracy for each group.


Evaluates accuracy if the task was predicting the spurious attribute.

property worst_group_accuracy

Returns the group with the lowest accuracy and its corresponding accuracy.


A tuple containing the key of the worst-performing group and its corresponding accuracy.

Return type:


property average_accuracy

Returns the weighted average accuracy across all groups.


The weighted average accuracy across all groups.

Return type:
