spuco.evaluate

Evaluation: * worst group accuracy * average group accuracy * accuracy on task defined by spurious attribute

Evaluator

class Evaluator(testset: Dataset, group_partition: Dict[Tuple[int, int], List[int]], group_weights: Dict[Tuple[int, int], float], batch_size: int, model: Module, sklearn_linear_model: Tuple[float, float, float, StandardScaler | None] | None = None, device: device = device(type='cpu'), verbose: bool = False)

Bases: object

__init__(testset: Dataset, group_partition: Dict[Tuple[int, int], List[int]], group_weights: Dict[Tuple[int, int], float], batch_size: int, model: Module, sklearn_linear_model: Tuple[float, float, float, StandardScaler | None] | None = None, device: device = device(type='cpu'), verbose: bool = False)

Initializes an instance of the Evaluator class.

Parameters:
  • testset (Dataset) – Dataset object containing the test set.

  • group_partition (Dict[Tuple[int, int], List[int]]) – Dictionary object mapping group keys to a list of indices corresponding to the test samples in that group.

  • group_weights (Dict[Tuple[int, int], float]) – Dictionary object mapping group keys to their respective weights.

  • batch_size (int) – Batch size for DataLoader.

  • model (nn.Module) – PyTorch model to evaluate.

  • sklearn_linear_model (Optional[Tuple[float, float, float, Optional[StandardScaler]]], optional) – Tuple representing the coefficients and intercept of the linear model from sklearn. Default is None.

  • device (torch.device, optional) – Device to use for computations. Default is torch.device(“cpu”).

  • verbose (bool, optional) – Whether to print evaluation results. Default is False.

evaluate()

Evaluates the PyTorch model on the test dataset and computes the accuracy for each group.

evaluate_spurious_attribute_prediction()

Evaluates accuracy if the task was predicting the spurious attribute.

property worst_group_accuracy

Returns the group with the lowest accuracy and its corresponding accuracy.

Returns:

A tuple containing the key of the worst-performing group and its corresponding accuracy.

Return type:

tuple

property average_accuracy

Returns the weighted average accuracy across all groups.

Returns:

The weighted average accuracy across all groups.

Return type:

float