spuco.datasets

Datasets

SpuCo MNIST

class GrayscaleToRGBTransform

Bases: object

class ColourMap(value)

Bases: Enum

An enumeration.

HSV = 'hsv'
class SpuCoMNIST(root: str, spurious_feature_difficulty: SpuriousFeatureDifficulty, classes: List[List[int]], spurious_correlation_strength=0.0, label_noise: float = 0.0, core_feature_noise: float = 0.0, color_map: ColourMap = ColourMap.HSV, split: str = 'train', transform: Callable | None = None, verbose: bool = False, download: bool = True)

Bases: BaseSpuCoDataset

A dataset consisting of images from the MNIST dataset with added spurious features to create a spurious MNIST dataset.

__init__(root: str, spurious_feature_difficulty: SpuriousFeatureDifficulty, classes: List[List[int]], spurious_correlation_strength=0.0, label_noise: float = 0.0, core_feature_noise: float = 0.0, color_map: ColourMap = ColourMap.HSV, split: str = 'train', transform: Callable | None = None, verbose: bool = False, download: bool = True)

Initializes the SpuCoMNIST dataset.

Parameters:
  • root (str) – The root directory of the dataset.

  • spurious_feature_difficulty (SpuriousFeatureDifficulty) – The difficulty level of the spurious feature.

  • classes (List[List[int]]) – The list of class labels for each digit.

  • spurious_correlation_strength (float, optional) – The strength of the spurious feature correlation. Default is 0.

  • label_noise (float, optional) – The amount of label noise to apply. Default is 0.0.

  • core_feature_noise (float, optional) – The amount of noise to add to the core features. Default is 0.0.

  • color_map (ColourMap, optional) – The color map to use. Default is ColourMap.HSV.

  • split (str, optional) – The dataset split to load. Default is “train”.

  • transform (Optional[Callable], optional) – The data transformation function. Default is None.

  • verbose (bool, optional) – Whether to print verbose information during dataset initialization. Default is False.

  • download (bool, optional) – Whether to download the dataset. Default is True.

load_data() SourceData

Loads the MNIST dataset and generates the spurious correlation dataset.

Returns:

The spurious correlation dataset.

Return type:

SourceData

init_colors(color_map: ColourMap) List[List[float]]

Initializes the color values for the spurious features.

Parameters:

color_map (ColourMap) – The color map to use for the spurious features. Should be a value from the ColourMap enum class.

Returns:

The color values for the spurious features.

Return type:

List[List[float]]

static validate_classes(classes: List[List[int]]) bool

Validates that the classes provided to the SpuCoMNIST dataset are disjoint and only contain integers between 0 and 9.

Parameters:

classes (List[List[int]]) – The classes to be included in the dataset, where each element is a list of integers representing the digits to be included in a single class.

Returns:

Whether the classes are valid.

Return type:

bool

static create_background(spurious_feature_difficulty: SpuriousFeatureDifficulty, hex_code: str) Tensor

Generates a tensor representing a background image with a specified spurious feature difficulty and hex code color.

Parameters:
  • spurious_feature_difficulty (SpuriousFeatureDifficulty) – The difficulty level of the spurious feature to add to the background image.

  • hex_code (str) – The hex code of the color to use for the background image.

Returns:

A tensor representing the generated background image.

Return type:

torch.Tensor

static compute_mask(unmask_points: Tensor) Tensor

Computes a binary mask based on the unmasked points.

Parameters:

unmask_points (torch.Tensor) – The coordinates of the unmasked points.

Returns:

The binary mask with 1s at the unmasked points and 0s elsewhere.

Return type:

torch.Tensor

static rgb_to_mnist_background(rgb: List[float]) Tensor

Converts an RGB color to a MNIST background tensor.

Parameters:

rgb (List[float]) – The RGB color values.

Returns:

The MNIST background tensor with the specified RGB color.

Return type:

torch.Tensor

SpuCo Animals

class SpuCoAnimals(root: str, download: bool = True, label_noise: float = 0.0, split: str = 'train', transform: Callable | None = None, mask_type: str | None = None, verbose: bool = False)

Bases: BaseSpuCoDataset

Next, we introduce SpuCoAnimals, a large-scale vision dataset curated from ImageNet with two realistic spurious correlations.

SpuCoAnimals has 4 classes:

  • landbirds

  • waterbirds

  • small dog breeds

  • big dog breeds.

Waterbirds and Landbirds are spuriously correlated with water and land backgrounds, respectively. Small dogs and big dogs are spuriously correlated with indoor and outdoor backgrounds, respectively.

__init__(root: str, download: bool = True, label_noise: float = 0.0, split: str = 'train', transform: Callable | None = None, mask_type: str | None = None, verbose: bool = False)

Initializes the dataset.

Parameters:
  • root (str) – Root directory of the dataset.

  • download (bool, optional) – Whether to download the dataset. Defaults to True.

  • label_noise (float, optional) – The amount of label noise to apply. Defaults to 0.0.

  • split (str, optional) – The split of the dataset. Defaults to TRAIN_SPLIT.

  • transform (Callable, optional) – Optional transform to be applied to the data. Defaults to None.

  • mask_type (str, optional) – Optionally mask out the spurious or core feature

  • verbose (bool, optional) – Whether to print verbose information during dataset initialization. Defaults to False.

load_data() SourceData

Loads SpuCoAnimals and sets spurious labels, label noise.

Returns:

The spurious correlation dataset.

Return type:

SourceData, List[int], List[int]

load_image(filename: str)

SpuCo Dogs

class SpuCoDogs(root: str, download: bool = True, label_noise: float = 0.0, split: str = 'train', transform: Callable | None = None, verbose: bool = False)

Bases: BaseSpuCoDataset

Subset of SpuCoAnimals only including Dog classes.

__init__(root: str, download: bool = True, label_noise: float = 0.0, split: str = 'train', transform: Callable | None = None, verbose: bool = False)

Initializes the dataset.

Parameters:
  • root (str) – Root directory of the dataset.

  • download (bool, optional) – Whether to download the dataset.

  • label_noise (float, optional) – The amount of label noise to apply.

  • split (str, optional) – The split of the dataset.

  • transform (Callable, optional) – Optional transform to be applied to the data.

  • verbose (bool, optional) – Whether to print verbose information during dataset initialization.

load_data() SourceData

Loads SpuCoDogs and sets spurious labels, label noise.

Returns:

The spurious correlation dataset.

Return type:

SourceData, List[int], List[int]

SpuCo Birds

class SpuCoBirds(root: str, download: bool = True, label_noise: float = 0.0, split: str = 'train', transform: Callable | None = None, verbose: bool = False)

Bases: BaseSpuCoDataset

Subset of SpuCoAnimals only including Bird classes.

__init__(root: str, download: bool = True, label_noise: float = 0.0, split: str = 'train', transform: Callable | None = None, verbose: bool = False)

Initializes the dataset.

Parameters:
  • root (str) – Root directory of the dataset.

  • download (bool, optional) – Whether to download the dataset.

  • label_noise (float, optional) – The amount of label noise to apply.

  • split (str, optional) – The split of the dataset.

  • transform (Callable, optional) – Optional transform to be applied to the data.

  • verbose (bool, optional) – Whether to print verbose information during dataset initialization.

load_data() SourceData

Loads SpuCoBirds and sets spurious labels, label noise.

Returns:

The spurious correlation dataset.

Return type:

SourceData, List[int], List[int]

Base Spuco Dataset

class SpuriousFeatureDifficulty(value)

Bases: Enum

Enumeration class for spurious feature difficulty levels.

Each level corresponds to a combination of the magnitude and variance of the spurious feature.

Magnitude definition of difficulty:

Large Magnitude <-> Easy

Medium Magnitude <-> Medium

Small Magnitude <-> Hard

Variance definition of difficulty:

Low Variance <-> Easy

Medium Variance <-> Medium

High Variance <-> Hard

MAGNITUDE_LARGE = 'magnitude_large'
MAGNITUDE_MEDIUM = 'magnitude_medium'
MAGNITUDE_SMALL = 'magnitude_small'
VARIANCE_LOW = 'variance_low'
VARIANCE_MEDIUM = 'variance_medium'
VARIANCE_HIGH = 'variance_high'
class SpuriousCorrelationStrength(value)

Bases: Enum

An enumeration.

UNIFORM = 'unform'
LINEAR = 'linear'
class SourceData(data=None)

Bases: object

Class representing the source data.

This class contains the input data and corresponding labels.

__init__(data=None)

Initialize the SourceData object.

Parameters:

data (List[Tuple]) – The input data and labels.

class BaseSpuCoDataset(root: str, num_classes: int, split: str = 'train', transform: Callable | None = None, verbose: bool = False)

Bases: BaseSpuCoCompatibleDataset, ABC

__init__(root: str, num_classes: int, split: str = 'train', transform: Callable | None = None, verbose: bool = False)

Initializes the dataset.

Parameters:
  • root (str) – Root directory of the dataset.

  • num_classes (int) – Number of classes in the dataset.

  • split (str, optional) – Split of the dataset (e.g., “train”, “test”, “val”). Defaults to “train”.

  • transform (Callable, optional) – Optional transform to be applied to the data. Defaults to None.

  • verbose (bool, optional) – Whether to print verbose information during dataset initialization. Defaults to False.

initialize()

Initializes the dataset.

property group_partition: Dict[Tuple[int, int], List[int]]

Dictionary partitioning indices into groups

Return type:

Dict[Tuple[int, int], List[int]]

property clean_group_partition: Dict[Tuple[int, int], List[int]]

Dictionary partitioning indices into groups based on clean labels

Return type:

Dict[Tuple[int, int], List[int]]

property group_weights: Dict[Tuple[int, int], float]

Dictionary containing the fractional weights of each group

Return type:

Dict[Tuple[int, int], float]

property spurious: List[int]

List containing spurious labels for each example

Return type:

List[int]

property labels: List[int]

List containing class labels for each example

Return type:

List[int]

property num_classes: int

Number of classes

Return type:

int

Wrappers

Group Labeled Dataset Wrapper

class GroupLabeledDatasetWrapper(dataset: Dataset, group_partition: Dict[Tuple[int, int], int], subset_indices: List[int] | None = None)

Bases: Dataset

__init__(dataset: Dataset, group_partition: Dict[Tuple[int, int], int], subset_indices: List[int] | None = None)

Initializes a GroupLabeledDataset.

Parameters:
  • dataset (torch.utils.data.Dataset) – The underlying dataset.

  • group_partition (Dict[Tuple[int, int], int]) – The group partition dictionary mapping indices to group labels.

  • subset_indices (Optional[List[int]]) – Optional list of subset indices to consider from the dataset. Defaults to None.

Spurious Target Dataset

class SpuriousTargetDatasetWrapper(dataset: Dataset, spurious_labels: List[int])

Bases: Dataset

Wrapper class that takes a Dataset and the spurious labels of the data and returns a dataset where the labels are the spurious labels.

__init__(dataset: Dataset, spurious_labels: List[int])

Initialize an instance of SpuriousTargetDatasetWrapper.

Parameters:
  • dataset (Dataset) – The original dataset.

  • spurious_labels (List[int]) – The spurious labels corresponding to the data.

WILDSDatasetWrapper

class WILDSDatasetWrapper(dataset: WILDSDataset, metadata_spurious_label: str, verbose=False, subset_indices: List[int] | None = None)

Bases: BaseSpuCoCompatibleDataset

Wrapper class that wraps WILDSDataset into a Dataset to be compatible with SpuCo.

__init__(dataset: WILDSDataset, metadata_spurious_label: str, verbose=False, subset_indices: List[int] | None = None)

Wraps WILDS Dataset into a Dataset object.

Parameters:
  • dataset (WILDDataset) – The source WILDS dataset

  • metadata_spurious_label (str) – String name of property in metadata_map corresponding to spurious target

  • verbose (bool) – Show logs

property group_partition: Dict[Tuple[int, int], List[int]]

Dictionary partitioning indices into groups

property group_weights: Dict[Tuple[int, int], float]

Dictionary containing the fractional weights of each group

property spurious: List[int]

List containing spurious labels for each example

property labels: List[int]

List containing class labels for each example

property num_classes: int

Number of classes