spuco.datasets
Datasets
SpuCo MNIST
- class SpuCoMNIST(root: str, spurious_feature_difficulty: SpuriousFeatureDifficulty, classes: List[List[int]], spurious_correlation_strength=0.0, label_noise: float = 0.0, core_feature_noise: float = 0.0, color_map: ColourMap = ColourMap.HSV, split: str = 'train', transform: Callable | None = None, verbose: bool = False, download: bool = True)
Bases:
BaseSpuCoDatasetA dataset consisting of images from the MNIST dataset with added spurious features to create a spurious MNIST dataset.
- __init__(root: str, spurious_feature_difficulty: SpuriousFeatureDifficulty, classes: List[List[int]], spurious_correlation_strength=0.0, label_noise: float = 0.0, core_feature_noise: float = 0.0, color_map: ColourMap = ColourMap.HSV, split: str = 'train', transform: Callable | None = None, verbose: bool = False, download: bool = True)
Initializes the SpuCoMNIST dataset.
- Parameters:
root (str) – The root directory of the dataset.
spurious_feature_difficulty (SpuriousFeatureDifficulty) – The difficulty level of the spurious feature.
classes (List[List[int]]) – The list of class labels for each digit.
spurious_correlation_strength (float, optional) – The strength of the spurious feature correlation. Default is 0.
label_noise (float, optional) – The amount of label noise to apply. Default is 0.0.
core_feature_noise (float, optional) – The amount of noise to add to the core features. Default is 0.0.
color_map (ColourMap, optional) – The color map to use. Default is ColourMap.HSV.
split (str, optional) – The dataset split to load. Default is “train”.
transform (Optional[Callable], optional) – The data transformation function. Default is None.
verbose (bool, optional) – Whether to print verbose information during dataset initialization. Default is False.
download (bool, optional) – Whether to download the dataset. Default is True.
- load_data() SourceData
Loads the MNIST dataset and generates the spurious correlation dataset.
- Returns:
The spurious correlation dataset.
- Return type:
- init_colors(color_map: ColourMap) List[List[float]]
Initializes the color values for the spurious features.
- static validate_classes(classes: List[List[int]]) bool
Validates that the classes provided to the SpuCoMNIST dataset are disjoint and only contain integers between 0 and 9.
- static create_background(spurious_feature_difficulty: SpuriousFeatureDifficulty, hex_code: str) Tensor
Generates a tensor representing a background image with a specified spurious feature difficulty and hex code color.
- Parameters:
spurious_feature_difficulty (SpuriousFeatureDifficulty) – The difficulty level of the spurious feature to add to the background image.
hex_code (str) – The hex code of the color to use for the background image.
- Returns:
A tensor representing the generated background image.
- Return type:
torch.Tensor
- static compute_mask(unmask_points: Tensor) Tensor
Computes a binary mask based on the unmasked points.
- Parameters:
unmask_points (torch.Tensor) – The coordinates of the unmasked points.
- Returns:
The binary mask with 1s at the unmasked points and 0s elsewhere.
- Return type:
torch.Tensor
Base Spuco Dataset
- class SpuriousFeatureDifficulty(value)
Bases:
EnumEnumeration class for spurious feature difficulty levels.
Each level corresponds to a combination of the magnitude and variance of the spurious feature.
- Magnitude definition of difficulty:
Large Magnitude <-> Easy
Medium Magnitude <-> Medium
Small Magnitude <-> Hard
- Variance definition of difficulty:
Low Variance <-> Easy
Medium Variance <-> Medium
High Variance <-> Hard
- MAGNITUDE_LARGE = 'magnitude_large'
- MAGNITUDE_MEDIUM = 'magnitude_medium'
- MAGNITUDE_SMALL = 'magnitude_small'
- VARIANCE_LOW = 'variance_low'
- VARIANCE_MEDIUM = 'variance_medium'
- VARIANCE_HIGH = 'variance_high'
- class SpuriousCorrelationStrength(value)
Bases:
EnumAn enumeration.
- UNIFORM = 'unform'
- LINEAR = 'linear'
- class SourceData(data=None)
Bases:
objectClass representing the source data.
This class contains the input data and corresponding labels.
- __init__(data=None)
Initialize the SourceData object.
- Parameters:
data (List[Tuple]) – The input data and labels.
- class BaseSpuCoDataset(root: str, num_classes: int, split: str = 'train', transform: Callable | None = None, verbose: bool = False)
Bases:
BaseSpuCoCompatibleDataset,ABC- __init__(root: str, num_classes: int, split: str = 'train', transform: Callable | None = None, verbose: bool = False)
Initializes the dataset.
- Parameters:
root (str) – Root directory of the dataset.
num_classes (int) – Number of classes in the dataset.
split (str, optional) – Split of the dataset (e.g., “train”, “test”, “val”). Defaults to “train”.
transform (Callable, optional) – Optional transform to be applied to the data. Defaults to None.
verbose (bool, optional) – Whether to print verbose information during dataset initialization. Defaults to False.
- initialize()
Initializes the dataset.
- property group_partition: Dict[Tuple[int, int], List[int]]
Dictionary partitioning indices into groups
- property clean_group_partition: Dict[Tuple[int, int], List[int]]
Dictionary partitioning indices into groups based on clean labels
- property group_weights: Dict[Tuple[int, int], float]
Dictionary containing the fractional weights of each group
Wrappers
Group Labeled Dataset Wrapper
- class GroupLabeledDatasetWrapper(dataset: Dataset, group_partition: Dict[Tuple[int, int], int], subset_indices: List[int] | None = None)
Bases:
Dataset
Spurious Target Dataset
WILDSDatasetWrapper
- class WILDSDatasetWrapper(dataset: WILDSDataset, metadata_spurious_label: str, verbose=False)
Bases:
BaseSpuCoCompatibleDatasetWrapper class that wraps WILDSDataset into a Dataset to be compatible with SpuCo.
- __init__(dataset: WILDSDataset, metadata_spurious_label: str, verbose=False)
Wraps WILDS Dataset into a Dataset object.
- property group_partition: Dict[Tuple[int, int], List[int]]
Dictionary partitioning indices into groups