spuco.evaluate
Evaluation: * worst group accuracy * average group accuracy * accuracy on task defined by spurious attribute
Evaluator
- class Evaluator(testset: Dataset, group_partition: Dict[Tuple[int, int], List[int]], group_weights: Dict[Tuple[int, int], float], batch_size: int, model: Module, sklearn_linear_model: Tuple[float, float, float, StandardScaler | None] | None = None, device: device = device(type='cpu'), verbose: bool = False)
Bases:
object
- __init__(testset: Dataset, group_partition: Dict[Tuple[int, int], List[int]], group_weights: Dict[Tuple[int, int], float], batch_size: int, model: Module, sklearn_linear_model: Tuple[float, float, float, StandardScaler | None] | None = None, device: device = device(type='cpu'), verbose: bool = False)
Initializes an instance of the Evaluator class.
- Parameters:
testset (Dataset) – Dataset object containing the test set.
group_partition (Dict[Tuple[int, int], List[int]]) – Dictionary object mapping group keys to a list of indices corresponding to the test samples in that group.
group_weights (Dict[Tuple[int, int], float]) – Dictionary object mapping group keys to their respective weights.
batch_size (int) – Batch size for DataLoader.
model (nn.Module) – PyTorch model to evaluate.
sklearn_linear_model (Optional[Tuple[float, float, float, Optional[StandardScaler]]], optional) – Tuple representing the coefficients and intercept of the linear model from sklearn. Default is None.
device (torch.device, optional) – Device to use for computations. Default is torch.device(“cpu”).
verbose (bool, optional) – Whether to print evaluation results. Default is False.
- evaluate()
Evaluates the PyTorch model on the test dataset and computes the accuracy for each group.
- evaluate_spurious_attribute_prediction()
Evaluates accuracy if the task was predicting the spurious attribute.
- property worst_group_accuracy
Returns the group with the lowest accuracy and its corresponding accuracy.
- Returns:
A tuple containing the key of the worst-performing group and its corresponding accuracy.
- Return type: